Skip to content

Commit

Permalink
Updated Inference and external module import
Browse files Browse the repository at this point in the history
  • Loading branch information
MiguelMarcelino committed Jul 21, 2022
1 parent 1860558 commit f5dd625
Show file tree
Hide file tree
Showing 11 changed files with 79 additions and 85 deletions.
32 changes: 22 additions & 10 deletions py2many/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,10 +123,7 @@ def core_transformers(tree, trees, args):
detect_nesting_levels(tree)
add_annotation_flags(tree)
add_imports(tree)
infer_meta = (
infer_types_typpete(tree) if args and args.typpete else infer_types(tree)
)
return tree, infer_meta
return tree


def _transpile(
Expand All @@ -142,6 +139,7 @@ def _transpile(
target language
"""
transpiler = settings.transpiler
inference = settings.inference
rewriters = settings.rewriters
transformers = settings.transformers
post_rewriters = settings.post_rewriters
Expand Down Expand Up @@ -196,6 +194,7 @@ def _transpile(
transformers,
post_rewriters,
optimization_rewriters,
inference,
config_handler,
args,
)
Expand Down Expand Up @@ -229,6 +228,7 @@ def _transpile_one(
transformers,
post_rewriters,
optimization_rewriters,
inference,
config_handler,
args,
):
Expand All @@ -242,7 +242,15 @@ def _transpile_one(
for rewriter in rewriters:
tree = rewriter.visit(tree)
# Language independent core transformers
tree, infer_meta = core_transformers(tree, trees, args)
tree = core_transformers(tree, trees, args)
# Type inference
if args and args.typpete:
infer_meta = infer_types_typpete(tree)
else:
if inference:
infer_meta = inference(tree)
else:
infer_meta = infer_types(tree)
# Language specific transformers
for tx in transformers:
tx(tree)
Expand All @@ -254,7 +262,7 @@ def _transpile_one(
tree = opt_rewriter.visit(tree)

# Rerun core transformers
tree, infer_meta = core_transformers(tree, trees, args)
tree = core_transformers(tree, trees, args)
out = []

transpile_output = transpiler.visit(tree)
Expand Down Expand Up @@ -406,10 +414,11 @@ def rust_settings(args, env=os.environ):
["rustfmt", "--edition=2018"],
None,
rewriters=[RustNoneCompareRewriter()],
transformers=[functools.partial(infer_rust_types, extension=args.extension)],
transformers=[],
post_rewriters=[RustLoopIndexRewriter(), RustStringJoinRewriter()],
create_project=["cargo", "new", "--bin"],
project_subdir="src",
inference = functools.partial(infer_rust_types, extension=args.extension),
)


Expand Down Expand Up @@ -444,7 +453,6 @@ def julia_settings(args, env=os.environ):
rewriters=[JuliaMainRewriter()],
transformers=[
parse_decorators,
infer_julia_types,
analyse_variable_scope,
optimize_loop_ranges,
find_ordered_collections,
Expand All @@ -468,6 +476,7 @@ def julia_settings(args, env=os.environ):
JuliaArbitraryPrecisionRewriter(),
],
optimization_rewriters=[AlgebraicSimplification(), OperationOptimizer()],
inference = infer_julia_types
)


Expand Down Expand Up @@ -526,11 +535,12 @@ def go_settings(args, env=os.environ):
["gofmt", "-w"],
None,
[GoNoneCompareRewriter(), GoVisibilityRewriter(), GoIfExpRewriter()],
[infer_go_types],
[],
[GoMethodCallRewriter(), GoPropagateTypeAnnotation()],
linter=(
["revive", "--config", str(revive_config)] if revive_config else ["revive"]
),
inference = infer_go_types
)


Expand All @@ -547,7 +557,8 @@ def vlang_settings(args, env=os.environ):
["v", *vfmt_args],
None,
[VNoneCompareRewriter(), VDictRewriter(), VComprehensionRewriter()],
[infer_v_types],
[],
inference = infer_v_types
)


Expand All @@ -562,6 +573,7 @@ def smt_settings(args, env=os.environ):
None,
[],
[infer_smt_types],
inference = infer_smt_types
)


Expand Down
2 changes: 0 additions & 2 deletions py2many/clike.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@
c_uint64 as u64,
)

from py2many.external_modules import import_external_modules

ilong = i64
ulong = u64
isize = i64
Expand Down
75 changes: 39 additions & 36 deletions py2many/external_modules.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from dataclasses import dataclass
import os
import imp
from pydoc import ispath
from types import ModuleType

MOD_DIR = f"external{os.sep}modules"
Expand Down Expand Up @@ -31,42 +33,43 @@
}


def import_external_modules(self, lang):
"""Updates all the dispatch maps to account for external modules"""
external_mods: list[tuple[str, str]] = _get_external_modules(lang)
for mod_path in external_mods:
split_name: tuple[str, str] = os.path.split(mod_path)
mod_name = split_name[1]
ext_mod: ModuleType = imp.load_source(mod_name, mod_path)
for attr_name, map_name in MOD_NAMES:
if attr_name in self.__dict__ and map_name in ext_mod.__dict__:
obj = ext_mod.__dict__[map_name]
curr_val = getattr(self, attr_name, None)
# Update value in default containers
if isinstance(curr_val, dict):
curr_val |= obj
elif isinstance(curr_val, list):
curr_val.extend(obj)
elif isinstance(curr_val, set):
curr_val.update(obj)
@dataclass
class ExternalBase():
"""Base class to add external modules"""


def _get_external_modules(lang) -> list[tuple[str, str]]:
p_lang = lang
if lang in LANG_MAP:
p_lang = LANG_MAP[lang]
else:
raise Exception("Language not supported")
# Get files
path = f"{os.getcwd()}{os.sep}{p_lang}{os.sep}{MOD_DIR}"
return [
f"{path}{os.sep}{file}"
for file in os.listdir(path)
if os.path.isfile(os.path.join(path, file)) and file != "__init__.py"
]
def import_external_modules(self, lang):
"""Updates all the dispatch maps to account for external modules"""
external_mods: list[tuple[str, str]] = self._get_external_modules(lang)
if external_mods:
for mod_path in external_mods:
split_name: tuple[str, str] = os.path.split(mod_path)
mod_name = split_name[1]
ext_mod: ModuleType = imp.load_source(mod_name, mod_path)
for attr_name, map_name in MOD_NAMES:
if attr_name in self.__dict__ and map_name in ext_mod.__dict__:
obj = ext_mod.__dict__[map_name]
curr_val = getattr(self, attr_name, None)
# Update value in default containers
if isinstance(curr_val, dict):
curr_val |= obj
elif isinstance(curr_val, list):
curr_val.extend(obj)
elif isinstance(curr_val, set):
curr_val.update(obj)


class ExternalWrapper():
"""Wrapper to add external modules"""
def __init__(self, lang):
import_external_modules(self, lang)
def _get_external_modules(self, lang) -> list[tuple[str, str]]:
p_lang = lang
if lang in LANG_MAP:
p_lang = LANG_MAP[lang]
else:
raise Exception("Language not supported")
# Get files
path = f"{os.getcwd()}{os.sep}{p_lang}{os.sep}{MOD_DIR}"
if not os.path.isdir(path):
return None
return [
f"{path}{os.sep}{file}"
for file in os.listdir(path)
if os.path.isfile(os.path.join(path, file)) and file != "__init__.py"
]
1 change: 1 addition & 0 deletions py2many/language.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class LanguageSettings:
transformers: List[Callable] = field(default_factory=list)
post_rewriters: List[ast.NodeVisitor] = field(default_factory=list)
optimization_rewriters: List[ast.NodeVisitor] = field(default_factory=list)
inference: Optional[Callable] = field(default_factory=list)
linter: Optional[List[str]] = None
# Create a language specific project structure
create_project: Optional[List[str]] = None
Expand Down
2 changes: 1 addition & 1 deletion pygo/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def get_inferred_go_type(node):


# Copy pasta from rust. Double check for correctness
class InferGoTypesTransformer(ast.NodeTransformer):
class InferGoTypesTransformer(InferTypesTransformer):
"""Implements go type inference logic as opposed to python type inference logic"""

FIXED_WIDTH_INTS = InferTypesTransformer.FIXED_WIDTH_INTS
Expand Down
8 changes: 4 additions & 4 deletions pyjl/clike.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import logging

from py2many.clike import CLikeTranspiler as CommonCLikeTranspiler, class_for_typename
from py2many.external_modules import import_external_modules
from py2many.external_modules import ExternalBase
from py2many.tracer import find_node_by_type
from pyjl.helpers import get_ann_repr
from pyjl.juliaAst import JuliaNodeVisitor
Expand Down Expand Up @@ -149,7 +149,7 @@ def jl_symbol(node):
symbol_type = type(node)
return jl_symbols[symbol_type]

class CLikeTranspiler(CommonCLikeTranspiler, JuliaNodeVisitor):
class CLikeTranspiler(CommonCLikeTranspiler, JuliaNodeVisitor, ExternalBase):
def __init__(self):
super().__init__()
self._type_map = JULIA_TYPE_MAP
Expand All @@ -167,8 +167,8 @@ def __init__(self):
self._use_modules = None
self._external_type_map = {}
self._module_dispatch_table = MODULE_DISPATCH_TABLE
#
import_external_modules(self, "Julia")
# Get external module features
self.import_external_modules("Julia")

def usings(self):
usings = sorted(list(set(self._usings)))
Expand Down
36 changes: 8 additions & 28 deletions pyjl/inference.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
import ast
from ctypes import c_int64
from typing import Any
from py2many.external_modules import import_external_modules
from py2many.external_modules import ExternalBase

from py2many.inference import InferTypesTransformer
from py2many.inference import InferMeta, InferTypesTransformer
from py2many.analysis import get_id
from py2many.exceptions import AstIncompatibleAssign, AstUnrecognisedBinOp
from py2many.clike import class_for_typename
from pyjl.clike import CLikeTranspiler
from pyjl.helpers import get_ann_repr
from pyjl.global_vars import NONE_TYPE
from pyjl.global_vars import DEFAULT_TYPE

def infer_julia_types(node, extension=False):
visitor = InferJuliaTypesTransformer()
visitor.visit(node)
return InferMeta(visitor.has_fixed_width_ints)

class InferJuliaTypesTransformer(ast.NodeTransformer):
class InferJuliaTypesTransformer(InferTypesTransformer, ExternalBase):
"""
Implements Julia type inference logic
"""
Expand All @@ -32,8 +32,8 @@ def __init__(self):
self._clike = CLikeTranspiler()
self._imported_names = None
self._func_type_map = InferTypesTransformer.FUNC_TYPE_MAP
#
import_external_modules(self, "Julia")
# Get external module features
self.import_external_modules("Julia")

def visit_Module(self, node: ast.Module) -> Any:
self._imported_names = node.imported_names
Expand Down Expand Up @@ -133,7 +133,7 @@ def _find_annotated_assign(self, node):
return None

def visit_AnnAssign(self, node: ast.AnnAssign) -> ast.AST:
self.generic_visit(node)
super().visit(node)
self._verify_annotation(node, node.annotation, node.target, inferred=False)
return node

Expand Down Expand Up @@ -219,27 +219,7 @@ def visit_BinOp(self, node):

if left_id is not None and right_id is not None and (left_id, right_id, type(node.op)) in ILLEGAL_COMBINATIONS:
raise AstUnrecognisedBinOp(left_id, right_id, node)
return node

# def visit_Call(self, node: ast.Call) -> Any:
# # TODO: This is just to keep inference language-independent for now
# self.generic_visit(node)
# fname = get_id(node.func)
# if (func := class_for_typename(fname, None, locals=self._imported_names)) \
# in self._func_type_map:
# InferTypesTransformer._annotate(node, self._func_type_map[func](self, node, node.args))
# else:
# # Use annotation
# ann = getattr(node.func, "annotation", None)
# func_name = ast.unparse(ann) if ann else None
# if isinstance(node.func, ast.Attribute):
# ann = getattr(node.func.value, "annotation", None)
# if ann:
# func_name = f"{ast.unparse(ann)}.{node.func.attr}"
# if (func := class_for_typename(func_name, None, locals=self._imported_names)) \
# in self._func_type_map:
# InferTypesTransformer._annotate(node, self._func_type_map[func](self, node, node.args))
# return node
return super().visit(node)

######################################################
################# Inference Methods ##################
Expand Down
2 changes: 1 addition & 1 deletion pykt/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def get_inferred_kotlin_type(node):


# Copy pasta from rust. Double check for correctness
class InferKotlinTypesTransformer(ast.NodeTransformer):
class InferKotlinTypesTransformer(InferTypesTransformer):
"""Implements kotlin type inference logic as opposed to python type inference logic"""

FIXED_WIDTH_INTS = InferTypesTransformer.FIXED_WIDTH_INTS
Expand Down
2 changes: 1 addition & 1 deletion pynim/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def get_inferred_nim_type(node):


# Copy pasta from rust. Double check for correctness
class InferNimTypesTransformer(ast.NodeTransformer):
class InferNimTypesTransformer(InferTypesTransformer):
"""Implements nim type inference logic as opposed to python type inference logic"""

FIXED_WIDTH_INTS = InferTypesTransformer.FIXED_WIDTH_INTS
Expand Down
2 changes: 1 addition & 1 deletion pyrs/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@
RUST_RANK_TO_TYPE = {v: k for k, v in RUST_WIDTH_RANK.items()}


class InferRustTypesTransformer(ast.NodeTransformer):
class InferRustTypesTransformer(InferTypesTransformer):
"""Implements rust type inference logic as opposed to python type inference logic"""

FIXED_WIDTH_INTS = InferTypesTransformer.FIXED_WIDTH_INTS
Expand Down
2 changes: 1 addition & 1 deletion pysmt/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def get_inferred_smt_type(node):


# Copy pasta from rust. Double check for correctness
class InferSmtTypesTransformer(ast.NodeTransformer):
class InferSmtTypesTransformer(InferTypesTransformer):
"""Implements smt type inference logic as opposed to python type inference logic"""

FIXED_WIDTH_INTS = InferTypesTransformer.FIXED_WIDTH_INTS
Expand Down

0 comments on commit f5dd625

Please sign in to comment.