Skip to content

Commit

Permalink
Merge pull request #9385 from Tishj/python_string_annotations
Browse files Browse the repository at this point in the history
[Python] Support PEP 563
  • Loading branch information
Mytherin authored Oct 22, 2023
2 parents 1773893 + 8ce3f48 commit 812c4a9
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 2 deletions.
15 changes: 13 additions & 2 deletions tools/pythonpkg/src/python_udf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -279,9 +279,20 @@ struct PythonUDFData {
}
}

void AnalyzeSignature(const py::object &udf) {
py::object GetSignature(const py::object &udf) {
const int32_t PYTHON_3_10_HEX = 0x030a00f0;
auto python_version = PY_VERSION_HEX;

auto signature_func = py::module_::import("inspect").attr("signature");
auto signature = signature_func(udf);
if (python_version >= PYTHON_3_10_HEX) {
return signature_func(udf, py::arg("eval_str") = true);
} else {
return signature_func(udf);
}
}

void AnalyzeSignature(const py::object &udf) {
auto signature = GetSignature(udf);
auto sig_params = signature.attr("parameters");
auto return_annotation = signature.attr("return_annotation");
if (!py::none().is(return_annotation)) {
Expand Down
53 changes: 53 additions & 0 deletions tools/pythonpkg/tests/fast/test_string_annotation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import duckdb
import pytest
import sys
from typing import Union


def make_annotated_function(type: str):
def test_base():
return None

import types

test_function = types.FunctionType(
test_base.__code__, test_base.__globals__, test_base.__name__, test_base.__defaults__, test_base.__closure__
)
# Add the 'type' string as return_annotation
test_function.__annotations__ = {'return': type}
return test_function


def python_version_lower_than_3_10():
import sys

if sys.version_info[0] < 3:
return True
if sys.version_info[1] < 10:
return True
return False


class TestStringAnnotation(object):
@pytest.mark.skipif(
python_version_lower_than_3_10(), reason="inspect.signature(eval_str=True) only supported since 3.10 and higher"
)
@pytest.mark.parametrize(
['input', 'expected'],
[
('str', 'VARCHAR'),
('list[str]', 'VARCHAR[]'),
('dict[str, str]', 'MAP(VARCHAR, VARCHAR)'),
('dict[Union[str, bool], str]', 'MAP(UNION(u1 VARCHAR, u2 BOOLEAN), VARCHAR)'),
],
)
def test_string_annotations(self, duckdb_cursor, input, expected):
from inspect import signature

func = make_annotated_function(input)
sig = signature(func)
assert sig.return_annotation.__class__ == str

duckdb_cursor.create_function("foo", func)
rel = duckdb_cursor.sql("select foo()")
assert rel.types == [expected]

0 comments on commit 812c4a9

Please sign in to comment.