[SPARK-24781][SQL] Using a reference from Dataset in Filter/Sort might not work
authorLiang-Chi Hsieh <viirya@gmail.com>
Fri, 13 Jul 2018 15:25:00 +0000 (08:25 -0700)
committerXiao Li <gatorsmile@gmail.com>
Fri, 13 Jul 2018 15:25:00 +0000 (08:25 -0700)
## What changes were proposed in this pull request?

When we use a reference from Dataset in filter or sort, which was not used in the prior select, an AnalysisException occurs, e.g.,

```scala
val df = Seq(("test1", 0), ("test2", 1)).toDF("name", "id")
df.select(df("name")).filter(df("id") === 0).show()
```

```scala
org.apache.spark.sql.AnalysisException: Resolved attribute(s) id#6 missing from name#5 in operator !Filter (id#6 = 0).;;
!Filter (id#6 = 0)
   +- AnalysisBarrier
      +- Project [name#5]
         +- Project [_1#2 AS name#5, _2#3 AS id#6]
            +- LocalRelation [_1#2, _2#3]
```
This change updates the rule `ResolveMissingReferences` so `Filter` and `Sort` with non-empty `missingInputs` will also be transformed.

## How was this patch tested?

Added tests.

Author: Liang-Chi Hsieh <viirya@gmail.com>

Closes #21745 from viirya/SPARK-24781.

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala

index 960ee27..36f14cc 100644 (file)
@@ -1132,7 +1132,8 @@ class Analyzer(
       case sa @ Sort(_, _, AnalysisBarrier(child: Aggregate)) => sa
       case sa @ Sort(_, _, child: Aggregate) => sa
 
-      case s @ Sort(order, _, child) if !s.resolved && child.resolved =>
+      case s @ Sort(order, _, child)
+          if (!s.resolved || s.missingInput.nonEmpty) && child.resolved =>
         val (newOrder, newChild) = resolveExprsAndAddMissingAttrs(order, child)
         val ordering = newOrder.map(_.asInstanceOf[SortOrder])
         if (child.output == newChild.output) {
@@ -1143,7 +1144,7 @@ class Analyzer(
           Project(child.output, newSort)
         }
 
-      case f @ Filter(cond, child) if !f.resolved && child.resolved =>
+      case f @ Filter(cond, child) if (!f.resolved || f.missingInput.nonEmpty) && child.resolved =>
         val (newCond, newChild) = resolveExprsAndAddMissingAttrs(Seq(cond), child)
         if (child.output == newChild.output) {
           f.copy(condition = newCond.head)
@@ -1154,10 +1155,17 @@ class Analyzer(
         }
     }
 
+    /**
+     * This method tries to resolve expressions and find missing attributes recursively. Specially,
+     * when the expressions used in `Sort` or `Filter` contain unresolved attributes or resolved
+     * attributes which are missed from child output. This method tries to find the missing
+     * attributes out and add into the projection.
+     */
     private def resolveExprsAndAddMissingAttrs(
         exprs: Seq[Expression], plan: LogicalPlan): (Seq[Expression], LogicalPlan) = {
-      if (exprs.forall(_.resolved)) {
-        // All given expressions are resolved, no need to continue anymore.
+      // Missing attributes can be unresolved attributes or resolved attributes which are not in
+      // the output attributes of the plan.
+      if (exprs.forall(e => e.resolved && e.references.subsetOf(plan.outputSet))) {
         (exprs, plan)
       } else {
         plan match {
@@ -1168,15 +1176,19 @@ class Analyzer(
             (newExprs, AnalysisBarrier(newChild))
 
           case p: Project =>
+            // Resolving expressions against current plan.
             val maybeResolvedExprs = exprs.map(resolveExpression(_, p))
+            // Recursively resolving expressions on the child of current plan.
             val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(maybeResolvedExprs, p.child)
-            val missingAttrs = AttributeSet(newExprs) -- AttributeSet(maybeResolvedExprs)
+            // If some attributes used by expressions are resolvable only on the rewritten child
+            // plan, we need to add them into original projection.
+            val missingAttrs = (AttributeSet(newExprs) -- p.outputSet).intersect(newChild.outputSet)
             (newExprs, Project(p.projectList ++ missingAttrs, newChild))
 
           case a @ Aggregate(groupExprs, aggExprs, child) =>
             val maybeResolvedExprs = exprs.map(resolveExpression(_, a))
             val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(maybeResolvedExprs, child)
-            val missingAttrs = AttributeSet(newExprs) -- AttributeSet(maybeResolvedExprs)
+            val missingAttrs = (AttributeSet(newExprs) -- a.outputSet).intersect(newChild.outputSet)
             if (missingAttrs.forall(attr => groupExprs.exists(_.semanticEquals(attr)))) {
               // All the missing attributes are grouping expressions, valid case.
               (newExprs, a.copy(aggregateExpressions = aggExprs ++ missingAttrs, child = newChild))
@@ -1526,7 +1538,11 @@ class Analyzer(
 
         // Try resolving the ordering as though it is in the aggregate clause.
         try {
-          val unresolvedSortOrders = sortOrder.filter(s => !s.resolved || containsAggregate(s))
+          // If a sort order is unresolved, containing references not in aggregate, or containing
+          // `AggregateExpression`, we need to push down it to the underlying aggregate operator.
+          val unresolvedSortOrders = sortOrder.filter { s =>
+            !s.resolved || !s.references.subsetOf(aggregate.outputSet) || containsAggregate(s)
+          }
           val aliasedOrdering =
             unresolvedSortOrders.map(o => Alias(o.child, "aggOrder")())
           val aggregatedOrdering = aggregate.copy(aggregateExpressions = aliasedOrdering)
index 9d7645d..5babdf6 100644 (file)
@@ -2387,4 +2387,29 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
     val mapWithBinaryKey = map(lit(Array[Byte](1.toByte)), lit(1))
     checkAnswer(spark.range(1).select(mapWithBinaryKey.getItem(Array[Byte](1.toByte))), Row(1))
   }
+
+  test("SPARK-24781: Using a reference from Dataset in Filter/Sort") {
+    val df = Seq(("test1", 0), ("test2", 1)).toDF("name", "id")
+    val filter1 = df.select(df("name")).filter(df("id") === 0)
+    val filter2 = df.select(col("name")).filter(col("id") === 0)
+    checkAnswer(filter1, filter2.collect())
+
+    val sort1 = df.select(df("name")).orderBy(df("id"))
+    val sort2 = df.select(col("name")).orderBy(col("id"))
+    checkAnswer(sort1, sort2.collect())
+  }
+
+  test("SPARK-24781: Using a reference not in aggregation in Filter/Sort") {
+     withSQLConf(SQLConf.DATAFRAME_RETAIN_GROUP_COLUMNS.key -> "false") {
+      val df = Seq(("test1", 0), ("test2", 1)).toDF("name", "id")
+
+      val aggPlusSort1 = df.groupBy(df("name")).agg(count(df("name"))).orderBy(df("name"))
+      val aggPlusSort2 = df.groupBy(col("name")).agg(count(col("name"))).orderBy(col("name"))
+      checkAnswer(aggPlusSort1, aggPlusSort2.collect())
+
+      val aggPlusFilter1 = df.groupBy(df("name")).agg(count(df("name"))).filter(df("name") === 0)
+      val aggPlusFilter2 = df.groupBy(col("name")).agg(count(col("name"))).filter(col("name") === 0)
+      checkAnswer(aggPlusFilter1, aggPlusFilter2.collect())
+    }
+  }
 }