SAMZA-1659: Serializable OperatorSpec
authorYi Pan (Data Infrastructure) <nickpan47@gmail.com>
Fri, 25 May 2018 16:37:55 +0000 (09:37 -0700)
committerYi Pan (Data Infrastructure) <yipan@yipan-mn1.linkedin.biz>
Fri, 25 May 2018 16:37:55 +0000 (09:37 -0700)
This change is to make the user supplied functions serializable. Hence, making the full user defined DAG serializable.

Author: Yi Pan (Data Infrastructure) <nickpan47@gmail.com>
Author: Yi Pan (Data Infrastructure) <yipan@yipan-mn1.linkedin.biz>
Author: Xinyu Liu <xiliu@xiliu-ld.linkedin.biz>

Reviewers: Jagadish <jvenkatraman@linkedin.com>, Prateek Maheshwari <pmaheshw@linkedin.com>

Closes #475 from nickpan47/serializable-opspec-only-Jan-24-18 and squashes the following commits:

db0dea73 [Yi Pan (Data Infrastructure)] SAMZA-1659: fix intermittent TestZkLocalApplicationRunner failure due to StreamProcessor#stop()
34716d42 [Yi Pan (Data Infrastructure)] SAMZA-1659: fix a comment on OperatorSpec#isClone()
37d4e6ae [Yi Pan (Data Infrastructure)] SAMZA-1659: addressing latest round of review comments
68674a14 [Yi Pan (Data Infrastructure)] Merge branch 'master' into serializable-opspec-only-Jan-24-18
d3a7826c [Yi Pan (Data Infrastructure)] SAMZA-1659: addressing review comments
f83e8dd0 [Yi Pan (Data Infrastructure)] Merge branch 'master' into serializable-opspec-only-Jan-24-18
acca418b [Yi Pan (Data Infrastructure)] Merge branch 'master' into serializable-opspec-only-Jan-24-18
842a73d6 [Yi Pan (Data Infrastructure)] SAMZA-1659: making user-defined functions in high-level API serializable
ad85a2cb [Yi Pan (Data Infrastructure)] Merge branch 'master' into serializable-opspec-only-Jan-24-18
c1567116 [Yi Pan (Data Infrastructure)] SAMZA-1659: Before re-merge with master. Still need to fix unit tests (moving OperatorSpec clone tests to OperatorSpecGraph.clone)
f2563f8e [Yi Pan (Data Infrastructure)] SAMZA-1659: serialize the whole DAG instead of each individual OperatorSpec.
24d33496 [Yi Pan (Data Infrastructure)] SAMZA-1659: updated according to review comments. Need to merge again with master.
3f643f8b [Yi Pan (Data Infrastructure)] SAMZA-1659: serialiable OperatorSpec
ed7d8c0e [Yi Pan (Data Infrastructure)] Fixed some javadoc and test files
94de218b [Yi Pan (Data Infrastructure)] Remove public access from StreamGraphImpl#getIntermediateStream(String, Serde)
8f4e9dd4 [Yi Pan (Data Infrastructure)] Serialization of StreamGraph in a wrapper class SerializedStreamGraph
f3bb1958 [Yi Pan (Data Infrastructure)] Fix some comments
c15246f5 [Yi Pan (Data Infrastructure)] Merge branch 'master' into serializable-opspec-only-Jan-24-18
e981967d [Yi Pan (Data Infrastructure)] WIP: fixing unit test for SamzaSQL translators w/ serialization of operator functions
40583051 [Yi Pan (Data Infrastructure)] WIP: update the serialization of user functions after the merge
18ba924f [Yi Pan (Data Infrastructure)] Merge branch 'master' into serializable-opspec-only-Jan-24-18
93951c5f [Yi Pan (Data Infrastructure)] Merge branch 'master' into serializable-opspec-only-Jan-24-18
54a28801 [Yi Pan (Data Infrastructure)] WIP: broadcast, sendtotable, and streamtotablejoin serialization and unit tests
45eb1fb0 [Yi Pan (Data Infrastructure)] Merge branch 'master' into serializable-opspec-only-Jan-24-18
7c8d1591 [Yi Pan (Data Infrastructure)] WIP: working on unit tests for trigger, broadcast, join, table, and SQL UDF function serialization
b973b105 [Yi Pan (Data Infrastructure)] Merge branch 'master' into serializable-opspec-only-Jan-24-18
aca42308 [Yi Pan (Data Infrastructure)] WIP: Serialize OperatorSpec only w/o StreamApplication interface change. Passed all build and tests.
0ebebfc3 [Yi Pan (Data Infrastructure)] WIP: serialization only change
1670aff0 [Yi Pan (Data Infrastructure)] WIP: class-loading of user program logic and main() method based user program logic are both included in ThreadJobFactory/ProcessJobFactory/YarnJobFactory. ThreadJobFactory test suite to be fixed.
4102aa8c [Yi Pan (Data Infrastructure)] WIP: continued working on potential offspring integration
dc7da87e [Yi Pan (Data Infrastructure)] WIP: unit tests for serialization
475a46bc [Yi Pan (Data Infrastructure)] WIP: fixed TestZkLocalApplicationRunner. Debugging issues w/ TestRepartitionWindowApp (i.e. missing changelog creation step when directly running LocalApplicationRunner)
6a14b2af [Yi Pan (Data Infrastructure)] WIP: fixed unit test failure for Windows
d4640329 [Yi Pan (Data Infrastructure)] WIP: fixing unit tests after merge
bf1ce907 [Yi Pan (Data Infrastructure)] WIP: removing StreamDescriptor first
50201728 [Yi Pan (Data Infrastructure)] Merge branch 'experiment-new-api-v2' into new-api-v2-0.14
dde1ab14 [Yi Pan (Data Infrastructure)] WIP: first end-to-end test
d7df6ed0 [Yi Pan (Data Infrastructure)] WIP: added all unit test for OperatorSpec#copy methods.
6fc6d4c0 [Yi Pan (Data Infrastructure)] WIP: experiment code to implement an end-to-end working example for new APIs
525d8bc1 [Yi Pan (Data Infrastructure)] Merge branch '0.14.0' into new-api-v2
e6fb96e5 [Yi Pan (Data Infrastructure)] WIP: merged all application types into StreamApplications
f227380f [Yi Pan (Data Infrastructure)] WIP: update the app runner classes
256155ad [Yi Pan (Data Infrastructure)] WIP: new API user code examples
4a6a58dc [Yi Pan (Data Infrastructure)] WIP: updated w/ low-level task API and global var ingestion/metrics reporter
3c50629e [Yi Pan (Data Infrastructure)] WIP: adding support for low-level task APIs
51541e13 [Yi Pan (Data Infrastructure)] WIP: cleanup StreamDescriptor
0bc7ee7b [Yi Pan (Data Infrastructure)] WIP: update the user code example on new APIs
cd528c1c [Yi Pan (Data Infrastructure)] WIP: updated spec and user DAG API
b898e6c0 [Yi Pan (Data Infrastructure)] WIP: new-api-v2
91f364f1 [Yi Pan (Data Infrastructure)] WIP: proto-type of input/output stream/system specs
ae3dc6ff [Yi Pan (Data Infrastructure)] WIP: new api revision
8bb97520 [Yi Pan (Data Infrastructure)] WIP: proto-type of input/output stream/system specs
5573a069 [Yi Pan (Data Infrastructure)] WIP: new api revision
aeb45730 [Xinyu Liu] SAMZA-1321: Propagate end-of-stream and watermark messages

126 files changed:
build.gradle
samza-api/src/main/java/org/apache/samza/application/StreamApplication.java
samza-api/src/main/java/org/apache/samza/config/MapConfig.java
samza-api/src/main/java/org/apache/samza/operators/MessageStream.java
samza-api/src/main/java/org/apache/samza/operators/functions/ClosableFunction.java
samza-api/src/main/java/org/apache/samza/operators/functions/FilterFunction.java
samza-api/src/main/java/org/apache/samza/operators/functions/FlatMapFunction.java
samza-api/src/main/java/org/apache/samza/operators/functions/FoldLeftFunction.java
samza-api/src/main/java/org/apache/samza/operators/functions/JoinFunction.java
samza-api/src/main/java/org/apache/samza/operators/functions/MapFunction.java
samza-api/src/main/java/org/apache/samza/operators/functions/SinkFunction.java
samza-api/src/main/java/org/apache/samza/operators/functions/StreamTableJoinFunction.java
samza-api/src/main/java/org/apache/samza/operators/functions/SupplierFunction.java [moved from samza-core/src/test/java/org/apache/samza/testUtils/InvalidStreamApplication.java with 65% similarity]
samza-api/src/main/java/org/apache/samza/operators/triggers/AnyTrigger.java
samza-api/src/main/java/org/apache/samza/operators/triggers/Trigger.java
samza-api/src/main/java/org/apache/samza/operators/windows/Window.java
samza-api/src/main/java/org/apache/samza/operators/windows/Windows.java
samza-api/src/main/java/org/apache/samza/operators/windows/internal/WindowInternal.java
samza-api/src/main/java/org/apache/samza/serializers/SerializableSerde.java
samza-api/src/main/java/org/apache/samza/system/StreamSpec.java
samza-api/src/main/java/org/apache/samza/system/SystemStreamPartition.java
samza-api/src/main/java/org/apache/samza/table/TableSpec.java
samza-api/src/test/java/org/apache/samza/operators/windows/TestWindowPane.java
samza-core/src/main/java/org/apache/samza/execution/ExecutionPlanner.java
samza-core/src/main/java/org/apache/samza/execution/JobGraph.java
samza-core/src/main/java/org/apache/samza/execution/JobGraphJsonGenerator.java
samza-core/src/main/java/org/apache/samza/execution/JobNode.java
samza-core/src/main/java/org/apache/samza/operators/MessageStreamImpl.java
samza-core/src/main/java/org/apache/samza/operators/OperatorSpecGraph.java [new file with mode: 0644]
samza-core/src/main/java/org/apache/samza/operators/StreamGraphSpec.java [moved from samza-core/src/main/java/org/apache/samza/operators/StreamGraphImpl.java with 82% similarity]
samza-core/src/main/java/org/apache/samza/operators/TableImpl.java
samza-core/src/main/java/org/apache/samza/operators/impl/BroadcastOperatorImpl.java
samza-core/src/main/java/org/apache/samza/operators/impl/OperatorImpl.java
samza-core/src/main/java/org/apache/samza/operators/impl/OperatorImplGraph.java
samza-core/src/main/java/org/apache/samza/operators/impl/OutputOperatorImpl.java
samza-core/src/main/java/org/apache/samza/operators/impl/PartitionByOperatorImpl.java
samza-core/src/main/java/org/apache/samza/operators/impl/StreamOperatorImpl.java
samza-core/src/main/java/org/apache/samza/operators/impl/WindowOperatorImpl.java
samza-core/src/main/java/org/apache/samza/operators/spec/FilterOperatorSpec.java [new file with mode: 0644]
samza-core/src/main/java/org/apache/samza/operators/spec/FlatMapOperatorSpec.java [new file with mode: 0644]
samza-core/src/main/java/org/apache/samza/operators/spec/InputOperatorSpec.java
samza-core/src/main/java/org/apache/samza/operators/spec/JoinOperatorSpec.java
samza-core/src/main/java/org/apache/samza/operators/spec/MapOperatorSpec.java [new file with mode: 0644]
samza-core/src/main/java/org/apache/samza/operators/spec/MergeOperatorSpec.java [new file with mode: 0644]
samza-core/src/main/java/org/apache/samza/operators/spec/OperatorSpec.java
samza-core/src/main/java/org/apache/samza/operators/spec/OperatorSpecs.java
samza-core/src/main/java/org/apache/samza/operators/spec/OutputStreamImpl.java
samza-core/src/main/java/org/apache/samza/operators/spec/PartitionByOperatorSpec.java
samza-core/src/main/java/org/apache/samza/operators/spec/SendToTableOperatorSpec.java
samza-core/src/main/java/org/apache/samza/operators/spec/StreamOperatorSpec.java
samza-core/src/main/java/org/apache/samza/operators/spec/WindowOperatorSpec.java
samza-core/src/main/java/org/apache/samza/operators/stream/IntermediateMessageStreamImpl.java
samza-core/src/main/java/org/apache/samza/operators/triggers/Cancellable.java
samza-core/src/main/java/org/apache/samza/operators/triggers/TriggerImpl.java
samza-core/src/main/java/org/apache/samza/runtime/AbstractApplicationRunner.java
samza-core/src/main/java/org/apache/samza/runtime/LocalApplicationRunner.java
samza-core/src/main/java/org/apache/samza/runtime/LocalContainerRunner.java
samza-core/src/main/java/org/apache/samza/task/StreamOperatorTask.java
samza-core/src/main/java/org/apache/samza/task/TaskFactoryUtil.java
samza-core/src/main/scala/org/apache/samza/container/TaskInstance.scala
samza-core/src/main/scala/org/apache/samza/job/local/ThreadJobFactory.scala
samza-core/src/test/java/org/apache/samza/execution/TestExecutionPlanner.java
samza-core/src/test/java/org/apache/samza/execution/TestJobGraph.java
samza-core/src/test/java/org/apache/samza/execution/TestJobGraphJsonGenerator.java
samza-core/src/test/java/org/apache/samza/execution/TestJobNode.java
samza-core/src/test/java/org/apache/samza/operators/TestJoinOperator.java
samza-core/src/test/java/org/apache/samza/operators/TestMessageStreamImpl.java
samza-core/src/test/java/org/apache/samza/operators/TestOperatorSpecGraph.java [new file with mode: 0644]
samza-core/src/test/java/org/apache/samza/operators/TestStreamGraphSpec.java [moved from samza-core/src/test/java/org/apache/samza/operators/TestStreamGraphImpl.java with 70% similarity]
samza-core/src/test/java/org/apache/samza/operators/data/TestOutputMessageEnvelope.java
samza-core/src/test/java/org/apache/samza/operators/impl/TestOperatorImplGraph.java
samza-core/src/test/java/org/apache/samza/operators/impl/TestStreamOperatorImpl.java
samza-core/src/test/java/org/apache/samza/operators/impl/TestWindowOperator.java
samza-core/src/test/java/org/apache/samza/operators/spec/OperatorSpecTestUtils.java [new file with mode: 0644]
samza-core/src/test/java/org/apache/samza/operators/spec/TestOperatorSpec.java [new file with mode: 0644]
samza-core/src/test/java/org/apache/samza/operators/spec/TestPartitionByOperatorSpec.java [new file with mode: 0644]
samza-core/src/test/java/org/apache/samza/operators/spec/TestWindowOperatorSpec.java
samza-core/src/test/java/org/apache/samza/runtime/TestAbstractApplicationRunner.java
samza-core/src/test/java/org/apache/samza/runtime/TestLocalApplicationRunner.java
samza-core/src/test/java/org/apache/samza/task/IdentityStreamTask.java [new file with mode: 0644]
samza-core/src/test/java/org/apache/samza/task/TestTaskFactoryUtil.java
samza-kafka/src/test/java/org/apache/samza/system/kafka/TestKafkaStreamSpec.java
samza-sql/src/main/java/org/apache/samza/sql/data/SamzaSqlCompositeKey.java
samza-sql/src/main/java/org/apache/samza/sql/data/SamzaSqlExecutionContext.java
samza-sql/src/main/java/org/apache/samza/sql/translator/FilterTranslator.java
samza-sql/src/main/java/org/apache/samza/sql/translator/LogicalAggregateTranslator.java
samza-sql/src/main/java/org/apache/samza/sql/translator/ProjectTranslator.java
samza-sql/src/main/java/org/apache/samza/sql/translator/QueryTranslator.java
samza-sql/src/main/java/org/apache/samza/sql/translator/SamzaSqlRelMessageJoinFunction.java
samza-sql/src/main/java/org/apache/samza/sql/translator/ScanTranslator.java
samza-sql/src/main/java/org/apache/samza/sql/translator/TranslatorContext.java
samza-sql/src/test/java/org/apache/samza/sql/data/TestSamzaSqlRelMessage.java [moved from samza-sql/src/test/java/org/apache/samza/sql/TestSamzaSqlRelMessage.java with 97% similarity]
samza-sql/src/test/java/org/apache/samza/sql/runner/TestSamzaSqlApplicationConfig.java [moved from samza-sql/src/test/java/org/apache/samza/sql/TestSamzaSqlApplicationConfig.java with 99% similarity]
samza-sql/src/test/java/org/apache/samza/sql/runner/TestSamzaSqlApplicationRunner.java [moved from samza-sql/src/test/java/org/apache/samza/sql/TestSamzaSqlApplicationRunner.java with 98% similarity]
samza-sql/src/test/java/org/apache/samza/sql/testutil/TestSamzaSqlFileParser.java [moved from samza-sql/src/test/java/org/apache/samza/sql/TestSamzaSqlFileParser.java with 98% similarity]
samza-sql/src/test/java/org/apache/samza/sql/testutil/TestSamzaSqlQueryParser.java [moved from samza-sql/src/test/java/org/apache/samza/sql/TestSamzaSqlQueryParser.java with 96% similarity]
samza-sql/src/test/java/org/apache/samza/sql/translator/TestFilterTranslator.java [new file with mode: 0644]
samza-sql/src/test/java/org/apache/samza/sql/translator/TestJoinTranslator.java [new file with mode: 0644]
samza-sql/src/test/java/org/apache/samza/sql/translator/TestProjectTranslator.java [new file with mode: 0644]
samza-sql/src/test/java/org/apache/samza/sql/translator/TestQueryTranslator.java [moved from samza-sql/src/test/java/org/apache/samza/sql/TestQueryTranslator.java with 62% similarity]
samza-sql/src/test/java/org/apache/samza/sql/translator/TestSamzaSqlRelMessageJoinFunction.java [moved from samza-sql/src/test/java/org/apache/samza/sql/TestSamzaSqlRelMessageJoinFunction.java with 98% similarity]
samza-sql/src/test/java/org/apache/samza/sql/translator/TranslatorTestBase.java [new file with mode: 0644]
samza-test/src/main/java/org/apache/samza/example/AppWithGlobalConfigExample.java [new file with mode: 0644]
samza-test/src/main/java/org/apache/samza/example/BroadcastExample.java [moved from samza-core/src/test/java/org/apache/samza/example/BroadcastExample.java with 70% similarity]
samza-test/src/main/java/org/apache/samza/example/KeyValueStoreExample.java [moved from samza-core/src/test/java/org/apache/samza/example/KeyValueStoreExample.java with 91% similarity]
samza-test/src/main/java/org/apache/samza/example/MergeExample.java [moved from samza-core/src/test/java/org/apache/samza/example/MergeExample.java with 65% similarity]
samza-test/src/main/java/org/apache/samza/example/OrderShipmentJoinExample.java [moved from samza-core/src/test/java/org/apache/samza/example/OrderShipmentJoinExample.java with 91% similarity]
samza-test/src/main/java/org/apache/samza/example/PageViewCounterExample.java [moved from samza-core/src/test/java/org/apache/samza/example/PageViewCounterExample.java with 85% similarity]
samza-test/src/main/java/org/apache/samza/example/RepartitionExample.java [moved from samza-core/src/test/java/org/apache/samza/example/RepartitionExample.java with 92% similarity]
samza-test/src/main/java/org/apache/samza/example/WindowExample.java [moved from samza-core/src/test/java/org/apache/samza/example/WindowExample.java with 91% similarity]
samza-test/src/main/java/org/apache/samza/test/framework/StreamAssert.java
samza-test/src/test/java/org/apache/samza/test/controlmessages/EndOfStreamIntegrationTest.java
samza-test/src/test/java/org/apache/samza/test/controlmessages/WatermarkIntegrationTest.java
samza-test/src/test/java/org/apache/samza/test/operator/RepartitionJoinWindowApp.java
samza-test/src/test/java/org/apache/samza/test/operator/RepartitionWindowApp.java [new file with mode: 0644]
samza-test/src/test/java/org/apache/samza/test/operator/SessionWindowApp.java
samza-test/src/test/java/org/apache/samza/test/operator/TestRepartitionJoinWindowApp.java
samza-test/src/test/java/org/apache/samza/test/operator/TestRepartitionWindowApp.java [new file with mode: 0644]
samza-test/src/test/java/org/apache/samza/test/operator/TumblingWindowApp.java
samza-test/src/test/java/org/apache/samza/test/operator/data/PageView.java
samza-test/src/test/java/org/apache/samza/test/processor/SharedContextFactories.java [new file with mode: 0644]
samza-test/src/test/java/org/apache/samza/test/processor/TestStreamApplication.java [new file with mode: 0644]
samza-test/src/test/java/org/apache/samza/test/processor/TestZkLocalApplicationRunner.java
samza-test/src/test/java/org/apache/samza/test/table/TestLocalTable.java
samza-test/src/test/java/org/apache/samza/test/table/TestRemoteTable.java
samza-test/src/test/java/org/apache/samza/test/timer/TestTimerApp.java

index 6872354..a94fcfa 100644 (file)
@@ -325,6 +325,9 @@ project(':samza-sql') {
 
     testCompile "junit:junit:$junitVersion"
     testCompile "org.mockito:mockito-core:$mockitoVersion"
+    testCompile "org.powermock:powermock-api-mockito:$powerMockVersion"
+    testCompile "org.powermock:powermock-core:$powerMockVersion"
+    testCompile "org.powermock:powermock-module-junit4:$powerMockVersion"
 
     testRuntime "org.slf4j:slf4j-simple:$slf4jVersion"
   }
@@ -756,10 +759,10 @@ project(":samza-test_$scalaVersion") {
     compile project(":samza-kv-inmemory_$scalaVersion")
     compile project(":samza-kv-rocksdb_$scalaVersion")
     compile project(":samza-core_$scalaVersion")
+    compile project(":samza-kafka_$scalaVersion")
     compile project(":samza-sql")
     runtime project(":samza-log4j")
     runtime project(":samza-yarn_$scalaVersion")
-    runtime project(":samza-kafka_$scalaVersion")
     runtime project(":samza-hdfs_$scalaVersion")
     compile "org.scala-lang:scala-library:$scalaLibVersion"
     compile "net.sf.jopt-simple:jopt-simple:$joptSimpleVersion"
index f615207..0b2142b 100644 (file)
@@ -61,9 +61,10 @@ import org.apache.samza.task.TaskContext;
  *
  * <p>
  * Implementation Notes: Currently StreamApplications are wrapped in a {@link StreamTask} during execution.
- * A new StreamApplication instance will be created and initialized when planning the execution, as well as for each
- * {@link StreamTask} instance used for processing incoming messages. Execution is synchronous and thread-safe within
- * each {@link StreamTask}.
+ * A new StreamApplication instance will be created and initialized with a user-defined {@link StreamGraph}
+ * when planning the execution. The {@link StreamGraph} and the functions implemented for transforms are required to
+ * be serializable. The execution planner will generate a serialized DAG which will be deserialized in each {@link StreamTask}
+ * instance used for processing incoming messages. Execution is synchronous and thread-safe within each {@link StreamTask}.
  *
  * <p>
  * Functions implemented for transforms in StreamApplications ({@link org.apache.samza.operators.functions.MapFunction},
index 0b1ed98..5af2535 100644 (file)
@@ -43,8 +43,11 @@ public class MapConfig extends Config {
 
   public MapConfig(List<Map<String, String>> maps) {
     this.map = new HashMap<>();
-    for (Map<String, String> m: maps)
-      this.map.putAll(m);
+    for (Map<String, String> m: maps) {
+      if (m != null) {
+        this.map.putAll(m);
+      }
+    }
   }
 
   public MapConfig(Map<String, String>... maps) {
index 98f0784..7797f9a 100644 (file)
@@ -21,7 +21,6 @@ package org.apache.samza.operators;
 import java.time.Duration;
 import java.util.ArrayList;
 import java.util.Collection;
-import java.util.function.Function;
 
 import org.apache.samza.annotation.InterfaceStability;
 import org.apache.samza.operators.functions.FilterFunction;
@@ -237,34 +236,34 @@ public interface MessageStream<M> {
    * <p>
    * Unlike {@link #sendTo}, messages with a null key are all sent to partition 0.
    *
-   * @param keyExtractor the {@link Function} to extract the message and partition key from the input message.
+   * @param keyExtractor the {@link MapFunction} to extract the message and partition key from the input message.
    *                     Messages with a null key are all sent to partition 0.
-   * @param valueExtractor the {@link Function} to extract the value from the input message
+   * @param valueExtractor the {@link MapFunction} to extract the value from the input message
    * @param serde the {@link KVSerde} to use for (de)serializing the key and value.
    * @param id the unique id of this operator in this application
    * @param <K> the type of output key
    * @param <V> the type of output value
    * @return the repartitioned {@link MessageStream}
    */
-  <K, V> MessageStream<KV<K, V>> partitionBy(Function<? super M, ? extends K> keyExtractor,
-      Function<? super M, ? extends V> valueExtractor, KVSerde<K, V> serde, String id);
+  <K, V> MessageStream<KV<K, V>> partitionBy(MapFunction<? super M, ? extends K> keyExtractor,
+      MapFunction<? super M, ? extends V> valueExtractor, KVSerde<K, V> serde, String id);
 
   /**
-   * Same as calling {@link #partitionBy(Function, Function, KVSerde, String)} with a null KVSerde.
+   * Same as calling {@link #partitionBy(MapFunction, MapFunction, KVSerde, String)} with a null KVSerde.
    * <p>
    * Uses the default serde provided via {@link StreamGraph#setDefaultSerde}, which must be a KVSerde. If the default
    * serde is not a {@link KVSerde}, a runtime exception will be thrown. If no default serde has been provided
    * <b>before</b> calling this method, a {@code KVSerde<NoOpSerde, NoOpSerde>} is used.
    *
-   * @param keyExtractor the {@link Function} to extract the message and partition key from the input message
-   * @param valueExtractor the {@link Function} to extract the value from the input message
+   * @param keyExtractor the {@link MapFunction} to extract the message and partition key from the input message
+   * @param valueExtractor the {@link MapFunction} to extract the value from the input message
    * @param id the unique id of this operator in this application
    * @param <K> the type of output key
    * @param <V> the type of output value
    * @return the repartitioned {@link MessageStream}
    */
-  <K, V> MessageStream<KV<K, V>> partitionBy(Function<? super M, ? extends K> keyExtractor,
-      Function<? super M, ? extends V> valueExtractor, String id);
+  <K, V> MessageStream<KV<K, V>> partitionBy(MapFunction<? super M, ? extends K> keyExtractor,
+      MapFunction<? super M, ? extends V> valueExtractor, String id);
 
   /**
    * Sends messages in this {@link MessageStream} to a {@link Table}. The type of input message is expected
index ea83ba4..faf9fc5 100644 (file)
@@ -33,5 +33,8 @@ import org.apache.samza.annotation.InterfaceStability;
  */
 @InterfaceStability.Unstable
 public interface ClosableFunction {
+  /**
+   * Frees any resource acquired by the operators in {@link InitableFunction}
+   */
   default void close() {}
 }
index 31bbbd8..ce68e0f 100644 (file)
@@ -18,6 +18,7 @@
  */
 package org.apache.samza.operators.functions;
 
+import java.io.Serializable;
 import org.apache.samza.annotation.InterfaceStability;
 
 
@@ -28,7 +29,7 @@ import org.apache.samza.annotation.InterfaceStability;
  */
 @InterfaceStability.Unstable
 @FunctionalInterface
-public interface FilterFunction<M> extends InitableFunction, ClosableFunction {
+public interface FilterFunction<M> extends InitableFunction, ClosableFunction, Serializable {
 
   /**
    * Returns a boolean indicating whether this message should be retained or filtered out.
index 7e9253e..63d7061 100644 (file)
@@ -18,6 +18,7 @@
  */
 package org.apache.samza.operators.functions;
 
+import java.io.Serializable;
 import org.apache.samza.annotation.InterfaceStability;
 
 import java.util.Collection;
@@ -31,7 +32,7 @@ import java.util.Collection;
  */
 @InterfaceStability.Unstable
 @FunctionalInterface
-public interface FlatMapFunction<M, OM>  extends InitableFunction, ClosableFunction {
+public interface FlatMapFunction<M, OM>  extends InitableFunction, ClosableFunction, Serializable {
 
   /**
    * Transforms the provided message into a collection of 0 or more messages.
index 78250e3..d6ba205 100644 (file)
 
 package org.apache.samza.operators.functions;
 
+import java.io.Serializable;
+import org.apache.samza.annotation.InterfaceStability;
+
+
 /**
- * Incrementally updates the window value as messages are added to the window.
+ * Incrementally updates the aggregated value as messages are added. Main usage is in {@link org.apache.samza.operators.windows.Window} operator.
  */
-public interface FoldLeftFunction<M, WV> extends InitableFunction, ClosableFunction {
+@InterfaceStability.Unstable
+@FunctionalInterface
+public interface FoldLeftFunction<M, WV> extends InitableFunction, ClosableFunction, Serializable {
 
   /**
-   * Incrementally updates the window value as messages are added to the window.
+   * Incrementally updates the aggregated value as messages are added.
    *
-   * @param message the message being added to the window
-   * @param oldValue the previous value associated with the window
+   * @param message the message being added to the aggregated value
+   * @param oldValue the previous value
    * @return the new value
    */
   WV apply(M message, WV oldValue);
index 954083d..94a998d 100644 (file)
@@ -18,6 +18,7 @@
  */
 package org.apache.samza.operators.functions;
 
+import java.io.Serializable;
 import org.apache.samza.annotation.InterfaceStability;
 
 
@@ -30,7 +31,7 @@ import org.apache.samza.annotation.InterfaceStability;
  * @param <RM>  type of the joined message
  */
 @InterfaceStability.Unstable
-public interface JoinFunction<K, M, JM, RM>  extends InitableFunction, ClosableFunction {
+public interface JoinFunction<K, M, JM, RM>  extends InitableFunction, ClosableFunction, Serializable {
 
   /**
    * Joins the provided messages and returns the joined message.
index a8c139f..fad9cf8 100644 (file)
@@ -18,6 +18,7 @@
  */
 package org.apache.samza.operators.functions;
 
+import java.io.Serializable;
 import org.apache.samza.annotation.InterfaceStability;
 
 
@@ -29,7 +30,7 @@ import org.apache.samza.annotation.InterfaceStability;
  */
 @InterfaceStability.Unstable
 @FunctionalInterface
-public interface MapFunction<M, OM>  extends InitableFunction, ClosableFunction {
+public interface MapFunction<M, OM>  extends InitableFunction, ClosableFunction, Serializable {
 
   /**
    * Transforms the provided message into another message.
index e290d7d..2b44125 100644 (file)
@@ -18,6 +18,7 @@
  */
 package org.apache.samza.operators.functions;
 
+import java.io.Serializable;
 import org.apache.samza.annotation.InterfaceStability;
 import org.apache.samza.task.MessageCollector;
 import org.apache.samza.task.TaskCoordinator;
@@ -30,7 +31,7 @@ import org.apache.samza.task.TaskCoordinator;
  */
 @InterfaceStability.Unstable
 @FunctionalInterface
-public interface SinkFunction<M>  extends InitableFunction, ClosableFunction {
+public interface SinkFunction<M>  extends InitableFunction, ClosableFunction, Serializable {
 
   /**
    * Allows sending the provided message to an output {@link org.apache.samza.system.SystemStream} using
index 6afcf67..356e07f 100644 (file)
@@ -18,6 +18,7 @@
  */
 package org.apache.samza.operators.functions;
 
+import java.io.Serializable;
 import org.apache.samza.annotation.InterfaceStability;
 
 
@@ -30,7 +31,7 @@ import org.apache.samza.annotation.InterfaceStability;
  * @param <JM> type of join results
  */
 @InterfaceStability.Unstable
-public interface StreamTableJoinFunction<K, M, R, JM> extends InitableFunction, ClosableFunction {
+public interface StreamTableJoinFunction<K, M, R, JM> extends InitableFunction, ClosableFunction, Serializable {
 
   /**
    * Joins the provided messages and table record, returns the joined message.
  * specific language governing permissions and limitations
  * under the License.
  */
-package org.apache.samza.testUtils;
+package org.apache.samza.operators.functions;
+
+import java.io.Serializable;
+import org.apache.samza.annotation.InterfaceStability;
+
 
 /**
- * Test class. Invalid class to implement {@link org.apache.samza.application.StreamApplication}
+ * A supplier to return a new value at each invocation
  */
-public class InvalidStreamApplication {
+@InterfaceStability.Unstable
+@FunctionalInterface
+public interface SupplierFunction<T> extends InitableFunction, ClosableFunction, Serializable {
+
+  /**
+   * Returns a value of type T
+   *
+   * @return a value for type T
+   */
+  T get();
 }
index f52b57b..6bdf406 100644 (file)
 */
 package org.apache.samza.operators.triggers;
 
+import java.util.ArrayList;
+import java.util.Collections;
 import java.util.List;
 
+
 /**
  * A {@link Trigger} fires as soon as any of its individual triggers has fired.
  */
 public class AnyTrigger<M> implements Trigger<M> {
 
-  private final List<Trigger<M>> triggers;
+  private final ArrayList<Trigger<M>> triggers;
 
   AnyTrigger(List<Trigger<M>> triggers) {
-    this.triggers = triggers;
+    this.triggers = new ArrayList<>();
+    this.triggers.addAll(triggers);
   }
 
   public List<Trigger<M>> getTriggers() {
-    return triggers;
+    return Collections.unmodifiableList(triggers);
   }
 }
index be0a877..f224fa2 100644 (file)
@@ -20,6 +20,7 @@
 package org.apache.samza.operators.triggers;
 
 
+import java.io.Serializable;
 import org.apache.samza.annotation.InterfaceStability;
 
 /**
@@ -30,6 +31,6 @@ import org.apache.samza.annotation.InterfaceStability;
  * @param <M> the type of the incoming message
  */
 @InterfaceStability.Unstable
-public interface Trigger<M> {
+public interface Trigger<M> extends Serializable {
 
 }
index 1c0fa53..7534fca 100644 (file)
@@ -18,6 +18,7 @@
  */
 package org.apache.samza.operators.windows;
 
+import java.io.Serializable;
 import org.apache.samza.annotation.InterfaceStability;
 import org.apache.samza.operators.triggers.Trigger;
 
@@ -70,7 +71,7 @@ import org.apache.samza.operators.triggers.Trigger;
  * @param <WV> the type of the value in the window
  */
 @InterfaceStability.Unstable
-public interface Window<M, K, WV> {
+public interface Window<M, K, WV> extends Serializable {
 
   /**
    * Set the early triggers for this {@link Window}.
index 50391ff..4805a0e 100644 (file)
@@ -21,6 +21,8 @@ package org.apache.samza.operators.windows;
 
 import org.apache.samza.annotation.InterfaceStability;
 import org.apache.samza.operators.functions.FoldLeftFunction;
+import org.apache.samza.operators.functions.MapFunction;
+import org.apache.samza.operators.functions.SupplierFunction;
 import org.apache.samza.operators.triggers.TimeTrigger;
 import org.apache.samza.operators.triggers.Trigger;
 import org.apache.samza.operators.triggers.Triggers;
@@ -30,8 +32,6 @@ import org.apache.samza.serializers.Serde;
 
 import java.time.Duration;
 import java.util.Collection;
-import java.util.function.Function;
-import java.util.function.Supplier;
 
 /**
  * APIs for creating different types of {@link Window}s.
@@ -84,7 +84,7 @@ import java.util.function.Supplier;
  * and triggers are fired and window panes are emitted per-key. It is possible to construct "keyed" variants
  * of the window types above.
  *
- * <p> The value for the window can be updated incrementally by providing an {@code initialValue} {@link Supplier}
+ * <p> The value for the window can be updated incrementally by providing an {@code initialValue} {@link SupplierFunction}
  * and an aggregating {@link FoldLeftFunction}. The initial value supplier is invoked every time a new window is
  * created. The aggregating function is invoked for each incoming message for the window. If these are not provided,
  * the emitted {@link WindowPane} will contain a collection of messages in the window.
@@ -105,8 +105,8 @@ public final class Windows {
    *
    * <pre> {@code
    *    MessageStream<UserClick> stream = ...;
-   *    Function<UserClick, String> keyFn = ...;
-   *    Supplier<Integer> initialValue = () -> 0;
+   *    MapFunction<UserClick, String> keyFn = ...;
+   *    SupplierFunction<Integer> initialValue = () -> 0;
    *    FoldLeftFunction<UserClick, Integer, Integer> maxAggregator = (m, c) -> Math.max(parseInt(m), c);
    *    MessageStream<WindowPane<String, Integer>> windowedStream = stream.window(
    *        Windows.keyedTumblingWindow(keyFn, Duration.ofSeconds(10), maxAggregator));
@@ -125,16 +125,15 @@ public final class Windows {
    * @param <K> the type of the key in the {@link Window}
    * @return the created {@link Window} function.
    */
-  public static <M, K, WV> Window<M, K, WV> keyedTumblingWindow(Function<? super M, ? extends K> keyFn, Duration interval,
-      Supplier<? extends WV> initialValue, FoldLeftFunction<? super M, WV> aggregator, Serde<K> keySerde,
+  public static <M, K, WV> Window<M, K, WV> keyedTumblingWindow(MapFunction<? super M, ? extends K> keyFn, Duration interval,
+      SupplierFunction<? extends WV> initialValue, FoldLeftFunction<? super M, WV> aggregator, Serde<K> keySerde,
       Serde<WV> windowValueSerde) {
 
     Trigger<M> defaultTrigger = new TimeTrigger<>(interval);
-    return new WindowInternal<>(defaultTrigger, (Supplier<WV>) initialValue, (FoldLeftFunction<M, WV>) aggregator,
-        (Function<M, K>) keyFn, null, WindowType.TUMBLING, keySerde, windowValueSerde, null);
+    return new WindowInternal<>(defaultTrigger, (SupplierFunction<WV>) initialValue, (FoldLeftFunction<M, WV>) aggregator,
+        (MapFunction<M, K>) keyFn, null, WindowType.TUMBLING, keySerde, windowValueSerde, null);
   }
 
-
   /**
    * Creates a {@link Window} that groups incoming messages into fixed-size, non-overlapping
    * processing time based windows using the provided keyFn.
@@ -157,12 +156,12 @@ public final class Windows {
    * @param <K> the type of the key in the {@link Window}
    * @return the created {@link Window} function
    */
-  public static <M, K> Window<M, K, Collection<M>> keyedTumblingWindow(Function<M, K> keyFn, Duration interval,
+  public static <M, K> Window<M, K, Collection<M>> keyedTumblingWindow(MapFunction<M, K> keyFn, Duration interval,
       Serde<K> keySerde, Serde<M> msgSerde) {
 
     Trigger<M> defaultTrigger = new TimeTrigger<>(interval);
-    return new WindowInternal<>(defaultTrigger, null, null, keyFn, null,
-        WindowType.TUMBLING, keySerde, null, msgSerde);
+    return new WindowInternal<>(defaultTrigger, null, null, keyFn, null, WindowType.TUMBLING,
+        keySerde, null, msgSerde);
   }
 
   /**
@@ -173,7 +172,7 @@ public final class Windows {
    *
    * <pre> {@code
    *    MessageStream<String> stream = ...;
-   *    Supplier<Integer> initialValue = () -> 0;
+   *    SupplierFunction<Integer> initialValue = () -> 0;
    *    FoldLeftFunction<String, Integer, Integer> maxAggregator = (m, c) -> Math.max(parseInt(m), c);
    *    MessageStream<WindowPane<Void, Integer>> windowedStream = stream.window(
    *        Windows.tumblingWindow(Duration.ofSeconds(10), maxAggregator));
@@ -189,10 +188,10 @@ public final class Windows {
    * @param <WV> the type of the {@link WindowPane} output value
    * @return the created {@link Window} function
    */
-  public static <M, WV> Window<M, Void, WV> tumblingWindow(Duration interval, Supplier<? extends WV> initialValue,
+  public static <M, WV> Window<M, Void, WV> tumblingWindow(Duration interval, SupplierFunction<? extends WV> initialValue,
       FoldLeftFunction<? super M, WV> aggregator, Serde<WV> windowValueSerde) {
     Trigger<M> defaultTrigger = new TimeTrigger<>(interval);
-    return new WindowInternal<>(defaultTrigger, (Supplier<WV>) initialValue, (FoldLeftFunction<M, WV>) aggregator,
+    return new WindowInternal<>(defaultTrigger, (SupplierFunction<WV>) initialValue, (FoldLeftFunction<M, WV>) aggregator,
         null, null, WindowType.TUMBLING, null, windowValueSerde, null);
   }
 
@@ -221,9 +220,8 @@ public final class Windows {
    */
   public static <M> Window<M, Void, Collection<M>> tumblingWindow(Duration duration, Serde<M> msgSerde) {
     Trigger<M> defaultTrigger = new TimeTrigger<>(duration);
-
-    return new WindowInternal<>(defaultTrigger, null, null, null,
-       null, WindowType.TUMBLING, null, null, msgSerde);
+    return new WindowInternal<>(defaultTrigger, null, null, null, null,
+        WindowType.TUMBLING, null, null, msgSerde);
   }
 
   /**
@@ -238,7 +236,7 @@ public final class Windows {
    *
    * <pre> {@code
    *    MessageStream<UserClick> stream = ...;
-   *    Supplier<Integer> initialValue = () -> 0;
+   *    SupplierFunction<Integer> initialValue = () -> 0;
    *    FoldLeftFunction<UserClick, Integer, Integer> maxAggregator = (m, c) -> Math.max(parseInt(m), c);
    *    Function<UserClick, String> userIdExtractor = m -> m.getUserId()..;
    *    MessageStream<WindowPane<String, Integer>> windowedStream = stream.window(
@@ -258,12 +256,12 @@ public final class Windows {
    * @param <WV> the type of the output value in the {@link WindowPane}
    * @return the created {@link Window} function
    */
-  public static <M, K, WV> Window<M, K, WV> keyedSessionWindow(Function<? super M, ? extends K> keyFn,
-      Duration sessionGap, Supplier<? extends WV> initialValue, FoldLeftFunction<? super M, WV> aggregator,
+  public static <M, K, WV> Window<M, K, WV> keyedSessionWindow(MapFunction<? super M, ? extends K> keyFn,
+      Duration sessionGap, SupplierFunction<? extends WV> initialValue, FoldLeftFunction<? super M, WV> aggregator,
       Serde<K> keySerde, Serde<WV> windowValueSerde) {
     Trigger<M> defaultTrigger = Triggers.timeSinceLastMessage(sessionGap);
-    return new WindowInternal<>(defaultTrigger, (Supplier<WV>) initialValue, (FoldLeftFunction<M, WV>) aggregator,
-        (Function<M, K>) keyFn, null, WindowType.SESSION, keySerde, windowValueSerde, null);
+    return new WindowInternal<>(defaultTrigger, (SupplierFunction<WV>) initialValue, (FoldLeftFunction<M, WV>) aggregator,
+        (MapFunction<M, K>) keyFn, null, WindowType.SESSION, keySerde, windowValueSerde, null);
   }
 
   /**
@@ -278,7 +276,7 @@ public final class Windows {
    *
    * <pre> {@code
    *    MessageStream<UserClick> stream = ...;
-   *    Supplier<Integer> initialValue = () -> 0;
+   *    SupplierFunction<Integer> initialValue = () -> 0;
    *    FoldLeftFunction<UserClick, Integer, Integer> maxAggregator = (m, c)-> Math.max(parseIntField(m), c);
    *    Function<UserClick, String> userIdExtractor = m -> m.getUserId()..;
    *    MessageStream<WindowPane<String>, Collection<M>> windowedStream = stream.window(
@@ -294,11 +292,10 @@ public final class Windows {
    * @param <K> the type of the key in the {@link Window}
    * @return the created {@link Window} function
    */
-  public static <M, K> Window<M, K, Collection<M>> keyedSessionWindow(Function<? super M, ? extends K> keyFn,
+  public static <M, K> Window<M, K, Collection<M>> keyedSessionWindow(MapFunction<? super M, ? extends K> keyFn,
       Duration sessionGap, Serde<K> keySerde, Serde<M> msgSerde) {
-
     Trigger<M> defaultTrigger = Triggers.timeSinceLastMessage(sessionGap);
-    return new WindowInternal<>(defaultTrigger, null, null, (Function<M, K>) keyFn,
+    return new WindowInternal<>(defaultTrigger, null, null, (MapFunction<M, K>) keyFn,
         null, WindowType.SESSION, keySerde, null, msgSerde);
   }
 }
index bc71872..ff19aba 100644 (file)
  */
 package org.apache.samza.operators.windows.internal;
 import org.apache.samza.annotation.InterfaceStability;
+import org.apache.samza.operators.functions.MapFunction;
+import org.apache.samza.operators.functions.SupplierFunction;
 import org.apache.samza.operators.functions.FoldLeftFunction;
 import org.apache.samza.operators.triggers.Trigger;
 import org.apache.samza.operators.windows.AccumulationMode;
 import org.apache.samza.operators.windows.Window;
 import org.apache.samza.serializers.Serde;
 
-import java.util.function.Function;
-import java.util.function.Supplier;
 
 /**
  *  Internal representation of a {@link Window}. This specifies default, early and late triggers for the {@link Window}
@@ -45,7 +45,7 @@ public final class WindowInternal<M, WK, WV> implements Window<M, WK, WV> {
   /**
    * The supplier of initial value to be used for windowed aggregations
    */
-  private final Supplier<WV> initializer;
+  private final SupplierFunction<WV> initializer;
 
   /*
    * The function that is applied each time a {@link MessageEnvelope} is added to this window.
@@ -55,28 +55,32 @@ public final class WindowInternal<M, WK, WV> implements Window<M, WK, WV> {
   /*
    * The function that extracts the key from a {@link MessageEnvelope}
    */
-  private final Function<M, WK> keyExtractor;
+  private final MapFunction<M, WK> keyExtractor;
 
   /*
    * The function that extracts the event time from a {@link MessageEnvelope}
    */
-  private final Function<M, Long> eventTimeExtractor;
+  private final MapFunction<M, Long> eventTimeExtractor;
 
   /**
    * The type of this window. Tumbling and Session windows are supported for now.
    */
   private final WindowType windowType;
 
-  private final Serde<WK> keySerde;
-  private final Serde<WV> windowValSerde;
-  private final Serde<M> msgSerde;
-
   private Trigger<M> earlyTrigger;
   private Trigger<M> lateTrigger;
   private AccumulationMode mode;
 
-  public WindowInternal(Trigger<M> defaultTrigger, Supplier<WV> initializer, FoldLeftFunction<M, WV> foldLeftFunction,
-      Function<M, WK> keyExtractor, Function<M, Long> eventTimeExtractor, WindowType windowType, Serde<WK> keySerde,
+  /**
+   * The following {@link Serde}s are serialized by the ExecutionPlanner when generating the store configs, and deserialized
+   * once during startup in SamzaContainer. They don't need to be deserialized here on a per-task basis
+   */
+  private transient final Serde<WK> keySerde;
+  private transient final Serde<WV> windowValSerde;
+  private transient final Serde<M> msgSerde;
+
+  public WindowInternal(Trigger<M> defaultTrigger, SupplierFunction<WV> initializer, FoldLeftFunction<M, WV> foldLeftFunction,
+      MapFunction<M, WK> keyExtractor, MapFunction<M, Long> eventTimeExtractor, WindowType windowType, Serde<WK> keySerde,
       Serde<WV> windowValueSerde, Serde<M> msgSerde) {
     this.defaultTrigger = defaultTrigger;
     this.initializer = initializer;
@@ -121,7 +125,7 @@ public final class WindowInternal<M, WK, WV> implements Window<M, WK, WV> {
     return lateTrigger;
   }
 
-  public Supplier<WV> getInitializer() {
+  public SupplierFunction<WV> getInitializer() {
     return initializer;
   }
 
@@ -129,11 +133,11 @@ public final class WindowInternal<M, WK, WV> implements Window<M, WK, WV> {
     return foldLeftFunction;
   }
 
-  public Function<M, WK> getKeyExtractor() {
+  public MapFunction<M, WK> getKeyExtractor() {
     return keyExtractor;
   }
 
-  public Function<M, Long> getEventTimeExtractor() {
+  public MapFunction<M, Long> getEventTimeExtractor() {
     return eventTimeExtractor;
   }
 
index d70746c..d49518c 100644 (file)
@@ -68,7 +68,7 @@ public class SerializableSerde<T extends Serializable> implements Serde<T> {
         ois = new ObjectInputStream(bis);
         return (T) ois.readObject();
       } catch (IOException | ClassNotFoundException e) {
-        throw new SamzaException("Error reading from input stream.");
+        throw new SamzaException("Error reading from input stream.", e);
       } finally {
         try {
           if (ois != null) {
index ce67d8d..cd86426 100644 (file)
@@ -19,6 +19,7 @@
 
 package org.apache.samza.system;
 
+import java.io.Serializable;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.Map;
@@ -33,7 +34,7 @@ import java.util.Map;
  *
  * It is immutable by design.
  */
-public class StreamSpec {
+public class StreamSpec implements Serializable {
 
   private static final int DEFAULT_PARTITION_COUNT = 1;
 
index 95cc266..e9ca9f7 100644 (file)
@@ -60,7 +60,7 @@ public class SystemStreamPartition extends SystemStream implements Comparable<Sy
   public Partition getPartition() {
     return partition;
   }
-  
+
   public SystemStream getSystemStream() {
     return new SystemStream(system, stream);
   }
@@ -69,7 +69,7 @@ public class SystemStreamPartition extends SystemStream implements Comparable<Sy
   public int hashCode() {
     return hash;
   }
-  
+
   private int computeHashCode() {
     final int prime = 31;
     int result = super.hashCode();
index 68043f9..ba57c2f 100644 (file)
@@ -18,6 +18,7 @@
  */
 package org.apache.samza.table;
 
+import java.io.Serializable;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.Map;
@@ -41,12 +42,17 @@ import org.apache.samza.serializers.KVSerde;
  * It is immutable by design.
  */
 @InterfaceStability.Unstable
-public class TableSpec {
+public class TableSpec implements Serializable {
 
   private final String id;
-  private final KVSerde serde;
   private final String tableProviderFactoryClassName;
-  private final Map<String, String> config = new HashMap<>();
+
+  /**
+   * The following fields are serialized by the ExecutionPlanner when generating the configs for a table, and deserialized
+   * once during startup in SamzaContainer. They don't need to be deserialized here on a per-task basis
+   */
+  private transient final KVSerde serde;
+  private transient final Map<String, String> config = new HashMap<>();
 
   /**
    * Default constructor
index 4184c9d..19cce6f 100644 (file)
@@ -27,7 +27,7 @@ import static org.junit.Assert.assertEquals;
 public class TestWindowPane {
   @Test
   public void testConstructor() {
-    WindowPane<String, Integer> wndOutput = new WindowPane(new WindowKey("testMsg", null), 10, AccumulationMode.DISCARDING, FiringType.EARLY);
+    WindowPane<String, Integer> wndOutput = new WindowPane<>(new WindowKey<>("testMsg", null), 10, AccumulationMode.DISCARDING, FiringType.EARLY);
     assertEquals(wndOutput.getKey().getKey(), "testMsg");
     assertEquals(wndOutput.getMessage(), Integer.valueOf(10));
   }
index e2c122a..9d8bd5f 100644 (file)
@@ -34,7 +34,7 @@ import org.apache.samza.config.ApplicationConfig;
 import org.apache.samza.config.ClusterManagerConfig;
 import org.apache.samza.config.Config;
 import org.apache.samza.config.JobConfig;
-import org.apache.samza.operators.StreamGraphImpl;
+import org.apache.samza.operators.OperatorSpecGraph;
 import org.apache.samza.operators.spec.JoinOperatorSpec;
 import org.apache.samza.operators.spec.OperatorSpec;
 import org.apache.samza.system.StreamSpec;
@@ -61,18 +61,18 @@ public class ExecutionPlanner {
     this.streamManager = streamManager;
   }
 
-  public ExecutionPlan plan(StreamGraphImpl streamGraph) throws Exception {
+  public ExecutionPlan plan(OperatorSpecGraph specGraph) throws Exception {
     validateConfig();
 
     // create physical job graph based on stream graph
-    JobGraph jobGraph = createJobGraph(streamGraph);
+    JobGraph jobGraph = createJobGraph(specGraph);
 
     // fetch the external streams partition info
     updateExistingPartitions(jobGraph, streamManager);
 
     if (!jobGraph.getIntermediateStreamEdges().isEmpty()) {
       // figure out the partitions for internal streams
-      calculatePartitions(streamGraph, jobGraph);
+      calculatePartitions(jobGraph);
     }
 
     return jobGraph;
@@ -91,12 +91,12 @@ public class ExecutionPlanner {
   /**
    * Create the physical graph from StreamGraph
    */
-  /* package private */ JobGraph createJobGraph(StreamGraphImpl streamGraph) {
-    JobGraph jobGraph = new JobGraph(config);
-    Set<StreamSpec> sourceStreams = new HashSet<>(streamGraph.getInputOperators().keySet());
-    Set<StreamSpec> sinkStreams = new HashSet<>(streamGraph.getOutputStreams().keySet());
+  /* package private */ JobGraph createJobGraph(OperatorSpecGraph specGraph) {
+    JobGraph jobGraph = new JobGraph(config, specGraph);
+    Set<StreamSpec> sourceStreams = new HashSet<>(specGraph.getInputOperators().keySet());
+    Set<StreamSpec> sinkStreams = new HashSet<>(specGraph.getOutputStreams().keySet());
     Set<StreamSpec> intStreams = new HashSet<>(sourceStreams);
-    Set<TableSpec> tables = new HashSet<>(streamGraph.getTables().keySet());
+    Set<TableSpec> tables = new HashSet<>(specGraph.getTables().keySet());
     intStreams.retainAll(sinkStreams);
     sourceStreams.removeAll(intStreams);
     sinkStreams.removeAll(intStreams);
@@ -104,7 +104,7 @@ public class ExecutionPlanner {
     // For this phase, we have a single job node for the whole dag
     String jobName = config.get(JobConfig.JOB_NAME());
     String jobId = config.get(JobConfig.JOB_ID(), "1");
-    JobNode node = jobGraph.getOrCreateJobNode(jobName, jobId, streamGraph);
+    JobNode node = jobGraph.getOrCreateJobNode(jobName, jobId);
 
     // add sources
     sourceStreams.forEach(spec -> jobGraph.addSource(spec, node));
@@ -126,9 +126,9 @@ public class ExecutionPlanner {
   /**
    * Figure out the number of partitions of all streams
    */
-  /* package private */ void calculatePartitions(StreamGraphImpl streamGraph, JobGraph jobGraph) {
+  /* package private */ void calculatePartitions(JobGraph jobGraph) {
     // calculate the partitions for the input streams of join operators
-    calculateJoinInputPartitions(streamGraph, jobGraph);
+    calculateJoinInputPartitions(jobGraph);
 
     // calculate the partitions for the rest of intermediate streams
     calculateIntStreamPartitions(jobGraph, config);
@@ -172,7 +172,7 @@ public class ExecutionPlanner {
   /**
    * Calculate the partitions for the input streams of join operators
    */
-  /* package private */ static void calculateJoinInputPartitions(StreamGraphImpl streamGraph, JobGraph jobGraph) {
+  /* package private */ static void calculateJoinInputPartitions(JobGraph jobGraph) {
     // mapping from a source stream to all join specs reachable from it
     Multimap<OperatorSpec, StreamEdge> joinSpecToStreamEdges = HashMultimap.create();
     // reverse mapping of the above
@@ -182,7 +182,7 @@ public class ExecutionPlanner {
     // The visited set keeps track of the join specs that have been already inserted in the queue before
     Set<OperatorSpec> visited = new HashSet<>();
 
-    streamGraph.getInputOperators().entrySet().forEach(entry -> {
+    jobGraph.getSpecGraph().getInputOperators().entrySet().forEach(entry -> {
         StreamEdge streamEdge = jobGraph.getOrCreateStreamEdge(entry.getKey());
         // Traverses the StreamGraph to find and update mappings for all Joins reachable from this input StreamEdge
         findReachableJoins(entry.getValue(), streamEdge, joinSpecToStreamEdges, streamEdgeToJoinSpecs,
index abd3ce7..843db85 100644 (file)
@@ -34,7 +34,7 @@ import java.util.stream.Collectors;
 import org.apache.samza.config.ApplicationConfig;
 import org.apache.samza.config.Config;
 import org.apache.samza.config.JobConfig;
-import org.apache.samza.operators.StreamGraphImpl;
+import org.apache.samza.operators.OperatorSpecGraph;
 import org.apache.samza.system.StreamSpec;
 import org.apache.samza.table.TableSpec;
 import org.slf4j.Logger;
@@ -60,13 +60,15 @@ import org.slf4j.LoggerFactory;
   private final Set<TableSpec> tables = new HashSet<>();
   private final Config config;
   private final JobGraphJsonGenerator jsonGenerator = new JobGraphJsonGenerator();
+  private final OperatorSpecGraph specGraph;
 
   /**
    * The JobGraph is only constructed by the {@link ExecutionPlanner}.
    * @param config Config
    */
-  JobGraph(Config config) {
+  JobGraph(Config config, OperatorSpecGraph specGraph) {
     this.config = config;
+    this.specGraph = specGraph;
   }
 
   @Override
@@ -107,6 +109,10 @@ import org.slf4j.LoggerFactory;
     return new ApplicationConfig(config);
   }
 
+  public OperatorSpecGraph getSpecGraph() {
+    return specGraph;
+  }
+
   /**
    * Add a source stream to a {@link JobNode}
    * @param input source stream
@@ -152,11 +158,11 @@ import org.slf4j.LoggerFactory;
    * @param jobId id of the job
    * @return
    */
-  JobNode getOrCreateJobNode(String jobName, String jobId, StreamGraphImpl streamGraph) {
+  JobNode getOrCreateJobNode(String jobName, String jobId) {
     String nodeId = JobNode.createId(jobName, jobId);
     JobNode node = nodes.get(nodeId);
     if (node == null) {
-      node = new JobNode(jobName, jobId, streamGraph, config);
+      node = new JobNode(jobName, jobId, specGraph, config);
       nodes.put(nodeId, node);
     }
     return node;
index 48d2219..298042b 100644 (file)
@@ -170,7 +170,7 @@ import org.codehaus.jackson.map.ObjectMapper;
   private OperatorGraphJson buildOperatorGraphJson(JobNode jobNode) {
     OperatorGraphJson opGraph = new OperatorGraphJson();
     opGraph.inputStreams = new ArrayList<>();
-    jobNode.getStreamGraph().getInputOperators().forEach((streamSpec, operatorSpec) -> {
+    jobNode.getSpecGraph().getInputOperators().forEach((streamSpec, operatorSpec) -> {
         StreamJson inputJson = new StreamJson();
         opGraph.inputStreams.add(inputJson);
         inputJson.streamId = streamSpec.getId();
@@ -181,7 +181,7 @@ import org.codehaus.jackson.map.ObjectMapper;
       });
 
     opGraph.outputStreams = new ArrayList<>();
-    jobNode.getStreamGraph().getOutputStreams().keySet().forEach(streamSpec -> {
+    jobNode.getSpecGraph().getOutputStreams().keySet().forEach(streamSpec -> {
         StreamJson outputJson = new StreamJson();
         outputJson.streamId = streamSpec.getId();
         opGraph.outputStreams.add(outputJson);
index 8abd463..db44d9f 100644 (file)
@@ -39,7 +39,7 @@ import org.apache.samza.config.StorageConfig;
 import org.apache.samza.config.StreamConfig;
 import org.apache.samza.config.TaskConfig;
 import org.apache.samza.config.TaskConfigJava;
-import org.apache.samza.operators.StreamGraphImpl;
+import org.apache.samza.operators.OperatorSpecGraph;
 import org.apache.samza.operators.spec.InputOperatorSpec;
 import org.apache.samza.operators.spec.JoinOperatorSpec;
 import org.apache.samza.operators.spec.OperatorSpec;
@@ -73,22 +73,22 @@ public class JobNode {
   private final String jobName;
   private final String jobId;
   private final String id;
-  private final StreamGraphImpl streamGraph;
+  private final OperatorSpecGraph specGraph;
   private final List<StreamEdge> inEdges = new ArrayList<>();
   private final List<StreamEdge> outEdges = new ArrayList<>();
   private final List<TableSpec> tables = new ArrayList<>();
   private final Config config;
 
-  JobNode(String jobName, String jobId, StreamGraphImpl streamGraph, Config config) {
+  JobNode(String jobName, String jobId, OperatorSpecGraph specGraph, Config config) {
     this.jobName = jobName;
     this.jobId = jobId;
     this.id = createId(jobName, jobId);
-    this.streamGraph = streamGraph;
+    this.specGraph = specGraph;
     this.config = config;
   }
 
-  public StreamGraphImpl getStreamGraph() {
-    return streamGraph;
+  public OperatorSpecGraph getSpecGraph() {
+    return this.specGraph;
   }
 
   public  String getId() {
@@ -154,7 +154,7 @@ public class JobNode {
     }
 
     // set triggering interval if a window or join is defined
-    if (streamGraph.hasWindowOrJoins()) {
+    if (specGraph.hasWindowOrJoins()) {
       if ("-1".equals(config.get(TaskConfig.WINDOW_MS(), "-1"))) {
         long triggerInterval = computeTriggerInterval();
         log.info("Using triggering interval: {} for jobName: {}", triggerInterval, jobName);
@@ -163,7 +163,7 @@ public class JobNode {
       }
     }
 
-    streamGraph.getAllOperatorSpecs().forEach(opSpec -> {
+    specGraph.getAllOperatorSpecs().forEach(opSpec -> {
         if (opSpec instanceof StatefulOperatorSpec) {
           ((StatefulOperatorSpec) opSpec).getStoreDescriptors()
               .forEach(sd -> configs.putAll(sd.getStorageConfigs()));
@@ -228,14 +228,14 @@ public class JobNode {
     // collect all key and msg serde instances for streams
     Map<String, Serde> streamKeySerdes = new HashMap<>();
     Map<String, Serde> streamMsgSerdes = new HashMap<>();
-    Map<StreamSpec, InputOperatorSpec> inputOperators = streamGraph.getInputOperators();
+    Map<StreamSpec, InputOperatorSpec> inputOperators = specGraph.getInputOperators();
     inEdges.forEach(edge -> {
         String streamId = edge.getStreamSpec().getId();
         InputOperatorSpec inputOperatorSpec = inputOperators.get(edge.getStreamSpec());
         streamKeySerdes.put(streamId, inputOperatorSpec.getKeySerde());
         streamMsgSerdes.put(streamId, inputOperatorSpec.getValueSerde());
       });
-    Map<StreamSpec, OutputStreamImpl> outputStreams = streamGraph.getOutputStreams();
+    Map<StreamSpec, OutputStreamImpl> outputStreams = specGraph.getOutputStreams();
     outEdges.forEach(edge -> {
         String streamId = edge.getStreamSpec().getId();
         OutputStreamImpl outputStream = outputStreams.get(edge.getStreamSpec());
@@ -246,7 +246,7 @@ public class JobNode {
     // collect all key and msg serde instances for stores
     Map<String, Serde> storeKeySerdes = new HashMap<>();
     Map<String, Serde> storeMsgSerdes = new HashMap<>();
-    streamGraph.getAllOperatorSpecs().forEach(opSpec -> {
+    specGraph.getAllOperatorSpecs().forEach(opSpec -> {
         if (opSpec instanceof StatefulOperatorSpec) {
           ((StatefulOperatorSpec) opSpec).getStoreDescriptors().forEach(storeDescriptor -> {
               storeKeySerdes.put(storeDescriptor.getStoreName(), storeDescriptor.getKeySerde());
@@ -320,8 +320,8 @@ public class JobNode {
    * Computes the triggering interval to use during the execution of this {@link JobNode}
    */
   private long computeTriggerInterval() {
-    // Obtain the operator specs from the streamGraph
-    Collection<OperatorSpec> operatorSpecs = streamGraph.getAllOperatorSpecs();
+    // Obtain the operator specs from the specGraph
+    Collection<OperatorSpec> operatorSpecs = specGraph.getAllOperatorSpecs();
 
     // Filter out window operators, and obtain a list of their triggering interval values
     List<Long> windowTimerIntervals = operatorSpecs.stream()
index 1681f30..6922c76 100644 (file)
@@ -21,7 +21,6 @@ package org.apache.samza.operators;
 
 import java.time.Duration;
 import java.util.Collection;
-import java.util.function.Function;
 
 import org.apache.samza.SamzaException;
 import org.apache.samza.operators.functions.FilterFunction;
@@ -64,16 +63,16 @@ import org.apache.samza.table.TableSpec;
  */
 public class MessageStreamImpl<M> implements MessageStream<M> {
   /**
-   * The {@link StreamGraphImpl} that contains this {@link MessageStreamImpl}
+   * The {@link StreamGraphSpec} that contains this {@link MessageStreamImpl}
    */
-  private final StreamGraphImpl graph;
+  private final StreamGraphSpec graph;
 
   /**
    * The {@link OperatorSpec} associated with this {@link MessageStreamImpl}
    */
   private final OperatorSpec operatorSpec;
 
-  public MessageStreamImpl(StreamGraphImpl graph, OperatorSpec<?, M> operatorSpec) {
+  public MessageStreamImpl(StreamGraphSpec graph, OperatorSpec<?, M> operatorSpec) {
     this.graph = graph;
     this.operatorSpec = operatorSpec;
   }
@@ -81,7 +80,7 @@ public class MessageStreamImpl<M> implements MessageStream<M> {
   @Override
   public <TM> MessageStream<TM> map(MapFunction<? super M, ? extends TM> mapFn) {
     String opId = this.graph.getNextOpId(OpCode.MAP);
-    OperatorSpec<M, TM> op = OperatorSpecs.createMapOperatorSpec(mapFn, opId);
+    StreamOperatorSpec<M, TM> op = OperatorSpecs.createMapOperatorSpec(mapFn, opId);
     this.operatorSpec.registerNextOperatorSpec(op);
     return new MessageStreamImpl<>(this.graph, op);
   }
@@ -89,7 +88,7 @@ public class MessageStreamImpl<M> implements MessageStream<M> {
   @Override
   public MessageStream<M> filter(FilterFunction<? super M> filterFn) {
     String opId = this.graph.getNextOpId(OpCode.FILTER);
-    OperatorSpec<M, M> op = OperatorSpecs.createFilterOperatorSpec(filterFn, opId);
+    StreamOperatorSpec<M, M> op = OperatorSpecs.createFilterOperatorSpec(filterFn, opId);
     this.operatorSpec.registerNextOperatorSpec(op);
     return new MessageStreamImpl<>(this.graph, op);
   }
@@ -97,7 +96,7 @@ public class MessageStreamImpl<M> implements MessageStream<M> {
   @Override
   public <TM> MessageStream<TM> flatMap(FlatMapFunction<? super M, ? extends TM> flatMapFn) {
     String opId = this.graph.getNextOpId(OpCode.FLAT_MAP);
-    OperatorSpec<M, TM> op = OperatorSpecs.createFlatMapOperatorSpec(flatMapFn, opId);
+    StreamOperatorSpec<M, TM> op = OperatorSpecs.createFlatMapOperatorSpec(flatMapFn, opId);
     this.operatorSpec.registerNextOperatorSpec(op);
     return new MessageStreamImpl<>(this.graph, op);
   }
@@ -112,15 +111,15 @@ public class MessageStreamImpl<M> implements MessageStream<M> {
   @Override
   public void sendTo(OutputStream<M> outputStream) {
     String opId = this.graph.getNextOpId(OpCode.SEND_TO);
-    OutputOperatorSpec<M> op = OperatorSpecs.createSendToOperatorSpec((OutputStreamImpl<M>) outputStream, opId);
+    OutputOperatorSpec<M> op = OperatorSpecs.createSendToOperatorSpec(
+        (OutputStreamImpl<M>) outputStream, opId);
     this.operatorSpec.registerNextOperatorSpec(op);
   }
 
   @Override
   public <K, WV> MessageStream<WindowPane<K, WV>> window(Window<M, K, WV> window, String userDefinedId) {
     String opId = this.graph.getNextOpId(OpCode.WINDOW, userDefinedId);
-    OperatorSpec<M, WindowPane<K, WV>> op = OperatorSpecs.createWindowOperatorSpec(
-        (WindowInternal<M, K, WV>) window, opId);
+    OperatorSpec<M, WindowPane<K, WV>> op = OperatorSpecs.createWindowOperatorSpec((WindowInternal<M, K, WV>) window, opId);
     this.operatorSpec.registerNextOperatorSpec(op);
     return new MessageStreamImpl<>(this.graph, op);
   }
@@ -131,24 +130,24 @@ public class MessageStreamImpl<M> implements MessageStream<M> {
       Serde<K> keySerde, Serde<M> messageSerde, Serde<OM> otherMessageSerde,
       Duration ttl, String userDefinedId) {
     if (otherStream.equals(this)) throw new SamzaException("Cannot join a MessageStream with itself.");
-    OperatorSpec<?, OM> otherOpSpec = ((MessageStreamImpl<OM>) otherStream).getOperatorSpec();
     String opId = this.graph.getNextOpId(OpCode.JOIN, userDefinedId);
-    JoinOperatorSpec<K, M, OM, JM> joinOpSpec =
-        OperatorSpecs.createJoinOperatorSpec(this.operatorSpec, otherOpSpec, (JoinFunction<K, M, OM, JM>) joinFn,
-            keySerde, messageSerde, otherMessageSerde, ttl.toMillis(), opId);
-
-    this.operatorSpec.registerNextOperatorSpec(joinOpSpec);
-    otherOpSpec.registerNextOperatorSpec((OperatorSpec<OM, ?>) joinOpSpec);
+    OperatorSpec<?, OM> otherOpSpec = ((MessageStreamImpl<OM>) otherStream).getOperatorSpec();
+    JoinOperatorSpec<K, M, OM, JM> op =
+        OperatorSpecs.createJoinOperatorSpec(this.operatorSpec, otherOpSpec, (JoinFunction<K, M, OM, JM>) joinFn, keySerde,
+            messageSerde, otherMessageSerde, ttl.toMillis(), opId);
+    this.operatorSpec.registerNextOperatorSpec(op);
+    otherOpSpec.registerNextOperatorSpec((OperatorSpec<OM, ?>) op);
 
-    return new MessageStreamImpl<>(this.graph, joinOpSpec);
+    return new MessageStreamImpl<>(this.graph, op);
   }
 
   @Override
   public <K, R extends KV, JM> MessageStream<JM> join(Table<R> table,
       StreamTableJoinFunction<? extends K, ? super M, ? super R, ? extends JM> joinFn) {
+    String opId = this.graph.getNextOpId(OpCode.JOIN);
     TableSpec tableSpec = ((TableImpl) table).getTableSpec();
     StreamTableJoinOperatorSpec<K, M, R, JM> joinOpSpec = OperatorSpecs.createStreamTableJoinOperatorSpec(
-        tableSpec, (StreamTableJoinFunction<K, M, R, JM>) joinFn, this.graph.getNextOpId(OpCode.JOIN));
+        tableSpec, (StreamTableJoinFunction<K, M, R, JM>) joinFn, opId);
     this.operatorSpec.registerNextOperatorSpec(joinOpSpec);
     return new MessageStreamImpl<>(this.graph, joinOpSpec);
   }
@@ -157,46 +156,38 @@ public class MessageStreamImpl<M> implements MessageStream<M> {
   public MessageStream<M> merge(Collection<? extends MessageStream<? extends M>> otherStreams) {
     if (otherStreams.isEmpty()) return this;
     String opId = this.graph.getNextOpId(OpCode.MERGE);
-    StreamOperatorSpec<M, M> opSpec = OperatorSpecs.createMergeOperatorSpec(opId);
-    this.operatorSpec.registerNextOperatorSpec(opSpec);
-    otherStreams.forEach(other -> ((MessageStreamImpl<M>) other).getOperatorSpec().registerNextOperatorSpec(opSpec));
-    return new MessageStreamImpl<>(this.graph, opSpec);
+    StreamOperatorSpec<M, M> op = OperatorSpecs.createMergeOperatorSpec(opId);
+    this.operatorSpec.registerNextOperatorSpec(op);
+    otherStreams.forEach(other -> ((MessageStreamImpl<M>) other).getOperatorSpec().registerNextOperatorSpec(op));
+    return new MessageStreamImpl<>(this.graph, op);
   }
 
   @Override
-  public <K, V> MessageStream<KV<K, V>> partitionBy(Function<? super M, ? extends K> keyExtractor,
-      Function<? super M, ? extends V> valueExtractor, KVSerde<K, V> serde, String userDefinedId) {
+  public <K, V> MessageStream<KV<K, V>> partitionBy(MapFunction<? super M, ? extends K> keyExtractor,
+      MapFunction<? super M, ? extends V> valueExtractor, KVSerde<K, V> serde, String userDefinedId) {
     String opId = this.graph.getNextOpId(OpCode.PARTITION_BY, userDefinedId);
     IntermediateMessageStreamImpl<KV<K, V>> intermediateStream = this.graph.getIntermediateStream(opId, serde);
     if (!intermediateStream.isKeyed()) {
       // this can only happen when the default serde partitionBy variant is being used
       throw new SamzaException("partitionBy can not be used with a default serde that is not a KVSerde.");
     }
-    PartitionByOperatorSpec<M, K, V> partitionByOperatorSpec =
-        OperatorSpecs.createPartitionByOperatorSpec(
-            intermediateStream.getOutputStream(), keyExtractor, valueExtractor, opId);
+    PartitionByOperatorSpec<M, K, V> partitionByOperatorSpec = OperatorSpecs.createPartitionByOperatorSpec(
+        intermediateStream.getOutputStream(), keyExtractor, valueExtractor, opId);
     this.operatorSpec.registerNextOperatorSpec(partitionByOperatorSpec);
     return intermediateStream;
   }
 
   @Override
-  public <K, V> MessageStream<KV<K, V>> partitionBy(Function<? super M, ? extends K> keyExtractor,
-      Function<? super M, ? extends V> valueExtractor, String userDefinedId) {
+  public <K, V> MessageStream<KV<K, V>> partitionBy(MapFunction<? super M, ? extends K> keyExtractor,
+      MapFunction<? super M, ? extends V> valueExtractor, String userDefinedId) {
     return partitionBy(keyExtractor, valueExtractor, null, userDefinedId);
   }
 
-  /**
-   * Get the {@link OperatorSpec} associated with this {@link MessageStreamImpl}.
-   * @return the {@link OperatorSpec} associated with this {@link MessageStreamImpl}.
-   */
-  protected OperatorSpec<?, M> getOperatorSpec() {
-    return this.operatorSpec;
-  }
-
   @Override
   public <K, V> void sendTo(Table<KV<K, V>> table) {
-    SendToTableOperatorSpec<K, V> op = OperatorSpecs.createSendToTableOperatorSpec(
-        this.operatorSpec, ((TableImpl) table).getTableSpec(), this.graph.getNextOpId(OpCode.SEND_TO));
+    String opId = this.graph.getNextOpId(OpCode.SEND_TO);
+    SendToTableOperatorSpec<K, V> op =
+        OperatorSpecs.createSendToTableOperatorSpec(((TableImpl) table).getTableSpec(), opId);
     this.operatorSpec.registerNextOperatorSpec(op);
   }
 
@@ -215,4 +206,12 @@ public class MessageStreamImpl<M> implements MessageStream<M> {
     return broadcast(null, userDefinedId);
   }
 
+  /**
+   * Get the {@link OperatorSpec} associated with this {@link MessageStreamImpl}.
+   * @return the {@link OperatorSpec} associated with this {@link MessageStreamImpl}.
+   */
+  protected OperatorSpec<?, M> getOperatorSpec() {
+    return this.operatorSpec;
+  }
+
 }
diff --git a/samza-core/src/main/java/org/apache/samza/operators/OperatorSpecGraph.java b/samza-core/src/main/java/org/apache/samza/operators/OperatorSpecGraph.java
new file mode 100644 (file)
index 0000000..ba51c7c
--- /dev/null
@@ -0,0 +1,132 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.samza.operators;
+
+import java.io.Serializable;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Set;
+import java.util.stream.Collectors;
+import org.apache.samza.operators.spec.InputOperatorSpec;
+import org.apache.samza.operators.spec.OperatorSpec;
+import org.apache.samza.operators.spec.OutputStreamImpl;
+import org.apache.samza.serializers.SerializableSerde;
+import org.apache.samza.system.StreamSpec;
+import org.apache.samza.table.TableSpec;
+
+
+/**
+ * Defines the serialized format of {@link StreamGraphSpec}. This class encapsulates all getter methods to get the {@link OperatorSpec}
+ * initialized in the {@link StreamGraphSpec} and constructsthe corresponding serialized instances of {@link OperatorSpec}.
+ * The {@link StreamGraphSpec} and {@link OperatorSpec} instances included in this class are considered as immutable and read-only.
+ * The instance of {@link OperatorSpecGraph} should only be used in runtime to construct {@link org.apache.samza.task.StreamOperatorTask}.
+ */
+public class OperatorSpecGraph implements Serializable {
+  // We use a LHM for deterministic order in initializing and closing operators.
+  private final Map<StreamSpec, InputOperatorSpec> inputOperators;
+  private final Map<StreamSpec, OutputStreamImpl> outputStreams;
+  private final Map<TableSpec, TableImpl> tables;
+  private final Set<OperatorSpec> allOpSpecs;
+  private final boolean hasWindowOrJoins;
+
+  // The following objects are transient since they are recreateable.
+  private transient final SerializableSerde<OperatorSpecGraph> opSpecGraphSerde = new SerializableSerde<>();
+  private transient final byte[] serializedOpSpecGraph;
+
+  OperatorSpecGraph(StreamGraphSpec graphSpec) {
+    this.inputOperators = graphSpec.getInputOperators();
+    this.outputStreams = graphSpec.getOutputStreams();
+    this.tables = graphSpec.getTables();
+    this.allOpSpecs = Collections.unmodifiableSet(this.findAllOperatorSpecs());
+    this.hasWindowOrJoins = checkWindowOrJoins();
+    this.serializedOpSpecGraph = opSpecGraphSerde.toBytes(this);
+  }
+
+  public Map<StreamSpec, InputOperatorSpec> getInputOperators() {
+    return inputOperators;
+  }
+
+  public Map<StreamSpec, OutputStreamImpl> getOutputStreams() {
+    return outputStreams;
+  }
+
+  public Map<TableSpec, TableImpl> getTables() {
+    return tables;
+  }
+
+  /**
+   * Get all {@link OperatorSpec}s available in this {@link StreamGraphSpec}
+   *
+   * @return all available {@link OperatorSpec}s
+   */
+  public Collection<OperatorSpec> getAllOperatorSpecs() {
+    return allOpSpecs;
+  }
+
+  /**
+   * Returns <tt>true</tt> iff this {@link StreamGraphSpec} contains a join or a window operator
+   *
+   * @return  <tt>true</tt> iff this {@link StreamGraphSpec} contains a join or a window operator
+   */
+  public boolean hasWindowOrJoins() {
+    return hasWindowOrJoins;
+  }
+
+  /**
+   * Returns a deserialized {@link OperatorSpecGraph} as a copy from this instance of {@link OperatorSpecGraph}
+   * This is used to create per-task instance of {@link OperatorSpecGraph} when instantiating task instances.
+   *
+   * @return a copy of this {@link OperatorSpecGraph} object via deserialization
+   */
+  public OperatorSpecGraph clone() {
+    if (opSpecGraphSerde == null) {
+      throw new IllegalStateException("Cannot clone from an already deserialized OperatorSpecGraph.");
+    }
+    return opSpecGraphSerde.fromBytes(serializedOpSpecGraph);
+  }
+
+  private HashSet<OperatorSpec> findAllOperatorSpecs() {
+    Collection<InputOperatorSpec> inputOperatorSpecs = this.inputOperators.values();
+    HashSet<OperatorSpec> operatorSpecs = new HashSet<>();
+    for (InputOperatorSpec inputOperatorSpec : inputOperatorSpecs) {
+      operatorSpecs.add(inputOperatorSpec);
+      doGetOperatorSpecs(inputOperatorSpec, operatorSpecs);
+    }
+    return operatorSpecs;
+  }
+
+  private void doGetOperatorSpecs(OperatorSpec operatorSpec, Set<OperatorSpec> specs) {
+    Collection<OperatorSpec> registeredOperatorSpecs = operatorSpec.getRegisteredOperatorSpecs();
+    for (OperatorSpec registeredOperatorSpec : registeredOperatorSpecs) {
+      specs.add(registeredOperatorSpec);
+      doGetOperatorSpecs(registeredOperatorSpec, specs);
+    }
+  }
+
+  private boolean checkWindowOrJoins() {
+    Set<OperatorSpec> windowOrJoinSpecs = allOpSpecs.stream()
+        .filter(spec -> spec.getOpCode() == OperatorSpec.OpCode.WINDOW || spec.getOpCode() == OperatorSpec.OpCode.JOIN)
+        .collect(Collectors.toSet());
+
+    return windowOrJoinSpecs.size() != 0;
+  }
+
+}
  */
 package org.apache.samza.operators;
 
-import java.util.Collection;
 import java.util.Collections;
 import java.util.HashSet;
 import java.util.LinkedHashMap;
 import java.util.Map;
 import java.util.Set;
 import java.util.regex.Pattern;
-import java.util.stream.Collectors;
 
 import org.apache.commons.lang3.StringUtils;
 import org.apache.samza.SamzaException;
 import org.apache.samza.config.Config;
 import org.apache.samza.config.JobConfig;
 import org.apache.samza.operators.spec.InputOperatorSpec;
-import org.apache.samza.operators.spec.OperatorSpec;
 import org.apache.samza.operators.spec.OperatorSpec.OpCode;
 import org.apache.samza.operators.spec.OperatorSpecs;
 import org.apache.samza.operators.spec.OutputStreamImpl;
@@ -47,14 +44,17 @@ import org.apache.samza.table.TableSpec;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import com.google.common.annotations.VisibleForTesting;
 import com.google.common.base.Preconditions;
 
 /**
- * A {@link StreamGraph} that provides APIs for accessing {@link MessageStream}s to be used to
+ * This class defines:
+ * 1) an implementation of {@link StreamGraph} that provides APIs for accessing {@link MessageStream}s to be used to
  * create the DAG of transforms.
+ * 2) a builder that creates a serializable {@link OperatorSpecGraph} from user-defined DAG
  */
-public class StreamGraphImpl implements StreamGraph {
-  private static final Logger LOGGER = LoggerFactory.getLogger(StreamGraphImpl.class);
+public class StreamGraphSpec implements StreamGraph {
+  private static final Logger LOGGER = LoggerFactory.getLogger(StreamGraphSpec.class);
   private static final Pattern USER_DEFINED_ID_PATTERN = Pattern.compile("[\\d\\w-_.]+");
 
   // We use a LHM for deterministic order in initializing and closing operators.
@@ -64,7 +64,6 @@ public class StreamGraphImpl implements StreamGraph {
   private final ApplicationRunner runner;
   private final Config config;
 
-
   /**
    * The 0-based position of the next operator in the graph.
    * Part of the unique ID for each OperatorSpec in the graph.
@@ -75,8 +74,8 @@ public class StreamGraphImpl implements StreamGraph {
   private Serde<?> defaultSerde = new KVSerde(new NoOpSerde(), new NoOpSerde());
   private ContextManager contextManager = null;
 
-  public StreamGraphImpl(ApplicationRunner runner, Config config) {
-    // TODO: SAMZA-1118 - Move StreamSpec and ApplicationRunner out of StreamGraphImpl once Systems
+  public StreamGraphSpec(ApplicationRunner runner, Config config) {
+    // TODO: SAMZA-1118 - Move StreamSpec and ApplicationRunner out of StreamGraphSpec once Systems
     // can use streamId to send and receive messages.
     this.runner = runner;
     this.config = config;
@@ -167,66 +166,14 @@ public class StreamGraphImpl implements StreamGraph {
     return this;
   }
 
-  /**
-   * See {@link StreamGraphImpl#getIntermediateStream(String, Serde, boolean)}.
-   */
-  <M> IntermediateMessageStreamImpl<M> getIntermediateStream(String streamId, Serde<M> serde) {
-    return getIntermediateStream(streamId, serde, false);
-  }
-
-  /**
-   * Internal helper for {@link MessageStreamImpl} to add an intermediate {@link MessageStream} to the graph.
-   * An intermediate {@link MessageStream} is both an output and an input stream.
-   *
-   * @param streamId the id of the stream to be created.
-   * @param serde the {@link Serde} to use for the message in the intermediate stream. If null, the default serde
-   *              is used.
-   * @param isBroadcast whether the stream is a broadcast stream.
-   * @param <M> the type of messages in the intermediate {@link MessageStream}
-   * @return  the intermediate {@link MessageStreamImpl}
-   *
-   * TODO: once SAMZA-1566 is resolved, we should be able to pass in the StreamSpec directly.
-   */
-  <M> IntermediateMessageStreamImpl<M> getIntermediateStream(String streamId, Serde<M> serde, boolean isBroadcast) {
-    StreamSpec streamSpec = runner.getStreamSpec(streamId);
-    if (isBroadcast) {
-      streamSpec = streamSpec.copyWithBroadCast();
-    }
-
-    Preconditions.checkState(!inputOperators.containsKey(streamSpec) && !outputStreams.containsKey(streamSpec),
-        "getIntermediateStream must not be called multiple times with the same streamId: " + streamId);
-
-    if (serde == null) {
-      LOGGER.info("Using default serde for intermediate stream: " + streamId);
-      serde = (Serde<M>) defaultSerde;
-    }
-
-    boolean isKeyed = serde instanceof KVSerde;
-    KV<Serde, Serde> kvSerdes = getKVSerdes(streamId, serde);
-    InputOperatorSpec inputOperatorSpec =
-        OperatorSpecs.createInputOperatorSpec(streamSpec, kvSerdes.getKey(), kvSerdes.getValue(),
-            isKeyed, this.getNextOpId(OpCode.INPUT, null));
-    inputOperators.put(streamSpec, inputOperatorSpec);
-    outputStreams.put(streamSpec, new OutputStreamImpl(streamSpec, kvSerdes.getKey(), kvSerdes.getValue(), isKeyed));
-    return new IntermediateMessageStreamImpl<>(this, inputOperators.get(streamSpec), outputStreams.get(streamSpec));
-  }
-
-  public Map<StreamSpec, InputOperatorSpec> getInputOperators() {
-    return Collections.unmodifiableMap(inputOperators);
-  }
-
-  public Map<StreamSpec, OutputStreamImpl> getOutputStreams() {
-    return Collections.unmodifiableMap(outputStreams);
-  }
-
-  public Map<TableSpec, TableImpl> getTables() {
-    return Collections.unmodifiableMap(tables);
-  }
-
   public ContextManager getContextManager() {
     return this.contextManager;
   }
 
+  public OperatorSpecGraph getOperatorSpecGraph() {
+    return new OperatorSpecGraph(this);
+  }
+
   /**
    * Gets the unique ID for the next operator in the graph. The ID is of the following format:
    * jobName-jobId-opCode-(userDefinedId|nextOpNum);
@@ -235,7 +182,7 @@ public class StreamGraphImpl implements StreamGraph {
    * @param userDefinedId the optional user-provided name of the next operator or null
    * @return the unique ID for the next operator in the graph
    */
-  /* package private */ String getNextOpId(OpCode opCode, String userDefinedId) {
+  public String getNextOpId(OpCode opCode, String userDefinedId) {
     if (StringUtils.isNotBlank(userDefinedId) && !USER_DEFINED_ID_PATTERN.matcher(userDefinedId).matches()) {
       throw new SamzaException("Operator ID must not contain spaces and special characters: " + userDefinedId);
     }
@@ -260,47 +207,71 @@ public class StreamGraphImpl implements StreamGraph {
    * @param opCode the {@link OpCode} of the next operator
    * @return the unique ID for the next operator in the graph
    */
-  /* package private */ String getNextOpId(OpCode opCode) {
+  public String getNextOpId(OpCode opCode) {
     return getNextOpId(opCode, null);
   }
 
   /**
-   * Get all {@link OperatorSpec}s available in this {@link StreamGraphImpl}
+   * See {@link StreamGraphSpec#getIntermediateStream(String, Serde, boolean)}.
    *
-   * @return  all available {@link OperatorSpec}s
+   * @param <M> type of messages in the intermediate stream
+   * @param streamId the id of the stream to be created
+   * @param serde the {@link Serde} to use for messages in the intermediate stream. If null, the default serde is used.
+   * @return  the intermediate {@link MessageStreamImpl}
    */
-  public Collection<OperatorSpec> getAllOperatorSpecs() {
-    Collection<InputOperatorSpec> inputOperatorSpecs = inputOperators.values();
-    Set<OperatorSpec> operatorSpecs = new HashSet<>();
-    for (InputOperatorSpec inputOperatorSpec: inputOperatorSpecs) {
-      operatorSpecs.add(inputOperatorSpec);
-      doGetOperatorSpecs(inputOperatorSpec, operatorSpecs);
-    }
-    return operatorSpecs;
-  }
-
-  private void doGetOperatorSpecs(OperatorSpec operatorSpec, Set<OperatorSpec> specs) {
-    Collection<OperatorSpec> registeredOperatorSpecs = operatorSpec.getRegisteredOperatorSpecs();
-    for (OperatorSpec registeredOperatorSpec: registeredOperatorSpecs) {
-      specs.add(registeredOperatorSpec);
-      doGetOperatorSpecs(registeredOperatorSpec, specs);
-    }
+  @VisibleForTesting
+  public <M> IntermediateMessageStreamImpl<M> getIntermediateStream(String streamId, Serde<M> serde) {
+    return getIntermediateStream(streamId, serde, false);
   }
 
   /**
-   * Returns <tt>true</tt> iff this {@link StreamGraphImpl} contains a join or a window operator
+   * Internal helper for {@link MessageStreamImpl} to add an intermediate {@link MessageStream} to the graph.
+   * An intermediate {@link MessageStream} is both an output and an input stream.
    *
-   * @return  <tt>true</tt> iff this {@link StreamGraphImpl} contains a join or a window operator
+   * @param streamId the id of the stream to be created.
+   * @param serde the {@link Serde} to use for the message in the intermediate stream. If null, the default serde
+   *              is used.
+   * @param isBroadcast whether the stream is a broadcast stream.
+   * @param <M> the type of messages in the intermediate {@link MessageStream}
+   * @return  the intermediate {@link MessageStreamImpl}
+   *
+   * TODO: once SAMZA-1566 is resolved, we should be able to pass in the StreamSpec directly.
    */
-  public boolean hasWindowOrJoins() {
-    // Obtain the operator specs from the streamGraph
-    Collection<OperatorSpec> operatorSpecs = getAllOperatorSpecs();
+  @VisibleForTesting
+  <M> IntermediateMessageStreamImpl<M> getIntermediateStream(String streamId, Serde<M> serde, boolean isBroadcast) {
+    StreamSpec streamSpec = runner.getStreamSpec(streamId);
+    if (isBroadcast) {
+      streamSpec = streamSpec.copyWithBroadCast();
+    }
 
-    Set<OperatorSpec> windowOrJoinSpecs = operatorSpecs.stream()
-        .filter(spec -> spec.getOpCode() == OperatorSpec.OpCode.WINDOW || spec.getOpCode() == OperatorSpec.OpCode.JOIN)
-        .collect(Collectors.toSet());
+    Preconditions.checkState(!inputOperators.containsKey(streamSpec) && !outputStreams.containsKey(streamSpec),
+        "getIntermediateStream must not be called multiple times with the same streamId: " + streamId);
+
+    if (serde == null) {
+      LOGGER.info("Using default serde for intermediate stream: " + streamId);
+      serde = (Serde<M>) defaultSerde;
+    }
+
+    boolean isKeyed = serde instanceof KVSerde;
+    KV<Serde, Serde> kvSerdes = getKVSerdes(streamId, serde);
+    InputOperatorSpec inputOperatorSpec =
+        OperatorSpecs.createInputOperatorSpec(streamSpec, kvSerdes.getKey(), kvSerdes.getValue(),
+            isKeyed, this.getNextOpId(OpCode.INPUT, null));
+    inputOperators.put(streamSpec, inputOperatorSpec);
+    outputStreams.put(streamSpec, new OutputStreamImpl(streamSpec, kvSerdes.getKey(), kvSerdes.getValue(), isKeyed));
+    return new IntermediateMessageStreamImpl<>(this, inputOperators.get(streamSpec), outputStreams.get(streamSpec));
+  }
+
+  Map<StreamSpec, InputOperatorSpec> getInputOperators() {
+    return Collections.unmodifiableMap(inputOperators);
+  }
+
+  Map<StreamSpec, OutputStreamImpl> getOutputStreams() {
+    return Collections.unmodifiableMap(outputStreams);
+  }
 
-    return windowOrJoinSpecs.size() != 0;
+  Map<TableSpec, TableImpl> getTables() {
+    return Collections.unmodifiableMap(tables);
   }
 
   private KV<Serde, Serde> getKVSerdes(String streamId, Serde serde) {
index e671534..8ceada0 100644 (file)
@@ -18,6 +18,7 @@
  */
 package org.apache.samza.operators;
 
+import java.io.Serializable;
 import org.apache.samza.table.Table;
 import org.apache.samza.table.TableSpec;
 
@@ -25,7 +26,7 @@ import org.apache.samza.table.TableSpec;
 /**
  * This class is the holder of a {@link TableSpec}
  */
-public class TableImpl implements Table {
+public class TableImpl implements Table, Serializable {
 
   private final TableSpec tableSpec;
 
index 269e7bc..8df670e 100644 (file)
@@ -42,7 +42,7 @@ class BroadcastOperatorImpl<M> extends OperatorImpl<M, Void> {
 
   BroadcastOperatorImpl(BroadcastOperatorSpec<M> broadcastOpSpec, TaskContext context) {
     this.broadcastOpSpec = broadcastOpSpec;
-    this.systemStream = broadcastOpSpec.getOutputStream().getStreamSpec().toSystemStream();
+    this.systemStream = broadcastOpSpec.getOutputStream().getSystemStream();
     this.taskName = context.getTaskName().getTaskName();
   }
 
index 608b2be..f0c0997 100644 (file)
@@ -196,7 +196,7 @@ public abstract class OperatorImpl<M, RM> {
 
     results.forEach(rm ->
         this.registeredOperators.forEach(op ->
-            op.onMessage(rm, collector, coordinator)));    
+            op.onMessage(rm, collector, coordinator)));
 
     WatermarkFunction watermarkFn = getOperatorSpec().getWatermarkFn();
     if (watermarkFn != null) {
index bbc8783..0f51798 100644 (file)
@@ -25,7 +25,6 @@ import org.apache.samza.config.Config;
 import org.apache.samza.container.TaskContextImpl;
 import org.apache.samza.job.model.JobModel;
 import org.apache.samza.operators.KV;
-import org.apache.samza.operators.StreamGraphImpl;
 import org.apache.samza.operators.TimerRegistry;
 import org.apache.samza.operators.functions.JoinFunction;
 import org.apache.samza.operators.functions.PartialJoinFunction;
@@ -34,6 +33,7 @@ import org.apache.samza.operators.spec.BroadcastOperatorSpec;
 import org.apache.samza.operators.spec.InputOperatorSpec;
 import org.apache.samza.operators.spec.JoinOperatorSpec;
 import org.apache.samza.operators.spec.OperatorSpec;
+import org.apache.samza.operators.OperatorSpecGraph;
 import org.apache.samza.operators.spec.OutputOperatorSpec;
 import org.apache.samza.operators.spec.PartitionByOperatorSpec;
 import org.apache.samza.operators.spec.SendToTableOperatorSpec;
@@ -81,26 +81,26 @@ public class OperatorImplGraph {
    * the two {@link PartialJoinOperatorImpl}s for a {@link JoinOperatorSpec} with each other since they're
    * reached from different {@link OperatorSpec} during DAG traversals.
    */
-  private final Map<String, KV<PartialJoinFunction, PartialJoinFunction>> joinFunctions = new HashMap<>();
+  private final Map<String, KV<PartialJoinOperatorImpl, PartialJoinOperatorImpl>> joinOpImpls = new HashMap<>();
 
   private final Clock clock;
 
   /**
    * Constructs the DAG of {@link OperatorImpl}s corresponding to the the DAG of {@link OperatorSpec}s
-   * in the {@code streamGraph}.
+   * in the {@code specGraph}.
    *
-   * @param streamGraph  the {@link StreamGraphImpl} containing the logical {@link OperatorSpec} DAG
+   * @param specGraph  the {@link OperatorSpecGraph} containing the logical {@link OperatorSpec} DAG
    * @param config  the {@link Config} required to instantiate operators
    * @param context  the {@link TaskContext} required to instantiate operators
    * @param clock  the {@link Clock} to get current time
    */
-  public OperatorImplGraph(StreamGraphImpl streamGraph, Config config, TaskContext context, Clock clock) {
+  public OperatorImplGraph(OperatorSpecGraph specGraph, Config config, TaskContext context, Clock clock) {
     this.clock = clock;
 
     TaskContextImpl taskContext = (TaskContextImpl) context;
-    Map<SystemStream, Integer> producerTaskCounts = hasIntermediateStreams(streamGraph) ?
+    Map<SystemStream, Integer> producerTaskCounts = hasIntermediateStreams(specGraph) ?
         getProducerTaskCountForIntermediateStreams(getStreamToConsumerTasks(taskContext.getJobModel()),
-            getIntermediateToInputStreamsMap(streamGraph)) :
+            getIntermediateToInputStreamsMap(specGraph)) :
         Collections.EMPTY_MAP;
     producerTaskCounts.forEach((stream, count) -> {
         LOG.info("{} has {} producer tasks.", stream, count);
@@ -113,7 +113,7 @@ public class OperatorImplGraph {
     taskContext.registerObject(WatermarkStates.class.getName(),
         new WatermarkStates(context.getSystemStreamPartitions(), producerTaskCounts));
 
-    streamGraph.getInputOperators().forEach((streamSpec, inputOpSpec) -> {
+    specGraph.getInputOperators().forEach((streamSpec, inputOpSpec) -> {
         SystemStream systemStream = new SystemStream(streamSpec.getSystemName(), streamSpec.getPhysicalName());
         InputOperatorImpl inputOperatorImpl =
             (InputOperatorImpl) createAndRegisterOperatorImpl(null, inputOpSpec, systemStream, config, context);
@@ -151,12 +151,13 @@ public class OperatorImplGraph {
    * creates the corresponding DAG of {@link OperatorImpl}s, and returns the root {@link OperatorImpl} node.
    *
    * @param prevOperatorSpec  the parent of the current {@code operatorSpec} in the traversal
-   * @param operatorSpec  the operatorSpec to create the {@link OperatorImpl} for
+   * @param operatorSpec  the {@link OperatorSpec} to create the {@link OperatorImpl} for
+   * @param inputStream  the source input stream that we traverse the {@link OperatorSpecGraph} from
    * @param config  the {@link Config} required to instantiate operators
    * @param context  the {@link TaskContext} required to instantiate operators
    * @return  the operator implementation for the operatorSpec
    */
-  OperatorImpl createAndRegisterOperatorImpl(OperatorSpec prevOperatorSpec, OperatorSpec operatorSpec,
+  private OperatorImpl createAndRegisterOperatorImpl(OperatorSpec prevOperatorSpec, OperatorSpec operatorSpec,
       SystemStream inputStream, Config config, TaskContext context) {
 
     if (!operatorImpls.containsKey(operatorSpec.getOpId()) || operatorSpec instanceof JoinOperatorSpec) {
@@ -178,7 +179,9 @@ public class OperatorImplGraph {
 
       Collection<OperatorSpec> registeredSpecs = operatorSpec.getRegisteredOperatorSpecs();
       registeredSpecs.forEach(registeredSpec -> {
-          OperatorImpl nextImpl = createAndRegisterOperatorImpl(operatorSpec, registeredSpec, inputStream, config, context);
+          LOG.debug("Creating operator {} with opCode: {}", registeredSpec.getOpId(), registeredSpec.getOpCode());
+          OperatorImpl nextImpl =
+              createAndRegisterOperatorImpl(operatorSpec, registeredSpec, inputStream, config, context);
           operatorImpl.registerNextOperator(nextImpl);
         });
       return operatorImpl;
@@ -199,7 +202,8 @@ public class OperatorImplGraph {
   /**
    * Creates a new {@link OperatorImpl} instance for the provided {@link OperatorSpec}.
    *
-   * @param operatorSpec  the immutable {@link OperatorSpec} definition.
+   * @param prevOperatorSpec the original {@link OperatorSpec} that produces output for {@code operatorSpec} from {@link OperatorSpecGraph}
+   * @param operatorSpec  the original {@link OperatorSpec} from {@link OperatorSpecGraph}
    * @param config  the {@link Config} required to instantiate operators
    * @param context  the {@link TaskContext} required to instantiate operators
    * @return  the {@link OperatorImpl} implementation instance
@@ -209,17 +213,19 @@ public class OperatorImplGraph {
     if (operatorSpec instanceof InputOperatorSpec) {
       return new InputOperatorImpl((InputOperatorSpec) operatorSpec);
     } else if (operatorSpec instanceof StreamOperatorSpec) {
-      return new StreamOperatorImpl((StreamOperatorSpec) operatorSpec, config, context);
+      return new StreamOperatorImpl((StreamOperatorSpec) operatorSpec);
     } else if (operatorSpec instanceof SinkOperatorSpec) {
       return new SinkOperatorImpl((SinkOperatorSpec) operatorSpec, config, context);
     } else if (operatorSpec instanceof OutputOperatorSpec) {
-      return new OutputOperatorImpl((OutputOperatorSpec) operatorSpec, config, context);
+      return new OutputOperatorImpl((OutputOperatorSpec) operatorSpec);
     } else if (operatorSpec instanceof PartitionByOperatorSpec) {
       return new PartitionByOperatorImpl((PartitionByOperatorSpec) operatorSpec, config, context);
     } else if (operatorSpec instanceof WindowOperatorSpec) {
       return new WindowOperatorImpl((WindowOperatorSpec) operatorSpec, clock);
     } else if (operatorSpec instanceof JoinOperatorSpec) {
-      return createPartialJoinOperatorImpl(prevOperatorSpec, (JoinOperatorSpec) operatorSpec, config, context, clock);
+      return getOrCreatePartialJoinOpImpls((JoinOperatorSpec) operatorSpec,
+          prevOperatorSpec.equals(((JoinOperatorSpec) operatorSpec).getLeftInputOpSpec()),
+          config, context, clock);
     } else if (operatorSpec instanceof StreamTableJoinOperatorSpec) {
       return new StreamTableJoinOperatorImpl((StreamTableJoinOperatorSpec) operatorSpec, config, context);
     } else if (operatorSpec instanceof SendToTableOperatorSpec) {
@@ -231,23 +237,24 @@ public class OperatorImplGraph {
         String.format("Unsupported OperatorSpec: %s", operatorSpec.getClass().getName()));
   }
 
-  private PartialJoinOperatorImpl createPartialJoinOperatorImpl(OperatorSpec prevOperatorSpec,
-      JoinOperatorSpec joinOpSpec, Config config, TaskContext context, Clock clock) {
-    KV<PartialJoinFunction, PartialJoinFunction> partialJoinFunctions = getOrCreatePartialJoinFunctions(joinOpSpec);
-    if (joinOpSpec.getLeftInputOpSpec().equals(prevOperatorSpec)) { // we got here from the left side of the join
-      return new PartialJoinOperatorImpl(joinOpSpec, /* isLeftSide */ true,
-          partialJoinFunctions.getKey(), partialJoinFunctions.getValue(), config, context, clock);
+  private PartialJoinOperatorImpl getOrCreatePartialJoinOpImpls(JoinOperatorSpec joinOpSpec, boolean isLeft,
+      Config config, TaskContext context, Clock clock) {
+    // get the per task pair of PartialJoinOperatorImpl for the corresponding {@code joinOpSpec}
+    KV<PartialJoinOperatorImpl, PartialJoinOperatorImpl> partialJoinOpImpls = joinOpImpls.computeIfAbsent(joinOpSpec.getOpId(),
+        joinOpId -> {
+        PartialJoinFunction leftJoinFn = createLeftJoinFn(joinOpSpec);
+        PartialJoinFunction rightJoinFn = createRightJoinFn(joinOpSpec);
+        return new KV(new PartialJoinOperatorImpl(joinOpSpec, true, leftJoinFn, rightJoinFn, config, context, clock),
+            new PartialJoinOperatorImpl(joinOpSpec, false, rightJoinFn, leftJoinFn, config, context, clock));
+      });
+
+    if (isLeft) { // we got here from the left side of the join
+      return partialJoinOpImpls.getKey();
     } else { // we got here from the right side of the join
-      return new PartialJoinOperatorImpl(joinOpSpec, /* isLeftSide */ false,
-          partialJoinFunctions.getValue(), partialJoinFunctions.getKey(), config, context, clock);
+      return partialJoinOpImpls.getValue();
     }
   }
 
-  private KV<PartialJoinFunction, PartialJoinFunction> getOrCreatePartialJoinFunctions(JoinOperatorSpec joinOpSpec) {
-    return joinFunctions.computeIfAbsent(joinOpSpec.getOpId(),
-        joinOpId -> KV.of(createLeftJoinFn(joinOpSpec), createRightJoinFn(joinOpSpec)));
-  }
-
   private PartialJoinFunction<Object, Object, Object, Object> createLeftJoinFn(JoinOperatorSpec joinOpSpec) {
     return new PartialJoinFunction<Object, Object, Object, Object>() {
       private final JoinFunction joinFn = joinOpSpec.getJoinFn();
@@ -316,8 +323,8 @@ public class OperatorImplGraph {
     };
   }
 
-  private boolean hasIntermediateStreams(StreamGraphImpl streamGraph) {
-    return !Collections.disjoint(streamGraph.getInputOperators().keySet(), streamGraph.getOutputStreams().keySet());
+  private boolean hasIntermediateStreams(OperatorSpecGraph specGraph) {
+    return !Collections.disjoint(specGraph.getInputOperators().keySet(), specGraph.getOutputStreams().keySet());
   }
 
   /**
@@ -358,12 +365,12 @@ public class OperatorImplGraph {
 
   /**
    * calculate the mapping from output streams to input streams
-   * @param streamGraph the user {@link StreamGraphImpl} instance
+   * @param specGraph the user {@link OperatorSpecGraph}
    * @return mapping from output streams to input streams
    */
-  static Multimap<SystemStream, SystemStream> getIntermediateToInputStreamsMap(StreamGraphImpl streamGraph) {
+  static Multimap<SystemStream, SystemStream> getIntermediateToInputStreamsMap(OperatorSpecGraph specGraph) {
     Multimap<SystemStream, SystemStream> outputToInputStreams = HashMultimap.create();
-    streamGraph.getInputOperators().entrySet().stream()
+    specGraph.getInputOperators().entrySet().stream()
         .forEach(
             entry -> computeOutputToInput(entry.getKey().toSystemStream(), entry.getValue(), outputToInputStreams));
     return outputToInputStreams;
index 27bef87..e625484 100644 (file)
@@ -42,11 +42,10 @@ class OutputOperatorImpl<M> extends OperatorImpl<M, Void> {
   private final OutputStreamImpl<M> outputStream;
   private final SystemStream systemStream;
 
-  OutputOperatorImpl(OutputOperatorSpec<M> outputOpSpec, Config config, TaskContext context) {
+  OutputOperatorImpl(OutputOperatorSpec<M> outputOpSpec) {
     this.outputOpSpec = outputOpSpec;
     this.outputStream = outputOpSpec.getOutputStream();
-    this.systemStream = new SystemStream(outputStream.getStreamSpec().getSystemName(),
-        outputStream.getStreamSpec().getPhysicalName());
+    this.systemStream = outputStream.getSystemStream();
   }
 
   @Override
index 9fc1e7c..dd64429 100644 (file)
@@ -21,6 +21,7 @@ package org.apache.samza.operators.impl;
 import org.apache.samza.config.Config;
 import org.apache.samza.container.TaskContextImpl;
 import org.apache.samza.operators.KV;
+import org.apache.samza.operators.functions.MapFunction;
 import org.apache.samza.operators.spec.OperatorSpec;
 import org.apache.samza.operators.spec.OutputStreamImpl;
 import org.apache.samza.operators.spec.PartitionByOperatorSpec;
@@ -36,7 +37,6 @@ import org.apache.samza.task.TaskCoordinator;
 
 import java.util.Collection;
 import java.util.Collections;
-import java.util.function.Function;
 
 
 /**
@@ -46,17 +46,15 @@ class PartitionByOperatorImpl<M, K, V> extends OperatorImpl<M, Void> {
 
   private final PartitionByOperatorSpec<M, K, V> partitionByOpSpec;
   private final SystemStream systemStream;
-  private final Function<? super M, ? extends K> keyFunction;
-  private final Function<? super M, ? extends V> valueFunction;
+  private final MapFunction<? super M, ? extends K> keyFunction;
+  private final MapFunction<? super M, ? extends V> valueFunction;
   private final String taskName;
   private final ControlMessageSender controlMessageSender;
 
   PartitionByOperatorImpl(PartitionByOperatorSpec<M, K, V> partitionByOpSpec, Config config, TaskContext context) {
     this.partitionByOpSpec = partitionByOpSpec;
     OutputStreamImpl<KV<K, V>> outputStream = partitionByOpSpec.getOutputStream();
-    this.systemStream = new SystemStream(
-        outputStream.getStreamSpec().getSystemName(),
-        outputStream.getStreamSpec().getPhysicalName());
+    this.systemStream = outputStream.getSystemStream();
     this.keyFunction = partitionByOpSpec.getKeyFunction();
     this.valueFunction = partitionByOpSpec.getValueFunction();
     this.taskName = context.getTaskName().getTaskName();
@@ -66,6 +64,8 @@ class PartitionByOperatorImpl<M, K, V> extends OperatorImpl<M, Void> {
 
   @Override
   protected void handleInit(Config config, TaskContext context) {
+    this.keyFunction.init(config, context);
+    this.valueFunction.init(config, context);
   }
 
   @Override
@@ -80,6 +80,8 @@ class PartitionByOperatorImpl<M, K, V> extends OperatorImpl<M, Void> {
 
   @Override
   protected void handleClose() {
+    this.keyFunction.close();
+    this.valueFunction.close();
   }
 
   @Override
@@ -100,7 +102,7 @@ class PartitionByOperatorImpl<M, K, V> extends OperatorImpl<M, Void> {
   }
 
   private void sendControlMessage(ControlMessage message, MessageCollector collector) {
-    SystemStream outputStream = partitionByOpSpec.getOutputStream().getStreamSpec().toSystemStream();
+    SystemStream outputStream = partitionByOpSpec.getOutputStream().getSystemStream();
     controlMessageSender.send(message, outputStream, collector);
   }
 }
index a51d5e6..6cd426b 100644 (file)
@@ -40,8 +40,7 @@ class StreamOperatorImpl<M, RM> extends OperatorImpl<M, RM> {
   private final StreamOperatorSpec<M, RM> streamOpSpec;
   private final FlatMapFunction<M, RM> transformFn;
 
-  StreamOperatorImpl(StreamOperatorSpec<M, RM> streamOpSpec,
-      Config config, TaskContext context) {
+  StreamOperatorImpl(StreamOperatorSpec<M, RM> streamOpSpec) {
     this.streamOpSpec = streamOpSpec;
     this.transformFn = streamOpSpec.getTransformFn();
   }
index 32406cb..6b5baae 100644 (file)
@@ -23,6 +23,8 @@ package org.apache.samza.operators.impl;
 import com.google.common.base.Preconditions;
 import org.apache.samza.config.Config;
 import org.apache.samza.operators.functions.FoldLeftFunction;
+import org.apache.samza.operators.functions.MapFunction;
+import org.apache.samza.operators.functions.SupplierFunction;
 import org.apache.samza.operators.impl.store.TimeSeriesKey;
 import org.apache.samza.operators.impl.store.TimeSeriesStore;
 import org.apache.samza.operators.impl.store.TimeSeriesStoreImpl;
@@ -58,8 +60,6 @@ import java.util.List;
 import java.util.Map;
 import java.util.Optional;
 import java.util.Set;
-import java.util.function.Function;
-import java.util.function.Supplier;
 import java.util.stream.Collectors;
 
 /**
@@ -93,8 +93,8 @@ public class WindowOperatorImpl<M, K> extends OperatorImpl<M, WindowPane<K, Obje
   private final Clock clock;
   private final WindowInternal<M, K, Object> window;
   private final FoldLeftFunction<M, Object> foldLeftFn;
-  private final Supplier<Object> initializer;
-  private final Function<M, K> keyFn;
+  private final SupplierFunction<Object> initializer;
+  private final MapFunction<M, K> keyFn;
 
   private final TriggerScheduler<K> triggerScheduler;
   private final Map<TriggerKey<K>, TriggerImplHandler> triggers = new HashMap<>();
@@ -112,11 +112,18 @@ public class WindowOperatorImpl<M, K> extends OperatorImpl<M, WindowPane<K, Obje
 
   @Override
   protected void handleInit(Config config, TaskContext context) {
-    WindowInternal<M, K, Object> window = windowOpSpec.getWindow();
 
     KeyValueStore<TimeSeriesKey<K>, Object> store =
         (KeyValueStore<TimeSeriesKey<K>, Object>) context.getStore(windowOpSpec.getOpId());
 
+    if (initializer != null) {
+      initializer.init(config, context);
+    }
+
+    if (keyFn != null) {
+      keyFn.init(config, context);
+    }
+
     // For aggregating windows, we use the store in over-write mode since we only retain the aggregated
     // value. Else, we use the store in append-mode.
     if (foldLeftFn != null) {
@@ -215,6 +222,12 @@ public class WindowOperatorImpl<M, K> extends OperatorImpl<M, WindowPane<K, Obje
     if (timeSeriesStore != null) {
       timeSeriesStore.close();
     }
+    if (initializer != null) {
+      initializer.close();
+    }
+    if (keyFn != null) {
+      keyFn.close();
+    }
   }
 
   private TriggerImplHandler getOrCreateTriggerImplHandler(TriggerKey<K> triggerKey, Trigger<M> trigger) {
diff --git a/samza-core/src/main/java/org/apache/samza/operators/spec/FilterOperatorSpec.java b/samza-core/src/main/java/org/apache/samza/operators/spec/FilterOperatorSpec.java
new file mode 100644 (file)
index 0000000..a5cdb82
--- /dev/null
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.samza.operators.spec;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import org.apache.samza.config.Config;
+import org.apache.samza.operators.functions.FilterFunction;
+import org.apache.samza.operators.functions.FlatMapFunction;
+import org.apache.samza.operators.functions.TimerFunction;
+import org.apache.samza.operators.functions.WatermarkFunction;
+import org.apache.samza.task.TaskContext;
+
+
+/**
+ * The spec for an operator that filters input messages based on some conditions.
+ *
+ * @param <M> type of input message
+ */
+class FilterOperatorSpec<M> extends StreamOperatorSpec<M, M> {
+  private final FilterFunction<M> filterFn;
+
+  FilterOperatorSpec(FilterFunction<M> filterFn, String opId) {
+    super(new FlatMapFunction<M, M>() {
+      @Override
+      public Collection<M> apply(M message) {
+        return new ArrayList<M>() {
+          {
+            if (filterFn.apply(message)) {
+              this.add(message);
+            }
+          }
+        };
+      }
+
+      @Override
+      public void init(Config config, TaskContext context) {
+        filterFn.init(config, context);
+      }
+
+      @Override
+      public void close() {
+        filterFn.close();
+      }
+    }, OpCode.FILTER, opId);
+    this.filterFn = filterFn;
+  }
+
+  @Override
+  public WatermarkFunction getWatermarkFn() {
+    return this.filterFn instanceof WatermarkFunction ? (WatermarkFunction) this.filterFn : null;
+  }
+
+  @Override
+  public TimerFunction getTimerFn() {
+    return this.filterFn instanceof TimerFunction ? (TimerFunction) this.filterFn : null;
+  }
+}
diff --git a/samza-core/src/main/java/org/apache/samza/operators/spec/FlatMapOperatorSpec.java b/samza-core/src/main/java/org/apache/samza/operators/spec/FlatMapOperatorSpec.java
new file mode 100644 (file)
index 0000000..a93a221
--- /dev/null
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.samza.operators.spec;
+
+import org.apache.samza.operators.functions.FlatMapFunction;
+import org.apache.samza.operators.functions.TimerFunction;
+import org.apache.samza.operators.functions.WatermarkFunction;
+
+
+/**
+ * The spec for an operator that transforms each input message to a collection of output messages.
+ *
+ * @param <M> type of input message
+ * @param <OM> type of output messages
+ */
+class FlatMapOperatorSpec<M, OM> extends StreamOperatorSpec<M, OM> {
+
+  FlatMapOperatorSpec(FlatMapFunction<M, OM> flatMapFn, String opId) {
+    super(flatMapFn, OpCode.FLAT_MAP, opId);
+  }
+
+  @Override
+  public WatermarkFunction getWatermarkFn() {
+    return this.transformFn instanceof WatermarkFunction ? (WatermarkFunction) this.transformFn : null;
+  }
+
+  @Override
+  public TimerFunction getTimerFn() {
+    return this.transformFn instanceof TimerFunction ? (TimerFunction) this.transformFn : null;
+  }
+}
index 2ed1e30..a636ac5 100644 (file)
@@ -20,8 +20,8 @@ package org.apache.samza.operators.spec;
 
 import org.apache.samza.operators.KV;
 import org.apache.samza.operators.functions.TimerFunction;
-import org.apache.samza.serializers.Serde;
 import org.apache.samza.operators.functions.WatermarkFunction;
+import org.apache.samza.serializers.Serde;
 import org.apache.samza.system.StreamSpec;
 
 /**
@@ -33,10 +33,15 @@ import org.apache.samza.system.StreamSpec;
  */
 public class InputOperatorSpec<K, V> extends OperatorSpec<KV<K, V>, Object> { // Object == KV<K, V> | V
 
-  private final StreamSpec streamSpec;
-  private final Serde<K> keySerde;
-  private final Serde<V> valueSerde;
   private final boolean isKeyed;
+  private final StreamSpec streamSpec;
+
+  /**
+   * The following {@link Serde}s are serialized by the ExecutionPlanner when generating the configs for a stream, and deserialized
+   * once during startup in SamzaContainer. They don't need to be deserialized here on a per-task basis
+   */
+  private transient final Serde<K> keySerde;
+  private transient final Serde<V> valueSerde;
 
   public InputOperatorSpec(StreamSpec streamSpec,
       Serde<K> keySerde, Serde<V> valueSerde, boolean isKeyed, String opId) {
index 9e058ff..a218135 100644 (file)
@@ -42,14 +42,20 @@ import java.util.Map;
  */
 public class JoinOperatorSpec<K, M, OM, JM> extends OperatorSpec<Object, JM> implements StatefulOperatorSpec { // Object == M | OM
 
-  private final OperatorSpec<?, M> leftInputOpSpec;
-  private final OperatorSpec<?, OM> rightInputOpSpec;
   private final JoinFunction<K, M, OM, JM> joinFn;
-  private final Serde<K> keySerde;
-  private final Serde<TimestampedValue<M>> messageSerde;
-  private final Serde<TimestampedValue<OM>> otherMessageSerde;
   private final long ttlMs;
 
+  private final OperatorSpec<?, M> leftInputOpSpec;
+  private final OperatorSpec<?, OM> rightInputOpSpec;
+
+  /**
+   * The following {@link Serde}s are serialized by the ExecutionPlanner when generating the store configs for a join, and
+   * deserialized once during startup in SamzaContainer. They don't need to be deserialized here on a per-task basis
+   */
+  private transient final Serde<K> keySerde;
+  private transient final Serde<TimestampedValue<M>> messageSerde;
+  private transient final Serde<TimestampedValue<OM>> otherMessageSerde;
+
   /**
    * Default constructor for a {@link JoinOperatorSpec}.
    *
@@ -126,4 +132,5 @@ public class JoinOperatorSpec<K, M, OM, JM> extends OperatorSpec<Object, JM> imp
   public long getTtlMs() {
     return ttlMs;
   }
+
 }
diff --git a/samza-core/src/main/java/org/apache/samza/operators/spec/MapOperatorSpec.java b/samza-core/src/main/java/org/apache/samza/operators/spec/MapOperatorSpec.java
new file mode 100644 (file)
index 0000000..1e2190b
--- /dev/null
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.samza.operators.spec;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import org.apache.samza.config.Config;
+import org.apache.samza.operators.functions.FlatMapFunction;
+import org.apache.samza.operators.functions.MapFunction;
+import org.apache.samza.operators.functions.TimerFunction;
+import org.apache.samza.operators.functions.WatermarkFunction;
+import org.apache.samza.task.TaskContext;
+
+
+/**
+ * The spec for an operator that transforms each input message to a single output message.
+ *
+ * @param <M> type of input message
+ * @param <OM> type of output messages
+ */
+class MapOperatorSpec<M, OM> extends StreamOperatorSpec<M, OM> {
+
+  private final MapFunction<M, OM> mapFn;
+
+  MapOperatorSpec(MapFunction<M, OM> mapFn, String opId) {
+    super(new FlatMapFunction<M, OM>() {
+      @Override
+      public Collection<OM> apply(M message) {
+        return new ArrayList<OM>() {
+          {
+            OM r = mapFn.apply(message);
+            if (r != null) {
+              this.add(r);
+            }
+          }
+        };
+      }
+
+      @Override
+      public void init(Config config, TaskContext context) {
+        mapFn.init(config, context);
+      }
+
+      @Override
+      public void close() {
+        mapFn.close();
+      }
+    }, OpCode.MAP, opId);
+    this.mapFn = mapFn;
+  }
+
+  @Override
+  public WatermarkFunction getWatermarkFn() {
+    return this.mapFn instanceof WatermarkFunction ? (WatermarkFunction) this.mapFn : null;
+  }
+
+  @Override
+  public TimerFunction getTimerFn() {
+    return this.mapFn instanceof TimerFunction ? (TimerFunction) this.mapFn : null;
+  }
+}
diff --git a/samza-core/src/main/java/org/apache/samza/operators/spec/MergeOperatorSpec.java b/samza-core/src/main/java/org/apache/samza/operators/spec/MergeOperatorSpec.java
new file mode 100644 (file)
index 0000000..987f72c
--- /dev/null
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.samza.operators.spec;
+
+import java.util.ArrayList;
+import org.apache.samza.operators.functions.TimerFunction;
+import org.apache.samza.operators.functions.WatermarkFunction;
+
+
+/**
+ * The spec for an operator that combines messages from all input streams into a single output stream.
+ *
+ * @param <M> the type of messages in all input streams
+ */
+class MergeOperatorSpec<M> extends StreamOperatorSpec<M, M> {
+
+  MergeOperatorSpec(String opId) {
+    super((M message) ->
+        new ArrayList<M>() {
+        {
+          this.add(message);
+        }
+      }, OperatorSpec.OpCode.MERGE, opId);
+  }
+
+  @Override
+  public WatermarkFunction getWatermarkFn() {
+    return null;
+  }
+
+  @Override
+  public TimerFunction getTimerFn() {
+    return null;
+  }
+}
index 7b0a41b..e1e1c55 100644 (file)
@@ -18,9 +18,9 @@
  */
 package org.apache.samza.operators.spec;
 
+import java.io.Serializable;
 import java.util.Collection;
 import java.util.LinkedHashSet;
-import java.util.Set;
 
 import org.apache.samza.annotation.InterfaceStability;
 import org.apache.samza.operators.MessageStream;
@@ -30,14 +30,14 @@ import org.apache.samza.operators.functions.WatermarkFunction;
 
 /**
  * A stream operator specification that holds all the information required to transform
- * the input {@link org.apache.samza.operators.MessageStreamImpl} and produce the output
- * {@link org.apache.samza.operators.MessageStreamImpl}.
+ * the input {@link MessageStreamImpl} and produce the output
+ * {@link MessageStreamImpl}.
  *
  * @param <M>  the type of input message to the operator
  * @param <OM>  the type of output message from the operator
  */
 @InterfaceStability.Unstable
-public abstract class OperatorSpec<M, OM> {
+public abstract class OperatorSpec<M, OM> implements Serializable {
 
   public enum OpCode {
     INPUT,
@@ -61,9 +61,15 @@ public abstract class OperatorSpec<M, OM> {
   /**
    * The set of operators that consume the messages produced from this operator.
    * <p>
-   * We use a LinkedHashSet since we need deterministic ordering in initializing/closing operators.
+   * We use a LinkedHashSet since we need both deterministic ordering in initializing/closing operators and serializability.
    */
-  private final Set<OperatorSpec<OM, ?>> nextOperatorSpecs = new LinkedHashSet<>();
+  private final LinkedHashSet<OperatorSpec<OM, ?>> nextOperatorSpecs = new LinkedHashSet<>();
+
+  // this method is used in unit tests to verify an {@link OperatorSpec} instance is a deserialized copy of this object.
+  final boolean isClone(OperatorSpec other) {
+    return this != other && this.getClass().isAssignableFrom(other.getClass())
+        && this.opCode.equals(other.opCode) && this.opId.equals(other.opId);
+  }
 
   public OperatorSpec(OpCode opCode, String opId) {
     this.opCode = opCode;
@@ -79,6 +85,11 @@ public abstract class OperatorSpec<M, OM> {
     nextOperatorSpecs.add(nextOperatorSpec);
   }
 
+  /**
+   * Get the collection of chained {@link OperatorSpec}s that are consuming the output of this node
+   *
+   * @return the collection of chained {@link OperatorSpec}s
+   */
   public Collection<OperatorSpec<OM, ?>> getRegisteredOperatorSpecs() {
     return nextOperatorSpecs;
   }
index c38f6e8..6e98d5a 100644 (file)
 
 package org.apache.samza.operators.spec;
 
-import java.util.ArrayList;
-import java.util.Collection;
-import java.util.function.Function;
-
-import org.apache.samza.config.Config;
 import org.apache.samza.operators.KV;
 import org.apache.samza.operators.functions.FilterFunction;
 import org.apache.samza.operators.functions.FlatMapFunction;
@@ -35,7 +30,6 @@ import org.apache.samza.operators.windows.internal.WindowInternal;
 import org.apache.samza.serializers.Serde;
 import org.apache.samza.system.StreamSpec;
 import org.apache.samza.table.TableSpec;
-import org.apache.samza.task.TaskContext;
 
 
 /**
@@ -73,29 +67,7 @@ public class OperatorSpecs {
    */
   public static <M, OM> StreamOperatorSpec<M, OM> createMapOperatorSpec(
       MapFunction<? super M, ? extends OM> mapFn, String opId) {
-    return new StreamOperatorSpec<>(new FlatMapFunction<M, OM>() {
-      @Override
-      public Collection<OM> apply(M message) {
-        return new ArrayList<OM>() {
-          {
-            OM r = mapFn.apply(message);
-            if (r != null) {
-              this.add(r);
-            }
-          }
-        };
-      }
-
-      @Override
-      public void init(Config config, TaskContext context) {
-        mapFn.init(config, context);
-      }
-
-      @Override
-      public void close() {
-        mapFn.close();
-      }
-    }, mapFn, OperatorSpec.OpCode.MAP, opId);
+    return new MapOperatorSpec<>((MapFunction<M, OM>) mapFn, opId);
   }
 
   /**
@@ -108,28 +80,7 @@ public class OperatorSpecs {
    */
   public static <M> StreamOperatorSpec<M, M> createFilterOperatorSpec(
       FilterFunction<? super M> filterFn, String opId) {
-    return new StreamOperatorSpec<>(new FlatMapFunction<M, M>() {
-      @Override
-      public Collection<M> apply(M message) {
-        return new ArrayList<M>() {
-          {
-            if (filterFn.apply(message)) {
-              this.add(message);
-            }
-          }
-        };
-      }
-
-      @Override
-      public void init(Config config, TaskContext context) {
-        filterFn.init(config, context);
-      }
-
-      @Override
-      public void close() {
-        filterFn.close();
-      }
-    }, filterFn, OperatorSpec.OpCode.FILTER, opId);
+    return new FilterOperatorSpec<>((FilterFunction<M>) filterFn, opId);
   }
 
   /**
@@ -143,7 +94,7 @@ public class OperatorSpecs {
    */
   public static <M, OM> StreamOperatorSpec<M, OM> createFlatMapOperatorSpec(
       FlatMapFunction<? super M, ? extends OM> flatMapFn, String opId) {
-    return new StreamOperatorSpec<>((FlatMapFunction<M, OM>) flatMapFn, flatMapFn, OperatorSpec.OpCode.FLAT_MAP, opId);
+    return new FlatMapOperatorSpec<>((FlatMapFunction<M, OM>) flatMapFn, opId);
   }
 
   /**
@@ -183,8 +134,8 @@ public class OperatorSpecs {
    * @return  the {@link OutputOperatorSpec} for the partitionBy operator
    */
   public static <M, K, V> PartitionByOperatorSpec<M, K, V> createPartitionByOperatorSpec(
-      OutputStreamImpl<KV<K, V>> outputStream, Function<? super M, ? extends K> keyFunction,
-      Function<? super M, ? extends V> valueFunction, String opId) {
+      OutputStreamImpl<KV<K, V>> outputStream, MapFunction<? super M, ? extends K> keyFunction,
+      MapFunction<? super M, ? extends V> valueFunction, String opId) {
     return new PartitionByOperatorSpec<>(outputStream, keyFunction, valueFunction, opId);
   }
 
@@ -198,7 +149,6 @@ public class OperatorSpecs {
    * @param <WV>  the type of value in the window
    * @return  the {@link WindowOperatorSpec}
    */
-
   public static <M, WK, WV> WindowOperatorSpec<M, WK, WV> createWindowOperatorSpec(
       WindowInternal<M, WK, WV> window, String opId) {
     return new WindowOperatorSpec<>(window, opId);
@@ -236,13 +186,7 @@ public class OperatorSpecs {
    * @return  the {@link StreamOperatorSpec} for the merge
    */
   public static <M> StreamOperatorSpec<M, M> createMergeOperatorSpec(String opId) {
-    return new StreamOperatorSpec<>(message ->
-        new ArrayList<M>() {
-          {
-            this.add(message);
-          }
-        },
-        null, OperatorSpec.OpCode.MERGE, opId);
+    return new MergeOperatorSpec<>(opId);
   }
 
   /**
@@ -266,7 +210,6 @@ public class OperatorSpecs {
    * Creates a {@link SendToTableOperatorSpec} with a key extractor and a value extractor function,
    * the type of incoming message is expected to be KV&#60;K, V&#62;.
    *
-   * @param inputOpSpec the operator spec for the input stream
    * @param tableSpec the table spec for the underlying table
    * @param opId the unique ID of the operator
    * @param <K> the type of the table record key
@@ -274,8 +217,8 @@ public class OperatorSpecs {
    * @return the {@link SendToTableOperatorSpec}
    */
   public static <K, V> SendToTableOperatorSpec<K, V> createSendToTableOperatorSpec(
-      OperatorSpec<?, KV<K, V>> inputOpSpec, TableSpec tableSpec, String opId) {
-    return new SendToTableOperatorSpec(inputOpSpec, tableSpec, opId);
+     TableSpec tableSpec, String opId) {
+    return new SendToTableOperatorSpec(tableSpec, opId);
   }
 
   /**
index e439c4e..fe0abcb 100644 (file)
  */
 package org.apache.samza.operators.spec;
 
+import java.io.Serializable;
 import org.apache.samza.operators.OutputStream;
 import org.apache.samza.serializers.Serde;
 import org.apache.samza.system.StreamSpec;
+import org.apache.samza.system.SystemStream;
 
 
-public class OutputStreamImpl<M> implements OutputStream<M> {
+public class OutputStreamImpl<M> implements OutputStream<M>, Serializable {
 
   private final StreamSpec streamSpec;
-  private final Serde keySerde;
-  private final Serde valueSerde;
   private final boolean isKeyed;
 
+  /**
+   * The following fields are serialized by the ExecutionPlanner when generating the configs for the output stream, and
+   * deserialized once during startup in SamzaContainer. They don't need to be deserialized here on a per-task basis
+   */
+  private transient final Serde keySerde;
+  private transient final Serde valueSerde;
+
   public OutputStreamImpl(StreamSpec streamSpec,
       Serde keySerde, Serde valueSerde, boolean isKeyed) {
     this.streamSpec = streamSpec;
@@ -50,6 +57,10 @@ public class OutputStreamImpl<M> implements OutputStream<M> {
     return valueSerde;
   }
 
+  public SystemStream getSystemStream() {
+    return this.streamSpec.toSystemStream();
+  }
+
   public boolean isKeyed() {
     return isKeyed;
   }
index a0a9b61..d6bf3d9 100644 (file)
 package org.apache.samza.operators.spec;
 
 import org.apache.samza.operators.KV;
+import org.apache.samza.operators.functions.MapFunction;
 import org.apache.samza.operators.functions.TimerFunction;
 import org.apache.samza.operators.functions.WatermarkFunction;
 
-import java.util.function.Function;
+import static com.google.common.base.Preconditions.checkArgument;
 
 
 /**
@@ -39,21 +40,25 @@ import java.util.function.Function;
 public class PartitionByOperatorSpec<M, K, V> extends OperatorSpec<M, Void> {
 
   private final OutputStreamImpl<KV<K, V>> outputStream;
-  private final Function<? super M, ? extends K> keyFunction;
-  private final Function<? super M, ? extends V> valueFunction;
+  private final MapFunction<? super M, ? extends K> keyFunction;
+  private final MapFunction<? super M, ? extends V> valueFunction;
 
   /**
    * Constructs an {@link PartitionByOperatorSpec} to send messages to the provided {@code outputStream}
    *
    * @param outputStream the {@link OutputStreamImpl} to send messages to
-   * @param keyFunction the {@link Function} for extracting the key from the message
-   * @param valueFunction the {@link Function} for extracting the value from the message
+   * @param keyFunction the {@link MapFunction} for extracting the key from the message
+   * @param valueFunction the {@link MapFunction} for extracting the value from the message
    * @param opId the unique ID of this {@link SinkOperatorSpec} in the graph
    */
   PartitionByOperatorSpec(OutputStreamImpl<KV<K, V>> outputStream,
-      Function<? super M, ? extends K> keyFunction,
-      Function<? super M, ? extends V> valueFunction, String opId) {
+      MapFunction<? super M, ? extends K> keyFunction,
+      MapFunction<? super M, ? extends V> valueFunction, String opId) {
     super(OpCode.PARTITION_BY, opId);
+    checkArgument(!(keyFunction instanceof TimerFunction || keyFunction instanceof WatermarkFunction),
+        "keyFunction for partitionBy should not implement TimerFunction or WatermarkFunction.");
+    checkArgument(!(valueFunction instanceof TimerFunction || valueFunction instanceof WatermarkFunction),
+        "valueFunction for partitionBy should not implement TimerFunction or WatermarkFunction.");
     this.outputStream = outputStream;
     this.keyFunction = keyFunction;
     this.valueFunction = valueFunction;
@@ -67,11 +72,11 @@ public class PartitionByOperatorSpec<M, K, V> extends OperatorSpec<M, Void> {
     return this.outputStream;
   }
 
-  public Function<? super M, ? extends K> getKeyFunction() {
+  public MapFunction<? super M, ? extends K> getKeyFunction() {
     return keyFunction;
   }
 
-  public Function<? super M, ? extends V> getValueFunction() {
+  public MapFunction<? super M, ? extends V> getValueFunction() {
     return valueFunction;
   }
 
index e1b51be..22f393e 100644 (file)
@@ -35,26 +35,19 @@ import org.apache.samza.table.TableSpec;
 @InterfaceStability.Unstable
 public class SendToTableOperatorSpec<K, V> extends OperatorSpec<KV<K, V>, Void> {
 
-  private final OperatorSpec<?, KV<K, V>> inputOpSpec;
   private final TableSpec tableSpec;
 
   /**
    * Constructor for a {@link SendToTableOperatorSpec}.
    *
-   * @param inputOpSpec  the operator spec of the input stream
    * @param tableSpec  the table spec of the table written to
    * @param opId  the unique ID for this operator
    */
-  SendToTableOperatorSpec(OperatorSpec<?, KV<K, V>> inputOpSpec, TableSpec tableSpec, String opId) {
+  SendToTableOperatorSpec(TableSpec tableSpec, String opId) {
     super(OpCode.SEND_TO, opId);
-    this.inputOpSpec = inputOpSpec;
     this.tableSpec = tableSpec;
   }
 
-  public OperatorSpec<?, KV<K, V>> getInputOpSpec() {
-    return inputOpSpec;
-  }
-
   public TableSpec getTableSpec() {
     return tableSpec;
   }
index 644eb6c..3addbf7 100644 (file)
 package org.apache.samza.operators.spec;
 
 import org.apache.samza.operators.functions.FlatMapFunction;
-import org.apache.samza.operators.functions.TimerFunction;
-import org.apache.samza.operators.functions.WatermarkFunction;
-
 
 /**
- * The spec for a simple stream operator that outputs 0 or more messages for each input message.
+ * The common spec for a simple stream operator that outputs 0 or more messages for each input message.
  *
  * @param <M>  the type of input message
  * @param <OM>  the type of output message
  */
-public class StreamOperatorSpec<M, OM> extends OperatorSpec<M, OM> {
+public abstract class StreamOperatorSpec<M, OM> extends OperatorSpec<M, OM> {
 
-  private final FlatMapFunction<M, OM> transformFn;
-  private final Object originalFn;
+  protected final FlatMapFunction<M, OM> transformFn;
 
   /**
    * Constructor for a {@link StreamOperatorSpec}.
    *
    * @param transformFn  the transformation function
-   * @param originalFn the original user function before wrapping to transformFn
    * @param opCode  the {@link OpCode} for this {@link StreamOperatorSpec}
    * @param opId  the unique ID for this {@link StreamOperatorSpec}
    */
-  StreamOperatorSpec(FlatMapFunction<M, OM> transformFn, Object originalFn, OperatorSpec.OpCode opCode, String opId) {
+  protected StreamOperatorSpec(FlatMapFunction<M, OM> transformFn, OperatorSpec.OpCode opCode, String opId) {
     super(opCode, opId);
     this.transformFn = transformFn;
-    this.originalFn = originalFn;
   }
 
   public FlatMapFunction<M, OM> getTransformFn() {
     return this.transformFn;
   }
 
-  @Override
-  public WatermarkFunction getWatermarkFn() {
-    return originalFn instanceof WatermarkFunction ? (WatermarkFunction) originalFn : null;
-  }
-
-  @Override
-  public TimerFunction getTimerFn() {
-    return originalFn instanceof TimerFunction ? (TimerFunction) originalFn : null;
-  }
 }
index 73d10ff..8d1ad29 100644 (file)
@@ -40,6 +40,8 @@ import java.util.Collections;
 import java.util.List;
 import java.util.stream.Collectors;
 
+import static com.google.common.base.Preconditions.*;
+
 
 /**
  * The spec for an operator that groups messages into finite windows for processing
@@ -61,6 +63,15 @@ public class WindowOperatorSpec<M, WK, WV> extends OperatorSpec<M, WindowPane<WK
    */
   WindowOperatorSpec(WindowInternal<M, WK, WV> window, String opId) {
     super(OpCode.WINDOW, opId);
+    checkArgument(window.getInitializer() == null ||
+        !(window.getInitializer() instanceof TimerFunction || window.getInitializer() instanceof WatermarkFunction),
+        "A window does not accepts a user-defined TimerFunction or WatermarkFunction as the initializer.");
+    checkArgument(window.getKeyExtractor() == null ||
+        !(window.getKeyExtractor() instanceof TimerFunction || window.getKeyExtractor() instanceof WatermarkFunction),
+        "A window does not accepts a user-defined TimerFunction or WatermarkFunction as the keyExtractor.");
+    checkArgument(window.getEventTimeExtractor() == null ||
+        !(window.getEventTimeExtractor() instanceof TimerFunction || window.getEventTimeExtractor() instanceof WatermarkFunction),
+        "A window does not accepts a user-defined TimerFunction or WatermarkFunction as the eventTimeExtractor.");
     this.window = window;
   }
 
index 5eeca99..272ba63 100644 (file)
@@ -20,7 +20,7 @@ package org.apache.samza.operators.stream;
 
 import org.apache.samza.operators.MessageStreamImpl;
 import org.apache.samza.operators.OutputStream;
-import org.apache.samza.operators.StreamGraphImpl;
+import org.apache.samza.operators.StreamGraphSpec;
 import org.apache.samza.operators.spec.InputOperatorSpec;
 import org.apache.samza.operators.spec.OperatorSpec;
 import org.apache.samza.operators.spec.OutputStreamImpl;
@@ -45,7 +45,7 @@ public class IntermediateMessageStreamImpl<M> extends MessageStreamImpl<M> imple
   private final OutputStreamImpl<M> outputStream;
   private final boolean isKeyed;
 
-  public IntermediateMessageStreamImpl(StreamGraphImpl graph, InputOperatorSpec<?, M> inputOperatorSpec,
+  public IntermediateMessageStreamImpl(StreamGraphSpec graph, InputOperatorSpec<?, M> inputOperatorSpec,
       OutputStreamImpl<M> outputStream) {
     super(graph, (OperatorSpec<?, M>) inputOperatorSpec);
     this.outputStream = outputStream;
index ca0ba67..96defd5 100644 (file)
@@ -30,5 +30,5 @@ public interface Cancellable {
    *
    * @return the result of the cancelation
    */
-  public boolean cancel();
+  boolean cancel();
 }
index 705cab7..b186cdb 100644 (file)
@@ -46,14 +46,14 @@ public interface TriggerImpl<M, WK> {
    * @param message the incoming message
    * @param context the {@link TriggerScheduler} to schedule and cancel callbacks
    */
-  public void onMessage(M message, TriggerScheduler<WK> context);
+  void onMessage(M message, TriggerScheduler<WK> context);
 
   /**
    * Returns {@code true} if the current state of the trigger indicates that its condition
    * is satisfied and it is ready to fire.
    * @return if this trigger should fire.
    */
-  public boolean shouldFire();
+  boolean shouldFire();
 
   /**
    * Invoked when the execution of this {@link TriggerImpl} is canceled by an up-stream {@link TriggerImpl}.
@@ -61,6 +61,6 @@ public interface TriggerImpl<M, WK> {
    * No calls to {@link #onMessage(Object, TriggerScheduler)} or {@link #shouldFire()} will be invoked
    * after this invocation.
    */
-  public void cancel();
+  void cancel();
 
 }
index 68962ce..5043977 100644 (file)
@@ -29,7 +29,8 @@ import org.apache.samza.config.StreamConfig;
 import org.apache.samza.execution.ExecutionPlan;
 import org.apache.samza.execution.ExecutionPlanner;
 import org.apache.samza.execution.StreamManager;
-import org.apache.samza.operators.StreamGraphImpl;
+import org.apache.samza.operators.OperatorSpecGraph;
+import org.apache.samza.operators.StreamGraphSpec;
 import org.apache.samza.system.StreamSpec;
 import org.apache.samza.system.SystemAdmins;
 import org.slf4j.Logger;
@@ -44,7 +45,7 @@ import java.util.Set;
 
 
 /**
- * Defines common, core behavior for implementations of the {@link ApplicationRunner} API
+ * Defines common, core behavior for implementations of the {@link ApplicationRunner} API.
  */
 public abstract class AbstractApplicationRunner extends ApplicationRunner {
   private static final Logger log = LoggerFactory.getLogger(AbstractApplicationRunner.class);
@@ -52,8 +53,14 @@ public abstract class AbstractApplicationRunner extends ApplicationRunner {
   private final StreamManager streamManager;
   private final SystemAdmins systemAdmins;
 
+  /**
+   * The {@link ApplicationRunner} is supposed to run a single {@link StreamApplication} instance in the full life-cycle
+   */
+  protected final StreamGraphSpec graphSpec;
+
   public AbstractApplicationRunner(Config config) {
     super(config);
+    this.graphSpec = new StreamGraphSpec(this, config);
     this.systemAdmins = new SystemAdmins(config);
     this.streamManager = new StreamManager(systemAdmins);
   }
@@ -126,23 +133,23 @@ public abstract class AbstractApplicationRunner extends ApplicationRunner {
   /* package private */
   ExecutionPlan getExecutionPlan(StreamApplication app, String runId) throws Exception {
     // build stream graph
-    StreamGraphImpl streamGraph = new StreamGraphImpl(this, config);
-    app.init(streamGraph, config);
+    app.init(graphSpec, config);
 
+    OperatorSpecGraph specGraph = graphSpec.getOperatorSpecGraph();
     // create the physical execution plan
     Map<String, String> cfg = new HashMap<>(config);
     if (StringUtils.isNoneEmpty(runId)) {
       cfg.put(ApplicationConfig.APP_RUN_ID, runId);
     }
 
-    Set<StreamSpec> inputStreams = new HashSet<>(streamGraph.getInputOperators().keySet());
-    inputStreams.removeAll(streamGraph.getOutputStreams().keySet());
+    Set<StreamSpec> inputStreams = new HashSet<>(specGraph.getInputOperators().keySet());
+    inputStreams.removeAll(specGraph.getOutputStreams().keySet());
     ApplicationMode mode = inputStreams.stream().allMatch(StreamSpec::isBounded)
         ? ApplicationMode.BATCH : ApplicationMode.STREAM;
     cfg.put(ApplicationConfig.APP_MODE, mode.name());
 
     ExecutionPlanner planner = new ExecutionPlanner(new MapConfig(cfg), streamManager);
-    return planner.plan(streamGraph);
+    return planner.plan(specGraph);
   }
 
   /* package private for testing */
index 1284060..d64e57a 100644 (file)
@@ -42,6 +42,7 @@ import org.apache.samza.coordinator.CoordinationUtils;
 import org.apache.samza.coordinator.DistributedLockWithState;
 import org.apache.samza.execution.ExecutionPlan;
 import org.apache.samza.job.ApplicationStatus;
+import org.apache.samza.operators.StreamGraphSpec;
 import org.apache.samza.processor.StreamProcessor;
 import org.apache.samza.processor.StreamProcessorLifecycleListener;
 import org.apache.samza.system.StreamSpec;
@@ -139,7 +140,7 @@ public class LocalApplicationRunner extends AbstractApplicationRunner {
     LOG.info("LocalApplicationRunner will run " + taskName);
     LocalStreamProcessorLifeCycleListener listener = new LocalStreamProcessorLifeCycleListener();
 
-    StreamProcessor processor = createStreamProcessor(jobConfig, null, listener);
+    StreamProcessor processor = createStreamProcessor(jobConfig, listener);
 
     numProcessorsToStart.set(1);
     listener.setProcessor(processor);
@@ -169,7 +170,7 @@ public class LocalApplicationRunner extends AbstractApplicationRunner {
       plan.getJobConfigs().forEach(jobConfig -> {
           LOG.debug("Starting job {} StreamProcessor with config {}", jobConfig.getName(), jobConfig);
           LocalStreamProcessorLifeCycleListener listener = new LocalStreamProcessorLifeCycleListener();
-          StreamProcessor processor = createStreamProcessor(jobConfig, app, listener);
+          StreamProcessor processor = createStreamProcessor(jobConfig, graphSpec, listener);
           listener.setProcessor(processor);
           processors.add(processor);
         });
@@ -284,15 +285,32 @@ public class LocalApplicationRunner extends AbstractApplicationRunner {
   /**
    * Create {@link StreamProcessor} based on {@link StreamApplication} and the config
    * @param config config
-   * @param app {@link StreamApplication}
    * @return {@link StreamProcessor]}
    */
   /* package private */
   StreamProcessor createStreamProcessor(
       Config config,
-      StreamApplication app,
       StreamProcessorLifecycleListener listener) {
-    Object taskFactory = TaskFactoryUtil.createTaskFactory(config, app, new LocalApplicationRunner(config));
+    Object taskFactory = TaskFactoryUtil.createTaskFactory(config);
+    return getStreamProcessorInstance(config, taskFactory, listener);
+  }
+
+  /**
+   * Create {@link StreamProcessor} based on {@link StreamApplication} and the config
+   * @param config config
+   * @param graphBuilder {@link StreamGraphSpec}
+   * @return {@link StreamProcessor]}
+   */
+  /* package private */
+  StreamProcessor createStreamProcessor(
+      Config config,
+      StreamGraphSpec graphBuilder,
+      StreamProcessorLifecycleListener listener) {
+    Object taskFactory = TaskFactoryUtil.createTaskFactory(graphBuilder.getOperatorSpecGraph(), graphBuilder.getContextManager());
+    return getStreamProcessorInstance(config, taskFactory, listener);
+  }
+
+  private StreamProcessor getStreamProcessorInstance(Config config, Object taskFactory, StreamProcessorLifecycleListener listener) {
     if (taskFactory instanceof StreamTaskFactory) {
       return new StreamProcessor(
           config, new HashMap<>(), (StreamTaskFactory) taskFactory, listener);
index 5831910..7751241 100644 (file)
@@ -70,8 +70,7 @@ public class LocalContainerRunner extends AbstractApplicationRunner {
 
   @Override
   public void run(StreamApplication streamApp) {
-    super.run(streamApp);
-    Object taskFactory = TaskFactoryUtil.createTaskFactory(config, streamApp, this);
+    Object taskFactory = getTaskFactory(streamApp);
 
     container = SamzaContainer$.MODULE$.apply(
         containerId,
@@ -106,6 +105,14 @@ public class LocalContainerRunner extends AbstractApplicationRunner {
     }
   }
 
+  private Object getTaskFactory(StreamApplication streamApp) {
+    if (streamApp != null) {
+      streamApp.init(graphSpec, config);
+      return TaskFactoryUtil.createTaskFactory(graphSpec.getOperatorSpecGraph(), graphSpec.getContextManager());
+    }
+    return TaskFactoryUtil.createTaskFactory(config);
+  }
+
   @Override
   public void kill(StreamApplication streamApp) {
     // Ultimately this class probably won't end up extending ApplicationRunner, so this will be deleted
index e4b3c62..fdd134f 100644 (file)
  */
 package org.apache.samza.task;
 
-import org.apache.samza.application.StreamApplication;
 import org.apache.samza.config.Config;
+import org.apache.samza.operators.OperatorSpecGraph;
 import org.apache.samza.system.EndOfStreamMessage;
 import org.apache.samza.system.MessageType;
 import org.apache.samza.operators.ContextManager;
 import org.apache.samza.operators.KV;
-import org.apache.samza.operators.StreamGraphImpl;
 import org.apache.samza.operators.impl.InputOperatorImpl;
 import org.apache.samza.operators.impl.OperatorImplGraph;
-import org.apache.samza.runtime.ApplicationRunner;
 import org.apache.samza.system.IncomingMessageEnvelope;
 import org.apache.samza.system.SystemStream;
 import org.apache.samza.system.WatermarkMessage;
@@ -39,41 +37,45 @@ import org.slf4j.LoggerFactory;
 
 /**
  * A {@link StreamTask} implementation that brings all the operator API implementation components together and
- * feeds the input messages into the user-defined transformation chains in {@link StreamApplication}.
+ * feeds the input messages into the user-defined transformation chains in {@link OperatorSpecGraph}.
  */
 public class StreamOperatorTask implements StreamTask, InitableTask, WindowableTask, ClosableTask {
   private static final Logger LOG = LoggerFactory.getLogger(StreamOperatorTask.class);
 
-  private final StreamApplication streamApplication;
-  private final ApplicationRunner runner;
+  private final OperatorSpecGraph specGraph;
+  // TODO: to be replaced by proper scope of shared context factory in SAMZA-1714
+  private final ContextManager contextManager;
   private final Clock clock;
 
   private OperatorImplGraph operatorImplGraph;
-  private ContextManager contextManager;
 
   /**
-   * Constructs an adaptor task to run the user-implemented {@link StreamApplication}.
-   * @param streamApplication the user-implemented {@link StreamApplication} that creates the logical DAG
-   * @param runner the {@link ApplicationRunner} to get the mapping between logical and physical streams
+   * Constructs an adaptor task to run the user-implemented {@link OperatorSpecGraph}.
+   * @param specGraph the serialized version of user-implemented {@link OperatorSpecGraph}
+   *                  that includes the logical DAG
+   * @param contextManager the {@link ContextManager} used to set up the shared context used by operators in the DAG
    * @param clock the {@link Clock} to use for time-keeping
    */
-  public StreamOperatorTask(StreamApplication streamApplication, ApplicationRunner runner, Clock clock) {
-    this.streamApplication = streamApplication;
-    this.runner = runner;
+  public StreamOperatorTask(OperatorSpecGraph specGraph, ContextManager contextManager, Clock clock) {
+    this.specGraph = specGraph.clone();
+    this.contextManager = contextManager;
     this.clock = clock;
   }
 
-  public StreamOperatorTask(StreamApplication application, ApplicationRunner runner) {
-    this(application, runner, SystemClock.instance());
+  public StreamOperatorTask(OperatorSpecGraph specGraph, ContextManager contextManager) {
+    this(specGraph, contextManager, SystemClock.instance());
   }
 
   /**
    * Initializes this task during startup.
    * <p>
-   * Implementation: Initializes the user-implemented {@link StreamApplication}. The {@link StreamApplication} sets
-   * the input and output streams and the task-wide context manager using the {@link StreamGraphImpl} APIs,
-   * and the logical transforms using the {@link org.apache.samza.operators.MessageStream} APIs. It then uses
-   * the {@link StreamGraphImpl} to create the {@link OperatorImplGraph} corresponding to the logical DAG.
+   * Implementation: Initializes the runtime {@link OperatorImplGraph} according to user-defined {@link OperatorSpecGraph}.
+   * The {@link org.apache.samza.operators.StreamGraphSpec} sets the input and output streams and the task-wide
+   * context manager using the {@link org.apache.samza.operators.StreamGraph} APIs,
+   * and the logical transforms using the {@link org.apache.samza.operators.MessageStream} APIs. After the
+   * {@link org.apache.samza.operators.StreamGraphSpec} is initialized once by the application, it then creates
+   * an immutable {@link OperatorSpecGraph} accordingly, which is passed in to this class to create the {@link OperatorImplGraph}
+   * corresponding to the logical DAG.
    *
    * @param config allows accessing of fields in the configuration files that this StreamTask is specified in
    * @param context allows initializing and accessing contextual data of this StreamTask
@@ -81,18 +83,14 @@ public class StreamOperatorTask implements StreamTask, InitableTask, WindowableT
    */
   @Override
   public final void init(Config config, TaskContext context) throws Exception {
-    StreamGraphImpl streamGraph = new StreamGraphImpl(runner, config);
-    // initialize the user-implemented stream application.
-    this.streamApplication.init(streamGraph, config);
 
-    // get the user-implemented context manager and initialize it
-    this.contextManager = streamGraph.getContextManager();
+    // get the user-implemented per task context manager and initialize it
     if (this.contextManager != null) {
       this.contextManager.init(config, context);
     }
 
     // create the operator impl DAG corresponding to the logical operator spec DAG
-    this.operatorImplGraph = new OperatorImplGraph(streamGraph, config, context, clock);
+    this.operatorImplGraph = new OperatorImplGraph(specGraph, config, context, clock);
   }
 
   /**
index 2a894ae..38ae854 100644 (file)
@@ -24,7 +24,8 @@ import org.apache.samza.config.Config;
 import org.apache.samza.config.ConfigException;
 import org.apache.samza.application.StreamApplication;
 import org.apache.samza.config.TaskConfig;
-import org.apache.samza.runtime.ApplicationRunner;
+import org.apache.samza.operators.ContextManager;
+import org.apache.samza.operators.OperatorSpecGraph;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -41,19 +42,28 @@ public class TaskFactoryUtil {
   private static final Logger log = LoggerFactory.getLogger(TaskFactoryUtil.class);
 
   /**
-   * This method creates a task factory class based on the configuration and {@link StreamApplication}
+   * This method creates a task factory class based on the {@link StreamApplication}
+   *
+   * @param specGraph the {@link OperatorSpecGraph}
+   * @param contextManager the {@link ContextManager} to set up initial context for {@code specGraph}
+   * @return  a task factory object, either a instance of {@link StreamTaskFactory} or {@link AsyncStreamTaskFactory}
+   */
+  public static Object createTaskFactory(OperatorSpecGraph specGraph, ContextManager contextManager) {
+    return createStreamOperatorTaskFactory(specGraph, contextManager);
+  }
+
+  /**
+   * This method creates a task factory class based on the configuration
    *
    * @param config  the {@link Config} for this job
-   * @param streamApp the {@link StreamApplication}
-   * @param runner  the {@link ApplicationRunner} to run this job
    * @return  a task factory object, either a instance of {@link StreamTaskFactory} or {@link AsyncStreamTaskFactory}
    */
-  public static Object createTaskFactory(Config config, StreamApplication streamApp, ApplicationRunner runner) {
-    return (streamApp != null) ? createStreamOperatorTaskFactory(streamApp, runner) : fromTaskClassConfig(config);
+  public static Object createTaskFactory(Config config) {
+    return fromTaskClassConfig(config);
   }
 
-  private static StreamTaskFactory createStreamOperatorTaskFactory(StreamApplication streamApp, ApplicationRunner runner) {
-    return () -> new StreamOperatorTask(streamApp, runner);
+  private static StreamTaskFactory createStreamOperatorTaskFactory(OperatorSpecGraph specGraph, ContextManager contextManager) {
+    return () -> new StreamOperatorTask(specGraph, contextManager);
   }
 
   /**
index 61e8c77..64ee7f3 100644 (file)
@@ -28,7 +28,6 @@ import org.apache.samza.config.Config
 import org.apache.samza.config.StreamConfig.Config2Stream
 import org.apache.samza.job.model.JobModel
 import org.apache.samza.metrics.MetricsReporter
-import org.apache.samza.operators.functions.TimerFunction
 import org.apache.samza.storage.TaskStorageManager
 import org.apache.samza.system._
 import org.apache.samza.table.TableManager
index e5ce3c8..029b375 100644 (file)
@@ -27,10 +27,12 @@ import org.apache.samza.coordinator.JobModelManager
 import org.apache.samza.coordinator.stream.CoordinatorStreamManager
 import org.apache.samza.job.{StreamJob, StreamJobFactory}
 import org.apache.samza.metrics.{JmxServer, MetricsRegistryMap, MetricsReporter}
+import org.apache.samza.operators.StreamGraphSpec
 import org.apache.samza.runtime.LocalContainerRunner
 import org.apache.samza.storage.ChangelogStreamManager
 import org.apache.samza.task.TaskFactoryUtil
 import org.apache.samza.util.Logging
+
 import scala.collection.JavaConversions._
 import scala.collection.mutable
 
@@ -71,7 +73,14 @@ class ThreadJobFactory extends StreamJobFactory with Logging {
     val jmxServer = new JmxServer
     val streamApp = TaskFactoryUtil.createStreamApplication(config)
     val appRunner = new LocalContainerRunner(jobModel, "0")
-    val taskFactory = TaskFactoryUtil.createTaskFactory(config, streamApp, appRunner)
+
+    val taskFactory = if (streamApp != null) {
+      val graphSpec = new StreamGraphSpec(appRunner, config)
+      streamApp.init(graphSpec, config)
+      TaskFactoryUtil.createTaskFactory(graphSpec.getOperatorSpecGraph(), graphSpec.getContextManager)
+    } else {
+      TaskFactoryUtil.createTaskFactory(config)
+    }
 
     // Give developers a nice friendly warning if they've specified task.opts and are using a threaded job.
     config.getTaskOpts match {
index 664f3b1..83fe5ad 100644 (file)
@@ -34,8 +34,8 @@ import org.apache.samza.config.MapConfig;
 import org.apache.samza.config.TaskConfig;
 import org.apache.samza.operators.KV;
 import org.apache.samza.operators.MessageStream;
+import org.apache.samza.operators.StreamGraphSpec;
 import org.apache.samza.operators.OutputStream;
-import org.apache.samza.operators.StreamGraphImpl;
 import org.apache.samza.operators.functions.JoinFunction;
 import org.apache.samza.operators.windows.Windows;
 import org.apache.samza.runtime.ApplicationRunner;
@@ -97,24 +97,24 @@ public class TestExecutionPlanner {
     };
   }
 
-  private StreamGraphImpl createSimpleGraph() {
+  private StreamGraphSpec createSimpleGraph() {
     /**
      * a simple graph of partitionBy and map
      *
      * input1 -> partitionBy -> map -> output1
      *
      */
-    StreamGraphImpl streamGraph = new StreamGraphImpl(runner, config);
-    MessageStream<KV<Object, Object>> input1 = streamGraph.getInputStream("input1");
-    OutputStream<KV<Object, Object>> output1 = streamGraph.getOutputStream("output1");
+    StreamGraphSpec graphSpec = new StreamGraphSpec(runner, config);
+    MessageStream<KV<Object, Object>> input1 = graphSpec.getInputStream("input1");
+    OutputStream<KV<Object, Object>> output1 = graphSpec.getOutputStream("output1");
     input1
         .partitionBy(m -> m.key, m -> m.value, "p1")
         .map(kv -> kv)
         .sendTo(output1);
-    return streamGraph;
+    return graphSpec;
   }
 
-  private StreamGraphImpl createStreamGraphWithJoin() {
+  private StreamGraphSpec createStreamGraphWithJoin() {
 
     /**
      * the graph looks like the following. number of partitions in parentheses. quotes indicate expected value.
@@ -127,76 +127,79 @@ public class TestExecutionPlanner {
      *
      */
 
-    StreamGraphImpl streamGraph = new StreamGraphImpl(runner, config);
+    StreamGraphSpec graphSpec = new StreamGraphSpec(runner, config);
     MessageStream<KV<Object, Object>> messageStream1 =
-        streamGraph.<KV<Object, Object>>getInputStream("input1")
+        graphSpec.<KV<Object, Object>>getInputStream("input1")
             .map(m -> m);
     MessageStream<KV<Object, Object>> messageStream2 =
-        streamGraph.<KV<Object, Object>>getInputStream("input2")
+        graphSpec.<KV<Object, Object>>getInputStream("input2")
             .partitionBy(m -> m.key, m -> m.value, "p1")
             .filter(m -> true);
     MessageStream<KV<Object, Object>> messageStream3 =
-        streamGraph.<KV<Object, Object>>getInputStream("input3")
+        graphSpec.<KV<Object, Object>>getInputStream("input3")
             .filter(m -> true)
             .partitionBy(m -> m.key, m -> m.value, "p2")
             .map(m -> m);
-    OutputStream<KV<Object, Object>> output1 = streamGraph.getOutputStream("output1");
-    OutputStream<KV<Object, Object>> output2 = streamGraph.getOutputStream("output2");
+    OutputStream<KV<Object, Object>> output1 = graphSpec.getOutputStream("output1");
+    OutputStream<KV<Object, Object>> output2 = graphSpec.getOutputStream("output2");
 
     messageStream1
-        .join(messageStream2, mock(JoinFunction.class),
+        .join(messageStream2,
+            (JoinFunction<Object, KV<Object, Object>, KV<Object, Object>, KV<Object, Object>>) mock(JoinFunction.class),
             mock(Serde.class), mock(Serde.class), mock(Serde.class), Duration.ofHours(2), "j1")
         .sendTo(output1);
     messageStream3
-        .join(messageStream2, mock(JoinFunction.class),
+        .join(messageStream2,
+            (JoinFunction<Object, KV<Object, Object>, KV<Object, Object>, KV<Object, Object>>) mock(JoinFunction.class),
             mock(Serde.class), mock(Serde.class), mock(Serde.class), Duration.ofHours(1), "j2")
         .sendTo(output2);
 
-    return streamGraph;
+    return graphSpec;
   }
 
-  private StreamGraphImpl createStreamGraphWithJoinAndWindow() {
+  private StreamGraphSpec createStreamGraphWithJoinAndWindow() {
 
-    StreamGraphImpl streamGraph = new StreamGraphImpl(runner, config);
+    StreamGraphSpec graphSpec = new StreamGraphSpec(runner, config);
     MessageStream<KV<Object, Object>> messageStream1 =
-        streamGraph.<KV<Object, Object>>getInputStream("input1")
+        graphSpec.<KV<Object, Object>>getInputStream("input1")
             .map(m -> m);
     MessageStream<KV<Object, Object>> messageStream2 =
-        streamGraph.<KV<Object, Object>>getInputStream("input2")
+        graphSpec.<KV<Object, Object>>getInputStream("input2")
             .partitionBy(m -> m.key, m -> m.value, "p1")
             .filter(m -> true);
     MessageStream<KV<Object, Object>> messageStream3 =
-        streamGraph.<KV<Object, Object>>getInputStream("input3")
+        graphSpec.<KV<Object, Object>>getInputStream("input3")
             .filter(m -> true)
             .partitionBy(m -> m.key, m -> m.value, "p2")
             .map(m -> m);
-    OutputStream<KV<Object, Object>> output1 = streamGraph.getOutputStream("output1");
-    OutputStream<KV<Object, Object>> output2 = streamGraph.getOutputStream("output2");
+    OutputStream<KV<Object, Object>> output1 = graphSpec.getOutputStream("output1");
+    OutputStream<KV<Object, Object>> output2 = graphSpec.getOutputStream("output2");
 
     messageStream1.map(m -> m)
         .filter(m->true)
-        .window(Windows.<KV<Object, Object>, Object>keyedTumblingWindow(m -> m, Duration.ofMillis(8),
-            mock(Serde.class), mock(Serde.class)), "w1");
+        .window(Windows.keyedTumblingWindow(m -> m, Duration.ofMillis(8), mock(Serde.class), mock(Serde.class)), "w1");
 
     messageStream2.map(m -> m)
         .filter(m->true)
-        .window(Windows.<KV<Object, Object>, Object>keyedTumblingWindow(m -> m, Duration.ofMillis(16),
-            mock(Serde.class), mock(Serde.class)), "w2");
+        .window(Windows.keyedTumblingWindow(m -> m, Duration.ofMillis(16), mock(Serde.class), mock(Serde.class)), "w2");
 
     messageStream1
-        .join(messageStream2, mock(JoinFunction.class),
+        .join(messageStream2,
+            (JoinFunction<Object, KV<Object, Object>, KV<Object, Object>, KV<Object, Object>>) mock(JoinFunction.class),
             mock(Serde.class), mock(Serde.class), mock(Serde.class), Duration.ofMillis(1600), "j1")
         .sendTo(output1);
     messageStream3
-        .join(messageStream2, mock(JoinFunction.class),
+        .join(messageStream2,
+            (JoinFunction<Object, KV<Object, Object>, KV<Object, Object>, KV<Object, Object>>) mock(JoinFunction.class),
             mock(Serde.class), mock(Serde.class), mock(Serde.class), Duration.ofMillis(100), "j2")
         .sendTo(output2);
     messageStream3
-        .join(messageStream2, mock(JoinFunction.class),
+        .join(messageStream2,
+            (JoinFunction<Object, KV<Object, Object>, KV<Object, Object>, KV<Object, Object>>) mock(JoinFunction.class),
             mock(Serde.class), mock(Serde.class), mock(Serde.class), Duration.ofMillis(252), "j3")
         .sendTo(output2);
 
-    return streamGraph;
+    return graphSpec;
   }
 
   @Before
@@ -252,9 +255,9 @@ public class TestExecutionPlanner {
   @Test
   public void testCreateProcessorGraph() {
     ExecutionPlanner planner = new ExecutionPlanner(config, streamManager);
-    StreamGraphImpl streamGraph = createStreamGraphWithJoin();
+    StreamGraphSpec graphSpec = createStreamGraphWithJoin();
 
-    JobGraph jobGraph = planner.createJobGraph(streamGraph);
+    JobGraph jobGraph = planner.createJobGraph(graphSpec.getOperatorSpecGraph());
     assertTrue(jobGraph.getSources().size() == 3);
     assertTrue(jobGraph.getSinks().size() == 2);
     assertTrue(jobGraph.getIntermediateStreams().size() == 2); // two streams generated by partitionBy
@@ -263,8 +266,8 @@ public class TestExecutionPlanner {
   @Test
   public void testFetchExistingStreamPartitions() {
     ExecutionPlanner planner = new ExecutionPlanner(config, streamManager);
-    StreamGraphImpl streamGraph = createStreamGraphWithJoin();
-    JobGraph jobGraph = planner.createJobGraph(streamGraph);
+    StreamGraphSpec graphSpec = createStreamGraphWithJoin();
+    JobGraph jobGraph = planner.createJobGraph(graphSpec.getOperatorSpecGraph());
 
     ExecutionPlanner.updateExistingPartitions(jobGraph, streamManager);
     assertTrue(jobGraph.getOrCreateStreamEdge(input1).getPartitionCount() == 64);
@@ -281,11 +284,11 @@ public class TestExecutionPlanner {
   @Test
   public void testCalculateJoinInputPartitions() {
     ExecutionPlanner planner = new ExecutionPlanner(config, streamManager);
-    StreamGraphImpl streamGraph = createStreamGraphWithJoin();
-    JobGraph jobGraph = planner.createJobGraph(streamGraph);
+    StreamGraphSpec graphSpec = createStreamGraphWithJoin();
+    JobGraph jobGraph = planner.createJobGraph(graphSpec.getOperatorSpecGraph());
 
     ExecutionPlanner.updateExistingPartitions(jobGraph, streamManager);
-    ExecutionPlanner.calculateJoinInputPartitions(streamGraph, jobGraph);
+    ExecutionPlanner.calculateJoinInputPartitions(jobGraph);
 
     // the partitions should be the same as input1
     jobGraph.getIntermediateStreams().forEach(edge -> {
@@ -300,9 +303,9 @@ public class TestExecutionPlanner {
     Config cfg = new MapConfig(map);
 
     ExecutionPlanner planner = new ExecutionPlanner(cfg, streamManager);
-    StreamGraphImpl streamGraph = createSimpleGraph();
-    JobGraph jobGraph = planner.createJobGraph(streamGraph);
-    planner.calculatePartitions(streamGraph, jobGraph);
+    StreamGraphSpec graphSpec = createSimpleGraph();
+    JobGraph jobGraph = planner.createJobGraph(graphSpec.getOperatorSpecGraph());
+    planner.calculatePartitions(jobGraph);
 
     // the partitions should be the same as input1
     jobGraph.getIntermediateStreams().forEach(edge -> {
@@ -317,8 +320,8 @@ public class TestExecutionPlanner {
     Config cfg = new MapConfig(map);
 
     ExecutionPlanner planner = new ExecutionPlanner(cfg, streamManager);
-    StreamGraphImpl streamGraph = createStreamGraphWithJoin();
-    ExecutionPlan plan = planner.plan(streamGraph);
+    StreamGraphSpec graphSpec = createStreamGraphWithJoin();
+    ExecutionPlan plan = planner.plan(graphSpec.getOperatorSpecGraph());
     List<JobConfig> jobConfigs = plan.getJobConfigs();
     for (JobConfig config : jobConfigs) {
       System.out.println(config);
@@ -332,8 +335,8 @@ public class TestExecutionPlanner {
     Config cfg = new MapConfig(map);
 
     ExecutionPlanner planner = new ExecutionPlanner(cfg, streamManager);
-    StreamGraphImpl streamGraph = createStreamGraphWithJoinAndWindow();
-    ExecutionPlan plan = planner.plan(streamGraph);
+    StreamGraphSpec graphSpec = createStreamGraphWithJoinAndWindow();
+    ExecutionPlan plan = planner.plan(graphSpec.getOperatorSpecGraph());
     List<JobConfig> jobConfigs = plan.getJobConfigs();
     assertEquals(1, jobConfigs.size());
 
@@ -349,8 +352,8 @@ public class TestExecutionPlanner {
     Config cfg = new MapConfig(map);
 
     ExecutionPlanner planner = new ExecutionPlanner(cfg, streamManager);
-    StreamGraphImpl streamGraph = createStreamGraphWithJoinAndWindow();
-    ExecutionPlan plan = planner.plan(streamGraph);
+    StreamGraphSpec graphSpec = createStreamGraphWithJoinAndWindow();
+    ExecutionPlan plan = planner.plan(graphSpec.getOperatorSpecGraph());
     List<JobConfig> jobConfigs = plan.getJobConfigs();
     assertEquals(1, jobConfigs.size());
 
@@ -366,8 +369,8 @@ public class TestExecutionPlanner {
     Config cfg = new MapConfig(map);
 
     ExecutionPlanner planner = new ExecutionPlanner(cfg, streamManager);
-    StreamGraphImpl streamGraph = createSimpleGraph();
-    ExecutionPlan plan = planner.plan(streamGraph);
+    StreamGraphSpec graphSpec = createSimpleGraph();
+    ExecutionPlan plan = planner.plan(graphSpec.getOperatorSpecGraph());
     List<JobConfig> jobConfigs = plan.getJobConfigs();
     assertEquals(1, jobConfigs.size());
     assertFalse(jobConfigs.get(0).containsKey(TaskConfig.WINDOW_MS()));
@@ -381,8 +384,8 @@ public class TestExecutionPlanner {
     Config cfg = new MapConfig(map);
 
     ExecutionPlanner planner = new ExecutionPlanner(cfg, streamManager);
-    StreamGraphImpl streamGraph = createSimpleGraph();
-    ExecutionPlan plan = planner.plan(streamGraph);
+    StreamGraphSpec graphSpec = createSimpleGraph();
+    ExecutionPlan plan = planner.plan(graphSpec.getOperatorSpecGraph());
     List<JobConfig> jobConfigs = plan.getJobConfigs();
     assertEquals(1, jobConfigs.size());
     assertEquals("2000", jobConfigs.get(0).get(TaskConfig.WINDOW_MS()));
@@ -391,8 +394,8 @@ public class TestExecutionPlanner {
   @Test
   public void testCalculateIntStreamPartitions() throws Exception {
     ExecutionPlanner planner = new ExecutionPlanner(config, streamManager);
-    StreamGraphImpl streamGraph = createSimpleGraph();
-    JobGraph jobGraph = (JobGraph) planner.plan(streamGraph);
+    StreamGraphSpec graphSpec = createSimpleGraph();
+    JobGraph jobGraph = (JobGraph) planner.plan(graphSpec.getOperatorSpecGraph());
 
     // the partitions should be the same as input1
     jobGraph.getIntermediateStreams().forEach(edge -> {
@@ -424,12 +427,12 @@ public class TestExecutionPlanner {
     int partitionLimit = ExecutionPlanner.MAX_INFERRED_PARTITIONS;
 
     ExecutionPlanner planner = new ExecutionPlanner(config, streamManager);
-    StreamGraphImpl streamGraph = new StreamGraphImpl(runner, config);
+    StreamGraphSpec graphSpec = new StreamGraphSpec(runner, config);
 
-    MessageStream<KV<Object, Object>> input1 = streamGraph.getInputStream("input4");
-    OutputStream<KV<Object, Object>> output1 = streamGraph.getOutputStream("output1");
+    MessageStream<KV<Object, Object>> input1 = graphSpec.getInputStream("input4");
+    OutputStream<KV<Object, Object>> output1 = graphSpec.getOutputStream("output1");
     input1.partitionBy(m -> m.key, m -> m.value, "p1").map(kv -> kv).sendTo(output1);
-    JobGraph jobGraph = (JobGraph) planner.plan(streamGraph);
+    JobGraph jobGraph = (JobGraph) planner.plan(graphSpec.getOperatorSpecGraph());
 
     // the partitions should be the same as input1
     jobGraph.getIntermediateStreams().forEach(edge -> {
index bf131ce..359c422 100644 (file)
@@ -57,16 +57,16 @@ public class TestJobGraph {
    * 2 9 10
    */
   private void createGraph1() {
-    graph1 = new JobGraph(null);
+    graph1 = new JobGraph(null, null);
 
-    JobNode n2 = graph1.getOrCreateJobNode("2", "1", null);
-    JobNode n3 = graph1.getOrCreateJobNode("3", "1", null);
-    JobNode n5 = graph1.getOrCreateJobNode("5", "1", null);
-    JobNode n7 = graph1.getOrCreateJobNode("7", "1", null);
-    JobNode n8 = graph1.getOrCreateJobNode("8", "1", null);
-    JobNode n9 = graph1.getOrCreateJobNode("9", "1", null);
-    JobNode n10 = graph1.getOrCreateJobNode("10", "1", null);
-    JobNode n11 = graph1.getOrCreateJobNode("11", "1", null);
+    JobNode n2 = graph1.getOrCreateJobNode("2", "1");
+    JobNode n3 = graph1.getOrCreateJobNode("3", "1");
+    JobNode n5 = graph1.getOrCreateJobNode("5", "1");
+    JobNode n7 = graph1.getOrCreateJobNode("7", "1");
+    JobNode n8 = graph1.getOrCreateJobNode("8", "1");
+    JobNode n9 = graph1.getOrCreateJobNode("9", "1");
+    JobNode n10 = graph1.getOrCreateJobNode("10", "1");
+    JobNode n11 = graph1.getOrCreateJobNode("11", "1");
 
     graph1.addSource(genStream(), n5);
     graph1.addSource(genStream(), n7);
@@ -90,15 +90,15 @@ public class TestJobGraph {
    *      |<---6 <--|    <>
    */
   private void createGraph2() {
-    graph2 = new JobGraph(null);
+    graph2 = new JobGraph(null, null);
 
-    JobNode n1 = graph2.getOrCreateJobNode("1", "1", null);
-    JobNode n2 = graph2.getOrCreateJobNode("2", "1", null);
-    JobNode n3 = graph2.getOrCreateJobNode("3", "1", null);
-    JobNode n4 = graph2.getOrCreateJobNode("4", "1", null);
-    JobNode n5 = graph2.getOrCreateJobNode("5", "1", null);
-    JobNode n6 = graph2.getOrCreateJobNode("6", "1", null);
-    JobNode n7 = graph2.getOrCreateJobNode("7", "1", null);
+    JobNode n1 = graph2.getOrCreateJobNode("1", "1");
+    JobNode n2 = graph2.getOrCreateJobNode("2", "1");
+    JobNode n3 = graph2.getOrCreateJobNode("3", "1");
+    JobNode n4 = graph2.getOrCreateJobNode("4", "1");
+    JobNode n5 = graph2.getOrCreateJobNode("5", "1");
+    JobNode n6 = graph2.getOrCreateJobNode("6", "1");
+    JobNode n7 = graph2.getOrCreateJobNode("7", "1");
 
     graph2.addSource(genStream(), n1);
     graph2.addIntermediateStream(genStream(), n1, n2);
@@ -117,10 +117,10 @@ public class TestJobGraph {
    * 1<->1 -> 2<->2
    */
   private void createGraph3() {
-    graph3 = new JobGraph(null);
+    graph3 = new JobGraph(null, null);
 
-    JobNode n1 = graph3.getOrCreateJobNode("1", "1", null);
-    JobNode n2 = graph3.getOrCreateJobNode("2", "1", null);
+    JobNode n1 = graph3.getOrCreateJobNode("1", "1");
+    JobNode n2 = graph3.getOrCreateJobNode("2", "1");
 
     graph3.addSource(genStream(), n1);
     graph3.addIntermediateStream(genStream(), n1, n1);
@@ -133,9 +133,9 @@ public class TestJobGraph {
    * 1<->1
    */
   private void createGraph4() {
-    graph4 = new JobGraph(null);
+    graph4 = new JobGraph(null, null);
 
-    JobNode n1 = graph4.getOrCreateJobNode("1", "1", null);
+    JobNode n1 = graph4.getOrCreateJobNode("1", "1");
 
     graph4.addSource(genStream(), n1);
     graph4.addIntermediateStream(genStream(), n1, n1);
@@ -151,7 +151,7 @@ public class TestJobGraph {
 
   @Test
   public void testAddSource() {
-    JobGraph graph = new JobGraph(null);
+    JobGraph graph = new JobGraph(null, null);
 
     /**
      * s1 -> 1
@@ -160,9 +160,9 @@ public class TestJobGraph {
      * s3 -> 2
      *   |-> 3
      */
-    JobNode n1 = graph.getOrCreateJobNode("1", "1", null);
-    JobNode n2 = graph.getOrCreateJobNode("2", "1", null);
-    JobNode n3 = graph.getOrCreateJobNode("3", "1", null);
+    JobNode n1 = graph.getOrCreateJobNode("1", "1");
+    JobNode n2 = graph.getOrCreateJobNode("2", "1");
+    JobNode n3 = graph.getOrCreateJobNode("3", "1");
     StreamSpec s1 = genStream();
     StreamSpec s2 = genStream();
     StreamSpec s3 = genStream();
@@ -173,9 +173,9 @@ public class TestJobGraph {
 
     assertTrue(graph.getSources().size() == 3);
 
-    assertTrue(graph.getOrCreateJobNode("1", "1", null).getInEdges().size() == 2);
-    assertTrue(graph.getOrCreateJobNode("2", "1", null).getInEdges().size() == 1);
-    assertTrue(graph.getOrCreateJobNode("3", "1", null).getInEdges().size() == 1);
+    assertTrue(graph.getOrCreateJobNode("1", "1").getInEdges().size() == 2);
+    assertTrue(graph.getOrCreateJobNode("2", "1").getInEdges().size() == 1);
+    assertTrue(graph.getOrCreateJobNode("3", "1").getInEdges().size() == 1);
 
     assertTrue(graph.getOrCreateStreamEdge(s1).getSourceNodes().size() == 0);
     assertTrue(graph.getOrCreateStreamEdge(s1).getTargetNodes().size() == 1);
@@ -192,9 +192,9 @@ public class TestJobGraph {
      * 2 -> s2
      * 2 -> s3
      */
-    JobGraph graph = new JobGraph(null);
-    JobNode n1 = graph.getOrCreateJobNode("1", "1", null);
-    JobNode n2 = graph.getOrCreateJobNode("2", "1", null);
+    JobGraph graph = new JobGraph(null, null);
+    JobNode n1 = graph.getOrCreateJobNode("1", "1");
+    JobNode n2 = graph.getOrCreateJobNode("2", "1");
     StreamSpec s1 = genStream();
     StreamSpec s2 = genStream();
     StreamSpec s3 = genStream();
@@ -203,8 +203,8 @@ public class TestJobGraph {
     graph.addSink(s3, n2);
 
     assertTrue(graph.getSinks().size() == 3);
-    assertTrue(graph.getOrCreateJobNode("1", "1", null).getOutEdges().size() == 1);
-    assertTrue(graph.getOrCreateJobNode("2", "1", null).getOutEdges().size() == 2);
+    assertTrue(graph.getOrCreateJobNode("1", "1").getOutEdges().size() == 1);
+    assertTrue(graph.getOrCreateJobNode("2", "1").getOutEdges().size() == 2);
 
     assertTrue(graph.getOrCreateStreamEdge(s1).getSourceNodes().size() == 1);
     assertTrue(graph.getOrCreateStreamEdge(s1).getTargetNodes().size() == 0);
index f218e89..abe8969 100644 (file)
@@ -28,7 +28,7 @@ import org.apache.samza.config.MapConfig;
 import org.apache.samza.operators.KV;
 import org.apache.samza.operators.MessageStream;
 import org.apache.samza.operators.OutputStream;
-import org.apache.samza.operators.StreamGraphImpl;
+import org.apache.samza.operators.StreamGraphSpec;
 import org.apache.samza.operators.functions.JoinFunction;
 import org.apache.samza.operators.windows.Windows;
 import org.apache.samza.runtime.ApplicationRunner;
@@ -114,35 +114,37 @@ public class TestJobGraphJsonGenerator {
     when(systemAdmins.getSystemAdmin("system2")).thenReturn(systemAdmin2);
     StreamManager streamManager = new StreamManager(systemAdmins);
 
-    StreamGraphImpl streamGraph = new StreamGraphImpl(runner, config);
-    streamGraph.setDefaultSerde(KVSerde.of(new NoOpSerde<>(), new NoOpSerde<>()));
+    StreamGraphSpec graphSpec = new StreamGraphSpec(runner, config);
+    graphSpec.setDefaultSerde(KVSerde.of(new NoOpSerde<>(), new NoOpSerde<>()));
     MessageStream<KV<Object, Object>> messageStream1 =
-        streamGraph.<KV<Object, Object>>getInputStream("input1")
+        graphSpec.<KV<Object, Object>>getInputStream("input1")
             .map(m -> m);
     MessageStream<KV<Object, Object>> messageStream2 =
-        streamGraph.<KV<Object, Object>>getInputStream("input2")
+        graphSpec.<KV<Object, Object>>getInputStream("input2")
             .partitionBy(m -> m.key, m -> m.value, "p1")
             .filter(m -> true);
     MessageStream<KV<Object, Object>> messageStream3 =
-        streamGraph.<KV<Object, Object>>getInputStream("input3")
+        graphSpec.<KV<Object, Object>>getInputStream("input3")
             .filter(m -> true)
             .partitionBy(m -> m.key, m -> m.value, "p2")
             .map(m -> m);
-    OutputStream<KV<Object, Object>> outputStream1 = streamGraph.getOutputStream("output1");
-    OutputStream<KV<Object, Object>> outputStream2 = streamGraph.getOutputStream("output2");
+    OutputStream<KV<Object, Object>> outputStream1 = graphSpec.getOutputStream("output1");
+    OutputStream<KV<Object, Object>> outputStream2 = graphSpec.getOutputStream("output2");
 
     messageStream1
-        .join(messageStream2, mock(JoinFunction.class),
+        .join(messageStream2,
+            (JoinFunction<Object, KV<Object, Object>, KV<Object, Object>, KV<Object, Object>>) mock(JoinFunction.class),
             mock(Serde.class), mock(Serde.class), mock(Serde.class), Duration.ofHours(2), "j1")
         .sendTo(outputStream1);
     messageStream2.sink((message, collector, coordinator) -> { });
     messageStream3
-        .join(messageStream2, mock(JoinFunction.class),
+        .join(messageStream2,
+            (JoinFunction<Object, KV<Object, Object>, KV<Object, Object>, KV<Object, Object>>) mock(JoinFunction.class),
             mock(Serde.class), mock(Serde.class), mock(Serde.class), Duration.ofHours(1), "j2")
         .sendTo(outputStream2);
 
     ExecutionPlanner planner = new ExecutionPlanner(config, streamManager);
-    ExecutionPlan plan = planner.plan(streamGraph);
+    ExecutionPlan plan = planner.plan(graphSpec.getOperatorSpecGraph());
     String json = plan.getPlanAsJson();
     System.out.println(json);
 
@@ -187,8 +189,8 @@ public class TestJobGraphJsonGenerator {
     when(systemAdmins.getSystemAdmin("kafka")).thenReturn(systemAdmin2);
     StreamManager streamManager = new StreamManager(systemAdmins);
 
-    StreamGraphImpl streamGraph = new StreamGraphImpl(runner, config);
-    MessageStream<KV<String, PageViewEvent>> inputStream = streamGraph.getInputStream("PageView");
+    StreamGraphSpec graphSpec = new StreamGraphSpec(runner, config);
+    MessageStream<KV<String, PageViewEvent>> inputStream = graphSpec.getInputStream("PageView");
     inputStream
         .partitionBy(kv -> kv.getValue().getCountry(), kv -> kv.getValue(), "keyed-by-country")
         .window(Windows.keyedTumblingWindow(kv -> kv.getValue().getCountry(),
@@ -198,10 +200,10 @@ public class TestJobGraphJsonGenerator {
             new StringSerde(),
             new LongSerde()), "count-by-country")
         .map(pane -> new KV<>(pane.getKey().getKey(), pane.getMessage()))
-        .sendTo(streamGraph.getOutputStream("PageViewCount"));
+        .sendTo(graphSpec.getOutputStream("PageViewCount"));
 
     ExecutionPlanner planner = new ExecutionPlanner(config, streamManager);
-    ExecutionPlan plan = planner.plan(streamGraph);
+    ExecutionPlan plan = planner.plan(graphSpec.getOperatorSpecGraph());
     String json = plan.getPlanAsJson();
     System.out.println(json);
 
index 53e8bf6..c43e242 100644 (file)
@@ -25,8 +25,8 @@ import org.apache.samza.config.MapConfig;
 import org.apache.samza.config.SerializerConfig;
 import org.apache.samza.operators.KV;
 import org.apache.samza.operators.MessageStream;
+import org.apache.samza.operators.StreamGraphSpec;
 import org.apache.samza.operators.OutputStream;
-import org.apache.samza.operators.StreamGraphImpl;
 import org.apache.samza.operators.functions.JoinFunction;
 import org.apache.samza.operators.impl.store.TimestampedValueSerde;
 import org.apache.samza.runtime.ApplicationRunner;
@@ -71,11 +71,11 @@ public class TestJobNode {
     when(mockConfig.get(JobConfig.JOB_NAME())).thenReturn("jobName");
     when(mockConfig.get(eq(JobConfig.JOB_ID()), anyString())).thenReturn("jobId");
 
-    StreamGraphImpl streamGraph = new StreamGraphImpl(mockRunner, mockConfig);
-    streamGraph.setDefaultSerde(KVSerde.of(new StringSerde(), new JsonSerdeV2<>()));
-    MessageStream<KV<String, Object>> input1 = streamGraph.getInputStream("input1");
-    MessageStream<KV<String, Object>> input2 = streamGraph.getInputStream("input2");
-    OutputStream<KV<String, Object>> output = streamGraph.getOutputStream("output");
+    StreamGraphSpec graphSpec = new StreamGraphSpec(mockRunner, mockConfig);
+    graphSpec.setDefaultSerde(KVSerde.of(new StringSerde(), new JsonSerdeV2<>()));
+    MessageStream<KV<String, Object>> input1 = graphSpec.getInputStream("input1");
+    MessageStream<KV<String, Object>> input2 = graphSpec.getInputStream("input2");
+    OutputStream<KV<String, Object>> output = graphSpec.getOutputStream("output");
     JoinFunction<String, Object, Object, KV<String, Object>> mockJoinFn = mock(JoinFunction.class);
     input1
         .partitionBy(KV::getKey, KV::getValue, "p1").map(kv -> kv.value)
@@ -84,7 +84,7 @@ public class TestJobNode {
             Duration.ofHours(1), "j1")
         .sendTo(output);
 
-    JobNode jobNode = new JobNode("jobName", "jobId", streamGraph, mockConfig);
+    JobNode jobNode = new JobNode("jobName", "jobId", graphSpec.getOperatorSpecGraph(), mockConfig);
     Config config = new MapConfig();
     StreamEdge input1Edge = new StreamEdge(input1Spec, config);
     StreamEdge input2Edge = new StreamEdge(input2Spec, config);
index dac4e94..602b595 100644 (file)
@@ -21,9 +21,8 @@ package org.apache.samza.operators;
 import com.google.common.collect.ImmutableSet;
 import org.apache.samza.Partition;
 import org.apache.samza.SamzaException;
-import org.apache.samza.application.StreamApplication;
 import org.apache.samza.config.Config;
-import org.apache.samza.config.JobConfig;
+import org.apache.samza.config.MapConfig;
 import org.apache.samza.container.TaskContextImpl;
 import org.apache.samza.metrics.MetricsRegistryMap;
 import org.apache.samza.operators.functions.JoinFunction;
@@ -44,18 +43,24 @@ import org.apache.samza.task.TaskCoordinator;
 import org.apache.samza.testUtils.TestClock;
 import org.apache.samza.util.Clock;
 import org.apache.samza.util.SystemClock;
+import org.junit.Before;
 import org.junit.Test;
 
+import java.io.IOException;
 import java.time.Duration;
 import java.util.ArrayList;
+import java.util.HashMap;
 import java.util.List;
+import java.util.Map;
 import java.util.Set;
 
-import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertTrue;
-import static org.mockito.Matchers.anyString;
+import static org.junit.Assert.assertEquals;
 import static org.mockito.Matchers.eq;
+import static org.mockito.Matchers.any;
 import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
 
 public class TestJoinOperator {
@@ -64,10 +69,22 @@ public class TestJoinOperator {
   private final TaskCoordinator taskCoordinator = mock(TaskCoordinator.class);
   private final Set<Integer> numbers = ImmutableSet.of(1, 2, 3, 4, 5, 6, 7, 8, 9, 10);
 
+  private Config config;
+
+  @Before
+  public void setUp() {
+    Map<String, String> mapConfig = new HashMap<>();
+    mapConfig.put("app.runner.class", "org.apache.samza.runtime.LocalApplicationRunner");
+    mapConfig.put("job.default.system", "insystem");
+    mapConfig.put("job.name", "jobName");
+    mapConfig.put("job.id", "jobId");
+    config = new MapConfig(mapConfig);
+  }
+
   @Test
   public void join() throws Exception {
-    TestJoinStreamApplication app = new TestJoinStreamApplication(new TestJoinFunction());
-    StreamOperatorTask sot = createStreamOperatorTask(new SystemClock(), app);
+    StreamGraphSpec graphSpec = this.getTestJoinStreamGraph(new TestJoinFunction());
+    StreamOperatorTask sot = createStreamOperatorTask(new SystemClock(), graphSpec);
     List<Integer> output = new ArrayList<>();
     MessageCollector messageCollector = envelope -> output.add((Integer) envelope.getMessage());
 
@@ -82,43 +99,42 @@ public class TestJoinOperator {
 
   @Test(expected = SamzaException.class)
   public void joinWithSelfThrowsException() throws Exception {
-    StreamApplication app = new StreamApplication() {
-      @Override
-      public void init(StreamGraph graph, Config config) {
-        IntegerSerde integerSerde = new IntegerSerde();
-        KVSerde<Integer, Integer> kvSerde = KVSerde.of(integerSerde, integerSerde);
-        MessageStream<KV<Integer, Integer>> inStream = graph.getInputStream("instream", kvSerde);
-
-        inStream.join(inStream, new TestJoinFunction(), integerSerde, kvSerde, kvSerde, JOIN_TTL, "join");
-      }
-    };
-
-    createStreamOperatorTask(new SystemClock(), app); // should throw an exception
+    config.put("streams.instream.system", "insystem");
+
+    StreamGraphSpec graphSpec = new StreamGraphSpec(mock(ApplicationRunner.class), config);
+    IntegerSerde integerSerde = new IntegerSerde();
+    KVSerde<Integer, Integer> kvSerde = KVSerde.of(integerSerde, integerSerde);
+    MessageStream<KV<Integer, Integer>> inStream = graphSpec.getInputStream("instream", kvSerde);
+
+    inStream.join(inStream, new TestJoinFunction(), integerSerde, kvSerde, kvSerde, JOIN_TTL, "join");
+
+    createStreamOperatorTask(new SystemClock(), graphSpec); // should throw an exception
   }
 
   @Test
   public void joinFnInitAndClose() throws Exception {
     TestJoinFunction joinFn = new TestJoinFunction();
-    TestJoinStreamApplication app = new TestJoinStreamApplication(joinFn);
-    StreamOperatorTask sot = createStreamOperatorTask(new SystemClock(), app);
-    assertEquals(1, joinFn.getNumInitCalls());
+    StreamGraphSpec graphSpec = this.getTestJoinStreamGraph(joinFn);
+    StreamOperatorTask sot = createStreamOperatorTask(new SystemClock(), graphSpec);
+
     MessageCollector messageCollector = mock(MessageCollector.class);
 
     // push messages to first stream
     numbers.forEach(n -> sot.process(new FirstStreamIME(n, n), messageCollector, taskCoordinator));
 
     // close should not be called till now
-    assertEquals(0, joinFn.getNumCloseCalls());
     sot.close();
 
-    // close should be called from sot.close()
-    assertEquals(1, joinFn.getNumCloseCalls());
+    verify(messageCollector, times(0)).send(any(OutgoingMessageEnvelope.class));
+    // Make sure the joinFn has been copied instead of directly referred by the task instance
+    assertEquals(0, joinFn.getNumInitCalls());
+    assertEquals(0, joinFn.getNumCloseCalls());
   }
 
   @Test
   public void joinReverse() throws Exception {
-    TestJoinStreamApplication app = new TestJoinStreamApplication(new TestJoinFunction());
-    StreamOperatorTask sot = createStreamOperatorTask(new SystemClock(), app);
+    StreamGraphSpec graphSpec = this.getTestJoinStreamGraph(new TestJoinFunction());
+    StreamOperatorTask sot = createStreamOperatorTask(new SystemClock(), graphSpec);
     List<Integer> output = new ArrayList<>();
     MessageCollector messageCollector = envelope -> output.add((Integer) envelope.getMessage());
 
@@ -133,8 +149,8 @@ public class TestJoinOperator {
 
   @Test
   public void joinNoMatch() throws Exception {
-    TestJoinStreamApplication app = new TestJoinStreamApplication(new TestJoinFunction());
-    StreamOperatorTask sot = createStreamOperatorTask(new SystemClock(), app);
+    StreamGraphSpec graphSpec = this.getTestJoinStreamGraph(new TestJoinFunction());
+    StreamOperatorTask sot = createStreamOperatorTask(new SystemClock(), graphSpec);
     List<Integer> output = new ArrayList<>();
     MessageCollector messageCollector = envelope -> output.add((Integer) envelope.getMessage());
 
@@ -148,8 +164,8 @@ public class TestJoinOperator {
 
   @Test
   public void joinNoMatchReverse() throws Exception {
-    TestJoinStreamApplication app = new TestJoinStreamApplication(new TestJoinFunction());
-    StreamOperatorTask sot = createStreamOperatorTask(new SystemClock(), app);
+    StreamGraphSpec graphSpec = this.getTestJoinStreamGraph(new TestJoinFunction());
+    StreamOperatorTask sot = createStreamOperatorTask(new SystemClock(), graphSpec);
     List<Integer> output = new ArrayList<>();
     MessageCollector messageCollector = envelope -> output.add((Integer) envelope.getMessage());
 
@@ -163,8 +179,8 @@ public class TestJoinOperator {
 
   @Test
   public void joinRetainsLatestMessageForKey() throws Exception {
-    TestJoinStreamApplication app = new TestJoinStreamApplication(new TestJoinFunction());
-    StreamOperatorTask sot = createStreamOperatorTask(new SystemClock(), app);
+    StreamGraphSpec graphSpec = this.getTestJoinStreamGraph(new TestJoinFunction());
+    StreamOperatorTask sot = createStreamOperatorTask(new SystemClock(), graphSpec);
     List<Integer> output = new ArrayList<>();
     MessageCollector messageCollector = envelope -> output.add((Integer) envelope.getMessage());
 
@@ -181,8 +197,8 @@ public class TestJoinOperator {
 
   @Test
   public void joinRetainsLatestMessageForKeyReverse() throws Exception {
-    TestJoinStreamApplication app = new TestJoinStreamApplication(new TestJoinFunction());
-    StreamOperatorTask sot = createStreamOperatorTask(new SystemClock(), app);
+    StreamGraphSpec graphSpec = this.getTestJoinStreamGraph(new TestJoinFunction());
+    StreamOperatorTask sot = createStreamOperatorTask(new SystemClock(), graphSpec);
     List<Integer> output = new ArrayList<>();
     MessageCollector messageCollector = envelope -> output.add((Integer) envelope.getMessage());
 
@@ -199,8 +215,8 @@ public class TestJoinOperator {
 
   @Test
   public void joinRetainsMatchedMessages() throws Exception {
-    TestJoinStreamApplication app = new TestJoinStreamApplication(new TestJoinFunction());
-    StreamOperatorTask sot = createStreamOperatorTask(new SystemClock(), app);
+    StreamGraphSpec graphSpec = this.getTestJoinStreamGraph(new TestJoinFunction());
+    StreamOperatorTask sot = createStreamOperatorTask(new SystemClock(), graphSpec);
     List<Integer> output = new ArrayList<>();
     MessageCollector messageCollector = envelope -> output.add((Integer) envelope.getMessage());
 
@@ -222,8 +238,8 @@ public class TestJoinOperator {
 
   @Test
   public void joinRetainsMatchedMessagesReverse() throws Exception {
-    TestJoinStreamApplication app = new TestJoinStreamApplication(new TestJoinFunction());
-    StreamOperatorTask sot = createStreamOperatorTask(new SystemClock(), app);
+    StreamGraphSpec graphSpec = this.getTestJoinStreamGraph(new TestJoinFunction());
+    StreamOperatorTask sot = createStreamOperatorTask(new SystemClock(), graphSpec);
     List<Integer> output = new ArrayList<>();
     MessageCollector messageCollector = envelope -> output.add((Integer) envelope.getMessage());
 
@@ -246,8 +262,8 @@ public class TestJoinOperator {
   @Test
   public void joinRemovesExpiredMessages() throws Exception {
     TestClock testClock = new TestClock();
-    TestJoinStreamApplication app = new TestJoinStreamApplication(new TestJoinFunction());
-    StreamOperatorTask sot = createStreamOperatorTask(testClock, app);
+    StreamGraphSpec graphSpec = this.getTestJoinStreamGraph(new TestJoinFunction());
+    StreamOperatorTask sot = createStreamOperatorTask(testClock, graphSpec);
     List<Integer> output = new ArrayList<>();
     MessageCollector messageCollector = envelope -> output.add((Integer) envelope.getMessage());
 
@@ -266,8 +282,8 @@ public class TestJoinOperator {
   @Test
   public void joinRemovesExpiredMessagesReverse() throws Exception {
     TestClock testClock = new TestClock();
-    TestJoinStreamApplication app = new TestJoinStreamApplication(new TestJoinFunction());
-    StreamOperatorTask sot = createStreamOperatorTask(testClock, app);
+    StreamGraphSpec graphSpec = this.getTestJoinStreamGraph(new TestJoinFunction());
+    StreamOperatorTask sot = createStreamOperatorTask(testClock, graphSpec);
     List<Integer> output = new ArrayList<>();
     MessageCollector messageCollector = envelope -> output.add((Integer) envelope.getMessage());
 
@@ -283,15 +299,12 @@ public class TestJoinOperator {
     assertTrue(output.isEmpty());
   }
 
-  private StreamOperatorTask createStreamOperatorTask(Clock clock, StreamApplication app) throws Exception {
-    ApplicationRunner runner = mock(ApplicationRunner.class);
-    when(runner.getStreamSpec("instream")).thenReturn(new StreamSpec("instream", "instream", "insystem"));
-    when(runner.getStreamSpec("instream2")).thenReturn(new StreamSpec("instream2", "instream2", "insystem2"));
+  private StreamOperatorTask createStreamOperatorTask(Clock clock, StreamGraphSpec graphSpec) throws Exception {
 
     TaskContextImpl taskContext = mock(TaskContextImpl.class);
     when(taskContext.getSystemStreamPartitions()).thenReturn(ImmutableSet
         .of(new SystemStreamPartition("insystem", "instream", new Partition(0)),
-            new SystemStreamPartition("insystem2", "instream2", new Partition(0))));
+            new SystemStreamPartition("insystem", "instream2", new Partition(0))));
     when(taskContext.getMetricsRegistry()).thenReturn(new MetricsRegistryMap());
     // need to return different stores for left and right side
     IntegerSerde integerSerde = new IntegerSerde();
@@ -301,35 +314,30 @@ public class TestJoinOperator {
     when(taskContext.getStore(eq("jobName-jobId-join-j1-R")))
         .thenReturn(new TestInMemoryStore(integerSerde, timestampedValueSerde));
 
-    Config config = mock(Config.class);
-    when(config.get(JobConfig.JOB_NAME())).thenReturn("jobName");
-    when(config.get(eq(JobConfig.JOB_ID()), anyString())).thenReturn("jobId");
-
-    StreamOperatorTask sot = new StreamOperatorTask(app, runner, clock);
+    StreamOperatorTask sot = new StreamOperatorTask(graphSpec.getOperatorSpecGraph(), graphSpec.getContextManager(), clock);
     sot.init(config, taskContext);
     return sot;
   }
 
-  private static class TestJoinStreamApplication implements StreamApplication {
-
-    private final TestJoinFunction joinFn;
-
-    TestJoinStreamApplication(TestJoinFunction joinFn) {
-      this.joinFn = joinFn;
-    }
+  private StreamGraphSpec getTestJoinStreamGraph(TestJoinFunction joinFn) throws IOException {
+    ApplicationRunner runner = mock(ApplicationRunner.class);
+    when(runner.getStreamSpec("instream")).thenReturn(new StreamSpec("instream", "instream", "insystem"));
+    when(runner.getStreamSpec("instream2")).thenReturn(new StreamSpec("instream2", "instream2", "insystem"));
 
-    @Override
-    public void init(StreamGraph graph, Config config) {
-      IntegerSerde integerSerde = new IntegerSerde();
-      KVSerde<Integer, Integer> kvSerde = KVSerde.of(integerSerde, integerSerde);
-      MessageStream<KV<Integer, Integer>> inStream = graph.getInputStream("instream", kvSerde);
-      MessageStream<KV<Integer, Integer>> inStream2 = graph.getInputStream("instream2", kvSerde);
-
-      SystemStream outputSystemStream = new SystemStream("outputSystem", "outputStream");
-      inStream
-          .join(inStream2, joinFn, integerSerde, kvSerde, kvSerde, JOIN_TTL, "j1")
-          .sink((m, mc, tc) -> mc.send(new OutgoingMessageEnvelope(outputSystemStream, m)));
-    }
+    StreamGraphSpec graphSpec = new StreamGraphSpec(runner, config);
+    IntegerSerde integerSerde = new IntegerSerde();
+    KVSerde<Integer, Integer> kvSerde = KVSerde.of(integerSerde, integerSerde);
+    MessageStream<KV<Integer, Integer>> inStream = graphSpec.getInputStream("instream", kvSerde);
+    MessageStream<KV<Integer, Integer>> inStream2 = graphSpec.getInputStream("instream2", kvSerde);
+
+    inStream
+        .join(inStream2, joinFn, integerSerde, kvSerde, kvSerde, JOIN_TTL, "j1")
+        .sink((message, messageCollector, taskCoordinator) -> {
+            SystemStream outputSystemStream = new SystemStream("outputSystem", "outputStream");
+            messageCollector.send(new OutgoingMessageEnvelope(outputSystemStream, message));
+          });
+
+    return graphSpec;
   }
 
   private static class TestJoinFunction
@@ -380,7 +388,7 @@ public class TestJoinOperator {
 
   private static class SecondStreamIME extends IncomingMessageEnvelope {
     SecondStreamIME(Integer key, Integer value) {
-      super(new SystemStreamPartition("insystem2", "instream2", new Partition(0)), "1", key, value);
+      super(new SystemStreamPartition("insystem", "instream2", new Partition(0)), "1", key, value);
     }
   }
 }
index 96e234e..fff85e8 100644 (file)
  */
 package org.apache.samza.operators;
 
+import com.google.common.collect.ImmutableList;
+import java.io.IOException;
 import java.time.Duration;
 import java.util.Collection;
 import java.util.Collections;
-import java.util.function.Function;
-import java.util.function.Supplier;
 
 import org.apache.samza.operators.data.TestMessageEnvelope;
 import org.apache.samza.operators.data.TestOutputMessageEnvelope;
@@ -32,6 +32,7 @@ import org.apache.samza.operators.functions.FoldLeftFunction;
 import org.apache.samza.operators.functions.JoinFunction;
 import org.apache.samza.operators.functions.MapFunction;
 import org.apache.samza.operators.functions.SinkFunction;
+import org.apache.samza.operators.functions.SupplierFunction;
 import org.apache.samza.operators.functions.StreamTableJoinFunction;
 import org.apache.samza.operators.spec.JoinOperatorSpec;
 import org.apache.samza.operators.spec.OperatorSpec;
@@ -54,8 +55,6 @@ import org.apache.samza.table.TableSpec;
 import org.junit.Test;
 import org.mockito.ArgumentCaptor;
 
-import com.google.common.collect.ImmutableList;
-
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertNotNull;
 import static org.junit.Assert.assertTrue;
@@ -71,7 +70,7 @@ public class TestMessageStreamImpl {
 
   @Test
   public void testMap() {
-    StreamGraphImpl mockGraph = mock(StreamGraphImpl.class);
+    StreamGraphSpec mockGraph = mock(StreamGraphSpec.class);
     OperatorSpec mockOpSpec = mock(OperatorSpec.class);
     MessageStreamImpl<TestMessageEnvelope> inputStream = new MessageStreamImpl<>(mockGraph, mockOpSpec);
 
@@ -96,7 +95,7 @@ public class TestMessageStreamImpl {
 
   @Test
   public void testFlatMap() {
-    StreamGraphImpl mockGraph = mock(StreamGraphImpl.class);
+    StreamGraphSpec mockGraph = mock(StreamGraphSpec.class);
     OperatorSpec mockOpSpec = mock(OperatorSpec.class);
     MessageStreamImpl<TestMessageEnvelope> inputStream = new MessageStreamImpl<>(mockGraph, mockOpSpec);
 
@@ -113,7 +112,7 @@ public class TestMessageStreamImpl {
 
   @Test
   public void testFlatMapWithRelaxedTypes() {
-    StreamGraphImpl mockGraph = mock(StreamGraphImpl.class);
+    StreamGraphSpec mockGraph = mock(StreamGraphSpec.class);
     OperatorSpec mockOpSpec = mock(OperatorSpec.class);
     MessageStreamImpl<TestInputMessageEnvelope> inputStream = new MessageStreamImpl<>(mockGraph, mockOpSpec);
 
@@ -133,7 +132,7 @@ public class TestMessageStreamImpl {
 
   @Test
   public void testFilter() {
-    StreamGraphImpl mockGraph = mock(StreamGraphImpl.class);
+    StreamGraphSpec mockGraph = mock(StreamGraphSpec.class);
     OperatorSpec mockOpSpec = mock(OperatorSpec.class);
     MessageStreamImpl<TestMessageEnvelope> inputStream = new MessageStreamImpl<>(mockGraph, mockOpSpec);
 
@@ -158,7 +157,7 @@ public class TestMessageStreamImpl {
 
   @Test
   public void testSink() {
-    StreamGraphImpl mockGraph = mock(StreamGraphImpl.class);
+    StreamGraphSpec mockGraph = mock(StreamGraphSpec.class);
     OperatorSpec mockOpSpec = mock(OperatorSpec.class);
     MessageStreamImpl<TestMessageEnvelope> inputStream = new MessageStreamImpl<>(mockGraph, mockOpSpec);
 
@@ -175,7 +174,7 @@ public class TestMessageStreamImpl {
 
   @Test
   public void testSendTo() {
-    StreamGraphImpl mockGraph = mock(StreamGraphImpl.class);
+    StreamGraphSpec mockGraph = mock(StreamGraphSpec.class);
     OperatorSpec mockOpSpec = mock(OperatorSpec.class);
     MessageStreamImpl<TestMessageEnvelope> inputStream = new MessageStreamImpl<>(mockGraph, mockOpSpec);
     OutputStreamImpl<TestMessageEnvelope> mockOutputStreamImpl = mock(OutputStreamImpl.class);
@@ -200,8 +199,8 @@ public class TestMessageStreamImpl {
   }
 
   @Test
-  public void testRepartition() {
-    StreamGraphImpl mockGraph = mock(StreamGraphImpl.class);
+  public void testPartitionBy() throws IOException {
+    StreamGraphSpec mockGraph = mock(StreamGraphSpec.class);
     OperatorSpec mockOpSpec = mock(OperatorSpec.class);
     String mockOpName = "mockName";
     when(mockGraph.getNextOpId(anyObject(), anyObject())).thenReturn(mockOpName);
@@ -215,8 +214,8 @@ public class TestMessageStreamImpl {
     when(mockIntermediateStream.isKeyed()).thenReturn(true);
 
     MessageStreamImpl<TestMessageEnvelope> inputStream = new MessageStreamImpl<>(mockGraph, mockOpSpec);
-    Function mockKeyFunction = mock(Function.class);
-    Function mockValueFunction = mock(Function.class);
+    MapFunction mockKeyFunction = mock(MapFunction.class);
+    MapFunction mockValueFunction = mock(MapFunction.class);
     inputStream.partitionBy(mockKeyFunction, mockValueFunction, mockKVSerde, "p1");
 
     ArgumentCaptor<OperatorSpec> registeredOpCaptor = ArgumentCaptor.forClass(OperatorSpec.class);
@@ -232,7 +231,7 @@ public class TestMessageStreamImpl {
 
   @Test
   public void testRepartitionWithoutSerde() {
-    StreamGraphImpl mockGraph = mock(StreamGraphImpl.class);
+    StreamGraphSpec mockGraph = mock(StreamGraphSpec.class);
     OperatorSpec mockOpSpec = mock(OperatorSpec.class);
     String mockOpName = "mockName";
     when(mockGraph.getNextOpId(anyObject(), anyObject())).thenReturn(mockOpName);
@@ -245,8 +244,8 @@ public class TestMessageStreamImpl {
     when(mockIntermediateStream.isKeyed()).thenReturn(true);
 
     MessageStreamImpl<TestMessageEnvelope> inputStream = new MessageStreamImpl<>(mockGraph, mockOpSpec);
-    Function mockKeyFunction = mock(Function.class);
-    Function mockValueFunction = mock(Function.class);
+    MapFunction mockKeyFunction = mock(MapFunction.class);
+    MapFunction mockValueFunction = mock(MapFunction.class);
     inputStream.partitionBy(mockKeyFunction, mockValueFunction, "p1");
 
     ArgumentCaptor<OperatorSpec> registeredOpCaptor = ArgumentCaptor.forClass(OperatorSpec.class);
@@ -262,18 +261,17 @@ public class TestMessageStreamImpl {
 
   @Test
   public void testWindowWithRelaxedTypes() throws Exception {
-    StreamGraphImpl mockGraph = mock(StreamGraphImpl.class);
+    StreamGraphSpec mockGraph = mock(StreamGraphSpec.class);
     OperatorSpec mockOpSpec = mock(OperatorSpec.class);
     MessageStream<TestInputMessageEnvelope> inputStream = new MessageStreamImpl<>(mockGraph, mockOpSpec);
 
-    Function<TestMessageEnvelope, String> keyExtractor = m -> m.getKey();
+    MapFunction<TestMessageEnvelope, String> keyExtractor = m -> m.getKey();
     FoldLeftFunction<TestMessageEnvelope, Integer> aggregator = (m, c) -> c + 1;
-    Supplier<Integer> initialValue = () -> 0;
+    SupplierFunction<Integer> initialValue = () -> 0;
 
     // should compile since TestMessageEnvelope (input for functions) is base class of TestInputMessageEnvelope (M)
-    Window<TestInputMessageEnvelope, String, Integer> window =
-        Windows.keyedTumblingWindow(keyExtractor, Duration.ofHours(1), initialValue, aggregator,
-            null, mock(Serde.class));
+    Window<TestInputMessageEnvelope, String, Integer> window = Windows
+        .keyedTumblingWindow(keyExtractor, Duration.ofHours(1), initialValue, aggregator, null, mock(Serde.class));
     MessageStream<WindowPane<String, Integer>> windowedStream = inputStream.window(window, "w1");
 
     ArgumentCaptor<OperatorSpec> registeredOpCaptor = ArgumentCaptor.forClass(OperatorSpec.class);
@@ -287,7 +285,7 @@ public class TestMessageStreamImpl {
 
   @Test
   public void testJoin() {
-    StreamGraphImpl mockGraph = mock(StreamGraphImpl.class);
+    StreamGraphSpec mockGraph = mock(StreamGraphSpec.class);
     OperatorSpec leftInputOpSpec = mock(OperatorSpec.class);
     MessageStreamImpl<TestMessageEnvelope> source1 = new MessageStreamImpl<>(mockGraph, leftInputOpSpec);
     OperatorSpec rightInputOpSpec = mock(OperatorSpec.class);
@@ -319,7 +317,7 @@ public class TestMessageStreamImpl {
 
   @Test
   public void testSendToTable() {
-    StreamGraphImpl mockGraph = mock(StreamGraphImpl.class);
+    StreamGraphSpec mockGraph = mock(StreamGraphSpec.class);
     OperatorSpec inputOpSpec = mock(OperatorSpec.class);
     MessageStreamImpl<TestMessageEnvelope> source = new MessageStreamImpl<>(mockGraph, inputOpSpec);
 
@@ -336,13 +334,12 @@ public class TestMessageStreamImpl {
     SendToTableOperatorSpec sendToTableOperatorSpec = (SendToTableOperatorSpec) registeredOpSpec;
 
     assertEquals(OpCode.SEND_TO, sendToTableOperatorSpec.getOpCode());
-    assertEquals(inputOpSpec, sendToTableOperatorSpec.getInputOpSpec());
     assertEquals(tableSpec, sendToTableOperatorSpec.getTableSpec());
   }
 
   @Test
   public void testStreamTableJoin() {
-    StreamGraphImpl mockGraph = mock(StreamGraphImpl.class);
+    StreamGraphSpec mockGraph = mock(StreamGraphSpec.class);
     OperatorSpec leftInputOpSpec = mock(OperatorSpec.class);
     MessageStreamImpl<KV<String, TestMessageEnvelope>> source1 = new MessageStreamImpl<>(mockGraph, leftInputOpSpec);
     OperatorSpec rightInputOpSpec = mock(OperatorSpec.class);
@@ -370,7 +367,7 @@ public class TestMessageStreamImpl {
 
   @Test
   public void testMerge() {
-    StreamGraphImpl mockGraph = mock(StreamGraphImpl.class);
+    StreamGraphSpec mockGraph = mock(StreamGraphSpec.class);
     OperatorSpec mockOpSpec1 = mock(OperatorSpec.class);
     MessageStream<TestMessageEnvelope> inputStream = new MessageStreamImpl<>(mockGraph, mockOpSpec1);
 
@@ -410,7 +407,7 @@ public class TestMessageStreamImpl {
 
   @Test
   public void testMergeWithRelaxedTypes() {
-    StreamGraphImpl mockGraph = mock(StreamGraphImpl.class);
+    StreamGraphSpec mockGraph = mock(StreamGraphSpec.class);
     MessageStream<TestMessageEnvelope> inputStream = new MessageStreamImpl<>(mockGraph, mock(OperatorSpec.class));
 
     // other streams have the same message type T as input stream message type M
diff --git a/samza-core/src/test/java/org/apache/samza/operators/TestOperatorSpecGraph.java b/samza-core/src/test/java/org/apache/samza/operators/TestOperatorSpecGraph.java
new file mode 100644 (file)
index 0000000..2be88ca
--- /dev/null
@@ -0,0 +1,185 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for THE
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.samza.operators;
+
+import java.io.IOException;
+import java.io.NotSerializableException;
+import java.io.ObjectInputStream;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.LinkedHashMap;
+import java.util.Map;
+import java.util.Set;
+import org.apache.samza.SamzaException;
+import org.apache.samza.operators.functions.TimerFunction;
+import org.apache.samza.operators.functions.WatermarkFunction;
+import org.apache.samza.operators.spec.InputOperatorSpec;
+import org.apache.samza.operators.spec.OperatorSpec;
+import org.apache.samza.operators.spec.OperatorSpecTestUtils;
+import org.apache.samza.operators.spec.OperatorSpecs;
+import org.apache.samza.operators.spec.OutputOperatorSpec;
+import org.apache.samza.operators.spec.OutputStreamImpl;
+import org.apache.samza.operators.spec.SinkOperatorSpec;
+import org.apache.samza.operators.spec.StreamOperatorSpec;
+import org.apache.samza.serializers.NoOpSerde;
+import org.apache.samza.system.StreamSpec;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.powermock.api.mockito.PowerMockito;
+import org.powermock.core.classloader.annotations.PrepareForTest;
+import org.powermock.modules.junit4.PowerMockRunner;
+
+import static org.junit.Assert.*;
+import static org.mockito.Mockito.*;
+
+
+/**
+ * Unit tests for {@link OperatorSpecGraph}
+ */
+@RunWith(PowerMockRunner.class)
+@PrepareForTest(OperatorSpec.class)
+public class TestOperatorSpecGraph {
+
+  private StreamGraphSpec mockGraph;
+  private Map<StreamSpec, InputOperatorSpec> inputOpSpecMap;
+  private Map<StreamSpec, OutputStreamImpl> outputStrmMap;
+  private Set<OperatorSpec> allOpSpecs;
+
+  @Before
+  public void setUp() {
+    this.mockGraph = mock(StreamGraphSpec.class);
+
+    /**
+     * Setup two linear transformation pipelines:
+     * 1) input1 --> filter --> sendTo
+     * 2) input2 --> map --> sink
+     */
+    StreamSpec testInputSpec = new StreamSpec("test-input-1", "test-input-1", "kafka");
+    InputOperatorSpec testInput = new InputOperatorSpec(testInputSpec, new NoOpSerde(), new NoOpSerde(), true, "test-input-1");
+    StreamOperatorSpec filterOp = OperatorSpecs.createFilterOperatorSpec(m -> true, "test-filter-2");
+    StreamSpec testOutputSpec = new StreamSpec("test-output-1", "test-output-1", "kafka");
+    OutputStreamImpl outputStream1 = new OutputStreamImpl(testOutputSpec, null, null, true);
+    OutputOperatorSpec outputSpec = OperatorSpecs.createSendToOperatorSpec(outputStream1, "test-output-3");
+    testInput.registerNextOperatorSpec(filterOp);
+    filterOp.registerNextOperatorSpec(outputSpec);
+    StreamSpec testInputSpec2 = new StreamSpec("test-input-2", "test-input-2", "kafka");
+    InputOperatorSpec testInput2 = new InputOperatorSpec(testInputSpec2, new NoOpSerde(), new NoOpSerde(), true, "test-input-4");
+    StreamOperatorSpec testMap = OperatorSpecs.createMapOperatorSpec(m -> m, "test-map-5");
+    SinkOperatorSpec testSink = OperatorSpecs.createSinkOperatorSpec((m, mc, tc) -> { }, "test-sink-6");
+    testInput2.registerNextOperatorSpec(testMap);
+    testMap.registerNextOperatorSpec(testSink);
+
+    this.inputOpSpecMap = new LinkedHashMap<>();
+    inputOpSpecMap.put(testInputSpec, testInput);
+    inputOpSpecMap.put(testInputSpec2, testInput2);
+    this.outputStrmMap = new LinkedHashMap<>();
+    outputStrmMap.put(testOutputSpec, outputStream1);
+    when(mockGraph.getInputOperators()).thenReturn(Collections.unmodifiableMap(inputOpSpecMap));
+    when(mockGraph.getOutputStreams()).thenReturn(Collections.unmodifiableMap(outputStrmMap));
+    this.allOpSpecs = new HashSet<OperatorSpec>() { {
+        this.add(testInput);
+        this.add(filterOp);
+        this.add(outputSpec);
+        this.add(testInput2);
+        this.add(testMap);
+        this.add(testSink);
+      } };
+  }
+
+  @After
+  public void tearDown() {
+    this.mockGraph = null;
+    this.inputOpSpecMap = null;
+    this.outputStrmMap = null;
+    this.allOpSpecs = null;
+  }
+
+  @Test
+  public void testConstructor() {
+    OperatorSpecGraph specGraph = new OperatorSpecGraph(mockGraph);
+    assertEquals(specGraph.getInputOperators(), inputOpSpecMap);
+    assertEquals(specGraph.getOutputStreams(), outputStrmMap);
+    assertTrue(specGraph.getTables().isEmpty());
+    assertTrue(!specGraph.hasWindowOrJoins());
+    assertEquals(specGraph.getAllOperatorSpecs(), this.allOpSpecs);
+  }
+
+  @Test
+  public void testClone() {
+    OperatorSpecGraph operatorSpecGraph = new OperatorSpecGraph(mockGraph);
+    OperatorSpecGraph clonedSpecGraph = operatorSpecGraph.clone();
+    OperatorSpecTestUtils.assertClonedGraph(operatorSpecGraph, clonedSpecGraph);
+  }
+
+  @Test(expected = NotSerializableException.class)
+  public void testCloneWithSerializationError() throws Throwable {
+    OperatorSpec mockFailedOpSpec = PowerMockito.mock(OperatorSpec.class);
+    when(mockFailedOpSpec.getOpId()).thenReturn("test-failed-op-4");
+    allOpSpecs.add(mockFailedOpSpec);
+    inputOpSpecMap.values().stream().findFirst().get().registerNextOperatorSpec(mockFailedOpSpec);
+
+    //failed with serialization error
+    try {
+      new OperatorSpecGraph(mockGraph);
+      fail("Should have failed with serialization error");
+    } catch (SamzaException se) {
+      throw se.getCause();
+    }
+  }
+
+  @Test(expected = IOException.class)
+  public void testCloneWithDeserializationError() throws Throwable {
+    TestDeserializeOperatorSpec testOp = new TestDeserializeOperatorSpec(OperatorSpec.OpCode.MAP, "test-failed-op-4");
+    this.allOpSpecs.add(testOp);
+    inputOpSpecMap.values().stream().findFirst().get().registerNextOperatorSpec(testOp);
+
+    OperatorSpecGraph operatorSpecGraph = new OperatorSpecGraph(mockGraph);
+    //failed with serialization error
+    try {
+      operatorSpecGraph.clone();
+      fail("Should have failed with serialization error");
+    } catch (SamzaException se) {
+      throw se.getCause();
+    }
+  }
+
+  private static class TestDeserializeOperatorSpec extends OperatorSpec {
+
+    public TestDeserializeOperatorSpec(OpCode opCode, String opId) {
+      super(opCode, opId);
+    }
+
+    private void readObject(ObjectInputStream ois) throws IOException, ClassNotFoundException {
+      throw new IOException("Raise IOException to cause deserialization failure");
+    }
+
+    @Override
+    public WatermarkFunction getWatermarkFn() {
+      return null;
+    }
+
+    @Override
+    public TimerFunction getTimerFn() {
+      return null;
+    }
+  }
+
+}
@@ -18,6 +18,7 @@
  */
 package org.apache.samza.operators;
 
+import com.google.common.collect.ImmutableList;
 import java.util.ArrayList;
 import java.util.HashMap;
 import java.util.List;
@@ -38,33 +39,32 @@ import org.apache.samza.system.StreamSpec;
 import org.apache.samza.table.TableSpec;
 import org.junit.Test;
 
-import com.google.common.collect.ImmutableList;
-import junit.framework.Assert;
-
 import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
 import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
 import static org.mockito.Matchers.anyString;
 import static org.mockito.Matchers.eq;
 import static org.mockito.Mockito.doReturn;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.when;
 
-public class TestStreamGraphImpl {
+public class TestStreamGraphSpec {
 
   @Test
   public void testGetInputStreamWithValueSerde() {
     ApplicationRunner mockRunner = mock(ApplicationRunner.class);
     StreamSpec mockStreamSpec = mock(StreamSpec.class);
     when(mockRunner.getStreamSpec("test-stream-1")).thenReturn(mockStreamSpec);
-    StreamGraphImpl graph = new StreamGraphImpl(mockRunner, mock(Config.class));
+    StreamGraphSpec graphSpec = new StreamGraphSpec(mockRunner, mock(Config.class));
 
     Serde mockValueSerde = mock(Serde.class);
-    MessageStream<TestMessageEnvelope> inputStream = graph.getInputStream("test-stream-1", mockValueSerde);
+    MessageStream<TestMessageEnvelope> inputStream = graphSpec.getInputStream("test-stream-1", mockValueSerde);
 
     InputOperatorSpec<String, TestMessageEnvelope> inputOpSpec =
         (InputOperatorSpec) ((MessageStreamImpl<TestMessageEnvelope>) inputStream).getOperatorSpec();
     assertEquals(OpCode.INPUT, inputOpSpec.getOpCode());
-    assertEquals(graph.getInputOperators().get(mockStreamSpec), inputOpSpec);
+    assertEquals(graphSpec.getInputOperators().get(mockStreamSpec), inputOpSpec);
     assertEquals(mockStreamSpec, inputOpSpec.getStreamSpec());
     assertTrue(inputOpSpec.getKeySerde() instanceof NoOpSerde);
     assertEquals(mockValueSerde, inputOpSpec.getValueSerde());
@@ -75,19 +75,19 @@ public class TestStreamGraphImpl {
     ApplicationRunner mockRunner = mock(ApplicationRunner.class);
     StreamSpec mockStreamSpec = mock(StreamSpec.class);
     when(mockRunner.getStreamSpec("test-stream-1")).thenReturn(mockStreamSpec);
-    StreamGraphImpl graph = new StreamGraphImpl(mockRunner, mock(Config.class));
+    StreamGraphSpec graphSpec = new StreamGraphSpec(mockRunner, mock(Config.class));
 
     KVSerde mockKVSerde = mock(KVSerde.class);
     Serde mockKeySerde = mock(Serde.class);
     Serde mockValueSerde = mock(Serde.class);
     doReturn(mockKeySerde).when(mockKVSerde).getKeySerde();
     doReturn(mockValueSerde).when(mockKVSerde).getValueSerde();
-    MessageStream<TestMessageEnvelope> inputStream = graph.getInputStream("test-stream-1", mockKVSerde);
+    MessageStream<TestMessageEnvelope> inputStream = graphSpec.getInputStream("test-stream-1", mockKVSerde);
 
     InputOperatorSpec<String, TestMessageEnvelope> inputOpSpec =
         (InputOperatorSpec) ((MessageStreamImpl<TestMessageEnvelope>) inputStream).getOperatorSpec();
     assertEquals(OpCode.INPUT, inputOpSpec.getOpCode());
-    assertEquals(graph.getInputOperators().get(mockStreamSpec), inputOpSpec);
+    assertEquals(graphSpec.getInputOperators().get(mockStreamSpec), inputOpSpec);
     assertEquals(mockStreamSpec, inputOpSpec.getStreamSpec());
     assertEquals(mockKeySerde, inputOpSpec.getKeySerde());
     assertEquals(mockValueSerde, inputOpSpec.getValueSerde());
@@ -98,9 +98,9 @@ public class TestStreamGraphImpl {
     ApplicationRunner mockRunner = mock(ApplicationRunner.class);
     StreamSpec mockStreamSpec = mock(StreamSpec.class);
     when(mockRunner.getStreamSpec("test-stream-1")).thenReturn(mockStreamSpec);
-    StreamGraphImpl graph = new StreamGraphImpl(mockRunner, mock(Config.class));
+    StreamGraphSpec graphSpec = new StreamGraphSpec(mockRunner, mock(Config.class));
 
-    graph.getInputStream("test-stream-1", null);
+    graphSpec.getInputStream("test-stream-1", null);
   }
 
   @Test
@@ -108,16 +108,16 @@ public class TestStreamGraphImpl {
     ApplicationRunner mockRunner = mock(ApplicationRunner.class);
     StreamSpec mockStreamSpec = mock(StreamSpec.class);
     when(mockRunner.getStreamSpec("test-stream-1")).thenReturn(mockStreamSpec);
-    StreamGraphImpl graph = new StreamGraphImpl(mockRunner, mock(Config.class));
+    StreamGraphSpec graphSpec = new StreamGraphSpec(mockRunner, mock(Config.class));
 
     Serde mockValueSerde = mock(Serde.class);
-    graph.setDefaultSerde(mockValueSerde);
-    MessageStream<TestMessageEnvelope> inputStream = graph.getInputStream("test-stream-1");
+    graphSpec.setDefaultSerde(mockValueSerde);
+    MessageStream<TestMessageEnvelope> inputStream = graphSpec.getInputStream("test-stream-1");
 
     InputOperatorSpec<String, TestMessageEnvelope> inputOpSpec =
         (InputOperatorSpec) ((MessageStreamImpl<TestMessageEnvelope>) inputStream).getOperatorSpec();
     assertEquals(OpCode.INPUT, inputOpSpec.getOpCode());
-    assertEquals(graph.getInputOperators().get(mockStreamSpec), inputOpSpec);
+    assertEquals(graphSpec.getInputOperators().get(mockStreamSpec), inputOpSpec);
     assertEquals(mockStreamSpec, inputOpSpec.getStreamSpec());
     assertTrue(inputOpSpec.getKeySerde() instanceof NoOpSerde);
     assertEquals(mockValueSerde, inputOpSpec.getValueSerde());
@@ -128,20 +128,20 @@ public class TestStreamGraphImpl {
     ApplicationRunner mockRunner = mock(ApplicationRunner.class);
     StreamSpec mockStreamSpec = mock(StreamSpec.class);
     when(mockRunner.getStreamSpec("test-stream-1")).thenReturn(mockStreamSpec);
-    StreamGraphImpl graph = new StreamGraphImpl(mockRunner, mock(Config.class));
+    StreamGraphSpec graphSpec = new StreamGraphSpec(mockRunner, mock(Config.class));
 
     KVSerde mockKVSerde = mock(KVSerde.class);
     Serde mockKeySerde = mock(Serde.class);
     Serde mockValueSerde = mock(Serde.class);
     doReturn(mockKeySerde).when(mockKVSerde).getKeySerde();
     doReturn(mockValueSerde).when(mockKVSerde).getValueSerde();
-    graph.setDefaultSerde(mockKVSerde);
-    MessageStream<TestMessageEnvelope> inputStream = graph.getInputStream("test-stream-1");
+    graphSpec.setDefaultSerde(mockKVSerde);
+    MessageStream<TestMessageEnvelope> inputStream = graphSpec.getInputStream("test-stream-1");
 
     InputOperatorSpec<String, TestMessageEnvelope> inputOpSpec =
         (InputOperatorSpec) ((MessageStreamImpl<TestMessageEnvelope>) inputStream).getOperatorSpec();
     assertEquals(OpCode.INPUT, inputOpSpec.getOpCode());
-    assertEquals(graph.getInputOperators().get(mockStreamSpec), inputOpSpec);
+    assertEquals(graphSpec.getInputOperators().get(mockStreamSpec), inputOpSpec);
     assertEquals(mockStreamSpec, inputOpSpec.getStreamSpec());
     assertEquals(mockKeySerde, inputOpSpec.getKeySerde());
     assertEquals(mockValueSerde, inputOpSpec.getValueSerde());
@@ -153,14 +153,14 @@ public class TestStreamGraphImpl {
     ApplicationRunner mockRunner = mock(ApplicationRunner.class);
     StreamSpec mockStreamSpec = mock(StreamSpec.class);
     when(mockRunner.getStreamSpec("test-stream-1")).thenReturn(mockStreamSpec);
-    StreamGraphImpl graph = new StreamGraphImpl(mockRunner, mock(Config.class));
+    StreamGraphSpec graphSpec = new StreamGraphSpec(mockRunner, mock(Config.class));
 
-    MessageStream<TestMessageEnvelope> inputStream = graph.getInputStream("test-stream-1");
+    MessageStream<TestMessageEnvelope> inputStream = graphSpec.getInputStream("test-stream-1");
 
     InputOperatorSpec<String, TestMessageEnvelope> inputOpSpec =
         (InputOperatorSpec) ((MessageStreamImpl<TestMessageEnvelope>) inputStream).getOperatorSpec();
     assertEquals(OpCode.INPUT, inputOpSpec.getOpCode());
-    assertEquals(graph.getInputOperators().get(mockStreamSpec), inputOpSpec);
+    assertEquals(graphSpec.getInputOperators().get(mockStreamSpec), inputOpSpec);
     assertEquals(mockStreamSpec, inputOpSpec.getStreamSpec());
     assertTrue(inputOpSpec.getKeySerde() instanceof NoOpSerde);
     assertTrue(inputOpSpec.getValueSerde() instanceof NoOpSerde);
@@ -171,14 +171,14 @@ public class TestStreamGraphImpl {
     ApplicationRunner mockRunner = mock(ApplicationRunner.class);
     StreamSpec mockStreamSpec = mock(StreamSpec.class);
     when(mockRunner.getStreamSpec("test-stream-1")).thenReturn(mockStreamSpec);
-    StreamGraphImpl graph = new StreamGraphImpl(mockRunner, mock(Config.class));
+    StreamGraphSpec graphSpec = new StreamGraphSpec(mockRunner, mock(Config.class));
 
-    MessageStream<TestMessageEnvelope> inputStream = graph.getInputStream("test-stream-1");
+    MessageStream<TestMessageEnvelope> inputStream = graphSpec.getInputStream("test-stream-1");
 
     InputOperatorSpec<String, TestMessageEnvelope> inputOpSpec =
         (InputOperatorSpec) ((MessageStreamImpl<TestMessageEnvelope>) inputStream).getOperatorSpec();
     assertEquals(OpCode.INPUT, inputOpSpec.getOpCode());
-    assertEquals(graph.getInputOperators().get(mockStreamSpec), inputOpSpec);
+    assertEquals(graphSpec.getInputOperators().get(mockStreamSpec), inputOpSpec);
     assertEquals(mockStreamSpec, inputOpSpec.getStreamSpec());
   }
 
@@ -190,18 +190,18 @@ public class TestStreamGraphImpl {
     when(mockRunner.getStreamSpec("test-stream-1")).thenReturn(mockStreamSpec1);
     when(mockRunner.getStreamSpec("test-stream-2")).thenReturn(mockStreamSpec2);
 
-    StreamGraphImpl graph = new StreamGraphImpl(mockRunner, mock(Config.class));
-    MessageStream<Object> inputStream1 = graph.getInputStream("test-stream-1");
-    MessageStream<Object> inputStream2 = graph.getInputStream("test-stream-2");
+    StreamGraphSpec graphSpec = new StreamGraphSpec(mockRunner, mock(Config.class));
+    MessageStream<Object> inputStream1 = graphSpec.getInputStream("test-stream-1");
+    MessageStream<Object> inputStream2 = graphSpec.getInputStream("test-stream-2");
 
     InputOperatorSpec<String, TestMessageEnvelope> inputOpSpec1 =
         (InputOperatorSpec) ((MessageStreamImpl<Object>) inputStream1).getOperatorSpec();
     InputOperatorSpec<String, TestMessageEnvelope> inputOpSpec2 =
         (InputOperatorSpec) ((MessageStreamImpl<Object>) inputStream2).getOperatorSpec();
 
-    assertEquals(graph.getInputOperators().size(), 2);
-    assertEquals(graph.getInputOperators().get(mockStreamSpec1), inputOpSpec1);
-    assertEquals(graph.getInputOperators().get(mockStreamSpec2), inputOpSpec2);
+    assertEquals(graphSpec.getInputOperators().size(), 2);
+    assertEquals(graphSpec.getInputOperators().get(mockStreamSpec1), inputOpSpec1);
+    assertEquals(graphSpec.getInputOperators().get(mockStreamSpec2), inputOpSpec2);
   }
 
   @Test(expected = IllegalStateException.class)
@@ -209,10 +209,10 @@ public class TestStreamGraphImpl {
     ApplicationRunner mockRunner = mock(ApplicationRunner.class);
     when(mockRunner.getStreamSpec("test-stream-1")).thenReturn(mock(StreamSpec.class));
 
-    StreamGraphImpl graph = new StreamGraphImpl(mockRunner, mock(Config.class));
-    graph.getInputStream("test-stream-1");
+    StreamGraphSpec graphSpec = new StreamGraphSpec(mockRunner, mock(Config.class));
+    graphSpec.getInputStream("test-stream-1");
     // should throw exception
-    graph.getInputStream("test-stream-1");
+    graphSpec.getInputStream("test-stream-1");
   }
 
   @Test
@@ -221,14 +221,14 @@ public class TestStreamGraphImpl {
     StreamSpec mockStreamSpec = mock(StreamSpec.class);
     when(mockRunner.getStreamSpec("test-stream-1")).thenReturn(mockStreamSpec);
 
-    StreamGraphImpl graph = new StreamGraphImpl(mockRunner, mock(Config.class));
+    StreamGraphSpec graphSpec = new StreamGraphSpec(mockRunner, mock(Config.class));
 
     Serde mockValueSerde = mock(Serde.class);
     OutputStream<TestMessageEnvelope> outputStream =
-        graph.getOutputStream("test-stream-1", mockValueSerde);
+        graphSpec.getOutputStream("test-stream-1", mockValueSerde);
 
     OutputStreamImpl<TestMessageEnvelope> outputStreamImpl = (OutputStreamImpl) outputStream;
-    assertEquals(graph.getOutputStreams().get(mockStreamSpec), outputStreamImpl);
+    assertEquals(graphSpec.getOutputStreams().get(mockStreamSpec), outputStreamImpl);
     assertEquals(mockStreamSpec, outputStreamImpl.getStreamSpec());
     assertTrue(outputStreamImpl.getKeySerde() instanceof NoOpSerde);
     assertEquals(mockValueSerde, outputStreamImpl.getValueSerde());
@@ -240,17 +240,17 @@ public class TestStreamGraphImpl {
     StreamSpec mockStreamSpec = mock(StreamSpec.class);
     when(mockRunner.getStreamSpec("test-stream-1")).thenReturn(mockStreamSpec);
 
-    StreamGraphImpl graph = new StreamGraphImpl(mockRunner, mock(Config.class));
+    StreamGraphSpec graphSpec = new StreamGraphSpec(mockRunner, mock(Config.class));
     KVSerde mockKVSerde = mock(KVSerde.class);
     Serde mockKeySerde = mock(Serde.class);
     Serde mockValueSerde = mock(Serde.class);
     doReturn(mockKeySerde).when(mockKVSerde).getKeySerde();
     doReturn(mockValueSerde).when(mockKVSerde).getValueSerde();
-    graph.setDefaultSerde(mockKVSerde);
-    OutputStream<TestMessageEnvelope> outputStream = graph.getOutputStream("test-stream-1", mockKVSerde);
+    graphSpec.setDefaultSerde(mockKVSerde);
+    OutputStream<TestMessageEnvelope> outputStream = graphSpec.getOutputStream("test-stream-1", mockKVSerde);
 
     OutputStreamImpl<TestMessageEnvelope> outputStreamImpl = (OutputStreamImpl) outputStream;
-    assertEquals(graph.getOutputStreams().get(mockStreamSpec), outputStreamImpl);
+    assertEquals(graphSpec.getOutputStreams().get(mockStreamSpec), outputStreamImpl);
     assertEquals(mockStreamSpec, outputStreamImpl.getStreamSpec());
     assertEquals(mockKeySerde, outputStreamImpl.getKeySerde());
     assertEquals(mockValueSerde, outputStreamImpl.getValueSerde());
@@ -262,9 +262,9 @@ public class TestStreamGraphImpl {
     StreamSpec mockStreamSpec = mock(StreamSpec.class);
     when(mockRunner.getStreamSpec("test-stream-1")).thenReturn(mockStreamSpec);
 
-    StreamGraphImpl graph = new StreamGraphImpl(mockRunner, mock(Config.class));
+    StreamGraphSpec graphSpec = new StreamGraphSpec(mockRunner, mock(Config.class));
 
-    graph.getOutputStream("test-stream-1", null);
+    graphSpec.getOutputStream("test-stream-1", null);
   }
 
   @Test
@@ -274,12 +274,12 @@ public class TestStreamGraphImpl {
     when(mockRunner.getStreamSpec("test-stream-1")).thenReturn(mockStreamSpec);
 
     Serde mockValueSerde = mock(Serde.class);
-    StreamGraphImpl graph = new StreamGraphImpl(mockRunner, mock(Config.class));
-    graph.setDefaultSerde(mockValueSerde);
-    OutputStream<TestMessageEnvelope> outputStream = graph.getOutputStream("test-stream-1");
+    StreamGraphSpec graphSpec = new StreamGraphSpec(mockRunner, mock(Config.class));
+    graphSpec.setDefaultSerde(mockValueSerde);
+    OutputStream<TestMessageEnvelope> outputStream = graphSpec.getOutputStream("test-stream-1");
 
     OutputStreamImpl<TestMessageEnvelope> outputStreamImpl = (OutputStreamImpl) outputStream;
-    assertEquals(graph.getOutputStreams().get(mockStreamSpec), outputStreamImpl);
+    assertEquals(graphSpec.getOutputStreams().get(mockStreamSpec), outputStreamImpl);
     assertEquals(mockStreamSpec, outputStreamImpl.getStreamSpec());
     assertTrue(outputStreamImpl.getKeySerde() instanceof NoOpSerde);
     assertEquals(mockValueSerde, outputStreamImpl.getValueSerde());
@@ -291,18 +291,18 @@ public class TestStreamGraphImpl {
     StreamSpec mockStreamSpec = mock(StreamSpec.class);
     when(mockRunner.getStreamSpec("test-stream-1")).thenReturn(mockStreamSpec);
 
-    StreamGraphImpl graph = new StreamGraphImpl(mockRunner, mock(Config.class));
+    StreamGraphSpec graphSpec = new StreamGraphSpec(mockRunner, mock(Config.class));
     KVSerde mockKVSerde = mock(KVSerde.class);
     Serde mockKeySerde = mock(Serde.class);
     Serde mockValueSerde = mock(Serde.class);
     doReturn(mockKeySerde).when(mockKVSerde).getKeySerde();
     doReturn(mockValueSerde).when(mockKVSerde).getValueSerde();
-    graph.setDefaultSerde(mockKVSerde);
+    graphSpec.setDefaultSerde(mockKVSerde);
 
-    OutputStream<TestMessageEnvelope> outputStream = graph.getOutputStream("test-stream-1");
+    OutputStream<TestMessageEnvelope> outputStream = graphSpec.getOutputStream("test-stream-1");
 
     OutputStreamImpl<TestMessageEnvelope> outputStreamImpl = (OutputStreamImpl) outputStream;
-    assertEquals(graph.getOutputStreams().get(mockStreamSpec), outputStreamImpl);
+    assertEquals(graphSpec.getOutputStreams().get(mockStreamSpec), outputStreamImpl);
     assertEquals(mockStreamSpec, outputStreamImpl.getStreamSpec());
     assertEquals(mockKeySerde, outputStreamImpl.getKeySerde());
     assertEquals(mockValueSerde, outputStreamImpl.getValueSerde());
@@ -314,12 +314,12 @@ public class TestStreamGraphImpl {
     StreamSpec mockStreamSpec = mock(StreamSpec.class);
     when(mockRunner.getStreamSpec("test-stream-1")).thenReturn(mockStreamSpec);
 
-    StreamGraphImpl graph = new StreamGraphImpl(mockRunner, mock(Config.class));
+    StreamGraphSpec graphSpec = new StreamGraphSpec(mockRunner, mock(Config.class));
 
-    OutputStream<TestMessageEnvelope> outputStream = graph.getOutputStream("test-stream-1");
+    OutputStream<TestMessageEnvelope> outputStream = graphSpec.getOutputStream("test-stream-1");
 
     OutputStreamImpl<TestMessageEnvelope> outputStreamImpl = (OutputStreamImpl) outputStream;
-    assertEquals(graph.getOutputStreams().get(mockStreamSpec), outputStreamImpl);
+    assertEquals(graphSpec.getOutputStreams().get(mockStreamSpec), outputStreamImpl);
     assertEquals(mockStreamSpec, outputStreamImpl.getStreamSpec());
     assertTrue(outputStreamImpl.getKeySerde() instanceof NoOpSerde);
     assertTrue(outputStreamImpl.getValueSerde() instanceof NoOpSerde);
@@ -330,9 +330,9 @@ public class TestStreamGraphImpl {
     ApplicationRunner mockRunner = mock(ApplicationRunner.class);
     when(mockRunner.getStreamSpec("test-stream-1")).thenReturn(mock(StreamSpec.class));
 
-    StreamGraphImpl graph = new StreamGraphImpl(mockRunner, mock(Config.class));
-    graph.getInputStream("test-stream-1");
-    graph.setDefaultSerde(mock(Serde.class)); // should throw exception
+    StreamGraphSpec graphSpec = new StreamGraphSpec(mockRunner, mock(Config.class));
+    graphSpec.getInputStream("test-stream-1");
+    graphSpec.setDefaultSerde(mock(Serde.class)); // should throw exception
   }
 
   @Test(expected = IllegalStateException.class)
@@ -340,9 +340,9 @@ public class TestStreamGraphImpl {
     ApplicationRunner mockRunner = mock(ApplicationRunner.class);
     when(mockRunner.getStreamSpec("test-stream-1")).thenReturn(mock(StreamSpec.class));
 
-    StreamGraphImpl graph = new StreamGraphImpl(mockRunner, mock(Config.class));
-    graph.getOutputStream("test-stream-1");
-    graph.setDefaultSerde(mock(Serde.class)); // should throw exception
+    StreamGraphSpec graphSpec = new StreamGraphSpec(mockRunner, mock(Config.class));
+    graphSpec.getOutputStream("test-stream-1");
+    graphSpec.setDefaultSerde(mock(Serde.class)); // should throw exception
   }
 
   @Test(expected = IllegalStateException.class)
@@ -350,9 +350,9 @@ public class TestStreamGraphImpl {
     ApplicationRunner mockRunner = mock(ApplicationRunner.class);
     when(mockRunner.getStreamSpec("test-stream-1")).thenReturn(mock(StreamSpec.class));
 
-    StreamGraphImpl graph = new StreamGraphImpl(mockRunner, mock(Config.class));
-    graph.getIntermediateStream("test-stream-1", null);
-    graph.setDefaultSerde(mock(Serde.class)); // should throw exception
+    StreamGraphSpec graphSpec = new StreamGraphSpec(mockRunner, mock(Config.class));
+    graphSpec.getIntermediateStream("test-stream-1", null);
+    graphSpec.setDefaultSerde(mock(Serde.class)); // should throw exception
   }
 
   @Test(expected = IllegalStateException.class)
@@ -360,9 +360,9 @@ public class TestStreamGraphImpl {
     ApplicationRunner mockRunner = mock(ApplicationRunner.class);
     when(mockRunner.getStreamSpec("test-stream-1")).thenReturn(mock(StreamSpec.class));
 
-    StreamGraphImpl graph = new StreamGraphImpl(mockRunner, mock(Config.class));
-    graph.getOutputStream("test-stream-1");
-    graph.getOutputStream("test-stream-1"); // should throw exception
+    StreamGraphSpec graphSpec = new StreamGraphSpec(mockRunner, mock(Config.class));
+    graphSpec.getOutputStream("test-stream-1");
+    graphSpec.getOutputStream("test-stream-1"); // should throw exception
   }
 
   @Test
@@ -373,14 +373,14 @@ public class TestStreamGraphImpl {
     String mockStreamName = "mockStreamName";
     when(mockRunner.getStreamSpec(mockStreamName)).thenReturn(mockStreamSpec);
 
-    StreamGraphImpl graph = new StreamGraphImpl(mockRunner, mockConfig);
+    StreamGraphSpec graphSpec = new StreamGraphSpec(mockRunner, mockConfig);
 
     Serde mockValueSerde = mock(Serde.class);
     IntermediateMessageStreamImpl<TestMessageEnvelope> intermediateStreamImpl =
-        graph.getIntermediateStream(mockStreamName, mockValueSerde);
+        graphSpec.getIntermediateStream(mockStreamName, mockValueSerde);
 
-    assertEquals(graph.getInputOperators().get(mockStreamSpec), intermediateStreamImpl.getOperatorSpec());
-    assertEquals(graph.getOutputStreams().get(mockStreamSpec), intermediateStreamImpl.getOutputStream());
+    assertEquals(graphSpec.getInputOperators().get(mockStreamSpec), intermediateStreamImpl.getOperatorSpec());
+    assertEquals(graphSpec.getOutputStreams().get(mockStreamSpec), intermediateStreamImpl.getOutputStream());
     assertEquals(mockStreamSpec, intermediateStreamImpl.getStreamSpec());
     assertTrue(intermediateStreamImpl.getOutputStream().getKeySerde() instanceof NoOpSerde);
     assertEquals(mockValueSerde, intermediateStreamImpl.getOutputStream().getValueSerde());
@@ -396,7 +396,7 @@ public class TestStreamGraphImpl {
     String mockStreamName = "mockStreamName";
     when(mockRunner.getStreamSpec(mockStreamName)).thenReturn(mockStreamSpec);
 
-    StreamGraphImpl graph = new StreamGraphImpl(mockRunner, mockConfig);
+    StreamGraphSpec graphSpec = new StreamGraphSpec(mockRunner, mockConfig);
 
     KVSerde mockKVSerde = mock(KVSerde.class);
     Serde mockKeySerde = mock(Serde.class);
@@ -404,10 +404,10 @@ public class TestStreamGraphImpl {
     doReturn(mockKeySerde).when(mockKVSerde).getKeySerde();
     doReturn(mockValueSerde).when(mockKVSerde).getValueSerde();
     IntermediateMessageStreamImpl<TestMessageEnvelope> intermediateStreamImpl =
-        graph.getIntermediateStream(mockStreamName, mockKVSerde);
+        graphSpec.getIntermediateStream(mockStreamName, mockKVSerde);
 
-    assertEquals(graph.getInputOperators().get(mockStreamSpec), intermediateStreamImpl.getOperatorSpec());
-    assertEquals(graph.getOutputStreams().get(mockStreamSpec), intermediateStreamImpl.getOutputStream());
+    assertEquals(graphSpec.getInputOperators().get(mockStreamSpec), intermediateStreamImpl.getOperatorSpec());
+    assertEquals(graphSpec.getOutputStreams().get(mockStreamSpec), intermediateStreamImpl.getOutputStream());
     assertEquals(mockStreamSpec, intermediateStreamImpl.getStreamSpec());
     assertEquals(mockKeySerde, intermediateStreamImpl.getOutputStream().getKeySerde());
     assertEquals(mockValueSerde, intermediateStreamImpl.getOutputStream().getValueSerde());
@@ -423,7 +423,7 @@ public class TestStreamGraphImpl {
     String mockStreamName = "mockStreamName";
     when(mockRunner.getStreamSpec(mockStreamName)).thenReturn(mockStreamSpec);
 
-    StreamGraphImpl graph = new StreamGraphImpl(mockRunner, mockConfig);
+    StreamGraphSpec graph = new StreamGraphSpec(mockRunner, mockConfig);
 
     Serde mockValueSerde = mock(Serde.class);
     graph.setDefaultSerde(mockValueSerde);
@@ -447,19 +447,19 @@ public class TestStreamGraphImpl {
     String mockStreamName = "mockStreamName";
     when(mockRunner.getStreamSpec(mockStreamName)).thenReturn(mockStreamSpec);
 
-    StreamGraphImpl graph = new StreamGraphImpl(mockRunner, mockConfig);
+    StreamGraphSpec graphSpec = new StreamGraphSpec(mockRunner, mockConfig);
 
     KVSerde mockKVSerde = mock(KVSerde.class);
     Serde mockKeySerde = mock(Serde.class);
     Serde mockValueSerde = mock(Serde.class);
     doReturn(mockKeySerde).when(mockKVSerde).getKeySerde();
     doReturn(mockValueSerde).when(mockKVSerde).getValueSerde();
-    graph.setDefaultSerde(mockKVSerde);
+    graphSpec.setDefaultSerde(mockKVSerde);
     IntermediateMessageStreamImpl<TestMessageEnvelope> intermediateStreamImpl =
-        graph.getIntermediateStream(mockStreamName, null);
+        graphSpec.getIntermediateStream(mockStreamName, null);
 
-    assertEquals(graph.getInputOperators().get(mockStreamSpec), intermediateStreamImpl.getOperatorSpec());
-    assertEquals(graph.getOutputStreams().get(mockStreamSpec), intermediateStreamImpl.getOutputStream());
+    assertEquals(graphSpec.getInputOperators().get(mockStreamSpec), intermediateStreamImpl.getOperatorSpec());
+    assertEquals(graphSpec.getOutputStreams().get(mockStreamSpec), intermediateStreamImpl.getOutputStream());
     assertEquals(mockStreamSpec, intermediateStreamImpl.getStreamSpec());
     assertEquals(mockKeySerde, intermediateStreamImpl.getOutputStream().getKeySerde());
     assertEquals(mockValueSerde, intermediateStreamImpl.getOutputStream().getValueSerde());
@@ -475,12 +475,12 @@ public class TestStreamGraphImpl {
     String mockStreamName = "mockStreamName";
     when(mockRunner.getStreamSpec(mockStreamName)).thenReturn(mockStreamSpec);
 
-    StreamGraphImpl graph = new StreamGraphImpl(mockRunner, mockConfig);
+    StreamGraphSpec graphSpec = new StreamGraphSpec(mockRunner, mockConfig);
     IntermediateMessageStreamImpl<TestMessageEnvelope> intermediateStreamImpl =
-        graph.getIntermediateStream(mockStreamName, null);
+        graphSpec.getIntermediateStream(mockStreamName, null);
 
-    assertEquals(graph.getInputOperators().get(mockStreamSpec), intermediateStreamImpl.getOperatorSpec());
-    assertEquals(graph.getOutputStreams().get(mockStreamSpec), intermediateStreamImpl.getOutputStream());
+    assertEquals(graphSpec.getInputOperators().get(mockStreamSpec), intermediateStreamImpl.getOperatorSpec());
+    assertEquals(graphSpec.getOutputStreams().get(mockStreamSpec), intermediateStreamImpl.getOutputStream());
     assertEquals(mockStreamSpec, intermediateStreamImpl.getStreamSpec());
     assertTrue(intermediateStreamImpl.getOutputStream().getKeySerde() instanceof NoOpSerde);
     assertTrue(intermediateStreamImpl.getOutputStream().getValueSerde() instanceof NoOpSerde);
@@ -493,9 +493,9 @@ public class TestStreamGraphImpl {
     ApplicationRunner mockRunner = mock(ApplicationRunner.class);
     when(mockRunner.getStreamSpec("test-stream-1")).thenReturn(mock(StreamSpec.class));
 
-    StreamGraphImpl graph = new StreamGraphImpl(mockRunner, mock(Config.class));
-    graph.getIntermediateStream("test-stream-1", mock(Serde.class));
-    graph.getIntermediateStream("test-stream-1", mock(Serde.class));
+    StreamGraphSpec graphSpec = new StreamGraphSpec(mockRunner, mock(Config.class));
+    graphSpec.getIntermediateStream("test-stream-1", mock(Serde.class));
+    graphSpec.getIntermediateStream("test-stream-1", mock(Serde.class));
   }
 
   @Test
@@ -505,10 +505,10 @@ public class TestStreamGraphImpl {
     when(mockConfig.get(eq(JobConfig.JOB_NAME()))).thenReturn("jobName");
     when(mockConfig.get(eq(JobConfig.JOB_ID()), anyString())).thenReturn("1234");
 
-    StreamGraphImpl graph = new StreamGraphImpl(mockRunner, mockConfig);
-    assertEquals("jobName-1234-merge-0", graph.getNextOpId(OpCode.MERGE, null));
-    assertEquals("jobName-1234-join-customName", graph.getNextOpId(OpCode.JOIN, "customName"));
-    assertEquals("jobName-1234-map-2", graph.getNextOpId(OpCode.MAP, null));
+    StreamGraphSpec graphSpec = new StreamGraphSpec(mockRunner, mockConfig);
+    assertEquals("jobName-1234-merge-0", graphSpec.getNextOpId(OpCode.MERGE, null));
+    assertEquals("jobName-1234-join-customName", graphSpec.getNextOpId(OpCode.JOIN, "customName"));
+    assertEquals("jobName-1234-map-2", graphSpec.getNextOpId(OpCode.MAP, null));
   }
 
   @Test(expected = SamzaException.class)
@@ -518,9 +518,9 @@ public class TestStreamGraphImpl {
     when(mockConfig.get(eq(JobConfig.JOB_NAME()))).thenReturn("jobName");
     when(mockConfig.get(eq(JobConfig.JOB_ID()), anyString())).thenReturn("1234");
 
-    StreamGraphImpl graph = new StreamGraphImpl(mockRunner, mockConfig);
-    assertEquals("jobName-1234-join-customName", graph.getNextOpId(OpCode.JOIN, "customName"));
-    graph.getNextOpId(OpCode.JOIN, "customName"); // should throw
+    StreamGraphSpec graphSpec = new StreamGraphSpec(mockRunner, mockConfig);
+    assertEquals("jobName-1234-join-customName", graphSpec.getNextOpId(OpCode.JOIN, "customName"));
+    graphSpec.getNextOpId(OpCode.JOIN, "customName"); // should throw
   }
 
   @Test
@@ -530,32 +530,32 @@ public class TestStreamGraphImpl {
     when(mockConfig.get(eq(JobConfig.JOB_NAME()))).thenReturn("jobName");
     when(mockConfig.get(eq(JobConfig.JOB_ID()), anyString())).thenReturn("1234");
 
-    StreamGraphImpl graph = new StreamGraphImpl(mockRunner, mockConfig);
+    StreamGraphSpec graphSpec = new StreamGraphSpec(mockRunner, mockConfig);
 
     // null and empty userDefinedIDs should fall back to autogenerated IDs.
     try {
-      graph.getNextOpId(OpCode.FILTER, null);
-      graph.getNextOpId(OpCode.FILTER, "");
-      graph.getNextOpId(OpCode.FILTER, " ");
-      graph.getNextOpId(OpCode.FILTER, "\t");
+      graphSpec.getNextOpId(OpCode.FILTER, null);
+      graphSpec.getNextOpId(OpCode.FILTER, "");
+      graphSpec.getNextOpId(OpCode.FILTER, " ");
+      graphSpec.getNextOpId(OpCode.FILTER, "\t");
     } catch (SamzaException e) {
-      Assert.fail("Received an error with a null or empty operator ID instead of defaulting to auto-generated ID.");
+      fail("Received an error with a null or empty operator ID instead of defaulting to auto-generated ID.");
     }
 
     List<String> validOpIds = ImmutableList.of("op.id", "op_id", "op-id", "1000", "op_1", "OP_ID");
     for (String validOpId: validOpIds) {
       try {
-        graph.getNextOpId(OpCode.FILTER, validOpId);
+        graphSpec.getNextOpId(OpCode.FILTER, validOpId);
       } catch (Exception e) {
-        Assert.fail("Received an exception with a valid operator ID: " + validOpId);
+        fail("Received an exception with a valid operator ID: " + validOpId);
       }
     }
 
     List<String> invalidOpIds = ImmutableList.of("op id", "op#id");
     for (String invalidOpId: invalidOpIds) {
       try {
-        graph.getNextOpId(OpCode.FILTER, invalidOpId);
-        Assert.fail("Did not receive an exception with an invalid operator ID: " + invalidOpId);
+        graphSpec.getNextOpId(OpCode.FILTER, invalidOpId);
+        fail("Did not receive an exception with an invalid operator ID: " + invalidOpId);
       } catch (SamzaException e) { }
     }
   }
@@ -565,7 +565,7 @@ public class TestStreamGraphImpl {
     ApplicationRunner mockRunner = mock(ApplicationRunner.class);
     Config mockConfig = mock(Config.class);
 
-    StreamGraphImpl graph = new StreamGraphImpl(mockRunner, mockConfig);
+    StreamGraphSpec graphSpec = new StreamGraphSpec(mockRunner, mockConfig);
 
     StreamSpec testStreamSpec1 = new StreamSpec("test-stream-1", "physical-stream-1", "test-system");
     when(mockRunner.getStreamSpec("test-stream-1")).thenReturn(testStreamSpec1);
@@ -576,26 +576,26 @@ public class TestStreamGraphImpl {
     StreamSpec testStreamSpec3 = new StreamSpec("test-stream-3", "physical-stream-3", "test-system");
     when(mockRunner.getStreamSpec("test-stream-3")).thenReturn(testStreamSpec3);
 
-    graph.getInputStream("test-stream-1");
-    graph.getInputStream("test-stream-2");
-    graph.getInputStream("test-stream-3");
+    graphSpec.getInputStream("test-stream-1");
+    graphSpec.getInputStream("test-stream-2");
+    graphSpec.getInputStream("test-stream-3");
 
-    List<InputOperatorSpec> inputSpecs = new ArrayList<>(graph.getInputOperators().values());
-    Assert.assertEquals(inputSpecs.size(), 3);
-    Assert.assertEquals(inputSpecs.get(0).getStreamSpec(), testStreamSpec1);
-    Assert.assertEquals(inputSpecs.get(1).getStreamSpec(), testStreamSpec2);
-    Assert.assertEquals(inputSpecs.get(2).getStreamSpec(), testStreamSpec3);
+    List<InputOperatorSpec> inputSpecs = new ArrayList<>(graphSpec.getInputOperators().values());
+    assertEquals(inputSpecs.size(), 3);
+    assertEquals(inputSpecs.get(0).getStreamSpec(), testStreamSpec1);
+    assertEquals(inputSpecs.get(1).getStreamSpec(), testStreamSpec2);
+    assertEquals(inputSpecs.get(2).getStreamSpec(), testStreamSpec3);
   }
 
   @Test
   public void testGetTable() {
     ApplicationRunner mockRunner = mock(ApplicationRunner.class);
     Config mockConfig = mock(Config.class);
-    StreamGraphImpl graph = new StreamGraphImpl(mockRunner, mockConfig);
+    StreamGraphSpec graphSpec = new StreamGraphSpec(mockRunner, mockConfig);
 
     BaseTableDescriptor mockTableDescriptor = mock(BaseTableDescriptor.class);
     when(mockTableDescriptor.getTableSpec()).thenReturn(
         new TableSpec("t1", KVSerde.of(new NoOpSerde(), new NoOpSerde()), "", new HashMap<>()));
-    Assert.assertNotNull(graph.getTable(mockTableDescriptor));
+    assertNotNull(graphSpec.getTable(mockTableDescriptor));
   }
 }
index f9537a3..519e5df 100644 (file)
@@ -35,5 +35,19 @@ public class TestOutputMessageEnvelope {
   public String getKey() {
     return this.key;
   }
+
+  @Override
+  public boolean equals(Object other) {
+    if (!(other instanceof TestOutputMessageEnvelope)) {
+      return false;
+    }
+    TestOutputMessageEnvelope otherMsg = (TestOutputMessageEnvelope) other;
+    return this.key.equals(otherMsg.key) && this.value.equals(otherMsg.value);
+  }
+
+  @Override
+  public int hashCode() {
+    return String.format("%s:%d", key, value).hashCode();
+  }
 }
 
index 2d8d1eb..b87e5ed 100644 (file)
@@ -21,11 +21,17 @@ package org.apache.samza.operators.impl;
 
 import com.google.common.collect.HashMultimap;
 import com.google.common.collect.Multimap;
+import java.io.Serializable;
+import java.time.Duration;
+import java.util.ArrayList;
 import java.util.Collection;
+import java.util.Collections;
 import java.util.HashMap;
 import java.util.HashSet;
 import java.util.Map;
 import java.util.Set;
+import java.util.function.BiFunction;
+import java.util.function.Function;
 import org.apache.samza.Partition;
 import org.apache.samza.config.Config;
 import org.apache.samza.config.JobConfig;
@@ -39,9 +45,11 @@ import org.apache.samza.job.model.TaskModel;
 import org.apache.samza.metrics.MetricsRegistryMap;
 import org.apache.samza.operators.KV;
 import org.apache.samza.operators.MessageStream;
+import org.apache.samza.operators.StreamGraphSpec;
 import org.apache.samza.operators.OutputStream;
-import org.apache.samza.operators.StreamGraphImpl;
+import org.apache.samza.operators.functions.ClosableFunction;
 import org.apache.samza.operators.functions.FilterFunction;
+import org.apache.samza.operators.functions.InitableFunction;
 import org.apache.samza.operators.functions.JoinFunction;
 import org.apache.samza.operators.functions.MapFunction;
 import org.apache.samza.operators.impl.store.TimestampedValue;
@@ -58,34 +66,160 @@ import org.apache.samza.system.SystemStream;
 import org.apache.samza.system.SystemStreamPartition;
 import org.apache.samza.task.MessageCollector;
 import org.apache.samza.task.TaskContext;
+import java.util.List;
 import org.apache.samza.task.TaskCoordinator;
 import org.apache.samza.util.Clock;
 import org.apache.samza.util.SystemClock;
+import org.junit.After;
 import org.junit.Test;
 
-import java.time.Duration;
-import java.util.ArrayList;
-import java.util.Collections;
-import java.util.List;
-
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertNotSame;
 import static org.junit.Assert.assertTrue;
-import static org.mockito.Matchers.any;
 import static org.mockito.Matchers.anyString;
 import static org.mockito.Matchers.eq;
 import static org.mockito.Mockito.mock;
-import static org.mockito.Mockito.times;
-import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
 
 public class TestOperatorImplGraph {
 
+  private void addOperatorRecursively(HashSet<OperatorImpl> s, OperatorImpl op) {
+    List<OperatorImpl> operators = new ArrayList<>();
+    operators.add(op);
+    while (!operators.isEmpty()) {
+      OperatorImpl opImpl = operators.remove(0);
+      s.add(opImpl);
+      if (!opImpl.registeredOperators.isEmpty()) {
+        operators.addAll(opImpl.registeredOperators);
+      }
+    }
+  }
+
+  static class TestMapFunction<M, OM> extends BaseTestFunction implements MapFunction<M, OM> {
+    final Function<M, OM> mapFn;
+
+    public TestMapFunction(String opId, Function<M, OM> mapFn) {
+      super(opId);
+      this.mapFn = mapFn;
+    }
+
+    @Override
+    public OM apply(M message) {
+      return this.mapFn.apply(message);
+    }
+  }
+
+  static class TestJoinFunction<K, M, JM, RM> extends BaseTestFunction implements JoinFunction<K, M, JM, RM> {
+    final BiFunction<M, JM, RM> joiner;
+    final Function<M, K> firstKeyFn;
+    final Function<JM, K> secondKeyFn;
+    final Collection<RM> joinResults = new HashSet<>();
+
+    public TestJoinFunction(String opId, BiFunction<M, JM, RM> joiner, Function<M, K> firstKeyFn, Function<JM, K> secondKeyFn) {
+      super(opId);
+      this.joiner = joiner;
+      this.firstKeyFn = firstKeyFn;
+      this.secondKeyFn = secondKeyFn;
+    }
+
+    @Override
+    public RM apply(M message, JM otherMessage) {
+      RM result = this.joiner.apply(message, otherMessage);
+      this.joinResults.add(result);
+      return result;
+    }
+
+    @Override
+    public K getFirstKey(M message) {
+      return this.firstKeyFn.apply(message);
+    }
+
+    @Override
+    public K getSecondKey(JM message) {
+      return this.secondKeyFn.apply(message);
+    }
+  }
+
+  static abstract class BaseTestFunction implements InitableFunction, ClosableFunction, Serializable {
+
+    static Map<TaskName, Map<String, BaseTestFunction>> perTaskFunctionMap = new HashMap<>();
+    static Map<TaskName, List<String>> perTaskInitList = new HashMap<>();
+    static Map<TaskName, List<String>> perTaskCloseList = new HashMap<>();
+    int numInitCalled = 0;
+    int numCloseCalled = 0;
+    TaskName taskName = null;
+    final String opId;
+
+    public BaseTestFunction(String opId) {
+      this.opId = opId;
+    }
+
+    static public void reset() {
+      perTaskFunctionMap.clear();
+      perTaskCloseList.clear();
+      perTaskInitList.clear();
+    }
+
+    static public BaseTestFunction getInstanceByTaskName(TaskName taskName, String opId) {
+      return perTaskFunctionMap.get(taskName).get(opId);
+    }
+
+    static public List<String> getInitListByTaskName(TaskName taskName) {
+      return perTaskInitList.get(taskName);
+    }
+
+    static public List<String> getCloseListByTaskName(TaskName taskName) {
+      return perTaskCloseList.get(taskName);
+    }
+
+    @Override
+    public void close() {
+      if (this.taskName == null) {
+        throw new IllegalStateException("Close called before init");
+      }
+      if (perTaskFunctionMap.get(this.taskName) == null || !perTaskFunctionMap.get(this.taskName).containsKey(opId)) {
+        throw new IllegalStateException("Close called before init");
+      }
+
+      if (perTaskCloseList.get(this.taskName) == null) {
+        perTaskCloseList.put(taskName, new ArrayList<String>() { { this.add(opId); } });
+      } else {
+        perTaskCloseList.get(taskName).add(opId);
+      }
+
+      this.numCloseCalled++;
+    }
+
+    @Override
+    public void init(Config config, TaskContext context) {
+      if (perTaskFunctionMap.get(context.getTaskName()) == null) {
+        perTaskFunctionMap.put(context.getTaskName(), new HashMap<String, BaseTestFunction>() { { this.put(opId, BaseTestFunction.this); } });
+      } else {
+        if (perTaskFunctionMap.get(context.getTaskName()).containsKey(opId)) {
+          throw new IllegalStateException(String.format("Multiple init called for op %s in the same task instance %s", opId, this.taskName.getTaskName()));
+        }
+        perTaskFunctionMap.get(context.getTaskName()).put(opId, this);
+      }
+      if (perTaskInitList.get(context.getTaskName()) == null) {
+        perTaskInitList.put(context.getTaskName(), new ArrayList<String>() { { this.add(opId); } });
+      } else {
+        perTaskInitList.get(context.getTaskName()).add(opId);
+      }
+      this.taskName = context.getTaskName();
+      this.numInitCalled++;
+    }
+  }
+
+  @After
+  public void tearDown() {
+    BaseTestFunction.reset();
+  }
+
   @Test
   public void testEmptyChain() {
-    StreamGraphImpl streamGraph = new StreamGraphImpl(mock(ApplicationRunner.class), mock(Config.class));
+    StreamGraphSpec graphSpec = new StreamGraphSpec(mock(ApplicationRunner.class), mock(Config.class));
     OperatorImplGraph opGraph =
-        new OperatorImplGraph(streamGraph, mock(Config.class), mock(TaskContextImpl.class), mock(Clock.class));
+        new OperatorImplGraph(graphSpec.getOperatorSpecGraph(), mock(Config.class), mock(TaskContextImpl.class), mock(Clock.class));
     assertEquals(0, opGraph.getAllInputOperators().size());
   }
 
@@ -94,10 +228,10 @@ public class TestOperatorImplGraph {
     ApplicationRunner mockRunner = mock(ApplicationRunner.class);
     when(mockRunner.getStreamSpec(eq("input"))).thenReturn(new StreamSpec("input", "input-stream", "input-system"));
     when(mockRunner.getStreamSpec(eq("output"))).thenReturn(mock(StreamSpec.class));
-    StreamGraphImpl streamGraph = new StreamGraphImpl(mockRunner, mock(Config.class));
+    StreamGraphSpec graphSpec = new StreamGraphSpec(mockRunner, mock(Config.class));
 
-    MessageStream<Object> inputStream = streamGraph.getInputStream("input");
-    OutputStream<Object> outputStream = streamGraph.getOutputStream("output");
+    MessageStream<Object> inputStream = graphSpec.getInputStream("input");
+    OutputStream<Object> outputStream = graphSpec.getOutputStream("output");
 
     inputStream
         .filter(mock(FilterFunction.class))
@@ -108,7 +242,7 @@ public class TestOperatorImplGraph {
     when(mockTaskContext.getMetricsRegistry()).thenReturn(new MetricsRegistryMap());
     when(mockTaskContext.getTaskName()).thenReturn(new TaskName("task 0"));
     OperatorImplGraph opImplGraph =
-        new OperatorImplGraph(streamGraph, mock(Config.class), mockTaskContext, mock(Clock.class));
+        new OperatorImplGraph(graphSpec.getOperatorSpecGraph(), mock(Config.class), mockTaskContext, mock(Clock.class));
 
     InputOperatorImpl inputOpImpl = opImplGraph.getInputOperator(new SystemStream("input-system", "input-stream"));
     assertEquals(1, inputOpImpl.registeredOperators.size());
@@ -136,9 +270,9 @@ public class TestOperatorImplGraph {
     Config mockConfig = mock(Config.class);
     when(mockConfig.get(JobConfig.JOB_NAME())).thenReturn("jobName");
     when(mockConfig.get(eq(JobConfig.JOB_ID()), anyString())).thenReturn("jobId");
-    StreamGraphImpl streamGraph = new StreamGraphImpl(mockRunner, mockConfig);
-    MessageStream<Object> inputStream = streamGraph.getInputStream("input");
-    OutputStream<KV<Integer, String>> outputStream = streamGraph
+    StreamGraphSpec graphSpec = new StreamGraphSpec(mockRunner, mockConfig);
+    MessageStream<Object> inputStream = graphSpec.getInputStream("input");
+    OutputStream<KV<Integer, String>> outputStream = graphSpec
         .getOutputStream("output", KVSerde.of(mock(IntegerSerde.class), mock(StringSerde.class)));
 
     inputStream
@@ -160,7 +294,7 @@ public class TestOperatorImplGraph {
         new SamzaContainerContext("0", mockConfig, Collections.singleton(new TaskName("task 0")), new MetricsRegistryMap());
     when(mockTaskContext.getSamzaContainerContext()).thenReturn(containerContext);
     OperatorImplGraph opImplGraph =
-        new OperatorImplGraph(streamGraph, mockConfig, mockTaskContext, mock(Clock.class));
+        new OperatorImplGraph(graphSpec.getOperatorSpecGraph(), mockConfig, mockTaskContext, mock(Clock.class));
 
     InputOperatorImpl inputOpImpl = opImplGraph.getInputOperator(new SystemStream("input-system", "input-stream"));
     assertEquals(1, inputOpImpl.registeredOperators.size());
@@ -182,16 +316,16 @@ public class TestOperatorImplGraph {
   public void testBroadcastChain() {
     ApplicationRunner mockRunner = mock(ApplicationRunner.class);
     when(mockRunner.getStreamSpec(eq("input"))).thenReturn(new StreamSpec("input", "input-stream", "input-system"));
-    StreamGraphImpl streamGraph = new StreamGraphImpl(mockRunner, mock(Config.class));
+    StreamGraphSpec graphSpec = new StreamGraphSpec(mockRunner, mock(Config.class));
 
-    MessageStream<Object> inputStream = streamGraph.getInputStream("input");
+    MessageStream<Object> inputStream = graphSpec.getInputStream("input");
     inputStream.filter(mock(FilterFunction.class));
     inputStream.map(mock(MapFunction.class));
 
     TaskContextImpl mockTaskContext = mock(TaskContextImpl.class);
     when(mockTaskContext.getMetricsRegistry()).thenReturn(new MetricsRegistryMap());
     OperatorImplGraph opImplGraph =
-        new OperatorImplGraph(streamGraph, mock(Config.class), mockTaskContext, mock(Clock.class));
+        new OperatorImplGraph(graphSpec.getOperatorSpecGraph(), mock(Config.class), mockTaskContext, mock(Clock.class));
 
     InputOperatorImpl inputOpImpl = opImplGraph.getInputOperator(new SystemStream("input-system", "input-stream"));
     assertEquals(2, inputOpImpl.registeredOperators.size());
@@ -204,23 +338,36 @@ public class TestOperatorImplGraph {
   @Test
   public void testMergeChain() {
     ApplicationRunner mockRunner = mock(ApplicationRunner.class);
-    when(mockRunner.getStreamSpec(eq("input"))).thenReturn(new StreamSpec("input", "input-stream", "input-system"));
-    StreamGraphImpl streamGraph = new StreamGraphImpl(mockRunner, mock(Config.class));
+    when(mockRunner.getStreamSpec(eq("input")))
+        .thenReturn(new StreamSpec("input", "input-stream", "input-system"));
+    StreamGraphSpec graphSpec = new StreamGraphSpec(mockRunner, mock(Config.class));
 
-    MessageStream<Object> inputStream = streamGraph.getInputStream("input");
+    MessageStream<Object> inputStream = graphSpec.getInputStream("input");
     MessageStream<Object> stream1 = inputStream.filter(mock(FilterFunction.class));
     MessageStream<Object> stream2 = inputStream.map(mock(MapFunction.class));
     MessageStream<Object> mergedStream = stream1.merge(Collections.singleton(stream2));
-    MapFunction mockMapFunction = mock(MapFunction.class);
-    mergedStream.map(mockMapFunction);
 
     TaskContextImpl mockTaskContext = mock(TaskContextImpl.class);
+    TaskName mockTaskName = mock(TaskName.class);
+    when(mockTaskContext.getTaskName()).thenReturn(mockTaskName);
     when(mockTaskContext.getMetricsRegistry()).thenReturn(new MetricsRegistryMap());
+
+    MapFunction testMapFunction = new TestMapFunction<Object, Object>("test-map-1", (Function & Serializable) m -> m);
+    mergedStream.map(testMapFunction);
+
     OperatorImplGraph opImplGraph =
-        new OperatorImplGraph(streamGraph, mock(Config.class), mockTaskContext, mock(Clock.class));
+        new OperatorImplGraph(graphSpec.getOperatorSpecGraph(), mock(Config.class), mockTaskContext, mock(Clock.class));
+
+    Set<OperatorImpl> opSet = opImplGraph.getAllInputOperators().stream().collect(HashSet::new,
+        (s, op) -> addOperatorRecursively(s, op), HashSet::addAll);
+    Object[] mergeOps = opSet.stream().filter(op -> op.getOperatorSpec().getOpCode() == OpCode.MERGE).toArray();
+    assertEquals(mergeOps.length, 1);
+    assertEquals(((OperatorImpl) mergeOps[0]).registeredOperators.size(), 1);
+    OperatorImpl mapOp = (OperatorImpl) ((OperatorImpl) mergeOps[0]).registeredOperators.iterator().next();
+    assertEquals(mapOp.getOperatorSpec().getOpCode(), OpCode.MAP);
 
     // verify that the DAG after merge is only traversed & initialized once
-    verify(mockMapFunction, times(1)).init(any(Config.class), any(TaskContext.class));
+    assertEquals(TestMapFunction.getInstanceByTaskName(mockTaskName, "test-map-1").numInitCalled, 1);
   }
 
   @Test
@@ -231,25 +378,30 @@ public class TestOperatorImplGraph {
     Config mockConfig = mock(Config.class);
     when(mockConfig.get(JobConfig.JOB_NAME())).thenReturn("jobName");
     when(mockConfig.get(eq(JobConfig.JOB_ID()), anyString())).thenReturn("jobId");
-    StreamGraphImpl streamGraph = new StreamGraphImpl(mockRunner, mockConfig);
-
-    JoinFunction mockJoinFunction = mock(JoinFunction.class);
-    MessageStream<Object> inputStream1 = streamGraph.getInputStream("input1", new NoOpSerde<>());
-    MessageStream<Object> inputStream2 = streamGraph.getInputStream("input2", new NoOpSerde<>());
-    inputStream1.join(inputStream2, mockJoinFunction,
+    StreamGraphSpec graphSpec = new StreamGraphSpec(mockRunner, mockConfig);
+
+    Integer joinKey = new Integer(1);
+    Function<Object, Integer> keyFn = (Function & Serializable) m -> joinKey;
+    JoinFunction testJoinFunction = new TestJoinFunction("jobName-jobId-join-j1",
+        (BiFunction & Serializable) (m1, m2) -> KV.of(m1, m2), keyFn, keyFn);
+    MessageStream<Object> inputStream1 = graphSpec.getInputStream("input1", new NoOpSerde<>());
+    MessageStream<Object> inputStream2 = graphSpec.getInputStream("input2", new NoOpSerde<>());
+    inputStream1.join(inputStream2, testJoinFunction,
         mock(Serde.class), mock(Serde.class), mock(Serde.class), Duration.ofHours(1), "j1");
 
+    TaskName mockTaskName = mock(TaskName.class);
     TaskContextImpl mockTaskContext = mock(TaskContextImpl.class);
+    when(mockTaskContext.getTaskName()).thenReturn(mockTaskName);
     when(mockTaskContext.getMetricsRegistry()).thenReturn(new MetricsRegistryMap());
     KeyValueStore mockLeftStore = mock(KeyValueStore.class);
     when(mockTaskContext.getStore(eq("jobName-jobId-join-j1-L"))).thenReturn(mockLeftStore);
     KeyValueStore mockRightStore = mock(KeyValueStore.class);
     when(mockTaskContext.getStore(eq("jobName-jobId-join-j1-R"))).thenReturn(mockRightStore);
     OperatorImplGraph opImplGraph =
-        new OperatorImplGraph(streamGraph, mockConfig, mockTaskContext, mock(Clock.class));
+        new OperatorImplGraph(graphSpec.getOperatorSpecGraph(), mockConfig, mockTaskContext, mock(Clock.class));
 
     // verify that join function is initialized once.
-    verify(mockJoinFunction, times(1)).init(any(Config.class), any(TaskContext.class));
+    assertEquals(TestJoinFunction.getInstanceByTaskName(mockTaskName, "jobName-jobId-join-j1").numInitCalled, 1);
 
     InputOperatorImpl inputOpImpl1 = opImplGraph.getInputOperator(new SystemStream("input-system", "input-stream1"));
     InputOperatorImpl inputOpImpl2 = opImplGraph.getInputOperator(new SystemStream("input-system", "input-stream2"));
@@ -261,24 +413,23 @@ public class TestOperatorImplGraph {
     assertEquals(leftPartialJoinOpImpl.getOperatorSpec(), rightPartialJoinOpImpl.getOperatorSpec());
     assertNotSame(leftPartialJoinOpImpl, rightPartialJoinOpImpl);
 
-    Object joinKey = new Object();
     // verify that left partial join operator calls getFirstKey
     Object mockLeftMessage = mock(Object.class);
     long currentTimeMillis = System.currentTimeMillis();
     when(mockLeftStore.get(eq(joinKey))).thenReturn(new TimestampedValue<>(mockLeftMessage, currentTimeMillis));
-    when(mockJoinFunction.getFirstKey(eq(mockLeftMessage))).thenReturn(joinKey);
     inputOpImpl1.onMessage(KV.of("", mockLeftMessage), mock(MessageCollector.class), mock(TaskCoordinator.class));
-    verify(mockJoinFunction, times(1)).getFirstKey(mockLeftMessage);
 
     // verify that right partial join operator calls getSecondKey
     Object mockRightMessage = mock(Object.class);
     when(mockRightStore.get(eq(joinKey))).thenReturn(new TimestampedValue<>(mockRightMessage, currentTimeMillis));
-    when(mockJoinFunction.getSecondKey(eq(mockRightMessage))).thenReturn(joinKey);
     inputOpImpl2.onMessage(KV.of("", mockRightMessage), mock(MessageCollector.class), mock(TaskCoordinator.class));
-    verify(mockJoinFunction, times(1)).getSecondKey(mockRightMessage);
+
 
     // verify that the join function apply is called with the correct messages on match
-    verify(mockJoinFunction, times(1)).apply(mockLeftMessage, mockRightMessage);
+    assertEquals(((TestJoinFunction) TestJoinFunction.getInstanceByTaskName(mockTaskName, "jobName-jobId-join-j1")).joinResults.size(), 1);
+    KV joinResult = (KV) ((TestJoinFunction) TestJoinFunction.getInstanceByTaskName(mockTaskName, "jobName-jobId-join-j1")).joinResults.iterator().next();
+    assertEquals(joinResult.getKey(), mockLeftMessage);
+    assertEquals(joinResult.getValue(), mockRightMessage);
   }
 
   @Test
@@ -287,23 +438,25 @@ public class TestOperatorImplGraph {
     when(mockRunner.getStreamSpec("input1")).thenReturn(new StreamSpec("input1", "input-stream1", "input-system"));
     when(mockRunner.getStreamSpec("input2")).thenReturn(new StreamSpec("input2", "input-stream2", "input-system"));
     Config mockConfig = mock(Config.class);
+    TaskName mockTaskName = mock(TaskName.class);
     TaskContextImpl mockContext = mock(TaskContextImpl.class);
+    when(mockContext.getTaskName()).thenReturn(mockTaskName);
     when(mockContext.getMetricsRegistry()).thenReturn(new MetricsRegistryMap());
-    StreamGraphImpl streamGraph = new StreamGraphImpl(mockRunner, mockConfig);
+    StreamGraphSpec graphSpec = new StreamGraphSpec(mockRunner, mockConfig);
 
-    MessageStream<Object> inputStream1 = streamGraph.getInputStream("input1");
-    MessageStream<Object> inputStream2 = streamGraph.getInputStream("input2");
+    MessageStream<Object> inputStream1 = graphSpec.getInputStream("input1");
+    MessageStream<Object> inputStream2 = graphSpec.getInputStream("input2");
 
-    List<String> initializedOperators = new ArrayList<>();
-    List<String> closedOperators = new ArrayList<>();
+    Function mapFn = (Function & Serializable) m -> m;
+    inputStream1.map(new TestMapFunction<Object, Object>("1", mapFn))
+        .map(new TestMapFunction<Object, Object>("2", mapFn));
 
-    inputStream1.map(createMapFunction("1", initializedOperators, closedOperators))
-        .map(createMapFunction("2", initializedOperators, closedOperators));
+    inputStream2.map(new TestMapFunction<Object, Object>("3", mapFn))
+        .map(new TestMapFunction<Object, Object>("4", mapFn));
 
-    inputStream2.map(createMapFunction("3", initializedOperators, closedOperators))
-        .map(createMapFunction("4", initializedOperators, closedOperators));
+    OperatorImplGraph opImplGraph = new OperatorImplGraph(graphSpec.getOperatorSpecGraph(), mockConfig, mockContext, SystemClock.instance());
 
-    OperatorImplGraph opImplGraph = new OperatorImplGraph(streamGraph, mockConfig, mockContext, SystemClock.instance());
+    List<String> initializedOperators = BaseTestFunction.getInitListByTaskName(mockTaskName);
 
     // Assert that initialization occurs in topological order.
     assertEquals(initializedOperators.get(0), "1");
@@ -313,35 +466,13 @@ public class TestOperatorImplGraph {
 
     // Assert that finalization occurs in reverse topological order.
     opImplGraph.close();
+    List<String> closedOperators = BaseTestFunction.getCloseListByTaskName(mockTaskName);
     assertEquals(closedOperators.get(0), "4");
     assertEquals(closedOperators.get(1), "3");
     assertEquals(closedOperators.get(2), "2");
     assertEquals(closedOperators.get(3), "1");
   }
 
-  /**
-   * Creates an identity map function that appends to the provided lists when init/close is invoked.
-   */
-  private MapFunction<Object, Object> createMapFunction(String id,
-      List<String> initializedOperators, List<String> finalizedOperators) {
-    return new MapFunction<Object, Object>() {
-      @Override
-      public void init(Config config, TaskContext context) {
-        initializedOperators.add(id);
-      }
-
-      @Override
-      public void close() {
-        finalizedOperators.add(id);
-      }
-
-      @Override
-      public Object apply(Object message) {
-        return message;
-      }
-    };
-  }
-
   @Test
   public void testGetStreamToConsumerTasks() {
     String system = "test-system";
@@ -409,16 +540,16 @@ public class TestOperatorImplGraph {
     when(runner.getStreamSpec("test-app-1-partition_by-p2")).thenReturn(int1);
     when(runner.getStreamSpec("test-app-1-partition_by-p1")).thenReturn(int2);
 
-    StreamGraphImpl streamGraph = new StreamGraphImpl(runner, config);
-    MessageStream messageStream1 = streamGraph.getInputStream("input1").map(m -> m);
-    MessageStream messageStream2 = streamGraph.getInputStream("input2").filter(m -> true);
+    StreamGraphSpec graphSpec = new StreamGraphSpec(runner, config);
+    MessageStream messageStream1 = graphSpec.getInputStream("input1").map(m -> m);
+    MessageStream messageStream2 = graphSpec.getInputStream("input2").filter(m -> true);
     MessageStream messageStream3 =
-        streamGraph.getInputStream("input3")
+        graphSpec.getInputStream("input3")
             .filter(m -> true)
             .partitionBy(m -> "hehe", m -> m, "p1")
             .map(m -> m);
-    OutputStream<Object> outputStream1 = streamGraph.getOutputStream("output1");
-    OutputStream<Object> outputStream2 = streamGraph.getOutputStream("output2");
+    OutputStream<Object> outputStream1 = graphSpec.getOutputStream("output1");
+    OutputStream<Object> outputStream2 = graphSpec.getOutputStream("output2");
 
     messageStream1
         .join(messageStream2, mock(JoinFunction.class),
@@ -430,7 +561,8 @@ public class TestOperatorImplGraph {
             mock(Serde.class), mock(Serde.class), mock(Serde.class), Duration.ofHours(1), "j2")
         .sendTo(outputStream2);
 
-    Multimap<SystemStream, SystemStream> outputToInput = OperatorImplGraph.getIntermediateToInputStreamsMap(streamGraph);
+    Multimap<SystemStream, SystemStream> outputToInput =
+        OperatorImplGraph.getIntermediateToInputStreamsMap(graphSpec.getOperatorSpecGraph());
     Collection<SystemStream> inputs = outputToInput.get(int1.toSystemStream());
     assertEquals(inputs.size(), 2);
     assertTrue(inputs.contains(input1.toSystemStream()));
index a91c1af..873cd3c 100644 (file)
@@ -48,7 +48,7 @@ public class TestStreamOperatorImpl {
     Config mockConfig = mock(Config.class);
     TaskContext mockContext = mock(TaskContext.class);
     StreamOperatorImpl<TestMessageEnvelope, TestOutputMessageEnvelope> opImpl =
-        new StreamOperatorImpl<>(mockOp, mockConfig, mockContext);
+        new StreamOperatorImpl<>(mockOp);
     TestMessageEnvelope inMsg = mock(TestMessageEnvelope.class);
     Collection<TestOutputMessageEnvelope> mockOutputs = mock(Collection.class);
     when(txfmFn.apply(inMsg)).thenReturn(mockOutputs);
@@ -69,7 +69,7 @@ public class TestStreamOperatorImpl {
     TaskContext mockContext = mock(TaskContext.class);
 
     StreamOperatorImpl<TestMessageEnvelope, TestOutputMessageEnvelope> opImpl =
-        new StreamOperatorImpl<>(mockOp, mockConfig, mockContext);
+        new StreamOperatorImpl<>(mockOp);
 
     // ensure that close is not called yet
     verify(txfmFn, times(0)).close();
index 7d0c623..9741fc4 100644 (file)
@@ -22,19 +22,20 @@ package org.apache.samza.operators.impl;
 
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableSet;
-import junit.framework.Assert;
 import org.apache.samza.Partition;
-import org.apache.samza.application.StreamApplication;
 import org.apache.samza.config.Config;
+import org.apache.samza.config.MapConfig;
 import org.apache.samza.config.JobConfig;
 import org.apache.samza.container.TaskContextImpl;
 import org.apache.samza.container.TaskName;
 import org.apache.samza.metrics.MetricsRegistryMap;
 import org.apache.samza.operators.KV;
 import org.apache.samza.operators.MessageStream;
-import org.apache.samza.operators.StreamGraph;
+import org.apache.samza.operators.StreamGraphSpec;
+import org.apache.samza.operators.functions.MapFunction;
 import org.apache.samza.operators.impl.store.TestInMemoryStore;
 import org.apache.samza.operators.impl.store.TimeSeriesKeySerde;
+import org.apache.samza.operators.OperatorSpecGraph;
 import org.apache.samza.operators.triggers.FiringType;
 import org.apache.samza.operators.triggers.Trigger;
 import org.apache.samza.operators.triggers.Triggers;
@@ -54,19 +55,25 @@ import org.apache.samza.task.MessageCollector;
 import org.apache.samza.task.StreamOperatorTask;
 import org.apache.samza.task.TaskCoordinator;
 import org.apache.samza.testUtils.TestClock;
+import org.junit.Assert;
 import org.junit.Before;
 import org.junit.Test;
 
+import java.io.IOException;
 import java.time.Duration;
 import java.util.ArrayList;
 import java.util.Collection;
-import java.util.Collections;
+import java.util.HashMap;
 import java.util.List;
-import java.util.function.Function;
+import java.util.Map;
+import java.util.Collections;
 
 import static org.mockito.Matchers.anyString;
 import static org.mockito.Matchers.eq;
-import static org.mockito.Mockito.*;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
 
 public class TestWindowOperator {
   private final TaskCoordinator taskCoordinator = mock(TaskCoordinator.class);
@@ -83,26 +90,32 @@ public class TestWindowOperator {
     taskContext = mock(TaskContextImpl.class);
     runner = mock(ApplicationRunner.class);
     Serde storeKeySerde = new TimeSeriesKeySerde(new IntegerSerde());
-    Serde storeValSerde = new IntegerEnvelopeSerde();
+    Serde storeValSerde = KVSerde.of(new IntegerSerde(), new IntegerSerde());
 
     when(taskContext.getSystemStreamPartitions()).thenReturn(ImmutableSet
         .of(new SystemStreamPartition("kafka", "integers", new Partition(0))));
     when(taskContext.getMetricsRegistry()).thenReturn(new MetricsRegistryMap());
-
     when(taskContext.getStore("jobName-jobId-window-w1"))
         .thenReturn(new TestInMemoryStore<>(storeKeySerde, storeValSerde));
     when(runner.getStreamSpec("integers")).thenReturn(new StreamSpec("integers", "integers", "kafka"));
+
+    Map<String, String> mapConfig = new HashMap<>();
+    mapConfig.put("app.runner.class", "org.apache.samza.runtime.LocalApplicationRunner");
+    mapConfig.put("job.default.system", "kafka");
+    mapConfig.put("job.name", "jobName");
+    mapConfig.put("job.id", "jobId");
+    config = new MapConfig(mapConfig);
   }
 
   @Test
   public void testTumblingWindowsDiscardingMode() throws Exception {
 
-    StreamApplication sgb = new KeyedTumblingWindowStreamApplication(AccumulationMode.DISCARDING,
-        Duration.ofSeconds(1), Triggers.repeat(Triggers.count(2)));
+    OperatorSpecGraph sgb = this.getKeyedTumblingWindowStreamGraph(AccumulationMode.DISCARDING,
+        Duration.ofSeconds(1), Triggers.repeat(Triggers.count(2))).getOperatorSpecGraph();
     List<WindowPane<Integer, Collection<IntegerEnvelope>>> windowPanes = new ArrayList<>();
 
     TestClock testClock = new TestClock();
-    StreamOperatorTask task = new StreamOperatorTask(sgb, runner, testClock);
+    StreamOperatorTask task = new StreamOperatorTask(sgb, null, testClock);
     task.init(config, taskContext);
     MessageCollector messageCollector =
         envelope -> windowPanes.add((WindowPane<Integer, Collection<IntegerEnvelope>>) envelope.getMessage());
@@ -130,12 +143,12 @@ public class TestWindowOperator {
   @Test
   public void testNonKeyedTumblingWindowsDiscardingMode() throws Exception {
 
-    StreamApplication sgb = new TumblingWindowStreamApplication(AccumulationMode.DISCARDING,
-        Duration.ofSeconds(1), Triggers.repeat(Triggers.count(1000)));
+    OperatorSpecGraph sgb = this.getTumblingWindowStreamGraph(AccumulationMode.DISCARDING,
+        Duration.ofSeconds(1), Triggers.repeat(Triggers.count(1000))).getOperatorSpecGraph();
     List<WindowPane<Integer, Collection<IntegerEnvelope>>> windowPanes = new ArrayList<>();
 
     TestClock testClock = new TestClock();
-    StreamOperatorTask task = new StreamOperatorTask(sgb, runner, testClock);
+    StreamOperatorTask task = new StreamOperatorTask(sgb, null, testClock);
     task.init(config, taskContext);
 
     MessageCollector messageCollector =
@@ -159,12 +172,12 @@ public class TestWindowOperator {
     when(taskContext.getStore("jobName-jobId-window-w1"))
         .thenReturn(new TestInMemoryStore<>(new TimeSeriesKeySerde(new IntegerSerde()), new IntegerSerde()));
 
-    StreamApplication sgb = new AggregateTumblingWindowStreamApplication(AccumulationMode.DISCARDING,
-        Duration.ofSeconds(1), Triggers.repeat(Triggers.count(2)));
+    OperatorSpecGraph sgb = this.getAggregateTumblingWindowStreamGraph(AccumulationMode.DISCARDING,
+        Duration.ofSeconds(1), Triggers.repeat(Triggers.count(2))).getOperatorSpecGraph();
     List<WindowPane<Integer, Integer>> windowPanes = new ArrayList<>();
 
     TestClock testClock = new TestClock();
-    StreamOperatorTask task = new StreamOperatorTask(sgb, runner, testClock);
+    StreamOperatorTask task = new StreamOperatorTask(sgb, null, testClock);
     task.init(config, taskContext);
     MessageCollector messageCollector = envelope -> windowPanes.add((WindowPane<Integer, Integer>) envelope.getMessage());
     integers.forEach(n -> task.process(new IntegerEnvelope(n), messageCollector, taskCoordinator));
@@ -181,11 +194,11 @@ public class TestWindowOperator {
 
   @Test
   public void testTumblingWindowsAccumulatingMode() throws Exception {
-    StreamApplication sgb = new KeyedTumblingWindowStreamApplication(AccumulationMode.ACCUMULATING,
-        Duration.ofSeconds(1), Triggers.repeat(Triggers.count(2)));
+    OperatorSpecGraph sgb = this.getKeyedTumblingWindowStreamGraph(AccumulationMode.ACCUMULATING,
+        Duration.ofSeconds(1), Triggers.repeat(Triggers.count(2))).getOperatorSpecGraph();
     List<WindowPane<Integer, Collection<IntegerEnvelope>>> windowPanes = new ArrayList<>();
     TestClock testClock = new TestClock();
-    StreamOperatorTask task = new StreamOperatorTask(sgb, runner, testClock);
+    StreamOperatorTask task = new StreamOperatorTask(sgb, null, testClock);
     task.init(config, taskContext);
 
     MessageCollector messageCollector =
@@ -210,10 +223,11 @@ public class TestWindowOperator {
 
   @Test
   public void testSessionWindowsDiscardingMode() throws Exception {
-    StreamApplication sgb = new KeyedSessionWindowStreamApplication(AccumulationMode.DISCARDING, Duration.ofMillis(500));
+    OperatorSpecGraph sgb =
+        this.getKeyedSessionWindowStreamGraph(AccumulationMode.DISCARDING, Duration.ofMillis(500)).getOperatorSpecGraph();
     TestClock testClock = new TestClock();
     List<WindowPane<Integer, Collection<IntegerEnvelope>>> windowPanes = new ArrayList<>();
-    StreamOperatorTask task = new StreamOperatorTask(sgb, runner, testClock);
+    StreamOperatorTask task = new StreamOperatorTask(sgb, null, testClock);
     task.init(config, taskContext);
     MessageCollector messageCollector =
         envelope -> windowPanes.add((WindowPane<Integer, Collection<IntegerEnvelope>>) envelope.getMessage());
@@ -255,10 +269,10 @@ public class TestWindowOperator {
 
   @Test
   public void testSessionWindowsAccumulatingMode() throws Exception {
-    StreamApplication sgb = new KeyedSessionWindowStreamApplication(AccumulationMode.DISCARDING,
-        Duration.ofMillis(500));
+    OperatorSpecGraph sgb = this.getKeyedSessionWindowStreamGraph(AccumulationMode.DISCARDING,
+        Duration.ofMillis(500)).getOperatorSpecGraph();
     TestClock testClock = new TestClock();
-    StreamOperatorTask task = new StreamOperatorTask(sgb, runner, testClock);
+    StreamOperatorTask task = new StreamOperatorTask(sgb, null, testClock);
     List<WindowPane<Integer, Collection<IntegerEnvelope>>> windowPanes = new ArrayList<>();
 
     MessageCollector messageCollector =
@@ -287,10 +301,10 @@ public class TestWindowOperator {
 
   @Test
   public void testCancellationOfOnceTrigger() throws Exception {
-    StreamApplication sgb = new KeyedTumblingWindowStreamApplication(AccumulationMode.ACCUMULATING,
-        Duration.ofSeconds(1), Triggers.count(2));
+    OperatorSpecGraph sgb = this.getKeyedTumblingWindowStreamGraph(AccumulationMode.ACCUMULATING,
+        Duration.ofSeconds(1), Triggers.count(2)).getOperatorSpecGraph();
     TestClock testClock = new TestClock();
-    StreamOperatorTask task = new StreamOperatorTask(sgb, runner, testClock);
+    StreamOperatorTask task = new StreamOperatorTask(sgb, null, testClock);
     task.init(config, taskContext);
 
     List<WindowPane<Integer, Collection<IntegerEnvelope>>> windowPanes = new ArrayList<>();
@@ -331,10 +345,10 @@ public class TestWindowOperator {
 
   @Test
   public void testCancellationOfAnyTrigger() throws Exception {
-    StreamApplication sgb = new KeyedTumblingWindowStreamApplication(AccumulationMode.ACCUMULATING, Duration.ofSeconds(1),
-        Triggers.any(Triggers.count(2), Triggers.timeSinceFirstMessage(Duration.ofMillis(500))));
+    OperatorSpecGraph sgb = this.getKeyedTumblingWindowStreamGraph(AccumulationMode.ACCUMULATING, Duration.ofSeconds(1),
+        Triggers.any(Triggers.count(2), Triggers.timeSinceFirstMessage(Duration.ofMillis(500)))).getOperatorSpecGraph();
     TestClock testClock = new TestClock();
-    StreamOperatorTask task = new StreamOperatorTask(sgb, runner, testClock);
+    StreamOperatorTask task = new StreamOperatorTask(sgb, null, testClock);
     task.init(config, taskContext);
 
     List<WindowPane<Integer, Collection<IntegerEnvelope>>> windowPanes = new ArrayList<>();
@@ -389,15 +403,15 @@ public class TestWindowOperator {
   @Test
   public void testCancelationOfRepeatingNestedTriggers() throws Exception {
 
-    StreamApplication sgb = new KeyedTumblingWindowStreamApplication(AccumulationMode.ACCUMULATING, Duration.ofSeconds(1),
-        Triggers.repeat(Triggers.any(Triggers.count(2), Triggers.timeSinceFirstMessage(Duration.ofMillis(500)))));
+    OperatorSpecGraph sgb = this.getKeyedTumblingWindowStreamGraph(AccumulationMode.ACCUMULATING, Duration.ofSeconds(1),
+        Triggers.repeat(Triggers.any(Triggers.count(2), Triggers.timeSinceFirstMessage(Duration.ofMillis(500))))).getOperatorSpecGraph();
     List<WindowPane<Integer, Collection<IntegerEnvelope>>> windowPanes = new ArrayList<>();
 
     MessageCollector messageCollector =
         envelope -> windowPanes.add((WindowPane<Integer, Collection<IntegerEnvelope>>) envelope.getMessage());
 
     TestClock testClock = new TestClock();
-    StreamOperatorTask task = new StreamOperatorTask(sgb, runner, testClock);
+    StreamOperatorTask task = new StreamOperatorTask(sgb, null, testClock);
     task.init(config, taskContext);
 
     task.process(new IntegerEnvelope(1), messageCollector, taskCoordinator);
@@ -434,12 +448,12 @@ public class TestWindowOperator {
     when(taskContext.fetchObject(EndOfStreamStates.class.getName())).thenReturn(endOfStreamStates);
     when(taskContext.fetchObject(WatermarkStates.class.getName())).thenReturn(mock(WatermarkStates.class));
 
-    StreamApplication sgb = new TumblingWindowStreamApplication(AccumulationMode.DISCARDING,
-        Duration.ofSeconds(1), Triggers.repeat(Triggers.count(2)));
+    OperatorSpecGraph sgb = this.getTumblingWindowStreamGraph(AccumulationMode.DISCARDING,
+        Duration.ofSeconds(1), Triggers.repeat(Triggers.count(2))).getOperatorSpecGraph();
     List<WindowPane<Integer, Collection<IntegerEnvelope>>> windowPanes = new ArrayList<>();
 
     TestClock testClock = new TestClock();
-    StreamOperatorTask task = new StreamOperatorTask(sgb, runner, testClock);
+    StreamOperatorTask task = new StreamOperatorTask(sgb, null, testClock);
     task.init(config, taskContext);
 
     MessageCollector messageCollector =
@@ -475,10 +489,11 @@ public class TestWindowOperator {
     when(taskContext.fetchObject(EndOfStreamStates.class.getName())).thenReturn(endOfStreamStates);
     when(taskContext.fetchObject(WatermarkStates.class.getName())).thenReturn(mock(WatermarkStates.class));
 
-    StreamApplication sgb = new KeyedSessionWindowStreamApplication(AccumulationMode.DISCARDING, Duration.ofMillis(500));
+    OperatorSpecGraph sgb =
+        this.getKeyedSessionWindowStreamGraph(AccumulationMode.DISCARDING, Duration.ofMillis(500)).getOperatorSpecGraph();
     TestClock testClock = new TestClock();
     List<WindowPane<Integer, Collection<IntegerEnvelope>>> windowPanes = new ArrayList<>();
-    StreamOperatorTask task = new StreamOperatorTask(sgb, runner, testClock);
+    StreamOperatorTask task = new StreamOperatorTask(sgb, null, testClock);
     task.init(config, taskContext);
 
     MessageCollector messageCollector =
@@ -511,10 +526,11 @@ public class TestWindowOperator {
     when(taskContext.fetchObject(EndOfStreamStates.class.getName())).thenReturn(endOfStreamStates);
     when(taskContext.fetchObject(WatermarkStates.class.getName())).thenReturn(mock(WatermarkStates.class));
 
-    StreamApplication sgb = new KeyedSessionWindowStreamApplication(AccumulationMode.DISCARDING, Duration.ofMillis(500));
+    OperatorSpecGraph sgb =
+        this.getKeyedSessionWindowStreamGraph(AccumulationMode.DISCARDING, Duration.ofMillis(500)).getOperatorSpecGraph();
     TestClock testClock = new TestClock();
     List<WindowPane<Integer, Collection<IntegerEnvelope>>> windowPanes = new ArrayList<>();
-    StreamOperatorTask task = new StreamOperatorTask(sgb, runner, testClock);
+    StreamOperatorTask task = new StreamOperatorTask(sgb, null, testClock);
     task.init(config, taskContext);
 
     MessageCollector messageCollector =
@@ -534,144 +550,83 @@ public class TestWindowOperator {
     verify(taskCoordinator, times(1)).shutdown(TaskCoordinator.RequestScope.CURRENT_TASK);
   }
 
-  private class KeyedTumblingWindowStreamApplication implements StreamApplication {
-
-    private final AccumulationMode mode;
-    private final Duration duration;
-    private final Trigger<IntegerEnvelope> earlyTrigger;
-    private final SystemStream outputSystemStream = new SystemStream("outputSystem", "outputStream");
+  private StreamGraphSpec getKeyedTumblingWindowStreamGraph(AccumulationMode mode,
+      Duration duration, Trigger<KV<Integer, Integer>> earlyTrigger) throws IOException {
+    StreamGraphSpec graph = new StreamGraphSpec(runner, config);
 
-    KeyedTumblingWindowStreamApplication(AccumulationMode mode,
-        Duration timeDuration, Trigger<IntegerEnvelope> earlyTrigger) {
-      this.mode = mode;
-      this.duration = timeDuration;
-      this.earlyTrigger = earlyTrigger;
-    }
+    KVSerde<Integer, Integer> kvSerde = KVSerde.of(new IntegerSerde(), new IntegerSerde());
+    graph.getInputStream("integers", kvSerde)
+        .window(Windows.keyedTumblingWindow(KV::getKey, duration, new IntegerSerde(), kvSerde)
+            .setEarlyTrigger(earlyTrigger).setAccumulationMode(mode), "w1")
+        .sink((message, messageCollector, taskCoordinator) -> {
+            SystemStream outputSystemStream = new SystemStream("outputSystem", "outputStream");
+            messageCollector.send(new OutgoingMessageEnvelope(outputSystemStream, message));
+          });
 
-    @Override
-    public void init(StreamGraph graph, Config config) {
-      MessageStream<IntegerEnvelope> inStream =
-          graph.getInputStream("integers", KVSerde.of(new IntegerSerde(), new IntegerSerde()))
-              .map(kv -> new IntegerEnvelope(kv.getKey()));
-      Function<IntegerEnvelope, Integer> keyFn = m -> (Integer) m.getKey();
-      inStream
-          .map(m -> m)
-          .window(Windows.keyedTumblingWindow(keyFn, duration, new IntegerSerde(), new IntegerEnvelopeSerde())
-              .setEarlyTrigger(earlyTrigger)
-              .setAccumulationMode(mode), "w1")
-          .sink((message, messageCollector, taskCoordinator) -> {
-              messageCollector.send(new OutgoingMessageEnvelope(outputSystemStream, message));
-            });
-    }
+    return graph;
   }
 
-  private class TumblingWindowStreamApplication implements StreamApplication {
-
-    private final AccumulationMode mode;
-    private final Duration duration;
-    private final Trigger<IntegerEnvelope> earlyTrigger;
-    private final SystemStream outputSystemStream = new SystemStream("outputSystem", "outputStream");
+  private StreamGraphSpec getTumblingWindowStreamGraph(AccumulationMode mode,
+      Duration duration, Trigger<KV<Integer, Integer>> earlyTrigger) throws IOException {
+    StreamGraphSpec graph = new StreamGraphSpec(runner, config);
 
-    TumblingWindowStreamApplication(AccumulationMode mode,
-                                         Duration timeDuration, Trigger<IntegerEnvelope> earlyTrigger) {
-      this.mode = mode;
-      this.duration = timeDuration;
-      this.earlyTrigger = earlyTrigger;
-    }
-
-    @Override
-    public void init(StreamGraph graph, Config config) {
-      MessageStream<IntegerEnvelope> inStream =
-          graph.getInputStream("integers", KVSerde.of(new IntegerSerde(), new IntegerSerde()))
-              .map(kv -> new IntegerEnvelope(kv.getKey()));
-      Function<IntegerEnvelope, Integer> keyFn = m -> (Integer) m.getKey();
-      inStream
-          .map(m -> m)
-          .window(Windows.tumblingWindow(duration, new IntegerEnvelopeSerde())
-              .setEarlyTrigger(earlyTrigger)
-              .setAccumulationMode(mode), "w1")
-          .sink((message, messageCollector, taskCoordinator) -> {
-              messageCollector.send(new OutgoingMessageEnvelope(outputSystemStream, message));
-            });
-    }
+    KVSerde<Integer, Integer> kvSerde = KVSerde.of(new IntegerSerde(), new IntegerSerde());
+    graph.getInputStream("integers", kvSerde)
+        .window(Windows.tumblingWindow(duration, kvSerde).setEarlyTrigger(earlyTrigger)
+            .setAccumulationMode(mode), "w1")
+        .sink((message, messageCollector, taskCoordinator) -> {
+            SystemStream outputSystemStream = new SystemStream("outputSystem", "outputStream");
+            messageCollector.send(new OutgoingMessageEnvelope(outputSystemStream, message));
+          });
+    return graph;
   }
 
-  private class AggregateTumblingWindowStreamApplication implements StreamApplication {
+  private StreamGraphSpec getKeyedSessionWindowStreamGraph(AccumulationMode mode, Duration duration) throws IOException {
+    StreamGraphSpec graph = new StreamGraphSpec(runner, config);
 
-    private final AccumulationMode mode;
-    private final Duration duration;
-    private final Trigger<IntegerEnvelope> earlyTrigger;
-    private final SystemStream outputSystemStream = new SystemStream("outputSystem", "outputStream");
-
-    AggregateTumblingWindowStreamApplication(AccumulationMode mode, Duration timeDuration,
-        Trigger<IntegerEnvelope> earlyTrigger) {
-      this.mode = mode;
-      this.duration = timeDuration;
-      this.earlyTrigger = earlyTrigger;
-    }
-
-    @Override
-    public void init(StreamGraph graph, Config config) {
-      MessageStream<KV<Integer, Integer>> integers = graph.getInputStream("integers",
-          KVSerde.of(new IntegerSerde(), new IntegerSerde()));
-
-      integers
-        .map(kv -> new IntegerEnvelope(kv.getKey()))
-        .window(Windows.<IntegerEnvelope, Integer>tumblingWindow(this.duration, () -> 0, (m, c) -> c + 1, new IntegerSerde())
-            .setEarlyTrigger(earlyTrigger)
+    KVSerde<Integer, Integer> kvSerde = KVSerde.of(new IntegerSerde(), new IntegerSerde());
+    graph.getInputStream("integers", kvSerde)
+        .window(Windows.keyedSessionWindow(KV::getKey, duration, new IntegerSerde(), kvSerde)
             .setAccumulationMode(mode), "w1")
         .sink((message, messageCollector, taskCoordinator) -> {
+            SystemStream outputSystemStream = new SystemStream("outputSystem", "outputStream");
             messageCollector.send(new OutgoingMessageEnvelope(outputSystemStream, message));
           });
-    }
+    return graph;
   }
 
-  private class KeyedSessionWindowStreamApplication implements StreamApplication {
+  private StreamGraphSpec getAggregateTumblingWindowStreamGraph(AccumulationMode mode, Duration timeDuration,
+        Trigger<IntegerEnvelope> earlyTrigger) throws IOException {
+    StreamGraphSpec graph = new StreamGraphSpec(runner, config);
 
-    private final AccumulationMode mode;
-    private final Duration duration;
-    private final SystemStream outputSystemStream = new SystemStream("outputSystem", "outputStream");
+    MessageStream<KV<Integer, Integer>> integers = graph.getInputStream("integers",
+        KVSerde.of(new IntegerSerde(), new IntegerSerde()));
 
-    KeyedSessionWindowStreamApplication(AccumulationMode mode, Duration duration) {
-      this.mode = mode;
-      this.duration = duration;
-    }
-
-    @Override
-    public void init(StreamGraph graph, Config config) {
-      MessageStream<IntegerEnvelope> inStream =
-          graph.getInputStream("integers", KVSerde.of(new IntegerSerde(), new IntegerSerde()))
-              .map(kv -> new IntegerEnvelope(kv.getKey()));
-      Function<IntegerEnvelope, Integer> keyFn = m -> (Integer) m.getKey();
-
-      inStream
-          .map(m -> m)
-          .window(Windows.keyedSessionWindow(keyFn, duration, new IntegerSerde(), new IntegerEnvelopeSerde())
-              .setAccumulationMode(mode), "w1")
-          .sink((message, messageCollector, taskCoordinator) -> {
-           &n