Revert "[SPARK-24776][SQL] Avro unit test: use SQLTestUtils and replace deprecated...
authorXiao Li <gatorsmile@gmail.com>
Fri, 13 Jul 2018 17:06:26 +0000 (10:06 -0700)
committerXiao Li <gatorsmile@gmail.com>
Fri, 13 Jul 2018 17:06:26 +0000 (10:06 -0700)
This reverts commit c1b62e420a43aa7da36733ccdbec057d87ac1b43.

external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala
external/avro/src/test/scala/org/apache/spark/sql/avro/TestUtils.scala [new file with mode: 0755]

index 108b347..c6c1e40 100644 (file)
@@ -31,24 +31,32 @@ import org.apache.avro.generic.{GenericData, GenericDatumWriter, GenericRecord}
 import org.apache.avro.generic.GenericData.{EnumSymbol, Fixed}
 import org.apache.commons.io.FileUtils
 
+import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql._
 import org.apache.spark.sql.avro.SchemaConverters.IncompatibleSchemaException
-import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils}
 import org.apache.spark.sql.types._
 
-class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
+class AvroSuite extends SparkFunSuite {
   val episodesFile = "src/test/resources/episodes.avro"
   val testFile = "src/test/resources/test.avro"
 
+  private var spark: SparkSession = _
+
   override protected def beforeAll(): Unit = {
     super.beforeAll()
-    spark.conf.set("spark.sql.files.maxPartitionBytes", 1024)
-  }
-
-  def checkReloadMatchesSaved(originalFile: String, newFile: String): Unit = {
-    val originalEntries = spark.read.avro(testFile).collect()
-    val newEntries = spark.read.avro(newFile)
-    checkAnswer(newEntries, originalEntries)
+    spark = SparkSession.builder()
+      .master("local[2]")
+      .appName("AvroSuite")
+      .config("spark.sql.files.maxPartitionBytes", 1024)
+      .getOrCreate()
+  }
+
+  override protected def afterAll(): Unit = {
+    try {
+      spark.sparkContext.stop()
+    } finally {
+      super.afterAll()
+    }
   }
 
   test("reading from multiple paths") {
@@ -60,7 +68,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
     val df = spark.read.avro(episodesFile)
     val fields = List("title", "air_date", "doctor")
     for (field <- fields) {
-      withTempPath { dir =>
+      TestUtils.withTempDir { dir =>
         val outputDir = s"$dir/${UUID.randomUUID}"
         df.write.partitionBy(field).avro(outputDir)
         val input = spark.read.avro(outputDir)
@@ -74,12 +82,12 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
 
   test("request no fields") {
     val df = spark.read.avro(episodesFile)
-    df.createOrReplaceTempView("avro_table")
+    df.registerTempTable("avro_table")
     assert(spark.sql("select count(*) from avro_table").collect().head === Row(8))
   }
 
   test("convert formats") {
-    withTempPath { dir =>
+    TestUtils.withTempDir { dir =>
       val df = spark.read.avro(episodesFile)
       df.write.parquet(dir.getCanonicalPath)
       assert(spark.read.parquet(dir.getCanonicalPath).count() === df.count)
@@ -87,16 +95,15 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
   }
 
   test("rearrange internal schema") {
-    withTempPath { dir =>
+    TestUtils.withTempDir { dir =>
       val df = spark.read.avro(episodesFile)
       df.select("doctor", "title").write.avro(dir.getCanonicalPath)
     }
   }
 
   test("test NULL avro type") {
-    withTempPath { dir =>
-      val fields =
-        Seq(new Field("null", Schema.create(Type.NULL), "doc", null.asInstanceOf[Any])).asJava
+    TestUtils.withTempDir { dir =>
+      val fields = Seq(new Field("null", Schema.create(Type.NULL), "doc", null)).asJava
       val schema = Schema.createRecord("name", "docs", "namespace", false)
       schema.setFields(fields)
       val datumWriter = new GenericDatumWriter[GenericRecord](schema)
@@ -115,11 +122,11 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
   }
 
   test("union(int, long) is read as long") {
-    withTempPath { dir =>
+    TestUtils.withTempDir { dir =>
       val avroSchema: Schema = {
         val union =
           Schema.createUnion(List(Schema.create(Type.INT), Schema.create(Type.LONG)).asJava)
-        val fields = Seq(new Field("field1", union, "doc", null.asInstanceOf[Any])).asJava
+        val fields = Seq(new Field("field1", union, "doc", null)).asJava
         val schema = Schema.createRecord("name", "docs", "namespace", false)
         schema.setFields(fields)
         schema
@@ -143,11 +150,11 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
   }
 
   test("union(float, double) is read as double") {
-    withTempPath { dir =>
+    TestUtils.withTempDir { dir =>
       val avroSchema: Schema = {
         val union =
           Schema.createUnion(List(Schema.create(Type.FLOAT), Schema.create(Type.DOUBLE)).asJava)
-        val fields = Seq(new Field("field1", union, "doc", null.asInstanceOf[Any])).asJava
+        val fields = Seq(new Field("field1", union, "doc", null)).asJava
         val schema = Schema.createRecord("name", "docs", "namespace", false)
         schema.setFields(fields)
         schema
@@ -171,7 +178,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
   }
 
   test("union(float, double, null) is read as nullable double") {
-    withTempPath { dir =>
+    TestUtils.withTempDir { dir =>
       val avroSchema: Schema = {
         val union = Schema.createUnion(
           List(Schema.create(Type.FLOAT),
@@ -179,7 +186,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
             Schema.create(Type.NULL)
           ).asJava
         )
-        val fields = Seq(new Field("field1", union, "doc", null.asInstanceOf[Any])).asJava
+        val fields = Seq(new Field("field1", union, "doc", null)).asJava
         val schema = Schema.createRecord("name", "docs", "namespace", false)
         schema.setFields(fields)
         schema
@@ -203,9 +210,9 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
   }
 
   test("Union of a single type") {
-    withTempPath { dir =>
+    TestUtils.withTempDir { dir =>
       val UnionOfOne = Schema.createUnion(List(Schema.create(Type.INT)).asJava)
-      val fields = Seq(new Field("field1", UnionOfOne, "doc", null.asInstanceOf[Any])).asJava
+      val fields = Seq(new Field("field1", UnionOfOne, "doc", null)).asJava
       val schema = Schema.createRecord("name", "docs", "namespace", false)
       schema.setFields(fields)
 
@@ -226,16 +233,16 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
   }
 
   test("Complex Union Type") {
-    withTempPath { dir =>
+    TestUtils.withTempDir { dir =>
       val fixedSchema = Schema.createFixed("fixed_name", "doc", "namespace", 4)
       val enumSchema = Schema.createEnum("enum_name", "doc", "namespace", List("e1", "e2").asJava)
       val complexUnionType = Schema.createUnion(
         List(Schema.create(Type.INT), Schema.create(Type.STRING), fixedSchema, enumSchema).asJava)
       val fields = Seq(
-        new Field("field1", complexUnionType, "doc", null.asInstanceOf[Any]),
-        new Field("field2", complexUnionType, "doc", null.asInstanceOf[Any]),
-        new Field("field3", complexUnionType, "doc", null.asInstanceOf[Any]),
-        new Field("field4", complexUnionType, "doc", null.asInstanceOf[Any])
+        new Field("field1", complexUnionType, "doc", null),
+        new Field("field2", complexUnionType, "doc", null),
+        new Field("field3", complexUnionType, "doc", null),
+        new Field("field4", complexUnionType, "doc", null)
       ).asJava
       val schema = Schema.createRecord("name", "docs", "namespace", false)
       schema.setFields(fields)
@@ -264,7 +271,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
   }
 
   test("Lots of nulls") {
-    withTempPath { dir =>
+    TestUtils.withTempDir { dir =>
       val schema = StructType(Seq(
         StructField("binary", BinaryType, true),
         StructField("timestamp", TimestampType, true),
@@ -283,7 +290,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
   }
 
   test("Struct field type") {
-    withTempPath { dir =>
+    TestUtils.withTempDir { dir =>
       val schema = StructType(Seq(
         StructField("float", FloatType, true),
         StructField("short", ShortType, true),
@@ -302,7 +309,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
   }
 
   test("Date field type") {
-    withTempPath { dir =>
+    TestUtils.withTempDir { dir =>
       val schema = StructType(Seq(
         StructField("float", FloatType, true),
         StructField("date", DateType, true)
@@ -322,7 +329,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
   }
 
   test("Array data types") {
-    withTempPath { dir =>
+    TestUtils.withTempDir { dir =>
       val testSchema = StructType(Seq(
         StructField("byte_array", ArrayType(ByteType), true),
         StructField("short_array", ArrayType(ShortType), true),
@@ -356,12 +363,13 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
   }
 
   test("write with compression") {
-    withTempPath { dir =>
+    TestUtils.withTempDir { dir =>
       val AVRO_COMPRESSION_CODEC = "spark.sql.avro.compression.codec"
       val AVRO_DEFLATE_LEVEL = "spark.sql.avro.deflate.level"
       val uncompressDir = s"$dir/uncompress"
       val deflateDir = s"$dir/deflate"
       val snappyDir = s"$dir/snappy"
+      val fakeDir = s"$dir/fake"
 
       val df = spark.read.avro(testFile)
       spark.conf.set(AVRO_COMPRESSION_CODEC, "uncompressed")
@@ -431,7 +439,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
   test("sql test") {
     spark.sql(
       s"""
-         |CREATE TEMPORARY VIEW avroTable
+         |CREATE TEMPORARY TABLE avroTable
          |USING avro
          |OPTIONS (path "$episodesFile")
       """.stripMargin.replaceAll("\n", " "))
@@ -442,24 +450,24 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
   test("conversion to avro and back") {
     // Note that test.avro includes a variety of types, some of which are nullable. We expect to
     // get the same values back.
-    withTempPath { dir =>
+    TestUtils.withTempDir { dir =>
       val avroDir = s"$dir/avro"
       spark.read.avro(testFile).write.avro(avroDir)
-      checkReloadMatchesSaved(testFile, avroDir)
+      TestUtils.checkReloadMatchesSaved(spark, testFile, avroDir)
     }
   }
 
   test("conversion to avro and back with namespace") {
     // Note that test.avro includes a variety of types, some of which are nullable. We expect to
     // get the same values back.
-    withTempPath { tempDir =>
+    TestUtils.withTempDir { tempDir =>
       val name = "AvroTest"
       val namespace = "com.databricks.spark.avro"
       val parameters = Map("recordName" -> name, "recordNamespace" -> namespace)
 
       val avroDir = tempDir + "/namedAvro"
       spark.read.avro(testFile).write.options(parameters).avro(avroDir)
-      checkReloadMatchesSaved(testFile, avroDir)
+      TestUtils.checkReloadMatchesSaved(spark, testFile, avroDir)
 
       // Look at raw file and make sure has namespace info
       val rawSaved = spark.sparkContext.textFile(avroDir)
@@ -470,7 +478,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
   }
 
   test("converting some specific sparkSQL types to avro") {
-    withTempPath { tempDir =>
+    TestUtils.withTempDir { tempDir =>
       val testSchema = StructType(Seq(
         StructField("Name", StringType, false),
         StructField("Length", IntegerType, true),
@@ -512,7 +520,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
   }
 
   test("correctly read long as date/timestamp type") {
-    withTempPath { tempDir =>
+    TestUtils.withTempDir { tempDir =>
       val sparkSession = spark
       import sparkSession.implicits._
 
@@ -541,7 +549,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
   }
 
   test("does not coerce null date/timestamp value to 0 epoch.") {
-    withTempPath { tempDir =>
+    TestUtils.withTempDir { tempDir =>
       val sparkSession = spark
       import sparkSession.implicits._
 
@@ -602,7 +610,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
 
     // Directory given has no avro files
     intercept[AnalysisException] {
-      withTempPath(dir => spark.read.avro(dir.getCanonicalPath))
+      TestUtils.withTempDir(dir => spark.read.avro(dir.getCanonicalPath))
     }
 
     intercept[AnalysisException] {
@@ -616,7 +624,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
     }
 
     intercept[FileNotFoundException] {
-      withTempPath { dir =>
+      TestUtils.withTempDir { dir =>
         FileUtils.touch(new File(dir, "test"))
         spark.read.avro(dir.toString)
       }
@@ -625,19 +633,19 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
   }
 
   test("SQL test insert overwrite") {
-    withTempPath { tempDir =>
+    TestUtils.withTempDir { tempDir =>
       val tempEmptyDir = s"$tempDir/sqlOverwrite"
       // Create a temp directory for table that will be overwritten
       new File(tempEmptyDir).mkdirs()
       spark.sql(
         s"""
-           |CREATE TEMPORARY VIEW episodes
+           |CREATE TEMPORARY TABLE episodes
            |USING avro
            |OPTIONS (path "$episodesFile")
          """.stripMargin.replaceAll("\n", " "))
       spark.sql(
         s"""
-           |CREATE TEMPORARY VIEW episodesEmpty
+           |CREATE TEMPORARY TABLE episodesEmpty
            |(name string, air_date string, doctor int)
            |USING avro
            |OPTIONS (path "$tempEmptyDir")
@@ -657,7 +665,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
 
   test("test save and load") {
     // Test if load works as expected
-    withTempPath { tempDir =>
+    TestUtils.withTempDir { tempDir =>
       val df = spark.read.avro(episodesFile)
       assert(df.count == 8)
 
@@ -671,7 +679,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
 
   test("test load with non-Avro file") {
     // Test if load works as expected
-    withTempPath { tempDir =>
+    TestUtils.withTempDir { tempDir =>
       val df = spark.read.avro(episodesFile)
       assert(df.count == 8)
 
@@ -729,7 +737,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
   }
 
   test("read avro file partitioned") {
-    withTempPath { dir =>
+    TestUtils.withTempDir { dir =>
       val sparkSession = spark
       import sparkSession.implicits._
       val df = (0 to 1024 * 3).toDS.map(i => s"record${i}").toDF("records")
@@ -748,7 +756,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
   case class NestedTop(id: Int, data: NestedMiddle)
 
   test("saving avro that has nested records with the same name") {
-    withTempPath { tempDir =>
+    TestUtils.withTempDir { tempDir =>
       // Save avro file on output folder path
       val writeDf = spark.createDataFrame(List(NestedTop(1, NestedMiddle(2, NestedBottom(3, "1")))))
       val outputFolder = s"$tempDir/duplicate_names/"
@@ -765,7 +773,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
   case class NestedTopArray(id: Int, data: NestedMiddleArray)
 
   test("saving avro that has nested records with the same name inside an array") {
-    withTempPath { tempDir =>
+    TestUtils.withTempDir { tempDir =>
       // Save avro file on output folder path
       val writeDf = spark.createDataFrame(
         List(NestedTopArray(1, NestedMiddleArray(2, Array(
@@ -786,7 +794,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
   case class NestedTopMap(id: Int, data: NestedMiddleMap)
 
   test("saving avro that has nested records with the same name inside a map") {
-    withTempPath { tempDir =>
+    TestUtils.withTempDir { tempDir =>
       // Save avro file on output folder path
       val writeDf = spark.createDataFrame(
         List(NestedTopMap(1, NestedMiddleMap(2, Map(
diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/TestUtils.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/TestUtils.scala
new file mode 100755 (executable)
index 0000000..4ae9b14
--- /dev/null
@@ -0,0 +1,156 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.avro
+
+import java.io.{File, IOException}
+import java.nio.ByteBuffer
+
+import scala.collection.immutable.HashSet
+import scala.collection.mutable.ArrayBuffer
+import scala.util.Random
+
+import com.google.common.io.Files
+import java.util
+
+import org.apache.spark.sql.SparkSession
+
+private[avro] object TestUtils {
+
+  /**
+   * This function checks that all records in a file match the original
+   * record.
+   */
+  def checkReloadMatchesSaved(spark: SparkSession, testFile: String, avroDir: String): Unit = {
+
+    def convertToString(elem: Any): String = {
+      elem match {
+        case null => "NULL" // HashSets can't have null in them, so we use a string instead
+        case arrayBuf: ArrayBuffer[_] =>
+          arrayBuf.asInstanceOf[ArrayBuffer[Any]].toArray.deep.mkString(" ")
+        case arrayByte: Array[Byte] => arrayByte.deep.mkString(" ")
+        case other => other.toString
+      }
+    }
+
+    val originalEntries = spark.read.avro(testFile).collect()
+    val newEntries = spark.read.avro(avroDir).collect()
+
+    assert(originalEntries.length == newEntries.length)
+
+    val origEntrySet = Array.fill(originalEntries(0).size)(new HashSet[Any]())
+    for (origEntry <- originalEntries) {
+      var idx = 0
+      for (origElement <- origEntry.toSeq) {
+        origEntrySet(idx) += convertToString(origElement)
+        idx += 1
+      }
+    }
+
+    for (newEntry <- newEntries) {
+      var idx = 0
+      for (newElement <- newEntry.toSeq) {
+        assert(origEntrySet(idx).contains(convertToString(newElement)))
+        idx += 1
+      }
+    }
+  }
+
+  def withTempDir(f: File => Unit): Unit = {
+    val dir = Files.createTempDir()
+    dir.delete()
+    try f(dir) finally deleteRecursively(dir)
+  }
+
+  /**
+   * This function deletes a file or a directory with everything that's in it. This function is
+   * copied from Spark with minor modifications made to it. See original source at:
+   * github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/util/Utils.scala
+   */
+
+  def deleteRecursively(file: File) {
+    def listFilesSafely(file: File): Seq[File] = {
+      if (file.exists()) {
+        val files = file.listFiles()
+        if (files == null) {
+          throw new IOException("Failed to list files for dir: " + file)
+        }
+        files
+      } else {
+        List()
+      }
+    }
+
+    if (file != null) {
+      try {
+        if (file.isDirectory) {
+          var savedIOException: IOException = null
+          for (child <- listFilesSafely(file)) {
+            try {
+              deleteRecursively(child)
+            } catch {
+              // In case of multiple exceptions, only last one will be thrown
+              case ioe: IOException => savedIOException = ioe
+            }
+          }
+          if (savedIOException != null) {
+            throw savedIOException
+          }
+        }
+      } finally {
+        if (!file.delete()) {
+          // Delete can also fail if the file simply did not exist
+          if (file.exists()) {
+            throw new IOException("Failed to delete: " + file.getAbsolutePath)
+          }
+        }
+      }
+    }
+  }
+
+  /**
+   * This function generates a random map(string, int) of a given size.
+   */
+  private[avro] def generateRandomMap(rand: Random, size: Int): java.util.Map[String, Int] = {
+    val jMap = new util.HashMap[String, Int]()
+    for (i <- 0 until size) {
+      jMap.put(rand.nextString(5), i)
+    }
+    jMap
+  }
+
+  /**
+   * This function generates a random array of booleans of a given size.
+   */
+  private[avro] def generateRandomArray(rand: Random, size: Int): util.ArrayList[Boolean] = {
+    val vec = new util.ArrayList[Boolean]()
+    for (i <- 0 until size) {
+      vec.add(rand.nextBoolean())
+    }
+    vec
+  }
+
+  /**
+   * This function generates a random ByteBuffer of a given size.
+   */
+  private[avro] def generateRandomByteBuffer(rand: Random, size: Int): ByteBuffer = {
+    val bb = ByteBuffer.allocate(size)
+    val arrayOfBytes = new Array[Byte](size)
+    rand.nextBytes(arrayOfBytes)
+    bb.put(arrayOfBytes)
+  }
+}