SAMZA-1361; OperatorImplGraph is using wrong keys to store/retrieve OperatorImpl...
[samza.git] / samza-core / src / test / java / org / apache / samza / operators / impl / TestOperatorImplGraph.java
index b2c7722..4505eef 100644 (file)
@@ -41,6 +41,7 @@ import org.junit.Test;
 
 import java.time.Duration;
 import java.util.ArrayList;
+import java.util.Collections;
 import java.util.List;
 import java.util.function.BiFunction;
 import java.util.function.Function;
@@ -57,6 +58,7 @@ import static org.mockito.Mockito.when;
 
 public class TestOperatorImplGraph {
 
+  @Test
   public void testEmptyChain() {
     StreamGraphImpl streamGraph = new StreamGraphImpl(mock(ApplicationRunner.class), mock(Config.class));
     OperatorImplGraph opGraph =
@@ -101,7 +103,6 @@ public class TestOperatorImplGraph {
     assertEquals(OpCode.SEND_TO, sendToOpImpl.getOperatorSpec().getOpCode());
   }
 
-
   @Test
   public void testBroadcastChain() {
     ApplicationRunner mockRunner = mock(ApplicationRunner.class);
@@ -126,6 +127,28 @@ public class TestOperatorImplGraph {
   }
 
   @Test
+  public void testMergeChain() {
+    ApplicationRunner mockRunner = mock(ApplicationRunner.class);
+    when(mockRunner.getStreamSpec(eq("input"))).thenReturn(new StreamSpec("input", "input-stream", "input-system"));
+    StreamGraphImpl streamGraph = new StreamGraphImpl(mockRunner, mock(Config.class));
+
+    MessageStream<Object> inputStream = streamGraph.getInputStream("input", mock(BiFunction.class));
+    MessageStream<Object> stream1 = inputStream.filter(mock(FilterFunction.class));
+    MessageStream<Object> stream2 = inputStream.map(mock(MapFunction.class));
+    MessageStream<Object> mergedStream = stream1.merge(Collections.singleton(stream2));
+    MapFunction mockMapFunction = mock(MapFunction.class);
+    mergedStream.map(mockMapFunction);
+
+    TaskContext mockTaskContext = mock(TaskContext.class);
+    when(mockTaskContext.getMetricsRegistry()).thenReturn(new MetricsRegistryMap());
+    OperatorImplGraph opImplGraph =
+        new OperatorImplGraph(streamGraph, mock(Config.class), mockTaskContext, mock(Clock.class));
+
+    // verify that the DAG after merge is only traversed & initialized once
+    verify(mockMapFunction, times(1)).init(any(Config.class), any(TaskContext.class));
+  }
+
+  @Test
   public void testJoinChain() {
     ApplicationRunner mockRunner = mock(ApplicationRunner.class);
     when(mockRunner.getStreamSpec(eq("input1"))).thenReturn(new StreamSpec("input1", "input-stream1", "input-system"));