[SPARK-22938][SQL][FOLLOWUP] Assert that SQLConf.get is accessed only on the driver
[spark.git] / sql / catalyst / src / main / scala / org / apache / spark / sql / internal / SQLConf.scala
index b00edca..0b1965c 100644 (file)
@@ -27,7 +27,7 @@ import scala.util.matching.Regex
 
 import org.apache.hadoop.fs.Path
 
-import org.apache.spark.{SparkContext, SparkEnv}
+import org.apache.spark.TaskContext
 import org.apache.spark.internal.Logging
 import org.apache.spark.internal.config._
 import org.apache.spark.network.util.ByteUnit
@@ -107,7 +107,13 @@ object SQLConf {
    * run tests in parallel. At the time this feature was implemented, this was a no-op since we
    * run unit tests (that does not involve SparkSession) in serial order.
    */
-  def get: SQLConf = confGetter.get()()
+  def get: SQLConf = {
+    if (Utils.isTesting && TaskContext.get != null) {
+      // we're accessing it during task execution, fail.
+      throw new IllegalStateException("SQLConf should only be created and accessed on the driver.")
+    }
+    confGetter.get()()
+  }
 
   val OPTIMIZER_MAX_ITERATIONS = buildConf("spark.sql.optimizer.maxIterations")
     .internal()
@@ -1274,12 +1280,6 @@ object SQLConf {
 class SQLConf extends Serializable with Logging {
   import SQLConf._
 
-  if (Utils.isTesting && SparkEnv.get != null) {
-    // assert that we're only accessing it on the driver.
-    assert(SparkEnv.get.executorId == SparkContext.DRIVER_IDENTIFIER,
-      "SQLConf should only be created and accessed on the driver.")
-  }
-
   /** Only low degree of contention is expected for conf, thus NOT using ConcurrentHashMap. */
   @transient protected[spark] val settings = java.util.Collections.synchronizedMap(
     new java.util.HashMap[String, String]())