SAMZA-1321: Propagate end-of-stream and watermark messages
authorXinyu Liu <xiliu@xiliu-ld.linkedin.biz>
Thu, 29 Jun 2017 00:16:10 +0000 (17:16 -0700)
committerXinyu Liu <xiliu@xiliu-ld.linkedin.biz>
Thu, 29 Jun 2017 00:16:10 +0000 (17:16 -0700)
The patch completes the end-of-stream work flow across multi-stage pipeline. It also contains initial commit for supporting watermarks. For watermark, there are issues raised in the review feedback and will be addressed by further prs. The main logic this patch adds:

- EndOfStreamManager aggregates the end-of-stream control messages, propagate the result to to downstream intermediate topics based on the topology of the IO in the StreamGraph.

- WatermarkManager aggregates the watermark control messages from the upstage tasks, pass it through the operators, and propagate it to downstream.

In operator impl, I implemented similar watermark logic as Beam for watermark propagation:
* InputWatermark(op) = min { OutputWatermark(op') | op1 is upstream of op}
* OutputWatermark(op) = min { InputWatermark(op), OldestWorkTime(op) }

Add quite a few unit tests and integration test. The code is 100% covered as reported by Intellij. Both control messages work as expected.

Author: Xinyu Liu <xiliu@xiliu-ld.linkedin.biz>

Reviewers: Yi Pan <nickpan47@gmail.com>

Closes #236 from xinyuiscool/SAMZA-1321

44 files changed:
samza-api/src/main/java/org/apache/samza/system/IncomingMessageEnvelope.java
samza-api/src/main/java/org/apache/samza/system/StreamSpec.java
samza-core/src/main/java/org/apache/samza/control/ControlMessageListenerTask.java [new file with mode: 0644]
samza-core/src/main/java/org/apache/samza/control/ControlMessageUtils.java [new file with mode: 0644]
samza-core/src/main/java/org/apache/samza/control/EndOfStreamManager.java [new file with mode: 0644]
samza-core/src/main/java/org/apache/samza/control/IOGraph.java [new file with mode: 0644]
samza-core/src/main/java/org/apache/samza/control/Watermark.java [new file with mode: 0644]
samza-core/src/main/java/org/apache/samza/control/WatermarkManager.java [new file with mode: 0644]
samza-core/src/main/java/org/apache/samza/message/IntermediateMessageType.java [new file with mode: 0644]
samza-core/src/main/java/org/apache/samza/operators/StreamGraphImpl.java
samza-core/src/main/java/org/apache/samza/operators/impl/OperatorImpl.java
samza-core/src/main/java/org/apache/samza/operators/impl/OutputOperatorImpl.java
samza-core/src/main/java/org/apache/samza/processor/StreamProcessor.java
samza-core/src/main/java/org/apache/samza/runtime/LocalApplicationRunner.java
samza-core/src/main/java/org/apache/samza/runtime/LocalContainerRunner.java
samza-core/src/main/java/org/apache/samza/serializers/IntermediateMessageSerde.java
samza-core/src/main/java/org/apache/samza/task/AsyncRunLoop.java
samza-core/src/main/java/org/apache/samza/task/AsyncStreamTaskAdapter.java
samza-core/src/main/java/org/apache/samza/task/StreamOperatorTask.java
samza-core/src/main/scala/org/apache/samza/container/SamzaContainer.scala
samza-core/src/main/scala/org/apache/samza/container/TaskInstance.scala
samza-core/src/main/scala/org/apache/samza/job/local/ThreadJobFactory.scala
samza-core/src/main/scala/org/apache/samza/system/StreamMetadataCache.scala
samza-core/src/test/java/org/apache/samza/control/TestControlMessageUtils.java [new file with mode: 0644]
samza-core/src/test/java/org/apache/samza/control/TestEndOfStreamManager.java [new file with mode: 0644]
samza-core/src/test/java/org/apache/samza/control/TestIOGraph.java [new file with mode: 0644]
samza-core/src/test/java/org/apache/samza/control/TestWatermarkManager.java [new file with mode: 0644]
samza-core/src/test/java/org/apache/samza/operators/impl/TestOperatorImpl.java
samza-core/src/test/java/org/apache/samza/processor/TestStreamProcessor.java
samza-core/src/test/java/org/apache/samza/runtime/TestLocalApplicationRunner.java
samza-core/src/test/java/org/apache/samza/serializers/model/serializers/TestIntermediateMessageSerde.java
samza-core/src/test/java/org/apache/samza/task/TestAsyncRunLoop.java
samza-core/src/test/java/org/apache/samza/task/TestStreamOperatorTask.java [new file with mode: 0644]
samza-core/src/test/scala/org/apache/samza/container/TestTaskInstance.scala
samza-hdfs/src/main/java/org/apache/samza/system/hdfs/HdfsSystemConsumer.java
samza-test/src/test/java/org/apache/samza/processor/TestStreamProcessorUtil.java [new file with mode: 0644]
samza-test/src/test/java/org/apache/samza/test/controlmessages/EndOfStreamIntegrationTest.java [new file with mode: 0644]
samza-test/src/test/java/org/apache/samza/test/controlmessages/TestData.java [new file with mode: 0644]
samza-test/src/test/java/org/apache/samza/test/controlmessages/WatermarkIntegrationTest.java [new file with mode: 0644]
samza-test/src/test/java/org/apache/samza/test/util/ArraySystemConsumer.java [new file with mode: 0644]
samza-test/src/test/java/org/apache/samza/test/util/ArraySystemFactory.java [new file with mode: 0644]
samza-test/src/test/java/org/apache/samza/test/util/Base64Serializer.java [new file with mode: 0644]
samza-test/src/test/java/org/apache/samza/test/util/SimpleSystemAdmin.java [new file with mode: 0644]
samza-test/src/test/java/org/apache/samza/test/util/TestStreamConsumer.java [new file with mode: 0644]

index 0ced773..9182522 100644 (file)
@@ -91,12 +91,12 @@ public class IncomingMessageEnvelope {
   }
 
   /**
-   * Builds an end-of-stream envelope for an SSP. This is used by a {@link SystemConsumer} implementation to
-   * indicate that it is at end-of-stream. The end-of-stream message should not delivered to the task implementation.
+   * This method is deprecated in favor of WatermarkManager.buildEndOfStreamEnvelope(SystemStreamPartition ssp).
    *
    * @param ssp The SSP that is at end-of-stream.
    * @return an IncomingMessageEnvelope corresponding to end-of-stream for that SSP.
    */
+  @Deprecated
   public static IncomingMessageEnvelope buildEndOfStreamEnvelope(SystemStreamPartition ssp) {
     return new IncomingMessageEnvelope(ssp, END_OF_STREAM_OFFSET, null, null);
   }
index 0cdeb95..49531dd 100644 (file)
@@ -196,6 +196,10 @@ public class StreamSpec {
     return config.getOrDefault(propertyName, defaultValue);
   }
 
+  public SystemStream toSystemStream() {
+    return new SystemStream(systemName, physicalName);
+  }
+
   private void validateLogicalIdentifier(String identifierName, String identifierValue) {
     if (identifierValue == null || !identifierValue.matches("[A-Za-z0-9_-]+")) {
       throw new IllegalArgumentException(String.format("Identifier '%s' is '%s'. It must match the expression [A-Za-z0-9_-]+", identifierName, identifierValue));
diff --git a/samza-core/src/main/java/org/apache/samza/control/ControlMessageListenerTask.java b/samza-core/src/main/java/org/apache/samza/control/ControlMessageListenerTask.java
new file mode 100644 (file)
index 0000000..9e4b40a
--- /dev/null
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.control;
+
+import org.apache.samza.system.SystemStream;
+import org.apache.samza.task.MessageCollector;
+import org.apache.samza.task.TaskCoordinator;
+
+
+/**
+ * The listener interface for the aggregation result of control messages.
+ * Any task that handles control messages such as {@link org.apache.samza.message.EndOfStreamMessage}
+ * and {@link org.apache.samza.message.WatermarkMessage} needs to implement this interface.
+ */
+public interface ControlMessageListenerTask {
+
+  /**
+   * Returns the topology of the streams. Any control message listener needs to
+   * provide this topology so Samza can propagate the control message to downstreams.
+   * @return {@link IOGraph} of input to output streams. It
+   */
+  IOGraph getIOGraph();
+
+  /**
+   * Invoked when a Watermark comes.
+   * @param watermark contains the watermark timestamp
+   * @param systemStream source of stream that emits the watermark
+   * @param collector message collector
+   * @param coordinator task coordinator
+   */
+  void onWatermark(Watermark watermark, SystemStream systemStream, MessageCollector collector, TaskCoordinator coordinator);
+}
diff --git a/samza-core/src/main/java/org/apache/samza/control/ControlMessageUtils.java b/samza-core/src/main/java/org/apache/samza/control/ControlMessageUtils.java
new file mode 100644 (file)
index 0000000..ebb0d86
--- /dev/null
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.control;
+
+import com.google.common.collect.Multimap;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.stream.Collectors;
+import org.apache.samza.message.ControlMessage;
+import org.apache.samza.system.OutgoingMessageEnvelope;
+import org.apache.samza.system.StreamMetadataCache;
+import org.apache.samza.system.SystemStream;
+import org.apache.samza.system.SystemStreamMetadata;
+import org.apache.samza.task.MessageCollector;
+
+
+/**
+ * This class privates static utils for handling control messages
+ */
+class ControlMessageUtils {
+
+  /**
+   * Send a control message to every partition of the {@link SystemStream}
+   * @param message control message
+   * @param systemStream the stream to sent
+   * @param metadataCache stream metadata cache
+   * @param collector collector to send the message
+   */
+  static void sendControlMessage(ControlMessage message,
+      SystemStream systemStream,
+      StreamMetadataCache metadataCache,
+      MessageCollector collector) {
+    SystemStreamMetadata metadata = metadataCache.getSystemStreamMetadata(systemStream, true);
+    int partitionCount = metadata.getSystemStreamPartitionMetadata().size();
+    for (int i = 0; i < partitionCount; i++) {
+      OutgoingMessageEnvelope envelopeOut = new OutgoingMessageEnvelope(systemStream, i, "", message);
+      collector.send(envelopeOut);
+    }
+  }
+
+  /**
+   * Calculate the mapping from an output stream to the number of upstream tasks that will produce to the output stream
+   * @param inputToTasks input stream to its consumer tasks mapping
+   * @param ioGraph topology of the stream inputs and outputs
+   * @return mapping from output to upstream task count
+   */
+  static Map<SystemStream, Integer> calculateUpstreamTaskCounts(Multimap<SystemStream, String> inputToTasks,
+      IOGraph ioGraph) {
+    if (ioGraph == null) {
+      return Collections.EMPTY_MAP;
+    }
+    Map<SystemStream, Integer> outputTaskCount = new HashMap<>();
+    ioGraph.getNodes().forEach(node -> {
+        // for each input stream, find out the tasks that are consuming this input stream using the inputToTasks mapping,
+        // then count the number of tasks
+        int count = node.getInputs().stream().flatMap(spec -> inputToTasks.get(spec.toSystemStream()).stream())
+            .collect(Collectors.toSet()).size();
+        // put the count of input tasks to the result
+        outputTaskCount.put(node.getOutput().toSystemStream(), count);
+      });
+    return outputTaskCount;
+  }
+}
diff --git a/samza-core/src/main/java/org/apache/samza/control/EndOfStreamManager.java b/samza-core/src/main/java/org/apache/samza/control/EndOfStreamManager.java
new file mode 100644 (file)
index 0000000..78a8741
--- /dev/null
@@ -0,0 +1,159 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.control;
+
+import com.google.common.collect.Multimap;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Set;
+import org.apache.samza.message.EndOfStreamMessage;
+import org.apache.samza.system.IncomingMessageEnvelope;
+import org.apache.samza.system.StreamMetadataCache;
+import org.apache.samza.system.SystemStream;
+import org.apache.samza.system.SystemStreamPartition;
+import org.apache.samza.task.MessageCollector;
+import org.apache.samza.task.TaskCoordinator;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+
+/**
+ * This class handles the end-of-stream control message. It aggregates the end-of-stream state for each input ssps of
+ * a task, and propagate the eos messages to downstream intermediate streams if needed.
+ *
+ * Internal use only.
+ */
+public class EndOfStreamManager {
+  private static final Logger log = LoggerFactory.getLogger(EndOfStreamManager.class);
+
+  private final String taskName;
+  private final MessageCollector collector;
+  // end-of-stream state per ssp
+  private final Map<SystemStreamPartition, EndOfStreamState> eosStates;
+  private final StreamMetadataCache metadataCache;
+  // topology information. Set during init()
+  private final ControlMessageListenerTask listener;
+  // mapping from output stream to its upstream task count
+  private final Map<SystemStream, Integer> upstreamTaskCounts;
+
+  public EndOfStreamManager(String taskName,
+      ControlMessageListenerTask listener,
+      Multimap<SystemStream, String> inputToTasks,
+      Set<SystemStreamPartition> ssps,
+      StreamMetadataCache metadataCache,
+      MessageCollector collector) {
+    this.taskName = taskName;
+    this.listener = listener;
+    this.metadataCache = metadataCache;
+    this.collector = collector;
+    Map<SystemStreamPartition, EndOfStreamState> states = new HashMap<>();
+    ssps.forEach(ssp -> {
+        states.put(ssp, new EndOfStreamState());
+      });
+    this.eosStates = Collections.unmodifiableMap(states);
+    this.upstreamTaskCounts = ControlMessageUtils.calculateUpstreamTaskCounts(inputToTasks, listener.getIOGraph());
+  }
+
+  public void update(IncomingMessageEnvelope envelope, TaskCoordinator coordinator) {
+    EndOfStreamState state = eosStates.get(envelope.getSystemStreamPartition());
+    EndOfStreamMessage message = (EndOfStreamMessage) envelope.getMessage();
+    state.update(message.getTaskName(), message.getTaskCount());
+    log.info("Received end-of-stream from task " + message.getTaskName() + " in " + envelope.getSystemStreamPartition());
+
+    SystemStream systemStream = envelope.getSystemStreamPartition().getSystemStream();
+    if (isEndOfStream(systemStream)) {
+      log.info("End-of-stream of input " + systemStream + " for " + systemStream);
+      listener.getIOGraph().getNodesOfInput(systemStream).forEach(node -> {
+          // find the intermediate streams that need broadcast the eos messages
+          if (node.isOutputIntermediate()) {
+            // check all the input stream partitions assigned to the task are end-of-stream
+            boolean inputsEndOfStream = node.getInputs().stream().allMatch(spec -> isEndOfStream(spec.toSystemStream()));
+            if (inputsEndOfStream) {
+              // broadcast the end-of-stream message to the intermediate stream
+              SystemStream outputStream = node.getOutput().toSystemStream();
+              sendEndOfStream(outputStream, upstreamTaskCounts.get(outputStream));
+            }
+          }
+        });
+
+      boolean allEndOfStream = eosStates.values().stream().allMatch(EndOfStreamState::isEndOfStream);
+      if (allEndOfStream) {
+        // all inputs have been end-of-stream, shut down the task
+        log.info("All input streams have reached the end for task " + taskName);
+        coordinator.shutdown(TaskCoordinator.RequestScope.CURRENT_TASK);
+      }
+    }
+  }
+
+  /**
+   * Return true if all partitions of the systemStream that are assigned to the current task have reached EndOfStream.
+   * @param systemStream stream
+   * @return whether the stream reaches to the end for this task
+   */
+  boolean isEndOfStream(SystemStream systemStream) {
+    return eosStates.entrySet().stream()
+        .filter(entry -> entry.getKey().getSystemStream().equals(systemStream))
+        .allMatch(entry -> entry.getValue().isEndOfStream());
+  }
+
+  /**
+   * Send the EndOfStream control messages to downstream
+   * @param systemStream downstream stream
+   */
+  void sendEndOfStream(SystemStream systemStream, int taskCount) {
+    log.info("Send end-of-stream messages with upstream task count {} to all partitions of {}", taskCount, systemStream);
+    final EndOfStreamMessage message = new EndOfStreamMessage(taskName, taskCount);
+    ControlMessageUtils.sendControlMessage(message, systemStream, metadataCache, collector);
+  }
+
+  /**
+   * This class keeps the internal state for a ssp to be end-of-stream.
+   */
+  final static class EndOfStreamState {
+    // set of upstream tasks
+    private final Set<String> tasks = new HashSet<>();
+    private int expectedTotal = Integer.MAX_VALUE;
+    private boolean isEndOfStream = false;
+
+    void update(String taskName, int taskCount) {
+      if (taskName != null) {
+        tasks.add(taskName);
+      }
+      expectedTotal = taskCount;
+      isEndOfStream = tasks.size() == expectedTotal;
+    }
+
+    boolean isEndOfStream() {
+      return isEndOfStream;
+    }
+  }
+
+  /**
+   * Build an end-of-stream envelope for an ssp of a source input.
+   *
+   * @param ssp The SSP that is at end-of-stream.
+   * @return an IncomingMessageEnvelope corresponding to end-of-stream for that SSP.
+   */
+  public static IncomingMessageEnvelope buildEndOfStreamEnvelope(SystemStreamPartition ssp) {
+    return new IncomingMessageEnvelope(ssp, IncomingMessageEnvelope.END_OF_STREAM_OFFSET, null, new EndOfStreamMessage(null, 0));
+  }
+}
diff --git a/samza-core/src/main/java/org/apache/samza/control/IOGraph.java b/samza-core/src/main/java/org/apache/samza/control/IOGraph.java
new file mode 100644 (file)
index 0000000..a30c13d
--- /dev/null
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.control;
+
+import com.google.common.collect.HashMultimap;
+import com.google.common.collect.Multimap;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Set;
+import org.apache.samza.operators.StreamGraphImpl;
+import org.apache.samza.operators.spec.OperatorSpec;
+import org.apache.samza.operators.spec.OutputOperatorSpec;
+import org.apache.samza.system.StreamSpec;
+import org.apache.samza.system.SystemStream;
+
+
+/**
+ * This class provides the topology of stream inputs to outputs.
+ */
+public class IOGraph {
+
+  public static final class IONode {
+    private final Set<StreamSpec> inputs = new HashSet<>();
+    private final StreamSpec output;
+    private final boolean isOutputIntermediate;
+
+    IONode(StreamSpec output, boolean isOutputIntermediate) {
+      this.output = output;
+      this.isOutputIntermediate = isOutputIntermediate;
+    }
+
+    void addInput(StreamSpec input) {
+      inputs.add(input);
+    }
+
+    public Set<StreamSpec> getInputs() {
+      return Collections.unmodifiableSet(inputs);
+    }
+
+    public StreamSpec getOutput() {
+      return output;
+    }
+
+    public boolean isOutputIntermediate() {
+      return isOutputIntermediate;
+    }
+  }
+
+  final Collection<IONode> nodes;
+  final Multimap<SystemStream, IONode> inputToNodes;
+
+  public IOGraph(Collection<IONode> nodes) {
+    this.nodes = Collections.unmodifiableCollection(nodes);
+    this.inputToNodes = HashMultimap.create();
+    nodes.forEach(node -> {
+        node.getInputs().forEach(stream -> {
+            inputToNodes.put(new SystemStream(stream.getSystemName(), stream.getPhysicalName()), node);
+          });
+      });
+  }
+
+  public Collection<IONode> getNodes() {
+    return this.nodes;
+  }
+
+  public Collection<IONode> getNodesOfInput(SystemStream input) {
+    return inputToNodes.get(input);
+  }
+
+  public static IOGraph buildIOGraph(StreamGraphImpl streamGraph) {
+    Map<Integer, IONode> nodes = new HashMap<>();
+    streamGraph.getInputOperators().entrySet().stream()
+        .forEach(entry -> buildIONodes(entry.getKey(), entry.getValue(), nodes));
+    return new IOGraph(nodes.values());
+  }
+
+  /* package private */
+  static void buildIONodes(StreamSpec input, OperatorSpec opSpec, Map<Integer, IONode> ioGraph) {
+    if (opSpec instanceof OutputOperatorSpec) {
+      OutputOperatorSpec outputOpSpec = (OutputOperatorSpec) opSpec;
+      IONode node = ioGraph.get(opSpec.getOpId());
+      if (node == null) {
+        StreamSpec output = outputOpSpec.getOutputStream().getStreamSpec();
+        node = new IONode(output, outputOpSpec.getOpCode() == OperatorSpec.OpCode.PARTITION_BY);
+        ioGraph.put(opSpec.getOpId(), node);
+      }
+      node.addInput(input);
+    }
+
+    Collection<OperatorSpec> nextOperators = opSpec.getRegisteredOperatorSpecs();
+    nextOperators.forEach(spec -> buildIONodes(input, spec, ioGraph));
+  }
+}
diff --git a/samza-core/src/main/java/org/apache/samza/control/Watermark.java b/samza-core/src/main/java/org/apache/samza/control/Watermark.java
new file mode 100644 (file)
index 0000000..a11e3b0
--- /dev/null
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.control;
+
+import org.apache.samza.annotation.InterfaceStability;
+import org.apache.samza.system.SystemStream;
+
+
+/**
+ * A watermark is a monotonically increasing value, which represents the point up to which the
+ * system believes it has received all of the data before the watermark timestamp. Data that arrives
+ * with a timestamp that is before the watermark is considered late.
+ *
+ * <p>This is the aggregate result from the WatermarkManager, which keeps track of the control message
+ * {@link org.apache.samza.message.WatermarkMessage} and aggregate by returning the min of all watermark timestamp
+ * in each partition.
+ */
+@InterfaceStability.Unstable
+public interface Watermark {
+  /**
+   * Returns the timestamp of the watermark
+   * Note that if the task consumes more than one partitions of this stream, the watermark emitted is the min of
+   * watermarks across all partitions.
+   * @return timestamp
+   */
+  long getTimestamp();
+
+  /**
+   * Propagates the watermark to an intermediate stream
+   * @param systemStream intermediate stream
+   */
+  void propagate(SystemStream systemStream);
+
+  /**
+   * Create a copy of the watermark with the timestamp
+   * @param timestamp new timestamp
+   * @return new watermark
+   */
+  Watermark copyWithTimestamp(long timestamp);
+}
diff --git a/samza-core/src/main/java/org/apache/samza/control/WatermarkManager.java b/samza-core/src/main/java/org/apache/samza/control/WatermarkManager.java
new file mode 100644 (file)
index 0000000..c4fdd88
--- /dev/null
@@ -0,0 +1,187 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.control;
+
+import com.google.common.collect.Multimap;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Optional;
+import java.util.Set;
+import org.apache.samza.message.WatermarkMessage;
+import org.apache.samza.system.IncomingMessageEnvelope;
+import org.apache.samza.system.StreamMetadataCache;
+import org.apache.samza.system.SystemStream;
+import org.apache.samza.system.SystemStreamPartition;
+import org.apache.samza.task.MessageCollector;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+
+/**
+ * This class manages watermarks. It aggregates the watermark control messages from the upstage tasks
+ * for each SSP into an envelope of {@link Watermark}, and provide a dispatcher to propagate it to downstream.
+ *
+ * Internal use only.
+ */
+public class WatermarkManager {
+  private static final Logger log = LoggerFactory.getLogger(WatermarkManager.class);
+  public static final long TIME_NOT_EXIST = -1;
+
+  private final String taskName;
+  private final Map<SystemStreamPartition, WatermarkState> watermarkStates;
+  private final Map<SystemStream, Long> watermarkPerStream;
+  private final StreamMetadataCache metadataCache;
+  private final MessageCollector collector;
+  // mapping from output stream to its upstream task count
+  private final Map<SystemStream, Integer> upstreamTaskCounts;
+
+  public WatermarkManager(String taskName,
+      ControlMessageListenerTask listener,
+      Multimap<SystemStream, String> inputToTasks,
+      Set<SystemStreamPartition> ssps,
+      StreamMetadataCache metadataCache,
+      MessageCollector collector) {
+    this.taskName = taskName;
+    this.watermarkPerStream = new HashMap<>();
+    this.metadataCache = metadataCache;
+    this.collector = collector;
+    this.upstreamTaskCounts = ControlMessageUtils.calculateUpstreamTaskCounts(inputToTasks, listener.getIOGraph());
+
+    Map<SystemStreamPartition, WatermarkState> states = new HashMap<>();
+    ssps.forEach(ssp -> {
+        states.put(ssp, new WatermarkState());
+        watermarkPerStream.put(ssp.getSystemStream(), TIME_NOT_EXIST);
+      });
+    this.watermarkStates = Collections.unmodifiableMap(states);
+  }
+
+  /**
+   * Update the watermark based on the incoming watermark message. The message contains
+   * a timestamp and the upstream producer task. The aggregation result is the minimal value
+   * of all watermarks for the stream:
+   * <ul>
+   *   <li>Watermark(ssp) = min { Watermark(task) | task is upstream producer and the count equals total expected tasks } </li>
+   *   <li>Watermark(stream) = min { Watermark(ssp) | ssp is a partition of stream that assigns to this task } </li>
+   * </ul>
+   *
+   * @param envelope the envelope contains {@link WatermarkMessage}
+   * @return watermark envelope if there is a new aggregate watermark for the stream
+   */
+  public Watermark update(IncomingMessageEnvelope envelope) {
+    SystemStreamPartition ssp = envelope.getSystemStreamPartition();
+    WatermarkState state = watermarkStates.get(ssp);
+    WatermarkMessage message = (WatermarkMessage) envelope.getMessage();
+    state.update(message.getTimestamp(), message.getTaskName(), message.getTaskCount());
+
+    if (state.getWatermarkTime() != TIME_NOT_EXIST) {
+      long minTimestamp = watermarkStates.entrySet().stream()
+          .filter(entry -> entry.getKey().getSystemStream().equals(ssp.getSystemStream()))
+          .map(entry -> entry.getValue().getWatermarkTime())
+          .min(Long::compare)
+          .get();
+      Long curWatermark = watermarkPerStream.get(ssp.getSystemStream());
+      if (curWatermark == null || curWatermark < minTimestamp) {
+        watermarkPerStream.put(ssp.getSystemStream(), minTimestamp);
+        return new WatermarkImpl(minTimestamp);
+      }
+    }
+
+    return null;
+  }
+
+  /* package private */
+  long getWatermarkTime(SystemStreamPartition ssp) {
+    return watermarkStates.get(ssp).getWatermarkTime();
+  }
+
+  /**
+   * Send the watermark message to all partitions of an intermediate stream
+   * @param timestamp watermark timestamp
+   * @param systemStream intermediate stream
+   */
+  void sendWatermark(long timestamp, SystemStream systemStream, int taskCount) {
+    log.info("Send end-of-stream messages to all partitions of " + systemStream);
+    final WatermarkMessage watermarkMessage = new WatermarkMessage(timestamp, taskName, taskCount);
+    ControlMessageUtils.sendControlMessage(watermarkMessage, systemStream, metadataCache, collector);
+  }
+
+  /**
+   * Per ssp state of the watermarks. This class keeps track of the latest watermark timestamp
+   * from each upstream producer tasks, and use the min to update the aggregated watermark time.
+   */
+  final static class WatermarkState {
+    private int expectedTotal = Integer.MAX_VALUE;
+    private final Map<String, Long> timestamps = new HashMap<>();
+    private long watermarkTime = TIME_NOT_EXIST;
+
+    void update(long timestamp, String taskName, int taskCount) {
+      if (taskName != null) {
+        timestamps.put(taskName, timestamp);
+      }
+      expectedTotal = taskCount;
+
+      if (timestamps.size() == expectedTotal) {
+        Optional<Long> min = timestamps.values().stream().min(Long::compare);
+        watermarkTime = min.orElse(timestamp);
+      }
+    }
+
+    long getWatermarkTime() {
+      return watermarkTime;
+    }
+  }
+
+  /**
+   * Implementation of the Watermark. It keeps a reference to the {@link WatermarkManager}
+   */
+  class WatermarkImpl implements Watermark {
+    private final long timestamp;
+
+    WatermarkImpl(long timestamp) {
+      this.timestamp = timestamp;
+    }
+
+    @Override
+    public long getTimestamp() {
+      return timestamp;
+    }
+
+    @Override
+    public void propagate(SystemStream systemStream) {
+      sendWatermark(timestamp, systemStream, upstreamTaskCounts.get(systemStream));
+    }
+
+    @Override
+    public Watermark copyWithTimestamp(long time) {
+      return new WatermarkImpl(time);
+    }
+  }
+
+  /**
+   * Build a watermark control message envelope for an ssp of a source input.
+   * @param timestamp watermark time
+   * @param ssp {@link SystemStreamPartition} where the watermark coming from.
+   * @return envelope of the watermark control message
+   */
+  public static IncomingMessageEnvelope buildWatermarkEnvelope(long timestamp, SystemStreamPartition ssp) {
+    return new IncomingMessageEnvelope(ssp, null, "", new WatermarkMessage(timestamp, null, 0));
+  }
+}
diff --git a/samza-core/src/main/java/org/apache/samza/message/IntermediateMessageType.java b/samza-core/src/main/java/org/apache/samza/message/IntermediateMessageType.java
new file mode 100644 (file)
index 0000000..25fbb14
--- /dev/null
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.message;
+
+/**
+ * The type of the intermediate stream message. The enum will be encoded using its ordinal value and
+ * put in the first byte of the serialization of intermediate message.
+ * For more details, see {@link org.apache.samza.serializers.IntermediateMessageSerde}
+ */
+public enum IntermediateMessageType {
+  USER_MESSAGE,
+  WATERMARK_MESSAGE,
+  END_OF_STREAM_MESSAGE;
+
+  /**
+   * Returns the {@link IntermediateMessageType} of a particular intermediate stream message.
+   * @param message an intermediate stream message
+   * @return type of the message
+   */
+  public static IntermediateMessageType of(Object message) {
+    if (message instanceof WatermarkMessage) {
+      return WATERMARK_MESSAGE;
+    } else if (message instanceof EndOfStreamMessage) {
+      return END_OF_STREAM_MESSAGE;
+    } else {
+      return USER_MESSAGE;
+    }
+  }
+}
index c0da1b2..8718c06 100644 (file)
@@ -24,6 +24,7 @@ import org.apache.samza.operators.spec.InputOperatorSpec;
 import org.apache.samza.operators.spec.OutputStreamImpl;
 import org.apache.samza.operators.stream.IntermediateMessageStreamImpl;
 import org.apache.samza.operators.spec.OperatorSpec;
+import org.apache.samza.control.IOGraph;
 import org.apache.samza.runtime.ApplicationRunner;
 import org.apache.samza.system.StreamSpec;
 
@@ -205,4 +206,8 @@ public class StreamGraphImpl implements StreamGraph {
 
     return windowOrJoinSpecs.size() != 0;
   }
+
+  public IOGraph toIOGraph() {
+    return IOGraph.buildIOGraph(this);
+  }
 }
index 73bb83d..74e6748 100644 (file)
  */
 package org.apache.samza.operators.impl;
 
+import java.util.Collection;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.Set;
 import org.apache.samza.config.Config;
 import org.apache.samza.config.MetricsConfig;
+import org.apache.samza.control.Watermark;
+import org.apache.samza.control.WatermarkManager;
 import org.apache.samza.metrics.Counter;
 import org.apache.samza.metrics.MetricsRegistry;
 import org.apache.samza.metrics.Timer;
@@ -29,11 +35,6 @@ import org.apache.samza.task.TaskContext;
 import org.apache.samza.task.TaskCoordinator;
 import org.apache.samza.util.HighResolutionClock;
 
-import java.util.Collection;
-import java.util.Collections;
-import java.util.HashSet;
-import java.util.Set;
-
 
 /**
  * Abstract base class for all stream operator implementations.
@@ -47,8 +48,11 @@ public abstract class OperatorImpl<M, RM> {
   private Counter numMessage;
   private Timer handleMessageNs;
   private Timer handleTimerNs;
+  private long inputWatermarkTime = WatermarkManager.TIME_NOT_EXIST;
+  private long outputWatermarkTime = WatermarkManager.TIME_NOT_EXIST;
 
   Set<OperatorImpl<RM, ?>> registeredOperators;
+  Set<OperatorImpl<?, M>> prevOperators;
 
   /**
    * Initialize this {@link OperatorImpl} and its user-defined functions.
@@ -69,6 +73,7 @@ public abstract class OperatorImpl<M, RM> {
 
     this.highResClock = createHighResClock(config);
     registeredOperators = new HashSet<>();
+    prevOperators = new HashSet<>();
     MetricsRegistry metricsRegistry = context.getMetricsRegistry();
     this.numMessage = metricsRegistry.newCounter(METRICS_GROUP, opName + "-messages");
     this.handleMessageNs = metricsRegistry.newTimer(METRICS_GROUP, opName + "-handle-message-ns");
@@ -99,6 +104,11 @@ public abstract class OperatorImpl<M, RM> {
               getOperatorName()));
     }
     this.registeredOperators.add(nextOperator);
+    nextOperator.registerPrevOperator(this);
+  }
+
+  void registerPrevOperator(OperatorImpl<?, M> prevOperator) {
+    this.prevOperators.add(prevOperator);
   }
 
   /**
@@ -117,9 +127,7 @@ public abstract class OperatorImpl<M, RM> {
     long endNs = this.highResClock.nanoTime();
     this.handleMessageNs.update(endNs - startNs);
 
-    results.forEach(rm ->
-        this.registeredOperators.forEach(op ->
-            op.onMessage(rm, collector, coordinator)));
+    results.forEach(rm -> this.registeredOperators.forEach(op -> op.onMessage(rm, collector, coordinator)));
   }
 
   /**
@@ -147,9 +155,7 @@ public abstract class OperatorImpl<M, RM> {
     long endNs = this.highResClock.nanoTime();
     this.handleTimerNs.update(endNs - startNs);
 
-    results.forEach(rm ->
-        this.registeredOperators.forEach(op ->
-            op.onMessage(rm, collector, coordinator)));
+    results.forEach(rm -> this.registeredOperators.forEach(op -> op.onMessage(rm, collector, coordinator)));
     this.registeredOperators.forEach(op ->
         op.onTimer(collector, coordinator));
   }
@@ -167,6 +173,71 @@ public abstract class OperatorImpl<M, RM> {
     return Collections.emptyList();
   }
 
+  /**
+   * Populate the watermarks based on the following equations:
+   *
+   * <ul>
+   *   <li>InputWatermark(op) = min { OutputWatermark(op') | op1 is upstream of op}</li>
+   *   <li>OutputWatermark(op) = min { InputWatermark(op), OldestWorkTime(op) }</li>
+   * </ul>
+   *
+   * @param watermark incoming watermark
+   * @param collector message collector
+   * @param coordinator task coordinator
+   */
+  public final void onWatermark(Watermark watermark,
+      MessageCollector collector,
+      TaskCoordinator coordinator) {
+    final long inputWatermarkMin;
+    if (prevOperators.isEmpty()) {
+      // for input operator, use the watermark time coming from the source input
+      inputWatermarkMin = watermark.getTimestamp();
+    } else {
+      // InputWatermark(op) = min { OutputWatermark(op') | op1 is upstream of op}
+      inputWatermarkMin = prevOperators.stream().map(op -> op.getOutputWatermarkTime()).min(Long::compare).get();
+    }
+
+    if (inputWatermarkTime < inputWatermarkMin) {
+      // advance the watermark time of this operator
+      inputWatermarkTime = inputWatermarkMin;
+      Watermark inputWatermark = watermark.copyWithTimestamp(inputWatermarkTime);
+      long oldestWorkTime = handleWatermark(inputWatermark, collector, coordinator);
+
+      // OutputWatermark(op) = min { InputWatermark(op), OldestWorkTime(op) }
+      long outputWatermarkMin = Math.min(inputWatermarkTime, oldestWorkTime);
+      if (outputWatermarkTime < outputWatermarkMin) {
+        // populate the watermark to downstream
+        outputWatermarkTime = outputWatermarkMin;
+        Watermark outputWatermark = watermark.copyWithTimestamp(outputWatermarkTime);
+        this.registeredOperators.forEach(op -> op.onWatermark(outputWatermark, collector, coordinator));
+      }
+    }
+  }
+
+  /**
+   * Returns the oldest time of the envelops that haven't been processed by this operator
+   * Default implementation of handling watermark, which returns the input watermark time
+   * @param inputWatermark input watermark
+   * @param collector message collector
+   * @param coordinator task coordinator
+   * @return time of oldest processing envelope
+   */
+  protected long handleWatermark(Watermark inputWatermark,
+      MessageCollector collector,
+      TaskCoordinator coordinator) {
+    return inputWatermark.getTimestamp();
+  }
+
+  /* package private */
+  long getInputWatermarkTime() {
+    return this.inputWatermarkTime;
+  }
+
+  /* package private */
+  long getOutputWatermarkTime() {
+    return this.outputWatermarkTime;
+  }
+
   public void close() {
     if (closed) {
       throw new IllegalStateException(
index fe59b74..f212b3e 100644 (file)
  */
 package org.apache.samza.operators.impl;
 
+import java.util.Collection;
+import java.util.Collections;
 import org.apache.samza.config.Config;
+import org.apache.samza.control.Watermark;
 import org.apache.samza.operators.spec.OperatorSpec;
 import org.apache.samza.operators.spec.OutputOperatorSpec;
 import org.apache.samza.operators.spec.OutputStreamImpl;
@@ -28,9 +31,6 @@ import org.apache.samza.task.MessageCollector;
 import org.apache.samza.task.TaskContext;
 import org.apache.samza.task.TaskCoordinator;
 
-import java.util.Collection;
-import java.util.Collections;
-
 
 /**
  * An operator that sends incoming messages to an output {@link SystemStream}.
@@ -69,4 +69,14 @@ class OutputOperatorImpl<M> extends OperatorImpl<M, Void> {
   protected OperatorSpec<M, Void> getOperatorSpec() {
     return outputOpSpec;
   }
+
+  @Override
+  protected long handleWatermark(Watermark inputWatermark,
+      MessageCollector collector,
+      TaskCoordinator coordinator) {
+    if (outputOpSpec.getOpCode() == OperatorSpec.OpCode.PARTITION_BY) {
+      inputWatermark.propagate(outputStream.getStreamSpec().toSystemStream());
+    }
+    return inputWatermark.getTimestamp();
+  }
 }
index 14a14a8..653c0bb 100644 (file)
@@ -36,7 +36,6 @@ import org.apache.samza.container.SamzaContainerListener;
 import org.apache.samza.coordinator.JobCoordinator;
 import org.apache.samza.coordinator.JobCoordinatorFactory;
 import org.apache.samza.coordinator.JobCoordinatorListener;
-import org.apache.samza.job.model.ContainerModel;
 import org.apache.samza.job.model.JobModel;
 import org.apache.samza.metrics.MetricsReporter;
 import org.apache.samza.task.AsyncStreamTaskFactory;
@@ -187,11 +186,11 @@ public class StreamProcessor {
 
   }
 
-  SamzaContainer createSamzaContainer(ContainerModel containerModel, int maxChangelogStreamPartitions) {
+  SamzaContainer createSamzaContainer(String processorId, JobModel jobModel) {
     return SamzaContainer.apply(
-        containerModel,
+        processorId,
+        jobModel,
         config,
-        maxChangelogStreamPartitions,
         Util.<String, MetricsReporter>javaMapAsScalaMap(customMetricsReporter),
         taskFactory);
   }
@@ -283,9 +282,7 @@ public class StreamProcessor {
             }
           };
 
-          container = createSamzaContainer(
-              jobModel.getContainers().get(processorId),
-              jobModel.maxChangeLogStreamPartitions);
+          container = createSamzaContainer(processorId, jobModel);
           container.setContainerListener(containerListener);
           LOGGER.info("Starting container " + container.toString());
           executorService = Executors.newSingleThreadExecutor(new ThreadFactoryBuilder()
@@ -318,4 +315,9 @@ public class StreamProcessor {
       }
     };
   }
+
+  /* package private for testing */
+  SamzaContainer getContainer() {
+    return container;
+  }
 }
index b0bfc8a..995645d 100644 (file)
@@ -272,4 +272,9 @@ public class LocalApplicationRunner extends AbstractApplicationRunner {
           taskFactory.getClass().getCanonicalName()));
     }
   }
+
+  /* package private for testing */
+  Set<StreamProcessor> getProcessors() {
+    return processors;
+  }
 }
index 5d0e455..50c8181 100644 (file)
@@ -34,7 +34,6 @@ import org.apache.samza.container.SamzaContainer$;
 import org.apache.samza.container.SamzaContainerExceptionHandler;
 import org.apache.samza.container.SamzaContainerListener;
 import org.apache.samza.job.ApplicationStatus;
-import org.apache.samza.job.model.ContainerModel;
 import org.apache.samza.job.model.JobModel;
 import org.apache.samza.metrics.MetricsReporter;
 import org.apache.samza.task.TaskFactoryUtil;
@@ -73,13 +72,12 @@ public class LocalContainerRunner extends AbstractApplicationRunner {
 
   @Override
   public void run(StreamApplication streamApp) {
-    ContainerModel containerModel = jobModel.getContainers().get(containerId);
     Object taskFactory = TaskFactoryUtil.createTaskFactory(config, streamApp, this);
 
     container = SamzaContainer$.MODULE$.apply(
-        containerModel,
+        containerId,
+        jobModel,
         config,
-        jobModel.maxChangeLogStreamPartitions,
         Util.<String, MetricsReporter>javaMapAsScalaMap(new HashMap<>()),
         taskFactory);
     container.setContainerListener(
index 26ef92c..0b98ec6 100644 (file)
@@ -22,7 +22,7 @@ package org.apache.samza.serializers;
 import java.util.Arrays;
 import org.apache.samza.SamzaException;
 import org.apache.samza.message.EndOfStreamMessage;
-import org.apache.samza.message.MessageType;
+import org.apache.samza.message.IntermediateMessageType;
 import org.apache.samza.message.WatermarkMessage;
 import org.codehaus.jackson.type.TypeReference;
 
@@ -86,16 +86,16 @@ public class IntermediateMessageSerde implements Serde<Object> {
   public Object fromBytes(byte[] bytes) {
     try {
       final Object object;
-      final MessageType type = MessageType.values()[bytes[0]];
+      final IntermediateMessageType type = IntermediateMessageType.values()[bytes[0]];
       final byte [] data = Arrays.copyOfRange(bytes, 1, bytes.length);
       switch (type) {
         case USER_MESSAGE:
           object = userMessageSerde.fromBytes(data);
           break;
-        case WATERMARK:
+        case WATERMARK_MESSAGE:
           object = watermarkSerde.fromBytes(data);
           break;
-        case END_OF_STREAM:
+        case END_OF_STREAM_MESSAGE:
           object = eosSerde.fromBytes(data);
           break;
         default:
@@ -117,15 +117,15 @@ public class IntermediateMessageSerde implements Serde<Object> {
   @Override
   public byte[] toBytes(Object object) {
     final byte [] data;
-    final MessageType type = MessageType.of(object);
+    final IntermediateMessageType type = IntermediateMessageType.of(object);
     switch (type) {
       case USER_MESSAGE:
         data = userMessageSerde.toBytes(object);
         break;
-      case WATERMARK:
+      case WATERMARK_MESSAGE:
         data = watermarkSerde.toBytes((WatermarkMessage) object);
         break;
-      case END_OF_STREAM:
+      case END_OF_STREAM_MESSAGE:
         data = eosSerde.toBytes((EndOfStreamMessage) object);
         break;
       default:
index e5c40df..b2903fb 100644 (file)
@@ -320,7 +320,7 @@ public class AsyncRunLoop implements Runnable, Throttleable {
       this.task = task;
       this.callbackManager = new TaskCallbackManager(this, callbackTimer, callbackTimeoutMs, maxConcurrency, clock);
       Set<SystemStreamPartition> sspSet = getWorkingSSPSet(task);
-      this.state = new AsyncTaskState(task.taskName(), task.metrics(), sspSet);
+      this.state = new AsyncTaskState(task.taskName(), task.metrics(), sspSet, task.hasIntermediateStreams());
     }
 
     private void init() {
@@ -581,12 +581,14 @@ public class AsyncRunLoop implements Runnable, Throttleable {
     private final Set<SystemStreamPartition> processingSspSet;
     private final TaskName taskName;
     private final TaskInstanceMetrics taskMetrics;
+    private final boolean hasIntermediateStreams;
 
-    AsyncTaskState(TaskName taskName, TaskInstanceMetrics taskMetrics, Set<SystemStreamPartition> sspSet) {
+    AsyncTaskState(TaskName taskName, TaskInstanceMetrics taskMetrics, Set<SystemStreamPartition> sspSet, boolean hasIntermediateStreams) {
       this.taskName = taskName;
       this.taskMetrics = taskMetrics;
       this.pendingEnvelopeQueue = new ArrayDeque<>();
       this.processingSspSet = sspSet;
+      this.hasIntermediateStreams = hasIntermediateStreams;
     }
 
     private boolean checkEndOfStream() {
@@ -597,7 +599,9 @@ public class AsyncRunLoop implements Runnable, Throttleable {
         if (envelope.isEndOfStream()) {
           SystemStreamPartition ssp = envelope.getSystemStreamPartition();
           processingSspSet.remove(ssp);
-          pendingEnvelopeQueue.remove();
+          if (!hasIntermediateStreams) {
+            pendingEnvelopeQueue.remove();
+          }
         }
       }
       return processingSspSet.isEmpty();
@@ -651,7 +655,7 @@ public class AsyncRunLoop implements Runnable, Throttleable {
       if (isReady()) {
         if (needCommit) return WorkerOp.COMMIT;
         else if (needWindow) return WorkerOp.WINDOW;
-        else if (endOfStream) return WorkerOp.END_OF_STREAM;
+        else if (endOfStream && pendingEnvelopeQueue.isEmpty()) return WorkerOp.END_OF_STREAM;
         else if (!pendingEnvelopeQueue.isEmpty()) return WorkerOp.PROCESS;
       }
       return WorkerOp.NO_OP;
index e2fea95..e57a89f 100644 (file)
@@ -21,7 +21,11 @@ package org.apache.samza.task;
 
 import java.util.concurrent.ExecutorService;
 import org.apache.samza.config.Config;
+import org.apache.samza.control.ControlMessageListenerTask;
+import org.apache.samza.control.Watermark;
+import org.apache.samza.control.IOGraph;
 import org.apache.samza.system.IncomingMessageEnvelope;
+import org.apache.samza.system.SystemStream;
 
 
 /**
@@ -30,7 +34,7 @@ import org.apache.samza.system.IncomingMessageEnvelope;
  * the callbacks once it's done. If the thread pool is null, it follows the legacy
  * synchronous model to execute the tasks on the run loop thread.
  */
-public class AsyncStreamTaskAdapter implements AsyncStreamTask, InitableTask, WindowableTask, ClosableTask, EndOfStreamListenerTask {
+public class AsyncStreamTaskAdapter implements AsyncStreamTask, InitableTask, WindowableTask, ClosableTask, EndOfStreamListenerTask, ControlMessageListenerTask {
   private final StreamTask wrappedTask;
   private final ExecutorService executor;
 
@@ -96,4 +100,20 @@ public class AsyncStreamTaskAdapter implements AsyncStreamTask, InitableTask, Wi
       ((EndOfStreamListenerTask) wrappedTask).onEndOfStream(collector, coordinator);
     }
   }
+
+  @Override
+  public IOGraph getIOGraph() {
+    if (wrappedTask instanceof ControlMessageListenerTask) {
+      return ((ControlMessageListenerTask) wrappedTask).getIOGraph();
+    }
+    return null;
+  }
+
+  @Override
+  public void onWatermark(Watermark watermark, SystemStream stream, MessageCollector collector, TaskCoordinator coordinator) {
+    if (wrappedTask instanceof ControlMessageListenerTask) {
+      ((ControlMessageListenerTask) wrappedTask).onWatermark(watermark, stream, collector, coordinator);
+    }
+  }
+
 }
index a77ef3b..16b7e40 100644 (file)
@@ -21,10 +21,13 @@ package org.apache.samza.task;
 import org.apache.commons.lang3.tuple.Pair;
 import org.apache.samza.application.StreamApplication;
 import org.apache.samza.config.Config;
+import org.apache.samza.control.ControlMessageListenerTask;
+import org.apache.samza.control.Watermark;
 import org.apache.samza.operators.ContextManager;
 import org.apache.samza.operators.StreamGraphImpl;
 import org.apache.samza.operators.impl.InputOperatorImpl;
 import org.apache.samza.operators.impl.OperatorImplGraph;
+import org.apache.samza.control.IOGraph;
 import org.apache.samza.runtime.ApplicationRunner;
 import org.apache.samza.system.IncomingMessageEnvelope;
 import org.apache.samza.system.SystemStream;
@@ -36,7 +39,7 @@ import org.apache.samza.util.SystemClock;
  * A {@link StreamTask} implementation that brings all the operator API implementation components together and
  * feeds the input messages into the user-defined transformation chains in {@link StreamApplication}.
  */
-public final class StreamOperatorTask implements StreamTask, InitableTask, WindowableTask, ClosableTask {
+public final class StreamOperatorTask implements StreamTask, InitableTask, WindowableTask, ClosableTask, ControlMessageListenerTask {
 
   private final StreamApplication streamApplication;
   private final ApplicationRunner runner;
@@ -44,6 +47,7 @@ public final class StreamOperatorTask implements StreamTask, InitableTask, Windo
 
   private OperatorImplGraph operatorImplGraph;
   private ContextManager contextManager;
+  private IOGraph ioGraph;
 
   /**
    * Constructs an adaptor task to run the user-implemented {@link StreamApplication}.
@@ -87,6 +91,7 @@ public final class StreamOperatorTask implements StreamTask, InitableTask, Windo
 
     // create the operator impl DAG corresponding to the logical operator spec DAG
     this.operatorImplGraph = new OperatorImplGraph(streamGraph, config, context, clock);
+    this.ioGraph = streamGraph.toIOGraph();
   }
 
   /**
@@ -116,10 +121,31 @@ public final class StreamOperatorTask implements StreamTask, InitableTask, Windo
   }
 
   @Override
+  public IOGraph getIOGraph() {
+    return ioGraph;
+  }
+
+  @Override
+  public final void onWatermark(Watermark watermark,
+      SystemStream systemStream,
+      MessageCollector collector,
+      TaskCoordinator coordinator) {
+    InputOperatorImpl inputOpImpl = operatorImplGraph.getInputOperator(systemStream);
+    if (inputOpImpl != null) {
+      inputOpImpl.onWatermark(watermark, collector, coordinator);
+    }
+  }
+
+  @Override
   public void close() throws Exception {
     if (this.contextManager != null) {
       this.contextManager.close();
     }
     operatorImplGraph.close();
   }
+
+  /* package private for testing */
+  OperatorImplGraph getOperatorImplGraph() {
+    return this.operatorImplGraph;
+  }
 }
index 3bf5c95..481cbcf 100644 (file)
@@ -109,13 +109,14 @@ object SamzaContainer extends Logging {
   }
 
   def apply(
-    containerModel: ContainerModel,
+    containerId: String,
+    jobModel: JobModel,
     config: Config,
-    maxChangeLogStreamPartitions: Int,
     customReporters: Map[String, MetricsReporter] = Map[String, MetricsReporter](),
     taskFactory: Object) = {
-    val containerId = containerModel.getProcessorId()
+    val containerModel = jobModel.getContainers.get(containerId)
     val containerName = "samza-container-%s" format containerId
+    val maxChangeLogStreamPartitions = jobModel.maxChangeLogStreamPartitions
 
     var localityManager: LocalityManager = null
     if (new ClusterManagerConfig(config).getHostAffinityEnabled()) {
@@ -558,7 +559,9 @@ object SamzaContainer extends Logging {
           storageManager = storageManager,
           reporters = reporters,
           systemStreamPartitions = systemStreamPartitions,
-          exceptionHandler = TaskInstanceExceptionHandler(taskInstanceMetrics, config))
+          exceptionHandler = TaskInstanceExceptionHandler(taskInstanceMetrics, config),
+          jobModel = jobModel,
+          streamMetadataCache = streamMetadataCache)
 
       val taskInstance = createTaskInstance(task)
 
@@ -660,6 +663,8 @@ class SamzaContainer(
 
   def getStatus(): SamzaContainerStatus = status
 
+  def getTaskInstances() = taskInstances
+
   def setContainerListener(listener: SamzaContainerListener): Unit = {
     containerListener = listener
   }
index 84e993b..c14908c 100644 (file)
 package org.apache.samza.container
 
 
+import com.google.common.collect.HashMultimap
+import com.google.common.collect.Multimap
 import org.apache.samza.SamzaException
 import org.apache.samza.checkpoint.OffsetManager
 import org.apache.samza.config.Config
+import org.apache.samza.config.StreamConfig.Config2Stream
+import org.apache.samza.control.ControlMessageListenerTask
+import org.apache.samza.control.ControlMessageUtils
+import org.apache.samza.control.EndOfStreamManager
+import org.apache.samza.control.WatermarkManager
+import org.apache.samza.job.model.JobModel
+import org.apache.samza.message.MessageType
 import org.apache.samza.metrics.MetricsReporter
 import org.apache.samza.storage.TaskStorageManager
 import org.apache.samza.system.IncomingMessageEnvelope
+import org.apache.samza.system.StreamMetadataCache
 import org.apache.samza.system.SystemAdmin
 import org.apache.samza.system.SystemConsumers
+import org.apache.samza.system.SystemStream
 import org.apache.samza.system.SystemStreamPartition
 import org.apache.samza.task.AsyncStreamTask
 import org.apache.samza.task.ClosableTask
@@ -42,9 +53,31 @@ import org.apache.samza.task.WindowableTask
 import org.apache.samza.util.Logging
 
 import scala.collection.JavaConverters._
+import scala.collection.JavaConversions._
+
+object TaskInstance {
+  /**
+   * Build a map from a stream to its consumer tasks
+   * @param jobModel job model which contains ssp-to-task assignment
+   * @return the map of input stream to tasks
+   */
+  def buildInputToTasks(jobModel: JobModel): Multimap[SystemStream, String] = {
+    val streamToTasks: Multimap[SystemStream, String] = HashMultimap.create[SystemStream, String]
+    if (jobModel != null) {
+      for (containerModel <- jobModel.getContainers.values) {
+        for (taskModel <- containerModel.getTasks.values) {
+          for (ssp <- taskModel.getSystemStreamPartitions) {
+            streamToTasks.put(ssp.getSystemStream, taskModel.getTaskName.toString)
+          }
+        }
+      }
+    }
+    return streamToTasks
+  }
+}
 
 class TaskInstance(
-  task: Any,
+  val task: Any,
   val taskName: TaskName,
   config: Config,
   val metrics: TaskInstanceMetrics,
@@ -56,12 +89,15 @@ class TaskInstance(
   storageManager: TaskStorageManager = null,
   reporters: Map[String, MetricsReporter] = Map(),
   val systemStreamPartitions: Set[SystemStreamPartition] = Set(),
-  val exceptionHandler: TaskInstanceExceptionHandler = new TaskInstanceExceptionHandler) extends Logging {
+  val exceptionHandler: TaskInstanceExceptionHandler = new TaskInstanceExceptionHandler,
+  jobModel: JobModel = null,
+  streamMetadataCache: StreamMetadataCache = null) extends Logging {
   val isInitableTask = task.isInstanceOf[InitableTask]
   val isWindowableTask = task.isInstanceOf[WindowableTask]
   val isEndOfStreamListenerTask = task.isInstanceOf[EndOfStreamListenerTask]
   val isClosableTask = task.isInstanceOf[ClosableTask]
   val isAsyncTask = task.isInstanceOf[AsyncStreamTask]
+  val isControlMessageListener = task.isInstanceOf[ControlMessageListenerTask]
 
   val context = new TaskContext {
     var userContext: Object = null;
@@ -93,9 +129,15 @@ class TaskInstance(
   // store the (ssp -> if this ssp is catched up) mapping. "catched up"
   // means the same ssp in other taskInstances have the same offset as
   // the one here.
-  var ssp2catchedupMapping: scala.collection.mutable.Map[SystemStreamPartition, Boolean] =
+  var ssp2CaughtupMapping: scala.collection.mutable.Map[SystemStreamPartition, Boolean] =
     scala.collection.mutable.Map[SystemStreamPartition, Boolean]()
-  systemStreamPartitions.foreach(ssp2catchedupMapping += _ -> false)
+  systemStreamPartitions.foreach(ssp2CaughtupMapping += _ -> false)
+
+  val inputToTasksMapping = TaskInstance.buildInputToTasks(jobModel)
+  var endOfStreamManager: EndOfStreamManager = null
+  var watermarkManager: WatermarkManager = null
+
+  val hasIntermediateStreams = config.getStreamIds.exists(config.getIsIntermediate(_))
 
   def registerMetrics {
     debug("Registering metrics for taskName: %s" format taskName)
@@ -127,6 +169,22 @@ class TaskInstance(
     } else {
       debug("Skipping task initialization for taskName: %s" format taskName)
     }
+
+    if (isControlMessageListener) {
+      endOfStreamManager = new EndOfStreamManager(taskName.getTaskName,
+                                                  task.asInstanceOf[ControlMessageListenerTask],
+                                                  inputToTasksMapping,
+                                                  systemStreamPartitions.asJava,
+                                                  streamMetadataCache,
+                                                  collector)
+
+      watermarkManager = new WatermarkManager(taskName.getTaskName,
+                                                  task.asInstanceOf[ControlMessageListenerTask],
+                                                  inputToTasksMapping,
+                                                  systemStreamPartitions.asJava,
+                                                  streamMetadataCache,
+                                                  collector)
+    }
   }
 
   def registerProducers {
@@ -154,31 +212,62 @@ class TaskInstance(
     callbackFactory: TaskCallbackFactory = null) {
     metrics.processes.inc
 
-    if (!ssp2catchedupMapping.getOrElse(envelope.getSystemStreamPartition,
+    if (!ssp2CaughtupMapping.getOrElse(envelope.getSystemStreamPartition,
       throw new SamzaException(envelope.getSystemStreamPartition + " is not registered!"))) {
       checkCaughtUp(envelope)
     }
 
-    if (ssp2catchedupMapping(envelope.getSystemStreamPartition)) {
+    if (ssp2CaughtupMapping(envelope.getSystemStreamPartition)) {
       metrics.messagesActuallyProcessed.inc
 
       trace("Processing incoming message envelope for taskName and SSP: %s, %s"
         format (taskName, envelope.getSystemStreamPartition))
 
-      if (isAsyncTask) {
-        exceptionHandler.maybeHandle {
-          val callback = callbackFactory.createCallback()
-          task.asInstanceOf[AsyncStreamTask].processAsync(envelope, collector, coordinator, callback)
-        }
-      } else {
-        exceptionHandler.maybeHandle {
-         task.asInstanceOf[StreamTask].process(envelope, collector, coordinator)
-        }
+      MessageType.of(envelope.getMessage) match {
+        case MessageType.USER_MESSAGE =>
+          if (isAsyncTask) {
+            exceptionHandler.maybeHandle {
+             val callback = callbackFactory.createCallback()
+             task.asInstanceOf[AsyncStreamTask].processAsync(envelope, collector, coordinator, callback)
+            }
+          }
+          else {
+            exceptionHandler.maybeHandle {
+             task.asInstanceOf[StreamTask].process(envelope, collector, coordinator)
+            }
+
+            trace("Updating offset map for taskName, SSP and offset: %s, %s, %s"
+              format(taskName, envelope.getSystemStreamPartition, envelope.getOffset))
+
+            offsetManager.update(taskName, envelope.getSystemStreamPartition, envelope.getOffset)
+          }
 
-        trace("Updating offset map for taskName, SSP and offset: %s, %s, %s"
-          format (taskName, envelope.getSystemStreamPartition, envelope.getOffset))
+        case MessageType.END_OF_STREAM =>
+          if (isControlMessageListener) {
+            // handle eos synchronously.
+            runSync(callbackFactory) {
+              endOfStreamManager.update(envelope, coordinator)
+            }
+          } else {
+            warn("Ignore end-of-stream message due to %s not implementing ControlMessageListener."
+              format(task.getClass.toString))
+          }
 
-        offsetManager.update(taskName, envelope.getSystemStreamPartition, envelope.getOffset)
+        case MessageType.WATERMARK =>
+          if (isControlMessageListener) {
+            // handle watermark synchronously in the run loop thread.
+            // we might consider running it asynchronously later
+            runSync(callbackFactory) {
+              val watermark = watermarkManager.update(envelope)
+              if (watermark != null) {
+                val stream = envelope.getSystemStreamPartition.getSystemStream
+                task.asInstanceOf[ControlMessageListenerTask].onWatermark(watermark, stream, collector, coordinator)
+              }
+            }
+          } else {
+            warn("Ignore watermark message due to %s not implementing ControlMessageListener."
+              format(task.getClass.toString))
+          }
       }
     }
   }
@@ -255,7 +344,7 @@ class TaskInstance(
     systemAdmins match {
       case null => {
         warn("systemAdmin is null. Set all SystemStreamPartitions to catched-up")
-        ssp2catchedupMapping(envelope.getSystemStreamPartition) = true
+        ssp2CaughtupMapping(envelope.getSystemStreamPartition) = true
       }
       case others => {
         val startingOffset = offsetManager.getStartingOffset(taskName, envelope.getSystemStreamPartition)
@@ -264,16 +353,26 @@ class TaskInstance(
         others(system).offsetComparator(envelope.getOffset, startingOffset) match {
           case null => {
             info("offsets in " + system + " is not comparable. Set all SystemStreamPartitions to catched-up")
-            ssp2catchedupMapping(envelope.getSystemStreamPartition) = true // not comparable
+            ssp2CaughtupMapping(envelope.getSystemStreamPartition) = true // not comparable
           }
           case result => {
             if (result >= 0) {
               info(envelope.getSystemStreamPartition.toString + " is catched up.")
-              ssp2catchedupMapping(envelope.getSystemStreamPartition) = true
+              ssp2CaughtupMapping(envelope.getSystemStreamPartition) = true
             }
           }
         }
       }
     }
   }
+
+  private def runSync(callbackFactory: TaskCallbackFactory)(runCodeBlock: => Unit) = {
+    val callback = callbackFactory.createCallback()
+    try {
+      runCodeBlock
+      callback.complete()
+    } catch {
+      case t: Throwable => callback.failure(t)
+    }
+  }
 }
index 385a060..6de4ce0 100644 (file)
@@ -38,7 +38,7 @@ class ThreadJobFactory extends StreamJobFactory with Logging {
     info("Creating a ThreadJob, which is only meant for debugging.")
     val coordinator = JobModelManager(config)
     val jobModel = coordinator.jobModel
-    val containerModel = jobModel.getContainers.get("0")
+    val containerId = "0"
     val jmxServer = new JmxServer
     val streamApp = TaskFactoryUtil.createStreamApplication(config)
     val appRunner = new LocalContainerRunner(jobModel, "0")
@@ -66,9 +66,9 @@ class ThreadJobFactory extends StreamJobFactory with Logging {
     try {
       coordinator.start
       val container = SamzaContainer(
-        containerModel,
+        containerId,
+        jobModel,
         config,
-        jobModel.maxChangeLogStreamPartitions,
         Map[String, MetricsReporter](),
         taskFactory)
       container.setContainerListener(containerListener)
index a1b1e27..271279f 100644 (file)
@@ -100,6 +100,17 @@ class StreamMetadataCache (
     allResults
   }
 
+  /**
+   * Returns metadata about the given streams. If the metadata isn't in the cache, it is retrieved from the systems
+   * using the given SystemAdmins.
+   *
+   * @param stream SystemStreams for which the metadata is requested
+   * @param partitionsMetadataOnly Flag to indicate that only partition count metadata should be fetched/refreshed
+   */
+  def getSystemStreamMetadata(stream: SystemStream, partitionsMetadataOnly: Boolean): SystemStreamMetadata = {
+    getStreamMetadata(Set(stream), partitionsMetadataOnly).get(stream).orNull
+  }
+
   private def getFromCache(stream: SystemStream, now: Long) = {
     cache.get(stream) match {
       case Some(CacheEntry(metadata, lastRefresh)) =>
diff --git a/samza-core/src/test/java/org/apache/samza/control/TestControlMessageUtils.java b/samza-core/src/test/java/org/apache/samza/control/TestControlMessageUtils.java
new file mode 100644 (file)
index 0000000..8351802
--- /dev/null
@@ -0,0 +1,115 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.control;
+
+import com.google.common.collect.HashMultimap;
+import com.google.common.collect.Multimap;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import org.apache.samza.Partition;
+import org.apache.samza.container.TaskName;
+import org.apache.samza.message.ControlMessage;
+import org.apache.samza.system.OutgoingMessageEnvelope;
+import org.apache.samza.system.StreamMetadataCache;
+import org.apache.samza.system.StreamSpec;
+import org.apache.samza.system.SystemStream;
+import org.apache.samza.system.SystemStreamMetadata;
+import org.apache.samza.task.MessageCollector;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+import static org.mockito.Matchers.any;
+import static org.mockito.Matchers.anyBoolean;
+import static org.mockito.Matchers.anyObject;
+import static org.mockito.Mockito.doAnswer;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+
+public class TestControlMessageUtils {
+
+  @Test
+  public void testSendControlMessage() {
+    SystemStreamMetadata metadata = mock(SystemStreamMetadata.class);
+    Map<Partition, SystemStreamMetadata.SystemStreamPartitionMetadata> partitionMetadata = new HashMap<>();
+    partitionMetadata.put(new Partition(0), mock(SystemStreamMetadata.SystemStreamPartitionMetadata.class));
+    partitionMetadata.put(new Partition(1), mock(SystemStreamMetadata.SystemStreamPartitionMetadata.class));
+    partitionMetadata.put(new Partition(2), mock(SystemStreamMetadata.SystemStreamPartitionMetadata.class));
+    partitionMetadata.put(new Partition(3), mock(SystemStreamMetadata.SystemStreamPartitionMetadata.class));
+    when(metadata.getSystemStreamPartitionMetadata()).thenReturn(partitionMetadata);
+    StreamMetadataCache metadataCache = mock(StreamMetadataCache.class);
+    when(metadataCache.getSystemStreamMetadata(anyObject(), anyBoolean())).thenReturn(metadata);
+
+    SystemStream systemStream = new SystemStream("test-system", "test-stream");
+    Set<Integer> partitions = new HashSet<>();
+    MessageCollector collector = mock(MessageCollector.class);
+    doAnswer(invocation -> {
+        OutgoingMessageEnvelope envelope = (OutgoingMessageEnvelope) invocation.getArguments()[0];
+        partitions.add((Integer) envelope.getPartitionKey());
+        assertEquals(envelope.getSystemStream(), systemStream);
+        return null;
+      }).when(collector).send(any());
+
+    ControlMessageUtils.sendControlMessage(mock(ControlMessage.class), systemStream, metadataCache, collector);
+    assertEquals(partitions.size(), 4);
+  }
+
+  @Test
+  public void testCalculateUpstreamTaskCounts() {
+    SystemStream input1 = new SystemStream("test-system", "input-stream-1");
+    SystemStream input2 = new SystemStream("test-system", "input-stream-2");
+    SystemStream input3 = new SystemStream("test-system", "input-stream-3");
+
+    Multimap<SystemStream, String> inputToTasks = HashMultimap.create();
+    TaskName t0 = new TaskName("task 0"); //consume input1 and input2
+    TaskName t1 = new TaskName("task 1"); //consume input1 and input2 and input 3
+    TaskName t2 = new TaskName("task 2"); //consume input2 and input 3
+    inputToTasks.put(input1, t0.getTaskName());
+    inputToTasks.put(input1, t1.getTaskName());
+    inputToTasks.put(input2, t0.getTaskName());
+    inputToTasks.put(input2, t1.getTaskName());
+    inputToTasks.put(input2, t2.getTaskName());
+    inputToTasks.put(input3, t1.getTaskName());
+    inputToTasks.put(input3, t2.getTaskName());
+
+    StreamSpec inputSpec2 = new StreamSpec("input-stream-2", "input-stream-2", "test-system");
+    StreamSpec inputSpec3 = new StreamSpec("input-stream-3", "input-stream-3", "test-system");
+    StreamSpec intSpec1 = new StreamSpec("int-stream-1", "int-stream-1", "test-system");
+    StreamSpec intSpec2 = new StreamSpec("int-stream-2", "int-stream-2", "test-system");
+
+    List<IOGraph.IONode> nodes = new ArrayList<>();
+    IOGraph.IONode node = new IOGraph.IONode(intSpec1, true);
+    node.addInput(inputSpec2);
+    nodes.add(node);
+    node = new IOGraph.IONode(intSpec2, true);
+    node.addInput(inputSpec3);
+    nodes.add(node);
+    IOGraph ioGraph = new IOGraph(nodes);
+
+    Map<SystemStream, Integer> counts = ControlMessageUtils.calculateUpstreamTaskCounts(inputToTasks, ioGraph);
+    assertEquals(counts.get(intSpec1.toSystemStream()).intValue(), 3);
+    assertEquals(counts.get(intSpec2.toSystemStream()).intValue(), 2);
+  }
+
+}
\ No newline at end of file
diff --git a/samza-core/src/test/java/org/apache/samza/control/TestEndOfStreamManager.java b/samza-core/src/test/java/org/apache/samza/control/TestEndOfStreamManager.java
new file mode 100644 (file)
index 0000000..cc70b6b
--- /dev/null
@@ -0,0 +1,333 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.control;
+
+import com.google.common.collect.HashMultimap;
+import com.google.common.collect.Multimap;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import org.apache.samza.Partition;
+import org.apache.samza.container.TaskName;
+import org.apache.samza.message.EndOfStreamMessage;
+import org.apache.samza.operators.spec.OperatorSpecs;
+import org.apache.samza.operators.spec.OutputOperatorSpec;
+import org.apache.samza.operators.spec.OutputStreamImpl;
+import org.apache.samza.system.IncomingMessageEnvelope;
+import org.apache.samza.system.OutgoingMessageEnvelope;
+import org.apache.samza.system.StreamMetadataCache;
+import org.apache.samza.system.StreamSpec;
+import org.apache.samza.system.SystemStream;
+import org.apache.samza.system.SystemStreamMetadata;
+import org.apache.samza.system.SystemStreamPartition;
+import org.apache.samza.task.MessageCollector;
+import org.apache.samza.task.TaskCoordinator;
+import org.junit.Before;
+import org.junit.Test;
+import org.mockito.ArgumentCaptor;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+import static org.mockito.Matchers.any;
+import static org.mockito.Matchers.anyBoolean;
+import static org.mockito.Matchers.anyInt;
+import static org.mockito.Matchers.anyObject;
+import static org.mockito.Mockito.doAnswer;
+import static org.mockito.Mockito.doNothing;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.spy;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+
+public class TestEndOfStreamManager {
+  StreamMetadataCache metadataCache;
+
+  @Before
+  public void setup() {
+    SystemStreamMetadata metadata = mock(SystemStreamMetadata.class);
+    Map<Partition, SystemStreamMetadata.SystemStreamPartitionMetadata> partitionMetadata = new HashMap<>();
+    partitionMetadata.put(new Partition(0), mock(SystemStreamMetadata.SystemStreamPartitionMetadata.class));
+    partitionMetadata.put(new Partition(1), mock(SystemStreamMetadata.SystemStreamPartitionMetadata.class));
+    partitionMetadata.put(new Partition(2), mock(SystemStreamMetadata.SystemStreamPartitionMetadata.class));
+    partitionMetadata.put(new Partition(3), mock(SystemStreamMetadata.SystemStreamPartitionMetadata.class));
+    when(metadata.getSystemStreamPartitionMetadata()).thenReturn(partitionMetadata);
+    metadataCache = mock(StreamMetadataCache.class);
+    when(metadataCache.getSystemStreamMetadata(anyObject(), anyBoolean())).thenReturn(metadata);
+  }
+
+  @Test
+  public void testUpdateFromInputSource() {
+    SystemStreamPartition ssp = new SystemStreamPartition("test-system", "test-stream", new Partition(0));
+    TaskName taskName = new TaskName("Task 0");
+    Multimap<SystemStream, String> streamToTasks = HashMultimap.create();
+    streamToTasks.put(ssp.getSystemStream(), taskName.getTaskName());
+    ControlMessageListenerTask listener = mock(ControlMessageListenerTask.class);
+    when(listener.getIOGraph()).thenReturn(new IOGraph(Collections.emptyList()));
+    EndOfStreamManager manager = new EndOfStreamManager("Task 0", listener, streamToTasks, Collections.singleton(ssp), null, null);
+    manager.update(EndOfStreamManager.buildEndOfStreamEnvelope(ssp), mock(TaskCoordinator.class));
+    assertTrue(manager.isEndOfStream(ssp.getSystemStream()));
+  }
+
+  @Test
+  public void testUpdateFromIntermediateStream() {
+    SystemStreamPartition[] ssps = new SystemStreamPartition[3];
+    ssps[0] = new SystemStreamPartition("test-system", "test-stream-1", new Partition(0));
+    ssps[1] = new SystemStreamPartition("test-system", "test-stream-2", new Partition(0));
+    ssps[2] = new SystemStreamPartition("test-system", "test-stream-2", new Partition(1));
+
+    TaskName taskName = new TaskName("Task 0");
+    Multimap<SystemStream, String> streamToTasks = HashMultimap.create();
+    for (SystemStreamPartition ssp : ssps) {
+      streamToTasks.put(ssp.getSystemStream(), taskName.getTaskName());
+    }
+    ControlMessageListenerTask listener = mock(ControlMessageListenerTask.class);
+    when(listener.getIOGraph()).thenReturn(new IOGraph(Collections.emptyList()));
+    EndOfStreamManager manager = new EndOfStreamManager("Task 0", listener, streamToTasks, new HashSet<>(Arrays.asList(ssps)), null, null);
+
+    int envelopeCount = 4;
+    IncomingMessageEnvelope[] envelopes = new IncomingMessageEnvelope[envelopeCount];
+    for (int i = 0; i < envelopeCount; i++) {
+      envelopes[i] = new IncomingMessageEnvelope(ssps[0], "dummy-offset", "", new EndOfStreamMessage("task " + i, envelopeCount));
+    }
+    TaskCoordinator coordinator = mock(TaskCoordinator.class);
+
+    // verify the first three messages won't result in end-of-stream
+    for (int i = 0; i < 3; i++) {
+      manager.update(envelopes[i], coordinator);
+      assertFalse(manager.isEndOfStream(ssps[0].getSystemStream()));
+    }
+    // the fourth message will end the stream
+    manager.update(envelopes[3], coordinator);
+    assertTrue(manager.isEndOfStream(ssps[0].getSystemStream()));
+    assertFalse(manager.isEndOfStream(ssps[1].getSystemStream()));
+
+    // stream2 has two partitions assigned to this task, so it requires a message from each partition to end it
+    envelopes = new IncomingMessageEnvelope[envelopeCount];
+    for (int i = 0; i < envelopeCount; i++) {
+      envelopes[i] = new IncomingMessageEnvelope(ssps[1], "dummy-offset", "dummy-key", new EndOfStreamMessage("task " + i, envelopeCount));
+    }
+    // verify the messages for the partition 0 won't result in end-of-stream
+    for (int i = 0; i < 4; i++) {
+      manager.update(envelopes[i], coordinator);
+      assertFalse(manager.isEndOfStream(ssps[1].getSystemStream()));
+    }
+    for (int i = 0; i < envelopeCount; i++) {
+      envelopes[i] = new IncomingMessageEnvelope(ssps[2], "dummy-offset", "dummy-key", new EndOfStreamMessage("task " + i, envelopeCount));
+    }
+    for (int i = 0; i < 3; i++) {
+      manager.update(envelopes[i], coordinator);
+      assertFalse(manager.isEndOfStream(ssps[1].getSystemStream()));
+    }
+    // the fourth message will end the stream
+    manager.update(envelopes[3], coordinator);
+    assertTrue(manager.isEndOfStream(ssps[1].getSystemStream()));
+  }
+
+  @Test
+  public void testUpdateFromIntermediateStreamWith2Tasks() {
+    SystemStreamPartition[] ssps0 = new SystemStreamPartition[2];
+    ssps0[0] = new SystemStreamPartition("test-system", "test-stream-1", new Partition(0));
+    ssps0[1] = new SystemStreamPartition("test-system", "test-stream-2", new Partition(0));
+
+    SystemStreamPartition ssp1 = new SystemStreamPartition("test-system", "test-stream-2", new Partition(1));
+
+    TaskName t0 = new TaskName("Task 0");
+    Multimap<SystemStream, String> streamToTasks = HashMultimap.create();
+    for (SystemStreamPartition ssp : ssps0) {
+      streamToTasks.put(ssp.getSystemStream(), t0.getTaskName());
+    }
+
+    TaskName t1 = new TaskName("Task 1");
+    streamToTasks.put(ssp1, t1.getTaskName());
+
+    List<StreamSpec> inputs = new ArrayList<>();
+    inputs.add(new StreamSpec("test-stream-1", "test-stream-1", "test-system"));
+    inputs.add(new StreamSpec("test-stream-2", "test-stream-2", "test-system"));
+    StreamSpec outputSpec = new StreamSpec("int-stream", "int-stream", "test-system");
+    IOGraph ioGraph = TestIOGraph.buildSimpleIOGraph(inputs, outputSpec, true);
+
+    ControlMessageListenerTask listener = mock(ControlMessageListenerTask.class);
+    when(listener.getIOGraph()).thenReturn(ioGraph);
+
+    EndOfStreamManager manager0 = spy(new EndOfStreamManager("Task 0", listener, streamToTasks, new HashSet<>(Arrays.asList(ssps0)), null, null));
+    manager0.update(EndOfStreamManager.buildEndOfStreamEnvelope(ssps0[0]), mock(TaskCoordinator.class));
+    assertTrue(manager0.isEndOfStream(ssps0[0].getSystemStream()));
+    doNothing().when(manager0).sendEndOfStream(any(), anyInt());
+    manager0.update(EndOfStreamManager.buildEndOfStreamEnvelope(ssps0[1]), mock(TaskCoordinator.class));
+    assertTrue(manager0.isEndOfStream(ssps0[1].getSystemStream()));
+    verify(manager0).sendEndOfStream(any(), anyInt());
+
+    EndOfStreamManager manager1 = spy(new EndOfStreamManager("Task 1", listener, streamToTasks, Collections.singleton(
+        ssp1), null, null));
+    doNothing().when(manager1).sendEndOfStream(any(), anyInt());
+    manager1.update(EndOfStreamManager.buildEndOfStreamEnvelope(ssp1), mock(TaskCoordinator.class));
+    assertTrue(manager1.isEndOfStream(ssp1.getSystemStream()));
+    verify(manager1).sendEndOfStream(any(), anyInt());
+  }
+
+  @Test
+  public void testSendEndOfStream() {
+    StreamSpec ints = new StreamSpec("int-stream", "int-stream", "test-system");
+    StreamSpec input = new StreamSpec("input-stream", "input-stream", "test-system");
+    IOGraph ioGraph = TestIOGraph.buildSimpleIOGraph(Collections.singletonList(input), ints, true);
+
+    Multimap<SystemStream, String> inputToTasks = HashMultimap.create();
+    for (int i = 0; i < 8; i++) {
+      inputToTasks.put(input.toSystemStream(), "Task " + i);
+    }
+
+    MessageCollector collector = mock(MessageCollector.class);
+    TaskName taskName = new TaskName("Task 0");
+    ControlMessageListenerTask listener = mock(ControlMessageListenerTask.class);
+    when(listener.getIOGraph()).thenReturn(ioGraph);
+    EndOfStreamManager manager = new EndOfStreamManager(taskName.getTaskName(),
+        listener,
+        inputToTasks,
+        Collections.EMPTY_SET,
+        metadataCache,
+        collector);
+
+    Set<Integer> partitions = new HashSet<>();
+    doAnswer(invocation -> {
+        OutgoingMessageEnvelope envelope = (OutgoingMessageEnvelope) invocation.getArguments()[0];
+        partitions.add((Integer) envelope.getPartitionKey());
+        EndOfStreamMessage eosMessage = (EndOfStreamMessage) envelope.getMessage();
+        assertEquals(eosMessage.getTaskName(), taskName.getTaskName());
+        assertEquals(eosMessage.getTaskCount(), 8);
+        return null;
+      }).when(collector).send(any());
+
+    manager.sendEndOfStream(input.toSystemStream(), 8);
+    assertEquals(partitions.size(), 4);
+  }
+
+  @Test
+  public void testPropagate() {
+    List<StreamSpec> inputs = new ArrayList<>();
+    inputs.add(new StreamSpec("input-stream-1", "input-stream-1", "test-system"));
+    inputs.add(new StreamSpec("input-stream-2", "input-stream-2", "test-system"));
+    StreamSpec outputSpec = new StreamSpec("int-stream", "int-stream", "test-system");
+
+    SystemStream input1 = new SystemStream("test-system", "input-stream-1");
+    SystemStream input2 = new SystemStream("test-system", "input-stream-2");
+    SystemStream ints = new SystemStream("test-system", "int-stream");
+    SystemStreamPartition[] ssps = new SystemStreamPartition[3];
+    ssps[0] = new SystemStreamPartition(input1, new Partition(0));
+    ssps[1] = new SystemStreamPartition(input2, new Partition(0));
+    ssps[2] = new SystemStreamPartition(ints, new Partition(0));
+
+    Set<SystemStreamPartition> sspSet = new HashSet<>(Arrays.asList(ssps));
+    TaskName taskName = new TaskName("task 0");
+    Multimap<SystemStream, String> streamToTasks = HashMultimap.create();
+    for (SystemStreamPartition ssp : ssps) {
+      streamToTasks.put(ssp.getSystemStream(), taskName.getTaskName());
+    }
+
+    IOGraph ioGraph = TestIOGraph.buildSimpleIOGraph(inputs, outputSpec, true);
+    MessageCollector collector = mock(MessageCollector.class);
+    ControlMessageListenerTask listener = mock(ControlMessageListenerTask.class);
+    when(listener.getIOGraph()).thenReturn(ioGraph);
+    EndOfStreamManager manager = spy(
+        new EndOfStreamManager("task 0", listener, streamToTasks, sspSet, metadataCache, collector));
+    TaskCoordinator coordinator = mock(TaskCoordinator.class);
+
+    // ssp1 end-of-stream, wait for ssp2
+    manager.update(EndOfStreamManager.buildEndOfStreamEnvelope(ssps[0]), coordinator);
+    verify(manager, never()).sendEndOfStream(any(), anyInt());
+
+    // ssp2 end-of-stream, propagate to intermediate
+    manager.update(EndOfStreamManager.buildEndOfStreamEnvelope(ssps[1]), coordinator);
+    doNothing().when(manager).sendEndOfStream(any(), anyInt());
+    ArgumentCaptor<SystemStream> argument = ArgumentCaptor.forClass(SystemStream.class);
+    verify(manager).sendEndOfStream(argument.capture(), anyInt());
+    assertEquals(ints, argument.getValue());
+
+    // intermediate end-of-stream, shutdown the task
+    manager.update(EndOfStreamManager.buildEndOfStreamEnvelope(ssps[2]), coordinator);
+    doNothing().when(coordinator).shutdown(any());
+    ArgumentCaptor<TaskCoordinator.RequestScope> arg = ArgumentCaptor.forClass(TaskCoordinator.RequestScope.class);
+    verify(coordinator).shutdown(arg.capture());
+    assertEquals(TaskCoordinator.RequestScope.CURRENT_TASK, arg.getValue());
+  }
+
+  //  Test the case when the publishing tasks to intermediate stream is a subset of total tasks
+  @Test
+  public void testPropogateWith2Tasks() {
+    StreamSpec outputSpec = new StreamSpec("int-stream", "int-stream", "test-system");
+    OutputStreamImpl outputStream = new OutputStreamImpl(outputSpec, null, null);
+    OutputOperatorSpec partitionByOp = OperatorSpecs.createPartitionByOperatorSpec(outputStream, 0);
+
+    List<StreamSpec> inputs = new ArrayList<>();
+    inputs.add(new StreamSpec("input-stream-1", "input-stream-1", "test-system"));
+
+    IOGraph ioGraph = TestIOGraph.buildSimpleIOGraph(inputs, outputSpec, true);
+
+    SystemStream input1 = new SystemStream("test-system", "input-stream-1");
+    SystemStream ints = new SystemStream("test-system", "int-stream");
+    SystemStreamPartition ssp1 = new SystemStreamPartition(input1, new Partition(0));
+    SystemStreamPartition ssp2 = new SystemStreamPartition(ints, new Partition(0));
+
+    TaskName t0 = new TaskName("task 0");
+    TaskName t1 = new TaskName("task 1");
+    Multimap<SystemStream, String> streamToTasks = HashMultimap.create();
+    streamToTasks.put(ssp1.getSystemStream(), t0.getTaskName());
+    streamToTasks.put(ssp2.getSystemStream(), t1.getTaskName());
+
+    ControlMessageListenerTask listener = mock(ControlMessageListenerTask.class);
+    when(listener.getIOGraph()).thenReturn(ioGraph);
+
+    EndOfStreamManager manager0 = spy(
+        new EndOfStreamManager(t0.getTaskName(), listener, streamToTasks, Collections.singleton(ssp1), metadataCache, null));
+    EndOfStreamManager manager1 = spy(
+        new EndOfStreamManager(t1.getTaskName(), listener, streamToTasks, Collections.singleton(ssp2), metadataCache, null));
+
+    TaskCoordinator coordinator0 = mock(TaskCoordinator.class);
+    TaskCoordinator coordinator1 = mock(TaskCoordinator.class);
+
+    // ssp1 end-of-stream
+    doNothing().when(manager0).sendEndOfStream(any(), anyInt());
+    doNothing().when(coordinator0).shutdown(any());
+    manager0.update(EndOfStreamManager.buildEndOfStreamEnvelope(ssp1), coordinator0);
+    //verify task count is 1
+    ArgumentCaptor<Integer> argument = ArgumentCaptor.forClass(Integer.class);
+    verify(manager0).sendEndOfStream(any(), argument.capture());
+    assertTrue(argument.getValue() == 1);
+    ArgumentCaptor<TaskCoordinator.RequestScope> arg = ArgumentCaptor.forClass(TaskCoordinator.RequestScope.class);
+    verify(coordinator0).shutdown(arg.capture());
+    assertEquals(TaskCoordinator.RequestScope.CURRENT_TASK, arg.getValue());
+
+    // int1 end-of-stream
+    IncomingMessageEnvelope intEos = new IncomingMessageEnvelope(ssp2, null, null, new EndOfStreamMessage(t0.getTaskName(), 1));
+    manager1.update(intEos, coordinator1);
+    doNothing().when(coordinator1).shutdown(any());
+    verify(manager1, never()).sendEndOfStream(any(), anyInt());
+    arg = ArgumentCaptor.forClass(TaskCoordinator.RequestScope.class);
+    verify(coordinator1).shutdown(arg.capture());
+    assertEquals(TaskCoordinator.RequestScope.CURRENT_TASK, arg.getValue());
+  }
+}
diff --git a/samza-core/src/test/java/org/apache/samza/control/TestIOGraph.java b/samza-core/src/test/java/org/apache/samza/control/TestIOGraph.java
new file mode 100644 (file)
index 0000000..39c56c3
--- /dev/null
@@ -0,0 +1,200 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.control;
+
+import java.time.Duration;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.function.BiFunction;
+import java.util.function.Function;
+import org.apache.samza.config.Config;
+import org.apache.samza.config.JobConfig;
+import org.apache.samza.config.MapConfig;
+import org.apache.samza.control.IOGraph.IONode;
+import org.apache.samza.operators.MessageStream;
+import org.apache.samza.operators.OutputStream;
+import org.apache.samza.operators.StreamGraphImpl;
+import org.apache.samza.operators.functions.JoinFunction;
+import org.apache.samza.runtime.ApplicationRunner;
+import org.apache.samza.system.StreamSpec;
+import org.junit.Before;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+
+public class TestIOGraph {
+  StreamSpec input1;
+  StreamSpec input2;
+  StreamSpec input3;
+  StreamSpec output1;
+  StreamSpec output2;
+  StreamSpec int1;
+  StreamSpec int2;
+
+  StreamGraphImpl streamGraph;
+
+  @Before
+  public void setup() {
+    ApplicationRunner runner = mock(ApplicationRunner.class);
+    Map<String, String> configMap = new HashMap<>();
+    configMap.put(JobConfig.JOB_NAME(), "test-app");
+    configMap.put(JobConfig.JOB_DEFAULT_SYSTEM(), "test-system");
+    Config config = new MapConfig(configMap);
+
+    /**
+     * the graph looks like the following. number of partitions in parentheses. quotes indicate expected value.
+     *
+     *                                    input1 -> map -> join -> output1
+     *                                                       |
+     *                      input2 -> partitionBy -> filter -|
+     *                                                       |
+     *           input3 -> filter -> partitionBy -> map -> join -> output2
+     *
+     */
+    input1 = new StreamSpec("input1", "input1", "system1");
+    input2 = new StreamSpec("input2", "input2", "system2");
+    input3 = new StreamSpec("input3", "input3", "system2");
+
+    output1 = new StreamSpec("output1", "output1", "system1");
+    output2 = new StreamSpec("output2", "output2", "system2");
+
+    runner = mock(ApplicationRunner.class);
+    when(runner.getStreamSpec("input1")).thenReturn(input1);
+    when(runner.getStreamSpec("input2")).thenReturn(input2);
+    when(runner.getStreamSpec("input3")).thenReturn(input3);
+    when(runner.getStreamSpec("output1")).thenReturn(output1);
+    when(runner.getStreamSpec("output2")).thenReturn(output2);
+
+    // intermediate streams used in tests
+    int1 = new StreamSpec("test-app-1-partition_by-3", "test-app-1-partition_by-3", "default-system");
+    int2 = new StreamSpec("test-app-1-partition_by-8", "test-app-1-partition_by-8", "default-system");
+    when(runner.getStreamSpec("test-app-1-partition_by-3"))
+        .thenReturn(int1);
+    when(runner.getStreamSpec("test-app-1-partition_by-8"))
+        .thenReturn(int2);
+
+    streamGraph = new StreamGraphImpl(runner, config);
+    BiFunction msgBuilder = mock(BiFunction.class);
+    MessageStream m1 = streamGraph.getInputStream("input1", msgBuilder).map(m -> m);
+    MessageStream m2 = streamGraph.getInputStream("input2", msgBuilder).partitionBy(m -> "haha").filter(m -> true);
+    MessageStream m3 = streamGraph.getInputStream("input3", msgBuilder).filter(m -> true).partitionBy(m -> "hehe").map(m -> m);
+    Function mockFn = mock(Function.class);
+    OutputStream<Object, Object, Object> om1 = streamGraph.getOutputStream("output1", mockFn, mockFn);
+    OutputStream<Object, Object, Object> om2 = streamGraph.getOutputStream("output2", mockFn, mockFn);
+
+    m1.join(m2, mock(JoinFunction.class), Duration.ofHours(2)).sendTo(om1);
+    m3.join(m2, mock(JoinFunction.class), Duration.ofHours(1)).sendTo(om2);
+  }
+
+  @Test
+  public void testBuildIOGraph() {
+    IOGraph ioGraph = streamGraph.toIOGraph();
+    assertEquals(ioGraph.getNodes().size(), 4);
+
+    for (IONode node : ioGraph.getNodes()) {
+      if (node.getOutput().equals(output1)) {
+        assertEquals(node.getInputs().size(), 2);
+        assertFalse(node.isOutputIntermediate());
+        StreamSpec[] inputs = sort(node.getInputs());
+        assertEquals(inputs[0], input1);
+        assertEquals(inputs[1], int1);
+      } else if (node.getOutput().equals(output2)) {
+        assertEquals(node.getInputs().size(), 2);
+        assertFalse(node.isOutputIntermediate());
+        StreamSpec[] inputs = sort(node.getInputs());
+        assertEquals(inputs[0], int1);
+        assertEquals(inputs[1], int2);
+      } else if (node.getOutput().equals(int1)) {
+        assertEquals(node.getInputs().size(), 1);
+        assertTrue(node.isOutputIntermediate());
+        StreamSpec[] inputs = sort(node.getInputs());
+        assertEquals(inputs[0], input2);
+      } else if (node.getOutput().equals(int2)) {
+        assertEquals(node.getInputs().size(), 1);
+        assertTrue(node.isOutputIntermediate());
+        StreamSpec[] inputs = sort(node.getInputs());
+        assertEquals(inputs[0], input3);
+      }
+    }
+  }
+
+  @Test
+  public void testNodesOfInput() {
+    IOGraph ioGraph = streamGraph.toIOGraph();
+    Collection<IONode> nodes = ioGraph.getNodesOfInput(input1.toSystemStream());
+    assertEquals(nodes.size(), 1);
+    IONode node = nodes.iterator().next();
+    assertEquals(node.getOutput(), output1);
+    assertEquals(node.getInputs().size(), 2);
+    assertFalse(node.isOutputIntermediate());
+
+    nodes = ioGraph.getNodesOfInput(input2.toSystemStream());
+    assertEquals(nodes.size(), 1);
+    node = nodes.iterator().next();
+    assertEquals(node.getOutput(), int1);
+    assertEquals(node.getInputs().size(), 1);
+    assertTrue(node.isOutputIntermediate());
+
+    nodes = ioGraph.getNodesOfInput(int1.toSystemStream());
+    assertEquals(nodes.size(), 2);
+    nodes.forEach(n -> {
+        assertEquals(n.getInputs().size(), 2);
+      });
+
+    nodes = ioGraph.getNodesOfInput(input3.toSystemStream());
+    assertEquals(nodes.size(), 1);
+    node = nodes.iterator().next();
+    assertEquals(node.getOutput(), int2);
+    assertEquals(node.getInputs().size(), 1);
+    assertTrue(node.isOutputIntermediate());
+
+    nodes = ioGraph.getNodesOfInput(int2.toSystemStream());
+    assertEquals(nodes.size(), 1);
+    node = nodes.iterator().next();
+    assertEquals(node.getOutput(), output2);
+    assertEquals(node.getInputs().size(), 2);
+    assertFalse(node.isOutputIntermediate());
+  }
+
+  private static StreamSpec[] sort(Set<StreamSpec> specs) {
+    StreamSpec[] array = new StreamSpec[specs.size()];
+    specs.toArray(array);
+    Arrays.sort(array, (s1, s2) -> s1.getId().compareTo(s2.getId()));
+    return array;
+  }
+
+  public static IOGraph buildSimpleIOGraph(List<StreamSpec> inputs,
+      StreamSpec output,
+      boolean isOutputIntermediate) {
+    IONode node = new IONode(output, isOutputIntermediate);
+    inputs.forEach(input -> node.addInput(input));
+    return new IOGraph(Collections.singleton(node));
+  }
+}
diff --git a/samza-core/src/test/java/org/apache/samza/control/TestWatermarkManager.java b/samza-core/src/test/java/org/apache/samza/control/TestWatermarkManager.java
new file mode 100644 (file)
index 0000000..8fe7a16
--- /dev/null
@@ -0,0 +1,260 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.control;
+
+import com.google.common.collect.HashMultimap;
+import com.google.common.collect.Multimap;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import org.apache.samza.Partition;
+import org.apache.samza.container.TaskName;
+import org.apache.samza.message.WatermarkMessage;
+import org.apache.samza.system.IncomingMessageEnvelope;
+import org.apache.samza.system.OutgoingMessageEnvelope;
+import org.apache.samza.system.StreamMetadataCache;
+import org.apache.samza.system.StreamSpec;
+import org.apache.samza.system.SystemStream;
+import org.apache.samza.system.SystemStreamMetadata;
+import org.apache.samza.system.SystemStreamPartition;
+import org.apache.samza.task.MessageCollector;
+import org.junit.Before;
+import org.junit.Test;
+import org.mockito.ArgumentCaptor;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertNull;
+import static org.mockito.Matchers.any;
+import static org.mockito.Matchers.anyBoolean;
+import static org.mockito.Matchers.anyInt;
+import static org.mockito.Matchers.anyLong;
+import static org.mockito.Matchers.anyObject;
+import static org.mockito.Mockito.doAnswer;
+import static org.mockito.Mockito.doNothing;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.spy;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+
+public class TestWatermarkManager {
+
+  StreamMetadataCache metadataCache;
+
+  @Before
+  public void setup() {
+    SystemStreamMetadata metadata = mock(SystemStreamMetadata.class);
+    Map<Partition, SystemStreamMetadata.SystemStreamPartitionMetadata> partitionMetadata = new HashMap<>();
+    partitionMetadata.put(new Partition(0), mock(SystemStreamMetadata.SystemStreamPartitionMetadata.class));
+    partitionMetadata.put(new Partition(1), mock(SystemStreamMetadata.SystemStreamPartitionMetadata.class));
+    partitionMetadata.put(new Partition(2), mock(SystemStreamMetadata.SystemStreamPartitionMetadata.class));
+    partitionMetadata.put(new Partition(3), mock(SystemStreamMetadata.SystemStreamPartitionMetadata.class));
+    when(metadata.getSystemStreamPartitionMetadata()).thenReturn(partitionMetadata);
+    metadataCache = mock(StreamMetadataCache.class);
+    when(metadataCache.getSystemStreamMetadata(anyObject(), anyBoolean())).thenReturn(metadata);
+  }
+
+  @Test
+  public void testUpdateFromInputSource() {
+    SystemStreamPartition ssp = new SystemStreamPartition("test-system", "test-stream", new Partition(0));
+    TaskName taskName = new TaskName("Task 0");
+    Multimap<SystemStream, String> streamToTasks = HashMultimap.create();
+    streamToTasks.put(ssp.getSystemStream(), taskName.getTaskName());
+    ControlMessageListenerTask listener = mock(ControlMessageListenerTask.class);
+    when(listener.getIOGraph()).thenReturn(new IOGraph(Collections.emptyList()));
+    WatermarkManager manager = new WatermarkManager("Task 0", listener, streamToTasks, Collections.singleton(ssp), null, null);
+    long time = System.currentTimeMillis();
+    Watermark watermark = manager.update(WatermarkManager.buildWatermarkEnvelope(time, ssp));
+    assertEquals(watermark.getTimestamp(), time);
+  }
+
+  @Test
+  public void testUpdateFromIntermediateStream() {
+    SystemStreamPartition[] ssps = new SystemStreamPartition[3];
+    ssps[0] = new SystemStreamPartition("test-system", "test-stream-1", new Partition(0));
+    ssps[1] = new SystemStreamPartition("test-system", "test-stream-2", new Partition(0));
+    ssps[2] = new SystemStreamPartition("test-system", "test-stream-2", new Partition(1));
+
+    TaskName taskName = new TaskName("Task 0");
+    Multimap<SystemStream, String> streamToTasks = HashMultimap.create();
+    for (SystemStreamPartition ssp : ssps) {
+      streamToTasks.put(ssp.getSystemStream(), taskName.getTaskName());
+    }
+    ControlMessageListenerTask listener = mock(ControlMessageListenerTask.class);
+    when(listener.getIOGraph()).thenReturn(new IOGraph(Collections.emptyList()));
+
+    WatermarkManager manager = new WatermarkManager("Task 0", listener, streamToTasks, new HashSet<>(Arrays.asList(ssps)), null, null);
+    int envelopeCount = 4;
+    IncomingMessageEnvelope[] envelopes = new IncomingMessageEnvelope[envelopeCount];
+
+    long[] time = {300L, 200L, 100L, 400L};
+    for (int i = 0; i < envelopeCount; i++) {
+      envelopes[i] = new IncomingMessageEnvelope(ssps[0], "dummy-offset", "", new WatermarkMessage(time[i], "task " + i, envelopeCount));
+    }
+    for (int i = 0; i < 3; i++) {
+      assertNull(manager.update(envelopes[i]));
+    }
+    // verify the first three messages won't result in end-of-stream
+    assertEquals(manager.getWatermarkTime(ssps[0]), WatermarkManager.TIME_NOT_EXIST);
+    // the fourth message will generate a watermark
+    Watermark watermark = manager.update(envelopes[3]);
+    assertNotNull(watermark);
+    assertEquals(watermark.getTimestamp(), 100);
+    assertEquals(manager.getWatermarkTime(ssps[1]), WatermarkManager.TIME_NOT_EXIST);
+    assertEquals(manager.getWatermarkTime(ssps[2]), WatermarkManager.TIME_NOT_EXIST);
+
+
+    // stream2 has two partitions assigned to this task, so it requires a message from each partition to calculate watermarks
+    long[] time1 = {300L, 200L, 100L, 400L};
+    envelopes = new IncomingMessageEnvelope[envelopeCount];
+    for (int i = 0; i < envelopeCount; i++) {
+      envelopes[i] = new IncomingMessageEnvelope(ssps[1], "dummy-offset", "", new WatermarkMessage(time1[i], "task " + i, envelopeCount));
+    }
+    // verify the messages for the partition 0 won't generate watermark
+    for (int i = 0; i < 4; i++) {
+      assertNull(manager.update(envelopes[i]));
+    }
+    assertEquals(manager.getWatermarkTime(ssps[1]), 100L);
+
+    long[] time2 = {350L, 150L, 500L, 80L};
+    for (int i = 0; i < envelopeCount; i++) {
+      envelopes[i] = new IncomingMessageEnvelope(ssps[2], "dummy-offset", "", new WatermarkMessage(time2[i], "task " + i, envelopeCount));
+    }
+    for (int i = 0; i < 3; i++) {
+      assertNull(manager.update(envelopes[i]));
+    }
+    assertEquals(manager.getWatermarkTime(ssps[2]), WatermarkManager.TIME_NOT_EXIST);
+    // the fourth message will generate the watermark
+    watermark = manager.update(envelopes[3]);
+    assertNotNull(watermark);
+    assertEquals(manager.getWatermarkTime(ssps[2]), 80L);
+    assertEquals(watermark.getTimestamp(), 80L);
+  }
+
+  @Test
+  public void testSendWatermark() {
+    SystemStream ints = new SystemStream("test-system", "int-stream");
+    SystemStreamMetadata metadata = mock(SystemStreamMetadata.class);
+    Map<Partition, SystemStreamMetadata.SystemStreamPartitionMetadata> partitionMetadata = new HashMap<>();
+    partitionMetadata.put(new Partition(0), mock(SystemStreamMetadata.SystemStreamPartitionMetadata.class));
+    partitionMetadata.put(new Partition(1), mock(SystemStreamMetadata.SystemStreamPartitionMetadata.class));
+    partitionMetadata.put(new Partition(2), mock(SystemStreamMetadata.SystemStreamPartitionMetadata.class));
+    partitionMetadata.put(new Partition(3), mock(SystemStreamMetadata.SystemStreamPartitionMetadata.class));
+    when(metadata.getSystemStreamPartitionMetadata()).thenReturn(partitionMetadata);
+    StreamMetadataCache metadataCache = mock(StreamMetadataCache.class);
+    when(metadataCache.getSystemStreamMetadata(anyObject(), anyBoolean())).thenReturn(metadata);
+
+    MessageCollector collector = mock(MessageCollector.class);
+    ControlMessageListenerTask listener = mock(ControlMessageListenerTask.class);
+    when(listener.getIOGraph()).thenReturn(new IOGraph(Collections.emptyList()));
+
+    WatermarkManager manager = new WatermarkManager("task 0",
+        listener,
+        HashMultimap.create(),
+        Collections.EMPTY_SET,
+        metadataCache,
+        collector);
+
+    long time = System.currentTimeMillis();
+    Set<Integer> partitions = new HashSet<>();
+    doAnswer(invocation -> {
+        OutgoingMessageEnvelope envelope = (OutgoingMessageEnvelope) invocation.getArguments()[0];
+        partitions.add((Integer) envelope.getPartitionKey());
+        WatermarkMessage watermarkMessage = (WatermarkMessage) envelope.getMessage();
+        assertEquals(watermarkMessage.getTaskName(), "task 0");
+        assertEquals(watermarkMessage.getTaskCount(), 8);
+        assertEquals(watermarkMessage.getTimestamp(), time);
+        return null;
+      }).when(collector).send(any());
+
+    manager.sendWatermark(time, ints, 8);
+    assertEquals(partitions.size(), 4);
+  }
+
+  @Test
+  public void testPropagate() {
+    StreamSpec outputSpec = new StreamSpec("int-stream", "int-stream", "test-system");
+    List<StreamSpec> inputs = new ArrayList<>();
+    inputs.add(new StreamSpec("input-stream-1", "input-stream-1", "test-system"));
+    inputs.add(new StreamSpec("input-stream-2", "input-stream-2", "test-system"));
+
+    IOGraph ioGraph = TestIOGraph.buildSimpleIOGraph(inputs, outputSpec, true);
+
+    SystemStream input1 = new SystemStream("test-system", "input-stream-1");
+    SystemStream input2 = new SystemStream("test-system", "input-stream-2");
+    SystemStream input3 = new SystemStream("test-system", "input-stream-3");
+    SystemStream ints = new SystemStream("test-system", "int-stream");
+    SystemStreamPartition[] ssps0 = new SystemStreamPartition[3];
+    ssps0[0] = new SystemStreamPartition(input1, new Partition(0));
+    ssps0[1] = new SystemStreamPartition(input2, new Partition(0));
+    ssps0[2] = new SystemStreamPartition(ints, new Partition(0));
+
+    SystemStreamPartition[] ssps1 = new SystemStreamPartition[4];
+    ssps1[0] = new SystemStreamPartition(input1, new Partition(1));
+    ssps1[1] = new SystemStreamPartition(input2, new Partition(1));
+    ssps1[2] = new SystemStreamPartition(input3, new Partition(1));
+    ssps1[3] = new SystemStreamPartition(ints, new Partition(1));
+
+    SystemStreamPartition[] ssps2 = new SystemStreamPartition[2];
+    ssps2[0] = new SystemStreamPartition(input3, new Partition(2));
+    ssps2[1] = new SystemStreamPartition(ints, new Partition(2));
+
+
+    TaskName t0 = new TaskName("task 0"); //consume input1 and input2
+    TaskName t1 = new TaskName("task 1"); //consume input 1 and input2 and input 3
+    TaskName t2 = new TaskName("task 2"); //consume input2 and input 3
+    Multimap<SystemStream, String> inputToTasks = HashMultimap.create();
+    for (SystemStreamPartition ssp : ssps0) {
+      inputToTasks.put(ssp.getSystemStream(), t0.getTaskName());
+    }
+    for (SystemStreamPartition ssp : ssps1) {
+      inputToTasks.put(ssp.getSystemStream(), t1.getTaskName());
+    }
+    for (SystemStreamPartition ssp : ssps2) {
+      inputToTasks.put(ssp.getSystemStream(), t2.getTaskName());
+    }
+
+    ControlMessageListenerTask listener = mock(ControlMessageListenerTask.class);
+    when(listener.getIOGraph()).thenReturn(ioGraph);
+    WatermarkManager manager = spy(
+        new WatermarkManager(t0.getTaskName(), listener, inputToTasks, new HashSet<>(Arrays.asList(ssps0)), null, null));
+
+    IncomingMessageEnvelope envelope = WatermarkManager.buildWatermarkEnvelope(System.currentTimeMillis(), ssps0[0]);
+    doNothing().when(manager).sendWatermark(anyLong(), any(), anyInt());
+    Watermark watermark = manager.update(envelope);
+    assertNotNull(watermark);
+    long time = System.currentTimeMillis();
+    Watermark updatedWatermark = watermark.copyWithTimestamp(time);
+    updatedWatermark.propagate(ints);
+    ArgumentCaptor<Long> arg1 = ArgumentCaptor.forClass(Long.class);
+    ArgumentCaptor<SystemStream> arg2 = ArgumentCaptor.forClass(SystemStream.class);
+    ArgumentCaptor<Integer> arg3 = ArgumentCaptor.forClass(Integer.class);
+    verify(manager).sendWatermark(arg1.capture(), arg2.capture(), arg3.capture());
+    assertEquals(arg1.getValue().longValue(), time);
+    assertEquals(arg2.getValue(), ints);
+    assertEquals(arg3.getValue().intValue(), 2);
+  }
+}
index d50d271..3ae8f5b 100644 (file)
@@ -18,6 +18,7 @@
  */
 package org.apache.samza.operators.impl;
 
+import java.util.Set;
 import org.apache.samza.config.Config;
 import org.apache.samza.metrics.Counter;
 import org.apache.samza.metrics.MetricsRegistry;
@@ -210,5 +211,22 @@ public class TestOperatorImpl {
      super(OpCode.INPUT, 1);
     }
   }
+
+  public static Set<OperatorImpl> getNextOperators(OperatorImpl op) {
+    return op.registeredOperators;
+  }
+
+  public static OperatorSpec.OpCode getOpCode(OperatorImpl op) {
+    return op.getOperatorSpec().getOpCode();
+  }
+
+  public static long getInputWatermark(OperatorImpl op) {
+    return op.getInputWatermarkTime();
+  }
+
+  public static long getOutputWatermark(OperatorImpl op) {
+    return op.getOutputWatermarkTime();
+  }
+
 }
 
index 6a8d765..fc1259c 100644 (file)
@@ -88,9 +88,7 @@ public class TestStreamProcessor {
     }
 
     @Override
-    SamzaContainer createSamzaContainer(
-        ContainerModel containerModel,
-        int maxChangelogStreamPartitions) {
+    SamzaContainer createSamzaContainer(String processorId, JobModel jobModel) {
       if (container == null) {
         RunLoop mockRunLoop = mock(RunLoop.class);
         doAnswer(invocation ->
index a04bd3b..4be4e73 100644 (file)
@@ -19,6 +19,7 @@
 
 package org.apache.samza.runtime;
 
+import java.util.Set;
 import org.apache.samza.application.StreamApplication;
 import org.apache.samza.config.ApplicationConfig;
 import org.apache.samza.config.JobConfig;
@@ -330,7 +331,6 @@ public class TestLocalApplicationRunner {
         return null;
       }).when(sp).start();
 
-
     LocalApplicationRunner spy = spy(runner);
     doReturn(sp).when(spy).createStreamProcessor(anyObject(), anyObject(), captor.capture());
 
@@ -343,4 +343,8 @@ public class TestLocalApplicationRunner {
     assertEquals(spy.status(app), ApplicationStatus.UnsuccessfulFinish);
   }
 
+  public static Set<StreamProcessor> getProcessors(LocalApplicationRunner runner) {
+    return runner.getProcessors();
+  }
+
 }
index 5b76bba..7192525 100644 (file)
@@ -26,7 +26,7 @@ import java.io.ObjectOutputStream;
 import java.io.Serializable;
 import org.apache.samza.message.EndOfStreamMessage;
 import org.apache.samza.message.WatermarkMessage;
-import org.apache.samza.message.MessageType;
+import org.apache.samza.message.IntermediateMessageType;
 import org.apache.samza.serializers.IntermediateMessageSerde;
 import org.apache.samza.serializers.Serde;
 import org.junit.Test;
@@ -96,7 +96,7 @@ public class TestIntermediateMessageSerde {
     TestUserMessage userMessage = new TestUserMessage(msg, 0, System.currentTimeMillis());
     byte[] bytes = imserde.toBytes(userMessage);
     TestUserMessage de = (TestUserMessage) imserde.fromBytes(bytes);
-    assertEquals(MessageType.of(de), MessageType.USER_MESSAGE);
+    assertEquals(IntermediateMessageType.of(de), IntermediateMessageType.USER_MESSAGE);
     assertEquals(de.getMessage(), msg);
     assertEquals(de.getOffset(), 0);
     assertTrue(de.getTimestamp() > 0);
@@ -109,7 +109,7 @@ public class TestIntermediateMessageSerde {
     WatermarkMessage watermark = new WatermarkMessage(System.currentTimeMillis(), taskName, 8);
     byte[] bytes = imserde.toBytes(watermark);
     WatermarkMessage de = (WatermarkMessage) imserde.fromBytes(bytes);
-    assertEquals(MessageType.of(de), MessageType.WATERMARK);
+    assertEquals(IntermediateMessageType.of(de), IntermediateMessageType.WATERMARK_MESSAGE);
     assertEquals(de.getTaskName(), taskName);
     assertEquals(de.getTaskCount(), 8);
     assertTrue(de.getTimestamp() > 0);
@@ -123,7 +123,7 @@ public class TestIntermediateMessageSerde {
     EndOfStreamMessage eos = new EndOfStreamMessage(taskName, 8);
     byte[] bytes = imserde.toBytes(eos);
     EndOfStreamMessage de = (EndOfStreamMessage) imserde.fromBytes(bytes);
-    assertEquals(MessageType.of(de), MessageType.END_OF_STREAM);
+    assertEquals(IntermediateMessageType.of(de), IntermediateMessageType.END_OF_STREAM_MESSAGE);
     assertEquals(de.getTaskName(), taskName);
     assertEquals(de.getTaskCount(), 8);
     assertEquals(de.getVersion(), 1);
index 1afc26a..03931f1 100644 (file)
@@ -39,6 +39,7 @@ import org.apache.samza.container.TaskInstance;
 import org.apache.samza.container.TaskInstanceExceptionHandler;
 import org.apache.samza.container.TaskInstanceMetrics;
 import org.apache.samza.container.TaskName;
+import org.apache.samza.control.EndOfStreamManager;
 import org.apache.samza.metrics.MetricsRegistryMap;
 import org.apache.samza.system.IncomingMessageEnvelope;
 import org.apache.samza.system.SystemConsumer;
@@ -78,15 +79,15 @@ public class TestAsyncRunLoop {
   private final IncomingMessageEnvelope envelope0 = new IncomingMessageEnvelope(ssp0, "0", "key0", "value0");
   private final IncomingMessageEnvelope envelope1 = new IncomingMessageEnvelope(ssp1, "1", "key1", "value1");
   private final IncomingMessageEnvelope envelope3 = new IncomingMessageEnvelope(ssp0, "1", "key0", "value0");
-  private final IncomingMessageEnvelope ssp0EndOfStream = IncomingMessageEnvelope.buildEndOfStreamEnvelope(ssp0);
-  private final IncomingMessageEnvelope ssp1EndOfStream = IncomingMessageEnvelope.buildEndOfStreamEnvelope(ssp1);
+  private final IncomingMessageEnvelope ssp0EndOfStream = EndOfStreamManager.buildEndOfStreamEnvelope(ssp0);
+  private final IncomingMessageEnvelope ssp1EndOfStream = EndOfStreamManager.buildEndOfStreamEnvelope(ssp1);
 
   TaskInstance createTaskInstance(AsyncStreamTask task, TaskName taskName, SystemStreamPartition ssp, OffsetManager manager, SystemConsumers consumers) {
     TaskInstanceMetrics taskInstanceMetrics = new TaskInstanceMetrics("task", new MetricsRegistryMap());
     scala.collection.immutable.Set<SystemStreamPartition> sspSet = JavaConverters.asScalaSetConverter(Collections.singleton(ssp)).asScala().toSet();
     return new TaskInstance(task, taskName, mock(Config.class), taskInstanceMetrics,
         null, consumers, mock(TaskInstanceCollector.class), mock(SamzaContainerContext.class),
-        manager, null, null, sspSet, new TaskInstanceExceptionHandler(taskInstanceMetrics, new scala.collection.immutable.HashSet<String>()));
+        manager, null, null, sspSet, new TaskInstanceExceptionHandler(taskInstanceMetrics, new scala.collection.immutable.HashSet<String>()), null, null);
   }
 
   TaskInstance createTaskInstance(AsyncStreamTask task, TaskName taskName, SystemStreamPartition ssp) {
@@ -569,7 +570,7 @@ public class TestAsyncRunLoop {
     SystemStreamPartition ssp2 = new SystemStreamPartition("system1", "stream2", p2);
     IncomingMessageEnvelope envelope1 = new IncomingMessageEnvelope(ssp2, "1", "key1", "message1");
     IncomingMessageEnvelope envelope2 = new IncomingMessageEnvelope(ssp2, "2", "key1", "message1");
-    IncomingMessageEnvelope envelope3 = IncomingMessageEnvelope.buildEndOfStreamEnvelope(ssp2);
+    IncomingMessageEnvelope envelope3 = EndOfStreamManager.buildEndOfStreamEnvelope(ssp2);
 
     Map<SystemStreamPartition, List<IncomingMessageEnvelope>> sspMap = new HashMap<>();
     List<IncomingMessageEnvelope> messageList = new ArrayList<>();
diff --git a/samza-core/src/test/java/org/apache/samza/task/TestStreamOperatorTask.java b/samza-core/src/test/java/org/apache/samza/task/TestStreamOperatorTask.java
new file mode 100644 (file)
index 0000000..45b08d7
--- /dev/null
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.task;
+
+import org.apache.samza.operators.impl.OperatorImplGraph;
+
+
+public class TestStreamOperatorTask {
+
+  public static OperatorImplGraph getOperatorImplGraph(StreamOperatorTask task) {
+    return task.getOperatorImplGraph();
+  }
+}
index 40974a6..9025077 100644 (file)
 
 package org.apache.samza.container
 
+
+import java.util
+import java.util
+import java.util.Collections
 import java.util.concurrent.ConcurrentHashMap
+import com.google.common.collect.Multimap
 import org.apache.samza.SamzaException
 import org.apache.samza.Partition
 import org.apache.samza.checkpoint.OffsetManager
 import org.apache.samza.config.Config
 import org.apache.samza.config.MapConfig
+import org.apache.samza.control.ControlMessageUtils
+import org.apache.samza.job.model.ContainerModel
+import org.apache.samza.job.model.JobModel
+import org.apache.samza.job.model.TaskModel
 import org.apache.samza.metrics.Counter
 import org.apache.samza.metrics.Metric
 import org.apache.samza.metrics.MetricsRegistryMap
@@ -354,6 +363,36 @@ class TestTaskInstance {
     val expected = List(envelope1, envelope2, envelope4)
     assertEquals(expected, result.toList)
   }
+
+  @Test
+  def testBuildInputToTasks = {
+    val system: String = "test-system"
+    val stream0: String = "test-stream-0"
+    val stream1: String = "test-stream-1"
+
+    val ssp0: SystemStreamPartition = new SystemStreamPartition(system, stream0, new Partition(0))
+    val ssp1: SystemStreamPartition = new SystemStreamPartition(system, stream0, new Partition(1))
+    val ssp2: SystemStreamPartition = new SystemStreamPartition(system, stream1, new Partition(0))
+
+    val task0: TaskName = new TaskName("Task 0")
+    val task1: TaskName = new TaskName("Task 1")
+    val ssps: util.Set[SystemStreamPartition] = new util.HashSet[SystemStreamPartition]
+    ssps.add(ssp0)
+    ssps.add(ssp2)
+    val tm0: TaskModel = new TaskModel(task0, ssps, new Partition(0))
+    val cm0: ContainerModel = new ContainerModel("c0", 0, Collections.singletonMap(task0, tm0))
+    val tm1: TaskModel = new TaskModel(task1, Collections.singleton(ssp1), new Partition(1))
+    val cm1: ContainerModel = new ContainerModel("c1", 1, Collections.singletonMap(task1, tm1))
+
+    val cms: util.Map[String, ContainerModel] = new util.HashMap[String, ContainerModel]
+    cms.put(cm0.getProcessorId, cm0)
+    cms.put(cm1.getProcessorId, cm1)
+
+    val jobModel: JobModel = new JobModel(new MapConfig, cms, null)
+    val streamToTasks: Multimap[SystemStream, String] = TaskInstance.buildInputToTasks(jobModel)
+    assertEquals(streamToTasks.get(ssp0.getSystemStream).size, 2)
+    assertEquals(streamToTasks.get(ssp2.getSystemStream).size, 1)
+  }
 }
 
 class MockSystemAdmin extends SystemAdmin {
index fb9bb56..de0d1da 100644 (file)
@@ -36,6 +36,7 @@ import org.apache.commons.lang.Validate;
 import org.apache.samza.Partition;
 import org.apache.samza.SamzaException;
 import org.apache.samza.config.Config;
+import org.apache.samza.control.EndOfStreamManager;
 import org.apache.samza.metrics.Counter;
 import org.apache.samza.metrics.MetricsRegistry;
 import org.apache.samza.system.IncomingMessageEnvelope;
@@ -236,7 +237,7 @@ public class HdfsSystemConsumer extends BlockingEnvelopeMap {
       consumerMetrics.incNumEvents(systemStreamPartition);
       consumerMetrics.incTotalNumEvents();
     }
-    offerMessage(systemStreamPartition, IncomingMessageEnvelope.buildEndOfStreamEnvelope(systemStreamPartition));
+    offerMessage(systemStreamPartition, EndOfStreamManager.buildEndOfStreamEnvelope(systemStreamPartition));
     reader.close();
   }
 
diff --git a/samza-test/src/test/java/org/apache/samza/processor/TestStreamProcessorUtil.java b/samza-test/src/test/java/org/apache/samza/processor/TestStreamProcessorUtil.java
new file mode 100644 (file)
index 0000000..08e866e
--- /dev/null
@@ -0,0 +1,28 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.processor;
+
+import org.apache.samza.container.SamzaContainer;
+
+public class TestStreamProcessorUtil {
+  public static SamzaContainer getContainer(StreamProcessor processor) {
+    return processor.getContainer();
+  }
+}
diff --git a/samza-test/src/test/java/org/apache/samza/test/controlmessages/EndOfStreamIntegrationTest.java b/samza-test/src/test/java/org/apache/samza/test/controlmessages/EndOfStreamIntegrationTest.java
new file mode 100644 (file)
index 0000000..26abb13
--- /dev/null
@@ -0,0 +1,102 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.test.controlmessages;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import org.apache.samza.application.StreamApplication;
+import org.apache.samza.config.JobConfig;
+import org.apache.samza.config.JobCoordinatorConfig;
+import org.apache.samza.config.MapConfig;
+import org.apache.samza.config.TaskConfig;
+import org.apache.samza.container.grouper.task.SingleContainerGrouperFactory;
+import org.apache.samza.runtime.LocalApplicationRunner;
+import org.apache.samza.standalone.PassthroughJobCoordinatorFactory;
+import org.apache.samza.test.controlmessages.TestData.PageView;
+import org.apache.samza.test.controlmessages.TestData.PageViewJsonSerdeFactory;
+import org.apache.samza.test.harness.AbstractIntegrationTestHarness;
+import org.apache.samza.test.util.ArraySystemFactory;
+import org.apache.samza.test.util.Base64Serializer;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+
+
+/**
+ * This test uses an array as a bounded input source, and does a partitionBy() and sink() after reading the input.
+ * It verifies the pipeline will stop and the number of output messages should equal to the input.
+ */
+public class EndOfStreamIntegrationTest extends AbstractIntegrationTestHarness {
+
+
+  private static final String[] PAGEKEYS = {"inbox", "home", "search", "pymk", "group", "job"};
+
+  @Test
+  public void testPipeline() throws  Exception {
+    Random random = new Random();
+    int count = 100;
+    PageView[] pageviews = new PageView[count];
+    for (int i = 0; i < count; i++) {
+      String pagekey = PAGEKEYS[random.nextInt(PAGEKEYS.length - 1)];
+      int memberId = random.nextInt(10);
+      pageviews[i] = new PageView(pagekey, memberId);
+    }
+
+    int partitionCount = 4;
+    Map<String, String> configs = new HashMap<>();
+    configs.put("systems.test.samza.factory", ArraySystemFactory.class.getName());
+    configs.put("streams.PageView.samza.system", "test");
+    configs.put("streams.PageView.source", Base64Serializer.serialize(pageviews));
+    configs.put("streams.PageView.partitionCount", String.valueOf(partitionCount));
+
+    configs.put(JobConfig.JOB_NAME(), "test-eos-job");
+    configs.put(JobConfig.PROCESSOR_ID(), "1");
+    configs.put(JobCoordinatorConfig.JOB_COORDINATOR_FACTORY, PassthroughJobCoordinatorFactory.class.getName());
+    configs.put(TaskConfig.GROUPER_FACTORY(), SingleContainerGrouperFactory.class.getName());
+
+    configs.put("systems.kafka.samza.factory", "org.apache.samza.system.kafka.KafkaSystemFactory");
+    configs.put("systems.kafka.producer.bootstrap.servers", bootstrapUrl());
+    configs.put("systems.kafka.consumer.zookeeper.connect", zkConnect());
+    configs.put("systems.kafka.samza.key.serde", "int");
+    configs.put("systems.kafka.samza.msg.serde", "json");
+    configs.put("systems.kafka.default.stream.replication.factor", "1");
+    configs.put("job.default.system", "kafka");
+
+    configs.put("serializers.registry.int.class", "org.apache.samza.serializers.IntegerSerdeFactory");
+    configs.put("serializers.registry.json.class", PageViewJsonSerdeFactory.class.getName());
+
+    final LocalApplicationRunner runner = new LocalApplicationRunner(new MapConfig(configs));
+    List<PageView> received = new ArrayList<>();
+    final StreamApplication app = (streamGraph, cfg) -> {
+      streamGraph.getInputStream("PageView", (k, v) -> (PageView) v)
+        .partitionBy(PageView::getMemberId)
+        .sink((m, collector, coordinator) -> {
+            received.add(m);
+          });
+    };
+    runner.run(app);
+    runner.waitForFinish();
+
+    assertEquals(received.size(), count * partitionCount);
+  }
+}
diff --git a/samza-test/src/test/java/org/apache/samza/test/controlmessages/TestData.java b/samza-test/src/test/java/org/apache/samza/test/controlmessages/TestData.java
new file mode 100644 (file)
index 0000000..8541b55
--- /dev/null
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.test.controlmessages;
+
+import java.io.Serializable;
+import org.apache.samza.SamzaException;
+import org.apache.samza.config.Config;
+import org.apache.samza.serializers.Serde;
+import org.apache.samza.serializers.SerdeFactory;
+import org.codehaus.jackson.annotate.JsonCreator;
+import org.codehaus.jackson.annotate.JsonProperty;
+import org.codehaus.jackson.map.ObjectMapper;
+import org.codehaus.jackson.type.TypeReference;
+
+public class TestData {
+
+  public static class PageView implements Serializable {
+    @JsonProperty("pageKey")
+    final String pageKey;
+    @JsonProperty("memberId")
+    final int memberId;
+
+    @JsonProperty("pageKey")
+    public String getPageKey() {
+      return pageKey;
+    }
+
+    @JsonProperty("memberId")
+    public int getMemberId() {
+      return memberId;
+    }
+
+    @JsonCreator
+    public PageView(@JsonProperty("pageKey") String pageKey, @JsonProperty("memberId") int memberId) {
+      this.pageKey = pageKey;
+      this.memberId = memberId;
+    }
+  }
+
+  public static class PageViewJsonSerdeFactory implements SerdeFactory<PageView> {
+    @Override
+    public Serde<PageView> getSerde(String name, Config config) {
+      return new PageViewJsonSerde();
+    }
+  }
+
+  public static class PageViewJsonSerde implements Serde<PageView> {
+    ObjectMapper mapper = new ObjectMapper();
+
+    @Override
+    public PageView fromBytes(byte[] bytes) {
+      try {
+        return mapper.readValue(new String(bytes, "UTF-8"), new TypeReference<PageView>() { });
+      } catch (Exception e) {
+        throw new SamzaException(e);
+      }
+    }
+
+    @Override
+    public byte[] toBytes(PageView pv) {
+      try {
+        return mapper.writeValueAsString(pv).getBytes("UTF-8");
+      } catch (Exception e) {
+        throw new SamzaException(e);
+      }
+    }
+  }
+}
diff --git a/samza-test/src/test/java/org/apache/samza/test/controlmessages/WatermarkIntegrationTest.java b/samza-test/src/test/java/org/apache/samza/test/controlmessages/WatermarkIntegrationTest.java
new file mode 100644 (file)
index 0000000..58da8bd
--- /dev/null
@@ -0,0 +1,204 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.test.controlmessages;
+
+import java.lang.reflect.Field;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import org.apache.samza.Partition;
+import org.apache.samza.application.StreamApplication;
+import org.apache.samza.config.Config;
+import org.apache.samza.config.JobConfig;
+import org.apache.samza.config.JobCoordinatorConfig;
+import org.apache.samza.config.MapConfig;
+import org.apache.samza.config.TaskConfig;
+import org.apache.samza.container.SamzaContainer;
+import org.apache.samza.container.TaskInstance;
+import org.apache.samza.container.TaskName;
+import org.apache.samza.container.grouper.task.SingleContainerGrouperFactory;
+import org.apache.samza.control.EndOfStreamManager;
+import org.apache.samza.control.WatermarkManager;
+import org.apache.samza.metrics.MetricsRegistry;
+import org.apache.samza.operators.impl.InputOperatorImpl;
+import org.apache.samza.operators.impl.OperatorImpl;
+import org.apache.samza.operators.impl.OperatorImplGraph;
+import org.apache.samza.operators.impl.TestOperatorImpl;
+import org.apache.samza.operators.spec.OperatorSpec;
+import org.apache.samza.processor.StreamProcessor;
+import org.apache.samza.processor.TestStreamProcessorUtil;
+import org.apache.samza.runtime.LocalApplicationRunner;
+import org.apache.samza.runtime.TestLocalApplicationRunner;
+import org.apache.samza.serializers.IntegerSerdeFactory;
+import org.apache.samza.serializers.StringSerdeFactory;
+import org.apache.samza.standalone.PassthroughJobCoordinatorFactory;
+import org.apache.samza.system.IncomingMessageEnvelope;
+import org.apache.samza.system.SystemAdmin;
+import org.apache.samza.system.SystemConsumer;
+import org.apache.samza.system.SystemFactory;
+import org.apache.samza.system.SystemProducer;
+import org.apache.samza.system.SystemStreamPartition;
+import org.apache.samza.task.AsyncStreamTaskAdapter;
+import org.apache.samza.task.StreamOperatorTask;
+import org.apache.samza.task.TestStreamOperatorTask;
+import org.apache.samza.test.controlmessages.TestData.PageView;
+import org.apache.samza.test.controlmessages.TestData.PageViewJsonSerdeFactory;
+import org.apache.samza.test.harness.AbstractIntegrationTestHarness;
+import org.apache.samza.test.util.SimpleSystemAdmin;
+import org.apache.samza.test.util.TestStreamConsumer;
+import org.junit.Test;
+import scala.collection.JavaConverters;
+
+import static org.junit.Assert.assertEquals;
+
+
+public class WatermarkIntegrationTest extends AbstractIntegrationTestHarness {
+
+  private static int offset = 1;
+  private static final String TEST_SYSTEM = "test";
+  private static final String TEST_STREAM = "PageView";
+  private static final int PARTITION_COUNT = 2;
+  private static final SystemStreamPartition SSP0 = new SystemStreamPartition(TEST_SYSTEM, TEST_STREAM, new Partition(0));
+  private static final SystemStreamPartition SSP1 = new SystemStreamPartition(TEST_SYSTEM, TEST_STREAM, new Partition(1));
+
+  private final static List<IncomingMessageEnvelope> TEST_DATA = new ArrayList<>();
+  static {
+    TEST_DATA.add(createIncomingMessage(new PageView("inbox", 1), SSP0));
+    TEST_DATA.add(createIncomingMessage(new PageView("home", 2), SSP1));
+    TEST_DATA.add(WatermarkManager.buildWatermarkEnvelope(1, SSP0));
+    TEST_DATA.add(WatermarkManager.buildWatermarkEnvelope(2, SSP1));
+    TEST_DATA.add(WatermarkManager.buildWatermarkEnvelope(4, SSP0));
+    TEST_DATA.add(WatermarkManager.buildWatermarkEnvelope(3, SSP1));
+    TEST_DATA.add(createIncomingMessage(new PageView("search", 3), SSP0));
+    TEST_DATA.add(createIncomingMessage(new PageView("pymk", 4), SSP1));
+    TEST_DATA.add(EndOfStreamManager.buildEndOfStreamEnvelope(SSP0));
+    TEST_DATA.add(EndOfStreamManager.buildEndOfStreamEnvelope(SSP1));
+  }
+
+  public final static class TestSystemFactory implements SystemFactory {
+    @Override
+    public SystemConsumer getConsumer(String systemName, Config config, MetricsRegistry registry) {
+      return new TestStreamConsumer(TEST_DATA);
+    }
+
+    @Override
+    public SystemProducer getProducer(String systemName, Config config, MetricsRegistry registry) {
+      return null;
+    }
+
+    @Override
+    public SystemAdmin getAdmin(String systemName, Config config) {
+      return new SimpleSystemAdmin(config);
+    }
+  }
+
+  private static IncomingMessageEnvelope createIncomingMessage(Object message, SystemStreamPartition ssp) {
+    return new IncomingMessageEnvelope(ssp, String.valueOf(offset++), "", message);
+  }
+
+  @Test
+  public void testWatermark() throws Exception {
+    Map<String, String> configs = new HashMap<>();
+    configs.put("systems.test.samza.factory", TestSystemFactory.class.getName());
+    configs.put("streams.PageView.samza.system", "test");
+    configs.put("streams.PageView.partitionCount", String.valueOf(PARTITION_COUNT));
+
+    configs.put(JobConfig.JOB_NAME(), "test-watermark-job");
+    configs.put(JobConfig.PROCESSOR_ID(), "1");
+    configs.put(JobCoordinatorConfig.JOB_COORDINATOR_FACTORY, PassthroughJobCoordinatorFactory.class.getName());
+    configs.put(TaskConfig.GROUPER_FACTORY(), SingleContainerGrouperFactory.class.getName());
+
+    configs.put("systems.kafka.samza.factory", "org.apache.samza.system.kafka.KafkaSystemFactory");
+    configs.put("systems.kafka.producer.bootstrap.servers", bootstrapUrl());
+    configs.put("systems.kafka.consumer.zookeeper.connect", zkConnect());
+    configs.put("systems.kafka.samza.key.serde", "int");
+    configs.put("systems.kafka.samza.msg.serde", "json");
+    configs.put("systems.kafka.default.stream.replication.factor", "1");
+    configs.put("job.default.system", "kafka");
+
+    configs.put("serializers.registry.int.class", IntegerSerdeFactory.class.getName());
+    configs.put("serializers.registry.string.class", StringSerdeFactory.class.getName());
+    configs.put("serializers.registry.json.class", PageViewJsonSerdeFactory.class.getName());
+
+    final LocalApplicationRunner runner = new LocalApplicationRunner(new MapConfig(configs));
+    List<PageView> received = new ArrayList<>();
+    final StreamApplication app = (streamGraph, cfg) -> {
+      streamGraph.getInputStream("PageView", (k, v) -> (PageView) v)
+          .partitionBy(PageView::getMemberId)
+          .sink((m, collector, coordinator) -> {
+              received.add(m);
+            });
+    };
+    runner.run(app);
+    Map<String, StreamOperatorTask> tasks = getTaskOperationGraphs(runner);
+
+    runner.waitForFinish();
+
+    StreamOperatorTask task0 = tasks.get("Partition 0");
+    OperatorImplGraph graph = TestStreamOperatorTask.getOperatorImplGraph(task0);
+    OperatorImpl pb = getOperator(graph, OperatorSpec.OpCode.PARTITION_BY);
+    assertEquals(TestOperatorImpl.getInputWatermark(pb), 4);
+    assertEquals(TestOperatorImpl.getOutputWatermark(pb), 4);
+    OperatorImpl sink = getOperator(graph, OperatorSpec.OpCode.SINK);
+    assertEquals(TestOperatorImpl.getInputWatermark(sink), 3);
+    assertEquals(TestOperatorImpl.getOutputWatermark(sink), 3);
+
+    StreamOperatorTask task1 = tasks.get("Partition 1");
+    graph = TestStreamOperatorTask.getOperatorImplGraph(task1);
+    pb = getOperator(graph, OperatorSpec.OpCode.PARTITION_BY);
+    assertEquals(TestOperatorImpl.getInputWatermark(pb), 3);
+    assertEquals(TestOperatorImpl.getOutputWatermark(pb), 3);
+    sink = getOperator(graph, OperatorSpec.OpCode.SINK);
+    assertEquals(TestOperatorImpl.getInputWatermark(sink), 3);
+    assertEquals(TestOperatorImpl.getOutputWatermark(sink), 3);
+  }
+
+  Map<String, StreamOperatorTask> getTaskOperationGraphs(LocalApplicationRunner runner) throws Exception {
+    StreamProcessor processor = TestLocalApplicationRunner.getProcessors(runner).iterator().next();
+    SamzaContainer container = TestStreamProcessorUtil.getContainer(processor);
+    Map<TaskName, TaskInstance> taskInstances = JavaConverters.mapAsJavaMapConverter(container.getTaskInstances()).asJava();
+    Map<String, StreamOperatorTask> tasks = new HashMap<>();
+    for (Map.Entry<TaskName, TaskInstance> entry : taskInstances.entrySet()) {
+      AsyncStreamTaskAdapter adapter = (AsyncStreamTaskAdapter) entry.getValue().task();
+      Field field = AsyncStreamTaskAdapter.class.getDeclaredField("wrappedTask");
+      field.setAccessible(true);
+      StreamOperatorTask task = (StreamOperatorTask) field.get(adapter);
+      tasks.put(entry.getKey().getTaskName(), task);
+    }
+    return tasks;
+  }
+
+  OperatorImpl getOperator(OperatorImplGraph graph, OperatorSpec.OpCode opCode) {
+    for (InputOperatorImpl input : graph.getAllInputOperators()) {
+      Set<OperatorImpl> nextOps = TestOperatorImpl.getNextOperators(input);
+      while (!nextOps.isEmpty()) {
+        OperatorImpl op = nextOps.iterator().next();
+        if (TestOperatorImpl.getOpCode(op) == opCode) {
+          return op;
+        } else {
+          nextOps = TestOperatorImpl.getNextOperators(op);
+        }
+      }
+    }
+    return null;
+  }
+}
diff --git a/samza-test/src/test/java/org/apache/samza/test/util/ArraySystemConsumer.java b/samza-test/src/test/java/org/apache/samza/test/util/ArraySystemConsumer.java
new file mode 100644 (file)
index 0000000..9b96216
--- /dev/null
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.test.util;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.stream.Collectors;
+import org.apache.samza.config.Config;
+import org.apache.samza.control.EndOfStreamManager;
+import org.apache.samza.system.IncomingMessageEnvelope;
+import org.apache.samza.system.SystemConsumer;
+import org.apache.samza.system.SystemStreamPartition;
+
+/**
+ * A simple implementation of array system consumer
+ */
+public class ArraySystemConsumer implements SystemConsumer {
+  boolean done = false;
+  private final Config config;
+
+  public ArraySystemConsumer(Config config) {
+    this.config = config;
+  }
+
+  @Override
+  public void start() {
+  }
+
+  @Override
+  public void stop() {
+  }
+
+  @Override
+  public void register(SystemStreamPartition systemStreamPartition, String s) {
+  }
+
+  @Override
+  public Map<SystemStreamPartition, List<IncomingMessageEnvelope>> poll(Set<SystemStreamPartition> set, long l) throws InterruptedException {
+    if (!done) {
+      Map<SystemStreamPartition, List<IncomingMessageEnvelope>> envelopeMap = new HashMap<>();
+      set.forEach(ssp -> {
+          List<IncomingMessageEnvelope> envelopes = Arrays.stream(getArrayObjects(ssp.getSystemStream().getStream(), config))
+              .map(object -> new IncomingMessageEnvelope(ssp, null, null, object)).collect(Collectors.toList());
+          envelopes.add(EndOfStreamManager.buildEndOfStreamEnvelope(ssp));
+          envelopeMap.put(ssp, envelopes);
+        });
+      done = true;
+      return envelopeMap;
+    } else {
+      return Collections.emptyMap();
+    }
+
+  }
+
+  private static Object[] getArrayObjects(String stream, Config config) {
+    try {
+      return Base64Serializer.deserialize(config.get("streams." + stream + ".source"), Object[].class);
+    } catch (Exception e) {
+      throw new RuntimeException(e);
+    }
+  }
+}
diff --git a/samza-test/src/test/java/org/apache/samza/test/util/ArraySystemFactory.java b/samza-test/src/test/java/org/apache/samza/test/util/ArraySystemFactory.java
new file mode 100644 (file)
index 0000000..0632865
--- /dev/null
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.test.util;
+
+import org.apache.samza.config.Config;
+import org.apache.samza.metrics.MetricsRegistry;
+import org.apache.samza.system.SystemAdmin;
+import org.apache.samza.system.SystemConsumer;
+import org.apache.samza.system.SystemFactory;
+import org.apache.samza.system.SystemProducer;
+
+
+/**
+ * System factory for the stream from an array
+ */
+public class ArraySystemFactory implements SystemFactory {
+
+  @Override
+  public SystemConsumer getConsumer(String systemName, Config config, MetricsRegistry metricsRegistry) {
+    return new ArraySystemConsumer(config);
+  }
+
+  @Override
+  public SystemProducer getProducer(String systemName, Config config, MetricsRegistry metricsRegistry) {
+    // no producer
+    return null;
+  }
+
+  @Override
+  public SystemAdmin getAdmin(String systemName, Config config) {
+    return new SimpleSystemAdmin(config);
+  }
+}
diff --git a/samza-test/src/test/java/org/apache/samza/test/util/Base64Serializer.java b/samza-test/src/test/java/org/apache/samza/test/util/Base64Serializer.java
new file mode 100644 (file)
index 0000000..1a17a3d
--- /dev/null
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.test.util;
+
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.ObjectInputStream;
+import java.io.ObjectOutputStream;
+import java.io.Serializable;
+import java.util.Base64;
+
+
+public class Base64Serializer {
+  private Base64Serializer() {}
+
+  public static String serializeUnchecked(Serializable serializable) {
+    try {
+      return serialize(serializable);
+    } catch (IOException e) {
+      throw new RuntimeException(e);
+    }
+  }
+
+  public static String serialize(Serializable serializable) throws IOException {
+    final ByteArrayOutputStream baos = new ByteArrayOutputStream();
+    final ObjectOutputStream oos = new ObjectOutputStream(baos);
+    oos.writeObject(serializable);
+    oos.close();
+    return Base64.getEncoder().encodeToString(baos.toByteArray());
+  }
+
+  public static <T> T deserializeUnchecked(String serialized, Class<T> klass) {
+    try {
+      return deserialize(serialized, klass);
+    } catch (Exception e) {
+      throw new RuntimeException(e);
+    }
+  }
+
+  public static <T> T deserialize(String serialized, Class<T> klass) throws IOException, ClassNotFoundException {
+    final byte[] bytes = Base64.getDecoder().decode(serialized);
+    final ObjectInputStream ois = new ObjectInputStream(new ByteArrayInputStream(bytes));
+    @SuppressWarnings("unchecked")
+    T object = (T) ois.readObject();
+    ois.close();
+    return object;
+  }
+}
diff --git a/samza-test/src/test/java/org/apache/samza/test/util/SimpleSystemAdmin.java b/samza-test/src/test/java/org/apache/samza/test/util/SimpleSystemAdmin.java
new file mode 100644 (file)
index 0000000..41f01c5
--- /dev/null
@@ -0,0 +1,90 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.test.util;
+
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Set;
+import java.util.function.Function;
+import java.util.stream.Collectors;
+import org.apache.samza.Partition;
+import org.apache.samza.config.Config;
+import org.apache.samza.system.SystemAdmin;
+import org.apache.samza.system.SystemStreamMetadata;
+import org.apache.samza.system.SystemStreamPartition;
+
+
+/**
+ * A dummy system admin
+ */
+public class SimpleSystemAdmin implements SystemAdmin {
+  private final Config config;
+
+  public SimpleSystemAdmin(Config config) {
+    this.config = config;
+  }
+
+  @Override
+  public Map<SystemStreamPartition, String> getOffsetsAfter(Map<SystemStreamPartition, String> offsets) {
+    return offsets.entrySet().stream()
+        .collect(Collectors.toMap(Map.Entry::getKey, null));
+  }
+
+  @Override
+  public Map<String, SystemStreamMetadata> getSystemStreamMetadata(Set<String> streamNames) {
+    return streamNames.stream()
+        .collect(Collectors.toMap(
+            Function.<String>identity(),
+            streamName -> {
+            Map<Partition, SystemStreamMetadata.SystemStreamPartitionMetadata> metadataMap = new HashMap<>();
+            int partitionCount = config.getInt("streams." + streamName + ".partitionCount", 1);
+            for (int i = 0; i < partitionCount; i++) {
+              metadataMap.put(new Partition(i), new SystemStreamMetadata.SystemStreamPartitionMetadata(null, null, null));
+            }
+            return new SystemStreamMetadata(streamName, metadataMap);
+          }));
+  }
+
+  @Override
+  public void createChangelogStream(String streamName, int numOfPartitions) {
+    throw new UnsupportedOperationException();
+  }
+
+  @Override
+  public void validateChangelogStream(String streamName, int numOfPartitions) {
+    throw new UnsupportedOperationException();
+  }
+
+  @Override
+  public void createCoordinatorStream(String streamName) {
+    throw new UnsupportedOperationException();
+  }
+
+  @Override
+  public Integer offsetComparator(String offset1, String offset2) {
+    if (offset1 == null) {
+      return offset2 == null ? 0 : -1;
+    } else if (offset2 == null) {
+      return 1;
+    }
+    return offset1.compareTo(offset2);
+  }
+}
+
diff --git a/samza-test/src/test/java/org/apache/samza/test/util/TestStreamConsumer.java b/samza-test/src/test/java/org/apache/samza/test/util/TestStreamConsumer.java
new file mode 100644 (file)
index 0000000..31eee15
--- /dev/null
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.test.util;
+
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.stream.Collectors;
+import org.apache.samza.system.IncomingMessageEnvelope;
+import org.apache.samza.system.SystemConsumer;
+import org.apache.samza.system.SystemStreamPartition;
+
+public class TestStreamConsumer implements SystemConsumer {
+  private List<IncomingMessageEnvelope> envelopes;
+
+  public TestStreamConsumer(List<IncomingMessageEnvelope> envelopes) {
+    this.envelopes = envelopes;
+  }
+
+  @Override
+  public void start() { }
+
+  @Override
+  public void stop() { }
+
+  @Override
+  public void register(SystemStreamPartition systemStreamPartition, String offset) { }
+
+  @Override
+  public Map<SystemStreamPartition, List<IncomingMessageEnvelope>> poll(
+      Set<SystemStreamPartition> systemStreamPartitions, long timeout)
+      throws InterruptedException {
+    return systemStreamPartitions.stream().collect(Collectors.toMap(ssp -> ssp, ssp -> envelopes));
+  }
+}