[SPARK-25090][ML] Enforce implicit type coercion in ParamGridBuilder
authorMarco Gaido <marcogaido91@gmail.com>
Mon, 13 Aug 2018 01:11:37 +0000 (09:11 +0800)
committerhyukjinkwon <gurwls223@apache.org>
Mon, 13 Aug 2018 01:11:37 +0000 (09:11 +0800)
## What changes were proposed in this pull request?

When the grid of the parameters is created in `ParamGridBuilder`, the implicit type coercion is not enforced. So using an integer in the list of parameters to set for a parameter accepting a double can cause a class cast exception.

The PR proposes to enforce the type coercion when building the parameters.

## How was this patch tested?

added UT

Closes #22076 from mgaido91/SPARK-25090.

Authored-by: Marco Gaido <marcogaido91@gmail.com>
Signed-off-by: hyukjinkwon <gurwls223@apache.org>
python/pyspark/ml/tests.py
python/pyspark/ml/tuning.py

index 3d8883b..a770bad 100755 (executable)
@@ -950,6 +950,13 @@ class CrossValidatorTests(SparkSessionTestCase):
                          "Best model should have zero induced error")
         self.assertEqual(1.0, bestModelMetric, "Best model has R-squared of 1")
 
+    def test_param_grid_type_coercion(self):
+        lr = LogisticRegression(maxIter=10)
+        paramGrid = ParamGridBuilder().addGrid(lr.regParam, [0.5, 1]).build()
+        for param in paramGrid:
+            for v in param.values():
+                assert(type(v) == float)
+
     def test_save_load_trained_model(self):
         # This tests saving and loading the trained model only.
         # Save/load for CrossValidator will be added later: SPARK-13786
index 0c8029f..1f4abf5 100644 (file)
@@ -115,7 +115,11 @@ class ParamGridBuilder(object):
         """
         keys = self._param_grid.keys()
         grid_values = self._param_grid.values()
-        return [dict(zip(keys, prod)) for prod in itertools.product(*grid_values)]
+
+        def to_key_value_pairs(keys, values):
+            return [(key, key.typeConverter(value)) for key, value in zip(keys, values)]
+
+        return [dict(to_key_value_pairs(keys, prod)) for prod in itertools.product(*grid_values)]
 
 
 class ValidatorParams(HasSeed):