[SPARK-24691][SQL] Dispatch the type support check in FileFormat implementation
[spark.git] / sql / core / src / main / scala / org / apache / spark / sql / execution / datasources / orc / OrcFileFormat.scala
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one or more
3  * contributor license agreements.  See the NOTICE file distributed with
4  * this work for additional information regarding copyright ownership.
5  * The ASF licenses this file to You under the Apache License, Version 2.0
6  * (the "License"); you may not use this file except in compliance with
7  * the License.  You may obtain a copy of the License at
8  *
9  *    http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 package org.apache.spark.sql.execution.datasources.orc
19
20 import java.io._
21 import java.net.URI
22
23 import org.apache.hadoop.conf.Configuration
24 import org.apache.hadoop.fs.{FileStatus, Path}
25 import org.apache.hadoop.mapred.JobConf
26 import org.apache.hadoop.mapreduce._
27 import org.apache.hadoop.mapreduce.lib.input.FileSplit
28 import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl
29 import org.apache.orc._
30 import org.apache.orc.OrcConf.{COMPRESS, MAPRED_OUTPUT_SCHEMA}
31 import org.apache.orc.mapred.OrcStruct
32 import org.apache.orc.mapreduce._
33
34 import org.apache.spark.TaskContext
35 import org.apache.spark.sql.AnalysisException
36 import org.apache.spark.sql.SparkSession
37 import org.apache.spark.sql.catalyst.InternalRow
38 import org.apache.spark.sql.catalyst.expressions._
39 import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
40 import org.apache.spark.sql.execution.datasources._
41 import org.apache.spark.sql.internal.SQLConf
42 import org.apache.spark.sql.sources._
43 import org.apache.spark.sql.types._
44 import org.apache.spark.util.SerializableConfiguration
45
46 private[sql] object OrcFileFormat {
47   private def checkFieldName(name: String): Unit = {
48     try {
49       TypeDescription.fromString(s"struct<$name:int>")
50     } catch {
51       case _: IllegalArgumentException =>
52         throw new AnalysisException(
53           s"""Column name "$name" contains invalid character(s).
54              |Please use alias to rename it.
55            """.stripMargin.split("\n").mkString(" ").trim)
56     }
57   }
58
59   def checkFieldNames(names: Seq[String]): Unit = {
60     names.foreach(checkFieldName)
61   }
62 }
63
64 /**
65  * New ORC File Format based on Apache ORC.
66  */
67 class OrcFileFormat
68   extends FileFormat
69   with DataSourceRegister
70   with Serializable {
71
72   override def shortName(): String = "orc"
73
74   override def toString: String = "ORC"
75
76   override def hashCode(): Int = getClass.hashCode()
77
78   override def equals(other: Any): Boolean = other.isInstanceOf[OrcFileFormat]
79
80   override def inferSchema(
81       sparkSession: SparkSession,
82       options: Map[String, String],
83       files: Seq[FileStatus]): Option[StructType] = {
84     OrcUtils.readSchema(sparkSession, files)
85   }
86
87   override def prepareWrite(
88       sparkSession: SparkSession,
89       job: Job,
90       options: Map[String, String],
91       dataSchema: StructType): OutputWriterFactory = {
92     val orcOptions = new OrcOptions(options, sparkSession.sessionState.conf)
93
94     val conf = job.getConfiguration
95
96     conf.set(MAPRED_OUTPUT_SCHEMA.getAttribute, dataSchema.catalogString)
97
98     conf.set(COMPRESS.getAttribute, orcOptions.compressionCodec)
99
100     conf.asInstanceOf[JobConf]
101       .setOutputFormat(classOf[org.apache.orc.mapred.OrcOutputFormat[OrcStruct]])
102
103     new OutputWriterFactory {
104       override def newInstance(
105           path: String,
106           dataSchema: StructType,
107           context: TaskAttemptContext): OutputWriter = {
108         new OrcOutputWriter(path, dataSchema, context)
109       }
110
111       override def getFileExtension(context: TaskAttemptContext): String = {
112         val compressionExtension: String = {
113           val name = context.getConfiguration.get(COMPRESS.getAttribute)
114           OrcUtils.extensionsForCompressionCodecNames.getOrElse(name, "")
115         }
116
117         compressionExtension + ".orc"
118       }
119     }
120   }
121
122   override def supportBatch(sparkSession: SparkSession, schema: StructType): Boolean = {
123     val conf = sparkSession.sessionState.conf
124     conf.orcVectorizedReaderEnabled && conf.wholeStageEnabled &&
125       schema.length <= conf.wholeStageMaxNumFields &&
126       schema.forall(_.dataType.isInstanceOf[AtomicType])
127   }
128
129   override def isSplitable(
130       sparkSession: SparkSession,
131       options: Map[String, String],
132       path: Path): Boolean = {
133     true
134   }
135
136   override def buildReaderWithPartitionValues(
137       sparkSession: SparkSession,
138       dataSchema: StructType,
139       partitionSchema: StructType,
140       requiredSchema: StructType,
141       filters: Seq[Filter],
142       options: Map[String, String],
143       hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = {
144     if (sparkSession.sessionState.conf.orcFilterPushDown) {
145       OrcFilters.createFilter(dataSchema, filters).foreach { f =>
146         OrcInputFormat.setSearchArgument(hadoopConf, f, dataSchema.fieldNames)
147       }
148     }
149
150     val resultSchema = StructType(requiredSchema.fields ++ partitionSchema.fields)
151     val sqlConf = sparkSession.sessionState.conf
152     val enableOffHeapColumnVector = sqlConf.offHeapColumnVectorEnabled
153     val enableVectorizedReader = supportBatch(sparkSession, resultSchema)
154     val capacity = sqlConf.orcVectorizedReaderBatchSize
155     val copyToSpark = sparkSession.sessionState.conf.getConf(SQLConf.ORC_COPY_BATCH_TO_SPARK)
156
157     val broadcastedConf =
158       sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf))
159     val isCaseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis
160
161     (file: PartitionedFile) => {
162       val conf = broadcastedConf.value.value
163
164       val filePath = new Path(new URI(file.filePath))
165
166       val fs = filePath.getFileSystem(conf)
167       val readerOptions = OrcFile.readerOptions(conf).filesystem(fs)
168       val reader = OrcFile.createReader(filePath, readerOptions)
169
170       val requestedColIdsOrEmptyFile = OrcUtils.requestedColumnIds(
171         isCaseSensitive, dataSchema, requiredSchema, reader, conf)
172
173       if (requestedColIdsOrEmptyFile.isEmpty) {
174         Iterator.empty
175       } else {
176         val requestedColIds = requestedColIdsOrEmptyFile.get
177         assert(requestedColIds.length == requiredSchema.length,
178           "[BUG] requested column IDs do not match required schema")
179         val taskConf = new Configuration(conf)
180         taskConf.set(OrcConf.INCLUDE_COLUMNS.getAttribute,
181           requestedColIds.filter(_ != -1).sorted.mkString(","))
182
183         val fileSplit = new FileSplit(filePath, file.start, file.length, Array.empty)
184         val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0)
185         val taskAttemptContext = new TaskAttemptContextImpl(taskConf, attemptId)
186
187         val taskContext = Option(TaskContext.get())
188         if (enableVectorizedReader) {
189           val batchReader = new OrcColumnarBatchReader(
190             enableOffHeapColumnVector && taskContext.isDefined, copyToSpark, capacity)
191           // SPARK-23399 Register a task completion listener first to call `close()` in all cases.
192           // There is a possibility that `initialize` and `initBatch` hit some errors (like OOM)
193           // after opening a file.
194           val iter = new RecordReaderIterator(batchReader)
195           Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => iter.close()))
196
197           batchReader.initialize(fileSplit, taskAttemptContext)
198           batchReader.initBatch(
199             reader.getSchema,
200             requestedColIds,
201             requiredSchema.fields,
202             partitionSchema,
203             file.partitionValues)
204
205           iter.asInstanceOf[Iterator[InternalRow]]
206         } else {
207           val orcRecordReader = new OrcInputFormat[OrcStruct]
208             .createRecordReader(fileSplit, taskAttemptContext)
209           val iter = new RecordReaderIterator[OrcStruct](orcRecordReader)
210           Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => iter.close()))
211
212           val fullSchema = requiredSchema.toAttributes ++ partitionSchema.toAttributes
213           val unsafeProjection = GenerateUnsafeProjection.generate(fullSchema, fullSchema)
214           val deserializer = new OrcDeserializer(dataSchema, requiredSchema, requestedColIds)
215
216           if (partitionSchema.length == 0) {
217             iter.map(value => unsafeProjection(deserializer.deserialize(value)))
218           } else {
219             val joinedRow = new JoinedRow()
220             iter.map(value =>
221               unsafeProjection(joinedRow(deserializer.deserialize(value), file.partitionValues)))
222           }
223         }
224       }
225     }
226   }
227
228   override def supportDataType(dataType: DataType, isReadPath: Boolean): Boolean = dataType match {
229     case _: AtomicType => true
230
231     case st: StructType => st.forall { f => supportDataType(f.dataType, isReadPath) }
232
233     case ArrayType(elementType, _) => supportDataType(elementType, isReadPath)
234
235     case MapType(keyType, valueType, _) =>
236       supportDataType(keyType, isReadPath) && supportDataType(valueType, isReadPath)
237
238     case udt: UserDefinedType[_] => supportDataType(udt.sqlType, isReadPath)
239
240     case _: NullType => isReadPath
241
242     case _ => false
243   }
244 }