[SPARK-23933][SQL] Add map_from_arrays function
authorKazuaki Ishizaki <ishizaki@jp.ibm.com>
Tue, 12 Jun 2018 19:31:22 +0000 (12:31 -0700)
committerTakuya UESHIN <ueshin@databricks.com>
Tue, 12 Jun 2018 19:31:22 +0000 (12:31 -0700)
## What changes were proposed in this pull request?

The PR adds the SQL function `map_from_arrays`. The behavior of the function is based on Presto's `map`. Since SparkSQL already had a `map` function, we prepared the different name for this behavior.

This function returns returns a map from a pair of arrays for keys and values.

## How was this patch tested?

Added UTs

Author: Kazuaki Ishizaki <ishizaki@jp.ibm.com>

Closes #21258 from kiszk/SPARK-23933.

python/pyspark/sql/functions.py
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
sql/core/src/main/scala/org/apache/spark/sql/functions.scala
sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala

index 0715297..1cdbb8a 100644 (file)
@@ -1819,6 +1819,25 @@ def create_map(*cols):
     return Column(jc)
 
 
+@since(2.4)
+def map_from_arrays(col1, col2):
+    """Creates a new map from two arrays.
+
+    :param col1: name of column containing a set of keys. All elements should not be null
+    :param col2: name of column containing a set of values
+
+    >>> df = spark.createDataFrame([([2, 5], ['a', 'b'])], ['k', 'v'])
+    >>> df.select(map_from_arrays(df.k, df.v).alias("map")).show()
+    +----------------+
+    |             map|
+    +----------------+
+    |[2 -> a, 5 -> b]|
+    +----------------+
+    """
+    sc = SparkContext._active_spark_context
+    return Column(sc._jvm.functions.map_from_arrays(_to_java_column(col1), _to_java_column(col2)))
+
+
 @since(1.4)
 def array(*cols):
     """Creates a new array column.
index 3c0b728..3700c63 100644 (file)
@@ -417,6 +417,7 @@ object FunctionRegistry {
     expression[CreateMap]("map"),
     expression[CreateNamedStruct]("named_struct"),
     expression[ElementAt]("element_at"),
+    expression[MapFromArrays]("map_from_arrays"),
     expression[MapKeys]("map_keys"),
     expression[MapValues]("map_values"),
     expression[MapEntries]("map_entries"),
index a9867aa..0a5f8a9 100644 (file)
@@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
 import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
 import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.catalyst.expressions.codegen.Block._
-import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, TypeUtils}
+import org.apache.spark.sql.catalyst.util._
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.Platform
 import org.apache.spark.unsafe.array.ByteArrayMethods
@@ -237,6 +237,76 @@ case class CreateMap(children: Seq[Expression]) extends Expression {
 }
 
 /**
+ * Returns a catalyst Map containing the two arrays in children expressions as keys and values.
+ */
+@ExpressionDescription(
+  usage = """
+    _FUNC_(keys, values) - Creates a map with a pair of the given key/value arrays. All elements
+      in keys should not be null""",
+  examples = """
+    Examples:
+      > SELECT _FUNC_([1.0, 3.0], ['2', '4']);
+       {1.0:"2",3.0:"4"}
+  """, since = "2.4.0")
+case class MapFromArrays(left: Expression, right: Expression)
+  extends BinaryExpression with ExpectsInputTypes {
+
+  override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, ArrayType)
+
+  override def dataType: DataType = {
+    MapType(
+      keyType = left.dataType.asInstanceOf[ArrayType].elementType,
+      valueType = right.dataType.asInstanceOf[ArrayType].elementType,
+      valueContainsNull = right.dataType.asInstanceOf[ArrayType].containsNull)
+  }
+
+  override def nullSafeEval(keyArray: Any, valueArray: Any): Any = {
+    val keyArrayData = keyArray.asInstanceOf[ArrayData]
+    val valueArrayData = valueArray.asInstanceOf[ArrayData]
+    if (keyArrayData.numElements != valueArrayData.numElements) {
+      throw new RuntimeException("The given two arrays should have the same length")
+    }
+    val leftArrayType = left.dataType.asInstanceOf[ArrayType]
+    if (leftArrayType.containsNull) {
+      var i = 0
+      while (i < keyArrayData.numElements) {
+        if (keyArrayData.isNullAt(i)) {
+          throw new RuntimeException("Cannot use null as map key!")
+        }
+        i += 1
+      }
+    }
+    new ArrayBasedMapData(keyArrayData.copy(), valueArrayData.copy())
+  }
+
+  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+    nullSafeCodeGen(ctx, ev, (keyArrayData, valueArrayData) => {
+      val arrayBasedMapData = classOf[ArrayBasedMapData].getName
+      val leftArrayType = left.dataType.asInstanceOf[ArrayType]
+      val keyArrayElemNullCheck = if (!leftArrayType.containsNull) "" else {
+        val i = ctx.freshName("i")
+        s"""
+           |for (int $i = 0; $i < $keyArrayData.numElements(); $i++) {
+           |  if ($keyArrayData.isNullAt($i)) {
+           |    throw new RuntimeException("Cannot use null as map key!");
+           |  }
+           |}
+         """.stripMargin
+      }
+      s"""
+         |if ($keyArrayData.numElements() != $valueArrayData.numElements()) {
+         |  throw new RuntimeException("The given two arrays should have the same length");
+         |}
+         |$keyArrayElemNullCheck
+         |${ev.value} = new $arrayBasedMapData($keyArrayData.copy(), $valueArrayData.copy());
+       """.stripMargin
+    })
+  }
+
+  override def prettyName: String = "map_from_arrays"
+}
+
+/**
  * An expression representing a not yet available attribute name. This expression is unevaluable
  * and as its name suggests it is a temporary place holder until we're able to determine the
  * actual attribute name.
index b4138ce..726193b 100644 (file)
@@ -186,6 +186,50 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
     }
   }
 
+  test("MapFromArrays") {
+    def createMap(keys: Seq[Any], values: Seq[Any]): Map[Any, Any] = {
+      // catalyst map is order-sensitive, so we create ListMap here to preserve the elements order.
+      scala.collection.immutable.ListMap(keys.zip(values): _*)
+    }
+
+    val intSeq = Seq(5, 10, 15, 20, 25)
+    val longSeq = intSeq.map(_.toLong)
+    val strSeq = intSeq.map(_.toString)
+    val integerSeq = Seq[java.lang.Integer](5, 10, 15, 20, 25)
+    val intWithNullSeq = Seq[java.lang.Integer](5, 10, null, 20, 25)
+    val longWithNullSeq = intSeq.map(java.lang.Long.valueOf(_))
+
+    val intArray = Literal.create(intSeq, ArrayType(IntegerType, false))
+    val longArray = Literal.create(longSeq, ArrayType(LongType, false))
+    val strArray = Literal.create(strSeq, ArrayType(StringType, false))
+
+    val integerArray = Literal.create(integerSeq, ArrayType(IntegerType, true))
+    val intWithNullArray = Literal.create(intWithNullSeq, ArrayType(IntegerType, true))
+    val longWithNullArray = Literal.create(longWithNullSeq, ArrayType(LongType, true))
+
+    val nullArray = Literal.create(null, ArrayType(StringType, false))
+
+    checkEvaluation(MapFromArrays(intArray, longArray), createMap(intSeq, longSeq))
+    checkEvaluation(MapFromArrays(intArray, strArray), createMap(intSeq, strSeq))
+    checkEvaluation(MapFromArrays(integerArray, strArray), createMap(integerSeq, strSeq))
+
+    checkEvaluation(
+      MapFromArrays(strArray, intWithNullArray), createMap(strSeq, intWithNullSeq))
+    checkEvaluation(
+      MapFromArrays(strArray, longWithNullArray), createMap(strSeq, longWithNullSeq))
+    checkEvaluation(
+      MapFromArrays(strArray, longWithNullArray), createMap(strSeq, longWithNullSeq))
+    checkEvaluation(MapFromArrays(nullArray, nullArray), null)
+
+    intercept[RuntimeException] {
+      checkEvaluation(MapFromArrays(intWithNullArray, strArray), null)
+    }
+    intercept[RuntimeException] {
+      checkEvaluation(
+        MapFromArrays(intArray, Literal.create(Seq(1), ArrayType(IntegerType))), null)
+    }
+  }
+
   test("CreateStruct") {
     val row = create_row(1, 2, 3)
     val c1 = 'a.int.at(0)
index 266a136..87bd7b3 100644 (file)
@@ -1071,6 +1071,17 @@ object functions {
   def map(cols: Column*): Column = withExpr { CreateMap(cols.map(_.expr)) }
 
   /**
+   * Creates a new map column. The array in the first column is used for keys. The array in the
+   * second column is used for values. All elements in the array for key should not be null.
+   *
+   * @group normal_funcs
+   * @since 2.4
+   */
+  def map_from_arrays(keys: Column, values: Column): Column = withExpr {
+    MapFromArrays(keys.expr, values.expr)
+  }
+
+  /**
    * Marks a DataFrame as small enough for use in broadcast joins.
    *
    * The following example marks the right DataFrame for broadcast hash join using `joinKey`.
index 959a77a..4e5c1c5 100644 (file)
@@ -62,6 +62,36 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
     assert(row.getMap[Int, String](0) === Map(2 -> "a"))
   }
 
+  test("map with arrays") {
+    val df1 = Seq((Seq(1, 2), Seq("a", "b"))).toDF("k", "v")
+    val expectedType = MapType(IntegerType, StringType, valueContainsNull = true)
+    val row = df1.select(map_from_arrays($"k", $"v")).first()
+    assert(row.schema(0).dataType === expectedType)
+    assert(row.getMap[Int, String](0) === Map(1 -> "a", 2 -> "b"))
+    checkAnswer(df1.select(map_from_arrays($"k", $"v")), Seq(Row(Map(1 -> "a", 2 -> "b"))))
+
+    val df2 = Seq((Seq(1, 2), Seq(null, "b"))).toDF("k", "v")
+    checkAnswer(df2.select(map_from_arrays($"k", $"v")), Seq(Row(Map(1 -> null, 2 -> "b"))))
+
+    val df3 = Seq((null, null)).toDF("k", "v")
+    checkAnswer(df3.select(map_from_arrays($"k", $"v")), Seq(Row(null)))
+
+    val df4 = Seq((1, "a")).toDF("k", "v")
+    intercept[AnalysisException] {
+      df4.select(map_from_arrays($"k", $"v"))
+    }
+
+    val df5 = Seq((Seq("a", null), Seq(1, 2))).toDF("k", "v")
+    intercept[RuntimeException] {
+      df5.select(map_from_arrays($"k", $"v")).collect
+    }
+
+    val df6 = Seq((Seq(1, 2), Seq("a"))).toDF("k", "v")
+    intercept[RuntimeException] {
+      df6.select(map_from_arrays($"k", $"v")).collect
+    }
+  }
+
   test("struct with column name") {
     val df = Seq((1, "str")).toDF("a", "b")
     val row = df.select(struct("a", "b")).first()