[SPARK-24776][SQL] Avro unit test: use SQLTestUtils and replace deprecated methods
authorGengliang Wang <gengliang.wang@databricks.com>
Fri, 13 Jul 2018 15:55:46 +0000 (08:55 -0700)
committerXiao Li <gatorsmile@gmail.com>
Fri, 13 Jul 2018 15:55:46 +0000 (08:55 -0700)
## What changes were proposed in this pull request?
Improve Avro unit test:
1. use QueryTest/SharedSQLContext/SQLTestUtils, instead of the duplicated test utils.
2. replace deprecated methods

## How was this patch tested?

Unit test

Author: Gengliang Wang <gengliang.wang@databricks.com>

Closes #21760 from gengliangwang/improve_avro_test.

external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala
external/avro/src/test/scala/org/apache/spark/sql/avro/TestUtils.scala [deleted file]

index c6c1e40..108b347 100644 (file)
@@ -31,32 +31,24 @@ import org.apache.avro.generic.{GenericData, GenericDatumWriter, GenericRecord}
 import org.apache.avro.generic.GenericData.{EnumSymbol, Fixed}
 import org.apache.commons.io.FileUtils
 
-import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql._
 import org.apache.spark.sql.avro.SchemaConverters.IncompatibleSchemaException
+import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils}
 import org.apache.spark.sql.types._
 
-class AvroSuite extends SparkFunSuite {
+class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
   val episodesFile = "src/test/resources/episodes.avro"
   val testFile = "src/test/resources/test.avro"
 
-  private var spark: SparkSession = _
-
   override protected def beforeAll(): Unit = {
     super.beforeAll()
-    spark = SparkSession.builder()
-      .master("local[2]")
-      .appName("AvroSuite")
-      .config("spark.sql.files.maxPartitionBytes", 1024)
-      .getOrCreate()
-  }
-
-  override protected def afterAll(): Unit = {
-    try {
-      spark.sparkContext.stop()
-    } finally {
-      super.afterAll()
-    }
+    spark.conf.set("spark.sql.files.maxPartitionBytes", 1024)
+  }
+
+  def checkReloadMatchesSaved(originalFile: String, newFile: String): Unit = {
+    val originalEntries = spark.read.avro(testFile).collect()
+    val newEntries = spark.read.avro(newFile)
+    checkAnswer(newEntries, originalEntries)
   }
 
   test("reading from multiple paths") {
@@ -68,7 +60,7 @@ class AvroSuite extends SparkFunSuite {
     val df = spark.read.avro(episodesFile)
     val fields = List("title", "air_date", "doctor")
     for (field <- fields) {
-      TestUtils.withTempDir { dir =>
+      withTempPath { dir =>
         val outputDir = s"$dir/${UUID.randomUUID}"
         df.write.partitionBy(field).avro(outputDir)
         val input = spark.read.avro(outputDir)
@@ -82,12 +74,12 @@ class AvroSuite extends SparkFunSuite {
 
   test("request no fields") {
     val df = spark.read.avro(episodesFile)
-    df.registerTempTable("avro_table")
+    df.createOrReplaceTempView("avro_table")
     assert(spark.sql("select count(*) from avro_table").collect().head === Row(8))
   }
 
   test("convert formats") {
-    TestUtils.withTempDir { dir =>
+    withTempPath { dir =>
       val df = spark.read.avro(episodesFile)
       df.write.parquet(dir.getCanonicalPath)
       assert(spark.read.parquet(dir.getCanonicalPath).count() === df.count)
@@ -95,15 +87,16 @@ class AvroSuite extends SparkFunSuite {
   }
 
   test("rearrange internal schema") {
-    TestUtils.withTempDir { dir =>
+    withTempPath { dir =>
       val df = spark.read.avro(episodesFile)
       df.select("doctor", "title").write.avro(dir.getCanonicalPath)
     }
   }
 
   test("test NULL avro type") {
-    TestUtils.withTempDir { dir =>
-      val fields = Seq(new Field("null", Schema.create(Type.NULL), "doc", null)).asJava
+    withTempPath { dir =>
+      val fields =
+        Seq(new Field("null", Schema.create(Type.NULL), "doc", null.asInstanceOf[Any])).asJava
       val schema = Schema.createRecord("name", "docs", "namespace", false)
       schema.setFields(fields)
       val datumWriter = new GenericDatumWriter[GenericRecord](schema)
@@ -122,11 +115,11 @@ class AvroSuite extends SparkFunSuite {
   }
 
   test("union(int, long) is read as long") {
-    TestUtils.withTempDir { dir =>
+    withTempPath { dir =>
       val avroSchema: Schema = {
         val union =
           Schema.createUnion(List(Schema.create(Type.INT), Schema.create(Type.LONG)).asJava)
-        val fields = Seq(new Field("field1", union, "doc", null)).asJava
+        val fields = Seq(new Field("field1", union, "doc", null.asInstanceOf[Any])).asJava
         val schema = Schema.createRecord("name", "docs", "namespace", false)
         schema.setFields(fields)
         schema
@@ -150,11 +143,11 @@ class AvroSuite extends SparkFunSuite {
   }
 
   test("union(float, double) is read as double") {
-    TestUtils.withTempDir { dir =>
+    withTempPath { dir =>
       val avroSchema: Schema = {
         val union =
           Schema.createUnion(List(Schema.create(Type.FLOAT), Schema.create(Type.DOUBLE)).asJava)
-        val fields = Seq(new Field("field1", union, "doc", null)).asJava
+        val fields = Seq(new Field("field1", union, "doc", null.asInstanceOf[Any])).asJava
         val schema = Schema.createRecord("name", "docs", "namespace", false)
         schema.setFields(fields)
         schema
@@ -178,7 +171,7 @@ class AvroSuite extends SparkFunSuite {
   }
 
   test("union(float, double, null) is read as nullable double") {
-    TestUtils.withTempDir { dir =>
+    withTempPath { dir =>
       val avroSchema: Schema = {
         val union = Schema.createUnion(
           List(Schema.create(Type.FLOAT),
@@ -186,7 +179,7 @@ class AvroSuite extends SparkFunSuite {
             Schema.create(Type.NULL)
           ).asJava
         )
-        val fields = Seq(new Field("field1", union, "doc", null)).asJava
+        val fields = Seq(new Field("field1", union, "doc", null.asInstanceOf[Any])).asJava
         val schema = Schema.createRecord("name", "docs", "namespace", false)
         schema.setFields(fields)
         schema
@@ -210,9 +203,9 @@ class AvroSuite extends SparkFunSuite {
   }
 
   test("Union of a single type") {
-    TestUtils.withTempDir { dir =>
+    withTempPath { dir =>
       val UnionOfOne = Schema.createUnion(List(Schema.create(Type.INT)).asJava)
-      val fields = Seq(new Field("field1", UnionOfOne, "doc", null)).asJava
+      val fields = Seq(new Field("field1", UnionOfOne, "doc", null.asInstanceOf[Any])).asJava
       val schema = Schema.createRecord("name", "docs", "namespace", false)
       schema.setFields(fields)
 
@@ -233,16 +226,16 @@ class AvroSuite extends SparkFunSuite {
   }
 
   test("Complex Union Type") {
-    TestUtils.withTempDir { dir =>
+    withTempPath { dir =>
       val fixedSchema = Schema.createFixed("fixed_name", "doc", "namespace", 4)
       val enumSchema = Schema.createEnum("enum_name", "doc", "namespace", List("e1", "e2").asJava)
       val complexUnionType = Schema.createUnion(
         List(Schema.create(Type.INT), Schema.create(Type.STRING), fixedSchema, enumSchema).asJava)
       val fields = Seq(
-        new Field("field1", complexUnionType, "doc", null),
-        new Field("field2", complexUnionType, "doc", null),
-        new Field("field3", complexUnionType, "doc", null),
-        new Field("field4", complexUnionType, "doc", null)
+        new Field("field1", complexUnionType, "doc", null.asInstanceOf[Any]),
+        new Field("field2", complexUnionType, "doc", null.asInstanceOf[Any]),
+        new Field("field3", complexUnionType, "doc", null.asInstanceOf[Any]),
+        new Field("field4", complexUnionType, "doc", null.asInstanceOf[Any])
       ).asJava
       val schema = Schema.createRecord("name", "docs", "namespace", false)
       schema.setFields(fields)
@@ -271,7 +264,7 @@ class AvroSuite extends SparkFunSuite {
   }
 
   test("Lots of nulls") {
-    TestUtils.withTempDir { dir =>
+    withTempPath { dir =>
       val schema = StructType(Seq(
         StructField("binary", BinaryType, true),
         StructField("timestamp", TimestampType, true),
@@ -290,7 +283,7 @@ class AvroSuite extends SparkFunSuite {
   }
 
   test("Struct field type") {
-    TestUtils.withTempDir { dir =>
+    withTempPath { dir =>
       val schema = StructType(Seq(
         StructField("float", FloatType, true),
         StructField("short", ShortType, true),
@@ -309,7 +302,7 @@ class AvroSuite extends SparkFunSuite {
   }
 
   test("Date field type") {
-    TestUtils.withTempDir { dir =>
+    withTempPath { dir =>
       val schema = StructType(Seq(
         StructField("float", FloatType, true),
         StructField("date", DateType, true)
@@ -329,7 +322,7 @@ class AvroSuite extends SparkFunSuite {
   }
 
   test("Array data types") {
-    TestUtils.withTempDir { dir =>
+    withTempPath { dir =>
       val testSchema = StructType(Seq(
         StructField("byte_array", ArrayType(ByteType), true),
         StructField("short_array", ArrayType(ShortType), true),
@@ -363,13 +356,12 @@ class AvroSuite extends SparkFunSuite {
   }
 
   test("write with compression") {
-    TestUtils.withTempDir { dir =>
+    withTempPath { dir =>
       val AVRO_COMPRESSION_CODEC = "spark.sql.avro.compression.codec"
       val AVRO_DEFLATE_LEVEL = "spark.sql.avro.deflate.level"
       val uncompressDir = s"$dir/uncompress"
       val deflateDir = s"$dir/deflate"
       val snappyDir = s"$dir/snappy"
-      val fakeDir = s"$dir/fake"
 
       val df = spark.read.avro(testFile)
       spark.conf.set(AVRO_COMPRESSION_CODEC, "uncompressed")
@@ -439,7 +431,7 @@ class AvroSuite extends SparkFunSuite {
   test("sql test") {
     spark.sql(
       s"""
-         |CREATE TEMPORARY TABLE avroTable
+         |CREATE TEMPORARY VIEW avroTable
          |USING avro
          |OPTIONS (path "$episodesFile")
       """.stripMargin.replaceAll("\n", " "))
@@ -450,24 +442,24 @@ class AvroSuite extends SparkFunSuite {
   test("conversion to avro and back") {
     // Note that test.avro includes a variety of types, some of which are nullable. We expect to
     // get the same values back.
-    TestUtils.withTempDir { dir =>
+    withTempPath { dir =>
       val avroDir = s"$dir/avro"
       spark.read.avro(testFile).write.avro(avroDir)
-      TestUtils.checkReloadMatchesSaved(spark, testFile, avroDir)
+      checkReloadMatchesSaved(testFile, avroDir)
     }
   }
 
   test("conversion to avro and back with namespace") {
     // Note that test.avro includes a variety of types, some of which are nullable. We expect to
     // get the same values back.
-    TestUtils.withTempDir { tempDir =>
+    withTempPath { tempDir =>
       val name = "AvroTest"
       val namespace = "com.databricks.spark.avro"
       val parameters = Map("recordName" -> name, "recordNamespace" -> namespace)
 
       val avroDir = tempDir + "/namedAvro"
       spark.read.avro(testFile).write.options(parameters).avro(avroDir)
-      TestUtils.checkReloadMatchesSaved(spark, testFile, avroDir)
+      checkReloadMatchesSaved(testFile, avroDir)
 
       // Look at raw file and make sure has namespace info
       val rawSaved = spark.sparkContext.textFile(avroDir)
@@ -478,7 +470,7 @@ class AvroSuite extends SparkFunSuite {
   }
 
   test("converting some specific sparkSQL types to avro") {
-    TestUtils.withTempDir { tempDir =>
+    withTempPath { tempDir =>
       val testSchema = StructType(Seq(
         StructField("Name", StringType, false),
         StructField("Length", IntegerType, true),
@@ -520,7 +512,7 @@ class AvroSuite extends SparkFunSuite {
   }
 
   test("correctly read long as date/timestamp type") {
-    TestUtils.withTempDir { tempDir =>
+    withTempPath { tempDir =>
       val sparkSession = spark
       import sparkSession.implicits._
 
@@ -549,7 +541,7 @@ class AvroSuite extends SparkFunSuite {
   }
 
   test("does not coerce null date/timestamp value to 0 epoch.") {
-    TestUtils.withTempDir { tempDir =>
+    withTempPath { tempDir =>
       val sparkSession = spark
       import sparkSession.implicits._
 
@@ -610,7 +602,7 @@ class AvroSuite extends SparkFunSuite {
 
     // Directory given has no avro files
     intercept[AnalysisException] {
-      TestUtils.withTempDir(dir => spark.read.avro(dir.getCanonicalPath))
+      withTempPath(dir => spark.read.avro(dir.getCanonicalPath))
     }
 
     intercept[AnalysisException] {
@@ -624,7 +616,7 @@ class AvroSuite extends SparkFunSuite {
     }
 
     intercept[FileNotFoundException] {
-      TestUtils.withTempDir { dir =>
+      withTempPath { dir =>
         FileUtils.touch(new File(dir, "test"))
         spark.read.avro(dir.toString)
       }
@@ -633,19 +625,19 @@ class AvroSuite extends SparkFunSuite {
   }
 
   test("SQL test insert overwrite") {
-    TestUtils.withTempDir { tempDir =>
+    withTempPath { tempDir =>
       val tempEmptyDir = s"$tempDir/sqlOverwrite"
       // Create a temp directory for table that will be overwritten
       new File(tempEmptyDir).mkdirs()
       spark.sql(
         s"""
-           |CREATE TEMPORARY TABLE episodes
+           |CREATE TEMPORARY VIEW episodes
            |USING avro
            |OPTIONS (path "$episodesFile")
          """.stripMargin.replaceAll("\n", " "))
       spark.sql(
         s"""
-           |CREATE TEMPORARY TABLE episodesEmpty
+           |CREATE TEMPORARY VIEW episodesEmpty
            |(name string, air_date string, doctor int)
            |USING avro
            |OPTIONS (path "$tempEmptyDir")
@@ -665,7 +657,7 @@ class AvroSuite extends SparkFunSuite {
 
   test("test save and load") {
     // Test if load works as expected
-    TestUtils.withTempDir { tempDir =>
+    withTempPath { tempDir =>
       val df = spark.read.avro(episodesFile)
       assert(df.count == 8)
 
@@ -679,7 +671,7 @@ class AvroSuite extends SparkFunSuite {
 
   test("test load with non-Avro file") {
     // Test if load works as expected
-    TestUtils.withTempDir { tempDir =>
+    withTempPath { tempDir =>
       val df = spark.read.avro(episodesFile)
       assert(df.count == 8)
 
@@ -737,7 +729,7 @@ class AvroSuite extends SparkFunSuite {
   }
 
   test("read avro file partitioned") {
-    TestUtils.withTempDir { dir =>
+    withTempPath { dir =>
       val sparkSession = spark
       import sparkSession.implicits._
       val df = (0 to 1024 * 3).toDS.map(i => s"record${i}").toDF("records")
@@ -756,7 +748,7 @@ class AvroSuite extends SparkFunSuite {
   case class NestedTop(id: Int, data: NestedMiddle)
 
   test("saving avro that has nested records with the same name") {
-    TestUtils.withTempDir { tempDir =>
+    withTempPath { tempDir =>
       // Save avro file on output folder path
       val writeDf = spark.createDataFrame(List(NestedTop(1, NestedMiddle(2, NestedBottom(3, "1")))))
       val outputFolder = s"$tempDir/duplicate_names/"
@@ -773,7 +765,7 @@ class AvroSuite extends SparkFunSuite {
   case class NestedTopArray(id: Int, data: NestedMiddleArray)
 
   test("saving avro that has nested records with the same name inside an array") {
-    TestUtils.withTempDir { tempDir =>
+    withTempPath { tempDir =>
       // Save avro file on output folder path
       val writeDf = spark.createDataFrame(
         List(NestedTopArray(1, NestedMiddleArray(2, Array(
@@ -794,7 +786,7 @@ class AvroSuite extends SparkFunSuite {
   case class NestedTopMap(id: Int, data: NestedMiddleMap)
 
   test("saving avro that has nested records with the same name inside a map") {
-    TestUtils.withTempDir { tempDir =>
+    withTempPath { tempDir =>
       // Save avro file on output folder path
       val writeDf = spark.createDataFrame(
         List(NestedTopMap(1, NestedMiddleMap(2, Map(
diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/TestUtils.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/TestUtils.scala
deleted file mode 100755 (executable)
index 4ae9b14..0000000
+++ /dev/null
@@ -1,156 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.avro
-
-import java.io.{File, IOException}
-import java.nio.ByteBuffer
-
-import scala.collection.immutable.HashSet
-import scala.collection.mutable.ArrayBuffer
-import scala.util.Random
-
-import com.google.common.io.Files
-import java.util
-
-import org.apache.spark.sql.SparkSession
-
-private[avro] object TestUtils {
-
-  /**
-   * This function checks that all records in a file match the original
-   * record.
-   */
-  def checkReloadMatchesSaved(spark: SparkSession, testFile: String, avroDir: String): Unit = {
-
-    def convertToString(elem: Any): String = {
-      elem match {
-        case null => "NULL" // HashSets can't have null in them, so we use a string instead
-        case arrayBuf: ArrayBuffer[_] =>
-          arrayBuf.asInstanceOf[ArrayBuffer[Any]].toArray.deep.mkString(" ")
-        case arrayByte: Array[Byte] => arrayByte.deep.mkString(" ")
-        case other => other.toString
-      }
-    }
-
-    val originalEntries = spark.read.avro(testFile).collect()
-    val newEntries = spark.read.avro(avroDir).collect()
-
-    assert(originalEntries.length == newEntries.length)
-
-    val origEntrySet = Array.fill(originalEntries(0).size)(new HashSet[Any]())
-    for (origEntry <- originalEntries) {
-      var idx = 0
-      for (origElement <- origEntry.toSeq) {
-        origEntrySet(idx) += convertToString(origElement)
-        idx += 1
-      }
-    }
-
-    for (newEntry <- newEntries) {
-      var idx = 0
-      for (newElement <- newEntry.toSeq) {
-        assert(origEntrySet(idx).contains(convertToString(newElement)))
-        idx += 1
-      }
-    }
-  }
-
-  def withTempDir(f: File => Unit): Unit = {
-    val dir = Files.createTempDir()
-    dir.delete()
-    try f(dir) finally deleteRecursively(dir)
-  }
-
-  /**
-   * This function deletes a file or a directory with everything that's in it. This function is
-   * copied from Spark with minor modifications made to it. See original source at:
-   * github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/util/Utils.scala
-   */
-
-  def deleteRecursively(file: File) {
-    def listFilesSafely(file: File): Seq[File] = {
-      if (file.exists()) {
-        val files = file.listFiles()
-        if (files == null) {
-          throw new IOException("Failed to list files for dir: " + file)
-        }
-        files
-      } else {
-        List()
-      }
-    }
-
-    if (file != null) {
-      try {
-        if (file.isDirectory) {
-          var savedIOException: IOException = null
-          for (child <- listFilesSafely(file)) {
-            try {
-              deleteRecursively(child)
-            } catch {
-              // In case of multiple exceptions, only last one will be thrown
-              case ioe: IOException => savedIOException = ioe
-            }
-          }
-          if (savedIOException != null) {
-            throw savedIOException
-          }
-        }
-      } finally {
-        if (!file.delete()) {
-          // Delete can also fail if the file simply did not exist
-          if (file.exists()) {
-            throw new IOException("Failed to delete: " + file.getAbsolutePath)
-          }
-        }
-      }
-    }
-  }
-
-  /**
-   * This function generates a random map(string, int) of a given size.
-   */
-  private[avro] def generateRandomMap(rand: Random, size: Int): java.util.Map[String, Int] = {
-    val jMap = new util.HashMap[String, Int]()
-    for (i <- 0 until size) {
-      jMap.put(rand.nextString(5), i)
-    }
-    jMap
-  }
-
-  /**
-   * This function generates a random array of booleans of a given size.
-   */
-  private[avro] def generateRandomArray(rand: Random, size: Int): util.ArrayList[Boolean] = {
-    val vec = new util.ArrayList[Boolean]()
-    for (i <- 0 until size) {
-      vec.add(rand.nextBoolean())
-    }
-    vec
-  }
-
-  /**
-   * This function generates a random ByteBuffer of a given size.
-   */
-  private[avro] def generateRandomByteBuffer(rand: Random, size: Int): ByteBuffer = {
-    val bb = ByteBuffer.allocate(size)
-    val arrayOfBytes = new Array[Byte](size)
-    rand.nextBytes(arrayOfBytes)
-    bb.put(arrayOfBytes)
-  }
-}