SAMZA-1226: relax type parameters in MessageStream functions
authorYi Pan (Data Infrastructure) <nickpan47@gmail.com>
Fri, 21 Apr 2017 23:14:41 +0000 (16:14 -0700)
committerYi Pan (Data Infrastructure) <nickpan47@gmail.com>
Fri, 21 Apr 2017 23:14:41 +0000 (16:14 -0700)
relax the type parameter in user supplied functions in fluent API

Author: Yi Pan (Data Infrastructure) <nickpan47@gmail.com>

Reviewers: Prateek Maheshwari <pmaheshw@linkedin.com>, Navina Ramesh <navina@apache.org>, Jacob Maes <jmaes@linkedin.com>

Closes #133 from nickpan47/SAMZA-1226 and squashes the following commits:

b8d3461 [Yi Pan (Data Infrastructure)] SAMZA-1226: cleanup code example in StreamApplication javadoc
93fa471 [Yi Pan (Data Infrastructure)] SAMZA-1226: added more unit tests for type-cast functions
18e1e9f [Yi Pan (Data Infrastructure)] SAMZA-1226: address review feedbacks
b5da53b [Yi Pan (Data Infrastructure)] Merge branch 'master' into SAMZA-1226
7981b83 [Yi Pan (Data Infrastructure)] SAMZA-1226: relax type parameters in MessageStream functions

23 files changed:
samza-api/src/main/java/org/apache/samza/application/StreamApplication.java
samza-api/src/main/java/org/apache/samza/operators/MessageStream.java
samza-api/src/main/java/org/apache/samza/operators/StreamGraph.java
samza-api/src/main/java/org/apache/samza/operators/windows/Windows.java
samza-api/src/main/java/org/apache/samza/task/TaskContext.java
samza-core/src/main/java/org/apache/samza/config/ApplicationConfig.java
samza-core/src/main/java/org/apache/samza/operators/MessageStreamImpl.java
samza-core/src/main/java/org/apache/samza/operators/StreamGraphImpl.java
samza-core/src/main/java/org/apache/samza/operators/spec/OperatorSpecs.java
samza-core/src/main/java/org/apache/samza/operators/spec/StreamOperatorSpec.java
samza-core/src/main/java/org/apache/samza/operators/stream/IntermediateStreamInternalImpl.java
samza-core/src/main/java/org/apache/samza/task/TaskFactoryUtil.java
samza-core/src/test/java/org/apache/samza/execution/TestExecutionPlanner.java
samza-core/src/test/java/org/apache/samza/execution/TestJobGraphJsonGenerator.java
samza-core/src/test/java/org/apache/samza/operators/TestMessageStreamImpl.java
samza-core/src/test/java/org/apache/samza/operators/TestStreamGraphImpl.java [new file with mode: 0644]
samza-core/src/test/java/org/apache/samza/operators/data/MessageType.java [new file with mode: 0644]
samza-core/src/test/java/org/apache/samza/operators/data/TestExtOutputMessageEnvelope.java [new file with mode: 0644]
samza-core/src/test/java/org/apache/samza/operators/data/TestInputMessageEnvelope.java [new file with mode: 0644]
samza-core/src/test/java/org/apache/samza/operators/data/TestMessageEnvelope.java
samza-core/src/test/java/org/apache/samza/operators/spec/TestOperatorSpecs.java
samza-core/src/test/java/org/apache/samza/task/TestTaskFactoryUtil.java
samza-test/src/test/scala/org/apache/samza/test/integration/StreamTaskTestUtil.scala

index eeece10..a26c5af 100644 (file)
@@ -24,11 +24,43 @@ import org.apache.samza.operators.StreamGraph;
 
 
 /**
- * This interface defines a template for stream application that user will implement to create operator DAG in {@link StreamGraph}.
+ * This interface defines a template for stream application that user will implement to initialize operator DAG in {@link StreamGraph}.
+ *
+ * <p>
+ * User program implements {@link StreamApplication#init(StreamGraph, Config)} method to initialize the transformation logic
+ * from all input streams to output streams. A simple user code example is shown below:
+ * </p>
+ *
+ * <pre>{@code
+ * public class PageViewCounterExample implements StreamApplication {
+ *   // max timeout is 60 seconds
+ *   private static final MAX_TIMEOUT = 60000;
+ *
+ *   public void init(StreamGraph graph, Config config) {
+ *     MessageStream<PageViewEvent> pageViewEvents = graph.getInputStream("pageViewEventStream", (k, m) -> (PageViewEvent) m);
+ *     OutputStream<String, PageViewEvent, PageViewEvent> pageViewEventFilteredStream = graph
+ *       .getOutputStream("pageViewEventFiltered", m -> m.memberId, m -> m);
+ *
+ *     pageViewEvents
+ *       .filter(m -> !(m.getMessage().getEventTime() < System.currentTimeMillis() - MAX_TIMEOUT))
+ *       .sendTo(pageViewEventFilteredStream);
+ *   }
+ *
+ *   // local execution mode
+ *   public static void main(String[] args) {
+ *     CommandLine cmdLine = new CommandLine();
+ *     Config config = cmdLine.loadConfig(cmdLine.parser().parse(args));
+ *     PageViewCounterExample userApp = new PageViewCounterExample();
+ *     ApplicationRunner localRunner = ApplicationRunner.getLocalRunner(config);
+ *     localRunner.run(userApp);
+ *   }
+ *
+ * }
+ * }</pre>
+ *
  */
 @InterfaceStability.Unstable
 public interface StreamApplication {
-  static final String APP_CLASS_CONFIG = "app.class";
 
   /**
    * Users are required to implement this abstract method to initialize the processing logic of the application, in terms
@@ -38,4 +70,5 @@ public interface StreamApplication {
    * @param config  the {@link Config} of the application
    */
   void init(StreamGraph graph, Config config);
+
 }
index 345bff0..c406a93 100644 (file)
@@ -50,7 +50,7 @@ public interface MessageStream<M> {
    * @param <TM> the type of messages in the transformed {@link MessageStream}
    * @return the transformed {@link MessageStream}
    */
-  <TM> MessageStream<TM> map(MapFunction<M, TM> mapFn);
+  <TM> MessageStream<TM> map(MapFunction<? super M, ? extends TM> mapFn);
 
   /**
    * Applies the provided 1:n function to transform a message in this {@link MessageStream}
@@ -60,7 +60,7 @@ public interface MessageStream<M> {
    * @param <TM> the type of messages in the transformed {@link MessageStream}
    * @return the transformed {@link MessageStream}
    */
-  <TM> MessageStream<TM> flatMap(FlatMapFunction<M, TM> flatMapFn);
+  <TM> MessageStream<TM> flatMap(FlatMapFunction<? super M, ? extends TM> flatMapFn);
 
   /**
    * Applies the provided function to messages in this {@link MessageStream} and returns the
@@ -72,7 +72,7 @@ public interface MessageStream<M> {
    * @param filterFn the predicate to filter messages from this {@link MessageStream}
    * @return the transformed {@link MessageStream}
    */
-  MessageStream<M> filter(FilterFunction<M> filterFn);
+  MessageStream<M> filter(FilterFunction<? super M> filterFn);
 
   /**
    * Allows sending messages in this {@link MessageStream} to an output system using the provided {@link SinkFunction}.
@@ -83,7 +83,7 @@ public interface MessageStream<M> {
    *
    * @param sinkFn the function to send messages in this stream to an external system
    */
-  void sink(SinkFunction<M> sinkFn);
+  void sink(SinkFunction<? super M> sinkFn);
 
   /**
    * Allows sending messages in this {@link MessageStream} to an output {@link MessageStream}.
@@ -120,10 +120,10 @@ public interface MessageStream<M> {
    * @param ttl the ttl for messages in each stream
    * @param <K> the type of join key
    * @param <OM> the type of messages in the other stream
-   * @param <RM> the type of messages resulting from the {@code joinFn}
+   * @param <TM> the type of messages resulting from the {@code joinFn}
    * @return the joined {@link MessageStream}
    */
-  <K, OM, RM> MessageStream<RM> join(MessageStream<OM> otherStream, JoinFunction<K, M, OM, RM> joinFn, Duration ttl);
+  <K, OM, TM> MessageStream<TM> join(MessageStream<OM> otherStream, JoinFunction<? extends K, ? super M, ? super OM, ? extends TM> joinFn, Duration ttl);
 
   /**
    * Merge all {@code otherStreams} with this {@link MessageStream}.
@@ -133,7 +133,7 @@ public interface MessageStream<M> {
    * @param otherStreams other {@link MessageStream}s to be merged with this {@link MessageStream}
    * @return the merged {@link MessageStream}
    */
-  MessageStream<M> merge(Collection<MessageStream<M>> otherStreams);
+  MessageStream<M> merge(Collection<MessageStream<? extends M>> otherStreams);
 
   /**
    * Sends the messages of type {@code M}in this {@link MessageStream} to a repartitioned output stream and consumes
@@ -144,6 +144,6 @@ public interface MessageStream<M> {
    * @param <K> the type of output message key and partition key
    * @return the repartitioned {@link MessageStream}
    */
-  <K> MessageStream<M> partitionBy(Function<M, K> keyExtractor);
+  <K> MessageStream<M> partitionBy(Function<? super M, ? extends K> keyExtractor);
 
 }
index ff1c580..a03f7c3 100644 (file)
@@ -40,7 +40,7 @@ public interface StreamGraph {
    * @param <M> the type of message in the input {@link MessageStream}
    * @return the input {@link MessageStream}
    */
-  <K, V, M> MessageStream<M> getInputStream(String streamId, BiFunction<K, V, M> msgBuilder);
+  <K, V, M> MessageStream<M> getInputStream(String streamId, BiFunction<? super K, ? super V, ? extends M> msgBuilder);
 
   /**
    * Gets the {@link OutputStream} corresponding to the logical {@code streamId}.
@@ -54,7 +54,7 @@ public interface StreamGraph {
    * @return the output {@link MessageStream}
    */
   <K, V, M> OutputStream<K, V, M> getOutputStream(String streamId,
-      Function<M, K> keyExtractor, Function<M, V> msgExtractor);
+      Function<? super M, ? extends K> keyExtractor, Function<? super M, ? extends V> msgExtractor);
 
   /**
    * Sets the {@link ContextManager} for this {@link StreamGraph}.
index 9192fc1..721b4c0 100644 (file)
@@ -119,11 +119,12 @@ public final class Windows {
    * @param <K> the type of the key in the {@link Window}
    * @return the created {@link Window} function.
    */
-  public static <M, K, WV> Window<M, K, WV> keyedTumblingWindow(Function<M, K> keyFn, Duration interval,
-                                                                Supplier<WV> initialValue, FoldLeftFunction<M, WV> foldFn) {
+  public static <M, K, WV> Window<M, K, WV> keyedTumblingWindow(Function<? super M, ? extends K> keyFn, Duration interval,
+                                                                Supplier<? extends WV> initialValue, FoldLeftFunction<? super M, WV> foldFn) {
 
     Trigger<M> defaultTrigger = new TimeTrigger<>(interval);
-    return new WindowInternal<M, K, WV>(defaultTrigger, initialValue, foldFn, keyFn, null, WindowType.TUMBLING);
+    return new WindowInternal<>(defaultTrigger, (Supplier<WV>) initialValue, (FoldLeftFunction<M, WV>) foldFn,
+        (Function<M, K>) keyFn, null, WindowType.TUMBLING);
   }
 
 
@@ -147,10 +148,10 @@ public final class Windows {
    * @param <K> the type of the key in the {@link Window}
    * @return the created {@link Window} function
    */
-  public static <M, K> Window<M, K, Collection<M>> keyedTumblingWindow(Function<M, K> keyFn, Duration interval) {
+  public static <M, K> Window<M, K, Collection<M>> keyedTumblingWindow(Function<? super M, ? extends K> keyFn, Duration interval) {
     FoldLeftFunction<M, Collection<M>> aggregator = createAggregator();
 
-    Supplier<Collection<M>> initialValue = () -> new ArrayList<>();
+    Supplier<Collection<M>> initialValue = ArrayList::new;
     return keyedTumblingWindow(keyFn, interval, initialValue, aggregator);
   }
 
@@ -175,10 +176,11 @@ public final class Windows {
    * @param <WV> the type of the {@link WindowPane} output value
    * @return the created {@link Window} function
    */
-  public static <M, WV> Window<M, Void, WV> tumblingWindow(Duration duration, Supplier<WV> initialValue,
-                                                           FoldLeftFunction<M, WV> foldFn) {
+  public static <M, WV> Window<M, Void, WV> tumblingWindow(Duration duration, Supplier<? extends WV> initialValue,
+                                                           FoldLeftFunction<? super M, WV> foldFn) {
     Trigger<M> defaultTrigger = Triggers.repeat(new TimeTrigger<>(duration));
-    return new WindowInternal<>(defaultTrigger, initialValue, foldFn, null, null, WindowType.TUMBLING);
+    return new WindowInternal<>(defaultTrigger, (Supplier<WV>) initialValue, (FoldLeftFunction<M, WV>) foldFn,
+        null, null, WindowType.TUMBLING);
   }
 
   /**
@@ -203,7 +205,7 @@ public final class Windows {
   public static <M> Window<M, Void, Collection<M>> tumblingWindow(Duration duration) {
     FoldLeftFunction<M, Collection<M>> aggregator = createAggregator();
 
-    Supplier<Collection<M>> initialValue = () -> new ArrayList<>();
+    Supplier<Collection<M>> initialValue = ArrayList::new;
     return tumblingWindow(duration, initialValue, aggregator);
   }
 
@@ -235,10 +237,11 @@ public final class Windows {
    * @param <WV> the type of the output value in the {@link WindowPane}
    * @return the created {@link Window} function
    */
-  public static <M, K, WV> Window<M, K, WV> keyedSessionWindow(Function<M, K> keyFn, Duration sessionGap,
-                                                               Supplier<WV> initialValue, FoldLeftFunction<M, WV> foldFn) {
+  public static <M, K, WV> Window<M, K, WV> keyedSessionWindow(Function<? super M, ? extends K> keyFn, Duration sessionGap,
+                                                               Supplier<? extends WV> initialValue, FoldLeftFunction<? super M, WV> foldFn) {
     Trigger<M> defaultTrigger = Triggers.timeSinceLastMessage(sessionGap);
-    return new WindowInternal<>(defaultTrigger, initialValue, foldFn, keyFn, null, WindowType.SESSION);
+    return new WindowInternal<>(defaultTrigger, (Supplier<WV>) initialValue, (FoldLeftFunction<M, WV>) foldFn, (Function<M, K>) keyFn,
+        null, WindowType.SESSION);
   }
 
   /**
@@ -265,11 +268,11 @@ public final class Windows {
    * @param <K> the type of the key in the {@link Window}
    * @return the created {@link Window} function
    */
-  public static <M, K> Window<M, K, Collection<M>> keyedSessionWindow(Function<M, K> keyFn, Duration sessionGap) {
+  public static <M, K> Window<M, K, Collection<M>> keyedSessionWindow(Function<? super M, ? extends K> keyFn, Duration sessionGap) {
 
     FoldLeftFunction<M, Collection<M>> aggregator = createAggregator();
 
-    Supplier<Collection<M>> initialValue = () -> new ArrayList<>();
+    Supplier<Collection<M>> initialValue = ArrayList::new;
     return keyedSessionWindow(keyFn, sessionGap, initialValue, aggregator);
   }
 
index 128cff1..dc5742f 100644 (file)
@@ -58,10 +58,9 @@ public interface TaskContext {
   /**
    * Method to allow user to return customized context
    *
-   * @param <T>  the type of user-defined task context
    * @return  user-defined task context object
    */
-  default <T> T getUserDefinedContext() {
+  default Object getUserDefinedContext() {
     return null;
   };
 }
index 1c00735..9eb4161 100644 (file)
@@ -46,6 +46,7 @@ public class ApplicationConfig extends MapConfig {
   public static final String APP_COORDINATION_SERVICE_FACTORY_CLASS = "app.coordination.service.factory.class";
   public static final String APP_NAME = "app.name";
   public static final String APP_ID = "app.id";
+  public static final String APP_CLASS = "app.class";
 
   public ApplicationConfig(Config config) {
     super(config);
@@ -67,6 +68,10 @@ public class ApplicationConfig extends MapConfig {
     return get(APP_ID, get(JobConfig.JOB_ID(), "1"));
   }
 
+  public String getAppClass() {
+    return get(APP_CLASS, null);
+  }
+
   @Deprecated
   public String getProcessorId() {
     return get(PROCESSOR_ID, null);
index dfe231e..69a41db 100644 (file)
@@ -72,7 +72,7 @@ public class MessageStreamImpl<M> implements MessageStream<M> {
   }
 
   @Override
-  public <TM> MessageStream<TM> map(MapFunction<M, TM> mapFn) {
+  public <TM> MessageStream<TM> map(MapFunction<? super M, ? extends TM> mapFn) {
     OperatorSpec<TM> op = OperatorSpecs.createMapOperatorSpec(
         mapFn, new MessageStreamImpl<>(this.graph), this.graph.getNextOpId());
     this.registeredOperatorSpecs.add(op);
@@ -80,7 +80,7 @@ public class MessageStreamImpl<M> implements MessageStream<M> {
   }
 
   @Override
-  public MessageStream<M> filter(FilterFunction<M> filterFn) {
+  public MessageStream<M> filter(FilterFunction<? super M> filterFn) {
     OperatorSpec<M> op = OperatorSpecs.createFilterOperatorSpec(
         filterFn, new MessageStreamImpl<>(this.graph), this.graph.getNextOpId());
     this.registeredOperatorSpecs.add(op);
@@ -88,7 +88,7 @@ public class MessageStreamImpl<M> implements MessageStream<M> {
   }
 
   @Override
-  public <TM> MessageStream<TM> flatMap(FlatMapFunction<M, TM> flatMapFn) {
+  public <TM> MessageStream<TM> flatMap(FlatMapFunction<? super M, ? extends TM> flatMapFn) {
     OperatorSpec<TM> op = OperatorSpecs.createStreamOperatorSpec(
         flatMapFn, new MessageStreamImpl<>(this.graph), this.graph.getNextOpId());
     this.registeredOperatorSpecs.add(op);
@@ -96,7 +96,7 @@ public class MessageStreamImpl<M> implements MessageStream<M> {
   }
 
   @Override
-  public void sink(SinkFunction<M> sinkFn) {
+  public void sink(SinkFunction<? super M> sinkFn) {
     SinkOperatorSpec<M> op = OperatorSpecs.createSinkOperatorSpec(sinkFn, this.graph.getNextOpId());
     this.registeredOperatorSpecs.add(op);
   }
@@ -110,22 +110,22 @@ public class MessageStreamImpl<M> implements MessageStream<M> {
 
   @Override
   public <K, WV> MessageStream<WindowPane<K, WV>> window(Window<M, K, WV> window) {
-    OperatorSpec<WindowPane<K, WV>> wndOp = OperatorSpecs.createWindowOperatorSpec((WindowInternal<M, K, WV>) window,
-        new MessageStreamImpl<>(this.graph), this.graph.getNextOpId());
+    OperatorSpec<WindowPane<K, WV>> wndOp = OperatorSpecs.createWindowOperatorSpec(
+        (WindowInternal<M, K, WV>) window, new MessageStreamImpl<>(this.graph), this.graph.getNextOpId());
     this.registeredOperatorSpecs.add(wndOp);
     return wndOp.getNextStream();
   }
 
   @Override
-  public <K, JM, RM> MessageStream<RM> join(
-      MessageStream<JM> otherStream, JoinFunction<K, M, JM, RM> joinFn, Duration ttl) {
-    MessageStreamImpl<RM> nextStream = new MessageStreamImpl<>(this.graph);
+  public <K, OM, TM> MessageStream<TM> join(
+      MessageStream<OM> otherStream, JoinFunction<? extends K, ? super M, ? super OM, ? extends TM> joinFn, Duration ttl) {
+    MessageStreamImpl<TM> nextStream = new MessageStreamImpl<>(this.graph);
 
-    PartialJoinFunction<K, M, JM, RM> thisPartialJoinFn = new PartialJoinFunction<K, M, JM, RM>() {
-      private KeyValueStore<K, PartialJoinMessage<M>> thisStreamState;
+    PartialJoinFunction<K, M, OM, TM> thisPartialJoinFn = new PartialJoinFunction<K, M, OM, TM>() {
+      private KeyValueStore<K, PartialJoinFunction.PartialJoinMessage<M>> thisStreamState;
 
       @Override
-      public RM apply(M m, JM jm) {
+      public TM apply(M m, OM jm) {
         return joinFn.apply(m, jm);
       }
 
@@ -148,21 +148,21 @@ public class MessageStreamImpl<M> implements MessageStream<M> {
       }
     };
 
-    PartialJoinFunction<K, JM, M, RM> otherPartialJoinFn = new PartialJoinFunction<K, JM, M, RM>() {
-      private KeyValueStore<K, PartialJoinMessage<JM>> otherStreamState;
+    PartialJoinFunction<K, OM, M, TM> otherPartialJoinFn = new PartialJoinFunction<K, OM, M, TM>() {
+      private KeyValueStore<K, PartialJoinMessage<OM>> otherStreamState;
 
       @Override
-      public RM apply(JM om, M m) {
+      public TM apply(OM om, M m) {
         return joinFn.apply(m, om);
       }
 
       @Override
-      public K getKey(JM message) {
+      public K getKey(OM message) {
         return joinFn.getSecondKey(message);
       }
 
       @Override
-      public KeyValueStore<K, PartialJoinMessage<JM>> getState() {
+      public KeyValueStore<K, PartialJoinMessage<OM>> getState() {
         return otherStreamState;
       }
 
@@ -175,7 +175,7 @@ public class MessageStreamImpl<M> implements MessageStream<M> {
     this.registeredOperatorSpecs.add(OperatorSpecs.createPartialJoinOperatorSpec(
         thisPartialJoinFn, otherPartialJoinFn, ttl.toMillis(), nextStream, this.graph.getNextOpId()));
 
-    ((MessageStreamImpl<JM>) otherStream).registeredOperatorSpecs
+    ((MessageStreamImpl<OM>) otherStream).registeredOperatorSpecs
         .add(OperatorSpecs.createPartialJoinOperatorSpec(
             otherPartialJoinFn, thisPartialJoinFn, ttl.toMillis(), nextStream, this.graph.getNextOpId()));
 
@@ -183,7 +183,7 @@ public class MessageStreamImpl<M> implements MessageStream<M> {
   }
 
   @Override
-  public MessageStream<M> merge(Collection<MessageStream<M>> otherStreams) {
+  public MessageStream<M> merge(Collection<MessageStream<? extends M>> otherStreams) {
     MessageStreamImpl<M> nextStream = new MessageStreamImpl<>(this.graph);
 
     otherStreams.add(this);
@@ -193,7 +193,7 @@ public class MessageStreamImpl<M> implements MessageStream<M> {
   }
 
   @Override
-  public <K> MessageStream<M> partitionBy(Function<M, K> keyExtractor) {
+  public <K> MessageStream<M> partitionBy(Function<? super M, ? extends K> keyExtractor) {
     int opId = this.graph.getNextOpId();
     String opName = String.format("%s-%s", OperatorSpec.OpCode.PARTITION_BY.name().toLowerCase(), opId);
     MessageStreamImpl<M> intermediateStream =
index a49b68e..86ce6a4 100644 (file)
@@ -61,16 +61,25 @@ public class StreamGraphImpl implements StreamGraph {
   }
 
   @Override
-  public <K, V, M> MessageStream<M> getInputStream(String streamId, BiFunction<K, V, M> msgBuilder) {
+  public <K, V, M> MessageStream<M> getInputStream(String streamId, BiFunction<? super K, ? super V, ? extends M> msgBuilder) {
+    if (msgBuilder == null) {
+      throw new IllegalArgumentException("msgBuilder can't be null for an input stream");
+    }
     return inStreams.computeIfAbsent(runner.getStreamSpec(streamId),
-        streamSpec -> new InputStreamInternalImpl<>(this, streamSpec, msgBuilder));
+        streamSpec -> new InputStreamInternalImpl<>(this, streamSpec, (BiFunction<K, V, M>) msgBuilder));
   }
 
   @Override
   public <K, V, M> OutputStream<K, V, M> getOutputStream(String streamId,
-      Function<M, K> keyExtractor, Function<M, V> msgExtractor) {
+      Function<? super M, ? extends K> keyExtractor, Function<? super M, ? extends V> msgExtractor) {
+    if (keyExtractor == null) {
+      throw new IllegalArgumentException("keyExtractor can't be null for an output stream.");
+    }
+    if (msgExtractor == null) {
+      throw new IllegalArgumentException("msgExtractor can't be null for an output stream.");
+    }
     return outStreams.computeIfAbsent(runner.getStreamSpec(streamId),
-        streamSpec -> new OutputStreamInternalImpl<>(this, streamSpec, keyExtractor, msgExtractor));
+        streamSpec -> new OutputStreamInternalImpl<>(this, streamSpec, (Function<M, K>) keyExtractor, (Function<M, V>) msgExtractor));
   }
 
   @Override
@@ -95,16 +104,28 @@ public class StreamGraphImpl implements StreamGraph {
    * @return  the intermediate {@link MessageStreamImpl}
    */
   <K, V, M> MessageStreamImpl<M> getIntermediateStream(String streamName,
-      Function<M, K> keyExtractor, Function<M, V> msgExtractor, BiFunction<K, V, M> msgBuilder) {
+      Function<? super M, ? extends K> keyExtractor, Function<? super M, ? extends V> msgExtractor, BiFunction<? super K, ? super V, ? extends M> msgBuilder) {
     String streamId = String.format("%s-%s-%s",
         config.get(JobConfig.JOB_NAME()),
         config.get(JobConfig.JOB_ID(), "1"),
         streamName);
+    if (msgBuilder == null) {
+      throw new IllegalArgumentException("msgBuilder cannot be null for an intermediate stream");
+    }
+
+    if (keyExtractor == null) {
+      throw new IllegalArgumentException("keyExtractor can't be null for an output stream.");
+    }
+    if (msgExtractor == null) {
+      throw new IllegalArgumentException("msgExtractor can't be null for an output stream.");
+    }
+
     StreamSpec streamSpec = runner.getStreamSpec(streamId);
     IntermediateStreamInternalImpl<K, V, M> intStream =
         (IntermediateStreamInternalImpl<K, V, M>) inStreams
             .computeIfAbsent(streamSpec,
-                k -> new IntermediateStreamInternalImpl<>(this, streamSpec, keyExtractor, msgExtractor, msgBuilder));
+                k -> new IntermediateStreamInternalImpl<>(this, streamSpec, (Function<M, K>) keyExtractor,
+                    (Function<M, V>) msgExtractor, (BiFunction<K, V, M>) msgBuilder));
     outStreams.putIfAbsent(streamSpec, intStream);
     return intStream;
   }
index e2c4b9a..0b93bbe 100644 (file)
@@ -53,7 +53,7 @@ public class OperatorSpecs {
    * @return  the {@link StreamOperatorSpec}
    */
   public static <M, OM> StreamOperatorSpec<M, OM> createMapOperatorSpec(
-      MapFunction<M, OM> mapFn, MessageStreamImpl<OM> nextStream, int opId) {
+      MapFunction<? super M, ? extends OM> mapFn, MessageStreamImpl<OM> nextStream, int opId) {
     return new StreamOperatorSpec<>(new FlatMapFunction<M, OM>() {
       @Override
       public Collection<OM> apply(M message) {
@@ -84,7 +84,7 @@ public class OperatorSpecs {
    * @return  the {@link StreamOperatorSpec}
    */
   public static <M> StreamOperatorSpec<M, M> createFilterOperatorSpec(
-      FilterFunction<M> filterFn, MessageStreamImpl<M> nextStream, int opId) {
+      FilterFunction<? super M> filterFn, MessageStreamImpl<M> nextStream, int opId) {
     return new StreamOperatorSpec<>(new FlatMapFunction<M, M>() {
       @Override
       public Collection<M> apply(M message) {
@@ -115,8 +115,8 @@ public class OperatorSpecs {
    * @return  the {@link StreamOperatorSpec}
    */
   public static <M, OM> StreamOperatorSpec<M, OM> createStreamOperatorSpec(
-      FlatMapFunction<M, OM> transformFn, MessageStreamImpl<OM> nextStream, int opId) {
-    return new StreamOperatorSpec<>(transformFn, nextStream, OperatorSpec.OpCode.FLAT_MAP, opId);
+      FlatMapFunction<? super M, ? extends OM> transformFn, MessageStreamImpl<OM> nextStream, int opId) {
+    return new StreamOperatorSpec<>((FlatMapFunction<M, OM>) transformFn, nextStream, OperatorSpec.OpCode.FLAT_MAP, opId);
   }
 
   /**
@@ -127,8 +127,8 @@ public class OperatorSpecs {
    * @param <M>  type of input message
    * @return  the {@link SinkOperatorSpec} for the sink operator
    */
-  public static <M> SinkOperatorSpec<M> createSinkOperatorSpec(SinkFunction<M> sinkFn, int opId) {
-    return new SinkOperatorSpec<>(sinkFn, OperatorSpec.OpCode.SINK, opId);
+  public static <M> SinkOperatorSpec<M> createSinkOperatorSpec(SinkFunction<? super M> sinkFn, int opId) {
+    return new SinkOperatorSpec<>((SinkFunction<M>) sinkFn, OperatorSpec.OpCode.SINK, opId);
   }
 
   /**
@@ -195,7 +195,7 @@ public class OperatorSpecs {
   public static <K, M, JM, RM> PartialJoinOperatorSpec<K, M, JM, RM> createPartialJoinOperatorSpec(
       PartialJoinFunction<K, M, JM, RM> thisPartialJoinFn, PartialJoinFunction<K, JM, M, RM> otherPartialJoinFn,
       long ttlMs, MessageStreamImpl<RM> nextStream, int opId) {
-    return new PartialJoinOperatorSpec<K, M, JM, RM>(thisPartialJoinFn, otherPartialJoinFn, ttlMs, nextStream, opId);
+    return new PartialJoinOperatorSpec<>(thisPartialJoinFn, otherPartialJoinFn, ttlMs, nextStream, opId);
   }
 
   /**
@@ -207,7 +207,7 @@ public class OperatorSpecs {
    * @return  the {@link StreamOperatorSpec} for the merge
    */
   public static <M> StreamOperatorSpec<M, M> createMergeOperatorSpec(MessageStreamImpl<M> nextStream, int opId) {
-    return new StreamOperatorSpec<M, M>(message ->
+    return new StreamOperatorSpec<>(message ->
         new ArrayList<M>() {
           {
             this.add(message);
index 3c427c7..f9bbe2d 100644 (file)
@@ -45,7 +45,7 @@ public class StreamOperatorSpec<M, OM> implements OperatorSpec<OM> {
    * @param opCode  the {@link OpCode} for this {@link StreamOperatorSpec}
    * @param opId  the unique id for this {@link StreamOperatorSpec} in a {@link org.apache.samza.operators.StreamGraph}
    */
-  StreamOperatorSpec(FlatMapFunction<M, OM> transformFn, MessageStreamImpl nextStream,
+  StreamOperatorSpec(FlatMapFunction<M, OM> transformFn, MessageStreamImpl<OM> nextStream,
       OperatorSpec.OpCode opCode, int opId) {
     this.transformFn = transformFn;
     this.nextStream = nextStream;
index a1bee6a..8f45f7a 100644 (file)
@@ -33,8 +33,8 @@ public class IntermediateStreamInternalImpl<K, V, M> extends MessageStreamImpl<M
   private final Function<M, V> msgExtractor;
   private final BiFunction<K, V, M> msgBuilder;
 
-  public IntermediateStreamInternalImpl(StreamGraphImpl graph, StreamSpec streamSpec,
-      Function<M, K> keyExtractor, Function<M, V> msgExtractor, BiFunction<K, V, M> msgBuilder) {
+  public IntermediateStreamInternalImpl(StreamGraphImpl graph, StreamSpec streamSpec, Function<M, K> keyExtractor,
+      Function<M, V> msgExtractor, BiFunction<K, V, M> msgBuilder) {
     super(graph);
     this.streamSpec = streamSpec;
     this.keyExtractor = keyExtractor;
index 445d13e..6408e6f 100644 (file)
@@ -19,6 +19,7 @@
 package org.apache.samza.task;
 
 import org.apache.samza.SamzaException;
+import org.apache.samza.config.ApplicationConfig;
 import org.apache.samza.config.Config;
 import org.apache.samza.config.ConfigException;
 import org.apache.samza.application.StreamApplication;
@@ -158,19 +159,20 @@ public class TaskFactoryUtil {
    * @return {@link StreamApplication} instance
    */
   public static StreamApplication createStreamApplication(Config config) {
-    if (config.get(StreamApplication.APP_CLASS_CONFIG) != null && !config.get(StreamApplication.APP_CLASS_CONFIG).isEmpty()) {
+    ApplicationConfig appConfig = new ApplicationConfig(config);
+    if (appConfig.getAppClass() != null && !appConfig.getAppClass().isEmpty()) {
       TaskConfig taskConfig = new TaskConfig(config);
       if (taskConfig.getTaskClass() != null && !taskConfig.getTaskClass().isEmpty()) {
         throw new ConfigException("High level StreamApplication API cannot be used together with low-level API using task.class.");
       }
 
-      String appClassName = config.get(StreamApplication.APP_CLASS_CONFIG);
+      String appClassName = appConfig.getAppClass();
       try {
         Class<?> builderClass = Class.forName(appClassName);
         return (StreamApplication) builderClass.newInstance();
       } catch (Throwable t) {
         String errorMsg = String.format("Failed to create StreamApplication class from the config. %s = %s",
-            StreamApplication.APP_CLASS_CONFIG, config.get(StreamApplication.APP_CLASS_CONFIG));
+            ApplicationConfig.APP_CLASS, appConfig.getAppClass());
         log.error(errorMsg, t);
         throw new ConfigException(errorMsg, t);
       }
index c55fcd0..b7f952a 100644 (file)
@@ -42,6 +42,8 @@ import java.util.Collections;
 import java.util.HashMap;
 import java.util.Map;
 import java.util.Set;
+import java.util.function.BiFunction;
+import java.util.function.Function;
 
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertTrue;
@@ -116,8 +118,10 @@ public class TestExecutionPlanner {
      *
      */
     StreamGraphImpl streamGraph = new StreamGraphImpl(runner, config);
-    OutputStream<Object, Object, Object> output1 = streamGraph.getOutputStream("output1", null, null);
-    streamGraph.getInputStream("input1", null)
+    Function mockFn = mock(Function.class);
+    OutputStream<Object, Object, Object> output1 = streamGraph.getOutputStream("output1", mockFn, mockFn);
+    BiFunction mockBuilder = mock(BiFunction.class);
+    streamGraph.getInputStream("input1", mockBuilder)
         .partitionBy(m -> "yes!!!").map(m -> m)
         .sendTo(output1);
     return streamGraph;
@@ -137,11 +141,13 @@ public class TestExecutionPlanner {
      */
 
     StreamGraphImpl streamGraph = new StreamGraphImpl(runner, config);
-    MessageStream m1 = streamGraph.getInputStream("input1", null).map(m -> m);
-    MessageStream m2 = streamGraph.getInputStream("input2", null).partitionBy(m -> "haha").filter(m -> true);
-    MessageStream m3 = streamGraph.getInputStream("input3", null).filter(m -> true).partitionBy(m -> "hehe").map(m -> m);
-    OutputStream<Object, Object, Object> output1 = streamGraph.getOutputStream("output1", null, null);
-    OutputStream<Object, Object, Object> output2 = streamGraph.getOutputStream("output2", null, null);
+    BiFunction msgBuilder = mock(BiFunction.class);
+    MessageStream m1 = streamGraph.getInputStream("input1", msgBuilder).map(m -> m);
+    MessageStream m2 = streamGraph.getInputStream("input2", msgBuilder).partitionBy(m -> "haha").filter(m -> true);
+    MessageStream m3 = streamGraph.getInputStream("input3", msgBuilder).filter(m -> true).partitionBy(m -> "hehe").map(m -> m);
+    Function mockFn = mock(Function.class);
+    OutputStream<Object, Object, Object> output1 = streamGraph.getOutputStream("output1", mockFn, mockFn);
+    OutputStream<Object, Object, Object> output2 = streamGraph.getOutputStream("output2", mockFn, mockFn);
 
     m1.join(m2, mock(JoinFunction.class), Duration.ofHours(2)).sendTo(output1);
     m3.join(m2, mock(JoinFunction.class), Duration.ofHours(1)).sendTo(output2);
index 9f9945b..c4ab922 100644 (file)
@@ -22,6 +22,9 @@ package org.apache.samza.execution;
 import java.time.Duration;
 import java.util.HashMap;
 import java.util.Map;
+import java.util.function.BiFunction;
+import java.util.function.Function;
+
 import org.apache.samza.config.Config;
 import org.apache.samza.config.JobConfig;
 import org.apache.samza.config.MapConfig;
@@ -101,11 +104,13 @@ public class TestJobGraphJsonGenerator {
     StreamManager streamManager = new StreamManager(systemAdmins);
 
     StreamGraphImpl streamGraph = new StreamGraphImpl(runner, config);
-    MessageStream m1 = streamGraph.getInputStream("input1", null).map(m -> m);
-    MessageStream m2 = streamGraph.getInputStream("input2", null).partitionBy(m -> "haha").filter(m -> true);
-    MessageStream m3 = streamGraph.getInputStream("input3", null).filter(m -> true).partitionBy(m -> "hehe").map(m -> m);
-    OutputStream<Object, Object, Object> outputStream1 = streamGraph.getOutputStream("output1", null, null);
-    OutputStream<Object, Object, Object> outputStream2 = streamGraph.getOutputStream("output2", null, null);
+    BiFunction mockBuilder = mock(BiFunction.class);
+    MessageStream m1 = streamGraph.getInputStream("input1", mockBuilder).map(m -> m);
+    MessageStream m2 = streamGraph.getInputStream("input2", mockBuilder).partitionBy(m -> "haha").filter(m -> true);
+    MessageStream m3 = streamGraph.getInputStream("input3", mockBuilder).filter(m -> true).partitionBy(m -> "hehe").map(m -> m);
+    Function mockFn = mock(Function.class);
+    OutputStream<Object, Object, Object> outputStream1 = streamGraph.getOutputStream("output1", mockFn, mockFn);
+    OutputStream<Object, Object, Object> outputStream2 = streamGraph.getOutputStream("output2", mockFn, mockFn);
 
     m1.join(m2, mock(JoinFunction.class), Duration.ofHours(2)).sendTo(outputStream1);
     m3.join(m2, mock(JoinFunction.class), Duration.ofHours(1)).sendTo(outputStream2);
index e815b81..44870fd 100644 (file)
@@ -21,8 +21,7 @@ package org.apache.samza.operators;
 import org.apache.samza.config.Config;
 import org.apache.samza.config.JobConfig;
 import org.apache.samza.config.MapConfig;
-import org.apache.samza.operators.data.TestMessageEnvelope;
-import org.apache.samza.operators.data.TestOutputMessageEnvelope;
+import org.apache.samza.operators.data.*;
 import org.apache.samza.operators.functions.FilterFunction;
 import org.apache.samza.operators.functions.FlatMapFunction;
 import org.apache.samza.operators.functions.JoinFunction;
@@ -36,6 +35,7 @@ import org.apache.samza.runtime.ApplicationRunner;
 import org.apache.samza.system.OutgoingMessageEnvelope;
 import org.apache.samza.system.SystemStream;
 import org.apache.samza.task.MessageCollector;
+import org.apache.samza.task.TaskContext;
 import org.apache.samza.task.TaskCoordinator;
 import org.junit.Test;
 
@@ -43,14 +43,15 @@ import java.time.Duration;
 import java.util.ArrayList;
 import java.util.Collection;
 import java.util.HashMap;
-import java.util.HashSet;
+import java.util.List;
 import java.util.Map;
-import java.util.Set;
 import java.util.function.Function;
 
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertNull;
 import static org.junit.Assert.assertTrue;
+import static org.mockito.Matchers.any;
+import static org.mockito.Mockito.doAnswer;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.when;
 
@@ -72,7 +73,7 @@ public class TestMessageStreamImpl {
     assertEquals(mapOp.getNextStream(), outputStream);
     // assert that the transformation function is what we defined above
     TestMessageEnvelope xTestMsg = mock(TestMessageEnvelope.class);
-    TestMessageEnvelope.MessageType mockInnerTestMessage = mock(TestMessageEnvelope.MessageType.class);
+    MessageType mockInnerTestMessage = mock(MessageType.class);
     when(xTestMsg.getKey()).thenReturn("test-msg-key");
     when(xTestMsg.getMessage()).thenReturn(mockInnerTestMessage);
     when(mockInnerTestMessage.getValue()).thenReturn("123456789");
@@ -87,20 +88,74 @@ public class TestMessageStreamImpl {
   @Test
   public void testFlatMap() {
     MessageStreamImpl<TestMessageEnvelope> inputStream = new MessageStreamImpl<>(mockGraph);
-    Set<TestOutputMessageEnvelope> flatOuts = new HashSet<TestOutputMessageEnvelope>() { {
+    List<TestOutputMessageEnvelope> flatOuts = new ArrayList<TestOutputMessageEnvelope>() { {
         this.add(mock(TestOutputMessageEnvelope.class));
         this.add(mock(TestOutputMessageEnvelope.class));
         this.add(mock(TestOutputMessageEnvelope.class));
       } };
-    FlatMapFunction<TestMessageEnvelope, TestOutputMessageEnvelope> xFlatMap = (TestMessageEnvelope message) -> flatOuts;
+    final List<TestMessageEnvelope> inputMsgs = new ArrayList<>();
+    FlatMapFunction<TestMessageEnvelope, TestOutputMessageEnvelope> xFlatMap = (TestMessageEnvelope message) -> {
+      inputMsgs.add(message);
+      return flatOuts;
+    };
+    MessageStream<TestOutputMessageEnvelope> outputStream = inputStream.flatMap(xFlatMap);
+    Collection<OperatorSpec> subs = inputStream.getRegisteredOperatorSpecs();
+    assertEquals(subs.size(), 1);
+    OperatorSpec<TestOutputMessageEnvelope> flatMapOp = subs.iterator().next();
+    assertTrue(flatMapOp instanceof StreamOperatorSpec);
+    assertEquals(flatMapOp.getNextStream(), outputStream);
+    assertEquals(((StreamOperatorSpec) flatMapOp).getTransformFn(), xFlatMap);
+
+    TestMessageEnvelope mockInput  = mock(TestMessageEnvelope.class);
+    // assert that the transformation function is what we defined above
+    List<TestOutputMessageEnvelope> result = (List<TestOutputMessageEnvelope>)
+        ((StreamOperatorSpec<TestMessageEnvelope, TestOutputMessageEnvelope>) flatMapOp).getTransformFn().apply(mockInput);
+    assertEquals(flatOuts, result);
+    assertEquals(inputMsgs.size(), 1);
+    assertEquals(inputMsgs.get(0), mockInput);
+  }
+
+  @Test
+  public void testFlatMapWithRelaxedTypes() {
+    MessageStreamImpl<TestInputMessageEnvelope> inputStream = new MessageStreamImpl<>(mockGraph);
+    List<TestExtOutputMessageEnvelope> flatOuts = new ArrayList<TestExtOutputMessageEnvelope>() { {
+        this.add(new TestExtOutputMessageEnvelope("output-key-1", 1, "output-id-001"));
+        this.add(new TestExtOutputMessageEnvelope("output-key-2", 2, "output-id-002"));
+        this.add(new TestExtOutputMessageEnvelope("output-key-3", 3, "output-id-003"));
+      } };
+
+    class MyFlatMapFunction implements FlatMapFunction<TestMessageEnvelope, TestExtOutputMessageEnvelope> {
+      public final List<TestMessageEnvelope> inputMsgs = new ArrayList<>();
+
+      @Override
+      public Collection<TestExtOutputMessageEnvelope> apply(TestMessageEnvelope message) {
+        inputMsgs.add(message);
+        return flatOuts;
+      }
+
+      @Override
+      public void init(Config config, TaskContext context) {
+        inputMsgs.clear();
+      }
+    }
+
+    MyFlatMapFunction xFlatMap = new MyFlatMapFunction();
+
     MessageStream<TestOutputMessageEnvelope> outputStream = inputStream.flatMap(xFlatMap);
     Collection<OperatorSpec> subs = inputStream.getRegisteredOperatorSpecs();
     assertEquals(subs.size(), 1);
     OperatorSpec<TestOutputMessageEnvelope> flatMapOp = subs.iterator().next();
     assertTrue(flatMapOp instanceof StreamOperatorSpec);
     assertEquals(flatMapOp.getNextStream(), outputStream);
+    assertEquals(((StreamOperatorSpec) flatMapOp).getTransformFn(), xFlatMap);
+
+    TestMessageEnvelope mockInput  = mock(TestMessageEnvelope.class);
     // assert that the transformation function is what we defined above
-    assertEquals(((StreamOperatorSpec<TestMessageEnvelope, TestOutputMessageEnvelope>) flatMapOp).getTransformFn(), xFlatMap);
+    List<TestOutputMessageEnvelope> result = (List<TestOutputMessageEnvelope>)
+        ((StreamOperatorSpec<TestMessageEnvelope, TestOutputMessageEnvelope>) flatMapOp).getTransformFn().apply(mockInput);
+    assertEquals(flatOuts, result);
+    assertEquals(xFlatMap.inputMsgs.size(), 1);
+    assertEquals(xFlatMap.inputMsgs.get(0), mockInput);
   }
 
   @Test
@@ -116,7 +171,7 @@ public class TestMessageStreamImpl {
     // assert that the transformation function is what we defined above
     FlatMapFunction<TestMessageEnvelope, TestMessageEnvelope> txfmFn = ((StreamOperatorSpec<TestMessageEnvelope, TestMessageEnvelope>) filterOp).getTransformFn();
     TestMessageEnvelope mockMsg = mock(TestMessageEnvelope.class);
-    TestMessageEnvelope.MessageType mockInnerTestMessage = mock(TestMessageEnvelope.MessageType.class);
+    MessageType mockInnerTestMessage = mock(MessageType.class);
     when(mockMsg.getMessage()).thenReturn(mockInnerTestMessage);
     when(mockInnerTestMessage.getEventTime()).thenReturn(11111L);
     Collection<TestMessageEnvelope> output = txfmFn.apply(mockMsg);
@@ -131,8 +186,9 @@ public class TestMessageStreamImpl {
   @Test
   public void testSink() {
     MessageStreamImpl<TestMessageEnvelope> inputStream = new MessageStreamImpl<>(mockGraph);
+    SystemStream testStream = new SystemStream("test-sys", "test-stream");
     SinkFunction<TestMessageEnvelope> xSink = (TestMessageEnvelope m, MessageCollector mc, TaskCoordinator tc) -> {
-      mc.send(new OutgoingMessageEnvelope(new SystemStream("test-sys", "test-stream"), m.getMessage()));
+      mc.send(new OutgoingMessageEnvelope(testStream, m.getMessage()));
       tc.commit(TaskCoordinator.RequestScope.CURRENT_TASK);
     };
     inputStream.sink(xSink);
@@ -141,6 +197,21 @@ public class TestMessageStreamImpl {
     OperatorSpec<TestMessageEnvelope> sinkOp = subs.iterator().next();
     assertTrue(sinkOp instanceof SinkOperatorSpec);
     assertEquals(((SinkOperatorSpec) sinkOp).getSinkFn(), xSink);
+
+    TestMessageEnvelope mockTest1 = mock(TestMessageEnvelope.class);
+    MessageType mockMsgBody = mock(MessageType.class);
+    when(mockTest1.getMessage()).thenReturn(mockMsgBody);
+    final List<OutgoingMessageEnvelope> outMsgs = new ArrayList<>();
+    MessageCollector mockCollector = mock(MessageCollector.class);
+    doAnswer(invocation -> {
+        outMsgs.add((OutgoingMessageEnvelope) invocation.getArguments()[0]);
+        return null;
+      }).when(mockCollector).send(any());
+    TaskCoordinator mockCoordinator = mock(TaskCoordinator.class);
+    ((SinkOperatorSpec) sinkOp).getSinkFn().apply(mockTest1, mockCollector, mockCoordinator);
+    assertEquals(1, outMsgs.size());
+    assertEquals(testStream, outMsgs.get(0).getSystemStream());
+    assertEquals(mockMsgBody, outMsgs.get(0).getMessage());
   }
 
   @Test
@@ -189,14 +260,14 @@ public class TestMessageStreamImpl {
   @Test
   public void testMerge() {
     MessageStream<TestMessageEnvelope> merge1 = new MessageStreamImpl<>(mockGraph);
-    Collection<MessageStream<TestMessageEnvelope>> others = new ArrayList<MessageStream<TestMessageEnvelope>>() { {
+    Collection<MessageStream<? extends TestMessageEnvelope>> others = new ArrayList<MessageStream<? extends TestMessageEnvelope>>() { {
         this.add(new MessageStreamImpl<>(mockGraph));
         this.add(new MessageStreamImpl<>(mockGraph));
       } };
     MessageStream<TestMessageEnvelope> mergeOutput = merge1.merge(others);
     validateMergeOperator(merge1, mergeOutput);
 
-    others.forEach(merge -> validateMergeOperator(merge, mergeOutput));
+    others.forEach(merge -> validateMergeOperator((MessageStream<TestMessageEnvelope>) merge, mergeOutput));
   }
 
   private void validateMergeOperator(MessageStream<TestMessageEnvelope> mergeSource, MessageStream<TestMessageEnvelope> mergeOutput) {
diff --git a/samza-core/src/test/java/org/apache/samza/operators/TestStreamGraphImpl.java b/samza-core/src/test/java/org/apache/samza/operators/TestStreamGraphImpl.java
new file mode 100644 (file)
index 0000000..3ab1a3c
--- /dev/null
@@ -0,0 +1,193 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.samza.operators;
+
+import org.apache.samza.config.Config;
+import org.apache.samza.config.JobConfig;
+import org.apache.samza.operators.data.MessageType;
+import org.apache.samza.operators.data.TestInputMessageEnvelope;
+import org.apache.samza.operators.data.TestMessageEnvelope;
+import org.apache.samza.operators.stream.InputStreamInternalImpl;
+import org.apache.samza.operators.stream.IntermediateStreamInternalImpl;
+import org.apache.samza.operators.stream.OutputStreamInternalImpl;
+import org.apache.samza.runtime.ApplicationRunner;
+import org.apache.samza.system.StreamSpec;
+import org.apache.samza.task.TaskContext;
+import org.junit.Test;
+
+import java.util.function.BiFunction;
+import java.util.function.Function;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+public class TestStreamGraphImpl {
+
+  @Test
+  public void testGetInputStream() {
+    ApplicationRunner mockRunner = mock(ApplicationRunner.class);
+    Config mockConfig = mock(Config.class);
+    StreamSpec testStreamSpec = new StreamSpec("test-stream-1", "physical-stream-1", "test-system");
+    when(mockRunner.getStreamSpec("test-stream-1")).thenReturn(testStreamSpec);
+
+    StreamGraphImpl graph = new StreamGraphImpl(mockRunner, mockConfig);
+    BiFunction<String, MessageType, TestInputMessageEnvelope> xMsgBuilder =
+        (k, v) -> new TestInputMessageEnvelope(k, v.getValue(), v.getEventTime(), "input-id-1");
+    MessageStream<TestMessageEnvelope> mInputStream = graph.getInputStream("test-stream-1", xMsgBuilder);
+    assertEquals(graph.getInputStreams().get(testStreamSpec), mInputStream);
+    assertTrue(mInputStream instanceof InputStreamInternalImpl);
+    assertEquals(((InputStreamInternalImpl) mInputStream).getMsgBuilder(), xMsgBuilder);
+
+    String key = "test-input-key";
+    MessageType msgBody = new MessageType("test-msg-value", 333333L);
+    TestMessageEnvelope xInputMsg = ((InputStreamInternalImpl<String, MessageType, TestMessageEnvelope>) mInputStream).
+        getMsgBuilder().apply(key, msgBody);
+    assertEquals(xInputMsg.getKey(), key);
+    assertEquals(xInputMsg.getMessage().getValue(), msgBody.getValue());
+    assertEquals(xInputMsg.getMessage().getEventTime(), msgBody.getEventTime());
+    assertEquals(((TestInputMessageEnvelope) xInputMsg).getInputId(), "input-id-1");
+  }
+
+  @Test
+  public void testGetOutputStream() {
+    ApplicationRunner mockRunner = mock(ApplicationRunner.class);
+    Config mockConfig = mock(Config.class);
+    StreamSpec testStreamSpec = new StreamSpec("test-stream-1", "physical-stream-1", "test-system");
+    when(mockRunner.getStreamSpec("test-stream-1")).thenReturn(testStreamSpec);
+
+    class MyMessageType extends MessageType {
+      public final String outputId;
+
+      public MyMessageType(String value, long eventTime, String outputId) {
+        super(value, eventTime);
+        this.outputId = outputId;
+      }
+    }
+
+    StreamGraphImpl graph = new StreamGraphImpl(mockRunner, mockConfig);
+    Function<TestMessageEnvelope, String> xKeyExtractor = x -> x.getKey();
+    Function<TestMessageEnvelope, MyMessageType> xMsgExtractor =
+        x -> new MyMessageType(x.getMessage().getValue(), x.getMessage().getEventTime(), "test-output-id-1");
+
+    OutputStream<String, MyMessageType, TestInputMessageEnvelope> mOutputStream =
+        graph.getOutputStream("test-stream-1", xKeyExtractor, xMsgExtractor);
+    assertEquals(graph.getOutputStreams().get(testStreamSpec), mOutputStream);
+    assertTrue(mOutputStream instanceof OutputStreamInternalImpl);
+    assertEquals(((OutputStreamInternalImpl) mOutputStream).getKeyExtractor(), xKeyExtractor);
+    assertEquals(((OutputStreamInternalImpl) mOutputStream).getMsgExtractor(), xMsgExtractor);
+
+    TestInputMessageEnvelope xInputMsg = new TestInputMessageEnvelope("test-key-1", "test-msg-1", 33333L, "input-id-1");
+    assertEquals(((OutputStreamInternalImpl<String, MyMessageType, TestInputMessageEnvelope>) mOutputStream).
+        getKeyExtractor().apply(xInputMsg), "test-key-1");
+    assertEquals(((OutputStreamInternalImpl<String, MyMessageType, TestInputMessageEnvelope>) mOutputStream).
+        getMsgExtractor().apply(xInputMsg).getValue(), "test-msg-1");
+    assertEquals(((OutputStreamInternalImpl<String, MyMessageType, TestInputMessageEnvelope>) mOutputStream).
+        getMsgExtractor().apply(xInputMsg).getEventTime(), 33333L);
+    assertEquals(((OutputStreamInternalImpl<String, MyMessageType, TestInputMessageEnvelope>) mOutputStream).
+        getMsgExtractor().apply(xInputMsg).outputId, "test-output-id-1");
+  }
+
+  @Test
+  public void testWithContextManager() {
+    ApplicationRunner mockRunner = mock(ApplicationRunner.class);
+    Config mockConfig = mock(Config.class);
+
+    StreamGraphImpl graph = new StreamGraphImpl(mockRunner, mockConfig);
+
+    // ensure that default is noop
+    TaskContext mockContext = mock(TaskContext.class);
+    assertEquals(graph.getContextManager().initTaskContext(mockConfig, mockContext), mockContext);
+
+    ContextManager testContextManager = new ContextManager() {
+      @Override
+      public TaskContext initTaskContext(Config config, TaskContext context) {
+        return null;
+      }
+
+      @Override
+      public void finalizeTaskContext() {
+
+      }
+    };
+
+    graph.withContextManager(testContextManager);
+    assertEquals(graph.getContextManager(), testContextManager);
+  }
+
+  @Test
+  public void testGetIntermediateStream() {
+    ApplicationRunner mockRunner = mock(ApplicationRunner.class);
+    Config mockConfig = mock(Config.class);
+    StreamSpec testStreamSpec = new StreamSpec("myJob-i001-test-stream-1", "physical-stream-1", "test-system");
+    when(mockRunner.getStreamSpec("myJob-i001-test-stream-1")).thenReturn(testStreamSpec);
+    when(mockConfig.get(JobConfig.JOB_NAME())).thenReturn("myJob");
+    when(mockConfig.get(JobConfig.JOB_ID(), "1")).thenReturn("i001");
+
+    class MyMessageType extends MessageType {
+      public final String outputId;
+
+      public MyMessageType(String value, long eventTime, String outputId) {
+        super(value, eventTime);
+        this.outputId = outputId;
+      }
+    }
+
+    StreamGraphImpl graph = new StreamGraphImpl(mockRunner, mockConfig);
+    Function<TestMessageEnvelope, String> xKeyExtractor = x -> x.getKey();
+    Function<TestMessageEnvelope, MyMessageType> xMsgExtractor =
+        x -> new MyMessageType(x.getMessage().getValue(), x.getMessage().getEventTime(), "test-output-id-1");
+    BiFunction<String, MessageType, TestInputMessageEnvelope> xMsgBuilder =
+        (k, v) -> new TestInputMessageEnvelope(k, v.getValue(), v.getEventTime(), "input-id-1");
+
+    MessageStream<TestMessageEnvelope> mIntermediateStream =
+        graph.getIntermediateStream("test-stream-1", xKeyExtractor, xMsgExtractor, xMsgBuilder);
+    assertEquals(graph.getOutputStreams().get(testStreamSpec), mIntermediateStream);
+    assertTrue(mIntermediateStream instanceof IntermediateStreamInternalImpl);
+    assertEquals(((IntermediateStreamInternalImpl) mIntermediateStream).getKeyExtractor(), xKeyExtractor);
+    assertEquals(((IntermediateStreamInternalImpl) mIntermediateStream).getMsgExtractor(), xMsgExtractor);
+    assertEquals(((IntermediateStreamInternalImpl) mIntermediateStream).getMsgBuilder(), xMsgBuilder);
+
+    TestMessageEnvelope xInputMsg = new TestMessageEnvelope("test-key-1", "test-msg-1", 33333L);
+    assertEquals(((IntermediateStreamInternalImpl<String, MessageType, TestMessageEnvelope>) mIntermediateStream).
+        getKeyExtractor().apply(xInputMsg), "test-key-1");
+    assertEquals(((IntermediateStreamInternalImpl<String, MessageType, TestMessageEnvelope>) mIntermediateStream).
+        getMsgExtractor().apply(xInputMsg).getValue(), "test-msg-1");
+    assertEquals(((IntermediateStreamInternalImpl<String, MessageType, TestMessageEnvelope>) mIntermediateStream).
+        getMsgExtractor().apply(xInputMsg).getEventTime(), 33333L);
+    assertEquals(((IntermediateStreamInternalImpl<String, MessageType, TestMessageEnvelope>) mIntermediateStream).
+        getMsgBuilder().apply("test-key-1", new MyMessageType("test-msg-1", 33333L, "test-output-id-1")).getKey(), "test-key-1");
+    assertEquals(((IntermediateStreamInternalImpl<String, MessageType, TestMessageEnvelope>) mIntermediateStream).
+        getMsgBuilder().apply("test-key-1", new MyMessageType("test-msg-1", 33333L, "test-output-id-1")).getMessage().getValue(), "test-msg-1");
+    assertEquals(((IntermediateStreamInternalImpl<String, MessageType, TestMessageEnvelope>) mIntermediateStream).
+        getMsgBuilder().apply("test-key-1", new MyMessageType("test-msg-1", 33333L, "test-output-id-1")).getMessage().getEventTime(), 33333L);
+  }
+
+  @Test
+  public void testGetNextOpId() {
+    ApplicationRunner mockRunner = mock(ApplicationRunner.class);
+    Config mockConfig = mock(Config.class);
+
+    StreamGraphImpl graph = new StreamGraphImpl(mockRunner, mockConfig);
+    assertEquals(graph.getNextOpId(), 0);
+    assertEquals(graph.getNextOpId(), 1);
+  }
+
+}
diff --git a/samza-core/src/test/java/org/apache/samza/operators/data/MessageType.java b/samza-core/src/test/java/org/apache/samza/operators/data/MessageType.java
new file mode 100644 (file)
index 0000000..3fd015b
--- /dev/null
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.samza.operators.data;
+
+public class MessageType {
+  private final String value;
+  private final long eventTime;
+
+  public MessageType(String value, long eventTime) {
+    this.value = value;
+    this.eventTime = eventTime;
+  }
+
+  public long getEventTime() {
+    return eventTime;
+  }
+
+  public String getValue() {
+    return value;
+  }
+}
diff --git a/samza-core/src/test/java/org/apache/samza/operators/data/TestExtOutputMessageEnvelope.java b/samza-core/src/test/java/org/apache/samza/operators/data/TestExtOutputMessageEnvelope.java
new file mode 100644 (file)
index 0000000..22222ed
--- /dev/null
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.samza.operators.data;
+
+public class TestExtOutputMessageEnvelope extends TestOutputMessageEnvelope {
+  private final String outputId;
+
+  public TestExtOutputMessageEnvelope(String key, Integer value, String outputId) {
+    super(key, value);
+    this.outputId = outputId;
+  }
+
+}
diff --git a/samza-core/src/test/java/org/apache/samza/operators/data/TestInputMessageEnvelope.java b/samza-core/src/test/java/org/apache/samza/operators/data/TestInputMessageEnvelope.java
new file mode 100644 (file)
index 0000000..089f534
--- /dev/null
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.samza.operators.data;
+
+public class TestInputMessageEnvelope extends TestMessageEnvelope {
+  private final String inputId;
+
+  public TestInputMessageEnvelope(String key, String value, long eventTime, String inputId) {
+    super(key, value, eventTime);
+    this.inputId = inputId;
+  }
+
+  public String getInputId() {
+    return this.inputId;
+  }
+}
index 2524c28..05a63cd 100644 (file)
@@ -37,21 +37,4 @@ public class TestMessageEnvelope {
     return this.key;
   }
 
-  public class MessageType {
-    private final String value;
-    private final long eventTime;
-
-    public MessageType(String value, long eventTime) {
-      this.value = value;
-      this.eventTime = eventTime;
-    }
-
-    public long getEventTime() {
-      return eventTime;
-    }
-
-    public String getValue() {
-      return value;
-    }
-  }
 }
index 37e3d1a..d227206 100644 (file)
@@ -21,6 +21,8 @@ package org.apache.samza.operators.spec;
 import org.apache.samza.operators.MessageStreamImpl;
 import org.apache.samza.operators.StreamGraphImpl;
 import org.apache.samza.operators.TestMessageStreamImplUtil;
+import org.apache.samza.operators.data.MessageType;
+import org.apache.samza.operators.data.TestInputMessageEnvelope;
 import org.apache.samza.operators.data.TestMessageEnvelope;
 import org.apache.samza.operators.data.TestOutputMessageEnvelope;
 import org.apache.samza.operators.functions.FlatMapFunction;
@@ -31,39 +33,71 @@ import org.apache.samza.operators.stream.OutputStreamInternalImpl;
 import org.apache.samza.operators.windows.WindowPane;
 import org.apache.samza.operators.windows.internal.WindowInternal;
 import org.apache.samza.operators.windows.internal.WindowType;
+import org.apache.samza.system.OutgoingMessageEnvelope;
+import org.apache.samza.system.SystemStream;
 import org.apache.samza.task.MessageCollector;
 import org.apache.samza.task.TaskCoordinator;
 import org.junit.Test;
 
 import java.util.ArrayList;
 import java.util.Collection;
+import java.util.List;
 import java.util.function.Function;
 import java.util.function.Supplier;
 
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertNotNull;
+import static org.mockito.Matchers.any;
+import static org.mockito.Mockito.doAnswer;
 import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
 
 
 public class TestOperatorSpecs {
   @Test
   public void testCreateStreamOperator() {
-    FlatMapFunction<?, TestMessageEnvelope> transformFn = m -> new ArrayList<TestMessageEnvelope>() { {
+    FlatMapFunction<Object, TestMessageEnvelope> transformFn = m -> new ArrayList<TestMessageEnvelope>() { {
           this.add(new TestMessageEnvelope(m.toString(), m.toString(), 12345L));
         } };
     MessageStreamImpl<TestMessageEnvelope> mockOutput = mock(MessageStreamImpl.class);
-    StreamOperatorSpec<?, TestMessageEnvelope> streamOp =
+    StreamOperatorSpec<Object, TestMessageEnvelope> streamOp =
         OperatorSpecs.createStreamOperatorSpec(transformFn, mockOutput, 1);
     assertEquals(streamOp.getTransformFn(), transformFn);
+
+    Object mockInput = mock(Object.class);
+    when(mockInput.toString()).thenReturn("test-string-1");
+    List<TestMessageEnvelope> outputs = (List<TestMessageEnvelope>) streamOp.getTransformFn().apply(mockInput);
+    assertEquals(outputs.size(), 1);
+    assertEquals(outputs.get(0).getKey(), "test-string-1");
+    assertEquals(outputs.get(0).getMessage().getValue(), "test-string-1");
+    assertEquals(outputs.get(0).getMessage().getEventTime(), 12345L);
     assertEquals(streamOp.getNextStream(), mockOutput);
   }
 
   @Test
   public void testCreateSinkOperator() {
+    SystemStream testStream = new SystemStream("test-sys", "test-stream");
     SinkFunction<TestMessageEnvelope> sinkFn = (TestMessageEnvelope message, MessageCollector messageCollector,
-          TaskCoordinator taskCoordinator) -> { };
+          TaskCoordinator taskCoordinator) -> {
+      messageCollector.send(new OutgoingMessageEnvelope(testStream, message.getKey(), message.getMessage()));
+    };
     SinkOperatorSpec<TestMessageEnvelope> sinkOp = OperatorSpecs.createSinkOperatorSpec(sinkFn, 1);
     assertEquals(sinkOp.getSinkFn(), sinkFn);
+
+    TestMessageEnvelope mockInput = mock(TestMessageEnvelope.class);
+    when(mockInput.getKey()).thenReturn("my-test-msg-key");
+    MessageType mockMsgBody = mock(MessageType.class);
+    when(mockInput.getMessage()).thenReturn(mockMsgBody);
+    final List<OutgoingMessageEnvelope> outputMsgs = new ArrayList<>();
+    MessageCollector mockCollector = mock(MessageCollector.class);
+    doAnswer(invocation -> {
+        outputMsgs.add((OutgoingMessageEnvelope) invocation.getArguments()[0]);
+        return null;
+      }).when(mockCollector).send(any());
+    sinkOp.getSinkFn().apply(mockInput, mockCollector, null);
+    assertEquals(1, outputMsgs.size());
+    assertEquals(outputMsgs.get(0).getKey(), "my-test-msg-key");
+    assertEquals(outputMsgs.get(0).getMessage(), mockMsgBody);
     assertEquals(sinkOp.getOpCode(), OperatorSpec.OpCode.SINK);
     assertEquals(sinkOp.getNextStream(), null);
   }
@@ -104,6 +138,27 @@ public class TestOperatorSpecs {
   }
 
   @Test
+  public void testCreateWindowOperatorWithRelaxedTypes() throws Exception {
+    Function<TestMessageEnvelope, String> keyExtractor = m -> m.getKey();
+    FoldLeftFunction<TestMessageEnvelope, Integer> aggregator = (m, c) -> c + 1;
+    Supplier<Integer> initialValue = () -> 0;
+    //instantiate a window using reflection
+    WindowInternal<TestInputMessageEnvelope, String, Integer> window = new WindowInternal(null, initialValue, aggregator, keyExtractor, null, WindowType.TUMBLING);
+
+    MessageStreamImpl<WindowPane<String, Integer>> mockWndOut = mock(MessageStreamImpl.class);
+    WindowOperatorSpec spec =
+        OperatorSpecs.createWindowOperatorSpec(window, mockWndOut, 1);
+    assertEquals(spec.getWindow(), window);
+    assertEquals(spec.getWindow().getKeyExtractor(), keyExtractor);
+    assertEquals(spec.getWindow().getFoldLeftFunction(), aggregator);
+
+    // make sure that the functions with relaxed types work as expected
+    TestInputMessageEnvelope inputMsg = new TestInputMessageEnvelope("test-input-key1", "test-value-1", 23456L, "input-id-1");
+    assertEquals("test-input-key1", spec.getWindow().getKeyExtractor().apply(inputMsg));
+    assertEquals(1, spec.getWindow().getFoldLeftFunction().apply(inputMsg, 0));
+  }
+
+  @Test
   public void testCreatePartialJoinOperator() {
     PartialJoinFunction<String, TestMessageEnvelope, TestMessageEnvelope, TestOutputMessageEnvelope> thisPartialJoinFn
         = mock(PartialJoinFunction.class);
index 0b051e8..e300996 100644 (file)
@@ -19,6 +19,7 @@
 package org.apache.samza.task;
 
 import org.apache.samza.SamzaException;
+import org.apache.samza.config.ApplicationConfig;
 import org.apache.samza.config.Config;
 import org.apache.samza.config.ConfigException;
 import org.apache.samza.config.MapConfig;
@@ -74,7 +75,7 @@ public class TestTaskFactoryUtil {
   public void testCreateStreamApplication() throws Exception {
     Config config = new MapConfig(new HashMap<String, String>() {
       {
-        this.put(StreamApplication.APP_CLASS_CONFIG, "org.apache.samza.testUtils.TestStreamApplication");
+        this.put(ApplicationConfig.APP_CLASS, "org.apache.samza.testUtils.TestStreamApplication");
       }
     });
     StreamApplication streamApp = TaskFactoryUtil.createStreamApplication(config);
@@ -85,7 +86,7 @@ public class TestTaskFactoryUtil {
 
     config = new MapConfig(new HashMap<String, String>() {
       {
-        this.put(StreamApplication.APP_CLASS_CONFIG, "org.apache.samza.testUtils.InvalidStreamApplication");
+        this.put(ApplicationConfig.APP_CLASS, "org.apache.samza.testUtils.InvalidStreamApplication");
       }
     });
     try {
@@ -97,7 +98,7 @@ public class TestTaskFactoryUtil {
 
     config = new MapConfig(new HashMap<String, String>() {
       {
-        this.put(StreamApplication.APP_CLASS_CONFIG, "no.such.class");
+        this.put(ApplicationConfig.APP_CLASS, "no.such.class");
       }
     });
     try {
@@ -109,7 +110,7 @@ public class TestTaskFactoryUtil {
 
     config = new MapConfig(new HashMap<String, String>() {
       {
-        this.put(StreamApplication.APP_CLASS_CONFIG, "");
+        this.put(ApplicationConfig.APP_CLASS, "");
       }
     });
     streamApp = TaskFactoryUtil.createStreamApplication(config);
@@ -124,7 +125,7 @@ public class TestTaskFactoryUtil {
   public void testCreateStreamApplicationWithTaskClass() throws Exception {
     Config config = new MapConfig(new HashMap<String, String>() {
       {
-        this.put(StreamApplication.APP_CLASS_CONFIG, "org.apache.samza.testUtils.TestStreamApplication");
+        this.put(ApplicationConfig.APP_CLASS, "org.apache.samza.testUtils.TestStreamApplication");
       }
     });
     StreamApplication streamApp = TaskFactoryUtil.createStreamApplication(config);
@@ -133,7 +134,7 @@ public class TestTaskFactoryUtil {
     config = new MapConfig(new HashMap<String, String>() {
       {
         this.put("task.class", "org.apache.samza.testUtils.TestAsyncStreamTask");
-        this.put(StreamApplication.APP_CLASS_CONFIG, "org.apache.samza.testUtils.TestStreamApplication");
+        this.put(ApplicationConfig.APP_CLASS, "org.apache.samza.testUtils.TestStreamApplication");
       }
     });
     try {
@@ -146,7 +147,7 @@ public class TestTaskFactoryUtil {
     config = new MapConfig(new HashMap<String, String>() {
       {
         this.put("task.class", "no.such.class");
-        this.put(StreamApplication.APP_CLASS_CONFIG, "org.apache.samza.testUtils.TestStreamApplication");
+        this.put(ApplicationConfig.APP_CLASS, "org.apache.samza.testUtils.TestStreamApplication");
       }
     });
     try {
@@ -162,7 +163,7 @@ public class TestTaskFactoryUtil {
 
     Config config = new MapConfig(new HashMap<String, String>() {
       {
-        this.put(StreamApplication.APP_CLASS_CONFIG, "org.apache.samza.testUtils.InvalidStreamApplication");
+        this.put(ApplicationConfig.APP_CLASS, "org.apache.samza.testUtils.InvalidStreamApplication");
       }
     });
     try {
@@ -175,7 +176,7 @@ public class TestTaskFactoryUtil {
     config = new MapConfig(new HashMap<String, String>() {
       {
         this.put("task.class", "org.apache.samza.testUtils.TestStreamTask");
-        this.put(StreamApplication.APP_CLASS_CONFIG, "");
+        this.put(ApplicationConfig.APP_CLASS, "");
       }
     });
     StreamApplication streamApp = TaskFactoryUtil.createStreamApplication(config);
@@ -186,7 +187,7 @@ public class TestTaskFactoryUtil {
     config = new MapConfig(new HashMap<String, String>() {
       {
         this.put("task.class", "");
-        this.put(StreamApplication.APP_CLASS_CONFIG, "org.apache.samza.testUtils.InvalidStreamApplication");
+        this.put(ApplicationConfig.APP_CLASS, "org.apache.samza.testUtils.InvalidStreamApplication");
       }
     });
     try {
@@ -226,7 +227,7 @@ public class TestTaskFactoryUtil {
 
     Config config = new MapConfig(new HashMap<String, String>() {
       {
-        this.put(StreamApplication.APP_CLASS_CONFIG, "org.apache.samza.testUtils.InvalidStreamApplication");
+        this.put(ApplicationConfig.APP_CLASS, "org.apache.samza.testUtils.InvalidStreamApplication");
       }
     });
     try {
@@ -239,7 +240,7 @@ public class TestTaskFactoryUtil {
     config = new MapConfig(new HashMap<String, String>() {
       {
         this.put("task.class", "org.apache.samza.testUtils.TestAsyncStreamTask");
-        this.put(StreamApplication.APP_CLASS_CONFIG, "");
+        this.put(ApplicationConfig.APP_CLASS, "");
       }
     });
     StreamApplication streamApp = TaskFactoryUtil.createStreamApplication(config);
@@ -250,7 +251,7 @@ public class TestTaskFactoryUtil {
     config = new MapConfig(new HashMap<String, String>() {
       {
         this.put("task.class", "org.apache.samza.testUtils.TestAsyncStreamTask");
-        this.put(StreamApplication.APP_CLASS_CONFIG, null);
+        this.put(ApplicationConfig.APP_CLASS, null);
       }
     });
     streamApp = TaskFactoryUtil.createStreamApplication(config);
index cda2690..29fb6d3 100644 (file)
@@ -119,7 +119,7 @@ object StreamTaskTestUtil {
     servers = configs.map(TestUtils.createServer(_)).toBuffer
 
     val brokerList = TestUtils.getBrokerListStrFromServers(servers, SecurityProtocol.PLAINTEXT)
-    brokers = brokerList.split(",").map(p => "localhost" + p).mkString(",")
+    brokers = brokerList.split(",").map(p => "127.0.0.1" + p).mkString(",")
 
     // setup the zookeeper and bootstrap servers for local kafka cluster
     jobConfig ++= Map("systems.kafka.consumer.zookeeper.connect" -> zkConnect,