[SPARK-24537][R] Add array_remove / array_zip / map_from_arrays / array_distinct
authorHuaxin Gao <huaxing@us.ibm.com>
Fri, 13 Jul 2018 02:40:58 +0000 (10:40 +0800)
committerhyukjinkwon <gurwls223@apache.org>
Fri, 13 Jul 2018 02:40:58 +0000 (10:40 +0800)
## What changes were proposed in this pull request?
Add array_remove / array_zip / map_from_arrays / array_distinct functions in SparkR.

## How was this patch tested?
Add tests in test_sparkSQL.R

Author: Huaxin Gao <huaxing@us.ibm.com>

Closes #21645 from huaxingao/spark-24537.

R/pkg/NAMESPACE
R/pkg/R/functions.R
R/pkg/R/generics.R
R/pkg/tests/fulltests/test_sparkSQL.R

index 9696f69..adfd387 100644 (file)
@@ -201,13 +201,16 @@ exportMethods("%<=>%",
               "approxCountDistinct",
               "approxQuantile",
               "array_contains",
+              "array_distinct",
               "array_join",
               "array_max",
               "array_min",
               "array_position",
+              "array_remove",
               "array_repeat",
               "array_sort",
               "arrays_overlap",
+              "arrays_zip",
               "asc",
               "ascii",
               "asin",
@@ -306,6 +309,7 @@ exportMethods("%<=>%",
               "lpad",
               "ltrim",
               "map_entries",
+              "map_from_arrays",
               "map_keys",
               "map_values",
               "max",
index 3bff633..2929a00 100644 (file)
@@ -194,10 +194,12 @@ NULL
 #'          \itemize{
 #'          \item \code{array_contains}: a value to be checked if contained in the column.
 #'          \item \code{array_position}: a value to locate in the given array.
+#'          \item \code{array_remove}: a value to remove in the given array.
 #'          }
 #' @param ... additional argument(s). In \code{to_json} and \code{from_json}, this contains
 #'            additional named properties to control how it is converted, accepts the same
-#'            options as the JSON data source.
+#'            options as the JSON data source.  In \code{arrays_zip}, this contains additional
+#'            Columns of arrays to be merged.
 #' @name column_collection_functions
 #' @rdname column_collection_functions
 #' @family collection functions
@@ -207,9 +209,9 @@ NULL
 #' df <- createDataFrame(cbind(model = rownames(mtcars), mtcars))
 #' tmp <- mutate(df, v1 = create_array(df$mpg, df$cyl, df$hp))
 #' head(select(tmp, array_contains(tmp$v1, 21), size(tmp$v1)))
-#' head(select(tmp, array_max(tmp$v1), array_min(tmp$v1)))
+#' head(select(tmp, array_max(tmp$v1), array_min(tmp$v1), array_distinct(tmp$v1)))
 #' head(select(tmp, array_position(tmp$v1, 21), array_repeat(df$mpg, 3), array_sort(tmp$v1)))
-#' head(select(tmp, flatten(tmp$v1), reverse(tmp$v1)))
+#' head(select(tmp, flatten(tmp$v1), reverse(tmp$v1), array_remove(tmp$v1, 21)))
 #' tmp2 <- mutate(tmp, v2 = explode(tmp$v1))
 #' head(tmp2)
 #' head(select(tmp, posexplode(tmp$v1)))
@@ -221,6 +223,7 @@ NULL
 #' head(select(tmp3, element_at(tmp3$v3, "Valiant")))
 #' tmp4 <- mutate(df, v4 = create_array(df$mpg, df$cyl), v5 = create_array(df$cyl, df$hp))
 #' head(select(tmp4, concat(tmp4$v4, tmp4$v5), arrays_overlap(tmp4$v4, tmp4$v5)))
+#' head(select(tmp4, arrays_zip(tmp4$v4, tmp4$v5), map_from_arrays(tmp4$v4, tmp4$v5)))
 #' head(select(tmp, concat(df$mpg, df$cyl, df$hp)))
 #' tmp5 <- mutate(df, v6 = create_array(df$model, df$model))
 #' head(select(tmp5, array_join(tmp5$v6, "#"), array_join(tmp5$v6, "#", "NULL")))}
@@ -1978,7 +1981,7 @@ setMethod("levenshtein", signature(y = "Column"),
           })
 
 #' @details
-#' \code{months_between}: Returns number of months between dates \code{y} and \code{x}. 
+#' \code{months_between}: Returns number of months between dates \code{y} and \code{x}.
 #' If \code{y} is later than \code{x}, then the result is positive. If \code{y} and \code{x}
 #' are on the same day of month, or both are the last day of month, time of day will be ignored.
 #' Otherwise, the difference is calculated based on 31 days per month, and rounded to 8 digits.
@@ -3009,6 +3012,19 @@ setMethod("array_contains",
           })
 
 #' @details
+#' \code{array_distinct}: Removes duplicate values from the array.
+#'
+#' @rdname column_collection_functions
+#' @aliases array_distinct array_distinct,Column-method
+#' @note array_distinct since 2.4.0
+setMethod("array_distinct",
+          signature(x = "Column"),
+          function(x) {
+            jc <- callJStatic("org.apache.spark.sql.functions", "array_distinct", x@jc)
+            column(jc)
+          })
+
+#' @details
 #' \code{array_join}: Concatenates the elements of column using the delimiter.
 #' Null values are replaced with nullReplacement if set, otherwise they are ignored.
 #'
@@ -3072,6 +3088,19 @@ setMethod("array_position",
           })
 
 #' @details
+#' \code{array_remove}: Removes all elements that equal to element from the given array.
+#'
+#' @rdname column_collection_functions
+#' @aliases array_remove array_remove,Column-method
+#' @note array_remove since 2.4.0
+setMethod("array_remove",
+          signature(x = "Column", value = "ANY"),
+          function(x, value) {
+            jc <- callJStatic("org.apache.spark.sql.functions", "array_remove", x@jc, value)
+            column(jc)
+          })
+
+#' @details
 #' \code{array_repeat}: Creates an array containing \code{x} repeated the number of times
 #' given by \code{count}.
 #'
@@ -3121,6 +3150,24 @@ setMethod("arrays_overlap",
           })
 
 #' @details
+#' \code{arrays_zip}: Returns a merged array of structs in which the N-th struct contains all N-th
+#' values of input arrays.
+#'
+#' @rdname column_collection_functions
+#' @aliases arrays_zip arrays_zip,Column-method
+#' @note arrays_zip since 2.4.0
+setMethod("arrays_zip",
+          signature(x = "Column"),
+          function(x, ...) {
+            jcols <- lapply(list(x, ...), function(arg) {
+              stopifnot(class(arg) == "Column")
+              arg@jc
+            })
+            jc <- callJStatic("org.apache.spark.sql.functions", "arrays_zip", jcols)
+            column(jc)
+          })
+
+#' @details
 #' \code{flatten}: Creates a single array from an array of arrays.
 #' If a structure of nested arrays is deeper than two levels, only one level of nesting is removed.
 #'
@@ -3148,6 +3195,21 @@ setMethod("map_entries",
          })
 
 #' @details
+#' \code{map_from_arrays}: Creates a new map column. The array in the first column is used for
+#' keys. The array in the second column is used for values. All elements in the array for key
+#' should not be null.
+#'
+#' @rdname column_collection_functions
+#' @aliases map_from_arrays map_from_arrays,Column-method
+#' @note map_from_arrays since 2.4.0
+setMethod("map_from_arrays",
+          signature(x = "Column", y = "Column"),
+          function(x, y) {
+            jc <- callJStatic("org.apache.spark.sql.functions", "map_from_arrays", x@jc, y@jc)
+            column(jc)
+         })
+
+#' @details
 #' \code{map_keys}: Returns an unordered array containing the keys of the map.
 #'
 #' @rdname column_collection_functions
index 9321bba..4a7210b 100644 (file)
@@ -759,6 +759,10 @@ setGeneric("array_contains", function(x, value) { standardGeneric("array_contain
 
 #' @rdname column_collection_functions
 #' @name NULL
+setGeneric("array_distinct", function(x) { standardGeneric("array_distinct") })
+
+#' @rdname column_collection_functions
+#' @name NULL
 setGeneric("array_join", function(x, delimiter, ...) { standardGeneric("array_join") })
 
 #' @rdname column_collection_functions
@@ -775,6 +779,10 @@ setGeneric("array_position", function(x, value) { standardGeneric("array_positio
 
 #' @rdname column_collection_functions
 #' @name NULL
+setGeneric("array_remove", function(x, value) { standardGeneric("array_remove") })
+
+#' @rdname column_collection_functions
+#' @name NULL
 setGeneric("array_repeat", function(x, count) { standardGeneric("array_repeat") })
 
 #' @rdname column_collection_functions
@@ -785,6 +793,10 @@ setGeneric("array_sort", function(x) { standardGeneric("array_sort") })
 #' @name NULL
 setGeneric("arrays_overlap", function(x, y) { standardGeneric("arrays_overlap") })
 
+#' @rdname column_collection_functions
+#' @name NULL
+setGeneric("arrays_zip", function(x, ...) { standardGeneric("arrays_zip") })
+
 #' @rdname column_string_functions
 #' @name NULL
 setGeneric("ascii", function(x) { standardGeneric("ascii") })
@@ -1052,6 +1064,10 @@ setGeneric("map_entries", function(x) { standardGeneric("map_entries") })
 
 #' @rdname column_collection_functions
 #' @name NULL
+setGeneric("map_from_arrays", function(x, y) { standardGeneric("map_from_arrays") })
+
+#' @rdname column_collection_functions
+#' @name NULL
 setGeneric("map_keys", function(x) { standardGeneric("map_keys") })
 
 #' @rdname column_collection_functions
index 36e0f78..adcbbff 100644 (file)
@@ -1503,6 +1503,27 @@ test_that("column functions", {
   result <- collect(select(df2, reverse(df2[[1]])))[[1]]
   expect_equal(result, "cba")
 
+  # Test array_distinct() and array_remove()
+  df <- createDataFrame(list(list(list(1L, 2L, 3L, 1L, 2L)), list(list(6L, 5L, 5L, 4L, 6L))))
+  result <- collect(select(df, array_distinct(df[[1]])))[[1]]
+  expect_equal(result, list(list(1L, 2L, 3L), list(6L, 5L, 4L)))
+
+  result <- collect(select(df, array_remove(df[[1]], 2L)))[[1]]
+  expect_equal(result, list(list(1L, 3L, 1L), list(6L, 5L, 5L, 4L, 6L)))
+
+  # Test arrays_zip()
+  df <- createDataFrame(list(list(list(1L, 2L), list(3L, 4L))), schema = c("c1", "c2"))
+  result <- collect(select(df, arrays_zip(df[[1]], df[[2]])))[[1]]
+  expected_entries <-  list(listToStruct(list(c1 = 1L, c2 = 3L)),
+                            listToStruct(list(c1 = 2L, c2 = 4L)))
+  expect_equal(result, list(expected_entries))
+
+  # Test map_from_arrays()
+  df <- createDataFrame(list(list(list("x", "y"), list(1, 2))), schema = c("k", "v"))
+  result <- collect(select(df, map_from_arrays(df$k, df$v)))[[1]]
+  expected_entries <- list(as.environment(list(x = 1, y = 2)))
+  expect_equal(result, expected_entries)
+
   # Test array_repeat()
   df <- createDataFrame(list(list("a", 3L), list("b", 2L)))
   result <- collect(select(df, array_repeat(df[[1]], df[[2]])))[[1]]