Skip to content

Commit

Permalink
[SPARK-44876][PYTHON] Fix Arrow-optimized Python UDF on Spark Connect
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

Fixes Arrow-optimized Python UDF on Spark Connect.

Also enables the missing test `pyspark.sql.tests.connect.test_parity_arrow_python_udf`.

### Why are the changes needed?

`pyspark.sql.tests.connect.test_parity_arrow_python_udf` is not listed in `dev/sparktestsupport/modules.py`, and it fails when running manually.

```
======================================================================
ERROR [0.072s]: test_register (pyspark.sql.tests.connect.test_parity_arrow_python_udf.ArrowPythonUDFParityTests)
----------------------------------------------------------------------
Traceback (most recent call last):
...
pyspark.errors.exceptions.base.PySparkRuntimeError: [SCHEMA_MISMATCH_FOR_PANDAS_UDF] Result vector from pandas_udf was not the required length: expected 1, got 38.
```

The failure had not been captured because the test is missing in the `module.py` file.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

Existing tests.

Closes apache#42568 from ueshin/issues/SPARK-44876/test_parity_arrow_python_udf.

Authored-by: Takuya UESHIN <ueshin@databricks.com>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
  • Loading branch information
ueshin authored and HyukjinKwon committed Aug 21, 2023
1 parent 169af11 commit 75c0b8b
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 62 deletions.
1 change: 1 addition & 0 deletions dev/sparktestsupport/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -840,6 +840,7 @@ def __hash__(self):
"pyspark.sql.tests.connect.test_connect_function",
"pyspark.sql.tests.connect.test_connect_column",
"pyspark.sql.tests.connect.test_parity_arrow",
"pyspark.sql.tests.connect.test_parity_arrow_python_udf",
"pyspark.sql.tests.connect.test_parity_datasources",
"pyspark.sql.tests.connect.test_parity_errors",
"pyspark.sql.tests.connect.test_parity_catalog",
Expand Down
20 changes: 9 additions & 11 deletions python/pyspark/sql/connect/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,6 @@ def _create_py_udf(
returnType: "DataTypeOrString",
useArrow: Optional[bool] = None,
) -> "UserDefinedFunctionLike":
from pyspark.sql.udf import _create_arrow_py_udf

if useArrow is None:
is_arrow_enabled = False
try:
Expand All @@ -74,22 +72,22 @@ def _create_py_udf(
else:
is_arrow_enabled = useArrow

regular_udf = _create_udf(f, returnType, PythonEvalType.SQL_BATCHED_UDF)
try:
is_func_with_args = len(getfullargspec(f).args) > 0
except TypeError:
is_func_with_args = False
eval_type: int = PythonEvalType.SQL_BATCHED_UDF

if is_arrow_enabled:
try:
is_func_with_args = len(getfullargspec(f).args) > 0
except TypeError:
is_func_with_args = False
if is_func_with_args:
return _create_arrow_py_udf(regular_udf)
eval_type = PythonEvalType.SQL_ARROW_BATCHED_UDF
else:
warnings.warn(
"Arrow optimization for Python UDFs cannot be enabled.",
UserWarning,
)
return regular_udf
else:
return regular_udf

return _create_udf(f, returnType, eval_type)


def _create_udf(
Expand Down
60 changes: 11 additions & 49 deletions python/pyspark/sql/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
from pyspark.rdd import _prepare_for_python_RDD, PythonEvalType
from pyspark.sql.column import Column, _to_java_column, _to_java_expr, _to_seq
from pyspark.sql.types import (
BinaryType,
DataType,
StringType,
StructType,
Expand Down Expand Up @@ -131,58 +130,24 @@ def _create_py_udf(
else:
is_arrow_enabled = useArrow

regular_udf = _create_udf(f, returnType, PythonEvalType.SQL_BATCHED_UDF)
try:
is_func_with_args = len(getfullargspec(f).args) > 0
except TypeError:
is_func_with_args = False
eval_type: int = PythonEvalType.SQL_BATCHED_UDF

if is_arrow_enabled:
try:
is_func_with_args = len(getfullargspec(f).args) > 0
except TypeError:
is_func_with_args = False
if is_func_with_args:
return _create_arrow_py_udf(regular_udf)
require_minimum_pandas_version()
require_minimum_pyarrow_version()
eval_type = PythonEvalType.SQL_ARROW_BATCHED_UDF
else:
warnings.warn(
"Arrow optimization for Python UDFs cannot be enabled.",
UserWarning,
)
return regular_udf
else:
return regular_udf


def _create_arrow_py_udf(regular_udf): # type: ignore
"""Create an Arrow-optimized Python UDF out of a regular Python UDF."""
require_minimum_pandas_version()
require_minimum_pyarrow_version()

import pandas as pd
from pyspark.sql.pandas.functions import _create_pandas_udf

f = regular_udf.func
return_type = regular_udf.returnType

# "result_func" ensures the result of a Python UDF to be consistent with/without Arrow
# optimization.
# Otherwise, an Arrow-optimized Python UDF raises "pyarrow.lib.ArrowTypeError: Expected a
# string or bytes dtype, got ..." whereas a non-Arrow-optimized Python UDF returns
# successfully.
result_func = lambda pdf: pdf # noqa: E731
if type(return_type) == StringType:
result_func = lambda r: str(r) if r is not None else r # noqa: E731
elif type(return_type) == BinaryType:
result_func = lambda r: bytes(r) if r is not None else r # noqa: E731

def vectorized_udf(*args: pd.Series) -> pd.Series:
return pd.Series(result_func(f(*a)) for a in zip(*args))

# Regular UDFs can take callable instances too.
vectorized_udf.__name__ = f.__name__ if hasattr(f, "__name__") else f.__class__.__name__
vectorized_udf.__module__ = f.__module__ if hasattr(f, "__module__") else f.__class__.__module__
vectorized_udf.__doc__ = f.__doc__
pudf = _create_pandas_udf(vectorized_udf, return_type, PythonEvalType.SQL_ARROW_BATCHED_UDF)
# Keep the attributes as if this is a regular Python UDF.
pudf.func = f
pudf.returnType = return_type
return pudf
return _create_udf(f, returnType, eval_type)


class UserDefinedFunction:
Expand Down Expand Up @@ -637,10 +602,7 @@ def register(
evalType=f.evalType,
deterministic=f.deterministic,
)
if f.evalType == PythonEvalType.SQL_ARROW_BATCHED_UDF:
register_udf = _create_arrow_py_udf(source_udf)._unwrapped
else:
register_udf = source_udf._unwrapped # type: ignore[attr-defined]
register_udf = source_udf._unwrapped # type: ignore[attr-defined]
return_udf = register_udf
else:
if returnType is None:
Expand Down
54 changes: 52 additions & 2 deletions python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
ApplyInPandasWithStateSerializer,
)
from pyspark.sql.pandas.types import to_arrow_type
from pyspark.sql.types import StructType, _parse_datatype_json_string
from pyspark.sql.types import BinaryType, StringType, StructType, _parse_datatype_json_string
from pyspark.util import fail_on_stopiteration, try_simplify_traceback
from pyspark import shuffle
from pyspark.errors import PySparkRuntimeError, PySparkTypeError
Expand Down Expand Up @@ -121,6 +121,54 @@ def verify_result_length(result, length):
)


def wrap_arrow_batch_udf(f, return_type):
import pandas as pd

arrow_return_type = to_arrow_type(return_type)

# "result_func" ensures the result of a Python UDF to be consistent with/without Arrow
# optimization.
# Otherwise, an Arrow-optimized Python UDF raises "pyarrow.lib.ArrowTypeError: Expected a
# string or bytes dtype, got ..." whereas a non-Arrow-optimized Python UDF returns
# successfully.
result_func = lambda pdf: pdf # noqa: E731
if type(return_type) == StringType:
result_func = lambda r: str(r) if r is not None else r # noqa: E731
elif type(return_type) == BinaryType:
result_func = lambda r: bytes(r) if r is not None else r # noqa: E731

def evaluate(*args: pd.Series) -> pd.Series:
return pd.Series(result_func(f(*a)) for a in zip(*args))

def verify_result_type(result):
if not hasattr(result, "__len__"):
pd_type = "pandas.DataFrame" if type(return_type) == StructType else "pandas.Series"
raise PySparkTypeError(
error_class="UDF_RETURN_TYPE",
message_parameters={
"expected": pd_type,
"actual": type(result).__name__,
},
)
return result

def verify_result_length(result, length):
if len(result) != length:
raise PySparkRuntimeError(
error_class="SCHEMA_MISMATCH_FOR_PANDAS_UDF",
message_parameters={
"expected": str(length),
"actual": str(len(result)),
},
)
return result

return lambda *a: (
verify_result_length(verify_result_type(evaluate(*a)), len(a[0])),
arrow_return_type,
)


def wrap_pandas_batch_iter_udf(f, return_type):
arrow_return_type = to_arrow_type(return_type)
iter_type_label = "pandas.DataFrame" if type(return_type) == StructType else "pandas.Series"
Expand Down Expand Up @@ -486,8 +534,10 @@ def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index):
func = fail_on_stopiteration(chained_func)

# the last returnType will be the return type of UDF
if eval_type in (PythonEvalType.SQL_SCALAR_PANDAS_UDF, PythonEvalType.SQL_ARROW_BATCHED_UDF):
if eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF:
return arg_offsets, wrap_scalar_pandas_udf(func, return_type)
elif eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF:
return arg_offsets, wrap_arrow_batch_udf(func, return_type)
elif eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF:
return arg_offsets, wrap_pandas_batch_iter_udf(func, return_type)
elif eval_type == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF:
Expand Down

0 comments on commit 75c0b8b

Please sign in to comment.