[SPARK-24216][SQL] Spark TypedAggregateExpression uses getSimpleName that is not...
authorFangshi Li <fli@linkedin.com>
Tue, 12 Jun 2018 19:10:08 +0000 (12:10 -0700)
committerWenchen Fan <wenchen@databricks.com>
Tue, 12 Jun 2018 19:10:08 +0000 (12:10 -0700)
## What changes were proposed in this pull request?

When user create a aggregator object in scala and pass the aggregator to Spark Dataset's agg() method, Spark's will initialize TypedAggregateExpression with the nodeName field as aggregator.getClass.getSimpleName. However, getSimpleName is not safe in scala environment, depending on how user creates the aggregator object. For example, if the aggregator class full qualified name is "com.my.company.MyUtils$myAgg$2$", the getSimpleName will throw java.lang.InternalError "Malformed class name". This has been reported in scalatest https://github.com/scalatest/scalatest/pull/1044 and discussed in many scala upstream jiras such as SI-8110, SI-5425.

To fix this issue, we follow the solution in https://github.com/scalatest/scalatest/pull/1044 to add safer version of getSimpleName as a util method, and TypedAggregateExpression will invoke this util method rather than getClass.getSimpleName.

## How was this patch tested?
added unit test

Author: Fangshi Li <fli@linkedin.com>

Closes #21276 from fangshil/SPARK-24216.

core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala
core/src/main/scala/org/apache/spark/util/Utils.scala
core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala
sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala

index 3b469a6..bf618b4 100644 (file)
@@ -200,10 +200,12 @@ abstract class AccumulatorV2[IN, OUT] extends Serializable {
   }
 
   override def toString: String = {
+    // getClass.getSimpleName can cause Malformed class name error,
+    // call safer `Utils.getSimpleName` instead
     if (metadata == null) {
-      "Un-registered Accumulator: " + getClass.getSimpleName
+      "Un-registered Accumulator: " + Utils.getSimpleName(getClass)
     } else {
-      getClass.getSimpleName + s"(id: $id, name: $name, value: $value)"
+      Utils.getSimpleName(getClass) + s"(id: $id, name: $name, value: $value)"
     }
   }
 }
index f9191a5..7428db2 100644 (file)
@@ -19,6 +19,7 @@ package org.apache.spark.util
 
 import java.io._
 import java.lang.{Byte => JByte}
+import java.lang.InternalError
 import java.lang.management.{LockInfo, ManagementFactory, MonitorInfo, ThreadInfo}
 import java.lang.reflect.InvocationTargetException
 import java.math.{MathContext, RoundingMode}
@@ -1820,7 +1821,7 @@ private[spark] object Utils extends Logging {
 
   /** Return the class name of the given object, removing all dollar signs */
   def getFormattedClassName(obj: AnyRef): String = {
-    obj.getClass.getSimpleName.replace("$", "")
+    getSimpleName(obj.getClass).replace("$", "")
   }
 
   /**
@@ -2715,6 +2716,62 @@ private[spark] object Utils extends Logging {
     HashCodes.fromBytes(secretBytes).toString()
   }
 
+  /**
+   * Safer than Class obj's getSimpleName which may throw Malformed class name error in scala.
+   * This method mimicks scalatest's getSimpleNameOfAnObjectsClass.
+   */
+  def getSimpleName(cls: Class[_]): String = {
+    try {
+      return cls.getSimpleName
+    } catch {
+      case err: InternalError => return stripDollars(stripPackages(cls.getName))
+    }
+  }
+
+  /**
+   * Remove the packages from full qualified class name
+   */
+  private def stripPackages(fullyQualifiedName: String): String = {
+    fullyQualifiedName.split("\\.").takeRight(1)(0)
+  }
+
+  /**
+   * Remove trailing dollar signs from qualified class name,
+   * and return the trailing part after the last dollar sign in the middle
+   */
+  private def stripDollars(s: String): String = {
+    val lastDollarIndex = s.lastIndexOf('$')
+    if (lastDollarIndex < s.length - 1) {
+      // The last char is not a dollar sign
+      if (lastDollarIndex == -1 || !s.contains("$iw")) {
+        // The name does not have dollar sign or is not an intepreter
+        // generated class, so we should return the full string
+        s
+      } else {
+        // The class name is intepreter generated,
+        // return the part after the last dollar sign
+        // This is the same behavior as getClass.getSimpleName
+        s.substring(lastDollarIndex + 1)
+      }
+    }
+    else {
+      // The last char is a dollar sign
+      // Find last non-dollar char
+      val lastNonDollarChar = s.reverse.find(_ != '$')
+      lastNonDollarChar match {
+        case None => s
+        case Some(c) =>
+          val lastNonDollarIndex = s.lastIndexOf(c)
+          if (lastNonDollarIndex == -1) {
+            s
+          } else {
+            // Strip the trailing dollar signs
+            // Invoke stripDollars again to get the simple name
+            stripDollars(s.substring(0, lastNonDollarIndex + 1))
+          }
+      }
+    }
+  }
 }
 
 private[util] object CallerContext extends Logging {
index 3b42731..418d2f9 100644 (file)
@@ -1168,6 +1168,22 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging {
       Utils.checkAndGetK8sMasterUrl("k8s://foo://host:port")
     }
   }
+
+  object MalformedClassObject {
+    class MalformedClass
+  }
+
+  test("Safe getSimpleName") {
+    // getSimpleName on class of MalformedClass will result in error: Malformed class name
+    // Utils.getSimpleName works
+    val err = intercept[java.lang.InternalError] {
+      classOf[MalformedClassObject.MalformedClass].getSimpleName
+    }
+    assert(err.getMessage === "Malformed class name")
+
+    assert(Utils.getSimpleName(classOf[MalformedClassObject.MalformedClass]) ===
+      "UtilsSuite$MalformedClassObject$MalformedClass")
+  }
 }
 
 private class SimpleExtension
index 3a1c166..11f46eb 100644 (file)
@@ -30,6 +30,7 @@ import org.apache.spark.ml.{Estimator, Model}
 import org.apache.spark.ml.param.Param
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.Dataset
+import org.apache.spark.util.Utils
 
 /**
  * A small wrapper that defines a training session for an estimator, and some methods to log
@@ -47,7 +48,9 @@ private[spark] class Instrumentation[E <: Estimator[_]] private (
 
   private val id = UUID.randomUUID()
   private val prefix = {
-    val className = estimator.getClass.getSimpleName
+    // estimator.getClass.getSimpleName can cause Malformed class name error,
+    // call safer `Utils.getSimpleName` instead
+    val className = Utils.getSimpleName(estimator.getClass)
     s"$className-${estimator.uid}-${dataset.hashCode()}-$id: "
   }
 
index aab8cc5..6d44890 100644 (file)
@@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateSafeProjection
 import org.apache.spark.sql.catalyst.expressions.objects.Invoke
 import org.apache.spark.sql.expressions.Aggregator
 import org.apache.spark.sql.types._
+import org.apache.spark.util.Utils
 
 object TypedAggregateExpression {
   def apply[BUF : Encoder, OUT : Encoder](
@@ -109,7 +110,9 @@ trait TypedAggregateExpression extends AggregateFunction {
     s"$nodeName($input)"
   }
 
-  override def nodeName: String = aggregator.getClass.getSimpleName.stripSuffix("$")
+  // aggregator.getClass.getSimpleName can cause Malformed class name error,
+  // call safer `Utils.getSimpleName` instead
+  override def nodeName: String = Utils.getSimpleName(aggregator.getClass).stripSuffix("$");
 }
 
 // TODO: merge these 2 implementations once we refactor the `AggregateFunction` interface.
index 693e67d..97e6c6d 100644 (file)
@@ -53,7 +53,9 @@ trait DataSourceV2StringFormat {
 
   private def sourceName: String = source match {
     case registered: DataSourceRegister => registered.shortName()
-    case _ => source.getClass.getSimpleName.stripSuffix("$")
+    // source.getClass.getSimpleName can cause Malformed class name error,
+    // call safer `Utils.getSimpleName` instead
+    case _ => Utils.getSimpleName(source.getClass)
   }
 
   def metadataString: String = {