Skip to content

Commit

Permalink
Minor bug fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
MiguelMarcelino committed Sep 21, 2022
1 parent af9c2a8 commit d90cf6b
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 28 deletions.
21 changes: 14 additions & 7 deletions py2many/rewriters.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from typing import Any, cast, Optional

from py2many.scope import ScopeList
from py2many.tracer import find_node_by_name_and_type, find_node_by_type
from py2many.tracer import find_node_by_name_and_type, find_node_by_type, find_parent_of_type


class InferredAnnAssignRewriter(ast.NodeTransformer):
Expand Down Expand Up @@ -612,11 +612,10 @@ def _generic_loop_visit(self, node):
if len(node.orelse) > 0:
lineno = node.orelse[0].lineno
if_expr = ast.If(
test=ast.Compare(
left=ast.Name(id=self._has_break_var_name),
ops=[ast.NotEq()],
comparators=[ast.Constant(value=True, scopes=node.scopes)],
scopes=node.scopes,
test= ast.UnaryOp(
op = ast.Not(),
operand=ast.Name(id=self._has_break_var_name),
scopes=node.scopes
),
body=[oe for oe in node.orelse],
orelse=[],
Expand All @@ -627,20 +626,28 @@ def _generic_loop_visit(self, node):

def _visit_Scope(self, node) -> Any:
self.generic_visit(node)
parent_loop = find_parent_of_type(ast.For, node.scopes)
is_local = False
if parent_loop and \
getattr(parent_loop, "orelse", None):
# Check if the node is a local assignment
is_local = True
assign = ast.Assign(
targets=[ast.Name(id=self._has_break_var_name)],
value=None,
scopes=node.scopes,
scopes=node.scopes
)
ast.fix_missing_locations(assign)
body = []
for n in node.body:
if hasattr(n, "if_expr"):
assign.value = ast.Constant(value=False, scopes=ScopeList())
assign.local = is_local
body.append(assign)
body.append(n)
body.append(n.if_expr)
elif isinstance(n, ast.Break):
# assign.local = False
for_node = find_node_by_type(ast.For, node.scopes)
if hasattr(for_node, "if_expr"):
assign.value = ast.Constant(value=True, scopes=ScopeList())
Expand Down
2 changes: 1 addition & 1 deletion pyjl/external/modules/ctypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from typing import Callable, Dict, Optional, Tuple, Union

from py2many.ast_helpers import get_id
from pyjl.helpers import get_python_dll_path, pycall_import
from pyjl.helpers import pycall_import

class JuliaExternalModulePlugins():
def visit_load_library(self, node: ast.Call, vargs: list[str], kwargs: list[tuple[str,str]]):
Expand Down
2 changes: 0 additions & 2 deletions pyjl/helpers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
# Gets range from for loop
import ast
import subprocess
import os
import random
from typing import Optional

Expand Down
21 changes: 6 additions & 15 deletions pyjl/rewriters.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,21 +196,12 @@ def visit_AugAssign(self, node: ast.AugAssign) -> Any:
right = ast.Constant(value=1)
)
ast.fix_missing_locations(value)
if isinstance(node.target, ast.List) and \
len(node.target.elts) == 1:
repeat_arg = ast.Call(
func = ast.Name(id="fill"),
args = [node.target.elts[0], value],
keywords = [],
scopes = node.scopes
)
else:
repeat_arg = ast.Call(
func = ast.Name(id="repeat"),
args = [node.target, value],
keywords = [],
scopes = node.scopes
)
repeat_arg = ast.Call(
func = ast.Name(id="repeat"),
args = [node.target, value],
keywords = [],
scopes = node.scopes
)
ast.fix_missing_locations(repeat_arg)
call.args.append(node.target)
call.args.append(repeat_arg)
Expand Down
16 changes: 13 additions & 3 deletions pyjl/transpiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,14 +520,20 @@ def visit_BinOp(self, node: ast.BinOp) -> str:
if isinstance(node.op, ast.Mult):
# Cover multiplication between List/Tuple and Int
if isinstance(node.right, ast.Num) or is_num(right_jl_ann):
if ((isinstance(node.left, ast.List) or is_list(left_jl_ann)) or
if isinstance(node.left, ast.List) and \
len(node.left.elts) == 1:
return f"fill({self.visit(node.left.elts[0])},{right})"
elif ((isinstance(node.left, ast.List) or is_list(left_jl_ann)) or
(isinstance(node.left, ast.Str) or left_jl_ann == "String")):
return f"repeat({left},{right})"
elif isinstance(node.left, ast.Tuple) or is_tuple(left_jl_ann):
return f"repeat([{left}...],{right})"

if isinstance(node.left, ast.Num) or is_num(left_jl_ann):
if ((isinstance(node.right, ast.List) or is_list(right_jl_ann)) or
if isinstance(node.right, ast.List) and \
len(node.right.elts) == 1:
return f"fill({self.visit(node.right.elts[0])},{left})"
elif ((isinstance(node.right, ast.List) or is_list(right_jl_ann)) or
(isinstance(node.right, ast.Str) or right_jl_ann == "String")):
return f"repeat({right},{left})"
elif isinstance(node.right, ast.Tuple) or is_tuple(right_jl_ann):
Expand Down Expand Up @@ -1095,7 +1101,11 @@ def visit_Assign(self, node: ast.Assign) -> str:
# Optimization to use global constants
if getattr(node, "use_constant", None):
return f"const {targets[0]} {op} {value}"


# Support for local variables
if getattr(node, "local", None):
return f"local {'='.join(targets)} {op} {value}"

return f"{'='.join(targets)} {op} {value}"

def visit_Delete(self, node: ast.Delete) -> str:
Expand Down

0 comments on commit d90cf6b

Please sign in to comment.