[SPARK-22389][SQL] data source v2 partitioning reporting interface
authorWenchen Fan <wenchen@databricks.com>
Mon, 22 Jan 2018 23:21:09 +0000 (15:21 -0800)
committergatorsmile <gatorsmile@gmail.com>
Mon, 22 Jan 2018 23:21:09 +0000 (15:21 -0800)
## What changes were proposed in this pull request?

a new interface which allows data source to report partitioning and avoid shuffle at Spark side.

The design is pretty like the internal distribution/partitioing framework. Spark defines a `Distribution` interfaces and several concrete implementations, and ask the data source to report a `Partitioning`, the `Partitioning` should tell Spark if it can satisfy a `Distribution` or not.

## How was this patch tested?

new test

Author: Wenchen Fan <wenchen@databricks.com>

Closes #20201 from cloud-fan/partition-reporting.

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ClusteredDistribution.java [new file with mode: 0644]
sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Distribution.java [new file with mode: 0644]
sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Partitioning.java [new file with mode: 0644]
sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java [new file with mode: 0644]
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourcePartitioning.scala [new file with mode: 0644]
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala
sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java [new file with mode: 0644]
sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala

index 0189bd7..4d9a992 100644 (file)
@@ -153,7 +153,7 @@ case class BroadcastDistribution(mode: BroadcastMode) extends Distribution {
  *   1. number of partitions.
  *   2. if it can satisfy a given distribution.
  */
-sealed trait Partitioning {
+trait Partitioning {
   /** Returns the number of partitions that the data is split across */
   val numPartitions: Int
 
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ClusteredDistribution.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ClusteredDistribution.java
new file mode 100644 (file)
index 0000000..7346500
--- /dev/null
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.sources.v2.reader;
+
+import org.apache.spark.annotation.InterfaceStability;
+
+/**
+ * A concrete implementation of {@link Distribution}. Represents a distribution where records that
+ * share the same values for the {@link #clusteredColumns} will be produced by the same
+ * {@link ReadTask}.
+ */
+@InterfaceStability.Evolving
+public class ClusteredDistribution implements Distribution {
+
+  /**
+   * The names of the clustered columns. Note that they are order insensitive.
+   */
+  public final String[] clusteredColumns;
+
+  public ClusteredDistribution(String[] clusteredColumns) {
+    this.clusteredColumns = clusteredColumns;
+  }
+}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Distribution.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Distribution.java
new file mode 100644 (file)
index 0000000..a6201a2
--- /dev/null
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.sources.v2.reader;
+
+import org.apache.spark.annotation.InterfaceStability;
+
+/**
+ * An interface to represent data distribution requirement, which specifies how the records should
+ * be distributed among the {@link ReadTask}s that are returned by
+ * {@link DataSourceV2Reader#createReadTasks()}. Note that this interface has nothing to do with
+ * the data ordering inside one partition(the output records of a single {@link ReadTask}).
+ *
+ * The instance of this interface is created and provided by Spark, then consumed by
+ * {@link Partitioning#satisfy(Distribution)}. This means data source developers don't need to
+ * implement this interface, but need to catch as more concrete implementations of this interface
+ * as possible in {@link Partitioning#satisfy(Distribution)}.
+ *
+ * Concrete implementations until now:
+ * <ul>
+ *   <li>{@link ClusteredDistribution}</li>
+ * </ul>
+ */
+@InterfaceStability.Evolving
+public interface Distribution {}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Partitioning.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Partitioning.java
new file mode 100644 (file)
index 0000000..199e45d
--- /dev/null
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.sources.v2.reader;
+
+import org.apache.spark.annotation.InterfaceStability;
+
+/**
+ * An interface to represent the output data partitioning for a data source, which is returned by
+ * {@link SupportsReportPartitioning#outputPartitioning()}. Note that this should work like a
+ * snapshot. Once created, it should be deterministic and always report the same number of
+ * partitions and the same "satisfy" result for a certain distribution.
+ */
+@InterfaceStability.Evolving
+public interface Partitioning {
+
+  /**
+   * Returns the number of partitions(i.e., {@link ReadTask}s) the data source outputs.
+   */
+  int numPartitions();
+
+  /**
+   * Returns true if this partitioning can satisfy the given distribution, which means Spark does
+   * not need to shuffle the output data of this data source for some certain operations.
+   *
+   * Note that, Spark may add new concrete implementations of {@link Distribution} in new releases.
+   * This method should be aware of it and always return false for unrecognized distributions. It's
+   * recommended to check every Spark new release and support new distributions if possible, to
+   * avoid shuffle at Spark side for more cases.
+   */
+  boolean satisfy(Distribution distribution);
+}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java
new file mode 100644 (file)
index 0000000..f786472
--- /dev/null
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.sources.v2.reader;
+
+import org.apache.spark.annotation.InterfaceStability;
+
+/**
+ * A mix in interface for {@link DataSourceV2Reader}. Data source readers can implement this
+ * interface to report data partitioning and try to avoid shuffle at Spark side.
+ */
+@InterfaceStability.Evolving
+public interface SupportsReportPartitioning {
+
+  /**
+   * Returns the output data partitioning that this reader guarantees.
+   */
+  Partitioning outputPartitioning();
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourcePartitioning.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourcePartitioning.scala
new file mode 100644 (file)
index 0000000..943d010
--- /dev/null
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.v2
+
+import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, Expression}
+import org.apache.spark.sql.catalyst.plans.physical
+import org.apache.spark.sql.sources.v2.reader.{ClusteredDistribution, Partitioning}
+
+/**
+ * An adapter from public data source partitioning to catalyst internal `Partitioning`.
+ */
+class DataSourcePartitioning(
+    partitioning: Partitioning,
+    colNames: AttributeMap[String]) extends physical.Partitioning {
+
+  override val numPartitions: Int = partitioning.numPartitions()
+
+  override def satisfies(required: physical.Distribution): Boolean = {
+    super.satisfies(required) || {
+      required match {
+        case d: physical.ClusteredDistribution if isCandidate(d.clustering) =>
+          val attrs = d.clustering.map(_.asInstanceOf[Attribute])
+          partitioning.satisfy(
+            new ClusteredDistribution(attrs.map { a =>
+              val name = colNames.get(a)
+              assert(name.isDefined, s"Attribute ${a.name} is not found in the data source output")
+              name.get
+            }.toArray))
+
+        case _ => false
+      }
+    }
+  }
+
+  private def isCandidate(clustering: Seq[Expression]): Boolean = {
+    clustering.forall {
+      case a: Attribute => colNames.contains(a)
+      case _ => false
+    }
+  }
+}
index beb6673..69d871d 100644 (file)
@@ -24,6 +24,7 @@ import org.apache.spark.sql.Row
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder}
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.physical
 import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, WholeStageCodegenExec}
 import org.apache.spark.sql.execution.streaming.continuous._
 import org.apache.spark.sql.sources.v2.reader._
@@ -42,6 +43,14 @@ case class DataSourceV2ScanExec(
 
   override def producedAttributes: AttributeSet = AttributeSet(fullOutput)
 
+  override def outputPartitioning: physical.Partitioning = reader match {
+    case s: SupportsReportPartitioning =>
+      new DataSourcePartitioning(
+        s.outputPartitioning(), AttributeMap(output.map(a => a -> a.name)))
+
+    case _ => super.outputPartitioning
+  }
+
   private lazy val readTasks: java.util.List[ReadTask[UnsafeRow]] = reader match {
     case r: SupportsScanUnsafeRow => r.createUnsafeRowReadTasks()
     case _ =>
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java
new file mode 100644 (file)
index 0000000..806d0bc
--- /dev/null
@@ -0,0 +1,110 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package test.org.apache.spark.sql.sources.v2;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.List;
+
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.catalyst.expressions.GenericRow;
+import org.apache.spark.sql.sources.v2.DataSourceV2;
+import org.apache.spark.sql.sources.v2.DataSourceV2Options;
+import org.apache.spark.sql.sources.v2.ReadSupport;
+import org.apache.spark.sql.sources.v2.reader.*;
+import org.apache.spark.sql.types.StructType;
+
+public class JavaPartitionAwareDataSource implements DataSourceV2, ReadSupport {
+
+  class Reader implements DataSourceV2Reader, SupportsReportPartitioning {
+    private final StructType schema = new StructType().add("a", "int").add("b", "int");
+
+    @Override
+    public StructType readSchema() {
+      return schema;
+    }
+
+    @Override
+    public List<ReadTask<Row>> createReadTasks() {
+      return java.util.Arrays.asList(
+        new SpecificReadTask(new int[]{1, 1, 3}, new int[]{4, 4, 6}),
+        new SpecificReadTask(new int[]{2, 4, 4}, new int[]{6, 2, 2}));
+    }
+
+    @Override
+    public Partitioning outputPartitioning() {
+      return new MyPartitioning();
+    }
+  }
+
+  static class MyPartitioning implements Partitioning {
+
+    @Override
+    public int numPartitions() {
+      return 2;
+    }
+
+    @Override
+    public boolean satisfy(Distribution distribution) {
+      if (distribution instanceof ClusteredDistribution) {
+        String[] clusteredCols = ((ClusteredDistribution) distribution).clusteredColumns;
+        return Arrays.asList(clusteredCols).contains("a");
+      }
+
+      return false;
+    }
+  }
+
+  static class SpecificReadTask implements ReadTask<Row>, DataReader<Row> {
+    private int[] i;
+    private int[] j;
+    private int current = -1;
+
+    SpecificReadTask(int[] i, int[] j) {
+      assert i.length == j.length;
+      this.i = i;
+      this.j = j;
+    }
+
+    @Override
+    public boolean next() throws IOException {
+      current += 1;
+      return current < i.length;
+    }
+
+    @Override
+    public Row get() {
+      return new GenericRow(new Object[] {i[current], j[current]});
+    }
+
+    @Override
+    public void close() throws IOException {
+
+    }
+
+    @Override
+    public DataReader<Row> createDataReader() {
+      return this;
+    }
+  }
+
+  @Override
+  public DataSourceV2Reader createReader(DataSourceV2Options options) {
+    return new Reader();
+  }
+}
index 0ca2952..0620693 100644 (file)
@@ -24,6 +24,7 @@ import test.org.apache.spark.sql.sources.v2._
 import org.apache.spark.SparkException
 import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
 import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
 import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector
 import org.apache.spark.sql.sources.{Filter, GreaterThan}
 import org.apache.spark.sql.sources.v2.reader._
@@ -95,6 +96,40 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext {
     }
   }
 
+  test("partitioning reporting") {
+    import org.apache.spark.sql.functions.{count, sum}
+    Seq(classOf[PartitionAwareDataSource], classOf[JavaPartitionAwareDataSource]).foreach { cls =>
+      withClue(cls.getName) {
+        val df = spark.read.format(cls.getName).load()
+        checkAnswer(df, Seq(Row(1, 4), Row(1, 4), Row(3, 6), Row(2, 6), Row(4, 2), Row(4, 2)))
+
+        val groupByColA = df.groupBy('a).agg(sum('b))
+        checkAnswer(groupByColA, Seq(Row(1, 8), Row(2, 6), Row(3, 6), Row(4, 4)))
+        assert(groupByColA.queryExecution.executedPlan.collectFirst {
+          case e: ShuffleExchangeExec => e
+        }.isEmpty)
+
+        val groupByColAB = df.groupBy('a, 'b).agg(count("*"))
+        checkAnswer(groupByColAB, Seq(Row(1, 4, 2), Row(2, 6, 1), Row(3, 6, 1), Row(4, 2, 2)))
+        assert(groupByColAB.queryExecution.executedPlan.collectFirst {
+          case e: ShuffleExchangeExec => e
+        }.isEmpty)
+
+        val groupByColB = df.groupBy('b).agg(sum('a))
+        checkAnswer(groupByColB, Seq(Row(2, 8), Row(4, 2), Row(6, 5)))
+        assert(groupByColB.queryExecution.executedPlan.collectFirst {
+          case e: ShuffleExchangeExec => e
+        }.isDefined)
+
+        val groupByAPlusB = df.groupBy('a + 'b).agg(count("*"))
+        checkAnswer(groupByAPlusB, Seq(Row(5, 2), Row(6, 2), Row(8, 1), Row(9, 1)))
+        assert(groupByAPlusB.queryExecution.executedPlan.collectFirst {
+          case e: ShuffleExchangeExec => e
+        }.isDefined)
+      }
+    }
+  }
+
   test("simple writable data source") {
     // TODO: java implementation.
     Seq(classOf[SimpleWritableDataSource]).foreach { cls =>
@@ -365,3 +400,47 @@ class BatchReadTask(start: Int, end: Int)
 
   override def close(): Unit = batch.close()
 }
+
+class PartitionAwareDataSource extends DataSourceV2 with ReadSupport {
+
+  class Reader extends DataSourceV2Reader with SupportsReportPartitioning {
+    override def readSchema(): StructType = new StructType().add("a", "int").add("b", "int")
+
+    override def createReadTasks(): JList[ReadTask[Row]] = {
+      // Note that we don't have same value of column `a` across partitions.
+      java.util.Arrays.asList(
+        new SpecificReadTask(Array(1, 1, 3), Array(4, 4, 6)),
+        new SpecificReadTask(Array(2, 4, 4), Array(6, 2, 2)))
+    }
+
+    override def outputPartitioning(): Partitioning = new MyPartitioning
+  }
+
+  class MyPartitioning extends Partitioning {
+    override def numPartitions(): Int = 2
+
+    override def satisfy(distribution: Distribution): Boolean = distribution match {
+      case c: ClusteredDistribution => c.clusteredColumns.contains("a")
+      case _ => false
+    }
+  }
+
+  override def createReader(options: DataSourceV2Options): DataSourceV2Reader = new Reader
+}
+
+class SpecificReadTask(i: Array[Int], j: Array[Int]) extends ReadTask[Row] with DataReader[Row] {
+  assert(i.length == j.length)
+
+  private var current = -1
+
+  override def createDataReader(): DataReader[Row] = this
+
+  override def next(): Boolean = {
+    current += 1
+    current < i.length
+  }
+
+  override def get(): Row = Row(i(current), j(current))
+
+  override def close(): Unit = {}
+}