SAMZA-1714: Creating shared context factory for shared context objects
authorCameron Lee <calee@linkedin.com>
Wed, 10 Oct 2018 22:34:19 +0000 (15:34 -0700)
committerPrateek Maheshwari <pmaheshwari@apache.org>
Wed, 10 Oct 2018 22:34:19 +0000 (15:34 -0700)
<s>This includes changes in https://github.com/apache/samza/pull/626.</s>
Update: PR #626 has been merged, so the diff here should no longer show those changes.

Author: Cameron Lee <calee@linkedin.com>

Reviewers: Prateek Maheshwari <pmaheshwari@apache.org>

Closes #672 from cameronlee314/shared_context_impl

138 files changed:
docs/learn/documentation/versioned/jobs/configuration-table.html
docs/learn/tutorials/versioned/hello-samza-high-level-code.md
samza-api/src/main/java/org/apache/samza/application/ApplicationDescriptor.java
samza-api/src/main/java/org/apache/samza/container/SamzaContainerContext.java [deleted file]
samza-api/src/main/java/org/apache/samza/context/ApplicationContainerContext.java
samza-api/src/main/java/org/apache/samza/context/ApplicationTaskContext.java
samza-api/src/main/java/org/apache/samza/context/JobContext.java
samza-api/src/main/java/org/apache/samza/operators/ContextManager.java [deleted file]
samza-api/src/main/java/org/apache/samza/operators/functions/InitableFunction.java
samza-api/src/main/java/org/apache/samza/scheduler/CallbackScheduler.java
samza-api/src/main/java/org/apache/samza/scheduler/ScheduledCallback.java
samza-api/src/main/java/org/apache/samza/storage/StorageEngineFactory.java
samza-api/src/main/java/org/apache/samza/table/ReadableTable.java
samza-api/src/main/java/org/apache/samza/table/TableProvider.java
samza-api/src/main/java/org/apache/samza/task/InitableTask.java
samza-api/src/main/java/org/apache/samza/task/TaskContext.java [deleted file]
samza-api/src/main/java/org/apache/samza/util/RateLimiter.java
samza-core/src/main/java/org/apache/samza/application/ApplicationDescriptorImpl.java
samza-core/src/main/java/org/apache/samza/container/TaskContextImpl.java [deleted file]
samza-core/src/main/java/org/apache/samza/context/ContextImpl.java
samza-core/src/main/java/org/apache/samza/context/JobContextImpl.java
samza-core/src/main/java/org/apache/samza/context/TaskContextImpl.java
samza-core/src/main/java/org/apache/samza/operators/impl/BroadcastOperatorImpl.java
samza-core/src/main/java/org/apache/samza/operators/impl/InputOperatorImpl.java
samza-core/src/main/java/org/apache/samza/operators/impl/OperatorImpl.java
samza-core/src/main/java/org/apache/samza/operators/impl/OperatorImplGraph.java
samza-core/src/main/java/org/apache/samza/operators/impl/OutputOperatorImpl.java
samza-core/src/main/java/org/apache/samza/operators/impl/PartialJoinOperatorImpl.java
samza-core/src/main/java/org/apache/samza/operators/impl/PartitionByOperatorImpl.java
samza-core/src/main/java/org/apache/samza/operators/impl/SendToTableOperatorImpl.java
samza-core/src/main/java/org/apache/samza/operators/impl/SinkOperatorImpl.java
samza-core/src/main/java/org/apache/samza/operators/impl/StreamOperatorImpl.java
samza-core/src/main/java/org/apache/samza/operators/impl/StreamTableJoinOperatorImpl.java
samza-core/src/main/java/org/apache/samza/operators/impl/WindowOperatorImpl.java
samza-core/src/main/java/org/apache/samza/operators/spec/FilterOperatorSpec.java
samza-core/src/main/java/org/apache/samza/operators/spec/MapOperatorSpec.java
samza-core/src/main/java/org/apache/samza/processor/StreamProcessor.java
samza-core/src/main/java/org/apache/samza/runtime/LocalApplicationRunner.java
samza-core/src/main/java/org/apache/samza/runtime/LocalContainerRunner.java
samza-core/src/main/java/org/apache/samza/storage/StorageRecovery.java
samza-core/src/main/java/org/apache/samza/table/TableManager.java
samza-core/src/main/java/org/apache/samza/table/caching/CachingTable.java
samza-core/src/main/java/org/apache/samza/table/caching/CachingTableProvider.java
samza-core/src/main/java/org/apache/samza/table/caching/guava/GuavaCacheTable.java
samza-core/src/main/java/org/apache/samza/table/caching/guava/GuavaCacheTableProvider.java
samza-core/src/main/java/org/apache/samza/table/remote/RemoteReadWriteTable.java
samza-core/src/main/java/org/apache/samza/table/remote/RemoteReadableTable.java
samza-core/src/main/java/org/apache/samza/table/remote/RemoteTableProvider.java
samza-core/src/main/java/org/apache/samza/table/utils/BaseTableProvider.java
samza-core/src/main/java/org/apache/samza/table/utils/DefaultTableReadMetrics.java
samza-core/src/main/java/org/apache/samza/table/utils/DefaultTableWriteMetrics.java
samza-core/src/main/java/org/apache/samza/table/utils/TableMetricsUtil.java
samza-core/src/main/java/org/apache/samza/task/AsyncRunLoop.java
samza-core/src/main/java/org/apache/samza/task/AsyncStreamTaskAdapter.java
samza-core/src/main/java/org/apache/samza/task/StreamOperatorTask.java
samza-core/src/main/java/org/apache/samza/task/TaskFactoryUtil.java
samza-core/src/main/java/org/apache/samza/util/EmbeddedTaggedRateLimiter.java
samza-core/src/main/scala/org/apache/samza/container/SamzaContainer.scala
samza-core/src/main/scala/org/apache/samza/container/TaskInstance.scala
samza-core/src/main/scala/org/apache/samza/job/local/ThreadJobFactory.scala
samza-core/src/test/java/org/apache/samza/application/TestStreamApplicationDescriptorImpl.java
samza-core/src/test/java/org/apache/samza/application/TestTaskApplicationDescriptorImpl.java
samza-core/src/test/java/org/apache/samza/context/MockContext.java [new file with mode: 0644]
samza-core/src/test/java/org/apache/samza/context/TestContextImpl.java
samza-core/src/test/java/org/apache/samza/context/TestTaskContextImpl.java
samza-core/src/test/java/org/apache/samza/execution/TestJobNodeConfigurationGenerator.java
samza-core/src/test/java/org/apache/samza/operators/TestJoinOperator.java
samza-core/src/test/java/org/apache/samza/operators/impl/TestOperatorImpl.java
samza-core/src/test/java/org/apache/samza/operators/impl/TestOperatorImplGraph.java
samza-core/src/test/java/org/apache/samza/operators/impl/TestSinkOperatorImpl.java
samza-core/src/test/java/org/apache/samza/operators/impl/TestStreamOperatorImpl.java
samza-core/src/test/java/org/apache/samza/operators/impl/TestStreamTableJoinOperatorImpl.java
samza-core/src/test/java/org/apache/samza/operators/impl/TestWindowOperator.java
samza-core/src/test/java/org/apache/samza/operators/spec/TestOperatorSpec.java
samza-core/src/test/java/org/apache/samza/operators/spec/TestPartitionByOperatorSpec.java
samza-core/src/test/java/org/apache/samza/operators/spec/TestWindowOperatorSpec.java
samza-core/src/test/java/org/apache/samza/processor/TestStreamProcessor.java
samza-core/src/test/java/org/apache/samza/storage/MockStorageEngineFactory.java
samza-core/src/test/java/org/apache/samza/table/TestTableManager.java
samza-core/src/test/java/org/apache/samza/table/caching/TestCachingTable.java
samza-core/src/test/java/org/apache/samza/table/remote/TestRemoteTable.java
samza-core/src/test/java/org/apache/samza/table/remote/TestRemoteTableDescriptor.java
samza-core/src/test/java/org/apache/samza/table/retry/TestRetriableTableFunctions.java
samza-core/src/test/java/org/apache/samza/task/IdentityStreamTask.java
samza-core/src/test/java/org/apache/samza/task/TestAsyncRunLoop.java
samza-core/src/test/java/org/apache/samza/task/TestAsyncStreamAdapter.java
samza-core/src/test/java/org/apache/samza/task/TestEpochTimeScheduler.java
samza-core/src/test/java/org/apache/samza/task/TestStreamOperatorTask.java
samza-core/src/test/java/org/apache/samza/util/TestEmbeddedTaggedRateLimiter.java
samza-core/src/test/scala/org/apache/samza/container/TestSamzaContainer.scala
samza-core/src/test/scala/org/apache/samza/container/TestTaskInstance.scala
samza-core/src/test/scala/org/apache/samza/processor/StreamProcessorTestUtils.scala
samza-kv-inmemory/src/main/scala/org/apache/samza/storage/kv/inmemory/InMemoryKeyValueStorageEngineFactory.scala
samza-kv-rocksdb/src/main/java/org/apache/samza/storage/kv/RocksDbKeyValueReader.java
samza-kv-rocksdb/src/main/java/org/apache/samza/storage/kv/RocksDbOptionsHelper.java
samza-kv-rocksdb/src/main/scala/org/apache/samza/storage/kv/RocksDbKeyValueStorageEngineFactory.scala
samza-kv-rocksdb/src/test/java/org/apache/samza/storage/kv/TestRocksDbTableDescriptor.java
samza-kv/src/main/java/org/apache/samza/storage/kv/BaseLocalStoreBackedTableProvider.java
samza-kv/src/main/java/org/apache/samza/storage/kv/LocalStoreBackedReadWriteTable.java
samza-kv/src/main/java/org/apache/samza/storage/kv/LocalStoreBackedReadableTable.java
samza-kv/src/main/scala/org/apache/samza/storage/kv/BaseKeyValueStorageEngineFactory.scala
samza-kv/src/test/java/org/apache/samza/storage/kv/TestBaseLocalStoreBackedTableProvider.java
samza-sql/src/main/java/org/apache/samza/sql/runner/SamzaSqlApplicationContext.java [new file with mode: 0644]
samza-sql/src/main/java/org/apache/samza/sql/translator/FilterTranslator.java
samza-sql/src/main/java/org/apache/samza/sql/translator/ModifyTranslator.java
samza-sql/src/main/java/org/apache/samza/sql/translator/ProjectTranslator.java
samza-sql/src/main/java/org/apache/samza/sql/translator/QueryTranslator.java
samza-sql/src/main/java/org/apache/samza/sql/translator/ScanTranslator.java
samza-sql/src/test/java/org/apache/samza/sql/e2e/TestSamzaSqlTable.java
samza-sql/src/test/java/org/apache/samza/sql/runner/TestSamzaSqlApplicationRunner.java
samza-sql/src/test/java/org/apache/samza/sql/system/TestAvroSystemFactory.java
samza-sql/src/test/java/org/apache/samza/sql/testutil/TestIOResolverFactory.java
samza-sql/src/test/java/org/apache/samza/sql/testutil/TestSamzaSqlFileParser.java
samza-sql/src/test/java/org/apache/samza/sql/translator/TestFilterTranslator.java
samza-sql/src/test/java/org/apache/samza/sql/translator/TestProjectTranslator.java
samza-sql/src/test/java/org/apache/samza/sql/translator/TestQueryTranslator.java
samza-test/src/main/java/org/apache/samza/example/KeyValueStoreExample.java
samza-test/src/main/java/org/apache/samza/test/framework/MessageStreamAssert.java
samza-test/src/main/java/org/apache/samza/test/integration/NegateNumberTask.java
samza-test/src/main/java/org/apache/samza/test/integration/SimpleStatefulTask.java
samza-test/src/main/java/org/apache/samza/test/integration/StatePerfTestTask.java
samza-test/src/main/java/org/apache/samza/test/integration/join/Checker.java
samza-test/src/main/java/org/apache/samza/test/integration/join/Emitter.java
samza-test/src/main/java/org/apache/samza/test/integration/join/Joiner.java
samza-test/src/main/java/org/apache/samza/test/integration/join/Watcher.java
samza-test/src/main/scala/org/apache/samza/test/performance/TestKeyValuePerformance.scala
samza-test/src/main/scala/org/apache/samza/test/performance/TestPerformanceTask.scala
samza-test/src/test/java/org/apache/samza/processor/TestZkStreamProcessorBase.java
samza-test/src/test/java/org/apache/samza/test/framework/TestSchedulingApp.java
samza-test/src/test/java/org/apache/samza/test/processor/IdentityStreamTask.java
samza-test/src/test/java/org/apache/samza/test/processor/TestStreamProcessor.java
samza-test/src/test/java/org/apache/samza/test/table/TestLocalTable.java
samza-test/src/test/java/org/apache/samza/test/table/TestLocalTableWithSideInputs.java
samza-test/src/test/java/org/apache/samza/test/table/TestRemoteTable.java
samza-test/src/test/java/org/apache/samza/test/table/TestTableDescriptorsProvider.java
samza-test/src/test/scala/org/apache/samza/test/integration/StreamTaskTestUtil.scala
samza-test/src/test/scala/org/apache/samza/test/integration/TestShutdownStatefulTask.scala
samza-test/src/test/scala/org/apache/samza/test/integration/TestStatefulTask.scala

index 35ddcab..44d43f3 100644 (file)
                         store any <span class="store">store-name</span> except <em>default</em> (the <span class="store">store-name</span>
                         <em>default</em> is reserved for defining default store parameters), and use that name to get a
                         reference to the store in your stream task (call
-                        <a href="../api/javadocs/org/apache/samza/task/TaskContext.html#getStore(java.lang.String)">TaskContext.getStore()</a>
+                        <a href="../api/javadocs/org/apache/samza/context/TaskContext.html#getStore(java.lang.String)">TaskContext.getStore()</a>
                         in your task's
-                        <a href="../api/javadocs/org/apache/samza/task/InitableTask.html#init(org.apache.samza.config.Config, org.apache.samza.task.TaskContext)">init()</a>
+                        <a href="../api/javadocs/org/apache/samza/task/InitableTask.html#init(org.apache.samza.context.Context)">init()</a>
                         method). The value of this property is the fully-qualified name of a Java class that implements
                         <a href="../api/javadocs/org/apache/samza/storage/StorageEngineFactory.html">StorageEngineFactory</a>.
                         Samza currently ships with one storage engine implementation:
index 2f6a4a6..1c06116 100644 (file)
@@ -357,7 +357,7 @@ To use the store in the application, we need to get it from the [TaskContext](/l
 private KeyValueStore<String, Integer> store;
 {% endhighlight %}
 
-Then override the [init](/learn/documentation/{{site.version}}/api/javadocs/org/apache/samza/operators/functions/InitableFunction.html#init-org.apache.samza.config.Config-org.apache.samza.task.TaskContext-) method in `WikipediaStatsAggregator` to initialize the store.
+Then override the [init](/learn/documentation/{{site.version}}/api/javadocs/org/apache/samza/operators/functions/InitableFunction.html#init-org.apache.samza.context.Context-) method in `WikipediaStatsAggregator` to initialize the store.
 {% highlight java %}
 @Override
 public void init(Config config, TaskContext context) {
index 178fdee..e806aad 100644 (file)
@@ -21,8 +21,9 @@ package org.apache.samza.application;
 import java.util.Map;
 import org.apache.samza.annotation.InterfaceStability;
 import org.apache.samza.config.Config;
+import org.apache.samza.context.ApplicationContainerContextFactory;
+import org.apache.samza.context.ApplicationTaskContextFactory;
 import org.apache.samza.metrics.MetricsReporterFactory;
-import org.apache.samza.operators.ContextManager;
 import org.apache.samza.runtime.ProcessorLifecycleListenerFactory;
 
 
@@ -44,17 +45,30 @@ public interface ApplicationDescriptor<S extends ApplicationDescriptor> {
   Config getConfig();
 
   /**
-   * Sets the {@link ContextManager} for this application.
+   * Sets the {@link ApplicationContainerContextFactory} for this application. Each task will be given access to a
+   * different instance of the {@link org.apache.samza.context.ApplicationContainerContext} that this creates. The
+   * context can be accessed through the {@link org.apache.samza.context.Context}.
    * <p>
-   * Setting the {@link ContextManager} is optional. The provided {@link ContextManager} can be used to build the shared
-   * context between the operator functions within a task instance
+   * Setting this is optional.
    *
-   * TODO: this should be replaced by the shared context factory when SAMZA-1714 is fixed.
+   * @param factory the {@link ApplicationContainerContextFactory} for this application
+   * @return type {@code S} of {@link ApplicationDescriptor} with {@code factory} set as its
+   * {@link ApplicationContainerContextFactory}
+   */
+  S withApplicationContainerContextFactory(ApplicationContainerContextFactory<?> factory);
 
-   * @param contextManager the {@link ContextManager} to use for the application
-   * @return type {@code S} of {@link ApplicationDescriptor} with {@code contextManager} set as its {@link ContextManager}
+  /**
+   * Sets the {@link ApplicationTaskContextFactory} for this application. Each task will be given access to a different
+   * instance of the {@link org.apache.samza.context.ApplicationTaskContext} that this creates. The context can be
+   * accessed through the {@link org.apache.samza.context.Context}.
+   * <p>
+   * Setting this is optional.
+   *
+   * @param factory the {@link ApplicationTaskContextFactory} for this application
+   * @return type {@code S} of {@link ApplicationDescriptor} with {@code factory} set as its
+   * {@link ApplicationTaskContextFactory}
    */
-  S withContextManager(ContextManager contextManager);
+  S withApplicationTaskContextFactory(ApplicationTaskContextFactory<?> factory);
 
   /**
    * Sets the {@link ProcessorLifecycleListenerFactory} for this application.
diff --git a/samza-api/src/main/java/org/apache/samza/container/SamzaContainerContext.java b/samza-api/src/main/java/org/apache/samza/container/SamzaContainerContext.java
deleted file mode 100644 (file)
index 6e13f7a..0000000
+++ /dev/null
@@ -1,55 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.samza.container;
-
-import org.apache.samza.config.Config;
-import org.apache.samza.metrics.MetricsRegistry;
-
-import java.util.Collection;
-import java.util.Collections;
-
-/**
- * A SamzaContainerContext maintains per-container information for the tasks it executes.
- */
-public class SamzaContainerContext {
-  public final String id;
-  public final Config config;
-  public final Collection<TaskName> taskNames;
-  public final MetricsRegistry metricsRegistry;
-
-  /**
-   * An immutable context object that can passed to tasks to give them information
-   * about the container in which they are executing.
-   * @param id The id of the container.
-   * @param config The job configuration.
-   * @param taskNames The set of taskName keys for which this container is responsible.
-   * @param metricsRegistry the {@link MetricsRegistry} for the container metrics
-   */
-  public SamzaContainerContext(
-      String id,
-      Config config,
-      Collection<TaskName> taskNames,
-      MetricsRegistry metricsRegistry) {
-    this.id = id;
-    this.config = config;
-    this.taskNames = Collections.unmodifiableCollection(taskNames);
-    this.metricsRegistry = metricsRegistry;
-  }
-}
index 08e0b60..aab8c7f 100644 (file)
@@ -20,11 +20,16 @@ package org.apache.samza.context;
 
 /**
  * An application should implement this to contain any runtime objects required by processing logic which can be shared
- * across all tasks in a container. A single instance of this will be created in each container.
+ * across all tasks in a container. A single instance of this will be created in each container. Note that if the
+ * container moves or the container model changes (e.g. container failure/rebalancing), then this will be recreated.
  * <p>
  * This needs to be created by an implementation of {@link ApplicationContainerContextFactory}. The factory should
  * create the runtime objects contained within this context.
  * <p>
+ * This is related to {@link ContainerContext} in that they are both associated with the container lifecycle. In order
+ * to access this in application code, use {@link Context#getApplicationContainerContext()}. The
+ * {@link ContainerContext} is accessible through {@link Context#getContainerContext()}.
+ * <p>
  * If it is necessary to have a separate instance per task, then use {@link ApplicationTaskContext} instead.
  * <p>
  * This class does not need to be {@link java.io.Serializable} and instances are not persisted across deployments.
index ffc5383..6afbf23 100644 (file)
@@ -25,6 +25,10 @@ package org.apache.samza.context;
  * This needs to be created by an implementation of {@link ApplicationTaskContextFactory}. The factory should create
  * the runtime objects contained within this context.
  * <p>
+ * This is related to {@link TaskContext} in that they are both associated with a task lifecycle. In order to access
+ * this in application code, use {@link Context#getApplicationTaskContext()}. The {@link TaskContext} is accessible
+ * through {@link Context#getTaskContext()}.
+ * <p>
  * If it is possible to share an instance of this across tasks in a container, then use
  * {@link ApplicationContainerContext} instead.
  * <p>
index 9b09fa9..239a011 100644 (file)
@@ -35,6 +35,7 @@ public interface JobContext {
   /**
    * Returns the name of the job.
    * @return name of the job
+   * @throws org.apache.samza.SamzaException if the job name was not configured
    */
   String getJobName();
 
diff --git a/samza-api/src/main/java/org/apache/samza/operators/ContextManager.java b/samza-api/src/main/java/org/apache/samza/operators/ContextManager.java
deleted file mode 100644 (file)
index 5f2c020..0000000
+++ /dev/null
@@ -1,49 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.samza.operators;
-
-import org.apache.samza.annotation.InterfaceStability;
-import org.apache.samza.config.Config;
-import org.apache.samza.task.TaskContext;
-
-
-/**
- * Manages custom context that is shared across multiple operator functions in a task.
- */
-@InterfaceStability.Unstable
-public interface ContextManager {
-
-  /**
-   * Allows initializing and setting a custom context that is shared across multiple operator functions in a task.
-   * <p>
-   * This method is invoked before any {@link org.apache.samza.operators.functions.InitableFunction}s are initialized.
-   * Use {@link TaskContext#setUserContext(Object)} to set the context here and {@link TaskContext#getUserContext()} to
-   * get it in InitableFunctions.
-   *
-   * @param config the {@link Config} for the application
-   * @param context the {@link TaskContext} for this task
-   */
-  void init(Config config, TaskContext context);
-
-  /**
-   * Allows closing the custom context that is shared across multiple operator functions in a task.
-   */
-  void close();
-
-}
index 8a5d83b..7f950de 100644 (file)
@@ -20,8 +20,7 @@
 package org.apache.samza.operators.functions;
 
 import org.apache.samza.annotation.InterfaceStability;
-import org.apache.samza.config.Config;
-import org.apache.samza.task.TaskContext;
+import org.apache.samza.context.Context;
 
 /**
  * A function that can be initialized before execution.
@@ -33,12 +32,10 @@ import org.apache.samza.task.TaskContext;
  */
 @InterfaceStability.Evolving
 public interface InitableFunction {
-
   /**
    * Initializes the function before any messages are processed.
    *
-   * @param config the {@link Config} for the application
-   * @param context the {@link TaskContext} for this task
+   * @param context the {@link Context} for this task
    */
-  default void init(Config config, TaskContext context) { }
+  default void init(Context context) { }
 }
index e230304..5c8d77d 100644 (file)
@@ -18,6 +18,7 @@
  */
 package org.apache.samza.scheduler;
 
+
 /**
  * Provides a way for applications to register some logic to be executed at a future time.
  */
index 8745422..546ca37 100644 (file)
@@ -24,8 +24,9 @@ import org.apache.samza.task.TaskCoordinator;
 
 
 /**
- * The callback that is invoked when its corresponding schedule time registered via
- * {@link org.apache.samza.task.TaskContext} is reached.
+ * The callback that is invoked when its corresponding schedule time registered via {@link CallbackScheduler} is
+ * reached. The {@link CallbackScheduler} is available through
+ * {@link org.apache.samza.context.TaskContext#getCallbackScheduler()}.
  * @param <K> type of the callback key
  */
 public interface ScheduledCallback<K> {
index 800deeb..2425cf3 100644 (file)
@@ -20,8 +20,8 @@
 package org.apache.samza.storage;
 
 import java.io.File;
-
-import org.apache.samza.container.SamzaContainerContext;
+import org.apache.samza.context.ContainerContext;
+import org.apache.samza.context.JobContext;
 import org.apache.samza.metrics.MetricsRegistry;
 import org.apache.samza.serializers.Serde;
 import org.apache.samza.system.SystemStreamPartition;
@@ -43,6 +43,7 @@ public interface StorageEngineFactory<K, V> {
    * @param collector MessageCollector the storage engine uses to persist changes.
    * @param registry MetricsRegistry to which to publish storage-engine specific metrics.
    * @param changeLogSystemStreamPartition Samza stream partition from which to receive the changelog.
+   * @param jobContext Information about the job in which the task is executing
    * @param containerContext Information about the container in which the task is executing.
    * @return The storage engine instance.
    */
@@ -54,5 +55,6 @@ public interface StorageEngineFactory<K, V> {
     MessageCollector collector,
     MetricsRegistry registry,
     SystemStreamPartition changeLogSystemStreamPartition,
-    SamzaContainerContext containerContext);
+    JobContext jobContext,
+    ContainerContext containerContext);
 }
index 490acc0..6c88fd3 100644 (file)
@@ -21,11 +21,9 @@ package org.apache.samza.table;
 import java.util.List;
 import java.util.Map;
 import java.util.concurrent.CompletableFuture;
-
 import org.apache.samza.annotation.InterfaceStability;
-import org.apache.samza.container.SamzaContainerContext;
+import org.apache.samza.context.Context;
 import org.apache.samza.operators.KV;
-import org.apache.samza.task.TaskContext;
 
 
 /**
@@ -40,10 +38,9 @@ public interface ReadableTable<K, V> extends Table<KV<K, V>> {
   /**
    * Initializes the table during container initialization.
    * Guaranteed to be invoked as the first operation on the table.
-   * @param containerContext Samza container context
-   * @param taskContext nullable for global table
+   * @param context {@link Context} corresponding to this table
    */
-  default void init(SamzaContainerContext containerContext, TaskContext taskContext) {
+  default void init(Context context) {
   }
 
   /**
index 99446e4..350324c 100644 (file)
 package org.apache.samza.table;
 
 import java.util.Map;
-
 import org.apache.samza.annotation.InterfaceStability;
 import org.apache.samza.config.Config;
-import org.apache.samza.container.SamzaContainerContext;
-import org.apache.samza.task.TaskContext;
+import org.apache.samza.context.Context;
 
 /**
  * A table provider provides the implementation for a table. It ensures a table is
@@ -33,10 +31,9 @@ import org.apache.samza.task.TaskContext;
 public interface TableProvider {
   /**
    * Initialize TableProvider with container and task context
-   * @param containerContext Samza container context
-   * @param taskContext nullable for global table
+   * @param context context for the task
    */
-  void init(SamzaContainerContext containerContext, TaskContext taskContext);
+  void init(Context context);
 
   /**
    * Get an instance of the table for read/write operations
index 6926f16..8bf0619 100644 (file)
@@ -19,7 +19,8 @@
 
 package org.apache.samza.task;
 
-import org.apache.samza.config.Config;
+import org.apache.samza.context.Context;
+
 
 /**
  * Used as an interface for user processing StreamTasks that need to have specific functionality performed as their StreamTasks
@@ -28,9 +29,8 @@ import org.apache.samza.config.Config;
 public interface InitableTask {
   /**
    * Called by TaskRunner each time an implementing task is created.
-   * @param config Allows accessing of fields in the configuration files that this StreamTask is specified in.
    * @param context Allows accessing of contextual data of this StreamTask.
    * @throws Exception Any exception types encountered during the execution of the processing task.
    */
-  void init(Config config, TaskContext context) throws Exception;
+  void init(Context context) throws Exception;
 }
diff --git a/samza-api/src/main/java/org/apache/samza/task/TaskContext.java b/samza-api/src/main/java/org/apache/samza/task/TaskContext.java
deleted file mode 100644 (file)
index 007028a..0000000
+++ /dev/null
@@ -1,98 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.samza.task;
-
-import java.util.Set;
-
-import org.apache.samza.container.SamzaContainerContext;
-import org.apache.samza.container.TaskName;
-import org.apache.samza.metrics.MetricsRegistry;
-import org.apache.samza.scheduler.ScheduledCallback;
-import org.apache.samza.system.SystemStreamPartition;
-import org.apache.samza.table.Table;
-
-
-/**
- * A TaskContext provides resources about the {@link org.apache.samza.task.StreamTask}, particularly during
- * initialization in an {@link org.apache.samza.task.InitableTask}.
- * TODO this will be replaced by {@link org.apache.samza.context.TaskContext} in the near future by SAMZA-1714
- */
-public interface TaskContext {
-  MetricsRegistry getMetricsRegistry();
-
-  Set<SystemStreamPartition> getSystemStreamPartitions();
-
-  Object getStore(String name);
-
-  Table getTable(String tableId);
-
-  TaskName getTaskName();
-
-  SamzaContainerContext getSamzaContainerContext();
-
-  /**
-   * Set the starting offset for the given {@link org.apache.samza.system.SystemStreamPartition}. Offsets
-   * can only be set for a {@link org.apache.samza.system.SystemStreamPartition} assigned to this task
-   * (as returned by {@link #getSystemStreamPartitions()}); trying to set the offset for any other partition
-   * will have no effect.
-   *
-   * NOTE: this feature is experimental, and the API may change in a future release.
-   *
-   * @param ssp {@link org.apache.samza.system.SystemStreamPartition} whose offset should be set
-   * @param offset to set for the given {@link org.apache.samza.system.SystemStreamPartition}
-   *
-   */
-  void setStartingOffset(SystemStreamPartition ssp, String offset);
-
-  /**
-   * Sets the user-defined context.
-   *
-   * @param context the user-defined context to set
-   */
-  default void setUserContext(Object context) { }
-
-  /**
-   * Gets the user-defined context.
-   *
-   * @return the user-defined context if set, else null
-   */
-  default Object getUserContext() {
-    return null;
-  }
-
-  /**
-   * Schedule the {@code callback} for the provided {@code key} to be invoked at epoch-time {@code timestamp}.
-   * The callback will be invoked exclusively with any other operations for this task,
-   * e.g. processing, windowing and commit.
-   * @param key key for the callback
-   * @param timestamp epoch time when the callback will be fired, in milliseconds
-   * @param callback callback to call when the {@code timestamp} is reached
-   * @param <K> type of the key
-   */
-  <K> void scheduleCallback(K key, long timestamp, ScheduledCallback<K> callback);
-
-  /**
-   * Delete the scheduled {@code callback} for the {@code key}.
-   * Deletion only happens if the callback hasn't been fired. Otherwise it will not interrupt.
-   * @param key callback key
-   * @param <K> type of the key
-   */
-  <K> void deleteScheduledCallback(K key);
-}
index ad40d35..83532e4 100644 (file)
@@ -22,10 +22,8 @@ import java.io.Serializable;
 import java.util.Map;
 import java.util.Set;
 import java.util.concurrent.TimeUnit;
-
 import org.apache.samza.annotation.InterfaceStability;
-import org.apache.samza.config.Config;
-import org.apache.samza.task.TaskContext;
+import org.apache.samza.context.Context;
 
 /**
  * A rate limiter interface used by Samza components to limit throughput of operations
@@ -53,10 +51,9 @@ public interface RateLimiter extends Serializable {
   /**
    * Initialize this rate limiter, this method should be called during container initialization.
    *
-   * @param config job configuration
-   * @param taskContext task context that owns this rate limiter
+   * @param context {@link Context} that corresponds to this rate limiter
    */
-  void init(Config config, TaskContext taskContext);
+  void init(Context context);
 
   /**
    * Attempt to acquire the provided number of credits, blocks indefinitely until
index b58d5a5..5416af5 100644 (file)
@@ -25,8 +25,11 @@ import java.util.Map;
 import java.util.Optional;
 import java.util.Set;
 import org.apache.samza.config.Config;
+import org.apache.samza.context.ApplicationContainerContext;
+import org.apache.samza.context.ApplicationContainerContextFactory;
+import org.apache.samza.context.ApplicationTaskContext;
+import org.apache.samza.context.ApplicationTaskContextFactory;
 import org.apache.samza.metrics.MetricsReporterFactory;
-import org.apache.samza.operators.ContextManager;
 import org.apache.samza.operators.KV;
 import org.apache.samza.operators.TableDescriptor;
 import org.apache.samza.operators.descriptors.base.stream.InputDescriptor;
@@ -38,7 +41,6 @@ import org.apache.samza.runtime.ProcessorLifecycleListenerFactory;
 import org.apache.samza.serializers.KVSerde;
 import org.apache.samza.serializers.NoOpSerde;
 import org.apache.samza.serializers.Serde;
-import org.apache.samza.task.TaskContext;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -47,7 +49,8 @@ import org.slf4j.LoggerFactory;
  * This is the base class that implements interface {@link ApplicationDescriptor}.
  * <p>
  * This base class contains the common objects that are used by both high-level and low-level API applications, such as
- * {@link Config}, {@link ContextManager}, and {@link ProcessorLifecycleListenerFactory}.
+ * {@link Config}, {@link ApplicationContainerContextFactory}, {@link ApplicationTaskContextFactory}, and
+ * {@link ProcessorLifecycleListenerFactory}.
  *
  * @param <S> the type of {@link ApplicationDescriptor} interface this implements. It has to be either
  *            {@link StreamApplicationDescriptor} or {@link TaskApplicationDescriptor}
@@ -64,17 +67,8 @@ public abstract class ApplicationDescriptorImpl<S extends ApplicationDescriptor>
   private final Map<String, KV<Serde, Serde>> tableSerdes = new HashMap<>();
   final Config config;
 
-  // Default to no-op functions in ContextManager
-  // TODO: this should be replaced by shared context factory defined in SAMZA-1714
-  ContextManager contextManager = new ContextManager() {
-    @Override
-    public void init(Config config, TaskContext context) {
-    }
-
-    @Override
-    public void close() {
-    }
-  };
+  private Optional<ApplicationContainerContextFactory<?>> applicationContainerContextFactoryOptional = Optional.empty();
+  private Optional<ApplicationTaskContextFactory<?>> applicationTaskContextFactoryOptional = Optional.empty();
 
   // Default to no-op  ProcessorLifecycleListenerFactory
   ProcessorLifecycleListenerFactory listenerFactory = (pcontext, cfg) -> new ProcessorLifecycleListener() { };
@@ -90,8 +84,14 @@ public abstract class ApplicationDescriptorImpl<S extends ApplicationDescriptor>
   }
 
   @Override
-  public S withContextManager(ContextManager contextManager) {
-    this.contextManager = contextManager;
+  public S withApplicationContainerContextFactory(ApplicationContainerContextFactory<?> factory) {
+    this.applicationContainerContextFactoryOptional = Optional.of(factory);
+    return (S) this;
+  }
+
+  @Override
+  public S withApplicationTaskContextFactory(ApplicationTaskContextFactory<?> factory) {
+    this.applicationTaskContextFactoryOptional = Optional.of(factory);
     return (S) this;
   }
 
@@ -118,12 +118,27 @@ public abstract class ApplicationDescriptorImpl<S extends ApplicationDescriptor>
   }
 
   /**
-   * Get the {@link ContextManager} associated with this application
+   * Get the {@link ApplicationContainerContextFactory} specified by the application.
+   *
+   * @return {@link ApplicationContainerContextFactory} if application specified it; empty otherwise
+   */
+  public Optional<ApplicationContainerContextFactory<ApplicationContainerContext>> getApplicationContainerContextFactory() {
+    @SuppressWarnings("unchecked") // ok because all context types are at least ApplicationContainerContext
+    Optional<ApplicationContainerContextFactory<ApplicationContainerContext>> factoryOptional =
+        (Optional) this.applicationContainerContextFactoryOptional;
+    return factoryOptional;
+  }
+
+  /**
+   * Get the {@link ApplicationTaskContextFactory} specified by the application.
    *
-   * @return the {@link ContextManager} for this application
+   * @return {@link ApplicationTaskContextFactory} if application specified it; empty otherwise
    */
-  public ContextManager getContextManager() {
-    return contextManager;
+  public Optional<ApplicationTaskContextFactory<ApplicationTaskContext>> getApplicationTaskContextFactory() {
+    @SuppressWarnings("unchecked") // ok because all context types are at least ApplicationTaskContext
+    Optional<ApplicationTaskContextFactory<ApplicationTaskContext>> factoryOptional =
+        (Optional) this.applicationTaskContextFactoryOptional;
+    return factoryOptional;
   }
 
   /**
diff --git a/samza-core/src/main/java/org/apache/samza/container/TaskContextImpl.java b/samza-core/src/main/java/org/apache/samza/container/TaskContextImpl.java
deleted file mode 100644 (file)
index 25ffe8f..0000000
+++ /dev/null
@@ -1,169 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.samza.container;
-
-import com.google.common.collect.ImmutableSet;
-import java.util.function.Function;
-import org.apache.samza.checkpoint.OffsetManager;
-import org.apache.samza.job.model.JobModel;
-import org.apache.samza.metrics.ReadableMetricsRegistry;
-import org.apache.samza.storage.kv.KeyValueStore;
-import org.apache.samza.system.StreamMetadataCache;
-import org.apache.samza.system.SystemStreamPartition;
-import org.apache.samza.table.Table;
-import org.apache.samza.table.TableManager;
-import org.apache.samza.task.EpochTimeScheduler;
-import org.apache.samza.task.TaskContext;
-import org.apache.samza.scheduler.ScheduledCallback;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
-import java.util.HashMap;
-import java.util.Map;
-import java.util.Set;
-import java.util.concurrent.ScheduledExecutorService;
-
-
-/**
- * TODO this will be replaced by {@link org.apache.samza.context.TaskContextImpl} in the near future by SAMZA-1714
- */
-public class TaskContextImpl implements TaskContext {
-  private static final Logger LOG = LoggerFactory.getLogger(TaskContextImpl.class);
-
-  private final TaskName taskName;
-  private final TaskInstanceMetrics metrics;
-  private final SamzaContainerContext containerContext;
-  private final Set<SystemStreamPartition> systemStreamPartitions;
-  private final OffsetManager offsetManager;
-  private final Function<String, KeyValueStore> kvStoreSupplier;
-  private final TableManager tableManager;
-  private final JobModel jobModel;
-  private final StreamMetadataCache streamMetadataCache;
-  private final Map<String, Object> objectRegistry = new HashMap<>();
-  private final EpochTimeScheduler timerScheduler;
-
-  private Object userContext = null;
-
-  public TaskContextImpl(TaskName taskName,
-                         TaskInstanceMetrics metrics,
-                         SamzaContainerContext containerContext,
-                         Set<SystemStreamPartition> systemStreamPartitions,
-                         OffsetManager offsetManager,
-                         Function<String, KeyValueStore> kvStoreSupplier,
-                         TableManager tableManager,
-                         JobModel jobModel,
-                         StreamMetadataCache streamMetadataCache,
-                         ScheduledExecutorService timerExecutor) {
-    this.taskName = taskName;
-    this.metrics = metrics;
-    this.containerContext = containerContext;
-    this.systemStreamPartitions = ImmutableSet.copyOf(systemStreamPartitions);
-    this.offsetManager = offsetManager;
-    this.kvStoreSupplier = kvStoreSupplier;
-    this.tableManager = tableManager;
-    this.jobModel = jobModel;
-    this.streamMetadataCache = streamMetadataCache;
-    this.timerScheduler = EpochTimeScheduler.create(timerExecutor);
-  }
-
-  @Override
-  public ReadableMetricsRegistry getMetricsRegistry() {
-    return metrics.registry();
-  }
-
-  @Override
-  public Set<SystemStreamPartition> getSystemStreamPartitions() {
-    return systemStreamPartitions;
-  }
-
-  @Override
-  public KeyValueStore getStore(String storeName) {
-    KeyValueStore store = kvStoreSupplier.apply(storeName);
-    if (store == null) {
-      LOG.warn("No store found for name: {}", storeName);
-    }
-    return store;
-  }
-
-  @Override
-  public Table getTable(String tableId) {
-    if (tableManager != null) {
-      return tableManager.getTable(tableId);
-    } else {
-      LOG.warn("No table manager found");
-      return null;
-    }
-  }
-
-  @Override
-  public TaskName getTaskName() {
-    return taskName;
-  }
-
-  @Override
-  public SamzaContainerContext getSamzaContainerContext() {
-    return containerContext;
-  }
-
-  @Override
-  public void setStartingOffset(SystemStreamPartition ssp, String offset) {
-    offsetManager.setStartingOffset(taskName, ssp, offset);
-  }
-
-  @Override
-  public void setUserContext(Object context) {
-    userContext = context;
-  }
-
-  @Override
-  public Object getUserContext() {
-    return userContext;
-  }
-
-  @Override
-  public <K> void scheduleCallback(K key, long timestamp, ScheduledCallback<K> callback) {
-    timerScheduler.setTimer(key, timestamp, callback);
-  }
-
-  @Override
-  public <K> void deleteScheduledCallback(K key) {
-    timerScheduler.deleteTimer(key);
-  }
-
-  public void registerObject(String name, Object value) {
-    objectRegistry.put(name, value);
-  }
-
-  public Object fetchObject(String name) {
-    return objectRegistry.get(name);
-  }
-
-  public JobModel getJobModel() {
-    return jobModel;
-  }
-
-  public StreamMetadataCache getStreamMetadataCache() {
-    return streamMetadataCache;
-  }
-
-  public EpochTimeScheduler getTimerScheduler() {
-    return timerScheduler;
-  }
-}
index c2c182f..93c7eb1 100644 (file)
  */
 package org.apache.samza.context;
 
+import com.google.common.base.Preconditions;
+
+import java.util.Objects;
+import java.util.Optional;
+
+
 public class ContextImpl implements Context {
   private final JobContext jobContext;
   private final ContainerContext containerContext;
   private final TaskContext taskContext;
-  private final ApplicationContainerContext applicationContainerContext;
-  private final ApplicationTaskContext applicationTaskContext;
+  private final Optional<ApplicationContainerContext> applicationContainerContextOptional;
+  private final Optional<ApplicationTaskContext> applicationTaskContextOptional;
 
   /**
    * @param jobContext non-null job context
    * @param containerContext non-null framework container context
    * @param taskContext non-null framework task context
-   * @param applicationContainerContext nullable application-defined container context
-   * @param applicationTaskContext nullable application-defined task context
+   * @param applicationContainerContextOptional optional application-defined container context
+   * @param applicationTaskContextOptional optional application-defined task context
    */
-  public ContextImpl(JobContext jobContext, ContainerContext containerContext, TaskContext taskContext,
-      ApplicationContainerContext applicationContainerContext, ApplicationTaskContext applicationTaskContext) {
-    this.jobContext = jobContext;
-    this.containerContext = containerContext;
-    this.taskContext = taskContext;
-    this.applicationContainerContext = applicationContainerContext;
-    this.applicationTaskContext = applicationTaskContext;
+  public ContextImpl(JobContext jobContext,
+      ContainerContext containerContext,
+      TaskContext taskContext,
+      Optional<ApplicationContainerContext> applicationContainerContextOptional,
+      Optional<ApplicationTaskContext> applicationTaskContextOptional) {
+    this.jobContext = Preconditions.checkNotNull(jobContext, "Job context can not be null");
+    this.containerContext = Preconditions.checkNotNull(containerContext, "Container context can not be null");
+    this.taskContext = Preconditions.checkNotNull(taskContext, "Task context can not be null");
+    this.applicationContainerContextOptional = applicationContainerContextOptional;
+    this.applicationTaskContextOptional = applicationTaskContextOptional;
   }
 
   @Override
@@ -58,17 +67,38 @@ public class ContextImpl implements Context {
 
   @Override
   public ApplicationContainerContext getApplicationContainerContext() {
-    if (this.applicationContainerContext == null) {
+    if (!this.applicationContainerContextOptional.isPresent()) {
       throw new IllegalStateException("No application-defined container context exists");
     }
-    return this.applicationContainerContext;
+    return this.applicationContainerContextOptional.get();
   }
 
   @Override
   public ApplicationTaskContext getApplicationTaskContext() {
-    if (this.applicationTaskContext == null) {
+    if (!this.applicationTaskContextOptional.isPresent()) {
       throw new IllegalStateException("No application-defined task context exists");
     }
-    return this.applicationTaskContext;
+    return this.applicationTaskContextOptional.get();
+  }
+
+  @Override
+  public boolean equals(Object o) {
+    if (this == o) {
+      return true;
+    }
+    if (o == null || getClass() != o.getClass()) {
+      return false;
+    }
+    ContextImpl context = (ContextImpl) o;
+    return Objects.equals(jobContext, context.jobContext) && Objects.equals(containerContext, context.containerContext)
+        && Objects.equals(taskContext, context.taskContext) && Objects.equals(applicationContainerContextOptional,
+        context.applicationContainerContextOptional) && Objects.equals(applicationTaskContextOptional,
+        context.applicationTaskContextOptional);
+  }
+
+  @Override
+  public int hashCode() {
+    return Objects.hash(jobContext, containerContext, taskContext, applicationContainerContextOptional,
+        applicationTaskContextOptional);
   }
 }
index 8fe44e4..797e2ca 100644 (file)
@@ -19,6 +19,8 @@
 package org.apache.samza.context;
 
 import org.apache.samza.config.Config;
+import org.apache.samza.config.JobConfig;
+import scala.Option;
 
 
 public class JobContextImpl implements JobContext {
@@ -26,12 +28,30 @@ public class JobContextImpl implements JobContext {
   private final String jobName;
   private final String jobId;
 
-  public JobContextImpl(Config config, String jobName, String jobId) {
+  private JobContextImpl(Config config, String jobName, String jobId) {
     this.config = config;
     this.jobName = jobName;
     this.jobId = jobId;
   }
 
+  /**
+   * Build a {@link JobContextImpl} from a {@link Config} object.
+   * This extracts some information like job name and job id.
+   *
+   * @param config used to extract job information
+   * @return {@link JobContextImpl} corresponding to the {@code config}
+   * @throws IllegalArgumentException if job name is not defined in the {@code config}
+   */
+  public static JobContextImpl fromConfigWithDefaults(Config config) {
+    JobConfig jobConfig = new JobConfig(config);
+    Option<String> jobName = jobConfig.getName();
+    if (jobName.isEmpty()) {
+      throw new IllegalArgumentException("Job name is not defined in configuration");
+    }
+    String jobId = jobConfig.getJobId();
+    return new JobContextImpl(config, jobName.get(), jobId);
+  }
+
   @Override
   public Config getConfig() {
     return this.config;
index e975dcd..ec52f8a 100644 (file)
  */
 package org.apache.samza.context;
 
-import java.util.function.Function;
 import org.apache.samza.checkpoint.OffsetManager;
+import org.apache.samza.job.model.JobModel;
 import org.apache.samza.job.model.TaskModel;
 import org.apache.samza.metrics.MetricsRegistry;
 import org.apache.samza.scheduler.CallbackScheduler;
 import org.apache.samza.storage.kv.KeyValueStore;
+import org.apache.samza.system.StreamMetadataCache;
 import org.apache.samza.system.SystemStreamPartition;
 import org.apache.samza.table.Table;
 import org.apache.samza.table.TableManager;
 
+import java.util.HashMap;
+import java.util.Map;
+import java.util.function.Function;
+
 
 public class TaskContextImpl implements TaskContext {
   private final TaskModel taskModel;
@@ -36,19 +41,26 @@ public class TaskContextImpl implements TaskContext {
   private final TableManager tableManager;
   private final CallbackScheduler callbackScheduler;
   private final OffsetManager offsetManager;
+  private final JobModel jobModel;
+  private final StreamMetadataCache streamMetadataCache;
+  private final Map<String, Object> objectRegistry = new HashMap<>();
 
   public TaskContextImpl(TaskModel taskModel,
       MetricsRegistry taskMetricsRegistry,
       Function<String, KeyValueStore> keyValueStoreProvider,
       TableManager tableManager,
       CallbackScheduler callbackScheduler,
-      OffsetManager offsetManager) {
+      OffsetManager offsetManager,
+      JobModel jobModel,
+      StreamMetadataCache streamMetadataCache) {
     this.taskModel = taskModel;
     this.taskMetricsRegistry = taskMetricsRegistry;
     this.keyValueStoreProvider = keyValueStoreProvider;
     this.tableManager = tableManager;
     this.callbackScheduler = callbackScheduler;
     this.offsetManager = offsetManager;
+    this.jobModel = jobModel;
+    this.streamMetadataCache = streamMetadataCache;
   }
 
   @Override
@@ -84,4 +96,22 @@ public class TaskContextImpl implements TaskContext {
   public void setStartingOffset(SystemStreamPartition systemStreamPartition, String offset) {
     this.offsetManager.setStartingOffset(this.taskModel.getTaskName(), systemStreamPartition, offset);
   }
+
+  // TODO SAMZA-1935: below methods are used by operator code; they should be decoupled from this client API
+
+  public void registerObject(String name, Object value) {
+    this.objectRegistry.put(name, value);
+  }
+
+  public Object fetchObject(String name) {
+    return this.objectRegistry.get(name);
+  }
+
+  public JobModel getJobModel() {
+    return this.jobModel;
+  }
+
+  public StreamMetadataCache getStreamMetadataCache() {
+    return this.streamMetadataCache;
+  }
 }
index 99ed089..4965f7b 100644 (file)
@@ -19,7 +19,7 @@
 
 package org.apache.samza.operators.impl;
 
-import org.apache.samza.config.Config;
+import org.apache.samza.context.Context;
 import org.apache.samza.operators.spec.BroadcastOperatorSpec;
 import org.apache.samza.operators.spec.OperatorSpec;
 import org.apache.samza.system.ControlMessage;
@@ -28,7 +28,6 @@ import org.apache.samza.system.OutgoingMessageEnvelope;
 import org.apache.samza.system.SystemStream;
 import org.apache.samza.system.WatermarkMessage;
 import org.apache.samza.task.MessageCollector;
-import org.apache.samza.task.TaskContext;
 import org.apache.samza.task.TaskCoordinator;
 
 import java.util.Collection;
@@ -40,14 +39,14 @@ class BroadcastOperatorImpl<M> extends OperatorImpl<M, Void> {
   private final SystemStream systemStream;
   private final String taskName;
 
-  BroadcastOperatorImpl(BroadcastOperatorSpec<M> broadcastOpSpec, SystemStream systemStream, TaskContext context) {
+  BroadcastOperatorImpl(BroadcastOperatorSpec<M> broadcastOpSpec, SystemStream systemStream, Context context) {
     this.broadcastOpSpec = broadcastOpSpec;
     this.systemStream = systemStream;
-    this.taskName = context.getTaskName().getTaskName();
+    this.taskName = context.getTaskContext().getTaskModel().getTaskName().getTaskName();
   }
 
   @Override
-  protected void handleInit(Config config, TaskContext context) {
+  protected void handleInit(Context context) {
   }
 
   @Override
index 6cc57e0..2a73064 100644 (file)
  */
 package org.apache.samza.operators.impl;
 
-import org.apache.samza.config.Config;
+import org.apache.samza.context.Context;
 import org.apache.samza.operators.KV;
 import org.apache.samza.operators.functions.InputTransformer;
 import org.apache.samza.operators.spec.InputOperatorSpec;
 import org.apache.samza.operators.spec.OperatorSpec;
 import org.apache.samza.system.IncomingMessageEnvelope;
 import org.apache.samza.task.MessageCollector;
-import org.apache.samza.task.TaskContext;
 import org.apache.samza.task.TaskCoordinator;
 
 import java.util.Collection;
@@ -44,7 +43,7 @@ public final class InputOperatorImpl extends OperatorImpl<IncomingMessageEnvelop
   }
 
   @Override
-  protected void handleInit(Config config, TaskContext context) {
+  protected void handleInit(Context context) {
   }
 
   @Override
index 5cafd26..675b211 100644 (file)
@@ -21,23 +21,23 @@ package org.apache.samza.operators.impl;
 import org.apache.samza.SamzaException;
 import org.apache.samza.config.Config;
 import org.apache.samza.config.MetricsConfig;
-import org.apache.samza.container.TaskContextImpl;
 import org.apache.samza.container.TaskName;
-import org.apache.samza.job.model.ContainerModel;
+import org.apache.samza.context.Context;
+import org.apache.samza.context.TaskContextImpl;
 import org.apache.samza.job.model.TaskModel;
-import org.apache.samza.operators.Scheduler;
-import org.apache.samza.operators.functions.ScheduledFunction;
-import org.apache.samza.operators.functions.WatermarkFunction;
-import org.apache.samza.system.EndOfStreamMessage;
 import org.apache.samza.metrics.Counter;
 import org.apache.samza.metrics.MetricsRegistry;
 import org.apache.samza.metrics.Timer;
+import org.apache.samza.operators.Scheduler;
+import org.apache.samza.operators.functions.ScheduledFunction;
+import org.apache.samza.operators.functions.WatermarkFunction;
 import org.apache.samza.operators.spec.OperatorSpec;
+import org.apache.samza.scheduler.CallbackScheduler;
+import org.apache.samza.system.EndOfStreamMessage;
 import org.apache.samza.system.SystemStream;
 import org.apache.samza.system.SystemStreamPartition;
 import org.apache.samza.system.WatermarkMessage;
 import org.apache.samza.task.MessageCollector;
-import org.apache.samza.task.TaskContext;
 import org.apache.samza.task.TaskCoordinator;
 import org.apache.samza.util.HighResolutionClock;
 import org.slf4j.Logger;
@@ -82,16 +82,15 @@ public abstract class OperatorImpl<M, RM> {
   private EndOfStreamStates eosStates;
   // watermark states
   private WatermarkStates watermarkStates;
-  private TaskContext taskContext;
+  private CallbackScheduler callbackScheduler;
   private ControlMessageSender controlMessageSender;
 
   /**
    * Initialize this {@link OperatorImpl} and its user-defined functions.
    *
-   * @param config  the {@link Config} for the task
-   * @param context  the {@link TaskContext} for the task
+   * @param context the {@link Context} for the task
    */
-  public final void init(Config config, TaskContext context) {
+  public final void init(Context context) {
     String opId = getOpImplId();
 
     if (initialized) {
@@ -102,32 +101,24 @@ public abstract class OperatorImpl<M, RM> {
       throw new IllegalStateException(String.format("Attempted to initialize Operator %s after it was closed.", opId));
     }
 
-    this.highResClock = createHighResClock(config);
+    this.highResClock = createHighResClock(context.getJobContext().getConfig());
     registeredOperators = new HashSet<>();
     prevOperators = new HashSet<>();
     inputStreams = new HashSet<>();
-    MetricsRegistry metricsRegistry = context.getMetricsRegistry();
+    // TODO SAMZA-1935: the objects that are only accessible through TaskContextImpl should be moved somewhere else
+    TaskContextImpl taskContext = (TaskContextImpl) context.getTaskContext();
+    MetricsRegistry metricsRegistry = taskContext.getTaskMetricsRegistry();
     this.numMessage = metricsRegistry.newCounter(METRICS_GROUP, opId + "-messages");
     this.handleMessageNs = metricsRegistry.newTimer(METRICS_GROUP, opId + "-handle-message-ns");
     this.handleTimerNs = metricsRegistry.newTimer(METRICS_GROUP, opId + "-handle-timer-ns");
-    this.taskName = context.getTaskName();
+    this.taskName = taskContext.getTaskModel().getTaskName();
 
-    TaskContextImpl taskContext = (TaskContextImpl) context;
     this.eosStates = (EndOfStreamStates) taskContext.fetchObject(EndOfStreamStates.class.getName());
     this.watermarkStates = (WatermarkStates) taskContext.fetchObject(WatermarkStates.class.getName());
     this.controlMessageSender = new ControlMessageSender(taskContext.getStreamMetadataCache());
-
-    if (taskContext.getJobModel() != null) {
-      ContainerModel containerModel = taskContext.getJobModel().getContainers()
-          .get(context.getSamzaContainerContext().id);
-      this.taskModel = containerModel.getTasks().get(taskName);
-    } else {
-      this.taskModel = null;
-      this.usedInCurrentTask = true;
-    }
-
-    this.taskContext = taskContext;
-    handleInit(config, taskContext);
+    this.taskModel = taskContext.getTaskModel();
+    this.callbackScheduler = taskContext.getCallbackScheduler();
+    handleInit(context);
 
     initialized = true;
   }
@@ -135,10 +126,9 @@ public abstract class OperatorImpl<M, RM> {
   /**
    * Initialize this {@link OperatorImpl} and its user-defined functions.
    *
-   * @param config  the {@link Config} for the task
-   * @param context  the {@link TaskContext} for the task
+   * @param context the {@link Context} for the task
    */
-  protected abstract void handleInit(Config config, TaskContext context);
+  protected abstract void handleInit(Context context);
 
   /**
    * Register an operator that this operator should propagate its results to.
@@ -448,7 +438,7 @@ public abstract class OperatorImpl<M, RM> {
     return new Scheduler<K>() {
       @Override
       public void schedule(K key, long time) {
-        taskContext.scheduleCallback(key, time, (k, collector, coordinator) -> {
+        callbackScheduler.scheduleCallback(key, time, (k, collector, coordinator) -> {
             final ScheduledFunction<K, RM> scheduledFn = getOperatorSpec().getScheduledFn();
             if (scheduledFn != null) {
               final Collection<RM> output = scheduledFn.onCallback(key, time);
@@ -468,7 +458,7 @@ public abstract class OperatorImpl<M, RM> {
 
       @Override
       public void delete(K key) {
-        taskContext.deleteScheduledCallback(key);
+        callbackScheduler.deleteCallback(key);
       }
     };
   }
index d76c7de..e668b91 100644 (file)
@@ -23,20 +23,18 @@ import com.google.common.collect.Lists;
 import com.google.common.collect.Multimap;
 import org.apache.samza.config.Config;
 import org.apache.samza.config.StreamConfig;
-import org.apache.samza.container.SamzaContainerContext;
-import org.apache.samza.container.TaskContextImpl;
+import org.apache.samza.context.Context;
+import org.apache.samza.context.TaskContextImpl;
 import org.apache.samza.job.model.JobModel;
-import org.apache.samza.metrics.MetricsRegistry;
 import org.apache.samza.operators.KV;
+import org.apache.samza.operators.OperatorSpecGraph;
 import org.apache.samza.operators.Scheduler;
 import org.apache.samza.operators.functions.JoinFunction;
 import org.apache.samza.operators.functions.PartialJoinFunction;
-import org.apache.samza.util.TimestampedValue;
 import org.apache.samza.operators.spec.BroadcastOperatorSpec;
 import org.apache.samza.operators.spec.InputOperatorSpec;
 import org.apache.samza.operators.spec.JoinOperatorSpec;
 import org.apache.samza.operators.spec.OperatorSpec;
-import org.apache.samza.operators.OperatorSpecGraph;
 import org.apache.samza.operators.spec.OutputOperatorSpec;
 import org.apache.samza.operators.spec.PartitionByOperatorSpec;
 import org.apache.samza.operators.spec.SendToTableOperatorSpec;
@@ -46,8 +44,8 @@ import org.apache.samza.operators.spec.StreamTableJoinOperatorSpec;
 import org.apache.samza.operators.spec.WindowOperatorSpec;
 import org.apache.samza.storage.kv.KeyValueStore;
 import org.apache.samza.system.SystemStream;
-import org.apache.samza.task.TaskContext;
 import org.apache.samza.util.Clock;
+import org.apache.samza.util.TimestampedValue;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -93,14 +91,14 @@ public class OperatorImplGraph {
    * in the {@code specGraph}.
    *
    * @param specGraph  the {@link OperatorSpecGraph} containing the logical {@link OperatorSpec} DAG
-   * @param config  the {@link Config} required to instantiate operators
-   * @param context  the {@link TaskContext} required to instantiate operators
+   * @param context  the {@link Context} required to instantiate operators
    * @param clock  the {@link Clock} to get current time
    */
-  public OperatorImplGraph(OperatorSpecGraph specGraph, Config config, TaskContext context, Clock clock) {
+  public OperatorImplGraph(OperatorSpecGraph specGraph, Context context, Clock clock) {
     this.clock = clock;
-    StreamConfig streamConfig = new StreamConfig(config);
-    TaskContextImpl taskContext = (TaskContextImpl) context;
+    StreamConfig streamConfig = new StreamConfig(context.getJobContext().getConfig());
+    // TODO SAMZA-1935: the objects that are only accessible through TaskContextImpl should be moved somewhere else
+    TaskContextImpl taskContext = (TaskContextImpl) context.getTaskContext();
     Map<SystemStream, Integer> producerTaskCounts =
         hasIntermediateStreams(specGraph)
             ? getProducerTaskCountForIntermediateStreams(
@@ -113,15 +111,16 @@ public class OperatorImplGraph {
 
     // set states for end-of-stream
     taskContext.registerObject(EndOfStreamStates.class.getName(),
-        new EndOfStreamStates(context.getSystemStreamPartitions(), producerTaskCounts));
+        new EndOfStreamStates(taskContext.getTaskModel().getSystemStreamPartitions(), producerTaskCounts));
     // set states for watermark
     taskContext.registerObject(WatermarkStates.class.getName(),
-        new WatermarkStates(context.getSystemStreamPartitions(), producerTaskCounts, getMetricsRegistry(context)));
+        new WatermarkStates(taskContext.getTaskModel().getSystemStreamPartitions(), producerTaskCounts,
+            context.getContainerContext().getContainerMetricsRegistry()));
 
     specGraph.getInputOperators().forEach((streamId, inputOpSpec) -> {
         SystemStream systemStream = streamConfig.streamIdToSystemStream(streamId);
         InputOperatorImpl inputOperatorImpl =
-            (InputOperatorImpl) createAndRegisterOperatorImpl(null, inputOpSpec, systemStream, config, context);
+            (InputOperatorImpl) createAndRegisterOperatorImpl(null, inputOpSpec, systemStream, context);
         this.inputOperators.put(systemStream, inputOperatorImpl);
       });
   }
@@ -158,18 +157,16 @@ public class OperatorImplGraph {
    * @param prevOperatorSpec  the parent of the current {@code operatorSpec} in the traversal
    * @param operatorSpec  the {@link OperatorSpec} to create the {@link OperatorImpl} for
    * @param inputStream  the source input stream that we traverse the {@link OperatorSpecGraph} from
-   * @param config  the {@link Config} required to instantiate operators
-   * @param context  the {@link TaskContext} required to instantiate operators
+   * @param context the {@link Context} required to instantiate operators
    * @return  the operator implementation for the operatorSpec
    */
   private OperatorImpl createAndRegisterOperatorImpl(OperatorSpec prevOperatorSpec, OperatorSpec operatorSpec,
-      SystemStream inputStream, Config config, TaskContext context) {
-
+      SystemStream inputStream, Context context) {
     if (!operatorImpls.containsKey(operatorSpec.getOpId()) || operatorSpec instanceof JoinOperatorSpec) {
       // Either this is the first time we've seen this operatorSpec, or this is a join operator spec
       // and we need to create 2 partial join operator impls for it. Initialize and register the sub-DAG.
-      OperatorImpl operatorImpl = createOperatorImpl(prevOperatorSpec, operatorSpec, config, context);
-      operatorImpl.init(config, context);
+      OperatorImpl operatorImpl = createOperatorImpl(prevOperatorSpec, operatorSpec, context);
+      operatorImpl.init(context);
       operatorImpl.registerInputStream(inputStream);
 
       if (operatorSpec.getScheduledFn() != null) {
@@ -185,8 +182,7 @@ public class OperatorImplGraph {
       Collection<OperatorSpec> registeredSpecs = operatorSpec.getRegisteredOperatorSpecs();
       registeredSpecs.forEach(registeredSpec -> {
           LOG.debug("Creating operator {} with opCode: {}", registeredSpec.getOpId(), registeredSpec.getOpCode());
-          OperatorImpl nextImpl =
-              createAndRegisterOperatorImpl(operatorSpec, registeredSpec, inputStream, config, context);
+          OperatorImpl nextImpl = createAndRegisterOperatorImpl(operatorSpec, registeredSpec, inputStream, context);
           operatorImpl.registerNextOperator(nextImpl);
         });
       return operatorImpl;
@@ -197,9 +193,8 @@ public class OperatorImplGraph {
 
       // We still need to traverse the DAG further to register the input streams.
       Collection<OperatorSpec> registeredSpecs = operatorSpec.getRegisteredOperatorSpecs();
-      registeredSpecs.forEach(registeredSpec -> {
-          createAndRegisterOperatorImpl(operatorSpec, registeredSpec, inputStream, config, context);
-        });
+      registeredSpecs.forEach(
+          registeredSpec -> createAndRegisterOperatorImpl(operatorSpec, registeredSpec, inputStream, context));
       return operatorImpl;
     }
   }
@@ -209,19 +204,18 @@ public class OperatorImplGraph {
    *
    * @param prevOperatorSpec the original {@link OperatorSpec} that produces output for {@code operatorSpec} from {@link OperatorSpecGraph}
    * @param operatorSpec  the original {@link OperatorSpec} from {@link OperatorSpecGraph}
-   * @param config  the {@link Config} required to instantiate operators
-   * @param context  the {@link TaskContext} required to instantiate operators
+   * @param context  the {@link Context} required to instantiate operators
    * @return  the {@link OperatorImpl} implementation instance
    */
-  OperatorImpl createOperatorImpl(OperatorSpec prevOperatorSpec, OperatorSpec operatorSpec,
-      Config config, TaskContext context) {
+  OperatorImpl createOperatorImpl(OperatorSpec prevOperatorSpec, OperatorSpec operatorSpec, Context context) {
+    Config config = context.getJobContext().getConfig();
     StreamConfig streamConfig = new StreamConfig(config);
     if (operatorSpec instanceof InputOperatorSpec) {
       return new InputOperatorImpl((InputOperatorSpec) operatorSpec);
     } else if (operatorSpec instanceof StreamOperatorSpec) {
       return new StreamOperatorImpl((StreamOperatorSpec) operatorSpec);
     } else if (operatorSpec instanceof SinkOperatorSpec) {
-      return new SinkOperatorImpl((SinkOperatorSpec) operatorSpec, config, context);
+      return new SinkOperatorImpl((SinkOperatorSpec) operatorSpec);
     } else if (operatorSpec instanceof OutputOperatorSpec) {
       String streamId = ((OutputOperatorSpec) operatorSpec).getOutputStream().getStreamId();
       SystemStream systemStream = streamConfig.streamIdToSystemStream(streamId);
@@ -234,12 +228,11 @@ public class OperatorImplGraph {
       return new WindowOperatorImpl((WindowOperatorSpec) operatorSpec, clock);
     } else if (operatorSpec instanceof JoinOperatorSpec) {
       return getOrCreatePartialJoinOpImpls((JoinOperatorSpec) operatorSpec,
-          prevOperatorSpec.equals(((JoinOperatorSpec) operatorSpec).getLeftInputOpSpec()),
-          config, context, clock);
+          prevOperatorSpec.equals(((JoinOperatorSpec) operatorSpec).getLeftInputOpSpec()), clock);
     } else if (operatorSpec instanceof StreamTableJoinOperatorSpec) {
-      return new StreamTableJoinOperatorImpl((StreamTableJoinOperatorSpec) operatorSpec, config, context);
+      return new StreamTableJoinOperatorImpl((StreamTableJoinOperatorSpec) operatorSpec, context);
     } else if (operatorSpec instanceof SendToTableOperatorSpec) {
-      return new SendToTableOperatorImpl((SendToTableOperatorSpec) operatorSpec, config, context);
+      return new SendToTableOperatorImpl((SendToTableOperatorSpec) operatorSpec, context);
     } else if (operatorSpec instanceof BroadcastOperatorSpec) {
       String streamId = ((BroadcastOperatorSpec) operatorSpec).getOutputStream().getStreamId();
       SystemStream systemStream = streamConfig.streamIdToSystemStream(streamId);
@@ -250,14 +243,14 @@ public class OperatorImplGraph {
   }
 
   private PartialJoinOperatorImpl getOrCreatePartialJoinOpImpls(JoinOperatorSpec joinOpSpec, boolean isLeft,
-      Config config, TaskContext context, Clock clock) {
+      Clock clock) {
     // get the per task pair of PartialJoinOperatorImpl for the corresponding {@code joinOpSpec}
     KV<PartialJoinOperatorImpl, PartialJoinOperatorImpl> partialJoinOpImpls = joinOpImpls.computeIfAbsent(joinOpSpec.getOpId(),
         joinOpId -> {
         PartialJoinFunction leftJoinFn = createLeftJoinFn(joinOpSpec);
         PartialJoinFunction rightJoinFn = createRightJoinFn(joinOpSpec);
-        return new KV(new PartialJoinOperatorImpl(joinOpSpec, true, leftJoinFn, rightJoinFn, config, context, clock),
-            new PartialJoinOperatorImpl(joinOpSpec, false, rightJoinFn, leftJoinFn, config, context, clock));
+        return new KV(new PartialJoinOperatorImpl(joinOpSpec, true, leftJoinFn, rightJoinFn, clock),
+            new PartialJoinOperatorImpl(joinOpSpec, false, rightJoinFn, leftJoinFn, clock));
       });
 
     if (isLeft) { // we got here from the left side of the join
@@ -288,12 +281,13 @@ public class OperatorImplGraph {
       }
 
       @Override
-      public void init(Config config, TaskContext context) {
+      public void init(Context context) {
         String leftStoreName = joinOpSpec.getLeftOpId();
-        leftStreamState = (KeyValueStore<Object, TimestampedValue<Object>>) context.getStore(leftStoreName);
+        leftStreamState =
+            (KeyValueStore<Object, TimestampedValue<Object>>) context.getTaskContext().getStore(leftStoreName);
 
         // user-defined joinFn should only be initialized once, so we do it only in left partial join function.
-        joinFn.init(config, context);
+        joinFn.init(context);
       }
 
       @Override
@@ -320,9 +314,10 @@ public class OperatorImplGraph {
       }
 
       @Override
-      public void init(Config config, TaskContext context) {
+      public void init(Context context) {
         String rightStoreName = joinOpSpec.getRightOpId();
-        rightStreamState = (KeyValueStore<Object, TimestampedValue<Object>>) context.getStore(rightStoreName);
+        rightStreamState =
+            (KeyValueStore<Object, TimestampedValue<Object>>) context.getTaskContext().getStore(rightStoreName);
 
         // user-defined joinFn should only be initialized once,
         // so we do it only in left partial join function and not here again.
@@ -405,9 +400,4 @@ public class OperatorImplGraph {
   private boolean hasIntermediateStreams(OperatorSpecGraph specGraph) {
     return !Collections.disjoint(specGraph.getInputOperators().keySet(), specGraph.getOutputStreams().keySet());
   }
-
-  private static MetricsRegistry getMetricsRegistry(TaskContext context) {
-    final SamzaContainerContext containerContext = context.getSamzaContainerContext();
-    return containerContext != null ? containerContext.metricsRegistry : context.getMetricsRegistry();
-  }
 }
index 22fbb1b..407cdd9 100644 (file)
@@ -18,7 +18,7 @@
  */
 package org.apache.samza.operators.impl;
 
-import org.apache.samza.config.Config;
+import org.apache.samza.context.Context;
 import org.apache.samza.operators.KV;
 import org.apache.samza.operators.spec.OperatorSpec;
 import org.apache.samza.operators.spec.OutputOperatorSpec;
@@ -26,7 +26,6 @@ import org.apache.samza.operators.spec.OutputStreamImpl;
 import org.apache.samza.system.OutgoingMessageEnvelope;
 import org.apache.samza.system.SystemStream;
 import org.apache.samza.task.MessageCollector;
-import org.apache.samza.task.TaskContext;
 import org.apache.samza.task.TaskCoordinator;
 
 import java.util.Collection;
@@ -49,7 +48,7 @@ class OutputOperatorImpl<M> extends OperatorImpl<M, Void> {
   }
 
   @Override
-  protected void handleInit(Config config, TaskContext context) {
+  protected void handleInit(Context context) {
   }
 
   @Override
index 0cdde49..55658eb 100644 (file)
 package org.apache.samza.operators.impl;
 
 import org.apache.samza.SamzaException;
-import org.apache.samza.config.Config;
+import org.apache.samza.context.Context;
 import org.apache.samza.operators.functions.PartialJoinFunction;
-import org.apache.samza.util.TimestampedValue;
 import org.apache.samza.operators.spec.JoinOperatorSpec;
 import org.apache.samza.operators.spec.OperatorSpec;
 import org.apache.samza.storage.kv.KeyValueStore;
 import org.apache.samza.task.MessageCollector;
-import org.apache.samza.task.TaskContext;
 import org.apache.samza.task.TaskCoordinator;
 import org.apache.samza.util.Clock;
+import org.apache.samza.util.TimestampedValue;
 
 import java.util.Collection;
 import java.util.Collections;
@@ -54,7 +53,7 @@ class PartialJoinOperatorImpl<K, M, OM, JM> extends OperatorImpl<M, JM> {
   PartialJoinOperatorImpl(JoinOperatorSpec<K, M, OM, JM> joinOpSpec, boolean isLeftSide,
       PartialJoinFunction<K, M, OM, JM> thisPartialJoinFn,
       PartialJoinFunction<K, OM, M, JM> otherPartialJoinFn,
-      Config config, TaskContext context, Clock clock) {
+      Clock clock) {
     this.joinOpSpec = joinOpSpec;
     this.isLeftSide = isLeftSide;
     this.thisPartialJoinFn = thisPartialJoinFn;
@@ -64,8 +63,8 @@ class PartialJoinOperatorImpl<K, M, OM, JM> extends OperatorImpl<M, JM> {
   }
 
   @Override
-  protected void handleInit(Config config, TaskContext context) {
-    this.thisPartialJoinFn.init(config, context);
+  protected void handleInit(Context context) {
+    this.thisPartialJoinFn.init(context);
   }
 
   @Override
index 63e269d..88644ce 100644 (file)
@@ -18,8 +18,8 @@
  */
 package org.apache.samza.operators.impl;
 
-import org.apache.samza.config.Config;
-import org.apache.samza.container.TaskContextImpl;
+import org.apache.samza.context.Context;
+import org.apache.samza.context.TaskContextImpl;
 import org.apache.samza.operators.functions.MapFunction;
 import org.apache.samza.operators.spec.OperatorSpec;
 import org.apache.samza.operators.spec.PartitionByOperatorSpec;
@@ -30,7 +30,6 @@ import org.apache.samza.system.StreamMetadataCache;
 import org.apache.samza.system.SystemStream;
 import org.apache.samza.system.WatermarkMessage;
 import org.apache.samza.task.MessageCollector;
-import org.apache.samza.task.TaskContext;
 import org.apache.samza.task.TaskCoordinator;
 
 import java.util.Collection;
@@ -50,20 +49,20 @@ class PartitionByOperatorImpl<M, K, V> extends OperatorImpl<M, Void> {
   private final ControlMessageSender controlMessageSender;
 
   PartitionByOperatorImpl(PartitionByOperatorSpec<M, K, V> partitionByOpSpec,
-      SystemStream systemStream, TaskContext context) {
+      SystemStream systemStream, Context context) {
     this.partitionByOpSpec = partitionByOpSpec;
     this.systemStream = systemStream;
     this.keyFunction = partitionByOpSpec.getKeyFunction();
     this.valueFunction = partitionByOpSpec.getValueFunction();
-    this.taskName = context.getTaskName().getTaskName();
-    StreamMetadataCache streamMetadataCache = ((TaskContextImpl) context).getStreamMetadataCache();
+    this.taskName = context.getTaskContext().getTaskModel().getTaskName().getTaskName();
+    StreamMetadataCache streamMetadataCache = ((TaskContextImpl) context.getTaskContext()).getStreamMetadataCache();
     this.controlMessageSender = new ControlMessageSender(streamMetadataCache);
   }
 
   @Override
-  protected void handleInit(Config config, TaskContext context) {
-    this.keyFunction.init(config, context);
-    this.valueFunction.init(config, context);
+  protected void handleInit(Context context) {
+    this.keyFunction.init(context);
+    this.valueFunction.init(context);
   }
 
   @Override
index 5ce1328..be3e0a3 100644 (file)
  */
 package org.apache.samza.operators.impl;
 
-import java.util.Collection;
-import java.util.Collections;
-
-import org.apache.samza.config.Config;
+import org.apache.samza.context.Context;
 import org.apache.samza.operators.KV;
 import org.apache.samza.operators.spec.OperatorSpec;
 import org.apache.samza.operators.spec.SendToTableOperatorSpec;
 import org.apache.samza.table.ReadWriteTable;
 import org.apache.samza.task.MessageCollector;
-import org.apache.samza.task.TaskContext;
 import org.apache.samza.task.TaskCoordinator;
 
+import java.util.Collection;
+import java.util.Collections;
+
 
 /**
  * Implementation of a send-stream-to-table operator that stores the record
@@ -43,13 +42,13 @@ public class SendToTableOperatorImpl<K, V> extends OperatorImpl<KV<K, V>, Void>
   private final SendToTableOperatorSpec<K, V> sendToTableOpSpec;
   private final ReadWriteTable<K, V> table;
 
-  SendToTableOperatorImpl(SendToTableOperatorSpec<K, V> sendToTableOpSpec, Config config, TaskContext context) {
+  SendToTableOperatorImpl(SendToTableOperatorSpec<K, V> sendToTableOpSpec, Context context) {
     this.sendToTableOpSpec = sendToTableOpSpec;
-    this.table = (ReadWriteTable) context.getTable(sendToTableOpSpec.getTableSpec().getId());
+    this.table = (ReadWriteTable) context.getTaskContext().getTable(sendToTableOpSpec.getTableSpec().getId());
   }
 
   @Override
-  protected void handleInit(Config config, TaskContext context) {
+  protected void handleInit(Context context) {
   }
 
   @Override
index 5dbe27f..6fe9006 100644 (file)
  */
 package org.apache.samza.operators.impl;
 
-import org.apache.samza.config.Config;
+import org.apache.samza.context.Context;
 import org.apache.samza.operators.functions.SinkFunction;
 import org.apache.samza.operators.spec.OperatorSpec;
 import org.apache.samza.operators.spec.SinkOperatorSpec;
 import org.apache.samza.task.MessageCollector;
-import org.apache.samza.task.TaskContext;
 import org.apache.samza.task.TaskCoordinator;
 
 import java.util.Collection;
@@ -38,14 +37,14 @@ class SinkOperatorImpl<M> extends OperatorImpl<M, Void> {
   private final SinkOperatorSpec<M> sinkOpSpec;
   private final SinkFunction<M> sinkFn;
 
-  SinkOperatorImpl(SinkOperatorSpec<M> sinkOpSpec, Config config, TaskContext context) {
+  SinkOperatorImpl(SinkOperatorSpec<M> sinkOpSpec) {
     this.sinkOpSpec = sinkOpSpec;
     this.sinkFn = sinkOpSpec.getSinkFn();
   }
 
   @Override
-  protected void handleInit(Config config, TaskContext context) {
-    this.sinkFn.init(config, context);
+  protected void handleInit(Context context) {
+    this.sinkFn.init(context);
   }
 
   @Override
index 6cd426b..1a615bd 100644 (file)
  */
 package org.apache.samza.operators.impl;
 
-import org.apache.samza.config.Config;
+import org.apache.samza.context.Context;
 import org.apache.samza.operators.functions.FlatMapFunction;
 import org.apache.samza.operators.spec.OperatorSpec;
 import org.apache.samza.operators.spec.StreamOperatorSpec;
 import org.apache.samza.task.MessageCollector;
-import org.apache.samza.task.TaskContext;
 import org.apache.samza.task.TaskCoordinator;
 
 import java.util.Collection;
@@ -46,8 +45,8 @@ class StreamOperatorImpl<M, RM> extends OperatorImpl<M, RM> {
   }
 
   @Override
-  protected void handleInit(Config config, TaskContext context) {
-    transformFn.init(config, context);
+  protected void handleInit(Context context) {
+    transformFn.init(context);
   }
 
   @Override
index 54a5770..d44241d 100644 (file)
  */
 package org.apache.samza.operators.impl;
 
-import java.util.Collection;
-import java.util.Collections;
-
-import org.apache.samza.config.Config;
+import org.apache.samza.context.Context;
 import org.apache.samza.operators.KV;
 import org.apache.samza.operators.spec.OperatorSpec;
 import org.apache.samza.operators.spec.StreamTableJoinOperatorSpec;
 import org.apache.samza.table.ReadableTable;
 import org.apache.samza.task.MessageCollector;
-import org.apache.samza.task.TaskContext;
 import org.apache.samza.task.TaskCoordinator;
 
+import java.util.Collection;
+import java.util.Collections;
+
 
 /**
  * Implementation of a stream-table join operator that first retrieve the value of
@@ -45,15 +44,14 @@ class StreamTableJoinOperatorImpl<K, M, R extends KV, JM> extends OperatorImpl<M
   private final StreamTableJoinOperatorSpec<K, M, R, JM> joinOpSpec;
   private final ReadableTable<K, ?> table;
 
-  StreamTableJoinOperatorImpl(StreamTableJoinOperatorSpec<K, M, R, JM> joinOpSpec,
-      Config config, TaskContext context) {
+  StreamTableJoinOperatorImpl(StreamTableJoinOperatorSpec<K, M, R, JM> joinOpSpec, Context context) {
     this.joinOpSpec = joinOpSpec;
-    this.table = (ReadableTable) context.getTable(joinOpSpec.getTableSpec().getId());
+    this.table = (ReadableTable) context.getTaskContext().getTable(joinOpSpec.getTableSpec().getId());
   }
 
   @Override
-  protected void handleInit(Config config, TaskContext context) {
-    this.joinOpSpec.getJoinFn().init(config, context);
+  protected void handleInit(Context context) {
+    this.joinOpSpec.getJoinFn().init(context);
   }
 
   @Override
index b175671..c09c5f8 100644 (file)
 package org.apache.samza.operators.impl;
 
 import com.google.common.base.Preconditions;
-import org.apache.samza.config.Config;
+import org.apache.samza.context.Context;
 import org.apache.samza.operators.functions.FoldLeftFunction;
 import org.apache.samza.operators.functions.MapFunction;
 import org.apache.samza.operators.functions.SupplierFunction;
 import org.apache.samza.operators.impl.store.TimeSeriesKey;
 import org.apache.samza.operators.impl.store.TimeSeriesStore;
 import org.apache.samza.operators.impl.store.TimeSeriesStoreImpl;
-import org.apache.samza.util.TimestampedValue;
 import org.apache.samza.operators.spec.OperatorSpec;
 import org.apache.samza.operators.spec.WindowOperatorSpec;
 import org.apache.samza.operators.triggers.FiringType;
@@ -45,9 +44,9 @@ import org.apache.samza.operators.windows.internal.WindowType;
 import org.apache.samza.storage.kv.ClosableIterator;
 import org.apache.samza.storage.kv.KeyValueStore;
 import org.apache.samza.task.MessageCollector;
-import org.apache.samza.task.TaskContext;
 import org.apache.samza.task.TaskCoordinator;
 import org.apache.samza.util.Clock;
+import org.apache.samza.util.TimestampedValue;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -111,23 +110,23 @@ public class WindowOperatorImpl<M, K> extends OperatorImpl<M, WindowPane<K, Obje
   }
 
   @Override
-  protected void handleInit(Config config, TaskContext context) {
+  protected void handleInit(Context context) {
 
     KeyValueStore<TimeSeriesKey<K>, Object> store =
-        (KeyValueStore<TimeSeriesKey<K>, Object>) context.getStore(windowOpSpec.getOpId());
+        (KeyValueStore<TimeSeriesKey<K>, Object>) context.getTaskContext().getStore(windowOpSpec.getOpId());
 
     if (initializer != null) {
-      initializer.init(config, context);
+      initializer.init(context);
     }
 
     if (keyFn != null) {
-      keyFn.init(config, context);
+      keyFn.init(context);
     }
 
     // For aggregating windows, we use the store in over-write mode since we only retain the aggregated
     // value. Else, we use the store in append-mode.
     if (foldLeftFn != null) {
-      foldLeftFn.init(config, context);
+      foldLeftFn.init(context);
       timeSeriesStore = new TimeSeriesStoreImpl(store, false);
     } else {
       timeSeriesStore = new TimeSeriesStoreImpl(store, true);
index 4e640dc..c1d62f5 100644 (file)
@@ -20,12 +20,11 @@ package org.apache.samza.operators.spec;
 
 import java.util.ArrayList;
 import java.util.Collection;
-import org.apache.samza.config.Config;
+import org.apache.samza.context.Context;
 import org.apache.samza.operators.functions.FilterFunction;
 import org.apache.samza.operators.functions.FlatMapFunction;
 import org.apache.samza.operators.functions.ScheduledFunction;
 import org.apache.samza.operators.functions.WatermarkFunction;
-import org.apache.samza.task.TaskContext;
 
 
 /**
@@ -50,8 +49,8 @@ class FilterOperatorSpec<M> extends StreamOperatorSpec<M, M> {
       }
 
       @Override
-      public void init(Config config, TaskContext context) {
-        filterFn.init(config, context);
+      public void init(Context context) {
+        filterFn.init(context);
       }
 
       @Override
index 6ce522f..d3a587a 100644 (file)
@@ -20,12 +20,11 @@ package org.apache.samza.operators.spec;
 
 import java.util.ArrayList;
 import java.util.Collection;
-import org.apache.samza.config.Config;
+import org.apache.samza.context.Context;
 import org.apache.samza.operators.functions.FlatMapFunction;
 import org.apache.samza.operators.functions.MapFunction;
 import org.apache.samza.operators.functions.ScheduledFunction;
 import org.apache.samza.operators.functions.WatermarkFunction;
-import org.apache.samza.task.TaskContext;
 
 
 /**
@@ -53,8 +52,8 @@ class MapOperatorSpec<M, OM> extends StreamOperatorSpec<M, OM> {
       }
 
       @Override
-      public void init(Config config, TaskContext context) {
-        mapFn.init(config, context);
+      public void init(Context context) {
+        mapFn.init(context);
       }
 
       @Override
index 26e52f2..3149989 100644 (file)
@@ -23,6 +23,7 @@ import com.google.common.base.Preconditions;
 import com.google.common.collect.ImmutableList;
 import com.google.common.util.concurrent.ThreadFactoryBuilder;
 import java.util.Map;
+import java.util.Optional;
 import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Executors;
@@ -35,6 +36,11 @@ import org.apache.samza.config.JobCoordinatorConfig;
 import org.apache.samza.config.TaskConfigJava;
 import org.apache.samza.container.SamzaContainer;
 import org.apache.samza.container.SamzaContainerListener;
+import org.apache.samza.context.ApplicationContainerContext;
+import org.apache.samza.context.ApplicationContainerContextFactory;
+import org.apache.samza.context.ApplicationTaskContext;
+import org.apache.samza.context.ApplicationTaskContextFactory;
+import org.apache.samza.context.JobContextImpl;
 import org.apache.samza.coordinator.JobCoordinator;
 import org.apache.samza.coordinator.JobCoordinatorFactory;
 import org.apache.samza.coordinator.JobCoordinatorListener;
@@ -46,6 +52,8 @@ import org.apache.samza.util.ScalaJavaUtil;
 import org.apache.samza.util.Util;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
+import scala.Option;
+
 
 /**
  * StreamProcessor can be embedded in any application or executed in a distributed environment (aka cluster) as an
@@ -97,6 +105,16 @@ public class StreamProcessor {
   private final JobCoordinator jobCoordinator;
   private final ProcessorLifecycleListener processorListener;
   private final TaskFactory taskFactory;
+  /**
+   * Type parameter needs to be {@link ApplicationContainerContext} so that we can eventually call the base methods of
+   * the context object.
+   */
+  private final Optional<ApplicationContainerContextFactory<ApplicationContainerContext>> applicationDefinedContainerContextFactoryOptional;
+  /**
+   * Type parameter needs to be {@link ApplicationTaskContext} so that we can eventually call the base methods of the
+   * context object.
+   */
+  private final Optional<ApplicationTaskContextFactory<ApplicationTaskContext>> applicationDefinedTaskContextFactoryOptional;
   private final Map<String, MetricsReporter> customMetricsReporter;
   private final Config config;
   private final long taskShutdownMs;
@@ -143,57 +161,60 @@ public class StreamProcessor {
   JobCoordinatorListener jobCoordinatorListener = null;
 
   /**
-   * StreamProcessor encapsulates and manages the lifecycle of {@link JobCoordinator} and {@link SamzaContainer}.
-   *
-   * <p>
-   * On startup, StreamProcessor starts the JobCoordinator. Schedules the SamzaContainer to run in a ExecutorService
-   * when it receives new {@link JobModel} from JobCoordinator.
-   * <p>
-   *
-   * <b>Note:</b> Lifecycle of the ExecutorService is fully managed by the StreamProcessor.
+   * Same as {@link #StreamProcessor(Config, Map, TaskFactory, ProcessorLifecycleListener, JobCoordinator)}, except
+   * it creates a {@link JobCoordinator} instead of accepting it as an argument.
    *
-   * @param config configuration required to launch {@link JobCoordinator} and {@link SamzaContainer}.
-   * @param customMetricsReporters metricReporter instances that will be used by SamzaContainer and JobCoordinator to report metrics.
-   * @param taskFactory the {@link TaskFactory} to be used for creating task instances.
-   * @param processorListener listener to the StreamProcessor life cycle.
+   * Deprecated: Use {@link #StreamProcessor(Config, Map, TaskFactory, Optional, Optional,
+   * StreamProcessorLifecycleListenerFactory, JobCoordinator)} instead.
    */
+  @Deprecated
   public StreamProcessor(Config config, Map<String, MetricsReporter> customMetricsReporters, TaskFactory taskFactory,
       ProcessorLifecycleListener processorListener) {
     this(config, customMetricsReporters, taskFactory, processorListener, null);
   }
 
   /**
-   * Same as {@link #StreamProcessor(Config, Map, TaskFactory, ProcessorLifecycleListener)}, except the
-   * {@link JobCoordinator} is given for this {@link StreamProcessor}.
-   * @param config configuration required to launch {@link JobCoordinator} and {@link SamzaContainer}
-   * @param customMetricsReporters metric Reporter
-   * @param taskFactory task factory to instantiate the Task
+   * Same as {@link #StreamProcessor(Config, Map, TaskFactory, Optional, Optional,
+   * StreamProcessorLifecycleListenerFactory, JobCoordinator)}, with the following differences:
+   * <ol>
+   *   <li>Passes null for application-defined context factories</li>
+   *   <li>Accepts a {@link ProcessorLifecycleListener} directly instead of a
+   *   {@link StreamProcessorLifecycleListenerFactory}</li>
+   * </ol>
+   * Deprecated: Use {@link #StreamProcessor(Config, Map, TaskFactory, Optional, Optional,
+   * StreamProcessorLifecycleListenerFactory, JobCoordinator)} instead.
+   *
    * @param processorListener listener to the StreamProcessor life cycle
-   * @param jobCoordinator the instance of {@link JobCoordinator}
    */
+  @Deprecated
   public StreamProcessor(Config config, Map<String, MetricsReporter> customMetricsReporters, TaskFactory taskFactory,
       ProcessorLifecycleListener processorListener, JobCoordinator jobCoordinator) {
-    this(config, customMetricsReporters, taskFactory, sp -> processorListener, jobCoordinator);
+    this(config, customMetricsReporters, taskFactory, Optional.empty(), Optional.empty(), sp -> processorListener,
+        jobCoordinator);
   }
 
   /**
-   * Same as {@link #StreamProcessor(Config, Map, TaskFactory, ProcessorLifecycleListener, JobCoordinator)}, except
-   * there is a {@link StreamProcessorLifecycleListenerFactory} as input instead of {@link ProcessorLifecycleListener}.
-   * This is useful to create a {@link ProcessorLifecycleListener} with a reference to this {@link StreamProcessor}
+   * Builds a {@link StreamProcessor} with full specification of processing components.
    *
    * @param config configuration required to launch {@link JobCoordinator} and {@link SamzaContainer}
-   * @param customMetricsReporters metric Reporter
+   * @param customMetricsReporters registered with the metrics system to report metrics
    * @param taskFactory task factory to instantiate the Task
-   * @param listenerFactory listener to the StreamProcessor life cycle
+   * @param applicationDefinedContainerContextFactoryOptional optional factory for application-defined container context
+   * @param applicationDefinedTaskContextFactoryOptional optional factory for application-defined task context
+   * @param listenerFactory factory for creating a listener to the StreamProcessor life cycle
    * @param jobCoordinator the instance of {@link JobCoordinator}
    */
   public StreamProcessor(Config config, Map<String, MetricsReporter> customMetricsReporters, TaskFactory taskFactory,
+      Optional<ApplicationContainerContextFactory<ApplicationContainerContext>> applicationDefinedContainerContextFactoryOptional,
+      Optional<ApplicationTaskContextFactory<ApplicationTaskContext>> applicationDefinedTaskContextFactoryOptional,
       StreamProcessorLifecycleListenerFactory listenerFactory, JobCoordinator jobCoordinator) {
     Preconditions.checkNotNull(listenerFactory, "StreamProcessorListenerFactory cannot be null.");
-    this.taskFactory = taskFactory;
     this.config = config;
-    this.taskShutdownMs = new TaskConfigJava(config).getShutdownMs();
     this.customMetricsReporter = customMetricsReporters;
+    this.taskFactory = taskFactory;
+    this.applicationDefinedContainerContextFactoryOptional = applicationDefinedContainerContextFactoryOptional;
+    this.applicationDefinedTaskContextFactoryOptional = applicationDefinedTaskContextFactoryOptional;
+    this.taskShutdownMs = new TaskConfigJava(config).getShutdownMs();
     this.jobCoordinator = (jobCoordinator != null) ? jobCoordinator : createJobCoordinator();
     this.jobCoordinatorListener = createJobCoordinatorListener();
     this.jobCoordinator.setListener(jobCoordinatorListener);
@@ -283,7 +304,10 @@ public class StreamProcessor {
 
   @VisibleForTesting
   SamzaContainer createSamzaContainer(String processorId, JobModel jobModel) {
-    return SamzaContainer.apply(processorId, jobModel, config, ScalaJavaUtil.toScalaMap(customMetricsReporter), taskFactory);
+    return SamzaContainer.apply(processorId, jobModel, ScalaJavaUtil.toScalaMap(this.customMetricsReporter),
+        this.taskFactory, JobContextImpl.fromConfigWithDefaults(this.config),
+        Option.apply(this.applicationDefinedContainerContextFactoryOptional.orElse(null)),
+        Option.apply(this.applicationDefinedTaskContextFactoryOptional.orElse(null)));
   }
 
   private JobCoordinator createJobCoordinator() {
index 7100482..a5eeba1 100644 (file)
@@ -167,7 +167,8 @@ public class LocalApplicationRunner implements ApplicationRunner {
     // TODO: the null processorId has to be fixed after SAMZA-1835
     appDesc.getMetricsReporterFactories().forEach((name, factory) ->
         reporters.put(name, factory.getMetricsReporter(name, null, config)));
-    return new StreamProcessor(config, reporters, taskFactory, listenerFactory, null);
+    return new StreamProcessor(config, reporters, taskFactory, appDesc.getApplicationContainerContextFactory(),
+        appDesc.getApplicationTaskContextFactory(), listenerFactory, null);
   }
 
   /**
index add7e69..94ff1eb 100644 (file)
@@ -25,8 +25,8 @@ import java.util.Random;
 import org.slf4j.MDC;
 import org.apache.samza.SamzaException;
 import org.apache.samza.application.ApplicationDescriptor;
-import org.apache.samza.application.ApplicationDescriptorUtil;
 import org.apache.samza.application.ApplicationDescriptorImpl;
+import org.apache.samza.application.ApplicationDescriptorUtil;
 import org.apache.samza.application.ApplicationUtil;
 import org.apache.samza.config.Config;
 import org.apache.samza.config.JobConfig;
@@ -36,6 +36,7 @@ import org.apache.samza.container.ContainerHeartbeatMonitor;
 import org.apache.samza.container.SamzaContainer;
 import org.apache.samza.container.SamzaContainer$;
 import org.apache.samza.container.SamzaContainerListener;
+import org.apache.samza.context.JobContextImpl;
 import org.apache.samza.job.model.JobModel;
 import org.apache.samza.metrics.MetricsReporter;
 import org.apache.samza.task.TaskFactory;
@@ -44,6 +45,8 @@ import org.apache.samza.util.SamzaUncaughtExceptionHandler;
 import org.apache.samza.util.ScalaJavaUtil;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
+import scala.Option;
+
 
 /**
  * Launches and manages the lifecycle for {@link SamzaContainer}s in YARN.
@@ -93,9 +96,11 @@ public class LocalContainerRunner {
     SamzaContainer container = SamzaContainer$.MODULE$.apply(
         containerId,
         jobModel,
-        config,
         ScalaJavaUtil.toScalaMap(loadMetricsReporters(appDesc, containerId, config)),
-        taskFactory);
+        taskFactory,
+        JobContextImpl.fromConfigWithDefaults(config),
+        Option.apply(appDesc.getApplicationContainerContextFactory().orElse(null)),
+        Option.apply(appDesc.getApplicationTaskContextFactory().orElse(null)));
 
     ProcessorLifecycleListener listener = appDesc.getProcessorLifecycleListenerFactory()
         .createInstance(new ProcessorContext() { }, config);
index 9a76d75..be074ee 100644 (file)
@@ -27,13 +27,14 @@ import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.Map.Entry;
-
 import org.apache.samza.SamzaException;
 import org.apache.samza.config.Config;
 import org.apache.samza.config.JavaStorageConfig;
 import org.apache.samza.config.JavaSystemConfig;
 import org.apache.samza.config.StorageConfig;
-import org.apache.samza.container.SamzaContainerContext;
+import org.apache.samza.context.ContainerContext;
+import org.apache.samza.context.ContainerContextImpl;
+import org.apache.samza.context.JobContextImpl;
 import org.apache.samza.coordinator.JobModelManager;
 import org.apache.samza.coordinator.stream.CoordinatorStreamManager;
 import org.apache.samza.job.model.ContainerModel;
@@ -209,8 +210,7 @@ public class StorageRecovery extends CommandLine {
 
     for (ContainerModel containerModel : containers.values()) {
       HashMap<String, StorageEngine> taskStores = new HashMap<String, StorageEngine>();
-      SamzaContainerContext containerContext = new SamzaContainerContext(containerModel.getId(), jobConfig, containerModel.getTasks()
-          .keySet(), new MetricsRegistryMap());
+      ContainerContext containerContext = new ContainerContextImpl(containerModel, new MetricsRegistryMap());
 
       for (TaskModel taskModel : containerModel.getTasks().values()) {
         HashMap<String, SystemConsumer> storeConsumers = getStoreConsumers();
@@ -233,6 +233,7 @@ public class StorageRecovery extends CommandLine {
                 null,
                 new MetricsRegistryMap(),
                 changeLogSystemStreamPartition,
+                JobContextImpl.fromConfigWithDefaults(jobConfig),
                 containerContext);
             taskStores.put(storeName, storageEngine);
           }
index ae72414..d7b15a4 100644 (file)
  */
 package org.apache.samza.table;
 
-import java.util.HashMap;
-import java.util.Map;
-
+import com.google.common.base.Preconditions;
 import org.apache.samza.SamzaException;
 import org.apache.samza.config.Config;
 import org.apache.samza.config.JavaTableConfig;
-import org.apache.samza.container.SamzaContainerContext;
+import org.apache.samza.context.Context;
 import org.apache.samza.serializers.KVSerde;
 import org.apache.samza.serializers.Serde;
-import org.apache.samza.task.TaskContext;
 import org.apache.samza.util.Util;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
-import com.google.common.base.Preconditions;
+import java.util.HashMap;
+import java.util.Map;
 
 
 /**
@@ -97,12 +95,10 @@ public class TableManager {
 
   /**
    * Initialize table providers with container and task contexts
-   * @param containerContext context for the Samza container
-   * @param taskContext context for the current task, nullable for global tables
+   * @param context context for the task
    */
-  public void init(SamzaContainerContext containerContext, TaskContext taskContext) {
-    Preconditions.checkNotNull(containerContext, "null container context.");
-    tableContexts.values().forEach(ctx -> ctx.tableProvider.init(containerContext, taskContext));
+  public void init(Context context) {
+    tableContexts.values().forEach(ctx -> ctx.tableProvider.init(context));
     initialized = true;
   }
 
index b7aa33c..32d2bed 100644 (file)
 
 package org.apache.samza.table.caching;
 
-import java.util.ArrayList;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
-import java.util.concurrent.CompletableFuture;
-import java.util.concurrent.atomic.AtomicLong;
-import java.util.stream.Collectors;
-
+import com.google.common.base.Preconditions;
 import org.apache.samza.SamzaException;
-import org.apache.samza.container.SamzaContainerContext;
+import org.apache.samza.context.Context;
 import org.apache.samza.storage.kv.Entry;
 import org.apache.samza.table.ReadWriteTable;
 import org.apache.samza.table.ReadableTable;
 import org.apache.samza.table.utils.DefaultTableReadMetrics;
 import org.apache.samza.table.utils.DefaultTableWriteMetrics;
 import org.apache.samza.table.utils.TableMetricsUtil;
-import org.apache.samza.task.TaskContext;
 
-import com.google.common.base.Preconditions;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.atomic.AtomicLong;
+import java.util.stream.Collectors;
 
 
 /**
@@ -91,10 +89,10 @@ public class CachingTable<K, V> implements ReadWriteTable<K, V> {
    * {@inheritDoc}
    */
   @Override
-  public void init(SamzaContainerContext containerContext, TaskContext taskContext) {
-    readMetrics = new DefaultTableReadMetrics(containerContext, taskContext, this, tableId);
-    writeMetrics = new DefaultTableWriteMetrics(containerContext, taskContext, this, tableId);
-    TableMetricsUtil tableMetricsUtil = new TableMetricsUtil(containerContext, taskContext, this, tableId);
+  public void init(Context context) {
+    readMetrics = new DefaultTableReadMetrics(context, this, tableId);
+    writeMetrics = new DefaultTableWriteMetrics(context, this, tableId);
+    TableMetricsUtil tableMetricsUtil = new TableMetricsUtil(context, this, tableId);
     tableMetricsUtil.newGauge("hit-rate", () -> hitRate());
     tableMetricsUtil.newGauge("miss-rate", () -> missRate());
     tableMetricsUtil.newGauge("req-count", () -> requestCount());
index d5f7767..c959a56 100644 (file)
@@ -54,13 +54,13 @@ public class CachingTableProvider extends BaseTableProvider {
   @Override
   public Table getTable() {
     String realTableId = tableSpec.getConfig().get(REAL_TABLE_ID);
-    ReadableTable table = (ReadableTable) taskContext.getTable(realTableId);
+    ReadableTable table = (ReadableTable) this.context.getTaskContext().getTable(realTableId);
 
     String cacheTableId = tableSpec.getConfig().get(CACHE_TABLE_ID);
     ReadWriteTable cache;
 
     if (cacheTableId != null) {
-      cache = (ReadWriteTable) taskContext.getTable(cacheTableId);
+      cache = (ReadWriteTable) this.context.getTaskContext().getTable(cacheTableId);
     } else {
       cache = createDefaultCacheTable(realTableId);
       defaultCaches.add(cache);
@@ -68,7 +68,7 @@ public class CachingTableProvider extends BaseTableProvider {
 
     boolean isWriteAround = Boolean.parseBoolean(tableSpec.getConfig().get(WRITE_AROUND));
     CachingTable cachingTable = new CachingTable(tableSpec.getId(), table, cache, isWriteAround);
-    cachingTable.init(containerContext, taskContext);
+    cachingTable.init(this.context);
     return cachingTable;
   }
 
@@ -97,7 +97,7 @@ public class CachingTableProvider extends BaseTableProvider {
         readTtlMs, writeTtlMs, cacheSize));
 
     GuavaCacheTable cacheTable = new GuavaCacheTable(tableId + "-def-cache", cacheBuilder.build());
-    cacheTable.init(containerContext, taskContext);
+    cacheTable.init(this.context);
 
     return cacheTable;
   }
index a8beb3b..5f77ee4 100644 (file)
 
 package org.apache.samza.table.caching.guava;
 
-import java.util.ArrayList;
-import java.util.List;
-import java.util.Map;
-import java.util.concurrent.CompletableFuture;
-
+import com.google.common.cache.Cache;
 import org.apache.samza.SamzaException;
-import org.apache.samza.container.SamzaContainerContext;
+import org.apache.samza.context.Context;
 import org.apache.samza.storage.kv.Entry;
 import org.apache.samza.table.ReadWriteTable;
 import org.apache.samza.table.utils.TableMetricsUtil;
-import org.apache.samza.task.TaskContext;
 
-import com.google.common.cache.Cache;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.CompletableFuture;
 
 
 /**
@@ -54,8 +52,8 @@ public class GuavaCacheTable<K, V> implements ReadWriteTable<K, V> {
    * {@inheritDoc}
    */
   @Override
-  public void init(SamzaContainerContext containerContext, TaskContext taskContext) {
-    TableMetricsUtil tableMetricsUtil = new TableMetricsUtil(containerContext, taskContext, this, tableId);
+  public void init(Context context) {
+    TableMetricsUtil tableMetricsUtil = new TableMetricsUtil(context, this, tableId);
     // hit- and miss-rate are provided by CachingTable.
     tableMetricsUtil.newGauge("evict-count", () -> cache.stats().evictionCount());
   }
index 1513249..39f332e 100644 (file)
@@ -47,7 +47,7 @@ public class GuavaCacheTableProvider extends BaseTableProvider {
   public Table getTable() {
     Cache guavaCache = SerdeUtils.deserialize(GUAVA_CACHE, tableSpec.getConfig().get(GUAVA_CACHE));
     GuavaCacheTable table = new GuavaCacheTable(tableSpec.getId(), guavaCache);
-    table.init(containerContext, taskContext);
+    table.init(this.context);
     guavaTables.add(table);
     return table;
   }
index 9ef4c1b..4cbc270 100644 (file)
 
 package org.apache.samza.table.remote;
 
-import java.util.List;
-import java.util.concurrent.CompletableFuture;
-import java.util.concurrent.ExecutorService;
-import java.util.stream.Collectors;
-
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.Preconditions;
 import org.apache.samza.SamzaException;
-import org.apache.samza.container.SamzaContainerContext;
+import org.apache.samza.context.Context;
 import org.apache.samza.storage.kv.Entry;
 import org.apache.samza.table.ReadWriteTable;
 import org.apache.samza.table.utils.DefaultTableWriteMetrics;
 import org.apache.samza.table.utils.TableMetricsUtil;
-import org.apache.samza.task.TaskContext;
 
-import com.google.common.annotations.VisibleForTesting;
-import com.google.common.base.Preconditions;
+import java.util.List;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutorService;
+import java.util.stream.Collectors;
 
 
 /**
@@ -63,10 +61,10 @@ public class RemoteReadWriteTable<K, V> extends RemoteReadableTable<K, V> implem
    * {@inheritDoc}
    */
   @Override
-  public void init(SamzaContainerContext containerContext, TaskContext taskContext) {
-    super.init(containerContext, taskContext);
-    writeMetrics = new DefaultTableWriteMetrics(containerContext, taskContext, this, tableId);
-    TableMetricsUtil tableMetricsUtil = new TableMetricsUtil(containerContext, taskContext, this, tableId);
+  public void init(Context context) {
+    super.init(context);
+    writeMetrics = new DefaultTableWriteMetrics(context, this, tableId);
+    TableMetricsUtil tableMetricsUtil = new TableMetricsUtil(context, this, tableId);
     writeRateLimiter.setTimerMetric(tableMetricsUtil.newTimer("put-throttle-ns"));
   }
 
index b3d82f3..9487e39 100644 (file)
 
 package org.apache.samza.table.remote;
 
-import java.util.Collection;
-import java.util.Collections;
-import java.util.List;
-import java.util.Map;
-import java.util.concurrent.CompletableFuture;
-import java.util.concurrent.ExecutorService;
-import java.util.function.BiFunction;
-import java.util.function.Function;
-
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.Preconditions;
 import org.apache.samza.SamzaException;
-import org.apache.samza.container.SamzaContainerContext;
+import org.apache.samza.context.Context;
 import org.apache.samza.metrics.Timer;
 import org.apache.samza.storage.kv.Entry;
 import org.apache.samza.table.ReadableTable;
 import org.apache.samza.table.utils.DefaultTableReadMetrics;
 import org.apache.samza.table.utils.TableMetricsUtil;
-import org.apache.samza.task.TaskContext;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
-import com.google.common.annotations.VisibleForTesting;
-import com.google.common.base.Preconditions;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutorService;
+import java.util.function.BiFunction;
+import java.util.function.Function;
 
 
 /**
@@ -110,9 +108,9 @@ public class RemoteReadableTable<K, V> implements ReadableTable<K, V> {
    * {@inheritDoc}
    */
   @Override
-  public void init(SamzaContainerContext containerContext, TaskContext taskContext) {
-    readMetrics = new DefaultTableReadMetrics(containerContext, taskContext, this, tableId);
-    TableMetricsUtil tableMetricsUtil = new TableMetricsUtil(containerContext, taskContext, this, tableId);
+  public void init(Context context) {
+    readMetrics = new DefaultTableReadMetrics(context, this, tableId);
+    TableMetricsUtil tableMetricsUtil = new TableMetricsUtil(context, this, tableId);
     readRateLimiter.setTimerMetric(tableMetricsUtil.newTimer("get-throttle-ns"));
   }
 
index cae0bbd..9415e70 100644 (file)
 
 package org.apache.samza.table.remote;
 
-import java.util.ArrayList;
-import java.util.List;
-import java.util.Map;
-import java.util.concurrent.ConcurrentHashMap;
-import java.util.concurrent.ExecutorService;
-import java.util.concurrent.Executors;
-import java.util.concurrent.ScheduledExecutorService;
-
 import org.apache.samza.table.Table;
 import org.apache.samza.table.TableSpec;
 import org.apache.samza.table.retry.RetriableReadFunction;
@@ -37,6 +29,14 @@ import org.apache.samza.table.utils.SerdeUtils;
 import org.apache.samza.table.utils.TableMetricsUtil;
 import org.apache.samza.util.RateLimiter;
 
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.ScheduledExecutorService;
+
 import static org.apache.samza.table.remote.RemoteTableDescriptor.RL_READ_TAG;
 import static org.apache.samza.table.remote.RemoteTableDescriptor.RL_WRITE_TAG;
 
@@ -83,7 +83,7 @@ public class RemoteTableProvider extends BaseTableProvider {
     TableReadFunction readFn = getReadFn();
     RateLimiter rateLimiter = deserializeObject(RATE_LIMITER);
     if (rateLimiter != null) {
-      rateLimiter.init(containerContext.config, taskContext);
+      rateLimiter.init(this.context);
     }
     TableRateLimiter.CreditFunction<?, ?> readCreditFn = deserializeObject(READ_CREDIT_FN);
     TableRateLimiter readRateLimiter = new TableRateLimiter(tableSpec.getId(), rateLimiter, readCreditFn, RL_READ_TAG);
@@ -150,7 +150,7 @@ public class RemoteTableProvider extends BaseTableProvider {
           writeRateLimiter, tableExecutors.get(tableId), callbackExecutors.get(tableId));
     }
 
-    TableMetricsUtil metricsUtil = new TableMetricsUtil(containerContext, taskContext, table, tableId);
+    TableMetricsUtil metricsUtil = new TableMetricsUtil(this.context, table, tableId);
     if (readRetryPolicy != null) {
       ((RetriableReadFunction) readFn).setMetrics(metricsUtil);
     }
@@ -158,7 +158,7 @@ public class RemoteTableProvider extends BaseTableProvider {
       ((RetriableWriteFunction) writeFn).setMetrics(metricsUtil);
     }
 
-    table.init(containerContext, taskContext);
+    table.init(this.context);
     tables.add(table);
     return table;
   }
@@ -184,7 +184,7 @@ public class RemoteTableProvider extends BaseTableProvider {
   private TableReadFunction<?, ?> getReadFn() {
     TableReadFunction<?, ?> readFn = deserializeObject(READ_FN);
     if (readFn != null) {
-      readFn.init(containerContext.config, taskContext);
+      readFn.init(this.context);
     }
     return readFn;
   }
@@ -192,7 +192,7 @@ public class RemoteTableProvider extends BaseTableProvider {
   private TableWriteFunction<?, ?> getWriteFn() {
     TableWriteFunction<?, ?> writeFn = deserializeObject(WRITE_FN);
     if (writeFn != null) {
-      writeFn.init(containerContext.config, taskContext);
+      writeFn.init(this.context);
     }
     return writeFn;
   }
index 960e2a4..dfbd835 100644 (file)
@@ -22,10 +22,9 @@ import java.util.HashMap;
 import java.util.Map;
 import org.apache.samza.config.Config;
 import org.apache.samza.config.JavaTableConfig;
-import org.apache.samza.container.SamzaContainerContext;
+import org.apache.samza.context.Context;
 import org.apache.samza.table.TableProvider;
 import org.apache.samza.table.TableSpec;
-import org.apache.samza.task.TaskContext;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -39,8 +38,7 @@ abstract public class BaseTableProvider implements TableProvider {
 
   final protected TableSpec tableSpec;
 
-  protected SamzaContainerContext containerContext;
-  protected TaskContext taskContext;
+  protected Context context;
 
   public BaseTableProvider(TableSpec tableSpec) {
     this.tableSpec = tableSpec;
@@ -50,9 +48,8 @@ abstract public class BaseTableProvider implements TableProvider {
    * {@inheritDoc}
    */
   @Override
-  public void init(SamzaContainerContext containerContext, TaskContext taskContext) {
-    this.containerContext = containerContext;
-    this.taskContext = taskContext;
+  public void init(Context context) {
+    this.context = context;
   }
 
   /**
index 2acd082..090c8c1 100644 (file)
  */
 package org.apache.samza.table.utils;
 
-import org.apache.samza.container.SamzaContainerContext;
+import org.apache.samza.context.Context;
 import org.apache.samza.metrics.Counter;
 import org.apache.samza.metrics.Timer;
 import org.apache.samza.table.Table;
-import org.apache.samza.task.TaskContext;
 
 
 /**
@@ -39,14 +38,12 @@ public class DefaultTableReadMetrics {
   /**
    * Constructor based on container and task container context
    *
-   * @param containerContext container context
-   * @param taskContext task context
+   * @param context {@link Context} for this task
    * @param table underlying table
    * @param tableId table Id
    */
-  public DefaultTableReadMetrics(SamzaContainerContext containerContext, TaskContext taskContext,
-      Table table, String tableId) {
-    TableMetricsUtil tableMetricsUtil = new TableMetricsUtil(containerContext, taskContext, table, tableId);
+  public DefaultTableReadMetrics(Context context, Table table, String tableId) {
+    TableMetricsUtil tableMetricsUtil = new TableMetricsUtil(context, table, tableId);
     getNs = tableMetricsUtil.newTimer("get-ns");
     getAllNs = tableMetricsUtil.newTimer("getAll-ns");
     numGets = tableMetricsUtil.newCounter("num-gets");
index a32d6d5..69d4ef2 100644 (file)
  */
 package org.apache.samza.table.utils;
 
-import org.apache.samza.container.SamzaContainerContext;
+import org.apache.samza.context.Context;
 import org.apache.samza.metrics.Counter;
 import org.apache.samza.metrics.Timer;
 import org.apache.samza.table.Table;
-import org.apache.samza.task.TaskContext;
 
 
 public class DefaultTableWriteMetrics {
@@ -43,14 +42,12 @@ public class DefaultTableWriteMetrics {
   /**
    * Utility class that contains the default set of write metrics.
    *
-   * @param containerContext container context
-   * @param taskContext task context
+   * @param context {@link Context} for this task
    * @param table underlying table
    * @param tableId table Id
    */
-  public DefaultTableWriteMetrics(SamzaContainerContext containerContext, TaskContext taskContext,
-      Table table, String tableId) {
-    TableMetricsUtil tableMetricsUtil = new TableMetricsUtil(containerContext, taskContext, table, tableId);
+  public DefaultTableWriteMetrics(Context context, Table table, String tableId) {
+    TableMetricsUtil tableMetricsUtil = new TableMetricsUtil(context, table, tableId);
     putNs = tableMetricsUtil.newTimer("put-ns");
     putAllNs = tableMetricsUtil.newTimer("putAll-ns");
     deleteNs = tableMetricsUtil.newTimer("delete-ns");
index 6805c64..1b19272 100644 (file)
 
 package org.apache.samza.table.utils;
 
-import java.util.function.Supplier;
-
 import com.google.common.base.Preconditions;
-
-import org.apache.samza.container.SamzaContainerContext;
+import org.apache.samza.context.Context;
 import org.apache.samza.metrics.Counter;
 import org.apache.samza.metrics.Gauge;
 import org.apache.samza.metrics.MetricsRegistry;
 import org.apache.samza.metrics.Timer;
 import org.apache.samza.table.Table;
 import org.apache.samza.table.caching.SupplierGauge;
-import org.apache.samza.task.TaskContext;
+
+import java.util.function.Supplier;
 
 
 /**
@@ -46,21 +44,16 @@ public class TableMetricsUtil {
   /**
    * Constructor based on container context
    *
-   * @param containerContext container context
-   * @param taskContext task context
+   * @param context {@link Context} for this task
    * @param table underlying table
    * @param tableId table Id
    */
-  public TableMetricsUtil(SamzaContainerContext containerContext, TaskContext taskContext,
-      Table table, String tableId) {
-
-    Preconditions.checkNotNull(containerContext);
+  public TableMetricsUtil(Context context, Table table, String tableId) {
+    Preconditions.checkNotNull(context);
     Preconditions.checkNotNull(table);
     Preconditions.checkNotNull(tableId);
 
-    this.metricsRegistry = taskContext == null // The table is at container level, when the task
-        ? containerContext.metricsRegistry     // context passed in is null
-        : taskContext.getMetricsRegistry();
+    this.metricsRegistry = context.getTaskContext().getTaskMetricsRegistry();
     this.groupName = table.getClass().getSimpleName();
     this.tableId = tableId;
   }
index 111869c..6c255f1 100644 (file)
@@ -33,16 +33,15 @@ import java.util.concurrent.ScheduledExecutorService;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicInteger;
 import java.util.stream.Collectors;
-
 import org.apache.samza.SamzaException;
 import org.apache.samza.container.SamzaContainerMetrics;
 import org.apache.samza.container.TaskInstance;
 import org.apache.samza.container.TaskInstanceMetrics;
 import org.apache.samza.container.TaskName;
-import org.apache.samza.util.HighResolutionClock;
 import org.apache.samza.system.IncomingMessageEnvelope;
 import org.apache.samza.system.SystemConsumers;
 import org.apache.samza.system.SystemStreamPartition;
+import org.apache.samza.util.HighResolutionClock;
 import org.apache.samza.util.Throttleable;
 import org.apache.samza.util.ThrottlingScheduler;
 import org.slf4j.Logger;
@@ -374,7 +373,7 @@ public class AsyncRunLoop implements Runnable, Throttleable {
         }, commitMs, commitMs, TimeUnit.MILLISECONDS);
       }
 
-      final EpochTimeScheduler epochTimeScheduler = task.context().getTimerScheduler();
+      final EpochTimeScheduler epochTimeScheduler = task.epochTimeScheduler();
       if (epochTimeScheduler != null) {
         epochTimeScheduler.registerListener(() -> {
             state.needScheduler();
index e2fea95..fcd9766 100644 (file)
@@ -20,7 +20,7 @@
 package org.apache.samza.task;
 
 import java.util.concurrent.ExecutorService;
-import org.apache.samza.config.Config;
+import org.apache.samza.context.Context;
 import org.apache.samza.system.IncomingMessageEnvelope;
 
 
@@ -40,9 +40,9 @@ public class AsyncStreamTaskAdapter implements AsyncStreamTask, InitableTask, Wi
   }
 
   @Override
-  public void init(Config config, TaskContext context) throws Exception {
+  public void init(Context context) throws Exception {
     if (wrappedTask instanceof InitableTask) {
-      ((InitableTask) wrappedTask).init(config, context);
+      ((InitableTask) wrappedTask).init(context);
     }
   }
 
index aa896c2..218ba5d 100644 (file)
  */
 package org.apache.samza.task;
 
-import org.apache.samza.config.Config;
+import org.apache.samza.context.Context;
 import org.apache.samza.operators.OperatorSpecGraph;
-import org.apache.samza.system.EndOfStreamMessage;
-import org.apache.samza.system.MessageType;
-import org.apache.samza.operators.ContextManager;
 import org.apache.samza.operators.impl.InputOperatorImpl;
 import org.apache.samza.operators.impl.OperatorImplGraph;
+import org.apache.samza.system.EndOfStreamMessage;
 import org.apache.samza.system.IncomingMessageEnvelope;
+import org.apache.samza.system.MessageType;
 import org.apache.samza.system.SystemStream;
 import org.apache.samza.system.WatermarkMessage;
 import org.apache.samza.util.Clock;
@@ -42,8 +41,6 @@ public class StreamOperatorTask implements StreamTask, InitableTask, WindowableT
   private static final Logger LOG = LoggerFactory.getLogger(StreamOperatorTask.class);
 
   private final OperatorSpecGraph specGraph;
-  // TODO: to be replaced by proper scope of shared context factory in SAMZA-1714
-  private final ContextManager contextManager;
   private final Clock clock;
 
   private OperatorImplGraph operatorImplGraph;
@@ -52,17 +49,15 @@ public class StreamOperatorTask implements StreamTask, InitableTask, WindowableT
    * Constructs an adaptor task to run the user-implemented {@link OperatorSpecGraph}.
    * @param specGraph the serialized version of user-implemented {@link OperatorSpecGraph}
    *                  that includes the logical DAG
-   * @param contextManager the {@link ContextManager} used to set up the shared context used by operators in the DAG
    * @param clock the {@link Clock} to use for time-keeping
    */
-  public StreamOperatorTask(OperatorSpecGraph specGraph, ContextManager contextManager, Clock clock) {
+  public StreamOperatorTask(OperatorSpecGraph specGraph, Clock clock) {
     this.specGraph = specGraph.clone();
-    this.contextManager = contextManager;
     this.clock = clock;
   }
 
-  public StreamOperatorTask(OperatorSpecGraph specGraph, ContextManager contextManager) {
-    this(specGraph, contextManager, SystemClock.instance());
+  public StreamOperatorTask(OperatorSpecGraph specGraph) {
+    this(specGraph, SystemClock.instance());
   }
 
   /**
@@ -75,20 +70,13 @@ public class StreamOperatorTask implements StreamTask, InitableTask, WindowableT
    * an immutable {@link OperatorSpecGraph} accordingly, which is passed in to this class to create the {@link OperatorImplGraph}
    * corresponding to the logical DAG.
    *
-   * @param config allows accessing of fields in the configuration files that this StreamTask is specified in
    * @param context allows initializing and accessing contextual data of this StreamTask
    * @throws Exception in case of initialization errors
    */
   @Override
-  public final void init(Config config, TaskContext context) throws Exception {
-
-    // get the user-implemented per task context manager and initialize it
-    if (this.contextManager != null) {
-      this.contextManager.init(config, context);
-    }
-
+  public final void init(Context context) throws Exception {
     // create the operator impl DAG corresponding to the logical operator spec DAG
-    this.operatorImplGraph = new OperatorImplGraph(specGraph, config, context, clock);
+    this.operatorImplGraph = new OperatorImplGraph(specGraph, context, clock);
   }
 
   /**
@@ -133,9 +121,6 @@ public class StreamOperatorTask implements StreamTask, InitableTask, WindowableT
 
   @Override
   public void close() throws Exception {
-    if (this.contextManager != null) {
-      this.contextManager.close();
-    }
     if (operatorImplGraph != null) {
       operatorImplGraph.close();
     }
index 834777b..c312fac 100644 (file)
@@ -48,8 +48,8 @@ public class TaskFactoryUtil {
     if (appDesc instanceof TaskApplicationDescriptorImpl) {
       return ((TaskApplicationDescriptorImpl) appDesc).getTaskFactory();
     } else if (appDesc instanceof StreamApplicationDescriptorImpl) {
-      return (StreamTaskFactory) () -> new StreamOperatorTask(((StreamApplicationDescriptorImpl) appDesc).getOperatorSpecGraph(),
-          ((StreamApplicationDescriptorImpl) appDesc).getContextManager());
+      return (StreamTaskFactory) () -> new StreamOperatorTask(
+          ((StreamApplicationDescriptorImpl) appDesc).getOperatorSpecGraph());
     }
     throw new IllegalArgumentException(String.format("ApplicationDescriptorImpl has to be either TaskApplicationDescriptorImpl or "
         + "StreamApplicationDescriptorImpl. class %s is not supported", appDesc.getClass().getName()));
index 1cf9a9c..a91d663 100644 (file)
  */
 package org.apache.samza.util;
 
+import com.google.common.base.Preconditions;
+import com.google.common.base.Stopwatch;
+import org.apache.commons.lang3.tuple.ImmutablePair;
+import org.apache.samza.container.TaskName;
+import org.apache.samza.context.Context;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
 import java.util.Collections;
 import java.util.Map;
 import java.util.Set;
 import java.util.concurrent.TimeUnit;
 import java.util.stream.Collectors;
 
-import org.apache.commons.lang3.tuple.ImmutablePair;
-import org.apache.samza.config.Config;
-import org.apache.samza.task.TaskContext;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
-import com.google.common.base.Preconditions;
-import com.google.common.base.Stopwatch;
-
 import static java.util.concurrent.TimeUnit.NANOSECONDS;
 
 
@@ -106,16 +105,15 @@ public class EmbeddedTaggedRateLimiter implements RateLimiter {
   }
 
   @Override
-  public void init(Config config, TaskContext taskContext) {
+  public void init(Context context) {
     this.tagToRateLimiterMap = Collections.unmodifiableMap(tagToTargetRateMap.entrySet().stream()
         .map(e -> {
             String tag = e.getKey();
-            int effectiveRate = e.getValue();
-            if (taskContext != null) {
-              effectiveRate /= taskContext.getSamzaContainerContext().taskNames.size();
-              LOGGER.info(String.format("Effective rate limit for task %s and tag %s is %d",
-                  taskContext.getTaskName(), tag, effectiveRate));
-            }
+            int numTasksInContainer = context.getContainerContext().getContainerModel().getTasks().keySet().size();
+            int effectiveRate = e.getValue() / numTasksInContainer;
+            TaskName taskName = context.getTaskContext().getTaskModel().getTaskName();
+            LOGGER.info(String.format("Effective rate limit for task %s and tag %s is %d", taskName, tag,
+                effectiveRate));
             return new ImmutablePair<>(tag, com.google.common.util.concurrent.RateLimiter.create(effectiveRate));
           })
         .collect(Collectors.toMap(ImmutablePair::getKey, ImmutablePair::getValue))
index 3c10aae..3292986 100644 (file)
@@ -43,6 +43,7 @@ import org.apache.samza.config._
 import org.apache.samza.container.disk.DiskSpaceMonitor.Listener
 import org.apache.samza.container.disk.{DiskQuotaPolicyFactory, DiskSpaceMonitor, NoThrottlingDiskQuotaPolicyFactory, PollingScanDiskSpaceMonitor}
 import org.apache.samza.container.host.{StatisticsMonitorImpl, SystemMemoryStatistics, SystemStatisticsMonitor}
+import org.apache.samza.context._
 import org.apache.samza.job.model.{ContainerModel, JobModel}
 import org.apache.samza.metrics.{JmxServer, JvmMetrics, MetricsRegistryMap, MetricsReporter}
 import org.apache.samza.serializers._
@@ -122,9 +123,13 @@ object SamzaContainer extends Logging {
   def apply(
     containerId: String,
     jobModel: JobModel,
-    config: Config,
     customReporters: Map[String, MetricsReporter] = Map[String, MetricsReporter](),
-    taskFactory: TaskFactory[_]) = {
+    taskFactory: TaskFactory[_],
+    jobContext: JobContext,
+    applicationContainerContextFactoryOption: Option[ApplicationContainerContextFactory[ApplicationContainerContext]],
+    applicationTaskContextFactoryOption: Option[ApplicationTaskContextFactory[ApplicationTaskContext]]
+  ) = {
+    val config = jobContext.getConfig
     val containerModel = jobModel.getContainers.get(containerId)
     val containerName = "samza-container-%s" format containerId
     val maxChangeLogStreamPartitions = jobModel.maxChangeLogStreamPartitions
@@ -488,8 +493,10 @@ object SamzaContainer extends Logging {
       .asScala
       .map(_.getTaskName)
       .toSet
-    val containerContext = new SamzaContainerContext(containerId, config, taskNames.asJava, samzaContainerMetrics.registry)
 
+    val containerContext = new ContainerContextImpl(containerModel, samzaContainerMetrics.registry)
+    val applicationContainerContextOption = applicationContainerContextFactoryOption
+      .map(_.create(jobContext, containerContext))
 
     val storeWatchPaths = new util.HashSet[Path]()
 
@@ -571,6 +578,7 @@ object SamzaContainer extends Logging {
               collector,
               taskInstanceMetrics.registry,
               changeLogSystemStreamPartition,
+              jobContext,
               containerContext)
             (storeName, storageEngine)
         }
@@ -635,13 +643,11 @@ object SamzaContainer extends Logging {
 
       def createTaskInstance(task: Any): TaskInstance = new TaskInstance(
           task = task,
-          taskName = taskName,
-          config = config,
+          taskModel = taskModel,
           metrics = taskInstanceMetrics,
           systemAdmins = systemAdmins,
           consumerMultiplexer = consumerMultiplexer,
           collector = collector,
-          containerContext = containerContext,
           offsetManager = offsetManager,
           storageManager = storageManager,
           tableManager = tableManager,
@@ -652,7 +658,11 @@ object SamzaContainer extends Logging {
           streamMetadataCache = streamMetadataCache,
           timerExecutor = timerExecutor,
           sideInputSSPs = taskSideInputSSPs,
-          sideInputStorageManager = sideInputStorageManager)
+          sideInputStorageManager = sideInputStorageManager,
+          jobContext = jobContext,
+          containerContext = containerContext,
+          applicationContainerContextOption = applicationContainerContextOption,
+          applicationTaskContextFactoryOption = applicationTaskContextFactoryOption)
 
       val taskInstance = createTaskInstance(task)
 
@@ -708,7 +718,7 @@ object SamzaContainer extends Logging {
     info("Samza container setup complete.")
 
     new SamzaContainer(
-      containerContext = containerContext,
+      config = config,
       taskInstances = taskInstances,
       runLoop = runLoop,
       systemAdmins = systemAdmins,
@@ -722,10 +732,11 @@ object SamzaContainer extends Logging {
       diskSpaceMonitor = diskSpaceMonitor,
       hostStatisticsMonitor = memoryStatisticsMonitor,
       taskThreadPool = taskThreadPool,
-      timerExecutor = timerExecutor)
+      timerExecutor = timerExecutor,
+      containerContext = containerContext,
+      applicationContainerContextOption = applicationContainerContextOption)
   }
 
-
   /**
     * Builds the set of SSPs for all changelogs on this container.
     */
@@ -741,7 +752,7 @@ object SamzaContainer extends Logging {
 }
 
 class SamzaContainer(
-  containerContext: SamzaContainerContext,
+  config: Config,
   taskInstances: Map[TaskName, TaskInstance],
   runLoop: Runnable,
   systemAdmins: SystemAdmins,
@@ -756,12 +767,14 @@ class SamzaContainer(
   reporters: Map[String, MetricsReporter] = Map(),
   jvm: JvmMetrics = null,
   taskThreadPool: ExecutorService = null,
-  timerExecutor: ScheduledExecutorService = Executors.newSingleThreadScheduledExecutor) extends Runnable with Logging {
+  timerExecutor: ScheduledExecutorService = Executors.newSingleThreadScheduledExecutor,
+  containerContext: ContainerContext,
+  applicationContainerContextOption: Option[ApplicationContainerContext]) extends Runnable with Logging {
 
-  val shutdownMs = containerContext.config.getShutdownMs.getOrElse(TaskConfigJava.DEFAULT_TASK_SHUTDOWN_MS)
+  val shutdownMs = config.getShutdownMs.getOrElse(TaskConfigJava.DEFAULT_TASK_SHUTDOWN_MS)
   var shutdownHookThread: Thread = null
   var jmxServer: JmxServer = null
-  val isAutoCommitEnabled = containerContext.config.isAutoCommitEnabled
+  val isAutoCommitEnabled = config.isAutoCommitEnabled
 
   @volatile private var status = SamzaContainerStatus.NOT_STARTED
   private var exceptionSeen: Throwable = null
@@ -789,6 +802,7 @@ class SamzaContainer(
       status = SamzaContainerStatus.STARTING
 
       jmxServer = new JmxServer()
+      applicationContainerContextOption.foreach(_.start)
 
       startMetrics
       startDiagnostics
@@ -841,6 +855,8 @@ class SamzaContainer(
       shutdownSecurityManger
       shutdownAdmins
 
+      applicationContainerContextOption.foreach(_.stop)
+
       if (!status.equals(SamzaContainerStatus.FAILED)) {
         status = SamzaContainerStatus.STOPPED
       }
@@ -930,18 +946,18 @@ class SamzaContainer(
   }
 
   def startDiagnostics {
-    if (containerContext.config.getDiagnosticsEnabled) {
+    if (config.getDiagnosticsEnabled) {
       info("Starting diagnostics.")
 
       try {
-        val diagnosticsAppender = Class.forName(containerContext.config.getDiagnosticsAppenderClass).
+        val diagnosticsAppender = Class.forName(config.getDiagnosticsAppenderClass).
           getDeclaredConstructor(classOf[SamzaContainerMetrics]).newInstance(this.metrics);
       }
       catch {
         case e@(_: ClassNotFoundException | _: InstantiationException | _: InvocationTargetException) => {
           error("Failed to instantiate diagnostic appender", e)
           throw new ConfigException("Failed to instantiate diagnostic appender class " +
-            containerContext.config.getDiagnosticsAppenderClass, e)
+            config.getDiagnosticsAppenderClass, e)
         }
       }
     }
@@ -958,24 +974,25 @@ class SamzaContainer(
   }
 
   def storeContainerLocality {
-    val isHostAffinityEnabled: Boolean = new ClusterManagerConfig(containerContext.config).getHostAffinityEnabled
+    val isHostAffinityEnabled: Boolean = new ClusterManagerConfig(config).getHostAffinityEnabled
     if (isHostAffinityEnabled) {
-      val localityManager: LocalityManager = new LocalityManager(containerContext.config, containerContext.metricsRegistry)
-      val containerName = "SamzaContainer-" + String.valueOf(containerContext.id)
+      val localityManager: LocalityManager = new LocalityManager(config, containerContext.getContainerMetricsRegistry)
+      val containerId = containerContext.getContainerModel.getId
+      val containerName = "SamzaContainer-" + containerId
       info("Registering %s with metadata store" format containerName)
       try {
         val hostInet = Util.getLocalHost
         val jmxUrl = if (jmxServer != null) jmxServer.getJmxUrl else ""
         val jmxTunnelingUrl = if (jmxServer != null) jmxServer.getTunnelingJmxUrl else ""
         info("Writing container locality and JMX address to metadata store")
-        localityManager.writeContainerToHostMapping(containerContext.id, hostInet.getHostName)
+        localityManager.writeContainerToHostMapping(containerId, hostInet.getHostName)
       } catch {
         case uhe: UnknownHostException =>
           warn("Received UnknownHostException when persisting locality info for container %s: " +
-            "%s" format (containerContext.id, uhe.getMessage))  //No-op
+            "%s" format (containerId, uhe.getMessage))  //No-op
         case unknownException: Throwable =>
           warn("Received an exception when persisting locality info for container %s: " +
-            "%s" format (containerContext.id, unknownException.getMessage))
+            "%s" format (containerId, unknownException.getMessage))
       } finally {
         info("Shutting down locality manager.")
         localityManager.close()
@@ -1016,7 +1033,6 @@ class SamzaContainer(
     systemAdmins.start
   }
 
-
   def startProducers {
     info("Registering task instances with producers.")
 
@@ -1092,7 +1108,6 @@ class SamzaContainer(
     systemAdmins.stop
   }
 
-
   def shutdownProducers {
     info("Shutting down producer multiplexer.")
 
@@ -1185,4 +1200,4 @@ class SamzaContainer(
       hostStatisticsMonitor.stop()
     }
   }
-}
+}
\ No newline at end of file
index 9f4fd17..f8e9c63 100644 (file)
 package org.apache.samza.container
 
 
+import java.util.Optional
 import java.util.concurrent.ScheduledExecutorService
 
 import org.apache.samza.SamzaException
 import org.apache.samza.checkpoint.OffsetManager
 import org.apache.samza.config.Config
 import org.apache.samza.config.StreamConfig.Config2Stream
-import org.apache.samza.job.model.JobModel
+import org.apache.samza.context._
+import org.apache.samza.job.model.{JobModel, TaskModel}
 import org.apache.samza.metrics.MetricsReporter
-import org.apache.samza.scheduler.ScheduledCallback
+import org.apache.samza.scheduler.{CallbackSchedulerImpl, ScheduledCallback}
 import org.apache.samza.storage.kv.KeyValueStore
 import org.apache.samza.storage.{TaskSideInputStorageManager, TaskStorageManager}
 import org.apache.samza.system._
@@ -36,19 +38,17 @@ import org.apache.samza.table.TableManager
 import org.apache.samza.task._
 import org.apache.samza.util.{Logging, ScalaJavaUtil}
 
-import scala.collection.JavaConverters._
 import scala.collection.JavaConversions._
+import scala.collection.JavaConverters._
 import scala.collection.Map
 
 class TaskInstance(
   val task: Any,
-  val taskName: TaskName,
-  config: Config,
+  taskModel: TaskModel,
   val metrics: TaskInstanceMetrics,
   systemAdmins: SystemAdmins,
   consumerMultiplexer: SystemConsumers,
   collector: TaskInstanceCollector,
-  containerContext: SamzaContainerContext,
   val offsetManager: OffsetManager = new OffsetManager,
   storageManager: TaskStorageManager = null,
   tableManager: TableManager = null,
@@ -59,15 +59,23 @@ class TaskInstance(
   streamMetadataCache: StreamMetadataCache = null,
   timerExecutor : ScheduledExecutorService = null,
   sideInputSSPs: Set[SystemStreamPartition] = Set(),
-  sideInputStorageManager: TaskSideInputStorageManager = null) extends Logging {
-
+  sideInputStorageManager: TaskSideInputStorageManager = null,
+  jobContext: JobContext,
+  containerContext: ContainerContext,
+  applicationContainerContextOption: Option[ApplicationContainerContext],
+  applicationTaskContextFactoryOption: Option[ApplicationTaskContextFactory[ApplicationTaskContext]]
+) extends Logging {
+
+  val taskName: TaskName = taskModel.getTaskName
   val isInitableTask = task.isInstanceOf[InitableTask]
   val isWindowableTask = task.isInstanceOf[WindowableTask]
   val isEndOfStreamListenerTask = task.isInstanceOf[EndOfStreamListenerTask]
   val isClosableTask = task.isInstanceOf[ClosableTask]
   val isAsyncTask = task.isInstanceOf[AsyncStreamTask]
 
-  val kvStoreSupplier = ScalaJavaUtil.toJavaFunction(
+  val epochTimeScheduler: EpochTimeScheduler = EpochTimeScheduler.create(timerExecutor)
+
+  private val kvStoreSupplier = ScalaJavaUtil.toJavaFunction(
     (storeName: String) => {
       if (storageManager != null && storageManager.getStore(storeName).isDefined) {
         storageManager.getStore(storeName).get.asInstanceOf[KeyValueStore[_, _]]
@@ -77,9 +85,14 @@ class TaskInstance(
         null
       }
     })
-
-  val context = new TaskContextImpl(taskName, metrics, containerContext, systemStreamPartitions.asJava, offsetManager,
-                                    kvStoreSupplier, tableManager, jobModel, streamMetadataCache, timerExecutor)
+  private val taskContext = new TaskContextImpl(taskModel, metrics.registry, kvStoreSupplier, tableManager,
+    new CallbackSchedulerImpl(epochTimeScheduler), offsetManager, jobModel, streamMetadataCache)
+  // need separate field for this instead of using it through Context, since Context throws an exception if it is null
+  private val applicationTaskContextOption = applicationTaskContextFactoryOption.map(_.create(jobContext,
+    containerContext, taskContext, applicationContainerContextOption.orNull))
+  val context = new ContextImpl(jobContext, containerContext, taskContext,
+    Optional.ofNullable(applicationContainerContextOption.orNull),
+    Optional.ofNullable(applicationTaskContextOption.orNull))
 
   // store the (ssp -> if this ssp has caught up) mapping. "caught up"
   // means the same ssp in other taskInstances have the same offset as
@@ -88,6 +101,8 @@ class TaskInstance(
     scala.collection.mutable.Map[SystemStreamPartition, Boolean]()
   systemStreamPartitions.foreach(ssp2CaughtupMapping += _ -> false)
 
+  private val config: Config = jobContext.getConfig
+
   val intermediateStreams: Set[String] = config.getStreamIds.filter(config.getIsIntermediateStream).toSet
 
   val streamsToDeleteCommittedMessages: Set[String] = config.getStreamIds.filter(config.getDeleteCommittedMessages).map(config.getPhysicalName).toSet
@@ -126,7 +141,7 @@ class TaskInstance(
     if (tableManager != null) {
       debug("Starting table manager for taskName: %s" format taskName)
 
-      tableManager.init(containerContext, context)
+      tableManager.init(context)
     } else {
       debug("Skipping table manager initialization for taskName: %s" format taskName)
     }
@@ -136,10 +151,14 @@ class TaskInstance(
     if (isInitableTask) {
       debug("Initializing task for taskName: %s" format taskName)
 
-      task.asInstanceOf[InitableTask].init(config, context)
+      task.asInstanceOf[InitableTask].init(context)
     } else {
       debug("Skipping task initialization for taskName: %s" format taskName)
     }
+    applicationTaskContextOption.foreach(applicationTaskContext => {
+      debug("Starting application-defined task context for taskName: %s" format taskName)
+      applicationTaskContext.start()
+    })
   }
 
   def registerProducers {
@@ -226,7 +245,7 @@ class TaskInstance(
     trace("Scheduler for taskName: %s" format taskName)
 
     exceptionHandler.maybeHandle {
-      context.getTimerScheduler.removeReadyTimers().entrySet().foreach { entry =>
+      epochTimeScheduler.removeReadyTimers().entrySet().foreach { entry =>
         entry.getValue.asInstanceOf[ScheduledCallback[Any]].onCallback(entry.getKey.getKey, collector, coordinator)
       }
     }
@@ -266,6 +285,10 @@ class TaskInstance(
   }
 
   def shutdownTask {
+    applicationTaskContextOption.foreach(applicationTaskContext => {
+      debug("Stopping application-defined task context for taskName: %s" format taskName)
+      applicationTaskContext.stop()
+    })
     if (task.isInstanceOf[ClosableTask]) {
       debug("Shutting down stream task for taskName: %s" format taskName)
 
index bec4ec0..929d6a4 100644 (file)
@@ -24,6 +24,7 @@ import org.apache.samza.config.JobConfig._
 import org.apache.samza.config.ShellCommandConfig._
 import org.apache.samza.config.{Config, TaskConfigJava}
 import org.apache.samza.container.{SamzaContainer, SamzaContainerListener, TaskName}
+import org.apache.samza.context.JobContextImpl
 import org.apache.samza.coordinator.JobModelManager
 import org.apache.samza.coordinator.stream.CoordinatorStreamManager
 import org.apache.samza.job.{StreamJob, StreamJobFactory}
@@ -112,9 +113,12 @@ class ThreadJobFactory extends StreamJobFactory with Logging {
       val container = SamzaContainer(
         containerId,
         jobModel,
-        config,
         Map[String, MetricsReporter](),
-        taskFactory)
+        taskFactory,
+        JobContextImpl.fromConfigWithDefaults(config),
+        Option(appDesc.getApplicationContainerContextFactory.orElse(null)),
+        Option(appDesc.getApplicationTaskContextFactory.orElse(null))
+      )
       container.setContainerListener(containerListener)
 
       val threadJob = new ThreadJob(container)
index 84f5dbb..de16ef2 100644 (file)
 package org.apache.samza.application;
 
 import com.google.common.collect.ImmutableList;
-
 import java.util.ArrayList;
 import java.util.HashMap;
 import java.util.List;
+import java.util.Optional;
 import java.util.concurrent.atomic.AtomicInteger;
 import org.apache.samza.SamzaException;
 import org.apache.samza.config.Config;
 import org.apache.samza.config.JobConfig;
+import org.apache.samza.context.ApplicationContainerContextFactory;
+import org.apache.samza.context.ApplicationTaskContextFactory;
 import org.apache.samza.operators.BaseTableDescriptor;
-import org.apache.samza.operators.ContextManager;
 import org.apache.samza.operators.data.TestMessageEnvelope;
 import org.apache.samza.operators.descriptors.GenericInputDescriptor;
 import org.apache.samza.operators.descriptors.GenericOutputDescriptor;
@@ -521,11 +522,35 @@ public class TestStreamApplicationDescriptorImpl {
   }
 
   @Test
-  public void testContextManager() {
-    ContextManager cntxMan = mock(ContextManager.class);
-    StreamApplication testApp = appDesc -> appDesc.withContextManager(cntxMan);
+  public void testApplicationContainerContextFactory() {
+    ApplicationContainerContextFactory factory = mock(ApplicationContainerContextFactory.class);
+    StreamApplication testApp = appDesc -> appDesc.withApplicationContainerContextFactory(factory);
+    StreamApplicationDescriptorImpl appSpec = new StreamApplicationDescriptorImpl(testApp, mock(Config.class));
+    assertEquals(appSpec.getApplicationContainerContextFactory(), Optional.of(factory));
+  }
+
+  @Test
+  public void testNoApplicationContainerContextFactory() {
+    StreamApplication testApp = appDesc -> {
+    };
+    StreamApplicationDescriptorImpl appSpec = new StreamApplicationDescriptorImpl(testApp, mock(Config.class));
+    assertEquals(appSpec.getApplicationContainerContextFactory(), Optional.empty());
+  }
+
+  @Test
+  public void testApplicationTaskContextFactory() {
+    ApplicationTaskContextFactory factory = mock(ApplicationTaskContextFactory.class);
+    StreamApplication testApp = appDesc -> appDesc.withApplicationTaskContextFactory(factory);
+    StreamApplicationDescriptorImpl appSpec = new StreamApplicationDescriptorImpl(testApp, mock(Config.class));
+    assertEquals(appSpec.getApplicationTaskContextFactory(), Optional.of(factory));
+  }
+
+  @Test
+  public void testNoApplicationTaskContextFactory() {
+    StreamApplication testApp = appDesc -> {
+    };
     StreamApplicationDescriptorImpl appSpec = new StreamApplicationDescriptorImpl(testApp, mock(Config.class));
-    assertEquals(appSpec.getContextManager(), cntxMan);
+    assertEquals(appSpec.getApplicationTaskContextFactory(), Optional.empty());
   }
 
   @Test
index abe5ce1..e79e25b 100644 (file)
@@ -21,10 +21,12 @@ package org.apache.samza.application;
 import java.util.ArrayList;
 import java.util.HashSet;
 import java.util.List;
+import java.util.Optional;
 import java.util.Set;
 import org.apache.samza.config.Config;
+import org.apache.samza.context.ApplicationContainerContextFactory;
+import org.apache.samza.context.ApplicationTaskContextFactory;
 import org.apache.samza.operators.BaseTableDescriptor;
-import org.apache.samza.operators.ContextManager;
 import org.apache.samza.operators.TableDescriptor;
 import org.apache.samza.operators.descriptors.base.stream.InputDescriptor;
 import org.apache.samza.operators.descriptors.base.stream.OutputDescriptor;
@@ -127,13 +129,35 @@ public class TestTaskApplicationDescriptorImpl {
   }
 
   @Test
-  public void testContextManager() {
-    ContextManager cntxMan = mock(ContextManager.class);
+  public void testApplicationContainerContextFactory() {
+    ApplicationContainerContextFactory factory = mock(ApplicationContainerContextFactory.class);
+    TaskApplication testApp = appDesc -> appDesc.withApplicationContainerContextFactory(factory);
+    TaskApplicationDescriptorImpl appSpec = new TaskApplicationDescriptorImpl(testApp, mock(Config.class));
+    assertEquals(appSpec.getApplicationContainerContextFactory(), Optional.of(factory));
+  }
+
+  @Test
+  public void testNoApplicationContainerContextFactory() {
     TaskApplication testApp = appDesc -> {
-      appDesc.withContextManager(cntxMan);
     };
-    TaskApplicationDescriptorImpl appDesc = new TaskApplicationDescriptorImpl(testApp, config);
-    assertEquals(appDesc.getContextManager(), cntxMan);
+    TaskApplicationDescriptorImpl appSpec = new TaskApplicationDescriptorImpl(testApp, mock(Config.class));
+    assertEquals(appSpec.getApplicationContainerContextFactory(), Optional.empty());
+  }
+
+  @Test
+  public void testApplicationTaskContextFactory() {
+    ApplicationTaskContextFactory factory = mock(ApplicationTaskContextFactory.class);
+    TaskApplication testApp = appDesc -> appDesc.withApplicationTaskContextFactory(factory);
+    TaskApplicationDescriptorImpl appSpec = new TaskApplicationDescriptorImpl(testApp, mock(Config.class));
+    assertEquals(appSpec.getApplicationTaskContextFactory(), Optional.of(factory));
+  }
+
+  @Test
+  public void testNoApplicationTaskContextFactory() {
+    TaskApplication testApp = appDesc -> {
+    };
+    TaskApplicationDescriptorImpl appSpec = new TaskApplicationDescriptorImpl(testApp, mock(Config.class));
+    assertEquals(appSpec.getApplicationTaskContextFactory(), Optional.empty());
   }
 
   @Test
diff --git a/samza-core/src/test/java/org/apache/samza/context/MockContext.java b/samza-core/src/test/java/org/apache/samza/context/MockContext.java
new file mode 100644 (file)
index 0000000..778d486
--- /dev/null
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.context;
+
+import org.apache.samza.config.Config;
+import org.apache.samza.config.MapConfig;
+
+import static org.mockito.Mockito.*;
+
+
+public class MockContext implements Context {
+  private final JobContext jobContext = mock(JobContext.class);
+  private final ContainerContext containerContext = mock(ContainerContext.class);
+  /**
+   * This is {@link TaskContextImpl} because some tests need more than just the interface.
+   */
+  private final TaskContextImpl taskContext = mock(TaskContextImpl.class);
+  private final ApplicationContainerContext applicationContainerContext = mock(ApplicationContainerContext.class);
+  private final ApplicationTaskContext applicationTaskContext = mock(ApplicationTaskContext.class);
+
+  public MockContext() {
+    this(new MapConfig());
+  }
+
+  /**
+   * @param config config is widely used, so help wire it in here
+   */
+  public MockContext(Config config) {
+    when(this.jobContext.getConfig()).thenReturn(config);
+  }
+
+  @Override
+  public JobContext getJobContext() {
+    return jobContext;
+  }
+
+  @Override
+  public ContainerContext getContainerContext() {
+    return containerContext;
+  }
+
+  @Override
+  public TaskContext getTaskContext() {
+    return taskContext;
+  }
+
+  @Override
+  public ApplicationContainerContext getApplicationContainerContext() {
+    return applicationContainerContext;
+  }
+
+  @Override
+  public ApplicationTaskContext getApplicationTaskContext() {
+    return applicationTaskContext;
+  }
+}
index 33ad3a5..40526db 100644 (file)
  */
 package org.apache.samza.context;
 
+import java.util.Optional;
 import org.junit.Test;
 
 import static org.junit.Assert.assertEquals;
+import static org.mockito.Mockito.mock;
 
 
 public class TestContextImpl {
@@ -63,11 +65,17 @@ public class TestContextImpl {
   }
 
   private static Context buildWithApplicationContainerContext(ApplicationContainerContext applicationContainerContext) {
-    return new ContextImpl(null, null, null, applicationContainerContext, null);
+    return buildWithApplicationContext(applicationContainerContext, mock(ApplicationTaskContext.class));
   }
 
   private static Context buildWithApplicationTaskContext(ApplicationTaskContext applicationTaskContext) {
-    return new ContextImpl(null, null, null, null, applicationTaskContext);
+    return buildWithApplicationContext(mock(ApplicationContainerContext.class), applicationTaskContext);
+  }
+
+  private static Context buildWithApplicationContext(ApplicationContainerContext applicationContainerContext,
+      ApplicationTaskContext applicationTaskContext) {
+    return new ContextImpl(mock(JobContext.class), mock(ContainerContext.class), mock(TaskContext.class),
+        Optional.ofNullable(applicationContainerContext), Optional.ofNullable(applicationTaskContext));
   }
 
   /**
index 78f886c..3d3803b 100644 (file)
@@ -34,6 +34,7 @@ import org.mockito.Mock;
 import org.mockito.MockitoAnnotations;
 
 import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNull;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
@@ -62,7 +63,7 @@ public class TestTaskContextImpl {
     MockitoAnnotations.initMocks(this);
     taskContext =
         new TaskContextImpl(taskModel, taskMetricsRegistry, keyValueStoreProvider, tableManager, callbackScheduler,
-            offsetManager);
+            offsetManager, null, null);
     when(this.taskModel.getTaskName()).thenReturn(TASK_NAME);
   }
 
@@ -95,4 +96,16 @@ public class TestTaskContextImpl {
     taskContext.setStartingOffset(ssp, "123");
     verify(offsetManager).setStartingOffset(TASK_NAME, ssp, "123");
   }
+
+  /**
+   * Given a registered object, fetchObject should get it. If an object is not registered at a key, then fetchObject
+   * should return null.
+   */
+  @Test
+  public void testRegisterAndFetchObject() {
+    String value = "hello world";
+    taskContext.registerObject("key", value);
+    assertEquals(value, taskContext.fetchObject("key"));
+    assertNull(taskContext.fetchObject("not a key"));
+  }
 }
\ No newline at end of file
index 51a9523..4618e52 100644 (file)
 package org.apache.samza.execution;
 
 import com.google.common.base.Joiner;
-import java.util.ArrayList;
-import java.util.Base64;
-import java.util.HashMap;
-import java.util.HashSet;
-import java.util.List;
-import java.util.Map;
-import java.util.Set;
-import java.util.stream.Collectors;
 import org.apache.samza.application.StreamApplicationDescriptorImpl;
 import org.apache.samza.application.TaskApplicationDescriptorImpl;
 import org.apache.samza.config.Config;
@@ -36,7 +28,7 @@ import org.apache.samza.config.MapConfig;
 import org.apache.samza.config.SerializerConfig;
 import org.apache.samza.config.TaskConfig;
 import org.apache.samza.config.TaskConfigJava;
-import org.apache.samza.container.SamzaContainerContext;
+import org.apache.samza.context.Context;
 import org.apache.samza.operators.BaseTableDescriptor;
 import org.apache.samza.operators.KV;
 import org.apache.samza.operators.TableDescriptor;
@@ -52,9 +44,17 @@ import org.apache.samza.table.Table;
 import org.apache.samza.table.TableProvider;
 import org.apache.samza.table.TableProviderFactory;
 import org.apache.samza.table.TableSpec;
-import org.apache.samza.task.TaskContext;
 import org.junit.Test;
 
+import java.util.ArrayList;
+import java.util.Base64;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.stream.Collectors;
+
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertTrue;
 import static org.mockito.Mockito.mock;
@@ -445,7 +445,7 @@ public class TestJobNodeConfigurationGenerator extends ExecutionPlannerTestBase
     }
 
     @Override
-    public void init(SamzaContainerContext containerContext, TaskContext taskContext) {
+    public void init(Context context) {
 
     }
 
index 6fa9ed1..1315912 100644 (file)
 package org.apache.samza.operators;
 
 import com.google.common.collect.ImmutableSet;
-
 import org.apache.samza.Partition;
 import org.apache.samza.SamzaException;
 import org.apache.samza.application.StreamApplicationDescriptorImpl;
 import org.apache.samza.config.Config;
 import org.apache.samza.config.MapConfig;
-import org.apache.samza.container.TaskContextImpl;
+import org.apache.samza.context.Context;
+import org.apache.samza.context.MockContext;
+import org.apache.samza.job.model.TaskModel;
 import org.apache.samza.metrics.MetricsRegistryMap;
 import org.apache.samza.operators.descriptors.GenericInputDescriptor;
 import org.apache.samza.operators.descriptors.GenericSystemDescriptor;
@@ -40,7 +41,6 @@ import org.apache.samza.system.SystemStream;
 import org.apache.samza.system.SystemStreamPartition;
 import org.apache.samza.task.MessageCollector;
 import org.apache.samza.task.StreamOperatorTask;
-import org.apache.samza.task.TaskContext;
 import org.apache.samza.task.TaskCoordinator;
 import org.apache.samza.testUtils.StreamTestUtils;
 import org.apache.samza.testUtils.TestClock;
@@ -56,10 +56,10 @@ import java.util.List;
 import java.util.Map;
 import java.util.Set;
 
-import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.assertEquals;
-import static org.mockito.Matchers.eq;
+import static org.junit.Assert.assertTrue;
 import static org.mockito.Matchers.any;
+import static org.mockito.Matchers.eq;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
@@ -304,22 +304,23 @@ public class TestJoinOperator {
     mapConfig.put("job.id", "jobId");
     StreamTestUtils.addStreamConfigs(mapConfig, "inStream", "insystem", "instream");
     StreamTestUtils.addStreamConfigs(mapConfig, "inStream2", "insystem", "instream2");
-    Config config = new MapConfig(mapConfig);
-    TaskContextImpl taskContext = mock(TaskContextImpl.class);
-    when(taskContext.getSystemStreamPartitions()).thenReturn(ImmutableSet
+    Context context = new MockContext(new MapConfig(mapConfig));
+    TaskModel taskModel = mock(TaskModel.class);
+    when(taskModel.getSystemStreamPartitions()).thenReturn(ImmutableSet
         .of(new SystemStreamPartition("insystem", "instream", new Partition(0)),
             new SystemStreamPartition("insystem", "instream2", new Partition(0))));
-    when(taskContext.getMetricsRegistry()).thenReturn(new MetricsRegistryMap());
+    when(context.getTaskContext().getTaskModel()).thenReturn(taskModel);
+    when(context.getTaskContext().getTaskMetricsRegistry()).thenReturn(new MetricsRegistryMap());
     // need to return different stores for left and right side
     IntegerSerde integerSerde = new IntegerSerde();
     TimestampedValueSerde timestampedValueSerde = new TimestampedValueSerde(new KVSerde(integerSerde, integerSerde));
-    when(taskContext.getStore(eq("jobName-jobId-join-j1-L")))
+    when(context.getTaskContext().getStore(eq("jobName-jobId-join-j1-L")))
         .thenReturn(new TestInMemoryStore(integerSerde, timestampedValueSerde));
-    when(taskContext.getStore(eq("jobName-jobId-join-j1-R")))
+    when(context.getTaskContext().getStore(eq("jobName-jobId-join-j1-R")))
         .thenReturn(new TestInMemoryStore(integerSerde, timestampedValueSerde));
 
-    StreamOperatorTask sot = new StreamOperatorTask(graphSpec.getOperatorSpecGraph(), graphSpec.getContextManager(), clock);
-    sot.init(config, taskContext);
+    StreamOperatorTask sot = new StreamOperatorTask(graphSpec.getOperatorSpecGraph(), clock);
+    sot.init(context);
     return sot;
   }
 
@@ -357,7 +358,7 @@ public class TestJoinOperator {
     private int numCloseCalls = 0;
 
     @Override
-    public void init(Config config, TaskContext context) {
+    public void init(Context context) {
       numInitCalls++;
     }
 
index 6d12d99..0ff2e0d 100644 (file)
@@ -21,9 +21,9 @@ package org.apache.samza.operators.impl;
 import java.util.Collection;
 import java.util.Collections;
 import java.util.Set;
-
-import org.apache.samza.config.Config;
-import org.apache.samza.container.TaskContextImpl;
+import org.apache.samza.context.Context;
+import org.apache.samza.context.MockContext;
+import org.apache.samza.job.model.TaskModel;
 import org.apache.samza.metrics.Counter;
 import org.apache.samza.metrics.MetricsRegistryMap;
 import org.apache.samza.metrics.ReadableMetricsRegistry;
@@ -32,8 +32,8 @@ import org.apache.samza.operators.functions.ScheduledFunction;
 import org.apache.samza.operators.functions.WatermarkFunction;
 import org.apache.samza.operators.spec.OperatorSpec;
 import org.apache.samza.task.MessageCollector;
-import org.apache.samza.task.TaskContext;
 import org.apache.samza.task.TaskCoordinator;
+import org.junit.Before;
 import org.junit.Test;
 
 import static org.mockito.Matchers.anyLong;
@@ -46,14 +46,20 @@ import static org.mockito.Mockito.when;
 
 
 public class TestOperatorImpl {
+  private Context context;
+
+  @Before
+  public void setup() {
+    this.context = new MockContext();
+    when(this.context.getTaskContext().getTaskMetricsRegistry()).thenReturn(new MetricsRegistryMap());
+    when(this.context.getTaskContext().getTaskModel()).thenReturn(mock(TaskModel.class));
+  }
 
   @Test(expected = IllegalStateException.class)
   public void testMultipleInitShouldThrow() {
     OperatorImpl<Object, Object> opImpl = new TestOpImpl(mock(Object.class));
-    TaskContextImpl mockTaskContext = mock(TaskContextImpl.class);
-    when(mockTaskContext.getMetricsRegistry()).thenReturn(new MetricsRegistryMap());
-    opImpl.init(mock(Config.class), mockTaskContext);
-    opImpl.init(mock(Config.class), mockTaskContext);
+    opImpl.init(this.context);
+    opImpl.init(this.context);
   }
 
   @Test(expected = IllegalStateException.class)
@@ -64,24 +70,21 @@ public class TestOperatorImpl {
 
   @Test
   public void testOnMessagePropagatesResults() {
-    TaskContextImpl mockTaskContext = mock(TaskContextImpl.class);
-    when(mockTaskContext.getMetricsRegistry()).thenReturn(new MetricsRegistryMap());
-
     Object mockTestOpImplOutput = mock(Object.class);
     OperatorImpl<Object, Object> opImpl = new TestOpImpl(mockTestOpImplOutput);
-    opImpl.init(mock(Config.class), mockTaskContext);
+    opImpl.init(this.context);
 
     // register a couple of operators
     OperatorImpl mockNextOpImpl1 = mock(OperatorImpl.class);
     when(mockNextOpImpl1.getOperatorSpec()).thenReturn(new TestOpSpec());
     when(mockNextOpImpl1.handleMessage(anyObject(), anyObject(), anyObject())).thenReturn(Collections.emptyList());
-    mockNextOpImpl1.init(mock(Config.class), mockTaskContext);
+    mockNextOpImpl1.init(this.context);
     opImpl.registerNextOperator(mockNextOpImpl1);
 
     OperatorImpl mockNextOpImpl2 = mock(OperatorImpl.class);
     when(mockNextOpImpl2.getOperatorSpec()).thenReturn(new TestOpSpec());
     when(mockNextOpImpl2.handleMessage(anyObject(), anyObject(), anyObject())).thenReturn(Collections.emptyList());
-    mockNextOpImpl2.init(mock(Config.class), mockTaskContext);
+    mockNextOpImpl2.init(this.context);
     opImpl.registerNextOperator(mockNextOpImpl2);
 
     // send a message to this operator
@@ -96,9 +99,8 @@ public class TestOperatorImpl {
 
   @Test
   public void testOnMessageUpdatesMetrics() {
-    TaskContextImpl mockTaskContext = mock(TaskContextImpl.class);
     ReadableMetricsRegistry mockMetricsRegistry = mock(ReadableMetricsRegistry.class);
-    when(mockTaskContext.getMetricsRegistry()).thenReturn(mockMetricsRegistry);
+    when(this.context.getTaskContext().getTaskMetricsRegistry()).thenReturn(mockMetricsRegistry);
     Counter mockCounter = mock(Counter.class);
     Timer mockTimer = mock(Timer.class);
     when(mockMetricsRegistry.newCounter(anyString(), anyString())).thenReturn(mockCounter);
@@ -106,7 +108,7 @@ public class TestOperatorImpl {
 
     Object mockTestOpImplOutput = mock(Object.class);
     OperatorImpl<Object, Object> opImpl = new TestOpImpl(mockTestOpImplOutput);
-    opImpl.init(mock(Config.class), mockTaskContext);
+    opImpl.init(this.context);
 
     // send a message to this operator
     MessageCollector mockCollector = mock(MessageCollector.class);
@@ -120,24 +122,21 @@ public class TestOperatorImpl {
 
   @Test
   public void testOnTimerPropagatesResultsAndTimer() {
-    TaskContextImpl mockTaskContext = mock(TaskContextImpl.class);
-    when(mockTaskContext.getMetricsRegistry()).thenReturn(new MetricsRegistryMap());
-
     Object mockTestOpImplOutput = mock(Object.class);
     OperatorImpl<Object, Object> opImpl = new TestOpImpl(mockTestOpImplOutput);
-    opImpl.init(mock(Config.class), mockTaskContext);
+    opImpl.init(this.context);
 
     // register a couple of operators
     OperatorImpl mockNextOpImpl1 = mock(OperatorImpl.class);
     when(mockNextOpImpl1.getOperatorSpec()).thenReturn(new TestOpSpec());
     when(mockNextOpImpl1.handleMessage(anyObject(), anyObject(), anyObject())).thenReturn(Collections.emptyList());
-    mockNextOpImpl1.init(mock(Config.class), mockTaskContext);
+    mockNextOpImpl1.init(this.context);
     opImpl.registerNextOperator(mockNextOpImpl1);
 
     OperatorImpl mockNextOpImpl2 = mock(OperatorImpl.class);
     when(mockNextOpImpl2.getOperatorSpec()).thenReturn(new TestOpSpec());
     when(mockNextOpImpl2.handleMessage(anyObject(), anyObject(), anyObject())).thenReturn(Collections.emptyList());
-    mockNextOpImpl2.init(mock(Config.class), mockTaskContext);
+    mockNextOpImpl2.init(this.context);
     opImpl.registerNextOperator(mockNextOpImpl2);
 
     // send a timer tick to this operator
@@ -156,9 +155,8 @@ public class TestOperatorImpl {
 
   @Test
   public void testOnTimerUpdatesMetrics() {
-    TaskContextImpl mockTaskContext = mock(TaskContextImpl.class);
     ReadableMetricsRegistry mockMetricsRegistry = mock(ReadableMetricsRegistry.class);
-    when(mockTaskContext.getMetricsRegistry()).thenReturn(mockMetricsRegistry);
+    when(this.context.getTaskContext().getTaskMetricsRegistry()).thenReturn(mockMetricsRegistry);
     Counter mockMessageCounter = mock(Counter.class);
     Timer mockTimer = mock(Timer.class);
     when(mockMetricsRegistry.newCounter(anyString(), anyString())).thenReturn(mockMessageCounter);
@@ -166,7 +164,7 @@ public class TestOperatorImpl {
 
     Object mockTestOpImplOutput = mock(Object.class);
     OperatorImpl<Object, Object> opImpl = new TestOpImpl(mockTestOpImplOutput);
-    opImpl.init(mock(Config.class), mockTaskContext);
+    opImpl.init(this.context);
 
     // send a message to this operator
     MessageCollector mockCollector = mock(MessageCollector.class);
@@ -188,7 +186,7 @@ public class TestOperatorImpl {
     }
 
     @Override
-    protected void handleInit(Config config, TaskContext context) {}
+    protected void handleInit(Context context) {}
 
     @Override
     public Collection<Object> handleMessage(Object message,
index 3abd502..d760805 100644 (file)
@@ -21,27 +21,16 @@ package org.apache.samza.operators.impl;
 
 import com.google.common.collect.HashMultimap;
 import com.google.common.collect.Multimap;
-import java.io.Serializable;
-import java.time.Duration;
-import java.util.ArrayList;
-import java.util.Collection;
-import java.util.Collections;
-import java.util.HashMap;
-import java.util.HashSet;
-import java.util.List;
-import java.util.Map;
-import java.util.Set;
-import java.util.function.BiFunction;
-import java.util.function.Function;
 import org.apache.samza.Partition;
 import org.apache.samza.application.StreamApplicationDescriptorImpl;
 import org.apache.samza.config.Config;
 import org.apache.samza.config.JobConfig;
 import org.apache.samza.config.MapConfig;
 import org.apache.samza.config.StreamConfig;
-import org.apache.samza.container.SamzaContainerContext;
-import org.apache.samza.container.TaskContextImpl;
 import org.apache.samza.container.TaskName;
+import org.apache.samza.context.Context;
+import org.apache.samza.context.MockContext;
+import org.apache.samza.context.TaskContextImpl;
 import org.apache.samza.job.model.ContainerModel;
 import org.apache.samza.job.model.JobModel;
 import org.apache.samza.job.model.TaskModel;
@@ -67,15 +56,28 @@ import org.apache.samza.system.IncomingMessageEnvelope;
 import org.apache.samza.system.SystemStream;
 import org.apache.samza.system.SystemStreamPartition;
 import org.apache.samza.task.MessageCollector;
-import org.apache.samza.task.TaskContext;
 import org.apache.samza.task.TaskCoordinator;
 import org.apache.samza.testUtils.StreamTestUtils;
 import org.apache.samza.util.Clock;
 import org.apache.samza.util.SystemClock;
 import org.apache.samza.util.TimestampedValue;
 import org.junit.After;
+import org.junit.Before;
 import org.junit.Test;
 
+import java.io.Serializable;
+import java.time.Duration;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.function.BiFunction;
+import java.util.function.Function;
+
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertNotSame;
 import static org.junit.Assert.assertTrue;
@@ -84,7 +86,6 @@ import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.when;
 
 public class TestOperatorImplGraph {
-
   private void addOperatorRecursively(HashSet<OperatorImpl> s, OperatorImpl op) {
     List<OperatorImpl> operators = new ArrayList<>();
     operators.add(op);
@@ -193,25 +194,39 @@ public class TestOperatorImplGraph {
     }
 
     @Override
-    public void init(Config config, TaskContext context) {
-      if (perTaskFunctionMap.get(context.getTaskName()) == null) {
-        perTaskFunctionMap.put(context.getTaskName(), new HashMap<String, BaseTestFunction>() { { this.put(opId, BaseTestFunction.this); } });
+    public void init(Context context) {
+      TaskName taskName = context.getTaskContext().getTaskModel().getTaskName();
+      if (perTaskFunctionMap.get(taskName) == null) {
+        perTaskFunctionMap.put(taskName, new HashMap<String, BaseTestFunction>() { { this.put(opId, BaseTestFunction.this); } });
       } else {
-        if (perTaskFunctionMap.get(context.getTaskName()).containsKey(opId)) {
+        if (perTaskFunctionMap.get(taskName).containsKey(opId)) {
           throw new IllegalStateException(String.format("Multiple init called for op %s in the same task instance %s", opId, this.taskName.getTaskName()));
         }
-        perTaskFunctionMap.get(context.getTaskName()).put(opId, this);
+        perTaskFunctionMap.get(taskName).put(opId, this);
       }
-      if (perTaskInitList.get(context.getTaskName()) == null) {
-        perTaskInitList.put(context.getTaskName(), new ArrayList<String>() { { this.add(opId); } });
+      if (perTaskInitList.get(taskName) == null) {
+        perTaskInitList.put(taskName, new ArrayList<String>() { { this.add(opId); } });
       } else {
-        perTaskInitList.get(context.getTaskName()).add(opId);
+        perTaskInitList.get(taskName).add(opId);
       }
-      this.taskName = context.getTaskName();
+      this.taskName = taskName;
       this.numInitCalled++;
     }
   }
 
+  private Context context;
+
+  @Before
+  public void setup() {
+    this.context = new MockContext();
+    // individual tests can override this config if necessary
+    when(this.context.getJobContext().getConfig()).thenReturn(mock(Config.class));
+    TaskModel taskModel = mock(TaskModel.class);
+    when(taskModel.getTaskName()).thenReturn(new TaskName("task 0"));
+    when(this.context.getTaskContext().getTaskModel()).thenReturn(taskModel);
+    when(this.context.getTaskContext().getTaskMetricsRegistry()).thenReturn(new MetricsRegistryMap());
+  }
+
   @After
   public void tearDown() {
     BaseTestFunction.reset();
@@ -220,8 +235,7 @@ public class TestOperatorImplGraph {
   @Test
   public void testEmptyChain() {
     StreamApplicationDescriptorImpl graphSpec = new StreamApplicationDescriptorImpl(appDesc -> { }, mock(Config.class));
-    OperatorImplGraph opGraph =
-        new OperatorImplGraph(graphSpec.getOperatorSpecGraph(), mock(Config.class), mock(TaskContextImpl.class), mock(Clock.class));
+    OperatorImplGraph opGraph = new OperatorImplGraph(graphSpec.getOperatorSpecGraph(), context, mock(Clock.class));
     assertEquals(0, opGraph.getAllInputOperators().size());
   }
 
@@ -242,6 +256,7 @@ public class TestOperatorImplGraph {
     StreamTestUtils.addStreamConfigs(configs, inputStreamId, inputSystem, inputPhysicalName);
     StreamTestUtils.addStreamConfigs(configs, outputStreamId, outputSystem, outputPhysicalName);
     Config config = new MapConfig(configs);
+    when(this.context.getJobContext().getConfig()).thenReturn(config);
 
     StreamApplicationDescriptorImpl graphSpec = new StreamApplicationDescriptorImpl(appDesc -> {
         GenericSystemDescriptor sd = new GenericSystemDescriptor(inputSystem, "mockFactoryClass");
@@ -256,11 +271,8 @@ public class TestOperatorImplGraph {
             .sendTo(outputStream);
       }, config);
 
-    TaskContextImpl mockTaskContext = mock(TaskContextImpl.class);
-    when(mockTaskContext.getMetricsRegistry()).thenReturn(new MetricsRegistryMap());
-    when(mockTaskContext.getTaskName()).thenReturn(new TaskName("task 0"));
     OperatorImplGraph opImplGraph =
-        new OperatorImplGraph(graphSpec.getOperatorSpecGraph(), config, mockTaskContext, mock(Clock.class));
+        new OperatorImplGraph(graphSpec.getOperatorSpecGraph(), this.context, mock(Clock.class));
 
     InputOperatorImpl inputOpImpl = opImplGraph.getInputOperator(new SystemStream(inputSystem, inputPhysicalName));
     assertEquals(1, inputOpImpl.registeredOperators.size());
@@ -296,6 +308,7 @@ public class TestOperatorImplGraph {
     StreamTestUtils.addStreamConfigs(configs, inputStreamId, inputSystem, inputPhysicalName);
     StreamTestUtils.addStreamConfigs(configs, outputStreamId, outputSystem, outputPhysicalName);
     Config config = new MapConfig(configs);
+    when(this.context.getJobContext().getConfig()).thenReturn(config);
 
     StreamApplicationDescriptorImpl graphSpec = new StreamApplicationDescriptorImpl(appDesc -> {
         GenericSystemDescriptor isd = new GenericSystemDescriptor(inputSystem, "mockFactoryClass");
@@ -312,21 +325,15 @@ public class TestOperatorImplGraph {
             .sendTo(outputStream);
       }, config);
 
-    TaskContextImpl mockTaskContext = mock(TaskContextImpl.class);
-    when(mockTaskContext.getMetricsRegistry()).thenReturn(new MetricsRegistryMap());
-    when(mockTaskContext.getTaskName()).thenReturn(new TaskName("task 0"));
     JobModel jobModel = mock(JobModel.class);
     ContainerModel containerModel = mock(ContainerModel.class);
     TaskModel taskModel = mock(TaskModel.class);
     when(jobModel.getContainers()).thenReturn(Collections.singletonMap("0", containerModel));
     when(containerModel.getTasks()).thenReturn(Collections.singletonMap(new TaskName("task 0"), taskModel));
     when(taskModel.getSystemStreamPartitions()).thenReturn(Collections.emptySet());
-    when(mockTaskContext.getJobModel()).thenReturn(jobModel);
-    SamzaContainerContext containerContext =
-        new SamzaContainerContext("0", config, Collections.singleton(new TaskName("task 0")), new MetricsRegistryMap());
-    when(mockTaskContext.getSamzaContainerContext()).thenReturn(containerContext);
+    when(((TaskContextImpl) this.context.getTaskContext()).getJobModel()).thenReturn(jobModel);
     OperatorImplGraph opImplGraph =
-        new OperatorImplGraph(graphSpec.getOperatorSpecGraph(), config, mockTaskContext, mock(Clock.class));
+        new OperatorImplGraph(graphSpec.getOperatorSpecGraph(), this.context, mock(Clock.class));
 
     InputOperatorImpl inputOpImpl = opImplGraph.getInputOperator(new SystemStream(inputSystem, inputPhysicalName));
     assertEquals(1, inputOpImpl.registeredOperators.size());
@@ -352,6 +359,7 @@ public class TestOperatorImplGraph {
     HashMap<String, String> configMap = new HashMap<>();
     StreamTestUtils.addStreamConfigs(configMap, inputStreamId, inputSystem, inputPhysicalName);
     Config config = new MapConfig(configMap);
+    when(this.context.getJobContext().getConfig()).thenReturn(config);
     StreamApplicationDescriptorImpl graphSpec = new StreamApplicationDescriptorImpl(appDesc -> {
         GenericSystemDescriptor sd = new GenericSystemDescriptor(inputSystem, "mockFactoryClass");
         GenericInputDescriptor inputDescriptor = sd.getInputDescriptor(inputStreamId, mock(Serde.class));
@@ -360,10 +368,8 @@ public class TestOperatorImplGraph {
         inputStream.map(mock(MapFunction.class));
       }, config);
 
-    TaskContextImpl mockTaskContext = mock(TaskContextImpl.class);
-    when(mockTaskContext.getMetricsRegistry()).thenReturn(new MetricsRegistryMap());
     OperatorImplGraph opImplGraph =
-        new OperatorImplGraph(graphSpec.getOperatorSpecGraph(), config, mockTaskContext, mock(Clock.class));
+        new OperatorImplGraph(graphSpec.getOperatorSpecGraph(), this.context, mock(Clock.class));
 
     InputOperatorImpl inputOpImpl = opImplGraph.getInputOperator(new SystemStream(inputSystem, inputPhysicalName));
     assertEquals(2, inputOpImpl.registeredOperators.size());
@@ -377,10 +383,6 @@ public class TestOperatorImplGraph {
   public void testMergeChain() {
     String inputStreamId = "input";
     String inputSystem = "input-system";
-    String inputPhysicalName = "input-stream";
-    HashMap<String, String> configs = new HashMap<>();
-    StreamTestUtils.addStreamConfigs(configs, inputStreamId, inputSystem, inputPhysicalName);
-    Config config = new MapConfig(configs);
     StreamApplicationDescriptorImpl graphSpec = new StreamApplicationDescriptorImpl(appDesc -> {
         GenericSystemDescriptor sd = new GenericSystemDescriptor(inputSystem, "mockFactoryClass");
         GenericInputDescriptor inputDescriptor = sd.getInputDescriptor(inputStreamId, mock(Serde.class));
@@ -390,13 +392,14 @@ public class TestOperatorImplGraph {
         stream1.merge(Collections.singleton(stream2))
             .map(new TestMapFunction<Object, Object>("test-map-1", (Function & Serializable) m -> m));
       }, mock(Config.class));
-    TaskContextImpl mockTaskContext = mock(TaskContextImpl.class);
+
     TaskName mockTaskName = mock(TaskName.class);
-    when(mockTaskContext.getTaskName()).thenReturn(mockTaskName);
-    when(mockTaskContext.getMetricsRegistry()).thenReturn(new MetricsRegistryMap());
+    TaskModel taskModel = mock(TaskModel.class);
+    when(taskModel.getTaskName()).thenReturn(mockTaskName);
+    when(this.context.getTaskContext().getTaskModel()).thenReturn(taskModel);
 
     OperatorImplGraph opImplGraph =
-        new OperatorImplGraph(graphSpec.getOperatorSpecGraph(), mock(Config.class), mockTaskContext, mock(Clock.class));
+        new OperatorImplGraph(graphSpec.getOperatorSpecGraph(), this.context, mock(Clock.class));
 
     Set<OperatorImpl> opSet = opImplGraph.getAllInputOperators().stream().collect(HashSet::new,
         (s, op) -> addOperatorRecursively(s, op), HashSet::addAll);
@@ -423,6 +426,7 @@ public class TestOperatorImplGraph {
     StreamTestUtils.addStreamConfigs(configs, inputStreamId1, inputSystem, inputPhysicalName1);
     StreamTestUtils.addStreamConfigs(configs, inputStreamId2, inputSystem, inputPhysicalName2);
     Config config = new MapConfig(configs);
+    when(this.context.getJobContext().getConfig()).thenReturn(config);
 
     Integer joinKey = new Integer(1);
     Function<Object, Integer> keyFn = (Function & Serializable) m -> joinKey;
@@ -441,15 +445,16 @@ public class TestOperatorImplGraph {
       }, config);
 
     TaskName mockTaskName = mock(TaskName.class);
-    TaskContextImpl mockTaskContext = mock(TaskContextImpl.class);
-    when(mockTaskContext.getTaskName()).thenReturn(mockTaskName);
-    when(mockTaskContext.getMetricsRegistry()).thenReturn(new MetricsRegistryMap());
+    TaskModel taskModel = mock(TaskModel.class);
+    when(taskModel.getTaskName()).thenReturn(mockTaskName);
+    when(this.context.getTaskContext().getTaskModel()).thenReturn(taskModel);
+
     KeyValueStore mockLeftStore = mock(KeyValueStore.class);
-    when(mockTaskContext.getStore(eq("jobName-jobId-join-j1-L"))).thenReturn(mockLeftStore);
+    when(this.context.getTaskContext().getStore(eq("jobName-jobId-join-j1-L"))).thenReturn(mockLeftStore);
     KeyValueStore mockRightStore = mock(KeyValueStore.class);
-    when(mockTaskContext.getStore(eq("jobName-jobId-join-j1-R"))).thenReturn(mockRightStore);
+    when(this.context.getTaskContext().getStore(eq("jobName-jobId-join-j1-R"))).thenReturn(mockRightStore);
     OperatorImplGraph opImplGraph =
-        new OperatorImplGraph(graphSpec.getOperatorSpecGraph(), config, mockTaskContext, mock(Clock.class));
+        new OperatorImplGraph(graphSpec.getOperatorSpecGraph(), this.context, mock(Clock.class));
 
     // verify that join function is initialized once.
     assertEquals(TestJoinFunction.getInstanceByTaskName(mockTaskName, "jobName-jobId-join-j1").numInitCalled, 1);
@@ -491,10 +496,12 @@ public class TestOperatorImplGraph {
     String inputStreamId2 = "input2";
     String inputSystem = "input-system";
     Config mockConfig = mock(Config.class);
+
     TaskName mockTaskName = mock(TaskName.class);
-    TaskContextImpl mockContext = mock(TaskContextImpl.class);
-    when(mockContext.getTaskName()).thenReturn(mockTaskName);
-    when(mockContext.getMetricsRegistry()).thenReturn(new MetricsRegistryMap());
+    TaskModel taskModel = mock(TaskModel.class);
+    when(taskModel.getTaskName()).thenReturn(mockTaskName);
+    when(this.context.getTaskContext().getTaskModel()).thenReturn(taskModel);
+
     StreamApplicationDescriptorImpl graphSpec = new StreamApplicationDescriptorImpl(appDesc -> {
         GenericSystemDescriptor sd = new GenericSystemDescriptor(inputSystem, "mockFactoryClass");
         GenericInputDescriptor inputDescriptor1 = sd.getInputDescriptor(inputStreamId1, mock(Serde.class));
@@ -510,7 +517,7 @@ public class TestOperatorImplGraph {
             .map(new TestMapFunction<Object, Object>("4", mapFn));
       }, mockConfig);
 
-    OperatorImplGraph opImplGraph = new OperatorImplGraph(graphSpec.getOperatorSpecGraph(), mockConfig, mockContext, SystemClock.instance());
+    OperatorImplGraph opImplGraph = new OperatorImplGraph(graphSpec.getOperatorSpecGraph(), this.context, SystemClock.instance());
 
     List<String> initializedOperators = BaseTestFunction.getInitListByTaskName(mockTaskName);
 
@@ -541,6 +548,7 @@ public class TestOperatorImplGraph {
     StreamTestUtils.addStreamConfigs(configs, streamId0, system, streamId0);
     StreamTestUtils.addStreamConfigs(configs, streamId1, system, streamId1);
     Config config = new MapConfig(configs);
+    when(this.context.getJobContext().getConfig()).thenReturn(config);
 
     SystemStreamPartition ssp0 = new SystemStreamPartition(system, streamId0, new Partition(0));
     SystemStreamPartition ssp1 = new SystemStreamPartition(system, streamId0, new Partition(1));
@@ -590,6 +598,7 @@ public class TestOperatorImplGraph {
     StreamTestUtils.addStreamConfigs(configs, outputStreamId1, outputSystem, outputStreamId1);
     StreamTestUtils.addStreamConfigs(configs, outputStreamId2, outputSystem, outputStreamId2);
     Config config = new MapConfig(configs);
+    when(this.context.getJobContext().getConfig()).thenReturn(config);
 
     StreamApplicationDescriptorImpl graphSpec = new StreamApplicationDescriptorImpl(appDesc -> {
         GenericSystemDescriptor isd = new GenericSystemDescriptor(inputSystem, "mockFactoryClass");
@@ -640,14 +649,6 @@ public class TestOperatorImplGraph {
     String inputSystem1 = "system1";
     String inputSystem2 = "system2";
 
-    HashMap<String, String> configs = new HashMap<>();
-    configs.put(JobConfig.JOB_NAME(), "test-app");
-    configs.put(JobConfig.JOB_DEFAULT_SYSTEM(), inputSystem1);
-    StreamTestUtils.addStreamConfigs(configs, inputStreamId1, inputSystem1, inputStreamId1);
-    StreamTestUtils.addStreamConfigs(configs, inputStreamId2, inputSystem2, inputStreamId2);
-    StreamTestUtils.addStreamConfigs(configs, inputStreamId3, inputSystem2, inputStreamId3);
-    Config config = new MapConfig(configs);
-
     SystemStream input1 = new SystemStream("system1", "intput1");
     SystemStream input2 = new SystemStream("system2", "intput2");
     SystemStream input3 = new SystemStream("system2", "intput3");
index dc94e36..dfd8657 100644 (file)
  */
 package org.apache.samza.operators.impl;
 
-import org.apache.samza.config.Config;
 import org.apache.samza.operators.data.TestOutputMessageEnvelope;
 import org.apache.samza.operators.functions.SinkFunction;
 import org.apache.samza.operators.spec.SinkOperatorSpec;
 import org.apache.samza.task.MessageCollector;
-import org.apache.samza.task.TaskContext;
 import org.apache.samza.task.TaskCoordinator;
 import org.junit.Test;
 
@@ -69,9 +67,6 @@ public class TestSinkOperatorImpl {
   private SinkOperatorImpl createSinkOperator(SinkFunction<TestOutputMessageEnvelope> sinkFn) {
     SinkOperatorSpec<TestOutputMessageEnvelope> sinkOp = mock(SinkOperatorSpec.class);
     when(sinkOp.getSinkFn()).thenReturn(sinkFn);
-
-    Config mockConfig = mock(Config.class);
-    TaskContext mockContext = mock(TaskContext.class);
-    return new SinkOperatorImpl<>(sinkOp, mockConfig, mockContext);
+    return new SinkOperatorImpl<>(sinkOp);
   }
 }
index 873cd3c..ae05305 100644 (file)
  */
 package org.apache.samza.operators.impl;
 
-import org.apache.samza.config.Config;
 import org.apache.samza.operators.data.TestMessageEnvelope;
 import org.apache.samza.operators.data.TestOutputMessageEnvelope;
 import org.apache.samza.operators.functions.FlatMapFunction;
 import org.apache.samza.operators.spec.StreamOperatorSpec;
 import org.apache.samza.task.MessageCollector;
-import org.apache.samza.task.TaskContext;
 import org.apache.samza.task.TaskCoordinator;
 import org.junit.Test;
 
@@ -45,8 +43,6 @@ public class TestStreamOperatorImpl {
     StreamOperatorSpec<TestMessageEnvelope, TestOutputMessageEnvelope> mockOp = mock(StreamOperatorSpec.class);
     FlatMapFunction<TestMessageEnvelope, TestOutputMessageEnvelope> txfmFn = mock(FlatMapFunction.class);
     when(mockOp.getTransformFn()).thenReturn(txfmFn);
-    Config mockConfig = mock(Config.class);
-    TaskContext mockContext = mock(TaskContext.class);
     StreamOperatorImpl<TestMessageEnvelope, TestOutputMessageEnvelope> opImpl =
         new StreamOperatorImpl<>(mockOp);
     TestMessageEnvelope inMsg = mock(TestMessageEnvelope.class);
@@ -65,8 +61,6 @@ public class TestStreamOperatorImpl {
     StreamOperatorSpec<TestMessageEnvelope, TestOutputMessageEnvelope> mockOp = mock(StreamOperatorSpec.class);
     FlatMapFunction<TestMessageEnvelope, TestOutputMessageEnvelope> txfmFn = mock(FlatMapFunction.class);
     when(mockOp.getTransformFn()).thenReturn(txfmFn);
-    Config mockConfig = mock(Config.class);
-    TaskContext mockContext = mock(TaskContext.class);
 
     StreamOperatorImpl<TestMessageEnvelope, TestOutputMessageEnvelope> opImpl =
         new StreamOperatorImpl<>(mockOp);
index d8b2e8d..9083495 100644 (file)
  */
 package org.apache.samza.operators.impl;
 
-import java.util.Collection;
-
+import junit.framework.Assert;
 import org.apache.samza.SamzaException;
-import org.apache.samza.config.Config;
+import org.apache.samza.context.Context;
+import org.apache.samza.context.MockContext;
 import org.apache.samza.operators.KV;
 import org.apache.samza.operators.data.TestMessageEnvelope;
 import org.apache.samza.operators.functions.StreamTableJoinFunction;
@@ -29,11 +29,10 @@ import org.apache.samza.operators.spec.StreamTableJoinOperatorSpec;
 import org.apache.samza.table.ReadableTable;
 import org.apache.samza.table.TableSpec;
 import org.apache.samza.task.MessageCollector;
-import org.apache.samza.task.TaskContext;
 import org.apache.samza.task.TaskCoordinator;
 import org.junit.Test;
 
-import junit.framework.Assert;
+import java.util.Collection;
 
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.when;
@@ -75,18 +74,16 @@ public class TestStreamTableJoinOperatorImpl {
             return record.getKey();
           }
         });
-    Config config = mock(Config.class);
     ReadableTable table = mock(ReadableTable.class);
     when(table.get("1")).thenReturn("r1");
     when(table.get("2")).thenReturn(null);
-    TaskContext mockTaskContext = mock(TaskContext.class);
-    when(mockTaskContext.getTable(tableId)).thenReturn(table);
+    Context context = new MockContext();
+    when(context.getTaskContext().getTable(tableId)).thenReturn(table);
 
     MessageCollector mockMessageCollector = mock(MessageCollector.class);
     TaskCoordinator mockTaskCoordinator = mock(TaskCoordinator.class);
 
-    StreamTableJoinOperatorImpl streamTableJoinOperator = new StreamTableJoinOperatorImpl(
-        mockJoinOpSpec, config, mockTaskContext);
+    StreamTableJoinOperatorImpl streamTableJoinOperator = new StreamTableJoinOperatorImpl(mockJoinOpSpec, context);
 
     // Table has the key
     Collection<TestMessageEnvelope> result;
index 7d468c9..20d5e25 100644 (file)
@@ -30,13 +30,15 @@ import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import org.apache.samza.Partition;
-import org.apache.samza.application.StreamApplicationDescriptorImpl;
 import org.apache.samza.application.StreamApplication;
+import org.apache.samza.application.StreamApplicationDescriptorImpl;
 import org.apache.samza.config.Config;
-import org.apache.samza.config.JobConfig;
 import org.apache.samza.config.MapConfig;
-import org.apache.samza.container.TaskContextImpl;
 import org.apache.samza.container.TaskName;
+import org.apache.samza.context.Context;
+import org.apache.samza.context.MockContext;
+import org.apache.samza.context.TaskContextImpl;
+import org.apache.samza.job.model.TaskModel;
 import org.apache.samza.metrics.MetricsRegistryMap;
 import org.apache.samza.operators.KV;
 import org.apache.samza.operators.MessageStream;
@@ -67,8 +69,6 @@ import org.junit.Assert;
 import org.junit.Before;
 import org.junit.Test;
 
-import static org.mockito.Matchers.anyString;
-import static org.mockito.Matchers.eq;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
@@ -77,41 +77,41 @@ import static org.mockito.Mockito.when;
 public class TestWindowOperator {
   private final TaskCoordinator taskCoordinator = mock(TaskCoordinator.class);
   private final List<Integer> integers = ImmutableList.of(1, 2, 1, 2, 1, 2, 1, 2, 3);
+  private Context context;
   private Config config;
-  private TaskContextImpl taskContext;
 
   @Before
-  public void setup() throws Exception {
-    config = mock(Config.class);
-    when(config.get(JobConfig.JOB_NAME())).thenReturn("jobName");
-    when(config.get(eq(JobConfig.JOB_ID()), anyString())).thenReturn("jobId");
-    taskContext = mock(TaskContextImpl.class);
+  public void setup() {
+    Map<String, String> configMap = new HashMap<>();
+    configMap.put("job.default.system", "kafka");
+    configMap.put("job.name", "jobName");
+    configMap.put("job.id", "jobId");
+    this.config = new MapConfig(configMap);
+
+    this.context = new MockContext();
+    when(this.context.getJobContext().getConfig()).thenReturn(this.config);
     Serde storeKeySerde = new TimeSeriesKeySerde(new IntegerSerde());
     Serde storeValSerde = KVSerde.of(new IntegerSerde(), new IntegerSerde());
 
-    when(taskContext.getSystemStreamPartitions()).thenReturn(ImmutableSet
+    TaskModel taskModel = mock(TaskModel.class);
+    when(taskModel.getSystemStreamPartitions()).thenReturn(ImmutableSet
         .of(new SystemStreamPartition("kafka", "integTestExecutionPlannerers", new Partition(0))));
-    when(taskContext.getMetricsRegistry()).thenReturn(new MetricsRegistryMap());
-    when(taskContext.getStore("jobName-jobId-window-w1"))
+    when(taskModel.getTaskName()).thenReturn(new TaskName("task 1"));
+    when(this.context.getTaskContext().getTaskModel()).thenReturn(taskModel);
+    when(this.context.getTaskContext().getTaskMetricsRegistry()).thenReturn(new MetricsRegistryMap());
+    when(this.context.getTaskContext().getStore("jobName-jobId-window-w1"))
         .thenReturn(new TestInMemoryStore<>(storeKeySerde, storeValSerde));
-
-    Map<String, String> mapConfig = new HashMap<>();
-    mapConfig.put("job.default.system", "kafka");
-    mapConfig.put("job.name", "jobName");
-    mapConfig.put("job.id", "jobId");
-    config = new MapConfig(mapConfig);
   }
 
   @Test
   public void testTumblingWindowsDiscardingMode() throws Exception {
-
     OperatorSpecGraph sgb = this.getKeyedTumblingWindowStreamGraph(AccumulationMode.DISCARDING,
         Duration.ofSeconds(1), Triggers.repeat(Triggers.count(2))).getOperatorSpecGraph();
     List<WindowPane<Integer, Collection<IntegerEnvelope>>> windowPanes = new ArrayList<>();
 
     TestClock testClock = new TestClock();
-    StreamOperatorTask task = new StreamOperatorTask(sgb, null, testClock);
-    task.init(config, taskContext);
+    StreamOperatorTask task = new StreamOperatorTask(sgb, testClock);
+    task.init(this.context);
     MessageCollector messageCollector =
         envelope -> windowPanes.add((WindowPane<Integer, Collection<IntegerEnvelope>>) envelope.getMessage());
     integers.forEach(n -> task.process(new IntegerEnvelope(n), messageCollector, taskCoordinator));
@@ -143,8 +143,8 @@ public class TestWindowOperator {
     List<WindowPane<Integer, Collection<IntegerEnvelope>>> windowPanes = new ArrayList<>();
 
     TestClock testClock = new TestClock();
-    StreamOperatorTask task = new StreamOperatorTask(sgb, null, testClock);
-    task.init(config, taskContext);
+    StreamOperatorTask task = new StreamOperatorTask(sgb, testClock);
+    task.init(this.context);
 
     MessageCollector messageCollector =
         envelope -> windowPanes.add((WindowPane<Integer, Collection<IntegerEnvelope>>) envelope.getMessage());
@@ -163,8 +163,7 @@ public class TestWindowOperator {
 
   @Test
   public void testTumblingAggregatingWindowsDiscardingMode() throws Exception {
-
-    when(taskContext.getStore("jobName-jobId-window-w1"))
+    when(this.context.getTaskContext().getStore("jobName-jobId-window-w1"))
         .thenReturn(new TestInMemoryStore<>(new TimeSeriesKeySerde(new IntegerSerde()), new IntegerSerde()));
 
     OperatorSpecGraph sgb = this.getAggregateTumblingWindowStreamGraph(AccumulationMode.DISCARDING,
@@ -172,8 +171,8 @@ public class TestWindowOperator {
     List<WindowPane<Integer, Integer>> windowPanes = new ArrayList<>();
 
     TestClock testClock = new TestClock();
-    StreamOperatorTask task = new StreamOperatorTask(sgb, null, testClock);
-    task.init(config, taskContext);
+    StreamOperatorTask task = new StreamOperatorTask(sgb, testClock);
+    task.init(this.context);
     MessageCollector messageCollector = envelope -> windowPanes.add((WindowPane<Integer, Integer>) envelope.getMessage());
     integers.forEach(n -> task.process(new IntegerEnvelope(n), messageCollector, taskCoordinator));
     testClock.advanceTime(Duration.ofSeconds(1));
@@ -193,8 +192,8 @@ public class TestWindowOperator {
         Duration.ofSeconds(1), Triggers.repeat(Triggers.count(2))).getOperatorSpecGraph();
     List<WindowPane<Integer, Collection<IntegerEnvelope>>> windowPanes = new ArrayList<>();
     TestClock testClock = new TestClock();
-    StreamOperatorTask task = new StreamOperatorTask(sgb, null, testClock);
-    task.init(config, taskContext);
+    StreamOperatorTask task = new StreamOperatorTask(sgb, testClock);
+    task.init(this.context);
 
     MessageCollector messageCollector =
         envelope -> windowPanes.add((WindowPane<Integer, Collection<IntegerEnvelope>>) envelope.getMessage());
@@ -222,8 +221,8 @@ public class TestWindowOperator {
         this.getKeyedSessionWindowStreamGraph(AccumulationMode.DISCARDING, Duration.ofMillis(500)).getOperatorSpecGraph();
     TestClock testClock = new TestClock();
     List<WindowPane<Integer, Collection<IntegerEnvelope>>> windowPanes = new ArrayList<>();
-    StreamOperatorTask task = new StreamOperatorTask(sgb, null, testClock);
-    task.init(config, taskContext);
+    StreamOperatorTask task = new StreamOperatorTask(sgb, testClock);
+    task.init(this.context);
     MessageCollector messageCollector =
         envelope -> windowPanes.add((WindowPane<Integer, Collection<IntegerEnvelope>>) envelope.getMessage());
     task.process(new IntegerEnvelope(1), messageCollector, taskCoordinator);
@@ -267,12 +266,12 @@ public class TestWindowOperator {
     OperatorSpecGraph sgb = this.getKeyedSessionWindowStreamGraph(AccumulationMode.DISCARDING,
         Duration.ofMillis(500)).getOperatorSpecGraph();
     TestClock testClock = new TestClock();
-    StreamOperatorTask task = new StreamOperatorTask(sgb, null, testClock);
+    StreamOperatorTask task = new StreamOperatorTask(sgb, testClock);
     List<WindowPane<Integer, Collection<IntegerEnvelope>>> windowPanes = new ArrayList<>();
 
     MessageCollector messageCollector =
         envelope -> windowPanes.add((WindowPane<Integer, Collection<IntegerEnvelope>>) envelope.getMessage());
-    task.init(config, taskContext);
+    task.init(this.context);
 
     task.process(new IntegerEnvelope(1), messageCollector, taskCoordinator);
     task.process(new IntegerEnvelope(1), messageCollector, taskCoordinator);
@@ -299,8 +298,8 @@ public class TestWindowOperator {
     OperatorSpecGraph sgb = this.getKeyedTumblingWindowStreamGraph(AccumulationMode.ACCUMULATING,
         Duration.ofSeconds(1), Triggers.count(2)).getOperatorSpecGraph();
     TestClock testClock = new TestClock();
-    StreamOperatorTask task = new StreamOperatorTask(sgb, null, testClock);
-    task.init(config, taskContext);
+    StreamOperatorTask task = new StreamOperatorTask(sgb, testClock);
+    task.init(this.context);
 
     List<WindowPane<Integer, Collection<IntegerEnvelope>>> windowPanes = new ArrayList<>();
     MessageCollector messageCollector =
@@ -343,8 +342,8 @@ public class TestWindowOperator {
     OperatorSpecGraph sgb = this.getKeyedTumblingWindowStreamGraph(AccumulationMode.ACCUMULATING, Duration.ofSeconds(1),
         Triggers.any(Triggers.count(2), Triggers.timeSinceFirstMessage(Duration.ofMillis(500)))).getOperatorSpecGraph();
     TestClock testClock = new TestClock();
-    StreamOperatorTask task = new StreamOperatorTask(sgb, null, testClock);
-    task.init(config, taskContext);
+    StreamOperatorTask task = new StreamOperatorTask(sgb, testClock);
+    task.init(this.context);
 
     List<WindowPane<Integer, Collection<IntegerEnvelope>>> windowPanes = new ArrayList<>();
     MessageCollector messageCollector =
@@ -406,8 +405,8 @@ public class TestWindowOperator {
         envelope -> windowPanes.add((WindowPane<Integer, Collection<IntegerEnvelope>>) envelope.getMessage());
 
     TestClock testClock = new TestClock();
-    StreamOperatorTask task = new StreamOperatorTask(sgb, null, testClock);
-    task.init(config, taskContext);
+    StreamOperatorTask task = new StreamOperatorTask(sgb, testClock);
+    task.init(this.context);
 
     task.process(new IntegerEnvelope(1), messageCollector, taskCoordinator);
 
@@ -439,17 +438,18 @@ public class TestWindowOperator {
     EndOfStreamStates endOfStreamStates = new EndOfStreamStates(ImmutableSet.of(new SystemStreamPartition("kafka",
         "integers", new Partition(0))), Collections.emptyMap());
 
-    when(taskContext.getTaskName()).thenReturn(new TaskName("task 1"));
-    when(taskContext.fetchObject(EndOfStreamStates.class.getName())).thenReturn(endOfStreamStates);
-    when(taskContext.fetchObject(WatermarkStates.class.getName())).thenReturn(mock(WatermarkStates.class));
+    when(((TaskContextImpl) this.context.getTaskContext()).fetchObject(EndOfStreamStates.class.getName())).thenReturn(
+        endOfStreamStates);
+    when(((TaskContextImpl) this.context.getTaskContext()).fetchObject(WatermarkStates.class.getName())).thenReturn(
+        mock(WatermarkStates.class));
 
     OperatorSpecGraph sgb = this.getTumblingWindowStreamGraph(AccumulationMode.DISCARDING,
         Duration.ofSeconds(1), Triggers.repeat(Triggers.count(2))).getOperatorSpecGraph();
     List<WindowPane<Integer, Collection<IntegerEnvelope>>> windowPanes = new ArrayList<>();
 
     TestClock testClock = new TestClock();
-    StreamOperatorTask task = new StreamOperatorTask(sgb, null, testClock);
-    task.init(config, taskContext);
+    StreamOperatorTask task = new StreamOperatorTask(sgb, testClock);
+    task.init(this.context);
 
     MessageCollector messageCollector =
         envelope -> windowPanes.add((WindowPane<Integer, Collection<IntegerEnvelope>>) envelope.getMessage());
@@ -480,16 +480,17 @@ public class TestWindowOperator {
     EndOfStreamStates endOfStreamStates = new EndOfStreamStates(ImmutableSet.of(new SystemStreamPartition("kafka",
         "integers", new Partition(0))), Collections.emptyMap());
 
-    when(taskContext.getTaskName()).thenReturn(new TaskName("task 1"));
-    when(taskContext.fetchObject(EndOfStreamStates.class.getName())).thenReturn(endOfStreamStates);
-    when(taskContext.fetchObject(WatermarkStates.class.getName())).thenReturn(mock(WatermarkStates.class));
+    when(((TaskContextImpl) this.context.getTaskContext()).fetchObject(EndOfStreamStates.class.getName())).thenReturn(
+        endOfStreamStates);
+    when(((TaskContextImpl) this.context.getTaskContext()).fetchObject(WatermarkStates.class.getName())).thenReturn(
+        mock(WatermarkStates.class));
 
     OperatorSpecGraph sgb =
         this.getKeyedSessionWindowStreamGraph(AccumulationMode.DISCARDING, Duration.ofMillis(500)).getOperatorSpecGraph();
     TestClock testClock = new TestClock();
     List<WindowPane<Integer, Collection<IntegerEnvelope>>> windowPanes = new ArrayList<>();
-    StreamOperatorTask task = new StreamOperatorTask(sgb, null, testClock);
-    task.init(config, taskContext);
+    StreamOperatorTask task = new StreamOperatorTask(sgb, testClock);
+    task.init(this.context);
 
     MessageCollector messageCollector =
         envelope -> windowPanes.add((WindowPane<Integer, Collection<IntegerEnvelope>>) envelope.getMessage());
@@ -517,16 +518,17 @@ public class TestWindowOperator {
     EndOfStreamStates endOfStreamStates = new EndOfStreamStates(ImmutableSet.of(new SystemStreamPartition("kafka",
         "integers", new Partition(0))), Collections.emptyMap());
 
-    when(taskContext.getTaskName()).thenReturn(new TaskName("task 1"));
-    when(taskContext.fetchObject(EndOfStreamStates.class.getName())).thenReturn(endOfStreamStates);
-    when(taskContext.fetchObject(WatermarkStates.class.getName())).thenReturn(mock(WatermarkStates.class));
+    when(((TaskContextImpl) this.context.getTaskContext()).fetchObject(EndOfStreamStates.class.getName())).thenReturn(
+        endOfStreamStates);
+    when(((TaskContextImpl) this.context.getTaskContext()).fetchObject(WatermarkStates.class.getName())).thenReturn(
+        mock(WatermarkStates.class));
 
     OperatorSpecGraph sgb =
         this.getKeyedSessionWindowStreamGraph(AccumulationMode.DISCARDING, Duration.ofMillis(500)).getOperatorSpecGraph();
     TestClock testClock = new TestClock();
     List<WindowPane<Integer, Collection<IntegerEnvelope>>> windowPanes = new ArrayList<>();
-    StreamOperatorTask task = new StreamOperatorTask(sgb, null, testClock);
-    task.init(config, taskContext);
+    StreamOperatorTask task = new StreamOperatorTask(sgb, testClock);
+    task.init(this.context);
 
     MessageCollector messageCollector =
         envelope -> windowPanes.add((WindowPane<Integer, Collection<IntegerEnvelope>>) envelope.getMessage());
index 860e630..6e91e2a 100644 (file)
@@ -31,9 +31,9 @@ import org.apache.samza.operators.functions.FilterFunction;
 import org.apache.samza.operators.functions.FlatMapFunction;
 import org.apache.samza.operators.functions.JoinFunction;
 import org.apache.samza.operators.functions.MapFunction;
+import org.apache.samza.operators.functions.ScheduledFunction;
 import org.apache.samza.operators.functions.SinkFunction;
 import org.apache.samza.operators.functions.StreamTableJoinFunction;
-import org.apache.samza.operators.functions.ScheduledFunction;
 import org.apache.samza.operators.functions.WatermarkFunction;
 import org.apache.samza.serializers.JsonSerdeV2;
 import org.apache.samza.serializers.KVSerde;
index e1342e3..fd4a7fb 100644 (file)
@@ -23,9 +23,9 @@ import java.util.Map;
 import org.apache.samza.application.StreamApplicationDescriptorImpl;
 import org.apache.samza.config.Config;
 import org.apache.samza.config.JobConfig;
-import org.apache.samza.operators.Scheduler;
 import org.apache.samza.operators.MessageStream;
 import org.apache.samza.operators.OperatorSpecGraph;
+import org.apache.samza.operators.Scheduler;
 import org.apache.samza.operators.descriptors.GenericInputDescriptor;
 import org.apache.samza.operators.descriptors.GenericSystemDescriptor;
 import org.apache.samza.operators.functions.MapFunction;
index 41973b2..b73f8e3 100644 (file)
 package org.apache.samza.operators.spec;
 
 import org.apache.samza.operators.Scheduler;
-import org.apache.samza.operators.functions.ScheduledFunction;
-import org.apache.samza.operators.functions.WatermarkFunction;
-import org.apache.samza.serializers.Serde;
 import org.apache.samza.operators.functions.FoldLeftFunction;
 import org.apache.samza.operators.functions.MapFunction;
+import org.apache.samza.operators.functions.ScheduledFunction;
 import org.apache.samza.operators.functions.SupplierFunction;
+import org.apache.samza.operators.functions.WatermarkFunction;
 import org.apache.samza.operators.triggers.Trigger;
 import org.apache.samza.operators.triggers.Triggers;
 import org.apache.samza.operators.windows.internal.WindowInternal;
 import org.apache.samza.operators.windows.internal.WindowType;
+import org.apache.samza.serializers.Serde;
 import org.junit.Assert;
 import org.junit.Before;
 import org.junit.Test;
@@ -38,7 +38,8 @@ import java.time.Duration;
 import java.util.ArrayList;
 import java.util.Collection;
 
-import static org.junit.Assert.*;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNull;
 import static org.mockito.Mockito.mock;
 
 public class TestWindowOperatorSpec {
index 93b157a..b002e2a 100644 (file)
@@ -49,7 +49,8 @@ import org.mockito.Mockito;
 import org.powermock.api.mockito.PowerMockito;
 
 import static org.junit.Assert.assertEquals;
-import static org.mockito.Matchers.*;
+import static org.mockito.Matchers.any;
+import static org.mockito.Matchers.anyString;
 import static org.mockito.Mockito.doAnswer;
 import static org.mockito.Mockito.doReturn;
 import static org.mockito.Mockito.mock;
@@ -465,8 +466,10 @@ public class TestStreamProcessor {
   @Test
   public void testStreamProcessorWithStreamProcessorListenerFactory() {
     AtomicReference<MockStreamProcessorLifecycleListener> mockListener = new AtomicReference<>();
-    StreamProcessor streamProcessor = new StreamProcessor(mock(Config.class), new HashMap<>(), mock(TaskFactory.class),
-        sp -> mockListener.updateAndGet(old -> new MockStreamProcessorLifecycleListener(sp)), mock(JobCoordinator.class));
+    StreamProcessor streamProcessor =
+        new StreamProcessor(mock(Config.class), new HashMap<>(), mock(TaskFactory.class), null, null,
+            sp -> mockListener.updateAndGet(old -> new MockStreamProcessorLifecycleListener(sp)),
+            mock(JobCoordinator.class));
     assertEquals(streamProcessor, mockListener.get().processor);
   }
 
index d483ae6..8eff4ad 100644 (file)
@@ -20,8 +20,8 @@
 package org.apache.samza.storage;
 
 import java.io.File;
-
-import org.apache.samza.container.SamzaContainerContext;
+import org.apache.samza.context.ContainerContext;
+import org.apache.samza.context.JobContext;
 import org.apache.samza.metrics.MetricsRegistry;
 import org.apache.samza.serializers.Serde;
 import org.apache.samza.system.SystemStreamPartition;
@@ -29,9 +29,15 @@ import org.apache.samza.task.MessageCollector;
 
 public class MockStorageEngineFactory implements StorageEngineFactory<Object, Object> {
   @Override
-  public StorageEngine getStorageEngine(String storeName, File storeDir, Serde<Object> keySerde, Serde<Object> msgSerde,
-      MessageCollector collector, MetricsRegistry registry, SystemStreamPartition changeLogSystemStreamPartition,
-      SamzaContainerContext containerContext) {
+  public StorageEngine getStorageEngine(String storeName,
+      File storeDir,
+      Serde<Object> keySerde,
+      Serde<Object> msgSerde,
+      MessageCollector collector,
+      MetricsRegistry registry,
+      SystemStreamPartition changeLogSystemStreamPartition,
+      JobContext jobContext,
+      ContainerContext containerContext) {
     StoreProperties storeProperties = new StoreProperties.StorePropertiesBuilder().setLoggedStore(true).build();
     return new MockStorageEngine(storeName, storeDir, changeLogSystemStreamPartition, storeProperties);
   }
index 42f05c0..0952a87 100644 (file)
  */
 package org.apache.samza.table;
 
-import java.lang.reflect.Field;
-import java.util.Base64;
-import java.util.HashMap;
-import java.util.Map;
-
+import junit.framework.Assert;
 import org.apache.samza.SamzaException;
 import org.apache.samza.config.JavaTableConfig;
 import org.apache.samza.config.MapConfig;
 import org.apache.samza.config.SerializerConfig;
-import org.apache.samza.container.SamzaContainerContext;
+import org.apache.samza.context.MockContext;
 import org.apache.samza.serializers.IntegerSerde;
 import org.apache.samza.serializers.Serde;
 import org.apache.samza.serializers.SerializableSerde;
 import org.apache.samza.serializers.StringSerde;
 import org.apache.samza.storage.StorageEngine;
-import org.apache.samza.task.TaskContext;
 import org.junit.Test;
 
-import junit.framework.Assert;
+import java.lang.reflect.Field;
+import java.util.Base64;
+import java.util.HashMap;
+import java.util.Map;
 
 import static org.mockito.Matchers.anyObject;
 import static org.mockito.Mockito.mock;
@@ -122,11 +120,11 @@ public class TestTableManager {
           });
 
     TableManager tableManager = new TableManager(new MapConfig(map), serdeMap);
-    tableManager.init(mock(SamzaContainerContext.class), mock(TaskContext.class));
+    tableManager.init(new MockContext());
 
     for (int i = 0; i < 2; i++) {
       Table table = tableManager.getTable(TABLE_ID);
-      verify(DummyTableProviderFactory.tableProvider, times(1)).init(anyObject(), anyObject());
+      verify(DummyTableProviderFactory.tableProvider, times(1)).init(anyObject());
       verify(DummyTableProviderFactory.tableProvider, times(1)).getTable();
       Assert.assertEquals(DummyTableProviderFactory.table, table);
     }
index ec1c915..dc13d00 100644 (file)
 
 package org.apache.samza.table.caching;
 
-import java.time.Duration;
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.Collections;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
-import java.util.concurrent.CompletableFuture;
-import java.util.concurrent.ConcurrentHashMap;
-import java.util.concurrent.Executors;
-
+import com.google.common.cache.Cache;
+import com.google.common.cache.CacheBuilder;
 import org.apache.commons.lang3.tuple.Pair;
-import org.apache.samza.container.SamzaContainerContext;
+import org.apache.samza.context.Context;
+import org.apache.samza.context.MockContext;
 import org.apache.samza.metrics.Counter;
 import org.apache.samza.metrics.Gauge;
 import org.apache.samza.metrics.MetricsRegistry;
@@ -45,17 +37,24 @@ import org.apache.samza.table.TableSpec;
 import org.apache.samza.table.caching.guava.GuavaCacheTable;
 import org.apache.samza.table.caching.guava.GuavaCacheTableDescriptor;
 import org.apache.samza.table.caching.guava.GuavaCacheTableProvider;
-import org.apache.samza.table.remote.TableRateLimiter;
 import org.apache.samza.table.remote.RemoteReadWriteTable;
+import org.apache.samza.table.remote.TableRateLimiter;
 import org.apache.samza.table.remote.TableReadFunction;
 import org.apache.samza.table.remote.TableWriteFunction;
-import org.apache.samza.task.TaskContext;
 import org.apache.samza.util.NoOpMetricsRegistry;
 import org.junit.Assert;
 import org.junit.Test;
 
-import com.google.common.cache.Cache;
-import com.google.common.cache.CacheBuilder;
+import java.time.Duration;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.Executors;
 
 import static org.mockito.Matchers.any;
 import static org.mockito.Matchers.anyString;
@@ -139,15 +138,14 @@ public class TestCachingTable {
   }
 
   private void initTables(ReadableTable ... tables) {
-    SamzaContainerContext containerContext = mock(SamzaContainerContext.class);
-    TaskContext taskContext = mock(TaskContext.class);
+    Context context = new MockContext();
     MetricsRegistry metricsRegistry = mock(MetricsRegistry.class);
     doReturn(mock(Timer.class)).when(metricsRegistry).newTimer(anyString(), anyString());
     doReturn(mock(Counter.class)).when(metricsRegistry).newCounter(anyString(), anyString());
     doReturn(mock(Gauge.class)).when(metricsRegistry).newGauge(anyString(), any());
-    when(taskContext.getMetricsRegistry()).thenReturn(metricsRegistry);
+    when(context.getTaskContext().getTaskMetricsRegistry()).thenReturn(metricsRegistry);
     for (ReadableTable table : tables) {
-      table.init(containerContext, taskContext);
+      table.init(context);
     }
   }
 
@@ -160,9 +158,7 @@ public class TestCachingTable {
     }
     CachingTableProvider tableProvider = new CachingTableProvider(desc.getTableSpec());
 
-    SamzaContainerContext containerContext = mock(SamzaContainerContext.class);
-
-    TaskContext taskContext = mock(TaskContext.class);
+    Context context = new MockContext();
     final ReadWriteTable cacheTable = getMockCache().getLeft();
 
     final ReadWriteTable realTable = mock(ReadWriteTable.class);
@@ -185,11 +181,11 @@ public class TestCachingTable {
 
         Assert.fail();
         return null;
-      }).when(taskContext).getTable(anyString());
+      }).when(context.getTaskContext()).getTable(anyString());
 
-    when(taskContext.getMetricsRegistry()).thenReturn(new NoOpMetricsRegistry());
+    when(context.getTaskContext().getTaskMetricsRegistry()).thenReturn(new NoOpMetricsRegistry());
 
-    tableProvider.init(containerContext, taskContext);
+    tableProvider.init(context);
 
     CachingTable cachingTable = (CachingTable) tableProvider.getTable();
 
index 3e844c3..571f87b 100644 (file)
 
 package org.apache.samza.table.remote;
 
-import java.util.Arrays;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
-import java.util.concurrent.CompletableFuture;
-import java.util.concurrent.ExecutionException;
-import java.util.concurrent.ExecutorService;
-import java.util.concurrent.Executors;
-import java.util.concurrent.ScheduledExecutorService;
-
-import org.apache.samza.container.SamzaContainerContext;
+import junit.framework.Assert;
+import org.apache.samza.context.Context;
+import org.apache.samza.context.MockContext;
 import org.apache.samza.metrics.Counter;
 import org.apache.samza.metrics.Gauge;
 import org.apache.samza.metrics.MetricsRegistry;
@@ -38,11 +30,18 @@ import org.apache.samza.storage.kv.Entry;
 import org.apache.samza.table.retry.RetriableReadFunction;
 import org.apache.samza.table.retry.RetriableWriteFunction;
 import org.apache.samza.table.retry.TableRetryPolicy;
-import org.apache.samza.task.TaskContext;
 import org.junit.Test;
 import org.mockito.ArgumentCaptor;
 
-import junit.framework.Assert;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.ScheduledExecutorService;
 
 import static org.mockito.Matchers.any;
 import static org.mockito.Matchers.anyCollection;
@@ -57,14 +56,14 @@ import static org.mockito.Mockito.verify;
 public class TestRemoteTable {
   private final ScheduledExecutorService schedExec = Executors.newSingleThreadScheduledExecutor();
 
-  public static TaskContext getMockTaskContext() {
+  public static Context getMockContext() {
+    Context context = new MockContext();
     MetricsRegistry metricsRegistry = mock(MetricsRegistry.class);
     doAnswer(args -> new Timer((String) args.getArguments()[0])).when(metricsRegistry).newTimer(anyString(), anyString());
     doAnswer(args -> new Counter((String) args.getArguments()[0])).when(metricsRegistry).newCounter(anyString(), anyString());
     doAnswer(args -> new Gauge((String) args.getArguments()[0], 0)).when(metricsRegistry).newGauge(anyString(), any());
-    TaskContext taskContext = mock(TaskContext.class);
-    doReturn(metricsRegistry).when(taskContext).getMetricsRegistry();
-    return taskContext;
+    doReturn(metricsRegistry).when(context.getTaskContext()).getTaskMetricsRegistry();
+    return context;
   }
 
   private <K, V, T extends RemoteReadableTable<K, V>> T getTable(String tableId,
@@ -89,11 +88,9 @@ public class TestRemoteTable {
       table = new RemoteReadWriteTable<K, V>(tableId, readFn, writeFn, readRateLimiter, writeRateLimiter, tableExecutor, cbExecutor);
     }
 
-    TaskContext taskContext = getMockTaskContext();
-
-    SamzaContainerContext containerContext = mock(SamzaContainerContext.class);
+    Context context = getMockContext();
 
-    table.init(containerContext, taskContext);
+    table.init(context);
 
     return (T) table;
   }
index efe1acf..f587885 100644 (file)
 
 package org.apache.samza.table.remote;
 
-import java.util.Collections;
-import java.util.HashMap;
-import java.util.Map;
-import java.util.concurrent.ThreadPoolExecutor;
-
-import org.apache.samza.container.SamzaContainerContext;
+import com.google.common.collect.ImmutableMap;
 import org.apache.samza.container.TaskName;
+import org.apache.samza.context.Context;
+import org.apache.samza.context.MockContext;
+import org.apache.samza.job.model.ContainerModel;
+import org.apache.samza.job.model.TaskModel;
 import org.apache.samza.metrics.Counter;
 import org.apache.samza.metrics.MetricsRegistry;
 import org.apache.samza.metrics.Timer;
@@ -34,18 +33,22 @@ import org.apache.samza.table.TableSpec;
 import org.apache.samza.table.retry.RetriableReadFunction;
 import org.apache.samza.table.retry.RetriableWriteFunction;
 import org.apache.samza.table.retry.TableRetryPolicy;
-import org.apache.samza.task.TaskContext;
 import org.apache.samza.util.EmbeddedTaggedRateLimiter;
 import org.apache.samza.util.RateLimiter;
 import org.junit.Assert;
 import org.junit.Test;
 
+import java.util.HashMap;
+import java.util.Map;
+import java.util.concurrent.ThreadPoolExecutor;
+
 import static org.apache.samza.table.remote.RemoteTableDescriptor.RL_READ_TAG;
 import static org.apache.samza.table.remote.RemoteTableDescriptor.RL_WRITE_TAG;
 import static org.mockito.Matchers.anyString;
 import static org.mockito.Mockito.doReturn;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.spy;
+import static org.mockito.Mockito.when;
 
 
 public class TestRemoteTableDescriptor {
@@ -117,16 +120,24 @@ public class TestRemoteTableDescriptor {
     desc.getTableSpec();
   }
 
-  private TaskContext createMockTaskContext() {
+  private Context createMockContext() {
+    Context context = new MockContext();
+
     MetricsRegistry metricsRegistry = mock(MetricsRegistry.class);
     doReturn(mock(Timer.class)).when(metricsRegistry).newTimer(anyString(), anyString());
     doReturn(mock(Counter.class)).when(metricsRegistry).newCounter(anyString(), anyString());
-    TaskContext taskContext = mock(TaskContext.class);
-    doReturn(metricsRegistry).when(taskContext).getMetricsRegistry();
-    SamzaContainerContext containerCtx = new SamzaContainerContext(
-        "1", null, Collections.singleton(new TaskName("MyTask")), null);
-    doReturn(containerCtx).when(taskContext).getSamzaContainerContext();
-    return taskContext;
+    doReturn(metricsRegistry).when(context.getTaskContext()).getTaskMetricsRegistry();
+
+    TaskName taskName = new TaskName("MyTask");
+    TaskModel taskModel = mock(TaskModel.class);
+    when(taskModel.getTaskName()).thenReturn(taskName);
+    when(context.getTaskContext().getTaskModel()).thenReturn(taskModel);
+
+    ContainerModel containerModel = mock(ContainerModel.class);
+    when(containerModel.getTasks()).thenReturn(ImmutableMap.of(taskName, taskModel));
+    when(context.getContainerContext().getContainerModel()).thenReturn(containerModel);
+
+    return context;
   }
 
   static class CountingCreditFunction<K, V> implements TableRateLimiter.CreditFunction<K, V> {
@@ -172,7 +183,7 @@ public class TestRemoteTableDescriptor {
 
     TableSpec spec = desc.getTableSpec();
     RemoteTableProvider provider = new RemoteTableProvider(spec);
-    provider.init(mock(SamzaContainerContext.class), createMockTaskContext());
+    provider.init(createMockContext());
     Table table = provider.getTable();
     Assert.assertTrue(table instanceof RemoteReadWriteTable);
     RemoteReadWriteTable rwTable = (RemoteReadWriteTable) table;
index 9dd5a74..050ea55 100644 (file)
@@ -29,19 +29,16 @@ import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.ExecutionException;
 import java.util.concurrent.Executors;
 import java.util.concurrent.ScheduledExecutorService;
-
-import org.apache.samza.container.SamzaContainerContext;
+import junit.framework.Assert;
+import org.apache.samza.context.Context;
 import org.apache.samza.storage.kv.Entry;
 import org.apache.samza.table.Table;
 import org.apache.samza.table.remote.TableReadFunction;
 import org.apache.samza.table.remote.TableWriteFunction;
 import org.apache.samza.table.remote.TestRemoteTable;
 import org.apache.samza.table.utils.TableMetricsUtil;
-import org.apache.samza.task.TaskContext;
 import org.junit.Test;
 
-import junit.framework.Assert;
-
 import static org.mockito.Matchers.any;
 import static org.mockito.Matchers.anyString;
 import static org.mockito.Mockito.atLeast;
@@ -57,9 +54,8 @@ public class TestRetriableTableFunctions {
 
   public TableMetricsUtil getMetricsUtil(String tableId) {
     Table table = mock(Table.class);
-    SamzaContainerContext cntCtx = mock(SamzaContainerContext.class);
-    TaskContext taskCtx = TestRemoteTable.getMockTaskContext();
-    return new TableMetricsUtil(cntCtx, taskCtx, table, tableId);
+    Context context = TestRemoteTable.getMockContext();
+    return new TableMetricsUtil(context, table, tableId);
   }
 
   @Test
index 1f71abd..13ce5f4 100644 (file)
 package org.apache.samza.task;
 
 import org.apache.samza.config.Config;
+import org.apache.samza.context.Context;
 import org.apache.samza.system.IncomingMessageEnvelope;
 import org.apache.samza.system.OutgoingMessageEnvelope;
 import org.apache.samza.system.SystemStream;
 
 
-public class IdentityStreamTask implements StreamTask , InitableTask  {
+public class IdentityStreamTask implements StreamTask, InitableTask  {
   private int processedMessageCount = 0;
   private int expectedMessageCount;
   private String outputTopic;
   private String outputSystem;
 
   @Override
-  public void init(Config config, TaskContext taskContext) throws Exception {
+  public void init(Context context) throws Exception {
+    Config config = context.getJobContext().getConfig();
     this.expectedMessageCount = config.getInt("app.messageCount");
     this.outputTopic = config.get("app.outputTopic", "output");
     this.outputSystem = config.get("app.outputSystem", "test-system");
index 9cdbfe6..be8d344 100644 (file)
@@ -28,17 +28,17 @@ import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Executors;
 import java.util.concurrent.atomic.AtomicInteger;
-
 import org.apache.samza.Partition;
 import org.apache.samza.checkpoint.Checkpoint;
 import org.apache.samza.checkpoint.OffsetManager;
-import org.apache.samza.config.Config;
-import org.apache.samza.container.SamzaContainerContext;
 import org.apache.samza.container.SamzaContainerMetrics;
 import org.apache.samza.container.TaskInstance;
 import org.apache.samza.container.TaskInstanceExceptionHandler;
 import org.apache.samza.container.TaskInstanceMetrics;
 import org.apache.samza.container.TaskName;
+import org.apache.samza.context.ContainerContext;
+import org.apache.samza.context.JobContext;
+import org.apache.samza.job.model.TaskModel;
 import org.apache.samza.metrics.MetricsRegistryMap;
 import org.apache.samza.system.IncomingMessageEnvelope;
 import org.apache.samza.system.SystemConsumer;
@@ -47,13 +47,20 @@ import org.apache.samza.system.SystemStreamPartition;
 import org.apache.samza.system.TestSystemConsumers;
 import org.junit.Rule;
 import org.junit.Test;
-
 import org.junit.rules.Timeout;
 import scala.Option;
 import scala.collection.JavaConverters;
 
 import static org.junit.Assert.assertEquals;
-import static org.mockito.Mockito.*;
+import static org.mockito.Mockito.any;
+import static org.mockito.Mockito.anyLong;
+import static org.mockito.Mockito.anyObject;
+import static org.mockito.Mockito.atLeastOnce;
+import static org.mockito.Mockito.eq;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
 
 public class TestAsyncRunLoop {
   // Immutable objects shared by all test methods.
@@ -77,12 +84,31 @@ public class TestAsyncRunLoop {
   private final IncomingMessageEnvelope ssp1EndOfStream = IncomingMessageEnvelope.buildEndOfStreamEnvelope(ssp1);
 
   TaskInstance createTaskInstance(AsyncStreamTask task, TaskName taskName, SystemStreamPartition ssp, OffsetManager manager, SystemConsumers consumers) {
+    TaskModel taskModel = mock(TaskModel.class);
+    when(taskModel.getTaskName()).thenReturn(taskName);
     TaskInstanceMetrics taskInstanceMetrics = new TaskInstanceMetrics("task", new MetricsRegistryMap());
     scala.collection.immutable.Set<SystemStreamPartition> sspSet = JavaConverters.asScalaSetConverter(Collections.singleton(ssp)).asScala().toSet();
-    return new TaskInstance(task, taskName, mock(Config.class), taskInstanceMetrics,
-        null, consumers, mock(TaskInstanceCollector.class), mock(SamzaContainerContext.class),
-        manager, null, null, null, sspSet, new TaskInstanceExceptionHandler(taskInstanceMetrics,
-        new scala.collection.immutable.HashSet<String>()), null, null, null, new scala.collection.immutable.HashSet<>(), null);
+    return new TaskInstance(task,
+        taskModel,
+        taskInstanceMetrics,
+        null,
+        consumers,
+        mock(TaskInstanceCollector.class),
+        manager,
+        null,
+        null,
+        null,
+        sspSet,
+        new TaskInstanceExceptionHandler(taskInstanceMetrics, new scala.collection.immutable.HashSet<String>()),
+        null,
+        null,
+        null,
+        new scala.collection.immutable.HashSet<>(),
+        null,
+        mock(JobContext.class),
+        mock(ContainerContext.class),
+        Option.apply(null),
+        Option.apply(null));
   }
 
   interface TestCode {
index d0b820a..0538980 100644 (file)
@@ -22,7 +22,7 @@ package org.apache.samza.task;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Executors;
 import java.util.concurrent.TimeUnit;
-import org.apache.samza.config.Config;
+import org.apache.samza.context.Context;
 import org.apache.samza.system.IncomingMessageEnvelope;
 import org.junit.Before;
 import org.junit.Test;
@@ -64,7 +64,7 @@ public class TestAsyncStreamAdapter {
     }
 
     @Override
-    public void init(Config config, TaskContext context) throws Exception {
+    public void init(Context context) throws Exception {
       inited = true;
     }
 
@@ -95,7 +95,7 @@ public class TestAsyncStreamAdapter {
     TestCallbackListener listener = new TestCallbackListener();
     TaskCallback callback = new TaskCallbackImpl(listener, null, envelope, null, 0L, 0L);
 
-    taskAdaptor.init(null, null);
+    taskAdaptor.init(null);
     assertTrue(task.inited);
 
     taskAdaptor.processAsync(null, null, null, callback);
index e0da2e9..da137e6 100644 (file)
 
 package org.apache.samza.task;
 
-import org.junit.Test;
-
 import java.util.ArrayList;
 import java.util.List;
 import java.util.concurrent.ScheduledExecutorService;
 import java.util.concurrent.ScheduledFuture;
+import org.junit.Test;
 
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertTrue;
index 1bc23d4..ab5e295 100644 (file)
 
 package org.apache.samza.task;
 
-import org.apache.samza.config.Config;
-import org.apache.samza.operators.ContextManager;
+import org.apache.samza.context.Context;
+import org.apache.samza.context.JobContext;
 import org.apache.samza.operators.OperatorSpecGraph;
 import org.apache.samza.operators.impl.OperatorImplGraph;
+import org.apache.samza.util.Clock;
 import org.junit.Test;
 
-import static org.junit.Assert.*;
-import static org.mockito.Mockito.*;
+import static org.junit.Assert.fail;
+import static org.mockito.Mockito.doThrow;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
 
 
 public class TestStreamOperatorTask {
@@ -36,20 +39,19 @@ public class TestStreamOperatorTask {
   }
 
   @Test
-  public void testCloseDuringInitializationErrors() {
-    ContextManager mockContextManager = mock(ContextManager.class);
-    StreamOperatorTask operatorTask = new StreamOperatorTask(mock(OperatorSpecGraph.class), mockContextManager);
-
-    doThrow(new RuntimeException("Failed to initialize context manager"))
-        .when(mockContextManager).init(any(), any());
-
+  public void testCloseDuringInitializationErrors() throws Exception {
+    Context context = mock(Context.class);
+    JobContext jobContext = mock(JobContext.class);
+    when(context.getJobContext()).thenReturn(jobContext);
+    doThrow(new RuntimeException("Failed to get config")).when(jobContext).getConfig();
+    StreamOperatorTask operatorTask = new StreamOperatorTask(mock(OperatorSpecGraph.class), mock(Clock.class));
     try {
-      operatorTask.init(mock(Config.class), mock(TaskContext.class));
-      operatorTask.close();
-    } catch (Exception e) {
+      operatorTask.init(context);
+    } catch (RuntimeException e) {
       if (e instanceof NullPointerException) {
         fail("Unexpected null pointer exception");
       }
     }
+    operatorTask.close();
   }
 }
index ef606c0..8559bb3 100644 (file)
  */
 package org.apache.samza.util;
 
-import java.lang.reflect.Field;
-import java.util.Collection;
-import java.util.HashMap;
-import java.util.Map;
-
-import org.apache.samza.SamzaException;
 import org.apache.samza.config.Config;
-import org.apache.samza.container.SamzaContainerContext;
-import org.apache.samza.task.TaskContext;
+import org.apache.samza.container.TaskName;
+import org.apache.samza.context.Context;
+import org.apache.samza.context.MockContext;
+import org.apache.samza.job.model.ContainerModel;
+import org.apache.samza.job.model.TaskModel;
 import org.junit.Assert;
 import org.junit.Ignore;
 import org.junit.Test;
 
+import java.util.HashMap;
+import java.util.Map;
+import java.util.function.Function;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+
 import static java.util.concurrent.TimeUnit.MILLISECONDS;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.when;
@@ -206,25 +209,14 @@ public class TestEmbeddedTaggedRateLimiter {
   }
 
   static void initRateLimiter(RateLimiter rateLimiter) {
-    Config config = mock(Config.class);
-    TaskContext taskContext = mock(TaskContext.class);
-    SamzaContainerContext containerContext = mockSamzaContainerContext();
-    when(taskContext.getSamzaContainerContext()).thenReturn(containerContext);
-    rateLimiter.init(config, taskContext);
-  }
-
-  static SamzaContainerContext mockSamzaContainerContext() {
-    try {
-      Collection<String> taskNames = mock(Collection.class);
-      when(taskNames.size()).thenReturn(NUMBER_OF_TASKS);
-      SamzaContainerContext containerContext = mock(SamzaContainerContext.class);
-      Field taskNamesField = SamzaContainerContext.class.getDeclaredField("taskNames");
-      taskNamesField.setAccessible(true);
-      taskNamesField.set(containerContext, taskNames);
-      taskNamesField.setAccessible(false);
-      return containerContext;
-    } catch (Exception ex) {
-      throw new SamzaException(ex);
-    }
+    Context context = new MockContext(mock(Config.class));
+    when(context.getTaskContext().getTaskModel()).thenReturn(mock(TaskModel.class));
+    ContainerModel containerModel = mock(ContainerModel.class);
+    Map<TaskName, TaskModel> tasks = IntStream.range(0, NUMBER_OF_TASKS)
+        .mapToObj(i -> new TaskName("task-" + i))
+        .collect(Collectors.toMap(Function.identity(), x -> mock(TaskModel.class)));
+    when(containerModel.getTasks()).thenReturn(tasks);
+    when(context.getContainerContext().getContainerModel()).thenReturn(containerModel);
+    rateLimiter.init(context);
   }
 }
index 57c0bf0..a35366d 100644 (file)
@@ -22,7 +22,8 @@ package org.apache.samza.container
 import java.util
 import java.util.concurrent.atomic.AtomicReference
 
-import org.apache.samza.config.MapConfig
+import org.apache.samza.config.{Config, MapConfig}
+import org.apache.samza.context.{ApplicationContainerContext, ContainerContext}
 import org.apache.samza.coordinator.JobModelManager
 import org.apache.samza.coordinator.server.{HttpServer, JobServlet}
 import org.apache.samza.job.model.{ContainerModel, JobModel, TaskModel}
@@ -46,7 +47,7 @@ class TestSamzaContainer extends AssertionsForJUnit with MockitoSugar {
   private val TASK_NAME = new TaskName("taskName")
 
   @Mock
-  private var containerContext: SamzaContainerContext = null
+  private var config: Config = null
   @Mock
   private var taskInstance: TaskInstance = null
   @Mock
@@ -60,6 +61,10 @@ class TestSamzaContainer extends AssertionsForJUnit with MockitoSugar {
   @Mock
   private var metrics: SamzaContainerMetrics = null
   @Mock
+  private var containerContext: ContainerContext = null
+  @Mock
+  private var applicationContainerContext: ApplicationContainerContext = null
+  @Mock
   private var samzaContainerListener: SamzaContainerListener = null
 
   private var samzaContainer: SamzaContainer = null
@@ -67,15 +72,7 @@ class TestSamzaContainer extends AssertionsForJUnit with MockitoSugar {
   @Before
   def setup(): Unit = {
     MockitoAnnotations.initMocks(this)
-    this.samzaContainer = new SamzaContainer(
-      this.containerContext,
-      Map(TASK_NAME -> this.taskInstance),
-      this.runLoop,
-      this.systemAdmins,
-      this.consumerMultiplexer,
-      this.producerMultiplexer,
-      metrics)
-    this.samzaContainer.setContainerListener(this.samzaContainerListener)
+    setupSamzaContainer(Some(this.applicationContainerContext))
     when(this.metrics.containerStartupTime).thenReturn(mock[Timer])
   }
 
@@ -173,6 +170,24 @@ class TestSamzaContainer extends AssertionsForJUnit with MockitoSugar {
   }
 
   @Test
+  def testApplicationContainerContext() {
+    val orderVerifier = inOrder(this.applicationContainerContext, this.runLoop)
+    this.samzaContainer.run
+    orderVerifier.verify(this.applicationContainerContext).start()
+    orderVerifier.verify(this.runLoop).run()
+    orderVerifier.verify(this.applicationContainerContext).stop()
+  }
+
+  @Test
+  def testNullApplicationContainerContextFactory() {
+    setupSamzaContainer(None)
+    this.samzaContainer.run
+    verify(this.runLoop).run()
+    // applicationContainerContext is not even wired into the container anymore, but just double check it is not used
+    verifyZeroInteractions(this.applicationContainerContext)
+  }
+
+  @Test
   def testReadJobModel() {
     val config = new MapConfig(Map("a" -> "b").asJava)
     val offsets = new util.HashMap[SystemStreamPartition, String]()
@@ -258,6 +273,20 @@ class TestSamzaContainer extends AssertionsForJUnit with MockitoSugar {
     assertEquals(Set(), SamzaContainer.getChangelogSSPsForContainer(containerModel, Map()))
   }
 
+  private def setupSamzaContainer(applicationContainerContext: Option[ApplicationContainerContext]) {
+    this.samzaContainer = new SamzaContainer(
+      this.config,
+      Map(TASK_NAME -> this.taskInstance),
+      this.runLoop,
+      this.systemAdmins,
+      this.consumerMultiplexer,
+      this.producerMultiplexer,
+      metrics,
+      containerContext = this.containerContext,
+      applicationContainerContextOption = applicationContainerContext)
+    this.samzaContainer.setContainerListener(this.samzaContainerListener)
+  }
+
   class MockJobServlet(exceptionLimit: Int, jobModelRef: AtomicReference[JobModel]) extends JobServlet(jobModelRef) {
     var exceptionCount = 0
 
index b196131..15534cd 100644 (file)
@@ -22,7 +22,8 @@ package org.apache.samza.container
 
 import org.apache.samza.Partition
 import org.apache.samza.checkpoint.{Checkpoint, OffsetManager}
-import org.apache.samza.config.Config
+import org.apache.samza.context.{TaskContext => _, _}
+import org.apache.samza.job.model.TaskModel
 import org.apache.samza.metrics.Counter
 import org.apache.samza.storage.TaskStorageManager
 import org.apache.samza.system.{IncomingMessageEnvelope, SystemAdmin, SystemConsumers, SystemStream, _}
@@ -34,11 +35,12 @@ import org.mockito.Mockito._
 import org.mockito.invocation.InvocationOnMock
 import org.mockito.stubbing.Answer
 import org.mockito.{Matchers, Mock, MockitoAnnotations}
+import org.scalatest.junit.AssertionsForJUnit
 import org.scalatest.mockito.MockitoSugar
 
 import scala.collection.JavaConverters._
 
-class TestTaskInstance extends MockitoSugar {
+class TestTaskInstance extends AssertionsForJUnit with MockitoSugar {
   private val SYSTEM_NAME = "test-system"
   private val TASK_NAME = new TaskName("taskName")
   private val SYSTEM_STREAM_PARTITION =
@@ -48,7 +50,7 @@ class TestTaskInstance extends MockitoSugar {
   @Mock
   private var task: AllTask = null
   @Mock
-  private var config: Config = null
+  private var taskModel: TaskModel = null
   @Mock
   private var metrics: TaskInstanceMetrics = null
   @Mock
@@ -60,13 +62,21 @@ class TestTaskInstance extends MockitoSugar {
   @Mock
   private var collector: TaskInstanceCollector = null
   @Mock
-  private var containerContext: SamzaContainerContext = null
-  @Mock
   private var offsetManager: OffsetManager = null
   @Mock
   private var taskStorageManager: TaskStorageManager = null
   // not a mock; using MockTaskInstanceExceptionHandler
   private var taskInstanceExceptionHandler: MockTaskInstanceExceptionHandler = null
+  @Mock
+  private var jobContext: JobContext = null
+  @Mock
+  private var containerContext: ContainerContext = null
+  @Mock
+  private var applicationContainerContext: ApplicationContainerContext = null
+  @Mock
+  private var applicationTaskContextFactory: ApplicationTaskContextFactory[ApplicationTaskContext] = null
+  @Mock
+  private var applicationTaskContext: ApplicationTaskContext = null
 
   private var taskInstance: TaskInstance = null
 
@@ -75,19 +85,12 @@ class TestTaskInstance extends MockitoSugar {
     MockitoAnnotations.initMocks(this)
     // not using Mockito mock since Mockito doesn't work well with the call-by-name argument in maybeHandle
     this.taskInstanceExceptionHandler = new MockTaskInstanceExceptionHandler
-    this.taskInstance = new TaskInstance(this.task,
-      TASK_NAME,
-      this.config,
-      this.metrics,
-      this.systemAdmins,
-      this.consumerMultiplexer,
-      this.collector,
-      this.containerContext,
-      this.offsetManager,
-      storageManager = this.taskStorageManager,
-      systemStreamPartitions = SYSTEM_STREAM_PARTITIONS,
-      exceptionHandler = this.taskInstanceExceptionHandler)
+    when(this.taskModel.getTaskName).thenReturn(TASK_NAME)
+    when(this.applicationTaskContextFactory.create(Matchers.eq(this.jobContext), Matchers.eq(this.containerContext),
+      any(), Matchers.eq(this.applicationContainerContext)))
+      .thenReturn(this.applicationTaskContext)
     when(this.systemAdmins.getSystemAdmin(SYSTEM_NAME)).thenReturn(this.systemAdmin)
+    setupTaskInstance(Some(this.applicationTaskContextFactory))
   }
 
   @Test
@@ -133,10 +136,10 @@ class TestTaskInstance extends MockitoSugar {
    */
   @Test
   def testManualOffsetReset() {
-    when(this.task.init(any(), any())).thenAnswer(new Answer[Void] {
+    when(this.task.init(any())).thenAnswer(new Answer[Void] {
       override def answer(invocation: InvocationOnMock): Void = {
-        val taskContext = invocation.getArgumentAt(1, classOf[TaskContext])
-        taskContext.setStartingOffset(SYSTEM_STREAM_PARTITION, "10")
+        val context = invocation.getArgumentAt(0, classOf[Context])
+        context.getTaskContext.setStartingOffset(SYSTEM_STREAM_PARTITION, "10")
         null
       }
     })
@@ -198,6 +201,35 @@ class TestTaskInstance extends MockitoSugar {
     verify(commitsCounter).inc()
   }
 
+  /**
+    * Given that an application task context factory is provided, then lifecycle calls should be made and the context
+    * should be accessible.
+    */
+  @Test
+  def testApplicationTaskContextFactoryProvided(): Unit = {
+    assertEquals(this.applicationTaskContext, this.taskInstance.context.getApplicationTaskContext)
+    this.taskInstance.initTask
+    verify(this.applicationTaskContext).start()
+    verify(this.applicationTaskContext, never()).stop()
+    this.taskInstance.shutdownTask
+    verify(this.applicationTaskContext).stop()
+  }
+
+  /**
+    * Given that no application task context factory is provided, then no lifecycle calls should be made. Also, an
+    * exception should be thrown if the application task context is accessed.
+    */
+  @Test
+  def testNoApplicationTaskContextFactoryProvided() {
+    setupTaskInstance(None)
+    this.taskInstance.initTask
+    this.taskInstance.shutdownTask
+    verifyZeroInteractions(this.applicationTaskContext)
+    intercept[IllegalStateException] {
+      this.taskInstance.context.getApplicationTaskContext
+    }
+  }
+
   @Test(expected = classOf[SystemProducerException])
   def testProducerExceptionsIsPropagated() {
     when(this.metrics.commits).thenReturn(mock[Counter])
@@ -210,6 +242,24 @@ class TestTaskInstance extends MockitoSugar {
     }
   }
 
+  private def setupTaskInstance(
+    applicationTaskContextFactory: Option[ApplicationTaskContextFactory[ApplicationTaskContext]]): Unit = {
+    this.taskInstance = new TaskInstance(this.task,
+      this.taskModel,
+      this.metrics,
+      this.systemAdmins,
+      this.consumerMultiplexer,
+      this.collector,
+      offsetManager = this.offsetManager,
+      storageManager = this.taskStorageManager,
+      systemStreamPartitions = SYSTEM_STREAM_PARTITIONS,
+      exceptionHandler = this.taskInstanceExceptionHandler,
+      jobContext = this.jobContext,
+      containerContext = this.containerContext,
+      applicationContainerContextOption = Some(this.applicationContainerContext),
+      applicationTaskContextFactoryOption = applicationTaskContextFactory)
+  }
+
   /**
     * Task type which has all task traits, which can be mocked.
     */
index 0b951f4..59f8662 100644 (file)
  */
 package org.apache.samza.processor
 
-import java.util.Collections
+import java.util
 
+import org.apache.samza.Partition
 import org.apache.samza.config.MapConfig
 import org.apache.samza.container._
-import org.apache.samza.metrics.MetricsRegistryMap
+import org.apache.samza.context.{ContainerContext, JobContext}
+import org.apache.samza.job.model.TaskModel
 import org.apache.samza.serializers.SerdeManager
-import org.apache.samza.system.chooser.RoundRobinChooser
 import org.apache.samza.system._
+import org.apache.samza.system.chooser.RoundRobinChooser
 import org.apache.samza.task.{StreamTask, TaskInstanceCollector}
+import org.mockito.Mockito
 
 
 object StreamProcessorTestUtils {
   def getDummyContainer(mockRunloop: RunLoop, streamTask: StreamTask) = {
-    val config = new MapConfig
+    val config = new MapConfig()
     val taskName = new TaskName("taskName")
+    val taskModel = new TaskModel(taskName, new util.HashSet[SystemStreamPartition](), new Partition(0))
     val adminMultiplexer = new SystemAdmins(config)
     val consumerMultiplexer = new SystemConsumers(
       new RoundRobinChooser,
@@ -41,26 +45,29 @@ object StreamProcessorTestUtils {
       Map[String, SystemProducer](),
       new SerdeManager)
     val collector = new TaskInstanceCollector(producerMultiplexer)
-    val containerContext = new SamzaContainerContext("0", config, Collections.singleton[TaskName](taskName), new MetricsRegistryMap)
+    val containerContext = Mockito.mock(classOf[ContainerContext])
     val taskInstance: TaskInstance = new TaskInstance(
       streamTask,
-      taskName,
-      config,
+      taskModel,
       new TaskInstanceMetrics,
-      null,
+      adminMultiplexer,
       consumerMultiplexer,
       collector,
-      containerContext
-    )
+      jobContext = Mockito.mock(classOf[JobContext]),
+      containerContext = containerContext,
+      applicationContainerContextOption = None,
+      applicationTaskContextFactoryOption = None)
 
     val container = new SamzaContainer(
-      containerContext = containerContext,
+      config = config,
       taskInstances = Map(taskName -> taskInstance),
       runLoop = mockRunloop,
       systemAdmins = adminMultiplexer,
       consumerMultiplexer = consumerMultiplexer,
       producerMultiplexer = producerMultiplexer,
-      metrics = new SamzaContainerMetrics)
+      metrics = new SamzaContainerMetrics,
+      containerContext = containerContext,
+      applicationContainerContextOption = None)
     container
   }
 }
\ No newline at end of file
index 53147ad..e30328a 100644 (file)
@@ -21,18 +21,19 @@ package org.apache.samza.storage.kv.inmemory
 
 import java.io.File
 
-import org.apache.samza.container.SamzaContainerContext
+import org.apache.samza.context.{ContainerContext, JobContext}
 import org.apache.samza.metrics.MetricsRegistry
-import org.apache.samza.storage.kv.{KeyValueStoreMetrics, BaseKeyValueStorageEngineFactory, KeyValueStore}
+import org.apache.samza.storage.kv.{BaseKeyValueStorageEngineFactory, KeyValueStore, KeyValueStoreMetrics}
 import org.apache.samza.system.SystemStreamPartition
 
 class InMemoryKeyValueStorageEngineFactory[K, V] extends BaseKeyValueStorageEngineFactory[K, V] {
 
   override def getKVStore(storeName: String,
-                          storeDir: File,
-                          registry: MetricsRegistry,
-                          changeLogSystemStreamPartition: SystemStreamPartition,
-                          containerContext: SamzaContainerContext): KeyValueStore[Array[Byte], Array[Byte]] = {
+    storeDir: File,
+    registry: MetricsRegistry,
+    changeLogSystemStreamPartition: SystemStreamPartition,
+    jobContext: JobContext,
+    containerContext: ContainerContext): KeyValueStore[Array[Byte], Array[Byte]] = {
     val metrics = new KeyValueStoreMetrics(storeName, registry)
     val inMemoryDb = new InMemoryKeyValueStore (metrics)
     inMemoryDb
index 9dca23c..0734fe6 100644 (file)
 
 package org.apache.samza.storage.kv;
 
-import java.util.ArrayList;
-
 import org.apache.samza.SamzaException;
 import org.apache.samza.config.Config;
 import org.apache.samza.config.JavaSerializerConfig;
 import org.apache.samza.config.JavaStorageConfig;
 import org.apache.samza.config.SerializerConfig$;
-import org.apache.samza.container.SamzaContainerContext;
-import org.apache.samza.container.TaskName;
-import org.apache.samza.metrics.MetricsRegistryMap;
 import org.apache.samza.serializers.Serde;
 import org.apache.samza.serializers.SerdeFactory;
 import org.apache.samza.util.Util;
@@ -65,11 +60,7 @@ public class RocksDbKeyValueReader {
     valueSerde = getSerdeFromName(storageConfig.getStorageMsgSerde(storeName), serializerConfig);
 
     // get db options
-    ArrayList<TaskName> taskNameList = new ArrayList<TaskName>();
-    taskNameList.add(new TaskName("read-rocks-db"));
-    SamzaContainerContext samzaContainerContext =
-        new SamzaContainerContext("0",  config, taskNameList, new MetricsRegistryMap());
-    Options options = RocksDbOptionsHelper.options(config, samzaContainerContext);
+    Options options = RocksDbOptionsHelper.options(config, 1);
 
     // open the db
     RocksDB.loadLibrary();
index 9389681..7beb066 100644 (file)
@@ -20,7 +20,8 @@
 package org.apache.samza.storage.kv;
 
 import org.apache.samza.config.Config;
-import org.apache.samza.container.SamzaContainerContext;
+import org.apache.samza.context.ContainerContext;
+import org.apache.samza.context.JobContext;
 import org.rocksdb.BlockBasedTableConfig;
 import org.rocksdb.CompactionStyle;
 import org.rocksdb.CompressionType;
@@ -41,12 +42,11 @@ public class RocksDbOptionsHelper {
   private static final String ROCKSDB_MAX_LOG_FILE_SIZE_BYTES = "rocksdb.max.log.file.size.bytes";
   private static final String ROCKSDB_KEEP_LOG_FILE_NUM = "rocksdb.keep.log.file.num";
 
-  public static Options options(Config storeConfig, SamzaContainerContext containerContext) {
+  public static Options options(Config storeConfig, int numTasksForContainer) {
     Options options = new Options();
     Long writeBufSize = storeConfig.getLong("container.write.buffer.size.bytes", 32 * 1024 * 1024);
     // Cache size and write buffer size are specified on a per-container basis.
-    int numTasks = containerContext.taskNames.size();
-    options.setWriteBufferSize((int) (writeBufSize / numTasks));
+    options.setWriteBufferSize((int) (writeBufSize / numTasksForContainer));
 
     CompressionType compressionType = CompressionType.SNAPPY_COMPRESSION;
     String compressionInConfig = storeConfig.get(ROCKSDB_COMPRESSION, "snappy");
@@ -75,7 +75,7 @@ public class RocksDbOptionsHelper {
     }
     options.setCompressionType(compressionType);
 
-    long blockCacheSize = getBlockCacheSize(storeConfig, containerContext);
+    long blockCacheSize = getBlockCacheSize(storeConfig, numTasksForContainer);
     int blockSize = storeConfig.getInt(ROCKSDB_BLOCK_SIZE_BYTES, 4096);
     BlockBasedTableConfig tableOptions = new BlockBasedTableConfig();
     tableOptions.setBlockCacheSize(blockCacheSize).setBlockSize(blockSize);
@@ -109,9 +109,8 @@ public class RocksDbOptionsHelper {
     return options;
   }
 
-  public static Long getBlockCacheSize(Config storeConfig, SamzaContainerContext containerContext) {
-    int numTasks = containerContext.taskNames.size();
+  public static Long getBlockCacheSize(Config storeConfig, int numTasksForContainer) {
     long cacheSize = storeConfig.getLong("container.cache.size.bytes", 100 * 1024 * 1024L);
-    return cacheSize / numTasks;
+    return cacheSize / numTasksForContainer;
   }
 }
\ No newline at end of file
index 2b7ffb5..704af4a 100644 (file)
 package org.apache.samza.storage.kv
 
 import java.io.File
-import org.apache.samza.container.SamzaContainerContext
+
+import org.apache.samza.config.StorageConfig._
+import org.apache.samza.context.{ContainerContext, JobContext}
 import org.apache.samza.metrics.MetricsRegistry
 import org.apache.samza.system.SystemStreamPartition
 import org.rocksdb.{FlushOptions, WriteOptions}
-import org.apache.samza.config.StorageConfig._
 
 class RocksDbKeyValueStorageEngineFactory [K, V] extends BaseKeyValueStorageEngineFactory[K, V] {
   /**
@@ -37,17 +38,19 @@ class RocksDbKeyValueStorageEngineFactory [K, V] extends BaseKeyValueStorageEngi
    * @return A valid KeyValueStore instance
    */
   override def getKVStore(storeName: String,
-                          storeDir: File,
-                          registry: MetricsRegistry,
-                          changeLogSystemStreamPartition: SystemStreamPartition,
-                          containerContext: SamzaContainerContext): KeyValueStore[Array[Byte], Array[Byte]] = {
-    val storageConfig = containerContext.config.subset("stores." + storeName + ".", true)
-    val isLoggedStore = containerContext.config.getChangelogStream(storeName).isDefined
+    storeDir: File,
+    registry: MetricsRegistry,
+    changeLogSystemStreamPartition: SystemStreamPartition,
+    jobContext: JobContext,
+    containerContext: ContainerContext): KeyValueStore[Array[Byte], Array[Byte]] = {
+    val storageConfig = jobContext.getConfig.subset("stores." + storeName + ".", true)
+    val isLoggedStore = jobContext.getConfig.getChangelogStream(storeName).isDefined
     val rocksDbMetrics = new KeyValueStoreMetrics(storeName, registry)
+    val numTasksForContainer = containerContext.getContainerModel.getTasks.keySet().size()
     rocksDbMetrics.newGauge("rocksdb.block-cache-size",
-      () => RocksDbOptionsHelper.getBlockCacheSize(storageConfig, containerContext))
+      () => RocksDbOptionsHelper.getBlockCacheSize(storageConfig, numTasksForContainer))
 
-    val rocksDbOptions = RocksDbOptionsHelper.options(storageConfig, containerContext)
+    val rocksDbOptions = RocksDbOptionsHelper.options(storageConfig, numTasksForContainer)
     val rocksDbWriteOptions = new WriteOptions().setDisableWAL(true)
     val rocksDbFlushOptions = new FlushOptions().setWaitForFlush(true)
     val rocksDb = new RocksDbKeyValueStore(
index 35a66e8..cd7e85c 100644 (file)
  */
 package org.apache.samza.storage.kv;
 
+import junit.framework.Assert;
 import org.apache.samza.serializers.IntegerSerde;
 import org.apache.samza.serializers.KVSerde;
 import org.apache.samza.serializers.StringSerde;
 import org.apache.samza.table.TableSpec;
 import org.junit.Test;
 
-import junit.framework.Assert;
-
 
 public class TestRocksDbTableDescriptor {
 
index 8231905..e56c977 100644 (file)
  */
 package org.apache.samza.storage.kv;
 
+import com.google.common.base.Preconditions;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.regex.Pattern;
-
 import org.apache.commons.lang3.StringUtils;
 import org.apache.samza.SamzaException;
 import org.apache.samza.config.Config;
@@ -31,15 +31,12 @@ import org.apache.samza.config.JavaTableConfig;
 import org.apache.samza.config.JobConfig;
 import org.apache.samza.config.MapConfig;
 import org.apache.samza.config.StorageConfig;
-import org.apache.samza.container.SamzaContainerContext;
+import org.apache.samza.context.Context;
 import org.apache.samza.table.ReadableTable;
 import org.apache.samza.table.Table;
 import org.apache.samza.table.TableSpec;
 import org.apache.samza.table.utils.BaseTableProvider;
 import org.apache.samza.table.utils.SerdeUtils;
-import org.apache.samza.task.TaskContext;
-
-import com.google.common.base.Preconditions;
 
 
 /**
@@ -59,13 +56,12 @@ abstract public class BaseLocalStoreBackedTableProvider extends BaseTableProvide
   }
 
   @Override
-  public void init(SamzaContainerContext containerContext, TaskContext taskContext) {
-
-    super.init(containerContext, taskContext);
+  public void init(Context context) {
+    super.init(context);
 
-    Preconditions.checkNotNull(this.taskContext, "Must specify task context for local tables.");
+    Preconditions.checkNotNull(this.context, "Must specify context for local tables.");
 
-    kvStore = (KeyValueStore) taskContext.getStore(tableSpec.getId());
+    kvStore = (KeyValueStore) this.context.getTaskContext().getStore(tableSpec.getId());
 
     if (kvStore == null) {
       throw new SamzaException(String.format(
@@ -81,7 +77,7 @@ abstract public class BaseLocalStoreBackedTableProvider extends BaseTableProvide
       throw new SamzaException("Store not initialized for table " + tableSpec.getId());
     }
     ReadableTable table = new LocalStoreBackedReadWriteTable(tableSpec.getId(), kvStore);
-    table.init(containerContext, taskContext);
+    table.init(this.context);
     return table;
   }
 
index 9eeb55e..804df43 100644 (file)
@@ -20,11 +20,9 @@ package org.apache.samza.storage.kv;
 
 import java.util.List;
 import java.util.concurrent.CompletableFuture;
-
-import org.apache.samza.container.SamzaContainerContext;
+import org.apache.samza.context.Context;
 import org.apache.samza.table.ReadWriteTable;
 import org.apache.samza.table.utils.DefaultTableWriteMetrics;
-import org.apache.samza.task.TaskContext;
 
 
 /**
@@ -51,9 +49,9 @@ public class LocalStoreBackedReadWriteTable<K, V> extends LocalStoreBackedReadab
    * {@inheritDoc}
    */
   @Override
-  public void init(SamzaContainerContext containerContext, TaskContext taskContext) {
-    super.init(containerContext, taskContext);
-    writeMetrics = new DefaultTableWriteMetrics(containerContext, taskContext, this, tableId);
+  public void init(Context context) {
+    super.init(context);
+    writeMetrics = new DefaultTableWriteMetrics(context, this, tableId);
   }
 
   @Override
index d0629c4..d440d42 100644 (file)
  */
 package org.apache.samza.storage.kv;
 
+import com.google.common.base.Preconditions;
 import java.util.List;
 import java.util.Map;
 import java.util.concurrent.CompletableFuture;
-
-import com.google.common.base.Preconditions;
-import org.apache.samza.container.SamzaContainerContext;
+import org.apache.samza.context.Context;
 import org.apache.samza.table.ReadableTable;
 import org.apache.samza.table.utils.DefaultTableReadMetrics;
-import org.apache.samza.task.TaskContext;
 
 
 /**
@@ -58,8 +56,8 @@ public class LocalStoreBackedReadableTable<K, V> implements ReadableTable<K, V>
    * {@inheritDoc}
    */
   @Override
-  public void init(SamzaContainerContext containerContext, TaskContext taskContext) {
-    readMetrics = new DefaultTableReadMetrics(containerContext, taskContext, this, tableId);
+  public void init(Context context) {
+    readMetrics = new DefaultTableReadMetrics(context, this, tableId);
   }
 
   @Override
index da80560..d962e93 100644 (file)
@@ -22,14 +22,14 @@ package org.apache.samza.storage.kv
 import java.io.File
 
 import org.apache.samza.SamzaException
-import org.apache.samza.container.SamzaContainerContext
+import org.apache.samza.config.MetricsConfig.Config2Metrics
+import org.apache.samza.context.{ContainerContext, JobContext}
 import org.apache.samza.metrics.MetricsRegistry
 import org.apache.samza.serializers.Serde
 import org.apache.samza.storage.{StorageEngine, StorageEngineFactory, StoreProperties}
 import org.apache.samza.system.SystemStreamPartition
 import org.apache.samza.task.MessageCollector
-import org.apache.samza.config.MetricsConfig.Config2Metrics
-import org.apache.samza.util.{HighResolutionClock, ScalaJavaUtil}
+import org.apache.samza.util.HighResolutionClock
 
 /**
  * A key value storage engine factory implementation
@@ -52,11 +52,12 @@ trait BaseKeyValueStorageEngineFactory[K, V] extends StorageEngineFactory[K, V]
    * @param containerContext Information about the container in which the task is executing.
    * @return A valid KeyValueStore instance
    */
-  def getKVStore( storeName: String,
-                  storeDir: File,
-                  registry: MetricsRegistry,
-                  changeLogSystemStreamPartition: SystemStreamPartition,
-                  containerContext: SamzaContainerContext): KeyValueStore[Array[Byte], Array[Byte]]
+  def getKVStore(storeName: String,
+    storeDir: File,
+    registry: MetricsRegistry,
+    changeLogSystemStreamPartition: SystemStreamPartition,
+    jobContext: JobContext,
+    containerContext: ContainerContext): KeyValueStore[Array[Byte], Array[Byte]]
 
   /**
    * Constructs a key-value StorageEngine and returns it to the caller
@@ -70,15 +71,16 @@ trait BaseKeyValueStorageEngineFactory[K, V] extends StorageEngineFactory[K, V]
    * @param changeLogSystemStreamPartition Samza stream partition from which to receive the changelog.
    * @param containerContext Information about the container in which the task is executing.
    **/
-  def getStorageEngine( storeName: String,
-                        storeDir: File,
-                        keySerde: Serde[K],
-                        msgSerde: Serde[V],
-                        collector: MessageCollector,
-                        registry: MetricsRegistry,
-                        changeLogSystemStreamPartition: SystemStreamPartition,
-                        containerContext: SamzaContainerContext): StorageEngine = {
-    val storageConfig = containerContext.config.subset("stores." + storeName + ".", true)
+  def getStorageEngine(storeName: String,
+    storeDir: File,
+    keySerde: Serde[K],
+    msgSerde: Serde[V],
+    collector: MessageCollector,
+    registry: MetricsRegistry,
+    changeLogSystemStreamPartition: SystemStreamPartition,
+    jobContext: JobContext,
+    containerContext: ContainerContext): StorageEngine = {
+    val storageConfig = jobContext.getConfig.subset("stores." + storeName + ".", true)
     val storeFactory = storageConfig.get("factory")
     var storePropertiesBuilder = new StoreProperties.StorePropertiesBuilder()
     val accessLog = storageConfig.getBoolean("accesslog.enabled", false)
@@ -106,7 +108,8 @@ trait BaseKeyValueStorageEngineFactory[K, V] extends StorageEngineFactory[K, V]
       throw new SamzaException("Must define a message serde when using key value storage.")
     }
 
-    val rawStore = getKVStore(storeName, storeDir, registry, changeLogSystemStreamPartition, containerContext)
+    val rawStore =
+      getKVStore(storeName, storeDir, registry, changeLogSystemStreamPartition, jobContext, containerContext)
 
     // maybe wrap with logging
     val maybeLoggedStore = if (changeLogSystemStreamPartition == null) {
@@ -141,7 +144,7 @@ trait BaseKeyValueStorageEngineFactory[K, V] extends StorageEngineFactory[K, V]
     // create the storage engine and return
     // TODO: Decide if we should use raw bytes when restoring
     val keyValueStorageEngineMetrics = new KeyValueStorageEngineMetrics(storeName, registry)
-    val clock = if (containerContext.config.getMetricsTimerEnabled) {
+    val clock = if (jobContext.getConfig.getMetricsTimerEnabled) {
       new HighResolutionClock {
         override def nanoTime(): Long = System.nanoTime()
       }
index 2b0166c..399f9fd 100644 (file)
@@ -28,33 +28,33 @@ import org.apache.samza.config.JavaTableConfig;
 import org.apache.samza.config.JobConfig;
 import org.apache.samza.config.MapConfig;
 import org.apache.samza.config.StorageConfig;
-import org.apache.samza.container.SamzaContainerContext;
-import org.apache.samza.storage.StorageEngine;
+import org.apache.samza.context.Context;
+import org.apache.samza.context.TaskContext;
 import org.apache.samza.table.TableProvider;
 import org.apache.samza.table.TableSpec;
-import org.apache.samza.task.TaskContext;
 import org.apache.samza.util.NoOpMetricsRegistry;
 import org.junit.Test;
 
 import static org.mockito.Matchers.any;
-import static org.mockito.Mockito.*;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
 
 
 public class TestBaseLocalStoreBackedTableProvider {
 
   @Test
   public void testInit() {
-    StorageEngine store = mock(KeyValueStorageEngine.class);
-    SamzaContainerContext containerContext = mock(SamzaContainerContext.class);
+    Context context = mock(Context.class);
     TaskContext taskContext = mock(TaskContext.class);
-    when(taskContext.getStore(any())).thenReturn(store);
-    when(taskContext.getMetricsRegistry()).thenReturn(new NoOpMetricsRegistry());
+    when(context.getTaskContext()).thenReturn(taskContext);
+    when(taskContext.getStore(any())).thenReturn(mock(KeyValueStore.class));
+    when(taskContext.getTaskMetricsRegistry()).thenReturn(new NoOpMetricsRegistry());
 
     TableSpec tableSpec = mock(TableSpec.class);
     when(tableSpec.getId()).thenReturn("t1");
 
     TableProvider tableProvider = createTableProvider(tableSpec);
-    tableProvider.init(containerContext, taskContext);
+    tableProvider.init(context);
     Assert.assertNotNull(tableProvider.getTable());
   }
 
diff --git a/samza-sql/src/main/java/org/apache/samza/sql/runner/SamzaSqlApplicationContext.java b/samza-sql/src/main/java/org/apache/samza/sql/runner/SamzaSqlApplicationContext.java
new file mode 100644 (file)
index 0000000..6841e15
--- /dev/null
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.sql.runner;
+
+import org.apache.samza.context.ApplicationTaskContext;
+import org.apache.samza.sql.translator.TranslatorContext;
+
+
+public class SamzaSqlApplicationContext implements ApplicationTaskContext {
+  private final TranslatorContext translatorContext;
+
+  public SamzaSqlApplicationContext(TranslatorContext translatorContext) {
+    this.translatorContext = translatorContext;
+  }
+
+  public TranslatorContext getTranslatorContext() {
+    return translatorContext;
+  }
+
+  @Override
+  public void start() {
+  }
+
+  @Override
+  public void stop() {
+  }
+}
index f33c5ca..77a24f8 100644 (file)
@@ -21,14 +21,13 @@ package org.apache.samza.sql.translator;
 
 import java.util.Arrays;
 import java.util.Collections;
-
 import org.apache.calcite.rel.logical.LogicalFilter;
-import org.apache.samza.config.Config;
+import org.apache.samza.context.Context;
 import org.apache.samza.operators.MessageStream;
 import org.apache.samza.operators.functions.FilterFunction;
 import org.apache.samza.sql.data.Expression;
 import org.apache.samza.sql.data.SamzaSqlRelMessage;
-import org.apache.samza.task.TaskContext;
+import org.apache.samza.sql.runner.SamzaSqlApplicationContext;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -53,8 +52,8 @@ class FilterTranslator {
     }
 
     @Override
-    public void init(Config config, TaskContext context) {
-      this.context = (TranslatorContext) context.getUserContext();
+    public void init(Context context) {
+      this.context = ((SamzaSqlApplicationContext) context.getApplicationTaskContext()).getTranslatorContext();
       this.filter = (LogicalFilter) this.context.getRelNode(filterId);
       this.expr = this.context.getExpressionCompiler().compile(filter.getInputs(), Collections.singletonList(filter.getCondition()));
     }
index 965338f..435a2cc 100644 (file)
@@ -26,7 +26,7 @@ import org.apache.calcite.rel.core.TableModify;
 import org.apache.commons.lang.Validate;
 import org.apache.samza.SamzaException;
 import org.apache.samza.application.StreamApplicationDescriptor;
-import org.apache.samza.config.Config;
+import org.apache.samza.context.Context;
 import org.apache.samza.operators.KV;
 import org.apache.samza.operators.MessageStream;
 import org.apache.samza.operators.MessageStreamImpl;
@@ -39,8 +39,8 @@ import org.apache.samza.serializers.NoOpSerde;
 import org.apache.samza.sql.data.SamzaSqlRelMessage;
 import org.apache.samza.sql.interfaces.SamzaRelConverter;
 import org.apache.samza.sql.interfaces.SqlIOConfig;
+import org.apache.samza.sql.runner.SamzaSqlApplicationContext;
 import org.apache.samza.table.Table;
-import org.apache.samza.task.TaskContext;
 
 
 /**
@@ -70,9 +70,10 @@ class ModifyTranslator {
     }
 
     @Override
-    public void init(Config config, TaskContext taskContext) {
-      TranslatorContext context = (TranslatorContext) taskContext.getUserContext();
-      this.samzaMsgConverter = context.getMsgConverter(outputTopic);
+    public void init(Context context) {
+      TranslatorContext transl