Skip to content

Commit

Permalink
py2many: cross module type inference
Browse files Browse the repository at this point in the history
Passes the test case below. bar1() defined in another
file is correctly inferred as returning an int.

Tested via: py2many --rust=1 /tmp/testdir

containing the two files in the paste attached to the issue

Related: py2many#158
  • Loading branch information
adsharma committed Jun 26, 2021
1 parent 24ceddf commit 16df9e5
Show file tree
Hide file tree
Showing 11 changed files with 34 additions and 17 deletions.
13 changes: 7 additions & 6 deletions py2many/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@
CWD = pathlib.Path.cwd()


def core_transformers(tree):
add_variable_context(tree)
def core_transformers(tree, trees):
add_variable_context(tree, trees)
add_scope_context(tree)
add_list_calls(tree)
detect_mutable_vars(tree)
Expand Down Expand Up @@ -102,11 +102,12 @@ def _transpile(
rewriters = settings.rewriters
transformers = settings.transformers
post_rewriters = settings.post_rewriters
trees = []
tree_list = []
for filename, source in zip(filenames, sources):
tree = ast.parse(source)
tree.__file__ = filename
trees.append(tree)
tree_list.append(tree)
trees = tuple(tree_list)
language = transpiler.NAME
generic_rewriters = [
ComplexDestructuringRewriter(language),
Expand Down Expand Up @@ -153,15 +154,15 @@ def _transpile_one(trees, tree, transpiler, rewriters, transformers, post_rewrit
for rewriter in rewriters:
tree = rewriter.visit(tree)
# Language independent core transformers
tree, infer_meta = core_transformers(tree)
tree, infer_meta = core_transformers(tree, trees)
# Language specific transformers
for tx in transformers:
tx(tree)
# Language specific rewriters that depend on previous steps
for rewriter in post_rewriters:
tree = rewriter.visit(tree)
# Rerun core transformers
tree, infer_meta = core_transformers(tree)
tree, infer_meta = core_transformers(tree, trees)
out = []
code = transpiler.visit(tree) + "\n"
headers = transpiler.headers(infer_meta)
Expand Down
20 changes: 18 additions & 2 deletions py2many/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ def add_list_calls(node):
return ListCallTransformer().visit(node)


def add_variable_context(node):
def add_variable_context(node, trees):
"""Provide context to Module and Function Def"""
return VariableTransformer().visit(node)
return VariableTransformer(trees).visit(node)


class ListCallTransformer(ast.NodeTransformer):
Expand Down Expand Up @@ -50,6 +50,13 @@ def is_list_addition(self, node):
class VariableTransformer(ast.NodeTransformer, ScopeMixin):
"""Adds all defined variables to scope block"""

def __init__(self, trees):
super().__init__()
if len(trees) == 1:
self._trees = {}
else:
self._trees = {t.__file__.stem: t for t in trees}

def visit_FunctionDef(self, node):
node.vars = []
# So function signatures are accessible even after they're
Expand All @@ -74,6 +81,15 @@ def visit_Import(self, node):
name.imported_from = node
return node

def visit_ImportFrom(self, node):
module_path = node.module
names = [n.name for n in node.names]
if module_path in self._trees:
m = self._trees[module_path]
resolved_names = [m.scopes.find(n) for n in names]
node.scopes[-1].vars += resolved_names
return node

def visit_If(self, node):
node.vars = []
self.visit(node.test)
Expand Down
2 changes: 1 addition & 1 deletion pycpp/tests/test_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
def parse(*args):
source = ast.parse("\n".join(args))
add_scope_context(source)
add_variable_context(source)
add_variable_context(source, (source,))
return source


Expand Down
2 changes: 1 addition & 1 deletion pycpp/tests/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
def parse(*args):
source = ast.parse("\n".join(args))
add_scope_context(source)
add_variable_context(source)
add_variable_context(source, (source,))
return source


Expand Down
2 changes: 1 addition & 1 deletion pycpp/tests/test_scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,6 @@ def test_scope_added(self):
class TestScopeList:
def test_find_returns_most_upper_definition(self):
source = parse("x = 1", "def foo():", " x = 2")
add_variable_context(source)
add_variable_context(source, (source,))
definition = source.scopes.find("x")
assert definition.lineno == 1
2 changes: 1 addition & 1 deletion pycpp/tests/test_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

def parse(*args):
source = ast.parse("\n".join(args))
add_variable_context(source)
add_variable_context(source, (source,))
add_scope_context(source)
return source

Expand Down
2 changes: 1 addition & 1 deletion pycpp/transpiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def transpile(source, headers=False, testing=False):
tree = ast.parse(source)
rewriter = PythonMainRewriter("cpp")
tree = rewriter.visit(tree)
add_variable_context(tree)
add_variable_context(tree, (tree,))
add_scope_context(tree)
add_list_calls(tree)
add_imports(tree)
Expand Down
2 changes: 1 addition & 1 deletion pyrs/tests/test_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
def parse(*args):
source = ast.parse("\n".join(args))
add_scope_context(source)
add_variable_context(source)
add_variable_context(source, (source,))
return source


Expand Down
2 changes: 1 addition & 1 deletion pyrs/tests/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
def parse(*args):
source = ast.parse("\n".join(args))
add_scope_context(source)
add_variable_context(source)
add_variable_context(source, (source,))
return source


Expand Down
2 changes: 1 addition & 1 deletion pyrs/tests/test_scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,6 @@ def test_scope_added(self):
class TestScopeList:
def test_find_returns_most_upper_definition(self):
source = parse("x = 1", "def foo():", " x = 2")
add_variable_context(source)
add_variable_context(source, (source,))
definition = source.scopes.find("x")
assert definition.lineno == 1
2 changes: 1 addition & 1 deletion pyrs/tests/test_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

def parse(*args):
source = ast.parse("\n".join(args))
add_variable_context(source)
add_variable_context(source, (source,))
add_scope_context(source)
return source

Expand Down

0 comments on commit 16df9e5

Please sign in to comment.