Skip to content

Commit

Permalink
Implement return type inference for methods of builtin types (GH-5865)
Browse files Browse the repository at this point in the history
* Also, infer 'str' in Py3 now, since it's a safe type.

* Builtin types generally don't return None on slicing, and string types also never do it for indexing, so change "IndexNode.may_be_none()" accordingly.
  • Loading branch information
scoder authored Nov 25, 2023
1 parent 875293e commit 4847429
Show file tree
Hide file tree
Showing 8 changed files with 296 additions and 52 deletions.
172 changes: 170 additions & 2 deletions Cython/Compiler/Builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,8 +302,7 @@ def declare_in_type(self, self_type):
BuiltinMethod("__mul__", "Tz", "T", "__Pyx_PySequence_Multiply",
utility_code=UtilityCode.load("PySequenceMultiply", "ObjectHandling.c")),
]),
("str", "&PyString_Type", [BuiltinMethod("join", "TO", "O", "__Pyx_PyString_Join",
builtin_return_type='basestring',
("str", "&PyString_Type", [BuiltinMethod("join", "TO", "T", "__Pyx_PyString_Join",
utility_code=UtilityCode.load("StringJoin", "StringTools.c")),
BuiltinMethod("__mul__", "Tz", "T", "__Pyx_PySequence_Multiply",
utility_code=UtilityCode.load("PySequenceMultiply", "ObjectHandling.c")),
Expand Down Expand Up @@ -411,6 +410,175 @@ def declare_in_type(self, self_type):
})


inferred_method_return_types = {
'complex': dict(
conjugate='complex',
),
'int': dict(
bit_length='T',
bit_count='T',
to_bytes='bytes',
from_bytes='T', # classmethod
as_integer_ratio='tuple[int,int]',
is_integer='bint',
),
'float': dict(
as_integer_ratio='tuple[int,int]',
is_integer='bint',
hex='unicode',
fromhex='T', # classmethod
),
'list': dict(
index='Py_ssize_t',
count='Py_ssize_t',
),
'unicode': dict(
capitalize='T',
casefold='T',
center='T',
count='Py_ssize_t',
encode='bytes',
endswith='bint',
expandtabs='T',
find='Py_ssize_t',
format='T',
format_map='T',
index='Py_ssize_t',
isalnum='bint',
isalpha='bint',
isascii='bint',
isdecimal='bint',
isdigit='bint',
isidentifier='bint',
islower='bint',
isnumeric='bint',
isprintable='bint',
isspace='bint',
istitle='bint',
isupper='bint',
join='T',
ljust='T',
lower='T',
lstrip='T',
maketrans='dict[int,object]', # staticmethod
partition='tuple[T,T,T]',
removeprefix='T',
removesuffix='T',
replace='T',
rfind='Py_ssize_t',
rindex='Py_ssize_t',
rjust='T',
rpartition='tuple[T,T,T]',
rsplit='list[T]',
rstrip='T',
split='list[T]',
splitlines='list[T]',
startswith='bint',
strip='T',
swapcase='T',
title='T',
translate='T',
upper='T',
zfill='T',
),
'bytes': dict(
hex='unicode',
fromhex='T', # classmethod
count='Py_ssize_t',
removeprefix='T',
removesuffix='T',
decode='unicode',
endswith='bint',
find='Py_ssize_t',
index='Py_ssize_t',
join='T',
maketrans='bytes', # staticmethod
partition='tuple[T,T,T]',
replace='T',
rfind='Py_ssize_t',
rindex='Py_ssize_t',
rpartition='tuple[T,T,T]',
startswith='bint',
translate='T',
center='T',
ljust='T',
lstrip='T',
rjust='T',
rsplit='list[T]',
rstrip='T',
split='list[T]',
strip='T',
capitalize='T',
expandtabs='T',
isalnum='bint',
isalpha='bint',
isascii='bint',
isdigit='bint',
islower='bint',
isspace='bint',
istitle='bint',
isupper='bint',
lower='T',
splitlines='list[T]',
swapcase='T',
title='T',
upper='T',
zfill='T',
),
'bytearray': dict(
# Inherited from 'bytes' below.
),
'memoryview': dict(
tobytes='bytes',
hex='unicode',
tolist='list',
toreadonly='T',
cast='T',
),
'set': dict(
isdisjoint='bint',
isubset='bint',
issuperset='bint',
union='T',
intersection='T',
difference='T',
symmetric_difference='T',
copy='T',
),
'frozenset': dict(
# Inherited from 'set' below.
),
'dict': dict(
copy='T',
),
}

inferred_method_return_types['bytearray'].update(inferred_method_return_types['bytes'])
inferred_method_return_types['frozenset'].update(inferred_method_return_types['set'])
inferred_method_return_types['str'] = inferred_method_return_types['unicode']


def find_return_type_of_builtin_method(builtin_type, method_name):
type_name = builtin_type.name
if type_name in inferred_method_return_types:
methods = inferred_method_return_types[type_name]
if method_name in methods:
return_type_name = methods[method_name]
if '[' in return_type_name:
# TODO: Keep the "[...]" part when we add support for generics.
return_type_name = return_type_name.partition('[')[0]
if return_type_name == 'T':
return builtin_type
if 'T' in return_type_name:
return_type_name = return_type_name.replace('T', builtin_type.name)
if return_type_name == 'bint':
return PyrexTypes.c_bint_type
elif return_type_name == 'Py_ssize_t':
return PyrexTypes.c_py_ssize_t_type
return builtin_scope.lookup(return_type_name).type
return PyrexTypes.py_object_type


builtin_structs_table = [
('Py_buffer', 'Py_buffer',
[("buf", PyrexTypes.c_void_ptr_type),
Expand Down
52 changes: 29 additions & 23 deletions Cython/Compiler/ExprNodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3877,10 +3877,13 @@ def may_be_none(self):
if base_type:
if base_type.is_string:
return False
if base_type in (unicode_type, bytes_type, str_type, bytearray_type, basestring_type):
return False
if isinstance(self.index, SliceNode):
# slicing!
if base_type in (bytes_type, bytearray_type, str_type, unicode_type,
basestring_type, list_type, tuple_type):
if base_type.is_builtin_type:
# It seems that none of the builtin types can return None for "__getitem__[slice]".
# Slices are not hashable, and thus cannot be used as key in dicts, for example.
return False
return ExprNode.may_be_none(self)

Expand Down Expand Up @@ -5845,8 +5848,15 @@ class CallNode(ExprNode):
may_return_none = None

def infer_type(self, env):
# TODO(robertwb): Reduce redundancy with analyse_types.
function = self.function
if function.is_attribute:
method_obj_type = function.obj.infer_type(env)
if method_obj_type.is_builtin_type:
result_type = Builtin.find_return_type_of_builtin_method(method_obj_type, function.attribute)
if result_type is not py_object_type:
return result_type

# TODO(robertwb): Reduce redundancy with analyse_types.
func_type = function.infer_type(env)
if isinstance(function, NewExprNode):
# note: needs call to infer_type() above
Expand Down Expand Up @@ -5932,6 +5942,16 @@ def set_py_result_type(self, function, func_type=None):
self.type = function.type_entry.type
self.result_ctype = py_object_type
self.may_return_none = False
elif function.is_attribute and function.obj.type.is_builtin_type:
method_obj_type = function.obj.type
result_type = Builtin.find_return_type_of_builtin_method(method_obj_type, function.attribute)
self.may_return_none = result_type is py_object_type
if result_type.is_pyobject:
self.type = result_type
elif result_type.equivalent_type:
self.type = result_type.equivalent_type
else:
self.type = py_object_type
else:
self.type = py_object_type

Expand Down Expand Up @@ -11952,8 +11972,8 @@ def is_py_operation_types(self, type1, type2):

def infer_builtin_types_operation(self, type1, type2):
# b'abc' + 'abc' raises an exception in Py3,
# so we can safely infer the Py2 type for bytes here
string_types = (bytes_type, bytearray_type, str_type, basestring_type, unicode_type)
# so we can safely infer a mix here.
string_types = (bytes_type, bytearray_type, basestring_type, str_type, unicode_type)
if type1 in string_types and type2 in string_types:
return string_types[max(string_types.index(type1),
string_types.index(type2))]
Expand Down Expand Up @@ -12326,25 +12346,11 @@ def is_py_operation_types(self, type1, type2):
or NumBinopNode.is_py_operation_types(self, type1, type2))

def infer_builtin_types_operation(self, type1, type2):
# b'%s' % xyz raises an exception in Py3<3.5, so it's safe to infer the type for Py2 and later Py3's.
if type1 is unicode_type:
# None + xyz may be implemented by RHS
if type2.is_builtin_type or not self.operand1.may_be_none():
return type1
elif type1 in (bytes_type, str_type, basestring_type):
if type2 is unicode_type:
return type2
elif type2.is_numeric:
# b'%s' % xyz raises an exception in Py3<3.5, so it's safe to infer the type for later Py3's.
if type1 in (unicode_type, bytes_type, str_type, basestring_type):
# 'None % xyz' may be implemented by the RHS, but everything else will do string formatting.
if type2.is_builtin_type or not type2.is_pyobject or not self.operand1.may_be_none():
return type1
elif self.operand1.is_string_literal:
if type1 is str_type or type1 is bytes_type:
if set(_find_formatting_types(self.operand1.value)) <= _safe_bytes_formats:
return type1
return basestring_type
elif type1 is bytes_type and not type2.is_builtin_type:
return None # RHS might implement '% operator differently in Py3
else:
return basestring_type # either str or unicode, can't tell
return super().infer_builtin_types_operation(type1, type2)

def zero_division_message(self):
Expand Down
32 changes: 32 additions & 0 deletions Cython/Compiler/Tests/TestBuiltin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import builtins
import sys
import unittest

from ..Builtin import (
inferred_method_return_types, find_return_type_of_builtin_method,
builtin_scope,
)


class TestBuiltinReturnTypes(unittest.TestCase):
def test_find_return_type_of_builtin_method(self):
# It's enough to test the method existence in a recent Python that likely has them.
look_up_methods = sys.version_info >= (3,10)

for type_name, methods in inferred_method_return_types.items():
py_type = getattr(builtins, type_name if type_name != 'unicode' else 'str')

for method_name, return_type_name in methods.items():
builtin_type = builtin_scope.lookup(type_name).type
return_type = find_return_type_of_builtin_method(builtin_type, method_name)

if return_type.is_builtin_type:
if '[' in return_type_name:
return_type_name = return_type_name.partition('[')[0]
if return_type_name == 'T':
return_type_name = type_name
self.assertEqual(return_type.name, return_type_name)
if look_up_methods:
self.assertTrue(hasattr(py_type, method_name), f"{type_name}.{method_name}")
else:
self.assertEqual(return_type.empty_declaration_code(pyrex=True), return_type_name)
9 changes: 1 addition & 8 deletions Cython/Compiler/TypeInference.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,14 +544,7 @@ def aggressive_spanning_type(types, might_overflow, scope):
def safe_spanning_type(types, might_overflow, scope):
result_type = simply_type(reduce(find_spanning_type, types))
if result_type.is_pyobject:
# In theory, any specific Python type is always safe to
# infer. However, inferring str can cause some existing code
# to break, since we are also now much more strict about
# coercion from str to char *. See trac #553.
if result_type.name == 'str':
return py_object_type
else:
return result_type
return result_type
elif (result_type is PyrexTypes.c_double_type or
result_type is PyrexTypes.c_float_type):
# Python's float type is just a C double, so it's safe to use
Expand Down
4 changes: 2 additions & 2 deletions tests/run/bytesmethods.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def bytes_join(bytes s, *args):
babab
"""
result = s.join(args)
assert cython.typeof(result) == 'Python object', cython.typeof(result)
assert cython.typeof(result) == 'bytes object', cython.typeof(result)
return result


Expand All @@ -275,7 +275,7 @@ def literal_join(*args):
b|b|b|b
"""
result = b'|'.join(args)
assert cython.typeof(result) == 'Python object', cython.typeof(result)
assert cython.typeof(result) == 'bytes object', cython.typeof(result)
return result

def fromhex(bytes b):
Expand Down
Loading

0 comments on commit 4847429

Please sign in to comment.