SAMZA-1067; Physical execution graph and planner for fluent API
authorXinyu Liu <xiliu@xiliu-ld.linkedin.biz>
Thu, 16 Mar 2017 01:27:01 +0000 (18:27 -0700)
committerXinyu Liu <xiliu@xiliu-ld.linkedin.biz>
Thu, 16 Mar 2017 01:27:01 +0000 (18:27 -0700)
Initial commit for the physical graph and plan. Design is there: https://issues.apache.org/jira/secure/attachment/12856670/SAMZA-1067.0.pdf.

The commit includes:

1) Physical ProcessorGraph, where each processor represents a physical execution unit (e.g. a job in Yarn).
2) A planner does the following:
   - create ProcessorGraph from StreamGraph. For this phase, the graph only contains a single node (single stage);
   - figure out the partitions of intermediate topics
   - create the topics

Please note currently the planner is used in the remote runner for now. Further changes/refactoring/cleanup are expected to be integrated with local runner.

Author: Xinyu Liu <xiliu@xiliu-ld.linkedin.biz>

Reviewers: Jagadish Venkatraman <jvenkatraman@linkedin.com>

Closes #75 from xinyuiscool/SAMZA-1067

samza-core/src/main/java/org/apache/samza/execution/ExecutionPlanner.java [new file with mode: 0644]
samza-core/src/main/java/org/apache/samza/execution/ProcessorGraph.java [new file with mode: 0644]
samza-core/src/main/java/org/apache/samza/execution/ProcessorNode.java [new file with mode: 0644]
samza-core/src/main/java/org/apache/samza/execution/StreamEdge.java [new file with mode: 0644]
samza-core/src/main/java/org/apache/samza/operators/spec/OperatorSpecs.java
samza-core/src/main/scala/org/apache/samza/config/JobConfig.scala
samza-core/src/main/scala/org/apache/samza/job/JobRunner.scala
samza-core/src/main/scala/org/apache/samza/util/Util.scala
samza-core/src/test/java/org/apache/samza/execution/TestExecutionPlanner.java [new file with mode: 0644]
samza-core/src/test/java/org/apache/samza/execution/TestProcessorGraph.java [new file with mode: 0644]

diff --git a/samza-core/src/main/java/org/apache/samza/execution/ExecutionPlanner.java b/samza-core/src/main/java/org/apache/samza/execution/ExecutionPlanner.java
new file mode 100644 (file)
index 0000000..77790a8
--- /dev/null
@@ -0,0 +1,337 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.execution;
+
+import com.google.common.collect.HashMultimap;
+import com.google.common.collect.Multimap;
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.LinkedList;
+import java.util.Map;
+import java.util.Queue;
+import java.util.Set;
+import java.util.stream.Collectors;
+import org.apache.samza.SamzaException;
+import org.apache.samza.config.Config;
+import org.apache.samza.config.JavaSystemConfig;
+import org.apache.samza.config.JobConfig;
+import org.apache.samza.operators.MessageStream;
+import org.apache.samza.operators.MessageStreamImpl;
+import org.apache.samza.operators.StreamGraph;
+import org.apache.samza.operators.spec.OperatorSpec;
+import org.apache.samza.operators.spec.PartialJoinOperatorSpec;
+import org.apache.samza.system.StreamSpec;
+import org.apache.samza.system.SystemAdmin;
+import org.apache.samza.system.SystemFactory;
+import org.apache.samza.system.SystemStream;
+import org.apache.samza.system.SystemStreamMetadata;
+import org.apache.samza.util.Util;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+
+/**
+ * The ExecutionPlanner creates the physical execution graph for the StreamGraph, and
+ * the intermediate topics needed for the execution.
+ */
+public class ExecutionPlanner {
+  private static final Logger log = LoggerFactory.getLogger(ExecutionPlanner.class);
+
+  private final Config config;
+
+  public ExecutionPlanner(Config config) {
+    this.config = config;
+  }
+
+  public ProcessorGraph plan(StreamGraph streamGraph) throws Exception {
+    Map<String, SystemAdmin> sysAdmins = getSystemAdmins(config);
+
+    // create physical processors based on stream graph
+    ProcessorGraph processorGraph = createProcessorGraph(streamGraph);
+
+    if (!processorGraph.getIntermediateStreams().isEmpty()) {
+      // figure out the partitions for internal streams
+      calculatePartitions(streamGraph, processorGraph, sysAdmins);
+
+      // create the streams
+      createStreams(processorGraph, sysAdmins);
+    }
+
+    return processorGraph;
+  }
+
+  /**
+   * Create the physical graph from StreamGraph
+   */
+  /* package private */ ProcessorGraph createProcessorGraph(StreamGraph streamGraph) {
+    // For this phase, we are going to create a processor for the whole dag
+    String processorId = config.get(JobConfig.JOB_NAME()); // only one processor, use the job name
+
+    ProcessorGraph processorGraph = new ProcessorGraph(config);
+    Set<StreamSpec> sourceStreams = new HashSet<>(streamGraph.getInStreams().keySet());
+    Set<StreamSpec> sinkStreams = new HashSet<>(streamGraph.getOutStreams().keySet());
+    Set<StreamSpec> intStreams = new HashSet<>(sourceStreams);
+    intStreams.retainAll(sinkStreams);
+    sourceStreams.removeAll(intStreams);
+    sinkStreams.removeAll(intStreams);
+
+    // add sources
+    sourceStreams.forEach(spec -> processorGraph.addSource(spec, processorId));
+
+    // add sinks
+    sinkStreams.forEach(spec -> processorGraph.addSink(spec, processorId));
+
+    // add intermediate streams
+    intStreams.forEach(spec -> processorGraph.addIntermediateStream(spec, processorId, processorId));
+
+    processorGraph.validate();
+
+    return processorGraph;
+  }
+
+  /**
+   * Figure out the number of partitions of all streams
+   */
+  /* package private */ void calculatePartitions(StreamGraph streamGraph, ProcessorGraph processorGraph, Map<String, SystemAdmin> sysAdmins) {
+    // fetch the external streams partition info
+    updateExistingPartitions(processorGraph, sysAdmins);
+
+    // calculate the partitions for the input streams of join operators
+    calculateJoinInputPartitions(streamGraph, processorGraph);
+
+    // calculate the partitions for the rest of intermediate streams
+    calculateIntStreamPartitions(processorGraph, config);
+
+    // validate all the partitions are assigned
+    validatePartitions(processorGraph);
+  }
+
+  /**
+   * Fetch the partitions of source/sink streams and update the StreamEdges.
+   * @param processorGraph ProcessorGraph
+   * @param sysAdmins mapping from system name to the {@link SystemAdmin}
+   */
+  /* package private */ static void updateExistingPartitions(ProcessorGraph processorGraph, Map<String, SystemAdmin> sysAdmins) {
+    Set<StreamEdge> existingStreams = new HashSet<>();
+    existingStreams.addAll(processorGraph.getSources());
+    existingStreams.addAll(processorGraph.getSinks());
+
+    Multimap<String, StreamEdge> systemToStreamEdges = HashMultimap.create();
+    // group the StreamEdge(s) based on the system name
+    existingStreams.forEach(streamEdge -> {
+        SystemStream systemStream = streamEdge.getSystemStream();
+        systemToStreamEdges.put(systemStream.getSystem(), streamEdge);
+      });
+    for (Map.Entry<String, Collection<StreamEdge>> entry : systemToStreamEdges.asMap().entrySet()) {
+      String systemName = entry.getKey();
+      Collection<StreamEdge> streamEdges = entry.getValue();
+      Map<String, StreamEdge> streamToStreamEdge = new HashMap<>();
+      // create the stream name to StreamEdge mapping for this system
+      streamEdges.forEach(streamEdge -> streamToStreamEdge.put(streamEdge.getSystemStream().getStream(), streamEdge));
+      SystemAdmin systemAdmin = sysAdmins.get(systemName);
+      // retrieve the metadata for the streams in this system
+      Map<String, SystemStreamMetadata> streamToMetadata = systemAdmin.getSystemStreamMetadata(streamToStreamEdge.keySet());
+      // set the partitions of a stream to its StreamEdge
+      streamToMetadata.forEach((stream, data) -> {
+          int partitions = data.getSystemStreamPartitionMetadata().size();
+          streamToStreamEdge.get(stream).setPartitionCount(partitions);
+          log.debug("Partition count is {} for stream {}", partitions, stream);
+        });
+    }
+  }
+
+  /**
+   * Calculate the partitions for the input streams of join operators
+   */
+  /* package private */ static void calculateJoinInputPartitions(StreamGraph streamGraph, ProcessorGraph processorGraph) {
+    // mapping from a source stream to all join specs reachable from it
+    Multimap<OperatorSpec, StreamEdge> joinSpecToStreamEdges = HashMultimap.create();
+    // reverse mapping of the above
+    Multimap<StreamEdge, OperatorSpec> streamEdgeToJoinSpecs = HashMultimap.create();
+    // Mapping from the output stream to the join spec. Since StreamGraph creates two partial join operators for a join and they
+    // will have the same output stream, this mapping is used to choose one of them as the unique join spec representing this join
+    // (who register first in the map wins).
+    Map<MessageStream, OperatorSpec> outputStreamToJoinSpec = new HashMap<>();
+    // A queue of joins with known input partitions
+    Queue<OperatorSpec> joinQ = new LinkedList<>();
+    // The visited set keeps track of the join specs that have been already inserted in the queue before
+    Set<OperatorSpec> visited = new HashSet<>();
+
+    streamGraph.getInStreams().entrySet().forEach(entry -> {
+        StreamEdge streamEdge = processorGraph.getOrCreateEdge(entry.getKey());
+        // Traverses the StreamGraph to find and update mappings for all Joins reachable from this input StreamEdge
+        findReachableJoins(entry.getValue(), streamEdge, joinSpecToStreamEdges, streamEdgeToJoinSpecs,
+            outputStreamToJoinSpec, joinQ, visited);
+      });
+
+    // At this point, joinQ contains joinSpecs where at least one of the input stream edge partitions is known.
+    while (!joinQ.isEmpty()) {
+      OperatorSpec join = joinQ.poll();
+      int partitions = StreamEdge.PARTITIONS_UNKNOWN;
+      // loop through the input streams to the join and find the partition count
+      for (StreamEdge edge : joinSpecToStreamEdges.get(join)) {
+        int edgePartitions = edge.getPartitionCount();
+        if (edgePartitions != StreamEdge.PARTITIONS_UNKNOWN) {
+          if (partitions == StreamEdge.PARTITIONS_UNKNOWN) {
+            //if the partition is not assigned
+            partitions = edgePartitions;
+          } else if (partitions != edgePartitions) {
+            throw  new SamzaException(String.format(
+                "Unable to resolve input partitions of stream %s for join. Expected: %d, Actual: %d",
+                edge.getFormattedSystemStream(), partitions, edgePartitions));
+          }
+        }
+      }
+      // assign the partition count for intermediate streams
+      for (StreamEdge edge : joinSpecToStreamEdges.get(join)) {
+        if (edge.getPartitionCount() <= 0) {
+          edge.setPartitionCount(partitions);
+
+          // find other joins can be inferred by setting this edge
+          for (OperatorSpec op : streamEdgeToJoinSpecs.get(edge)) {
+            if (!visited.contains(op)) {
+              joinQ.add(op);
+              visited.add(op);
+            }
+          }
+        }
+      }
+    }
+  }
+
+  /**
+   * This function traverses the StreamGraph to find and update mappings for all Joins reachable from this input StreamEdge
+   * @param inputMessageStream next input MessageStream to traverse {@link MessageStream}
+   * @param sourceStreamEdge source {@link StreamEdge}
+   * @param joinSpecToStreamEdges mapping from join spec to its source {@link StreamEdge}s
+   * @param streamEdgeToJoinSpecs mapping from source {@link StreamEdge} to the join specs that consumes it
+   * @param outputStreamToJoinSpec mapping from the output stream to the join spec
+   * @param joinQ queue that contains joinSpecs where at least one of the input stream edge partitions is known.
+   */
+  private static void findReachableJoins(MessageStream inputMessageStream, StreamEdge sourceStreamEdge,
+      Multimap<OperatorSpec, StreamEdge> joinSpecToStreamEdges, Multimap<StreamEdge, OperatorSpec> streamEdgeToJoinSpecs,
+      Map<MessageStream, OperatorSpec> outputStreamToJoinSpec, Queue<OperatorSpec> joinQ, Set<OperatorSpec> visited) {
+    Collection<OperatorSpec> specs = ((MessageStreamImpl) inputMessageStream).getRegisteredOperatorSpecs();
+    for (OperatorSpec spec : specs) {
+      if (spec instanceof PartialJoinOperatorSpec) {
+        // every join will have two partial join operators
+        // we will choose one of them in order to consolidate the inputs
+        // the first one who registered with the outputStreamToJoinSpec will win
+        MessageStream output = spec.getNextStream();
+        OperatorSpec joinSpec = outputStreamToJoinSpec.get(output);
+        if (joinSpec == null) {
+          joinSpec = spec;
+          outputStreamToJoinSpec.put(output, joinSpec);
+        }
+
+        joinSpecToStreamEdges.put(joinSpec, sourceStreamEdge);
+        streamEdgeToJoinSpecs.put(sourceStreamEdge, joinSpec);
+
+        if (!visited.contains(joinSpec) && sourceStreamEdge.getPartitionCount() > 0) {
+          // put the joins with known input partitions into the queue
+          joinQ.add(joinSpec);
+          visited.add(joinSpec);
+        }
+      }
+
+      if (spec.getNextStream() != null) {
+        findReachableJoins(spec.getNextStream(), sourceStreamEdge, joinSpecToStreamEdges, streamEdgeToJoinSpecs, outputStreamToJoinSpec, joinQ,
+            visited);
+      }
+    }
+  }
+
+  private static void calculateIntStreamPartitions(ProcessorGraph processorGraph, Config config) {
+    int partitions = config.getInt(JobConfig.JOB_INTERMEDIATE_STREAM_PARTITIONS(), StreamEdge.PARTITIONS_UNKNOWN);
+    if (partitions < 0) {
+      // use the following simple algo to figure out the partitions
+      // partition = MAX(MAX(Input topic partitions), MAX(Output topic partitions))
+      int maxInPartitions = maxPartition(processorGraph.getSources());
+      int maxOutPartitions = maxPartition(processorGraph.getSinks());
+      partitions = Math.max(maxInPartitions, maxOutPartitions);
+    }
+    for (StreamEdge edge : processorGraph.getIntermediateStreams()) {
+      if (edge.getPartitionCount() <= 0) {
+        edge.setPartitionCount(partitions);
+      }
+    }
+  }
+
+  private static void validatePartitions(ProcessorGraph processorGraph) {
+    for (StreamEdge edge : processorGraph.getIntermediateStreams()) {
+      if (edge.getPartitionCount() <= 0) {
+        throw new SamzaException(String.format("Failure to assign the partitions to Stream %s", edge.getFormattedSystemStream()));
+      }
+    }
+  }
+
+  private static void createStreams(ProcessorGraph graph, Map<String, SystemAdmin> sysAdmins) {
+    Multimap<String, StreamSpec> streamsToCreate = HashMultimap.create();
+    graph.getIntermediateStreams().forEach(edge -> {
+        StreamSpec streamSpec = createStreamSpec(edge);
+        streamsToCreate.put(edge.getSystemStream().getSystem(), streamSpec);
+      });
+
+    for (Map.Entry<String, Collection<StreamSpec>> entry : streamsToCreate.asMap().entrySet()) {
+      String systemName = entry.getKey();
+      SystemAdmin systemAdmin = sysAdmins.get(systemName);
+
+      for (StreamSpec stream : entry.getValue()) {
+        log.info("Creating stream {} with partitions {} on system {}",
+            new Object[]{stream.getPhysicalName(), stream.getPartitionCount(), systemName});
+        systemAdmin.createStream(stream);
+      }
+    }
+  }
+
+  private static int maxPartition(Collection<StreamEdge> edges) {
+    return edges.stream().map(StreamEdge::getPartitionCount).reduce(Integer::max).get();
+  }
+
+  private static StreamSpec createStreamSpec(StreamEdge edge) {
+    StreamSpec orgSpec = edge.getStreamSpec();
+    return orgSpec.copyWithPartitionCount(edge.getPartitionCount());
+  }
+
+  private static Map<String, SystemAdmin> getSystemAdmins(Config config) {
+    return getSystemFactories(config).entrySet()
+        .stream()
+        .collect(Collectors.toMap(entry -> entry.getKey(), entry -> entry.getValue().getAdmin(entry.getKey(), config)));
+  }
+
+  private static Map<String, SystemFactory> getSystemFactories(Config config) {
+    Map<String, SystemFactory> systemFactories =
+        getSystemNames(config).stream().collect(Collectors.toMap(systemName -> systemName, systemName -> {
+            String systemFactoryClassName = new JavaSystemConfig(config).getSystemFactory(systemName);
+            if (systemFactoryClassName == null) {
+              throw new SamzaException(
+                  String.format("A stream uses system %s, which is missing from the configuration.", systemName));
+            }
+            return Util.getObj(systemFactoryClassName);
+          }));
+
+    return systemFactories;
+  }
+
+  private static Collection<String> getSystemNames(Config config) {
+    return new JavaSystemConfig(config).getSystemNames();
+  }
+}
diff --git a/samza-core/src/main/java/org/apache/samza/execution/ProcessorGraph.java b/samza-core/src/main/java/org/apache/samza/execution/ProcessorGraph.java
new file mode 100644 (file)
index 0000000..07cc3d3
--- /dev/null
@@ -0,0 +1,353 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.execution;
+
+import java.util.ArrayDeque;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Queue;
+import java.util.Set;
+import java.util.stream.Collectors;
+import org.apache.samza.config.Config;
+import org.apache.samza.system.StreamSpec;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+
+/**
+ * The ProcessorGraph is the physical execution graph for a multi-stage Samza application.
+ * It contains the topology of execution processors connected with source/sink/intermediate streams.
+ * High level APIs are transformed into ProcessorGraph for planning, validation and execution.
+ * Source/sink streams are external streams while intermediate streams are created and managed by Samza.
+ * Note that intermediate streams are both the input and output of a ProcessorNode in ProcessorGraph.
+ * So the graph may have cycles and it's not a DAG.
+ */
+public class ProcessorGraph {
+  private static final Logger log = LoggerFactory.getLogger(ProcessorGraph.class);
+
+  private final Map<String, ProcessorNode> nodes = new HashMap<>();
+  private final Map<String, StreamEdge> edges = new HashMap<>();
+  private final Set<StreamEdge> sources = new HashSet<>();
+  private final Set<StreamEdge> sinks = new HashSet<>();
+  private final Set<StreamEdge> intermediateStreams = new HashSet<>();
+  private final Config config;
+
+  /**
+   * The ProcessorGraph is only constructed by the {@link ExecutionPlanner}.
+   * @param config Config
+   */
+  /* package private */ ProcessorGraph(Config config) {
+    this.config = config;
+  }
+
+  /**
+   * Add a source stream to a {@link ProcessorNode}
+   * @param input source stream
+   * @param targetProcessorId id of the {@link ProcessorNode}
+   */
+  /* package private */ void addSource(StreamSpec input, String targetProcessorId) {
+    ProcessorNode node = getOrCreateProcessor(targetProcessorId);
+    StreamEdge edge = getOrCreateEdge(input);
+    edge.addTargetNode(node);
+    node.addInEdge(edge);
+    sources.add(edge);
+  }
+
+  /**
+   * Add a sink stream to a {@link ProcessorNode}
+   * @param output sink stream
+   * @param sourceProcessorId id of the {@link ProcessorNode}
+   */
+  /* package private */ void addSink(StreamSpec output, String sourceProcessorId) {
+    ProcessorNode node = getOrCreateProcessor(sourceProcessorId);
+    StreamEdge edge = getOrCreateEdge(output);
+    edge.addSourceNode(node);
+    node.addOutEdge(edge);
+    sinks.add(edge);
+  }
+
+  /**
+   * Add an intermediate stream from source to target {@link ProcessorNode}
+   * @param streamSpec intermediate stream
+   * @param sourceProcessorId id of the source {@link ProcessorNode}
+   * @param targetProcessorId id of the target {@link ProcessorNode}
+   */
+  /* package private */ void addIntermediateStream(StreamSpec streamSpec, String sourceProcessorId, String targetProcessorId) {
+    ProcessorNode sourceNode = getOrCreateProcessor(sourceProcessorId);
+    ProcessorNode targetNode = getOrCreateProcessor(targetProcessorId);
+    StreamEdge edge = getOrCreateEdge(streamSpec);
+    edge.addSourceNode(sourceNode);
+    edge.addTargetNode(targetNode);
+    sourceNode.addOutEdge(edge);
+    targetNode.addInEdge(edge);
+    intermediateStreams.add(edge);
+  }
+
+  /**
+   * Get the {@link ProcessorNode} for an id. Create one if it does not exist.
+   * @param processorId id of the processor
+   * @return processor node
+   */
+  /* package private */ProcessorNode getOrCreateProcessor(String processorId) {
+    ProcessorNode node = nodes.get(processorId);
+    if (node == null) {
+      node = new ProcessorNode(processorId, config);
+      nodes.put(processorId, node);
+    }
+    return node;
+  }
+
+  /**
+   * Get the {@link StreamEdge} for a {@link StreamSpec}. Create one if it does not exist.
+   * @param streamSpec spec of the StreamEdge
+   * @return stream edge
+   */
+  /* package private */StreamEdge getOrCreateEdge(StreamSpec streamSpec) {
+    String streamId = streamSpec.getId();
+    StreamEdge edge = edges.get(streamId);
+    if (edge == null) {
+      edge = new StreamEdge(streamSpec);
+      edges.put(streamId, edge);
+    }
+    return edge;
+  }
+
+  /**
+   * Returns the processors to be executed in the topological order
+   * @return unmodifiable list of {@link ProcessorNode}
+   */
+  public List<ProcessorNode> getProcessors() {
+    List<ProcessorNode> sortedNodes = topologicalSort();
+    return Collections.unmodifiableList(sortedNodes);
+  }
+
+  /**
+   * Returns the source streams in the graph
+   * @return unmodifiable set of {@link StreamEdge}
+   */
+  public Set<StreamEdge> getSources() {
+    return Collections.unmodifiableSet(sources);
+  }
+
+  /**
+   * Return the sink streams in the graph
+   * @return unmodifiable set of {@link StreamEdge}
+   */
+  public Set<StreamEdge> getSinks() {
+    return Collections.unmodifiableSet(sinks);
+  }
+
+  /**
+   * Return the intermediate streams in the graph
+   * @return unmodifiable set of {@link StreamEdge}
+   */
+  public Set<StreamEdge> getIntermediateStreams() {
+    return Collections.unmodifiableSet(intermediateStreams);
+  }
+
+
+  /**
+   * Validate the graph has the correct topology, meaning the sources are coming from external streams,
+   * sinks are going to external streams, and the nodes are connected with intermediate streams.
+   * Also validate all the nodes are reachable from the sources.
+   */
+  public void validate() {
+    validateSources();
+    validateSinks();
+    validateInternalStreams();
+    validateReachability();
+  }
+
+  /**
+   * Validate the sources should have indegree being 0 and outdegree greater than 0
+   */
+  private void validateSources() {
+    sources.forEach(edge -> {
+        if (!edge.getSourceNodes().isEmpty()) {
+          throw new IllegalArgumentException(
+              String.format("Source stream %s should not have producers.", edge.getFormattedSystemStream()));
+        }
+        if (edge.getTargetNodes().isEmpty()) {
+          throw new IllegalArgumentException(
+              String.format("Source stream %s should have consumers.", edge.getFormattedSystemStream()));
+        }
+      });
+  }
+
+  /**
+   * Validate the sinks should have outdegree being 0 and indegree greater than 0
+   */
+  private void validateSinks() {
+    sinks.forEach(edge -> {
+        if (!edge.getTargetNodes().isEmpty()) {
+          throw new IllegalArgumentException(
+              String.format("Sink stream %s should not have consumers", edge.getFormattedSystemStream()));
+        }
+        if (edge.getSourceNodes().isEmpty()) {
+          throw new IllegalArgumentException(
+              String.format("Sink stream %s should have producers", edge.getFormattedSystemStream()));
+        }
+      });
+  }
+
+  /**
+   * Validate the internal streams should have both indegree and outdegree greater than 0
+   */
+  private void validateInternalStreams() {
+    Set<StreamEdge> internalEdges = new HashSet<>(edges.values());
+    internalEdges.removeAll(sources);
+    internalEdges.removeAll(sinks);
+
+    internalEdges.forEach(edge -> {
+        if (edge.getSourceNodes().isEmpty() || edge.getTargetNodes().isEmpty()) {
+          throw new IllegalArgumentException(
+              String.format("Internal stream %s should have both producers and consumers", edge.getFormattedSystemStream()));
+        }
+      });
+  }
+
+  /**
+   * Validate all nodes are reachable by sources.
+   */
+  private void validateReachability() {
+    // validate all nodes are reachable from the sources
+    final Set<ProcessorNode> reachable = findReachable();
+    if (reachable.size() != nodes.size()) {
+      Set<ProcessorNode> unreachable = new HashSet<>(nodes.values());
+      unreachable.removeAll(reachable);
+      throw new IllegalArgumentException(String.format("Processors %s cannot be reached from Sources.",
+          String.join(", ", unreachable.stream().map(ProcessorNode::getId).collect(Collectors.toList()))));
+    }
+  }
+
+  /**
+   * Find the reachable set of nodes using BFS.
+   * @return reachable set of {@link ProcessorNode}
+   */
+  /* package private */ Set<ProcessorNode> findReachable() {
+    Queue<ProcessorNode> queue = new ArrayDeque<>();
+    Set<ProcessorNode> visited = new HashSet<>();
+
+    sources.forEach(source -> {
+        List<ProcessorNode> next = source.getTargetNodes();
+        queue.addAll(next);
+        visited.addAll(next);
+      });
+
+    while (!queue.isEmpty()) {
+      ProcessorNode node = queue.poll();
+      node.getOutEdges().stream().flatMap(edge -> edge.getTargetNodes().stream()).forEach(target -> {
+          if (!visited.contains(target)) {
+            visited.add(target);
+            queue.offer(target);
+          }
+        });
+    }
+
+    return visited;
+  }
+
+  /**
+   * An variation of Kahn's algorithm of topological sorting.
+   * This algorithm also takes account of the simple loops in the graph
+   * @return topologically sorted {@link ProcessorNode}s
+   */
+  /* package private */ List<ProcessorNode> topologicalSort() {
+    Collection<ProcessorNode> pnodes = nodes.values();
+    Queue<ProcessorNode> q = new ArrayDeque<>();
+    Map<String, Long> indegree = new HashMap<>();
+    Set<ProcessorNode> visited = new HashSet<>();
+    pnodes.forEach(node -> {
+        String nid = node.getId();
+        //only count the degrees of intermediate streams
+        long degree = node.getInEdges().stream().filter(e -> !sources.contains(e)).count();
+        indegree.put(nid, degree);
+
+        if (degree == 0L) {
+          // start from the nodes that has no intermediate input streams, so it only consumes from sources
+          q.add(node);
+          visited.add(node);
+        }
+      });
+
+    List<ProcessorNode> sortedNodes = new ArrayList<>();
+    Set<ProcessorNode> reachable = new HashSet<>();
+    while (sortedNodes.size() < pnodes.size()) {
+      // Here we use indegree-based approach to implment Kahn's algorithm for topological sort
+      // This approach will not change the graph itself during computation.
+      //
+      // The algorithm works as:
+      // 1. start with nodes with no incoming edges (in degree being 0) and inserted into the list
+      // 2. remove the edge from any node in the list to its connected nodes by changing the indegree of the connected nodes.
+      // 3. add any new nodes with ingree being 0
+      // 4. loop 1-3 until no more nodes with indegree 0
+      //
+      while (!q.isEmpty()) {
+        ProcessorNode node = q.poll();
+        sortedNodes.add(node);
+        node.getOutEdges().stream().flatMap(edge -> edge.getTargetNodes().stream()).forEach(n -> {
+            String nid = n.getId();
+            Long degree = indegree.get(nid) - 1;
+            indegree.put(nid, degree);
+            if (degree == 0L && !visited.contains(n)) {
+              q.add(n);
+              visited.add(n);
+            }
+            reachable.add(n);
+          });
+      }
+
+      if (sortedNodes.size() < pnodes.size()) {
+        // The remaining nodes have cycles
+        // use the following approach to break the cycles
+        // start from the nodes that are reachable from previous traverse
+        reachable.removeAll(sortedNodes);
+        if (!reachable.isEmpty()) {
+          //find out the nodes with minimal input edge
+          long min = Long.MAX_VALUE;
+          ProcessorNode minNode = null;
+          for (ProcessorNode node : reachable) {
+            Long degree = indegree.get(node.getId());
+            if (degree < min) {
+              min = degree;
+              minNode = node;
+            }
+          }
+          // start from the node with minimal input edge again
+          q.add(minNode);
+        } else {
+          // all the remaining nodes should be reachable from sources
+          // start from sources again to find the next node that hasn't been visited
+          ProcessorNode nextNode = sources.stream().flatMap(source -> source.getTargetNodes().stream())
+              .filter(node -> !visited.contains(node))
+              .findAny().get();
+          q.add(nextNode);
+        }
+      }
+    }
+
+    return sortedNodes;
+  }
+}
diff --git a/samza-core/src/main/java/org/apache/samza/execution/ProcessorNode.java b/samza-core/src/main/java/org/apache/samza/execution/ProcessorNode.java
new file mode 100644 (file)
index 0000000..3cb695f
--- /dev/null
@@ -0,0 +1,116 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.execution;
+
+import com.google.common.base.Joiner;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.stream.Collectors;
+import org.apache.samza.config.Config;
+import org.apache.samza.config.JobConfig;
+import org.apache.samza.config.MapConfig;
+import org.apache.samza.config.TaskConfig;
+import org.apache.samza.util.Util;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+
+/**
+ * A ProcessorNode is a physical execution unit. In RemoteExecutionEnvironment, it's a job that will be submitted
+ * to remote cluster. In LocalExecutionEnvironment, it's a set of StreamProcessors for local execution.
+ * A ProcessorNode contains the input/output, and the configs for physical execution.
+ */
+public class ProcessorNode {
+  private static final Logger log = LoggerFactory.getLogger(ProcessorNode.class);
+  private static final String CONFIG_PROCESSOR_PREFIX = "processors.%s.";
+
+  private final String id;
+  private final List<StreamEdge> inEdges = new ArrayList<>();
+  private final List<StreamEdge> outEdges = new ArrayList<>();
+  private final Config config;
+
+  ProcessorNode(String id, Config config) {
+    this.id = id;
+    this.config = config;
+  }
+
+  public  String getId() {
+    return id;
+  }
+
+  void addInEdge(StreamEdge in) {
+    inEdges.add(in);
+  }
+
+  void addOutEdge(StreamEdge out) {
+    outEdges.add(out);
+  }
+
+  List<StreamEdge> getInEdges() {
+    return inEdges;
+  }
+
+  List<StreamEdge> getOutEdges() {
+    return outEdges;
+  }
+
+  public Config generateConfig() {
+    Map<String, String> configs = new HashMap<>();
+    configs.put(JobConfig.JOB_NAME(), id);
+
+    List<String> inputs = inEdges.stream().map(edge -> edge.getFormattedSystemStream()).collect(Collectors.toList());
+    configs.put(TaskConfig.INPUT_STREAMS(), Joiner.on(',').join(inputs));
+    log.info("Processor {} has generated configs {}", id, configs);
+
+    String configPrefix = String.format(CONFIG_PROCESSOR_PREFIX, id);
+    // TODO: Disallow user specifying processor inputs/outputs. This info comes strictly from the pipeline.
+    return Util.rewriteConfig(extractScopedConfig(config, new MapConfig(configs), configPrefix));
+  }
+
+  /**
+   * This function extract the subset of configs from the full config, and use it to override the generated configs
+   * from the processor.
+   * @param fullConfig full config
+   * @param generatedConfig config generated from the processor
+   * @param configPrefix prefix to extract the subset of the config overrides
+   * @return config that merges the generated configs and overrides
+   */
+  private static Config extractScopedConfig(Config fullConfig, Config generatedConfig, String configPrefix) {
+    Config scopedConfig = fullConfig.subset(configPrefix);
+
+    Config[] configPrecedence = new Config[] {fullConfig, generatedConfig, scopedConfig};
+    // Strip empty configs so they don't override the configs before them.
+    Map<String, String> mergedConfig = new HashMap<>();
+    for (Map<String, String> config : configPrecedence) {
+      for (Map.Entry<String, String> property : config.entrySet()) {
+        String value = property.getValue();
+        if (!(value == null || value.isEmpty())) {
+          mergedConfig.put(property.getKey(), property.getValue());
+        }
+      }
+    }
+    scopedConfig = new MapConfig(mergedConfig);
+    log.debug("Prefix '{}' has merged config {}", configPrefix, scopedConfig);
+
+    return scopedConfig;
+  }
+}
diff --git a/samza-core/src/main/java/org/apache/samza/execution/StreamEdge.java b/samza-core/src/main/java/org/apache/samza/execution/StreamEdge.java
new file mode 100644 (file)
index 0000000..5dc4178
--- /dev/null
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.execution;
+
+import java.util.ArrayList;
+import java.util.List;
+import org.apache.samza.system.StreamSpec;
+import org.apache.samza.system.SystemStream;
+import org.apache.samza.util.Util;
+
+
+/**
+ * A StreamEdge connects the source {@link ProcessorNode}s to the target {@link ProcessorNode}s with a stream.
+ * If it's a sink StreamEdge, the target ProcessorNode is empty.
+ * If it's a source StreamEdge, the source ProcessorNode is empty.
+ */
+public class StreamEdge {
+  public static final int PARTITIONS_UNKNOWN = -1;
+
+  private final StreamSpec streamSpec;
+  private final List<ProcessorNode> sourceNodes = new ArrayList<>();
+  private final List<ProcessorNode> targetNodes = new ArrayList<>();
+
+  private String name = "";
+  private int partitions = PARTITIONS_UNKNOWN;
+
+  StreamEdge(StreamSpec streamSpec) {
+    this.streamSpec = streamSpec;
+    this.name = Util.getNameFromSystemStream(getSystemStream());
+  }
+
+  void addSourceNode(ProcessorNode sourceNode) {
+    sourceNodes.add(sourceNode);
+  }
+
+  void addTargetNode(ProcessorNode targetNode) {
+    targetNodes.add(targetNode);
+  }
+
+  StreamSpec getStreamSpec() {
+    return streamSpec;
+  }
+
+  SystemStream getSystemStream() {
+    return new SystemStream(streamSpec.getSystemName(), streamSpec.getPhysicalName());
+  }
+
+  String getFormattedSystemStream() {
+    return Util.getNameFromSystemStream(getSystemStream());
+  }
+
+  List<ProcessorNode> getSourceNodes() {
+    return sourceNodes;
+  }
+
+  List<ProcessorNode> getTargetNodes() {
+    return targetNodes;
+  }
+
+  int getPartitionCount() {
+    return partitions;
+  }
+
+  void setPartitionCount(int partitions) {
+    this.partitions = partitions;
+  }
+
+  String getName() {
+    return name;
+  }
+
+  void setName(String name) {
+    this.name = name;
+  }
+}
index a0c7820..c00f470 100644 (file)
@@ -148,6 +148,7 @@ public class OperatorSpecs {
    *
    * @param sinkFn  the sink function
    * @param stream  the {@link OutputStream} where the message is sent to
+   * @param opId operator ID
    * @param <M>  type of input message
    * @return  the {@link SinkOperatorSpec}
    */
index 9d6cbc2..4e14097 100644 (file)
@@ -49,6 +49,8 @@ object JobConfig {
   val JOB_CONTAINER_SINGLE_THREAD_MODE = "job.container.single.thread.mode"
   val JOB_REPLICATION_FACTOR = "job.coordinator.replication.factor"
   val JOB_SEGMENT_BYTES = "job.coordinator.segment.bytes"
+  val JOB_INTERMEDIATE_STREAM_PARTITIONS = "job.intermediate.stream.partitions"
+
   val SSP_GROUPER_FACTORY = "job.systemstreampartition.grouper.factory"
 
   val SSP_MATCHER_CLASS = "job.systemstreampartition.matcher.class";
index 022b480..6d8b24d 100644 (file)
@@ -36,34 +36,11 @@ import org.apache.samza.coordinator.stream.CoordinatorStreamSystemFactory
 object JobRunner extends Logging {
   val SOURCE = "job-runner"
 
-  /**
-   * Re-writes configuration using a ConfigRewriter, if one is defined. If
-   * there is no ConfigRewriter defined for the job, then this method is a
-   * no-op.
-   *
-   * @param config The config to re-write.
-   */
-  def rewriteConfig(config: Config): Config = {
-    def rewrite(c: Config, rewriterName: String): Config = {
-      val klass = config
-        .getConfigRewriterClass(rewriterName)
-        .getOrElse(throw new SamzaException("Unable to find class config for config rewriter %s." format rewriterName))
-      val rewriter = Util.getObj[ConfigRewriter](klass)
-      info("Re-writing config with " + rewriter)
-      rewriter.rewrite(rewriterName, c)
-    }
-
-    config.getConfigRewriters match {
-      case Some(rewriters) => rewriters.split(",").foldLeft(config)(rewrite(_, _))
-      case _ => config
-    }
-  }
-
   def main(args: Array[String]) {
     val cmdline = new CommandLine
     val options = cmdline.parser.parse(args: _*)
     val config = cmdline.loadConfig(options)
-    new JobRunner(rewriteConfig(config)).run()
+    new JobRunner(Util.rewriteConfig(config)).run()
   }
 }
 
index 9019d02..97bd22a 100644 (file)
@@ -23,6 +23,7 @@ import java.net._
 import java.io._
 import java.lang.management.ManagementFactory
 import java.util.zip.CRC32
+import org.apache.samza.config.ConfigRewriter
 import org.apache.samza.{SamzaException, Partition}
 import org.apache.samza.system.{SystemFactory, SystemStreamPartition, SystemStream}
 import java.util.Random
@@ -395,4 +396,28 @@ object Util extends Logging {
    * @return Scala clock function
    */
   implicit def asScalaClock(c: HighResolutionClock): () => Long = () => c.nanoTime()
+
+  /**
+   * Re-writes configuration using a ConfigRewriter, if one is defined. If
+   * there is no ConfigRewriter defined for the job, then this method is a
+   * no-op.
+   *
+   * @param config The config to re-write
+   * @return re-written config
+   */
+  def rewriteConfig(config: Config): Config = {
+    def rewrite(c: Config, rewriterName: String): Config = {
+      val klass = config
+              .getConfigRewriterClass(rewriterName)
+              .getOrElse(throw new SamzaException("Unable to find class config for config rewriter %s." format rewriterName))
+      val rewriter = Util.getObj[ConfigRewriter](klass)
+      info("Re-writing config with " + rewriter)
+      rewriter.rewrite(rewriterName, c)
+    }
+
+    config.getConfigRewriters match {
+      case Some(rewriters) => rewriters.split(",").foldLeft(config)(rewrite(_, _))
+      case _ => config
+    }
+  }
 }
diff --git a/samza-core/src/test/java/org/apache/samza/execution/TestExecutionPlanner.java b/samza-core/src/test/java/org/apache/samza/execution/TestExecutionPlanner.java
new file mode 100644 (file)
index 0000000..ee73195
--- /dev/null
@@ -0,0 +1,281 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.execution;
+
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Set;
+import org.apache.samza.Partition;
+import org.apache.samza.config.Config;
+import org.apache.samza.config.JobConfig;
+import org.apache.samza.config.MapConfig;
+import org.apache.samza.operators.MessageStream;
+import org.apache.samza.operators.StreamGraph;
+import org.apache.samza.operators.StreamGraphBuilder;
+import org.apache.samza.operators.StreamGraphImpl;
+import org.apache.samza.operators.functions.JoinFunction;
+import org.apache.samza.operators.functions.SinkFunction;
+import org.apache.samza.runtime.AbstractApplicationRunner;
+import org.apache.samza.runtime.ApplicationRunner;
+import org.apache.samza.system.StreamSpec;
+import org.apache.samza.system.SystemAdmin;
+import org.apache.samza.system.SystemStreamMetadata;
+import org.apache.samza.system.SystemStreamPartition;
+import org.apache.samza.task.MessageCollector;
+import org.apache.samza.task.TaskCoordinator;
+import org.junit.Before;
+import org.junit.Test;
+
+import static org.junit.Assert.assertTrue;
+
+
+public class TestExecutionPlanner {
+
+  private Config config;
+
+  private static final String DEFAULT_SYSTEM = "test-system";
+  private static final int DEFAULT_PARTITIONS = 10;
+
+  private StreamSpec input1;
+  private StreamSpec input2;
+  private StreamSpec input3;
+  private StreamSpec output1;
+  private StreamSpec output2;
+
+  private Map<String, SystemAdmin> systemAdmins;
+
+  private ApplicationRunner runner;
+
+  private JoinFunction createJoin() {
+    return new JoinFunction() {
+      @Override
+      public Object apply(Object message, Object otherMessage) {
+        return null;
+      }
+
+      @Override
+      public Object getFirstKey(Object message) {
+        return null;
+      }
+
+      @Override
+      public Object getSecondKey(Object message) {
+        return null;
+      }
+    };
+  }
+
+  private SinkFunction createSink() {
+    return new SinkFunction() {
+      @Override
+      public void apply(Object message, MessageCollector messageCollector, TaskCoordinator taskCoordinator) {
+      }
+    };
+  }
+
+  private SystemAdmin createSystemAdmin(Map<String, Integer> streamToPartitions) {
+
+    return new SystemAdmin() {
+      @Override
+      public Map<SystemStreamPartition, String> getOffsetsAfter(Map<SystemStreamPartition, String> offsets) {
+        return null;
+      }
+
+      @Override
+      public Map<String, SystemStreamMetadata> getSystemStreamMetadata(Set<String> streamNames) {
+        Map<String, SystemStreamMetadata> map = new HashMap<>();
+        for (String stream : streamNames) {
+          Map<Partition, SystemStreamMetadata.SystemStreamPartitionMetadata> m = new HashMap<>();
+          for (int i = 0; i < streamToPartitions.get(stream); i++) {
+            m.put(new Partition(i), new SystemStreamMetadata.SystemStreamPartitionMetadata("", "", ""));
+          }
+          map.put(stream, new SystemStreamMetadata(stream, m));
+        }
+        return map;
+      }
+
+      @Override
+      public void createChangelogStream(String streamName, int numOfPartitions) {
+
+      }
+
+      @Override
+      public void validateChangelogStream(String streamName, int numOfPartitions) {
+
+      }
+
+      @Override
+      public void createCoordinatorStream(String streamName) {
+
+      }
+
+      @Override
+      public Integer offsetComparator(String offset1, String offset2) {
+        return null;
+      }
+    };
+  }
+
+  private StreamGraph createSimpleGraph() {
+    /**
+     * a simple graph of partitionBy and map
+     *
+     * input1 -> partitionBy -> map -> output1
+     *
+     */
+    StreamGraph streamGraph = new StreamGraphImpl(runner, config);
+    streamGraph.createInStream(input1, null, null).partitionBy(m -> "yes!!!").map(m -> m).sendTo(streamGraph.createOutStream(output1, null, null));
+    return streamGraph;
+  }
+
+  private StreamGraph createStreamGraphWithJoin() {
+
+    /** the graph looks like the following
+     *
+     *                        input1 -> map -> join -> output1
+     *                                           |
+     *          input2 -> partitionBy -> filter -|
+     *                                           |
+     * input3 -> filter -> partitionBy -> map -> join -> output2
+     *
+     */
+
+    StreamGraph streamGraph = new StreamGraphImpl(runner, config);
+    MessageStream m1 = streamGraph.createInStream(input1, null, null).map(m -> m);
+    MessageStream m2 = streamGraph.createInStream(input2, null, null).partitionBy(m -> "haha").filter(m -> true);
+    MessageStream m3 = streamGraph.createInStream(input3, null, null).filter(m -> true).partitionBy(m -> "hehe").map(m -> m);
+
+    m1.join(m2, createJoin()).sendTo(streamGraph.createOutStream(output1, null, null));
+    m3.join(m2, createJoin()).sendTo(streamGraph.createOutStream(output2, null, null));
+
+    return streamGraph;
+  }
+
+  @Before
+  public void setup() {
+    Map<String, String> configMap = new HashMap<>();
+    configMap.put(JobConfig.JOB_NAME(), "test-app");
+    configMap.put(JobConfig.JOB_DEFAULT_SYSTEM(), DEFAULT_SYSTEM);
+
+    config = new MapConfig(configMap);
+
+    input1 = new StreamSpec("input1", "input1", "system1");
+    input2 = new StreamSpec("input2", "input2", "system2");
+    input3 = new StreamSpec("input3", "input3", "system2");
+
+    output1 = new StreamSpec("output1", "output1", "system1");
+    output2 = new StreamSpec("output2", "output2", "system2");
+
+    // set up external partition count
+    Map<String, Integer> system1Map = new HashMap<>();
+    system1Map.put("input1", 64);
+    system1Map.put("output1", 8);
+    Map<String, Integer> system2Map = new HashMap<>();
+    system2Map.put("input2", 16);
+    system2Map.put("input3", 32);
+    system2Map.put("output2", 16);
+
+    SystemAdmin systemAdmin1 = createSystemAdmin(system1Map);
+    SystemAdmin systemAdmin2 = createSystemAdmin(system2Map);
+    systemAdmins = new HashMap<>();
+    systemAdmins.put("system1", systemAdmin1);
+    systemAdmins.put("system2", systemAdmin2);
+
+    runner = new AbstractApplicationRunner(config) {
+      @Override
+      public void run(StreamGraphBuilder graphBuilder, Config config) {
+      }
+    };
+  }
+
+  @Test
+  public void testCreateProcessorGraph() {
+    ExecutionPlanner planner = new ExecutionPlanner(config);
+    StreamGraph streamGraph = createStreamGraphWithJoin();
+
+    ProcessorGraph processorGraph = planner.createProcessorGraph(streamGraph);
+    assertTrue(processorGraph.getSources().size() == 3);
+    assertTrue(processorGraph.getSinks().size() == 2);
+    assertTrue(processorGraph.getIntermediateStreams().size() == 2); // two streams generated by partitionBy
+  }
+
+  @Test
+  public void testFetchExistingStreamPartitions() {
+    ExecutionPlanner planner = new ExecutionPlanner(config);
+    StreamGraph streamGraph = createStreamGraphWithJoin();
+    ProcessorGraph processorGraph = planner.createProcessorGraph(streamGraph);
+
+    ExecutionPlanner.updateExistingPartitions(processorGraph, systemAdmins);
+    assertTrue(processorGraph.getOrCreateEdge(input1).getPartitionCount() == 64);
+    assertTrue(processorGraph.getOrCreateEdge(input2).getPartitionCount() == 16);
+    assertTrue(processorGraph.getOrCreateEdge(input3).getPartitionCount() == 32);
+    assertTrue(processorGraph.getOrCreateEdge(output1).getPartitionCount() == 8);
+    assertTrue(processorGraph.getOrCreateEdge(output2).getPartitionCount() == 16);
+
+    processorGraph.getIntermediateStreams().forEach(edge -> {
+        assertTrue(edge.getPartitionCount() == -1);
+      });
+  }
+
+  @Test
+  public void testCalculateJoinInputPartitions() {
+    ExecutionPlanner planner = new ExecutionPlanner(config);
+    StreamGraph streamGraph = createStreamGraphWithJoin();
+    ProcessorGraph processorGraph = planner.createProcessorGraph(streamGraph);
+
+    ExecutionPlanner.updateExistingPartitions(processorGraph, systemAdmins);
+    ExecutionPlanner.calculateJoinInputPartitions(streamGraph, processorGraph);
+
+    // the partitions should be the same as input1
+    processorGraph.getIntermediateStreams().forEach(edge -> {
+        assertTrue(edge.getPartitionCount() == 64);
+      });
+  }
+
+  @Test
+  public void testDefaultPartitions() {
+    Map<String, String> map = new HashMap<>(config);
+    map.put(JobConfig.JOB_INTERMEDIATE_STREAM_PARTITIONS(), String.valueOf(DEFAULT_PARTITIONS));
+    Config cfg = new MapConfig(map);
+
+    ExecutionPlanner planner = new ExecutionPlanner(cfg);
+    StreamGraph streamGraph = createSimpleGraph();
+    ProcessorGraph processorGraph = planner.createProcessorGraph(streamGraph);
+    planner.calculatePartitions(streamGraph, processorGraph, systemAdmins);
+
+    // the partitions should be the same as input1
+    processorGraph.getIntermediateStreams().forEach(edge -> {
+        assertTrue(edge.getPartitionCount() == DEFAULT_PARTITIONS);
+      });
+  }
+
+  @Test
+  public void testCalculateIntStreamPartitions() {
+    ExecutionPlanner planner = new ExecutionPlanner(config);
+    StreamGraph streamGraph = createSimpleGraph();
+    ProcessorGraph processorGraph = planner.createProcessorGraph(streamGraph);
+    planner.calculatePartitions(streamGraph, processorGraph, systemAdmins);
+
+    // the partitions should be the same as input1
+    processorGraph.getIntermediateStreams().forEach(edge -> {
+        assertTrue(edge.getPartitionCount() == 64); // max of input1 and output1
+      });
+  }
+}
diff --git a/samza-core/src/test/java/org/apache/samza/execution/TestProcessorGraph.java b/samza-core/src/test/java/org/apache/samza/execution/TestProcessorGraph.java
new file mode 100644 (file)
index 0000000..2bdf529
--- /dev/null
@@ -0,0 +1,198 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.execution;
+
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import org.apache.samza.system.StreamSpec;
+import org.junit.Before;
+import org.junit.Test;
+
+import static org.junit.Assert.assertTrue;
+
+
+public class TestProcessorGraph {
+
+  ProcessorGraph graph1;
+  ProcessorGraph graph2;
+  int streamSeq = 0;
+
+  private StreamSpec genStream() {
+    ++streamSeq;
+
+    return new StreamSpec(String.valueOf(streamSeq), "test-stream", "test-system");
+  }
+
+  @Before
+  public void setup() {
+    /**
+     * graph1 is the example graph from wikipedia
+     *
+     * 5   7   3
+     * | / | / |
+     * v   v   |
+     * 11  8   |
+     * | \X   /
+     * v v \v
+     * 2 9 10
+     */
+    // init graph1
+    graph1 = new ProcessorGraph(null);
+    graph1.addSource(genStream(), "5");
+    graph1.addSource(genStream(), "7");
+    graph1.addSource(genStream(), "3");
+    graph1.addIntermediateStream(genStream(), "5", "11");
+    graph1.addIntermediateStream(genStream(), "7", "11");
+    graph1.addIntermediateStream(genStream(), "7", "8");
+    graph1.addIntermediateStream(genStream(), "3", "8");
+    graph1.addIntermediateStream(genStream(), "11", "2");
+    graph1.addIntermediateStream(genStream(), "11", "9");
+    graph1.addIntermediateStream(genStream(), "8", "9");
+    graph1.addIntermediateStream(genStream(), "11", "10");
+    graph1.addSink(genStream(), "2");
+    graph1.addSink(genStream(), "9");
+    graph1.addSink(genStream(), "10");
+
+    /**
+     * graph2 is a graph with a loop
+     * 1 -> 2 -> 3 -> 4 -> 5 -> 7
+     *      |<---6 <--|    <>
+     */
+    graph2 = new ProcessorGraph(null);
+    graph2.addSource(genStream(), "1");
+    graph2.addIntermediateStream(genStream(), "1", "2");
+    graph2.addIntermediateStream(genStream(), "2", "3");
+    graph2.addIntermediateStream(genStream(), "3", "4");
+    graph2.addIntermediateStream(genStream(), "4", "5");
+    graph2.addIntermediateStream(genStream(), "4", "6");
+    graph2.addIntermediateStream(genStream(), "6", "2");
+    graph2.addIntermediateStream(genStream(), "5", "5");
+    graph2.addIntermediateStream(genStream(), "5", "7");
+    graph2.addSink(genStream(), "7");
+  }
+
+  @Test
+  public void testAddSource() {
+    ProcessorGraph graph = new ProcessorGraph(null);
+
+    /**
+     * s1 -> 1
+     * s2 ->|
+     *
+     * s3 -> 2
+     *   |-> 3
+     */
+    StreamSpec s1 = genStream();
+    StreamSpec s2 = genStream();
+    StreamSpec s3 = genStream();
+    graph.addSource(s1, "1");
+    graph.addSource(s2, "1");
+    graph.addSource(s3, "2");
+    graph.addSource(s3, "3");
+
+    assertTrue(graph.getSources().size() == 3);
+
+    assertTrue(graph.getOrCreateProcessor("1").getInEdges().size() == 2);
+    assertTrue(graph.getOrCreateProcessor("2").getInEdges().size() == 1);
+    assertTrue(graph.getOrCreateProcessor("3").getInEdges().size() == 1);
+
+    assertTrue(graph.getOrCreateEdge(s1).getSourceNodes().size() == 0);
+    assertTrue(graph.getOrCreateEdge(s1).getTargetNodes().size() == 1);
+    assertTrue(graph.getOrCreateEdge(s2).getSourceNodes().size() == 0);
+    assertTrue(graph.getOrCreateEdge(s2).getTargetNodes().size() == 1);
+    assertTrue(graph.getOrCreateEdge(s3).getSourceNodes().size() == 0);
+    assertTrue(graph.getOrCreateEdge(s3).getTargetNodes().size() == 2);
+  }
+
+  @Test
+  public void testAddSink() {
+    /**
+     * 1 -> s1
+     * 2 -> s2
+     * 2 -> s3
+     */
+    StreamSpec s1 = genStream();
+    StreamSpec s2 = genStream();
+    StreamSpec s3 = genStream();
+    ProcessorGraph graph = new ProcessorGraph(null);
+    graph.addSink(s1, "1");
+    graph.addSink(s2, "2");
+    graph.addSink(s3, "2");
+
+    assertTrue(graph.getSinks().size() == 3);
+    assertTrue(graph.getOrCreateProcessor("1").getOutEdges().size() == 1);
+    assertTrue(graph.getOrCreateProcessor("2").getOutEdges().size() == 2);
+
+    assertTrue(graph.getOrCreateEdge(s1).getSourceNodes().size() == 1);
+    assertTrue(graph.getOrCreateEdge(s1).getTargetNodes().size() == 0);
+    assertTrue(graph.getOrCreateEdge(s2).getSourceNodes().size() == 1);
+    assertTrue(graph.getOrCreateEdge(s2).getTargetNodes().size() == 0);
+    assertTrue(graph.getOrCreateEdge(s3).getSourceNodes().size() == 1);
+    assertTrue(graph.getOrCreateEdge(s3).getTargetNodes().size() == 0);
+  }
+
+  @Test
+  public void testReachable() {
+    Set<ProcessorNode> reachable1 = graph1.findReachable();
+    assertTrue(reachable1.size() == 8);
+
+    Set<ProcessorNode> reachable2 = graph2.findReachable();
+    assertTrue(reachable2.size() == 7);
+  }
+
+  @Test
+  public void testTopologicalSort() {
+
+    // test graph1
+    List<ProcessorNode> sortedNodes1 = graph1.topologicalSort();
+    Map<String, Integer> idxMap1 = new HashMap<>();
+    for (int i = 0; i < sortedNodes1.size(); i++) {
+      idxMap1.put(sortedNodes1.get(i).getId(), i);
+    }
+
+    assertTrue(idxMap1.size() == 8);
+    assertTrue(idxMap1.get("11") > idxMap1.get("5"));
+    assertTrue(idxMap1.get("11") > idxMap1.get("7"));
+    assertTrue(idxMap1.get("8") > idxMap1.get("7"));
+    assertTrue(idxMap1.get("8") > idxMap1.get("3"));
+    assertTrue(idxMap1.get("2") > idxMap1.get("11"));
+    assertTrue(idxMap1.get("9") > idxMap1.get("8"));
+    assertTrue(idxMap1.get("9") > idxMap1.get("11"));
+    assertTrue(idxMap1.get("10") > idxMap1.get("11"));
+    assertTrue(idxMap1.get("10") > idxMap1.get("3"));
+
+    // test graph2
+    List<ProcessorNode> sortedNodes2 = graph2.topologicalSort();
+    Map<String, Integer> idxMap2 = new HashMap<>();
+    for (int i = 0; i < sortedNodes2.size(); i++) {
+      idxMap2.put(sortedNodes2.get(i).getId(), i);
+    }
+
+    assertTrue(idxMap2.size() == 7);
+    assertTrue(idxMap2.get("2") > idxMap2.get("1"));
+    assertTrue(idxMap2.get("3") > idxMap2.get("1"));
+    assertTrue(idxMap2.get("4") > idxMap2.get("1"));
+    assertTrue(idxMap2.get("6") > idxMap2.get("1"));
+    assertTrue(idxMap2.get("5") > idxMap2.get("4"));
+    assertTrue(idxMap2.get("7") > idxMap2.get("5"));
+  }
+}