[SPARK-23914][SQL] Add array_union function
authorKazuaki Ishizaki <ishizaki@jp.ibm.com>
Thu, 12 Jul 2018 08:42:29 +0000 (17:42 +0900)
committerTakuya UESHIN <ueshin@databricks.com>
Thu, 12 Jul 2018 08:42:29 +0000 (17:42 +0900)
## What changes were proposed in this pull request?

The PR adds the SQL function `array_union`. The behavior of the function is based on Presto's one.

This function returns returns an array of the elements in the union of array1 and array2.

Note: The order of elements in the result is not defined.

## How was this patch tested?

Added UTs

Author: Kazuaki Ishizaki <ishizaki@jp.ibm.com>

Closes #21061 from kiszk/SPARK-23914.

python/pyspark/sql/functions.py
sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
sql/core/src/main/scala/org/apache/spark/sql/functions.scala
sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala

index 9f61e29..5ef7398 100644 (file)
@@ -2033,6 +2033,25 @@ def array_distinct(col):
     return Column(sc._jvm.functions.array_distinct(_to_java_column(col)))
 
 
+@ignore_unicode_prefix
+@since(2.4)
+def array_union(col1, col2):
+    """
+    Collection function: returns an array of the elements in the union of col1 and col2,
+    without duplicates.
+
+    :param col1: name of column containing array
+    :param col2: name of column containing array
+
+    >>> from pyspark.sql import Row
+    >>> df = spark.createDataFrame([Row(c1=["b", "a", "c"], c2=["c", "d", "a", "f"])])
+    >>> df.select(array_union(df.c1, df.c2)).collect()
+    [Row(array_union(c1, c2)=[u'b', u'a', u'c', u'd', u'f'])]
+    """
+    sc = SparkContext._active_spark_context
+    return Column(sc._jvm.functions.array_union(_to_java_column(col1), _to_java_column(col2)))
+
+
 @since(1.4)
 def explode(col):
     """Returns a new row for each element in the given array or map.
index 4dd2b73..cf2a5ed 100644 (file)
@@ -450,7 +450,7 @@ public final class UnsafeArrayData extends ArrayData {
     return values;
   }
 
-  private static UnsafeArrayData fromPrimitiveArray(
+  public static UnsafeArrayData fromPrimitiveArray(
        Object arr, int offset, int length, int elementSize) {
     final long headerInBytes = calculateHeaderPortionInBytes(length);
     final long valueRegionInBytes = (long)elementSize * length;
@@ -463,14 +463,27 @@ public final class UnsafeArrayData extends ArrayData {
     final long[] data = new long[(int)totalSizeInLongs];
 
     Platform.putLong(data, Platform.LONG_ARRAY_OFFSET, length);
-    Platform.copyMemory(arr, offset, data,
-      Platform.LONG_ARRAY_OFFSET + headerInBytes, valueRegionInBytes);
+    if (arr != null) {
+      Platform.copyMemory(arr, offset, data,
+        Platform.LONG_ARRAY_OFFSET + headerInBytes, valueRegionInBytes);
+    }
 
     UnsafeArrayData result = new UnsafeArrayData();
     result.pointTo(data, Platform.LONG_ARRAY_OFFSET, (int)totalSizeInLongs * 8);
     return result;
   }
 
+  public static UnsafeArrayData forPrimitiveArray(int offset, int length, int elementSize) {
+    return fromPrimitiveArray(null, offset, length, elementSize);
+  }
+
+  public static boolean shouldUseGenericArrayData(int elementSize, int length) {
+    final long headerInBytes = calculateHeaderPortionInBytes(length);
+    final long valueRegionInBytes = (long)elementSize * length;
+    final long totalSizeInLongs = (headerInBytes + valueRegionInBytes + 7) / 8;
+    return totalSizeInLongs > Integer.MAX_VALUE / 8;
+  }
+
   public static UnsafeArrayData fromPrimitiveArray(boolean[] arr) {
     return fromPrimitiveArray(arr, Platform.BOOLEAN_ARRAY_OFFSET, arr.length, 1);
   }
index e7517e8..1d9e470 100644 (file)
@@ -414,6 +414,7 @@ object FunctionRegistry {
     expression[ArrayJoin]("array_join"),
     expression[ArrayPosition]("array_position"),
     expression[ArraySort]("array_sort"),
+    expression[ArrayUnion]("array_union"),
     expression[CreateMap]("map"),
     expression[CreateNamedStruct]("named_struct"),
     expression[ElementAt]("element_at"),
index b8f2aa3..0f4f4f1 100644 (file)
@@ -3486,3 +3486,322 @@ case class ArrayDistinct(child: Expression)
 
   override def prettyName: String = "array_distinct"
 }
+
+/**
+ * Will become common base class for [[ArrayUnion]], ArrayIntersect, and ArrayExcept.
+ */
+abstract class ArraySetLike extends BinaryArrayExpressionWithImplicitCast {
+  override def dataType: DataType = {
+    val dataTypes = children.map(_.dataType.asInstanceOf[ArrayType])
+    ArrayType(elementType, dataTypes.exists(_.containsNull))
+  }
+
+  override def checkInputDataTypes(): TypeCheckResult = {
+    val typeCheckResult = super.checkInputDataTypes()
+    if (typeCheckResult.isSuccess) {
+      TypeUtils.checkForOrderingExpr(dataType.asInstanceOf[ArrayType].elementType,
+        s"function $prettyName")
+    } else {
+      typeCheckResult
+    }
+  }
+
+  @transient protected lazy val ordering: Ordering[Any] =
+    TypeUtils.getInterpretedOrdering(elementType)
+
+  @transient protected lazy val elementTypeSupportEquals = elementType match {
+    case BinaryType => false
+    case _: AtomicType => true
+    case _ => false
+  }
+}
+
+object ArraySetLike {
+  def throwUnionLengthOverflowException(length: Int): Unit = {
+    throw new RuntimeException(s"Unsuccessful try to union arrays with $length " +
+      s"elements due to exceeding the array size limit " +
+      s"${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.")
+  }
+}
+
+
+/**
+ * Returns an array of the elements in the union of x and y, without duplicates
+ */
+@ExpressionDescription(
+  usage = """
+    _FUNC_(array1, array2) - Returns an array of the elements in the union of array1 and array2,
+      without duplicates.
+  """,
+  examples = """
+    Examples:
+      > SELECT _FUNC_(array(1, 2, 3), array(1, 3, 5));
+       array(1, 2, 3, 5)
+  """,
+  since = "2.4.0")
+case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike {
+  var hsInt: OpenHashSet[Int] = _
+  var hsLong: OpenHashSet[Long] = _
+
+  def assignInt(array: ArrayData, idx: Int, resultArray: ArrayData, pos: Int): Boolean = {
+    val elem = array.getInt(idx)
+    if (!hsInt.contains(elem)) {
+      if (resultArray != null) {
+        resultArray.setInt(pos, elem)
+      }
+      hsInt.add(elem)
+      true
+    } else {
+      false
+    }
+  }
+
+  def assignLong(array: ArrayData, idx: Int, resultArray: ArrayData, pos: Int): Boolean = {
+    val elem = array.getLong(idx)
+    if (!hsLong.contains(elem)) {
+      if (resultArray != null) {
+        resultArray.setLong(pos, elem)
+      }
+      hsLong.add(elem)
+      true
+    } else {
+      false
+    }
+  }
+
+  def evalIntLongPrimitiveType(
+      array1: ArrayData,
+      array2: ArrayData,
+      resultArray: ArrayData,
+      isLongType: Boolean): Int = {
+    // store elements into resultArray
+    var nullElementSize = 0
+    var pos = 0
+    Seq(array1, array2).foreach { array =>
+      var i = 0
+      while (i < array.numElements()) {
+        val size = if (!isLongType) hsInt.size else hsLong.size
+        if (size + nullElementSize > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
+          ArraySetLike.throwUnionLengthOverflowException(size)
+        }
+        if (array.isNullAt(i)) {
+          if (nullElementSize == 0) {
+            if (resultArray != null) {
+              resultArray.setNullAt(pos)
+            }
+            pos += 1
+            nullElementSize = 1
+          }
+        } else {
+          val assigned = if (!isLongType) {
+            assignInt(array, i, resultArray, pos)
+          } else {
+            assignLong(array, i, resultArray, pos)
+          }
+          if (assigned) {
+            pos += 1
+          }
+        }
+        i += 1
+      }
+    }
+    pos
+  }
+
+  override def nullSafeEval(input1: Any, input2: Any): Any = {
+    val array1 = input1.asInstanceOf[ArrayData]
+    val array2 = input2.asInstanceOf[ArrayData]
+
+    if (elementTypeSupportEquals) {
+      elementType match {
+        case IntegerType =>
+          // avoid boxing of primitive int array elements
+          // calculate result array size
+          hsInt = new OpenHashSet[Int]
+          val elements = evalIntLongPrimitiveType(array1, array2, null, false)
+          hsInt = new OpenHashSet[Int]
+          val resultArray = if (UnsafeArrayData.shouldUseGenericArrayData(
+            IntegerType.defaultSize, elements)) {
+            new GenericArrayData(new Array[Any](elements))
+          } else {
+            UnsafeArrayData.forPrimitiveArray(
+              Platform.INT_ARRAY_OFFSET, elements, IntegerType.defaultSize)
+          }
+          evalIntLongPrimitiveType(array1, array2, resultArray, false)
+          resultArray
+        case LongType =>
+          // avoid boxing of primitive long array elements
+          // calculate result array size
+          hsLong = new OpenHashSet[Long]
+          val elements = evalIntLongPrimitiveType(array1, array2, null, true)
+          hsLong = new OpenHashSet[Long]
+          val resultArray = if (UnsafeArrayData.shouldUseGenericArrayData(
+            LongType.defaultSize, elements)) {
+            new GenericArrayData(new Array[Any](elements))
+          } else {
+            UnsafeArrayData.forPrimitiveArray(
+              Platform.LONG_ARRAY_OFFSET, elements, LongType.defaultSize)
+          }
+          evalIntLongPrimitiveType(array1, array2, resultArray, true)
+          resultArray
+        case _ =>
+          val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any]
+          val hs = new OpenHashSet[Any]
+          var foundNullElement = false
+          Seq(array1, array2).foreach { array =>
+            var i = 0
+            while (i < array.numElements()) {
+              if (array.isNullAt(i)) {
+                if (!foundNullElement) {
+                  arrayBuffer += null
+                  foundNullElement = true
+                }
+              } else {
+                val elem = array.get(i, elementType)
+                if (!hs.contains(elem)) {
+                  if (arrayBuffer.size > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
+                    ArraySetLike.throwUnionLengthOverflowException(arrayBuffer.size)
+                  }
+                  arrayBuffer += elem
+                  hs.add(elem)
+                }
+              }
+              i += 1
+            }
+          }
+          new GenericArrayData(arrayBuffer)
+      }
+    } else {
+      ArrayUnion.unionOrdering(array1, array2, elementType, ordering)
+    }
+  }
+
+  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+    val i = ctx.freshName("i")
+    val pos = ctx.freshName("pos")
+    val value = ctx.freshName("value")
+    val size = ctx.freshName("size")
+    val (postFix, openHashElementType, getter, setter, javaTypeName, castOp, arrayBuilder) =
+      if (elementTypeSupportEquals) {
+        elementType match {
+          case ByteType | ShortType | IntegerType | LongType =>
+            val ptName = CodeGenerator.primitiveTypeName(elementType)
+            val unsafeArray = ctx.freshName("unsafeArray")
+            (if (elementType == LongType) s"$$mcJ$$sp" else s"$$mcI$$sp",
+              if (elementType == LongType) "Long" else "Int",
+              s"get$ptName($i)", s"set$ptName($pos, $value)", CodeGenerator.javaType(elementType),
+              if (elementType == LongType) "(long)" else "(int)",
+              s"""
+                 |${ctx.createUnsafeArray(unsafeArray, size, elementType, s" $prettyName failed.")}
+                 |${ev.value} = $unsafeArray;
+               """.stripMargin)
+          case _ =>
+            val genericArrayData = classOf[GenericArrayData].getName
+            val et = ctx.addReferenceObj("elementType", elementType)
+            ("", "Object",
+              s"get($i, $et)", s"update($pos, $value)", "Object", "",
+              s"${ev.value} = new $genericArrayData(new Object[$size]);")
+        }
+      } else {
+        ("", "", "", "", "", "", "")
+      }
+
+    nullSafeCodeGen(ctx, ev, (array1, array2) => {
+      if (openHashElementType != "") {
+        // Here, we ensure elementTypeSupportEquals is true
+        val foundNullElement = ctx.freshName("foundNullElement")
+        val openHashSet = classOf[OpenHashSet[_]].getName
+        val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$openHashElementType()"
+        val hs = ctx.freshName("hs")
+        val arrayData = classOf[ArrayData].getName
+        val arrays = ctx.freshName("arrays")
+        val array = ctx.freshName("array")
+        val arrayDataIdx = ctx.freshName("arrayDataIdx")
+        s"""
+           |$openHashSet $hs = new $openHashSet$postFix($classTag);
+           |boolean $foundNullElement = false;
+           |$arrayData[] $arrays = new $arrayData[]{$array1, $array2};
+           |for (int $arrayDataIdx = 0; $arrayDataIdx < 2; $arrayDataIdx++) {
+           |  $arrayData $array = $arrays[$arrayDataIdx];
+           |  for (int $i = 0; $i < $array.numElements(); $i++) {
+           |    if ($array.isNullAt($i)) {
+           |      $foundNullElement = true;
+           |    } else {
+           |      $hs.add$postFix($array.$getter);
+           |    }
+           |  }
+           |}
+           |int $size = $hs.size() + ($foundNullElement ? 1 : 0);
+           |$arrayBuilder
+           |$hs = new $openHashSet$postFix($classTag);
+           |$foundNullElement = false;
+           |int $pos = 0;
+           |for (int $arrayDataIdx = 0; $arrayDataIdx < 2; $arrayDataIdx++) {
+           |  $arrayData $array = $arrays[$arrayDataIdx];
+           |  for (int $i = 0; $i < $array.numElements(); $i++) {
+           |    if ($array.isNullAt($i)) {
+           |      if (!$foundNullElement) {
+           |        ${ev.value}.setNullAt($pos++);
+           |        $foundNullElement = true;
+           |      }
+           |    } else {
+           |      $javaTypeName $value = $array.$getter;
+           |      if (!$hs.contains($castOp $value)) {
+           |        $hs.add$postFix($value);
+           |        ${ev.value}.$setter;
+           |        $pos++;
+           |      }
+           |    }
+           |  }
+           |}
+         """.stripMargin
+      } else {
+        val arrayUnion = classOf[ArrayUnion].getName
+        val et = ctx.addReferenceObj("elementTypeUnion", elementType)
+        val order = ctx.addReferenceObj("orderingUnion", ordering)
+        val method = "unionOrdering"
+        s"${ev.value} = $arrayUnion$$.MODULE$$.$method($array1, $array2, $et, $order);"
+      }
+    })
+  }
+
+  override def prettyName: String = "array_union"
+}
+
+object ArrayUnion {
+  def unionOrdering(
+      array1: ArrayData,
+      array2: ArrayData,
+      elementType: DataType,
+      ordering: Ordering[Any]): ArrayData = {
+    val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any]
+    var alreadyIncludeNull = false
+    Seq(array1, array2).foreach(_.foreach(elementType, (_, elem) => {
+      var found = false
+      if (elem == null) {
+        if (alreadyIncludeNull) {
+          found = true
+        } else {
+          alreadyIncludeNull = true
+        }
+      } else {
+        // check elem is already stored in arrayBuffer or not?
+        var j = 0
+        while (!found && j < arrayBuffer.size) {
+          val va = arrayBuffer(j)
+          if (va != null && ordering.equiv(va, elem)) {
+            found = true
+          }
+          j = j + 1
+        }
+      }
+      if (!found) {
+        if (arrayBuffer.length > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
+          ArraySetLike.throwUnionLengthOverflowException(arrayBuffer.length)
+        }
+        arrayBuffer += elem
+      }
+    }))
+    new GenericArrayData(arrayBuffer)
+  }
+}
index a838a2e..85d6a1b 100644 (file)
@@ -1304,4 +1304,85 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
     checkEvaluation(ArrayDistinct(c1), Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1)))
     checkEvaluation(ArrayDistinct(c2), Seq[Seq[Int]](null, Seq[Int](2, 1)))
   }
+
+  test("Array Union") {
+    val a00 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false))
+    val a01 = Literal.create(Seq(4, 2), ArrayType(IntegerType, containsNull = false))
+    val a02 = Literal.create(Seq(1, 2, null, 4, 5), ArrayType(IntegerType, containsNull = true))
+    val a03 = Literal.create(Seq(-5, 4, -3, 2, 4), ArrayType(IntegerType, containsNull = false))
+    val a04 = Literal.create(Seq.empty[Int], ArrayType(IntegerType, containsNull = false))
+    val a05 = Literal.create(Seq[Byte](1, 2, 3), ArrayType(ByteType, containsNull = false))
+    val a06 = Literal.create(Seq[Byte](4, 2), ArrayType(ByteType, containsNull = false))
+    val a07 = Literal.create(Seq[Short](1, 2, 3), ArrayType(ShortType, containsNull = false))
+    val a08 = Literal.create(Seq[Short](4, 2), ArrayType(ShortType, containsNull = false))
+
+    val a10 = Literal.create(Seq(1L, 2L, 3L), ArrayType(LongType, containsNull = false))
+    val a11 = Literal.create(Seq(4L, 2L), ArrayType(LongType, containsNull = false))
+    val a12 = Literal.create(Seq(1L, 2L, null, 4L, 5L), ArrayType(LongType, containsNull = true))
+    val a13 = Literal.create(Seq(-5L, 4L, -3L, 2L, -1L), ArrayType(LongType, containsNull = false))
+    val a14 = Literal.create(Seq.empty[Long], ArrayType(LongType, containsNull = false))
+
+    val a20 = Literal.create(Seq("b", "a", "c"), ArrayType(StringType, containsNull = false))
+    val a21 = Literal.create(Seq("c", "d", "a", "f"), ArrayType(StringType, containsNull = false))
+    val a22 = Literal.create(Seq("b", null, "a", "g"), ArrayType(StringType, containsNull = true))
+
+    val a30 = Literal.create(Seq(null, null), ArrayType(IntegerType))
+    val a31 = Literal.create(null, ArrayType(StringType))
+
+    checkEvaluation(ArrayUnion(a00, a01), Seq(1, 2, 3, 4))
+    checkEvaluation(ArrayUnion(a02, a03), Seq(1, 2, null, 4, 5, -5, -3))
+    checkEvaluation(ArrayUnion(a03, a02), Seq(-5, 4, -3, 2, 1, null, 5))
+    checkEvaluation(ArrayUnion(a02, a04), Seq(1, 2, null, 4, 5))
+    checkEvaluation(ArrayUnion(a05, a06), Seq[Byte](1, 2, 3, 4))
+    checkEvaluation(ArrayUnion(a07, a08), Seq[Short](1, 2, 3, 4))
+
+    checkEvaluation(ArrayUnion(a10, a11), Seq(1L, 2L, 3L, 4L))
+    checkEvaluation(ArrayUnion(a12, a13), Seq(1L, 2L, null, 4L, 5L, -5L, -3L, -1L))
+    checkEvaluation(ArrayUnion(a13, a12), Seq(-5L, 4L, -3L, 2L, -1L, 1L, null, 5L))
+    checkEvaluation(ArrayUnion(a12, a14), Seq(1L, 2L, null, 4L, 5L))
+
+    checkEvaluation(ArrayUnion(a20, a21), Seq("b", "a", "c", "d", "f"))
+    checkEvaluation(ArrayUnion(a20, a22), Seq("b", "a", "c", null, "g"))
+
+    checkEvaluation(ArrayUnion(a30, a30), Seq(null))
+    checkEvaluation(ArrayUnion(a20, a31), null)
+    checkEvaluation(ArrayUnion(a31, a20), null)
+
+    val b0 = Literal.create(Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](1, 2)),
+      ArrayType(BinaryType))
+    val b1 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), Array[Byte](4, 3)),
+      ArrayType(BinaryType))
+    val b2 = Literal.create(Seq[Array[Byte]](Array[Byte](1, 2), Array[Byte](4, 3)),
+      ArrayType(BinaryType))
+    val b3 = Literal.create(Seq[Array[Byte]](
+      Array[Byte](1, 2), Array[Byte](4, 3), Array[Byte](1, 2)), ArrayType(BinaryType))
+    val b4 = Literal.create(Seq[Array[Byte]](Array[Byte](1, 2), null), ArrayType(BinaryType))
+    val b5 = Literal.create(Seq[Array[Byte]](null, Array[Byte](1, 2)), ArrayType(BinaryType))
+    val b6 = Literal.create(Seq.empty, ArrayType(BinaryType))
+    val arrayWithBinaryNull = Literal.create(Seq(null), ArrayType(BinaryType))
+
+    checkEvaluation(ArrayUnion(b0, b1),
+      Seq(Array[Byte](5, 6), Array[Byte](1, 2), Array[Byte](2, 1), Array[Byte](4, 3)))
+    checkEvaluation(ArrayUnion(b0, b2),
+      Seq(Array[Byte](5, 6), Array[Byte](1, 2), Array[Byte](4, 3)))
+    checkEvaluation(ArrayUnion(b2, b4), Seq(Array[Byte](1, 2), Array[Byte](4, 3), null))
+    checkEvaluation(ArrayUnion(b3, b0),
+      Seq(Array[Byte](1, 2), Array[Byte](4, 3), Array[Byte](5, 6)))
+    checkEvaluation(ArrayUnion(b4, b0), Seq(Array[Byte](1, 2), null, Array[Byte](5, 6)))
+    checkEvaluation(ArrayUnion(b4, b5), Seq(Array[Byte](1, 2), null))
+    checkEvaluation(ArrayUnion(b6, b4), Seq(Array[Byte](1, 2), null))
+    checkEvaluation(ArrayUnion(b4, arrayWithBinaryNull), Seq(Array[Byte](1, 2), null))
+
+    val aa0 = Literal.create(Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4)),
+      ArrayType(ArrayType(IntegerType)))
+    val aa1 = Literal.create(Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1)),
+      ArrayType(ArrayType(IntegerType)))
+    checkEvaluation(ArrayUnion(aa0, aa1),
+      Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4), Seq[Int](5, 6), Seq[Int](2, 1)))
+
+    assert(ArrayUnion(a00, a01).dataType.asInstanceOf[ArrayType].containsNull === false)
+    assert(ArrayUnion(a00, a02).dataType.asInstanceOf[ArrayType].containsNull === true)
+    assert(ArrayUnion(a20, a21).dataType.asInstanceOf[ArrayType].containsNull === false)
+    assert(ArrayUnion(a20, a22).dataType.asInstanceOf[ArrayType].containsNull === true)
+  }
 }
index 6b956dd..b98ab11 100644 (file)
@@ -3204,6 +3204,7 @@ object functions {
 
   /**
    * Remove all elements that equal to element from the given array.
+   *
    * @group collection_funcs
    * @since 2.4.0
    */
@@ -3219,6 +3220,16 @@ object functions {
   def array_distinct(e: Column): Column = withExpr { ArrayDistinct(e.expr) }
 
   /**
+   * Returns an array of the elements in the union of the given two arrays, without duplicates.
+   *
+   * @group collection_funcs
+   * @since 2.4.0
+   */
+  def array_union(col1: Column, col2: Column): Column = withExpr {
+    ArrayUnion(col1.expr, col2.expr)
+  }
+
+  /**
    * Creates a new row for each element in the given array or map column.
    *
    * @group collection_funcs
index d60ed7a..d461571 100644 (file)
@@ -1198,6 +1198,58 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
       "argument 1 requires (array or map) type, however, '`_1`' is of string type"))
   }
 
+  test("array_union functions") {
+    val df1 = Seq((Array(1, 2, 3), Array(4, 2))).toDF("a", "b")
+    val ans1 = Row(Seq(1, 2, 3, 4))
+    checkAnswer(df1.select(array_union($"a", $"b")), ans1)
+    checkAnswer(df1.selectExpr("array_union(a, b)"), ans1)
+
+    val df2 = Seq((Array[Integer](1, 2, null, 4, 5), Array(-5, 4, -3, 2, -1))).toDF("a", "b")
+    val ans2 = Row(Seq(1, 2, null, 4, 5, -5, -3, -1))
+    checkAnswer(df2.select(array_union($"a", $"b")), ans2)
+    checkAnswer(df2.selectExpr("array_union(a, b)"), ans2)
+
+    val df3 = Seq((Array(1L, 2L, 3L), Array(4L, 2L))).toDF("a", "b")
+    val ans3 = Row(Seq(1L, 2L, 3L, 4L))
+    checkAnswer(df3.select(array_union($"a", $"b")), ans3)
+    checkAnswer(df3.selectExpr("array_union(a, b)"), ans3)
+
+    val df4 = Seq((Array[java.lang.Long](1L, 2L, null, 4L, 5L), Array(-5L, 4L, -3L, 2L, -1L)))
+      .toDF("a", "b")
+    val ans4 = Row(Seq(1L, 2L, null, 4L, 5L, -5L, -3L, -1L))
+    checkAnswer(df4.select(array_union($"a", $"b")), ans4)
+    checkAnswer(df4.selectExpr("array_union(a, b)"), ans4)
+
+    val df5 = Seq((Array("b", "a", "c"), Array("b", null, "a", "g"))).toDF("a", "b")
+    val ans5 = Row(Seq("b", "a", "c", null, "g"))
+    checkAnswer(df5.select(array_union($"a", $"b")), ans5)
+    checkAnswer(df5.selectExpr("array_union(a, b)"), ans5)
+
+    val df6 = Seq((null, Array("a"))).toDF("a", "b")
+    intercept[AnalysisException] {
+      df6.select(array_union($"a", $"b"))
+    }
+    intercept[AnalysisException] {
+      df6.selectExpr("array_union(a, b)")
+    }
+
+    val df7 = Seq((null, null)).toDF("a", "b")
+    intercept[AnalysisException] {
+      df7.select(array_union($"a", $"b"))
+    }
+    intercept[AnalysisException] {
+      df7.selectExpr("array_union(a, b)")
+    }
+
+    val df8 = Seq((Array(Array(1)), Array("a"))).toDF("a", "b")
+    intercept[AnalysisException] {
+      df8.select(array_union($"a", $"b"))
+    }
+    intercept[AnalysisException] {
+      df8.selectExpr("array_union(a, b)")
+    }
+  }
+
   test("concat function - arrays") {
     val nseqi : Seq[Int] = null
     val nseqs : Seq[String] = null