SAMZA-1759: Stream Assert utilities for low level and high level api for TestFramework
authorSanil Jain <snjain@linkedin.com>
Thu, 2 Aug 2018 00:18:21 +0000 (17:18 -0700)
committerBoris S <bshkolnik@linkedin.com>
Thu, 2 Aug 2018 00:18:21 +0000 (17:18 -0700)
Adding utilities and corresponding test for low and high level api

Author: Sanil Jain <snjain@linkedin.com>

Reviewers: Shanthoosh Venkataraman <spvenkat@usc.edu>

Closes #568 from Sanil15/SAMZA-1759 and squashes the following commits:

a4861089 [Sanil Jain] Reverting back travis increase for wait time
876a3a58 [Sanil Jain] Increase travis timeout
9e6482b1 [Sanil Jain] Fixing travis build, removing unused imports
526244e8 [Sanil Jain] Merge branch 'master' into SAMZA-1759
9f489acf [Sanil Jain] Moving tests that use MessageStreamAssert to same package name in test folder to use package private
a93e5a14 [Sanil Jain] Marking collection transient to ensure newer api changes work
5e6d3ed1 [Sanil Jain] Making MessageStreamAssert package private
a5a521cc [Sanil Jain] Splitting operator assertions outside StreamAssert to MessageStreamAssert, addressing review, renaming utils
d1e64180 [Sanil Jain] Cleaning unused imports
ff218ff7 [Sanil Jain] Removing contains method for operator level assertios for high level api
c5768772 [Sanil Jain] Merge branch 'SAMZA-1759' of https://github.com/Sanil15/samza into SAMZA-1759
c69d1bbb [Sanil Jain] StreamAssert Utilities for Low level and High Level Api, Adding More Test for Low Level api for testing multiple partitions and in mulithreaded mode
e3c8e2a5 [Sanil Jain] StreamAssert Utilities for Low level and High Level Api, Adding More Test for Low Level api for testing multiple partitions and in mulithreaded mode

16 files changed:
.travis.yml
samza-test/src/main/java/org/apache/samza/test/framework/MessageStreamAssert.java [new file with mode: 0644]
samza-test/src/main/java/org/apache/samza/test/framework/StreamAssert.java
samza-test/src/main/java/org/apache/samza/test/framework/TestRunner.java
samza-test/src/main/java/org/apache/samza/test/framework/stream/CollectionStream.java
samza-test/src/test/java/org/apache/samza/test/framework/AsyncStreamTaskIntegrationTest.java
samza-test/src/test/java/org/apache/samza/test/framework/BroadcastAssertApp.java [moved from samza-test/src/test/java/org/apache/samza/test/operator/BroadcastAssertApp.java with 91% similarity]
samza-test/src/test/java/org/apache/samza/test/framework/MyAsyncStreamTask.java
samza-test/src/test/java/org/apache/samza/test/framework/MyStreamTestTask.java
samza-test/src/test/java/org/apache/samza/test/framework/StreamApplicationIntegrationTest.java
samza-test/src/test/java/org/apache/samza/test/framework/StreamApplicationIntegrationTestHarness.java [moved from samza-test/src/test/java/org/apache/samza/test/operator/StreamApplicationIntegrationTestHarness.java with 99% similarity]
samza-test/src/test/java/org/apache/samza/test/framework/StreamTaskIntegrationTest.java
samza-test/src/test/java/org/apache/samza/test/framework/TestTimerApp.java [moved from samza-test/src/test/java/org/apache/samza/test/timer/TestTimerApp.java with 94% similarity]
samza-test/src/test/java/org/apache/samza/test/framework/TimerTest.java [moved from samza-test/src/test/java/org/apache/samza/test/timer/TimerTest.java with 90% similarity]
samza-test/src/test/java/org/apache/samza/test/operator/TestRepartitionJoinWindowApp.java
samza-test/src/test/java/org/apache/samza/test/operator/TestRepartitionWindowApp.java

index ef112f2..2a3ae0c 100644 (file)
@@ -5,15 +5,15 @@
 # The ASF licenses this file to You under the Apache License, Version 2.0
 # (the "License"); you may not use this file except in compliance with
 # the License.  You may obtain a copy of the License at
-# 
+#
 # http://www.apache.org/licenses/LICENSE-2.0
-# 
+#
 # Unless required by applicable law or agreed to in writing, software
 # distributed under the License is distributed on an "AS IS" BASIS,
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-# 
+#
 
 language: java
 
diff --git a/samza-test/src/main/java/org/apache/samza/test/framework/MessageStreamAssert.java b/samza-test/src/main/java/org/apache/samza/test/framework/MessageStreamAssert.java
new file mode 100644 (file)
index 0000000..1a1c24c
--- /dev/null
@@ -0,0 +1,192 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.test.framework;
+
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.collect.Iterables;
+import java.io.IOException;
+import java.io.ObjectInputStream;
+import org.apache.samza.config.Config;
+import org.apache.samza.operators.MessageStream;
+import org.apache.samza.operators.functions.SinkFunction;
+import org.apache.samza.serializers.KVSerde;
+import org.apache.samza.serializers.Serde;
+import org.apache.samza.serializers.StringSerde;
+import org.apache.samza.system.SystemStreamPartition;
+import org.apache.samza.task.MessageCollector;
+import org.apache.samza.task.TaskContext;
+import org.apache.samza.task.TaskCoordinator;
+import org.hamcrest.Matchers;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.Timer;
+import java.util.TimerTask;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.CountDownLatch;
+
+import static org.junit.Assert.assertThat;
+
+/**
+ * An assertion on the content of a {@link MessageStream}.
+ *
+ * <pre>Example: {@code
+ * MessageStream<String> stream = streamGraph.getInputStream("input", serde).map(some_function)...;
+ * ...
+ * MessageStreamAssert.that(id, stream, stringSerde).containsInAnyOrder(Arrays.asList("a", "b", "c"));
+ * }</pre>
+ *
+ */
+@VisibleForTesting
+class MessageStreamAssert<M> {
+  private final static Map<String, CountDownLatch> LATCHES = new ConcurrentHashMap<>();
+  private final static CountDownLatch PLACE_HOLDER = new CountDownLatch(0);
+
+  private final String id;
+  private final MessageStream<M> messageStream;
+  private final Serde<M> serde;
+  private boolean checkEachTask = false;
+
+  /**
+   * Constructors a MessageStreamAssert with an id and serde
+   * @param id unique id
+   * @param messageStream represents messageStream that you want to assert on
+   * @param serde serde used to desialize messageStream
+   * @param <M> represents type of Message
+   * @return MessageStreamAssert that returns the the messages in the stream
+   */
+  public static <M> MessageStreamAssert<M> that(String id, MessageStream<M> messageStream, Serde<M> serde) {
+    return new MessageStreamAssert<>(id, messageStream, serde);
+  }
+
+  private MessageStreamAssert(String id, MessageStream<M> messageStream, Serde<M> serde) {
+    this.id = id;
+    this.messageStream = messageStream;
+    this.serde = serde;
+  }
+
+  public MessageStreamAssert forEachTask() {
+    checkEachTask = true;
+    return this;
+  }
+
+  public void containsInAnyOrder(final Collection<M> expected) {
+    LATCHES.putIfAbsent(id, PLACE_HOLDER);
+    final MessageStream<M> streamToCheck = checkEachTask
+        ? messageStream
+        : messageStream
+            .partitionBy(m -> null, m -> m, KVSerde.of(new StringSerde(), serde), null)
+            .map(kv -> kv.value);
+
+    streamToCheck.sink(new CheckAgainstExpected<M>(id, expected, checkEachTask));
+  }
+
+  public static void waitForComplete() {
+    try {
+      while (!LATCHES.isEmpty()) {
+        final Set<String> ids  = new HashSet<>(LATCHES.keySet());
+        for (String id : ids) {
+          while (LATCHES.get(id) == PLACE_HOLDER) {
+            Thread.sleep(100);
+          }
+
+          final CountDownLatch latch = LATCHES.get(id);
+          if (latch != null) {
+            latch.await();
+            LATCHES.remove(id);
+          }
+        }
+      }
+    } catch (Exception e) {
+      throw new RuntimeException(e);
+    }
+  }
+
+  private static final class CheckAgainstExpected<M> implements SinkFunction<M> {
+    private static final long TIMEOUT = 5000L;
+
+    private final String id;
+    private final boolean checkEachTask;
+    private final transient Collection<M> expected;
+
+
+    private transient Timer timer = new Timer();
+    private transient List<M> actual = Collections.synchronizedList(new ArrayList<>());
+    private transient TimerTask timerTask = new TimerTask() {
+      @Override
+      public void run() {
+        check();
+      }
+    };
+
+    CheckAgainstExpected(String id, Collection<M> expected, boolean checkEachTask) {
+      this.id = id;
+      this.expected = expected;
+      this.checkEachTask = checkEachTask;
+    }
+
+    @Override
+    public void init(Config config, TaskContext context) {
+      final SystemStreamPartition ssp = Iterables.getFirst(context.getSystemStreamPartitions(), null);
+      if (ssp != null || ssp.getPartition().getPartitionId() == 0) {
+        final int count = checkEachTask ? context.getSamzaContainerContext().taskNames.size() : 1;
+        LATCHES.put(id, new CountDownLatch(count));
+        timer.schedule(timerTask, TIMEOUT);
+      }
+    }
+
+    @Override
+    public void apply(M message, MessageCollector messageCollector, TaskCoordinator taskCoordinator) {
+      actual.add(message);
+
+      if (actual.size() >= expected.size()) {
+        timerTask.cancel();
+        check();
+      }
+    }
+
+    private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
+      in.defaultReadObject();
+      timer = new Timer();
+      actual = Collections.synchronizedList(new ArrayList<>());
+      timerTask = new TimerTask() {
+        @Override
+        public void run() {
+          check();
+        }
+      };
+    }
+
+    private void check() {
+      final CountDownLatch latch = LATCHES.get(id);
+      try {
+        assertThat(actual, Matchers.containsInAnyOrder((M[]) expected.toArray()));
+        throw new IllegalArgumentException("asdas");
+      } finally {
+        latch.countDown();
+      }
+    }
+  }
+}
index a1ac299..9972d7f 100644 (file)
 
 package org.apache.samza.test.framework;
 
-import com.google.common.collect.Iterables;
-import java.io.IOException;
-import java.io.ObjectInputStream;
-import org.apache.samza.config.Config;
-import org.apache.samza.operators.MessageStream;
-import org.apache.samza.operators.functions.SinkFunction;
-import org.apache.samza.serializers.KVSerde;
-import org.apache.samza.serializers.Serde;
-import org.apache.samza.serializers.StringSerde;
-import org.apache.samza.system.SystemStreamPartition;
-import org.apache.samza.task.MessageCollector;
-import org.apache.samza.task.TaskContext;
-import org.apache.samza.task.TaskCoordinator;
-import org.hamcrest.Matchers;
-
-import java.util.ArrayList;
-import java.util.Collection;
-import java.util.Collections;
-import java.util.HashSet;
+import com.google.common.base.Preconditions;
+import java.time.Duration;
+import java.util.stream.Collectors;
+import org.apache.samza.test.framework.stream.CollectionStream;
 import java.util.List;
 import java.util.Map;
-import java.util.Set;
-import java.util.Timer;
-import java.util.TimerTask;
-import java.util.concurrent.ConcurrentHashMap;
-import java.util.concurrent.CountDownLatch;
+import org.hamcrest.collection.IsIterableContainingInAnyOrder;
+import org.hamcrest.collection.IsIterableContainingInOrder;
 
 import static org.junit.Assert.assertThat;
 
+
 /**
- * An assertion on the content of a {@link MessageStream}.
- *
- * <pre>Example: {@code
- * MessageStream<String> stream = streamGraph.getInputStream("input", serde).map(some_function)...;
- * ...
- * StreamAssert.that(id, stream, stringSerde).containsInAnyOrder(Arrays.asList("a", "b", "c"));
- * }</pre>
- *
+ * Assertion utils non the content of a {@link CollectionStream}.
  */
-public class StreamAssert<M> {
-  private final static Map<String, CountDownLatch> LATCHES = new ConcurrentHashMap<>();
-  private final static CountDownLatch PLACE_HOLDER = new CountDownLatch(0);
-
-  private final String id;
-  private final MessageStream<M> messageStream;
-  private final Serde<M> serde;
-  private boolean checkEachTask = false;
-
-  public static <M> StreamAssert<M> that(String id, MessageStream<M> messageStream, Serde<M> serde) {
-    return new StreamAssert<>(id, messageStream, serde);
-  }
-
-  private StreamAssert(String id, MessageStream<M> messageStream, Serde<M> serde) {
-    this.id = id;
-    this.messageStream = messageStream;
-    this.serde = serde;
-  }
-
-  public StreamAssert forEachTask() {
-    checkEachTask = true;
-    return this;
+public class StreamAssert {
+  /**
+   * Util to assert  presence of messages in a stream with single partition in any order
+   *
+   * @param collectionStream represents the actual stream which will be consumed to compare against expected list
+   * @param expected represents the expected stream of messages
+   * @param timeout maximum time to wait for consuming the stream
+   * @param <M> represents the type of Message in the stream
+   * @throws InterruptedException when {@code consumeStream} is interrupted by another thread during polling messages
+   */
+  public static <M> void containsInAnyOrder(CollectionStream<M> collectionStream, final List<M> expected, Duration timeout)
+      throws InterruptedException {
+    Preconditions.checkNotNull(collectionStream, "This util is intended to use only on CollectionStream");
+    assertThat(TestRunner.consumeStream(collectionStream, timeout)
+        .entrySet()
+        .stream()
+        .flatMap(entry -> entry.getValue().stream())
+        .collect(Collectors.toList()), IsIterableContainingInAnyOrder.containsInAnyOrder(expected.toArray()));
   }
 
-  public void containsInAnyOrder(final Collection<M> expected) {
-    LATCHES.putIfAbsent(id, PLACE_HOLDER);
-    final MessageStream<M> streamToCheck = checkEachTask
-        ? messageStream
-        : messageStream
-          .partitionBy(m -> null, m -> m, KVSerde.of(new StringSerde(), serde), null)
-          .map(kv -> kv.value);
-
-    streamToCheck.sink(new CheckAgainstExpected<M>(id, expected, checkEachTask));
-  }
-
-  public static void waitForComplete() {
-    try {
-      while (!LATCHES.isEmpty()) {
-        final Set<String> ids  = new HashSet<>(LATCHES.keySet());
-        for (String id : ids) {
-          while (LATCHES.get(id) == PLACE_HOLDER) {
-            Thread.sleep(100);
-          }
-
-          final CountDownLatch latch = LATCHES.get(id);
-          if (latch != null) {
-            latch.await();
-            LATCHES.remove(id);
-          }
-        }
-      }
-    } catch (Exception e) {
-      throw new RuntimeException(e);
+  /**
+   * Util to assert presence of messages in a stream with multiple partition in any order
+   *
+   * @param collectionStream represents the actual stream which will be consumed to compare against expected partition map
+   * @param expected represents a map of partitionId as key and list of messages in stream as value
+   * @param timeout maximum time to wait for consuming the stream
+   * @param <M> represents the type of Message in the stream
+   * @throws InterruptedException when {@code consumeStream} is interrupted by another thread during polling messages
+   *
+   */
+  public static <M> void containsInAnyOrder(CollectionStream<M> collectionStream, final Map<Integer, List<M>> expected,
+      Duration timeout) throws InterruptedException {
+    Preconditions.checkNotNull(collectionStream, "This util is intended to use only on CollectionStream");
+    Map<Integer, List<M>> actual = TestRunner.consumeStream(collectionStream, timeout);
+    for (Integer paritionId : expected.keySet()) {
+      assertThat(actual.get(paritionId),
+          IsIterableContainingInAnyOrder.containsInAnyOrder(expected.get(paritionId).toArray()));
     }
   }
 
-  private static final class CheckAgainstExpected<M> implements SinkFunction<M> {
-    private static final long TIMEOUT = 5000L;
-
-    private final String id;
-    private final boolean checkEachTask;
-    private final Collection<M> expected;
-
-
-    private transient Timer timer = new Timer();
-    private transient List<M> actual = Collections.synchronizedList(new ArrayList<>());
-    private transient TimerTask timerTask = new TimerTask() {
-      @Override
-      public void run() {
-        check();
-      }
-    };
-
-    CheckAgainstExpected(String id, Collection<M> expected, boolean checkEachTask) {
-      this.id = id;
-      this.expected = expected;
-      this.checkEachTask = checkEachTask;
-    }
-
-    @Override
-    public void init(Config config, TaskContext context) {
-      final SystemStreamPartition ssp = Iterables.getFirst(context.getSystemStreamPartitions(), null);
-      if (ssp == null ? false : ssp.getPartition().getPartitionId() == 0) {
-        final int count = checkEachTask ? context.getSamzaContainerContext().taskNames.size() : 1;
-        LATCHES.put(id, new CountDownLatch(count));
-        timer.schedule(timerTask, TIMEOUT);
-      }
-    }
-
-    @Override
-    public void apply(M message, MessageCollector messageCollector, TaskCoordinator taskCoordinator) {
-      actual.add(message);
-
-      if (actual.size() >= expected.size()) {
-        timerTask.cancel();
-        check();
-      }
-    }
-
-    private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
-      in.defaultReadObject();
-      timer = new Timer();
-      actual = Collections.synchronizedList(new ArrayList<>());
-      timerTask = new TimerTask() {
-        @Override
-        public void run() {
-          check();
-        }
-      };
-    }
+  /**
+   * Util to assert ordering of messages in a stream with single partition
+   *
+   * @param collectionStream represents the actual stream which will be consumed to compare against expected list
+   * @param expected represents the expected stream of messages
+   * @param timeout maximum time to wait for consuming the stream
+   * @param <M> represents the type of Message in the stream
+   * @throws InterruptedException when {@code consumeStream} is interrupted by another thread during polling messages
+   */
+  public static <M> void containsInOrder(CollectionStream<M> collectionStream, final List<M> expected, Duration timeout)
+      throws InterruptedException {
+    Preconditions.checkNotNull(collectionStream, "This util is intended to use only on CollectionStream");
+    assertThat(TestRunner.consumeStream(collectionStream, timeout)
+        .entrySet()
+        .stream()
+        .flatMap(entry -> entry.getValue().stream())
+        .collect(Collectors.toList()), IsIterableContainingInOrder.contains(expected.toArray()));
+  }
 
-    private void check() {
-      final CountDownLatch latch = LATCHES.get(id);
-      try {
-        assertThat(actual, Matchers.containsInAnyOrder((M[]) expected.toArray()));
-      } finally {
-        latch.countDown();
-      }
+  /**
+   * Util to assert ordering of messages in a multi-partitioned stream
+   *
+   * @param collectionStream represents the actual stream which will be consumed to compare against expected partition map
+   * @param expected represents a map of partitionId as key and list of messages as value
+   * @param timeout maximum time to wait for consuming the stream
+   * @param <M> represents the type of Message in the stream
+   * @throws InterruptedException when {@code consumeStream} is interrupted by another thread during polling messages
+   */
+  public static <M> void containsInOrder(CollectionStream<M> collectionStream, final Map<Integer, List<M>> expected,
+      Duration timeout) throws InterruptedException {
+    Preconditions.checkNotNull(collectionStream, "This util is intended to use only on CollectionStream");
+    Map<Integer, List<M>> actual = TestRunner.consumeStream(collectionStream, timeout);
+    for (Integer paritionId : expected.keySet()) {
+      assertThat(actual.get(paritionId), IsIterableContainingInOrder.contains(expected.get(paritionId).toArray()));
     }
   }
 }
index 6e647d9..3c45967 100644 (file)
@@ -311,7 +311,7 @@ public class TestRunner {
    *         i.e messages in the partition
    * @throws InterruptedException Thrown when a blocking poll has been interrupted by another thread.
    */
-  public static <T> Map<Integer, List<T>> consumeStream(CollectionStream stream, Integer timeout) throws InterruptedException {
+  public static <T> Map<Integer, List<T>> consumeStream(CollectionStream stream, Duration timeout) throws InterruptedException {
     Preconditions.checkNotNull(stream);
     Preconditions.checkNotNull(stream.getSystemName());
     String streamName = stream.getStreamName();
@@ -334,7 +334,7 @@ public class TestRunner {
     long t = System.currentTimeMillis();
     Map<SystemStreamPartition, List<IncomingMessageEnvelope>> output = new HashMap<>();
     HashSet<SystemStreamPartition> didNotReachEndOfStream = new HashSet<>(ssps);
-    while (System.currentTimeMillis() < t + timeout) {
+    while (System.currentTimeMillis() < t + timeout.toMillis()) {
       Map<SystemStreamPartition, List<IncomingMessageEnvelope>> currentState = consumer.poll(ssps, 10);
       for (Map.Entry<SystemStreamPartition, List<IncomingMessageEnvelope>> entry : currentState.entrySet()) {
         SystemStreamPartition ssp = entry.getKey();
index b3d9485..320a0ac 100644 (file)
@@ -104,7 +104,7 @@ public class CollectionStream<T> {
   private CollectionStream(String systemName, String streamName, Map<Integer, ? extends Iterable<T>> initPartitions) {
     this(systemName, streamName);
     Preconditions.checkNotNull(initPartitions);
-    initPartitions = new HashMap<>(initPartitions);
+    this.initPartitions = new HashMap<>(initPartitions);
   }
 
   /**
index ad25cae..3a1eba0 100644 (file)
 package org.apache.samza.test.framework;
 
 import java.time.Duration;
+import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.HashMap;
 import java.util.List;
+import java.util.Map;
+import java.util.stream.Collectors;
+import org.apache.samza.operators.KV;
 import org.apache.samza.test.framework.stream.CollectionStream;
 import org.hamcrest.collection.IsIterableContainingInOrder;
 import org.junit.Assert;
@@ -44,10 +49,82 @@ public class AsyncStreamTaskIntegrationTest {
         .addOutputStream(output)
         .run(Duration.ofSeconds(2));
 
-    Assert.assertThat(TestRunner.consumeStream(output, 1000).get(0),
+    Assert.assertThat(TestRunner.consumeStream(output, Duration.ofMillis(1000)).get(0),
         IsIterableContainingInOrder.contains(outputList.toArray()));
   }
 
+  @Test
+  public void testAsyncTaskWithSinglePartitionUsingStreamAssert() throws Exception {
+    List<Integer> inputList = Arrays.asList(1, 2, 3, 4, 5);
+    List<Integer> outputList = Arrays.asList(50, 10, 20, 30, 40);
+
+    CollectionStream<Integer> input = CollectionStream.of("async-test", "ints", inputList);
+    CollectionStream output = CollectionStream.empty("async-test", "ints-out");
+
+    TestRunner
+        .of(MyAsyncStreamTask.class)
+        .addInputStream(input)
+        .addOutputStream(output)
+        .run(Duration.ofSeconds(2));
+
+    StreamAssert.containsInAnyOrder(output, outputList, Duration.ofMillis(1000));
+  }
+
+  @Test
+  public void testAsyncTaskWithMultiplePartition() throws Exception {
+    Map<Integer, List<KV>> inputPartitionData = new HashMap<>();
+    Map<Integer, List<Integer>> expectedOutputPartitionData = new HashMap<>();
+    List<Integer> partition = Arrays.asList(1, 2, 3, 4, 5);
+    List<Integer> outputPartition = partition.stream().map(x -> x * 10).collect(Collectors.toList());
+    for (int i = 0; i < 5; i++) {
+      List<KV> keyedPartition = new ArrayList<>();
+      for (Integer val : partition) {
+        keyedPartition.add(KV.of(i, val));
+      }
+      inputPartitionData.put(i, keyedPartition);
+      expectedOutputPartitionData.put(i, new ArrayList<Integer>(outputPartition));
+    }
+
+    CollectionStream<KV> inputStream = CollectionStream.of("async-test", "ints", inputPartitionData);
+    CollectionStream outputStream = CollectionStream.empty("async-test", "ints-out", 5);
+
+    TestRunner
+        .of(MyAsyncStreamTask.class)
+        .addInputStream(inputStream)
+        .addOutputStream(outputStream)
+        .run(Duration.ofSeconds(2));
+
+    StreamAssert.containsInOrder(outputStream, expectedOutputPartitionData, Duration.ofMillis(1000));
+  }
+
+  @Test
+  public void testAsyncTaskWithMultiplePartitionMultithreaded() throws Exception {
+    Map<Integer, List<KV>> inputPartitionData = new HashMap<>();
+    Map<Integer, List<Integer>> expectedOutputPartitionData = new HashMap<>();
+    List<Integer> partition = Arrays.asList(1, 2, 3, 4, 5);
+    List<Integer> outputPartition = partition.stream().map(x -> x * 10).collect(Collectors.toList());
+    for (int i = 0; i < 5; i++) {
+      List<KV> keyedPartition = new ArrayList<>();
+      for (Integer val : partition) {
+        keyedPartition.add(KV.of(i, val));
+      }
+      inputPartitionData.put(i, keyedPartition);
+      expectedOutputPartitionData.put(i, new ArrayList<Integer>(outputPartition));
+    }
+
+    CollectionStream<KV> inputStream = CollectionStream.of("async-test", "ints", inputPartitionData);
+    CollectionStream outputStream = CollectionStream.empty("async-test", "ints-out", 5);
+
+    TestRunner
+        .of(MyAsyncStreamTask.class)
+        .addInputStream(inputStream)
+        .addOutputStream(outputStream)
+        .addOverrideConfig("task.max.concurrency", "4")
+        .run(Duration.ofSeconds(2));
+
+    StreamAssert.containsInAnyOrder(outputStream, expectedOutputPartitionData, Duration.ofMillis(1000));
+  }
+
   /**
    * Job should fail because it times out too soon
    */
@@ -17,7 +17,7 @@
  * under the License.
  */
 
-package org.apache.samza.test.operator;
+package org.apache.samza.test.framework;
 
 import org.apache.samza.application.StreamApplication;
 import org.apache.samza.config.Config;
@@ -25,7 +25,6 @@ import org.apache.samza.operators.MessageStream;
 import org.apache.samza.operators.StreamGraph;
 import org.apache.samza.serializers.JsonSerdeV2;
 import org.apache.samza.test.operator.data.PageView;
-import org.apache.samza.test.framework.StreamAssert;
 
 import java.util.Arrays;
 
@@ -46,7 +45,7 @@ public class BroadcastAssertApp implements StreamApplication {
     /**
      * Each task will see all the pageview events
      */
-    StreamAssert.that("Each task contains all broadcast PageView events", broadcastPageViews, serde)
+    MessageStreamAssert.that("Each task contains all broadcast PageView events", broadcastPageViews, serde)
         .forEachTask()
         .containsInAnyOrder(
             Arrays.asList(
index 347e766..4ecb4b6 100644 (file)
@@ -60,7 +60,8 @@ class RestCall extends Thread {
       System.out.println("Thread " + this.getName() + " interrupted.");
     }
     Integer obj = (Integer) envelope.getMessage();
-    messageCollector.send(new OutgoingMessageEnvelope(new SystemStream("async-test", "ints-out"), obj * 10));
+    messageCollector.send(new OutgoingMessageEnvelope(new SystemStream("async-test", "ints-out"),
+        envelope.getKey(), envelope.getKey(), obj * 10));
     callback.complete();
   }
 }
index c83e461..a07fe74 100644 (file)
@@ -32,6 +32,7 @@ public class MyStreamTestTask implements StreamTask {
   public void process(IncomingMessageEnvelope envelope, MessageCollector collector, TaskCoordinator coordinator)
       throws Exception {
     Integer obj = (Integer) envelope.getMessage();
-    collector.send(new OutgoingMessageEnvelope(new SystemStream("test", "output"), obj * 10));
+    collector.send(new OutgoingMessageEnvelope(new SystemStream("test", "output"),
+        envelope.getKey(), envelope.getKey(), obj * 10));
   }
 }
index 8ac40e1..ba4c985 100644 (file)
@@ -20,7 +20,9 @@ package org.apache.samza.test.framework;
 
 import java.time.Duration;
 import java.util.ArrayList;
+import java.util.HashMap;
 import java.util.List;
+import java.util.Map;
 import java.util.Random;
 import org.apache.samza.SamzaException;
 import org.apache.samza.application.StreamApplication;
@@ -60,10 +62,12 @@ public class StreamApplicationIntegrationTest {
     Random random = new Random();
     int count = 10;
     List<PageView> pageviews = new ArrayList<>(count);
+    Map<Integer, List<PageView>> expectedOutput = new HashMap<>();
     for (int i = 0; i < count; i++) {
       String pagekey = PAGEKEYS[random.nextInt(PAGEKEYS.length - 1)];
       int memberId = i;
-      pageviews.add(new PageView(pagekey, memberId));
+      PageView pv = new PageView(pagekey, memberId);
+      pageviews.add(pv);
     }
 
     CollectionStream<PageView> input = CollectionStream.of("test", "PageView", pageviews);
@@ -76,7 +80,7 @@ public class StreamApplicationIntegrationTest {
         .addOverrideConfig("job.default.system", "test")
         .run(Duration.ofMillis(1500));
 
-    Assert.assertEquals(TestRunner.consumeStream(output, 10000).get(random.nextInt(count)).size(), 1);
+    Assert.assertEquals(TestRunner.consumeStream(output, Duration.ofMillis(1000)).get(random.nextInt(count)).size(), 1);
   }
 
   public static final class Values {
@@ -124,4 +128,5 @@ public class StreamApplicationIntegrationTest {
         .addOverrideConfig("job.default.system", "test")
         .run(Duration.ofMillis(1000));
   }
+
 }
@@ -16,7 +16,7 @@
  * specific language governing permissions and limitations
  * under the License.
  */
-package org.apache.samza.test.operator;
+package org.apache.samza.test.framework;
 
 import kafka.utils.TestUtils;
 import org.apache.kafka.clients.consumer.ConsumerRecord;
@@ -35,7 +35,6 @@ import org.apache.samza.runtime.AbstractApplicationRunner;
 import org.apache.samza.runtime.ApplicationRunner;
 import org.apache.samza.system.kafka.KafkaSystemAdmin;
 import org.apache.samza.test.harness.AbstractIntegrationTestHarness;
-import org.apache.samza.test.framework.StreamAssert;
 import scala.Option;
 import scala.Option$;
 
@@ -260,7 +259,7 @@ public class StreamApplicationIntegrationTestHarness extends AbstractIntegration
     AbstractApplicationRunner runner = (AbstractApplicationRunner) ApplicationRunner.fromConfig(config);
     runner.run(streamApplication);
 
-    StreamAssert.waitForComplete();
+    MessageStreamAssert.waitForComplete();
     return new RunApplicationContext(runner, config);
   }
 
index 2cc5977..f888b4a 100644 (file)
 package org.apache.samza.test.framework;
 
 import java.time.Duration;
+import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.HashMap;
 import java.util.List;
+import java.util.Map;
+import java.util.stream.Collectors;
 import org.apache.samza.SamzaException;
+import org.apache.samza.operators.KV;
 import org.apache.samza.test.framework.stream.CollectionStream;
 import org.hamcrest.collection.IsIterableContainingInOrder;
 import org.junit.Assert;
 import org.junit.Test;
 
+
 public class StreamTaskIntegrationTest {
 
   @Test
@@ -40,7 +46,7 @@ public class StreamTaskIntegrationTest {
 
     TestRunner.of(MyStreamTestTask.class).addInputStream(input).addOutputStream(output).run(Duration.ofSeconds(1));
 
-    Assert.assertThat(TestRunner.consumeStream(output, 1000).get(0),
+    Assert.assertThat(TestRunner.consumeStream(output, Duration.ofMillis(1000)).get(0),
         IsIterableContainingInOrder.contains(outputList.toArray()));
   }
 
@@ -54,11 +60,79 @@ public class StreamTaskIntegrationTest {
     CollectionStream<Double> input = CollectionStream.of("test", "doubles", inputList);
     CollectionStream output = CollectionStream.empty("test", "output");
 
+    TestRunner.of(MyStreamTestTask.class).addInputStream(input).addOutputStream(output).run(Duration.ofSeconds(1));
+  }
+
+  @Test
+  public void testSyncTaskWithSinglePartitionMultithreaded() throws Exception {
+    List<Integer> inputList = Arrays.asList(1, 2, 3, 4, 5);
+    List<Integer> outputList = Arrays.asList(10, 20, 30, 40, 50);
+
+    CollectionStream<Integer> input = CollectionStream.of("test", "input", inputList);
+    CollectionStream output = CollectionStream.empty("test", "output");
+
     TestRunner
         .of(MyStreamTestTask.class)
         .addInputStream(input)
         .addOutputStream(output)
+        .addOverrideConfig("job.container.thread.pool.size", "4")
         .run(Duration.ofSeconds(1));
+
+    StreamAssert.containsInOrder(output, outputList, Duration.ofMillis(1000));
   }
 
+  @Test
+  public void testSyncTaskWithMultiplePartition() throws Exception {
+    Map<Integer, List<KV>> inputPartitionData = new HashMap<>();
+    Map<Integer, List<Integer>> expectedOutputPartitionData = new HashMap<>();
+    List<Integer> partition = Arrays.asList(1, 2, 3, 4, 5);
+    List<Integer> outputPartition = partition.stream().map(x -> x * 10).collect(Collectors.toList());
+    for (int i = 0; i < 5; i++) {
+      List<KV> keyedPartition = new ArrayList<>();
+      for (Integer val : partition) {
+        keyedPartition.add(KV.of(i, val));
+      }
+      inputPartitionData.put(i, keyedPartition);
+      expectedOutputPartitionData.put(i, new ArrayList<Integer>(outputPartition));
+    }
+
+    CollectionStream<KV> inputStream = CollectionStream.of("test", "input", inputPartitionData);
+    CollectionStream outputStream = CollectionStream.empty("test", "output", 5);
+
+    TestRunner
+        .of(MyStreamTestTask.class)
+        .addInputStream(inputStream)
+        .addOutputStream(outputStream)
+        .run(Duration.ofSeconds(2));
+
+    StreamAssert.containsInOrder(outputStream, expectedOutputPartitionData, Duration.ofMillis(1000));
+  }
+
+  @Test
+  public void testSyncTaskWithMultiplePartitionMultithreaded() throws Exception {
+    Map<Integer, List<KV>> inputPartitionData = new HashMap<>();
+    Map<Integer, List<Integer>> expectedOutputPartitionData = new HashMap<>();
+    List<Integer> partition = Arrays.asList(1, 2, 3, 4, 5);
+    List<Integer> outputPartition = partition.stream().map(x -> x * 10).collect(Collectors.toList());
+    for (int i = 0; i < 5; i++) {
+      List<KV> keyedPartition = new ArrayList<>();
+      for (Integer val : partition) {
+        keyedPartition.add(KV.of(i, val));
+      }
+      inputPartitionData.put(i, keyedPartition);
+      expectedOutputPartitionData.put(i, new ArrayList<Integer>(outputPartition));
+    }
+
+    CollectionStream<KV> inputStream = CollectionStream.of("test", "input", inputPartitionData);
+    CollectionStream outputStream = CollectionStream.empty("test", "output", 5);
+
+    TestRunner
+        .of(MyStreamTestTask.class)
+        .addInputStream(inputStream)
+        .addOutputStream(outputStream)
+        .addOverrideConfig("job.container.thread.pool.size", "4")
+        .run(Duration.ofSeconds(2));
+
+    StreamAssert.containsInOrder(outputStream, expectedOutputPartitionData, Duration.ofMillis(1000));
+  }
 }
@@ -17,7 +17,7 @@
  * under the License.
  */
 
-package org.apache.samza.test.timer;
+package org.apache.samza.test.framework;
 
 import org.apache.samza.application.StreamApplication;
 import org.apache.samza.config.Config;
@@ -28,7 +28,6 @@ import org.apache.samza.operators.functions.FlatMapFunction;
 import org.apache.samza.operators.functions.TimerFunction;
 import org.apache.samza.serializers.JsonSerdeV2;
 import org.apache.samza.test.operator.data.PageView;
-import org.apache.samza.test.framework.StreamAssert;
 
 import java.util.ArrayList;
 import java.util.Arrays;
@@ -45,7 +44,7 @@ public class TestTimerApp implements StreamApplication {
     final MessageStream<PageView> pageViews = graph.getInputStream(PAGE_VIEWS, serde);
     final MessageStream<PageView> output = pageViews.flatMap(new FlatmapTimerFn());
 
-    StreamAssert.that("Output from timer function should container all complete messages", output, serde)
+    MessageStreamAssert.that("Output from timer function should container all complete messages", output, serde)
         .containsInAnyOrder(
             Arrays.asList(
                 new PageView("v1-complete", "p1", "u1"),
  * under the License.
  */
 
-package org.apache.samza.test.timer;
+package org.apache.samza.test.framework;
 
-import org.apache.samza.test.operator.StreamApplicationIntegrationTestHarness;
 import org.junit.Before;
 import org.junit.Test;
 
 
-import static org.apache.samza.test.timer.TestTimerApp.PAGE_VIEWS;
+import static org.apache.samza.test.framework.TestTimerApp.PAGE_VIEWS;
 
 public class TimerTest extends StreamApplicationIntegrationTestHarness {
 
index a9a4026..2f75103 100644 (file)
@@ -27,6 +27,8 @@ import org.apache.samza.Partition;
 import org.apache.samza.system.SystemStreamMetadata;
 import org.apache.samza.system.SystemStreamMetadata.SystemStreamPartitionMetadata;
 import org.apache.samza.system.kafka.KafkaSystemAdmin;
+import org.apache.samza.test.framework.BroadcastAssertApp;
+import org.apache.samza.test.framework.StreamApplicationIntegrationTestHarness;
 import org.apache.samza.util.ExponentialSleepStrategy;
 import org.junit.Assert;
 import org.junit.Test;
index fbc315f..058a690 100644 (file)
@@ -24,6 +24,7 @@ import org.apache.kafka.clients.consumer.ConsumerRecord;
 import org.apache.samza.config.JobCoordinatorConfig;
 import org.apache.samza.config.MapConfig;
 import org.apache.samza.config.TaskConfig;
+import org.apache.samza.test.framework.StreamApplicationIntegrationTestHarness;
 import org.apache.samza.test.operator.data.PageView;
 import org.codehaus.jackson.map.ObjectMapper;
 import org.junit.Assert;