SAMZA-123: Move topic partition grouping to the AM and generalize
authorJakob Homan <jghoman@apache.org>
Mon, 28 Jul 2014 22:11:57 +0000 (15:11 -0700)
committerJakob Homan <jghoman@apache.org>
Mon, 28 Jul 2014 22:11:57 +0000 (15:11 -0700)
65 files changed:
.gitignore
build.gradle
docs/learn/documentation/0.7.0/container/samza-container.md
samza-api/src/main/java/org/apache/samza/checkpoint/Checkpoint.java
samza-api/src/main/java/org/apache/samza/checkpoint/CheckpointManager.java
samza-api/src/main/java/org/apache/samza/container/SamzaContainerContext.java
samza-api/src/main/java/org/apache/samza/container/SystemStreamPartitionGrouper.java [new file with mode: 0644]
samza-api/src/main/java/org/apache/samza/container/SystemStreamPartitionGrouperFactory.java [new file with mode: 0644]
samza-api/src/main/java/org/apache/samza/container/TaskName.java [new file with mode: 0644]
samza-api/src/main/java/org/apache/samza/job/CommandBuilder.java
samza-api/src/main/java/org/apache/samza/task/TaskContext.java
samza-core/src/main/scala/org/apache/samza/checkpoint/CheckpointTool.scala
samza-core/src/main/scala/org/apache/samza/checkpoint/OffsetManager.scala
samza-core/src/main/scala/org/apache/samza/checkpoint/file/FileSystemCheckpointManager.scala
samza-core/src/main/scala/org/apache/samza/config/JobConfig.scala
samza-core/src/main/scala/org/apache/samza/config/ShellCommandConfig.scala
samza-core/src/main/scala/org/apache/samza/container/RunLoop.scala
samza-core/src/main/scala/org/apache/samza/container/SamzaContainer.scala
samza-core/src/main/scala/org/apache/samza/container/SystemStreamPartitionTaskNameGrouper.scala [new file with mode: 0644]
samza-core/src/main/scala/org/apache/samza/container/TaskInstance.scala
samza-core/src/main/scala/org/apache/samza/container/TaskInstanceMetrics.scala
samza-core/src/main/scala/org/apache/samza/container/TaskNamesToSystemStreamPartitions.scala [new file with mode: 0644]
samza-core/src/main/scala/org/apache/samza/container/systemstreampartition/groupers/GroupByPartition.scala [new file with mode: 0644]
samza-core/src/main/scala/org/apache/samza/container/systemstreampartition/groupers/GroupBySystemStreamPartition.scala [new file with mode: 0644]
samza-core/src/main/scala/org/apache/samza/container/systemstreampartition/taskname/groupers/SimpleSystemStreamPartitionTaskNameGrouper.scala [new file with mode: 0644]
samza-core/src/main/scala/org/apache/samza/job/ShellCommandBuilder.scala
samza-core/src/main/scala/org/apache/samza/job/local/LocalJobFactory.scala
samza-core/src/main/scala/org/apache/samza/serializers/CheckpointSerde.scala
samza-core/src/main/scala/org/apache/samza/storage/TaskStorageManager.scala
samza-core/src/main/scala/org/apache/samza/task/ReadableCoordinator.scala
samza-core/src/main/scala/org/apache/samza/util/Util.scala
samza-core/src/test/scala/org/apache/samza/checkpoint/TestCheckpointTool.scala
samza-core/src/test/scala/org/apache/samza/checkpoint/TestOffsetManager.scala
samza-core/src/test/scala/org/apache/samza/checkpoint/file/TestFileSystemCheckpointManager.scala
samza-core/src/test/scala/org/apache/samza/container/SystemStreamPartitionGrouperTestBase.scala [new file with mode: 0644]
samza-core/src/test/scala/org/apache/samza/container/TestRunLoop.scala
samza-core/src/test/scala/org/apache/samza/container/TestSamzaContainer.scala
samza-core/src/test/scala/org/apache/samza/container/TestTaskInstance.scala
samza-core/src/test/scala/org/apache/samza/container/TestTaskNamesToSystemStreamPartitions.scala [new file with mode: 0644]
samza-core/src/test/scala/org/apache/samza/container/systemstreampartition/groupers/TestGroupByPartition.scala [new file with mode: 0644]
samza-core/src/test/scala/org/apache/samza/container/systemstreampartition/groupers/TestGroupBySystemStreamPartition.scala [new file with mode: 0644]
samza-core/src/test/scala/org/apache/samza/container/systemstreampartition/taskname/groupers/TestSimpleSystemStreamPartitionTaskNameGrouper.scala [new file with mode: 0644]
samza-core/src/test/scala/org/apache/samza/job/TestJobRunner.scala
samza-core/src/test/scala/org/apache/samza/job/TestShellCommandBuilder.scala [new file with mode: 0644]
samza-core/src/test/scala/org/apache/samza/metrics/TestJmxServer.scala
samza-core/src/test/scala/org/apache/samza/serializers/TestCheckpointSerde.scala
samza-core/src/test/scala/org/apache/samza/system/filereader/TestFileReaderSystemConsumer.scala
samza-core/src/test/scala/org/apache/samza/task/TestReadableCoordinator.scala
samza-core/src/test/scala/org/apache/samza/util/TestUtil.scala
samza-kafka/src/main/scala/org/apache/samza/checkpoint/kafka/KafkaCheckpointLogKey.scala [new file with mode: 0644]
samza-kafka/src/main/scala/org/apache/samza/checkpoint/kafka/KafkaCheckpointManager.scala
samza-kafka/src/main/scala/org/apache/samza/checkpoint/kafka/KafkaCheckpointManagerFactory.scala
samza-kafka/src/main/scala/org/apache/samza/system/kafka/TopicMetadataCache.scala
samza-kafka/src/test/scala/org/apache/samza/checkpoint/kafka/TestKafkaCheckpointLogKey.scala [new file with mode: 0644]
samza-kafka/src/test/scala/org/apache/samza/checkpoint/kafka/TestKafkaCheckpointManager.scala
samza-kafka/src/test/scala/org/apache/samza/system/kafka/TestKafkaSystemAdmin.scala
samza-kv-leveldb/src/main/scala/org/apache/samza/storage/kv/LevelDbKeyValueStore.scala
samza-test/src/main/java/org/apache/samza/test/integration/join/Emitter.java
samza-test/src/main/scala/org/apache/samza/test/performance/TestKeyValuePerformance.scala
samza-test/src/test/scala/org/apache/samza/test/integration/TestStatefulTask.scala
samza-yarn/src/main/resources/scalate/WEB-INF/views/index.scaml
samza-yarn/src/main/scala/org/apache/samza/job/yarn/SamzaAppMasterState.scala
samza-yarn/src/main/scala/org/apache/samza/job/yarn/SamzaAppMasterTaskManager.scala
samza-yarn/src/main/scala/org/apache/samza/webapp/ApplicationMasterRestServlet.scala
samza-yarn/src/test/scala/org/apache/samza/job/yarn/TestSamzaAppMasterTaskManager.scala

index db9d3ec..7cbbfd6 100644 (file)
@@ -24,3 +24,5 @@ build
 samza-test/state
 docs/learn/documentation/0.7.0/api/javadocs
 .DS_Store
+out/
+*.patch
index d1ee5e8..c262b5f 100644 (file)
@@ -144,6 +144,7 @@ project(":samza-kafka_$scalaVersion") {
     compile "org.apache.zookeeper:zookeeper:$zookeeperVersion"
     compile "org.codehaus.jackson:jackson-jaxrs:$jacksonVersion"
     compile "org.apache.kafka:kafka_$scalaVersion:$kafkaVersion"
+    compile "org.codehaus.jackson:jackson-jaxrs:$jacksonVersion"
     testCompile "org.apache.kafka:kafka_$scalaVersion:$kafkaVersion:test"
     testCompile "junit:junit:$junitVersion"
     testCompile "org.mockito:mockito-all:$mockitoVersion"
index a96ab4a..ab4f0e4 100644 (file)
@@ -45,7 +45,7 @@ public interface InitableTask {
 }
 {% endhighlight %}
 
-How many instances of your task class are created depends on the number of partitions in the job's input streams. If your Samza job has ten partitions, there will be ten instantiations of your task class: one for each partition. The first task instance will receive all messages for partition one, the second instance will receive all messages for partition two, and so on.
+By default, how many instances of your task class are created depends on the number of partitions in the job's input streams. If your Samza job has ten partitions, there will be ten instantiations of your task class: one for each partition. The first task instance will receive all messages for partition one, the second instance will receive all messages for partition two, and so on.
 
 <img src="/img/0.7.0/learn/documentation/container/tasks-and-partitions.svg" alt="Illustration of tasks consuming partitions" class="diagram-large">
 
@@ -53,7 +53,13 @@ The number of partitions in the input streams is determined by the systems from
 
 If a Samza job has more than one input stream, the number of task instances for the Samza job is the maximum number of partitions across all input streams. For example, if a Samza job is reading from PageViewEvent (12 partitions), and ServiceMetricEvent (14 partitions), then the Samza job would have 14 task instances (numbered 0 through 13). Task instances 12 and 13 only receive events from ServiceMetricEvent, because there is no corresponding PageViewEvent partition.
 
-There is [work underway](https://issues.apache.org/jira/browse/SAMZA-71) to make the assignment of partitions to tasks more flexible in future versions of Samza.
+With this default approach to assigning input streams to task instances, Samza is effectively performing a group-by operation on the input streams with their partitions as the key. Other strategies for grouping input stream partitions are possible by implementing a new [SystemStreamPartitionGrouper](../api/javadocs/org/apache/samza/container/SystemStreamPartitionGrouper.html) and factory, and configuring the job to use it via the job.systemstreampartition.grouper.factory configuration value.
+
+Samza provides the above-discussed per-partition grouper as well as the [GroupBySystemStreamPartitionGrouper](../api/javadocs/org/apache/samza/container/systemstreampartition/groupers/GroupBySystemStreamPartition), which provides a separate task class instance for every input stream partition, effectively grouping by the input stream itself. This provides maximum scalability in terms of how many containers can be used to process those input streams and is appropriate for very high volume jobs that need no grouping of the input streams.
+
+Considering the above example of a PageViewEvent partitioned 12 ways and a ServiceMetricEvent partitioned 14 ways, the GroupBySystemStreamPartitionGrouper would create 12 + 14 = 26 task instances, which would then be distributed across the number of containers configured, as discussed below.
+
+Note that once a job has been started using a particular SystemStreamPartitionGrouper and that job is using state or checkpointing, it is not possible to change that grouping in subsequent job starts, as the previous checkpoints and state information would likely be incorrect under the new grouping approach.
 
 ### Containers and resource allocation
 
index 6fad1fa..593d118 100644 (file)
 
 package org.apache.samza.checkpoint;
 
+import org.apache.samza.system.SystemStreamPartition;
+
 import java.util.Collections;
 import java.util.Map;
 
-import org.apache.samza.system.SystemStream;
-
 /**
  * A checkpoint is a mapping of all the streams a job is consuming and the most recent current offset for each.
  * It is used to restore a {@link org.apache.samza.task.StreamTask}, either as part of a job restart or as part
  * of restarting a failed container within a running job.
  */
 public class Checkpoint {
-  private final Map<SystemStream, String> offsets;
+  private final Map<SystemStreamPartition, String> offsets;
 
   /**
    * Constructs a new checkpoint based off a map of Samza stream offsets.
    * @param offsets Map of Samza streams to their current offset.
    */
-  public Checkpoint(Map<SystemStream, String> offsets) {
+  public Checkpoint(Map<SystemStreamPartition, String> offsets) {
     this.offsets = offsets;
   }
 
@@ -44,33 +44,25 @@ public class Checkpoint {
    * Gets a unmodifiable view of the current Samza stream offsets.
    * @return A unmodifiable view of a Map of Samza streams to their recorded offsets.
    */
-  public Map<SystemStream, String> getOffsets() {
+  public Map<SystemStreamPartition, String> getOffsets() {
     return Collections.unmodifiableMap(offsets);
   }
 
   @Override
-  public int hashCode() {
-    final int prime = 31;
-    int result = 1;
-    result = prime * result + ((offsets == null) ? 0 : offsets.hashCode());
-    return result;
+  public boolean equals(Object o) {
+    if (this == o) return true;
+    if (!(o instanceof Checkpoint)) return false;
+
+    Checkpoint that = (Checkpoint) o;
+
+    if (offsets != null ? !offsets.equals(that.offsets) : that.offsets != null) return false;
+
+    return true;
   }
 
   @Override
-  public boolean equals(Object obj) {
-    if (this == obj)
-      return true;
-    if (obj == null)
-      return false;
-    if (getClass() != obj.getClass())
-      return false;
-    Checkpoint other = (Checkpoint) obj;
-    if (offsets == null) {
-      if (other.offsets != null)
-        return false;
-    } else if (!offsets.equals(other.offsets))
-      return false;
-    return true;
+  public int hashCode() {
+    return offsets != null ? offsets.hashCode() : 0;
   }
 
   @Override
index a6e1ba6..092cb91 100644 (file)
 
 package org.apache.samza.checkpoint;
 
-import org.apache.samza.Partition;
+import org.apache.samza.SamzaException;
+import org.apache.samza.container.TaskName;
+
+import java.util.Map;
 
 /**
  * CheckpointManagers read and write {@link org.apache.samza.checkpoint.Checkpoint} to some
@@ -30,23 +33,38 @@ public interface CheckpointManager {
 
   /**
    * Registers this manager to write checkpoints of a specific Samza stream partition.
-   * @param partition Specific Samza stream partition of which to write checkpoints for.
+   * @param taskName Specific Samza taskName of which to write checkpoints for.
    */
-  public void register(Partition partition);
+  public void register(TaskName taskName);
 
   /**
    * Writes a checkpoint based on the current state of a Samza stream partition.
-   * @param partition Specific Samza stream partition of which to write a checkpoint of.
+   * @param taskName Specific Samza taskName of which to write a checkpoint of.
    * @param checkpoint Reference to a Checkpoint object to store offset data in.
    */
-  public void writeCheckpoint(Partition partition, Checkpoint checkpoint);
+  public void writeCheckpoint(TaskName taskName, Checkpoint checkpoint);
 
   /**
-   * Returns the last recorded checkpoint for a specified Samza stream partition.
-   * @param partition Specific Samza stream partition for which to get the last checkpoint of.
+   * Returns the last recorded checkpoint for a specified taskName.
+   * @param taskName Specific Samza taskName for which to get the last checkpoint of.
    * @return A Checkpoint object with the recorded offset data of the specified partition.
    */
-  public Checkpoint readLastCheckpoint(Partition partition);
+  public Checkpoint readLastCheckpoint(TaskName taskName);
+
+  /**
+   * Read the taskName to partition mapping that is being maintained by this CheckpointManager
+   *
+   * @return TaskName to task log partition mapping, or an empty map if there were no messages.
+   */
+  public Map<TaskName, Integer> readChangeLogPartitionMapping();
+
+  /**
+   * Write the taskName to partition mapping that is being maintained by this CheckpointManager
+   *
+   * @param mapping Each TaskName's partition within the changelog
+   */
+  public void writeChangeLogPartitionMapping(Map<TaskName, Integer> mapping);
 
   public void stop();
+
 }
index 78d56a9..c8693c8 100644 (file)
 
 package org.apache.samza.container;
 
-import java.util.Collection;
-
-import org.apache.samza.Partition;
 import org.apache.samza.config.Config;
 
+import java.util.Collection;
+import java.util.Collections;
+
 /**
  * A SamzaContainerContext maintains per-container information for the tasks it executes.
  */
 public class SamzaContainerContext {
   public final String name;
   public final Config config;
-  public final Collection<Partition> partitions;
+  public final Collection<TaskName> taskNames;
 
   /**
    * An immutable context object that can passed to tasks to give them information
    * about the container in which they are executing.
    * @param name The name of the container (either a YARN AM or SamzaContainer).
    * @param config The job configuration.
-   * @param partitions The set of input partitions assigned to this container.
+   * @param taskNames The set of taskName keys for which this container is responsible.
    */
   public SamzaContainerContext(
       String name,
       Config config,
-      Collection<Partition> partitions) {
+      Collection<TaskName> taskNames) {
     this.name = name;
     this.config = config;
-    this.partitions = partitions;
+    this.taskNames = Collections.unmodifiableCollection(taskNames);
   }
 }
diff --git a/samza-api/src/main/java/org/apache/samza/container/SystemStreamPartitionGrouper.java b/samza-api/src/main/java/org/apache/samza/container/SystemStreamPartitionGrouper.java
new file mode 100644 (file)
index 0000000..897d9f5
--- /dev/null
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.samza.container;
+
+import org.apache.samza.system.SystemStreamPartition;
+
+import java.util.Map;
+import java.util.Set;
+
+/**
+ * Group a set of SystemStreamPartitions into logical taskNames that share a common characteristic, defined
+ * by the implementation.  Each taskName has a key that uniquely describes what sets may be in it, but does
+ * not generally enumerate the elements of those sets.  For example, a SystemStreamPartitionGrouper that
+ * groups SystemStreamPartitions (each with 4 partitions) by their partition, would end up generating
+ * four TaskNames: 0, 1, 2, 3.  These TaskNames describe the partitions but do not list all of the
+ * SystemStreamPartitions, which allows new SystemStreamPartitions to be added later without changing
+ * the definition of the TaskNames, assuming these new SystemStreamPartitions do not have more than
+ * four partitions.  On the other hand, a SystemStreamPartitionGrouper that wanted each SystemStreamPartition
+ * to be its own, unique group would use the SystemStreamPartition's entire description to generate
+ * the TaskNames.
+ */
+public interface SystemStreamPartitionGrouper {
+  public Map<TaskName, Set<SystemStreamPartition>> group(Set<SystemStreamPartition> ssps);
+}
diff --git a/samza-api/src/main/java/org/apache/samza/container/SystemStreamPartitionGrouperFactory.java b/samza-api/src/main/java/org/apache/samza/container/SystemStreamPartitionGrouperFactory.java
new file mode 100644 (file)
index 0000000..10ac6e2
--- /dev/null
@@ -0,0 +1,28 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.samza.container;
+
+import org.apache.samza.config.Config;
+
+/**
+ * Return an instance a SystemStreamPartitionGrouper per the particular implementation
+ */
+public interface SystemStreamPartitionGrouperFactory {
+  public SystemStreamPartitionGrouper getSystemStreamPartitionGrouper(Config config);
+}
diff --git a/samza-api/src/main/java/org/apache/samza/container/TaskName.java b/samza-api/src/main/java/org/apache/samza/container/TaskName.java
new file mode 100644 (file)
index 0000000..13a1206
--- /dev/null
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.samza.container;
+
+/**
+ * A unique identifier of a set of a SystemStreamPartitions that have been grouped by
+ * a {@link org.apache.samza.container.SystemStreamPartitionGrouper}.  The
+ * SystemStreamPartitionGrouper determines the TaskName for each set it creates.
+ */
+public class TaskName implements Comparable<TaskName> {
+  private final String taskName;
+
+  public String getTaskName() {
+    return taskName;
+  }
+
+  public TaskName(String taskName) {
+    this.taskName = taskName;
+  }
+
+  @Override
+  public boolean equals(Object o) {
+    if (this == o) return true;
+    if (o == null || getClass() != o.getClass()) return false;
+
+    TaskName taskName1 = (TaskName) o;
+
+    if (!taskName.equals(taskName1.taskName)) return false;
+
+    return true;
+  }
+
+  @Override
+  public int hashCode() {
+    return taskName.hashCode();
+  }
+
+  @Override
+  public String toString() {
+    return taskName;
+  }
+
+  @Override
+  public int compareTo(TaskName that) {
+    return taskName.compareTo(that.taskName);
+  }
+}
index cb40092..f510ce5 100644 (file)
@@ -20,6 +20,7 @@
 package org.apache.samza.job;
 
 import org.apache.samza.config.Config;
+import org.apache.samza.container.TaskName;
 import org.apache.samza.system.SystemStreamPartition;
 
 import java.util.Map;
@@ -30,12 +31,13 @@ import java.util.Set;
  * such as YARN or the LocalJobRunner.
  */
 public abstract class CommandBuilder {
-  protected Set<SystemStreamPartition> systemStreamPartitions;
+  protected Map<TaskName, Set<SystemStreamPartition>> taskNameToSystemStreamPartitionsMapping;
+  protected Map<TaskName, Integer> taskNameToChangeLogPartitionMapping;
   protected String name;
   protected Config config;
 
-  public CommandBuilder setStreamPartitions(Set<SystemStreamPartition> ssp) {
-    this.systemStreamPartitions = ssp;
+  public CommandBuilder setTaskNameToSystemStreamPartitionsMapping(Map<TaskName, Set<SystemStreamPartition>> systemStreamPartitionTaskNames) {
+    this.taskNameToSystemStreamPartitionsMapping = systemStreamPartitionTaskNames;
     return this;
   }
   
@@ -54,6 +56,11 @@ public abstract class CommandBuilder {
     return this;
   }
 
+  public CommandBuilder setTaskNameToChangeLogPartitionMapping(Map<TaskName, Integer> mapping) {
+    this.taskNameToChangeLogPartitionMapping = mapping;
+    return this;
+  }
+
   public abstract String buildCommand();
 
   public abstract Map<String, String> buildEnvironment();
index 7c1b085..35de8cc 100644 (file)
 
 package org.apache.samza.task;
 
-import org.apache.samza.Partition;
+import org.apache.samza.container.TaskName;
 import org.apache.samza.metrics.MetricsRegistry;
+import org.apache.samza.system.SystemStreamPartition;
+
+import java.util.Set;
 
 /**
  * A TaskContext provides resources about the {@link org.apache.samza.task.StreamTask}, particularly during
@@ -29,7 +32,9 @@ import org.apache.samza.metrics.MetricsRegistry;
 public interface TaskContext {
   MetricsRegistry getMetricsRegistry();
 
-  Partition getPartition();
+  Set<SystemStreamPartition> getSystemStreamPartitions();
 
   Object getStore(String name);
+
+  TaskName getTaskName();
 }
index 5735a39..84ea4ca 100644 (file)
 
 package org.apache.samza.checkpoint
 
+import grizzled.slf4j.Logging
 import java.net.URI
 import java.util.regex.Pattern
 import joptsimple.OptionSet
-import org.apache.samza.{Partition, SamzaException}
-import org.apache.samza.config.{Config, StreamConfig}
+import org.apache.samza.checkpoint.CheckpointTool.TaskNameToCheckpointMap
 import org.apache.samza.config.TaskConfig.Config2Task
+import org.apache.samza.config.{Config, StreamConfig}
+import org.apache.samza.container.TaskName
 import org.apache.samza.metrics.MetricsRegistryMap
 import org.apache.samza.system.SystemStreamPartition
 import org.apache.samza.util.{CommandLine, Util}
+import org.apache.samza.{Partition, SamzaException}
 import scala.collection.JavaConversions._
-import grizzled.slf4j.Logging
 
 /**
  * Command-line tool for inspecting and manipulating the checkpoints for a job.
@@ -44,7 +46,10 @@ import grizzled.slf4j.Logging
  * containing the offsets you want. It needs to be in the same format as the tool
  * prints out the latest checkpoint:
  *
- *   systems.<system>.streams.<topic>.partitions.<partition>=<offset>
+ *   tasknames.<taskname>.systems.<system>.streams.<topic>.partitions.<partition>=<offset>
+ *
+ * The provided offset definitions will be grouped by <taskname> and written to
+ * individual checkpoint entries for each <taskname>
  *
  * NOTE: A job only reads its checkpoint when it starts up. Therefore, if you want
  * your checkpoint change to take effect, you have to first stop the job, then
@@ -56,8 +61,10 @@ import grizzled.slf4j.Logging
  */
 object CheckpointTool {
   /** Format in which SystemStreamPartition is represented in a properties file */
-  val SSP_PATTERN = StreamConfig.STREAM_PREFIX + "partitions.%d"
-  val SSP_REGEX = Pattern.compile("systems\\.(.+)\\.streams\\.(.+)\\.partitions\\.([0-9]+)")
+  val SSP_PATTERN = "tasknames.%s." + StreamConfig.STREAM_PREFIX + "partitions.%d"
+  val SSP_REGEX = Pattern.compile("tasknames\\.(.+)\\.systems\\.(.+)\\.streams\\.(.+)\\.partitions\\.([0-9]+)")
+
+  type TaskNameToCheckpointMap = Map[TaskName, Map[SystemStreamPartition, String]]
 
   class CheckpointToolCommandLine extends CommandLine with Logging {
     val newOffsetsOpt =
@@ -68,20 +75,31 @@ object CheckpointTool {
             .ofType(classOf[URI])
             .describedAs("path")
 
-    var newOffsets: Map[SystemStreamPartition, String] = null
+    var newOffsets: TaskNameToCheckpointMap = null
 
-    def parseOffsets(propertiesFile: Config): Map[SystemStreamPartition, String] = {
-      propertiesFile.entrySet.flatMap(entry => {
+    def parseOffsets(propertiesFile: Config): TaskNameToCheckpointMap = {
+      val taskNameSSPPairs = propertiesFile.entrySet.flatMap(entry => {
         val matcher = SSP_REGEX.matcher(entry.getKey)
         if (matcher.matches) {
-          val partition = new Partition(Integer.parseInt(matcher.group(3)))
-          val ssp = new SystemStreamPartition(matcher.group(1), matcher.group(2), partition)
-          Some(ssp -> entry.getValue)
+          val taskname = new TaskName(matcher.group(1))
+          val partition = new Partition(Integer.parseInt(matcher.group(4)))
+          val ssp = new SystemStreamPartition(matcher.group(2), matcher.group(3), partition)
+          Some(taskname -> Map(ssp -> entry.getValue))
         } else {
           warn("Warning: ignoring unrecognised property: %s = %s" format (entry.getKey, entry.getValue))
           None
         }
-      }).toMap
+      }).toList
+
+      if(taskNameSSPPairs.isEmpty) {
+        return null
+      }
+
+      // Need to turn taskNameSSPPairs List[(taskname, Map[SystemStreamPartition, Offset])] to Map[TaskName, Map[SSP, Offset]]
+      taskNameSSPPairs                      // List[(taskname, Map[SystemStreamPartition, Offset])]
+        .groupBy(_._1)                      // Group by taskname
+        .mapValues(m => m.map(_._2))        // Drop the extra taskname that we grouped on
+        .mapValues(m => m.reduce( _ ++ _))  // Merge all the maps of SSPs->Offset into one for the whole taskname
     }
 
     override def loadConfig(options: OptionSet) = {
@@ -103,7 +121,7 @@ object CheckpointTool {
   }
 }
 
-class CheckpointTool(config: Config, newOffsets: Map[SystemStreamPartition, String]) extends Logging {
+class CheckpointTool(config: Config, newOffsets: TaskNameToCheckpointMap) extends Logging {
   val manager = config.getCheckpointManagerFactory match {
     case Some(className) =>
       Util.getObj[CheckpointManagerFactory](className).getCheckpointManager(config, new MetricsRegistryMap)
@@ -113,52 +131,48 @@ class CheckpointTool(config: Config, newOffsets: Map[SystemStreamPartition, Stri
 
   // The CheckpointManagerFactory needs to perform this same operation when initializing
   // the manager. TODO figure out some way of avoiding duplicated work.
-  val partitions = Util.getInputStreamPartitions(config).map(_.getPartition).toSet
 
   def run {
     info("Using %s" format manager)
-    partitions.foreach(manager.register)
+
+    // Find all the TaskNames that would be generated for this job config
+    val taskNames = Util.assignContainerToSSPTaskNames(config, 1).get(0).get.keys.toSet
+
+    taskNames.foreach(manager.register)
     manager.start
-    val lastCheckpoint = readLastCheckpoint
 
-    logCheckpoint(lastCheckpoint, "Current checkpoint")
+    val lastCheckpoints = taskNames.map(tn => tn -> readLastCheckpoint(tn)).toMap
+
+    lastCheckpoints.foreach(lcp => logCheckpoint(lcp._1, lcp._2, "Current checkpoint for taskname "+ lcp._1))
 
     if (newOffsets != null) {
-      logCheckpoint(newOffsets, "New offset to be written")
-      writeNewCheckpoint(newOffsets)
-      manager.stop
-      info("Ok, new checkpoint has been written.")
+      newOffsets.foreach(no => {
+        logCheckpoint(no._1, no._2, "New offset to be written for taskname " + no._1)
+        writeNewCheckpoint(no._1, no._2)
+        info("Ok, new checkpoint has been written for taskname " + no._1)
+      })
     }
+
+    manager.stop
   }
 
-  /** Load the most recent checkpoint state for all partitions. */
-  def readLastCheckpoint: Map[SystemStreamPartition, String] = {
-    partitions.flatMap(partition => {
-      manager.readLastCheckpoint(partition)
-        .getOffsets
-        .map { case (systemStream, offset) =>
-          new SystemStreamPartition(systemStream, partition) -> offset
-        }
-    }).toMap
+  /** Load the most recent checkpoint state for all a specified TaskName. */
+  def readLastCheckpoint(taskName:TaskName): Map[SystemStreamPartition, String] = {
+    manager.readLastCheckpoint(taskName).getOffsets.toMap
   }
 
   /**
-   * Store a new checkpoint state for all given partitions, overwriting the
-   * current state. Any partitions that are not mentioned will not
-   * be changed.
+   * Store a new checkpoint state for specified TaskName, overwriting any previous
+   * checkpoint for that TaskName
    */
-  def writeNewCheckpoint(newOffsets: Map[SystemStreamPartition, String]) {
-    newOffsets.groupBy(_._1.getPartition).foreach {
-      case (partition, offsets) =>
-        val streamOffsets = offsets.map { case (ssp, offset) => ssp.getSystemStream -> offset }.toMap
-        val checkpoint = new Checkpoint(streamOffsets)
-        manager.writeCheckpoint(partition, checkpoint)
-    }
+  def writeNewCheckpoint(tn: TaskName, newOffsets: Map[SystemStreamPartition, String]) {
+    val checkpoint = new Checkpoint(newOffsets)
+    manager.writeCheckpoint(tn, checkpoint)
   }
 
-  def logCheckpoint(checkpoint: Map[SystemStreamPartition, String], prefix: String) {
-    checkpoint.map { case (ssp, offset) =>
-      (CheckpointTool.SSP_PATTERN + " = %s") format (ssp.getSystem, ssp.getStream, ssp.getPartition.getPartitionId, offset)
-    }.toList.sorted.foreach(line => info(prefix + ": " + line))
+  def logCheckpoint(tn: TaskName, checkpoint: Map[SystemStreamPartition, String], prefix: String) {
+    def logLine(tn:TaskName, ssp:SystemStreamPartition, offset:String) = (prefix + ": " + CheckpointTool.SSP_PATTERN + " = %s") format (tn.toString, ssp.getSystem, ssp.getStream, ssp.getPartition.getPartitionId, offset)
+
+    checkpoint.keys.toList.sorted.foreach(ssp => info(logLine(tn, ssp, checkpoint.get(ssp).get)))
   }
 }
index 9487b58..4efe997 100644 (file)
@@ -20,7 +20,6 @@
 package org.apache.samza.checkpoint
 
 import org.apache.samza.system.SystemStream
-import org.apache.samza.Partition
 import org.apache.samza.system.SystemStreamPartition
 import org.apache.samza.system.SystemStreamMetadata
 import org.apache.samza.system.SystemStreamMetadata.OffsetType
@@ -31,6 +30,8 @@ import org.apache.samza.config.Config
 import org.apache.samza.config.StreamConfig.Config2Stream
 import org.apache.samza.config.SystemConfig.Config2System
 import org.apache.samza.system.SystemAdmin
+import org.apache.samza.container.TaskName
+import scala.collection._
 
 /**
  * OffsetSetting encapsulates a SystemStream's metadata, default offset, and
@@ -150,13 +151,13 @@ class OffsetManager(
 
   /**
    * The set of system stream partitions that have been registered with the
-   * OffsetManager. These are the SSPs that will be tracked within the offset
-   * manager.
+   * OffsetManager, grouped by the taskName they belong to. These are the SSPs
+   * that will be tracked within the offset manager.
    */
-  var systemStreamPartitions = Set[SystemStreamPartition]()
+  val systemStreamPartitions = mutable.Map[TaskName, mutable.Set[SystemStreamPartition]]()
 
-  def register(systemStreamPartition: SystemStreamPartition) {
-    systemStreamPartitions += systemStreamPartition
+  def register(taskName: TaskName, systemStreamPartitionsToRegister: Set[SystemStreamPartition]) {
+    systemStreamPartitions.getOrElseUpdate(taskName, mutable.Set[SystemStreamPartition]()).addAll(systemStreamPartitionsToRegister)
   }
 
   def start {
@@ -193,20 +194,18 @@ class OffsetManager(
   }
 
   /**
-   * Checkpoint all offsets for a given partition using the CheckpointManager.
+   * Checkpoint all offsets for a given TaskName using the CheckpointManager.
    */
-  def checkpoint(partition: Partition) {
+  def checkpoint(taskName: TaskName) {
     if (checkpointManager != null) {
-      debug("Checkpointing offsets for partition %s." format partition)
+      debug("Checkpointing offsets for taskName %s." format taskName)
 
-      val partitionOffsets = lastProcessedOffsets
-        .filterKeys(_.getPartition.equals(partition))
-        .map { case (systemStreamPartition, offset) => (systemStreamPartition.getSystemStream, offset) }
-        .toMap
+      val sspsForTaskName = systemStreamPartitions.getOrElse(taskName, throw new SamzaException("No such SystemStreamPartition set " + taskName + " registered for this checkpointmanager")).toSet
+      val partitionOffsets = lastProcessedOffsets.filterKeys(sspsForTaskName.contains(_))
 
-      checkpointManager.writeCheckpoint(partition, new Checkpoint(partitionOffsets))
+      checkpointManager.writeCheckpoint(taskName, new Checkpoint(partitionOffsets))
     } else {
-      debug("Skipping checkpointing for partition %s because no checkpoint manager is defined." format partition)
+      debug("Skipping checkpointing for taskName %s because no checkpoint manager is defined." format taskName)
     }
   }
 
@@ -221,23 +220,13 @@ class OffsetManager(
   }
 
   /**
-   * Returns a set of partitions that have been registered with this offset
-   * manager.
-   */
-  private def getPartitions = {
-    systemStreamPartitions
-      .map(_.getPartition)
-      .toSet
-  }
-
-  /**
    * Register all partitions with the CheckpointManager.
    */
   private def registerCheckpointManager {
     if (checkpointManager != null) {
       debug("Registering checkpoint manager.")
 
-      getPartitions.foreach(checkpointManager.register)
+      systemStreamPartitions.keys.foreach(checkpointManager.register)
     } else {
       debug("Skipping checkpoint manager registration because no manager was defined.")
     }
@@ -253,9 +242,8 @@ class OffsetManager(
 
       checkpointManager.start
 
-      lastProcessedOffsets ++= getPartitions
-        .flatMap(restoreOffsetsFromCheckpoint(_))
-        .filter {
+      lastProcessedOffsets ++= systemStreamPartitions.keys
+        .flatMap(restoreOffsetsFromCheckpoint(_)).filter {
           case (systemStreamPartition, offset) =>
             val shouldKeep = offsetSettings.contains(systemStreamPartition.getSystemStream)
             if (!shouldKeep) {
@@ -269,20 +257,17 @@ class OffsetManager(
   }
 
   /**
-   * Loads last processed offsets for a single partition.
+   * Loads last processed offsets for a single taskName.
    */
-  private def restoreOffsetsFromCheckpoint(partition: Partition): Map[SystemStreamPartition, String] = {
-    debug("Loading checkpoints for partition: %s." format partition)
+  private def restoreOffsetsFromCheckpoint(taskName: TaskName): Map[SystemStreamPartition, String] = {
+    debug("Loading checkpoints for taskName: %s." format taskName)
 
-    val checkpoint = checkpointManager.readLastCheckpoint(partition)
+    val checkpoint = checkpointManager.readLastCheckpoint(taskName)
 
     if (checkpoint != null) {
-      checkpoint
-        .getOffsets
-        .map { case (systemStream, offset) => (new SystemStreamPartition(systemStream, partition), offset) }
-        .toMap
+      checkpoint.getOffsets.toMap
     } else {
-      info("Did not receive a checkpoint for partition %s. Proceeding without a checkpoint." format partition)
+      info("Did not receive a checkpoint for taskName %s. Proceeding without a checkpoint." format taskName)
 
       Map()
     }
@@ -338,7 +323,12 @@ class OffsetManager(
    * that was registered, but has no offset.
    */
   private def loadDefaults {
-    systemStreamPartitions.foreach(systemStreamPartition => {
+    val allSSPs: Set[SystemStreamPartition] = systemStreamPartitions
+      .values
+      .flatten
+      .toSet
+
+    allSSPs.foreach(systemStreamPartition => {
       if (!startingOffsets.contains(systemStreamPartition)) {
         val systemStream = systemStreamPartition.getSystemStream
         val partition = systemStreamPartition.getPartition
index 364e489..2a87a6e 100644 (file)
 package org.apache.samza.checkpoint.file
 
 import java.io.File
+import java.io.FileNotFoundException
 import java.io.FileOutputStream
-import scala.collection.JavaConversions._
-import scala.io.Source
+import java.util
 import org.apache.samza.SamzaException
-import org.apache.samza.serializers.CheckpointSerde
-import org.apache.samza.metrics.MetricsRegistry
+import org.apache.samza.checkpoint.Checkpoint
+import org.apache.samza.checkpoint.CheckpointManager
+import org.apache.samza.checkpoint.CheckpointManagerFactory
 import org.apache.samza.config.Config
-import org.apache.samza.Partition
-import org.apache.samza.config.JobConfig.Config2Job
 import org.apache.samza.config.FileSystemCheckpointManagerConfig.Config2FSCP
-import org.apache.samza.checkpoint.CheckpointManagerFactory
-import org.apache.samza.checkpoint.CheckpointManager
-import org.apache.samza.checkpoint.Checkpoint
-import java.io.FileNotFoundException
+import org.apache.samza.config.JobConfig.Config2Job
+import org.apache.samza.container.TaskName
+import org.apache.samza.metrics.MetricsRegistry
+import org.apache.samza.serializers.CheckpointSerde
+import scala.io.Source
 
 class FileSystemCheckpointManager(
   jobName: String,
   root: File,
   serde: CheckpointSerde = new CheckpointSerde) extends CheckpointManager {
 
-  def register(partition: Partition) {
-  }
+  override def register(taskName: TaskName):Unit = Unit
 
-  def writeCheckpoint(partition: Partition, checkpoint: Checkpoint) {
+  def getCheckpointFile(taskName: TaskName) = getFile(jobName, taskName, "checkpoints")
+
+  def writeCheckpoint(taskName: TaskName, checkpoint: Checkpoint) {
     val bytes = serde.toBytes(checkpoint)
-    val fos = new FileOutputStream(getFile(jobName, partition))
+    val fos = new FileOutputStream(getCheckpointFile(taskName))
 
     fos.write(bytes)
     fos.close
   }
 
-  def readLastCheckpoint(partition: Partition): Checkpoint = {
+  def readLastCheckpoint(taskName: TaskName): Checkpoint = {
     try {
-      val bytes = Source.fromFile(getFile(jobName, partition)).map(_.toByte).toArray
+      val bytes = Source.fromFile(getCheckpointFile(taskName)).map(_.toByte).toArray
 
       serde.fromBytes(bytes)
     } catch {
@@ -69,8 +70,28 @@ class FileSystemCheckpointManager(
 
   def stop {}
 
-  private def getFile(jobName: String, partition: Partition) =
-    new File(root, "%s-%d" format (jobName, partition.getPartitionId))
+  private def getFile(jobName: String, taskName: TaskName, fileType:String) =
+    new File(root, "%s-%s-%s" format (jobName, taskName, fileType))
+
+  private def getChangeLogPartitionMappingFile() = getFile(jobName, new TaskName("partition-mapping"), "changelog-partition-mapping")
+
+  override def readChangeLogPartitionMapping(): util.Map[TaskName, java.lang.Integer] = {
+    try {
+      val bytes = Source.fromFile(getChangeLogPartitionMappingFile()).map(_.toByte).toArray
+      serde.changelogPartitionMappingFromBytes(bytes)
+    } catch {
+      case e: FileNotFoundException => new util.HashMap[TaskName, java.lang.Integer]()
+    }
+  }
+
+  def writeChangeLogPartitionMapping(mapping: util.Map[TaskName, java.lang.Integer]): Unit = {
+    val hashmap = new util.HashMap[TaskName, java.lang.Integer](mapping)
+    val bytes = serde.changelogPartitionMappingToBytes(hashmap)
+    val fos = new FileOutputStream(getChangeLogPartitionMappingFile())
+
+    fos.write(bytes)
+    fos.close
+  }
 }
 
 class FileSystemCheckpointManagerFactory extends CheckpointManagerFactory {
index fcafe83..f84aeea 100644 (file)
@@ -19,6 +19,8 @@
 
 package org.apache.samza.config
 
+import org.apache.samza.container.systemstreampartition.groupers.GroupByPartitionFactory
+
 object JobConfig {
   // job config constants
   val STREAM_JOB_FACTORY_CLASS = "job.factory.class" // streaming.job_factory_class
@@ -34,6 +36,8 @@ object JobConfig {
   val JOB_NAME = "job.name" // streaming.job_name
   val JOB_ID = "job.id" // streaming.job_id
 
+  val SSP_GROUPER_FACTORY = "job.systemstreampartition.grouper.factory"
+
   implicit def Config2Job(config: Config) = new JobConfig(config)
 }
 
@@ -47,4 +51,6 @@ class JobConfig(config: Config) extends ScalaMapConfig(config) {
   def getConfigRewriters = getOption(JobConfig.CONFIG_REWRITERS)
 
   def getConfigRewriterClass(name: String) = getOption(JobConfig.CONFIG_REWRITER_CLASS format name)
+
+  def getSystemStreamPartitionGrouperFactory = getOption(JobConfig.SSP_GROUPER_FACTORY).getOrElse(classOf[GroupByPartitionFactory].getCanonicalName)
 }
index 0cdc0d1..e4197fa 100644 (file)
@@ -26,7 +26,12 @@ object ShellCommandConfig {
   val ENV_CONFIG = "SAMZA_CONFIG"
 
   /**
-   * An encoded list of the streams and partitions this container is responsible for. Encoded by 
+   * All taskNames across the job; used to calculate state store partition mapping
+   */
+  val ENV_TASK_NAME_TO_CHANGELOG_PARTITION_MAPPING = "TASK_NAME_TO_CHANGELOG_PARTITION_MAPPING"
+
+  /**
+   * A serialized list of the streams and partitions this container is responsible for. Encoded by
    * {@link org.apache.samza.util.Util#createStreamPartitionString}
    */
   val ENV_SYSTEM_STREAMS = "SAMZA_SYSTEM_STREAMS"
index 4ca340c..7fb4763 100644 (file)
@@ -20,8 +20,7 @@
 package org.apache.samza.container
 
 import grizzled.slf4j.Logging
-import org.apache.samza.Partition
-import org.apache.samza.system.SystemConsumers
+import org.apache.samza.system.{SystemStreamPartition, SystemConsumers}
 import org.apache.samza.task.ReadableCoordinator
 
 /**
@@ -34,7 +33,7 @@ import org.apache.samza.task.ReadableCoordinator
  * be done when.
  */
 class RunLoop(
-  val taskInstances: Map[Partition, TaskInstance],
+  val taskInstances: Map[TaskName, TaskInstance],
   val consumerMultiplexer: SystemConsumers,
   val metrics: SamzaContainerMetrics,
   val windowMs: Long = -1,
@@ -43,10 +42,18 @@ class RunLoop(
 
   private var lastWindowMs = 0L
   private var lastCommitMs = 0L
-  private var taskShutdownRequests: Set[Partition] = Set()
-  private var taskCommitRequests: Set[Partition] = Set()
+  private var taskShutdownRequests: Set[TaskName] = Set()
+  private var taskCommitRequests: Set[TaskName] = Set()
   private var shutdownNow = false
 
+  // Messages come from the chooser with no connection to the TaskInstance they're bound for.
+  // Keep a mapping of SystemStreamPartition to TaskInstance to efficiently route them.
+  val systemStreamPartitionToTaskInstance: Map[SystemStreamPartition, TaskInstance] = {
+    // We could just pass in the SystemStreamPartitionMap during construction, but it's safer and cleaner to derive the information directly
+    def getSystemStreamPartitionToTaskInstance(taskInstance: TaskInstance) = taskInstance.systemStreamPartitions.map(_ -> taskInstance).toMap
+
+    taskInstances.values.map{ getSystemStreamPartitionToTaskInstance }.flatten.toMap
+  }
 
   /**
    * Starts the run loop. Blocks until either the tasks request shutdown, or an
@@ -61,7 +68,6 @@ class RunLoop(
     }
   }
 
-
   /**
    * Chooses a message from an input stream to process, and calls the
    * process() method on the appropriate StreamTask to handle it.
@@ -73,13 +79,15 @@ class RunLoop(
     val envelope = consumerMultiplexer.choose
 
     if (envelope != null) {
-      val partition = envelope.getSystemStreamPartition.getPartition
+      val ssp = envelope.getSystemStreamPartition
 
-      trace("Processing incoming message envelope for partition %s." format partition)
+      trace("Processing incoming message envelope for SSP %s." format ssp)
       metrics.envelopes.inc
 
-      val coordinator = new ReadableCoordinator(partition)
-      taskInstances(partition).process(envelope, coordinator)
+      val taskInstance = systemStreamPartitionToTaskInstance(ssp)
+
+      val coordinator = new ReadableCoordinator(taskInstance.taskName)
+      taskInstance.process(envelope, coordinator)
       checkCoordinator(coordinator)
     } else {
       trace("No incoming message envelope was available.")
@@ -87,7 +95,6 @@ class RunLoop(
     }
   }
 
-
   /**
    * Invokes WindowableTask.window on all tasks if it's time to do so.
    */
@@ -97,8 +104,8 @@ class RunLoop(
       lastWindowMs = clock()
       metrics.windows.inc
 
-      taskInstances.foreach { case (partition, task) =>
-        val coordinator = new ReadableCoordinator(partition)
+      taskInstances.foreach { case (taskName, task) =>
+        val coordinator = new ReadableCoordinator(taskName)
         task.window(coordinator)
         checkCoordinator(coordinator)
       }
@@ -129,8 +136,8 @@ class RunLoop(
     } else if (!taskCommitRequests.isEmpty) {
       trace("Committing due to explicit commit request.")
       metrics.commits.inc
-      taskCommitRequests.foreach(partition => {
-        taskInstances(partition).commit
+      taskCommitRequests.foreach(taskName => {
+        taskInstances(taskName).commit
       })
     }
 
@@ -146,17 +153,17 @@ class RunLoop(
    */
   private def checkCoordinator(coordinator: ReadableCoordinator) {
     if (coordinator.requestedCommitTask) {
-      debug("Task %s requested commit for current task only" format coordinator.partition)
-      taskCommitRequests += coordinator.partition
+      debug("Task %s requested commit for current task only" format coordinator.taskName)
+      taskCommitRequests += coordinator.taskName
     }
 
     if (coordinator.requestedCommitAll) {
-      debug("Task %s requested commit for all tasks in the container" format coordinator.partition)
+      debug("Task %s requested commit for all tasks in the container" format coordinator.taskName)
       taskCommitRequests ++= taskInstances.keys
     }
 
     if (coordinator.requestedShutdownOnConsensus) {
-      taskShutdownRequests += coordinator.partition
+      taskShutdownRequests += coordinator.taskName
       info("Shutdown has now been requested by tasks: %s" format taskShutdownRequests)
     }
 
index a7142b2..d574ac4 100644 (file)
 
 package org.apache.samza.container
 
-import java.io.File
 import grizzled.slf4j.Logging
+import java.io.File
 import org.apache.samza.Partition
 import org.apache.samza.SamzaException
-import org.apache.samza.checkpoint.CheckpointManager
-import org.apache.samza.checkpoint.CheckpointManagerFactory
+import org.apache.samza.checkpoint.{CheckpointManagerFactory, OffsetManager}
 import org.apache.samza.config.Config
 import org.apache.samza.config.MetricsConfig.Config2Metrics
 import org.apache.samza.config.SerializerConfig.Config2Serializer
@@ -34,35 +33,34 @@ import org.apache.samza.config.StreamConfig.Config2Stream
 import org.apache.samza.config.SystemConfig.Config2System
 import org.apache.samza.config.TaskConfig.Config2Task
 import org.apache.samza.config.serializers.JsonConfigSerializer
+import org.apache.samza.job.ShellCommandBuilder
 import org.apache.samza.metrics.JmxServer
 import org.apache.samza.metrics.JvmMetrics
+import org.apache.samza.metrics.MetricsRegistryMap
 import org.apache.samza.metrics.MetricsReporter
 import org.apache.samza.metrics.MetricsReporterFactory
 import org.apache.samza.serializers.SerdeFactory
 import org.apache.samza.serializers.SerdeManager
 import org.apache.samza.storage.StorageEngineFactory
 import org.apache.samza.storage.TaskStorageManager
+import org.apache.samza.system.StreamMetadataCache
+import org.apache.samza.system.SystemConsumers
+import org.apache.samza.system.SystemConsumersMetrics
 import org.apache.samza.system.SystemFactory
+import org.apache.samza.system.SystemProducers
+import org.apache.samza.system.SystemProducersMetrics
 import org.apache.samza.system.SystemStream
+import org.apache.samza.system.SystemStreamMetadata
 import org.apache.samza.system.SystemStreamPartition
+import org.apache.samza.system.chooser.DefaultChooser
+import org.apache.samza.system.chooser.MessageChooserFactory
+import org.apache.samza.system.chooser.RoundRobinChooserFactory
+import org.apache.samza.task.ReadableCollector
 import org.apache.samza.task.StreamTask
 import org.apache.samza.task.TaskLifecycleListener
 import org.apache.samza.task.TaskLifecycleListenerFactory
 import org.apache.samza.util.Util
-import org.apache.samza.system.SystemProducers
-import org.apache.samza.task.ReadableCollector
-import org.apache.samza.system.SystemConsumers
-import org.apache.samza.system.chooser.MessageChooserFactory
-import org.apache.samza.system.SystemProducersMetrics
-import org.apache.samza.system.SystemConsumersMetrics
-import org.apache.samza.metrics.MetricsRegistryMap
-import org.apache.samza.system.chooser.DefaultChooser
-import org.apache.samza.system.chooser.RoundRobinChooserFactory
 import scala.collection.JavaConversions._
-import org.apache.samza.system.SystemAdmin
-import org.apache.samza.system.SystemStreamMetadata
-import org.apache.samza.checkpoint.OffsetManager
-import org.apache.samza.system.StreamMetadataCache
 
 object SamzaContainer extends Logging {
 
@@ -92,29 +90,46 @@ object SamzaContainer extends Logging {
      * properties. Note: This is a temporary workaround to reduce the size of the config and hence size
      * of the environment variable(s) exported while starting a Samza container (SAMZA-337)
      */
-    val isCompressed = if (System.getenv(ShellCommandConfig.ENV_COMPRESS_CONFIG).equals("TRUE")) true else false
+    val isCompressed = System.getenv(ShellCommandConfig.ENV_COMPRESS_CONFIG).equals("TRUE")
     val configStr = getParameter(System.getenv(ShellCommandConfig.ENV_CONFIG), isCompressed)
     val config = JsonConfigSerializer.fromJson(configStr)
-    val encodedStreamsAndPartitions = getParameter(System.getenv(ShellCommandConfig.ENV_SYSTEM_STREAMS), isCompressed)
-    val partitions = Util.deserializeSSPSetFromJSON(encodedStreamsAndPartitions)
-
-    if (partitions.isEmpty) {
-      throw new SamzaException("No partitions for this task. Can't run a task without partition assignments. It's likely that the partition manager for this system doesn't know about the stream you're trying to read.")
-    }
+    val sspTaskNames = getTaskNameToSystemStreamPartition(getParameter(System.getenv(ShellCommandConfig.ENV_SYSTEM_STREAMS), isCompressed))
+    val taskNameToChangeLogPartitionMapping = getTaskNameToChangeLogPartitionMapping(getParameter(System.getenv(ShellCommandConfig.ENV_TASK_NAME_TO_CHANGELOG_PARTITION_MAPPING), isCompressed))
 
     try {
-      SamzaContainer(containerName, partitions, config).run
+      SamzaContainer(containerName, sspTaskNames, taskNameToChangeLogPartitionMapping, config).run
     } finally {
       jmxServer.stop
     }
   }
 
-  def apply(containerName: String, inputStreams: Set[SystemStreamPartition], config: Config) = {
+  def getTaskNameToSystemStreamPartition(SSPTaskNamesJSON: String) = {
+    // Covert into a standard Java map
+    val sspTaskNamesAsJava: Map[TaskName, Set[SystemStreamPartition]] = ShellCommandBuilder.deserializeSystemStreamPartitionSetFromJSON(SSPTaskNamesJSON)
+
+    // From that map build the TaskNamesToSystemStreamPartitions
+    val sspTaskNames = TaskNamesToSystemStreamPartitions(sspTaskNamesAsJava)
+
+    if (sspTaskNames.isEmpty) {
+      throw new SamzaException("No SystemStreamPartitions for this task. Can't run a task without SystemStreamPartition assignments.")
+    }
+
+    sspTaskNames
+  }
+
+  def getTaskNameToChangeLogPartitionMapping(taskNameToChangeLogPartitionMappingJSON: String) = {
+    // Convert that mapping into a Map
+    val taskNameToChangeLogPartitionMapping = ShellCommandBuilder.deserializeTaskNameToChangeLogPartitionMapping(taskNameToChangeLogPartitionMappingJSON).map(kv => kv._1 -> Integer.valueOf(kv._2))
+
+    taskNameToChangeLogPartitionMapping
+  }
+
+  def apply(containerName: String, sspTaskNames: TaskNamesToSystemStreamPartitions, taskNameToChangeLogPartitionMapping: Map[TaskName, java.lang.Integer], config: Config) = {
     val containerPID = Util.getContainerPID
 
     info("Setting up Samza container: %s" format containerName)
+    info("Using SystemStreamPartition taskNames %s" format sspTaskNames)
     info("Samza container PID: %s" format containerPID)
-    info("Using streams and partitions: %s" format inputStreams)
     info("Using configuration: %s" format config)
 
     val registry = new MetricsRegistryMap(containerName)
@@ -122,7 +137,7 @@ object SamzaContainer extends Logging {
     val systemProducersMetrics = new SystemProducersMetrics(registry)
     val systemConsumersMetrics = new SystemConsumersMetrics(registry)
 
-    val inputSystems = inputStreams.map(_.getSystem)
+    val inputSystems = sspTaskNames.getAllSystems()
 
     val systemNames = config.getSystemNames
 
@@ -150,7 +165,7 @@ object SamzaContainer extends Logging {
     info("Got system factories: %s" format systemFactories.keys)
 
     val streamMetadataCache = new StreamMetadataCache(systemAdmins)
-    val inputStreamMetadata = streamMetadataCache.getStreamMetadata(inputStreams.map(_.getSystemStream))
+    val inputStreamMetadata = streamMetadataCache.getStreamMetadata(sspTaskNames.getAllSystemStreams)
 
     info("Got input stream metadata: %s" format inputStreamMetadata)
 
@@ -219,7 +234,7 @@ object SamzaContainer extends Logging {
      * A Helper function to build a Map[SystemStream, Serde] for streams defined in the config. This is useful to build both key and message serde maps.
      */
     val buildSystemStreamSerdeMap = (getSerdeName: (SystemStream) => Option[String]) => {
-      (serdeStreams ++ inputStreams)
+      (serdeStreams ++ sspTaskNames.getAllSSPs())
         .filter(systemStream => getSerdeName(systemStream).isDefined)
         .map(systemStream => {
           val serdeName = getSerdeName(systemStream).get
@@ -384,20 +399,20 @@ object SamzaContainer extends Logging {
 
     info("Got commit milliseconds: %s" format taskCommitMs)
 
-    // Wire up all task-level (unshared) objects.
+    // Wire up all task-instance-level (unshared) objects.
 
-    val partitions = inputStreams.map(_.getPartition).toSet
+    val taskNames = sspTaskNames.keys.toSet
 
-    val containerContext = new SamzaContainerContext(containerName, config, partitions)
+    val containerContext = new SamzaContainerContext(containerName, config, taskNames)
 
-    val taskInstances = partitions.map(partition => {
-      debug("Setting up task instance: %s" format partition)
+    val taskInstances: Map[TaskName, TaskInstance] = taskNames.map(taskName => {
+      debug("Setting up task instance: %s" format taskName)
 
       val task = Util.getObj[StreamTask](taskClassName)
 
       val collector = new ReadableCollector
 
-      val taskInstanceMetrics = new TaskInstanceMetrics("Partition-%s" format partition.getPartitionId)
+      val taskInstanceMetrics = new TaskInstanceMetrics("TaskName-%s" format taskName)
 
       val storeConsumers = changeLogSystemStreams
         .map {
@@ -410,11 +425,13 @@ object SamzaContainer extends Logging {
 
       info("Got store consumers: %s" format storeConsumers)
 
+      val partitionForThisTaskName = new Partition(taskNameToChangeLogPartitionMapping(taskName))
+
       val taskStores = storageEngineFactories
         .map {
           case (storeName, storageEngineFactory) =>
             val changeLogSystemStreamPartition = if (changeLogSystemStreams.contains(storeName)) {
-              new SystemStreamPartition(changeLogSystemStreams(storeName), partition)
+              new SystemStreamPartition(changeLogSystemStreams(storeName), partitionForThisTaskName)
             } else {
               null
             }
@@ -426,7 +443,7 @@ object SamzaContainer extends Logging {
               case Some(msgSerde) => serdes(msgSerde)
               case _ => null
             }
-            val storePartitionDir = TaskStorageManager.getStorePartitionDir(storeBaseDir, storeName, partition)
+            val storePartitionDir = TaskStorageManager.getStorePartitionDir(storeBaseDir, storeName, taskName)
             val storageEngine = storageEngineFactory.getStorageEngine(
               storeName,
               storePartitionDir,
@@ -441,25 +458,26 @@ object SamzaContainer extends Logging {
 
       info("Got task stores: %s" format taskStores)
 
-      val changeLogOldestOffsets = getChangeLogOldestOffsetsForPartition(partition, changeLogMetadata)
+      val changeLogOldestOffsets = getChangeLogOldestOffsetsForPartition(partitionForThisTaskName, changeLogMetadata)
 
-      info("Assigning oldest change log offsets for partition %s: %s" format (partition, changeLogOldestOffsets))
+      info("Assigning oldest change log offsets for taskName %s: %s" format (taskName, changeLogOldestOffsets))
 
       val storageManager = new TaskStorageManager(
-        partition = partition,
+        taskName = taskName,
         taskStores = taskStores,
         storeConsumers = storeConsumers,
         changeLogSystemStreams = changeLogSystemStreams,
         changeLogOldestOffsets = changeLogOldestOffsets,
-        storeBaseDir = storeBaseDir)
+        storeBaseDir = storeBaseDir,
+        partitionForThisTaskName)
 
-      val inputStreamsForThisPartition = inputStreams.filter(_.getPartition.equals(partition)).map(_.getSystemStream)
+      val systemStreamPartitions: Set[SystemStreamPartition] = sspTaskNames.getOrElse(taskName, throw new SamzaException("Can't find taskName " + taskName + " in map of SystemStreamPartitions: " + sspTaskNames))
 
-      info("Assigning SystemStreams " + inputStreamsForThisPartition + " to " + partition)
+      info("Retrieved SystemStreamPartitions " + systemStreamPartitions + " for " + taskName)
 
       val taskInstance = new TaskInstance(
         task = task,
-        partition = partition,
+        taskName = taskName,
         config = config,
         metrics = taskInstanceMetrics,
         consumerMultiplexer = consumerMultiplexer,
@@ -468,10 +486,10 @@ object SamzaContainer extends Logging {
         storageManager = storageManager,
         reporters = reporters,
         listeners = listeners,
-        inputStreams = inputStreamsForThisPartition,
+        systemStreamPartitions = systemStreamPartitions,
         collector = collector)
 
-      (partition, taskInstance)
+      (taskName, taskInstance)
     }).toMap
 
     val runLoop = new RunLoop(
@@ -506,7 +524,7 @@ object SamzaContainer extends Logging {
 }
 
 class SamzaContainer(
-  taskInstances: Map[Partition, TaskInstance],
+  taskInstances: Map[TaskName, TaskInstance],
   runLoop: RunLoop,
   consumerMultiplexer: SystemConsumers,
   producerMultiplexer: SystemProducers,
diff --git a/samza-core/src/main/scala/org/apache/samza/container/SystemStreamPartitionTaskNameGrouper.scala b/samza-core/src/main/scala/org/apache/samza/container/SystemStreamPartitionTaskNameGrouper.scala
new file mode 100644 (file)
index 0000000..a8c93ac
--- /dev/null
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.samza.container
+
+/**
+ * After the input SystemStreamPartitions have been mapped to their TaskNames by an implementation of
+ * {@link org.apache.samza.container.SystemStreamPartitionGrouper}, we can then map those groupings onto
+ * the {@link org.apache.samza.container.SamzaContainer}s on which they will run.  This class takes
+ * those groupings-of-SSPs and groups them together on which container each should run on.  A simple
+ * implementation could assign each TaskNamesToSystemStreamPartition to a separate container.  More
+ * advanced implementations could examine the TaskNamesToSystemStreamPartition to group by them
+ * by data locality, anti-affinity, even distribution of expected bandwidth consumption, etc.
+ */
+trait SystemStreamPartitionTaskNameGrouper {
+  /**
+   * Group TaskNamesToSystemStreamPartitions onto the containers they will share
+   *
+   * @param taskNames Pre-grouped SSPs
+   * @return Mapping of container ID to set if TaskNames it will run
+   */
+  def groupTaskNames(taskNames: TaskNamesToSystemStreamPartitions): Map[Int, TaskNamesToSystemStreamPartitions]
+}
index 99a9841..9484ddb 100644 (file)
@@ -21,33 +21,27 @@ package org.apache.samza.container
 
 import org.apache.samza.metrics.MetricsReporter
 import org.apache.samza.config.Config
-import org.apache.samza.Partition
 import grizzled.slf4j.Logging
-import scala.collection.JavaConversions._
 import org.apache.samza.storage.TaskStorageManager
-import org.apache.samza.config.StreamConfig.Config2Stream
 import org.apache.samza.system.SystemStreamPartition
 import org.apache.samza.task.TaskContext
 import org.apache.samza.task.ClosableTask
 import org.apache.samza.task.InitableTask
 import org.apache.samza.system.IncomingMessageEnvelope
 import org.apache.samza.task.WindowableTask
-import org.apache.samza.checkpoint.CheckpointManager
 import org.apache.samza.task.TaskLifecycleListener
 import org.apache.samza.task.StreamTask
-import org.apache.samza.system.SystemStream
-import org.apache.samza.checkpoint.Checkpoint
 import org.apache.samza.task.ReadableCollector
 import org.apache.samza.system.SystemConsumers
 import org.apache.samza.system.SystemProducers
 import org.apache.samza.task.ReadableCoordinator
-import org.apache.samza.metrics.Gauge
 import org.apache.samza.checkpoint.OffsetManager
 import org.apache.samza.SamzaException
+import scala.collection.JavaConversions._
 
 class TaskInstance(
   task: StreamTask,
-  partition: Partition,
+  val taskName: TaskName,
   config: Config,
   metrics: TaskInstanceMetrics,
   consumerMultiplexer: SystemConsumers,
@@ -56,15 +50,14 @@ class TaskInstance(
   storageManager: TaskStorageManager = null,
   reporters: Map[String, MetricsReporter] = Map(),
   listeners: Seq[TaskLifecycleListener] = Seq(),
-  inputStreams: Set[SystemStream] = Set(),
+  val systemStreamPartitions: Set[SystemStreamPartition] = Set(),
   collector: ReadableCollector = new ReadableCollector) extends Logging {
-
   val isInitableTask = task.isInstanceOf[InitableTask]
   val isWindowableTask = task.isInstanceOf[WindowableTask]
   val isClosableTask = task.isInstanceOf[ClosableTask]
   val context = new TaskContext {
     def getMetricsRegistry = metrics.registry
-    def getPartition = partition
+    def getSystemStreamPartitions = systemStreamPartitions
     def getStore(storeName: String) = if (storageManager != null) {
       storageManager(storeName)
     } else {
@@ -72,29 +65,28 @@ class TaskInstance(
 
       null
     }
+    def getTaskName = taskName
   }
 
   def registerMetrics {
-    debug("Registering metrics for partition: %s." format partition)
+    debug("Registering metrics for taskName: %s" format taskName)
 
     reporters.values.foreach(_.register(metrics.source, metrics.registry))
   }
 
   def registerOffsets {
-    debug("Registering offsets for partition: %s." format partition)
+    debug("Registering offsets for taskName: %s" format taskName)
 
-    inputStreams.foreach(systemStream => {
-      offsetManager.register(new SystemStreamPartition(systemStream, partition))
-    })
+    offsetManager.register(taskName, systemStreamPartitions)
   }
 
   def startStores {
     if (storageManager != null) {
-      debug("Starting storage manager for partition: %s." format partition)
+      debug("Starting storage manager for taskName: %s" format taskName)
 
       storageManager.init
     } else {
-      debug("Skipping storage manager initialization for partition: %s." format partition)
+      debug("Skipping storage manager initialization for taskName: %s" format taskName)
     }
   }
 
@@ -102,31 +94,30 @@ class TaskInstance(
     listeners.foreach(_.beforeInit(config, context))
 
     if (isInitableTask) {
-      debug("Initializing task for partition: %s." format partition)
+      debug("Initializing task for taskName: %s" format taskName)
 
       task.asInstanceOf[InitableTask].init(config, context)
     } else {
-      debug("Skipping task initialization for partition: %s." format partition)
+      debug("Skipping task initialization for taskName: %s" format taskName)
     }
 
     listeners.foreach(_.afterInit(config, context))
   }
 
   def registerProducers {
-    debug("Registering producers for partition: %s." format partition)
+    debug("Registering producers for taskName: %s" format taskName)
 
     producerMultiplexer.register(metrics.source)
   }
 
   def registerConsumers {
-    debug("Registering consumers for partition: %s." format partition)
+    debug("Registering consumers for taskName: %s" format taskName)
 
-    inputStreams.foreach(systemStream => {
-      val systemStreamPartition = new SystemStreamPartition(systemStream, partition)
+    systemStreamPartitions.foreach(systemStreamPartition => {
       val offset = offsetManager.getStartingOffset(systemStreamPartition)
-        .getOrElse(throw new SamzaException("No offset defined for partition %s: %s" format (partition, systemStream)))
+        .getOrElse(throw new SamzaException("No offset defined for SystemStreamPartition: %s" format systemStreamPartition))
       consumerMultiplexer.register(systemStreamPartition, offset)
-      metrics.addOffsetGauge(systemStream, () => {
+      metrics.addOffsetGauge(systemStreamPartition, () => {
         offsetManager
           .getLastProcessedOffset(systemStreamPartition)
           .getOrElse(null)
@@ -139,20 +130,20 @@ class TaskInstance(
 
     listeners.foreach(_.beforeProcess(envelope, config, context))
 
-    trace("Processing incoming message envelope for partition: %s, %s" format (partition, envelope.getSystemStreamPartition))
+    trace("Processing incoming message envelope for taskName and SSP: %s, %s" format (taskName, envelope.getSystemStreamPartition))
 
     task.process(envelope, collector, coordinator)
 
     listeners.foreach(_.afterProcess(envelope, config, context))
 
-    trace("Updating offset map for partition: %s, %s, %s" format (partition, envelope.getSystemStreamPartition, envelope.getOffset))
+    trace("Updating offset map for taskName, SSP and offset: %s, %s, %s" format (taskName, envelope.getSystemStreamPartition, envelope.getOffset))
 
     offsetManager.update(envelope.getSystemStreamPartition, envelope.getOffset)
   }
 
   def window(coordinator: ReadableCoordinator) {
     if (isWindowableTask) {
-      trace("Windowing for partition: %s" format partition)
+      trace("Windowing for taskName: %s" format taskName)
 
       metrics.windows.inc
 
@@ -162,48 +153,48 @@ class TaskInstance(
 
   def send {
     if (collector.envelopes.size > 0) {
-      trace("Sending messages for partition: %s, %s" format (partition, collector.envelopes.size))
+      trace("Sending messages for taskName: %s, %s" format (taskName, collector.envelopes.size))
 
       metrics.sends.inc
       metrics.messagesSent.inc(collector.envelopes.size)
 
       collector.envelopes.foreach(envelope => producerMultiplexer.send(metrics.source, envelope))
 
-      trace("Resetting collector for partition: %s" format partition)
+      trace("Resetting collector for taskName: %s" format taskName)
 
       collector.reset
     } else {
-      trace("Skipping send for partition %s because no messages were collected." format partition)
+      trace("Skipping send for taskName %s because no messages were collected." format taskName)
 
       metrics.sendsSkipped.inc
     }
   }
 
   def commit {
-    trace("Flushing state stores for partition: %s" format partition)
+    trace("Flushing state stores for taskName: %s" format taskName)
 
     metrics.commits.inc
 
     storageManager.flush
 
-    trace("Flushing producers for partition: %s" format partition)
+    trace("Flushing producers for taskName: %s" format taskName)
 
     producerMultiplexer.flush(metrics.source)
 
-    trace("Committing offset manager for partition: %s" format partition)
+    trace("Committing offset manager for taskName: %s" format taskName)
 
-    offsetManager.checkpoint(partition)
+    offsetManager.checkpoint(taskName)
   }
 
   def shutdownTask {
     listeners.foreach(_.beforeClose(config, context))
 
     if (task.isInstanceOf[ClosableTask]) {
-      debug("Shutting down stream task for partition: %s" format partition)
+      debug("Shutting down stream task for taskName: %s" format taskName)
 
       task.asInstanceOf[ClosableTask].close
     } else {
-      debug("Skipping stream task shutdown for partition: %s" format partition)
+      debug("Skipping stream task shutdown for taskName: %s" format taskName)
     }
 
     listeners.foreach(_.afterClose(config, context))
@@ -211,16 +202,16 @@ class TaskInstance(
 
   def shutdownStores {
     if (storageManager != null) {
-      debug("Shutting down storage manager for partition: %s" format partition)
+      debug("Shutting down storage manager for taskName: %s" format taskName)
 
       storageManager.stop
     } else {
-      debug("Skipping storage manager shutdown for partition: %s" format partition)
+      debug("Skipping storage manager shutdown for taskName: %s" format taskName)
     }
   }
 
-  override def toString() = "TaskInstance for class %s and partition %s." format (task.getClass.getName, partition)
+  override def toString() = "TaskInstance for class %s and taskName %s." format (task.getClass.getName, taskName)
 
-  def toDetailedString() = "TaskInstance [windowable=%s, closable=%s, collector_size=%s]" format (isWindowableTask, isClosableTask, collector.envelopes.size)
+  def toDetailedString() = "TaskInstance [taskName = %s, windowable=%s, closable=%s, collector_size=%s]" format (taskName, isWindowableTask, isClosableTask, collector.envelopes.size)
 
 }
index 7502124..aae3f87 100644 (file)
@@ -21,10 +21,8 @@ package org.apache.samza.container
 
 import org.apache.samza.metrics.ReadableMetricsRegistry
 import org.apache.samza.metrics.MetricsRegistryMap
-import org.apache.samza.Partition
 import org.apache.samza.metrics.MetricsHelper
-import org.apache.samza.system.SystemStream
-import org.apache.samza.metrics.Gauge
+import org.apache.samza.system.SystemStreamPartition
 
 class TaskInstanceMetrics(
   val source: String = "unknown",
@@ -37,7 +35,7 @@ class TaskInstanceMetrics(
   val sendsSkipped = newCounter("send-skipped")
   val messagesSent = newCounter("messages-sent")
 
-  def addOffsetGauge(systemStream: SystemStream, getValue: () => String) {
-    newGauge("%s-%s-offset" format (systemStream.getSystem, systemStream.getStream), getValue)
+  def addOffsetGauge(systemStreamPartition: SystemStreamPartition, getValue: () => String) {
+    newGauge("%s-%s-%d-offset" format (systemStreamPartition.getSystem, systemStreamPartition.getStream, systemStreamPartition.getPartition.getPartitionId), getValue)
   }
 }
diff --git a/samza-core/src/main/scala/org/apache/samza/container/TaskNamesToSystemStreamPartitions.scala b/samza-core/src/main/scala/org/apache/samza/container/TaskNamesToSystemStreamPartitions.scala
new file mode 100644 (file)
index 0000000..427119e
--- /dev/null
@@ -0,0 +1,145 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.samza.container
+
+import grizzled.slf4j.Logging
+import org.apache.samza.SamzaException
+import org.apache.samza.system.{SystemStream, SystemStreamPartition}
+import scala.collection.{immutable, Map, MapLike}
+
+/**
+ * Map of {@link TaskName} to its set of {@link SystemStreamPartition}s with additional methods for aggregating
+ * those SystemStreamPartitions' individual system, streams and partitions.  Is useful for highlighting this
+ * particular, heavily used map within the code.
+ *
+ * @param m Original map of TaskNames to SystemStreamPartitions
+ */
+class TaskNamesToSystemStreamPartitions(m:Map[TaskName, Set[SystemStreamPartition]] = Map[TaskName, Set[SystemStreamPartition]]())
+  extends Map[TaskName, Set[SystemStreamPartition]]
+  with MapLike[TaskName, Set[SystemStreamPartition], TaskNamesToSystemStreamPartitions] with Logging {
+
+  // Constructor
+  validate
+
+  // Methods
+
+  // TODO: Get rid of public constructor, rely entirely on the companion object
+  override def -(key: TaskName): TaskNamesToSystemStreamPartitions = new TaskNamesToSystemStreamPartitions(m - key)
+
+  override def +[B1 >: Set[SystemStreamPartition]](kv: (TaskName, B1)): Map[TaskName, B1] = new TaskNamesToSystemStreamPartitions(m + kv.asInstanceOf[(TaskName, Set[SystemStreamPartition])])
+
+  override def iterator: Iterator[(TaskName, Set[SystemStreamPartition])] = m.iterator
+
+  override def get(key: TaskName): Option[Set[SystemStreamPartition]] = m.get(key)
+
+  override def empty: TaskNamesToSystemStreamPartitions = new TaskNamesToSystemStreamPartitions()
+
+  override def seq: Map[TaskName, Set[SystemStreamPartition]] = m.seq
+
+  override def foreach[U](f: ((TaskName, Set[SystemStreamPartition])) => U): Unit = m.foreach(f)
+
+  override def size: Int = m.size
+
+  /**
+   * Validate that this is a legal mapping of TaskNames to SystemStreamPartitions.  At the moment,
+   * we only check that an SSP is included in the mapping at most once.  We could add other,
+   * pluggable validations here, or if we decided to allow an SSP to appear in the mapping more than
+   * once, remove this limitation.
+   */
+  def validate():Unit = {
+    // Convert sets of SSPs to lists, to preserve duplicates
+    val allSSPs: List[SystemStreamPartition] = m.values.toList.map(_.toList).flatten
+    val sspCountMap = allSSPs.groupBy(ssp => ssp)  // Group all the SSPs together
+      .map(ssp => (ssp._1 -> ssp._2.size))         // Turn into map -> count of that SSP
+      .filter(ssp => ssp._2 != 1)                  // Filter out those that appear once
+
+    if(!sspCountMap.isEmpty) {
+      throw new SamzaException("Assigning the same SystemStreamPartition to multiple TaskNames is not currently supported." +
+        "  Out of compliance SystemStreamPartitions and counts: " + sspCountMap)
+    }
+
+    debug("Successfully validated TaskName to SystemStreamPartition set mapping:" + m)
+  }
+
+  /**
+   * Return a set of all the SystemStreamPartitions for all the keys.
+   *
+   * @return All SystemStreamPartitions within this map
+   */
+  def getAllSSPs(): Iterable[SystemStreamPartition] = m.values.flatten
+
+  /**
+   * Return a set of all the Systems presents in the SystemStreamPartitions across all the keys
+   *
+   * @return All Systems within this map
+   */
+  def getAllSystems(): Set[String] = getAllSSPs.map(_.getSystemStream.getSystem).toSet
+
+  /**
+   * Return a set of all the Partition IDs in the SystemStreamPartitions across all the keys
+   *
+   * @return All Partition IDs within this map
+   */
+  def getAllPartitionIds(): Set[Int] = getAllSSPs.map(_.getPartition.getPartitionId).toSet
+
+  /**
+   * Return a set of all the Streams in the SystemStreamPartitions across all the keys
+   *
+   * @return All Streams within this map
+   */
+  def getAllStreams(): Set[String] = getAllSSPs.map(_.getSystemStream.getStream).toSet
+
+  /**
+   * Return a set of all the SystemStreams in the SystemStreamPartitions across all the keys
+   *
+   * @return All SystemStreams within this map
+   */
+  def getAllSystemStreams: Set[SystemStream] = getAllSSPs().map(_.getSystemStream).toSet
+
+  // CommandBuilder needs to get a copy of this map and is a Java interface, therefore we can't just go straight
+  // from this type to JSON (for passing into the command option.
+  // Not super crazy about having the Java -> Scala and Scala -> Java methods in two different (but close) places:
+  // here and in the apply method on the companion object.  May be better to just have a conversion util, but would
+  // be less clean.  Life is cruel on the border of Scalapolis and Javatown.
+  def getJavaFriendlyType: java.util.Map[TaskName, java.util.Set[SystemStreamPartition]] = {
+    import scala.collection.JavaConverters._
+
+    m.map({case(k,v) => k -> v.asJava}).toMap.asJava
+  }
+}
+
+object TaskNamesToSystemStreamPartitions {
+  def apply() = new TaskNamesToSystemStreamPartitions()
+
+  def apply(m: Map[TaskName, Set[SystemStreamPartition]]) = new TaskNamesToSystemStreamPartitions(m)
+
+  /**
+   * Convert from Java-happy type we obtain from the SSPTaskName factory
+   *
+   * @param m Java version of a map of sets of strings
+   * @return Populated SSPTaskName map
+   */
+  def apply(m: java.util.Map[TaskName, java.util.Set[SystemStreamPartition]]) = {
+    import scala.collection.JavaConversions._
+
+    val rightType: immutable.Map[TaskName, Set[SystemStreamPartition]] = m.map({case(k,v) => k -> v.toSet}).toMap
+
+    new TaskNamesToSystemStreamPartitions(rightType)
+  }
+}
diff --git a/samza-core/src/main/scala/org/apache/samza/container/systemstreampartition/groupers/GroupByPartition.scala b/samza-core/src/main/scala/org/apache/samza/container/systemstreampartition/groupers/GroupByPartition.scala
new file mode 100644 (file)
index 0000000..223862f
--- /dev/null
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.samza.container.systemstreampartition.groupers
+
+import org.apache.samza.container.{TaskName, SystemStreamPartitionGrouperFactory, SystemStreamPartitionGrouper}
+import java.util
+import org.apache.samza.system.SystemStreamPartition
+import scala.collection.JavaConverters._
+import scala.collection.JavaConversions._
+import org.apache.samza.config.Config
+
+/**
+ * Group the {@link org.apache.samza.system.SystemStreamPartition}s by their Partition, with the key being
+ * the string representation of the Partition.
+ */
+class GroupByPartition extends SystemStreamPartitionGrouper {
+  override def group(ssps: util.Set[SystemStreamPartition]) = {
+    ssps.groupBy( s => new TaskName("Partition " + s.getPartition.getPartitionId) )
+      .map(r => r._1 -> r._2.asJava)
+  }
+}
+
+class GroupByPartitionFactory extends SystemStreamPartitionGrouperFactory {
+  override def getSystemStreamPartitionGrouper(config: Config): SystemStreamPartitionGrouper = new GroupByPartition
+}
diff --git a/samza-core/src/main/scala/org/apache/samza/container/systemstreampartition/groupers/GroupBySystemStreamPartition.scala b/samza-core/src/main/scala/org/apache/samza/container/systemstreampartition/groupers/GroupBySystemStreamPartition.scala
new file mode 100644 (file)
index 0000000..a2bcfee
--- /dev/null
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.samza.container.systemstreampartition.groupers
+
+import org.apache.samza.container.{TaskName, SystemStreamPartitionGrouperFactory, SystemStreamPartitionGrouper}
+import java.util
+import org.apache.samza.system.SystemStreamPartition
+import scala.collection.JavaConverters._
+import scala.collection.JavaConversions._
+import org.apache.samza.config.Config
+
+/**
+ * Group the {@link org.apache.samza.system.SystemStreamPartition}s by themselves, effectively putting each
+ * SystemStreamPartition into its own group, with the key being the string representation of the SystemStringPartition
+ */
+class GroupBySystemStreamPartition extends SystemStreamPartitionGrouper {
+  override def group(ssps: util.Set[SystemStreamPartition]) = ssps.groupBy({ s=> new TaskName(s.toString)}).map(r => r._1 -> r._2.asJava)
+}
+
+class GroupBySystemStreamPartitionFactory extends SystemStreamPartitionGrouperFactory {
+  override def getSystemStreamPartitionGrouper(config: Config): SystemStreamPartitionGrouper = new GroupBySystemStreamPartition
+}
diff --git a/samza-core/src/main/scala/org/apache/samza/container/systemstreampartition/taskname/groupers/SimpleSystemStreamPartitionTaskNameGrouper.scala b/samza-core/src/main/scala/org/apache/samza/container/systemstreampartition/taskname/groupers/SimpleSystemStreamPartitionTaskNameGrouper.scala
new file mode 100644 (file)
index 0000000..7913294
--- /dev/null
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.samza.container.systemstreampartition.taskname.groupers
+
+import org.apache.samza.container.{TaskName, SystemStreamPartitionTaskNameGrouper, TaskNamesToSystemStreamPartitions}
+import org.apache.samza.system.SystemStreamPartition
+
+/**
+ * Group the SSP taskNames by dividing the number of taskNames into the number of containers (n) and assigning n taskNames
+ * to each container as returned by iterating over the keys in the map of taskNames (whatever that ordering happens to be).
+ * No consideration is given towards locality, even distribution of aggregate SSPs within a container, even distribution
+ * of the number of taskNames between containers, etc.
+ */
+class SimpleSystemStreamPartitionTaskNameGrouper(numContainers:Int) extends SystemStreamPartitionTaskNameGrouper {
+  require(numContainers > 0, "Must have at least one container")
+
+  override def groupTaskNames(taskNames: TaskNamesToSystemStreamPartitions): Map[Int, TaskNamesToSystemStreamPartitions] = {
+    val keySize = taskNames.keySet.size
+    require(keySize > 0, "Must have some SSPs to group, but found none")
+
+    // Iterate through the taskNames, round-robining them per container
+    val byContainerNum = (0 until numContainers).map(_ -> scala.collection.mutable.Map[TaskName, Set[SystemStreamPartition]]()).toMap
+    var idx = 0
+    for(taskName <- taskNames.iterator) {
+      val currMap = byContainerNum.get(idx).get // safe to use simple get since we populated everybody above
+      idx = (idx + 1) % numContainers
+
+      currMap += taskName
+    }
+
+    byContainerNum.map(kv => kv._1 -> TaskNamesToSystemStreamPartitions(kv._2)).toMap
+  }
+}
+
index 4635bb2..8c5533c 100644 (file)
 
 package org.apache.samza.job
 
-import scala.collection.JavaConversions._
+import java.util
+import org.apache.samza.Partition
 import org.apache.samza.config.ShellCommandConfig
 import org.apache.samza.config.ShellCommandConfig.Config2ShellCommand
 import org.apache.samza.config.serializers.JsonConfigSerializer
+import org.apache.samza.container.TaskName
+import org.apache.samza.system.SystemStreamPartition
 import org.apache.samza.util.Util
+import org.codehaus.jackson.`type`.TypeReference
+import org.codehaus.jackson.map.ObjectMapper
+import scala.collection.JavaConversions._
+import scala.reflect.BeanProperty
+
+object ShellCommandBuilder {
+  /**
+   * Jackson really hates Scala's classes, so we need to wrap up the SSP in a form Jackson will take
+   */
+  private class SSPWrapper(@BeanProperty var partition:java.lang.Integer = null,
+                           @BeanProperty var Stream:java.lang.String = null,
+                           @BeanProperty var System:java.lang.String = null) {
+    def this() { this(null, null, null) }
+    def this(ssp:SystemStreamPartition) { this(ssp.getPartition.getPartitionId, ssp.getSystemStream.getStream, ssp.getSystemStream.getSystem)}
+  }
+
+  def serializeSystemStreamPartitionSetToJSON(sspTaskNames: java.util.Map[TaskName,java.util.Set[SystemStreamPartition]]): String = {
+    val map = new util.HashMap[TaskName, util.ArrayList[SSPWrapper]]()
+    for((key, ssps) <- sspTaskNames) {
+      val al = new util.ArrayList[SSPWrapper](ssps.size)
+      for(ssp <- ssps) { al.add(new SSPWrapper(ssp)) }
+      map.put(key, al)
+    }
+
+    new ObjectMapper().writeValueAsString(map)
+  }
+
+  def deserializeSystemStreamPartitionSetFromJSON(sspsAsJSON: String): Map[TaskName, Set[SystemStreamPartition]] = {
+    val om = new ObjectMapper()
+
+    val asWrapper = om.readValue(sspsAsJSON, new TypeReference[util.HashMap[String, util.ArrayList[SSPWrapper]]]() { }).asInstanceOf[util.HashMap[String, util.ArrayList[SSPWrapper]]]
+
+    val taskName = for((key, sspsWrappers) <- asWrapper;
+                       taskName = new TaskName(key);
+                       ssps = sspsWrappers.map(w => new SystemStreamPartition(w.getSystem, w.getStream, new Partition(w.getPartition))).toSet
+    ) yield(taskName -> ssps)
 
+    taskName.toMap // to get an immutable map rather than mutable...
+  }
+
+  def serializeTaskNameToChangeLogPartitionMapping(mapping:Map[TaskName, Int]) = {
+    val javaMap = new util.HashMap[TaskName, java.lang.Integer]()
+
+    mapping.foreach(kv => javaMap.put(kv._1, Integer.valueOf(kv._2)))
+
+    new ObjectMapper().writeValueAsString(javaMap)
+  }
+
+  def deserializeTaskNameToChangeLogPartitionMapping(taskNamesAsJSON: String): Map[TaskName, Int] = {
+    val om = new ObjectMapper()
+
+    val asMap = om.readValue(taskNamesAsJSON, new TypeReference[util.HashMap[String, java.lang.Integer]] {}).asInstanceOf[util.HashMap[String, java.lang.Integer]]
+
+    asMap.map(kv => new TaskName(kv._1) -> kv._2.intValue()).toMap
+  }
+}
 class ShellCommandBuilder extends CommandBuilder {
   def buildCommand() = config.getCommand
 
   def buildEnvironment(): java.util.Map[String, String] = {
-    var streamsAndPartsString = Util.serializeSSPSetToJSON(systemStreamPartitions.toSet) // Java to Scala set conversion
+    var streamsAndPartsString = ShellCommandBuilder.serializeSystemStreamPartitionSetToJSON(taskNameToSystemStreamPartitionsMapping) // Java to Scala set conversion
+    var taskNameToChangeLogPartitionMappingString = ShellCommandBuilder.serializeTaskNameToChangeLogPartitionMapping(taskNameToChangeLogPartitionMapping.map(kv => kv._1 -> kv._2.toInt).toMap)
     var envConfig = JsonConfigSerializer.toJson(config)
     val isCompressed = if(config.isEnvConfigCompressed) "TRUE" else "FALSE"
 
@@ -40,14 +99,17 @@ class ShellCommandBuilder extends CommandBuilder {
        * of the environment variable(s) exported while starting a Samza container (SAMZA-337)
        */
       streamsAndPartsString = Util.compress(streamsAndPartsString)
+      taskNameToChangeLogPartitionMappingString = Util.compress(taskNameToChangeLogPartitionMappingString)
       envConfig = Util.compress(envConfig)
     }
 
     Map(
       ShellCommandConfig.ENV_CONTAINER_NAME -> name,
       ShellCommandConfig.ENV_SYSTEM_STREAMS -> streamsAndPartsString,
+      ShellCommandConfig.ENV_TASK_NAME_TO_CHANGELOG_PARTITION_MAPPING -> taskNameToChangeLogPartitionMappingString,
       ShellCommandConfig.ENV_CONFIG -> envConfig,
       ShellCommandConfig.ENV_JAVA_OPTS -> config.getTaskOpts.getOrElse(""),
       ShellCommandConfig.ENV_COMPRESS_CONFIG -> isCompressed)
+
   }
 }
index e20e7c1..713bded 100644 (file)
@@ -24,24 +24,32 @@ import org.apache.samza.config.ShellCommandConfig._
 import org.apache.samza.job.CommandBuilder
 import org.apache.samza.job.StreamJob
 import org.apache.samza.job.StreamJobFactory
-import scala.collection.JavaConversions._
 import grizzled.slf4j.Logging
 import org.apache.samza.SamzaException
-import org.apache.samza.container.SamzaContainer
+import org.apache.samza.container.{TaskNamesToSystemStreamPartitions, SamzaContainer}
 import org.apache.samza.util.Util
 import org.apache.samza.job.ShellCommandBuilder
+import scala.collection.JavaConversions._
 
 class LocalJobFactory extends StreamJobFactory with Logging {
   def getJob(config: Config): StreamJob = {
-    val taskName = "local-task"
-    val partitions = Util.getInputStreamPartitions(config)
+    val jobName = "local-container"
 
-    info("got partitions for job %s" format partitions)
+    // Since we're local, there will only be a single task into which all the SSPs will be processed
+    val taskToTaskNames: Map[Int, TaskNamesToSystemStreamPartitions] = Util.assignContainerToSSPTaskNames(config, 1)
+    if(taskToTaskNames.size != 1) {
+      throw new SamzaException("Should only have a single task count but somehow got more " + taskToTaskNames.size)
+    }
 
-    if (partitions.size <= 0) {
-      throw new SamzaException("No partitions were detected for your input streams. It's likely that the system(s) specified don't know about the input streams: %s" format config.getInputStreams)
+    // So pull out that single TaskNamesToSystemStreamPartitions
+    val sspTaskName: TaskNamesToSystemStreamPartitions = taskToTaskNames.getOrElse(0, throw new SamzaException("Should have a 0 task number for the SSPs but somehow do not: " + taskToTaskNames))
+    if (sspTaskName.size <= 0) {
+      throw new SamzaException("No SystemStreamPartitions to process were detected for your input streams. It's likely that the system(s) specified don't know about the input streams: %s" format config.getInputStreams)
     }
 
+    val taskNameToChangeLogPartitionMapping = Util.getTaskNameToChangeLogPartitionMapping(config, taskToTaskNames).map(kv => kv._1 -> Integer.valueOf(kv._2))
+    info("got taskName for job %s" format sspTaskName)
+
     config.getCommandClass match {
       case Some(cmdBuilderClassName) => {
         // A command class was specified, so we need to use a process job to
@@ -50,8 +58,8 @@ class LocalJobFactory extends StreamJobFactory with Logging {
         
         cmdBuilder
           .setConfig(config)
-          .setName(taskName)
-          .setStreamPartitions(partitions)
+          .setName(jobName)
+          .setTaskNameToSystemStreamPartitionsMapping(sspTaskName.getJavaFriendlyType)
 
         val processBuilder = new ProcessBuilder(cmdBuilder.buildCommand.split(" ").toList)
 
@@ -72,7 +80,7 @@ class LocalJobFactory extends StreamJobFactory with Logging {
 
         // No command class was specified, so execute the job in this process
         // using a threaded job.
-        new ThreadJob(SamzaContainer(taskName, partitions, config))
+        new ThreadJob(SamzaContainer(jobName, sspTaskName, taskNameToChangeLogPartitionMapping, config))
       }
     }
   }
index 3d0a484..34c846c 100644 (file)
 
 package org.apache.samza.serializers
 
-import scala.collection.JavaConversions._
-import org.codehaus.jackson.map.ObjectMapper
-import org.apache.samza.system.SystemStream
-import org.apache.samza.checkpoint.Checkpoint
-import org.apache.samza.SamzaException
 import grizzled.slf4j.Logging
+import java.util
+import org.apache.samza.checkpoint.Checkpoint
+import org.apache.samza.container.TaskName
+import org.apache.samza.system.SystemStreamPartition
+import org.apache.samza.{SamzaException, Partition}
+import org.codehaus.jackson.map.ObjectMapper
+import scala.collection.JavaConversions._
+import org.codehaus.jackson.`type`.TypeReference
 
+/**
+ * Write out the Checkpoint object in JSON.  The underlying map of SSP => Offset cannot be stored directly because
+ * JSON only allows strings as map types, so we would need to separately serialize the SSP to a string that doesn't
+ * then interfere with JSON's decoding of the overall map.  We'll sidestep the whole issue by turning the
+ * map into a list[String] of (System, Stream, Partition, Offset) serializing that.
+ */
 class CheckpointSerde extends Serde[Checkpoint] with Logging {
+  import CheckpointSerde._
+  // TODO: Elucidate the CheckpointSerde relationshiop to Serde. Should Serde also have keyTo/FromBytes? Should
+  // we just take CheckpointSerde here as interface and have this be JSONCheckpointSerde?
+  // TODO: Add more tests.  This class currently only has direct test and is mainly tested by the other checkpoint managers
   val jsonMapper = new ObjectMapper()
 
+  // Jackson absolutely hates Scala types and hidden conversions hate you, so we're going to be very, very
+  // explicit about the Java (not Scala) types used here and never let Scala get its grubby little hands
+  // on any instance.
+
+  // Store checkpoint as maps keyed of the SSP.toString to the another map of the constituent SSP components
+  // and offset.  Jackson can't automatically serialize the SSP since it's not a POJO and this avoids
+  // having to wrap it another class while maintaing readability.
+
   def fromBytes(bytes: Array[Byte]): Checkpoint = {
     try {
-      val checkpointMap = jsonMapper
-        .readValue(bytes, classOf[java.util.Map[String, java.util.Map[String, String]]])
-        .flatMap {
-          case (systemName, streamToOffsetMap) =>
-            streamToOffsetMap.map { case (streamName, offset) => (new SystemStream(systemName, streamName), offset) }
-        }
-      return new Checkpoint(checkpointMap)
-    } catch {
+      val jMap = jsonMapper.readValue(bytes, classOf[util.HashMap[String, util.HashMap[String, String]]])
+
+      def deserializeJSONMap(m:util.HashMap[String, String]) = {
+        require(m.size() == 4, "All JSON-encoded SystemStreamPartitions must have four keys")
+        val system = m.get("system")
+        require(system != null, "System must be present in JSON-encoded SystemStreamPartition")
+        val stream = m.get("stream")
+        require(stream != null, "Stream must be present in JSON-encoded SystemStreamPartition")
+        val partition = m.get("partition")
+        require(partition != null, "Partition must be present in JSON-encoded SystemStreamPartition")
+        val offset = m.get("offset")
+        require(offset != null, "Offset must be present in JSON-encoded SystemStreamPartition")
+
+        new SystemStreamPartition(system, stream, new Partition(partition.toInt)) -> offset
+      }
+
+      val cpMap = jMap.values.map(deserializeJSONMap).toMap
+
+      return new Checkpoint(cpMap)
+    }catch {
       case e : Exception =>
         warn("Exception while deserializing checkpoint: " + e)
         debug("Exception detail:", e)
@@ -46,34 +79,38 @@ class CheckpointSerde extends Serde[Checkpoint] with Logging {
     }
   }
 
-  def toBytes(checkpoint: Checkpoint) = {
-    val offsetMap = mapAsJavaMap(checkpoint
-      .getOffsets
-      // Convert Map[SystemStream, String] offset map to a iterable of tuples (system, stream, offset)
-      .map { case (systemStream, offset) => (systemStream.getSystem, systemStream.getStream, offset) }
-      // Group into a Map[String, (String, String, String)] by system
-      .groupBy(_._1)
-      // Group the tuples for each system into a Map[String, String] for stream to offsets
-      .map {
-        case (systemName, tuples) =>
-          val streamToOffestMap = mapAsJavaMap(tuples
-            // Group the tuples by stream name
-            .groupBy(_._2)
-            // There should only ever be one SystemStream to offset mapping, so just 
-            // grab the first element from the tuple list for each stream.
-            .map {
-              case (streamName, tuples) => {
-                // If there's more than one offset, something is seriously wrong.
-                if (tuples.size != 1) {
-                  throw new SamzaException("Got %s offsets for %s. Expected only one offset, so failing." format (tuples.size, streamName))
-                }
-                (streamName, tuples.head._3)
-              }
-            }
-            .toMap)
-          (systemName, streamToOffestMap)
-      }.toMap)
-
-    jsonMapper.writeValueAsBytes(offsetMap)
+  def toBytes(checkpoint: Checkpoint): Array[Byte] = {
+    val offsets = checkpoint.getOffsets
+    val asMap = new util.HashMap[String, util.HashMap[String, String]](offsets.size())
+
+    offsets.foreach {
+      case (ssp, offset) =>
+        val jMap = new util.HashMap[String, String](4)
+        jMap.put("system", ssp.getSystemStream.getSystem)
+        jMap.put("stream", ssp.getSystemStream.getStream)
+        jMap.put("partition", ssp.getPartition.getPartitionId.toString)
+        jMap.put("offset", offset)
+
+        asMap.put(ssp.toString, jMap)
+    }
+
+    jsonMapper.writeValueAsBytes(asMap)
+  }
+
+  def changelogPartitionMappingFromBytes(bytes: Array[Byte]): util.Map[TaskName, java.lang.Integer] = {
+    try {
+      jsonMapper.readValue(bytes, PARTITION_MAPPING_TYPEREFERENCE)
+    } catch {
+      case e : Exception =>
+        throw new SamzaException("Exception while deserializing changelog partition mapping", e)
+    }
   }
+
+  def changelogPartitionMappingToBytes(mapping: util.Map[TaskName, java.lang.Integer]) = {
+    jsonMapper.writeValueAsBytes(new util.HashMap[TaskName, java.lang.Integer](mapping))
+  }
+}
+
+object CheckpointSerde {
+  val PARTITION_MAPPING_TYPEREFERENCE = new TypeReference[util.HashMap[TaskName, java.lang.Integer]]() {}
 }
index 7214151..0cfdbb3 100644 (file)
@@ -29,14 +29,16 @@ import org.apache.samza.system.SystemStreamPartition
 import org.apache.samza.system.SystemStreamPartitionIterator
 import org.apache.samza.util.Util
 import org.apache.samza.SamzaException
+import org.apache.samza.container.TaskName
 
 object TaskStorageManager {
   def getStoreDir(storeBaseDir: File, storeName: String) = {
     new File(storeBaseDir, storeName)
   }
 
-  def getStorePartitionDir(storeBaseDir: File, storeName: String, partition: Partition) = {
-    new File(storeBaseDir, storeName + File.separator + partition.getPartitionId)
+  def getStorePartitionDir(storeBaseDir: File, storeName: String, taskName: TaskName) = {
+    // TODO: Sanitize, check and clean taskName string as a valid value for a file
+    new File(storeBaseDir, storeName + File.separator + taskName)
   }
 }
 
@@ -44,12 +46,13 @@ object TaskStorageManager {
  * Manage all the storage engines for a given task
  */
 class TaskStorageManager(
-  partition: Partition,
+  taskName: TaskName,
   taskStores: Map[String, StorageEngine] = Map(),
   storeConsumers: Map[String, SystemConsumer] = Map(),
   changeLogSystemStreams: Map[String, SystemStream] = Map(),
   changeLogOldestOffsets: Map[SystemStream, String] = Map(),
-  storeBaseDir: File = new File(System.getProperty("user.dir"), "state")) extends Logging {
+  storeBaseDir: File = new File(System.getProperty("user.dir"), "state"),
+  partition: Partition) extends Logging {
 
   var taskStoresToRestore = taskStores
 
@@ -66,7 +69,7 @@ class TaskStorageManager(
     debug("Cleaning base directories for stores.")
 
     taskStores.keys.foreach(storeName => {
-      val storagePartitionDir = TaskStorageManager.getStorePartitionDir(storeBaseDir, storeName, partition)
+      val storagePartitionDir = TaskStorageManager.getStorePartitionDir(storeBaseDir, storeName, taskName)
 
       debug("Cleaning %s for store %s." format (storagePartitionDir, storeName))
 
index 4ccd604..6e1134d 100644 (file)
 package org.apache.samza.task
 
 import org.apache.samza.task.TaskCoordinator.RequestScope
-import org.apache.samza.Partition
+import org.apache.samza.container.TaskName
 
 /**
  * An in-memory implementation of TaskCoordinator that is specific to a single TaskInstance.
  */
-class ReadableCoordinator(val partition: Partition) extends TaskCoordinator {
+class ReadableCoordinator(val taskName: TaskName) extends TaskCoordinator {
   var commitRequest: Option[RequestScope] = None
   var shutdownRequest: Option[RequestScope] = None
 
index 60d96c9..32c2647 100644 (file)
 
 package org.apache.samza.util
 
+import grizzled.slf4j.Logging
 import java.io._
 import java.lang.management.ManagementFactory
+import java.util
 import java.util.Random
 import java.util.zip.{GZIPInputStream, GZIPOutputStream}
-import grizzled.slf4j.Logging
 import org.apache.commons.codec.binary.Base64
-import org.apache.samza.{ Partition, SamzaException }
+import org.apache.samza.SamzaException
+import org.apache.samza.checkpoint.CheckpointManagerFactory
 import org.apache.samza.config.Config
+import org.apache.samza.config.StorageConfig.Config2Storage
 import org.apache.samza.config.SystemConfig.Config2System
 import org.apache.samza.config.TaskConfig.Config2Task
-import scala.collection.JavaConversions._
+import org.apache.samza.container.systemstreampartition.groupers.GroupByPartitionFactory
+import org.apache.samza.container.systemstreampartition.taskname.groupers.SimpleSystemStreamPartitionTaskNameGrouper
+import org.apache.samza.container.{TaskName, SystemStreamPartitionTaskNameGrouper, TaskNamesToSystemStreamPartitions, SystemStreamPartitionGrouperFactory}
+import org.apache.samza.metrics.MetricsRegistryMap
 import org.apache.samza.system.{SystemStreamPartition, SystemFactory, StreamMetadataCache, SystemStream}
-import org.codehaus.jackson.map.ObjectMapper
-import org.codehaus.jackson.`type`.TypeReference
-import java.util
-import scala.reflect.BeanProperty
-
+import scala.collection.JavaConversions._
+import scala.collection
 
 object Util extends Logging {
   val random = new Random
@@ -82,7 +85,7 @@ object Util extends Logging {
 
   /**
    * For each input stream specified in config, exactly determine its partitions, returning a set of SystemStreamPartitions
-   * corresponding to them all
+   * containing them all
    *
    * @param config Source of truth for systems and inputStreams
    * @return Set of SystemStreamPartitions, one for each unique system, stream and partition
@@ -112,6 +115,50 @@ object Util extends Logging {
   }
 
   /**
+   * Assign mapping of which TaskNames go to which container
+   *
+   * @param config For factories for Grouper and TaskNameGrouper
+   * @param containerCount How many tasks are we we working with
+   * @return Map of int (taskId) to SSPTaskNameMap that taskID is responsible for
+   */
+  def assignContainerToSSPTaskNames(config:Config, containerCount:Int): Map[Int, TaskNamesToSystemStreamPartitions] = {
+    import org.apache.samza.config.JobConfig.Config2Job
+
+    val allSystemStreamPartitions: Set[SystemStreamPartition] = Util.getInputStreamPartitions(config)
+
+    val sspTaskNamesAsJava: util.Map[TaskName, util.Set[SystemStreamPartition]] = {
+      val factoryString = config.getSystemStreamPartitionGrouperFactory
+
+      info("Instantiating type " + factoryString + " to build SystemStreamPartition groupings")
+
+      val factory = Util.getObj[SystemStreamPartitionGrouperFactory](factoryString)
+
+      val grouper = factory.getSystemStreamPartitionGrouper(config)
+
+      val groups = grouper.group(allSystemStreamPartitions)
+
+      info("SystemStreamPartitionGrouper " + grouper + " has grouped the SystemStreamPartitions into the following taskNames:")
+      groups.foreach(g => info("TaskName: " + g._1 + " => " + g._2))
+
+      groups
+    }
+
+    val sspTaskNames = TaskNamesToSystemStreamPartitions(sspTaskNamesAsJava)
+
+    info("Assigning " + sspTaskNames.keySet.size + " SystemStreamPartitions taskNames to " + containerCount + " containers.")
+
+    // Here is where we should put in a pluggable option for the SSPTaskNameGrouper for locality, load-balancing, etc.
+    val sspTaskNameGrouper = new SimpleSystemStreamPartitionTaskNameGrouper(containerCount)
+
+    val containersToTaskNames = sspTaskNameGrouper.groupTaskNames(sspTaskNames).toMap
+
+    info("Grouped SystemStreamPartition TaskNames (size = " + containersToTaskNames.size + "): ")
+    containersToTaskNames.foreach(t => info("Container number: " + t._1 + " => " + t._2))
+
+    containersToTaskNames
+  }
+
+  /**
    * Returns a SystemStream object based on the system stream name given. For
    * example, kafka.topic would return new SystemStream("kafka", "topic").
    */
@@ -131,40 +178,120 @@ object Util extends Logging {
   }
 
   /**
-   * For specified containerId, create a list of of the streams and partitions that task should handle,
-   * based on the number of tasks in the job
+   * Using previous taskName to partition mapping and current taskNames for this job run, create a new mapping that preserves
+   * the previous order and deterministically assigns any new taskNames to changelog partitions.  Be chatty about new or
+   * missing taskNames.
    *
-   * @param containerId TaskID to determine work for
-   * @param containerCount Total number of tasks in the job
-   * @param ssp All SystemStreamPartitions
-   * @return Collection of streams and partitions for this particular containerId
+   * @param currentTaskNames All the taskNames the current job is processing
+   * @param previousTaskNameMapping Previous mapping of taskNames to partition
+   * @return New mapping of taskNames to partitions for the changelog
    */
-  def getStreamsAndPartitionsForContainer(containerId: Int, containerCount: Int, ssp: Set[SystemStreamPartition]): Set[SystemStreamPartition] = {
-    ssp.filter(_.getPartition.getPartitionId % containerCount == containerId)
+  def resolveTaskNameToChangelogPartitionMapping(currentTaskNames:Set[TaskName],
+    previousTaskNameMapping:Map[TaskName, Int]): Map[TaskName, Int] = {
+    info("Previous mapping of taskNames to partition: " + previousTaskNameMapping.toList.sorted)
+    info("Current set of taskNames: " + currentTaskNames.toList.sorted)
+
+    val previousTaskNames: Set[TaskName] = previousTaskNameMapping.keySet
+
+    if(previousTaskNames.equals(currentTaskNames)) {
+      info("No change in TaskName sets from previous run. Returning previous mapping.")
+      return previousTaskNameMapping
+    }
+
+    if(previousTaskNames.isEmpty) {
+      warn("No previous taskName mapping defined.  This is OK if it's the first time the job is being run, otherwise data may have been lost.")
+    }
+
+    val missingTaskNames = previousTaskNames -- currentTaskNames
+
+    if(missingTaskNames.isEmpty) {
+      info("No taskNames are missing between this run and previous")
+    } else {
+      warn("The following taskNames were previously defined and are no longer present: " + missingTaskNames)
+    }
+
+    val newTaskNames = currentTaskNames -- previousTaskNames
+
+    if(newTaskNames.isEmpty) {
+      info("No new taskNames have been added between this run and the previous")
+      previousTaskNameMapping // Return the old mapping since there are no new taskNames for which to account
+
+    } else {
+      warn("The following new taskNames have been added in this job run: " + newTaskNames)
+
+      // Sort the new taskNames and assign them partitions (starting at 0 for now)
+      val sortedNewTaskNames = newTaskNames.toList.sortWith { (a,b) => a.getTaskName < b.getTaskName }.zipWithIndex.toMap
+
+      // Find next largest partition to use based on previous mapping
+      val nextPartitionToUse = if(previousTaskNameMapping.size == 0) 0
+                               else previousTaskNameMapping.foldLeft(0)((a,b) => math.max(a, b._2)) + 1
+
+      // Bump up the partition values
+      val newTaskNamesWithTheirPartitions = sortedNewTaskNames.map(c => c._1 -> (c._2 + nextPartitionToUse))
+      
+      // Merge old and new
+      val newMapping = previousTaskNameMapping ++ newTaskNamesWithTheirPartitions
+      
+      info("New taskName to partition mapping: " + newMapping.toList.sortWith{ (a,b) => a._2 < b._2})
+      
+      newMapping
+    }
   }
 
   /**
-   * Jackson really hates Scala's classes, so we need to wrap up the SSP in a form Jackson will take
+   * Read the TaskName to changelog partition mapping from the checkpoint manager, if one exists.
+   *
+   * @param config To pull out values for instantiating checkpoint manager
+   * @param tasksToSSPTaskNames Current TaskNames for the current job run
+   * @return Current mapping of TaskNames to changelog partitions
    */
-  private class SSPWrapper(@BeanProperty var partition:java.lang.Integer = null,
-                           @BeanProperty var Stream:java.lang.String = null,
-                           @BeanProperty var System:java.lang.String = null) {
-    def this() { this(null, null, null) }
-    def this(ssp:SystemStreamPartition) { this(ssp.getPartition.getPartitionId, ssp.getSystemStream.getStream, ssp.getSystemStream.getSystem)}
-  }
+  def getTaskNameToChangeLogPartitionMapping(config: Config, tasksToSSPTaskNames: Map[Int, TaskNamesToSystemStreamPartitions]) = {
+    val taskNameMaps: Set[TaskNamesToSystemStreamPartitions] = tasksToSSPTaskNames.map(_._2).toSet
+    val currentTaskNames: Set[TaskName] = taskNameMaps.map(_.keys).toSet.flatten
+
+    // We need to oh so quickly instantiate a checkpoint manager and grab the partition mapping from the log, then toss the manager aside
+    val checkpointManager = config.getCheckpointManagerFactory match {
+      case Some(checkpointFactoryClassName) =>
+        Util
+          .getObj[CheckpointManagerFactory](checkpointFactoryClassName)
+          .getCheckpointManager(config, new MetricsRegistryMap)
+      case _ => null
+    }
 
-  def serializeSSPSetToJSON(ssps: Set[SystemStreamPartition]): String = {
-    val al = new util.ArrayList[SSPWrapper](ssps.size)
-    for(ssp <- ssps) { al.add(new SSPWrapper(ssp)) }
+    if(checkpointManager == null) {
+      // Check if we have a changelog configured, which requires a checkpoint manager
 
-    new ObjectMapper().writeValueAsString(al)
-  }
+      if(!config.getStoreNames.isEmpty) {
+        throw new SamzaException("Storage factories configured, but no checkpoint manager has been specified.  " +
+          "Unable to start job as there would be no place to store changelog partition mapping.")
+      }
+      // No need to do any mapping, just use what has been provided
+      Util.resolveTaskNameToChangelogPartitionMapping(currentTaskNames, Map[TaskName, Int]())
+    } else {
+
+      info("Got checkpoint manager: %s" format checkpointManager)
+
+      // Always put in a call to create so the log is available for the tasks on startup.
+      // Reasonably lame to hide it in here.  TODO: Pull out to more visible location.
+      checkpointManager.start
 
-  def deserializeSSPSetFromJSON(ssp: String) = {
-    val om = new ObjectMapper()
+      val previousMapping: Map[TaskName, Int] = {
+        val fromCM = checkpointManager.readChangeLogPartitionMapping()
 
-    val asWrapper = om.readValue(ssp, new TypeReference[util.ArrayList[SSPWrapper]]() { }).asInstanceOf[util.ArrayList[SSPWrapper]]
-    asWrapper.map(w => new SystemStreamPartition(w.getSystem, w.getStream(), new Partition(w.getPartition()))).toSet
+        fromCM.map(kv => kv._1 -> kv._2.intValue()).toMap // Java to Scala interop!!!
+      }
+
+      checkpointManager.stop
+
+      val newMapping = Util.resolveTaskNameToChangelogPartitionMapping(currentTaskNames, previousMapping)
+
+      if (newMapping != null) {
+        info("Writing new changelog partition mapping to checkpoint manager.")
+        checkpointManager.writeChangeLogPartitionMapping(newMapping.map(kv => kv._1 -> new java.lang.Integer(kv._2))) //Java to Scala interop!!!
+      }
+
+      newMapping
+    }
   }
 
   /**
index bc54f9e..1eb3995 100644 (file)
 
 package org.apache.samza.checkpoint
 
+import org.apache.samza.Partition
+import org.apache.samza.checkpoint.TestCheckpointTool.{MockCheckpointManagerFactory, MockSystemFactory}
+import org.apache.samza.config.{Config, MapConfig, SystemConfig, TaskConfig}
+import org.apache.samza.metrics.MetricsRegistry
+import org.apache.samza.system.SystemStreamMetadata.SystemStreamPartitionMetadata
+import org.apache.samza.system.{SystemAdmin, SystemConsumer, SystemFactory, SystemProducer, SystemStream, SystemStreamMetadata, SystemStreamPartition}
 import org.junit.{Before, Test}
 import org.mockito.Matchers._
 import org.mockito.Mockito._
 import org.scalatest.junit.AssertionsForJUnit
 import org.scalatest.mock.MockitoSugar
 import scala.collection.JavaConversions._
-import org.apache.samza.Partition
-import org.apache.samza.checkpoint.TestCheckpointTool.{MockCheckpointManagerFactory, MockSystemFactory}
-import org.apache.samza.config.{Config, MapConfig, SystemConfig, TaskConfig}
-import org.apache.samza.metrics.MetricsRegistry
-import org.apache.samza.system.{SystemAdmin, SystemConsumer, SystemFactory, SystemProducer, SystemStream, SystemStreamMetadata, SystemStreamPartition}
-import org.apache.samza.system.SystemStreamMetadata.SystemStreamPartitionMetadata
+import org.apache.samza.container.TaskName
 
 object TestCheckpointTool {
   var checkpointManager: CheckpointManager = null
@@ -52,6 +53,11 @@ object TestCheckpointTool {
 class TestCheckpointTool extends AssertionsForJUnit with MockitoSugar {
   var config: MapConfig = null
 
+  val tn0 = new TaskName("Partition 0")
+  val tn1 = new TaskName("Partition 1")
+  val p0 = new Partition(0)
+  val p1 = new Partition(1)
+
   @Before
   def setup {
     config = new MapConfig(Map(
@@ -68,29 +74,30 @@ class TestCheckpointTool extends AssertionsForJUnit with MockitoSugar {
     TestCheckpointTool.systemAdmin = mock[SystemAdmin]
     when(TestCheckpointTool.systemAdmin.getSystemStreamMetadata(Set("foo")))
       .thenReturn(Map("foo" -> metadata))
-    when(TestCheckpointTool.checkpointManager.readLastCheckpoint(new Partition(0)))
-      .thenReturn(new Checkpoint(Map(new SystemStream("test", "foo") -> "1234")))
-    when(TestCheckpointTool.checkpointManager.readLastCheckpoint(new Partition(1)))
-      .thenReturn(new Checkpoint(Map(new SystemStream("test", "foo") -> "4321")))
+    when(TestCheckpointTool.checkpointManager.readLastCheckpoint(tn0))
+      .thenReturn(new Checkpoint(Map(new SystemStreamPartition("test", "foo", p0) -> "1234")))
+    when(TestCheckpointTool.checkpointManager.readLastCheckpoint(tn1))
+      .thenReturn(new Checkpoint(Map(new SystemStreamPartition("test", "foo", p1) -> "4321")))
+
   }
 
   @Test
   def testReadLatestCheckpoint {
     new CheckpointTool(config, null).run
-    verify(TestCheckpointTool.checkpointManager).readLastCheckpoint(new Partition(0))
-    verify(TestCheckpointTool.checkpointManager).readLastCheckpoint(new Partition(1))
+    verify(TestCheckpointTool.checkpointManager).readLastCheckpoint(tn0)
+    verify(TestCheckpointTool.checkpointManager).readLastCheckpoint(tn1)
     verify(TestCheckpointTool.checkpointManager, never()).writeCheckpoint(any(), any())
   }
 
   @Test
   def testOverwriteCheckpoint {
-    new CheckpointTool(config, Map(
-      new SystemStreamPartition("test", "foo", new Partition(0)) -> "42",
-      new SystemStreamPartition("test", "foo", new Partition(1)) -> "43"
-    )).run
+    val toOverwrite = Map(tn0 -> Map(new SystemStreamPartition("test", "foo", p0) -> "42"),
+      tn1 -> Map(new SystemStreamPartition("test", "foo", p1) -> "43"))
+
+    new CheckpointTool(config, toOverwrite).run
     verify(TestCheckpointTool.checkpointManager)
-      .writeCheckpoint(new Partition(0), new Checkpoint(Map(new SystemStream("test", "foo") -> "42")))
+      .writeCheckpoint(tn0, new Checkpoint(Map(new SystemStreamPartition("test", "foo", p0) -> "42")))
     verify(TestCheckpointTool.checkpointManager)
-      .writeCheckpoint(new Partition(1), new Checkpoint(Map(new SystemStream("test", "foo") -> "43")))
+      .writeCheckpoint(tn1, new Checkpoint(Map(new SystemStreamPartition("test", "foo", p1) -> "43")))
   }
 }
index 94f6f4c..44a98a5 100644 (file)
@@ -23,19 +23,21 @@ import scala.collection.JavaConversions._
 import org.apache.samza.Partition
 import org.apache.samza.system.SystemStream
 import org.apache.samza.system.SystemStreamMetadata
-import org.apache.samza.system.SystemStreamMetadata.OffsetType
-import org.apache.samza.system.SystemStreamMetadata.SystemStreamPartitionMetadata
+import org.apache.samza.system.SystemStreamMetadata.{OffsetType, SystemStreamPartitionMetadata}
 import org.apache.samza.system.SystemStreamPartition
 import org.junit.Assert._
 import org.junit.{Ignore, Test}
 import org.apache.samza.SamzaException
 import org.apache.samza.config.MapConfig
 import org.apache.samza.system.SystemAdmin
+import java.util
+import org.apache.samza.container.TaskName
 import org.scalatest.Assertions.intercept
 
 class TestOffsetManager {
   @Test
   def testSystemShouldUseDefaults {
+    val taskName = new TaskName("c")
     val systemStream = new SystemStream("test-system", "test-stream")
     val partition = new Partition(0)
     val systemStreamPartition = new SystemStreamPartition(systemStream, partition)
@@ -43,7 +45,7 @@ class TestOffsetManager {
     val systemStreamMetadata = Map(systemStream -> testStreamMetadata)
     val config = new MapConfig(Map("systems.test-system.samza.offset.default" -> "oldest"))
     val offsetManager = OffsetManager(systemStreamMetadata, config)
-    offsetManager.register(systemStreamPartition)
+    offsetManager.register(taskName, Set(systemStreamPartition))
     offsetManager.start
     assertTrue(!offsetManager.getLastProcessedOffset(systemStreamPartition).isDefined)
     assertTrue(offsetManager.getStartingOffset(systemStreamPartition).isDefined)
@@ -52,21 +54,22 @@ class TestOffsetManager {
 
   @Test
   def testShouldLoadFromAndSaveWithCheckpointManager {
+    val taskName = new TaskName("c")
     val systemStream = new SystemStream("test-system", "test-stream")
     val partition = new Partition(0)
     val systemStreamPartition = new SystemStreamPartition(systemStream, partition)
     val testStreamMetadata = new SystemStreamMetadata(systemStream.getStream, Map(partition -> new SystemStreamPartitionMetadata("0", "1", "2")))
     val systemStreamMetadata = Map(systemStream -> testStreamMetadata)
     val config = new MapConfig
-    val checkpointManager = getCheckpointManager(systemStreamPartition)
+    val checkpointManager = getCheckpointManager(systemStreamPartition, taskName)
     val systemAdmins = Map("test-system" -> getSystemAdmin)
     val offsetManager = OffsetManager(systemStreamMetadata, config, checkpointManager, systemAdmins)
-    offsetManager.register(systemStreamPartition)
+    offsetManager.register(taskName, Set(systemStreamPartition))
     offsetManager.start
     assertTrue(checkpointManager.isStarted)
     assertEquals(1, checkpointManager.registered.size)
-    assertEquals(partition, checkpointManager.registered.head)
-    assertEquals(checkpointManager.checkpoints.head._2, checkpointManager.readLastCheckpoint(partition))
+    assertEquals(taskName, checkpointManager.registered.head)
+    assertEquals(checkpointManager.checkpoints.head._2, checkpointManager.readLastCheckpoint(taskName))
     // Should get offset 45 back from the checkpoint manager, which is last processed, and system admin should return 46 as starting offset.
     assertEquals("46", offsetManager.getStartingOffset(systemStreamPartition).get)
     assertEquals("45", offsetManager.getLastProcessedOffset(systemStreamPartition).get)
@@ -76,31 +79,31 @@ class TestOffsetManager {
     assertEquals("47", offsetManager.getLastProcessedOffset(systemStreamPartition).get)
     // Should never update starting offset.
     assertEquals("46", offsetManager.getStartingOffset(systemStreamPartition).get)
-    offsetManager.checkpoint(partition)
-    val expectedCheckpoint = new Checkpoint(Map(systemStream -> "47"))
-    assertEquals(expectedCheckpoint, checkpointManager.readLastCheckpoint(partition))
+    offsetManager.checkpoint(taskName)
+    val expectedCheckpoint = new Checkpoint(Map(systemStreamPartition -> "47"))
+    assertEquals(expectedCheckpoint, checkpointManager.readLastCheckpoint(taskName))
   }
 
   @Test
   def testShouldResetStreams {
+    val taskName = new TaskName("c")
     val systemStream = new SystemStream("test-system", "test-stream")
     val partition = new Partition(0)
     val systemStreamPartition = new SystemStreamPartition(systemStream, partition)
     val testStreamMetadata = new SystemStreamMetadata(systemStream.getStream, Map(partition -> new SystemStreamPartitionMetadata("0", "1", "2")))
     val systemStreamMetadata = Map(systemStream -> testStreamMetadata)
-    val defaultOffsets = Map(systemStream -> OffsetType.OLDEST)
-    val checkpoint = new Checkpoint(Map(systemStream -> "45"))
-    val checkpointManager = getCheckpointManager(systemStreamPartition)
+    val checkpoint = new Checkpoint(Map(systemStreamPartition -> "45"))
+    val checkpointManager = getCheckpointManager(systemStreamPartition, taskName)
     val config = new MapConfig(Map(
       "systems.test-system.samza.offset.default" -> "oldest",
       "systems.test-system.streams.test-stream.samza.reset.offset" -> "true"))
     val offsetManager = OffsetManager(systemStreamMetadata, config, checkpointManager)
-    offsetManager.register(systemStreamPartition)
+    offsetManager.register(taskName, Set(systemStreamPartition))
     offsetManager.start
     assertTrue(checkpointManager.isStarted)
     assertEquals(1, checkpointManager.registered.size)
-    assertEquals(partition, checkpointManager.registered.head)
-    assertEquals(checkpoint, checkpointManager.readLastCheckpoint(partition))
+    assertEquals(taskName, checkpointManager.registered.head)
+    assertEquals(checkpoint, checkpointManager.readLastCheckpoint(taskName))
     // Should be zero even though the checkpoint has an offset of 45, since we're forcing a reset.
     assertEquals("0", offsetManager.getStartingOffset(systemStreamPartition).get)
   }
@@ -110,36 +113,38 @@ class TestOffsetManager {
     val systemStream = new SystemStream("test-system", "test-stream")
     val partition1 = new Partition(0)
     val partition2 = new Partition(1)
+    val taskName1 = new TaskName("P0")
+    val taskName2 = new TaskName("P1")
     val systemStreamPartition1 = new SystemStreamPartition(systemStream, partition1)
     val systemStreamPartition2 = new SystemStreamPartition(systemStream, partition2)
     val testStreamMetadata = new SystemStreamMetadata(systemStream.getStream, Map(
       partition1 -> new SystemStreamPartitionMetadata("0", "1", "2"),
       partition2 -> new SystemStreamPartitionMetadata("3", "4", "5")))
     val systemStreamMetadata = Map(systemStream -> testStreamMetadata)
-    val defaultOffsets = Map(systemStream -> OffsetType.OLDEST)
-    val checkpoint = new Checkpoint(Map(systemStream -> "45"))
+    val checkpoint = new Checkpoint(Map(systemStreamPartition1 -> "45"))
     // Checkpoint manager only has partition 1.
-    val checkpointManager = getCheckpointManager(systemStreamPartition1)
+    val checkpointManager = getCheckpointManager(systemStreamPartition1, taskName1)
     val systemAdmins = Map("test-system" -> getSystemAdmin)
     val config = new MapConfig
     val offsetManager = OffsetManager(systemStreamMetadata, config, checkpointManager, systemAdmins)
     // Register both partitions. Partition 2 shouldn't have a checkpoint.
-    offsetManager.register(systemStreamPartition1)
-    offsetManager.register(systemStreamPartition2)
+    offsetManager.register(taskName1, Set(systemStreamPartition1))
+    offsetManager.register(taskName2, Set(systemStreamPartition2))
     offsetManager.start
     assertTrue(checkpointManager.isStarted)
     assertEquals(2, checkpointManager.registered.size)
-    assertEquals(checkpoint, checkpointManager.readLastCheckpoint(partition1))
-    assertNull(checkpointManager.readLastCheckpoint(partition2))
+    assertEquals(checkpoint, checkpointManager.readLastCheckpoint(taskName1))
+    assertNull(checkpointManager.readLastCheckpoint(taskName2))
   }
 
   @Test
   def testShouldFailWhenMissingMetadata {
+    val taskName = new TaskName("c")
     val systemStream = new SystemStream("test-system", "test-stream")
     val partition = new Partition(0)
     val systemStreamPartition = new SystemStreamPartition(systemStream, partition)
     val offsetManager = new OffsetManager
-    offsetManager.register(systemStreamPartition)
+    offsetManager.register(taskName, Set(systemStreamPartition))
 
     intercept[SamzaException] {
       offsetManager.start
@@ -148,13 +153,14 @@ class TestOffsetManager {
 
   @Ignore("OffsetManager.start is supposed to throw an exception - but it doesn't") @Test
   def testShouldFailWhenMissingDefault {
+    val taskName = new TaskName("c")
     val systemStream = new SystemStream("test-system", "test-stream")
     val partition = new Partition(0)
     val systemStreamPartition = new SystemStreamPartition(systemStream, partition)
     val testStreamMetadata = new SystemStreamMetadata(systemStream.getStream, Map(partition -> new SystemStreamPartitionMetadata("0", "1", "2")))
     val systemStreamMetadata = Map(systemStream -> testStreamMetadata)
     val offsetManager = OffsetManager(systemStreamMetadata, new MapConfig(Map[String, String]()))
-    offsetManager.register(systemStreamPartition)
+    offsetManager.register(taskName, Set(systemStreamPartition))
 
     intercept[SamzaException] {
       offsetManager.start
@@ -190,6 +196,7 @@ class TestOffsetManager {
 
   @Test
   def testOutdatedStreamInCheckpoint {
+    val taskName = new TaskName("c")
     val systemStream0 = new SystemStream("test-system-0", "test-stream")
     val systemStream1 = new SystemStream("test-system-1", "test-stream")
     val partition0 = new Partition(0)
@@ -200,26 +207,32 @@ class TestOffsetManager {
     val offsetSettings = Map(systemStream0 -> OffsetSetting(testStreamMetadata, OffsetType.UPCOMING, false))
     val checkpointManager = getCheckpointManager(systemStreamPartition1)
     val offsetManager = new OffsetManager(offsetSettings, checkpointManager)
-    offsetManager.register(systemStreamPartition0)
+    offsetManager.register(taskName, Set(systemStreamPartition0))
     offsetManager.start
     assertTrue(checkpointManager.isStarted)
     assertEquals(1, checkpointManager.registered.size)
     assertNull(offsetManager.getLastProcessedOffset(systemStreamPartition1).getOrElse(null))
   }
 
-  private def getCheckpointManager(systemStreamPartition: SystemStreamPartition) = {
-    val checkpoint = new Checkpoint(Map(systemStreamPartition.getSystemStream -> "45"))
+  private def getCheckpointManager(systemStreamPartition: SystemStreamPartition, taskName:TaskName = new TaskName("taskName")) = {
+    val checkpoint = new Checkpoint(Map(systemStreamPartition -> "45"))
 
     new CheckpointManager {
       var isStarted = false
       var isStopped = false
-      var registered = Set[Partition]()
-      var checkpoints = Map(systemStreamPartition.getPartition -> checkpoint)
+      var registered = Set[TaskName]()
+      var checkpoints: Map[TaskName, Checkpoint] = Map(taskName -> checkpoint)
+      var taskNameToPartitionMapping: util.Map[TaskName, java.lang.Integer] = new util.HashMap[TaskName, java.lang.Integer]()
       def start { isStarted = true }
-      def register(partition: Partition) { registered += partition }
-      def writeCheckpoint(partition: Partition, checkpoint: Checkpoint) { checkpoints += partition -> checkpoint }
-      def readLastCheckpoint(partition: Partition) = checkpoints.getOrElse(partition, null)
+      def register(taskName: TaskName) { registered += taskName }
+      def writeCheckpoint(taskName: TaskName, checkpoint: Checkpoint) { checkpoints += taskName -> checkpoint }
+      def readLastCheckpoint(taskName: TaskName) = checkpoints.getOrElse(taskName, null)
       def stop { isStopped = true }
+
+      override def writeChangeLogPartitionMapping(mapping: util.Map[TaskName, java.lang.Integer]): Unit = taskNameToPartitionMapping = mapping
+
+      override def readChangeLogPartitionMapping(): util.Map[TaskName, java.lang.Integer] = taskNameToPartitionMapping
+
     }
   }
 
@@ -232,4 +245,4 @@ class TestOffsetManager {
         Map[String, SystemStreamMetadata]()
     }
   }
-}
\ No newline at end of file
+}
index 50d9a05..10ff1f4 100644 (file)
@@ -23,34 +23,53 @@ import java.io.File
 import scala.collection.JavaConversions._
 import java.util.Random
 import org.junit.Assert._
-import org.junit.Test
+import org.junit.{After, Before, Test}
 import org.apache.samza.SamzaException
 import org.apache.samza.Partition
 import org.apache.samza.checkpoint.Checkpoint
-import org.apache.samza.system.SystemStream
+import org.apache.samza.system.SystemStreamPartition
+import org.apache.samza.container.TaskName
+import org.junit.rules.TemporaryFolder
 
-class TestFileSystemCheckpointManager {
-  val checkpointRoot = System.getProperty("java.io.tmpdir")
+class TestFileSystemCheckpointManager  {
+  val checkpointRoot = System.getProperty("java.io.tmpdir") // TODO: Move this out of tmp, into our build dir
+  val taskName = new TaskName("Warwickshire")
+  val baseFileLocation = new File(checkpointRoot)
+
+  val tempFolder = new TemporaryFolder
+
+  @Before
+  def createTempFolder = tempFolder.create()
+
+  @After
+  def deleteTempFolder = tempFolder.delete()
 
   @Test
   def testReadForCheckpointFileThatDoesNotExistShouldReturnNull {
-    val cpm = new FileSystemCheckpointManager("some-job-name", new File(checkpointRoot))
-    assert(cpm.readLastCheckpoint(new Partition(1)) == null)
+    val cpm = new FileSystemCheckpointManager("some-job-name", tempFolder.getRoot)
+    assertNull(cpm.readLastCheckpoint(taskName))
   }
 
   @Test
   def testReadForCheckpointFileThatDoesExistShouldReturnProperCheckpoint {
-    val cpm = new FileSystemCheckpointManager("some-job-name", new File(checkpointRoot))
-    val partition = new Partition(2)
     val cp = new Checkpoint(Map(
-      new SystemStream("a", "b") -> "c",
-      new SystemStream("a", "c") -> "d",
-      new SystemStream("b", "d") -> "e"))
+          new SystemStreamPartition("a", "b", new Partition(0)) -> "c",
+          new SystemStreamPartition("a", "c", new Partition(1)) -> "d",
+          new SystemStreamPartition("b", "d", new Partition(2)) -> "e"))
+
+    var readCp:Checkpoint = null
+    val cpm =  new FileSystemCheckpointManager("some-job-name", tempFolder.getRoot)
+
     cpm.start
-    cpm.writeCheckpoint(partition, cp)
-    val readCp = cpm.readLastCheckpoint(partition)
+    cpm.writeCheckpoint(taskName, cp)
+    readCp = cpm.readLastCheckpoint(taskName)
     cpm.stop
-    assert(readCp.equals(cp))
+
+    assertNotNull(readCp)
+    cp.equals(readCp)
+    assertEquals(cp.getOffsets.keySet(), readCp.getOffsets.keySet())
+    assertEquals(cp.getOffsets, readCp.getOffsets)
+    assertEquals(cp, readCp)
   }
 
   @Test
diff --git a/samza-core/src/test/scala/org/apache/samza/container/SystemStreamPartitionGrouperTestBase.scala b/samza-core/src/test/scala/org/apache/samza/container/SystemStreamPartitionGrouperTestBase.scala
new file mode 100644 (file)
index 0000000..3032b00
--- /dev/null
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.samza.container
+
+import org.apache.samza.Partition
+import org.apache.samza.system.SystemStreamPartition
+import org.junit.Test
+import java.util.HashSet
+import java.util.Map
+import java.util.Set
+import org.junit.Assert.assertEquals
+import org.junit.Assert.assertTrue
+import java.util.Collections
+
+object SystemStreamPartitionGrouperTestBase {
+  val aa0 = new SystemStreamPartition("SystemA", "StreamA", new Partition(0))
+  val aa1 = new SystemStreamPartition("SystemA", "StreamA", new Partition(1))
+  val aa2 = new SystemStreamPartition("SystemA", "StreamA", new Partition(2))
+  val ab1 = new SystemStreamPartition("SystemA", "StreamB", new Partition(1))
+  val ab2 = new SystemStreamPartition("SystemA", "StreamB", new Partition(2))
+  val ac0 = new SystemStreamPartition("SystemA", "StreamB", new Partition(0))
+  val allSSPs = new HashSet[SystemStreamPartition]
+  Collections.addAll(allSSPs, aa0, aa1, aa2, ab1, ab2, ac0)
+}
+
+abstract class SystemStreamPartitionGrouperTestBase {
+  def getGrouper: SystemStreamPartitionGrouper
+
+  @Test
+  def emptySetReturnsEmptyMap {
+    val grouper: SystemStreamPartitionGrouper = getGrouper
+    val result: Map[TaskName, Set[SystemStreamPartition]] = grouper.group(new HashSet[SystemStreamPartition])
+    assertTrue(result.isEmpty)
+  }
+
+  def verifyGroupGroupsCorrectly(input: Set[SystemStreamPartition], output: Map[TaskName, Set[SystemStreamPartition]]) {
+    val grouper: SystemStreamPartitionGrouper = getGrouper
+    val result: Map[TaskName, Set[SystemStreamPartition]] = grouper.group(input)
+    assertEquals(output, result)
+  }
+}
\ No newline at end of file
index fa10231..d4ceffc 100644 (file)
@@ -37,29 +37,45 @@ class TestRunLoop extends AssertionsForJUnit with MockitoSugar with ShouldMatche
 
   val p0 = new Partition(0)
   val p1 = new Partition(1)
+  val taskName0 = new TaskName(p0.toString)
+  val taskName1 = new TaskName(p1.toString)
   val ssp0 = new SystemStreamPartition("testSystem", "testStream", p0)
   val ssp1 = new SystemStreamPartition("testSystem", "testStream", p1)
   val envelope0 = new IncomingMessageEnvelope(ssp0, "0", "key0", "value0")
   val envelope1 = new IncomingMessageEnvelope(ssp1, "1", "key1", "value1")
 
+  def getMockTaskInstances: Map[TaskName, TaskInstance] = {
+    val ti0 = mock[TaskInstance]
+    when(ti0.systemStreamPartitions).thenReturn(Set(ssp0))
+    when(ti0.taskName).thenReturn(taskName0)
+
+    val ti1 = mock[TaskInstance]
+    when(ti1.systemStreamPartitions).thenReturn(Set(ssp1))
+    when(ti1.taskName).thenReturn(taskName1)
+
+    Map(taskName0 -> ti0, taskName1 -> ti1)
+  }
+
   @Test
   def testProcessMessageFromChooser {
-    val taskInstances = Map(p0 -> mock[TaskInstance], p1 -> mock[TaskInstance])
+    val taskInstances = getMockTaskInstances
     val consumers = mock[SystemConsumers]
     val runLoop = new RunLoop(taskInstances, consumers, new SamzaContainerMetrics)
 
     when(consumers.choose).thenReturn(envelope0).thenReturn(envelope1).thenThrow(new StopRunLoop)
     intercept[StopRunLoop] { runLoop.run }
-    verify(taskInstances(p0)).process(Matchers.eq(envelope0), anyObject)
-    verify(taskInstances(p1)).process(Matchers.eq(envelope1), anyObject)
+    verify(taskInstances(taskName0)).process(Matchers.eq(envelope0), anyObject)
+    verify(taskInstances(taskName1)).process(Matchers.eq(envelope1), anyObject)
     runLoop.metrics.envelopes.getCount should equal(2L)
     runLoop.metrics.nullEnvelopes.getCount should equal(0L)
   }
 
+
   @Test
   def testNullMessageFromChooser {
     val consumers = mock[SystemConsumers]
-    val runLoop = new RunLoop(Map(p0 -> mock[TaskInstance]), consumers, new SamzaContainerMetrics)
+    val map = getMockTaskInstances - taskName1 // This test only needs p0
+    val runLoop = new RunLoop(map, consumers, new SamzaContainerMetrics)
     when(consumers.choose).thenReturn(null).thenReturn(null).thenThrow(new StopRunLoop)
     intercept[StopRunLoop] { runLoop.run }
     runLoop.metrics.envelopes.getCount should equal(0L)
@@ -73,7 +89,7 @@ class TestRunLoop extends AssertionsForJUnit with MockitoSugar with ShouldMatche
     when(consumers.choose).thenReturn(envelope0)
 
     val runLoop = new RunLoop(
-      taskInstances = Map(p0 -> mock[TaskInstance], p1 -> mock[TaskInstance]),
+      taskInstances = getMockTaskInstances,
       consumerMultiplexer = consumers,
       metrics = new SamzaContainerMetrics,
       windowMs = 60000, // call window once per minute
@@ -86,67 +102,67 @@ class TestRunLoop extends AssertionsForJUnit with MockitoSugar with ShouldMatche
 
     intercept[StopRunLoop] { runLoop.run }
 
-    verify(runLoop.taskInstances(p0), times(5)).window(anyObject)
-    verify(runLoop.taskInstances(p1), times(5)).window(anyObject)
-    verify(runLoop.taskInstances(p0), times(10)).commit
-    verify(runLoop.taskInstances(p1), times(10)).commit
+    verify(runLoop.taskInstances(taskName0), times(5)).window(anyObject)
+    verify(runLoop.taskInstances(taskName1), times(5)).window(anyObject)
+    verify(runLoop.taskInstances(taskName0), times(10)).commit
+    verify(runLoop.taskInstances(taskName1), times(10)).commit
   }
 
   @Test
   def testCommitCurrentTaskManually {
-    val taskInstances = Map(p0 -> mock[TaskInstance], p1 -> mock[TaskInstance])
+    val taskInstances = getMockTaskInstances
     val consumers = mock[SystemConsumers]
     val runLoop = new RunLoop(taskInstances, consumers, new SamzaContainerMetrics, windowMs = -1, commitMs = -1)
 
     when(consumers.choose).thenReturn(envelope0).thenReturn(envelope1).thenThrow(new StopRunLoop)
-    stubProcess(taskInstances(p0), (envelope, coordinator) => coordinator.commit(RequestScope.CURRENT_TASK))
+    stubProcess(taskInstances(taskName0), (envelope, coordinator) => coordinator.commit(RequestScope.CURRENT_TASK))
 
     intercept[StopRunLoop] { runLoop.run }
-    verify(taskInstances(p0), times(1)).commit
-    verify(taskInstances(p1), times(0)).commit
+    verify(taskInstances(taskName0), times(1)).commit
+    verify(taskInstances(taskName1), times(0)).commit
   }
 
   @Test
   def testCommitAllTasksManually {
-    val taskInstances = Map(p0 -> mock[TaskInstance], p1 -> mock[TaskInstance])
+    val taskInstances = getMockTaskInstances
     val consumers = mock[SystemConsumers]
     val runLoop = new RunLoop(taskInstances, consumers, new SamzaContainerMetrics, windowMs = -1, commitMs = -1)
 
     when(consumers.choose).thenReturn(envelope0).thenThrow(new StopRunLoop)
-    stubProcess(taskInstances(p0), (envelope, coordinator) => coordinator.commit(RequestScope.ALL_TASKS_IN_CONTAINER))
+    stubProcess(taskInstances(taskName0), (envelope, coordinator) => coordinator.commit(RequestScope.ALL_TASKS_IN_CONTAINER))
 
     intercept[StopRunLoop] { runLoop.run }
-    verify(taskInstances(p0), times(1)).commit
-    verify(taskInstances(p1), times(1)).commit
+    verify(taskInstances(taskName0), times(1)).commit
+    verify(taskInstances(taskName1), times(1)).commit
   }
 
   @Test
   def testShutdownOnConsensus {
-    val taskInstances = Map(p0 -> mock[TaskInstance], p1 -> mock[TaskInstance])
+    val taskInstances = getMockTaskInstances
     val consumers = mock[SystemConsumers]
     val runLoop = new RunLoop(taskInstances, consumers, new SamzaContainerMetrics, windowMs = -1, commitMs = -1)
 
     when(consumers.choose).thenReturn(envelope0).thenReturn(envelope0).thenReturn(envelope1)
-    stubProcess(taskInstances(p0), (envelope, coordinator) => coordinator.shutdown(RequestScope.CURRENT_TASK))
-    stubProcess(taskInstances(p1), (envelope, coordinator) => coordinator.shutdown(RequestScope.CURRENT_TASK))
+    stubProcess(taskInstances(taskName0), (envelope, coordinator) => coordinator.shutdown(RequestScope.CURRENT_TASK))
+    stubProcess(taskInstances(taskName1), (envelope, coordinator) => coordinator.shutdown(RequestScope.CURRENT_TASK))
 
     runLoop.run
-    verify(taskInstances(p0), times(2)).process(Matchers.eq(envelope0), anyObject)
-    verify(taskInstances(p1), times(1)).process(Matchers.eq(envelope1), anyObject)
+    verify(taskInstances(taskName0), times(2)).process(Matchers.eq(envelope0), anyObject)
+    verify(taskInstances(taskName1), times(1)).process(Matchers.eq(envelope1), anyObject)
   }
 
   @Test
   def testShutdownNow {
-    val taskInstances = Map(p0 -> mock[TaskInstance], p1 -> mock[TaskInstance])
+    val taskInstances = getMockTaskInstances
     val consumers = mock[SystemConsumers]
     val runLoop = new RunLoop(taskInstances, consumers, new SamzaContainerMetrics, windowMs = -1, commitMs = -1)
 
     when(consumers.choose).thenReturn(envelope0).thenReturn(envelope1)
-    stubProcess(taskInstances(p0), (envelope, coordinator) => coordinator.shutdown(RequestScope.ALL_TASKS_IN_CONTAINER))
+    stubProcess(taskInstances(taskName0), (envelope, coordinator) => coordinator.shutdown(RequestScope.ALL_TASKS_IN_CONTAINER))
 
     runLoop.run
-    verify(taskInstances(p0), times(1)).process(anyObject, anyObject)
-    verify(taskInstances(p1), times(0)).process(anyObject, anyObject)
+    verify(taskInstances(taskName0), times(1)).process(anyObject, anyObject)
+    verify(taskInstances(taskName1), times(0)).process(anyObject, anyObject)
   }
 
   def anyObject[T] = Matchers.anyObject.asInstanceOf[T]
index 190bdfe..e3c7fe3 100644 (file)
 
 package org.apache.samza.container
 
-import java.io.File
 import org.apache.samza.config.Config
 import org.junit.Assert._
 import org.junit.Test
 import org.apache.samza.Partition
 import org.apache.samza.config.MapConfig
-import org.apache.samza.metrics.MetricsRegistryMap
 import org.apache.samza.system.SystemConsumers
 import org.apache.samza.system.chooser.RoundRobinChooser
 import org.apache.samza.system.SystemConsumer
@@ -81,8 +79,7 @@ class TestSamzaContainer {
       }
     }
     val config = new MapConfig
-    val partition = new Partition(0)
-    val containerName = "test-container"
+    val taskName = new TaskName("taskName")
     val consumerMultiplexer = new SystemConsumers(
       new RoundRobinChooser,
       Map[String, SystemConsumer]())
@@ -91,18 +88,18 @@ class TestSamzaContainer {
       new SerdeManager)
     val taskInstance: TaskInstance = new TaskInstance(
       task,
-      partition,
+      taskName,
       config,
       new TaskInstanceMetrics,
       consumerMultiplexer: SystemConsumers,
       producerMultiplexer: SystemProducers)
     val runLoop = new RunLoop(
-      taskInstances = Map(partition -> taskInstance),
+      taskInstances = Map(taskName -> taskInstance),
       consumerMultiplexer = consumerMultiplexer,
       metrics = new SamzaContainerMetrics
     )
     val container = new SamzaContainer(
-      Map(partition -> taskInstance),
+      Map(taskName -> taskInstance),
       runLoop,
       consumerMultiplexer,
       producerMultiplexer,
index 1f5e3bb..9d5ff13 100644 (file)
@@ -50,7 +50,6 @@ class TestTaskInstance {
     }
     val config = new MapConfig
     val partition = new Partition(0)
-    val containerName = "test-container"
     val consumerMultiplexer = new SystemConsumers(
       new RoundRobinChooser,
       Map[String, SystemConsumer]())
@@ -62,16 +61,17 @@ class TestTaskInstance {
     // Pretend our last checkpointed (next) offset was 2.
     val testSystemStreamMetadata = new SystemStreamMetadata(systemStream.getStream, Map(partition -> new SystemStreamPartitionMetadata("0", "1", "2")))
     val offsetManager = OffsetManager(Map(systemStream -> testSystemStreamMetadata), config)
+    val taskName = new TaskName("taskName")
     val taskInstance: TaskInstance = new TaskInstance(
       task,
-      partition,
+      taskName,
       config,
       new TaskInstanceMetrics,
       consumerMultiplexer,
       producerMultiplexer,
       offsetManager)
     // Pretend we got a message with offset 2 and next offset 3.
-    val coordinator = new ReadableCoordinator(partition)
+    val coordinator = new ReadableCoordinator(taskName)
     taskInstance.process(new IncomingMessageEnvelope(systemStreamPartition, "2", null, null), coordinator)
     // Check to see if the offset manager has been properly updated with offset 3.
     val lastProcessedOffset = offsetManager.getLastProcessedOffset(systemStreamPartition)
diff --git a/samza-core/src/test/scala/org/apache/samza/container/TestTaskNamesToSystemStreamPartitions.scala b/samza-core/src/test/scala/org/apache/samza/container/TestTaskNamesToSystemStreamPartitions.scala
new file mode 100644 (file)
index 0000000..d680b20
--- /dev/null
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.samza.container
+
+import org.junit.Test
+import org.junit.Assert._
+import org.apache.samza.system.SystemStreamPartition
+import org.apache.samza.{SamzaException, Partition}
+
+class TestTaskNamesToSystemStreamPartitions {
+  var sspCounter = 0
+  def makeSSP(stream:String) = new SystemStreamPartition("system", stream, new Partition(42))
+
+  @Test
+  def toSetWorksCorrectly() {
+    val map = Map(new TaskName("tn1") -> Set(makeSSP("tn1-1"), makeSSP("tn1-2")),
+                  new TaskName("tn2") -> Set(makeSSP("tn2-1"), makeSSP("tn2-2")))
+    val tntssp = TaskNamesToSystemStreamPartitions(map)
+
+    val asSet = tntssp.toSet
+    val expected = Set(new TaskName("tn1") -> Set(makeSSP("tn1-1"), makeSSP("tn1-2")),
+                      (new TaskName("tn2") -> Set(makeSSP("tn2-1"), makeSSP("tn2-2"))))
+    assertEquals(expected , asSet)
+  }
+
+  @Test
+  def validateMethodCatchesDuplicatedSSPs() {
+    val duplicatedSSP1 = new SystemStreamPartition("sys", "str", new Partition(42))
+    val duplicatedSSP2 = new SystemStreamPartition("sys", "str", new Partition(42))
+    val notDuplicatedSSP1 = new SystemStreamPartition("sys", "str2", new Partition(42))
+    val notDuplicatedSSP2 = new SystemStreamPartition("sys", "str3", new Partition(42))
+
+    val badMapping = Map(new TaskName("a") -> Set(notDuplicatedSSP1, duplicatedSSP1), new TaskName("b") -> Set(notDuplicatedSSP2, duplicatedSSP2))
+
+    var caughtException = false
+    try {
+      TaskNamesToSystemStreamPartitions(badMapping)
+    } catch {
+      case se: SamzaException => assertEquals("Assigning the same SystemStreamPartition to multiple " +
+        "TaskNames is not currently supported.  Out of compliance SystemStreamPartitions and counts: " +
+        "Map(SystemStreamPartition [sys, str, 42] -> 2)", se.getMessage)
+        caughtException = true
+      case _: Throwable       =>
+    }
+    assertTrue("TaskNamesToSystemStreamPartitions should have rejected this mapping but didn't", caughtException)
+  }
+
+  @Test
+  def validateMethodAllowsUniqueSSPs() {
+    val sspSet1 = (0 to 10).map(p => new SystemStreamPartition("sys", "str", new Partition(p))).toSet
+    val sspSet2 = (0 to 10).map(p => new SystemStreamPartition("sys", "str2", new Partition(p))).toSet
+
+    TaskNamesToSystemStreamPartitions(Map(new TaskName("set1") -> sspSet1, new TaskName("set2") -> sspSet2))
+  }
+}
diff --git a/samza-core/src/test/scala/org/apache/samza/container/systemstreampartition/groupers/TestGroupByPartition.scala b/samza-core/src/test/scala/org/apache/samza/container/systemstreampartition/groupers/TestGroupByPartition.scala
new file mode 100644 (file)
index 0000000..733be20
--- /dev/null
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.samza.container.systemstreampartition.groupers
+
+import org.apache.samza.container.{TaskName, SystemStreamPartitionGrouperTestBase, SystemStreamPartitionGrouper}
+import scala.collection.JavaConverters._
+import org.junit.Test
+
+class TestGroupByPartition extends SystemStreamPartitionGrouperTestBase {
+  import SystemStreamPartitionGrouperTestBase._
+
+  val expected /* from base class provided set */ =  Map(new TaskName("Partition 0") -> Set(aa0, ac0).asJava,
+                                                         new TaskName("Partition 1") -> Set(aa1, ab1).asJava,
+                                                         new TaskName("Partition 2") -> Set(aa2, ab2).asJava).asJava
+
+  override def getGrouper: SystemStreamPartitionGrouper = new GroupByPartition
+
+  @Test def groupingWorks() {
+    verifyGroupGroupsCorrectly(allSSPs, expected)
+  }
+}
diff --git a/samza-core/src/test/scala/org/apache/samza/container/systemstreampartition/groupers/TestGroupBySystemStreamPartition.scala b/samza-core/src/test/scala/org/apache/samza/container/systemstreampartition/groupers/TestGroupBySystemStreamPartition.scala
new file mode 100644 (file)
index 0000000..e9c15a5
--- /dev/null
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.samza.container.systemstreampartition.groupers
+
+import org.apache.samza.container.{TaskName, SystemStreamPartitionGrouperTestBase, SystemStreamPartitionGrouper}
+import scala.collection.JavaConverters._
+import org.junit.Test
+
+class TestGroupBySystemStreamPartition extends SystemStreamPartitionGrouperTestBase {
+  import SystemStreamPartitionGrouperTestBase._
+
+  // Building manually to avoid just duplicating a logic potential logic error here and there
+  val expected /* from base class provided set */ =  Map(new TaskName(aa0.toString) -> Set(aa0).asJava,
+    new TaskName(aa1.toString) -> Set(aa1).asJava,
+    new TaskName(aa2.toString) -> Set(aa2).asJava,
+    new TaskName(ab1.toString) -> Set(ab1).asJava,
+    new TaskName(ab2.toString) -> Set(ab2).asJava,
+    new TaskName(ac0.toString) -> Set(ac0).asJava).asJava
+
+  override def getGrouper: SystemStreamPartitionGrouper = new GroupBySystemStreamPartition
+
+  @Test def groupingWorks() {
+    verifyGroupGroupsCorrectly(allSSPs, expected)
+  }
+}
diff --git a/samza-core/src/test/scala/org/apache/samza/container/systemstreampartition/taskname/groupers/TestSimpleSystemStreamPartitionTaskNameGrouper.scala b/samza-core/src/test/scala/org/apache/samza/container/systemstreampartition/taskname/groupers/TestSimpleSystemStreamPartitionTaskNameGrouper.scala
new file mode 100644 (file)
index 0000000..7ea09cd
--- /dev/null
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.samza.container.systemstreampartition.taskname.groupers
+
+import org.apache.samza.container.{TaskName, TaskNamesToSystemStreamPartitions}
+import org.apache.samza.system.SystemStreamPartition
+import org.junit.Assert._
+import org.junit.Test
+
+class TestSimpleSystemStreamPartitionTaskNameGrouper {
+  val emptySSPSet = Set[SystemStreamPartition]()
+
+  @Test
+  def weGetAsExactlyManyGroupsAsWeAskFor() {
+    // memoize the maps used in the test to avoid an O(n^3) loop
+    val tntsspCache = scala.collection.mutable.Map[Int, TaskNamesToSystemStreamPartitions]()
+
+    def tntsspOfSize(size:Int) = {
+      def getMap(size:Int) = TaskNamesToSystemStreamPartitions((0 until size).map(z => new TaskName("tn" + z) -> emptySSPSet).toMap)
+
+      tntsspCache.getOrElseUpdate(size, getMap(size))
+    }
+
+    val maxTNTSSPSize = 1000
+    val maxNumGroups = 140
+    for(numGroups <- 1 to maxNumGroups) {
+      val grouper = new SimpleSystemStreamPartitionTaskNameGrouper(numGroups)
+
+      for (tntsspSize <- numGroups to maxTNTSSPSize) {
+        val map = tntsspOfSize(tntsspSize)
+        assertEquals(tntsspSize, map.size)
+
+        val grouped = grouper.groupTaskNames(map)
+        assertEquals("Asked for " + numGroups + " but got " + grouped.size, numGroups, grouped.size)
+      }
+    }
+  }
+}
index 21d8a78..258ccc1 100644 (file)
@@ -20,7 +20,6 @@
 package org.apache.samza.job
 import java.io.File
 import org.apache.samza.config.Config
-import org.junit.Assert._
 import org.junit.Test
 
 object TestJobRunner {
diff --git a/samza-core/src/test/scala/org/apache/samza/job/TestShellCommandBuilder.scala b/samza-core/src/test/scala/org/apache/samza/job/TestShellCommandBuilder.scala
new file mode 100644 (file)
index 0000000..f8a535a
--- /dev/null
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.samza.job
+
+import org.junit.Test
+import org.apache.samza.system.SystemStreamPartition
+import org.apache.samza.Partition
+import org.apache.samza.util.Util._
+import org.apache.samza.container.{TaskName, TaskNamesToSystemStreamPartitions}
+import org.junit.Assert._
+
+class TestShellCommandBuilder {
+
+  @Test
+  def testJsonCreateStreamPartitionStringRoundTrip() {
+    val getPartitions: Set[SystemStreamPartition] = {
+      // Build a heavily skewed set of partitions.
+      def partitionSet(max:Int) = (0 until max).map(new Partition(_)).toSet
+      val system = "all-same-system."
+      val lotsOfParts = Map(system + "topic-with-many-parts-a" -> partitionSet(128),
+        system + "topic-with-many-parts-b" -> partitionSet(128), system + "topic-with-many-parts-c" -> partitionSet(64))
+      val fewParts = ('c' to 'z').map(l => system + l.toString -> partitionSet(4)).toMap
+      val streamsMap = (lotsOfParts ++ fewParts)
+      (for(s <- streamsMap.keys;
+           part <- streamsMap.getOrElse(s, Set.empty)) yield new SystemStreamPartition(getSystemStreamFromNames(s), part)).toSet
+    }
+
+    // Group by partition...
+    val sspTaskNameMap = TaskNamesToSystemStreamPartitions(getPartitions.groupBy(p => new TaskName(p.getPartition.toString)).toMap)
+
+    val asString = ShellCommandBuilder.serializeSystemStreamPartitionSetToJSON(sspTaskNameMap.getJavaFriendlyType)
+
+    val backFromSSPTaskNameMap = TaskNamesToSystemStreamPartitions(ShellCommandBuilder.deserializeSystemStreamPartitionSetFromJSON(asString))
+    assertEquals(sspTaskNameMap, backFromSSPTaskNameMap)
+  }
+}
index 4f7ddcd..d425e86 100644 (file)
@@ -21,7 +21,6 @@ package org.apache.samza.metrics
 
 import org.junit.Assert._
 import org.junit.Test
-import org.apache.samza.config.MapConfig
 import grizzled.slf4j.Logging
 import javax.management.remote.{JMXConnector, JMXConnectorFactory, JMXServiceURL}
 import java.io.IOException
index 70d8c80..0d07314 100644 (file)
 
 package org.apache.samza.serializers
 
+import java.util
+import org.apache.samza.Partition
+import org.apache.samza.checkpoint.Checkpoint
+import org.apache.samza.container.TaskName
+import org.apache.samza.system.SystemStreamPartition
 import org.junit.Assert._
 import org.junit.Test
-import org.apache.samza.system.SystemStream
-import org.apache.samza.checkpoint.Checkpoint
 import scala.collection.JavaConversions._
-import org.apache.samza.system.SystemStreamPartition
-import org.apache.samza.SamzaException
-import org.apache.samza.Partition
 
 class TestCheckpointSerde {
   @Test
   def testExactlyOneOffset {
     val serde = new CheckpointSerde
-    var offsets = Map[SystemStream, String]()
-    val systemStream = new SystemStream("test-system", "test-stream")
-    offsets += systemStream -> "1"
+    var offsets = Map[SystemStreamPartition, String]()
+    val systemStreamPartition = new SystemStreamPartition("test-system", "test-stream", new Partition(777))
+    offsets += systemStreamPartition -> "1"
     val deserializedOffsets = serde.fromBytes(serde.toBytes(new Checkpoint(offsets)))
-    assertEquals("1", deserializedOffsets.getOffsets.get(systemStream))
+    assertEquals("1", deserializedOffsets.getOffsets.get(systemStreamPartition))
     assertEquals(1, deserializedOffsets.getOffsets.size)
   }
 
   @Test
-  def testMoreThanOneOffsetShouldFail {
-    val serde = new CheckpointSerde
-    var offsets = Map[SystemStream, String]()
-    // Since SS != SSP with same system and stream name, this should result in 
-    // two offsets for one system stream in the serde.
-    offsets += new SystemStream("test-system", "test-stream") -> "1"
-    offsets += new SystemStreamPartition("test-system", "test-stream", new Partition(0)) -> "2"
-    try {
-      serde.toBytes(new Checkpoint(offsets))
-      fail("Expected to fail with more than one offset for a single SystemStream.")
-    } catch {
-      case e: SamzaException => // expected this
-    }
+  def testChangelogPartitionMappingRoundTrip {
+    val mapping = new util.HashMap[TaskName, java.lang.Integer]()
+    mapping.put(new TaskName("Ted"), 0)
+    mapping.put(new TaskName("Dougal"), 1)
+    mapping.put(new TaskName("Jack"), 2)
+
+    val checkpointSerde = new CheckpointSerde
+    val asBytes = checkpointSerde.changelogPartitionMappingToBytes(mapping)
+    val backToMap = checkpointSerde.changelogPartitionMappingFromBytes(asBytes)
+
+    assertEquals(mapping, backToMap)
+    assertNotSame(mapping, backToMap)
   }
-}
\ No newline at end of file
+
+}
index d31c3ce..f505eb1 100644 (file)
 
 package org.apache.samza.system.filereader
 
-import org.junit.Test
-import org.junit.Assert._
-import org.apache.samza.system.SystemStreamPartition
-import org.junit.AfterClass
-import java.io.PrintWriter
 import java.io.File
+import java.io.FileWriter
+import java.io.PrintWriter
 import org.apache.samza.Partition
+import org.apache.samza.system.SystemStreamPartition
+import org.junit.AfterClass
+import org.junit.Assert._
+import org.junit.BeforeClass
+import org.junit.Test
 import scala.collection.JavaConversions._
 import scala.collection.mutable.HashMap
-import org.junit.BeforeClass
-import java.io.FileWriter
 
 object TestFileReaderSystemConsumer {
   val consumer = new FileReaderSystemConsumer("file-reader", null)
index 12f1e03..7cfeb5a 100644 (file)
@@ -21,13 +21,15 @@ package org.apache.samza.task
 
 import org.junit.Assert._
 import org.junit.Test
-import org.apache.samza.Partition
 import org.apache.samza.task.TaskCoordinator.RequestScope
+import org.apache.samza.container.TaskName
 
 class TestReadableCoordinator {
+  val taskName = new TaskName("P0")
+
   @Test
   def testCommitTask {
-    val coord = new ReadableCoordinator(new Partition(0))
+    val coord = new ReadableCoordinator(taskName)
     assertFalse(coord.requestedCommitTask)
     assertFalse(coord.requestedCommitAll)
     coord.commit(RequestScope.CURRENT_TASK)
@@ -37,7 +39,7 @@ class TestReadableCoordinator {
 
   @Test
   def testCommitAll {
-    val coord = new ReadableCoordinator(new Partition(0))
+    val coord = new ReadableCoordinator(taskName)
     assertFalse(coord.requestedCommitTask)
     assertFalse(coord.requestedCommitAll)
     coord.commit(RequestScope.ALL_TASKS_IN_CONTAINER)
@@ -47,7 +49,7 @@ class TestReadableCoordinator {
 
   @Test
   def testShutdownNow {
-    val coord = new ReadableCoordinator(new Partition(0))
+    val coord = new ReadableCoordinator(taskName)
     assertFalse(coord.requestedShutdownOnConsensus)
     assertFalse(coord.requestedShutdownNow)
     coord.shutdown(RequestScope.ALL_TASKS_IN_CONTAINER)
@@ -57,7 +59,7 @@ class TestReadableCoordinator {
 
   @Test
   def testShutdownRequest {
-    val coord = new ReadableCoordinator(new Partition(0))
+    val coord = new ReadableCoordinator(taskName)
     assertFalse(coord.requestedShutdownOnConsensus)
     assertFalse(coord.requestedShutdownNow)
     coord.shutdown(RequestScope.CURRENT_TASK)
index ad6d2da..7c314ce 100644 (file)
@@ -20,9 +20,13 @@ package org.apache.samza.util
 
 import org.apache.samza.Partition
 import org.apache.samza.config.Config
+import org.apache.samza.config.Config
 import org.apache.samza.config.MapConfig
+import org.apache.samza.container.{TaskName, TaskNamesToSystemStreamPartitions}
+import org.apache.samza.metrics.MetricsRegistry
 import org.apache.samza.metrics.MetricsRegistry
 import org.apache.samza.system.SystemFactory
+import org.apache.samza.system.SystemFactory
 import org.apache.samza.system.SystemStreamPartition
 import org.apache.samza.util.Util._
 import org.junit.Assert._
@@ -54,48 +58,32 @@ class TestUtil {
   }
 
   @Test
-  def testGetTopicPartitionsForTask() {
-    def partitionSet(max:Int) = (0 until max).map(new Partition(_)).toSet
-
-    val taskCount = 4
-    val streamsMap = Map("kafka.a" -> partitionSet(4), "kafka.b" -> partitionSet(18), "timestream.c" -> partitionSet(24))
-    val streamsAndParts = (for(s <- streamsMap.keys;
-                               part <- streamsMap.getOrElse(s, Set.empty))
-    yield new SystemStreamPartition(getSystemStreamFromNames(s), part)).toSet
+  def testResolveTaskNameToChangelogPartitionMapping {
+    def testRunner(description:String, currentTaskNames:Set[TaskName], previousTaskNameMapping:Map[TaskName, Int],
+                   result:Map[TaskName, Int]) {
+      assertEquals("Failed: " + description, result,
+        Util.resolveTaskNameToChangelogPartitionMapping(currentTaskNames, previousTaskNameMapping))
+    }
 
-    for(i <- 0 until taskCount) {
-      val result: Set[SystemStreamPartition] = Util.getStreamsAndPartitionsForContainer(i, taskCount, streamsAndParts)
-      // b -> 18 % 4 = 2 therefore first two results should have an extra element
-      if(i < 2) {
-        assertEquals(12, result.size)
-      } else {
-        assertEquals(11, result.size)
-      }
+    testRunner("No change between runs",
+      Set(new TaskName("Partition 0")),
+      Map(new TaskName("Partition 0") -> 0),
+      Map(new TaskName("Partition 0") -> 0))
 
-      result.foreach(r => assertEquals(i, r.getPartition.getPartitionId % taskCount))
-    }
-  }
-  
-  @Test
-  def testJsonCreateStreamPartitionStringRoundTrip() {
-    val getPartitions: Set[SystemStreamPartition] = {
-      // Build a heavily skewed set of partitions.
-      def partitionSet(max:Int) = (0 until max).map(new Partition(_)).toSet
-      val system = "all-same-system."
-      val lotsOfParts = Map(system + "topic-with-many-parts-a" -> partitionSet(128),
-        system + "topic-with-many-parts-b" -> partitionSet(128), system + "topic-with-many-parts-c" -> partitionSet(64))
-      val fewParts = ('c' to 'z').map(l => system + l.toString -> partitionSet(4)).toMap
-      val streamsMap = (lotsOfParts ++ fewParts)
-      (for(s <- streamsMap.keys;
-           part <- streamsMap.getOrElse(s, Set.empty)) yield new SystemStreamPartition(getSystemStreamFromNames(s), part)).toSet
-    }
+    testRunner("New TaskName added, none missing this run",
+      Set(new TaskName("Partition 0"), new TaskName("Partition 1")),
+      Map(new TaskName("Partition 0") -> 0),
+      Map(new TaskName("Partition 0") -> 0, new TaskName("Partition 1") -> 1))
 
-    val streamsAndParts: Set[SystemStreamPartition] = getStreamsAndPartitionsForContainer(0, 4, getPartitions).toSet
-    println(streamsAndParts)
-    val asString = serializeSSPSetToJSON(streamsAndParts)
+    testRunner("New TaskName added, one missing this run",
+      Set(new TaskName("Partition 0"), new TaskName("Partition 2")),
+      Map(new TaskName("Partition 0") -> 0, new TaskName("Partition 1") -> 1),
+      Map(new TaskName("Partition 0") -> 0, new TaskName("Partition 1") -> 1, new TaskName("Partition 2") -> 2))
 
-    val backToStreamsAndParts = deserializeSSPSetFromJSON(asString)
-    assertEquals(streamsAndParts, backToStreamsAndParts)
+    testRunner("New TaskName added, all previous missing this run",
+      Set(new TaskName("Partition 3"), new TaskName("Partition 4")),
+      Map(new TaskName("Partition 0") -> 0, new TaskName("Partition 1") -> 1, new TaskName("Partition 2") -> 2),
+      Map(new TaskName("Partition 0") -> 0, new TaskName("Partition 1") -> 1, new TaskName("Partition 2") -> 2, new TaskName("Partition 3") -> 3, new TaskName("Partition 4") -> 4))
   }
 
   /**
diff --git a/samza-kafka/src/main/scala/org/apache/samza/checkpoint/kafka/KafkaCheckpointLogKey.scala b/samza-kafka/src/main/scala/org/apache/samza/checkpoint/kafka/KafkaCheckpointLogKey.scala
new file mode 100644 (file)
index 0000000..5d8ee4f
--- /dev/null
@@ -0,0 +1,186 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.samza.checkpoint.kafka
+
+import java.util
+import org.apache.samza.SamzaException
+import org.apache.samza.container.TaskName
+import org.codehaus.jackson.`type`.TypeReference
+import org.codehaus.jackson.map.ObjectMapper
+import scala.collection.JavaConversions._
+
+/**
+ * Kafka Checkpoint Log-specific key used to identify what type of entry is
+ * written for any particular log entry.
+ *
+ * @param map Backing map to hold key values
+ */
+class KafkaCheckpointLogKey private (val map: Map[String, String]) {
+  // This might be better as a case class...
+  import KafkaCheckpointLogKey._
+
+  /**
+   * Serialize this key to bytes
+   * @return Key as bytes
+   */
+  def toBytes(): Array[Byte] = {
+    val jMap = new util.HashMap[String, String](map.size)
+    jMap.putAll(map)
+
+    JSON_MAPPER.writeValueAsBytes(jMap)
+  }
+
+  private def getKey = map.getOrElse(CHECKPOINT_KEY_KEY, throw new SamzaException("No " + CHECKPOINT_KEY_KEY  + " in map for Kafka Checkpoint log key"))
+
+  /**
+   * Is this key for a checkpoint entry?
+   *
+   * @return true iff this key's entry is for a checkpoint
+   */
+  def isCheckpointKey = getKey.equals(CHECKPOINT_KEY_TYPE)
+
+  /**
+   * Is this key for a changelog partition mapping?
+   *
+   * @return true iff this key's entry is for a changelog partition mapping
+   */
+  def isChangelogPartitionMapping = getKey.equals(CHANGELOG_PARTITION_KEY_TYPE)
+
+  /**
+   * If this Key is for a checkpoint entry, return its associated TaskName.
+   *
+   * @return TaskName for this checkpoint or throw an exception if this key does not have a TaskName entry
+   */
+  def getCheckpointTaskName = {
+    val asString = map.getOrElse(CHECKPOINT_TASKNAME_KEY, throw new SamzaException("No TaskName in checkpoint key: " + this))
+    new TaskName(asString)
+  }
+
+  def canEqual(other: Any): Boolean = other.isInstanceOf[KafkaCheckpointLogKey]
+
+  override def equals(other: Any): Boolean = other match {
+    case that: KafkaCheckpointLogKey =>
+      (that canEqual this) &&
+        map == that.map
+    case _ => false
+  }
+
+  override def hashCode(): Int = {
+    val state = Seq(map)
+    state.map(_.hashCode()).foldLeft(0)((a, b) => 31 * a + b)
+  }
+}
+
+object KafkaCheckpointLogKey {
+  /**
+   *  Messages in the checkpoint log have keys associated with them. These keys are maps that describe the message's
+   *  type, either a checkpoint or a changelog-partition-mapping.
+   */
+  val CHECKPOINT_KEY_KEY = "type"
+  val CHECKPOINT_KEY_TYPE = "checkpoint"
+  val CHANGELOG_PARTITION_KEY_TYPE = "changelog-partition-mapping"
+  val CHECKPOINT_TASKNAME_KEY = "taskName"
+  val SYSTEMSTREAMPARTITION_GROUPER_FACTORY_KEY = "systemstreampartition-grouper-factory"
+
+  /**
+   * Partition mapping keys have no dynamic values, so we just need one instance.
+   */
+  val CHANGELOG_PARTITION_MAPPING_KEY = new KafkaCheckpointLogKey(Map(CHECKPOINT_KEY_KEY -> CHANGELOG_PARTITION_KEY_TYPE))
+
+  private val JSON_MAPPER = new ObjectMapper()
+  val KEY_TYPEREFERENCE = new TypeReference[util.HashMap[String, String]]() {}
+
+  var systemStreamPartitionGrouperFactoryString:Option[String] = None
+
+  /**
+   * Set the name of the factory configured to provide the SystemStreamPartition grouping
+   * so it be included in the key.
+   *
+   * @param str Config value of SystemStreamPartition Grouper Factory
+   */
+  def setSystemStreamPartitionGrouperFactoryString(str:String) = {
+    systemStreamPartitionGrouperFactoryString = Some(str)
+  }
+
+  /**
+   * Get the name of the factory configured to provide the SystemStreamPartition grouping
+   * so it be included in the key
+   */
+  def getSystemStreamPartitionGrouperFactoryString = systemStreamPartitionGrouperFactoryString.getOrElse(throw new SamzaException("No SystemStreamPartition grouping factory string has been set."))
+
+  /**
+   * Build a key for a a checkpoint log entry for a particular TaskName
+   * @param taskName TaskName to build for this checkpoint entry
+   *
+   * @return Key for checkpoint log entry
+   */
+  def getCheckpointKey(taskName:TaskName) = {
+    val map = Map(CHECKPOINT_KEY_KEY -> CHECKPOINT_KEY_TYPE,
+      CHECKPOINT_TASKNAME_KEY -> taskName.getTaskName,
+      SYSTEMSTREAMPARTITION_GROUPER_FACTORY_KEY -> getSystemStreamPartitionGrouperFactoryString)
+
+    new KafkaCheckpointLogKey(map)
+  }
+
+  /**
+   * Build a key for a changelog partition mapping entry
+   *
+   * @return Key for changelog partition mapping entry
+   */
+  def getChangelogPartitionMappingKey() = CHANGELOG_PARTITION_MAPPING_KEY
+
+  /**
+   * Deserialize a Kafka checkpoint log key
+   * @param bytes Serialized (via JSON) Kafka checkpoint log key
+   * @return Checkpoint log key
+   */
+  def fromBytes(bytes: Array[Byte]): KafkaCheckpointLogKey = {
+    try {
+      val jmap: util.HashMap[String, String] = JSON_MAPPER.readValue(bytes, KEY_TYPEREFERENCE)
+
+      if(!jmap.containsKey(CHECKPOINT_KEY_KEY)) {
+        throw new SamzaException("No type entry in checkpoint key: " + jmap)
+      }
+
+      // Only checkpoint keys have ssp grouper factory keys
+      if(jmap.get(CHECKPOINT_KEY_KEY).equals(CHECKPOINT_KEY_TYPE)) {
+        val sspGrouperFactory = jmap.get(SYSTEMSTREAMPARTITION_GROUPER_FACTORY_KEY)
+
+        if (sspGrouperFactory == null) {
+          throw new SamzaException("No SystemStreamPartition Grouper factory entry in checkpoint key: " + jmap)
+        }
+
+        if (!sspGrouperFactory.equals(getSystemStreamPartitionGrouperFactoryString)) {
+          throw new DifferingSystemStreamPartitionGrouperFactoryValues(sspGrouperFactory, getSystemStreamPartitionGrouperFactoryString)
+        }
+      }
+
+      new KafkaCheckpointLogKey(jmap.toMap)
+    } catch {
+      case e: Exception =>
+        throw new SamzaException("Exception while deserializing checkpoint key", e)
+    }
+  }
+}
+
+class DifferingSystemStreamPartitionGrouperFactoryValues(inKey:String, inConfig:String) extends SamzaException {
+  override def getMessage() = "Checkpoint key's SystemStreamPartition Grouper factory (" + inKey +
+    ") does not match value from current configuration (" + inConfig + ").  " +
+    "This likely means the SystemStreamPartitionGrouper was changed between job runs, which is not supported."
+}
\ No newline at end of file
index 15245d4..fff62e4 100644 (file)
 
 package org.apache.samza.checkpoint.kafka
 
-import org.I0Itec.zkclient.ZkClient
 import grizzled.slf4j.Logging
+import java.nio.ByteBuffer
+import java.util
 import kafka.admin.AdminUtils
-import kafka.api.FetchRequestBuilder
-import kafka.api.OffsetRequest
-import kafka.api.PartitionOffsetRequestInfo
+import kafka.api._
 import kafka.common.ErrorMapping
+import kafka.common.InvalidMessageSizeException
 import kafka.common.TopicAndPartition
 import kafka.common.TopicExistsException
-import kafka.common.InvalidMessageSizeException
 import kafka.common.UnknownTopicOrPartitionException
 import kafka.consumer.SimpleConsumer
+import kafka.message.InvalidMessageException
 import kafka.producer.KeyedMessage
-import kafka.producer.Partitioner
 import kafka.producer.Producer
-import kafka.serializer.Decoder
-import kafka.serializer.Encoder
 import kafka.utils.Utils
-import kafka.utils.VerifiableProperties
-import kafka.message.InvalidMessageException
-import org.apache.samza.Partition
+import org.I0Itec.zkclient.ZkClient
 import org.apache.samza.SamzaException
 import org.apache.samza.checkpoint.Checkpoint
 import org.apache.samza.checkpoint.CheckpointManager
+import org.apache.samza.container.TaskName
 import org.apache.samza.serializers.CheckpointSerde
-import org.apache.samza.serializers.Serde
 import org.apache.samza.system.kafka.TopicMetadataCache
-import org.apache.samza.util.TopicMetadataStore
 import org.apache.samza.util.ExponentialSleepStrategy
+import org.apache.samza.util.TopicMetadataStore
+import scala.collection.mutable
 
 /**
- * Kafka checkpoint manager is used to store checkpoints in a Kafka topic that
- * is uniquely identified by a job/partition combination. To read a checkpoint
- * for a given job and partition combination (e.g. my-job, partition 1), we
- * simply read the last message from the topic: __samza_checkpoint_my-job_1. If
- * the topic does not yet exist, we assume that there is not yet any state for
- * this job/partition pair, and return an empty checkpoint.
+ * Kafka checkpoint manager is used to store checkpoints in a Kafka topic.
+ * To read a checkpoint for a specific taskName, we find the newest message
+ * keyed to that taskName. If there is no such message, no checkpoint data
+ * exists.  The underlying log has a single partition into which all
+ * checkpoints and TaskName to changelog partition mappings are written.
  */
 class KafkaCheckpointManager(
   clientId: String,
   checkpointTopic: String,
   systemName: String,
-  totalPartitions: Int,
   replicationFactor: Int,
   socketTimeout: Int,
   bufferSize: Int,
   fetchSize: Int,
   metadataStore: TopicMetadataStore,
-  connectProducer: () => Producer[Partition, Array[Byte]],
+  connectProducer: () => Producer[Array[Byte], Array[Byte]],
   connectZk: () => ZkClient,
+  systemStreamPartitionGrouperFactoryString: String,
   retryBackoff: ExponentialSleepStrategy = new ExponentialSleepStrategy,
-  serde: Serde[Checkpoint] = new CheckpointSerde) extends CheckpointManager with Logging {
+  serde: CheckpointSerde = new CheckpointSerde) extends CheckpointManager with Logging {
+  import KafkaCheckpointManager._
+
+  var taskNames = Set[TaskName]()
+  var producer: Producer[Array[Byte], Array[Byte]] = null
+  var taskNamesToOffsets: Map[TaskName, Checkpoint] = null
 
-  var partitions = Set[Partition]()
-  var producer: Producer[Partition, Array[Byte]] = null
+  var startingOffset: Option[Long] = None // Where to start reading for each subsequent call of readCheckpoint
 
-  info("Creating KafkaCheckpointManager with: clientId=%s, checkpointTopic=%s, systemName=%s" format (clientId, checkpointTopic, systemName))
+  KafkaCheckpointLogKey.setSystemStreamPartitionGrouperFactoryString(systemStreamPartitionGrouperFactoryString)
 
-  def writeCheckpoint(partition: Partition, checkpoint: Checkpoint) {
+  info("Creating KafkaCheckpointManager with: clientId=%s, checkpointTopic=%s, systemName=%s" format(clientId, checkpointTopic, systemName))
+
+  /**
+   * Write Checkpoint for specified taskName to log
+   *
+   * @param taskName Specific Samza taskName of which to write a checkpoint of.
+   * @param checkpoint Reference to a Checkpoint object to store offset data in.
+   **/
+  override def writeCheckpoint(taskName: TaskName, checkpoint: Checkpoint) {
+    val key = KafkaCheckpointLogKey.getCheckpointKey(taskName)
+    val keyBytes = key.toBytes()
+    val msgBytes = serde.toBytes(checkpoint)
+
+    writeLog(CHECKPOINT_LOG4J_ENTRY, keyBytes, msgBytes)
+  }
+
+  /**
+   * Write the taskName to partition mapping that is being maintained by this CheckpointManager
+   *
+   * @param changelogPartitionMapping Each TaskName's partition within the changelog
+   */
+  override def writeChangeLogPartitionMapping(changelogPartitionMapping: util.Map[TaskName, java.lang.Integer]) {
+    val key = KafkaCheckpointLogKey.getChangelogPartitionMappingKey()
+    val keyBytes = key.toBytes()
+    val msgBytes = serde.changelogPartitionMappingToBytes(changelogPartitionMapping)
+
+    writeLog(CHANGELOG_PARTITION_MAPPING_LOG4j, keyBytes, msgBytes)
+  }
+
+  /**
+   * Common code for writing either checkpoints or changelog-partition-mappings to the log
+   *
+   * @param logType Type of entry that is being written, for logging
+   * @param key pre-serialized key for message
+   * @param msg pre-serialized message to write to log
+   */
+  private def writeLog(logType:String, key: Array[Byte], msg: Array[Byte]) {
     retryBackoff.run(
       loop => {
         if (producer == null) {
           producer = connectProducer()
         }
-        producer.send(new KeyedMessage(checkpointTopic, null, partition, serde.toBytes(checkpoint)))
+
+        producer.send(new KeyedMessage(checkpointTopic, key, 0, msg))
         loop.done
       },
 
       (exception, loop) => {
-        warn("Failed to send checkpoint %s for partition %s: %s. Retrying." format (checkpoint, partition, exception))
+        warn("Failed to write %s partition entry %s: %s. Retrying." format(logType, key, exception))
         debug("Exception detail:", exception)
         if (producer != null) {
           producer.close
@@ -98,124 +134,219 @@ class KafkaCheckpointManager(
     )
   }
 
-  def readLastCheckpoint(partition: Partition): Checkpoint = {
-    info("Reading checkpoint for partition %s." format partition.getPartitionId)
+  private def getConsumer(): SimpleConsumer = {
+    val metadataMap = TopicMetadataCache.getTopicMetadata(Set(checkpointTopic), systemName, (topics: Set[String]) => metadataStore.getTopicInfo(topics))
+    val metadata = metadataMap(checkpointTopic)
+    val partitionMetadata = metadata.partitionsMetadata
+      .filter(_.partitionId == 0)
+      .headOption
+      .getOrElse(throw new KafkaCheckpointException("Tried to find partition information for partition 0 for checkpoint topic, but it didn't exist in Kafka."))
+    val leader = partitionMetadata
+      .leader
+      .getOrElse(throw new SamzaException("No leader available for topic %s" format checkpointTopic))
+
+    info("Connecting to leader %s:%d for topic %s and to fetch all checkpoint messages." format(leader.host, leader.port, checkpointTopic))
+
+    new SimpleConsumer(leader.host, leader.port, socketTimeout, bufferSize, clientId)
+  }
+
+  private def getEarliestOffset(consumer: SimpleConsumer, topicAndPartition: TopicAndPartition): Long = consumer.earliestOrLatestOffset(topicAndPartition, OffsetRequest.EarliestTime, -1)
+
+  private def getOffset(consumer: SimpleConsumer, topicAndPartition: TopicAndPartition, earliestOrLatest: Long): Long = {
+    val offsetResponse = consumer.getOffsetsBefore(new OffsetRequest(Map(topicAndPartition -> PartitionOffsetRequestInfo(earliestOrLatest, 1))))
+      .partitionErrorAndOffsets
+      .get(topicAndPartition)
+      .getOrElse(throw new KafkaCheckpointException("Unable to find offset information for %s:0" format checkpointTopic))
+    // Fail or retry if there was an an issue with the offset request.
+    ErrorMapping.maybeThrowException(offsetResponse.error)
+
+    val offset: Long = offsetResponse
+      .offsets
+      .headOption
+      .getOrElse(throw new KafkaCheckpointException("Got response, but no offsets defined for %s:0" format checkpointTopic))
+
+    offset
+  }
+
+  /**
+   * Read the last checkpoint for specified TaskName
+   *
+   * @param taskName Specific Samza taskName for which to get the last checkpoint of.
+   **/
+  override def readLastCheckpoint(taskName: TaskName): Checkpoint = {
+    if (!taskNames.contains(taskName)) {
+      throw new SamzaException(taskName + " not registered with this CheckpointManager")
+    }
+
+    info("Reading checkpoint for taskName " + taskName)
+
+    if (taskNamesToOffsets == null) {
+      info("No TaskName to checkpoint mapping provided.  Reading for first time.")
+      taskNamesToOffsets = readCheckpointsFromLog()
+    } else {
+      info("Already existing checkpoint mapping.  Merging new offsets")
+      taskNamesToOffsets ++= readCheckpointsFromLog()
+    }
+
+    val checkpoint = taskNamesToOffsets.get(taskName).getOrElse(null)
+
+    info("Got checkpoint state for taskName %s: %s" format(taskName, checkpoint))
+
+    checkpoint
+  }
+
+  /**
+   * Read through entire log, discarding changelog mapping, and building map of TaskNames to Checkpoints
+   */
+  private def readCheckpointsFromLog(): Map[TaskName, Checkpoint] = {
+    val checkpoints = mutable.Map[TaskName, Checkpoint]()
+
+    def shouldHandleEntry(key: KafkaCheckpointLogKey) = key.isCheckpointKey
+
+    def handleCheckpoint(payload: ByteBuffer, checkpointKey:KafkaCheckpointLogKey): Unit = {
+      val taskName = checkpointKey.getCheckpointTaskName
+
+      if (taskNames.contains(taskName)) {
+        val checkpoint = serde.fromBytes(Utils.readBytes(payload))
+
+        debug("Adding checkpoint " + checkpoint + " for taskName " + taskName)
+
+        checkpoints.put(taskName, checkpoint) // replacing any existing, older checkpoints as we go
+      }
+    }
+
+    readLog(CHECKPOINT_LOG4J_ENTRY, shouldHandleEntry, handleCheckpoint)
+
+    checkpoints.toMap /* of the immutable kind */
+  }
+
+  /**
+   * Read through entire log, discarding checkpoints, finding latest changelogPartitionMapping
+   *
+   * Lots of duplicated code from the checkpoint method, but will be better to refactor this code into AM-based
+   * checkpoint log reading
+   */
+  override def readChangeLogPartitionMapping(): util.Map[TaskName, java.lang.Integer] = {
+    var changelogPartitionMapping: util.Map[TaskName, java.lang.Integer] = new util.HashMap[TaskName, java.lang.Integer]()
+
+    def shouldHandleEntry(key: KafkaCheckpointLogKey) = key.isChangelogPartitionMapping
+
+    def handleCheckpoint(payload: ByteBuffer, checkpointKey:KafkaCheckpointLogKey): Unit = {
+      changelogPartitionMapping = serde.changelogPartitionMappingFromBytes(Utils.readBytes(payload))
+
+      debug("Adding changelog partition mapping" + changelogPartitionMapping)
+    }
 
-    val checkpoint = retryBackoff.run(
+    readLog(CHANGELOG_PARTITION_MAPPING_LOG4j, shouldHandleEntry, handleCheckpoint)
+
+    changelogPartitionMapping
+  }
+
+  /**
+   * Common code for reading both changelog partition mapping and change log
+   *
+   * @param entryType What type of entry to look for within the log key's
+   * @param handleEntry Code to handle an entry in the log once it's found
+   */
+  private def readLog(entryType:String, shouldHandleEntry: (KafkaCheckpointLogKey) => Boolean,
+                      handleEntry: (ByteBuffer, KafkaCheckpointLogKey) => Unit): Unit = {
+    retryBackoff.run[Unit](
       loop => {
-        // Assume checkpoint topic exists with correct partitions, since it should be verified on start.
-        // Fetch the metadata for this checkpoint topic/partition pair.
-        val metadataMap = TopicMetadataCache.getTopicMetadata(Set(checkpointTopic), systemName, (topics: Set[String]) => metadataStore.getTopicInfo(topics))
-        val metadata = metadataMap(checkpointTopic)
-        val partitionMetadata = metadata.partitionsMetadata
-          .filter(_.partitionId == partition.getPartitionId)
-          .headOption
-          .getOrElse(throw new KafkaCheckpointException("Tried to find partition information for partition %d, but it didn't exist in Kafka." format partition.getPartitionId))
-        val partitionId = partitionMetadata.partitionId
-        val leader = partitionMetadata
-          .leader
-          .getOrElse(throw new SamzaException("No leader available for topic %s" format checkpointTopic))
-
-        info("Connecting to leader %s:%d for topic %s and partition %s to fetch last checkpoint message." format (leader.host, leader.port, checkpointTopic, partitionId))
-
-        val consumer = new SimpleConsumer(
-          leader.host,
-          leader.port,
-          socketTimeout,
-          bufferSize,
-          clientId)
+        val consumer = getConsumer()
+
+        val topicAndPartition = new TopicAndPartition(checkpointTopic, 0)
+
         try {
-          val topicAndPartition = new TopicAndPartition(checkpointTopic, partitionId)
-          val offsetResponse = consumer.getOffsetsBefore(new OffsetRequest(Map(topicAndPartition -> PartitionOffsetRequestInfo(OffsetRequest.LatestTime, 1))))
-            .partitionErrorAndOffsets
-            .get(topicAndPartition)
-            .getOrElse(throw new KafkaCheckpointException("Unable to find offset information for %s:%d" format (checkpointTopic, partitionId)))
+          var offset = startingOffset.getOrElse(getEarliestOffset(consumer, topicAndPartition))
 
-          // Fail or retry if there was an an issue with the offset request.
-          ErrorMapping.maybeThrowException(offsetResponse.error)
+          info("Got offset %s for topic %s and partition 0. Attempting to fetch messages for %s." format(offset, checkpointTopic, entryType))
 
-          val offset = offsetResponse
-            .offsets
-            .headOption
-            .getOrElse(throw new KafkaCheckpointException("Got response, but no offsets defined for %s:%d" format (checkpointTopic, partitionId)))
+          val latestOffset = getOffset(consumer, topicAndPartition, OffsetRequest.LatestTime)
 
-          info("Got offset %s for topic %s and partition %s. Attempting to fetch message." format (offset, checkpointTopic, partitionId))
+          info("Get latest offset %s for topic %s and partition 0." format(latestOffset, checkpointTopic))
 
-          if (offset <= 0) {
-            info("Got offset 0 (no messages in checkpoint topic) for topic %s and partition %s, so returning null. If you expected the checkpoint topic to have messages, you're probably going to lose data." format (checkpointTopic, partition))
-            return null
+          if (offset < 0) {
+            info("Got offset 0 (no messages in %s) for topic %s and partition 0, so returning empty collection. If you expected the checkpoint topic to have messages, you're probably going to lose data." format (entryType, checkpointTopic))
+            return
           }
 
-          val request = new FetchRequestBuilder()
-            // Kafka returns 1 greater than the offset of the last message in
-            // the topic, so subtract one to fetch the last message.
-            .addFetch(checkpointTopic, partitionId, offset - 1, fetchSize)
-            .maxWait(500)
-            .minBytes(1)
-            .clientId(clientId)
-            .build
-          val messageSet = consumer.fetch(request)
-          if (messageSet.hasError) {
-            warn("Got error code from broker for %s: %s" format (checkpointTopic, messageSet.errorCode(checkpointTopic, partitionId)))
-            val errorCode = messageSet.errorCode(checkpointTopic, partitionId)
-            if (ErrorMapping.OffsetOutOfRangeCode.equals(errorCode)) {
-              warn("Got an offset out of range exception while getting last checkpoint for topic %s and partition %s, so returning a null offset to the KafkaConsumer. Let it decide what to do based on its autooffset.reset setting." format (checkpointTopic, partitionId))
-              return null
+          while (offset < latestOffset) {
+            val request = new FetchRequestBuilder()
+              .addFetch(checkpointTopic, 0, offset, fetchSize)
+              .maxWait(500)
+              .minBytes(1)
+              .clientId(clientId)
+              .build
+
+            val fetchResponse = consumer.fetch(request)
+            if (fetchResponse.hasError) {
+              warn("Got error code from broker for %s: %s" format(checkpointTopic, fetchResponse.errorCode(checkpointTopic, 0)))
+              val errorCode = fetchResponse.errorCode(checkpointTopic, 0)
+              if (ErrorMapping.OffsetOutOfRangeCode.equals(errorCode)) {
+                warn("Got an offset out of range exception while getting last entry in %s for topic %s and partition 0, so returning a null offset to the KafkaConsumer. Let it decide what to do based on its autooffset.reset setting." format (entryType, checkpointTopic))
+                return
+              }
+              ErrorMapping.maybeThrowException(errorCode)
             }
-            ErrorMapping.maybeThrowException(errorCode)
-          }
-          val messages = messageSet.messageSet(checkpointTopic, partitionId).toList
 
-          if (messages.length != 1) {
-            throw new KafkaCheckpointException("Something really unexpected happened. Got %s "
-              + "messages back when fetching from checkpoint topic %s and partition %s. "
-              + "Expected one message. It would be unsafe to go on without the latest checkpoint, "
-              + "so failing." format (messages.length, checkpointTopic, partition))
-          }
+            for (response <- fetchResponse.messageSet(checkpointTopic, 0)) {
+              offset = response.nextOffset
+              startingOffset = Some(offset) // For next time we call
+
+              if (!response.message.hasKey) {
+                throw new KafkaCheckpointException("Encountered message without key.")
+              }
 
-          // Some back bending to go from message to checkpoint.
-          val checkpoint = serde.fromBytes(Utils.readBytes(messages(0).message.payload))
-          loop.done
-          checkpoint
+              val checkpointKey = KafkaCheckpointLogKey.fromBytes(Utils.readBytes(response.message.key))
+
+              if (!shouldHandleEntry(checkpointKey)) {
+                debug("Skipping " + entryType + " entry with key " + checkpointKey)
+              } else {
+                handleEntry(response.message.payload, checkpointKey)
+              }
+            }
+          }
         } finally {
-          consumer.close
+          consumer.close()
         }
+
+        loop.done
+        Unit
       },
 
       (exception, loop) => {
         exception match {
-          case e: InvalidMessageException => throw new KafkaCheckpointException ("Got InvalidMessageException from Kafka, which is unrecoverable, so fail the samza job", e)
-          case e: InvalidMessageSizeException => throw new KafkaCheckpointException ("Got InvalidMessageSizeException from Kafka, which is unrecoverable, so fail the samza job", e)
-          case e: UnknownTopicOrPartitionException => throw new KafkaCheckpointException ("Got UnknownTopicOrPartitionException from Kafka, which is unrecoverable, so fail the samza job", e)
+          case e: InvalidMessageException => throw new KafkaCheckpointException("Got InvalidMessageException from Kafka, which is unrecoverable, so fail the samza job", e)
+          case e: InvalidMessageSizeException => throw new KafkaCheckpointException("Got InvalidMessageSizeException from Kafka, which is unrecoverable, so fail the samza job", e)
+          case e: UnknownTopicOrPartitionException => throw new KafkaCheckpointException("Got UnknownTopicOrPartitionException from Kafka, which is unrecoverable, so fail the samza job", e)
           case e: KafkaCheckpointException => throw e
           case e: Exception =>
-            warn("While trying to read last checkpoint for topic %s and partition %s: %s. Retrying." format (checkpointTopic, partition, e))
+            warn("While trying to read last %s entry for topic %s and partition 0: %s. Retrying." format(entryType, checkpointTopic, e))
             debug("Exception detail:", e)
         }
       }
-    ).getOrElse(throw new SamzaException("Failed to get checkpoint for partition %s" format partition.getPartitionId))
+    ).getOrElse(throw new SamzaException("Failed to get entries for " + entryType + " from topic " + checkpointTopic))
 
-    info("Got checkpoint state for partition %s: %s" format (partition.getPartitionId, checkpoint))
-    checkpoint
   }
 
   def start {
-    createTopic
+    create
     validateTopic
   }
 
-  def register(partition: Partition) {
-    partitions += partition
+  def register(taskName: TaskName) {
+    debug("Adding taskName " + taskName + " to " + this)
+    taskNames += taskName
   }
 
   def stop = {
-    if(producer != null) {
+    if (producer != null) {
       producer.close
     }
   }
 
-  private def createTopic {
-    info("Attempting to create checkpoint topic %s with %s partitions." format (checkpointTopic, totalPartitions))
+  def create {
+    info("Attempting to create checkpoint topic %s." format checkpointTopic)
     retryBackoff.run(
       loop => {
         val zkClient = connectZk()
@@ -223,7 +354,7 @@ class KafkaCheckpointManager(
           AdminUtils.createTopic(
             zkClient,
             checkpointTopic,
-            totalPartitions,
+            1,
             replicationFactor)
         } finally {
           zkClient.close
@@ -239,7 +370,7 @@ class KafkaCheckpointManager(
             info("Checkpoint topic %s already exists." format checkpointTopic)
             loop.done
           case e: Exception =>
-            warn("Failed to create topic %s: %s. Retrying." format (checkpointTopic, e))
+            warn("Failed to create topic %s: %s. Retrying." format(checkpointTopic, e))
             debug("Exception detail:", e)
         }
       }
@@ -255,8 +386,8 @@ class KafkaCheckpointManager(
         ErrorMapping.maybeThrowException(topicMetadata.errorCode)
 
         val partitionCount = topicMetadata.partitionsMetadata.length
-        if (partitionCount != totalPartitions) {
-          throw new KafkaCheckpointException("Checkpoint topic validation failed for topic %s because partition count %s did not match expected partition count %s." format (checkpointTopic, topicMetadata.partitionsMetadata.length, totalPartitions))
+        if (partitionCount != 1) {
+          throw new KafkaCheckpointException("Checkpoint topic validation failed for topic %s because partition count %s did not match expected partition count of 1." format(checkpointTopic, topicMetadata.partitionsMetadata.length))
         }
 
         info("Successfully validated checkpoint topic %s." format checkpointTopic)
@@ -267,14 +398,19 @@ class KafkaCheckpointManager(
         exception match {
           case e: KafkaCheckpointException => throw e
           case e: Exception =>
-            warn("While trying to validate topic %s: %s. Retrying." format (checkpointTopic, e))
+            warn("While trying to validate topic %s: %s. Retrying." format(checkpointTopic, e))
             debug("Exception detail:", e)
         }
       }
     )
   }
 
-  override def toString = "KafkaCheckpointManager [systemName=%s, checkpointTopic=%s]" format (systemName, checkpointTopic)
+  override def toString = "KafkaCheckpointManager [systemName=%s, checkpointTopic=%s]" format(systemName, checkpointTopic)
+}
+
+object KafkaCheckpointManager {
+  val CHECKPOINT_LOG4J_ENTRY = "checkpoint log"
+  val CHANGELOG_PARTITION_MAPPING_LOG4j = "changelog partition mapping"
 }
 
 /**
index cb6dbdf..087c6ad 100644 (file)
 
 package org.apache.samza.checkpoint.kafka
 
-import org.apache.samza.config.{ KafkaConfig, Config }
-import org.apache.samza.SamzaException
-import java.util.Properties
+import grizzled.slf4j.Logging
 import kafka.producer.Producer
-import org.apache.samza.config.SystemConfig.Config2System
-import org.apache.samza.config.StreamConfig.Config2Stream
-import org.apache.samza.config.TaskConfig.Config2Task
-import org.apache.samza.config.KafkaConfig.Config2Kafka
+import kafka.utils.ZKStringSerializer
+import org.I0Itec.zkclient.ZkClient
+import org.apache.samza.SamzaException
+import org.apache.samza.checkpoint.CheckpointManager
+import org.apache.samza.checkpoint.CheckpointManagerFactory
+import org.apache.samza.config.Config
 import org.apache.samza.config.JobConfig.Config2Job
-import org.apache.samza.Partition
-import grizzled.slf4j.Logging
+import org.apache.samza.config.KafkaConfig.Config2Kafka
 import org.apache.samza.metrics.MetricsRegistry
 import org.apache.samza.util.{ KafkaUtil, ClientUtilTopicMetadataStore }
-import org.apache.samza.util.Util
-import org.I0Itec.zkclient.ZkClient
-import kafka.utils.ZKStringSerializer
-import org.apache.samza.checkpoint.CheckpointManagerFactory
-import org.apache.samza.checkpoint.CheckpointManager
 
+object KafkaCheckpointManagerFactory {
+  /**
+   * Version number to track the format of the checkpoint log
+   */
+  val CHECKPOINT_LOG_VERSION_NUMBER = 1
+}
 class KafkaCheckpointManagerFactory extends CheckpointManagerFactory with Logging {
+  import KafkaCheckpointManagerFactory._
+
   def getCheckpointManager(config: Config, registry: MetricsRegistry): CheckpointManager = {
     val clientId = KafkaUtil.getClientId("samza-checkpoint-manager", config)
     val systemName = config
@@ -60,7 +62,7 @@ class KafkaCheckpointManagerFactory extends CheckpointManagerFactory with Loggin
     val fetchSize = consumerConfig.fetchMessageMaxBytes // must be > buffer size
 
     val connectProducer = () => {
-      new Producer[Partition, Array[Byte]](producerConfig)
+      new Producer[Array[Byte], Array[Byte]](producerConfig)
     }
     val zkConnect = Option(consumerConfig.zkConnect)
       .getOrElse(throw new SamzaException("no zookeeper.connect defined in config"))
@@ -73,24 +75,24 @@ class KafkaCheckpointManagerFactory extends CheckpointManagerFactory with Loggin
       .getOrElse(throw new SamzaException("No broker list defined in config for %s." format systemName))
     val metadataStore = new ClientUtilTopicMetadataStore(brokersListString, clientId, socketTimeout)
     val checkpointTopic = getTopic(jobName, jobId)
-    
-    // This is a reasonably expensive operation and the TaskInstance already knows the answer. Should use that info.
-    val totalPartitions = Util.getInputStreamPartitions(config).map(_.getPartition).toSet.size
+
+    // Find out the SSPGrouperFactory class so it can be included/verified in the key
+    val systemStreamPartitionGrouperFactoryString = config.getSystemStreamPartitionGrouperFactory
 
     new KafkaCheckpointManager(
       clientId,
       checkpointTopic,
       systemName,
-      totalPartitions,
       replicationFactor,
       socketTimeout,
       bufferSize,
       fetchSize,
       metadataStore,
       connectProducer,
-      connectZk)
+      connectZk,
+      systemStreamPartitionGrouperFactoryString)
   }
 
   private def getTopic(jobName: String, jobId: String) =
-    "__samza_checkpoint_%s_%s" format (jobName.replaceAll("_", "-"), jobId.replaceAll("_", "-"))
+    "__samza_checkpoint_ver_%d_for_%s_%s" format (CHECKPOINT_LOG_VERSION_NUMBER, jobName.replaceAll("_", "-"), jobId.replaceAll("_", "-"))
 }
index 8a8834f..9553050 100644 (file)
@@ -19,8 +19,6 @@
 
 package org.apache.samza.system.kafka
 
-import scala.annotation.implicitNotFound
-
 import grizzled.slf4j.Logging
 import kafka.api.TopicMetadata
 import kafka.common.ErrorMapping
diff --git a/samza-kafka/src/test/scala/org/apache/samza/checkpoint/kafka/TestKafkaCheckpointLogKey.scala b/samza-kafka/src/test/scala/org/apache/samza/checkpoint/kafka/TestKafkaCheckpointLogKey.scala
new file mode 100644 (file)
index 0000000..7a23041
--- /dev/null
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.samza.checkpoint.kafka
+
+import org.apache.samza.container.TaskName
+import org.junit.Assert._
+import org.junit.{Before, Test}
+import org.apache.samza.SamzaException
+
+class TestKafkaCheckpointLogKey {
+  @Before
+  def setSSPGrouperFactoryString() {
+    KafkaCheckpointLogKey.setSystemStreamPartitionGrouperFactoryString("hello")
+  }
+
+  @Test
+  def checkpointKeySerializationRoundTrip() {
+    val checkpointKey = KafkaCheckpointLogKey.getCheckpointKey(new TaskName("TN"))
+    val asBytes = checkpointKey.toBytes()
+    val backFromBytes = KafkaCheckpointLogKey.fromBytes(asBytes)
+
+    assertEquals(checkpointKey, backFromBytes)
+    assertNotSame(checkpointKey, backFromBytes)
+  }
+
+  @Test
+  def changelogPartitionMappingKeySerializationRoundTrip() {
+    val key = KafkaCheckpointLogKey.getChangelogPartitionMappingKey()
+    val asBytes = key.toBytes()
+    val backFromBytes = KafkaCheckpointLogKey.fromBytes(asBytes)
+
+    assertEquals(key, backFromBytes)
+    assertNotSame(key, backFromBytes)
+  }
+
+  @Test
+  def differingSSPGrouperFactoriesCauseException() {
+
+    val checkpointKey = KafkaCheckpointLogKey.getCheckpointKey(new TaskName("TN"))
+
+    val asBytes = checkpointKey.toBytes()
+
+    KafkaCheckpointLogKey.setSystemStreamPartitionGrouperFactoryString("goodbye")
+
+    var gotException = false
+    try {
+      KafkaCheckpointLogKey.fromBytes(asBytes)
+    } catch {
+      case se:SamzaException => assertEquals(new DifferingSystemStreamPartitionGrouperFactoryValues("hello", "goodbye").getMessage(), se.getCause.getMessage)
+                                gotException = true
+    }
+
+    assertTrue("Should have had an exception since ssp grouper factories didn't match", gotException)
+  }
+}
index 92ac61e..cddee13 100644 (file)
 
 package org.apache.samza.checkpoint.kafka
 
-import org.I0Itec.zkclient.ZkClient
-import org.junit.Assert._
-import org.junit.AfterClass
-import org.junit.BeforeClass
-import org.junit.Test
+import kafka.common.InvalidMessageSizeException
+import kafka.common.UnknownTopicOrPartitionException
+import kafka.message.InvalidMessageException
 import kafka.producer.Producer
 import kafka.producer.ProducerConfig
 import kafka.server.KafkaConfig
@@ -31,20 +29,20 @@ import kafka.server.KafkaServer
 import kafka.utils.TestUtils
 import kafka.utils.TestZKUtils
 import kafka.utils.Utils
+import kafka.utils.ZKStringSerializer
 import kafka.zk.EmbeddedZookeeper
-import org.apache.samza.metrics.MetricsRegistryMap
-import org.apache.samza.Partition
-import scala.collection._
-import scala.collection.JavaConversions._
-import org.apache.samza.util.{ ClientUtilTopicMetadataStore, TopicMetadataStore }
-import org.apache.samza.config.MapConfig
+import org.I0Itec.zkclient.ZkClient
 import org.apache.samza.checkpoint.Checkpoint
+import org.apache.samza.container.TaskName
 import org.apache.samza.serializers.CheckpointSerde
-import org.apache.samza.system.SystemStream
-import kafka.utils.ZKStringSerializer
-import kafka.message.InvalidMessageException
-import kafka.common.InvalidMessageSizeException
-import kafka.common.UnknownTopicOrPartitionException
+import org.apache.samza.system.SystemStreamPartition
+import org.apache.samza.util.{ ClientUtilTopicMetadataStore, TopicMetadataStore }
+import org.apache.samza.{SamzaException, Partition}
+import org.junit.Assert._
+import org.junit.{AfterClass, BeforeClass, Test}
+import scala.collection.JavaConversions._
+import scala.collection._
+import org.apache.samza.container.systemstreampartition.groupers.GroupByPartitionFactory
 
 object TestKafkaCheckpointManager {
   val zkConnect: String = TestZKUtils.zookeeperConnect
@@ -72,14 +70,16 @@ object TestKafkaCheckpointManager {
   config.put("request.required.acks", "-1")
   val producerConfig = new ProducerConfig(config)
   val partition = new Partition(0)
-  val cp1 = new Checkpoint(Map(new SystemStream("kafka", "topic") -> "123"))
-  val cp2 = new Checkpoint(Map(new SystemStream("kafka", "topic") -> "12345"))
+  val cp1 = new Checkpoint(Map(new SystemStreamPartition("kafka", "topic", partition) -> "123"))
+  val cp2 = new Checkpoint(Map(new SystemStreamPartition("kafka", "topic", partition) -> "12345"))
   var zookeeper: EmbeddedZookeeper = null
   var server1: KafkaServer = null
   var server2: KafkaServer = null
   var server3: KafkaServer = null
   var metadataStore: TopicMetadataStore = null
 
+  val systemStreamPartitionGrouperFactoryString = classOf[GroupByPartitionFactory].getCanonicalName
+
   @BeforeClass
   def beforeSetupServers {
     zookeeper = new EmbeddedZookeeper(zkConnect)
@@ -108,42 +108,45 @@ class TestKafkaCheckpointManager {
   import TestKafkaCheckpointManager._
 
   @Test
-  def testCheckpointShouldBeNullIfcheckpointTopicDoesNotExistShouldBeCreatedOnWriteAndShouldBeReadableAfterWrite {
+  def testCheckpointShouldBeNullIfCheckpointTopicDoesNotExistShouldBeCreatedOnWriteAndShouldBeReadableAfterWrite {
     val kcm = getKafkaCheckpointManager
-    kcm.register(partition)
+    val taskName = new TaskName(partition.toString)
+    kcm.register(taskName)
     kcm.start
-    var readCp = kcm.readLastCheckpoint(partition)
+    var readCp = kcm.readLastCheckpoint(taskName)
     // read before topic exists should result in a null checkpoint
-    assert(readCp == null)
+    assertNull(readCp)
     // create topic the first time around
-    kcm.writeCheckpoint(partition, cp1)
-    readCp = kcm.readLastCheckpoint(partition)
-    assert(cp1.equals(readCp))
+    kcm.writeCheckpoint(taskName, cp1)
+    readCp = kcm.readLastCheckpoint(taskName)
+    assertEquals(cp1, readCp)
     // should get an exception if partition doesn't exist
     try {
-      readCp = kcm.readLastCheckpoint(new Partition(1))
+      readCp = kcm.readLastCheckpoint(new TaskName(new Partition(1).toString))
       fail("Expected a SamzaException, since only one partition (partition 0) should exist.")
     } catch {
-      case e: Exception => None // expected
+      case e: SamzaException => None // expected
+      case _: Exception => fail("Expected a SamzaException, since only one partition (partition 0) should exist.")
     }
     // writing a second message should work, too
-    kcm.writeCheckpoint(partition, cp2)
-    readCp = kcm.readLastCheckpoint(partition)
-    assert(cp2.equals(readCp))
+    kcm.writeCheckpoint(taskName, cp2)
+    readCp = kcm.readLastCheckpoint(taskName)
+    assertEquals(cp2, readCp)
     kcm.stop
   }
 
   @Test
-  def testUnrecovableKafkaErrorShouldThrowKafkaCheckpointManagerException {
+  def testUnrecoverableKafkaErrorShouldThrowKafkaCheckpointManagerException {
     val exceptions = List("InvalidMessageException", "InvalidMessageSizeException", "UnknownTopicOrPartitionException")
     exceptions.foreach { exceptionName =>
       val kcm = getKafkaCheckpointManagerWithInvalidSerde(exceptionName)
-      kcm.register(partition)
+      val taskName = new TaskName(partition.toString)
+      kcm.register(taskName)
       kcm.start
-      kcm.writeCheckpoint(partition, cp1)
+      kcm.writeCheckpoint(taskName, cp1)
       // because serde will throw unrecoverable errors, it should result a KafkaCheckpointException
       try {
-        val readCpInvalide = kcm.readLastCheckpoint(partition)
+        kcm.readLastCheckpoint(taskName)
         fail("Expected a KafkaCheckpointException.")
       } catch {
         case e: KafkaCheckpointException => None
@@ -156,28 +159,28 @@ class TestKafkaCheckpointManager {
     clientId = "some-client-id",
     checkpointTopic = "checkpoint-topic",
     systemName = "kafka",
-    totalPartitions = 1,
     replicationFactor = 3,
     socketTimeout = 30000,
     bufferSize = 64 * 1024,
     fetchSize = 300 * 1024,
     metadataStore = metadataStore,
-    connectProducer = () => new Producer[Partition, Array[Byte]](producerConfig),
-    connectZk = () => new ZkClient(zkConnect, 6000, 6000, ZKStringSerializer))
+    connectProducer = () => new Producer[Array[Byte], Array[Byte]](producerConfig),
+    connectZk = () => new ZkClient(zkConnect, 6000, 6000, ZKStringSerializer),
+    systemStreamPartitionGrouperFactoryString = systemStreamPartitionGrouperFactoryString)
 
   // inject serde. Kafka exceptions will be thrown when serde.fromBytes is called
   private def getKafkaCheckpointManagerWithInvalidSerde(exception: String) = new KafkaCheckpointManager(
     clientId = "some-client-id-invalid-serde",
     checkpointTopic = "checkpoint-topic-invalid-serde",
     systemName = "kafka",
-    totalPartitions = 1,
     replicationFactor = 3,
     socketTimeout = 30000,
     bufferSize = 64 * 1024,
     fetchSize = 300 * 1024,
     metadataStore = metadataStore,
-    connectProducer = () => new Producer[Partition, Array[Byte]](producerConfig),
+    connectProducer = () => new Producer[Array[Byte], Array[Byte]](producerConfig),
     connectZk = () => new ZkClient(zkConnect, 6000, 6000, ZKStringSerializer),
+    systemStreamPartitionGrouperFactoryString = systemStreamPartitionGrouperFactoryString,
     serde = new InvalideSerde(exception))
 
   class InvalideSerde(exception: String) extends CheckpointSerde {
index 6be9732..be1670c 100644 (file)
 package org.apache.samza.system.kafka
 
 import org.junit.Assert._
-import org.junit.Test
+import org.junit.{Test, BeforeClass, AfterClass}
 import kafka.zk.EmbeddedZookeeper
-import org.junit.BeforeClass
-import org.junit.AfterClass
 import org.apache.samza.util.ClientUtilTopicMetadataStore
 import org.I0Itec.zkclient.ZkClient
 import kafka.admin.AdminUtils
index 751fe4c..6652f6b 100644 (file)
 
 package org.apache.samza.storage.kv
 
-import java.nio.ByteBuffer
 import org.iq80.leveldb._
-import org.fusesource.leveldbjni.internal.NativeComparator
 import org.fusesource.leveldbjni.JniDBFactory._
 import java.io._
-import java.util.Iterator
-import java.lang.Iterable
 import org.apache.samza.config.Config
 import org.apache.samza.container.SamzaContainerContext
 import grizzled.slf4j.{ Logger, Logging }
@@ -39,8 +35,8 @@ object LevelDbKeyValueStore {
     val options = new Options
 
     // Cache size and write buffer size are specified on a per-container basis.
-    options.cacheSize(cacheSize / containerContext.partitions.size)
-    options.writeBufferSize((writeBufSize / containerContext.partitions.size).toInt)
+    options.cacheSize(cacheSize / containerContext.taskNames.size)
+    options.writeBufferSize((writeBufSize / containerContext.taskNames.size).toInt)
     options.blockSize(storeConfig.getInt("leveldb.block.size.bytes", 4096))
     options.compressionType(
       storeConfig.get("leveldb.compression", "snappy") match {
index 222c130..f20bb7f 100644 (file)
@@ -20,6 +20,7 @@
 package org.apache.samza.test.integration.join;
 
 import org.apache.samza.config.Config;
+import org.apache.samza.container.TaskName;
 import org.apache.samza.storage.kv.KeyValueStore;
 import org.apache.samza.system.IncomingMessageEnvelope;
 import org.apache.samza.system.OutgoingMessageEnvelope;
@@ -45,12 +46,12 @@ public class Emitter implements StreamTask, InitableTask, WindowableTask {
   
   private KeyValueStore<String, String> state;
   private int max;
-  private String partition;
+  private TaskName taskName;
 
   @Override
   public void init(Config config, TaskContext context) {
     this.state = (KeyValueStore<String, String>) context.getStore("emitter-state");
-    this.partition = Integer.toString(context.getPartition().getPartitionId());
+    this.taskName = context.getTaskName();
     this.max = config.getInt("count");
   }
 
@@ -79,7 +80,7 @@ public class Emitter implements StreamTask, InitableTask, WindowableTask {
     }
     int counter = getInt(COUNT);
     if(counter < max) {
-      OutgoingMessageEnvelope envelope = new OutgoingMessageEnvelope(new SystemStream("kafka", "emitted"), Integer.toString(counter), epoch + "-" + partition);
+      OutgoingMessageEnvelope envelope = new OutgoingMessageEnvelope(new SystemStream("kafka", "emitted"), Integer.toString(counter), epoch + "-" + taskName);
       collector.send(envelope);
       this.state.put(COUNT, Integer.toString(getInt(COUNT) + 1));
     } else {
index c0ac5dd..7d0b8db 100644 (file)
@@ -22,7 +22,7 @@ package org.apache.samza.test.performance
 import grizzled.slf4j.Logging
 import org.apache.samza.config.Config
 import org.apache.samza.config.StorageConfig._
-import org.apache.samza.container.SamzaContainerContext
+import org.apache.samza.container.{TaskName, SamzaContainerContext}
 import org.apache.samza.metrics.MetricsRegistryMap
 import org.apache.samza.storage.kv.KeyValueStore
 import org.apache.samza.storage.kv.KeyValueStorageEngine
@@ -34,7 +34,6 @@ import org.apache.samza.serializers.ByteSerde
 import org.apache.samza.Partition
 import org.apache.samza.SamzaException
 import java.io.File
-import scala.collection.JavaConversions._
 import java.util.UUID
 
 /**
@@ -79,7 +78,9 @@ object TestKeyValuePerformance extends Logging {
     val numLoops = config.getInt("test.num.loops", 100)
     val messagesPerBatch = config.getInt("test.messages.per.batch", 10000)
     val messageSizeBytes = config.getInt("test.message.size.bytes", 200)
-    val partitions = (0 until partitionCount).map(new Partition(_))
+    val taskNames = new java.util.ArrayList[TaskName]()
+
+    (0 until partitionCount).map(p => taskNames.add(new TaskName(new Partition(p).toString)))
 
     info("Using partition count: %s" format partitionCount)
     info("Using num loops: %s" format numLoops)
@@ -109,7 +110,7 @@ object TestKeyValuePerformance extends Logging {
         new ReadableCollector,
         new MetricsRegistryMap,
         null,
-        new SamzaContainerContext("test", config, partitions))
+        new SamzaContainerContext("test", config, taskNames))
 
       val db = if (!engine.isInstanceOf[KeyValueStorageEngine[_, _]]) {
         throw new SamzaException("This test can only run with KeyValueStorageEngine configured as store factory.")
index dc44a99..3ed8b5c 100644 (file)
 
 package org.apache.samza.test.integration
 
-import org.apache.samza.task.StreamTask
-import org.apache.samza.task.TaskContext
-import org.apache.samza.task.InitableTask
-import org.apache.samza.config.Config
-import scala.collection.JavaConversions._
-import org.apache.samza.task.TaskCoordinator
-import org.apache.samza.task.MessageCollector
-import org.apache.samza.system.IncomingMessageEnvelope
-import org.apache.samza.checkpoint.Checkpoint
-import org.junit.BeforeClass
-import org.junit.AfterClass
-import kafka.zk.EmbeddedZookeeper
-import kafka.utils.TestUtils
-import org.apache.samza.system.SystemStream
-import kafka.utils.TestZKUtils
-import kafka.server.KafkaConfig
-import org.I0Itec.zkclient.ZkClient
+import java.util.Properties
+import java.util.concurrent.CountDownLatch
+import java.util.concurrent.TimeUnit
+import kafka.admin.AdminUtils
+import kafka.common.ErrorMapping
+import kafka.consumer.Consumer
+import kafka.consumer.ConsumerConfig
+import kafka.message.MessageAndMetadata
+import kafka.producer.KeyedMessage
+import kafka.producer.Producer
 import kafka.producer.ProducerConfig
+import kafka.server.KafkaConfig
 import kafka.server.KafkaServer
+import kafka.utils.TestUtils
+import kafka.utils.TestZKUtils
 import kafka.utils.Utils
-import org.apache.samza.storage.kv.KeyValueStore
-import org.apache.samza.util._
-import org.junit.Test
-import kafka.admin.AdminUtils
-import kafka.common.ErrorMapping
-import org.junit.Assert._
 import kafka.utils.ZKStringSerializer
-import scala.collection.mutable.ArrayBuffer
-import org.apache.samza.job.local.LocalJobFactory
-import org.apache.samza.job.ApplicationStatus
-import java.util.concurrent.CountDownLatch
-import org.apache.samza.job.local.ThreadJob
-import org.apache.samza.util.TopicMetadataStore
-import org.apache.samza.util.ClientUtilTopicMetadataStore
+import kafka.zk.EmbeddedZookeeper
+import org.I0Itec.zkclient.ZkClient
+import org.apache.samza.Partition
+import org.apache.samza.checkpoint.Checkpoint
+import org.apache.samza.config.Config
 import org.apache.samza.config.MapConfig
+import org.apache.samza.container.TaskName
+import org.apache.samza.job.ApplicationStatus
+import org.apache.samza.job.StreamJob
+import org.apache.samza.job.local.LocalJobFactory
+import org.apache.samza.storage.kv.KeyValueStore
 import org.apache.samza.system.kafka.TopicMetadataCache
-import org.apache.samza.container.SamzaContainer
+import org.apache.samza.system.{SystemStreamPartition, IncomingMessageEnvelope}
+import org.apache.samza.task.InitableTask
+import org.apache.samza.task.MessageCollector
+import org.apache.samza.task.StreamTask
+import org.apache.samza.task.TaskContext
+import org.apache.samza.task.TaskCoordinator
+import org.apache.samza.task.TaskCoordinator.RequestScope
+import org.apache.samza.util.ClientUtilTopicMetadataStore
+import org.apache.samza.util.TopicMetadataStore
+import org.junit.Assert._
+import org.junit.{BeforeClass, AfterClass, Test}
+import scala.collection.JavaConversions._
+import scala.collection.mutable.ArrayBuffer
 import scala.collection.mutable.HashMap
 import scala.collection.mutable.SynchronizedMap
-import org.apache.samza.Partition
-import java.util.concurrent.TimeUnit
-import kafka.producer.Producer
-import kafka.producer.KeyedMessage
-import java.util.concurrent.Semaphore
-import java.util.concurrent.CyclicBarrier
-import kafka.consumer.Consumer
-import kafka.consumer.ConsumerConfig
-import java.util.Properties
-import java.util.concurrent.Executors
-import kafka.message.MessageAndOffset
-import kafka.message.MessageAndMetadata
-import org.apache.samza.job.StreamJob
-import org.apache.samza.task.TaskCoordinator.RequestScope
 
 object TestStatefulTask {
   val INPUT_TOPIC = "input"
   val STATE_TOPIC = "mystore"
-  val TOTAL_PARTITIONS = 1
+  val TOTAL_TASK_NAMES = 1
   val REPLICATION_FACTOR = 3
 
   val zkConnect: String = TestZKUtils.zookeeperConnect
@@ -102,8 +93,8 @@ object TestStatefulTask {
   config.put("serializer.class", "kafka.serializer.StringEncoder");
   val producerConfig = new ProducerConfig(config)
   var producer: Producer[String, String] = null
-  val cp1 = new Checkpoint(Map(new SystemStream("kafka", "topic") -> "123"))
-  val cp2 = new Checkpoint(Map(new SystemStream("kafka", "topic") -> "12345"))
+  val cp1 = new Checkpoint(Map(new SystemStreamPartition("kafka", "topic", new Partition(0)) -> "123"))
+  val cp2 = new Checkpoint(Map(new SystemStreamPartition("kafka", "topic", new Partition(0)) -> "12345"))
   var zookeeper: EmbeddedZookeeper = null
   var server1: KafkaServer = null
   var server2: KafkaServer = null
@@ -128,13 +119,13 @@ object TestStatefulTask {
     AdminUtils.createTopic(
       zkClient,
       INPUT_TOPIC,
-      TOTAL_PARTITIONS,
+      TOTAL_TASK_NAMES,
       REPLICATION_FACTOR)
 
     AdminUtils.createTopic(
       zkClient,
       STATE_TOPIC,
-      TOTAL_PARTITIONS,
+      TOTAL_TASK_NAMES,
       REPLICATION_FACTOR)
   }
 
@@ -221,7 +212,14 @@ class TestStatefulTask {
     "systems.kafka.consumer.auto.offset.reset" -> "smallest", // applies to an empty topic
     "systems.kafka.samza.msg.serde" -> "string",
     "systems.kafka.consumer.zookeeper.connect" -> zkConnect,
-    "systems.kafka.producer.metadata.broker.list" -> ("localhost:%s" format port1))
+    "systems.kafka.producer.metadata.broker.list" -> ("localhost:%s" format port1),
+    // Since using state, need a checkpoint manager
+    "task.checkpoint.factory" -> "org.apache.samza.checkpoint.kafka.KafkaCheckpointManagerFactory",
+    "task.checkpoint.system" -> "kafka",
+    "task.checkpoint.replication.factor" -> "1",
+    // However, don't have the inputs use the checkpoint manager
+    // since the second part of the test expects to replay the input streams.
+    "systems.kafka.streams.input.samza.reset.offset" -> "true")
 
   @Test
   def testShouldStartAndRestore {
@@ -278,7 +276,7 @@ class TestStatefulTask {
       count += 1
     }
 
-    assertTrue(count < 100)
+    assertTrue("Timed out waiting to received messages. Received thus far: " + task.received.size, count < 100)
 
     // Reset the count down latch after the 4 messages come in.
     task.awaitMessage
@@ -328,8 +326,7 @@ class TestStatefulTask {
     TestTask.awaitTaskRegistered
     val tasks = TestTask.tasks
 
-    // Should only have one partition.
-    assertEquals(1, tasks.size)
+    assertEquals("Should only have a single partition in this task", 1, tasks.size)
 
     val task = tasks.values.toList.head
 
@@ -392,25 +389,25 @@ class TestStatefulTask {
 }
 
 object TestTask {
-  val tasks = new HashMap[Partition, TestTask] with SynchronizedMap[Partition, TestTask]
-  @volatile var allTasksRegistered = new CountDownLatch(TestStatefulTask.TOTAL_PARTITIONS)
+  val tasks = new HashMap[TaskName, TestTask] with SynchronizedMap[TaskName, TestTask]
+  @volatile var allTasksRegistered = new CountDownLatch(TestStatefulTask.TOTAL_TASK_NAMES)
 
   /**
    * Static method that tasks can use to register themselves with. Useful so
    * we don't have to sneak into the ThreadJob/SamzaContainer to get our test
    * tasks.
    */
-  def register(partition: Partition, task: TestTask) {
-    tasks += partition -> task
+  def register(taskName: TaskName, task: TestTask) {
+    tasks += taskName -> task
     allTasksRegistered.countDown
   }
 
   def awaitTaskRegistered {
     allTasksRegistered.await(60, TimeUnit.SECONDS)
     assertEquals(0, allTasksRegistered.getCount)
-    assertEquals(TestStatefulTask.TOTAL_PARTITIONS, tasks.size)
+    assertEquals(TestStatefulTask.TOTAL_TASK_NAMES, tasks.size)
     // Reset the registered latch, so we can use it again every time we start a new job.
-    TestTask.allTasksRegistered = new CountDownLatch(TestStatefulTask.TOTAL_PARTITIONS)
+    TestTask.allTasksRegistered = new CountDownLatch(TestStatefulTask.TOTAL_TASK_NAMES)
   }
 }
 
@@ -422,7 +419,7 @@ class TestTask extends StreamTask with InitableTask {
   var gotMessage = new CountDownLatch(1)
 
   def init(config: Config, context: TaskContext) {
-    TestTask.register(context.getPartition, this)
+    TestTask.register(context.getTaskName, this)
     store = context
       .getStore(TestStatefulTask.STATE_TOPIC)
       .asInstanceOf[KeyValueStore[String, String]]
index ea6f03b..86fc0fd 100644 (file)
             %td.key Finished
             %td= state.finishedTasks.size.toString
 
-      %h3 Partition Assignment
-      %table.table.table-striped.table-bordered.tablesorter#partitions-table
+      %h3 TaskName Assignment
+      %table.table.table-striped.table-bordered.tablesorter#taskids-table
         %thead
           %tr
             %th Task ID
-            %th Partitions
+            %th TaskName
+            %th SystemStreamPartitions
             %th Container
         %tbody
-          - for((taskId, partitions) <- state.taskPartitions)
-            %tr
-              %td= taskId
-              %td= partitions.map(_.getPartitionId).toList.sorted.mkString(", ")
-              %td
-                - val container = state.runningTasks(taskId)
-                %a(target="_blank" href="http://#{container.nodeHttpAddress}/node/containerlogs/#{container.id.toString}/#{username}")= container.id.toString
+          - for((taskId, taskNames) <- state.taskToTaskNames)
+            - for((taskName, ssps) <- taskNames)
+              %tr
+                %td= taskId
+                %td= taskName
+                %td= ssps.map(_.toString).toList.sorted.mkString(", ")
+                %td
+                  - val container = state.runningTasks(taskId)
+                  %a(target="_blank" href="http://#{container.nodeHttpAddress}/node/containerlogs/#{container.id.toString}/#{username}")= container.id.toString
 
     %div.tab-pane#config
       %h2 Config
     :javascript
       $(document).ready(function() {
         $("#containers-table").tablesorter();
-        $("#partitions-table").tablesorter();
+        $("#taskids-table").tablesorter();
         $("#config-table").tablesorter();
         $("#config-table-filter").keyup(function() {
           var regex = new RegExp($(this).val(), 'i');
index 01a2683..d9dfbc6 100644 (file)
  */
 
 package org.apache.samza.job.yarn
-import org.apache.samza.config.Config
 import grizzled.slf4j.Logging
 import org.apache.hadoop.yarn.api.records.FinalApplicationStatus
-import org.apache.samza.Partition
-import org.apache.hadoop.yarn.util.ConverterUtils
 import org.apache.hadoop.yarn.api.records.ContainerId
+import java.util
+import org.apache.samza.system.SystemStreamPartition
+import org.apache.samza.container.TaskName
 
 /**
  * Samza's application master has state that is usually manipulated based on
@@ -40,7 +40,7 @@ class SamzaAppMasterState(val taskId: Int, val containerId: ContainerId, val nod
   var unclaimedTasks = Set[Int]()
   var finishedTasks = Set[Int]()
   var runningTasks = Map[Int, YarnContainer]()
-  var taskPartitions = Map[Int, Set[Partition]]()
+  var taskToTaskNames = Map[Int, util.Map[TaskName, util.Set[SystemStreamPartition]]]()
   var status = FinalApplicationStatus.UNDEFINED
 
   // controlled by the service
index eb1ff54..0dd244d 100644 (file)
@@ -23,6 +23,7 @@ import java.nio.ByteBuffer
 import java.util.Collections
 
 import scala.collection.JavaConversions._
+import scala.collection.JavaConverters.mapAsJavaMapConverter
 
 import org.apache.hadoop.fs.Path
 import org.apache.hadoop.io.DataOutputBuffer
@@ -46,6 +47,7 @@ import org.apache.samza.job.ShellCommandBuilder
 import org.apache.samza.util.Util
 
 import grizzled.slf4j.Logging
+import org.apache.samza.container.TaskNamesToSystemStreamPartitions
 
 object SamzaAppMasterTaskManager {
   val DEFAULT_CONTAINER_MEM = 1024
@@ -72,7 +74,10 @@ class SamzaAppMasterTaskManager(clock: () => Long, config: Config, state: SamzaA
       1
   }
 
-  val allSystemStreamPartitions = Util.getInputStreamPartitions(config)
+  val tasksToSSPTaskNames: Map[Int, TaskNamesToSystemStreamPartitions] = Util.assignContainerToSSPTaskNames(config, state.taskCount)
+
+  val taskNameToChangeLogPartitionMapping = Util.getTaskNameToChangeLogPartitionMapping(config, tasksToSSPTaskNames)
+
   var taskFailures = Map[Int, TaskFailure]()
   var tooManyFailedContainers = false
   // TODO we might want to use NMClientAsync as well
@@ -106,13 +111,14 @@ class SamzaAppMasterTaskManager(clock: () => Long, config: Config, state: SamzaA
     state.unclaimedTasks.headOption match {
       case Some(taskId) => {
         info("Got available task id (%d) for container: %s" format (taskId, container))
-        val streamsAndPartitionsForTask = Util.getStreamsAndPartitionsForContainer(taskId, state.taskCount, allSystemStreamPartitions)
-        info("Claimed partitions %s for container ID %s" format (streamsAndPartitionsForTask, taskId))
+        val sspTaskNames: TaskNamesToSystemStreamPartitions = tasksToSSPTaskNames.getOrElse(taskId, TaskNamesToSystemStreamPartitions())
+        info("Claimed SSP taskNames %s for container ID %s" format (sspTaskNames, taskId))
         val cmdBuilderClassName = config.getCommandClass.getOrElse(classOf[ShellCommandBuilder].getName)
         val cmdBuilder = Class.forName(cmdBuilderClassName).newInstance.asInstanceOf[CommandBuilder]
           .setConfig(config)
           .setName("samza-container-%s" format taskId)
-          .setStreamPartitions(streamsAndPartitionsForTask)
+          .setTaskNameToSystemStreamPartitionsMapping(sspTaskNames.getJavaFriendlyType)
+          .setTaskNameToChangeLogPartitionMapping(taskNameToChangeLogPartitionMapping.map(kv => kv._1 -> Integer.valueOf(kv._2)).asJava)
         val command = cmdBuilder.buildCommand
         info("Task ID %s using command %s" format (taskId, command))
         val env = cmdBuilder.buildEnvironment.map { case (k, v) => (k, Util.envVarEscape(v)) }
@@ -129,7 +135,7 @@ class SamzaAppMasterTaskManager(clock: () => Long, config: Config, state: SamzaA
         state.neededContainers -= 1
         state.runningTasks += taskId -> new YarnContainer(container)
         state.unclaimedTasks -= taskId
-        state.taskPartitions += taskId -> streamsAndPartitionsForTask.map(_.getPartition).toSet
+        state.taskToTaskNames += taskId -> sspTaskNames.getJavaFriendlyType
 
         info("Claimed task ID %s for container %s on node %s (http://%s/node/containerlogs/%s)." format (taskId, containerIdStr, container.getNodeId.getHost, container.getNodeHttpAddress, containerIdStr))
 
@@ -151,7 +157,7 @@ class SamzaAppMasterTaskManager(clock: () => Long, config: Config, state: SamzaA
     taskId match {
       case Some(taskId) => {
         state.runningTasks -= taskId
-        state.taskPartitions -= taskId
+        state.taskToTaskNames -= taskId
       }
       case _ => None
     }
@@ -315,4 +321,5 @@ class SamzaAppMasterTaskManager(clock: () => Long, config: Config, state: SamzaA
     capability.setVirtualCores(cpuCores)
     (0 until containers).foreach(idx => amClient.addContainerRequest(new ContainerRequest(capability, null, null, priority)))
   }
+
 }
index 520f784..d10dc38 100644 (file)
@@ -79,13 +79,14 @@ class ApplicationMasterRestServlet(config: Config, state: SamzaAppMasterState, r
     state.runningTasks.values.foreach(c => {
       val containerIdStr = c.id.toString
       val containerMap = new HashMap[String, Object]
+
       val taskId = state.runningTasks.filter { case (_, container) => container.id.toString.equals(containerIdStr) }.keys.head
-      var partitions = new java.util.ArrayList(state.taskPartitions.get(taskId).get)
+      val taskNames = new java.util.ArrayList(state.taskToTaskNames.get(taskId).get.toList)
 
       containerMap.put("yarn-address", c.nodeHttpAddress)
       containerMap.put("start-time", c.startTime.toString)
       containerMap.put("up-time", c.upTime.toString)
-      containerMap.put("partitions", partitions)
+      containerMap.put("task-names", taskNames)
       containerMap.put("task-id", taskId.toString)
       containers.put(containerIdStr, containerMap)
     })
index f1139f5..685620f 100644 (file)
@@ -19,8 +19,6 @@
 
 package org.apache.samza.job.yarn
 
-import scala.annotation.elidable
-import scala.annotation.elidable.ASSERTION
 import scala.collection.JavaConversions._
 
 import org.apache.hadoop.conf.Configuration
@@ -41,7 +39,6 @@ import org.apache.samza.system.SystemFactory
 import org.apache.samza.system.SystemStreamPartition
 import org.apache.samza.util.SinglePartitionWithoutOffsetsSystemAdmin
 import org.apache.samza.util.Util
-import org.junit.Assert._
 import org.junit.Test
 
 import TestSamzaAppMasterTaskManager._
@@ -230,7 +227,7 @@ class TestSamzaAppMasterTaskManager {
     taskManager.onContainerAllocated(getContainer(container2))
     assert(state.neededContainers == 0)
     assert(state.runningTasks.size == 1)
-    assert(state.taskPartitions.size == 1)
+    assert(state.taskToTaskNames.size == 1)
     assert(state.unclaimedTasks.size == 0)
     assert(containersRequested == 1)
     assert(containersStarted == 1)
@@ -239,7 +236,7 @@ class TestSamzaAppMasterTaskManager {
     taskManager.onContainerAllocated(getContainer(container3))
     assert(state.neededContainers == 0)
     assert(state.runningTasks.size == 1)
-    assert(state.taskPartitions.size == 1)
+    assert(state.taskToTaskNames.size == 1)
     assert(state.unclaimedTasks.size == 0)
     assert(amClient.getClient.requests.size == 1)
     assert(amClient.getClient.getRelease.size == 1)
@@ -255,7 +252,7 @@ class TestSamzaAppMasterTaskManager {
     assert(taskManager.shouldShutdown == false)
     assert(state.neededContainers == 0)
     assert(state.runningTasks.size == 1)
-    assert(state.taskPartitions.size == 1)
+    assert(state.taskToTaskNames.size == 1)
     assert(state.unclaimedTasks.size == 0)
     assert(amClient.getClient.requests.size == 0)
     assert(amClient.getClient.getRelease.size == 0)
@@ -293,13 +290,13 @@ class TestSamzaAppMasterTaskManager {
     taskManager.onContainerAllocated(getContainer(container2))
     assert(state.neededContainers == 1)
     assert(state.runningTasks.size == 1)
-    assert(state.taskPartitions.size == 1)
+    assert(state.taskToTaskNames.size == 1)
     assert(state.unclaimedTasks.size == 1)
     assert(containersStarted == 1)
     taskManager.onContainerAllocated(getContainer(container3))
     assert(state.neededContainers == 0)
     assert(state.runningTasks.size == 2)
-    assert(state.taskPartitions.size == 2)
+    assert(state.taskToTaskNames.size == 2)
     assert(state.unclaimedTasks.size == 0)
     assert(containersStarted == 2)
 
@@ -307,7 +304,7 @@ class TestSamzaAppMasterTaskManager {
     taskManager.onContainerCompleted(getContainerStatus(container2, 0, ""))
     assert(state.neededContainers == 0)
     assert(state.runningTasks.size == 1)
-    assert(state.taskPartitions.size == 1)
+    assert(state.taskToTaskNames.size == 1)
     assert(state.unclaimedTasks.size == 0)
     assert(state.completedTasks == 1)
 
@@ -315,7 +312,7 @@ class TestSamzaAppMasterTaskManager {
     taskManager.onContainerCompleted(getContainerStatus(container3, 1, "expected failure here"))
     assert(state.neededContainers == 1)
     assert(state.runningTasks.size == 0)
-    assert(state.taskPartitions.size == 0)
+    assert(state.taskToTaskNames.size == 0)
     assert(state.unclaimedTasks.size == 1)
     assert(state.completedTasks == 1)
     assert(taskManager.shouldShutdown == false)
@@ -324,7 +321,7 @@ class TestSamzaAppMasterTaskManager {
     taskManager.onContainerAllocated(getContainer(container3))
     assert(state.neededContainers == 0)
     assert(state.runningTasks.size == 1)
-    assert(state.taskPartitions.size == 1)
+    assert(state.taskToTaskNames.size == 1)
     assert(state.unclaimedTasks.size == 0)
     assert(containersStarted == 3)
 
@@ -332,7 +329,7 @@ class TestSamzaAppMasterTaskManager {
     taskManager.onContainerCompleted(getContainerStatus(container3, 0, ""))
     assert(state.neededContainers == 0)
     assert(state.runningTasks.size == 0)
-    assert(state.taskPartitions.size == 0)
+    assert(state.taskToTaskNames.size == 0)
     assert(state.unclaimedTasks.size == 0)
     assert(state.completedTasks == 2)
     assert(taskManager.shouldShutdown == true)
@@ -364,19 +361,19 @@ class TestSamzaAppMasterTaskManager {
     assert(amClient.getClient.getRelease.size == 0)
     assert(state.neededContainers == 1)
     assert(state.runningTasks.size == 0)
-    assert(state.taskPartitions.size == 0)
+    assert(state.taskToTaskNames.size == 0)
     assert(state.unclaimedTasks.size == 1)
     taskManager.onContainerAllocated(getContainer(container2))
     assert(state.neededContainers == 0)
     assert(state.runningTasks.size == 1)
-    assert(state.taskPartitions.size == 1)
+    assert(state.taskToTaskNames.size == 1)
     assert(state.unclaimedTasks.size == 0)
     assert(containersRequested == 1)
     assert(containersStarted == 1)
     taskManager.onContainerAllocated(getContainer(container3))
     assert(state.neededContainers == 0)
     assert(state.runningTasks.size == 1)
-    assert(state.taskPartitions.size == 1)
+    assert(state.taskToTaskNames.size == 1)
     assert(state.unclaimedTasks.size == 0)
     assert(containersRequested == 1)
     assert(containersStarted == 1)
@@ -385,38 +382,6 @@ class TestSamzaAppMasterTaskManager {
     assert(amClient.getClient.getRelease.head.equals(container3))
   }
 
-  @Test
-  def testPartitionsShouldWorkWithMoreTasksThanPartitions {
-    val onePartition = Set(new SystemStreamPartition("system", "stream", new Partition(0)))
-    assertEquals(Util.getStreamsAndPartitionsForContainer(0, 2, onePartition), Set(new SystemStreamPartition("system", "stream", new Partition(0))))
-    assertEquals(Util.getStreamsAndPartitionsForContainer(1, 2, onePartition), Set())
-  }
-
-  @Test
-  def testPartitionsShouldWorkWithMorePartitionsThanTasks {
-    val fivePartitions = (0 until 5).map(p => new SystemStreamPartition("system", "stream", new Partition(p))).toSet
-    assertEquals(Util.getStreamsAndPartitionsForContainer(0, 2, fivePartitions), Set(new SystemStreamPartition("system", "stream", new Partition(0)), new SystemStreamPartition("system", "stream", new Partition(2)), new SystemStreamPartition("system", "stream", new Partition(4))))
-    assertEquals(Util.getStreamsAndPartitionsForContainer(1, 2, fivePartitions), Set(new SystemStreamPartition("system", "stream", new Partition(1)), new SystemStreamPartition("system", "stream", new Partition(3))))
-  }
-
-  @Test
-  def testPartitionsShouldWorkWithTwelvePartitionsAndFiveContainers {
-    val fivePartitions = (0 until 12).map(p => new SystemStreamPartition("system", "stream", new Partition(p))).toSet
-    assertEquals(Util.getStreamsAndPartitionsForContainer(0, 5, fivePartitions), Set(new SystemStreamPartition("system", "stream", new Partition(0)), new SystemStreamPartition("system", "stream", new Partition(5)), new SystemStreamPartition("system", "stream", new Partition(10))))
-    assertEquals(Util.getStreamsAndPartitionsForContainer(1, 5, fivePartitions), Set(new SystemStreamPartition("system", "stream", new Partition(1)), new SystemStreamPartition("system", "stream", new Partition(6)), new SystemStreamPartition("system", "stream", new Partition(11))))
-    assertEquals(Util.getStreamsAndPartitionsForContainer(2, 5, fivePartitions), Set(new SystemStreamPartition("system", "stream", new Partition(2)), new SystemStreamPartition("system", "stream", new Partition(7))))
-    assertEquals(Util.getStreamsAndPartitionsForContainer(3, 5, fivePartitions), Set(new SystemStreamPartition("system", "stream", new Partition(3)), new SystemStreamPartition("system", "stream", new Partition(8))))
-    assertEquals(Util.getStreamsAndPartitionsForContainer(4, 5, fivePartitions), Set(new SystemStreamPartition("system", "stream", new Partition(4)), new SystemStreamPartition("system", "stream", new Partition(9))))
-  }
-
-  @Test
-  def testPartitionsShouldWorkWithEqualPartitionsAndTasks {
-    val twoPartitions = (0 until 2).map(p => new SystemStreamPartition("system", "stream", new Partition(p))).toSet
-    assertEquals(Util.getStreamsAndPartitionsForContainer(0, 2, twoPartitions), Set(new SystemStreamPartition("system", "stream", new Partition(0))))
-    assertEquals(Util.getStreamsAndPartitionsForContainer(1, 2, twoPartitions), Set(new SystemStreamPartition("system", "stream", new Partition(1))))
-    assertEquals(Util.getStreamsAndPartitionsForContainer(0, 1, Set(new SystemStreamPartition("system", "stream", new Partition(0)))), Set(new SystemStreamPartition("system", "stream", new Partition(0))))
-  }
-
   val clock = () => System.currentTimeMillis
 }