[SPARK-24691][SQL] Dispatch the type support check in FileFormat implementation
[spark.git] / sql / core / src / main / scala / org / apache / spark / sql / execution / datasources / FileFormatWriter.scala
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one or more
3  * contributor license agreements.  See the NOTICE file distributed with
4  * this work for additional information regarding copyright ownership.
5  * The ASF licenses this file to You under the Apache License, Version 2.0
6  * (the "License"); you may not use this file except in compliance with
7  * the License.  You may obtain a copy of the License at
8  *
9  *    http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 package org.apache.spark.sql.execution.datasources
19
20 import java.util.{Date, UUID}
21
22 import org.apache.hadoop.conf.Configuration
23 import org.apache.hadoop.fs.Path
24 import org.apache.hadoop.mapreduce._
25 import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat
26 import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl
27
28 import org.apache.spark._
29 import org.apache.spark.internal.Logging
30 import org.apache.spark.internal.io.{FileCommitProtocol, SparkHadoopWriterUtils}
31 import org.apache.spark.shuffle.FetchFailedException
32 import org.apache.spark.sql.SparkSession
33 import org.apache.spark.sql.catalyst.InternalRow
34 import org.apache.spark.sql.catalyst.catalog.BucketSpec
35 import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
36 import org.apache.spark.sql.catalyst.expressions._
37 import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
38 import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
39 import org.apache.spark.sql.execution.{SortExec, SparkPlan, SQLExecution}
40 import org.apache.spark.util.{SerializableConfiguration, Utils}
41
42
43 /** A helper object for writing FileFormat data out to a location. */
44 object FileFormatWriter extends Logging {
45   /** Describes how output files should be placed in the filesystem. */
46   case class OutputSpec(
47       outputPath: String,
48       customPartitionLocations: Map[TablePartitionSpec, String],
49       outputColumns: Seq[Attribute])
50
51   /**
52    * Basic work flow of this command is:
53    * 1. Driver side setup, including output committer initialization and data source specific
54    *    preparation work for the write job to be issued.
55    * 2. Issues a write job consists of one or more executor side tasks, each of which writes all
56    *    rows within an RDD partition.
57    * 3. If no exception is thrown in a task, commits that task, otherwise aborts that task;  If any
58    *    exception is thrown during task commitment, also aborts that task.
59    * 4. If all tasks are committed, commit the job, otherwise aborts the job;  If any exception is
60    *    thrown during job commitment, also aborts the job.
61    * 5. If the job is successfully committed, perform post-commit operations such as
62    *    processing statistics.
63    * @return The set of all partition paths that were updated during this write job.
64    */
65   def write(
66       sparkSession: SparkSession,
67       plan: SparkPlan,
68       fileFormat: FileFormat,
69       committer: FileCommitProtocol,
70       outputSpec: OutputSpec,
71       hadoopConf: Configuration,
72       partitionColumns: Seq[Attribute],
73       bucketSpec: Option[BucketSpec],
74       statsTrackers: Seq[WriteJobStatsTracker],
75       options: Map[String, String])
76     : Set[String] = {
77
78     val job = Job.getInstance(hadoopConf)
79     job.setOutputKeyClass(classOf[Void])
80     job.setOutputValueClass(classOf[InternalRow])
81     FileOutputFormat.setOutputPath(job, new Path(outputSpec.outputPath))
82
83     val partitionSet = AttributeSet(partitionColumns)
84     val dataColumns = outputSpec.outputColumns.filterNot(partitionSet.contains)
85
86     val bucketIdExpression = bucketSpec.map { spec =>
87       val bucketColumns = spec.bucketColumnNames.map(c => dataColumns.find(_.name == c).get)
88       // Use `HashPartitioning.partitionIdExpression` as our bucket id expression, so that we can
89       // guarantee the data distribution is same between shuffle and bucketed data source, which
90       // enables us to only shuffle one side when join a bucketed table and a normal one.
91       HashPartitioning(bucketColumns, spec.numBuckets).partitionIdExpression
92     }
93     val sortColumns = bucketSpec.toSeq.flatMap {
94       spec => spec.sortColumnNames.map(c => dataColumns.find(_.name == c).get)
95     }
96
97     val caseInsensitiveOptions = CaseInsensitiveMap(options)
98
99     val dataSchema = dataColumns.toStructType
100     DataSourceUtils.verifyWriteSchema(fileFormat, dataSchema)
101     // Note: prepareWrite has side effect. It sets "job".
102     val outputWriterFactory =
103       fileFormat.prepareWrite(sparkSession, job, caseInsensitiveOptions, dataSchema)
104
105     val description = new WriteJobDescription(
106       uuid = UUID.randomUUID().toString,
107       serializableHadoopConf = new SerializableConfiguration(job.getConfiguration),
108       outputWriterFactory = outputWriterFactory,
109       allColumns = outputSpec.outputColumns,
110       dataColumns = dataColumns,
111       partitionColumns = partitionColumns,
112       bucketIdExpression = bucketIdExpression,
113       path = outputSpec.outputPath,
114       customPartitionLocations = outputSpec.customPartitionLocations,
115       maxRecordsPerFile = caseInsensitiveOptions.get("maxRecordsPerFile").map(_.toLong)
116         .getOrElse(sparkSession.sessionState.conf.maxRecordsPerFile),
117       timeZoneId = caseInsensitiveOptions.get(DateTimeUtils.TIMEZONE_OPTION)
118         .getOrElse(sparkSession.sessionState.conf.sessionLocalTimeZone),
119       statsTrackers = statsTrackers
120     )
121
122     // We should first sort by partition columns, then bucket id, and finally sorting columns.
123     val requiredOrdering = partitionColumns ++ bucketIdExpression ++ sortColumns
124     // the sort order doesn't matter
125     val actualOrdering = plan.outputOrdering.map(_.child)
126     val orderingMatched = if (requiredOrdering.length > actualOrdering.length) {
127       false
128     } else {
129       requiredOrdering.zip(actualOrdering).forall {
130         case (requiredOrder, childOutputOrder) =>
131           requiredOrder.semanticEquals(childOutputOrder)
132       }
133     }
134
135     SQLExecution.checkSQLExecutionId(sparkSession)
136
137     // This call shouldn't be put into the `try` block below because it only initializes and
138     // prepares the job, any exception thrown from here shouldn't cause abortJob() to be called.
139     committer.setupJob(job)
140
141     try {
142       val rdd = if (orderingMatched) {
143         plan.execute()
144       } else {
145         // SPARK-21165: the `requiredOrdering` is based on the attributes from analyzed plan, and
146         // the physical plan may have different attribute ids due to optimizer removing some
147         // aliases. Here we bind the expression ahead to avoid potential attribute ids mismatch.
148         val orderingExpr = requiredOrdering
149           .map(SortOrder(_, Ascending))
150           .map(BindReferences.bindReference(_, outputSpec.outputColumns))
151         SortExec(
152           orderingExpr,
153           global = false,
154           child = plan).execute()
155       }
156
157       // SPARK-23271 If we are attempting to write a zero partition rdd, create a dummy single
158       // partition rdd to make sure we at least set up one write task to write the metadata.
159       val rddWithNonEmptyPartitions = if (rdd.partitions.length == 0) {
160         sparkSession.sparkContext.parallelize(Array.empty[InternalRow], 1)
161       } else {
162         rdd
163       }
164
165       val ret = new Array[WriteTaskResult](rddWithNonEmptyPartitions.partitions.length)
166       sparkSession.sparkContext.runJob(
167         rddWithNonEmptyPartitions,
168         (taskContext: TaskContext, iter: Iterator[InternalRow]) => {
169           executeTask(
170             description = description,
171             sparkStageId = taskContext.stageId(),
172             sparkPartitionId = taskContext.partitionId(),
173             sparkAttemptNumber = taskContext.attemptNumber(),
174             committer,
175             iterator = iter)
176         },
177         rddWithNonEmptyPartitions.partitions.indices,
178         (index, res: WriteTaskResult) => {
179           committer.onTaskCommit(res.commitMsg)
180           ret(index) = res
181         })
182
183       val commitMsgs = ret.map(_.commitMsg)
184
185       committer.commitJob(job, commitMsgs)
186       logInfo(s"Job ${job.getJobID} committed.")
187
188       processStats(description.statsTrackers, ret.map(_.summary.stats))
189       logInfo(s"Finished processing stats for job ${job.getJobID}.")
190
191       // return a set of all the partition paths that were updated during this job
192       ret.map(_.summary.updatedPartitions).reduceOption(_ ++ _).getOrElse(Set.empty)
193     } catch { case cause: Throwable =>
194       logError(s"Aborting job ${job.getJobID}.", cause)
195       committer.abortJob(job)
196       throw new SparkException("Job aborted.", cause)
197     }
198   }
199
200   /** Writes data out in a single Spark task. */
201   private def executeTask(
202       description: WriteJobDescription,
203       sparkStageId: Int,
204       sparkPartitionId: Int,
205       sparkAttemptNumber: Int,
206       committer: FileCommitProtocol,
207       iterator: Iterator[InternalRow]): WriteTaskResult = {
208
209     val jobId = SparkHadoopWriterUtils.createJobID(new Date, sparkStageId)
210     val taskId = new TaskID(jobId, TaskType.MAP, sparkPartitionId)
211     val taskAttemptId = new TaskAttemptID(taskId, sparkAttemptNumber)
212
213     // Set up the attempt context required to use in the output committer.
214     val taskAttemptContext: TaskAttemptContext = {
215       // Set up the configuration object
216       val hadoopConf = description.serializableHadoopConf.value
217       hadoopConf.set("mapreduce.job.id", jobId.toString)
218       hadoopConf.set("mapreduce.task.id", taskAttemptId.getTaskID.toString)
219       hadoopConf.set("mapreduce.task.attempt.id", taskAttemptId.toString)
220       hadoopConf.setBoolean("mapreduce.task.ismap", true)
221       hadoopConf.setInt("mapreduce.task.partition", 0)
222
223       new TaskAttemptContextImpl(hadoopConf, taskAttemptId)
224     }
225
226     committer.setupTask(taskAttemptContext)
227
228     val dataWriter =
229       if (sparkPartitionId != 0 && !iterator.hasNext) {
230         // In case of empty job, leave first partition to save meta for file format like parquet.
231         new EmptyDirectoryDataWriter(description, taskAttemptContext, committer)
232       } else if (description.partitionColumns.isEmpty && description.bucketIdExpression.isEmpty) {
233         new SingleDirectoryDataWriter(description, taskAttemptContext, committer)
234       } else {
235         new DynamicPartitionDataWriter(description, taskAttemptContext, committer)
236       }
237
238     try {
239       Utils.tryWithSafeFinallyAndFailureCallbacks(block = {
240         // Execute the task to write rows out and commit the task.
241         while (iterator.hasNext) {
242           dataWriter.write(iterator.next())
243         }
244         dataWriter.commit()
245       })(catchBlock = {
246         // If there is an error, abort the task
247         dataWriter.abort()
248         logError(s"Job $jobId aborted.")
249       })
250     } catch {
251       case e: FetchFailedException =>
252         throw e
253       case t: Throwable =>
254         throw new SparkException("Task failed while writing rows.", t)
255     }
256   }
257
258   /**
259    * For every registered [[WriteJobStatsTracker]], call `processStats()` on it, passing it
260    * the corresponding [[WriteTaskStats]] from all executors.
261    */
262   private def processStats(
263       statsTrackers: Seq[WriteJobStatsTracker],
264       statsPerTask: Seq[Seq[WriteTaskStats]])
265   : Unit = {
266
267     val numStatsTrackers = statsTrackers.length
268     assert(statsPerTask.forall(_.length == numStatsTrackers),
269       s"""Every WriteTask should have produced one `WriteTaskStats` object for every tracker.
270          |There are $numStatsTrackers statsTrackers, but some task returned
271          |${statsPerTask.find(_.length != numStatsTrackers).get.length} results instead.
272        """.stripMargin)
273
274     val statsPerTracker = if (statsPerTask.nonEmpty) {
275       statsPerTask.transpose
276     } else {
277       statsTrackers.map(_ => Seq.empty)
278     }
279
280     statsTrackers.zip(statsPerTracker).foreach {
281       case (statsTracker, stats) => statsTracker.processStats(stats)
282     }
283   }
284 }