From 48474297d1a6f323bafe24de62d7b659ae26fdd5 Mon Sep 17 00:00:00 2001 From: scoder Date: Sat, 25 Nov 2023 20:51:24 +0100 Subject: [PATCH] Implement return type inference for methods of builtin types (GH-5865) * Also, infer 'str' in Py3 now, since it's a safe type. * Builtin types generally don't return None on slicing, and string types also never do it for indexing, so change "IndexNode.may_be_none()" accordingly. --- Cython/Compiler/Builtin.py | 172 ++++++++++++++++++++++++++- Cython/Compiler/ExprNodes.py | 52 ++++---- Cython/Compiler/Tests/TestBuiltin.py | 32 +++++ Cython/Compiler/TypeInference.py | 9 +- tests/run/bytesmethods.pyx | 4 +- tests/run/strmethods.pyx | 38 +++--- tests/run/strmethods_ll2.pyx | 9 ++ tests/run/type_inference.pyx | 32 ++++- 8 files changed, 296 insertions(+), 52 deletions(-) create mode 100644 Cython/Compiler/Tests/TestBuiltin.py create mode 100644 tests/run/strmethods_ll2.pyx diff --git a/Cython/Compiler/Builtin.py b/Cython/Compiler/Builtin.py index 743a6adba35..55d005f89b6 100644 --- a/Cython/Compiler/Builtin.py +++ b/Cython/Compiler/Builtin.py @@ -302,8 +302,7 @@ def declare_in_type(self, self_type): BuiltinMethod("__mul__", "Tz", "T", "__Pyx_PySequence_Multiply", utility_code=UtilityCode.load("PySequenceMultiply", "ObjectHandling.c")), ]), - ("str", "&PyString_Type", [BuiltinMethod("join", "TO", "O", "__Pyx_PyString_Join", - builtin_return_type='basestring', + ("str", "&PyString_Type", [BuiltinMethod("join", "TO", "T", "__Pyx_PyString_Join", utility_code=UtilityCode.load("StringJoin", "StringTools.c")), BuiltinMethod("__mul__", "Tz", "T", "__Pyx_PySequence_Multiply", utility_code=UtilityCode.load("PySequenceMultiply", "ObjectHandling.c")), @@ -411,6 +410,175 @@ def declare_in_type(self, self_type): }) +inferred_method_return_types = { + 'complex': dict( + conjugate='complex', + ), + 'int': dict( + bit_length='T', + bit_count='T', + to_bytes='bytes', + from_bytes='T', # classmethod + as_integer_ratio='tuple[int,int]', + is_integer='bint', + ), + 'float': dict( + as_integer_ratio='tuple[int,int]', + is_integer='bint', + hex='unicode', + fromhex='T', # classmethod + ), + 'list': dict( + index='Py_ssize_t', + count='Py_ssize_t', + ), + 'unicode': dict( + capitalize='T', + casefold='T', + center='T', + count='Py_ssize_t', + encode='bytes', + endswith='bint', + expandtabs='T', + find='Py_ssize_t', + format='T', + format_map='T', + index='Py_ssize_t', + isalnum='bint', + isalpha='bint', + isascii='bint', + isdecimal='bint', + isdigit='bint', + isidentifier='bint', + islower='bint', + isnumeric='bint', + isprintable='bint', + isspace='bint', + istitle='bint', + isupper='bint', + join='T', + ljust='T', + lower='T', + lstrip='T', + maketrans='dict[int,object]', # staticmethod + partition='tuple[T,T,T]', + removeprefix='T', + removesuffix='T', + replace='T', + rfind='Py_ssize_t', + rindex='Py_ssize_t', + rjust='T', + rpartition='tuple[T,T,T]', + rsplit='list[T]', + rstrip='T', + split='list[T]', + splitlines='list[T]', + startswith='bint', + strip='T', + swapcase='T', + title='T', + translate='T', + upper='T', + zfill='T', + ), + 'bytes': dict( + hex='unicode', + fromhex='T', # classmethod + count='Py_ssize_t', + removeprefix='T', + removesuffix='T', + decode='unicode', + endswith='bint', + find='Py_ssize_t', + index='Py_ssize_t', + join='T', + maketrans='bytes', # staticmethod + partition='tuple[T,T,T]', + replace='T', + rfind='Py_ssize_t', + rindex='Py_ssize_t', + rpartition='tuple[T,T,T]', + startswith='bint', + translate='T', + center='T', + ljust='T', + lstrip='T', + rjust='T', + rsplit='list[T]', + rstrip='T', + split='list[T]', + strip='T', + capitalize='T', + expandtabs='T', + isalnum='bint', + isalpha='bint', + isascii='bint', + isdigit='bint', + islower='bint', + isspace='bint', + istitle='bint', + isupper='bint', + lower='T', + splitlines='list[T]', + swapcase='T', + title='T', + upper='T', + zfill='T', + ), + 'bytearray': dict( + # Inherited from 'bytes' below. + ), + 'memoryview': dict( + tobytes='bytes', + hex='unicode', + tolist='list', + toreadonly='T', + cast='T', + ), + 'set': dict( + isdisjoint='bint', + isubset='bint', + issuperset='bint', + union='T', + intersection='T', + difference='T', + symmetric_difference='T', + copy='T', + ), + 'frozenset': dict( + # Inherited from 'set' below. + ), + 'dict': dict( + copy='T', + ), +} + +inferred_method_return_types['bytearray'].update(inferred_method_return_types['bytes']) +inferred_method_return_types['frozenset'].update(inferred_method_return_types['set']) +inferred_method_return_types['str'] = inferred_method_return_types['unicode'] + + +def find_return_type_of_builtin_method(builtin_type, method_name): + type_name = builtin_type.name + if type_name in inferred_method_return_types: + methods = inferred_method_return_types[type_name] + if method_name in methods: + return_type_name = methods[method_name] + if '[' in return_type_name: + # TODO: Keep the "[...]" part when we add support for generics. + return_type_name = return_type_name.partition('[')[0] + if return_type_name == 'T': + return builtin_type + if 'T' in return_type_name: + return_type_name = return_type_name.replace('T', builtin_type.name) + if return_type_name == 'bint': + return PyrexTypes.c_bint_type + elif return_type_name == 'Py_ssize_t': + return PyrexTypes.c_py_ssize_t_type + return builtin_scope.lookup(return_type_name).type + return PyrexTypes.py_object_type + + builtin_structs_table = [ ('Py_buffer', 'Py_buffer', [("buf", PyrexTypes.c_void_ptr_type), diff --git a/Cython/Compiler/ExprNodes.py b/Cython/Compiler/ExprNodes.py index 526fcc0e4c6..6fde12a078e 100644 --- a/Cython/Compiler/ExprNodes.py +++ b/Cython/Compiler/ExprNodes.py @@ -3877,10 +3877,13 @@ def may_be_none(self): if base_type: if base_type.is_string: return False + if base_type in (unicode_type, bytes_type, str_type, bytearray_type, basestring_type): + return False if isinstance(self.index, SliceNode): # slicing! - if base_type in (bytes_type, bytearray_type, str_type, unicode_type, - basestring_type, list_type, tuple_type): + if base_type.is_builtin_type: + # It seems that none of the builtin types can return None for "__getitem__[slice]". + # Slices are not hashable, and thus cannot be used as key in dicts, for example. return False return ExprNode.may_be_none(self) @@ -5845,8 +5848,15 @@ class CallNode(ExprNode): may_return_none = None def infer_type(self, env): - # TODO(robertwb): Reduce redundancy with analyse_types. function = self.function + if function.is_attribute: + method_obj_type = function.obj.infer_type(env) + if method_obj_type.is_builtin_type: + result_type = Builtin.find_return_type_of_builtin_method(method_obj_type, function.attribute) + if result_type is not py_object_type: + return result_type + + # TODO(robertwb): Reduce redundancy with analyse_types. func_type = function.infer_type(env) if isinstance(function, NewExprNode): # note: needs call to infer_type() above @@ -5932,6 +5942,16 @@ def set_py_result_type(self, function, func_type=None): self.type = function.type_entry.type self.result_ctype = py_object_type self.may_return_none = False + elif function.is_attribute and function.obj.type.is_builtin_type: + method_obj_type = function.obj.type + result_type = Builtin.find_return_type_of_builtin_method(method_obj_type, function.attribute) + self.may_return_none = result_type is py_object_type + if result_type.is_pyobject: + self.type = result_type + elif result_type.equivalent_type: + self.type = result_type.equivalent_type + else: + self.type = py_object_type else: self.type = py_object_type @@ -11952,8 +11972,8 @@ def is_py_operation_types(self, type1, type2): def infer_builtin_types_operation(self, type1, type2): # b'abc' + 'abc' raises an exception in Py3, - # so we can safely infer the Py2 type for bytes here - string_types = (bytes_type, bytearray_type, str_type, basestring_type, unicode_type) + # so we can safely infer a mix here. + string_types = (bytes_type, bytearray_type, basestring_type, str_type, unicode_type) if type1 in string_types and type2 in string_types: return string_types[max(string_types.index(type1), string_types.index(type2))] @@ -12326,25 +12346,11 @@ def is_py_operation_types(self, type1, type2): or NumBinopNode.is_py_operation_types(self, type1, type2)) def infer_builtin_types_operation(self, type1, type2): - # b'%s' % xyz raises an exception in Py3<3.5, so it's safe to infer the type for Py2 and later Py3's. - if type1 is unicode_type: - # None + xyz may be implemented by RHS - if type2.is_builtin_type or not self.operand1.may_be_none(): - return type1 - elif type1 in (bytes_type, str_type, basestring_type): - if type2 is unicode_type: - return type2 - elif type2.is_numeric: + # b'%s' % xyz raises an exception in Py3<3.5, so it's safe to infer the type for later Py3's. + if type1 in (unicode_type, bytes_type, str_type, basestring_type): + # 'None % xyz' may be implemented by the RHS, but everything else will do string formatting. + if type2.is_builtin_type or not type2.is_pyobject or not self.operand1.may_be_none(): return type1 - elif self.operand1.is_string_literal: - if type1 is str_type or type1 is bytes_type: - if set(_find_formatting_types(self.operand1.value)) <= _safe_bytes_formats: - return type1 - return basestring_type - elif type1 is bytes_type and not type2.is_builtin_type: - return None # RHS might implement '% operator differently in Py3 - else: - return basestring_type # either str or unicode, can't tell return super().infer_builtin_types_operation(type1, type2) def zero_division_message(self): diff --git a/Cython/Compiler/Tests/TestBuiltin.py b/Cython/Compiler/Tests/TestBuiltin.py new file mode 100644 index 00000000000..ebc5278a328 --- /dev/null +++ b/Cython/Compiler/Tests/TestBuiltin.py @@ -0,0 +1,32 @@ +import builtins +import sys +import unittest + +from ..Builtin import ( + inferred_method_return_types, find_return_type_of_builtin_method, + builtin_scope, +) + + +class TestBuiltinReturnTypes(unittest.TestCase): + def test_find_return_type_of_builtin_method(self): + # It's enough to test the method existence in a recent Python that likely has them. + look_up_methods = sys.version_info >= (3,10) + + for type_name, methods in inferred_method_return_types.items(): + py_type = getattr(builtins, type_name if type_name != 'unicode' else 'str') + + for method_name, return_type_name in methods.items(): + builtin_type = builtin_scope.lookup(type_name).type + return_type = find_return_type_of_builtin_method(builtin_type, method_name) + + if return_type.is_builtin_type: + if '[' in return_type_name: + return_type_name = return_type_name.partition('[')[0] + if return_type_name == 'T': + return_type_name = type_name + self.assertEqual(return_type.name, return_type_name) + if look_up_methods: + self.assertTrue(hasattr(py_type, method_name), f"{type_name}.{method_name}") + else: + self.assertEqual(return_type.empty_declaration_code(pyrex=True), return_type_name) diff --git a/Cython/Compiler/TypeInference.py b/Cython/Compiler/TypeInference.py index d40d191534e..ffd48d45179 100644 --- a/Cython/Compiler/TypeInference.py +++ b/Cython/Compiler/TypeInference.py @@ -544,14 +544,7 @@ def aggressive_spanning_type(types, might_overflow, scope): def safe_spanning_type(types, might_overflow, scope): result_type = simply_type(reduce(find_spanning_type, types)) if result_type.is_pyobject: - # In theory, any specific Python type is always safe to - # infer. However, inferring str can cause some existing code - # to break, since we are also now much more strict about - # coercion from str to char *. See trac #553. - if result_type.name == 'str': - return py_object_type - else: - return result_type + return result_type elif (result_type is PyrexTypes.c_double_type or result_type is PyrexTypes.c_float_type): # Python's float type is just a C double, so it's safe to use diff --git a/tests/run/bytesmethods.pyx b/tests/run/bytesmethods.pyx index 5973a7334d4..ac4e9b719d0 100644 --- a/tests/run/bytesmethods.pyx +++ b/tests/run/bytesmethods.pyx @@ -259,7 +259,7 @@ def bytes_join(bytes s, *args): babab """ result = s.join(args) - assert cython.typeof(result) == 'Python object', cython.typeof(result) + assert cython.typeof(result) == 'bytes object', cython.typeof(result) return result @@ -275,7 +275,7 @@ def literal_join(*args): b|b|b|b """ result = b'|'.join(args) - assert cython.typeof(result) == 'Python object', cython.typeof(result) + assert cython.typeof(result) == 'bytes object', cython.typeof(result) return result def fromhex(bytes b): diff --git a/tests/run/strmethods.pyx b/tests/run/strmethods.pyx index 58d7a7801ef..77730929cf8 100644 --- a/tests/run/strmethods.pyx +++ b/tests/run/strmethods.pyx @@ -1,5 +1,15 @@ +# mode: run + +# cython: language_level=3 + cimport cython +# Also used by the language_level=2 tests in "strmethods_ll2.pyx" +assert cython.typeof(1 / 2) in ('long', 'double') +IS_LANGUAGE_LEVEL_3 = cython.typeof(1 / 2) == 'double' +str_type = "unicode object" if IS_LANGUAGE_LEVEL_3 else "str object" + + @cython.test_assert_path_exists( "//PythonCapiCallNode") def str_startswith(str s, sub, start=None, stop=None): @@ -75,33 +85,31 @@ def str_as_name(str): return str.endswith("x") -@cython.test_assert_path_exists( - "//SimpleCallNode", - "//SimpleCallNode//NoneCheckNode", - "//SimpleCallNode//AttributeNode[@is_py_attr = false]") +#@cython.test_fail_if_path_exists( +# "//SimpleCallNode", +# "//SimpleCallNode//NoneCheckNode", +# "//SimpleCallNode//AttributeNode[@is_py_attr = false]") def str_join(str s, args): """ >>> print(str_join('a', list('bbb'))) babab """ result = s.join(args) - assert cython.typeof(result) == 'basestring object', cython.typeof(result) + assert cython.typeof(result) == str_type, (cython.typeof(result), str_type) return result -@cython.test_fail_if_path_exists( - "//SimpleCallNode//NoneCheckNode", -) -@cython.test_assert_path_exists( - "//SimpleCallNode", - "//SimpleCallNode//AttributeNode[@is_py_attr = false]") +#@cython.test_fail_if_path_exists( +# "//SimpleCallNode", +# "//SimpleCallNode//NoneCheckNode", +# "//SimpleCallNode//AttributeNode[@is_py_attr = false]") def literal_join(args): """ >>> print(literal_join(list('abcdefg'))) a|b|c|d|e|f|g """ result = '|'.join(args) - assert cython.typeof(result) == 'basestring object', cython.typeof(result) + assert cython.typeof(result) == str_type, (cython.typeof(result), str_type) return result @@ -125,7 +133,7 @@ def mod_format(str s, values): >>> mod_format(None, RMod()) 123 """ - assert cython.typeof(s % values) == 'basestring object', cython.typeof(s % values) + assert cython.typeof(s % values) == "Python object", cython.typeof(s % values) return s % values @@ -138,7 +146,7 @@ def mod_format_literal(values): >>> mod_format_literal(['sa']) == "abc['sa']def" or mod_format(format1, ['sa']) True """ - assert cython.typeof('abc%sdef' % values) == 'basestring object', cython.typeof('abc%sdef' % values) + assert cython.typeof('abc%sdef' % values) == str_type, (cython.typeof('abc%sdef' % values), str_type) return 'abc%sdef' % values @@ -150,5 +158,5 @@ def mod_format_tuple(*values): Traceback (most recent call last): TypeError: not enough arguments for format string """ - assert cython.typeof('abc%sdef' % values) == 'basestring object', cython.typeof('abc%sdef' % values) + assert cython.typeof('abc%sdef' % values) == str_type, (cython.typeof('abc%sdef' % values), str_type) return 'abc%sdef' % values diff --git a/tests/run/strmethods_ll2.pyx b/tests/run/strmethods_ll2.pyx new file mode 100644 index 00000000000..63faae059a5 --- /dev/null +++ b/tests/run/strmethods_ll2.pyx @@ -0,0 +1,9 @@ +# mode: run + +# cython: language_level=2 + +""" +Same tests as 'strmethods.pyx', but using 'language_level=2'. +""" + +include "strmethods.pyx" diff --git a/tests/run/type_inference.pyx b/tests/run/type_inference.pyx index feb18817e6f..595ef1297be 100644 --- a/tests/run/type_inference.pyx +++ b/tests/run/type_inference.pyx @@ -270,6 +270,35 @@ def builtin_type_methods(): append(1) assert l == [1], str(l) + u = u'abc def' + split = u.split() + assert typeof(split) == 'list object', typeof(split) + + str_result1 = u.upper() + assert typeof(str_result1) == 'unicode object', typeof(str_result1) + str_result2 = u.upper().lower() + assert typeof(str_result2) == 'unicode object', typeof(str_result2) + str_result3 = u.upper().lower().strip() + assert typeof(str_result3) == 'unicode object', typeof(str_result3) + str_result4 = u.upper().lower().strip().lstrip() + assert typeof(str_result4) == 'unicode object', typeof(str_result4) + str_result5 = u.upper().lower().strip().lstrip().rstrip() + assert typeof(str_result5) == 'unicode object', typeof(str_result5) + str_result6 = u.upper().lower().strip().lstrip().rstrip().center(20) + assert typeof(str_result6) == 'unicode object', typeof(str_result6) + str_result7 = u.upper().lower().strip().lstrip().rstrip().center(20).format() + assert typeof(str_result7) == 'unicode object', typeof(str_result7) + str_result8 = u.upper().lower().strip().lstrip().rstrip().center(20).format().expandtabs(4) + assert typeof(str_result8) == 'unicode object', typeof(str_result8) + str_result9 = u.upper().lower().strip().lstrip().rstrip().center(20).format().expandtabs(4).swapcase() + assert typeof(str_result9) == 'unicode object', typeof(str_result9) + + predicate1 = u.isupper() + assert typeof(predicate1) == 'bint', typeof(predicate1) + predicate2 = u.istitle() + assert typeof(predicate2) == 'bint', typeof(predicate2) + + cdef int cfunc(int x): return x+1 @@ -544,9 +573,8 @@ def safe_only(): div_res = pyint_val / 7 assert typeof(div_res) == ("double" if IS_LANGUAGE_LEVEL_3 else "Python object"), typeof(div_res) - # we special-case inference to type str s = "abc" - assert typeof(s) == ("unicode object" if IS_LANGUAGE_LEVEL_3 else "Python object"), (typeof(s), str_type) + assert typeof(s) == str_type, (typeof(s), str_type) cdef str t = "def" assert typeof(t) == str_type, (typeof(t), str_type)