SAMZA-1933: Fix NPE in LocalityManager.
authorShanthoosh Venkataraman <spvenkat@usc.edu>
Fri, 5 Oct 2018 17:58:39 +0000 (10:58 -0700)
committerPrateek Maheshwari <pmaheshwari@apache.org>
Fri, 5 Oct 2018 17:58:39 +0000 (10:58 -0700)
Author: Shanthoosh Venkataraman <spvenkat@usc.edu>

Reviewers: Prateek Maheshwari <pmaheshwari@apache.org>

Closes #684 from shanthoosh/fix_NPE_in_task_assignment_manager

samza-core/src/main/java/org/apache/samza/container/LocalityManager.java
samza-core/src/main/java/org/apache/samza/container/grouper/task/GroupByContainerCount.java
samza-core/src/main/java/org/apache/samza/container/grouper/task/GroupByContainerCountFactory.java
samza-core/src/main/java/org/apache/samza/container/grouper/task/TaskAssignmentManager.java
samza-core/src/test/java/org/apache/samza/container/grouper/task/TestGroupByContainerCount.java
samza-core/src/test/java/org/apache/samza/container/grouper/task/TestGroupByContainerIds.java
samza-core/src/test/java/org/apache/samza/container/grouper/task/TestTaskAssignmentManager.java
samza-core/src/test/java/org/apache/samza/coordinator/TestJobModelManager.java
samza-rest/src/main/java/org/apache/samza/rest/proxy/task/SamzaTaskProxy.java
samza-yarn/src/test/java/org/apache/samza/webapp/TestApplicationMasterRestClient.java

index 20e86d9..63483b7 100644 (file)
@@ -25,7 +25,6 @@ import java.util.HashMap;
 import java.util.Map;
 import org.apache.samza.config.Config;
 import org.apache.samza.config.JobConfig;
-import org.apache.samza.container.grouper.task.TaskAssignmentManager;
 import org.apache.samza.coordinator.stream.CoordinatorStreamKeySerde;
 import org.apache.samza.coordinator.stream.CoordinatorStreamValueSerde;
 import org.apache.samza.coordinator.stream.messages.SetContainerHostMapping;
@@ -48,7 +47,6 @@ public class LocalityManager {
   private final Serde<String> keySerde;
   private final Serde<String> valueSerde;
   private final MetadataStore metadataStore;
-  private final TaskAssignmentManager taskAssignmentManager;
 
   /**
    * Builds the LocalityManager based upon {@link Config} and {@link MetricsRegistry}.
@@ -81,7 +79,6 @@ public class LocalityManager {
     this.metadataStore.init();
     this.keySerde = keySerde;
     this.valueSerde = valueSerde;
-    this.taskAssignmentManager = new TaskAssignmentManager(config, metricsRegistry, keySerde, valueSerde);
   }
 
   /**
@@ -128,10 +125,5 @@ public class LocalityManager {
 
   public void close() {
     metadataStore.close();
-    taskAssignmentManager.close();
-  }
-
-  public TaskAssignmentManager getTaskAssignmentManager() {
-    return taskAssignmentManager;
   }
 }
index b4d6c90..759f82e 100644 (file)
@@ -29,10 +29,13 @@ import java.util.Map;
 import java.util.Set;
 
 import org.apache.samza.SamzaException;
+import org.apache.samza.config.Config;
+import org.apache.samza.config.JobConfig;
 import org.apache.samza.container.LocalityManager;
 import org.apache.samza.container.TaskName;
 import org.apache.samza.job.model.ContainerModel;
 import org.apache.samza.job.model.TaskModel;
+import org.apache.samza.metrics.MetricsRegistryMap;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -50,10 +53,12 @@ import org.slf4j.LoggerFactory;
 public class GroupByContainerCount implements BalancingTaskNameGrouper {
   private static final Logger log = LoggerFactory.getLogger(GroupByContainerCount.class);
   private final int containerCount;
+  private final Config config;
 
-  public GroupByContainerCount(int containerCount) {
+  public GroupByContainerCount(Config config) {
+    this.containerCount = new JobConfig(config).getContainerCount();
+    this.config = config;
     if (containerCount <= 0) throw new IllegalArgumentException("Must have at least one container");
-    this.containerCount = containerCount;
   }
 
   @Override
@@ -94,51 +99,56 @@ public class GroupByContainerCount implements BalancingTaskNameGrouper {
       return group(tasks);
     }
 
-    TaskAssignmentManager taskAssignmentManager = localityManager.getTaskAssignmentManager();
-    List<TaskGroup> containers = getPreviousContainers(taskAssignmentManager, tasks.size());
-    if (containers == null || containers.size() == 1 || containerCount == 1) {
-      log.info("Balancing does not apply. Invoking grouper.");
-      Set<ContainerModel> models = group(tasks);
-      saveTaskAssignments(models, taskAssignmentManager);
-      return models;
-    }
+    TaskAssignmentManager taskAssignmentManager =  new TaskAssignmentManager(config, new MetricsRegistryMap());
+    taskAssignmentManager.init();
+    try {
+      List<TaskGroup> containers = getPreviousContainers(taskAssignmentManager, tasks.size());
+      if (containers == null || containers.size() == 1 || containerCount == 1) {
+        log.info("Balancing does not apply. Invoking grouper.");
+        Set<ContainerModel> models = group(tasks);
+        saveTaskAssignments(models, taskAssignmentManager);
+        return models;
+      }
 
-    int prevContainerCount = containers.size();
-    int containerDelta = containerCount - prevContainerCount;
-    if (containerDelta == 0) {
-      log.info("Container count has not changed. Reusing previous container models.");
-      return buildContainerModels(tasks, containers);
-    }
-    log.info("Container count changed from {} to {}. Balancing tasks.", prevContainerCount, containerCount);
+      int prevContainerCount = containers.size();
+      int containerDelta = containerCount - prevContainerCount;
+      if (containerDelta == 0) {
+        log.info("Container count has not changed. Reusing previous container models.");
+        return buildContainerModels(tasks, containers);
+      }
+      log.info("Container count changed from {} to {}. Balancing tasks.", prevContainerCount, containerCount);
 
-    // Calculate the expected task count per container
-    int[] expectedTaskCountPerContainer = calculateTaskCountPerContainer(tasks.size(), prevContainerCount, containerCount);
+      // Calculate the expected task count per container
+      int[] expectedTaskCountPerContainer = calculateTaskCountPerContainer(tasks.size(), prevContainerCount, containerCount);
 
-    // Collect excess tasks from over-assigned containers
-    List<String> taskNamesToReassign = new LinkedList<>();
-    for (int i = 0; i < prevContainerCount; i++) {
-      TaskGroup taskGroup = containers.get(i);
-      while (taskGroup.size() > expectedTaskCountPerContainer[i]) {
-        taskNamesToReassign.add(taskGroup.removeTask());
+      // Collect excess tasks from over-assigned containers
+      List<String> taskNamesToReassign = new LinkedList<>();
+      for (int i = 0; i < prevContainerCount; i++) {
+        TaskGroup taskGroup = containers.get(i);
+        while (taskGroup.size() > expectedTaskCountPerContainer[i]) {
+          taskNamesToReassign.add(taskGroup.removeTask());
+        }
       }
-    }
 
-    // Assign tasks to the under-assigned containers
-    if (containerDelta > 0) {
-      List<TaskGroup> newContainers = createContainers(prevContainerCount, containerCount);
-      containers.addAll(newContainers);
-    } else {
-      containers = containers.subList(0, containerCount);
-    }
-    assignTasksToContainers(expectedTaskCountPerContainer, taskNamesToReassign, containers);
+      // Assign tasks to the under-assigned containers
+      if (containerDelta > 0) {
+        List<TaskGroup> newContainers = createContainers(prevContainerCount, containerCount);
+        containers.addAll(newContainers);
+      } else {
+        containers = containers.subList(0, containerCount);
+      }
+      assignTasksToContainers(expectedTaskCountPerContainer, taskNamesToReassign, containers);
 
-    // Transform containers to containerModel
-    Set<ContainerModel> models = buildContainerModels(tasks, containers);
+      // Transform containers to containerModel
+      Set<ContainerModel> models = buildContainerModels(tasks, containers);
 
-    // Save the results
-    saveTaskAssignments(models, taskAssignmentManager);
+      // Save the results
+      saveTaskAssignments(models, taskAssignmentManager);
 
-    return models;
+      return models;
+    } finally {
+      taskAssignmentManager.close();
+    }
   }
 
   /**
index f0e9686..06aba33 100644 (file)
@@ -19,8 +19,6 @@
 package org.apache.samza.container.grouper.task;
 
 import org.apache.samza.config.Config;
-import org.apache.samza.config.JobConfig;
-
 
 /**
  * Factory to build the GroupByContainerCount class.
@@ -28,6 +26,6 @@ import org.apache.samza.config.JobConfig;
 public class GroupByContainerCountFactory implements TaskNameGrouperFactory {
   @Override
   public TaskNameGrouper build(Config config) {
-    return new GroupByContainerCount(new JobConfig(config).getContainerCount());
+    return new GroupByContainerCount(config);
   }
 }
index 2bfd4c3..0ada91c 100644 (file)
@@ -83,7 +83,7 @@ public class TaskAssignmentManager {
     this.metadataStore = metadataStoreFactory.getMetadataStore(SetTaskContainerMapping.TYPE, config, metricsRegistry);
   }
 
-  public void init(Config config, MetricsRegistry metricsRegistry) {
+  public void init() {
     this.metadataStore.init();
   }
 
index 8d2d394..0c2f2fb 100644 (file)
@@ -25,42 +25,53 @@ import java.util.Map;
 import java.util.Set;
 import java.util.UUID;
 import org.apache.samza.SamzaException;
+import org.apache.samza.config.Config;
+import org.apache.samza.config.JobConfig;
+import org.apache.samza.config.MapConfig;
 import org.apache.samza.container.LocalityManager;
 import org.apache.samza.job.model.ContainerModel;
 import org.apache.samza.job.model.TaskModel;
 import org.junit.Before;
 import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.mockito.Mockito;
+import org.powermock.api.mockito.PowerMockito;
+import org.powermock.core.classloader.annotations.PrepareForTest;
+import org.powermock.modules.junit4.PowerMockRunner;
 
 import static org.apache.samza.container.mock.ContainerMocks.*;
 import static org.junit.Assert.*;
 import static org.mockito.Mockito.*;
 
+@RunWith(PowerMockRunner.class)
+@PrepareForTest({TaskAssignmentManager.class, GroupByContainerCount.class})
 public class TestGroupByContainerCount {
   private TaskAssignmentManager taskAssignmentManager;
   private LocalityManager localityManager;
   @Before
-  public void setup() {
+  public void setup() throws Exception {
     taskAssignmentManager = mock(TaskAssignmentManager.class);
     localityManager = mock(LocalityManager.class);
-    when(localityManager.getTaskAssignmentManager()).thenReturn(taskAssignmentManager);
+    PowerMockito.whenNew(TaskAssignmentManager.class).withAnyArguments().thenReturn(taskAssignmentManager);
+    Mockito.doNothing().when(taskAssignmentManager).init();
   }
 
   @Test(expected = IllegalArgumentException.class)
   public void testGroupEmptyTasks() {
-    new GroupByContainerCount(1).group(new HashSet());
+    new GroupByContainerCount(getConfig(1)).group(new HashSet());
   }
 
   @Test(expected = IllegalArgumentException.class)
   public void testGroupFewerTasksThanContainers() {
     Set<TaskModel> taskModels = new HashSet<>();
     taskModels.add(getTaskModel(1));
-    new GroupByContainerCount(2).group(taskModels);
+    new GroupByContainerCount(getConfig(2)).group(taskModels);
   }
 
   @Test(expected = UnsupportedOperationException.class)
   public void testGrouperResultImmutable() {
     Set<TaskModel> taskModels = generateTaskModels(3);
-    Set<ContainerModel> containers = new GroupByContainerCount(3).group(taskModels);
+    Set<ContainerModel> containers = new GroupByContainerCount(getConfig(3)).group(taskModels);
     containers.remove(containers.iterator().next());
   }
 
@@ -68,7 +79,7 @@ public class TestGroupByContainerCount {
   public void testGroupHappyPath() {
     Set<TaskModel> taskModels = generateTaskModels(5);
 
-    Set<ContainerModel> containers = new GroupByContainerCount(2).group(taskModels);
+    Set<ContainerModel> containers = new GroupByContainerCount(getConfig(2)).group(taskModels);
 
     Map<String, ContainerModel> containersMap = new HashMap<>();
     for (ContainerModel container : containers) {
@@ -95,7 +106,7 @@ public class TestGroupByContainerCount {
   public void testGroupManyTasks() {
     Set<TaskModel> taskModels = generateTaskModels(21);
 
-    Set<ContainerModel> containers = new GroupByContainerCount(2).group(taskModels);
+    Set<ContainerModel> containers = new GroupByContainerCount(getConfig(2)).group(taskModels);
 
     Map<String, ContainerModel> containersMap = new HashMap<>();
     for (ContainerModel container : containers) {
@@ -163,11 +174,11 @@ public class TestGroupByContainerCount {
   @Test
   public void testBalancerAfterContainerIncrease() {
     Set<TaskModel> taskModels = generateTaskModels(9);
-    Set<ContainerModel> prevContainers = new GroupByContainerCount(2).group(taskModels);
+    Set<ContainerModel> prevContainers = new GroupByContainerCount(getConfig(2)).group(taskModels);
     Map<String, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
     when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping);
 
-    Set<ContainerModel> containers = new GroupByContainerCount(4).balance(taskModels, localityManager);
+    Set<ContainerModel> containers = new GroupByContainerCount(getConfig(4)).balance(taskModels, localityManager);
 
     Map<String, ContainerModel> containersMap = new HashMap<>();
     for (ContainerModel container : containers) {
@@ -245,11 +256,11 @@ public class TestGroupByContainerCount {
   @Test
   public void testBalancerAfterContainerDecrease() {
     Set<TaskModel> taskModels = generateTaskModels(9);
-    Set<ContainerModel> prevContainers = new GroupByContainerCount(4).group(taskModels);
+    Set<ContainerModel> prevContainers = new GroupByContainerCount(getConfig(4)).group(taskModels);
     Map<String, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
     when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping);
 
-    Set<ContainerModel> containers = new GroupByContainerCount(2).balance(taskModels, localityManager);
+    Set<ContainerModel> containers = new GroupByContainerCount(getConfig(2)).balance(taskModels, localityManager);
 
     Map<String, ContainerModel> containersMap = new HashMap<>();
     for (ContainerModel container : containers) {
@@ -320,15 +331,15 @@ public class TestGroupByContainerCount {
    *  T8  T7  T3
    */
   @Test
-  public void testBalancerMultipleReblances() {
+  public void testBalancerMultipleReblances() throws Exception {
     // Before
     Set<TaskModel> taskModels = generateTaskModels(9);
-    Set<ContainerModel> prevContainers = new GroupByContainerCount(4).group(taskModels);
+    Set<ContainerModel> prevContainers = new GroupByContainerCount(getConfig(4)).group(taskModels);
     Map<String, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
     when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping);
 
     // First balance
-    Set<ContainerModel> containers = new GroupByContainerCount(2).balance(taskModels, localityManager);
+    Set<ContainerModel> containers = new GroupByContainerCount(getConfig(2)).balance(taskModels, localityManager);
 
     Map<String, ContainerModel> containersMap = new HashMap<>();
     for (ContainerModel container : containers) {
@@ -380,9 +391,9 @@ public class TestGroupByContainerCount {
     TaskAssignmentManager taskAssignmentManager2 = mock(TaskAssignmentManager.class);
     when(taskAssignmentManager2.readTaskAssignment()).thenReturn(prevTaskToContainerMapping);
     LocalityManager localityManager2 = mock(LocalityManager.class);
-    when(localityManager2.getTaskAssignmentManager()).thenReturn(taskAssignmentManager2);
+    PowerMockito.whenNew(TaskAssignmentManager.class).withAnyArguments().thenReturn(taskAssignmentManager2);
 
-    containers = new GroupByContainerCount(3).balance(taskModels, localityManager2);
+    containers = new GroupByContainerCount(getConfig(3)).balance(taskModels, localityManager2);
 
     containersMap = new HashMap<>();
     for (ContainerModel container : containers) {
@@ -455,11 +466,11 @@ public class TestGroupByContainerCount {
   @Test
   public void testBalancerAfterContainerSame() {
     Set<TaskModel> taskModels = generateTaskModels(9);
-    Set<ContainerModel> prevContainers = new GroupByContainerCount(2).group(taskModels);
+    Set<ContainerModel> prevContainers = new GroupByContainerCount(getConfig(2)).group(taskModels);
     Map<String, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
     when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping);
 
-    Set<ContainerModel> containers = new GroupByContainerCount(2).balance(taskModels, localityManager);
+    Set<ContainerModel> containers = new GroupByContainerCount(getConfig(2)).balance(taskModels, localityManager);
 
     Map<String, ContainerModel> containersMap = new HashMap<>();
     for (ContainerModel container : containers) {
@@ -529,7 +540,7 @@ public class TestGroupByContainerCount {
     prevTaskToContainerMapping.put(getTaskName(8).getTaskName(), "1");
     when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping);
 
-    Set<ContainerModel> containers = new GroupByContainerCount(2).balance(taskModels, localityManager);
+    Set<ContainerModel> containers = new GroupByContainerCount(getConfig(2)).balance(taskModels, localityManager);
 
     Map<String, ContainerModel> containersMap = new HashMap<>();
     for (ContainerModel container : containers) {
@@ -595,7 +606,7 @@ public class TestGroupByContainerCount {
     prevTaskToContainerMapping.put(getTaskName(5).getTaskName(), "1");
     when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping);
 
-    Set<ContainerModel> containers = new GroupByContainerCount(3).balance(taskModels, localityManager);
+    Set<ContainerModel> containers = new GroupByContainerCount(getConfig(3)).balance(taskModels, localityManager);
 
     Map<String, ContainerModel> containersMap = new HashMap<>();
     for (ContainerModel container : containers) {
@@ -636,12 +647,12 @@ public class TestGroupByContainerCount {
   @Test
   public void testBalancerOldContainerCountOne() {
     Set<TaskModel> taskModels = generateTaskModels(3);
-    Set<ContainerModel> prevContainers = new GroupByContainerCount(1).group(taskModels);
+    Set<ContainerModel> prevContainers = new GroupByContainerCount(getConfig(1)).group(taskModels);
     Map<String, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
     when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping);
 
-    Set<ContainerModel> groupContainers = new GroupByContainerCount(3).group(taskModels);
-    Set<ContainerModel> balanceContainers = new GroupByContainerCount(3).balance(taskModels, localityManager);
+    Set<ContainerModel> groupContainers = new GroupByContainerCount(getConfig(3)).group(taskModels);
+    Set<ContainerModel> balanceContainers = new GroupByContainerCount(getConfig(3)).balance(taskModels, localityManager);
 
     // Results should be the same as calling group()
     assertEquals(groupContainers, balanceContainers);
@@ -657,12 +668,12 @@ public class TestGroupByContainerCount {
   @Test
   public void testBalancerNewContainerCountOne() {
     Set<TaskModel> taskModels = generateTaskModels(3);
-    Set<ContainerModel> prevContainers = new GroupByContainerCount(3).group(taskModels);
+    Set<ContainerModel> prevContainers = new GroupByContainerCount(getConfig(3)).group(taskModels);
     Map<String, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
     when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping);
 
-    Set<ContainerModel> groupContainers = new GroupByContainerCount(1).group(taskModels);
-    Set<ContainerModel> balanceContainers = new GroupByContainerCount(1).balance(taskModels, localityManager);
+    Set<ContainerModel> groupContainers = new GroupByContainerCount(getConfig(1)).group(taskModels);
+    Set<ContainerModel> balanceContainers = new GroupByContainerCount(getConfig(1)).balance(taskModels, localityManager);
 
     // Results should be the same as calling group()
     assertEquals(groupContainers, balanceContainers);
@@ -677,10 +688,10 @@ public class TestGroupByContainerCount {
   @Test
   public void testBalancerEmptyTaskMapping() {
     Set<TaskModel> taskModels = generateTaskModels(3);
-    when(taskAssignmentManager.readTaskAssignment()).thenReturn(new HashMap<String, String>());
+    when(taskAssignmentManager.readTaskAssignment()).thenReturn(new HashMap<>());
 
-    Set<ContainerModel> groupContainers = new GroupByContainerCount(1).group(taskModels);
-    Set<ContainerModel> balanceContainers = new GroupByContainerCount(1).balance(taskModels, localityManager);
+    Set<ContainerModel> groupContainers = new GroupByContainerCount(getConfig(1)).group(taskModels);
+    Set<ContainerModel> balanceContainers = new GroupByContainerCount(getConfig(1)).balance(taskModels, localityManager);
 
     // Results should be the same as calling group()
     assertEquals(groupContainers, balanceContainers);
@@ -696,12 +707,12 @@ public class TestGroupByContainerCount {
   public void testGroupTaskCountIncrease() {
     int taskCount = 3;
     Set<TaskModel> taskModels = generateTaskModels(taskCount);
-    Set<ContainerModel> prevContainers = new GroupByContainerCount(2).group(generateTaskModels(taskCount - 1)); // Here's the key step
+    Set<ContainerModel> prevContainers = new GroupByContainerCount(getConfig(2)).group(generateTaskModels(taskCount - 1)); // Here's the key step
     Map<String, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
     when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping);
 
-    Set<ContainerModel> groupContainers = new GroupByContainerCount(1).group(taskModels);
-    Set<ContainerModel> balanceContainers = new GroupByContainerCount(1).balance(taskModels, localityManager);
+    Set<ContainerModel> groupContainers = new GroupByContainerCount(getConfig(1)).group(taskModels);
+    Set<ContainerModel> balanceContainers = new GroupByContainerCount(getConfig(1)).balance(taskModels, localityManager);
 
     // Results should be the same as calling group()
     assertEquals(groupContainers, balanceContainers);
@@ -717,12 +728,12 @@ public class TestGroupByContainerCount {
   public void testGroupTaskCountDecrease() {
     int taskCount = 3;
     Set<TaskModel> taskModels = generateTaskModels(taskCount);
-    Set<ContainerModel> prevContainers = new GroupByContainerCount(3).group(generateTaskModels(taskCount + 1)); // Here's the key step
+    Set<ContainerModel> prevContainers = new GroupByContainerCount(getConfig(3)).group(generateTaskModels(taskCount + 1)); // Here's the key step
     Map<String, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
     when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping);
 
-    Set<ContainerModel> groupContainers = new GroupByContainerCount(1).group(taskModels);
-    Set<ContainerModel> balanceContainers = new GroupByContainerCount(1).balance(taskModels, localityManager);
+    Set<ContainerModel> groupContainers = new GroupByContainerCount(getConfig(1)).group(taskModels);
+    Set<ContainerModel> balanceContainers = new GroupByContainerCount(getConfig(1)).balance(taskModels, localityManager);
 
     // Results should be the same as calling group()
     assertEquals(groupContainers, balanceContainers);
@@ -737,31 +748,31 @@ public class TestGroupByContainerCount {
   @Test(expected = IllegalArgumentException.class)
   public void testBalancerNewContainerCountGreaterThanTasks() {
     Set<TaskModel> taskModels = generateTaskModels(3);
-    Set<ContainerModel> prevContainers = new GroupByContainerCount(3).group(taskModels);
+    Set<ContainerModel> prevContainers = new GroupByContainerCount(getConfig(3)).group(taskModels);
     Map<String, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
     when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping);
 
-    new GroupByContainerCount(5).balance(taskModels, localityManager);     // Should throw
+    new GroupByContainerCount(getConfig(5)).balance(taskModels, localityManager);     // Should throw
   }
 
   @Test(expected = IllegalArgumentException.class)
   public void testBalancerEmptyTasks() {
     Set<TaskModel> taskModels = generateTaskModels(3);
-    Set<ContainerModel> prevContainers = new GroupByContainerCount(3).group(taskModels);
+    Set<ContainerModel> prevContainers = new GroupByContainerCount(getConfig(3)).group(taskModels);
     Map<String, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
     when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping);
 
-    new GroupByContainerCount(5).balance(new HashSet<TaskModel>(), localityManager);     // Should throw
+    new GroupByContainerCount(getConfig(5)).balance(new HashSet<>(), localityManager);     // Should throw
   }
 
   @Test(expected = UnsupportedOperationException.class)
   public void testBalancerResultImmutable() {
     Set<TaskModel> taskModels = generateTaskModels(3);
-    Set<ContainerModel> prevContainers = new GroupByContainerCount(3).group(taskModels);
+    Set<ContainerModel> prevContainers = new GroupByContainerCount(getConfig(3)).group(taskModels);
     Map<String, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
     when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping);
 
-    Set<ContainerModel> containers = new GroupByContainerCount(2).balance(taskModels, localityManager);
+    Set<ContainerModel> containers = new GroupByContainerCount(getConfig(2)).balance(taskModels, localityManager);
     containers.remove(containers.iterator().next());
   }
 
@@ -776,7 +787,7 @@ public class TestGroupByContainerCount {
     Map<String, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
     when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping);
 
-    new GroupByContainerCount(3).balance(taskModels, localityManager); //Should throw
+    new GroupByContainerCount(getConfig(3)).balance(taskModels, localityManager); //Should throw
 
   }
 
@@ -784,10 +795,17 @@ public class TestGroupByContainerCount {
   public void testBalancerWithNullLocalityManager() {
     Set<TaskModel> taskModels = generateTaskModels(3);
 
-    Set<ContainerModel> groupContainers = new GroupByContainerCount(3).group(taskModels);
-    Set<ContainerModel> balanceContainers = new GroupByContainerCount(3).balance(taskModels, null);
+    Set<ContainerModel> groupContainers = new GroupByContainerCount(getConfig(3)).group(taskModels);
+    Set<ContainerModel> balanceContainers = new GroupByContainerCount(getConfig(3)).balance(taskModels, null);
 
     // Results should be the same as calling group()
     assertEquals(groupContainers, balanceContainers);
   }
+
+
+  Config getConfig(int containerCount) {
+    Map<String, String> config = new HashMap<>();
+    config.put(JobConfig.JOB_CONTAINER_COUNT(), String.valueOf(containerCount));
+    return new MapConfig(config);
+  }
 }
index b9fe6fb..5bb78e8 100644 (file)
@@ -37,21 +37,25 @@ import org.apache.samza.job.model.ContainerModel;
 import org.apache.samza.job.model.TaskModel;
 import org.junit.Before;
 import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.powermock.api.mockito.PowerMockito;
+import org.powermock.core.classloader.annotations.PrepareForTest;
+import org.powermock.modules.junit4.PowerMockRunner;
 
 import static org.apache.samza.container.mock.ContainerMocks.*;
 import static org.junit.Assert.*;
 import static org.mockito.Mockito.*;
 
 
+@RunWith(PowerMockRunner.class)
+@PrepareForTest({TaskAssignmentManager.class, GroupByContainerIds.class})
 public class TestGroupByContainerIds {
 
   @Before
-  public void setup() {
+  public void setup() throws Exception {
     TaskAssignmentManager taskAssignmentManager = mock(TaskAssignmentManager.class);
     LocalityManager localityManager = mock(LocalityManager.class);
-    when(localityManager.getTaskAssignmentManager()).thenReturn(taskAssignmentManager);
-
-
+    PowerMockito.whenNew(TaskAssignmentManager.class).withAnyArguments().thenReturn(taskAssignmentManager);
   }
 
   private Config buildConfigForContainerCount(int count) {
index 879171e..fcdbf08 100644 (file)
@@ -68,7 +68,7 @@ public class TestTaskAssignmentManager {
   @Test
   public void testTaskAssignmentManager() {
     TaskAssignmentManager taskAssignmentManager = new TaskAssignmentManager(config, new MetricsRegistryMap());
-    taskAssignmentManager.init(config, new MetricsRegistryMap());
+    taskAssignmentManager.init();
 
     Map<String, String> expectedMap = ImmutableMap.of("Task0", "0", "Task1", "1", "Task2", "2", "Task3", "0", "Task4", "1");
 
@@ -83,9 +83,10 @@ public class TestTaskAssignmentManager {
     taskAssignmentManager.close();
   }
 
-  @Test public void testDeleteMappings() {
+  @Test
+  public void testDeleteMappings() {
     TaskAssignmentManager taskAssignmentManager = new TaskAssignmentManager(config, new MetricsRegistryMap());
-    taskAssignmentManager.init(config, new MetricsRegistryMap());
+    taskAssignmentManager.init();
 
     Map<String, String> expectedMap = ImmutableMap.of("Task0", "0", "Task1", "1");
 
@@ -104,9 +105,10 @@ public class TestTaskAssignmentManager {
     taskAssignmentManager.close();
   }
 
-  @Test public void testTaskAssignmentManagerEmptyCoordinatorStream() {
+  @Test
+  public void testTaskAssignmentManagerEmptyCoordinatorStream() {
     TaskAssignmentManager taskAssignmentManager = new TaskAssignmentManager(config, new MetricsRegistryMap());
-    taskAssignmentManager.init(config, new MetricsRegistryMap());
+    taskAssignmentManager.init();
 
     Map<String, String> expectedMap = new HashMap<>();
     Map<String, String> localMap = taskAssignmentManager.readTaskAssignment();
index 3130ed6..1dbf132 100644 (file)
@@ -23,6 +23,7 @@ import org.apache.samza.Partition;
 import org.apache.samza.config.Config;
 import org.apache.samza.config.MapConfig;
 import org.apache.samza.container.LocalityManager;
+import org.apache.samza.container.grouper.task.GroupByContainerCount;
 import org.apache.samza.container.grouper.task.TaskAssignmentManager;
 import org.apache.samza.coordinator.server.HttpServer;
 import org.apache.samza.coordinator.stream.messages.SetContainerHostMapping;
@@ -45,12 +46,18 @@ import static org.mockito.Matchers.argThat;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.when;
 
+import org.junit.runner.RunWith;
 import org.mockito.ArgumentMatcher;
+import org.powermock.api.mockito.PowerMockito;
+import org.powermock.core.classloader.annotations.PrepareForTest;
+import org.powermock.modules.junit4.PowerMockRunner;
 import scala.collection.JavaConversions;
 
 /**
  * Unit tests for {@link JobModelManager}
  */
+@RunWith(PowerMockRunner.class)
+@PrepareForTest({TaskAssignmentManager.class, GroupByContainerCount.class})
 public class TestJobModelManager {
   private final TaskAssignmentManager mockTaskManager = mock(TaskAssignmentManager.class);
   private final LocalityManager mockLocalityManager = mock(LocalityManager.class);
@@ -67,7 +74,7 @@ public class TestJobModelManager {
   private JobModelManager jobModelManager;
 
   @Before
-  public void setup() {
+  public void setup() throws Exception {
     when(mockLocalityManager.readContainerLocality()).thenReturn(this.localityMappings);
     when(mockStreamMetadataCache.getStreamMetadata(argThat(new ArgumentMatcher<scala.collection.immutable.Set<SystemStream>>() {
       @Override
@@ -77,7 +84,7 @@ public class TestJobModelManager {
       }
     }), anyBoolean())).thenReturn(mockStreamMetadataMap);
     when(mockStreamMetadata.getSystemStreamPartitionMetadata()).thenReturn(mockSspMetadataMap);
-    when(mockLocalityManager.getTaskAssignmentManager()).thenReturn(mockTaskManager);
+    PowerMockito.whenNew(TaskAssignmentManager.class).withAnyArguments().thenReturn(mockTaskManager);
     when(mockTaskManager.readTaskAssignment()).thenReturn(Collections.EMPTY_MAP);
   }
 
index b14d14b..daf665b 100644 (file)
@@ -34,6 +34,7 @@ import org.apache.samza.config.JobConfig;
 import org.apache.samza.config.MapConfig;
 import org.apache.samza.config.StorageConfig;
 import org.apache.samza.container.LocalityManager;
+import org.apache.samza.container.grouper.task.TaskAssignmentManager;
 import org.apache.samza.coordinator.stream.CoordinatorStreamSystemConsumer;
 import org.apache.samza.coordinator.stream.messages.SetContainerHostMapping;
 import org.apache.samza.metrics.MetricsRegistryMap;
@@ -131,7 +132,8 @@ public class SamzaTaskProxy implements TaskProxy {
   protected List<Task> readTasksFromCoordinatorStream(CoordinatorStreamSystemConsumer consumer) {
     LocalityManager localityManager = new LocalityManager(consumer.getConfig(), new MetricsRegistryMap());
     Map<String, Map<String, String>> containerIdToHostMapping = localityManager.readContainerLocality();
-    Map<String, String> taskNameToContainerIdMapping = localityManager.getTaskAssignmentManager().readTaskAssignment();
+    TaskAssignmentManager taskAssignmentManager = new TaskAssignmentManager(consumer.getConfig(), new MetricsRegistryMap());
+    Map<String, String> taskNameToContainerIdMapping = taskAssignmentManager.readTaskAssignment();
     StorageConfig storageConfig = new StorageConfig(consumer.getConfig());
     List<String> storeNames = JavaConverters.seqAsJavaListConverter(storageConfig.getStoreNames()).asJava();
     return taskNameToContainerIdMapping.entrySet()
index 9c0dea7..d19badc 100644 (file)
@@ -44,6 +44,7 @@ import org.apache.samza.SamzaException;
 import org.apache.samza.clustermanager.SamzaApplicationState;
 import org.apache.samza.clustermanager.SamzaResource;
 import org.apache.samza.config.Config;
+import org.apache.samza.config.JobConfig;
 import org.apache.samza.config.MapConfig;
 import org.apache.samza.container.TaskName;
 import org.apache.samza.container.grouper.task.GroupByContainerCount;
@@ -275,7 +276,9 @@ public class TestApplicationMasterRestClient {
         new TaskModel(new TaskName("task2"),
             ImmutableSet.of(new SystemStreamPartition(new SystemStream("system1", "stream1"), new Partition(1))),
             new Partition(1)));
-    GroupByContainerCount grouper = new GroupByContainerCount(2);
+    Map<String, String> config = new HashMap<>();
+    config.put(JobConfig.JOB_CONTAINER_COUNT(), String.valueOf(2));
+    GroupByContainerCount grouper = new GroupByContainerCount(new MapConfig(config));
     Set<ContainerModel> containerModels = grouper.group(taskModels);
     HashMap<String, ContainerModel> containers = new HashMap<>();
     for (ContainerModel containerModel : containerModels) {