Skip to content

Commit

Permalink
[SPARK-44131][SQL][PYTHON][CONNECT][FOLLOWUP] Support qualified funct…
Browse files Browse the repository at this point in the history
…ion name for call_function

### What changes were proposed in this pull request?
apache#41687 added `call_function` and deprecate `call_udf` for Scala API.

Some times, the function name can be qualified, we should let users use it to invoke persistent functions as well.

### Why are the changes needed?
Support qualified function name for `call_function`.

### Does this PR introduce _any_ user-facing change?
'No'.
New feature.

### How was this patch tested?
New test cases.

Closes apache#41932 from beliefer/SPARK-44131_followup.

Authored-by: Jiaan Geng <beliefer@163.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
beliefer authored and cloud-fan committed Jul 25, 2023
1 parent a6ac63d commit d97a4e2
Show file tree
Hide file tree
Showing 14 changed files with 238 additions and 76 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7923,15 +7923,19 @@ object functions {
def call_udf(udfName: String, cols: Column*): Column = call_function(udfName, cols: _*)

/**
* Call a builtin or temp function.
* Call a SQL function.
*
* @param funcName
* function name
* function name that follows the SQL identifier syntax (can be quoted, can be qualified)
* @param cols
* the expression parameters of function
* @since 3.5.0
*/
@scala.annotation.varargs
def call_function(funcName: String, cols: Column*): Column = Column.fn(funcName, cols: _*)
def call_function(funcName: String, cols: Column*): Column = Column { builder =>
builder.getCallFunctionBuilder
.setFunctionName(funcName)
.addAllArguments(cols.map(_.expr).asJava)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -239,4 +239,14 @@ class ReplE2ESuite extends RemoteSparkSession with BeforeAndAfterEach {
val output = runCommandsInShell(input)
assertContains("Array[org.apache.spark.sql.Row] = Array([id1,1], [id2,16], [id3,25])", output)
}

test("call_function") {
val input = """
|val df = Seq(("id1", 1), ("id2", 4), ("id3", 5)).toDF("id", "value")
|spark.udf.register("simpleUDF", (v: Int) => v * v)
|df.select($"id", call_function("simpleUDF", $"value")).collect()
""".stripMargin
val output = runCommandsInShell(input)
assertContains("Array[org.apache.spark.sql.Row] = Array([id1,1], [id2,16], [id3,25])", output)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ message Expression {
UpdateFields update_fields = 13;
UnresolvedNamedLambdaVariable unresolved_named_lambda_variable = 14;
CommonInlineUserDefinedFunction common_inline_user_defined_function = 15;
CallFunction call_function = 16;

// This field is used to mark extensions to the protocol. When plugins generate arbitrary
// relations they can add them here. During the planning the correct resolution is done.
Expand Down Expand Up @@ -371,3 +372,11 @@ message JavaUDF {
// (Required) Indicate if the Java user-defined function is an aggregate function
bool aggregate = 3;
}

message CallFunction {
// (Required) Unparsed name of the SQL function.
string function_name = 1;

// (Optional) Function arguments. Empty arguments are allowed.
repeated Expression arguments = 2;
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
}
},
"expressions": [{
"unresolvedFunction": {
"callFunction": {
"functionName": "lower",
"arguments": [{
"unresolvedAttribute": {
Expand Down
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -1380,6 +1380,8 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging {
transformExpressionPlugin(exp.getExtension)
case proto.Expression.ExprTypeCase.COMMON_INLINE_USER_DEFINED_FUNCTION =>
transformCommonInlineUserDefinedFunction(exp.getCommonInlineUserDefinedFunction)
case proto.Expression.ExprTypeCase.CALL_FUNCTION =>
transformCallFunction(exp.getCallFunction)
case _ =>
throw InvalidPlanInput(
s"Expression with ID: ${exp.getExprTypeCase.getNumber} is not supported")
Expand Down Expand Up @@ -1484,6 +1486,23 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging {
}
}

/**
* Translates a SQL function from proto to the Catalyst expression.
*
* @param fun
* Proto representation of the function call.
* @return
* Expression.
*/
private def transformCallFunction(fun: proto.CallFunction): Expression = {
val funcName = fun.getFunctionName
val nameParts = session.sessionState.sqlParser.parseMultipartIdentifier(funcName)
UnresolvedFunction(
nameParts,
fun.getArgumentsList.asScala.map(transformExpression).toSeq,
false)
}

private def unpackUdf(fun: proto.CommonInlineUserDefinedFunction): UdfPacket = {
Utils.deserialize[UdfPacket](
fun.getScalarScalaUdf.getPayload.toByteArray,
Expand Down
24 changes: 24 additions & 0 deletions python/pyspark/sql/connect/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1027,3 +1027,27 @@ def to_plan(self, session: "SparkConnectClient") -> proto.Expression:

def __repr__(self) -> str:
return "DistributedSequenceID()"


class CallFunction(Expression):
def __init__(self, name: str, args: Sequence["Expression"]):
super().__init__()

assert isinstance(name, str)
self._name = name

assert isinstance(args, list) and all(isinstance(arg, Expression) for arg in args)
self._args = args

def to_plan(self, session: "SparkConnectClient") -> "proto.Expression":
expr = proto.Expression()
expr.call_function.function_name = self._name
if len(self._args) > 0:
expr.call_function.arguments.extend([arg.to_plan(session) for arg in self._args])
return expr

def __repr__(self) -> str:
if len(self._args) > 0:
return f"CallFunction('{self._name}', {', '.join([str(arg) for arg in self._args])})"
else:
return f"CallFunction('{self._name}')"
6 changes: 4 additions & 2 deletions python/pyspark/sql/connect/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
SQLExpression,
LambdaFunction,
UnresolvedNamedLambdaVariable,
CallFunction,
)
from pyspark.sql.connect.udf import _create_py_udf
from pyspark.sql.connect.udtf import AnalyzeArgument, AnalyzeResult # noqa: F401
Expand Down Expand Up @@ -3910,8 +3911,9 @@ def udtf(
udtf.__doc__ = pysparkfuncs.udtf.__doc__


def call_function(udfName: str, *cols: "ColumnOrName") -> Column:
return _invoke_function(udfName, *[_to_col(c) for c in cols])
def call_function(funcName: str, *cols: "ColumnOrName") -> Column:
expressions = [_to_col(c)._expr for c in cols]
return Column(CallFunction(funcName, expressions))


call_function.__doc__ = pysparkfuncs.call_function.__doc__
Expand Down
118 changes: 60 additions & 58 deletions python/pyspark/sql/connect/proto/expressions_pb2.py

Large diffs are not rendered by default.

36 changes: 36 additions & 0 deletions python/pyspark/sql/connect/proto/expressions_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1101,6 +1101,7 @@ class Expression(google.protobuf.message.Message):
UPDATE_FIELDS_FIELD_NUMBER: builtins.int
UNRESOLVED_NAMED_LAMBDA_VARIABLE_FIELD_NUMBER: builtins.int
COMMON_INLINE_USER_DEFINED_FUNCTION_FIELD_NUMBER: builtins.int
CALL_FUNCTION_FIELD_NUMBER: builtins.int
EXTENSION_FIELD_NUMBER: builtins.int
@property
def literal(self) -> global___Expression.Literal: ...
Expand Down Expand Up @@ -1135,6 +1136,8 @@ class Expression(google.protobuf.message.Message):
@property
def common_inline_user_defined_function(self) -> global___CommonInlineUserDefinedFunction: ...
@property
def call_function(self) -> global___CallFunction: ...
@property
def extension(self) -> google.protobuf.any_pb2.Any:
"""This field is used to mark extensions to the protocol. When plugins generate arbitrary
relations they can add them here. During the planning the correct resolution is done.
Expand All @@ -1158,13 +1161,16 @@ class Expression(google.protobuf.message.Message):
unresolved_named_lambda_variable: global___Expression.UnresolvedNamedLambdaVariable
| None = ...,
common_inline_user_defined_function: global___CommonInlineUserDefinedFunction | None = ...,
call_function: global___CallFunction | None = ...,
extension: google.protobuf.any_pb2.Any | None = ...,
) -> None: ...
def HasField(
self,
field_name: typing_extensions.Literal[
"alias",
b"alias",
"call_function",
b"call_function",
"cast",
b"cast",
"common_inline_user_defined_function",
Expand Down Expand Up @@ -1204,6 +1210,8 @@ class Expression(google.protobuf.message.Message):
field_name: typing_extensions.Literal[
"alias",
b"alias",
"call_function",
b"call_function",
"cast",
b"cast",
"common_inline_user_defined_function",
Expand Down Expand Up @@ -1256,6 +1264,7 @@ class Expression(google.protobuf.message.Message):
"update_fields",
"unresolved_named_lambda_variable",
"common_inline_user_defined_function",
"call_function",
"extension",
] | None: ...

Expand Down Expand Up @@ -1469,3 +1478,30 @@ class JavaUDF(google.protobuf.message.Message):
) -> typing_extensions.Literal["output_type"] | None: ...

global___JavaUDF = JavaUDF

class CallFunction(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor

FUNCTION_NAME_FIELD_NUMBER: builtins.int
ARGUMENTS_FIELD_NUMBER: builtins.int
function_name: builtins.str
"""(Required) Unparsed name of the SQL function."""
@property
def arguments(
self,
) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___Expression]:
"""(Optional) Function arguments. Empty arguments are allowed."""
def __init__(
self,
*,
function_name: builtins.str = ...,
arguments: collections.abc.Iterable[global___Expression] | None = ...,
) -> None: ...
def ClearField(
self,
field_name: typing_extensions.Literal[
"arguments", b"arguments", "function_name", b"function_name"
],
) -> None: ...

global___CallFunction = CallFunction
23 changes: 18 additions & 5 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14396,16 +14396,16 @@ def call_udf(udfName: str, *cols: "ColumnOrName") -> Column:


@try_remote_functions
def call_function(udfName: str, *cols: "ColumnOrName") -> Column:
def call_function(funcName: str, *cols: "ColumnOrName") -> Column:
"""
Call a builtin or temp function.
Call a SQL function.
.. versionadded:: 3.5.0
Parameters
----------
udfName : str
name of the function
funcName : str
function name that follows the SQL identifier syntax (can be quoted, can be qualified)
cols : :class:`~pyspark.sql.Column` or str
column names or :class:`~pyspark.sql.Column`\\s to be used in the function
Expand Down Expand Up @@ -14443,9 +14443,22 @@ def call_function(udfName: str, *cols: "ColumnOrName") -> Column:
+-------+
| 2.0|
+-------+
>>> _ = spark.sql("CREATE FUNCTION custom_avg AS 'test.org.apache.spark.sql.MyDoubleAvg'")
>>> df.select(call_function("custom_avg", col("id"))).show()
+------------------------------------+
|spark_catalog.default.custom_avg(id)|
+------------------------------------+
| 102.0|
+------------------------------------+
>>> df.select(call_function("spark_catalog.default.custom_avg", col("id"))).show()
+------------------------------------+
|spark_catalog.default.custom_avg(id)|
+------------------------------------+
| 102.0|
+------------------------------------+
"""
sc = get_active_spark_context()
return _invoke_function("call_function", udfName, _to_seq(sc, cols, _to_java_column))
return _invoke_function("call_function", funcName, _to_seq(sc, cols, _to_java_column))


@try_remote_functions
Expand Down
22 changes: 16 additions & 6 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8338,7 +8338,7 @@ object functions {
@scala.annotation.varargs
@deprecated("Use call_udf")
def callUDF(udfName: String, cols: Column*): Column =
call_function(udfName, cols: _*)
call_function(Seq(udfName), cols: _*)

/**
* Call an user-defined function.
Expand All @@ -8357,18 +8357,28 @@ object functions {
*/
@scala.annotation.varargs
def call_udf(udfName: String, cols: Column*): Column =
call_function(udfName, cols: _*)
call_function(Seq(udfName), cols: _*)

/**
* Call a builtin or temp function.
* Call a SQL function.
*
* @param funcName function name
* @param funcName function name that follows the SQL identifier syntax
* (can be quoted, can be qualified)
* @param cols the expression parameters of function
* @since 3.5.0
*/
@scala.annotation.varargs
def call_function(funcName: String, cols: Column*): Column =
withExpr { UnresolvedFunction(funcName, cols.map(_.expr), false) }
def call_function(funcName: String, cols: Column*): Column = {
val parser = SparkSession.getActiveSession.map(_.sessionState.sqlParser).getOrElse {
new SparkSqlParser()
}
val nameParts = parser.parseMultipartIdentifier(funcName)
call_function(nameParts, cols: _*)
}

private def call_function(nameParts: Seq[String], cols: Column*): Column = withExpr {
UnresolvedFunction(nameParts, cols.map(_.expr), false)
}

/**
* Unwrap UDT data type column into its underlying type.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5918,6 +5918,26 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession {

test("call_function") {
checkAnswer(testData2.select(call_function("avg", $"a")), testData2.selectExpr("avg(a)"))

withUserDefinedFunction("custom_func" -> true, "custom_sum" -> false) {
spark.udf.register("custom_func", (i: Int) => { i + 2 })
checkAnswer(
testData2.select(call_function("custom_func", $"a")),
Seq(Row(3), Row(3), Row(4), Row(4), Row(5), Row(5)))
spark.udf.register("default.custom_func", (i: Int) => { i + 2 })
checkAnswer(
testData2.select(call_function("`default.custom_func`", $"a")),
Seq(Row(3), Row(3), Row(4), Row(4), Row(5), Row(5)))

sql("CREATE FUNCTION custom_sum AS 'test.org.apache.spark.sql.MyDoubleSum'")
checkAnswer(
testData2.select(
call_function("custom_sum", $"a"),
call_function("default.custom_sum", $"a"),
call_function("spark_catalog.default.custom_sum", $"a")),
Row(12.0, 12.0, 12.0))
}

}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ import org.apache.spark.{SparkException, SparkFiles, TestUtils}
import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
import org.apache.spark.sql.catalyst.plans.logical.Project
import org.apache.spark.sql.execution.WholeStageCodegenExec
import org.apache.spark.sql.functions.max
import org.apache.spark.sql.functions.{call_function, max}
import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SQLTestUtils
Expand Down Expand Up @@ -552,6 +552,19 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils {
}
}

test("Invoke a persist hive function with call_function") {
val testData = spark.range(5).repartition(1)
withUserDefinedFunction("custom_avg" -> false) {
sql(s"CREATE FUNCTION custom_avg AS '${classOf[GenericUDAFAverage].getName}'")
checkAnswer(
testData.select(
call_function("custom_avg", $"id"),
call_function("default.custom_avg", $"id"),
call_function("spark_catalog.default.custom_avg", $"id")),
Row(2.0, 2.0, 2.0))
}
}

test("Temp function has dots in the names") {
withUserDefinedFunction("test_avg" -> false, "`default.test_avg`" -> true) {
sql(s"CREATE FUNCTION test_avg AS '${classOf[GenericUDAFAverage].getName}'")
Expand Down

0 comments on commit d97a4e2

Please sign in to comment.