- Added `close()` to the lifecycle of `OperatorImpl`s, and all `Function`s.
- Added unit tests to verify calls to `close()`
Author: vjagadish1989 <jvenkatr@linkedin.com>
Reviewers: Prateek Maheshwari<pmaheshw@linkedin.com>
Closes #208 from vjagadish1989/operator_functions
* runner.waitForFinish();
* }
* }</pre>
+ *
* <p>
* Implementation Notes: Currently StreamApplications are wrapped in a {@link StreamTask} during execution.
* A new StreamApplication instance will be created and initialized when planning the execution, as well as for each
- * {@link StreamTask} instance used for processing incoming messages. Execution is synchronous and thread-safe
- * within each {@link StreamTask}.
+ * {@link StreamTask} instance used for processing incoming messages. Execution is synchronous and thread-safe within
+ * each {@link StreamTask}.
+ *
+ * <p>
+ * Functions implemented for transforms in StreamApplications ({@link org.apache.samza.operators.functions.MapFunction},
+ * {@link org.apache.samza.operators.functions.FilterFunction} for e.g.) are initable and closable. They are initialized
+ * before messages are delivered to them and closed after their execution when the {@link StreamTask} instance is closed.
+ * See {@link InitableFunction} and {@link org.apache.samza.operators.functions.ClosableFunction}.
*/
@InterfaceStability.Unstable
public interface StreamApplication {
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.operators.functions;
+
+import org.apache.samza.annotation.InterfaceStability;
+
+/**
+ * A function that can be closed after its execution.
+ *
+ * <p> Implement {@link #close()} to free resources used during the execution of the function, clean up state etc.
+ *
+ */
+@InterfaceStability.Unstable
+public interface ClosableFunction {
+ default void close() {}
+}
*/
@InterfaceStability.Unstable
@FunctionalInterface
-public interface FilterFunction<M> extends InitableFunction {
+public interface FilterFunction<M> extends InitableFunction, ClosableFunction {
/**
* Returns a boolean indicating whether this message should be retained or filtered out.
*/
@InterfaceStability.Unstable
@FunctionalInterface
-public interface FlatMapFunction<M, OM> extends InitableFunction {
+public interface FlatMapFunction<M, OM> extends InitableFunction, ClosableFunction {
/**
* Transforms the provided message into a collection of 0 or more messages.
/**
* Incrementally updates the window value as messages are added to the window.
*/
-public interface FoldLeftFunction<M, WV> extends InitableFunction {
+public interface FoldLeftFunction<M, WV> extends InitableFunction, ClosableFunction {
/**
* Incrementally updates the window value as messages are added to the window.
* @param <RM> type of the joined message
*/
@InterfaceStability.Unstable
-public interface JoinFunction<K, M, JM, RM> extends InitableFunction {
+public interface JoinFunction<K, M, JM, RM> extends InitableFunction, ClosableFunction {
/**
* Joins the provided messages and returns the joined message.
*/
@InterfaceStability.Unstable
@FunctionalInterface
-public interface MapFunction<M, OM> extends InitableFunction {
+public interface MapFunction<M, OM> extends InitableFunction, ClosableFunction {
/**
* Transforms the provided message into another message.
*/
@InterfaceStability.Unstable
@FunctionalInterface
-public interface SinkFunction<M> extends InitableFunction {
+public interface SinkFunction<M> extends InitableFunction, ClosableFunction {
/**
* Allows sending the provided message to an output {@link org.apache.samza.system.SystemStream} using
thisStreamState = new InternalInMemoryStore<>();
}
+
+ @Override
+ public void close() {
+ // joinFn#close() must only be called once, so we do it in this partial join function's #close.
+ joinFn.close();
+ }
};
PartialJoinFunction<K, OM, M, TM> otherPartialJoinFn = new PartialJoinFunction<K, OM, M, TM>() {
/**
* An internal function that maintains state and join logic for one side of a two-way join.
*/
-public interface PartialJoinFunction<K, M, JM, RM> extends InitableFunction {
+public interface PartialJoinFunction<K, M, JM, RM> extends InitableFunction, ClosableFunction {
/**
* Joins a message in this stream with a message from another stream.
private static final String METRICS_GROUP = OperatorImpl.class.getName();
private boolean initialized;
+ private boolean closed;
private Set<OperatorImpl<RM, ?>> registeredOperators;
private HighResolutionClock highResClock;
private Counter numMessage;
throw new IllegalStateException(String.format("Attempted to initialize Operator %s more than once.", opName));
}
+ if (closed) {
+ throw new IllegalStateException(String.format("Attempted to initialize Operator %s after it was closed.", opName));
+ }
+
this.highResClock = createHighResClock(config);
registeredOperators = new HashSet<>();
MetricsRegistry metricsRegistry = context.getMetricsRegistry();
return Collections.emptyList();
}
+ public void close() {
+ String opName = getOperatorSpec().getOpName();
+
+ if (closed) {
+ throw new IllegalStateException(String.format("Attempted to close Operator %s more than once.", opName));
+ }
+ handleClose();
+ closed = true;
+ }
+
+ protected abstract void handleClose();
+
/**
* Get the {@link OperatorSpec} for this {@link OperatorImpl}.
*
}
/**
+ * Get all {@link OperatorImpl}s for the graph.
+ *
+ * @return an unmodifiable view of all {@link OperatorImpl}s for the graph
+ */
+ public Collection<OperatorImpl> getAllOperators() {
+ return Collections.unmodifiableCollection(this.operatorImpls.values());
+ }
+
+ /**
* Traverses the DAG of {@link OperatorSpec}s starting from the provided {@link MessageStreamImpl},
* creates the corresponding DAG of {@link OperatorImpl}s, and returns its root {@link RootOperatorImpl} node.
*
}
@Override
+ protected void handleClose() {
+ this.thisPartialJoinFn.close();
+ }
+
+ @Override
protected OperatorSpec<RM> getOperatorSpec() {
return partialJoinOpSpec;
}
return Collections.singletonList(message);
}
+ @Override
+ protected void handleClose() {
+ }
+
// TODO: SAMZA-1221 - Change to InputOperatorSpec that also builds the message
@Override
protected OperatorSpec<M> getOperatorSpec() {
}
@Override
+ protected void handleClose() {
+ this.sinkFn.close();
+ }
+
+ @Override
protected OperatorSpec<M> getOperatorSpec() {
return sinkOpSpec;
}
}
@Override
+ protected void handleClose() {
+ this.transformFn.close();
+ }
+
+ @Override
protected OperatorSpec<RM> getOperatorSpec() {
return streamOpSpec;
}
return windowOpSpec;
}
+ @Override
+ protected void handleClose() {
+ WindowInternal<M, WK, WV> window = windowOpSpec.getWindow();
+
+ if (window.getFoldLeftFunction() != null) {
+ window.getFoldLeftFunction().close();
+ }
+ }
+
/**
* Get the key to be used for lookups in the store for this message.
*/
import org.apache.samza.config.Config;
import org.apache.samza.operators.ContextManager;
import org.apache.samza.operators.StreamGraphImpl;
+import org.apache.samza.operators.impl.OperatorImpl;
import org.apache.samza.operators.impl.OperatorImplGraph;
import org.apache.samza.operators.impl.RootOperatorImpl;
import org.apache.samza.operators.stream.InputStreamInternal;
import org.apache.samza.util.Clock;
import org.apache.samza.util.SystemClock;
+import java.util.Collection;
import java.util.HashMap;
import java.util.Map;
if (this.contextManager != null) {
this.contextManager.close();
}
+
+ Collection<OperatorImpl> allOperators = operatorImplGraph.getAllOperators();
+ allOperators.forEach(OperatorImpl::close);
}
}
@Test
public void join() throws Exception {
- StreamOperatorTask sot = createStreamOperatorTask(new SystemClock());
+ StreamOperatorTask sot = createStreamOperatorTask(new SystemClock(), new TestJoinStreamApplication(new TestJoinFunction()));
List<Integer> output = new ArrayList<>();
MessageCollector messageCollector = envelope -> output.add((Integer) envelope.getMessage());
}
@Test
+ public void testJoinFnInitAndClose() throws Exception {
+ TestJoinFunction joinFn = new TestJoinFunction();
+ StreamOperatorTask sot = createStreamOperatorTask(new SystemClock(), new TestJoinStreamApplication(joinFn));
+ assertEquals(joinFn.getNumInitCalls(), 1);
+ MessageCollector messageCollector = mock(MessageCollector.class);
+
+ // push messages to first stream
+ numbers.forEach(n -> sot.process(new FirstStreamIME(n, n), messageCollector, taskCoordinator));
+
+ // close should not be called till now
+ assertEquals(joinFn.getNumCloseCalls(), 0);
+ sot.close();
+
+ // close should be called from sot.close()
+ assertEquals(joinFn.getNumCloseCalls(), 1);
+ }
+
+ @Test
public void joinReverse() throws Exception {
- StreamOperatorTask sot = createStreamOperatorTask(new SystemClock());
+ StreamOperatorTask sot = createStreamOperatorTask(new SystemClock(), new TestJoinStreamApplication(new TestJoinFunction()));
List<Integer> output = new ArrayList<>();
MessageCollector messageCollector = envelope -> output.add((Integer) envelope.getMessage());
@Test
public void joinNoMatch() throws Exception {
- StreamOperatorTask sot = createStreamOperatorTask(new SystemClock());
+ StreamOperatorTask sot = createStreamOperatorTask(new SystemClock(), new TestJoinStreamApplication(new TestJoinFunction()));
List<Integer> output = new ArrayList<>();
MessageCollector messageCollector = envelope -> output.add((Integer) envelope.getMessage());
@Test
public void joinNoMatchReverse() throws Exception {
- StreamOperatorTask sot = createStreamOperatorTask(new SystemClock());
+ StreamOperatorTask sot = createStreamOperatorTask(new SystemClock(), new TestJoinStreamApplication(new TestJoinFunction()));
List<Integer> output = new ArrayList<>();
MessageCollector messageCollector = envelope -> output.add((Integer) envelope.getMessage());
@Test
public void joinRetainsLatestMessageForKey() throws Exception {
- StreamOperatorTask sot = createStreamOperatorTask(new SystemClock());
+ StreamOperatorTask sot = createStreamOperatorTask(new SystemClock(), new TestJoinStreamApplication(new TestJoinFunction()));
List<Integer> output = new ArrayList<>();
MessageCollector messageCollector = envelope -> output.add((Integer) envelope.getMessage());
@Test
public void joinRetainsLatestMessageForKeyReverse() throws Exception {
- StreamOperatorTask sot = createStreamOperatorTask(new SystemClock());
+ StreamOperatorTask sot = createStreamOperatorTask(new SystemClock(), new TestJoinStreamApplication(new TestJoinFunction()));
List<Integer> output = new ArrayList<>();
MessageCollector messageCollector = envelope -> output.add((Integer) envelope.getMessage());
@Test
public void joinRetainsMatchedMessages() throws Exception {
- StreamOperatorTask sot = createStreamOperatorTask(new SystemClock());
+ StreamOperatorTask sot = createStreamOperatorTask(new SystemClock(), new TestJoinStreamApplication(new TestJoinFunction()));
List<Integer> output = new ArrayList<>();
MessageCollector messageCollector = envelope -> output.add((Integer) envelope.getMessage());
@Test
public void joinRetainsMatchedMessagesReverse() throws Exception {
- StreamOperatorTask sot = createStreamOperatorTask(new SystemClock());
+ StreamOperatorTask sot = createStreamOperatorTask(new SystemClock(), new TestJoinStreamApplication(new TestJoinFunction()));
List<Integer> output = new ArrayList<>();
MessageCollector messageCollector = envelope -> output.add((Integer) envelope.getMessage());
@Test
public void joinRemovesExpiredMessages() throws Exception {
TestClock testClock = new TestClock();
- StreamOperatorTask sot = createStreamOperatorTask(testClock);
+ StreamOperatorTask sot = createStreamOperatorTask(testClock, new TestJoinStreamApplication(new TestJoinFunction()));
List<Integer> output = new ArrayList<>();
MessageCollector messageCollector = envelope -> output.add((Integer) envelope.getMessage());
@Test
public void joinRemovesExpiredMessagesReverse() throws Exception {
TestClock testClock = new TestClock();
- StreamOperatorTask sot = createStreamOperatorTask(testClock);
+ StreamOperatorTask sot = createStreamOperatorTask(testClock, new TestJoinStreamApplication(new TestJoinFunction()));
List<Integer> output = new ArrayList<>();
MessageCollector messageCollector = envelope -> output.add((Integer) envelope.getMessage());
assertTrue(output.isEmpty());
}
- private StreamOperatorTask createStreamOperatorTask(Clock clock) throws Exception {
+ private StreamOperatorTask createStreamOperatorTask(Clock clock, StreamApplication app) throws Exception {
ApplicationRunner runner = mock(ApplicationRunner.class);
when(runner.getStreamSpec("instream")).thenReturn(new StreamSpec("instream", "instream", "insystem"));
when(runner.getStreamSpec("instream2")).thenReturn(new StreamSpec("instream2", "instream2", "insystem2"));
Config config = mock(Config.class);
- StreamApplication sgb = new TestStreamApplication();
- StreamOperatorTask sot = new StreamOperatorTask(sgb, runner, clock);
+ StreamOperatorTask sot = new StreamOperatorTask(app, runner, clock);
sot.init(config, taskContext);
return sot;
}
- private static class TestStreamApplication implements StreamApplication {
+ private static class TestJoinStreamApplication implements StreamApplication {
+
+ private final TestJoinFunction joinFn;
+
+ TestJoinStreamApplication(TestJoinFunction joinFn) {
+ this.joinFn = joinFn;
+ }
+
@Override
public void init(StreamGraph graph, Config config) {
MessageStream<FirstStreamIME> inStream =
SystemStream outputSystemStream = new SystemStream("outputSystem", "outputStream");
inStream
- .join(inStream2, new TestJoinFunction(), JOIN_TTL)
+ .join(inStream2, joinFn, JOIN_TTL)
.sink((message, messageCollector, taskCoordinator) -> {
messageCollector.send(new OutgoingMessageEnvelope(outputSystemStream, message));
});
}
private static class TestJoinFunction implements JoinFunction<Integer, FirstStreamIME, SecondStreamIME, Integer> {
+
+ private int numInitCalls = 0;
+ private int numCloseCalls = 0;
+
+ @Override
+ public void init(Config config, TaskContext context) {
+ numInitCalls++;
+ }
+
@Override
public Integer apply(FirstStreamIME message, SecondStreamIME otherMessage) {
return (Integer) message.getMessage() + (Integer) otherMessage.getMessage();
public Integer getSecondKey(SecondStreamIME message) {
return (Integer) message.getKey();
}
+
+ @Override
+ public void close() {
+ numCloseCalls++;
+ }
+
+ public int getNumInitCalls() {
+ return numInitCalls;
+ }
+
+ public int getNumCloseCalls() {
+ return numCloseCalls;
+ }
}
private static class FirstStreamIME extends IncomingMessageEnvelope {
}
@Override
+ protected void handleClose() {}
+
+ @Override
protected OperatorSpec<Object> getOperatorSpec() {
return new TestOpSpec();
}
public class TestSinkOperatorImpl {
@Test
- public void testSinkOperator() {
- SinkOperatorSpec<TestOutputMessageEnvelope> sinkOp = mock(SinkOperatorSpec.class);
+ public void testSinkOperatorSinkFunction() {
SinkFunction<TestOutputMessageEnvelope> sinkFn = mock(SinkFunction.class);
- when(sinkOp.getSinkFn()).thenReturn(sinkFn);
- Config mockConfig = mock(Config.class);
- TaskContext mockContext = mock(TaskContext.class);
- SinkOperatorImpl<TestOutputMessageEnvelope> sinkImpl = new SinkOperatorImpl<>(sinkOp, mockConfig, mockContext);
+ SinkOperatorImpl<TestOutputMessageEnvelope> sinkImpl = createSinkOperator(sinkFn);
TestOutputMessageEnvelope mockMsg = mock(TestOutputMessageEnvelope.class);
MessageCollector mockCollector = mock(MessageCollector.class);
TaskCoordinator mockCoordinator = mock(TaskCoordinator.class);
sinkImpl.handleMessage(mockMsg, mockCollector, mockCoordinator);
verify(sinkFn, times(1)).apply(mockMsg, mockCollector, mockCoordinator);
}
+
+ @Test
+ public void testSinkOperatorClose() {
+ TestOutputMessageEnvelope mockMsg = mock(TestOutputMessageEnvelope.class);
+ MessageCollector mockCollector = mock(MessageCollector.class);
+ TaskCoordinator mockCoordinator = mock(TaskCoordinator.class);
+ SinkFunction<TestOutputMessageEnvelope> sinkFn = mock(SinkFunction.class);
+
+ SinkOperatorImpl<TestOutputMessageEnvelope> sinkImpl = createSinkOperator(sinkFn);
+ sinkImpl.handleMessage(mockMsg, mockCollector, mockCoordinator);
+ verify(sinkFn, times(1)).apply(mockMsg, mockCollector, mockCoordinator);
+
+ // ensure that close is not called yet
+ verify(sinkFn, times(0)).close();
+
+ sinkImpl.handleClose();
+ // ensure that close is called once from handleClose()
+ verify(sinkFn, times(1)).close();
+ }
+
+ private SinkOperatorImpl createSinkOperator(SinkFunction<TestOutputMessageEnvelope> sinkFn) {
+ SinkOperatorSpec<TestOutputMessageEnvelope> sinkOp = mock(SinkOperatorSpec.class);
+ when(sinkOp.getSinkFn()).thenReturn(sinkFn);
+
+ Config mockConfig = mock(Config.class);
+ TaskContext mockContext = mock(TaskContext.class);
+ return new SinkOperatorImpl<>(sinkOp, mockConfig, mockContext);
+ }
}
verify(txfmFn, times(1)).apply(inMsg);
assertEquals(results, mockOutputs);
}
+
+ @Test
+ public void testSimpleOperatorClose() {
+ StreamOperatorSpec<TestMessageEnvelope, TestOutputMessageEnvelope> mockOp = mock(StreamOperatorSpec.class);
+ FlatMapFunction<TestMessageEnvelope, TestOutputMessageEnvelope> txfmFn = mock(FlatMapFunction.class);
+ when(mockOp.getTransformFn()).thenReturn(txfmFn);
+ Config mockConfig = mock(Config.class);
+ TaskContext mockContext = mock(TaskContext.class);
+
+ StreamOperatorImpl<TestMessageEnvelope, TestOutputMessageEnvelope> opImpl =
+ spy(new StreamOperatorImpl<>(mockOp, mockConfig, mockContext));
+
+ // ensure that close is not called yet
+ verify(txfmFn, times(0)).close();
+ opImpl.handleClose();
+ // ensure that close is called once inside handleClose()
+ verify(txfmFn, times(1)).close();
+ }
}