SAMZA-1763: Add async methods to Table API
authorPeng Du <pdu@linkedin.com>
Fri, 10 Aug 2018 18:21:12 +0000 (11:21 -0700)
committerJagadish <jvenkatraman@linkedin.com>
Fri, 10 Aug 2018 18:21:12 +0000 (11:21 -0700)
Currently, Table API only has blocking/sync methods which limit the
throughput of remote tables. This change adds async methods to the API
to enable high throughput remote table accesses through usage of async
IO. The new methods are added to ReadableTable and ReadWriteTable. A
high level summary of the change is below:

- added async methods to table ReadableTable and ReadWriteTable.
- added async methods to TableRead/WriteFunction
- CompletableFuture is used for the async abstraction
- CachingTable are updated to support async methods
- added default impls for sync methods backed by async in table functions
- added helper class, Throttler/AsyncHelper to ease table development
- fixed existing test cases with table implementations
- added more thorough unit tests to RemoteTable CRUD methods

Additionally remove explicit check of config entries for remote table in
TestTableDescriptorsProvider since there is already a test case on
RemoteTableDescriptor.

Author: Peng Du <pdu@linkedin.com>

Reviewers: Jagadish<jagadish@apache.org>, Wei <wsong@linkedin.com>

Closes #593 from pdu-mn1/async-table-api-futures

26 files changed:
samza-api/src/main/java/org/apache/samza/table/ReadWriteTable.java
samza-api/src/main/java/org/apache/samza/table/ReadableTable.java
samza-core/src/main/java/org/apache/samza/table/caching/CachingTable.java
samza-core/src/main/java/org/apache/samza/table/caching/CachingTableDescriptor.java
samza-core/src/main/java/org/apache/samza/table/caching/CachingTableProvider.java
samza-core/src/main/java/org/apache/samza/table/caching/guava/GuavaCacheTable.java
samza-core/src/main/java/org/apache/samza/table/remote/CreditFunction.java [deleted file]
samza-core/src/main/java/org/apache/samza/table/remote/RemoteReadWriteTable.java
samza-core/src/main/java/org/apache/samza/table/remote/RemoteReadableTable.java
samza-core/src/main/java/org/apache/samza/table/remote/RemoteTableDescriptor.java
samza-core/src/main/java/org/apache/samza/table/remote/RemoteTableProvider.java
samza-core/src/main/java/org/apache/samza/table/remote/TableRateLimiter.java [new file with mode: 0644]
samza-core/src/main/java/org/apache/samza/table/remote/TableReadFunction.java
samza-core/src/main/java/org/apache/samza/table/remote/TableWriteFunction.java
samza-core/src/main/java/org/apache/samza/table/utils/DefaultTableReadMetrics.java
samza-core/src/main/java/org/apache/samza/table/utils/DefaultTableWriteMetrics.java
samza-core/src/test/java/org/apache/samza/table/caching/TestCachingTable.java
samza-core/src/test/java/org/apache/samza/table/remote/TestRemoteTable.java [new file with mode: 0644]
samza-core/src/test/java/org/apache/samza/table/remote/TestRemoteTableDescriptor.java
samza-core/src/test/java/org/apache/samza/table/remote/TestTableRateLimiter.java [new file with mode: 0644]
samza-kv/src/main/java/org/apache/samza/storage/kv/LocalStoreBackedReadWriteTable.java
samza-kv/src/main/java/org/apache/samza/storage/kv/LocalStoreBackedReadableTable.java
samza-sql/src/test/java/org/apache/samza/sql/testutil/TestIOResolverFactory.java
samza-test/src/test/java/org/apache/samza/test/table/TestLocalTable.java
samza-test/src/test/java/org/apache/samza/test/table/TestRemoteTable.java
samza-test/src/test/java/org/apache/samza/test/table/TestTableDescriptorsProvider.java

index def5afb..083a1b5 100644 (file)
@@ -19,6 +19,8 @@
 package org.apache.samza.table;
 
 import java.util.List;
+import java.util.concurrent.CompletableFuture;
+
 import org.apache.samza.annotation.InterfaceStability;
 import org.apache.samza.storage.kv.Entry;
 
@@ -33,7 +35,8 @@ import org.apache.samza.storage.kv.Entry;
 public interface ReadWriteTable<K, V> extends ReadableTable<K, V> {
 
   /**
-   * Updates the mapping of the specified key-value pair; Associates the specified {@code key} with the specified {@code value}.
+   * Updates the mapping of the specified key-value pair;
+   * Associates the specified {@code key} with the specified {@code value}.
    *
    * The key is deleted from the table if value is {@code null}.
    *
@@ -44,6 +47,18 @@ public interface ReadWriteTable<K, V> extends ReadableTable<K, V> {
   void put(K key, V value);
 
   /**
+   * Asynchronously updates the mapping of the specified key-value pair;
+   * Associates the specified {@code key} with the specified {@code value}.
+   * The key is deleted from the table if value is {@code null}.
+   *
+   * @param key the key with which the specified {@code value} is to be associated.
+   * @param value the value with which the specified {@code key} is to be associated.
+   * @throws NullPointerException if the specified {@code key} is {@code null}.
+   * @return CompletableFuture for the operation
+   */
+  CompletableFuture<Void> putAsync(K key, V value);
+
+  /**
    * Updates the mappings of the specified key-value {@code entries}.
    *
    * A key is deleted from the table if its corresponding value is {@code null}.
@@ -54,6 +69,16 @@ public interface ReadWriteTable<K, V> extends ReadableTable<K, V> {
   void putAll(List<Entry<K, V>> entries);
 
   /**
+   * Asynchronously updates the mappings of the specified key-value {@code entries}.
+   * A key is deleted from the table if its corresponding value is {@code null}.
+   *
+   * @param entries the updated mappings to put into this table.
+   * @throws NullPointerException if any of the specified {@code entries} has {@code null} as key.
+   * @return CompletableFuture for the operation
+   */
+  CompletableFuture<Void> putAllAsync(List<Entry<K, V>> entries);
+
+  /**
    * Deletes the mapping for the specified {@code key} from this table (if such mapping exists).
    *
    * @param key the key for which the mapping is to be deleted.
@@ -62,6 +87,14 @@ public interface ReadWriteTable<K, V> extends ReadableTable<K, V> {
   void delete(K key);
 
   /**
+   * Asynchronously deletes the mapping for the specified {@code key} from this table (if such mapping exists).
+   * @param key the key for which the mapping is to be deleted.
+   * @throws NullPointerException if the specified {@code key} is {@code null}.
+   * @return CompletableFuture for the operation
+   */
+  CompletableFuture<Void> deleteAsync(K key);
+
+  /**
    * Deletes the mappings for the specified {@code keys} from this table.
    *
    * @param keys the keys for which the mappings are to be deleted.
@@ -69,10 +102,16 @@ public interface ReadWriteTable<K, V> extends ReadableTable<K, V> {
    */
   void deleteAll(List<K> keys);
 
+  /**
+   * Asynchronously deletes the mappings for the specified {@code keys} from this table.
+   * @param keys the keys for which the mappings are to be deleted.
+   * @throws NullPointerException if the specified {@code keys} list, or any of the keys, is {@code null}.
+   * @return CompletableFuture for the operation
+   */
+  CompletableFuture<Void> deleteAllAsync(List<K> keys);
 
   /**
    * Flushes the underlying store of this table, if applicable.
    */
   void flush();
-
 }
index 15b6115..490acc0 100644 (file)
@@ -20,6 +20,7 @@ package org.apache.samza.table;
 
 import java.util.List;
 import java.util.Map;
+import java.util.concurrent.CompletableFuture;
 
 import org.apache.samza.annotation.InterfaceStability;
 import org.apache.samza.container.SamzaContainerContext;
@@ -55,6 +56,15 @@ public interface ReadableTable<K, V> extends Table<KV<K, V>> {
   V get(K key);
 
   /**
+   * Asynchronously gets the value associated with the specified {@code key}.
+   *
+   * @param key the key with which the associated value is to be fetched.
+   * @return completableFuture for the requested value
+   * @throws NullPointerException if the specified {@code key} is {@code null}.
+   */
+  CompletableFuture<V> getAsync(K key);
+
+  /**
    * Gets the values with which the specified {@code keys} are associated.
    *
    * @param keys the keys with which the associated values are to be fetched.
@@ -64,6 +74,15 @@ public interface ReadableTable<K, V> extends Table<KV<K, V>> {
   Map<K, V> getAll(List<K> keys);
 
   /**
+   * Asynchronously gets the values with which the specified {@code keys} are associated.
+   *
+   * @param keys the keys with which the associated values are to be fetched.
+   * @return completableFuture for the requested entries
+   * @throws NullPointerException if the specified {@code keys} list, or any of the keys, is {@code null}.
+   */
+  CompletableFuture<Map<K, V>> getAllAsync(List<K> keys);
+
+  /**
    * Close the table and release any resources acquired
    */
   void close();
index 23e4f7f..b7aa33c 100644 (file)
 
 package org.apache.samza.table.caching;
 
+import java.util.ArrayList;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
+import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.atomic.AtomicLong;
-import java.util.concurrent.locks.Lock;
+import java.util.stream.Collectors;
 
+import org.apache.samza.SamzaException;
 import org.apache.samza.container.SamzaContainerContext;
 import org.apache.samza.storage.kv.Entry;
 import org.apache.samza.table.ReadWriteTable;
@@ -35,7 +38,6 @@ import org.apache.samza.table.utils.TableMetricsUtil;
 import org.apache.samza.task.TaskContext;
 
 import com.google.common.base.Preconditions;
-import com.google.common.util.concurrent.Striped;
 
 
 /**
@@ -52,29 +54,23 @@ import com.google.common.util.concurrent.Striped;
  * supports read-after-write semantics because the value is cached after written to
  * the table.
  *
- * Table and cache are updated (put/delete) in an atomic manner as such it is thread
- * safe for concurrent accesses. Strip locks are used for fine-grained synchronization
- * and the number of stripes is configurable.
- *
- * NOTE: Cache get is not synchronized with put for better parallelism in the read path.
- * As such, cache table implementation is expected to be thread-safe for concurrent
- * accesses.
+ * Note that there is no synchronization in CachingTable because it is impossible to
+ * implement a critical section between table read/write and cache update in the async
+ * code paths without serializing all async operations for the same keys. Given stale
+ * data is a presumed trade off for using a cache for table, it should be acceptable
+ * for the data in table and cache are out-of-sync. Moreover, unsynchronized operations
+ * in CachingTable also deliver higher performance when there is contention.
  *
  * @param <K> type of the table key
  * @param <V> type of the table value
  */
 public class CachingTable<K, V> implements ReadWriteTable<K, V> {
-  private static final String GROUP_NAME = CachingTable.class.getSimpleName();
-
   private final String tableId;
   private final ReadableTable<K, V> rdTable;
   private final ReadWriteTable<K, V> rwTable;
   private final ReadWriteTable<K, V> cache;
   private final boolean isWriteAround;
 
-  // Use stripe based locking to allow parallelism of disjoint keys.
-  private final Striped<Lock> stripedLocks;
-
   // Metrics
   private DefaultTableReadMetrics readMetrics;
   private DefaultTableWriteMetrics writeMetrics;
@@ -83,13 +79,12 @@ public class CachingTable<K, V> implements ReadWriteTable<K, V> {
   private AtomicLong hitCount = new AtomicLong();
   private AtomicLong missCount = new AtomicLong();
 
-  public CachingTable(String tableId, ReadableTable<K, V> table, ReadWriteTable<K, V> cache, int stripes, boolean isWriteAround) {
+  public CachingTable(String tableId, ReadableTable<K, V> table, ReadWriteTable<K, V> cache, boolean isWriteAround) {
     this.tableId = tableId;
     this.rdTable = table;
     this.rwTable = table instanceof ReadWriteTable ? (ReadWriteTable) table : null;
     this.cache = cache;
     this.isWriteAround = isWriteAround;
-    this.stripedLocks = Striped.lazyWeakLock(stripes);
   }
 
   /**
@@ -105,96 +100,208 @@ public class CachingTable<K, V> implements ReadWriteTable<K, V> {
     tableMetricsUtil.newGauge("req-count", () -> requestCount());
   }
 
+  /**
+   * Lookup the cache and return the keys that are missed in cache
+   * @param keys keys to be looked up
+   * @param records result map
+   * @return list of keys missed in the cache
+   */
+  private List<K> lookupCache(List<K> keys, Map<K, V> records) {
+    List<K> missKeys = new ArrayList<>();
+    records.putAll(cache.getAll(keys));
+    keys.forEach(k -> {
+        if (!records.containsKey(k)) {
+          missKeys.add(k);
+        }
+      });
+    return missKeys;
+  }
+
   @Override
   public V get(K key) {
+    try {
+      return getAsync(key).get();
+    } catch (InterruptedException e) {
+      throw new SamzaException(e);
+    } catch (Exception e) {
+      throw (SamzaException) e.getCause();
+    }
+  }
+
+  @Override
+  public CompletableFuture<V> getAsync(K key) {
     readMetrics.numGets.inc();
-    long startNs = System.nanoTime();
     V value = cache.get(key);
-    if (value == null) {
-      missCount.incrementAndGet();
-      Lock lock = stripedLocks.get(key);
-      try {
-        lock.lock();
-        if (cache.get(key) == null) {
-          // Due to the lack of contains() API in ReadableTable, there is
-          // no way to tell whether a null return by cache.get(key) means
-          // cache miss or the value is actually null. As such, we cannot
-          // support negative cache semantics.
-          value = rdTable.get(key);
-          if (value != null) {
-            cache.put(key, value);
-          }
-        }
-      } finally {
-        lock.unlock();
-      }
-    } else {
+    if (value != null) {
       hitCount.incrementAndGet();
+      return CompletableFuture.completedFuture(value);
     }
-    readMetrics.getNs.update(System.nanoTime() - startNs);
-    return value;
+
+    long startNs = System.nanoTime();
+    missCount.incrementAndGet();
+
+    return rdTable.getAsync(key).handle((result, e) -> {
+        if (e != null) {
+          throw new SamzaException("Failed to get the record for " + key, e);
+        } else {
+          if (result != null) {
+            cache.put(key, result);
+          }
+          readMetrics.getNs.update(System.nanoTime() - startNs);
+          return result;
+        }
+      });
   }
 
   @Override
   public Map<K, V> getAll(List<K> keys) {
+    try {
+      return getAllAsync(keys).get();
+    } catch (InterruptedException e) {
+      throw new SamzaException(e);
+    } catch (Exception e) {
+      throw (SamzaException) e.getCause();
+    }
+  }
+
+  @Override
+  public CompletableFuture<Map<K, V>> getAllAsync(List<K> keys) {
     readMetrics.numGetAlls.inc();
-    long startNs = System.nanoTime();
+    // Make a copy of entries which might be immutable
     Map<K, V> getAllResult = new HashMap<>();
-    keys.stream().forEach(k -> getAllResult.put(k, get(k)));
-    readMetrics.getAllNs.update(System.nanoTime() - startNs);
-    return getAllResult;
+    List<K> missingKeys = lookupCache(keys, getAllResult);
+
+    if (missingKeys.isEmpty()) {
+      return CompletableFuture.completedFuture(getAllResult);
+    }
+
+    long startNs = System.nanoTime();
+    return rdTable.getAllAsync(missingKeys).handle((records, e) -> {
+        if (e != null) {
+          throw new SamzaException("Failed to get records for " + keys, e);
+        } else {
+          if (records != null) {
+            cache.putAll(records.entrySet().stream()
+                .map(r -> new Entry<>(r.getKey(), r.getValue()))
+                .collect(Collectors.toList()));
+            getAllResult.putAll(records);
+          }
+          readMetrics.getAllNs.update(System.nanoTime() - startNs);
+          return getAllResult;
+        }
+      });
   }
 
   @Override
   public void put(K key, V value) {
+    try {
+      putAsync(key, value).get();
+    } catch (InterruptedException e) {
+      throw new SamzaException(e);
+    } catch (Exception e) {
+      throw (SamzaException) e.getCause();
+    }
+  }
+
+  @Override
+  public CompletableFuture<Void> putAsync(K key, V value) {
     writeMetrics.numPuts.inc();
-    long startNs = System.nanoTime();
     Preconditions.checkNotNull(rwTable, "Cannot write to a read-only table: " + rdTable);
-    Lock lock = stripedLocks.get(key);
+
+    long startNs = System.nanoTime();
+    return rwTable.putAsync(key, value).handle((result, e) -> {
+        if (e != null) {
+          throw new SamzaException(String.format("Failed to put a record, key=%s, value=%s", key, value), e);
+        } else if (!isWriteAround) {
+          if (value == null) {
+            cache.delete(key);
+          } else {
+            cache.put(key, value);
+          }
+        }
+        writeMetrics.putNs.update(System.nanoTime() - startNs);
+        return result;
+      });
+  }
+
+  @Override
+  public void putAll(List<Entry<K, V>> records) {
     try {
-      lock.lock();
-      rwTable.put(key, value);
-      if (!isWriteAround) {
-        cache.put(key, value);
-      }
-    } finally {
-      lock.unlock();
+      putAllAsync(records).get();
+    } catch (InterruptedException e) {
+      throw new SamzaException(e);
+    } catch (Exception e) {
+      throw (SamzaException) e.getCause();
     }
-    writeMetrics.putNs.update(System.nanoTime() - startNs);
   }
 
   @Override
-  public void putAll(List<Entry<K, V>> entries) {
+  public CompletableFuture<Void> putAllAsync(List<Entry<K, V>> records) {
     writeMetrics.numPutAlls.inc();
     long startNs = System.nanoTime();
     Preconditions.checkNotNull(rwTable, "Cannot write to a read-only table: " + rdTable);
-    entries.forEach(e -> put(e.getKey(), e.getValue()));
-    writeMetrics.putAllNs.update(System.nanoTime() - startNs);
+    return rwTable.putAllAsync(records).handle((result, e) -> {
+        if (e != null) {
+          throw new SamzaException("Failed to put records " + records, e);
+        } else if (!isWriteAround) {
+          cache.putAll(records);
+        }
+
+        writeMetrics.putAllNs.update(System.nanoTime() - startNs);
+        return result;
+      });
   }
 
   @Override
   public void delete(K key) {
+    try {
+      deleteAsync(key).get();
+    } catch (InterruptedException e) {
+      throw new SamzaException(e);
+    } catch (Exception e) {
+      throw (SamzaException) e.getCause();
+    }
+  }
+
+  @Override
+  public CompletableFuture<Void> deleteAsync(K key) {
     writeMetrics.numDeletes.inc();
     long startNs = System.nanoTime();
     Preconditions.checkNotNull(rwTable, "Cannot delete from a read-only table: " + rdTable);
-    Lock lock = stripedLocks.get(key);
+    return rwTable.deleteAsync(key).handle((result, e) -> {
+        if (e != null) {
+          throw new SamzaException("Failed to delete the record for " + key, e);
+        } else if (!isWriteAround) {
+          cache.delete(key);
+        }
+        writeMetrics.deleteNs.update(System.nanoTime() - startNs);
+        return result;
+      });
+  }
+
+  @Override
+  public void deleteAll(List<K> keys) {
     try {
-      lock.lock();
-      rwTable.delete(key);
-      cache.delete(key);
-    } finally {
-      lock.unlock();
+      deleteAllAsync(keys).get();
+    } catch (Exception e) {
+      throw new SamzaException(e);
     }
-    writeMetrics.deleteNs.update(System.nanoTime() - startNs);
   }
 
   @Override
-  public void deleteAll(List<K> keys) {
+  public CompletableFuture<Void> deleteAllAsync(List<K> keys) {
     writeMetrics.numDeleteAlls.inc();
     long startNs = System.nanoTime();
     Preconditions.checkNotNull(rwTable, "Cannot delete from a read-only table: " + rdTable);
-    keys.stream().forEach(k -> delete(k));
-    writeMetrics.deleteAllNs.update(System.nanoTime() - startNs);
+    return rwTable.deleteAllAsync(keys).handle((result, e) -> {
+        if (e != null) {
+          throw new SamzaException("Failed to delete the record for " + keys, e);
+        } else if (!isWriteAround) {
+          cache.deleteAll(keys);
+        }
+        writeMetrics.deleteAllNs.update(System.nanoTime() - startNs);
+        return result;
+      });
   }
 
   @Override
index eb74825..21463c2 100644 (file)
@@ -42,7 +42,6 @@ public class CachingTableDescriptor<K, V> extends BaseTableDescriptor<K, V, Cach
   private long cacheSize;
   private Table<KV<K, V>> cache;
   private Table<KV<K, V>> table;
-  private int stripes = 16;
   private boolean isWriteAround;
 
   /**
@@ -75,7 +74,6 @@ public class CachingTableDescriptor<K, V> extends BaseTableDescriptor<K, V, Cach
     }
 
     tableSpecConfig.put(CachingTableProvider.REAL_TABLE_ID, ((TableImpl) table).getTableSpec().getId());
-    tableSpecConfig.put(CachingTableProvider.LOCK_STRIPES, String.valueOf(stripes));
     tableSpecConfig.put(CachingTableProvider.WRITE_AROUND, String.valueOf(isWriteAround));
 
     return new TableSpec(tableId, serde, CachingTableProviderFactory.class.getName(), tableSpecConfig);
@@ -137,17 +135,6 @@ public class CachingTableDescriptor<K, V> extends BaseTableDescriptor<K, V, Cach
   }
 
   /**
-   * Specify the number of stripes for striped locking for atomically updating
-   * cache and the actual table. Default number of stripes is 16.
-   * @param stripes number of stripes for locking
-   * @return this descriptor
-   */
-  public CachingTableDescriptor withStripes(int stripes) {
-    this.stripes = stripes;
-    return this;
-  }
-
-  /**
    * Specify if write-around policy should be used to bypass writing
    * to cache for put operations. This is useful when put() is the
    * dominant operation and get() has no locality with recent puts.
@@ -168,6 +155,5 @@ public class CachingTableDescriptor<K, V> extends BaseTableDescriptor<K, V, Cach
       Preconditions.checkArgument(readTtl == null && writeTtl == null && cacheSize == 0,
           "Invalid to specify both {cache} and {readTtl|writeTtl|cacheSize} at the same time.");
     }
-    Preconditions.checkArgument(stripes > 0, "Number of cache stripes must be positive.");
   }
 }
index 797d963..a7d65bc 100644 (file)
@@ -50,7 +50,6 @@ public class CachingTableProvider implements TableProvider {
   public static final String READ_TTL_MS = "readTtl";
   public static final String WRITE_TTL_MS = "writeTtl";
   public static final String CACHE_SIZE = "cacheSize";
-  public static final String LOCK_STRIPES = "lockStripes";
   public static final String WRITE_AROUND = "writeAround";
 
   private final TableSpec cachingTableSpec;
@@ -86,9 +85,8 @@ public class CachingTableProvider implements TableProvider {
       defaultCaches.add(cache);
     }
 
-    int stripes = Integer.parseInt(cachingTableSpec.getConfig().get(LOCK_STRIPES));
     boolean isWriteAround = Boolean.parseBoolean(cachingTableSpec.getConfig().get(WRITE_AROUND));
-    CachingTable cachingTable = new CachingTable(cachingTableSpec.getId(), table, cache, stripes, isWriteAround);
+    CachingTable cachingTable = new CachingTable(cachingTableSpec.getId(), table, cache, isWriteAround);
     cachingTable.init(containerContext, taskContext);
     return cachingTable;
   }
index fcded2f..a8beb3b 100644 (file)
 
 package org.apache.samza.table.caching.guava;
 
-import java.util.HashMap;
+import java.util.ArrayList;
 import java.util.List;
 import java.util.Map;
+import java.util.concurrent.CompletableFuture;
 
+import org.apache.samza.SamzaException;
 import org.apache.samza.container.SamzaContainerContext;
 import org.apache.samza.storage.kv.Entry;
 import org.apache.samza.table.ReadWriteTable;
@@ -40,8 +42,6 @@ import com.google.common.cache.Cache;
  * @param <V> type of the value in the cache
  */
 public class GuavaCacheTable<K, V> implements ReadWriteTable<K, V> {
-  private static final String GROUP_NAME = GuavaCacheTableProvider.class.getSimpleName();
-
   private final String tableId;
   private final Cache<K, V> cache;
 
@@ -61,44 +61,148 @@ public class GuavaCacheTable<K, V> implements ReadWriteTable<K, V> {
   }
 
   @Override
+  public V get(K key) {
+    try {
+      return getAsync(key).get();
+    } catch (Exception e) {
+      throw new SamzaException("GET failed for " + key, e);
+    }
+  }
+
+  @Override
+  public CompletableFuture<V> getAsync(K key) {
+    CompletableFuture<V> future = new CompletableFuture<>();
+    try {
+      future.complete(cache.getIfPresent(key));
+    } catch (Exception e) {
+      future.completeExceptionally(e);
+    }
+    return future;
+  }
+
+  @Override
+  public Map<K, V> getAll(List<K> keys) {
+    try {
+      return getAllAsync(keys).get();
+    } catch (Exception e) {
+      throw new SamzaException("GET_ALL failed for " + keys, e);
+    }
+  }
+
+  @Override
+  public CompletableFuture<Map<K, V>> getAllAsync(List<K> keys) {
+    CompletableFuture<Map<K, V>> future = new CompletableFuture<>();
+    try {
+      future.complete(cache.getAllPresent(keys));
+    } catch (Exception e) {
+      future.completeExceptionally(e);
+    }
+    return future;
+  }
+
+  @Override
   public void put(K key, V value) {
-    if (value != null) {
+    try {
+      putAsync(key, value).get();
+    } catch (Exception e) {
+      throw new SamzaException("PUT failed for " + key, e);
+    }
+  }
+
+  @Override
+  public CompletableFuture<Void> putAsync(K key, V value) {
+    if (key == null) {
+      return deleteAsync(key);
+    }
+
+    CompletableFuture<Void> future = new CompletableFuture<>();
+    try {
       cache.put(key, value);
-    } else {
-      delete(key);
+      future.complete(null);
+    } catch (Exception e) {
+      future.completeExceptionally(e);
     }
+    return future;
   }
 
   @Override
   public void putAll(List<Entry<K, V>> entries) {
-    entries.forEach(e -> put(e.getKey(), e.getValue()));
+    try {
+      putAllAsync(entries).get();
+    } catch (Exception e) {
+      throw new SamzaException("PUT_ALL failed", e);
+    }
+  }
+
+  @Override
+  public CompletableFuture<Void> putAllAsync(List<Entry<K, V>> entries) {
+    CompletableFuture<Void> future = new CompletableFuture<>();
+    try {
+      // Separate out put vs delete records
+      List<K> delKeys = new ArrayList<>();
+      List<Entry<K, V>> putRecords = new ArrayList<>();
+      entries.forEach(r -> {
+          if (r.getValue() != null) {
+            putRecords.add(r);
+          } else {
+            delKeys.add(r.getKey());
+          }
+        });
+
+      cache.invalidateAll(delKeys);
+      putRecords.forEach(e -> put(e.getKey(), e.getValue()));
+      future.complete(null);
+    } catch (Exception e) {
+      future.completeExceptionally(e);
+    }
+    return future;
   }
 
   @Override
   public void delete(K key) {
-    cache.invalidate(key);
+    try {
+      deleteAsync(key).get();
+    } catch (Exception e) {
+      throw new SamzaException("DELETE failed", e);
+    }
   }
 
   @Override
-  public void deleteAll(List<K> keys) {
-    keys.forEach(k -> delete(k));
+  public CompletableFuture<Void> deleteAsync(K key) {
+    CompletableFuture<Void> future = new CompletableFuture<>();
+    try {
+      cache.invalidate(key);
+      future.complete(null);
+    } catch (Exception e) {
+      future.completeExceptionally(e);
+    }
+    return future;
   }
 
   @Override
-  public synchronized void flush() {
-    cache.cleanUp();
+  public void deleteAll(List<K> keys) {
+    try {
+      deleteAllAsync(keys).get();
+    } catch (Exception e) {
+      throw new SamzaException("DELETE_ALL failed", e);
+    }
   }
 
   @Override
-  public V get(K key) {
-    return cache.getIfPresent(key);
+  public CompletableFuture<Void> deleteAllAsync(List<K> keys) {
+    CompletableFuture<Void> future = new CompletableFuture<>();
+    try {
+      cache.invalidateAll(keys);
+      future.complete(null);
+    } catch (Exception e) {
+      future.completeExceptionally(e);
+    }
+    return future;
   }
 
   @Override
-  public Map<K, V> getAll(List<K> keys) {
-    Map<K, V> getAllResult = new HashMap<>();
-    keys.stream().forEach(k -> getAllResult.put(k, get(k)));
-    return getAllResult;
+  public synchronized void flush() {
+    cache.cleanUp();
   }
 
   @Override
diff --git a/samza-core/src/main/java/org/apache/samza/table/remote/CreditFunction.java b/samza-core/src/main/java/org/apache/samza/table/remote/CreditFunction.java
deleted file mode 100644 (file)
index 0d30098..0000000
+++ /dev/null
@@ -1,36 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.samza.table.remote;
-
-import java.io.Serializable;
-import java.util.function.Function;
-
-import org.apache.samza.operators.KV;
-
-
-/**
- * Function interface for providing rate limiting credits for each table record.
- * This interface allows callers to pass in lambda expressions which are otherwise
- * non-serializable as-is.
- * @param <K> the type of the key
- * @param <V> the type of the value
- */
-public interface CreditFunction<K, V> extends Function<KV<K, V>, Integer>, Serializable {
-}
\ No newline at end of file
index 95f8cfa..88bc7df 100644 (file)
 package org.apache.samza.table.remote;
 
 import java.util.List;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutorService;
+import java.util.stream.Collectors;
 
 import org.apache.samza.SamzaException;
 import org.apache.samza.container.SamzaContainerContext;
-import org.apache.samza.metrics.Timer;
 import org.apache.samza.storage.kv.Entry;
 import org.apache.samza.table.ReadWriteTable;
 import org.apache.samza.table.utils.DefaultTableWriteMetrics;
 import org.apache.samza.table.utils.TableMetricsUtil;
 import org.apache.samza.task.TaskContext;
-import org.apache.samza.util.RateLimiter;
 
+import com.google.common.annotations.VisibleForTesting;
 import com.google.common.base.Preconditions;
 
-import static org.apache.samza.table.remote.RemoteTableDescriptor.RL_WRITE_TAG;
-
 
 /**
  * Remote store backed read writable table
@@ -43,22 +43,20 @@ import static org.apache.samza.table.remote.RemoteTableDescriptor.RL_WRITE_TAG;
  * @param <V> the type of the value in this table
  */
 public class RemoteReadWriteTable<K, V> extends RemoteReadableTable<K, V> implements ReadWriteTable<K, V> {
+  private final TableWriteFunction<K, V> writeFn;
 
-  protected final TableWriteFunction<K, V> writeFn;
-  protected final CreditFunction<K, V> writeCreditFn;
-  protected final boolean rateLimitWrites;
+  private DefaultTableWriteMetrics writeMetrics;
 
-  protected DefaultTableWriteMetrics writeMetrics;
-  protected Timer putThrottleNs; // use single timer for all write operations
+  @VisibleForTesting
+  final TableRateLimiter writeRateLimiter;
 
   public RemoteReadWriteTable(String tableId, TableReadFunction readFn, TableWriteFunction writeFn,
-      RateLimiter ratelimiter, CreditFunction<K, V> readCreditFn, CreditFunction<K, V> writeCreditFn) {
-    super(tableId, readFn, ratelimiter, readCreditFn);
+      TableRateLimiter<K, V> readRateLimiter, TableRateLimiter<K, V> writeRateLimiter,
+      ExecutorService tableExecutor, ExecutorService callbackExecutor) {
+    super(tableId, readFn, readRateLimiter, tableExecutor, callbackExecutor);
     Preconditions.checkNotNull(writeFn, "null write function");
     this.writeFn = writeFn;
-    this.writeCreditFn = writeCreditFn;
-    this.rateLimitWrites = rateLimiter != null && rateLimiter.getSupportedTags().contains(RL_WRITE_TAG);
-    logger.info("Rate limiting is {} for remote write operations", rateLimitWrites ? "enabled" : "disabled");
+    this.writeRateLimiter = writeRateLimiter;
   }
 
   /**
@@ -69,7 +67,7 @@ public class RemoteReadWriteTable<K, V> extends RemoteReadableTable<K, V> implem
     super.init(containerContext, taskContext);
     writeMetrics = new DefaultTableWriteMetrics(containerContext, taskContext, this, tableId);
     TableMetricsUtil tableMetricsUtil = new TableMetricsUtil(containerContext, taskContext, this, tableId);
-    putThrottleNs = tableMetricsUtil.newTimer("put-throttle-ns");
+    writeRateLimiter.setTimerMetric(tableMetricsUtil.newTimer("put-throttle-ns"));
   }
 
   /**
@@ -77,25 +75,28 @@ public class RemoteReadWriteTable<K, V> extends RemoteReadableTable<K, V> implem
    */
   @Override
   public void put(K key, V value) {
+    try {
+      putAsync(key, value).get();
+    } catch (Exception e) {
+      throw new SamzaException(e);
+    }
+  }
 
+  /**
+   * {@inheritDoc}
+   */
+  @Override
+  public CompletableFuture<Void> putAsync(K key, V value) {
+    Preconditions.checkNotNull(key);
     if (value == null) {
-      delete(key);
-      return;
+      return deleteAsync(key);
     }
 
-    try {
-      writeMetrics.numPuts.inc();
-      if (rateLimitWrites) {
-        throttle(key, value, RL_WRITE_TAG, writeCreditFn, putThrottleNs);
-      }
-      long startNs = System.nanoTime();
-      writeFn.put(key, value);
-      writeMetrics.putNs.update(System.nanoTime() - startNs);
-    } catch (Exception e) {
-      String errMsg = String.format("Failed to put a record, key=%s, value=%s", key, value);
-      logger.error(errMsg, e);
-      throw new SamzaException(errMsg, e);
-    }
+    writeMetrics.numPuts.inc();
+    return execute(writeRateLimiter, key, value, writeFn::putAsync, writeMetrics.putNs)
+        .exceptionally(e -> {
+            throw new SamzaException("Failed to put a record with key=" + key, (Throwable) e);
+          });
   }
 
   /**
@@ -104,14 +105,9 @@ public class RemoteReadWriteTable<K, V> extends RemoteReadableTable<K, V> implem
   @Override
   public void putAll(List<Entry<K, V>> entries) {
     try {
-      writeMetrics.numPutAlls.inc();
-      long startNs = System.nanoTime();
-      writeFn.putAll(entries);
-      writeMetrics.putAllNs.update(System.nanoTime() - startNs);
+      putAllAsync(entries).get();
     } catch (Exception e) {
-      String errMsg = String.format("Failed to put records: %s", entries);
-      logger.error(errMsg, e);
-      throw new SamzaException(errMsg, e);
+      throw new SamzaException(e);
     }
   }
 
@@ -119,19 +115,42 @@ public class RemoteReadWriteTable<K, V> extends RemoteReadableTable<K, V> implem
    * {@inheritDoc}
    */
   @Override
+  public CompletableFuture<Void> putAllAsync(List<Entry<K, V>> records) {
+    Preconditions.checkNotNull(records);
+    if (records.isEmpty()) {
+      return CompletableFuture.completedFuture(null);
+    }
+
+    writeMetrics.numPutAlls.inc();
+
+    List<K> deleteKeys = records.stream()
+        .filter(e -> e.getValue() == null).map(Entry::getKey).collect(Collectors.toList());
+
+    CompletableFuture<Void> deleteFuture = deleteKeys.isEmpty()
+        ? CompletableFuture.completedFuture(null) : deleteAllAsync(deleteKeys);
+
+    List<Entry<K, V>> putRecords = records.stream()
+        .filter(e -> e.getValue() != null).collect(Collectors.toList());
+
+    // Return the combined future
+    return CompletableFuture.allOf(
+        deleteFuture,
+        executeRecords(writeRateLimiter, putRecords, writeFn::putAllAsync, writeMetrics.putAllNs))
+        .exceptionally(e -> {
+            String strKeys = records.stream().map(r -> r.getKey().toString()).collect(Collectors.joining(","));
+            throw new SamzaException(String.format("Failed to put records with keys=" + strKeys), e);
+          });
+  }
+
+  /**
+   * {@inheritDoc}
+   */
+  @Override
   public void delete(K key) {
     try {
-      writeMetrics.numDeletes.inc();
-      if (rateLimitWrites) {
-        throttle(key, null, RL_WRITE_TAG, writeCreditFn, putThrottleNs);
-      }
-      long startNs = System.nanoTime();
-      writeFn.delete(key);
-      writeMetrics.deleteNs.update(System.nanoTime() - startNs);
+      deleteAsync(key).get();
     } catch (Exception e) {
-      String errMsg = String.format("Failed to delete a record, key=%s", key);
-      logger.error(errMsg, e);
-      throw new SamzaException(errMsg, e);
+      throw new SamzaException(e);
     }
   }
 
@@ -139,16 +158,24 @@ public class RemoteReadWriteTable<K, V> extends RemoteReadableTable<K, V> implem
    * {@inheritDoc}
    */
   @Override
+  public CompletableFuture<Void> deleteAsync(K key) {
+    Preconditions.checkNotNull(key);
+    writeMetrics.numDeletes.inc();
+    return execute(writeRateLimiter, key, writeFn::deleteAsync, writeMetrics.deleteNs)
+        .exceptionally(e -> {
+            throw new SamzaException(String.format("Failed to delete the record for " + key), (Throwable) e);
+          });
+  }
+
+  /**
+   * {@inheritDoc}
+   */
+  @Override
   public void deleteAll(List<K> keys) {
     try {
-      writeMetrics.numDeleteAlls.inc();
-      writeFn.deleteAll(keys);
-      long startNs = System.nanoTime();
-      writeMetrics.deleteAllNs.update(System.nanoTime() - startNs);
+      deleteAllAsync(keys).get();
     } catch (Exception e) {
-      String errMsg = String.format("Failed to delete records, keys=%s", keys);
-      logger.error(errMsg, e);
-      throw new SamzaException(errMsg, e);
+      throw new SamzaException(e);
     }
   }
 
@@ -156,12 +183,26 @@ public class RemoteReadWriteTable<K, V> extends RemoteReadableTable<K, V> implem
    * {@inheritDoc}
    */
   @Override
+  public CompletableFuture<Void> deleteAllAsync(List<K> keys) {
+    Preconditions.checkNotNull(keys);
+    if (keys.isEmpty()) {
+      return CompletableFuture.completedFuture(null);
+    }
+
+    writeMetrics.numDeleteAlls.inc();
+    return execute(writeRateLimiter, keys, writeFn::deleteAllAsync, writeMetrics.deleteAllNs)
+        .exceptionally(e -> {
+            throw new SamzaException(String.format("Failed to delete records for " + keys), (Throwable) e);
+          });
+  }
+
+  /**
+   * {@inheritDoc}
+   */
+  @Override
   public void flush() {
     try {
       writeMetrics.numFlushes.inc();
-      if (rateLimitWrites) {
-        throttle(null, null, RL_WRITE_TAG, writeCreditFn, putThrottleNs);
-      }
       long startNs = System.nanoTime();
       writeFn.flush();
       writeMetrics.flushNs.update(System.nanoTime() - startNs);
@@ -177,7 +218,7 @@ public class RemoteReadWriteTable<K, V> extends RemoteReadableTable<K, V> implem
    */
   @Override
   public void close() {
-    super.close();
     writeFn.close();
+    super.close();
   }
 }
index d919d2f..24edbce 100644 (file)
 
 package org.apache.samza.table.remote;
 
+import java.util.Collection;
 import java.util.Collections;
 import java.util.List;
 import java.util.Map;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutorService;
+import java.util.function.BiFunction;
+import java.util.function.Function;
 
 import org.apache.samza.SamzaException;
 import org.apache.samza.container.SamzaContainerContext;
 import org.apache.samza.metrics.Timer;
-import org.apache.samza.operators.KV;
+import org.apache.samza.storage.kv.Entry;
 import org.apache.samza.table.ReadableTable;
 import org.apache.samza.table.utils.DefaultTableReadMetrics;
 import org.apache.samza.table.utils.TableMetricsUtil;
 import org.apache.samza.task.TaskContext;
-import org.apache.samza.util.RateLimiter;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import com.google.common.annotations.VisibleForTesting;
 import com.google.common.base.Preconditions;
 
-import static org.apache.samza.table.remote.RemoteTableDescriptor.RL_READ_TAG;
-
 
 /**
  * A Samza {@link org.apache.samza.table.Table} backed by a remote data-store or service.
@@ -60,6 +63,12 @@ import static org.apache.samza.table.remote.RemoteTableDescriptor.RL_READ_TAG;
  * these reader and writer functions, sub-classes of {@link RemoteReadableTable} may provide rich functionality like
  * caching or throttling on top of them.
  *
+ * For async IO methods, requests are dispatched by a single-threaded executor after invoking the rateLimiter.
+ * Optionally, an executor can be specified for invoking the future callbacks which otherwise are
+ * executed on the threads of the underlying native data store client. This could be useful when
+ * application might execute long-running operations upon future completions; another use case is to increase
+ * throughput with more parallelism in the callback executions.
+ *
  * @param <K> the type of the key in this table
  * @param <V> the type of the value in this table
  */
@@ -67,34 +76,34 @@ public class RemoteReadableTable<K, V> implements ReadableTable<K, V> {
 
   protected final String tableId;
   protected final Logger logger;
-  protected final TableReadFunction<K, V> readFn;
-  protected final String groupName;
-  protected final RateLimiter rateLimiter;
-  protected final CreditFunction<K, V> readCreditFn;
-  protected final boolean rateLimitReads;
 
-  protected DefaultTableReadMetrics readMetrics;
-  protected Timer getThrottleNs;
+  protected final ExecutorService callbackExecutor;
+  protected final ExecutorService tableExecutor;
+
+  private final TableReadFunction<K, V> readFn;
+  private DefaultTableReadMetrics readMetrics;
+
+  @VisibleForTesting
+  final TableRateLimiter<K, V> readRateLimiter;
 
   /**
    * Construct a RemoteReadableTable instance
    * @param tableId table id
    * @param readFn {@link TableReadFunction} for read operations
-   * @param rateLimiter optional {@link RateLimiter} for throttling reads
-   * @param readCreditFn function returning a credit to be charged for rate limiting per record
+   * @param rateLimiter helper for rate limiting
+   * @param tableExecutor executor for issuing async requests
+   * @param callbackExecutor executor for invoking async callbacks
    */
-  public RemoteReadableTable(String tableId, TableReadFunction<K, V> readFn, RateLimiter rateLimiter,
-      CreditFunction<K, V> readCreditFn) {
+  public RemoteReadableTable(String tableId, TableReadFunction<K, V> readFn,
+      TableRateLimiter<K, V> rateLimiter, ExecutorService tableExecutor, ExecutorService callbackExecutor) {
     Preconditions.checkArgument(tableId != null && !tableId.isEmpty(), "invalid table id");
     Preconditions.checkNotNull(readFn, "null read function");
     this.tableId = tableId;
     this.readFn = readFn;
-    this.rateLimiter = rateLimiter;
-    this.readCreditFn = readCreditFn;
-    this.groupName = getClass().getSimpleName();
-    this.logger = LoggerFactory.getLogger(groupName + tableId);
-    this.rateLimitReads = rateLimiter != null && rateLimiter.getSupportedTags().contains(RL_READ_TAG);
-    logger.info("Rate limiting is {} for remote read operations", rateLimitReads ? "enabled" : "disabled");
+    this.readRateLimiter = rateLimiter;
+    this.callbackExecutor = callbackExecutor;
+    this.tableExecutor = tableExecutor;
+    this.logger = LoggerFactory.getLogger(getClass().getName() + "-" + tableId);
   }
 
   /**
@@ -104,7 +113,7 @@ public class RemoteReadableTable<K, V> implements ReadableTable<K, V> {
   public void init(SamzaContainerContext containerContext, TaskContext taskContext) {
     readMetrics = new DefaultTableReadMetrics(containerContext, taskContext, this, tableId);
     TableMetricsUtil tableMetricsUtil = new TableMetricsUtil(containerContext, taskContext, this, tableId);
-    getThrottleNs = tableMetricsUtil.newTimer("get-throttle-ns");
+    readRateLimiter.setTimerMetric(tableMetricsUtil.newTimer("get-throttle-ns"));
   }
 
   /**
@@ -113,73 +122,177 @@ public class RemoteReadableTable<K, V> implements ReadableTable<K, V> {
   @Override
   public V get(K key) {
     try {
-      readMetrics.numGets.inc();
-      if (rateLimitReads) {
-        throttle(key, null, RL_READ_TAG, readCreditFn, getThrottleNs);
-      }
-      long startNs = System.nanoTime();
-      V result = readFn.get(key);
-      readMetrics.getNs.update(System.nanoTime() - startNs);
-      return result;
+      return getAsync(key).get();
     } catch (Exception e) {
-      String errMsg = String.format("Failed to get a record, key=%s", key);
-      logger.error(errMsg, e);
-      throw new SamzaException(errMsg, e);
+      throw new SamzaException(e);
     }
   }
 
+  @Override
+  public CompletableFuture<V> getAsync(K key) {
+    Preconditions.checkNotNull(key);
+    readMetrics.numGets.inc();
+    return execute(readRateLimiter, key, readFn::getAsync, readMetrics.getNs)
+        .exceptionally(e -> {
+            throw new SamzaException("Failed to get the record for " + key, e);
+          });
+  }
+
   /**
    * {@inheritDoc}
    */
   @Override
   public Map<K, V> getAll(List<K> keys) {
-    Map<K, V> result;
+    readMetrics.numGetAlls.inc();
     try {
-      readMetrics.numGetAlls.inc();
-      long startNs = System.nanoTime();
-      result = readFn.getAll(keys);
-      readMetrics.getAllNs.update(System.nanoTime() - startNs);
+      return getAllAsync(keys).get();
     } catch (Exception e) {
-      String errMsg = "Failed to get some records";
-      logger.error(errMsg, e);
-      throw new SamzaException(errMsg, e);
+      throw new SamzaException(e);
     }
+  }
 
-    if (result == null) {
-      String errMsg = String.format("Received null records, keys=%s", keys);
-      logger.error(errMsg);
-      throw new SamzaException(errMsg);
+  @Override
+  public CompletableFuture<Map<K, V>> getAllAsync(List<K> keys) {
+    Preconditions.checkNotNull(keys);
+    if (keys.isEmpty()) {
+      return CompletableFuture.completedFuture(Collections.EMPTY_MAP);
     }
+    readMetrics.numGetAlls.inc();
+    return execute(readRateLimiter, keys, readFn::getAllAsync, readMetrics.getAllNs)
+        .handle((result, e) -> {
+            if (e != null) {
+              throw new SamzaException("Failed to get the records for " + keys, e);
+            }
+            return result;
+          });
+  }
 
-    if (result.size() < keys.size()) {
-      String errMsg = String.format("Received insufficient number of records (%d), keys=%s", result.size(), keys);
-      logger.error(errMsg);
-      throw new SamzaException(errMsg);
+  /**
+   * Execute an async request given a table key
+   * @param key key of the table record
+   * @param method method to be executed
+   * @param timer latency metric to be updated
+   * @param <T> return type
+   * @return CompletableFuture of the operation
+   */
+  protected <T> CompletableFuture<T> execute(TableRateLimiter<K, V> rateLimiter,
+      K key, Function<K, CompletableFuture<T>> method, Timer timer) {
+    final long startNs = System.nanoTime();
+    CompletableFuture<T> ioFuture = rateLimiter.isRateLimited() ?
+        CompletableFuture
+            .runAsync(() -> rateLimiter.throttle(key), tableExecutor)
+            .thenCompose((r) -> method.apply(key)) :
+        method.apply(key);
+    if (callbackExecutor != null) {
+      ioFuture.thenApplyAsync(r -> {
+          timer.update(System.nanoTime() - startNs);
+          return r;
+        }, callbackExecutor);
+    } else {
+      ioFuture.thenApply(r -> {
+          timer.update(System.nanoTime() - startNs);
+          return r;
+        });
     }
+    return ioFuture;
+  }
 
-    return result;
+  /**
+   * Execute an async request given a table record (key+value)
+   * @param key key of the table record
+   * @param value value of the table record
+   * @param method method to be executed
+   * @param timer latency metric to be updated
+   * @return CompletableFuture of the operation
+   */
+  protected CompletableFuture<Void> execute(TableRateLimiter<K, V> rateLimiter,
+      K key, V value, BiFunction<K, V, CompletableFuture<Void>> method, Timer timer) {
+    final long startNs = System.nanoTime();
+    CompletableFuture<Void> ioFuture = rateLimiter.isRateLimited() ?
+        CompletableFuture
+            .runAsync(() -> rateLimiter.throttle(key, value), tableExecutor)
+            .thenCompose((r) -> method.apply(key, value)) :
+        method.apply(key, value);
+    if (callbackExecutor != null) {
+      ioFuture.thenApplyAsync(r -> {
+          timer.update(System.nanoTime() - startNs);
+          return r;
+        }, callbackExecutor);
+    } else {
+      ioFuture.thenApply(r -> {
+          timer.update(System.nanoTime() - startNs);
+          return r;
+        });
+    }
+    return ioFuture;
   }
 
   /**
-   * {@inheritDoc}
+   * Execute an async request given a collection of table keys
+   * @param keys collection of keys
+   * @param method method to be executed
+   * @param timer latency metric to be updated
+   * @return CompletableFuture of the operation
    */
-  @Override
-  public void close() {
-    readFn.close();
+  protected <T> CompletableFuture<T> execute(TableRateLimiter<K, V> rateLimiter,
+      Collection<K> keys, Function<Collection<K>, CompletableFuture<T>> method, Timer timer) {
+    final long startNs = System.nanoTime();
+    CompletableFuture<T> ioFuture = rateLimiter.isRateLimited() ?
+        CompletableFuture
+            .runAsync(() -> rateLimiter.throttle(keys), tableExecutor)
+            .thenCompose((r) -> method.apply(keys)) :
+        method.apply(keys);
+    if (callbackExecutor != null) {
+      ioFuture.thenApplyAsync(r -> {
+          timer.update(System.nanoTime() - startNs);
+          return r;
+        }, callbackExecutor);
+    } else {
+      ioFuture.thenApply(r -> {
+          timer.update(System.nanoTime() - startNs);
+          return r;
+        });
+    }
+    return ioFuture;
   }
 
   /**
-   * Throttle requests given a table record (key, value) with rate limiter and credit function
-   * @param key key of the table record (nullable)
-   * @param value value of the table record (nullable)
-   * @param tag tag for rate limiter
-   * @param creditFn mapper function from KV to credits to be charged
-   * @param timer timer metric to track throttling delays
+   * Execute an async request given a collection of table records
+   * @param records list of records
+   * @param method method to be executed
+   * @param timer latency metric to be updated
+   * @return CompletableFuture of the operation
    */
-  protected void throttle(K key, V value, String tag, CreditFunction<K, V> creditFn, Timer timer) {
-    long startNs = System.nanoTime();
-    int credits = (creditFn == null) ? 1 : creditFn.apply(KV.of(key, value));
-    rateLimiter.acquire(Collections.singletonMap(tag, credits));
-    timer.update(System.nanoTime() - startNs);
+  protected CompletableFuture<Void> executeRecords(TableRateLimiter<K, V> rateLimiter,
+      Collection<Entry<K, V>> records, Function<Collection<Entry<K, V>>, CompletableFuture<Void>> method, Timer timer) {
+    final long startNs = System.nanoTime();
+    CompletableFuture<Void> ioFuture;
+    if (rateLimiter.isRateLimited()) {
+      ioFuture = CompletableFuture
+          .runAsync(() -> rateLimiter.throttleRecords(records), tableExecutor)
+          .thenCompose((r) -> method.apply(records));
+    } else {
+      ioFuture = method.apply(records);
+    }
+    if (callbackExecutor != null) {
+      ioFuture.thenApplyAsync(r -> {
+          timer.update(System.nanoTime() - startNs);
+          return r;
+        }, callbackExecutor);
+    } else {
+      ioFuture.thenApply(r -> {
+          timer.update(System.nanoTime() - startNs);
+          return r;
+        });
+    }
+    return ioFuture;
+  }
+
+  /**
+   * {@inheritDoc}
+   */
+  @Override
+  public void close() {
+    readFn.close();
   }
 }
index bad4639..e405096 100644 (file)
@@ -41,15 +41,16 @@ public class RemoteTableDescriptor<K, V> extends BaseTableDescriptor<K, V, Remot
   /**
    * Tag to be used for provision credits for rate limiting read operations from the remote table.
    * Caller must pre-populate the credits with this tag when specifying a custom rate limiter instance
-   * through {@link RemoteTableDescriptor#withRateLimiter(RateLimiter, CreditFunction, CreditFunction)}
+   * through {@link RemoteTableDescriptor#withRateLimiter(RateLimiter, TableRateLimiter.CreditFunction,
+   * TableRateLimiter.CreditFunction)}
    */
   public static final String RL_READ_TAG = "readTag";
 
   /**
    * Tag to be used for provision credits for rate limiting write operations into the remote table.
    * Caller can optionally populate the credits with this tag when specifying a custom rate limiter instance
-   * through {@link RemoteTableDescriptor#withRateLimiter(RateLimiter, CreditFunction, CreditFunction)}
-   * and it needs the write functionality.
+   * through {@link RemoteTableDescriptor#withRateLimiter(RateLimiter, TableRateLimiter.CreditFunction,
+   * TableRateLimiter.CreditFunction)} and it needs the write functionality.
    */
   public static final String RL_WRITE_TAG = "writeTag";
 
@@ -66,8 +67,12 @@ public class RemoteTableDescriptor<K, V> extends BaseTableDescriptor<K, V, Remot
   // Rates for constructing the default rate limiter when they are non-zero
   private Map<String, Integer> tagCreditsMap = new HashMap<>();
 
-  private CreditFunction<K, V> readCreditFn;
-  private CreditFunction<K, V> writeCreditFn;
+  private TableRateLimiter.CreditFunction<K, V> readCreditFn;
+  private TableRateLimiter.CreditFunction<K, V> writeCreditFn;
+
+  // By default execute future callbacks on the native client threads
+  // ie. no additional thread pool for callbacks.
+  private int asyncCallbackPoolSize = -1;
 
   /**
    * Construct a table descriptor instance
@@ -111,6 +116,8 @@ public class RemoteTableDescriptor<K, V> extends BaseTableDescriptor<K, V, Remot
           "write credit function", writeCreditFn));
     }
 
+    tableSpecConfig.put(RemoteTableProvider.ASYNC_CALLBACK_POOL_SIZE, String.valueOf(asyncCallbackPoolSize));
+
     return new TableSpec(tableId, serde, RemoteTableProviderFactory.class.getName(), tableSpecConfig);
   }
 
@@ -149,8 +156,9 @@ public class RemoteTableDescriptor<K, V> extends BaseTableDescriptor<K, V, Remot
    * @param writeCreditFn credit function for rate limiting write operations
    * @return this table descriptor instance
    */
-  public RemoteTableDescriptor<K, V> withRateLimiter(RateLimiter rateLimiter, CreditFunction<K, V> readCreditFn,
-      CreditFunction<K, V> writeCreditFn) {
+  public RemoteTableDescriptor<K, V> withRateLimiter(RateLimiter rateLimiter,
+      TableRateLimiter.CreditFunction<K, V> readCreditFn,
+      TableRateLimiter.CreditFunction<K, V> writeCreditFn) {
     Preconditions.checkNotNull(rateLimiter, "null read rate limiter");
     this.rateLimiter = rateLimiter;
     this.readCreditFn = readCreditFn;
@@ -160,7 +168,8 @@ public class RemoteTableDescriptor<K, V> extends BaseTableDescriptor<K, V, Remot
 
   /**
    * Specify the rate limit for table read operations. If the read rate limit is set with this method
-   * it is invalid to call {@link RemoteTableDescriptor#withRateLimiter(RateLimiter, CreditFunction, CreditFunction)}
+   * it is invalid to call {@link RemoteTableDescriptor#withRateLimiter(RateLimiter,
+   * TableRateLimiter.CreditFunction, TableRateLimiter.CreditFunction)}
    * and vice versa.
    * @param creditsPerSec rate limit for read operations; must be positive
    * @return this table descriptor instance
@@ -173,7 +182,8 @@ public class RemoteTableDescriptor<K, V> extends BaseTableDescriptor<K, V, Remot
 
   /**
    * Specify the rate limit for table write operations. If the write rate limit is set with this method
-   * it is invalid to call {@link RemoteTableDescriptor#withRateLimiter(RateLimiter, CreditFunction, CreditFunction)}
+   * it is invalid to call {@link RemoteTableDescriptor#withRateLimiter(RateLimiter,
+   * TableRateLimiter.CreditFunction, TableRateLimiter.CreditFunction)}
    * and vice versa.
    * @param creditsPerSec rate limit for write operations; must be positive
    * @return this table descriptor instance
@@ -184,11 +194,30 @@ public class RemoteTableDescriptor<K, V> extends BaseTableDescriptor<K, V, Remot
     return this;
   }
 
+  /**
+   * Specify the size of the thread pool for the executor used to execute
+   * callbacks of CompletableFutures of async Table operations. By default, these
+   * futures are completed (called) by the threads of the native store client. Depending
+   * on the implementation of the native client, it may or may not allow executing long
+   * running operations in the callbacks. This config can be used to execute the callbacks
+   * from a separate executor to decouple from the native client. If configured, this
+   * thread pool is shared by all read and write operations.
+   * @param poolSize max number of threads in the executor for async callbacks
+   * @return this table descriptor instance
+   */
+  public RemoteTableDescriptor<K, V> withAsyncCallbackExecutorPoolSize(int poolSize) {
+    this.asyncCallbackPoolSize = poolSize;
+    return this;
+  }
+
   @Override
   protected void validate() {
     super.validate();
     Preconditions.checkNotNull(readFn, "TableReadFunction is required.");
     Preconditions.checkArgument(rateLimiter == null || tagCreditsMap.isEmpty(),
         "Only one of rateLimiter instance or read/write limits can be specified");
+    // Assume callback executor pool should have no more than 20 threads
+    Preconditions.checkArgument(asyncCallbackPoolSize <= 20,
+        "too many threads for async callback executor.");
   }
 }
index b4051cb..f09c6fd 100644 (file)
@@ -23,6 +23,9 @@ import java.util.ArrayList;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
 
 import org.apache.samza.config.JavaTableConfig;
 import org.apache.samza.container.SamzaContainerContext;
@@ -35,6 +38,9 @@ import org.apache.samza.util.RateLimiter;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import static org.apache.samza.table.remote.RemoteTableDescriptor.RL_READ_TAG;
+import static org.apache.samza.table.remote.RemoteTableDescriptor.RL_WRITE_TAG;
+
 
 /**
  * Provide for remote table instances
@@ -42,11 +48,12 @@ import org.slf4j.LoggerFactory;
 public class RemoteTableProvider implements TableProvider {
   private static final Logger LOG = LoggerFactory.getLogger(RemoteTableProvider.class);
 
-  static final String READ_FN = "io.readFn";
-  static final String WRITE_FN = "io.writeFn";
+  static final String READ_FN = "io.read.func";
+  static final String WRITE_FN = "io.write.func";
   static final String RATE_LIMITER = "io.ratelimiter";
-  static final String READ_CREDIT_FN = "io.readCreditFn";
-  static final String WRITE_CREDIT_FN = "io.writeCreditFn";
+  static final String READ_CREDIT_FN = "io.read.credit.func";
+  static final String WRITE_CREDIT_FN = "io.write.credit.func";
+  static final String ASYNC_CALLBACK_POOL_SIZE = "io.async.callback.pool.size";
 
   private final TableSpec tableSpec;
   private final boolean readOnly;
@@ -54,9 +61,17 @@ public class RemoteTableProvider implements TableProvider {
   private SamzaContainerContext containerContext;
   private TaskContext taskContext;
 
+  /**
+   * Map of tableId -> executor service for async table IO and callbacks. The same executors
+   * are shared by both read/write operations such that tables of the same tableId all share
+   * the set same of executors globally whereas table itself is per-task.
+   */
+  private static Map<String, ExecutorService> tableExecutors = new ConcurrentHashMap<>();
+  private static Map<String, ExecutorService> callbackExecutors = new ConcurrentHashMap<>();
+
   public RemoteTableProvider(TableSpec tableSpec) {
     this.tableSpec = tableSpec;
-    readOnly = !tableSpec.getConfig().containsKey(WRITE_FN);
+    this.readOnly = !tableSpec.getConfig().containsKey(WRITE_FN);
   }
 
   /**
@@ -74,18 +89,56 @@ public class RemoteTableProvider implements TableProvider {
   @Override
   public Table getTable() {
     RemoteReadableTable table;
+    String tableId = tableSpec.getId();
+
     TableReadFunction<?, ?> readFn = getReadFn();
     RateLimiter rateLimiter = deserializeObject(RATE_LIMITER);
     if (rateLimiter != null) {
       rateLimiter.init(containerContext.config, taskContext);
     }
-    CreditFunction<?, ?> readCreditFn = deserializeObject(READ_CREDIT_FN);
+    TableRateLimiter.CreditFunction<?, ?> readCreditFn = deserializeObject(READ_CREDIT_FN);
+    TableRateLimiter readRateLimiter = new TableRateLimiter(tableSpec.getId(), rateLimiter, readCreditFn, RL_READ_TAG);
+
+    TableRateLimiter.CreditFunction<?, ?> writeCreditFn;
+    TableRateLimiter writeRateLimiter = null;
+
+    boolean isRateLimited = readRateLimiter.isRateLimited();
+    if (!readOnly) {
+      writeCreditFn = deserializeObject(WRITE_CREDIT_FN);
+      writeRateLimiter = new TableRateLimiter(tableSpec.getId(), rateLimiter, writeCreditFn, RL_WRITE_TAG);
+      isRateLimited |= writeRateLimiter.isRateLimited();
+    }
+
+    // Optional executor for future callback/completion. Shared by both read and write operations.
+    int callbackPoolSize = Integer.parseInt(tableSpec.getConfig().get(ASYNC_CALLBACK_POOL_SIZE));
+    if (callbackPoolSize > 0) {
+      callbackExecutors.computeIfAbsent(tableId, (arg) ->
+          Executors.newFixedThreadPool(callbackPoolSize, (runnable) -> {
+              Thread thread = new Thread(runnable);
+              thread.setName("table-" + tableId + "-async-callback-pool");
+              thread.setDaemon(true);
+              return thread;
+            }));
+    }
+
+    if (isRateLimited) {
+      tableExecutors.computeIfAbsent(tableId, (arg) ->
+          Executors.newSingleThreadExecutor(runnable -> {
+              Thread thread = new Thread(runnable);
+              thread.setName("table-" + tableId + "-async-executor");
+              thread.setDaemon(true);
+              return thread;
+            }));
+    }
+
     if (readOnly) {
-      table = new RemoteReadableTable(tableSpec.getId(), readFn, rateLimiter, readCreditFn);
+      table = new RemoteReadableTable(tableSpec.getId(), readFn, readRateLimiter,
+          tableExecutors.get(tableId), callbackExecutors.get(tableId));
     } else {
-      CreditFunction<?, ?> writeCreditFn = deserializeObject(WRITE_CREDIT_FN);
-      table = new RemoteReadWriteTable(tableSpec.getId(), readFn, getWriteFn(), rateLimiter, readCreditFn, writeCreditFn);
+      table = new RemoteReadWriteTable(tableSpec.getId(), readFn, getWriteFn(), readRateLimiter,
+          writeRateLimiter, tableExecutors.get(tableId), callbackExecutors.get(tableId));
     }
+
     table.init(containerContext, taskContext);
     tables.add(table);
     return table;
@@ -115,6 +168,8 @@ public class RemoteTableProvider implements TableProvider {
   @Override
   public void close() {
     tables.forEach(t -> t.close());
+    tableExecutors.values().forEach(e -> e.shutdown());
+    callbackExecutors.values().forEach(e -> e.shutdown());
   }
 
   private <T> T deserializeObject(String key) {
diff --git a/samza-core/src/main/java/org/apache/samza/table/remote/TableRateLimiter.java b/samza-core/src/main/java/org/apache/samza/table/remote/TableRateLimiter.java
new file mode 100644 (file)
index 0000000..c67a648
--- /dev/null
@@ -0,0 +1,167 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.table.remote;
+
+import java.io.Serializable;
+import java.util.Collection;
+import java.util.Collections;
+
+import org.apache.samza.annotation.InterfaceStability;
+import org.apache.samza.metrics.Timer;
+import org.apache.samza.storage.kv.Entry;
+import org.apache.samza.util.RateLimiter;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.Preconditions;
+
+
+/**
+ * Helper class for remote table to throttle table IO requests with the configured rate limiter.
+ * For each request, the needed credits are calculated with the configured credit functions.
+ * The throttle methods are overloaded to support the possible CRUD operations.
+ *
+ * @param <K> type of the table key
+ * @param <V> type of the table record
+ */
+public class TableRateLimiter<K, V> {
+  private static final Logger LOG = LoggerFactory.getLogger(TableRateLimiter.class);
+
+  private final String tag;
+  private final boolean rateLimited;
+  private final CreditFunction<K, V> creditFn;
+
+  @VisibleForTesting
+  final RateLimiter rateLimiter;
+
+  private Timer waitTimeMetric;
+
+  /**
+   * Function interface for providing rate limiting credits for each table record.
+   * This interface allows callers to pass in lambda expressions which are otherwise
+   * non-serializable as-is.
+   * @param <K> the type of the key
+   * @param <V> the type of the value
+   */
+  @InterfaceStability.Unstable
+  public interface CreditFunction<K, V> extends Serializable {
+    /**
+     * Get the number of credits required for the {@code key} and {@code value} pair.
+     * @param key table key
+     * @param value table record
+     * @return number of credits
+     */
+    int getCredits(K key, V value);
+  }
+
+  /**
+   * @param tableId table id of the table to be rate limited
+   * @param rateLimiter actual rate limiter instance to be used
+   * @param creditFn function for deriving the credits for each request
+   * @param tag tag to be used with the rate limiter
+   */
+  public TableRateLimiter(String tableId, RateLimiter rateLimiter, CreditFunction<K, V> creditFn, String tag) {
+    this.rateLimiter = rateLimiter;
+    this.creditFn = creditFn;
+    this.tag = tag;
+    this.rateLimited = rateLimiter != null && rateLimiter.getSupportedTags().contains(tag);
+    LOG.info("Rate limiting is {} for {}", rateLimited ? "enabled" : "disabled", tableId);
+  }
+
+  /**
+   * Set up waitTimeMetric metric for latency reporting due to throttling.
+   * @param timer waitTimeMetric metric
+   */
+  public void setTimerMetric(Timer timer) {
+    Preconditions.checkNotNull(timer);
+    this.waitTimeMetric = timer;
+  }
+
+  int getCredits(K key, V value) {
+    return (creditFn == null) ? 1 : creditFn.getCredits(key, value);
+  }
+
+  int getCredits(Collection<K> keys) {
+    if (creditFn == null) {
+      return keys.size();
+    } else {
+      return keys.stream().mapToInt(k -> creditFn.getCredits(k, null)).sum();
+    }
+  }
+
+  int getEntryCredits(Collection<Entry<K, V>> entries) {
+    if (creditFn == null) {
+      return entries.size();
+    } else {
+      return entries.stream().mapToInt(e -> creditFn.getCredits(e.getKey(), e.getValue())).sum();
+    }
+  }
+
+  private void throttle(int credits) {
+    if (!rateLimited) {
+      return;
+    }
+
+    long startNs = System.nanoTime();
+    rateLimiter.acquire(Collections.singletonMap(tag, credits));
+    waitTimeMetric.update(System.nanoTime() - startNs);
+  }
+
+  /**
+   * Throttle a request with a key argument if necessary.
+   * @param key key used for the table request
+   */
+  public void throttle(K key) {
+    throttle(getCredits(key, null));
+  }
+
+  /**
+   * Throttle a request with both the key and value arguments if necessary.
+   * @param key key used for the table request
+   * @param value value used for the table request
+   */
+  public void throttle(K key, V value) {
+    throttle(getCredits(key, value));
+  }
+
+  /**
+   * Throttle a request with a collection of keys as the argument if necessary.
+   * @param keys collection of keys used for the table request
+   */
+  public void throttle(Collection<K> keys) {
+    throttle(getCredits(keys));
+  }
+
+  /**
+   * Throttle a request with a collection of table records as the argument if necessary.
+   * @param records collection of records used for the table request
+   */
+  public void throttleRecords(Collection<Entry<K, V>> records) {
+    throttle(getEntryCredits(records));
+  }
+
+  /**
+   * @return whether rate limiting is enabled for the associated table
+   */
+  public boolean isRateLimited() {
+    return rateLimited;
+  }
+}
index dbd386c..5d0f963 100644 (file)
@@ -21,13 +21,18 @@ package org.apache.samza.table.remote;
 
 import java.io.Serializable;
 import java.util.Collection;
-import java.util.HashMap;
 import java.util.Map;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutionException;
+import java.util.stream.Collectors;
 
+import org.apache.samza.SamzaException;
 import org.apache.samza.annotation.InterfaceStability;
 import org.apache.samza.operators.functions.ClosableFunction;
 import org.apache.samza.operators.functions.InitableFunction;
 
+import com.google.common.collect.Iterables;
+
 
 /**
  * A function object to be used with a {@link RemoteReadableTable} implementation. It encapsulates the functionality
@@ -44,22 +49,55 @@ import org.apache.samza.operators.functions.InitableFunction;
 public interface TableReadFunction<K, V> extends Serializable, InitableFunction, ClosableFunction {
   /**
    * Fetch single table record for a specified {@code key}. This method must be thread-safe.
+   * The default implementation calls getAsync and blocks on the completion afterwards.
    * @param key key for the table record
    * @return table record for the specified {@code key}
    */
-  V get(K key);
+  default V get(K key) {
+    try {
+      return getAsync(key).get();
+    } catch (InterruptedException | ExecutionException e) {
+      throw new SamzaException("GET failed for " + key, e);
+    }
+  }
+
+  /**
+   * Asynchronously fetch single table record for a specified {@code key}. This method must be thread-safe.
+   * @param key key for the table record
+   * @return CompletableFuture for the get request
+   */
+  CompletableFuture<V> getAsync(K key);
 
   /**
    * Fetch the table {@code records} for specified {@code keys}. This method must be thread-safe.
+   * The default implementation calls getAllAsync and blocks on the completion afterwards.
    * @param keys keys for the table records
-   * @return all records for the specified keys if succeeded; depending on the implementation
-   * of {@link TableReadFunction#get(Object)} it either returns records for a subset of the
-   * keys or throws exception when there is any failure.
+   * @return all records for the specified keys.
    */
   default Map<K, V> getAll(Collection<K> keys) {
-    Map<K, V> records = new HashMap<>();
-    keys.forEach(k -> records.put(k, get(k)));
-    return records;
+    try {
+      return getAllAsync(keys).get();
+    } catch (InterruptedException | ExecutionException e) {
+      throw new SamzaException("GET_ALL failed for " + keys, e);
+    }
+  }
+
+  /**
+   * Asynchronously fetch the table {@code records} for specified {@code keys}. This method must be thread-safe.
+   * The default implementation calls getAsync for each key and return a combined future.
+   * @param keys keys for the table records
+   * @return CompletableFuture for the get request
+   */
+  default CompletableFuture<Map<K, V>> getAllAsync(Collection<K> keys) {
+    Map<K, CompletableFuture<V>> getFutures =  keys.stream().collect(
+        Collectors.toMap(k -> k, k -> getAsync(k)));
+
+    return CompletableFuture.allOf(
+        Iterables.toArray(getFutures.values(), CompletableFuture.class))
+        .thenApply(future ->
+          getFutures.entrySet()
+            .stream()
+            .collect(Collectors.toMap(e -> e.getKey(), e -> e.getValue().join())));
   }
 
   // optionally implement readObject() to initialize transient states
index df54878..0ac3a0c 100644 (file)
@@ -22,11 +22,18 @@ package org.apache.samza.table.remote;
 import java.io.Serializable;
 import java.util.Collection;
 import java.util.List;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutionException;
+import java.util.stream.Collectors;
+
+import org.apache.samza.SamzaException;
 import org.apache.samza.annotation.InterfaceStability;
 import org.apache.samza.operators.functions.ClosableFunction;
 import org.apache.samza.operators.functions.InitableFunction;
 import org.apache.samza.storage.kv.Entry;
 
+import com.google.common.collect.Iterables;
+
 
 /**
  * A function object to be used with a {@link RemoteReadWriteTable} implementation. It encapsulates the functionality
@@ -43,37 +50,96 @@ import org.apache.samza.storage.kv.Entry;
 public interface TableWriteFunction<K, V> extends Serializable, InitableFunction, ClosableFunction {
   /**
    * Store single table {@code record} with specified {@code key}. This method must be thread-safe.
+   * The default implementation calls putAsync and blocks on the completion afterwards.
    *
-   * The key is deleted if record is {@code null}.
-   *
    * @param key key for the table record
    * @param record table record to be written
    */
-  void put(K key, V record);
+  default void put(K key, V record) {
+    try {
+      putAsync(key, record).get();
+    } catch (InterruptedException | ExecutionException e) {
+      throw new SamzaException("PUT failed for " + key, e);
+    }
+  }
+
+  /**
+   * Asynchronously store single table {@code record} with specified {@code key}. This method must be thread-safe.
+   * @param key key for the table record
+   * @param record table record to be written
+   * @return CompletableFuture for the put request
+   */
+  CompletableFuture<Void> putAsync(K key, V record);
 
   /**
    * Store the table {@code records} with specified {@code keys}. This method must be thread-safe.
-   *
-   * A key is deleted if its corresponding record is {@code null}.
-   *
+   * The default implementation calls putAllAsync and blocks on the completion afterwards.
    * @param records table records to be written
    */
   default void putAll(List<Entry<K, V>> records) {
-    records.forEach(e -> put(e.getKey(), e.getValue()));
+    try {
+      putAllAsync(records).get();
+    } catch (InterruptedException | ExecutionException e) {
+      throw new SamzaException("PUT_ALL failed for " + records, e);
+    }
   }
 
   /**
-   * Delete the {@code record} with specified {@code key} from the remote store
+   * Asynchronously store the table {@code records} with specified {@code keys}. This method must be thread-safe.
+   * The default implementation calls putAsync for each entry and return a combined future.
+   * @param records table records to be written
+   * @return CompletableFuture for the put request
+   */
+  default CompletableFuture<Void> putAllAsync(Collection<Entry<K, V>> records) {
+    List<CompletableFuture<Void>> putFutures =
+        records.stream().map(e -> putAsync(e.getKey(), e.getValue())).collect(Collectors.toList());
+    return CompletableFuture.allOf(Iterables.toArray(putFutures, CompletableFuture.class));
+  }
+
+  /**
+   * Delete the {@code record} with specified {@code key} from the remote store.
+   * The default implementation calls deleteAsync and blocks on the completion afterwards.
+   * @param key key to the table record to be deleted
+   */
+  default void delete(K key) {
+    try {
+      deleteAsync(key).get();
+    } catch (InterruptedException | ExecutionException e) {
+      throw new SamzaException("DELETE failed for " + key, e);
+    }
+  }
+
+  /**
+   * Asynchronously delete the {@code record} with specified {@code key} from the remote store
    * @param key key to the table record to be deleted
+   * @return CompletableFuture for the delete request
    */
-  void delete(K key);
+  CompletableFuture<Void> deleteAsync(K key);
 
   /**
    * Delete all {@code records} with the specified {@code keys} from the remote store
+   * The default implementation calls deleteAllAsync and blocks on the completion afterwards.
    * @param keys keys for the table records to be written
    */
   default void deleteAll(Collection<K> keys) {
-    keys.stream().forEach(k -> delete(k));
+    try {
+      deleteAllAsync(keys).get();
+    } catch (InterruptedException | ExecutionException e) {
+      throw new SamzaException("DELETE failed for " + keys, e);
+    }
+  }
+
+  /**
+   * Asynchronously delete all {@code records} with the specified {@code keys} from the remote store.
+   * The default implementation calls deleteAsync for each key and return a combined future.
+   *
+   * @param keys keys for the table records to be written
+   * @return CompletableFuture for the deleteAll request
+   */
+  default CompletableFuture<Void> deleteAllAsync(Collection<K> keys) {
+    List<CompletableFuture<Void>> deleteFutures =
+        keys.stream().map(this::deleteAsync).collect(Collectors.toList());
+    return CompletableFuture.allOf(Iterables.toArray(deleteFutures, CompletableFuture.class));
   }
 
   /**
index a327ae3..2acd082 100644 (file)
@@ -34,6 +34,7 @@ public class DefaultTableReadMetrics {
   public final Timer getAllNs;
   public final Counter numGets;
   public final Counter numGetAlls;
+  public final Timer getCallbackNs;
 
   /**
    * Constructor based on container and task container context
@@ -50,6 +51,7 @@ public class DefaultTableReadMetrics {
     getAllNs = tableMetricsUtil.newTimer("getAll-ns");
     numGets = tableMetricsUtil.newCounter("num-gets");
     numGetAlls = tableMetricsUtil.newCounter("num-getAlls");
+    getCallbackNs = tableMetricsUtil.newTimer("get-callback-ns");
   }
 
 }
index 150ee9a..a32d6d5 100644 (file)
@@ -37,6 +37,8 @@ public class DefaultTableWriteMetrics {
   public final Counter numDeletes;
   public final Counter numDeleteAlls;
   public final Counter numFlushes;
+  public final Timer putCallbackNs;
+  public final Timer deleteCallbackNs;
 
   /**
    * Utility class that contains the default set of write metrics.
@@ -59,5 +61,7 @@ public class DefaultTableWriteMetrics {
     numDeletes = tableMetricsUtil.newCounter("num-deletes");
     numDeleteAlls = tableMetricsUtil.newCounter("num-deleteAlls");
     numFlushes = tableMetricsUtil.newCounter("num-flushes");
+    putCallbackNs = tableMetricsUtil.newTimer("put-callback-ns");
+    deleteCallbackNs = tableMetricsUtil.newTimer("delete-callback-ns");
   }
 }
index 2e40358..49c72dc 100644 (file)
 package org.apache.samza.table.caching;
 
 import java.time.Duration;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
 import java.util.HashMap;
+import java.util.List;
 import java.util.Map;
-import java.util.Random;
+import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.ConcurrentHashMap;
-import java.util.concurrent.CountDownLatch;
-import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Executors;
-import java.util.concurrent.TimeUnit;
 
 import org.apache.commons.lang3.tuple.Pair;
 import org.apache.samza.container.SamzaContainerContext;
+import org.apache.samza.metrics.Counter;
+import org.apache.samza.metrics.Gauge;
+import org.apache.samza.metrics.MetricsRegistry;
+import org.apache.samza.metrics.Timer;
 import org.apache.samza.operators.TableImpl;
+import org.apache.samza.storage.kv.Entry;
 import org.apache.samza.table.ReadWriteTable;
 import org.apache.samza.table.ReadableTable;
 import org.apache.samza.table.Table;
 import org.apache.samza.table.TableSpec;
+import org.apache.samza.table.caching.guava.GuavaCacheTable;
 import org.apache.samza.table.caching.guava.GuavaCacheTableDescriptor;
 import org.apache.samza.table.caching.guava.GuavaCacheTableProvider;
+import org.apache.samza.table.remote.TableRateLimiter;
+import org.apache.samza.table.remote.RemoteReadWriteTable;
+import org.apache.samza.table.remote.TableReadFunction;
+import org.apache.samza.table.remote.TableWriteFunction;
 import org.apache.samza.task.TaskContext;
 import org.apache.samza.util.NoOpMetricsRegistry;
 import org.junit.Assert;
 import org.junit.Test;
 
+import com.google.common.cache.Cache;
 import com.google.common.cache.CacheBuilder;
 
 import static org.mockito.Matchers.any;
@@ -81,7 +93,6 @@ public class TestCachingTable {
       desc.withCache(cache);
     }
 
-    desc.withStripes(32);
     desc.withWriteAround();
 
     TableSpec spec = desc.getTableSpec();
@@ -94,7 +105,6 @@ public class TestCachingTable {
       Assert.assertTrue(spec.getConfig().containsKey(CachingTableProvider.CACHE_TABLE_ID));
     }
 
-    Assert.assertEquals("32", spec.getConfig().get(CachingTableProvider.LOCK_STRIPES));
     Assert.assertEquals("true", spec.getConfig().get(CachingTableProvider.WRITE_AROUND));
 
     desc.validate();
@@ -106,33 +116,39 @@ public class TestCachingTable {
     // CHM for each does not serialize such two-step operation so the atomicity is still tested.
     // Regular HashMap is not thread-safe even for disjoint keys.
     final Map<String, String> cacheStore = new ConcurrentHashMap<>();
-    final ReadWriteTable tableCache = mock(ReadWriteTable.class);
+    final ReadWriteTable cacheTable = mock(ReadWriteTable.class);
 
     doAnswer(invocation -> {
         String key = invocation.getArgumentAt(0, String.class);
         String value = invocation.getArgumentAt(1, String.class);
         cacheStore.put(key, value);
         return null;
-      }).when(tableCache).put(any(), any());
+      }).when(cacheTable).put(any(), any());
 
     doAnswer(invocation -> {
         String key = invocation.getArgumentAt(0, String.class);
         return cacheStore.get(key);
-      }).when(tableCache).get(any());
+      }).when(cacheTable).get(any());
 
     doAnswer(invocation -> {
         String key = invocation.getArgumentAt(0, String.class);
         return cacheStore.remove(key);
-      }).when(tableCache).delete(any());
+      }).when(cacheTable).delete(any());
 
-    return Pair.of(tableCache, cacheStore);
+    return Pair.of(cacheTable, cacheStore);
   }
 
-  private void initTable(CachingTable cachingTable) {
+  private void initTables(ReadableTable ... tables) {
     SamzaContainerContext containerContext = mock(SamzaContainerContext.class);
     TaskContext taskContext = mock(TaskContext.class);
-    when(taskContext.getMetricsRegistry()).thenReturn(new NoOpMetricsRegistry());
-    cachingTable.init(containerContext, taskContext);
+    MetricsRegistry metricsRegistry = mock(MetricsRegistry.class);
+    doReturn(mock(Timer.class)).when(metricsRegistry).newTimer(anyString(), anyString());
+    doReturn(mock(Counter.class)).when(metricsRegistry).newCounter(anyString(), anyString());
+    doReturn(mock(Gauge.class)).when(metricsRegistry).newGauge(anyString(), any());
+    when(taskContext.getMetricsRegistry()).thenReturn(metricsRegistry);
+    for (ReadableTable table : tables) {
+      table.init(containerContext, taskContext);
+    }
   }
 
   private void doTestCacheOps(boolean isWriteAround) {
@@ -147,14 +163,16 @@ public class TestCachingTable {
     SamzaContainerContext containerContext = mock(SamzaContainerContext.class);
 
     TaskContext taskContext = mock(TaskContext.class);
-    final ReadWriteTable tableCache = getMockCache().getLeft();
+    final ReadWriteTable cacheTable = getMockCache().getLeft();
 
     final ReadWriteTable realTable = mock(ReadWriteTable.class);
 
     doAnswer(invocation -> {
         String key = invocation.getArgumentAt(0, String.class);
-        return "test-data-" + key;
-      }).when(realTable).get(any());
+        return CompletableFuture.completedFuture("test-data-" + key);
+      }).when(realTable).getAsync(any());
+
+    doReturn(CompletableFuture.completedFuture(null)).when(realTable).putAsync(any(), any());
 
     doAnswer(invocation -> {
         String tableId = invocation.getArgumentAt(0, String.class);
@@ -162,7 +180,7 @@ public class TestCachingTable {
           // cache
           return realTable;
         } else if (tableId.equals("cacheTable")) {
-          return tableCache;
+          return cacheTable;
         }
 
         Assert.fail();
@@ -173,39 +191,39 @@ public class TestCachingTable {
 
     tableProvider.init(containerContext, taskContext);
 
-    CachingTable cacheTable = (CachingTable) tableProvider.getTable();
+    CachingTable cachingTable = (CachingTable) tableProvider.getTable();
 
-    Assert.assertEquals("test-data-1", cacheTable.get("1"));
-    verify(realTable, times(1)).get(any());
-    verify(tableCache, times(2)).get(any()); // cache miss leads to 2 more get() calls
-    verify(tableCache, times(1)).put(any(), any());
-    Assert.assertEquals(cacheTable.hitRate(), 0.0, 0.0); // 0 hit, 1 request
-    Assert.assertEquals(cacheTable.missRate(), 1.0, 0.0);
+    Assert.assertEquals("test-data-1", cachingTable.get("1"));
+    verify(realTable, times(1)).getAsync(any());
+    verify(cacheTable, times(1)).get(any()); // cache miss
+    verify(cacheTable, times(1)).put(any(), any());
+    Assert.assertEquals(cachingTable.hitRate(), 0.0, 0.0); // 0 hit, 1 request
+    Assert.assertEquals(cachingTable.missRate(), 1.0, 0.0);
 
-    Assert.assertEquals("test-data-1", cacheTable.get("1"));
-    verify(realTable, times(1)).get(any()); // no change
-    verify(tableCache, times(3)).get(any());
-    verify(tableCache, times(1)).put(any(), any()); // no change
-    Assert.assertEquals(cacheTable.hitRate(), 0.5, 0.0); // 1 hit, 2 requests
-    Assert.assertEquals(cacheTable.missRate(), 0.5, 0.0);
+    Assert.assertEquals("test-data-1", cachingTable.get("1"));
+    verify(realTable, times(1)).getAsync(any()); // no change
+    verify(cacheTable, times(2)).get(any());
+    verify(cacheTable, times(1)).put(any(), any()); // no change
+    Assert.assertEquals(0.5, cachingTable.hitRate(), 0.0); // 1 hit, 2 requests
+    Assert.assertEquals(0.5, cachingTable.missRate(), 0.0);
 
-    cacheTable.put("2", "test-data-XXXX");
-    verify(tableCache, times(isWriteAround ? 1 : 2)).put(any(), any());
-    verify(realTable, times(1)).put(any(), any());
+    cachingTable.put("2", "test-data-XXXX");
+    verify(cacheTable, times(isWriteAround ? 1 : 2)).put(any(), any());
+    verify(realTable, times(1)).putAsync(any(), any());
 
     if (isWriteAround) {
-      Assert.assertEquals("test-data-2", cacheTable.get("2")); // expects value from table
-      verify(tableCache, times(5)).get(any()); // cache miss leads to 2 more get() calls
-      Assert.assertEquals(cacheTable.hitRate(), 0.33, 0.1); // 1 hit, 3 requests
+      Assert.assertEquals("test-data-2", cachingTable.get("2")); // expects value from table
+      verify(realTable, times(2)).getAsync(any()); // should have one more fetch
+      Assert.assertEquals(cachingTable.hitRate(), 0.33, 0.1); // 1 hit, 3 requests
     } else {
-      Assert.assertEquals("test-data-XXXX", cacheTable.get("2")); // expect value from cache
-      verify(tableCache, times(4)).get(any()); // cache hit
-      Assert.assertEquals(cacheTable.hitRate(), 0.66, 0.1); // 2 hits, 3 requests
+      Assert.assertEquals("test-data-XXXX", cachingTable.get("2")); // expect value from cache
+      verify(realTable, times(1)).getAsync(any()); // no change
+      Assert.assertEquals(cachingTable.hitRate(), 0.66, 0.1); // 2 hits, 3 requests
     }
   }
 
   @Test
-  public void testCacheOps() {
+  public void testCacheOpsWriteThrough() {
     doTestCacheOps(false);
   }
 
@@ -217,12 +235,12 @@ public class TestCachingTable {
   @Test
   public void testNonexistentKeyInTable() {
     ReadableTable<String, String> table = mock(ReadableTable.class);
-    doReturn(null).when(table).get(any());
+    doReturn(CompletableFuture.completedFuture(null)).when(table).getAsync(any());
     ReadWriteTable<String, String> cache = getMockCache().getLeft();
-    CachingTable<String, String> cachingTable = new CachingTable<>("myTable", table, cache, 16, false);
-    initTable(cachingTable);
+    CachingTable<String, String> cachingTable = new CachingTable<>("myTable", table, cache, false);
+    initTables(cachingTable);
     Assert.assertNull(cachingTable.get("abc"));
-    verify(cache, times(2)).get(any());
+    verify(cache, times(1)).get(any());
     Assert.assertNull(cache.get("abc"));
     verify(cache, times(0)).put(any(), any());
   }
@@ -230,86 +248,119 @@ public class TestCachingTable {
   @Test
   public void testKeyEviction() {
     ReadableTable<String, String> table = mock(ReadableTable.class);
-    doReturn("3").when(table).get(any());
+    doReturn(CompletableFuture.completedFuture("3")).when(table).getAsync(any());
     ReadWriteTable<String, String> cache = mock(ReadWriteTable.class);
 
     // no handler added to mock cache so get/put are noop, this can simulate eviction
-    CachingTable<String, String> cachingTable = new CachingTable<>("myTable", table, cache, 16, false);
-    initTable(cachingTable);
+    CachingTable<String, String> cachingTable = new CachingTable<>("myTable", table, cache, false);
+    initTables(cachingTable);
     cachingTable.get("abc");
-    verify(table, times(1)).get(any());
+    verify(table, times(1)).getAsync(any());
 
     // get() should go to table again
     cachingTable.get("abc");
-    verify(table, times(2)).get(any());
+    verify(table, times(2)).getAsync(any());
   }
 
   /**
-   * Test the atomic operations in CachingTable by simulating 10 threads each executing
-   * 5000 random operations (GET/PUT/DELETE) with random keys, which are picked from a
-   * narrow range (0-9) for higher concurrency. Consistency is verified by comparing
-   * the cache content and table content both of which should match exactly. Eviction
-   * is not simulated because it would be impossible to compare the cache/table.
-   * @throws InterruptedException
+   * Testing caching in a more realistic scenario with Guava cache + remote table
    */
   @Test
-  public void testConcurrentAccess() throws InterruptedException {
-    final int numThreads = 10;
-    final int iterations = 5000;
-    ExecutorService executor = Executors.newFixedThreadPool(numThreads);
-
-    // Ensure all threads to reach rendezvous before starting the simulation
-    final CountDownLatch startLatch = new CountDownLatch(numThreads);
-
-    Pair<ReadWriteTable<String, String>, Map<String, String>> tableMapPair = getMockCache();
-    final ReadableTable<String, String> table = tableMapPair.getLeft();
-
-    Pair<ReadWriteTable<String, String>, Map<String, String>> cacheMapPair = getMockCache();
-    final CachingTable<String, String> cachingTable = new CachingTable<>("myTable", table, cacheMapPair.getLeft(), 16, false);
-
-    Map<String, String> cacheMap = cacheMapPair.getRight();
-    Map<String, String> tableMap = tableMapPair.getRight();
-
-    final Random rand = new Random(System.currentTimeMillis());
-
-    for (int i = 0; i < numThreads; i++) {
-      executor.submit(() -> {
-          try {
-            startLatch.countDown();
-            startLatch.await();
-          } catch (InterruptedException e) {
-            Assert.fail();
-          }
-
-          String lastPutKey = null;
-          for (int j = 0; j < iterations; j++) {
-            int cmd = rand.nextInt(3);
-            String key = String.valueOf(rand.nextInt(10));
-            switch (cmd) {
-              case 0:
-                cachingTable.get(key);
-                break;
-              case 1:
-                cachingTable.put(key, "test-data-" + rand.nextInt());
-                lastPutKey = key;
-                break;
-              case 2:
-                if (lastPutKey != null) {
-                  cachingTable.delete(lastPutKey);
-                }
-                break;
-            }
-          }
-        });
+  public void testGuavaCacheAndRemoteTable() throws Exception {
+    String tableId = "testGuavaCacheAndRemoteTable";
+    Cache<String, String> guavaCache = CacheBuilder.newBuilder().initialCapacity(100).build();
+    final ReadWriteTable<String, String> guavaTable = new GuavaCacheTable<>(tableId, guavaCache);
+
+    // It is okay to share rateLimitHelper and async helper for read/write in test
+    TableRateLimiter<String, String> rateLimitHelper = mock(TableRateLimiter.class);
+    TableReadFunction<String, String> readFn = mock(TableReadFunction.class);
+    TableWriteFunction<String, String> writeFn = mock(TableWriteFunction.class);
+    final RemoteReadWriteTable<String, String> remoteTable = new RemoteReadWriteTable<>(
+        tableId, readFn, writeFn, rateLimitHelper, rateLimitHelper,
+        Executors.newSingleThreadExecutor(), Executors.newSingleThreadExecutor());
+
+    final CachingTable<String, String> cachingTable = new CachingTable<>(
+        tableId, remoteTable, guavaTable, false);
+
+    initTables(cachingTable, guavaTable, remoteTable);
+
+    // GET
+    doReturn(CompletableFuture.completedFuture("bar")).when(readFn).getAsync(any());
+    Assert.assertEquals(cachingTable.getAsync("foo").get(), "bar");
+    // Ensure cache is updated
+    Assert.assertEquals(guavaCache.getIfPresent("foo"), "bar");
+
+    // PUT
+    doReturn(CompletableFuture.completedFuture(null)).when(writeFn).putAsync(any(), any());
+    cachingTable.putAsync("foo", "baz").get();
+    // Ensure cache is updated
+    Assert.assertEquals(guavaCache.getIfPresent("foo"), "baz");
+
+    // DELETE
+    doReturn(CompletableFuture.completedFuture(null)).when(writeFn).deleteAsync(any());
+    cachingTable.deleteAsync("foo").get();
+    // Ensure cache is updated
+    Assert.assertNull(guavaCache.getIfPresent("foo"));
+
+    // GET-ALL
+    Map<String, String> records = new HashMap<>();
+    records.put("foo1", "bar1");
+    records.put("foo2", "bar2");
+    doReturn(CompletableFuture.completedFuture(records)).when(readFn).getAllAsync(any());
+    Assert.assertEquals(cachingTable.getAllAsync(Arrays.asList("foo1", "foo2")).get(), records);
+    // Ensure cache is updated
+    Assert.assertEquals(guavaCache.getIfPresent("foo1"), "bar1");
+    Assert.assertEquals(guavaCache.getIfPresent("foo2"), "bar2");
+
+    // GET-ALL with partial miss
+    doReturn(CompletableFuture.completedFuture(Collections.singletonMap("foo3", "bar3"))).when(readFn).getAllAsync(any());
+    records = cachingTable.getAllAsync(Arrays.asList("foo1", "foo2", "foo3")).get();
+    Assert.assertEquals(records.get("foo3"), "bar3");
+    // Ensure cache is updated
+    Assert.assertEquals(guavaCache.getIfPresent("foo3"), "bar3");
+
+    // Calling again for the same keys should not trigger IO, ie. no exception is thrown
+    CompletableFuture<String> exFuture = new CompletableFuture<>();
+    exFuture.completeExceptionally(new RuntimeException("Test exception"));
+    doReturn(exFuture).when(readFn).getAllAsync(any());
+    cachingTable.getAllAsync(Arrays.asList("foo1", "foo2", "foo3")).get();
+
+    // Partial results should throw
+    try {
+      cachingTable.getAllAsync(Arrays.asList("foo1", "foo2", "foo5")).get();
+      Assert.fail();
+    } catch (Exception e) {
     }
 
-    executor.shutdown();
-
-    // Wait up to 1 minute for all threads to finish
-    Assert.assertTrue(executor.awaitTermination(60, TimeUnit.MINUTES));
-
-    // Verify cache and table contents fully match
-    Assert.assertEquals(cacheMap.size(), tableMap.size());
-    cacheMap.keySet().forEach(k -> Assert.assertEquals(cacheMap.get(k), tableMap.get(k)));
+    // PUT-ALL
+    doReturn(CompletableFuture.completedFuture(null)).when(writeFn).putAllAsync(any());
+    List<Entry<String, String>> entries = new ArrayList<>();
+    entries.add(new Entry<>("foo1", "bar111"));
+    entries.add(new Entry<>("foo2", "bar222"));
+    cachingTable.putAllAsync(entries).get();
+    // Ensure cache is updated
+    Assert.assertEquals(guavaCache.getIfPresent("foo1"), "bar111");
+    Assert.assertEquals(guavaCache.getIfPresent("foo2"), "bar222");
+
+    // PUT-ALL with delete
+    doReturn(CompletableFuture.completedFuture(null)).when(writeFn).putAllAsync(any());
+    doReturn(CompletableFuture.completedFuture(null)).when(writeFn).deleteAllAsync(any());
+    entries = new ArrayList<>();
+    entries.add(new Entry<>("foo1", "bar111"));
+    entries.add(new Entry<>("foo2", null));
+    cachingTable.putAllAsync(entries).get();
+    // Ensure cache is updated
+    Assert.assertNull(guavaCache.getIfPresent("foo2"));
+
+    // At this point, foo1 and foo3 should still exist
+    Assert.assertNotNull(guavaCache.getIfPresent("foo1"));
+    Assert.assertNotNull(guavaCache.getIfPresent("foo3"));
+
+    // DELETE-ALL
+    doReturn(CompletableFuture.completedFuture(null)).when(writeFn).deleteAllAsync(any());
+    cachingTable.deleteAllAsync(Arrays.asList("foo1", "foo3")).get();
+    // Ensure foo1 and foo3 are gone
+    Assert.assertNull(guavaCache.getIfPresent("foo1"));
+    Assert.assertNull(guavaCache.getIfPresent("foo3"));
   }
 }
\ No newline at end of file
diff --git a/samza-core/src/test/java/org/apache/samza/table/remote/TestRemoteTable.java b/samza-core/src/test/java/org/apache/samza/table/remote/TestRemoteTable.java
new file mode 100644 (file)
index 0000000..21fc6a5
--- /dev/null
@@ -0,0 +1,413 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.table.remote;
+
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+
+import org.apache.samza.container.SamzaContainerContext;
+import org.apache.samza.metrics.Counter;
+import org.apache.samza.metrics.Gauge;
+import org.apache.samza.metrics.MetricsRegistry;
+import org.apache.samza.metrics.Timer;
+import org.apache.samza.storage.kv.Entry;
+import org.apache.samza.task.TaskContext;
+import org.junit.Test;
+import org.mockito.ArgumentCaptor;
+
+import junit.framework.Assert;
+
+import static org.mockito.Matchers.any;
+import static org.mockito.Matchers.anyCollection;
+import static org.mockito.Matchers.anyString;
+import static org.mockito.Mockito.doReturn;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+
+
+public class TestRemoteTable {
+  private <K, V, T extends RemoteReadableTable<K, V>> T getTable(String tableId,
+      TableReadFunction<K, V> readFn, TableWriteFunction<K, V> writeFn) {
+    return getTable(tableId, readFn, writeFn, null);
+  }
+
+  private <K, V, T extends RemoteReadableTable<K, V>> T getTable(String tableId,
+      TableReadFunction<K, V> readFn, TableWriteFunction<K, V> writeFn, ExecutorService cbExecutor) {
+    RemoteReadableTable<K, V> table;
+
+    TableRateLimiter<K, V> readRateLimiter = mock(TableRateLimiter.class);
+    TableRateLimiter<K, V> writeRateLimiter = mock(TableRateLimiter.class);
+    doReturn(true).when(readRateLimiter).isRateLimited();
+    doReturn(true).when(writeRateLimiter).isRateLimited();
+
+    ExecutorService tableExecutor = Executors.newSingleThreadExecutor();
+
+    if (writeFn == null) {
+      table = new RemoteReadableTable<K, V>(tableId, readFn, readRateLimiter, tableExecutor, cbExecutor);
+    } else {
+      table = new RemoteReadWriteTable<K, V>(tableId, readFn, writeFn, readRateLimiter, writeRateLimiter, tableExecutor, cbExecutor);
+    }
+
+    TaskContext taskContext = mock(TaskContext.class);
+    MetricsRegistry metricsRegistry = mock(MetricsRegistry.class);
+    doReturn(mock(Timer.class)).when(metricsRegistry).newTimer(anyString(), anyString());
+    doReturn(mock(Counter.class)).when(metricsRegistry).newCounter(anyString(), anyString());
+    doReturn(mock(Gauge.class)).when(metricsRegistry).newGauge(anyString(), any());
+    doReturn(metricsRegistry).when(taskContext).getMetricsRegistry();
+
+    SamzaContainerContext containerContext = mock(SamzaContainerContext.class);
+
+    table.init(containerContext, taskContext);
+
+    return (T) table;
+  }
+
+  private void doTestGet(boolean sync, boolean error) throws Exception {
+    TableReadFunction<String, String> readFn = mock(TableReadFunction.class);
+    // Sync is backed by async so needs to mock the async method
+    CompletableFuture<String> future;
+    if (error) {
+      future = new CompletableFuture();
+      future.completeExceptionally(new RuntimeException("Test exception"));
+    } else {
+      future = CompletableFuture.completedFuture("bar");
+    }
+    doReturn(future).when(readFn).getAsync(anyString());
+    RemoteReadableTable<String, String> table = getTable("testGet-" + sync + error, readFn, null);
+    Assert.assertEquals("bar", sync ? table.get("foo") : table.getAsync("foo").get());
+    verify(table.readRateLimiter, times(1)).throttle(anyString());
+  }
+
+  @Test
+  public void testGet() throws Exception {
+    doTestGet(true, false);
+  }
+
+  @Test
+  public void testGetAsync() throws Exception {
+    doTestGet(false, false);
+  }
+
+  @Test(expected = ExecutionException.class)
+  public void testGetAsyncError() throws Exception {
+    doTestGet(false, true);
+  }
+
+  @Test
+  public void testGetMultipleTables() {
+    TableReadFunction<String, String> readFn1 = mock(TableReadFunction.class);
+    TableReadFunction<String, String> readFn2 = mock(TableReadFunction.class);
+
+    // Sync is backed by async so needs to mock the async method
+    doReturn(CompletableFuture.completedFuture("bar1")).when(readFn1).getAsync(anyString());
+    doReturn(CompletableFuture.completedFuture("bar2")).when(readFn1).getAsync(anyString());
+
+    RemoteReadableTable<String, String> table1 = getTable("testGetMultipleTables-1", readFn1, null);
+    RemoteReadableTable<String, String> table2 = getTable("testGetMultipleTables-2", readFn2, null);
+
+    CompletableFuture<String> future1 = table1.getAsync("foo1");
+    CompletableFuture<String> future2 = table2.getAsync("foo2");
+
+    CompletableFuture.allOf(future1, future2)
+        .thenAccept(u -> {
+            Assert.assertEquals(future1.join(), "bar1");
+            Assert.assertEquals(future2.join(), "bar1");
+          });
+  }
+
+  private void doTestPut(boolean sync, boolean error, boolean isDelete) throws Exception {
+    TableWriteFunction<String, String> writeFn = mock(TableWriteFunction.class);
+    RemoteReadWriteTable<String, String> table = getTable("testPut-" + sync + error + isDelete,
+        mock(TableReadFunction.class), writeFn);
+    CompletableFuture<Void> future;
+    if (error) {
+      future = new CompletableFuture();
+      future.completeExceptionally(new RuntimeException("Test exception"));
+    } else {
+      future = CompletableFuture.completedFuture(null);
+    }
+    // Sync is backed by async so needs to mock the async method
+    if (isDelete) {
+      doReturn(future).when(writeFn).deleteAsync(any());
+    } else {
+      doReturn(future).when(writeFn).putAsync(any(), any());
+    }
+    if (sync) {
+      table.put("foo", isDelete ? null : "bar");
+    } else {
+      table.putAsync("foo", isDelete ? null : "bar").get();
+    }
+    ArgumentCaptor<String> keyCaptor = ArgumentCaptor.forClass(String.class);
+    ArgumentCaptor<String> valCaptor = ArgumentCaptor.forClass(String.class);
+    if (isDelete) {
+      verify(writeFn, times(1)).deleteAsync(keyCaptor.capture());
+    } else {
+      verify(writeFn, times(1)).putAsync(keyCaptor.capture(), valCaptor.capture());
+      Assert.assertEquals("bar", valCaptor.getValue());
+    }
+    Assert.assertEquals("foo", keyCaptor.getValue());
+    if (isDelete) {
+      verify(table.writeRateLimiter, times(1)).throttle(anyString());
+    } else {
+      verify(table.writeRateLimiter, times(1)).throttle(anyString(), anyString());
+    }
+  }
+
+  @Test
+  public void testPut() throws Exception {
+    doTestPut(true, false, false);
+  }
+
+  @Test
+  public void testPutDelete() throws Exception {
+    doTestPut(true, false, true);
+  }
+
+  @Test
+  public void testPutAsync() throws Exception {
+    doTestPut(false, false, false);
+  }
+
+  @Test
+  public void testPutAsyncDelete() throws Exception {
+    doTestPut(false, false, true);
+  }
+
+  @Test(expected = ExecutionException.class)
+  public void testPutAsyncError() throws Exception {
+    doTestPut(false, true, false);
+  }
+
+  private void doTestDelete(boolean sync, boolean error) throws Exception {
+    TableWriteFunction<String, String> writeFn = mock(TableWriteFunction.class);
+    RemoteReadWriteTable<String, String> table = getTable("testDelete-" + sync + error,
+        mock(TableReadFunction.class), writeFn);
+    CompletableFuture<Void> future;
+    if (error) {
+      future = new CompletableFuture();
+      future.completeExceptionally(new RuntimeException("Test exception"));
+    } else {
+      future = CompletableFuture.completedFuture(null);
+    }
+    // Sync is backed by async so needs to mock the async method
+    doReturn(future).when(writeFn).deleteAsync(any());
+    ArgumentCaptor<String> argCaptor = ArgumentCaptor.forClass(String.class);
+    if (sync) {
+      table.delete("foo");
+    } else {
+      table.deleteAsync("foo").get();
+    }
+    verify(writeFn, times(1)).deleteAsync(argCaptor.capture());
+    Assert.assertEquals("foo", argCaptor.getValue());
+    verify(table.writeRateLimiter, times(1)).throttle(anyString());
+  }
+
+  @Test
+  public void testDelete() throws Exception {
+    doTestDelete(true, false);
+  }
+
+  @Test
+  public void testDeleteAsync() throws Exception {
+    doTestDelete(false, false);
+  }
+
+  @Test(expected = ExecutionException.class)
+  public void testDeleteAsyncError() throws Exception {
+    doTestDelete(false, true);
+  }
+
+  private void doTestGetAll(boolean sync, boolean error, boolean partial) throws Exception {
+    TableReadFunction<String, String> readFn = mock(TableReadFunction.class);
+    Map<String, String> res = new HashMap<>();
+    res.put("foo1", "bar1");
+    if (!partial) {
+      res.put("foo2", "bar2");
+    }
+    CompletableFuture<Map<String, String>> future;
+    if (error) {
+      future = new CompletableFuture();
+      future.completeExceptionally(new RuntimeException("Test exception"));
+    } else {
+      future = CompletableFuture.completedFuture(res);
+    }
+    // Sync is backed by async so needs to mock the async method
+    doReturn(future).when(readFn).getAllAsync(any());
+    RemoteReadableTable<String, String> table = getTable("testGetAll-" + sync + error + partial, readFn, null);
+    Assert.assertEquals(res, sync ? table.getAll(Arrays.asList("foo1", "foo2"))
+        : table.getAllAsync(Arrays.asList("foo1", "foo2")).get());
+    verify(table.readRateLimiter, times(1)).throttle(anyCollection());
+  }
+
+  @Test
+  public void testGetAll() throws Exception {
+    doTestGetAll(true, false, false);
+  }
+
+  @Test
+  public void testGetAllAsync() throws Exception {
+    doTestGetAll(false, false, false);
+  }
+
+  @Test(expected = ExecutionException.class)
+  public void testGetAllAsyncError() throws Exception {
+    doTestGetAll(false, true, false);
+  }
+
+  // Partial result is an acceptable scenario
+  @Test
+  public void testGetAllPartialResult() throws Exception {
+    doTestGetAll(false, false, true);
+  }
+
+  public void doTestPutAll(boolean sync, boolean error, boolean hasDelete) throws Exception {
+    TableWriteFunction<String, String> writeFn = mock(TableWriteFunction.class);
+    RemoteReadWriteTable<String, String> table = getTable("testPutAll-" + sync + error + hasDelete,
+        mock(TableReadFunction.class), writeFn);
+    CompletableFuture<Void> future;
+    if (error) {
+      future = new CompletableFuture();
+      future.completeExceptionally(new RuntimeException("Test exception"));
+    } else {
+      future = CompletableFuture.completedFuture(null);
+    }
+    // Sync is backed by async so needs to mock the async method
+    doReturn(future).when(writeFn).putAllAsync(any());
+    if (hasDelete) {
+      doReturn(future).when(writeFn).deleteAllAsync(any());
+    }
+    List<Entry<String, String>> entries = Arrays.asList(
+        new Entry<>("foo1", "bar1"), new Entry<>("foo2", hasDelete ? null : "bar2"));
+    ArgumentCaptor<List> argCaptor = ArgumentCaptor.forClass(List.class);
+    if (sync) {
+      table.putAll(entries);
+    } else {
+      table.putAllAsync(entries).get();
+    }
+    verify(writeFn, times(1)).putAllAsync(argCaptor.capture());
+    if (hasDelete) {
+      ArgumentCaptor<List> delArgCaptor = ArgumentCaptor.forClass(List.class);
+      verify(writeFn, times(1)).deleteAllAsync(delArgCaptor.capture());
+      Assert.assertEquals(Arrays.asList("foo2"), delArgCaptor.getValue());
+      Assert.assertEquals(1, argCaptor.getValue().size());
+      Assert.assertEquals("foo1", ((Entry) argCaptor.getValue().get(0)).getKey());
+      verify(table.writeRateLimiter, times(1)).throttle(anyCollection());
+    } else {
+      Assert.assertEquals(entries, argCaptor.getValue());
+    }
+    verify(table.writeRateLimiter, times(1)).throttleRecords(anyCollection());
+  }
+
+  @Test
+  public void testPutAll() throws Exception {
+    doTestPutAll(true, false, false);
+  }
+
+  @Test
+  public void testPutAllHasDelete() throws Exception {
+    doTestPutAll(true, false, true);
+  }
+
+  @Test
+  public void testPutAllAsync() throws Exception {
+    doTestPutAll(false, false, false);
+  }
+
+  @Test
+  public void testPutAllAsyncHasDelete() throws Exception {
+    doTestPutAll(false, false, true);
+  }
+
+  @Test(expected = ExecutionException.class)
+  public void testPutAllAsyncError() throws Exception {
+    doTestPutAll(false, true, false);
+  }
+
+  public void doTestDeleteAll(boolean sync, boolean error) throws Exception {
+    TableWriteFunction<String, String> writeFn = mock(TableWriteFunction.class);
+    RemoteReadWriteTable<String, String> table = getTable("testDeleteAll-" + sync + error,
+        mock(TableReadFunction.class), writeFn);
+    CompletableFuture<Void> future;
+    if (error) {
+      future = new CompletableFuture();
+      future.completeExceptionally(new RuntimeException("Test exception"));
+    } else {
+      future = CompletableFuture.completedFuture(null);
+    }
+    // Sync is backed by async so needs to mock the async method
+    doReturn(future).when(writeFn).deleteAllAsync(any());
+    List<String> keys = Arrays.asList("foo1", "foo2");
+    ArgumentCaptor<List> argCaptor = ArgumentCaptor.forClass(List.class);
+    if (sync) {
+      table.deleteAll(keys);
+    } else {
+      table.deleteAllAsync(keys).get();
+    }
+    verify(writeFn, times(1)).deleteAllAsync(argCaptor.capture());
+    Assert.assertEquals(keys, argCaptor.getValue());
+    verify(table.writeRateLimiter, times(1)).throttle(anyCollection());
+  }
+
+  @Test
+  public void testDeleteAll() throws Exception {
+    doTestDeleteAll(true, false);
+  }
+
+  @Test
+  public void testDeleteAllAsync() throws Exception {
+    doTestDeleteAll(false, false);
+  }
+
+  @Test(expected = ExecutionException.class)
+  public void testDeleteAllAsyncError() throws Exception {
+    doTestDeleteAll(false, true);
+  }
+
+  @Test
+  public void testFlush() {
+    TableWriteFunction<String, String> writeFn = mock(TableWriteFunction.class);
+    RemoteReadWriteTable<String, String> table = getTable("testFlush", mock(TableReadFunction.class), writeFn);
+    table.flush();
+    verify(writeFn, times(1)).flush();
+  }
+
+  @Test
+  public void testGetWithCallbackExecutor() throws Exception {
+    TableReadFunction<String, String> readFn = mock(TableReadFunction.class);
+    // Sync is backed by async so needs to mock the async method
+    doReturn(CompletableFuture.completedFuture("bar")).when(readFn).getAsync(anyString());
+    RemoteReadableTable<String, String> table = getTable("testGetWithCallbackExecutor", readFn, null,
+        Executors.newSingleThreadExecutor());
+    Thread testThread = Thread.currentThread();
+
+    table.getAsync("foo").thenAccept(result -> {
+        Assert.assertEquals("bar", result);
+        // Must be executed on the executor thread
+        Assert.assertNotSame(testThread, Thread.currentThread());
+      });
+  }
+}
index acf3d61..e30da12 100644 (file)
@@ -22,13 +22,13 @@ package org.apache.samza.table.remote;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.Map;
+import java.util.concurrent.ThreadPoolExecutor;
 
 import org.apache.samza.container.SamzaContainerContext;
 import org.apache.samza.container.TaskName;
 import org.apache.samza.metrics.Counter;
 import org.apache.samza.metrics.MetricsRegistry;
 import org.apache.samza.metrics.Timer;
-import org.apache.samza.operators.KV;
 import org.apache.samza.table.Table;
 import org.apache.samza.table.TableSpec;
 import org.apache.samza.task.TaskContext;
@@ -39,19 +39,16 @@ import org.junit.Test;
 
 import static org.apache.samza.table.remote.RemoteTableDescriptor.RL_READ_TAG;
 import static org.apache.samza.table.remote.RemoteTableDescriptor.RL_WRITE_TAG;
-import static org.mockito.Matchers.anyMap;
 import static org.mockito.Matchers.anyString;
 import static org.mockito.Mockito.doReturn;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.spy;
-import static org.mockito.Mockito.times;
-import static org.mockito.Mockito.verify;
 
 
 public class TestRemoteTableDescriptor {
   private void doTestSerialize(RateLimiter rateLimiter,
-      CreditFunction readCredFn,
-      CreditFunction writeCredFn) {
+      TableRateLimiter.CreditFunction readCredFn,
+      TableRateLimiter.CreditFunction writeCredFn) {
     RemoteTableDescriptor desc = new RemoteTableDescriptor("1");
     desc.withReadFunction(mock(TableReadFunction.class));
     desc.withWriteFunction(mock(TableWriteFunction.class));
@@ -79,17 +76,17 @@ public class TestRemoteTableDescriptor {
 
   @Test
   public void testSerializeWithLimiterAndReadCredFn() {
-    doTestSerialize(mock(RateLimiter.class), kv -> 1, null);
+    doTestSerialize(mock(RateLimiter.class), (k, v) -> 1, null);
   }
 
   @Test
   public void testSerializeWithLimiterAndWriteCredFn() {
-    doTestSerialize(mock(RateLimiter.class), null, kv -> 1);
+    doTestSerialize(mock(RateLimiter.class), null, (k, v) -> 1);
   }
 
   @Test
   public void testSerializeWithLimiterAndReadWriteCredFns() {
-    doTestSerialize(mock(RateLimiter.class), kv -> 1, kv -> 1);
+    doTestSerialize(mock(RateLimiter.class), (key, value) -> 1, (key, value) -> 1);
   }
 
   @Test
@@ -129,10 +126,10 @@ public class TestRemoteTableDescriptor {
     return taskContext;
   }
 
-  static class CountingCreditFunction<K, V> implements CreditFunction<K, V> {
+  static class CountingCreditFunction<K, V> implements TableRateLimiter.CreditFunction<K, V> {
     int numCalls = 0;
     @Override
-    public Integer apply(KV<K, V> kv) {
+    public int getCredits(K key, V value) {
       numCalls++;
       return 1;
     }
@@ -143,6 +140,8 @@ public class TestRemoteTableDescriptor {
     RemoteTableDescriptor<String, String> desc = new RemoteTableDescriptor("1");
     desc.withReadFunction(mock(TableReadFunction.class));
     desc.withWriteFunction(mock(TableWriteFunction.class));
+    desc.withAsyncCallbackExecutorPoolSize(10);
+
     if (rateOnly) {
       if (rlGets) {
         desc.withReadRateLimit(1000);
@@ -172,39 +171,13 @@ public class TestRemoteTableDescriptor {
     Table table = provider.getTable();
     Assert.assertTrue(table instanceof RemoteReadWriteTable);
     RemoteReadWriteTable rwTable = (RemoteReadWriteTable) table;
-    Assert.assertNotNull(rwTable.readFn);
-    Assert.assertNotNull(rwTable.writeFn);
     if (numRateLimitOps > 0) {
-      Assert.assertNotNull(rwTable.rateLimiter);
+      Assert.assertTrue(!rlGets || rwTable.readRateLimiter != null);
+      Assert.assertTrue(!rlPuts || rwTable.writeRateLimiter != null);
     }
 
-    // Verify rate limiter usage
-    if (numRateLimitOps > 0) {
-      rwTable.get("xxx");
-      rwTable.put("yyy", "zzz");
-
-      if (!rateOnly) {
-        verify(rwTable.rateLimiter, times(numRateLimitOps)).acquire(anyMap());
-
-        CountingCreditFunction<?, ?> readCreditFn = (CountingCreditFunction<?, ?>) rwTable.readCreditFn;
-        CountingCreditFunction<?, ?> writeCreditFn = (CountingCreditFunction<?, ?>) rwTable.writeCreditFn;
-
-        Assert.assertNotNull(readCreditFn);
-        Assert.assertNotNull(writeCreditFn);
-
-        Assert.assertEquals(readCreditFn.numCalls, rlGets ? 1 : 0);
-        Assert.assertEquals(writeCreditFn.numCalls, rlPuts ? 1 : 0);
-      } else {
-        Assert.assertTrue(rwTable.rateLimiter instanceof EmbeddedTaggedRateLimiter);
-        Assert.assertEquals(rwTable.rateLimiter.getSupportedTags().size(), numRateLimitOps);
-        if (rlGets) {
-          Assert.assertTrue(rwTable.rateLimiter.getSupportedTags().contains(RL_READ_TAG));
-        }
-        if (rlPuts) {
-          Assert.assertTrue(rwTable.rateLimiter.getSupportedTags().contains(RL_WRITE_TAG));
-        }
-      }
-    }
+    ThreadPoolExecutor callbackExecutor = (ThreadPoolExecutor) rwTable.callbackExecutor;
+    Assert.assertEquals(10, callbackExecutor.getCorePoolSize());
   }
 
   @Test
diff --git a/samza-core/src/test/java/org/apache/samza/table/remote/TestTableRateLimiter.java b/samza-core/src/test/java/org/apache/samza/table/remote/TestTableRateLimiter.java
new file mode 100644 (file)
index 0000000..ea9acbd
--- /dev/null
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.table.remote;
+
+import java.util.Arrays;
+import java.util.Collections;
+
+import org.apache.samza.metrics.Timer;
+import org.apache.samza.storage.kv.Entry;
+import org.apache.samza.util.RateLimiter;
+import org.junit.Test;
+
+import junit.framework.Assert;
+
+import static org.mockito.Matchers.anyLong;
+import static org.mockito.Matchers.anyMap;
+import static org.mockito.Mockito.doReturn;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+
+
+public class TestTableRateLimiter {
+  private static final String DEFAULT_TAG = "mytag";
+
+  public TableRateLimiter<String, String> getThrottler() {
+    return getThrottler(DEFAULT_TAG);
+  }
+
+  public TableRateLimiter<String, String> getThrottler(String tag) {
+    TableRateLimiter.CreditFunction<String, String> credFn =
+        (TableRateLimiter.CreditFunction<String, String>) (key, value) -> {
+      int credits = key == null ? 0 : 3;
+      credits += value == null ? 0 : 3;
+      return credits;
+    };
+    RateLimiter rateLimiter = mock(RateLimiter.class);
+    doReturn(Collections.singleton(DEFAULT_TAG)).when(rateLimiter).getSupportedTags();
+    TableRateLimiter<String, String> rateLimitHelper = new TableRateLimiter<>("foo", rateLimiter, credFn, tag);
+    Timer timer = mock(Timer.class);
+    rateLimitHelper.setTimerMetric(timer);
+    return rateLimitHelper;
+  }
+
+  @Test
+  public void testCreditKeyOnly() {
+    TableRateLimiter<String, String> rateLimitHelper = getThrottler();
+    Assert.assertEquals(3, rateLimitHelper.getCredits("abc", null));
+  }
+
+  @Test
+  public void testCreditKeyValue() {
+    TableRateLimiter<String, String> rateLimitHelper = getThrottler();
+    Assert.assertEquals(6, rateLimitHelper.getCredits("abc", "efg"));
+  }
+
+  @Test
+  public void testCreditKeys() {
+    TableRateLimiter<String, String> rateLimitHelper = getThrottler();
+    Assert.assertEquals(9, rateLimitHelper.getCredits(Arrays.asList("abc", "efg", "hij")));
+  }
+
+  @Test
+  public void testCreditEntries() {
+    TableRateLimiter<String, String> rateLimitHelper = getThrottler();
+    Assert.assertEquals(12, rateLimitHelper.getEntryCredits(
+        Arrays.asList(new Entry<>("abc", "efg"), new Entry<>("hij", "lmn"))));
+  }
+
+  @Test
+  public void testThrottle() {
+    TableRateLimiter<String, String> rateLimitHelper = getThrottler();
+    Timer timer = mock(Timer.class);
+    rateLimitHelper.setTimerMetric(timer);
+    rateLimitHelper.throttle("foo");
+    verify(rateLimitHelper.rateLimiter, times(1)).acquire(anyMap());
+    verify(timer, times(1)).update(anyLong());
+  }
+
+  @Test
+  public void testThrottleUnknownTag() {
+    TableRateLimiter<String, String> rateLimitHelper = getThrottler("unknown_tag");
+    rateLimitHelper.throttle("foo");
+    verify(rateLimitHelper.rateLimiter, times(0)).acquire(anyMap());
+  }
+}
index 882ae0d..98c3e3c 100644 (file)
@@ -19,6 +19,7 @@
 package org.apache.samza.storage.kv;
 
 import java.util.List;
+import java.util.concurrent.CompletableFuture;
 
 import org.apache.samza.container.SamzaContainerContext;
 import org.apache.samza.table.ReadWriteTable;
@@ -67,6 +68,18 @@ public class LocalStoreBackedReadWriteTable<K, V> extends LocalStoreBackedReadab
   }
 
   @Override
+  public CompletableFuture<Void> putAsync(K key, V value) {
+    CompletableFuture<Void> future = new CompletableFuture();
+    try {
+      put(key, value);
+      future.complete(null);
+    } catch (Exception e) {
+      future.completeExceptionally(e);
+    }
+    return future;
+  }
+
+  @Override
   public void putAll(List<Entry<K, V>> entries) {
     writeMetrics.numPutAlls.inc();
     long startNs = System.nanoTime();
@@ -75,6 +88,18 @@ public class LocalStoreBackedReadWriteTable<K, V> extends LocalStoreBackedReadab
   }
 
   @Override
+  public CompletableFuture<Void> putAllAsync(List<Entry<K, V>> entries) {
+    CompletableFuture<Void> future = new CompletableFuture();
+    try {
+      putAll(entries);
+      future.complete(null);
+    } catch (Exception e) {
+      future.completeExceptionally(e);
+    }
+    return future;
+  }
+
+  @Override
   public void delete(K key) {
     writeMetrics.numDeletes.inc();
     long startNs = System.nanoTime();
@@ -83,6 +108,18 @@ public class LocalStoreBackedReadWriteTable<K, V> extends LocalStoreBackedReadab
   }
 
   @Override
+  public CompletableFuture<Void> deleteAsync(K key) {
+    CompletableFuture<Void> future = new CompletableFuture();
+    try {
+      delete(key);
+      future.complete(null);
+    } catch (Exception e) {
+      future.completeExceptionally(e);
+    }
+    return future;
+  }
+
+  @Override
   public void deleteAll(List<K> keys) {
     writeMetrics.numDeleteAlls.inc();
     long startNs = System.nanoTime();
@@ -91,6 +128,18 @@ public class LocalStoreBackedReadWriteTable<K, V> extends LocalStoreBackedReadab
   }
 
   @Override
+  public CompletableFuture<Void> deleteAllAsync(List<K> keys) {
+    CompletableFuture<Void> future = new CompletableFuture();
+    try {
+      deleteAll(keys);
+      future.complete(null);
+    } catch (Exception e) {
+      future.completeExceptionally(e);
+    }
+    return future;
+  }
+
+  @Override
   public void flush() {
     writeMetrics.numFlushes.inc();
     long startNs = System.nanoTime();
index 8d79e0d..1c59eb6 100644 (file)
@@ -20,6 +20,7 @@ package org.apache.samza.storage.kv;
 
 import java.util.List;
 import java.util.Map;
+import java.util.concurrent.CompletableFuture;
 
 import com.google.common.base.Preconditions;
 import org.apache.samza.container.SamzaContainerContext;
@@ -70,6 +71,17 @@ public class LocalStoreBackedReadableTable<K, V> implements ReadableTable<K, V>
   }
 
   @Override
+  public CompletableFuture<V> getAsync(K key) {
+    CompletableFuture<V> future = new CompletableFuture();
+    try {
+      future.complete(get(key));
+    } catch (Exception e) {
+      future.completeExceptionally(e);
+    }
+    return future;
+  }
+
+  @Override
   public Map<K, V> getAll(List<K> keys) {
     readMetrics.numGetAlls.inc();
     long startNs = System.nanoTime();
@@ -79,6 +91,17 @@ public class LocalStoreBackedReadableTable<K, V> implements ReadableTable<K, V>
   }
 
   @Override
+  public CompletableFuture<Map<K, V>> getAllAsync(List<K> keys) {
+    CompletableFuture<Map<K, V>> future = new CompletableFuture();
+    try {
+      future.complete(getAll(keys));
+    } catch (Exception e) {
+      future.completeExceptionally(e);
+    }
+    return future;
+  }
+
+  @Override
   public void close() {
     // The KV store is not closed here as it may still be needed by downstream operators,
     // it will be closed by the SamzaContainer
index 8a20239..7068e9b 100644 (file)
@@ -23,6 +23,8 @@ import java.util.Arrays;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
+import java.util.concurrent.CompletableFuture;
+
 import org.apache.commons.lang.NotImplementedException;
 import org.apache.samza.config.Config;
 import org.apache.samza.container.SamzaContainerContext;
@@ -78,11 +80,21 @@ public class TestIOResolverFactory implements SqlIOResolverFactory {
     }
 
     @Override
+    public CompletableFuture getAsync(Object key) {
+      throw new NotImplementedException();
+    }
+
+    @Override
     public Map getAll(List keys) {
       throw new NotImplementedException();
     }
 
     @Override
+    public CompletableFuture<Map> getAllAsync(List keys) {
+      throw new NotImplementedException();
+    }
+
+    @Override
     public void close() {
     }
 
@@ -98,16 +110,36 @@ public class TestIOResolverFactory implements SqlIOResolverFactory {
     }
 
     @Override
+    public CompletableFuture<Void> putAsync(Object key, Object value) {
+      throw new NotImplementedException();
+    }
+
+    @Override
+    public CompletableFuture<Void> putAllAsync(List list) {
+      throw new NotImplementedException();
+    }
+
+    @Override
     public void delete(Object key) {
       records.remove(key);
     }
 
     @Override
+    public CompletableFuture<Void> deleteAsync(Object key) {
+      throw new NotImplementedException();
+    }
+
+    @Override
     public void deleteAll(List keys) {
       records.clear();
     }
 
     @Override
+    public CompletableFuture<Void> deleteAllAsync(List keys) {
+      throw new NotImplementedException();
+    }
+
+    @Override
     public void flush() {
     }
 
index d7f0570..14ef751 100644 (file)
@@ -20,6 +20,7 @@
 package org.apache.samza.test.table;
 
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.HashMap;
 import java.util.LinkedList;
 import java.util.List;
@@ -32,7 +33,12 @@ import org.apache.samza.config.JobConfig;
 import org.apache.samza.config.JobCoordinatorConfig;
 import org.apache.samza.config.MapConfig;
 import org.apache.samza.config.TaskConfig;
+import org.apache.samza.container.SamzaContainerContext;
 import org.apache.samza.container.grouper.task.SingleContainerGrouperFactory;
+import org.apache.samza.metrics.Counter;
+import org.apache.samza.metrics.Gauge;
+import org.apache.samza.metrics.MetricsRegistry;
+import org.apache.samza.metrics.Timer;
 import org.apache.samza.operators.KV;
 import org.apache.samza.operators.MessageStream;
 import org.apache.samza.operators.functions.MapFunction;
@@ -43,6 +49,9 @@ import org.apache.samza.serializers.KVSerde;
 import org.apache.samza.serializers.NoOpSerde;
 import org.apache.samza.standalone.PassthroughCoordinationUtilsFactory;
 import org.apache.samza.standalone.PassthroughJobCoordinatorFactory;
+import org.apache.samza.storage.kv.Entry;
+import org.apache.samza.storage.kv.KeyValueStore;
+import org.apache.samza.storage.kv.LocalStoreBackedReadWriteTable;
 import org.apache.samza.storage.kv.inmemory.InMemoryTableDescriptor;
 import org.apache.samza.table.ReadableTable;
 import org.apache.samza.table.Table;
@@ -61,6 +70,14 @@ import org.junit.Test;
 
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertTrue;
+import static org.mockito.Matchers.any;
+import static org.mockito.Matchers.anyList;
+import static org.mockito.Matchers.anyString;
+import static org.mockito.Mockito.doReturn;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+
 
 /**
  * This test class tests sendTo() and join() for local tables
@@ -360,4 +377,48 @@ public class TestLocalTable extends AbstractIntegrationTestHarness {
       return record.getKey();
     }
   }
+
+  @Test
+  public void testAsyncOperation() throws Exception {
+    KeyValueStore kvStore = mock(KeyValueStore.class);
+    LocalStoreBackedReadWriteTable<String, String> table = new LocalStoreBackedReadWriteTable<>("table1", kvStore);
+    TaskContext taskContext = mock(TaskContext.class);
+    MetricsRegistry metricsRegistry = mock(MetricsRegistry.class);
+    doReturn(mock(Timer.class)).when(metricsRegistry).newTimer(anyString(), anyString());
+    doReturn(mock(Counter.class)).when(metricsRegistry).newCounter(anyString(), anyString());
+    doReturn(mock(Gauge.class)).when(metricsRegistry).newGauge(anyString(), any());
+    doReturn(metricsRegistry).when(taskContext).getMetricsRegistry();
+
+    SamzaContainerContext containerContext = mock(SamzaContainerContext.class);
+
+    table.init(containerContext, taskContext);
+
+    // GET
+    doReturn("bar").when(kvStore).get(anyString());
+    Assert.assertEquals("bar", table.getAsync("foo").get());
+
+    // GET-ALL
+    Map<String, String> recordMap = new HashMap<>();
+    recordMap.put("foo1", "bar1");
+    recordMap.put("foo2", "bar2");
+    doReturn(recordMap).when(kvStore).getAll(anyList());
+    Assert.assertEquals(recordMap, table.getAllAsync(Arrays.asList("foo1", "foo2")).get());
+
+    // PUT
+    table.putAsync("foo1", "bar1").get();
+    verify(kvStore, times(1)).put(anyString(), anyString());
+
+    // PUT-ALL
+    List<Entry<String, String>> records = Arrays.asList(new Entry<>("foo1", "bar1"), new Entry<>("foo2", "bar2"));
+    table.putAllAsync(records).get();
+    verify(kvStore, times(1)).putAll(anyList());
+
+    // DELETE
+    table.deleteAsync("foo").get();
+    verify(kvStore, times(1)).delete(anyString());
+
+    // DELETE-ALL
+    table.deleteAllAsync(Arrays.asList("foo1", "foo2")).get();
+    verify(kvStore, times(1)).deleteAll(anyList());
+  }
 }
index 8d07570..2d07b01 100644 (file)
@@ -24,10 +24,11 @@ import java.io.ObjectInputStream;
 import java.time.Duration;
 import java.util.ArrayList;
 import java.util.Arrays;
-import java.util.Collection;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.Executors;
 import java.util.concurrent.TimeUnit;
 import java.util.function.Function;
 import java.util.stream.Collectors;
@@ -46,6 +47,7 @@ import org.apache.samza.serializers.NoOpSerde;
 import org.apache.samza.table.Table;
 import org.apache.samza.table.caching.CachingTableDescriptor;
 import org.apache.samza.table.caching.guava.GuavaCacheTableDescriptor;
+import org.apache.samza.table.remote.TableRateLimiter;
 import org.apache.samza.table.remote.TableReadFunction;
 import org.apache.samza.table.remote.TableWriteFunction;
 import org.apache.samza.table.remote.RemoteReadableTable;
@@ -63,7 +65,6 @@ import com.google.common.cache.CacheBuilder;
 import static org.mockito.Matchers.any;
 import static org.mockito.Matchers.anyString;
 import static org.mockito.Mockito.doReturn;
-import static org.mockito.Mockito.doThrow;
 import static org.mockito.Mockito.mock;
 
 
@@ -86,8 +87,8 @@ public class TestRemoteTable extends AbstractIntegrationTestHarness {
     }
 
     @Override
-    public TestTableData.Profile get(Integer key) {
-      return profileMap.getOrDefault(key, null);
+    public CompletableFuture<TestTableData.Profile> getAsync(Integer key) {
+      return CompletableFuture.completedFuture(profileMap.get(key));
     }
 
     static InMemoryReadFunction getInMemoryReadFunction(String serializedProfiles) {
@@ -112,18 +113,15 @@ public class TestRemoteTable extends AbstractIntegrationTestHarness {
     }
 
     @Override
-    public void put(Integer key, TestTableData.EnrichedPageView record) {
+    public CompletableFuture<Void> putAsync(Integer key, TestTableData.EnrichedPageView record) {
       records.add(record);
+      return CompletableFuture.completedFuture(null);
     }
 
     @Override
-    public void delete(Integer key) {
+    public CompletableFuture<Void> deleteAsync(Integer key) {
       records.remove(key);
-    }
-
-    @Override
-    public void deleteAll(Collection<Integer> keys) {
-      records.removeAll(keys);
+      return CompletableFuture.completedFuture(null);
     }
   }
 
@@ -187,9 +185,7 @@ public class TestRemoteTable extends AbstractIntegrationTestHarness {
       }
 
       streamGraph.getInputStream("PageView", new NoOpSerde<TestTableData.PageView>())
-          .map(pv -> {
-              return new KV<Integer, TestTableData.PageView>(pv.getMemberId(), pv);
-            })
+          .map(pv -> new KV<>(pv.getMemberId(), pv))
           .join(inputTable, new TestLocalTable.PageViewToProfileJoinFunction())
           .map(m -> new KV(m.getMemberId(), m))
           .sendTo(outputTable);
@@ -230,8 +226,12 @@ public class TestRemoteTable extends AbstractIntegrationTestHarness {
   @Test(expected = SamzaException.class)
   public void testCatchReaderException() {
     TableReadFunction<String, ?> reader = mock(TableReadFunction.class);
-    doThrow(new RuntimeException("Expected test exception")).when(reader).get(anyString());
-    RemoteReadableTable<String, ?> table = new RemoteReadableTable<>("table1", reader, null, null);
+    CompletableFuture<String> future = new CompletableFuture<>();
+    future.completeExceptionally(new RuntimeException("Expected test exception"));
+    doReturn(future).when(reader).getAsync(anyString());
+    TableRateLimiter rateLimitHelper = mock(TableRateLimiter.class);
+    RemoteReadableTable<String, ?> table = new RemoteReadableTable<>(
+        "table1", reader, rateLimitHelper, Executors.newSingleThreadExecutor(), null);
     table.init(mock(SamzaContainerContext.class), createMockTaskContext());
     table.get("abc");
   }
@@ -240,8 +240,12 @@ public class TestRemoteTable extends AbstractIntegrationTestHarness {
   public void testCatchWriterException() {
     TableReadFunction<String, String> reader = mock(TableReadFunction.class);
     TableWriteFunction<String, String> writer = mock(TableWriteFunction.class);
-    doThrow(new RuntimeException("Expected test exception")).when(writer).put(anyString(), any());
-    RemoteReadWriteTable<String, String> table = new RemoteReadWriteTable<>("table1", reader, writer, null, null, null);
+    CompletableFuture<String> future = new CompletableFuture<>();
+    future.completeExceptionally(new RuntimeException("Expected test exception"));
+    doReturn(future).when(writer).putAsync(anyString(), any());
+    TableRateLimiter rateLimitHelper = mock(TableRateLimiter.class);
+    RemoteReadWriteTable<String, String> table = new RemoteReadWriteTable<String, String>(
+        "table1", reader, writer, rateLimitHelper, rateLimitHelper, Executors.newSingleThreadExecutor(), null);
     table.init(mock(SamzaContainerContext.class), createMockTaskContext());
     table.put("abc", "efg");
   }
index 817fb9f..38cc47c 100644 (file)
@@ -81,7 +81,8 @@ public class TestTableDescriptorsProvider {
     String tableRewriterName = "tableRewriter";
     configs.put("tables.descriptors.provider.class", MySampleTableDescriptorsProvider.class.getName());
     Config resultConfig = new MySampleTableConfigRewriter().rewrite(tableRewriterName, new MapConfig(configs));
-    Assert.assertTrue(resultConfig.size() == 18);
+    Assert.assertNotNull(resultConfig);
+    Assert.assertTrue(!resultConfig.isEmpty());
 
     String localTableId = "local-table-1";
     String remoteTableId = "remote-table-1";