[SPARK-24235][SS] Implement continuous shuffle writer for single reader partition.
authorJose Torres <torres.joseph.f+github@gmail.com>
Wed, 13 Jun 2018 20:13:01 +0000 (13:13 -0700)
committerShixiong Zhu <zsxwing@gmail.com>
Wed, 13 Jun 2018 20:13:01 +0000 (13:13 -0700)
## What changes were proposed in this pull request?

https://docs.google.com/document/d/1IL4kJoKrZWeyIhklKUJqsW-yEN7V7aL05MmM65AYOfE/edit

Implement continuous shuffle write RDD for a single reader partition. (I don't believe any implementation changes are actually required for multiple reader partitions, but this PR is already very large, so I want to exclude those for now to keep the size down.)

## How was this patch tested?

new unit tests

Author: Jose Torres <torres.joseph.f+github@gmail.com>

Closes #21428 from jose-torres/writerTask.

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleWriter.scala [new file with mode: 0644]
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleReader.scala [moved from sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala with 86% similarity]
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleWriter.scala [new file with mode: 0644]
sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleSuite.scala [moved from sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala with 65% similarity]

index 801b28b..cf6572d 100644 (file)
@@ -34,8 +34,10 @@ case class ContinuousShuffleReadPartition(
   // Initialized only on the executor, and only once even as we call compute() multiple times.
   lazy val (reader: ContinuousShuffleReader, endpoint) = {
     val env = SparkEnv.get.rpcEnv
-    val receiver = new UnsafeRowReceiver(queueSize, numShuffleWriters, epochIntervalMs, env)
-    val endpoint = env.setupEndpoint(s"UnsafeRowReceiver-${UUID.randomUUID()}", receiver)
+    val receiver = new RPCContinuousShuffleReader(
+      queueSize, numShuffleWriters, epochIntervalMs, env)
+    val endpoint = env.setupEndpoint(s"RPCContinuousShuffleReader-${UUID.randomUUID()}", receiver)
+
     TaskContext.get().addTaskCompletionListener { ctx =>
       env.stop(endpoint)
     }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleWriter.scala
new file mode 100644 (file)
index 0000000..47b1f78
--- /dev/null
@@ -0,0 +1,27 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.continuous.shuffle
+
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+
+/**
+ * Trait for writing to a continuous processing shuffle.
+ */
+trait ContinuousShuffleWriter {
+  def write(epoch: Iterator[UnsafeRow]): Unit
+}
@@ -20,26 +20,24 @@ package org.apache.spark.sql.execution.streaming.continuous.shuffle
 import java.util.concurrent._
 import java.util.concurrent.atomic.AtomicBoolean
 
-import scala.collection.mutable
-
 import org.apache.spark.internal.Logging
 import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint}
 import org.apache.spark.sql.catalyst.expressions.UnsafeRow
 import org.apache.spark.util.NextIterator
 
 /**
- * Messages for the UnsafeRowReceiver endpoint. Either an incoming row or an epoch marker.
+ * Messages for the RPCContinuousShuffleReader endpoint. Either an incoming row or an epoch marker.
  *
  * Each message comes tagged with writerId, identifying which writer the message is coming
  * from. The receiver will only begin the next epoch once all writers have sent an epoch
  * marker ending the current epoch.
  */
-private[shuffle] sealed trait UnsafeRowReceiverMessage extends Serializable {
+private[shuffle] sealed trait RPCContinuousShuffleMessage extends Serializable {
   def writerId: Int
 }
 private[shuffle] case class ReceiverRow(writerId: Int, row: UnsafeRow)
-  extends UnsafeRowReceiverMessage
-private[shuffle] case class ReceiverEpochMarker(writerId: Int) extends UnsafeRowReceiverMessage
+  extends RPCContinuousShuffleMessage
+private[shuffle] case class ReceiverEpochMarker(writerId: Int) extends RPCContinuousShuffleMessage
 
 /**
  * RPC endpoint for receiving rows into a continuous processing shuffle task. Continuous shuffle
@@ -48,7 +46,7 @@ private[shuffle] case class ReceiverEpochMarker(writerId: Int) extends UnsafeRow
  * TODO: Support multiple source tasks. We need to output a single epoch marker once all
  * source tasks have sent one.
  */
-private[shuffle] class UnsafeRowReceiver(
+private[shuffle] class RPCContinuousShuffleReader(
       queueSize: Int,
       numShuffleWriters: Int,
       epochIntervalMs: Long,
@@ -57,7 +55,7 @@ private[shuffle] class UnsafeRowReceiver(
   // Note that this queue will be drained from the main task thread and populated in the RPC
   // response thread.
   private val queues = Array.fill(numShuffleWriters) {
-    new ArrayBlockingQueue[UnsafeRowReceiverMessage](queueSize)
+    new ArrayBlockingQueue[RPCContinuousShuffleMessage](queueSize)
   }
 
   // Exposed for testing to determine if the endpoint gets stopped on task end.
@@ -68,7 +66,9 @@ private[shuffle] class UnsafeRowReceiver(
   }
 
   override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
-    case r: UnsafeRowReceiverMessage =>
+    case r: RPCContinuousShuffleMessage =>
+      // Note that this will block a thread the shared RPC handler pool!
+      // The TCP based shuffle handler (SPARK-24541) will avoid this problem.
       queues(r.writerId).put(r)
       context.reply(())
   }
@@ -79,10 +79,10 @@ private[shuffle] class UnsafeRowReceiver(
       private val writerEpochMarkersReceived = Array.fill(numShuffleWriters)(false)
 
       private val executor = Executors.newFixedThreadPool(numShuffleWriters)
-      private val completion = new ExecutorCompletionService[UnsafeRowReceiverMessage](executor)
+      private val completion = new ExecutorCompletionService[RPCContinuousShuffleMessage](executor)
 
-      private def completionTask(writerId: Int) = new Callable[UnsafeRowReceiverMessage] {
-        override def call(): UnsafeRowReceiverMessage = queues(writerId).take()
+      private def completionTask(writerId: Int) = new Callable[RPCContinuousShuffleMessage] {
+        override def call(): RPCContinuousShuffleMessage = queues(writerId).take()
       }
 
       // Initialize by submitting tasks to read the first row from each writer.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleWriter.scala
new file mode 100644 (file)
index 0000000..1c6f3dd
--- /dev/null
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.continuous.shuffle
+
+import scala.concurrent.Future
+import scala.concurrent.duration.Duration
+
+import org.apache.spark.Partitioner
+import org.apache.spark.rpc.RpcEndpointRef
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.util.ThreadUtils
+
+/**
+ * A [[ContinuousShuffleWriter]] sending data to [[RPCContinuousShuffleReader]] instances.
+ *
+ * @param writerId The partition ID of this writer.
+ * @param outputPartitioner The partitioner on the reader side of the shuffle.
+ * @param endpoints The [[RPCContinuousShuffleReader]] endpoints to write to. Indexed by
+ *                  partition ID within outputPartitioner.
+ */
+class RPCContinuousShuffleWriter(
+    writerId: Int,
+    outputPartitioner: Partitioner,
+    endpoints: Array[RpcEndpointRef]) extends ContinuousShuffleWriter {
+
+  if (outputPartitioner.numPartitions != 1) {
+    throw new IllegalArgumentException("multiple readers not yet supported")
+  }
+
+  if (outputPartitioner.numPartitions != endpoints.length) {
+    throw new IllegalArgumentException(s"partitioner size ${outputPartitioner.numPartitions} did " +
+      s"not match endpoint count ${endpoints.length}")
+  }
+
+  def write(epoch: Iterator[UnsafeRow]): Unit = {
+    while (epoch.hasNext) {
+      val row = epoch.next()
+      endpoints(outputPartitioner.getPartition(row)).askSync[Unit](ReceiverRow(writerId, row))
+    }
+
+    val futures = endpoints.map(_.ask[Unit](ReceiverEpochMarker(writerId))).toSeq
+    implicit val ec = ThreadUtils.sameThread
+    ThreadUtils.awaitResult(Future.sequence(futures), Duration.Inf)
+  }
+}
 
 package org.apache.spark.sql.execution.streaming.continuous.shuffle
 
-import org.apache.spark.{TaskContext, TaskContextImpl}
+import org.apache.spark.{HashPartitioner, Partition, TaskContext, TaskContextImpl}
 import org.apache.spark.rpc.RpcEndpointRef
-import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection}
+import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow}
 import org.apache.spark.sql.streaming.StreamTest
 import org.apache.spark.sql.types.{DataType, IntegerType, StringType}
 import org.apache.spark.unsafe.types.UTF8String
 
-class ContinuousShuffleReadSuite extends StreamTest {
-
-  private def unsafeRow(value: Int) = {
-    UnsafeProjection.create(Array(IntegerType : DataType))(
-      new GenericInternalRow(Array(value: Any)))
-  }
-
-  private def unsafeRow(value: String) = {
-    UnsafeProjection.create(Array(StringType : DataType))(
-      new GenericInternalRow(Array(UTF8String.fromString(value): Any)))
-  }
-
-  private def send(endpoint: RpcEndpointRef, messages: UnsafeRowReceiverMessage*) = {
-    messages.foreach(endpoint.askSync[Unit](_))
-  }
-
+class ContinuousShuffleSuite extends StreamTest {
   // In this unit test, we emulate that we're in the task thread where
   // ContinuousShuffleReadRDD.compute() will be evaluated. This requires a task context
   // thread local to be set.
@@ -58,39 +43,29 @@ class ContinuousShuffleReadSuite extends StreamTest {
     super.afterEach()
   }
 
-  test("receiver stopped with row last") {
-    val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1)
-    val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint
-    send(
-      endpoint,
-      ReceiverEpochMarker(0),
-      ReceiverRow(0, unsafeRow(111))
-    )
+  private implicit def unsafeRow(value: Int) = {
+    UnsafeProjection.create(Array(IntegerType : DataType))(
+      new GenericInternalRow(Array(value: Any)))
+  }
 
-    ctx.markTaskCompleted(None)
-    val receiver = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].reader
-    eventually(timeout(streamingTimeout)) {
-      assert(receiver.asInstanceOf[UnsafeRowReceiver].stopped.get())
-    }
+  private def unsafeRow(value: String) = {
+    UnsafeProjection.create(Array(StringType : DataType))(
+      new GenericInternalRow(Array(UTF8String.fromString(value): Any)))
   }
 
-  test("receiver stopped with marker last") {
-    val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1)
-    val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint
-    send(
-      endpoint,
-      ReceiverRow(0, unsafeRow(111)),
-      ReceiverEpochMarker(0)
-    )
+  private def send(endpoint: RpcEndpointRef, messages: RPCContinuousShuffleMessage*) = {
+    messages.foreach(endpoint.askSync[Unit](_))
+  }
 
-    ctx.markTaskCompleted(None)
-    val receiver = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].reader
-    eventually(timeout(streamingTimeout)) {
-      assert(receiver.asInstanceOf[UnsafeRowReceiver].stopped.get())
-    }
+  private def readRDDEndpoint(rdd: ContinuousShuffleReadRDD) = {
+    rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint
   }
 
-  test("one epoch") {
+  private def readEpoch(rdd: ContinuousShuffleReadRDD) = {
+    rdd.compute(rdd.partitions(0), ctx).toSeq.map(_.getInt(0))
+  }
+
+  test("reader - one epoch") {
     val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1)
     val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint
     send(
@@ -105,7 +80,7 @@ class ContinuousShuffleReadSuite extends StreamTest {
     assert(iter.toSeq.map(_.getInt(0)) == Seq(111, 222, 333))
   }
 
-  test("multiple epochs") {
+  test("reader - multiple epochs") {
     val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1)
     val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint
     send(
@@ -124,7 +99,7 @@ class ContinuousShuffleReadSuite extends StreamTest {
     assert(secondEpoch.toSeq.map(_.getInt(0)) == Seq(222, 333))
   }
 
-  test("empty epochs") {
+  test("reader - empty epochs") {
     val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1)
     val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint
 
@@ -148,7 +123,7 @@ class ContinuousShuffleReadSuite extends StreamTest {
     assert(rdd.compute(rdd.partitions(0), ctx).isEmpty)
   }
 
-  test("multiple partitions") {
+  test("reader - multiple partitions") {
     val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 5)
     // Send all data before processing to ensure there's no crossover.
     for (p <- rdd.partitions) {
@@ -169,7 +144,7 @@ class ContinuousShuffleReadSuite extends StreamTest {
     }
   }
 
-  test("blocks waiting for new rows") {
+  test("reader - blocks waiting for new rows") {
     val rdd = new ContinuousShuffleReadRDD(
       sparkContext, numPartitions = 1, epochIntervalMs = Long.MaxValue)
     val epoch = rdd.compute(rdd.partitions(0), ctx)
@@ -195,7 +170,7 @@ class ContinuousShuffleReadSuite extends StreamTest {
     }
   }
 
-  test("multiple writers") {
+  test("reader - multiple writers") {
     val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1, numShuffleWriters = 3)
     val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint
     send(
@@ -213,7 +188,7 @@ class ContinuousShuffleReadSuite extends StreamTest {
       Set("writer0-row0", "writer1-row0", "writer2-row0"))
   }
 
-  test("epoch only ends when all writers send markers") {
+  test("reader - epoch only ends when all writers send markers") {
     val rdd = new ContinuousShuffleReadRDD(
       sparkContext, numPartitions = 1, numShuffleWriters = 3, epochIntervalMs = Long.MaxValue)
     val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint
@@ -233,6 +208,7 @@ class ContinuousShuffleReadSuite extends StreamTest {
 
     // After checking the right rows, block until we get an epoch marker indicating there's no next.
     // (Also fail the assertion if for some reason we get a row.)
+
     val readEpochMarkerThread = new Thread {
       override def run(): Unit = {
         assert(!epoch.hasNext)
@@ -251,10 +227,10 @@ class ContinuousShuffleReadSuite extends StreamTest {
     }
 
     // Join to pick up assertion failures.
-    readEpochMarkerThread.join()
+    readEpochMarkerThread.join(streamingTimeout.toMillis)
   }
 
-  test("writer epochs non aligned") {
+  test("reader - writer epochs non aligned") {
     val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1, numShuffleWriters = 3)
     val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint
     // We send multiple epochs for 0, then multiple for 1, then multiple for 2. The receiver should
@@ -288,4 +264,153 @@ class ContinuousShuffleReadSuite extends StreamTest {
     val thirdEpoch = rdd.compute(rdd.partitions(0), ctx).map(_.getUTF8String(0).toString).toSet
     assert(thirdEpoch == Set("writer1-row1", "writer2-row0"))
   }
+
+  test("one epoch") {
+    val reader = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1)
+    val writer = new RPCContinuousShuffleWriter(
+      0, new HashPartitioner(1), Array(readRDDEndpoint(reader)))
+
+    writer.write(Iterator(1, 2, 3))
+
+    assert(readEpoch(reader) == Seq(1, 2, 3))
+  }
+
+  test("multiple epochs") {
+    val reader = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1)
+    val writer = new RPCContinuousShuffleWriter(
+      0, new HashPartitioner(1), Array(readRDDEndpoint(reader)))
+
+    writer.write(Iterator(1, 2, 3))
+    writer.write(Iterator(4, 5, 6))
+
+    assert(readEpoch(reader) == Seq(1, 2, 3))
+    assert(readEpoch(reader) == Seq(4, 5, 6))
+  }
+
+  test("empty epochs") {
+    val reader = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1)
+    val writer = new RPCContinuousShuffleWriter(
+      0, new HashPartitioner(1), Array(readRDDEndpoint(reader)))
+
+    writer.write(Iterator())
+    writer.write(Iterator(1, 2))
+    writer.write(Iterator())
+    writer.write(Iterator())
+    writer.write(Iterator(3, 4))
+    writer.write(Iterator())
+
+    assert(readEpoch(reader) == Seq())
+    assert(readEpoch(reader) == Seq(1, 2))
+    assert(readEpoch(reader) == Seq())
+    assert(readEpoch(reader) == Seq())
+    assert(readEpoch(reader) == Seq(3, 4))
+    assert(readEpoch(reader) == Seq())
+  }
+
+  test("blocks waiting for writer") {
+    val reader = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1)
+    val writer = new RPCContinuousShuffleWriter(
+      0, new HashPartitioner(1), Array(readRDDEndpoint(reader)))
+
+    val readerEpoch = reader.compute(reader.partitions(0), ctx)
+
+    val readRowThread = new Thread {
+      override def run(): Unit = {
+        assert(readerEpoch.toSeq.map(_.getInt(0)) == Seq(1))
+      }
+    }
+    readRowThread.start()
+
+    eventually(timeout(streamingTimeout)) {
+      assert(readRowThread.getState == Thread.State.TIMED_WAITING)
+    }
+
+    // Once we write the epoch the thread should stop waiting and succeed.
+    writer.write(Iterator(1))
+    readRowThread.join(streamingTimeout.toMillis)
+  }
+
+  test("multiple writer partitions") {
+    val numWriterPartitions = 3
+
+    val reader = new ContinuousShuffleReadRDD(
+      sparkContext, numPartitions = 1, numShuffleWriters = numWriterPartitions)
+    val writers = (0 until 3).map { idx =>
+      new RPCContinuousShuffleWriter(idx, new HashPartitioner(1), Array(readRDDEndpoint(reader)))
+    }
+
+    writers(0).write(Iterator(1, 4, 7))
+    writers(1).write(Iterator(2, 5))
+    writers(2).write(Iterator(3, 6))
+
+    writers(0).write(Iterator(4, 7, 10))
+    writers(1).write(Iterator(5, 8))
+    writers(2).write(Iterator(6, 9))
+
+    // Since there are multiple asynchronous writers, the original row sequencing is not guaranteed.
+    // The epochs should be deterministically preserved, however.
+    assert(readEpoch(reader).toSet == Seq(1, 2, 3, 4, 5, 6, 7).toSet)
+    assert(readEpoch(reader).toSet == Seq(4, 5, 6, 7, 8, 9, 10).toSet)
+  }
+
+  test("reader epoch only ends when all writer partitions write it") {
+    val numWriterPartitions = 3
+
+    val reader = new ContinuousShuffleReadRDD(
+      sparkContext, numPartitions = 1, numShuffleWriters = numWriterPartitions)
+    val writers = (0 until 3).map { idx =>
+      new RPCContinuousShuffleWriter(idx, new HashPartitioner(1), Array(readRDDEndpoint(reader)))
+    }
+
+    writers(1).write(Iterator())
+    writers(2).write(Iterator())
+
+    val readerEpoch = reader.compute(reader.partitions(0), ctx)
+
+    val readEpochMarkerThread = new Thread {
+      override def run(): Unit = {
+        assert(!readerEpoch.hasNext)
+      }
+    }
+
+    readEpochMarkerThread.start()
+    eventually(timeout(streamingTimeout)) {
+      assert(readEpochMarkerThread.getState == Thread.State.TIMED_WAITING)
+    }
+
+    writers(0).write(Iterator())
+    readEpochMarkerThread.join(streamingTimeout.toMillis)
+  }
+
+  test("receiver stopped with row last") {
+    val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1)
+    val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint
+    send(
+      endpoint,
+      ReceiverEpochMarker(0),
+      ReceiverRow(0, unsafeRow(111))
+    )
+
+    ctx.markTaskCompleted(None)
+    val receiver = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].reader
+    eventually(timeout(streamingTimeout)) {
+      assert(receiver.asInstanceOf[RPCContinuousShuffleReader].stopped.get())
+    }
+  }
+
+  test("receiver stopped with marker last") {
+    val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1)
+    val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint
+    send(
+      endpoint,
+      ReceiverRow(0, unsafeRow(111)),
+      ReceiverEpochMarker(0)
+    )
+
+    ctx.markTaskCompleted(None)
+    val receiver = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].reader
+    eventually(timeout(streamingTimeout)) {
+      assert(receiver.asInstanceOf[RPCContinuousShuffleReader].stopped.get())
+    }
+  }
 }