SAMZA-863: Multithreading support in Samza
authorXinyu Liu <xiliu@linkedin.com>
Tue, 19 Jul 2016 18:53:44 +0000 (11:53 -0700)
committerYi Pan (Data Infrastructure) <nickpan47@gmail.com>
Tue, 19 Jul 2016 18:53:44 +0000 (11:53 -0700)
43 files changed:
checkstyle/import-control.xml
samza-api/src/main/java/org/apache/samza/task/AsyncStreamTask.java [new file with mode: 0644]
samza-api/src/main/java/org/apache/samza/task/TaskCallback.java [new file with mode: 0644]
samza-core/src/main/java/org/apache/samza/container/RunLoopFactory.java [new file with mode: 0644]
samza-core/src/main/java/org/apache/samza/task/AsyncRunLoop.java [new file with mode: 0644]
samza-core/src/main/java/org/apache/samza/task/AsyncStreamTaskAdapter.java [new file with mode: 0644]
samza-core/src/main/java/org/apache/samza/task/CoordinatorRequests.java [new file with mode: 0644]
samza-core/src/main/java/org/apache/samza/task/TaskCallbackFactory.java [new file with mode: 0644]
samza-core/src/main/java/org/apache/samza/task/TaskCallbackImpl.java [new file with mode: 0644]
samza-core/src/main/java/org/apache/samza/task/TaskCallbackListener.java [new file with mode: 0644]
samza-core/src/main/java/org/apache/samza/task/TaskCallbackManager.java [new file with mode: 0644]
samza-core/src/main/java/org/apache/samza/task/TaskCallbackTimeoutException.java [new file with mode: 0644]
samza-core/src/main/java/org/apache/samza/util/Utils.java [new file with mode: 0644]
samza-core/src/main/scala/org/apache/samza/checkpoint/OffsetManager.scala
samza-core/src/main/scala/org/apache/samza/config/JobConfig.scala
samza-core/src/main/scala/org/apache/samza/config/TaskConfig.scala
samza-core/src/main/scala/org/apache/samza/container/RunLoop.scala
samza-core/src/main/scala/org/apache/samza/container/SamzaContainer.scala
samza-core/src/main/scala/org/apache/samza/container/SamzaContainerMetrics.scala
samza-core/src/main/scala/org/apache/samza/container/TaskInstance.scala
samza-core/src/main/scala/org/apache/samza/container/TaskInstanceMetrics.scala
samza-core/src/main/scala/org/apache/samza/coordinator/JobCoordinator.scala
samza-core/src/main/scala/org/apache/samza/system/SystemConsumers.scala
samza-core/src/test/java/org/apache/samza/task/TestAsyncRunLoop.java [new file with mode: 0644]
samza-core/src/test/java/org/apache/samza/task/TestAsyncStreamAdapter.java [new file with mode: 0644]
samza-core/src/test/java/org/apache/samza/task/TestCoordinatorRequests.java [new file with mode: 0644]
samza-core/src/test/java/org/apache/samza/task/TestTaskCallbackImpl.java [new file with mode: 0644]
samza-core/src/test/java/org/apache/samza/task/TestTaskCallbackManager.java [new file with mode: 0644]
samza-core/src/test/scala/org/apache/samza/container/TestRunLoop.scala
samza-core/src/test/scala/org/apache/samza/container/TestSamzaContainer.scala
samza-core/src/test/scala/org/apache/samza/container/TestTaskInstance.scala
samza-core/src/test/scala/org/apache/samza/system/TestSystemConsumers.scala
samza-hdfs/src/main/scala/org/apache/samza/system/hdfs/HdfsSystemProducer.scala
samza-kafka/src/main/scala/org/apache/samza/migration/KafkaCheckpointMigration.scala
samza-kafka/src/main/scala/org/apache/samza/system/kafka/KafkaSystemProducer.scala
samza-kafka/src/test/java/org/apache/samza/system/kafka/TestKafkaSystemProducerJava.java
samza-kafka/src/test/scala/org/apache/samza/system/kafka/TestKafkaSystemProducer.scala
samza-kv-inmemory/src/main/scala/org/apache/samza/storage/kv/inmemory/InMemoryKeyValueStore.scala
samza-kv-rocksdb/src/main/scala/org/apache/samza/storage/kv/RocksDbKeyValueStore.scala
samza-kv-rocksdb/src/test/scala/org/apache/samza/storage/kv/TestRocksDbKeyValueStore.scala
samza-kv/src/main/scala/org/apache/samza/storage/kv/CachedStore.scala
samza-kv/src/test/scala/org/apache/samza/storage/kv/TestCachedStore.scala
samza-test/src/test/scala/org/apache/samza/storage/kv/TestKeyValueStores.scala

index 325c381..c85dc94 100644 (file)
@@ -33,6 +33,7 @@
     <allow pkg="org.apache.commons" />
     <allow class="scala.collection.JavaConversions" />
     <allow class="scala.collection.JavaConverters" />
+    <allow pkg="scala.runtime" />
 
     <subpackage name="config">
         <allow class="org.apache.samza.SamzaException" />
         <allow pkg="org.apache.samza.container" />
         <allow pkg="org.apache.samza.metrics" />
         <allow pkg="org.apache.samza.system" />
+        <allow pkg="org.apache.samza.util" />
+        <allow pkg="org.apache.samza.checkpoint" />
+        <allow class="org.apache.samza.SamzaException" />
+        <allow class="org.apache.samza.Partition" />
     </subpackage>
 
     <subpackage name="container">
         <allow pkg="org.apache.samza.util" />
         <allow pkg="junit.framework" />
         <allow class="org.apache.samza.coordinator.stream.AbstractCoordinatorStreamManager" />
-
+        <allow class="org.apache.samza.SamzaException" />
+        <allow pkg="org.apache.samza.system" />
+        <allow pkg="org.apache.samza.task" />
+        <allow pkg="org.apache.samza.util" />
         <subpackage name="grouper">
             <subpackage name="stream">
                 <allow pkg="org.apache.samza.system" />
diff --git a/samza-api/src/main/java/org/apache/samza/task/AsyncStreamTask.java b/samza-api/src/main/java/org/apache/samza/task/AsyncStreamTask.java
new file mode 100644 (file)
index 0000000..684ba0b
--- /dev/null
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.task;
+
+import org.apache.samza.system.IncomingMessageEnvelope;
+
+/**
+ * An AsyncStreamTask is the basic class to support multithreading execution in Samza container. It’s provided for better
+ * parallelism and resource utilization. This class allows task to make asynchronous calls and fire callbacks upon completion.
+ * Similar to {@link StreamTask}, an AsyncStreamTask may be augmented by implementing other interfaces, such as
+ * {@link InitableTask}, {@link WindowableTask}, or {@link ClosableTask}. The following invariants hold with these mix-ins:
+ *
+ * InitableTask.init - always the first method invoked on an AsyncStreamTask. It happens-before every subsequent
+ * invocation on AsyncStreamTask (for happens-before semantics, see https://docs.oracle.com/javase/tutorial/essential/concurrency/memconsist.html).
+ *
+ * CloseableTask.close - always the last method invoked on an AsyncStreamTask and all other AsyncStreamTask are guaranteed
+ * to happen-before it.
+ *
+ * AsyncStreamTask.processAsync - can run in either a serialized or parallel mode. In the serialized mode (task.process.max.inflight.messages=1),
+ * each invocation of processAsync is guaranteed to happen-before the next. In a parallel execution mode (task.process.max.inflight.messages&gt;1),
+ * there is no such happens-before constraint and the AsyncStreamTask is required to coordinate any shared state.
+ *
+ * WindowableTask.window - in either above mode, it is called when no invocations to processAsync are pending and no new
+ * processAsync invocations can be scheduled until it completes. Therefore, a guarantee that all previous processAsync invocations
+ * happen before an invocation of WindowableTask.window. An invocation to WindowableTask.window is guaranteed to happen-before
+ * any subsequent processAsync invocations. The Samza engine is responsible for ensuring that window is invoked in a timely manner.
+ *
+ * Similar to WindowableTask.window, commits are guaranteed to happen only when there are no pending processAsync or WindowableTask.window
+ * invocations. All preceding invocations happen-before commit and commit happens-before all subsequent invocations.
+ */
+public interface AsyncStreamTask {
+  /**
+   * Called once for each message that this AsyncStreamTask receives.
+   * @param envelope Contains the received deserialized message and key, and also information regarding the stream and
+   * partition of which the message was received from.
+   * @param collector Contains the means of sending message envelopes to the output stream. The collector must only
+   * be used during the current call to the process method; you should not reuse the collector between invocations
+   * of this method.
+   * @param coordinator Manages execution of tasks.
+   * @param callback Triggers the completion of the process.
+   */
+  void processAsync(IncomingMessageEnvelope envelope, MessageCollector collector, TaskCoordinator coordinator, TaskCallback callback);
+}
\ No newline at end of file
diff --git a/samza-api/src/main/java/org/apache/samza/task/TaskCallback.java b/samza-api/src/main/java/org/apache/samza/task/TaskCallback.java
new file mode 100644 (file)
index 0000000..8ba7a36
--- /dev/null
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.task;
+
+/**
+ * A TaskCallback is fired by a {@link AsyncStreamTask} to notify when an asynchronous
+ * process has completed. If the callback is fired multiple times, it will throw IllegalStateException.
+ */
+public interface TaskCallback {
+
+  /**
+   * Invoke when the asynchronous process completed with success.
+   */
+  void complete();
+
+  /**
+   * Invoke when the asynchronous process failed with an error.
+   * @param t  error throwable
+   */
+  void failure(Throwable t);
+}
\ No newline at end of file
diff --git a/samza-core/src/main/java/org/apache/samza/container/RunLoopFactory.java b/samza-core/src/main/java/org/apache/samza/container/RunLoopFactory.java
new file mode 100644 (file)
index 0000000..a789d04
--- /dev/null
@@ -0,0 +1,112 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.container;
+
+import java.util.concurrent.Executor;
+import java.util.concurrent.ExecutorService;
+import org.apache.samza.SamzaException;
+import org.apache.samza.config.TaskConfig;
+import org.apache.samza.system.SystemConsumers;
+import org.apache.samza.task.AsyncRunLoop;
+import org.apache.samza.task.AsyncStreamTask;
+import org.apache.samza.task.StreamTask;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import scala.collection.JavaConversions;
+import scala.runtime.AbstractFunction1;
+
+import static org.apache.samza.util.Utils.defaultValue;
+import static org.apache.samza.util.Utils.defaultClock;
+
+/**
+ * Factory class to create runloop for a Samza task, based on the type
+ * of the task
+ */
+public class RunLoopFactory {
+  private static final Logger log = LoggerFactory.getLogger(RunLoopFactory.class);
+
+  private static final long DEFAULT_WINDOW_MS = -1L;
+  private static final long DEFAULT_COMMIT_MS = 60000L;
+  private static final long DEFAULT_CALLBACK_TIMEOUT_MS = -1L;
+
+  public static Runnable createRunLoop(scala.collection.immutable.Map<TaskName, TaskInstance<?>> taskInstances,
+      SystemConsumers consumerMultiplexer,
+      ExecutorService threadPool,
+      Executor executor,
+      SamzaContainerMetrics containerMetrics,
+      TaskConfig config) {
+
+    long taskWindowMs = config.getWindowMs().getOrElse(defaultValue(DEFAULT_WINDOW_MS));
+
+    log.info("Got window milliseconds: " + taskWindowMs);
+
+    long taskCommitMs = config.getCommitMs().getOrElse(defaultValue(DEFAULT_COMMIT_MS));
+
+    log.info("Got commit milliseconds: " + taskCommitMs);
+
+    int asyncTaskCount = taskInstances.values().count(new AbstractFunction1<TaskInstance<?>, Object>() {
+      @Override
+      public Boolean apply(TaskInstance<?> t) {
+        return t.isAsyncTask();
+      }
+    });
+
+    // asyncTaskCount should be either 0 or the number of all taskInstances
+    if (asyncTaskCount > 0 && asyncTaskCount < taskInstances.size()) {
+      throw new SamzaException("Mixing StreamTask and AsyncStreamTask is not supported");
+    }
+
+    if (asyncTaskCount == 0) {
+      log.info("Run loop in single thread mode.");
+
+      scala.collection.immutable.Map<TaskName, TaskInstance<StreamTask>> streamTaskInstances = (scala.collection.immutable.Map) taskInstances;
+      return new RunLoop(
+        streamTaskInstances,
+        consumerMultiplexer,
+        containerMetrics,
+        taskWindowMs,
+        taskCommitMs,
+        defaultClock(),
+        executor);
+    } else {
+      Integer taskMaxConcurrency = config.getMaxConcurrency().getOrElse(defaultValue(1));
+
+      log.info("Got max messages in flight: " + taskMaxConcurrency);
+
+      Long callbackTimeout = config.getCallbackTimeoutMs().getOrElse(defaultValue(DEFAULT_CALLBACK_TIMEOUT_MS));
+
+      log.info("Got callback timeout: " + callbackTimeout);
+
+      scala.collection.immutable.Map<TaskName, TaskInstance<AsyncStreamTask>> asyncStreamTaskInstances = (scala.collection.immutable.Map) taskInstances;
+
+      log.info("Run loop in asynchronous mode.");
+
+      return new AsyncRunLoop(
+        JavaConversions.asJavaMap(asyncStreamTaskInstances),
+        threadPool,
+        consumerMultiplexer,
+        taskMaxConcurrency,
+        taskWindowMs,
+        taskCommitMs,
+        callbackTimeout,
+        containerMetrics);
+    }
+  }
+}
diff --git a/samza-core/src/main/java/org/apache/samza/task/AsyncRunLoop.java b/samza-core/src/main/java/org/apache/samza/task/AsyncRunLoop.java
new file mode 100644 (file)
index 0000000..a510bb0
--- /dev/null
@@ -0,0 +1,619 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.task;
+
+import java.util.ArrayDeque;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.ScheduledExecutorService;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicInteger;
+import org.apache.samza.SamzaException;
+import org.apache.samza.container.SamzaContainerMetrics;
+import org.apache.samza.container.TaskInstance;
+import org.apache.samza.container.TaskInstanceMetrics;
+import org.apache.samza.container.TaskName;
+import org.apache.samza.system.IncomingMessageEnvelope;
+import org.apache.samza.system.SystemConsumers;
+import org.apache.samza.system.SystemStreamPartition;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import scala.collection.JavaConversions;
+
+
+/**
+ * The AsyncRunLoop supports multithreading execution of Samza {@link AsyncStreamTask}s.
+ */
+public class AsyncRunLoop implements Runnable {
+  private static final Logger log = LoggerFactory.getLogger(AsyncRunLoop.class);
+
+  private final Map<TaskName, AsyncTaskWorker> taskWorkers;
+  private final SystemConsumers consumerMultiplexer;
+  private final Map<SystemStreamPartition, List<AsyncTaskWorker>> sspToTaskWorkerMapping;
+  private final ExecutorService threadPool;
+  private final CoordinatorRequests coordinatorRequests;
+  private final Object latch;
+  private final int maxConcurrency;
+  private final long windowMs;
+  private final long commitMs;
+  private final long callbackTimeoutMs;
+  private final SamzaContainerMetrics containerMetrics;
+  private final ScheduledExecutorService workerTimer;
+  private final ScheduledExecutorService callbackTimer;
+  private volatile boolean shutdownNow = false;
+  private volatile Throwable throwable = null;
+
+  public AsyncRunLoop(Map<TaskName, TaskInstance<AsyncStreamTask>> taskInstances,
+      ExecutorService threadPool,
+      SystemConsumers consumerMultiplexer,
+      int maxConcurrency,
+      long windowMs,
+      long commitMs,
+      long callbackTimeoutMs,
+      SamzaContainerMetrics containerMetrics) {
+
+    this.threadPool = threadPool;
+    this.consumerMultiplexer = consumerMultiplexer;
+    this.containerMetrics = containerMetrics;
+    this.windowMs = windowMs;
+    this.commitMs = commitMs;
+    this.maxConcurrency = maxConcurrency;
+    this.callbackTimeoutMs = callbackTimeoutMs;
+    this.callbackTimer = (callbackTimeoutMs > 0) ? Executors.newSingleThreadScheduledExecutor() : null;
+    this.coordinatorRequests = new CoordinatorRequests(taskInstances.keySet());
+    this.latch = new Object();
+    this.workerTimer = Executors.newSingleThreadScheduledExecutor();
+    Map<TaskName, AsyncTaskWorker> workers = new HashMap<>();
+    for (TaskInstance<AsyncStreamTask> task : taskInstances.values()) {
+      workers.put(task.taskName(), new AsyncTaskWorker(task));
+    }
+    // Partions and tasks assigned to the container will not change during the run loop life time
+    this.taskWorkers = Collections.unmodifiableMap(workers);
+    this.sspToTaskWorkerMapping = Collections.unmodifiableMap(getSspToAsyncTaskWorkerMap(taskInstances, taskWorkers));
+  }
+
+  /**
+   * Returns mapping of the SystemStreamPartition to the AsyncTaskWorkers to efficiently route the envelopes
+   */
+  private static Map<SystemStreamPartition, List<AsyncTaskWorker>> getSspToAsyncTaskWorkerMap(
+      Map<TaskName, TaskInstance<AsyncStreamTask>> taskInstances, Map<TaskName, AsyncTaskWorker> taskWorkers) {
+    Map<SystemStreamPartition, List<AsyncTaskWorker>> sspToWorkerMap = new HashMap<>();
+    for (TaskInstance<AsyncStreamTask> task : taskInstances.values()) {
+      Set<SystemStreamPartition> ssps = JavaConversions.asJavaSet(task.systemStreamPartitions());
+      for (SystemStreamPartition ssp : ssps) {
+        if (sspToWorkerMap.get(ssp) == null) {
+          sspToWorkerMap.put(ssp, new ArrayList<AsyncTaskWorker>());
+        }
+        sspToWorkerMap.get(ssp).add(taskWorkers.get(task.taskName()));
+      }
+    }
+    return sspToWorkerMap;
+  }
+
+  /**
+   * The run loop chooses messages from the SystemConsumers, and run the ready tasks asynchronously.
+   * Window and commit are run in a thread pool, and they are mutual exclusive with task process.
+   * The loop thread will block if all tasks are busy, and resume if any task finishes.
+   */
+  @Override
+  public void run() {
+    try {
+      for (AsyncTaskWorker taskWorker : taskWorkers.values()) {
+        taskWorker.init();
+      }
+
+      long prevNs = System.nanoTime();
+
+      while (!shutdownNow) {
+        if (throwable != null) {
+          log.error("Caught throwable and stopping run loop", throwable);
+          throw new SamzaException(throwable);
+        }
+
+        long startNs = System.nanoTime();
+
+        IncomingMessageEnvelope envelope = chooseEnvelope();
+
+        long chooseNs = System.nanoTime();
+
+        containerMetrics.chooseNs().update(chooseNs - startNs);
+
+        runTasks(envelope);
+
+        long blockNs = System.nanoTime();
+
+        blockIfBusy(envelope);
+
+        long currentNs = System.nanoTime();
+        long activeNs = blockNs - chooseNs;
+        long totalNs = currentNs - prevNs;
+        prevNs = currentNs;
+        containerMetrics.blockNs().update(currentNs - blockNs);
+        containerMetrics.utilization().set(((double) activeNs) / totalNs);
+      }
+    } finally {
+      workerTimer.shutdown();
+      if (callbackTimer != null) callbackTimer.shutdown();
+    }
+  }
+
+  public void shutdown() {
+    shutdownNow = true;
+  }
+
+  /**
+   * Chooses an envelope from messageChooser without updating it. This enables flow control
+   * on the SSP level, meaning the task will not get further messages for the SSP if it cannot
+   * process it. The chooser is updated only after the callback to process is invoked, then the task
+   * is able to process more messages. This flow control does not block. so in case of empty message chooser,
+   * it will return null immediately without blocking, and the chooser will not poll the underlying system
+   * consumer since there are still messages in the SystemConsumers buffer.
+   */
+  private IncomingMessageEnvelope chooseEnvelope() {
+    IncomingMessageEnvelope envelope = consumerMultiplexer.choose(false);
+    if (envelope != null) {
+      log.trace("Choose envelope ssp {} offset {} for processing", envelope.getSystemStreamPartition(), envelope.getOffset());
+      containerMetrics.envelopes().inc();
+    } else {
+      log.trace("No envelope is available");
+      containerMetrics.nullEnvelopes().inc();
+    }
+    return envelope;
+  }
+
+  /**
+   * Insert the envelope into the task pending queues and run all the tasks
+   */
+  private void runTasks(IncomingMessageEnvelope envelope) {
+    if (envelope != null) {
+      PendingEnvelope pendingEnvelope = new PendingEnvelope(envelope);
+      for (AsyncTaskWorker worker : sspToTaskWorkerMapping.get(envelope.getSystemStreamPartition())) {
+        worker.state.insertEnvelope(pendingEnvelope);
+      }
+    }
+
+    for (AsyncTaskWorker worker: taskWorkers.values()) {
+      worker.run();
+    }
+  }
+
+  /**
+   * Block the runloop thread if all tasks are busy. Due to limitation of non-blocking for the flow control,
+   * we block the run loop when there are no runnable tasks, or all tasks are idle (no pending messages) while
+   * chooser is empty too. When a task worker finishes or window/commit completes, it will resume the runloop.
+   */
+  private void blockIfBusy(IncomingMessageEnvelope envelope) {
+    synchronized (latch) {
+      while (!shutdownNow && throwable == null) {
+        for (AsyncTaskWorker worker : taskWorkers.values()) {
+          if (worker.state.isReady() && (envelope != null || worker.state.hasPendingOps())) {
+            // should continue running since the worker state is ready and there is either new message
+            // or some pending operations for the worker
+            return;
+          }
+        }
+
+        try {
+          log.trace("Block loop thread");
+
+          if (envelope == null) {
+            // If the envelope is null then we will wait for a poll interval, otherwise next choose() will
+            // return null immediately and we will have a busy loop
+            latch.wait(consumerMultiplexer.pollIntervalMs());
+            return;
+          } else {
+            latch.wait();
+          }
+        } catch (InterruptedException e) {
+          throw new SamzaException("Run loop is interrupted", e);
+        }
+      }
+    }
+  }
+
+  /**
+   * Resume the runloop thread. It is triggered once a task becomes ready again or has failure.
+   */
+  private void resume() {
+    log.trace("Resume loop thread");
+    if (coordinatorRequests.shouldShutdownNow() && coordinatorRequests.commitRequests().isEmpty()) {
+      shutdownNow = true;
+    }
+    synchronized (latch) {
+      latch.notifyAll();
+    }
+  }
+
+  /**
+   * Set the throwable and abort run loop. The throwable will be thrown from the run loop thread
+   * @param t throwable
+   */
+  private void abort(Throwable t) {
+    throwable = t;
+  }
+
+  /**
+   * PendingEnvenlope contains an envelope that is not processed by this task, and
+   * a flag indicating whether it has been processed by any tasks.
+   */
+  private static final class PendingEnvelope {
+    private final IncomingMessageEnvelope envelope;
+    private boolean processed = false;
+
+    PendingEnvelope(IncomingMessageEnvelope envelope) {
+      this.envelope = envelope;
+    }
+
+    /**
+     * Returns true if the envelope has not been processed.
+     */
+    private boolean markProcessed() {
+      boolean oldValue = processed;
+      processed = true;
+      return !oldValue;
+    }
+  }
+
+
+  private enum WorkerOp {
+    WINDOW,
+    COMMIT,
+    PROCESS,
+    NO_OP
+  }
+
+  /**
+   * The AsyncTaskWorker encapsulates the states of an {@link AsyncStreamTask}. If the task becomes ready, it
+   * will run the task asynchronously. It runs window and commit in the provided thread pool.
+   */
+  private class AsyncTaskWorker implements TaskCallbackListener {
+    private final TaskInstance<AsyncStreamTask> task;
+    private final TaskCallbackManager callbackManager;
+    private volatile AsyncTaskState state;
+
+    AsyncTaskWorker(TaskInstance<AsyncStreamTask> task) {
+      this.task = task;
+      this.callbackManager = new TaskCallbackManager(this, task.metrics(), callbackTimer, callbackTimeoutMs);
+      this.state = new AsyncTaskState(task.taskName(), task.metrics());
+    }
+
+    private void init() {
+      // schedule the timer for windowing and commiting
+      if (task.isWindowableTask() && windowMs > 0L) {
+        workerTimer.scheduleAtFixedRate(new Runnable() {
+          @Override
+          public void run() {
+            log.trace("Task {} need window", task.taskName());
+            state.needWindow();
+            resume();
+          }
+        }, windowMs, windowMs, TimeUnit.MILLISECONDS);
+      }
+
+      if (commitMs > 0L) {
+        workerTimer.scheduleAtFixedRate(new Runnable() {
+          @Override
+          public void run() {
+            log.trace("Task {} need commit", task.taskName());
+            state.needCommit();
+            resume();
+          }
+        }, commitMs, commitMs, TimeUnit.MILLISECONDS);
+      }
+    }
+
+    /**
+     * Invoke next task operation based on its state
+     */
+    private void run() {
+      switch (state.nextOp()) {
+        case PROCESS:
+          process();
+          break;
+        case WINDOW:
+          window();
+          break;
+        case COMMIT:
+          commit();
+          break;
+        default:
+          //no op
+          break;
+      }
+    }
+
+    /**
+     * Process asynchronously. The callback needs to be fired once the processing is done.
+     */
+    private void process() {
+      final IncomingMessageEnvelope envelope = state.fetchEnvelope();
+      log.trace("Process ssp {} offset {}", envelope.getSystemStreamPartition(), envelope.getOffset());
+
+      final ReadableCoordinator coordinator = new ReadableCoordinator(task.taskName());
+      TaskCallbackFactory callbackFactory = new TaskCallbackFactory() {
+        @Override
+        public TaskCallback createCallback() {
+          state.startProcess();
+          containerMetrics.processes().inc();
+          return callbackManager.createCallback(task.taskName(), envelope, coordinator);
+        }
+      };
+
+      task.process(envelope, coordinator, callbackFactory);
+    }
+
+    /**
+     * Invoke window. Run window in thread pool if not the single thread mode.
+     */
+    private void window() {
+      state.startWindow();
+      Runnable windowWorker = new Runnable() {
+        @Override
+        public void run() {
+          try {
+            containerMetrics.windows().inc();
+
+            ReadableCoordinator coordinator = new ReadableCoordinator(task.taskName());
+            long startTime = System.nanoTime();
+            task.window(coordinator);
+            containerMetrics.windowNs().update(System.nanoTime() - startTime);
+            coordinatorRequests.update(coordinator);
+
+            state.doneWindowOrCommit();
+          } catch (Throwable t) {
+            log.error("Task {} window failed", task.taskName(), t);
+            abort(t);
+          } finally {
+            log.trace("Task {} window completed", task.taskName());
+            resume();
+          }
+        }
+      };
+
+      if (threadPool != null) {
+        log.trace("Task {} window on the thread pool", task.taskName());
+        threadPool.submit(windowWorker);
+      } else {
+        log.trace("Task {} window on the run loop thread", task.taskName());
+        windowWorker.run();
+      }
+    }
+
+    /**
+     * Invoke commit. Run commit in thread pool if not the single thread mode.
+     */
+    private void commit() {
+      state.startCommit();
+      Runnable commitWorker = new Runnable() {
+        @Override
+        public void run() {
+          try {
+            containerMetrics.commits().inc();
+
+            long startTime = System.nanoTime();
+            task.commit();
+            containerMetrics.commitNs().update(System.nanoTime() - startTime);
+
+            state.doneWindowOrCommit();
+          } catch (Throwable t) {
+            log.error("Task {} commit failed", task.taskName(), t);
+            abort(t);
+          } finally {
+            log.trace("Task {} commit completed", task.taskName());
+            resume();
+          }
+        }
+      };
+
+      if (threadPool != null) {
+        log.trace("Task {} commits on the thread pool", task.taskName());
+        threadPool.submit(commitWorker);
+      } else {
+        log.trace("Task {} commits on the run loop thread", task.taskName());
+        commitWorker.run();
+      }
+    }
+
+
+
+    /**
+     * Task process completes successfully, update the offsets based on the high-water mark.
+     * Then it will trigger the listener for task state change.
+     * * @param callback AsyncSteamTask.processAsync callback
+     */
+    @Override
+    public void onComplete(TaskCallback callback) {
+      try {
+        state.doneProcess();
+        TaskCallbackImpl callbackImpl = (TaskCallbackImpl) callback;
+        containerMetrics.processNs().update(System.nanoTime() - callbackImpl.timeCreatedNs);
+        log.trace("Got callback complete for task {}, ssp {}", callbackImpl.taskName, callbackImpl.envelope.getSystemStreamPartition());
+
+        TaskCallbackImpl callbackToUpdate = callbackManager.updateCallback(callbackImpl, true);
+        if (callbackToUpdate != null) {
+          IncomingMessageEnvelope envelope = callbackToUpdate.envelope;
+          log.trace("Update offset for ssp {}, offset {}", envelope.getSystemStreamPartition(), envelope.getOffset());
+
+          // update offset
+          task.offsetManager().update(task.taskName(), envelope.getSystemStreamPartition(), envelope.getOffset());
+
+          // update coordinator
+          coordinatorRequests.update(callbackToUpdate.coordinator);
+        }
+      } catch (Throwable t) {
+        log.error(t.getMessage(), t);
+        abort(t);
+      } finally {
+        resume();
+      }
+    }
+
+    /**
+     * Task process fails. Trigger the listener indicating failure.
+     * @param callback AsyncSteamTask.processAsync callback
+     * @param t throwable of the failure
+     */
+    @Override
+    public void onFailure(TaskCallback callback, Throwable t) {
+      try {
+        state.doneProcess();
+        abort(t);
+        // update pending count, but not offset
+        TaskCallbackImpl callbackImpl = (TaskCallbackImpl) callback;
+        callbackManager.updateCallback(callbackImpl, false);
+        log.error("Got callback failure for task {}", callbackImpl.taskName);
+      } catch (Throwable e) {
+        log.error(e.getMessage(), e);
+      } finally {
+        resume();
+      }
+    }
+  }
+
+
+  /**
+   * AsyncTaskState manages the state of the AsyncStreamTask. In summary, a worker has the following states:
+   * ready - ready for window, commit or process next incoming message.
+   * busy - doing window, commit or not able to process next message.
+   * idle - no pending messages, and no window/commit
+   */
+  private final class AsyncTaskState {
+    private volatile boolean needWindow = false;
+    private volatile boolean needCommit = false;
+    private volatile boolean windowOrCommitInFlight = false;
+    private final AtomicInteger messagesInFlight = new AtomicInteger(0);
+    private final ArrayDeque<PendingEnvelope> pendingEnvelopQueue;
+    private final TaskName taskName;
+    private final TaskInstanceMetrics taskMetrics;
+
+    AsyncTaskState(TaskName taskName, TaskInstanceMetrics taskMetrics) {
+      this.taskName = taskName;
+      this.taskMetrics = taskMetrics;
+      this.pendingEnvelopQueue = new ArrayDeque<>();
+    }
+
+    /**
+     * Returns whether the task is ready to do process/window/commit.
+     */
+    private boolean isReady() {
+      needCommit |= coordinatorRequests.commitRequests().remove(taskName);
+      if (needWindow || needCommit) {
+        // ready for window or commit only when no messages are in progress and
+        // no window/commit in flight
+        return messagesInFlight.get() == 0 && !windowOrCommitInFlight;
+      } else {
+        // ready for process only when the inflight message count does not exceed threshold
+        // and no window/commit in flight
+        return messagesInFlight.get() < maxConcurrency && !windowOrCommitInFlight;
+      }
+    }
+
+    private boolean hasPendingOps() {
+      return !pendingEnvelopQueue.isEmpty() || needCommit || needWindow;
+    }
+
+    /**
+     * Returns the next operation by this taskWorker
+     */
+    private WorkerOp nextOp() {
+      if (isReady()) {
+        if (needCommit) return WorkerOp.COMMIT;
+        else if (needWindow) return WorkerOp.WINDOW;
+        else if (!pendingEnvelopQueue.isEmpty()) return WorkerOp.PROCESS;
+      }
+      return WorkerOp.NO_OP;
+    }
+
+    private void needWindow() {
+      needWindow = true;
+    }
+
+    private void needCommit() {
+      needCommit = true;
+    }
+
+    private void startWindow() {
+      needWindow = false;
+      windowOrCommitInFlight = true;
+    }
+
+    private void startCommit() {
+      needCommit = false;
+      windowOrCommitInFlight = true;
+    }
+
+    private void startProcess() {
+      messagesInFlight.incrementAndGet();
+    }
+
+    private void doneWindowOrCommit() {
+      windowOrCommitInFlight = false;
+    }
+
+    private void doneProcess() {
+      messagesInFlight.decrementAndGet();
+    }
+
+    /**
+     * Insert an PendingEnvelope into the pending envelope queue.
+     * The function will be called in the run loop thread so no synchronization.
+     * @param pendingEnvelope
+     */
+    private void insertEnvelope(PendingEnvelope pendingEnvelope) {
+      pendingEnvelopQueue.add(pendingEnvelope);
+      int queueSize = pendingEnvelopQueue.size();
+      taskMetrics.pendingMessages().set(queueSize);
+      log.trace("Insert envelope to task {} queue.", taskName);
+      log.debug("Task {} pending envelope count is {} after insertion.", taskName, queueSize);
+    }
+
+    /**
+     * Fetch the pending envelope in the pending queue for the task to process.
+     * Update the chooser for flow control on the SSP level. Once it's updated, the AsyncRunLoop
+     * will be able to choose new messages from this SSP for the task to process. Note that we
+     * update only when the envelope is first time being processed. This solves the issue in
+     * Broadcast stream where a message need to be processed by multiple tasks. In that case,
+     * the envelope will be in the pendingEnvelopeQueue of each task. Only the first fetch updates
+     * the chooser with the next envelope in the broadcast stream partition.
+     * The function will be called in the run loop thread so no synchronization.
+     * @return
+     */
+    private IncomingMessageEnvelope fetchEnvelope() {
+      PendingEnvelope pendingEnvelope = pendingEnvelopQueue.remove();
+      int queueSize = pendingEnvelopQueue.size();
+      taskMetrics.pendingMessages().set(queueSize);
+      log.trace("fetch envelope ssp {} offset {} to process.", pendingEnvelope.envelope.getSystemStreamPartition(), pendingEnvelope.envelope.getOffset());
+      log.debug("Task {} pending envelopes count is {} after fetching.", taskName, queueSize);
+
+      if (pendingEnvelope.markProcessed()) {
+        SystemStreamPartition partition = pendingEnvelope.envelope.getSystemStreamPartition();
+        consumerMultiplexer.tryUpdate(partition);
+        log.debug("Update chooser for " + partition);
+      }
+      return pendingEnvelope.envelope;
+    }
+  }
+}
\ No newline at end of file
diff --git a/samza-core/src/main/java/org/apache/samza/task/AsyncStreamTaskAdapter.java b/samza-core/src/main/java/org/apache/samza/task/AsyncStreamTaskAdapter.java
new file mode 100644 (file)
index 0000000..1fc6456
--- /dev/null
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.task;
+
+import java.util.concurrent.ExecutorService;
+import org.apache.samza.config.Config;
+import org.apache.samza.system.IncomingMessageEnvelope;
+
+
+/**
+ * AsyncStreamTaskAdapter allows a StreamTask to be executed in parallel.The class
+ * uses the build-in thread pool to invoke StreamTask.process and triggers
+ * the callbacks once it's done. If the thread pool is null, it follows the legacy
+ * synchronous model to execute the tasks on the run loop thread.
+ */
+public class AsyncStreamTaskAdapter implements AsyncStreamTask, InitableTask, WindowableTask, ClosableTask {
+  private final StreamTask wrappedTask;
+  private final ExecutorService executor;
+
+  public AsyncStreamTaskAdapter(StreamTask task, ExecutorService executor) {
+    this.wrappedTask = task;
+    this.executor = executor;
+  }
+
+  @Override
+  public void init(Config config, TaskContext context) throws Exception {
+    if (wrappedTask instanceof InitableTask) {
+      ((InitableTask) wrappedTask).init(config, context);
+    }
+  }
+
+  @Override
+  public void processAsync(final IncomingMessageEnvelope envelope,
+      final MessageCollector collector,
+      final TaskCoordinator coordinator,
+      final TaskCallback callback) {
+    if (executor != null) {
+      executor.submit(new Runnable() {
+        @Override
+        public void run() {
+          process(envelope, collector, coordinator, callback);
+        }
+      });
+    } else {
+      // legacy mode: running all tasks in the runloop thread
+      process(envelope, collector, coordinator, callback);
+    }
+  }
+
+  private void process(IncomingMessageEnvelope envelope,
+      MessageCollector collector,
+      TaskCoordinator coordinator,
+      TaskCallback callback) {
+    try {
+      wrappedTask.process(envelope, collector, coordinator);
+      callback.complete();
+    } catch (Throwable t) {
+      callback.failure(t);
+    }
+  }
+
+  @Override
+  public void window(MessageCollector collector, TaskCoordinator coordinator) throws Exception {
+    if (wrappedTask instanceof WindowableTask) {
+      ((WindowableTask) wrappedTask).window(collector, coordinator);
+    }
+  }
+
+  @Override
+  public void close() throws Exception {
+    if (wrappedTask instanceof ClosableTask) {
+      ((ClosableTask) wrappedTask).close();
+    }
+  }
+}
diff --git a/samza-core/src/main/java/org/apache/samza/task/CoordinatorRequests.java b/samza-core/src/main/java/org/apache/samza/task/CoordinatorRequests.java
new file mode 100644 (file)
index 0000000..052b3b9
--- /dev/null
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.task;
+
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.Set;
+
+import org.apache.samza.container.TaskName;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+
+/**
+ * TaskCoordinatorRequests is used in run loop to collect the coordinator
+ * requests from tasks, including commit requests and shutdown requests.
+ * It is thread safe so it can be updated from multiple task threads.
+ */
+public class CoordinatorRequests {
+  private static final Logger log = LoggerFactory.getLogger(CoordinatorRequests.class);
+
+  private final Set<TaskName> taskNames;
+  private final Set<TaskName> taskShutdownRequests = Collections.synchronizedSet(new HashSet<TaskName>());
+  private final Set<TaskName> taskCommitRequests = Collections.synchronizedSet(new HashSet<TaskName>());
+  volatile private boolean shutdownNow = false;
+
+  public CoordinatorRequests(Set<TaskName> taskNames) {
+    this.taskNames = taskNames;
+  }
+
+  public void update(ReadableCoordinator coordinator) {
+    if (coordinator.commitRequest().isDefined() || coordinator.shutdownRequest().isDefined()) {
+      checkCoordinator(coordinator);
+    }
+  }
+
+  public Set<TaskName> commitRequests() {
+    return taskCommitRequests;
+  }
+
+  public boolean shouldShutdownNow() {
+    return shutdownNow;
+  }
+
+  /**
+   * A new TaskCoordinator object is passed to a task on every call to StreamTask.process
+   * and WindowableTask.window. This method checks whether the task requested that we
+   * do something that affects the run loop (such as commit or shut down), and updates
+   * run loop state accordingly.
+   */
+  private void checkCoordinator(ReadableCoordinator coordinator) {
+    if (coordinator.requestedCommitTask()) {
+      log.info("Task "  + coordinator.taskName() + " requested commit for current task only");
+      taskCommitRequests.add(coordinator.taskName());
+    }
+
+    if (coordinator.requestedCommitAll()) {
+      log.info("Task " + coordinator.taskName() + " requested commit for all tasks in the container");
+      taskCommitRequests.addAll(taskNames);
+    }
+
+    if (coordinator.requestedShutdownOnConsensus()) {
+      taskShutdownRequests.add(coordinator.taskName());
+      log.info("Shutdown has now been requested by tasks " + taskShutdownRequests);
+    }
+
+    if (coordinator.requestedShutdownNow() || taskShutdownRequests.size() == taskNames.size()) {
+      log.info("Shutdown requested.");
+      shutdownNow = true;
+    }
+  }
+}
\ No newline at end of file
diff --git a/samza-core/src/main/java/org/apache/samza/task/TaskCallbackFactory.java b/samza-core/src/main/java/org/apache/samza/task/TaskCallbackFactory.java
new file mode 100644 (file)
index 0000000..7dddb67
--- /dev/null
@@ -0,0 +1,28 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.task;
+
+/**
+ * TaskCallbackFactory creates the {@link TaskCallback} for {@link org.apache.samza.container.TaskInstance}
+ * to process asynchronously
+ */
+public interface TaskCallbackFactory {
+  TaskCallback createCallback();
+}
\ No newline at end of file
diff --git a/samza-core/src/main/java/org/apache/samza/task/TaskCallbackImpl.java b/samza-core/src/main/java/org/apache/samza/task/TaskCallbackImpl.java
new file mode 100644 (file)
index 0000000..9b70099
--- /dev/null
@@ -0,0 +1,104 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.task;
+
+import java.util.concurrent.ScheduledFuture;
+import java.util.concurrent.atomic.AtomicBoolean;
+import org.apache.samza.container.TaskName;
+import org.apache.samza.system.IncomingMessageEnvelope;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * This class implements {@link TaskCallback}. It triggers the
+ * {@link TaskCallbackListener} with the callback result. If the
+ * callback is called multiple times, it will throw IllegalStateException
+ * to the listener.
+ */
+class TaskCallbackImpl implements TaskCallback, Comparable<TaskCallbackImpl> {
+  private static final Logger log = LoggerFactory.getLogger(TaskCallbackImpl.class);
+
+  final TaskName taskName;
+  final IncomingMessageEnvelope envelope;
+  final ReadableCoordinator coordinator;
+  final long timeCreatedNs;
+  private final AtomicBoolean isComplete = new AtomicBoolean(false);
+  private final TaskCallbackListener listener;
+  private ScheduledFuture scheduledFuture = null;
+  private final long seqNum;
+
+  public TaskCallbackImpl(TaskCallbackListener listener,
+      TaskName taskName,
+      IncomingMessageEnvelope envelope,
+      ReadableCoordinator coordinator,
+      long seqNum) {
+    this.listener = listener;
+    this.taskName = taskName;
+    this.envelope = envelope;
+    this.coordinator = coordinator;
+    this.seqNum = seqNum;
+    this.timeCreatedNs = System.nanoTime();
+  }
+
+  @Override
+  public void complete() {
+    if (scheduledFuture != null) {
+      scheduledFuture.cancel(true);
+    }
+    log.trace("Callback complete for ssp {} offset {}.", envelope.getSystemStreamPartition(), envelope.getOffset());
+
+    if (isComplete.compareAndSet(false, true)) {
+      listener.onComplete(this);
+    } else {
+      Throwable throwable = new IllegalStateException("TaskCallback complete has been invoked after completion");
+      log.error("Callback for process task {}, envelope {}.", new Object[] {taskName, envelope}, throwable);
+      listener.onFailure(this, throwable);
+    }
+  }
+
+  @Override
+  public void failure(Throwable t) {
+    if (scheduledFuture != null) {
+      scheduledFuture.cancel(true);
+    }
+    log.error("Callback fails for task {} envelope {}.", new Object[] {taskName, envelope}, t);
+
+    if (isComplete.compareAndSet(false, true)) {
+      listener.onFailure(this, t);
+    } else {
+      Throwable throwable = new IllegalStateException("TaskCallback failure has been invoked after completion", t);
+      log.error("Callback for process task {}, envelope {}.", new Object[] {taskName, envelope}, throwable);
+      listener.onFailure(this, throwable);
+    }
+  }
+
+  void setScheduledFuture(ScheduledFuture scheduledFuture) {
+    this.scheduledFuture = scheduledFuture;
+  }
+
+  @Override
+  public int compareTo(TaskCallbackImpl callback) {
+    return Long.compare(this.seqNum, callback.seqNum);
+  }
+
+  boolean matchSeqNum(long seqNum) {
+    return this.seqNum == seqNum;
+  }
+}
\ No newline at end of file
diff --git a/samza-core/src/main/java/org/apache/samza/task/TaskCallbackListener.java b/samza-core/src/main/java/org/apache/samza/task/TaskCallbackListener.java
new file mode 100644 (file)
index 0000000..de4ee58
--- /dev/null
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.task;
+
+/**
+ * The interface of the listener to the {@link AsyncStreamTask}.processAsync
+ * callback events. If the callback completes with success, onComplete() will be fired.
+ * If the callback fails, onFailure() will be fired.
+ */
+interface TaskCallbackListener {
+  void onComplete(TaskCallback callback);
+  void onFailure(TaskCallback callback, Throwable t);
+}
\ No newline at end of file
diff --git a/samza-core/src/main/java/org/apache/samza/task/TaskCallbackManager.java b/samza-core/src/main/java/org/apache/samza/task/TaskCallbackManager.java
new file mode 100644 (file)
index 0000000..132cf59
--- /dev/null
@@ -0,0 +1,141 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.task;
+
+import java.util.PriorityQueue;
+import java.util.Queue;
+import java.util.concurrent.ScheduledExecutorService;
+import java.util.concurrent.ScheduledFuture;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicInteger;
+import org.apache.samza.container.TaskInstanceMetrics;
+import org.apache.samza.container.TaskName;
+import org.apache.samza.system.IncomingMessageEnvelope;
+
+
+/**
+ * TaskCallbackManager manages the life cycle of {@link AsyncStreamTask} callbacks,
+ * including creation, update and status. Internally it maintains a PriorityQueue
+ * for the callbacks based on the sequence number, and updates the offsets for checkpointing
+ * by always moving forward to the latest contiguous callback (uses the high watermark).
+ */
+class TaskCallbackManager {
+
+  private static final class TaskCallbacks {
+    private final Queue<TaskCallbackImpl> callbacks = new PriorityQueue<>();
+    private final Object lock = new Object();
+    private long nextSeqNum = 0L;
+
+    /**
+     * Adding the newly complete callback to the callback queue
+     * Move the queue to the last contiguous callback to commit offset
+     * @param cb new callback completed
+     * @return callback of highest watermark needed to be committed
+     */
+    TaskCallbackImpl update(TaskCallbackImpl cb) {
+      synchronized (lock) {
+        callbacks.add(cb);
+
+        TaskCallbackImpl callback = null;
+        TaskCallbackImpl callbackToCommit = null;
+        TaskCoordinator.RequestScope shutdownRequest = null;
+        // look for the last contiguous callback
+        while (!callbacks.isEmpty() && callbacks.peek().matchSeqNum(nextSeqNum)) {
+          ++nextSeqNum;
+          callback = callbacks.poll();
+
+          if (callback.coordinator.commitRequest().isDefined()) {
+            callbackToCommit = callback;
+          }
+
+          if (callback.coordinator.shutdownRequest().isDefined()) {
+            shutdownRequest = callback.coordinator.shutdownRequest().get();
+          }
+        }
+
+        // if there is no manual commit, use the highest contiguous callback message offset
+        if (callbackToCommit == null) {
+          callbackToCommit = callback;
+        }
+
+        // if there is a shutdown request, merge it into the coordinator to commit
+        if (shutdownRequest != null) {
+          callbackToCommit.coordinator.shutdown(shutdownRequest);
+        }
+
+        return callbackToCommit;
+      }
+    }
+  }
+
+  private long seqNum = 0L;
+  private final AtomicInteger pendingCount = new AtomicInteger(0);
+  private final TaskCallbacks completeCallbacks = new TaskCallbacks();
+  private final TaskInstanceMetrics metrics;
+  private final ScheduledExecutorService timer;
+  private final TaskCallbackListener listener;
+  private long timeout;
+
+  public TaskCallbackManager(TaskCallbackListener listener, TaskInstanceMetrics metrics, ScheduledExecutorService timer, long timeout) {
+    this.listener = listener;
+    this.metrics = metrics;
+    this.timer = timer;
+    this.timeout = timeout;
+  }
+
+  public TaskCallbackImpl createCallback(TaskName taskName,
+      IncomingMessageEnvelope envelope,
+      ReadableCoordinator coordinator) {
+    final TaskCallbackImpl callback = new TaskCallbackImpl(listener, taskName, envelope, coordinator, seqNum++);
+    int count = pendingCount.incrementAndGet();
+    metrics.messagesInFlight().set(count);
+
+    if (timer != null) {
+      Runnable timerTask = new Runnable() {
+        @Override
+        public void run() {
+          String msg = "Task " + callback.taskName + " callback times out";
+          callback.failure(new TaskCallbackTimeoutException(msg));
+        }
+      };
+      ScheduledFuture scheduledFuture = timer.schedule(timerTask, timeout, TimeUnit.MILLISECONDS);
+      callback.setScheduledFuture(scheduledFuture);
+    }
+
+    return callback;
+  }
+
+  /**
+   * Update the task callbacks with the new callback completed.
+   * It uses a high-watermark model to roll the callbacks for checkpointing.
+   * @param callback new completed callback
+   * @param success callback result status
+   * @return the callback for checkpointing
+   */
+  public TaskCallbackImpl updateCallback(TaskCallbackImpl callback, boolean success) {
+    TaskCallbackImpl callbackToCommit = null;
+    if (success) {
+      callbackToCommit = completeCallbacks.update(callback);
+    }
+    int count = pendingCount.decrementAndGet();
+    metrics.messagesInFlight().set(count);
+    return callbackToCommit;
+  }
+}
diff --git a/samza-core/src/main/java/org/apache/samza/task/TaskCallbackTimeoutException.java b/samza-core/src/main/java/org/apache/samza/task/TaskCallbackTimeoutException.java
new file mode 100644 (file)
index 0000000..bf7f13c
--- /dev/null
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.task;
+
+import org.apache.samza.SamzaException;
+
+
+/**
+ * Specific {@link SamzaException}s thrown when a task callback times out
+ */
+public class TaskCallbackTimeoutException extends SamzaException {
+  private static final long serialVersionUID = -2342134146355610665L;
+
+  public TaskCallbackTimeoutException(Throwable e) {
+    super(e);
+  }
+
+  public TaskCallbackTimeoutException(String msg) {
+    super(msg);
+  }
+
+  public TaskCallbackTimeoutException(String msg, Throwable e) {
+    super(msg, e);
+  }
+}
\ No newline at end of file
diff --git a/samza-core/src/main/java/org/apache/samza/util/Utils.java b/samza-core/src/main/java/org/apache/samza/util/Utils.java
new file mode 100644 (file)
index 0000000..472e0a5
--- /dev/null
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.util;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import scala.runtime.AbstractFunction0;
+
+
+public class Utils {
+  private static final Logger log = LoggerFactory.getLogger(Utils.class);
+
+  private Utils() {}
+
+  /**
+   * Returns a default value object for scala option.getOrDefault() to use
+   * @param value default value
+   * @param <T> value type
+   * @return object containing default value
+   */
+  public static <T> AbstractFunction0<T> defaultValue(final T value) {
+    return new AbstractFunction0<T>() {
+      @Override
+      public T apply() {
+        return value;
+      }
+    };
+  }
+
+  /**
+   * Creates a nanosecond clock using default system nanotime
+   * @return object invokes the system clock
+   */
+  public static AbstractFunction0<Object> defaultClock() {
+    return new AbstractFunction0<Object>() {
+      @Override
+      public Object apply() {
+        return System.nanoTime();
+      }
+    };
+  }
+}
index 00648e4..7245902 100644 (file)
@@ -19,6 +19,9 @@
 
 package org.apache.samza.checkpoint
 
+
+import java.util.concurrent.ConcurrentHashMap
+
 import org.apache.samza.system.SystemStream
 import org.apache.samza.system.SystemStreamPartition
 import org.apache.samza.system.SystemStreamMetadata
@@ -146,7 +149,7 @@ class OffsetManager(
   /**
    * Last offsets processed for each SystemStreamPartition.
    */
-  var lastProcessedOffsets = Map[TaskName, Map[SystemStreamPartition, String]]()
+  val lastProcessedOffsets = new ConcurrentHashMap[TaskName, ConcurrentHashMap[SystemStreamPartition, String]]()
 
   /**
    * Offsets to start reading from for each SystemStreamPartition. This
@@ -182,20 +185,15 @@ class OffsetManager(
    * Set the last processed offset for a given SystemStreamPartition.
    */
   def update(taskName: TaskName, systemStreamPartition: SystemStreamPartition, offset: String) {
-    lastProcessedOffsets.get(taskName) match {
-      case Some(sspToOffsets) => lastProcessedOffsets += taskName -> (sspToOffsets + (systemStreamPartition -> offset))
-      case None => lastProcessedOffsets += (taskName -> Map(systemStreamPartition -> offset))
-    }
+    lastProcessedOffsets.putIfAbsent(taskName, new ConcurrentHashMap[SystemStreamPartition, String]())
+    lastProcessedOffsets.get(taskName).put(systemStreamPartition, offset)
   }
 
   /**
    * Get the last processed offset for a SystemStreamPartition.
    */
-  def getLastProcessedOffset(taskName: TaskName, systemStreamPartition: SystemStreamPartition) = {
-    lastProcessedOffsets.get(taskName) match {
-      case Some(sspToOffsets) => sspToOffsets.get(systemStreamPartition)
-      case None => None
-    }
+  def getLastProcessedOffset(taskName: TaskName, systemStreamPartition: SystemStreamPartition): Option[String] = {
+    Option(lastProcessedOffsets.get(taskName)).map(_.get(systemStreamPartition))
   }
 
   /**
@@ -217,7 +215,7 @@ class OffsetManager(
       debug("Checkpointing offsets for taskName %s." format taskName)
 
       val sspsForTaskName = systemStreamPartitions.getOrElse(taskName, throw new SamzaException("No such SystemStreamPartition set " + taskName + " registered for this checkpointmanager")).toSet
-      val partitionOffsets = lastProcessedOffsets.get(taskName) match {
+      val partitionOffsets = Option(lastProcessedOffsets.get(taskName)) match {
         case Some(sspToOffsets) => sspToOffsets.filterKeys(sspsForTaskName.contains(_))
         case None => {
           warn(taskName + " is not found... ")
@@ -225,8 +223,9 @@ class OffsetManager(
         }
       }
 
+      partitionOffsets.foreach(p => info("task " + taskName + " checkpoint " + p._1 + ", " + p._2))
       checkpointManager.writeCheckpoint(taskName, new Checkpoint(partitionOffsets))
-      lastProcessedOffsets.get(taskName) match {
+      Option(lastProcessedOffsets.get(taskName)) match {
         case Some(sspToOffsets) => sspToOffsets.foreach { case (ssp, checkpoint) => offsetManagerMetrics.checkpointedOffsets(ssp).set(checkpoint) }
         case None =>
       }
@@ -270,9 +269,8 @@ class OffsetManager(
         .keys
         .flatMap(restoreOffsetsFromCheckpoint(_))
         .toMap
-      lastProcessedOffsets ++= result.map {
-        case (taskName, sspToOffset) => {
-          taskName -> sspToOffset.filter {
+      result.map { case (taskName, sspToOffset) => {
+          lastProcessedOffsets.put(taskName, new ConcurrentHashMap[SystemStreamPartition, String](sspToOffset.filter {
             case (systemStreamPartition, offset) =>
               val shouldKeep = offsetSettings.contains(systemStreamPartition.getSystemStream)
               if (!shouldKeep) {
@@ -280,7 +278,7 @@ class OffsetManager(
               }
               info("Checkpointed offset is currently %s for %s" format (offset, systemStreamPartition))
               shouldKeep
-          }
+          }))
         }
       }
     } else {
@@ -324,17 +322,15 @@ class OffsetManager(
       }
     }
 
-    lastProcessedOffsets = lastProcessedOffsets.map {
-      case (taskName, sspToOffsets) => {
-        taskName -> (sspToOffsets -- systemStreamPartitionsToReset(taskName))
-      }
+    lastProcessedOffsets.keys().foreach { taskName =>
+      lastProcessedOffsets.get(taskName).keySet().removeAll(systemStreamPartitionsToReset(taskName))
     }
   }
 
   /**
    * Returns a map of all SystemStreamPartitions in lastProcessedOffsets that need to be reset
    */
-  private def getSystemStreamPartitionsToReset(taskNameTosystemStreamPartitions: Map[TaskName, Map[SystemStreamPartition, String]]): Map[TaskName, Set[SystemStreamPartition]] = {
+  private def getSystemStreamPartitionsToReset(taskNameTosystemStreamPartitions: ConcurrentHashMap[TaskName, ConcurrentHashMap[SystemStreamPartition, String]]): Map[TaskName, Set[SystemStreamPartition]] = {
     taskNameTosystemStreamPartitions.map {
       case (taskName, sspToOffsets) => {
         taskName -> (sspToOffsets.filter {
index 49b08f6..13b72fa 100644 (file)
@@ -44,6 +44,8 @@ object JobConfig {
   val SAMZA_FWK_VERSION = "samza.fwk.version"
   val JOB_COORDINATOR_SYSTEM = "job.coordinator.system"
   val JOB_CONTAINER_COUNT = "job.container.count"
+  val jOB_CONTAINER_THREAD_POOL_SIZE = "job.container.thread.pool.size"
+  val JOB_CONTAINER_SINGLE_THREAD_MODE = "job.container.single.thread.mode"
   val JOB_REPLICATION_FACTOR = "job.coordinator.replication.factor"
   val JOB_SEGMENT_BYTES = "job.coordinator.segment.bytes"
   val SSP_GROUPER_FACTORY = "job.systemstreampartition.grouper.factory"
@@ -167,4 +169,13 @@ class JobConfig(config: Config) extends ScalaMapConfig(config) with Logging {
 
   def getSSPMatcherConfigJobFactoryRegex = getOrElse(JobConfig.SSP_MATCHER_CONFIG_JOB_FACTORY_REGEX, JobConfig.DEFAULT_SSP_MATCHER_CONFIG_JOB_FACTORY_REGEX)
 
+  def getThreadPoolSize = getOption(JobConfig.jOB_CONTAINER_THREAD_POOL_SIZE) match {
+    case Some(size) => size.toInt
+    case _ => 0
+  }
+
+  def getSingleThreadMode = getOption(JobConfig.JOB_CONTAINER_SINGLE_THREAD_MODE) match {
+    case Some(mode) => mode.toBoolean
+    case _ => false
+  }
 }
index 08a4deb..90c1904 100644 (file)
@@ -38,6 +38,8 @@ object TaskConfig {
   val DROP_SERIALIZATION_ERROR = "task.drop.serialization.errors" // define whether drop the messages or not when serialization fails
   val IGNORED_EXCEPTIONS = "task.ignored.exceptions" // exceptions to ignore in process and window
   val GROUPER_FACTORY = "task.name.grouper.factory" // class name for task grouper
+  val MAX_CONCURRENCY = "task.max.concurrency" // max number of concurrent process for a AsyncStreamTask
+  val CALLBACK_TIMEOUT_MS = "task.callback.timeout.ms"  // timeout period for triggering a callback
 
   /**
    * Samza's container polls for more messages under two conditions. The first
@@ -117,4 +119,13 @@ class TaskConfig(config: Config) extends ScalaMapConfig(config) with Logging {
     }
   }
 
+  def getMaxConcurrency: Option[Int] = getOption(TaskConfig.MAX_CONCURRENCY) match {
+    case Some(count) => Some(count.toInt)
+    case _ => None
+  }
+
+  def getCallbackTimeoutMs: Option[Long] = getOption(TaskConfig.CALLBACK_TIMEOUT_MS) match {
+    case Some(ms) => Some(ms.toLong)
+    case _ => None
+  }
 }
index cf05c15..bb2c376 100644 (file)
@@ -21,12 +21,18 @@ package org.apache.samza.container
 
 import java.util.concurrent.Executor
 
-import org.apache.samza.system.{SystemConsumers, SystemStreamPartition}
+import org.apache.samza.system.SystemConsumers
+import org.apache.samza.system.SystemStreamPartition
+import org.apache.samza.task.CoordinatorRequests
 import org.apache.samza.task.ReadableCoordinator
-import org.apache.samza.util.{Logging, TimerUtils}
+import org.apache.samza.task.StreamTask
+import org.apache.samza.util.Logging
+import org.apache.samza.util.TimerUtils
+
+import scala.collection.JavaConversions._
 
 /**
- * Each {@link SamzaContainer} uses a single-threaded execution model: activities for
+ * The run loop uses a single-threaded execution model: activities for
  * all {@link TaskInstance}s within a container are multiplexed onto one execution
  * thread. Those activities include task callbacks (such as StreamTask.process and
  * WindowableTask.window), committing checkpoints, etc.
@@ -34,31 +40,29 @@ import org.apache.samza.util.{Logging, TimerUtils}
  * <p>This class manages the execution of that run loop, determining what needs to
  * be done when.
  */
-class RunLoop(
-  val taskInstances: Map[TaskName, TaskInstance],
+class RunLoop (
+  val taskInstances: Map[TaskName, TaskInstance[StreamTask]],
   val consumerMultiplexer: SystemConsumers,
   val metrics: SamzaContainerMetrics,
   val windowMs: Long = -1,
   val commitMs: Long = 60000,
   val clock: () => Long = { System.nanoTime },
-  val shutdownMs: Long = 5000,
   val executor: Executor = new SameThreadExecutor()) extends Runnable with TimerUtils with Logging {
 
   private val metricsMsOffset = 1000000L
   private var lastWindowNs = clock()
   private var lastCommitNs = clock()
   private var activeNs = 0L
-  private var taskShutdownRequests: Set[TaskName] = Set()
-  private var taskCommitRequests: Set[TaskName] = Set()
   @volatile private var shutdownNow = false
+  private val coordinatorRequests: CoordinatorRequests = new CoordinatorRequests(taskInstances.keySet)
 
   // Messages come from the chooser with no connection to the TaskInstance they're bound for.
   // Keep a mapping of SystemStreamPartition to TaskInstance to efficiently route them.
   val systemStreamPartitionToTaskInstances = getSystemStreamPartitionToTaskInstancesMapping
 
-  def getSystemStreamPartitionToTaskInstancesMapping: Map[SystemStreamPartition, List[TaskInstance]] = {
+  def getSystemStreamPartitionToTaskInstancesMapping: Map[SystemStreamPartition, List[TaskInstance[StreamTask]]] = {
     // We could just pass in the SystemStreamPartitionMap during construction, but it's safer and cleaner to derive the information directly
-    def getSystemStreamPartitionToTaskInstance(taskInstance: TaskInstance) = taskInstance.systemStreamPartitions.map(_ -> taskInstance).toMap
+    def getSystemStreamPartitionToTaskInstance(taskInstance: TaskInstance[StreamTask]) = taskInstance.systemStreamPartitions.map(_ -> taskInstance).toMap
 
     taskInstances.values.map { getSystemStreamPartitionToTaskInstance }.flatten.groupBy(_._1).map {
       case (ssp, ssp2taskInstance) => ssp -> ssp2taskInstance.map(_._2).toList
@@ -70,8 +74,6 @@ class RunLoop(
    * unhandled exception is thrown.
    */
   def run {
-    addShutdownHook(Thread.currentThread())
-
     val runTask = new Runnable() {
       override def run(): Unit = {
         val loopStartTime = clock()
@@ -89,19 +91,8 @@ class RunLoop(
     }
   }
 
-  private def addShutdownHook(runLoopThread: Thread) {
-    Runtime.getRuntime().addShutdownHook(new Thread() {
-      override def run() = {
-        info("Shutting down, will wait up to %s ms" format shutdownMs)
-        shutdownNow = true
-        runLoopThread.join(shutdownMs)
-        if (runLoopThread.isAlive) {
-          warn("Did not shut down within %s ms, exiting" format shutdownMs)
-        } else {
-          info("Shutdown complete")
-        }
-      }
-    })
+  def shutdown = {
+    shutdownNow = true
   }
 
   /**
@@ -115,7 +106,7 @@ class RunLoop(
     // Exclude choose time from activeNs. Although it includes deserialization time,
     // it most closely captures idle time.
     val envelope = updateTimer(metrics.chooseNs) {
-     consumerMultiplexer.choose
+     consumerMultiplexer.choose()
     }
 
     activeNs += updateTimerAndGetDuration(metrics.processNs) ((currentTimeNs: Long) => {
@@ -128,11 +119,11 @@ class RunLoop(
         val taskInstances = systemStreamPartitionToTaskInstances(ssp)
         taskInstances.foreach {
           taskInstance =>
-            {
-              val coordinator = new ReadableCoordinator(taskInstance.taskName)
-              taskInstance.process(envelope, coordinator)
-              checkCoordinator(coordinator)
-            }
+          {
+            val coordinator = new ReadableCoordinator(taskInstance.taskName)
+            taskInstance.process(envelope, coordinator)
+            coordinatorRequests.update(coordinator)
+          }
         }
       } else {
         trace("No incoming message envelope was available.")
@@ -155,7 +146,7 @@ class RunLoop(
           case (taskName, task) =>
             val coordinator = new ReadableCoordinator(taskName)
             task.window(coordinator)
-            checkCoordinator(coordinator)
+            coordinatorRequests.update(coordinator)
         }
       }
     })
@@ -167,47 +158,20 @@ class RunLoop(
   private def commit {
     activeNs += updateTimerAndGetDuration(metrics.commitNs) ((currentTimeNs: Long) => {
       if (commitMs >= 0 && lastCommitNs + commitMs * metricsMsOffset < currentTimeNs) {
-        trace("Committing task instances because the commit interval has elapsed.")
+        info("Committing task instances because the commit interval has elapsed.")
         lastCommitNs = currentTimeNs
         metrics.commits.inc
         taskInstances.values.foreach(_.commit)
-      } else if (!taskCommitRequests.isEmpty) {
+      } else if (!coordinatorRequests.commitRequests.isEmpty){
         trace("Committing due to explicit commit request.")
         metrics.commits.inc
-        taskCommitRequests.foreach(taskName => {
+        coordinatorRequests.commitRequests.foreach(taskName => {
           taskInstances(taskName).commit
         })
       }
 
-      taskCommitRequests = Set()
+      shutdownNow |= coordinatorRequests.shouldShutdownNow
+      coordinatorRequests.commitRequests.clear()
     })
   }
-
-  /**
-   * A new TaskCoordinator object is passed to a task on every call to StreamTask.process
-   * and WindowableTask.window. This method checks whether the task requested that we
-   * do something that affects the run loop (such as commit or shut down), and updates
-   * run loop state accordingly.
-   */
-  private def checkCoordinator(coordinator: ReadableCoordinator) {
-    if (coordinator.requestedCommitTask) {
-      debug("Task %s requested commit for current task only" format coordinator.taskName)
-      taskCommitRequests += coordinator.taskName
-    }
-
-    if (coordinator.requestedCommitAll) {
-      debug("Task %s requested commit for all tasks in the container" format coordinator.taskName)
-      taskCommitRequests ++= taskInstances.keys
-    }
-
-    if (coordinator.requestedShutdownOnConsensus) {
-      taskShutdownRequests += coordinator.taskName
-      info("Shutdown has now been requested by tasks: %s" format taskShutdownRequests)
-    }
-
-    if (coordinator.requestedShutdownNow || taskShutdownRequests.size == taskInstances.size) {
-      info("Shutdown requested.")
-      shutdownNow = true
-    }
-  }
 }
index 18c0922..b8600d5 100644 (file)
 package org.apache.samza.container
 
 import java.io.File
+import java.lang.Thread.UncaughtExceptionHandler
+import java.net.URL
+import java.net.UnknownHostException
 import java.nio.file.Path
 import java.util
-import java.lang.Thread.UncaughtExceptionHandler
-import java.net.{URL, UnknownHostException}
+import java.util.concurrent.ExecutorService
+import java.util.concurrent.Executors
+import java.util.concurrent.TimeUnit
+
 import org.apache.samza.SamzaException
-import org.apache.samza.checkpoint.{CheckpointManagerFactory, OffsetManager, OffsetManagerMetrics}
+import org.apache.samza.checkpoint.CheckpointManagerFactory
+import org.apache.samza.checkpoint.OffsetManager
+import org.apache.samza.checkpoint.OffsetManagerMetrics
 import org.apache.samza.config.JobConfig.Config2Job
 import org.apache.samza.config.MetricsConfig.Config2Metrics
 import org.apache.samza.config.SerializerConfig.Config2Serializer
@@ -34,18 +41,45 @@ import org.apache.samza.config.StorageConfig.Config2Storage
 import org.apache.samza.config.StreamConfig.Config2Stream
 import org.apache.samza.config.SystemConfig.Config2System
 import org.apache.samza.config.TaskConfig.Config2Task
+import org.apache.samza.container.disk.DiskQuotaPolicyFactory
+import org.apache.samza.container.disk.DiskSpaceMonitor
 import org.apache.samza.container.disk.DiskSpaceMonitor.Listener
-import org.apache.samza.container.disk.{NoThrottlingDiskQuotaPolicyFactory, DiskQuotaPolicyFactory, PollingScanDiskSpaceMonitor, DiskSpaceMonitor}
+import org.apache.samza.container.disk.NoThrottlingDiskQuotaPolicyFactory
+import org.apache.samza.container.disk.PollingScanDiskSpaceMonitor
 import org.apache.samza.coordinator.stream.CoordinatorStreamSystemFactory
-import org.apache.samza.job.model.{ContainerModel, JobModel}
-import org.apache.samza.metrics.{JmxServer, JvmMetrics, MetricsRegistryMap, MetricsReporter, MetricsReporterFactory}
-import org.apache.samza.serializers.{SerdeFactory, SerdeManager}
+import org.apache.samza.job.model.ContainerModel
+import org.apache.samza.job.model.JobModel
+import org.apache.samza.metrics.JmxServer
+import org.apache.samza.metrics.JvmMetrics
+import org.apache.samza.metrics.MetricsRegistryMap
+import org.apache.samza.metrics.MetricsReporter
+import org.apache.samza.metrics.MetricsReporterFactory
+import org.apache.samza.serializers.SerdeFactory
+import org.apache.samza.serializers.SerdeManager
 import org.apache.samza.serializers.model.SamzaObjectMapper
-import org.apache.samza.storage.{StorageEngineFactory, TaskStorageManager}
-import org.apache.samza.system.{StreamMetadataCache, SystemConsumers, SystemConsumersMetrics, SystemFactory, SystemProducers, SystemProducersMetrics, SystemStream, SystemStreamPartition}
-import org.apache.samza.system.chooser.{DefaultChooser, MessageChooserFactory, RoundRobinChooserFactory}
-import org.apache.samza.task.{StreamTask, TaskInstanceCollector}
-import org.apache.samza.util.{ThrottlingExecutor, ExponentialSleepStrategy, Logging, Util}
+import org.apache.samza.storage.StorageEngineFactory
+import org.apache.samza.storage.TaskStorageManager
+import org.apache.samza.system.StreamMetadataCache
+import org.apache.samza.system.SystemConsumers
+import org.apache.samza.system.SystemConsumersMetrics
+import org.apache.samza.system.SystemFactory
+import org.apache.samza.system.SystemProducers
+import org.apache.samza.system.SystemProducersMetrics
+import org.apache.samza.system.SystemStream
+import org.apache.samza.system.SystemStreamPartition
+import org.apache.samza.system.chooser.DefaultChooser
+import org.apache.samza.system.chooser.MessageChooserFactory
+import org.apache.samza.system.chooser.RoundRobinChooserFactory
+import org.apache.samza.task.AsyncRunLoop
+import org.apache.samza.task.AsyncStreamTask
+import org.apache.samza.task.AsyncStreamTaskAdapter
+import org.apache.samza.task.StreamTask
+import org.apache.samza.task.TaskInstanceCollector
+import org.apache.samza.util.ExponentialSleepStrategy
+import org.apache.samza.util.Logging
+import org.apache.samza.util.ThrottlingExecutor
+import org.apache.samza.util.Util
+
 import scala.collection.JavaConversions._
 
 object SamzaContainer extends Logging {
@@ -164,6 +198,12 @@ object SamzaContainer extends Logging {
 
     info("Got input stream metadata: %s" format inputStreamMetadata)
 
+    val taskClassName = config
+      .getTaskClass
+      .getOrElse(throw new SamzaException("No task class defined in configuration."))
+
+    info("Got stream task class: %s" format taskClassName)
+
     val consumers = inputSystems
       .map(systemName => {
         val systemFactory = systemFactories(systemName)
@@ -181,6 +221,9 @@ object SamzaContainer extends Logging {
 
     info("Got system consumers: %s" format consumers.keys)
 
+    val isAsyncTask = classOf[AsyncStreamTask].isAssignableFrom(Class.forName(taskClassName))
+    info("%s is AsyncStreamTask" format taskClassName)
+
     val producers = systemFactories
       .map {
         case (systemName, systemFactory) =>
@@ -360,26 +403,22 @@ object SamzaContainer extends Logging {
 
     info("Got storage engines: %s" format storageEngineFactories.keys)
 
-    val taskClassName = config
-      .getTaskClass
-      .getOrElse(throw new SamzaException("No task class defined in configuration."))
-
-    info("Got stream task class: %s" format taskClassName)
-
-    val taskWindowMs = config.getWindowMs.getOrElse(-1L)
-
-    info("Got window milliseconds: %s" format taskWindowMs)
+    val singleThreadMode = config.getSingleThreadMode
+    info("Got single thread mode: " + singleThreadMode)
 
-    val taskCommitMs = config.getCommitMs.getOrElse(60000L)
-
-    info("Got commit milliseconds: %s" format taskCommitMs)
+    if(singleThreadMode && isAsyncTask) {
+      throw new SamzaException("AsyncStreamTask %s cannot run on single thread mode." format taskClassName)
+    }
 
-    val taskShutdownMs = config.getShutdownMs.getOrElse(5000L)
+    val threadPoolSize = config.getThreadPoolSize
+    info("Got thread pool size: " + threadPoolSize)
 
-    info("Got shutdown timeout milliseconds: %s" format taskShutdownMs)
+    val taskThreadPool = if (!singleThreadMode && threadPoolSize > 0)
+      Executors.newFixedThreadPool(threadPoolSize)
+    else
+      null
 
     // Wire up all task-instance-level (unshared) objects.
-
     val taskNames = containerModel
       .getTasks
       .values
@@ -395,12 +434,18 @@ object SamzaContainer extends Logging {
     val storeWatchPaths = new util.HashSet[Path]()
     storeWatchPaths.add(defaultStoreBaseDir.toPath)
 
-    val taskInstances: Map[TaskName, TaskInstance] = containerModel.getTasks.values.map(taskModel => {
+    val taskInstances: Map[TaskName, TaskInstance[_]] = containerModel.getTasks.values.map(taskModel => {
       debug("Setting up task instance: %s" format taskModel)
 
       val taskName = taskModel.getTaskName
 
-      val task = Util.getObj[StreamTask](taskClassName)
+      val taskObj = Class.forName(taskClassName).newInstance
+
+      val task = if (!singleThreadMode && !isAsyncTask)
+        // Wrap the StreamTask into a AsyncStreamTask with the build-in thread pool
+        new AsyncStreamTaskAdapter(taskObj.asInstanceOf[StreamTask], taskThreadPool)
+      else
+        taskObj
 
       val taskInstanceMetrics = new TaskInstanceMetrics("TaskName-%s" format taskName)
 
@@ -487,20 +532,22 @@ object SamzaContainer extends Logging {
 
       info("Retrieved SystemStreamPartitions " + systemStreamPartitions + " for " + taskName)
 
-      val taskInstance = new TaskInstance(
-        task = task,
-        taskName = taskName,
-        config = config,
-        metrics = taskInstanceMetrics,
-        systemAdmins = systemAdmins,
-        consumerMultiplexer = consumerMultiplexer,
-        collector = collector,
-        containerContext = containerContext,
-        offsetManager = offsetManager,
-        storageManager = storageManager,
-        reporters = reporters,
-        systemStreamPartitions = systemStreamPartitions,
-        exceptionHandler = TaskInstanceExceptionHandler(taskInstanceMetrics, config))
+      def createTaskInstance[T] (task: T ): TaskInstance[T] = new TaskInstance[T](
+          task = task,
+          taskName = taskName,
+          config = config,
+          metrics = taskInstanceMetrics,
+          systemAdmins = systemAdmins,
+          consumerMultiplexer = consumerMultiplexer,
+          collector = collector,
+          containerContext = containerContext,
+          offsetManager = offsetManager,
+          storageManager = storageManager,
+          reporters = reporters,
+          systemStreamPartitions = systemStreamPartitions,
+          exceptionHandler = TaskInstanceExceptionHandler(taskInstanceMetrics, config))
+
+      val taskInstance = createTaskInstance(task)
 
       (taskName, taskInstance)
     }).toMap
@@ -533,14 +580,13 @@ object SamzaContainer extends Logging {
       info(s"Disk quotas disabled because polling interval is not set ($DISK_POLL_INTERVAL_KEY)")
     }
 
-    val runLoop = new RunLoop(
-      taskInstances = taskInstances,
-      consumerMultiplexer = consumerMultiplexer,
-      metrics = samzaContainerMetrics,
-      windowMs = taskWindowMs,
-      commitMs = taskCommitMs,
-      shutdownMs = taskShutdownMs,
-      executor = executor)
+    val runLoop = RunLoopFactory.createRunLoop(
+      taskInstances,
+      consumerMultiplexer,
+      taskThreadPool,
+      executor,
+      samzaContainerMetrics,
+      config)
 
     info("Samza container setup complete.")
 
@@ -557,14 +603,15 @@ object SamzaContainer extends Logging {
       reporters = reporters,
       jvm = jvm,
       jmxServer = jmxServer,
-      diskSpaceMonitor = diskSpaceMonitor)
+      diskSpaceMonitor = diskSpaceMonitor,
+      taskThreadPool = taskThreadPool)
   }
 }
 
 class SamzaContainer(
   containerContext: SamzaContainerContext,
-  taskInstances: Map[TaskName, TaskInstance],
-  runLoop: RunLoop,
+  taskInstances: Map[TaskName, TaskInstance[_]],
+  runLoop: Runnable,
   consumerMultiplexer: SystemConsumers,
   producerMultiplexer: SystemProducers,
   metrics: SamzaContainerMetrics,
@@ -574,7 +621,10 @@ class SamzaContainer(
   localityManager: LocalityManager = null,
   securityManager: SecurityManager = null,
   reporters: Map[String, MetricsReporter] = Map(),
-  jvm: JvmMetrics = null) extends Runnable with Logging {
+  jvm: JvmMetrics = null,
+  taskThreadPool: ExecutorService = null) extends Runnable with Logging {
+
+  val shutdownMs = containerContext.config.getShutdownMs.getOrElse(5000L)
 
   def run {
     try {
@@ -591,6 +641,7 @@ class SamzaContainer(
       startSecurityManger
 
       info("Entering run loop.")
+      addShutdownHook
       runLoop.run
     } catch {
       case e: Exception =>
@@ -710,7 +761,7 @@ class SamzaContainer(
     consumerMultiplexer.start
   }
 
-  def startSecurityManger: Unit = {
+  def startSecurityManger {
     if (securityManager != null) {
       info("Starting security manager.")
 
@@ -718,6 +769,25 @@ class SamzaContainer(
     }
   }
 
+  def addShutdownHook {
+    val runLoopThread = Thread.currentThread()
+    Runtime.getRuntime().addShutdownHook(new Thread() {
+      override def run() = {
+        info("Shutting down, will wait up to %s ms" format shutdownMs)
+        runLoop match {
+          case runLoop: RunLoop => runLoop.shutdown
+          case asyncRunLoop: AsyncRunLoop => asyncRunLoop.shutdown()
+        }
+        runLoopThread.join(shutdownMs)
+        if (runLoopThread.isAlive) {
+          warn("Did not shut down within %s ms, exiting" format shutdownMs)
+        } else {
+          info("Shutdown complete")
+        }
+      }
+    })
+  }
+
   def shutdownConsumers {
     info("Shutting down consumer multiplexer.")
 
@@ -733,6 +803,19 @@ class SamzaContainer(
   def shutdownTask {
     info("Shutting down task instance stream tasks.")
 
+
+    if (taskThreadPool != null) {
+      info("Shutting down task thread pool")
+      try {
+        taskThreadPool.shutdown()
+        if(taskThreadPool.awaitTermination(shutdownMs, TimeUnit.MILLISECONDS)) {
+          taskThreadPool.shutdownNow()
+        }
+      } catch {
+        case e: Exception => error(e.getMessage, e)
+      }
+    }
+
     taskInstances.values.foreach(_.shutdownTask)
   }
 
index 2044ce0..e3891cf 100644 (file)
@@ -34,9 +34,11 @@ class SamzaContainerMetrics(
   val envelopes = newCounter("process-envelopes")
   val nullEnvelopes = newCounter("process-null-envelopes")
   val chooseNs = newTimer("choose-ns")
+  val chooserUpdateNs = newTimer("chooser-update-ns")
   val windowNs = newTimer("window-ns")
   val processNs = newTimer("process-ns")
   val commitNs = newTimer("commit-ns")
+  val blockNs = newTimer("block-ns")
   val utilization = newGauge("event-loop-utilization", 0.0F)
   val diskUsageBytes = newGauge("disk-usage-bytes", 0L)
   val diskQuotaBytes = newGauge("disk-quota-bytes", Long.MaxValue)
index d32a929..89f6857 100644 (file)
 
 package org.apache.samza.container
 
+
 import org.apache.samza.SamzaException
 import org.apache.samza.checkpoint.OffsetManager
 import org.apache.samza.config.Config
-import org.apache.samza.config.TaskConfig.Config2Task
 import org.apache.samza.metrics.MetricsReporter
 import org.apache.samza.storage.TaskStorageManager
 import org.apache.samza.system.IncomingMessageEnvelope
-import org.apache.samza.system.SystemStreamPartition
+import org.apache.samza.system.SystemAdmin
 import org.apache.samza.system.SystemConsumers
-import org.apache.samza.task.TaskContext
+import org.apache.samza.system.SystemStreamPartition
+import org.apache.samza.task.AsyncStreamTask
 import org.apache.samza.task.ClosableTask
 import org.apache.samza.task.InitableTask
-import org.apache.samza.task.WindowableTask
-import org.apache.samza.task.StreamTask
 import org.apache.samza.task.ReadableCoordinator
+import org.apache.samza.task.StreamTask
+import org.apache.samza.task.TaskCallbackFactory
+import org.apache.samza.task.TaskContext
 import org.apache.samza.task.TaskInstanceCollector
+import org.apache.samza.task.WindowableTask
 import org.apache.samza.util.Logging
+
 import scala.collection.JavaConversions._
-import org.apache.samza.system.SystemAdmin
 
-class TaskInstance(
-  task: StreamTask,
+class TaskInstance[T](
+  task: T,
   val taskName: TaskName,
   config: Config,
-  metrics: TaskInstanceMetrics,
+  val metrics: TaskInstanceMetrics,
   systemAdmins: Map[String, SystemAdmin],
   consumerMultiplexer: SystemConsumers,
   collector: TaskInstanceCollector,
   containerContext: SamzaContainerContext,
-  offsetManager: OffsetManager = new OffsetManager,
+  val offsetManager: OffsetManager = new OffsetManager,
   storageManager: TaskStorageManager = null,
   reporters: Map[String, MetricsReporter] = Map(),
   val systemStreamPartitions: Set[SystemStreamPartition] = Set(),
@@ -56,6 +59,8 @@ class TaskInstance(
   val isInitableTask = task.isInstanceOf[InitableTask]
   val isWindowableTask = task.isInstanceOf[WindowableTask]
   val isClosableTask = task.isInstanceOf[ClosableTask]
+  val isAsyncTask = task.isInstanceOf[AsyncStreamTask]
+
   val context = new TaskContext {
     def getMetricsRegistry = metrics.registry
     def getSystemStreamPartitions = systemStreamPartitions
@@ -133,7 +138,7 @@ class TaskInstance(
     })
   }
 
-  def process(envelope: IncomingMessageEnvelope, coordinator: ReadableCoordinator) {
+  def process(envelope: IncomingMessageEnvelope, coordinator: ReadableCoordinator, callbackFactory: TaskCallbackFactory = null) {
     metrics.processes.inc
 
     if (!ssp2catchedupMapping.getOrElse(envelope.getSystemStreamPartition,
@@ -146,13 +151,20 @@ class TaskInstance(
 
       trace("Processing incoming message envelope for taskName and SSP: %s, %s" format (taskName, envelope.getSystemStreamPartition))
 
-      exceptionHandler.maybeHandle {
-        task.process(envelope, collector, coordinator)
-      }
+      if (isAsyncTask) {
+        exceptionHandler.maybeHandle {
+          val callback = callbackFactory.createCallback()
+          task.asInstanceOf[AsyncStreamTask].processAsync(envelope, collector, coordinator, callback)
+        }
+      } else {
+        exceptionHandler.maybeHandle {
+         task.asInstanceOf[StreamTask].process(envelope, collector, coordinator)
+        }
 
-      trace("Updating offset map for taskName, SSP and offset: %s, %s, %s" format (taskName, envelope.getSystemStreamPartition, envelope.getOffset))
+        trace("Updating offset map for taskName, SSP and offset: %s, %s, %s" format (taskName, envelope.getSystemStreamPartition, envelope.getOffset))
 
-      offsetManager.update(taskName, envelope.getSystemStreamPartition, envelope.getOffset)
+        offsetManager.update(taskName, envelope.getSystemStreamPartition, envelope.getOffset)
+      }
     }
   }
 
index 8b86388..7bedadf 100644 (file)
@@ -35,6 +35,8 @@ class TaskInstanceMetrics(
   val sends = newCounter("send-calls")
   val flushes = newCounter("flush-calls")
   val messagesSent = newCounter("messages-sent")
+  val pendingMessages = newGauge("pending-messages", 0)
+  val messagesInFlight = newGauge("messages-in-flight", 0)
 
   def addOffsetGauge(systemStreamPartition: SystemStreamPartition, getValue: () => String) {
     newGauge("%s-%s-%d-offset" format (systemStreamPartition.getSystem, systemStreamPartition.getStream, systemStreamPartition.getPartition.getPartitionId), getValue)
index d3bd9b7..ba38b5c 100644 (file)
@@ -71,9 +71,12 @@ object JobModelManager extends Logging {
     coordinatorSystemConsumer.start
     debug("Bootstrapping coordinator system stream.")
     coordinatorSystemConsumer.bootstrap
+    val source = "Job-coordinator"
+    coordinatorSystemProducer.register(source)
+    info("Registering coordinator system stream producer.")
     val config = coordinatorSystemConsumer.getConfig
     info("Got config: %s" format config)
-    val changelogManager = new ChangelogPartitionManager(coordinatorSystemProducer, coordinatorSystemConsumer, "Job-coordinator")
+    val changelogManager = new ChangelogPartitionManager(coordinatorSystemProducer, coordinatorSystemConsumer, source)
     val localityManager = new LocalityManager(coordinatorSystemProducer, coordinatorSystemConsumer)
 
     val systemNames = getSystemNames(config)
index 2efe836..a8355b9 100644 (file)
@@ -99,7 +99,7 @@ class SystemConsumers (
    * with no remaining unprocessed messages, the SystemConsumers will poll for
    * it within 50ms of its availability in the stream system.</p>
    */
-  pollIntervalMs: Int,
+  val pollIntervalMs: Int,
 
   /**
    * Clock can be used to inject a custom clock when mocking this class in
@@ -203,28 +203,31 @@ class SystemConsumers (
     }
   }
 
-  def choose: IncomingMessageEnvelope = {
+  def choose (updateChooser: Boolean = true): IncomingMessageEnvelope = {
     val envelopeFromChooser = chooser.choose
 
     updateTimer(metrics.deserializationNs) {
       if (envelopeFromChooser == null) {
-       trace("Chooser returned null.")
+        trace("Chooser returned null.")
 
-       metrics.choseNull.inc
+        metrics.choseNull.inc
 
-       // Sleep for a while so we don't poll in a tight loop.
-       timeout = noNewMessagesTimeout
+        // Sleep for a while so we don't poll in a tight loop.
+        timeout = noNewMessagesTimeout
       } else {
-       val systemStreamPartition = envelopeFromChooser.getSystemStreamPartition
+        val systemStreamPartition = envelopeFromChooser.getSystemStreamPartition
 
-       trace("Chooser returned an incoming message envelope: %s" format envelopeFromChooser)
+        trace("Chooser returned an incoming message envelope: %s" format envelopeFromChooser)
 
-       // Ok to give the chooser a new message from this stream.
-       timeout = 0
-       metrics.choseObject.inc
-       metrics.systemStreamMessagesChosen(envelopeFromChooser.getSystemStreamPartition).inc
+        // Ok to give the chooser a new message from this stream.
+        timeout = 0
+        metrics.choseObject.inc
+        metrics.systemStreamMessagesChosen(envelopeFromChooser.getSystemStreamPartition).inc
 
-       tryUpdate(systemStreamPartition)
+        if (updateChooser) {
+          trace("Update chooser for " + systemStreamPartition.getPartition)
+          tryUpdate(systemStreamPartition)
+        }
       }
     }
 
@@ -287,7 +290,7 @@ class SystemConsumers (
     }
   }
 
-  private def tryUpdate(ssp: SystemStreamPartition) {
+  def tryUpdate(ssp: SystemStreamPartition) {
     var updated = false
     try {
       updated = update(ssp)
diff --git a/samza-core/src/test/java/org/apache/samza/task/TestAsyncRunLoop.java b/samza-core/src/test/java/org/apache/samza/task/TestAsyncRunLoop.java
new file mode 100644 (file)
index 0000000..ca913de
--- /dev/null
@@ -0,0 +1,333 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.task;
+
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicInteger;
+import org.apache.samza.Partition;
+import org.apache.samza.checkpoint.OffsetManager;
+import org.apache.samza.config.Config;
+import org.apache.samza.container.SamzaContainerContext;
+import org.apache.samza.container.SamzaContainerMetrics;
+import org.apache.samza.container.TaskInstance;
+import org.apache.samza.container.TaskInstanceExceptionHandler;
+import org.apache.samza.container.TaskInstanceMetrics;
+import org.apache.samza.container.TaskName;
+import org.apache.samza.metrics.MetricsRegistryMap;
+import org.apache.samza.system.IncomingMessageEnvelope;
+import org.apache.samza.system.SystemConsumers;
+import org.apache.samza.system.SystemStreamPartition;
+import org.junit.Before;
+import org.junit.Test;
+import scala.collection.JavaConversions;
+
+import static org.junit.Assert.assertEquals;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+public class TestAsyncRunLoop {
+
+  Map<TaskName, TaskInstance<AsyncStreamTask>> tasks;
+  ExecutorService executor;
+  SystemConsumers consumerMultiplexer;
+  SamzaContainerMetrics containerMetrics;
+  OffsetManager offsetManager;
+  long windowMs;
+  long commitMs;
+  long callbackTimeoutMs;
+  int maxMessagesInFlight;
+  TaskCoordinator.RequestScope commitRequest;
+  TaskCoordinator.RequestScope shutdownRequest;
+
+  Partition p0 = new Partition(0);
+  Partition p1 = new Partition(1);
+  TaskName taskName0 = new TaskName(p0.toString());
+  TaskName taskName1 = new TaskName(p1.toString());
+  SystemStreamPartition ssp0 = new SystemStreamPartition("testSystem", "testStream", p0);
+  SystemStreamPartition ssp1 = new SystemStreamPartition("testSystem", "testStream", p1);
+  IncomingMessageEnvelope envelope0 = new IncomingMessageEnvelope(ssp0, "0", "key0", "value0");
+  IncomingMessageEnvelope envelope1 = new IncomingMessageEnvelope(ssp1, "1", "key1", "value1");
+  IncomingMessageEnvelope envelope3 = new IncomingMessageEnvelope(ssp0, "1", "key0", "value0");
+
+  TestTask task0;
+  TestTask task1;
+  TaskInstance<AsyncStreamTask> t0;
+  TaskInstance<AsyncStreamTask> t1;
+
+  AsyncRunLoop createRunLoop() {
+    return new AsyncRunLoop(tasks,
+        executor,
+        consumerMultiplexer,
+        maxMessagesInFlight,
+        windowMs,
+        commitMs,
+        callbackTimeoutMs,
+        containerMetrics);
+  }
+
+  TaskInstance<AsyncStreamTask> createTaskInstance(AsyncStreamTask task, TaskName taskName, SystemStreamPartition ssp) {
+    TaskInstanceMetrics taskInstanceMetrics = new TaskInstanceMetrics("task", new MetricsRegistryMap());
+    scala.collection.immutable.Set<SystemStreamPartition> sspSet = JavaConversions.asScalaSet(Collections.singleton(ssp)).toSet();
+    return new TaskInstance<AsyncStreamTask>(task, taskName, mock(Config.class), taskInstanceMetrics,
+        null, consumerMultiplexer, mock(TaskInstanceCollector.class), mock(SamzaContainerContext.class),
+        offsetManager, null, null, sspSet, new TaskInstanceExceptionHandler(taskInstanceMetrics, new scala.collection.immutable.HashSet<String>()));
+  }
+
+  ExecutorService callbackExecutor;
+  void triggerCallback(final TestTask task, final TaskCallback callback, final boolean success) {
+    callbackExecutor.submit(new Runnable() {
+      @Override
+      public void run() {
+        if (task.code != null) {
+          task.code.run(callback);
+        }
+
+        task.completed.incrementAndGet();
+
+        if (success) {
+          callback.complete();
+        } else {
+          callback.failure(new Exception("process failure"));
+        }
+      }
+    });
+  }
+
+  interface TestCode {
+    void run(TaskCallback callback);
+  }
+
+  class TestTask implements AsyncStreamTask, WindowableTask {
+    boolean shutdown = false;
+    boolean commit = false;
+    boolean success;
+    int processed = 0;
+    volatile int windowCount = 0;
+
+    AtomicInteger completed = new AtomicInteger(0);
+    TestCode code = null;
+
+    TestTask(boolean success, boolean commit, boolean shutdown) {
+      this.success = success;
+      this.shutdown = shutdown;
+      this.commit = commit;
+    }
+
+    @Override
+    public void processAsync(IncomingMessageEnvelope envelope, MessageCollector collector, TaskCoordinator coordinator,
+        TaskCallback callback) {
+
+      if (maxMessagesInFlight == 1) {
+        assertEquals(processed, completed.get());
+      }
+
+      processed++;
+
+      if (commit) {
+        coordinator.commit(commitRequest);
+      }
+
+      if (shutdown) {
+        coordinator.shutdown(shutdownRequest);
+      }
+      triggerCallback(this, callback, success);
+    }
+
+    @Override
+    public void window(MessageCollector collector, TaskCoordinator coordinator) throws Exception {
+      windowCount++;
+
+      if (shutdown && windowCount == 4) {
+        coordinator.shutdown(shutdownRequest);
+      }
+    }
+  }
+
+  @Before
+  public void setup() {
+    executor = null;
+    consumerMultiplexer = mock(SystemConsumers.class);
+    windowMs = -1;
+    commitMs = -1;
+    maxMessagesInFlight = 1;
+    containerMetrics = new SamzaContainerMetrics("container", new MetricsRegistryMap());
+    callbackExecutor = Executors.newFixedThreadPool(2);
+    offsetManager = mock(OffsetManager.class);
+    shutdownRequest = TaskCoordinator.RequestScope.ALL_TASKS_IN_CONTAINER;
+
+    when(consumerMultiplexer.pollIntervalMs()).thenReturn(1000000);
+
+    tasks = new HashMap<>();
+    task0 = new TestTask(true, true, false);
+    task1 = new TestTask(true, false, true);
+    t0 = createTaskInstance(task0, taskName0, ssp0);
+    t1 = createTaskInstance(task1, taskName1, ssp1);
+    tasks.put(taskName0, t0);
+    tasks.put(taskName1, t1);
+  }
+
+  @Test
+  public void testProcessMultipleTasks() throws Exception {
+    AsyncRunLoop runLoop = createRunLoop();
+    when(consumerMultiplexer.choose(false)).thenReturn(envelope0).thenReturn(envelope1).thenReturn(null);
+    runLoop.run();
+
+    callbackExecutor.awaitTermination(100, TimeUnit.MILLISECONDS);
+
+    assertEquals(1, task0.processed);
+    assertEquals(1, task0.completed.get());
+    assertEquals(1, task1.processed);
+    assertEquals(1, task1.completed.get());
+    assertEquals(2L, containerMetrics.envelopes().getCount());
+    assertEquals(2L, containerMetrics.processes().getCount());
+  }
+
+  @Test
+  public void testProcessInOrder() throws Exception {
+    AsyncRunLoop runLoop = createRunLoop();
+    when(consumerMultiplexer.choose(false)).thenReturn(envelope0).thenReturn(envelope3).thenReturn(envelope1).thenReturn(null);
+    runLoop.run();
+
+    callbackExecutor.awaitTermination(100, TimeUnit.MILLISECONDS);
+
+    assertEquals(2, task0.processed);
+    assertEquals(2, task0.completed.get());
+    assertEquals(1, task1.processed);
+    assertEquals(1, task1.completed.get());
+    assertEquals(3L, containerMetrics.envelopes().getCount());
+    assertEquals(3L, containerMetrics.processes().getCount());
+  }
+
+  @Test
+  public void testProcessOutOfOrder() throws Exception {
+    maxMessagesInFlight = 2;
+
+    final CountDownLatch latch = new CountDownLatch(1);
+    task0.code = new TestCode() {
+      @Override
+      public void run(TaskCallback callback) {
+        IncomingMessageEnvelope envelope = ((TaskCallbackImpl) callback).envelope;
+        if (envelope == envelope0) {
+          // process first message will wait till the second one is processed
+          try {
+            latch.await();
+          } catch (InterruptedException e) {
+            e.printStackTrace();
+          }
+        } else {
+          // second envelope complete first
+          assertEquals(0, task0.completed.get());
+          latch.countDown();
+        }
+      }
+    };
+
+    AsyncRunLoop runLoop = createRunLoop();
+    when(consumerMultiplexer.choose(false)).thenReturn(envelope0).thenReturn(envelope3).thenReturn(envelope1).thenReturn(null);
+    runLoop.run();
+
+    callbackExecutor.awaitTermination(100, TimeUnit.MILLISECONDS);
+
+    assertEquals(2, task0.processed);
+    assertEquals(2, task0.completed.get());
+    assertEquals(1, task1.processed);
+    assertEquals(1, task1.completed.get());
+    assertEquals(3L, containerMetrics.envelopes().getCount());
+    assertEquals(3L, containerMetrics.processes().getCount());
+  }
+
+  @Test
+  public void testWindow() throws Exception {
+    windowMs = 1;
+
+    AsyncRunLoop runLoop = createRunLoop();
+    when(consumerMultiplexer.choose(false)).thenReturn(null);
+    runLoop.run();
+
+    callbackExecutor.awaitTermination(100, TimeUnit.MILLISECONDS);
+
+    assertEquals(4, task1.windowCount);
+  }
+
+  @Test
+  public void testCommitSingleTask() throws Exception {
+    commitRequest = TaskCoordinator.RequestScope.CURRENT_TASK;
+
+    AsyncRunLoop runLoop = createRunLoop();
+    //have a null message in between to make sure task0 finishes processing and invoke the commit
+    when(consumerMultiplexer.choose(false)).thenReturn(envelope0).thenReturn(null).thenReturn(envelope1).thenReturn(null);
+    runLoop.run();
+
+    callbackExecutor.awaitTermination(100, TimeUnit.MILLISECONDS);
+
+    verify(offsetManager).checkpoint(taskName0);
+    verify(offsetManager, never()).checkpoint(taskName1);
+  }
+
+  @Test
+  public void testCommitAllTasks() throws Exception {
+    commitRequest = TaskCoordinator.RequestScope.ALL_TASKS_IN_CONTAINER;
+
+    AsyncRunLoop runLoop = createRunLoop();
+    //have a null message in between to make sure task0 finishes processing and invoke the commit
+    when(consumerMultiplexer.choose(false)).thenReturn(envelope0).thenReturn(null).thenReturn(envelope1).thenReturn(null);
+    runLoop.run();
+
+    callbackExecutor.awaitTermination(100, TimeUnit.MILLISECONDS);
+
+    verify(offsetManager).checkpoint(taskName0);
+    verify(offsetManager).checkpoint(taskName1);
+  }
+
+  @Test
+  public void testShutdownOnConsensus() throws Exception {
+    shutdownRequest = TaskCoordinator.RequestScope.CURRENT_TASK;
+
+    tasks = new HashMap<>();
+    task0 = new TestTask(true, true, true);
+    task1 = new TestTask(true, false, true);
+    t0 = createTaskInstance(task0, taskName0, ssp0);
+    t1 = createTaskInstance(task1, taskName1, ssp1);
+    tasks.put(taskName0, t0);
+    tasks.put(taskName1, t1);
+
+    AsyncRunLoop runLoop = createRunLoop();
+    // consensus is reached after envelope1 is processed.
+    when(consumerMultiplexer.choose(false)).thenReturn(envelope0).thenReturn(envelope1).thenReturn(null);
+    runLoop.run();
+
+    callbackExecutor.awaitTermination(100, TimeUnit.MILLISECONDS);
+
+    assertEquals(1, task0.processed);
+    assertEquals(1, task0.completed.get());
+    assertEquals(1, task1.processed);
+    assertEquals(1, task1.completed.get());
+    assertEquals(2L, containerMetrics.envelopes().getCount());
+    assertEquals(2L, containerMetrics.processes().getCount());
+  }
+}
diff --git a/samza-core/src/test/java/org/apache/samza/task/TestAsyncStreamAdapter.java b/samza-core/src/test/java/org/apache/samza/task/TestAsyncStreamAdapter.java
new file mode 100644 (file)
index 0000000..99e1e18
--- /dev/null
@@ -0,0 +1,141 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.task;
+
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.TimeUnit;
+import org.apache.samza.config.Config;
+import org.apache.samza.system.IncomingMessageEnvelope;
+import org.junit.Before;
+import org.junit.Test;
+
+import static org.junit.Assert.assertTrue;
+import static org.mockito.Mockito.mock;
+
+
+public class TestAsyncStreamAdapter {
+  TestStreamTask task;
+  AsyncStreamTaskAdapter taskAdaptor;
+  Exception e;
+  IncomingMessageEnvelope envelope;
+
+  class TestCallbackListener implements TaskCallbackListener {
+    boolean callbackComplete = false;
+    boolean callbackFailure = false;
+
+    @Override
+    public void onComplete(TaskCallback callback) {
+      callbackComplete = true;
+    }
+
+    @Override
+    public void onFailure(TaskCallback callback, Throwable t) {
+      callbackFailure = true;
+    }
+  }
+
+  class TestStreamTask implements StreamTask, InitableTask, ClosableTask, WindowableTask {
+    boolean inited = false;
+    boolean closed = false;
+    boolean processed = false;
+    boolean windowed = false;
+
+    @Override
+    public void close() throws Exception {
+      closed = true;
+    }
+
+    @Override
+    public void init(Config config, TaskContext context) throws Exception {
+      inited = true;
+    }
+
+    @Override
+    public void process(IncomingMessageEnvelope envelope, MessageCollector collector, TaskCoordinator coordinator) throws Exception {
+      processed = true;
+      if (e != null) {
+        throw e;
+      }
+    }
+
+    @Override
+    public void window(MessageCollector collector, TaskCoordinator coordinator) throws Exception {
+      windowed = true;
+    }
+  }
+
+  @Before
+  public void setup() {
+    task = new TestStreamTask();
+    e = null;
+    envelope = mock(IncomingMessageEnvelope.class);
+  }
+
+  @Test
+  public void testAdapterWithoutThreadPool() throws Exception {
+    taskAdaptor = new AsyncStreamTaskAdapter(task, null);
+    TestCallbackListener listener = new TestCallbackListener();
+    TaskCallback callback = new TaskCallbackImpl(listener, null, envelope, null, 0L);
+
+    taskAdaptor.init(null, null);
+    assertTrue(task.inited);
+
+    taskAdaptor.processAsync(null, null, null, callback);
+    assertTrue(task.processed);
+    assertTrue(listener.callbackComplete);
+
+    e = new Exception("dummy exception");
+    taskAdaptor.processAsync(null, null, null, callback);
+    assertTrue(listener.callbackFailure);
+
+    taskAdaptor.window(null, null);
+    assertTrue(task.windowed);
+
+    taskAdaptor.close();
+    assertTrue(task.closed);
+  }
+
+  @Test
+  public void testAdapterWithThreadPool() throws Exception {
+    TestCallbackListener listener1 = new TestCallbackListener();
+    TaskCallback callback1 = new TaskCallbackImpl(listener1, null, envelope, null, 0L);
+
+    TestCallbackListener listener2 = new TestCallbackListener();
+    TaskCallback callback2 = new TaskCallbackImpl(listener2, null, envelope, null, 1L);
+
+    ExecutorService executor = Executors.newFixedThreadPool(2);
+    taskAdaptor = new AsyncStreamTaskAdapter(task, executor);
+    taskAdaptor.processAsync(null, null, null, callback1);
+    taskAdaptor.processAsync(null, null, null, callback2);
+
+    executor.awaitTermination(1, TimeUnit.SECONDS);
+    assertTrue(listener1.callbackComplete);
+    assertTrue(listener2.callbackComplete);
+
+    e = new Exception("dummy exception");
+    taskAdaptor.processAsync(null, null, null, callback1);
+    taskAdaptor.processAsync(null, null, null, callback2);
+
+    executor.awaitTermination(1, TimeUnit.SECONDS);
+    assertTrue(listener1.callbackFailure);
+    assertTrue(listener2.callbackFailure);
+  }
+}
diff --git a/samza-core/src/test/java/org/apache/samza/task/TestCoordinatorRequests.java b/samza-core/src/test/java/org/apache/samza/task/TestCoordinatorRequests.java
new file mode 100644 (file)
index 0000000..d9c68d7
--- /dev/null
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.task;
+
+import java.util.HashSet;
+import java.util.Set;
+import org.apache.samza.container.TaskName;
+import org.junit.Before;
+import org.junit.Test;
+
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+
+public class TestCoordinatorRequests {
+  CoordinatorRequests coordinatorRequests;
+  TaskName taskA = new TaskName("a");
+  TaskName taskB = new TaskName("b");
+  TaskName taskC = new TaskName("c");
+
+
+  @Before
+  public void setup() {
+    Set<TaskName> taskNames = new HashSet<>();
+    taskNames.add(taskA);
+    taskNames.add(taskB);
+    taskNames.add(taskC);
+
+    coordinatorRequests = new CoordinatorRequests(taskNames);
+  }
+
+  @Test
+  public void testUpdateCommit() {
+    ReadableCoordinator coordinator = new ReadableCoordinator(taskA);
+    coordinator.commit(TaskCoordinator.RequestScope.CURRENT_TASK);
+    coordinatorRequests.update(coordinator);
+    assertTrue(coordinatorRequests.commitRequests().contains(taskA));
+
+    coordinator = new ReadableCoordinator(taskC);
+    coordinator.commit(TaskCoordinator.RequestScope.CURRENT_TASK);
+    coordinatorRequests.update(coordinator);
+    assertTrue(coordinatorRequests.commitRequests().contains(taskC));
+    assertFalse(coordinatorRequests.commitRequests().contains(taskB));
+    assertTrue(coordinatorRequests.commitRequests().size() == 2);
+
+    coordinator.commit(TaskCoordinator.RequestScope.ALL_TASKS_IN_CONTAINER);
+    coordinatorRequests.update(coordinator);
+    assertTrue(coordinatorRequests.commitRequests().contains(taskB));
+    assertTrue(coordinatorRequests.commitRequests().size() == 3);
+  }
+
+  @Test
+  public void testUpdateShutdownOnConsensus() {
+    ReadableCoordinator coordinator = new ReadableCoordinator(taskA);
+    coordinator.shutdown(TaskCoordinator.RequestScope.CURRENT_TASK);
+    coordinatorRequests.update(coordinator);
+    assertFalse(coordinatorRequests.shouldShutdownNow());
+
+    coordinator = new ReadableCoordinator(taskB);
+    coordinator.shutdown(TaskCoordinator.RequestScope.CURRENT_TASK);
+    coordinatorRequests.update(coordinator);
+    assertFalse(coordinatorRequests.shouldShutdownNow());
+
+    coordinator = new ReadableCoordinator(taskC);
+    coordinator.shutdown(TaskCoordinator.RequestScope.CURRENT_TASK);
+    coordinatorRequests.update(coordinator);
+    assertTrue(coordinatorRequests.shouldShutdownNow());
+  }
+
+  @Test
+  public void testUpdateShutdownNow() {
+    ReadableCoordinator coordinator = new ReadableCoordinator(taskA);
+    coordinator.shutdown(TaskCoordinator.RequestScope.ALL_TASKS_IN_CONTAINER);
+    coordinatorRequests.update(coordinator);
+    assertTrue(coordinatorRequests.shouldShutdownNow());
+  }
+}
diff --git a/samza-core/src/test/java/org/apache/samza/task/TestTaskCallbackImpl.java b/samza-core/src/test/java/org/apache/samza/task/TestTaskCallbackImpl.java
new file mode 100644 (file)
index 0000000..f1dbf35
--- /dev/null
@@ -0,0 +1,125 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.task;
+
+import java.util.concurrent.CyclicBarrier;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicInteger;
+import org.apache.samza.system.IncomingMessageEnvelope;
+import org.junit.Before;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+import static org.mockito.Mockito.mock;
+
+
+public class TestTaskCallbackImpl {
+
+  TaskCallbackListener listener = null;
+  AtomicInteger completeCount;
+  AtomicInteger failureCount;
+  TaskCallback callback = null;
+  Throwable throwable = null;
+
+  @Before
+  public void setup() {
+    completeCount = new AtomicInteger(0);
+    failureCount = new AtomicInteger(0);
+    throwable = null;
+
+    listener = new TaskCallbackListener() {
+
+      @Override
+      public void onComplete(TaskCallback callback) {
+        completeCount.incrementAndGet();
+      }
+
+      @Override
+      public void onFailure(TaskCallback callback, Throwable t) {
+        throwable = t;
+        failureCount.incrementAndGet();
+      }
+    };
+
+    callback = new TaskCallbackImpl(listener, null, mock(IncomingMessageEnvelope.class), null, 0);
+  }
+
+  @Test
+  public void testComplete() {
+    callback.complete();
+    assertEquals(1L, completeCount.get());
+    assertEquals(0L, failureCount.get());
+  }
+
+  @Test
+  public void testFailure() {
+    callback.failure(new Exception("dummy exception"));
+    assertEquals(0L, completeCount.get());
+    assertEquals(1L, failureCount.get());
+  }
+
+  @Test
+  public void testCallbackMultipleComplete() {
+    callback.complete();
+    assertEquals(1L, completeCount.get());
+
+    callback.complete();
+    assertEquals(1L, failureCount.get());
+    assertTrue(throwable instanceof IllegalStateException);
+  }
+
+  @Test
+  public void testCallbackFailureAfterComplete() {
+    callback.complete();
+    assertEquals(1L, completeCount.get());
+
+    callback.failure(new Exception("dummy exception"));
+    assertEquals(1L, failureCount.get());
+    assertTrue(throwable instanceof IllegalStateException);
+  }
+
+
+  @Test
+  public void testMultithreadedCallbacks() throws Exception {
+    final CyclicBarrier barrier = new CyclicBarrier(2);
+    ExecutorService executor = Executors.newFixedThreadPool(2);
+
+    for (int i = 0; i < 2; i++) {
+      executor.submit(new Runnable() {
+        @Override
+        public void run() {
+          try {
+            barrier.await();
+            callback.complete();
+          } catch (Exception e) {
+            e.printStackTrace();
+          }
+        }
+      });
+    }
+    executor.awaitTermination(1, TimeUnit.SECONDS);
+    assertEquals(1L, completeCount.get());
+    assertEquals(1L, failureCount.get());
+    assertTrue(throwable instanceof IllegalStateException);
+  }
+}
diff --git a/samza-core/src/test/java/org/apache/samza/task/TestTaskCallbackManager.java b/samza-core/src/test/java/org/apache/samza/task/TestTaskCallbackManager.java
new file mode 100644 (file)
index 0000000..d7110f3
--- /dev/null
@@ -0,0 +1,141 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.task;
+
+import org.apache.samza.Partition;
+import org.apache.samza.container.TaskInstanceMetrics;
+import org.apache.samza.container.TaskName;
+import org.apache.samza.metrics.MetricsRegistryMap;
+import org.apache.samza.system.IncomingMessageEnvelope;
+import org.apache.samza.system.SystemStreamPartition;
+import org.junit.Before;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertTrue;
+
+
+public class TestTaskCallbackManager {
+  TaskCallbackManager callbackManager = null;
+  TaskCallbackListener listener = null;
+
+  @Before
+  public void setup() {
+    TaskInstanceMetrics metrics = new TaskInstanceMetrics("Partition 0", new MetricsRegistryMap());
+    listener = new TaskCallbackListener() {
+      @Override
+      public void onComplete(TaskCallback callback) {
+      }
+      @Override
+      public void onFailure(TaskCallback callback, Throwable t) {
+      }
+    };
+    callbackManager = new TaskCallbackManager(listener, metrics, null, -1);
+
+  }
+
+  @Test
+  public void testCreateCallback() {
+    TaskCallbackImpl callback = callbackManager.createCallback(new TaskName("Partition 0"), null, null);
+    assertTrue(callback.matchSeqNum(0));
+
+    callback = callbackManager.createCallback(new TaskName("Partition 0"), null, null);
+    assertTrue(callback.matchSeqNum(1));
+  }
+
+  @Test
+  public void testUpdateCallbackInOrder() {
+    TaskName taskName = new TaskName("Partition 0");
+    SystemStreamPartition ssp = new SystemStreamPartition("kafka", "topic", new Partition(0));
+    ReadableCoordinator coordinator = new ReadableCoordinator(taskName);
+
+    IncomingMessageEnvelope envelope0 = new IncomingMessageEnvelope(ssp, "0", null, null);
+    TaskCallbackImpl callback0 = new TaskCallbackImpl(listener, taskName, envelope0, coordinator, 0);
+    TaskCallbackImpl callbackToCommit = callbackManager.updateCallback(callback0, true);
+    assertTrue(callbackToCommit.matchSeqNum(0));
+    assertEquals(ssp, callbackToCommit.envelope.getSystemStreamPartition());
+    assertEquals("0", callbackToCommit.envelope.getOffset());
+
+    IncomingMessageEnvelope envelope1 = new IncomingMessageEnvelope(ssp, "1", null, null);
+    TaskCallbackImpl callback1 = new TaskCallbackImpl(listener, taskName, envelope1, coordinator, 1);
+    callbackToCommit = callbackManager.updateCallback(callback1, true);
+    assertTrue(callbackToCommit.matchSeqNum(1));
+    assertEquals(ssp, callbackToCommit.envelope.getSystemStreamPartition());
+    assertEquals("1", callbackToCommit.envelope.getOffset());
+  }
+
+  @Test
+  public void testUpdateCallbackOutofOrder() {
+    TaskName taskName = new TaskName("Partition 0");
+    SystemStreamPartition ssp = new SystemStreamPartition("kafka", "topic", new Partition(0));
+    ReadableCoordinator coordinator = new ReadableCoordinator(taskName);
+
+    // simulate out of order
+    IncomingMessageEnvelope envelope2 = new IncomingMessageEnvelope(ssp, "2", null, null);
+    TaskCallbackImpl callback2 = new TaskCallbackImpl(listener, taskName, envelope2, coordinator, 2);
+    TaskCallbackImpl callbackToCommit = callbackManager.updateCallback(callback2, true);
+    assertNull(callbackToCommit);
+
+    IncomingMessageEnvelope envelope1 = new IncomingMessageEnvelope(ssp, "1", null, null);
+    TaskCallbackImpl callback1 = new TaskCallbackImpl(listener, taskName, envelope1, coordinator, 1);
+    callbackToCommit = callbackManager.updateCallback(callback1, true);
+    assertNull(callbackToCommit);
+
+    IncomingMessageEnvelope envelope0 = new IncomingMessageEnvelope(ssp, "0", null, null);
+    TaskCallbackImpl callback0 = new TaskCallbackImpl(listener, taskName, envelope0, coordinator, 0);
+    callbackToCommit = callbackManager.updateCallback(callback0, true);
+    assertTrue(callbackToCommit.matchSeqNum(2));
+    assertEquals(ssp, callbackToCommit.envelope.getSystemStreamPartition());
+    assertEquals("2", callbackToCommit.envelope.getOffset());
+  }
+
+  @Test
+  public void testUpdateCallbackWithCoordinatorRequests() {
+    TaskName taskName = new TaskName("Partition 0");
+    SystemStreamPartition ssp = new SystemStreamPartition("kafka", "topic", new Partition(0));
+
+
+    // simulate out of order
+    IncomingMessageEnvelope envelope2 = new IncomingMessageEnvelope(ssp, "2", null, null);
+    ReadableCoordinator coordinator2 = new ReadableCoordinator(taskName);
+    coordinator2.shutdown(TaskCoordinator.RequestScope.ALL_TASKS_IN_CONTAINER);
+    TaskCallbackImpl callback2 = new TaskCallbackImpl(listener, taskName, envelope2, coordinator2, 2);
+    TaskCallbackImpl callbackToCommit = callbackManager.updateCallback(callback2, true);
+    assertNull(callbackToCommit);
+
+    IncomingMessageEnvelope envelope1 = new IncomingMessageEnvelope(ssp, "1", null, null);
+    ReadableCoordinator coordinator1 = new ReadableCoordinator(taskName);
+    coordinator1.commit(TaskCoordinator.RequestScope.CURRENT_TASK);
+    TaskCallbackImpl callback1 = new TaskCallbackImpl(listener, taskName, envelope1, coordinator1, 1);
+    callbackToCommit = callbackManager.updateCallback(callback1, true);
+    assertNull(callbackToCommit);
+
+    IncomingMessageEnvelope envelope0 = new IncomingMessageEnvelope(ssp, "0", null, null);
+    ReadableCoordinator coordinator = new ReadableCoordinator(taskName);
+    TaskCallbackImpl callback0 = new TaskCallbackImpl(listener, taskName, envelope0, coordinator, 0);
+    callbackToCommit = callbackManager.updateCallback(callback0, true);
+    assertTrue(callbackToCommit.matchSeqNum(1));
+    assertEquals(ssp, callbackToCommit.envelope.getSystemStreamPartition());
+    assertEquals("1", callbackToCommit.envelope.getOffset());
+    assertTrue(callbackToCommit.coordinator.requestedShutdownNow());
+  }
+
+}
index e280daa..aa1a8d6 100644 (file)
 package org.apache.samza.container
 
 
-import org.apache.samza.metrics.{Timer, SlidingTimeWindowReservoir, MetricsRegistryMap}
+import org.apache.samza.Partition
+import org.apache.samza.metrics.MetricsRegistryMap
+import org.apache.samza.metrics.SlidingTimeWindowReservoir
+import org.apache.samza.metrics.Timer
+import org.apache.samza.system.IncomingMessageEnvelope
+import org.apache.samza.system.SystemConsumers
+import org.apache.samza.system.SystemStreamPartition
+import org.apache.samza.task.TaskCoordinator.RequestScope
+import org.apache.samza.task.ReadableCoordinator
+import org.apache.samza.task.StreamTask
 import org.apache.samza.util.Clock
-import org.junit.Test
 import org.junit.Assert._
+import org.junit.Test
 import org.mockito.Matchers
 import org.mockito.Mockito._
-import org.mockito.internal.util.reflection.Whitebox
 import org.mockito.invocation.InvocationOnMock
 import org.mockito.stubbing.Answer
 import org.scalatest.junit.AssertionsForJUnit
-import org.scalatest.{Matchers => ScalaTestMatchers}
 import org.scalatest.mock.MockitoSugar
-import org.apache.samza.Partition
-import org.apache.samza.system.{ IncomingMessageEnvelope, SystemConsumers, SystemStreamPartition }
-import org.apache.samza.task.ReadableCoordinator
-import org.apache.samza.task.TaskCoordinator.RequestScope
+import org.scalatest.{Matchers => ScalaTestMatchers}
 
 class TestRunLoop extends AssertionsForJUnit with MockitoSugar with ScalaTestMatchers {
   class StopRunLoop extends RuntimeException
@@ -49,12 +53,12 @@ class TestRunLoop extends AssertionsForJUnit with MockitoSugar with ScalaTestMat
   val envelope0 = new IncomingMessageEnvelope(ssp0, "0", "key0", "value0")
   val envelope1 = new IncomingMessageEnvelope(ssp1, "1", "key1", "value1")
 
-  def getMockTaskInstances: Map[TaskName, TaskInstance] = {
-    val ti0 = mock[TaskInstance]
+  def getMockTaskInstances: Map[TaskName, TaskInstance[StreamTask]] = {
+    val ti0 = mock[TaskInstance[StreamTask]]
     when(ti0.systemStreamPartitions).thenReturn(Set(ssp0))
     when(ti0.taskName).thenReturn(taskName0)
 
-    val ti1 = mock[TaskInstance]
+    val ti1 = mock[TaskInstance[StreamTask]]
     when(ti1.systemStreamPartitions).thenReturn(Set(ssp1))
     when(ti1.taskName).thenReturn(taskName1)
 
@@ -67,10 +71,10 @@ class TestRunLoop extends AssertionsForJUnit with MockitoSugar with ScalaTestMat
     val consumers = mock[SystemConsumers]
     val runLoop = new RunLoop(taskInstances, consumers, new SamzaContainerMetrics)
 
-    when(consumers.choose).thenReturn(envelope0).thenReturn(envelope1).thenThrow(new StopRunLoop)
+    when(consumers.choose()).thenReturn(envelope0).thenReturn(envelope1).thenThrow(new StopRunLoop)
     intercept[StopRunLoop] { runLoop.run }
-    verify(taskInstances(taskName0)).process(Matchers.eq(envelope0), anyObject)
-    verify(taskInstances(taskName1)).process(Matchers.eq(envelope1), anyObject)
+    verify(taskInstances(taskName0)).process(Matchers.eq(envelope0), anyObject, anyObject)
+    verify(taskInstances(taskName1)).process(Matchers.eq(envelope1), anyObject, anyObject)
     runLoop.metrics.envelopes.getCount should equal(2L)
     runLoop.metrics.nullEnvelopes.getCount should equal(0L)
   }
@@ -80,7 +84,7 @@ class TestRunLoop extends AssertionsForJUnit with MockitoSugar with ScalaTestMat
     val consumers = mock[SystemConsumers]
     val map = getMockTaskInstances - taskName1 // This test only needs p0
     val runLoop = new RunLoop(map, consumers, new SamzaContainerMetrics)
-    when(consumers.choose).thenReturn(null).thenReturn(null).thenThrow(new StopRunLoop)
+    when(consumers.choose()).thenReturn(null).thenReturn(null).thenThrow(new StopRunLoop)
     intercept[StopRunLoop] { runLoop.run }
     runLoop.metrics.envelopes.getCount should equal(0L)
     runLoop.metrics.nullEnvelopes.getCount should equal(2L)
@@ -90,7 +94,7 @@ class TestRunLoop extends AssertionsForJUnit with MockitoSugar with ScalaTestMat
   def testWindowAndCommitAreCalledRegularly {
     var now = 1400000000000L
     val consumers = mock[SystemConsumers]
-    when(consumers.choose).thenReturn(envelope0)
+    when(consumers.choose()).thenReturn(envelope0)
 
     val runLoop = new RunLoop(
       taskInstances = getMockTaskInstances,
@@ -118,7 +122,7 @@ class TestRunLoop extends AssertionsForJUnit with MockitoSugar with ScalaTestMat
     val consumers = mock[SystemConsumers]
     val runLoop = new RunLoop(taskInstances, consumers, new SamzaContainerMetrics, windowMs = -1, commitMs = -1)
 
-    when(consumers.choose).thenReturn(envelope0).thenReturn(envelope1).thenThrow(new StopRunLoop)
+    when(consumers.choose()).thenReturn(envelope0).thenReturn(envelope1).thenThrow(new StopRunLoop)
     stubProcess(taskInstances(taskName0), (envelope, coordinator) => coordinator.commit(RequestScope.CURRENT_TASK))
 
     intercept[StopRunLoop] { runLoop.run }
@@ -132,7 +136,7 @@ class TestRunLoop extends AssertionsForJUnit with MockitoSugar with ScalaTestMat
     val consumers = mock[SystemConsumers]
     val runLoop = new RunLoop(taskInstances, consumers, new SamzaContainerMetrics, windowMs = -1, commitMs = -1)
 
-    when(consumers.choose).thenReturn(envelope0).thenThrow(new StopRunLoop)
+    when(consumers.choose()).thenReturn(envelope0).thenThrow(new StopRunLoop)
     stubProcess(taskInstances(taskName0), (envelope, coordinator) => coordinator.commit(RequestScope.ALL_TASKS_IN_CONTAINER))
 
     intercept[StopRunLoop] { runLoop.run }
@@ -146,13 +150,13 @@ class TestRunLoop extends AssertionsForJUnit with MockitoSugar with ScalaTestMat
     val consumers = mock[SystemConsumers]
     val runLoop = new RunLoop(taskInstances, consumers, new SamzaContainerMetrics, windowMs = -1, commitMs = -1)
 
-    when(consumers.choose).thenReturn(envelope0).thenReturn(envelope0).thenReturn(envelope1)
+    when(consumers.choose()).thenReturn(envelope0).thenReturn(envelope0).thenReturn(envelope1)
     stubProcess(taskInstances(taskName0), (envelope, coordinator) => coordinator.shutdown(RequestScope.CURRENT_TASK))
     stubProcess(taskInstances(taskName1), (envelope, coordinator) => coordinator.shutdown(RequestScope.CURRENT_TASK))
 
     runLoop.run
-    verify(taskInstances(taskName0), times(2)).process(Matchers.eq(envelope0), anyObject)
-    verify(taskInstances(taskName1), times(1)).process(Matchers.eq(envelope1), anyObject)
+    verify(taskInstances(taskName0), times(2)).process(Matchers.eq(envelope0), anyObject, anyObject)
+    verify(taskInstances(taskName1), times(1)).process(Matchers.eq(envelope1), anyObject, anyObject)
   }
 
   @Test
@@ -161,19 +165,19 @@ class TestRunLoop extends AssertionsForJUnit with MockitoSugar with ScalaTestMat
     val consumers = mock[SystemConsumers]
     val runLoop = new RunLoop(taskInstances, consumers, new SamzaContainerMetrics, windowMs = -1, commitMs = -1)
 
-    when(consumers.choose).thenReturn(envelope0).thenReturn(envelope1)
+    when(consumers.choose()).thenReturn(envelope0).thenReturn(envelope1)
     stubProcess(taskInstances(taskName0), (envelope, coordinator) => coordinator.shutdown(RequestScope.ALL_TASKS_IN_CONTAINER))
 
     runLoop.run
-    verify(taskInstances(taskName0), times(1)).process(anyObject, anyObject)
-    verify(taskInstances(taskName1), times(0)).process(anyObject, anyObject)
+    verify(taskInstances(taskName0), times(1)).process(anyObject, anyObject, anyObject)
+    verify(taskInstances(taskName1), times(0)).process(anyObject, anyObject, anyObject)
   }
 
   def anyObject[T] = Matchers.anyObject.asInstanceOf[T]
 
   // Stub out TaskInstance.process. Mockito really doesn't make this easy. :(
-  def stubProcess(taskInstance: TaskInstance, process: (IncomingMessageEnvelope, ReadableCoordinator) => Unit) {
-    when(taskInstance.process(anyObject, anyObject)).thenAnswer(new Answer[Unit]() {
+  def stubProcess(taskInstance: TaskInstance[StreamTask], process: (IncomingMessageEnvelope, ReadableCoordinator) => Unit) {
+    when(taskInstance.process(anyObject, anyObject, anyObject)).thenAnswer(new Answer[Unit]() {
       override def answer(invocation: InvocationOnMock) {
         val envelope = invocation.getArguments()(0).asInstanceOf[IncomingMessageEnvelope]
         val coordinator = invocation.getArguments()(1).asInstanceOf[ReadableCoordinator]
@@ -186,7 +190,7 @@ class TestRunLoop extends AssertionsForJUnit with MockitoSugar with ScalaTestMat
   def testUpdateTimerCorrectly {
     var now = 0L
     val consumers = mock[SystemConsumers]
-    when(consumers.choose).thenReturn(envelope0)
+    when(consumers.choose()).thenReturn(envelope0)
     val clock = new Clock {
       var c = 0L
       def currentTimeMillis: Long = {
@@ -263,9 +267,9 @@ class TestRunLoop extends AssertionsForJUnit with MockitoSugar with ScalaTestMat
 
   @Test
   def testGetSystemStreamPartitionToTaskInstancesMapping {
-    val ti0 = mock[TaskInstance]
-    val ti1 = mock[TaskInstance]
-    val ti2 = mock[TaskInstance]
+    val ti0 = mock[TaskInstance[StreamTask]]
+    val ti1 = mock[TaskInstance[StreamTask]]
+    val ti2 = mock[TaskInstance[StreamTask]]
     when(ti0.systemStreamPartitions).thenReturn(Set(ssp0))
     when(ti1.systemStreamPartitions).thenReturn(Set(ssp1))
     when(ti2.systemStreamPartitions).thenReturn(Set(ssp1))
index 1358fdd..cff6b96 100644 (file)
@@ -180,7 +180,7 @@ class TestSamzaContainer extends AssertionsForJUnit with MockitoSugar {
       new SerdeManager)
     val collector = new TaskInstanceCollector(producerMultiplexer)
     val containerContext = new SamzaContainerContext(0, config, Set[TaskName](taskName))
-    val taskInstance: TaskInstance = new TaskInstance(
+    val taskInstance: TaskInstance[StreamTask] = new TaskInstance[StreamTask](
       task,
       taskName,
       config,
@@ -261,7 +261,7 @@ class TestSamzaContainer extends AssertionsForJUnit with MockitoSugar {
       }
     })
 
-    val taskInstance: TaskInstance = new TaskInstance(
+    val taskInstance: TaskInstance[StreamTask] = new TaskInstance[StreamTask](
       task,
       taskName,
       config,
@@ -314,4 +314,4 @@ class MockJobServlet(exceptionLimit: Int, jobModelRef: AtomicReference[JobModel]
       jobModel
     }
   }
-}
\ No newline at end of file
+}
index 5457f0e..3c83529 100644 (file)
@@ -71,7 +71,7 @@ class TestTaskInstance {
     val taskName = new TaskName("taskName")
     val collector = new TaskInstanceCollector(producerMultiplexer)
     val containerContext = new SamzaContainerContext(0, config, Set[TaskName](taskName))
-    val taskInstance: TaskInstance = new TaskInstance(
+    val taskInstance: TaskInstance[StreamTask] = new TaskInstance[StreamTask](
       task,
       taskName,
       config,
@@ -169,7 +169,7 @@ class TestTaskInstance {
 
     val registry = new MetricsRegistryMap
     val taskMetrics = new TaskInstanceMetrics(registry = registry)
-    val taskInstance: TaskInstance = new TaskInstance(
+    val taskInstance: TaskInstance[StreamTask] = new TaskInstance[StreamTask](
       task,
       taskName,
       config,
@@ -226,7 +226,7 @@ class TestTaskInstance {
 
     val registry = new MetricsRegistryMap
     val taskMetrics = new TaskInstanceMetrics(registry = registry)
-    val taskInstance: TaskInstance = new TaskInstance(
+    val taskInstance: TaskInstance[StreamTask] = new TaskInstance[StreamTask](
       task,
       taskName,
       config,
index 09da62e..db2249b 100644 (file)
@@ -54,14 +54,14 @@ class TestSystemConsumers {
     consumer.setResponseSizes(numEnvelopes)
 
     // Choose to trigger a refresh with data.
-    assertNull(consumers.choose)
+    assertNull(consumers.choose())
     // 2: First on start, second on choose.
     assertEquals(2, consumer.polls)
     assertEquals(2, consumer.lastPoll.size)
     assertTrue(consumer.lastPoll.contains(systemStreamPartition0))
     assertTrue(consumer.lastPoll.contains(systemStreamPartition1))
-    assertEquals(envelope, consumers.choose)
-    assertEquals(envelope, consumers.choose)
+    assertEquals(envelope, consumers.choose())
+    assertEquals(envelope, consumers.choose())
     // We aren't polling because we're getting non-null envelopes.
     assertEquals(2, consumer.polls)
 
@@ -69,7 +69,7 @@ class TestSystemConsumers {
     // messages.
     now = SystemConsumers.DEFAULT_POLL_INTERVAL_MS
 
-    assertEquals(envelope, consumers.choose)
+    assertEquals(envelope, consumers.choose())
 
     // We polled even though there are still 997 messages in the unprocessed
     // message buffer.
@@ -82,11 +82,11 @@ class TestSystemConsumers {
     // Now drain all messages for SSP0. There should be exactly 997 messages,
     // since we have chosen 3 already, and we started with 1000.
     (0 until (numEnvelopes - 3)).foreach { i =>
-      assertEquals(envelope, consumers.choose)
+      assertEquals(envelope, consumers.choose())
     }
 
     // Nothing left. Should trigger a poll here.
-    assertNull(consumers.choose)
+    assertNull(consumers.choose())
     assertEquals(4, consumer.polls)
     assertEquals(2, consumer.lastPoll.size)
 
@@ -117,31 +117,31 @@ class TestSystemConsumers {
     consumer.setResponseSizes(1)
 
     // Choose to trigger a refresh with data.
-    assertNull(consumers.choose)
+    assertNull(consumers.choose())
 
     // Choose should have triggered a second poll, since no messages are available.
     assertEquals(2, consumer.polls)
 
     // Choose a few times. This time there is no data.
-    assertEquals(envelope, consumers.choose)
-    assertNull(consumers.choose)
-    assertNull(consumers.choose)
+    assertEquals(envelope, consumers.choose())
+    assertNull(consumers.choose())
+    assertNull(consumers.choose())
 
     // Return more than one message this time.
     consumer.setResponseSizes(2)
 
     // Choose to trigger a refresh with data.
-    assertNull(consumers.choose)
+    assertNull(consumers.choose())
 
     // Increase clock interval.
     now = SystemConsumers.DEFAULT_POLL_INTERVAL_MS
 
     // We get two messages now.
-    assertEquals(envelope, consumers.choose)
+    assertEquals(envelope, consumers.choose())
     // Should not poll even though clock interval increases past interval threshold.
     assertEquals(2, consumer.polls)
-    assertEquals(envelope, consumers.choose)
-    assertNull(consumers.choose)
+    assertEquals(envelope, consumers.choose())
+    assertNull(consumers.choose())
   }
 
   @Test
@@ -238,7 +238,7 @@ class TestSystemConsumers {
 
     var caughtRightException = false
     try {
-      consumers.choose
+      consumers.choose()
     } catch {
       case e: SystemConsumersException => caughtRightException = true
       case _: Throwable => caughtRightException = false
@@ -256,13 +256,13 @@ class TestSystemConsumers {
 
     var notThrowException = true;
     try {
-      consumers2.choose
+      consumers2.choose()
     } catch {
       case e: Throwable => notThrowException = false
     }
     assertTrue("it should not throw any exception", notThrowException)
 
-    var msgEnvelope = Some(consumers2.choose)
+    var msgEnvelope = Some(consumers2.choose())
     assertTrue("Consumer did not succeed in receiving the second message after Serde exception in choose", msgEnvelope.get != null)
     consumers2.stop
 
@@ -279,7 +279,7 @@ class TestSystemConsumers {
     assertTrue("SystemConsumer start should not throw any Serde exception", notThrowException)
 
     msgEnvelope = null
-    msgEnvelope = Some(consumers2.choose)
+    msgEnvelope = Some(consumers2.choose())
     assertTrue("Consumer did not succeed in receiving the second message after Serde exception in poll", msgEnvelope.get != null)
     consumers2.stop
 
index 1f4b5c4..24bc8b5 100644 (file)
@@ -36,6 +36,7 @@ class HdfsSystemProducer(
   val clock: () => Long = () => System.currentTimeMillis) extends SystemProducer with Logging with TimerUtils {
   val dfs = FileSystem.get(new Configuration(true))
   val writers: MMap[String, HdfsWriter[_]] = MMap.empty[String, HdfsWriter[_]]
+  private val lock = new Object //synchronization lock for thread safe access
 
   def start(): Unit = {
     info("entering HdfsSystemProducer.start() call for system: " + systemName + ", client: " + clientId)
@@ -43,52 +44,65 @@ class HdfsSystemProducer(
 
   def stop(): Unit = {
     info("entering HdfsSystemProducer.stop() for system: " + systemName + ", client: " + clientId)
-    writers.values.map { _.close }
-    dfs.close
+
+    lock.synchronized {
+      writers.values.map(_.close)
+      dfs.close
+    }
   }
 
   def register(source: String): Unit = {
     info("entering HdfsSystemProducer.register(" + source + ") " +
       "call for system: " + systemName + ", client: " + clientId)
-    writers += (source -> HdfsWriter.getInstance(dfs, systemName, config))
+
+    lock.synchronized {
+      writers += (source -> HdfsWriter.getInstance(dfs, systemName, config))
+    }
   }
 
   def flush(source: String): Unit = {
     debug("entering HdfsSystemProducer.flush(" + source + ") " +
       "call for system: " + systemName + ", client: " + clientId)
-    try {
-      metrics.flushes.inc
-      updateTimer(metrics.flushMs) { writers.get(source).head.flush }
-      metrics.flushSuccess.inc
-    } catch {
-      case e: Exception => {
-        metrics.flushFailed.inc
-        warn("Exception thrown while client " + clientId + " flushed HDFS out stream, msg: " + e.getMessage)
-        debug("Detailed message from exception thrown by client " + clientId + " in HDFS flush: ", e)
-        writers.get(source).head.close
-        throw e
+
+    metrics.flushes.inc
+    lock.synchronized {
+      try {
+        updateTimer(metrics.flushMs) {
+          writers.get(source).head.flush
+        }
+      } catch {
+        case e: Exception => {
+          metrics.flushFailed.inc
+          warn("Exception thrown while client " + clientId + " flushed HDFS out stream, msg: " + e.getMessage)
+          debug("Detailed message from exception thrown by client " + clientId + " in HDFS flush: ", e)
+          writers.get(source).head.close
+          throw e
+        }
       }
     }
+    metrics.flushSuccess.inc
   }
 
   def send(source: String, ome: OutgoingMessageEnvelope) = {
     debug("entering HdfsSystemProducer.send(source = " + source + ", envelope) " +
       "call for system: " + systemName + ", client: " + clientId)
+
     metrics.sends.inc
-    try {
-      updateTimer(metrics.sendMs) {
-        writers.get(source).head.write(ome)
-      }
-      metrics.sendSuccess.inc
-    } catch {
-      case e: Exception => {
-        metrics.sendFailed.inc
-        warn("Exception thrown while client " + clientId + " wrote to HDFS, msg: " + e.getMessage)
-        debug("Detailed message from exception thrown by client " + clientId + " in HDFS write: ", e)
-        writers.get(source).head.close
-        throw e
+    lock.synchronized {
+      try {
+        updateTimer(metrics.sendMs) {
+          writers.get(source).head.write(ome)
+        }
+      } catch {
+        case e: Exception => {
+          metrics.sendFailed.inc
+          warn("Exception thrown while client " + clientId + " wrote to HDFS, msg: " + e.getMessage)
+          debug("Detailed message from exception thrown by client " + clientId + " in HDFS write: ", e)
+          writers.get(source).head.close
+          throw e
+        }
       }
     }
+    metrics.sendSuccess.inc
   }
-
-}
+}
\ No newline at end of file
index 5e8cc65..5d2641a 100644 (file)
@@ -140,6 +140,7 @@ class KafkaCheckpointMigration extends MigrationPlan with Logging {
   def migrationCompletionMark(coordinatorSystemProducer: CoordinatorStreamSystemProducer) = {
     info("Marking completion of migration %s" format migrationKey)
     val message = new SetMigrationMetaMessage(source, migrationKey, migrationVal)
+    coordinatorSystemProducer.register(source)
     coordinatorSystemProducer.start()
     coordinatorSystemProducer.send(message)
     coordinatorSystemProducer.stop()
index 3769e10..5a16580 100644 (file)
 
 package org.apache.samza.system.kafka
 
-import org.apache.samza.util.Logging
-import org.apache.kafka.clients.producer.{RecordMetadata, Callback, ProducerRecord, Producer}
-import org.apache.samza.system.SystemProducer
+
+import java.util.concurrent.ConcurrentHashMap
+import java.util.concurrent.Future
+
+import org.apache.kafka.clients.producer.Callback
+import org.apache.kafka.clients.producer.Producer
+import org.apache.kafka.clients.producer.ProducerRecord
+import org.apache.kafka.clients.producer.RecordMetadata
+import org.apache.kafka.common.PartitionInfo
+import org.apache.samza.SamzaException
 import org.apache.samza.system.OutgoingMessageEnvelope
+import org.apache.samza.system.SystemProducer
 import org.apache.samza.util.ExponentialSleepStrategy
-import org.apache.samza.util.TimerUtils
 import org.apache.samza.util.KafkaUtil
-import java.util.concurrent.atomic.{AtomicInteger, AtomicReference, AtomicBoolean}
-import java.util.{Map => javaMap}
-import org.apache.samza.SamzaException
-import org.apache.kafka.common.errors.RetriableException
-import org.apache.kafka.common.PartitionInfo
-import java.util
-import java.util.concurrent.Future
-import scala.collection.JavaConversions._
+import org.apache.samza.util.Logging
+import org.apache.samza.util.TimerUtils
 
+import scala.collection.JavaConversions._
 
 class KafkaSystemProducer(systemName: String,
                           retryBackoff: ExponentialSleepStrategy = new ExponentialSleepStrategy,
                           getProducer: () => Producer[Array[Byte], Array[Byte]],
                           metrics: KafkaSystemProducerMetrics,
-                          val clock: () => Long = () => System.nanoTime,
-                          val maxRetries: Int = 30) extends SystemProducer with Logging with TimerUtils
+                          val clock: () => Long = () => System.nanoTime) extends SystemProducer with Logging with TimerUtils
 {
-  var producer: Producer[Array[Byte], Array[Byte]] = null
-  val latestFuture: javaMap[String, Future[RecordMetadata]] = new util.HashMap[String, Future[RecordMetadata]]()
-  val sendFailed: AtomicBoolean = new AtomicBoolean(false)
-  var exceptionThrown: AtomicReference[Exception] = new AtomicReference[Exception]()
-  val StreamNameNullOrEmptyErrorMsg = "Stream Name should be specified in the stream configuration file.";
-
-  // Backward-compatible constructor for Java clients
-  def this(systemName: String,
-           retryBackoff: ExponentialSleepStrategy,
-           getProducer: () => Producer[Array[Byte], Array[Byte]],
-           metrics: KafkaSystemProducerMetrics,
-           clock: () => Long) = this(systemName, retryBackoff, getProducer, metrics, clock, 30)
-
-  def start() {
+
+  class SourceData {
+    /**
+     * lock to make send() and store its future atomic
+     */
+    val sendLock: Object = new Object
+    /**
+     * The most recent send's Future handle
+     */
+    @volatile
+    var latestFuture: Future[RecordMetadata] = null
+    /**
+     * exceptionThrown: to store the exception in case of any "ultimate" send failure (ie. failure
+     * after exhausting max_retries in Kafka producer) in the I/O thread, we do not continue to queue up more send
+     * requests from the samza thread. It helps the samza thread identify if the failure happened in I/O thread or not.
+     */
+    @volatile
+    var exceptionThrown: SamzaException = null
+  }
+
+  @volatile var producer: Producer[Array[Byte], Array[Byte]] = null
+  var producerLock: Object = new Object
+  val StreamNameNullOrEmptyErrorMsg = "Stream Name should be specified in the stream configuration file."
+  val sources: ConcurrentHashMap[String, SourceData] = new ConcurrentHashMap[String, SourceData]
+
+  def start(): Unit = {
+    producerLock.synchronized {
+      if (producer == null) {
+        info("Creating a new producer for system %s." format systemName)
+        producer = getProducer()
+      }
+    }
   }
 
   def stop() {
-    if (producer != null) {
-      latestFuture.keys.foreach(flush(_))
-      producer.close
-      producer = null
+    producerLock.synchronized {
+      try {
+        if (producer != null) {
+          producer.close
+          producer = null
+
+          sources.foreach {p =>
+            if (p._2.exceptionThrown == null) {
+              flush(p._1)
+            }
+          }
+        }
+      } catch {
+        case e: Exception => logger.error(e.getMessage, e)
+      }
     }
   }
 
   def register(source: String) {
-    if(latestFuture.containsKey(source)) {
+    if(sources.putIfAbsent(source, new SourceData) != null) {
       throw new SamzaException("%s is already registered with the %s system producer" format (source, systemName))
     }
-    latestFuture.put(source, null)
   }
 
   def send(source: String, envelope: OutgoingMessageEnvelope) {
-    var numRetries: AtomicInteger = new AtomicInteger(0)
-    trace("Enqueueing message: %s, %s." format (source, envelope))
-    if(producer == null) {
-      info("Creating a new producer for system %s." format systemName)
-      producer = getProducer()
-      debug("Created a new producer for system %s." format systemName)
-    }
-    // Java-based Kafka producer API requires an "Integer" type partitionKey and does not allow custom overriding of Partitioners
-    // Any kind of custom partitioning has to be done on the client-side
+    trace("Enqueuing message: %s, %s." format (source, envelope))
+
     val topicName = envelope.getSystemStream.getStream
     if (topicName == null || topicName == "") {
       throw new IllegalArgumentException(StreamNameNullOrEmptyErrorMsg)
     }
-    val partitions: java.util.List[PartitionInfo]  = producer.partitionsFor(topicName)
-    val partitionKey = if(envelope.getPartitionKey != null) KafkaUtil.getIntegerPartitionKey(envelope, partitions) else null
+
+    val sourceData = sources.get(source)
+    if (sourceData == null) {
+      throw new IllegalArgumentException("Source %s must be registered first before send." format source)
+    }
+
+    val exception = sourceData.exceptionThrown
+    if (exception != null) {
+      metrics.sendFailed.inc
+      throw exception
+    }
+
+    val currentProducer = producer
+    if (currentProducer == null) {
+      throw new SamzaException("Kafka system producer is not available.")
+    }
+
+    // Java-based Kafka producer API requires an "Integer" type partitionKey and does not allow custom overriding of Partitioners
+    // Any kind of custom partitioning has to be done on the client-side
+    val partitions: java.util.List[PartitionInfo] = currentProducer.partitionsFor(topicName)
+    val partitionKey = if (envelope.getPartitionKey != null) KafkaUtil.getIntegerPartitionKey(envelope, partitions)
+    else null
     val record = new ProducerRecord(envelope.getSystemStream.getStream,
                                     partitionKey,
                                     envelope.getKey.asInstanceOf[Array[Byte]],
                                     envelope.getMessage.asInstanceOf[Array[Byte]])
 
-    sendFailed.set(false)
-
-    retryBackoff.run(
-      loop => {
-        if(sendFailed.get()) {
-          throw exceptionThrown.get()
-        }
+    try {
+      sourceData.sendLock.synchronized {
         val futureRef: Future[RecordMetadata] =
-          producer.send(record, new Callback {
+          currentProducer.send(record, new Callback {
             def onCompletion(metadata: RecordMetadata, exception: Exception): Unit = {
               if (exception == null) {
                 //send was successful. Don't retry
                 metrics.sendSuccess.inc
               }
               else {
-                //If there is an exception in the callback, it means that the Kafka producer has exhausted the max-retries
-                //Hence, fail container!
-                exceptionThrown.compareAndSet(null, exception)
-                sendFailed.set(true)
+                //If there is an exception in the callback, fail container!
+                //Close producer.
+                currentProducer.close
+                sourceData.exceptionThrown = new SamzaException("Unable to send message from %s to system %s." format(source, systemName),
+                                                     exception)
+                metrics.sendFailed.inc
+                logger.error("Unable to send message on Topic:%s Partition:%s" format(topicName, partitionKey),
+                             exception)
               }
             }
           })
-        latestFuture.put(source, futureRef)
-        metrics.sends.inc
-        if(!sendFailed.get())
-          loop.done
-      },
-      (exception, loop) => {
-        if((exception != null && !exception.isInstanceOf[RetriableException]) || numRetries.get() >= maxRetries) {
-          // Irrecoverable exceptions.
-          error("Exception detail : ", exception)
-          //Close producer
-          stop()
-          producer = null
-          //Mark loop as done as we are not going to retry
-          loop.done
-          metrics.sendFailed.inc
-          throw new SamzaException(("Failed to send message on Topic:%s Partition:%s NumRetries:%s Exception:\n %s,")
-            .format(topicName, partitionKey, numRetries, exception))
-        } else {
-          numRetries.incrementAndGet()
-          warn(("Retrying send due to RetriableException - %s for Topic:%s Partition:%s. " +
-            "Turn on debugging to get a full stack trace").format(exception, topicName, partitionKey))
-          debug("Exception detail:", exception)
-          metrics.retries.inc
-        }
+        sourceData.latestFuture = futureRef
+      }
+      metrics.sends.inc
+    } catch {
+      case e: Exception => {
+        currentProducer.close()
+        metrics.sendFailed.inc
+        throw new SamzaException(("Failed to send message on Topic:%s Partition:%s Exception:\n %s,")
+          .format(topicName, partitionKey, e))
       }
-    )
+    }
   }
 
   def flush(source: String) {
     updateTimer(metrics.flushNs) {
       metrics.flushes.inc
+
+      val sourceData = sources.get(source)
       //if latestFuture is null, it probably means that there has been no calls to "send" messages
       //Hence, nothing to do in flush
-      if(latestFuture.get(source) != null) {
-        while (!latestFuture.get(source).isDone && !sendFailed.get()) {
-          //do nothing
-        }
-        if (sendFailed.get()) {
-          logger.error("Unable to send message from %s to system %s" format(source, systemName))
-          //Close producer.
-          if (producer != null) {
-            producer.close
+      if(sourceData.latestFuture != null) {
+        while(!sourceData.latestFuture.isDone && sourceData.exceptionThrown == null) {
+          try {
+            sourceData.latestFuture.get()
+          } catch {
+            case t: Throwable => error(t.getMessage, t)
           }
-          producer = null
+        }
+
+        if (sourceData.exceptionThrown != null) {
           metrics.flushFailed.inc
-          throw new SamzaException("Unable to send message from %s to system %s." format(source, systemName), exceptionThrown.get)
+          throw sourceData.exceptionThrown
         } else {
           trace("Flushed %s." format (source))
         }
       }
     }
   }
-}
+}
\ No newline at end of file
index 04c9113..224ca2f 100644 (file)
@@ -50,9 +50,7 @@ public class TestKafkaSystemProducerJava {
       }
     });
 
-    // Default value should have been used.
-    assertEquals(30, ksp.maxRetries());
     long now = System.currentTimeMillis();
     assertTrue((Long)ksp.clock().apply() >= now);
   }
-}
\ No newline at end of file
+}
index 8e32bba..fab998a 100644 (file)
@@ -140,7 +140,7 @@ class TestKafkaSystemProducer {
       systemProducer.flush("test")
     }
     assertTrue(thrown.isInstanceOf[SamzaException])
-    assertEquals(2, mockProducer.getMsgsSent)
+    assertEquals(3, mockProducer.getMsgsSent) // msg1, msg2 and msg4 will be sent
     systemProducer.stop()
   }
 
@@ -150,14 +150,12 @@ class TestKafkaSystemProducer {
     val msg2 = new OutgoingMessageEnvelope(new SystemStream("test", "test"), "b".getBytes)
     val msg3 = new OutgoingMessageEnvelope(new SystemStream("test", "test"), "c".getBytes)
     val msg4 = new OutgoingMessageEnvelope(new SystemStream("test", "test"), "d".getBytes)
-    val numMaxRetries = 3
 
     val mockProducer = new MockKafkaProducer(1, "test", 1)
     val producerMetrics = new KafkaSystemProducerMetrics()
     val producer = new KafkaSystemProducer(systemName =  "test",
       getProducer = () => mockProducer,
-      metrics = producerMetrics,
-      maxRetries = numMaxRetries)
+      metrics = producerMetrics)
 
     producer.register("test")
     producer.start()
@@ -169,14 +167,15 @@ class TestKafkaSystemProducer {
     assertEquals(0, producerMetrics.retries.getCount)
     mockProducer.setErrorNext(true, new TimeoutException())
 
+    producer.send("test", msg4)
     val thrown = intercept[SamzaException] {
-      producer.send("test", msg4)
+      producer.flush("test")
     }
     assertTrue(thrown.isInstanceOf[SamzaException])
     assertTrue(thrown.getCause.isInstanceOf[TimeoutException])
-    assertEquals(true, producer.sendFailed.get())
     assertEquals(3, mockProducer.getMsgsSent)
-    assertEquals(numMaxRetries, producerMetrics.retries.getCount)
+    // retriable exception will be thrown immediately
+    assertEquals(0, producerMetrics.retries.getCount)
     producer.stop()
   }
 
@@ -199,12 +198,12 @@ class TestKafkaSystemProducer {
     producer.send("test", msg3)
     mockProducer.setErrorNext(true, new RecordTooLargeException())
 
+    producer.send("test", msg4)
     val thrown = intercept[SamzaException] {
-       producer.send("test", msg4)
+       producer.flush("test")
     }
     assertTrue(thrown.isInstanceOf[SamzaException])
     assertTrue(thrown.getCause.isInstanceOf[RecordTooLargeException])
-    assertEquals(true, producer.sendFailed.get())
     assertEquals(3, mockProducer.getMsgsSent)
     assertEquals(0, producerMetrics.retries.getCount)
     producer.stop()
index 72f25a3..4c245b6 100644 (file)
@@ -26,14 +26,14 @@ import java.util
 /**
  * In memory implementation of a key value store.
  *
- * This uses a TreeMap to store the keys in order
+ * This uses a ConcurrentSkipListMap to store the keys in order
  *
  * @param metrics A metrics instance to publish key-value store related statistics
  */
 class InMemoryKeyValueStore(val metrics: KeyValueStoreMetrics = new KeyValueStoreMetrics)
   extends KeyValueStore[Array[Byte], Array[Byte]] with Logging {
 
-  val underlying = new util.TreeMap[Array[Byte], Array[Byte]] (UnsignedBytes.lexicographicalComparator())
+  val underlying = new util.concurrent.ConcurrentSkipListMap[Array[Byte], Array[Byte]] (UnsignedBytes.lexicographicalComparator())
 
   override def flush(): Unit = {
     // No-op for In memory store.
@@ -47,7 +47,7 @@ class InMemoryKeyValueStore(val metrics: KeyValueStoreMetrics = new KeyValueStor
 
     override def close(): Unit = Unit
 
-    override def remove(): Unit = iter.remove()
+    override def remove(): Unit = throw new UnsupportedOperationException("InMemoryKeyValueStore iterator doesn't support remove")
 
     override def next(): Entry[Array[Byte], Array[Byte]] = {
       val n = iter.next()
index 9b9b1f6..73b89f7 100644 (file)
@@ -106,7 +106,6 @@ class RocksDbKeyValueStore(
   // after the directories are created, which happens much later from now.
   private lazy val db = RocksDbKeyValueStore.openDB(dir, options, storeConfig, isLoggedStore, storeName)
   private val lexicographic = new LexicographicComparator()
-  private var deletesSinceLastCompaction = 0
 
   def get(key: Array[Byte]): Array[Byte] = {
     metrics.gets.inc
@@ -141,7 +140,6 @@ class RocksDbKeyValueStore(
     require(key != null, "Null key not allowed.")
     if (value == null) {
       db.remove(writeOptions, key)
-      deletesSinceLastCompaction += 1
     } else {
       metrics.bytesWritten.inc(key.size + value.size)
       db.put(writeOptions, key, value)
@@ -168,7 +166,6 @@ class RocksDbKeyValueStore(
     }
     metrics.puts.inc(wrote)
     metrics.deletes.inc(deletes)
-    deletesSinceLastCompaction += deletes
   }
 
   def delete(key: Array[Byte]) {
index b7f1cdc..05d39ea 100644 (file)
@@ -26,7 +26,7 @@ import java.util
 import org.apache.samza.config.MapConfig
 import org.apache.samza.util.ExponentialSleepStrategy
 import org.junit.{Assert, Test}
-import org.rocksdb.{RocksDB, FlushOptions, Options}
+import org.rocksdb.{RocksIterator, RocksDB, FlushOptions, Options}
 
 class TestRocksDbKeyValueStore
 {
@@ -85,4 +85,61 @@ class TestRocksDbKeyValueStore
     rocksDB.close()
     rocksDBReadOnly.close()
   }
+
+  @Test
+  def testIteratorWithRemoval(): Unit = {
+    val lock = new Object
+
+    val map = new util.HashMap[String, String]()
+    val config = new MapConfig(map)
+    val options = new Options()
+    options.setCreateIfMissing(true)
+    val rocksDB = RocksDbKeyValueStore.openDB(new File(System.getProperty("java.io.tmpdir")),
+                                              options,
+                                              config,
+                                              false,
+                                              "dbStore")
+
+    val key = "key".getBytes("UTF-8")
+    val key1 = "key1".getBytes("UTF-8")
+    val value = "val".getBytes("UTF-8")
+    val value1 = "val1".getBytes("UTF-8")
+
+    var iter: RocksIterator = null
+
+    lock.synchronized {
+      rocksDB.put(key, value)
+      rocksDB.put(key1, value1)
+      // SAMZA-836: Mysteriously,calling new FlushOptions() does not invoke the NativeLibraryLoader in rocksdbjni-3.13.1!
+      // Moving this line after calling new Options() resolve the issue.
+      val flushOptions = new FlushOptions().setWaitForFlush(true)
+      rocksDB.flush(flushOptions)
+
+      iter = rocksDB.newIterator()
+      iter.seekToFirst()
+    }
+
+    while (iter.isValid) {
+      iter.next()
+    }
+    iter.dispose()
+
+    lock.synchronized {
+      rocksDB.remove(key)
+      iter = rocksDB.newIterator()
+      iter.seek(key)
+    }
+
+    while (iter.isValid) {
+      iter.next()
+    }
+    iter.dispose()
+
+    val dbDir = new File(System.getProperty("java.io.tmpdir")).toString
+    val rocksDBReadOnly = RocksDB.openReadOnly(options, dbDir)
+    Assert.assertEquals(new String(rocksDBReadOnly.get(key1), "UTF-8"), "val1")
+    Assert.assertEquals(rocksDBReadOnly.get(key), null)
+    rocksDB.close()
+    rocksDBReadOnly.close()
+  }
 }
index e7e4ede..44f96b4 100644 (file)
@@ -37,7 +37,7 @@ import java.util.Arrays
  * that have not yet been written to disk. All writes go to the dirty list and when the list is long enough we call putAll on all those values at once. Dirty items
  * that time out of the cache before being written will also trigger a putAll of the dirty list.
  *
- * This class is very non-thread safe.
+ * This class is thread safe.
  *
  * @param store The store to cache
  * @param cacheSize The number of entries to hold in the in memory-cache
@@ -59,6 +59,9 @@ class CachedStore[K, V](
   /** the list of items to be written out on flush from newest to oldest */
   private var dirty = new mutable.DoubleLinkedList[K]()
 
+  /** the synchronization lock to protect access to the store from multiple threads **/
+  private val lock = new Object
+
   /** an lru cache of values that holds cacheEntries and calls putAll() on dirty entries if necessary when discarding */
   private val cache = new java.util.LinkedHashMap[K, CacheEntry[K, V]]((cacheSize * 1.2).toInt, 1.0f, true) {
     override def removeEldestEntry(eldest: java.util.Map.Entry[K, CacheEntry[K, V]]): Boolean = {
@@ -76,7 +79,7 @@ class CachedStore[K, V](
   }
 
   /** tracks whether an array has been used as a key. since this is dangerous with LinkedHashMap, we want to warn on it. **/
-  private var containsArrayKeys = false
+  @volatile private var containsArrayKeys = false
 
   // Use counters here, rather than directly accessing variables using .size
   // since metrics can be accessed in other threads, and cache.size is not
@@ -85,7 +88,7 @@ class CachedStore[K, V](
   metrics.setDirtyCount(() => dirtyCount)
   metrics.setCacheSize(() => cacheCount)
 
-  override def get(key: K) = {
+  override def get(key: K) = lock.synchronized({
     metrics.gets.inc
 
     val c = cache.get(key)
@@ -98,7 +101,7 @@ class CachedStore[K, V](
       cacheCount = cache.size
       v
     }
-  }
+  })
 
   private class CachedStoreIterator(val iter: KeyValueIterator[K, V])
     extends KeyValueIterator[K, V] {
@@ -107,10 +110,7 @@ class CachedStore[K, V](
 
     override def close(): Unit = iter.close()
 
-    override def remove(): Unit = {
-      iter.remove()
-      delete(last.getKey)
-    }
+    override def remove(): Unit = throw new UnsupportedOperationException("CachedStore iterator doesn't support remove")
 
     override def next() = {
       last = iter.next()
@@ -120,79 +120,85 @@ class CachedStore[K, V](
     override def hasNext: Boolean = iter.hasNext
   }
 
-  override def range(from: K, to: K): KeyValueIterator[K, V] = {
+  override def range(from: K, to: K): KeyValueIterator[K, V] = lock.synchronized({
     metrics.ranges.inc
     putAllDirtyEntries()
 
     new CachedStoreIterator(store.range(from, to))
-  }
+  })
 
-  override def all(): KeyValueIterator[K, V] = {
+  override def all(): KeyValueIterator[K, V] = lock.synchronized({
     metrics.alls.inc
     putAllDirtyEntries()
 
     new CachedStoreIterator(store.all())
-  }
+  })
 
   override def put(key: K, value: V) {
-    metrics.puts.inc
+    lock.synchronized({
+      metrics.puts.inc
 
-    checkKeyIsArray(key)
+      checkKeyIsArray(key)
 
-    // Add the key to the front of the dirty list (and remove any prior
-    // occurrences to dedupe).
-    val found = cache.get(key)
-    if (found == null || found.dirty == null) {
-      this.dirtyCount += 1
-    } else {
-      // If we are removing the head of the list, move the head to the next
-      // element. See SAMZA-45 for details.
-      if (found.dirty.prev == null) {
-        this.dirty = found.dirty.next
-        this.dirty.prev = null
+      // Add the key to the front of the dirty list (and remove any prior
+      // occurrences to dedupe).
+      val found = cache.get(key)
+      if (found == null || found.dirty == null) {
+        this.dirtyCount += 1
       } else {
-        found.dirty.remove()
+        // If we are removing the head of the list, move the head to the next
+        // element. See SAMZA-45 for details.
+        if (found.dirty.prev == null) {
+          this.dirty = found.dirty.next
+          this.dirty.prev = null
+        } else {
+          found.dirty.remove()
+        }
       }
-    }
-    this.dirty = new mutable.DoubleLinkedList(key, this.dirty)
+      this.dirty = new mutable.DoubleLinkedList(key, this.dirty)
 
-    // Add the key to the cache (but don't allocate a new cache entry if we
-    // already have one).
-    if (found == null) {
-      cache.put(key, new CacheEntry(value, this.dirty))
-      cacheCount = cache.size
-    } else {
-      found.value = value
-      found.dirty = this.dirty
-    }
+      // Add the key to the cache (but don't allocate a new cache entry if we
+      // already have one).
+      if (found == null) {
+        cache.put(key, new CacheEntry(value, this.dirty))
+        cacheCount = cache.size
+      } else {
+        found.value = value
+        found.dirty = this.dirty
+      }
 
-    // putAll() dirty values if the write list is full.
-    val purgeNeeded = if (dirtyCount >= writeBatchSize) {
-      debug("Dirty count %s >= write batch size %s. Calling putAll() on all dirty entries." format (dirtyCount, writeBatchSize))
-      true
-    } else if (hasArrayKeys) {
-      // Flush every time to support the following legacy behavior:
-      // If array keys are used with a cached store, get() will always miss the cache because of array equality semantics
-      // However, it will fall back to the underlying store which does support arrays.
-      true
-    } else {
-      false
-    }
+      // putAll() dirty values if the write list is full.
+      val purgeNeeded = if (dirtyCount >= writeBatchSize) {
+        debug("Dirty count %s >= write batch size %s. Calling putAll() on all dirty entries." format (dirtyCount, writeBatchSize))
+        true
+      } else if (hasArrayKeys) {
+        // Flush every time to support the following legacy behavior:
+        // If array keys are used with a cached store, get() will always miss the cache because of array equality semantics
+        // However, it will fall back to the underlying store which does support arrays.
+        true
+      } else {
+        false
+      }
 
-    if (purgeNeeded) {
-      putAllDirtyEntries()
-    }
+      if (purgeNeeded) {
+        putAllDirtyEntries()
+      }
+    })
   }
 
   override def flush() {
     trace("Purging dirty entries from CachedStore.")
-    metrics.flushes.inc
-    putAllDirtyEntries()
-    trace("Flushing store.")
-    store.flush()
+    lock.synchronized({
+      metrics.flushes.inc
+      putAllDirtyEntries()
+      store.flush()
+    })
     trace("Flushed store.")
   }
 
+  /**
+   * The synchronization lock must be held before calling this method.
+   */
   private def putAllDirtyEntries() {
     trace("Calling putAll() on dirty entries.")
     // write out the contents of the dirty list oldest first
@@ -212,26 +218,34 @@ class CachedStore[K, V](
   }
 
   override def putAll(entries: java.util.List[Entry[K, V]]) {
-    val iter = entries.iterator
-    while (iter.hasNext) {
-      val curr = iter.next
-      put(curr.getKey, curr.getValue)
-    }
+    lock.synchronized({
+      val iter = entries.iterator
+      while (iter.hasNext) {
+        val curr = iter.next
+        put(curr.getKey, curr.getValue)
+      }
+    })
   }
 
   override def delete(key: K) {
-    metrics.deletes.inc
-    put(key, null.asInstanceOf[V])
+    lock.synchronized({
+      metrics.deletes.inc
+      put(key, null.asInstanceOf[V])
+    })
   }
 
   override def close() {
-    trace("Closing.")
-    flush()
-    store.close()
+    lock.synchronized({
+      trace("Closing.")
+      flush()
+      store.close()
+    })
   }
 
   override def deleteAll(keys: java.util.List[K]) = {
-    KeyValueStore.Extension.deleteAll(this, keys)
+    lock.synchronized({
+      KeyValueStore.Extension.deleteAll(this, keys)
+    })
   }
 
   private def checkKeyIsArray(key: K) {
@@ -243,30 +257,32 @@ class CachedStore[K, V](
   }
 
   override def getAll(keys: java.util.List[K]): java.util.Map[K, V] = {
-    metrics.gets.inc(keys.size)
-    val returnValue = new java.util.HashMap[K, V](keys.size)
-    val misses = new java.util.ArrayList[K]
-    val keysIterator = keys.iterator
-    while (keysIterator.hasNext) {
-      val key = keysIterator.next
-      val cached = cache.get(key)
-      if (cached != null) {
-        metrics.cacheHits.inc
-        returnValue.put(key, cached.value)
-      } else {
-        misses.add(key)
+    lock.synchronized({
+      metrics.gets.inc(keys.size)
+      val returnValue = new java.util.HashMap[K, V](keys.size)
+      val misses = new java.util.ArrayList[K]
+      val keysIterator = keys.iterator
+      while (keysIterator.hasNext) {
+        val key = keysIterator.next
+        val cached = cache.get(key)
+        if (cached != null) {
+          metrics.cacheHits.inc
+          returnValue.put(key, cached.value)
+        } else {
+          misses.add(key)
+        }
       }
-    }
-    if (!misses.isEmpty) {
-      val entryIterator = store.getAll(misses).entrySet.iterator
-      while (entryIterator.hasNext) {
-        val entry = entryIterator.next
-        returnValue.put(entry.getKey, entry.getValue)
-        cache.put(entry.getKey, new CacheEntry(entry.getValue, null))
+      if (!misses.isEmpty) {
+        val entryIterator = store.getAll(misses).entrySet.iterator
+        while (entryIterator.hasNext) {
+          val entry = entryIterator.next
+          returnValue.put(entry.getKey, entry.getValue)
+          cache.put(entry.getKey, new CacheEntry(entry.getValue, null))
+        }
+        cacheCount = cache.size // update outside the loop since it's used for metrics and not for time-sensitive logic
       }
-      cacheCount = cache.size // update outside the loop since it's used for metrics and not for time-sensitive logic
-    }
-    returnValue
+      returnValue
+    })
   }
 
   def hasArrayKeys = containsArrayKeys
index 96eb5fa..e16bdc0 100644 (file)
@@ -127,21 +127,6 @@ class TestCachedStore {
     }
     assertFalse(iter.hasNext)
 
-    // test iterator remove
-    iter = store.all()
-    iter.next()
-    iter.remove()
-
-    assertNull(kv.get(keys.get(0)))
-    assertNull(store.get(keys.get(0)))
-
-    iter = store.range(keys.get(1), keys.get(2))
-    iter.next()
-    iter.remove()
-
-    assertFalse(iter.hasNext)
-    assertNull(kv.get(keys.get(1)))
-    assertNull(store.get(keys.get(1)))
   }
 
   @Test
index fd4e762..d7d23ec 100644 (file)
@@ -20,7 +20,9 @@
 package org.apache.samza.storage.kv
 
 import java.io.File
-import java.util.{Arrays, Random}
+import java.util.Arrays
+import java.util.Random
+import java.util.concurrent.CountDownLatch
 
 import org.apache.samza.config.MapConfig
 import org.apache.samza.serializers.Serde
@@ -29,7 +31,9 @@ import org.junit.Assert._
 import org.junit.runner.RunWith
 import org.junit.runners.Parameterized
 import org.junit.runners.Parameterized.Parameters
-import org.junit.{After, Before, Test}
+import org.junit.After
+import org.junit.Before
+import org.junit.Test
 import org.scalatest.Assertions.intercept
 
 import scala.collection.JavaConversions._
@@ -136,8 +140,9 @@ class TestKeyValueStores(typeOfStore: String, storeConfig: String) {
 
   @Test
   def putAndGet() {
-    store.put(b("k"), b("v"))
-    assertArrayEquals(b("v"), store.get(b("k")))
+    val k = b("k")
+    store.put(k, b("v"))
+    assertArrayEquals(b("v"), store.get(k))
   }
 
   @Test
@@ -379,6 +384,181 @@ class TestKeyValueStores(typeOfStore: String, storeConfig: String) {
     })
   }
 
+  @Test
+  def testParallelReadWriteSameKey(): Unit = {
+    // Make test deterministic by seeding the random number generator.
+    val key = b("key")
+    val val1 = "val1"
+    val val2 = "val2"
+
+    val runner1 = new Thread(new Runnable {
+      override def run(): Unit = {
+        store.put(key, b(val1))
+      }
+    })
+
+    val runner2 = new Thread(new Runnable {
+      override def run(): Unit = {
+        while(!val1.equals({store.get(key) match {
+          case null => ""
+          case _ => { new String(store.get(key), "UTF-8") }
+        }})) {}
+        store.put(key, b(val2))
+      }
+    })
+
+    runner2.start()
+    runner1.start()
+
+    runner2.join(1000)
+    runner1.join(1000)
+
+    assertEquals("val2", new String(store.get(key), "UTF-8"))
+
+    store.delete(key)
+    store.flush()
+  }
+
+  @Test
+  def testParallelReadWriteDiffKeys(): Unit = {
+    // Make test deterministic by seeding the random number generator.
+    val key1 = b("key1")
+    val key2 = b("key2")
+    val val1 = "val1"
+    val val2 = "val2"
+
+    val runner1 = new Thread(new Runnable {
+      override def run(): Unit = {
+        store.put(key1, b(val1))
+      }
+    })
+
+    val runner2 = new Thread(new Runnable {
+      override def run(): Unit = {
+        while(!val1.equals({store.get(key1) match {
+          case null => ""
+          case _ => { new String(store.get(key1), "UTF-8") }
+        }})) {}
+        store.delete(key1)
+      }
+    })
+
+    val runner3 = new Thread(new Runnable {
+      override def run(): Unit = {
+        store.put(key2, b(val2))
+      }
+    })
+
+    val runner4 = new Thread(new Runnable {
+      override def run(): Unit = {
+        while(!val2.equals({store.get(key2) match {
+          case null => ""
+          case _ => { new String(store.get(key2), "UTF-8") }
+        }})) {}
+        store.delete(key2)
+      }
+    })
+
+    runner2.start()
+    runner1.start()
+    runner3.start()
+    runner4.start()
+
+    runner2.join(1000)
+    runner1.join(1000)
+    runner3.join(1000)
+    runner4.join(1000)
+
+    assertNull(store.get(key1))
+    assertNull(store.get(key2))
+
+    store.flush()
+  }
+
+  @Test
+  def testParallelIteratorAndWrite(): Unit = {
+    // Make test deterministic by seeding the random number generator.
+    val key1 = b("key1")
+    val key2 = b("key2")
+    val val1 = "val1"
+    val val2 = "val2"
+    @volatile var throwable: Throwable = null
+
+    store.put(key1, b(val1))
+    store.put(key2, b(val2))
+
+    val runner1StartLatch = new CountDownLatch(1)
+    val runner2StartLatch = new CountDownLatch(1)
+
+    val runner1 = new Thread(new Runnable {
+      override def run(): Unit = {
+        runner1StartLatch.await()
+        store.put(key1, b("val1-2"))
+        store.delete(key2)
+        store.flush()
+        runner2StartLatch.countDown()
+      }
+    })
+
+    val runner2 = new Thread(new Runnable {
+      override def run(): Unit = {
+        runner2StartLatch.await()
+        val iter = store.all() //snapshot after change
+        try {
+          while (iter.hasNext) {
+            val e = iter.next()
+            if ("key1".equals(new String(e.getKey, "UTF-8"))) {
+              assertEquals("val1-2", new String(e.getValue, "UTF-8"))
+            }
+            System.out.println(String.format("iterator1: key: %s, value: %s", new String(e.getKey, "UTF-8"), new String(e.getValue, "UTF-8")))
+          }
+          iter.close()
+        } catch {
+          case t: Throwable => throwable = t
+        }
+      }
+    })
+
+    val runner3 = new Thread(new Runnable {
+      override def run(): Unit = {
+        val iter = store.all()  //snapshot
+        runner1StartLatch.countDown()
+        try {
+          while (iter.hasNext) {
+            val e = iter.next()
+            val key = new String(e.getKey, "UTF-8")
+            val value = new String(e.getValue, "UTF-8")
+            if (key.equals("key1")) {
+              assertEquals(val1, value)
+            }
+            else if (key.equals("key2") && !val2.equals(value)) {
+              assertEquals(val2, value)
+            }
+            else if (!key.equals("key1") && !key.equals("key2")) {
+              throw new Exception("unknow key " + new String(e.getKey, "UTF-8") + ", value: " + new String(e.getValue, "UTF-8"))
+            }
+            System.out.println(String.format("iterator2: key: %s, value: %s", new String(e.getKey, "UTF-8"), new String(e.getValue, "UTF-8")))
+          }
+          iter.close()
+        } catch {
+          case t: Throwable => throwable = t
+        }
+      }
+    })
+
+    runner2.start()
+    runner3.start()
+    runner1.start()
+
+    runner2.join()
+    runner1.join()
+    runner3.join()
+
+    if(throwable != null) throw throwable
+
+    store.flush()
+  }
+
   def checkRange(vals: IndexedSeq[String], iter: KeyValueIterator[Array[Byte], Array[Byte]]) {
     for (v <- vals) {
       assertTrue(iter.hasNext)
@@ -417,5 +597,6 @@ object TestKeyValueStores {
       Array("rocksdb","cache"),
       Array("rocksdb","serde"),
       Array("rocksdb","cache-and-serde"),
-      Array("rocksdb","none"))
+      Array("rocksdb","none")
+  )
 }