SAMZA-1264; Make Operator Functions Closable
authorvjagadish1989 <jvenkatr@linkedin.com>
Thu, 1 Jun 2017 05:20:14 +0000 (22:20 -0700)
committervjagadish1989 <jvenkatr@linkedin.com>
Thu, 1 Jun 2017 05:20:14 +0000 (22:20 -0700)
- Added `close()` to the lifecycle of `OperatorImpl`s, and all `Function`s.
- Added unit tests to verify calls to `close()`

Author: vjagadish1989 <jvenkatr@linkedin.com>

Reviewers: Prateek Maheshwari<pmaheshw@linkedin.com>

Closes #208 from vjagadish1989/operator_functions

22 files changed:
samza-api/src/main/java/org/apache/samza/application/StreamApplication.java
samza-api/src/main/java/org/apache/samza/operators/functions/ClosableFunction.java [new file with mode: 0644]
samza-api/src/main/java/org/apache/samza/operators/functions/FilterFunction.java
samza-api/src/main/java/org/apache/samza/operators/functions/FlatMapFunction.java
samza-api/src/main/java/org/apache/samza/operators/functions/FoldLeftFunction.java
samza-api/src/main/java/org/apache/samza/operators/functions/JoinFunction.java
samza-api/src/main/java/org/apache/samza/operators/functions/MapFunction.java
samza-api/src/main/java/org/apache/samza/operators/functions/SinkFunction.java
samza-core/src/main/java/org/apache/samza/operators/MessageStreamImpl.java
samza-core/src/main/java/org/apache/samza/operators/functions/PartialJoinFunction.java
samza-core/src/main/java/org/apache/samza/operators/impl/OperatorImpl.java
samza-core/src/main/java/org/apache/samza/operators/impl/OperatorImplGraph.java
samza-core/src/main/java/org/apache/samza/operators/impl/PartialJoinOperatorImpl.java
samza-core/src/main/java/org/apache/samza/operators/impl/RootOperatorImpl.java
samza-core/src/main/java/org/apache/samza/operators/impl/SinkOperatorImpl.java
samza-core/src/main/java/org/apache/samza/operators/impl/StreamOperatorImpl.java
samza-core/src/main/java/org/apache/samza/operators/impl/WindowOperatorImpl.java
samza-core/src/main/java/org/apache/samza/task/StreamOperatorTask.java
samza-core/src/test/java/org/apache/samza/operators/TestJoinOperator.java
samza-core/src/test/java/org/apache/samza/operators/impl/TestOperatorImpl.java
samza-core/src/test/java/org/apache/samza/operators/impl/TestSinkOperatorImpl.java
samza-core/src/test/java/org/apache/samza/operators/impl/TestStreamOperatorImpl.java

index 0d77295..f615207 100644 (file)
@@ -58,11 +58,18 @@ import org.apache.samza.task.TaskContext;
  *     runner.waitForFinish();
  *   }
  * }</pre>
+ *
  * <p>
  * Implementation Notes: Currently StreamApplications are wrapped in a {@link StreamTask} during execution.
  * A new StreamApplication instance will be created and initialized when planning the execution, as well as for each
- * {@link StreamTask} instance used for processing incoming messages. Execution is synchronous and thread-safe
- * within each {@link StreamTask}.
+ * {@link StreamTask} instance used for processing incoming messages. Execution is synchronous and thread-safe within
+ * each {@link StreamTask}.
+ *
+ * <p>
+ * Functions implemented for transforms in StreamApplications ({@link org.apache.samza.operators.functions.MapFunction},
+ * {@link org.apache.samza.operators.functions.FilterFunction} for e.g.) are initable and closable. They are initialized
+ * before messages are delivered to them and closed after their execution when the {@link StreamTask} instance is closed.
+ * See {@link InitableFunction} and {@link org.apache.samza.operators.functions.ClosableFunction}.
  */
 @InterfaceStability.Unstable
 public interface StreamApplication {
diff --git a/samza-api/src/main/java/org/apache/samza/operators/functions/ClosableFunction.java b/samza-api/src/main/java/org/apache/samza/operators/functions/ClosableFunction.java
new file mode 100644 (file)
index 0000000..2e73652
--- /dev/null
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.operators.functions;
+
+import org.apache.samza.annotation.InterfaceStability;
+
+/**
+ * A function that can be closed after its execution.
+ *
+ * <p> Implement {@link #close()} to free resources used during the execution of the function, clean up state etc.
+ *
+ */
+@InterfaceStability.Unstable
+public interface ClosableFunction {
+  default void close() {}
+}
index cd49d1b..31bbbd8 100644 (file)
@@ -28,7 +28,7 @@ import org.apache.samza.annotation.InterfaceStability;
  */
 @InterfaceStability.Unstable
 @FunctionalInterface
-public interface FilterFunction<M> extends InitableFunction {
+public interface FilterFunction<M> extends InitableFunction, ClosableFunction {
 
   /**
    * Returns a boolean indicating whether this message should be retained or filtered out.
index e6c4958..7e9253e 100644 (file)
@@ -31,7 +31,7 @@ import java.util.Collection;
  */
 @InterfaceStability.Unstable
 @FunctionalInterface
-public interface FlatMapFunction<M, OM>  extends InitableFunction {
+public interface FlatMapFunction<M, OM>  extends InitableFunction, ClosableFunction {
 
   /**
    * Transforms the provided message into a collection of 0 or more messages.
index 25728fc..78250e3 100644 (file)
@@ -22,7 +22,7 @@ package org.apache.samza.operators.functions;
 /**
  * Incrementally updates the window value as messages are added to the window.
  */
-public interface FoldLeftFunction<M, WV> extends InitableFunction {
+public interface FoldLeftFunction<M, WV> extends InitableFunction, ClosableFunction {
 
   /**
    * Incrementally updates the window value as messages are added to the window.
index f30a47d..954083d 100644 (file)
@@ -30,7 +30,7 @@ import org.apache.samza.annotation.InterfaceStability;
  * @param <RM>  type of the joined message
  */
 @InterfaceStability.Unstable
-public interface JoinFunction<K, M, JM, RM>  extends InitableFunction {
+public interface JoinFunction<K, M, JM, RM>  extends InitableFunction, ClosableFunction {
 
   /**
    * Joins the provided messages and returns the joined message.
index 240039f..a8c139f 100644 (file)
@@ -29,7 +29,7 @@ import org.apache.samza.annotation.InterfaceStability;
  */
 @InterfaceStability.Unstable
 @FunctionalInterface
-public interface MapFunction<M, OM>  extends InitableFunction {
+public interface MapFunction<M, OM>  extends InitableFunction, ClosableFunction {
 
   /**
    * Transforms the provided message into another message.
index 83aa0a1..e290d7d 100644 (file)
@@ -30,7 +30,7 @@ import org.apache.samza.task.TaskCoordinator;
  */
 @InterfaceStability.Unstable
 @FunctionalInterface
-public interface SinkFunction<M>  extends InitableFunction {
+public interface SinkFunction<M>  extends InitableFunction, ClosableFunction {
 
   /**
    * Allows sending the provided message to an output {@link org.apache.samza.system.SystemStream} using
index 4694262..0c84e90 100644 (file)
@@ -148,6 +148,12 @@ public class MessageStreamImpl<M> implements MessageStream<M> {
 
         thisStreamState = new InternalInMemoryStore<>();
       }
+
+      @Override
+      public void close() {
+        // joinFn#close() must only be called once, so we do it in this partial join function's #close.
+        joinFn.close();
+      }
     };
 
     PartialJoinFunction<K, OM, M, TM> otherPartialJoinFn = new PartialJoinFunction<K, OM, M, TM>() {
index a961830..9b7956a 100644 (file)
@@ -23,7 +23,7 @@ import org.apache.samza.storage.kv.KeyValueStore;
 /**
  * An internal function that maintains state and join logic for one side of a two-way join.
  */
-public interface PartialJoinFunction<K, M, JM, RM> extends InitableFunction {
+public interface PartialJoinFunction<K, M, JM, RM> extends InitableFunction, ClosableFunction {
 
   /**
    * Joins a message in this stream with a message from another stream.
index d547869..23c31ac 100644 (file)
@@ -42,6 +42,7 @@ public abstract class OperatorImpl<M, RM> {
   private static final String METRICS_GROUP = OperatorImpl.class.getName();
 
   private boolean initialized;
+  private boolean closed;
   private Set<OperatorImpl<RM, ?>> registeredOperators;
   private HighResolutionClock highResClock;
   private Counter numMessage;
@@ -61,6 +62,10 @@ public abstract class OperatorImpl<M, RM> {
       throw new IllegalStateException(String.format("Attempted to initialize Operator %s more than once.", opName));
     }
 
+    if (closed) {
+      throw new IllegalStateException(String.format("Attempted to initialize Operator %s after it was closed.", opName));
+    }
+
     this.highResClock = createHighResClock(config);
     registeredOperators = new HashSet<>();
     MetricsRegistry metricsRegistry = context.getMetricsRegistry();
@@ -161,6 +166,18 @@ public abstract class OperatorImpl<M, RM> {
     return Collections.emptyList();
   }
 
+  public void close() {
+    String opName = getOperatorSpec().getOpName();
+
+    if (closed) {
+      throw new IllegalStateException(String.format("Attempted to close Operator %s more than once.", opName));
+    }
+    handleClose();
+    closed = true;
+  }
+
+  protected abstract void handleClose();
+
   /**
    * Get the {@link OperatorSpec} for this {@link OperatorImpl}.
    *
index d8ea592..78a6d1e 100644 (file)
@@ -100,6 +100,15 @@ public class OperatorImplGraph {
   }
 
   /**
+   * Get all {@link OperatorImpl}s for the graph.
+   *
+   * @return  an unmodifiable view of all {@link OperatorImpl}s for the graph
+   */
+  public Collection<OperatorImpl> getAllOperators() {
+    return Collections.unmodifiableCollection(this.operatorImpls.values());
+  }
+
+  /**
    * Traverses the DAG of {@link OperatorSpec}s starting from the provided {@link MessageStreamImpl},
    * creates the corresponding DAG of {@link OperatorImpl}s, and returns its root {@link RootOperatorImpl} node.
    *
index c7bdc22..b00a2e9 100644 (file)
@@ -112,6 +112,11 @@ class PartialJoinOperatorImpl<K, M, JM, RM> extends OperatorImpl<M, RM> {
   }
 
   @Override
+  protected void handleClose() {
+    this.thisPartialJoinFn.close();
+  }
+
+  @Override
   protected OperatorSpec<RM> getOperatorSpec() {
     return partialJoinOpSpec;
   }
index 059b567..45cb941 100644 (file)
@@ -44,6 +44,10 @@ public final class RootOperatorImpl<M> extends OperatorImpl<M, M> {
     return Collections.singletonList(message);
   }
 
+  @Override
+  protected void handleClose() {
+  }
+
   // TODO: SAMZA-1221 - Change to InputOperatorSpec that also builds the message
   @Override
   protected OperatorSpec<M> getOperatorSpec() {
index e82737f..4f698f8 100644 (file)
@@ -57,6 +57,11 @@ class SinkOperatorImpl<M> extends OperatorImpl<M, M> {
   }
 
   @Override
+  protected void handleClose() {
+    this.sinkFn.close();
+  }
+
+  @Override
   protected OperatorSpec<M> getOperatorSpec() {
     return sinkOpSpec;
   }
index bd4dce1..e720803 100644 (file)
@@ -58,6 +58,11 @@ class StreamOperatorImpl<M, RM> extends OperatorImpl<M, RM> {
   }
 
   @Override
+  protected void handleClose() {
+    this.transformFn.close();
+  }
+
+  @Override
   protected OperatorSpec<RM> getOperatorSpec() {
     return streamOpSpec;
   }
index a297aba..b258042 100644 (file)
@@ -158,6 +158,15 @@ public class WindowOperatorImpl<M, WK, WV> extends OperatorImpl<M, WindowPane<WK
     return windowOpSpec;
   }
 
+  @Override
+  protected void handleClose() {
+    WindowInternal<M, WK, WV> window = windowOpSpec.getWindow();
+
+    if (window.getFoldLeftFunction() != null) {
+      window.getFoldLeftFunction().close();
+    }
+  }
+
   /**
    * Get the key to be used for lookups in the store for this message.
    */
index b18cf06..a5f3f85 100644 (file)
@@ -22,6 +22,7 @@ import org.apache.samza.application.StreamApplication;
 import org.apache.samza.config.Config;
 import org.apache.samza.operators.ContextManager;
 import org.apache.samza.operators.StreamGraphImpl;
+import org.apache.samza.operators.impl.OperatorImpl;
 import org.apache.samza.operators.impl.OperatorImplGraph;
 import org.apache.samza.operators.impl.RootOperatorImpl;
 import org.apache.samza.operators.stream.InputStreamInternal;
@@ -31,6 +32,7 @@ import org.apache.samza.system.SystemStream;
 import org.apache.samza.util.Clock;
 import org.apache.samza.util.SystemClock;
 
+import java.util.Collection;
 import java.util.HashMap;
 import java.util.Map;
 
@@ -140,5 +142,8 @@ public final class StreamOperatorTask implements StreamTask, InitableTask, Windo
     if (this.contextManager != null) {
       this.contextManager.close();
     }
+
+    Collection<OperatorImpl> allOperators = operatorImplGraph.getAllOperators();
+    allOperators.forEach(OperatorImpl::close);
   }
 }
index 23b67aa..39745bf 100644 (file)
@@ -57,7 +57,7 @@ public class TestJoinOperator {
 
   @Test
   public void join() throws Exception {
-    StreamOperatorTask sot = createStreamOperatorTask(new SystemClock());
+    StreamOperatorTask sot = createStreamOperatorTask(new SystemClock(), new TestJoinStreamApplication(new TestJoinFunction()));
     List<Integer> output = new ArrayList<>();
     MessageCollector messageCollector = envelope -> output.add((Integer) envelope.getMessage());
 
@@ -71,8 +71,26 @@ public class TestJoinOperator {
   }
 
   @Test
+  public void testJoinFnInitAndClose() throws Exception {
+    TestJoinFunction joinFn = new TestJoinFunction();
+    StreamOperatorTask sot = createStreamOperatorTask(new SystemClock(), new TestJoinStreamApplication(joinFn));
+    assertEquals(joinFn.getNumInitCalls(), 1);
+    MessageCollector messageCollector = mock(MessageCollector.class);
+
+    // push messages to first stream
+    numbers.forEach(n -> sot.process(new FirstStreamIME(n, n), messageCollector, taskCoordinator));
+
+    // close should not be called till now
+    assertEquals(joinFn.getNumCloseCalls(), 0);
+    sot.close();
+
+    // close should be called from sot.close()
+    assertEquals(joinFn.getNumCloseCalls(), 1);
+  }
+
+  @Test
   public void joinReverse() throws Exception {
-    StreamOperatorTask sot = createStreamOperatorTask(new SystemClock());
+    StreamOperatorTask sot = createStreamOperatorTask(new SystemClock(), new TestJoinStreamApplication(new TestJoinFunction()));
     List<Integer> output = new ArrayList<>();
     MessageCollector messageCollector = envelope -> output.add((Integer) envelope.getMessage());
 
@@ -87,7 +105,7 @@ public class TestJoinOperator {
 
   @Test
   public void joinNoMatch() throws Exception {
-    StreamOperatorTask sot = createStreamOperatorTask(new SystemClock());
+    StreamOperatorTask sot = createStreamOperatorTask(new SystemClock(), new TestJoinStreamApplication(new TestJoinFunction()));
     List<Integer> output = new ArrayList<>();
     MessageCollector messageCollector = envelope -> output.add((Integer) envelope.getMessage());
 
@@ -101,7 +119,7 @@ public class TestJoinOperator {
 
   @Test
   public void joinNoMatchReverse() throws Exception {
-    StreamOperatorTask sot = createStreamOperatorTask(new SystemClock());
+    StreamOperatorTask sot = createStreamOperatorTask(new SystemClock(), new TestJoinStreamApplication(new TestJoinFunction()));
     List<Integer> output = new ArrayList<>();
     MessageCollector messageCollector = envelope -> output.add((Integer) envelope.getMessage());
 
@@ -115,7 +133,7 @@ public class TestJoinOperator {
 
   @Test
   public void joinRetainsLatestMessageForKey() throws Exception {
-    StreamOperatorTask sot = createStreamOperatorTask(new SystemClock());
+    StreamOperatorTask sot = createStreamOperatorTask(new SystemClock(), new TestJoinStreamApplication(new TestJoinFunction()));
     List<Integer> output = new ArrayList<>();
     MessageCollector messageCollector = envelope -> output.add((Integer) envelope.getMessage());
 
@@ -132,7 +150,7 @@ public class TestJoinOperator {
 
   @Test
   public void joinRetainsLatestMessageForKeyReverse() throws Exception {
-    StreamOperatorTask sot = createStreamOperatorTask(new SystemClock());
+    StreamOperatorTask sot = createStreamOperatorTask(new SystemClock(), new TestJoinStreamApplication(new TestJoinFunction()));
     List<Integer> output = new ArrayList<>();
     MessageCollector messageCollector = envelope -> output.add((Integer) envelope.getMessage());
 
@@ -149,7 +167,7 @@ public class TestJoinOperator {
 
   @Test
   public void joinRetainsMatchedMessages() throws Exception {
-    StreamOperatorTask sot = createStreamOperatorTask(new SystemClock());
+    StreamOperatorTask sot = createStreamOperatorTask(new SystemClock(), new TestJoinStreamApplication(new TestJoinFunction()));
     List<Integer> output = new ArrayList<>();
     MessageCollector messageCollector = envelope -> output.add((Integer) envelope.getMessage());
 
@@ -171,7 +189,7 @@ public class TestJoinOperator {
 
   @Test
   public void joinRetainsMatchedMessagesReverse() throws Exception {
-    StreamOperatorTask sot = createStreamOperatorTask(new SystemClock());
+    StreamOperatorTask sot = createStreamOperatorTask(new SystemClock(), new TestJoinStreamApplication(new TestJoinFunction()));
     List<Integer> output = new ArrayList<>();
     MessageCollector messageCollector = envelope -> output.add((Integer) envelope.getMessage());
 
@@ -194,7 +212,7 @@ public class TestJoinOperator {
   @Test
   public void joinRemovesExpiredMessages() throws Exception {
     TestClock testClock = new TestClock();
-    StreamOperatorTask sot = createStreamOperatorTask(testClock);
+    StreamOperatorTask sot = createStreamOperatorTask(testClock, new TestJoinStreamApplication(new TestJoinFunction()));
     List<Integer> output = new ArrayList<>();
     MessageCollector messageCollector = envelope -> output.add((Integer) envelope.getMessage());
 
@@ -214,7 +232,7 @@ public class TestJoinOperator {
   @Test
   public void joinRemovesExpiredMessagesReverse() throws Exception {
     TestClock testClock = new TestClock();
-    StreamOperatorTask sot = createStreamOperatorTask(testClock);
+    StreamOperatorTask sot = createStreamOperatorTask(testClock, new TestJoinStreamApplication(new TestJoinFunction()));
     List<Integer> output = new ArrayList<>();
     MessageCollector messageCollector = envelope -> output.add((Integer) envelope.getMessage());
 
@@ -230,7 +248,7 @@ public class TestJoinOperator {
     assertTrue(output.isEmpty());
   }
 
-  private StreamOperatorTask createStreamOperatorTask(Clock clock) throws Exception {
+  private StreamOperatorTask createStreamOperatorTask(Clock clock, StreamApplication app) throws Exception {
     ApplicationRunner runner = mock(ApplicationRunner.class);
     when(runner.getStreamSpec("instream")).thenReturn(new StreamSpec("instream", "instream", "insystem"));
     when(runner.getStreamSpec("instream2")).thenReturn(new StreamSpec("instream2", "instream2", "insystem2"));
@@ -243,13 +261,19 @@ public class TestJoinOperator {
 
     Config config = mock(Config.class);
 
-    StreamApplication sgb = new TestStreamApplication();
-    StreamOperatorTask sot = new StreamOperatorTask(sgb, runner, clock);
+    StreamOperatorTask sot = new StreamOperatorTask(app, runner, clock);
     sot.init(config, taskContext);
     return sot;
   }
 
-  private static class TestStreamApplication implements StreamApplication {
+  private static class TestJoinStreamApplication implements StreamApplication {
+
+    private final TestJoinFunction joinFn;
+
+    TestJoinStreamApplication(TestJoinFunction joinFn) {
+      this.joinFn = joinFn;
+    }
+
     @Override
     public void init(StreamGraph graph, Config config) {
       MessageStream<FirstStreamIME> inStream =
@@ -259,7 +283,7 @@ public class TestJoinOperator {
 
       SystemStream outputSystemStream = new SystemStream("outputSystem", "outputStream");
       inStream
-          .join(inStream2, new TestJoinFunction(), JOIN_TTL)
+          .join(inStream2, joinFn, JOIN_TTL)
           .sink((message, messageCollector, taskCoordinator) -> {
               messageCollector.send(new OutgoingMessageEnvelope(outputSystemStream, message));
             });
@@ -267,6 +291,15 @@ public class TestJoinOperator {
   }
 
   private static class TestJoinFunction implements JoinFunction<Integer, FirstStreamIME, SecondStreamIME, Integer> {
+
+    private int numInitCalls = 0;
+    private int numCloseCalls = 0;
+
+    @Override
+    public void init(Config config, TaskContext context) {
+      numInitCalls++;
+    }
+
     @Override
     public Integer apply(FirstStreamIME message, SecondStreamIME otherMessage) {
       return (Integer) message.getMessage() + (Integer) otherMessage.getMessage();
@@ -281,6 +314,19 @@ public class TestJoinOperator {
     public Integer getSecondKey(SecondStreamIME message) {
       return (Integer) message.getKey();
     }
+
+    @Override
+    public void close() {
+      numCloseCalls++;
+    }
+
+    public int getNumInitCalls() {
+      return numInitCalls;
+    }
+
+    public int getNumCloseCalls() {
+      return numCloseCalls;
+    }
   }
 
   private static class FirstStreamIME extends IncomingMessageEnvelope {
index 99bf854..73d851b 100644 (file)
@@ -197,6 +197,9 @@ public class TestOperatorImpl {
     }
 
     @Override
+    protected void handleClose() {}
+
+    @Override
     protected OperatorSpec<Object> getOperatorSpec() {
       return new TestOpSpec();
     }
index 1c01e57..dc94e36 100644 (file)
@@ -36,13 +36,9 @@ import static org.mockito.Mockito.when;
 public class TestSinkOperatorImpl {
 
   @Test
-  public void testSinkOperator() {
-    SinkOperatorSpec<TestOutputMessageEnvelope> sinkOp = mock(SinkOperatorSpec.class);
+  public void testSinkOperatorSinkFunction() {
     SinkFunction<TestOutputMessageEnvelope> sinkFn = mock(SinkFunction.class);
-    when(sinkOp.getSinkFn()).thenReturn(sinkFn);
-    Config mockConfig = mock(Config.class);
-    TaskContext mockContext = mock(TaskContext.class);
-    SinkOperatorImpl<TestOutputMessageEnvelope> sinkImpl = new SinkOperatorImpl<>(sinkOp, mockConfig, mockContext);
+    SinkOperatorImpl<TestOutputMessageEnvelope> sinkImpl = createSinkOperator(sinkFn);
     TestOutputMessageEnvelope mockMsg = mock(TestOutputMessageEnvelope.class);
     MessageCollector mockCollector = mock(MessageCollector.class);
     TaskCoordinator mockCoordinator = mock(TaskCoordinator.class);
@@ -50,4 +46,32 @@ public class TestSinkOperatorImpl {
     sinkImpl.handleMessage(mockMsg, mockCollector, mockCoordinator);
     verify(sinkFn, times(1)).apply(mockMsg, mockCollector, mockCoordinator);
   }
+
+  @Test
+  public void testSinkOperatorClose() {
+    TestOutputMessageEnvelope mockMsg = mock(TestOutputMessageEnvelope.class);
+    MessageCollector mockCollector = mock(MessageCollector.class);
+    TaskCoordinator mockCoordinator = mock(TaskCoordinator.class);
+    SinkFunction<TestOutputMessageEnvelope> sinkFn = mock(SinkFunction.class);
+
+    SinkOperatorImpl<TestOutputMessageEnvelope> sinkImpl = createSinkOperator(sinkFn);
+    sinkImpl.handleMessage(mockMsg, mockCollector, mockCoordinator);
+    verify(sinkFn, times(1)).apply(mockMsg, mockCollector, mockCoordinator);
+
+    // ensure that close is not called yet
+    verify(sinkFn, times(0)).close();
+
+    sinkImpl.handleClose();
+    // ensure that close is called once from handleClose()
+    verify(sinkFn, times(1)).close();
+  }
+
+  private SinkOperatorImpl createSinkOperator(SinkFunction<TestOutputMessageEnvelope> sinkFn) {
+    SinkOperatorSpec<TestOutputMessageEnvelope> sinkOp = mock(SinkOperatorSpec.class);
+    when(sinkOp.getSinkFn()).thenReturn(sinkFn);
+
+    Config mockConfig = mock(Config.class);
+    TaskContext mockContext = mock(TaskContext.class);
+    return new SinkOperatorImpl<>(sinkOp, mockConfig, mockContext);
+  }
 }
index 36d7b92..a5d9539 100644 (file)
@@ -64,4 +64,22 @@ public class TestStreamOperatorImpl {
     verify(txfmFn, times(1)).apply(inMsg);
     assertEquals(results, mockOutputs);
   }
+
+  @Test
+  public void testSimpleOperatorClose() {
+    StreamOperatorSpec<TestMessageEnvelope, TestOutputMessageEnvelope> mockOp = mock(StreamOperatorSpec.class);
+    FlatMapFunction<TestMessageEnvelope, TestOutputMessageEnvelope> txfmFn = mock(FlatMapFunction.class);
+    when(mockOp.getTransformFn()).thenReturn(txfmFn);
+    Config mockConfig = mock(Config.class);
+    TaskContext mockContext = mock(TaskContext.class);
+
+    StreamOperatorImpl<TestMessageEnvelope, TestOutputMessageEnvelope> opImpl =
+        spy(new StreamOperatorImpl<>(mockOp, mockConfig, mockContext));
+
+    // ensure that close is not called yet
+    verify(txfmFn, times(0)).close();
+    opImpl.handleClose();
+    // ensure that close is called once inside handleClose()
+    verify(txfmFn, times(1)).close();
+  }
 }