52da8356ab8350a05e409b2024e3f4bae8fca5ad
[spark.git] / sql / core / src / main / scala / org / apache / spark / sql / execution / datasources / FileFormatWriter.scala
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one or more
3  * contributor license agreements.  See the NOTICE file distributed with
4  * this work for additional information regarding copyright ownership.
5  * The ASF licenses this file to You under the Apache License, Version 2.0
6  * (the "License"); you may not use this file except in compliance with
7  * the License.  You may obtain a copy of the License at
8  *
9  *    http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 package org.apache.spark.sql.execution.datasources
19
20 import java.util.{Date, UUID}
21
22 import org.apache.hadoop.conf.Configuration
23 import org.apache.hadoop.fs.Path
24 import org.apache.hadoop.mapreduce._
25 import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat
26 import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl
27
28 import org.apache.spark._
29 import org.apache.spark.internal.Logging
30 import org.apache.spark.internal.io.{FileCommitProtocol, SparkHadoopWriterUtils}
31 import org.apache.spark.shuffle.FetchFailedException
32 import org.apache.spark.sql.SparkSession
33 import org.apache.spark.sql.catalyst.InternalRow
34 import org.apache.spark.sql.catalyst.catalog.BucketSpec
35 import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
36 import org.apache.spark.sql.catalyst.expressions._
37 import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
38 import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
39 import org.apache.spark.sql.execution.{SortExec, SparkPlan, SQLExecution}
40 import org.apache.spark.util.{SerializableConfiguration, Utils}
41
42
43 /** A helper object for writing FileFormat data out to a location. */
44 object FileFormatWriter extends Logging {
45   /** Describes how output files should be placed in the filesystem. */
46   case class OutputSpec(
47       outputPath: String,
48       customPartitionLocations: Map[TablePartitionSpec, String],
49       outputColumns: Seq[Attribute])
50
51   /**
52    * Basic work flow of this command is:
53    * 1. Driver side setup, including output committer initialization and data source specific
54    *    preparation work for the write job to be issued.
55    * 2. Issues a write job consists of one or more executor side tasks, each of which writes all
56    *    rows within an RDD partition.
57    * 3. If no exception is thrown in a task, commits that task, otherwise aborts that task;  If any
58    *    exception is thrown during task commitment, also aborts that task.
59    * 4. If all tasks are committed, commit the job, otherwise aborts the job;  If any exception is
60    *    thrown during job commitment, also aborts the job.
61    * 5. If the job is successfully committed, perform post-commit operations such as
62    *    processing statistics.
63    * @return The set of all partition paths that were updated during this write job.
64    */
65   def write(
66       sparkSession: SparkSession,
67       plan: SparkPlan,
68       fileFormat: FileFormat,
69       committer: FileCommitProtocol,
70       outputSpec: OutputSpec,
71       hadoopConf: Configuration,
72       partitionColumns: Seq[Attribute],
73       bucketSpec: Option[BucketSpec],
74       statsTrackers: Seq[WriteJobStatsTracker],
75       options: Map[String, String])
76     : Set[String] = {
77
78     val job = Job.getInstance(hadoopConf)
79     job.setOutputKeyClass(classOf[Void])
80     job.setOutputValueClass(classOf[InternalRow])
81     FileOutputFormat.setOutputPath(job, new Path(outputSpec.outputPath))
82
83     val partitionSet = AttributeSet(partitionColumns)
84     val dataColumns = outputSpec.outputColumns.filterNot(partitionSet.contains)
85
86     val bucketIdExpression = bucketSpec.map { spec =>
87       val bucketColumns = spec.bucketColumnNames.map(c => dataColumns.find(_.name == c).get)
88       // Use `HashPartitioning.partitionIdExpression` as our bucket id expression, so that we can
89       // guarantee the data distribution is same between shuffle and bucketed data source, which
90       // enables us to only shuffle one side when join a bucketed table and a normal one.
91       HashPartitioning(bucketColumns, spec.numBuckets).partitionIdExpression
92     }
93     val sortColumns = bucketSpec.toSeq.flatMap {
94       spec => spec.sortColumnNames.map(c => dataColumns.find(_.name == c).get)
95     }
96
97     val caseInsensitiveOptions = CaseInsensitiveMap(options)
98
99     // Note: prepareWrite has side effect. It sets "job".
100     val outputWriterFactory =
101       fileFormat.prepareWrite(sparkSession, job, caseInsensitiveOptions, dataColumns.toStructType)
102
103     val description = new WriteJobDescription(
104       uuid = UUID.randomUUID().toString,
105       serializableHadoopConf = new SerializableConfiguration(job.getConfiguration),
106       outputWriterFactory = outputWriterFactory,
107       allColumns = outputSpec.outputColumns,
108       dataColumns = dataColumns,
109       partitionColumns = partitionColumns,
110       bucketIdExpression = bucketIdExpression,
111       path = outputSpec.outputPath,
112       customPartitionLocations = outputSpec.customPartitionLocations,
113       maxRecordsPerFile = caseInsensitiveOptions.get("maxRecordsPerFile").map(_.toLong)
114         .getOrElse(sparkSession.sessionState.conf.maxRecordsPerFile),
115       timeZoneId = caseInsensitiveOptions.get(DateTimeUtils.TIMEZONE_OPTION)
116         .getOrElse(sparkSession.sessionState.conf.sessionLocalTimeZone),
117       statsTrackers = statsTrackers
118     )
119
120     // We should first sort by partition columns, then bucket id, and finally sorting columns.
121     val requiredOrdering = partitionColumns ++ bucketIdExpression ++ sortColumns
122     // the sort order doesn't matter
123     val actualOrdering = plan.outputOrdering.map(_.child)
124     val orderingMatched = if (requiredOrdering.length > actualOrdering.length) {
125       false
126     } else {
127       requiredOrdering.zip(actualOrdering).forall {
128         case (requiredOrder, childOutputOrder) =>
129           requiredOrder.semanticEquals(childOutputOrder)
130       }
131     }
132
133     SQLExecution.checkSQLExecutionId(sparkSession)
134
135     // This call shouldn't be put into the `try` block below because it only initializes and
136     // prepares the job, any exception thrown from here shouldn't cause abortJob() to be called.
137     committer.setupJob(job)
138
139     try {
140       val rdd = if (orderingMatched) {
141         plan.execute()
142       } else {
143         // SPARK-21165: the `requiredOrdering` is based on the attributes from analyzed plan, and
144         // the physical plan may have different attribute ids due to optimizer removing some
145         // aliases. Here we bind the expression ahead to avoid potential attribute ids mismatch.
146         val orderingExpr = requiredOrdering
147           .map(SortOrder(_, Ascending))
148           .map(BindReferences.bindReference(_, outputSpec.outputColumns))
149         SortExec(
150           orderingExpr,
151           global = false,
152           child = plan).execute()
153       }
154
155       // SPARK-23271 If we are attempting to write a zero partition rdd, create a dummy single
156       // partition rdd to make sure we at least set up one write task to write the metadata.
157       val rddWithNonEmptyPartitions = if (rdd.partitions.length == 0) {
158         sparkSession.sparkContext.parallelize(Array.empty[InternalRow], 1)
159       } else {
160         rdd
161       }
162
163       val ret = new Array[WriteTaskResult](rddWithNonEmptyPartitions.partitions.length)
164       sparkSession.sparkContext.runJob(
165         rddWithNonEmptyPartitions,
166         (taskContext: TaskContext, iter: Iterator[InternalRow]) => {
167           executeTask(
168             description = description,
169             sparkStageId = taskContext.stageId(),
170             sparkPartitionId = taskContext.partitionId(),
171             sparkAttemptNumber = taskContext.attemptNumber(),
172             committer,
173             iterator = iter)
174         },
175         rddWithNonEmptyPartitions.partitions.indices,
176         (index, res: WriteTaskResult) => {
177           committer.onTaskCommit(res.commitMsg)
178           ret(index) = res
179         })
180
181       val commitMsgs = ret.map(_.commitMsg)
182
183       committer.commitJob(job, commitMsgs)
184       logInfo(s"Job ${job.getJobID} committed.")
185
186       processStats(description.statsTrackers, ret.map(_.summary.stats))
187       logInfo(s"Finished processing stats for job ${job.getJobID}.")
188
189       // return a set of all the partition paths that were updated during this job
190       ret.map(_.summary.updatedPartitions).reduceOption(_ ++ _).getOrElse(Set.empty)
191     } catch { case cause: Throwable =>
192       logError(s"Aborting job ${job.getJobID}.", cause)
193       committer.abortJob(job)
194       throw new SparkException("Job aborted.", cause)
195     }
196   }
197
198   /** Writes data out in a single Spark task. */
199   private def executeTask(
200       description: WriteJobDescription,
201       sparkStageId: Int,
202       sparkPartitionId: Int,
203       sparkAttemptNumber: Int,
204       committer: FileCommitProtocol,
205       iterator: Iterator[InternalRow]): WriteTaskResult = {
206
207     val jobId = SparkHadoopWriterUtils.createJobID(new Date, sparkStageId)
208     val taskId = new TaskID(jobId, TaskType.MAP, sparkPartitionId)
209     val taskAttemptId = new TaskAttemptID(taskId, sparkAttemptNumber)
210
211     // Set up the attempt context required to use in the output committer.
212     val taskAttemptContext: TaskAttemptContext = {
213       // Set up the configuration object
214       val hadoopConf = description.serializableHadoopConf.value
215       hadoopConf.set("mapreduce.job.id", jobId.toString)
216       hadoopConf.set("mapreduce.task.id", taskAttemptId.getTaskID.toString)
217       hadoopConf.set("mapreduce.task.attempt.id", taskAttemptId.toString)
218       hadoopConf.setBoolean("mapreduce.task.ismap", true)
219       hadoopConf.setInt("mapreduce.task.partition", 0)
220
221       new TaskAttemptContextImpl(hadoopConf, taskAttemptId)
222     }
223
224     committer.setupTask(taskAttemptContext)
225
226     val dataWriter =
227       if (sparkPartitionId != 0 && !iterator.hasNext) {
228         // In case of empty job, leave first partition to save meta for file format like parquet.
229         new EmptyDirectoryDataWriter(description, taskAttemptContext, committer)
230       } else if (description.partitionColumns.isEmpty && description.bucketIdExpression.isEmpty) {
231         new SingleDirectoryDataWriter(description, taskAttemptContext, committer)
232       } else {
233         new DynamicPartitionDataWriter(description, taskAttemptContext, committer)
234       }
235
236     try {
237       Utils.tryWithSafeFinallyAndFailureCallbacks(block = {
238         // Execute the task to write rows out and commit the task.
239         while (iterator.hasNext) {
240           dataWriter.write(iterator.next())
241         }
242         dataWriter.commit()
243       })(catchBlock = {
244         // If there is an error, abort the task
245         dataWriter.abort()
246         logError(s"Job $jobId aborted.")
247       })
248     } catch {
249       case e: FetchFailedException =>
250         throw e
251       case t: Throwable =>
252         throw new SparkException("Task failed while writing rows.", t)
253     }
254   }
255
256   /**
257    * For every registered [[WriteJobStatsTracker]], call `processStats()` on it, passing it
258    * the corresponding [[WriteTaskStats]] from all executors.
259    */
260   private def processStats(
261       statsTrackers: Seq[WriteJobStatsTracker],
262       statsPerTask: Seq[Seq[WriteTaskStats]])
263   : Unit = {
264
265     val numStatsTrackers = statsTrackers.length
266     assert(statsPerTask.forall(_.length == numStatsTrackers),
267       s"""Every WriteTask should have produced one `WriteTaskStats` object for every tracker.
268          |There are $numStatsTrackers statsTrackers, but some task returned
269          |${statsPerTask.find(_.length != numStatsTrackers).get.length} results instead.
270        """.stripMargin)
271
272     val statsPerTracker = if (statsPerTask.nonEmpty) {
273       statsPerTask.transpose
274     } else {
275       statsTrackers.map(_ => Seq.empty)
276     }
277
278     statsTrackers.zip(statsPerTracker).foreach {
279       case (statsTracker, stats) => statsTracker.processStats(stats)
280     }
281   }
282 }