[SPARK-24691][SQL] Dispatch the type support check in FileFormat implementation
[spark.git] / sql / core / src / main / scala / org / apache / spark / sql / execution / datasources / text / TextFileFormat.scala
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one or more
3  * contributor license agreements.  See the NOTICE file distributed with
4  * this work for additional information regarding copyright ownership.
5  * The ASF licenses this file to You under the Apache License, Version 2.0
6  * (the "License"); you may not use this file except in compliance with
7  * the License.  You may obtain a copy of the License at
8  *
9  *    http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 package org.apache.spark.sql.execution.datasources.text
19
20 import org.apache.hadoop.conf.Configuration
21 import org.apache.hadoop.fs.{FileStatus, Path}
22 import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext}
23
24 import org.apache.spark.TaskContext
25 import org.apache.spark.broadcast.Broadcast
26 import org.apache.spark.sql.{AnalysisException, SparkSession}
27 import org.apache.spark.sql.catalyst.InternalRow
28 import org.apache.spark.sql.catalyst.expressions.UnsafeRow
29 import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter
30 import org.apache.spark.sql.catalyst.util.CompressionCodecs
31 import org.apache.spark.sql.execution.datasources._
32 import org.apache.spark.sql.sources._
33 import org.apache.spark.sql.types.{DataType, StringType, StructType}
34 import org.apache.spark.util.SerializableConfiguration
35
36 /**
37  * A data source for reading text files.
38  */
39 class TextFileFormat extends TextBasedFileFormat with DataSourceRegister {
40
41   override def shortName(): String = "text"
42
43   override def toString: String = "Text"
44
45   private def verifySchema(schema: StructType): Unit = {
46     if (schema.size != 1) {
47       throw new AnalysisException(
48         s"Text data source supports only a single column, and you have ${schema.size} columns.")
49     }
50   }
51
52   override def isSplitable(
53       sparkSession: SparkSession,
54       options: Map[String, String],
55       path: Path): Boolean = {
56     val textOptions = new TextOptions(options)
57     super.isSplitable(sparkSession, options, path) && !textOptions.wholeText
58   }
59
60   override def inferSchema(
61       sparkSession: SparkSession,
62       options: Map[String, String],
63       files: Seq[FileStatus]): Option[StructType] = Some(new StructType().add("value", StringType))
64
65   override def prepareWrite(
66       sparkSession: SparkSession,
67       job: Job,
68       options: Map[String, String],
69       dataSchema: StructType): OutputWriterFactory = {
70     verifySchema(dataSchema)
71
72     val textOptions = new TextOptions(options)
73     val conf = job.getConfiguration
74
75     textOptions.compressionCodec.foreach { codec =>
76       CompressionCodecs.setCodecConfiguration(conf, codec)
77     }
78
79     new OutputWriterFactory {
80       override def newInstance(
81           path: String,
82           dataSchema: StructType,
83           context: TaskAttemptContext): OutputWriter = {
84         new TextOutputWriter(path, dataSchema, textOptions.lineSeparatorInWrite, context)
85       }
86
87       override def getFileExtension(context: TaskAttemptContext): String = {
88         ".txt" + CodecStreams.getCompressionExtension(context)
89       }
90     }
91   }
92
93   override def buildReader(
94       sparkSession: SparkSession,
95       dataSchema: StructType,
96       partitionSchema: StructType,
97       requiredSchema: StructType,
98       filters: Seq[Filter],
99       options: Map[String, String],
100       hadoopConf: Configuration): PartitionedFile => Iterator[InternalRow] = {
101     assert(
102       requiredSchema.length <= 1,
103       "Text data source only produces a single data column named \"value\".")
104     val textOptions = new TextOptions(options)
105     val broadcastedHadoopConf =
106       sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf))
107
108     readToUnsafeMem(broadcastedHadoopConf, requiredSchema, textOptions)
109   }
110
111   private def readToUnsafeMem(
112       conf: Broadcast[SerializableConfiguration],
113       requiredSchema: StructType,
114       textOptions: TextOptions): (PartitionedFile) => Iterator[UnsafeRow] = {
115
116     (file: PartitionedFile) => {
117       val confValue = conf.value.value
118       val reader = if (!textOptions.wholeText) {
119         new HadoopFileLinesReader(file, textOptions.lineSeparatorInRead, confValue)
120       } else {
121         new HadoopFileWholeTextReader(file, confValue)
122       }
123       Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => reader.close()))
124       if (requiredSchema.isEmpty) {
125         val emptyUnsafeRow = new UnsafeRow(0)
126         reader.map(_ => emptyUnsafeRow)
127       } else {
128         val unsafeRowWriter = new UnsafeRowWriter(1)
129
130         reader.map { line =>
131           // Writes to an UnsafeRow directly
132           unsafeRowWriter.reset()
133           unsafeRowWriter.write(0, line.getBytes, 0, line.getLength)
134           unsafeRowWriter.getRow()
135         }
136       }
137     }
138   }
139
140   override def supportDataType(dataType: DataType, isReadPath: Boolean): Boolean =
141     dataType == StringType
142 }
143
144 class TextOutputWriter(
145     path: String,
146     dataSchema: StructType,
147     lineSeparator: Array[Byte],
148     context: TaskAttemptContext)
149   extends OutputWriter {
150
151   private val writer = CodecStreams.createOutputStream(context, new Path(path))
152
153   override def write(row: InternalRow): Unit = {
154     if (!row.isNullAt(0)) {
155       val utf8string = row.getUTF8String(0)
156       utf8string.writeTo(writer)
157     }
158     writer.write(lineSeparator)
159   }
160
161   override def close(): Unit = {
162     writer.close()
163   }
164 }