Support for types in Samza SQL UDF (#911)
authorSrinivasulu Punuru <srinipunuru@users.noreply.github.com>
Mon, 11 Feb 2019 22:08:47 +0000 (14:08 -0800)
committerGitHub <noreply@github.com>
Mon, 11 Feb 2019 22:08:47 +0000 (14:08 -0800)
* Support for types in udf

* Option to disable argument check for dynamic functions

* Added comments

* Update based on review comments

* Adding license

* Added some comments

* fix for test

20 files changed:
samza-api/src/main/java/org/apache/samza/sql/udfs/SamzaSqlUdf.java
samza-api/src/main/java/org/apache/samza/sql/udfs/SamzaSqlUdfMethod.java
samza-api/src/main/java/org/apache/samza/sql/udfs/ScalarUdf.java
samza-sql/src/main/java/org/apache/samza/sql/data/SamzaSqlExecutionContext.java
samza-sql/src/main/java/org/apache/samza/sql/fn/BuildOutputRecordUdf.java
samza-sql/src/main/java/org/apache/samza/sql/fn/ConvertToStringUdf.java
samza-sql/src/main/java/org/apache/samza/sql/fn/FlattenUdf.java
samza-sql/src/main/java/org/apache/samza/sql/fn/GetSqlFieldUdf.java
samza-sql/src/main/java/org/apache/samza/sql/fn/RegexMatchUdf.java
samza-sql/src/main/java/org/apache/samza/sql/impl/ConfigBasedUdfResolver.java
samza-sql/src/main/java/org/apache/samza/sql/interfaces/UdfMetadata.java
samza-sql/src/main/java/org/apache/samza/sql/planner/QueryPlanner.java
samza-sql/src/main/java/org/apache/samza/sql/planner/SamzaSqlScalarFunctionImpl.java
samza-sql/src/main/java/org/apache/samza/sql/planner/SamzaSqlUdfOperatorTable.java
samza-sql/src/test/java/org/apache/samza/sql/runner/TestSamzaSqlApplicationConfig.java
samza-sql/src/test/java/org/apache/samza/sql/util/MyTestArrayUdf.java
samza-sql/src/test/java/org/apache/samza/sql/util/MyTestPolyUdf.java [new file with mode: 0644]
samza-sql/src/test/java/org/apache/samza/sql/util/MyTestUdf.java
samza-sql/src/test/java/org/apache/samza/sql/util/SamzaSqlTestConfig.java
samza-test/src/test/java/org/apache/samza/test/samzasql/TestSamzaSqlEndToEnd.java

index de1821e..6a55c07 100644 (file)
@@ -37,6 +37,11 @@ public @interface SamzaSqlUdf {
   String name();
 
   /**
+   * Description of the UDF
+   */
+  String description();
+
+  /**
    * Whether the UDF is enabled or not.
    */
   boolean enabled() default true;
index 9b1c7ef..cba2749 100644 (file)
@@ -34,6 +34,12 @@ import org.apache.samza.sql.schema.SamzaSqlFieldType;
 public @interface SamzaSqlUdfMethod {
 
   /**
+   * Whether the argument check needs to be disabled. This is useful if the udf takes in
+   * dynamic number of arguments
+   */
+  boolean disableArgumentCheck() default false;
+
+  /**
    * Type of the arguments for the Samza SQL udf method
    */
   SamzaSqlFieldType[] params() default {};
index 3307dc0..ff67487 100644 (file)
@@ -23,8 +23,8 @@ import org.apache.samza.config.Config;
 
 
 /**
- * The base class for the Scalar UDFs. All the scalar UDF classes needs to extend this and implement a method named
- * "execute". The number and type of arguments for the execute method in the UDF class should match the number and type of fields
+ * The base class for the Scalar UDFs. All the scalar UDF classes needs to extend this.
+ * The number and type of arguments for the method annotated with {@link SamzaSqlUdfMethod} in the UDF class should match the number and type of fields
  * used while invoking this UDF in SQL statement.
  * Say for e.g. User creates a UDF class with signature int execute(int var1, String var2). It can be used in a SQL query
  *     select myudf(id, name) from profile
@@ -36,5 +36,4 @@ public interface ScalarUdf {
    * @param udfConfig Config specific to the udf.
    */
   void init(Config udfConfig);
-
 }
index 091ca62..2a8f92e 100644 (file)
@@ -19,7 +19,9 @@
 
 package org.apache.samza.sql.data;
 
+import java.util.ArrayList;
 import java.util.HashMap;
+import java.util.List;
 import java.util.Map;
 import java.util.function.Function;
 import java.util.stream.Collectors;
@@ -38,7 +40,10 @@ public class SamzaSqlExecutionContext implements Cloneable {
    * The variables that are shared among all cloned instance of {@link SamzaSqlExecutionContext}
    */
   private final SamzaSqlApplicationConfig sqlConfig;
-  private final Map<String, UdfMetadata> udfMetadata;
+
+  // Maps the UDF name to list of all UDF methods associated with the name.
+  // Since we support polymorphism there can be multiple udfMetadata associated with the single name.
+  private final Map<String, List<UdfMetadata>> udfMetadata;
 
   /**
    * The variable that are not shared among all cloned instance of {@link SamzaSqlExecutionContext}
@@ -52,8 +57,11 @@ public class SamzaSqlExecutionContext implements Cloneable {
 
   public SamzaSqlExecutionContext(SamzaSqlApplicationConfig config) {
     this.sqlConfig = config;
-    udfMetadata =
-        this.sqlConfig.getUdfMetadata().stream().collect(Collectors.toMap(UdfMetadata::getName, Function.identity()));
+    udfMetadata = new HashMap<>();
+    for(UdfMetadata udf : this.sqlConfig.getUdfMetadata()) {
+      udfMetadata.putIfAbsent(udf.getName(), new ArrayList<>());
+      udfMetadata.get(udf.getName()).add(udf);
+    }
   }
 
   public ScalarUdf getOrCreateUdf(String clazz, String udfName) {
@@ -61,7 +69,9 @@ public class SamzaSqlExecutionContext implements Cloneable {
   }
 
   public ScalarUdf createInstance(String clazz, String udfName) {
-    Config udfConfig = udfMetadata.get(udfName).getUdfConfig();
+
+    // Configs should be same for all the UDF methods within a UDF. Hence taking the first one.
+    Config udfConfig = udfMetadata.get(udfName).get(0).getUdfConfig();
     ScalarUdf scalarUdf = ReflectionUtils.createInstance(clazz);
     if (scalarUdf == null) {
       String msg = String.format("Couldn't create udf %s of class %s", udfName, clazz);
index dc928ab..e0c34f1 100644 (file)
@@ -61,13 +61,13 @@ import org.apache.samza.sql.udfs.ScalarUdf;
  * If no args is provided, it returns an empty SamzaSqlRelRecord (with empty field names and values list).
  */
 
-@SamzaSqlUdf(name="BuildOutputRecord")
+@SamzaSqlUdf(name = "BuildOutputRecord", description = "Creates an Output record.")
 public class BuildOutputRecordUdf implements ScalarUdf {
   @Override
   public void init(Config udfConfig) {
   }
 
-  @SamzaSqlUdfMethod
+  @SamzaSqlUdfMethod(disableArgumentCheck = true)
   public SamzaSqlRelRecord execute(Object... args) {
     int numOfArgs = args.length;
     Validate.isTrue(numOfArgs % 2 == 0, "numOfArgs should be an even number");
index dc482d8..659f7e3 100644 (file)
@@ -20,6 +20,7 @@
 package org.apache.samza.sql.fn;
 
 import org.apache.samza.config.Config;
+import org.apache.samza.sql.schema.SamzaSqlFieldType;
 import org.apache.samza.sql.udfs.SamzaSqlUdf;
 import org.apache.samza.sql.udfs.SamzaSqlUdfMethod;
 import org.apache.samza.sql.udfs.ScalarUdf;
@@ -28,15 +29,15 @@ import org.apache.samza.sql.udfs.ScalarUdf;
 /**
  * UDF that converts an object to it's string representation.
  */
-@SamzaSqlUdf(name = "convertToString")
+@SamzaSqlUdf(name = "convertToString", description = "Converts the object to string.")
 public class ConvertToStringUdf implements ScalarUdf {
   @Override
   public void init(Config udfConfig) {
   }
 
-  @SamzaSqlUdfMethod
-  public String execute(Object... args) {
-    return args[0].toString();
+  @SamzaSqlUdfMethod(params = SamzaSqlFieldType.ANY)
+  public String execute(Object args) {
+    return args.toString();
   }
 }
 
index fa3d15e..0734a3a 100644 (file)
@@ -21,20 +21,20 @@ package org.apache.samza.sql.fn;
 
 import java.util.List;
 import org.apache.samza.config.Config;
+import org.apache.samza.sql.schema.SamzaSqlFieldType;
 import org.apache.samza.sql.udfs.SamzaSqlUdf;
 import org.apache.samza.sql.udfs.SamzaSqlUdfMethod;
 import org.apache.samza.sql.udfs.ScalarUdf;
 
 
-@SamzaSqlUdf(name = "Flatten")
+@SamzaSqlUdf(name = "Flatten", description = "Flattens the array.")
 public class FlattenUdf implements ScalarUdf {
   @Override
   public void init(Config udfConfig) {
   }
 
-  @SamzaSqlUdfMethod
-  public Object execute(Object... arg) {
-    List value = (List) arg[0];
+  @SamzaSqlUdfMethod(params = SamzaSqlFieldType.ARRAY)
+  public Object execute(List value) {
     return value != null && !value.isEmpty() ? value.get(0) : value;
   }
 }
\ No newline at end of file
index 8f5704c..de56fa0 100644 (file)
@@ -24,6 +24,7 @@ import java.util.Map;
 import org.apache.commons.lang.Validate;
 import org.apache.samza.config.Config;
 import org.apache.samza.sql.SamzaSqlRelRecord;
+import org.apache.samza.sql.schema.SamzaSqlFieldType;
 import org.apache.samza.sql.udfs.SamzaSqlUdf;
 import org.apache.samza.sql.udfs.SamzaSqlUdfMethod;
 import org.apache.samza.sql.udfs.ScalarUdf;
@@ -51,22 +52,21 @@ import org.apache.samza.sql.udfs.ScalarUdf;
  *           - sessionKey (Scalar)
  *
  */
-@SamzaSqlUdf(name = "GetSqlField")
+@SamzaSqlUdf(name = "GetSqlField", description = "Get an element from complex Sql field as a String.")
 public class GetSqlFieldUdf implements ScalarUdf {
   @Override
   public void init(Config udfConfig) {
   }
 
-  @SamzaSqlUdfMethod
-  public String execute(Object... args) {
-    Object currentFieldOrValue = args[0];
+  @SamzaSqlUdfMethod(params = {SamzaSqlFieldType.ANY, SamzaSqlFieldType.STRING})
+  public String execute(Object field, String fieldName) {
+    Object currentFieldOrValue = field;
     Validate.isTrue(currentFieldOrValue == null
         || currentFieldOrValue instanceof SamzaSqlRelRecord);
-    if (currentFieldOrValue != null && args.length > 1) {
-      String[] fieldNameChain = ((String) args[1]).split("\\.");
-      for (int i = 0; i < fieldNameChain.length && currentFieldOrValue != null; i++) {
-        currentFieldOrValue = extractField(fieldNameChain[i], currentFieldOrValue);
-      }
+
+    String[] fieldNameChain = fieldName.split("\\.");
+    for (int i = 0; i < fieldNameChain.length && currentFieldOrValue != null; i++) {
+      currentFieldOrValue = extractField(fieldNameChain[i], currentFieldOrValue);
     }
 
     if (currentFieldOrValue != null) {
index 00b5775..c157112 100644 (file)
@@ -21,6 +21,7 @@ package org.apache.samza.sql.fn;
 
 import java.util.regex.Pattern;
 import org.apache.samza.config.Config;
+import org.apache.samza.sql.schema.SamzaSqlFieldType;
 import org.apache.samza.sql.udfs.SamzaSqlUdf;
 import org.apache.samza.sql.udfs.SamzaSqlUdfMethod;
 import org.apache.samza.sql.udfs.ScalarUdf;
@@ -29,15 +30,15 @@ import org.apache.samza.sql.udfs.ScalarUdf;
 /**
  * Simple RegexMatch Udf.
  */
-@SamzaSqlUdf(name="RegexMatch")
+@SamzaSqlUdf(name="RegexMatch", description = "Function to perform the regex match.")
 public class RegexMatchUdf implements ScalarUdf {
   @Override
   public void init(Config config) {
 
   }
 
-  @SamzaSqlUdfMethod
-  public Boolean execute(Object... args) {
-    return Pattern.matches((String) args[0], (String) args[1]);
+  @SamzaSqlUdfMethod(params = {SamzaSqlFieldType.STRING, SamzaSqlFieldType.STRING})
+  public Boolean match(String regexPattern, String input) {
+    return Pattern.matches(regexPattern, input);
   }
 }
index d21c1a6..1319a85 100644 (file)
@@ -23,7 +23,9 @@ import java.lang.reflect.Method;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collection;
+import java.util.HashMap;
 import java.util.List;
+import java.util.Map;
 import java.util.Properties;
 import java.util.stream.Collectors;
 import org.apache.commons.lang.StringUtils;
@@ -31,6 +33,7 @@ import org.apache.samza.SamzaException;
 import org.apache.samza.config.Config;
 import org.apache.samza.sql.interfaces.UdfMetadata;
 import org.apache.samza.sql.interfaces.UdfResolver;
+import org.apache.samza.sql.schema.SamzaSqlFieldType;
 import org.apache.samza.sql.udfs.SamzaSqlUdf;
 import org.apache.samza.sql.udfs.SamzaSqlUdfMethod;
 import org.apache.samza.sql.udfs.ScalarUdf;
@@ -74,16 +77,15 @@ public class ConfigBasedUdfResolver implements UdfResolver {
       }
 
       SamzaSqlUdf sqlUdf;
+      Map<SamzaSqlUdfMethod, Method> udfMethods = new HashMap<>();
       SamzaSqlUdfMethod sqlUdfMethod = null;
-      Method udfMethod = null;
 
       sqlUdf = udfClass.getAnnotation(SamzaSqlUdf.class);
       Method[] methods = udfClass.getMethods();
       for (Method method : methods) {
         sqlUdfMethod = method.getAnnotation(SamzaSqlUdfMethod.class);
         if (sqlUdfMethod != null) {
-          udfMethod = method;
-          break;
+          udfMethods.put(sqlUdfMethod, method);
         }
       }
 
@@ -93,7 +95,7 @@ public class ConfigBasedUdfResolver implements UdfResolver {
         throw new SamzaException(msg);
       }
 
-      if (sqlUdfMethod == null) {
+      if (udfMethods.isEmpty()) {
         String msg = String.format("UdfClass %s doesn't have any methods annotated with SamzaSqlUdfMethod", udfClass);
         LOG.error(msg);
         throw new SamzaException(msg);
@@ -101,7 +103,11 @@ public class ConfigBasedUdfResolver implements UdfResolver {
 
       if (sqlUdf.enabled()) {
         String udfName = sqlUdf.name();
-        udfs.add(new UdfMetadata(udfName, udfMethod, udfConfig.subset(udfName + ".")));
+        for (Map.Entry<SamzaSqlUdfMethod, Method> udfMethod : udfMethods.entrySet()) {
+          List<SamzaSqlFieldType> params = Arrays.asList(udfMethod.getKey().params());
+          udfs.add(new UdfMetadata(udfName, udfMethod.getValue(), udfConfig.subset(udfName + "."), params,
+              udfMethod.getKey().disableArgumentCheck()));
+        }
       }
     }
   }
index b1a2d6d..4adb5ea 100644 (file)
@@ -21,7 +21,9 @@ package org.apache.samza.sql.interfaces;
 
 import java.lang.reflect.Method;
 
+import java.util.List;
 import org.apache.samza.config.Config;
+import org.apache.samza.sql.schema.SamzaSqlFieldType;
 
 
 /**
@@ -30,15 +32,18 @@ import org.apache.samza.config.Config;
 public class UdfMetadata {
 
   private final String name;
-
   private final Method udfMethod;
-
   private final Config udfConfig;
+  private final boolean disableArgCheck;
+  private final List<SamzaSqlFieldType> arguments;
 
-  public UdfMetadata(String name, Method udfMethod, Config udfConfig) {
+  public UdfMetadata(String name, Method udfMethod, Config udfConfig, List<SamzaSqlFieldType> arguments,
+      boolean disableArgCheck) {
     this.name = name;
     this.udfMethod = udfMethod;
     this.udfConfig = udfConfig;
+    this.arguments = arguments;
+    this.disableArgCheck = disableArgCheck;
   }
 
   public Config getUdfConfig() {
@@ -58,4 +63,19 @@ public class UdfMetadata {
   public String getName() {
     return name;
   }
+
+  /**
+   * @return Returns the list of arguments that the udf should take.
+   */
+  public List<SamzaSqlFieldType> getArguments() {
+    return arguments;
+  }
+
+  /**
+   * @return Returns whether the argument check needs to be disabled.
+   */
+  public boolean isDisableArgCheck() {
+    return disableArgCheck;
+  }
+
 }
index b860b20..8bccc2e 100644 (file)
@@ -113,7 +113,7 @@ public class QueryPlanner {
       }
 
       List<SamzaSqlScalarFunctionImpl> samzaSqlFunctions = udfMetadata.stream()
-          .map(x -> new SamzaSqlScalarFunctionImpl(x.getName(), x.getUdfMethod()))
+          .map(x -> new SamzaSqlScalarFunctionImpl(x))
           .collect(Collectors.toList());
 
       final List<RelTraitDef> traitDefs = new ArrayList<>();
index 6894c86..c5d0121 100644 (file)
@@ -35,22 +35,27 @@ import org.apache.calcite.schema.ImplementableFunction;
 import org.apache.calcite.schema.ScalarFunction;
 import org.apache.calcite.schema.impl.ScalarFunctionImpl;
 import org.apache.samza.sql.data.SamzaSqlExecutionContext;
+import org.apache.samza.sql.interfaces.UdfMetadata;
 import org.apache.samza.sql.udfs.ScalarUdf;
 
-
+/**
+ * Calcite implementation for Samza SQL UDF.
+ * This class contains logic to generate the java code to execute {@link org.apache.samza.sql.udfs.SamzaSqlUdf}.
+ */
 public class SamzaSqlScalarFunctionImpl implements ScalarFunction, ImplementableFunction {
 
   private final ScalarFunction myIncFunction;
   private final Method udfMethod;
   private final Method getUdfMethod;
-
-
   private final String udfName;
+  private final UdfMetadata udfMetadata;
+
+  public SamzaSqlScalarFunctionImpl(UdfMetadata udfMetadata) {
 
-  public SamzaSqlScalarFunctionImpl(String udfName, Method udfMethod) {
-    myIncFunction = ScalarFunctionImpl.create(udfMethod);
-    this.udfName = udfName;
-    this.udfMethod = udfMethod;
+    myIncFunction = ScalarFunctionImpl.create(udfMetadata.getUdfMethod());
+    this.udfMetadata = udfMetadata;
+    this.udfName = udfMetadata.getName();
+    this.udfMethod = udfMetadata.getUdfMethod();
     this.getUdfMethod = Arrays.stream(SamzaSqlExecutionContext.class.getMethods())
         .filter(x -> x.getName().equals("getOrCreateUdf"))
         .findFirst()
@@ -61,17 +66,23 @@ public class SamzaSqlScalarFunctionImpl implements ScalarFunction, Implementable
     return udfName;
   }
 
+  public int numberOfArguments() {
+    return udfMetadata.getArguments().size();
+  }
+
+  public UdfMetadata getUdfMetadata() {
+    return udfMetadata;
+  }
+
   @Override
   public CallImplementor getImplementor() {
     return RexImpTable.createImplementor((translator, call, translatedOperands) -> {
       final Expression context = Expressions.parameter(SamzaSqlExecutionContext.class, "context");
       final Expression getUdfInstance = Expressions.call(ScalarUdf.class, context, getUdfMethod,
           Expressions.constant(udfMethod.getDeclaringClass().getName()), Expressions.constant(udfName));
-      final Expression callExpression = Expressions.convert_(Expressions.call(Expressions.convert_(getUdfInstance, udfMethod.getDeclaringClass()), udfMethod,
-          translatedOperands), Object.class);
-      // The Janino compiler which is used to compile the expressions doesn't seem to understand the Type of the ScalarUdf.execute
-      // because it is a generic. To work around that we are explicitly casting it to the return type.
-      return Expressions.convert_(callExpression, udfMethod.getReturnType());
+      final Expression callExpression = Expressions.call(Expressions.convert_(getUdfInstance, udfMethod.getDeclaringClass()), udfMethod,
+          translatedOperands);
+      return callExpression;
     }, NullPolicy.NONE, false);
   }
 
index 476e9b0..6ee10f8 100644 (file)
@@ -1,27 +1,26 @@
 /*
-* Licensed to the Apache Software Foundation (ASF) under one
-* or more contributor license agreements.  See the NOTICE file
-* distributed with this work for additional information
-* regarding copyright ownership.  The ASF licenses this file
-* to you under the Apache License, Version 2.0 (the
-* "License"); you may not use this file except in compliance
-* with the License.  You may obtain a copy of the License at
-*
-*   http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing,
-* software distributed under the License is distributed on an
-* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-* KIND, either express or implied.  See the License for the
-* specific language governing permissions and limitations
-* under the License.
-*/
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
 
 package org.apache.samza.sql.planner;
 
 import java.util.List;
 import java.util.stream.Collectors;
-
 import org.apache.calcite.sql.SqlFunctionCategory;
 import org.apache.calcite.sql.SqlIdentifier;
 import org.apache.calcite.sql.SqlOperator;
@@ -30,6 +29,7 @@ import org.apache.calcite.sql.SqlSyntax;
 import org.apache.calcite.sql.parser.SqlParserPos;
 import org.apache.calcite.sql.util.ListSqlOperatorTable;
 import org.apache.calcite.sql.validate.SqlUserDefinedFunction;
+import org.apache.samza.sql.interfaces.UdfMetadata;
 
 
 public class SamzaSqlUdfOperatorTable implements SqlOperatorTable {
@@ -45,8 +45,18 @@ public class SamzaSqlUdfOperatorTable implements SqlOperatorTable {
   }
 
   private SqlOperator getSqlOperator(SamzaSqlScalarFunctionImpl scalarFunction) {
-    return new SqlUserDefinedFunction(new SqlIdentifier(scalarFunction.getUdfName(), SqlParserPos.ZERO),
-        o -> scalarFunction.getReturnType(o.getTypeFactory()), null, Checker.ANY_CHECKER, null, scalarFunction);
+    int numArguments = scalarFunction.numberOfArguments();
+    UdfMetadata udfMetadata = scalarFunction.getUdfMetadata();
+
+    if(udfMetadata.isDisableArgCheck()) {
+      return new SqlUserDefinedFunction(new SqlIdentifier(scalarFunction.getUdfName(), SqlParserPos.ZERO),
+          o -> scalarFunction.getReturnType(o.getTypeFactory()), null, Checker.ANY_CHECKER,
+          null, scalarFunction);
+    } else {
+      return new SqlUserDefinedFunction(new SqlIdentifier(scalarFunction.getUdfName(), SqlParserPos.ZERO),
+          o -> scalarFunction.getReturnType(o.getTypeFactory()), null, Checker.getChecker(numArguments, numArguments),
+          null, scalarFunction);
+    }
   }
 
   @Override
index 8d2c588..c6fb357 100644 (file)
@@ -55,7 +55,8 @@ public class TestSamzaSqlApplicationConfig {
             .collect(Collectors.toList()),
         queryInfo.stream().map(SamzaSqlQueryParser.QueryInfo::getSink).collect(Collectors.toList()));
 
-    Assert.assertEquals(numUdfs, samzaSqlApplicationConfig.getUdfMetadata().size());
+    // Two of the UDFs has an overload, hence + 1.
+    Assert.assertEquals(numUdfs + 2, samzaSqlApplicationConfig.getUdfMetadata().size());
     Assert.assertEquals(1, samzaSqlApplicationConfig.getInputSystemStreamConfigBySource().size());
     Assert.assertEquals(1, samzaSqlApplicationConfig.getOutputSystemStreamConfigsBySource().size());
   }
index c71813b..7f6ee50 100644 (file)
@@ -23,20 +23,20 @@ import java.util.List;
 import java.util.stream.Collectors;
 import java.util.stream.IntStream;
 import org.apache.samza.config.Config;
+import org.apache.samza.sql.schema.SamzaSqlFieldType;
 import org.apache.samza.sql.udfs.SamzaSqlUdf;
 import org.apache.samza.sql.udfs.SamzaSqlUdfMethod;
 import org.apache.samza.sql.udfs.ScalarUdf;
 
 
-@SamzaSqlUdf(name = "MyTestArray")
+@SamzaSqlUdf(name = "MyTestArray", description = "Test udf that returns an array")
 public class MyTestArrayUdf implements ScalarUdf {
   @Override
   public void init(Config udfConfig) {
   }
 
-  @SamzaSqlUdfMethod
-  public List<String> execute(Object... args) {
-    Integer value = (Integer) args[0];
+  @SamzaSqlUdfMethod(params = SamzaSqlFieldType.INT32)
+  public List<String> execute(Integer value) {
     return IntStream.range(0, value).mapToObj(String::valueOf).collect(Collectors.toList());
   }
 }
diff --git a/samza-sql/src/test/java/org/apache/samza/sql/util/MyTestPolyUdf.java b/samza-sql/src/test/java/org/apache/samza/sql/util/MyTestPolyUdf.java
new file mode 100644 (file)
index 0000000..f4afbd6
--- /dev/null
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.samza.sql.util;
+
+import org.apache.samza.config.Config;
+import org.apache.samza.sql.schema.SamzaSqlFieldType;
+import org.apache.samza.sql.udfs.SamzaSqlUdf;
+import org.apache.samza.sql.udfs.SamzaSqlUdfMethod;
+import org.apache.samza.sql.udfs.ScalarUdf;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+
+/**
+ * UDF to test polymorphism.
+ */
+@SamzaSqlUdf(name = "MyTestPoly", description = "Test Polymorphism UDF.")
+public class MyTestPolyUdf implements ScalarUdf {
+  private static final Logger LOG = LoggerFactory.getLogger(MyTestPolyUdf.class);
+
+  @SamzaSqlUdfMethod(params = SamzaSqlFieldType.INT32)
+  public Integer execute(Integer value) {
+    return value * 2;
+  }
+
+  @SamzaSqlUdfMethod(params = SamzaSqlFieldType.ANY)
+  public Integer execute(String value) {
+    return value.length() * 2;
+  }
+
+
+  @Override
+  public void init(Config udfConfig) {
+    LOG.info("Init called with {}", udfConfig);
+  }
+}
index 6b714b4..35a44e3 100644 (file)
@@ -20,6 +20,7 @@
 package org.apache.samza.sql.util;
 
 import org.apache.samza.config.Config;
+import org.apache.samza.sql.schema.SamzaSqlFieldType;
 import org.apache.samza.sql.udfs.SamzaSqlUdf;
 import org.apache.samza.sql.udfs.SamzaSqlUdfMethod;
 import org.apache.samza.sql.udfs.ScalarUdf;
@@ -30,16 +31,22 @@ import org.slf4j.LoggerFactory;
 /**
  * Test UDF used by unit and integration tests.
  */
-@SamzaSqlUdf(name = "MyTest")
+@SamzaSqlUdf(name = "MyTest", description = "Test UDF.")
 public class MyTestUdf implements ScalarUdf {
 
   private static final Logger LOG = LoggerFactory.getLogger(MyTestUdf.class);
 
-  @SamzaSqlUdfMethod
-  public Integer execute(Object... value) {
-    return ((Integer) value[0]) * 2;
+  @SamzaSqlUdfMethod(params = SamzaSqlFieldType.INT32)
+  public Integer execute(Integer value) {
+    return value * 2;
   }
 
+  @SamzaSqlUdfMethod(params = SamzaSqlFieldType.ANY)
+  public Integer execute(Object value) {
+    return ((Integer) value) * 2;
+  }
+
+
   @Override
   public void init(Config udfConfig) {
     LOG.info("Init called with {}", udfConfig);
index 19a8638..627dc65 100644 (file)
@@ -96,7 +96,7 @@ public class SamzaSqlTestConfig {
         ConfigBasedUdfResolver.class.getName());
     staticConfigs.put(configUdfResolverDomain + ConfigBasedUdfResolver.CFG_UDF_CLASSES, Joiner.on(",")
         .join(MyTestUdf.class.getName(), RegexMatchUdf.class.getName(), FlattenUdf.class.getName(),
-            MyTestArrayUdf.class.getName(), BuildOutputRecordUdf.class.getName()));
+            MyTestArrayUdf.class.getName(), BuildOutputRecordUdf.class.getName(), MyTestPolyUdf.class.getName()));
 
     String avroSystemConfigPrefix =
         String.format(ConfigBasedIOResolverFactory.CFG_FMT_SAMZA_PREFIX, SAMZA_SYSTEM_TEST_AVRO);
index 76119e4..9deb561 100644 (file)
@@ -369,6 +369,30 @@ public class TestSamzaSqlEndToEnd extends SamzaSqlIntegrationTestHarness {
   }
 
   @Test
+  public void testEndToEndUdfPolymorphism() throws Exception {
+    int numMessages = 20;
+    TestAvroSystemFactory.messages.clear();
+    Map<String, String> staticConfigs = SamzaSqlTestConfig.fetchStaticConfigsWithFactories(configs, numMessages);
+    String sql1 = "Insert into testavro.outputTopic(id, long_value) "
+        + "select MyTestPoly(id) as long_value, MyTestPoly(name) as id from testavro.SIMPLE1";
+    List<String> sqlStmts = Collections.singletonList(sql1);
+    staticConfigs.put(SamzaSqlApplicationConfig.CFG_SQL_STMTS_JSON, JsonUtil.toJson(sqlStmts));
+    runApplication(new MapConfig(staticConfigs));
+
+    LOG.info("output Messages " + TestAvroSystemFactory.messages);
+
+    List<Integer> outMessages = TestAvroSystemFactory.messages.stream()
+        .map(x -> Integer.valueOf(((GenericRecord) x.getMessage()).get("long_value").toString()))
+        .sorted()
+        .collect(Collectors.toList());
+    Assert.assertEquals(outMessages.size(), numMessages);
+    MyTestUdf udf = new MyTestUdf();
+
+    Assert.assertTrue(
+        IntStream.range(0, numMessages).map(udf::execute).boxed().collect(Collectors.toList()).equals(outMessages));
+  }
+
+  @Test
   public void testRegexMatchUdfInWhereClause() throws Exception {
     int numMessages = 20;
     TestAvroSystemFactory.messages.clear();