Skip to content

Commit

Permalink
pygo: Finish moving dispatch to plugins
Browse files Browse the repository at this point in the history
  • Loading branch information
jayvdb committed Jun 27, 2021
1 parent a4593a2 commit a5d0704
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 96 deletions.
59 changes: 22 additions & 37 deletions pygo/plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,15 @@ def visit_ap_dataclass(self, cls):
return cls

def visit_range(self, node, vargs: List[str]) -> str:
self._usings.add('iter "github.com/hgfischer/go-iter"')
if len(node.args) == 1:
return "(0..{})".format(vargs[0])
return f"iter.NewIntSeq(iter.Start(0), iter.Stop({vargs[0]})).All()"
elif len(node.args) == 2:
return "({}..{})".format(vargs[0], vargs[1])
return (
f"iter.NewIntSeq(iter.Start({vargs[0]}), iter.Stop({vargs[1]})).All()"
)
elif len(node.args) == 3:
return "({}..{}).step_by({})".format(vargs[0], vargs[1], vargs[2])
return f"iter.NewIntSeq(iter.Start({vargs[0]}), iter.Stop({vargs[1]}), iter.Step({vargs[2]})).All()"

raise Exception(
"encountered range() call with unknown parameters: range({})".format(vargs)
Expand All @@ -103,57 +106,39 @@ def visit_range(self, node, vargs: List[str]) -> str:
def visit_print(self, node, vargs: List[str]) -> str:
placeholders = []
for n in node.args:
placeholders.append("{}")
return 'println!("{0}",{1});'.format(" ".join(placeholders), ", ".join(vargs))
placeholders.append("%v")
self._usings.add('"fmt"')
placeholders_str = " ".join(placeholders)
vargs_str = ", ".join(vargs)
return f'fmt.Printf("{placeholders_str}\\n",{vargs_str})'

def visit_min_max(self, node, vargs, is_max: bool) -> str:
self._usings.add("std::cmp")
min_max = "max" if is_max else "min"
self._typename_from_annotation(node.args[0])
if hasattr(node.args[0], "container_type"):
node.result_type = True
return f"{vargs[0]}.iter().{min_max}()"
else:
all_vargs = ", ".join(vargs)
return f"cmp::{min_max}({all_vargs})"
min_max = "math.Max" if is_max else "math.Min"
self._usings.add('"math"')
vargs_str = ", ".join(vargs)
return f"{min_max}({vargs_str})"

@staticmethod
def visit_cast(node, vargs, cast_to: str) -> str:
return f"{vargs[0]} as {cast_to}"

@staticmethod
def visit_asyncio_run(node, vargs) -> str:
return f"block_on({vargs[0]})"
def visit_floor(self, node, vargs) -> str:
self._usings.add('"math"')
return f"math.Floor({vargs[0]})"


# small one liners are inlined here as lambdas
SMALL_DISPATCH_MAP = {
"str": lambda n, vargs: f"&{vargs[0]}.to_string()",
"len": lambda n, vargs: f"{vargs[0]}.len()",
"enumerate": lambda n, vargs: f"{vargs[0]}.iter().enumerate()",
"sum": lambda n, vargs: f"{vargs[0]}.iter().sum()",
"int": functools.partial(GoTranspilerPlugins.visit_cast, cast_to="i32"),
"str": lambda n, vargs: f"String({vargs[0]})",
"bool": lambda n, vargs: f"({vargs[0]} != 0)",
"float": functools.partial(GoTranspilerPlugins.visit_cast, cast_to="f64"),
# as usize below is a hack to pass comb_sort.rs. Need a better solution
"floor": lambda n, vargs: f"{vargs[0]}.floor() as usize",
"reversed": lambda n, vargs: f"{vargs[0]}.iter().rev()",
"map": lambda n, vargs: f"{vargs[1]}.iter().map({vargs[0]})",
"filter": lambda n, vargs: f"{vargs[1]}.into_iter().filter({vargs[0]})",
"list": lambda n, vargs: f"{vargs[0]}.collect::<Vec<_>>()",
"asyncio.run": GoTranspilerPlugins.visit_asyncio_run,
}

SMALL_USINGS_MAP = {
"asyncio.run": "futures::executor::block_on",
}
SMALL_USINGS_MAP = {}

DISPATCH_MAP = {
"max": functools.partial(GoTranspilerPlugins.visit_min_max, is_max=True),
"min": functools.partial(GoTranspilerPlugins.visit_min_max, is_min=True),
"range": GoTranspilerPlugins.visit_range,
"range_": GoTranspilerPlugins.visit_range,
"xrange": GoTranspilerPlugins.visit_range,
"print": GoTranspilerPlugins.visit_print,
"floor": GoTranspilerPlugins.visit_floor,
}

MODULE_DISPATCH_TABLE = {}
Expand Down
61 changes: 2 additions & 59 deletions pygo/transpiler.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import ast
import functools
import textwrap

from typing import List

from .clike import CLikeTranspiler
from .inference import get_inferred_go_type
from .plugins import (
Expand All @@ -21,8 +22,6 @@
from py2many.rewriters import capitalize_first, rename, camel_case
from py2many.tracer import is_list, defined_before, is_class_or_module, is_enum

from typing import Optional, List


class GoMethodCallRewriter(ast.NodeTransformer):
def visit_Call(self, node):
Expand Down Expand Up @@ -253,62 +252,6 @@ def visit_Attribute(self, node):

return f"{value_id}.{attr}"

def visit_range(self, node, vargs: List[str]) -> str:
self._usings.add('iter "github.com/hgfischer/go-iter"')
if len(node.args) == 1:
return f"iter.NewIntSeq(iter.Start(0), iter.Stop({vargs[0]})).All()"
elif len(node.args) == 2:
return (
f"iter.NewIntSeq(iter.Start({vargs[0]}), iter.Stop({vargs[1]})).All()"
)
elif len(node.args) == 3:
return f"iter.NewIntSeq(iter.Start({vargs[0]}), iter.Stop({vargs[1]}), iter.Step({vargs[2]})).All()"

raise Exception(
"encountered range() call with unknown parameters: range({})".format(vargs)
)

def _visit_print(self, node, vargs: List[str]) -> str:
placeholders = []
for n in node.args:
placeholders.append("%v")
self._usings.add('"fmt"')
placeholders_str = " ".join(placeholders)
vargs_str = ", ".join(vargs)
return f'fmt.Printf("{placeholders_str}\\n",{vargs_str})'

def _dispatch(self, node, fname: str, vargs: List[str]) -> Optional[str]:
dispatch_map = {
"range_": self.visit_range,
"xrange": self.visit_range,
"print": self._visit_print,
}

if fname in dispatch_map:
return dispatch_map[fname](node, vargs)

def visit_min_max(is_max: bool) -> str:
min_max = "math.Max" if is_max else "math.Min"
self._usings.add('"math"')
vargs_str = ", ".join(vargs)
return f"{min_max}({vargs_str})"

def visit_floor():
self._usings.add('"math"')
return f"math.Floor({vargs[0]})"

# small one liners are inlined here as lambdas
small_dispatch_map = {
"str": lambda: f"String({vargs[0]})",
"bool": lambda: f"({vargs[0]} != 0)",
"max": functools.partial(visit_min_max, is_max=True),
"min": functools.partial(visit_min_max, is_min=True),
"floor": visit_floor,
}
if fname in small_dispatch_map:
return small_dispatch_map[fname]()
return None

def _visit_struct_literal(self, node, fname: str, fndef: ast.ClassDef):
vargs = [] # visited args
if not hasattr(fndef, "declarations"):
Expand Down

0 comments on commit a5d0704

Please sign in to comment.