SAMZA-1557: Broadcast operator
authorXinyu Liu <xinyuliu.us@gmail.com>
Wed, 24 Jan 2018 22:27:39 +0000 (14:27 -0800)
committerxiliu <xiliu@linkedin.com>
Wed, 24 Jan 2018 22:27:39 +0000 (14:27 -0800)
This patch adds Broadcast operator that allows broadcasting messages to all tasks. It's the counterpart of the Samza broadcast stream in low level api, and will be used by BEAM runner to broadcast views as side input to other part of the pipeline.

Author: xiliu <xiliu@linkedin.com>

Reviewers: Jagadish V <vjagadish1989@gmail.com>

Closes #410 from xinyuiscool/SAMZA-1557

21 files changed:
build.gradle
gradle/dependency-versions.gradle
samza-api/src/main/java/org/apache/samza/operators/MessageStream.java
samza-api/src/main/java/org/apache/samza/system/StreamSpec.java
samza-core/src/main/java/org/apache/samza/execution/JobGraph.java
samza-core/src/main/java/org/apache/samza/execution/JobNode.java
samza-core/src/main/java/org/apache/samza/operators/MessageStreamImpl.java
samza-core/src/main/java/org/apache/samza/operators/StreamGraphImpl.java
samza-core/src/main/java/org/apache/samza/operators/impl/BroadcastOperatorImpl.java [new file with mode: 0644]
samza-core/src/main/java/org/apache/samza/operators/impl/OperatorImplGraph.java
samza-core/src/main/java/org/apache/samza/operators/spec/BroadcastOperatorSpec.java [new file with mode: 0644]
samza-core/src/main/java/org/apache/samza/operators/spec/OperatorSpec.java
samza-core/src/main/java/org/apache/samza/operators/spec/OperatorSpecs.java
samza-kafka/src/main/java/org/apache/samza/system/kafka/KafkaStreamSpec.java
samza-kafka/src/main/scala/org/apache/samza/system/kafka/KafkaSystemAdmin.scala
samza-kafka/src/test/scala/org/apache/samza/checkpoint/kafka/TestKafkaCheckpointManager.scala
samza-test/src/test/java/org/apache/samza/test/operator/BroadcastAssertApp.java [new file with mode: 0644]
samza-test/src/test/java/org/apache/samza/test/operator/StreamApplicationIntegrationTestHarness.java
samza-test/src/test/java/org/apache/samza/test/operator/TestRepartitionJoinWindowApp.java
samza-test/src/test/java/org/apache/samza/test/operator/data/PageView.java
samza-test/src/test/java/org/apache/samza/test/util/StreamAssert.java [new file with mode: 0644]

index 04b42f2..68e9bcf 100644 (file)
@@ -764,6 +764,7 @@ project(":samza-test_$scalaVersion") {
     testCompile project(":samza-core_$scalaVersion").sourceSets.test.output
     testCompile "org.scalatest:scalatest_$scalaVersion:$scalaTestVersion"
     testCompile "org.mockito:mockito-core:$mockitoVersion"
+    testCompile "org.hamcrest:hamcrest-all:$hamcrestVersion"
     testRuntime "org.slf4j:slf4j-simple:$slf4jVersion"
   }
 
index 6ee117f..90483bf 100644 (file)
@@ -27,6 +27,7 @@
   commonsLang3Version = "3.4"
   elasticsearchVersion = "2.2.0"
   guavaVersion = "17.0"
+  hamcrestVersion = "1.3"
   httpClientVersion = "4.4.1"
   jacksonVersion = "1.9.13"
   jerseyVersion = "2.22.1"
index f0a5526..98f0784 100644 (file)
@@ -276,4 +276,19 @@ public interface MessageStream<M> {
    */
   <K, V> void sendTo(Table<KV<K, V>> table);
 
+  /**
+   * Broadcasts messages in this {@link MessageStream} to all instances of its downstream operators..
+   * @param serde the {@link Serde} to use for (de)serializing the message.
+   * @param id id the unique id of this operator in this application
+   * @return the broadcast {@link MessageStream}
+   */
+  MessageStream<M> broadcast(Serde<M> serde, String id);
+
+  /**
+   * Same as calling {@link MessageStream#broadcast(Serde, String)} with a null Serde.
+   * @param id id the unique id of this operator in this application
+   * @return the broadcast {@link MessageStream}
+   */
+  MessageStream<M> broadcast(String id);
+
 }
index 523ff68..3a005c1 100644 (file)
@@ -77,6 +77,11 @@ public class StreamSpec {
   private final boolean isBounded;
 
   /**
+   * broadcast stream to all tasks
+   */
+  private final boolean isBroadcast;
+
+  /**
    * A set of all system-specific configurations for the stream.
    */
   private final Map<String, String> config;
@@ -98,7 +103,7 @@ public class StreamSpec {
    *                      Samza System abstraction. See {@link SystemFactory}
    */
   public StreamSpec(String id, String physicalName, String systemName) {
-    this(id, physicalName, systemName, DEFAULT_PARTITION_COUNT, false, Collections.emptyMap());
+    this(id, physicalName, systemName, DEFAULT_PARTITION_COUNT, false, false, Collections.emptyMap());
   }
 
   /**
@@ -117,7 +122,7 @@ public class StreamSpec {
    * @param partitionCount  The number of partitionts for the stream. A value of {@code 1} indicates unpartitioned.
    */
   public StreamSpec(String id, String physicalName, String systemName, int partitionCount) {
-    this(id, physicalName, systemName, partitionCount, false, Collections.emptyMap());
+    this(id, physicalName, systemName, partitionCount, false, false, Collections.emptyMap());
   }
 
   /**
@@ -137,7 +142,7 @@ public class StreamSpec {
    * @param config        A map of properties for the stream. These may be System-specfic.
    */
   public StreamSpec(String id, String physicalName, String systemName, boolean isBounded, Map<String, String> config) {
-    this(id, physicalName, systemName, DEFAULT_PARTITION_COUNT, isBounded, config);
+    this(id, physicalName, systemName, DEFAULT_PARTITION_COUNT, isBounded, false, config);
   }
 
   /**
@@ -156,9 +161,12 @@ public class StreamSpec {
    *
    * @param isBounded       The stream is bounded or not.
    *
+   * @param isBroadcast     This stream is broadcast or not.
+   *
    * @param config          A map of properties for the stream. These may be System-specfic.
    */
-  public StreamSpec(String id, String physicalName, String systemName, int partitionCount, boolean isBounded, Map<String, String> config) {
+  public StreamSpec(String id, String physicalName, String systemName, int partitionCount,
+                    boolean isBounded, boolean isBroadcast, Map<String, String> config) {
     validateLogicalIdentifier("streamId", id);
     validateLogicalIdentifier("systemName", systemName);
 
@@ -172,6 +180,7 @@ public class StreamSpec {
     this.physicalName = physicalName;
     this.partitionCount = partitionCount;
     this.isBounded = isBounded;
+    this.isBroadcast = isBroadcast;
 
     if (config != null) {
       this.config = Collections.unmodifiableMap(new HashMap<>(config));
@@ -189,11 +198,15 @@ public class StreamSpec {
    * @return                A copy of this StreamSpec with the specified partitionCount.
    */
   public StreamSpec copyWithPartitionCount(int partitionCount) {
-    return new StreamSpec(id, physicalName, systemName, partitionCount, this.isBounded, config);
+    return new StreamSpec(id, physicalName, systemName, partitionCount, this.isBounded, this.isBroadcast, config);
   }
 
   public StreamSpec copyWithPhysicalName(String physicalName) {
-    return new StreamSpec(id, physicalName, systemName, partitionCount, this.isBounded, config);
+    return new StreamSpec(id, physicalName, systemName, partitionCount, this.isBounded, this.isBroadcast, config);
+  }
+
+  public StreamSpec copyWithBroadCast() {
+    return new StreamSpec(id, physicalName, systemName, partitionCount, this.isBounded, true, config);
   }
 
   public String getId() {
@@ -240,6 +253,10 @@ public class StreamSpec {
     return isBounded;
   }
 
+  public boolean isBroadcast() {
+    return isBroadcast;
+  }
+
   private void validateLogicalIdentifier(String identifierName, String identifierValue) {
     if (identifierValue == null || !identifierValue.matches("[A-Za-z0-9_-]+")) {
       throw new IllegalArgumentException(String.format("Identifier '%s' is '%s'. It must match the expression [A-Za-z0-9_-]+", identifierName, identifierValue));
index 4a09260..abd3ce7 100644 (file)
@@ -184,6 +184,9 @@ import org.slf4j.LoggerFactory;
       edge = new StreamEdge(streamSpec, isIntermediate, config);
       edges.put(streamId, edge);
     }
+    if (streamSpec.isBroadcast()) {
+      edge.setPartitionCount(1);
+    }
     return edge;
   }
 
index 4e337d9..c0b4ee5 100644 (file)
@@ -29,6 +29,7 @@ import java.util.Map;
 import java.util.UUID;
 import java.util.stream.Collectors;
 
+import org.apache.commons.lang3.StringUtils;
 import org.apache.samza.config.Config;
 import org.apache.samza.config.JavaTableConfig;
 import org.apache.samza.config.JobConfig;
@@ -37,6 +38,7 @@ import org.apache.samza.config.SerializerConfig;
 import org.apache.samza.config.StorageConfig;
 import org.apache.samza.config.StreamConfig;
 import org.apache.samza.config.TaskConfig;
+import org.apache.samza.config.TaskConfigJava;
 import org.apache.samza.operators.StreamGraphImpl;
 import org.apache.samza.operators.spec.InputOperatorSpec;
 import org.apache.samza.operators.spec.JoinOperatorSpec;
@@ -130,8 +132,26 @@ public class JobNode {
     Map<String, String> configs = new HashMap<>();
     configs.put(JobConfig.JOB_NAME(), jobName);
 
-    List<String> inputs = inEdges.stream().map(edge -> edge.getFormattedSystemStream()).collect(Collectors.toList());
+    final List<String> inputs = new ArrayList<>();
+    final List<String> broadcasts = new ArrayList<>();
+    for (StreamEdge inEdge : inEdges) {
+      String formattedSystemStream = inEdge.getFormattedSystemStream();
+      if (inEdge.getStreamSpec().isBroadcast()) {
+        broadcasts.add(formattedSystemStream + "#0");
+      } else {
+        inputs.add(formattedSystemStream);
+      }
+    }
     configs.put(TaskConfig.INPUT_STREAMS(), Joiner.on(',').join(inputs));
+    if (!broadcasts.isEmpty()) {
+      // TODO: remove this once we support defining broadcast input stream in high-level
+      // task.broadcast.input should be generated by the planner in the future.
+      final String taskBroadcasts = config.get(TaskConfigJava.BROADCAST_INPUT_STREAMS);
+      if (StringUtils.isNoneEmpty(taskBroadcasts)) {
+        broadcasts.add(taskBroadcasts);
+      }
+      configs.put(TaskConfigJava.BROADCAST_INPUT_STREAMS, Joiner.on(',').join(broadcasts));
+    }
 
     // set triggering interval if a window or join is defined
     if (streamGraph.hasWindowOrJoins()) {
index 07af54f..1681f30 100644 (file)
@@ -30,6 +30,7 @@ import org.apache.samza.operators.functions.JoinFunction;
 import org.apache.samza.operators.functions.MapFunction;
 import org.apache.samza.operators.functions.SinkFunction;
 import org.apache.samza.operators.functions.StreamTableJoinFunction;
+import org.apache.samza.operators.spec.BroadcastOperatorSpec;
 import org.apache.samza.operators.spec.JoinOperatorSpec;
 import org.apache.samza.operators.spec.OperatorSpec;
 import org.apache.samza.operators.spec.OperatorSpec.OpCode;
@@ -199,4 +200,19 @@ public class MessageStreamImpl<M> implements MessageStream<M> {
     this.operatorSpec.registerNextOperatorSpec(op);
   }
 
+  @Override
+  public MessageStream<M> broadcast(Serde<M> serde, String userDefinedId) {
+    String opId = this.graph.getNextOpId(OpCode.BROADCAST, userDefinedId);
+    IntermediateMessageStreamImpl<M> intermediateStream = this.graph.getIntermediateStream(opId, serde, true);
+    BroadcastOperatorSpec<M> broadcastOperatorSpec =
+        OperatorSpecs.createBroadCastOperatorSpec(intermediateStream.getOutputStream(), opId);
+    this.operatorSpec.registerNextOperatorSpec(broadcastOperatorSpec);
+    return intermediateStream;
+  }
+
+  @Override
+  public MessageStream<M> broadcast(String userDefinedId) {
+    return broadcast(null, userDefinedId);
+  }
+
 }
index b607c62..7ddcd19 100644 (file)
@@ -168,17 +168,30 @@ public class StreamGraphImpl implements StreamGraph {
   }
 
   /**
+   * See {@link StreamGraphImpl#getIntermediateStream(String, Serde, boolean)}.
+   */
+  <M> IntermediateMessageStreamImpl<M> getIntermediateStream(String streamId, Serde<M> serde) {
+    return getIntermediateStream(streamId, serde, false);
+  }
+
+  /**
    * Internal helper for {@link MessageStreamImpl} to add an intermediate {@link MessageStream} to the graph.
    * An intermediate {@link MessageStream} is both an output and an input stream.
    *
    * @param streamId the id of the stream to be created.
    * @param serde the {@link Serde} to use for the message in the intermediate stream. If null, the default serde
    *              is used.
+   * @param isBroadcast whether the stream is a broadcast stream.
    * @param <M> the type of messages in the intermediate {@link MessageStream}
    * @return  the intermediate {@link MessageStreamImpl}
+   *
+   * TODO: once SAMZA-1566 is resolved, we should be able to pass in the StreamSpec directly.
    */
-  <M> IntermediateMessageStreamImpl<M> getIntermediateStream(String streamId, Serde<M> serde) {
+  <M> IntermediateMessageStreamImpl<M> getIntermediateStream(String streamId, Serde<M> serde, boolean isBroadcast) {
     StreamSpec streamSpec = runner.getStreamSpec(streamId);
+    if (isBroadcast) {
+      streamSpec = streamSpec.copyWithBroadCast();
+    }
 
     Preconditions.checkState(!inputOperators.containsKey(streamSpec) && !outputStreams.containsKey(streamSpec),
         "getIntermediateStream must not be called multiple times with the same streamId: " + streamId);
diff --git a/samza-core/src/main/java/org/apache/samza/operators/impl/BroadcastOperatorImpl.java b/samza-core/src/main/java/org/apache/samza/operators/impl/BroadcastOperatorImpl.java
new file mode 100644 (file)
index 0000000..269e7bc
--- /dev/null
@@ -0,0 +1,84 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.operators.impl;
+
+import org.apache.samza.config.Config;
+import org.apache.samza.operators.spec.BroadcastOperatorSpec;
+import org.apache.samza.operators.spec.OperatorSpec;
+import org.apache.samza.system.ControlMessage;
+import org.apache.samza.system.EndOfStreamMessage;
+import org.apache.samza.system.OutgoingMessageEnvelope;
+import org.apache.samza.system.SystemStream;
+import org.apache.samza.system.WatermarkMessage;
+import org.apache.samza.task.MessageCollector;
+import org.apache.samza.task.TaskContext;
+import org.apache.samza.task.TaskCoordinator;
+
+import java.util.Collection;
+import java.util.Collections;
+
+class BroadcastOperatorImpl<M> extends OperatorImpl<M, Void> {
+
+  private final BroadcastOperatorSpec<M> broadcastOpSpec;
+  private final SystemStream systemStream;
+  private final String taskName;
+
+  BroadcastOperatorImpl(BroadcastOperatorSpec<M> broadcastOpSpec, TaskContext context) {
+    this.broadcastOpSpec = broadcastOpSpec;
+    this.systemStream = broadcastOpSpec.getOutputStream().getStreamSpec().toSystemStream();
+    this.taskName = context.getTaskName().getTaskName();
+  }
+
+  @Override
+  protected void handleInit(Config config, TaskContext context) {
+  }
+
+  @Override
+  protected Collection<Void> handleMessage(M message, MessageCollector collector, TaskCoordinator coordinator) {
+    collector.send(new OutgoingMessageEnvelope(systemStream, 0, null, message));
+    return Collections.emptyList();
+  }
+
+  @Override
+  protected void handleClose() {
+  }
+
+  @Override
+  protected OperatorSpec<M, Void> getOperatorSpec() {
+    return broadcastOpSpec;
+  }
+
+  @Override
+  protected Collection<Void> handleEndOfStream(MessageCollector collector, TaskCoordinator coordinator) {
+    sendControlMessage(new EndOfStreamMessage(taskName), collector);
+    return Collections.emptyList();
+  }
+
+  @Override
+  protected Collection<Void> handleWatermark(long watermark, MessageCollector collector, TaskCoordinator coordinator) {
+    sendControlMessage(new WatermarkMessage(watermark, taskName), collector);
+    return Collections.emptyList();
+  }
+
+  private void sendControlMessage(ControlMessage message, MessageCollector collector) {
+    OutgoingMessageEnvelope envelopeOut = new OutgoingMessageEnvelope(systemStream, 0, null, message);
+    collector.send(envelopeOut);
+  }
+}
index ea278c1..3882544 100644 (file)
@@ -35,6 +35,7 @@ import org.apache.samza.operators.StreamGraphImpl;
 import org.apache.samza.operators.functions.JoinFunction;
 import org.apache.samza.operators.functions.PartialJoinFunction;
 import org.apache.samza.operators.impl.store.TimestampedValue;
+import org.apache.samza.operators.spec.BroadcastOperatorSpec;
 import org.apache.samza.operators.spec.InputOperatorSpec;
 import org.apache.samza.operators.spec.JoinOperatorSpec;
 import org.apache.samza.operators.spec.OperatorSpec;
@@ -218,6 +219,8 @@ public class OperatorImplGraph {
       return new StreamTableJoinOperatorImpl((StreamTableJoinOperatorSpec) operatorSpec, config, context);
     } else if (operatorSpec instanceof SendToTableOperatorSpec) {
       return new SendToTableOperatorImpl((SendToTableOperatorSpec) operatorSpec, config, context);
+    } else if (operatorSpec instanceof BroadcastOperatorSpec) {
+      return new BroadcastOperatorImpl((BroadcastOperatorSpec) operatorSpec, context);
     }
     throw new IllegalArgumentException(
         String.format("Unsupported OperatorSpec: %s", operatorSpec.getClass().getName()));
@@ -366,6 +369,9 @@ public class OperatorImplGraph {
     if (opSpec instanceof PartitionByOperatorSpec) {
       PartitionByOperatorSpec spec = (PartitionByOperatorSpec) opSpec;
       outputToInputStreams.put(spec.getOutputStream().getStreamSpec().toSystemStream(), input);
+    } else if (opSpec instanceof BroadcastOperatorSpec) {
+      BroadcastOperatorSpec spec = (BroadcastOperatorSpec) opSpec;
+      outputToInputStreams.put(spec.getOutputStream().getStreamSpec().toSystemStream(), input);
     } else {
       Collection<OperatorSpec> nextOperators = opSpec.getRegisteredOperatorSpecs();
       nextOperators.forEach(spec -> computeOutputToInput(input, spec, outputToInputStreams));
diff --git a/samza-core/src/main/java/org/apache/samza/operators/spec/BroadcastOperatorSpec.java b/samza-core/src/main/java/org/apache/samza/operators/spec/BroadcastOperatorSpec.java
new file mode 100644 (file)
index 0000000..6689690
--- /dev/null
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+
+package org.apache.samza.operators.spec;
+
+import org.apache.samza.operators.functions.WatermarkFunction;
+
+public class BroadcastOperatorSpec<M> extends OperatorSpec<M, Void> {
+  private final OutputStreamImpl<M> outputStream;
+
+
+  public BroadcastOperatorSpec(OutputStreamImpl<M> outputStream, String opId) {
+    super(OpCode.BROADCAST, opId);
+
+    this.outputStream = outputStream;
+  }
+
+  public OutputStreamImpl<M> getOutputStream() {
+    return this.outputStream;
+  }
+
+  @Override
+  public WatermarkFunction getWatermarkFn() {
+    return null;
+  }
+}
index 2a5991c..00b5318 100644 (file)
@@ -49,7 +49,8 @@ public abstract class OperatorSpec<M, OM> {
     WINDOW,
     MERGE,
     PARTITION_BY,
-    OUTPUT
+    OUTPUT,
+    BROADCAST
   }
 
   private final String opId;
index c752fe2..2a2e33a 100644 (file)
@@ -278,4 +278,16 @@ public class OperatorSpecs {
     return new SendToTableOperatorSpec(inputOpSpec, tableSpec, opId);
   }
 
+  /**
+   * Creates a {@link BroadcastOperatorSpec} for the Broadcast operator.
+   * @param outputStream the {@link OutputStreamImpl} to send messages to
+   * @param opId the unique ID of the operator
+   * @param <M> the type of input message
+   * @return the {@link BroadcastOperatorSpec}
+   */
+  public static <M> BroadcastOperatorSpec<M> createBroadCastOperatorSpec(
+      OutputStreamImpl<M> outputStream, String opId) {
+    return new BroadcastOperatorSpec<>(outputStream, opId);
+  }
+
 }
index a84c434..217248d 100644 (file)
@@ -110,6 +110,7 @@ public class KafkaStreamSpec extends StreamSpec {
                                 originalSpec.getSystemName(),
                                 originalSpec.getPartitionCount(),
                                 replicationFactor,
+                                originalSpec.isBroadcast(),
                                 mapToProperties(filterUnsupportedProperties(originalSpec.getConfig())));
   }
 
@@ -124,7 +125,7 @@ public class KafkaStreamSpec extends StreamSpec {
    * @param partitionCount  The number of partitions.
    */
   public KafkaStreamSpec(String id, String topicName, String systemName, int partitionCount) {
-    this(id, topicName, systemName, partitionCount, DEFAULT_REPLICATION_FACTOR, new Properties());
+    this(id, topicName, systemName, partitionCount, DEFAULT_REPLICATION_FACTOR, false, new Properties());
   }
 
   /**
@@ -145,11 +146,13 @@ public class KafkaStreamSpec extends StreamSpec {
    *
    * @param replicationFactor The number of topic replicas in the Kafka cluster for durability.
    *
+   * @param isBroadcast       The stream is broadcast or not.
+   *
    * @param properties        A set of properties for the stream. These may be System-specfic.
    */
   public KafkaStreamSpec(String id, String topicName, String systemName, int partitionCount, int replicationFactor,
-      Properties properties) {
-    super(id, topicName, systemName, partitionCount, false, propertiesToMap(properties));
+      Boolean isBroadcast, Properties properties) {
+    super(id, topicName, systemName, partitionCount, false, isBroadcast, propertiesToMap(properties));
 
     if (partitionCount < 1) {
       throw new IllegalArgumentException("Parameter 'partitionCount' must be > 0");
@@ -164,11 +167,13 @@ public class KafkaStreamSpec extends StreamSpec {
 
   @Override
   public StreamSpec copyWithPartitionCount(int partitionCount) {
-    return new KafkaStreamSpec(getId(), getPhysicalName(), getSystemName(), partitionCount, getReplicationFactor(), getProperties());
+    return new KafkaStreamSpec(getId(), getPhysicalName(), getSystemName(), partitionCount, getReplicationFactor(),
+        isBroadcast(), getProperties());
   }
 
   public KafkaStreamSpec copyWithReplicationFactor(int replicationFactor) {
-    return new KafkaStreamSpec(getId(), getPhysicalName(), getSystemName(), getPartitionCount(), replicationFactor, getProperties());
+    return new KafkaStreamSpec(getId(), getPhysicalName(), getSystemName(), getPartitionCount(), replicationFactor,
+        isBroadcast(), getProperties());
   }
 
   /**
@@ -177,7 +182,8 @@ public class KafkaStreamSpec extends StreamSpec {
    * @return new instance of {@link KafkaStreamSpec}
    */
   public KafkaStreamSpec copyWithProperties(Properties properties) {
-    return new KafkaStreamSpec(getId(), getPhysicalName(), getSystemName(), getPartitionCount(), getReplicationFactor(), properties);
+    return new KafkaStreamSpec(getId(), getPhysicalName(), getSystemName(), getPartitionCount(), getReplicationFactor(),
+        isBroadcast(), properties);
   }
 
   public int getReplicationFactor() {
index 4715141..a9a9bd7 100644 (file)
@@ -452,9 +452,11 @@ class KafkaSystemAdmin(
     if (spec.isChangeLogStream) {
       val topicName = spec.getPhysicalName
       val topicMeta = topicMetaInformation.getOrElse(topicName, throw new StreamValidationException("Unable to find topic information for topic " + topicName))
-      new KafkaStreamSpec(spec.getId, topicName, systemName, spec.getPartitionCount, topicMeta.replicationFactor, topicMeta.kafkaProps)
+      new KafkaStreamSpec(spec.getId, topicName, systemName, spec.getPartitionCount, topicMeta.replicationFactor,
+        spec.isBroadcast, topicMeta.kafkaProps)
     } else if (spec.isCoordinatorStream){
-      new KafkaStreamSpec(spec.getId, spec.getPhysicalName, systemName, 1, coordinatorStreamReplicationFactor, coordinatorStreamProperties)
+      new KafkaStreamSpec(spec.getId, spec.getPhysicalName, systemName, 1, coordinatorStreamReplicationFactor,
+        spec.isBroadcast, coordinatorStreamProperties)
     } else if (intermediateStreamProperties.contains(spec.getId)) {
       KafkaStreamSpec.fromSpec(spec).copyWithProperties(intermediateStreamProperties(spec.getId))
     } else {
index 86cb418..c4e57f7 100644 (file)
@@ -166,7 +166,7 @@ class TestKafkaCheckpointManager extends KafkaServerTestHarness {
 
     val systemFactory = Util.getObj[SystemFactory](systemFactoryClassName)
 
-    val spec = new KafkaStreamSpec("id", cpTopic, checkpointSystemName, 1, 1, props)
+    val spec = new KafkaStreamSpec("id", cpTopic, checkpointSystemName, 1, 1, false, props)
     new KafkaCheckpointManager(spec, systemFactory, failOnTopicValidation, config, new NoOpMetricsRegistry, serde)
   }
 
diff --git a/samza-test/src/test/java/org/apache/samza/test/operator/BroadcastAssertApp.java b/samza-test/src/test/java/org/apache/samza/test/operator/BroadcastAssertApp.java
new file mode 100644 (file)
index 0000000..9c89aba
--- /dev/null
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.test.operator;
+
+import org.apache.samza.application.StreamApplication;
+import org.apache.samza.config.Config;
+import org.apache.samza.operators.MessageStream;
+import org.apache.samza.operators.StreamGraph;
+import org.apache.samza.serializers.JsonSerdeV2;
+import org.apache.samza.test.operator.data.PageView;
+import org.apache.samza.test.util.StreamAssert;
+
+import java.util.Arrays;
+
+import static org.apache.samza.test.operator.RepartitionJoinWindowApp.PAGE_VIEWS;
+
+public class BroadcastAssertApp implements StreamApplication {
+
+  @Override
+  public void init(StreamGraph graph, Config config) {
+    final JsonSerdeV2<PageView> serde = new JsonSerdeV2<>(PageView.class);
+    final MessageStream<PageView> broadcastPageViews = graph
+        .getInputStream(PAGE_VIEWS, serde)
+        .broadcast(serde, "pv");
+
+    /**
+     * Each task will see all the pageview events
+     */
+    StreamAssert.that("Each task contains all broadcast PageView events", broadcastPageViews, serde)
+        .forEachTask()
+        .containsInAnyOrder(
+            Arrays.asList(
+                new PageView("v1", "p1", "u1"),
+                new PageView("v2", "p2", "u1"),
+                new PageView("v3", "p1", "u2"),
+                new PageView("v4", "p3", "u2")
+            ));
+  }
+}
index db46982..04497bd 100644 (file)
@@ -32,6 +32,7 @@ import org.apache.samza.config.KafkaConfig;
 import org.apache.samza.config.MapConfig;
 import org.apache.samza.runtime.ApplicationRunner;
 import org.apache.samza.test.harness.AbstractIntegrationTestHarness;
+import org.apache.samza.test.util.StreamAssert;
 import scala.Option;
 import scala.Option$;
 
@@ -248,6 +249,8 @@ public class StreamApplicationIntegrationTestHarness extends AbstractIntegration
     app = streamApplication;
     runner = ApplicationRunner.fromConfig(new MapConfig(configs));
     runner.run(streamApplication);
+
+    StreamAssert.waitForComplete();
   }
 
   public void setNumEmptyPolls(int numEmptyPolls) {
index 49611bb..a83f9cf 100644 (file)
@@ -20,23 +20,24 @@ package org.apache.samza.test.operator;
 
 import org.apache.kafka.clients.consumer.ConsumerRecord;
 import org.junit.Assert;
+import org.junit.Before;
 import org.junit.Test;
 
 import java.util.Collections;
 import java.util.List;
 
 import static org.apache.samza.test.operator.RepartitionJoinWindowApp.AD_CLICKS;
-import static org.apache.samza.test.operator.RepartitionJoinWindowApp.PAGE_VIEWS;
 import static org.apache.samza.test.operator.RepartitionJoinWindowApp.OUTPUT_TOPIC;
+import static org.apache.samza.test.operator.RepartitionJoinWindowApp.PAGE_VIEWS;
+
 
 /**
  * Test driver for {@link RepartitionJoinWindowApp}.
  */
 public class TestRepartitionJoinWindowApp extends StreamApplicationIntegrationTestHarness {
-  private static final String APP_NAME = "UserPageAdClickCounter";
 
-  @Test
-  public void testRepartitionJoinWindowApp() throws Exception {
+  @Before
+  public void setup() {
     // create topics
     createTopic(PAGE_VIEWS, 2);
     createTopic(AD_CLICKS, 2);
@@ -56,9 +57,14 @@ public class TestRepartitionJoinWindowApp extends StreamApplicationIntegrationTe
     produceMessage(AD_CLICKS, 0, "a1", "{\"viewId\":\"v3\",\"adId\":\"a1\"}");
     produceMessage(AD_CLICKS, 0, "a5", "{\"viewId\":\"v4\",\"adId\":\"a5\"}");
 
+  }
+
+  @Test
+  public void testRepartitionJoinWindowApp() throws Exception {
     // run the application
     RepartitionJoinWindowApp app = new RepartitionJoinWindowApp();
-    runApplication(app, APP_NAME, null);
+    final String appName = "UserPageAdClickCounter";
+    runApplication(app, appName, null);
 
     // consume and validate result
     List<ConsumerRecord<String, String>> messages = consumeMessages(Collections.singletonList(OUTPUT_TOPIC), 2);
@@ -71,4 +77,9 @@ public class TestRepartitionJoinWindowApp extends StreamApplicationIntegrationTe
       Assert.assertEquals("2", value);
     }
   }
+
+  @Test
+  public void testBroadcastApp() {
+    runApplication(new BroadcastAssertApp(), "BroadcastTest", null);
+  }
 }
index d2cebf9..b114b43 100644 (file)
 package org.apache.samza.test.operator.data;
 
 
+import org.codehaus.jackson.annotate.JsonProperty;
+
 public class PageView {
   private String viewId;
   private String pageId;
   private String userId;
 
-  public String getViewId() {
-    return viewId;
+  public PageView(@JsonProperty("view-id") String viewId,
+                  @JsonProperty("page-id") String pageId,
+                  @JsonProperty("user-id") String userId) {
+    this.viewId = viewId;
+    this.pageId = pageId;
+    this.userId = userId;
   }
 
-  public void setViewId(String viewId) {
-    this.viewId = viewId;
+  public String getViewId() {
+    return viewId;
   }
 
   public String getPageId() {
     return pageId;
   }
 
-  public void setPageId(String pageId) {
-    this.pageId = pageId;
-  }
-
   public String getUserId() {
     return userId;
   }
 
-  public void setUserId(String userId) {
-    this.userId = userId;
+  @Override
+  public int hashCode() {
+    final int prime = 31;
+    int result = viewId != null ? viewId.hashCode() : 0;
+    result = prime * result + (pageId != null ? pageId.hashCode() : 0);
+    result = prime * result + (userId != null ? userId.hashCode() : 0);
+    return result;
+  }
+
+  @Override
+  public boolean equals(Object obj) {
+    if (this == obj)
+      return true;
+    if (obj == null)
+      return false;
+    if (getClass() != obj.getClass())
+      return false;
+
+    final PageView other = (PageView) obj;
+
+    if (viewId != null
+        ? !viewId.equals(other.viewId)
+        : other.viewId != null) {
+      return false;
+    }
+
+    if (pageId != null
+        ? !pageId.equals(other.pageId)
+        : other.pageId != null) {
+      return false;
+    }
+
+    return userId != null
+        ? userId.equals(other.userId)
+        : other.userId == null;
+  }
+
+  @Override
+  public String toString() {
+    return "viewId:" + viewId + "|pageId:" + pageId + "|userId:" + userId;
   }
 }
diff --git a/samza-test/src/test/java/org/apache/samza/test/util/StreamAssert.java b/samza-test/src/test/java/org/apache/samza/test/util/StreamAssert.java
new file mode 100644 (file)
index 0000000..8a46db0
--- /dev/null
@@ -0,0 +1,167 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.test.util;
+
+import com.google.common.collect.Iterables;
+import org.apache.samza.config.Config;
+import org.apache.samza.operators.MessageStream;
+import org.apache.samza.operators.functions.SinkFunction;
+import org.apache.samza.serializers.KVSerde;
+import org.apache.samza.serializers.Serde;
+import org.apache.samza.serializers.StringSerde;
+import org.apache.samza.system.SystemStreamPartition;
+import org.apache.samza.task.MessageCollector;
+import org.apache.samza.task.TaskContext;
+import org.apache.samza.task.TaskCoordinator;
+import org.hamcrest.Matchers;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.Timer;
+import java.util.TimerTask;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.CountDownLatch;
+
+import static org.junit.Assert.assertThat;
+
+/**
+ * An assertion on the content of a {@link MessageStream}.
+ *
+ * <p>Example: </pre>{@code
+ * MessageStream<String> stream = streamGraph.getInputStream("input", serde).map(some_function)...;
+ * ...
+ * StreamAssert.that(id, stream, stringSerde).containsInAnyOrder(Arrays.asList("a", "b", "c"));
+ * }</pre>
+ *
+ */
+public class StreamAssert<M> {
+  private final static Map<String, CountDownLatch> LATCHES = new ConcurrentHashMap<>();
+  private final static CountDownLatch PLACE_HOLDER = new CountDownLatch(0);
+
+  private final String id;
+  private final MessageStream<M> messageStream;
+  private final Serde<M> serde;
+  private boolean checkEachTask = false;
+
+  public static <M> StreamAssert<M> that(String id, MessageStream<M> messageStream, Serde<M> serde) {
+    return new StreamAssert<>(id, messageStream, serde);
+  }
+
+  private StreamAssert(String id, MessageStream<M> messageStream, Serde<M> serde) {
+    this.id = id;
+    this.messageStream = messageStream;
+    this.serde = serde;
+  }
+
+  public StreamAssert forEachTask() {
+    checkEachTask = true;
+    return this;
+  }
+
+  public void containsInAnyOrder(final Collection<M> expected) {
+    LATCHES.putIfAbsent(id, PLACE_HOLDER);
+    final MessageStream<M> streamToCheck = checkEachTask
+        ? messageStream
+        : messageStream
+          .partitionBy(m -> null, m -> m, KVSerde.of(new StringSerde(), serde), null)
+          .map(kv -> kv.value);
+
+    streamToCheck.sink(new CheckAgainstExpected<M>(id, expected, checkEachTask));
+  }
+
+  public static void waitForComplete() {
+    try {
+      while (!LATCHES.isEmpty()) {
+        final Set<String> ids  = new HashSet<>(LATCHES.keySet());
+        for (String id : ids) {
+          while (LATCHES.get(id) == PLACE_HOLDER) {
+            Thread.sleep(100);
+          }
+
+          final CountDownLatch latch = LATCHES.get(id);
+          if (latch != null) {
+            latch.await();
+            LATCHES.remove(id);
+          }
+        }
+      }
+    } catch (Exception e) {
+      throw new RuntimeException(e);
+    }
+  }
+
+  private static final class CheckAgainstExpected<M> implements SinkFunction<M> {
+    private static final long TIMEOUT = 5000L;
+
+    private final String id;
+    private final boolean checkEachTask;
+    private final Collection<M> expected;
+
+
+    private transient Timer timer = new Timer();
+    private transient List<M> actual = Collections.synchronizedList(new ArrayList<>());
+    private transient TimerTask timerTask = new TimerTask() {
+      @Override
+      public void run() {
+        check();
+      }
+    };
+
+    CheckAgainstExpected(String id, Collection<M> expected, boolean checkEachTask) {
+      this.id = id;
+      this.expected = expected;
+      this.checkEachTask = checkEachTask;
+    }
+
+    @Override
+    public void init(Config config, TaskContext context) {
+      final SystemStreamPartition ssp = Iterables.getFirst(context.getSystemStreamPartitions(), null);
+      if (ssp == null ? false : ssp.getPartition().getPartitionId() == 0) {
+        final int count = checkEachTask ? context.getSamzaContainerContext().taskNames.size() : 1;
+        LATCHES.put(id, new CountDownLatch(count));
+        timer.schedule(timerTask, TIMEOUT);
+      }
+    }
+
+    @Override
+    public void apply(M message, MessageCollector messageCollector, TaskCoordinator taskCoordinator) {
+      actual.add(message);
+
+      if (actual.size() >= expected.size()) {
+        timerTask.cancel();
+        check();
+      }
+    }
+
+    private void check() {
+      final CountDownLatch latch = LATCHES.get(id);
+      try {
+        assertThat(actual, Matchers.containsInAnyOrder((M[]) expected.toArray()));
+      } finally {
+        latch.countDown();
+      }
+    }
+  }
+}