[SPARK-24500][SQL] Make sure streams are materialized during Tree transforms.
authorHerman van Hovell <hvanhovell@databricks.com>
Wed, 13 Jun 2018 14:09:48 +0000 (07:09 -0700)
committerWenchen Fan <wenchen@databricks.com>
Wed, 13 Jun 2018 14:09:48 +0000 (07:09 -0700)
## What changes were proposed in this pull request?
If you construct catalyst trees using `scala.collection.immutable.Stream` you can run into situations where valid transformations do not seem to have any effect. There are two causes for this behavior:
- `Stream` is evaluated lazily. Note that default implementation will generally only evaluate a function for the first element (this makes testing a bit tricky).
- `TreeNode` and `QueryPlan` use side effects to detect if a tree has changed. Mapping over a stream is lazy and does not need to trigger this side effect. If this happens the node will invalidly assume that it did not change and return itself instead if the newly created node (this is for GC reasons).

This PR fixes this issue by forcing materialization on streams in `TreeNode` and `QueryPlan`.

## How was this patch tested?
Unit tests were added to `TreeNodeSuite` and `LogicalPlanSuite`. An integration test was added to the `PlannerSuite`

Author: Herman van Hovell <hvanhovell@databricks.com>

Closes #21539 from hvanhovell/SPARK-24500.

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala
sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala

index 64cb8c7..e431c95 100644 (file)
@@ -119,6 +119,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
       case Some(value) => Some(recursiveTransform(value))
       case m: Map[_, _] => m
       case d: DataType => d // Avoid unpacking Structs
+      case stream: Stream[_] => stream.map(recursiveTransform).force
       case seq: Traversable[_] => seq.map(recursiveTransform)
       case other: AnyRef => other
       case null => null
index 9c7d47f..becfa8d 100644 (file)
@@ -199,44 +199,33 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
     var changed = false
     val remainingNewChildren = newChildren.toBuffer
     val remainingOldChildren = children.toBuffer
+    def mapTreeNode(node: TreeNode[_]): TreeNode[_] = {
+      val newChild = remainingNewChildren.remove(0)
+      val oldChild = remainingOldChildren.remove(0)
+      if (newChild fastEquals oldChild) {
+        oldChild
+      } else {
+        changed = true
+        newChild
+      }
+    }
+    def mapChild(child: Any): Any = child match {
+      case arg: TreeNode[_] if containsChild(arg) => mapTreeNode(arg)
+      case nonChild: AnyRef => nonChild
+      case null => null
+    }
     val newArgs = mapProductIterator {
       case s: StructType => s // Don't convert struct types to some other type of Seq[StructField]
       // Handle Seq[TreeNode] in TreeNode parameters.
-      case s: Seq[_] => s.map {
-        case arg: TreeNode[_] if containsChild(arg) =>
-          val newChild = remainingNewChildren.remove(0)
-          val oldChild = remainingOldChildren.remove(0)
-          if (newChild fastEquals oldChild) {
-            oldChild
-          } else {
-            changed = true
-            newChild
-          }
-        case nonChild: AnyRef => nonChild
-        case null => null
-      }
-      case m: Map[_, _] => m.mapValues {
-        case arg: TreeNode[_] if containsChild(arg) =>
-          val newChild = remainingNewChildren.remove(0)
-          val oldChild = remainingOldChildren.remove(0)
-          if (newChild fastEquals oldChild) {
-            oldChild
-          } else {
-            changed = true
-            newChild
-          }
-        case nonChild: AnyRef => nonChild
-        case null => null
-      }.view.force // `mapValues` is lazy and we need to force it to materialize
-      case arg: TreeNode[_] if containsChild(arg) =>
-        val newChild = remainingNewChildren.remove(0)
-        val oldChild = remainingOldChildren.remove(0)
-        if (newChild fastEquals oldChild) {
-          oldChild
-        } else {
-          changed = true
-          newChild
-        }
+      case s: Stream[_] =>
+        // Stream is lazy so we need to force materialization
+        s.map(mapChild).force
+      case s: Seq[_] =>
+        s.map(mapChild)
+      case m: Map[_, _] =>
+        // `mapValues` is lazy and we need to force it to materialize
+        m.mapValues(mapChild).view.force
+      case arg: TreeNode[_] if containsChild(arg) => mapTreeNode(arg)
       case nonChild: AnyRef => nonChild
       case null => null
     }
@@ -301,6 +290,37 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
   def mapChildren(f: BaseType => BaseType): BaseType = {
     if (children.nonEmpty) {
       var changed = false
+      def mapChild(child: Any): Any = child match {
+        case arg: TreeNode[_] if containsChild(arg) =>
+          val newChild = f(arg.asInstanceOf[BaseType])
+          if (!(newChild fastEquals arg)) {
+            changed = true
+            newChild
+          } else {
+            arg
+          }
+        case tuple@(arg1: TreeNode[_], arg2: TreeNode[_]) =>
+          val newChild1 = if (containsChild(arg1)) {
+            f(arg1.asInstanceOf[BaseType])
+          } else {
+            arg1.asInstanceOf[BaseType]
+          }
+
+          val newChild2 = if (containsChild(arg2)) {
+            f(arg2.asInstanceOf[BaseType])
+          } else {
+            arg2.asInstanceOf[BaseType]
+          }
+
+          if (!(newChild1 fastEquals arg1) || !(newChild2 fastEquals arg2)) {
+            changed = true
+            (newChild1, newChild2)
+          } else {
+            tuple
+          }
+        case other => other
+      }
+
       val newArgs = mapProductIterator {
         case arg: TreeNode[_] if containsChild(arg) =>
           val newChild = f(arg.asInstanceOf[BaseType])
@@ -330,36 +350,8 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
           case other => other
         }.view.force // `mapValues` is lazy and we need to force it to materialize
         case d: DataType => d // Avoid unpacking Structs
-        case args: Traversable[_] => args.map {
-          case arg: TreeNode[_] if containsChild(arg) =>
-            val newChild = f(arg.asInstanceOf[BaseType])
-            if (!(newChild fastEquals arg)) {
-              changed = true
-              newChild
-            } else {
-              arg
-            }
-          case tuple@(arg1: TreeNode[_], arg2: TreeNode[_]) =>
-            val newChild1 = if (containsChild(arg1)) {
-              f(arg1.asInstanceOf[BaseType])
-            } else {
-              arg1.asInstanceOf[BaseType]
-            }
-
-            val newChild2 = if (containsChild(arg2)) {
-              f(arg2.asInstanceOf[BaseType])
-            } else {
-              arg2.asInstanceOf[BaseType]
-            }
-
-            if (!(newChild1 fastEquals arg1) || !(newChild2 fastEquals arg2)) {
-              changed = true
-              (newChild1, newChild2)
-            } else {
-              tuple
-            }
-          case other => other
-        }
+        case args: Stream[_] => args.map(mapChild).force // Force materialization on stream
+        case args: Traversable[_] => args.map(mapChild)
         case nonChild: AnyRef => nonChild
         case null => null
       }
index 1404174..bf569cb 100644 (file)
@@ -18,7 +18,7 @@
 package org.apache.spark.sql.catalyst.plans
 
 import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
+import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Coalesce, Literal, NamedExpression}
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.types.IntegerType
 
@@ -101,4 +101,22 @@ class LogicalPlanSuite extends SparkFunSuite {
     assert(TestBinaryRelation(relation, incrementalRelation).isStreaming === true)
     assert(TestBinaryRelation(incrementalRelation, incrementalRelation).isStreaming)
   }
+
+  test("transformExpressions works with a Stream") {
+    val id1 = NamedExpression.newExprId
+    val id2 = NamedExpression.newExprId
+    val plan = Project(Stream(
+      Alias(Literal(1), "a")(exprId = id1),
+      Alias(Literal(2), "b")(exprId = id2)),
+      OneRowRelation())
+    val result = plan.transformExpressions {
+      case Literal(v: Int, IntegerType) if v != 1 =>
+        Literal(v + 1, IntegerType)
+    }
+    val expected = Project(Stream(
+      Alias(Literal(1), "a")(exprId = id1),
+      Alias(Literal(3), "b")(exprId = id2)),
+      OneRowRelation())
+    assert(result.sameResult(expected))
+  }
 }
index 84d0ba7..b7092f4 100644 (file)
@@ -29,14 +29,14 @@ import org.json4s.jackson.JsonMethods._
 
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow, TableIdentifier}
-import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogTableType, FunctionResource, JarResource}
+import org.apache.spark.sql.catalyst.catalog._
 import org.apache.spark.sql.catalyst.dsl.expressions.DslString
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
 import org.apache.spark.sql.catalyst.plans.{LeftOuter, NaturalJoin}
 import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Union}
 import org.apache.spark.sql.catalyst.plans.physical.{IdentityBroadcastMode, RoundRobinPartitioning, SinglePartition}
-import org.apache.spark.sql.types.{BooleanType, DoubleType, FloatType, IntegerType, Metadata, NullType, StringType, StructField, StructType}
+import org.apache.spark.sql.types._
 import org.apache.spark.storage.StorageLevel
 
 case class Dummy(optKey: Option[Expression]) extends Expression with CodegenFallback {
@@ -574,4 +574,25 @@ class TreeNodeSuite extends SparkFunSuite {
     val right = JsonMethods.parse(rightJson)
     assert(left == right)
   }
+
+  test("transform works on stream of children") {
+    val before = Coalesce(Stream(Literal(1), Literal(2)))
+    // Note it is a bit tricky to exhibit the broken behavior. Basically we want to create the
+    // situation in which the TreeNode.mapChildren function's change detection is not triggered. A
+    // stream's first element is typically materialized, so in order to not trip the TreeNode change
+    // detection logic, we should not change the first element in the sequence.
+    val result = before.transform {
+      case Literal(v: Int, IntegerType) if v != 1 =>
+        Literal(v + 1, IntegerType)
+    }
+    val expected = Coalesce(Stream(Literal(1), Literal(3)))
+    assert(result === expected)
+  }
+
+  test("withNewChildren on stream of children") {
+    val before = Coalesce(Stream(Literal(1), Literal(2)))
+    val result = before.withNewChildren(Stream(Literal(1), Literal(3)))
+    val expected = Coalesce(Stream(Literal(1), Literal(3)))
+    assert(result === expected)
+  }
 }
index 98a50fb..ed0ff1b 100644 (file)
@@ -21,8 +21,8 @@ import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.{execution, Row}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, LeftOuter, RightOuter}
-import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Range, Repartition, Sort}
+import org.apache.spark.sql.catalyst.plans._
+import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Range, Repartition, Sort, Union}
 import org.apache.spark.sql.catalyst.plans.physical._
 import org.apache.spark.sql.execution.columnar.InMemoryRelation
 import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReusedExchangeExec, ReuseExchange, ShuffleExchangeExec}
@@ -679,6 +679,13 @@ class PlannerSuite extends SharedSQLContext {
     }
     assert(rangeExecInZeroPartition.head.outputPartitioning == UnknownPartitioning(0))
   }
+
+  test("SPARK-24500: create union with stream of children") {
+    val df = Union(Stream(
+      Range(1, 1, 1, 1),
+      Range(1, 2, 1, 1)))
+    df.queryExecution.executedPlan.execute()
+  }
 }
 
 // Used for unit-testing EnsureRequirements