SAMZA-1361; OperatorImplGraph is using wrong keys to store/retrieve OperatorImpl...
authorPrateek Maheshwari <pmaheshw@linkedin.com>
Thu, 27 Jul 2017 22:13:33 +0000 (15:13 -0700)
committerJagadish <jagadish@apache.org>
Thu, 27 Jul 2017 22:13:33 +0000 (15:13 -0700)
Author: Prateek Maheshwari <pmaheshw@linkedin.com>

Reviewers: Jagadish <jagadish@apachce.org>

Closes #248 from prateekm/operatorimpl-key and squashes the following commits:

e733e9d3 [Prateek Maheshwari] Dummy commit to trigger jenkins build.
5a16f162 [Prateek Maheshwari] Updated with Yi's feedback.
8eb2c5df [Prateek Maheshwari] SAMZA-1361: OperatorImplGraph is using wrong keys to store/retrieve OperatorImpl in the map

samza-core/src/main/java/org/apache/samza/operators/StreamGraphImpl.java
samza-core/src/main/java/org/apache/samza/operators/impl/OperatorImpl.java
samza-core/src/main/java/org/apache/samza/operators/impl/OperatorImplGraph.java
samza-core/src/test/java/org/apache/samza/operators/TestJoinOperator.java
samza-core/src/test/java/org/apache/samza/operators/impl/TestOperatorImplGraph.java

index c0da1b2..2c2eb56 100644 (file)
@@ -170,13 +170,13 @@ public class StreamGraphImpl implements StreamGraph {
   /**
    * Get all {@link OperatorSpec}s available in this {@link StreamGraphImpl}
    *
-   * @return  a set of all available {@link OperatorSpec}s
+   * @return  all available {@link OperatorSpec}s
    */
   public Collection<OperatorSpec> getAllOperatorSpecs() {
     Collection<InputOperatorSpec> inputOperatorSpecs = inputOperators.values();
     Set<OperatorSpec> operatorSpecs = new HashSet<>();
-
     for (InputOperatorSpec inputOperatorSpec: inputOperatorSpecs) {
+      operatorSpecs.add(inputOperatorSpec);
       doGetOperatorSpecs(inputOperatorSpec, operatorSpecs);
     }
     return operatorSpecs;
index 73bb83d..8dd5acd 100644 (file)
@@ -186,13 +186,13 @@ public abstract class OperatorImpl<M, RM> {
   protected abstract OperatorSpec<M, RM> getOperatorSpec();
 
   /**
-   * Get the name for this {@link OperatorImpl}.
+   * Get the unique name for this {@link OperatorImpl} in the DAG.
    *
    * Some {@link OperatorImpl}s don't have a 1:1 mapping with their {@link OperatorSpec}. E.g., there are
    * 2 PartialJoinOperatorImpls for a JoinOperatorSpec. Overriding this method allows them to provide an
    * implementation specific name, e.g., for use in metrics.
    *
-   * @return the operator name
+   * @return the unique name for this {@link OperatorImpl} in the DAG
    */
   protected String getOperatorName() {
     return getOperatorSpec().getOpName();
index e5fce13..99496eb 100644 (file)
@@ -129,7 +129,7 @@ public class OperatorImplGraph {
    */
   OperatorImpl createAndRegisterOperatorImpl(OperatorSpec prevOperatorSpec, OperatorSpec operatorSpec,
       Config config, TaskContext context) {
-    if (!operatorImpls.containsKey(operatorSpec) || operatorSpec instanceof JoinOperatorSpec) {
+    if (!operatorImpls.containsKey(operatorSpec.getOpName()) || operatorSpec instanceof JoinOperatorSpec) {
       // Either this is the first time we've seen this operatorSpec, or this is a join operator spec
       // and we need to create 2 partial join operator impls for it. Initialize and register the sub-DAG.
       OperatorImpl operatorImpl = createOperatorImpl(prevOperatorSpec, operatorSpec, config, context);
@@ -145,7 +145,7 @@ public class OperatorImplGraph {
     } else {
       // the implementation corresponding to operatorSpec has already been instantiated
       // and registered, so we do not need to traverse the DAG further.
-      return operatorImpls.get(operatorSpec);
+      return operatorImpls.get(operatorSpec.getOpName());
     }
   }
 
@@ -179,7 +179,6 @@ public class OperatorImplGraph {
   private PartialJoinOperatorImpl createPartialJoinOperatorImpl(OperatorSpec prevOperatorSpec,
       JoinOperatorSpec joinOpSpec, Config config, TaskContext context, Clock clock) {
     Pair<PartialJoinFunction, PartialJoinFunction> partialJoinFunctions = getOrCreatePartialJoinFunctions(joinOpSpec);
-
     if (joinOpSpec.getLeftInputOpSpec().equals(prevOperatorSpec)) { // we got here from the left side of the join
       return new PartialJoinOperatorImpl(joinOpSpec, /* isLeftSide */ true,
           partialJoinFunctions.getLeft(), partialJoinFunctions.getRight(), config, context, clock);
index 0c41fb8..c51b1ea 100644 (file)
@@ -71,7 +71,7 @@ public class TestJoinOperator {
   }
 
   @Test
-  public void testJoinFnInitAndClose() throws Exception {
+  public void joinFnInitAndClose() throws Exception {
     TestJoinFunction joinFn = new TestJoinFunction();
     StreamOperatorTask sot = createStreamOperatorTask(new SystemClock(), new TestJoinStreamApplication(joinFn));
     assertEquals(1, joinFn.getNumInitCalls());
index b2c7722..4505eef 100644 (file)
@@ -41,6 +41,7 @@ import org.junit.Test;
 
 import java.time.Duration;
 import java.util.ArrayList;
+import java.util.Collections;
 import java.util.List;
 import java.util.function.BiFunction;
 import java.util.function.Function;
@@ -57,6 +58,7 @@ import static org.mockito.Mockito.when;
 
 public class TestOperatorImplGraph {
 
+  @Test
   public void testEmptyChain() {
     StreamGraphImpl streamGraph = new StreamGraphImpl(mock(ApplicationRunner.class), mock(Config.class));
     OperatorImplGraph opGraph =
@@ -101,7 +103,6 @@ public class TestOperatorImplGraph {
     assertEquals(OpCode.SEND_TO, sendToOpImpl.getOperatorSpec().getOpCode());
   }
 
-
   @Test
   public void testBroadcastChain() {
     ApplicationRunner mockRunner = mock(ApplicationRunner.class);
@@ -126,6 +127,28 @@ public class TestOperatorImplGraph {
   }
 
   @Test
+  public void testMergeChain() {
+    ApplicationRunner mockRunner = mock(ApplicationRunner.class);
+    when(mockRunner.getStreamSpec(eq("input"))).thenReturn(new StreamSpec("input", "input-stream", "input-system"));
+    StreamGraphImpl streamGraph = new StreamGraphImpl(mockRunner, mock(Config.class));
+
+    MessageStream<Object> inputStream = streamGraph.getInputStream("input", mock(BiFunction.class));
+    MessageStream<Object> stream1 = inputStream.filter(mock(FilterFunction.class));
+    MessageStream<Object> stream2 = inputStream.map(mock(MapFunction.class));
+    MessageStream<Object> mergedStream = stream1.merge(Collections.singleton(stream2));
+    MapFunction mockMapFunction = mock(MapFunction.class);
+    mergedStream.map(mockMapFunction);
+
+    TaskContext mockTaskContext = mock(TaskContext.class);
+    when(mockTaskContext.getMetricsRegistry()).thenReturn(new MetricsRegistryMap());
+    OperatorImplGraph opImplGraph =
+        new OperatorImplGraph(streamGraph, mock(Config.class), mockTaskContext, mock(Clock.class));
+
+    // verify that the DAG after merge is only traversed & initialized once
+    verify(mockMapFunction, times(1)).init(any(Config.class), any(TaskContext.class));
+  }
+
+  @Test
   public void testJoinChain() {
     ApplicationRunner mockRunner = mock(ApplicationRunner.class);
     when(mockRunner.getStreamSpec(eq("input1"))).thenReturn(new StreamSpec("input1", "input-stream1", "input-system"));