[SPARK-22938][SQL][FOLLOWUP] Assert that SQLConf.get is accessed only on the driver
[spark.git] / sql / catalyst / src / main / scala / org / apache / spark / sql / catalyst / analysis / TypeCoercion.scala
index b2817b0..a7ba201 100644 (file)
@@ -48,18 +48,18 @@ object TypeCoercion {
 
   def typeCoercionRules(conf: SQLConf): List[Rule[LogicalPlan]] =
     InConversion(conf) ::
-      WidenSetOperationTypes ::
+      WidenSetOperationTypes(conf) ::
       PromoteStrings(conf) ::
       DecimalPrecision ::
       BooleanEquality ::
-      FunctionArgumentConversion ::
+      FunctionArgumentConversion(conf) ::
       ConcatCoercion(conf) ::
       EltCoercion(conf) ::
-      CaseWhenCoercion ::
-      IfCoercion ::
+      CaseWhenCoercion(conf) ::
+      IfCoercion(conf) ::
       StackCoercion ::
       Division ::
-      new ImplicitTypeCasts(conf) ::
+      ImplicitTypeCasts(conf) ::
       DateTimeOperations ::
       WindowFrameCoercion ::
       Nil
@@ -83,7 +83,10 @@ object TypeCoercion {
    * with primitive types, because in that case the precision and scale of the result depends on
    * the operation. Those rules are implemented in [[DecimalPrecision]].
    */
-  val findTightestCommonType: (DataType, DataType) => Option[DataType] = {
+  def findTightestCommonType(
+      left: DataType,
+      right: DataType,
+      caseSensitive: Boolean): Option[DataType] = (left, right) match {
     case (t1, t2) if t1 == t2 => Some(t1)
     case (NullType, t1) => Some(t1)
     case (t1, NullType) => Some(t1)
@@ -102,22 +105,32 @@ object TypeCoercion {
     case (_: TimestampType, _: DateType) | (_: DateType, _: TimestampType) =>
       Some(TimestampType)
 
-    case (t1 @ StructType(fields1), t2 @ StructType(fields2)) if t1.sameType(t2) =>
-      Some(StructType(fields1.zip(fields2).map { case (f1, f2) =>
-        // Since `t1.sameType(t2)` is true, two StructTypes have the same DataType
-        // except `name` (in case of `spark.sql.caseSensitive=false`) and `nullable`.
-        // - Different names: use f1.name
-        // - Different nullabilities: `nullable` is true iff one of them is nullable.
-        val dataType = findTightestCommonType(f1.dataType, f2.dataType).get
-        StructField(f1.name, dataType, nullable = f1.nullable || f2.nullable)
-      }))
+    case (t1 @ StructType(fields1), t2 @ StructType(fields2)) =>
+      val isSameType = if (caseSensitive) {
+        DataType.equalsIgnoreNullability(t1, t2)
+      } else {
+        DataType.equalsIgnoreCaseAndNullability(t1, t2)
+      }
+
+      if (isSameType) {
+        Some(StructType(fields1.zip(fields2).map { case (f1, f2) =>
+          // Since t1 is same type of t2, two StructTypes have the same DataType
+          // except `name` (in case of `spark.sql.caseSensitive=false`) and `nullable`.
+          // - Different names: use f1.name
+          // - Different nullabilities: `nullable` is true iff one of them is nullable.
+          val dataType = findTightestCommonType(f1.dataType, f2.dataType, caseSensitive).get
+          StructField(f1.name, dataType, nullable = f1.nullable || f2.nullable)
+        }))
+      } else {
+        None
+      }
 
     case (a1 @ ArrayType(et1, hasNull1), a2 @ ArrayType(et2, hasNull2)) if a1.sameType(a2) =>
-      findTightestCommonType(et1, et2).map(ArrayType(_, hasNull1 || hasNull2))
+      findTightestCommonType(et1, et2, caseSensitive).map(ArrayType(_, hasNull1 || hasNull2))
 
     case (m1 @ MapType(kt1, vt1, hasNull1), m2 @ MapType(kt2, vt2, hasNull2)) if m1.sameType(m2) =>
-      val keyType = findTightestCommonType(kt1, kt2)
-      val valueType = findTightestCommonType(vt1, vt2)
+      val keyType = findTightestCommonType(kt1, kt2, caseSensitive)
+      val valueType = findTightestCommonType(vt1, vt2, caseSensitive)
       Some(MapType(keyType.get, valueType.get, hasNull1 || hasNull2))
 
     case _ => None
@@ -172,13 +185,14 @@ object TypeCoercion {
    * i.e. the main difference with [[findTightestCommonType]] is that here we allow some
    * loss of precision when widening decimal and double, and promotion to string.
    */
-  def findWiderTypeForTwo(t1: DataType, t2: DataType): Option[DataType] = {
-    findTightestCommonType(t1, t2)
+  def findWiderTypeForTwo(t1: DataType, t2: DataType, caseSensitive: Boolean): Option[DataType] = {
+    findTightestCommonType(t1, t2, caseSensitive)
       .orElse(findWiderTypeForDecimal(t1, t2))
       .orElse(stringPromotion(t1, t2))
       .orElse((t1, t2) match {
         case (ArrayType(et1, containsNull1), ArrayType(et2, containsNull2)) =>
-          findWiderTypeForTwo(et1, et2).map(ArrayType(_, containsNull1 || containsNull2))
+          findWiderTypeForTwo(et1, et2, caseSensitive)
+            .map(ArrayType(_, containsNull1 || containsNull2))
         case _ => None
       })
   }
@@ -193,7 +207,8 @@ object TypeCoercion {
     case _ => false
   }
 
-  private def findWiderCommonType(types: Seq[DataType]): Option[DataType] = {
+  private def findWiderCommonType(
+      types: Seq[DataType], caseSensitive: Boolean): Option[DataType] = {
     // findWiderTypeForTwo doesn't satisfy the associative law, i.e. (a op b) op c may not equal
     // to a op (b op c). This is only a problem for StringType or nested StringType in ArrayType.
     // Excluding these types, findWiderTypeForTwo satisfies the associative law. For instance,
@@ -201,7 +216,7 @@ object TypeCoercion {
     val (stringTypes, nonStringTypes) = types.partition(hasStringType(_))
     (stringTypes.distinct ++ nonStringTypes).foldLeft[Option[DataType]](Some(NullType))((r, c) =>
       r match {
-        case Some(d) => findWiderTypeForTwo(d, c)
+        case Some(d) => findWiderTypeForTwo(d, c, caseSensitive)
         case _ => None
       })
   }
@@ -213,20 +228,22 @@ object TypeCoercion {
    */
   private[analysis] def findWiderTypeWithoutStringPromotionForTwo(
       t1: DataType,
-      t2: DataType): Option[DataType] = {
-    findTightestCommonType(t1, t2)
+      t2: DataType,
+      caseSensitive: Boolean): Option[DataType] = {
+    findTightestCommonType(t1, t2, caseSensitive)
       .orElse(findWiderTypeForDecimal(t1, t2))
       .orElse((t1, t2) match {
         case (ArrayType(et1, containsNull1), ArrayType(et2, containsNull2)) =>
-          findWiderTypeWithoutStringPromotionForTwo(et1, et2)
+          findWiderTypeWithoutStringPromotionForTwo(et1, et2, caseSensitive)
             .map(ArrayType(_, containsNull1 || containsNull2))
         case _ => None
       })
   }
 
-  def findWiderTypeWithoutStringPromotion(types: Seq[DataType]): Option[DataType] = {
+  def findWiderTypeWithoutStringPromotion(
+      types: Seq[DataType], caseSensitive: Boolean): Option[DataType] = {
     types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match {
-      case Some(d) => findWiderTypeWithoutStringPromotionForTwo(d, c)
+      case Some(d) => findWiderTypeWithoutStringPromotionForTwo(d, c, caseSensitive)
       case None => None
     })
   }
@@ -279,29 +296,32 @@ object TypeCoercion {
    *
    * This rule is only applied to Union/Except/Intersect
    */
-  object WidenSetOperationTypes extends Rule[LogicalPlan] {
+  case class WidenSetOperationTypes(conf: SQLConf) extends Rule[LogicalPlan] {
 
     def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
       case s @ SetOperation(left, right) if s.childrenResolved &&
           left.output.length == right.output.length && !s.resolved =>
-        val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(left :: right :: Nil)
+        val newChildren: Seq[LogicalPlan] =
+          buildNewChildrenWithWiderTypes(left :: right :: Nil, conf.caseSensitiveAnalysis)
         assert(newChildren.length == 2)
         s.makeCopy(Array(newChildren.head, newChildren.last))
 
       case s: Union if s.childrenResolved &&
           s.children.forall(_.output.length == s.children.head.output.length) && !s.resolved =>
-        val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(s.children)
+        val newChildren: Seq[LogicalPlan] =
+          buildNewChildrenWithWiderTypes(s.children, conf.caseSensitiveAnalysis)
         s.makeCopy(Array(newChildren))
     }
 
     /** Build new children with the widest types for each attribute among all the children */
-    private def buildNewChildrenWithWiderTypes(children: Seq[LogicalPlan]): Seq[LogicalPlan] = {
+    private def buildNewChildrenWithWiderTypes(
+        children: Seq[LogicalPlan], caseSensitive: Boolean): Seq[LogicalPlan] = {
       require(children.forall(_.output.length == children.head.output.length))
 
       // Get a sequence of data types, each of which is the widest type of this specific attribute
       // in all the children
       val targetTypes: Seq[DataType] =
-        getWidestTypes(children, attrIndex = 0, mutable.Queue[DataType]())
+        getWidestTypes(children, attrIndex = 0, mutable.Queue[DataType](), caseSensitive)
 
       if (targetTypes.nonEmpty) {
         // Add an extra Project if the targetTypes are different from the original types.
@@ -316,18 +336,19 @@ object TypeCoercion {
     @tailrec private def getWidestTypes(
         children: Seq[LogicalPlan],
         attrIndex: Int,
-        castedTypes: mutable.Queue[DataType]): Seq[DataType] = {
+        castedTypes: mutable.Queue[DataType],
+        caseSensitive: Boolean): Seq[DataType] = {
       // Return the result after the widen data types have been found for all the children
       if (attrIndex >= children.head.output.length) return castedTypes.toSeq
 
       // For the attrIndex-th attribute, find the widest type
-      findWiderCommonType(children.map(_.output(attrIndex).dataType)) match {
+      findWiderCommonType(children.map(_.output(attrIndex).dataType), caseSensitive) match {
         // If unable to find an appropriate widen type for this column, return an empty Seq
         case None => Seq.empty[DataType]
         // Otherwise, record the result in the queue and find the type for the next column
         case Some(widenType) =>
           castedTypes.enqueue(widenType)
-          getWidestTypes(children, attrIndex + 1, castedTypes)
+          getWidestTypes(children, attrIndex + 1, castedTypes, caseSensitive)
       }
     }
 
@@ -432,7 +453,7 @@ object TypeCoercion {
 
         val commonTypes = lhs.zip(rhs).flatMap { case (l, r) =>
           findCommonTypeForBinaryComparison(l.dataType, r.dataType, conf)
-            .orElse(findTightestCommonType(l.dataType, r.dataType))
+            .orElse(findTightestCommonType(l.dataType, r.dataType, conf.caseSensitiveAnalysis))
         }
 
         // The number of columns/expressions must match between LHS and RHS of an
@@ -461,7 +482,7 @@ object TypeCoercion {
         }
 
       case i @ In(a, b) if b.exists(_.dataType != a.dataType) =>
-        findWiderCommonType(i.children.map(_.dataType)) match {
+        findWiderCommonType(i.children.map(_.dataType), conf.caseSensitiveAnalysis) match {
           case Some(finalDataType) => i.withNewChildren(i.children.map(Cast(_, finalDataType)))
           case None => i
         }
@@ -515,7 +536,7 @@ object TypeCoercion {
   /**
    * This ensure that the types for various functions are as expected.
    */
-  object FunctionArgumentConversion extends TypeCoercionRule {
+  case class FunctionArgumentConversion(conf: SQLConf) extends TypeCoercionRule {
     override protected def coerceTypes(
         plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
       // Skip nodes who's children have not been resolved yet.
@@ -523,7 +544,7 @@ object TypeCoercion {
 
       case a @ CreateArray(children) if !haveSameType(children) =>
         val types = children.map(_.dataType)
-        findWiderCommonType(types) match {
+        findWiderCommonType(types, conf.caseSensitiveAnalysis) match {
           case Some(finalDataType) => CreateArray(children.map(Cast(_, finalDataType)))
           case None => a
         }
@@ -531,7 +552,7 @@ object TypeCoercion {
       case c @ Concat(children) if children.forall(c => ArrayType.acceptsType(c.dataType)) &&
         !haveSameType(children) =>
         val types = children.map(_.dataType)
-        findWiderCommonType(types) match {
+        findWiderCommonType(types, conf.caseSensitiveAnalysis) match {
           case Some(finalDataType) => Concat(children.map(Cast(_, finalDataType)))
           case None => c
         }
@@ -542,7 +563,7 @@ object TypeCoercion {
           m.keys
         } else {
           val types = m.keys.map(_.dataType)
-          findWiderCommonType(types) match {
+          findWiderCommonType(types, conf.caseSensitiveAnalysis) match {
             case Some(finalDataType) => m.keys.map(Cast(_, finalDataType))
             case None => m.keys
           }
@@ -552,7 +573,7 @@ object TypeCoercion {
           m.values
         } else {
           val types = m.values.map(_.dataType)
-          findWiderCommonType(types) match {
+          findWiderCommonType(types, conf.caseSensitiveAnalysis) match {
             case Some(finalDataType) => m.values.map(Cast(_, finalDataType))
             case None => m.values
           }
@@ -580,7 +601,7 @@ object TypeCoercion {
       // compatible with every child column.
       case c @ Coalesce(es) if !haveSameType(es) =>
         val types = es.map(_.dataType)
-        findWiderCommonType(types) match {
+        findWiderCommonType(types, conf.caseSensitiveAnalysis) match {
           case Some(finalDataType) => Coalesce(es.map(Cast(_, finalDataType)))
           case None => c
         }
@@ -590,14 +611,14 @@ object TypeCoercion {
       // string.g
       case g @ Greatest(children) if !haveSameType(children) =>
         val types = children.map(_.dataType)
-        findWiderTypeWithoutStringPromotion(types) match {
+        findWiderTypeWithoutStringPromotion(types, conf.caseSensitiveAnalysis) match {
           case Some(finalDataType) => Greatest(children.map(Cast(_, finalDataType)))
           case None => g
         }
 
       case l @ Least(children) if !haveSameType(children) =>
         val types = children.map(_.dataType)
-        findWiderTypeWithoutStringPromotion(types) match {
+        findWiderTypeWithoutStringPromotion(types, conf.caseSensitiveAnalysis) match {
           case Some(finalDataType) => Least(children.map(Cast(_, finalDataType)))
           case None => l
         }
@@ -637,11 +658,11 @@ object TypeCoercion {
   /**
    * Coerces the type of different branches of a CASE WHEN statement to a common type.
    */
-  object CaseWhenCoercion extends TypeCoercionRule {
+  case class CaseWhenCoercion(conf: SQLConf) extends TypeCoercionRule {
     override protected def coerceTypes(
         plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
       case c: CaseWhen if c.childrenResolved && !c.valueTypesEqual =>
-        val maybeCommonType = findWiderCommonType(c.valueTypes)
+        val maybeCommonType = findWiderCommonType(c.valueTypes, conf.caseSensitiveAnalysis)
         maybeCommonType.map { commonType =>
           var changed = false
           val newBranches = c.branches.map { case (condition, value) =>
@@ -668,16 +689,17 @@ object TypeCoercion {
   /**
    * Coerces the type of different branches of If statement to a common type.
    */
-  object IfCoercion extends TypeCoercionRule {
+  case class IfCoercion(conf: SQLConf) extends TypeCoercionRule {
     override protected def coerceTypes(
         plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
       case e if !e.childrenResolved => e
       // Find tightest common type for If, if the true value and false value have different types.
       case i @ If(pred, left, right) if left.dataType != right.dataType =>
-        findWiderTypeForTwo(left.dataType, right.dataType).map { widestType =>
-          val newLeft = if (left.dataType == widestType) left else Cast(left, widestType)
-          val newRight = if (right.dataType == widestType) right else Cast(right, widestType)
-          If(pred, newLeft, newRight)
+        findWiderTypeForTwo(left.dataType, right.dataType, conf.caseSensitiveAnalysis).map {
+          widestType =>
+            val newLeft = if (left.dataType == widestType) left else Cast(left, widestType)
+            val newRight = if (right.dataType == widestType) right else Cast(right, widestType)
+            If(pred, newLeft, newRight)
         }.getOrElse(i)  // If there is no applicable conversion, leave expression unchanged.
       case If(Literal(null, NullType), left, right) =>
         If(Literal.create(null, BooleanType), left, right)
@@ -776,12 +798,11 @@ object TypeCoercion {
   /**
    * Casts types according to the expected input types for [[Expression]]s.
    */
-  class ImplicitTypeCasts(conf: SQLConf) extends TypeCoercionRule {
+  case class ImplicitTypeCasts(conf: SQLConf) extends TypeCoercionRule {
 
     private def rejectTzInString = conf.getConf(SQLConf.REJECT_TIMEZONE_IN_STRING)
 
-    override protected def coerceTypes(
-        plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
+    override def coerceTypes(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
       // Skip nodes who's children have not been resolved yet.
       case e if !e.childrenResolved => e
 
@@ -804,17 +825,18 @@ object TypeCoercion {
         }
 
       case b @ BinaryOperator(left, right) if left.dataType != right.dataType =>
-        findTightestCommonType(left.dataType, right.dataType).map { commonType =>
-          if (b.inputType.acceptsType(commonType)) {
-            // If the expression accepts the tightest common type, cast to that.
-            val newLeft = if (left.dataType == commonType) left else Cast(left, commonType)
-            val newRight = if (right.dataType == commonType) right else Cast(right, commonType)
-            b.withNewChildren(Seq(newLeft, newRight))
-          } else {
-            // Otherwise, don't do anything with the expression.
-            b
-          }
-        }.getOrElse(b)  // If there is no applicable conversion, leave expression unchanged.
+        findTightestCommonType(left.dataType, right.dataType, conf.caseSensitiveAnalysis).map {
+          commonType =>
+            if (b.inputType.acceptsType(commonType)) {
+              // If the expression accepts the tightest common type, cast to that.
+              val newLeft = if (left.dataType == commonType) left else Cast(left, commonType)
+              val newRight = if (right.dataType == commonType) right else Cast(right, commonType)
+              b.withNewChildren(Seq(newLeft, newRight))
+            } else {
+              // Otherwise, don't do anything with the expression.
+              b
+            }
+        }.getOrElse(b) // If there is no applicable conversion, leave expression unchanged.
 
       case e: ImplicitCastInputTypes if e.inputTypes.nonEmpty =>
         val children: Seq[Expression] = e.children.zip(e.inputTypes).map { case (in, expected) =>