[SPARK-22938][SQL][FOLLOWUP] Assert that SQLConf.get is accessed only on the driver
[spark.git] / sql / catalyst / src / test / scala / org / apache / spark / sql / catalyst / analysis / TypeCoercionSuite.scala
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one or more
3  * contributor license agreements.  See the NOTICE file distributed with
4  * this work for additional information regarding copyright ownership.
5  * The ASF licenses this file to You under the Apache License, Version 2.0
6  * (the "License"); you may not use this file except in compliance with
7  * the License.  You may obtain a copy of the License at
8  *
9  *    http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 package org.apache.spark.sql.catalyst.analysis
19
20 import java.sql.Timestamp
21
22 import org.apache.spark.sql.catalyst.analysis.TypeCoercion._
23 import org.apache.spark.sql.catalyst.dsl.expressions._
24 import org.apache.spark.sql.catalyst.expressions._
25 import org.apache.spark.sql.catalyst.plans.PlanTest
26 import org.apache.spark.sql.catalyst.plans.logical._
27 import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor}
28 import org.apache.spark.sql.internal.SQLConf
29 import org.apache.spark.sql.types._
30 import org.apache.spark.unsafe.types.CalendarInterval
31
32 class TypeCoercionSuite extends AnalysisTest {
33
34   // scalastyle:off line.size.limit
35   // The following table shows all implicit data type conversions that are not visible to the user.
36   // +----------------------+----------+-----------+-------------+----------+------------+-----------+------------+------------+-------------+------------+----------+---------------+------------+----------+-------------+----------+----------------------+---------------------+-------------+--------------+
37   // | Source Type\CAST TO  | ByteType | ShortType | IntegerType | LongType | DoubleType | FloatType | Dec(10, 2) | BinaryType | BooleanType | StringType | DateType | TimestampType | ArrayType  | MapType  | StructType  | NullType | CalendarIntervalType |     DecimalType     | NumericType | IntegralType |
38   // +----------------------+----------+-----------+-------------+----------+------------+-----------+------------+------------+-------------+------------+----------+---------------+------------+----------+-------------+----------+----------------------+---------------------+-------------+--------------+
39   // | ByteType             | ByteType | ShortType | IntegerType | LongType | DoubleType | FloatType | Dec(10, 2) | X          | X           | StringType | X        | X             | X          | X        | X           | X        | X                    | DecimalType(3, 0)   | ByteType    | ByteType     |
40   // | ShortType            | ByteType | ShortType | IntegerType | LongType | DoubleType | FloatType | Dec(10, 2) | X          | X           | StringType | X        | X             | X          | X        | X           | X        | X                    | DecimalType(5, 0)   | ShortType   | ShortType    |
41   // | IntegerType          | ByteType | ShortType | IntegerType | LongType | DoubleType | FloatType | Dec(10, 2) | X          | X           | StringType | X        | X             | X          | X        | X           | X        | X                    | DecimalType(10, 0)  | IntegerType | IntegerType  |
42   // | LongType             | ByteType | ShortType | IntegerType | LongType | DoubleType | FloatType | Dec(10, 2) | X          | X           | StringType | X        | X             | X          | X        | X           | X        | X                    | DecimalType(20, 0)  | LongType    | LongType     |
43   // | DoubleType           | ByteType | ShortType | IntegerType | LongType | DoubleType | FloatType | Dec(10, 2) | X          | X           | StringType | X        | X             | X          | X        | X           | X        | X                    | DecimalType(30, 15) | DoubleType  | IntegerType  |
44   // | FloatType            | ByteType | ShortType | IntegerType | LongType | DoubleType | FloatType | Dec(10, 2) | X          | X           | StringType | X        | X             | X          | X        | X           | X        | X                    | DecimalType(14, 7)  | FloatType   | IntegerType  |
45   // | Dec(10, 2)           | ByteType | ShortType | IntegerType | LongType | DoubleType | FloatType | Dec(10, 2) | X          | X           | StringType | X        | X             | X          | X        | X           | X        | X                    | DecimalType(10, 2)  | Dec(10, 2)  | IntegerType  |
46   // | BinaryType           | X        | X         | X           | X        | X          | X         | X          | BinaryType | X           | StringType | X        | X             | X          | X        | X           | X        | X                    | X                   | X           | X            |
47   // | BooleanType          | X        | X         | X           | X        | X          | X         | X          | X          | BooleanType | StringType | X        | X             | X          | X        | X           | X        | X                    | X                   | X           | X            |
48   // | StringType           | ByteType | ShortType | IntegerType | LongType | DoubleType | FloatType | Dec(10, 2) | BinaryType | X           | StringType | DateType | TimestampType | X          | X        | X           | X        | X                    | DecimalType(38, 18) | DoubleType  | X            |
49   // | DateType             | X        | X         | X           | X        | X          | X         | X          | X          | X           | StringType | DateType | TimestampType | X          | X        | X           | X        | X                    | X                   | X           | X            |
50   // | TimestampType        | X        | X         | X           | X        | X          | X         | X          | X          | X           | StringType | DateType | TimestampType | X          | X        | X           | X        | X                    | X                   | X           | X            |
51   // | ArrayType            | X        | X         | X           | X        | X          | X         | X          | X          | X           | X          | X        | X             | ArrayType* | X        | X           | X        | X                    | X                   | X           | X            |
52   // | MapType              | X        | X         | X           | X        | X          | X         | X          | X          | X           | X          | X        | X             | X          | MapType* | X           | X        | X                    | X                   | X           | X            |
53   // | StructType           | X        | X         | X           | X        | X          | X         | X          | X          | X           | X          | X        | X             | X          | X        | StructType* | X        | X                    | X                   | X           | X            |
54   // | NullType             | ByteType | ShortType | IntegerType | LongType | DoubleType | FloatType | Dec(10, 2) | BinaryType | BooleanType | StringType | DateType | TimestampType | ArrayType  | MapType  | StructType  | NullType | CalendarIntervalType | DecimalType(38, 18) | DoubleType  | IntegerType  |
55   // | CalendarIntervalType | X        | X         | X           | X        | X          | X         | X          | X          | X           | X          | X        | X             | X          | X        | X           | X        | CalendarIntervalType | X                   | X           | X            |
56   // +----------------------+----------+-----------+-------------+----------+------------+-----------+------------+------------+-------------+------------+----------+---------------+------------+----------+-------------+----------+----------------------+---------------------+-------------+--------------+
57   // Note: MapType*, StructType* are castable only when the internal child types also match; otherwise, not castable.
58   // Note: ArrayType* is castable when the element type is castable according to the table.
59   // scalastyle:on line.size.limit
60
61   private def shouldCast(from: DataType, to: AbstractDataType, expected: DataType): Unit = {
62     // Check default value
63     val castDefault = TypeCoercion.ImplicitTypeCasts.implicitCast(default(from), to)
64     assert(DataType.equalsIgnoreCompatibleNullability(
65       castDefault.map(_.dataType).getOrElse(null), expected),
66       s"Failed to cast $from to $to")
67
68     // Check null value
69     val castNull = TypeCoercion.ImplicitTypeCasts.implicitCast(createNull(from), to)
70     assert(DataType.equalsIgnoreCaseAndNullability(
71       castNull.map(_.dataType).getOrElse(null), expected),
72       s"Failed to cast $from to $to")
73   }
74
75   private def shouldNotCast(from: DataType, to: AbstractDataType): Unit = {
76     // Check default value
77     val castDefault = TypeCoercion.ImplicitTypeCasts.implicitCast(default(from), to)
78     assert(castDefault.isEmpty, s"Should not be able to cast $from to $to, but got $castDefault")
79
80     // Check null value
81     val castNull = TypeCoercion.ImplicitTypeCasts.implicitCast(createNull(from), to)
82     assert(castNull.isEmpty, s"Should not be able to cast $from to $to, but got $castNull")
83   }
84
85   private def default(dataType: DataType): Expression = dataType match {
86     case ArrayType(internalType: DataType, _) =>
87       CreateArray(Seq(Literal.default(internalType)))
88     case MapType(keyDataType: DataType, valueDataType: DataType, _) =>
89       CreateMap(Seq(Literal.default(keyDataType), Literal.default(valueDataType)))
90     case _ => Literal.default(dataType)
91   }
92
93   private def createNull(dataType: DataType): Expression = dataType match {
94     case ArrayType(internalType: DataType, _) =>
95       CreateArray(Seq(Literal.create(null, internalType)))
96     case MapType(keyDataType: DataType, valueDataType: DataType, _) =>
97       CreateMap(Seq(Literal.create(null, keyDataType), Literal.create(null, valueDataType)))
98     case _ => Literal.create(null, dataType)
99   }
100
101   val integralTypes: Seq[DataType] =
102     Seq(ByteType, ShortType, IntegerType, LongType)
103   val fractionalTypes: Seq[DataType] =
104     Seq(DoubleType, FloatType, DecimalType.SYSTEM_DEFAULT, DecimalType(10, 2))
105   val numericTypes: Seq[DataType] = integralTypes ++ fractionalTypes
106   val atomicTypes: Seq[DataType] =
107     numericTypes ++ Seq(BinaryType, BooleanType, StringType, DateType, TimestampType)
108   val complexTypes: Seq[DataType] =
109     Seq(ArrayType(IntegerType),
110       ArrayType(StringType),
111       MapType(StringType, StringType),
112       new StructType().add("a1", StringType),
113       new StructType().add("a1", StringType).add("a2", IntegerType))
114   val allTypes: Seq[DataType] =
115     atomicTypes ++ complexTypes ++ Seq(NullType, CalendarIntervalType)
116
117   // Check whether the type `checkedType` can be cast to all the types in `castableTypes`,
118   // but cannot be cast to the other types in `allTypes`.
119   private def checkTypeCasting(checkedType: DataType, castableTypes: Seq[DataType]): Unit = {
120     val nonCastableTypes = allTypes.filterNot(castableTypes.contains)
121
122     castableTypes.foreach { tpe =>
123       shouldCast(checkedType, tpe, tpe)
124     }
125     nonCastableTypes.foreach { tpe =>
126       shouldNotCast(checkedType, tpe)
127     }
128   }
129
130   private def checkWidenType(
131       widenFunc: (DataType, DataType, Boolean) => Option[DataType],
132       t1: DataType,
133       t2: DataType,
134       expected: Option[DataType],
135       isSymmetric: Boolean = true): Unit = {
136     var found = widenFunc(t1, t2, conf.caseSensitiveAnalysis)
137     assert(found == expected,
138       s"Expected $expected as wider common type for $t1 and $t2, found $found")
139     // Test both directions to make sure the widening is symmetric.
140     if (isSymmetric) {
141       found = widenFunc(t2, t1, conf.caseSensitiveAnalysis)
142       assert(found == expected,
143         s"Expected $expected as wider common type for $t2 and $t1, found $found")
144     }
145   }
146
147   test("implicit type cast - ByteType") {
148     val checkedType = ByteType
149     checkTypeCasting(checkedType, castableTypes = numericTypes ++ Seq(StringType))
150     shouldCast(checkedType, DecimalType, DecimalType.ByteDecimal)
151     shouldCast(checkedType, NumericType, checkedType)
152     shouldCast(checkedType, IntegralType, checkedType)
153   }
154
155   test("implicit type cast - ShortType") {
156     val checkedType = ShortType
157     checkTypeCasting(checkedType, castableTypes = numericTypes ++ Seq(StringType))
158     shouldCast(checkedType, DecimalType, DecimalType.ShortDecimal)
159     shouldCast(checkedType, NumericType, checkedType)
160     shouldCast(checkedType, IntegralType, checkedType)
161   }
162
163   test("implicit type cast - IntegerType") {
164     val checkedType = IntegerType
165     checkTypeCasting(checkedType, castableTypes = numericTypes ++ Seq(StringType))
166     shouldCast(IntegerType, DecimalType, DecimalType.IntDecimal)
167     shouldCast(checkedType, NumericType, checkedType)
168     shouldCast(checkedType, IntegralType, checkedType)
169   }
170
171   test("implicit type cast - LongType") {
172     val checkedType = LongType
173     checkTypeCasting(checkedType, castableTypes = numericTypes ++ Seq(StringType))
174     shouldCast(checkedType, DecimalType, DecimalType.LongDecimal)
175     shouldCast(checkedType, NumericType, checkedType)
176     shouldCast(checkedType, IntegralType, checkedType)
177   }
178
179   test("implicit type cast - FloatType") {
180     val checkedType = FloatType
181     checkTypeCasting(checkedType, castableTypes = numericTypes ++ Seq(StringType))
182     shouldCast(checkedType, DecimalType, DecimalType.FloatDecimal)
183     shouldCast(checkedType, NumericType, checkedType)
184     shouldNotCast(checkedType, IntegralType)
185   }
186
187   test("implicit type cast - DoubleType") {
188     val checkedType = DoubleType
189     checkTypeCasting(checkedType, castableTypes = numericTypes ++ Seq(StringType))
190     shouldCast(checkedType, DecimalType, DecimalType.DoubleDecimal)
191     shouldCast(checkedType, NumericType, checkedType)
192     shouldNotCast(checkedType, IntegralType)
193   }
194
195   test("implicit type cast - DecimalType(10, 2)") {
196     val checkedType = DecimalType(10, 2)
197     checkTypeCasting(checkedType, castableTypes = numericTypes ++ Seq(StringType))
198     shouldCast(checkedType, DecimalType, checkedType)
199     shouldCast(checkedType, NumericType, checkedType)
200     shouldNotCast(checkedType, IntegralType)
201   }
202
203   test("implicit type cast - BinaryType") {
204     val checkedType = BinaryType
205     checkTypeCasting(checkedType, castableTypes = Seq(checkedType, StringType))
206     shouldNotCast(checkedType, DecimalType)
207     shouldNotCast(checkedType, NumericType)
208     shouldNotCast(checkedType, IntegralType)
209   }
210
211   test("implicit type cast - BooleanType") {
212     val checkedType = BooleanType
213     checkTypeCasting(checkedType, castableTypes = Seq(checkedType, StringType))
214     shouldNotCast(checkedType, DecimalType)
215     shouldNotCast(checkedType, NumericType)
216     shouldNotCast(checkedType, IntegralType)
217   }
218
219   test("implicit type cast - StringType") {
220     val checkedType = StringType
221     val nonCastableTypes =
222       complexTypes ++ Seq(BooleanType, NullType, CalendarIntervalType)
223     checkTypeCasting(checkedType, castableTypes = allTypes.filterNot(nonCastableTypes.contains))
224     shouldCast(checkedType, DecimalType, DecimalType.SYSTEM_DEFAULT)
225     shouldCast(checkedType, NumericType, NumericType.defaultConcreteType)
226     shouldNotCast(checkedType, IntegralType)
227   }
228
229   test("implicit type cast - DateType") {
230     val checkedType = DateType
231     checkTypeCasting(checkedType, castableTypes = Seq(checkedType, StringType, TimestampType))
232     shouldNotCast(checkedType, DecimalType)
233     shouldNotCast(checkedType, NumericType)
234     shouldNotCast(checkedType, IntegralType)
235   }
236
237   test("implicit type cast - TimestampType") {
238     val checkedType = TimestampType
239     checkTypeCasting(checkedType, castableTypes = Seq(checkedType, StringType, DateType))
240     shouldNotCast(checkedType, DecimalType)
241     shouldNotCast(checkedType, NumericType)
242     shouldNotCast(checkedType, IntegralType)
243   }
244
245   test("implicit type cast - ArrayType(StringType)") {
246     val checkedType = ArrayType(StringType)
247     val nonCastableTypes =
248       complexTypes ++ Seq(BooleanType, NullType, CalendarIntervalType)
249     checkTypeCasting(checkedType,
250       castableTypes = allTypes.filterNot(nonCastableTypes.contains).map(ArrayType(_)))
251     nonCastableTypes.map(ArrayType(_)).foreach(shouldNotCast(checkedType, _))
252     shouldNotCast(ArrayType(DoubleType, containsNull = false),
253       ArrayType(LongType, containsNull = false))
254     shouldNotCast(checkedType, DecimalType)
255     shouldNotCast(checkedType, NumericType)
256     shouldNotCast(checkedType, IntegralType)
257   }
258
259   test("implicit type cast - MapType(StringType, StringType)") {
260     val checkedType = MapType(StringType, StringType)
261     checkTypeCasting(checkedType, castableTypes = Seq(checkedType))
262     shouldNotCast(checkedType, DecimalType)
263     shouldNotCast(checkedType, NumericType)
264     shouldNotCast(checkedType, IntegralType)
265   }
266
267   test("implicit type cast - StructType().add(\"a1\", StringType)") {
268     val checkedType = new StructType().add("a1", StringType)
269     checkTypeCasting(checkedType, castableTypes = Seq(checkedType))
270     shouldNotCast(checkedType, DecimalType)
271     shouldNotCast(checkedType, NumericType)
272     shouldNotCast(checkedType, IntegralType)
273   }
274
275   test("implicit type cast - NullType") {
276     val checkedType = NullType
277     checkTypeCasting(checkedType, castableTypes = allTypes)
278     shouldCast(checkedType, DecimalType, DecimalType.SYSTEM_DEFAULT)
279     shouldCast(checkedType, NumericType, NumericType.defaultConcreteType)
280     shouldCast(checkedType, IntegralType, IntegralType.defaultConcreteType)
281   }
282
283   test("implicit type cast - CalendarIntervalType") {
284     val checkedType = CalendarIntervalType
285     checkTypeCasting(checkedType, castableTypes = Seq(checkedType))
286     shouldNotCast(checkedType, DecimalType)
287     shouldNotCast(checkedType, NumericType)
288     shouldNotCast(checkedType, IntegralType)
289   }
290
291   test("eligible implicit type cast - TypeCollection") {
292     shouldCast(NullType, TypeCollection(StringType, BinaryType), StringType)
293
294     shouldCast(StringType, TypeCollection(StringType, BinaryType), StringType)
295     shouldCast(BinaryType, TypeCollection(StringType, BinaryType), BinaryType)
296     shouldCast(StringType, TypeCollection(BinaryType, StringType), StringType)
297
298     shouldCast(IntegerType, TypeCollection(IntegerType, BinaryType), IntegerType)
299     shouldCast(IntegerType, TypeCollection(BinaryType, IntegerType), IntegerType)
300     shouldCast(BinaryType, TypeCollection(BinaryType, IntegerType), BinaryType)
301     shouldCast(BinaryType, TypeCollection(IntegerType, BinaryType), BinaryType)
302
303     shouldCast(IntegerType, TypeCollection(StringType, BinaryType), StringType)
304     shouldCast(IntegerType, TypeCollection(BinaryType, StringType), StringType)
305
306     shouldCast(DecimalType.SYSTEM_DEFAULT,
307       TypeCollection(IntegerType, DecimalType), DecimalType.SYSTEM_DEFAULT)
308     shouldCast(DecimalType(10, 2), TypeCollection(IntegerType, DecimalType), DecimalType(10, 2))
309     shouldCast(DecimalType(10, 2), TypeCollection(DecimalType, IntegerType), DecimalType(10, 2))
310     shouldCast(IntegerType, TypeCollection(DecimalType(10, 2), StringType), DecimalType(10, 2))
311
312     shouldCast(StringType, TypeCollection(NumericType, BinaryType), DoubleType)
313
314     shouldCast(
315       ArrayType(StringType, false),
316       TypeCollection(ArrayType(StringType), StringType),
317       ArrayType(StringType, false))
318
319     shouldCast(
320       ArrayType(StringType, true),
321       TypeCollection(ArrayType(StringType), StringType),
322       ArrayType(StringType, true))
323   }
324
325   test("ineligible implicit type cast - TypeCollection") {
326     shouldNotCast(IntegerType, TypeCollection(DateType, TimestampType))
327   }
328
329   test("tightest common bound for types") {
330     def widenTest(t1: DataType, t2: DataType, expected: Option[DataType]): Unit =
331       checkWidenType(TypeCoercion.findTightestCommonType, t1, t2, expected)
332
333     // Null
334     widenTest(NullType, NullType, Some(NullType))
335
336     // Boolean
337     widenTest(NullType, BooleanType, Some(BooleanType))
338     widenTest(BooleanType, BooleanType, Some(BooleanType))
339     widenTest(IntegerType, BooleanType, None)
340     widenTest(LongType, BooleanType, None)
341
342     // Integral
343     widenTest(NullType, ByteType, Some(ByteType))
344     widenTest(NullType, IntegerType, Some(IntegerType))
345     widenTest(NullType, LongType, Some(LongType))
346     widenTest(ShortType, IntegerType, Some(IntegerType))
347     widenTest(ShortType, LongType, Some(LongType))
348     widenTest(IntegerType, LongType, Some(LongType))
349     widenTest(LongType, LongType, Some(LongType))
350
351     // Floating point
352     widenTest(NullType, FloatType, Some(FloatType))
353     widenTest(NullType, DoubleType, Some(DoubleType))
354     widenTest(FloatType, DoubleType, Some(DoubleType))
355     widenTest(FloatType, FloatType, Some(FloatType))
356     widenTest(DoubleType, DoubleType, Some(DoubleType))
357
358     // Integral mixed with floating point.
359     widenTest(IntegerType, FloatType, Some(FloatType))
360     widenTest(IntegerType, DoubleType, Some(DoubleType))
361     widenTest(IntegerType, DoubleType, Some(DoubleType))
362     widenTest(LongType, FloatType, Some(FloatType))
363     widenTest(LongType, DoubleType, Some(DoubleType))
364
365     // No up-casting for fixed-precision decimal (this is handled by arithmetic rules)
366     widenTest(DecimalType(2, 1), DecimalType(3, 2), None)
367     widenTest(DecimalType(2, 1), DoubleType, None)
368     widenTest(DecimalType(2, 1), IntegerType, None)
369     widenTest(DoubleType, DecimalType(2, 1), None)
370
371     // StringType
372     widenTest(NullType, StringType, Some(StringType))
373     widenTest(StringType, StringType, Some(StringType))
374     widenTest(IntegerType, StringType, None)
375     widenTest(LongType, StringType, None)
376
377     // TimestampType
378     widenTest(NullType, TimestampType, Some(TimestampType))
379     widenTest(TimestampType, TimestampType, Some(TimestampType))
380     widenTest(DateType, TimestampType, Some(TimestampType))
381     widenTest(IntegerType, TimestampType, None)
382     widenTest(StringType, TimestampType, None)
383
384     // ComplexType
385     widenTest(NullType,
386       MapType(IntegerType, StringType, false),
387       Some(MapType(IntegerType, StringType, false)))
388     widenTest(NullType, StructType(Seq()), Some(StructType(Seq())))
389     widenTest(StringType, MapType(IntegerType, StringType, true), None)
390     widenTest(ArrayType(IntegerType), StructType(Seq()), None)
391
392     widenTest(
393       StructType(Seq(StructField("a", IntegerType))),
394       StructType(Seq(StructField("b", IntegerType))),
395       None)
396     widenTest(
397       StructType(Seq(StructField("a", IntegerType, nullable = false))),
398       StructType(Seq(StructField("a", DoubleType, nullable = false))),
399       None)
400
401     widenTest(
402       StructType(Seq(StructField("a", IntegerType, nullable = false))),
403       StructType(Seq(StructField("a", IntegerType, nullable = false))),
404       Some(StructType(Seq(StructField("a", IntegerType, nullable = false)))))
405     widenTest(
406       StructType(Seq(StructField("a", IntegerType, nullable = false))),
407       StructType(Seq(StructField("a", IntegerType, nullable = true))),
408       Some(StructType(Seq(StructField("a", IntegerType, nullable = true)))))
409     widenTest(
410       StructType(Seq(StructField("a", IntegerType, nullable = true))),
411       StructType(Seq(StructField("a", IntegerType, nullable = false))),
412       Some(StructType(Seq(StructField("a", IntegerType, nullable = true)))))
413     widenTest(
414       StructType(Seq(StructField("a", IntegerType, nullable = true))),
415       StructType(Seq(StructField("a", IntegerType, nullable = true))),
416       Some(StructType(Seq(StructField("a", IntegerType, nullable = true)))))
417
418     withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") {
419       widenTest(
420         StructType(Seq(StructField("a", IntegerType))),
421         StructType(Seq(StructField("A", IntegerType))),
422         None)
423     }
424     withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") {
425       checkWidenType(
426         TypeCoercion.findTightestCommonType,
427         StructType(Seq(StructField("a", IntegerType), StructField("B", IntegerType))),
428         StructType(Seq(StructField("A", IntegerType), StructField("b", IntegerType))),
429         Some(StructType(Seq(StructField("a", IntegerType), StructField("B", IntegerType)))),
430         isSymmetric = false)
431     }
432
433     widenTest(
434       ArrayType(IntegerType, containsNull = true),
435       ArrayType(IntegerType, containsNull = false),
436       Some(ArrayType(IntegerType, containsNull = true)))
437
438     widenTest(
439       MapType(IntegerType, StringType, valueContainsNull = true),
440       MapType(IntegerType, StringType, valueContainsNull = false),
441       Some(MapType(IntegerType, StringType, valueContainsNull = true)))
442
443     widenTest(
444       new StructType()
445         .add("arr", ArrayType(IntegerType, containsNull = true), nullable = false),
446       new StructType()
447         .add("arr", ArrayType(IntegerType, containsNull = false), nullable = true),
448       Some(new StructType()
449         .add("arr", ArrayType(IntegerType, containsNull = true), nullable = true)))
450   }
451
452   test("wider common type for decimal and array") {
453     def widenTestWithStringPromotion(
454         t1: DataType,
455         t2: DataType,
456         expected: Option[DataType]): Unit = {
457       checkWidenType(TypeCoercion.findWiderTypeForTwo, t1, t2, expected)
458     }
459
460     def widenTestWithoutStringPromotion(
461         t1: DataType,
462         t2: DataType,
463         expected: Option[DataType]): Unit = {
464       checkWidenType(TypeCoercion.findWiderTypeWithoutStringPromotionForTwo, t1, t2, expected)
465     }
466
467     // Decimal
468     widenTestWithStringPromotion(
469       DecimalType(2, 1), DecimalType(3, 2), Some(DecimalType(3, 2)))
470     widenTestWithStringPromotion(
471       DecimalType(2, 1), DoubleType, Some(DoubleType))
472     widenTestWithStringPromotion(
473       DecimalType(2, 1), IntegerType, Some(DecimalType(11, 1)))
474     widenTestWithStringPromotion(
475       DecimalType(2, 1), LongType, Some(DecimalType(21, 1)))
476
477     // ArrayType
478     widenTestWithStringPromotion(
479       ArrayType(ShortType, containsNull = true),
480       ArrayType(DoubleType, containsNull = false),
481       Some(ArrayType(DoubleType, containsNull = true)))
482     widenTestWithStringPromotion(
483       ArrayType(TimestampType, containsNull = false),
484       ArrayType(StringType, containsNull = true),
485       Some(ArrayType(StringType, containsNull = true)))
486     widenTestWithStringPromotion(
487       ArrayType(ArrayType(IntegerType), containsNull = false),
488       ArrayType(ArrayType(LongType), containsNull = false),
489       Some(ArrayType(ArrayType(LongType), containsNull = false)))
490
491     // Without string promotion
492     widenTestWithoutStringPromotion(IntegerType, StringType, None)
493     widenTestWithoutStringPromotion(StringType, TimestampType, None)
494     widenTestWithoutStringPromotion(ArrayType(LongType), ArrayType(StringType), None)
495     widenTestWithoutStringPromotion(ArrayType(StringType), ArrayType(TimestampType), None)
496
497     // String promotion
498     widenTestWithStringPromotion(IntegerType, StringType, Some(StringType))
499     widenTestWithStringPromotion(StringType, TimestampType, Some(StringType))
500     widenTestWithStringPromotion(
501       ArrayType(LongType), ArrayType(StringType), Some(ArrayType(StringType)))
502     widenTestWithStringPromotion(
503       ArrayType(StringType), ArrayType(TimestampType), Some(ArrayType(StringType)))
504   }
505
506   private def ruleTest(rule: Rule[LogicalPlan], initial: Expression, transformed: Expression) {
507     ruleTest(Seq(rule), initial, transformed)
508   }
509
510   private def ruleTest(
511       rules: Seq[Rule[LogicalPlan]],
512       initial: Expression,
513       transformed: Expression): Unit = {
514     val testRelation = LocalRelation(AttributeReference("a", IntegerType)())
515     val analyzer = new RuleExecutor[LogicalPlan] {
516       override val batches = Seq(Batch("Resolution", FixedPoint(3), rules: _*))
517     }
518
519     comparePlans(
520       analyzer.execute(Project(Seq(Alias(initial, "a")()), testRelation)),
521       Project(Seq(Alias(transformed, "a")()), testRelation))
522   }
523
524   test("cast NullType for expressions that implement ExpectsInputTypes") {
525     import TypeCoercionSuite._
526
527     ruleTest(TypeCoercion.ImplicitTypeCasts(conf),
528       AnyTypeUnaryExpression(Literal.create(null, NullType)),
529       AnyTypeUnaryExpression(Literal.create(null, NullType)))
530
531     ruleTest(TypeCoercion.ImplicitTypeCasts(conf),
532       NumericTypeUnaryExpression(Literal.create(null, NullType)),
533       NumericTypeUnaryExpression(Literal.create(null, DoubleType)))
534   }
535
536   test("cast NullType for binary operators") {
537     import TypeCoercionSuite._
538
539     ruleTest(TypeCoercion.ImplicitTypeCasts(conf),
540       AnyTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType)),
541       AnyTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType)))
542
543     ruleTest(TypeCoercion.ImplicitTypeCasts(conf),
544       NumericTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType)),
545       NumericTypeBinaryOperator(Literal.create(null, DoubleType), Literal.create(null, DoubleType)))
546   }
547
548   test("coalesce casts") {
549     val rule = TypeCoercion.FunctionArgumentConversion(conf)
550
551     val intLit = Literal(1)
552     val longLit = Literal.create(1L)
553     val doubleLit = Literal(1.0)
554     val stringLit = Literal.create("c", StringType)
555     val nullLit = Literal.create(null, NullType)
556     val floatNullLit = Literal.create(null, FloatType)
557     val floatLit = Literal.create(1.0f, FloatType)
558     val timestampLit = Literal.create("2017-04-12", TimestampType)
559     val decimalLit = Literal(new java.math.BigDecimal("1000000000000000000000"))
560     val tsArrayLit = Literal(Array(new Timestamp(System.currentTimeMillis())))
561     val strArrayLit = Literal(Array("c"))
562     val intArrayLit = Literal(Array(1))
563
564     ruleTest(rule,
565       Coalesce(Seq(doubleLit, intLit, floatLit)),
566       Coalesce(Seq(Cast(doubleLit, DoubleType),
567         Cast(intLit, DoubleType), Cast(floatLit, DoubleType))))
568
569     ruleTest(rule,
570       Coalesce(Seq(longLit, intLit, decimalLit)),
571       Coalesce(Seq(Cast(longLit, DecimalType(22, 0)),
572         Cast(intLit, DecimalType(22, 0)), Cast(decimalLit, DecimalType(22, 0)))))
573
574     ruleTest(rule,
575       Coalesce(Seq(nullLit, intLit)),
576       Coalesce(Seq(Cast(nullLit, IntegerType), Cast(intLit, IntegerType))))
577
578     ruleTest(rule,
579       Coalesce(Seq(timestampLit, stringLit)),
580       Coalesce(Seq(Cast(timestampLit, StringType), Cast(stringLit, StringType))))
581
582     ruleTest(rule,
583       Coalesce(Seq(nullLit, floatNullLit, intLit)),
584       Coalesce(Seq(Cast(nullLit, FloatType), Cast(floatNullLit, FloatType),
585         Cast(intLit, FloatType))))
586
587     ruleTest(rule,
588       Coalesce(Seq(nullLit, intLit, decimalLit, doubleLit)),
589       Coalesce(Seq(Cast(nullLit, DoubleType), Cast(intLit, DoubleType),
590         Cast(decimalLit, DoubleType), Cast(doubleLit, DoubleType))))
591
592     ruleTest(rule,
593       Coalesce(Seq(nullLit, floatNullLit, doubleLit, stringLit)),
594       Coalesce(Seq(Cast(nullLit, StringType), Cast(floatNullLit, StringType),
595         Cast(doubleLit, StringType), Cast(stringLit, StringType))))
596
597     ruleTest(rule,
598       Coalesce(Seq(timestampLit, intLit, stringLit)),
599       Coalesce(Seq(Cast(timestampLit, StringType), Cast(intLit, StringType),
600         Cast(stringLit, StringType))))
601
602     ruleTest(rule,
603       Coalesce(Seq(tsArrayLit, intArrayLit, strArrayLit)),
604       Coalesce(Seq(Cast(tsArrayLit, ArrayType(StringType)),
605         Cast(intArrayLit, ArrayType(StringType)), Cast(strArrayLit, ArrayType(StringType)))))
606   }
607
608   test("CreateArray casts") {
609     ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
610       CreateArray(Literal(1.0)
611         :: Literal(1)
612         :: Literal.create(1.0, FloatType)
613         :: Nil),
614       CreateArray(Cast(Literal(1.0), DoubleType)
615         :: Cast(Literal(1), DoubleType)
616         :: Cast(Literal.create(1.0, FloatType), DoubleType)
617         :: Nil))
618
619     ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
620       CreateArray(Literal(1.0)
621         :: Literal(1)
622         :: Literal("a")
623         :: Nil),
624       CreateArray(Cast(Literal(1.0), StringType)
625         :: Cast(Literal(1), StringType)
626         :: Cast(Literal("a"), StringType)
627         :: Nil))
628
629     ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
630       CreateArray(Literal.create(null, DecimalType(5, 3))
631         :: Literal(1)
632         :: Nil),
633       CreateArray(Literal.create(null, DecimalType(5, 3)).cast(DecimalType(13, 3))
634         :: Literal(1).cast(DecimalType(13, 3))
635         :: Nil))
636
637     ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
638       CreateArray(Literal.create(null, DecimalType(5, 3))
639         :: Literal.create(null, DecimalType(22, 10))
640         :: Literal.create(null, DecimalType(38, 38))
641         :: Nil),
642       CreateArray(Literal.create(null, DecimalType(5, 3)).cast(DecimalType(38, 38))
643         :: Literal.create(null, DecimalType(22, 10)).cast(DecimalType(38, 38))
644         :: Literal.create(null, DecimalType(38, 38)).cast(DecimalType(38, 38))
645         :: Nil))
646   }
647
648   test("CreateMap casts") {
649     // type coercion for map keys
650     ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
651       CreateMap(Literal(1)
652         :: Literal("a")
653         :: Literal.create(2.0, FloatType)
654         :: Literal("b")
655         :: Nil),
656       CreateMap(Cast(Literal(1), FloatType)
657         :: Literal("a")
658         :: Cast(Literal.create(2.0, FloatType), FloatType)
659         :: Literal("b")
660         :: Nil))
661     ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
662       CreateMap(Literal.create(null, DecimalType(5, 3))
663         :: Literal("a")
664         :: Literal.create(2.0, FloatType)
665         :: Literal("b")
666         :: Nil),
667       CreateMap(Literal.create(null, DecimalType(5, 3)).cast(DoubleType)
668         :: Literal("a")
669         :: Literal.create(2.0, FloatType).cast(DoubleType)
670         :: Literal("b")
671         :: Nil))
672     // type coercion for map values
673     ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
674       CreateMap(Literal(1)
675         :: Literal("a")
676         :: Literal(2)
677         :: Literal(3.0)
678         :: Nil),
679       CreateMap(Literal(1)
680         :: Cast(Literal("a"), StringType)
681         :: Literal(2)
682         :: Cast(Literal(3.0), StringType)
683         :: Nil))
684     ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
685       CreateMap(Literal(1)
686         :: Literal.create(null, DecimalType(38, 0))
687         :: Literal(2)
688         :: Literal.create(null, DecimalType(38, 38))
689         :: Nil),
690       CreateMap(Literal(1)
691         :: Literal.create(null, DecimalType(38, 0)).cast(DecimalType(38, 38))
692         :: Literal(2)
693         :: Literal.create(null, DecimalType(38, 38)).cast(DecimalType(38, 38))
694         :: Nil))
695     // type coercion for both map keys and values
696     ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
697       CreateMap(Literal(1)
698         :: Literal("a")
699         :: Literal(2.0)
700         :: Literal(3.0)
701         :: Nil),
702       CreateMap(Cast(Literal(1), DoubleType)
703         :: Cast(Literal("a"), StringType)
704         :: Cast(Literal(2.0), DoubleType)
705         :: Cast(Literal(3.0), StringType)
706         :: Nil))
707   }
708
709   test("greatest/least cast") {
710     for (operator <- Seq[(Seq[Expression] => Expression)](Greatest, Least)) {
711       ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
712         operator(Literal(1.0)
713           :: Literal(1)
714           :: Literal.create(1.0, FloatType)
715           :: Nil),
716         operator(Cast(Literal(1.0), DoubleType)
717           :: Cast(Literal(1), DoubleType)
718           :: Cast(Literal.create(1.0, FloatType), DoubleType)
719           :: Nil))
720       ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
721         operator(Literal(1L)
722           :: Literal(1)
723           :: Literal(new java.math.BigDecimal("1000000000000000000000"))
724           :: Nil),
725         operator(Cast(Literal(1L), DecimalType(22, 0))
726           :: Cast(Literal(1), DecimalType(22, 0))
727           :: Cast(Literal(new java.math.BigDecimal("1000000000000000000000")), DecimalType(22, 0))
728           :: Nil))
729       ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
730         operator(Literal(1.0)
731           :: Literal.create(null, DecimalType(10, 5))
732           :: Literal(1)
733           :: Nil),
734         operator(Literal(1.0).cast(DoubleType)
735           :: Literal.create(null, DecimalType(10, 5)).cast(DoubleType)
736           :: Literal(1).cast(DoubleType)
737           :: Nil))
738       ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
739         operator(Literal.create(null, DecimalType(15, 0))
740           :: Literal.create(null, DecimalType(10, 5))
741           :: Literal(1)
742           :: Nil),
743         operator(Literal.create(null, DecimalType(15, 0)).cast(DecimalType(20, 5))
744           :: Literal.create(null, DecimalType(10, 5)).cast(DecimalType(20, 5))
745           :: Literal(1).cast(DecimalType(20, 5))
746           :: Nil))
747       ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
748         operator(Literal.create(2L, LongType)
749           :: Literal(1)
750           :: Literal.create(null, DecimalType(10, 5))
751           :: Nil),
752         operator(Literal.create(2L, LongType).cast(DecimalType(25, 5))
753           :: Literal(1).cast(DecimalType(25, 5))
754           :: Literal.create(null, DecimalType(10, 5)).cast(DecimalType(25, 5))
755           :: Nil))
756     }
757   }
758
759   test("nanvl casts") {
760     ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
761       NaNvl(Literal.create(1.0f, FloatType), Literal.create(1.0, DoubleType)),
762       NaNvl(Cast(Literal.create(1.0f, FloatType), DoubleType), Literal.create(1.0, DoubleType)))
763     ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
764       NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0f, FloatType)),
765       NaNvl(Literal.create(1.0, DoubleType), Cast(Literal.create(1.0f, FloatType), DoubleType)))
766     ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
767       NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0, DoubleType)),
768       NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0, DoubleType)))
769     ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
770       NaNvl(Literal.create(1.0f, FloatType), Literal.create(null, NullType)),
771       NaNvl(Literal.create(1.0f, FloatType), Cast(Literal.create(null, NullType), FloatType)))
772     ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
773       NaNvl(Literal.create(1.0, DoubleType), Literal.create(null, NullType)),
774       NaNvl(Literal.create(1.0, DoubleType), Cast(Literal.create(null, NullType), DoubleType)))
775   }
776
777   test("type coercion for If") {
778     val rule = TypeCoercion.IfCoercion(conf)
779     val intLit = Literal(1)
780     val doubleLit = Literal(1.0)
781     val trueLit = Literal.create(true, BooleanType)
782     val falseLit = Literal.create(false, BooleanType)
783     val stringLit = Literal.create("c", StringType)
784     val floatLit = Literal.create(1.0f, FloatType)
785     val timestampLit = Literal.create("2017-04-12", TimestampType)
786     val decimalLit = Literal(new java.math.BigDecimal("1000000000000000000000"))
787
788     ruleTest(rule,
789       If(Literal(true), Literal(1), Literal(1L)),
790       If(Literal(true), Cast(Literal(1), LongType), Literal(1L)))
791
792     ruleTest(rule,
793       If(Literal.create(null, NullType), Literal(1), Literal(1)),
794       If(Literal.create(null, BooleanType), Literal(1), Literal(1)))
795
796     ruleTest(rule,
797       If(AssertTrue(trueLit), Literal(1), Literal(2)),
798       If(Cast(AssertTrue(trueLit), BooleanType), Literal(1), Literal(2)))
799
800     ruleTest(rule,
801       If(AssertTrue(falseLit), Literal(1), Literal(2)),
802       If(Cast(AssertTrue(falseLit), BooleanType), Literal(1), Literal(2)))
803
804     ruleTest(rule,
805       If(trueLit, intLit, doubleLit),
806       If(trueLit, Cast(intLit, DoubleType), doubleLit))
807
808     ruleTest(rule,
809       If(trueLit, floatLit, doubleLit),
810       If(trueLit, Cast(floatLit, DoubleType), doubleLit))
811
812     ruleTest(rule,
813       If(trueLit, floatLit, decimalLit),
814       If(trueLit, Cast(floatLit, DoubleType), Cast(decimalLit, DoubleType)))
815
816     ruleTest(rule,
817       If(falseLit, stringLit, doubleLit),
818       If(falseLit, stringLit, Cast(doubleLit, StringType)))
819
820     ruleTest(rule,
821       If(trueLit, timestampLit, stringLit),
822       If(trueLit, Cast(timestampLit, StringType), stringLit))
823   }
824
825   test("type coercion for CaseKeyWhen") {
826     ruleTest(TypeCoercion.ImplicitTypeCasts(conf),
827       CaseKeyWhen(Literal(1.toShort), Seq(Literal(1), Literal("a"))),
828       CaseKeyWhen(Cast(Literal(1.toShort), IntegerType), Seq(Literal(1), Literal("a")))
829     )
830     ruleTest(TypeCoercion.CaseWhenCoercion(conf),
831       CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a"))),
832       CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a")))
833     )
834     ruleTest(TypeCoercion.CaseWhenCoercion(conf),
835       CaseWhen(Seq((Literal(true), Literal(1.2))), Literal.create(1, DecimalType(7, 2))),
836       CaseWhen(Seq((Literal(true), Literal(1.2))),
837         Cast(Literal.create(1, DecimalType(7, 2)), DoubleType))
838     )
839     ruleTest(TypeCoercion.CaseWhenCoercion(conf),
840       CaseWhen(Seq((Literal(true), Literal(100L))), Literal.create(1, DecimalType(7, 2))),
841       CaseWhen(Seq((Literal(true), Cast(Literal(100L), DecimalType(22, 2)))),
842         Cast(Literal.create(1, DecimalType(7, 2)), DecimalType(22, 2)))
843     )
844   }
845
846   test("type coercion for Stack") {
847     val rule = TypeCoercion.StackCoercion
848
849     ruleTest(rule,
850       Stack(Seq(Literal(3), Literal(1), Literal(2), Literal(null))),
851       Stack(Seq(Literal(3), Literal(1), Literal(2), Literal.create(null, IntegerType))))
852     ruleTest(rule,
853       Stack(Seq(Literal(3), Literal(1.0), Literal(null), Literal(3.0))),
854       Stack(Seq(Literal(3), Literal(1.0), Literal.create(null, DoubleType), Literal(3.0))))
855     ruleTest(rule,
856       Stack(Seq(Literal(3), Literal(null), Literal("2"), Literal("3"))),
857       Stack(Seq(Literal(3), Literal.create(null, StringType), Literal("2"), Literal("3"))))
858     ruleTest(rule,
859       Stack(Seq(Literal(3), Literal(null), Literal(null), Literal(null))),
860       Stack(Seq(Literal(3), Literal(null), Literal(null), Literal(null))))
861
862     ruleTest(rule,
863       Stack(Seq(Literal(2),
864         Literal(1), Literal("2"),
865         Literal(null), Literal(null))),
866       Stack(Seq(Literal(2),
867         Literal(1), Literal("2"),
868         Literal.create(null, IntegerType), Literal.create(null, StringType))))
869
870     ruleTest(rule,
871       Stack(Seq(Literal(2),
872         Literal(1), Literal(null),
873         Literal(null), Literal("2"))),
874       Stack(Seq(Literal(2),
875         Literal(1), Literal.create(null, StringType),
876         Literal.create(null, IntegerType), Literal("2"))))
877
878     ruleTest(rule,
879       Stack(Seq(Literal(2),
880         Literal(null), Literal(1),
881         Literal("2"), Literal(null))),
882       Stack(Seq(Literal(2),
883         Literal.create(null, StringType), Literal(1),
884         Literal("2"), Literal.create(null, IntegerType))))
885
886     ruleTest(rule,
887       Stack(Seq(Literal(2),
888         Literal(null), Literal(null),
889         Literal(1), Literal("2"))),
890       Stack(Seq(Literal(2),
891         Literal.create(null, IntegerType), Literal.create(null, StringType),
892         Literal(1), Literal("2"))))
893
894     ruleTest(rule,
895       Stack(Seq(Subtract(Literal(3), Literal(1)),
896         Literal(1), Literal("2"),
897         Literal(null), Literal(null))),
898       Stack(Seq(Subtract(Literal(3), Literal(1)),
899         Literal(1), Literal("2"),
900         Literal.create(null, IntegerType), Literal.create(null, StringType))))
901   }
902
903   test("type coercion for Concat") {
904     val rule = TypeCoercion.ConcatCoercion(conf)
905
906     ruleTest(rule,
907       Concat(Seq(Literal("ab"), Literal("cde"))),
908       Concat(Seq(Literal("ab"), Literal("cde"))))
909     ruleTest(rule,
910       Concat(Seq(Literal(null), Literal("abc"))),
911       Concat(Seq(Cast(Literal(null), StringType), Literal("abc"))))
912     ruleTest(rule,
913       Concat(Seq(Literal(1), Literal("234"))),
914       Concat(Seq(Cast(Literal(1), StringType), Literal("234"))))
915     ruleTest(rule,
916       Concat(Seq(Literal("1"), Literal("234".getBytes()))),
917       Concat(Seq(Literal("1"), Cast(Literal("234".getBytes()), StringType))))
918     ruleTest(rule,
919       Concat(Seq(Literal(1L), Literal(2.toByte), Literal(0.1))),
920       Concat(Seq(Cast(Literal(1L), StringType), Cast(Literal(2.toByte), StringType),
921         Cast(Literal(0.1), StringType))))
922     ruleTest(rule,
923       Concat(Seq(Literal(true), Literal(0.1f), Literal(3.toShort))),
924       Concat(Seq(Cast(Literal(true), StringType), Cast(Literal(0.1f), StringType),
925         Cast(Literal(3.toShort), StringType))))
926     ruleTest(rule,
927       Concat(Seq(Literal(1L), Literal(0.1))),
928       Concat(Seq(Cast(Literal(1L), StringType), Cast(Literal(0.1), StringType))))
929     ruleTest(rule,
930       Concat(Seq(Literal(Decimal(10)))),
931       Concat(Seq(Cast(Literal(Decimal(10)), StringType))))
932     ruleTest(rule,
933       Concat(Seq(Literal(BigDecimal.valueOf(10)))),
934       Concat(Seq(Cast(Literal(BigDecimal.valueOf(10)), StringType))))
935     ruleTest(rule,
936       Concat(Seq(Literal(java.math.BigDecimal.valueOf(10)))),
937       Concat(Seq(Cast(Literal(java.math.BigDecimal.valueOf(10)), StringType))))
938     ruleTest(rule,
939       Concat(Seq(Literal(new java.sql.Date(0)), Literal(new Timestamp(0)))),
940       Concat(Seq(Cast(Literal(new java.sql.Date(0)), StringType),
941         Cast(Literal(new Timestamp(0)), StringType))))
942
943     withSQLConf("spark.sql.function.concatBinaryAsString" -> "true") {
944       ruleTest(rule,
945         Concat(Seq(Literal("123".getBytes), Literal("456".getBytes))),
946         Concat(Seq(Cast(Literal("123".getBytes), StringType),
947           Cast(Literal("456".getBytes), StringType))))
948     }
949
950     withSQLConf("spark.sql.function.concatBinaryAsString" -> "false") {
951       ruleTest(rule,
952         Concat(Seq(Literal("123".getBytes), Literal("456".getBytes))),
953         Concat(Seq(Literal("123".getBytes), Literal("456".getBytes))))
954     }
955   }
956
957   test("type coercion for Elt") {
958     val rule = TypeCoercion.EltCoercion(conf)
959
960     ruleTest(rule,
961       Elt(Seq(Literal(1), Literal("ab"), Literal("cde"))),
962       Elt(Seq(Literal(1), Literal("ab"), Literal("cde"))))
963     ruleTest(rule,
964       Elt(Seq(Literal(1.toShort), Literal("ab"), Literal("cde"))),
965       Elt(Seq(Cast(Literal(1.toShort), IntegerType), Literal("ab"), Literal("cde"))))
966     ruleTest(rule,
967       Elt(Seq(Literal(2), Literal(null), Literal("abc"))),
968       Elt(Seq(Literal(2), Cast(Literal(null), StringType), Literal("abc"))))
969     ruleTest(rule,
970       Elt(Seq(Literal(2), Literal(1), Literal("234"))),
971       Elt(Seq(Literal(2), Cast(Literal(1), StringType), Literal("234"))))
972     ruleTest(rule,
973       Elt(Seq(Literal(3), Literal(1L), Literal(2.toByte), Literal(0.1))),
974       Elt(Seq(Literal(3), Cast(Literal(1L), StringType), Cast(Literal(2.toByte), StringType),
975         Cast(Literal(0.1), StringType))))
976     ruleTest(rule,
977       Elt(Seq(Literal(2), Literal(true), Literal(0.1f), Literal(3.toShort))),
978       Elt(Seq(Literal(2), Cast(Literal(true), StringType), Cast(Literal(0.1f), StringType),
979         Cast(Literal(3.toShort), StringType))))
980     ruleTest(rule,
981       Elt(Seq(Literal(1), Literal(1L), Literal(0.1))),
982       Elt(Seq(Literal(1), Cast(Literal(1L), StringType), Cast(Literal(0.1), StringType))))
983     ruleTest(rule,
984       Elt(Seq(Literal(1), Literal(Decimal(10)))),
985       Elt(Seq(Literal(1), Cast(Literal(Decimal(10)), StringType))))
986     ruleTest(rule,
987       Elt(Seq(Literal(1), Literal(BigDecimal.valueOf(10)))),
988       Elt(Seq(Literal(1), Cast(Literal(BigDecimal.valueOf(10)), StringType))))
989     ruleTest(rule,
990       Elt(Seq(Literal(1), Literal(java.math.BigDecimal.valueOf(10)))),
991       Elt(Seq(Literal(1), Cast(Literal(java.math.BigDecimal.valueOf(10)), StringType))))
992     ruleTest(rule,
993       Elt(Seq(Literal(2), Literal(new java.sql.Date(0)), Literal(new Timestamp(0)))),
994       Elt(Seq(Literal(2), Cast(Literal(new java.sql.Date(0)), StringType),
995         Cast(Literal(new Timestamp(0)), StringType))))
996
997     withSQLConf("spark.sql.function.eltOutputAsString" -> "true") {
998       ruleTest(rule,
999         Elt(Seq(Literal(1), Literal("123".getBytes), Literal("456".getBytes))),
1000         Elt(Seq(Literal(1), Cast(Literal("123".getBytes), StringType),
1001           Cast(Literal("456".getBytes), StringType))))
1002     }
1003
1004     withSQLConf("spark.sql.function.eltOutputAsString" -> "false") {
1005       ruleTest(rule,
1006         Elt(Seq(Literal(1), Literal("123".getBytes), Literal("456".getBytes))),
1007         Elt(Seq(Literal(1), Literal("123".getBytes), Literal("456".getBytes))))
1008     }
1009   }
1010
1011   test("BooleanEquality type cast") {
1012     val be = TypeCoercion.BooleanEquality
1013     // Use something more than a literal to avoid triggering the simplification rules.
1014     val one = Add(Literal(Decimal(1)), Literal(Decimal(0)))
1015
1016     ruleTest(be,
1017       EqualTo(Literal(true), one),
1018       EqualTo(Cast(Literal(true), one.dataType), one)
1019     )
1020
1021     ruleTest(be,
1022       EqualTo(one, Literal(true)),
1023       EqualTo(one, Cast(Literal(true), one.dataType))
1024     )
1025
1026     ruleTest(be,
1027       EqualNullSafe(Literal(true), one),
1028       EqualNullSafe(Cast(Literal(true), one.dataType), one)
1029     )
1030
1031     ruleTest(be,
1032       EqualNullSafe(one, Literal(true)),
1033       EqualNullSafe(one, Cast(Literal(true), one.dataType))
1034     )
1035   }
1036
1037   test("BooleanEquality simplification") {
1038     val be = TypeCoercion.BooleanEquality
1039
1040     ruleTest(be,
1041       EqualTo(Literal(true), Literal(1)),
1042       Literal(true)
1043     )
1044     ruleTest(be,
1045       EqualTo(Literal(true), Literal(0)),
1046       Not(Literal(true))
1047     )
1048     ruleTest(be,
1049       EqualNullSafe(Literal(true), Literal(1)),
1050       And(IsNotNull(Literal(true)), Literal(true))
1051     )
1052     ruleTest(be,
1053       EqualNullSafe(Literal(true), Literal(0)),
1054       And(IsNotNull(Literal(true)), Not(Literal(true)))
1055     )
1056
1057     ruleTest(be,
1058       EqualTo(Literal(true), Literal(1L)),
1059       Literal(true)
1060     )
1061     ruleTest(be,
1062       EqualTo(Literal(new java.math.BigDecimal(1)), Literal(true)),
1063       Literal(true)
1064     )
1065     ruleTest(be,
1066       EqualTo(Literal(BigDecimal(0)), Literal(true)),
1067       Not(Literal(true))
1068     )
1069     ruleTest(be,
1070       EqualTo(Literal(Decimal(1)), Literal(true)),
1071       Literal(true)
1072     )
1073     ruleTest(be,
1074       EqualTo(Literal.create(Decimal(1), DecimalType(8, 0)), Literal(true)),
1075       Literal(true)
1076     )
1077   }
1078
1079   private def checkOutput(logical: LogicalPlan, expectTypes: Seq[DataType]): Unit = {
1080     logical.output.zip(expectTypes).foreach { case (attr, dt) =>
1081       assert(attr.dataType === dt)
1082     }
1083   }
1084
1085   private val timeZoneResolver = ResolveTimeZone(new SQLConf)
1086
1087   private def widenSetOperationTypes(plan: LogicalPlan): LogicalPlan = {
1088     timeZoneResolver(TypeCoercion.WidenSetOperationTypes(conf)(plan))
1089   }
1090
1091   test("WidenSetOperationTypes for except and intersect") {
1092     val firstTable = LocalRelation(
1093       AttributeReference("i", IntegerType)(),
1094       AttributeReference("u", DecimalType.SYSTEM_DEFAULT)(),
1095       AttributeReference("b", ByteType)(),
1096       AttributeReference("d", DoubleType)())
1097     val secondTable = LocalRelation(
1098       AttributeReference("s", StringType)(),
1099       AttributeReference("d", DecimalType(2, 1))(),
1100       AttributeReference("f", FloatType)(),
1101       AttributeReference("l", LongType)())
1102
1103     val expectedTypes = Seq(StringType, DecimalType.SYSTEM_DEFAULT, FloatType, DoubleType)
1104
1105     val r1 = widenSetOperationTypes(Except(firstTable, secondTable)).asInstanceOf[Except]
1106     val r2 = widenSetOperationTypes(Intersect(firstTable, secondTable)).asInstanceOf[Intersect]
1107     checkOutput(r1.left, expectedTypes)
1108     checkOutput(r1.right, expectedTypes)
1109     checkOutput(r2.left, expectedTypes)
1110     checkOutput(r2.right, expectedTypes)
1111
1112     // Check if a Project is added
1113     assert(r1.left.isInstanceOf[Project])
1114     assert(r1.right.isInstanceOf[Project])
1115     assert(r2.left.isInstanceOf[Project])
1116     assert(r2.right.isInstanceOf[Project])
1117   }
1118
1119   test("WidenSetOperationTypes for union") {
1120     val firstTable = LocalRelation(
1121       AttributeReference("i", IntegerType)(),
1122       AttributeReference("u", DecimalType.SYSTEM_DEFAULT)(),
1123       AttributeReference("b", ByteType)(),
1124       AttributeReference("d", DoubleType)())
1125     val secondTable = LocalRelation(
1126       AttributeReference("s", StringType)(),
1127       AttributeReference("d", DecimalType(2, 1))(),
1128       AttributeReference("f", FloatType)(),
1129       AttributeReference("l", LongType)())
1130     val thirdTable = LocalRelation(
1131       AttributeReference("m", StringType)(),
1132       AttributeReference("n", DecimalType.SYSTEM_DEFAULT)(),
1133       AttributeReference("p", FloatType)(),
1134       AttributeReference("q", DoubleType)())
1135     val forthTable = LocalRelation(
1136       AttributeReference("m", StringType)(),
1137       AttributeReference("n", DecimalType.SYSTEM_DEFAULT)(),
1138       AttributeReference("p", ByteType)(),
1139       AttributeReference("q", DoubleType)())
1140
1141     val expectedTypes = Seq(StringType, DecimalType.SYSTEM_DEFAULT, FloatType, DoubleType)
1142
1143     val unionRelation = widenSetOperationTypes(
1144       Union(firstTable :: secondTable :: thirdTable :: forthTable :: Nil)).asInstanceOf[Union]
1145     assert(unionRelation.children.length == 4)
1146     checkOutput(unionRelation.children.head, expectedTypes)
1147     checkOutput(unionRelation.children(1), expectedTypes)
1148     checkOutput(unionRelation.children(2), expectedTypes)
1149     checkOutput(unionRelation.children(3), expectedTypes)
1150
1151     assert(unionRelation.children.head.isInstanceOf[Project])
1152     assert(unionRelation.children(1).isInstanceOf[Project])
1153     assert(unionRelation.children(2).isInstanceOf[Project])
1154     assert(unionRelation.children(3).isInstanceOf[Project])
1155   }
1156
1157   test("Transform Decimal precision/scale for union except and intersect") {
1158     def checkOutput(logical: LogicalPlan, expectTypes: Seq[DataType]): Unit = {
1159       logical.output.zip(expectTypes).foreach { case (attr, dt) =>
1160         assert(attr.dataType === dt)
1161       }
1162     }
1163
1164     val left1 = LocalRelation(
1165       AttributeReference("l", DecimalType(10, 8))())
1166     val right1 = LocalRelation(
1167       AttributeReference("r", DecimalType(5, 5))())
1168     val expectedType1 = Seq(DecimalType(10, 8))
1169
1170     val r1 = widenSetOperationTypes(Union(left1, right1)).asInstanceOf[Union]
1171     val r2 = widenSetOperationTypes(Except(left1, right1)).asInstanceOf[Except]
1172     val r3 = widenSetOperationTypes(Intersect(left1, right1)).asInstanceOf[Intersect]
1173
1174     checkOutput(r1.children.head, expectedType1)
1175     checkOutput(r1.children.last, expectedType1)
1176     checkOutput(r2.left, expectedType1)
1177     checkOutput(r2.right, expectedType1)
1178     checkOutput(r3.left, expectedType1)
1179     checkOutput(r3.right, expectedType1)
1180
1181     val plan1 = LocalRelation(AttributeReference("l", DecimalType(10, 5))())
1182
1183     val rightTypes = Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType)
1184     val expectedTypes = Seq(DecimalType(10, 5), DecimalType(10, 5), DecimalType(15, 5),
1185       DecimalType(25, 5), DoubleType, DoubleType)
1186
1187     rightTypes.zip(expectedTypes).foreach { case (rType, expectedType) =>
1188       val plan2 = LocalRelation(
1189         AttributeReference("r", rType)())
1190
1191       val r1 = widenSetOperationTypes(Union(plan1, plan2)).asInstanceOf[Union]
1192       val r2 = widenSetOperationTypes(Except(plan1, plan2)).asInstanceOf[Except]
1193       val r3 = widenSetOperationTypes(Intersect(plan1, plan2)).asInstanceOf[Intersect]
1194
1195       checkOutput(r1.children.last, Seq(expectedType))
1196       checkOutput(r2.right, Seq(expectedType))
1197       checkOutput(r3.right, Seq(expectedType))
1198
1199       val r4 = widenSetOperationTypes(Union(plan2, plan1)).asInstanceOf[Union]
1200       val r5 = widenSetOperationTypes(Except(plan2, plan1)).asInstanceOf[Except]
1201       val r6 = widenSetOperationTypes(Intersect(plan2, plan1)).asInstanceOf[Intersect]
1202
1203       checkOutput(r4.children.last, Seq(expectedType))
1204       checkOutput(r5.left, Seq(expectedType))
1205       checkOutput(r6.left, Seq(expectedType))
1206     }
1207   }
1208
1209   test("rule for date/timestamp operations") {
1210     val dateTimeOperations = TypeCoercion.DateTimeOperations
1211     val date = Literal(new java.sql.Date(0L))
1212     val timestamp = Literal(new Timestamp(0L))
1213     val interval = Literal(new CalendarInterval(0, 0))
1214     val str = Literal("2015-01-01")
1215
1216     ruleTest(dateTimeOperations, Add(date, interval), Cast(TimeAdd(date, interval), DateType))
1217     ruleTest(dateTimeOperations, Add(interval, date), Cast(TimeAdd(date, interval), DateType))
1218     ruleTest(dateTimeOperations, Add(timestamp, interval),
1219       Cast(TimeAdd(timestamp, interval), TimestampType))
1220     ruleTest(dateTimeOperations, Add(interval, timestamp),
1221       Cast(TimeAdd(timestamp, interval), TimestampType))
1222     ruleTest(dateTimeOperations, Add(str, interval), Cast(TimeAdd(str, interval), StringType))
1223     ruleTest(dateTimeOperations, Add(interval, str), Cast(TimeAdd(str, interval), StringType))
1224
1225     ruleTest(dateTimeOperations, Subtract(date, interval), Cast(TimeSub(date, interval), DateType))
1226     ruleTest(dateTimeOperations, Subtract(timestamp, interval),
1227       Cast(TimeSub(timestamp, interval), TimestampType))
1228     ruleTest(dateTimeOperations, Subtract(str, interval), Cast(TimeSub(str, interval), StringType))
1229
1230     // interval operations should not be effected
1231     ruleTest(dateTimeOperations, Add(interval, interval), Add(interval, interval))
1232     ruleTest(dateTimeOperations, Subtract(interval, interval), Subtract(interval, interval))
1233   }
1234
1235   /**
1236    * There are rules that need to not fire before child expressions get resolved.
1237    * We use this test to make sure those rules do not fire early.
1238    */
1239   test("make sure rules do not fire early") {
1240     // InConversion
1241     val inConversion = TypeCoercion.InConversion(conf)
1242     ruleTest(inConversion,
1243       In(UnresolvedAttribute("a"), Seq(Literal(1))),
1244       In(UnresolvedAttribute("a"), Seq(Literal(1)))
1245     )
1246     ruleTest(inConversion,
1247       In(Literal("test"), Seq(UnresolvedAttribute("a"), Literal(1))),
1248       In(Literal("test"), Seq(UnresolvedAttribute("a"), Literal(1)))
1249     )
1250     ruleTest(inConversion,
1251       In(Literal("a"), Seq(Literal(1), Literal("b"))),
1252       In(Cast(Literal("a"), StringType),
1253         Seq(Cast(Literal(1), StringType), Cast(Literal("b"), StringType)))
1254     )
1255   }
1256
1257   test("SPARK-15776 Divide expression's dataType should be casted to Double or Decimal " +
1258     "in aggregation function like sum") {
1259     val rules = Seq(FunctionArgumentConversion(conf), Division)
1260     // Casts Integer to Double
1261     ruleTest(rules, sum(Divide(4, 3)), sum(Divide(Cast(4, DoubleType), Cast(3, DoubleType))))
1262     // Left expression is Double, right expression is Int. Another rule ImplicitTypeCasts will
1263     // cast the right expression to Double.
1264     ruleTest(rules, sum(Divide(4.0, 3)), sum(Divide(4.0, 3)))
1265     // Left expression is Int, right expression is Double
1266     ruleTest(rules, sum(Divide(4, 3.0)), sum(Divide(Cast(4, DoubleType), Cast(3.0, DoubleType))))
1267     // Casts Float to Double
1268     ruleTest(
1269       rules,
1270       sum(Divide(4.0f, 3)),
1271       sum(Divide(Cast(4.0f, DoubleType), Cast(3, DoubleType))))
1272     // Left expression is Decimal, right expression is Int. Another rule DecimalPrecision will cast
1273     // the right expression to Decimal.
1274     ruleTest(rules, sum(Divide(Decimal(4.0), 3)), sum(Divide(Decimal(4.0), 3)))
1275   }
1276
1277   test("SPARK-17117 null type coercion in divide") {
1278     val rules = Seq(FunctionArgumentConversion(conf), Division, ImplicitTypeCasts(conf))
1279     val nullLit = Literal.create(null, NullType)
1280     ruleTest(rules, Divide(1L, nullLit), Divide(Cast(1L, DoubleType), Cast(nullLit, DoubleType)))
1281     ruleTest(rules, Divide(nullLit, 1L), Divide(Cast(nullLit, DoubleType), Cast(1L, DoubleType)))
1282   }
1283
1284   test("binary comparison with string promotion") {
1285     val rule = TypeCoercion.PromoteStrings(conf)
1286     ruleTest(rule,
1287       GreaterThan(Literal("123"), Literal(1)),
1288       GreaterThan(Cast(Literal("123"), IntegerType), Literal(1)))
1289     ruleTest(rule,
1290       LessThan(Literal(true), Literal("123")),
1291       LessThan(Literal(true), Cast(Literal("123"), BooleanType)))
1292     ruleTest(rule,
1293       EqualTo(Literal(Array(1, 2)), Literal("123")),
1294       EqualTo(Literal(Array(1, 2)), Literal("123")))
1295     ruleTest(rule,
1296       GreaterThan(Literal("1.5"), Literal(BigDecimal("0.5"))),
1297       GreaterThan(Cast(Literal("1.5"), DoubleType), Cast(Literal(BigDecimal("0.5")),
1298         DoubleType)))
1299     Seq(true, false).foreach { convertToTS =>
1300       withSQLConf(
1301         "spark.sql.typeCoercion.compareDateTimestampInTimestamp" -> convertToTS.toString) {
1302         val date0301 = Literal(java.sql.Date.valueOf("2017-03-01"))
1303         val timestamp0301000000 = Literal(Timestamp.valueOf("2017-03-01 00:00:00"))
1304         val timestamp0301000001 = Literal(Timestamp.valueOf("2017-03-01 00:00:01"))
1305         if (convertToTS) {
1306           // `Date` should be treated as timestamp at 00:00:00 See SPARK-23549
1307           ruleTest(rule, EqualTo(date0301, timestamp0301000000),
1308             EqualTo(Cast(date0301, TimestampType), timestamp0301000000))
1309           ruleTest(rule, LessThan(date0301, timestamp0301000001),
1310             LessThan(Cast(date0301, TimestampType), timestamp0301000001))
1311         } else {
1312           ruleTest(rule, LessThan(date0301, timestamp0301000000),
1313             LessThan(Cast(date0301, StringType), Cast(timestamp0301000000, StringType)))
1314           ruleTest(rule, LessThan(date0301, timestamp0301000001),
1315             LessThan(Cast(date0301, StringType), Cast(timestamp0301000001, StringType)))
1316         }
1317       }
1318     }
1319   }
1320
1321   test("cast WindowFrame boundaries to the type they operate upon") {
1322     // Can cast frame boundaries to order dataType.
1323     ruleTest(WindowFrameCoercion,
1324       windowSpec(
1325         Seq(UnresolvedAttribute("a")),
1326         Seq(SortOrder(Literal(1L), Ascending)),
1327         SpecifiedWindowFrame(RangeFrame, Literal(3), Literal(2147483648L))),
1328       windowSpec(
1329         Seq(UnresolvedAttribute("a")),
1330         Seq(SortOrder(Literal(1L), Ascending)),
1331         SpecifiedWindowFrame(RangeFrame, Cast(3, LongType), Literal(2147483648L)))
1332     )
1333     // Cannot cast frame boundaries to order dataType.
1334     ruleTest(WindowFrameCoercion,
1335       windowSpec(
1336         Seq(UnresolvedAttribute("a")),
1337         Seq(SortOrder(Literal.default(DateType), Ascending)),
1338         SpecifiedWindowFrame(RangeFrame, Literal(10.0), Literal(2147483648L))),
1339       windowSpec(
1340         Seq(UnresolvedAttribute("a")),
1341         Seq(SortOrder(Literal.default(DateType), Ascending)),
1342         SpecifiedWindowFrame(RangeFrame, Literal(10.0), Literal(2147483648L)))
1343     )
1344     // Should not cast SpecialFrameBoundary.
1345     ruleTest(WindowFrameCoercion,
1346       windowSpec(
1347         Seq(UnresolvedAttribute("a")),
1348         Seq(SortOrder(Literal(1L), Ascending)),
1349         SpecifiedWindowFrame(RangeFrame, CurrentRow, UnboundedFollowing)),
1350       windowSpec(
1351         Seq(UnresolvedAttribute("a")),
1352         Seq(SortOrder(Literal(1L), Ascending)),
1353         SpecifiedWindowFrame(RangeFrame, CurrentRow, UnboundedFollowing))
1354     )
1355   }
1356 }
1357
1358
1359 object TypeCoercionSuite {
1360
1361   case class AnyTypeUnaryExpression(child: Expression)
1362     extends UnaryExpression with ExpectsInputTypes with Unevaluable {
1363     override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
1364     override def dataType: DataType = NullType
1365   }
1366
1367   case class NumericTypeUnaryExpression(child: Expression)
1368     extends UnaryExpression with ExpectsInputTypes with Unevaluable {
1369     override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)
1370     override def dataType: DataType = NullType
1371   }
1372
1373   case class AnyTypeBinaryOperator(left: Expression, right: Expression)
1374     extends BinaryOperator with Unevaluable {
1375     override def dataType: DataType = NullType
1376     override def inputType: AbstractDataType = AnyDataType
1377     override def symbol: String = "anytype"
1378   }
1379
1380   case class NumericTypeBinaryOperator(left: Expression, right: Expression)
1381     extends BinaryOperator with Unevaluable {
1382     override def dataType: DataType = NullType
1383     override def inputType: AbstractDataType = NumericType
1384     override def symbol: String = "numerictype"
1385   }
1386 }