SAMZA-1123; Create intermediate stream in partitionBy() operator
authorXinyu Liu <xiliu@xiliu-ld.linkedin.biz>
Fri, 10 Mar 2017 22:08:18 +0000 (14:08 -0800)
committerXinyu Liu <xiliu@xiliu-ld.linkedin.biz>
Fri, 10 Mar 2017 22:08:18 +0000 (14:08 -0800)
For partitionBy() operator, Samza generates an intermediate stream with id based on operator name and id, and system based on config. The intermediate streams will be materialized later by different execution environments. For example, if the intermediate stream is a Kafka stream, the topic will be created before the application starts.

Also renamed the config from "job.runner.class" to "app.runner.class".

Author: Xinyu Liu <xiliu@xiliu-ld.linkedin.biz>

Reviewers: Prateek Maheshwari <pmaheshwari@linkedin.com>

Closes #79 from xinyuiscool/SAMZA-1123

samza-api/src/main/java/org/apache/samza/runtime/ApplicationRunner.java
samza-core/src/main/java/org/apache/samza/operators/MessageStreamImpl.java
samza-core/src/main/java/org/apache/samza/operators/StreamGraphImpl.java
samza-core/src/main/java/org/apache/samza/operators/spec/OperatorSpecs.java
samza-core/src/main/java/org/apache/samza/runtime/AbstractApplicationRunner.java
samza-core/src/main/java/org/apache/samza/runtime/LocalApplicationRunner.java
samza-core/src/main/java/org/apache/samza/task/StreamOperatorTask.java
samza-core/src/main/scala/org/apache/samza/config/StreamConfig.scala
samza-core/src/test/java/org/apache/samza/example/TestBasicStreamGraphs.java
samza-core/src/test/java/org/apache/samza/operators/TestMessageStreamImpl.java

index ff31eff..62c8a02 100644 (file)
@@ -35,7 +35,7 @@ import org.apache.samza.system.StreamSpec;
 @InterfaceStability.Unstable
 public interface ApplicationRunner {
 
-  String RUNNER_CONFIG = "job.runner.class";
+  String RUNNER_CONFIG = "app.runner.class";
   String DEFAULT_RUNNER_CLASS = "org.apache.samza.runtime.RemoteApplicationRunner";
 
   /**
index 830e4a5..b22f199 100644 (file)
@@ -163,10 +163,10 @@ public class MessageStreamImpl<M> implements MessageStream<M> {
 
   @Override
   public <K> MessageStream<M> partitionBy(Function<M, K> parKeyExtractor) {
-    MessageStreamImpl<M> intStream = this.graph.createIntStream(parKeyExtractor);
+    int opId = graph.getNextOpId();
+    MessageStreamImpl<M> intStream = this.graph.generateIntStreamFromOpId(opId, parKeyExtractor);
     OutputStream<M> outputStream = this.graph.getOutputStream(intStream);
-    this.registeredOperatorSpecs.add(OperatorSpecs.createPartitionOperatorSpec(outputStream.getSinkFunction(),
-        this.graph, outputStream));
+    this.registeredOperatorSpecs.add(OperatorSpecs.createPartitionOperatorSpec(outputStream.getSinkFunction(), outputStream, opId));
     return intStream;
   }
   /**
index 8ca8157..f801097 100644 (file)
  */
 package org.apache.samza.operators;
 
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
 import java.util.function.Function;
+import org.apache.samza.config.Config;
+import org.apache.samza.config.JobConfig;
 import org.apache.samza.operators.data.MessageEnvelope;
 import org.apache.samza.operators.functions.SinkFunction;
+import org.apache.samza.operators.spec.OperatorSpec;
+import org.apache.samza.runtime.ApplicationRunner;
 import org.apache.samza.serializers.Serde;
 import org.apache.samza.system.OutgoingMessageEnvelope;
 import org.apache.samza.system.StreamSpec;
@@ -28,10 +35,6 @@ import org.apache.samza.system.SystemStream;
 import org.apache.samza.task.MessageCollector;
 import org.apache.samza.task.TaskCoordinator;
 
-import java.util.Collections;
-import java.util.HashMap;
-import java.util.Map;
-
 /**
  * The implementation of {@link StreamGraph} interface. This class provides implementation of methods to allow users to
  * create system input/output/intermediate streams.
@@ -129,9 +132,16 @@ public class StreamGraphImpl implements StreamGraph {
    */
   private final Map<String, MessageStream> inStreams = new HashMap<>();
   private final Map<String, OutputStream> outStreams = new HashMap<>();
+  private final ApplicationRunner runner;
+  private final Config config;
 
   private ContextManager contextManager = new ContextManager() { };
 
+  public StreamGraphImpl(ApplicationRunner runner, Config config) {
+    this.runner = runner;
+    this.config = config;
+  }
+
   @Override
   public <K, V, M extends MessageEnvelope<K, V>> MessageStream<M> createInStream(StreamSpec streamSpec, Serde<K> keySerde, Serde<V> msgSerde) {
     if (!this.inStreams.containsKey(streamSpec.getId())) {
@@ -163,7 +173,8 @@ public class StreamGraphImpl implements StreamGraph {
    * @return  the {@link MessageStreamImpl} object
    */
   @Override
-  public <K, V, M extends MessageEnvelope<K, V>> OutputStream<M> createIntStream(StreamSpec streamSpec, Serde<K> keySerde, Serde<V> msgSerde) {
+  public <K, V, M extends MessageEnvelope<K, V>> OutputStream<M> createIntStream(StreamSpec streamSpec,
+      Serde<K> keySerde, Serde<V> msgSerde) {
     if (!this.inStreams.containsKey(streamSpec.getId())) {
       this.inStreams.putIfAbsent(streamSpec.getId(), new IntermediateStreamImpl<K, K, V, M>(this, streamSpec, keySerde, msgSerde));
     }
@@ -182,7 +193,12 @@ public class StreamGraphImpl implements StreamGraph {
 
   @Override public Map<StreamSpec, OutputStream> getOutStreams() {
     Map<StreamSpec, OutputStream> outStreamMap = new HashMap<>();
-    this.outStreams.forEach((ss, entry) -> outStreamMap.put(((OutputStreamImpl) entry).getSpec(), entry));
+    this.outStreams.forEach((ss, entry) -> {
+        StreamSpec streamSpec = (entry instanceof IntermediateStreamImpl) ?
+          ((IntermediateStreamImpl) entry).getSpec() :
+          ((OutputStreamImpl) entry).getSpec();
+        outStreamMap.put(streamSpec, entry);
+      });
     return Collections.unmodifiableMap(outStreamMap);
   }
 
@@ -208,8 +224,8 @@ public class StreamGraphImpl implements StreamGraph {
    */
   public MessageStreamImpl getInputStream(SystemStream sstream) {
     for (MessageStream entry: this.inStreams.values()) {
-      if (((InputStreamImpl) entry).getSpec().getSystemName() == sstream.getSystem() &&
-          ((InputStreamImpl) entry).getSpec().getPhysicalName() == sstream.getStream()) {
+      if (((InputStreamImpl) entry).getSpec().getSystemName().equals(sstream.getSystem()) &&
+          ((InputStreamImpl) entry).getSpec().getPhysicalName().equals(sstream.getStream())) {
         return (MessageStreamImpl) entry;
       }
     }
@@ -224,30 +240,25 @@ public class StreamGraphImpl implements StreamGraph {
   }
 
   /**
-   * Method to create intermediate topics for {@link MessageStreamImpl#partitionBy(Function)} method.
+   * Method to generate intermediate stream from an operator ID.
    *
+   * @param opId  operator ID
    * @param parKeyFn  the function to extract the partition key from the input message
    * @param <PK>  the type of partition key
    * @param <M>  the type of input message
    * @return  the {@link OutputStream} object for the re-partitioned stream
    */
-  <PK, M> MessageStreamImpl<M> createIntStream(Function<M, PK> parKeyFn) {
-    // TODO: placeholder to auto-generate intermediate streams via {@link StreamSpec}
-    StreamSpec streamSpec = this.createIntStreamSpec();
-
-    if (!this.inStreams.containsKey(streamSpec.getId())) {
-      this.inStreams.putIfAbsent(streamSpec.getId(), new IntermediateStreamImpl(this, streamSpec, null, null, parKeyFn));
-    }
+  <PK, M> MessageStreamImpl<M> generateIntStreamFromOpId(int opId, Function<M, PK> parKeyFn) {
+    String opNameWithId = String.format("%s-%s", OperatorSpec.OpCode.PARTITION_BY.name().toLowerCase(), opId);
+    String streamId = String.format("%s-%s-%s",
+        config.get(JobConfig.JOB_NAME()),
+        config.get(JobConfig.JOB_ID(), "1"),
+        opNameWithId);
+    StreamSpec streamSpec = runner.streamFromConfig(streamId);
+
+    this.inStreams.putIfAbsent(streamSpec.getId(), new IntermediateStreamImpl(this, streamSpec, null, null, parKeyFn));
     IntermediateStreamImpl intStream = (IntermediateStreamImpl) this.inStreams.get(streamSpec.getId());
-    if (!this.outStreams.containsKey(streamSpec.getId())) {
-      this.outStreams.putIfAbsent(streamSpec.getId(), intStream);
-    }
+    this.outStreams.putIfAbsent(streamSpec.getId(), intStream);
     return intStream;
   }
-
-  private StreamSpec createIntStreamSpec() {
-    // TODO: placeholder to generate the intermediate stream's {@link StreamSpec} automatically
-    return null;
-  }
-
 }
index d626852..ae82f9d 100644 (file)
@@ -147,13 +147,12 @@ public class OperatorSpecs {
    * Creates a {@link SinkOperatorSpec}.
    *
    * @param sinkFn  the sink function
-   * @param graph  the {@link StreamGraphImpl} object
    * @param stream  the {@link OutputStream} where the message is sent to
    * @param <M>  type of input message
    * @return  the {@link SinkOperatorSpec}
    */
-  public static <M> SinkOperatorSpec<M> createPartitionOperatorSpec(SinkFunction<M> sinkFn, StreamGraphImpl graph, OutputStream<M> stream) {
-    return new SinkOperatorSpec<>(sinkFn, OperatorSpec.OpCode.PARTITION_BY, graph.getNextOpId(), stream);
+  public static <M> SinkOperatorSpec<M> createPartitionOperatorSpec(SinkFunction<M> sinkFn, OutputStream<M> stream, int opId) {
+    return new SinkOperatorSpec<>(sinkFn, OperatorSpec.OpCode.PARTITION_BY, opId, stream);
   }
 
   /**
index 8e21ec2..a864868 100644 (file)
@@ -39,8 +39,7 @@ public abstract class AbstractApplicationRunner implements ApplicationRunner {
   @Override
   public StreamSpec streamFromConfig(String streamId) {
     StreamConfig streamConfig = new StreamConfig(config);
-    String physicalName = streamConfig.getPhysicalName(streamId, streamId);
-
+    String physicalName = streamConfig.getPhysicalName(streamId);
     return streamFromConfig(streamId, physicalName);
   }
 
index c936553..eb9a997 100644 (file)
 
 package org.apache.samza.runtime;
 
-import org.apache.samza.operators.StreamGraph;
-import org.apache.samza.operators.StreamGraphBuilder;
 import org.apache.samza.config.Config;
-import org.apache.samza.operators.StreamGraphImpl;
+import org.apache.samza.operators.StreamGraphBuilder;
 
 
 /**
@@ -34,13 +32,6 @@ public class LocalApplicationRunner extends AbstractApplicationRunner {
     super(config);
   }
 
-  // TODO: may want to move this to a common base class for all {@link ExecutionEnvironment}
-  StreamGraph createGraph(StreamGraphBuilder app, Config config) {
-    StreamGraphImpl graph = new StreamGraphImpl();
-    app.init(graph, config);
-    return graph;
-  }
-
   @Override public void run(StreamGraphBuilder app, Config config) {
     // 1. get logic graph for optimization
     // StreamGraph logicGraph = this.createGraph(app, config);
@@ -50,5 +41,4 @@ public class LocalApplicationRunner extends AbstractApplicationRunner {
     // 5. create the configuration for StreamProcessor
     // 6. start the StreamProcessor w/ optimized instance of StreamGraphBuilder
   }
-
 }
index b007e3c..c697b62 100644 (file)
@@ -18,6 +18,8 @@
  */
 package org.apache.samza.task;
 
+import java.util.HashMap;
+import java.util.Map;
 import org.apache.samza.config.Config;
 import org.apache.samza.operators.ContextManager;
 import org.apache.samza.operators.MessageStreamImpl;
@@ -25,12 +27,10 @@ import org.apache.samza.operators.StreamGraphBuilder;
 import org.apache.samza.operators.StreamGraphImpl;
 import org.apache.samza.operators.data.InputMessageEnvelope;
 import org.apache.samza.operators.impl.OperatorGraph;
+import org.apache.samza.runtime.ApplicationRunner;
 import org.apache.samza.system.IncomingMessageEnvelope;
 import org.apache.samza.system.SystemStream;
 
-import java.util.HashMap;
-import java.util.Map;
-
 
 /**
  * Execution of the logic sub-DAG
@@ -69,25 +69,28 @@ public final class StreamOperatorTask implements StreamTask, InitableTask, Windo
 
   private final StreamGraphBuilder graphBuilder;
 
+  private final ApplicationRunner runner;
+
   private ContextManager contextManager;
 
-  public StreamOperatorTask(StreamGraphBuilder graphBuilder) {
+  public StreamOperatorTask(StreamGraphBuilder graphBuilder, ApplicationRunner runner) {
     this.graphBuilder = graphBuilder;
+    this.runner = runner;
   }
 
   @Override
   public final void init(Config config, TaskContext context) throws Exception {
     // create the MessageStreamsImpl object and initialize app-specific logic DAG within the task
-    StreamGraphImpl streams = new StreamGraphImpl();
-    this.graphBuilder.init(streams, config);
+    StreamGraphImpl streamGraph = new StreamGraphImpl(this.runner, config);
+    this.graphBuilder.init(streamGraph, config);
     // get the context manager of the {@link StreamGraph} and initialize the task-specific context
-    this.contextManager = streams.getContextManager();
+    this.contextManager = streamGraph.getContextManager();
 
     Map<SystemStream, MessageStreamImpl> inputBySystemStream = new HashMap<>();
     context.getSystemStreamPartitions().forEach(ssp -> {
         if (!inputBySystemStream.containsKey(ssp.getSystemStream())) {
           // create mapping from the physical input {@link SystemStream} to the logic {@link MessageStream}
-          inputBySystemStream.putIfAbsent(ssp.getSystemStream(), streams.getInputStream(ssp.getSystemStream()));
+          inputBySystemStream.putIfAbsent(ssp.getSystemStream(), streamGraph.getInputStream(ssp.getSystemStream()));
         }
       });
     operatorGraph.init(inputBySystemStream, config, this.contextManager.initTaskContext(config, context));
index 6a3ed4b..4cce32f 100644 (file)
@@ -119,7 +119,7 @@ class StreamConfig(config: Config) extends ScalaMapConfig(config) with Logging {
     val allProperties = subset(StreamConfig.STREAM_ID_PREFIX format streamId)
     val samzaProperties = allProperties.subset(StreamConfig.SAMZA_PROPERTY, false)
     val filteredStreamProperties:java.util.Map[String, String] = allProperties.filterKeys(k => !samzaProperties.containsKey(k))
-    val inheritedLegacyProperties:java.util.Map[String, String] = getSystemStreamProperties(getSystem(streamId), getPhysicalName(streamId, streamId))
+    val inheritedLegacyProperties:java.util.Map[String, String] = getSystemStreamProperties(getSystem(streamId), getPhysicalName(streamId))
     new MapConfig(java.util.Arrays.asList(inheritedLegacyProperties, filteredStreamProperties))
   }
 
@@ -145,11 +145,11 @@ class StreamConfig(config: Config) extends ScalaMapConfig(config) with Logging {
     * Gets the physical name for the specified streamId.
     *
     * @param streamId             the identifier for the stream in the config.
-    * @param defaultPhysicalName  the default to use if the physical name is missing.
     * @return                     the physical identifier for the stream or the default if it is undefined.
     */
-  def getPhysicalName(streamId: String, defaultPhysicalName: String) = {
-    get(StreamConfig.PHYSICAL_NAME_FOR_STREAM_ID format streamId, defaultPhysicalName)
+  def getPhysicalName(streamId: String) = {
+    // use streamId as the default physical name
+    get(StreamConfig.PHYSICAL_NAME_FOR_STREAM_ID format streamId, streamId)
   }
 
   /**
@@ -177,7 +177,7 @@ class StreamConfig(config: Config) extends ScalaMapConfig(config) with Logging {
   }
 
   private def systemStreamToStreamId(systemStream: SystemStream): String = {
-   val streamIds = getStreamIdsForSystem(systemStream.getSystem).filter(streamId => systemStream.getStream().equals(getPhysicalName(streamId, streamId)))
+   val streamIds = getStreamIdsForSystem(systemStream.getSystem).filter(streamId => systemStream.getStream().equals(getPhysicalName(streamId)))
     if (streamIds.size > 1) {
       throw new IllegalStateException("There was more than one stream found for system stream %s" format(systemStream))
     }
@@ -194,7 +194,7 @@ class StreamConfig(config: Config) extends ScalaMapConfig(config) with Logging {
     * will use the streamId as the stream name if the physicalName doesn't exist.
     */
   private def streamIdToSystemStream(streamId: String): SystemStream = {
-    new SystemStream(getSystem(streamId), getPhysicalName(streamId, streamId))
+    new SystemStream(getSystem(streamId), getPhysicalName(streamId))
   }
 
   private def nonEmptyOption(value: String): Option[String] = {
index 8ecd44f..6975955 100644 (file)
@@ -21,7 +21,9 @@ package org.apache.samza.example;
 import java.lang.reflect.Field;
 import org.apache.samza.Partition;
 import org.apache.samza.config.Config;
+import org.apache.samza.config.MapConfig;
 import org.apache.samza.operators.impl.OperatorGraph;
+import org.apache.samza.runtime.ApplicationRunner;
 import org.apache.samza.system.SystemStreamPartition;
 import org.apache.samza.task.StreamOperatorTask;
 import org.apache.samza.task.TaskContext;
@@ -45,18 +47,20 @@ public class TestBasicStreamGraphs {
       }
     } };
 
+  private final ApplicationRunner runner = mock(ApplicationRunner.class);
+
   @Test
   public void testUserTask() throws Exception {
-    Config mockConfig = mock(Config.class);
+    Config config = new MapConfig();
     TaskContext mockContext = mock(TaskContext.class);
     when(mockContext.getSystemStreamPartitions()).thenReturn(this.inputPartitions);
     TestWindowExample userTask = new TestWindowExample(this.inputPartitions);
-    StreamOperatorTask adaptorTask = new StreamOperatorTask(userTask);
+    StreamOperatorTask adaptorTask = new StreamOperatorTask(userTask, runner);
     Field pipelineMapFld = StreamOperatorTask.class.getDeclaredField("operatorGraph");
     pipelineMapFld.setAccessible(true);
     OperatorGraph opGraph = (OperatorGraph) pipelineMapFld.get(adaptorTask);
 
-    adaptorTask.init(mockConfig, mockContext);
+    adaptorTask.init(config, mockContext);
     this.inputPartitions.forEach(partition -> {
         assertNotNull(opGraph.get(partition.getSystemStream()));
       });
@@ -64,16 +68,16 @@ public class TestBasicStreamGraphs {
 
   @Test
   public void testSplitTask() throws Exception {
-    Config mockConfig = mock(Config.class);
+    Config config = new MapConfig();
     TaskContext mockContext = mock(TaskContext.class);
     when(mockContext.getSystemStreamPartitions()).thenReturn(this.inputPartitions);
     TestBroadcastExample splitTask = new TestBroadcastExample(this.inputPartitions);
-    StreamOperatorTask adaptorTask = new StreamOperatorTask(splitTask);
+    StreamOperatorTask adaptorTask = new StreamOperatorTask(splitTask, runner);
     Field pipelineMapFld = StreamOperatorTask.class.getDeclaredField("operatorGraph");
     pipelineMapFld.setAccessible(true);
     OperatorGraph opGraph = (OperatorGraph) pipelineMapFld.get(adaptorTask);
 
-    adaptorTask.init(mockConfig, mockContext);
+    adaptorTask.init(config, mockContext);
     this.inputPartitions.forEach(partition -> {
         assertNotNull(opGraph.get(partition.getSystemStream()));
       });
@@ -81,16 +85,16 @@ public class TestBasicStreamGraphs {
 
   @Test
   public void testJoinTask() throws Exception {
-    Config mockConfig = mock(Config.class);
+    Config config = new MapConfig();
     TaskContext mockContext = mock(TaskContext.class);
     when(mockContext.getSystemStreamPartitions()).thenReturn(this.inputPartitions);
     TestJoinExample joinTask = new TestJoinExample(this.inputPartitions);
-    StreamOperatorTask adaptorTask = new StreamOperatorTask(joinTask);
+    StreamOperatorTask adaptorTask = new StreamOperatorTask(joinTask, runner);
     Field pipelineMapFld = StreamOperatorTask.class.getDeclaredField("operatorGraph");
     pipelineMapFld.setAccessible(true);
     OperatorGraph opGraph = (OperatorGraph) pipelineMapFld.get(adaptorTask);
 
-    adaptorTask.init(mockConfig, mockContext);
+    adaptorTask.init(config, mockContext);
     this.inputPartitions.forEach(partition -> {
         assertNotNull(opGraph.get(partition.getSystemStream()));
       });
index 160a47a..c22bd95 100644 (file)
  */
 package org.apache.samza.operators;
 
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Set;
+import java.util.function.Function;
+import org.apache.samza.config.Config;
+import org.apache.samza.config.JobConfig;
+import org.apache.samza.config.MapConfig;
 import org.apache.samza.operators.functions.FilterFunction;
 import org.apache.samza.operators.functions.FlatMapFunction;
 import org.apache.samza.operators.functions.JoinFunction;
@@ -27,17 +37,13 @@ import org.apache.samza.operators.spec.OperatorSpec;
 import org.apache.samza.operators.spec.PartialJoinOperatorSpec;
 import org.apache.samza.operators.spec.SinkOperatorSpec;
 import org.apache.samza.operators.spec.StreamOperatorSpec;
+import org.apache.samza.runtime.ApplicationRunner;
 import org.apache.samza.system.OutgoingMessageEnvelope;
 import org.apache.samza.system.SystemStream;
 import org.apache.samza.task.MessageCollector;
 import org.apache.samza.task.TaskCoordinator;
 import org.junit.Test;
 
-import java.util.ArrayList;
-import java.util.Collection;
-import java.util.HashSet;
-import java.util.Set;
-
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertNull;
 import static org.junit.Assert.assertTrue;
@@ -197,8 +203,36 @@ public class TestMessageStreamImpl {
     assertTrue(mergeOp instanceof StreamOperatorSpec);
     assertEquals(((StreamOperatorSpec) mergeOp).getNextStream(), mergeOutput);
     TestMessageEnvelope mockMsg = mock(TestMessageEnvelope.class);
-    Collection<TestMessageEnvelope> outputs = ((StreamOperatorSpec<TestMessageEnvelope, TestMessageEnvelope>) mergeOp).getTransformFn().apply(mockMsg);
+    Collection<TestMessageEnvelope> outputs = ((StreamOperatorSpec<TestMessageEnvelope, TestMessageEnvelope>) mergeOp).getTransformFn().apply(
+        mockMsg);
     assertEquals(outputs.size(), 1);
     assertEquals(outputs.iterator().next(), mockMsg);
   }
+
+  @Test
+  public void testPartitionBy() {
+    Map<String, String> map = new HashMap<>();
+    map.put(JobConfig.JOB_DEFAULT_SYSTEM(), "testsystem");
+    Config config = new MapConfig(map);
+    ApplicationRunner runner = ApplicationRunner.fromConfig(config);
+    StreamGraphImpl streamGraph = new StreamGraphImpl(runner, config);
+    MessageStreamImpl<TestMessageEnvelope> inputStream = new MessageStreamImpl<>(streamGraph);
+    Function<TestMessageEnvelope, String> keyExtractorFunc = m -> "222";
+    inputStream.partitionBy(keyExtractorFunc);
+    assertTrue(streamGraph.getInStreams().size() == 1);
+    assertTrue(streamGraph.getOutStreams().size() == 1);
+
+    Collection<OperatorSpec> subs = inputStream.getRegisteredOperatorSpecs();
+    assertEquals(subs.size(), 1);
+    OperatorSpec<TestMessageEnvelope> partitionByOp = subs.iterator().next();
+    assertTrue(partitionByOp instanceof SinkOperatorSpec);
+    assertNull(partitionByOp.getNextStream());
+
+    ((SinkOperatorSpec) partitionByOp).getSinkFn().apply(new TestMessageEnvelope("111", "test", 1000), new MessageCollector() {
+      @Override
+      public void send(OutgoingMessageEnvelope envelope) {
+        assertTrue(envelope.getPartitionKey().equals("222"));
+      }
+    }, null);
+  }
 }