Skip to content

Commit

Permalink
Preliminary fix for MiguelMarcelino#36
Browse files Browse the repository at this point in the history
  • Loading branch information
MiguelMarcelino committed Nov 5, 2022
1 parent 926fefe commit a779c84
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 5 deletions.
5 changes: 3 additions & 2 deletions pyjl/rewriters.py
Original file line number Diff line number Diff line change
Expand Up @@ -1704,8 +1704,9 @@ def _build_assignments(self, targets):
for (target, ann) in targets:
new_target = ast.Name(id=get_id(target))
default = get_default_val(target, ann)
assign = ast.Assign(
targets=[new_target],
assign = ast.AnnAssign(
target = new_target,
annotation = ann,
value = default,
scopes = getattr(target, "scopes", ScopeList()))
ast.fix_missing_locations(assign)
Expand Down
19 changes: 16 additions & 3 deletions pyjl/transpiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,11 +478,12 @@ def visit_While(self, node) -> str:
return "\n".join(buf)

def visit_BinOp(self, node: ast.BinOp) -> str:
# Attempts to find node annotations
left_jl_ann: str = self._typename_from_type_node(
getattr(node.left, "annotation", self._default_type),
self._get_annotation(node.left),
default=self._default_type)
right_jl_ann: str = self._typename_from_type_node(
getattr(node.right, "annotation", self._default_type),
self._get_annotation(node.right),
default=self._default_type)

is_list = lambda x: re.match(r"^Array|^Vector", x)
Expand Down Expand Up @@ -570,6 +571,18 @@ def visit_BinOp(self, node: ast.BinOp) -> str:
# By default, call super
return super().visit_BinOp(node)

def _get_annotation(self, node):
node_id = get_id(node)
n_inst = node.scopes.find(node_id)
if hasattr(node, "annotation"):
return node.annotation
elif hasattr(n_inst, "annotation"):
return n_inst.annotation
elif hasattr(n_inst, "assigned_from") and \
hasattr(n_inst.assigned_from, "annotation"):
return n_inst.assigned_from.annotation
return None

def visit_UnaryOp(self, node: ast.UnaryOp) -> str:
if isinstance(node.operand, (ast.Call, ast.Num, ast.Name, ast.Constant)):
# Shortcut if parenthesis are not needed
Expand Down Expand Up @@ -1036,7 +1049,7 @@ def visit_Assert(self, node) -> str:
return "@assert({0})".format(self.visit(node.test))

def visit_AnnAssign(self, node: ast.AnnAssign) -> str:
if (val := self._generic_assign_visit(node, node.targets[0])) == "":
if (val := self._generic_assign_visit(node, node.target)) == "":
return val

target = self.visit(node.target)
Expand Down
12 changes: 12 additions & 0 deletions tests/cases/scope.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@

def test():
# The definition of a depends on the for loop's range
# Therefore, the propagated static type should not be
# int, but Optional[int] (Union[int, None])
# PyJL's fix_scope_bounds flag should propagate the type
# and generate the expected code
for i in range(1):
a = 2
print(a*"aa")

test()

0 comments on commit a779c84

Please sign in to comment.