[SPARK-23931][SQL] Adds arrays_zip function to sparksql
authorDylanGuedes <djmgguedes@gmail.com>
Tue, 12 Jun 2018 18:57:25 +0000 (11:57 -0700)
committerTakuya UESHIN <ueshin@databricks.com>
Tue, 12 Jun 2018 18:57:25 +0000 (11:57 -0700)
Signed-off-by: DylanGuedes <djmgguedesgmail.com>
## What changes were proposed in this pull request?

Addition of arrays_zip function to spark sql functions.

## How was this patch tested?

(Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests)
Unit tests that checks if the results are correct.

Author: DylanGuedes <djmgguedes@gmail.com>

Closes #21045 from DylanGuedes/SPARK-23931.

python/pyspark/sql/functions.py
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
sql/core/src/main/scala/org/apache/spark/sql/functions.scala
sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala

index 1759195..0715297 100644 (file)
@@ -2394,6 +2394,23 @@ def array_repeat(col, count):
     return Column(sc._jvm.functions.array_repeat(_to_java_column(col), count))
 
 
+@since(2.4)
+def arrays_zip(*cols):
+    """
+    Collection function: Returns a merged array of structs in which the N-th struct contains all
+    N-th values of input arrays.
+
+    :param cols: columns of arrays to be merged.
+
+    >>> from pyspark.sql.functions import arrays_zip
+    >>> df = spark.createDataFrame([(([1, 2, 3], [2, 3, 4]))], ['vals1', 'vals2'])
+    >>> df.select(arrays_zip(df.vals1, df.vals2).alias('zipped')).collect()
+    [Row(zipped=[Row(vals1=1, vals2=2), Row(vals1=2, vals2=3), Row(vals1=3, vals2=4)])]
+    """
+    sc = SparkContext._active_spark_context
+    return Column(sc._jvm.functions.arrays_zip(_to_seq(sc, cols, _to_java_column)))
+
+
 # ---------------------------- User Defined Function ----------------------------------
 
 class PandasUDFType(object):
index 49fb35b..3c0b728 100644 (file)
@@ -423,6 +423,7 @@ object FunctionRegistry {
     expression[Size]("size"),
     expression[Slice]("slice"),
     expression[Size]("cardinality"),
+    expression[ArraysZip]("arrays_zip"),
     expression[SortArray]("sort_array"),
     expression[ArrayMin]("array_min"),
     expression[ArrayMax]("array_max"),
index 176995a..d76f301 100644 (file)
@@ -128,6 +128,172 @@ case class MapKeys(child: Expression)
   override def prettyName: String = "map_keys"
 }
 
+@ExpressionDescription(
+  usage = """
+    _FUNC_(a1, a2, ...) - Returns a merged array of structs in which the N-th struct contains all
+    N-th values of input arrays.
+  """,
+  examples = """
+    Examples:
+      > SELECT _FUNC_(array(1, 2, 3), array(2, 3, 4));
+        [[1, 2], [2, 3], [3, 4]]
+      > SELECT _FUNC_(array(1, 2), array(2, 3), array(3, 4));
+        [[1, 2, 3], [2, 3, 4]]
+  """,
+  since = "2.4.0")
+case class ArraysZip(children: Seq[Expression]) extends Expression with ExpectsInputTypes {
+
+  override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.length)(ArrayType)
+
+  override def dataType: DataType = ArrayType(mountSchema)
+
+  override def nullable: Boolean = children.exists(_.nullable)
+
+  private lazy val arrayTypes = children.map(_.dataType.asInstanceOf[ArrayType])
+
+  private lazy val arrayElementTypes = arrayTypes.map(_.elementType)
+
+  @transient private lazy val mountSchema: StructType = {
+    val fields = children.zip(arrayElementTypes).zipWithIndex.map {
+      case ((expr: NamedExpression, elementType), _) =>
+        StructField(expr.name, elementType, nullable = true)
+      case ((_, elementType), idx) =>
+        StructField(idx.toString, elementType, nullable = true)
+    }
+    StructType(fields)
+  }
+
+  @transient lazy val numberOfArrays: Int = children.length
+
+  @transient lazy val genericArrayData = classOf[GenericArrayData].getName
+
+  def emptyInputGenCode(ev: ExprCode): ExprCode = {
+    ev.copy(code"""
+      |${CodeGenerator.javaType(dataType)} ${ev.value} = new $genericArrayData(new Object[0]);
+      |boolean ${ev.isNull} = false;
+    """.stripMargin)
+  }
+
+  def nonEmptyInputGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+    val genericInternalRow = classOf[GenericInternalRow].getName
+    val arrVals = ctx.freshName("arrVals")
+    val biggestCardinality = ctx.freshName("biggestCardinality")
+
+    val currentRow = ctx.freshName("currentRow")
+    val j = ctx.freshName("j")
+    val i = ctx.freshName("i")
+    val args = ctx.freshName("args")
+
+    val evals = children.map(_.genCode(ctx))
+    val getValuesAndCardinalities = evals.zipWithIndex.map { case (eval, index) =>
+      s"""
+        |if ($biggestCardinality != -1) {
+        |  ${eval.code}
+        |  if (!${eval.isNull}) {
+        |    $arrVals[$index] = ${eval.value};
+        |    $biggestCardinality = Math.max($biggestCardinality, ${eval.value}.numElements());
+        |  } else {
+        |    $biggestCardinality = -1;
+        |  }
+        |}
+      """.stripMargin
+    }
+
+    val splittedGetValuesAndCardinalities = ctx.splitExpressions(
+      expressions = getValuesAndCardinalities,
+      funcName = "getValuesAndCardinalities",
+      returnType = "int",
+      makeSplitFunction = body =>
+        s"""
+          |$body
+          |return $biggestCardinality;
+        """.stripMargin,
+      foldFunctions = _.map(funcCall => s"$biggestCardinality = $funcCall;").mkString("\n"),
+      arguments =
+        ("ArrayData[]", arrVals) ::
+        ("int", biggestCardinality) :: Nil)
+
+    val getValueForType = arrayElementTypes.zipWithIndex.map { case (eleType, idx) =>
+      val g = CodeGenerator.getValue(s"$arrVals[$idx]", eleType, i)
+      s"""
+        |if ($i < $arrVals[$idx].numElements() && !$arrVals[$idx].isNullAt($i)) {
+        |  $currentRow[$idx] = $g;
+        |} else {
+        |  $currentRow[$idx] = null;
+        |}
+      """.stripMargin
+    }
+
+    val getValueForTypeSplitted = ctx.splitExpressions(
+      expressions = getValueForType,
+      funcName = "extractValue",
+      arguments =
+        ("int", i) ::
+        ("Object[]", currentRow) ::
+        ("ArrayData[]", arrVals) :: Nil)
+
+    val initVariables = s"""
+      |ArrayData[] $arrVals = new ArrayData[$numberOfArrays];
+      |int $biggestCardinality = 0;
+      |${CodeGenerator.javaType(dataType)} ${ev.value} = null;
+    """.stripMargin
+
+    ev.copy(code"""
+      |$initVariables
+      |$splittedGetValuesAndCardinalities
+      |boolean ${ev.isNull} = $biggestCardinality == -1;
+      |if (!${ev.isNull}) {
+      |  Object[] $args = new Object[$biggestCardinality];
+      |  for (int $i = 0; $i < $biggestCardinality; $i ++) {
+      |    Object[] $currentRow = new Object[$numberOfArrays];
+      |    $getValueForTypeSplitted
+      |    $args[$i] = new $genericInternalRow($currentRow);
+      |  }
+      |  ${ev.value} = new $genericArrayData($args);
+      |}
+    """.stripMargin)
+  }
+
+  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+    if (numberOfArrays == 0) {
+      emptyInputGenCode(ev)
+    } else {
+      nonEmptyInputGenCode(ctx, ev)
+    }
+  }
+
+  override def eval(input: InternalRow): Any = {
+    val inputArrays = children.map(_.eval(input).asInstanceOf[ArrayData])
+    if (inputArrays.contains(null)) {
+      null
+    } else {
+      val biggestCardinality = if (inputArrays.isEmpty) {
+        0
+      } else {
+        inputArrays.map(_.numElements()).max
+      }
+
+      val result = new Array[InternalRow](biggestCardinality)
+      val zippedArrs: Seq[(ArrayData, Int)] = inputArrays.zipWithIndex
+
+      for (i <- 0 until biggestCardinality) {
+        val currentLayer: Seq[Object] = zippedArrs.map { case (arr, index) =>
+          if (i < arr.numElements() && !arr.isNullAt(i)) {
+            arr.get(i, arrayElementTypes(index))
+          } else {
+            null
+          }
+        }
+
+        result(i) = InternalRow.apply(currentLayer: _*)
+      }
+      new GenericArrayData(result)
+    }
+  }
+
+  override def prettyName: String = "arrays_zip"
+}
+
 /**
  * Returns an unordered array containing the values of the map.
  */
index f8ad624..85e692b 100644 (file)
@@ -18,6 +18,7 @@
 package org.apache.spark.sql.catalyst.expressions
 
 import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.Row
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.types._
 
@@ -315,6 +316,91 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
       Some(Literal.create(null, StringType))), null)
   }
 
+  test("ArraysZip") {
+    val literals = Seq(
+      Literal.create(Seq(9001, 9002, 9003, null), ArrayType(IntegerType)),
+      Literal.create(Seq(null, 1L, null, 4L, 11L), ArrayType(LongType)),
+      Literal.create(Seq(-1, -3, 900, null), ArrayType(IntegerType)),
+      Literal.create(Seq("a", null, "c"), ArrayType(StringType)),
+      Literal.create(Seq(null, false, true), ArrayType(BooleanType)),
+      Literal.create(Seq(1.1, null, 1.3, null), ArrayType(DoubleType)),
+      Literal.create(Seq(), ArrayType(NullType)),
+      Literal.create(Seq(null), ArrayType(NullType)),
+      Literal.create(Seq(192.toByte), ArrayType(ByteType)),
+      Literal.create(
+        Seq(Seq(1, 2, 3), null, Seq(4, 5), Seq(1, null, 3)), ArrayType(ArrayType(IntegerType))),
+      Literal.create(Seq(Array[Byte](1.toByte, 5.toByte)), ArrayType(BinaryType))
+    )
+
+    checkEvaluation(ArraysZip(Seq(literals(0), literals(1))),
+      List(Row(9001, null), Row(9002, 1L), Row(9003, null), Row(null, 4L), Row(null, 11L)))
+
+    checkEvaluation(ArraysZip(Seq(literals(0), literals(2))),
+      List(Row(9001, -1), Row(9002, -3), Row(9003, 900), Row(null, null)))
+
+    checkEvaluation(ArraysZip(Seq(literals(0), literals(3))),
+      List(Row(9001, "a"), Row(9002, null), Row(9003, "c"), Row(null, null)))
+
+    checkEvaluation(ArraysZip(Seq(literals(0), literals(4))),
+      List(Row(9001, null), Row(9002, false), Row(9003, true), Row(null, null)))
+
+    checkEvaluation(ArraysZip(Seq(literals(0), literals(5))),
+      List(Row(9001, 1.1), Row(9002, null), Row(9003, 1.3), Row(null, null)))
+
+    checkEvaluation(ArraysZip(Seq(literals(0), literals(6))),
+      List(Row(9001, null), Row(9002, null), Row(9003, null), Row(null, null)))
+
+    checkEvaluation(ArraysZip(Seq(literals(0), literals(7))),
+      List(Row(9001, null), Row(9002, null), Row(9003, null), Row(null, null)))
+
+    checkEvaluation(ArraysZip(Seq(literals(0), literals(1), literals(2), literals(3))),
+      List(
+        Row(9001, null, -1, "a"),
+        Row(9002, 1L, -3, null),
+        Row(9003, null, 900, "c"),
+        Row(null, 4L, null, null),
+        Row(null, 11L, null, null)))
+
+    checkEvaluation(ArraysZip(Seq(literals(4), literals(5), literals(6), literals(7), literals(8))),
+      List(
+        Row(null, 1.1, null, null, 192.toByte),
+        Row(false, null, null, null, null),
+        Row(true, 1.3, null, null, null),
+        Row(null, null, null, null, null)))
+
+    checkEvaluation(ArraysZip(Seq(literals(9), literals(0))),
+      List(
+        Row(List(1, 2, 3), 9001),
+        Row(null, 9002),
+        Row(List(4, 5), 9003),
+        Row(List(1, null, 3), null)))
+
+    checkEvaluation(ArraysZip(Seq(literals(7), literals(10))),
+      List(Row(null, Array[Byte](1.toByte, 5.toByte))))
+
+    val longLiteral =
+      Literal.create((0 to 1000).toSeq, ArrayType(IntegerType))
+
+    checkEvaluation(ArraysZip(Seq(literals(0), longLiteral)),
+      List(Row(9001, 0), Row(9002, 1), Row(9003, 2)) ++
+      (3 to 1000).map { Row(null, _) }.toList)
+
+    val manyLiterals = (0 to 1000).map { _ =>
+      Literal.create(Seq(1), ArrayType(IntegerType))
+    }.toSeq
+
+    val numbers = List(
+      Row(Seq(9001) ++ (0 to 1000).map { _ => 1 }.toSeq: _*),
+      Row(Seq(9002) ++ (0 to 1000).map { _ => null }.toSeq: _*),
+      Row(Seq(9003) ++ (0 to 1000).map { _ => null }.toSeq: _*),
+      Row(Seq(null) ++ (0 to 1000).map { _ => null }.toSeq: _*))
+    checkEvaluation(ArraysZip(Seq(literals(0)) ++ manyLiterals),
+      List(numbers(0), numbers(1), numbers(2), numbers(3)))
+
+    checkEvaluation(ArraysZip(Seq(literals(0), Literal.create(null, ArrayType(IntegerType)))), null)
+    checkEvaluation(ArraysZip(Seq()), List())
+  }
+
   test("Array Min") {
     checkEvaluation(ArrayMin(Literal.create(Seq(-11, 10, 2), ArrayType(IntegerType))), -11)
     checkEvaluation(
index a2aae9a..266a136 100644 (file)
@@ -3508,6 +3508,14 @@ object functions {
    */
   def map_entries(e: Column): Column = withExpr { MapEntries(e.expr) }
 
+  /**
+   * Returns a merged array of structs in which the N-th struct contains all N-th values of input
+   * arrays.
+   * @group collection_funcs
+   * @since 2.4.0
+   */
+  def arrays_zip(e: Column*): Column = withExpr { ArraysZip(e.map(_.expr)) }
+
   //////////////////////////////////////////////////////////////////////////////////////////////
   // Mask functions
   //////////////////////////////////////////////////////////////////////////////////////////////
index 59119bb..959a77a 100644 (file)
@@ -479,6 +479,53 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
     )
   }
 
+  test("dataframe arrays_zip function") {
+    val df1 = Seq((Seq(9001, 9002, 9003), Seq(4, 5, 6))).toDF("val1", "val2")
+    val df2 = Seq((Seq("a", "b"), Seq(true, false), Seq(10, 11))).toDF("val1", "val2", "val3")
+    val df3 = Seq((Seq("a", "b"), Seq(4, 5, 6))).toDF("val1", "val2")
+    val df4 = Seq((Seq("a", "b", null), Seq(4L))).toDF("val1", "val2")
+    val df5 = Seq((Seq(-1), Seq(null), Seq(), Seq(null, null))).toDF("val1", "val2", "val3", "val4")
+    val df6 = Seq((Seq(192.toByte, 256.toByte), Seq(1.1), Seq(), Seq(null, null)))
+      .toDF("v1", "v2", "v3", "v4")
+    val df7 = Seq((Seq(Seq(1, 2, 3), Seq(4, 5)), Seq(1.1, 2.2))).toDF("v1", "v2")
+    val df8 = Seq((Seq(Array[Byte](1.toByte, 5.toByte)), Seq(null))).toDF("v1", "v2")
+
+    val expectedValue1 = Row(Seq(Row(9001, 4), Row(9002, 5), Row(9003, 6)))
+    checkAnswer(df1.select(arrays_zip($"val1", $"val2")), expectedValue1)
+    checkAnswer(df1.selectExpr("arrays_zip(val1, val2)"), expectedValue1)
+
+    val expectedValue2 = Row(Seq(Row("a", true, 10), Row("b", false, 11)))
+    checkAnswer(df2.select(arrays_zip($"val1", $"val2", $"val3")), expectedValue2)
+    checkAnswer(df2.selectExpr("arrays_zip(val1, val2, val3)"), expectedValue2)
+
+    val expectedValue3 = Row(Seq(Row("a", 4), Row("b", 5), Row(null, 6)))
+    checkAnswer(df3.select(arrays_zip($"val1", $"val2")), expectedValue3)
+    checkAnswer(df3.selectExpr("arrays_zip(val1, val2)"), expectedValue3)
+
+    val expectedValue4 = Row(Seq(Row("a", 4L), Row("b", null), Row(null, null)))
+    checkAnswer(df4.select(arrays_zip($"val1", $"val2")), expectedValue4)
+    checkAnswer(df4.selectExpr("arrays_zip(val1, val2)"), expectedValue4)
+
+    val expectedValue5 = Row(Seq(Row(-1, null, null, null), Row(null, null, null, null)))
+    checkAnswer(df5.select(arrays_zip($"val1", $"val2", $"val3", $"val4")), expectedValue5)
+    checkAnswer(df5.selectExpr("arrays_zip(val1, val2, val3, val4)"), expectedValue5)
+
+    val expectedValue6 = Row(Seq(
+      Row(192.toByte, 1.1, null, null), Row(256.toByte, null, null, null)))
+    checkAnswer(df6.select(arrays_zip($"v1", $"v2", $"v3", $"v4")), expectedValue6)
+    checkAnswer(df6.selectExpr("arrays_zip(v1, v2, v3, v4)"), expectedValue6)
+
+    val expectedValue7 = Row(Seq(
+      Row(Seq(1, 2, 3), 1.1), Row(Seq(4, 5), 2.2)))
+    checkAnswer(df7.select(arrays_zip($"v1", $"v2")), expectedValue7)
+    checkAnswer(df7.selectExpr("arrays_zip(v1, v2)"), expectedValue7)
+
+    val expectedValue8 = Row(Seq(
+      Row(Array[Byte](1.toByte, 5.toByte), null)))
+    checkAnswer(df8.select(arrays_zip($"v1", $"v2")), expectedValue8)
+    checkAnswer(df8.selectExpr("arrays_zip(v1, v2)"), expectedValue8)
+  }
+
   test("map size function") {
     val df = Seq(
       (Map[Int, Int](1 -> 1, 2 -> 2), "x"),