SAMZA-1334: fix pre-condition for ContainerAllocator to work properly
authorYi Pan (Data Infrastructure) <nickpan47@gmail.com>
Tue, 20 Jun 2017 15:35:39 +0000 (08:35 -0700)
committerYi Pan (Data Infrastructure) <nickpan47@gmail.com>
Tue, 20 Jun 2017 15:35:39 +0000 (08:35 -0700)
We have observed issues when the LocalityManager reports the container locality mapping while the host-affinity is disabled in ContainerAllocator, in which the ContainerAllocator failed to release extra containers.

Hence, fix is in the form of make sure the pre-condition is met for the ContainerAllocator w/o host-affinity: the localityMap from the JobModel should contain no preferred host info.

Author: Yi Pan (Data Infrastructure) <nickpan47@gmail.com>

Reviewers: Jagadish <jagadish1989@gmail.com>

Closes #228 from nickpan47/SAMZA-1334 and squashes the following commits:

ad3320f [Yi Pan (Data Infrastructure)] SAMZA-1334: fix the pre-conditions for ContainerAllocator to work properly. Make sure JobModel is generated w/o LocalityManager if host-affinity is disabled
f76fff1 [Yi Pan (Data Infrastructure)] WIP: SAMZA-1334 fix

samza-core/src/main/java/org/apache/samza/job/model/JobModel.java
samza-core/src/main/scala/org/apache/samza/coordinator/JobModelManager.scala
samza-core/src/test/java/org/apache/samza/clustermanager/TestContainerAllocator.java
samza-core/src/test/java/org/apache/samza/clustermanager/TestContainerProcessManager.java
samza-core/src/test/java/org/apache/samza/clustermanager/TestHostAwareContainerAllocator.java
samza-core/src/test/java/org/apache/samza/coordinator/JobModelManagerTestUtil.java [new file with mode: 0644]
samza-core/src/test/java/org/apache/samza/coordinator/TestJobModelManager.java [new file with mode: 0644]
samza-core/src/test/java/org/apache/samza/testUtils/MockHttpServer.java [moved from samza-core/src/test/java/org/apache/samza/clustermanager/MockHttpServer.java with 97% similarity]

index dbb3867..1115faf 100644 (file)
@@ -44,7 +44,7 @@ public class JobModel {
   private final Map<String, ContainerModel> containers;
 
   private final LocalityManager localityManager;
-  private Map<String, String> localityMappings = new HashMap<String, String>();
+  private final Map<String, String> localityMappings;
 
   public int maxChangeLogStreamPartitions;
 
@@ -57,6 +57,8 @@ public class JobModel {
     this.containers = Collections.unmodifiableMap(containers);
     this.localityManager = localityManager;
 
+    // initialize container localityMappings
+    this.localityMappings = new HashMap<>();
     if (localityManager == null) {
       for (String containerId : containers.keySet()) {
         localityMappings.put(containerId, null);
index 353e297..6319173 100644 (file)
@@ -23,6 +23,7 @@ package org.apache.samza.coordinator
 import java.util
 import java.util.concurrent.atomic.AtomicReference
 
+import org.apache.samza.config.ClusterManagerConfig
 import org.apache.samza.config.JobConfig.Config2Job
 import org.apache.samza.config.SystemConfig.Config2System
 import org.apache.samza.config.TaskConfig.Config2Task
@@ -98,6 +99,7 @@ object JobModelManager extends Logging {
     info("Got config: %s" format config)
     val changelogManager = new ChangelogPartitionManager(coordinatorSystemProducer, coordinatorSystemConsumer, SOURCE)
     changelogManager.start()
+
     val localityManager = new LocalityManager(coordinatorSystemProducer, coordinatorSystemConsumer)
     // We don't need to start() localityManager as they share the same instances with checkpoint and changelog managers.
     // TODO: This code will go away with refactoring - SAMZA-678
@@ -228,6 +230,8 @@ object JobModelManager extends Logging {
     val groups = grouper.group(allSystemStreamPartitions.asJava)
     info("SystemStreamPartitionGrouper %s has grouped the SystemStreamPartitions into %d tasks with the following taskNames: %s" format(grouper, groups.size(), groups.keySet()))
 
+    val isHostAffinityEnabled = new ClusterManagerConfig(config).getHostAffinityEnabled
+
     // If no mappings are present(first time the job is running) we return -1, this will allow 0 to be the first change
     // mapping.
     var maxChangelogPartitionId = changeLogPartitionMapping.asScala.values.map(_.toInt).toList.sorted.lastOption.getOrElse(-1)
@@ -256,13 +260,17 @@ object JobModelManager extends Logging {
     val containerGrouper = containerGrouperFactory.build(config)
     val containerModels = {
       containerGrouper match {
-        case grouper: BalancingTaskNameGrouper => grouper.balance(taskModels.asJava, localityManager)
+        case grouper: BalancingTaskNameGrouper if isHostAffinityEnabled => grouper.balance(taskModels.asJava, localityManager)
         case _ => containerGrouper.group(taskModels.asJava, containerIds)
       }
     }
     val containerMap = containerModels.asScala.map { case (containerModel) => containerModel.getProcessorId -> containerModel }.toMap
 
-    new JobModel(config, containerMap.asJava, localityManager)
+    if (isHostAffinityEnabled) {
+      new JobModel(config, containerMap.asJava, localityManager)
+    } else {
+      new JobModel(config, containerMap.asJava)
+    }
   }
 
   /**
index 989b82a..1e9d372 100644 (file)
@@ -21,12 +21,9 @@ package org.apache.samza.clustermanager;
 
 import org.apache.samza.config.Config;
 import org.apache.samza.config.MapConfig;
-import org.apache.samza.container.TaskName;
 import org.apache.samza.coordinator.JobModelManager;
-import org.apache.samza.coordinator.server.HttpServer;
-import org.apache.samza.job.model.ContainerModel;
-import org.apache.samza.job.model.JobModel;
-import org.apache.samza.job.model.TaskModel;
+import org.apache.samza.coordinator.JobModelManagerTestUtil;
+import org.apache.samza.testUtils.MockHttpServer;
 import org.eclipse.jetty.servlet.DefaultServlet;
 import org.eclipse.jetty.servlet.ServletHolder;
 import org.junit.After;
@@ -47,8 +44,9 @@ public class TestContainerAllocator {
   private final MockClusterResourceManagerCallback callback = new MockClusterResourceManagerCallback();
   private final MockClusterResourceManager manager = new MockClusterResourceManager(callback);
   private final Config config = getConfig();
-  private final JobModelManager reader = getJobModelReader(1);
-  private final SamzaApplicationState state = new SamzaApplicationState(reader);
+  private final JobModelManager jobModelManager = JobModelManagerTestUtil.getJobModelManager(config, 1,
+      new MockHttpServer("/", 7777, null, new ServletHolder(DefaultServlet.class)));
+  private final SamzaApplicationState state = new SamzaApplicationState(jobModelManager);
   private ContainerAllocator containerAllocator;
   private MockContainerRequestState requestState;
   private Thread allocatorThread;
@@ -67,7 +65,7 @@ public class TestContainerAllocator {
 
   @After
   public void teardown() throws Exception {
-    reader.stop();
+    jobModelManager.stop();
     containerAllocator.stop();
   }
 
@@ -93,19 +91,6 @@ public class TestContainerAllocator {
     return new MapConfig(map);
   }
 
-  private static JobModelManager getJobModelReader(int containerCount) {
-    //Ideally, the JobModelReader should be constructed independent of HttpServer.
-    //That way it becomes easier to mock objects. Save it for later.
-
-    HttpServer server = new MockHttpServer("/", 7777, null, new ServletHolder(DefaultServlet.class));
-    Map<String, ContainerModel> containers = new java.util.HashMap<>();
-    for (int i = 0; i < containerCount; i++) {
-      ContainerModel container = new ContainerModel(String.valueOf(i), i, new HashMap<TaskName, TaskModel>());
-      containers.put(String.valueOf(i), container);
-    }
-    JobModel jobModel = new JobModel(getConfig(), containers);
-    return new JobModelManager(jobModel, server, null);
-  }
 
 
   /**
@@ -132,10 +117,10 @@ public class TestContainerAllocator {
   public void testRequestContainers() throws Exception {
     Map<String, String> containersToHostMapping = new HashMap<String, String>() {
       {
-        put("0", "abc");
-        put("1", "def");
+        put("0", null);
+        put("1", null);
         put("2", null);
-        put("3", "abc");
+        put("3", null);
       }
     };
 
index 660012e..8199559 100644 (file)
@@ -23,14 +23,12 @@ import org.apache.samza.config.Config;
 import org.apache.samza.config.JobConfig;
 import org.apache.samza.config.MapConfig;
 import org.apache.samza.container.LocalityManager;
-import org.apache.samza.container.TaskName;
 import org.apache.samza.coordinator.JobModelManager;
+import org.apache.samza.coordinator.JobModelManagerTestUtil;
 import org.apache.samza.coordinator.server.HttpServer;
 import org.apache.samza.coordinator.stream.messages.SetContainerHostMapping;
-import org.apache.samza.job.model.ContainerModel;
-import org.apache.samza.job.model.JobModel;
-import org.apache.samza.job.model.TaskModel;
 import org.apache.samza.metrics.MetricsRegistryMap;
+import org.apache.samza.testUtils.MockHttpServer;
 import org.eclipse.jetty.servlet.DefaultServlet;
 import org.eclipse.jetty.servlet.ServletHolder;
 import org.junit.After;
@@ -79,7 +77,7 @@ public class TestContainerProcessManager {
   private Config getConfigWithHostAffinity() {
     Map<String, String> map = new HashMap<>();
     map.putAll(config);
-    map.put("yarn.samza.host-affinity.enabled", "true");
+    map.put("job.host-affinity.enabled", "true");
     return new MapConfig(map);
   }
 
@@ -87,30 +85,24 @@ public class TestContainerProcessManager {
 
   private SamzaApplicationState state = null;
 
-  private JobModelManager getCoordinator(int containerCount) {
-    Map<String, ContainerModel> containers = new java.util.HashMap<>();
-    for (int i = 0; i < containerCount; i++) {
-      ContainerModel container = new ContainerModel(String.valueOf(i), i, new HashMap<TaskName, TaskModel>());
-      containers.put(String.valueOf(i), container);
-    }
+  private JobModelManager getJobModelManagerWithHostAffinity(int containerCount) {
     Map<String, Map<String, String>> localityMap = new HashMap<>();
     localityMap.put("0", new HashMap<String, String>() { {
         put(SetContainerHostMapping.HOST_KEY, "abc");
-      }
-    });
+      } });
     LocalityManager mockLocalityManager = mock(LocalityManager.class);
     when(mockLocalityManager.readContainerLocality()).thenReturn(localityMap);
 
-    JobModel jobModel = new JobModel(getConfig(), containers, mockLocalityManager);
-    JobModelManager.jobModelRef().getAndSet(jobModel);
+    return JobModelManagerTestUtil.getJobModelManagerWithLocalityManager(getConfig(), containerCount, mockLocalityManager, this.server);
+  }
 
-    return new JobModelManager(jobModel, this.server, null);
+  private JobModelManager getJobModelManagerWithoutHostAffinity(int containerCount) {
+    return JobModelManagerTestUtil.getJobModelManager(getConfig(), containerCount, this.server);
   }
 
   @Before
   public void setup() throws Exception {
     server = new MockHttpServer("/", 7777, null, new ServletHolder(DefaultServlet.class));
-    state = new SamzaApplicationState(getCoordinator(1));
   }
 
   private Field getPrivateFieldFromTaskManager(String fieldName, ContainerProcessManager object) throws Exception {
@@ -127,6 +119,7 @@ public class TestContainerProcessManager {
     conf.put("yarn.container.memory.mb", "500");
     conf.put("yarn.container.cpu.cores", "5");
 
+    state = new SamzaApplicationState(getJobModelManagerWithoutHostAffinity(1));
     ContainerProcessManager taskManager = new ContainerProcessManager(
         new MapConfig(conf),
         state,
@@ -146,6 +139,7 @@ public class TestContainerProcessManager {
     conf.put("yarn.container.memory.mb", "500");
     conf.put("yarn.container.cpu.cores", "5");
 
+    state = new SamzaApplicationState(getJobModelManagerWithHostAffinity(1));
     taskManager = new ContainerProcessManager(
         new MapConfig(conf),
         state,
@@ -164,6 +158,8 @@ public class TestContainerProcessManager {
   @Test
   public void testOnInit() throws Exception {
     Config conf = getConfig();
+    state = new SamzaApplicationState(getJobModelManagerWithoutHostAffinity(1));
+
     ContainerProcessManager taskManager = new ContainerProcessManager(
         new MapConfig(conf),
         state,
@@ -200,6 +196,8 @@ public class TestContainerProcessManager {
   @Test
   public void testOnShutdown() throws Exception {
     Config conf = getConfig();
+    state = new SamzaApplicationState(getJobModelManagerWithoutHostAffinity(1));
+
     ContainerProcessManager taskManager =  new ContainerProcessManager(
         new MapConfig(conf),
         state,
@@ -226,6 +224,8 @@ public class TestContainerProcessManager {
   @Test
   public void testTaskManagerShouldStopWhenContainersFinish() {
     Config conf = getConfig();
+    state = new SamzaApplicationState(getJobModelManagerWithoutHostAffinity(1));
+
     ContainerProcessManager taskManager =  new ContainerProcessManager(
         new MapConfig(conf),
         state,
@@ -251,6 +251,7 @@ public class TestContainerProcessManager {
   @Test
   public void testNewContainerRequestedOnFailureWithUnknownCode() throws Exception {
     Config conf = getConfig();
+    state = new SamzaApplicationState(getJobModelManagerWithoutHostAffinity(1));
 
     ContainerProcessManager taskManager = new ContainerProcessManager(
         new MapConfig(conf),
@@ -330,6 +331,8 @@ public class TestContainerProcessManager {
     config.putAll(getConfig());
     config.remove("yarn.container.retry.count");
 
+    state = new SamzaApplicationState(getJobModelManagerWithoutHostAffinity(1));
+
     ContainerProcessManager taskManager = new ContainerProcessManager(
         new MapConfig(conf),
         state,
@@ -393,8 +396,11 @@ public class TestContainerProcessManager {
 
   @Test
   public void testAppMasterWithFwk() {
+    Config conf = getConfig();
+    state = new SamzaApplicationState(getJobModelManagerWithoutHostAffinity(1));
+
     ContainerProcessManager taskManager = new ContainerProcessManager(
-        new MapConfig(config),
+        new MapConfig(conf),
         state,
         new MetricsRegistryMap(),
         manager
index 83d31e2..32ec2d2 100644 (file)
@@ -26,12 +26,11 @@ import java.util.concurrent.atomic.AtomicInteger;
 
 import org.apache.samza.config.Config;
 import org.apache.samza.config.MapConfig;
-import org.apache.samza.container.TaskName;
+import org.apache.samza.container.LocalityManager;
 import org.apache.samza.coordinator.JobModelManager;
-import org.apache.samza.coordinator.server.HttpServer;
-import org.apache.samza.job.model.ContainerModel;
-import org.apache.samza.job.model.JobModel;
-import org.apache.samza.job.model.TaskModel;
+import org.apache.samza.coordinator.JobModelManagerTestUtil;
+import org.apache.samza.coordinator.stream.messages.SetContainerHostMapping;
+import org.apache.samza.testUtils.MockHttpServer;
 import org.eclipse.jetty.servlet.DefaultServlet;
 import org.eclipse.jetty.servlet.ServletHolder;
 import org.junit.After;
@@ -42,13 +41,28 @@ import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertNotNull;
 import static org.junit.Assert.assertNull;
 import static org.junit.Assert.assertTrue;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
 
 public class TestHostAwareContainerAllocator {
 
   private final MockClusterResourceManagerCallback callback = new MockClusterResourceManagerCallback();
   private final MockClusterResourceManager manager = new MockClusterResourceManager(callback);
   private final Config config = getConfig();
-  private final JobModelManager reader = getJobModelManager(1);
+  private final JobModelManager reader = initializeJobModelManager(config, 1);
+
+  private JobModelManager initializeJobModelManager(Config config, int containerCount) {
+    Map<String, Map<String, String>> localityMap = new HashMap<>();
+    localityMap.put("0", new HashMap<String, String>() { {
+        put(SetContainerHostMapping.HOST_KEY, "abc");
+      } });
+    LocalityManager mockLocalityManager = mock(LocalityManager.class);
+    when(mockLocalityManager.readContainerLocality()).thenReturn(localityMap);
+
+    return JobModelManagerTestUtil.getJobModelManagerWithLocalityManager(getConfig(), containerCount, mockLocalityManager,
+        new MockHttpServer("/", 7777, null, new ServletHolder(DefaultServlet.class)));
+  }
+
   private final SamzaApplicationState state = new SamzaApplicationState(reader);
   private HostAwareContainerAllocator containerAllocator;
   private final int timeoutMillis = 1000;
@@ -334,19 +348,4 @@ public class TestHostAwareContainerAllocator {
     return new MapConfig(map);
   }
 
-  private static JobModelManager getJobModelManager(int containerCount) {
-    //Ideally, the JobModelReader should be constructed independent of HttpServer.
-    //That way it becomes easier to mock objects. Save it for later.
-
-    HttpServer server = new MockHttpServer("/", 7777, null, new ServletHolder(DefaultServlet.class));
-    Map<String, ContainerModel> containers = new java.util.HashMap<>();
-    for (int i = 0; i < containerCount; i++) {
-      ContainerModel container = new ContainerModel(String.valueOf(i), i, new HashMap<TaskName, TaskModel>());
-      containers.put(String.valueOf(i), container);
-    }
-    JobModel jobModel = new JobModel(getConfig(), containers);
-    return new JobModelManager(jobModel, server, null);
-  }
-
-
 }
diff --git a/samza-core/src/test/java/org/apache/samza/coordinator/JobModelManagerTestUtil.java b/samza-core/src/test/java/org/apache/samza/coordinator/JobModelManagerTestUtil.java
new file mode 100644 (file)
index 0000000..b7514c4
--- /dev/null
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.coordinator;
+
+import org.apache.samza.config.Config;
+import org.apache.samza.container.LocalityManager;
+import org.apache.samza.container.TaskName;
+import org.apache.samza.coordinator.server.HttpServer;
+import org.apache.samza.job.model.ContainerModel;
+import org.apache.samza.job.model.JobModel;
+import org.apache.samza.job.model.TaskModel;
+import org.apache.samza.system.StreamMetadataCache;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * Utils to create instances of {@link JobModelManager} in unit tests
+ */
+public class JobModelManagerTestUtil {
+
+  public static JobModelManager getJobModelManager(Config config, int containerCount, HttpServer server) {
+    return getJobModelManagerWithLocalityManager(config, containerCount, null, server);
+  }
+
+  public static JobModelManager getJobModelManagerWithLocalityManager(Config config, int containerCount, LocalityManager localityManager, HttpServer server) {
+    Map<String, ContainerModel> containers = new java.util.HashMap<>();
+    for (int i = 0; i < containerCount; i++) {
+      ContainerModel container = new ContainerModel(String.valueOf(i), i, new HashMap<TaskName, TaskModel>());
+      containers.put(String.valueOf(i), container);
+    }
+    JobModel jobModel = new JobModel(config, containers, localityManager);
+    return new JobModelManager(jobModel, server, null);
+  }
+
+  public static JobModelManager getJobModelManagerUsingReadModel(Config config, int containerCount, StreamMetadataCache streamMetadataCache,
+    LocalityManager locManager, HttpServer server) {
+    List<String> containerIds = new ArrayList<>();
+    for (int i = 0; i < containerCount; i++) {
+      containerIds.add(String.valueOf(i));
+    }
+    JobModel jobModel = JobModelManager.readJobModel(config, new HashMap<>(), locManager, streamMetadataCache, containerIds);
+    return new JobModelManager(jobModel, server, null);
+  }
+
+
+}
diff --git a/samza-core/src/test/java/org/apache/samza/coordinator/TestJobModelManager.java b/samza-core/src/test/java/org/apache/samza/coordinator/TestJobModelManager.java
new file mode 100644 (file)
index 0000000..1d6fc65
--- /dev/null
@@ -0,0 +1,135 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.coordinator;
+
+import org.apache.samza.Partition;
+import org.apache.samza.config.Config;
+import org.apache.samza.config.MapConfig;
+import org.apache.samza.container.LocalityManager;
+import org.apache.samza.container.grouper.task.TaskAssignmentManager;
+import org.apache.samza.coordinator.server.HttpServer;
+import org.apache.samza.coordinator.stream.messages.SetContainerHostMapping;
+import org.apache.samza.system.StreamMetadataCache;
+import org.apache.samza.system.SystemStream;
+import org.apache.samza.system.SystemStreamMetadata;
+import org.apache.samza.testUtils.MockHttpServer;
+import org.eclipse.jetty.servlet.DefaultServlet;
+import org.eclipse.jetty.servlet.ServletHolder;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Collections;
+
+import static org.junit.Assert.assertEquals;
+import static org.mockito.Matchers.anyBoolean;
+import static org.mockito.Matchers.argThat;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+import org.mockito.ArgumentMatcher;
+import scala.collection.JavaConversions;
+
+/**
+ * Unit tests for {@link JobModelManager}
+ */
+public class TestJobModelManager {
+  private final TaskAssignmentManager mockTaskManager = mock(TaskAssignmentManager.class);
+  private final LocalityManager mockLocalityManager = mock(LocalityManager.class);
+  private final Map<String, Map<String, String>> localityMappings = new HashMap<>();
+  private final HttpServer server = new MockHttpServer("/", 7777, null, new ServletHolder(DefaultServlet.class));
+  private final SystemStream inputStream = new SystemStream("test-system", "test-stream");
+  private final SystemStreamMetadata.SystemStreamPartitionMetadata mockSspMetadata = mock(SystemStreamMetadata.SystemStreamPartitionMetadata.class);
+  private final Map<Partition, SystemStreamMetadata.SystemStreamPartitionMetadata> mockSspMetadataMap = Collections.singletonMap(new Partition(0), mockSspMetadata);
+  private final SystemStreamMetadata mockStreamMetadata = mock(SystemStreamMetadata.class);
+  private final scala.collection.immutable.Map<SystemStream, SystemStreamMetadata> mockStreamMetadataMap = new scala.collection.immutable.Map.Map1(inputStream, mockStreamMetadata);
+  private final StreamMetadataCache mockStreamMetadataCache = mock(StreamMetadataCache.class);
+  private final scala.collection.immutable.Set<SystemStream> inputStreamSet = JavaConversions.asScalaSet(Collections.singleton(inputStream)).toSet();
+
+  private JobModelManager jobModelManager;
+
+  @Before
+  public void setup() {
+    when(mockLocalityManager.readContainerLocality()).thenReturn(this.localityMappings);
+    when(mockStreamMetadataCache.getStreamMetadata(argThat(new ArgumentMatcher<scala.collection.immutable.Set<SystemStream>>() {
+      @Override
+      public boolean matches(Object argument) {
+        scala.collection.immutable.Set<SystemStream> set = (scala.collection.immutable.Set<SystemStream>) argument;
+        return set.equals(inputStreamSet);
+      }
+    }), anyBoolean())).thenReturn(mockStreamMetadataMap);
+    when(mockStreamMetadata.getSystemStreamPartitionMetadata()).thenReturn(mockSspMetadataMap);
+    when(mockLocalityManager.getTaskAssignmentManager()).thenReturn(mockTaskManager);
+    when(mockTaskManager.readTaskAssignment()).thenReturn(Collections.EMPTY_MAP);
+  }
+
+  @Test
+  public void testLocalityMapWithHostAffinity() {
+    Config config = new MapConfig(new HashMap<String, String>() {
+      {
+        put("yarn.container.count", "1");
+        put("systems.test-system.samza.factory", "org.apache.samza.job.yarn.MockSystemFactory");
+        put("yarn.container.memory.mb", "512");
+        put("yarn.package.path", "/foo");
+        put("task.inputs", "test-system.test-stream");
+        put("systems.test-system.samza.key.serde", "org.apache.samza.serializers.JsonSerde");
+        put("systems.test-system.samza.msg.serde", "org.apache.samza.serializers.JsonSerde");
+        put("yarn.container.retry.count", "1");
+        put("yarn.container.retry.window.ms", "1999999999");
+        put("yarn.allocator.sleep.ms", "10");
+        put("job.host-affinity.enabled", "true");
+      }
+    });
+
+    this.localityMappings.put("0", new HashMap<String, String>() { {
+        put(SetContainerHostMapping.HOST_KEY, "abc-affinity");
+      } });
+    this.jobModelManager = JobModelManagerTestUtil.getJobModelManagerUsingReadModel(config, 1, mockStreamMetadataCache, mockLocalityManager, server);
+
+    assertEquals(jobModelManager.jobModel().getAllContainerLocality(), new HashMap<String, String>() { { this.put("0", "abc-affinity"); } });
+  }
+
+  @Test
+  public void testLocalityMapWithoutHostAffinity() {
+    Config config = new MapConfig(new HashMap<String, String>() {
+      {
+        put("yarn.container.count", "1");
+        put("systems.test-system.samza.factory", "org.apache.samza.job.yarn.MockSystemFactory");
+        put("yarn.container.memory.mb", "512");
+        put("yarn.package.path", "/foo");
+        put("task.inputs", "test-system.test-stream");
+        put("systems.test-system.samza.key.serde", "org.apache.samza.serializers.JsonSerde");
+        put("systems.test-system.samza.msg.serde", "org.apache.samza.serializers.JsonSerde");
+        put("yarn.container.retry.count", "1");
+        put("yarn.container.retry.window.ms", "1999999999");
+        put("yarn.allocator.sleep.ms", "10");
+        put("job.host-affinity.enabled", "false");
+      }
+    });
+
+    this.localityMappings.put("0", new HashMap<String, String>() { {
+        put(SetContainerHostMapping.HOST_KEY, "abc-affinity");
+      } });
+    this.jobModelManager = JobModelManagerTestUtil.getJobModelManagerUsingReadModel(config, 1, mockStreamMetadataCache, mockLocalityManager, server);
+
+    assertEquals(jobModelManager.jobModel().getAllContainerLocality(), new HashMap<String, String>() { { this.put("0", null); } });
+  }
+}
@@ -17,7 +17,7 @@
  * under the License.
  */
 
-package org.apache.samza.clustermanager;
+package org.apache.samza.testUtils;
 
 import org.apache.samza.coordinator.server.HttpServer;
 import org.eclipse.jetty.servlet.ServletHolder;