From 316c227d24924044a0f3f2109471deeda0ab57fb Mon Sep 17 00:00:00 2001 From: John Vandenberg Date: Mon, 28 Jun 2021 07:53:54 +0800 Subject: [PATCH] pykt: Finish moving dispatch to plugins --- pykt/plugins.py | 61 ++++++++++++++++------------------ pykt/transpiler.py | 81 +++++----------------------------------------- 2 files changed, 35 insertions(+), 107 deletions(-) diff --git a/pykt/plugins.py b/pykt/plugins.py index 77bf6456d..7bd47b36b 100644 --- a/pykt/plugins.py +++ b/pykt/plugins.py @@ -2,6 +2,7 @@ import os import ast import functools +import re import textwrap from tempfile import NamedTemporaryFile @@ -90,63 +91,54 @@ def visit_ap_dataclass(self, cls): def visit_range(self, node, vargs: List[str]) -> str: if len(node.args) == 1: - return "(0..{})".format(vargs[0]) + return "(0..{}-1)".format(vargs[0]) elif len(node.args) == 2: - return "({}..{})".format(vargs[0], vargs[1]) + return "({}..{}-1)".format(vargs[0], vargs[1]) elif len(node.args) == 3: - return "({}..{}).step_by({})".format(vargs[0], vargs[1], vargs[2]) + return "({}..{}-1 step {})".format(vargs[0], vargs[1], vargs[2]) raise Exception( "encountered range() call with unknown parameters: range({})".format(vargs) ) def visit_print(self, node, vargs: List[str]) -> str: - placeholders = [] - for n in node.args: - placeholders.append("{}") - return 'println!("{0}",{1});'.format(" ".join(placeholders), ", ".join(vargs)) + def _format(arg): + if arg.isdigit(): + return arg + if re.match(r"'.*'", arg) or re.match(r'".*"', arg): + return arg[1:-1] + else: + return f"${arg}" + + vargs_str = " ".join([f"{_format(arg)}" for arg in vargs]) + return f'println("{vargs_str}")' def visit_min_max(self, node, vargs, is_max: bool) -> str: - self._usings.add("std::cmp") min_max = "max" if is_max else "min" + self._usings.add(f"kotlin.math.{min_max}") self._typename_from_annotation(node.args[0]) if hasattr(node.args[0], "container_type"): - node.result_type = True - return f"{vargs[0]}.iter().{min_max}()" + return f"maxOf({vargs[0]})" else: all_vargs = ", ".join(vargs) - return f"cmp::{min_max}({all_vargs})" + return f"{min_max}({all_vargs})" - @staticmethod - def visit_cast(node, vargs, cast_to: str) -> str: - return f"{vargs[0]} as {cast_to}" - - @staticmethod - def visit_asyncio_run(node, vargs) -> str: - return f"block_on({vargs[0]})" + def visit_floor(self, node, vargs) -> str: + self._usings.add("kotlin.math.floor") + return f"floor({vargs[0]}).toInt()" # small one liners are inlined here as lambdas SMALL_DISPATCH_MAP = { - "str": lambda n, vargs: f"&{vargs[0]}.to_string()", - "len": lambda n, vargs: f"{vargs[0]}.len()", - "enumerate": lambda n, vargs: f"{vargs[0]}.iter().enumerate()", - "sum": lambda n, vargs: f"{vargs[0]}.iter().sum()", - "int": functools.partial(KotlinTranspilerPlugins.visit_cast, cast_to="i32"), + "str": lambda n, vargs: f"{vargs[0]}.toString()", + # TODO: strings use .length + "len": lambda n, vargs: f"{vargs[0]}.size", + "int": lambda n, vargs: f"{vargs[0]}.toInt()", "bool": lambda n, vargs: f"({vargs[0]} != 0)", - "float": functools.partial(KotlinTranspilerPlugins.visit_cast, cast_to="f64"), - # as usize below is a hack to pass comb_sort.rs. Need a better solution - "floor": lambda n, vargs: f"{vargs[0]}.floor() as usize", - "reversed": lambda n, vargs: f"{vargs[0]}.iter().rev()", - "map": lambda n, vargs: f"{vargs[1]}.iter().map({vargs[0]})", - "filter": lambda n, vargs: f"{vargs[1]}.into_iter().filter({vargs[0]})", - "list": lambda n, vargs: f"{vargs[0]}.collect::>()", - "asyncio.run": KotlinTranspilerPlugins.visit_asyncio_run, + "reversed": lambda n, vargs: f"{vargs[0]}.reversed()", } -SMALL_USINGS_MAP = { - "asyncio.run": "futures::executor::block_on", -} +SMALL_USINGS_MAP = {} DISPATCH_MAP = { "max": functools.partial(KotlinTranspilerPlugins.visit_min_max, is_max=True), @@ -154,6 +146,7 @@ def visit_asyncio_run(node, vargs) -> str: "range": KotlinTranspilerPlugins.visit_range, "xrange": KotlinTranspilerPlugins.visit_range, "print": KotlinTranspilerPlugins.visit_print, + "floor": KotlinTranspilerPlugins.visit_floor, } MODULE_DISPATCH_TABLE = {} diff --git a/pykt/transpiler.py b/pykt/transpiler.py index 18a065629..b3de46c02 100644 --- a/pykt/transpiler.py +++ b/pykt/transpiler.py @@ -1,6 +1,6 @@ import ast -import functools -import re + +from typing import List, Tuple from .clike import CLikeTranspiler from .inference import get_inferred_kotlin_type @@ -19,18 +19,11 @@ from py2many.declaration_extractor import DeclarationExtractor from py2many.tracer import is_list, defined_before, is_class_or_module, is_self_arg -from typing import Optional, List, Tuple - class KotlinPrintRewriter(ast.NodeTransformer): def __init__(self): super().__init__() self._temp = 0 - self._dispatch_map = DISPATCH_MAP - self._small_dispatch_map = SMALL_DISPATCH_MAP - self._small_usings_map = SMALL_USINGS_MAP - self._func_dispatch_table = FUNC_DISPATCH_TABLE - self._attr_dispatch_table = ATTR_DISPATCH_TABLE def _get_temp(self): self._temp += 1 @@ -38,6 +31,7 @@ def _get_temp(self): def visit_Call(self, node): fname = self.visit(node.func) + if ( get_id(fname) == "print" and len(node.args) == 1 @@ -100,6 +94,11 @@ def __init__(self): super().__init__() self._default_type = "" self._container_type_map = self.CONTAINER_TYPE_MAP + self._dispatch_map = DISPATCH_MAP + self._small_dispatch_map = SMALL_DISPATCH_MAP + self._small_usings_map = SMALL_USINGS_MAP + self._func_dispatch_table = FUNC_DISPATCH_TABLE + self._attr_dispatch_table = ATTR_DISPATCH_TABLE self._main_signature_arg_names = ["argv"] def usings(self): @@ -203,70 +202,6 @@ def visit_Attribute(self, node): return f"{value_id}.{attr}" - def visit_range(self, node, vargs: List[str]) -> str: - if len(node.args) == 1: - return "(0..{}-1)".format(vargs[0]) - elif len(node.args) == 2: - return "({}..{}-1)".format(vargs[0], vargs[1]) - elif len(node.args) == 3: - return "({}..{}-1 step {})".format(vargs[0], vargs[1], vargs[2]) - - raise Exception( - "encountered range() call with unknown parameters: range({})".format(vargs) - ) - - def _visit_print(self, node, vargs: List[str]) -> str: - def _format(arg): - if arg.isdigit(): - return arg - if re.match(r"'.*'", arg) or re.match(r'".*"', arg): - return arg[1:-1] - else: - return f"${arg}" - - vargs_str = " ".join([f"{_format(arg)}" for arg in vargs]) - return f'println("{vargs_str}")' - - def _dispatch(self, node, fname: str, vargs: List[str]) -> Optional[str]: - dispatch_map = { - "range": self.visit_range, - "xrange": self.visit_range, - "print": self._visit_print, - } - - if fname in dispatch_map: - return dispatch_map[fname](node, vargs) - - def visit_min_max(is_max: bool) -> str: - min_max = "max" if is_max else "min" - self._usings.add(f"kotlin.math.{min_max}") - self._typename_from_annotation(node.args[0]) - if hasattr(node.args[0], "container_type"): - return f"maxOf({vargs[0]})" - else: - all_vargs = ", ".join(vargs) - return f"{min_max}({all_vargs})" - - def visit_floor(): - self._usings.add("kotlin.math.floor") - return f"floor({vargs[0]}).toInt()" - - # small one liners are inlined here as lambdas - small_dispatch_map = { - "int": lambda: f"{vargs[0]}.toInt()", - "bool": lambda: f"({vargs[0]} != 0)", - "str": lambda: f"{vargs[0]}.toString()", - # TODO: strings use .length - "len": lambda: f"{vargs[0]}.size", - "max": functools.partial(visit_min_max, is_max=True), - "min": functools.partial(visit_min_max, is_min=True), - "floor": visit_floor, - "reversed": lambda: f"{vargs[0]}.reversed()", - } - if fname in small_dispatch_map: - return small_dispatch_map[fname]() - return None - def visit_Call(self, node): fname = self.visit(node.func) vargs = []