[SPARK-24822][PYSPARK] Python support for barrier execution mode
authorXingbo Jiang <xingbo.jiang@databricks.com>
Sat, 11 Aug 2018 13:44:45 +0000 (21:44 +0800)
committerWenchen Fan <wenchen@databricks.com>
Sat, 11 Aug 2018 13:44:45 +0000 (21:44 +0800)
## What changes were proposed in this pull request?

This PR add python support for barrier execution mode, thus enable launch a job containing barrier stage(s) from PySpark.

We just forked the existing `RDDBarrier` and `RDD.barrier()` in Python api.

## How was this patch tested?

Manually tested:
```
>>> rdd = sc.parallelize([1, 2, 3, 4])
>>> def f(iterator): yield sum(iterator)
...
>>> rdd.barrier().mapPartitions(f).isBarrier() == True
True
```

Unit tests will be added in a follow-up PR that implements BarrierTaskContext on python side.

Closes #22011 from jiangxb1987/python.

Authored-by: Xingbo Jiang <xingbo.jiang@databricks.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
core/src/main/scala/org/apache/spark/rdd/RDDBarrier.scala
python/pyspark/rdd.py

index 8bc0ff7..8c2ce88 100644 (file)
@@ -45,7 +45,8 @@ import org.apache.spark.util._
 private[spark] class PythonRDD(
     parent: RDD[_],
     func: PythonFunction,
-    preservePartitoning: Boolean)
+    preservePartitoning: Boolean,
+    isFromBarrier: Boolean = false)
   extends RDD[Array[Byte]](parent) {
 
   val bufferSize = conf.getInt("spark.buffer.size", 65536)
@@ -63,6 +64,9 @@ private[spark] class PythonRDD(
     val runner = PythonRunner(func, bufferSize, reuseWorker)
     runner.compute(firstParent.iterator(split, context), split.index, context)
   }
+
+  @transient protected lazy override val isBarrier_ : Boolean =
+    isFromBarrier || dependencies.exists(_.rdd.isBarrier())
 }
 
 /**
index 978e7c0..b399bf9 100644 (file)
@@ -19,7 +19,6 @@ package org.apache.spark.rdd
 
 import scala.reflect.ClassTag
 
-import org.apache.spark.BarrierTaskContext
 import org.apache.spark.TaskContext
 import org.apache.spark.annotation.{Experimental, Since}
 
index 9518518..d17a8eb 100644 (file)
@@ -2406,6 +2406,22 @@ class RDD(object):
             sock_info = self.ctx._jvm.PythonRDD.toLocalIteratorAndServe(self._jrdd.rdd())
         return _load_from_socket(sock_info, self._jrdd_deserializer)
 
+    def barrier(self):
+        """
+        .. note:: Experimental
+
+        Indicates that Spark must launch the tasks together for the current stage.
+
+        .. versionadded:: 2.4.0
+        """
+        return RDDBarrier(self)
+
+    def _is_barrier(self):
+        """
+        Whether this RDD is in a barrier stage.
+        """
+        return self._jrdd.rdd().isBarrier()
+
 
 def _prepare_for_python_RDD(sc, command):
     # the serialized command will be compressed by broadcast
@@ -2429,6 +2445,33 @@ def _wrap_function(sc, func, deserializer, serializer, profiler=None):
                                   sc.pythonVer, broadcast_vars, sc._javaAccumulator)
 
 
+class RDDBarrier(object):
+
+    """
+    .. note:: Experimental
+
+    An RDDBarrier turns an RDD into a barrier RDD, which forces Spark to launch tasks of the stage
+    contains this RDD together.
+
+    .. versionadded:: 2.4.0
+    """
+
+    def __init__(self, rdd):
+        self.rdd = rdd
+
+    def mapPartitions(self, f, preservesPartitioning=False):
+        """
+        .. note:: Experimental
+
+        Return a new RDD by applying a function to each partition of this RDD.
+
+        .. versionadded:: 2.4.0
+        """
+        def func(s, iterator):
+            return f(iterator)
+        return PipelinedRDD(self.rdd, func, preservesPartitioning, isFromBarrier=True)
+
+
 class PipelinedRDD(RDD):
 
     """
@@ -2448,7 +2491,7 @@ class PipelinedRDD(RDD):
     20
     """
 
-    def __init__(self, prev, func, preservesPartitioning=False):
+    def __init__(self, prev, func, preservesPartitioning=False, isFromBarrier=False):
         if not isinstance(prev, PipelinedRDD) or not prev._is_pipelinable():
             # This transformation is the first in its stage:
             self.func = func
@@ -2474,6 +2517,7 @@ class PipelinedRDD(RDD):
         self._jrdd_deserializer = self.ctx.serializer
         self._bypass_serializer = False
         self.partitioner = prev.partitioner if self.preservesPartitioning else None
+        self.is_barrier = prev._is_barrier() or isFromBarrier
 
     def getNumPartitions(self):
         return self._prev_jrdd.partitions().size()
@@ -2493,7 +2537,7 @@ class PipelinedRDD(RDD):
         wrapped_func = _wrap_function(self.ctx, self.func, self._prev_jrdd_deserializer,
                                       self._jrdd_deserializer, profiler)
         python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(), wrapped_func,
-                                             self.preservesPartitioning)
+                                             self.preservesPartitioning, self.is_barrier)
         self._jrdd_val = python_rdd.asJavaRDD()
 
         if profiler:
@@ -2509,6 +2553,9 @@ class PipelinedRDD(RDD):
     def _is_pipelinable(self):
         return not (self.is_cached or self.is_checkpointed)
 
+    def _is_barrier(self):
+        return self.is_barrier
+
 
 def _test():
     import doctest