[SPARK-24466][SS] Fix TextSocketMicroBatchReader to be compatible with netcat again
authorJungtaek Lim <kabhwan@gmail.com>
Wed, 13 Jun 2018 04:34:46 +0000 (12:34 +0800)
committerhyukjinkwon <gurwls223@apache.org>
Wed, 13 Jun 2018 04:34:46 +0000 (12:34 +0800)
## What changes were proposed in this pull request?

TextSocketMicroBatchReader was no longer be compatible with netcat due to launching temporary reader for reading schema, and closing reader, and re-opening reader. While reliable socket server should be able to handle this without any issue, nc command normally can't handle multiple connections and simply exits when closing temporary reader.

This patch fixes TextSocketMicroBatchReader to be compatible with netcat again, via deferring opening socket to the first call of planInputPartitions() instead of constructor.

## How was this patch tested?

Added unit test which fails on current and succeeds with the patch. And also manually tested.

Author: Jungtaek Lim <kabhwan@gmail.com>

Closes #21497 from HeartSaVioR/SPARK-24466.

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala

index 8240e06..91e3b71 100644 (file)
@@ -22,6 +22,7 @@ import java.net.Socket
 import java.sql.Timestamp
 import java.text.SimpleDateFormat
 import java.util.{Calendar, List => JList, Locale, Optional}
+import java.util.concurrent.atomic.AtomicBoolean
 import javax.annotation.concurrent.GuardedBy
 
 import scala.collection.JavaConverters._
@@ -76,7 +77,7 @@ class TextSocketMicroBatchReader(options: DataSourceOptions) extends MicroBatchR
   @GuardedBy("this")
   private var lastOffsetCommitted: LongOffset = LongOffset(-1L)
 
-  initialize()
+  private val initialized: AtomicBoolean = new AtomicBoolean(false)
 
   /** This method is only used for unit test */
   private[sources] def getCurrentOffset(): LongOffset = synchronized {
@@ -149,6 +150,10 @@ class TextSocketMicroBatchReader(options: DataSourceOptions) extends MicroBatchR
 
     // Internal buffer only holds the batches after lastOffsetCommitted
     val rawList = synchronized {
+      if (initialized.compareAndSet(false, true)) {
+        initialize()
+      }
+
       val sliceStart = startOrdinal - lastOffsetCommitted.offset.toInt - 1
       val sliceEnd = endOrdinal - lastOffsetCommitted.offset.toInt - 1
       batches.slice(sliceStart, sliceEnd)
index a15a980..52e8386 100644 (file)
@@ -17,8 +17,7 @@
 
 package org.apache.spark.sql.execution.streaming.sources
 
-import java.io.IOException
-import java.net.InetSocketAddress
+import java.net.{InetSocketAddress, SocketException}
 import java.nio.ByteBuffer
 import java.nio.channels.ServerSocketChannel
 import java.sql.Timestamp
@@ -33,9 +32,10 @@ import org.apache.spark.internal.Logging
 import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.execution.datasources.DataSource
 import org.apache.spark.sql.execution.streaming._
+import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.sources.v2.{DataSourceOptions, MicroBatchReadSupport}
 import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset}
-import org.apache.spark.sql.streaming.StreamTest
+import org.apache.spark.sql.streaming.{StreamingQueryException, StreamTest}
 import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.sql.types.{StringType, StructField, StructType, TimestampType}
 
@@ -101,7 +101,7 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before
     serverThread = new ServerThread()
     serverThread.start()
 
-    withSQLConf("spark.sql.streaming.unsupportedOperationCheck" -> "false") {
+    withSQLConf(SQLConf.UNSUPPORTED_OPERATION_CHECK_ENABLED.key -> "false") {
       val ref = spark
       import ref.implicits._
 
@@ -130,7 +130,7 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before
     serverThread = new ServerThread()
     serverThread.start()
 
-    withSQLConf("spark.sql.streaming.unsupportedOperationCheck" -> "false") {
+    withSQLConf(SQLConf.UNSUPPORTED_OPERATION_CHECK_ENABLED.key -> "false") {
       val socket = spark
         .readStream
         .format("socket")
@@ -216,20 +216,11 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before
       "socket source does not support a user-specified schema"))
   }
 
-  test("no server up") {
-    val provider = new TextSocketSourceProvider
-    val parameters = Map("host" -> "localhost", "port" -> "0")
-    intercept[IOException] {
-      batchReader = provider.createMicroBatchReader(
-        Optional.empty(), "", new DataSourceOptions(parameters.asJava))
-    }
-  }
-
   test("input row metrics") {
     serverThread = new ServerThread()
     serverThread.start()
 
-    withSQLConf("spark.sql.streaming.unsupportedOperationCheck" -> "false") {
+    withSQLConf(SQLConf.UNSUPPORTED_OPERATION_CHECK_ENABLED.key -> "false") {
       val ref = spark
       import ref.implicits._
 
@@ -256,6 +247,66 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before
     }
   }
 
+  test("verify ServerThread only accepts the first connection") {
+    serverThread = new ServerThread()
+    serverThread.start()
+
+    withSQLConf(SQLConf.UNSUPPORTED_OPERATION_CHECK_ENABLED.key -> "false") {
+      val ref = spark
+      import ref.implicits._
+
+      val socket = spark
+        .readStream
+        .format("socket")
+        .options(Map("host" -> "localhost", "port" -> serverThread.port.toString))
+        .load()
+        .as[String]
+
+      assert(socket.schema === StructType(StructField("value", StringType) :: Nil))
+
+      testStream(socket)(
+        StartStream(),
+        AddSocketData("hello"),
+        CheckAnswer("hello"),
+        AddSocketData("world"),
+        CheckLastBatch("world"),
+        CheckAnswer("hello", "world"),
+        StopStream
+      )
+
+      // we are trying to connect to the server once again which should fail
+      try {
+        val socket2 = spark
+          .readStream
+          .format("socket")
+          .options(Map("host" -> "localhost", "port" -> serverThread.port.toString))
+          .load()
+          .as[String]
+
+        testStream(socket2)(
+          StartStream(),
+          AddSocketData("hello"),
+          CheckAnswer("hello"),
+          AddSocketData("world"),
+          CheckLastBatch("world"),
+          CheckAnswer("hello", "world"),
+          StopStream
+        )
+
+        fail("StreamingQueryException is expected!")
+      } catch {
+        case e: StreamingQueryException if e.cause.isInstanceOf[SocketException] => // pass
+      }
+    }
+  }
+
+  /**
+   * This class tries to mimic the behavior of netcat, so that we can ensure
+   * TextSocketStream supports netcat, which only accepts the first connection
+   * and exits the process when the first connection is closed.
+   *
+   * Please refer SPARK-24466 for more details.
+   */
   private class ServerThread extends Thread with Logging {
     private val serverSocketChannel = ServerSocketChannel.open()
     serverSocketChannel.bind(new InetSocketAddress(0))
@@ -265,36 +316,24 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before
 
     override def run(): Unit = {
       try {
+        val clientSocketChannel = serverSocketChannel.accept()
+
+        // Close server socket channel immediately to mimic the behavior that
+        // only first connection will be made and deny any further connections
+        // Note that the first client socket channel will be available
+        serverSocketChannel.close()
+
+        clientSocketChannel.configureBlocking(false)
+        clientSocketChannel.socket().setTcpNoDelay(true)
+
         while (true) {
-          val clientSocketChannel = serverSocketChannel.accept()
-          clientSocketChannel.configureBlocking(false)
-          clientSocketChannel.socket().setTcpNoDelay(true)
-
-          // Check whether remote client is closed but still send data to this closed socket.
-          // This happens in DataStreamReader where a source will be created to get the schema.
-          var remoteIsClosed = false
-          var cnt = 0
-          while (cnt < 3 && !remoteIsClosed) {
-            if (clientSocketChannel.read(ByteBuffer.allocate(1)) != -1) {
-              cnt += 1
-              Thread.sleep(100)
-            } else {
-              remoteIsClosed = true
-            }
-          }
-
-          if (remoteIsClosed) {
-            logInfo(s"remote client ${clientSocketChannel.socket()} is closed")
-          } else {
-            while (true) {
-              val line = messageQueue.take() + "\n"
-              clientSocketChannel.write(ByteBuffer.wrap(line.getBytes("UTF-8")))
-            }
-          }
+          val line = messageQueue.take() + "\n"
+          clientSocketChannel.write(ByteBuffer.wrap(line.getBytes("UTF-8")))
         }
       } catch {
         case e: InterruptedException =>
       } finally {
+        // no harm to call close() again...
         serverSocketChannel.close()
       }
     }