SAMZA-1814: consolidate JobNode and JobGraph configuration generation for high and...
authorYi Pan (Data Infrastructure) <nickpan47@gmail.com>
Wed, 26 Sep 2018 18:00:51 +0000 (11:00 -0700)
committerYi Pan (Data Infrastructure) <nickpan47@gmail.com>
Wed, 26 Sep 2018 18:00:51 +0000 (11:00 -0700)
High-level changes:
- Move configuration generation to JobNodeConfigurationGenerator
- Move the intermediate partition calculation to IntermediationStreamPartitionPlanner
- Consolidate the code in JobPlanner and ExecutionPlanner for high and low-level API plan/configuration generation

Author: Yi Pan (Data Infrastructure) <nickpan47@gmail.com>
Author: Yi Pan (Data Infrastructure) <yipan@yipan-mn1.linkedin.biz>
Author: Yi Pan (Data Infrastructure) <yipan@yipan-ld2.linkedin.biz>
Author: Prateek Maheshwari <pmaheshwari@linkedin.com>
Author: Prateek Maheshwari <prateekm@utexas.edu>
Author: prateekm <prateekm@utexas.edu>

Reviewers: Prateek Maheshwari <pmaheshwari@apache.org>, Cameron Lee <calee@linkedin.com>

Closes #642 from nickpan47/SAMZA-1814 and squashes the following commits:

214373966 [Yi Pan (Data Infrastructure)] Merge branch 'master' into SAMZA-1814. With minor fixes to allow merge correctly.
f8c8108ac [Yi Pan (Data Infrastructure)] SAMZA-1814: Fix merging errors.
c8681a028 [Yi Pan (Data Infrastructure)] Merge branch 'master' into SAMZA-1814
b66b9fa9d [Yi Pan (Data Infrastructure)] SAMZA-1814: moving serde generation to a single top-level configuration generation, not embedded in table. Address review comments
0db5068dd [Yi Pan (Data Infrastructure)] SAMZA-1814: fix merge issue and consolidated some test classes
2c856c5f5 [Yi Pan (Data Infrastructure)] SAMZA-1814: consolidate configuration generation for high and low-level APIs
ffc6f1a70 [Yi Pan (Data Infrastructure)] SAMZA-1814: consolidate configuration generation in ExecutionPlanner between high and low-level API applications
c7fde4a03 [Yi Pan (Data Infrastructure)] Merge branch 'master' into SAMZA-1814
44844635b [Yi Pan (Data Infrastructure)] SAMZA-1814: merge with master
8797cdd4f [Yi Pan (Data Infrastructure)] SAMZA-1814: merge with master
dae98cebe [Yi Pan (Data Infrastructure)] SAMZA-1814: consolidate the configure generation between high and low-level API applications
3a91b9a62 [Yi Pan (Data Infrastructure)] SAMZA-1814: moving some logic to ApplicationDescriptorImpl to simplify the JobGraph/JobNode code
97c00a2e0 [Yi Pan (Data Infrastructure)] SAMZA-1814: WIP unit tests fixed for configure generation.
05637e6e6 [Yi Pan (Data Infrastructure)] SAMZA-1814: WIP consolidate all JobGraph and JobNode Json and Config generation code to support both high- and low-level applications
9d564642a [Yi Pan (Data Infrastructure)] Merge branch 'master' into SAMZA-1814
16bef1b1b [Yi Pan (Data Infrastructure)] SAMZA-1814: WIP fixing the task application configuration generation in the planner
66af5b706 [Yi Pan (Data Infrastructure)] SAMZA-1789: addressing Cameron's review comments.
ec4bb1dca [Yi Pan (Data Infrastructure)] SAMZA-1789: merge with fix for SAMZA-1836
9c89c63dc [Yi Pan (Data Infrastructure)] Merge branch 'master' into app-runtime-with-processor-callbacks
91fcd73ae [Yi Pan (Data Infrastructure)] Merge branch 'master' into app-runtime-with-processor-callbacks
34ffda8ae [Yi Pan (Data Infrastructure)] SAMZA-1789: disabling tests due to SAMZA-1836
02076c850 [Yi Pan (Data Infrastructure)] SAMZA-1789: fixed the modifier for the mandatory constructor of ApplicationRunner; Disabled three tests due to wrong configure for test systems
222abf21f [Yi Pan (Data Infrastructure)] SAMZA-1789: added a constructor to StreamProcessor to take a StreamProcessorListenerFactory
7a73992a5 [Yi Pan (Data Infrastructure)] SAMZA-1789: fixing checkstyle and javadoc errors
9997b98bb [Yi Pan (Data Infrastructure)] SAMZA-1789: renamed all ApplicationDescriptor classes with full-spelling of Application
f4b3d43a4 [Yi Pan (Data Infrastructure)] SAMZA-1789: Fxing TaskApplication examples and some checkstyle errors
f2969f8df [Yi Pan (Data Infrastructure)] SAMZA-1789: fixed ApplicationDescriptor to use InputDescriptor and OutputDescriptor; addressed Prateek's comments.
f04404cc2 [Yi Pan (Data Infrastructure)] SAMZA-1789: move createStreams out of the loop in prepareJobs
33753f72d [Yi Pan (Data Infrastructure)] Merge branch 'master' into app-runtime-with-processor-callbacks
12c09af06 [Yi Pan (Data Infrastructure)] SAMZA-1789: Fix a merging error (with SAMZA-1813)
a072118d0 [Yi Pan (Data Infrastructure)] Merge branch 'master' into app-runtime-with-processor-callbacks
e7af6932d [Yi Pan (Data Infrastructure)] Merge branch 'master' into app-runtime-with-processor-callbacks
8d4d3ffda [Yi Pan (Data Infrastructure)] Merge with master
055bd91e4 [Yi Pan (Data Infrastructure)] SAMZA-1789: fix unit test with ThreadJobFactory
247dcff4c [Yi Pan (Data Infrastructure)] Merge branch 'master' into app-runtime-with-processor-callbacks
1621c4d00 [Yi Pan (Data Infrastructure)] SAMZA-1789: a few more fixes to address Cameron's reviews
6e446fe6d [Yi Pan (Data Infrastructure)] SAMZA-1789: address Cameron's review comments.
4382d45db [Yi Pan (Data Infrastructure)] Merge branch 'master' into app-runtime-with-processor-callbacks
3b2f04d54 [Yi Pan (Data Infrastructure)] SAMZA-1789: moved all impl classes from samza-api to samza-core.
db96da830 [Yi Pan (Data Infrastructure)] SAMZA-1789: WIP - revision to address review feedbacks.
014337170 [Yi Pan (Data Infrastructure)] Merge branch 'master' into app-runtime-with-processor-callbacks
a82708bb0 [Yi Pan (Data Infrastructure)] SAMZA-1789: unify ApplicationDescriptor and ApplicationRunner for high- and low-level APIs in YARN and standalone environment
c4bb0dce6 [Yi Pan (Data Infrastructure)] Merge branch 'master' into app-runtime-with-processor-callbacks
f20cdcda6 [Yi Pan (Data Infrastructure)] WIP: adding unit tests. Pending update on StreamProcessorLifecycleListener, LocalContainerRunner, and SamzaContainerListener
973eb5261 [Yi Pan (Data Infrastructure)] WIP: compiles, still working on LocalContainerRunner refactor
fb1bc49e0 [Yi Pan (Data Infrastructure)] Merge branch 'master' into app-spec-with-app-runtime-Jul-16-18
30a4e5f0a [Yi Pan (Data Infrastructure)] WIP: application runner refactor - proto-type for SEP-13
95577b74c [Yi Pan (Data Infrastructure)] WIP: trying to figure out the two interface classes for spec: a) spec builder in init(); b) spec reader in all other lifecycle methods
42782d815 [Yi Pan (Data Infrastructure)] Merge branch 'prateek-remove-app-runner-stream-spec' into app-spec-with-app-runtime-Jul-16-18
d43e92319 [Yi Pan (Data Infrastructure)] WIP: proto-type with ApplicationRunnable and no ApplicationRunner exposed to user
f1cb8f0eb [Yi Pan (Data Infrastructure)] Merge branch 'master' into single-app-api-May-21-18
7e71dc7e0 [Yi Pan (Data Infrastructure)] Merge with master
856193013 [Prateek Maheshwari] Merge branch 'master' into stream-spec-cleanup
7d7aa5088 [Prateek Maheshwari] Updated with Cameron and Daniel's feedback.
8e6fc2dac [prateekm] Remove all usages of StreamSpec and ApplicationRunner from the operator spec and impl layers.

47 files changed:
build.gradle
samza-core/src/main/java/org/apache/samza/application/ApplicationDescriptorImpl.java
samza-core/src/main/java/org/apache/samza/application/StreamApplicationDescriptorImpl.java
samza-core/src/main/java/org/apache/samza/application/TaskApplicationDescriptorImpl.java
samza-core/src/main/java/org/apache/samza/execution/ExecutionPlanner.java
samza-core/src/main/java/org/apache/samza/execution/IntermediateStreamManager.java [new file with mode: 0644]
samza-core/src/main/java/org/apache/samza/execution/JobGraph.java
samza-core/src/main/java/org/apache/samza/execution/JobGraphJsonGenerator.java
samza-core/src/main/java/org/apache/samza/execution/JobNode.java
samza-core/src/main/java/org/apache/samza/execution/JobNodeConfigurationGenerator.java [new file with mode: 0644]
samza-core/src/main/java/org/apache/samza/execution/JobPlanner.java
samza-core/src/main/java/org/apache/samza/execution/LocalJobPlanner.java
samza-core/src/main/java/org/apache/samza/execution/OperatorSpecGraphAnalyzer.java
samza-core/src/main/java/org/apache/samza/execution/RemoteJobPlanner.java
samza-core/src/main/java/org/apache/samza/operators/BaseTableDescriptor.java
samza-core/src/main/java/org/apache/samza/operators/OperatorSpecGraph.java
samza-core/src/main/java/org/apache/samza/runtime/LocalContainerRunner.java
samza-core/src/main/java/org/apache/samza/table/TableConfigGenerator.java
samza-core/src/main/java/org/apache/samza/zk/ZkJobCoordinatorFactory.java
samza-core/src/main/scala/org/apache/samza/config/JobConfig.scala
samza-core/src/main/scala/org/apache/samza/container/SamzaContainer.scala
samza-core/src/main/scala/org/apache/samza/metrics/reporter/MetricsSnapshotReporterFactory.scala
samza-core/src/main/scala/org/apache/samza/util/CoordinatorStreamUtil.scala
samza-core/src/test/java/org/apache/samza/application/TestStreamApplicationDescriptorImpl.java
samza-core/src/test/java/org/apache/samza/application/TestTaskApplicationDescriptorImpl.java
samza-core/src/test/java/org/apache/samza/execution/ExecutionPlannerTestBase.java [new file with mode: 0644]
samza-core/src/test/java/org/apache/samza/execution/TestExecutionPlanner.java
samza-core/src/test/java/org/apache/samza/execution/TestIntermediateStreamManager.java [new file with mode: 0644]
samza-core/src/test/java/org/apache/samza/execution/TestJobGraph.java
samza-core/src/test/java/org/apache/samza/execution/TestJobGraphJsonGenerator.java
samza-core/src/test/java/org/apache/samza/execution/TestJobNode.java [deleted file]
samza-core/src/test/java/org/apache/samza/execution/TestJobNodeConfigurationGenerator.java [new file with mode: 0644]
samza-core/src/test/java/org/apache/samza/execution/TestRemoteJobPlanner.java
samza-core/src/test/java/org/apache/samza/operators/TestOperatorSpecGraph.java
samza-core/src/test/java/org/apache/samza/operators/spec/OperatorSpecTestUtils.java
samza-core/src/test/java/org/apache/samza/runtime/TestLocalApplicationRunner.java
samza-core/src/test/java/org/apache/samza/runtime/TestRemoteApplicationRunner.java
samza-hdfs/src/main/scala/org/apache/samza/system/hdfs/HdfsSystemFactory.scala
samza-kafka/src/main/scala/org/apache/samza/checkpoint/kafka/KafkaCheckpointManagerFactory.scala
samza-kafka/src/main/scala/org/apache/samza/config/KafkaConsumerConfig.java
samza-kafka/src/main/scala/org/apache/samza/util/KafkaUtil.scala
samza-kv/src/main/java/org/apache/samza/storage/kv/BaseLocalStoreBackedTableProvider.java
samza-test/src/main/java/org/apache/samza/test/framework/TestRunner.java
samza-test/src/main/java/org/apache/samza/test/framework/system/InMemorySystemDescriptor.java
samza-test/src/test/java/org/apache/samza/test/table/TestTableDescriptorsProvider.java
samza-yarn/src/main/java/org/apache/samza/validation/YarnJobValidationTool.java
samza-yarn/src/main/scala/org/apache/samza/job/yarn/YarnJob.scala

index 48a28f2..a45a875 100644 (file)
@@ -194,6 +194,7 @@ project(":samza-core_$scalaVersion") {
     testCompile "org.powermock:powermock-core:$powerMockVersion"
     testCompile "org.powermock:powermock-module-junit4:$powerMockVersion"
     testCompile "org.scalatest:scalatest_$scalaVersion:$scalaTestVersion"
+    testCompile "org.hamcrest:hamcrest-all:$hamcrestVersion"
   }
 
   checkstyle {
index 9679136..b58d5a5 100644 (file)
@@ -19,6 +19,7 @@
 package org.apache.samza.application;
 
 import java.util.Collections;
+import java.util.HashMap;
 import java.util.LinkedHashMap;
 import java.util.Map;
 import java.util.Optional;
@@ -26,13 +27,20 @@ import java.util.Set;
 import org.apache.samza.config.Config;
 import org.apache.samza.metrics.MetricsReporterFactory;
 import org.apache.samza.operators.ContextManager;
+import org.apache.samza.operators.KV;
 import org.apache.samza.operators.TableDescriptor;
 import org.apache.samza.operators.descriptors.base.stream.InputDescriptor;
 import org.apache.samza.operators.descriptors.base.stream.OutputDescriptor;
 import org.apache.samza.operators.descriptors.base.system.SystemDescriptor;
+import org.apache.samza.operators.spec.InputOperatorSpec;
 import org.apache.samza.runtime.ProcessorLifecycleListener;
 import org.apache.samza.runtime.ProcessorLifecycleListenerFactory;
+import org.apache.samza.serializers.KVSerde;
+import org.apache.samza.serializers.NoOpSerde;
+import org.apache.samza.serializers.Serde;
 import org.apache.samza.task.TaskContext;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
 
 
 /**
@@ -46,10 +54,15 @@ import org.apache.samza.task.TaskContext;
  */
 public abstract class ApplicationDescriptorImpl<S extends ApplicationDescriptor>
     implements ApplicationDescriptor<S> {
+  private static final Logger LOGGER = LoggerFactory.getLogger(ApplicationDescriptorImpl.class);
 
-  final Config config;
   private final Class<? extends SamzaApplication> appClass;
   private final Map<String, MetricsReporterFactory> reporterFactories = new LinkedHashMap<>();
+  // serdes used by input/output/intermediate streams, keyed by streamId
+  private final Map<String, KV<Serde, Serde>> streamSerdes = new HashMap<>();
+  // serdes used by tables, keyed by tableId
+  private final Map<String, KV<Serde, Serde>> tableSerdes = new HashMap<>();
+  final Config config;
 
   // Default to no-op functions in ContextManager
   // TODO: this should be replaced by shared context factory defined in SAMZA-1714
@@ -142,6 +155,35 @@ public abstract class ApplicationDescriptorImpl<S extends ApplicationDescriptor>
   }
 
   /**
+   * Get the corresponding {@link KVSerde} for the input {@code inputStreamId}
+   *
+   * @param streamId id of the stream
+   * @return the {@link KVSerde} for the stream. null if the serde is not defined or {@code streamId} does not exist
+   */
+  public KV<Serde, Serde> getStreamSerdes(String streamId) {
+    return streamSerdes.get(streamId);
+  }
+
+  /**
+   * Get the corresponding {@link KVSerde} for the input {@code inputStreamId}
+   *
+   * @param tableId id of the table
+   * @return the {@link KVSerde} for the stream. null if the serde is not defined or {@code streamId} does not exist
+   */
+  public KV<Serde, Serde> getTableSerdes(String tableId) {
+    return tableSerdes.get(tableId);
+  }
+
+  /**
+   * Get the map of all {@link InputOperatorSpec}s in this applicaiton
+   *
+   * @return an immutable map from streamId to {@link InputOperatorSpec}. Default to empty map for low-level {@link TaskApplication}
+   */
+  public Map<String, InputOperatorSpec> getInputOperators() {
+    return Collections.EMPTY_MAP;
+  }
+
+  /**
    * Get all the {@link InputDescriptor}s to this application
    *
    * @return an immutable map of streamId to {@link InputDescriptor}
@@ -176,4 +218,66 @@ public abstract class ApplicationDescriptorImpl<S extends ApplicationDescriptor>
    */
   public abstract Set<SystemDescriptor> getSystemDescriptors();
 
+  /**
+   * Get all the unique input streamIds in this application
+   *
+   * @return an immutable set of input streamIds
+   */
+  public abstract Set<String> getInputStreamIds();
+
+  /**
+   * Get all the unique output streamIds in this application
+   *
+   * @return an immutable set of output streamIds
+   */
+  public abstract Set<String> getOutputStreamIds();
+
+  KV<Serde, Serde> getOrCreateStreamSerdes(String streamId, Serde serde) {
+    Serde keySerde, valueSerde;
+
+    KV<Serde, Serde> currentSerdePair = streamSerdes.get(streamId);
+
+    if (serde instanceof KVSerde) {
+      keySerde = ((KVSerde) serde).getKeySerde();
+      valueSerde = ((KVSerde) serde).getValueSerde();
+    } else {
+      keySerde = new NoOpSerde();
+      valueSerde = serde;
+    }
+
+    if (currentSerdePair == null) {
+      if (keySerde instanceof NoOpSerde) {
+        LOGGER.info("Using NoOpSerde as the key serde for stream " + streamId +
+            ". Keys will not be (de)serialized");
+      }
+      if (valueSerde instanceof NoOpSerde) {
+        LOGGER.info("Using NoOpSerde as the value serde for stream " + streamId +
+            ". Values will not be (de)serialized");
+      }
+      streamSerdes.put(streamId, KV.of(keySerde, valueSerde));
+    } else if (!currentSerdePair.getKey().equals(keySerde) || !currentSerdePair.getValue().equals(valueSerde)) {
+      throw new IllegalArgumentException(String.format("Serde for stream %s is already defined. Cannot change it to "
+          + "different serdes.", streamId));
+    }
+    return streamSerdes.get(streamId);
+  }
+
+  KV<Serde, Serde> getOrCreateTableSerdes(String tableId, KVSerde kvSerde) {
+    Serde keySerde, valueSerde;
+    keySerde = kvSerde.getKeySerde();
+    valueSerde = kvSerde.getValueSerde();
+
+    if (!tableSerdes.containsKey(tableId)) {
+      tableSerdes.put(tableId, KV.of(keySerde, valueSerde));
+      return tableSerdes.get(tableId);
+    }
+
+    KV<Serde, Serde> currentSerdePair = tableSerdes.get(tableId);
+    if (!currentSerdePair.getKey().equals(keySerde) || !currentSerdePair.getValue().equals(valueSerde)) {
+      throw new IllegalArgumentException(String.format("Serde for table %s is already defined. Cannot change it to "
+          + "different serdes.", tableId));
+    }
+    return streamSerdes.get(tableId);
+  }
+
 }
\ No newline at end of file
index d50b0d0..5129913 100644 (file)
@@ -51,7 +51,6 @@ import org.apache.samza.operators.spec.OperatorSpecs;
 import org.apache.samza.operators.spec.OutputStreamImpl;
 import org.apache.samza.operators.stream.IntermediateMessageStreamImpl;
 import org.apache.samza.serializers.KVSerde;
-import org.apache.samza.serializers.NoOpSerde;
 import org.apache.samza.serializers.Serde;
 import org.apache.samza.table.Table;
 import org.apache.samza.table.TableSpec;
@@ -78,7 +77,7 @@ public class StreamApplicationDescriptorImpl extends ApplicationDescriptorImpl<S
   // We use a LHM for deterministic order in initializing and closing operators.
   private final Map<String, InputOperatorSpec> inputOperators = new LinkedHashMap<>();
   private final Map<String, OutputStreamImpl> outputStreams = new LinkedHashMap<>();
-  private final Map<TableSpec, TableImpl> tables = new LinkedHashMap<>();
+  private final Map<String, TableImpl> tables = new LinkedHashMap<>();
   private final Set<String> operatorIds = new HashSet<>();
 
   private Optional<SystemDescriptor> defaultSystemDescriptorOptional = Optional.empty();
@@ -125,7 +124,7 @@ public class StreamApplicationDescriptorImpl extends ApplicationDescriptorImpl<S
         "getInputStream must not be called multiple times with the same streamId: " + streamId);
 
     Serde serde = inputDescriptor.getSerde();
-    KV<Serde, Serde> kvSerdes = getKVSerdes(streamId, serde);
+    KV<Serde, Serde> kvSerdes = getOrCreateStreamSerdes(streamId, serde);
     if (outputStreams.containsKey(streamId)) {
       OutputStreamImpl outputStream = outputStreams.get(streamId);
       Serde keySerde = outputStream.getKeySerde();
@@ -156,7 +155,7 @@ public class StreamApplicationDescriptorImpl extends ApplicationDescriptorImpl<S
         "getOutputStream must not be called multiple times with the same streamId: " + streamId);
 
     Serde serde = outputDescriptor.getSerde();
-    KV<Serde, Serde> kvSerdes = getKVSerdes(streamId, serde);
+    KV<Serde, Serde> kvSerdes = getOrCreateStreamSerdes(streamId, serde);
     if (inputOperators.containsKey(streamId)) {
       InputOperatorSpec inputOperatorSpec = inputOperators.get(streamId);
       Serde keySerde = inputOperatorSpec.getKeySerde();
@@ -186,13 +185,15 @@ public class StreamApplicationDescriptorImpl extends ApplicationDescriptorImpl<S
         String.format("add table descriptors multiple times with the same tableId: %s", tableDescriptor.getTableId()));
     tableDescriptors.put(tableDescriptor.getTableId(), tableDescriptor);
 
-    TableSpec tableSpec = ((BaseTableDescriptor) tableDescriptor).getTableSpec();
-    if (tables.containsKey(tableSpec)) {
+    BaseTableDescriptor baseTableDescriptor = (BaseTableDescriptor) tableDescriptor;
+    TableSpec tableSpec = baseTableDescriptor.getTableSpec();
+    if (tables.containsKey(tableSpec.getId())) {
       throw new IllegalStateException(
           String.format("getTable() invoked multiple times with the same tableId: %s", tableId));
     }
-    tables.put(tableSpec, new TableImpl(tableSpec));
-    return tables.get(tableSpec);
+    tables.put(tableSpec.getId(), new TableImpl(tableSpec));
+    getOrCreateTableSerdes(tableSpec.getId(), baseTableDescriptor.getSerde());
+    return tables.get(tableSpec.getId());
   }
 
   /**
@@ -247,6 +248,16 @@ public class StreamApplicationDescriptorImpl extends ApplicationDescriptorImpl<S
     return Collections.unmodifiableSet(new HashSet<>(systemDescriptors.values()));
   }
 
+  @Override
+  public Set<String> getInputStreamIds() {
+    return Collections.unmodifiableSet(new HashSet<>(inputOperators.keySet()));
+  }
+
+  @Override
+  public Set<String> getOutputStreamIds() {
+    return Collections.unmodifiableSet(new HashSet<>(outputStreams.keySet()));
+  }
+
   /**
    * Get the default {@link SystemDescriptor} in this application
    *
@@ -306,7 +317,7 @@ public class StreamApplicationDescriptorImpl extends ApplicationDescriptorImpl<S
     return Collections.unmodifiableMap(outputStreams);
   }
 
-  public Map<TableSpec, TableImpl> getTables() {
+  public Map<String, TableImpl> getTables() {
     return Collections.unmodifiableMap(tables);
   }
 
@@ -342,7 +353,7 @@ public class StreamApplicationDescriptorImpl extends ApplicationDescriptorImpl<S
       kvSerdes = new KV<>(null, null); // and that key and msg serdes are provided for job.default.system in configs
     } else {
       isKeyed = serde instanceof KVSerde;
-      kvSerdes = getKVSerdes(streamId, serde);
+      kvSerdes = getOrCreateStreamSerdes(streamId, serde);
     }
 
     InputTransformer transformer = (InputTransformer) getDefaultSystemDescriptor()
@@ -356,29 +367,6 @@ public class StreamApplicationDescriptorImpl extends ApplicationDescriptorImpl<S
     return new IntermediateMessageStreamImpl<>(this, inputOperators.get(streamId), outputStreams.get(streamId));
   }
 
-  private KV<Serde, Serde> getKVSerdes(String streamId, Serde serde) {
-    Serde keySerde, valueSerde;
-
-    if (serde instanceof KVSerde) {
-      keySerde = ((KVSerde) serde).getKeySerde();
-      valueSerde = ((KVSerde) serde).getValueSerde();
-    } else {
-      keySerde = new NoOpSerde();
-      valueSerde = serde;
-    }
-
-    if (keySerde instanceof NoOpSerde) {
-      LOGGER.info("Using NoOpSerde as the key serde for stream " + streamId +
-          ". Keys will not be (de)serialized");
-    }
-    if (valueSerde instanceof NoOpSerde) {
-      LOGGER.info("Using NoOpSerde as the value serde for stream " + streamId +
-          ". Values will not be (de)serialized");
-    }
-
-    return KV.of(keySerde, valueSerde);
-  }
-
   // check uniqueness of the {@code systemDescriptor} and add if it is unique
   private void addSystemDescriptor(SystemDescriptor systemDescriptor) {
     Preconditions.checkState(!systemDescriptors.containsKey(systemDescriptor.getSystemName())
index 3597d7c..d140a90 100644 (file)
@@ -25,6 +25,7 @@ import java.util.LinkedHashMap;
 import java.util.Map;
 import java.util.Set;
 import org.apache.samza.config.Config;
+import org.apache.samza.operators.BaseTableDescriptor;
 import org.apache.samza.operators.TableDescriptor;
 import org.apache.samza.operators.descriptors.base.stream.InputDescriptor;
 import org.apache.samza.operators.descriptors.base.stream.OutputDescriptor;
@@ -65,6 +66,7 @@ public class TaskApplicationDescriptorImpl extends ApplicationDescriptorImpl<Tas
     // TODO: SAMZA-1841: need to add to the broadcast streams if inputDescriptor is for a broadcast stream
     Preconditions.checkState(!inputDescriptors.containsKey(inputDescriptor.getStreamId()),
         String.format("add input descriptors multiple times with the same streamId: %s", inputDescriptor.getStreamId()));
+    getOrCreateStreamSerdes(inputDescriptor.getStreamId(), inputDescriptor.getSerde());
     inputDescriptors.put(inputDescriptor.getStreamId(), inputDescriptor);
     addSystemDescriptor(inputDescriptor.getSystemDescriptor());
   }
@@ -73,6 +75,7 @@ public class TaskApplicationDescriptorImpl extends ApplicationDescriptorImpl<Tas
   public void addOutputStream(OutputDescriptor outputDescriptor) {
     Preconditions.checkState(!outputDescriptors.containsKey(outputDescriptor.getStreamId()),
         String.format("add output descriptors multiple times with the same streamId: %s", outputDescriptor.getStreamId()));
+    getOrCreateStreamSerdes(outputDescriptor.getStreamId(), outputDescriptor.getSerde());
     outputDescriptors.put(outputDescriptor.getStreamId(), outputDescriptor);
     addSystemDescriptor(outputDescriptor.getSystemDescriptor());
   }
@@ -81,6 +84,7 @@ public class TaskApplicationDescriptorImpl extends ApplicationDescriptorImpl<Tas
   public void addTable(TableDescriptor tableDescriptor) {
     Preconditions.checkState(!tableDescriptors.containsKey(tableDescriptor.getTableId()),
         String.format("add table descriptors multiple times with the same tableId: %s", tableDescriptor.getTableId()));
+    getOrCreateTableSerdes(tableDescriptor.getTableId(), ((BaseTableDescriptor) tableDescriptor).getSerde());
     tableDescriptors.put(tableDescriptor.getTableId(), tableDescriptor);
   }
 
@@ -111,6 +115,16 @@ public class TaskApplicationDescriptorImpl extends ApplicationDescriptorImpl<Tas
     return Collections.unmodifiableSet(new HashSet<>(systemDescriptors.values()));
   }
 
+  @Override
+  public Set<String> getInputStreamIds() {
+    return Collections.unmodifiableSet(new HashSet<>(inputDescriptors.keySet()));
+  }
+
+  @Override
+  public Set<String> getOutputStreamIds() {
+    return Collections.unmodifiableSet(new HashSet<>(outputDescriptors.keySet()));
+  }
+
   /**
    * Get the user-defined {@link TaskFactory}
    * @return the {@link TaskFactory} object
index 46aef8d..eea6387 100644 (file)
@@ -22,72 +22,57 @@ package org.apache.samza.execution;
 import com.google.common.collect.HashMultimap;
 import com.google.common.collect.Multimap;
 import com.google.common.collect.Sets;
-import java.util.ArrayList;
 import java.util.Collection;
-import java.util.Collections;
-import java.util.Comparator;
 import java.util.HashMap;
 import java.util.HashSet;
-import java.util.List;
 import java.util.Map;
 import java.util.Set;
+import java.util.stream.Collectors;
 import org.apache.samza.SamzaException;
+import org.apache.samza.application.ApplicationDescriptor;
+import org.apache.samza.application.ApplicationDescriptorImpl;
+import org.apache.samza.application.LegacyTaskApplication;
 import org.apache.samza.config.ApplicationConfig;
 import org.apache.samza.config.ClusterManagerConfig;
 import org.apache.samza.config.Config;
 import org.apache.samza.config.JobConfig;
 import org.apache.samza.config.StreamConfig;
-import org.apache.samza.operators.OperatorSpecGraph;
-import org.apache.samza.operators.spec.InputOperatorSpec;
-import org.apache.samza.operators.spec.JoinOperatorSpec;
+import org.apache.samza.operators.BaseTableDescriptor;
 import org.apache.samza.system.StreamSpec;
 import org.apache.samza.table.TableSpec;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
-import static org.apache.samza.execution.ExecutionPlanner.StreamEdgeSet.StreamEdgeSetCategory;
 import static org.apache.samza.util.StreamUtil.*;
 
 
 /**
- * The ExecutionPlanner creates the physical execution graph for the {@link OperatorSpecGraph}, and
+ * The ExecutionPlanner creates the physical execution graph for the {@link ApplicationDescriptorImpl}, and
  * the intermediate topics needed for the execution.
  */
 // TODO: ExecutionPlanner needs to be able to generate single node JobGraph for low-level TaskApplication as well (SAMZA-1811)
 public class ExecutionPlanner {
   private static final Logger log = LoggerFactory.getLogger(ExecutionPlanner.class);
 
-  /* package private */ static final int MAX_INFERRED_PARTITIONS = 256;
-
   private final Config config;
-  private final StreamConfig streamConfig;
   private final StreamManager streamManager;
 
   public ExecutionPlanner(Config config, StreamManager streamManager) {
     this.config = config;
     this.streamManager = streamManager;
-    this.streamConfig = new StreamConfig(config);
   }
 
-  public ExecutionPlan plan(OperatorSpecGraph opSpecGraph) {
+  public ExecutionPlan plan(ApplicationDescriptorImpl<? extends ApplicationDescriptor> appDesc) {
     validateConfig();
 
-    // Create physical job graph based on stream graph
-    JobGraph jobGraph = createJobGraph(opSpecGraph);
-
-    // Fetch the external streams partition info
-    fetchInputAndOutputStreamPartitions(jobGraph);
+    // create physical job graph based on stream graph
+    JobGraph jobGraph = createJobGraph(config, appDesc);
 
-    // Verify agreement in partition count between all joined input/intermediate streams
-    validateJoinInputStreamPartitions(jobGraph);
+    // fetch the external streams partition info
+    setInputAndOutputStreamPartitionCount(jobGraph, streamManager);
 
-    if (!jobGraph.getIntermediateStreamEdges().isEmpty()) {
-      // Set partition count of intermediate streams not participating in joins
-      setIntermediateStreamPartitions(jobGraph);
-
-      // Validate partition counts were assigned for all intermediate streams
-      validateIntermediateStreamPartitions(jobGraph);
-    }
+    // figure out the partitions for internal streams
+    new IntermediateStreamManager(config, appDesc).calculatePartitions(jobGraph);
 
     return jobGraph;
   }
@@ -103,21 +88,23 @@ public class ExecutionPlanner {
   }
 
   /**
-   * Creates the physical graph from {@link OperatorSpecGraph}
+   * Create the physical graph from {@link ApplicationDescriptorImpl}
    */
-  /* package private */ JobGraph createJobGraph(OperatorSpecGraph opSpecGraph) {
-    JobGraph jobGraph = new JobGraph(config, opSpecGraph);
-
+  /* package private */
+  JobGraph createJobGraph(Config config, ApplicationDescriptorImpl<? extends ApplicationDescriptor> appDesc) {
+    JobGraph jobGraph = new JobGraph(config, appDesc);
+    StreamConfig streamConfig = new StreamConfig(config);
     // Source streams contain both input and intermediate streams.
-    Set<StreamSpec> sourceStreams = getStreamSpecs(opSpecGraph.getInputOperators().keySet(), streamConfig);
+    Set<StreamSpec> sourceStreams = getStreamSpecs(appDesc.getInputStreamIds(), streamConfig);
     // Sink streams contain both output and intermediate streams.
-    Set<StreamSpec> sinkStreams = getStreamSpecs(opSpecGraph.getOutputStreams().keySet(), streamConfig);
+    Set<StreamSpec> sinkStreams = getStreamSpecs(appDesc.getOutputStreamIds(), streamConfig);
 
     Set<StreamSpec> intermediateStreams = Sets.intersection(sourceStreams, sinkStreams);
     Set<StreamSpec> inputStreams = Sets.difference(sourceStreams, intermediateStreams);
     Set<StreamSpec> outputStreams = Sets.difference(sinkStreams, intermediateStreams);
 
-    Set<TableSpec> tables = opSpecGraph.getTables().keySet();
+    Set<TableSpec> tables = appDesc.getTableDescriptors().stream()
+        .map(tableDescriptor -> ((BaseTableDescriptor) tableDescriptor).getTableSpec()).collect(Collectors.toSet());
 
     // For this phase, we have a single job node for the whole dag
     String jobName = config.get(JobConfig.JOB_NAME());
@@ -136,15 +123,20 @@ public class ExecutionPlanner {
     // Add tables
     tables.forEach(spec -> jobGraph.addTable(spec, node));
 
-    jobGraph.validate();
+    if (!LegacyTaskApplication.class.isAssignableFrom(appDesc.getAppClass())) {
+      // skip the validation when input streamIds are empty. This is only possible for LegacyTaskApplication
+      jobGraph.validate();
+    }
 
     return jobGraph;
   }
 
   /**
-   * Fetches the partitions of input/output streams and update the corresponding StreamEdges.
+   * Fetch the partitions of source/sink streams and update the StreamEdges.
+   * @param jobGraph {@link JobGraph}
+   * @param streamManager the {@link StreamManager} to interface with the streams.
    */
-  /* package private */ void fetchInputAndOutputStreamPartitions(JobGraph jobGraph) {
+  /* package private */ static void setInputAndOutputStreamPartitionCount(JobGraph jobGraph, StreamManager streamManager) {
     Set<StreamEdge> existingStreams = new HashSet<>();
     existingStreams.addAll(jobGraph.getInputStreams());
     existingStreams.addAll(jobGraph.getOutputStreams());
@@ -182,224 +174,4 @@ public class ExecutionPlanner {
     }
   }
 
-  /**
-   * Validates agreement in partition count between input/intermediate streams participating in join operations.
-   */
-  private void validateJoinInputStreamPartitions(JobGraph jobGraph) {
-    // Group input operator specs (input/intermediate streams) by the joins they participate in.
-    Multimap<JoinOperatorSpec, InputOperatorSpec> joinOpSpecToInputOpSpecs =
-        OperatorSpecGraphAnalyzer.getJoinToInputOperatorSpecs(jobGraph.getSpecGraph());
-
-    // Convert every group of input operator specs into a group of corresponding stream edges.
-    List<StreamEdgeSet> streamEdgeSets = new ArrayList<>();
-    for (JoinOperatorSpec joinOpSpec : joinOpSpecToInputOpSpecs.keySet()) {
-      Collection<InputOperatorSpec> joinedInputOpSpecs = joinOpSpecToInputOpSpecs.get(joinOpSpec);
-      StreamEdgeSet streamEdgeSet = getStreamEdgeSet(joinOpSpec.getOpId(), joinedInputOpSpecs, jobGraph);
-      streamEdgeSets.add(streamEdgeSet);
-    }
-
-    /*
-     * Sort the stream edge groups by their category so they appear in this order:
-     *   1. groups composed exclusively of stream edges with set partition counts
-     *   2. groups composed of a mix of stream edges  with set/unset partition counts
-     *   3. groups composed exclusively of stream edges with unset partition counts
-     *
-     *   This guarantees that we process the most constrained stream edge groups first,
-     *   which is crucial for intermediate stream edges that are members of multiple
-     *   stream edge groups. For instance, if we have the following groups of stream
-     *   edges (partition counts in parentheses, question marks for intermediate streams):
-     *
-     *      a. e1 (16), e2 (16)
-     *      b. e2 (16), e3 (?)
-     *      c. e3 (?), e4 (?)
-     *
-     *   processing them in the above order (most constrained first) is guaranteed to
-     *   yield correct assignment of partition counts of e3 and e4 in a single scan.
-     */
-    Collections.sort(streamEdgeSets, Comparator.comparingInt(e -> e.getCategory().getSortOrder()));
-
-    // Verify agreement between joined input/intermediate streams.
-    // This may involve setting partition counts of intermediate stream edges.
-    streamEdgeSets.forEach(ExecutionPlanner::validateAndAssignStreamEdgeSetPartitions);
-  }
-
-  /**
-   * Creates a {@link StreamEdgeSet} whose Id is {@code setId}, and {@link StreamEdge}s
-   * correspond to the provided {@code inputOpSpecs}.
-   */
-  private StreamEdgeSet getStreamEdgeSet(String setId, Iterable<InputOperatorSpec> inputOpSpecs,
-      JobGraph jobGraph) {
-
-    int countStreamEdgeWithSetPartitions = 0;
-    Set<StreamEdge> streamEdges = new HashSet<>();
-
-    for (InputOperatorSpec inputOpSpec : inputOpSpecs) {
-      StreamEdge streamEdge = jobGraph.getOrCreateStreamEdge(getStreamSpec(inputOpSpec.getStreamId(), streamConfig));
-      if (streamEdge.getPartitionCount() != StreamEdge.PARTITIONS_UNKNOWN) {
-        ++countStreamEdgeWithSetPartitions;
-      }
-      streamEdges.add(streamEdge);
-    }
-
-    // Determine category of stream group based on stream partition counts.
-    StreamEdgeSetCategory category;
-    if (countStreamEdgeWithSetPartitions == 0) {
-      category = StreamEdgeSetCategory.NO_PARTITION_COUNT_SET;
-    } else if (countStreamEdgeWithSetPartitions == streamEdges.size()) {
-      category = StreamEdgeSetCategory.ALL_PARTITION_COUNT_SET;
-    } else {
-      category = StreamEdgeSetCategory.SOME_PARTITION_COUNT_SET;
-    }
-
-    return new StreamEdgeSet(setId, streamEdges, category);
-  }
-
-  /**
-   * Sets partition count of intermediate streams which have not been assigned partition counts.
-   */
-  private void setIntermediateStreamPartitions(JobGraph jobGraph) {
-    final String defaultPartitionsConfigProperty = JobConfig.JOB_INTERMEDIATE_STREAM_PARTITIONS();
-    int partitions = config.getInt(defaultPartitionsConfigProperty, StreamEdge.PARTITIONS_UNKNOWN);
-    if (partitions == StreamEdge.PARTITIONS_UNKNOWN) {
-      // use the following simple algo to figure out the partitions
-      // partition = MAX(MAX(Input topic partitions), MAX(Output topic partitions))
-      // partition will be further bounded by MAX_INFERRED_PARTITIONS.
-      // This is important when running in hadoop where an HDFS input can have lots of files (partitions).
-      int maxInPartitions = maxPartitions(jobGraph.getInputStreams());
-      int maxOutPartitions = maxPartitions(jobGraph.getOutputStreams());
-      partitions = Math.max(maxInPartitions, maxOutPartitions);
-
-      if (partitions > MAX_INFERRED_PARTITIONS) {
-        partitions = MAX_INFERRED_PARTITIONS;
-        log.warn(String.format("Inferred intermediate stream partition count %d is greater than the max %d. Using the max.",
-            partitions, MAX_INFERRED_PARTITIONS));
-      }
-    } else {
-      // Reject any zero or other negative values explicitly specified in config.
-      if (partitions <= 0) {
-        throw new SamzaException(String.format("Invalid value %d specified for config property %s", partitions,
-            defaultPartitionsConfigProperty));
-      }
-
-      log.info("Using partition count value {} specified for config property {}", partitions,
-          defaultPartitionsConfigProperty);
-    }
-
-    for (StreamEdge edge : jobGraph.getIntermediateStreamEdges()) {
-      if (edge.getPartitionCount() <= 0) {
-        log.info("Set the partition count for intermediate stream {} to {}.", edge.getName(), partitions);
-        edge.setPartitionCount(partitions);
-      }
-    }
-  }
-
-  /**
-   * Ensures all intermediate streams have been assigned partition counts.
-   */
-  private static void validateIntermediateStreamPartitions(JobGraph jobGraph) {
-    for (StreamEdge edge : jobGraph.getIntermediateStreamEdges()) {
-      if (edge.getPartitionCount() <= 0) {
-        throw new SamzaException(String.format("Failed to assign valid partition count to Stream %s", edge.getName()));
-      }
-    }
-  }
-
-  /**
-   * Ensures that all streams in the supplied {@link StreamEdgeSet} agree in partition count.
-   * This may include setting partition counts of intermediate streams in this set that do not
-   * have their partition counts set.
-   */
-  private static void validateAndAssignStreamEdgeSetPartitions(StreamEdgeSet streamEdgeSet) {
-    Set<StreamEdge> streamEdges = streamEdgeSet.getStreamEdges();
-    StreamEdge firstStreamEdgeWithSetPartitions =
-        streamEdges.stream()
-            .filter(streamEdge -> streamEdge.getPartitionCount() != StreamEdge.PARTITIONS_UNKNOWN)
-            .findFirst()
-            .orElse(null);
-
-    // This group consists exclusively of intermediate streams with unknown partition counts.
-    // We cannot do any validation/computation of partition counts of such streams right here,
-    // but they are tackled later in the ExecutionPlanner.
-    if (firstStreamEdgeWithSetPartitions == null) {
-      return;
-    }
-
-    // Make sure all other stream edges in this group have the same partition count.
-    int partitions = firstStreamEdgeWithSetPartitions.getPartitionCount();
-    for (StreamEdge streamEdge : streamEdges) {
-      int streamPartitions = streamEdge.getPartitionCount();
-      if (streamPartitions == StreamEdge.PARTITIONS_UNKNOWN) {
-        streamEdge.setPartitionCount(partitions);
-        log.info("Inferred the partition count {} for the join operator {} from {}."
-            , new Object[] {partitions, streamEdgeSet.getSetId(), firstStreamEdgeWithSetPartitions.getName()});
-      } else if (streamPartitions != partitions) {
-        throw  new SamzaException(String.format(
-            "Unable to resolve input partitions of stream %s for the join %s. Expected: %d, Actual: %d",
-            streamEdge.getName(), streamEdgeSet.getSetId(), partitions, streamPartitions));
-      }
-    }
-  }
-
-  /* package private */ static int maxPartitions(Collection<StreamEdge> edges) {
-    return edges.stream().mapToInt(StreamEdge::getPartitionCount).max().orElse(StreamEdge.PARTITIONS_UNKNOWN);
-  }
-
-  /**
-   * Represents a set of {@link StreamEdge}s.
-   */
-  /* package private */ static class StreamEdgeSet {
-
-    /**
-     * Indicates whether all stream edges in this group have their partition counts assigned.
-     */
-    public enum StreamEdgeSetCategory {
-      /**
-       * All stream edges in this group have their partition counts assigned.
-       */
-      ALL_PARTITION_COUNT_SET(0),
-
-      /**
-       * Only some stream edges in this group have their partition counts assigned.
-       */
-      SOME_PARTITION_COUNT_SET(1),
-
-      /**
-       * No stream edge in this group is assigned a partition count.
-       */
-      NO_PARTITION_COUNT_SET(2);
-
-
-      private final int sortOrder;
-
-      StreamEdgeSetCategory(int sortOrder) {
-        this.sortOrder = sortOrder;
-      }
-
-      public int getSortOrder() {
-        return sortOrder;
-      }
-    }
-
-    private final String setId;
-    private final Set<StreamEdge> streamEdges;
-    private final StreamEdgeSetCategory category;
-
-    public StreamEdgeSet(String setId, Set<StreamEdge> streamEdges, StreamEdgeSetCategory category) {
-      this.setId = setId;
-      this.streamEdges = streamEdges;
-      this.category = category;
-    }
-
-    public Set<StreamEdge> getStreamEdges() {
-      return streamEdges;
-    }
-
-    public String getSetId() {
-      return setId;
-    }
-
-    public StreamEdgeSetCategory getCategory() {
-      return category;
-    }
-  }
 }
diff --git a/samza-core/src/main/java/org/apache/samza/execution/IntermediateStreamManager.java b/samza-core/src/main/java/org/apache/samza/execution/IntermediateStreamManager.java
new file mode 100644 (file)
index 0000000..66cbe6a
--- /dev/null
@@ -0,0 +1,297 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.execution;
+
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.collect.Multimap;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import org.apache.samza.SamzaException;
+import org.apache.samza.application.ApplicationDescriptor;
+import org.apache.samza.application.ApplicationDescriptorImpl;
+import org.apache.samza.config.Config;
+import org.apache.samza.config.JobConfig;
+import org.apache.samza.operators.spec.InputOperatorSpec;
+import org.apache.samza.operators.spec.JoinOperatorSpec;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * {@link IntermediateStreamManager} calculates intermediate stream partitions based on the high-level application graph.
+ */
+class IntermediateStreamManager {
+
+  private static final Logger log = LoggerFactory.getLogger(IntermediateStreamManager.class);
+
+  private final Config config;
+  private final Map<String, InputOperatorSpec> inputOperators;
+
+  @VisibleForTesting
+  static final int MAX_INFERRED_PARTITIONS = 256;
+
+  IntermediateStreamManager(Config config, ApplicationDescriptorImpl<? extends ApplicationDescriptor> appDesc) {
+    this.config = config;
+    this.inputOperators = appDesc.getInputOperators();
+  }
+
+  /**
+   * Figure out the number of partitions of all streams
+   */
+  /* package private */ void calculatePartitions(JobGraph jobGraph) {
+
+    // Verify agreement in partition count between all joined input/intermediate streams
+    validateJoinInputStreamPartitions(jobGraph);
+
+    if (!jobGraph.getIntermediateStreamEdges().isEmpty()) {
+      // Set partition count of intermediate streams not participating in joins
+      setIntermediateStreamPartitions(jobGraph);
+
+      // Validate partition counts were assigned for all intermediate streams
+      validateIntermediateStreamPartitions(jobGraph);
+    }
+  }
+
+  /**
+   * Validates agreement in partition count between input/intermediate streams participating in join operations.
+   */
+  private void validateJoinInputStreamPartitions(JobGraph jobGraph) {
+    // Group input operator specs (input/intermediate streams) by the joins they participate in.
+    Multimap<JoinOperatorSpec, InputOperatorSpec> joinOpSpecToInputOpSpecs =
+        OperatorSpecGraphAnalyzer.getJoinToInputOperatorSpecs(inputOperators.values());
+
+    // Convert every group of input operator specs into a group of corresponding stream edges.
+    List<StreamEdgeSet> streamEdgeSets = new ArrayList<>();
+    for (JoinOperatorSpec joinOpSpec : joinOpSpecToInputOpSpecs.keySet()) {
+      Collection<InputOperatorSpec> joinedInputOpSpecs = joinOpSpecToInputOpSpecs.get(joinOpSpec);
+      StreamEdgeSet streamEdgeSet = getStreamEdgeSet(joinOpSpec.getOpId(), joinedInputOpSpecs, jobGraph);
+      streamEdgeSets.add(streamEdgeSet);
+    }
+
+    /*
+     * Sort the stream edge groups by their category so they appear in this order:
+     *   1. groups composed exclusively of stream edges with set partition counts
+     *   2. groups composed of a mix of stream edges  with set/unset partition counts
+     *   3. groups composed exclusively of stream edges with unset partition counts
+     *
+     *   This guarantees that we process the most constrained stream edge groups first,
+     *   which is crucial for intermediate stream edges that are members of multiple
+     *   stream edge groups. For instance, if we have the following groups of stream
+     *   edges (partition counts in parentheses, question marks for intermediate streams):
+     *
+     *      a. e1 (16), e2 (16)
+     *      b. e2 (16), e3 (?)
+     *      c. e3 (?), e4 (?)
+     *
+     *   processing them in the above order (most constrained first) is guaranteed to
+     *   yield correct assignment of partition counts of e3 and e4 in a single scan.
+     */
+    Collections.sort(streamEdgeSets, Comparator.comparingInt(e -> e.getCategory().getSortOrder()));
+
+    // Verify agreement between joined input/intermediate streams.
+    // This may involve setting partition counts of intermediate stream edges.
+    streamEdgeSets.forEach(IntermediateStreamManager::validateAndAssignStreamEdgeSetPartitions);
+  }
+
+  /**
+   * Creates a {@link StreamEdgeSet} whose Id is {@code setId}, and {@link StreamEdge}s
+   * correspond to the provided {@code inputOpSpecs}.
+   */
+  private StreamEdgeSet getStreamEdgeSet(String setId, Iterable<InputOperatorSpec> inputOpSpecs,
+      JobGraph jobGraph) {
+
+    int countStreamEdgeWithSetPartitions = 0;
+    Set<StreamEdge> streamEdges = new HashSet<>();
+
+    for (InputOperatorSpec inputOpSpec : inputOpSpecs) {
+      StreamEdge streamEdge = jobGraph.getStreamEdge(inputOpSpec.getStreamId());
+      if (streamEdge.getPartitionCount() != StreamEdge.PARTITIONS_UNKNOWN) {
+        ++countStreamEdgeWithSetPartitions;
+      }
+      streamEdges.add(streamEdge);
+    }
+
+    // Determine category of stream group based on stream partition counts.
+    StreamEdgeSet.StreamEdgeSetCategory category;
+    if (countStreamEdgeWithSetPartitions == 0) {
+      category = StreamEdgeSet.StreamEdgeSetCategory.NO_PARTITION_COUNT_SET;
+    } else if (countStreamEdgeWithSetPartitions == streamEdges.size()) {
+      category = StreamEdgeSet.StreamEdgeSetCategory.ALL_PARTITION_COUNT_SET;
+    } else {
+      category = StreamEdgeSet.StreamEdgeSetCategory.SOME_PARTITION_COUNT_SET;
+    }
+
+    return new StreamEdgeSet(setId, streamEdges, category);
+  }
+
+  /**
+   * Sets partition count of intermediate streams which have not been assigned partition counts.
+   */
+  private void setIntermediateStreamPartitions(JobGraph jobGraph) {
+    final String defaultPartitionsConfigProperty = JobConfig.JOB_INTERMEDIATE_STREAM_PARTITIONS();
+    int partitions = config.getInt(defaultPartitionsConfigProperty, StreamEdge.PARTITIONS_UNKNOWN);
+    if (partitions == StreamEdge.PARTITIONS_UNKNOWN) {
+      // use the following simple algo to figure out the partitions
+      // partition = MAX(MAX(Input topic partitions), MAX(Output topic partitions))
+      // partition will be further bounded by MAX_INFERRED_PARTITIONS.
+      // This is important when running in hadoop where an HDFS input can have lots of files (partitions).
+      int maxInPartitions = maxPartitions(jobGraph.getInputStreams());
+      int maxOutPartitions = maxPartitions(jobGraph.getOutputStreams());
+      partitions = Math.max(maxInPartitions, maxOutPartitions);
+
+      if (partitions > MAX_INFERRED_PARTITIONS) {
+        partitions = MAX_INFERRED_PARTITIONS;
+        log.warn(String.format("Inferred intermediate stream partition count %d is greater than the max %d. Using the max.",
+            partitions, MAX_INFERRED_PARTITIONS));
+      }
+    } else {
+      // Reject any zero or other negative values explicitly specified in config.
+      if (partitions <= 0) {
+        throw new SamzaException(String.format("Invalid value %d specified for config property %s", partitions,
+            defaultPartitionsConfigProperty));
+      }
+
+      log.info("Using partition count value {} specified for config property {}", partitions,
+          defaultPartitionsConfigProperty);
+    }
+
+    for (StreamEdge edge : jobGraph.getIntermediateStreamEdges()) {
+      if (edge.getPartitionCount() <= 0) {
+        log.info("Set the partition count for intermediate stream {} to {}.", edge.getName(), partitions);
+        edge.setPartitionCount(partitions);
+      }
+    }
+  }
+
+  /**
+   * Ensures all intermediate streams have been assigned partition counts.
+   */
+  private static void validateIntermediateStreamPartitions(JobGraph jobGraph) {
+    for (StreamEdge edge : jobGraph.getIntermediateStreamEdges()) {
+      if (edge.getPartitionCount() <= 0) {
+        throw new SamzaException(String.format("Failed to assign valid partition count to Stream %s", edge.getName()));
+      }
+    }
+  }
+
+  /**
+   * Ensures that all streams in the supplied {@link StreamEdgeSet} agree in partition count.
+   * This may include setting partition counts of intermediate streams in this set that do not
+   * have their partition counts set.
+   */
+  private static void validateAndAssignStreamEdgeSetPartitions(StreamEdgeSet streamEdgeSet) {
+    Set<StreamEdge> streamEdges = streamEdgeSet.getStreamEdges();
+    StreamEdge firstStreamEdgeWithSetPartitions =
+        streamEdges.stream()
+            .filter(streamEdge -> streamEdge.getPartitionCount() != StreamEdge.PARTITIONS_UNKNOWN)
+            .findFirst()
+            .orElse(null);
+
+    // This group consists exclusively of intermediate streams with unknown partition counts.
+    // We cannot do any validation/computation of partition counts of such streams right here,
+    // but they are tackled later in the ExecutionPlanner.
+    if (firstStreamEdgeWithSetPartitions == null) {
+      return;
+    }
+
+    // Make sure all other stream edges in this group have the same partition count.
+    int partitions = firstStreamEdgeWithSetPartitions.getPartitionCount();
+    for (StreamEdge streamEdge : streamEdges) {
+      int streamPartitions = streamEdge.getPartitionCount();
+      if (streamPartitions == StreamEdge.PARTITIONS_UNKNOWN) {
+        streamEdge.setPartitionCount(partitions);
+        log.info("Inferred the partition count {} for the join operator {} from {}.",
+            new Object[] {partitions, streamEdgeSet.getSetId(), firstStreamEdgeWithSetPartitions.getName()});
+      } else if (streamPartitions != partitions) {
+        throw  new SamzaException(String.format(
+            "Unable to resolve input partitions of stream %s for the join %s. Expected: %d, Actual: %d",
+            streamEdge.getName(), streamEdgeSet.getSetId(), partitions, streamPartitions));
+      }
+    }
+  }
+
+  /* package private */ static int maxPartitions(Collection<StreamEdge> edges) {
+    return edges.stream().mapToInt(StreamEdge::getPartitionCount).max().orElse(StreamEdge.PARTITIONS_UNKNOWN);
+  }
+
+  /**
+   * Represents a set of {@link StreamEdge}s.
+   */
+  /* package private */ static class StreamEdgeSet {
+
+    /**
+     * Indicates whether all stream edges in this group have their partition counts assigned.
+     */
+    public enum StreamEdgeSetCategory {
+      /**
+       * All stream edges in this group have their partition counts assigned.
+       */
+      ALL_PARTITION_COUNT_SET(0),
+
+      /**
+       * Only some stream edges in this group have their partition counts assigned.
+       */
+      SOME_PARTITION_COUNT_SET(1),
+
+      /**
+       * No stream edge in this group is assigned a partition count.
+       */
+      NO_PARTITION_COUNT_SET(2);
+
+
+      private final int sortOrder;
+
+      StreamEdgeSetCategory(int sortOrder) {
+        this.sortOrder = sortOrder;
+      }
+
+      public int getSortOrder() {
+        return sortOrder;
+      }
+    }
+
+    private final String setId;
+    private final Set<StreamEdge> streamEdges;
+    private final StreamEdgeSetCategory category;
+
+    StreamEdgeSet(String setId, Set<StreamEdge> streamEdges, StreamEdgeSetCategory category) {
+      this.setId = setId;
+      this.streamEdges = streamEdges;
+      this.category = category;
+    }
+
+    Set<StreamEdge> getStreamEdges() {
+      return streamEdges;
+    }
+
+    String getSetId() {
+      return setId;
+    }
+
+    StreamEdgeSetCategory getCategory() {
+      return category;
+    }
+  }
+}
index 5b19095..d975188 100644 (file)
@@ -31,10 +31,11 @@ import java.util.Queue;
 import java.util.Set;
 import java.util.stream.Collectors;
 
+import org.apache.samza.application.ApplicationDescriptor;
+import org.apache.samza.application.ApplicationDescriptorImpl;
 import org.apache.samza.config.ApplicationConfig;
 import org.apache.samza.config.Config;
 import org.apache.samza.config.JobConfig;
-import org.apache.samza.operators.OperatorSpecGraph;
 import org.apache.samza.system.StreamSpec;
 import org.apache.samza.table.TableSpec;
 import org.slf4j.Logger;
@@ -59,16 +60,21 @@ import org.slf4j.LoggerFactory;
   private final Set<StreamEdge> intermediateStreams = new HashSet<>();
   private final Set<TableSpec> tables = new HashSet<>();
   private final Config config;
-  private final JobGraphJsonGenerator jsonGenerator = new JobGraphJsonGenerator();
-  private final OperatorSpecGraph specGraph;
+  private final JobGraphJsonGenerator jsonGenerator;
+  private final JobNodeConfigurationGenerator configGenerator;
+  private final ApplicationDescriptorImpl<? extends ApplicationDescriptor> appDesc;
 
   /**
    * The JobGraph is only constructed by the {@link ExecutionPlanner}.
-   * @param config Config
+   *
+   * @param config configuration for the application
+   * @param appDesc {@link ApplicationDescriptorImpl} describing the application
    */
-  JobGraph(Config config, OperatorSpecGraph specGraph) {
+  JobGraph(Config config, ApplicationDescriptorImpl appDesc) {
     this.config = config;
-    this.specGraph = specGraph;
+    this.appDesc = appDesc;
+    this.jsonGenerator = new JobGraphJsonGenerator();
+    this.configGenerator = new JobNodeConfigurationGenerator();
   }
 
   @Override
@@ -91,11 +97,6 @@ import org.slf4j.LoggerFactory;
         .collect(Collectors.toList());
   }
 
-  void addTable(TableSpec tableSpec, JobNode node) {
-    tables.add(tableSpec);
-    node.addTable(tableSpec);
-  }
-
   @Override
   public String getPlanAsJson() throws Exception {
     return jsonGenerator.toJson(this);
@@ -105,14 +106,11 @@ import org.slf4j.LoggerFactory;
    * Returns the config for this application
    * @return {@link ApplicationConfig}
    */
+  @Override
   public ApplicationConfig getApplicationConfig() {
     return new ApplicationConfig(config);
   }
 
-  public OperatorSpecGraph getSpecGraph() {
-    return specGraph;
-  }
-
   /**
    * Add a source stream to a {@link JobNode}
    * @param streamSpec input stream
@@ -152,20 +150,20 @@ import org.slf4j.LoggerFactory;
     intermediateStreams.add(edge);
   }
 
+  void addTable(TableSpec tableSpec, JobNode node) {
+    tables.add(tableSpec);
+    node.addTable(tableSpec);
+  }
+
   /**
    * Get the {@link JobNode}. Create one if it does not exist.
    * @param jobName name of the job
    * @param jobId id of the job
-   * @return
+   * @return {@link JobNode} created with {@code jobName} and {@code jobId}
    */
   JobNode getOrCreateJobNode(String jobName, String jobId) {
-    String nodeId = JobNode.createId(jobName, jobId);
-    JobNode node = nodes.get(nodeId);
-    if (node == null) {
-      node = new JobNode(jobName, jobId, specGraph, config);
-      nodes.put(nodeId, node);
-    }
-    return node;
+    String nodeId = JobNode.createJobNameAndId(jobName, jobId);
+    return nodes.computeIfAbsent(nodeId, k -> new JobNode(jobName, jobId, config, appDesc, configGenerator));
   }
 
   /**
@@ -178,20 +176,13 @@ import org.slf4j.LoggerFactory;
   }
 
   /**
-   * Get the {@link StreamEdge} for a {@link StreamSpec}. Create one if it does not exist.
-   * @param streamSpec  spec of the StreamEdge
-   * @param isIntermediate  boolean flag indicating whether it's an intermediate stream
+   * Get the {@link StreamEdge} for {@code streamId}.
+   *
+   * @param streamId the streamId for the {@link StreamEdge}
    * @return stream edge
    */
-  StreamEdge getOrCreateStreamEdge(StreamSpec streamSpec, boolean isIntermediate) {
-    String streamId = streamSpec.getId();
-    StreamEdge edge = edges.get(streamId);
-    if (edge == null) {
-      boolean isBroadcast = specGraph.getBroadcastStreams().contains(streamId);
-      edge = new StreamEdge(streamSpec, isIntermediate, isBroadcast, config);
-      edges.put(streamId, edge);
-    }
-    return edge;
+  StreamEdge getStreamEdge(String streamId) {
+    return edges.get(streamId);
   }
 
   /**
@@ -248,6 +239,23 @@ import org.slf4j.LoggerFactory;
   }
 
   /**
+   * Get the {@link StreamEdge} for a {@link StreamSpec}. Create one if it does not exist.
+   * @param streamSpec  spec of the StreamEdge
+   * @param isIntermediate  boolean flag indicating whether it's an intermediate stream
+   * @return stream edge
+   */
+  private StreamEdge getOrCreateStreamEdge(StreamSpec streamSpec, boolean isIntermediate) {
+    String streamId = streamSpec.getId();
+    StreamEdge edge = edges.get(streamId);
+    if (edge == null) {
+      boolean isBroadcast = appDesc.getBroadcastStreams().contains(streamId);
+      edge = new StreamEdge(streamSpec, isIntermediate, isBroadcast, config);
+      edges.put(streamId, edge);
+    }
+    return edge;
+  }
+
+  /**
    * Validate the input streams should have indegree being 0 and outdegree greater than 0
    */
   private void validateInputStreams() {
@@ -305,7 +313,7 @@ import org.slf4j.LoggerFactory;
       Set<JobNode> unreachable = new HashSet<>(nodes.values());
       unreachable.removeAll(reachable);
       throw new IllegalArgumentException(String.format("Jobs %s cannot be reached from Sources.",
-          String.join(", ", unreachable.stream().map(JobNode::getId).collect(Collectors.toList()))));
+          String.join(", ", unreachable.stream().map(JobNode::getJobNameAndId).collect(Collectors.toList()))));
     }
   }
 
@@ -325,7 +333,7 @@ import org.slf4j.LoggerFactory;
 
     while (!queue.isEmpty()) {
       JobNode node = queue.poll();
-      node.getOutEdges().stream().flatMap(edge -> edge.getTargetNodes().stream()).forEach(target -> {
+      node.getOutEdges().values().stream().flatMap(edge -> edge.getTargetNodes().stream()).forEach(target -> {
           if (!visited.contains(target)) {
             visited.add(target);
             queue.offer(target);
@@ -351,9 +359,9 @@ import org.slf4j.LoggerFactory;
     Map<String, Long> indegree = new HashMap<>();
     Set<JobNode> visited = new HashSet<>();
     pnodes.forEach(node -> {
-        String nid = node.getId();
+        String nid = node.getJobNameAndId();
         //only count the degrees of intermediate streams
-        long degree = node.getInEdges().stream().filter(e -> !inputStreams.contains(e)).count();
+        long degree = node.getInEdges().values().stream().filter(e -> !inputStreams.contains(e)).count();
         indegree.put(nid, degree);
 
         if (degree == 0L) {
@@ -378,8 +386,8 @@ import org.slf4j.LoggerFactory;
       while (!q.isEmpty()) {
         JobNode node = q.poll();
         sortedNodes.add(node);
-        node.getOutEdges().stream().flatMap(edge -> edge.getTargetNodes().stream()).forEach(n -> {
-            String nid = n.getId();
+        node.getOutEdges().values().stream().flatMap(edge -> edge.getTargetNodes().stream()).forEach(n -> {
+            String nid = n.getJobNameAndId();
             Long degree = indegree.get(nid) - 1;
             indegree.put(nid, degree);
             if (degree == 0L && !visited.contains(n)) {
@@ -400,7 +408,7 @@ import org.slf4j.LoggerFactory;
           long min = Long.MAX_VALUE;
           JobNode minNode = null;
           for (JobNode node : reachable) {
-            Long degree = indegree.get(node.getId());
+            Long degree = indegree.get(node.getJobNameAndId());
             if (degree < min) {
               min = degree;
               minNode = node;
index 91453d2..18705e4 100644 (file)
@@ -32,7 +32,6 @@ import java.util.stream.Collectors;
 import org.apache.samza.config.ApplicationConfig;
 import org.apache.samza.operators.spec.JoinOperatorSpec;
 import org.apache.samza.operators.spec.OperatorSpec;
-import org.apache.samza.operators.spec.OperatorSpec.OpCode;
 import org.apache.samza.operators.spec.OutputOperatorSpec;
 import org.apache.samza.operators.spec.OutputStreamImpl;
 import org.apache.samza.operators.spec.PartitionByOperatorSpec;
@@ -140,7 +139,7 @@ import org.codehaus.jackson.map.ObjectMapper;
     jobGraph.getTables().forEach(t -> buildTableJson(t, jobGraphJson.tables));
 
     jobGraphJson.jobs = jobGraph.getJobNodes().stream()
-        .map(jobNode -> buildJobNodeJson(jobNode))
+        .map(this::buildJobNodeJson)
         .collect(Collectors.toList());
 
     ByteArrayOutputStream out = new ByteArrayOutputStream();
@@ -149,54 +148,12 @@ import org.codehaus.jackson.map.ObjectMapper;
     return new String(out.toByteArray());
   }
 
-  /**
-   * Create JSON POJO for a {@link JobNode}, including the {@link org.apache.samza.operators.StreamGraph} for this job
-   * @param jobNode job node in the {@link JobGraph}
-   * @return {@link org.apache.samza.execution.JobGraphJsonGenerator.JobNodeJson}
-   */
-  private JobNodeJson buildJobNodeJson(JobNode jobNode) {
-    JobNodeJson job = new JobNodeJson();
-    job.jobName = jobNode.getJobName();
-    job.jobId = jobNode.getJobId();
-    job.operatorGraph = buildOperatorGraphJson(jobNode);
-    return job;
-  }
-
-  /**
-   * Traverse the {@link OperatorSpec} graph and build the operator graph JSON POJO.
-   * @param jobNode job node in the {@link JobGraph}
-   * @return {@link org.apache.samza.execution.JobGraphJsonGenerator.OperatorGraphJson}
-   */
-  private OperatorGraphJson buildOperatorGraphJson(JobNode jobNode) {
-    OperatorGraphJson opGraph = new OperatorGraphJson();
-    opGraph.inputStreams = new ArrayList<>();
-    jobNode.getSpecGraph().getInputOperators().forEach((streamId, operatorSpec) -> {
-        StreamJson inputJson = new StreamJson();
-        opGraph.inputStreams.add(inputJson);
-        inputJson.streamId = streamId;
-        inputJson.nextOperatorIds = operatorSpec.getRegisteredOperatorSpecs().stream()
-            .map(OperatorSpec::getOpId).collect(Collectors.toSet());
-
-        updateOperatorGraphJson(operatorSpec, opGraph);
-      });
-
-    opGraph.outputStreams = new ArrayList<>();
-    jobNode.getSpecGraph().getOutputStreams().keySet().forEach(streamId -> {
-        StreamJson outputJson = new StreamJson();
-        outputJson.streamId = streamId;
-        opGraph.outputStreams.add(outputJson);
-      });
-    return opGraph;
-  }
-
-  /**
-   * Traverse the {@link OperatorSpec} graph recursively and update the operator graph JSON POJO.
-   * @param operatorSpec input
-   * @param opGraph operator graph to build
-   */
   private void updateOperatorGraphJson(OperatorSpec operatorSpec, OperatorGraphJson opGraph) {
-    // TODO xiliu: render input operators instead of input streams
-    if (operatorSpec.getOpCode() != OpCode.INPUT) {
+    if (operatorSpec == null) {
+      // task application may not have any defined OperatorSpec
+      return;
+    }
+    if (operatorSpec.getOpCode() != OperatorSpec.OpCode.INPUT) {
       opGraph.operators.put(operatorSpec.getOpId(), operatorToMap(operatorSpec));
     }
     Collection<OperatorSpec> specs = operatorSpec.getRegisteredOperatorSpecs();
@@ -243,6 +200,46 @@ import org.codehaus.jackson.map.ObjectMapper;
   }
 
   /**
+   * Create JSON POJO for a {@link JobNode}, including the {@link org.apache.samza.application.ApplicationDescriptorImpl}
+   * for this job
+   *
+   * @param jobNode job node in the {@link JobGraph}
+   * @return {@link org.apache.samza.execution.JobGraphJsonGenerator.JobNodeJson}
+   */
+  private JobNodeJson buildJobNodeJson(JobNode jobNode) {
+    JobNodeJson job = new JobNodeJson();
+    job.jobName = jobNode.getJobName();
+    job.jobId = jobNode.getJobId();
+    job.operatorGraph = buildOperatorGraphJson(jobNode);
+    return job;
+  }
+
+  /**
+   * Traverse the {@link OperatorSpec} graph and build the operator graph JSON POJO.
+   * @param jobNode job node in the {@link JobGraph}
+   * @return {@link org.apache.samza.execution.JobGraphJsonGenerator.OperatorGraphJson}
+   */
+  private OperatorGraphJson buildOperatorGraphJson(JobNode jobNode) {
+    OperatorGraphJson opGraph = new OperatorGraphJson();
+    opGraph.inputStreams = new ArrayList<>();
+    jobNode.getInEdges().values().forEach(inStream -> {
+        StreamJson inputJson = new StreamJson();
+        opGraph.inputStreams.add(inputJson);
+        inputJson.streamId = inStream.getStreamSpec().getId();
+        inputJson.nextOperatorIds = jobNode.getNextOperatorIds(inputJson.streamId);
+        updateOperatorGraphJson(jobNode.getInputOperator(inputJson.streamId), opGraph);
+      });
+
+    opGraph.outputStreams = new ArrayList<>();
+    jobNode.getOutEdges().values().forEach(outStream -> {
+        StreamJson outputJson = new StreamJson();
+        outputJson.streamId = outStream.getStreamSpec().getId();
+        opGraph.outputStreams.add(outputJson);
+      });
+    return opGraph;
+  }
+
+  /**
    * Get or create the JSON POJO for a {@link StreamEdge}
    * @param edge {@link StreamEdge}
    * @param streamEdges map of streamId to {@link org.apache.samza.execution.JobGraphJsonGenerator.StreamEdgeJson}
@@ -261,15 +258,11 @@ import org.codehaus.jackson.map.ObjectMapper;
       edgeJson.streamSpec = streamSpecJson;
 
       List<String> sourceJobs = new ArrayList<>();
-      edge.getSourceNodes().forEach(jobNode -> {
-          sourceJobs.add(jobNode.getJobName());
-        });
+      edge.getSourceNodes().forEach(jobNode -> sourceJobs.add(jobNode.getJobName()));
       edgeJson.sourceJobs = sourceJobs;
 
       List<String> targetJobs = new ArrayList<>();
-      edge.getTargetNodes().forEach(jobNode -> {
-          targetJobs.add(jobNode.getJobName());
-        });
+      edge.getTargetNodes().forEach(jobNode -> targetJobs.add(jobNode.getJobName()));
       edgeJson.targetJobs = targetJobs;
 
       streamEdges.put(streamId, edgeJson);
@@ -285,12 +278,7 @@ import org.codehaus.jackson.map.ObjectMapper;
    */
   private TableSpecJson buildTableJson(TableSpec tableSpec, Map<String, TableSpecJson> tableSpecs) {
     String tableId = tableSpec.getId();
-    TableSpecJson tableSpecJson = tableSpecs.get(tableId);
-    if (tableSpecJson == null) {
-      tableSpecJson = buildTableJson(tableSpec);
-      tableSpecs.put(tableId, tableSpecJson);
-    }
-    return tableSpecJson;
+    return tableSpecs.computeIfAbsent(tableId, k -> buildTableJson(tableSpec));
   }
 
   /**
index 47705ee..af556f5 100644 (file)
 
 package org.apache.samza.execution;
 
-import java.util.ArrayList;
-import java.util.Base64;
 import java.util.Collection;
 import java.util.HashMap;
 import java.util.HashSet;
-import java.util.List;
 import java.util.Map;
-import java.util.UUID;
+import java.util.Objects;
+import java.util.Set;
 import java.util.stream.Collectors;
-
-import org.apache.commons.lang3.StringUtils;
+import org.apache.samza.application.ApplicationDescriptor;
+import org.apache.samza.application.ApplicationDescriptorImpl;
+import org.apache.samza.application.LegacyTaskApplication;
 import org.apache.samza.config.Config;
 import org.apache.samza.config.JobConfig;
-import org.apache.samza.config.MapConfig;
-import org.apache.samza.config.SerializerConfig;
-import org.apache.samza.config.StorageConfig;
-import org.apache.samza.config.StreamConfig;
-import org.apache.samza.config.TaskConfig;
-import org.apache.samza.config.TaskConfigJava;
-import org.apache.samza.operators.OperatorSpecGraph;
+import org.apache.samza.operators.KV;
 import org.apache.samza.operators.spec.InputOperatorSpec;
-import org.apache.samza.operators.spec.JoinOperatorSpec;
 import org.apache.samza.operators.spec.OperatorSpec;
-import org.apache.samza.operators.spec.OutputStreamImpl;
-import org.apache.samza.operators.spec.StatefulOperatorSpec;
-import org.apache.samza.operators.spec.WindowOperatorSpec;
-import org.apache.samza.table.TableConfigGenerator;
-import org.apache.samza.util.MathUtil;
 import org.apache.samza.serializers.Serde;
-import org.apache.samza.serializers.SerializableSerde;
 import org.apache.samza.table.TableSpec;
-import org.apache.samza.util.StreamUtil;
-import org.apache.samza.util.Util;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
-import com.google.common.base.Joiner;
-
-
 /**
  * A JobNode is a physical execution unit. In RemoteExecutionEnvironment, it's a job that will be submitted
  * to remote cluster. In LocalExecutionEnvironment, it's a set of StreamProcessors for local execution.
@@ -65,64 +46,71 @@ import com.google.common.base.Joiner;
  */
 public class JobNode {
   private static final Logger log = LoggerFactory.getLogger(JobNode.class);
-  private static final String CONFIG_INTERNAL_EXECUTION_PLAN = "samza.internal.execution.plan";
 
   private final String jobName;
   private final String jobId;
-  private final String id;
-  private final OperatorSpecGraph specGraph;
-  private final List<StreamEdge> inEdges = new ArrayList<>();
-  private final List<StreamEdge> outEdges = new ArrayList<>();
-  private final List<TableSpec> tables = new ArrayList<>();
+  private final String jobNameAndId;
   private final Config config;
-
-  JobNode(String jobName, String jobId, OperatorSpecGraph specGraph, Config config) {
+  private final JobNodeConfigurationGenerator configGenerator;
+  // The following maps (i.e. inEdges and outEdges) uses the streamId as the key
+  private final Map<String, StreamEdge> inEdges = new HashMap<>();
+  private final Map<String, StreamEdge> outEdges = new HashMap<>();
+  // Similarly, tables uses tableId as the key
+  private final Map<String, TableSpec> tables = new HashMap<>();
+  private final ApplicationDescriptorImpl<? extends ApplicationDescriptor> appDesc;
+
+  JobNode(String jobName, String jobId, Config config, ApplicationDescriptorImpl appDesc,
+      JobNodeConfigurationGenerator configureGenerator) {
     this.jobName = jobName;
     this.jobId = jobId;
-    this.id = createId(jobName, jobId);
-    this.specGraph = specGraph;
+    this.jobNameAndId = createJobNameAndId(jobName, jobId);
     this.config = config;
+    this.appDesc = appDesc;
+    this.configGenerator = configureGenerator;
   }
 
-  public static Config mergeJobConfig(Config fullConfig, Config generatedConfig) {
-    return new JobConfig(Util.rewriteConfig(extractScopedConfig(
-        fullConfig, generatedConfig, String.format(JobConfig.CONFIG_JOB_PREFIX(), new JobConfig(fullConfig).getName().get()))));
-  }
-
-  public OperatorSpecGraph getSpecGraph() {
-    return this.specGraph;
+  static String createJobNameAndId(String jobName, String jobId) {
+    return String.format("%s-%s", jobName, jobId);
   }
 
-  public  String getId() {
-    return id;
+  String getJobNameAndId() {
+    return jobNameAndId;
   }
 
-  public String getJobName() {
+  String getJobName() {
     return jobName;
   }
 
-  public String getJobId() {
+  String getJobId() {
     return jobId;
   }
 
+  Config getConfig() {
+    return config;
+  }
+
   void addInEdge(StreamEdge in) {
-    inEdges.add(in);
+    inEdges.put(in.getStreamSpec().getId(), in);
   }
 
   void addOutEdge(StreamEdge out) {
-    outEdges.add(out);
+    outEdges.put(out.getStreamSpec().getId(), out);
   }
 
-  List<StreamEdge> getInEdges() {
+  void addTable(TableSpec tableSpec) {
+    tables.put(tableSpec.getId(), tableSpec);
+  }
+
+  Map<String, StreamEdge> getInEdges() {
     return inEdges;
   }
 
-  List<StreamEdge> getOutEdges() {
+  Map<String, StreamEdge> getOutEdges() {
     return outEdges;
   }
 
-  void addTable(TableSpec tableSpec) {
-    tables.add(tableSpec);
+  Map<String, TableSpec> getTables() {
+    return tables;
   }
 
   /**
@@ -130,250 +118,65 @@ public class JobNode {
    * @param executionPlanJson JSON representation of the execution plan
    * @return config of the job
    */
-  public JobConfig generateConfig(String executionPlanJson) {
-    Map<String, String> configs = new HashMap<>();
-    configs.put(JobConfig.JOB_NAME(), jobName);
-    configs.put(JobConfig.JOB_ID(), jobId);
+  JobConfig generateConfig(String executionPlanJson) {
+    return configGenerator.generateJobConfig(this, executionPlanJson);
+  }
 
-    final List<String> inputs = new ArrayList<>();
-    final List<String> broadcasts = new ArrayList<>();
-    for (StreamEdge inEdge : inEdges) {
-      String formattedSystemStream = inEdge.getName();
-      if (inEdge.isBroadcast()) {
-        broadcasts.add(formattedSystemStream + "#0");
-      } else {
-        inputs.add(formattedSystemStream);
-      }
+  KV<Serde, Serde> getInputSerdes(String streamId) {
+    if (!inEdges.containsKey(streamId)) {
+      return null;
     }
+    return appDesc.getStreamSerdes(streamId);
+  }
 
-    if (!broadcasts.isEmpty()) {
-      // TODO: remove this once we support defining broadcast input stream in high-level
-      // task.broadcast.input should be generated by the planner in the future.
-      final String taskBroadcasts = config.get(TaskConfigJava.BROADCAST_INPUT_STREAMS);
-      if (StringUtils.isNoneEmpty(taskBroadcasts)) {
-        broadcasts.add(taskBroadcasts);
-      }
-      configs.put(TaskConfigJava.BROADCAST_INPUT_STREAMS, Joiner.on(',').join(broadcasts));
+  KV<Serde, Serde> getOutputSerde(String streamId) {
+    if (!outEdges.containsKey(streamId)) {
+      return null;
     }
+    return appDesc.getStreamSerdes(streamId);
+  }
 
-    // set triggering interval if a window or join is defined
-    if (specGraph.hasWindowOrJoins()) {
-      if ("-1".equals(config.get(TaskConfig.WINDOW_MS(), "-1"))) {
-        long triggerInterval = computeTriggerInterval();
-        log.info("Using triggering interval: {} for jobName: {}", triggerInterval, jobName);
+  Collection<OperatorSpec> getReachableOperators() {
+    Set<OperatorSpec> inputOperatorsInJobNode = inEdges.values().stream().map(inEdge ->
+        appDesc.getInputOperators().get(inEdge.getStreamSpec().getId())).filter(Objects::nonNull).collect(Collectors.toSet());
+    Set<OperatorSpec> reachableOperators = new HashSet<>();
+    findReachableOperators(inputOperatorsInJobNode, reachableOperators);
+    return reachableOperators;
+  }
 
-        configs.put(TaskConfig.WINDOW_MS(), String.valueOf(triggerInterval));
-      }
+  // get all next operators consuming from the input {@code streamId}
+  Set<String> getNextOperatorIds(String streamId) {
+    if (!appDesc.getInputOperators().containsKey(streamId) || !inEdges.containsKey(streamId)) {
+      return new HashSet<>();
     }
+    return appDesc.getInputOperators().get(streamId).getRegisteredOperatorSpecs().stream()
+        .map(op -> op.getOpId()).collect(Collectors.toSet());
+  }
 
-    specGraph.getAllOperatorSpecs().forEach(opSpec -> {
-        if (opSpec instanceof StatefulOperatorSpec) {
-          ((StatefulOperatorSpec) opSpec).getStoreDescriptors()
-              .forEach(sd -> configs.putAll(sd.getStorageConfigs()));
-          // store key and message serdes are configured separately in #addSerdeConfigs
-        }
-      });
-
-    configs.put(CONFIG_INTERNAL_EXECUTION_PLAN, executionPlanJson);
-
-    // write input/output streams to configs
-    inEdges.stream().filter(StreamEdge::isIntermediate).forEach(edge -> configs.putAll(edge.generateConfig()));
-
-    // write serialized serde instances and stream serde configs to configs
-    addSerdeConfigs(configs);
-
-    configs.putAll(TableConfigGenerator.generateConfigsForTableSpecs(new MapConfig(configs), tables));
-
-    // Add side inputs to the inputs and mark the stream as bootstrap
-    tables.forEach(tableSpec -> {
-        List<String> sideInputs = tableSpec.getSideInputs();
-        if (sideInputs != null && !sideInputs.isEmpty()) {
-          sideInputs.stream()
-              .map(sideInput -> StreamUtil.getSystemStreamFromNameOrId(config, sideInput))
-              .forEach(systemStream -> {
-                  inputs.add(StreamUtil.getNameFromSystemStream(systemStream));
-                  configs.put(String.format(StreamConfig.STREAM_PREFIX() + StreamConfig.BOOTSTRAP(),
-                      systemStream.getSystem(), systemStream.getStream()), "true");
-                });
-        }
-      });
-
-    configs.put(TaskConfig.INPUT_STREAMS(), Joiner.on(',').join(inputs));
-
-    log.info("Job {} has generated configs {}", jobName, configs);
-
-    String configPrefix = String.format(JobConfig.CONFIG_JOB_PREFIX(), jobName);
-
-    // Disallow user specified job inputs/outputs. This info comes strictly from the user application.
-    Map<String, String> allowedConfigs = new HashMap<>(config);
-    if (allowedConfigs.containsKey(TaskConfig.INPUT_STREAMS())) {
-      log.warn("Specifying task inputs in configuration is not allowed with Fluent API. "
-          + "Ignoring configured value for " + TaskConfig.INPUT_STREAMS());
-      allowedConfigs.remove(TaskConfig.INPUT_STREAMS());
+  InputOperatorSpec getInputOperator(String inputStreamId) {
+    if (!inEdges.containsKey(inputStreamId)) {
+      return null;
     }
-
-    log.debug("Job {} has allowed configs {}", jobName, allowedConfigs);
-    return new JobConfig(
-        Util.rewriteConfig(
-            extractScopedConfig(new MapConfig(allowedConfigs), new MapConfig(configs), configPrefix)));
+    return appDesc.getInputOperators().get(inputStreamId);
   }
 
-  /**
-   * Serializes the {@link Serde} instances for operators, adds them to the provided config, and
-   * sets the serde configuration for the input/output/intermediate streams appropriately.
-   *
-   * We try to preserve the number of Serde instances before and after serialization. However we don't
-   * guarantee that references shared between these serdes instances (e.g. an Jackson ObjectMapper shared
-   * between two json serdes) are shared after deserialization too.
-   *
-   * Ideally all the user defined objects in the application should be serialized and de-serialized in one pass
-   * from the same output/input stream so that we can maintain reference sharing relationships.
-   *
-   * @param configs the configs to add serialized serde instances and stream serde configs to
-   */
-  void addSerdeConfigs(Map<String, String> configs) {
-    // collect all key and msg serde instances for streams
-    Map<String, Serde> streamKeySerdes = new HashMap<>();
-    Map<String, Serde> streamMsgSerdes = new HashMap<>();
-    Map<String, InputOperatorSpec> inputOperators = specGraph.getInputOperators();
-    inEdges.forEach(edge -> {
-        String streamId = edge.getStreamSpec().getId();
-        InputOperatorSpec inputOperatorSpec = inputOperators.get(streamId);
-        Serde keySerde = inputOperatorSpec.getKeySerde();
-        if (keySerde != null) {
-          streamKeySerdes.put(streamId, keySerde);
-        }
-        Serde valueSerde = inputOperatorSpec.getValueSerde();
-        if (valueSerde != null) {
-          streamMsgSerdes.put(streamId, valueSerde);
-        }
-      });
-    Map<String, OutputStreamImpl> outputStreams = specGraph.getOutputStreams();
-    outEdges.forEach(edge -> {
-        String streamId = edge.getStreamSpec().getId();
-        OutputStreamImpl outputStream = outputStreams.get(streamId);
-        Serde keySerde = outputStream.getKeySerde();
-        if (keySerde != null) {
-          streamKeySerdes.put(streamId, keySerde);
-        }
-        Serde valueSerde = outputStream.getValueSerde();
-        if (valueSerde != null) {
-          streamMsgSerdes.put(streamId, valueSerde);
-        }
-      });
-
-    // collect all key and msg serde instances for stores
-    Map<String, Serde> storeKeySerdes = new HashMap<>();
-    Map<String, Serde> storeMsgSerdes = new HashMap<>();
-    specGraph.getAllOperatorSpecs().forEach(opSpec -> {
-        if (opSpec instanceof StatefulOperatorSpec) {
-          ((StatefulOperatorSpec) opSpec).getStoreDescriptors().forEach(storeDescriptor -> {
-              storeKeySerdes.put(storeDescriptor.getStoreName(), storeDescriptor.getKeySerde());
-              storeMsgSerdes.put(storeDescriptor.getStoreName(), storeDescriptor.getMsgSerde());
-            });
-        }
-      });
-
-    // for each unique stream or store serde instance, generate a unique name and serialize to config
-    HashSet<Serde> serdes = new HashSet<>(streamKeySerdes.values());
-    serdes.addAll(streamMsgSerdes.values());
-    serdes.addAll(storeKeySerdes.values());
-    serdes.addAll(storeMsgSerdes.values());
-    SerializableSerde<Serde> serializableSerde = new SerializableSerde<>();
-    Base64.Encoder base64Encoder = Base64.getEncoder();
-    Map<Serde, String> serdeUUIDs = new HashMap<>();
-    serdes.forEach(serde -> {
-        String serdeName = serdeUUIDs.computeIfAbsent(serde,
-            s -> serde.getClass().getSimpleName() + "-" + UUID.randomUUID().toString());
-        configs.putIfAbsent(String.format(SerializerConfig.SERDE_SERIALIZED_INSTANCE(), serdeName),
-            base64Encoder.encodeToString(serializableSerde.toBytes(serde)));
-      });
-
-    // set key and msg serdes for streams to the serde names generated above
-    streamKeySerdes.forEach((streamId, serde) -> {
-        String streamIdPrefix = String.format(StreamConfig.STREAM_ID_PREFIX(), streamId);
-        String keySerdeConfigKey = streamIdPrefix + StreamConfig.KEY_SERDE();
-        configs.put(keySerdeConfigKey, serdeUUIDs.get(serde));
-      });
-
-    streamMsgSerdes.forEach((streamId, serde) -> {
-        String streamIdPrefix = String.format(StreamConfig.STREAM_ID_PREFIX(), streamId);
-        String valueSerdeConfigKey = streamIdPrefix + StreamConfig.MSG_SERDE();
-        configs.put(valueSerdeConfigKey, serdeUUIDs.get(serde));
-      });
-
-    // set key and msg serdes for stores to the serde names generated above
-    storeKeySerdes.forEach((storeName, serde) -> {
-        String keySerdeConfigKey = String.format(StorageConfig.KEY_SERDE(), storeName);
-        configs.put(keySerdeConfigKey, serdeUUIDs.get(serde));
-      });
-
-    storeMsgSerdes.forEach((storeName, serde) -> {
-        String msgSerdeConfigKey = String.format(StorageConfig.MSG_SERDE(), storeName);
-        configs.put(msgSerdeConfigKey, serdeUUIDs.get(serde));
-      });
+  boolean isLegacyTaskApplication() {
+    return LegacyTaskApplication.class.isAssignableFrom(appDesc.getAppClass());
   }
 
-  /**
-   * Computes the triggering interval to use during the execution of this {@link JobNode}
-   */
-  private long computeTriggerInterval() {
-    // Obtain the operator specs from the specGraph
-    Collection<OperatorSpec> operatorSpecs = specGraph.getAllOperatorSpecs();
-
-    // Filter out window operators, and obtain a list of their triggering interval values
-    List<Long> windowTimerIntervals = operatorSpecs.stream()
-        .filter(spec -> spec.getOpCode() == OperatorSpec.OpCode.WINDOW)
-        .map(spec -> ((WindowOperatorSpec) spec).getDefaultTriggerMs())
-        .collect(Collectors.toList());
-
-    // Filter out the join operators, and obtain a list of their ttl values
-    List<Long> joinTtlIntervals = operatorSpecs.stream()
-        .filter(spec -> spec instanceof JoinOperatorSpec)
-        .map(spec -> ((JoinOperatorSpec) spec).getTtlMs())
-        .collect(Collectors.toList());
-
-    // Combine both the above lists
-    List<Long> candidateTimerIntervals = new ArrayList<>(joinTtlIntervals);
-    candidateTimerIntervals.addAll(windowTimerIntervals);
-
-    if (candidateTimerIntervals.isEmpty()) {
-      return -1;
-    }
-
-    // Compute the gcd of the resultant list
-    return MathUtil.gcd(candidateTimerIntervals);
+  KV<Serde, Serde> getTableSerdes(String tableId) {
+    //TODO: SAMZA-1893: should test whether the table is used in the current JobNode
+    return appDesc.getTableSerdes(tableId);
   }
 
-  /**
-   * This function extract the subset of configs from the full config, and use it to override the generated configs
-   * from the job.
-   * @param fullConfig full config
-   * @param generatedConfig config generated for the job
-   * @param configPrefix prefix to extract the subset of the config overrides
-   * @return config that merges the generated configs and overrides
-   */
-  private static Config extractScopedConfig(Config fullConfig, Config generatedConfig, String configPrefix) {
-    Config scopedConfig = fullConfig.subset(configPrefix);
-
-    Config[] configPrecedence = new Config[] {fullConfig, generatedConfig, scopedConfig};
-    // Strip empty configs so they don't override the configs before them.
-    Map<String, String> mergedConfig = new HashMap<>();
-    for (Map<String, String> config : configPrecedence) {
-      for (Map.Entry<String, String> property : config.entrySet()) {
-        String value = property.getValue();
-        if (!(value == null || value.isEmpty())) {
-          mergedConfig.put(property.getKey(), property.getValue());
+  private void findReachableOperators(Collection<OperatorSpec> inputOperatorsInJobNode,
+      Set<OperatorSpec> reachableOperators) {
+    inputOperatorsInJobNode.forEach(op -> {
+        if (reachableOperators.contains(op)) {
+          return;
         }
-      }
-    }
-    scopedConfig = new MapConfig(mergedConfig);
-    log.debug("Prefix '{}' has merged config {}", configPrefix, scopedConfig);
-
-    return scopedConfig;
-  }
-
-  static String createId(String jobName, String jobId) {
-    return String.format("%s-%s", jobName, jobId);
+        reachableOperators.add(op);
+        findReachableOperators(op.getRegisteredOperatorSpecs(), reachableOperators);
+      });
   }
 }
diff --git a/samza-core/src/main/java/org/apache/samza/execution/JobNodeConfigurationGenerator.java b/samza-core/src/main/java/org/apache/samza/execution/JobNodeConfigurationGenerator.java
new file mode 100644 (file)
index 0000000..676d28e
--- /dev/null
@@ -0,0 +1,361 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.samza.execution;
+
+import com.google.common.base.Joiner;
+import java.util.ArrayList;
+import java.util.Base64;
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.UUID;
+import java.util.stream.Collectors;
+import org.apache.commons.lang3.StringUtils;
+import org.apache.samza.config.Config;
+import org.apache.samza.config.JavaTableConfig;
+import org.apache.samza.config.JobConfig;
+import org.apache.samza.config.MapConfig;
+import org.apache.samza.config.SerializerConfig;
+import org.apache.samza.config.StorageConfig;
+import org.apache.samza.config.StreamConfig;
+import org.apache.samza.config.TaskConfig;
+import org.apache.samza.config.TaskConfigJava;
+import org.apache.samza.operators.KV;
+import org.apache.samza.operators.spec.JoinOperatorSpec;
+import org.apache.samza.operators.spec.OperatorSpec;
+import org.apache.samza.operators.spec.StatefulOperatorSpec;
+import org.apache.samza.operators.spec.StoreDescriptor;
+import org.apache.samza.operators.spec.WindowOperatorSpec;
+import org.apache.samza.serializers.NoOpSerde;
+import org.apache.samza.serializers.Serde;
+import org.apache.samza.serializers.SerializableSerde;
+import org.apache.samza.table.TableConfigGenerator;
+import org.apache.samza.table.TableSpec;
+import org.apache.samza.util.MathUtil;
+import org.apache.samza.util.StreamUtil;
+import org.apache.samza.util.Util;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+
+/**
+ * This class provides methods to generate configuration for a {@link JobNode}
+ */
+/* package private */ class JobNodeConfigurationGenerator {
+
+  private static final Logger LOG = LoggerFactory.getLogger(JobNodeConfigurationGenerator.class);
+
+  static final String CONFIG_INTERNAL_EXECUTION_PLAN = "samza.internal.execution.plan";
+
+  static JobConfig mergeJobConfig(Config originalConfig, Config generatedConfig) {
+    JobConfig jobConfig = new JobConfig(originalConfig);
+    String jobNameAndId = JobNode.createJobNameAndId(jobConfig.getName().get(), jobConfig.getJobId());
+    return new JobConfig(Util.rewriteConfig(extractScopedConfig(originalConfig, generatedConfig,
+        String.format(JobConfig.CONFIG_OVERRIDE_JOBS_PREFIX(), jobNameAndId))));
+  }
+
+  JobConfig generateJobConfig(JobNode jobNode, String executionPlanJson) {
+    Map<String, String> configs = new HashMap<>();
+    // set up job name and job ID
+    configs.put(JobConfig.JOB_NAME(), jobNode.getJobName());
+    configs.put(JobConfig.JOB_ID(), jobNode.getJobId());
+
+    Map<String, StreamEdge> inEdges = jobNode.getInEdges();
+    Map<String, StreamEdge> outEdges = jobNode.getOutEdges();
+    Collection<OperatorSpec> reachableOperators = jobNode.getReachableOperators();
+    List<StoreDescriptor> stores = getStoreDescriptors(reachableOperators);
+    Map<String, TableSpec> reachableTables = getReachableTables(reachableOperators, jobNode);
+    Config config = jobNode.getConfig();
+
+    // check all inputs to the node for broadcast and input streams
+    final Set<String> inputs = new HashSet<>();
+    final Set<String> broadcasts = new HashSet<>();
+    for (StreamEdge inEdge : inEdges.values()) {
+      String formattedSystemStream = inEdge.getName();
+      if (inEdge.isBroadcast()) {
+        broadcasts.add(formattedSystemStream + "#0");
+      } else {
+        inputs.add(formattedSystemStream);
+      }
+    }
+
+    configureBroadcastInputs(configs, config, broadcasts);
+
+    // compute window and join operator intervals in this node
+    configureWindowInterval(configs, config, reachableOperators);
+
+    // set store configuration for stateful operators.
+    stores.forEach(sd -> configs.putAll(sd.getStorageConfigs()));
+
+    // set the execution plan in json
+    configs.put(CONFIG_INTERNAL_EXECUTION_PLAN, executionPlanJson);
+
+    // write intermediate input/output streams to configs
+    inEdges.values().stream().filter(StreamEdge::isIntermediate).forEach(edge -> configs.putAll(edge.generateConfig()));
+
+    // write serialized serde instances and stream, store, and table serdes to configs
+    // serde configuration generation has to happen before table configuration, since the serde configuration
+    // is required when generating configurations for some TableProvider (i.e. local store backed tables)
+    configureSerdes(configs, inEdges, outEdges, stores, reachableTables.keySet(), jobNode);
+
+    // generate table configuration and potential side input configuration
+    configureTables(configs, config, reachableTables, inputs);
+
+    // finalize the task.inputs configuration
+    configs.put(TaskConfig.INPUT_STREAMS(), Joiner.on(',').join(inputs));
+
+    LOG.info("Job {} has generated configs {}", jobNode.getJobNameAndId(), configs);
+
+    // apply configure rewriters and user configure overrides
+    return applyConfigureRewritersAndOverrides(configs, config, jobNode);
+  }
+
+  private Map<String, TableSpec> getReachableTables(Collection<OperatorSpec> reachableOperators, JobNode jobNode) {
+    // TODO: Fix this in SAMZA-1893. For now, returning all tables for single-job execution plan
+    return jobNode.getTables();
+  }
+
+  private void configureBroadcastInputs(Map<String, String> configs, Config config, Set<String> broadcastStreams) {
+    // TODO: SAMZA-1841: remove this once we support defining broadcast input stream in high-level
+    // task.broadcast.input should be generated by the planner in the future.
+    if (broadcastStreams.isEmpty()) {
+      return;
+    }
+    final String taskBroadcasts = config.get(TaskConfigJava.BROADCAST_INPUT_STREAMS);
+    if (StringUtils.isNoneEmpty(taskBroadcasts)) {
+      broadcastStreams.add(taskBroadcasts);
+    }
+    configs.put(TaskConfigJava.BROADCAST_INPUT_STREAMS, Joiner.on(',').join(broadcastStreams));
+  }
+
+  private void configureWindowInterval(Map<String, String> configs, Config config,
+      Collection<OperatorSpec> reachableOperators) {
+    if (!reachableOperators.stream().anyMatch(op -> op.getOpCode() == OperatorSpec.OpCode.WINDOW
+        || op.getOpCode() == OperatorSpec.OpCode.JOIN)) {
+      return;
+    }
+
+    // set triggering interval if a window or join is defined. Only applies to high-level applications
+    if ("-1".equals(config.get(TaskConfig.WINDOW_MS(), "-1"))) {
+      long triggerInterval = computeTriggerInterval(reachableOperators);
+      LOG.info("Using triggering interval: {}", triggerInterval);
+
+      configs.put(TaskConfig.WINDOW_MS(), String.valueOf(triggerInterval));
+    }
+  }
+
+  /**
+   * Computes the triggering interval to use during the execution of this {@link JobNode}
+   */
+  private long computeTriggerInterval(Collection<OperatorSpec> reachableOperators) {
+    List<Long> windowTimerIntervals =  reachableOperators.stream()
+        .filter(spec -> spec.getOpCode() == OperatorSpec.OpCode.WINDOW)
+        .map(spec -> ((WindowOperatorSpec) spec).getDefaultTriggerMs())
+        .collect(Collectors.toList());
+
+    // Filter out the join operators, and obtain a list of their ttl values
+    List<Long> joinTtlIntervals = reachableOperators.stream()
+        .filter(spec -> spec instanceof JoinOperatorSpec)
+        .map(spec -> ((JoinOperatorSpec) spec).getTtlMs())
+        .collect(Collectors.toList());
+
+    // Combine both the above lists
+    List<Long> candidateTimerIntervals = new ArrayList<>(joinTtlIntervals);
+    candidateTimerIntervals.addAll(windowTimerIntervals);
+
+    if (candidateTimerIntervals.isEmpty()) {
+      return -1;
+    }
+
+    // Compute the gcd of the resultant list
+    return MathUtil.gcd(candidateTimerIntervals);
+  }
+
+  private JobConfig applyConfigureRewritersAndOverrides(Map<String, String> configs, Config config, JobNode jobNode) {
+    // Disallow user specified job inputs/outputs. This info comes strictly from the user application.
+    Map<String, String> allowedConfigs = new HashMap<>(config);
+    if (!jobNode.isLegacyTaskApplication()) {
+      if (allowedConfigs.containsKey(TaskConfig.INPUT_STREAMS())) {
+        LOG.warn("Specifying task inputs in configuration is not allowed for SamzaApplication. "
+            + "Ignoring configured value for " + TaskConfig.INPUT_STREAMS());
+        allowedConfigs.remove(TaskConfig.INPUT_STREAMS());
+      }
+    }
+
+    LOG.debug("Job {} has allowed configs {}", jobNode.getJobNameAndId(), allowedConfigs);
+    return mergeJobConfig(new MapConfig(allowedConfigs), new MapConfig(configs));
+  }
+
+  /**
+   * This function extract the subset of configs from the full config, and use it to override the generated configs
+   * from the job.
+   * @param fullConfig full config
+   * @param generatedConfig config generated for the job
+   * @param configPrefix prefix to extract the subset of the config overrides
+   * @return config that merges the generated configs and overrides
+   */
+  private static Config extractScopedConfig(Config fullConfig, Config generatedConfig, String configPrefix) {
+    Config scopedConfig = fullConfig.subset(configPrefix);
+
+    Config[] configPrecedence = new Config[] {fullConfig, generatedConfig, scopedConfig};
+    // Strip empty configs so they don't override the configs before them.
+    Map<String, String> mergedConfig = new HashMap<>();
+    for (Map<String, String> config : configPrecedence) {
+      for (Map.Entry<String, String> property : config.entrySet()) {
+        String value = property.getValue();
+        if (!(value == null || value.isEmpty())) {
+          mergedConfig.put(property.getKey(), property.getValue());
+        }
+      }
+    }
+    scopedConfig = new MapConfig(mergedConfig);
+    LOG.debug("Prefix '{}' has merged config {}", configPrefix, scopedConfig);
+
+    return scopedConfig;
+  }
+
+  private List<StoreDescriptor> getStoreDescriptors(Collection<OperatorSpec> reachableOperators) {
+    return reachableOperators.stream().filter(operatorSpec -> operatorSpec instanceof StatefulOperatorSpec)
+        .map(operatorSpec -> ((StatefulOperatorSpec) operatorSpec).getStoreDescriptors()).flatMap(Collection::stream)
+        .collect(Collectors.toList());
+  }
+
+  private void configureTables(Map<String, String> configs, Config config, Map<String, TableSpec> tables, Set<String> inputs) {
+    configs.putAll(TableConfigGenerator.generateConfigsForTableSpecs(new MapConfig(configs),
+        tables.values().stream().collect(Collectors.toList())));
+
+    // Add side inputs to the inputs and mark the stream as bootstrap
+    tables.values().forEach(tableSpec -> {
+        List<String> sideInputs = tableSpec.getSideInputs();
+        if (sideInputs != null && !sideInputs.isEmpty()) {
+          sideInputs.stream()
+              .map(sideInput -> StreamUtil.getSystemStreamFromNameOrId(config, sideInput))
+              .forEach(systemStream -> {
+                  inputs.add(StreamUtil.getNameFromSystemStream(systemStream));
+                  configs.put(String.format(StreamConfig.STREAM_PREFIX() + StreamConfig.BOOTSTRAP(),
+                      systemStream.getSystem(), systemStream.getStream()), "true");
+                });
+        }
+      });
+  }
+
+  /**
+   * Serializes the {@link Serde} instances for operators, adds them to the provided config, and
+   * sets the serde configuration for the input/output/intermediate streams appropriately.
+   *
+   * We try to preserve the number of Serde instances before and after serialization. However we don't
+   * guarantee that references shared between these serdes instances (e.g. an Jackson ObjectMapper shared
+   * between two json serdes) are shared after deserialization too.
+   *
+   * Ideally all the user defined objects in the application should be serialized and de-serialized in one pass
+   * from the same output/input stream so that we can maintain reference sharing relationships.
+   *
+   * @param configs the configs to add serialized serde instances and stream serde configs to
+   */
+  private void configureSerdes(Map<String, String> configs, Map<String, StreamEdge> inEdges, Map<String, StreamEdge> outEdges,
+      List<StoreDescriptor> stores, Collection<String> tables, JobNode jobNode) {
+    // collect all key and msg serde instances for streams
+    Map<String, Serde> streamKeySerdes = new HashMap<>();
+    Map<String, Serde> streamMsgSerdes = new HashMap<>();
+    inEdges.keySet().forEach(streamId ->
+        addSerdes(jobNode.getInputSerdes(streamId), streamId, streamKeySerdes, streamMsgSerdes));
+    outEdges.keySet().forEach(streamId ->
+        addSerdes(jobNode.getOutputSerde(streamId), streamId, streamKeySerdes, streamMsgSerdes));
+
+    Map<String, Serde> storeKeySerdes = new HashMap<>();
+    Map<String, Serde> storeMsgSerdes = new HashMap<>();
+    stores.forEach(storeDescriptor -> {
+        storeKeySerdes.put(storeDescriptor.getStoreName(), storeDescriptor.getKeySerde());
+        storeMsgSerdes.put(storeDescriptor.getStoreName(), storeDescriptor.getMsgSerde());
+      });
+
+    Map<String, Serde> tableKeySerdes = new HashMap<>();
+    Map<String, Serde> tableMsgSerdes = new HashMap<>();
+    tables.forEach(tableId -> {
+        addSerdes(jobNode.getTableSerdes(tableId), tableId, tableKeySerdes, tableMsgSerdes);
+      });
+
+    // for each unique stream or store serde instance, generate a unique name and serialize to config
+    HashSet<Serde> serdes = new HashSet<>(streamKeySerdes.values());
+    serdes.addAll(streamMsgSerdes.values());
+    serdes.addAll(storeKeySerdes.values());
+    serdes.addAll(storeMsgSerdes.values());
+    serdes.addAll(tableKeySerdes.values());
+    serdes.addAll(tableMsgSerdes.values());
+    SerializableSerde<Serde> serializableSerde = new SerializableSerde<>();
+    Base64.Encoder base64Encoder = Base64.getEncoder();
+    Map<Serde, String> serdeUUIDs = new HashMap<>();
+    serdes.forEach(serde -> {
+        String serdeName = serdeUUIDs.computeIfAbsent(serde,
+            s -> serde.getClass().getSimpleName() + "-" + UUID.randomUUID().toString());
+        configs.putIfAbsent(String.format(SerializerConfig.SERDE_SERIALIZED_INSTANCE(), serdeName),
+            base64Encoder.encodeToString(serializableSerde.toBytes(serde)));
+      });
+
+    // set key and msg serdes for streams to the serde names generated above
+    streamKeySerdes.forEach((streamId, serde) -> {
+        String streamIdPrefix = String.format(StreamConfig.STREAM_ID_PREFIX(), streamId);
+        String keySerdeConfigKey = streamIdPrefix + StreamConfig.KEY_SERDE();
+        configs.put(keySerdeConfigKey, serdeUUIDs.get(serde));
+      });
+
+    streamMsgSerdes.forEach((streamId, serde) -> {
+        String streamIdPrefix = String.format(StreamConfig.STREAM_ID_PREFIX(), streamId);
+        String valueSerdeConfigKey = streamIdPrefix + StreamConfig.MSG_SERDE();
+        configs.put(valueSerdeConfigKey, serdeUUIDs.get(serde));
+      });
+
+    // set key and msg serdes for stores to the serde names generated above
+    storeKeySerdes.forEach((storeName, serde) -> {
+        String keySerdeConfigKey = String.format(StorageConfig.KEY_SERDE(), storeName);
+        configs.put(keySerdeConfigKey, serdeUUIDs.get(serde));
+      });
+
+    storeMsgSerdes.forEach((storeName, serde) -> {
+        String msgSerdeConfigKey = String.format(StorageConfig.MSG_SERDE(), storeName);
+        configs.put(msgSerdeConfigKey, serdeUUIDs.get(serde));
+      });
+
+    // set key and msg serdes for stores to the serde names generated above
+    tableKeySerdes.forEach((tableId, serde) -> {
+        String keySerdeConfigKey = String.format(JavaTableConfig.TABLE_KEY_SERDE, tableId);
+        configs.put(keySerdeConfigKey, serdeUUIDs.get(serde));
+      });
+
+    tableMsgSerdes.forEach((tableId, serde) -> {
+        String valueSerdeConfigKey = String.format(JavaTableConfig.TABLE_VALUE_SERDE, tableId);
+        configs.put(valueSerdeConfigKey, serdeUUIDs.get(serde));
+      });
+  }
+
+  private void addSerdes(KV<Serde, Serde> serdes, String streamId, Map<String, Serde> keySerdeMap,
+      Map<String, Serde> msgSerdeMap) {
+    if (serdes != null) {
+      if (serdes.getKey() != null && !(serdes.getKey() instanceof NoOpSerde)) {
+        keySerdeMap.put(streamId, serdes.getKey());
+      }
+      if (serdes.getValue() != null && !(serdes.getValue() instanceof NoOpSerde)) {
+        msgSerdeMap.put(streamId, serdes.getValue());
+      }
+    }
+  }
+}
index a2050e5..abbec18 100644 (file)
@@ -20,29 +20,20 @@ package org.apache.samza.execution;
 
 import java.io.File;
 import java.io.PrintWriter;
-import java.util.ArrayList;
-import java.util.Collections;
 import java.util.HashMap;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
 import org.apache.commons.lang3.StringUtils;
-import org.apache.samza.SamzaException;
 import org.apache.samza.application.ApplicationDescriptor;
 import org.apache.samza.application.ApplicationDescriptorImpl;
-import org.apache.samza.application.StreamApplicationDescriptorImpl;
-import org.apache.samza.application.TaskApplicationDescriptorImpl;
 import org.apache.samza.config.ApplicationConfig;
 import org.apache.samza.config.Config;
 import org.apache.samza.config.JobConfig;
 import org.apache.samza.config.MapConfig;
 import org.apache.samza.config.ShellCommandConfig;
 import org.apache.samza.config.StreamConfig;
-import org.apache.samza.operators.BaseTableDescriptor;
-import org.apache.samza.operators.OperatorSpecGraph;
-import org.apache.samza.table.TableConfigGenerator;
-import org.apache.samza.table.TableSpec;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -64,22 +55,7 @@ public abstract class JobPlanner {
     this.config = descriptor.getConfig();
   }
 
-  public List<JobConfig> prepareJobs() {
-    String appId = new ApplicationConfig(appDesc.getConfig()).getGlobalAppId();
-    if (appDesc instanceof TaskApplicationDescriptorImpl) {
-      return Collections.singletonList(prepareTaskJob((TaskApplicationDescriptorImpl) appDesc));
-    } else if (appDesc instanceof StreamApplicationDescriptorImpl) {
-      try {
-        return prepareStreamJobs((StreamApplicationDescriptorImpl) appDesc);
-      } catch (Exception e) {
-        throw new SamzaException("Failed to generate JobConfig for StreamApplication " + appId, e);
-      }
-    }
-    throw new IllegalArgumentException(String.format("ApplicationDescriptorImpl has to be either TaskApplicationDescriptorImpl or "
-        + "StreamApplicationDescriptorImpl. class %s is not supported", appDesc.getClass().getName()));
-  }
-
-  abstract List<JobConfig> prepareStreamJobs(StreamApplicationDescriptorImpl streamAppDesc) throws Exception;
+  public abstract List<JobConfig> prepareJobs();
 
   StreamManager buildAndStartStreamManager(Config config) {
     StreamManager streamManager = new StreamManager(config);
@@ -87,12 +63,12 @@ public abstract class JobPlanner {
     return streamManager;
   }
 
-  ExecutionPlan getExecutionPlan(OperatorSpecGraph specGraph) {
-    return getExecutionPlan(specGraph, null);
+  ExecutionPlan getExecutionPlan() {
+    return getExecutionPlan(null);
   }
 
   /* package private */
-  ExecutionPlan getExecutionPlan(OperatorSpecGraph specGraph, String runId) {
+  ExecutionPlan getExecutionPlan(String runId) {
 
     // update application configs
     Map<String, String> cfg = new HashMap<>();
@@ -101,8 +77,8 @@ public abstract class JobPlanner {
     }
 
     StreamConfig streamConfig = new StreamConfig(config);
-    Set<String> inputStreams = new HashSet<>(specGraph.getInputOperators().keySet());
-    inputStreams.removeAll(specGraph.getOutputStreams().keySet());
+    Set<String> inputStreams = new HashSet<>(appDesc.getInputStreamIds());
+    inputStreams.removeAll(appDesc.getOutputStreamIds());
     ApplicationConfig.ApplicationMode mode = inputStreams.stream().allMatch(streamConfig::getIsBounded)
         ? ApplicationConfig.ApplicationMode.BATCH : ApplicationConfig.ApplicationMode.STREAM;
     cfg.put(ApplicationConfig.APP_MODE, mode.name());
@@ -117,12 +93,12 @@ public abstract class JobPlanner {
 
     // create the physical execution plan and merge with overrides. This works for a single-stage job now
     // TODO: This should all be consolidated with ExecutionPlanner after fixing SAMZA-1811
-    Config mergedConfig = JobNode.mergeJobConfig(config, new MapConfig(cfg));
+    Config mergedConfig = JobNodeConfigurationGenerator.mergeJobConfig(config, new MapConfig(cfg));
     // creating the StreamManager to get all input/output streams' metadata for planning
     StreamManager streamManager = buildAndStartStreamManager(mergedConfig);
     try {
       ExecutionPlanner planner = new ExecutionPlanner(mergedConfig, streamManager);
-      return planner.plan(specGraph);
+      return planner.plan(appDesc);
     } finally {
       streamManager.stop();
     }
@@ -149,25 +125,6 @@ public abstract class JobPlanner {
     }
   }
 
-  // TODO: SAMZA-1814: the following configuration generation still misses serde configuration generation,
-  // side input configuration, broadcast input and task inputs configuration generation for low-level task
-  // applications
-  // helper method to generate a single node job configuration for low level task applications
-  private JobConfig prepareTaskJob(TaskApplicationDescriptorImpl taskAppDesc) {
-    // copy original configure
-    Map<String, String> cfg = new HashMap<>();
-    // expand system and streams configure
-    Map<String, String> systemStreamConfigs = expandSystemStreamConfigs(taskAppDesc);
-    cfg.putAll(systemStreamConfigs);
-    // expand table configure
-    cfg.putAll(expandTableConfigs(cfg, taskAppDesc));
-    // adding app.class in the configuration
-    cfg.put(ApplicationConfig.APP_CLASS, appDesc.getAppClass().getName());
-    // create the physical execution plan and merge with overrides. This works for a single-stage job now
-    // TODO: This should all be consolidated with ExecutionPlanner after fixing SAMZA-1811
-    return new JobConfig(JobNode.mergeJobConfig(config, new MapConfig(cfg)));
-  }
-
   private Map<String, String> expandSystemStreamConfigs(ApplicationDescriptorImpl<? extends ApplicationDescriptor> appDesc) {
     Map<String, String> systemStreamConfigs = new HashMap<>();
     appDesc.getInputDescriptors().forEach((key, value) -> systemStreamConfigs.putAll(value.toConfig()));
@@ -177,12 +134,4 @@ public abstract class JobPlanner {
         systemStreamConfigs.put(JobConfig.JOB_DEFAULT_SYSTEM(), dsd.getSystemName()));
     return systemStreamConfigs;
   }
-
-  private Map<String, String> expandTableConfigs(Map<String, String> originConfig,
-      ApplicationDescriptorImpl<? extends ApplicationDescriptor> appDesc) {
-    List<TableSpec> tableSpecs = new ArrayList<>();
-    appDesc.getTableDescriptors().stream().map(td -> ((BaseTableDescriptor) td).getTableSpec())
-        .forEach(spec -> tableSpecs.add(spec));
-    return TableConfigGenerator.generateConfigsForTableSpecs(new MapConfig(originConfig), tableSpecs);
-  }
 }
index 7996d6b..86aca0f 100644 (file)
@@ -25,7 +25,6 @@ import java.util.concurrent.TimeoutException;
 import org.apache.samza.SamzaException;
 import org.apache.samza.application.ApplicationDescriptor;
 import org.apache.samza.application.ApplicationDescriptorImpl;
-import org.apache.samza.application.StreamApplicationDescriptorImpl;
 import org.apache.samza.config.ApplicationConfig;
 import org.apache.samza.config.JobConfig;
 import org.apache.samza.config.JobCoordinatorConfig;
@@ -37,7 +36,7 @@ import org.slf4j.LoggerFactory;
 
 
 /**
- * Temporarily helper class with specific implementation of {@link JobPlanner#prepareStreamJobs(StreamApplicationDescriptorImpl)}
+ * Temporarily helper class with specific implementation of {@link JobPlanner#prepareJobs()}
  * for standalone Samza processors.
  *
  * TODO: we need to consolidate this with {@link ExecutionPlanner} after SAMZA-1811.
@@ -53,17 +52,23 @@ public class LocalJobPlanner extends JobPlanner {
   }
 
   @Override
-  List<JobConfig> prepareStreamJobs(StreamApplicationDescriptorImpl streamAppDesc) throws Exception {
+  public List<JobConfig> prepareJobs() {
     // for high-level DAG, generating the plan and job configs
     // 1. initialize and plan
-    ExecutionPlan plan = getExecutionPlan(streamAppDesc.getOperatorSpecGraph());
+    ExecutionPlan plan = getExecutionPlan();
 
-    String executionPlanJson = plan.getPlanAsJson();
+    String executionPlanJson = "";
+    try {
+      executionPlanJson = plan.getPlanAsJson();
+    } catch (Exception e) {
+      throw new SamzaException("Failed to create plan JSON.", e);
+    }
     writePlanJsonFile(executionPlanJson);
     LOG.info("Execution Plan: \n" + executionPlanJson);
     String planId = String.valueOf(executionPlanJson.hashCode());
 
-    if (plan.getJobConfigs().isEmpty()) {
+    List<JobConfig> jobConfigs = plan.getJobConfigs();
+    if (jobConfigs.isEmpty()) {
       throw new SamzaException("No jobs in the plan.");
     }
 
@@ -71,7 +76,7 @@ public class LocalJobPlanner extends JobPlanner {
     // TODO: System generated intermediate streams should have robust naming scheme. See SAMZA-1391
     // TODO: this works for single-job applications. For multi-job applications, ExecutionPlan should return an AppConfig
     // to be used for the whole application
-    JobConfig jobConfig = plan.getJobConfigs().get(0);
+    JobConfig jobConfig = jobConfigs.get(0);
     StreamManager streamManager = null;
     try {
       // create the StreamManager to create intermediate streams in the plan
@@ -82,7 +87,7 @@ public class LocalJobPlanner extends JobPlanner {
         streamManager.stop();
       }
     }
-    return plan.getJobConfigs();
+    return jobConfigs;
   }
 
   /**
index aa1dff9..ca91214 100644 (file)
@@ -27,15 +27,14 @@ import java.util.HashSet;
 import java.util.Set;
 import java.util.function.Consumer;
 import java.util.function.Function;
-import org.apache.samza.operators.OperatorSpecGraph;
 import org.apache.samza.operators.spec.InputOperatorSpec;
 import org.apache.samza.operators.spec.JoinOperatorSpec;
 import org.apache.samza.operators.spec.OperatorSpec;
 
 
 /**
- * A utility class that encapsulates the logic for traversing an {@link OperatorSpecGraph} and building
- * associations between related {@link OperatorSpec}s.
+ * A utility class that encapsulates the logic for traversing operators in the graph from the set of {@link InputOperatorSpec}
+ * and building associations between related {@link OperatorSpec}s.
  */
 /* package private */ class OperatorSpecGraphAnalyzer {
 
@@ -43,14 +42,13 @@ import org.apache.samza.operators.spec.OperatorSpec;
    * Returns a grouping of {@link InputOperatorSpec}s by the joins, i.e. {@link JoinOperatorSpec}s, they participate in.
    */
   public static Multimap<JoinOperatorSpec, InputOperatorSpec> getJoinToInputOperatorSpecs(
-      OperatorSpecGraph operatorSpecGraph) {
+      Collection<InputOperatorSpec> inputOperatorSpecs) {
 
     Multimap<JoinOperatorSpec, InputOperatorSpec> joinOpSpecToInputOpSpecs = HashMultimap.create();
 
     // Traverse graph starting from every input operator spec, observing connectivity between input operator specs
     // and Join operator specs.
-    Iterable<InputOperatorSpec> inputOpSpecs = operatorSpecGraph.getInputOperators().values();
-    for (InputOperatorSpec inputOpSpec : inputOpSpecs) {
+    for (InputOperatorSpec inputOpSpec : inputOperatorSpecs) {
       // Observe all join operator specs reachable from this input operator spec.
       JoinOperatorSpecVisitor joinOperatorSpecVisitor = new JoinOperatorSpecVisitor();
       traverse(inputOpSpec, joinOperatorSpecVisitor, opSpec -> opSpec.getRegisteredOperatorSpecs());
@@ -77,7 +75,7 @@ import org.apache.samza.operators.spec.OperatorSpec;
   }
 
   /**
-   * An {@link OperatorSpecGraph} visitor that records all {@link JoinOperatorSpec}s encountered in the graph.
+   * An visitor that records all {@link JoinOperatorSpec}s encountered in the graph of {@link OperatorSpec}s
    */
   private static class JoinOperatorSpecVisitor implements Consumer<OperatorSpec> {
     private Set<JoinOperatorSpec> joinOpSpecs = new HashSet<>();
index 254ff97..54f86d5 100644 (file)
@@ -23,7 +23,6 @@ import java.util.UUID;
 import org.apache.samza.SamzaException;
 import org.apache.samza.application.ApplicationDescriptor;
 import org.apache.samza.application.ApplicationDescriptorImpl;
-import org.apache.samza.application.StreamApplicationDescriptorImpl;
 import org.apache.samza.config.ApplicationConfig;
 import org.apache.samza.config.Config;
 import org.apache.samza.config.JobConfig;
@@ -34,7 +33,7 @@ import org.slf4j.LoggerFactory;
 
 
 /**
- * Temporary helper class with specific implementation of {@link JobPlanner#prepareStreamJobs(StreamApplicationDescriptorImpl)}
+ * Temporary helper class with specific implementation of {@link JobPlanner#prepareJobs()}
  * for remote-launched Samza processors (e.g. in YARN).
  *
  * TODO: we need to consolidate this class with {@link ExecutionPlanner} after SAMZA-1811.
@@ -47,7 +46,7 @@ public class RemoteJobPlanner extends JobPlanner {
   }
 
   @Override
-  List<JobConfig> prepareStreamJobs(StreamApplicationDescriptorImpl streamAppDesc) throws Exception {
+  public List<JobConfig> prepareJobs() {
     // for high-level DAG, generate the plan and job configs
     // TODO: run.id needs to be set for standalone: SAMZA-1531
     // run.id is based on current system time with the most significant bits in UUID (8 digits) to avoid collision
@@ -55,17 +54,22 @@ public class RemoteJobPlanner extends JobPlanner {
     LOG.info("The run id for this run is {}", runId);
 
     // 1. initialize and plan
-    ExecutionPlan plan = getExecutionPlan(streamAppDesc.getOperatorSpecGraph(), runId);
-    writePlanJsonFile(plan.getPlanAsJson());
+    ExecutionPlan plan = getExecutionPlan(runId);
+    try {
+      writePlanJsonFile(plan.getPlanAsJson());
+    } catch (Exception e) {
+      throw new SamzaException("Failed to create plan JSON.", e);
+    }
 
-    if (plan.getJobConfigs().isEmpty()) {
+    List<JobConfig> jobConfigs = plan.getJobConfigs();
+    if (jobConfigs.isEmpty()) {
       throw new SamzaException("No jobs in the plan.");
     }
 
     // 2. create the necessary streams
     // TODO: this works for single-job applications. For multi-job applications, ExecutionPlan should return an AppConfig
     // to be used for the whole application
-    JobConfig jobConfig = plan.getJobConfigs().get(0);
+    JobConfig jobConfig = jobConfigs.get(0);
     StreamManager streamManager = null;
     try {
       // create the StreamManager to create intermediate streams in the plan
@@ -79,7 +83,7 @@ public class RemoteJobPlanner extends JobPlanner {
         streamManager.stop();
       }
     }
-    return plan.getJobConfigs();
+    return jobConfigs;
   }
 
   private Config getConfigFromPrevRun() {
index 1e4194a..1830d1c 100644 (file)
@@ -73,6 +73,15 @@ abstract public class BaseTableDescriptor<K, V, D extends BaseTableDescriptor<K,
   }
 
   /**
+   * Get the serde assigned to this {@link TableDescriptor}
+   *
+   * @return {@link KVSerde} used by this table
+   */
+  public KVSerde<K, V> getSerde() {
+    return serde;
+  }
+
+  /**
    * Generate config for {@link TableSpec}; this method is used internally.
    * @param tableSpecConfig configuration for the {@link TableSpec}
    */
index b75b1e8..5329fd7 100644 (file)
@@ -30,7 +30,6 @@ import org.apache.samza.operators.spec.InputOperatorSpec;
 import org.apache.samza.operators.spec.OperatorSpec;
 import org.apache.samza.operators.spec.OutputStreamImpl;
 import org.apache.samza.serializers.SerializableSerde;
-import org.apache.samza.table.TableSpec;
 
 
 /**
@@ -45,7 +44,6 @@ public class OperatorSpecGraph implements Serializable {
   private final Map<String, InputOperatorSpec> inputOperators;
   private final Map<String, OutputStreamImpl> outputStreams;
   private final Set<String> broadcastStreams;
-  private final Map<TableSpec, TableImpl> tables;
   private final Set<OperatorSpec> allOpSpecs;
   private final boolean hasWindowOrJoins;
 
@@ -57,7 +55,6 @@ public class OperatorSpecGraph implements Serializable {
     this.inputOperators = streamAppDesc.getInputOperators();
     this.outputStreams = streamAppDesc.getOutputStreams();
     this.broadcastStreams = streamAppDesc.getBroadcastStreams();
-    this.tables = streamAppDesc.getTables();
     this.allOpSpecs = Collections.unmodifiableSet(this.findAllOperatorSpecs());
     this.hasWindowOrJoins = checkWindowOrJoins();
     this.serializedOpSpecGraph = opSpecGraphSerde.toBytes(this);
@@ -75,10 +72,6 @@ public class OperatorSpecGraph implements Serializable {
     return broadcastStreams;
   }
 
-  public Map<TableSpec, TableImpl> getTables() {
-    return tables;
-  }
-
   /**
    * Get all {@link OperatorSpec}s available in this {@link StreamApplicationDescriptorImpl}
    *
index 98864d2..b9bb1f6 100644 (file)
@@ -75,7 +75,7 @@ public class LocalContainerRunner {
       throw new SamzaException("can not find the job name");
     }
     String jobName = jobConfig.getName().get();
-    String jobId = jobConfig.getJobId().getOrElse(ScalaJavaUtil.defaultValue("1"));
+    String jobId = jobConfig.getJobId();
     MDC.put("containerName", "samza-container-" + containerId);
     MDC.put("jobName", jobName);
     MDC.put("jobId", jobId);
index 085131c..03be758 100644 (file)
 package org.apache.samza.table;
 
 import java.util.ArrayList;
-import java.util.Base64;
 import java.util.HashMap;
-import java.util.HashSet;
 import java.util.LinkedHashMap;
 import java.util.List;
 import java.util.Map;
-import java.util.UUID;
 
 import org.apache.samza.config.Config;
 import org.apache.samza.config.JavaTableConfig;
-import org.apache.samza.config.SerializerConfig;
 import org.apache.samza.operators.BaseTableDescriptor;
 import org.apache.samza.operators.TableDescriptor;
 import org.apache.samza.operators.TableImpl;
-import org.apache.samza.serializers.Serde;
-import org.apache.samza.serializers.SerializableSerde;
 import org.apache.samza.util.Util;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -66,8 +60,6 @@ public class TableConfigGenerator {
   static public Map<String, String> generateConfigsForTableSpecs(Config config, List<TableSpec> tableSpecs) {
     Map<String, String> tableConfigs = new HashMap<>();
 
-    tableConfigs.putAll(generateTableKVSerdeConfigs(tableSpecs));
-
     tableSpecs.forEach(tableSpec -> {
         // Add table provider factory config
         tableConfigs.put(String.format(JavaTableConfig.TABLE_PROVIDER_FACTORY, tableSpec.getId()),
@@ -103,44 +95,4 @@ public class TableConfigGenerator {
       });
     return new ArrayList<>(tableSpecs.keySet());
   }
-
-  static private Map<String, String> generateTableKVSerdeConfigs(List<TableSpec> tableSpecs) {
-    Map<String, String> serdeConfigs = new HashMap<>();
-
-    // Collect key and msg serde instances for all the tables
-    Map<String, Serde> tableKeySerdes = new HashMap<>();
-    Map<String, Serde> tableValueSerdes = new HashMap<>();
-    HashSet<Serde> serdes = new HashSet<>();
-
-    tableSpecs.forEach(tableSpec -> {
-        tableKeySerdes.put(tableSpec.getId(), tableSpec.getSerde().getKeySerde());
-        tableValueSerdes.put(tableSpec.getId(), tableSpec.getSerde().getValueSerde());
-      });
-    serdes.addAll(tableKeySerdes.values());
-    serdes.addAll(tableValueSerdes.values());
-
-    // Generate serde names
-    SerializableSerde<Serde> serializableSerde = new SerializableSerde<>();
-    Base64.Encoder base64Encoder = Base64.getEncoder();
-    Map<Serde, String> serdeUUIDs = new HashMap<>();
-    serdes.forEach(serde -> {
-        String serdeName = serdeUUIDs.computeIfAbsent(serde,
-            s -> serde.getClass().getSimpleName() + "-" + UUID.randomUUID().toString());
-        serdeConfigs.putIfAbsent(String.format(SerializerConfig.SERDE_SERIALIZED_INSTANCE(), serdeName),
-            base64Encoder.encodeToString(serializableSerde.toBytes(serde)));
-      });
-
-    // Set key and msg serdes for tables to the serde names generated above
-    tableKeySerdes.forEach((tableId, serde) -> {
-        String keySerdeConfigKey = String.format(JavaTableConfig.TABLE_KEY_SERDE, tableId);
-        serdeConfigs.put(keySerdeConfigKey, serdeUUIDs.get(serde));
-      });
-
-    tableValueSerdes.forEach((tableId, serde) -> {
-        String valueSerdeConfigKey = String.format(JavaTableConfig.TABLE_VALUE_SERDE, tableId);
-        serdeConfigs.put(valueSerdeConfigKey, serdeUUIDs.get(serde));
-      });
-
-    return serdeConfigs;
-  }
 }
index 3dad6c1..41294a3 100644 (file)
@@ -35,7 +35,6 @@ public class ZkJobCoordinatorFactory implements JobCoordinatorFactory {
 
   private static final Logger LOG = LoggerFactory.getLogger(ZkJobCoordinatorFactory.class);
   private static final String JOB_COORDINATOR_ZK_PATH_FORMAT = "%s/%s-%s-coordinationData";
-  private static final String DEFAULT_JOB_ID = "1";
   private static final String DEFAULT_JOB_NAME = "defaultJob";
 
   /**
@@ -68,9 +67,7 @@ public class ZkJobCoordinatorFactory implements JobCoordinatorFactory {
     String jobName = jobConfig.getName().isDefined()
         ? jobConfig.getName().get()
         : DEFAULT_JOB_NAME;
-    String jobId = jobConfig.getJobId().isDefined()
-        ? jobConfig.getJobId().get()
-        : DEFAULT_JOB_ID;
+    String jobId = jobConfig.getJobId();
 
     return String.format(JOB_COORDINATOR_ZK_PATH_FORMAT, appId, jobName, jobId);
   }
index fc8780f..d7b71b5 100644 (file)
@@ -39,7 +39,7 @@ object JobConfig {
    */
   val CONFIG_REWRITERS = "job.config.rewriters" // streaming.job_config_rewriters
   val CONFIG_REWRITER_CLASS = "job.config.rewriter.%s.class" // streaming.job_config_rewriter_class - regex, system, config
-  val CONFIG_JOB_PREFIX = "jobs.%s."
+  val CONFIG_OVERRIDE_JOBS_PREFIX = "jobs.%s."
   val JOB_NAME = "job.name" // streaming.job_name
   val JOB_ID = "job.id" // streaming.job_id
   val SAMZA_FWK_PATH = "samza.fwk.path"
@@ -164,7 +164,7 @@ class JobConfig(config: Config) extends ScalaMapConfig(config) with Logging {
 
   def getStreamJobFactoryClass = getOption(JobConfig.STREAM_JOB_FACTORY_CLASS)
 
-  def getJobId = getOption(JobConfig.JOB_ID)
+  def getJobId = getOption(JobConfig.JOB_ID).getOrElse("1")
 
   def failOnCheckpointValidation = { getBoolean(JobConfig.JOB_FAIL_CHECKPOINT_VALIDATION, true) }
 
index fba7329..417fc18 100644 (file)
@@ -101,7 +101,7 @@ object SamzaContainer extends Logging {
     if(System.getenv(ShellCommandConfig.ENV_LOGGED_STORE_BASE_DIR) != null) {
       val jobNameAndId = (
         config.getName.getOrElse(throw new ConfigException("Missing required config: job.name")),
-        config.getJobId.getOrElse("1")
+        config.getJobId
       )
 
       loggedStorageBaseDir = new File(System.getenv(ShellCommandConfig.ENV_LOGGED_STORE_BASE_DIR)
index d1e6554..8a9c021 100644 (file)
@@ -44,7 +44,6 @@ class MetricsSnapshotReporterFactory extends MetricsReporterFactory with Logging
 
     val jobId = config
       .getJobId
-      .getOrElse(1.toString)
 
     val taskClass = config
       .getTaskClass
index cd74716..bfb2271 100644 (file)
@@ -89,6 +89,6 @@ object CoordinatorStreamUtil {
     */
   private def getJobNameAndId(config: Config) = {
     (config.getName.getOrElse(throw new ConfigException("Missing required config: job.name")),
-      config.getJobId.getOrElse("1"))
+      config.getJobId)
   }
 }
index db85e33..1fe6023 100644 (file)
@@ -522,10 +522,11 @@ public class TestStreamApplicationDescriptorImpl {
     TableSpec testTableSpec = new TableSpec("t1", KVSerde.of(new NoOpSerde(), new NoOpSerde()), "", new HashMap<>());
     when(mockTableDescriptor.getTableSpec()).thenReturn(testTableSpec);
     when(mockTableDescriptor.getTableId()).thenReturn(testTableSpec.getId());
+    when(mockTableDescriptor.getSerde()).thenReturn(testTableSpec.getSerde());
     StreamApplicationDescriptorImpl streamAppDesc = new StreamApplicationDescriptorImpl(appDesc -> {
         appDesc.getTable(mockTableDescriptor);
       }, mockConfig);
-    assertNotNull(streamAppDesc.getTables().get(testTableSpec));
+    assertNotNull(streamAppDesc.getTables().get(testTableSpec.getId()));
   }
 
   @Test
index 9418c1f..abe5ce1 100644 (file)
@@ -23,12 +23,14 @@ import java.util.HashSet;
 import java.util.List;
 import java.util.Set;
 import org.apache.samza.config.Config;
+import org.apache.samza.operators.BaseTableDescriptor;
 import org.apache.samza.operators.ContextManager;
 import org.apache.samza.operators.TableDescriptor;
 import org.apache.samza.operators.descriptors.base.stream.InputDescriptor;
 import org.apache.samza.operators.descriptors.base.stream.OutputDescriptor;
 import org.apache.samza.operators.descriptors.base.system.SystemDescriptor;
 import org.apache.samza.runtime.ProcessorLifecycleListenerFactory;
+import org.apache.samza.serializers.KVSerde;
 import org.apache.samza.task.TaskFactory;
 import org.junit.Before;
 import org.junit.Test;
@@ -64,10 +66,12 @@ public class TestTaskApplicationDescriptorImpl {
       this.add(mock2);
     } };
   private Set<TableDescriptor> mockTables = new HashSet<TableDescriptor>() { {
-      TableDescriptor mock1 = mock(TableDescriptor.class);
-      TableDescriptor mock2 = mock(TableDescriptor.class);
+      BaseTableDescriptor mock1 = mock(BaseTableDescriptor.class);
+      BaseTableDescriptor mock2 = mock(BaseTableDescriptor.class);
       when(mock1.getTableId()).thenReturn("test-table1");
       when(mock2.getTableId()).thenReturn("test-table2");
+      when(mock1.getSerde()).thenReturn(mock(KVSerde.class));
+      when(mock2.getSerde()).thenReturn(mock(KVSerde.class));
       this.add(mock1);
       this.add(mock2);
     } };
diff --git a/samza-core/src/test/java/org/apache/samza/execution/ExecutionPlannerTestBase.java b/samza-core/src/test/java/org/apache/samza/execution/ExecutionPlannerTestBase.java
new file mode 100644 (file)
index 0000000..f507c70
--- /dev/null
@@ -0,0 +1,157 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.samza.execution;
+
+import java.time.Duration;
+import java.util.HashMap;
+import java.util.Map;
+import org.apache.samza.application.ApplicationDescriptorImpl;
+import org.apache.samza.application.LegacyTaskApplication;
+import org.apache.samza.application.StreamApplication;
+import org.apache.samza.application.StreamApplicationDescriptorImpl;
+import org.apache.samza.application.TaskApplication;
+import org.apache.samza.config.Config;
+import org.apache.samza.config.JobConfig;
+import org.apache.samza.config.MapConfig;
+import org.apache.samza.operators.KV;
+import org.apache.samza.operators.MessageStream;
+import org.apache.samza.operators.OutputStream;
+import org.apache.samza.operators.descriptors.GenericInputDescriptor;
+import org.apache.samza.operators.descriptors.GenericOutputDescriptor;
+import org.apache.samza.operators.descriptors.GenericSystemDescriptor;
+import org.apache.samza.operators.functions.JoinFunction;
+import org.apache.samza.serializers.JsonSerdeV2;
+import org.apache.samza.serializers.KVSerde;
+import org.apache.samza.serializers.Serde;
+import org.apache.samza.serializers.StringSerde;
+import org.apache.samza.task.IdentityStreamTask;
+import org.junit.Before;
+
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.spy;
+
+
+/**
+ * Unit test base class to set up commonly used test application and configuration.
+ */
+class ExecutionPlannerTestBase {
+  protected StreamApplicationDescriptorImpl mockStreamAppDesc;
+  protected Config mockConfig;
+  protected JobNode mockJobNode;
+  protected KVSerde<String, Object> defaultSerde;
+  protected GenericSystemDescriptor inputSystemDescriptor;
+  protected GenericSystemDescriptor outputSystemDescriptor;
+  protected GenericSystemDescriptor intermediateSystemDescriptor;
+  protected GenericInputDescriptor<KV<String, Object>> input1Descriptor;
+  protected GenericInputDescriptor<KV<String, Object>> input2Descriptor;
+  protected GenericInputDescriptor<KV<String, Object>> intermediateInputDescriptor;
+  protected GenericInputDescriptor<KV<String, Object>> broadcastInputDesriptor;
+  protected GenericOutputDescriptor<KV<String, Object>> outputDescriptor;
+  protected GenericOutputDescriptor<KV<String, Object>> intermediateOutputDescriptor;
+
+  @Before
+  public void setUp() {
+    defaultSerde = KVSerde.of(new StringSerde(), new JsonSerdeV2<>());
+    inputSystemDescriptor = new GenericSystemDescriptor("input-system", "mockSystemFactoryClassName");
+    outputSystemDescriptor = new GenericSystemDescriptor("output-system", "mockSystemFactoryClassName");
+    intermediateSystemDescriptor = new GenericSystemDescriptor("intermediate-system", "mockSystemFactoryClassName");
+    input1Descriptor = inputSystemDescriptor.getInputDescriptor("input1", defaultSerde);
+    input2Descriptor = inputSystemDescriptor.getInputDescriptor("input2", defaultSerde);
+    outputDescriptor = outputSystemDescriptor.getOutputDescriptor("output", defaultSerde);
+    intermediateInputDescriptor = intermediateSystemDescriptor.getInputDescriptor("jobName-jobId-partition_by-p1", defaultSerde)
+        .withPhysicalName("jobName-jobId-partition_by-p1");
+    intermediateOutputDescriptor = intermediateSystemDescriptor.getOutputDescriptor("jobName-jobId-partition_by-p1", defaultSerde)
+        .withPhysicalName("jobName-jobId-partition_by-p1");
+    broadcastInputDesriptor = intermediateSystemDescriptor.getInputDescriptor("jobName-jobId-broadcast-b1", defaultSerde)
+        .withPhysicalName("jobName-jobId-broadcast-b1");
+
+    Map<String, String> configs = new HashMap<>();
+    configs.put(JobConfig.JOB_NAME(), "jobName");
+    configs.put(JobConfig.JOB_ID(), "jobId");
+    configs.putAll(input1Descriptor.toConfig());
+    configs.putAll(input2Descriptor.toConfig());
+    configs.putAll(outputDescriptor.toConfig());
+    configs.putAll(inputSystemDescriptor.toConfig());
+    configs.putAll(outputSystemDescriptor.toConfig());
+    configs.putAll(intermediateSystemDescriptor.toConfig());
+    configs.put(JobConfig.JOB_DEFAULT_SYSTEM(), intermediateSystemDescriptor.getSystemName());
+    mockConfig = spy(new MapConfig(configs));
+
+    mockStreamAppDesc = new StreamApplicationDescriptorImpl(getRepartitionJoinStreamApplication(), mockConfig);
+  }
+
+  String getJobNameAndId() {
+    return "jobName-jobId";
+  }
+
+  void configureJobNode(ApplicationDescriptorImpl mockStreamAppDesc) {
+    JobGraph jobGraph = new ExecutionPlanner(mockConfig, mock(StreamManager.class))
+        .createJobGraph(mockConfig, mockStreamAppDesc);
+    mockJobNode = spy(jobGraph.getJobNodes().get(0));
+  }
+
+  StreamApplication getRepartitionOnlyStreamApplication() {
+    return appDesc -> {
+      MessageStream<KV<String, Object>> input1 = appDesc.getInputStream(input1Descriptor);
+      input1.partitionBy(KV::getKey, KV::getValue, "p1");
+    };
+  }
+
+  StreamApplication getRepartitionJoinStreamApplication() {
+    return appDesc -> {
+      MessageStream<KV<String, Object>> input1 = appDesc.getInputStream(input1Descriptor);
+      MessageStream<KV<String, Object>> input2 = appDesc.getInputStream(input2Descriptor);
+      OutputStream<KV<String, Object>> output = appDesc.getOutputStream(outputDescriptor);
+      JoinFunction<String, Object, Object, KV<String, Object>> mockJoinFn = mock(JoinFunction.class);
+      input1
+          .partitionBy(KV::getKey, KV::getValue, defaultSerde, "p1")
+          .map(kv -> kv.value)
+          .join(input2.map(kv -> kv.value), mockJoinFn,
+              new StringSerde(), new JsonSerdeV2<>(Object.class), new JsonSerdeV2<>(Object.class),
+              Duration.ofHours(1), "j1")
+          .sendTo(output);
+    };
+  }
+
+  TaskApplication getTaskApplication() {
+    return appDesc -> {
+      appDesc.addInputStream(input1Descriptor);
+      appDesc.addInputStream(input2Descriptor);
+      appDesc.addInputStream(intermediateInputDescriptor);
+      appDesc.addOutputStream(intermediateOutputDescriptor);
+      appDesc.addOutputStream(outputDescriptor);
+      appDesc.setTaskFactory(() -> new IdentityStreamTask());
+    };
+  }
+
+  TaskApplication getLegacyTaskApplication() {
+    return new LegacyTaskApplication(IdentityStreamTask.class.getName());
+  }
+
+  StreamApplication getBroadcastOnlyStreamApplication(Serde serde) {
+    return appDesc -> {
+      MessageStream<KV<String, Object>> input = appDesc.getInputStream(input1Descriptor);
+      if (serde != null) {
+        input.broadcast(serde, "b1");
+      } else {
+        input.broadcast("b1");
+      }
+    };
+  }
+}
index 779d299..61289af 100644 (file)
@@ -24,12 +24,18 @@ import java.util.ArrayList;
 import java.util.Collection;
 import java.util.Collections;
 import java.util.HashMap;
+import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
+import java.util.stream.Collectors;
 import org.apache.samza.Partition;
 import org.apache.samza.SamzaException;
+import org.apache.samza.application.ApplicationDescriptor;
+import org.apache.samza.application.LegacyTaskApplication;
+import org.apache.samza.application.SamzaApplication;
 import org.apache.samza.application.StreamApplicationDescriptorImpl;
+import org.apache.samza.application.TaskApplicationDescriptorImpl;
 import org.apache.samza.config.Config;
 import org.apache.samza.config.JobConfig;
 import org.apache.samza.config.MapConfig;
@@ -37,9 +43,13 @@ import org.apache.samza.config.TaskConfig;
 import org.apache.samza.operators.KV;
 import org.apache.samza.operators.MessageStream;
 import org.apache.samza.operators.OutputStream;
+import org.apache.samza.operators.TableDescriptor;
 import org.apache.samza.operators.descriptors.GenericInputDescriptor;
 import org.apache.samza.operators.descriptors.GenericOutputDescriptor;
 import org.apache.samza.operators.descriptors.GenericSystemDescriptor;
+import org.apache.samza.operators.descriptors.base.stream.InputDescriptor;
+import org.apache.samza.operators.descriptors.base.stream.OutputDescriptor;
+import org.apache.samza.operators.descriptors.base.system.SystemDescriptor;
 import org.apache.samza.operators.functions.JoinFunction;
 import org.apache.samza.operators.windows.Windows;
 import org.apache.samza.serializers.KVSerde;
@@ -54,8 +64,12 @@ import org.apache.samza.testUtils.StreamTestUtils;
 import org.junit.Before;
 import org.junit.Test;
 
-import static org.junit.Assert.*;
-import static org.mockito.Mockito.*;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+import static org.mockito.Mockito.doReturn;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
 
 
 public class TestExecutionPlanner {
@@ -63,6 +77,11 @@ public class TestExecutionPlanner {
   private static final String DEFAULT_SYSTEM = "test-system";
   private static final int DEFAULT_PARTITIONS = 10;
 
+  private final Set<SystemDescriptor> systemDescriptors = new HashSet<>();
+  private final Map<String, InputDescriptor> inputDescriptors = new HashMap<>();
+  private final Map<String, OutputDescriptor> outputDescriptors = new HashMap<>();
+  private final Set<TableDescriptor> tableDescriptors = new HashSet<>();
+
   private SystemAdmins systemAdmins;
   private StreamManager streamManager;
   private Config config;
@@ -78,6 +97,8 @@ public class TestExecutionPlanner {
   private GenericOutputDescriptor<KV<Object, Object>> output1Descriptor;
   private StreamSpec output2Spec;
   private GenericOutputDescriptor<KV<Object, Object>> output2Descriptor;
+  private GenericSystemDescriptor system1Descriptor;
+  private GenericSystemDescriptor system2Descriptor;
 
   static SystemAdmin createSystemAdmin(Map<String, Integer> streamToPartitions) {
 
@@ -236,20 +257,35 @@ public class TestExecutionPlanner {
 
     KVSerde<Object, Object> kvSerde = new KVSerde<>(new NoOpSerde(), new NoOpSerde());
     String mockSystemFactoryClass = "factory.class.name";
-    GenericSystemDescriptor system1 = new GenericSystemDescriptor("system1", mockSystemFactoryClass);
-    GenericSystemDescriptor system2 = new GenericSystemDescriptor("system2", mockSystemFactoryClass);
-    input1Descriptor = system1.getInputDescriptor("input1", kvSerde);
-    input2Descriptor = system2.getInputDescriptor("input2", kvSerde);
-    input3Descriptor = system2.getInputDescriptor("input3", kvSerde);
-    input4Descriptor = system1.getInputDescriptor("input4", kvSerde);
-    output1Descriptor = system1.getOutputDescriptor("output1", kvSerde);
-    output2Descriptor = system2.getOutputDescriptor("output2", kvSerde);
+    system1Descriptor = new GenericSystemDescriptor("system1", mockSystemFactoryClass);
+    system2Descriptor = new GenericSystemDescriptor("system2", mockSystemFactoryClass);
+    input1Descriptor = system1Descriptor.getInputDescriptor("input1", kvSerde);
+    input2Descriptor = system2Descriptor.getInputDescriptor("input2", kvSerde);
+    input3Descriptor = system2Descriptor.getInputDescriptor("input3", kvSerde);
+    input4Descriptor = system1Descriptor.getInputDescriptor("input4", kvSerde);
+    output1Descriptor = system1Descriptor.getOutputDescriptor("output1", kvSerde);
+    output2Descriptor = system2Descriptor.getOutputDescriptor("output2", kvSerde);
+
+    // clean and set up sets and maps of descriptors
+    systemDescriptors.clear();
+    inputDescriptors.clear();
+    outputDescriptors.clear();
+    tableDescriptors.clear();
+    systemDescriptors.add(system1Descriptor);
+    systemDescriptors.add(system2Descriptor);
+    inputDescriptors.put(input1Descriptor.getStreamId(), input1Descriptor);
+    inputDescriptors.put(input2Descriptor.getStreamId(), input2Descriptor);
+    inputDescriptors.put(input3Descriptor.getStreamId(), input3Descriptor);
+    inputDescriptors.put(input4Descriptor.getStreamId(), input4Descriptor);
+    outputDescriptors.put(output1Descriptor.getStreamId(), output1Descriptor);
+    outputDescriptors.put(output2Descriptor.getStreamId(), output2Descriptor);
+
 
     // set up external partition count
     Map<String, Integer> system1Map = new HashMap<>();
     system1Map.put("input1", 64);
     system1Map.put("output1", 8);
-    system1Map.put("input4", ExecutionPlanner.MAX_INFERRED_PARTITIONS * 2);
+    system1Map.put("input4", IntermediateStreamManager.MAX_INFERRED_PARTITIONS * 2);
     Map<String, Integer> system2Map = new HashMap<>();
     system2Map.put("input2", 16);
     system2Map.put("input3", 32);
@@ -268,7 +304,7 @@ public class TestExecutionPlanner {
     ExecutionPlanner planner = new ExecutionPlanner(config, streamManager);
     StreamApplicationDescriptorImpl graphSpec = createStreamGraphWithJoin();
 
-    JobGraph jobGraph = planner.createJobGraph(graphSpec.getOperatorSpecGraph());
+    JobGraph jobGraph = planner.createJobGraph(graphSpec.getConfig(), graphSpec);
     assertTrue(jobGraph.getInputStreams().size() == 3);
     assertTrue(jobGraph.getOutputStreams().size() == 2);
     assertTrue(jobGraph.getIntermediateStreams().size() == 2); // two streams generated by partitionBy
@@ -278,9 +314,9 @@ public class TestExecutionPlanner {
   public void testFetchExistingStreamPartitions() {
     ExecutionPlanner planner = new ExecutionPlanner(config, streamManager);
     StreamApplicationDescriptorImpl graphSpec = createStreamGraphWithJoin();
-    JobGraph jobGraph = planner.createJobGraph(graphSpec.getOperatorSpecGraph());
+    JobGraph jobGraph = planner.createJobGraph(graphSpec.getConfig(), graphSpec);
 
-    planner.fetchInputAndOutputStreamPartitions(jobGraph);
+    ExecutionPlanner.setInputAndOutputStreamPartitionCount(jobGraph, streamManager);
     assertTrue(jobGraph.getOrCreateStreamEdge(input1Spec).getPartitionCount() == 64);
     assertTrue(jobGraph.getOrCreateStreamEdge(input2Spec).getPartitionCount() == 16);
     assertTrue(jobGraph.getOrCreateStreamEdge(input3Spec).getPartitionCount() == 32);
@@ -296,7 +332,10 @@ public class TestExecutionPlanner {
   public void testCalculateJoinInputPartitions() {
     ExecutionPlanner planner = new ExecutionPlanner(config, streamManager);
     StreamApplicationDescriptorImpl graphSpec = createStreamGraphWithJoin();
-    JobGraph jobGraph = (JobGraph) planner.plan(graphSpec.getOperatorSpecGraph());
+    JobGraph jobGraph = planner.createJobGraph(graphSpec.getConfig(), graphSpec);
+
+    ExecutionPlanner.setInputAndOutputStreamPartitionCount(jobGraph, streamManager);
+    new IntermediateStreamManager(config, graphSpec).calculatePartitions(jobGraph);
 
     // the partitions should be the same as input1
     jobGraph.getIntermediateStreams().forEach(edge -> {
@@ -309,7 +348,7 @@ public class TestExecutionPlanner {
     ExecutionPlanner planner = new ExecutionPlanner(config, streamManager);
     StreamApplicationDescriptorImpl graphSpec = createStreamGraphWithInvalidJoin();
 
-    planner.plan(graphSpec.getOperatorSpecGraph());
+    planner.plan(graphSpec);
   }
 
   @Test
@@ -320,7 +359,7 @@ public class TestExecutionPlanner {
 
     ExecutionPlanner planner = new ExecutionPlanner(cfg, streamManager);
     StreamApplicationDescriptorImpl graphSpec = createSimpleGraph();
-    JobGraph jobGraph = (JobGraph) planner.plan(graphSpec.getOperatorSpecGraph());
+    JobGraph jobGraph = (JobGraph) planner.plan(graphSpec);
 
     // the partitions should be the same as input1
     jobGraph.getIntermediateStreams().forEach(edge -> {
@@ -336,7 +375,7 @@ public class TestExecutionPlanner {
 
     ExecutionPlanner planner = new ExecutionPlanner(cfg, streamManager);
     StreamApplicationDescriptorImpl graphSpec = createStreamGraphWithJoin();
-    ExecutionPlan plan = planner.plan(graphSpec.getOperatorSpecGraph());
+    ExecutionPlan plan = planner.plan(graphSpec);
     List<JobConfig> jobConfigs = plan.getJobConfigs();
     for (JobConfig config : jobConfigs) {
       System.out.println(config);
@@ -351,7 +390,7 @@ public class TestExecutionPlanner {
 
     ExecutionPlanner planner = new ExecutionPlanner(cfg, streamManager);
     StreamApplicationDescriptorImpl graphSpec = createStreamGraphWithJoinAndWindow();
-    ExecutionPlan plan = planner.plan(graphSpec.getOperatorSpecGraph());
+    ExecutionPlan plan = planner.plan(graphSpec);
     List<JobConfig> jobConfigs = plan.getJobConfigs();
     assertEquals(1, jobConfigs.size());
 
@@ -368,7 +407,7 @@ public class TestExecutionPlanner {
 
     ExecutionPlanner planner = new ExecutionPlanner(cfg, streamManager);
     StreamApplicationDescriptorImpl graphSpec = createStreamGraphWithJoinAndWindow();
-    ExecutionPlan plan = planner.plan(graphSpec.getOperatorSpecGraph());
+    ExecutionPlan plan = planner.plan(graphSpec);
     List<JobConfig> jobConfigs = plan.getJobConfigs();
     assertEquals(1, jobConfigs.size());
 
@@ -384,7 +423,7 @@ public class TestExecutionPlanner {
 
     ExecutionPlanner planner = new ExecutionPlanner(cfg, streamManager);
     StreamApplicationDescriptorImpl graphSpec = createSimpleGraph();
-    ExecutionPlan plan = planner.plan(graphSpec.getOperatorSpecGraph());
+    ExecutionPlan plan = planner.plan(graphSpec);
     List<JobConfig> jobConfigs = plan.getJobConfigs();
     assertEquals(1, jobConfigs.size());
     assertFalse(jobConfigs.get(0).containsKey(TaskConfig.WINDOW_MS()));
@@ -399,7 +438,7 @@ public class TestExecutionPlanner {
 
     ExecutionPlanner planner = new ExecutionPlanner(cfg, streamManager);
     StreamApplicationDescriptorImpl graphSpec = createSimpleGraph();
-    ExecutionPlan plan = planner.plan(graphSpec.getOperatorSpecGraph());
+    ExecutionPlan plan = planner.plan(graphSpec);
     List<JobConfig> jobConfigs = plan.getJobConfigs();
     assertEquals(1, jobConfigs.size());
     assertEquals("2000", jobConfigs.get(0).get(TaskConfig.WINDOW_MS()));
@@ -409,7 +448,7 @@ public class TestExecutionPlanner {
   public void testCalculateIntStreamPartitions() {
     ExecutionPlanner planner = new ExecutionPlanner(config, streamManager);
     StreamApplicationDescriptorImpl graphSpec = createSimpleGraph();
-    JobGraph jobGraph = (JobGraph) planner.plan(graphSpec.getOperatorSpecGraph());
+    JobGraph jobGraph = (JobGraph) planner.plan(graphSpec);
 
     // the partitions should be the same as input1
     jobGraph.getIntermediateStreams().forEach(edge -> {
@@ -430,15 +469,15 @@ public class TestExecutionPlanner {
     edge.setPartitionCount(16);
     edges.add(edge);
 
-    assertEquals(32, ExecutionPlanner.maxPartitions(edges));
+    assertEquals(32, IntermediateStreamManager.maxPartitions(edges));
 
     edges = Collections.emptyList();
-    assertEquals(StreamEdge.PARTITIONS_UNKNOWN, ExecutionPlanner.maxPartitions(edges));
+    assertEquals(StreamEdge.PARTITIONS_UNKNOWN, IntermediateStreamManager.maxPartitions(edges));
   }
 
   @Test
   public void testMaxPartitionLimit() throws Exception {
-    int partitionLimit = ExecutionPlanner.MAX_INFERRED_PARTITIONS;
+    int partitionLimit = IntermediateStreamManager.MAX_INFERRED_PARTITIONS;
 
     ExecutionPlanner planner = new ExecutionPlanner(config, streamManager);
     StreamApplicationDescriptorImpl graphSpec = new StreamApplicationDescriptorImpl(appDesc -> {
@@ -447,11 +486,99 @@ public class TestExecutionPlanner {
         input1.partitionBy(m -> m.key, m -> m.value, "p1").map(kv -> kv).sendTo(output1);
       }, config);
 
-    JobGraph jobGraph = (JobGraph) planner.plan(graphSpec.getOperatorSpecGraph());
+    JobGraph jobGraph = (JobGraph) planner.plan(graphSpec);
 
     // the partitions should be the same as input1
     jobGraph.getIntermediateStreams().forEach(edge -> {
         assertEquals(partitionLimit, edge.getPartitionCount()); // max of input1 and output1
       });
   }
+
+  @Test
+  public void testCreateJobGraphForTaskApplication() {
+    TaskApplicationDescriptorImpl taskAppDesc = mock(TaskApplicationDescriptorImpl.class);
+    // add interemediate streams
+    String intermediateStream1 = "intermediate-stream1";
+    String intermediateBroadcast = "intermediate-broadcast1";
+    // intermediate stream1, not broadcast
+    GenericInputDescriptor<KV<Object, Object>> intermediateInput1 = system1Descriptor.getInputDescriptor(
+        intermediateStream1, new KVSerde<>(new NoOpSerde(), new NoOpSerde()));
+    GenericOutputDescriptor<KV<Object, Object>> intermediateOutput1 = system1Descriptor.getOutputDescriptor(
+        intermediateStream1, new KVSerde<>(new NoOpSerde(), new NoOpSerde()));
+    // intermediate stream2, broadcast
+    GenericInputDescriptor<KV<Object, Object>> intermediateBroacastInput1 = system1Descriptor.getInputDescriptor(
+        intermediateBroadcast, new KVSerde<>(new NoOpSerde<>(), new NoOpSerde<>()));
+    GenericOutputDescriptor<KV<Object, Object>> intermediateBroacastOutput1 = system1Descriptor.getOutputDescriptor(
+        intermediateBroadcast, new KVSerde<>(new NoOpSerde<>(), new NoOpSerde<>()));
+    inputDescriptors.put(intermediateStream1, intermediateInput1);
+    outputDescriptors.put(intermediateStream1, intermediateOutput1);
+    inputDescriptors.put(intermediateBroadcast, intermediateBroacastInput1);
+    outputDescriptors.put(intermediateBroadcast, intermediateBroacastOutput1);
+    Set<String> broadcastStreams = new HashSet<>();
+    broadcastStreams.add(intermediateBroadcast);
+
+    when(taskAppDesc.getInputDescriptors()).thenReturn(inputDescriptors);
+    when(taskAppDesc.getInputStreamIds()).thenReturn(inputDescriptors.keySet());
+    when(taskAppDesc.getOutputDescriptors()).thenReturn(outputDescriptors);
+    when(taskAppDesc.getOutputStreamIds()).thenReturn(outputDescriptors.keySet());
+    when(taskAppDesc.getTableDescriptors()).thenReturn(Collections.emptySet());
+    when(taskAppDesc.getSystemDescriptors()).thenReturn(systemDescriptors);
+    when(taskAppDesc.getBroadcastStreams()).thenReturn(broadcastStreams);
+    doReturn(MockTaskApplication.class).when(taskAppDesc).getAppClass();
+
+    Map<String, String> systemStreamConfigs = new HashMap<>();
+    inputDescriptors.forEach((key, value) -> systemStreamConfigs.putAll(value.toConfig()));
+    outputDescriptors.forEach((key, value) -> systemStreamConfigs.putAll(value.toConfig()));
+    systemDescriptors.forEach(sd -> systemStreamConfigs.putAll(sd.toConfig()));
+
+    ExecutionPlanner planner = new ExecutionPlanner(config, streamManager);
+    JobGraph jobGraph = planner.createJobGraph(config, taskAppDesc);
+    assertEquals(1, jobGraph.getJobNodes().size());
+    assertTrue(jobGraph.getInputStreams().stream().map(edge -> edge.getName())
+        .filter(streamId -> inputDescriptors.containsKey(streamId)).collect(Collectors.toList()).isEmpty());
+    Set<String> intermediateStreams = new HashSet<>(inputDescriptors.keySet());
+    jobGraph.getInputStreams().forEach(edge -> {
+        if (intermediateStreams.contains(edge.getStreamSpec().getId())) {
+          intermediateStreams.remove(edge.getStreamSpec().getId());
+        }
+      });
+    assertEquals(new HashSet<String>() { { this.add(intermediateStream1); this.add(intermediateBroadcast); } }.toArray(),
+        intermediateStreams.toArray());
+  }
+
+  @Test
+  public void testCreateJobGraphForLegacyTaskApplication() {
+    TaskApplicationDescriptorImpl taskAppDesc = mock(TaskApplicationDescriptorImpl.class);
+
+    when(taskAppDesc.getInputDescriptors()).thenReturn(new HashMap<>());
+    when(taskAppDesc.getOutputDescriptors()).thenReturn(new HashMap<>());
+    when(taskAppDesc.getTableDescriptors()).thenReturn(new HashSet<>());
+    when(taskAppDesc.getSystemDescriptors()).thenReturn(new HashSet<>());
+    when(taskAppDesc.getBroadcastStreams()).thenReturn(new HashSet<>());
+    doReturn(LegacyTaskApplication.class).when(taskAppDesc).getAppClass();
+
+    Map<String, String> systemStreamConfigs = new HashMap<>();
+    inputDescriptors.forEach((key, value) -> systemStreamConfigs.putAll(value.toConfig()));
+    outputDescriptors.forEach((key, value) -> systemStreamConfigs.putAll(value.toConfig()));
+    systemDescriptors.forEach(sd -> systemStreamConfigs.putAll(sd.toConfig()));
+
+    ExecutionPlanner planner = new ExecutionPlanner(config, streamManager);
+    JobGraph jobGraph = planner.createJobGraph(config, taskAppDesc);
+    assertEquals(1, jobGraph.getJobNodes().size());
+    JobNode jobNode = jobGraph.getJobNodes().get(0);
+    assertEquals("test-app", jobNode.getJobName());
+    assertEquals("test-app-1", jobNode.getJobNameAndId());
+    assertEquals(0, jobNode.getInEdges().size());
+    assertEquals(0, jobNode.getOutEdges().size());
+    assertEquals(0, jobNode.getTables().size());
+    assertEquals(config, jobNode.getConfig());
+  }
+
+  public static class MockTaskApplication implements SamzaApplication {
+
+    @Override
+    public void describe(ApplicationDescriptor appDesc) {
+
+    }
+  }
 }
diff --git a/samza-core/src/test/java/org/apache/samza/execution/TestIntermediateStreamManager.java b/samza-core/src/test/java/org/apache/samza/execution/TestIntermediateStreamManager.java
new file mode 100644 (file)
index 0000000..bc15709
--- /dev/null
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.samza.execution;
+
+import org.apache.samza.application.StreamApplicationDescriptorImpl;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+import static org.mockito.Mockito.mock;
+
+/**
+ * Unit tests for {@link IntermediateStreamManager}
+ */
+public class TestIntermediateStreamManager extends ExecutionPlannerTestBase {
+
+  @Test
+  public void testCalculateRepartitionJoinTopicPartitions() {
+    mockStreamAppDesc = new StreamApplicationDescriptorImpl(getRepartitionJoinStreamApplication(), mockConfig);
+    IntermediateStreamManager partitionPlanner = new IntermediateStreamManager(mockConfig, mockStreamAppDesc);
+    JobGraph mockGraph = new ExecutionPlanner(mockConfig, mock(StreamManager.class))
+        .createJobGraph(mockConfig, mockStreamAppDesc);
+    // set the input stream partitions
+    mockGraph.getInputStreams().forEach(inEdge -> {
+        if (inEdge.getStreamSpec().getId().equals(input1Descriptor.getStreamId())) {
+          inEdge.setPartitionCount(6);
+        } else if (inEdge.getStreamSpec().getId().equals(input2Descriptor.getStreamId())) {
+          inEdge.setPartitionCount(5);
+        }
+      });
+    partitionPlanner.calculatePartitions(mockGraph);
+    assertEquals(1, mockGraph.getIntermediateStreamEdges().size());
+    assertEquals(5, mockGraph.getIntermediateStreamEdges().stream()
+        .filter(inEdge -> inEdge.getStreamSpec().getId().equals(intermediateInputDescriptor.getStreamId()))
+        .findFirst().get().getPartitionCount());
+  }
+
+  @Test
+  public void testCalculateRepartitionIntermediateTopicPartitions() {
+    mockStreamAppDesc = new StreamApplicationDescriptorImpl(getRepartitionOnlyStreamApplication(), mockConfig);
+    IntermediateStreamManager partitionPlanner = new IntermediateStreamManager(mockConfig, mockStreamAppDesc);
+    JobGraph mockGraph = new ExecutionPlanner(mockConfig, mock(StreamManager.class))
+        .createJobGraph(mockConfig, mockStreamAppDesc);
+    // set the input stream partitions
+    mockGraph.getInputStreams().forEach(inEdge -> inEdge.setPartitionCount(7));
+    partitionPlanner.calculatePartitions(mockGraph);
+    assertEquals(1, mockGraph.getIntermediateStreamEdges().size());
+    assertEquals(7, mockGraph.getIntermediateStreamEdges().stream()
+        .filter(inEdge -> inEdge.getStreamSpec().getId().equals(intermediateInputDescriptor.getStreamId()))
+        .findFirst().get().getPartitionCount());
+  }
+
+}
index ed35d67..4de0485 100644 (file)
 
 package org.apache.samza.execution;
 
-import java.util.Collections;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
-import org.apache.samza.operators.OperatorSpecGraph;
+import org.apache.samza.application.StreamApplicationDescriptorImpl;
 import org.apache.samza.system.StreamSpec;
 import org.junit.Before;
 import org.junit.Test;
@@ -32,7 +31,6 @@ import org.junit.Test;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertTrue;
 import static org.mockito.Mockito.mock;
-import static org.mockito.Mockito.when;
 
 
 public class TestJobGraph {
@@ -61,9 +59,8 @@ public class TestJobGraph {
    * 2 9 10
    */
   private void createGraph1() {
-    OperatorSpecGraph specGraph = mock(OperatorSpecGraph.class);
-    when(specGraph.getBroadcastStreams()).thenReturn(Collections.emptySet());
-    graph1 = new JobGraph(null, specGraph);
+    StreamApplicationDescriptorImpl appDesc = mock(StreamApplicationDescriptorImpl.class);
+    graph1 = new JobGraph(null, appDesc);
 
     JobNode n2 = graph1.getOrCreateJobNode("2", "1");
     JobNode n3 = graph1.getOrCreateJobNode("3", "1");
@@ -96,9 +93,8 @@ public class TestJobGraph {
    *      |<---6 <--|    <>
    */
   private void createGraph2() {
-    OperatorSpecGraph specGraph = mock(OperatorSpecGraph.class);
-    when(specGraph.getBroadcastStreams()).thenReturn(Collections.emptySet());
-    graph2 = new JobGraph(null, specGraph);
+    StreamApplicationDescriptorImpl appDesc = mock(StreamApplicationDescriptorImpl.class);
+    graph2 = new JobGraph(null, appDesc);
 
     JobNode n1 = graph2.getOrCreateJobNode("1", "1");
     JobNode n2 = graph2.getOrCreateJobNode("2", "1");
@@ -125,9 +121,8 @@ public class TestJobGraph {
    * 1<->1 -> 2<->2
    */
   private void createGraph3() {
-    OperatorSpecGraph specGraph = mock(OperatorSpecGraph.class);
-    when(specGraph.getBroadcastStreams()).thenReturn(Collections.emptySet());
-    graph3 = new JobGraph(null, specGraph);
+    StreamApplicationDescriptorImpl appDesc = mock(StreamApplicationDescriptorImpl.class);
+    graph3 = new JobGraph(null, appDesc);
 
     JobNode n1 = graph3.getOrCreateJobNode("1", "1");
     JobNode n2 = graph3.getOrCreateJobNode("2", "1");
@@ -143,9 +138,8 @@ public class TestJobGraph {
    * 1<->1
    */
   private void createGraph4() {
-    OperatorSpecGraph specGraph = mock(OperatorSpecGraph.class);
-    when(specGraph.getBroadcastStreams()).thenReturn(Collections.emptySet());
-    graph4 = new JobGraph(null, specGraph);
+    StreamApplicationDescriptorImpl appDesc = mock(StreamApplicationDescriptorImpl.class);
+    graph4 = new JobGraph(null, appDesc);
 
     JobNode n1 = graph4.getOrCreateJobNode("1", "1");
 
@@ -163,9 +157,8 @@ public class TestJobGraph {
 
   @Test
   public void testAddSource() {
-    OperatorSpecGraph specGraph = mock(OperatorSpecGraph.class);
-    when(specGraph.getBroadcastStreams()).thenReturn(Collections.emptySet());
-    JobGraph graph = new JobGraph(null, specGraph);
+    StreamApplicationDescriptorImpl appDesc = mock(StreamApplicationDescriptorImpl.class);
+    JobGraph graph = new JobGraph(null, appDesc);
 
     /**
      * s1 -> 1
@@ -206,9 +199,8 @@ public class TestJobGraph {
      * 2 -> s2
      * 2 -> s3
      */
-    OperatorSpecGraph specGraph = mock(OperatorSpecGraph.class);
-    when(specGraph.getBroadcastStreams()).thenReturn(Collections.emptySet());
-    JobGraph graph = new JobGraph(null, specGraph);
+    StreamApplicationDescriptorImpl appDesc = mock(StreamApplicationDescriptorImpl.class);
+    JobGraph graph = new JobGraph(null, appDesc);
     JobNode n1 = graph.getOrCreateJobNode("1", "1");
     JobNode n2 = graph.getOrCreateJobNode("2", "1");
     StreamSpec s1 = genStream();
index ae6e25e..c207118 100644 (file)
 package org.apache.samza.execution;
 
 import java.time.Duration;
+import java.util.Collections;
 import java.util.HashMap;
+import java.util.HashSet;
 import java.util.Map;
+import java.util.Set;
+import java.util.stream.Collectors;
 import org.apache.samza.application.StreamApplicationDescriptorImpl;
+import org.apache.samza.config.ApplicationConfig;
 import org.apache.samza.config.Config;
 import org.apache.samza.config.JobConfig;
 import org.apache.samza.config.MapConfig;
@@ -40,10 +45,13 @@ import org.apache.samza.serializers.LongSerde;
 import org.apache.samza.serializers.NoOpSerde;
 import org.apache.samza.serializers.Serde;
 import org.apache.samza.serializers.StringSerde;
+import org.apache.samza.system.StreamSpec;
 import org.apache.samza.system.SystemAdmin;
 import org.apache.samza.system.SystemAdmins;
 import org.apache.samza.testUtils.StreamTestUtils;
 import org.codehaus.jackson.map.ObjectMapper;
+import org.hamcrest.Matchers;
+import org.junit.Before;
 import org.junit.Test;
 
 import static org.apache.samza.execution.TestExecutionPlanner.*;
@@ -51,16 +59,68 @@ import static org.junit.Assert.*;
 import static org.mockito.Mockito.*;
 
 
+/**
+ * Unit test for {@link JobGraphJsonGenerator}
+ */
 public class TestJobGraphJsonGenerator {
+  private Config mockConfig;
+  private JobNode mockJobNode;
+  private StreamSpec input1Spec;
+  private StreamSpec input2Spec;
+  private StreamSpec outputSpec;
+  private StreamSpec repartitionSpec;
+  private KVSerde<String, Object> defaultSerde;
+  private GenericSystemDescriptor inputSystemDescriptor;
+  private GenericSystemDescriptor outputSystemDescriptor;
+  private GenericSystemDescriptor intermediateSystemDescriptor;
+  private GenericInputDescriptor<KV<String, Object>> input1Descriptor;
+  private GenericInputDescriptor<KV<String, Object>> input2Descriptor;
+  private GenericOutputDescriptor<KV<String, Object>> outputDescriptor;
 
-  public class PageViewEvent {
-    String getCountry() {
-      return "";
-    }
+  @Before
+  public void setUp() {
+    input1Spec = new StreamSpec("input1", "input1", "input-system");
+    input2Spec = new StreamSpec("input2", "input2", "input-system");
+    outputSpec = new StreamSpec("output", "output", "output-system");
+    repartitionSpec =
+        new StreamSpec("jobName-jobId-partition_by-p1", "partition_by-p1", "intermediate-system");
+
+
+    defaultSerde = KVSerde.of(new StringSerde(), new JsonSerdeV2<>());
+    inputSystemDescriptor = new GenericSystemDescriptor("input-system", "mockSystemFactoryClassName");
+    outputSystemDescriptor = new GenericSystemDescriptor("output-system", "mockSystemFactoryClassName");
+    intermediateSystemDescriptor = new GenericSystemDescriptor("intermediate-system", "mockSystemFactoryClassName");
+    input1Descriptor = inputSystemDescriptor.getInputDescriptor("input1", defaultSerde);
+    input2Descriptor = inputSystemDescriptor.getInputDescriptor("input2", defaultSerde);
+    outputDescriptor = outputSystemDescriptor.getOutputDescriptor("output", defaultSerde);
+
+    Map<String, String> configs = new HashMap<>();
+    configs.put(JobConfig.JOB_NAME(), "jobName");
+    configs.put(JobConfig.JOB_ID(), "jobId");
+    mockConfig = spy(new MapConfig(configs));
+
+    mockJobNode = mock(JobNode.class);
+    StreamEdge input1Edge = new StreamEdge(input1Spec, false, false, mockConfig);
+    StreamEdge input2Edge = new StreamEdge(input2Spec, false, false, mockConfig);
+    StreamEdge outputEdge = new StreamEdge(outputSpec, false, false, mockConfig);
+    StreamEdge repartitionEdge = new StreamEdge(repartitionSpec, true, false, mockConfig);
+    Map<String, StreamEdge> inputEdges = new HashMap<>();
+    inputEdges.put(input1Descriptor.getStreamId(), input1Edge);
+    inputEdges.put(input2Descriptor.getStreamId(), input2Edge);
+    inputEdges.put(repartitionSpec.getId(), repartitionEdge);
+    Map<String, StreamEdge> outputEdges = new HashMap<>();
+    outputEdges.put(outputDescriptor.getStreamId(), outputEdge);
+    outputEdges.put(repartitionSpec.getId(), repartitionEdge);
+    when(mockJobNode.getInEdges()).thenReturn(inputEdges);
+    when(mockJobNode.getOutEdges()).thenReturn(outputEdges);
+    when(mockJobNode.getConfig()).thenReturn(mockConfig);
+    when(mockJobNode.getJobName()).thenReturn("jobName");
+    when(mockJobNode.getJobId()).thenReturn("jobId");
+    when(mockJobNode.getJobNameAndId()).thenReturn(JobNode.createJobNameAndId("jobName", "jobId"));
   }
 
   @Test
-  public void test() throws Exception {
+  public void testRepartitionedJoinStreamApplication() throws Exception {
 
     /**
      * the graph looks like the following.
@@ -142,7 +202,7 @@ public class TestJobGraphJsonGenerator {
       }, config);
 
     ExecutionPlanner planner = new ExecutionPlanner(config, streamManager);
-    ExecutionPlan plan = planner.plan(graphSpec.getOperatorSpecGraph());
+    ExecutionPlan plan = planner.plan(graphSpec);
     String json = plan.getPlanAsJson();
     System.out.println(json);
 
@@ -157,7 +217,7 @@ public class TestJobGraphJsonGenerator {
   }
 
   @Test
-  public void test2() throws Exception {
+  public void testRepartitionedWindowStreamApplication() throws Exception {
     Map<String, String> configMap = new HashMap<>();
     configMap.put(JobConfig.JOB_NAME(), "test-app");
     configMap.put(JobConfig.JOB_DEFAULT_SYSTEM(), "test-system");
@@ -202,7 +262,7 @@ public class TestJobGraphJsonGenerator {
       }, config);
 
     ExecutionPlanner planner = new ExecutionPlanner(config, streamManager);
-    ExecutionPlan plan = planner.plan(graphSpec.getOperatorSpecGraph());
+    ExecutionPlan plan = planner.plan(graphSpec);
     String json = plan.getPlanAsJson();
     System.out.println(json);
 
@@ -222,4 +282,75 @@ public class TestJobGraphJsonGenerator {
     assertEquals(operatorGraphJson.operators.get("test-app-1-send_to-5").get("outputStreamId"),
         "PageViewCount");
   }
+
+  @Test
+  public void testTaskApplication() throws Exception {
+    JobGraphJsonGenerator jsonGenerator = new JobGraphJsonGenerator();
+    JobGraph mockJobGraph = mock(JobGraph.class);
+    ApplicationConfig mockAppConfig = mock(ApplicationConfig.class);
+    when(mockAppConfig.getAppName()).thenReturn("testTaskApp");
+    when(mockAppConfig.getAppId()).thenReturn("testTaskAppId");
+    when(mockJobGraph.getApplicationConfig()).thenReturn(mockAppConfig);
+    // compute the three disjoint sets of the JobGraph: input only, output only, and intermediate streams
+    Set<StreamEdge> inEdges = new HashSet<>(mockJobNode.getInEdges().values());
+    Set<StreamEdge> outEdges = new HashSet<>(mockJobNode.getOutEdges().values());
+    Set<StreamEdge> intermediateEdges = new HashSet<>(inEdges);
+    // intermediate streams are the intersection between input and output
+    intermediateEdges.retainAll(outEdges);
+    // remove all intermediate streams from input
+    inEdges.removeAll(intermediateEdges);
+    // remove all intermediate streams from output
+    outEdges.removeAll(intermediateEdges);
+    // set the return values for mockJobGraph
+    when(mockJobGraph.getInputStreams()).thenReturn(inEdges);
+    when(mockJobGraph.getOutputStreams()).thenReturn(outEdges);
+    when(mockJobGraph.getIntermediateStreamEdges()).thenReturn(intermediateEdges);
+    when(mockJobGraph.getJobNodes()).thenReturn(Collections.singletonList(mockJobNode));
+    String graphJson = jsonGenerator.toJson(mockJobGraph);
+    ObjectMapper objectMapper = new ObjectMapper();
+    JobGraphJsonGenerator.JobGraphJson jsonObject = objectMapper.readValue(graphJson.getBytes(), JobGraphJsonGenerator.JobGraphJson.class);
+    assertEquals("testTaskAppId", jsonObject.applicationId);
+    assertEquals("testTaskApp", jsonObject.applicationName);
+    Set<String> inStreamIds = inEdges.stream().map(stream -> stream.getStreamSpec().getId()).collect(Collectors.toSet());
+    assertThat(jsonObject.sourceStreams.keySet(), Matchers.containsInAnyOrder(inStreamIds.toArray()));
+    Set<String> outStreamIds = outEdges.stream().map(stream -> stream.getStreamSpec().getId()).collect(Collectors.toSet());
+    assertThat(jsonObject.sinkStreams.keySet(), Matchers.containsInAnyOrder(outStreamIds.toArray()));
+    Set<String> intStreamIds = intermediateEdges.stream().map(stream -> stream.getStreamSpec().getId()).collect(Collectors.toSet());
+    assertThat(jsonObject.intermediateStreams.keySet(), Matchers.containsInAnyOrder(intStreamIds.toArray()));
+    JobGraphJsonGenerator.JobNodeJson expectedNodeJson = new JobGraphJsonGenerator.JobNodeJson();
+    expectedNodeJson.jobId = mockJobNode.getJobId();
+    expectedNodeJson.jobName = mockJobNode.getJobName();
+    assertEquals(1, jsonObject.jobs.size());
+    JobGraphJsonGenerator.JobNodeJson actualNodeJson = jsonObject.jobs.get(0);
+    assertEquals(expectedNodeJson.jobId, actualNodeJson.jobId);
+    assertEquals(expectedNodeJson.jobName, actualNodeJson.jobName);
+    assertEquals(3, actualNodeJson.operatorGraph.inputStreams.size());
+    assertEquals(2, actualNodeJson.operatorGraph.outputStreams.size());
+    assertEquals(0, actualNodeJson.operatorGraph.operators.size());
+  }
+
+  @Test
+  public void testLegacyTaskApplication() throws Exception {
+    JobGraphJsonGenerator jsonGenerator = new JobGraphJsonGenerator();
+    JobGraph mockJobGraph = mock(JobGraph.class);
+    ApplicationConfig mockAppConfig = mock(ApplicationConfig.class);
+    when(mockAppConfig.getAppName()).thenReturn("testTaskApp");
+    when(mockAppConfig.getAppId()).thenReturn("testTaskAppId");
+    when(mockJobGraph.getApplicationConfig()).thenReturn(mockAppConfig);
+    String graphJson = jsonGenerator.toJson(mockJobGraph);
+    ObjectMapper objectMapper = new ObjectMapper();
+    JobGraphJsonGenerator.JobGraphJson jsonObject = objectMapper.readValue(graphJson.getBytes(), JobGraphJsonGenerator.JobGraphJson.class);
+    assertEquals("testTaskAppId", jsonObject.applicationId);
+    assertEquals("testTaskApp", jsonObject.applicationName);
+    JobGraphJsonGenerator.JobNodeJson expectedNodeJson = new JobGraphJsonGenerator.JobNodeJson();
+    expectedNodeJson.jobId = mockJobNode.getJobId();
+    expectedNodeJson.jobName = mockJobNode.getJobName();
+    assertEquals(0, jsonObject.jobs.size());
+  }
+
+  public class PageViewEvent {
+    String getCountry() {
+      return "";
+    }
+  }
 }
diff --git a/samza-core/src/test/java/org/apache/samza/execution/TestJobNode.java b/samza-core/src/test/java/org/apache/samza/execution/TestJobNode.java
deleted file mode 100644 (file)
index 163b094..0000000
+++ /dev/null
@@ -1,228 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.samza.execution;
-
-import java.time.Duration;
-import java.util.Base64;
-import java.util.HashMap;
-import java.util.Map;
-import java.util.stream.Collectors;
-import org.apache.samza.application.StreamApplicationDescriptorImpl;
-import org.apache.samza.config.Config;
-import org.apache.samza.config.JobConfig;
-import org.apache.samza.config.MapConfig;
-import org.apache.samza.config.SerializerConfig;
-import org.apache.samza.operators.KV;
-import org.apache.samza.operators.MessageStream;
-import org.apache.samza.operators.OutputStream;
-import org.apache.samza.operators.descriptors.GenericInputDescriptor;
-import org.apache.samza.operators.descriptors.GenericOutputDescriptor;
-import org.apache.samza.operators.descriptors.GenericSystemDescriptor;
-import org.apache.samza.operators.functions.JoinFunction;
-import org.apache.samza.operators.impl.store.TimestampedValueSerde;
-import org.apache.samza.serializers.JsonSerdeV2;
-import org.apache.samza.serializers.KVSerde;
-import org.apache.samza.serializers.Serde;
-import org.apache.samza.serializers.SerializableSerde;
-import org.apache.samza.serializers.StringSerde;
-import org.apache.samza.system.StreamSpec;
-import org.junit.Test;
-
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertTrue;
-import static org.mockito.Matchers.anyString;
-import static org.mockito.Matchers.eq;
-import static org.mockito.Mockito.*;
-
-public class TestJobNode {
-
-  @Test
-  public void testAddSerdeConfigs() {
-    StreamSpec input1Spec = new StreamSpec("input1", "input1", "input-system");
-    StreamSpec input2Spec = new StreamSpec("input2", "input2", "input-system");
-    StreamSpec outputSpec = new StreamSpec("output", "output", "output-system");
-    StreamSpec partitionBySpec =
-        new StreamSpec("jobName-jobId-partition_by-p1", "partition_by-p1", "intermediate-system");
-
-    Config mockConfig = mock(Config.class);
-    when(mockConfig.get(JobConfig.JOB_NAME())).thenReturn("jobName");
-    when(mockConfig.get(eq(JobConfig.JOB_ID()), anyString())).thenReturn("jobId");
-
-    StreamApplicationDescriptorImpl graphSpec = new StreamApplicationDescriptorImpl(appDesc -> {
-        KVSerde<String, Object> serde = KVSerde.of(new StringSerde(), new JsonSerdeV2<>());
-        GenericSystemDescriptor sd = new GenericSystemDescriptor("system1", "mockSystemFactoryClass");
-        GenericInputDescriptor<KV<String, Object>> inputDescriptor1 = sd.getInputDescriptor("input1", serde);
-        GenericInputDescriptor<KV<String, Object>> inputDescriptor2 = sd.getInputDescriptor("input2", serde);
-        GenericOutputDescriptor<KV<String, Object>> outputDescriptor = sd.getOutputDescriptor("output", serde);
-        MessageStream<KV<String, Object>> input1 = appDesc.getInputStream(inputDescriptor1);
-        MessageStream<KV<String, Object>> input2 = appDesc.getInputStream(inputDescriptor2);
-        OutputStream<KV<String, Object>> output = appDesc.getOutputStream(outputDescriptor);
-        JoinFunction<String, Object, Object, KV<String, Object>> mockJoinFn = mock(JoinFunction.class);
-        input1
-            .partitionBy(KV::getKey, KV::getValue, serde, "p1")
-            .map(kv -> kv.value)
-            .join(input2.map(kv -> kv.value), mockJoinFn,
-                new StringSerde(), new JsonSerdeV2<>(Object.class), new JsonSerdeV2<>(Object.class),
-                Duration.ofHours(1), "j1")
-            .sendTo(output);
-      }, mockConfig);
-
-    JobNode jobNode = new JobNode("jobName", "jobId", graphSpec.getOperatorSpecGraph(), mockConfig);
-    Config config = new MapConfig();
-    StreamEdge input1Edge = new StreamEdge(input1Spec, false, false, config);
-    StreamEdge input2Edge = new StreamEdge(input2Spec, false, false, config);
-    StreamEdge outputEdge = new StreamEdge(outputSpec, false, false, config);
-    StreamEdge repartitionEdge = new StreamEdge(partitionBySpec, true, false, config);
-    jobNode.addInEdge(input1Edge);
-    jobNode.addInEdge(input2Edge);
-    jobNode.addOutEdge(outputEdge);
-    jobNode.addInEdge(repartitionEdge);
-    jobNode.addOutEdge(repartitionEdge);
-
-    Map<String, String> configs = new HashMap<>();
-    jobNode.addSerdeConfigs(configs);
-
-    MapConfig mapConfig = new MapConfig(configs);
-    Config serializers = mapConfig.subset("serializers.registry.", true);
-
-    // make sure that the serializers deserialize correctly
-    SerializableSerde<Serde> serializableSerde = new SerializableSerde<>();
-    Map<String, Serde> deserializedSerdes = serializers.entrySet().stream().collect(Collectors.toMap(
-        e -> e.getKey().replace(SerializerConfig.SERIALIZED_INSTANCE_SUFFIX(), ""),
-        e -> serializableSerde.fromBytes(Base64.getDecoder().decode(e.getValue().getBytes()))
-    ));
-    assertEquals(5, serializers.size()); // 2 default + 3 specific for join
-
-    String input1KeySerde = mapConfig.get("streams.input1.samza.key.serde");
-    String input1MsgSerde = mapConfig.get("streams.input1.samza.msg.serde");
-    assertTrue("Serialized serdes should contain input1 key serde",
-        deserializedSerdes.containsKey(input1KeySerde));
-    assertTrue("Serialized input1 key serde should be a StringSerde",
-        input1KeySerde.startsWith(StringSerde.class.getSimpleName()));
-    assertTrue("Serialized serdes should contain input1 msg serde",
-        deserializedSerdes.containsKey(input1MsgSerde));
-    assertTrue("Serialized input1 msg serde should be a JsonSerdeV2",
-        input1MsgSerde.startsWith(JsonSerdeV2.class.getSimpleName()));
-
-    String input2KeySerde = mapConfig.get("streams.input2.samza.key.serde");
-    String input2MsgSerde = mapConfig.get("streams.input2.samza.msg.serde");
-    assertTrue("Serialized serdes should contain input2 key serde",
-        deserializedSerdes.containsKey(input2KeySerde));
-    assertTrue("Serialized input2 key serde should be a StringSerde",
-        input2KeySerde.startsWith(StringSerde.class.getSimpleName()));
-    assertTrue("Serialized serdes should contain input2 msg serde",
-        deserializedSerdes.containsKey(input2MsgSerde));
-    assertTrue("Serialized input2 msg serde should be a JsonSerdeV2",
-        input2MsgSerde.startsWith(JsonSerdeV2.class.getSimpleName()));
-
-    String outputKeySerde = mapConfig.get("streams.output.samza.key.serde");
-    String outputMsgSerde = mapConfig.get("streams.output.samza.msg.serde");
-    assertTrue("Serialized serdes should contain output key serde",
-        deserializedSerdes.containsKey(outputKeySerde));
-    assertTrue("Serialized output key serde should be a StringSerde",
-        outputKeySerde.startsWith(StringSerde.class.getSimpleName()));
-    assertTrue("Serialized serdes should contain output msg serde",
-        deserializedSerdes.containsKey(outputMsgSerde));
-    assertTrue("Serialized output msg serde should be a JsonSerdeV2",
-        outputMsgSerde.startsWith(JsonSerdeV2.class.getSimpleName()));
-
-    String partitionByKeySerde = mapConfig.get("streams.jobName-jobId-partition_by-p1.samza.key.serde");
-    String partitionByMsgSerde = mapConfig.get("streams.jobName-jobId-partition_by-p1.samza.msg.serde");
-    assertTrue("Serialized serdes should contain intermediate stream key serde",
-        deserializedSerdes.containsKey(partitionByKeySerde));
-    assertTrue("Serialized intermediate stream key serde should be a StringSerde",
-        partitionByKeySerde.startsWith(StringSerde.class.getSimpleName()));
-    assertTrue("Serialized serdes should contain intermediate stream msg serde",
-        deserializedSerdes.containsKey(partitionByMsgSerde));
-    assertTrue(
-        "Serialized intermediate stream msg serde should be a JsonSerdeV2",
-        partitionByMsgSerde.startsWith(JsonSerdeV2.class.getSimpleName()));
-
-    String leftJoinStoreKeySerde = mapConfig.get("stores.jobName-jobId-join-j1-L.key.serde");
-    String leftJoinStoreMsgSerde = mapConfig.get("stores.jobName-jobId-join-j1-L.msg.serde");
-    assertTrue("Serialized serdes should contain left join store key serde",
-        deserializedSerdes.containsKey(leftJoinStoreKeySerde));
-    assertTrue("Serialized left join store key serde should be a StringSerde",
-        leftJoinStoreKeySerde.startsWith(StringSerde.class.getSimpleName()));
-    assertTrue("Serialized serdes should contain left join store msg serde",
-        deserializedSerdes.containsKey(leftJoinStoreMsgSerde));
-    assertTrue("Serialized left join store msg serde should be a TimestampedValueSerde",
-        leftJoinStoreMsgSerde.startsWith(TimestampedValueSerde.class.getSimpleName()));
-
-    String rightJoinStoreKeySerde = mapConfig.get("stores.jobName-jobId-join-j1-R.key.serde");
-    String rightJoinStoreMsgSerde = mapConfig.get("stores.jobName-jobId-join-j1-R.msg.serde");
-    assertTrue("Serialized serdes should contain right join store key serde",
-        deserializedSerdes.containsKey(rightJoinStoreKeySerde));
-    assertTrue("Serialized right join store key serde should be a StringSerde",
-        rightJoinStoreKeySerde.startsWith(StringSerde.class.getSimpleName()));
-    assertTrue("Serialized serdes should contain right join store msg serde",
-        deserializedSerdes.containsKey(rightJoinStoreMsgSerde));
-    assertTrue("Serialized right join store msg serde should be a TimestampedValueSerde",
-        rightJoinStoreMsgSerde.startsWith(TimestampedValueSerde.class.getSimpleName()));
-  }
-
-  @Test
-  public void testAddSerdeConfigsForRepartitionWithNoDefaultSystem() {
-    StreamSpec inputSpec = new StreamSpec("input", "input", "input-system");
-    StreamSpec partitionBySpec =
-        new StreamSpec("jobName-jobId-partition_by-p1", "partition_by-p1", "intermediate-system");
-
-    Config mockConfig = mock(Config.class);
-    when(mockConfig.get(JobConfig.JOB_NAME())).thenReturn("jobName");
-    when(mockConfig.get(eq(JobConfig.JOB_ID()), anyString())).thenReturn("jobId");
-
-    StreamApplicationDescriptorImpl graphSpec = new StreamApplicationDescriptorImpl(appDesc -> {
-        GenericSystemDescriptor sd = new GenericSystemDescriptor("system1", "mockSystemFactoryClassName");
-        GenericInputDescriptor<KV<String, Object>> inputDescriptor1 =
-            sd.getInputDescriptor("input", KVSerde.of(new StringSerde(), new JsonSerdeV2<>()));
-        MessageStream<KV<String, Object>> input = appDesc.getInputStream(inputDescriptor1);
-        input.partitionBy(KV::getKey, KV::getValue, "p1");
-      }, mockConfig);
-
-    JobNode jobNode = new JobNode("jobName", "jobId", graphSpec.getOperatorSpecGraph(), mockConfig);
-    Config config = new MapConfig();
-    StreamEdge input1Edge = new StreamEdge(inputSpec, false, false, config);
-    StreamEdge repartitionEdge = new StreamEdge(partitionBySpec, true, false, config);
-    jobNode.addInEdge(input1Edge);
-    jobNode.addInEdge(repartitionEdge);
-    jobNode.addOutEdge(repartitionEdge);
-
-    Map<String, String> configs = new HashMap<>();
-    jobNode.addSerdeConfigs(configs);
-
-    MapConfig mapConfig = new MapConfig(configs);
-    Config serializers = mapConfig.subset("serializers.registry.", true);
-
-    // make sure that the serializers deserialize correctly
-    SerializableSerde<Serde> serializableSerde = new SerializableSerde<>();
-    Map<String, Serde> deserializedSerdes = serializers.entrySet().stream().collect(Collectors.toMap(
-        e -> e.getKey().replace(SerializerConfig.SERIALIZED_INSTANCE_SUFFIX(), ""),
-        e -> serializableSerde.fromBytes(Base64.getDecoder().decode(e.getValue().getBytes()))
-    ));
-    assertEquals(2, serializers.size()); // 2 input stream
-
-    String partitionByKeySerde = mapConfig.get("streams.jobName-jobId-partition_by-p1.samza.key.serde");
-    String partitionByMsgSerde = mapConfig.get("streams.jobName-jobId-partition_by-p1.samza.msg.serde");
-    assertTrue("Serialized serdes should not contain intermediate stream key serde",
-        !deserializedSerdes.containsKey(partitionByKeySerde));
-    assertTrue("Serialized serdes should not contain intermediate stream msg serde",
-        !deserializedSerdes.containsKey(partitionByMsgSerde));
-  }
-}
diff --git a/samza-core/src/test/java/org/apache/samza/execution/TestJobNodeConfigurationGenerator.java b/samza-core/src/test/java/org/apache/samza/execution/TestJobNodeConfigurationGenerator.java
new file mode 100644 (file)
index 0000000..f351c44
--- /dev/null
@@ -0,0 +1,509 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.samza.execution;
+
+import com.google.common.base.Joiner;
+import java.util.ArrayList;
+import java.util.Base64;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.stream.Collectors;
+import org.apache.samza.application.StreamApplicationDescriptorImpl;
+import org.apache.samza.application.TaskApplicationDescriptorImpl;
+import org.apache.samza.config.Config;
+import org.apache.samza.config.ConfigRewriter;
+import org.apache.samza.config.JobConfig;
+import org.apache.samza.config.MapConfig;
+import org.apache.samza.config.SerializerConfig;
+import org.apache.samza.config.TaskConfig;
+import org.apache.samza.config.TaskConfigJava;
+import org.apache.samza.container.SamzaContainerContext;
+import org.apache.samza.operators.BaseTableDescriptor;
+import org.apache.samza.operators.KV;
+import org.apache.samza.operators.TableDescriptor;
+import org.apache.samza.operators.descriptors.GenericInputDescriptor;
+import org.apache.samza.operators.impl.store.TimestampedValueSerde;
+import org.apache.samza.serializers.JsonSerdeV2;
+import org.apache.samza.serializers.KVSerde;
+import org.apache.samza.serializers.Serde;
+import org.apache.samza.serializers.SerializableSerde;
+import org.apache.samza.serializers.StringSerde;
+import org.apache.samza.system.StreamSpec;
+import org.apache.samza.table.Table;
+import org.apache.samza.table.TableProvider;
+import org.apache.samza.table.TableProviderFactory;
+import org.apache.samza.table.TableSpec;
+import org.apache.samza.task.TaskContext;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.spy;
+import static org.mockito.Mockito.when;
+
+
+/**
+ * Unit test for {@link JobNodeConfigurationGenerator}
+ */
+public class TestJobNodeConfigurationGenerator extends ExecutionPlannerTestBase {
+
+  @Test
+  public void testConfigureSerdesWithRepartitionJoinApplication() {
+    mockStreamAppDesc = new StreamApplicationDescriptorImpl(getRepartitionJoinStreamApplication(), mockConfig);
+    configureJobNode(mockStreamAppDesc);
+    // create the JobGraphConfigureGenerator and generate the jobConfig for the jobNode
+    JobNodeConfigurationGenerator configureGenerator = new JobNodeConfigurationGenerator();
+    JobConfig jobConfig = configureGenerator.generateJobConfig(mockJobNode, "testJobGraphJson");
+
+    // Verify the results
+    Config expectedJobConfig = getExpectedJobConfig(mockConfig, mockJobNode.getInEdges());
+    validateJobConfig(expectedJobConfig, jobConfig);
+    // additional, check the computed window.ms for join
+    assertEquals("3600000", jobConfig.get(TaskConfig.WINDOW_MS()));
+    Map<String, Serde> deserializedSerdes = validateAndGetDeserializedSerdes(jobConfig, 5);
+    validateStreamConfigures(jobConfig, deserializedSerdes);
+    validateJoinStoreConfigures(jobConfig, deserializedSerdes);
+  }
+
+  @Test
+  public void testConfigureSerdesForRepartitionWithNoDefaultSystem() {
+    // set the application to RepartitionOnlyStreamApplication
+    mockStreamAppDesc = new StreamApplicationDescriptorImpl(getRepartitionOnlyStreamApplication(), mockConfig);
+    configureJobNode(mockStreamAppDesc);
+
+    // create the JobGraphConfigureGenerator and generate the jobConfig for the jobNode
+    JobNodeConfigurationGenerator configureGenerator = new JobNodeConfigurationGenerator();
+    JobConfig jobConfig = configureGenerator.generateJobConfig(mockJobNode, "testJobGraphJson");
+
+    // Verify the results
+    Config expectedJobConfig = getExpectedJobConfig(mockConfig, mockJobNode.getInEdges());
+    validateJobConfig(expectedJobConfig, jobConfig);
+
+    Map<String, Serde> deserializedSerdes = validateAndGetDeserializedSerdes(jobConfig, 2);
+    validateStreamConfigures(jobConfig, null);
+
+    String partitionByKeySerde = jobConfig.get("streams.jobName-jobId-partition_by-p1.samza.key.serde");
+    String partitionByMsgSerde = jobConfig.get("streams.jobName-jobId-partition_by-p1.samza.msg.serde");
+    assertTrue("Serialized serdes should not contain intermediate stream key serde",
+        !deserializedSerdes.containsKey(partitionByKeySerde));
+    assertTrue("Serialized serdes should not contain intermediate stream msg serde",
+        !deserializedSerdes.containsKey(partitionByMsgSerde));
+  }
+
+  @Test
+  public void testGenerateJobConfigWithTaskApplication() {
+    // set the application to TaskApplication, which still wire up all input/output/intermediate streams
+    TaskApplicationDescriptorImpl taskAppDesc = new TaskApplicationDescriptorImpl(getTaskApplication(), mockConfig);
+    configureJobNode(taskAppDesc);
+    // create the JobGraphConfigureGenerator and generate the jobConfig for the jobNode
+    JobNodeConfigurationGenerator configureGenerator = new JobNodeConfigurationGenerator();
+    JobConfig jobConfig = configureGenerator.generateJobConfig(mockJobNode, "testJobGraphJson");
+
+    // Verify the results
+    Config expectedJobConfig = getExpectedJobConfig(mockConfig, mockJobNode.getInEdges());
+    validateJobConfig(expectedJobConfig, jobConfig);
+    Map<String, Serde> deserializedSerdes = validateAndGetDeserializedSerdes(jobConfig, 2);
+    validateStreamConfigures(jobConfig, deserializedSerdes);
+  }
+
+  @Test
+  public void testGenerateJobConfigWithLegacyTaskApplication() {
+    TaskApplicationDescriptorImpl taskAppDesc = new TaskApplicationDescriptorImpl(getLegacyTaskApplication(), mockConfig);
+    configureJobNode(taskAppDesc);
+    Map<String, String> originConfig = new HashMap<>(mockConfig);
+
+    // create the JobGraphConfigureGenerator and generate the jobConfig for the jobNode
+    JobNodeConfigurationGenerator configureGenerator = new JobNodeConfigurationGenerator();
+    JobConfig jobConfig = configureGenerator.generateJobConfig(mockJobNode, "");
+    // jobConfig should be exactly the same as original config
+    Map<String, String> generatedConfig = new HashMap<>(jobConfig);
+    assertEquals(originConfig, generatedConfig);
+  }
+
+  @Test
+  public void testBroadcastStreamApplication() {
+    // set the application to BroadcastStreamApplication
+    mockStreamAppDesc = new StreamApplicationDescriptorImpl(getBroadcastOnlyStreamApplication(defaultSerde), mockConfig);
+    configureJobNode(mockStreamAppDesc);
+
+    // create the JobGraphConfigureGenerator and generate the jobConfig for the jobNode
+    JobNodeConfigurationGenerator configureGenerator = new JobNodeConfigurationGenerator();
+    JobConfig jobConfig = configureGenerator.generateJobConfig(mockJobNode, "testJobGraphJson");
+    Config expectedJobConfig = getExpectedJobConfig(mockConfig, mockJobNode.getInEdges());
+    validateJobConfig(expectedJobConfig, jobConfig);
+    Map<String, Serde> deserializedSerdes = validateAndGetDeserializedSerdes(jobConfig, 2);
+    validateStreamSerdeConfigure(broadcastInputDesriptor.getStreamId(), jobConfig, deserializedSerdes);
+    validateIntermediateStreamConfigure(broadcastInputDesriptor.getStreamId(), broadcastInputDesriptor.getPhysicalName().get(), jobConfig);
+  }
+
+  @Test
+  public void testBroadcastStreamApplicationWithoutSerde() {
+    // set the application to BroadcastStreamApplication withoutSerde
+    mockStreamAppDesc = new StreamApplicationDescriptorImpl(getBroadcastOnlyStreamApplication(null), mockConfig);
+    configureJobNode(mockStreamAppDesc);
+
+    // create the JobGraphConfigureGenerator and generate the jobConfig for the jobNode
+    JobNodeConfigurationGenerator configureGenerator = new JobNodeConfigurationGenerator();
+    JobConfig jobConfig = configureGenerator.generateJobConfig(mockJobNode, "testJobGraphJson");
+    Config expectedJobConfig = getExpectedJobConfig(mockConfig, mockJobNode.getInEdges());
+    validateJobConfig(expectedJobConfig, jobConfig);
+    Map<String, Serde> deserializedSerdes = validateAndGetDeserializedSerdes(jobConfig, 2);
+    validateIntermediateStreamConfigure(broadcastInputDesriptor.getStreamId(), broadcastInputDesriptor.getPhysicalName().get(), jobConfig);
+
+    String keySerde = jobConfig.get(String.format("streams.%s.samza.key.serde", broadcastInputDesriptor.getStreamId()));
+    String msgSerde = jobConfig.get(String.format("streams.%s.samza.msg.serde", broadcastInputDesriptor.getStreamId()));
+    assertTrue("Serialized serdes should not contain intermediate stream key serde",
+        !deserializedSerdes.containsKey(keySerde));
+    assertTrue("Serialized serdes should not contain intermediate stream msg serde",
+        !deserializedSerdes.containsKey(msgSerde));
+  }
+
+  @Test
+  public void testStreamApplicationWithTableAndSideInput() {
+    mockStreamAppDesc = new StreamApplicationDescriptorImpl(getRepartitionJoinStreamApplication(), mockConfig);
+    // add table to the RepartitionJoinStreamApplication
+    GenericInputDescriptor<KV<String, Object>> sideInput1 = inputSystemDescriptor.getInputDescriptor("sideInput1", defaultSerde);
+    BaseTableDescriptor mockTableDescriptor = mock(BaseTableDescriptor.class);
+    TableSpec mockTableSpec = mock(TableSpec.class);
+    when(mockTableSpec.getId()).thenReturn("testTable");
+    when(mockTableSpec.getSerde()).thenReturn((KVSerde) defaultSerde);
+    when(mockTableSpec.getTableProviderFactoryClassName()).thenReturn(MockTableProviderFactory.class.getName());
+    List<String> sideInputs = new ArrayList<>();
+    sideInputs.add(sideInput1.getStreamId());
+    when(mockTableSpec.getSideInputs()).thenReturn(sideInputs);
+    when(mockTableDescriptor.getTableId()).thenReturn("testTable");
+    when(mockTableDescriptor.getTableSpec()).thenReturn(mockTableSpec);
+    when(mockTableDescriptor.getSerde()).thenReturn(defaultSerde);
+    // add side input and terminate at table in the appplication
+    mockStreamAppDesc.getInputStream(sideInput1).sendTo(mockStreamAppDesc.getTable(mockTableDescriptor));
+    StreamEdge sideInputEdge = new StreamEdge(new StreamSpec(sideInput1.getStreamId(), "sideInput1",
+        inputSystemDescriptor.getSystemName()), false, false, mockConfig);
+    // need to put the sideInput related stream configuration to the original config
+    // TODO: this is confusing since part of the system and stream related configuration is generated outside the JobGraphConfigureGenerator
+    // It would be nice if all system and stream related configuration is generated in one place and only intermediate stream
+    // configuration is generated by JobGraphConfigureGenerator
+    Map<String, String> configs = new HashMap<>(mockConfig);
+    configs.putAll(sideInputEdge.generateConfig());
+    mockConfig = spy(new MapConfig(configs));
+    configureJobNode(mockStreamAppDesc);
+
+    // create the JobGraphConfigureGenerator and generate the jobConfig for the jobNode
+    JobNodeConfigurationGenerator configureGenerator = new JobNodeConfigurationGenerator();
+    JobConfig jobConfig = configureGenerator.generateJobConfig(mockJobNode, "testJobGraphJson");
+    Config expectedJobConfig = getExpectedJobConfig(mockConfig, mockJobNode.getInEdges());
+    validateJobConfig(expectedJobConfig, jobConfig);
+    Map<String, Serde> deserializedSerdes = validateAndGetDeserializedSerdes(jobConfig, 5);
+    validateTableConfigure(jobConfig, deserializedSerdes, mockTableDescriptor);
+  }
+
+  @Test
+  public void testTaskApplicationWithTableAndSideInput() {
+    // add table to the RepartitionJoinStreamApplication
+    GenericInputDescriptor<KV<String, Object>> sideInput1 = inputSystemDescriptor.getInputDescriptor("sideInput1", defaultSerde);
+    BaseTableDescriptor mockTableDescriptor = mock(BaseTableDescriptor.class);
+    TableSpec mockTableSpec = mock(TableSpec.class);
+    when(mockTableSpec.getId()).thenReturn("testTable");
+    when(mockTableSpec.getSerde()).thenReturn((KVSerde) defaultSerde);
+    when(mockTableSpec.getTableProviderFactoryClassName()).thenReturn(MockTableProviderFactory.class.getName());
+    List<String> sideInputs = new ArrayList<>();
+    sideInputs.add(sideInput1.getStreamId());
+    when(mockTableSpec.getSideInputs()).thenReturn(sideInputs);
+    when(mockTableDescriptor.getTableId()).thenReturn("testTable");
+    when(mockTableDescriptor.getTableSpec()).thenReturn(mockTableSpec);
+    when(mockTableDescriptor.getSerde()).thenReturn(defaultSerde);
+    StreamEdge sideInputEdge = new StreamEdge(new StreamSpec(sideInput1.getStreamId(), "sideInput1",
+        inputSystemDescriptor.getSystemName()), false, false, mockConfig);
+    // need to put the sideInput related stream configuration to the original config
+    // TODO: this is confusing since part of the system and stream related configuration is generated outside the JobGraphConfigureGenerator
+    // It would be nice if all system and stream related configuration is generated in one place and only intermediate stream
+    // configuration is generated by JobGraphConfigureGenerator
+    Map<String, String> configs = new HashMap<>(mockConfig);
+    configs.putAll(sideInputEdge.generateConfig());
+    mockConfig = spy(new MapConfig(configs));
+
+    // set the application to TaskApplication, which still wire up all input/output/intermediate streams
+    TaskApplicationDescriptorImpl taskAppDesc = new TaskApplicationDescriptorImpl(getTaskApplication(), mockConfig);
+    // add table to the task application
+    taskAppDesc.addTable(mockTableDescriptor);
+    taskAppDesc.addInputStream(inputSystemDescriptor.getInputDescriptor("sideInput1", defaultSerde));
+    configureJobNode(taskAppDesc);
+
+    // create the JobGraphConfigureGenerator and generate the jobConfig for the jobNode
+    JobNodeConfigurationGenerator configureGenerator = new JobNodeConfigurationGenerator();
+    JobConfig jobConfig = configureGenerator.generateJobConfig(mockJobNode, "testJobGraphJson");
+
+    // Verify the results
+    Config expectedJobConfig = getExpectedJobConfig(mockConfig, mockJobNode.getInEdges());
+    validateJobConfig(expectedJobConfig, jobConfig);
+    Map<String, Serde> deserializedSerdes = validateAndGetDeserializedSerdes(jobConfig, 2);
+    validateStreamConfigures(jobConfig, deserializedSerdes);
+    validateTableConfigure(jobConfig, deserializedSerdes, mockTableDescriptor);
+  }
+
+  @Test
+  public void testTaskInputsRemovedFromOriginalConfig() {
+    Map<String, String> configs = new HashMap<>(mockConfig);
+    configs.put(TaskConfig.INPUT_STREAMS(), "not.allowed1,not.allowed2");
+    mockConfig = spy(new MapConfig(configs));
+
+    mockStreamAppDesc = new StreamApplicationDescriptorImpl(getBroadcastOnlyStreamApplication(defaultSerde), mockConfig);
+    configureJobNode(mockStreamAppDesc);
+
+    JobNodeConfigurationGenerator configureGenerator = new JobNodeConfigurationGenerator();
+    JobConfig jobConfig = configureGenerator.generateJobConfig(mockJobNode, "testJobGraphJson");
+    Config expectedConfig = getExpectedJobConfig(mockConfig, mockJobNode.getInEdges());
+    validateJobConfig(expectedConfig, jobConfig);
+  }
+
+  @Test
+  public void testTaskInputsRetainedForLegacyTaskApplication() {
+    Map<String, String> originConfig = new HashMap<>(mockConfig);
+    originConfig.put(TaskConfig.INPUT_STREAMS(), "must.retain1,must.retain2");
+    mockConfig = new MapConfig(originConfig);
+    TaskApplicationDescriptorImpl taskAppDesc = new TaskApplicationDescriptorImpl(getLegacyTaskApplication(), mockConfig);
+    configureJobNode(taskAppDesc);
+
+    // create the JobGraphConfigureGenerator and generate the jobConfig for the jobNode
+    JobNodeConfigurationGenerator configureGenerator = new JobNodeConfigurationGenerator();
+    JobConfig jobConfig = configureGenerator.generateJobConfig(mockJobNode, "");
+    // jobConfig should be exactly the same as original config
+    Map<String, String> generatedConfig = new HashMap<>(jobConfig);
+    assertEquals(originConfig, generatedConfig);
+  }
+
+  @Test
+  public void testOverrideConfigs() {
+    Map<String, String> configs = new HashMap<>(mockConfig);
+    String streamCfgToOverride = String.format("streams.%s.samza.system", intermediateInputDescriptor.getStreamId());
+    String overrideCfgKey = String.format(JobConfig.CONFIG_OVERRIDE_JOBS_PREFIX(), getJobNameAndId()) + streamCfgToOverride;
+    configs.put(overrideCfgKey, "customized-system");
+    mockConfig = spy(new MapConfig(configs));
+    mockStreamAppDesc = new StreamApplicationDescriptorImpl(getRepartitionJoinStreamApplication(), mockConfig);
+    configureJobNode(mockStreamAppDesc);
+
+    JobNodeConfigurationGenerator configureGenerator = new JobNodeConfigurationGenerator();
+    JobConfig jobConfig = configureGenerator.generateJobConfig(mockJobNode, "testJobGraphJson");
+    Config expectedConfig = getExpectedJobConfig(mockConfig, mockJobNode.getInEdges());
+    validateJobConfig(expectedConfig, jobConfig);
+    assertEquals("customized-system", jobConfig.get(streamCfgToOverride));
+  }
+
+  @Test
+  public void testConfigureRewriter() {
+    Map<String, String> configs = new HashMap<>(mockConfig);
+    String streamCfgToOverride = String.format("streams.%s.samza.system", intermediateInputDescriptor.getStreamId());
+    String overrideCfgKey = String.format(JobConfig.CONFIG_OVERRIDE_JOBS_PREFIX(), getJobNameAndId()) + streamCfgToOverride;
+    configs.put(overrideCfgKey, "customized-system");
+    configs.put(String.format(JobConfig.CONFIG_REWRITER_CLASS(), "mock"), MockConfigRewriter.class.getName());
+    configs.put(JobConfig.CONFIG_REWRITERS(), "mock");
+    configs.put(String.format("job.config.rewriter.mock.%s", streamCfgToOverride), "rewritten-system");
+    mockConfig = spy(new MapConfig(configs));
+    mockStreamAppDesc = new StreamApplicationDescriptorImpl(getRepartitionJoinStreamApplication(), mockConfig);
+    configureJobNode(mockStreamAppDesc);
+
+    JobNodeConfigurationGenerator configureGenerator = new JobNodeConfigurationGenerator();
+    JobConfig jobConfig = configureGenerator.generateJobConfig(mockJobNode, "testJobGraphJson");
+    Config expectedConfig = getExpectedJobConfig(mockConfig, mockJobNode.getInEdges());
+    validateJobConfig(expectedConfig, jobConfig);
+    assertEquals("rewritten-system", jobConfig.get(streamCfgToOverride));
+  }
+
+  private void validateTableConfigure(JobConfig jobConfig, Map<String, Serde> deserializedSerdes,
+      TableDescriptor tableDescriptor) {
+    Config tableConfig = jobConfig.subset(String.format("tables.%s.", tableDescriptor.getTableId()));
+    assertEquals(MockTableProviderFactory.class.getName(), tableConfig.get("provider.factory"));
+    MockTableProvider mockTableProvider =
+        (MockTableProvider) new MockTableProviderFactory().getTableProvider(((BaseTableDescriptor) tableDescriptor).getTableSpec());
+    assertEquals(mockTableProvider.configMap.get("mock.table.provider.config"), jobConfig.get("mock.table.provider.config"));
+    validateTableSerdeConfigure(tableDescriptor.getTableId(), jobConfig, deserializedSerdes);
+  }
+
+  private Config getExpectedJobConfig(Config originConfig, Map<String, StreamEdge> inputEdges) {
+    Map<String, String> configMap = new HashMap<>(originConfig);
+    Set<String> inputs = new HashSet<>();
+    Set<String> broadcasts = new HashSet<>();
+    for (StreamEdge inputEdge : inputEdges.values()) {
+      if (inputEdge.isBroadcast()) {
+        broadcasts.add(inputEdge.getName() + "#0");
+      } else {
+        inputs.add(inputEdge.getName());
+      }
+    }
+    if (!inputs.isEmpty()) {
+      configMap.put(TaskConfig.INPUT_STREAMS(), Joiner.on(',').join(inputs));
+    }
+    if (!broadcasts.isEmpty()) {
+      configMap.put(TaskConfigJava.BROADCAST_INPUT_STREAMS, Joiner.on(',').join(broadcasts));
+    }
+    return new MapConfig(configMap);
+  }
+
+  private Map<String, Serde> validateAndGetDeserializedSerdes(Config jobConfig, int numSerdes) {
+    Config serializers = jobConfig.subset("serializers.registry.", true);
+    // make sure that the serializers deserialize correctly
+    SerializableSerde<Serde> serializableSerde = new SerializableSerde<>();
+    assertEquals(numSerdes, serializers.size());
+    return serializers.entrySet().stream().collect(Collectors.toMap(
+        e -> e.getKey().replace(SerializerConfig.SERIALIZED_INSTANCE_SUFFIX(), ""),
+        e -> serializableSerde.fromBytes(Base64.getDecoder().decode(e.getValue().getBytes()))
+    ));
+  }
+
+  private void validateJobConfig(Config expectedConfig, JobConfig jobConfig) {
+    assertEquals(expectedConfig.get(JobConfig.JOB_NAME()), jobConfig.getName().get());
+    assertEquals(expectedConfig.get(JobConfig.JOB_ID()), jobConfig.getJobId());
+    assertEquals("testJobGraphJson", jobConfig.get(JobNodeConfigurationGenerator.CONFIG_INTERNAL_EXECUTION_PLAN));
+    assertEquals(expectedConfig.get(TaskConfig.INPUT_STREAMS()), jobConfig.get(TaskConfig.INPUT_STREAMS()));
+    assertEquals(expectedConfig.get(TaskConfigJava.BROADCAST_INPUT_STREAMS), jobConfig.get(TaskConfigJava.BROADCAST_INPUT_STREAMS));
+  }
+
+  private void validateStreamSerdeConfigure(String streamId, Config config, Map<String, Serde> deserializedSerdes) {
+    Config streamConfig = config.subset(String.format("streams.%s.samza.", streamId));
+    String keySerdeName = streamConfig.get("key.serde");
+    String valueSerdeName = streamConfig.get("msg.serde");
+    assertTrue(String.format("Serialized serdes should contain %s key serde", streamId), deserializedSerdes.containsKey(keySerdeName));
+    assertTrue(String.format("Serialized %s key serde should be a StringSerde", streamId), keySerdeName.startsWith(StringSerde.class.getSimpleName()));
+    assertTrue(String.format("Serialized serdes should contain %s msg serde", streamId), deserializedSerdes.containsKey(valueSerdeName));
+    assertTrue(String.format("Serialized %s msg serde should be a JsonSerdeV2", streamId), valueSerdeName.startsWith(JsonSerdeV2.class.getSimpleName()));
+  }
+
+  private void validateTableSerdeConfigure(String tableId, Config config, Map<String, Serde> deserializedSerdes) {
+    Config streamConfig = config.subset(String.format("tables.%s.", tableId));
+    String keySerdeName = streamConfig.get("key.serde");
+    String valueSerdeName = streamConfig.get("value.serde");
+    assertTrue(String.format("Serialized serdes should contain %s key serde", tableId), deserializedSerdes.containsKey(keySerdeName));
+    assertTrue(String.format("Serialized %s key serde should be a StringSerde", tableId), keySerdeName.startsWith(StringSerde.class.getSimpleName()));
+    assertTrue(String.format("Serialized serdes should contain %s value serde", tableId), deserializedSerdes.containsKey(valueSerdeName));
+    assertTrue(String.format("Serialized %s msg serde should be a JsonSerdeV2", tableId), valueSerdeName.startsWith(JsonSerdeV2.class.getSimpleName()));
+  }
+
+  private void validateIntermediateStreamConfigure(String streamId, String physicalName, Config config) {
+    Config intStreamConfig = config.subset(String.format("streams.%s.", streamId),  true);
+    assertEquals("intermediate-system", intStreamConfig.get("samza.system"));
+    assertEquals(String.valueOf(Integer.MAX_VALUE), intStreamConfig.get("samza.priority"));
+    assertEquals("true", intStreamConfig.get("samza.delete.committed.messages"));
+    assertEquals(physicalName, intStreamConfig.get("samza.physical.name"));
+    assertEquals("true", intStreamConfig.get("samza.intermediate"));
+    assertEquals("oldest", intStreamConfig.get("samza.offset.default"));
+  }
+
+  private void validateStreamConfigures(Config config, Map<String, Serde> deserializedSerdes) {
+
+    if (deserializedSerdes != null) {
+      validateStreamSerdeConfigure(input1Descriptor.getStreamId(), config, deserializedSerdes);
+      validateStreamSerdeConfigure(input2Descriptor.getStreamId(), config, deserializedSerdes);
+      validateStreamSerdeConfigure(outputDescriptor.getStreamId(), config, deserializedSerdes);
+      validateStreamSerdeConfigure(intermediateInputDescriptor.getStreamId(), config, deserializedSerdes);
+    }
+
+    // generated stream config for intermediate stream
+    String physicalName = intermediateInputDescriptor.getPhysicalName().isPresent() ?
+        intermediateInputDescriptor.getPhysicalName().get() : null;
+    validateIntermediateStreamConfigure(intermediateInputDescriptor.getStreamId(), physicalName, config);
+  }
+
+  private void validateJoinStoreConfigures(MapConfig mapConfig, Map<String, Serde> deserializedSerdes) {
+    String leftJoinStoreKeySerde = mapConfig.get("stores.jobName-jobId-join-j1-L.key.serde");
+    String leftJoinStoreMsgSerde = mapConfig.get("stores.jobName-jobId-join-j1-L.msg.serde");
+    assertTrue("Serialized serdes should contain left join store key serde",
+        deserializedSerdes.containsKey(leftJoinStoreKeySerde));
+    assertTrue("Serialized left join store key serde should be a StringSerde",
+        leftJoinStoreKeySerde.startsWith(StringSerde.class.getSimpleName()));
+    assertTrue("Serialized serdes should contain left join store msg serde",
+        deserializedSerdes.containsKey(leftJoinStoreMsgSerde));
+    assertTrue("Serialized left join store msg serde should be a TimestampedValueSerde",
+        leftJoinStoreMsgSerde.startsWith(TimestampedValueSerde.class.getSimpleName()));
+
+    String rightJoinStoreKeySerde = mapConfig.get("stores.jobName-jobId-join-j1-R.key.serde");
+    String rightJoinStoreMsgSerde = mapConfig.get("stores.jobName-jobId-join-j1-R.msg.serde");
+    assertTrue("Serialized serdes should contain right join store key serde",
+        deserializedSerdes.containsKey(rightJoinStoreKeySerde));
+    assertTrue("Serialized right join store key serde should be a StringSerde",
+        rightJoinStoreKeySerde.startsWith(StringSerde.class.getSimpleName()));
+    assertTrue("Serialized serdes should contain right join store msg serde",
+        deserializedSerdes.containsKey(rightJoinStoreMsgSerde));
+    assertTrue("Serialized right join store msg serde should be a TimestampedValueSerde",
+        rightJoinStoreMsgSerde.startsWith(TimestampedValueSerde.class.getSimpleName()));
+
+    Config leftJoinStoreConfig = mapConfig.subset("stores.jobName-jobId-join-j1-L.", true);
+    validateJoinStoreConfigure(leftJoinStoreConfig, "jobName-jobId-join-j1-L");
+    Config rightJoinStoreConfig = mapConfig.subset("stores.jobName-jobId-join-j1-R.", true);
+    validateJoinStoreConfigure(rightJoinStoreConfig, "jobName-jobId-join-j1-R");
+  }
+
+  private void validateJoinStoreConfigure(Config joinStoreConfig, String changelogName) {
+    assertEquals("org.apache.samza.storage.kv.RocksDbKeyValueStorageEngineFactory", joinStoreConfig.get("factory"));
+    assertEquals(changelogName, joinStoreConfig.get("changelog"));
+    assertEquals("delete", joinStoreConfig.get("changelog.kafka.cleanup.policy"));
+    assertEquals("3600000", joinStoreConfig.get("changelog.kafka.retention.ms"));
+    assertEquals("3600000", joinStoreConfig.get("rocksdb.ttl.ms"));
+  }
+
+  private static class MockTableProvider implements TableProvider {
+    private final Map<String, String> configMap;
+
+    MockTableProvider(Map<String, String> configMap) {
+      this.configMap = configMap;
+    }
+
+    @Override
+    public void init(SamzaContainerContext containerContext, TaskContext taskContext) {
+
+    }
+
+    @Override
+    public Table getTable() {
+      return null;
+    }
+
+    @Override
+    public Map<String, String> generateConfig(Config jobConfig, Map<String, String> generatedConfig) {
+      return configMap;
+    }
+
+    @Override
+    public void close() {
+
+    }
+  }
+
+  public static class MockTableProviderFactory implements TableProviderFactory {
+
+    @Override
+    public TableProvider getTableProvider(TableSpec tableSpec) {
+      Map<String, String> configMap = new HashMap<>();
+      configMap.put("mock.table.provider.config", "mock.config.value");
+      return new MockTableProvider(configMap);
+    }
+  }
+
+  public static class MockConfigRewriter implements ConfigRewriter {
+
+    @Override
+    public Config rewrite(String name, Config config) {
+      Map<String, String> configMap = new HashMap<>(config);
+      configMap.putAll(config.subset(String.format("job.config.rewriter.%s.", name)));
+      return new MapConfig(configMap);
+    }
+  }
+}
index 988fb34..85921f4 100644 (file)
@@ -69,7 +69,7 @@ public class TestRemoteJobPlanner {
     ApplicationConfig mockAppConfig = mock(ApplicationConfig.class);
     when(mockAppConfig.getAppMode()).thenReturn(ApplicationConfig.ApplicationMode.STREAM);
     when(plan.getApplicationConfig()).thenReturn(mockAppConfig);
-    doReturn(plan).when(remotePlanner).getExecutionPlan(any(), any());
+    doReturn(plan).when(remotePlanner).getExecutionPlan(any());
 
     remotePlanner.prepareJobs();
 
index a5b15b8..57ae6d8 100644 (file)
@@ -117,7 +117,6 @@ public class TestOperatorSpecGraph {
     OperatorSpecGraph specGraph = new OperatorSpecGraph(mockAppDesc);
     assertEquals(specGraph.getInputOperators(), inputOpSpecMap);
     assertEquals(specGraph.getOutputStreams(), outputStrmMap);
-    assertTrue(specGraph.getTables().isEmpty());
     assertTrue(!specGraph.hasWindowOrJoins());
     assertEquals(specGraph.getAllOperatorSpecs(), this.allOpSpecs);
   }
index 7704a5b..a34fdc3 100644 (file)
@@ -53,7 +53,6 @@ public class OperatorSpecTestUtils {
   public static void assertClonedGraph(OperatorSpecGraph originalGraph, OperatorSpecGraph clonedGraph) {
     assertClonedInputs(originalGraph.getInputOperators(), clonedGraph.getInputOperators());
     assertClonedOutputs(originalGraph.getOutputStreams(), clonedGraph.getOutputStreams());
-    assertClonedTables(originalGraph.getTables(), clonedGraph.getTables());
     assertAllOperators(originalGraph.getAllOperatorSpecs(), clonedGraph.getAllOperatorSpecs());
   }
 
index 19ee74f..fd0ddf8 100644 (file)
@@ -25,10 +25,10 @@ import java.util.HashMap;
 import java.util.Map;
 import org.apache.samza.application.ApplicationDescriptor;
 import org.apache.samza.application.ApplicationDescriptorImpl;
+import org.apache.samza.application.LegacyTaskApplication;
 import org.apache.samza.application.SamzaApplication;
 import org.apache.samza.application.ApplicationDescriptorUtil;
 import org.apache.samza.application.StreamApplication;
-import org.apache.samza.application.TaskApplication;
 import org.apache.samza.config.ApplicationConfig;
 import org.apache.samza.config.Config;
 import org.apache.samza.config.JobConfig;
@@ -37,7 +37,6 @@ import org.apache.samza.job.ApplicationStatus;
 import org.apache.samza.processor.StreamProcessor;
 import org.apache.samza.execution.LocalJobPlanner;
 import org.apache.samza.task.IdentityStreamTask;
-import org.apache.samza.task.StreamTaskFactory;
 import org.junit.Before;
 import org.junit.Test;
 import org.mockito.ArgumentCaptor;
@@ -73,8 +72,9 @@ public class TestLocalApplicationRunner {
     final Map<String, String> cfgs = new HashMap<>();
     cfgs.put(ApplicationConfig.APP_PROCESSOR_ID_GENERATOR_CLASS, UUIDGenerator.class.getName());
     cfgs.put(JobConfig.JOB_NAME(), "test-task-job");
+    cfgs.put(JobConfig.JOB_ID(), "jobId");
     config = new MapConfig(cfgs);
-    mockApp = (TaskApplication) appDesc -> appDesc.setTaskFactory((StreamTaskFactory) () -> new IdentityStreamTask());
+    mockApp = new LegacyTaskApplication(IdentityStreamTask.class.getName());
     prepareTest();
 
     StreamProcessor sp = mock(StreamProcessor.class);
@@ -186,7 +186,8 @@ public class TestLocalApplicationRunner {
   }
 
   private void prepareTest() {
-    ApplicationDescriptorImpl<? extends ApplicationDescriptor> appDesc = ApplicationDescriptorUtil.getAppDescriptor(mockApp, config);
+    ApplicationDescriptorImpl<? extends ApplicationDescriptor> appDesc =
+        ApplicationDescriptorUtil.getAppDescriptor(mockApp, config);
     localPlanner = spy(new LocalJobPlanner(appDesc));
     runner = spy(new LocalApplicationRunner(appDesc, localPlanner));
   }
index ae525fb..702cbfb 100644 (file)
@@ -124,7 +124,7 @@ public class TestRemoteApplicationRunner {
 
         @Override
         public ApplicationStatus getStatus() {
-          String jobId = c.getJobId().get();
+          String jobId = c.getJobId();
           switch (jobId) {
             case "newJob":
               return ApplicationStatus.New;
index 05d717a..3f5f11c 100644 (file)
@@ -35,7 +35,7 @@ class HdfsSystemFactory extends SystemFactory with Logging {
   def getProducer(systemName: String, config: Config, registry: MetricsRegistry) = {
     val jobConfig = new JobConfig(config)
     val jobName = jobConfig.getName.getOrElse(throw new ConfigException("Missing job name."))
-    val jobId = jobConfig.getJobId.getOrElse("1")
+    val jobId = jobConfig.getJobId
 
     val clientId = getClientId("samza-producer", jobName, jobId)
     val metrics = new HdfsSystemProducerMetrics(systemName, registry)
index 8d4098f..2999800 100644 (file)
@@ -32,7 +32,7 @@ class KafkaCheckpointManagerFactory extends CheckpointManagerFactory with Loggin
 
   def getCheckpointManager(config: Config, registry: MetricsRegistry): CheckpointManager = {
     val jobName = config.getName.getOrElse(throw new SamzaException("Missing job name in configs"))
-    val jobId = config.getJobId.getOrElse("1")
+    val jobId = config.getJobId
 
     val kafkaConfig = new KafkaConfig(config)
     val checkpointSystemName = kafkaConfig.getCheckpointSystem.getOrElse(
index 3fa66e5..6cebc28 100644 (file)
@@ -31,7 +31,6 @@ import org.apache.samza.SamzaException;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 import scala.Option;
-import scala.runtime.AbstractFunction0;
 
 
 /**
@@ -126,11 +125,7 @@ public class KafkaConsumerConfig extends HashMap<String, Object> {
     }
     String jobName = (String) jobNameOption.get();
 
-    Option jobIdOption = jobConfig.getJobId();
-    String jobId = "1";
-    if (! jobIdOption.isEmpty()) {
-      jobId = (String) jobIdOption.get();
-    }
+    String jobId = jobConfig.getJobId();
 
     return String.format("%s-%s", jobName, jobId);
   }
@@ -156,11 +151,7 @@ public class KafkaConsumerConfig extends HashMap<String, Object> {
     }
     String jobName = (String) jobNameOption.get();
 
-    Option jobIdOption = jobConfig.getJobId();
-    String jobId = "1";
-    if (! jobIdOption.isEmpty()) {
-      jobId = (String) jobIdOption.get();
-    }
+    String jobId = jobConfig.getJobId();
 
     return String.format("%s-%s-%s", id.replaceAll("\\W", "_"), jobName.replaceAll("\\W", "_"),
         jobId.replaceAll("\\W", "_"));
index 601ffa2..2d09301 100644 (file)
@@ -40,7 +40,7 @@ object KafkaUtil extends Logging {
   def getClientId(id: String, config: Config): String = getClientId(
     id,
     config.getName.getOrElse(throw new ConfigException("Missing job name.")),
-    config.getJobId.getOrElse("1"))
+    config.getJobId)
 
   def getClientId(id: String, jobName: String, jobId: String): String =
     "%s-%s-%s" format
index 07f4f55..8231905 100644 (file)
@@ -89,10 +89,15 @@ abstract public class BaseLocalStoreBackedTableProvider extends BaseTableProvide
 
     Map<String, String> storeConfig = new HashMap<>();
 
-    // We assume the configuration for serde are already generated for this table,
-    // so we simply carry them over to store configuration.
-    //
-    JavaTableConfig tableConfig = new JavaTableConfig(new MapConfig(generatedConfig));
+    // serde configurations for tables are generated at top level by JobNodeConfigurationGenerator and are included
+    // in the global jobConfig. generatedConfig has all table specific configuration generated from TableSpec, such
+    // as TableProviderFactory, sideInputs, etc.
+    // Merge the global jobConfig and generatedConfig to get full access to configuration needed to create local
+    // store configuration
+    Map<String, String> mergedConfigMap = new HashMap<>(jobConfig);
+    mergedConfigMap.putAll(generatedConfig);
+    JobConfig mergedJobConfig = new JobConfig(new MapConfig(mergedConfigMap));
+    JavaTableConfig tableConfig = new JavaTableConfig(mergedJobConfig);
 
     String keySerde = tableConfig.getKeySerde(tableSpec.getId());
     storeConfig.put(String.format(StorageConfig.KEY_SERDE(), tableSpec.getId()), keySerde);
@@ -116,9 +121,7 @@ abstract public class BaseLocalStoreBackedTableProvider extends BaseTableProvide
     if (enableChangelog) {
       String changelogStream = tableSpec.getConfig().get(BaseLocalStoreBackedTableDescriptor.INTERNAL_CHANGELOG_STREAM);
       if (StringUtils.isEmpty(changelogStream)) {
-        changelogStream = String.format("%s-%s-table-%s",
-            jobConfig.get(JobConfig.JOB_NAME()),
-            jobConfig.get(JobConfig.JOB_ID(), "1"),
+        changelogStream = String.format("%s-%s-table-%s", mergedJobConfig.getName().get(), mergedJobConfig.getJobId(),
             tableSpec.getId());
       }
 
index 5c4ba3b..6d6613c 100644 (file)
@@ -31,9 +31,9 @@ import java.util.stream.Collectors;
 import org.apache.commons.lang.RandomStringUtils;
 import org.apache.commons.lang.exception.ExceptionUtils;
 import org.apache.samza.SamzaException;
+import org.apache.samza.application.LegacyTaskApplication;
 import org.apache.samza.application.SamzaApplication;
 import org.apache.samza.application.StreamApplication;
-import org.apache.samza.application.TaskApplication;
 import org.apache.samza.config.Config;
 import org.apache.samza.config.JobConfig;
 import org.apache.samza.config.JobCoordinatorConfig;
@@ -57,10 +57,7 @@ import org.apache.samza.system.SystemStreamMetadata;
 import org.apache.samza.system.SystemStreamPartition;
 import org.apache.samza.system.inmemory.InMemorySystemFactory;
 import org.apache.samza.task.AsyncStreamTask;
-import org.apache.samza.task.AsyncStreamTaskFactory;
 import org.apache.samza.task.StreamTask;
-import org.apache.samza.task.StreamTaskFactory;
-import org.apache.samza.task.TaskFactory;
 import org.apache.samza.test.framework.system.InMemoryInputDescriptor;
 import org.apache.samza.test.framework.system.InMemoryOutputDescriptor;
 import org.apache.samza.test.framework.system.InMemorySystemDescriptor;
@@ -86,8 +83,7 @@ public class TestRunner {
   public static final String JOB_NAME = "samza-test";
 
   private Map<String, String> configs;
-  private Class taskClass;
-  private StreamApplication app;
+  private SamzaApplication app;
   /*
    * inMemoryScope is a unique global key per TestRunner, this key when configured with {@link InMemorySystemDescriptor}
    * provides an isolated state to run with in memory system
@@ -112,7 +108,7 @@ public class TestRunner {
     this();
     Preconditions.checkNotNull(taskClass);
     configs.put(TaskConfig.TASK_CLASS(), taskClass.getName());
-    this.taskClass = taskClass;
+    this.app = new LegacyTaskApplication(taskClass.getName());
   }
 
   /**
@@ -159,6 +155,17 @@ public class TestRunner {
   }
 
   /**
+   * Only adds a config from {@code config} to samza job {@code configs} if they dont exist in it.
+   * @param config configs for the application
+   * @return this {@link TestRunner}
+   */
+  public TestRunner addConfigs(Map<String, String> config, String configPrefix) {
+    Preconditions.checkNotNull(config);
+    config.forEach((key, value) -> this.configs.putIfAbsent(String.format("%s%s", configPrefix, key), value));
+    return this;
+  }
+
+  /**
    * Adds a config to {@code configs} if its not already present. Overrides a config value for which key is already
    * exisiting in {@code configs}
    * @param key key of the config
@@ -168,7 +175,7 @@ public class TestRunner {
   public TestRunner addOverrideConfig(String key, String value) {
     Preconditions.checkNotNull(key);
     Preconditions.checkNotNull(value);
-    String configKeyPrefix = String.format(JobConfig.CONFIG_JOB_PREFIX(), JOB_NAME);
+    String configKeyPrefix = String.format(JobConfig.CONFIG_OVERRIDE_JOBS_PREFIX(), getJobNameAndId());
     configs.put(String.format("%s%s", configKeyPrefix, key), value);
     return this;
   }
@@ -192,6 +199,10 @@ public class TestRunner {
     return this;
   }
 
+  private String getJobNameAndId() {
+    return String.format("%s-%s", JOB_NAME, configs.getOrDefault(JobConfig.JOB_ID(), "1"));
+  }
+
   /**
    * Adds the provided input stream with mock data to the test application.
    * @param descriptor describes the stream that is supposed to be input to Samza application
@@ -243,11 +254,10 @@ public class TestRunner {
    * @throws SamzaException if Samza job fails with exception and returns UnsuccessfulFinish as the statuscode
    */
   public void run(Duration timeout) {
-    Preconditions.checkState((app == null && taskClass != null) || (app != null && taskClass == null),
+    Preconditions.checkState(app != null,
         "TestRunner should run for Low Level Task api or High Level Application Api");
     Preconditions.checkState(!timeout.isZero() || !timeout.isNegative(), "Timeouts should be positive");
-    SamzaApplication testApp = app == null ? (TaskApplication) appDesc -> appDesc.setTaskFactory(createTaskFactory()) : app;
-    final LocalApplicationRunner runner = new LocalApplicationRunner(testApp, new MapConfig(configs));
+    final LocalApplicationRunner runner = new LocalApplicationRunner(app, new MapConfig(configs));
     runner.run();
     boolean timedOut = !runner.waitForFinish(timeout);
     Assert.assertFalse("Timed out waiting for application to finish", timedOut);
@@ -326,28 +336,6 @@ public class TestRunner {
             entry -> entry.getValue().stream().map(e -> (StreamMessageType) e.getMessage()).collect(Collectors.toList())));
   }
 
-  private TaskFactory createTaskFactory() {
-    if (StreamTask.class.isAssignableFrom(taskClass)) {
-      return (StreamTaskFactory) () -> {
-        try {
-          return (StreamTask) taskClass.newInstance();
-        } catch (InstantiationException | IllegalAccessException e) {
-          throw new SamzaException(String.format("Failed to instantiate StreamTask class %s", taskClass.getName()), e);
-        }
-      };
-    } else if (AsyncStreamTask.class.isAssignableFrom(taskClass)) {
-      return (AsyncStreamTaskFactory) () -> {
-        try {
-          return (AsyncStreamTask) taskClass.newInstance();
-        } catch (InstantiationException | IllegalAccessException e) {
-          throw new SamzaException(String.format("Failed to instantiate AsyncStreamTask class %s", taskClass.getName()), e);
-        }
-      };
-    }
-    throw new SamzaException(String.format("Not supported task.class %s. task.class has to implement either StreamTask "
-        + "or AsyncStreamTask", taskClass.getName()));
-  }
-
   /**
    * Creates an in memory stream with {@link InMemorySystemFactory} and feeds its partition with stream of messages
    * @param partitonData key of the map represents partitionId and value represents
@@ -367,7 +355,7 @@ public class TestRunner {
     InMemorySystemDescriptor imsd = (InMemorySystemDescriptor) descriptor.getSystemDescriptor();
     imsd.withInMemoryScope(this.inMemoryScope);
     addConfigs(descriptor.toConfig());
-    addConfigs(descriptor.getSystemDescriptor().toConfig());
+    addConfigs(descriptor.getSystemDescriptor().toConfig(), String.format(JobConfig.CONFIG_OVERRIDE_JOBS_PREFIX(), getJobNameAndId()));
     StreamSpec spec = new StreamSpec(descriptor.getStreamId(), streamName, systemName, partitonData.size());
     SystemFactory factory = new InMemorySystemFactory();
     Config config = new MapConfig(descriptor.toConfig(), descriptor.getSystemDescriptor().toConfig());
@@ -381,7 +369,7 @@ public class TestRunner {
             producer.send(systemName, new OutgoingMessageEnvelope(sysStream, Integer.valueOf(partitionId), key, value));
           });
         producer.send(systemName, new OutgoingMessageEnvelope(sysStream, Integer.valueOf(partitionId), null,
-          new EndOfStreamMessage(null)));
+            new EndOfStreamMessage(null)));
       });
   }
 }
index 92b23ef..e6e423f 100644 (file)
@@ -29,7 +29,6 @@ import org.apache.samza.serializers.Serde;
 import org.apache.samza.system.SystemStreamMetadata;
 import org.apache.samza.system.inmemory.InMemorySystemFactory;
 import org.apache.samza.config.JavaSystemConfig;
-import org.apache.samza.test.framework.TestRunner;
 
 
 /**
@@ -60,9 +59,6 @@ public class InMemorySystemDescriptor extends SystemDescriptor<InMemorySystemDes
    * </ol>
    *
    **/
-  private static final String CONFIG_OVERRIDE_PREFIX = "jobs.%s.";
-  private static final String DEFAULT_STREAM_OFFSET_DEFAULT_CONFIG_KEY = "systems.%s.default.stream.samza.offset.default";
-
   private String inMemoryScope;
 
   /**
@@ -106,11 +102,7 @@ public class InMemorySystemDescriptor extends SystemDescriptor<InMemorySystemDes
   public Map<String, String> toConfig() {
     HashMap<String, String> configs = new HashMap<>(super.toConfig());
     configs.put(InMemorySystemConfig.INMEMORY_SCOPE, this.inMemoryScope);
-    configs.put(String.format(CONFIG_OVERRIDE_PREFIX + JavaSystemConfig.SYSTEM_FACTORY_FORMAT, TestRunner.JOB_NAME, getSystemName()),
-        FACTORY_CLASS_NAME);
-    configs.put(
-        String.format(CONFIG_OVERRIDE_PREFIX + DEFAULT_STREAM_OFFSET_DEFAULT_CONFIG_KEY, TestRunner.JOB_NAME,
-            getSystemName()), SystemStreamMetadata.OffsetType.OLDEST.toString());
+    configs.put(String.format(JavaSystemConfig.SYSTEM_FACTORY_FORMAT, getSystemName()), FACTORY_CLASS_NAME);
     return configs;
   }
 
index d123cee..6186ca7 100644 (file)
@@ -96,8 +96,6 @@ public class TestTableDescriptorsProvider {
     Assert.assertEquals(storageConfig.getStoreNames().get(0), localTableId);
     Assert.assertEquals(storageConfig.getStorageFactoryClassName(localTableId),
         RocksDbKeyValueStorageEngineFactory.class.getName());
-    Assert.assertTrue(storageConfig.getStorageKeySerde(localTableId).startsWith("StringSerde"));
-    Assert.assertTrue(storageConfig.getStorageMsgSerde(localTableId).startsWith("StringSerde"));
     Config storeConfig = resultConfig.subset("stores." + localTableId + ".", true);
     Assert.assertEquals(4, storeConfig.size());
     Assert.assertEquals(4096, storeConfig.getInt("rocksdb.block.size.bytes"));
@@ -107,10 +105,6 @@ public class TestTableDescriptorsProvider {
         RocksDbTableProviderFactory.class.getName());
     Assert.assertEquals(tableConfig.getTableProviderFactory(remoteTableId),
         RemoteTableProviderFactory.class.getName());
-    Assert.assertTrue(tableConfig.getKeySerde(localTableId).startsWith("StringSerde"));
-    Assert.assertTrue(tableConfig.getValueSerde(localTableId).startsWith("StringSerde"));
-    Assert.assertTrue(tableConfig.getKeySerde(remoteTableId).startsWith("StringSerde"));
-    Assert.assertTrue(tableConfig.getValueSerde(remoteTableId).startsWith("LongSerde"));
     Assert.assertEquals(tableConfig.getTableProviderFactory(localTableId), RocksDbTableProviderFactory.class.getName());
     Assert.assertEquals(tableConfig.getTableProviderFactory(remoteTableId), RemoteTableProviderFactory.class.getName());
   }
index b30b896..4adb93a 100644 (file)
@@ -76,7 +76,7 @@ public class YarnJobValidationTool {
     this.config = config;
     this.client = client;
     String name = this.config.getName().get();
-    String jobId = this.config.getJobId().nonEmpty()? this.config.getJobId().get() : "1";
+    String jobId = this.config.getJobId();
     this.jobName =  name + "_" + jobId;
     this.validator = validator;
   }
index d335448..1d72a88 100644 (file)
@@ -67,7 +67,7 @@ class YarnJob(config: Config, hadoopConfig: Configuration) extends StreamJob {
           }
           envMapWithJavaHome
         }),
-        Some("%s_%s" format(config.getName.get, config.getJobId.getOrElse(1)))
+        Some("%s_%s" format(config.getName.get, config.getJobId))
       )
     } catch {
       case e: Throwable =>
@@ -169,7 +169,7 @@ class YarnJob(config: Config, hadoopConfig: Configuration) extends StreamJob {
         // Get by name
         config.getName match {
           case Some(jobName) =>
-            val applicationName = "%s_%s" format(jobName, config.getJobId.getOrElse(1))
+            val applicationName = "%s_%s" format(jobName, config.getJobId)
             logger.info("Fetching status from YARN for application name %s" format applicationName)
             val applicationIds = client.getActiveApplicationIds(applicationName)