SAMZA-1670: When fetching a newest offset for a partition, also prefetch and cache...
authorCameron Lee <calee@linkedin.com>
Wed, 13 Jun 2018 23:04:02 +0000 (16:04 -0700)
committerJagadish <jvenkatraman@linkedin.com>
Wed, 13 Jun 2018 23:04:02 +0000 (16:04 -0700)
Author: Cameron Lee <calee@linkedin.com>

Reviewers: Jagadish <jagadish@apache.org>

Closes #520 from cameronlee314/partition_metadata

build.gradle
samza-api/src/main/java/org/apache/samza/system/ExtendedSystemAdmin.java
samza-api/src/main/java/org/apache/samza/system/SystemAdmin.java
samza-api/src/test/java/org/apache/samza/system/TestSystemAdmin.java [new file with mode: 0644]
samza-core/src/main/java/org/apache/samza/storage/StorageRecovery.java
samza-core/src/main/scala/org/apache/samza/container/SamzaContainer.scala
samza-core/src/main/scala/org/apache/samza/storage/TaskStorageManager.scala
samza-core/src/main/scala/org/apache/samza/system/SSPMetadataCache.java [new file with mode: 0644]
samza-core/src/test/java/org/apache/samza/system/TestSSPMetadataCache.java [new file with mode: 0644]
samza-core/src/test/scala/org/apache/samza/container/TestSamzaContainer.scala
samza-core/src/test/scala/org/apache/samza/storage/TestTaskStorageManager.scala

index a94fcfa..2e90604 100644 (file)
@@ -143,6 +143,7 @@ project(':samza-api') {
     compile "org.codehaus.jackson:jackson-mapper-asl:$jacksonVersion"
     testCompile "junit:junit:$junitVersion"
     testCompile "org.mockito:mockito-core:$mockitoVersion"
+    testCompile "com.google.guava:guava:$guavaVersion"
   }
   checkstyle {
     configFile = new File(rootDir, "checkstyle/checkstyle.xml")
index ac5b1aa..ba239dc 100644 (file)
@@ -28,6 +28,10 @@ import java.util.Set;
 public interface ExtendedSystemAdmin extends SystemAdmin {
   Map<String, SystemStreamMetadata> getSystemStreamPartitionCounts(Set<String> streamNames, long cacheTTL);
 
-  // Makes fewer offset requests than getSystemStreamMetadata
+  /**
+   * Deprecated: Use {@link SystemAdmin#getSSPMetadata}, ideally combined with caching (i.e. SSPMetadataCache).
+   * Makes fewer offset requests than getSystemStreamMetadata
+   */
+  @Deprecated
   String getNewestOffset(SystemStreamPartition ssp, Integer maxRetries);
 }
index 13566a6..16f90e9 100644 (file)
 
 package org.apache.samza.system;
 
+import java.util.HashMap;
 import java.util.Map;
 import java.util.Set;
+import java.util.stream.Collectors;
+
 
 /**
  * Helper interface attached to an underlying system to fetch information about
@@ -61,6 +64,34 @@ public interface SystemAdmin {
   Map<String, SystemStreamMetadata> getSystemStreamMetadata(Set<String> streamNames);
 
   /**
+   * Fetch metadata from a system for a set of SSPs.
+   * Implementors should override this if there is a more efficient implementation than delegating to
+   * {@link #getSystemStreamMetadata}.
+   *
+   * @param ssps SSPs for which to get metadata
+   * @return A map from SystemStreamPartition to the SystemStreamPartitionMetadata, with an entry for each SSP in
+   * {@code ssps} for which metadata could be found
+   * @throws RuntimeException if there was an error fetching metadata
+   */
+  default Map<SystemStreamPartition, SystemStreamMetadata.SystemStreamPartitionMetadata> getSSPMetadata(
+      Set<SystemStreamPartition> ssps) {
+    Set<String> streams = ssps.stream().map(SystemStream::getStream).collect(Collectors.toSet());
+    Map<String, SystemStreamMetadata> streamToSystemStreamMetadata = getSystemStreamMetadata(streams);
+    Map<SystemStreamPartition, SystemStreamMetadata.SystemStreamPartitionMetadata> sspToSSPMetadata = new HashMap<>();
+    for (SystemStreamPartition ssp : ssps) {
+      SystemStreamMetadata systemStreamMetadata = streamToSystemStreamMetadata.get(ssp.getStream());
+      if (systemStreamMetadata != null) {
+        SystemStreamMetadata.SystemStreamPartitionMetadata sspMetadata =
+            systemStreamMetadata.getSystemStreamPartitionMetadata().get(ssp.getPartition());
+        if (sspMetadata != null) {
+          sspToSSPMetadata.put(ssp, sspMetadata);
+        }
+      }
+    }
+    return sspToSSPMetadata;
+  }
+
+  /**
    * Compare the two offsets. -1, 0, +1 means offset1 &lt; offset2,
    * offset1 == offset2 and offset1 &gt; offset2 respectively. Return
    * null if those two offsets are not comparable
diff --git a/samza-api/src/test/java/org/apache/samza/system/TestSystemAdmin.java b/samza-api/src/test/java/org/apache/samza/system/TestSystemAdmin.java
new file mode 100644 (file)
index 0000000..85245e3
--- /dev/null
@@ -0,0 +1,116 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.samza.system;
+
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.ImmutableSet;
+import java.util.Map;
+import java.util.Set;
+import org.apache.samza.Partition;
+import org.junit.Test;
+
+import static org.junit.Assert.*;
+import static org.mockito.Mockito.*;
+
+
+public class TestSystemAdmin {
+  private static final String SYSTEM = "system";
+  private static final String STREAM = "stream";
+  private static final String OTHER_STREAM = "otherStream";
+
+  /**
+   * Given some SSPs, getSSPMetadata should delegate to getSystemStreamMetadata and properly extract the results for the
+   * requested SSPs.
+   */
+  @Test
+  public void testGetSSPMetadata() {
+    SystemStreamPartition streamPartition0 = new SystemStreamPartition(SYSTEM, STREAM, new Partition(0));
+    SystemStreamPartition streamPartition1 = new SystemStreamPartition(SYSTEM, STREAM, new Partition(1));
+    SystemStreamPartition otherStreamPartition0 = new SystemStreamPartition(SYSTEM, OTHER_STREAM, new Partition(0));
+    SystemAdmin systemAdmin = mock(MySystemAdmin.class);
+    SystemStreamMetadata.SystemStreamPartitionMetadata streamPartition0Metadata =
+        new SystemStreamMetadata.SystemStreamPartitionMetadata("1", "2", "3");
+    SystemStreamMetadata.SystemStreamPartitionMetadata streamPartition1Metadata =
+        new SystemStreamMetadata.SystemStreamPartitionMetadata("11", "12", "13");
+    SystemStreamMetadata.SystemStreamPartitionMetadata otherStreamPartition0Metadata =
+        new SystemStreamMetadata.SystemStreamPartitionMetadata("21", "22", "23");
+    when(systemAdmin.getSystemStreamMetadata(ImmutableSet.of(STREAM, OTHER_STREAM))).thenReturn(ImmutableMap.of(
+        STREAM, new SystemStreamMetadata(STREAM, ImmutableMap.of(
+            new Partition(0), streamPartition0Metadata,
+            new Partition(1), streamPartition1Metadata)),
+        OTHER_STREAM, new SystemStreamMetadata(OTHER_STREAM, ImmutableMap.of(
+            new Partition(0), otherStreamPartition0Metadata))));
+    Set<SystemStreamPartition> ssps = ImmutableSet.of(streamPartition0, streamPartition1, otherStreamPartition0);
+    when(systemAdmin.getSSPMetadata(ssps)).thenCallRealMethod();
+    Map<SystemStreamPartition, SystemStreamMetadata.SystemStreamPartitionMetadata> expected = ImmutableMap.of(
+        streamPartition0, streamPartition0Metadata,
+        streamPartition1, streamPartition1Metadata,
+        otherStreamPartition0, otherStreamPartition0Metadata);
+    assertEquals(expected, systemAdmin.getSSPMetadata(ssps));
+    verify(systemAdmin).getSystemStreamMetadata(ImmutableSet.of(STREAM, OTHER_STREAM));
+  }
+
+  /**
+   * Given some SSPs, but missing metadata for one of the streams, getSSPMetadata should delegate to
+   * getSystemStreamMetadata and only fill in results for the SSPs corresponding to streams with metadata.
+   */
+  @Test
+  public void testGetSSPMetadataMissingStream() {
+    SystemStreamPartition streamPartition0 = new SystemStreamPartition(SYSTEM, STREAM, new Partition(0));
+    SystemStreamPartition otherStreamPartition0 = new SystemStreamPartition(SYSTEM, OTHER_STREAM, new Partition(0));
+    SystemAdmin systemAdmin = mock(MySystemAdmin.class);
+    SystemStreamMetadata.SystemStreamPartitionMetadata streamPartition0Metadata =
+        new SystemStreamMetadata.SystemStreamPartitionMetadata("1", "2", "3");
+    when(systemAdmin.getSystemStreamMetadata(ImmutableSet.of(STREAM, OTHER_STREAM))).thenReturn(ImmutableMap.of(
+        STREAM, new SystemStreamMetadata(STREAM, ImmutableMap.of(new Partition(0), streamPartition0Metadata))));
+    Set<SystemStreamPartition> ssps = ImmutableSet.of(streamPartition0, otherStreamPartition0);
+    when(systemAdmin.getSSPMetadata(ssps)).thenCallRealMethod();
+    Map<SystemStreamPartition, SystemStreamMetadata.SystemStreamPartitionMetadata> expected =
+        ImmutableMap.of(streamPartition0, streamPartition0Metadata);
+    assertEquals(expected, systemAdmin.getSSPMetadata(ssps));
+    verify(systemAdmin).getSystemStreamMetadata(ImmutableSet.of(STREAM, OTHER_STREAM));
+  }
+
+  /**
+   * Given some SSPs, but missing metadata for one of the SSPs, getSSPMetadata should delegate to
+   * getSystemStreamMetadata and only fill in results for the SSPs that have metadata.
+   */
+  @Test
+  public void testGetSSPMetadataMissingPartition() {
+    SystemStreamPartition streamPartition0 = new SystemStreamPartition(SYSTEM, STREAM, new Partition(0));
+    SystemStreamPartition streamPartition1 = new SystemStreamPartition(SYSTEM, STREAM, new Partition(1));
+    SystemAdmin systemAdmin = mock(MySystemAdmin.class);
+    SystemStreamMetadata.SystemStreamPartitionMetadata streamPartition0Metadata =
+        new SystemStreamMetadata.SystemStreamPartitionMetadata("1", "2", "3");
+    when(systemAdmin.getSystemStreamMetadata(ImmutableSet.of(STREAM))).thenReturn(ImmutableMap.of(
+        STREAM, new SystemStreamMetadata(STREAM, ImmutableMap.of(new Partition(0), streamPartition0Metadata))));
+    Set<SystemStreamPartition> ssps = ImmutableSet.of(streamPartition0, streamPartition1);
+    when(systemAdmin.getSSPMetadata(ssps)).thenCallRealMethod();
+    Map<SystemStreamPartition, SystemStreamMetadata.SystemStreamPartitionMetadata> expected =
+        ImmutableMap.of(streamPartition0, streamPartition0Metadata);
+    assertEquals(expected, systemAdmin.getSSPMetadata(ssps));
+    verify(systemAdmin).getSystemStreamMetadata(ImmutableSet.of(STREAM));
+  }
+
+  /**
+   * Looks like Mockito 1.x does not support using thenCallRealMethod with default methods for interfaces, but it works
+   * to use this placeholder abstract class.
+   */
+  private abstract class MySystemAdmin implements ExtendedSystemAdmin { }
+}
\ No newline at end of file
index db6f0d9..c807b02 100644 (file)
@@ -20,7 +20,9 @@
 package org.apache.samza.storage;
 
 import java.io.File;
+import java.time.Duration;
 import java.util.ArrayList;
+import java.util.Collections;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
@@ -40,12 +42,14 @@ import org.apache.samza.job.model.TaskModel;
 import org.apache.samza.metrics.MetricsRegistryMap;
 import org.apache.samza.serializers.ByteSerde;
 import org.apache.samza.serializers.Serde;
+import org.apache.samza.system.SSPMetadataCache;
 import org.apache.samza.system.StreamMetadataCache;
 import org.apache.samza.system.SystemAdmins;
 import org.apache.samza.system.SystemConsumer;
 import org.apache.samza.system.SystemFactory;
 import org.apache.samza.system.SystemStream;
 import org.apache.samza.system.SystemStreamPartition;
+import org.apache.samza.util.Clock;
 import org.apache.samza.util.CommandLine;
 import org.apache.samza.util.ScalaJavaUtil;
 import org.apache.samza.util.SystemClock;
@@ -196,7 +200,11 @@ public class StorageRecovery extends CommandLine {
    */
   @SuppressWarnings({ "unchecked", "rawtypes" })
   private void getTaskStorageManagers() {
-    StreamMetadataCache streamMetadataCache = new StreamMetadataCache(systemAdmins, 5000, SystemClock.instance());
+    Clock clock = SystemClock.instance();
+    StreamMetadataCache streamMetadataCache = new StreamMetadataCache(systemAdmins, 5000, clock);
+    // don't worry about prefetching for this; looks like the tool doesn't flush to offset files anyways
+    SSPMetadataCache sspMetadataCache =
+        new SSPMetadataCache(systemAdmins, Duration.ofSeconds(5), clock, Collections.emptySet());
 
     for (ContainerModel containerModel : containers.values()) {
       HashMap<String, StorageEngine> taskStores = new HashMap<String, StorageEngine>();
@@ -235,6 +243,7 @@ public class StorageRecovery extends CommandLine {
             ScalaJavaUtil.toScalaMap(changeLogSystemStreams),
             maxPartitionNumber,
             streamMetadataCache,
+            sspMetadataCache,
             storeBaseDir,
             storeBaseDir,
             taskModel.getChangelogPartition(),
index 5380fc9..38d0d9c 100644 (file)
@@ -23,10 +23,12 @@ import java.io.File
 import java.lang.management.ManagementFactory
 import java.net.{URL, UnknownHostException}
 import java.nio.file.Path
+import java.time.Duration
 import java.util
 import java.util.Base64
 import java.util.concurrent.{ExecutorService, Executors, ScheduledExecutorService, TimeUnit}
 
+import com.google.common.annotations.VisibleForTesting
 import com.google.common.util.concurrent.ThreadFactoryBuilder
 import org.apache.samza.checkpoint.{CheckpointListener, CheckpointManagerFactory, OffsetManager, OffsetManagerMetrics}
 import org.apache.samza.config.JobConfig.Config2Job
@@ -41,7 +43,7 @@ import org.apache.samza.container.disk.DiskSpaceMonitor.Listener
 import org.apache.samza.container.disk.{DiskQuotaPolicyFactory, DiskSpaceMonitor, NoThrottlingDiskQuotaPolicyFactory, PollingScanDiskSpaceMonitor}
 import org.apache.samza.container.host.{StatisticsMonitorImpl, SystemMemoryStatistics, SystemStatisticsMonitor}
 import org.apache.samza.coordinator.stream.{CoordinatorStreamManager, CoordinatorStreamSystemProducer}
-import org.apache.samza.job.model.JobModel
+import org.apache.samza.job.model.{ContainerModel, JobModel}
 import org.apache.samza.metrics.{JmxServer, JvmMetrics, MetricsRegistryMap, MetricsReporter}
 import org.apache.samza.serializers._
 import org.apache.samza.serializers.model.SamzaObjectMapper
@@ -327,6 +329,23 @@ object SamzaContainer extends Logging {
 
     info("Got change log system streams: %s" format changeLogSystemStreams)
 
+    /*
+     * This keeps track of the changelog SSPs that are associated with the whole container. This is used so that we can
+     * prefetch the metadata about the all of the changelog SSPs associated with the container whenever we need the
+     * metadata about some of the changelog SSPs.
+     * An example use case is when Samza writes offset files for stores ({@link TaskStorageManager}). Each task is
+     * responsible for its own offset file, but if we can do prefetching, then most tasks will already have cached
+     * metadata by the time they need the offset metadata.
+     * Note: By using all changelog streams to build the sspsToPrefetch, any fetches done for persisted stores will
+     * include the ssps for non-persisted stores, so this is slightly suboptimal. However, this does not increase the
+     * actual number of calls to the {@link SystemAdmin}, and we can decouple this logic from the per-task objects (e.g.
+     * {@link TaskStorageManager}).
+     */
+    val changelogSSPMetadataCache = new SSPMetadataCache(systemAdmins,
+      Duration.ofSeconds(5),
+      SystemClock.instance,
+      getChangelogSSPsForContainer(containerModel, changeLogSystemStreams).asJava)
+
     val intermediateStreams = config
       .getStreamIds
       .filter(config.getIsIntermediateStream(_))
@@ -564,6 +583,7 @@ object SamzaContainer extends Logging {
         changeLogSystemStreams = changeLogSystemStreams,
         maxChangeLogStreamPartitions,
         streamMetadataCache = streamMetadataCache,
+        sspMetadataCache = changelogSSPMetadataCache,
         nonLoggedStoreBaseDir = nonLoggedStorageBaseDir,
         loggedStoreBaseDir = loggedStorageBaseDir,
         partition = taskModel.getChangelogPartition,
@@ -673,6 +693,20 @@ object SamzaContainer extends Logging {
       taskThreadPool = taskThreadPool,
       timerExecutor = timerExecutor)
   }
+
+
+  /**
+    * Builds the set of SSPs for all changelogs on this container.
+    */
+  @VisibleForTesting
+  private[container] def getChangelogSSPsForContainer(containerModel: ContainerModel,
+    changeLogSystemStreams: Map[String, SystemStream]): Set[SystemStreamPartition] = {
+    containerModel.getTasks.values().asScala
+      .map(taskModel => taskModel.getChangelogPartition)
+      .flatMap(changelogPartition => changeLogSystemStreams.map { case (_, systemStream) =>
+        new SystemStreamPartition(systemStream, changelogPartition) })
+      .toSet
+  }
 }
 
 class SamzaContainer(
index 09744cf..62b59fb 100644 (file)
@@ -26,7 +26,7 @@ import org.apache.samza.config.StorageConfig
 import org.apache.samza.{Partition, SamzaException}
 import org.apache.samza.container.TaskName
 import org.apache.samza.system._
-import org.apache.samza.util.{Clock, FileUtil, Logging, Util}
+import org.apache.samza.util.{Clock, FileUtil, Logging}
 
 import scala.collection.JavaConverters._
 
@@ -51,6 +51,7 @@ class TaskStorageManager(
   changeLogSystemStreams: Map[String, SystemStream] = Map(),
   changeLogStreamPartitions: Int,
   streamMetadataCache: StreamMetadataCache,
+  sspMetadataCache: SSPMetadataCache,
   nonLoggedStoreBaseDir: File = new File(System.getProperty("user.dir"), "state"),
   loggedStoreBaseDir: File = new File(System.getProperty("user.dir"), "state"),
   partition: Partition,
@@ -329,23 +330,11 @@ class TaskStorageManager(
     debug("Persisting logged key value stores")
 
     for ((storeName, systemStream) <- changeLogSystemStreams.filterKeys(storeName => persistedStores.contains(storeName))) {
-      val systemAdmin = systemAdmins.getSystemAdmin(systemStream.getSystem)
-
       debug("Fetching newest offset for store %s" format(storeName))
       try {
-        val newestOffset = if (systemAdmin.isInstanceOf[ExtendedSystemAdmin]) {
-          // This approach is much more efficient because it only fetches the newest offset for 1 SSP
-          // rather than newest and oldest offsets for all SSPs. Use it if we can.
-          systemAdmin.asInstanceOf[ExtendedSystemAdmin].getNewestOffset(new SystemStreamPartition(systemStream.getSystem, systemStream.getStream, partition), 3)
-        } else {
-          val streamToMetadata = systemAdmins.getSystemAdmin(systemStream.getSystem)
-                  .getSystemStreamMetadata(Set(systemStream.getStream).asJava)
-          val sspMetadata = streamToMetadata
-                  .get(systemStream.getStream)
-                  .getSystemStreamPartitionMetadata
-                  .get(partition)
-          sspMetadata.getNewestOffset
-        }
+        val ssp = new SystemStreamPartition(systemStream.getSystem, systemStream.getStream, partition)
+        val sspMetadata = sspMetadataCache.getMetadata(ssp)
+        val newestOffset = if (sspMetadata == null) null else sspMetadata.getNewestOffset
         debug("Got offset %s for store %s" format(newestOffset, storeName))
 
         val loggedStorePartitionDir = TaskStorageManager.getStorePartitionDir(loggedStoreBaseDir, storeName, taskName)
diff --git a/samza-core/src/main/scala/org/apache/samza/system/SSPMetadataCache.java b/samza-core/src/main/scala/org/apache/samza/system/SSPMetadataCache.java
new file mode 100644 (file)
index 0000000..bbe81a8
--- /dev/null
@@ -0,0 +1,126 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.samza.system;
+
+import java.time.Duration;
+import java.time.Instant;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.ConcurrentHashMap;
+import org.apache.samza.util.Clock;
+
+
+/**
+ * Fetches and caches metadata about a set of SSPs. When fetching metadata for a stale SSP, this will also prefetch
+ * metadata for all other SSPs specified for prefetching which are stale and in the same system.
+ */
+public class SSPMetadataCache {
+  private final SystemAdmins systemAdmins;
+  private final Duration cacheTTL;
+  private final Clock clock;
+  private final Set<SystemStreamPartition> sspsToPrefetch;
+
+  private final Object metadataRefreshLock;
+  private final ConcurrentHashMap<SystemStreamPartition, CacheEntry> cache;
+
+  public SSPMetadataCache(SystemAdmins systemAdmins, Duration cacheTTL, Clock clock,
+      Set<SystemStreamPartition> sspsToPrefetch) {
+    this.systemAdmins = systemAdmins;
+    this.cacheTTL = cacheTTL;
+    this.clock = clock;
+    this.sspsToPrefetch = sspsToPrefetch;
+
+    this.metadataRefreshLock = new Object();
+    this.cache = new ConcurrentHashMap<>();
+  }
+
+  /**
+   * Gets the metadata for an SSP. This will return a cached value if it is fresh enough. Otherwise, it will fetch the
+   * metadata from a source-of-truth.
+   * If the metadata for the SSP needs to be fetched, then this will also prefetch and cache the metadata for any stale
+   * {@link #sspsToPrefetch} that can be included in the same fetch call (e.g. same system).
+   *
+   * @param ssp SSP for which to get metadata
+   * @return metadata for the SSP; null if the source-of-truth returned no metadata
+   * @throws RuntimeException if there was an error in fetching metadata
+   */
+  public SystemStreamMetadata.SystemStreamPartitionMetadata getMetadata(SystemStreamPartition ssp) {
+    maybeRefreshMetadata(ssp);
+    CacheEntry cacheEntry = cache.get(ssp);
+    /*
+     * cacheEntry itself should not be null once the refresh is done, but check anyways to be safe.
+     * The metadata inside a non-null cacheEntry might still be null, so this will return null in that case.
+     */
+    return cacheEntry == null ? null : cacheEntry.getMetadata();
+  }
+
+  private void maybeRefreshMetadata(SystemStreamPartition requestedSSP) {
+    synchronized (this.metadataRefreshLock) {
+      Instant refreshRequestedAt = Instant.ofEpochMilli(this.clock.currentTimeMillis());
+      if (shouldRefresh(requestedSSP, refreshRequestedAt)) {
+        String system = requestedSSP.getSystem();
+        Set<SystemStreamPartition> sspsToFetchFor = new HashSet<>();
+        sspsToFetchFor.add(requestedSSP);
+        for (SystemStreamPartition sspToPrefetch : this.sspsToPrefetch) {
+          if (system.equals(sspToPrefetch.getSystem()) && shouldRefresh(sspToPrefetch, refreshRequestedAt)) {
+            sspsToFetchFor.add(sspToPrefetch);
+          }
+        }
+        SystemAdmin systemAdmin = this.systemAdmins.getSystemAdmin(system);
+        Map<SystemStreamPartition, SystemStreamMetadata.SystemStreamPartitionMetadata> fetchedMetadata =
+            systemAdmin.getSSPMetadata(sspsToFetchFor);
+        Instant updatedAt = Instant.ofEpochMilli(this.clock.currentTimeMillis());
+        // we want to add an entry even if there was no metadata, so iterate over sspsToFetchFor
+        sspsToFetchFor.forEach(ssp -> this.cache.put(ssp, new CacheEntry(fetchedMetadata.get(ssp), updatedAt)));
+      }
+    }
+  }
+
+  private boolean shouldRefresh(SystemStreamPartition ssp, Instant now) {
+    CacheEntry cacheEntry = cache.get(ssp);
+    if (cacheEntry == null) {
+      return true;
+    } else {
+      Instant isFreshUntil = cacheEntry.getLastUpdatedAt().plus(cacheTTL);
+      return now.isAfter(isFreshUntil);
+    }
+  }
+
+  private static class CacheEntry {
+    /**
+     * Nullable so that we can cache that there was no metadata for the last fetch.
+     */
+    private final SystemStreamMetadata.SystemStreamPartitionMetadata metadata;
+    private final Instant lastUpdatedAt;
+
+    private CacheEntry(SystemStreamMetadata.SystemStreamPartitionMetadata metadata, Instant lastUpdatedAt) {
+      this.metadata = metadata;
+      this.lastUpdatedAt = lastUpdatedAt;
+    }
+
+    private SystemStreamMetadata.SystemStreamPartitionMetadata getMetadata() {
+      return metadata;
+    }
+
+    private Instant getLastUpdatedAt() {
+      return lastUpdatedAt;
+    }
+  }
+}
diff --git a/samza-core/src/test/java/org/apache/samza/system/TestSSPMetadataCache.java b/samza-core/src/test/java/org/apache/samza/system/TestSSPMetadataCache.java
new file mode 100644 (file)
index 0000000..efe09d1
--- /dev/null
@@ -0,0 +1,319 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.samza.system;
+
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.ImmutableSet;
+import java.time.Duration;
+import java.util.List;
+import java.util.Set;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+import org.apache.samza.Partition;
+import org.apache.samza.SamzaException;
+import org.apache.samza.util.Clock;
+import org.junit.Before;
+import org.junit.Test;
+import org.mockito.Mock;
+import org.mockito.MockitoAnnotations;
+
+import static org.junit.Assert.*;
+import static org.mockito.Mockito.*;
+
+
+public class TestSSPMetadataCache {
+  private static final String SYSTEM = "system";
+  private static final String STREAM = "stream";
+  private static final Duration CACHE_TTL = Duration.ofMillis(100);
+
+  @Mock
+  private SystemAdmin systemAdmin;
+  @Mock
+  private SystemAdmins systemAdmins;
+  @Mock
+  private Clock clock;
+
+  @Before
+  public void setup() {
+    MockitoAnnotations.initMocks(this);
+    when(systemAdmins.getSystemAdmin(SYSTEM)).thenReturn(systemAdmin);
+  }
+
+  /**
+   * Given that there are sspsToPrefetch, getMetadata should call the admin (when necessary) to get the metadata for the
+   * requested and "prefetch" SSPs. It should also cache the data.
+   */
+  @Test
+  public void testGetMetadataWithPrefetch() {
+    SystemStreamPartition ssp = buildSSP(0);
+    SystemStreamPartition otherSSP = buildSSP(1);
+    SSPMetadataCache cache = buildSSPMetadataCache(ImmutableSet.of(ssp, otherSSP));
+
+    // t = 10: first read, t = 11: first write
+    when(clock.currentTimeMillis()).thenReturn(10L, 11L);
+    when(systemAdmin.getSSPMetadata(ImmutableSet.of(ssp, otherSSP))).thenReturn(
+        ImmutableMap.of(ssp, sspMetadata(1), otherSSP, sspMetadata(2)));
+    assertEquals(sspMetadata(1), cache.getMetadata(ssp));
+    verify(systemAdmin).getSSPMetadata(ImmutableSet.of(ssp, otherSSP));
+
+    // stay within TTL: use cached data
+    when(clock.currentTimeMillis()).thenReturn(11 + CACHE_TTL.toMillis());
+    assertEquals(sspMetadata(1), cache.getMetadata(ssp));
+    assertEquals(sspMetadata(2), cache.getMetadata(otherSSP));
+    // still only one call to the admin from the initial fill
+    verify(systemAdmin).getSSPMetadata(ImmutableSet.of(ssp, otherSSP));
+
+    // now entries are stale
+    when(clock.currentTimeMillis()).thenReturn(12 + CACHE_TTL.toMillis());
+    when(systemAdmin.getSSPMetadata(ImmutableSet.of(ssp, otherSSP))).thenReturn(
+        ImmutableMap.of(ssp, sspMetadata(10), otherSSP, sspMetadata(11)));
+    // flip the order; prefetching should still be done correctly
+    assertEquals(sspMetadata(11), cache.getMetadata(otherSSP));
+    assertEquals(sspMetadata(10), cache.getMetadata(ssp));
+    verify(systemAdmin, times(2)).getSSPMetadata(ImmutableSet.of(ssp, otherSSP));
+  }
+
+  /**
+   * Given that an SSP has empty metadata, getMetadata should return and cache that.
+   */
+  @Test
+  public void testGetMetadataEmptyMetadata() {
+    SystemStreamPartition ssp = buildSSP(0);
+    SSPMetadataCache cache = buildSSPMetadataCache(ImmutableSet.of(ssp));
+
+    // t = 10: first read, t = 11: first write
+    when(clock.currentTimeMillis()).thenReturn(10L, 11L);
+    when(systemAdmin.getSSPMetadata(ImmutableSet.of(ssp))).thenReturn(ImmutableMap.of());
+    assertNull(cache.getMetadata(ssp));
+    verify(systemAdmin).getSSPMetadata(ImmutableSet.of(ssp));
+
+    // stay within TTL: use cached data
+    when(clock.currentTimeMillis()).thenReturn(11 + CACHE_TTL.toMillis());
+    assertNull(cache.getMetadata(ssp));
+    // still only one call to the admin from the initial fill
+    verify(systemAdmin).getSSPMetadata(ImmutableSet.of(ssp));
+
+    // now entries are stale
+    when(clock.currentTimeMillis()).thenReturn(12 + CACHE_TTL.toMillis());
+    when(systemAdmin.getSSPMetadata(ImmutableSet.of(ssp))).thenReturn(ImmutableMap.of());
+    assertNull(cache.getMetadata(ssp));
+    verify(systemAdmin, times(2)).getSSPMetadata(ImmutableSet.of(ssp));
+  }
+
+  /**
+   * Given that the there are sspsToPrefetch with systems that do not match the requested SSP, getMetadata should not
+   * prefetch all sspsToPrefetch.
+   */
+  @Test
+  public void testGetMetadataMultipleSystemsForPrefetch() {
+    // add one more extended system admin so we can have two of them for this test
+    SystemAdmin otherSystemAdmin = mock(SystemAdmin.class);
+    String otherSystem = "otherSystem";
+    when(systemAdmins.getSystemAdmin(otherSystem)).thenReturn(otherSystemAdmin);
+    SystemStreamPartition ssp = buildSSP(0);
+    // different system should not get prefetched
+    SystemStreamPartition sspOtherSystem = new SystemStreamPartition(otherSystem, "otherStream", new Partition(1));
+    SSPMetadataCache cache = buildSSPMetadataCache(ImmutableSet.of(ssp, sspOtherSystem));
+
+    // t = 10: first read for ssp, t = 11: first write for ssp
+    when(clock.currentTimeMillis()).thenReturn(10L, 11L);
+    when(systemAdmin.getSSPMetadata(ImmutableSet.of(ssp))).thenReturn(ImmutableMap.of(ssp, sspMetadata(1)));
+    assertEquals(sspMetadata(1), cache.getMetadata(ssp));
+    // does not call for sspOtherSystem
+    verify(systemAdmin).getSSPMetadata(ImmutableSet.of(ssp));
+
+    // t = 12: first read for sspOtherSystem, t = 13: first write for sspOtherSystem
+    when(clock.currentTimeMillis()).thenReturn(12L, 13L);
+    when(otherSystemAdmin.getSSPMetadata(ImmutableSet.of(sspOtherSystem))).thenReturn(
+        ImmutableMap.of(sspOtherSystem, sspMetadata(2)));
+    assertEquals(sspMetadata(2), cache.getMetadata(sspOtherSystem));
+    // does not call for ssp
+    verify(otherSystemAdmin).getSSPMetadata(ImmutableSet.of(sspOtherSystem));
+
+    // now entries are stale, do another round of individual fetches
+    when(clock.currentTimeMillis()).thenReturn(14 + CACHE_TTL.toMillis());
+    when(systemAdmin.getSSPMetadata(ImmutableSet.of(ssp))).thenReturn(ImmutableMap.of(ssp, sspMetadata(10)));
+    assertEquals(sspMetadata(10), cache.getMetadata(ssp));
+    verify(systemAdmin, times(2)).getSSPMetadata(ImmutableSet.of(ssp));
+    when(otherSystemAdmin.getSSPMetadata(ImmutableSet.of(sspOtherSystem))).thenReturn(
+        ImmutableMap.of(sspOtherSystem, sspMetadata(11)));
+    assertEquals(sspMetadata(11), cache.getMetadata(sspOtherSystem));
+    verify(otherSystemAdmin, times(2)).getSSPMetadata(ImmutableSet.of(sspOtherSystem));
+  }
+
+  /**
+   * Given that there are no sspsToPrefetch, getMetadata should still fetch and cache metadata for a requested SSP.
+   */
+  @Test
+  public void testGetMetadataNoSSPsToPrefetch() {
+    SystemStreamPartition ssp = buildSSP(0);
+    SSPMetadataCache cache = buildSSPMetadataCache(ImmutableSet.of());
+
+    // t = 10: first read, t = 11: first write
+    when(clock.currentTimeMillis()).thenReturn(10L, 11L);
+    when(systemAdmin.getSSPMetadata(ImmutableSet.of(ssp))).thenReturn(ImmutableMap.of(ssp, sspMetadata(1)));
+    assertEquals(sspMetadata(1), cache.getMetadata(ssp));
+    verify(systemAdmin).getSSPMetadata(ImmutableSet.of(ssp));
+
+    // stay within TTL: use cached data
+    when(clock.currentTimeMillis()).thenReturn(11 + CACHE_TTL.toMillis());
+    assertEquals(sspMetadata(1), cache.getMetadata(ssp));
+
+    // now entry is stale
+    when(clock.currentTimeMillis()).thenReturn(12 + CACHE_TTL.toMillis());
+    when(systemAdmin.getSSPMetadata(ImmutableSet.of(ssp))).thenReturn(ImmutableMap.of(ssp, sspMetadata(10)));
+    assertEquals(sspMetadata(10), cache.getMetadata(ssp));
+    verify(systemAdmin, times(2)).getSSPMetadata(ImmutableSet.of(ssp));
+  }
+
+  /**
+   * Given that the sspsToPrefetch does not contain the requested SSP, getMetadata should still fetch and cache metadata
+   * for it.
+   */
+  @Test
+  public void testGetMetadataRequestedSSPNotInSSPsToPrefetch() {
+    SystemStreamPartition ssp = buildSSP(0);
+    SystemStreamPartition otherSSP = buildSSP(1);
+    // do not include ssp in sspsToPrefetch
+    SSPMetadataCache cache = buildSSPMetadataCache(ImmutableSet.of(otherSSP));
+
+    // t = 10: first read, t = 11: first write
+    when(clock.currentTimeMillis()).thenReturn(10L, 11L);
+    when(systemAdmin.getSSPMetadata(ImmutableSet.of(ssp, otherSSP))).thenReturn(
+        ImmutableMap.of(ssp, sspMetadata(1), otherSSP, sspMetadata(2)));
+    assertEquals(sspMetadata(1), cache.getMetadata(ssp));
+    // still will fetch metadata for both
+    verify(systemAdmin).getSSPMetadata(ImmutableSet.of(ssp, otherSSP));
+
+    // stay within TTL: use cached data
+    when(clock.currentTimeMillis()).thenReturn(11 + CACHE_TTL.toMillis());
+    assertEquals(sspMetadata(1), cache.getMetadata(ssp));
+    assertEquals(sspMetadata(2), cache.getMetadata(otherSSP));
+    // still only one call to the admin from the initial fill
+    verify(systemAdmin).getSSPMetadata(ImmutableSet.of(ssp, otherSSP));
+
+    // now entries are stale
+    when(clock.currentTimeMillis()).thenReturn(12 + CACHE_TTL.toMillis());
+    when(systemAdmin.getSSPMetadata(ImmutableSet.of(ssp))).thenReturn(
+        ImmutableMap.of(ssp, sspMetadata(10), otherSSP, sspMetadata(11)));
+    when(systemAdmin.getSSPMetadata(ImmutableSet.of(otherSSP))).thenReturn(
+        ImmutableMap.of(otherSSP, sspMetadata(11)));
+    // call for otherSSP first; no prefetching since ssp is not in sspsToPrefetch
+    assertEquals(sspMetadata(11), cache.getMetadata(otherSSP));
+    verify(systemAdmin).getSSPMetadata(ImmutableSet.of(otherSSP));
+    // call for ssp also has no prefetching since the otherSSP metadata is fresh at this point
+    assertEquals(sspMetadata(10), cache.getMetadata(ssp));
+    verify(systemAdmin).getSSPMetadata(ImmutableSet.of(ssp));
+    // still only one call for both at the same time from the initial fill
+    verify(systemAdmin).getSSPMetadata(ImmutableSet.of(ssp, otherSSP));
+  }
+
+  /**
+   * Given concurrent access to getMetadata, there should be only single calls to fetch metadata.
+   */
+  @Test
+  public void testGetMetadataConcurrentAccess() throws ExecutionException, InterruptedException {
+    int numPartitions = 50;
+    // initial fetch
+    when(clock.currentTimeMillis()).thenReturn(10L);
+    Set<SystemStreamPartition> ssps =
+        IntStream.range(0, numPartitions).mapToObj(TestSSPMetadataCache::buildSSP).collect(Collectors.toSet());
+    SSPMetadataCache cache = buildSSPMetadataCache(ssps);
+    ExecutorService executorService = Executors.newFixedThreadPool(10);
+    when(systemAdmin.getSSPMetadata(ssps)).thenAnswer(invocation -> {
+        // have the admin call wait so that it forces the threads to overlap on the lock
+        Thread.sleep(500);
+        return IntStream.range(0, numPartitions)
+            .boxed()
+            .collect(Collectors.toMap(TestSSPMetadataCache::buildSSP, i -> sspMetadata((long) i)));
+      });
+
+    // send concurrent requests for metadata
+    List<Future<SystemStreamMetadata.SystemStreamPartitionMetadata>> getMetadataFutures =
+        IntStream.range(0, numPartitions)
+            .mapToObj(i -> executorService.submit(() -> cache.getMetadata(buildSSP(i))))
+            .collect(Collectors.toList());
+    for (int i = 0; i < numPartitions; i++) {
+      assertEquals(sspMetadata(i), getMetadataFutures.get(i).get());
+    }
+    // should only see one call to fetch metadata
+    verify(systemAdmin).getSSPMetadata(ssps);
+
+    // make entries stale
+    when(clock.currentTimeMillis()).thenReturn(11 + CACHE_TTL.toMillis());
+    getMetadataFutures = IntStream.range(0, numPartitions)
+        .mapToObj(i -> executorService.submit(() -> cache.getMetadata(buildSSP(i))))
+        .collect(Collectors.toList());
+    for (int i = 0; i < numPartitions; i++) {
+      assertEquals(sspMetadata(i), getMetadataFutures.get(i).get());
+    }
+    // should see two total calls to fetch metadata
+    verify(systemAdmin, times(2)).getSSPMetadata(ssps);
+  }
+
+  /**
+   * Given that the admin throws an exception when trying to get the metadata for the first time, getMetadata should
+   * propagate the exception.
+   */
+  @Test(expected = SamzaException.class)
+  public void testGetMetadataExceptionFirstFetch() {
+    SystemStreamPartition ssp = buildSSP(0);
+    SSPMetadataCache cache = buildSSPMetadataCache(ImmutableSet.of(ssp));
+    when(clock.currentTimeMillis()).thenReturn(10L);
+    when(systemAdmin.getSSPMetadata(ImmutableSet.of(ssp))).thenThrow(new SamzaException());
+    cache.getMetadata(ssp);
+  }
+
+  /**
+   * Given that the admin throws an exception when trying to get the metadata after a successful fetch, getMetadata
+   * should propagate the exception.
+   */
+  @Test(expected = SamzaException.class)
+  public void testGetMetadataExceptionAfterSuccessfulFetch() {
+    SystemStreamPartition ssp = buildSSP(0);
+    SSPMetadataCache cache = buildSSPMetadataCache(ImmutableSet.of(ssp));
+
+    // do a successful fetch first
+    when(clock.currentTimeMillis()).thenReturn(10L);
+    when(systemAdmin.getSSPMetadata(ImmutableSet.of(ssp))).thenReturn(ImmutableMap.of(ssp, sspMetadata(1)));
+    cache.getMetadata(ssp);
+
+    // throw an exception on the next fetch
+    when(clock.currentTimeMillis()).thenReturn(11 + CACHE_TTL.toMillis());
+    when(systemAdmin.getSSPMetadata(ImmutableSet.of(ssp))).thenThrow(new SamzaException());
+    cache.getMetadata(ssp);
+  }
+
+  private SSPMetadataCache buildSSPMetadataCache(Set<SystemStreamPartition> sspsToPrefetch) {
+    return new SSPMetadataCache(systemAdmins, CACHE_TTL, clock, sspsToPrefetch);
+  }
+
+  private static SystemStreamPartition buildSSP(int partition) {
+    return new SystemStreamPartition(SYSTEM, STREAM, new Partition(partition));
+  }
+
+  private static SystemStreamMetadata.SystemStreamPartitionMetadata sspMetadata(long baseOffset) {
+    return new SystemStreamMetadata.SystemStreamPartitionMetadata(Long.toString(baseOffset),
+        Long.toString(baseOffset * 100), Long.toString(baseOffset * 100 + 1));
+  }
+}
\ No newline at end of file
index c002f76..b27b151 100644 (file)
@@ -599,6 +599,40 @@ class TestSamzaContainer extends AssertionsForJUnit with MockitoSugar {
     assertTrue(containerMetrics.taskStoreRestorationMetrics.get(taskName).getValue >= 1)
 
   }
+
+  @Test
+  def testGetChangelogSSPsForContainer() = {
+    val taskName0 = new TaskName("task0")
+    val taskName1 = new TaskName("task1")
+    val taskModel0 = new TaskModel(taskName0,
+      Set(new SystemStreamPartition("input", "stream", new Partition(0))),
+      new Partition(10))
+    val taskModel1 = new TaskModel(taskName1,
+      Set(new SystemStreamPartition("input", "stream", new Partition(1))),
+      new Partition(11))
+    val containerModel = new ContainerModel("processorId", 0, Map(taskName0 -> taskModel0, taskName1 -> taskModel1))
+    val changeLogSystemStreams = Map("store0" -> new SystemStream("changelogSystem0", "store0-changelog"),
+      "store1" -> new SystemStream("changelogSystem1", "store1-changelog"))
+    val expected = Set(new SystemStreamPartition("changelogSystem0", "store0-changelog", new Partition(10)),
+      new SystemStreamPartition("changelogSystem1", "store1-changelog", new Partition(10)),
+      new SystemStreamPartition("changelogSystem0", "store0-changelog", new Partition(11)),
+      new SystemStreamPartition("changelogSystem1", "store1-changelog", new Partition(11)))
+    assertEquals(expected, SamzaContainer.getChangelogSSPsForContainer(containerModel, changeLogSystemStreams))
+  }
+
+  @Test
+  def testGetChangelogSSPsForContainerNoChangelogs() = {
+    val taskName0 = new TaskName("task0")
+    val taskName1 = new TaskName("task1")
+    val taskModel0 = new TaskModel(taskName0,
+      Set(new SystemStreamPartition("input", "stream", new Partition(0))),
+      new Partition(10))
+    val taskModel1 = new TaskModel(taskName1,
+      Set(new SystemStreamPartition("input", "stream", new Partition(1))),
+      new Partition(11))
+    val containerModel = new ContainerModel("processorId", 0, Map(taskName0 -> taskModel0, taskName1 -> taskModel1))
+    assertEquals(Set(), SamzaContainer.getChangelogSSPsForContainer(containerModel, Map()))
+  }
 }
 
 class MockCheckpointManager extends CheckpointManager {
index d092577..3bb4e99 100644 (file)
@@ -79,48 +79,36 @@ class TestTaskStorageManager extends MockitoSugar {
 
     // Mock for StreamMetadataCache, SystemConsumer, SystemAdmin
     val mockStreamMetadataCache = mock[StreamMetadataCache]
+    val mockSSPMetadataCache = mock[SSPMetadataCache]
     val mockSystemConsumer = mock[SystemConsumer]
     val mockSystemAdmin = mock[SystemAdmin]
     val changelogSpec = StreamSpec.createChangeLogStreamSpec("testStream", "kafka", 1)
     doNothing().when(mockSystemAdmin).validateStream(changelogSpec)
-    var registerOffset = "0"
-    when(mockSystemConsumer.register(any(), any())).thenAnswer(new Answer[Unit] {
-      override def answer(invocation: InvocationOnMock): Unit = {
-        val args = invocation.getArguments
-        if (ssp.equals(args.apply(0).asInstanceOf[SystemStreamPartition])) {
-          val offset = args.apply(1).asInstanceOf[String]
-          assertNotNull(offset)
-          assertEquals(registerOffset, offset)
-        }
-      }
-    })
     doNothing().when(mockSystemConsumer).stop()
 
     // Test 1: Initial invocation - No store on disk (only changelog has data)
     // Setup initial sspMetadata
-    val sspMetadata = new SystemStreamPartitionMetadata("0", "50", "51")
+    var sspMetadata = new SystemStreamPartitionMetadata("0", "50", "51")
     var metadata = new SystemStreamMetadata("testStream", new java.util.HashMap[Partition, SystemStreamPartitionMetadata]() {
       {
         put(partition, sspMetadata)
       }
     })
     when(mockStreamMetadataCache.getStreamMetadata(any(), any())).thenReturn(Map(ss -> metadata))
-    when(mockSystemAdmin.getSystemStreamMetadata(any())).thenReturn(new util.HashMap[String, SystemStreamMetadata](){
-      {
-        put("testStream", metadata)
-      }
-    })
+    when(mockSSPMetadataCache.getMetadata(ssp)).thenReturn(sspMetadata)
+
     val taskManager = new TaskStorageManagerBuilder()
       .addStore(loggedStore, mockStorageEngine, mockSystemConsumer)
       .setStreamMetadataCache(mockStreamMetadataCache)
+      .setSSPMetadataCache(mockSSPMetadataCache)
       .setSystemAdmin("kafka", mockSystemAdmin)
       .build
 
-
     taskManager.init
 
     assertTrue(storeFile.exists())
     assertFalse(offsetFile.exists())
+    verify(mockSystemConsumer).register(ssp, "0")
 
     // Test 2: flush should update the offset file
     taskManager.flush()
@@ -128,41 +116,28 @@ class TestTaskStorageManager extends MockitoSugar {
     assertEquals("50", FileUtil.readWithChecksum(offsetFile))
 
     // Test 3: Update sspMetadata before shutdown and verify that offset file is updated correctly
-    metadata = new SystemStreamMetadata("testStream", new java.util.HashMap[Partition, SystemStreamPartitionMetadata]() {
-      {
-        put(partition, new SystemStreamPartitionMetadata("0", "100", "101"))
-      }
-    })
-    when(mockStreamMetadataCache.getStreamMetadata(any(), any())).thenReturn(Map(ss -> metadata))
-    when(mockSystemAdmin.getSystemStreamMetadata(any())).thenReturn(new util.HashMap[String, SystemStreamMetadata](){
-      {
-        put("testStream", metadata)
-      }
-    })
+    when(mockSSPMetadataCache.getMetadata(ssp)).thenReturn(new SystemStreamPartitionMetadata("0", "100", "101"))
     taskManager.stop()
     assertTrue(storeFile.exists())
     assertTrue(offsetFile.exists())
     assertEquals("100", FileUtil.readWithChecksum(offsetFile))
 
-
     // Test 4: Initialize again with an updated sspMetadata; Verify that it restores from the correct offset
+    sspMetadata = new SystemStreamPartitionMetadata("0", "150", "151")
     metadata = new SystemStreamMetadata("testStream", new java.util.HashMap[Partition, SystemStreamPartitionMetadata]() {
       {
-        put(partition, new SystemStreamPartitionMetadata("0", "150", "151"))
+        put(partition, sspMetadata)
       }
     })
     when(mockStreamMetadataCache.getStreamMetadata(any(), any())).thenReturn(Map(ss -> metadata))
-    when(mockSystemAdmin.getSystemStreamMetadata(any())).thenReturn(new util.HashMap[String, SystemStreamMetadata](){
-      {
-        put("testStream", metadata)
-      }
-    })
-    registerOffset = "100"
+    when(mockSSPMetadataCache.getMetadata(ssp)).thenReturn(sspMetadata)
+    when(mockSystemAdmin.getOffsetsAfter(Map(ssp -> "100").asJava)).thenReturn(Map(ssp -> "101").asJava)
 
     taskManager.init
 
     assertTrue(storeFile.exists())
     assertTrue(offsetFile.exists())
+    verify(mockSystemConsumer).register(ssp, "101")
   }
 
   /**
@@ -187,16 +162,6 @@ class TestTaskStorageManager extends MockitoSugar {
     doNothing().when(mockSystemAdmin).validateStream(changelogSpec)
 
     val mockSystemConsumer = mock[SystemConsumer]
-    when(mockSystemConsumer.register(any(), any())).thenAnswer(new Answer[Unit] {
-      override def answer(invocation: InvocationOnMock): Unit = {
-        val args = invocation.getArguments
-        if (ssp.equals(args.apply(0).asInstanceOf[SystemStreamPartition])) {
-          val offset = args.apply(1).asInstanceOf[String]
-          assertNotNull(offset)
-          assertEquals("0", offset) // Should always restore from earliest offset
-        }
-      }
-    })
     doNothing().when(mockSystemConsumer).stop()
 
     // Test 1: Initial invocation - No store data (only changelog has data)
@@ -214,11 +179,11 @@ class TestTaskStorageManager extends MockitoSugar {
       .setSystemAdmin("kafka", mockSystemAdmin)
       .build
 
-
     taskManager.init
 
     // Verify that the store directory doesn't have ANY files
     assertNull(storeDirectory.listFiles())
+    verify(mockSystemConsumer).register(ssp, "0")
 
     // Test 2: flush should NOT create/update the offset file. Store directory has no files
     taskManager.flush()
@@ -245,6 +210,8 @@ class TestTaskStorageManager extends MockitoSugar {
     taskManager.init
 
     assertNull(storeDirectory.listFiles())
+    // second time to register; make sure it starts from beginning
+    verify(mockSystemConsumer, times(2)).register(ssp, "0")
   }
 
   @Test
@@ -335,15 +302,15 @@ class TestTaskStorageManager extends MockitoSugar {
 
     val offsetFilePath = new File(TaskStorageManager.getStorePartitionDir(TaskStorageManagerBuilder.defaultLoggedStoreBaseDir, loggedStore, taskName) + File.separator + "OFFSET")
 
-    val mockSystemAdmin = mock[SystemAdmin]
-    val mockSspMetadata = Map("testStream" -> new SystemStreamMetadata("testStream" , Map(partition -> new SystemStreamPartitionMetadata("20", "100", "101")).asJava))
-    val myMap = mockSspMetadata.asJava
-    when(mockSystemAdmin.getSystemStreamMetadata(any(Set("").asJava.getClass))).thenReturn(myMap)
+    val sspMetadataCache = mock[SSPMetadataCache]
+    val sspMetadata = new SystemStreamPartitionMetadata("20", "100", "101")
+    when(sspMetadataCache.getMetadata(new SystemStreamPartition("kafka", "testStream", partition)))
+      .thenReturn(sspMetadata)
 
     //Build TaskStorageManager
     val taskStorageManager = new TaskStorageManagerBuilder()
       .addStore(loggedStore, true)
-      .setSystemAdmin("kafka", mockSystemAdmin)
+      .setSSPMetadataCache(sspMetadataCache)
       .setPartition(partition)
       .build
 
@@ -356,8 +323,7 @@ class TestTaskStorageManager extends MockitoSugar {
   }
 
   /**
-    * For instances of SystemAdmin, the store manager should call the slow getSystemStreamMetadata() method
-    * which gets offsets for ALL n partitions of the changelog, regardless of how many we need for the current task.
+    * Given that the SSPMetadataCache returns metadata, flush should create the offset files.
     */
   @Test
   def testFlushCreatesOffsetFileForLoggedStore() {
@@ -368,16 +334,16 @@ class TestTaskStorageManager extends MockitoSugar {
       TaskStorageManager.getStorePartitionDir(
         TaskStorageManagerBuilder.defaultLoggedStoreBaseDir, store, taskName) + File.separator + "OFFSET")
 
-    val mockSystemAdmin = mock[SystemAdmin]
-    val mockSspMetadata = Map("testStream" -> new SystemStreamMetadata("testStream" , Map(partition -> new SystemStreamPartitionMetadata("20", "100", "101")).asJava))
-    val myMap = mockSspMetadata.asJava
-    when(mockSystemAdmin.getSystemStreamMetadata(any(Set("").asJava.getClass))).thenReturn(myMap)
+    val sspMetadataCache = mock[SSPMetadataCache]
+    val sspMetadata = new SystemStreamPartitionMetadata("20", "100", "101")
+    when(sspMetadataCache.getMetadata(new SystemStreamPartition("kafka", "testStream", partition)))
+      .thenReturn(sspMetadata)
 
     //Build TaskStorageManager
     val taskStorageManager = new TaskStorageManagerBuilder()
             .addStore(loggedStore, true)
             .addStore(store, false)
-            .setSystemAdmin("kafka", mockSystemAdmin)
+            .setSSPMetadataCache(sspMetadataCache)
             .setPartition(partition)
             .build
 
@@ -392,22 +358,25 @@ class TestTaskStorageManager extends MockitoSugar {
   }
 
   /**
-    * For instances of ExtendedSystemAdmin, the store manager should call the optimized getNewestOffset() method.
-    * Flush should also delete the existing OFFSET file if the changelog partition (for some reason) becomes empty
+    * Flush should delete the existing OFFSET file if the changelog partition (for some reason) becomes empty
     */
   @Test
-  def testFlushCreatesOffsetFileForLoggedStoreExtendedSystemAdmin() {
+  def testFlushDeletesOffsetFileForLoggedStoreForEmptyPartition() {
     val partition = new Partition(0)
 
     val offsetFilePath = new File(TaskStorageManager.getStorePartitionDir(TaskStorageManagerBuilder.defaultLoggedStoreBaseDir, loggedStore, taskName) + File.separator + "OFFSET")
 
-    val mockSystemAdmin = mock[ExtendedSystemAdmin]
-    when(mockSystemAdmin.getNewestOffset(any(classOf[SystemStreamPartition]), anyInt())).thenReturn("100").thenReturn(null)
+    val sspMetadataCache = mock[SSPMetadataCache]
+    when(sspMetadataCache.getMetadata(new SystemStreamPartition("kafka", "testStream", partition)))
+      // first return some metadata
+      .thenReturn(new SystemStreamPartitionMetadata("0", "100", "101"))
+      // then return no metadata to trigger the delete
+      .thenReturn(null)
 
     //Build TaskStorageManager
     val taskStorageManager = new TaskStorageManagerBuilder()
             .addStore(loggedStore, true)
-            .setSystemAdmin("kafka", mockSystemAdmin)
+            .setSSPMetadataCache(sspMetadataCache)
             .setPartition(partition)
             .build
 
@@ -428,19 +397,18 @@ class TestTaskStorageManager extends MockitoSugar {
   @Test
   def testFlushOverwritesOffsetFileForLoggedStore() {
     val partition = new Partition(0)
+    val ssp = new SystemStreamPartition("kafka", "testStream", partition)
 
     val offsetFilePath = new File(TaskStorageManager.getStorePartitionDir(TaskStorageManagerBuilder.defaultLoggedStoreBaseDir, loggedStore, taskName) + File.separator + "OFFSET")
     FileUtil.writeWithChecksum(offsetFilePath, "100")
 
-    val mockSystemAdmin = mock[SystemAdmin]
-    var mockSspMetadata = Map("testStream" -> new SystemStreamMetadata("testStream" , Map(partition -> new SystemStreamPartitionMetadata("20", "139", "140")).asJava))
-    var myMap = mockSspMetadata.asJava
-    when(mockSystemAdmin.getSystemStreamMetadata(any(Set("").asJava.getClass))).thenReturn(myMap)
+    val sspMetadataCache = mock[SSPMetadataCache]
+    when(sspMetadataCache.getMetadata(ssp)).thenReturn(new SystemStreamPartitionMetadata("20", "139", "140"))
 
     //Build TaskStorageManager
     val taskStorageManager = new TaskStorageManagerBuilder()
             .addStore(loggedStore, true)
-            .setSystemAdmin("kafka", mockSystemAdmin)
+            .setSSPMetadataCache(sspMetadataCache)
             .setPartition(partition)
             .build
 
@@ -452,9 +420,7 @@ class TestTaskStorageManager extends MockitoSugar {
     assertEquals("Found incorrect value in offset file!", "139", FileUtil.readWithChecksum(offsetFilePath))
 
     // Flush again
-    mockSspMetadata = Map("testStream" -> new SystemStreamMetadata("testStream" , Map(partition -> new SystemStreamPartitionMetadata("20", "193", "194")).asJava))
-    myMap = mockSspMetadata.asJava
-    when(mockSystemAdmin.getSystemStreamMetadata(any(Set("").asJava.getClass))).thenReturn(myMap)
+    when(sspMetadataCache.getMetadata(ssp)).thenReturn(new SystemStreamPartitionMetadata("20", "193", "194"))
 
     //Invoke test method
     taskStorageManager.flush()
@@ -470,15 +436,13 @@ class TestTaskStorageManager extends MockitoSugar {
 
     val offsetFilePath = new File(TaskStorageManager.getStorePartitionDir(TaskStorageManagerBuilder.defaultLoggedStoreBaseDir, loggedStore, taskName) + File.separator + "OFFSET")
 
-    val mockSystemAdmin = mock[SystemAdmin]
-    val mockSspMetadata = Map("testStream" -> new SystemStreamMetadata("testStream" , Map(partition -> new SystemStreamPartitionMetadata("20", null, null)).asJava))
-    val myMap = mockSspMetadata.asJava
-    when(mockSystemAdmin.getSystemStreamMetadata(any(Set("").asJava.getClass))).thenReturn(myMap)
+    val sspMetadataCache = mock[SSPMetadataCache]
+    when(sspMetadataCache.getMetadata(new SystemStreamPartition("kafka", "testStream", partition))).thenReturn(null)
 
     //Build TaskStorageManager
     val taskStorageManager = new TaskStorageManagerBuilder()
       .addStore(loggedStore, true)
-      .setSystemAdmin("kafka", mockSystemAdmin)
+      .setSSPMetadataCache(sspMetadataCache)
       .setPartition(partition)
       .build
 
@@ -661,6 +625,7 @@ class TaskStorageManagerBuilder extends MockitoSugar {
   var storeConsumers: Map[String, SystemConsumer] = Map()
   var changeLogSystemStreams: Map[String, SystemStream] = Map()
   var streamMetadataCache = mock[StreamMetadataCache]
+  var sspMetadataCache = mock[SSPMetadataCache]
   var partition: Partition = new Partition(0)
   var systemAdminsMap: Map[String, SystemAdmin] = Map("kafka" -> mock[SystemAdmin])
   var taskName: TaskName = new TaskName("testTask")
@@ -707,6 +672,11 @@ class TaskStorageManagerBuilder extends MockitoSugar {
     this
   }
 
+  def setSSPMetadataCache(cache: SSPMetadataCache) = {
+    sspMetadataCache = cache
+    this
+  }
+
   def build: TaskStorageManager = {
     new TaskStorageManager(
       taskName = taskName,
@@ -715,6 +685,7 @@ class TaskStorageManagerBuilder extends MockitoSugar {
       changeLogSystemStreams = changeLogSystemStreams,
       changeLogStreamPartitions = changeLogStreamPartitions,
       streamMetadataCache = streamMetadataCache,
+      sspMetadataCache = sspMetadataCache,
       nonLoggedStoreBaseDir = storeBaseDir,
       loggedStoreBaseDir = loggedStoreBaseDir,
       partition = partition,