[SPARK-19355][SQL] Use map output statistics to improve global limit's parallelism
authorLiang-Chi Hsieh <viirya@gmail.com>
Fri, 10 Aug 2018 09:32:15 +0000 (11:32 +0200)
committerHerman van Hovell <hvanhovell@databricks.com>
Fri, 10 Aug 2018 09:32:15 +0000 (11:32 +0200)
## What changes were proposed in this pull request?

A logical `Limit` is performed physically by two operations `LocalLimit` and `GlobalLimit`.

Most of time, we gather all data into a single partition in order to run `GlobalLimit`. If we use a very big limit number, shuffling data causes performance issue also reduces parallelism.

We can avoid shuffling into single partition if we don't care data ordering. This patch implements this idea by doing a map stage during global limit. It collects the info of row numbers at each partition. For each partition, we locally retrieves limited data without any shuffling to finish this global limit.

For example, we have three partitions with rows (100, 100, 50) respectively. In global limit of 100 rows, we may take (34, 33, 33) rows for each partition locally. After global limit we still have three partitions.

If the data partition has certain ordering, we can't distribute required rows evenly to each partitions because it could change data ordering. But we still can avoid shuffling.

## How was this patch tested?

Jenkins tests.

Author: Liang-Chi Hsieh <viirya@gmail.com>

Closes #16677 from viirya/improve-global-limit-parallelism.

26 files changed:
core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
core/src/main/scala/org/apache/spark/MapOutputStatistics.scala
core/src/main/scala/org/apache/spark/MapOutputTracker.scala
core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala
core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
core/src/test/scala/org/apache/spark/ShuffleSuite.scala
core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala
core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
sql/core/src/test/resources/sql-tests/inputs/limit.sql
sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-limit.sql
sql/core/src/test/resources/sql-tests/results/limit.sql.out
sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-limit.sql.out
sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala
sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala

index 323a5d3..e3bd549 100644 (file)
@@ -125,7 +125,7 @@ final class BypassMergeSortShuffleWriter<K, V> extends ShuffleWriter<K, V> {
     if (!records.hasNext()) {
       partitionLengths = new long[numPartitions];
       shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, null);
-      mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths);
+      mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths, 0);
       return;
     }
     final SerializerInstance serInstance = serializer.newInstance();
@@ -167,7 +167,8 @@ final class BypassMergeSortShuffleWriter<K, V> extends ShuffleWriter<K, V> {
         logger.error("Error while deleting temp file {}", tmp.getAbsolutePath());
       }
     }
-    mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths);
+    mapStatus = MapStatus$.MODULE$.apply(
+      blockManager.shuffleServerId(), partitionLengths, writeMetrics.recordsWritten());
   }
 
   @VisibleForTesting
index 4839d04..069e6d5 100644 (file)
@@ -248,7 +248,8 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
         logger.error("Error while deleting temp file {}", tmp.getAbsolutePath());
       }
     }
-    mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths);
+    mapStatus = MapStatus$.MODULE$.apply(
+      blockManager.shuffleServerId(), partitionLengths, writeMetrics.recordsWritten());
   }
 
   @VisibleForTesting
index f8a6f1d..ff85e11 100644 (file)
@@ -23,5 +23,9 @@ package org.apache.spark
  * @param shuffleId ID of the shuffle
  * @param bytesByPartitionId approximate number of output bytes for each map output partition
  *   (may be inexact due to use of compressed map statuses)
+ * @param recordsByPartitionId number of output records for each map output partition
  */
-private[spark] class MapOutputStatistics(val shuffleId: Int, val bytesByPartitionId: Array[Long])
+private[spark] class MapOutputStatistics(
+    val shuffleId: Int,
+    val bytesByPartitionId: Array[Long],
+    val recordsByPartitionId: Array[Long])
index 1c4fa4b..41575ce 100644 (file)
@@ -522,16 +522,19 @@ private[spark] class MapOutputTrackerMaster(
   def getStatistics(dep: ShuffleDependency[_, _, _]): MapOutputStatistics = {
     shuffleStatuses(dep.shuffleId).withMapStatuses { statuses =>
       val totalSizes = new Array[Long](dep.partitioner.numPartitions)
+      val recordsByMapTask = new Array[Long](statuses.length)
+
       val parallelAggThreshold = conf.get(
         SHUFFLE_MAP_OUTPUT_PARALLEL_AGGREGATION_THRESHOLD)
       val parallelism = math.min(
         Runtime.getRuntime.availableProcessors(),
         statuses.length.toLong * totalSizes.length / parallelAggThreshold + 1).toInt
       if (parallelism <= 1) {
-        for (s <- statuses) {
+        statuses.zipWithIndex.foreach { case (s, index) =>
           for (i <- 0 until totalSizes.length) {
             totalSizes(i) += s.getSizeForBlock(i)
           }
+          recordsByMapTask(index) = s.numberOfOutput
         }
       } else {
         val threadPool = ThreadUtils.newDaemonFixedThreadPool(parallelism, "map-output-aggregate")
@@ -548,8 +551,11 @@ private[spark] class MapOutputTrackerMaster(
         } finally {
           threadPool.shutdown()
         }
+        statuses.zipWithIndex.foreach { case (s, index) =>
+          recordsByMapTask(index) = s.numberOfOutput
+        }
       }
-      new MapOutputStatistics(dep.shuffleId, totalSizes)
+      new MapOutputStatistics(dep.shuffleId, totalSizes, recordsByMapTask)
     }
   }
 
index 659694d..7e1d75f 100644 (file)
@@ -31,7 +31,8 @@ import org.apache.spark.util.Utils
 
 /**
  * Result returned by a ShuffleMapTask to a scheduler. Includes the block manager address that the
- * task ran on as well as the sizes of outputs for each reducer, for passing on to the reduce tasks.
+ * task ran on, the sizes of outputs for each reducer, and the number of outputs of the map task,
+ * for passing on to the reduce tasks.
  */
 private[spark] sealed trait MapStatus {
   /** Location where this task was run. */
@@ -44,18 +45,23 @@ private[spark] sealed trait MapStatus {
    * necessary for correctness, since block fetchers are allowed to skip zero-size blocks.
    */
   def getSizeForBlock(reduceId: Int): Long
+
+  /**
+   * The number of outputs for the map task.
+   */
+  def numberOfOutput: Long
 }
 
 
 private[spark] object MapStatus {
 
-  def apply(loc: BlockManagerId, uncompressedSizes: Array[Long]): MapStatus = {
+  def apply(loc: BlockManagerId, uncompressedSizes: Array[Long], numOutput: Long): MapStatus = {
     if (uncompressedSizes.length >  Option(SparkEnv.get)
       .map(_.conf.get(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS))
       .getOrElse(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS.defaultValue.get)) {
-      HighlyCompressedMapStatus(loc, uncompressedSizes)
+      HighlyCompressedMapStatus(loc, uncompressedSizes, numOutput)
     } else {
-      new CompressedMapStatus(loc, uncompressedSizes)
+      new CompressedMapStatus(loc, uncompressedSizes, numOutput)
     }
   }
 
@@ -98,29 +104,34 @@ private[spark] object MapStatus {
  */
 private[spark] class CompressedMapStatus(
     private[this] var loc: BlockManagerId,
-    private[this] var compressedSizes: Array[Byte])
+    private[this] var compressedSizes: Array[Byte],
+    private[this] var numOutput: Long)
   extends MapStatus with Externalizable {
 
-  protected def this() = this(null, null.asInstanceOf[Array[Byte]])  // For deserialization only
+  protected def this() = this(null, null.asInstanceOf[Array[Byte]], -1)  // For deserialization only
 
-  def this(loc: BlockManagerId, uncompressedSizes: Array[Long]) {
-    this(loc, uncompressedSizes.map(MapStatus.compressSize))
+  def this(loc: BlockManagerId, uncompressedSizes: Array[Long], numOutput: Long) {
+    this(loc, uncompressedSizes.map(MapStatus.compressSize), numOutput)
   }
 
   override def location: BlockManagerId = loc
 
+  override def numberOfOutput: Long = numOutput
+
   override def getSizeForBlock(reduceId: Int): Long = {
     MapStatus.decompressSize(compressedSizes(reduceId))
   }
 
   override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {
     loc.writeExternal(out)
+    out.writeLong(numOutput)
     out.writeInt(compressedSizes.length)
     out.write(compressedSizes)
   }
 
   override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
     loc = BlockManagerId(in)
+    numOutput = in.readLong()
     val len = in.readInt()
     compressedSizes = new Array[Byte](len)
     in.readFully(compressedSizes)
@@ -143,17 +154,20 @@ private[spark] class HighlyCompressedMapStatus private (
     private[this] var numNonEmptyBlocks: Int,
     private[this] var emptyBlocks: RoaringBitmap,
     private[this] var avgSize: Long,
-    private var hugeBlockSizes: Map[Int, Byte])
+    private var hugeBlockSizes: Map[Int, Byte],
+    private[this] var numOutput: Long)
   extends MapStatus with Externalizable {
 
   // loc could be null when the default constructor is called during deserialization
   require(loc == null || avgSize > 0 || hugeBlockSizes.size > 0 || numNonEmptyBlocks == 0,
     "Average size can only be zero for map stages that produced no output")
 
-  protected def this() = this(null, -1, null, -1, null)  // For deserialization only
+  protected def this() = this(null, -1, null, -1, null, -1)  // For deserialization only
 
   override def location: BlockManagerId = loc
 
+  override def numberOfOutput: Long = numOutput
+
   override def getSizeForBlock(reduceId: Int): Long = {
     assert(hugeBlockSizes != null)
     if (emptyBlocks.contains(reduceId)) {
@@ -168,6 +182,7 @@ private[spark] class HighlyCompressedMapStatus private (
 
   override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {
     loc.writeExternal(out)
+    out.writeLong(numOutput)
     emptyBlocks.writeExternal(out)
     out.writeLong(avgSize)
     out.writeInt(hugeBlockSizes.size)
@@ -179,6 +194,7 @@ private[spark] class HighlyCompressedMapStatus private (
 
   override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
     loc = BlockManagerId(in)
+    numOutput = in.readLong()
     emptyBlocks = new RoaringBitmap()
     emptyBlocks.readExternal(in)
     avgSize = in.readLong()
@@ -194,7 +210,10 @@ private[spark] class HighlyCompressedMapStatus private (
 }
 
 private[spark] object HighlyCompressedMapStatus {
-  def apply(loc: BlockManagerId, uncompressedSizes: Array[Long]): HighlyCompressedMapStatus = {
+  def apply(
+      loc: BlockManagerId,
+      uncompressedSizes: Array[Long],
+      numOutput: Long): HighlyCompressedMapStatus = {
     // We must keep track of which blocks are empty so that we don't report a zero-sized
     // block as being non-empty (or vice-versa) when using the average block size.
     var i = 0
@@ -235,6 +254,6 @@ private[spark] object HighlyCompressedMapStatus {
     emptyBlocks.trim()
     emptyBlocks.runOptimize()
     new HighlyCompressedMapStatus(loc, numNonEmptyBlocks, emptyBlocks, avgSize,
-      hugeBlockSizesArray.toMap)
+      hugeBlockSizesArray.toMap, numOutput)
   }
 }
index 274399b..91fc267 100644 (file)
@@ -70,7 +70,8 @@ private[spark] class SortShuffleWriter[K, V, C](
       val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID)
       val partitionLengths = sorter.writePartitionedFile(blockId, tmp)
       shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, tmp)
-      mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths)
+      mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths,
+        writeMetrics.recordsWritten)
     } finally {
       if (tmp.exists() && !tmp.delete()) {
         logError(s"Error while deleting temp file ${tmp.getAbsolutePath}")
index 0d5c5ea..faa70f2 100644 (file)
@@ -233,6 +233,7 @@ public class UnsafeShuffleWriterSuite {
     writer.write(Iterators.emptyIterator());
     final Option<MapStatus> mapStatus = writer.stop(true);
     assertTrue(mapStatus.isDefined());
+    assertEquals(0, mapStatus.get().numberOfOutput());
     assertTrue(mergedOutputFile.exists());
     assertArrayEquals(new long[NUM_PARTITITONS], partitionSizesInMergedFile);
     assertEquals(0, taskMetrics.shuffleWriteMetrics().recordsWritten());
@@ -252,6 +253,7 @@ public class UnsafeShuffleWriterSuite {
     writer.write(dataToWrite.iterator());
     final Option<MapStatus> mapStatus = writer.stop(true);
     assertTrue(mapStatus.isDefined());
+    assertEquals(NUM_PARTITITONS, mapStatus.get().numberOfOutput());
     assertTrue(mergedOutputFile.exists());
 
     long sumOfPartitionSizes = 0;
index 21f481d..e797396 100644 (file)
@@ -62,9 +62,9 @@ class MapOutputTrackerSuite extends SparkFunSuite {
     val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L))
     val size10000 = MapStatus.decompressSize(MapStatus.compressSize(10000L))
     tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000),
-        Array(1000L, 10000L)))
+        Array(1000L, 10000L), 10))
     tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000),
-        Array(10000L, 1000L)))
+        Array(10000L, 1000L), 10))
     val statuses = tracker.getMapSizesByExecutorId(10, 0)
     assert(statuses.toSet ===
       Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))),
@@ -84,9 +84,9 @@ class MapOutputTrackerSuite extends SparkFunSuite {
     val compressedSize1000 = MapStatus.compressSize(1000L)
     val compressedSize10000 = MapStatus.compressSize(10000L)
     tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000),
-      Array(compressedSize1000, compressedSize10000)))
+      Array(compressedSize1000, compressedSize10000), 10))
     tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000),
-      Array(compressedSize10000, compressedSize1000)))
+      Array(compressedSize10000, compressedSize1000), 10))
     assert(tracker.containsShuffle(10))
     assert(tracker.getMapSizesByExecutorId(10, 0).nonEmpty)
     assert(0 == tracker.getNumCachedSerializedBroadcast)
@@ -107,9 +107,9 @@ class MapOutputTrackerSuite extends SparkFunSuite {
     val compressedSize1000 = MapStatus.compressSize(1000L)
     val compressedSize10000 = MapStatus.compressSize(10000L)
     tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000),
-        Array(compressedSize1000, compressedSize1000, compressedSize1000)))
+        Array(compressedSize1000, compressedSize1000, compressedSize1000), 10))
     tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000),
-        Array(compressedSize10000, compressedSize1000, compressedSize1000)))
+        Array(compressedSize10000, compressedSize1000, compressedSize1000), 10))
 
     assert(0 == tracker.getNumCachedSerializedBroadcast)
     // As if we had two simultaneous fetch failures
@@ -145,7 +145,7 @@ class MapOutputTrackerSuite extends SparkFunSuite {
 
     val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L))
     masterTracker.registerMapOutput(10, 0, MapStatus(
-      BlockManagerId("a", "hostA", 1000), Array(1000L)))
+      BlockManagerId("a", "hostA", 1000), Array(1000L), 10))
     slaveTracker.updateEpoch(masterTracker.getEpoch)
     assert(slaveTracker.getMapSizesByExecutorId(10, 0).toSeq ===
       Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000)))))
@@ -182,7 +182,7 @@ class MapOutputTrackerSuite extends SparkFunSuite {
     // Message size should be ~123B, and no exception should be thrown
     masterTracker.registerShuffle(10, 1)
     masterTracker.registerMapOutput(10, 0, MapStatus(
-      BlockManagerId("88", "mph", 1000), Array.fill[Long](10)(0)))
+      BlockManagerId("88", "mph", 1000), Array.fill[Long](10)(0), 0))
     val senderAddress = RpcAddress("localhost", 12345)
     val rpcCallContext = mock(classOf[RpcCallContext])
     when(rpcCallContext.senderAddress).thenReturn(senderAddress)
@@ -216,11 +216,11 @@ class MapOutputTrackerSuite extends SparkFunSuite {
     // on hostB with output size 3
     tracker.registerShuffle(10, 3)
     tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000),
-        Array(2L)))
+        Array(2L), 1))
     tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("a", "hostA", 1000),
-        Array(2L)))
+        Array(2L), 1))
     tracker.registerMapOutput(10, 2, MapStatus(BlockManagerId("b", "hostB", 1000),
-        Array(3L)))
+        Array(3L), 1))
 
     // When the threshold is 50%, only host A should be returned as a preferred location
     // as it has 4 out of 7 bytes of output.
@@ -260,7 +260,7 @@ class MapOutputTrackerSuite extends SparkFunSuite {
       masterTracker.registerShuffle(20, 100)
       (0 until 100).foreach { i =>
         masterTracker.registerMapOutput(20, i, new CompressedMapStatus(
-          BlockManagerId("999", "mps", 1000), Array.fill[Long](4000000)(0)))
+          BlockManagerId("999", "mps", 1000), Array.fill[Long](4000000)(0), 0))
       }
       val senderAddress = RpcAddress("localhost", 12345)
       val rpcCallContext = mock(classOf[RpcCallContext])
@@ -309,9 +309,9 @@ class MapOutputTrackerSuite extends SparkFunSuite {
     val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L))
     val size10000 = MapStatus.decompressSize(MapStatus.compressSize(10000L))
     tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000),
-      Array(size0, size1000, size0, size10000)))
+      Array(size0, size1000, size0, size10000), 1))
     tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000),
-      Array(size10000, size0, size1000, size0)))
+      Array(size10000, size0, size1000, size0), 1))
     assert(tracker.containsShuffle(10))
     assert(tracker.getMapSizesByExecutorId(10, 0, 4).toSeq ===
         Seq(
index ced5a06..d11eaf8 100644 (file)
@@ -391,6 +391,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC
     assert(mapOutput2.isDefined)
     assert(mapOutput1.get.location === mapOutput2.get.location)
     assert(mapOutput1.get.getSizeForBlock(0) === mapOutput1.get.getSizeForBlock(0))
+    assert(mapOutput1.get.numberOfOutput === mapOutput2.get.numberOfOutput)
 
     // register one of the map outputs -- doesn't matter which one
     mapOutput1.foreach { case mapStatus =>
index 211002b..5e095ce 100644 (file)
@@ -423,17 +423,17 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
     // map stage1 completes successfully, with one task on each executor
     complete(taskSets(0), Seq(
       (Success,
-        MapStatus(BlockManagerId("exec-hostA1", "hostA", 12345), Array.fill[Long](1)(2))),
+        MapStatus(BlockManagerId("exec-hostA1", "hostA", 12345), Array.fill[Long](1)(2), 1)),
       (Success,
-        MapStatus(BlockManagerId("exec-hostA2", "hostA", 12345), Array.fill[Long](1)(2))),
+        MapStatus(BlockManagerId("exec-hostA2", "hostA", 12345), Array.fill[Long](1)(2), 1)),
       (Success, makeMapStatus("hostB", 1))
     ))
     // map stage2 completes successfully, with one task on each executor
     complete(taskSets(1), Seq(
       (Success,
-        MapStatus(BlockManagerId("exec-hostA1", "hostA", 12345), Array.fill[Long](1)(2))),
+        MapStatus(BlockManagerId("exec-hostA1", "hostA", 12345), Array.fill[Long](1)(2), 1)),
       (Success,
-        MapStatus(BlockManagerId("exec-hostA2", "hostA", 12345), Array.fill[Long](1)(2))),
+        MapStatus(BlockManagerId("exec-hostA2", "hostA", 12345), Array.fill[Long](1)(2), 1)),
       (Success, makeMapStatus("hostB", 1))
     ))
     // make sure our test setup is correct
@@ -2576,7 +2576,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
 
 object DAGSchedulerSuite {
   def makeMapStatus(host: String, reduces: Int, sizes: Byte = 2): MapStatus =
-    MapStatus(makeBlockManagerId(host), Array.fill[Long](reduces)(sizes))
+    MapStatus(makeBlockManagerId(host), Array.fill[Long](reduces)(sizes), 1)
 
   def makeBlockManagerId(host: String): BlockManagerId =
     BlockManagerId("exec-" + host, host, 12345)
index 354e638..555e48b 100644 (file)
@@ -60,7 +60,7 @@ class MapStatusSuite extends SparkFunSuite {
       stddev <- Seq(0.0, 0.01, 0.5, 1.0)
     ) {
       val sizes = Array.fill[Long](numSizes)(abs(round(Random.nextGaussian() * stddev)) + mean)
-      val status = MapStatus(BlockManagerId("a", "b", 10), sizes)
+      val status = MapStatus(BlockManagerId("a", "b", 10), sizes, 1)
       val status1 = compressAndDecompressMapStatus(status)
       for (i <- 0 until numSizes) {
         if (sizes(i) != 0) {
@@ -74,7 +74,7 @@ class MapStatusSuite extends SparkFunSuite {
 
   test("large tasks should use " + classOf[HighlyCompressedMapStatus].getName) {
     val sizes = Array.fill[Long](2001)(150L)
-    val status = MapStatus(null, sizes)
+    val status = MapStatus(null, sizes, 1)
     assert(status.isInstanceOf[HighlyCompressedMapStatus])
     assert(status.getSizeForBlock(10) === 150L)
     assert(status.getSizeForBlock(50) === 150L)
@@ -86,7 +86,7 @@ class MapStatusSuite extends SparkFunSuite {
     val sizes = Array.tabulate[Long](3000) { i => i.toLong }
     val avg = sizes.sum / sizes.count(_ != 0)
     val loc = BlockManagerId("a", "b", 10)
-    val status = MapStatus(loc, sizes)
+    val status = MapStatus(loc, sizes, 1)
     val status1 = compressAndDecompressMapStatus(status)
     assert(status1.isInstanceOf[HighlyCompressedMapStatus])
     assert(status1.location == loc)
@@ -108,7 +108,7 @@ class MapStatusSuite extends SparkFunSuite {
     val smallBlockSizes = sizes.filter(n => n > 0 && n < threshold)
     val avg = smallBlockSizes.sum / smallBlockSizes.length
     val loc = BlockManagerId("a", "b", 10)
-    val status = MapStatus(loc, sizes)
+    val status = MapStatus(loc, sizes, 1)
     val status1 = compressAndDecompressMapStatus(status)
     assert(status1.isInstanceOf[HighlyCompressedMapStatus])
     assert(status1.location == loc)
@@ -164,7 +164,7 @@ class MapStatusSuite extends SparkFunSuite {
     SparkEnv.set(env)
     // Value of element in sizes is equal to the corresponding index.
     val sizes = (0L to 2000L).toArray
-    val status1 = MapStatus(BlockManagerId("exec-0", "host-0", 100), sizes)
+    val status1 = MapStatus(BlockManagerId("exec-0", "host-0", 100), sizes, 1)
     val arrayStream = new ByteArrayOutputStream(102400)
     val objectOutputStream = new ObjectOutputStream(arrayStream)
     assert(status1.isInstanceOf[HighlyCompressedMapStatus])
@@ -196,19 +196,19 @@ class MapStatusSuite extends SparkFunSuite {
     SparkEnv.set(env)
     val sizes = Array.fill[Long](500)(150L)
     // Test default value
-    val status = MapStatus(null, sizes)
+    val status = MapStatus(null, sizes, 1)
     assert(status.isInstanceOf[CompressedMapStatus])
     // Test Non-positive values
     for (s <- -1 to 0) {
       assertThrows[IllegalArgumentException] {
         conf.set(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS, s)
-        val status = MapStatus(null, sizes)
+        val status = MapStatus(null, sizes, 1)
       }
     }
     // Test positive values
     Seq(1, 100, 499, 500, 501).foreach { s =>
       conf.set(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS, s)
-      val status = MapStatus(null, sizes)
+      val status = MapStatus(null, sizes, 1)
       if(sizes.length > s) {
         assert(status.isInstanceOf[HighlyCompressedMapStatus])
       } else {
index fc78655..240f8cf 100644 (file)
@@ -345,7 +345,8 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext {
     val denseBlockSizes = new Array[Long](5000)
     val sparseBlockSizes = Array[Long](0L, 1L, 0L, 2L)
     Seq(denseBlockSizes, sparseBlockSizes).foreach { blockSizes =>
-      ser.serialize(HighlyCompressedMapStatus(BlockManagerId("exec-1", "host", 1234), blockSizes))
+      ser.serialize(
+        HighlyCompressedMapStatus(BlockManagerId("exec-1", "host", 1234), blockSizes, 1))
     }
   }
 
index cc1a5e8..cd28c73 100644 (file)
@@ -17,6 +17,8 @@
 
 package org.apache.spark.sql.catalyst.plans.physical
 
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.types.{DataType, IntegerType}
 
@@ -207,6 +209,18 @@ case object SinglePartition extends Partitioning {
 }
 
 /**
+ * Represents a partitioning where rows are only serialized/deserialized locally. The number
+ * of partitions are not changed and also the distribution of rows. This is mainly used to
+ * obtain some statistics of map tasks such as number of outputs.
+ */
+case class LocalPartitioning(childRDD: RDD[InternalRow]) extends Partitioning {
+  val numPartitions = childRDD.getNumPartitions
+
+  // We will perform this partitioning no matter what the data distribution is.
+  override def satisfies0(required: Distribution): Boolean = false
+}
+
+/**
  * Represents a partitioning where rows are split up across partitions based on the hash
  * of `expressions`.  All rows where `expressions` evaluate to the same values are guaranteed to be
  * in the same partition.
index 979a554..603c070 100644 (file)
@@ -214,6 +214,13 @@ object SQLConf {
     .intConf
     .createWithDefault(4)
 
+  val LIMIT_FLAT_GLOBAL_LIMIT = buildConf("spark.sql.limit.flatGlobalLimit")
+    .internal()
+    .doc("During global limit, try to evenly distribute limited rows across data " +
+      "partitions. If disabled, scanning data partitions sequentially until reaching limit number.")
+    .booleanConf
+    .createWithDefault(true)
+
   val ADVANCED_PARTITION_PREDICATE_PUSHDOWN =
     buildConf("spark.sql.hive.advancedPartitionPredicatePushdown.enabled")
       .internal()
@@ -1682,6 +1689,8 @@ class SQLConf extends Serializable with Logging {
 
   def limitScaleUpFactor: Int = getConf(LIMIT_SCALE_UP_FACTOR)
 
+  def limitFlatGlobalLimit: Boolean = getConf(LIMIT_FLAT_GLOBAL_LIMIT)
+
   def advancedPartitionPredicatePushdownEnabled: Boolean =
     getConf(ADVANCED_PARTITION_PREDICATE_PUSHDOWN)
 
index b892037..50f10c3 100644 (file)
@@ -231,6 +231,11 @@ object ShuffleExchangeExec {
           override def numPartitions: Int = 1
           override def getPartition(key: Any): Int = 0
         }
+      case l: LocalPartitioning =>
+        new Partitioner {
+          override def numPartitions: Int = l.numPartitions
+          override def getPartition(key: Any): Int = key.asInstanceOf[Int]
+        }
       case _ => sys.error(s"Exchange not implemented for $newPartitioning")
       // TODO: Handle BroadcastPartitioning.
     }
@@ -247,6 +252,9 @@ object ShuffleExchangeExec {
         val projection = UnsafeProjection.create(h.partitionIdExpression :: Nil, outputAttributes)
         row => projection(row).getInt(0)
       case RangePartitioning(_, _) | SinglePartition => identity
+      case _: LocalPartitioning =>
+        val partitionId = TaskContext.get().partitionId()
+        _ => partitionId
       case _ => sys.error(s"Exchange not implemented for $newPartitioning")
     }
 
index 66bcda8..392ca13 100644 (file)
@@ -47,13 +47,16 @@ case class CollectLimitExec(limit: Int, child: SparkPlan) extends UnaryExecNode
 }
 
 /**
- * Helper trait which defines methods that are shared by both
- * [[LocalLimitExec]] and [[GlobalLimitExec]].
+ * Take the first `limit` elements of each child partition, but do not collect or shuffle them.
  */
-trait BaseLimitExec extends UnaryExecNode with CodegenSupport {
-  val limit: Int
+case class LocalLimitExec(limit: Int, child: SparkPlan) extends UnaryExecNode with CodegenSupport {
+
   override def output: Seq[Attribute] = child.output
 
+  override def outputOrdering: Seq[SortOrder] = child.outputOrdering
+
+  override def outputPartitioning: Partitioning = child.outputPartitioning
+
   protected override def doExecute(): RDD[InternalRow] = child.execute().mapPartitions { iter =>
     iter.take(limit)
   }
@@ -93,25 +96,93 @@ trait BaseLimitExec extends UnaryExecNode with CodegenSupport {
 }
 
 /**
- * Take the first `limit` elements of each child partition, but do not collect or shuffle them.
+ * Take the `limit` elements of the child output.
  */
-case class LocalLimitExec(limit: Int, child: SparkPlan) extends BaseLimitExec {
+case class GlobalLimitExec(limit: Int, child: SparkPlan) extends UnaryExecNode {
 
-  override def outputOrdering: Seq[SortOrder] = child.outputOrdering
+  override def output: Seq[Attribute] = child.output
 
   override def outputPartitioning: Partitioning = child.outputPartitioning
-}
 
-/**
- * Take the first `limit` elements of the child's single output partition.
- */
-case class GlobalLimitExec(limit: Int, child: SparkPlan) extends BaseLimitExec {
+  override def outputOrdering: Seq[SortOrder] = child.outputOrdering
 
-  override def requiredChildDistribution: List[Distribution] = AllTuples :: Nil
+  private val serializer: Serializer = new UnsafeRowSerializer(child.output.size)
 
-  override def outputPartitioning: Partitioning = child.outputPartitioning
+  protected override def doExecute(): RDD[InternalRow] = {
+    val childRDD = child.execute()
+    val partitioner = LocalPartitioning(childRDD)
+    val shuffleDependency = ShuffleExchangeExec.prepareShuffleDependency(
+      childRDD, child.output, partitioner, serializer)
+    val numberOfOutput: Seq[Long] = if (shuffleDependency.rdd.getNumPartitions != 0) {
+      // submitMapStage does not accept RDD with 0 partition.
+      // So, we will not submit this dependency.
+      val submittedStageFuture = sparkContext.submitMapStage(shuffleDependency)
+      submittedStageFuture.get().recordsByPartitionId.toSeq
+    } else {
+      Nil
+    }
 
-  override def outputOrdering: Seq[SortOrder] = child.outputOrdering
+    // During global limit, try to evenly distribute limited rows across data
+    // partitions. If disabled, scanning data partitions sequentially until reaching limit number.
+    // Besides, if child output has certain ordering, we can't evenly pick up rows from
+    // each parititon.
+    val flatGlobalLimit = sqlContext.conf.limitFlatGlobalLimit && child.outputOrdering == Nil
+
+    val shuffled = new ShuffledRowRDD(shuffleDependency)
+
+    val sumOfOutput = numberOfOutput.sum
+    if (sumOfOutput <= limit) {
+      shuffled
+    } else if (!flatGlobalLimit) {
+      var numRowTaken = 0
+      val takeAmounts = numberOfOutput.map { num =>
+        if (numRowTaken + num < limit) {
+          numRowTaken += num.toInt
+          num.toInt
+        } else {
+          val toTake = limit - numRowTaken
+          numRowTaken += toTake
+          toTake
+        }
+      }
+      val broadMap = sparkContext.broadcast(takeAmounts)
+      shuffled.mapPartitionsWithIndexInternal { case (index, iter) =>
+        iter.take(broadMap.value(index).toInt)
+      }
+    } else {
+      // We try to evenly require the asked limit number of rows across all child rdd's partitions.
+      var rowsNeedToTake: Long = limit
+      val takeAmountByPartition: Array[Long] = Array.fill[Long](numberOfOutput.length)(0L)
+      val remainingRowsByPartition: Array[Long] = Array(numberOfOutput: _*)
+
+      while (rowsNeedToTake > 0) {
+        val nonEmptyParts = remainingRowsByPartition.count(_ > 0)
+        // If the rows needed to take are less the number of non-empty partitions, take one row from
+        // each non-empty partitions until we reach `limit` rows.
+        // Otherwise, evenly divide the needed rows to each non-empty partitions.
+        val takePerPart = math.max(1, rowsNeedToTake / nonEmptyParts)
+        remainingRowsByPartition.zipWithIndex.foreach { case (num, index) =>
+          // In case `rowsNeedToTake` < `nonEmptyParts`, we may run out of `rowsNeedToTake` during
+          // the traversal, so we need to add this check.
+          if (rowsNeedToTake > 0 && num > 0) {
+            if (num >= takePerPart) {
+              rowsNeedToTake -= takePerPart
+              takeAmountByPartition(index) += takePerPart
+              remainingRowsByPartition(index) -= takePerPart
+            } else {
+              rowsNeedToTake -= num
+              takeAmountByPartition(index) += num
+              remainingRowsByPartition(index) -= num
+            }
+          }
+        }
+      }
+      val broadMap = sparkContext.broadcast(takeAmountByPartition)
+      shuffled.mapPartitionsWithIndexInternal { case (index, iter) =>
+        iter.take(broadMap.value(index).toInt)
+      }
+    }
+  }
 }
 
 /**
index b4c73cf..e33cd81 100644 (file)
@@ -1,3 +1,5 @@
+-- Disable global limit parallel
+set spark.sql.limit.flatGlobalLimit=false;
 
 -- limit on various data types
 SELECT * FROM testdata LIMIT 2;
index a40ee08..a862e09 100644 (file)
@@ -1,6 +1,9 @@
 -- A test suite for IN LIMIT in parent side, subquery, and both predicate subquery
 -- It includes correlated cases.
 
+-- Disable global limit optimization
+set spark.sql.limit.flatGlobalLimit=false;
+
 create temporary view t1 as select * from values
   ("val1a", 6S, 8, 10L, float(15.0), 20D, 20E2, timestamp '2014-04-04 01:00:00.000', date '2014-04-04'),
   ("val1b", 8S, 16, 19L, float(17.0), 25D, 26E2, timestamp '2014-05-04 01:01:00.000', date '2014-05-04'),
@@ -97,4 +100,4 @@ WHERE  t1d NOT IN (SELECT t2d
                    LIMIT 1)
 GROUP  BY t1b
 ORDER BY t1b NULLS last
-LIMIT  1;
\ No newline at end of file
+LIMIT  1;
index 02fe1de..187f3bd 100644 (file)
@@ -1,63 +1,62 @@
 -- Automatically generated by SQLQueryTestSuite
--- Number of queries: 14
+-- Number of queries: 15
 
 
 -- !query 0
-SELECT * FROM testdata LIMIT 2
+set spark.sql.limit.flatGlobalLimit=false
 -- !query 0 schema
-struct<key:int,value:string>
+struct<key:string,value:string>
 -- !query 0 output
-1      1
-2      2
+spark.sql.limit.flatGlobalLimit        false
 
 
 -- !query 1
-SELECT * FROM arraydata LIMIT 2
+SELECT * FROM testdata LIMIT 2
 -- !query 1 schema
-struct<arraycol:array<int>,nestedarraycol:array<array<int>>>
+struct<key:int,value:string>
 -- !query 1 output
-[1,2,3]        [[1,2,3]]
-[2,3,4]        [[2,3,4]]
+1      1
+2      2
 
 
 -- !query 2
-SELECT * FROM mapdata LIMIT 2
+SELECT * FROM arraydata LIMIT 2
 -- !query 2 schema
-struct<mapcol:map<int,string>>
+struct<arraycol:array<int>,nestedarraycol:array<array<int>>>
 -- !query 2 output
-{1:"a1",2:"b1",3:"c1",4:"d1",5:"e1"}
-{1:"a2",2:"b2",3:"c2",4:"d2"}
+[1,2,3]        [[1,2,3]]
+[2,3,4]        [[2,3,4]]
 
 
 -- !query 3
-SELECT * FROM testdata LIMIT 2 + 1
+SELECT * FROM mapdata LIMIT 2
 -- !query 3 schema
-struct<key:int,value:string>
+struct<mapcol:map<int,string>>
 -- !query 3 output
-1      1
-2      2
-3      3
+{1:"a1",2:"b1",3:"c1",4:"d1",5:"e1"}
+{1:"a2",2:"b2",3:"c2",4:"d2"}
 
 
 -- !query 4
-SELECT * FROM testdata LIMIT CAST(1 AS int)
+SELECT * FROM testdata LIMIT 2 + 1
 -- !query 4 schema
 struct<key:int,value:string>
 -- !query 4 output
 1      1
+2      2
+3      3
 
 
 -- !query 5
-SELECT * FROM testdata LIMIT -1
+SELECT * FROM testdata LIMIT CAST(1 AS int)
 -- !query 5 schema
-struct<>
+struct<key:int,value:string>
 -- !query 5 output
-org.apache.spark.sql.AnalysisException
-The limit expression must be equal to or greater than 0, but got -1;
+1      1
 
 
 -- !query 6
-SELECT * FROM testData TABLESAMPLE (-1 ROWS)
+SELECT * FROM testdata LIMIT -1
 -- !query 6 schema
 struct<>
 -- !query 6 output
@@ -66,61 +65,70 @@ The limit expression must be equal to or greater than 0, but got -1;
 
 
 -- !query 7
-SELECT * FROM testdata LIMIT CAST(1 AS INT)
+SELECT * FROM testData TABLESAMPLE (-1 ROWS)
 -- !query 7 schema
-struct<key:int,value:string>
+struct<>
 -- !query 7 output
-1      1
+org.apache.spark.sql.AnalysisException
+The limit expression must be equal to or greater than 0, but got -1;
 
 
 -- !query 8
-SELECT * FROM testdata LIMIT CAST(NULL AS INT)
+SELECT * FROM testdata LIMIT CAST(1 AS INT)
 -- !query 8 schema
-struct<>
+struct<key:int,value:string>
 -- !query 8 output
-org.apache.spark.sql.AnalysisException
-The evaluated limit expression must not be null, but got CAST(NULL AS INT);
+1      1
 
 
 -- !query 9
-SELECT * FROM testdata LIMIT key > 3
+SELECT * FROM testdata LIMIT CAST(NULL AS INT)
 -- !query 9 schema
 struct<>
 -- !query 9 output
 org.apache.spark.sql.AnalysisException
-The limit expression must evaluate to a constant value, but got (testdata.`key` > 3);
+The evaluated limit expression must not be null, but got CAST(NULL AS INT);
 
 
 -- !query 10
-SELECT * FROM testdata LIMIT true
+SELECT * FROM testdata LIMIT key > 3
 -- !query 10 schema
 struct<>
 -- !query 10 output
 org.apache.spark.sql.AnalysisException
-The limit expression must be integer type, but got boolean;
+The limit expression must evaluate to a constant value, but got (testdata.`key` > 3);
 
 
 -- !query 11
-SELECT * FROM testdata LIMIT 'a'
+SELECT * FROM testdata LIMIT true
 -- !query 11 schema
 struct<>
 -- !query 11 output
 org.apache.spark.sql.AnalysisException
-The limit expression must be integer type, but got string;
+The limit expression must be integer type, but got boolean;
 
 
 -- !query 12
-SELECT * FROM (SELECT * FROM range(10) LIMIT 5) WHERE id > 3
+SELECT * FROM testdata LIMIT 'a'
 -- !query 12 schema
-struct<id:bigint>
+struct<>
 -- !query 12 output
-4
+org.apache.spark.sql.AnalysisException
+The limit expression must be integer type, but got string;
 
 
 -- !query 13
-SELECT * FROM testdata WHERE key < 3 LIMIT ALL
+SELECT * FROM (SELECT * FROM range(10) LIMIT 5) WHERE id > 3
 -- !query 13 schema
-struct<key:int,value:string>
+struct<id:bigint>
 -- !query 13 output
+4
+
+
+-- !query 14
+SELECT * FROM testdata WHERE key < 3 LIMIT ALL
+-- !query 14 schema
+struct<key:int,value:string>
+-- !query 14 output
 1      1
 2      2
index 71ca1f8..9eb5b33 100644 (file)
@@ -1,8 +1,16 @@
 -- Automatically generated by SQLQueryTestSuite
--- Number of queries: 8
+-- Number of queries: 9
 
 
 -- !query 0
+set spark.sql.limit.flatGlobalLimit=false
+-- !query 0 schema
+struct<key:string,value:string>
+-- !query 0 output
+spark.sql.limit.flatGlobalLimit        false
+
+
+-- !query 1
 create temporary view t1 as select * from values
   ("val1a", 6S, 8, 10L, float(15.0), 20D, 20E2, timestamp '2014-04-04 01:00:00.000', date '2014-04-04'),
   ("val1b", 8S, 16, 19L, float(17.0), 25D, 26E2, timestamp '2014-05-04 01:01:00.000', date '2014-05-04'),
@@ -17,13 +25,13 @@ create temporary view t1 as select * from values
   ("val1a", 6S, 8, 10L, float(15.0), 20D, 20E2, timestamp '2014-04-04 01:02:00.001', date '2014-04-04'),
   ("val1e", 10S, null, 19L, float(17.0), 25D, 26E2, timestamp '2014-05-04 01:01:00.000', date '2014-05-04')
   as t1(t1a, t1b, t1c, t1d, t1e, t1f, t1g, t1h, t1i)
--- !query 0 schema
+-- !query 1 schema
 struct<>
--- !query 0 output
+-- !query 1 output
 
 
 
--- !query 1
+-- !query 2
 create temporary view t2 as select * from values
   ("val2a", 6S, 12, 14L, float(15), 20D, 20E2, timestamp '2014-04-04 01:01:00.000', date '2014-04-04'),
   ("val1b", 10S, 12, 19L, float(17), 25D, 26E2, timestamp '2014-05-04 01:01:00.000', date '2014-05-04'),
@@ -39,13 +47,13 @@ create temporary view t2 as select * from values
   ("val1f", 19S, null, 19L, float(17), 25D, 26E2, timestamp '2014-10-04 01:01:00.000', date '2014-10-04'),
   ("val1b", null, 16, 19L, float(17), 25D, 26E2, timestamp '2014-05-04 01:01:00.000', null)
   as t2(t2a, t2b, t2c, t2d, t2e, t2f, t2g, t2h, t2i)
--- !query 1 schema
+-- !query 2 schema
 struct<>
--- !query 1 output
+-- !query 2 output
 
 
 
--- !query 2
+-- !query 3
 create temporary view t3 as select * from values
   ("val3a", 6S, 12, 110L, float(15), 20D, 20E2, timestamp '2014-04-04 01:02:00.000', date '2014-04-04'),
   ("val3a", 6S, 12, 10L, float(15), 20D, 20E2, timestamp '2014-05-04 01:02:00.000', date '2014-05-04'),
@@ -60,27 +68,27 @@ create temporary view t3 as select * from values
   ("val3b", 8S, null, 719L, float(17), 25D, 26E2, timestamp '2014-05-04 01:02:00.000', date '2014-05-04'),
   ("val3b", 8S, null, 19L, float(17), 25D, 26E2, timestamp '2015-05-04 01:02:00.000', date '2015-05-04')
   as t3(t3a, t3b, t3c, t3d, t3e, t3f, t3g, t3h, t3i)
--- !query 2 schema
+-- !query 3 schema
 struct<>
--- !query 2 output
+-- !query 3 output
 
 
 
--- !query 3
+-- !query 4
 SELECT *
 FROM   t1
 WHERE  t1a IN (SELECT t2a
                FROM   t2
                WHERE  t1d = t2d)
 LIMIT  2
--- !query 3 schema
+-- !query 4 schema
 struct<t1a:string,t1b:smallint,t1c:int,t1d:bigint,t1e:float,t1f:double,t1g:decimal(2,-2),t1h:timestamp,t1i:date>
--- !query 3 output
+-- !query 4 output
 val1b  8       16      19      17.0    25.0    2600    2014-05-04 01:01:00     2014-05-04
 val1c  8       16      19      17.0    25.0    2600    2014-05-04 01:02:00.001 2014-05-05
 
 
--- !query 4
+-- !query 5
 SELECT *
 FROM   t1
 WHERE  t1c IN (SELECT t2c
@@ -88,16 +96,16 @@ WHERE  t1c IN (SELECT t2c
                WHERE  t2b >= 8
                LIMIT  2)
 LIMIT 4
--- !query 4 schema
+-- !query 5 schema
 struct<t1a:string,t1b:smallint,t1c:int,t1d:bigint,t1e:float,t1f:double,t1g:decimal(2,-2),t1h:timestamp,t1i:date>
--- !query 4 output
+-- !query 5 output
 val1a  16      12      10      15.0    20.0    2000    2014-07-04 01:01:00     2014-07-04
 val1a  16      12      21      15.0    20.0    2000    2014-06-04 01:02:00.001 2014-06-04
 val1b  8       16      19      17.0    25.0    2600    2014-05-04 01:01:00     2014-05-04
 val1c  8       16      19      17.0    25.0    2600    2014-05-04 01:02:00.001 2014-05-05
 
 
--- !query 5
+-- !query 6
 SELECT Count(DISTINCT( t1a )),
        t1b
 FROM   t1
@@ -108,29 +116,29 @@ WHERE  t1d IN (SELECT t2d
 GROUP  BY t1b
 ORDER  BY t1b DESC NULLS FIRST
 LIMIT  1
--- !query 5 schema
+-- !query 6 schema
 struct<count(DISTINCT t1a):bigint,t1b:smallint>
--- !query 5 output
+-- !query 6 output
 1      NULL
 
 
--- !query 6
+-- !query 7
 SELECT *
 FROM   t1
 WHERE  t1b NOT IN (SELECT t2b
                    FROM   t2
                    WHERE  t2b > 6
                    LIMIT  2)
--- !query 6 schema
+-- !query 7 schema
 struct<t1a:string,t1b:smallint,t1c:int,t1d:bigint,t1e:float,t1f:double,t1g:decimal(2,-2),t1h:timestamp,t1i:date>
--- !query 6 output
+-- !query 7 output
 val1a  16      12      10      15.0    20.0    2000    2014-07-04 01:01:00     2014-07-04
 val1a  16      12      21      15.0    20.0    2000    2014-06-04 01:02:00.001 2014-06-04
 val1a  6       8       10      15.0    20.0    2000    2014-04-04 01:00:00     2014-04-04
 val1a  6       8       10      15.0    20.0    2000    2014-04-04 01:02:00.001 2014-04-04
 
 
--- !query 7
+-- !query 8
 SELECT Count(DISTINCT( t1a )),
        t1b
 FROM   t1
@@ -141,7 +149,7 @@ WHERE  t1d NOT IN (SELECT t2d
 GROUP  BY t1b
 ORDER BY t1b NULLS last
 LIMIT  1
--- !query 7 schema
+-- !query 8 schema
 struct<count(DISTINCT t1a):bigint,t1b:smallint>
--- !query 7 output
+-- !query 8 output
 1      6
index d0106c4..85b3ca1 100644 (file)
@@ -557,11 +557,13 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
   }
 
   test("SPARK-18004 limit + aggregates") {
-    val df = Seq(("a", 1), ("b", 2), ("c", 1), ("d", 5)).toDF("id", "value")
-    val limit2Df = df.limit(2)
-    checkAnswer(
-      limit2Df.groupBy("id").count().select($"id"),
-      limit2Df.select($"id"))
+    withSQLConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT.key -> "true") {
+      val df = Seq(("a", 1), ("b", 2), ("c", 1), ("d", 5)).toDF("id", "value")
+      val limit2Df = df.limit(2)
+      checkAnswer(
+        limit2Df.groupBy("id").count().select($"id"),
+        limit2Df.select($"id"))
+    }
   }
 
   test("SPARK-17237 remove backticks in a pivot result schema") {
index 3a393d7..c1a5f50 100644 (file)
@@ -524,6 +524,15 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
     sortTest()
   }
 
+  test("limit for skew dataframe") {
+    // Create a skew dataframe.
+    val df = testData.repartition(100).union(testData).limit(50)
+    // Because `rdd` of dataframe will add a `DeserializeToObject` on top of `GlobalLimit`,
+    // the `GlobalLimit` will not be replaced with `CollectLimit`. So we can test if `GlobalLimit`
+    // work on skew partitions.
+    assert(df.rdd.count() == 50L)
+  }
+
   test("CTE feature") {
     checkAnswer(
       sql("with q1 as (select * from testData limit 10) select * from q1"),
@@ -1935,7 +1944,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
     // TODO: support subexpression elimination in whole stage codegen
     withSQLConf("spark.sql.codegen.wholeStage" -> "false") {
       // select from a table to prevent constant folding.
-      val df = sql("SELECT a, b from testData2 limit 1")
+      val df = sql("SELECT a, b from testData2 order by a, b limit 1")
       checkAnswer(df, Row(1, 1))
 
       checkAnswer(df.selectExpr("a + 1", "a + 1"), Row(2, 2))
index b736d43..41de731 100644 (file)
@@ -50,7 +50,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll {
       expectedPartitionStartIndices: Array[Int]): Unit = {
     val mapOutputStatistics = bytesByPartitionIdArray.zipWithIndex.map {
       case (bytesByPartitionId, index) =>
-        new MapOutputStatistics(index, bytesByPartitionId)
+        new MapOutputStatistics(index, bytesByPartitionId, Array[Long](1))
     }
     val estimatedPartitionStartIndices =
       coordinator.estimatePartitionStartIndices(mapOutputStatistics)
@@ -114,8 +114,8 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll {
       val bytesByPartitionId2 = Array[Long](0, 0, 0, 0, 0, 0)
       val mapOutputStatistics =
         Array(
-          new MapOutputStatistics(0, bytesByPartitionId1),
-          new MapOutputStatistics(1, bytesByPartitionId2))
+          new MapOutputStatistics(0, bytesByPartitionId1, Array[Long](0)),
+          new MapOutputStatistics(1, bytesByPartitionId2, Array[Long](0)))
       intercept[AssertionError](coordinator.estimatePartitionStartIndices(mapOutputStatistics))
     }
 
index bdc1063..3db89ec 100644 (file)
@@ -262,7 +262,7 @@ class PlannerSuite extends SharedSQLContext {
           ).queryExecution.executedPlan.collect {
             case exchange: ShuffleExchangeExec => exchange
           }.length
-          assert(numExchanges === 5)
+          assert(numExchanges === 3)
         }
 
         {
@@ -277,7 +277,7 @@ class PlannerSuite extends SharedSQLContext {
           ).queryExecution.executedPlan.collect {
             case exchange: ShuffleExchangeExec => exchange
           }.length
-          assert(numExchanges === 5)
+          assert(numExchanges === 3)
         }
 
       }
index cebaad5..b9b2b7d 100644 (file)
@@ -40,6 +40,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
   private val originalColumnBatchSize = TestHive.conf.columnBatchSize
   private val originalInMemoryPartitionPruning = TestHive.conf.inMemoryPartitionPruning
   private val originalCrossJoinEnabled = TestHive.conf.crossJoinEnabled
+  private val originalLimitFlatGlobalLimit = TestHive.conf.limitFlatGlobalLimit
   private val originalSessionLocalTimeZone = TestHive.conf.sessionLocalTimeZone
 
   def testCases: Seq[(String, File)] = {
@@ -59,6 +60,8 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
     TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, true)
     // Ensures that cross joins are enabled so that we can test them
     TestHive.setConf(SQLConf.CROSS_JOINS_ENABLED, true)
+    // Ensure that limit operation returns rows in the same order as Hive
+    TestHive.setConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT, false)
     // Fix session local timezone to America/Los_Angeles for those timezone sensitive tests
     // (timestamp_*)
     TestHive.setConf(SQLConf.SESSION_LOCAL_TIMEZONE, "America/Los_Angeles")
@@ -73,6 +76,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
       TestHive.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize)
       TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning)
       TestHive.setConf(SQLConf.CROSS_JOINS_ENABLED, originalCrossJoinEnabled)
+      TestHive.setConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT, originalLimitFlatGlobalLimit)
       TestHive.setConf(SQLConf.SESSION_LOCAL_TIMEZONE, originalSessionLocalTimeZone)
 
       // For debugging dump some statistics about how much time was spent in various optimizer rules
index cc592cf..1654129 100644 (file)
@@ -22,21 +22,29 @@ import scala.collection.JavaConverters._
 import org.scalatest.BeforeAndAfter
 
 import org.apache.spark.sql.hive.test.{TestHive, TestHiveQueryExecution}
+import org.apache.spark.sql.internal.SQLConf
 
 /**
  * A set of test cases that validate partition and column pruning.
  */
 class PruningSuite extends HiveComparisonTest with BeforeAndAfter {
 
+  private val originalLimitFlatGlobalLimit = TestHive.conf.limitFlatGlobalLimit
+
   override def beforeAll(): Unit = {
     super.beforeAll()
     TestHive.setCacheTables(false)
+    TestHive.setConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT, false)
     // Column/partition pruning is not implemented for `InMemoryColumnarTableScan` yet,
     // need to reset the environment to ensure all referenced tables in this suites are
     // not cached in-memory. Refer to https://issues.apache.org/jira/browse/SPARK-2283
     // for details.
     TestHive.reset()
   }
+   override def afterAll() {
+    TestHive.setConf(SQLConf.LIMIT_FLAT_GLOBAL_LIMIT, originalLimitFlatGlobalLimit)
+    super.afterAll()
+  }
 
   // Column pruning tests