Skip to content

Commit

Permalink
cross module: toposort modules before transpiling
Browse files Browse the repository at this point in the history
This avoids referencing modules on which scopes haven't been
computed.
  • Loading branch information
adsharma committed Jun 26, 2021
1 parent 967d9fe commit b71d35e
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 20 deletions.
31 changes: 20 additions & 11 deletions py2many/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from .mutability_transformer import detect_mutable_vars
from .nesting_transformer import detect_nesting_levels
from .scope import add_scope_context
from .toposort_modules import toposort

from pycpp.transpiler import CppTranspiler, CppListComparisonRewriter
from pyrs.inference import infer_rust_types
Expand Down Expand Up @@ -107,7 +108,8 @@ def _transpile(
tree = ast.parse(source)
tree.__file__ = filename
tree_list.append(tree)
trees = tuple(tree_list)
trees = toposort(tree_list)
topo_filenames = [t.__file__ for t in trees]
language = transpiler.NAME
generic_rewriters = [
ComplexDestructuringRewriter(language),
Expand All @@ -125,15 +127,15 @@ def _transpile(
]
rewriters = generic_rewriters + rewriters
post_rewriters = generic_post_rewriters + post_rewriters
outputs = []
outputs = {}
successful = []
for filename, tree in zip(filenames, trees):
for filename, tree in zip(topo_filenames, trees):
try:
output = _transpile_one(
trees, tree, transpiler, rewriters, transformers, post_rewriters
)
successful.append(filename)
outputs.append(output)
outputs[filename] = output
except Exception as e:
import traceback

Expand All @@ -142,8 +144,10 @@ def _transpile(
print(f"{filename}:{e.lineno}:{e.col_offset}: {formatted_lines[-1]}")
else:
print(f"{filename}: {formatted_lines[-1]}")
outputs.append("FAILED")
return outputs, successful
outputs[filename] = "FAILED"
# return output in the same order as input
output_list = [outputs[f] for f in filenames]
return output_list, successful


def _transpile_one(trees, tree, transpiler, rewriters, transformers, post_rewriters):
Expand Down Expand Up @@ -414,12 +418,14 @@ def _format_one(settings, output_path, env=None):
FileSet = Set[pathlib.Path]


def _process_many(settings, filenames, outdir, env=None) -> Tuple[FileSet, FileSet]:
def _process_many(
settings, basedir, filenames, outdir, env=None
) -> Tuple[FileSet, FileSet]:
"""Transpile and reformat many files."""

source_data = []
for filename in filenames:
with open(filename) as f:
with open(basedir / filename) as f:
source_data.append(f.read())

outputs, successful = _transpile(
Expand All @@ -429,7 +435,8 @@ def _process_many(settings, filenames, outdir, env=None) -> Tuple[FileSet, FileS
)

output_paths = [
_get_output_path(filename, settings.ext, outdir) for filename in filenames
_get_output_path(basedir / filename, settings.ext, outdir)
for filename in filenames
]
for filename, output, output_path in zip(filenames, outputs, output_paths):
with open(output_path, "w") as f:
Expand Down Expand Up @@ -461,9 +468,11 @@ def _process_dir(settings, source, outdir, env=None, _suppress_exceptions=True):
target_path = outdir / relative_path
target_dir = target_path.parent
os.makedirs(target_dir, exist_ok=True)
input_paths.append(path)
input_paths.append(relative_path)

successful, format_errors = _process_many(settings, input_paths, outdir, env=env)
successful, format_errors = _process_many(
settings, source, input_paths, outdir, env=env
)
failures = set(input_paths) - set(successful)

print("\nFinished!")
Expand Down
49 changes: 49 additions & 0 deletions py2many/toposort_modules.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import ast
from pathlib import Path
from toposort import toposort_flatten
from collections import defaultdict
from typing import Tuple


def module_for_path(path: Path) -> str:
# strip out .py at the end
return ".".join(path.parts)[:-3]


class ImportDependencyVisitor(ast.NodeVisitor):
def __init__(self, modules):
self.deps = defaultdict(set)
self._modules = modules

def visit_Module(self, node):
self._current = module_for_path(node.__file__)
self.generic_visit(node)

def visit_ImportFrom(self, node):
if node.module in self._modules:
self.deps[self._current].add(node.module)
self.generic_visit(node)

def visit_Import(self, node):
names = [n.name for n in node.names]
for n in names:
if n in self._modules:
self.deps[self._current].add(n)
self.generic_visit(node)


def get_dependencies(trees):
modules = {module_for_path(node.__file__) for node in trees}
visitor = ImportDependencyVisitor(modules)
for t in trees:
visitor.visit(t)
for m in modules:
if m not in visitor.deps:
visitor.deps[m] = set()
return visitor.deps


def toposort(trees) -> Tuple:
deps = get_dependencies(trees)
tree_dict = {module_for_path(node.__file__): node for node in trees}
return tuple([tree_dict[t] for t in toposort_flatten(deps, sort=True)])
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

__version__ = "0.2.1"

install_requires = []
install_requires = ["toposort"]
setup_requires = []
tests_require = ["pytest", "unittest-expander", "argparse_dataclass"]

Expand Down
13 changes: 5 additions & 8 deletions tests/test_transpile_self.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def test_kotlin_recursive(self):
"rewriters.py",
"scope.py",
"tracer.py",
"toposort_modules.py",
},
)

Expand Down Expand Up @@ -170,7 +171,6 @@ def test_nim_recursive(self):
"analysis.py",
"annotation_transformer.py",
"clike.py",
"context.py",
"declaration_extractor.py",
"exceptions.py",
"mutability_transformer.py",
Expand All @@ -191,17 +191,14 @@ def test_cpp_recursive(self):
OUT_DIR,
_suppress_exceptions=suppress_exceptions,
)
assert len(successful) == 10
assert set(failures) == {
transpiler_module / "plugins.py",
transpiler_module / "__init__.py",
}
assert len(successful) == 11
assert set(failures) == {Path("plugins.py")}

successful, format_errors, failures = _process_dir(
settings, PY2MANY_MODULE, OUT_DIR, _suppress_exceptions=suppress_exceptions
)
assert len(successful) == 16
assert len(failures) == 1
assert len(successful) == 15
assert len(failures) == 2

def test_julia_recursive(self):
settings = self.SETTINGS["julia"]
Expand Down
1 change: 1 addition & 0 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ deps =
unittest-expander
pytest-cov
astpretty
toposort
git+https://github.com/mivade/argparse_dataclass/
git+https://github.com/adsharma/adt/
commands =
Expand Down

0 comments on commit b71d35e

Please sign in to comment.