[SPARK-22239][SQL][PYTHON] Enable grouped aggregate pandas UDFs as window functions...
authorLi Jin <ice.xelloss@gmail.com>
Wed, 13 Jun 2018 01:10:52 +0000 (09:10 +0800)
committerhyukjinkwon <gurwls223@apache.org>
Wed, 13 Jun 2018 01:10:52 +0000 (09:10 +0800)
## What changes were proposed in this pull request?
This PR enables using a grouped aggregate pandas UDFs as window functions. The semantics is the same as using SQL aggregation function as window functions.

```
       >>> from pyspark.sql.functions import pandas_udf, PandasUDFType
       >>> from pyspark.sql import Window
       >>> df = spark.createDataFrame(
       ...     [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
       ...     ("id", "v"))
       >>> pandas_udf("double", PandasUDFType.GROUPED_AGG)
       ... def mean_udf(v):
       ...     return v.mean()
       >>> w = Window.partitionBy('id')
       >>> df.withColumn('mean_v', mean_udf(df['v']).over(w)).show()
       +---+----+------+
       | id|   v|mean_v|
       +---+----+------+
       |  1| 1.0|   1.5|
       |  1| 2.0|   1.5|
       |  2| 3.0|   6.0|
       |  2| 5.0|   6.0|
       |  2|10.0|   6.0|
       +---+----+------+
```

The scope of this PR is somewhat limited in terms of:
(1) Only supports unbounded window, which acts essentially as group by.
(2) Only supports aggregation functions, not "transform" like window functions (n -> n mapping)

Both of these are left as future work. Especially, (1) needs careful thinking w.r.t. how to pass rolling window data to python efficiently. (2) is a bit easier but does require more changes therefore I think it's better to leave it as a separate PR.

## How was this patch tested?

WindowPandasUDFTests

Author: Li Jin <ice.xelloss@gmail.com>

Closes #21082 from icexelloss/SPARK-22239-window-udf.

15 files changed:
core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
python/pyspark/rdd.py
python/pyspark/sql/functions.py
python/pyspark/sql/tests.py
python/pyspark/worker.py
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala
sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala [new file with mode: 0644]

index 41eac10..ebabedf 100644 (file)
@@ -40,6 +40,7 @@ private[spark] object PythonEvalType {
   val SQL_SCALAR_PANDAS_UDF = 200
   val SQL_GROUPED_MAP_PANDAS_UDF = 201
   val SQL_GROUPED_AGG_PANDAS_UDF = 202
+  val SQL_WINDOW_AGG_PANDAS_UDF = 203
 
   def toString(pythonEvalType: Int): String = pythonEvalType match {
     case NON_UDF => "NON_UDF"
@@ -47,6 +48,7 @@ private[spark] object PythonEvalType {
     case SQL_SCALAR_PANDAS_UDF => "SQL_SCALAR_PANDAS_UDF"
     case SQL_GROUPED_MAP_PANDAS_UDF => "SQL_GROUPED_MAP_PANDAS_UDF"
     case SQL_GROUPED_AGG_PANDAS_UDF => "SQL_GROUPED_AGG_PANDAS_UDF"
+    case SQL_WINDOW_AGG_PANDAS_UDF => "SQL_WINDOW_AGG_PANDAS_UDF"
   }
 }
 
index 14d9128..7e7e582 100644 (file)
@@ -74,6 +74,7 @@ class PythonEvalType(object):
     SQL_SCALAR_PANDAS_UDF = 200
     SQL_GROUPED_MAP_PANDAS_UDF = 201
     SQL_GROUPED_AGG_PANDAS_UDF = 202
+    SQL_WINDOW_AGG_PANDAS_UDF = 203
 
 
 def portable_hash(x):
index 1cdbb8a..a5e3384 100644 (file)
@@ -2616,10 +2616,12 @@ def pandas_udf(f=None, returnType=None, functionType=None):
        The returned scalar can be either a python primitive type, e.g., `int` or `float`
        or a numpy data type, e.g., `numpy.int64` or `numpy.float64`.
 
-       :class:`ArrayType`, :class:`MapType` and :class:`StructType` are currently not supported as
-       output types.
+       :class:`MapType` and :class:`StructType` are currently not supported as output types.
 
-       Group aggregate UDFs are used with :meth:`pyspark.sql.GroupedData.agg`
+       Group aggregate UDFs are used with :meth:`pyspark.sql.GroupedData.agg` and
+       :class:`pyspark.sql.Window`
+
+       This example shows using grouped aggregated UDFs with groupby:
 
        >>> from pyspark.sql.functions import pandas_udf, PandasUDFType
        >>> df = spark.createDataFrame(
@@ -2636,7 +2638,31 @@ def pandas_udf(f=None, returnType=None, functionType=None):
        |  2|        6.0|
        +---+-----------+
 
-       .. seealso:: :meth:`pyspark.sql.GroupedData.agg`
+       This example shows using grouped aggregated UDFs as window functions. Note that only
+       unbounded window frame is supported at the moment:
+
+       >>> from pyspark.sql.functions import pandas_udf, PandasUDFType
+       >>> from pyspark.sql import Window
+       >>> df = spark.createDataFrame(
+       ...     [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
+       ...     ("id", "v"))
+       >>> @pandas_udf("double", PandasUDFType.GROUPED_AGG)  # doctest: +SKIP
+       ... def mean_udf(v):
+       ...     return v.mean()
+       >>> w = Window.partitionBy('id') \\
+       ...           .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)
+       >>> df.withColumn('mean_v', mean_udf(df['v']).over(w)).show()  # doctest: +SKIP
+       +---+----+------+
+       | id|   v|mean_v|
+       +---+----+------+
+       |  1| 1.0|   1.5|
+       |  1| 2.0|   1.5|
+       |  2| 3.0|   6.0|
+       |  2| 5.0|   6.0|
+       |  2|10.0|   6.0|
+       +---+----+------+
+
+       .. seealso:: :meth:`pyspark.sql.GroupedData.agg` and :class:`pyspark.sql.Window`
 
     .. note:: The user-defined functions are considered deterministic by default. Due to
         optimization, duplicate invocations may be eliminated or the function may even be invoked
index 4a3941d..2d7a4f6 100644 (file)
@@ -5454,6 +5454,15 @@ class GroupedAggPandasUDFTests(ReusedSQLTestCase):
             expected1 = df.groupby(df.id).agg(sum(df.v))
             self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
 
+    def test_array_type(self):
+        from pyspark.sql.functions import pandas_udf, PandasUDFType
+
+        df = self.data
+
+        array_udf = pandas_udf(lambda x: [1.0, 2.0], 'array<double>', PandasUDFType.GROUPED_AGG)
+        result1 = df.groupby('id').agg(array_udf(df['v']).alias('v2'))
+        self.assertEquals(result1.first()['v2'], [1.0, 2.0])
+
     def test_invalid_args(self):
         from pyspark.sql.functions import mean
 
@@ -5479,6 +5488,235 @@ class GroupedAggPandasUDFTests(ReusedSQLTestCase):
                     'mixture.*aggregate function.*group aggregate pandas UDF'):
                 df.groupby(df.id).agg(mean_udf(df.v), mean(df.v)).collect()
 
+
+@unittest.skipIf(
+    not _have_pandas or not _have_pyarrow,
+    _pandas_requirement_message or _pyarrow_requirement_message)
+class WindowPandasUDFTests(ReusedSQLTestCase):
+    @property
+    def data(self):
+        from pyspark.sql.functions import array, explode, col, lit
+        return self.spark.range(10).toDF('id') \
+            .withColumn("vs", array([lit(i * 1.0) + col('id') for i in range(20, 30)])) \
+            .withColumn("v", explode(col('vs'))) \
+            .drop('vs') \
+            .withColumn('w', lit(1.0))
+
+    @property
+    def python_plus_one(self):
+        from pyspark.sql.functions import udf
+        return udf(lambda v: v + 1, 'double')
+
+    @property
+    def pandas_scalar_time_two(self):
+        from pyspark.sql.functions import pandas_udf, PandasUDFType
+        return pandas_udf(lambda v: v * 2, 'double')
+
+    @property
+    def pandas_agg_mean_udf(self):
+        from pyspark.sql.functions import pandas_udf, PandasUDFType
+
+        @pandas_udf('double', PandasUDFType.GROUPED_AGG)
+        def avg(v):
+            return v.mean()
+        return avg
+
+    @property
+    def pandas_agg_max_udf(self):
+        from pyspark.sql.functions import pandas_udf, PandasUDFType
+
+        @pandas_udf('double', PandasUDFType.GROUPED_AGG)
+        def max(v):
+            return v.max()
+        return max
+
+    @property
+    def pandas_agg_min_udf(self):
+        from pyspark.sql.functions import pandas_udf, PandasUDFType
+
+        @pandas_udf('double', PandasUDFType.GROUPED_AGG)
+        def min(v):
+            return v.min()
+        return min
+
+    @property
+    def unbounded_window(self):
+        return Window.partitionBy('id') \
+            .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)
+
+    @property
+    def ordered_window(self):
+        return Window.partitionBy('id').orderBy('v')
+
+    @property
+    def unpartitioned_window(self):
+        return Window.partitionBy()
+
+    def test_simple(self):
+        from pyspark.sql.functions import pandas_udf, PandasUDFType, percent_rank, mean, max
+
+        df = self.data
+        w = self.unbounded_window
+
+        mean_udf = self.pandas_agg_mean_udf
+
+        result1 = df.withColumn('mean_v', mean_udf(df['v']).over(w))
+        expected1 = df.withColumn('mean_v', mean(df['v']).over(w))
+
+        result2 = df.select(mean_udf(df['v']).over(w))
+        expected2 = df.select(mean(df['v']).over(w))
+
+        self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
+        self.assertPandasEqual(expected2.toPandas(), result2.toPandas())
+
+    def test_multiple_udfs(self):
+        from pyspark.sql.functions import max, min, mean
+
+        df = self.data
+        w = self.unbounded_window
+
+        result1 = df.withColumn('mean_v', self.pandas_agg_mean_udf(df['v']).over(w)) \
+                    .withColumn('max_v', self.pandas_agg_max_udf(df['v']).over(w)) \
+                    .withColumn('min_w', self.pandas_agg_min_udf(df['w']).over(w))
+
+        expected1 = df.withColumn('mean_v', mean(df['v']).over(w)) \
+                      .withColumn('max_v', max(df['v']).over(w)) \
+                      .withColumn('min_w', min(df['w']).over(w))
+
+        self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
+
+    def test_replace_existing(self):
+        from pyspark.sql.functions import mean
+
+        df = self.data
+        w = self.unbounded_window
+
+        result1 = df.withColumn('v', self.pandas_agg_mean_udf(df['v']).over(w))
+        expected1 = df.withColumn('v', mean(df['v']).over(w))
+
+        self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
+
+    def test_mixed_sql(self):
+        from pyspark.sql.functions import mean
+
+        df = self.data
+        w = self.unbounded_window
+        mean_udf = self.pandas_agg_mean_udf
+
+        result1 = df.withColumn('v', mean_udf(df['v'] * 2).over(w) + 1)
+        expected1 = df.withColumn('v', mean(df['v'] * 2).over(w) + 1)
+
+        self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
+
+    def test_mixed_udf(self):
+        from pyspark.sql.functions import mean
+
+        df = self.data
+        w = self.unbounded_window
+
+        plus_one = self.python_plus_one
+        time_two = self.pandas_scalar_time_two
+        mean_udf = self.pandas_agg_mean_udf
+
+        result1 = df.withColumn(
+            'v2',
+            plus_one(mean_udf(plus_one(df['v'])).over(w)))
+        expected1 = df.withColumn(
+            'v2',
+            plus_one(mean(plus_one(df['v'])).over(w)))
+
+        result2 = df.withColumn(
+            'v2',
+            time_two(mean_udf(time_two(df['v'])).over(w)))
+        expected2 = df.withColumn(
+            'v2',
+            time_two(mean(time_two(df['v'])).over(w)))
+
+        self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
+        self.assertPandasEqual(expected2.toPandas(), result2.toPandas())
+
+    def test_without_partitionBy(self):
+        from pyspark.sql.functions import mean
+
+        df = self.data
+        w = self.unpartitioned_window
+        mean_udf = self.pandas_agg_mean_udf
+
+        result1 = df.withColumn('v2', mean_udf(df['v']).over(w))
+        expected1 = df.withColumn('v2', mean(df['v']).over(w))
+
+        result2 = df.select(mean_udf(df['v']).over(w))
+        expected2 = df.select(mean(df['v']).over(w))
+
+        self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
+        self.assertPandasEqual(expected2.toPandas(), result2.toPandas())
+
+    def test_mixed_sql_and_udf(self):
+        from pyspark.sql.functions import max, min, rank, col
+
+        df = self.data
+        w = self.unbounded_window
+        ow = self.ordered_window
+        max_udf = self.pandas_agg_max_udf
+        min_udf = self.pandas_agg_min_udf
+
+        result1 = df.withColumn('v_diff', max_udf(df['v']).over(w) - min_udf(df['v']).over(w))
+        expected1 = df.withColumn('v_diff', max(df['v']).over(w) - min(df['v']).over(w))
+
+        # Test mixing sql window function and window udf in the same expression
+        result2 = df.withColumn('v_diff', max_udf(df['v']).over(w) - min(df['v']).over(w))
+        expected2 = expected1
+
+        # Test chaining sql aggregate function and udf
+        result3 = df.withColumn('max_v', max_udf(df['v']).over(w)) \
+                    .withColumn('min_v', min(df['v']).over(w)) \
+                    .withColumn('v_diff', col('max_v') - col('min_v')) \
+                    .drop('max_v', 'min_v')
+        expected3 = expected1
+
+        # Test mixing sql window function and udf
+        result4 = df.withColumn('max_v', max_udf(df['v']).over(w)) \
+                    .withColumn('rank', rank().over(ow))
+        expected4 = df.withColumn('max_v', max(df['v']).over(w)) \
+                      .withColumn('rank', rank().over(ow))
+
+        self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
+        self.assertPandasEqual(expected2.toPandas(), result2.toPandas())
+        self.assertPandasEqual(expected3.toPandas(), result3.toPandas())
+        self.assertPandasEqual(expected4.toPandas(), result4.toPandas())
+
+    def test_array_type(self):
+        from pyspark.sql.functions import pandas_udf, PandasUDFType
+
+        df = self.data
+        w = self.unbounded_window
+
+        array_udf = pandas_udf(lambda x: [1.0, 2.0], 'array<double>', PandasUDFType.GROUPED_AGG)
+        result1 = df.withColumn('v2', array_udf(df['v']).over(w))
+        self.assertEquals(result1.first()['v2'], [1.0, 2.0])
+
+    def test_invalid_args(self):
+        from pyspark.sql.functions import mean, pandas_udf, PandasUDFType
+
+        df = self.data
+        w = self.unbounded_window
+        ow = self.ordered_window
+        mean_udf = self.pandas_agg_mean_udf
+
+        with QuietTest(self.sc):
+            with self.assertRaisesRegexp(
+                    AnalysisException,
+                    '.*not supported within a window function'):
+                foo_udf = pandas_udf(lambda x: x, 'v double', PandasUDFType.GROUPED_MAP)
+                df.withColumn('v2', foo_udf(df['v']).over(w))
+
+        with QuietTest(self.sc):
+            with self.assertRaisesRegexp(
+                    AnalysisException,
+                    '.*Only unbounded window frame is supported.*'):
+                df.withColumn('mean_v', mean_udf(df['v']).over(ow))
+
+
 if __name__ == "__main__":
     from pyspark.sql.tests import *
     if xmlrunner:
index a30d6bf..38fe2ef 100644 (file)
@@ -128,6 +128,21 @@ def wrap_grouped_agg_pandas_udf(f, return_type):
     return lambda *a: (wrapped(*a), arrow_return_type)
 
 
+def wrap_window_agg_pandas_udf(f, return_type):
+    # This is similar to grouped_agg_pandas_udf, the only difference
+    # is that window_agg_pandas_udf needs to repeat the return value
+    # to match window length, where grouped_agg_pandas_udf just returns
+    # the scalar value.
+    arrow_return_type = to_arrow_type(return_type)
+
+    def wrapped(*series):
+        import pandas as pd
+        result = f(*series)
+        return pd.Series([result]).repeat(len(series[0]))
+
+    return lambda *a: (wrapped(*a), arrow_return_type)
+
+
 def read_single_udf(pickleSer, infile, eval_type):
     num_arg = read_int(infile)
     arg_offsets = [read_int(infile) for i in range(num_arg)]
@@ -151,6 +166,8 @@ def read_single_udf(pickleSer, infile, eval_type):
         return arg_offsets, wrap_grouped_map_pandas_udf(func, return_type, argspec)
     elif eval_type == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF:
         return arg_offsets, wrap_grouped_agg_pandas_udf(func, return_type)
+    elif eval_type == PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF:
+        return arg_offsets, wrap_window_agg_pandas_udf(func, return_type)
     elif eval_type == PythonEvalType.SQL_BATCHED_UDF:
         return arg_offsets, wrap_udf(func, return_type)
     else:
@@ -195,7 +212,8 @@ def read_udfs(pickleSer, infile, eval_type):
 
     if eval_type in (PythonEvalType.SQL_SCALAR_PANDAS_UDF,
                      PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
-                     PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF):
+                     PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
+                     PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF):
         timezone = utf8_deserializer.loads(infile)
         ser = ArrowStreamPandasSerializer(timezone)
     else:
index f9947d1..6e3107f 100644 (file)
@@ -1739,9 +1739,10 @@ class Analyzer(
    * 1. For a list of [[Expression]]s (a projectList or an aggregateExpressions), partitions
    *    it two lists of [[Expression]]s, one for all [[WindowExpression]]s and another for
    *    all regular expressions.
-   * 2. For all [[WindowExpression]]s, groups them based on their [[WindowSpecDefinition]]s.
-   * 3. For every distinct [[WindowSpecDefinition]], creates a [[Window]] operator and inserts
-   *    it into the plan tree.
+   * 2. For all [[WindowExpression]]s, groups them based on their [[WindowSpecDefinition]]s
+   *    and [[WindowFunctionType]]s.
+   * 3. For every distinct [[WindowSpecDefinition]] and [[WindowFunctionType]], creates a
+   *    [[Window]] operator and inserts it into the plan tree.
    */
   object ExtractWindowExpressions extends Rule[LogicalPlan] {
     private def hasWindowFunction(exprs: Seq[Expression]): Boolean =
@@ -1901,7 +1902,7 @@ class Analyzer(
             s"Please file a bug report with this error message, stack trace, and the query.")
         } else {
           val spec = distinctWindowSpec.head
-          (spec.partitionSpec, spec.orderSpec)
+          (spec.partitionSpec, spec.orderSpec, WindowFunctionType.functionType(expr))
         }
       }.toSeq
 
@@ -1909,7 +1910,7 @@ class Analyzer(
       // setting this to the child of the next Window operator.
       val windowOps =
         groupedWindowExpressions.foldLeft(child) {
-          case (last, ((partitionSpec, orderSpec), windowExpressions)) =>
+          case (last, ((partitionSpec, orderSpec, _), windowExpressions)) =>
             Window(windowExpressions, partitionSpec, orderSpec, last)
         }
 
index 90bda2a..af256b9 100644 (file)
@@ -17,6 +17,7 @@
 
 package org.apache.spark.sql.catalyst.analysis
 
+import org.apache.spark.api.python.PythonEvalType
 import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.SubExprUtils._
@@ -112,12 +113,19 @@ trait CheckAnalysis extends PredicateHelper {
             failAnalysis("An offset window function can only be evaluated in an ordered " +
               s"row-based window frame with a single offset: $w")
 
+          case _ @ WindowExpression(_: PythonUDF,
+            WindowSpecDefinition(_, _, frame: SpecifiedWindowFrame))
+              if !frame.isUnbounded =>
+            failAnalysis("Only unbounded window frame is supported with Pandas UDFs.")
+
           case w @ WindowExpression(e, s) =>
             // Only allow window functions with an aggregate expression or an offset window
-            // function.
+            // function or a Pandas window UDF.
             e match {
               case _: AggregateExpression | _: OffsetWindowFunction | _: AggregateWindowFunction =>
                 w
+              case f: PythonUDF if PythonUDF.isWindowPandasUDF(f) =>
+                w
               case _ =>
                 failAnalysis(s"Expression '$e' not supported within a window function.")
             }
@@ -154,7 +162,7 @@ trait CheckAnalysis extends PredicateHelper {
 
           case Aggregate(groupingExprs, aggregateExprs, child) =>
             def isAggregateExpression(expr: Expression) = {
-              expr.isInstanceOf[AggregateExpression] || PythonUDF.isGroupAggPandasUDF(expr)
+              expr.isInstanceOf[AggregateExpression] || PythonUDF.isGroupedAggPandasUDF(expr)
             }
 
             def checkValidAggregateExpression(expr: Expression): Unit = expr match {
index efd664d..6530b17 100644 (file)
@@ -34,10 +34,14 @@ object PythonUDF {
     e.isInstanceOf[PythonUDF] && SCALAR_TYPES.contains(e.asInstanceOf[PythonUDF].evalType)
   }
 
-  def isGroupAggPandasUDF(e: Expression): Boolean = {
+  def isGroupedAggPandasUDF(e: Expression): Boolean = {
     e.isInstanceOf[PythonUDF] &&
       e.asInstanceOf[PythonUDF].evalType == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF
   }
+
+  // This is currently same as GroupedAggPandasUDF, but we might support new types in the future,
+  // e.g, N -> N transform.
+  def isWindowPandasUDF(e: Expression): Boolean = isGroupedAggPandasUDF(e)
 }
 
 /**
index 9fe2fb2..f957aaa 100644 (file)
@@ -21,7 +21,7 @@ import java.util.Locale
 
 import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedException}
 import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
-import org.apache.spark.sql.catalyst.expressions.aggregate.{DeclarativeAggregate, NoOp}
+import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateFunction, DeclarativeAggregate, NoOp}
 import org.apache.spark.sql.types._
 
 /**
@@ -298,6 +298,37 @@ trait WindowFunction extends Expression {
 }
 
 /**
+ * Case objects that describe whether a window function is a SQL window function or a Python
+ * user-defined window function.
+ */
+sealed trait WindowFunctionType
+
+object WindowFunctionType {
+  case object SQL extends WindowFunctionType
+  case object Python extends WindowFunctionType
+
+  def functionType(windowExpression: NamedExpression): WindowFunctionType = {
+    val t = windowExpression.collectFirst {
+      case _: WindowFunction | _: AggregateFunction => SQL
+      case udf: PythonUDF if PythonUDF.isWindowPandasUDF(udf) => Python
+    }
+
+    // Normally a window expression would either have a SQL window function, a SQL
+    // aggregate function or a python window UDF. However, sometimes the optimizer will replace
+    // the window function if the value of the window function can be predetermined.
+    // For example, for query:
+    //
+    // select count(NULL) over () from values 1.0, 2.0, 3.0 T(a)
+    //
+    // The window function will be replaced by expression literal(0)
+    // To handle this case, if a window expression doesn't have a regular window function, we
+    // consider its type to be SQL as literal(0) is also a SQL expression.
+    t.getOrElse(SQL)
+  }
+}
+
+
+/**
  * An offset window function is a window function that returns the value of the input column offset
  * by a number of rows within the partition. For instance: an OffsetWindowfunction for value x with
  * offset -2, will get the value of x 2 rows back in the partition.
index bfa6111..aa992de 100644 (file)
@@ -621,12 +621,15 @@ object CollapseRepartition extends Rule[LogicalPlan] {
 /**
  * Collapse Adjacent Window Expression.
  * - If the partition specs and order specs are the same and the window expression are
- *   independent, collapse into the parent.
+ *   independent and are of the same window function type, collapse into the parent.
  */
 object CollapseWindow extends Rule[LogicalPlan] {
   def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
     case w1 @ Window(we1, ps1, os1, w2 @ Window(we2, ps2, os2, grandChild))
-        if ps1 == ps2 && os1 == os2 && w1.references.intersect(w2.windowOutputSet).isEmpty =>
+        if ps1 == ps2 && os1 == os2 && w1.references.intersect(w2.windowOutputSet).isEmpty &&
+          // This assumes Window contains the same type of window expressions. This is ensured
+          // by ExtractWindowFunctions.
+          WindowFunctionType.functionType(we1.head) == WindowFunctionType.functionType(we2.head) =>
       w1.copy(windowExpressions = we2 ++ we1, child = grandChild)
   }
 }
index 626f905..84be677 100644 (file)
@@ -18,6 +18,7 @@
 package org.apache.spark.sql.catalyst.planning
 
 import org.apache.spark.internal.Logging
+import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
 import org.apache.spark.sql.catalyst.plans._
@@ -215,7 +216,7 @@ object PhysicalAggregation {
           case agg: AggregateExpression
             if !equivalentAggregateExpressions.addExpr(agg) => agg
           case udf: PythonUDF
-            if PythonUDF.isGroupAggPandasUDF(udf) &&
+            if PythonUDF.isGroupedAggPandasUDF(udf) &&
               !equivalentAggregateExpressions.addExpr(udf) => udf
         }
       }
@@ -245,7 +246,7 @@ object PhysicalAggregation {
             equivalentAggregateExpressions.getEquivalentExprs(ae).headOption
               .getOrElse(ae).asInstanceOf[AggregateExpression].resultAttribute
             // Similar to AggregateExpression
-          case ue: PythonUDF if PythonUDF.isGroupAggPandasUDF(ue) =>
+          case ue: PythonUDF if PythonUDF.isGroupedAggPandasUDF(ue) =>
             equivalentAggregateExpressions.getEquivalentExprs(ue).headOption
               .getOrElse(ue).asInstanceOf[PythonUDF].resultAttribute
           case expression =>
@@ -268,3 +269,40 @@ object PhysicalAggregation {
     case _ => None
   }
 }
+
+/**
+ * An extractor used when planning physical execution of a window. This extractor outputs
+ * the window function type of the logical window.
+ *
+ * The input logical window must contain same type of window functions, which is ensured by
+ * the rule ExtractWindowExpressions in the analyzer.
+ */
+object PhysicalWindow {
+  // windowFunctionType, windowExpression, partitionSpec, orderSpec, child
+  private type ReturnType =
+    (WindowFunctionType, Seq[NamedExpression], Seq[Expression], Seq[SortOrder], LogicalPlan)
+
+  def unapply(a: Any): Option[ReturnType] = a match {
+    case expr @ logical.Window(windowExpressions, partitionSpec, orderSpec, child) =>
+
+      // The window expression should not be empty here, otherwise it's a bug.
+      if (windowExpressions.isEmpty) {
+        throw new AnalysisException(s"Window expression is empty in $expr")
+      }
+
+      val windowFunctionType = windowExpressions.map(WindowFunctionType.functionType)
+        .reduceLeft { (t1: WindowFunctionType, t2: WindowFunctionType) =>
+          if (t1 != t2) {
+            // We shouldn't have different window function type here, otherwise it's a bug.
+            throw new AnalysisException(
+              s"Found different window function type in $windowExpressions")
+          } else {
+            t1
+          }
+        }
+
+      Some((windowFunctionType, windowExpressions, partitionSpec, orderSpec, child))
+
+    case _ => None
+  }
+}
index 7404887..75f5ec0 100644 (file)
@@ -41,6 +41,7 @@ class SparkPlanner(
       DataSourceStrategy(conf) ::
       SpecialLimits ::
       Aggregation ::
+      Window ::
       JoinSelection ::
       InMemoryScans ::
       BasicOperators :: Nil)
index be34387..d6951ad 100644 (file)
@@ -327,7 +327,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
       case PhysicalAggregation(
         namedGroupingExpressions, aggregateExpressions, rewrittenResultExpressions, child) =>
 
-        if (aggregateExpressions.exists(PythonUDF.isGroupAggPandasUDF)) {
+        if (aggregateExpressions.exists(PythonUDF.isGroupedAggPandasUDF)) {
           throw new AnalysisException(
             "Streaming aggregation doesn't support group aggregate pandas UDF")
         }
@@ -428,6 +428,22 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
     }
   }
 
+  object Window extends Strategy {
+    def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
+      case PhysicalWindow(
+        WindowFunctionType.SQL, windowExprs, partitionSpec, orderSpec, child) =>
+        execution.window.WindowExec(
+          windowExprs, partitionSpec, orderSpec, planLater(child)) :: Nil
+
+      case PhysicalWindow(
+        WindowFunctionType.Python, windowExprs, partitionSpec, orderSpec, child) =>
+        execution.python.WindowInPandasExec(
+          windowExprs, partitionSpec, orderSpec, planLater(child)) :: Nil
+
+      case _ => Nil
+    }
+  }
+
   protected lazy val singleRowRdd = sparkContext.parallelize(Seq(InternalRow()), 1)
 
   object InMemoryScans extends Strategy {
@@ -548,8 +564,6 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
         execution.FilterExec(f.typedCondition(f.deserializer), planLater(f.child)) :: Nil
       case e @ logical.Expand(_, _, child) =>
         execution.ExpandExec(e.projections, e.output, planLater(child)) :: Nil
-      case logical.Window(windowExprs, partitionSpec, orderSpec, child) =>
-        execution.window.WindowExec(windowExprs, partitionSpec, orderSpec, planLater(child)) :: Nil
       case logical.Sample(lb, ub, withReplacement, seed, child) =>
         execution.SampleExec(lb, ub, withReplacement, seed, planLater(child)) :: Nil
       case logical.LocalRelation(output, data, _) =>
index 9d56f48..1e09610 100644 (file)
@@ -39,7 +39,7 @@ object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] {
    */
   private def belongAggregate(e: Expression, agg: Aggregate): Boolean = {
     e.isInstanceOf[AggregateExpression] ||
-      PythonUDF.isGroupAggPandasUDF(e) ||
+      PythonUDF.isGroupedAggPandasUDF(e) ||
       agg.groupingExpressions.exists(_.semanticEquals(e))
   }
 
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala
new file mode 100644 (file)
index 0000000..c76832a
--- /dev/null
@@ -0,0 +1,173 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.python
+
+import java.io.File
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark.{SparkEnv, TaskContext}
+import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.physical._
+import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode}
+import org.apache.spark.sql.types.{DataType, StructField, StructType}
+import org.apache.spark.util.Utils
+
+case class WindowInPandasExec(
+    windowExpression: Seq[NamedExpression],
+    partitionSpec: Seq[Expression],
+    orderSpec: Seq[SortOrder],
+    child: SparkPlan) extends UnaryExecNode {
+
+  override def output: Seq[Attribute] =
+    child.output ++ windowExpression.map(_.toAttribute)
+
+  override def requiredChildDistribution: Seq[Distribution] = {
+    if (partitionSpec.isEmpty) {
+      // Only show warning when the number of bytes is larger than 100 MB?
+      logWarning("No Partition Defined for Window operation! Moving all data to a single "
+        + "partition, this can cause serious performance degradation.")
+      AllTuples :: Nil
+    } else {
+      ClusteredDistribution(partitionSpec) :: Nil
+    }
+  }
+
+  override def requiredChildOrdering: Seq[Seq[SortOrder]] =
+    Seq(partitionSpec.map(SortOrder(_, Ascending)) ++ orderSpec)
+
+  override def outputOrdering: Seq[SortOrder] = child.outputOrdering
+
+  override def outputPartitioning: Partitioning = child.outputPartitioning
+
+  private def collectFunctions(udf: PythonUDF): (ChainedPythonFunctions, Seq[Expression]) = {
+    udf.children match {
+      case Seq(u: PythonUDF) =>
+        val (chained, children) = collectFunctions(u)
+        (ChainedPythonFunctions(chained.funcs ++ Seq(udf.func)), children)
+      case children =>
+        // There should not be any other UDFs, or the children can't be evaluated directly.
+        assert(children.forall(_.find(_.isInstanceOf[PythonUDF]).isEmpty))
+        (ChainedPythonFunctions(Seq(udf.func)), udf.children)
+    }
+  }
+
+  /**
+   * Create the resulting projection.
+   *
+   * This method uses Code Generation. It can only be used on the executor side.
+   *
+   * @param expressions unbound ordered function expressions.
+   * @return the final resulting projection.
+   */
+  private[this] def createResultProjection(expressions: Seq[Expression]): UnsafeProjection = {
+    val references = expressions.zipWithIndex.map { case (e, i) =>
+      // Results of window expressions will be on the right side of child's output
+      BoundReference(child.output.size + i, e.dataType, e.nullable)
+    }
+    val unboundToRefMap = expressions.zip(references).toMap
+    val patchedWindowExpression = windowExpression.map(_.transform(unboundToRefMap))
+    UnsafeProjection.create(
+      child.output ++ patchedWindowExpression,
+      child.output)
+  }
+
+  protected override def doExecute(): RDD[InternalRow] = {
+    val inputRDD = child.execute()
+
+    val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536)
+    val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true)
+    val sessionLocalTimeZone = conf.sessionLocalTimeZone
+    val pandasRespectSessionTimeZone = conf.pandasRespectSessionTimeZone
+
+    // Extract window expressions and window functions
+    val expressions = windowExpression.flatMap(_.collect { case e: WindowExpression => e })
+
+    val udfExpressions = expressions.map(_.windowFunction.asInstanceOf[PythonUDF])
+
+    val (pyFuncs, inputs) = udfExpressions.map(collectFunctions).unzip
+
+    // Filter child output attributes down to only those that are UDF inputs.
+    // Also eliminate duplicate UDF inputs.
+    val allInputs = new ArrayBuffer[Expression]
+    val dataTypes = new ArrayBuffer[DataType]
+    val argOffsets = inputs.map { input =>
+      input.map { e =>
+        if (allInputs.exists(_.semanticEquals(e))) {
+          allInputs.indexWhere(_.semanticEquals(e))
+        } else {
+          allInputs += e
+          dataTypes += e.dataType
+          allInputs.length - 1
+        }
+      }.toArray
+    }.toArray
+
+    // Schema of input rows to the python runner
+    val windowInputSchema = StructType(dataTypes.zipWithIndex.map { case (dt, i) =>
+      StructField(s"_$i", dt)
+    })
+
+    inputRDD.mapPartitionsInternal { iter =>
+      val context = TaskContext.get()
+
+      val grouped = if (partitionSpec.isEmpty) {
+        // Use an empty unsafe row as a place holder for the grouping key
+        Iterator((new UnsafeRow(), iter))
+      } else {
+        GroupedIterator(iter, partitionSpec, child.output)
+      }
+
+      // The queue used to buffer input rows so we can drain it to
+      // combine input with output from Python.
+      val queue = HybridRowQueue(context.taskMemoryManager(),
+        new File(Utils.getLocalDir(SparkEnv.get.conf)), child.output.length)
+      context.addTaskCompletionListener { _ =>
+        queue.close()
+      }
+
+      val inputProj = UnsafeProjection.create(allInputs, child.output)
+      val pythonInput = grouped.map { case (_, rows) =>
+        rows.map { row =>
+          queue.add(row.asInstanceOf[UnsafeRow])
+          inputProj(row)
+        }
+      }
+
+      val windowFunctionResult = new ArrowPythonRunner(
+        pyFuncs, bufferSize, reuseWorker,
+        PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF,
+        argOffsets, windowInputSchema,
+        sessionLocalTimeZone, pandasRespectSessionTimeZone)
+        .compute(pythonInput, context.partitionId(), context)
+
+      val joined = new JoinedRow
+      val resultProj = createResultProjection(expressions)
+
+      windowFunctionResult.flatMap(_.rowIterator.asScala).map { windowOutput =>
+        val leftRow = queue.remove()
+        val joinedRow = joined(leftRow, windowOutput)
+        resultProj(joinedRow)
+      }
+    }
+  }
+}