Skip to content

Commit

Permalink
[mypyc] Simplify IR for tagged integer comparisons (python#9607)
Browse files Browse the repository at this point in the history
In a conditional context, such as in an if condition, simplify the 
IR for tagged integer comparisons. Also perform some additional 
optimizations if an operand is known to be a short integer. 

This slightly improves performance when compiling with no 
optimizations. The impact should be pretty negligible otherwise.

This is a bit simple-minded, and some further optimizations are 
possible. For example, `3 < x < 6` could be made faster. This 
covers the most common cases, however.

Closes mypyc/mypyc#758.
  • Loading branch information
JukkaL authored Oct 18, 2020
1 parent 3acbf3f commit 27a9830
Show file tree
Hide file tree
Showing 10 changed files with 821 additions and 1,019 deletions.
42 changes: 37 additions & 5 deletions mypyc/irbuild/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from mypy.nodes import (
MypyFile, SymbolNode, Statement, OpExpr, IntExpr, NameExpr, LDEF, Var, UnaryExpr,
CallExpr, IndexExpr, Expression, MemberExpr, RefExpr, Lvalue, TupleExpr,
TypeInfo, Decorator, OverloadedFuncDef, StarExpr, GDEF, ARG_POS, ARG_NAMED
TypeInfo, Decorator, OverloadedFuncDef, StarExpr, ComparisonExpr, GDEF, ARG_POS, ARG_NAMED
)
from mypy.types import (
Type, Instance, TupleType, UninhabitedType, get_proper_type
Expand All @@ -39,7 +39,7 @@
from mypyc.ir.rtypes import (
RType, RTuple, RInstance, int_rprimitive, dict_rprimitive,
none_rprimitive, is_none_rprimitive, object_rprimitive, is_object_rprimitive,
str_rprimitive,
str_rprimitive, is_tagged
)
from mypyc.ir.func_ir import FuncIR, INVALID_FUNC_DEF
from mypyc.ir.class_ir import ClassIR, NonExtClassInfo
Expand Down Expand Up @@ -813,11 +813,45 @@ def process_conditional(self, e: Expression, true: BasicBlock, false: BasicBlock
self.process_conditional(e.right, true, false)
elif isinstance(e, UnaryExpr) and e.op == 'not':
self.process_conditional(e.expr, false, true)
# Catch-all for arbitrary expressions.
else:
res = self.maybe_process_conditional_comparison(e, true, false)
if res:
return
# Catch-all for arbitrary expressions.
reg = self.accept(e)
self.add_bool_branch(reg, true, false)

def maybe_process_conditional_comparison(self,
e: Expression,
true: BasicBlock,
false: BasicBlock) -> bool:
"""Transform simple tagged integer comparisons in a conditional context.
Return True if the operation is supported (and was transformed). Otherwise,
do nothing and return False.
Args:
e: Arbitrary expression
true: Branch target if comparison is true
false: Branch target if comparison is false
"""
if not isinstance(e, ComparisonExpr) or len(e.operands) != 2:
return False
ltype = self.node_type(e.operands[0])
rtype = self.node_type(e.operands[1])
if not is_tagged(ltype) or not is_tagged(rtype):
return False
op = e.operators[0]
if op not in ('==', '!=', '<', '<=', '>', '>='):
return False
left = self.accept(e.operands[0])
right = self.accept(e.operands[1])
# "left op right" for two tagged integers
self.builder.compare_tagged_condition(left, right, op, true, false, e.line)
return True

# Basic helpers

def flatten_classes(self, arg: Union[RefExpr, TupleExpr]) -> Optional[List[ClassIR]]:
"""Flatten classes in isinstance(obj, (A, (B, C))).
Expand All @@ -841,8 +875,6 @@ def flatten_classes(self, arg: Union[RefExpr, TupleExpr]) -> Optional[List[Class
return None
return res

# Basic helpers

def enter(self, fn_info: Union[FuncInfo, str] = '') -> None:
if isinstance(fn_info, str):
fn_info = FuncInfo(name=fn_info)
Expand Down
72 changes: 64 additions & 8 deletions mypyc/irbuild/ll_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,17 +589,21 @@ def binary_op(self,
assert target, 'Unsupported binary operation: %s' % op
return target

def check_tagged_short_int(self, val: Value, line: int) -> Value:
"""Check if a tagged integer is a short integer"""
def check_tagged_short_int(self, val: Value, line: int, negated: bool = False) -> Value:
"""Check if a tagged integer is a short integer.
Return the result of the check (value of type 'bit').
"""
int_tag = self.add(LoadInt(1, line, rtype=c_pyssize_t_rprimitive))
bitwise_and = self.binary_int_op(c_pyssize_t_rprimitive, val,
int_tag, BinaryIntOp.AND, line)
zero = self.add(LoadInt(0, line, rtype=c_pyssize_t_rprimitive))
check = self.comparison_op(bitwise_and, zero, ComparisonOp.EQ, line)
op = ComparisonOp.NEQ if negated else ComparisonOp.EQ
check = self.comparison_op(bitwise_and, zero, op, line)
return check

def compare_tagged(self, lhs: Value, rhs: Value, op: str, line: int) -> Value:
"""Compare two tagged integers using given op"""
"""Compare two tagged integers using given operator (value context)."""
# generate fast binary logic ops on short ints
if is_short_int_rprimitive(lhs.type) and is_short_int_rprimitive(rhs.type):
return self.comparison_op(lhs, rhs, int_comparison_op_mapping[op][0], line)
Expand All @@ -610,13 +614,11 @@ def compare_tagged(self, lhs: Value, rhs: Value, op: str, line: int) -> Value:
if op in ("==", "!="):
check = check_lhs
else:
# for non-equal logical ops(less than, greater than, etc.), need to check both side
# for non-equality logical ops (less/greater than, etc.), need to check both sides
check_rhs = self.check_tagged_short_int(rhs, line)
check = self.binary_int_op(bit_rprimitive, check_lhs,
check_rhs, BinaryIntOp.AND, line)
branch = Branch(check, short_int_block, int_block, Branch.BOOL)
branch.negated = False
self.add(branch)
self.add(Branch(check, short_int_block, int_block, Branch.BOOL))
self.activate_block(short_int_block)
eq = self.comparison_op(lhs, rhs, op_type, line)
self.add(Assign(result, eq, line))
Expand All @@ -636,6 +638,60 @@ def compare_tagged(self, lhs: Value, rhs: Value, op: str, line: int) -> Value:
self.goto_and_activate(out)
return result

def compare_tagged_condition(self,
lhs: Value,
rhs: Value,
op: str,
true: BasicBlock,
false: BasicBlock,
line: int) -> None:
"""Compare two tagged integers using given operator (conditional context).
Assume lhs and and rhs are tagged integers.
Args:
lhs: Left operand
rhs: Right operand
op: Operation, one of '==', '!=', '<', '<=', '>', '<='
true: Branch target if comparison is true
false: Branch target if comparison is false
"""
is_eq = op in ("==", "!=")
if ((is_short_int_rprimitive(lhs.type) and is_short_int_rprimitive(rhs.type))
or (is_eq and (is_short_int_rprimitive(lhs.type) or
is_short_int_rprimitive(rhs.type)))):
# We can skip the tag check
check = self.comparison_op(lhs, rhs, int_comparison_op_mapping[op][0], line)
self.add(Branch(check, true, false, Branch.BOOL))
return
op_type, c_func_desc, negate_result, swap_op = int_comparison_op_mapping[op]
int_block, short_int_block = BasicBlock(), BasicBlock()
check_lhs = self.check_tagged_short_int(lhs, line, negated=True)
if is_eq or is_short_int_rprimitive(rhs.type):
self.add(Branch(check_lhs, int_block, short_int_block, Branch.BOOL))
else:
# For non-equality logical ops (less/greater than, etc.), need to check both sides
rhs_block = BasicBlock()
self.add(Branch(check_lhs, int_block, rhs_block, Branch.BOOL))
self.activate_block(rhs_block)
check_rhs = self.check_tagged_short_int(rhs, line, negated=True)
self.add(Branch(check_rhs, int_block, short_int_block, Branch.BOOL))
# Arbitrary integers (slow path)
self.activate_block(int_block)
if swap_op:
args = [rhs, lhs]
else:
args = [lhs, rhs]
call = self.call_c(c_func_desc, args, line)
if negate_result:
self.add(Branch(call, false, true, Branch.BOOL))
else:
self.add(Branch(call, true, false, Branch.BOOL))
# Short integers (fast path)
self.activate_block(short_int_block)
eq = self.comparison_op(lhs, rhs, op_type, line)
self.add(Branch(eq, true, false, Branch.BOOL))

def compare_strings(self, lhs: Value, rhs: Value, op: str, line: int) -> Value:
"""Compare two strings"""
compare_result = self.call_c(unicode_compare, [lhs, rhs], line)
Expand Down
Loading

0 comments on commit 27a9830

Please sign in to comment.