faster varint
authorspupyrev <spupyrev@fb.com>
Wed, 23 Mar 2016 17:34:05 +0000 (10:34 -0700)
committerIgor Kabiljo <ikabiljo@fb.com>
Wed, 23 Mar 2016 17:34:05 +0000 (10:34 -0700)
Summary:
Varint is improved in two ways:
- faster readLong and readInt
- making sure that negative numbers can be encoded

JIRA: https://issues.apache.org/jira/browse/GIRAPH-1049

Test Plan: TestVarint.java

Reviewers: dionysis.logothetis, maja.kabiljo, sergey.edunov, ikabiljo

Reviewed By: ikabiljo

Differential Revision: https://reviews.facebook.net/D55755

giraph-core/src/main/java/org/apache/giraph/utils/Varint.java
giraph-core/src/test/java/org/apache/giraph/utils/TestVarint.java [new file with mode: 0644]

index 89d4e90..174d1f5 100644 (file)
 package org.apache.giraph.utils;
 
 /**
- * This Code is Copied from main/java/org/apache/mahout/math/Varint.java
- *
- * Only modification is throwing exceptions for passing negative values to
- * unsigned functions, instead of serializing them.
+ * This Code is adapted from main/java/org/apache/mahout/math/Varint.java
  *
  * Licensed to the Apache Software Foundation (ASF) under one or more
  * contributor license agreements.  See the NOTICE file distributed with
@@ -43,6 +40,8 @@ import java.io.DataInput;
 import java.io.DataOutput;
 import java.io.IOException;
 
+import com.google.common.base.Preconditions;
+
 /**
  * <p>
  * Encodes signed and unsigned values using a common variable-length scheme,
@@ -68,131 +67,284 @@ public final class Varint {
   /**
    * Encodes a value using the variable-length encoding from <a
    * href="http://code.google.com/apis/protocolbuffers/docs/encoding.html">
-   * Google Protocol Buffers</a>. Zig-zag is not used, so input must not be
-   * negative. If values can be negative, use
-   * {@link #writeSignedVarLong(long, DataOutput)} instead. This method treats
-   * negative input as like a large unsigned value.
+   * Google Protocol Buffers</a>.
    *
-   * @param value
-   *          value to encode
-   * @param out
-   *          to write bytes to
+   * @param value to encode
+   * @param out to write bytes to
    * @throws IOException
    *           if {@link DataOutput} throws {@link IOException}
    */
-  public static void writeUnsignedVarLong(
-      long value, DataOutput out) throws IOException {
-    if (value < 0) {
-      throw new IllegalArgumentException(
-          "Negative value passed into writeUnsignedVarLong - " + value);
-    }
-    while ((value & 0xFFFFFFFFFFFFFF80L) != 0L) {
-      out.writeByte(((int) value & 0x7F) | 0x80);
+  private static void writeVarLong(
+    long value,
+    DataOutput out
+  ) throws IOException {
+    while (true) {
+      int bits = ((int) value) & 0x7f;
       value >>>= 7;
+      if (value == 0) {
+        out.writeByte((byte) bits);
+        return;
+      }
+      out.writeByte((byte) (bits | 0x80));
     }
-    out.writeByte((int) value & 0x7F);
   }
 
   /**
-   * @see #writeUnsignedVarLong(long, DataOutput)
-   * @param value
-   *          value to encode
-   * @param out
-   *          to write bytes to
+   * Encodes a value using the variable-length encoding from <a
+   * href="http://code.google.com/apis/protocolbuffers/docs/encoding.html">
+   * Google Protocol Buffers</a>.
+   *
+   * @param value to encode
+   * @param out to write bytes to
+   * @throws IOException
+   *           if {@link DataOutput} throws {@link IOException}
    */
-  public static void writeUnsignedVarInt(
-      int value, DataOutput out) throws IOException {
-    if (value < 0) {
-      throw new IllegalArgumentException(
-          "Negative value passed into writeUnsignedVarInt - " + value);
-    }
-    while ((value & 0xFFFFFF80) != 0L) {
-      out.writeByte((value & 0x7F) | 0x80);
+  public static void writeUnsignedVarLong(
+    long value,
+    DataOutput out
+  ) throws IOException {
+    Preconditions.checkState(
+      value >= 0,
+      "Negative value passed into writeUnsignedVarLong - " + value
+    );
+    writeVarLong(value, out);
+  }
+
+  /**
+   * Zig-zag encoding for signed longs
+   *
+   * @param value to encode
+   * @param out to write bytes to
+   * @throws IOException
+   *           if {@link DataOutput} throws {@link IOException}
+   */
+  public static void writeSignedVarLong(
+    long value,
+    DataOutput out
+  ) throws IOException {
+    writeVarLong((value << 1) ^ (value >> 63), out);
+  }
+
+  /**
+   * @see #writeVarLong(long, DataOutput)
+   * @param value to encode
+   * @param out to write bytes to
+   * @throws IOException
+   */
+  private static void writeVarInt(
+    int value,
+    DataOutput out
+  ) throws IOException {
+    while (true) {
+      int bits = value & 0x7f;
       value >>>= 7;
+      if (value == 0) {
+        out.writeByte((byte) bits);
+        return;
+      }
+      out.writeByte((byte) (bits | 0x80));
     }
-    out.writeByte(value & 0x7F);
   }
 
   /**
-   * @param in
-   *          to read bytes from
+   * @see #writeVarLong(long, DataOutput)
+   * @param value to encode
+   * @param out to write bytes to
+   * @throws IOException
+   */
+  public static void writeUnsignedVarInt(
+    int value,
+    DataOutput out
+  ) throws IOException {
+    Preconditions.checkState(
+      value >= 0,
+      "Negative value passed into writeUnsignedVarInt - " + value
+    );
+    writeVarInt(value, out);
+  }
+
+  /**
+   * Zig-zag encoding for signed ints
+   *
+   * @see #writeUnsignedVarInt(int, DataOutput)
+   * @param value to encode
+   * @param out to write bytes to
+   * @throws IOException
+   */
+  public static void writeSignedVarInt(
+    int value,
+    DataOutput out
+  ) throws IOException {
+    writeVarInt((value << 1) ^ (value >> 31), out);
+  }
+
+  /**
+   * @param in to read bytes from
    * @return decode value
    * @throws IOException
    *           if {@link DataInput} throws {@link IOException}
-   * @throws IllegalArgumentException
-   *           if variable-length value does not terminate after 9 bytes have
-   *           been read
-   * @see #writeUnsignedVarLong(long, DataOutput)
    */
   public static long readUnsignedVarLong(DataInput in) throws IOException {
-    long value = 0L;
-    int i = 0;
-    long b = in.readByte();
-    while ((b & 0x80L) != 0) {
-      value |= (b & 0x7F) << i;
-      i += 7;
-      if (i > 63) {
-        throw new IllegalArgumentException(
-            "Variable length quantity is too long");
+    long tmp;
+    // CHECKSTYLE: stop InnerAssignment
+    if ((tmp = in.readByte()) >= 0) {
+      return tmp;
+    }
+    long result = tmp & 0x7f;
+    if ((tmp = in.readByte()) >= 0) {
+      result |= tmp << 7;
+    } else {
+      result |= (tmp & 0x7f) << 7;
+      if ((tmp = in.readByte()) >= 0) {
+        result |= tmp << 14;
+      } else {
+        result |= (tmp & 0x7f) << 14;
+        if ((tmp = in.readByte()) >= 0) {
+          result |= tmp << 21;
+        } else {
+          result |= (tmp & 0x7f) << 21;
+          if ((tmp = in.readByte()) >= 0) {
+            result |= tmp << 28;
+          } else {
+            result |= (tmp & 0x7f) << 28;
+            if ((tmp = in.readByte()) >= 0) {
+              result |= tmp << 35;
+            } else {
+              result |= (tmp & 0x7f) << 35;
+              if ((tmp = in.readByte()) >= 0) {
+                result |= tmp << 42;
+              } else {
+                result |= (tmp & 0x7f) << 42;
+                if ((tmp = in.readByte()) >= 0) {
+                  result |= tmp << 49;
+                } else {
+                  result |= (tmp & 0x7f) << 49;
+                  if ((tmp = in.readByte()) >= 0) {
+                    result |= tmp << 56;
+                  } else {
+                    result |= (tmp & 0x7f) << 56;
+                    result |= ((long) in.readByte()) << 63;
+                  }
+                }
+              }
+            }
+          }
+        }
       }
-      b = in.readByte();
     }
-    return value | (b << i);
+    // CHECKSTYLE: resume InnerAssignment
+    return result;
+  }
+
+  /**
+   * @param in to read bytes from
+   * @return decode value
+   * @throws IOException
+   *           if {@link DataInput} throws {@link IOException}
+   */
+  public static long readSignedVarLong(DataInput in) throws IOException {
+    long raw = readUnsignedVarLong(in);
+    long temp = (((raw << 63) >> 63) ^ raw) >> 1;
+    return temp ^ (raw & (1L << 63));
   }
 
   /**
-   * @throws IllegalArgumentException
-   *           if variable-length value does not terminate after
-   *           5 bytes have been read
    * @throws IOException
    *           if {@link DataInput} throws {@link IOException}
-   * @param in to read bytes from.
-   * @return decode value.
+   * @param in to read bytes from
+   * @return decode value
    */
   public static int readUnsignedVarInt(DataInput in) throws IOException {
-    int value = 0;
-    int i = 0;
-    int b = in.readByte();
-    while ((b & 0x80) != 0) {
-      value |= (b & 0x7F) << i;
-      i += 7;
-      if (i > 35) {
-        throw new IllegalArgumentException(
-            "Variable length quantity is too long");
+    int tmp;
+    // CHECKSTYLE: stop InnerAssignment
+    if ((tmp = in.readByte()) >= 0) {
+      return tmp;
+    }
+    int result = tmp & 0x7f;
+    if ((tmp = in.readByte()) >= 0) {
+      result |= tmp << 7;
+    } else {
+      result |= (tmp & 0x7f) << 7;
+      if ((tmp = in.readByte()) >= 0) {
+        result |= tmp << 14;
+      } else {
+        result |= (tmp & 0x7f) << 14;
+        if ((tmp = in.readByte()) >= 0) {
+          result |= tmp << 21;
+        } else {
+          result |= (tmp & 0x7f) << 21;
+          result |= (in.readByte()) << 28;
+        }
       }
-      b = in.readByte();
     }
-    return value | (b << i);
+    // CHECKSTYLE: resume InnerAssignment
+    return result;
   }
+
+  /**
+   * @throws IOException
+   *           if {@link DataInput} throws {@link IOException}
+   * @param in to read bytes from
+   * @return decode value
+   */
+  public static int readSignedVarInt(DataInput in) throws IOException {
+    int raw = readUnsignedVarInt(in);
+    int temp = (((raw << 31) >> 31) ^ raw) >> 1;
+    return temp ^ (raw & (1 << 31));
+  }
+
   /**
    * Simulation for what will happen when writing an unsigned long value
    * as varlong.
-   * @param value the value
+   * @param value to consider
    * @return the number of bytes needed to write value.
    * @throws IOException
    */
   public static long sizeOfUnsignedVarLong(long value) throws IOException {
-    long cnt = 0;
-    while ((value & 0xFFFFFFFFFFFFFF80L) != 0L) {
-      cnt++;
+    int result = 0;
+    do {
+      result++;
       value >>>= 7;
-    }
-    return ++cnt;
+    } while (value != 0);
+    return result;
+  }
+
+  /**
+   * Simulation for what will happen when writing a signed long value
+   * as varlong.
+   * @param value to consider
+   * @return the number of bytes needed to write value.
+   * @throws IOException
+   */
+  public static long sizeOfSignedVarLong(long value) throws IOException {
+    return sizeOfUnsignedVarLong((value << 1) ^ (value >> 63));
   }
 
   /**
    * Simulation for what will happen when writing an unsigned int value
    * as varint.
-   * @param value the value
+   * @param value to consider
    * @return the number of bytes needed to write value.
    * @throws IOException
    */
-  public static long sizeOfUnsignedVarInt(int value) throws IOException {
-    long cnt = 0;
-    while ((value & 0xFFFFFF80) != 0L) {
+  public static int sizeOfUnsignedVarInt(int value) throws IOException {
+    int cnt = 0;
+    do {
       cnt++;
       value >>>= 7;
-    }
-    return ++cnt;
+    } while (value != 0);
+    return cnt;
   }
+
+  /**
+   * Simulation for what will happen when writing a signed int value
+   * as varint.
+   * @param value to consider
+   * @return the number of bytes needed to write value.
+   * @throws IOException
+   */
+  public static int sizeOfSignedVarInt(int value) throws IOException {
+    return sizeOfUnsignedVarInt((value << 1) ^ (value >> 31));
+  }
+
 }
diff --git a/giraph-core/src/test/java/org/apache/giraph/utils/TestVarint.java b/giraph-core/src/test/java/org/apache/giraph/utils/TestVarint.java
new file mode 100644 (file)
index 0000000..70bebd8
--- /dev/null
@@ -0,0 +1,201 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.giraph.utils;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+import java.util.concurrent.ThreadLocalRandom;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+public class TestVarint {
+  private long[] genLongs(int n) {
+    long[] res = new long[n];
+    for (int i = 0; i < n; i++) {
+      res[i] = ThreadLocalRandom.current().nextLong();
+    }
+    return res;
+  }
+
+  private int[] genInts(int n) {
+    int[] res = new int[n];
+    for (int i = 0; i < n; i++) {
+      res[i] = ThreadLocalRandom.current().nextInt();
+    }
+    return res;
+  }
+
+  private void writeLongs(DataOutput out, long[] array) throws IOException {
+    for (int i = 0; i < array.length; i++) {
+      Varint.writeSignedVarLong(array[i], out);
+    }
+  }
+
+  private void writeInts(DataOutput out, int[] array) throws IOException {
+    for (int i = 0; i < array.length; i++) {
+      Varint.writeSignedVarInt(array[i], out);
+    }
+  }
+
+  private void readLongs(DataInput in, long[] array) throws IOException {
+    for (int i = 0; i < array.length; i++) {
+      array[i] = Varint.readSignedVarLong(in);
+    }
+  }
+
+  private void readInts(DataInput in, int[] array) throws IOException {
+    for (int i = 0; i < array.length; i++) {
+      array[i] = Varint.readSignedVarInt(in);
+    }
+  }
+
+  private void testVarLong(long value) throws IOException {
+    UnsafeByteArrayOutputStream os = new UnsafeByteArrayOutputStream();
+    Varint.writeSignedVarLong(value, os);
+
+    UnsafeByteArrayInputStream is = new UnsafeByteArrayInputStream(os.getByteArray());
+    long newValue = Varint.readSignedVarLong(is);
+
+    Assert.assertEquals(Varint.sizeOfSignedVarLong(value), os.getPos());
+    Assert.assertEquals(value, newValue);
+
+    if (value >= 0) {
+      os = new UnsafeByteArrayOutputStream();
+      Varint.writeUnsignedVarLong(value, os);
+      is = new UnsafeByteArrayInputStream(os.getByteArray());
+      newValue = Varint.readUnsignedVarLong(is);
+      Assert.assertEquals(Varint.sizeOfUnsignedVarLong(value), os.getPos());
+      Assert.assertEquals(value, newValue);
+    }
+  }
+
+  private void testVarInt(int value) throws IOException {
+    UnsafeByteArrayOutputStream os = new UnsafeByteArrayOutputStream();
+    Varint.writeSignedVarInt(value, os);
+
+    UnsafeByteArrayInputStream is = new UnsafeByteArrayInputStream(os.getByteArray());
+    int newValue = Varint.readSignedVarInt(is);
+
+    Assert.assertEquals(Varint.sizeOfSignedVarLong(value), os.getPos());
+    Assert.assertEquals(value, newValue);
+
+    if (value >= 0) {
+      os = new UnsafeByteArrayOutputStream();
+      Varint.writeUnsignedVarInt(value, os);
+      is = new UnsafeByteArrayInputStream(os.getByteArray());
+      newValue = Varint.readUnsignedVarInt(is);
+      Assert.assertEquals(Varint.sizeOfUnsignedVarInt(value), os.getPos());
+      Assert.assertEquals(value, newValue);
+    }
+  }
+
+  @Test
+  public void testVars() throws IOException {
+    testVarLong(0);
+    testVarLong(Long.MIN_VALUE);
+    testVarLong(Long.MAX_VALUE);
+    testVarLong(-123456789999l);
+    testVarLong(12342356789999l);
+    testVarInt(0);
+    testVarInt(4);
+    testVarInt(-1);
+    testVarInt(1);
+    testVarInt(Integer.MIN_VALUE);
+    testVarInt(Integer.MAX_VALUE);
+    testVarInt(Integer.MAX_VALUE - 1);
+  }
+
+  @Test
+  public void testVarLongSmall() throws IOException {
+    long[] array = new long[] {1, 2, 3, -5, 0, 12345678987l, Long.MIN_VALUE};
+    UnsafeByteArrayOutputStream os = new UnsafeByteArrayOutputStream();
+    writeLongs(os, array);
+
+    long[] resArray = new long[array.length];
+    UnsafeByteArrayInputStream is = new UnsafeByteArrayInputStream(os.getByteArray());
+    readLongs(is, resArray);
+
+    Assert.assertArrayEquals(array, resArray);
+  }
+
+  @Test
+  public void testVarIntSmall() throws IOException {
+    int[] array = new int[] {13, -2, 3, 0, 123456789, Integer.MIN_VALUE, Integer.MAX_VALUE};
+    UnsafeByteArrayOutputStream os = new UnsafeByteArrayOutputStream();
+    writeInts(os, array);
+
+    int[] resArray = new int[array.length];
+    UnsafeByteArrayInputStream is = new UnsafeByteArrayInputStream(os.getByteArray());
+    readInts(is, resArray);
+
+    Assert.assertArrayEquals(array, resArray);
+  }
+
+  @Test
+  public void testVarLongLarge() throws IOException {
+    int n = 1000000;
+    long[] array = genLongs(n);
+    UnsafeByteArrayOutputStream os = new UnsafeByteArrayOutputStream();
+
+    long startTime = System.currentTimeMillis();
+    writeLongs(os, array);
+    long endTime = System.currentTimeMillis();
+    System.out.println("Write time: " + (endTime - startTime) / 1000.0);
+
+    long[] resArray = new long[array.length];
+    UnsafeByteArrayInputStream is = new UnsafeByteArrayInputStream(os.getByteArray());
+    startTime = System.currentTimeMillis();
+    readLongs(is, resArray);
+    endTime = System.currentTimeMillis();
+    System.out.println("Read time: " + (endTime - startTime) / 1000.0);
+
+    Assert.assertArrayEquals(array, resArray);
+  }
+
+  @Test
+  public void testVarIntLarge() throws IOException {
+    int n = 1000000;
+    int[] array = genInts(n);
+    UnsafeByteArrayOutputStream os = new UnsafeByteArrayOutputStream();
+
+    long startTime = System.currentTimeMillis();
+    writeInts(os, array);
+    long endTime = System.currentTimeMillis();
+    System.out.println("Write time: " + (endTime - startTime) / 1000.0);
+
+    int[] resArray = new int[array.length];
+    UnsafeByteArrayInputStream is = new UnsafeByteArrayInputStream(os.getByteArray());
+    startTime = System.currentTimeMillis();
+    readInts(is, resArray);
+    endTime = System.currentTimeMillis();
+    System.out.println("Read time: " + (endTime - startTime) / 1000.0);
+
+    Assert.assertArrayEquals(array, resArray);
+  }
+
+  @Test
+  public void testSmall() throws IOException {
+    for (int i = -100000; i <= 100000; i++) {
+      testVarInt(i);
+      testVarLong(i);
+    }
+  }
+
+}