SAMZA-2018: State restore improvements using RocksDB bulk load
authorRay Matharu <rmatharu@linkedin.com>
Sat, 12 Jan 2019 00:23:01 +0000 (16:23 -0800)
committerPrateek Maheshwari <pmaheshwari@apache.org>
Sat, 12 Jan 2019 00:23:01 +0000 (16:23 -0800)
This PR makes the following changes:
* Moves all the state-restore code from TaskStorageManager.scala to ContainerStorageManager (and its internal private java classes).
* Introduces a StoreMode in StorageEngineFactory.getStorageEngine to add a StoreMode enum.
* Changes RocksDB store creation to use that enum and use Rocksdb's bulk load option when creating store in bulk-load mode.
* Changes the ContainerStorageManager to create stores in BulkLoad mode when restoring, then closes such persistent and changelogged stores, and re-opens them in Read-Write mode.
* Adds tests for ContainerStorageManager and changes tests for TaskStorageManager accordingly.

Author: Ray Matharu <rmatharu@linkedin.com>
Author: rmatharu <40646191+rmatharu@users.noreply.github.com>

Reviewers: Jagadish Venkatraman <vjagadish1989@gmail.com>

Closes #843 from rmatharu/refactoringCSM

20 files changed:
samza-api/src/main/java/org/apache/samza/storage/StorageEngineFactory.java
samza-core/src/main/java/org/apache/samza/storage/StorageManagerUtil.java
samza-core/src/main/java/org/apache/samza/storage/StorageRecovery.java
samza-core/src/main/scala/org/apache/samza/container/SamzaContainer.scala
samza-core/src/main/scala/org/apache/samza/storage/ContainerStorageManager.java
samza-core/src/main/scala/org/apache/samza/storage/TaskStorageManager.scala
samza-core/src/main/scala/org/apache/samza/util/ScalaJavaUtil.scala
samza-core/src/test/java/org/apache/samza/storage/MockStorageEngineFactory.java
samza-core/src/test/java/org/apache/samza/storage/TestStorageRecovery.java
samza-core/src/test/scala/org/apache/samza/storage/TestContainerStorageManager.java
samza-core/src/test/scala/org/apache/samza/storage/TestTaskStorageManager.scala
samza-kv-inmemory/src/main/scala/org/apache/samza/storage/kv/inmemory/InMemoryKeyValueStorageEngineFactory.scala
samza-kv-rocksdb/src/main/java/org/apache/samza/storage/kv/RocksDbKeyValueReader.java
samza-kv-rocksdb/src/main/java/org/apache/samza/storage/kv/RocksDbOptionsHelper.java
samza-kv-rocksdb/src/main/scala/org/apache/samza/storage/kv/RocksDbKeyValueStorageEngineFactory.scala
samza-kv-rocksdb/src/main/scala/org/apache/samza/storage/kv/RocksDbKeyValueStore.scala
samza-kv/src/main/scala/org/apache/samza/storage/kv/BaseKeyValueStorageEngineFactory.scala
samza-rest/src/main/java/org/apache/samza/monitor/LocalStoreMonitor.java
samza-test/src/main/scala/org/apache/samza/test/performance/TestKeyValuePerformance.scala
samza-test/src/test/java/org/apache/samza/test/table/TestLocalTableWithSideInputs.java

index 2425cf3..26d6e75 100644 (file)
@@ -34,6 +34,23 @@ import org.apache.samza.task.MessageCollector;
 public interface StorageEngineFactory<K, V> {
 
   /**
+   * Enum to describe different modes a {@link StorageEngine} can be created in.
+   * The BulkLoad mode is used when bulk loading of data onto the store, e.g., store-restoration at Samza
+   * startup. In this mode, the underlying store will tailor itself for write-intensive ops -- tune its params,
+   * adapt its compaction behaviour, etc.
+   *
+   * The ReadWrite mode is used for normal read-write ops by the application.
+   */
+  enum StoreMode {
+    BulkLoad("bulk"), ReadWrite("rw");
+    public final String mode;
+
+    StoreMode(String mode) {
+      this.mode = mode;
+    }
+  }
+
+  /**
    * Create an instance of the given storage engine.
    *
    * @param storeName The name of the storage engine.
@@ -45,6 +62,7 @@ public interface StorageEngineFactory<K, V> {
    * @param changeLogSystemStreamPartition Samza stream partition from which to receive the changelog.
    * @param jobContext Information about the job in which the task is executing
    * @param containerContext Information about the container in which the task is executing.
+   * @param storeMode The mode in which the instance should be created in.
    * @return The storage engine instance.
    */
   StorageEngine getStorageEngine(
@@ -56,5 +74,5 @@ public interface StorageEngineFactory<K, V> {
     MetricsRegistry registry,
     SystemStreamPartition changeLogSystemStreamPartition,
     JobContext jobContext,
-    ContainerContext containerContext);
+    ContainerContext containerContext, StoreMode storeMode);
 }
index e7301ea..2086aa4 100644 (file)
@@ -21,6 +21,7 @@ package org.apache.samza.storage;
 
 import com.google.common.collect.ImmutableMap;
 import java.io.File;
+import org.apache.samza.container.TaskName;
 import org.apache.samza.system.SystemAdmin;
 import org.apache.samza.system.SystemStreamPartition;
 import org.apache.samza.util.FileUtil;
@@ -139,4 +140,16 @@ public class StorageManagerUtil {
 
     return offset;
   }
+
+  /**
+   * Creates and returns a File pointing to the directory for the given store and task, given a particular base directory.
+   *
+   * @param storeBaseDir the base directory to use
+   * @param storeName the store name to use
+   * @param taskName the task name which is referencing the store
+   * @return the partition directory for the store
+   */
+  public static File getStorePartitionDir(File storeBaseDir, String storeName, TaskName taskName) {
+    return new File(storeBaseDir, (storeName + File.separator + taskName.toString()).replace(' ', '_'));
+  }
 }
index 5442d6e..cf8338a 100644 (file)
 package org.apache.samza.storage;
 
 import java.io.File;
-import java.time.Duration;
-import java.util.Collections;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
-import java.util.Map.Entry;
 import org.apache.samza.SamzaException;
 import org.apache.samza.config.Config;
 import org.apache.samza.config.JavaStorageConfig;
 import org.apache.samza.config.JavaSystemConfig;
-import org.apache.samza.config.StorageConfig;
-import org.apache.samza.container.TaskName;
+import org.apache.samza.config.SerializerConfig;
+import org.apache.samza.container.SamzaContainerMetrics;
 import org.apache.samza.context.ContainerContext;
 import org.apache.samza.context.ContainerContextImpl;
 import org.apache.samza.context.JobContextImpl;
@@ -41,15 +38,12 @@ import org.apache.samza.job.model.ContainerModel;
 import org.apache.samza.job.model.JobModel;
 import org.apache.samza.job.model.TaskModel;
 import org.apache.samza.metrics.MetricsRegistryMap;
-import org.apache.samza.serializers.ByteSerde;
 import org.apache.samza.serializers.Serde;
-import org.apache.samza.system.SSPMetadataCache;
+import org.apache.samza.serializers.SerdeFactory;
 import org.apache.samza.system.StreamMetadataCache;
 import org.apache.samza.system.SystemAdmins;
-import org.apache.samza.system.SystemConsumer;
 import org.apache.samza.system.SystemFactory;
 import org.apache.samza.system.SystemStream;
-import org.apache.samza.system.SystemStreamPartition;
 import org.apache.samza.util.Clock;
 import org.apache.samza.util.CommandLine;
 import org.apache.samza.util.ScalaJavaUtil;
@@ -58,6 +52,8 @@ import org.apache.samza.util.SystemClock;
 import org.apache.samza.util.Util;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
+import scala.Option;
+
 
 /**
  * Recovers the state storages from the changelog streams and store the storages
@@ -70,13 +66,13 @@ public class StorageRecovery extends CommandLine {
   private int maxPartitionNumber = 0;
   private File storeBaseDir = null;
   private HashMap<String, SystemStream> changeLogSystemStreams = new HashMap<>();
-  private HashMap<String, StorageEngineFactory<?, ?>> storageEngineFactories = new HashMap<>();
+  private HashMap<String, StorageEngineFactory<Object, Object>> storageEngineFactories = new HashMap<>();
   private Map<String, ContainerModel> containers = new HashMap<>();
-  private ContainerStorageManager containerStorageManager;
+  private Map<String, ContainerStorageManager> containerStorageManagers = new HashMap<>();
+
   private Logger log = LoggerFactory.getLogger(StorageRecovery.class);
   private SystemAdmins systemAdmins = null;
 
-
   /**
    * Construct the StorageRecovery
    *
@@ -101,7 +97,7 @@ public class StorageRecovery extends CommandLine {
     getContainerModels();
     getChangeLogSystemStreamsAndStorageFactories();
     getChangeLogMaxPartitionNumber();
-    getContainerStorageManager();
+    getContainerStorageManagers();
   }
 
   /**
@@ -113,16 +109,19 @@ public class StorageRecovery extends CommandLine {
     log.info("start recovering...");
 
     systemAdmins.start();
-    this.containerStorageManager.start();
-    this.containerStorageManager.shutdown();
+    this.containerStorageManagers.forEach((containerName, containerStorageManager) -> {
+        containerStorageManager.start();
+      });
+    this.containerStorageManagers.forEach((containerName, containerStorageManager) -> {
+        containerStorageManager.shutdown();
+      });
     systemAdmins.stop();
 
     log.info("successfully recovered in " + storeBaseDir.toString());
   }
 
   /**
-   * build the ContainerModels from job config file and put the results in the
-   * map
+   * Build ContainerModels from job config file and put the results in the containerModels map.
    */
   private void getContainerModels() {
     MetricsRegistryMap metricsRegistryMap = new MetricsRegistryMap();
@@ -137,7 +136,7 @@ public class StorageRecovery extends CommandLine {
   }
 
   /**
-   * get the changelog streams and the storage factories from the config file
+   * Get the changelog streams and the storage factories from the config file
    * and put them into the maps
    */
   private void getChangeLogSystemStreamsAndStorageFactories() {
@@ -165,24 +164,6 @@ public class StorageRecovery extends CommandLine {
   }
 
   /**
-   * get the SystemConsumers for the stores
-   */
-  private HashMap<String, SystemConsumer> getStoreConsumers() {
-    HashMap<String, SystemConsumer> storeConsumers = new HashMap<>();
-    Map<String, SystemFactory> systemFactories = new JavaSystemConfig(jobConfig).getSystemFactories();
-
-    for (Entry<String, SystemStream> entry : changeLogSystemStreams.entrySet()) {
-      String storeSystem = entry.getValue().getSystem();
-      if (!systemFactories.containsKey(storeSystem)) {
-        throw new SamzaException("Changelog system " + storeSystem + " for store " + entry.getKey() + " does not exist in the config.");
-      }
-      storeConsumers.put(entry.getKey(), systemFactories.get(storeSystem).getConsumer(storeSystem, jobConfig, new MetricsRegistryMap()));
-    }
-
-    return storeConsumers;
-  }
-
-  /**
    * get the max partition number of the changelog stream
    */
   private void getChangeLogMaxPartitionNumber() {
@@ -195,68 +176,49 @@ public class StorageRecovery extends CommandLine {
     maxPartitionNumber = maxPartitionId + 1;
   }
 
+  private Map<String, Serde<Object>> getSerdes() {
+    Map<String, Serde<Object>> serdeMap = new HashMap<>();
+    SerializerConfig serializerConfig = new SerializerConfig(jobConfig);
+
+    // Adding all serdes from factories
+    ScalaJavaUtil.toJavaCollection(serializerConfig.getSerdeNames())
+        .stream()
+        .forEach(serdeName -> {
+            Option<String> serdeClassName = serializerConfig.getSerdeClass(serdeName);
+
+            if (serdeClassName.isEmpty()) {
+              serdeClassName = Option.apply(SerializerConfig.getSerdeFactoryName(serdeName));
+            }
+
+            Serde serde = Util.getObj(serdeClassName.get(), SerdeFactory.class).getSerde(serdeName, serializerConfig);
+            serdeMap.put(serdeName, serde);
+          });
+
+    return serdeMap;
+  }
+
   /**
    * create one TaskStorageManager for each task. Add all of them to the
    * List<TaskStorageManager>
    */
-  @SuppressWarnings({ "unchecked", "rawtypes" })
-  private void getContainerStorageManager() {
+  @SuppressWarnings({"unchecked", "rawtypes"})
+  private void getContainerStorageManagers() {
     Clock clock = SystemClock.instance();
-    Map<TaskName, TaskStorageManager> taskStorageManagers = new HashMap<>();
-    HashMap<String, SystemConsumer> storeConsumers = getStoreConsumers();
     StreamMetadataCache streamMetadataCache = new StreamMetadataCache(systemAdmins, 5000, clock);
     // don't worry about prefetching for this; looks like the tool doesn't flush to offset files anyways
-    SSPMetadataCache sspMetadataCache =
-        new SSPMetadataCache(systemAdmins, Duration.ofSeconds(5), clock, Collections.emptySet());
+
+    Map<String, SystemFactory> systemFactories = new JavaSystemConfig(jobConfig).getSystemFactories();
 
     for (ContainerModel containerModel : containers.values()) {
-      HashMap<String, StorageEngine> taskStores = new HashMap<String, StorageEngine>();
       ContainerContext containerContext = new ContainerContextImpl(containerModel, new MetricsRegistryMap());
 
-      for (TaskModel taskModel : containerModel.getTasks().values()) {
-
-        for (Entry<String, StorageEngineFactory<?, ?>> entry : storageEngineFactories.entrySet()) {
-          String storeName = entry.getKey();
-
-          if (changeLogSystemStreams.containsKey(storeName)) {
-            SystemStreamPartition changeLogSystemStreamPartition = new SystemStreamPartition(changeLogSystemStreams.get(storeName),
-                taskModel.getChangelogPartition());
-            File storePartitionDir = TaskStorageManager.getStorePartitionDir(storeBaseDir, storeName, taskModel.getTaskName());
-
-            log.info("Got storage engine directory: " + storePartitionDir);
-
-            StorageEngine storageEngine = (entry.getValue()).getStorageEngine(
-                storeName,
-                storePartitionDir,
-                (Serde) new ByteSerde(),
-                (Serde) new ByteSerde(),
-                null,
-                new MetricsRegistryMap(),
-                changeLogSystemStreamPartition,
-                JobContextImpl.fromConfigWithDefaults(jobConfig),
-                containerContext);
-            taskStores.put(storeName, storageEngine);
-          }
-        }
-        TaskStorageManager taskStorageManager = new TaskStorageManager(
-            taskModel.getTaskName(),
-            ScalaJavaUtil.toScalaMap(taskStores),
-            ScalaJavaUtil.toScalaMap(storeConsumers),
-            ScalaJavaUtil.toScalaMap(changeLogSystemStreams),
-            maxPartitionNumber,
-            streamMetadataCache,
-            sspMetadataCache,
-            storeBaseDir,
-            storeBaseDir,
-            taskModel.getChangelogPartition(),
-            systemAdmins,
-            new StorageConfig(jobConfig).getChangeLogDeleteRetentionsInMs(),
-            new SystemClock());
-
-        taskStorageManagers.put(taskModel.getTaskName(), taskStorageManager);
-      }
+      ContainerStorageManager containerStorageManager =
+          new ContainerStorageManager(containerModel, streamMetadataCache, systemAdmins, changeLogSystemStreams,
+              storageEngineFactories, systemFactories, this.getSerdes(), jobConfig, new HashMap<>(),
+              new SamzaContainerMetrics(containerModel.getId(), new MetricsRegistryMap()),
+              JobContextImpl.fromConfigWithDefaults(jobConfig), containerContext, new HashMap<>(),
+              storeBaseDir, storeBaseDir, maxPartitionNumber, new SystemClock());
+      this.containerStorageManagers.put(containerModel.getId(), containerStorageManager);
     }
-
-    this.containerStorageManager = new ContainerStorageManager(taskStorageManagers, storeConsumers, null);
   }
 }
index e38a451..ec7360a 100644 (file)
@@ -26,7 +26,7 @@ import java.net.{URL, UnknownHostException}
 import java.nio.file.Path
 import java.time.Duration
 import java.util
-import java.util.Base64
+import java.util.{Base64}
 import java.util.concurrent.{ExecutorService, Executors, ScheduledExecutorService, TimeUnit}
 
 import com.google.common.annotations.VisibleForTesting
@@ -50,6 +50,7 @@ import org.apache.samza.metrics.{JmxServer, JvmMetrics, MetricsRegistryMap, Metr
 import org.apache.samza.serializers._
 import org.apache.samza.serializers.model.SamzaObjectMapper
 import org.apache.samza.startpoint.StartpointManager
+import org.apache.samza.storage.StorageEngineFactory.StoreMode
 import org.apache.samza.storage._
 import org.apache.samza.system._
 import org.apache.samza.system.chooser.{DefaultChooser, MessageChooserFactory, RoundRobinChooserFactory}
@@ -81,7 +82,9 @@ object SamzaContainer extends Logging {
         classOf[JobModel])
   }
 
-  // TODO: SAMZA-1701 SamzaContainer should not contain any logic related to store directories
+  /**
+    * If a base-directory was NOT explicitly provided in config, a default base directory is returned.
+    */
   def getNonLoggedStorageBaseDir(config: Config, defaultStoreBaseDir: File) = {
     config.getNonLoggedStorePath match {
       case Some(nonLoggedStorePath) =>
@@ -91,7 +94,10 @@ object SamzaContainer extends Logging {
     }
   }
 
-  // TODO: SAMZA-1701 SamzaContainer should not contain any logic related to store directories
+  /**
+    * If a base-directory was NOT explicitly provided in config or via an environment variable
+    * (see ShellCommandConfig.ENV_LOGGED_STORE_BASE_DIR), then a default base directory is returned.
+    */
   def getLoggedStorageBaseDir(config: Config, defaultStoreBaseDir: File) = {
     val defaultLoggedStorageBaseDir = config.getLoggedStorePath match {
       case Some(durableStorePath) =>
@@ -504,6 +510,7 @@ object SamzaContainer extends Logging {
       .map(_.getTaskName)
       .toSet
 
+    val taskModels = containerModel.getTasks.values.asScala
     val containerContext = new ContainerContextImpl(containerModel, samzaContainerMetrics.registry)
     val applicationContainerContextOption = applicationContainerContextFactoryOption
       .map(_.create(externalContextOption.orNull, jobContext, containerContext))
@@ -512,23 +519,37 @@ object SamzaContainer extends Logging {
 
     val timerExecutor = Executors.newSingleThreadScheduledExecutor
 
-    // We create a map of store SystemName to its respective SystemConsumer
-    val storeSystemConsumers: Map[String, SystemConsumer] = changeLogSystemStreams.mapValues {
-      case (changeLogSystemStream) => (changeLogSystemStream.getSystem)
-    }.values.toSet.map {
-      systemName: String =>
-        (systemName, systemFactories
-          .getOrElse(systemName,
-            throw new SamzaException("Changelog system %s exist in the config." format (systemName)))
-          .getConsumer(systemName, config, samzaContainerMetrics.registry))
-    }.toMap
+    var taskStorageManagers : Map[TaskName, TaskStorageManager] = Map()
 
-    info("Created store system consumers: %s" format storeSystemConsumers)
+    val taskInstanceMetrics: Map[TaskName, TaskInstanceMetrics] = taskModels.map(taskModel => {
+      (taskModel.getTaskName, new TaskInstanceMetrics("TaskName-%s" format taskModel.getTaskName))
+    }).toMap
 
-    var taskStorageManagers : Map[TaskName, TaskStorageManager] = Map()
+    val taskCollectors : Map[TaskName, TaskInstanceCollector] = taskModels.map(taskModel => {
+      (taskModel.getTaskName, new TaskInstanceCollector(producerMultiplexer, taskInstanceMetrics.get(taskModel.getTaskName).get))
+    }).toMap
+
+    val defaultStoreBaseDir = new File(System.getProperty("user.dir"), "state")
+    info("Got default storage engine base directory: %s" format defaultStoreBaseDir)
+
+    val nonLoggedStorageBaseDir = getNonLoggedStorageBaseDir(config, defaultStoreBaseDir)
+    info("Got base directory for non logged data stores: %s" format nonLoggedStorageBaseDir)
+
+    val loggedStorageBaseDir = getLoggedStorageBaseDir(config, defaultStoreBaseDir)
+    info("Got base directory for logged data stores: %s" format loggedStorageBaseDir)
+
+    val sideInputStorageEngineFactories = storageEngineFactories.filterKeys(storeName => sideInputStoresToSystemStreams.contains(storeName))
+    val nonSideInputStorageEngineFactories = (storageEngineFactories.toSet diff sideInputStorageEngineFactories.toSet).toMap
+
+    val containerStorageManager = new ContainerStorageManager(containerModel, streamMetadataCache, systemAdmins,
+      changeLogSystemStreams.asJava, nonSideInputStorageEngineFactories.asJava, systemFactories.asJava, serdes.asJava, config,
+      taskInstanceMetrics.asJava, samzaContainerMetrics, jobContext, containerContext, taskCollectors.asJava,
+      loggedStorageBaseDir, nonLoggedStorageBaseDir, maxChangeLogStreamPartitions, new SystemClock)
+
+    storeWatchPaths.addAll(containerStorageManager.getStoreDirectoryPaths)
 
     // Create taskInstances
-    val taskInstances: Map[TaskName, TaskInstance] = containerModel.getTasks.values.asScala.map(taskModel => {
+    val taskInstances: Map[TaskName, TaskInstance] = taskModels.map(taskModel => {
       debug("Setting up task instance: %s" format taskModel)
 
       val taskName = taskModel.getTaskName
@@ -538,32 +559,7 @@ object SamzaContainer extends Logging {
         case tf: StreamTaskFactory => tf.asInstanceOf[StreamTaskFactory].createInstance()
       }
 
-      val taskInstanceMetrics = new TaskInstanceMetrics("TaskName-%s" format taskName)
-
-      val collector = new TaskInstanceCollector(producerMultiplexer, taskInstanceMetrics)
-
-      // Re-use the storeConsumers, stored in storeSystemConsumers
-      val storeConsumers : Map[String, SystemConsumer] = changeLogSystemStreams
-        .map {
-          case (storeName, changeLogSystemStream) =>
-            val systemConsumer = storeSystemConsumers.get(changeLogSystemStream.getSystem).get
-            samzaContainerMetrics.addStoreRestorationGauge(taskName, storeName)
-            (storeName, systemConsumer)
-        }
-
-      info("Got store consumers: %s" format storeConsumers)
-
-      val defaultStoreBaseDir = new File(System.getProperty("user.dir"), "state")
-      info("Got default storage engine base directory: %s" format defaultStoreBaseDir)
-
-      val nonLoggedStorageBaseDir = getNonLoggedStorageBaseDir(config, defaultStoreBaseDir)
-      info("Got base directory for non logged data stores: %s" format nonLoggedStorageBaseDir)
-
-      val loggedStorageBaseDir = getLoggedStorageBaseDir(config, defaultStoreBaseDir)
-      info("Got base directory for logged data stores: %s" format loggedStorageBaseDir)
-
-      val taskStores = storageEngineFactories
-        .map {
+      val sideInputStores = sideInputStorageEngineFactories.map {
           case (storeName, storageEngineFactory) =>
             val changeLogSystemStreamPartition = if (changeLogSystemStreams.contains(storeName)) {
               new SystemStreamPartition(changeLogSystemStreams(storeName), taskModel.getChangelogPartition)
@@ -583,37 +579,29 @@ object SamzaContainer extends Logging {
               case _ => null
             }
 
-            // We use the logged storage base directory for change logged and side input stores since side input stores
+            // We use the logged storage base directory for side input stores since side input stores
             // dont have changelog configured.
-            val storeDir = if (changeLogSystemStreamPartition != null || sideInputStoresToSystemStreams.contains(storeName)) {
-              TaskStorageManager.getStorePartitionDir(loggedStorageBaseDir, storeName, taskName)
-            } else {
-              TaskStorageManager.getStorePartitionDir(nonLoggedStorageBaseDir, storeName, taskName)
-            }
-
+            val storeDir = StorageManagerUtil.getStorePartitionDir(loggedStorageBaseDir, storeName, taskName)
             storeWatchPaths.add(storeDir.toPath)
 
-            val storageEngine = storageEngineFactory.getStorageEngine(
+            val sideInputStorageEngine = storageEngineFactory.getStorageEngine(
               storeName,
               storeDir,
               keySerde,
               msgSerde,
-              collector,
-              taskInstanceMetrics.registry,
+              taskCollectors.get(taskName).get,
+              taskInstanceMetrics.get(taskName).get.registry,
               changeLogSystemStreamPartition,
               jobContext,
-              containerContext)
-            (storeName, storageEngine)
+              containerContext, StoreMode.ReadWrite)
+            (storeName, sideInputStorageEngine)
         }
 
-      info("Got task stores: %s" format taskStores)
+      info("Got side input stores: %s" format sideInputStores)
 
       val taskSSPs = taskModel.getSystemStreamPartitions.asScala.toSet
       info("Got task SSPs: %s" format taskSSPs)
 
-      val (sideInputStores, nonSideInputStores) =
-        taskStores.partition { case (storeName, _) => sideInputStoresToSystemStreams.contains(storeName)}
-
       val sideInputStoresToSSPs = sideInputStoresToSystemStreams.mapValues(sideInputSystemStreams =>
         taskSSPs.filter(ssp => sideInputSystemStreams.contains(ssp.getSystemStream)).asJava)
 
@@ -627,24 +615,17 @@ object SamzaContainer extends Logging {
               (storeName, SerdeUtils.deserialize("Side Inputs Processor", serializedInstance)))
             .orElse(config.getSideInputsProcessorFactory(storeName).map(factoryClassName =>
               (storeName, Util.getObj(factoryClassName, classOf[SideInputsProcessorFactory])
-                .getSideInputsProcessor(config, taskInstanceMetrics.registry))))
+                .getSideInputsProcessor(config, taskInstanceMetrics.get(taskName).get.registry))))
             .get
         }).toMap
 
       val storageManager = new TaskStorageManager(
         taskName = taskName,
-        taskStores = nonSideInputStores,
-        storeConsumers = storeConsumers,
+        containerStorageManager = containerStorageManager,
         changeLogSystemStreams = changeLogSystemStreams,
-        maxChangeLogStreamPartitions,
-        streamMetadataCache = streamMetadataCache,
         sspMetadataCache = changelogSSPMetadataCache,
-        nonLoggedStoreBaseDir = nonLoggedStorageBaseDir,
         loggedStoreBaseDir = loggedStorageBaseDir,
-        partition = taskModel.getChangelogPartition,
-        systemAdmins = systemAdmins,
-        new StorageConfig(config).getChangeLogDeleteRetentionsInMs,
-        new SystemClock)
+        partition = taskModel.getChangelogPartition)
 
       var sideInputStorageManager: TaskSideInputStorageManager = null
       if (sideInputStores.nonEmpty) {
@@ -667,16 +648,16 @@ object SamzaContainer extends Logging {
       def createTaskInstance(task: Any): TaskInstance = new TaskInstance(
           task = task,
           taskModel = taskModel,
-          metrics = taskInstanceMetrics,
+          metrics = taskInstanceMetrics.get(taskName).get,
           systemAdmins = systemAdmins,
           consumerMultiplexer = consumerMultiplexer,
-          collector = collector,
+          collector = taskCollectors.get(taskName).get,
           offsetManager = offsetManager,
           storageManager = storageManager,
           tableManager = tableManager,
           reporters = reporters,
           systemStreamPartitions = taskSSPs,
-          exceptionHandler = TaskInstanceExceptionHandler(taskInstanceMetrics, config),
+          exceptionHandler = TaskInstanceExceptionHandler(taskInstanceMetrics.get(taskName).get, config),
           jobModel = jobModel,
           streamMetadataCache = streamMetadataCache,
           timerExecutor = timerExecutor,
@@ -694,10 +675,6 @@ object SamzaContainer extends Logging {
       (taskName, taskInstance)
     }).toMap
 
-
-    val containerStorageManager = new ContainerStorageManager(taskStorageManagers.asJava, storeSystemConsumers.asJava,
-      samzaContainerMetrics)
-
     val maxThrottlingDelayMs = config.getLong("container.disk.quota.delay.max.ms", TimeUnit.SECONDS.toMillis(1))
 
     val runLoop = RunLoopFactory.createRunLoop(
index c39d6e7..b896267 100644 (file)
  */
 package org.apache.samza.storage;
 
+import com.google.common.annotations.VisibleForTesting;
 import com.google.common.util.concurrent.ThreadFactoryBuilder;
+import java.io.File;
+import java.nio.file.Path;
 import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
+import java.util.Optional;
+import java.util.Set;
 import java.util.concurrent.Callable;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Executors;
 import java.util.concurrent.Future;
+import java.util.stream.Collectors;
+import org.apache.samza.Partition;
 import org.apache.samza.SamzaException;
+import org.apache.samza.config.Config;
+import org.apache.samza.config.StorageConfig;
 import org.apache.samza.container.SamzaContainerMetrics;
+import org.apache.samza.container.TaskInstanceMetrics;
 import org.apache.samza.container.TaskName;
+import org.apache.samza.context.ContainerContext;
+import org.apache.samza.context.JobContext;
+import org.apache.samza.job.model.ContainerModel;
+import org.apache.samza.job.model.TaskModel;
 import org.apache.samza.metrics.Gauge;
+import org.apache.samza.metrics.MetricsRegistry;
+import org.apache.samza.metrics.MetricsRegistryMap;
+import org.apache.samza.serializers.Serde;
+import org.apache.samza.system.StreamMetadataCache;
+import org.apache.samza.system.StreamSpec;
+import org.apache.samza.system.SystemAdmin;
+import org.apache.samza.system.SystemAdmins;
 import org.apache.samza.system.SystemConsumer;
+import org.apache.samza.system.SystemFactory;
+import org.apache.samza.system.SystemStream;
+import org.apache.samza.system.SystemStreamMetadata;
+import org.apache.samza.system.SystemStreamPartition;
+import org.apache.samza.system.SystemStreamPartitionIterator;
+import org.apache.samza.task.TaskInstanceCollector;
+import org.apache.samza.util.Clock;
+import org.apache.samza.util.FileUtil;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
+import scala.collection.JavaConverters;
 
 
 /**
@@ -43,53 +75,279 @@ import org.slf4j.LoggerFactory;
  *  a) performing all container-level actions for restore such as, initializing and shutting down
  *  taskStorage managers, starting, registering and stopping consumers, etc.
  *
- *  b) performing individual taskStorageManager restores in parallel.
+ *  b) performing individual task stores' restores in parallel.
  *
  */
 public class ContainerStorageManager {
-
   private static final Logger LOG = LoggerFactory.getLogger(ContainerStorageManager.class);
-  private final Map<TaskName, TaskStorageManager> taskStorageManagers;
+  private static final String RESTORE_THREAD_NAME = "Samza Restore Thread-%d";
+
+  /** Maps containing relevant per-task objects */
+  private final Map<TaskName, Map<String, StorageEngine>> taskStores;
+  private final Map<TaskName, TaskRestoreManager> taskRestoreManagers;
+  private final Map<TaskName, TaskInstanceMetrics> taskInstanceMetrics;
+  private final Map<TaskName, TaskInstanceCollector> taskInstanceCollectors;
+
+  private final Map<String, SystemConsumer> systemConsumers; // Mapping from storeSystemNames to SystemConsumers
+  private final Map<String, StorageEngineFactory<Object, Object>> storageEngineFactories; // Map of storageEngineFactories indexed by store name
+  private final Map<String, SystemStream> changelogSystemStreams; // Map of changelog system-streams indexed by store name
+  private final Map<String, Serde<Object>> serdes; // Map of Serde objects indexed by serde name (specified in config)
+
+  private final StreamMetadataCache streamMetadataCache;
   private final SamzaContainerMetrics samzaContainerMetrics;
 
-  // Mapping of from storeSystemNames to SystemConsumers
-  private final Map<String, SystemConsumer> systemConsumers;
+  /* Parameters required to re-create taskStores post-restoration */
+  private final ContainerModel containerModel;
+  private final JobContext jobContext;
+  private final ContainerContext containerContext;
+
+  private final File loggedStoreBaseDirectory;
+  private final File nonLoggedStoreBaseDirectory;
+  private final Set<Path> storeDirectoryPaths; // the set of store directory paths, used by SamzaContainer to initialize its disk-space-monitor
 
-  // Size of thread-pool to be used for parallel restores
   private final int parallelRestoreThreadPoolSize;
+  private final int maxChangeLogStreamPartitions; // The partition count of each changelog-stream topic. This is used for validating changelog streams before restoring.
 
-  // Naming convention to be used for restore threads
-  private static final String RESTORE_THREAD_NAME = "Samza Restore Thread-%d";
+  private final Config config;
+
+  public ContainerStorageManager(ContainerModel containerModel, StreamMetadataCache streamMetadataCache,
+      SystemAdmins systemAdmins, Map<String, SystemStream> changelogSystemStreams,
+      Map<String, StorageEngineFactory<Object, Object>> storageEngineFactories,
+      Map<String, SystemFactory> systemFactories, Map<String, Serde<Object>> serdes, Config config,
+      Map<TaskName, TaskInstanceMetrics> taskInstanceMetrics, SamzaContainerMetrics samzaContainerMetrics,
+      JobContext jobContext, ContainerContext containerContext,
+      Map<TaskName, TaskInstanceCollector> taskInstanceCollectors, File loggedStoreBaseDirectory,
+      File nonLoggedStoreBaseDirectory, int maxChangeLogStreamPartitions, Clock clock) {
+
+    this.containerModel = containerModel;
+    this.changelogSystemStreams = changelogSystemStreams;
+    this.storageEngineFactories = storageEngineFactories;
+    this.serdes = serdes;
+    this.loggedStoreBaseDirectory = loggedStoreBaseDirectory;
+    this.nonLoggedStoreBaseDirectory = nonLoggedStoreBaseDirectory;
+
+    // set the config
+    this.config = config;
 
-  public ContainerStorageManager(Map<TaskName, TaskStorageManager> taskStorageManagers,
-      Map<String, SystemConsumer> systemConsumers, SamzaContainerMetrics samzaContainerMetrics) {
-    this.taskStorageManagers = taskStorageManagers;
-    this.systemConsumers = systemConsumers;
+    this.taskInstanceMetrics = taskInstanceMetrics;
+
+    // Setting the metrics registry
     this.samzaContainerMetrics = samzaContainerMetrics;
 
-    // Setting thread pool size equal to the number of tasks
-    this.parallelRestoreThreadPoolSize = taskStorageManagers.size();
+    this.jobContext = jobContext;
+    this.containerContext = containerContext;
+
+    this.taskInstanceCollectors = taskInstanceCollectors;
+
+    // initializing the set of store directory paths
+    this.storeDirectoryPaths = new HashSet<>();
+
+    // Setting the restore thread pool size equal to the number of taskInstances
+    this.parallelRestoreThreadPoolSize = containerModel.getTasks().size();
+
+    this.maxChangeLogStreamPartitions = maxChangeLogStreamPartitions;
+    this.streamMetadataCache = streamMetadataCache;
+
+    // create taskStores for all tasks in the containerModel and each store in storageEngineFactories
+    this.taskStores = createTaskStores(containerModel, jobContext, containerContext, storageEngineFactories, changelogSystemStreams,
+        serdes, taskInstanceMetrics, taskInstanceCollectors, StorageEngineFactory.StoreMode.BulkLoad);
+
+    // create system consumers (1 per store system)
+    this.systemConsumers = createStoreConsumers(changelogSystemStreams, systemFactories, config, this.samzaContainerMetrics.registry());
+
+    // creating task restore managers
+    this.taskRestoreManagers = createTaskRestoreManagers(systemAdmins, clock);
+  }
+
+  /**
+   *  Creates SystemConsumer objects for store restoration, creating one consumer per system.
+   */
+  private static Map<String, SystemConsumer> createStoreConsumers(Map<String, SystemStream> changelogSystemStreams,
+      Map<String, SystemFactory> systemFactories, Config config, MetricsRegistry registry) {
+    // Determine the set of systems being used across all stores
+    Set<String> storeSystems =
+        changelogSystemStreams.values().stream().map(SystemStream::getSystem).collect(Collectors.toSet());
+
+    // Create one consumer for each system in use, map with one entry for each such system
+    Map<String, SystemConsumer> storeSystemConsumers = new HashMap<>();
+
+    // Map of each storeName to its respective systemConsumer
+    Map<String, SystemConsumer> storeConsumers = new HashMap<>();
+
+    // Iterate over the list of storeSystems and create one sysConsumer per system
+    for (String storeSystemName : storeSystems) {
+      SystemFactory systemFactory = systemFactories.get(storeSystemName);
+      if (systemFactory == null) {
+        throw new SamzaException("Changelog system " + storeSystemName + " does not exist in config");
+      }
+      storeSystemConsumers.put(storeSystemName,
+          systemFactory.getConsumer(storeSystemName, config, registry));
+    }
+
+    // Populate the map of storeName to its relevant systemConsumer
+    for (String storeName : changelogSystemStreams.keySet()) {
+      storeConsumers.put(storeName, storeSystemConsumers.get(changelogSystemStreams.get(storeName).getSystem()));
+    }
+
+    return storeConsumers;
+  }
+
+  private Map<TaskName, TaskRestoreManager> createTaskRestoreManagers(SystemAdmins systemAdmins, Clock clock) {
+    Map<TaskName, TaskRestoreManager> taskRestoreManagers = new HashMap<>();
+    containerModel.getTasks().forEach((taskName, taskModel) ->
+      taskRestoreManagers.put(taskName, new TaskRestoreManager(taskModel, changelogSystemStreams, taskStores.get(taskName), systemAdmins, clock)));
+    return taskRestoreManagers;
+  }
+
+  /**
+   * Create taskStores with the given store mode for all stores in storageEngineFactories.
+   */
+  private Map<TaskName, Map<String, StorageEngine>> createTaskStores(ContainerModel containerModel, JobContext jobContext, ContainerContext containerContext,
+      Map<String, StorageEngineFactory<Object, Object>> storageEngineFactories,
+      Map<String, SystemStream> changelogSystemStreams, Map<String, Serde<Object>> serdes,
+      Map<TaskName, TaskInstanceMetrics> taskInstanceMetrics,
+      Map<TaskName, TaskInstanceCollector> taskInstanceCollectors, StorageEngineFactory.StoreMode storeMode) {
+
+    Map<TaskName, Map<String, StorageEngine>> taskStores = new HashMap<>();
+
+    // iterate over each task in the containerModel, and each store in storageEngineFactories
+    for (Map.Entry<TaskName, TaskModel> task : containerModel.getTasks().entrySet()) {
+      TaskName taskName = task.getKey();
+      TaskModel taskModel = task.getValue();
+
+      if (!taskStores.containsKey(taskName)) {
+        taskStores.put(taskName, new HashMap<>());
+      }
+
+      for (String storeName : storageEngineFactories.keySet()) {
+
+        StorageEngine storageEngine =
+            createStore(storeName, taskName, taskModel, jobContext, containerContext, storageEngineFactories,
+                changelogSystemStreams, serdes, taskInstanceMetrics, taskInstanceCollectors, storeMode);
+
+        // add created store to map
+        taskStores.get(taskName).put(storeName, storageEngine);
+
+        LOG.info("Created store {} for task {}", storeName, taskName);
+      }
+    }
+
+    return taskStores;
+  }
+
+  /**
+   * Recreate all persistent stores in ReadWrite mode.
+   *
+   */
+  private void recreatePersistentTaskStoresInReadWriteMode(ContainerModel containerModel, JobContext jobContext,
+      ContainerContext containerContext, Map<String, StorageEngineFactory<Object, Object>> storageEngineFactories,
+      Map<String, SystemStream> changelogSystemStreams, Map<String, Serde<Object>> serdes,
+      Map<TaskName, TaskInstanceMetrics> taskInstanceMetrics,
+      Map<TaskName, TaskInstanceCollector> taskInstanceCollectors) {
+
+    // iterate over each task and each storeName
+    for (Map.Entry<TaskName, TaskModel> task : containerModel.getTasks().entrySet()) {
+      TaskName taskName = task.getKey();
+      TaskModel taskModel = task.getValue();
+
+      for (String storeName : storageEngineFactories.keySet()) {
+
+        // if this store has been already created in the taskStores, then re-create and overwrite it only if it is a persistentStore
+        if (this.taskStores.get(taskName).containsKey(storeName) && this.taskStores.get(taskName)
+            .get(storeName)
+            .getStoreProperties()
+            .isPersistedToDisk()) {
+
+          StorageEngine storageEngine =
+              createStore(storeName, taskName, taskModel, jobContext, containerContext, storageEngineFactories,
+                  changelogSystemStreams, serdes, taskInstanceMetrics, taskInstanceCollectors,
+                  StorageEngineFactory.StoreMode.ReadWrite);
+
+          // add created store to map
+          this.taskStores.get(taskName).put(storeName, storageEngine);
+
+          LOG.info("Re-created store {} in read-write mode for task {} because it a persistent store", storeName, taskName);
+        } else {
+
+          LOG.info("Skipping re-creation of store {} for task {} because it a non-persistent store", storeName, taskName);
+        }
+      }
+    }
+  }
+
+  /**
+   * Method to instantiate a StorageEngine with the given parameters, and populate the storeDirectory paths (used to monitor
+   * disk space).
+   */
+  private StorageEngine createStore(String storeName, TaskName taskName, TaskModel taskModel, JobContext jobContext,
+      ContainerContext containerContext, Map<String, StorageEngineFactory<Object, Object>> storageEngineFactories,
+      Map<String, SystemStream> changelogSystemStreams, Map<String, Serde<Object>> serdes,
+      Map<TaskName, TaskInstanceMetrics> taskInstanceMetrics,
+      Map<TaskName, TaskInstanceCollector> taskInstanceCollectors, StorageEngineFactory.StoreMode storeMode) {
+
+    StorageConfig storageConfig = new StorageConfig(config);
+
+    SystemStreamPartition changeLogSystemStreamPartition =
+        (changelogSystemStreams.containsKey(storeName)) ? new SystemStreamPartition(
+            changelogSystemStreams.get(storeName), taskModel.getChangelogPartition()) : null;
+
+    // Use the logged-store-base-directory for change logged stores, and non-logged-store-base-dir for non logged stores
+    File storeDirectory =
+        (changeLogSystemStreamPartition != null) ? StorageManagerUtil.getStorePartitionDir(this.loggedStoreBaseDirectory,
+            storeName, taskName)
+            : StorageManagerUtil.getStorePartitionDir(this.nonLoggedStoreBaseDirectory, storeName, taskName);
+    this.storeDirectoryPaths.add(storeDirectory.toPath());
+
+    if (storageConfig.getStorageKeySerde(storeName).isEmpty()) {
+      throw new SamzaException("No key serde defined for store: " + storeName);
+    }
+
+    Serde keySerde = serdes.get(storageConfig.getStorageKeySerde(storeName).get());
+    if (keySerde == null) {
+      throw new SamzaException(
+          "StorageKeySerde: No class defined for serde: " + storageConfig.getStorageKeySerde(storeName));
+    }
+
+    if (storageConfig.getStorageMsgSerde(storeName).isEmpty()) {
+      throw new SamzaException("No msg serde defined for store: " + storeName);
+    }
+
+    Serde messageSerde = serdes.get(storageConfig.getStorageMsgSerde(storeName).get());
+    if (messageSerde == null) {
+      throw new SamzaException(
+          "StorageMsgSerde: No class defined for serde: " + storageConfig.getStorageMsgSerde(storeName));
+    }
+
+    // if taskInstanceMetrics are specified use those for store metrics,
+    // otherwise (in case of StorageRecovery) use a blank MetricsRegistryMap
+    MetricsRegistry storeMetricsRegistry =
+        taskInstanceMetrics.get(taskName) != null ? taskInstanceMetrics.get(taskName).registry()
+            : new MetricsRegistryMap();
+
+    return storageEngineFactories.get(storeName)
+        .getStorageEngine(storeName, storeDirectory, keySerde, messageSerde, taskInstanceCollectors.get(taskName),
+            storeMetricsRegistry, changeLogSystemStreamPartition, jobContext, containerContext, storeMode);
   }
 
   public void start() throws SamzaException {
     LOG.info("Restore started");
 
     // initialize each TaskStorageManager
-    this.taskStorageManagers.values().forEach(taskStorageManager -> taskStorageManager.init());
+    this.taskRestoreManagers.values().forEach(taskStorageManager -> taskStorageManager.initialize());
 
     // Start consumers
     this.systemConsumers.values().forEach(systemConsumer -> systemConsumer.start());
 
-    // Create a thread pool for parallel restores
+    // Create a thread pool for parallel restores (and stopping of persistent stores)
     ExecutorService executorService = Executors.newFixedThreadPool(this.parallelRestoreThreadPoolSize,
         new ThreadFactoryBuilder().setNameFormat(RESTORE_THREAD_NAME).build());
 
-    List<Future> taskRestoreFutures = new ArrayList<>(this.taskStorageManagers.entrySet().size());
+    List<Future> taskRestoreFutures = new ArrayList<>(this.taskRestoreManagers.entrySet().size());
 
     // Submit restore callable for each taskInstance
-    this.taskStorageManagers.forEach((taskInstance, taskStorageManager) -> {
-        taskRestoreFutures.add(
-            executorService.submit(new TaskRestoreCallable(this.samzaContainerMetrics, taskInstance, taskStorageManager)));
+    this.taskRestoreManagers.forEach((taskInstance, taskRestoreManager) -> {
+        taskRestoreFutures.add(executorService.submit(
+            new TaskRestoreCallable(this.samzaContainerMetrics, taskInstance, taskRestoreManager)));
       });
 
     // loop-over the future list to wait for each thread to finish, catch any exceptions during restore and throw
@@ -108,14 +366,50 @@ public class ContainerStorageManager {
     // Stop consumers
     this.systemConsumers.values().forEach(systemConsumer -> systemConsumer.stop());
 
+    // Now re-create persistent stores in read-write mode, leave non-persistent stores as-is
+    recreatePersistentTaskStoresInReadWriteMode(this.containerModel, jobContext, containerContext,
+        storageEngineFactories, changelogSystemStreams, serdes, taskInstanceMetrics, taskInstanceCollectors);
+
     LOG.info("Restore complete");
   }
 
+  /**
+   * Get the {@link StorageEngine} instance with a given name for a given task.
+   * @param taskName the task name for which the storage engine is desired.
+   * @param storeName the desired store's name.
+   * @return the task store.
+   */
+  public Optional<StorageEngine> getStore(TaskName taskName, String storeName) {
+    return Optional.ofNullable(this.taskStores.get(taskName).get(storeName));
+  }
+
+  /**
+   *  Get all {@link StorageEngine} instance used by a given task.
+   * @param taskName  the task name, all stores for which are desired.
+   * @return map of stores used by the given task, indexed by storename
+   */
+  public Map<String, StorageEngine> getAllStores(TaskName taskName) {
+    return this.taskStores.get(taskName);
+  }
+
+  /**
+   * Set of directory paths for all stores restored by this {@link ContainerStorageManager}.
+   * @return the set of all store directory paths
+   */
+  public Set<Path> getStoreDirectoryPaths() {
+    return this.storeDirectoryPaths;
+  }
+
+  @VisibleForTesting
+  public void stopStores() {
+    this.taskStores.forEach((taskName, storeMap) -> storeMap.forEach((storeName, store) -> store.stop()));
+  }
+
   public void shutdown() {
-    this.taskStorageManagers.forEach((taskInstance, taskStorageManager) -> {
-        if (taskStorageManager != null) {
+    this.taskRestoreManagers.forEach((taskInstance, taskRestoreManager) -> {
+        if (taskRestoreManager != null) {
           LOG.debug("Shutting down task storage manager for taskName: {} ", taskInstance);
-          taskStorageManager.stop();
+          taskRestoreManager.stop();
         } else {
           LOG.debug("Skipping task storage manager shutdown for taskName: {}", taskInstance);
         }
@@ -124,27 +418,36 @@ public class ContainerStorageManager {
     LOG.info("Shutdown complete");
   }
 
-  /** Callable for performing the restoreStores on a taskStorage manager and emitting task-restoration metric.
+  /**
+   * Callable for performing the restoreStores on a task restore manager and emitting the task-restoration metric.
+   * After restoration, all persistent stores are stopped (which will invoke compaction in case of certain persistent
+   * stores that were opened in bulk-load mode).
+   * Performing stop here parallelizes this compaction, which is a time-intensive operation.
    *
    */
   private class TaskRestoreCallable implements Callable<Void> {
 
     private TaskName taskName;
-    private TaskStorageManager taskStorageManager;
+    private TaskRestoreManager taskRestoreManager;
     private SamzaContainerMetrics samzaContainerMetrics;
 
     public TaskRestoreCallable(SamzaContainerMetrics samzaContainerMetrics, TaskName taskName,
-        TaskStorageManager taskStorageManager) {
+        TaskRestoreManager taskRestoreManager) {
       this.samzaContainerMetrics = samzaContainerMetrics;
       this.taskName = taskName;
-      this.taskStorageManager = taskStorageManager;
+      this.taskRestoreManager = taskRestoreManager;
     }
 
     @Override
     public Void call() {
       long startTime = System.currentTimeMillis();
       LOG.info("Starting stores in task instance {}", this.taskName.getTaskName());
-      taskStorageManager.restoreStores();
+      taskRestoreManager.restoreStores();
+
+      // Stop all persistent stores after restoring. Certain persistent stores opened in BulkLoad mode are compacted
+      // on stop, so paralleling stop() also parallelizes their compaction (a time-intensive operation).
+      taskRestoreManager.stopPersistentStores();
+
       long timeToRestore = System.currentTimeMillis() - startTime;
 
       if (this.samzaContainerMetrics != null) {
@@ -157,4 +460,280 @@ public class ContainerStorageManager {
       return null;
     }
   }
+
+  /**
+   * Restore logic for all stores of a task including directory cleanup, setup, changelogSSP validation, registering
+   * with the respective consumer, restoring stores, and stopping stores.
+   */
+  private class TaskRestoreManager {
+
+    private final static String OFFSET_FILE_NAME = "OFFSET";
+    private final Map<String, StorageEngine> taskStores; // Map of all StorageEngines for this task indexed by store name
+    private final Set<String> taskStoresToRestore;
+    // Set of store names which need to be restored by consuming using system-consumers (see registerStartingOffsets)
+
+    private final TaskModel taskModel;
+    private final Clock clock; // Clock value used to validate base-directories for staleness. See isLoggedStoreValid.
+    private Map<SystemStream, String> changeLogOldestOffsets; // Map of changelog oldest known offsets
+    private final Map<SystemStreamPartition, String> fileOffsets; // Map of offsets read from offset file indexed by changelog SSP
+    private final Map<String, SystemStream> changelogSystemStreams; // Map of change log system-streams indexed by store name
+    private final SystemAdmins systemAdmins;
+
+    public TaskRestoreManager(TaskModel taskModel, Map<String, SystemStream> changelogSystemStreams,
+        Map<String, StorageEngine> taskStores, SystemAdmins systemAdmins, Clock clock) {
+      this.taskStores = taskStores;
+      this.taskModel = taskModel;
+      this.clock = clock;
+      this.changelogSystemStreams = changelogSystemStreams;
+      this.systemAdmins = systemAdmins;
+      this.fileOffsets = new HashMap<>();
+      this.taskStoresToRestore = this.taskStores.entrySet().stream()
+          .filter(x -> x.getValue().getStoreProperties().isLoggedStore())
+          .map(x -> x.getKey()).collect(Collectors.toSet());
+    }
+
+    /**
+     * Cleans up and sets up store directories, validates changeLog SSPs for all stores of this task,
+     * and registers SSPs with the respective consumers.
+     */
+    public void initialize() {
+      cleanBaseDirsAndReadOffsetFiles();
+      setupBaseDirs();
+      validateChangelogStreams();
+      getOldestChangeLogOffsets();
+      registerStartingOffsets();
+    }
+
+    /**
+     * For each store for this task,
+     * a. Deletes the corresponding non-logged-store base dir.
+     * b. Deletes the logged-store-base-dir if it not valid. See {@link #isLoggedStoreValid} for validation semantics.
+     * c. If the logged-store-base-dir is valid, this method reads the offset file and stores each offset.
+     */
+    private void cleanBaseDirsAndReadOffsetFiles() {
+      LOG.debug("Cleaning base directories for stores.");
+
+      taskStores.keySet().forEach(storeName -> {
+          File nonLoggedStorePartitionDir =
+              StorageManagerUtil.getStorePartitionDir(nonLoggedStoreBaseDirectory, storeName, taskModel.getTaskName());
+          LOG.info("Got non logged storage partition directory as " + nonLoggedStorePartitionDir.toPath().toString());
+
+          if (nonLoggedStorePartitionDir.exists()) {
+            LOG.info("Deleting non logged storage partition directory " + nonLoggedStorePartitionDir.toPath().toString());
+            FileUtil.rm(nonLoggedStorePartitionDir);
+          }
+
+          File loggedStorePartitionDir =
+              StorageManagerUtil.getStorePartitionDir(loggedStoreBaseDirectory, storeName, taskModel.getTaskName());
+          LOG.info("Got logged storage partition directory as " + loggedStorePartitionDir.toPath().toString());
+
+          // Delete the logged store if it is not valid.
+          if (!isLoggedStoreValid(storeName, loggedStorePartitionDir)) {
+            LOG.info("Deleting logged storage partition directory " + loggedStorePartitionDir.toPath().toString());
+            FileUtil.rm(loggedStorePartitionDir);
+          } else {
+            String offset = StorageManagerUtil.readOffsetFile(loggedStorePartitionDir, OFFSET_FILE_NAME);
+            LOG.info("Read offset " + offset + " for the store " + storeName + " from logged storage partition directory "
+                + loggedStorePartitionDir);
+
+            if (offset != null) {
+              fileOffsets.put(
+                  new SystemStreamPartition(changelogSystemStreams.get(storeName), taskModel.getChangelogPartition()),
+                  offset);
+            }
+          }
+        });
+    }
+
+    /**
+     * Directory loggedStoreDir associated with the logged store storeName is determined to be valid
+     * if all of the following conditions are true.
+     * a) If the store has to be persisted to disk.
+     * b) If there is a valid offset file associated with the logged store.
+     * c) If the logged store has not gone stale.
+     *
+     * @return true if the logged store is valid, false otherwise.
+     */
+    private boolean isLoggedStoreValid(String storeName, File loggedStoreDir) {
+      long changeLogDeleteRetentionInMs = StorageConfig.DEFAULT_CHANGELOG_DELETE_RETENTION_MS();
+
+      if (new StorageConfig(config).getChangeLogDeleteRetentionsInMs().get(storeName).isDefined()) {
+        changeLogDeleteRetentionInMs =
+            (long) new StorageConfig(config).getChangeLogDeleteRetentionsInMs().get(storeName).get();
+      }
+
+      return this.taskStores.get(storeName).getStoreProperties().isPersistedToDisk()
+          && StorageManagerUtil.isOffsetFileValid(loggedStoreDir, OFFSET_FILE_NAME) && !StorageManagerUtil.isStaleStore(
+          loggedStoreDir, OFFSET_FILE_NAME, changeLogDeleteRetentionInMs, clock.currentTimeMillis());
+    }
+
+    /**
+     * Create stores' base directories for logged-stores if they dont exist.
+     */
+    private void setupBaseDirs() {
+      LOG.debug("Setting up base directories for stores.");
+      taskStores.forEach((storeName, storageEngine) -> {
+          if (storageEngine.getStoreProperties().isLoggedStore()) {
+
+            File loggedStorePartitionDir =
+                StorageManagerUtil.getStorePartitionDir(loggedStoreBaseDirectory, storeName, taskModel.getTaskName());
+
+            LOG.info("Using logged storage partition directory: " + loggedStorePartitionDir.toPath().toString()
+                + " for store: " + storeName);
+
+            if (!loggedStorePartitionDir.exists()) {
+              loggedStorePartitionDir.mkdirs();
+            }
+          } else {
+            File nonLoggedStorePartitionDir =
+                StorageManagerUtil.getStorePartitionDir(nonLoggedStoreBaseDirectory, storeName, taskModel.getTaskName());
+            LOG.info("Using non logged storage partition directory: " + nonLoggedStorePartitionDir.toPath().toString()
+                + " for store: " + storeName);
+            nonLoggedStorePartitionDir.mkdirs();
+          }
+        });
+    }
+
+    /**
+     *  Validates each changelog system-stream with its respective SystemAdmin.
+     */
+    private void validateChangelogStreams() {
+      LOG.info("Validating change log streams: " + changelogSystemStreams);
+
+      for (SystemStream changelogSystemStream : changelogSystemStreams.values()) {
+        SystemAdmin systemAdmin = systemAdmins.getSystemAdmin(changelogSystemStream.getSystem());
+        StreamSpec changelogSpec =
+            StreamSpec.createChangeLogStreamSpec(changelogSystemStream.getStream(), changelogSystemStream.getSystem(),
+                maxChangeLogStreamPartitions);
+
+        systemAdmin.validateStream(changelogSpec);
+      }
+    }
+
+    /**
+     * Get the oldest offset for each changelog SSP based on the stream's metadata (obtained from streamMetadataCache).
+     */
+    private void getOldestChangeLogOffsets() {
+
+      Map<SystemStream, SystemStreamMetadata> changeLogMetadata = JavaConverters.mapAsJavaMapConverter(
+          streamMetadataCache.getStreamMetadata(
+              JavaConverters.asScalaSetConverter(new HashSet<>(changelogSystemStreams.values())).asScala().toSet(),
+              false)).asJava();
+
+      LOG.info("Got change log stream metadata: {}", changeLogMetadata);
+
+      changeLogOldestOffsets =
+          getChangeLogOldestOffsetsForPartition(taskModel.getChangelogPartition(), changeLogMetadata);
+      LOG.info("Assigning oldest change log offsets for taskName {} : {}", taskModel.getTaskName(),
+          changeLogOldestOffsets);
+    }
+
+    /**
+     * Builds a map from SystemStreamPartition to oldest offset for changelogs.
+     */
+    private Map<SystemStream, String> getChangeLogOldestOffsetsForPartition(Partition partition,
+        Map<SystemStream, SystemStreamMetadata> inputStreamMetadata) {
+
+      Map<SystemStream, String> retVal = new HashMap<>();
+
+      // NOTE: do not use Collectors.Map because of https://bugs.openjdk.java.net/browse/JDK-8148463
+      inputStreamMetadata.entrySet()
+          .stream()
+          .filter(x -> x.getValue().getSystemStreamPartitionMetadata().get(partition) != null)
+          .forEach(e -> retVal.put(e.getKey(),
+              e.getValue().getSystemStreamPartitionMetadata().get(partition).getOldestOffset()));
+
+      return retVal;
+    }
+
+    /**
+     * Determines the starting offset for each store SSP (based on {@link #getStartingOffset(SystemStreamPartition, SystemAdmin)}) and
+     * registers it with the respective SystemConsumer for starting consumption.
+     */
+    private void registerStartingOffsets() {
+
+      for (Map.Entry<String, SystemStream> changelogSystemStreamEntry : changelogSystemStreams.entrySet()) {
+        SystemStreamPartition systemStreamPartition =
+            new SystemStreamPartition(changelogSystemStreamEntry.getValue(), taskModel.getChangelogPartition());
+        SystemAdmin systemAdmin = systemAdmins.getSystemAdmin(changelogSystemStreamEntry.getValue().getSystem());
+        SystemConsumer systemConsumer = systemConsumers.get(changelogSystemStreamEntry.getKey());
+
+        String offset = getStartingOffset(systemStreamPartition, systemAdmin);
+
+        if (offset != null) {
+          LOG.info("Registering change log consumer with offset " + offset + " for %" + systemStreamPartition);
+          systemConsumer.register(systemStreamPartition, offset);
+        } else {
+          LOG.info("Skipping change log restoration for {} because stream appears to be empty (offset was null).",
+              systemStreamPartition);
+          taskStoresToRestore.remove(changelogSystemStreamEntry.getKey());
+        }
+      }
+    }
+
+    /**
+     * Returns the offset with which the changelog consumer should be initialized for the given SystemStreamPartition.
+     *
+     * If a file offset exists, it represents the last changelog offset which is also reflected in the on-disk state.
+     * In that case, we use the next offset after the file offset, as long as it is newer than the oldest offset
+     * currently available in the stream.
+     *
+     * If there isn't a file offset or it's older than the oldest available offset, we simply start with the oldest.
+     *
+     * @param systemStreamPartition  the changelog partition for which the offset is needed.
+     * @param systemAdmin                  the [[SystemAdmin]] for the changelog.
+     * @return the offset to from which the changelog consumer should be initialized.
+     */
+    private String getStartingOffset(SystemStreamPartition systemStreamPartition, SystemAdmin systemAdmin) {
+      String fileOffset = fileOffsets.get(systemStreamPartition);
+
+      // NOTE: changeLogOldestOffsets may contain a null-offset for the given SSP (signifying an empty stream)
+      // therefore, we need to differentiate that from the case where the offset is simply missing
+      if (!changeLogOldestOffsets.containsKey(systemStreamPartition.getSystemStream())) {
+        throw new SamzaException("Missing a change log offset for " + systemStreamPartition);
+      }
+
+      String oldestOffset = changeLogOldestOffsets.get(systemStreamPartition.getSystemStream());
+      return StorageManagerUtil.getStartingOffset(systemStreamPartition, systemAdmin, fileOffset, oldestOffset);
+    }
+
+
+    /**
+     * Restore each store in taskStoresToRestore sequentially
+     */
+    public void restoreStores() {
+      LOG.debug("Restoring stores for task: {}", taskModel.getTaskName());
+
+      for (String storeName : taskStoresToRestore) {
+        SystemConsumer systemConsumer = systemConsumers.get(storeName);
+        SystemStream systemStream = changelogSystemStreams.get(storeName);
+
+        SystemStreamPartitionIterator systemStreamPartitionIterator = new SystemStreamPartitionIterator(systemConsumer,
+            new SystemStreamPartition(systemStream, taskModel.getChangelogPartition()));
+
+        taskStores.get(storeName).restore(systemStreamPartitionIterator);
+      }
+    }
+
+    /**
+     * Stop all stores.
+     */
+    public void stop() {
+      this.taskStores.values().forEach(storageEngine -> {
+          storageEngine.stop();
+        });
+    }
+
+    /**
+     * Stop only persistent stores. In case of certain stores and store mode (such as RocksDB), this
+     * can invoke compaction.
+     */
+    public void stopPersistentStores() {
+      this.taskStores.values().stream().filter(storageEngine -> {
+          return storageEngine.getStoreProperties().isPersistedToDisk();
+        }).forEach(storageEngine -> {
+            storageEngine.stop();
+          });
+    }
+  }
 }
index 4bcf2d3..f2c4679 100644 (file)
 package org.apache.samza.storage
 
 import java.io._
-import java.util
 
-import org.apache.samza.config.StorageConfig
-import org.apache.samza.{Partition, SamzaException}
+import com.google.common.annotations.VisibleForTesting
+import org.apache.samza.Partition
 import org.apache.samza.container.TaskName
 import org.apache.samza.system._
-import org.apache.samza.util.{Clock, FileUtil, Logging}
+import org.apache.samza.util.ScalaJavaUtil.JavaOptionals
+import org.apache.samza.util.{FileUtil, Logging}
 
-object TaskStorageManager {
-  def getStoreDir(storeBaseDir: File, storeName: String) = {
-    new File(storeBaseDir, storeName)
-  }
-
-  def getStorePartitionDir(storeBaseDir: File, storeName: String, taskName: TaskName) = {
-    // TODO: Sanitize, check and clean taskName string as a valid value for a file
-    new File(storeBaseDir, (storeName + File.separator + taskName.toString).replace(' ', '_'))
-  }
-}
+import scala.collection.JavaConverters._
 
 /**
  * Manage all the storage engines for a given task
  */
 class TaskStorageManager(
   taskName: TaskName,
-  taskStores: Map[String, StorageEngine] = Map(),
-  storeConsumers: Map[String, SystemConsumer] = Map(),
+  containerStorageManager: ContainerStorageManager,
   changeLogSystemStreams: Map[String, SystemStream] = Map(),
-  changeLogStreamPartitions: Int,
-  streamMetadataCache: StreamMetadataCache,
   sspMetadataCache: SSPMetadataCache,
-  nonLoggedStoreBaseDir: File = new File(System.getProperty("user.dir"), "state"),
   loggedStoreBaseDir: File = new File(System.getProperty("user.dir"), "state"),
-  partition: Partition,
-  systemAdmins: SystemAdmins,
-  changeLogDeleteRetentionsInMs: Map[String, Long],
-  clock: Clock) extends Logging {
+  partition: Partition) extends Logging {
 
-  var taskStoresToRestore = taskStores.filter{
-    case (storeName, storageEngine) => storageEngine.getStoreProperties.isLoggedStore
-  }
-  val persistedStores = taskStores.filter{
+  val persistedStores = containerStorageManager.getAllStores(taskName).asScala.filter{
     case (storeName, storageEngine) => storageEngine.getStoreProperties.isPersistedToDisk
   }
 
-  var changeLogOldestOffsets: Map[SystemStream, String] = Map()
-  val fileOffsets: util.Map[SystemStreamPartition, String] = new util.HashMap[SystemStreamPartition, String]()
   val offsetFileName = "OFFSET"
 
-  def getStore(storeName: String): Option[StorageEngine] = taskStores.get(storeName)
+  def getStore(storeName: String): Option[StorageEngine] =  JavaOptionals.toRichOptional(containerStorageManager.getStore(taskName, storeName)).toOption
 
   def init {
-    cleanBaseDirs()
-    setupBaseDirs()
-    validateChangelogStreams()
-    registerSSPs()
-  }
-
-  private def cleanBaseDirs() {
-    debug("Cleaning base directories for stores.")
-
-    taskStores.keys.foreach(storeName => {
-      val nonLoggedStorePartitionDir = TaskStorageManager.getStorePartitionDir(nonLoggedStoreBaseDir, storeName, taskName)
-      info("Got non logged storage partition directory as %s" format nonLoggedStorePartitionDir.toPath.toString)
-
-      if(nonLoggedStorePartitionDir.exists()) {
-        info("Deleting non logged storage partition directory %s" format nonLoggedStorePartitionDir.toPath.toString)
-        FileUtil.rm(nonLoggedStorePartitionDir)
-      }
-
-      val loggedStorePartitionDir = TaskStorageManager.getStorePartitionDir(loggedStoreBaseDir, storeName, taskName)
-      info("Got logged storage partition directory as %s" format loggedStorePartitionDir.toPath.toString)
-
-      // Delete the logged store if it is not valid.
-      if (!isLoggedStoreValid(storeName, loggedStorePartitionDir)) {
-        info("Deleting logged storage partition directory %s." format loggedStorePartitionDir.toPath.toString)
-        FileUtil.rm(loggedStorePartitionDir)
-      } else {
-        val offset = StorageManagerUtil.readOffsetFile(loggedStorePartitionDir, offsetFileName)
-        info("Read offset %s for the store %s from logged storage partition directory %s." format(offset, storeName, loggedStorePartitionDir))
-        if (offset != null) {
-          fileOffsets.put(new SystemStreamPartition(changeLogSystemStreams(storeName), partition), offset)
-        }
-      }
-    })
-  }
-
-  /**
-    * Directory loggedStoreDir associated with the logged store storeName is valid
-    * if all of the following conditions are true.
-    * a) If the store has to be persisted to disk.
-    * b) If there is a valid offset file associated with the logged store.
-    * c) If the logged store has not gone stale.
-    *
-    * @return true if the logged store is valid, false otherwise.
-    */
-  private def isLoggedStoreValid(storeName: String, loggedStoreDir: File): Boolean = {
-    val changeLogDeleteRetentionInMs = changeLogDeleteRetentionsInMs
-      .getOrElse(storeName, StorageConfig.DEFAULT_CHANGELOG_DELETE_RETENTION_MS)
-
-    persistedStores.contains(storeName) &&
-      StorageManagerUtil.isOffsetFileValid(loggedStoreDir, offsetFileName) &&
-      !StorageManagerUtil.isStaleStore(loggedStoreDir, offsetFileName, changeLogDeleteRetentionInMs, clock.currentTimeMillis())
-  }
-
-  private def setupBaseDirs() {
-    debug("Setting up base directories for stores.")
-    taskStores.foreach {
-      case (storeName, storageEngine) =>
-        if (storageEngine.getStoreProperties.isLoggedStore) {
-          val loggedStorePartitionDir = TaskStorageManager.getStorePartitionDir(loggedStoreBaseDir, storeName, taskName)
-          info("Using logged storage partition directory: %s for store: %s." format(loggedStorePartitionDir.toPath.toString, storeName))
-          if (!loggedStorePartitionDir.exists()) loggedStorePartitionDir.mkdirs()
-        } else {
-          val nonLoggedStorePartitionDir = TaskStorageManager.getStorePartitionDir(nonLoggedStoreBaseDir, storeName, taskName)
-          info("Using non logged storage partition directory: %s for store: %s." format(nonLoggedStorePartitionDir.toPath.toString, storeName))
-          nonLoggedStorePartitionDir.mkdirs()
-        }
-    }
-  }
-
-  private def validateChangelogStreams() = {
-    info("Validating change log streams: " + changeLogSystemStreams)
-
-    for ((storeName, systemStream) <- changeLogSystemStreams) {
-      val systemAdmin = systemAdmins.getSystemAdmin(systemStream.getSystem)
-      val changelogSpec = StreamSpec.createChangeLogStreamSpec(systemStream.getStream, systemStream.getSystem, changeLogStreamPartitions)
-
-      systemAdmin.validateStream(changelogSpec)
-    }
-
-    val changeLogMetadata = streamMetadataCache.getStreamMetadata(changeLogSystemStreams.values.toSet)
-    info("Got change log stream metadata: %s" format changeLogMetadata)
-
-    changeLogOldestOffsets = getChangeLogOldestOffsetsForPartition(partition, changeLogMetadata)
-    info("Assigning oldest change log offsets for taskName %s: %s" format (taskName, changeLogOldestOffsets))
-  }
-
-  private def registerSSPs() {
-    debug("Starting consumers for stores.")
-
-    for ((storeName, systemStream) <- changeLogSystemStreams) {
-      val systemStreamPartition = new SystemStreamPartition(systemStream, partition)
-      val admin = systemAdmins.getSystemAdmin(systemStream.getSystem)
-      val consumer = storeConsumers(storeName)
-
-      val offset = getStartingOffset(systemStreamPartition, admin)
-      if (offset != null) {
-        info("Registering change log consumer with offset %s for %s." format (offset, systemStreamPartition))
-        consumer.register(systemStreamPartition, offset)
-      } else {
-        info("Skipping change log restoration for %s because stream appears to be empty (offset was null)." format systemStreamPartition)
-        taskStoresToRestore -= storeName
-      }
-    }
-  }
-
-  /**
-    * Returns the offset with which the changelog consumer should be initialized for the given SystemStreamPartition.
-    *
-    * If a file offset exists, it represents the last changelog offset which is also reflected in the on-disk state.
-    * In that case, we use the next offset after the file offset, as long as it is newer than the oldest offset
-    * currently available in the stream.
-    *
-    * If there isn't a file offset or it's older than the oldest available offset, we simply start with the oldest.
-    *
-    * @param systemStreamPartition  the changelog partition for which the offset is needed.
-    * @param admin                  the [[SystemAdmin]] for the changelog.
-    * @return                       the offset to from which the changelog consumer should be initialized.
-    */
-  private def getStartingOffset(systemStreamPartition: SystemStreamPartition, admin: SystemAdmin) = {
-    val fileOffset = fileOffsets.get(systemStreamPartition)
-    val oldestOffset = changeLogOldestOffsets
-      .getOrElse(systemStreamPartition.getSystemStream,
-        throw new SamzaException("Missing a change log offset for %s." format systemStreamPartition))
-
-    StorageManagerUtil.getStartingOffset(systemStreamPartition, admin, fileOffset, oldestOffset)
-  }
-
-  def restoreStores() {
-    debug("Restoring stores for task: %s." format taskName.getTaskName)
-
-    for ((storeName, store) <- taskStoresToRestore) {
-      if (changeLogSystemStreams.contains(storeName)) {
-        val systemStream = changeLogSystemStreams(storeName)
-        val systemStreamPartition = new SystemStreamPartition(systemStream, partition)
-        val systemConsumer = storeConsumers(storeName)
-        val systemConsumerIterator = new SystemStreamPartitionIterator(systemConsumer, systemStreamPartition)
-        store.restore(systemConsumerIterator)
-      }
-    }
   }
 
   def flush() {
     debug("Flushing stores.")
 
-    taskStores.values.foreach(_.flush)
+    containerStorageManager.getAllStores(taskName).asScala.values.foreach(_.flush)
     flushChangelogOffsetFiles()
   }
 
   def stopStores() {
     debug("Stopping stores.")
-    taskStores.values.foreach(_.stop)
+    containerStorageManager.stopStores();
   }
 
+  @VisibleForTesting
   def stop() {
     stopStores()
 
@@ -249,7 +90,7 @@ class TaskStorageManager(
         val newestOffset = if (sspMetadata == null) null else sspMetadata.getNewestOffset
         debug("Got offset %s for store %s" format(newestOffset, storeName))
 
-        val loggedStorePartitionDir = TaskStorageManager.getStorePartitionDir(loggedStoreBaseDir, storeName, taskName)
+        val loggedStorePartitionDir = StorageManagerUtil.getStorePartitionDir(loggedStoreBaseDir, storeName, taskName)
         val offsetFile = new File(loggedStorePartitionDir, offsetFileName)
         if (newestOffset != null) {
           debug("Storing offset for store in OFFSET file ")
@@ -270,14 +111,4 @@ class TaskStorageManager(
 
     debug("Done persisting logged key value stores")
   }
-
-  /**
-   * Builds a map from SystemStreamPartition to oldest offset for changelogs.
-   */
-  private def getChangeLogOldestOffsetsForPartition(partition: Partition, inputStreamMetadata: Map[SystemStream, SystemStreamMetadata]): Map[SystemStream, String] = {
-    inputStreamMetadata
-      .mapValues(_.getSystemStreamPartitionMetadata.get(partition))
-      .filter(_._2 != null)
-      .mapValues(_.getOldestOffset)
-  }
 }
index a359cd5..049feba 100644 (file)
 
 package org.apache.samza.util
 
-import java.util.function
+import java.util
+import java.util.{Optional, function}
 
+import scala.collection.JavaConverters
 import scala.collection.immutable.Map
 import scala.collection.JavaConverters._
 import scala.runtime.AbstractFunction0
@@ -37,6 +39,13 @@ object ScalaJavaUtil {
   }
 
   /**
+    * Convert a scala iterable to a Java collection.
+    */
+  def toJavaCollection[T](iterable : Iterable[T]): util.Collection[T] = {
+    JavaConverters.asJavaCollectionConverter(iterable).asJavaCollection
+  }
+
+  /**
     * Wraps the provided value in an Scala Function, e.g. for use in [[Option#getOrDefault]]
     *
     * @param value the value to be wrapped
@@ -71,4 +80,30 @@ object ScalaJavaUtil {
   def toScalaFunction[T, R](javaFunction: java.util.function.Function[T, R]): Function1[T, R] = {
     t => javaFunction.apply(t)
   }
+
+
+  /**
+    * Conversions between Scala Option and Java 8 Optional.
+    */
+  object JavaOptionals {
+    implicit def toRichOption[T](opt: Option[T]): RichOption[T] = new RichOption[T](opt)
+    implicit def toRichOptional[T](optional: Optional[T]): RichOptional[T] = new RichOptional[T](optional)
+  }
+
+  class RichOption[T] (opt: Option[T]) {
+
+    /**
+      * Transform this Option to an equivalent Java Optional
+      */
+    def toOptional: Optional[T] = Optional.ofNullable(opt.getOrElse(null).asInstanceOf[T])
+  }
+
+  class RichOptional[T] (opt: Optional[T]) {
+
+    /**
+      * Transform this Optional to an equivalent Scala Option
+      */
+    def toOption: Option[T] = if (opt.isPresent) Some(opt.get()) else None
+  }
+
 }
index 8eff4ad..9a47705 100644 (file)
@@ -37,7 +37,7 @@ public class MockStorageEngineFactory implements StorageEngineFactory<Object, Ob
       MetricsRegistry registry,
       SystemStreamPartition changeLogSystemStreamPartition,
       JobContext jobContext,
-      ContainerContext containerContext) {
+      ContainerContext containerContext, StoreMode storeMode) {
     StoreProperties storeProperties = new StoreProperties.StorePropertiesBuilder().setLoggedStore(true).build();
     return new MockStorageEngine(storeName, storeDir, changeLogSystemStreamPartition, storeProperties);
   }
index 7c1647e..0bd33fa 100644 (file)
@@ -26,6 +26,7 @@ import org.apache.samza.Partition;
 import org.apache.samza.config.Config;
 import org.apache.samza.config.MapConfig;
 import org.apache.samza.coordinator.stream.MockCoordinatorStreamSystemFactory;
+import org.apache.samza.serializers.ByteSerdeFactory;
 import org.apache.samza.system.IncomingMessageEnvelope;
 import org.apache.samza.system.MockSystemFactory;
 import org.apache.samza.system.SystemStreamPartition;
@@ -33,11 +34,12 @@ import org.junit.After;
 import org.junit.Before;
 import org.junit.Test;
 
-import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.*;
 
 public class TestStorageRecovery {
 
   public Config config = null;
+  String path = "/tmp/testing";
   private static final String SYSTEM_STREAM_NAME = "changelog";
   private static final String INPUT_STREAM = "input";
   private static final String STORE_NAME = "testStore";
@@ -59,7 +61,7 @@ public class TestStorageRecovery {
   public void testStorageEngineReceivedAllValues() {
     MockCoordinatorStreamSystemFactory.enableMockConsumerCache();
 
-    String path = "/tmp/testing";
+
     StorageRecovery storageRecovery = new StorageRecovery(config, path);
     storageRecovery.run();
 
@@ -79,6 +81,9 @@ public class TestStorageRecovery {
     map.put("systems.mockSystem.samza.factory", MockSystemFactory.class.getCanonicalName());
     map.put(String.format("stores.%s.factory", STORE_NAME), MockStorageEngineFactory.class.getCanonicalName());
     map.put(String.format("stores.%s.changelog", STORE_NAME), "mockSystem." + SYSTEM_STREAM_NAME);
+    map.put(String.format("stores.%s.key.serde", STORE_NAME), "byteserde");
+    map.put(String.format("stores.%s.msg.serde", STORE_NAME), "byteserde");
+    map.put("serializers.registry.byteserde.class", ByteSerdeFactory.class.getName());
     map.put("task.inputs", "mockSystem.input");
     map.put("job.coordinator.system", "coordinator");
     map.put("systems.coordinator.samza.factory", MockCoordinatorStreamSystemFactory.class.getCanonicalName());
index 5e71efc..9517eec 100644 (file)
  */
 package org.apache.samza.storage;
 
+import java.io.File;
+import java.util.Arrays;
 import java.util.HashMap;
+import java.util.HashSet;
 import java.util.Map;
-import java.util.concurrent.CountDownLatch;
+import org.apache.samza.Partition;
+import org.apache.samza.config.Config;
+import org.apache.samza.config.MapConfig;
 import org.apache.samza.container.SamzaContainerMetrics;
 import org.apache.samza.container.TaskInstance;
+import org.apache.samza.container.TaskInstanceMetrics;
 import org.apache.samza.container.TaskName;
+import org.apache.samza.context.ContainerContext;
+import org.apache.samza.context.JobContext;
+import org.apache.samza.job.model.ContainerModel;
+import org.apache.samza.job.model.TaskModel;
 import org.apache.samza.metrics.Gauge;
+import org.apache.samza.serializers.Serde;
+import org.apache.samza.serializers.StringSerdeFactory;
+import org.apache.samza.system.StreamMetadataCache;
+import org.apache.samza.system.SystemAdmin;
+import org.apache.samza.system.SystemAdmins;
 import org.apache.samza.system.SystemConsumer;
+import org.apache.samza.system.SystemFactory;
+import org.apache.samza.system.SystemStream;
+import org.apache.samza.system.SystemStreamMetadata;
+import org.apache.samza.util.SystemClock;
 import org.junit.Assert;
 import org.junit.Before;
 import org.junit.Test;
 import org.mockito.Mockito;
+import org.mockito.invocation.InvocationOnMock;
+import org.mockito.stubbing.Answer;
+import scala.collection.JavaConverters;
 
 
 public class TestContainerStorageManager {
 
+  private static final String STORE_NAME = "store";
+  private static final String SYSTEM_NAME = "kafka";
+  private static final String STREAM_NAME = "store-stream";
+  private static final File DEFAULT_STORE_BASE_DIR = new File(System.getProperty("java.io.tmpdir") + File.separator + "store");
+  private static final File
+      DEFAULT_LOGGED_STORE_BASE_DIR = new File(System.getProperty("java.io.tmpdir") + File.separator + "loggedStore");
+
   private ContainerStorageManager containerStorageManager;
-  private Map<String, SystemConsumer> systemConsumers;
-  private Map<TaskName, TaskStorageManager> taskStorageManagers;
+  private Map<TaskName, Gauge<Object>> taskRestoreMetricGauges;
+  private Map<TaskName, TaskInstanceMetrics> taskInstanceMetrics;
   private SamzaContainerMetrics samzaContainerMetrics;
+  private Map<TaskName, TaskModel> tasks;
 
-  private CountDownLatch taskStorageManagersRestoreStoreCount;
-  private CountDownLatch taskStorageManagersInitCount;
-  private CountDownLatch taskStorageManagersRestoreStopCount;
-
-  private CountDownLatch systemConsumerStartCount;
-  private CountDownLatch systemConsumerStopCount;
-
-  private Map<TaskName, Gauge<Object>> taskRestoreMetricGauges;
+  private volatile int systemConsumerCreationCount;
+  private volatile int systemConsumerStartCount;
+  private volatile int systemConsumerStopCount;
+  private volatile int storeRestoreCallCount;
 
   /**
    * Utility method for creating a mocked taskInstance and taskStorageManager and adding it to the map.
    * @param taskname the desired taskname.
    */
-  private void addMockedTask(String taskname) {
+  private void addMockedTask(String taskname, int changelogPartition) {
     TaskInstance mockTaskInstance = Mockito.mock(TaskInstance.class);
     Mockito.doAnswer(invocation -> {
         return new TaskName(taskname);
       }).when(mockTaskInstance).taskName();
 
-    TaskStorageManager mockTaskStorageManager = Mockito.mock(TaskStorageManager.class);
-    Mockito.doAnswer(invocation -> {
-        taskStorageManagersInitCount.countDown();
-        return null;
-      }).when(mockTaskStorageManager).init();
-
-    Mockito.doAnswer(invocation -> {
-        taskStorageManagersRestoreStopCount.countDown();
-        return null;
-      }).when(mockTaskStorageManager).stop();
-
-    Mockito.doAnswer(invocation -> {
-        taskStorageManagersRestoreStoreCount.countDown();
-        return null;
-      }).when(mockTaskStorageManager).restoreStores();
-
-    taskStorageManagers.put(new TaskName(taskname), mockTaskStorageManager);
-
     Gauge testGauge = Mockito.mock(Gauge.class);
+    this.tasks.put(new TaskName(taskname),
+        new TaskModel(new TaskName(taskname), new HashSet<>(), new Partition(changelogPartition)));
     this.taskRestoreMetricGauges.put(new TaskName(taskname), testGauge);
+    this.taskInstanceMetrics.put(new TaskName(taskname), Mockito.mock(TaskInstanceMetrics.class));
   }
 
+  /**
+   * Method to create a containerStorageManager with mocked dependencies
+   */
   @Before
   public void setUp() {
     taskRestoreMetricGauges = new HashMap<>();
-    systemConsumers = new HashMap<>();
-    taskStorageManagers = new HashMap<>();
-
-    // add two mocked tasks
-    addMockedTask("task 1");
-    addMockedTask("task 2");
-
-    // define the expected number of invocations on taskStorageManagers' init, stop and restore count
-    // and the expected number of sysConsumer start and stop
-    this.taskStorageManagersInitCount = new CountDownLatch(2);
-    this.taskStorageManagersRestoreStoreCount = new CountDownLatch(2);
-    this.taskStorageManagersRestoreStopCount = new CountDownLatch(2);
-    this.systemConsumerStartCount = new CountDownLatch(1);
-    this.systemConsumerStopCount = new CountDownLatch(1);
-
-    // mock container metrics
+    this.tasks = new HashMap<>();
+    this.taskInstanceMetrics = new HashMap<>();
+
+    // Add two mocked tasks
+    addMockedTask("task 0", 0);
+    addMockedTask("task 1", 1);
+
+    // Mock container metrics
     samzaContainerMetrics = Mockito.mock(SamzaContainerMetrics.class);
     Mockito.when(samzaContainerMetrics.taskStoreRestorationMetrics()).thenReturn(taskRestoreMetricGauges);
 
-    // mock and setup sysconsumers
+    // Create a map of test changeLogSSPs
+    Map<String, SystemStream> changelogSystemStreams = new HashMap<>();
+    changelogSystemStreams.put(STORE_NAME, new SystemStream(SYSTEM_NAME, STREAM_NAME));
+
+    // Create mocked storage engine factories
+    Map<String, StorageEngineFactory<Object, Object>> storageEngineFactories = new HashMap<>();
+    StorageEngineFactory mockStorageEngineFactory =
+        (StorageEngineFactory<Object, Object>) Mockito.mock(StorageEngineFactory.class);
+    StorageEngine mockStorageEngine = Mockito.mock(StorageEngine.class);
+    Mockito.doAnswer(invocation -> {
+        return mockStorageEngine;
+      }).when(mockStorageEngineFactory).getStorageEngine(Mockito.anyString(), Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any(),
+            Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any());
+
+    storageEngineFactories.put(STORE_NAME, mockStorageEngineFactory);
+
+    // Add instrumentation to mocked storage engine, to record the number of store.restore() calls
+    Mockito.doAnswer(invocation -> {
+        storeRestoreCallCount++;
+        return null;
+      }).when(mockStorageEngine).restore(Mockito.any());
+
+    // Set the mocked stores' properties to be persistent
+    Mockito.doAnswer(invocation -> {
+        return new StoreProperties.StorePropertiesBuilder().setLoggedStore(true).build();
+      }).when(mockStorageEngine).getStoreProperties();
+
+    // Mock and setup sysconsumers
     SystemConsumer mockSystemConsumer = Mockito.mock(SystemConsumer.class);
     Mockito.doAnswer(invocation -> {
-        systemConsumerStartCount.countDown();
+        systemConsumerStartCount++;
         return null;
       }).when(mockSystemConsumer).start();
     Mockito.doAnswer(invocation -> {
-        systemConsumerStopCount.countDown();
+        systemConsumerStopCount++;
         return null;
       }).when(mockSystemConsumer).stop();
 
-    systemConsumers.put("kafka", mockSystemConsumer);
+    // Create mocked system factories
+    Map<String, SystemFactory> systemFactories = new HashMap<>();
 
+    // Count the number of sysConsumers created
+    SystemFactory mockSystemFactory = Mockito.mock(SystemFactory.class);
+    Mockito.doAnswer(invocation -> {
+        this.systemConsumerCreationCount++;
+        return mockSystemConsumer;
+      }).when(mockSystemFactory).getConsumer(Mockito.anyString(), Mockito.any(), Mockito.any());
+
+    systemFactories.put(SYSTEM_NAME, mockSystemFactory);
+
+    // Create mocked configs for specifying serdes
+    Map<String, String> configMap = new HashMap<>();
+    configMap.put("stores." + STORE_NAME + ".key.serde", "stringserde");
+    configMap.put("stores." + STORE_NAME + ".msg.serde", "stringserde");
+    configMap.put("serializers.registry.stringserde.class", StringSerdeFactory.class.getName());
+    Config config = new MapConfig(configMap);
+
+    Map<String, Serde<Object>> serdes = new HashMap<>();
+    serdes.put("stringserde", Mockito.mock(Serde.class));
+
+    // Create mocked system admins
+    SystemAdmin mockSystemAdmin = Mockito.mock(SystemAdmin.class);
+    Mockito.doAnswer(new Answer<Void>() {
+        public Void answer(InvocationOnMock invocation) {
+          Object[] args = invocation.getArguments();
+          System.out.println("called with arguments: " + Arrays.toString(args));
+          return null;
+        }
+      }).when(mockSystemAdmin).validateStream(Mockito.any());
+    SystemAdmins mockSystemAdmins = Mockito.mock(SystemAdmins.class);
+    Mockito.when(mockSystemAdmins.getSystemAdmin("kafka")).thenReturn(mockSystemAdmin);
+
+    // Create a mocked mockStreamMetadataCache
+    SystemStreamMetadata.SystemStreamPartitionMetadata sspMetadata =
+        new SystemStreamMetadata.SystemStreamPartitionMetadata("0", "50", "51");
+    Map<Partition, SystemStreamMetadata.SystemStreamPartitionMetadata> partitionMetadata = new HashMap<>();
+    partitionMetadata.put(new Partition(0), sspMetadata);
+    partitionMetadata.put(new Partition(1), sspMetadata);
+    SystemStreamMetadata systemStreamMetadata = new SystemStreamMetadata(STREAM_NAME, partitionMetadata);
+    StreamMetadataCache mockStreamMetadataCache = Mockito.mock(StreamMetadataCache.class);
+
+    Mockito.when(mockStreamMetadataCache.
+        getStreamMetadata(JavaConverters.
+            asScalaSetConverter(new HashSet<SystemStream>(changelogSystemStreams.values())).asScala().toSet(), false))
+        .thenReturn(
+            new scala.collection.immutable.Map.Map1(new SystemStream(SYSTEM_NAME, STREAM_NAME), systemStreamMetadata));
+
+    // Reset the  expected number of sysConsumer create, start and stop calls, and store.restore() calls
+    this.systemConsumerCreationCount = 0;
+    this.systemConsumerStartCount = 0;
+    this.systemConsumerStopCount = 0;
+    this.storeRestoreCallCount = 0;
+
+    // Create the container storage manager
     this.containerStorageManager =
-        new ContainerStorageManager(taskStorageManagers, systemConsumers, samzaContainerMetrics);
+        new ContainerStorageManager(new ContainerModel("samza-container-test", tasks), mockStreamMetadataCache,
+            mockSystemAdmins, changelogSystemStreams, storageEngineFactories, systemFactories, serdes, config,
+            taskInstanceMetrics, samzaContainerMetrics, Mockito.mock(JobContext.class),
+            Mockito.mock(ContainerContext.class), Mockito.mock(Map.class), DEFAULT_LOGGED_STORE_BASE_DIR,
+            DEFAULT_STORE_BASE_DIR, 2, new SystemClock());
   }
 
   @Test
   public void testParallelismAndMetrics() {
     this.containerStorageManager.start();
     this.containerStorageManager.shutdown();
-    Assert.assertTrue("init count should be 0", this.taskStorageManagersInitCount.getCount() == 0);
-    Assert.assertTrue("Restore count should be 0", this.taskStorageManagersRestoreStoreCount.getCount() == 0);
-    Assert.assertTrue("stop count should be 0", this.taskStorageManagersRestoreStopCount.getCount() == 0);
-
-    Assert.assertTrue("systemConsumerStopCount count should be 0", this.systemConsumerStopCount.getCount() == 0);
-    Assert.assertTrue("systemConsumerStartCount count should be 0", this.systemConsumerStartCount.getCount() == 0);
 
     for (Gauge gauge : taskRestoreMetricGauges.values()) {
-      Assert.assertTrue("Restoration time gauge value should be invoked atleast once", Mockito.mockingDetails(gauge).getInvocations().size() >= 1);
+      Assert.assertTrue("Restoration time gauge value should be invoked atleast once",
+          Mockito.mockingDetails(gauge).getInvocations().size() >= 1);
     }
-  }
 
+    Assert.assertTrue("Store restore count should be 2 because there are 2 tasks", this.storeRestoreCallCount == 2);
+    Assert.assertTrue("systemConsumerCreation count should be 1 (1 consumer per system)",
+        this.systemConsumerCreationCount == 1);
+    Assert.assertTrue("systemConsumerStopCount count should be 1", this.systemConsumerStopCount == 1);
+    Assert.assertTrue("systemConsumerStartCount count should be 1", this.systemConsumerStartCount == 1);
+  }
 }
index ffdceca..3da564e 100644 (file)
@@ -22,23 +22,31 @@ package org.apache.samza.storage
 
 import java.io.{File, FileOutputStream, ObjectOutputStream}
 import java.util
+import java.util.Optional
 
 import org.apache.samza.Partition
-import org.apache.samza.config.{MapConfig, StorageConfig}
-import org.apache.samza.container.TaskName
+import org.apache.samza.config._
+import org.apache.samza.container.{SamzaContainerMetrics, TaskInstanceMetrics, TaskName}
+import org.apache.samza.context.{ContainerContext, JobContext}
+import org.apache.samza.job.model.{ContainerModel, TaskModel}
+import org.apache.samza.serializers.{Serde, StringSerdeFactory}
 import org.apache.samza.storage.StoreProperties.StorePropertiesBuilder
 import org.apache.samza.system.SystemStreamMetadata.SystemStreamPartitionMetadata
 import org.apache.samza.system._
+import org.apache.samza.task.TaskInstanceCollector
 import org.apache.samza.util.{FileUtil, SystemClock}
 import org.junit.Assert._
 import org.junit.{After, Before, Test}
 import org.mockito.Matchers._
+import org.mockito.Mockito
 import org.mockito.Mockito._
 import org.mockito.invocation.InvocationOnMock
 import org.mockito.stubbing.Answer
 import org.scalatest.mockito.MockitoSugar
 
 import scala.collection.JavaConverters._
+import scala.collection.immutable.HashMap
+import scala.collection.mutable
 
 class TestTaskStorageManager extends MockitoSugar {
 
@@ -48,10 +56,10 @@ class TestTaskStorageManager extends MockitoSugar {
 
   @Before
   def setupTestDirs() {
-    TaskStorageManager.getStorePartitionDir(TaskStorageManagerBuilder.defaultStoreBaseDir, store , taskName)
-                      .mkdirs()
-    TaskStorageManager.getStorePartitionDir(TaskStorageManagerBuilder.defaultLoggedStoreBaseDir, loggedStore, taskName)
-                      .mkdirs()
+    StorageManagerUtil.getStorePartitionDir(TaskStorageManagerBuilder.defaultStoreBaseDir, store, taskName)
+      .mkdirs()
+    StorageManagerUtil.getStorePartitionDir(TaskStorageManagerBuilder.defaultLoggedStoreBaseDir, loggedStore, taskName)
+      .mkdirs()
   }
 
   @After
@@ -61,17 +69,17 @@ class TestTaskStorageManager extends MockitoSugar {
   }
 
   /**
-   * This tests the entire TaskStorageManager lifecycle for a Persisted Logged Store
-   * For example, a RocksDb store with changelog needs to continuously update the offset file on flush & stop
-   * When the task is restarted, it should restore correctly from the offset in the OFFSET file on disk (if available)
-   */
+    * This tests the entire TaskStorageManager lifecycle for a Persisted Logged Store
+    * For example, a RocksDb store with changelog needs to continuously update the offset file on flush & stop
+    * When the task is restarted, it should restore correctly from the offset in the OFFSET file on disk (if available)
+    */
   @Test
   def testStoreLifecycleForLoggedPersistedStore(): Unit = {
     // Basic test setup of SystemStream, SystemStreamPartition for this task
     val ss = new SystemStream("kafka", "testStream")
     val partition = new Partition(0)
     val ssp = new SystemStreamPartition(ss, partition)
-    val storeDirectory = TaskStorageManager.getStorePartitionDir(TaskStorageManagerBuilder.defaultLoggedStoreBaseDir, loggedStore, taskName)
+    val storeDirectory = StorageManagerUtil.getStorePartitionDir(TaskStorageManagerBuilder.defaultLoggedStoreBaseDir, loggedStore, taskName)
     val storeFile = new File(storeDirectory, "store.sst")
     val offsetFile = new File(storeDirectory, "OFFSET")
 
@@ -97,18 +105,16 @@ class TestTaskStorageManager extends MockitoSugar {
     when(mockStreamMetadataCache.getStreamMetadata(any(), any())).thenReturn(Map(ss -> metadata))
     when(mockSSPMetadataCache.getMetadata(ssp)).thenReturn(sspMetadata)
 
-    val taskManager = new TaskStorageManagerBuilder()
+    var taskManager = new TaskStorageManagerBuilder()
       .addStore(loggedStore, mockStorageEngine, mockSystemConsumer)
       .setStreamMetadataCache(mockStreamMetadataCache)
       .setSSPMetadataCache(mockSSPMetadataCache)
       .setSystemAdmin("kafka", mockSystemAdmin)
+      .initializeContainerStorageManager()
       .build
 
     taskManager.init
 
-    // mocking restore (issued by ContainerStorageManager)
-    mockStorageEngine.restore(mock[util.Iterator[IncomingMessageEnvelope]])
-
     assertTrue(storeFile.exists())
     assertFalse(offsetFile.exists())
     verify(mockSystemConsumer).register(ssp, "0")
@@ -135,6 +141,15 @@ class TestTaskStorageManager extends MockitoSugar {
     when(mockStreamMetadataCache.getStreamMetadata(any(), any())).thenReturn(Map(ss -> metadata))
     when(mockSSPMetadataCache.getMetadata(ssp)).thenReturn(sspMetadata)
     when(mockSystemAdmin.getOffsetsAfter(Map(ssp -> "100").asJava)).thenReturn(Map(ssp -> "101").asJava)
+    Mockito.reset(mockSystemConsumer)
+
+    taskManager = new TaskStorageManagerBuilder()
+      .addStore(loggedStore, mockStorageEngine, mockSystemConsumer)
+      .setStreamMetadataCache(mockStreamMetadataCache)
+      .setSSPMetadataCache(mockSSPMetadataCache)
+      .setSystemAdmin("kafka", mockSystemAdmin)
+        .initializeContainerStorageManager()
+      .build
 
     taskManager.init
 
@@ -144,17 +159,17 @@ class TestTaskStorageManager extends MockitoSugar {
   }
 
   /**
-   * This tests the entire TaskStorageManager lifecycle for an InMemory Logged Store
-   * For example, an InMemory KV store with changelog should not update the offset file on flush & stop
-   * When the task is restarted, it should ALWAYS restore correctly from the earliest offset
-   */
+    * This tests the entire TaskStorageManager lifecycle for an InMemory Logged Store
+    * For example, an InMemory KV store with changelog should not update the offset file on flush & stop
+    * When the task is restarted, it should ALWAYS restore correctly from the earliest offset
+    */
   @Test
   def testStoreLifecycleForLoggedInMemoryStore(): Unit = {
     // Basic test setup of SystemStream, SystemStreamPartition for this task
     val ss = new SystemStream("kafka", "testStream")
     val partition = new Partition(0)
     val ssp = new SystemStreamPartition(ss, partition)
-    val storeDirectory = TaskStorageManager.getStorePartitionDir(TaskStorageManagerBuilder.defaultLoggedStoreBaseDir, store, taskName)
+    val storeDirectory = StorageManagerUtil.getStorePartitionDir(TaskStorageManagerBuilder.defaultLoggedStoreBaseDir, store, taskName)
 
     val mockStorageEngine: StorageEngine = createMockStorageEngine(isLoggedStore = true, isPersistedStore = false, null)
 
@@ -176,10 +191,11 @@ class TestTaskStorageManager extends MockitoSugar {
       }
     })
     when(mockStreamMetadataCache.getStreamMetadata(any(), any())).thenReturn(Map(ss -> metadata))
-    val taskManager = new TaskStorageManagerBuilder()
+    var taskManager = new TaskStorageManagerBuilder()
       .addStore(store, mockStorageEngine, mockSystemConsumer)
       .setStreamMetadataCache(mockStreamMetadataCache)
       .setSystemAdmin("kafka", mockSystemAdmin)
+      .initializeContainerStorageManager()
       .build
 
     taskManager.init
@@ -210,6 +226,13 @@ class TestTaskStorageManager extends MockitoSugar {
     })
     when(mockStreamMetadataCache.getStreamMetadata(any(), any())).thenReturn(Map(ss -> metadata))
 
+    taskManager = new TaskStorageManagerBuilder()
+      .addStore(store, mockStorageEngine, mockSystemConsumer)
+      .setStreamMetadataCache(mockStreamMetadataCache)
+      .setSystemAdmin("kafka", mockSystemAdmin)
+      .initializeContainerStorageManager()
+      .build
+
     taskManager.init
 
     assertTrue(storeDirectory.list().isEmpty)
@@ -219,83 +242,68 @@ class TestTaskStorageManager extends MockitoSugar {
 
   @Test
   def testStoreDirsWithoutOffsetFileAreDeletedInCleanBaseDirs() {
-    val checkFilePath1 = new File(TaskStorageManager.getStorePartitionDir(TaskStorageManagerBuilder.defaultStoreBaseDir, store, taskName), "check")
+    val checkFilePath1 = new File(StorageManagerUtil.getStorePartitionDir(TaskStorageManagerBuilder.defaultStoreBaseDir, store, taskName), "check")
     checkFilePath1.createNewFile()
-    val checkFilePath2 = new File(TaskStorageManager.getStorePartitionDir(TaskStorageManagerBuilder.defaultLoggedStoreBaseDir, loggedStore, taskName), "check")
+    val checkFilePath2 = new File(StorageManagerUtil.getStorePartitionDir(TaskStorageManagerBuilder.defaultLoggedStoreBaseDir, loggedStore, taskName), "check")
     checkFilePath2.createNewFile()
 
     val taskStorageManager = new TaskStorageManagerBuilder()
       .addStore(store, false)
       .addLoggedStore(loggedStore, true)
+      .setStreamMetadataCache(createMockStreamMetadataCache(null, null, null)) //empty store
+      .initializeContainerStorageManager()
       .build
 
-    //Invoke test method
-    val cleanDirMethod = taskStorageManager
-                          .getClass
-                          .getDeclaredMethod("cleanBaseDirs",
-                                             new Array[java.lang.Class[_]](0):_*)
-    cleanDirMethod.setAccessible(true)
-    cleanDirMethod.invoke(taskStorageManager, new Array[Object](0):_*)
-
     assertTrue("check file was found in store partition directory. Clean up failed!", !checkFilePath1.exists())
     assertTrue("check file was found in logged store partition directory. Clean up failed!", !checkFilePath2.exists())
   }
 
   @Test
   def testLoggedStoreDirsWithOffsetFileAreNotDeletedInCleanBaseDirs() {
-    val offsetFilePath = new File(TaskStorageManager.getStorePartitionDir(TaskStorageManagerBuilder.defaultLoggedStoreBaseDir, loggedStore, taskName), "OFFSET")
+    val offsetFilePath = new File(StorageManagerUtil.getStorePartitionDir(TaskStorageManagerBuilder.defaultLoggedStoreBaseDir, loggedStore, taskName), "OFFSET")
     FileUtil.writeWithChecksum(offsetFilePath, "100")
 
     val taskStorageManager = new TaskStorageManagerBuilder()
       .addLoggedStore(loggedStore, true)
+      .setStreamMetadataCache(createMockStreamMetadataCache(null, null, null)) // empty store
+      .initializeContainerStorageManager()
       .build
 
-    val cleanDirMethod = taskStorageManager.getClass.getDeclaredMethod("cleanBaseDirs",
-      new Array[java.lang.Class[_]](0):_*)
-    cleanDirMethod.setAccessible(true)
-    cleanDirMethod.invoke(taskStorageManager, new Array[Object](0):_*)
-
     assertTrue("Offset file was removed. Clean up failed!", offsetFilePath.exists())
-    assertEquals("Offset read does not match what was in the file", "100", taskStorageManager.fileOffsets.get(new SystemStreamPartition("kafka", "testStream", new Partition(0))))
   }
 
   @Test
   def testStoreDeletedWhenOffsetFileOlderThanDeleteRetention() {
     // This test ensures that store gets deleted when lastModifiedTime of the offset file
     // is older than deletionRetention of the changeLog.
-    val storeDirectory = TaskStorageManager.getStorePartitionDir(TaskStorageManagerBuilder.defaultLoggedStoreBaseDir, loggedStore, taskName)
+    val storeDirectory = StorageManagerUtil.getStorePartitionDir(TaskStorageManagerBuilder.defaultLoggedStoreBaseDir, loggedStore, taskName)
+    storeDirectory.setLastModified(0)
     val offsetFile = new File(storeDirectory, "OFFSET")
     offsetFile.createNewFile()
     FileUtil.writeWithChecksum(offsetFile, "Test Offset Data")
     offsetFile.setLastModified(0)
+
     val taskStorageManager = new TaskStorageManagerBuilder().addStore(store, false)
       .addLoggedStore(loggedStore, true)
+      .setStreamMetadataCache(createMockStreamMetadataCache("0", "1", "2"))
+      .initializeContainerStorageManager()
       .build
 
-    val cleanDirMethod = taskStorageManager.getClass
-      .getDeclaredMethod("cleanBaseDirs",
-        new Array[java.lang.Class[_]](0):_*)
-    cleanDirMethod.setAccessible(true)
-    cleanDirMethod.invoke(taskStorageManager, new Array[Object](0):_*)
-
     assertTrue("Offset file was found in store partition directory. Clean up failed!", !offsetFile.exists())
-    assertTrue("Store directory exists. Clean up failed!", !storeDirectory.exists())
+    assertTrue("Store directory should be deleted and re-created with new last modified time", storeDirectory.lastModified() > 0)
   }
 
   @Test
   def testOffsetFileIsRemovedInCleanBaseDirsForInMemoryLoggedStore() {
-    val offsetFilePath = new File(TaskStorageManager.getStorePartitionDir(TaskStorageManagerBuilder.defaultLoggedStoreBaseDir, loggedStore, taskName), "OFFSET")
+    val offsetFilePath = new File(StorageManagerUtil.getStorePartitionDir(TaskStorageManagerBuilder.defaultLoggedStoreBaseDir, loggedStore, taskName), "OFFSET")
     FileUtil.writeWithChecksum(offsetFilePath, "100")
 
     val taskStorageManager = new TaskStorageManagerBuilder()
       .addLoggedStore(loggedStore, false)
+      .setStreamMetadataCache(createMockStreamMetadataCache(null, null, null)) // empty store
+      .initializeContainerStorageManager()
       .build
 
-    val cleanDirMethod = taskStorageManager.getClass.getDeclaredMethod("cleanBaseDirs",
-      new Array[java.lang.Class[_]](0):_*)
-    cleanDirMethod.setAccessible(true)
-    cleanDirMethod.invoke(taskStorageManager, new Array[Object](0):_*)
-
     assertFalse("Offset file was not removed. Clean up failed!", offsetFilePath.exists())
   }
 
@@ -303,26 +311,38 @@ class TestTaskStorageManager extends MockitoSugar {
   def testStopCreatesOffsetFileForLoggedStore() {
     val partition = new Partition(0)
 
-    val offsetFilePath = new File(TaskStorageManager.getStorePartitionDir(TaskStorageManagerBuilder.defaultLoggedStoreBaseDir, loggedStore, taskName) + File.separator + "OFFSET")
+    val storeDirectory = StorageManagerUtil.getStorePartitionDir(TaskStorageManagerBuilder.defaultLoggedStoreBaseDir, loggedStore, taskName)
+    val offsetFile = new File(storeDirectory, "OFFSET")
 
     val sspMetadataCache = mock[SSPMetadataCache]
     val sspMetadata = new SystemStreamPartitionMetadata("20", "100", "101")
     when(sspMetadataCache.getMetadata(new SystemStreamPartition("kafka", "testStream", partition)))
       .thenReturn(sspMetadata)
 
+    var metadata = new SystemStreamMetadata("testStream", new java.util.HashMap[Partition, SystemStreamPartitionMetadata]() {
+      {
+        put(partition, sspMetadata)
+      }
+    })
+
+    val mockStreamMetadataCache = mock[StreamMetadataCache]
+    when(mockStreamMetadataCache.getStreamMetadata(any(), any())).thenReturn(Map(new SystemStream("kafka", "testStream") -> metadata))
+
     //Build TaskStorageManager
     val taskStorageManager = new TaskStorageManagerBuilder()
-      .addStore(loggedStore, true)
+      .addLoggedStore(loggedStore, true)
+      .setStreamMetadataCache(mockStreamMetadataCache)
       .setSSPMetadataCache(sspMetadataCache)
       .setPartition(partition)
+      .initializeContainerStorageManager()
       .build
 
     //Invoke test method
     taskStorageManager.stop()
 
     //Check conditions
-    assertTrue("Offset file doesn't exist!", offsetFilePath.exists())
-    assertEquals("Found incorrect value in offset file!", "100", FileUtil.readWithChecksum(offsetFilePath))
+    assertTrue("Offset file doesn't exist!", offsetFile.exists())
+    assertEquals("Found incorrect value in offset file!", "100", FileUtil.readWithChecksum(offsetFile))
   }
 
   /**
@@ -332,9 +352,9 @@ class TestTaskStorageManager extends MockitoSugar {
   def testFlushCreatesOffsetFileForLoggedStore() {
     val partition = new Partition(0)
 
-    val offsetFilePath = new File(TaskStorageManager.getStorePartitionDir(TaskStorageManagerBuilder.defaultLoggedStoreBaseDir, loggedStore, taskName) + File.separator + "OFFSET")
+    val offsetFilePath = new File(StorageManagerUtil.getStorePartitionDir(TaskStorageManagerBuilder.defaultLoggedStoreBaseDir, loggedStore, taskName) + File.separator + "OFFSET")
     val anotherOffsetPath = new File(
-      TaskStorageManager.getStorePartitionDir(
+      StorageManagerUtil.getStorePartitionDir(
         TaskStorageManagerBuilder.defaultLoggedStoreBaseDir, store, taskName) + File.separator + "OFFSET")
 
     val sspMetadataCache = mock[SSPMetadataCache]
@@ -342,13 +362,24 @@ class TestTaskStorageManager extends MockitoSugar {
     when(sspMetadataCache.getMetadata(new SystemStreamPartition("kafka", "testStream", partition)))
       .thenReturn(sspMetadata)
 
+    var metadata = new SystemStreamMetadata("testStream", new java.util.HashMap[Partition, SystemStreamPartitionMetadata]() {
+      {
+        put(partition, sspMetadata)
+      }
+    })
+
+    val mockStreamMetadataCache = mock[StreamMetadataCache]
+    when(mockStreamMetadataCache.getStreamMetadata(any(), any())).thenReturn(Map(new SystemStream("kafka", "testStream") -> metadata))
+
     //Build TaskStorageManager
     val taskStorageManager = new TaskStorageManagerBuilder()
-            .addStore(loggedStore, true)
-            .addStore(store, false)
-            .setSSPMetadataCache(sspMetadataCache)
-            .setPartition(partition)
-            .build
+      .addLoggedStore(loggedStore, true)
+      .addStore(store, false)
+      .setSSPMetadataCache(sspMetadataCache)
+      .setStreamMetadataCache(mockStreamMetadataCache)
+      .setPartition(partition)
+      .initializeContainerStorageManager()
+      .build
 
     //Invoke test method
     taskStorageManager.flush()
@@ -367,21 +398,33 @@ class TestTaskStorageManager extends MockitoSugar {
   def testFlushDeletesOffsetFileForLoggedStoreForEmptyPartition() {
     val partition = new Partition(0)
 
-    val offsetFilePath = new File(TaskStorageManager.getStorePartitionDir(TaskStorageManagerBuilder.defaultLoggedStoreBaseDir, loggedStore, taskName) + File.separator + "OFFSET")
+    val offsetFilePath = new File(StorageManagerUtil.getStorePartitionDir(TaskStorageManagerBuilder.defaultLoggedStoreBaseDir, loggedStore, taskName) + File.separator + "OFFSET")
 
     val sspMetadataCache = mock[SSPMetadataCache]
+    val sspMetadata = new SystemStreamPartitionMetadata("0", "100", "101")
     when(sspMetadataCache.getMetadata(new SystemStreamPartition("kafka", "testStream", partition)))
       // first return some metadata
-      .thenReturn(new SystemStreamPartitionMetadata("0", "100", "101"))
+      .thenReturn(sspMetadata)
       // then return no metadata to trigger the delete
       .thenReturn(null)
 
+    var metadata = new SystemStreamMetadata("testStream", new java.util.HashMap[Partition, SystemStreamPartitionMetadata]() {
+      {
+        put(partition, sspMetadata)
+      }
+    })
+
+    val mockStreamMetadataCache = mock[StreamMetadataCache]
+    when(mockStreamMetadataCache.getStreamMetadata(any(), any())).thenReturn(Map(new SystemStream("kafka", "testStream") -> metadata))
+
     //Build TaskStorageManager
     val taskStorageManager = new TaskStorageManagerBuilder()
-            .addStore(loggedStore, true)
-            .setSSPMetadataCache(sspMetadataCache)
-            .setPartition(partition)
-            .build
+      .addLoggedStore(loggedStore, true)
+      .setSSPMetadataCache(sspMetadataCache)
+      .setStreamMetadataCache(mockStreamMetadataCache)
+      .setPartition(partition)
+      .initializeContainerStorageManager()
+      .build
 
     //Invoke test method
     taskStorageManager.flush()
@@ -402,18 +445,30 @@ class TestTaskStorageManager extends MockitoSugar {
     val partition = new Partition(0)
     val ssp = new SystemStreamPartition("kafka", "testStream", partition)
 
-    val offsetFilePath = new File(TaskStorageManager.getStorePartitionDir(TaskStorageManagerBuilder.defaultLoggedStoreBaseDir, loggedStore, taskName) + File.separator + "OFFSET")
+    val offsetFilePath = new File(StorageManagerUtil.getStorePartitionDir(TaskStorageManagerBuilder.defaultLoggedStoreBaseDir, loggedStore, taskName) + File.separator + "OFFSET")
     FileUtil.writeWithChecksum(offsetFilePath, "100")
 
     val sspMetadataCache = mock[SSPMetadataCache]
-    when(sspMetadataCache.getMetadata(ssp)).thenReturn(new SystemStreamPartitionMetadata("20", "139", "140"))
+    val sspMetadata = new SystemStreamPartitionMetadata("20", "139", "140")
+    when(sspMetadataCache.getMetadata(ssp)).thenReturn(sspMetadata)
+
+    var metadata = new SystemStreamMetadata("testStream", new java.util.HashMap[Partition, SystemStreamPartitionMetadata]() {
+      {
+        put(partition, sspMetadata)
+      }
+    })
+
+    val mockStreamMetadataCache = mock[StreamMetadataCache]
+    when(mockStreamMetadataCache.getStreamMetadata(any(), any())).thenReturn(Map(new SystemStream("kafka", "testStream") -> metadata))
 
     //Build TaskStorageManager
     val taskStorageManager = new TaskStorageManagerBuilder()
-            .addStore(loggedStore, true)
-            .setSSPMetadataCache(sspMetadataCache)
-            .setPartition(partition)
-            .build
+      .addLoggedStore(loggedStore, true)
+      .setSSPMetadataCache(sspMetadataCache)
+      .setPartition(partition)
+      .setStreamMetadataCache(mockStreamMetadataCache)
+      .initializeContainerStorageManager()
+      .build
 
     //Invoke test method
     taskStorageManager.flush()
@@ -437,16 +492,19 @@ class TestTaskStorageManager extends MockitoSugar {
   def testStopShouldNotCreateOffsetFileForEmptyStore() {
     val partition = new Partition(0)
 
-    val offsetFilePath = new File(TaskStorageManager.getStorePartitionDir(TaskStorageManagerBuilder.defaultLoggedStoreBaseDir, loggedStore, taskName) + File.separator + "OFFSET")
+    val offsetFilePath = new File(StorageManagerUtil.getStorePartitionDir(TaskStorageManagerBuilder.defaultLoggedStoreBaseDir, loggedStore, taskName) + File.separator + "OFFSET")
+
 
     val sspMetadataCache = mock[SSPMetadataCache]
     when(sspMetadataCache.getMetadata(new SystemStreamPartition("kafka", "testStream", partition))).thenReturn(null)
 
     //Build TaskStorageManager
     val taskStorageManager = new TaskStorageManagerBuilder()
-      .addStore(loggedStore, true)
+      .addLoggedStore(loggedStore, true)
       .setSSPMetadataCache(sspMetadataCache)
       .setPartition(partition)
+      .setStreamMetadataCache(createMockStreamMetadataCache(null, null, null)) // null offsets for empty store
+      .initializeContainerStorageManager()
       .build
 
     //Invoke test method
@@ -517,7 +575,7 @@ class TestTaskStorageManager extends MockitoSugar {
     val ss = new SystemStream(systemName, streamName)
     val partition = new Partition(0)
     val ssp = new SystemStreamPartition(ss, partition)
-    val storeDirectory = TaskStorageManager.getStorePartitionDir(TaskStorageManagerBuilder.defaultLoggedStoreBaseDir, loggedStore, taskName)
+    val storeDirectory = StorageManagerUtil.getStorePartitionDir(TaskStorageManagerBuilder.defaultLoggedStoreBaseDir, loggedStore, taskName)
     val storeFile = new File(storeDirectory, "store.sst")
 
     if (writeOffsetFile) {
@@ -589,6 +647,7 @@ class TestTaskStorageManager extends MockitoSugar {
       .addStore(loggedStore, mockStorageEngine, mockSystemConsumer)
       .setStreamMetadataCache(mockStreamMetadataCache)
       .setSystemAdmin(systemName, mockSystemAdmin)
+      .initializeContainerStorageManager()
       .build
 
     taskManager.init
@@ -596,6 +655,19 @@ class TestTaskStorageManager extends MockitoSugar {
     verify(mockSystemConsumer).register(any(classOf[SystemStreamPartition]), anyString())
   }
 
+  private def createMockStreamMetadataCache(oldestOffset: String, newestOffset: String, upcomingOffset: String) = {
+    // an empty store would return a SSPMetadata with oldest, newest and upcoming offset set to null
+    var metadata = new SystemStreamMetadata("testStream", new java.util.HashMap[Partition, SystemStreamPartitionMetadata]() {
+      {
+        put(new Partition(0), new SystemStreamPartitionMetadata(oldestOffset, newestOffset, upcomingOffset))
+      }
+    })
+
+    val mockStreamMetadataCache = mock[StreamMetadataCache]
+    when(mockStreamMetadataCache.getStreamMetadata(any(), any())).thenReturn(Map(new SystemStream("kafka", "testStream") -> metadata))
+    mockStreamMetadataCache
+  }
+
   private def createMockStorageEngine(isLoggedStore: Boolean, isPersistedStore: Boolean, storeFile: File) = {
     val mockStorageEngine = mock[StorageEngine]
     // getStoreProperties should always return the same StoreProperties
@@ -619,7 +691,7 @@ class TestTaskStorageManager extends MockitoSugar {
 }
 
 object TaskStorageManagerBuilder {
-  val defaultStoreBaseDir =  new File(System.getProperty("java.io.tmpdir") + File.separator + "store")
+  val defaultStoreBaseDir = new File(System.getProperty("java.io.tmpdir") + File.separator + "store")
   val defaultLoggedStoreBaseDir = new File(System.getProperty("java.io.tmpdir") + File.separator + "loggedStore")
 }
 
@@ -633,12 +705,13 @@ class TaskStorageManagerBuilder extends MockitoSugar {
   var systemAdminsMap: Map[String, SystemAdmin] = Map("kafka" -> mock[SystemAdmin])
   var taskName: TaskName = new TaskName("testTask")
   var storeBaseDir: File = TaskStorageManagerBuilder.defaultStoreBaseDir
-  var loggedStoreBaseDir: File =  TaskStorageManagerBuilder.defaultLoggedStoreBaseDir
+  var loggedStoreBaseDir: File = TaskStorageManagerBuilder.defaultLoggedStoreBaseDir
   var changeLogStreamPartitions: Int = 1
+  var containerStorageManager: ContainerStorageManager = mock[ContainerStorageManager]
 
   def addStore(storeName: String, storageEngine: StorageEngine, systemConsumer: SystemConsumer): TaskStorageManagerBuilder = {
     taskStores = taskStores ++ Map(storeName -> storageEngine)
-    storeConsumers = storeConsumers ++ Map(storeName -> systemConsumer)
+    storeConsumers = storeConsumers ++ Map("kafka" -> systemConsumer)
     changeLogSystemStreams = changeLogSystemStreams ++ Map(storeName -> new SystemStream("kafka", "testStream"))
     this
   }
@@ -653,7 +726,7 @@ class TaskStorageManagerBuilder extends MockitoSugar {
   def addLoggedStore(storeName: String, isPersistedToDisk: Boolean): TaskStorageManagerBuilder = {
     val mockStorageEngine = mock[StorageEngine]
     when(mockStorageEngine.getStoreProperties)
-    .thenReturn(new StorePropertiesBuilder().setPersistedToDisk(isPersistedToDisk).setLoggedStore(true).build())
+      .thenReturn(new StorePropertiesBuilder().setPersistedToDisk(isPersistedToDisk).setLoggedStore(true).build())
     addStore(storeName, mockStorageEngine, mock[SystemConsumer])
   }
 
@@ -687,21 +760,69 @@ class TaskStorageManagerBuilder extends MockitoSugar {
     this
   }
 
+  /**
+    * This method creates and starts a {@link ContainerStorageManager}
+    */
+  def initializeContainerStorageManager() = {
+    var tasks: Map[TaskName, TaskModel] = HashMap[TaskName, TaskModel]((taskName, new TaskModel(taskName, new util.HashSet[SystemStreamPartition], new Partition(0))))
+    var containerModel = new ContainerModel("container", tasks.asJava)
+
+    val mockSystemAdmins = Mockito.mock(classOf[SystemAdmins])
+    Mockito.when(mockSystemAdmins.getSystemAdmin(org.mockito.Matchers.eq("kafka"))).thenReturn(systemAdminsMap.get("kafka").get)
+
+    var mockStorageEngineFactory : StorageEngineFactory[AnyRef, AnyRef] = Mockito.mock(classOf[StorageEngineFactory[AnyRef, AnyRef]])
+
+    var storageEngineFactories : mutable.Map[String, StorageEngineFactory[AnyRef, AnyRef]] =  scala.collection.mutable.Map[String, StorageEngineFactory[AnyRef, AnyRef]]()
+
+    if(taskStores.contains("store1")) {
+      Mockito.when(mockStorageEngineFactory.getStorageEngine(org.mockito.Matchers.eq("store1"), any(), any(), any(), any(), any(), any(), any(), any(), any()))
+        .thenReturn(taskStores.get("store1").get)
+      storageEngineFactories += ("store1" -> mockStorageEngineFactory)
+    }
+
+    if(taskStores.contains("loggedStore1")) {
+      Mockito.when(mockStorageEngineFactory.getStorageEngine(org.mockito.Matchers.eq("loggedStore1"), any(), any(), any(), any(), any(), any(), any(), any(), any()))
+        .thenReturn(taskStores.get("loggedStore1").get)
+      storageEngineFactories += ("loggedStore1" -> mockStorageEngineFactory)
+    }
+
+
+    var mockSystemFactory = Mockito.mock(classOf[SystemFactory])
+    Mockito.when(mockSystemFactory.getConsumer(org.mockito.Matchers.eq("kafka"),any(), any())).thenReturn(storeConsumers.get("kafka").get)
+    var systemFactories : Map[String, SystemFactory] = HashMap[String, SystemFactory](("kafka", mockSystemFactory))
+
+    var config =  new MapConfig(mutable.Map(
+      "stores.store1.key.serde" -> classOf[StringSerdeFactory].getCanonicalName,
+      "stores.store1.msg.serde" -> classOf[StringSerdeFactory].getCanonicalName,
+      "stores.loggedStore1.key.serde" -> classOf[StringSerdeFactory].getCanonicalName,
+      "stores.loggedStore1.msg.serde" -> classOf[StringSerdeFactory].getCanonicalName).asJava)
+
+    var mockSerdes: Map[String, Serde[AnyRef]] = HashMap[String, Serde[AnyRef]]((classOf[StringSerdeFactory].getCanonicalName, Mockito.mock(classOf[Serde[AnyRef]])))
+
+
+    containerStorageManager = new ContainerStorageManager(containerModel, streamMetadataCache, mockSystemAdmins,
+      changeLogSystemStreams.asJava, storageEngineFactories.asJava, systemFactories.asJava, mockSerdes.asJava, config,
+      new HashMap[TaskName, TaskInstanceMetrics]().asJava, Mockito.mock(classOf[SamzaContainerMetrics]), Mockito.mock(classOf[JobContext]),
+      Mockito.mock(classOf[ContainerContext]), new HashMap[TaskName, TaskInstanceCollector].asJava, loggedStoreBaseDir, TaskStorageManagerBuilder.defaultStoreBaseDir, 1,
+      new SystemClock)
+    this
+  }
+
+
+
   def build: TaskStorageManager = {
+
+    if (containerStorageManager != null) {
+      containerStorageManager.start()
+    }
+
     new TaskStorageManager(
       taskName = taskName,
-      taskStores = taskStores,
-      storeConsumers = storeConsumers,
+      containerStorageManager = containerStorageManager,
       changeLogSystemStreams = changeLogSystemStreams,
-      changeLogStreamPartitions = changeLogStreamPartitions,
-      streamMetadataCache = streamMetadataCache,
       sspMetadataCache = sspMetadataCache,
-      nonLoggedStoreBaseDir = storeBaseDir,
       loggedStoreBaseDir = loggedStoreBaseDir,
-      partition = partition,
-      systemAdmins = buildSystemAdmins(systemAdminsMap),
-      new StorageConfig(new MapConfig()).getChangeLogDeleteRetentionsInMs,
-      SystemClock.instance
+      partition = partition
     )
   }
 
index e30328a..9a58760 100644 (file)
@@ -23,6 +23,7 @@ import java.io.File
 
 import org.apache.samza.context.{ContainerContext, JobContext}
 import org.apache.samza.metrics.MetricsRegistry
+import org.apache.samza.storage.StorageEngineFactory.StoreMode
 import org.apache.samza.storage.kv.{BaseKeyValueStorageEngineFactory, KeyValueStore, KeyValueStoreMetrics}
 import org.apache.samza.system.SystemStreamPartition
 
@@ -33,7 +34,7 @@ class InMemoryKeyValueStorageEngineFactory[K, V] extends BaseKeyValueStorageEngi
     registry: MetricsRegistry,
     changeLogSystemStreamPartition: SystemStreamPartition,
     jobContext: JobContext,
-    containerContext: ContainerContext): KeyValueStore[Array[Byte], Array[Byte]] = {
+    containerContext: ContainerContext, storeMode: StoreMode): KeyValueStore[Array[Byte], Array[Byte]] = {
     val metrics = new KeyValueStoreMetrics(storeName, registry)
     val inMemoryDb = new InMemoryKeyValueStore (metrics)
     inMemoryDb
index 0734fe6..ce9f974 100644 (file)
@@ -26,6 +26,7 @@ import org.apache.samza.config.JavaStorageConfig;
 import org.apache.samza.config.SerializerConfig$;
 import org.apache.samza.serializers.Serde;
 import org.apache.samza.serializers.SerdeFactory;
+import org.apache.samza.storage.StorageEngineFactory;
 import org.apache.samza.util.Util;
 import org.rocksdb.Options;
 import org.rocksdb.RocksDB;
@@ -60,7 +61,7 @@ public class RocksDbKeyValueReader {
     valueSerde = getSerdeFromName(storageConfig.getStorageMsgSerde(storeName), serializerConfig);
 
     // get db options
-    Options options = RocksDbOptionsHelper.options(config, 1);
+    Options options = RocksDbOptionsHelper.options(config, 1, StorageEngineFactory.StoreMode.ReadWrite);
 
     // open the db
     RocksDB.loadLibrary();
index 7beb066..f9b8ea9 100644 (file)
@@ -22,6 +22,7 @@ package org.apache.samza.storage.kv;
 import org.apache.samza.config.Config;
 import org.apache.samza.context.ContainerContext;
 import org.apache.samza.context.JobContext;
+import org.apache.samza.storage.StorageEngineFactory;
 import org.rocksdb.BlockBasedTableConfig;
 import org.rocksdb.CompactionStyle;
 import org.rocksdb.CompressionType;
@@ -42,7 +43,7 @@ public class RocksDbOptionsHelper {
   private static final String ROCKSDB_MAX_LOG_FILE_SIZE_BYTES = "rocksdb.max.log.file.size.bytes";
   private static final String ROCKSDB_KEEP_LOG_FILE_NUM = "rocksdb.keep.log.file.num";
 
-  public static Options options(Config storeConfig, int numTasksForContainer) {
+  public static Options options(Config storeConfig, int numTasksForContainer, StorageEngineFactory.StoreMode storeMode) {
     Options options = new Options();
     Long writeBufSize = storeConfig.getLong("container.write.buffer.size.bytes", 32 * 1024 * 1024);
     // Cache size and write buffer size are specified on a per-container basis.
@@ -106,6 +107,10 @@ public class RocksDbOptionsHelper {
     options.setMaxLogFileSize(storeConfig.getLong(ROCKSDB_MAX_LOG_FILE_SIZE_BYTES, 64 * 1024 * 1024L));
     options.setKeepLogFileNum(storeConfig.getLong(ROCKSDB_KEEP_LOG_FILE_NUM, 2));
 
+    if(storeMode.equals(StorageEngineFactory.StoreMode.BulkLoad)) {
+      options.prepareForBulkLoad();
+    }
+
     return options;
   }
 
index 704af4a..31d3221 100644 (file)
@@ -24,6 +24,7 @@ import java.io.File
 import org.apache.samza.config.StorageConfig._
 import org.apache.samza.context.{ContainerContext, JobContext}
 import org.apache.samza.metrics.MetricsRegistry
+import org.apache.samza.storage.StorageEngineFactory.StoreMode
 import org.apache.samza.system.SystemStreamPartition
 import org.rocksdb.{FlushOptions, WriteOptions}
 
@@ -42,7 +43,7 @@ class RocksDbKeyValueStorageEngineFactory [K, V] extends BaseKeyValueStorageEngi
     registry: MetricsRegistry,
     changeLogSystemStreamPartition: SystemStreamPartition,
     jobContext: JobContext,
-    containerContext: ContainerContext): KeyValueStore[Array[Byte], Array[Byte]] = {
+    containerContext: ContainerContext, storeMode: StoreMode): KeyValueStore[Array[Byte], Array[Byte]] = {
     val storageConfig = jobContext.getConfig.subset("stores." + storeName + ".", true)
     val isLoggedStore = jobContext.getConfig.getChangelogStream(storeName).isDefined
     val rocksDbMetrics = new KeyValueStoreMetrics(storeName, registry)
@@ -50,7 +51,7 @@ class RocksDbKeyValueStorageEngineFactory [K, V] extends BaseKeyValueStorageEngi
     rocksDbMetrics.newGauge("rocksdb.block-cache-size",
       () => RocksDbOptionsHelper.getBlockCacheSize(storageConfig, numTasksForContainer))
 
-    val rocksDbOptions = RocksDbOptionsHelper.options(storageConfig, numTasksForContainer)
+    val rocksDbOptions = RocksDbOptionsHelper.options(storageConfig, numTasksForContainer, storeMode)
     val rocksDbWriteOptions = new WriteOptions().setDisableWAL(true)
     val rocksDbFlushOptions = new FlushOptions().setWaitForFlush(true)
     val rocksDb = new RocksDbKeyValueStore(
index c5a89d9..e4f78d3 100644 (file)
@@ -237,7 +237,15 @@ class RocksDbKeyValueStore(
   }
 
   def close(): Unit = {
+    trace("Calling compact range.")
     stateChangeLock.writeLock().lock()
+
+    // if auto-compaction is disabled, e.g., when bulk-loading
+    if(options.disableAutoCompactions()) {
+      trace("Auto compaction is disabled, invoking compact range.")
+      db.compactRange()
+    }
+
     try {
       trace("Closing.")
       if (stackAtFirstClose == null) { // first close
index e1e7642..6f1e0f6 100644 (file)
@@ -26,6 +26,7 @@ import org.apache.samza.config.MetricsConfig.Config2Metrics
 import org.apache.samza.context.{ContainerContext, JobContext}
 import org.apache.samza.metrics.MetricsRegistry
 import org.apache.samza.serializers.Serde
+import org.apache.samza.storage.StorageEngineFactory.StoreMode
 import org.apache.samza.storage.{StorageEngine, StorageEngineFactory, StoreProperties}
 import org.apache.samza.system.SystemStreamPartition
 import org.apache.samza.task.MessageCollector
@@ -57,7 +58,7 @@ trait BaseKeyValueStorageEngineFactory[K, V] extends StorageEngineFactory[K, V]
     registry: MetricsRegistry,
     changeLogSystemStreamPartition: SystemStreamPartition,
     jobContext: JobContext,
-    containerContext: ContainerContext): KeyValueStore[Array[Byte], Array[Byte]]
+    containerContext: ContainerContext, storeMode: StoreMode): KeyValueStore[Array[Byte], Array[Byte]]
 
   /**
    * Constructs a key-value StorageEngine and returns it to the caller
@@ -79,7 +80,7 @@ trait BaseKeyValueStorageEngineFactory[K, V] extends StorageEngineFactory[K, V]
     registry: MetricsRegistry,
     changeLogSystemStreamPartition: SystemStreamPartition,
     jobContext: JobContext,
-    containerContext: ContainerContext): StorageEngine = {
+    containerContext: ContainerContext, storeMode : StoreMode): StorageEngine = {
     val storageConfig = jobContext.getConfig.subset("stores." + storeName + ".", true)
     val storeFactory = storageConfig.get("factory")
     var storePropertiesBuilder = new StoreProperties.StorePropertiesBuilder()
@@ -109,7 +110,7 @@ trait BaseKeyValueStorageEngineFactory[K, V] extends StorageEngineFactory[K, V]
     }
 
     val rawStore =
-      getKVStore(storeName, storeDir, registry, changeLogSystemStreamPartition, jobContext, containerContext)
+      getKVStore(storeName, storeDir, registry, changeLogSystemStreamPartition, jobContext, containerContext, storeMode)
 
     // maybe wrap with logging
     val maybeLoggedStore = if (changeLogSystemStreamPartition == null) {
index 01b8ed7..5ac8f9a 100644 (file)
@@ -31,6 +31,8 @@ import org.apache.samza.container.TaskName;
 import org.apache.samza.rest.model.JobStatus;
 import org.apache.samza.rest.model.Task;
 import org.apache.samza.rest.proxy.job.JobInstance;
+import org.apache.samza.storage.ContainerStorageManager;
+import org.apache.samza.storage.StorageManagerUtil;
 import org.apache.samza.storage.TaskStorageManager;
 import org.apache.samza.util.Clock;
 import org.apache.samza.util.SystemClock;
@@ -99,7 +101,7 @@ public class LocalStoreMonitor implements Monitor {
               LOG.info(String.format("Local store: %s is actively used by the task: %s.", storeName, task.getTaskName()));
             } else {
               LOG.info(String.format("Local store: %s not used by the task: %s.", storeName, task.getTaskName()));
-              markSweepTaskStore(TaskStorageManager.getStorePartitionDir(jobDir, storeName, new TaskName(task.getTaskName())));
+              markSweepTaskStore(StorageManagerUtil.getStorePartitionDir(jobDir, storeName, new TaskName(task.getTaskName())));
             }
           }
         }
index 1c2b333..9c7657d 100644 (file)
@@ -32,6 +32,7 @@ import org.apache.samza.job.model.{ContainerModel, TaskModel}
 import org.apache.samza.metrics.MetricsRegistryMap
 import org.apache.samza.serializers.{ByteSerde, SerdeManager, UUIDSerde}
 import org.apache.samza.storage.StorageEngineFactory
+import org.apache.samza.storage.StorageEngineFactory.StoreMode
 import org.apache.samza.storage.kv.{KeyValueStorageEngine, KeyValueStore}
 import org.apache.samza.system.{SystemProducer, SystemProducers, SystemStreamPartition}
 import org.apache.samza.task.TaskInstanceCollector
@@ -123,7 +124,7 @@ object TestKeyValuePerformance extends Logging {
           new MetricsRegistryMap,
           null,
           JobContextImpl.fromConfigWithDefaults(config),
-          new ContainerContextImpl(new ContainerModel("0", tasks.asJava), new MetricsRegistryMap)
+          new ContainerContextImpl(new ContainerModel("0", tasks.asJava), new MetricsRegistryMap), StoreMode.ReadWrite
         )
 
         val db = if(!engine.isInstanceOf[KeyValueStorageEngine[_,_]]) {
index 4e410d9..ce55dd0 100644 (file)
@@ -118,8 +118,7 @@ public class TestLocalTableWithSideInputs extends AbstractIntegrationTestHarness
           .collect(Collectors.toList());
 
       boolean successfulJoin = results.stream().allMatch(expectedEnrichedPageviews::contains);
-      assertEquals("Mismatch between the expected and actual join count", results.size(),
-          expectedEnrichedPageviews.size());
+      assertEquals("Mismatch between the expected and actual join count", expectedEnrichedPageviews.size(), results.size());
       assertTrue("Pageview profile join did not succeed for all inputs", successfulJoin);
     } catch (SamzaException e) {
       e.printStackTrace();