df488a748e3e5f41b6da17d7f58ac2e6db72b2be
[spark.git] / sql / core / src / main / scala / org / apache / spark / sql / execution / datasources / orc / OrcFileFormat.scala
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one or more
3  * contributor license agreements.  See the NOTICE file distributed with
4  * this work for additional information regarding copyright ownership.
5  * The ASF licenses this file to You under the Apache License, Version 2.0
6  * (the "License"); you may not use this file except in compliance with
7  * the License.  You may obtain a copy of the License at
8  *
9  *    http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 package org.apache.spark.sql.execution.datasources.orc
19
20 import java.io._
21 import java.net.URI
22
23 import org.apache.hadoop.conf.Configuration
24 import org.apache.hadoop.fs.{FileStatus, Path}
25 import org.apache.hadoop.mapred.JobConf
26 import org.apache.hadoop.mapreduce._
27 import org.apache.hadoop.mapreduce.lib.input.FileSplit
28 import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl
29 import org.apache.orc._
30 import org.apache.orc.OrcConf.{COMPRESS, MAPRED_OUTPUT_SCHEMA}
31 import org.apache.orc.mapred.OrcStruct
32 import org.apache.orc.mapreduce._
33
34 import org.apache.spark.TaskContext
35 import org.apache.spark.sql.AnalysisException
36 import org.apache.spark.sql.SparkSession
37 import org.apache.spark.sql.catalyst.InternalRow
38 import org.apache.spark.sql.catalyst.expressions._
39 import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
40 import org.apache.spark.sql.execution.datasources._
41 import org.apache.spark.sql.internal.SQLConf
42 import org.apache.spark.sql.sources._
43 import org.apache.spark.sql.types._
44 import org.apache.spark.util.SerializableConfiguration
45
46 private[sql] object OrcFileFormat {
47   private def checkFieldName(name: String): Unit = {
48     try {
49       TypeDescription.fromString(s"struct<$name:int>")
50     } catch {
51       case _: IllegalArgumentException =>
52         throw new AnalysisException(
53           s"""Column name "$name" contains invalid character(s).
54              |Please use alias to rename it.
55            """.stripMargin.split("\n").mkString(" ").trim)
56     }
57   }
58
59   def checkFieldNames(names: Seq[String]): Unit = {
60     names.foreach(checkFieldName)
61   }
62 }
63
64 /**
65  * New ORC File Format based on Apache ORC.
66  */
67 class OrcFileFormat
68   extends FileFormat
69   with DataSourceRegister
70   with Serializable {
71
72   override def shortName(): String = "orc"
73
74   override def toString: String = "ORC"
75
76   override def hashCode(): Int = getClass.hashCode()
77
78   override def equals(other: Any): Boolean = other.isInstanceOf[OrcFileFormat]
79
80   override def inferSchema(
81       sparkSession: SparkSession,
82       options: Map[String, String],
83       files: Seq[FileStatus]): Option[StructType] = {
84     OrcUtils.readSchema(sparkSession, files)
85   }
86
87   override def prepareWrite(
88       sparkSession: SparkSession,
89       job: Job,
90       options: Map[String, String],
91       dataSchema: StructType): OutputWriterFactory = {
92     DataSourceUtils.verifyWriteSchema(this, dataSchema)
93
94     val orcOptions = new OrcOptions(options, sparkSession.sessionState.conf)
95
96     val conf = job.getConfiguration
97
98     conf.set(MAPRED_OUTPUT_SCHEMA.getAttribute, dataSchema.catalogString)
99
100     conf.set(COMPRESS.getAttribute, orcOptions.compressionCodec)
101
102     conf.asInstanceOf[JobConf]
103       .setOutputFormat(classOf[org.apache.orc.mapred.OrcOutputFormat[OrcStruct]])
104
105     new OutputWriterFactory {
106       override def newInstance(
107           path: String,
108           dataSchema: StructType,
109           context: TaskAttemptContext): OutputWriter = {
110         new OrcOutputWriter(path, dataSchema, context)
111       }
112
113       override def getFileExtension(context: TaskAttemptContext): String = {
114         val compressionExtension: String = {
115           val name = context.getConfiguration.get(COMPRESS.getAttribute)
116           OrcUtils.extensionsForCompressionCodecNames.getOrElse(name, "")
117         }
118
119         compressionExtension + ".orc"
120       }
121     }
122   }
123
124   override def supportBatch(sparkSession: SparkSession, schema: StructType): Boolean = {
125     val conf = sparkSession.sessionState.conf
126     conf.orcVectorizedReaderEnabled && conf.wholeStageEnabled &&
127       schema.length <= conf.wholeStageMaxNumFields &&
128       schema.forall(_.dataType.isInstanceOf[AtomicType])
129   }
130
131   override def isSplitable(
132       sparkSession: SparkSession,
133       options: Map[String, String],
134       path: Path): Boolean = {
135     true
136   }
137
138   override def buildReaderWithPartitionValues(
139       sparkSession: SparkSession,
140       dataSchema: StructType,
141       partitionSchema: StructType,
142       requiredSchema: StructType,
143       filters: Seq[Filter],
144       options: Map[String, String],
145       hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = {
146     DataSourceUtils.verifyReadSchema(this, dataSchema)
147
148     if (sparkSession.sessionState.conf.orcFilterPushDown) {
149       OrcFilters.createFilter(dataSchema, filters).foreach { f =>
150         OrcInputFormat.setSearchArgument(hadoopConf, f, dataSchema.fieldNames)
151       }
152     }
153
154     val resultSchema = StructType(requiredSchema.fields ++ partitionSchema.fields)
155     val sqlConf = sparkSession.sessionState.conf
156     val enableOffHeapColumnVector = sqlConf.offHeapColumnVectorEnabled
157     val enableVectorizedReader = supportBatch(sparkSession, resultSchema)
158     val capacity = sqlConf.orcVectorizedReaderBatchSize
159     val copyToSpark = sparkSession.sessionState.conf.getConf(SQLConf.ORC_COPY_BATCH_TO_SPARK)
160
161     val broadcastedConf =
162       sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf))
163     val isCaseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis
164
165     (file: PartitionedFile) => {
166       val conf = broadcastedConf.value.value
167
168       val filePath = new Path(new URI(file.filePath))
169
170       val fs = filePath.getFileSystem(conf)
171       val readerOptions = OrcFile.readerOptions(conf).filesystem(fs)
172       val reader = OrcFile.createReader(filePath, readerOptions)
173
174       val requestedColIdsOrEmptyFile = OrcUtils.requestedColumnIds(
175         isCaseSensitive, dataSchema, requiredSchema, reader, conf)
176
177       if (requestedColIdsOrEmptyFile.isEmpty) {
178         Iterator.empty
179       } else {
180         val requestedColIds = requestedColIdsOrEmptyFile.get
181         assert(requestedColIds.length == requiredSchema.length,
182           "[BUG] requested column IDs do not match required schema")
183         val taskConf = new Configuration(conf)
184         taskConf.set(OrcConf.INCLUDE_COLUMNS.getAttribute,
185           requestedColIds.filter(_ != -1).sorted.mkString(","))
186
187         val fileSplit = new FileSplit(filePath, file.start, file.length, Array.empty)
188         val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0)
189         val taskAttemptContext = new TaskAttemptContextImpl(taskConf, attemptId)
190
191         val taskContext = Option(TaskContext.get())
192         if (enableVectorizedReader) {
193           val batchReader = new OrcColumnarBatchReader(
194             enableOffHeapColumnVector && taskContext.isDefined, copyToSpark, capacity)
195           // SPARK-23399 Register a task completion listener first to call `close()` in all cases.
196           // There is a possibility that `initialize` and `initBatch` hit some errors (like OOM)
197           // after opening a file.
198           val iter = new RecordReaderIterator(batchReader)
199           Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => iter.close()))
200
201           batchReader.initialize(fileSplit, taskAttemptContext)
202           batchReader.initBatch(
203             reader.getSchema,
204             requestedColIds,
205             requiredSchema.fields,
206             partitionSchema,
207             file.partitionValues)
208
209           iter.asInstanceOf[Iterator[InternalRow]]
210         } else {
211           val orcRecordReader = new OrcInputFormat[OrcStruct]
212             .createRecordReader(fileSplit, taskAttemptContext)
213           val iter = new RecordReaderIterator[OrcStruct](orcRecordReader)
214           Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => iter.close()))
215
216           val fullSchema = requiredSchema.toAttributes ++ partitionSchema.toAttributes
217           val unsafeProjection = GenerateUnsafeProjection.generate(fullSchema, fullSchema)
218           val deserializer = new OrcDeserializer(dataSchema, requiredSchema, requestedColIds)
219
220           if (partitionSchema.length == 0) {
221             iter.map(value => unsafeProjection(deserializer.deserialize(value)))
222           } else {
223             val joinedRow = new JoinedRow()
224             iter.map(value =>
225               unsafeProjection(joinedRow(deserializer.deserialize(value), file.partitionValues)))
226           }
227         }
228       }
229     }
230   }
231 }