case sa @ Sort(_, _, AnalysisBarrier(child: Aggregate)) => sa
case sa @ Sort(_, _, child: Aggregate) => sa
- case s @ Sort(order, _, child) if !s.resolved && child.resolved =>
+ case s @ Sort(order, _, child)
+ if (!s.resolved || s.missingInput.nonEmpty) && child.resolved =>
val (newOrder, newChild) = resolveExprsAndAddMissingAttrs(order, child)
val ordering = newOrder.map(_.asInstanceOf[SortOrder])
if (child.output == newChild.output) {
Project(child.output, newSort)
}
- case f @ Filter(cond, child) if !f.resolved && child.resolved =>
+ case f @ Filter(cond, child) if (!f.resolved || f.missingInput.nonEmpty) && child.resolved =>
val (newCond, newChild) = resolveExprsAndAddMissingAttrs(Seq(cond), child)
if (child.output == newChild.output) {
f.copy(condition = newCond.head)
}
}
+ /**
+ * This method tries to resolve expressions and find missing attributes recursively. Specially,
+ * when the expressions used in `Sort` or `Filter` contain unresolved attributes or resolved
+ * attributes which are missed from child output. This method tries to find the missing
+ * attributes out and add into the projection.
+ */
private def resolveExprsAndAddMissingAttrs(
exprs: Seq[Expression], plan: LogicalPlan): (Seq[Expression], LogicalPlan) = {
- if (exprs.forall(_.resolved)) {
- // All given expressions are resolved, no need to continue anymore.
+ // Missing attributes can be unresolved attributes or resolved attributes which are not in
+ // the output attributes of the plan.
+ if (exprs.forall(e => e.resolved && e.references.subsetOf(plan.outputSet))) {
(exprs, plan)
} else {
plan match {
(newExprs, AnalysisBarrier(newChild))
case p: Project =>
+ // Resolving expressions against current plan.
val maybeResolvedExprs = exprs.map(resolveExpression(_, p))
+ // Recursively resolving expressions on the child of current plan.
val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(maybeResolvedExprs, p.child)
- val missingAttrs = AttributeSet(newExprs) -- AttributeSet(maybeResolvedExprs)
+ // If some attributes used by expressions are resolvable only on the rewritten child
+ // plan, we need to add them into original projection.
+ val missingAttrs = (AttributeSet(newExprs) -- p.outputSet).intersect(newChild.outputSet)
(newExprs, Project(p.projectList ++ missingAttrs, newChild))
case a @ Aggregate(groupExprs, aggExprs, child) =>
val maybeResolvedExprs = exprs.map(resolveExpression(_, a))
val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(maybeResolvedExprs, child)
- val missingAttrs = AttributeSet(newExprs) -- AttributeSet(maybeResolvedExprs)
+ val missingAttrs = (AttributeSet(newExprs) -- a.outputSet).intersect(newChild.outputSet)
if (missingAttrs.forall(attr => groupExprs.exists(_.semanticEquals(attr)))) {
// All the missing attributes are grouping expressions, valid case.
(newExprs, a.copy(aggregateExpressions = aggExprs ++ missingAttrs, child = newChild))
// Try resolving the ordering as though it is in the aggregate clause.
try {
- val unresolvedSortOrders = sortOrder.filter(s => !s.resolved || containsAggregate(s))
+ // If a sort order is unresolved, containing references not in aggregate, or containing
+ // `AggregateExpression`, we need to push down it to the underlying aggregate operator.
+ val unresolvedSortOrders = sortOrder.filter { s =>
+ !s.resolved || !s.references.subsetOf(aggregate.outputSet) || containsAggregate(s)
+ }
val aliasedOrdering =
unresolvedSortOrders.map(o => Alias(o.child, "aggOrder")())
val aggregatedOrdering = aggregate.copy(aggregateExpressions = aliasedOrdering)
val mapWithBinaryKey = map(lit(Array[Byte](1.toByte)), lit(1))
checkAnswer(spark.range(1).select(mapWithBinaryKey.getItem(Array[Byte](1.toByte))), Row(1))
}
+
+ test("SPARK-24781: Using a reference from Dataset in Filter/Sort") {
+ val df = Seq(("test1", 0), ("test2", 1)).toDF("name", "id")
+ val filter1 = df.select(df("name")).filter(df("id") === 0)
+ val filter2 = df.select(col("name")).filter(col("id") === 0)
+ checkAnswer(filter1, filter2.collect())
+
+ val sort1 = df.select(df("name")).orderBy(df("id"))
+ val sort2 = df.select(col("name")).orderBy(col("id"))
+ checkAnswer(sort1, sort2.collect())
+ }
+
+ test("SPARK-24781: Using a reference not in aggregation in Filter/Sort") {
+ withSQLConf(SQLConf.DATAFRAME_RETAIN_GROUP_COLUMNS.key -> "false") {
+ val df = Seq(("test1", 0), ("test2", 1)).toDF("name", "id")
+
+ val aggPlusSort1 = df.groupBy(df("name")).agg(count(df("name"))).orderBy(df("name"))
+ val aggPlusSort2 = df.groupBy(col("name")).agg(count(col("name"))).orderBy(col("name"))
+ checkAnswer(aggPlusSort1, aggPlusSort2.collect())
+
+ val aggPlusFilter1 = df.groupBy(df("name")).agg(count(df("name"))).filter(df("name") === 0)
+ val aggPlusFilter2 = df.groupBy(col("name")).agg(count(col("name"))).filter(col("name") === 0)
+ checkAnswer(aggPlusFilter1, aggPlusFilter2.collect())
+ }
+ }
}