Skip to content

Commit

Permalink
Improved translation of ctypes
Browse files Browse the repository at this point in the history
  • Loading branch information
MiguelMarcelino committed Sep 27, 2022
1 parent d90cf6b commit b934a4e
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 33 deletions.
5 changes: 5 additions & 0 deletions pyjl/external/modules/ctypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ def visit_wintypes(self, node: ast.Call, vargs: list[str], kwargs: list[tuple[st
self._usings.add("WinTypes")
return f"WinTypes({', '.join(vargs)})"

def visit_functype(self, node: ast.Call, vargs: list[str], kwargs: list[tuple[str,str]]):
return f"func -> @cfunction($func, {vargs[0]}, ({', '.join(vargs[1:])}))"

# Hacks
def visit_Libdl(self, node: ast.Call, vargs: list[str], kwargs: list[tuple[str,str]]):
self._usings.add("Libdl")
Expand All @@ -93,6 +96,7 @@ def visit_Libdl(self, node: ast.Call, vargs: list[str], kwargs: list[tuple[str,s
ctypes.byref: (JuliaExternalModulePlugins.visit_byref, True),
ctypes.sizeof: (lambda self, node, vargs, kwargs: f"sizeof({self._map_type(vargs[0])})"
if vargs else "sizeof", True),
ctypes.CFUNCTYPE: (JuliaExternalModulePlugins.visit_functype, True),
# Using PythonCall
ctypes.POINTER: (JuliaExternalModulePlugins.visit_pointer, True),
ctypes.create_unicode_buffer: (JuliaExternalModulePlugins.visit_create_unicode_buffer, True),
Expand Down Expand Up @@ -131,6 +135,7 @@ def visit_Libdl(self, node: ast.Call, vargs: list[str], kwargs: list[tuple[str,s
# ctypes.GetLastError: (lambda self, node, vargs, kwargs: "Base.Libc.GetLastError", True),
ctypes.FormatError: (lambda self, node, vargs, kwargs: f"Base.Libc.FormatMessage({', '.join(vargs)})", True),
wintypes: (JuliaExternalModulePlugins.visit_wintypes, True),
ctypes.WINFUNCTYPE: (JuliaExternalModulePlugins.visit_functype, True),
}
FUNC_DISPATCH_TABLE: Dict[FuncType, Tuple[Callable, bool]] = GENERIC_DISPATCH_TABLE | WIN_DISPATCH_TABLE
else:
Expand Down
45 changes: 12 additions & 33 deletions pyjl/rewriters.py
Original file line number Diff line number Diff line change
Expand Up @@ -1986,9 +1986,9 @@ class JuliaCtypesRewriter(ast.NodeTransformer):
ctypes.c_ssize_t, ctypes.c_char_p, ctypes.c_wchar_p, ctypes.c_void_p,
]

SPECIAL_CALLS = {
"ctypes.WINFUNCTYPE"
}
# SPECIAL_CALLS = {
# "ctypes.WINFUNCTYPE"
# }

NONE_TYPES = {"c_void_p", "HANDLE", "HMODULE"}

Expand All @@ -2012,7 +2012,6 @@ def __init__(self) -> None:
# Mapps assignment target id's to ctypes call types
self._assign_ctypes_funcs = {}
# Mapps special assignment target ids to their respective values
self._special_assignments: dict[str, ast.Call] = {}

def visit_Module(self, node: ast.Module) -> Any:
self._imported_names = getattr(node, "imported_names", None)
Expand Down Expand Up @@ -2040,15 +2039,6 @@ def _ctypes_assign_visit(self, node, target) -> Any:
self._assign_ctypes_funcs[get_id(target)] = self._ctypes_func_types
self._ctypes_func_types = None

# Check for any special calls to replace
if isinstance(node.value, ast.Call) and \
get_id(node.value.func) in self.SPECIAL_CALLS:
if isinstance(node, ast.Assign):
self._special_assignments[get_id(node.targets[0])] = node.value
elif isinstance(node, ast.AnnAssign):
self._special_assignments[get_id(node.target)] = node.value
return None

# Check if the target is what we are looking for
admissible_args = re.match(r".*argtypes$|.*restype$|.*errcheck$", get_id(target)) \
if get_id(target) else False
Expand Down Expand Up @@ -2103,23 +2093,6 @@ def visit_Call(self, node: ast.Call) -> Any:
node.func.is_call_func = True
self.generic_visit(node)

# Check for any special calls to replace
if get_id(node.func) in self._special_assignments:
func_name = node.args[0]
if isinstance(func_name, ast.Name):
func_name.id = f"${get_id(func_name)}"
call_node = self._special_assignments[get_id(node.func)]
restype = call_node.args[0]
argtypes = ast.Tuple(elts = call_node.args[1:])
cfunc = ast.Call(
func = ast.Name(id = "@cfunction"),
args = [func_name, restype, argtypes],
keywords = [],
scopes = node.scopes
)
ast.fix_missing_locations(cfunc)
return cfunc

func = node.func
mod_name = ""
# Retrieve module name
Expand Down Expand Up @@ -2220,6 +2193,8 @@ def visit_Call(self, node: ast.Call) -> Any:
ptr_node = self._make_ptr("Cvoid")
replace_cond = lambda x: get_id(getattr(x, "annotation", None)) in \
{"PyObject", "ctypes._FuncPointer", "_FuncPointer", "ctypes.POINTER"}
# Save old argument types
old_argtypes = argtypes.elts
argtypes.elts = list(map(lambda x: ptr_node if replace_cond(x) else x, argtypes.elts))
# Set all as annotation
for arg in argtypes.elts:
Expand Down Expand Up @@ -2262,16 +2237,20 @@ def visit_Call(self, node: ast.Call) -> Any:
# If it is a factory, build a lamdba expression
var_list: list[str] = [f"a{i}" for i in range(len(argtypes.elts))]
args = ast.arguments(args=[ast.arg(arg=var) for var in var_list], defaults=[])
for var, typ in zip(var_list, argtypes.elts):
annotations = list(map(lambda x: getattr(x, "annotation", None), old_argtypes))
for var, typ, ann in zip(var_list, argtypes.elts, annotations):
mapped_arg = None
if get_id(typ) in self.WRAP_TYPES:
mapped_arg = self.WRAP_TYPES[get_id(typ)](var)
# elif get_id(ann) in {"ctypes._FuncPointer", "_FuncPointer"}:
# mapped_arg = f"Ptr[Cvoid]({var})"
if mapped_arg:
arg_node = cast(ast.Expr, create_ast_node(mapped_arg)).value
fill_attributes(arg_node, node.scopes, no_rewrite=True,
preserve_keyword=True, is_annotation=True)
ccall.args.append(arg_node)
else:
ccall.args.append(ast.Name(id=var))
# ccall.args.extend([ast.Name(id=var) for var in var_list])
if errcheck_call:
# Pass the arguments to the errcheck function
errcheck_call.args.append(
Expand All @@ -2296,7 +2275,7 @@ def visit_Call(self, node: ast.Call) -> Any:
# Pass the arguments to the errcheck function
errcheck_call.args.append(ast.Tuple(elts=node.args))
# TODO: ccall with error check not yet supported with
# non-factory expressionss
# non-factory expressions
return ccall
else:
return ccall
Expand Down

0 comments on commit b934a4e

Please sign in to comment.