[SPARK-24790][SQL] Allow complex aggregate expressions in Pivot
authormaryannxue <maryannxue@apache.org>
Thu, 12 Jul 2018 23:54:03 +0000 (16:54 -0700)
committerXiao Li <gatorsmile@gmail.com>
Thu, 12 Jul 2018 23:54:03 +0000 (16:54 -0700)
## What changes were proposed in this pull request?

Relax the check to allow complex aggregate expressions, like `ceil(sum(col1))` or `sum(col1) + 1`, which roughly means any aggregate expression that could appear in an Aggregate plan except pandas UDF (due to the fact that it is not supported in pivot yet).

## How was this patch tested?

Added 2 tests in pivot.sql

Author: maryannxue <maryannxue@apache.org>

Closes #21753 from maryannxue/pivot-relax-syntax.

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
sql/core/src/test/resources/sql-tests/inputs/pivot.sql
sql/core/src/test/resources/sql-tests/results/pivot.sql.out

index c078efd..9749893 100644 (file)
@@ -509,12 +509,7 @@ class Analyzer(
         || !p.pivotColumn.resolved => p
       case Pivot(groupByExprsOpt, pivotColumn, pivotValues, aggregates, child) =>
         // Check all aggregate expressions.
-        aggregates.foreach { e =>
-          if (!isAggregateExpression(e)) {
-              throw new AnalysisException(
-                s"Aggregate expression required for pivot, found '$e'")
-          }
-        }
+        aggregates.foreach(checkValidAggregateExpression)
         // Group-by expressions coming from SQL are implicit and need to be deduced.
         val groupByExprs = groupByExprsOpt.getOrElse(
           (child.outputSet -- aggregates.flatMap(_.references) -- pivotColumn.references).toSeq)
@@ -586,12 +581,17 @@ class Analyzer(
         }
     }
 
-    private def isAggregateExpression(expr: Expression): Boolean = {
-      expr match {
-        case Alias(e, _) => isAggregateExpression(e)
-        case AggregateExpression(_, _, _, _) => true
-        case _ => false
-      }
+    // Support any aggregate expression that can appear in an Aggregate plan except Pandas UDF.
+    // TODO: Support Pandas UDF.
+    private def checkValidAggregateExpression(expr: Expression): Unit = expr match {
+      case _: AggregateExpression => // OK and leave the argument check to CheckAnalysis.
+      case expr: PythonUDF if PythonUDF.isGroupedAggPandasUDF(expr) =>
+        failAnalysis("Pandas UDF aggregate expressions are currently not supported in pivot.")
+      case e: Attribute =>
+        failAnalysis(
+          s"Aggregate expression required for pivot, but '${e.sql}' " +
+          s"did not appear in any aggregate function.")
+      case e => e.children.foreach(checkValidAggregateExpression)
     }
   }
 
index 01dea6c..b3d53ad 100644 (file)
@@ -111,3 +111,21 @@ PIVOT (
   sum(earnings)
   FOR year IN (2012, 2013)
 );
+
+-- pivot with complex aggregate expressions
+SELECT * FROM (
+  SELECT year, course, earnings FROM courseSales
+)
+PIVOT (
+  ceil(sum(earnings)), avg(earnings) + 1 as a1
+  FOR course IN ('dotNET', 'Java')
+);
+
+-- pivot with invalid arguments in aggregate expressions
+SELECT * FROM (
+  SELECT year, course, earnings FROM courseSales
+)
+PIVOT (
+  sum(avg(earnings))
+  FOR course IN ('dotNET', 'Java')
+);
index 85e3488..922d8b9 100644 (file)
@@ -1,5 +1,5 @@
 -- Automatically generated by SQLQueryTestSuite
--- Number of queries: 13
+-- Number of queries: 15
 
 
 -- !query 0
@@ -176,7 +176,7 @@ PIVOT (
 struct<>
 -- !query 11 output
 org.apache.spark.sql.AnalysisException
-Aggregate expression required for pivot, found 'abs(earnings#x)';
+Aggregate expression required for pivot, but 'coursesales.`earnings`' did not appear in any aggregate function.;
 
 
 -- !query 12
@@ -192,3 +192,33 @@ struct<>
 -- !query 12 output
 org.apache.spark.sql.AnalysisException
 cannot resolve '`year`' given input columns: [__auto_generated_subquery_name.course, __auto_generated_subquery_name.earnings]; line 4 pos 0
+
+
+-- !query 13
+SELECT * FROM (
+  SELECT year, course, earnings FROM courseSales
+)
+PIVOT (
+  ceil(sum(earnings)), avg(earnings) + 1 as a1
+  FOR course IN ('dotNET', 'Java')
+)
+-- !query 13 schema
+struct<year:int,dotNET_CEIL(sum(CAST(earnings AS BIGINT))):bigint,dotNET_a1:double,Java_CEIL(sum(CAST(earnings AS BIGINT))):bigint,Java_a1:double>
+-- !query 13 output
+2012   15000   7501.0  20000   20001.0
+2013   48000   48001.0 30000   30001.0
+
+
+-- !query 14
+SELECT * FROM (
+  SELECT year, course, earnings FROM courseSales
+)
+PIVOT (
+  sum(avg(earnings))
+  FOR course IN ('dotNET', 'Java')
+)
+-- !query 14 schema
+struct<>
+-- !query 14 output
+org.apache.spark.sql.AnalysisException
+It is not allowed to use an aggregate function in the argument of another aggregate function. Please use the inner aggregate function in a sub-query.;