Skip to content

Commit

Permalink
[SPARK-44901][SQL] Add API in Python UDTF 'analyze' method to return …
Browse files Browse the repository at this point in the history
…partitioning/ordering expressions

### What changes were proposed in this pull request?

This PR adds an API in the Python UDTF 'analyze' method to require partitioning/ordering properties from the input relation. Catalyst then performs necessary repartitioning and/or sorting as needed to fulfill the requested properties.

For example, the following property would request for Catalyst to behave as if the UDTF call included `PARTITION BY partition_col ORDER BY input`:

```
from pyspark.sql.functions import AnalyzeResult, OrderingColumn, PartitioningColumn
from pyspark.sql.types import IntegerType, StructType

udtf
class MyUDTF:
    staticmethod
    def analyze(self):
        return AnalyzeResult(
            schema=StructType()
                .add("partition_col", IntegerType())
                .add("count", IntegerType())
                .add("total", IntegerType())
                .add("last", IntegerType()),
            partition_by=[
                PartitioningColumn("partition_col")
            ],
            order_by=[
                OrderingColumn("input")
            ])
    ...
```

Or, the following property would request for Catalyst to behave as if the UDTF call included `WITH SINGLE PARTITION`:

```
from pyspark.sql.functions import AnalyzeResult
from pyspark.sql.types import IntegerType, StructType

udtf
class MyUDTF:
    staticmethod
    def analyze(self):
        return AnalyzeResult(
            schema=StructType()
                .add("partition_col", IntegerType())
                .add("count", IntegerType())
                .add("total", IntegerType())
                .add("last", IntegerType()),
            with_single_partition=True)
    ...
```

### Why are the changes needed?

This gives Python UDTF authors the ability to write table functions that can assume constraints about which rows are consumed by which instances of the UDTF class.

### Does this PR introduce _any_ user-facing change?

Yes, see above.

### How was this patch tested?

This PR adds unit test coverage in Scala and Python.

### Was this patch authored or co-authored using generative AI tooling?

No

Closes apache#42595 from dtenedor/anlayze-result.

Authored-by: Daniel Tenedorio <daniel.tenedorio@databricks.com>
Signed-off-by: Takuya UESHIN <ueshin@databricks.com>
  • Loading branch information
dtenedor authored and ueshin committed Sep 2, 2023
1 parent df534c3 commit e4ebb37
Show file tree
Hide file tree
Showing 17 changed files with 1,716 additions and 55 deletions.
12 changes: 12 additions & 0 deletions common/utils/src/main/resources/error/error-classes.json
Original file line number Diff line number Diff line change
Expand Up @@ -2698,6 +2698,18 @@
"Failed to analyze the Python user defined table function: <msg>"
]
},
"TABLE_VALUED_FUNCTION_REQUIRED_METADATA_INCOMPATIBLE_WITH_CALL" : {
"message" : [
"Failed to evaluate the table function <functionName> because its table metadata <requestedMetadata>, but the function call <invalidFunctionCallProperty>."
],
"sqlState" : "22023"
},
"TABLE_VALUED_FUNCTION_REQUIRED_METADATA_INVALID" : {
"message" : [
"Failed to evaluate the table function <functionName> because its table metadata was invalid; <reason>."
],
"sqlState" : "22023"
},
"TABLE_VALUED_FUNCTION_TOO_MANY_TABLE_ARGUMENTS" : {
"message" : [
"There are too many table arguments for table-valued function. It allows one table argument, but got: <num>. If you want to allow it, please set \"spark.sql.allowMultipleTableArguments.enabled\" to \"true\""
Expand Down
12 changes: 12 additions & 0 deletions docs/sql-error-conditions.md
Original file line number Diff line number Diff line change
Expand Up @@ -1764,6 +1764,18 @@ SQLSTATE: none assigned

Failed to analyze the Python user defined table function: `<msg>`

### TABLE_VALUED_FUNCTION_REQUIRED_METADATA_INCOMPATIBLE_WITH_CALL

[SQLSTATE: 22023](sql-error-conditions-sqlstates.html#class-22-data-exception)

Failed to evaluate the table function `<functionName>` because its table metadata `<requestedMetadata>`, but the function call `<invalidFunctionCallProperty>`.

### TABLE_VALUED_FUNCTION_REQUIRED_METADATA_INVALID

[SQLSTATE: 22023](sql-error-conditions-sqlstates.html#class-22-data-exception)

Failed to evaluate the table function `<functionName>` because its table metadata was invalid; `<reason>`.

### TABLE_VALUED_FUNCTION_TOO_MANY_TABLE_ARGUMENTS

SQLSTATE: none assigned
Expand Down
1 change: 1 addition & 0 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
# Keep UserDefinedFunction import for backwards compatible import; moved in SPARK-22409
from pyspark.sql.udf import UserDefinedFunction, _create_py_udf # noqa: F401
from pyspark.sql.udtf import AnalyzeArgument, AnalyzeResult # noqa: F401
from pyspark.sql.udtf import OrderingColumn, PartitioningColumn # noqa: F401
from pyspark.sql.udtf import UserDefinedTableFunction, _create_py_udtf

# Keep pandas_udf and PandasUDFType import for backwards compatible import; moved in SPARK-28264
Expand Down
142 changes: 136 additions & 6 deletions python/pyspark/sql/tests/test_udtf.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@
udtf,
AnalyzeArgument,
AnalyzeResult,
OrderingColumn,
PartitioningColumn,
)
from pyspark.sql.types import (
ArrayType,
Expand Down Expand Up @@ -2002,7 +2004,7 @@ def terminate(self):
# This is a basic example.
func = udtf(TestUDTF, returnType="partition_col: int, total: int")
self.spark.udtf.register("test_udtf", func)
self.assertEqual(
assertDataFrameEqual(
self.spark.sql(
"""
WITH t AS (
Expand All @@ -2023,7 +2025,7 @@ def terminate(self):
("123", "456", 123, 456),
("123", "NULL", None, 123),
):
self.assertEqual(
assertDataFrameEqual(
self.spark.sql(
f"""
WITH t AS (
Expand All @@ -2045,7 +2047,7 @@ def terminate(self):
# Combine a lateral join with a TABLE argument with PARTITION BY .
func = udtf(TestUDTF, returnType="partition_col: int, total: int")
self.spark.udtf.register("test_udtf", func)
self.assertEqual(
assertDataFrameEqual(
self.spark.sql(
"""
WITH t AS (
Expand Down Expand Up @@ -2090,7 +2092,7 @@ def terminate(self):
("input DESC", 1),
("input - 1 DESC", 1),
):
self.assertEqual(
assertDataFrameEqual(
self.spark.sql(
f"""
WITH t AS (
Expand Down Expand Up @@ -2130,7 +2132,7 @@ def terminate(self):

func = udtf(TestUDTF, returnType="count: int, total: int, last: int")
self.spark.udtf.register("test_udtf", func)
self.assertEqual(
assertDataFrameEqual(
self.spark.sql(
"""
WITH t AS (
Expand All @@ -2143,7 +2145,135 @@ def terminate(self):
ORDER BY 1, 2
"""
).collect(),
[Row(count=40, total=60, last=2)],
[
Row(count=40, total=60, last=2),
],
)

def test_udtf_with_table_argument_with_single_partition_from_analyze(self):
@udtf
class TestUDTF:
def __init__(self):
self._count = 0
self._sum = 0
self._last = None

@staticmethod
def analyze(self):
return AnalyzeResult(
schema=StructType()
.add("count", IntegerType())
.add("total", IntegerType())
.add("last", IntegerType()),
with_single_partition=True,
order_by=[OrderingColumn("input"), OrderingColumn("partition_col")],
)

def eval(self, row: Row):
# Make sure that the rows arrive in the expected order.
if self._last is not None and self._last > row["input"]:
raise Exception(
f"self._last was {self._last} but the row value was {row['input']}"
)
self._count += 1
self._last = row["input"]
self._sum += row["input"]

def terminate(self):
yield self._count, self._sum, self._last

self.spark.udtf.register("test_udtf", TestUDTF)

assertDataFrameEqual(
self.spark.sql(
"""
WITH t AS (
SELECT id AS partition_col, 1 AS input FROM range(1, 21)
UNION ALL
SELECT id AS partition_col, 2 AS input FROM range(1, 21)
)
SELECT count, total, last
FROM test_udtf(TABLE(t))
ORDER BY 1, 2
"""
).collect(),
[
Row(count=40, total=60, last=2),
],
)

def test_udtf_with_table_argument_with_partition_by_and_order_by_from_analyze(self):
@udtf
class TestUDTF:
def __init__(self):
self._partition_col = None
self._count = 0
self._sum = 0
self._last = None

@staticmethod
def analyze(self):
return AnalyzeResult(
schema=StructType()
.add("partition_col", IntegerType())
.add("count", IntegerType())
.add("total", IntegerType())
.add("last", IntegerType()),
partition_by=[PartitioningColumn("partition_col")],
order_by=[
OrderingColumn(name="input", ascending=True, overrideNullsFirst=False)
],
)

def eval(self, row: Row):
# Make sure that all values of the partitioning column are the same
# for each row consumed by this method for this instance of the class.
if self._partition_col is not None and self._partition_col != row["partition_col"]:
raise Exception(
f"self._partition_col was {self._partition_col} but the row "
+ f"value was {row['partition_col']}"
)
# Make sure that the rows arrive in the expected order.
if (
self._last is not None
and row["input"] is not None
and self._last > row["input"]
):
raise Exception(
f"self._last was {self._last} but the row value was {row['input']}"
)
self._partition_col = row["partition_col"]
self._count += 1
self._last = row["input"]
if row["input"] is not None:
self._sum += row["input"]

def terminate(self):
yield self._partition_col, self._count, self._sum, self._last

self.spark.udtf.register("test_udtf", TestUDTF)

assertDataFrameEqual(
self.spark.sql(
"""
WITH t AS (
SELECT id AS partition_col, 1 AS input FROM range(1, 21)
UNION ALL
SELECT id AS partition_col, 2 AS input FROM range(1, 21)
UNION ALL
SELECT 42 AS partition_col, NULL AS input
UNION ALL
SELECT 42 AS partition_col, 1 AS input
UNION ALL
SELECT 42 AS partition_col, 2 AS input
)
SELECT partition_col, count, total, last
FROM test_udtf(TABLE(t))
ORDER BY 1, 2
"""
).collect(),
[Row(partition_col=x, count=2, total=3, last=2) for x in range(1, 21)]
+ [Row(partition_col=42, count=3, total=3, last=None)],
)


Expand Down
44 changes: 42 additions & 2 deletions python/pyspark/sql/udtf.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@
User-defined table function related classes and functions
"""
import pickle
from dataclasses import dataclass
from dataclasses import dataclass, field
import inspect
import sys
import warnings
from typing import Any, Type, TYPE_CHECKING, Optional, Union
from typing import Any, Type, TYPE_CHECKING, Optional, Sequence, Union

from py4j.java_gateway import JavaObject

Expand Down Expand Up @@ -61,6 +61,30 @@ class AnalyzeArgument:
is_table: bool


@dataclass(frozen=True)
class PartitioningColumn:
"""
Represents a UDTF column for purposes of returning metadata from the 'analyze' method.
"""

name: str


@dataclass(frozen=True)
class OrderingColumn:
"""
Represents a single ordering column name for purposes of returning metadata from the 'analyze'
method.
"""

name: str
ascending: bool = True
# If this is None, use the default behavior to sort NULL values first when sorting in ascending
# order, or last when sorting in descending order. Otherwise, if this is True or False, override
# the default behavior accordingly.
overrideNullsFirst: Optional[bool] = None


@dataclass(frozen=True)
class AnalyzeResult:
"""
Expand All @@ -70,9 +94,25 @@ class AnalyzeResult:
----------
schema : :class:`StructType`
The schema that the Python UDTF will return.
with_single_partition : bool
If true, the UDTF is specifying for Catalyst to repartition all rows of the input TABLE
argument to one collection for consumption by exactly one instance of the correpsonding
UDTF class.
partition_by : Sequence[PartitioningColumn]
If non-empty, this is a sequence of columns that the UDTF is specifying for Catalyst to
partition the input TABLE argument by. In this case, calls to the UDTF may not include any
explicit PARTITION BY clause, in which case Catalyst will return an error. This option is
mutually exclusive with 'with_single_partition'.
order_by: Sequence[OrderingColumn]
If non-empty, this is a sequence of columns that the UDTF is specifying for Catalyst to
sort the input TABLE argument by. Note that the 'partition_by' list must also be non-empty
in this case.
"""

schema: StructType
with_single_partition: bool = False
partition_by: Sequence[PartitioningColumn] = field(default_factory=tuple)
order_by: Sequence[OrderingColumn] = field(default_factory=tuple)


def _create_udtf(
Expand Down
18 changes: 18 additions & 0 deletions python/pyspark/sql/worker/analyze_udtf.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,24 @@ def main(infile: IO, outfile: IO) -> None:

# Return the analyzed schema.
write_with_length(result.schema.json().encode("utf-8"), outfile)
# Return whether the "with single partition" property is requested.
write_int(1 if result.with_single_partition else 0, outfile)
# Return the list of partitioning columns, if any.
write_int(len(result.partition_by), outfile)
for partitioning_col in result.partition_by:
write_with_length(partitioning_col.name.encode("utf-8"), outfile)
# Return the requested input table ordering, if any.
write_int(len(result.order_by), outfile)
for ordering_col in result.order_by:
write_with_length(ordering_col.name.encode("utf-8"), outfile)
write_int(1 if ordering_col.ascending else 0, outfile)
if ordering_col.overrideNullsFirst is None:
write_int(0, outfile)
elif ordering_col.overrideNullsFirst:
write_int(1, outfile)
else:
write_int(2, outfile)

except BaseException as e:
try:
exc_info = None
Expand Down
Loading

0 comments on commit e4ebb37

Please sign in to comment.