Skip to content

Commit

Permalink
pykt: Finish moving dispatch to plugins
Browse files Browse the repository at this point in the history
  • Loading branch information
jayvdb committed Jun 27, 2021
1 parent a5d0704 commit 316c227
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 107 deletions.
61 changes: 27 additions & 34 deletions pykt/plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import ast
import functools
import re
import textwrap

from tempfile import NamedTemporaryFile
Expand Down Expand Up @@ -90,70 +91,62 @@ def visit_ap_dataclass(self, cls):

def visit_range(self, node, vargs: List[str]) -> str:
if len(node.args) == 1:
return "(0..{})".format(vargs[0])
return "(0..{}-1)".format(vargs[0])
elif len(node.args) == 2:
return "({}..{})".format(vargs[0], vargs[1])
return "({}..{}-1)".format(vargs[0], vargs[1])
elif len(node.args) == 3:
return "({}..{}).step_by({})".format(vargs[0], vargs[1], vargs[2])
return "({}..{}-1 step {})".format(vargs[0], vargs[1], vargs[2])

raise Exception(
"encountered range() call with unknown parameters: range({})".format(vargs)
)

def visit_print(self, node, vargs: List[str]) -> str:
placeholders = []
for n in node.args:
placeholders.append("{}")
return 'println!("{0}",{1});'.format(" ".join(placeholders), ", ".join(vargs))
def _format(arg):
if arg.isdigit():
return arg
if re.match(r"'.*'", arg) or re.match(r'".*"', arg):
return arg[1:-1]
else:
return f"${arg}"

vargs_str = " ".join([f"{_format(arg)}" for arg in vargs])
return f'println("{vargs_str}")'

def visit_min_max(self, node, vargs, is_max: bool) -> str:
self._usings.add("std::cmp")
min_max = "max" if is_max else "min"
self._usings.add(f"kotlin.math.{min_max}")
self._typename_from_annotation(node.args[0])
if hasattr(node.args[0], "container_type"):
node.result_type = True
return f"{vargs[0]}.iter().{min_max}()"
return f"maxOf({vargs[0]})"
else:
all_vargs = ", ".join(vargs)
return f"cmp::{min_max}({all_vargs})"
return f"{min_max}({all_vargs})"

@staticmethod
def visit_cast(node, vargs, cast_to: str) -> str:
return f"{vargs[0]} as {cast_to}"

@staticmethod
def visit_asyncio_run(node, vargs) -> str:
return f"block_on({vargs[0]})"
def visit_floor(self, node, vargs) -> str:
self._usings.add("kotlin.math.floor")
return f"floor({vargs[0]}).toInt()"


# small one liners are inlined here as lambdas
SMALL_DISPATCH_MAP = {
"str": lambda n, vargs: f"&{vargs[0]}.to_string()",
"len": lambda n, vargs: f"{vargs[0]}.len()",
"enumerate": lambda n, vargs: f"{vargs[0]}.iter().enumerate()",
"sum": lambda n, vargs: f"{vargs[0]}.iter().sum()",
"int": functools.partial(KotlinTranspilerPlugins.visit_cast, cast_to="i32"),
"str": lambda n, vargs: f"{vargs[0]}.toString()",
# TODO: strings use .length
"len": lambda n, vargs: f"{vargs[0]}.size",
"int": lambda n, vargs: f"{vargs[0]}.toInt()",
"bool": lambda n, vargs: f"({vargs[0]} != 0)",
"float": functools.partial(KotlinTranspilerPlugins.visit_cast, cast_to="f64"),
# as usize below is a hack to pass comb_sort.rs. Need a better solution
"floor": lambda n, vargs: f"{vargs[0]}.floor() as usize",
"reversed": lambda n, vargs: f"{vargs[0]}.iter().rev()",
"map": lambda n, vargs: f"{vargs[1]}.iter().map({vargs[0]})",
"filter": lambda n, vargs: f"{vargs[1]}.into_iter().filter({vargs[0]})",
"list": lambda n, vargs: f"{vargs[0]}.collect::<Vec<_>>()",
"asyncio.run": KotlinTranspilerPlugins.visit_asyncio_run,
"reversed": lambda n, vargs: f"{vargs[0]}.reversed()",
}

SMALL_USINGS_MAP = {
"asyncio.run": "futures::executor::block_on",
}
SMALL_USINGS_MAP = {}

DISPATCH_MAP = {
"max": functools.partial(KotlinTranspilerPlugins.visit_min_max, is_max=True),
"min": functools.partial(KotlinTranspilerPlugins.visit_min_max, is_min=True),
"range": KotlinTranspilerPlugins.visit_range,
"xrange": KotlinTranspilerPlugins.visit_range,
"print": KotlinTranspilerPlugins.visit_print,
"floor": KotlinTranspilerPlugins.visit_floor,
}

MODULE_DISPATCH_TABLE = {}
Expand Down
81 changes: 8 additions & 73 deletions pykt/transpiler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import ast
import functools
import re

from typing import List, Tuple

from .clike import CLikeTranspiler
from .inference import get_inferred_kotlin_type
Expand All @@ -19,25 +19,19 @@
from py2many.declaration_extractor import DeclarationExtractor
from py2many.tracer import is_list, defined_before, is_class_or_module, is_self_arg

from typing import Optional, List, Tuple


class KotlinPrintRewriter(ast.NodeTransformer):
def __init__(self):
super().__init__()
self._temp = 0
self._dispatch_map = DISPATCH_MAP
self._small_dispatch_map = SMALL_DISPATCH_MAP
self._small_usings_map = SMALL_USINGS_MAP
self._func_dispatch_table = FUNC_DISPATCH_TABLE
self._attr_dispatch_table = ATTR_DISPATCH_TABLE

def _get_temp(self):
self._temp += 1
return f"__tmp{self._temp}"

def visit_Call(self, node):
fname = self.visit(node.func)

if (
get_id(fname) == "print"
and len(node.args) == 1
Expand Down Expand Up @@ -100,6 +94,11 @@ def __init__(self):
super().__init__()
self._default_type = ""
self._container_type_map = self.CONTAINER_TYPE_MAP
self._dispatch_map = DISPATCH_MAP
self._small_dispatch_map = SMALL_DISPATCH_MAP
self._small_usings_map = SMALL_USINGS_MAP
self._func_dispatch_table = FUNC_DISPATCH_TABLE
self._attr_dispatch_table = ATTR_DISPATCH_TABLE
self._main_signature_arg_names = ["argv"]

def usings(self):
Expand Down Expand Up @@ -203,70 +202,6 @@ def visit_Attribute(self, node):

return f"{value_id}.{attr}"

def visit_range(self, node, vargs: List[str]) -> str:
if len(node.args) == 1:
return "(0..{}-1)".format(vargs[0])
elif len(node.args) == 2:
return "({}..{}-1)".format(vargs[0], vargs[1])
elif len(node.args) == 3:
return "({}..{}-1 step {})".format(vargs[0], vargs[1], vargs[2])

raise Exception(
"encountered range() call with unknown parameters: range({})".format(vargs)
)

def _visit_print(self, node, vargs: List[str]) -> str:
def _format(arg):
if arg.isdigit():
return arg
if re.match(r"'.*'", arg) or re.match(r'".*"', arg):
return arg[1:-1]
else:
return f"${arg}"

vargs_str = " ".join([f"{_format(arg)}" for arg in vargs])
return f'println("{vargs_str}")'

def _dispatch(self, node, fname: str, vargs: List[str]) -> Optional[str]:
dispatch_map = {
"range": self.visit_range,
"xrange": self.visit_range,
"print": self._visit_print,
}

if fname in dispatch_map:
return dispatch_map[fname](node, vargs)

def visit_min_max(is_max: bool) -> str:
min_max = "max" if is_max else "min"
self._usings.add(f"kotlin.math.{min_max}")
self._typename_from_annotation(node.args[0])
if hasattr(node.args[0], "container_type"):
return f"maxOf({vargs[0]})"
else:
all_vargs = ", ".join(vargs)
return f"{min_max}({all_vargs})"

def visit_floor():
self._usings.add("kotlin.math.floor")
return f"floor({vargs[0]}).toInt()"

# small one liners are inlined here as lambdas
small_dispatch_map = {
"int": lambda: f"{vargs[0]}.toInt()",
"bool": lambda: f"({vargs[0]} != 0)",
"str": lambda: f"{vargs[0]}.toString()",
# TODO: strings use .length
"len": lambda: f"{vargs[0]}.size",
"max": functools.partial(visit_min_max, is_max=True),
"min": functools.partial(visit_min_max, is_min=True),
"floor": visit_floor,
"reversed": lambda: f"{vargs[0]}.reversed()",
}
if fname in small_dispatch_map:
return small_dispatch_map[fname]()
return None

def visit_Call(self, node):
fname = self.visit(node.func)
vargs = []
Expand Down

0 comments on commit 316c227

Please sign in to comment.