[SPARK-24691][SQL] Dispatch the type support check in FileFormat implementation
[spark.git] / sql / core / src / main / scala / org / apache / spark / sql / execution / datasources / json / JsonFileFormat.scala
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one or more
3  * contributor license agreements.  See the NOTICE file distributed with
4  * this work for additional information regarding copyright ownership.
5  * The ASF licenses this file to You under the Apache License, Version 2.0
6  * (the "License"); you may not use this file except in compliance with
7  * the License.  You may obtain a copy of the License at
8  *
9  *    http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 package org.apache.spark.sql.execution.datasources.json
19
20 import java.nio.charset.{Charset, StandardCharsets}
21
22 import org.apache.hadoop.conf.Configuration
23 import org.apache.hadoop.fs.{FileStatus, Path}
24 import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext}
25
26 import org.apache.spark.internal.Logging
27 import org.apache.spark.sql.{AnalysisException, SparkSession}
28 import org.apache.spark.sql.catalyst.InternalRow
29 import org.apache.spark.sql.catalyst.json.{JacksonGenerator, JacksonParser, JSONOptions, JSONOptionsInRead}
30 import org.apache.spark.sql.catalyst.util.CompressionCodecs
31 import org.apache.spark.sql.execution.datasources._
32 import org.apache.spark.sql.sources._
33 import org.apache.spark.sql.types._
34 import org.apache.spark.util.SerializableConfiguration
35
36 class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister {
37   override val shortName: String = "json"
38
39   override def isSplitable(
40       sparkSession: SparkSession,
41       options: Map[String, String],
42       path: Path): Boolean = {
43     val parsedOptions = new JSONOptionsInRead(
44       options,
45       sparkSession.sessionState.conf.sessionLocalTimeZone,
46       sparkSession.sessionState.conf.columnNameOfCorruptRecord)
47     val jsonDataSource = JsonDataSource(parsedOptions)
48     jsonDataSource.isSplitable && super.isSplitable(sparkSession, options, path)
49   }
50
51   override def inferSchema(
52       sparkSession: SparkSession,
53       options: Map[String, String],
54       files: Seq[FileStatus]): Option[StructType] = {
55     val parsedOptions = new JSONOptionsInRead(
56       options,
57       sparkSession.sessionState.conf.sessionLocalTimeZone,
58       sparkSession.sessionState.conf.columnNameOfCorruptRecord)
59     JsonDataSource(parsedOptions).inferSchema(
60       sparkSession, files, parsedOptions)
61   }
62
63   override def prepareWrite(
64       sparkSession: SparkSession,
65       job: Job,
66       options: Map[String, String],
67       dataSchema: StructType): OutputWriterFactory = {
68     val conf = job.getConfiguration
69     val parsedOptions = new JSONOptions(
70       options,
71       sparkSession.sessionState.conf.sessionLocalTimeZone,
72       sparkSession.sessionState.conf.columnNameOfCorruptRecord)
73     parsedOptions.compressionCodec.foreach { codec =>
74       CompressionCodecs.setCodecConfiguration(conf, codec)
75     }
76
77     new OutputWriterFactory {
78       override def newInstance(
79           path: String,
80           dataSchema: StructType,
81           context: TaskAttemptContext): OutputWriter = {
82         new JsonOutputWriter(path, parsedOptions, dataSchema, context)
83       }
84
85       override def getFileExtension(context: TaskAttemptContext): String = {
86         ".json" + CodecStreams.getCompressionExtension(context)
87       }
88     }
89   }
90
91   override def buildReader(
92       sparkSession: SparkSession,
93       dataSchema: StructType,
94       partitionSchema: StructType,
95       requiredSchema: StructType,
96       filters: Seq[Filter],
97       options: Map[String, String],
98       hadoopConf: Configuration): PartitionedFile => Iterator[InternalRow] = {
99     val broadcastedHadoopConf =
100       sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf))
101
102     val parsedOptions = new JSONOptionsInRead(
103       options,
104       sparkSession.sessionState.conf.sessionLocalTimeZone,
105       sparkSession.sessionState.conf.columnNameOfCorruptRecord)
106
107     val actualSchema =
108       StructType(requiredSchema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord))
109     // Check a field requirement for corrupt records here to throw an exception in a driver side
110     dataSchema.getFieldIndex(parsedOptions.columnNameOfCorruptRecord).foreach { corruptFieldIndex =>
111       val f = dataSchema(corruptFieldIndex)
112       if (f.dataType != StringType || !f.nullable) {
113         throw new AnalysisException(
114           "The field for corrupt records must be string type and nullable")
115       }
116     }
117
118     if (requiredSchema.length == 1 &&
119       requiredSchema.head.name == parsedOptions.columnNameOfCorruptRecord) {
120       throw new AnalysisException(
121         "Since Spark 2.3, the queries from raw JSON/CSV files are disallowed when the\n" +
122         "referenced columns only include the internal corrupt record column\n" +
123         s"(named _corrupt_record by default). For example:\n" +
124         "spark.read.schema(schema).json(file).filter($\"_corrupt_record\".isNotNull).count()\n" +
125         "and spark.read.schema(schema).json(file).select(\"_corrupt_record\").show().\n" +
126         "Instead, you can cache or save the parsed results and then send the same query.\n" +
127         "For example, val df = spark.read.schema(schema).json(file).cache() and then\n" +
128         "df.filter($\"_corrupt_record\".isNotNull).count()."
129       )
130     }
131
132     (file: PartitionedFile) => {
133       val parser = new JacksonParser(actualSchema, parsedOptions)
134       JsonDataSource(parsedOptions).readFile(
135         broadcastedHadoopConf.value.value,
136         file,
137         parser,
138         requiredSchema)
139     }
140   }
141
142   override def toString: String = "JSON"
143
144   override def hashCode(): Int = getClass.hashCode()
145
146   override def equals(other: Any): Boolean = other.isInstanceOf[JsonFileFormat]
147
148   override def supportDataType(dataType: DataType, isReadPath: Boolean): Boolean = dataType match {
149     case _: AtomicType => true
150
151     case st: StructType => st.forall { f => supportDataType(f.dataType, isReadPath) }
152
153     case ArrayType(elementType, _) => supportDataType(elementType, isReadPath)
154
155     case MapType(keyType, valueType, _) =>
156       supportDataType(keyType, isReadPath) && supportDataType(valueType, isReadPath)
157
158     case udt: UserDefinedType[_] => supportDataType(udt.sqlType, isReadPath)
159
160     case _: NullType => true
161
162     case _ => false
163   }
164 }
165
166 private[json] class JsonOutputWriter(
167     path: String,
168     options: JSONOptions,
169     dataSchema: StructType,
170     context: TaskAttemptContext)
171   extends OutputWriter with Logging {
172
173   private val encoding = options.encoding match {
174     case Some(charsetName) => Charset.forName(charsetName)
175     case None => StandardCharsets.UTF_8
176   }
177
178   if (JSONOptionsInRead.blacklist.contains(encoding)) {
179     logWarning(s"The JSON file ($path) was written in the encoding ${encoding.displayName()}" +
180          " which can be read back by Spark only if multiLine is enabled.")
181   }
182
183   private val writer = CodecStreams.createOutputStreamWriter(
184     context, new Path(path), encoding)
185
186   // create the Generator without separator inserted between 2 records
187   private[this] val gen = new JacksonGenerator(dataSchema, writer, options)
188
189   override def write(row: InternalRow): Unit = {
190     gen.write(row)
191     gen.writeLineEnding()
192   }
193
194   override def close(): Unit = {
195     gen.close()
196     writer.close()
197   }
198 }