[SPARK-22938][SQL][FOLLOWUP] Assert that SQLConf.get is accessed only on the driver
[spark.git] / sql / core / src / main / scala / org / apache / spark / sql / SparkSession.scala
index c502e58..e2a1a57 100644 (file)
@@ -24,7 +24,7 @@ import scala.collection.JavaConverters._
 import scala.reflect.runtime.universe.TypeTag
 import scala.util.control.NonFatal
 
-import org.apache.spark.{SPARK_VERSION, SparkConf, SparkContext}
+import org.apache.spark.{SPARK_VERSION, SparkConf, SparkContext, TaskContext}
 import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability}
 import org.apache.spark.api.java.JavaRDD
 import org.apache.spark.internal.Logging
@@ -898,6 +898,7 @@ object SparkSession extends Logging {
      * @since 2.0.0
      */
     def getOrCreate(): SparkSession = synchronized {
+      assertOnDriver()
       // Get the session from current thread's active session.
       var session = activeThreadSession.get()
       if ((session ne null) && !session.sparkContext.isStopped) {
@@ -1022,14 +1023,20 @@ object SparkSession extends Logging {
    *
    * @since 2.2.0
    */
-  def getActiveSession: Option[SparkSession] = Option(activeThreadSession.get)
+  def getActiveSession: Option[SparkSession] = {
+    assertOnDriver()
+    Option(activeThreadSession.get)
+  }
 
   /**
    * Returns the default SparkSession that is returned by the builder.
    *
    * @since 2.2.0
    */
-  def getDefaultSession: Option[SparkSession] = Option(defaultSession.get)
+  def getDefaultSession: Option[SparkSession] = {
+    assertOnDriver()
+    Option(defaultSession.get)
+  }
 
   /**
    * Returns the currently active SparkSession, otherwise the default one. If there is no default
@@ -1062,6 +1069,14 @@ object SparkSession extends Logging {
     }
   }
 
+  private def assertOnDriver(): Unit = {
+    if (Utils.isTesting && TaskContext.get != null) {
+      // we're accessing it during task execution, fail.
+      throw new IllegalStateException(
+        "SparkSession should only be created and accessed on the driver.")
+    }
+  }
+
   /**
    * Helper method to create an instance of `SessionState` based on `className` from conf.
    * The result is either `SessionState` or a Hive based `SessionState`.