Skip to content

Commit

Permalink
InferredAnnAssignRewriter: Dont redeclare type
Browse files Browse the repository at this point in the history
Also dont add annotation to subscript assignments.

Closes py2many#360
  • Loading branch information
jayvdb committed Jul 4, 2021
1 parent d0d30b9 commit da9845b
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 12 deletions.
21 changes: 18 additions & 3 deletions py2many/rewriters.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@

class InferredAnnAssignRewriter(ast.NodeTransformer):
def visit_Assign(self, node):
target = node.targets[0]
target = node.targets[0] # Assumes all targets have same annotation
if isinstance(target, ast.Subscript):
return node
annotation = getattr(target, "annotation", False)
if not annotation:
return node
Expand All @@ -23,8 +25,21 @@ def visit_Assign(self, node):

assigns = []
for assign_target in node.targets:
print(assign_target.__class__)
# assert False
definition = node.scopes.find(get_id(assign_target))
if definition is not assign_target:
previous_type = get_inferred_type(definition)
if get_id(previous_type) == get_id(annotation):
if len(node.targets) == 1:
return node
else:
new_node = ast.Assign(
targets=[assign_target],
value=node.value,
lineno=node.lineno,
col_offset=col_offset,
)
assigns.append(new_node)
continue
new_node = ast.AnnAssign(
target=assign_target,
value=node.value,
Expand Down
4 changes: 2 additions & 2 deletions tests/expected/binit.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ def bisect_right(data: List[int], item: int) -> int:
while low < high:
middle: int = int((low + high) / 2)
if item < data[middle]:
high: int = middle
high = middle
else:
low: int = middle + 1
low = middle + 1
return low


Expand Down
4 changes: 2 additions & 2 deletions tests/expected/bubble_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ def bubble_sort(seq: List[int]) -> List[int]:
if seq[n] < seq[n - 1]:
if True:
(__tmp1, __tmp2) = (seq[n], seq[n - 1])
seq[n - 1]: int = __tmp1
seq[n]: int = __tmp2
seq[n - 1] = __tmp1
seq[n] = __tmp2
return seq


Expand Down
8 changes: 4 additions & 4 deletions tests/expected/comb_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@ def comb_sort(seq: List[int]) -> List[int]:
swap: bool = True
while gap > 1 or swap:
gap: int = max(1, floor(gap / 1.25))
swap: bool = False
swap = False
for i in range(len(seq) - gap):
if seq[i] > seq[i + gap]:
if True:
(__tmp1, __tmp2) = (seq[i + gap], seq[i])
seq[i]: int = __tmp1
seq[i + gap]: int = __tmp2
swap: bool = True
seq[i] = __tmp1
seq[i + gap] = __tmp2
swap = True
return seq


Expand Down
2 changes: 1 addition & 1 deletion tests/expected/rect.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def is_square(self) -> bool:
def show():
r: Rectangle = Rectangle(height=1, length=1)
assert r.is_square()
r: Rectangle = Rectangle(height=1, length=2)
r = Rectangle(height=1, length=2)
assert not r.is_square()
print(r.height)
print(r.length)
Expand Down

0 comments on commit da9845b

Please sign in to comment.