SAMZA-1719: Add caching support to table-api
authorPeng Du <pdu@linkedin.com>
Thu, 31 May 2018 17:43:30 +0000 (10:43 -0700)
committerJagadish <jvenkatraman@linkedin.com>
Thu, 31 May 2018 17:43:30 +0000 (10:43 -0700)
This change adds caching support for Samza tables. This is especially
useful for remote table where the accesses can have high latency for
applications that can tolerate staleness. Caching is added in the form
of a composite table that wraps the actual table and a cache. We reuse
the ReadWriteTable interface for the cache. A simple Guava cache-based
table is provided in this change.

Original PR was inadvertently closed: https://github.com/apache/samza/pull/524

Author: Peng Du <pdu@linkedin.com>

Reviewers: Jagadish <jagadish@apache.org>, Wei <wsong@linkedin.com>

Closes #531 from pdu-mn1/table-cache

14 files changed:
samza-core/src/main/java/org/apache/samza/table/caching/CachingTable.java [new file with mode: 0644]
samza-core/src/main/java/org/apache/samza/table/caching/CachingTableDescriptor.java [new file with mode: 0644]
samza-core/src/main/java/org/apache/samza/table/caching/CachingTableProvider.java [new file with mode: 0644]
samza-core/src/main/java/org/apache/samza/table/caching/CachingTableProviderFactory.java [new file with mode: 0644]
samza-core/src/main/java/org/apache/samza/table/caching/SupplierGauge.java [new file with mode: 0644]
samza-core/src/main/java/org/apache/samza/table/caching/guava/GuavaCacheTable.java [new file with mode: 0644]
samza-core/src/main/java/org/apache/samza/table/caching/guava/GuavaCacheTableDescriptor.java [new file with mode: 0644]
samza-core/src/main/java/org/apache/samza/table/caching/guava/GuavaCacheTableProvider.java [new file with mode: 0644]
samza-core/src/main/java/org/apache/samza/table/caching/guava/GuavaCacheTableProviderFactory.java [new file with mode: 0644]
samza-core/src/main/java/org/apache/samza/table/remote/RemoteTableDescriptor.java
samza-core/src/main/java/org/apache/samza/table/remote/RemoteTableProvider.java
samza-core/src/main/java/org/apache/samza/table/utils/SerdeUtils.java [new file with mode: 0644]
samza-core/src/test/java/org/apache/samza/table/caching/TestCachingTable.java [new file with mode: 0644]
samza-test/src/test/java/org/apache/samza/test/table/TestRemoteTable.java

diff --git a/samza-core/src/main/java/org/apache/samza/table/caching/CachingTable.java b/samza-core/src/main/java/org/apache/samza/table/caching/CachingTable.java
new file mode 100644 (file)
index 0000000..989828c
--- /dev/null
@@ -0,0 +1,199 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.table.caching;
+
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.atomic.AtomicLong;
+import java.util.concurrent.locks.Lock;
+
+import org.apache.samza.container.SamzaContainerContext;
+import org.apache.samza.metrics.MetricsRegistry;
+import org.apache.samza.storage.kv.Entry;
+import org.apache.samza.table.ReadWriteTable;
+import org.apache.samza.table.ReadableTable;
+import org.apache.samza.task.TaskContext;
+
+import com.google.common.base.Preconditions;
+import com.google.common.util.concurrent.Striped;
+
+
+/**
+ * A composite table incorporating a cache with a Samza table. The cache is
+ * represented as a {@link ReadWriteTable}.
+ *
+ * The intented use case is to optimize the latency of accessing the actual table, eg.
+ * remote tables, when eventual consistency between cache and table is acceptable.
+ * The cache is expected to support TTL such that the values can be refreshed at some
+ * point.
+ *
+ * If the actual table is read-write table, CachingTable supports both write-through
+ * and write-around (writes bypassing cache) policies. For write-through policy, it
+ * supports read-after-write semantics because the value is cached after written to
+ * the table.
+ *
+ * Table and cache are updated (put/delete) in an atomic manner as such it is thread
+ * safe for concurrent accesses. Strip locks are used for fine-grained synchronization
+ * and the number of stripes is configurable.
+ *
+ * NOTE: Cache get is not synchronized with put for better parallelism in the read path.
+ * As such, cache table implementation is expected to be thread-safe for concurrent
+ * accesses.
+ *
+ * @param <K> type of the table key
+ * @param <V> type of the table value
+ */
+public class CachingTable<K, V> implements ReadWriteTable<K, V> {
+  private static final String GROUP_NAME = CachingTable.class.getSimpleName();
+
+  private final String tableId;
+  private final ReadableTable<K, V> rdTable;
+  private final ReadWriteTable<K, V> rwTable;
+  private final ReadWriteTable<K, V> cache;
+  private final boolean isWriteAround;
+
+  // Use stripe based locking to allow parallelism of disjoint keys.
+  private final Striped<Lock> stripedLocks;
+
+  // Common caching stats
+  private AtomicLong hitCount = new AtomicLong();
+  private AtomicLong missCount = new AtomicLong();
+
+  public CachingTable(String tableId, ReadableTable<K, V> table, ReadWriteTable<K, V> cache, int stripes, boolean isWriteAround) {
+    this.tableId = tableId;
+    this.rdTable = table;
+    this.rwTable = table instanceof ReadWriteTable ? (ReadWriteTable) table : null;
+    this.cache = cache;
+    this.isWriteAround = isWriteAround;
+    this.stripedLocks = Striped.lazyWeakLock(stripes);
+  }
+
+  /**
+   * {@inheritDoc}
+   */
+  @Override
+  public void init(SamzaContainerContext containerContext, TaskContext taskContext) {
+    MetricsRegistry metricsRegistry = taskContext.getMetricsRegistry();
+    metricsRegistry.newGauge(GROUP_NAME, new SupplierGauge(tableId + "-hit-rate", () -> hitRate()));
+    metricsRegistry.newGauge(GROUP_NAME, new SupplierGauge(tableId + "-miss-rate", () -> missRate()));
+    metricsRegistry.newGauge(GROUP_NAME, new SupplierGauge(tableId + "-req-count", () -> requestCount()));
+  }
+
+  @Override
+  public V get(K key) {
+    V value = cache.get(key);
+    if (value == null) {
+      missCount.incrementAndGet();
+      Lock lock = stripedLocks.get(key);
+      try {
+        lock.lock();
+        if (cache.get(key) == null) {
+          // Due to the lack of contains() API in ReadableTable, there is
+          // no way to tell whether a null return by cache.get(key) means
+          // cache miss or the value is actually null. As such, we cannot
+          // support negative cache semantics.
+          value = rdTable.get(key);
+          if (value != null) {
+            cache.put(key, value);
+          }
+        }
+      } finally {
+        lock.unlock();
+      }
+    } else {
+      hitCount.incrementAndGet();
+    }
+    return value;
+  }
+
+  @Override
+  public Map<K, V> getAll(List<K> keys) {
+    Map<K, V> getAllResult = new HashMap<>();
+    keys.stream().forEach(k -> getAllResult.put(k, get(k)));
+    return getAllResult;
+  }
+
+  @Override
+  public void put(K key, V value) {
+    Preconditions.checkNotNull(rwTable, "Cannot write to a read-only table: " + rdTable);
+    Lock lock = stripedLocks.get(key);
+    try {
+      lock.lock();
+      rwTable.put(key, value);
+      if (!isWriteAround) {
+        cache.put(key, value);
+      }
+    } finally {
+      lock.unlock();
+    }
+  }
+
+  @Override
+  public void putAll(List<Entry<K, V>> entries) {
+    Preconditions.checkNotNull(rwTable, "Cannot write to a read-only table: " + rdTable);
+    entries.forEach(e -> put(e.getKey(), e.getValue()));
+  }
+
+  @Override
+  public void delete(K key) {
+    Preconditions.checkNotNull(rwTable, "Cannot delete from a read-only table: " + rdTable);
+    Lock lock = stripedLocks.get(key);
+    try {
+      lock.lock();
+      rwTable.delete(key);
+      cache.delete(key);
+    } finally {
+      lock.unlock();
+    }
+  }
+
+  @Override
+  public void deleteAll(List<K> keys) {
+    Preconditions.checkNotNull(rwTable, "Cannot delete from a read-only table: " + rdTable);
+    keys.stream().forEach(k -> delete(k));
+  }
+
+  @Override
+  public synchronized void flush() {
+    Preconditions.checkNotNull(rwTable, "Cannot flush a read-only table: " + rdTable);
+    rwTable.flush();
+  }
+
+  @Override
+  public void close() {
+    this.cache.close();
+    this.rdTable.close();
+  }
+
+  double hitRate() {
+    long reqs = requestCount();
+    return reqs == 0 ? 1.0 : (double) hitCount.get() / reqs;
+  }
+
+  double missRate() {
+    long reqs = requestCount();
+    return reqs == 0 ? 1.0 : (double) missCount.get() / reqs;
+  }
+
+  long requestCount() {
+    return hitCount.get() + missCount.get();
+  }
+}
diff --git a/samza-core/src/main/java/org/apache/samza/table/caching/CachingTableDescriptor.java b/samza-core/src/main/java/org/apache/samza/table/caching/CachingTableDescriptor.java
new file mode 100644 (file)
index 0000000..eb74825
--- /dev/null
@@ -0,0 +1,173 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.table.caching;
+
+import java.time.Duration;
+import java.util.HashMap;
+import java.util.Map;
+
+import org.apache.samza.operators.BaseTableDescriptor;
+import org.apache.samza.operators.KV;
+import org.apache.samza.operators.TableImpl;
+import org.apache.samza.table.Table;
+import org.apache.samza.table.TableSpec;
+
+import com.google.common.base.Preconditions;
+
+/**
+ * Table descriptor for {@link CachingTable}.
+ * @param <K> type of the key in the cache
+ * @param <V> type of the value in the cache
+ */
+public class CachingTableDescriptor<K, V> extends BaseTableDescriptor<K, V, CachingTableDescriptor<K, V>> {
+  private Duration readTtl;
+  private Duration writeTtl;
+  private long cacheSize;
+  private Table<KV<K, V>> cache;
+  private Table<KV<K, V>> table;
+  private int stripes = 16;
+  private boolean isWriteAround;
+
+  /**
+   * Constructs a table descriptor instance
+   * @param tableId Id of the table
+   */
+  public CachingTableDescriptor(String tableId) {
+    super(tableId);
+  }
+
+  @Override
+  public TableSpec getTableSpec() {
+    validate();
+
+    Map<String, String> tableSpecConfig = new HashMap<>();
+    generateTableSpecConfig(tableSpecConfig);
+
+    if (cache != null) {
+      tableSpecConfig.put(CachingTableProvider.CACHE_TABLE_ID, ((TableImpl) cache).getTableSpec().getId());
+    } else {
+      if (readTtl != null) {
+        tableSpecConfig.put(CachingTableProvider.READ_TTL_MS, String.valueOf(readTtl.toMillis()));
+      }
+      if (writeTtl != null) {
+        tableSpecConfig.put(CachingTableProvider.WRITE_TTL_MS, String.valueOf(writeTtl.toMillis()));
+      }
+      if (cacheSize > 0) {
+        tableSpecConfig.put(CachingTableProvider.CACHE_SIZE, String.valueOf(cacheSize));
+      }
+    }
+
+    tableSpecConfig.put(CachingTableProvider.REAL_TABLE_ID, ((TableImpl) table).getTableSpec().getId());
+    tableSpecConfig.put(CachingTableProvider.LOCK_STRIPES, String.valueOf(stripes));
+    tableSpecConfig.put(CachingTableProvider.WRITE_AROUND, String.valueOf(isWriteAround));
+
+    return new TableSpec(tableId, serde, CachingTableProviderFactory.class.getName(), tableSpecConfig);
+  }
+
+  /**
+   * Specify a cache instance (as Table abstraction) to be used for caching.
+   * Cache get is not synchronized with put for better parallelism in the read path
+   * of {@link CachingTable}. As such, cache table implementation is expected to be
+   * thread-safe for concurrent accesses.
+   * @param cache cache instance
+   * @return this descriptor
+   */
+  public CachingTableDescriptor withCache(Table<KV<K, V>> cache) {
+    this.cache = cache;
+    return this;
+  }
+
+  /**
+   * Specify the table instance for the actual table input/output.
+   * @param table table instance
+   * @return this descriptor
+   */
+  public CachingTableDescriptor withTable(Table<KV<K, V>> table) {
+    this.table = table;
+    return this;
+  }
+
+  /**
+   * Specify the TTL for each read access, ie. record is expired after
+   * the TTL duration since last read access of each key.
+   * @param readTtl read TTL
+   * @return this descriptor
+   */
+  public CachingTableDescriptor withReadTtl(Duration readTtl) {
+    this.readTtl = readTtl;
+    return this;
+  }
+
+  /**
+   * Specify the TTL for each write access, ie. record is expired after
+   * the TTL duration since last write access of each key.
+   * @param writeTtl write TTL
+   * @return this descriptor
+   */
+  public CachingTableDescriptor withWriteTtl(Duration writeTtl) {
+    this.writeTtl = writeTtl;
+    return this;
+  }
+
+  /**
+   * Specify the max cache size for size-based eviction.
+   * @param cacheSize max size of the cache
+   * @return this descriptor
+   */
+  public CachingTableDescriptor withCacheSize(long cacheSize) {
+    this.cacheSize = cacheSize;
+    return this;
+  }
+
+  /**
+   * Specify the number of stripes for striped locking for atomically updating
+   * cache and the actual table. Default number of stripes is 16.
+   * @param stripes number of stripes for locking
+   * @return this descriptor
+   */
+  public CachingTableDescriptor withStripes(int stripes) {
+    this.stripes = stripes;
+    return this;
+  }
+
+  /**
+   * Specify if write-around policy should be used to bypass writing
+   * to cache for put operations. This is useful when put() is the
+   * dominant operation and get() has no locality with recent puts.
+   * @return this descriptor
+   */
+  public CachingTableDescriptor withWriteAround() {
+    this.isWriteAround = true;
+    return this;
+  }
+
+  @Override
+  protected void validate() {
+    super.validate();
+    Preconditions.checkNotNull(table, "Actual table is required.");
+    if (cache == null) {
+      Preconditions.checkNotNull(readTtl, "readTtl must be specified.");
+    } else {
+      Preconditions.checkArgument(readTtl == null && writeTtl == null && cacheSize == 0,
+          "Invalid to specify both {cache} and {readTtl|writeTtl|cacheSize} at the same time.");
+    }
+    Preconditions.checkArgument(stripes > 0, "Number of cache stripes must be positive.");
+  }
+}
diff --git a/samza-core/src/main/java/org/apache/samza/table/caching/CachingTableProvider.java b/samza-core/src/main/java/org/apache/samza/table/caching/CachingTableProvider.java
new file mode 100644 (file)
index 0000000..52d0c94
--- /dev/null
@@ -0,0 +1,138 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.table.caching;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.TimeUnit;
+
+import org.apache.samza.config.JavaTableConfig;
+import org.apache.samza.container.SamzaContainerContext;
+import org.apache.samza.table.ReadWriteTable;
+import org.apache.samza.table.ReadableTable;
+import org.apache.samza.table.Table;
+import org.apache.samza.table.TableProvider;
+import org.apache.samza.table.TableSpec;
+import org.apache.samza.table.caching.guava.GuavaCacheTable;
+import org.apache.samza.task.TaskContext;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import com.google.common.cache.CacheBuilder;
+
+/**
+ * Table provider for {@link CachingTable}.
+ */
+public class CachingTableProvider implements TableProvider {
+  private static final Logger LOG = LoggerFactory.getLogger(CachingTableProvider.class);
+
+  public static final String REAL_TABLE_ID = "realTableId";
+  public static final String CACHE_TABLE_ID = "cacheTableId";
+  public static final String READ_TTL_MS = "readTtl";
+  public static final String WRITE_TTL_MS = "writeTtl";
+  public static final String CACHE_SIZE = "cacheSize";
+  public static final String LOCK_STRIPES = "lockStripes";
+  public static final String WRITE_AROUND = "writeAround";
+
+  private final TableSpec cachingTableSpec;
+
+  // Store the cache instances created by default
+  private final List<ReadWriteTable> defaultCaches = new ArrayList<>();
+
+  private SamzaContainerContext containerContext;
+  private TaskContext taskContext;
+
+  public CachingTableProvider(TableSpec tableSpec) {
+    this.cachingTableSpec = tableSpec;
+  }
+
+  @Override
+  public void init(SamzaContainerContext containerContext, TaskContext taskContext) {
+    this.taskContext = taskContext;
+    this.containerContext = containerContext;
+  }
+
+  @Override
+  public Table getTable() {
+    String realTableId = cachingTableSpec.getConfig().get(REAL_TABLE_ID);
+    ReadableTable table = (ReadableTable) taskContext.getTable(realTableId);
+
+    String cacheTableId = cachingTableSpec.getConfig().get(CACHE_TABLE_ID);
+    ReadWriteTable cache;
+
+    if (cacheTableId != null) {
+      cache = (ReadWriteTable) taskContext.getTable(cacheTableId);
+    } else {
+      cache = createDefaultCacheTable(realTableId);
+      defaultCaches.add(cache);
+    }
+
+    int stripes = Integer.parseInt(cachingTableSpec.getConfig().get(LOCK_STRIPES));
+    boolean isWriteAround = Boolean.parseBoolean(cachingTableSpec.getConfig().get(WRITE_AROUND));
+    return new CachingTable(cachingTableSpec.getId(), table, cache, stripes, isWriteAround);
+  }
+
+  @Override
+  public Map<String, String> generateConfig(Map<String, String> config) {
+    Map<String, String> tableConfig = new HashMap<>();
+
+    // Insert table_id prefix to config entries
+    cachingTableSpec.getConfig().forEach((k, v) -> {
+        String realKey = String.format(JavaTableConfig.TABLE_ID_PREFIX, cachingTableSpec.getId()) + "." + k;
+        tableConfig.put(realKey, v);
+      });
+
+    LOG.info("Generated configuration for table " + cachingTableSpec.getId());
+
+    return tableConfig;
+  }
+
+  @Override
+  public void close() {
+    defaultCaches.forEach(c -> c.close());
+  }
+
+  private ReadWriteTable createDefaultCacheTable(String tableId) {
+    long readTtlMs = Long.parseLong(cachingTableSpec.getConfig().getOrDefault(READ_TTL_MS, "-1"));
+    long writeTtlMs = Long.parseLong(cachingTableSpec.getConfig().getOrDefault(WRITE_TTL_MS, "-1"));
+    long cacheSize = Long.parseLong(cachingTableSpec.getConfig().getOrDefault(CACHE_SIZE, "-1"));
+
+    CacheBuilder cacheBuilder = CacheBuilder.newBuilder();
+    if (readTtlMs != -1) {
+      cacheBuilder.expireAfterAccess(readTtlMs, TimeUnit.MILLISECONDS);
+    }
+    if (writeTtlMs != -1) {
+      cacheBuilder.expireAfterWrite(writeTtlMs, TimeUnit.MILLISECONDS);
+    }
+    if (cacheSize != -1) {
+      cacheBuilder.maximumSize(cacheSize);
+    }
+
+    LOG.info(String.format("Creating default cache with: readTtl=%d, writeTtl=%d, maxSize=%d",
+        readTtlMs, writeTtlMs, cacheSize));
+
+    GuavaCacheTable cacheTable = new GuavaCacheTable(tableId + "-def-cache", cacheBuilder.build());
+    cacheTable.init(containerContext, taskContext);
+
+    return cacheTable;
+  }
+}
diff --git a/samza-core/src/main/java/org/apache/samza/table/caching/CachingTableProviderFactory.java b/samza-core/src/main/java/org/apache/samza/table/caching/CachingTableProviderFactory.java
new file mode 100644 (file)
index 0000000..9262207
--- /dev/null
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.table.caching;
+
+import org.apache.samza.table.TableProvider;
+import org.apache.samza.table.TableProviderFactory;
+import org.apache.samza.table.TableSpec;
+
+/**
+ * Table provider factory for {@link CachingTable}.
+ */
+public class CachingTableProviderFactory implements TableProviderFactory {
+  @Override
+  public TableProvider getTableProvider(TableSpec tableSpec) {
+    return new CachingTableProvider(tableSpec);
+  }
+}
diff --git a/samza-core/src/main/java/org/apache/samza/table/caching/SupplierGauge.java b/samza-core/src/main/java/org/apache/samza/table/caching/SupplierGauge.java
new file mode 100644 (file)
index 0000000..8c8d4dc
--- /dev/null
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.table.caching;
+
+import java.util.function.Supplier;
+
+import org.apache.samza.metrics.Gauge;
+
+import com.google.common.base.Preconditions;
+
+
+/**
+ * Simple Gauge backed by an external supplier expression.
+ * @param <T> data type of the gauge
+ */
+public class SupplierGauge<T> extends Gauge<T> {
+  private Supplier<T> supplier;
+
+  public SupplierGauge(String name, Supplier<T> supplier) {
+    super(name, null);
+    Preconditions.checkNotNull(supplier);
+    this.supplier = supplier;
+  }
+
+  @Override
+  public T getValue() {
+    return supplier.get();
+  }
+}
diff --git a/samza-core/src/main/java/org/apache/samza/table/caching/guava/GuavaCacheTable.java b/samza-core/src/main/java/org/apache/samza/table/caching/guava/GuavaCacheTable.java
new file mode 100644 (file)
index 0000000..27bf971
--- /dev/null
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.table.caching.guava;
+
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import org.apache.samza.container.SamzaContainerContext;
+import org.apache.samza.metrics.MetricsRegistry;
+import org.apache.samza.storage.kv.Entry;
+import org.apache.samza.table.ReadWriteTable;
+import org.apache.samza.table.caching.SupplierGauge;
+import org.apache.samza.task.TaskContext;
+
+import com.google.common.cache.Cache;
+
+
+/**
+ * Simple cache table backed by a Guava cache instance. Application is expect to build
+ * a cache instance with desired parameters and specify it to the table descriptor.
+ *
+ * @param <K> type of the key in the cache
+ * @param <V> type of the value in the cache
+ */
+public class GuavaCacheTable<K, V> implements ReadWriteTable<K, V> {
+  private static final String GROUP_NAME = GuavaCacheTableProvider.class.getSimpleName();
+
+  private final String tableId;
+  private final Cache<K, V> cache;
+
+  public GuavaCacheTable(String tableId, Cache<K, V> cache) {
+    this.tableId = tableId;
+    this.cache = cache;
+  }
+
+  private void registerMetrics(String tableId, Cache cache, MetricsRegistry metricsReg) {
+    // hit- and miss-rate are provided by CachingTable.
+    metricsReg.newGauge(GROUP_NAME, new SupplierGauge(tableId + "-evict-count", () -> cache.stats().evictionCount()));
+  }
+
+  /**
+   * {@inheritDoc}
+   */
+  @Override
+  public void init(SamzaContainerContext containerContext, TaskContext taskContext) {
+    registerMetrics(tableId, cache, taskContext.getMetricsRegistry());
+  }
+
+  @Override
+  public void put(K key, V value) {
+    cache.put(key, value);
+  }
+
+  @Override
+  public void putAll(List<Entry<K, V>> entries) {
+    entries.forEach(e -> put(e.getKey(), e.getValue()));
+  }
+
+  @Override
+  public void delete(K key) {
+    cache.invalidate(key);
+  }
+
+  @Override
+  public void deleteAll(List<K> keys) {
+    keys.forEach(k -> delete(k));
+  }
+
+  @Override
+  public synchronized void flush() {
+    cache.cleanUp();
+  }
+
+  @Override
+  public V get(K key) {
+    return cache.getIfPresent(key);
+  }
+
+  @Override
+  public Map<K, V> getAll(List<K> keys) {
+    Map<K, V> getAllResult = new HashMap<>();
+    keys.stream().forEach(k -> getAllResult.put(k, get(k)));
+    return getAllResult;
+  }
+
+  @Override
+  public synchronized void close() {
+    cache.invalidateAll();
+  }
+}
diff --git a/samza-core/src/main/java/org/apache/samza/table/caching/guava/GuavaCacheTableDescriptor.java b/samza-core/src/main/java/org/apache/samza/table/caching/guava/GuavaCacheTableDescriptor.java
new file mode 100644 (file)
index 0000000..ce125c0
--- /dev/null
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.table.caching.guava;
+
+import java.util.HashMap;
+import java.util.Map;
+
+import org.apache.samza.operators.BaseTableDescriptor;
+import org.apache.samza.table.TableSpec;
+import org.apache.samza.table.utils.SerdeUtils;
+
+import com.google.common.base.Preconditions;
+import com.google.common.cache.Cache;
+
+
+/**
+ * Table descriptor for {@link GuavaCacheTable}.
+ * @param <K> type of the key in the cache
+ * @param <V> type of the value in the cache
+ */
+public class GuavaCacheTableDescriptor<K, V> extends BaseTableDescriptor<K, V, GuavaCacheTableDescriptor<K, V>> {
+  private Cache<K, V> cache;
+
+  /**
+   * Constructs a table descriptor instance
+   * @param tableId Id of the table
+   */
+  public GuavaCacheTableDescriptor(String tableId) {
+    super(tableId);
+  }
+
+  @Override
+  public TableSpec getTableSpec() {
+    validate();
+
+    Map<String, String> tableSpecConfig = new HashMap<>();
+    generateTableSpecConfig(tableSpecConfig);
+
+    tableSpecConfig.put(GuavaCacheTableProvider.GUAVA_CACHE, SerdeUtils.serialize("Guava cache", cache));
+
+    return new TableSpec(tableId, serde, GuavaCacheTableProviderFactory.class.getName(), tableSpecConfig);
+  }
+
+  /**
+   * Specify a pre-configured Guava cache instance to be used for caching table.
+   * @param cache Guava cache instance
+   * @return this descriptor
+   */
+  public GuavaCacheTableDescriptor withCache(Cache<K, V> cache) {
+    this.cache = cache;
+    return this;
+  }
+
+  @Override
+  protected void validate() {
+    super.validate();
+    Preconditions.checkArgument(cache != null, "Must provide a Guava cache instance.");
+  }
+}
diff --git a/samza-core/src/main/java/org/apache/samza/table/caching/guava/GuavaCacheTableProvider.java b/samza-core/src/main/java/org/apache/samza/table/caching/guava/GuavaCacheTableProvider.java
new file mode 100644 (file)
index 0000000..7395b22
--- /dev/null
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.table.caching.guava;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import org.apache.samza.config.JavaTableConfig;
+import org.apache.samza.container.SamzaContainerContext;
+import org.apache.samza.table.Table;
+import org.apache.samza.table.TableProvider;
+import org.apache.samza.table.TableSpec;
+import org.apache.samza.table.utils.SerdeUtils;
+import org.apache.samza.task.TaskContext;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import com.google.common.cache.Cache;
+
+
+/**
+ * Table provider for {@link GuavaCacheTable}.
+ */
+public class GuavaCacheTableProvider implements TableProvider {
+  private static final Logger LOG = LoggerFactory.getLogger(GuavaCacheTableProvider.class);
+
+  public static final String GUAVA_CACHE = "guavaCache";
+
+  private final TableSpec guavaCacheTableSpec;
+
+  private SamzaContainerContext containerContext;
+  private TaskContext taskContext;
+
+  private List<GuavaCacheTable> guavaTables = new ArrayList<>();
+
+  public GuavaCacheTableProvider(TableSpec tableSpec) {
+    this.guavaCacheTableSpec = tableSpec;
+  }
+
+  @Override
+  public void init(SamzaContainerContext containerContext, TaskContext taskContext) {
+    this.taskContext = taskContext;
+    this.containerContext = containerContext;
+  }
+
+  @Override
+  public Table getTable() {
+    Cache guavaCache = SerdeUtils.deserialize(GUAVA_CACHE, guavaCacheTableSpec.getConfig().get(GUAVA_CACHE));
+    GuavaCacheTable table = new GuavaCacheTable(guavaCacheTableSpec.getId(), guavaCache);
+    guavaTables.add(table);
+    return table;
+  }
+
+  @Override
+  public Map<String, String> generateConfig(Map<String, String> config) {
+    Map<String, String> tableConfig = new HashMap<>();
+
+    // Insert table_id prefix to config entries
+    guavaCacheTableSpec.getConfig().forEach((k, v) -> {
+        String realKey = String.format(JavaTableConfig.TABLE_ID_PREFIX, guavaCacheTableSpec.getId()) + "." + k;
+        tableConfig.put(realKey, v);
+      });
+
+    LOG.info("Generated configuration for table " + guavaCacheTableSpec.getId());
+
+    return tableConfig;
+  }
+
+  @Override
+  public void close() {
+    guavaTables.forEach(t -> t.close());
+  }
+}
diff --git a/samza-core/src/main/java/org/apache/samza/table/caching/guava/GuavaCacheTableProviderFactory.java b/samza-core/src/main/java/org/apache/samza/table/caching/guava/GuavaCacheTableProviderFactory.java
new file mode 100644 (file)
index 0000000..066c6f9
--- /dev/null
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.table.caching.guava;
+
+import org.apache.samza.table.TableProvider;
+import org.apache.samza.table.TableProviderFactory;
+import org.apache.samza.table.TableSpec;
+
+/**
+ * Table provider factory for {@link GuavaCacheTable}.
+ */
+public class GuavaCacheTableProviderFactory implements TableProviderFactory {
+  @Override
+  public TableProvider getTableProvider(TableSpec tableSpec) {
+    return new GuavaCacheTableProvider(tableSpec);
+  }
+}
index 7bc369d..bad4639 100644 (file)
 
 package org.apache.samza.table.remote;
 
-import java.io.ByteArrayOutputStream;
-import java.io.IOException;
-import java.io.ObjectOutputStream;
-import java.util.Base64;
 import java.util.HashMap;
 import java.util.Map;
 
-import org.apache.samza.SamzaException;
 import org.apache.samza.operators.BaseTableDescriptor;
 import org.apache.samza.table.TableSpec;
+import org.apache.samza.table.utils.SerdeUtils;
 import org.apache.samza.util.EmbeddedTaggedRateLimiter;
 import org.apache.samza.util.RateLimiter;
 
@@ -89,10 +85,10 @@ public class RemoteTableDescriptor<K, V> extends BaseTableDescriptor<K, V, Remot
     generateTableSpecConfig(tableSpecConfig);
 
     // Serialize and store reader/writer functions
-    tableSpecConfig.put(RemoteTableProvider.READ_FN, serializeObject("read function", readFn));
+    tableSpecConfig.put(RemoteTableProvider.READ_FN, SerdeUtils.serialize("read function", readFn));
 
     if (writeFn != null) {
-      tableSpecConfig.put(RemoteTableProvider.WRITE_FN, serializeObject("write function", writeFn));
+      tableSpecConfig.put(RemoteTableProvider.WRITE_FN, SerdeUtils.serialize("write function", writeFn));
     }
 
     // Serialize the rate limiter if specified
@@ -101,17 +97,17 @@ public class RemoteTableDescriptor<K, V> extends BaseTableDescriptor<K, V, Remot
     }
 
     if (rateLimiter != null) {
-      tableSpecConfig.put(RemoteTableProvider.RATE_LIMITER, serializeObject("rate limiter", rateLimiter));
+      tableSpecConfig.put(RemoteTableProvider.RATE_LIMITER, SerdeUtils.serialize("rate limiter", rateLimiter));
     }
 
     // Serialize the readCredit and writeCredit functions
     if (readCreditFn != null) {
-      tableSpecConfig.put(RemoteTableProvider.READ_CREDIT_FN, serializeObject(
+      tableSpecConfig.put(RemoteTableProvider.READ_CREDIT_FN, SerdeUtils.serialize(
           "read credit function", readCreditFn));
     }
 
     if (writeCreditFn != null) {
-      tableSpecConfig.put(RemoteTableProvider.WRITE_CREDIT_FN, serializeObject(
+      tableSpecConfig.put(RemoteTableProvider.WRITE_CREDIT_FN, SerdeUtils.serialize(
           "write credit function", writeCreditFn));
     }
 
@@ -188,22 +184,6 @@ public class RemoteTableDescriptor<K, V> extends BaseTableDescriptor<K, V, Remot
     return this;
   }
 
-  /**
-   * Helper method to serialize Java objects as Base64 strings
-   * @param name name of the object (for error reporting)
-   * @param object object to be serialized
-   * @return Base64 representation of the object
-   */
-  private <T> String serializeObject(String name, T object) {
-    try (ByteArrayOutputStream baos = new ByteArrayOutputStream();
-        ObjectOutputStream oos = new ObjectOutputStream(baos)) {
-      oos.writeObject(object);
-      return Base64.getEncoder().encodeToString(baos.toByteArray());
-    } catch (IOException e) {
-      throw new SamzaException("Failed to serialize " + name, e);
-    }
-  }
-
   @Override
   protected void validate() {
     super.validate();
index 8b9001a..b4051cb 100644 (file)
 
 package org.apache.samza.table.remote;
 
-import java.io.ByteArrayInputStream;
-import java.io.ObjectInputStream;
 import java.util.ArrayList;
-import java.util.Base64;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 
-import org.apache.samza.SamzaException;
 import org.apache.samza.config.JavaTableConfig;
 import org.apache.samza.container.SamzaContainerContext;
 import org.apache.samza.table.Table;
 import org.apache.samza.table.TableProvider;
 import org.apache.samza.table.TableSpec;
+import org.apache.samza.table.utils.SerdeUtils;
 import org.apache.samza.task.TaskContext;
 import org.apache.samza.util.RateLimiter;
 import org.slf4j.Logger;
@@ -101,7 +98,7 @@ public class RemoteTableProvider implements TableProvider {
   public Map<String, String> generateConfig(Map<String, String> config) {
     Map<String, String> tableConfig = new HashMap<>();
 
-    // Insert table_id prefix to config entires
+    // Insert table_id prefix to config entries
     tableSpec.getConfig().forEach((k, v) -> {
         String realKey = String.format(JavaTableConfig.TABLE_ID_PREFIX, tableSpec.getId()) + "." + k;
         tableConfig.put(realKey, v);
@@ -125,14 +122,7 @@ public class RemoteTableProvider implements TableProvider {
     if (entry.isEmpty()) {
       return null;
     }
-
-    try {
-      byte [] bytes = Base64.getDecoder().decode(entry);
-      return (T) new ObjectInputStream(new ByteArrayInputStream(bytes)).readObject();
-    } catch (Exception e) {
-      String errMsg = "Failed to deserialize " + key;
-      throw new SamzaException(errMsg, e);
-    }
+    return SerdeUtils.deserialize(key, entry);
   }
 
   private TableReadFunction<?, ?> getReadFn() {
diff --git a/samza-core/src/main/java/org/apache/samza/table/utils/SerdeUtils.java b/samza-core/src/main/java/org/apache/samza/table/utils/SerdeUtils.java
new file mode 100644 (file)
index 0000000..a7b66e5
--- /dev/null
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.table.utils;
+
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.ObjectInputStream;
+import java.io.ObjectOutputStream;
+import java.util.Base64;
+
+import org.apache.samza.SamzaException;
+
+
+public final class SerdeUtils {
+  /**
+   * Helper method to serialize Java objects as Base64 strings
+   * @param name name of the object (for error reporting)
+   * @param object object to be serialized
+   * @return Base64 representation of the object
+   * @param <T> type of the object
+   */
+  public static <T> String serialize(String name, T object) {
+    try (ByteArrayOutputStream baos = new ByteArrayOutputStream();
+        ObjectOutputStream oos = new ObjectOutputStream(baos)) {
+      oos.writeObject(object);
+      return Base64.getEncoder().encodeToString(baos.toByteArray());
+    } catch (IOException e) {
+      throw new SamzaException("Failed to serialize " + name, e);
+    }
+  }
+
+  /**
+   * Helper method to deserialize Java objects from Base64 strings
+   * @param name name of the object (for error reporting)
+   * @param strObject base64 string of the serialized object
+   * @return deserialized object instance
+   * @param <T> type of the object
+   */
+  public static <T> T deserialize(String name, String strObject) {
+    try {
+      byte [] bytes = Base64.getDecoder().decode(strObject);
+      return (T) new ObjectInputStream(new ByteArrayInputStream(bytes)).readObject();
+    } catch (Exception e) {
+      String errMsg = "Failed to deserialize " + name;
+      throw new SamzaException(errMsg, e);
+    }
+  }
+}
diff --git a/samza-core/src/test/java/org/apache/samza/table/caching/TestCachingTable.java b/samza-core/src/test/java/org/apache/samza/table/caching/TestCachingTable.java
new file mode 100644 (file)
index 0000000..769fb7d
--- /dev/null
@@ -0,0 +1,299 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.table.caching;
+
+import java.time.Duration;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Random;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.TimeUnit;
+
+import org.apache.commons.lang3.tuple.Pair;
+import org.apache.samza.operators.TableImpl;
+import org.apache.samza.table.ReadWriteTable;
+import org.apache.samza.table.ReadableTable;
+import org.apache.samza.table.Table;
+import org.apache.samza.table.TableSpec;
+import org.apache.samza.table.caching.guava.GuavaCacheTableDescriptor;
+import org.apache.samza.table.caching.guava.GuavaCacheTableProvider;
+import org.apache.samza.task.TaskContext;
+import org.junit.Assert;
+import org.junit.Test;
+
+import com.google.common.cache.CacheBuilder;
+
+import static org.mockito.Matchers.any;
+import static org.mockito.Matchers.anyString;
+import static org.mockito.Mockito.doAnswer;
+import static org.mockito.Mockito.doReturn;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+
+
+public class TestCachingTable {
+  @Test
+  public void testSerializeSimple() {
+    doTestSerialize(null);
+  }
+
+  @Test
+  public void testSerializeWithCacheInstance() {
+    GuavaCacheTableDescriptor guavaTableDesc = new GuavaCacheTableDescriptor("guavaCacheId");
+    guavaTableDesc.withCache(CacheBuilder.newBuilder().build());
+    TableSpec spec = guavaTableDesc.getTableSpec();
+    Assert.assertTrue(spec.getConfig().containsKey(GuavaCacheTableProvider.GUAVA_CACHE));
+    doTestSerialize(new TableImpl(guavaTableDesc.getTableSpec()));
+  }
+
+  private void doTestSerialize(Table cache) {
+    CachingTableDescriptor desc = new CachingTableDescriptor("1");
+    desc.withTable(new TableImpl(new TableSpec("2", null, null, new HashMap<>())));
+    if (cache == null) {
+      desc.withReadTtl(Duration.ofMinutes(3));
+      desc.withWriteTtl(Duration.ofMinutes(3));
+      desc.withCacheSize(1000);
+    } else {
+      desc.withCache(cache);
+    }
+
+    desc.withStripes(32);
+    desc.withWriteAround();
+
+    TableSpec spec = desc.getTableSpec();
+    Assert.assertTrue(spec.getConfig().containsKey(CachingTableProvider.REAL_TABLE_ID));
+
+    if (cache == null) {
+      Assert.assertTrue(spec.getConfig().containsKey(CachingTableProvider.READ_TTL_MS));
+      Assert.assertTrue(spec.getConfig().containsKey(CachingTableProvider.WRITE_TTL_MS));
+    } else {
+      Assert.assertTrue(spec.getConfig().containsKey(CachingTableProvider.CACHE_TABLE_ID));
+    }
+
+    Assert.assertEquals("32", spec.getConfig().get(CachingTableProvider.LOCK_STRIPES));
+    Assert.assertEquals("true", spec.getConfig().get(CachingTableProvider.WRITE_AROUND));
+
+    desc.validate();
+  }
+
+  private static Pair<ReadWriteTable<String, String>, Map<String, String>> getMockCache() {
+    // To allow concurrent writes for disjoint keys by testConcurrentAccess, we must use CHM here.
+    // This is okay because the atomic section in CachingTable covers both cache and table so using
+    // CHM for each does not serialize such two-step operation so the atomicity is still tested.
+    // Regular HashMap is not thread-safe even for disjoint keys.
+    final Map<String, String> cacheStore = new ConcurrentHashMap<>();
+    final ReadWriteTable tableCache = mock(ReadWriteTable.class);
+
+    doAnswer(invocation -> {
+        String key = invocation.getArgumentAt(0, String.class);
+        String value = invocation.getArgumentAt(1, String.class);
+        cacheStore.put(key, value);
+        return null;
+      }).when(tableCache).put(any(), any());
+
+    doAnswer(invocation -> {
+        String key = invocation.getArgumentAt(0, String.class);
+        return cacheStore.get(key);
+      }).when(tableCache).get(any());
+
+    doAnswer(invocation -> {
+        String key = invocation.getArgumentAt(0, String.class);
+        return cacheStore.remove(key);
+      }).when(tableCache).delete(any());
+
+    return Pair.of(tableCache, cacheStore);
+  }
+
+  private void doTestCacheOps(boolean isWriteAround) {
+    CachingTableDescriptor desc = new CachingTableDescriptor("1");
+    desc.withTable(new TableImpl(new TableSpec("realTable", null, null, new HashMap<>())));
+    desc.withCache(new TableImpl(new TableSpec("cacheTable", null, null, new HashMap<>())));
+    if (isWriteAround) {
+      desc.withWriteAround();
+    }
+    CachingTableProvider tableProvider = new CachingTableProvider(desc.getTableSpec());
+
+    TaskContext taskContext = mock(TaskContext.class);
+    final ReadWriteTable tableCache = getMockCache().getLeft();
+
+    final ReadWriteTable realTable = mock(ReadWriteTable.class);
+
+    doAnswer(invocation -> {
+        String key = invocation.getArgumentAt(0, String.class);
+        return "test-data-" + key;
+      }).when(realTable).get(any());
+
+    doAnswer(invocation -> {
+        String tableId = invocation.getArgumentAt(0, String.class);
+        if (tableId.equals("realTable")) {
+          // cache
+          return realTable;
+        } else if (tableId.equals("cacheTable")) {
+          return tableCache;
+        }
+
+        Assert.fail();
+        return null;
+      }).when(taskContext).getTable(anyString());
+
+    tableProvider.init(null, taskContext);
+
+    CachingTable cacheTable = (CachingTable) tableProvider.getTable();
+
+    Assert.assertEquals("test-data-1", cacheTable.get("1"));
+    verify(realTable, times(1)).get(any());
+    verify(tableCache, times(2)).get(any()); // cache miss leads to 2 more get() calls
+    verify(tableCache, times(1)).put(any(), any());
+    Assert.assertEquals(cacheTable.hitRate(), 0.0, 0.0); // 0 hit, 1 request
+    Assert.assertEquals(cacheTable.missRate(), 1.0, 0.0);
+
+    Assert.assertEquals("test-data-1", cacheTable.get("1"));
+    verify(realTable, times(1)).get(any()); // no change
+    verify(tableCache, times(3)).get(any());
+    verify(tableCache, times(1)).put(any(), any()); // no change
+    Assert.assertEquals(cacheTable.hitRate(), 0.5, 0.0); // 1 hit, 2 requests
+    Assert.assertEquals(cacheTable.missRate(), 0.5, 0.0);
+
+    cacheTable.put("2", "test-data-XXXX");
+    verify(tableCache, times(isWriteAround ? 1 : 2)).put(any(), any());
+    verify(realTable, times(1)).put(any(), any());
+
+    if (isWriteAround) {
+      Assert.assertEquals("test-data-2", cacheTable.get("2")); // expects value from table
+      verify(tableCache, times(5)).get(any()); // cache miss leads to 2 more get() calls
+      Assert.assertEquals(cacheTable.hitRate(), 0.33, 0.1); // 1 hit, 3 requests
+    } else {
+      Assert.assertEquals("test-data-XXXX", cacheTable.get("2")); // expect value from cache
+      verify(tableCache, times(4)).get(any()); // cache hit
+      Assert.assertEquals(cacheTable.hitRate(), 0.66, 0.1); // 2 hits, 3 requests
+    }
+  }
+
+  @Test
+  public void testCacheOps() {
+    doTestCacheOps(false);
+  }
+
+  @Test
+  public void testCacheOpsWriteAround() {
+    doTestCacheOps(true);
+  }
+
+  @Test
+  public void testNonexistentKeyInTable() {
+    ReadableTable<String, String> table = mock(ReadableTable.class);
+    doReturn(null).when(table).get(any());
+    ReadWriteTable<String, String> cache = getMockCache().getLeft();
+    CachingTable<String, String> cachingTable = new CachingTable<>("myTable", table, cache, 16, false);
+    Assert.assertNull(cachingTable.get("abc"));
+    verify(cache, times(2)).get(any());
+    Assert.assertNull(cache.get("abc"));
+    verify(cache, times(0)).put(any(), any());
+  }
+
+  @Test
+  public void testKeyEviction() {
+    ReadableTable<String, String> table = mock(ReadableTable.class);
+    doReturn("3").when(table).get(any());
+    ReadWriteTable<String, String> cache = mock(ReadWriteTable.class);
+
+    // no handler added to mock cache so get/put are noop, this can simulate eviction
+    CachingTable<String, String> cachingTable = new CachingTable<>("myTable", table, cache, 16, false);
+    cachingTable.get("abc");
+    verify(table, times(1)).get(any());
+
+    // get() should go to table again
+    cachingTable.get("abc");
+    verify(table, times(2)).get(any());
+  }
+
+  /**
+   * Test the atomic operations in CachingTable by simulating 10 threads each executing
+   * 5000 random operations (GET/PUT/DELETE) with random keys, which are picked from a
+   * narrow range (0-9) for higher concurrency. Consistency is verified by comparing
+   * the cache content and table content both of which should match exactly. Eviction
+   * is not simulated because it would be impossible to compare the cache/table.
+   * @throws InterruptedException
+   */
+  @Test
+  public void testConcurrentAccess() throws InterruptedException {
+    final int numThreads = 10;
+    final int iterations = 5000;
+    ExecutorService executor = Executors.newFixedThreadPool(numThreads);
+
+    // Ensure all threads to reach rendezvous before starting the simulation
+    final CountDownLatch startLatch = new CountDownLatch(numThreads);
+
+    Pair<ReadWriteTable<String, String>, Map<String, String>> tableMapPair = getMockCache();
+    final ReadableTable<String, String> table = tableMapPair.getLeft();
+
+    Pair<ReadWriteTable<String, String>, Map<String, String>> cacheMapPair = getMockCache();
+    final CachingTable<String, String> cachingTable = new CachingTable<>("myTable", table, cacheMapPair.getLeft(), 16, false);
+
+    Map<String, String> cacheMap = cacheMapPair.getRight();
+    Map<String, String> tableMap = tableMapPair.getRight();
+
+    final Random rand = new Random(System.currentTimeMillis());
+
+    for (int i = 0; i < numThreads; i++) {
+      executor.submit(() -> {
+          try {
+            startLatch.countDown();
+            startLatch.await();
+          } catch (InterruptedException e) {
+            Assert.fail();
+          }
+
+          String lastPutKey = null;
+          for (int j = 0; j < iterations; j++) {
+            int cmd = rand.nextInt(3);
+            String key = String.valueOf(rand.nextInt(10));
+            switch (cmd) {
+              case 0:
+                cachingTable.get(key);
+                break;
+              case 1:
+                cachingTable.put(key, "test-data-" + rand.nextInt());
+                lastPutKey = key;
+                break;
+              case 2:
+                if (lastPutKey != null) {
+                  cachingTable.delete(lastPutKey);
+                }
+                break;
+            }
+          }
+        });
+    }
+
+    executor.shutdown();
+
+    // Wait up to 1 minute for all threads to finish
+    Assert.assertTrue(executor.awaitTermination(60, TimeUnit.MINUTES));
+
+    // Verify cache and table contents fully match
+    Assert.assertEquals(cacheMap.size(), tableMap.size());
+    cacheMap.keySet().forEach(k -> Assert.assertEquals(cacheMap.get(k), tableMap.get(k)));
+  }
+}
\ No newline at end of file
index 208c670..8d07570 100644 (file)
@@ -21,11 +21,14 @@ package org.apache.samza.test.table;
 
 import java.io.IOException;
 import java.io.ObjectInputStream;
+import java.time.Duration;
+import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collection;
-import java.util.LinkedList;
+import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
+import java.util.concurrent.TimeUnit;
 import java.util.function.Function;
 import java.util.stream.Collectors;
 
@@ -37,9 +40,12 @@ import org.apache.samza.metrics.Counter;
 import org.apache.samza.metrics.MetricsRegistry;
 import org.apache.samza.metrics.Timer;
 import org.apache.samza.operators.KV;
+import org.apache.samza.operators.StreamGraph;
 import org.apache.samza.runtime.LocalApplicationRunner;
 import org.apache.samza.serializers.NoOpSerde;
 import org.apache.samza.table.Table;
+import org.apache.samza.table.caching.CachingTableDescriptor;
+import org.apache.samza.table.caching.guava.GuavaCacheTableDescriptor;
 import org.apache.samza.table.remote.TableReadFunction;
 import org.apache.samza.table.remote.TableWriteFunction;
 import org.apache.samza.table.remote.RemoteReadableTable;
@@ -52,6 +58,8 @@ import org.apache.samza.util.RateLimiter;
 import org.junit.Assert;
 import org.junit.Test;
 
+import com.google.common.cache.CacheBuilder;
+
 import static org.mockito.Matchers.any;
 import static org.mockito.Matchers.anyString;
 import static org.mockito.Mockito.doReturn;
@@ -61,8 +69,7 @@ import static org.mockito.Mockito.mock;
 
 public class TestRemoteTable extends AbstractIntegrationTestHarness {
 
-  static List<TestTableData.EnrichedPageView> writtenRecords = new LinkedList<>();
-  static List<TestTableData.PageView> received = new LinkedList<>();
+  static Map<String, List<TestTableData.EnrichedPageView>> writtenRecords = new HashMap<>();
 
   static class InMemoryReadFunction implements TableReadFunction<Integer, TestTableData.Profile> {
     private final String serializedProfiles;
@@ -90,13 +97,18 @@ public class TestRemoteTable extends AbstractIntegrationTestHarness {
 
   static class InMemoryWriteFunction implements TableWriteFunction<Integer, TestTableData.EnrichedPageView> {
     private transient List<TestTableData.EnrichedPageView> records;
+    private String testName;
+
+    public InMemoryWriteFunction(String testName) {
+      this.testName = testName;
+    }
 
     // Verify serializable functionality
     private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
       in.defaultReadObject();
 
       // Write to the global list for verification
-      records = writtenRecords;
+      records = writtenRecords.get(testName);
     }
 
     @Override
@@ -115,9 +127,26 @@ public class TestRemoteTable extends AbstractIntegrationTestHarness {
     }
   }
 
-  @Test
-  public void testStreamTableJoinRemoteTable() throws Exception {
-    final InMemoryWriteFunction writer = new InMemoryWriteFunction();
+  private <K, V> Table<KV<K, V>> getCachingTable(Table<KV<K, V>> actualTable, boolean defaultCache, String id, StreamGraph streamGraph) {
+    CachingTableDescriptor<K, V> cachingDesc = new CachingTableDescriptor<>("caching-table-" + id);
+    if (defaultCache) {
+      cachingDesc.withReadTtl(Duration.ofMinutes(5));
+      cachingDesc.withWriteTtl(Duration.ofMinutes(5));
+    } else {
+      GuavaCacheTableDescriptor<K, V> guavaDesc = new GuavaCacheTableDescriptor<>("guava-table-" + id);
+      guavaDesc.withCache(CacheBuilder.newBuilder().expireAfterAccess(5, TimeUnit.MINUTES).build());
+      Table<KV<K, V>> guavaTable = streamGraph.getTable(guavaDesc);
+      cachingDesc.withCache(guavaTable);
+    }
+
+    cachingDesc.withTable(actualTable);
+    return streamGraph.getTable(cachingDesc);
+  }
+
+  private void doTestStreamTableJoinRemoteTable(boolean withCache, boolean defaultCache, String testName) throws Exception {
+    final InMemoryWriteFunction writer = new InMemoryWriteFunction(testName);
+
+    writtenRecords.put(testName, new ArrayList<>());
 
     int count = 10;
     TestTableData.PageView[] pageViews = TestTableData.generatePageViews(count);
@@ -145,12 +174,20 @@ public class TestRemoteTable extends AbstractIntegrationTestHarness {
           .withWriteFunction(writer)
           .withRateLimiter(writeRateLimiter, null, null);
 
-      Table<KV<Integer, TestTableData.Profile>> inputTable = streamGraph.getTable(inputTableDesc);
       Table<KV<Integer, TestTableData.EnrichedPageView>> outputTable = streamGraph.getTable(outputTableDesc);
 
+      if (withCache) {
+        outputTable = getCachingTable(outputTable, defaultCache, "output", streamGraph);
+      }
+
+      Table<KV<Integer, TestTableData.Profile>> inputTable = streamGraph.getTable(inputTableDesc);
+
+      if (withCache) {
+        inputTable = getCachingTable(inputTable, defaultCache, "input", streamGraph);
+      }
+
       streamGraph.getInputStream("PageView", new NoOpSerde<TestTableData.PageView>())
           .map(pv -> {
-              received.add(pv);
               return new KV<Integer, TestTableData.PageView>(pv.getMemberId(), pv);
             })
           .join(inputTable, new TestLocalTable.PageViewToProfileJoinFunction())
@@ -162,9 +199,23 @@ public class TestRemoteTable extends AbstractIntegrationTestHarness {
     runner.waitForFinish();
 
     int numExpected = count * partitionCount;
-    Assert.assertEquals(numExpected, received.size());
-    Assert.assertEquals(numExpected, writtenRecords.size());
-    Assert.assertTrue(writtenRecords.get(0) instanceof TestTableData.EnrichedPageView);
+    Assert.assertEquals(numExpected, writtenRecords.get(testName).size());
+    Assert.assertTrue(writtenRecords.get(testName).get(0) instanceof TestTableData.EnrichedPageView);
+  }
+
+  @Test
+  public void testStreamTableJoinRemoteTable() throws Exception {
+    doTestStreamTableJoinRemoteTable(false, false, "testStreamTableJoinRemoteTable");
+  }
+
+  @Test
+  public void testStreamTableJoinRemoteTableWithCache() throws Exception {
+    doTestStreamTableJoinRemoteTable(true, false, "testStreamTableJoinRemoteTableWithCache");
+  }
+
+  @Test
+  public void testStreamTableJoinRemoteTableWithDefaultCache() throws Exception {
+    doTestStreamTableJoinRemoteTable(true, true, "testStreamTableJoinRemoteTableWithDefaultCache");
   }
 
   private TaskContext createMockTaskContext() {