Skip to content

Commit

Permalink
Merge branch 'main' of github.com:MiguelMarcelino/py2many into main
Browse files Browse the repository at this point in the history
  • Loading branch information
MiguelMarcelino committed Oct 6, 2022
2 parents 1274739 + 3ff654d commit 2f7f60a
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 60 deletions.
1 change: 1 addition & 0 deletions py2many/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ class InferTypesTransformer(ast.NodeTransformer):
bytes.translate: lambda self, node, vargs, kwargs: "bytes",
bytearray.translate: lambda self, node, vargs, kwargs: "bytearray",
argparse.ArgumentParser: lambda self, node, vargs, kwargs: "argparse.ArgumentParser",
isinstance: lambda self, node, vargs, kwargs: "bool",
zip: FuncTypeDispatch.visit_zip,
max: FuncTypeDispatch.visit_min_max,
min: FuncTypeDispatch.visit_min_max,
Expand Down
12 changes: 8 additions & 4 deletions pyjl/external/modules/ctypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ def visit_wintypes(self, node: ast.Call, vargs: list[str], kwargs: list[tuple[st
self._usings.add("WinTypes")
return f"WinTypes({', '.join(vargs)})"

def visit_functype(self, node: ast.Call, vargs: list[str], kwargs: list[tuple[str,str]]):
return f"func -> @cfunction($func, {vargs[0]}, ({', '.join(vargs[1:])}))"

# Hacks
def visit_Libdl(self, node: ast.Call, vargs: list[str], kwargs: list[tuple[str,str]]):
self._usings.add("Libdl")
Expand All @@ -93,6 +96,7 @@ def visit_Libdl(self, node: ast.Call, vargs: list[str], kwargs: list[tuple[str,s
ctypes.byref: (JuliaExternalModulePlugins.visit_byref, True),
ctypes.sizeof: (lambda self, node, vargs, kwargs: f"sizeof({self._map_type(vargs[0])})"
if vargs else "sizeof", True),
ctypes.CFUNCTYPE: (JuliaExternalModulePlugins.visit_functype, True),
# Using PythonCall
ctypes.POINTER: (JuliaExternalModulePlugins.visit_pointer, True),
ctypes.create_unicode_buffer: (JuliaExternalModulePlugins.visit_create_unicode_buffer, True),
Expand All @@ -105,7 +109,7 @@ def visit_Libdl(self, node: ast.Call, vargs: list[str], kwargs: list[tuple[str,s

GENERIC_SMALL_DISPATCH_MAP = {
"ctypes.memset": lambda node, vargs, kwargs: f"ccall(\"memset\", Ptr{{Cvoid}}, (Ptr{{Cvoid}}, Cint, Csize_t), {vargs[0]}, {vargs[1]}, {vargs[2]})",
"LPCWSTR": lambda node, vargs, kwargs: f"isa({vargs[0]}, String) ? Cwstring(pointer(transcode(Cwchar_t, {vargs[0]}))) : Cwstring(Ptr{{Cwchar_t}}({vargs[0]}))"
"LPCWSTR": lambda node, vargs, kwargs: f"isa({vargs[0]}, String) ? Cwstring(pointer_from_objref({vargs[0]})) : Cwstring(Ptr{{Cwchar_t}}({vargs[0]}))"
}

GENERIC_EXTERNAL_TYPE_MAP = {
Expand Down Expand Up @@ -171,15 +175,14 @@ def visit_Libdl(self, node: ast.Call, vargs: list[str], kwargs: list[tuple[str,s
else "Base.Libc.GetLastError()",
}

SMALL_DISPATCH_MAP = GENERIC_SMALL_DISPATCH_MAP | WIN_SMALL_DISPATCH_MAP

WIN_DISPATCH_TABLE = {
ctypes.WinDLL: (JuliaExternalModulePlugins.visit_load_library, True),
ctypes.windll.LoadLibrary: (JuliaExternalModulePlugins.visit_load_library, True),
# ctypes.WinDLL: (JuliaExternalModulePlugins.visit_windll, True),
# ctypes.GetLastError: (lambda self, node, vargs, kwargs: "Base.Libc.GetLastError", True),
ctypes.FormatError: (lambda self, node, vargs, kwargs: f"Base.Libc.FormatMessage({', '.join(vargs)})", True),
wintypes: (JuliaExternalModulePlugins.visit_wintypes, True),
ctypes.WINFUNCTYPE: (JuliaExternalModulePlugins.visit_functype, True),
}

WIN_EXTERNAL_TYPE_MAP = {
Expand All @@ -194,8 +197,9 @@ def visit_Libdl(self, node: ast.Call, vargs: list[str], kwargs: list[tuple[str,s
}

FUNC_DISPATCH_TABLE: Dict[FuncType, Tuple[Callable, bool]] = GENERIC_DISPATCH_TABLE | WIN_DISPATCH_TABLE
SMALL_DISPATCH_MAP = GENERIC_SMALL_DISPATCH_MAP | WIN_SMALL_DISPATCH_MAP
EXTERNAL_TYPE_MAP = WIN_EXTERNAL_TYPE_MAP | GENERIC_EXTERNAL_TYPE_MAP
FUNC_TYPE_MAP = WIN_FUNC_TYPE_MAP + GENERIC_FUNC_TYPE_MAP
FUNC_TYPE_MAP = WIN_FUNC_TYPE_MAP | GENERIC_FUNC_TYPE_MAP
else:
FUNC_DISPATCH_TABLE: Dict[FuncType, Tuple[Callable, bool]] = GENERIC_DISPATCH_TABLE
EXTERNAL_TYPE_MAP = GENERIC_EXTERNAL_TYPE_MAP
Expand Down
3 changes: 2 additions & 1 deletion pyjl/external/modules/shapely.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
from pyjl.helpers import pycall_import

try:
import shapely
from shapely.geometry.base import BaseGeometry
from shapely.geometry import Point
from shapely.ops import transform
except:
except ImportError:
shapely = None

from typing import Callable, Dict, Optional, Tuple, Union
Expand Down
91 changes: 39 additions & 52 deletions pyjl/rewriters.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,26 +423,29 @@ def _generic_test_visit(self, node):

annotation = getattr(node.test, "annotation", None)
ann_id = get_ann_repr(annotation, sep=SEP)
if ann_id:
if ann_id == "int" or ann_id == "float":
node.test = self._build_compare(node.test,
[ast.NotEq()], [ast.Constant(value=0)])
elif re.match(r"^list|^List", ann_id):
# Compare with empty list
node.test = self._build_compare(node.test,
[ast.IsNot()], [ast.List(elts=[])])
elif re.match(r"^tuple|^Tuple", ann_id):
# Compare with empty tuple
node.test = self._build_compare(node.test,
[ast.IsNot()], [ast.Tuple(elts=[])])
elif re.match(r"^Optional", ann_id):
# Compare with type None
node.test = self._build_compare(node.test,
[ast.IsNot()], [ast.Constant(value=None)])
if not isinstance(node.test, ast.Compare) and \
not isinstance(node.test, ast.UnaryOp):
if ann_id:
if ann_id != "bool":
if ann_id == "int" or ann_id == "float":
node.test = self._build_compare(node.test,
[ast.NotEq()], [ast.Constant(value=0)])
elif re.match(r"^list|^List", ann_id):
# Compare with empty list
node.test = self._build_compare(node.test,
[ast.IsNot()], [ast.List(elts=[])])
elif re.match(r"^tuple|^Tuple", ann_id):
# Compare with empty tuple
node.test = self._build_compare(node.test,
[ast.IsNot()], [ast.Tuple(elts=[])])
elif re.match(r"^Optional", ann_id):
# Compare with type None
node.test = self._build_compare(node.test,
[ast.IsNot()], [ast.Constant(value=None)])
else:
node.test = self._build_runtime_comparison(node)
else:
node.test = self._build_runtime_comparison(node)
else:
node.test = self._build_runtime_comparison(node)

def _build_compare(self, node, ops, comp_values):
for comp_value in comp_values:
Expand Down Expand Up @@ -484,7 +487,12 @@ def _build_runtime_comparison(self, node):
self._build_compare(node.test, [ast.NotEq()], [ast.List(elts=[])])]),
ast.BoolOp(
op = ast.And(),
values = [self._build_compare(node.test, [ast.Is()], [ast.Constant(value=None)])])
values = [self._build_compare(node.test, [ast.Is()], [ast.Constant(value=None)])]),
ast.BoolOp(
op = ast.And(),
values = [
instance_check([node.test, ast.Name(id="bool")]),
node.test]),
]
)
ast.fix_missing_locations(node.test)
Expand Down Expand Up @@ -2027,9 +2035,9 @@ class JuliaCtypesRewriter(ast.NodeTransformer):
ctypes.c_ssize_t, ctypes.c_char_p, ctypes.c_wchar_p, ctypes.c_void_p,
]

SPECIAL_CALLS = {
"ctypes.WINFUNCTYPE"
}
# SPECIAL_CALLS = {
# "ctypes.WINFUNCTYPE"
# }

NONE_TYPES = {"c_void_p", "HANDLE", "HMODULE"}

Expand All @@ -2053,7 +2061,6 @@ def __init__(self) -> None:
# Mapps assignment target id's to ctypes call types
self._assign_ctypes_funcs = {}
# Mapps special assignment target ids to their respective values
self._special_assignments: dict[str, ast.Call] = {}

def visit_Module(self, node: ast.Module) -> Any:
self._imported_names = getattr(node, "imported_names", None)
Expand Down Expand Up @@ -2081,15 +2088,6 @@ def _ctypes_assign_visit(self, node, target) -> Any:
self._assign_ctypes_funcs[get_id(target)] = self._ctypes_func_types
self._ctypes_func_types = None

# Check for any special calls to replace
if isinstance(node.value, ast.Call) and \
get_id(node.value.func) in self.SPECIAL_CALLS:
if isinstance(node, ast.Assign):
self._special_assignments[get_id(node.targets[0])] = node.value
elif isinstance(node, ast.AnnAssign):
self._special_assignments[get_id(node.target)] = node.value
return None

# Check if the target is what we are looking for
admissible_args = re.match(r".*argtypes$|.*restype$|.*errcheck$", get_id(target)) \
if get_id(target) else False
Expand Down Expand Up @@ -2144,23 +2142,6 @@ def visit_Call(self, node: ast.Call) -> Any:
node.func.is_call_func = True
self.generic_visit(node)

# Check for any special calls to replace
if get_id(node.func) in self._special_assignments:
func_name = node.args[0]
if isinstance(func_name, ast.Name):
func_name.id = f"${get_id(func_name)}"
call_node = self._special_assignments[get_id(node.func)]
restype = call_node.args[0]
argtypes = ast.Tuple(elts = call_node.args[1:])
cfunc = ast.Call(
func = ast.Name(id = "@cfunction"),
args = [func_name, restype, argtypes],
keywords = [],
scopes = node.scopes
)
ast.fix_missing_locations(cfunc)
return cfunc

func = node.func
mod_name = ""
# Retrieve module name
Expand Down Expand Up @@ -2261,6 +2242,8 @@ def visit_Call(self, node: ast.Call) -> Any:
ptr_node = self._make_ptr("Cvoid")
replace_cond = lambda x: get_id(getattr(x, "annotation", None)) in \
{"PyObject", "ctypes._FuncPointer", "_FuncPointer", "ctypes.POINTER"}
# Save old argument types
old_argtypes = argtypes.elts
argtypes.elts = list(map(lambda x: ptr_node if replace_cond(x) else x, argtypes.elts))
# Set all as annotation
for arg in argtypes.elts:
Expand Down Expand Up @@ -2303,16 +2286,20 @@ def visit_Call(self, node: ast.Call) -> Any:
# If it is a factory, build a lamdba expression
var_list: list[str] = [f"a{i}" for i in range(len(argtypes.elts))]
args = ast.arguments(args=[ast.arg(arg=var) for var in var_list], defaults=[])
for var, typ in zip(var_list, argtypes.elts):
annotations = list(map(lambda x: getattr(x, "annotation", None), old_argtypes))
for var, typ, ann in zip(var_list, argtypes.elts, annotations):
mapped_arg = None
if get_id(typ) in self.WRAP_TYPES:
mapped_arg = self.WRAP_TYPES[get_id(typ)](var)
# elif get_id(ann) in {"ctypes._FuncPointer", "_FuncPointer"}:
# mapped_arg = f"Ptr[Cvoid]({var})"
if mapped_arg:
arg_node = cast(ast.Expr, create_ast_node(mapped_arg)).value
fill_attributes(arg_node, node.scopes, no_rewrite=True,
preserve_keyword=True, is_annotation=True)
ccall.args.append(arg_node)
else:
ccall.args.append(ast.Name(id=var))
# ccall.args.extend([ast.Name(id=var) for var in var_list])
if errcheck_call:
# Pass the arguments to the errcheck function
errcheck_call.args.append(
Expand All @@ -2337,7 +2324,7 @@ def visit_Call(self, node: ast.Call) -> Any:
# Pass the arguments to the errcheck function
errcheck_call.args.append(ast.Tuple(elts=node.args))
# TODO: ccall with error check not yet supported with
# non-factory expressionss
# non-factory expressions
return ccall
else:
return ccall
Expand Down
8 changes: 5 additions & 3 deletions pyjl/setup_files/setup.ini
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
; oop_nested_funcs=True
; optimize_loop_ranges=True
;
; use_arbitrary_precision=True
;
;##############################
; For win32ctypes
loop_scope_warning=True
fix_scope_bounds=True
use_modules=True
use_resumables=True
remove_nested=True
remove_nested_resumables=True
;
; use_arbitrary_precision=True
remove_nested_resumables=True

0 comments on commit 2f7f60a

Please sign in to comment.