[SPARK-24774][SQL] Avro: Support logical decimal type
authorGengliang Wang <gengliang.wang@databricks.com>
Mon, 13 Aug 2018 00:29:07 +0000 (08:29 +0800)
committerWenchen Fan <wenchen@databricks.com>
Mon, 13 Aug 2018 00:29:07 +0000 (08:29 +0800)
## What changes were proposed in this pull request?

Support Avro logical date type:
https://avro.apache.org/docs/1.8.2/spec.html#Decimal

## How was this patch tested?
Unit test

Closes #22037 from gengliangwang/avro_decimal.

Authored-by: Gengliang Wang <gengliang.wang@databricks.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala
external/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala
external/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala
external/avro/src/test/scala/org/apache/spark/sql/avro/AvroCatalystDataConversionSuite.scala
external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala
sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala

index 74677a2..272e7d5 100644 (file)
 
 package org.apache.spark.sql.avro
 
+import java.math.{BigDecimal}
 import java.nio.ByteBuffer
 
 import scala.collection.JavaConverters._
 import scala.collection.mutable.ArrayBuffer
 
-import org.apache.avro.{Schema, SchemaBuilder}
+import org.apache.avro.{LogicalTypes, Schema, SchemaBuilder}
+import org.apache.avro.Conversions.DecimalConversion
 import org.apache.avro.LogicalTypes.{TimestampMicros, TimestampMillis}
 import org.apache.avro.Schema.Type._
 import org.apache.avro.generic._
@@ -38,6 +40,8 @@ import org.apache.spark.unsafe.types.UTF8String
  * A deserializer to deserialize data in avro format to data in catalyst format.
  */
 class AvroDeserializer(rootAvroType: Schema, rootCatalystType: DataType) {
+  private lazy val decimalConversions = new DecimalConversion()
+
   private val converter: Any => Any = rootCatalystType match {
     // A shortcut for empty schema.
     case st: StructType if st.isEmpty =>
@@ -138,10 +142,21 @@ class AvroDeserializer(rootAvroType: Schema, rootCatalystType: DataType) {
             bytes
           case b: Array[Byte] => b
           case other => throw new RuntimeException(s"$other is not a valid avro binary.")
-
         }
         updater.set(ordinal, bytes)
 
+      case (FIXED, d: DecimalType) => (updater, ordinal, value) =>
+        val bigDecimal = decimalConversions.fromFixed(value.asInstanceOf[GenericFixed], avroType,
+          LogicalTypes.decimal(d.precision, d.scale))
+        val decimal = createDecimal(bigDecimal, d.precision, d.scale)
+        updater.setDecimal(ordinal, decimal)
+
+      case (BYTES, d: DecimalType) => (updater, ordinal, value) =>
+        val bigDecimal = decimalConversions.fromBytes(value.asInstanceOf[ByteBuffer], avroType,
+          LogicalTypes.decimal(d.precision, d.scale))
+        val decimal = createDecimal(bigDecimal, d.precision, d.scale)
+        updater.setDecimal(ordinal, decimal)
+
       case (RECORD, st: StructType) =>
         val writeRecord = getRecordWriter(avroType, st, path)
         (updater, ordinal, value) =>
@@ -263,6 +278,17 @@ class AvroDeserializer(rootAvroType: Schema, rootCatalystType: DataType) {
             s"Target Catalyst type: $rootCatalystType")
     }
 
+  // TODO: move the following method in Decimal object on creating Decimal from BigDecimal?
+  private def createDecimal(decimal: BigDecimal, precision: Int, scale: Int): Decimal = {
+    if (precision <= Decimal.MAX_LONG_DIGITS) {
+      // Constructs a `Decimal` with an unscaled `Long` value if possible.
+      Decimal(decimal.unscaledValue().longValue(), precision, scale)
+    } else {
+      // Otherwise, resorts to an unscaled `BigInteger` instead.
+      Decimal(decimal, precision, scale)
+    }
+  }
+
   private def getRecordWriter(
       avroType: Schema,
       sqlType: StructType,
@@ -334,6 +360,7 @@ class AvroDeserializer(rootAvroType: Schema, rootCatalystType: DataType) {
     def setLong(ordinal: Int, value: Long): Unit = set(ordinal, value)
     def setDouble(ordinal: Int, value: Double): Unit = set(ordinal, value)
     def setFloat(ordinal: Int, value: Float): Unit = set(ordinal, value)
+    def setDecimal(ordinal: Int, value: Decimal): Unit = set(ordinal, value)
   }
 
   final class RowUpdater(row: InternalRow) extends CatalystDataUpdater {
@@ -347,6 +374,8 @@ class AvroDeserializer(rootAvroType: Schema, rootCatalystType: DataType) {
     override def setLong(ordinal: Int, value: Long): Unit = row.setLong(ordinal, value)
     override def setDouble(ordinal: Int, value: Double): Unit = row.setDouble(ordinal, value)
     override def setFloat(ordinal: Int, value: Float): Unit = row.setFloat(ordinal, value)
+    override def setDecimal(ordinal: Int, value: Decimal): Unit =
+      row.setDecimal(ordinal, value, value.precision)
   }
 
   final class ArrayDataUpdater(array: ArrayData) extends CatalystDataUpdater {
@@ -360,5 +389,6 @@ class AvroDeserializer(rootAvroType: Schema, rootCatalystType: DataType) {
     override def setLong(ordinal: Int, value: Long): Unit = array.setLong(ordinal, value)
     override def setDouble(ordinal: Int, value: Double): Unit = array.setDouble(ordinal, value)
     override def setFloat(ordinal: Int, value: Float): Unit = array.setFloat(ordinal, value)
+    override def setDecimal(ordinal: Int, value: Decimal): Unit = array.update(ordinal, value)
   }
 }
index 216c52a..3a9544c 100644 (file)
@@ -21,15 +21,17 @@ import java.nio.ByteBuffer
 
 import scala.collection.JavaConverters._
 
+import org.apache.avro.{LogicalTypes, Schema}
+import org.apache.avro.Conversions.DecimalConversion
 import org.apache.avro.LogicalTypes.{TimestampMicros, TimestampMillis}
 import org.apache.avro.Schema
 import org.apache.avro.Schema.Type
 import org.apache.avro.generic.GenericData.{EnumSymbol, Fixed, Record}
+import org.apache.avro.generic.GenericData.Record
 import org.apache.avro.util.Utf8
 
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.{SpecializedGetters, SpecificInternalRow}
-import org.apache.spark.sql.catalyst.util.DateTimeUtils
 import org.apache.spark.sql.types._
 
 /**
@@ -67,6 +69,8 @@ class AvroSerializer(rootCatalystType: DataType, rootAvroType: Schema, nullable:
 
   private type Converter = (SpecializedGetters, Int) => Any
 
+  private lazy val decimalConversions = new DecimalConversion()
+
   private def newConverter(catalystType: DataType, avroType: Schema): Converter = {
     catalystType match {
       case NullType =>
@@ -86,7 +90,11 @@ class AvroSerializer(rootCatalystType: DataType, rootAvroType: Schema, nullable:
       case DoubleType =>
         (getter, ordinal) => getter.getDouble(ordinal)
       case d: DecimalType =>
-        (getter, ordinal) => getter.getDecimal(ordinal, d.precision, d.scale).toString
+        (getter, ordinal) =>
+          val decimal = getter.getDecimal(ordinal, d.precision, d.scale)
+          decimalConversions.toFixed(decimal.toJavaBigDecimal, avroType,
+            LogicalTypes.decimal(d.precision, d.scale))
+
       case StringType => avroType.getType match {
         case Type.ENUM =>
           import scala.collection.JavaConverters._
index 245e68d..7b33cf6 100644 (file)
 package org.apache.spark.sql.avro
 
 import scala.collection.JavaConverters._
+import scala.util.Random
 
+import com.fasterxml.jackson.annotation.ObjectIdGenerators.UUIDGenerator
 import org.apache.avro.{LogicalType, LogicalTypes, Schema, SchemaBuilder}
-import org.apache.avro.LogicalTypes.{Date, TimestampMicros, TimestampMillis}
+import org.apache.avro.LogicalTypes.{Date, Decimal, TimestampMicros, TimestampMillis}
 import org.apache.avro.Schema.Type._
 
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.util.RandomUUIDGenerator
 import org.apache.spark.sql.internal.SQLConf.AvroOutputTimestampType
 import org.apache.spark.sql.types._
+import org.apache.spark.sql.types.Decimal.{maxPrecisionForBytes, minBytesForPrecision}
 
 /**
  * This object contains method that are used to convert sparkSQL schemas to avro schemas and vice
  * versa.
  */
 object SchemaConverters {
+  private lazy val uuidGenerator = RandomUUIDGenerator(new Random().nextLong())
+
+  private lazy val nullSchema = Schema.create(Schema.Type.NULL)
+
   case class SchemaType(dataType: DataType, nullable: Boolean)
 
   /**
@@ -44,14 +53,20 @@ object SchemaConverters {
       }
       case STRING => SchemaType(StringType, nullable = false)
       case BOOLEAN => SchemaType(BooleanType, nullable = false)
-      case BYTES => SchemaType(BinaryType, nullable = false)
+      case BYTES | FIXED => avroSchema.getLogicalType match {
+        // For FIXED type, if the precision requires more bytes than fixed size, the logical
+        // type will be null, which is handled by Avro library.
+        case d: Decimal => SchemaType(DecimalType(d.getPrecision, d.getScale), nullable = false)
+        case _ => SchemaType(BinaryType, nullable = false)
+      }
+
       case DOUBLE => SchemaType(DoubleType, nullable = false)
       case FLOAT => SchemaType(FloatType, nullable = false)
       case LONG => avroSchema.getLogicalType match {
         case _: TimestampMillis | _: TimestampMicros => SchemaType(TimestampType, nullable = false)
         case _ => SchemaType(LongType, nullable = false)
       }
-      case FIXED => SchemaType(BinaryType, nullable = false)
+
       case ENUM => SchemaType(StringType, nullable = false)
 
       case RECORD =>
@@ -114,20 +129,14 @@ object SchemaConverters {
       prevNameSpace: String = "",
       outputTimestampType: AvroOutputTimestampType.Value = AvroOutputTimestampType.TIMESTAMP_MICROS)
     : Schema = {
-    val builder = if (nullable) {
-      SchemaBuilder.builder().nullable()
-    } else {
-      SchemaBuilder.builder()
-    }
+    val builder = SchemaBuilder.builder()
 
-    catalystType match {
+    val schema = catalystType match {
       case BooleanType => builder.booleanType()
       case ByteType | ShortType | IntegerType => builder.intType()
       case LongType => builder.longType()
-      case DateType => builder
-        .intBuilder()
-        .prop(LogicalType.LOGICAL_TYPE_PROP, LogicalTypes.date().getName)
-        .endInt()
+      case DateType =>
+        LogicalTypes.date().addToSchema(builder.intType())
       case TimestampType =>
         val timestampType = outputTimestampType match {
           case AvroOutputTimestampType.TIMESTAMP_MILLIS => LogicalTypes.timestampMillis()
@@ -135,11 +144,21 @@ object SchemaConverters {
           case other =>
             throw new IncompatibleSchemaException(s"Unexpected output timestamp type $other.")
         }
-        builder.longBuilder().prop(LogicalType.LOGICAL_TYPE_PROP, timestampType.getName).endLong()
+        timestampType.addToSchema(builder.longType())
 
       case FloatType => builder.floatType()
       case DoubleType => builder.doubleType()
-      case _: DecimalType | StringType => builder.stringType()
+      case StringType => builder.stringType()
+      case d: DecimalType =>
+        val avroType = LogicalTypes.decimal(d.precision, d.scale)
+        val fixedSize = minBytesForPrecision(d.precision)
+        // Need to avoid naming conflict for the fixed fields
+        val name = prevNameSpace match {
+          case "" => s"$recordName.fixed"
+          case _ => s"$prevNameSpace.$recordName.fixed"
+        }
+        avroType.addToSchema(SchemaBuilder.fixed(name).size(fixedSize))
+
       case BinaryType => builder.bytesType()
       case ArrayType(et, containsNull) =>
         builder.array()
@@ -164,6 +183,11 @@ object SchemaConverters {
       // This should never happen.
       case other => throw new IncompatibleSchemaException(s"Unexpected type $other.")
     }
+    if (nullable) {
+      Schema.createUnion(schema, nullSchema)
+    } else {
+      schema
+    }
   }
 }
 
index 06d5477..4b3bf0c 100644 (file)
@@ -64,8 +64,6 @@ class AvroCatalystDataConversionSuite extends SparkFunSuite with ExpressionEvalH
     BinaryType)
 
   protected def prepareExpectedResult(expected: Any): Any = expected match {
-    // Spark decimal is converted to avro string=
-    case d: Decimal => UTF8String.fromString(d.toString)
     // Spark byte and short both map to avro int
     case b: Byte => b.toInt
     case s: Short => s.toInt
index ada9980..3fa43bf 100644 (file)
@@ -25,7 +25,8 @@ import java.util.{TimeZone, UUID}
 
 import scala.collection.JavaConverters._
 
-import org.apache.avro.Schema
+import org.apache.avro.{LogicalTypes, Schema}
+import org.apache.avro.Conversions.DecimalConversion
 import org.apache.avro.Schema.{Field, Type}
 import org.apache.avro.file.{DataFileReader, DataFileWriter}
 import org.apache.avro.generic.{GenericData, GenericDatumReader, GenericDatumWriter, GenericRecord}
@@ -494,6 +495,104 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
     checkAnswer(df, expected)
   }
 
+  test("Logical type: Decimal") {
+    val precision = 4
+    val scale = 2
+    val bytesFieldName = "bytes"
+    val bytesSchema = s"""{
+         "type":"bytes",
+         "logicalType":"decimal",
+         "precision":$precision,
+         "scale":$scale
+      }
+    """
+
+    val fixedFieldName = "fixed"
+    val fixedSchema = s"""{
+         "type":"fixed",
+         "size":5,
+         "logicalType":"decimal",
+         "precision":$precision,
+         "scale":$scale,
+         "name":"foo"
+      }
+    """
+    val avroSchema = s"""
+      {
+        "namespace": "logical",
+        "type": "record",
+        "name": "test",
+        "fields": [
+          {"name": "$bytesFieldName", "type": $bytesSchema},
+          {"name": "$fixedFieldName", "type": $fixedSchema}
+        ]
+      }
+    """
+    val schema = new Schema.Parser().parse(avroSchema)
+    val datumWriter = new GenericDatumWriter[GenericRecord](schema)
+    val dataFileWriter = new DataFileWriter[GenericRecord](datumWriter)
+    val decimalConversion = new DecimalConversion
+    withTempDir { dir =>
+      val avroFile = s"$dir.avro"
+      dataFileWriter.create(schema, new File(avroFile))
+      val logicalType = LogicalTypes.decimal(precision, scale)
+      val data = Seq("1.23", "4.56", "78.90", "-1", "-2.31")
+      data.map { x =>
+        val avroRec = new GenericData.Record(schema)
+        val decimal = new java.math.BigDecimal(x).setScale(scale)
+        val bytes =
+          decimalConversion.toBytes(decimal, schema.getField(bytesFieldName).schema, logicalType)
+        avroRec.put(bytesFieldName, bytes)
+        val fixed =
+          decimalConversion.toFixed(decimal, schema.getField(fixedFieldName).schema, logicalType)
+        avroRec.put(fixedFieldName, fixed)
+        dataFileWriter.append(avroRec)
+      }
+      dataFileWriter.flush()
+      dataFileWriter.close()
+
+      val expected = data.map { x => Row(new java.math.BigDecimal(x), new java.math.BigDecimal(x)) }
+      val df = spark.read.format("avro").load(avroFile)
+      checkAnswer(df, expected)
+      checkAnswer(spark.read.format("avro").option("avroSchema", avroSchema).load(avroFile),
+        expected)
+
+      withTempPath { path =>
+        df.write.format("avro").save(path.toString)
+        checkAnswer(spark.read.format("avro").load(path.toString), expected)
+      }
+    }
+  }
+
+  test("Logical type: Decimal with too large precision") {
+    withTempDir { dir =>
+      val schema = new Schema.Parser().parse("""{
+        "namespace": "logical",
+        "type": "record",
+        "name": "test",
+        "fields": [{
+          "name": "decimal",
+          "type": {"type": "bytes", "logicalType": "decimal", "precision": 4, "scale": 2}
+        }]
+      }""")
+      val datumWriter = new GenericDatumWriter[GenericRecord](schema)
+      val dataFileWriter = new DataFileWriter[GenericRecord](datumWriter)
+      dataFileWriter.create(schema, new File(s"$dir.avro"))
+      val avroRec = new GenericData.Record(schema)
+      val decimal = new java.math.BigDecimal("0.12345678901234567890123456789012345678")
+      val bytes = (new DecimalConversion).toBytes(decimal, schema, LogicalTypes.decimal(39, 38))
+      avroRec.put("decimal", bytes)
+      dataFileWriter.append(avroRec)
+      dataFileWriter.flush()
+      dataFileWriter.close()
+
+      val msg = intercept[SparkException] {
+        spark.read.format("avro").load(s"$dir.avro").collect()
+      }.getCause.getMessage
+      assert(msg.contains("Unscaled value too large for precision"))
+    }
+  }
+
   test("Array data types") {
     withTempPath { dir =>
       val testSchema = StructType(Seq(
@@ -689,7 +788,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
 
       // DecimalType should be converted to string
       val decimals = spark.read.format("avro").load(avroDir).select("Decimal").collect()
-      assert(decimals.map(_(0)).contains("3.14"))
+      assert(decimals.map(_(0)).contains(new java.math.BigDecimal("3.14")))
 
       // There should be a null entry
       val length = spark.read.format("avro").load(avroDir).select("Length").collect()
index 6da4f28..9eed2eb 100644 (file)
@@ -479,6 +479,25 @@ object Decimal {
     dec
   }
 
+  // Max precision of a decimal value stored in `numBytes` bytes
+  def maxPrecisionForBytes(numBytes: Int): Int = {
+    Math.round(                               // convert double to long
+      Math.floor(Math.log10(                  // number of base-10 digits
+        Math.pow(2, 8 * numBytes - 1) - 1)))  // max value stored in numBytes
+      .asInstanceOf[Int]
+  }
+
+  // Returns the minimum number of bytes needed to store a decimal with a given `precision`.
+  lazy val minBytesForPrecision = Array.tabulate[Int](39)(computeMinBytesForPrecision)
+
+  private def computeMinBytesForPrecision(precision : Int) : Int = {
+    var numBytes = 1
+    while (math.pow(2.0, 8 * numBytes - 1) < math.pow(10.0, precision)) {
+      numBytes += 1
+    }
+    numBytes
+  }
+
   // Evidence parameters for Decimal considered either as Fractional or Integral. We provide two
   // parameters inheriting from a common trait since both traits define mkNumericOps.
   // See scala.math's Numeric.scala for examples for Scala's built-in types.
index 70f42f2..8ce8a86 100644 (file)
@@ -26,7 +26,6 @@ import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName._
 import org.apache.parquet.schema.Type.Repetition._
 
 import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.execution.datasources.parquet.ParquetSchemaConverter.maxPrecisionForBytes
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
 
@@ -171,7 +170,7 @@ class ParquetToSparkSchemaConverter(
 
       case FIXED_LEN_BYTE_ARRAY =>
         originalType match {
-          case DECIMAL => makeDecimalType(maxPrecisionForBytes(field.getTypeLength))
+          case DECIMAL => makeDecimalType(Decimal.maxPrecisionForBytes(field.getTypeLength))
           case INTERVAL => typeNotImplemented()
           case _ => illegalType()
         }
@@ -411,7 +410,7 @@ class SparkToParquetSchemaConverter(
           .as(DECIMAL)
           .precision(precision)
           .scale(scale)
-          .length(ParquetSchemaConverter.minBytesForPrecision(precision))
+          .length(Decimal.minBytesForPrecision(precision))
           .named(field.name)
 
       // ========================
@@ -445,7 +444,7 @@ class SparkToParquetSchemaConverter(
           .as(DECIMAL)
           .precision(precision)
           .scale(scale)
-          .length(ParquetSchemaConverter.minBytesForPrecision(precision))
+          .length(Decimal.minBytesForPrecision(precision))
           .named(field.name)
 
       // ===================================
@@ -584,23 +583,4 @@ private[sql] object ParquetSchemaConverter {
       throw new AnalysisException(message)
     }
   }
-
-  private def computeMinBytesForPrecision(precision : Int) : Int = {
-    var numBytes = 1
-    while (math.pow(2.0, 8 * numBytes - 1) < math.pow(10.0, precision)) {
-      numBytes += 1
-    }
-    numBytes
-  }
-
-  // Returns the minimum number of bytes needed to store a decimal with a given `precision`.
-  val minBytesForPrecision = Array.tabulate[Int](39)(computeMinBytesForPrecision)
-
-  // Max precision of a decimal value stored in `numBytes` bytes
-  def maxPrecisionForBytes(numBytes: Int): Int = {
-    Math.round(                               // convert double to long
-      Math.floor(Math.log10(                  // number of base-10 digits
-        Math.pow(2, 8 * numBytes - 1) - 1)))  // max value stored in numBytes
-      .asInstanceOf[Int]
-  }
 }
index af4e143..b40b8c2 100644 (file)
@@ -33,7 +33,6 @@ import org.apache.spark.internal.Logging
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.SpecializedGetters
 import org.apache.spark.sql.catalyst.util.DateTimeUtils
-import org.apache.spark.sql.execution.datasources.parquet.ParquetSchemaConverter.minBytesForPrecision
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
 
@@ -73,7 +72,8 @@ private[parquet] class ParquetWriteSupport extends WriteSupport[InternalRow] wit
   private val timestampBuffer = new Array[Byte](12)
 
   // Reusable byte array used to write decimal values
-  private val decimalBuffer = new Array[Byte](minBytesForPrecision(DecimalType.MAX_PRECISION))
+  private val decimalBuffer =
+    new Array[Byte](Decimal.minBytesForPrecision(DecimalType.MAX_PRECISION))
 
   override def init(configuration: Configuration): WriteContext = {
     val schemaString = configuration.get(ParquetWriteSupport.SPARK_ROW_SCHEMA)
@@ -212,7 +212,7 @@ private[parquet] class ParquetWriteSupport extends WriteSupport[InternalRow] wit
       precision <= DecimalType.MAX_PRECISION,
       s"Decimal precision $precision exceeds max precision ${DecimalType.MAX_PRECISION}")
 
-    val numBytes = minBytesForPrecision(precision)
+    val numBytes = Decimal.minBytesForPrecision(precision)
 
     val int32Writer =
       (row: SpecializedGetters, ordinal: Int) => {