SAMZA-1346: GroupByContainerCount.balance() should guard against null…
authorJacob Maes <jmaes@linkedin.com>
Mon, 26 Jun 2017 20:43:33 +0000 (13:43 -0700)
committerJacob Maes <jmaes@linkedin.com>
Mon, 26 Jun 2017 20:43:33 +0000 (13:43 -0700)
… LocalityManager

Author: Jacob Maes <jmaes@linkedin.com>

Reviewers: Chris Pettitt <cpettitt@linkedin.com>

Closes #232 from jmakes/samza-1346

samza-core/src/main/java/org/apache/samza/container/grouper/task/GroupByContainerCount.java
samza-core/src/test/java/org/apache/samza/container/grouper/task/TestGroupByContainerCount.java

index 246188e..74c69d6 100644 (file)
@@ -89,6 +89,11 @@ public class GroupByContainerCount implements BalancingTaskNameGrouper {
 
     validateTasks(tasks);
 
+    if (localityManager == null) {
+      log.info("Locality manager is null. Cannot read or write task assignments. Invoking grouper.");
+      return group(tasks);
+    }
+
     TaskAssignmentManager taskAssignmentManager = localityManager.getTaskAssignmentManager();
     List<TaskGroup> containers = getPreviousContainers(taskAssignmentManager, tasks.size());
     if (containers == null || containers.size() == 1 || containerCount == 1) {
index de4de7c..e89d673 100644 (file)
@@ -651,30 +651,11 @@ public class TestGroupByContainerCount {
     Map<String, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
     when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping);
 
-    Set<ContainerModel> containers = new GroupByContainerCount(3).balance(taskModels, localityManager);
+    Set<ContainerModel> groupContainers = new GroupByContainerCount(3).group(taskModels);
+    Set<ContainerModel> balanceContainers = new GroupByContainerCount(3).balance(taskModels, localityManager);
 
     // Results should be the same as calling group()
-    Map<String, ContainerModel> containersMap = new HashMap<>();
-    for (ContainerModel container : containers) {
-      containersMap.put(container.getProcessorId(), container);
-    }
-    assertEquals(3, containers.size());
-    ContainerModel container0 = containersMap.get("0");
-    ContainerModel container1 = containersMap.get("1");
-    ContainerModel container2 = containersMap.get("2");
-    assertNotNull(container0);
-    assertNotNull(container1);
-    assertNotNull(container2);
-    assertEquals("0", container0.getProcessorId());
-    assertEquals("1", container1.getProcessorId());
-    assertEquals("2", container2.getProcessorId());
-    assertEquals(1, container0.getTasks().size());
-    assertEquals(1, container1.getTasks().size());
-    assertEquals(1, container2.getTasks().size());
-
-    assertTrue(container0.getTasks().containsKey(getTaskName(0)));
-    assertTrue(container1.getTasks().containsKey(getTaskName(1)));
-    assertTrue(container2.getTasks().containsKey(getTaskName(2)));
+    assertEquals(groupContainers, balanceContainers);
 
     // Verify task mappings are saved
     verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(0).getTaskName(), "0");
@@ -691,23 +672,11 @@ public class TestGroupByContainerCount {
     Map<String, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
     when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping);
 
-    Set<ContainerModel> containers = new GroupByContainerCount(1).balance(taskModels, localityManager);
+    Set<ContainerModel> groupContainers = new GroupByContainerCount(1).group(taskModels);
+    Set<ContainerModel> balanceContainers = new GroupByContainerCount(1).balance(taskModels, localityManager);
 
-    // Results should be the same as calling group
-    Map<String, ContainerModel> containersMap = new HashMap<>();
-    for (ContainerModel container : containers) {
-      containersMap.put(container.getProcessorId(), container);
-    }
-
-    assertEquals(1, containers.size());
-    ContainerModel container0 = containersMap.get("0");
-    assertNotNull(container0);
-    assertEquals("0", container0.getProcessorId());
-    assertEquals(3, container0.getTasks().size());
-
-    assertTrue(container0.getTasks().containsKey(getTaskName(0)));
-    assertTrue(container0.getTasks().containsKey(getTaskName(1)));
-    assertTrue(container0.getTasks().containsKey(getTaskName(2)));
+    // Results should be the same as calling group()
+    assertEquals(groupContainers, balanceContainers);
 
     verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(0).getTaskName(), "0");
     verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(1).getTaskName(), "0");
@@ -721,23 +690,11 @@ public class TestGroupByContainerCount {
     Set<TaskModel> taskModels = generateTaskModels(3);
     when(taskAssignmentManager.readTaskAssignment()).thenReturn(new HashMap<String, String>());
 
-    Set<ContainerModel> containers = new GroupByContainerCount(1).balance(taskModels, localityManager);
-
-    // Results should be the same as calling group
-    Map<String, ContainerModel> containersMap = new HashMap<>();
-    for (ContainerModel container : containers) {
-      containersMap.put(container.getProcessorId(), container);
-    }
-
-    assertEquals(1, containers.size());
-    ContainerModel container0 = containersMap.get("0");
-    assertNotNull(container0);
-    assertEquals("0", container0.getProcessorId());
-    assertEquals(3, container0.getTasks().size());
+    Set<ContainerModel> groupContainers = new GroupByContainerCount(1).group(taskModels);
+    Set<ContainerModel> balanceContainers = new GroupByContainerCount(1).balance(taskModels, localityManager);
 
-    assertTrue(container0.getTasks().containsKey(getTaskName(0)));
-    assertTrue(container0.getTasks().containsKey(getTaskName(1)));
-    assertTrue(container0.getTasks().containsKey(getTaskName(2)));
+    // Results should be the same as calling group()
+    assertEquals(groupContainers, balanceContainers);
 
     verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(0).getTaskName(), "0");
     verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(1).getTaskName(), "0");
@@ -750,27 +707,15 @@ public class TestGroupByContainerCount {
   public void testGroupTaskCountIncrease() {
     int taskCount = 3;
     Set<TaskModel> taskModels = generateTaskModels(taskCount);
-    Set<ContainerModel> prevContainers = new GroupByContainerCount(2).group(generateTaskModels(taskCount - 1));
+    Set<ContainerModel> prevContainers = new GroupByContainerCount(2).group(generateTaskModels(taskCount - 1)); // Here's the key step
     Map<String, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
     when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping);
 
-    Set<ContainerModel> containers = new GroupByContainerCount(1).balance(taskModels, localityManager);
-
-    // Results should be the same as calling group
-    Map<String, ContainerModel> containersMap = new HashMap<>();
-    for (ContainerModel container : containers) {
-      containersMap.put(container.getProcessorId(), container);
-    }
+    Set<ContainerModel> groupContainers = new GroupByContainerCount(1).group(taskModels);
+    Set<ContainerModel> balanceContainers = new GroupByContainerCount(1).balance(taskModels, localityManager);
 
-    assertEquals(1, containers.size());
-    ContainerModel container0 = containersMap.get("0");
-    assertNotNull(container0);
-    assertEquals("0", container0.getProcessorId());
-    assertEquals(3, container0.getTasks().size());
-
-    assertTrue(container0.getTasks().containsKey(getTaskName(0)));
-    assertTrue(container0.getTasks().containsKey(getTaskName(1)));
-    assertTrue(container0.getTasks().containsKey(getTaskName(2)));
+    // Results should be the same as calling group()
+    assertEquals(groupContainers, balanceContainers);
 
     verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(0).getTaskName(), "0");
     verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(1).getTaskName(), "0");
@@ -783,27 +728,15 @@ public class TestGroupByContainerCount {
   public void testGroupTaskCountDecrease() {
     int taskCount = 3;
     Set<TaskModel> taskModels = generateTaskModels(taskCount);
-    Set<ContainerModel> prevContainers = new GroupByContainerCount(3).group(generateTaskModels(taskCount + 1));
+    Set<ContainerModel> prevContainers = new GroupByContainerCount(3).group(generateTaskModels(taskCount + 1)); // Here's the key step
     Map<String, String> prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
     when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping);
 
-    Set<ContainerModel> containers = new GroupByContainerCount(1).balance(taskModels, localityManager);
+    Set<ContainerModel> groupContainers = new GroupByContainerCount(1).group(taskModels);
+    Set<ContainerModel> balanceContainers = new GroupByContainerCount(1).balance(taskModels, localityManager);
 
-    // Results should be the same as calling group
-    Map<String, ContainerModel> containersMap = new HashMap<>();
-    for (ContainerModel container : containers) {
-      containersMap.put(container.getProcessorId(), container);
-    }
-
-    assertEquals(1, containers.size());
-    ContainerModel container0 = containersMap.get("0");
-    assertNotNull(container0);
-    assertEquals("0", container0.getProcessorId());
-    assertEquals(3, container0.getTasks().size());
-
-    assertTrue(container0.getTasks().containsKey(getTaskName(0)));
-    assertTrue(container0.getTasks().containsKey(getTaskName(1)));
-    assertTrue(container0.getTasks().containsKey(getTaskName(2)));
+    // Results should be the same as calling group()
+    assertEquals(groupContainers, balanceContainers);
 
     verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(0).getTaskName(), "0");
     verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(1).getTaskName(), "0");
@@ -857,4 +790,15 @@ public class TestGroupByContainerCount {
     new GroupByContainerCount(3).balance(taskModels, localityManager); //Should throw
 
   }
+
+  @Test
+  public void testBalancerWithNullLocalityManager() {
+    Set<TaskModel> taskModels = generateTaskModels(3);
+
+    Set<ContainerModel> groupContainers = new GroupByContainerCount(3).group(taskModels);
+    Set<ContainerModel> balanceContainers = new GroupByContainerCount(3).balance(taskModels, null);
+
+    // Results should be the same as calling group()
+    assertEquals(groupContainers, balanceContainers);
+  }
 }