Initial commit.
authorShanthoosh Venkataraman <spvenkat@usc.edu>
Fri, 7 Dec 2018 02:04:18 +0000 (18:04 -0800)
committerShanthoosh Venkataraman <spvenkat@usc.edu>
Tue, 12 Feb 2019 16:20:24 +0000 (08:20 -0800)
18 files changed:
samza-api/src/main/java/org/apache/samza/container/grouper/stream/GrouperContext.java [deleted file]
samza-core/src/main/java/org/apache/samza/container/LocalityManager.java
samza-core/src/main/java/org/apache/samza/container/grouper/stream/SSPGrouperProxy.java
samza-core/src/main/java/org/apache/samza/container/grouper/task/TaskAssignmentManager.java
samza-core/src/main/java/org/apache/samza/container/grouper/task/TaskPartitionAssignmentManager.java [new file with mode: 0644]
samza-core/src/main/java/org/apache/samza/coordinator/stream/CoordinatorStreamKeySerde.java
samza-core/src/main/java/org/apache/samza/coordinator/stream/CoordinatorStreamValueSerde.java
samza-core/src/main/java/org/apache/samza/coordinator/stream/messages/SetContainerHostMapping.java
samza-core/src/main/java/org/apache/samza/coordinator/stream/messages/SetTaskContainerMapping.java
samza-core/src/main/java/org/apache/samza/coordinator/stream/messages/SetTaskPartitionMapping.java [new file with mode: 0644]
samza-core/src/main/java/org/apache/samza/serializers/model/SamzaObjectMapper.java
samza-core/src/main/scala/org/apache/samza/coordinator/JobModelManager.scala
samza-core/src/test/java/org/apache/samza/container/grouper/stream/TestGroupByPartition.java
samza-core/src/test/java/org/apache/samza/container/grouper/stream/TestGroupBySystemStreamPartition.java
samza-core/src/test/java/org/apache/samza/container/grouper/task/TestTaskPartitionAssignmentManager.java [new file with mode: 0644]
samza-core/src/test/java/org/apache/samza/coordinator/TestJobModelManager.java
samza-test/src/test/java/org/apache/samza/test/processor/TestStreamApplication.java
samza-test/src/test/java/org/apache/samza/test/processor/TestZkLocalApplicationRunner.java

diff --git a/samza-api/src/main/java/org/apache/samza/container/grouper/stream/GrouperContext.java b/samza-api/src/main/java/org/apache/samza/container/grouper/stream/GrouperContext.java
deleted file mode 100644 (file)
index d1118e4..0000000
+++ /dev/null
@@ -1,69 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.samza.container.grouper.stream;
-
-import java.util.Collections;
-import java.util.ArrayList;
-import java.util.Map;
-import java.util.List;
-import java.util.Set;
-import org.apache.samza.annotation.InterfaceStability;
-import org.apache.samza.container.TaskName;
-import org.apache.samza.runtime.LocationId;
-import org.apache.samza.system.SystemStreamPartition;
-
-/**
- * A Wrapper class that holds the necessary historical metadata of the samza job which is used
- * by the {@link org.apache.samza.container.grouper.stream.SystemStreamPartitionGrouper}
- * to generate optimal task assignments.
- */
-@InterfaceStability.Evolving
-public class GrouperContext {
-  private Map<String, LocationId> processorLocality;
-  private Map<TaskName, LocationId> taskLocality;
-  private Map<TaskName, Set<SystemStreamPartition>> previousTaskToSSPAssignment;
-  private Map<TaskName, String> previousTaskToContainerAssignment;
-
-  public GrouperContext(Map<String, LocationId> processorLocality, Map<TaskName, LocationId> taskLocality, Map<TaskName, Set<SystemStreamPartition>> previousTaskToSSPAssignments, Map<TaskName, String> previousTaskToContainerAssignment) {
-    this.processorLocality = processorLocality;
-    this.taskLocality = taskLocality;
-    this.previousTaskToSSPAssignment = previousTaskToSSPAssignments;
-    this.previousTaskToContainerAssignment = previousTaskToContainerAssignment;
-  }
-
-  public Map<String, LocationId> getProcessorLocality() {
-    return Collections.unmodifiableMap(processorLocality);
-  }
-
-  public Map<TaskName, LocationId> getTaskLocality() {
-    return Collections.unmodifiableMap(taskLocality);
-  }
-
-  public Map<TaskName, Set<SystemStreamPartition>> getPreviousTaskToSSPAssignment() {
-    return Collections.unmodifiableMap(previousTaskToSSPAssignment);
-  }
-
-  public List<String> getProcessorIds() {
-    return new ArrayList<>(processorLocality.keySet());
-  }
-
-  public Map<TaskName, String> getPreviousTaskToContainerAssignment() {
-    return Collections.unmodifiableMap(this.previousTaskToContainerAssignment);
-  }
-}
\ No newline at end of file
index fe076ee..24ddc23 100644 (file)
@@ -25,7 +25,6 @@ import java.util.HashMap;
 import java.util.Map;
 import org.apache.samza.config.Config;
 import org.apache.samza.config.JobConfig;
-import org.apache.samza.coordinator.stream.CoordinatorStreamKeySerde;
 import org.apache.samza.coordinator.stream.CoordinatorStreamValueSerde;
 import org.apache.samza.coordinator.stream.messages.SetContainerHostMapping;
 import org.apache.samza.metadatastore.MetadataStore;
@@ -37,28 +36,24 @@ import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 /**
- * Locality Manager is used to persist and read the container-to-host
- * assignment information from the coordinator stream.
+ * Used for persisting and reading the container-to-host assignment information into the metadata store.
  * */
 public class LocalityManager {
   private static final Logger LOG = LoggerFactory.getLogger(LocalityManager.class);
 
-  private final Config config;
-  private final Serde<String> keySerde;
   private final Serde<String> valueSerde;
   private final MetadataStore metadataStore;
 
   /**
    * Builds the LocalityManager based upon {@link Config} and {@link MetricsRegistry}.
-   * Uses {@link CoordinatorStreamKeySerde} and {@link CoordinatorStreamValueSerde} to
-   * serialize messages before reading/writing into coordinator stream.
+   * Uses the {@link CoordinatorStreamValueSerde} to serialize messages before
+   * reading/writing into metadata store.
    *
    * @param config the configuration required for setting up metadata store.
    * @param metricsRegistry the registry for reporting metrics.
    */
   public LocalityManager(Config config, MetricsRegistry metricsRegistry) {
-    this(config, metricsRegistry, new CoordinatorStreamKeySerde(SetContainerHostMapping.TYPE),
-         new CoordinatorStreamValueSerde(SetContainerHostMapping.TYPE));
+    this(config, metricsRegistry, new CoordinatorStreamValueSerde(SetContainerHostMapping.TYPE));
   }
 
   /**
@@ -69,15 +64,12 @@ public class LocalityManager {
    * Key and value serializer are different for yarn (uses CoordinatorStreamMessage) and standalone (native ObjectOutputStream for serialization) modes.
    * @param config the configuration required for setting up metadata store.
    * @param metricsRegistry the registry for reporting metrics.
-   * @param keySerde the key serializer.
    * @param valueSerde the value serializer.
    */
-  LocalityManager(Config config, MetricsRegistry metricsRegistry, Serde<String> keySerde, Serde<String> valueSerde) {
-    this.config = config;
+  LocalityManager(Config config, MetricsRegistry metricsRegistry, Serde<String> valueSerde) {
     MetadataStoreFactory metadataStoreFactory = Util.getObj(new JobConfig(config).getMetadataStoreFactory(), MetadataStoreFactory.class);
     this.metadataStore = metadataStoreFactory.getMetadataStore(SetContainerHostMapping.TYPE, config, metricsRegistry);
     this.metadataStore.init();
-    this.keySerde = keySerde;
     this.valueSerde = valueSerde;
   }
 
@@ -115,7 +107,7 @@ public class LocalityManager {
     Map<String, String> existingMappings = containerToHostMapping.get(containerId);
     String existingHostMapping = existingMappings != null ? existingMappings.get(SetContainerHostMapping.HOST_KEY) : null;
     if (existingHostMapping != null && !existingHostMapping.equals(hostName)) {
-      LOG.info("Container {} moved from {} to {}", new Object[]{containerId, existingHostMapping, hostName});
+      LOG.info("Container {} moved from {} to {}", containerId, existingHostMapping, hostName);
     } else {
       LOG.info("Container {} started at {}", containerId, hostName);
     }
index 0330700..1ef7269 100644 (file)
  */
 package org.apache.samza.container.grouper.stream;
 
+import java.util.ArrayList;
+import java.util.Collection;
 import java.util.Map;
 import java.util.HashMap;
+import java.util.List;
 import java.util.Set;
 import java.util.HashSet;
 import com.google.common.base.Preconditions;
@@ -27,6 +30,7 @@ import org.apache.samza.config.Config;
 import org.apache.samza.config.JobConfig;
 import org.apache.samza.config.TaskConfigJava;
 import org.apache.samza.container.TaskName;
+import org.apache.samza.container.grouper.task.GrouperMetadata;
 import org.apache.samza.metrics.MetricsRegistryMap;
 import org.apache.samza.system.SystemStream;
 import org.apache.samza.system.SystemStreamPartition;
@@ -60,26 +64,31 @@ public class SSPGrouperProxy {
    * 3. Uses the previous, current task to partition assignments and result of {@link SystemStreamPartitionMapper} to redistribute the expanded {@link SystemStreamPartition}'s
    * to correct tasks after the stream expansion or contraction.
    * @param ssps the input system stream partitions of the job.
-   * @param grouperContext the grouper context holding metadata for the previous job execution.
+   * @param grouperMetadata provides the metadata for the previous samza job execution.
    * @return the grouped {@link TaskName} to {@link SystemStreamPartition} assignments.
    */
-  public Map<TaskName, Set<SystemStreamPartition>> group(Set<SystemStreamPartition> ssps, GrouperContext grouperContext) {
-    Map<TaskName, Set<SystemStreamPartition>> currentTaskAssignments = grouper.group(ssps);
+  public Map<TaskName, Set<SystemStreamPartition>> group(Set<SystemStreamPartition> ssps, GrouperMetadata grouperMetadata) {
+    Map<TaskName, Set<SystemStreamPartition>> groupedResult = grouper.group(ssps);
 
-    if (grouperContext.getPreviousTaskToSSPAssignment().isEmpty()) {
+    if (grouperMetadata.getPreviousTaskToSSPAssignment().isEmpty()) {
       LOGGER.info("Previous task to partition assignment does not exist. Using the result from the group method.");
-      return currentTaskAssignments;
+      return groupedResult;
     }
 
-    Map<SystemStreamPartition, TaskName> previousSSPToTask = getPreviousSSPToTaskMapping(grouperContext);
+    Map<TaskName, List<SystemStreamPartition>> currentTaskAssignments = new HashMap<>();
+    for (Map.Entry<TaskName, Set<SystemStreamPartition>> entry : groupedResult.entrySet()) {
+      currentTaskAssignments.put(entry.getKey(), new ArrayList<>(entry.getValue()));
+    }
+
+    Map<SystemStreamPartition, TaskName> previousSSPToTask = getPreviousSSPToTaskMapping(grouperMetadata);
 
     Map<TaskName, PartitionGroup> taskToPartitionGroup = new HashMap<>();
     currentTaskAssignments.forEach((taskName, systemStreamPartitions) -> taskToPartitionGroup.put(taskName, new PartitionGroup(taskName, systemStreamPartitions)));
 
-    Map<SystemStream, Integer> previousStreamToPartitionCount = getSystemStreamToPartitionCount(grouperContext.getPreviousTaskToSSPAssignment());
+    Map<SystemStream, Integer> previousStreamToPartitionCount = getSystemStreamToPartitionCount(grouperMetadata.getPreviousTaskToSSPAssignment());
     Map<SystemStream, Integer> currentStreamToPartitionCount = getSystemStreamToPartitionCount(currentTaskAssignments);
 
-    for (Map.Entry<TaskName, Set<SystemStreamPartition>> entry : currentTaskAssignments.entrySet()) {
+    for (Map.Entry<TaskName, List<SystemStreamPartition>> entry : currentTaskAssignments.entrySet()) {
       TaskName currentlyAssignedTask = entry.getKey();
       for (SystemStreamPartition currentSystemStreamPartition : entry.getValue()) {
         if (broadcastSystemStreamPartitions.contains(currentSystemStreamPartition)) {
@@ -121,7 +130,7 @@ public class SSPGrouperProxy {
    * @param taskToSSPAssignment the {@link TaskName} to {@link SystemStreamPartition}'s assignment of the job.
    * @return a mapping from {@link SystemStream} to the number of partitions of the stream.
    */
-  private Map<SystemStream, Integer> getSystemStreamToPartitionCount(Map<TaskName, Set<SystemStreamPartition>> taskToSSPAssignment) {
+  private Map<SystemStream, Integer> getSystemStreamToPartitionCount(Map<TaskName, List<SystemStreamPartition>> taskToSSPAssignment) {
     Map<SystemStream, Integer> systemStreamToPartitionCount = new HashMap<>();
     taskToSSPAssignment.forEach((taskName, systemStreamPartitions) -> {
         systemStreamPartitions.forEach(systemStreamPartition -> {
@@ -134,13 +143,13 @@ public class SSPGrouperProxy {
   }
 
   /**
-   * Computes a mapping from the {@link SystemStreamPartition} to {@link TaskName} using the provided {@param grouperContext}
-   * @param grouperContext the grouper context that contains relevant historical metadata about the job.
+   * Computes a mapping from the {@link SystemStreamPartition} to {@link TaskName} using the provided {@param grouperMetadata}
+   * @param grouperMetadata the grouper context that contains relevant historical metadata about the job.
    * @return a mapping from {@link SystemStreamPartition} to {@link TaskName}.
    */
-  private Map<SystemStreamPartition, TaskName> getPreviousSSPToTaskMapping(GrouperContext grouperContext) {
+  private Map<SystemStreamPartition, TaskName> getPreviousSSPToTaskMapping(GrouperMetadata grouperMetadata) {
     Map<SystemStreamPartition, TaskName> sspToTaskMapping = new HashMap<>();
-    Map<TaskName, Set<SystemStreamPartition>> previousTaskToSSPAssignment = grouperContext.getPreviousTaskToSSPAssignment();
+    Map<TaskName, List<SystemStreamPartition>> previousTaskToSSPAssignment = grouperMetadata.getPreviousTaskToSSPAssignment();
     previousTaskToSSPAssignment.forEach((taskName, systemStreamPartitions) -> {
         systemStreamPartitions.forEach(systemStreamPartition -> {
             if (!broadcastSystemStreamPartitions.contains(systemStreamPartition)) {
@@ -172,7 +181,7 @@ public class SSPGrouperProxy {
     private TaskName taskName;
     private Set<SystemStreamPartition> systemStreamPartitions;
 
-    PartitionGroup(TaskName taskName, Set<SystemStreamPartition> systemStreamPartitions) {
+    PartitionGroup(TaskName taskName, Collection<SystemStreamPartition> systemStreamPartitions) {
       Preconditions.checkNotNull(taskName);
       Preconditions.checkNotNull(systemStreamPartitions);
       this.taskName = taskName;
index 669ed57..0bac04d 100644 (file)
@@ -44,7 +44,6 @@ import org.slf4j.LoggerFactory;
 public class TaskAssignmentManager {
   private static final Logger LOG = LoggerFactory.getLogger(TaskAssignmentManager.class);
 
-  private final Config config;
   private final Map<String, String> taskNameToContainerId = new HashMap<>();
   private final Serde<String> keySerde;
   private final Serde<String> containerIdSerde;
@@ -55,8 +54,8 @@ public class TaskAssignmentManager {
 
   /**
    * Builds the TaskAssignmentManager based upon {@link Config} and {@link MetricsRegistry}.
-   * Uses {@link CoordinatorStreamKeySerde} and {@link CoordinatorStreamValueSerde} to
-   * serialize messages before reading/writing into coordinator stream.
+   * Uses {@link CoordinatorStreamValueSerde} to serialize messages before reading/writing
+   * into the metadata store.
    *
    * @param config the configuration required for setting up metadata store.
    * @param metricsRegistry the registry for reporting metrics.
@@ -81,7 +80,6 @@ public class TaskAssignmentManager {
    * @param taskModeSerde the task-mode serializer.
    */
   public TaskAssignmentManager(Config config, MetricsRegistry metricsRegistry, Serde<String> keySerde, Serde<String> containerIdSerde, Serde<String> taskModeSerde) {
-    this.config = config;
     this.keySerde = keySerde;
     this.containerIdSerde = containerIdSerde;
     this.taskModeSerde = taskModeSerde;
diff --git a/samza-core/src/main/java/org/apache/samza/container/grouper/task/TaskPartitionAssignmentManager.java b/samza-core/src/main/java/org/apache/samza/container/grouper/task/TaskPartitionAssignmentManager.java
new file mode 100644 (file)
index 0000000..c38bcd4
--- /dev/null
@@ -0,0 +1,150 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.samza.container.grouper.task;
+
+import org.apache.samza.SamzaException;
+import org.apache.samza.config.Config;
+import org.apache.samza.config.JobConfig;
+import org.apache.samza.coordinator.stream.CoordinatorStreamValueSerde;
+import org.apache.samza.coordinator.stream.messages.SetTaskPartitionMapping;
+import org.apache.samza.metadatastore.MetadataStore;
+import org.apache.samza.metadatastore.MetadataStoreFactory;
+import org.apache.samza.metrics.MetricsRegistry;
+import org.apache.samza.serializers.Serde;
+import org.apache.samza.serializers.model.SamzaObjectMapper;
+import org.apache.samza.system.SystemStreamPartition;
+import org.apache.samza.util.Util;
+import org.codehaus.jackson.map.ObjectMapper;
+import org.codehaus.jackson.type.TypeReference;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * Used to persisting and reading the task-to-partition assignment information
+ * into the metadata store.
+ */
+public class TaskPartitionAssignmentManager {
+
+  private static final Logger LOG = LoggerFactory.getLogger(TaskPartitionAssignmentManager.class);
+
+  private final ObjectMapper taskNameMapper = SamzaObjectMapper.getObjectMapper();
+  private final ObjectMapper sspMapper = SamzaObjectMapper.getObjectMapper();
+
+  private final Serde<String> valueSerde;
+  private final MetadataStore metadataStore;
+
+  /**
+   * Instantiates the task partition assignment manager with the provided metricsRegistry
+   * and config.
+   * @param config the configuration required for connecting with the metadata store.
+   * @param metricsRegistry the registry to create and report custom metrics.
+   */
+  public TaskPartitionAssignmentManager(Config config, MetricsRegistry metricsRegistry) {
+    this(config, metricsRegistry, new CoordinatorStreamValueSerde(SetTaskPartitionMapping.TYPE));
+  }
+
+  TaskPartitionAssignmentManager(Config config, MetricsRegistry metricsRegistry, Serde<String> valueSerde) {
+    this.valueSerde = valueSerde;
+    MetadataStoreFactory metadataStoreFactory = Util.getObj(new JobConfig(config).getMetadataStoreFactory(), MetadataStoreFactory.class);
+    this.metadataStore = metadataStoreFactory.getMetadataStore(SetTaskPartitionMapping.TYPE, config, metricsRegistry);
+    this.metadataStore.init();
+  }
+
+  /**
+   * Stores the task to partition assignments to the metadata store.
+   * @param partition the system stream partition.
+   * @param taskNames the task names to which the partition is assigned to.
+   */
+  public void writeTaskPartitionAssignment(SystemStreamPartition partition, List<String> taskNames) {
+    String serializedKey = getKey(partition);
+    if (taskNames == null || taskNames.isEmpty()) {
+      LOG.info("Deleting the key: {} from the metadata store.", partition);
+      metadataStore.delete(serializedKey);
+    } else {
+      try {
+        String taskNameAsString = taskNameMapper.writeValueAsString(taskNames);
+        byte[] taskNameAsBytes = valueSerde.toBytes(taskNameAsString);
+        LOG.info("Storing the partition: {} and taskNames: {} into the metadata store.", serializedKey, taskNames);
+        metadataStore.put(serializedKey, taskNameAsBytes);
+      } catch (Exception e) {
+        throw new SamzaException("Exception occurred when writing task to partition assignment.", e);
+      }
+    }
+  }
+
+  /**
+   * Reads the task partition assignments from the underlying storage layer.
+   * @return the task partition assignments.
+   */
+  public Map<SystemStreamPartition, List<String>> readTaskPartitionAssignments() {
+    try {
+      Map<SystemStreamPartition, List<String>> sspToTaskNamesMap = new HashMap<>();
+      Map<String, byte[]> allMetadataEntries = metadataStore.all();
+      for (Map.Entry<String, byte[]> entry : allMetadataEntries.entrySet()) {
+        LOG.info("Trying to deserialize the system stream partition: {}", entry.getKey());
+        SystemStreamPartition systemStreamPartition = getSystemStreamPartition(entry.getKey());
+        String taskNameAsJson = valueSerde.fromBytes(entry.getValue());
+        List<String> taskNames = taskNameMapper.readValue(taskNameAsJson, new TypeReference<List<String>>() { });
+        sspToTaskNamesMap.put(systemStreamPartition, taskNames);
+      }
+      return sspToTaskNamesMap;
+    } catch (Exception e) {
+      throw new SamzaException("Exception occurred when reading task partition assignments.", e);
+    }
+  }
+
+  /**
+   * Deletes the system stream partitions from the underlying metadata store.
+   * @param systemStreamPartitions the system stream partitions to delete.
+   */
+  public void delete(Iterable<SystemStreamPartition> systemStreamPartitions) {
+    for (SystemStreamPartition systemStreamPartition : systemStreamPartitions) {
+      LOG.info("Deleting the partition: {} from store.", systemStreamPartition);
+      String sspKey = getKey(systemStreamPartition);
+      metadataStore.delete(sspKey);
+    }
+  }
+
+  /**
+   * Closes the connections with the underlying metadata store.
+   */
+  public void close() {
+    metadataStore.close();
+  }
+
+  private String getKey(SystemStreamPartition systemStreamPartition) {
+    try {
+      return sspMapper.writeValueAsString(systemStreamPartition);
+    } catch (IOException e) {
+      throw new SamzaException(String.format("Exception occurred when serializing the partition: %s", systemStreamPartition), e);
+    }
+  }
+
+  private SystemStreamPartition getSystemStreamPartition(String partitionAsString) {
+    try {
+      return sspMapper.readValue(partitionAsString, SystemStreamPartition.class);
+    } catch (IOException e) {
+      throw new SamzaException(String.format("Exception occurred when deserializing the partition: %s", partitionAsString), e);
+    }
+  }
+}
index 4eb9024..8e3f229 100644 (file)
@@ -19,7 +19,6 @@
 package org.apache.samza.coordinator.stream;
 
 import java.util.Arrays;
-import java.util.HashMap;
 import java.util.List;
 import org.apache.samza.coordinator.stream.messages.CoordinatorStreamMessage;
 import org.apache.samza.serializers.JsonSerde;
@@ -30,6 +29,8 @@ import org.apache.samza.serializers.Serde;
  */
 public class CoordinatorStreamKeySerde implements Serde<String> {
 
+  private static final int KEY_INDEX = 2;
+
   private final Serde<List<?>> keySerde;
   private final String type;
 
@@ -40,8 +41,8 @@ public class CoordinatorStreamKeySerde implements Serde<String> {
 
   @Override
   public String fromBytes(byte[] bytes) {
-    CoordinatorStreamMessage message = new CoordinatorStreamMessage(keySerde.fromBytes(bytes).toArray(), new HashMap<>());
-    return message.getKey();
+    Object[] keyArray = keySerde.fromBytes(bytes).toArray();
+    return (String) keyArray[KEY_INDEX];
   }
 
   @Override
index 82dcf81..07e0985 100644 (file)
@@ -25,6 +25,7 @@ import org.apache.samza.coordinator.stream.messages.SetChangelogMapping;
 import org.apache.samza.coordinator.stream.messages.SetContainerHostMapping;
 import org.apache.samza.coordinator.stream.messages.SetTaskContainerMapping;
 import org.apache.samza.coordinator.stream.messages.SetConfig;
+import org.apache.samza.coordinator.stream.messages.SetTaskPartitionMapping;
 import org.apache.samza.SamzaException;
 import org.apache.samza.coordinator.stream.messages.SetTaskModeMapping;
 import org.apache.samza.serializers.JsonSerde;
@@ -65,6 +66,9 @@ public class CoordinatorStreamValueSerde implements Serde<String> {
     } else if (type.equalsIgnoreCase(SetTaskModeMapping.TYPE)) {
       SetTaskModeMapping setTaskModeMapping = new SetTaskModeMapping(message);
       return String.valueOf(setTaskModeMapping.getTaskMode());
+    } else if (type.equalsIgnoreCase(SetTaskPartitionMapping.TYPE)) {
+      SetTaskPartitionMapping setTaskPartitionMapping = new SetTaskPartitionMapping(message);
+      return setTaskPartitionMapping.getTaskName();
     } else {
       throw new SamzaException(String.format("Unknown coordinator stream message type: %s", type));
     }
@@ -87,6 +91,9 @@ public class CoordinatorStreamValueSerde implements Serde<String> {
     } else if (type.equalsIgnoreCase(SetConfig.TYPE)) {
       SetConfig setConfig = new SetConfig(SOURCE, "", value);
       return messageSerde.toBytes(setConfig.getMessageMap());
+    } else if (type.equalsIgnoreCase(SetTaskPartitionMapping.TYPE)) {
+      SetTaskPartitionMapping setTaskPartitionMapping = new SetTaskPartitionMapping(SOURCE, "", value);
+      return messageSerde.toBytes(setTaskPartitionMapping.getMessageMap());
     } else {
       throw new SamzaException(String.format("Unknown coordinator stream message type: %s", type));
     }
index da67346..82c6c1e 100644 (file)
@@ -70,21 +70,4 @@ public class SetContainerHostMapping extends CoordinatorStreamMessage {
   public String getHostLocality() {
     return getMessageValue(HOST_KEY);
   }
-
-  /**
-   * Returns the JMX url of the container.
-   * @return the JMX url
-   */
-  public String getJmxUrl() {
-    return getMessageValue(JMX_URL_KEY);
-  }
-
-  /**
-   * Returns the JMX tunneling url of the container
-   * @return the JMX tunneling url
-   */
-  public String getJmxTunnelingUrl() {
-    return getMessageValue(JMX_TUNNELING_URL_KEY);
-  }
-
 }
diff --git a/samza-core/src/main/java/org/apache/samza/coordinator/stream/messages/SetTaskPartitionMapping.java b/samza-core/src/main/java/org/apache/samza/coordinator/stream/messages/SetTaskPartitionMapping.java
new file mode 100644 (file)
index 0000000..3b88293
--- /dev/null
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.coordinator.stream.messages;
+
+/**
+ * SetTaskPartitionMapping is a {@link CoordinatorStreamMessage} used in samza
+ * to persist the task to partition assignments of a samza job.
+ *
+ * Structure of the message:
+ *
+ * <pre>
+ * key =&gt; [1, "set-task-partition-assignment", "{"partition": "1", "system": "test-system", "stream": "test-stream"}"]
+ *
+ * message =&gt; {
+ *     "host" : "192.168.0.1",
+ *     "source" : "TaskPartitionAssignmentManager",
+ *     "username" :"app",
+ *     "timestamp" : 1456177487325,
+ *     "values" : {
+ *         "taskNames" : ["task-1", "task2", "task3"]
+ *     }
+ * }
+ * </pre>
+ */
+public class SetTaskPartitionMapping extends CoordinatorStreamMessage {
+
+  private static final String TASK_NAME_KEY = "taskNames";
+
+  public static final String TYPE = "set-task-partition-assignment";
+
+  /**
+   * SetTaskPartitionMapping is the data format for persisting the partition to task assignments
+   * of a samza job to the coordinator stream.
+   * @param message which holds the partition to task information.
+   */
+  public SetTaskPartitionMapping(CoordinatorStreamMessage message) {
+    super(message.getKeyArray(), message.getMessageMap());
+  }
+
+  /**
+   * SetTaskPartitionMapping is the data format for persisting the partition to task assignments
+   * of a samza job to the coordinator stream.
+   * @param source      the source of the message.
+   * @param partition   the system stream partition serialized as a string.
+   * @param taskName    the name of the task mapped to the system stream partition.
+   */
+  public SetTaskPartitionMapping(String source, String partition, String taskName) {
+    super(source);
+    setType(TYPE);
+    setKey(partition);
+    putMessageValue(TASK_NAME_KEY, taskName);
+  }
+
+  public String getTaskName() {
+    return getMessageValue(TASK_NAME_KEY);
+  }
+}
index a247fb3..694987f 100644 (file)
@@ -83,6 +83,7 @@ public class SamzaObjectMapper {
     module.addKeySerializer(SystemStreamPartition.class, new SystemStreamPartitionKeySerializer());
     module.addSerializer(TaskName.class, new TaskNameSerializer());
     module.addSerializer(TaskMode.class, new TaskModeSerializer());
+    module.addDeserializer(TaskName.class, new TaskNameDeserializer());
     module.addDeserializer(Partition.class, new PartitionDeserializer());
     module.addDeserializer(SystemStreamPartition.class, new SystemStreamPartitionDeserializer());
     module.addKeyDeserializer(SystemStreamPartition.class, new SystemStreamPartitionKeyDeserializer());
index 0695cd7..c4a5db7 100644 (file)
@@ -28,6 +28,12 @@ import org.apache.samza.config.SystemConfig.Config2System
 import org.apache.samza.config.TaskConfig.Config2Task
 import org.apache.samza.config.{Config, _}
 import org.apache.samza.container.grouper.stream.SystemStreamPartitionGrouperFactory
+import org.apache.samza.config._
+import org.apache.samza.config.JobConfig.Config2Job
+import org.apache.samza.config.SystemConfig.Config2System
+import org.apache.samza.config.TaskConfig.Config2Task
+import org.apache.samza.config.Config
+import org.apache.samza.container.grouper.stream.{SSPGrouperProxy, SystemStreamPartitionGrouperFactory}
 import org.apache.samza.container.grouper.task._
 import org.apache.samza.container.{LocalityManager, TaskName}
 import org.apache.samza.coordinator.server.{HttpServer, JobServlet}
@@ -67,16 +73,17 @@ object JobModelManager extends Logging {
   def apply(config: Config, changelogPartitionMapping: util.Map[TaskName, Integer], metricsRegistry: MetricsRegistry = new MetricsRegistryMap()): JobModelManager = {
     val localityManager = new LocalityManager(config, metricsRegistry)
     val taskAssignmentManager = new TaskAssignmentManager(config, metricsRegistry)
+    val taskPartitionAssignmentManager = new TaskPartitionAssignmentManager(config, metricsRegistry)
     val systemAdmins = new SystemAdmins(config)
     try {
       systemAdmins.start()
       val streamMetadataCache = new StreamMetadataCache(systemAdmins, 0)
-      val grouperMetadata: GrouperMetadata = getGrouperMetadata(config, localityManager, taskAssignmentManager)
+      val grouperMetadata: GrouperMetadata = getGrouperMetadata(config, localityManager, taskAssignmentManager, taskPartitionAssignmentManager)
 
       val jobModel: JobModel = readJobModel(config, changelogPartitionMapping, streamMetadataCache, grouperMetadata)
       jobModelRef.set(new JobModel(jobModel.getConfig, jobModel.getContainers, localityManager))
 
-      updateTaskAssignments(jobModel, taskAssignmentManager, grouperMetadata)
+      updateTaskAssignments(jobModel, taskAssignmentManager, taskPartitionAssignmentManager, grouperMetadata)
 
       val server = new HttpServer
       server.addServlet("/", new JobServlet(jobModelRef))
@@ -84,6 +91,7 @@ object JobModelManager extends Logging {
       currentJobModelManager = new JobModelManager(jobModelRef.get(), server, localityManager)
       currentJobModelManager
     } finally {
+      taskPartitionAssignmentManager.close()
       taskAssignmentManager.close()
       systemAdmins.stop()
       // Not closing localityManager, since {@code ClusterBasedJobCoordinator} uses it to read container locality through {@code JobModel}.
@@ -95,9 +103,10 @@ object JobModelManager extends Logging {
     * @param config represents the configurations defined by the user.
     * @param localityManager provides the processor to host mapping persisted to the metadata store.
     * @param taskAssignmentManager provides the processor to task assignments persisted to the metadata store.
+    * @param taskPartitionAssignmentManager provides the task to partition assignments persisted to the metadata store.
     * @return the instantiated {@see GrouperMetadata}.
     */
-  def getGrouperMetadata(config: Config, localityManager: LocalityManager, taskAssignmentManager: TaskAssignmentManager) = {
+  def getGrouperMetadata(config: Config, localityManager: LocalityManager, taskAssignmentManager: TaskAssignmentManager, taskPartitionAssignmentManager: TaskPartitionAssignmentManager) = {
     val processorLocality: util.Map[String, LocationId] = getProcessorLocality(config, localityManager)
     val taskModes: util.Map[TaskName, TaskMode] = taskAssignmentManager.readTaskModes()
 
@@ -111,13 +120,26 @@ object JobModelManager extends Logging {
       taskNameToProcessorId.put(new TaskName(taskName), processorId)
     }
 
-    val taskLocality:util.Map[TaskName, LocationId] = new util.HashMap[TaskName, LocationId]()
+    val taskLocality: util.Map[TaskName, LocationId] = new util.HashMap[TaskName, LocationId]()
     for ((taskName, processorId) <- taskAssignment) {
       if (processorLocality.containsKey(processorId)) {
         taskLocality.put(new TaskName(taskName), processorLocality.get(processorId))
       }
     }
-    new GrouperMetadataImpl(processorLocality, taskLocality, new util.HashMap[TaskName, util.List[SystemStreamPartition]](), taskNameToProcessorId)
+
+    val sspToTaskMapping: util.Map[SystemStreamPartition, util.List[String]] = taskPartitionAssignmentManager.readTaskPartitionAssignments()
+    val taskPartitionAssignments: util.Map[TaskName, util.List[SystemStreamPartition]] = new util.HashMap[TaskName, util.List[SystemStreamPartition]]()
+
+    sspToTaskMapping foreach { case (systemStreamPartition: SystemStreamPartition, taskNames: util.List[String]) =>
+      for (task <- taskNames) {
+        val taskName: TaskName = new TaskName(task)
+        if (!taskPartitionAssignments.containsKey(taskName)) {
+          taskPartitionAssignments.put(taskName, new util.ArrayList[SystemStreamPartition]())
+        }
+        taskPartitionAssignments.get(taskName).add(systemStreamPartition)
+      }
+    }
+    new GrouperMetadataImpl(processorLocality, taskLocality, taskPartitionAssignments, taskNameToProcessorId)
   }
 
   /**
@@ -150,12 +172,18 @@ object JobModelManager extends Logging {
     * 2. Saves the newly generated task assignments to the storage layer through the {@param TaskAssignementManager}.
     *
     * @param jobModel              represents the {@see JobModel} of the samza job.
-    * @param taskAssignmentManager required to persist the processor to task assignments to the storage layer.
-    * @param grouperMetadata       provides the historical metadata of the application.
+    * @param taskAssignmentManager required to persist the processor to task assignments to the metadata store.
+    * @param taskPartitionAssignmentManager required to persist the task to partition assignments to the metadata store.
+    * @param grouperMetadata       provides the historical metadata of the samza application.
     */
-  def updateTaskAssignments(jobModel: JobModel, taskAssignmentManager: TaskAssignmentManager, grouperMetadata: GrouperMetadata): Unit = {
+  def updateTaskAssignments(jobModel: JobModel,
+                            taskAssignmentManager: TaskAssignmentManager,
+                            taskPartitionAssignmentManager: TaskPartitionAssignmentManager,
+                            grouperMetadata: GrouperMetadata): Unit = {
+    info("Storing the task assignments into metadata store.")
     val activeTaskNames: util.Set[String] = new util.HashSet[String]()
     val standbyTaskNames: util.Set[String] = new util.HashSet[String]()
+    val systemStreamPartitions: util.Set[SystemStreamPartition] = new util.HashSet[SystemStreamPartition]()
     for (container <- jobModel.getContainers.values()) {
       for (taskModel <- container.getTasks.values()) {
         if(taskModel.getTaskMode.eq(TaskMode.Active)) {
@@ -165,6 +193,7 @@ object JobModelManager extends Logging {
         if(taskModel.getTaskMode.eq(TaskMode.Standby)) {
           standbyTaskNames.add(taskModel.getTaskName.getTaskName)
         }
+        systemStreamPartitions.addAll(taskModel.getSystemStreamPartitions)
       }
     }
 
@@ -178,6 +207,7 @@ object JobModelManager extends Logging {
       // data associated with the incoming keys. Warn the user and default to grouper
       // In this scenario the tasks may have been reduced, so we need to delete all the existing messages
       taskAssignmentManager.deleteTaskContainerMappings(previousTaskToContainerId.keys.map(taskName => taskName.getTaskName).asJava)
+      taskPartitionAssignmentManager.delete(systemStreamPartitions)
     }
 
     // if the set of standby tasks has changed, e.g., when the replication-factor changed, or the active-tasks-set has
@@ -188,11 +218,25 @@ object JobModelManager extends Logging {
       taskAssignmentManager.deleteTaskContainerMappings(previousStandbyTasks.map(x => x._1.getTaskName).asJava)
     }
 
+    val sspToTaskNameMap: util.Map[SystemStreamPartition, util.List[String]] = new util.HashMap[SystemStreamPartition, util.List[String]]()
+
     for (container <- jobModel.getContainers.values()) {
-      for (taskName <- container.getTasks.keySet) {
+      for ((taskName, taskModel) <- container.getTasks) {
+        info ("Storing ssp: %s and task: %s into metadata store" format(taskName.getTaskName, container.getId))
         taskAssignmentManager.writeTaskContainerMapping(taskName.getTaskName, container.getId, container.getTasks.get(taskName).getTaskMode)
+        for (partition <- taskModel.getSystemStreamPartitions) {
+          if (!sspToTaskNameMap.containsKey(partition)) {
+            sspToTaskNameMap.put(partition, new util.ArrayList[String]())
+          }
+          sspToTaskNameMap.get(partition).add(taskName.getTaskName)
+        }
       }
     }
+
+    for ((ssp, taskNames) <- sspToTaskNameMap) {
+      info ("Storing ssp: %s and task: %s into metadata store" format(ssp, taskNames))
+      taskPartitionAssignmentManager.writeTaskPartitionAssignment(ssp, taskNames)
+    }
   }
 
   /**
@@ -255,7 +299,6 @@ object JobModelManager extends Logging {
     factory.getSystemStreamPartitionGrouper(config)
   }
 
-
   /**
     * Does the following:
     * 1. Fetches metadata of the input streams defined in configuration through {@param streamMetadataCache}.
@@ -280,11 +323,17 @@ object JobModelManager extends Logging {
     configMap.put(JobConfig.PROCESSOR_LIST, String.join(",", grouperMetadata.getProcessorLocality.keySet()))
     val grouper = getSystemStreamPartitionGrouper(new MapConfig(configMap))
 
-    val groups = grouper.group(allSystemStreamPartitions)
-    info("SystemStreamPartitionGrouper %s has grouped the SystemStreamPartitions into %d tasks with the following taskNames: %s" format(grouper, groups.size(), groups.keySet()))
-
     val isHostAffinityEnabled = new ClusterManagerConfig(config).getHostAffinityEnabled
 
+    var groups: util.Map[TaskName, util.Set[SystemStreamPartition]] = null
+    if (isHostAffinityEnabled) {
+      val sspGrouperProxy: SSPGrouperProxy =  new SSPGrouperProxy(config, grouper)
+      groups = sspGrouperProxy.group(allSystemStreamPartitions, grouperMetadata)
+    } else {
+      groups = grouper.group(allSystemStreamPartitions)
+    }
+    info("SystemStreamPartitionGrouper %s has grouped the SystemStreamPartitions into %d tasks with the following taskNames: %s" format(grouper, groups.size(), groups))
+
     // If no mappings are present(first time the job is running) we return -1, this will allow 0 to be the first change
     // mapping.
     var maxChangelogPartitionId = changeLogPartitionMapping.asScala.values.map(_.toInt).toList.sorted.lastOption.getOrElse(-1)
index b0e0ccd..265233b 100644 (file)
@@ -21,17 +21,20 @@ package org.apache.samza.container.grouper.stream;
 
 import java.util.HashMap;
 import java.util.HashSet;
+import java.util.List;
 import java.util.Map;
 import java.util.Set;
 import java.util.stream.Collectors;
 import java.util.stream.IntStream;
-
+import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableMap;
 import com.google.common.collect.ImmutableSet;
 import org.apache.samza.Partition;
 import org.apache.samza.config.Config;
 import org.apache.samza.config.MapConfig;
 import org.apache.samza.container.TaskName;
+import org.apache.samza.container.grouper.task.GrouperMetadata;
+import org.apache.samza.container.grouper.task.GrouperMetadataImpl;
 import org.apache.samza.system.SystemStreamPartition;
 import org.junit.Assert;
 import org.junit.Test;
@@ -93,11 +96,11 @@ public class TestGroupByPartition {
 
   @Test
   public void testSingleStreamRepartitioning() {
-    Map<TaskName, Set<SystemStreamPartition>> prevGroupingWithSingleStream = ImmutableMap.<TaskName, Set<SystemStreamPartition>>builder()
-            .put(new TaskName("Partition 0"), ImmutableSet.of(new SystemStreamPartition("kafka", "PVE", new Partition(0))))
-            .put(new TaskName("Partition 1"), ImmutableSet.of(new SystemStreamPartition("kafka", "PVE", new Partition(1))))
-            .put(new TaskName("Partition 2"), ImmutableSet.of(new SystemStreamPartition("kafka", "PVE", new Partition(2))))
-            .put(new TaskName("Partition 3"), ImmutableSet.of(new SystemStreamPartition("kafka", "PVE", new Partition(3))))
+    Map<TaskName, List<SystemStreamPartition>> prevGroupingWithSingleStream = ImmutableMap.<TaskName, List<SystemStreamPartition>>builder()
+            .put(new TaskName("Partition 0"), ImmutableList.of(new SystemStreamPartition("kafka", "PVE", new Partition(0))))
+            .put(new TaskName("Partition 1"), ImmutableList.of(new SystemStreamPartition("kafka", "PVE", new Partition(1))))
+            .put(new TaskName("Partition 2"), ImmutableList.of(new SystemStreamPartition("kafka", "PVE", new Partition(2))))
+            .put(new TaskName("Partition 3"), ImmutableList.of(new SystemStreamPartition("kafka", "PVE", new Partition(3))))
             .build();
 
     Set<SystemStreamPartition> currSsps = IntStream.range(0, 8)
@@ -116,18 +119,18 @@ public class TestGroupByPartition {
             .build();
 
     SSPGrouperProxy groupByPartition = new SSPGrouperProxy(new MapConfig(), new GroupByPartition(new MapConfig()));
-    GrouperContext grouperContext = new GrouperContext(new HashMap<>(), new HashMap<>(), prevGroupingWithSingleStream, new HashMap<>());
-    Map<TaskName, Set<SystemStreamPartition>> finalGrouping = groupByPartition.group(currSsps, grouperContext);
+    GrouperMetadata grouperMetadata = new GrouperMetadataImpl(new HashMap<>(), new HashMap<>(), prevGroupingWithSingleStream, new HashMap<>());
+    Map<TaskName, Set<SystemStreamPartition>> finalGrouping = groupByPartition.group(currSsps, grouperMetadata);
     Assert.assertEquals(expectedGrouping, finalGrouping);
   }
 
   @Test
   public void testMultipleStreamsWithSingleStreamRepartitioning() {
-    Map<TaskName, Set<SystemStreamPartition>> prevGroupingWithMultipleStreams = ImmutableMap.<TaskName, Set<SystemStreamPartition>>builder()
-            .put(new TaskName("Partition 0"), ImmutableSet.of(new SystemStreamPartition("kafka", "PVE", new Partition(0)), new SystemStreamPartition("kafka", "URE", new Partition(0))))
-            .put(new TaskName("Partition 1"), ImmutableSet.of(new SystemStreamPartition("kafka", "PVE", new Partition(1)), new SystemStreamPartition("kafka", "URE", new Partition(1))))
-            .put(new TaskName("Partition 2"), ImmutableSet.of(new SystemStreamPartition("kafka", "PVE", new Partition(2)), new SystemStreamPartition("kafka", "URE", new Partition(2))))
-            .put(new TaskName("Partition 3"), ImmutableSet.of(new SystemStreamPartition("kafka", "PVE", new Partition(3)), new SystemStreamPartition("kafka", "URE", new Partition(3))))
+    Map<TaskName, List<SystemStreamPartition>> prevGroupingWithMultipleStreams = ImmutableMap.<TaskName, List<SystemStreamPartition>>builder()
+            .put(new TaskName("Partition 0"), ImmutableList.of(new SystemStreamPartition("kafka", "PVE", new Partition(0)), new SystemStreamPartition("kafka", "URE", new Partition(0))))
+            .put(new TaskName("Partition 1"), ImmutableList.of(new SystemStreamPartition("kafka", "PVE", new Partition(1)), new SystemStreamPartition("kafka", "URE", new Partition(1))))
+            .put(new TaskName("Partition 2"), ImmutableList.of(new SystemStreamPartition("kafka", "PVE", new Partition(2)), new SystemStreamPartition("kafka", "URE", new Partition(2))))
+            .put(new TaskName("Partition 3"), ImmutableList.of(new SystemStreamPartition("kafka", "PVE", new Partition(3)), new SystemStreamPartition("kafka", "URE", new Partition(3))))
             .build();
 
     Set<SystemStreamPartition> currSsps = IntStream.range(0, 8)
@@ -163,18 +166,18 @@ public class TestGroupByPartition {
             .build();
 
     SSPGrouperProxy groupByPartition = new SSPGrouperProxy(new MapConfig(), new GroupByPartition(new MapConfig()));
-    GrouperContext grouperContext = new GrouperContext(new HashMap<>(), new HashMap<>(), prevGroupingWithMultipleStreams, new HashMap<>());
-    Map<TaskName, Set<SystemStreamPartition>> finalGrouping = groupByPartition.group(currSsps, grouperContext);
+    GrouperMetadata grouperMetadata = new GrouperMetadataImpl(new HashMap<>(), new HashMap<>(), prevGroupingWithMultipleStreams, new HashMap<>());
+    Map<TaskName, Set<SystemStreamPartition>> finalGrouping = groupByPartition.group(currSsps, grouperMetadata);
     Assert.assertEquals(expectedGrouping, finalGrouping);
   }
 
   @Test
   public void testOnlyNewlyAddedStreams() {
-    Map<TaskName, Set<SystemStreamPartition>> prevGroupingWithMultipleStreams = ImmutableMap.<TaskName, Set<SystemStreamPartition>>builder()
-            .put(new TaskName("Partition 0"), ImmutableSet.of(new SystemStreamPartition("kafka", "PVE", new Partition(0)), new SystemStreamPartition("kafka", "URE", new Partition(0))))
-            .put(new TaskName("Partition 1"), ImmutableSet.of(new SystemStreamPartition("kafka", "PVE", new Partition(1)), new SystemStreamPartition("kafka", "URE", new Partition(1))))
-            .put(new TaskName("Partition 2"), ImmutableSet.of(new SystemStreamPartition("kafka", "PVE", new Partition(2)), new SystemStreamPartition("kafka", "URE", new Partition(2))))
-            .put(new TaskName("Partition 3"), ImmutableSet.of(new SystemStreamPartition("kafka", "PVE", new Partition(3)), new SystemStreamPartition("kafka", "URE", new Partition(3))))
+    Map<TaskName, List<SystemStreamPartition>> prevGroupingWithMultipleStreams = ImmutableMap.<TaskName, List<SystemStreamPartition>>builder()
+            .put(new TaskName("Partition 0"), ImmutableList.of(new SystemStreamPartition("kafka", "PVE", new Partition(0)), new SystemStreamPartition("kafka", "URE", new Partition(0))))
+            .put(new TaskName("Partition 1"), ImmutableList.of(new SystemStreamPartition("kafka", "PVE", new Partition(1)), new SystemStreamPartition("kafka", "URE", new Partition(1))))
+            .put(new TaskName("Partition 2"), ImmutableList.of(new SystemStreamPartition("kafka", "PVE", new Partition(2)), new SystemStreamPartition("kafka", "URE", new Partition(2))))
+            .put(new TaskName("Partition 3"), ImmutableList.of(new SystemStreamPartition("kafka", "PVE", new Partition(3)), new SystemStreamPartition("kafka", "URE", new Partition(3))))
             .build();
 
     Set<SystemStreamPartition> currSsps = IntStream.range(0, 8)
@@ -194,19 +197,19 @@ public class TestGroupByPartition {
             .build();
 
     SSPGrouperProxy groupByPartition = new SSPGrouperProxy(new MapConfig(), new GroupByPartition(new MapConfig()));
-    GrouperContext grouperContext = new GrouperContext(new HashMap<>(), new HashMap<>(), prevGroupingWithMultipleStreams, new HashMap<>());
-    Map<TaskName, Set<SystemStreamPartition>> finalGrouping = groupByPartition.group(currSsps, grouperContext);
+    GrouperMetadata grouperMetadata = new GrouperMetadataImpl(new HashMap<>(), new HashMap<>(), prevGroupingWithMultipleStreams, new HashMap<>());
+    Map<TaskName, Set<SystemStreamPartition>> finalGrouping = groupByPartition.group(currSsps, grouperMetadata);
     Assert.assertEquals(expectedGrouping, finalGrouping);
   }
 
 
   @Test
   public void testRemovalAndAdditionOfStreamsWithRepartitioning() {
-    Map<TaskName, Set<SystemStreamPartition>> prevGroupingWithMultipleStreams = ImmutableMap.<TaskName, Set<SystemStreamPartition>>builder()
-            .put(new TaskName("Partition 0"), ImmutableSet.of(new SystemStreamPartition("kafka", "PVE", new Partition(0)), new SystemStreamPartition("kafka", "URE", new Partition(0))))
-            .put(new TaskName("Partition 1"), ImmutableSet.of(new SystemStreamPartition("kafka", "PVE", new Partition(1)), new SystemStreamPartition("kafka", "URE", new Partition(1))))
-            .put(new TaskName("Partition 2"), ImmutableSet.of(new SystemStreamPartition("kafka", "PVE", new Partition(2)), new SystemStreamPartition("kafka", "URE", new Partition(2))))
-            .put(new TaskName("Partition 3"), ImmutableSet.of(new SystemStreamPartition("kafka", "PVE", new Partition(3)), new SystemStreamPartition("kafka", "URE", new Partition(3))))
+    Map<TaskName, List<SystemStreamPartition>> prevGroupingWithMultipleStreams = ImmutableMap.<TaskName, List<SystemStreamPartition>>builder()
+            .put(new TaskName("Partition 0"), ImmutableList.of(new SystemStreamPartition("kafka", "PVE", new Partition(0)), new SystemStreamPartition("kafka", "URE", new Partition(0))))
+            .put(new TaskName("Partition 1"), ImmutableList.of(new SystemStreamPartition("kafka", "PVE", new Partition(1)), new SystemStreamPartition("kafka", "URE", new Partition(1))))
+            .put(new TaskName("Partition 2"), ImmutableList.of(new SystemStreamPartition("kafka", "PVE", new Partition(2)), new SystemStreamPartition("kafka", "URE", new Partition(2))))
+            .put(new TaskName("Partition 3"), ImmutableList.of(new SystemStreamPartition("kafka", "PVE", new Partition(3)), new SystemStreamPartition("kafka", "URE", new Partition(3))))
             .build();
 
     Set<SystemStreamPartition> currSsps = IntStream.range(0, 8)
@@ -237,18 +240,18 @@ public class TestGroupByPartition {
             .build();
 
     SSPGrouperProxy groupByPartition = new SSPGrouperProxy(new MapConfig(), new GroupByPartition(new MapConfig()));
-    GrouperContext grouperContext = new GrouperContext(new HashMap<>(), new HashMap<>(), prevGroupingWithMultipleStreams, new HashMap<>());
-    Map<TaskName, Set<SystemStreamPartition>> finalGrouping = groupByPartition.group(currSsps, grouperContext);
+    GrouperMetadata grouperMetadata = new GrouperMetadataImpl(new HashMap<>(), new HashMap<>(), prevGroupingWithMultipleStreams, new HashMap<>());
+    Map<TaskName, Set<SystemStreamPartition>> finalGrouping = groupByPartition.group(currSsps, grouperMetadata);
     Assert.assertEquals(expectedGrouping, finalGrouping);
   }
 
   @Test
   public void testMultipleStreamRepartitioningWithNewStreams() {
-    Map<TaskName, Set<SystemStreamPartition>> prevGroupingWithMultipleStreams = ImmutableMap.<TaskName, Set<SystemStreamPartition>>builder()
-            .put(new TaskName("Partition 0"), ImmutableSet.of(new SystemStreamPartition("kafka", "PVE", new Partition(0)), new SystemStreamPartition("kafka", "URE", new Partition(0))))
-            .put(new TaskName("Partition 1"), ImmutableSet.of(new SystemStreamPartition("kafka", "PVE", new Partition(1)), new SystemStreamPartition("kafka", "URE", new Partition(1))))
-            .put(new TaskName("Partition 2"), ImmutableSet.of(new SystemStreamPartition("kafka", "PVE", new Partition(2)), new SystemStreamPartition("kafka", "URE", new Partition(2))))
-            .put(new TaskName("Partition 3"), ImmutableSet.of(new SystemStreamPartition("kafka", "PVE", new Partition(3)), new SystemStreamPartition("kafka", "URE", new Partition(3))))
+    Map<TaskName, List<SystemStreamPartition>> prevGroupingWithMultipleStreams = ImmutableMap.<TaskName, List<SystemStreamPartition>>builder()
+            .put(new TaskName("Partition 0"), ImmutableList.of(new SystemStreamPartition("kafka", "PVE", new Partition(0)), new SystemStreamPartition("kafka", "URE", new Partition(0))))
+            .put(new TaskName("Partition 1"), ImmutableList.of(new SystemStreamPartition("kafka", "PVE", new Partition(1)), new SystemStreamPartition("kafka", "URE", new Partition(1))))
+            .put(new TaskName("Partition 2"), ImmutableList.of(new SystemStreamPartition("kafka", "PVE", new Partition(2)), new SystemStreamPartition("kafka", "URE", new Partition(2))))
+            .put(new TaskName("Partition 3"), ImmutableList.of(new SystemStreamPartition("kafka", "PVE", new Partition(3)), new SystemStreamPartition("kafka", "URE", new Partition(3))))
             .build();
 
     Set<SystemStreamPartition> currSsps = new HashSet<>();
@@ -288,8 +291,8 @@ public class TestGroupByPartition {
 
 
     SSPGrouperProxy groupByPartition = new SSPGrouperProxy(new MapConfig(), new GroupByPartition(new MapConfig()));
-    GrouperContext grouperContext = new GrouperContext(new HashMap<>(), new HashMap<>(), prevGroupingWithMultipleStreams, new HashMap<>());
-    Map<TaskName, Set<SystemStreamPartition>> finalGrouping = groupByPartition.group(currSsps, grouperContext);
+    GrouperMetadata grouperMetadata = new GrouperMetadataImpl(new HashMap<>(), new HashMap<>(), prevGroupingWithMultipleStreams, new HashMap<>());
+    Map<TaskName, Set<SystemStreamPartition>> finalGrouping = groupByPartition.group(currSsps, grouperMetadata);
     Assert.assertEquals(expectedGrouping, finalGrouping);
   }
 }
index 385083b..a31eaa1 100644 (file)
@@ -25,12 +25,15 @@ import java.util.*;
 import java.util.stream.Collectors;
 import java.util.stream.IntStream;
 
+import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableMap;
 import com.google.common.collect.ImmutableSet;
 import org.apache.samza.Partition;
 import org.apache.samza.config.Config;
 import org.apache.samza.config.MapConfig;
 import org.apache.samza.container.TaskName;
+import org.apache.samza.container.grouper.task.GrouperMetadata;
+import org.apache.samza.container.grouper.task.GrouperMetadataImpl;
 import org.apache.samza.system.SystemStreamPartition;
 import org.junit.Assert;
 import org.junit.Test;
@@ -77,11 +80,11 @@ public class TestGroupBySystemStreamPartition {
 
   @Test
   public void testSingleStreamRepartitioning() {
-    Map<TaskName, Set<SystemStreamPartition>> prevGroupingWithSingleStream = ImmutableMap.<TaskName, Set<SystemStreamPartition>>builder()
-            .put(new TaskName("SystemStreamPartition [kafka, PVE, 0]"), ImmutableSet.of(new SystemStreamPartition("kafka", "PVE", new Partition(0))))
-            .put(new TaskName("SystemStreamPartition [kafka, PVE, 1]"), ImmutableSet.of(new SystemStreamPartition("kafka", "PVE", new Partition(1))))
-            .put(new TaskName("SystemStreamPartition [kafka, PVE, 2]"), ImmutableSet.of(new SystemStreamPartition("kafka", "PVE", new Partition(2))))
-            .put(new TaskName("SystemStreamPartition [kafka, PVE, 3]"), ImmutableSet.of(new SystemStreamPartition("kafka", "PVE", new Partition(3))))
+    Map<TaskName, List<SystemStreamPartition>> prevGroupingWithSingleStream = ImmutableMap.<TaskName, List<SystemStreamPartition>>builder()
+            .put(new TaskName("SystemStreamPartition [kafka, PVE, 0]"), ImmutableList.of(new SystemStreamPartition("kafka", "PVE", new Partition(0))))
+            .put(new TaskName("SystemStreamPartition [kafka, PVE, 1]"), ImmutableList.of(new SystemStreamPartition("kafka", "PVE", new Partition(1))))
+            .put(new TaskName("SystemStreamPartition [kafka, PVE, 2]"), ImmutableList.of(new SystemStreamPartition("kafka", "PVE", new Partition(2))))
+            .put(new TaskName("SystemStreamPartition [kafka, PVE, 3]"), ImmutableList.of(new SystemStreamPartition("kafka", "PVE", new Partition(3))))
             .build();
 
     Set<SystemStreamPartition> currSsps = IntStream.range(0, 8)
@@ -100,22 +103,22 @@ public class TestGroupBySystemStreamPartition {
             .build();
 
     SSPGrouperProxy groupBySystemStreamPartition = new SSPGrouperProxy(new MapConfig(), new GroupBySystemStreamPartition(new MapConfig()));
-    GrouperContext grouperContext = new GrouperContext(new HashMap<>(), new HashMap<>(), prevGroupingWithSingleStream, new HashMap<>());
-    Map<TaskName, Set<SystemStreamPartition>> finalGrouping = groupBySystemStreamPartition.group(currSsps, grouperContext);
+    GrouperMetadata grouperMetadata = new GrouperMetadataImpl(new HashMap<>(), new HashMap<>(), prevGroupingWithSingleStream, new HashMap<>());
+    Map<TaskName, Set<SystemStreamPartition>> finalGrouping = groupBySystemStreamPartition.group(currSsps, grouperMetadata);
     Assert.assertEquals(expectedGrouping, finalGrouping);
   }
 
   @Test
   public void testMultipleStreamsWithSingleStreamRepartitioning() {
-    Map<TaskName, Set<SystemStreamPartition>> prevGroupingWithMultipleStreams = ImmutableMap.<TaskName, Set<SystemStreamPartition>>builder()
-            .put(new TaskName("SystemStreamPartition [kafka, PVE, 0]"), ImmutableSet.of(new SystemStreamPartition("kafka", "PVE", new Partition(0))))
-            .put(new TaskName("SystemStreamPartition [kafka, PVE, 1]"), ImmutableSet.of(new SystemStreamPartition("kafka", "PVE", new Partition(1))))
-            .put(new TaskName("SystemStreamPartition [kafka, PVE, 2]"), ImmutableSet.of(new SystemStreamPartition("kafka", "PVE", new Partition(2))))
-            .put(new TaskName("SystemStreamPartition [kafka, PVE, 3]"), ImmutableSet.of(new SystemStreamPartition("kafka", "PVE", new Partition(3))))
-            .put(new TaskName("SystemStreamPartition [kafka, URE, 0]"), ImmutableSet.of(new SystemStreamPartition("kafka", "URE", new Partition(0))))
-            .put(new TaskName("SystemStreamPartition [kafka, URE, 1]"), ImmutableSet.of(new SystemStreamPartition("kafka", "URE", new Partition(1))))
-            .put(new TaskName("SystemStreamPartition [kafka, URE, 2]"), ImmutableSet.of(new SystemStreamPartition("kafka", "URE", new Partition(2))))
-            .put(new TaskName("SystemStreamPartition [kafka, URE, 3]"), ImmutableSet.of(new SystemStreamPartition("kafka", "URE", new Partition(3))))
+    Map<TaskName, List<SystemStreamPartition>> prevGroupingWithMultipleStreams = ImmutableMap.<TaskName, List<SystemStreamPartition>>builder()
+            .put(new TaskName("SystemStreamPartition [kafka, PVE, 0]"), ImmutableList.of(new SystemStreamPartition("kafka", "PVE", new Partition(0))))
+            .put(new TaskName("SystemStreamPartition [kafka, PVE, 1]"), ImmutableList.of(new SystemStreamPartition("kafka", "PVE", new Partition(1))))
+            .put(new TaskName("SystemStreamPartition [kafka, PVE, 2]"), ImmutableList.of(new SystemStreamPartition("kafka", "PVE", new Partition(2))))
+            .put(new TaskName("SystemStreamPartition [kafka, PVE, 3]"), ImmutableList.of(new SystemStreamPartition("kafka", "PVE", new Partition(3))))
+            .put(new TaskName("SystemStreamPartition [kafka, URE, 0]"), ImmutableList.of(new SystemStreamPartition("kafka", "URE", new Partition(0))))
+            .put(new TaskName("SystemStreamPartition [kafka, URE, 1]"), ImmutableList.of(new SystemStreamPartition("kafka", "URE", new Partition(1))))
+            .put(new TaskName("SystemStreamPartition [kafka, URE, 2]"), ImmutableList.of(new SystemStreamPartition("kafka", "URE", new Partition(2))))
+            .put(new TaskName("SystemStreamPartition [kafka, URE, 3]"), ImmutableList.of(new SystemStreamPartition("kafka", "URE", new Partition(3))))
             .build();
 
     Set<SystemStreamPartition> currSsps = IntStream.range(0, 8)
@@ -148,22 +151,22 @@ public class TestGroupBySystemStreamPartition {
             .build();
 
     SSPGrouperProxy groupBySystemStreamPartition = new SSPGrouperProxy(new MapConfig(), new GroupBySystemStreamPartition(new MapConfig()));
-    GrouperContext grouperContext = new GrouperContext(new HashMap<>(), new HashMap<>(), prevGroupingWithMultipleStreams, new HashMap<>());
-    Map<TaskName, Set<SystemStreamPartition>> finalGrouping = groupBySystemStreamPartition.group(currSsps, grouperContext);
+    GrouperMetadata grouperMetadata = new GrouperMetadataImpl(new HashMap<>(), new HashMap<>(), prevGroupingWithMultipleStreams, new HashMap<>());
+    Map<TaskName, Set<SystemStreamPartition>> finalGrouping = groupBySystemStreamPartition.group(currSsps, grouperMetadata);
     Assert.assertEquals(expectedGrouping, finalGrouping);
   }
 
   @Test
   public void testOnlyNewlyAddedStreams() {
-    Map<TaskName, Set<SystemStreamPartition>> prevGroupingWithMultipleStreams = ImmutableMap.<TaskName, Set<SystemStreamPartition>>builder()
-            .put(new TaskName("SystemStreamPartition [kafka, PVE, 0]"), ImmutableSet.of(new SystemStreamPartition("kafka", "PVE", new Partition(0))))
-            .put(new TaskName("SystemStreamPartition [kafka, PVE, 1]"), ImmutableSet.of(new SystemStreamPartition("kafka", "PVE", new Partition(1))))
-            .put(new TaskName("SystemStreamPartition [kafka, PVE, 2]"), ImmutableSet.of(new SystemStreamPartition("kafka", "PVE", new Partition(2))))
-            .put(new TaskName("SystemStreamPartition [kafka, PVE, 3]"), ImmutableSet.of(new SystemStreamPartition("kafka", "PVE", new Partition(3))))
-            .put(new TaskName("SystemStreamPartition [kafka, URE, 0]"), ImmutableSet.of(new SystemStreamPartition("kafka", "PVE", new Partition(0))))
-            .put(new TaskName("SystemStreamPartition [kafka, URE, 1]"), ImmutableSet.of(new SystemStreamPartition("kafka", "PVE", new Partition(1))))
-            .put(new TaskName("SystemStreamPartition [kafka, URE, 2]"), ImmutableSet.of(new SystemStreamPartition("kafka", "PVE", new Partition(2))))
-            .put(new TaskName("SystemStreamPartition [kafka, URE, 3]"), ImmutableSet.of(new SystemStreamPartition("kafka", "PVE", new Partition(3))))
+    Map<TaskName, List<SystemStreamPartition>> prevGroupingWithMultipleStreams = ImmutableMap.<TaskName, List<SystemStreamPartition>>builder()
+            .put(new TaskName("SystemStreamPartition [kafka, PVE, 0]"), ImmutableList.of(new SystemStreamPartition("kafka", "PVE", new Partition(0))))
+            .put(new TaskName("SystemStreamPartition [kafka, PVE, 1]"), ImmutableList.of(new SystemStreamPartition("kafka", "PVE", new Partition(1))))
+            .put(new TaskName("SystemStreamPartition [kafka, PVE, 2]"), ImmutableList.of(new SystemStreamPartition("kafka", "PVE", new Partition(2))))
+            .put(new TaskName("SystemStreamPartition [kafka, PVE, 3]"), ImmutableList.of(new SystemStreamPartition("kafka", "PVE", new Partition(3))))
+            .put(new TaskName("SystemStreamPartition [kafka, URE, 0]"), ImmutableList.of(new SystemStreamPartition("kafka", "PVE", new Partition(0))))
+            .put(new TaskName("SystemStreamPartition [kafka, URE, 1]"), ImmutableList.of(new SystemStreamPartition("kafka", "PVE", new Partition(1))))
+            .put(new TaskName("SystemStreamPartition [kafka, URE, 2]"), ImmutableList.of(new SystemStreamPartition("kafka", "PVE", new Partition(2))))
+            .put(new TaskName("SystemStreamPartition [kafka, URE, 3]"), ImmutableList.of(new SystemStreamPartition("kafka", "PVE", new Partition(3))))
             .build();
 
     Set<SystemStreamPartition> currSsps = IntStream.range(0, 8)
@@ -182,23 +185,23 @@ public class TestGroupBySystemStreamPartition {
             .build();
 
     SSPGrouperProxy groupBySystemStreamPartition = new SSPGrouperProxy(new MapConfig(), new GroupBySystemStreamPartition(new MapConfig()));
-    GrouperContext grouperContext = new GrouperContext(new HashMap<>(), new HashMap<>(), prevGroupingWithMultipleStreams, new HashMap<>());
-    Map<TaskName, Set<SystemStreamPartition>> finalGrouping = groupBySystemStreamPartition.group(currSsps, grouperContext);
+    GrouperMetadata grouperMetadata = new GrouperMetadataImpl(new HashMap<>(), new HashMap<>(), prevGroupingWithMultipleStreams, new HashMap<>());
+    Map<TaskName, Set<SystemStreamPartition>> finalGrouping = groupBySystemStreamPartition.group(currSsps, grouperMetadata);
     Assert.assertEquals(expectedGrouping, finalGrouping);
   }
 
 
   @Test
   public void testRemovalAndAdditionOfStreamsWithRepartitioning() {
-    Map<TaskName, Set<SystemStreamPartition>> prevGroupingWithMultipleStreams = ImmutableMap.<TaskName, Set<SystemStreamPartition>>builder()
-            .put(new TaskName("SystemStreamPartition [kafka, PVE, 0]"), ImmutableSet.of(new SystemStreamPartition("kafka", "PVE", new Partition(0))))
-            .put(new TaskName("SystemStreamPartition [kafka, PVE, 1]"), ImmutableSet.of(new SystemStreamPartition("kafka", "PVE", new Partition(1))))
-            .put(new TaskName("SystemStreamPartition [kafka, PVE, 2]"), ImmutableSet.of(new SystemStreamPartition("kafka", "PVE", new Partition(2))))
-            .put(new TaskName("SystemStreamPartition [kafka, PVE, 3]"), ImmutableSet.of(new SystemStreamPartition("kafka", "PVE", new Partition(3))))
-            .put(new TaskName("SystemStreamPartition [kafka, URE, 0]"), ImmutableSet.of(new SystemStreamPartition("kafka", "URE", new Partition(0))))
-            .put(new TaskName("SystemStreamPartition [kafka, URE, 1]"), ImmutableSet.of(new SystemStreamPartition("kafka", "URE", new Partition(1))))
-            .put(new TaskName("SystemStreamPartition [kafka, URE, 2]"), ImmutableSet.of(new SystemStreamPartition("kafka", "URE", new Partition(2))))
-            .put(new TaskName("SystemStreamPartition [kafka, URE, 3]"), ImmutableSet.of(new SystemStreamPartition("kafka", "URE", new Partition(3))))
+    Map<TaskName, List<SystemStreamPartition>> prevGroupingWithMultipleStreams = ImmutableMap.<TaskName, List<SystemStreamPartition>>builder()
+            .put(new TaskName("SystemStreamPartition [kafka, PVE, 0]"), ImmutableList.of(new SystemStreamPartition("kafka", "PVE", new Partition(0))))
+            .put(new TaskName("SystemStreamPartition [kafka, PVE, 1]"), ImmutableList.of(new SystemStreamPartition("kafka", "PVE", new Partition(1))))
+            .put(new TaskName("SystemStreamPartition [kafka, PVE, 2]"), ImmutableList.of(new SystemStreamPartition("kafka", "PVE", new Partition(2))))
+            .put(new TaskName("SystemStreamPartition [kafka, PVE, 3]"), ImmutableList.of(new SystemStreamPartition("kafka", "PVE", new Partition(3))))
+            .put(new TaskName("SystemStreamPartition [kafka, URE, 0]"), ImmutableList.of(new SystemStreamPartition("kafka", "URE", new Partition(0))))
+            .put(new TaskName("SystemStreamPartition [kafka, URE, 1]"), ImmutableList.of(new SystemStreamPartition("kafka", "URE", new Partition(1))))
+            .put(new TaskName("SystemStreamPartition [kafka, URE, 2]"), ImmutableList.of(new SystemStreamPartition("kafka", "URE", new Partition(2))))
+            .put(new TaskName("SystemStreamPartition [kafka, URE, 3]"), ImmutableList.of(new SystemStreamPartition("kafka", "URE", new Partition(3))))
             .build();
 
     Set<SystemStreamPartition> currSsps = IntStream.range(0, 8)
@@ -226,22 +229,22 @@ public class TestGroupBySystemStreamPartition {
             .build();
 
     SSPGrouperProxy groupBySystemStreamPartition = new SSPGrouperProxy(new MapConfig(), new GroupBySystemStreamPartition(new MapConfig()));
-    GrouperContext grouperContext = new GrouperContext(new HashMap<>(), new HashMap<>(), prevGroupingWithMultipleStreams, new HashMap<>());
-    Map<TaskName, Set<SystemStreamPartition>> finalGrouping = groupBySystemStreamPartition.group(currSsps, grouperContext);
+    GrouperMetadata grouperMetadata = new GrouperMetadataImpl(new HashMap<>(), new HashMap<>(), prevGroupingWithMultipleStreams, new HashMap<>());
+    Map<TaskName, Set<SystemStreamPartition>> finalGrouping = groupBySystemStreamPartition.group(currSsps, grouperMetadata);
     Assert.assertEquals(expectedGrouping, finalGrouping);
   }
 
   @Test
   public void testMultipleStreamRepartitioningWithNewStreams() {
-    Map<TaskName, Set<SystemStreamPartition>> prevGroupingWithMultipleStreams = ImmutableMap.<TaskName, Set<SystemStreamPartition>>builder()
-            .put(new TaskName("SystemStreamPartition [kafka, PVE, 0]"), ImmutableSet.of(new SystemStreamPartition("kafka", "PVE", new Partition(0))))
-            .put(new TaskName("SystemStreamPartition [kafka, PVE, 1]"), ImmutableSet.of(new SystemStreamPartition("kafka", "PVE", new Partition(1))))
-            .put(new TaskName("SystemStreamPartition [kafka, PVE, 2]"), ImmutableSet.of(new SystemStreamPartition("kafka", "PVE", new Partition(2))))
-            .put(new TaskName("SystemStreamPartition [kafka, PVE, 3]"), ImmutableSet.of(new SystemStreamPartition("kafka", "PVE", new Partition(3))))
-            .put(new TaskName("SystemStreamPartition [kafka, URE, 0]"), ImmutableSet.of(new SystemStreamPartition("kafka", "URE", new Partition(0))))
-            .put(new TaskName("SystemStreamPartition [kafka, URE, 1]"), ImmutableSet.of(new SystemStreamPartition("kafka", "URE", new Partition(1))))
-            .put(new TaskName("SystemStreamPartition [kafka, URE, 2]"), ImmutableSet.of(new SystemStreamPartition("kafka", "URE", new Partition(2))))
-            .put(new TaskName("SystemStreamPartition [kafka, URE, 3]"), ImmutableSet.of(new SystemStreamPartition("kafka", "URE", new Partition(3))))
+    Map<TaskName, List<SystemStreamPartition>> prevGroupingWithMultipleStreams = ImmutableMap.<TaskName, List<SystemStreamPartition>>builder()
+            .put(new TaskName("SystemStreamPartition [kafka, PVE, 0]"), ImmutableList.of(new SystemStreamPartition("kafka", "PVE", new Partition(0))))
+            .put(new TaskName("SystemStreamPartition [kafka, PVE, 1]"), ImmutableList.of(new SystemStreamPartition("kafka", "PVE", new Partition(1))))
+            .put(new TaskName("SystemStreamPartition [kafka, PVE, 2]"), ImmutableList.of(new SystemStreamPartition("kafka", "PVE", new Partition(2))))
+            .put(new TaskName("SystemStreamPartition [kafka, PVE, 3]"), ImmutableList.of(new SystemStreamPartition("kafka", "PVE", new Partition(3))))
+            .put(new TaskName("SystemStreamPartition [kafka, URE, 0]"), ImmutableList.of(new SystemStreamPartition("kafka", "URE", new Partition(0))))
+            .put(new TaskName("SystemStreamPartition [kafka, URE, 1]"), ImmutableList.of(new SystemStreamPartition("kafka", "URE", new Partition(1))))
+            .put(new TaskName("SystemStreamPartition [kafka, URE, 2]"), ImmutableList.of(new SystemStreamPartition("kafka", "URE", new Partition(2))))
+            .put(new TaskName("SystemStreamPartition [kafka, URE, 3]"), ImmutableList.of(new SystemStreamPartition("kafka", "URE", new Partition(3))))
             .build();
 
     Set<SystemStreamPartition> currSsps = IntStream.range(0, 8)
@@ -278,8 +281,8 @@ public class TestGroupBySystemStreamPartition {
             .build();
 
     SSPGrouperProxy groupBySystemStreamPartition = new SSPGrouperProxy(new MapConfig(), new GroupBySystemStreamPartition(new MapConfig()));
-    GrouperContext grouperContext = new GrouperContext(new HashMap<>(), new HashMap<>(), prevGroupingWithMultipleStreams, new HashMap<>());
-    Map<TaskName, Set<SystemStreamPartition>> finalGrouping = groupBySystemStreamPartition.group(currSsps, grouperContext);
+    GrouperMetadata grouperMetadata = new GrouperMetadataImpl(new HashMap<>(), new HashMap<>(), prevGroupingWithMultipleStreams, new HashMap<>());
+    Map<TaskName, Set<SystemStreamPartition>> finalGrouping = groupBySystemStreamPartition.group(currSsps, grouperMetadata);
     Assert.assertEquals(expectedGrouping, finalGrouping);
   }
 }
diff --git a/samza-core/src/test/java/org/apache/samza/container/grouper/task/TestTaskPartitionAssignmentManager.java b/samza-core/src/test/java/org/apache/samza/container/grouper/task/TestTaskPartitionAssignmentManager.java
new file mode 100644 (file)
index 0000000..a141ee3
--- /dev/null
@@ -0,0 +1,137 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.samza.container.grouper.task;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+import java.util.List;
+import java.util.Map;
+import org.apache.samza.config.Config;
+import org.apache.samza.config.MapConfig;
+import org.apache.samza.coordinator.stream.MockCoordinatorStreamSystemFactory;
+import org.apache.samza.metrics.MetricsRegistryMap;
+import org.apache.samza.Partition;
+import org.apache.samza.system.SystemStream;
+import org.apache.samza.system.SystemStreamPartition;
+import org.apache.samza.util.CoordinatorStreamUtil;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.powermock.api.mockito.PowerMockito;
+import org.powermock.core.classloader.annotations.PrepareForTest;
+import org.powermock.modules.junit4.PowerMockRunner;
+
+import static org.mockito.Matchers.anyObject;
+import static org.mockito.Mockito.when;
+
+@RunWith(PowerMockRunner.class)
+@PrepareForTest(CoordinatorStreamUtil.class)
+public class TestTaskPartitionAssignmentManager {
+
+  private static final String TEST_SYSTEM = "system";
+  private static final String TEST_STREAM = "stream";
+  private static final Partition PARTITION = new Partition(0);
+
+  private final Config config = new MapConfig(ImmutableMap.of("job.name", "test-job", "job.coordinator.system", "test-kafka"));
+  private final SystemStreamPartition testSystemStreamPartition = new SystemStreamPartition(TEST_SYSTEM, TEST_STREAM, PARTITION);
+
+  private MockCoordinatorStreamSystemFactory mockCoordinatorStreamSystemFactory;
+  private TaskPartitionAssignmentManager taskPartitionAssignmentManager;
+
+  @Before
+  public void setup() {
+    mockCoordinatorStreamSystemFactory = new MockCoordinatorStreamSystemFactory();
+    MockCoordinatorStreamSystemFactory.enableMockConsumerCache();
+    PowerMockito.mockStatic(CoordinatorStreamUtil.class);
+    when(CoordinatorStreamUtil.getCoordinatorSystemFactory(anyObject())).thenReturn(mockCoordinatorStreamSystemFactory);
+    when(CoordinatorStreamUtil.getCoordinatorSystemStream(anyObject())).thenReturn(new SystemStream("test-kafka", "test"));
+    when(CoordinatorStreamUtil.getCoordinatorStreamName(anyObject(), anyObject())).thenReturn("test");
+    taskPartitionAssignmentManager = new TaskPartitionAssignmentManager(config, new MetricsRegistryMap());
+  }
+
+  @After
+  public void tearDown() {
+    MockCoordinatorStreamSystemFactory.disableMockConsumerCache();
+    taskPartitionAssignmentManager.close();
+  }
+
+  @Test
+  public void testReadAfterWrite() {
+    List<String> testTaskNames = ImmutableList.of("test-task1", "test-task2", "test-task3");
+    taskPartitionAssignmentManager.writeTaskPartitionAssignment(testSystemStreamPartition, testTaskNames);
+
+    Map<SystemStreamPartition, List<String>> expectedMapping = ImmutableMap.of(testSystemStreamPartition, testTaskNames);
+    Map<SystemStreamPartition, List<String>> actualMapping = taskPartitionAssignmentManager.readTaskPartitionAssignments();
+
+    Assert.assertEquals(expectedMapping, actualMapping);
+  }
+
+  @Test
+  public void testDeleteAfterWrite() {
+    List<String> testTaskNames = ImmutableList.of("test-task1", "test-task2", "test-task3");
+    taskPartitionAssignmentManager.writeTaskPartitionAssignment(testSystemStreamPartition, testTaskNames);
+
+    Map<SystemStreamPartition, List<String>> actualMapping = taskPartitionAssignmentManager.readTaskPartitionAssignments();
+    Assert.assertEquals(1, actualMapping.size());
+
+    taskPartitionAssignmentManager.delete(ImmutableList.of(testSystemStreamPartition));
+
+    actualMapping = taskPartitionAssignmentManager.readTaskPartitionAssignments();
+    Assert.assertEquals(0, actualMapping.size());
+  }
+
+  @Test
+  public void testReadPartitionAssignments() {
+    SystemStreamPartition testSystemStreamPartition1 = new SystemStreamPartition(TEST_SYSTEM, TEST_STREAM, PARTITION);
+    List<String> testTaskNames1 = ImmutableList.of("test-task1", "test-task2", "test-task3");
+    SystemStreamPartition testSystemStreamPartition2 = new SystemStreamPartition(TEST_SYSTEM, "stream-2", PARTITION);
+    List<String> testTaskNames2 = ImmutableList.of("test-task4", "test-task5");
+    SystemStreamPartition testSystemStreamPartition3 = new SystemStreamPartition(TEST_SYSTEM, "stream-3", PARTITION);
+    List<String> testTaskNames3 = ImmutableList.of("test-task6", "test-task7", "test-task8");
+
+    taskPartitionAssignmentManager.writeTaskPartitionAssignment(testSystemStreamPartition1, testTaskNames1);
+    taskPartitionAssignmentManager.writeTaskPartitionAssignment(testSystemStreamPartition2, testTaskNames2);
+    taskPartitionAssignmentManager.writeTaskPartitionAssignment(testSystemStreamPartition3, testTaskNames3);
+
+    Map<SystemStreamPartition, List<String>> expectedMapping = ImmutableMap.of(testSystemStreamPartition1, testTaskNames1,
+            testSystemStreamPartition2, testTaskNames2, testSystemStreamPartition3, testTaskNames3);
+    Map<SystemStreamPartition, List<String>> actualMapping = taskPartitionAssignmentManager.readTaskPartitionAssignments();
+
+    Assert.assertEquals(expectedMapping, actualMapping);
+  }
+
+  @Test
+  public void testMultipleUpdatesReturnsTheMostRecentValue() {
+    List<String> testTaskNames1 = ImmutableList.of("test-task1", "test-task2", "test-task3");
+
+    taskPartitionAssignmentManager.writeTaskPartitionAssignment(testSystemStreamPartition, testTaskNames1);
+
+    List<String> testTaskNames2 = ImmutableList.of("test-task4", "test-task5");
+    taskPartitionAssignmentManager.writeTaskPartitionAssignment(testSystemStreamPartition, testTaskNames2);
+
+    List<String> testTaskNames3 = ImmutableList.of("test-task6", "test-task7", "test-task8");
+    taskPartitionAssignmentManager.writeTaskPartitionAssignment(testSystemStreamPartition, testTaskNames3);
+
+    Map<SystemStreamPartition, List<String>> expectedMapping = ImmutableMap.of(testSystemStreamPartition, testTaskNames3);
+    Map<SystemStreamPartition, List<String>> actualMapping = taskPartitionAssignmentManager.readTaskPartitionAssignments();
+    Assert.assertEquals(expectedMapping, actualMapping);
+  }
+}
index 7eb5768..6e627b0 100644 (file)
 
 package org.apache.samza.coordinator;
 
+import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableMap;
-import java.util.HashSet;
-import java.util.Set;
+
+import java.util.*;
+
+import com.google.common.collect.ImmutableSet;
 import org.apache.samza.Partition;
 import org.apache.samza.config.Config;
 import org.apache.samza.config.MapConfig;
@@ -30,6 +33,7 @@ import org.apache.samza.container.TaskName;
 import org.apache.samza.container.grouper.task.GroupByContainerCount;
 import org.apache.samza.container.grouper.task.GrouperMetadataImpl;
 import org.apache.samza.container.grouper.task.TaskAssignmentManager;
+import org.apache.samza.container.grouper.task.TaskPartitionAssignmentManager;
 import org.apache.samza.coordinator.server.HttpServer;
 import org.apache.samza.coordinator.stream.messages.SetContainerHostMapping;
 import org.apache.samza.job.model.ContainerModel;
@@ -40,6 +44,7 @@ import org.apache.samza.runtime.LocationId;
 import org.apache.samza.system.StreamMetadataCache;
 import org.apache.samza.system.SystemStream;
 import org.apache.samza.system.SystemStreamMetadata;
+import org.apache.samza.system.SystemStreamPartition;
 import org.apache.samza.testUtils.MockHttpServer;
 import org.eclipse.jetty.servlet.DefaultServlet;
 import org.eclipse.jetty.servlet.ServletHolder;
@@ -47,11 +52,6 @@ import org.junit.Assert;
 import org.junit.Before;
 import org.junit.Test;
 
-import java.util.HashMap;
-import java.util.Map;
-import java.util.Collections;
-
-import static org.apache.samza.coordinator.JobModelManager.*;
 import static org.junit.Assert.assertEquals;
 import static org.mockito.Matchers.anyBoolean;
 import static org.mockito.Matchers.argThat;
@@ -165,26 +165,42 @@ public class TestJobModelManager {
     // Mocking setup.
     LocalityManager mockLocalityManager = mock(LocalityManager.class);
     TaskAssignmentManager mockTaskAssignmentManager = Mockito.mock(TaskAssignmentManager.class);
+    TaskPartitionAssignmentManager mockTaskPartitionAssignmentManager = Mockito.mock(TaskPartitionAssignmentManager.class);
+
+    SystemStreamPartition testSystemStreamPartition1 = new SystemStreamPartition(new SystemStream("test-system-0", "test-stream-0"), new Partition(1));
+    SystemStreamPartition testSystemStreamPartition2 = new SystemStreamPartition(new SystemStream("test-system-1", "test-stream-1"), new Partition(2));
 
     Map<String, Map<String, String>> localityMappings = new HashMap<>();
     localityMappings.put("0", ImmutableMap.of(SetContainerHostMapping.HOST_KEY, "abc-affinity"));
 
+    Map<SystemStreamPartition, List<String>> taskToSSPAssignments = ImmutableMap.of(testSystemStreamPartition1, ImmutableList.of("task-0", "task-1"),
+                                                                                    testSystemStreamPartition2, ImmutableList.of("task-2", "task-3"));
+
     Map<String, String> taskAssignment = ImmutableMap.of("task-0", "0");
 
     // Mock the container locality assignment.
     when(mockLocalityManager.readContainerLocality()).thenReturn(localityMappings);
 
+    // Mock the task to partition assignment.
+    when(mockTaskPartitionAssignmentManager.readTaskPartitionAssignments()).thenReturn(taskToSSPAssignments);
+
     // Mock the container to task assignment.
     when(mockTaskAssignmentManager.readTaskAssignment()).thenReturn(taskAssignment);
     when(mockTaskAssignmentManager.readTaskModes()).thenReturn(Collections.singletonMap(new TaskName("task-0"), TaskMode.Active));
 
-    GrouperMetadataImpl grouperMetadata = JobModelManager.getGrouperMetadata(new MapConfig(), mockLocalityManager, mockTaskAssignmentManager);
+    GrouperMetadataImpl grouperMetadata = JobModelManager.getGrouperMetadata(new MapConfig(), mockLocalityManager, mockTaskAssignmentManager, mockTaskPartitionAssignmentManager);
 
     Mockito.verify(mockLocalityManager).readContainerLocality();
     Mockito.verify(mockTaskAssignmentManager).readTaskAssignment();
 
     Assert.assertEquals(ImmutableMap.of("0", new LocationId("abc-affinity"), "1", new LocationId("ANY_HOST")), grouperMetadata.getProcessorLocality());
     Assert.assertEquals(ImmutableMap.of(new TaskName("task-0"), new LocationId("abc-affinity")), grouperMetadata.getTaskLocality());
+
+    Map<TaskName, List<SystemStreamPartition>> expectedTaskToSSPAssignments = ImmutableMap.of(new TaskName("task-0"), ImmutableList.of(testSystemStreamPartition1),
+                                                                                              new TaskName("task-1"), ImmutableList.of(testSystemStreamPartition1),
+                                                                                              new TaskName("task-2"), ImmutableList.of(testSystemStreamPartition2),
+                                                                                              new TaskName("task-3"), ImmutableList.of(testSystemStreamPartition2));
+    Assert.assertEquals(expectedTaskToSSPAssignments, grouperMetadata.getPreviousTaskToSSPAssignment());
   }
 
   @Test
@@ -210,12 +226,18 @@ public class TestJobModelManager {
     JobModel mockJobModel = Mockito.mock(JobModel.class);
     GrouperMetadataImpl mockGrouperMetadata = Mockito.mock(GrouperMetadataImpl.class);
     TaskAssignmentManager mockTaskAssignmentManager = Mockito.mock(TaskAssignmentManager.class);
+    TaskPartitionAssignmentManager mockTaskPartitionAssignmentManager = Mockito.mock(TaskPartitionAssignmentManager.class);
+
+    SystemStreamPartition testSystemStreamPartition1 = new SystemStreamPartition(new SystemStream("test-system-0", "test-stream-0"), new Partition(1));
+    SystemStreamPartition testSystemStreamPartition2 = new SystemStreamPartition(new SystemStream("test-system-1", "test-stream-1"), new Partition(2));
+    SystemStreamPartition testSystemStreamPartition3 = new SystemStreamPartition(new SystemStream("test-system-2", "test-stream-2"), new Partition(1));
+    SystemStreamPartition testSystemStreamPartition4 = new SystemStreamPartition(new SystemStream("test-system-3", "test-stream-3"), new Partition(2));
 
     Map<TaskName, TaskModel> taskModelMap = new HashMap<>();
-    taskModelMap.put(new TaskName("task-1"), new TaskModel(new TaskName("task-1"), new HashSet<>(), new Partition(0)));
-    taskModelMap.put(new TaskName("task-2"), new TaskModel(new TaskName("task-2"), new HashSet<>(), new Partition(1)));
-    taskModelMap.put(new TaskName("task-3"), new TaskModel(new TaskName("task-3"), new HashSet<>(), new Partition(2)));
-    taskModelMap.put(new TaskName("task-4"), new TaskModel(new TaskName("task-4"), new HashSet<>(), new Partition(3)));
+    taskModelMap.put(new TaskName("task-1"), new TaskModel(new TaskName("task-1"), ImmutableSet.of(testSystemStreamPartition1), new Partition(0)));
+    taskModelMap.put(new TaskName("task-2"), new TaskModel(new TaskName("task-2"), ImmutableSet.of(testSystemStreamPartition2), new Partition(1)));
+    taskModelMap.put(new TaskName("task-3"), new TaskModel(new TaskName("task-3"), ImmutableSet.of(testSystemStreamPartition3), new Partition(2)));
+    taskModelMap.put(new TaskName("task-4"), new TaskModel(new TaskName("task-4"), ImmutableSet.of(testSystemStreamPartition4), new Partition(3)));
     ContainerModel containerModel = new ContainerModel("test-container-id", taskModelMap);
     Map<String, ContainerModel> containerMapping = ImmutableMap.of("test-container-id", containerModel);
 
@@ -223,7 +245,7 @@ public class TestJobModelManager {
     when(mockGrouperMetadata.getPreviousTaskToProcessorAssignment()).thenReturn(new HashMap<>());
     Mockito.doNothing().when(mockTaskAssignmentManager).writeTaskContainerMapping(Mockito.any(), Mockito.any(), Mockito.any());
 
-    JobModelManager.updateTaskAssignments(mockJobModel, mockTaskAssignmentManager, mockGrouperMetadata);
+    JobModelManager.updateTaskAssignments(mockJobModel, mockTaskAssignmentManager, mockTaskPartitionAssignmentManager, mockGrouperMetadata);
 
     Set<String> taskNames = new HashSet<String>();
     taskNames.add("task-4");
@@ -231,6 +253,12 @@ public class TestJobModelManager {
     taskNames.add("task-3");
     taskNames.add("task-1");
 
+    Set<SystemStreamPartition> systemStreamPartitions = new HashSet<>();
+    systemStreamPartitions.add(new SystemStreamPartition(new SystemStream("test-system-0", "test-stream-0"), new Partition(1)));
+    systemStreamPartitions.add(new SystemStreamPartition(new SystemStream("test-system-1", "test-stream-1"), new Partition(2)));
+    systemStreamPartitions.add(new SystemStreamPartition(new SystemStream("test-system-2", "test-stream-2"), new Partition(1)));
+    systemStreamPartitions.add(new SystemStreamPartition(new SystemStream("test-system-3", "test-stream-3"), new Partition(2)));
+
     // Verifications
     Mockito.verify(mockJobModel, atLeast(1)).getContainers();
     Mockito.verify(mockTaskAssignmentManager).deleteTaskContainerMappings(Mockito.any());
@@ -238,5 +266,14 @@ public class TestJobModelManager {
     Mockito.verify(mockTaskAssignmentManager).writeTaskContainerMapping("task-2", "test-container-id", TaskMode.Active);
     Mockito.verify(mockTaskAssignmentManager).writeTaskContainerMapping("task-3", "test-container-id", TaskMode.Active);
     Mockito.verify(mockTaskAssignmentManager).writeTaskContainerMapping("task-4", "test-container-id", TaskMode.Active);
+
+    // Verify that the old, stale partition mappings had been purged in the coordinator stream.
+    Mockito.verify(mockTaskPartitionAssignmentManager).delete(systemStreamPartitions);
+
+    // Verify that the new task to partition assignment is stored in the coordinator stream.
+    Mockito.verify(mockTaskPartitionAssignmentManager).writeTaskPartitionAssignment(testSystemStreamPartition1, ImmutableList.of("task-1"));
+    Mockito.verify(mockTaskPartitionAssignmentManager).writeTaskPartitionAssignment(testSystemStreamPartition2, ImmutableList.of("task-2"));
+    Mockito.verify(mockTaskPartitionAssignmentManager).writeTaskPartitionAssignment(testSystemStreamPartition3, ImmutableList.of("task-3"));
+    Mockito.verify(mockTaskPartitionAssignmentManager).writeTaskPartitionAssignment(testSystemStreamPartition4, ImmutableList.of("task-4"));
   }
 }
index a2170cc..4c36109 100644 (file)
@@ -22,6 +22,7 @@ package org.apache.samza.test.processor;
 import java.io.IOException;
 import java.io.ObjectInputStream;
 import java.io.Serializable;
+import java.util.List;
 import java.util.concurrent.CountDownLatch;
 import org.apache.samza.application.descriptors.StreamApplicationDescriptor;
 import org.apache.samza.application.StreamApplication;
@@ -44,15 +45,15 @@ import org.apache.samza.system.kafka.descriptors.KafkaSystemDescriptor;
 public class TestStreamApplication implements StreamApplication {
 
   private final String systemName;
-  private final String inputTopic;
+  private final List<String> inputTopics;
   private final String outputTopic;
   private final String appName;
   private final String processorName;
 
-  private TestStreamApplication(String systemName, String inputTopic, String outputTopic,
+  private TestStreamApplication(String systemName, List<String> inputTopics, String outputTopic,
       String appName, String processorName) {
     this.systemName = systemName;
-    this.inputTopic = inputTopic;
+    this.inputTopics = inputTopics;
     this.outputTopic = outputTopic;
     this.appName = appName;
     this.processorName = processorName;
@@ -61,11 +62,14 @@ public class TestStreamApplication implements StreamApplication {
   @Override
   public void describe(StreamApplicationDescriptor appDescriptor) {
     KafkaSystemDescriptor ksd = new KafkaSystemDescriptor(systemName);
-    KafkaInputDescriptor<String> isd = ksd.getInputDescriptor(inputTopic, new NoOpSerde<>());
     KafkaOutputDescriptor<String> osd = ksd.getOutputDescriptor(outputTopic, new StringSerde());
-    MessageStream<String> inputStream = appDescriptor.getInputStream(isd);
     OutputStream<String> outputStream = appDescriptor.getOutputStream(osd);
-    inputStream.map(new TestMapFunction(appName, processorName)).sendTo(outputStream);
+
+    for (String inputTopic : inputTopics) {
+      KafkaInputDescriptor<String> isd = ksd.getInputDescriptor(inputTopic, new NoOpSerde<>());
+      MessageStream<String> inputStream = appDescriptor.getInputStream(isd);
+      inputStream.map(new TestMapFunction(appName, processorName)).sendTo(outputStream);
+    }
   }
 
   public interface StreamApplicationCallback {
@@ -144,7 +148,7 @@ public class TestStreamApplication implements StreamApplication {
 
   public static StreamApplication getInstance(
       String systemName,
-      String inputTopic,
+      List<String> inputTopics,
       String outputTopic,
       CountDownLatch processedMessageLatch,
       StreamApplicationCallback callback,
@@ -154,7 +158,7 @@ public class TestStreamApplication implements StreamApplication {
     String processorName = config.get(JobConfig.PROCESSOR_ID());
     registerLatches(processedMessageLatch, kafkaEventsConsumedLatch, callback, appName, processorName);
 
-    StreamApplication app = new TestStreamApplication(systemName, inputTopic, outputTopic, appName, processorName);
+    StreamApplication app = new TestStreamApplication(systemName, inputTopics, outputTopic, appName, processorName);
     return app;
   }
 
index 298abae..c773c5b 100644 (file)
 
 package org.apache.samza.test.processor;
 
+import com.google.common.base.Joiner;
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.ImmutableSet;
 import com.google.common.collect.Maps;
 import com.google.common.collect.Sets;
 import java.util.ArrayList;
@@ -35,12 +37,14 @@ import java.util.Set;
 import java.util.HashSet;
 import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.stream.Collectors;
 import kafka.admin.AdminUtils;
 import kafka.admin.RackAwareMode;
 import kafka.utils.TestUtils;
 import org.I0Itec.zkclient.ZkClient;
 import org.apache.kafka.clients.producer.KafkaProducer;
 import org.apache.kafka.clients.producer.ProducerRecord;
+import org.apache.samza.Partition;
 import org.apache.samza.config.ApplicationConfig;
 import org.apache.samza.config.Config;
 import org.apache.samza.config.ClusterManagerConfig;
@@ -54,6 +58,7 @@ import org.apache.samza.SamzaException;
 import org.apache.samza.container.TaskName;
 import org.apache.samza.coordinator.stream.CoordinatorStreamValueSerde;
 import org.apache.samza.job.ApplicationStatus;
+import org.apache.samza.job.model.ContainerModel;
 import org.apache.samza.job.model.JobModel;
 import org.apache.samza.job.model.TaskModel;
 import org.apache.samza.metadatastore.MetadataStore;
@@ -135,7 +140,7 @@ public class TestZkLocalApplicationRunner extends StandaloneIntegrationTestHarne
     // Set up stream application config map with the given testStreamAppName, testStreamAppId and test kafka system
     // TODO: processorId should typically come up from a processorID generator as processor.id will be deprecated in 0.14.0+
     Map<String, String> configMap =
-        buildStreamApplicationConfigMap(TEST_SYSTEM, inputKafkaTopic, testStreamAppName, testStreamAppId);
+        buildStreamApplicationConfigMap(ImmutableList.of(inputKafkaTopic), testStreamAppName, testStreamAppId);
     configMap.put(JobConfig.PROCESSOR_ID(), PROCESSOR_IDS[0]);
     applicationConfig1 = new ApplicationConfig(new MapConfig(configMap));
     configMap.put(JobConfig.PROCESSOR_ID(), PROCESSOR_IDS[1]);
@@ -195,13 +200,15 @@ public class TestZkLocalApplicationRunner extends StandaloneIntegrationTestHarne
     }
   }
 
-  private Map<String, String> buildStreamApplicationConfigMap(String systemName, String inputTopic,
-      String appName, String appId) {
+  private Map<String, String> buildStreamApplicationConfigMap(List<String> inputTopics, String appName, String appId) {
+    List<String> inputSystemStreams = inputTopics.stream()
+                                                 .map(topic -> String.format("%s.%s", TestZkLocalApplicationRunner.TEST_SYSTEM, topic))
+                                                 .collect(Collectors.toList());
     String coordinatorSystemName = "coordinatorSystem";
     Map<String, String> samzaContainerConfig = ImmutableMap.<String, String>builder()
         .put(ZkConfig.ZK_CONSENSUS_TIMEOUT_MS, BARRIER_TIMEOUT_MS)
-        .put(TaskConfig.INPUT_STREAMS(), inputTopic)
-        .put(JobConfig.JOB_DEFAULT_SYSTEM(), systemName)
+        .put(TaskConfig.INPUT_STREAMS(), Joiner.on(',').join(inputSystemStreams))
+        .put(JobConfig.JOB_DEFAULT_SYSTEM(), TestZkLocalApplicationRunner.TEST_SYSTEM)
         .put(TaskConfig.IGNORED_EXCEPTIONS(), "*")
         .put(ZkConfig.ZK_CONNECT, zkConnect())
         .put(JobConfig.SSP_GROUPER_FACTORY(), TEST_SSP_GROUPER_FACTORY)
@@ -210,7 +217,7 @@ public class TestZkLocalApplicationRunner extends StandaloneIntegrationTestHarne
         .put(ApplicationConfig.APP_NAME, appName)
         .put(ApplicationConfig.APP_ID, appId)
         .put("app.runner.class", "org.apache.samza.runtime.LocalApplicationRunner")
-        .put(String.format("systems.%s.samza.factory", systemName), TEST_SYSTEM_FACTORY)
+        .put(String.format("systems.%s.samza.factory", TestZkLocalApplicationRunner.TEST_SYSTEM), TEST_SYSTEM_FACTORY)
         .put(JobConfig.JOB_NAME(), appName)
         .put(JobConfig.JOB_ID(), appId)
         .put(TaskConfigJava.TASK_SHUTDOWN_MS, TASK_SHUTDOWN_MS)
@@ -222,9 +229,8 @@ public class TestZkLocalApplicationRunner extends StandaloneIntegrationTestHarne
         .put("job.coordinator.replication.factor", "1")
         .build();
     Map<String, String> applicationConfig = Maps.newHashMap(samzaContainerConfig);
-
-    applicationConfig.putAll(StandaloneTestUtils.getKafkaSystemConfigs(systemName, bootstrapServers(), zkConnect(), null, StandaloneTestUtils.SerdeAlias.STRING, true));
     applicationConfig.putAll(StandaloneTestUtils.getKafkaSystemConfigs(coordinatorSystemName, bootstrapServers(), zkConnect(), null, StandaloneTestUtils.SerdeAlias.STRING, true));
+    applicationConfig.putAll(StandaloneTestUtils.getKafkaSystemConfigs(TestZkLocalApplicationRunner.TEST_SYSTEM, bootstrapServers(), zkConnect(), null, StandaloneTestUtils.SerdeAlias.STRING, true));
     return applicationConfig;
   }
 
@@ -265,7 +271,7 @@ public class TestZkLocalApplicationRunner extends StandaloneIntegrationTestHarne
     CountDownLatch processedMessagesLatch = new CountDownLatch(NUM_KAFKA_EVENTS);
     Config localTestConfig2 = new MapConfig(applicationConfig2, testConfig);
     ApplicationRunner appRunner2 = ApplicationRunners.getApplicationRunner(TestStreamApplication.getInstance(
-        TEST_SYSTEM, inputSinglePartitionKafkaTopic, outputSinglePartitionKafkaTopic, processedMessagesLatch,
+        TEST_SYSTEM, ImmutableList.of(inputSinglePartitionKafkaTopic), outputSinglePartitionKafkaTopic, processedMessagesLatch,
         null, null, localTestConfig2), localTestConfig2);
 
     // Callback handler for appRunner1.
@@ -287,7 +293,7 @@ public class TestZkLocalApplicationRunner extends StandaloneIntegrationTestHarne
     // Set up stream app appRunner1.
     Config localTestConfig1 = new MapConfig(applicationConfig1, testConfig);
     ApplicationRunner appRunner1 = ApplicationRunners.getApplicationRunner(TestStreamApplication.getInstance(
-        TEST_SYSTEM, inputSinglePartitionKafkaTopic, outputSinglePartitionKafkaTopic, null,
+        TEST_SYSTEM, ImmutableList.of(inputSinglePartitionKafkaTopic), outputSinglePartitionKafkaTopic, null,
         callback, kafkaEventsConsumedLatch, localTestConfig1), localTestConfig1);
     executeRun(appRunner1, localTestConfig1);
 
@@ -329,7 +335,7 @@ public class TestZkLocalApplicationRunner extends StandaloneIntegrationTestHarne
 
     // Configuration, verification variables
     MapConfig testConfig = new MapConfig(ImmutableMap.of(JobConfig.SSP_GROUPER_FACTORY(),
-        "org.apache.samza.container.grouper.stream.AllSspToSingleTaskGrouperFactory", JobConfig.JOB_DEBOUNCE_TIME_MS(), "10"));
+        "org.apache.samza.container.grouper.stream.AllSspToSingleTaskGrouperFactory", JobConfig.JOB_DEBOUNCE_TIME_MS(), "10", ClusterManagerConfig.HOST_AFFINITY_ENABLED, "false"));
     // Declared as final array to update it from streamApplication callback(Variable should be declared final to access in lambda block).
     final JobModel[] previousJobModel = new JobModel[1];
     final String[] previousJobModelVersion = new String[1];
@@ -347,7 +353,7 @@ public class TestZkLocalApplicationRunner extends StandaloneIntegrationTestHarne
     CountDownLatch processedMessagesLatch = new CountDownLatch(NUM_KAFKA_EVENTS * 2);
     Config testAppConfig2 = new MapConfig(applicationConfig2, testConfig);
     ApplicationRunner appRunner2 = ApplicationRunners.getApplicationRunner(TestStreamApplication.getInstance(
-        TEST_SYSTEM, inputKafkaTopic, outputKafkaTopic, processedMessagesLatch, null,
+        TEST_SYSTEM, ImmutableList.of(inputKafkaTopic), outputKafkaTopic, processedMessagesLatch, null,
         null, testAppConfig2), testAppConfig2);
 
     // Callback handler for appRunner1.
@@ -371,7 +377,7 @@ public class TestZkLocalApplicationRunner extends StandaloneIntegrationTestHarne
     // Set up stream app appRunner1.
     Config testAppConfig1 = new MapConfig(applicationConfig1, testConfig);
     ApplicationRunner appRunner1 = ApplicationRunners.getApplicationRunner(TestStreamApplication.getInstance(
-        TEST_SYSTEM, inputKafkaTopic, outputKafkaTopic, null, streamApplicationCallback,
+        TEST_SYSTEM, ImmutableList.of(inputKafkaTopic), outputKafkaTopic, null, streamApplicationCallback,
         kafkaEventsConsumedLatch, testAppConfig1), testAppConfig1);
     executeRun(appRunner1, testAppConfig1);
 
@@ -425,13 +431,13 @@ public class TestZkLocalApplicationRunner extends StandaloneIntegrationTestHarne
     CountDownLatch processedMessagesLatch3 = new CountDownLatch(1);
 
     ApplicationRunner appRunner1 = ApplicationRunners.getApplicationRunner(TestStreamApplication.getInstance(
-        TEST_SYSTEM, inputKafkaTopic, outputKafkaTopic, processedMessagesLatch1, null, kafkaEventsConsumedLatch,
+        TEST_SYSTEM, ImmutableList.of(inputKafkaTopic), outputKafkaTopic, processedMessagesLatch1, null, kafkaEventsConsumedLatch,
         applicationConfig1), applicationConfig1);
     ApplicationRunner appRunner2 = ApplicationRunners.getApplicationRunner(TestStreamApplication.getInstance(
-        TEST_SYSTEM, inputKafkaTopic, outputKafkaTopic, processedMessagesLatch2, null, kafkaEventsConsumedLatch,
+        TEST_SYSTEM, ImmutableList.of(inputKafkaTopic), outputKafkaTopic, processedMessagesLatch2, null, kafkaEventsConsumedLatch,
         applicationConfig2), applicationConfig2);
     ApplicationRunner appRunner3 = ApplicationRunners.getApplicationRunner(TestStreamApplication.getInstance(
-        TEST_SYSTEM, inputKafkaTopic, outputKafkaTopic, processedMessagesLatch3, null, kafkaEventsConsumedLatch,
+        TEST_SYSTEM, ImmutableList.of(inputKafkaTopic), outputKafkaTopic, processedMessagesLatch3, null, kafkaEventsConsumedLatch,
         applicationConfig3), applicationConfig3);
 
     executeRun(appRunner1, applicationConfig1);
@@ -492,10 +498,10 @@ public class TestZkLocalApplicationRunner extends StandaloneIntegrationTestHarne
     CountDownLatch processedMessagesLatch2 = new CountDownLatch(1);
 
     ApplicationRunner appRunner1 = ApplicationRunners.getApplicationRunner(TestStreamApplication.getInstance(
-        TEST_SYSTEM, inputKafkaTopic, outputKafkaTopic, processedMessagesLatch1, null, kafkaEventsConsumedLatch,
+        TEST_SYSTEM, ImmutableList.of(inputKafkaTopic), outputKafkaTopic, processedMessagesLatch1, null, kafkaEventsConsumedLatch,
         applicationConfig1), applicationConfig1);
     ApplicationRunner appRunner2 = ApplicationRunners.getApplicationRunner(TestStreamApplication.getInstance(
-        TEST_SYSTEM, inputKafkaTopic, outputKafkaTopic, processedMessagesLatch2, null, kafkaEventsConsumedLatch,
+        TEST_SYSTEM, ImmutableList.of(inputKafkaTopic), outputKafkaTopic, processedMessagesLatch2, null, kafkaEventsConsumedLatch,
         applicationConfig2), applicationConfig2);
 
     // Run stream applications.
@@ -510,7 +516,7 @@ public class TestZkLocalApplicationRunner extends StandaloneIntegrationTestHarne
     publishKafkaEvents(inputKafkaTopic, NUM_KAFKA_EVENTS, 2 * NUM_KAFKA_EVENTS, PROCESSOR_IDS[2]);
     kafkaEventsConsumedLatch = new CountDownLatch(NUM_KAFKA_EVENTS);
     ApplicationRunner appRunner3 = ApplicationRunners.getApplicationRunner(TestStreamApplication.getInstance(
-        TEST_SYSTEM, inputKafkaTopic, outputKafkaTopic, null, null, kafkaEventsConsumedLatch,
+        TEST_SYSTEM, ImmutableList.of(inputKafkaTopic), outputKafkaTopic, null, null, kafkaEventsConsumedLatch,
         applicationConfig2), applicationConfig2);
     // Fail when the duplicate processor joins.
     expectedException.expect(SamzaException.class);
@@ -531,7 +537,7 @@ public class TestZkLocalApplicationRunner extends StandaloneIntegrationTestHarne
     publishKafkaEvents(inputKafkaTopic, 0, NUM_KAFKA_EVENTS, PROCESSOR_IDS[0]);
 
     Map<String, String> configMap = buildStreamApplicationConfigMap(
-        TEST_SYSTEM, inputKafkaTopic, testStreamAppName, testStreamAppId);
+            ImmutableList.of(inputKafkaTopic), testStreamAppName, testStreamAppId);
 
     configMap.put(JobConfig.PROCESSOR_ID(), PROCESSOR_IDS[0]);
     Config applicationConfig1 = new MapConfig(configMap);
@@ -548,10 +554,10 @@ public class TestZkLocalApplicationRunner extends StandaloneIntegrationTestHarne
     CountDownLatch processedMessagesLatch2 = new CountDownLatch(1);
 
     ApplicationRunner appRunner1 = ApplicationRunners.getApplicationRunner(TestStreamApplication.getInstance(
-        TEST_SYSTEM, inputKafkaTopic, outputKafkaTopic, processedMessagesLatch1, null, kafkaEventsConsumedLatch,
+        TEST_SYSTEM, ImmutableList.of(inputKafkaTopic), outputKafkaTopic, processedMessagesLatch1, null, kafkaEventsConsumedLatch,
         applicationConfig1), applicationConfig1);
     ApplicationRunner appRunner2 = ApplicationRunners.getApplicationRunner(TestStreamApplication.getInstance(
-        TEST_SYSTEM, inputKafkaTopic, outputKafkaTopic, processedMessagesLatch2, null, kafkaEventsConsumedLatch,
+        TEST_SYSTEM, ImmutableList.of(inputKafkaTopic), outputKafkaTopic, processedMessagesLatch2, null, kafkaEventsConsumedLatch,
         applicationConfig2), applicationConfig2);
 
     // Run stream application.
@@ -579,7 +585,7 @@ public class TestZkLocalApplicationRunner extends StandaloneIntegrationTestHarne
     processedMessagesLatch1 = new CountDownLatch(1);
     publishKafkaEvents(inputKafkaTopic, NUM_KAFKA_EVENTS, 2 * NUM_KAFKA_EVENTS, PROCESSOR_IDS[0]);
     ApplicationRunner appRunner3 = ApplicationRunners.getApplicationRunner(TestStreamApplication.getInstance(
-        TEST_SYSTEM, inputKafkaTopic, outputKafkaTopic, processedMessagesLatch1, null, kafkaEventsConsumedLatch,
+        TEST_SYSTEM, ImmutableList.of(inputKafkaTopic), outputKafkaTopic, processedMessagesLatch1, null, kafkaEventsConsumedLatch,
         applicationConfig1), applicationConfig1);
     executeRun(appRunner3, applicationConfig1);
 
@@ -604,7 +610,7 @@ public class TestZkLocalApplicationRunner extends StandaloneIntegrationTestHarne
   public void testShouldStopStreamApplicationWhenShutdownTimeOutIsLessThanContainerShutdownTime() throws Exception {
     publishKafkaEvents(inputKafkaTopic, 0, NUM_KAFKA_EVENTS, PROCESSOR_IDS[0]);
 
-    Map<String, String> configMap = buildStreamApplicationConfigMap(TEST_SYSTEM, inputKafkaTopic, testStreamAppName, testStreamAppId);
+    Map<String, String> configMap = buildStreamApplicationConfigMap(ImmutableList.of(inputKafkaTopic), testStreamAppName, testStreamAppId);
     configMap.put(TaskConfig.SHUTDOWN_MS(), "0");
 
     configMap.put(JobConfig.PROCESSOR_ID(), PROCESSOR_IDS[0]);
@@ -619,10 +625,10 @@ public class TestZkLocalApplicationRunner extends StandaloneIntegrationTestHarne
     CountDownLatch processedMessagesLatch2 = new CountDownLatch(1);
 
     ApplicationRunner appRunner1 = ApplicationRunners.getApplicationRunner(TestStreamApplication.getInstance(
-        TEST_SYSTEM, inputKafkaTopic, outputKafkaTopic, processedMessagesLatch1, null, kafkaEventsConsumedLatch,
+        TEST_SYSTEM, ImmutableList.of(inputKafkaTopic), outputKafkaTopic, processedMessagesLatch1, null, kafkaEventsConsumedLatch,
         applicationConfig1), applicationConfig1);
     ApplicationRunner appRunner2 = ApplicationRunners.getApplicationRunner(TestStreamApplication.getInstance(
-        TEST_SYSTEM, inputKafkaTopic, outputKafkaTopic, processedMessagesLatch2, null, kafkaEventsConsumedLatch,
+        TEST_SYSTEM, ImmutableList.of(inputKafkaTopic), outputKafkaTopic, processedMessagesLatch2, null, kafkaEventsConsumedLatch,
         applicationConfig2), applicationConfig2);
 
     executeRun(appRunner1, applicationConfig1);
@@ -644,7 +650,7 @@ public class TestZkLocalApplicationRunner extends StandaloneIntegrationTestHarne
     CountDownLatch processedMessagesLatch3 = new CountDownLatch(1);
 
     ApplicationRunner appRunner3 = ApplicationRunners.getApplicationRunner(TestStreamApplication.getInstance(
-        TEST_SYSTEM, inputKafkaTopic, outputKafkaTopic, processedMessagesLatch3, null, kafkaEventsConsumedLatch,
+        TEST_SYSTEM, ImmutableList.of(inputKafkaTopic), outputKafkaTopic, processedMessagesLatch3, null, kafkaEventsConsumedLatch,
         applicationConfig3), applicationConfig3);
     executeRun(appRunner3, applicationConfig3);
 
@@ -681,7 +687,7 @@ public class TestZkLocalApplicationRunner extends StandaloneIntegrationTestHarne
     CountDownLatch processedMessagesLatch1 = new CountDownLatch(1);
 
     ApplicationRunner appRunner1 = ApplicationRunners.getApplicationRunner(TestStreamApplication.getInstance(
-            TEST_SYSTEM, inputKafkaTopic, outputKafkaTopic, processedMessagesLatch1, null, kafkaEventsConsumedLatch1,
+            TEST_SYSTEM, ImmutableList.of(inputKafkaTopic), outputKafkaTopic, processedMessagesLatch1, null, kafkaEventsConsumedLatch1,
             applicationConfig1), applicationConfig1);
 
     executeRun(appRunner1, applicationConfig1);
@@ -735,6 +741,209 @@ public class TestZkLocalApplicationRunner extends StandaloneIntegrationTestHarne
     return new MapConfig(configMap);
   }
 
+  /**
+   * A. Create a input kafka topic with partition count set to 32.
+   * B. Create and launch a stateful samza application which consumes events from the input kafka topic.
+   * C. Validate that the {@link JobModel} contains 32 {@link SystemStreamPartition}'s.
+   * D. Increase the partition count of the input kafka topic to 64.
+   * E. Validate that the new {@link JobModel} contains 64 {@link SystemStreamPartition}'s and the input
+   * SystemStreamPartitions are mapped to the correct task.
+   */
+  @Test
+  public void testStatefulSamzaApplicationShouldRedistributeInputPartitionsToCorrectTasksWhenAInputStreamIsExpanded() throws Exception {
+    // Setup input topics.
+    String statefulInputKafkaTopic = String.format("test-input-topic-%s", UUID.randomUUID().toString());
+    TestUtils.createTopic(zkUtils(), statefulInputKafkaTopic, 32, 1, servers(), new Properties());
+
+    // Generate configuration for the test.
+    Map<String, String> configMap = buildStreamApplicationConfigMap(ImmutableList.of(statefulInputKafkaTopic), testStreamAppName, testStreamAppId);
+    configMap.put(JobConfig.PROCESSOR_ID(), PROCESSOR_IDS[0]);
+    Config applicationConfig1 = new ApplicationConfig(new MapConfig(configMap));
+
+    publishKafkaEvents(statefulInputKafkaTopic, 0, NUM_KAFKA_EVENTS, PROCESSOR_IDS[0]);
+
+    // Create StreamApplication from configuration.
+    CountDownLatch kafkaEventsConsumedLatch1 = new CountDownLatch(NUM_KAFKA_EVENTS);
+    CountDownLatch processedMessagesLatch1 = new CountDownLatch(1);
+
+    ApplicationRunner appRunner1 = ApplicationRunners.getApplicationRunner(TestStreamApplication.getInstance(
+            TEST_SYSTEM, ImmutableList.of(statefulInputKafkaTopic), outputKafkaTopic, processedMessagesLatch1, null, kafkaEventsConsumedLatch1,
+            applicationConfig1), applicationConfig1);
+
+    executeRun(appRunner1, applicationConfig1);
+    processedMessagesLatch1.await();
+
+    // Generate the correct task assignments before the input stream expansion.
+    Map<TaskName, Set<SystemStreamPartition>> expectedTaskAssignments = new HashMap<>();
+    for (int partition = 0; partition < 32; ++partition) {
+      TaskName taskName = new TaskName(String.format("Partition %d", partition));
+      SystemStreamPartition systemStreamPartition = new SystemStreamPartition(TEST_SYSTEM, statefulInputKafkaTopic, new Partition(partition));
+      expectedTaskAssignments.put(taskName, ImmutableSet.of(systemStreamPartition));
+    }
+
+    // Read the latest JobModel for validation.
+    String jobModelVersion = zkUtils.getJobModelVersion();
+    JobModel jobModel = zkUtils.getJobModel(jobModelVersion);
+    Set<SystemStreamPartition> ssps = getSystemStreamPartitions(jobModel);
+
+    // Validate that the input partition count is 32 in the JobModel.
+    Assert.assertEquals(32, ssps.size());
+
+    // Validate that the new JobModel has the expected task assignments before the input stream expansion.
+    Map<TaskName, Set<SystemStreamPartition>> actualTaskAssignments = getTaskAssignments(jobModel);
+    Assert.assertEquals(expectedTaskAssignments, actualTaskAssignments);
+
+    // Increase the partition count of the input kafka topic to 64.
+    AdminUtils.addPartitions(zkUtils(), statefulInputKafkaTopic, 64, "", true, RackAwareMode.Enforced$.MODULE$);
+
+    // Wait for the JobModel version to change due to the increase in the input partition count.
+    long jobModelWaitTimeInMillis = 10;
+    while (Objects.equals(zkUtils.getJobModelVersion(), jobModelVersion)) {
+      LOGGER.info("Waiting for new jobModel to be published");
+      Thread.sleep(jobModelWaitTimeInMillis);
+      jobModelWaitTimeInMillis = jobModelWaitTimeInMillis * 2;
+    }
+
+    // Read the latest JobModel for validation.
+    jobModelVersion = zkUtils.getJobModelVersion();
+    jobModel = zkUtils.getJobModel(jobModelVersion);
+    ssps = getSystemStreamPartitions(jobModel);
+
+    // Validate that the input partition count is 64 in the new JobModel.
+    Assert.assertEquals(64, ssps.size());
+
+    // Generate the correct task assignments after the input stream expansion.
+    expectedTaskAssignments = new HashMap<>();
+    for (int partition = 0; partition < 32; ++partition) {
+      TaskName taskName = new TaskName(String.format("Partition %d", partition));
+      SystemStreamPartition ssp = new SystemStreamPartition(TEST_SYSTEM, statefulInputKafkaTopic, new Partition(partition));
+      SystemStreamPartition expandedSSP = new SystemStreamPartition(TEST_SYSTEM, statefulInputKafkaTopic, new Partition(partition + 32));
+      expectedTaskAssignments.put(taskName, ImmutableSet.of(ssp, expandedSSP));
+    }
+
+    // Validate that the new JobModel has the expected task assignments.
+    actualTaskAssignments = getTaskAssignments(jobModel);
+    Assert.assertEquals(expectedTaskAssignments, actualTaskAssignments);
+  }
+
+  /**
+   * A. Create a input kafka topic: T1 with partition count set to 32.
+   * B. Create a input kafka topic: T2 with partition count set to 32.
+   * C. Create and launch a stateful samza application which consumes events from the kafka topics: T1 and T2.
+   * D. Validate that the {@link JobModel} contains 32 {@link SystemStreamPartition}'s of T1 and 32 {@link SystemStreamPartition}'s of T2.
+   * E. Increase the partition count of the input kafka topic: T1 to 64.
+   * F. Increase the partition count of the input kafka topic: T2 to 64.
+   * G. Validate that the new {@link JobModel} contains 64 {@link SystemStreamPartition}'s of T1, T2 and the input
+   * SystemStreamPartitions are mapped to the correct task.
+   */
+  @Test
+  public void testStatefulSamzaApplicationShouldRedistributeInputPartitionsToCorrectTasksWhenMultipleInputStreamsAreExpanded() throws Exception {
+    // Setup the two input kafka topics.
+    String statefulInputKafkaTopic1 = String.format("test-input-topic-%s", UUID.randomUUID().toString());
+    String statefulInputKafkaTopic2 = String.format("test-input-topic-%s", UUID.randomUUID().toString());
+    TestUtils.createTopic(zkUtils(), statefulInputKafkaTopic1, 32, 1, servers(), new Properties());
+    TestUtils.createTopic(zkUtils(), statefulInputKafkaTopic2, 32, 1, servers(), new Properties());
+
+    // Generate configuration for the test.
+    Map<String, String> configMap = buildStreamApplicationConfigMap(ImmutableList.of(statefulInputKafkaTopic1, statefulInputKafkaTopic2),
+                                  testStreamAppName, testStreamAppId);
+    configMap.put(JobConfig.PROCESSOR_ID(), PROCESSOR_IDS[0]);
+    Config applicationConfig1 = new ApplicationConfig(new MapConfig(configMap));
+
+    // Publish events into the input kafka topics.
+    publishKafkaEvents(statefulInputKafkaTopic1, 0, NUM_KAFKA_EVENTS, PROCESSOR_IDS[0]);
+    publishKafkaEvents(statefulInputKafkaTopic2, 0, NUM_KAFKA_EVENTS, PROCESSOR_IDS[0]);
+
+    // Create and launch the StreamApplication from configuration.
+    CountDownLatch kafkaEventsConsumedLatch1 = new CountDownLatch(NUM_KAFKA_EVENTS);
+    CountDownLatch processedMessagesLatch1 = new CountDownLatch(1);
+
+    ApplicationRunner appRunner1 = ApplicationRunners.getApplicationRunner(TestStreamApplication.getInstance(
+        TEST_SYSTEM, ImmutableList.of(statefulInputKafkaTopic1, statefulInputKafkaTopic2), outputKafkaTopic, processedMessagesLatch1, null, kafkaEventsConsumedLatch1,
+        applicationConfig1), applicationConfig1);
+
+    executeRun(appRunner1, applicationConfig1);
+    processedMessagesLatch1.await();
+    kafkaEventsConsumedLatch1.await();
+
+    // Generate the correct task assignments before the input stream expansion.
+    Map<TaskName, Set<SystemStreamPartition>> expectedTaskAssignments = new HashMap<>();
+    for (int partition = 0; partition < 32; ++partition) {
+      TaskName taskName = new TaskName(String.format("Partition %d", partition));
+      SystemStreamPartition systemStreamPartition1 = new SystemStreamPartition(TEST_SYSTEM, statefulInputKafkaTopic1, new Partition(partition));
+      SystemStreamPartition systemStreamPartition2 = new SystemStreamPartition(TEST_SYSTEM, statefulInputKafkaTopic2, new Partition(partition));
+      expectedTaskAssignments.put(taskName, ImmutableSet.of(systemStreamPartition1, systemStreamPartition2));
+    }
+
+    // Read the latest JobModel for validation.
+    String jobModelVersion = zkUtils.getJobModelVersion();
+    JobModel jobModel = zkUtils.getJobModel(jobModelVersion);
+    Set<SystemStreamPartition> ssps = getSystemStreamPartitions(jobModel);
+
+    // Validate that the input 64 partitions are present in JobModel.
+    Assert.assertEquals(64, ssps.size());
+
+    // Validate that the new JobModel has the expected task assignments before the input stream expansion.
+    Map<TaskName, Set<SystemStreamPartition>> actualTaskAssignments = getTaskAssignments(jobModel);
+    Assert.assertEquals(expectedTaskAssignments, actualTaskAssignments);
+
+    // Increase the partition count of the input kafka topic1 to 64.
+    AdminUtils.addPartitions(zkUtils(), statefulInputKafkaTopic1, 64, "", true, RackAwareMode.Enforced$.MODULE$);
+
+    // Increase the partition count of the input kafka topic2 to 64.
+    AdminUtils.addPartitions(zkUtils(), statefulInputKafkaTopic2, 64, "", true, RackAwareMode.Enforced$.MODULE$);
+
+    // Wait for the JobModel version to change due to the increase in the input partition count.
+    long jobModelWaitTimeInMillis = 10;
+    while (Objects.equals(zkUtils.getJobModelVersion(), jobModelVersion)) {
+      LOGGER.info("Waiting for new jobModel to be published");
+      Thread.sleep(jobModelWaitTimeInMillis);
+      jobModelWaitTimeInMillis = jobModelWaitTimeInMillis * 2;
+    }
+
+    // Read the latest JobModel for validation.
+    jobModelVersion = zkUtils.getJobModelVersion();
+    jobModel = zkUtils.getJobModel(jobModelVersion);
+    ssps = getSystemStreamPartitions(jobModel);
+
+    // Validate that the input partition count is 128 in the new JobModel.
+    Assert.assertEquals(128, ssps.size());
+
+    // Generate the correct task assignments after the input stream expansion.
+    expectedTaskAssignments = new HashMap<>();
+    for (int partition = 0; partition < 32; ++partition) {
+      TaskName taskName = new TaskName(String.format("Partition %d", partition));
+      SystemStreamPartition ssp1 = new SystemStreamPartition(TEST_SYSTEM, statefulInputKafkaTopic1, new Partition(partition));
+      SystemStreamPartition expandedSSP1 = new SystemStreamPartition(TEST_SYSTEM, statefulInputKafkaTopic1, new Partition(partition + 32));
+      SystemStreamPartition ssp2 = new SystemStreamPartition(TEST_SYSTEM, statefulInputKafkaTopic2, new Partition(partition));
+      SystemStreamPartition expandedSSP2 = new SystemStreamPartition(TEST_SYSTEM, statefulInputKafkaTopic2, new Partition(partition + 32));
+      expectedTaskAssignments.put(taskName, ImmutableSet.of(ssp1, expandedSSP1, ssp2, expandedSSP2));
+    }
+
+    // Validate that the new JobModel has the expected task assignments.
+    actualTaskAssignments = getTaskAssignments(jobModel);
+    Assert.assertEquals(expectedTaskAssignments, actualTaskAssignments);
+  }
+
+  /**
+   * Computes the task to partition assignment of the {@param JobModel}.
+   * @param jobModel the jobModel to compute task to partition assignment for.
+   * @return the computed task to partition assignments of the {@param JobModel}.
+   */
+  private static Map<TaskName, Set<SystemStreamPartition>> getTaskAssignments(JobModel jobModel) {
+    Map<TaskName, Set<SystemStreamPartition>> taskAssignments = new HashMap<>();
+    for (Map.Entry<String, ContainerModel> entry : jobModel.getContainers().entrySet()) {
+      Map<TaskName, TaskModel> tasks = entry.getValue().getTasks();
+      for (TaskModel taskModel : tasks.values()) {
+        if (!taskAssignments.containsKey(taskModel.getTaskName())) {
+          taskAssignments.put(taskModel.getTaskName(), new HashSet<>());
+        }
+        taskAssignments.get(taskModel.getTaskName()).addAll(taskModel.getSystemStreamPartitions());
+      }
+    }
+    return taskAssignments;
+  }
+
   private static Set<SystemStreamPartition> getSystemStreamPartitions(JobModel jobModel) {
     Set<SystemStreamPartition> ssps = new HashSet<>();
     jobModel.getContainers().forEach((containerName, containerModel) -> {