dd2144c5fcea845f2ef457c4b32867c41c038caf
[spark.git] / sql / hive / src / main / scala / org / apache / spark / sql / hive / orc / OrcFileFormat.scala
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one or more
3  * contributor license agreements.  See the NOTICE file distributed with
4  * this work for additional information regarding copyright ownership.
5  * The ASF licenses this file to You under the Apache License, Version 2.0
6  * (the "License"); you may not use this file except in compliance with
7  * the License.  You may obtain a copy of the License at
8  *
9  *    http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 package org.apache.spark.sql.hive.orc
19
20 import java.net.URI
21 import java.util.Properties
22
23 import scala.collection.JavaConverters._
24
25 import org.apache.hadoop.conf.Configuration
26 import org.apache.hadoop.fs.{FileStatus, Path}
27 import org.apache.hadoop.hive.conf.HiveConf.ConfVars
28 import org.apache.hadoop.hive.ql.io.orc._
29 import org.apache.hadoop.hive.serde2.objectinspector.{SettableStructObjectInspector, StructObjectInspector}
30 import org.apache.hadoop.hive.serde2.typeinfo.{StructTypeInfo, TypeInfoUtils}
31 import org.apache.hadoop.io.{NullWritable, Writable}
32 import org.apache.hadoop.mapred.{JobConf, OutputFormat => MapRedOutputFormat, RecordWriter, Reporter}
33 import org.apache.hadoop.mapreduce._
34 import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat, FileSplit}
35 import org.apache.orc.OrcConf.COMPRESS
36
37 import org.apache.spark.TaskContext
38 import org.apache.spark.sql.SparkSession
39 import org.apache.spark.sql.catalyst.InternalRow
40 import org.apache.spark.sql.catalyst.expressions._
41 import org.apache.spark.sql.execution.datasources._
42 import org.apache.spark.sql.execution.datasources.orc.OrcOptions
43 import org.apache.spark.sql.hive.{HiveInspectors, HiveShim}
44 import org.apache.spark.sql.sources.{Filter, _}
45 import org.apache.spark.sql.types.StructType
46 import org.apache.spark.util.SerializableConfiguration
47
48 /**
49  * `FileFormat` for reading ORC files. If this is moved or renamed, please update
50  * `DataSource`'s backwardCompatibilityMap.
51  */
52 class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable {
53
54   override def shortName(): String = "orc"
55
56   override def toString: String = "ORC"
57
58   override def inferSchema(
59       sparkSession: SparkSession,
60       options: Map[String, String],
61       files: Seq[FileStatus]): Option[StructType] = {
62     val ignoreCorruptFiles = sparkSession.sessionState.conf.ignoreCorruptFiles
63     OrcFileOperator.readSchema(
64       files.map(_.getPath.toString),
65       Some(sparkSession.sessionState.newHadoopConf()),
66       ignoreCorruptFiles
67     )
68   }
69
70   override def prepareWrite(
71       sparkSession: SparkSession,
72       job: Job,
73       options: Map[String, String],
74       dataSchema: StructType): OutputWriterFactory = {
75     DataSourceUtils.verifyWriteSchema(this, dataSchema)
76
77     val orcOptions = new OrcOptions(options, sparkSession.sessionState.conf)
78
79     val configuration = job.getConfiguration
80
81     configuration.set(COMPRESS.getAttribute, orcOptions.compressionCodec)
82     configuration match {
83       case conf: JobConf =>
84         conf.setOutputFormat(classOf[OrcOutputFormat])
85       case conf =>
86         conf.setClass(
87           "mapred.output.format.class",
88           classOf[OrcOutputFormat],
89           classOf[MapRedOutputFormat[_, _]])
90     }
91
92     new OutputWriterFactory {
93       override def newInstance(
94           path: String,
95           dataSchema: StructType,
96           context: TaskAttemptContext): OutputWriter = {
97         new OrcOutputWriter(path, dataSchema, context)
98       }
99
100       override def getFileExtension(context: TaskAttemptContext): String = {
101         val compressionExtension: String = {
102           val name = context.getConfiguration.get(COMPRESS.getAttribute)
103           OrcFileFormat.extensionsForCompressionCodecNames.getOrElse(name, "")
104         }
105
106         compressionExtension + ".orc"
107       }
108     }
109   }
110
111   override def isSplitable(
112       sparkSession: SparkSession,
113       options: Map[String, String],
114       path: Path): Boolean = {
115     true
116   }
117
118   override def buildReader(
119       sparkSession: SparkSession,
120       dataSchema: StructType,
121       partitionSchema: StructType,
122       requiredSchema: StructType,
123       filters: Seq[Filter],
124       options: Map[String, String],
125       hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = {
126     DataSourceUtils.verifyReadSchema(this, dataSchema)
127
128     if (sparkSession.sessionState.conf.orcFilterPushDown) {
129       // Sets pushed predicates
130       OrcFilters.createFilter(requiredSchema, filters.toArray).foreach { f =>
131         hadoopConf.set(OrcFileFormat.SARG_PUSHDOWN, f.toKryo)
132         hadoopConf.setBoolean(ConfVars.HIVEOPTINDEXFILTER.varname, true)
133       }
134     }
135
136     val broadcastedHadoopConf =
137       sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf))
138     val ignoreCorruptFiles = sparkSession.sessionState.conf.ignoreCorruptFiles
139
140     (file: PartitionedFile) => {
141       val conf = broadcastedHadoopConf.value.value
142
143       val filePath = new Path(new URI(file.filePath))
144
145       // SPARK-8501: Empty ORC files always have an empty schema stored in their footer. In this
146       // case, `OrcFileOperator.readSchema` returns `None`, and we can't read the underlying file
147       // using the given physical schema. Instead, we simply return an empty iterator.
148       val isEmptyFile =
149         OrcFileOperator.readSchema(Seq(filePath.toString), Some(conf), ignoreCorruptFiles).isEmpty
150       if (isEmptyFile) {
151         Iterator.empty
152       } else {
153         OrcFileFormat.setRequiredColumns(conf, dataSchema, requiredSchema)
154
155         val orcRecordReader = {
156           val job = Job.getInstance(conf)
157           FileInputFormat.setInputPaths(job, file.filePath)
158
159           val fileSplit = new FileSplit(filePath, file.start, file.length, Array.empty)
160           // Custom OrcRecordReader is used to get
161           // ObjectInspector during recordReader creation itself and can
162           // avoid NameNode call in unwrapOrcStructs per file.
163           // Specifically would be helpful for partitioned datasets.
164           val orcReader = OrcFile.createReader(filePath, OrcFile.readerOptions(conf))
165           new SparkOrcNewRecordReader(orcReader, conf, fileSplit.getStart, fileSplit.getLength)
166         }
167
168         val recordsIterator = new RecordReaderIterator[OrcStruct](orcRecordReader)
169         Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => recordsIterator.close()))
170
171         // Unwraps `OrcStruct`s to `UnsafeRow`s
172         OrcFileFormat.unwrapOrcStructs(
173           conf,
174           dataSchema,
175           requiredSchema,
176           Some(orcRecordReader.getObjectInspector.asInstanceOf[StructObjectInspector]),
177           recordsIterator)
178       }
179     }
180   }
181 }
182
183 private[orc] class OrcSerializer(dataSchema: StructType, conf: Configuration)
184   extends HiveInspectors {
185
186   def serialize(row: InternalRow): Writable = {
187     wrapOrcStruct(cachedOrcStruct, structOI, row)
188     serializer.serialize(cachedOrcStruct, structOI)
189   }
190
191   private[this] val serializer = {
192     val table = new Properties()
193     table.setProperty("columns", dataSchema.fieldNames.mkString(","))
194     table.setProperty("columns.types", dataSchema.map(_.dataType.catalogString).mkString(":"))
195
196     val serde = new OrcSerde
197     serde.initialize(conf, table)
198     serde
199   }
200
201   // Object inspector converted from the schema of the relation to be serialized.
202   private[this] val structOI = {
203     val typeInfo = TypeInfoUtils.getTypeInfoFromTypeString(dataSchema.catalogString)
204     OrcStruct.createObjectInspector(typeInfo.asInstanceOf[StructTypeInfo])
205       .asInstanceOf[SettableStructObjectInspector]
206   }
207
208   private[this] val cachedOrcStruct = structOI.create().asInstanceOf[OrcStruct]
209
210   // Wrapper functions used to wrap Spark SQL input arguments into Hive specific format
211   private[this] val wrappers = dataSchema.zip(structOI.getAllStructFieldRefs().asScala.toSeq).map {
212     case (f, i) => wrapperFor(i.getFieldObjectInspector, f.dataType)
213   }
214
215   private[this] def wrapOrcStruct(
216       struct: OrcStruct,
217       oi: SettableStructObjectInspector,
218       row: InternalRow): Unit = {
219     val fieldRefs = oi.getAllStructFieldRefs
220     var i = 0
221     val size = fieldRefs.size
222     while (i < size) {
223
224       oi.setStructFieldData(
225         struct,
226         fieldRefs.get(i),
227         wrappers(i)(row.get(i, dataSchema(i).dataType))
228       )
229       i += 1
230     }
231   }
232 }
233
234 private[orc] class OrcOutputWriter(
235     path: String,
236     dataSchema: StructType,
237     context: TaskAttemptContext)
238   extends OutputWriter {
239
240   private[this] val serializer = new OrcSerializer(dataSchema, context.getConfiguration)
241
242   // `OrcRecordWriter.close()` creates an empty file if no rows are written at all.  We use this
243   // flag to decide whether `OrcRecordWriter.close()` needs to be called.
244   private var recordWriterInstantiated = false
245
246   private lazy val recordWriter: RecordWriter[NullWritable, Writable] = {
247     recordWriterInstantiated = true
248     new OrcOutputFormat().getRecordWriter(
249       new Path(path).getFileSystem(context.getConfiguration),
250       context.getConfiguration.asInstanceOf[JobConf],
251       path,
252       Reporter.NULL
253     ).asInstanceOf[RecordWriter[NullWritable, Writable]]
254   }
255
256   override def write(row: InternalRow): Unit = {
257     recordWriter.write(NullWritable.get(), serializer.serialize(row))
258   }
259
260   override def close(): Unit = {
261     if (recordWriterInstantiated) {
262       recordWriter.close(Reporter.NULL)
263     }
264   }
265 }
266
267 private[orc] object OrcFileFormat extends HiveInspectors {
268   // This constant duplicates `OrcInputFormat.SARG_PUSHDOWN`, which is unfortunately not public.
269   private[orc] val SARG_PUSHDOWN = "sarg.pushdown"
270
271   // The extensions for ORC compression codecs
272   val extensionsForCompressionCodecNames = Map(
273     "NONE" -> "",
274     "SNAPPY" -> ".snappy",
275     "ZLIB" -> ".zlib",
276     "LZO" -> ".lzo")
277
278   def unwrapOrcStructs(
279       conf: Configuration,
280       dataSchema: StructType,
281       requiredSchema: StructType,
282       maybeStructOI: Option[StructObjectInspector],
283       iterator: Iterator[Writable]): Iterator[InternalRow] = {
284     val deserializer = new OrcSerde
285     val mutableRow = new SpecificInternalRow(requiredSchema.map(_.dataType))
286     val unsafeProjection = UnsafeProjection.create(requiredSchema)
287
288     def unwrap(oi: StructObjectInspector): Iterator[InternalRow] = {
289       val (fieldRefs, fieldOrdinals) = requiredSchema.zipWithIndex.map {
290         case (field, ordinal) =>
291           var ref = oi.getStructFieldRef(field.name)
292           if (ref == null) {
293             ref = oi.getStructFieldRef("_col" + dataSchema.fieldIndex(field.name))
294           }
295           ref -> ordinal
296       }.unzip
297
298       val unwrappers = fieldRefs.map(r => if (r == null) null else unwrapperFor(r))
299
300       iterator.map { value =>
301         val raw = deserializer.deserialize(value)
302         var i = 0
303         val length = fieldRefs.length
304         while (i < length) {
305           val fieldRef = fieldRefs(i)
306           val fieldValue = if (fieldRef == null) null else oi.getStructFieldData(raw, fieldRef)
307           if (fieldValue == null) {
308             mutableRow.setNullAt(fieldOrdinals(i))
309           } else {
310             unwrappers(i)(fieldValue, mutableRow, fieldOrdinals(i))
311           }
312           i += 1
313         }
314         unsafeProjection(mutableRow)
315       }
316     }
317
318     maybeStructOI.map(unwrap).getOrElse(Iterator.empty)
319   }
320
321   def setRequiredColumns(
322       conf: Configuration, dataSchema: StructType, requestedSchema: StructType): Unit = {
323     val ids = requestedSchema.map(a => dataSchema.fieldIndex(a.name): Integer)
324     val (sortedIDs, sortedNames) = ids.zip(requestedSchema.fieldNames).sorted.unzip
325     HiveShim.appendReadColumns(conf, sortedIDs, sortedNames)
326   }
327 }