Skip to content

Commit

Permalink
[mypyc] Speed up and improve multiple assignment (python#9800)
Browse files Browse the repository at this point in the history
Speed up multiple assignment from variable-length lists and tuples. 
This speeds up the `multiple_assignment` benchmark by around 80%.

Fix multiple lvalues in fixed-length sequence assignments.

Optimize some cases of list expressions in assignments.

Fixes mypyc/mypyc#729.
  • Loading branch information
JukkaL authored Dec 29, 2020
1 parent 7c0c1e7 commit 18ab589
Show file tree
Hide file tree
Showing 8 changed files with 218 additions and 61 deletions.
40 changes: 35 additions & 5 deletions mypyc/irbuild/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,15 @@
from mypyc.ir.rtypes import (
RType, RTuple, RInstance, int_rprimitive, dict_rprimitive,
none_rprimitive, is_none_rprimitive, object_rprimitive, is_object_rprimitive,
str_rprimitive, is_tagged
str_rprimitive, is_tagged, is_list_rprimitive, is_tuple_rprimitive, c_pyssize_t_rprimitive
)
from mypyc.ir.func_ir import FuncIR, INVALID_FUNC_DEF
from mypyc.ir.class_ir import ClassIR, NonExtClassInfo
from mypyc.primitives.registry import CFunctionDescription, function_ops
from mypyc.primitives.list_ops import to_list, list_pop_last
from mypyc.primitives.list_ops import to_list, list_pop_last, list_get_item_unsafe_op
from mypyc.primitives.dict_ops import dict_get_item_op, dict_set_item_op
from mypyc.primitives.generic_ops import py_setattr_op, iter_op, next_op
from mypyc.primitives.misc_ops import import_op
from mypyc.primitives.misc_ops import import_op, check_unpack_count_op
from mypyc.crash import catch_errors
from mypyc.options import CompilerOptions
from mypyc.errors import Errors
Expand Down Expand Up @@ -460,8 +460,10 @@ def read(self, target: Union[Value, AssignmentTarget], line: int = -1) -> Value:

assert False, 'Unsupported lvalue: %r' % target

def assign(self, target: Union[Register, AssignmentTarget],
rvalue_reg: Value, line: int) -> None:
def assign(self,
target: Union[Register, AssignmentTarget],
rvalue_reg: Value,
line: int) -> None:
if isinstance(target, Register):
self.add(Assign(target, rvalue_reg))
elif isinstance(target, AssignmentTargetRegister):
Expand All @@ -486,11 +488,39 @@ def assign(self, target: Union[Register, AssignmentTarget],
for i in range(len(rtypes)):
item_value = self.add(TupleGet(rvalue_reg, i, line))
self.assign(target.items[i], item_value, line)
elif ((is_list_rprimitive(rvalue_reg.type) or is_tuple_rprimitive(rvalue_reg.type))
and target.star_idx is None):
self.process_sequence_assignment(target, rvalue_reg, line)
else:
self.process_iterator_tuple_assignment(target, rvalue_reg, line)
else:
assert False, 'Unsupported assignment target'

def process_sequence_assignment(self,
target: AssignmentTargetTuple,
rvalue: Value,
line: int) -> None:
"""Process assignment like 'x, y = s', where s is a variable-length list or tuple."""
# Check the length of sequence.
expected_len = self.add(LoadInt(len(target.items), rtype=c_pyssize_t_rprimitive))
self.builder.call_c(check_unpack_count_op, [rvalue, expected_len], line)

# Read sequence items.
values = []
for i in range(len(target.items)):
item = target.items[i]
index = self.builder.load_static_int(i)
if is_list_rprimitive(rvalue.type):
item_value = self.call_c(list_get_item_unsafe_op, [rvalue, index], line)
else:
item_value = self.builder.gen_method_call(
rvalue, '__getitem__', [index], item.type, line)
values.append(item_value)

# Assign sequence items to the target lvalues.
for lvalue, value in zip(target.items, values):
self.assign(lvalue, value, line)

def process_iterator_tuple_assignment_helper(self,
litem: AssignmentTarget,
ritem: Value, line: int) -> None:
Expand Down
31 changes: 20 additions & 11 deletions mypyc/irbuild/statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
from mypy.nodes import (
Block, ExpressionStmt, ReturnStmt, AssignmentStmt, OperatorAssignmentStmt, IfStmt, WhileStmt,
ForStmt, BreakStmt, ContinueStmt, RaiseStmt, TryStmt, WithStmt, AssertStmt, DelStmt,
Expression, StrExpr, TempNode, Lvalue, Import, ImportFrom, ImportAll, TupleExpr
Expression, StrExpr, TempNode, Lvalue, Import, ImportFrom, ImportAll, TupleExpr, ListExpr,
StarExpr
)

from mypyc.ir.ops import (
Expand Down Expand Up @@ -69,39 +70,47 @@ def transform_return_stmt(builder: IRBuilder, stmt: ReturnStmt) -> None:


def transform_assignment_stmt(builder: IRBuilder, stmt: AssignmentStmt) -> None:
assert len(stmt.lvalues) >= 1
builder.disallow_class_assignments(stmt.lvalues, stmt.line)
lvalue = stmt.lvalues[0]
lvalues = stmt.lvalues
assert len(lvalues) >= 1
builder.disallow_class_assignments(lvalues, stmt.line)
first_lvalue = lvalues[0]
if stmt.type and isinstance(stmt.rvalue, TempNode):
# This is actually a variable annotation without initializer. Don't generate
# an assignment but we need to call get_assignment_target since it adds a
# name binding as a side effect.
builder.get_assignment_target(lvalue, stmt.line)
builder.get_assignment_target(first_lvalue, stmt.line)
return

# multiple assignment
if (isinstance(lvalue, TupleExpr) and isinstance(stmt.rvalue, TupleExpr)
and len(lvalue.items) == len(stmt.rvalue.items)):
# Special case multiple assignments like 'x, y = e1, e2'.
if (isinstance(first_lvalue, (TupleExpr, ListExpr))
and isinstance(stmt.rvalue, (TupleExpr, ListExpr))
and len(first_lvalue.items) == len(stmt.rvalue.items)
and all(is_simple_lvalue(item) for item in first_lvalue.items)
and len(lvalues) == 1):
temps = []
for right in stmt.rvalue.items:
rvalue_reg = builder.accept(right)
temp = Register(rvalue_reg.type)
builder.assign(temp, rvalue_reg, stmt.line)
temps.append(temp)
for (left, temp) in zip(lvalue.items, temps):
for (left, temp) in zip(first_lvalue.items, temps):
assignment_target = builder.get_assignment_target(left)
builder.assign(assignment_target, temp, stmt.line)
return

line = stmt.rvalue.line
rvalue_reg = builder.accept(stmt.rvalue)
if builder.non_function_scope() and stmt.is_final_def:
builder.init_final_static(lvalue, rvalue_reg)
for lvalue in stmt.lvalues:
builder.init_final_static(first_lvalue, rvalue_reg)
for lvalue in lvalues:
target = builder.get_assignment_target(lvalue)
builder.assign(target, rvalue_reg, line)


def is_simple_lvalue(expr: Expression) -> bool:
return not isinstance(expr, (StarExpr, ListExpr, TupleExpr))


def transform_operator_assignment_stmt(builder: IRBuilder, stmt: OperatorAssignmentStmt) -> None:
"""Operator assignment statement such as x += 1"""
builder.disallow_class_assignments([stmt.lvalue], stmt.line)
Expand Down
1 change: 1 addition & 0 deletions mypyc/lib-rt/CPy.h
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,7 @@ void CPyDebug_Print(const char *msg);
void CPy_Init(void);
int CPyArg_ParseTupleAndKeywords(PyObject *, PyObject *,
const char *, char **, ...);
int CPySequence_CheckUnpackCount(PyObject *sequence, Py_ssize_t expected);


#ifdef __cplusplus
Expand Down
14 changes: 14 additions & 0 deletions mypyc/lib-rt/misc_ops.c
Original file line number Diff line number Diff line change
Expand Up @@ -495,3 +495,17 @@ void CPyDebug_Print(const char *msg) {
printf("%s\n", msg);
fflush(stdout);
}

int CPySequence_CheckUnpackCount(PyObject *sequence, Py_ssize_t expected) {
Py_ssize_t actual = Py_SIZE(sequence);
if (unlikely(actual != expected)) {
if (actual < expected) {
PyErr_Format(PyExc_ValueError, "not enough values to unpack (expected %zd, got %zd)",
expected, actual);
} else {
PyErr_Format(PyExc_ValueError, "too many values to unpack (expected %zd)", expected);
}
return -1;
}
return 0;
}
10 changes: 9 additions & 1 deletion mypyc/primitives/misc_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from mypyc.ir.ops import ERR_NEVER, ERR_MAGIC, ERR_FALSE
from mypyc.ir.rtypes import (
bool_rprimitive, object_rprimitive, str_rprimitive, object_pointer_rprimitive,
int_rprimitive, dict_rprimitive, c_int_rprimitive, bit_rprimitive
int_rprimitive, dict_rprimitive, c_int_rprimitive, bit_rprimitive, c_pyssize_t_rprimitive
)
from mypyc.primitives.registry import (
function_op, custom_op, load_address_op, ERR_NEG_INT
Expand Down Expand Up @@ -176,3 +176,11 @@
return_type=bit_rprimitive,
c_function_name='CPyDataclass_SleightOfHand',
error_kind=ERR_FALSE)

# Raise ValueError if length of first argument is not equal to the second argument.
# The first argument must be a list or a variable-length tuple.
check_unpack_count_op = custom_op(
arg_types=[object_rprimitive, c_pyssize_t_rprimitive],
return_type=c_int_rprimitive,
c_function_name='CPySequence_CheckUnpackCount',
error_kind=ERR_NEG_INT)
42 changes: 0 additions & 42 deletions mypyc/test-data/irbuild-basic.test
Original file line number Diff line number Diff line change
Expand Up @@ -3605,48 +3605,6 @@ L0:
r2 = truncate r0: int32 to builtins.bool
return r2

[case testMultipleAssignment]
from typing import Tuple

def f(x: int, y: int) -> Tuple[int, int]:
x, y = y, x
return (x, y)

def f2(x: int, y: str, z: float) -> Tuple[float, str, int]:
a, b, c = x, y, z
return (c, b, a)
[out]
def f(x, y):
x, y, r0, r1 :: int
r2 :: tuple[int, int]
L0:
r0 = y
r1 = x
x = r0
y = r1
r2 = (x, y)
return r2
def f2(x, y, z):
x :: int
y :: str
z :: float
r0 :: int
r1 :: str
r2 :: float
a :: int
b :: str
c :: float
r3 :: tuple[float, str, int]
L0:
r0 = x
r1 = y
r2 = z
a = r0
b = r1
c = r2
r3 = (c, b, a)
return r3

[case testLocalImportSubmodule]
def f() -> int:
import p.m
Expand Down
97 changes: 96 additions & 1 deletion mypyc/test-data/irbuild-statements.test
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,63 @@ L9:
L10:
return s

[case testMultipleAssignment]
[case testMultipleAssignmentWithNoUnpacking]
from typing import Tuple

def f(x: int, y: int) -> Tuple[int, int]:
x, y = y, x
return (x, y)

def f2(x: int, y: str, z: float) -> Tuple[float, str, int]:
a, b, c = x, y, z
return (c, b, a)

def f3(x: int, y: int) -> Tuple[int, int]:
[x, y] = [y, x]
return (x, y)
[out]
def f(x, y):
x, y, r0, r1 :: int
r2 :: tuple[int, int]
L0:
r0 = y
r1 = x
x = r0
y = r1
r2 = (x, y)
return r2
def f2(x, y, z):
x :: int
y :: str
z :: float
r0 :: int
r1 :: str
r2 :: float
a :: int
b :: str
c :: float
r3 :: tuple[float, str, int]
L0:
r0 = x
r1 = y
r2 = z
a = r0
b = r1
c = r2
r3 = (c, b, a)
return r3
def f3(x, y):
x, y, r0, r1 :: int
r2 :: tuple[int, int]
L0:
r0 = y
r1 = x
x = r0
y = r1
r2 = (x, y)
return r2

[case testMultipleAssignmentBasicUnpacking]
from typing import Tuple, Any

def from_tuple(t: Tuple[int, str]) -> None:
Expand Down Expand Up @@ -596,6 +652,45 @@ L0:
z = r6
return 1

[case testMultipleAssignmentUnpackFromSequence]
from typing import List, Tuple

def f(l: List[int], t: Tuple[int, ...]) -> None:
x: object
y: int
x, y = l
x, y = t
[out]
def f(l, t):
l :: list
t :: tuple
x :: object
y :: int
r0 :: int32
r1 :: bit
r2, r3 :: object
r4 :: int
r5 :: int32
r6 :: bit
r7, r8 :: object
r9 :: int
L0:
r0 = CPySequence_CheckUnpackCount(l, 2)
r1 = r0 >= 0 :: signed
r2 = CPyList_GetItemUnsafe(l, 0)
r3 = CPyList_GetItemUnsafe(l, 2)
x = r2
r4 = unbox(int, r3)
y = r4
r5 = CPySequence_CheckUnpackCount(t, 2)
r6 = r5 >= 0 :: signed
r7 = CPySequenceTuple_GetItem(t, 0)
r8 = CPySequenceTuple_GetItem(t, 2)
r9 = unbox(int, r8)
x = r7
y = r9
return 1

[case testAssert]
from typing import Optional

Expand Down
Loading

0 comments on commit 18ab589

Please sign in to comment.