[SPARK-24208][SQL][FOLLOWUP] Move test cases to proper locations
authorMarco Gaido <marcogaido91@gmail.com>
Thu, 12 Jul 2018 22:13:26 +0000 (15:13 -0700)
committerXiao Li <gatorsmile@gmail.com>
Thu, 12 Jul 2018 22:13:26 +0000 (15:13 -0700)
## What changes were proposed in this pull request?

The PR is a followup to move the test cases introduced by the original PR in their proper location.

## How was this patch tested?

moved UTs

Author: Marco Gaido <marcogaido91@gmail.com>

Closes #21751 from mgaido91/SPARK-24208_followup.

python/pyspark/sql/tests.py
sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
sql/core/src/test/scala/org/apache/spark/sql/GroupedDatasetSuite.scala

index 4404dbe..565654e 100644 (file)
@@ -5471,6 +5471,22 @@ class GroupedMapPandasUDFTests(ReusedSQLTestCase):
                 self.assertEqual(r.a, 'hi')
                 self.assertEqual(r.b, 1)
 
+    def test_self_join_with_pandas(self):
+        import pyspark.sql.functions as F
+
+        @F.pandas_udf('key long, col string', F.PandasUDFType.GROUPED_MAP)
+        def dummy_pandas_udf(df):
+            return df[['key', 'col']]
+
+        df = self.spark.createDataFrame([Row(key=1, col='A'), Row(key=1, col='B'),
+                                         Row(key=2, col='C')])
+        df_with_pandas = df.groupBy('key').apply(dummy_pandas_udf)
+
+        # this was throwing an AnalysisException before SPARK-24208
+        res = df_with_pandas.alias('temp0').join(df_with_pandas.alias('temp1'),
+                                                 F.col('temp0.key') == F.col('temp1.key'))
+        self.assertEquals(res.count(), 5)
+
 
 @unittest.skipIf(
     not _have_pandas or not _have_pyarrow,
@@ -5925,22 +5941,6 @@ class GroupedAggPandasUDFTests(ReusedSQLTestCase):
                     'mixture.*aggregate function.*group aggregate pandas UDF'):
                 df.groupby(df.id).agg(mean_udf(df.v), mean(df.v)).collect()
 
-    def test_self_join_with_pandas(self):
-        import pyspark.sql.functions as F
-
-        @F.pandas_udf('key long, col string', F.PandasUDFType.GROUPED_MAP)
-        def dummy_pandas_udf(df):
-            return df[['key', 'col']]
-
-        df = self.spark.createDataFrame([Row(key=1, col='A'), Row(key=1, col='B'),
-                                         Row(key=2, col='C')])
-        dfWithPandas = df.groupBy('key').apply(dummy_pandas_udf)
-
-        # this was throwing an AnalysisException before SPARK-24208
-        res = dfWithPandas.alias('temp0').join(dfWithPandas.alias('temp1'),
-                                               F.col('temp0.key') == F.col('temp1.key'))
-        self.assertEquals(res.count(), 5)
-
 
 @unittest.skipIf(
     not _have_pandas or not _have_pyarrow,
index cd85795..bbcdf6c 100644 (file)
@@ -21,6 +21,7 @@ import java.util.TimeZone
 
 import org.scalatest.Matchers
 
+import org.apache.spark.api.python.PythonEvalType
 import org.apache.spark.sql.catalyst.TableIdentifier
 import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.dsl.plans._
@@ -557,4 +558,21 @@ class AnalysisSuite extends AnalysisTest with Matchers {
       SubqueryAlias("tbl", testRelation)))
     assertAnalysisError(barrier, Seq("cannot resolve '`tbl.b`'"))
   }
+
+  test("SPARK-24208: analysis fails on self-join with FlatMapGroupsInPandas") {
+    val pythonUdf = PythonUDF("pyUDF", null,
+      StructType(Seq(StructField("a", LongType))),
+      Seq.empty,
+      PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
+      true)
+    val output = pythonUdf.dataType.asInstanceOf[StructType].toAttributes
+    val project = Project(Seq(UnresolvedAttribute("a")), testRelation)
+    val flatMapGroupsInPandas = FlatMapGroupsInPandas(
+      Seq(UnresolvedAttribute("a")), pythonUdf, output, project)
+    val left = SubqueryAlias("temp0", flatMapGroupsInPandas)
+    val right = SubqueryAlias("temp1", flatMapGroupsInPandas)
+    val join = Join(left, right, Inner, None)
+    assertAnalysisSuccess(
+      Project(Seq(UnresolvedAttribute("temp0.a"), UnresolvedAttribute("temp1.a")), join))
+  }
 }
index bd54ea4..147c0b6 100644 (file)
@@ -93,16 +93,4 @@ class GroupedDatasetSuite extends QueryTest with SharedSQLContext {
     }
     datasetWithUDF.unpersist(true)
   }
-
-  test("SPARK-24208: analysis fails on self-join with FlatMapGroupsInPandas") {
-    val df = datasetWithUDF.groupBy("s").flatMapGroupsInPandas(PythonUDF(
-      "pyUDF",
-      null,
-      StructType(Seq(StructField("s", LongType))),
-      Seq.empty,
-      PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
-      true))
-    val df1 = df.alias("temp0").join(df.alias("temp1"), $"temp0.s" === $"temp1.s")
-    df1.queryExecution.assertAnalyzed()
-  }
 }