Skip to content

Commit

Permalink
Improve Arbitrary Precision translation
Browse files Browse the repository at this point in the history
  • Loading branch information
MiguelMarcelino committed Sep 16, 2022
1 parent 8258b2f commit 687befd
Showing 1 changed file with 29 additions and 27 deletions.
56 changes: 29 additions & 27 deletions pyjl/rewriters.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from pathlib import Path
from pyclbr import Function
import re
from typing import Any, Dict, cast
from typing import Any, Dict, Union, cast

from py2many.clike import class_for_typename

Expand Down Expand Up @@ -908,58 +908,60 @@ class JuliaArbitraryPrecisionRewriter(ast.NodeTransformer):
def __init__(self) -> None:
super().__init__()
self._use_arbitrary_precision = False
self._arbitrary_precision_vars = set()

def visit_Module(self, node: ast.Module) -> Any:
self._use_arbitrary_precision = getattr(node, "use_arbitrary_precision", False)
self._arbitrary_precision_vars = set()
self.generic_visit(node)
return node

def visit_Name(self, node: ast.Name) -> Any:
if get_id(node) in self._arbitrary_precision_vars:
node.is_arbitrary_precision_var = True
return node

def visit_Assign(self, node: ast.Assign) -> Any:
# self.generic_visit(node)
for t in node.targets:
self.visit(t)
self._generic_assign_visit(node)
self._generic_assign_visit(node, target=node.targets[0])
return node

def visit_AnnAssign(self, node: ast.AnnAssign) -> Any:
# self.generic_visit(node)
self.visit(node.target)
self._generic_assign_visit(node)
self._generic_assign_visit(node, target=node.target)
return node

def _generic_assign_visit(self, node):
type_comment = getattr(node, "type_comment", None)
if isinstance(node.value, ast.Constant):
node.value = self.visit_Constant(node.value, type_comment)
else:
if getattr(node, "value", None):
self.visit(node.value)

def visit_Constant(self, node: ast.Constant, type_comment=None):
def _generic_assign_visit(self, node: Union[ast.Assign, ast.AnnAssign], target):
self.generic_visit(node)
ann = getattr(node, "annotation", None)
if ann:
is_int = lambda x: get_id(x) == "int"
is_float = lambda x: get_id(x) == "float"
func_name = "BigInt" if is_int(ann) else "BigFloat"
if (type_comment == "BigInt" or type_comment == "BigFloat" or
(self._use_arbitrary_precision and
(is_int(ann) or is_float(ann)))):
annotation = get_id(getattr(target, "annotation", None))
if annotation:
if (annotation == "BigInt" or annotation == "BigFloat" or
(self._use_arbitrary_precision and
(annotation == "int" or annotation == "float")))\
and not getattr(node.value, "ignore_wrap", None):
self._arbitrary_precision_vars.add(get_id(target))
func_name = "BigInt" if annotation == "int" else "BigFloat"
lineno = getattr(node, "lineno", 0)
col_offset = getattr(node, "col_offset", 0)
return ast.Call(
node.value = ast.Call(
func = ast.Name(id=func_name),
args = [ast.Constant(
value = node.value,
annotation = ann,
scopes = node.scopes)],
args = [node.value],
keywords = [],
lineno = lineno,
col_offset = col_offset,
annotation = ann,
annotation = ast.Name(id=annotation),
scopes = node.scopes)
return node
ast.fix_missing_locations(node.value)

def visit_BinOp(self, node: ast.BinOp) -> Any:
self.generic_visit(node)
node.ignore_wrap = (
getattr(node.left, "is_arbitrary_precision_var", False) or
getattr(node.right, "is_arbitrary_precision_var", False))
return node

###########################################################
############### Removing nested constructs ################
Expand Down

0 comments on commit 687befd

Please sign in to comment.