Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mypyc] Add 'bit' primitive type and streamline branching #9606

Merged
merged 13 commits into from
Oct 17, 2020
9 changes: 5 additions & 4 deletions mypyc/codegen/emit.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
is_float_rprimitive, is_bool_rprimitive, is_int_rprimitive, is_short_int_rprimitive,
is_list_rprimitive, is_dict_rprimitive, is_set_rprimitive, is_tuple_rprimitive,
is_none_rprimitive, is_object_rprimitive, object_rprimitive, is_str_rprimitive,
int_rprimitive, is_optional_type, optional_value_type, is_int32_rprimitive, is_int64_rprimitive
int_rprimitive, is_optional_type, optional_value_type, is_int32_rprimitive,
is_int64_rprimitive, is_bit_rprimitive
)
from mypyc.ir.func_ir import FuncDecl
from mypyc.ir.class_ir import ClassIR, all_concrete_classes
Expand Down Expand Up @@ -413,7 +414,7 @@ def emit_cast(self, src: str, dest: str, typ: RType, declare_dest: bool = False,
prefix = 'PyUnicode'
elif is_int_rprimitive(typ):
prefix = 'PyLong'
elif is_bool_rprimitive(typ):
elif is_bool_rprimitive(typ) or is_bit_rprimitive(typ):
prefix = 'PyBool'
else:
assert False, 'unexpected primitive type'
Expand Down Expand Up @@ -602,7 +603,7 @@ def emit_unbox(self, src: str, dest: str, typ: RType, custom_failure: Optional[s
self.emit_line('else {')
self.emit_lines(*failure)
self.emit_line('}')
elif is_bool_rprimitive(typ):
elif is_bool_rprimitive(typ) or is_bit_rprimitive(typ):
# Whether we are borrowing or not makes no difference.
if declare_dest:
self.emit_line('char {};'.format(dest))
Expand Down Expand Up @@ -681,7 +682,7 @@ def emit_box(self, src: str, dest: str, typ: RType, declare_dest: bool = False,
if is_int_rprimitive(typ) or is_short_int_rprimitive(typ):
# Steal the existing reference if it exists.
self.emit_line('{}{} = CPyTagged_StealAsObject({});'.format(declaration, dest, src))
elif is_bool_rprimitive(typ):
elif is_bool_rprimitive(typ) or is_bit_rprimitive(typ):
# N.B: bool is special cased to produce a borrowed value
# after boxing, so we don't need to increment the refcount
# when this comes directly from a Box op.
Expand Down
5 changes: 1 addition & 4 deletions mypyc/codegen/emitfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,12 +115,9 @@ def visit_branch(self, op: Branch) -> None:
neg = '!' if op.negated else ''

cond = ''
if op.op == Branch.BOOL_EXPR:
if op.op == Branch.BOOL:
expr_result = self.reg(op.left) # right isn't used
cond = '{}{}'.format(neg, expr_result)
elif op.op == Branch.NEG_INT_EXPR:
expr_result = self.reg(op.left)
cond = '{} < 0'.format(expr_result)
elif op.op == Branch.IS_ERROR:
typ = op.left.type
compare = '!=' if op.negated else '=='
Expand Down
44 changes: 29 additions & 15 deletions mypyc/ir/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
from mypyc.ir.rtypes import (
RType, RInstance, RTuple, RVoid, is_bool_rprimitive, is_int_rprimitive,
is_short_int_rprimitive, is_none_rprimitive, object_rprimitive, bool_rprimitive,
short_int_rprimitive, int_rprimitive, void_rtype, pointer_rprimitive, is_pointer_rprimitive
short_int_rprimitive, int_rprimitive, void_rtype, pointer_rprimitive, is_pointer_rprimitive,
bit_rprimitive, is_bit_rprimitive
)
from mypyc.common import short_name

Expand Down Expand Up @@ -300,10 +301,8 @@ def terminated(self) -> bool:
ERR_MAGIC = 1 # type: Final
# Generates false (bool) on exception
ERR_FALSE = 2 # type: Final
# Generates negative integer on exception
ERR_NEG_INT = 3 # type: Final
# Always fails
ERR_ALWAYS = 4 # type: Final
ERR_ALWAYS = 3 # type: Final

# Hack: using this line number for an op will suppress it in tracebacks
NO_TRACEBACK_LINE_NO = -10000
Expand Down Expand Up @@ -416,20 +415,25 @@ def accept(self, visitor: 'OpVisitor[T]') -> T:


class Branch(ControlOp):
"""if [not] r1 goto 1 else goto 2"""
"""Branch based on a value.

If op is BOOL, branch based on a bit/bool value:
if [not] r1 goto L1 else goto L2

If op is IS_ERROR, branch based on whether there is an error value:
if [not] is_error(r1) goto L1 else goto L2
"""

# Branch ops must *not* raise an exception. If a comparison, for example, can raise an
# exception, it needs to split into two opcodes and only the first one may fail.
error_kind = ERR_NEVER

BOOL_EXPR = 100 # type: Final
BOOL = 100 # type: Final
IS_ERROR = 101 # type: Final
NEG_INT_EXPR = 102 # type: Final

op_names = {
BOOL_EXPR: ('%r', 'bool'),
BOOL: ('%r', 'bool'),
IS_ERROR: ('is_error(%r)', ''),
NEG_INT_EXPR: ('%r < 0', ''),
} # type: Final

def __init__(self,
Expand All @@ -445,7 +449,7 @@ def __init__(self,
self.left = left
self.true = true_label
self.false = false_label
# BOOL_EXPR (boolean check) or IS_ERROR (error value check)
# BOOL (boolean check) or IS_ERROR (error value check)
self.op = op
self.negated = False
# If not None, the true label should generate a traceback entry (func name, line number)
Expand Down Expand Up @@ -1073,7 +1077,9 @@ def __init__(self, src: Value, line: int = -1) -> None:
self.src = src
self.type = object_rprimitive
# When we box None and bool values, we produce a borrowed result
if is_none_rprimitive(self.src.type) or is_bool_rprimitive(self.src.type):
if (is_none_rprimitive(self.src.type)
or is_bool_rprimitive(self.src.type)
or is_bit_rprimitive(self.src.type)):
self.is_borrowed = True

def sources(self) -> List[Value]:
Expand Down Expand Up @@ -1315,12 +1321,20 @@ def accept(self, visitor: 'OpVisitor[T]') -> T:


class ComparisonOp(RegisterOp):
"""Comparison ops
"""Low-level comparison op.

The result type will always be boolean.
Both unsigned and signed comparisons are supported.

Support comparison between integer types and pointer types
The operands are assumed to be fixed-width integers/pointers. Python
semantics, such as calling __eq__, are not supported.

The result is always a bit.

Supports comparisons between fixed-width integer types and pointer
types.
"""
# Must be ERR_NEVER or ERR_FALSE. ERR_FALSE means that a false result
# indicates that an exception has been raised and should be propagated.
error_kind = ERR_NEVER

# S for signed and U for unsigned
Expand Down Expand Up @@ -1350,7 +1364,7 @@ class ComparisonOp(RegisterOp):

def __init__(self, lhs: Value, rhs: Value, op: int, line: int = -1) -> None:
super().__init__(line)
self.type = bool_rprimitive
self.type = bit_rprimitive
self.lhs = lhs
self.rhs = rhs
self.op = op
Expand Down
15 changes: 13 additions & 2 deletions mypyc/ir/rtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,11 +265,18 @@ def __repr__(self) -> str:
float_rprimitive = RPrimitive('builtins.float', is_unboxed=False,
is_refcounted=True) # type: Final

# An unboxed boolean value. This actually has three possible values
# (0 -> False, 1 -> True, 2 -> error).
# An unboxed Python bool value. This actually has three possible values
# (0 -> False, 1 -> True, 2 -> error). If you only need True/False, use
# bit_rprimitive instead.
bool_rprimitive = RPrimitive('builtins.bool', is_unboxed=True, is_refcounted=False,
ctype='char', size=1) # type: Final

# A low-level boolean value with two possible values: 0 and 1. Any
# other value results in undefined behavior. Undefined or error values
# are not supported.
bit_rprimitive = RPrimitive('bit', is_unboxed=True, is_refcounted=False,
ctype='char', size=1) # type: Final

# The 'None' value. The possible values are 0 -> None and 2 -> error.
none_rprimitive = RPrimitive('builtins.None', is_unboxed=True, is_refcounted=False,
ctype='char', size=1) # type: Final
Expand Down Expand Up @@ -329,6 +336,10 @@ def is_bool_rprimitive(rtype: RType) -> bool:
return isinstance(rtype, RPrimitive) and rtype.name == 'builtins.bool'


def is_bit_rprimitive(rtype: RType) -> bool:
return isinstance(rtype, RPrimitive) and rtype.name == 'bit'


def is_object_rprimitive(rtype: RType) -> bool:
return isinstance(rtype, RPrimitive) and rtype.name == 'builtins.object'

Expand Down
2 changes: 1 addition & 1 deletion mypyc/irbuild/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,7 @@ def process_iterator_tuple_assignment(self,
condition = self.binary_op(post_star_len, iter_list_len, '<=', line)

error_block, ok_block = BasicBlock(), BasicBlock()
self.add(Branch(condition, ok_block, error_block, Branch.BOOL_EXPR))
self.add(Branch(condition, ok_block, error_block, Branch.BOOL))

self.activate_block(error_block)
self.add(RaiseStandardError(RaiseStandardError.VALUE_ERROR,
Expand Down
2 changes: 1 addition & 1 deletion mypyc/irbuild/classdef.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ def gen_glue_ne_method(builder: IRBuilder, cls: ClassIR, line: int) -> FuncIR:
builder.translate_is_op(eqval, not_implemented, 'is', line),
not_implemented_block,
regular_block,
Branch.BOOL_EXPR))
Branch.BOOL))

builder.activate_block(regular_block)
retval = builder.coerce(
Expand Down
2 changes: 1 addition & 1 deletion mypyc/irbuild/for_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,7 @@ def gen_condition(self) -> None:

should_continue = builder.add(TupleGet(self.next_tuple, 0, line))
builder.add(
Branch(should_continue, self.body_block, self.loop_exit, Branch.BOOL_EXPR)
Branch(should_continue, self.body_block, self.loop_exit, Branch.BOOL)
)

def gen_step(self) -> None:
Expand Down
2 changes: 1 addition & 1 deletion mypyc/irbuild/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,7 @@ def except_body() -> None:
val = builder.add(TupleGet(res, 1, o.line))

ok, stop = BasicBlock(), BasicBlock()
builder.add(Branch(to_stop, stop, ok, Branch.BOOL_EXPR))
builder.add(Branch(to_stop, stop, ok, Branch.BOOL))

# The exception got swallowed. Continue, yielding the returned value
builder.activate_block(ok)
Expand Down
57 changes: 35 additions & 22 deletions mypyc/irbuild/ll_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,15 @@
LoadStatic, MethodCall, PrimitiveOp, OpDescription, RegisterOp, CallC, Truncate,
RaiseStandardError, Unreachable, LoadErrorValue, LoadGlobal,
NAMESPACE_TYPE, NAMESPACE_MODULE, NAMESPACE_STATIC, BinaryIntOp, GetElementPtr,
LoadMem, ComparisonOp, LoadAddress, TupleGet, SetMem
LoadMem, ComparisonOp, LoadAddress, TupleGet, SetMem, ERR_NEVER, ERR_FALSE
)
from mypyc.ir.rtypes import (
RType, RUnion, RInstance, optional_value_type, int_rprimitive, float_rprimitive,
bool_rprimitive, list_rprimitive, str_rprimitive, is_none_rprimitive, object_rprimitive,
c_pyssize_t_rprimitive, is_short_int_rprimitive, is_tagged, PyVarObject, short_int_rprimitive,
is_list_rprimitive, is_tuple_rprimitive, is_dict_rprimitive, is_set_rprimitive, PySetObject,
none_rprimitive, RTuple, is_bool_rprimitive, is_str_rprimitive, c_int_rprimitive,
pointer_rprimitive, PyObject, PyListObject
pointer_rprimitive, PyObject, PyListObject, bit_rprimitive, is_bit_rprimitive
)
from mypyc.ir.func_ir import FuncDecl, FuncSignature
from mypyc.ir.class_ir import ClassIR, all_concrete_classes
Expand All @@ -40,7 +40,7 @@
)
from mypyc.primitives.registry import (
func_ops, c_method_call_ops, CFunctionDescription, c_function_ops,
c_binary_ops, c_unary_ops
c_binary_ops, c_unary_ops, ERR_NEG_INT
)
from mypyc.primitives.list_ops import (
list_extend_op, new_list_op
Expand Down Expand Up @@ -612,9 +612,9 @@ def compare_tagged(self, lhs: Value, rhs: Value, op: str, line: int) -> Value:
else:
# for non-equal logical ops(less than, greater than, etc.), need to check both side
check_rhs = self.check_tagged_short_int(rhs, line)
check = self.binary_int_op(bool_rprimitive, check_lhs,
check = self.binary_int_op(bit_rprimitive, check_lhs,
check_rhs, BinaryIntOp.AND, line)
branch = Branch(check, short_int_block, int_block, Branch.BOOL_EXPR)
branch = Branch(check, short_int_block, int_block, Branch.BOOL)
branch.negated = False
self.add(branch)
self.activate_block(short_int_block)
Expand Down Expand Up @@ -643,15 +643,15 @@ def compare_strings(self, lhs: Value, rhs: Value, op: str, line: int) -> Value:
compare_error_check = self.add(ComparisonOp(compare_result,
error_constant, ComparisonOp.EQ, line))
exception_check, propagate, final_compare = BasicBlock(), BasicBlock(), BasicBlock()
branch = Branch(compare_error_check, exception_check, final_compare, Branch.BOOL_EXPR)
branch = Branch(compare_error_check, exception_check, final_compare, Branch.BOOL)
branch.negated = False
self.add(branch)
self.activate_block(exception_check)
check_error_result = self.call_c(err_occurred_op, [], line)
null = self.add(LoadInt(0, line, pointer_rprimitive))
compare_error_check = self.add(ComparisonOp(check_error_result,
null, ComparisonOp.NEQ, line))
branch = Branch(compare_error_check, propagate, final_compare, Branch.BOOL_EXPR)
branch = Branch(compare_error_check, propagate, final_compare, Branch.BOOL)
branch.negated = False
self.add(branch)
self.activate_block(propagate)
Expand Down Expand Up @@ -698,9 +698,9 @@ def compare_tuples(self,
if not is_bool_rprimitive(compare.type):
compare = self.call_c(bool_op, [compare], line)
if i < len(lhs.type.types) - 1:
branch = Branch(compare, early_stop, check_blocks[i + 1], Branch.BOOL_EXPR)
branch = Branch(compare, early_stop, check_blocks[i + 1], Branch.BOOL)
else:
branch = Branch(compare, early_stop, final, Branch.BOOL_EXPR)
branch = Branch(compare, early_stop, final, Branch.BOOL)
# if op is ==, we branch on false, else branch on true
branch.negated = equal
self.add(branch)
Expand All @@ -726,14 +726,14 @@ def bool_bitwise_op(self, lreg: Value, rreg: Value, op: str, line: int) -> Value
def unary_not(self,
value: Value,
line: int) -> Value:
mask = self.add(LoadInt(1, line, rtype=bool_rprimitive))
return self.binary_int_op(bool_rprimitive, value, mask, BinaryIntOp.XOR, line)
mask = self.add(LoadInt(1, line, rtype=value.type))
return self.binary_int_op(value.type, value, mask, BinaryIntOp.XOR, line)

def unary_op(self,
lreg: Value,
expr_op: str,
line: int) -> Value:
if is_bool_rprimitive(lreg.type) and expr_op == 'not':
if (is_bool_rprimitive(lreg.type) or is_bit_rprimitive(lreg.type)) and expr_op == 'not':
return self.unary_not(lreg, line)
call_c_ops_candidates = c_unary_ops.get(expr_op, [])
target = self.matching_call_c(call_c_ops_candidates, [lreg], line)
Expand Down Expand Up @@ -841,7 +841,7 @@ def shortcircuit_helper(self, op: str,

def add_bool_branch(self, value: Value, true: BasicBlock, false: BasicBlock) -> None:
if is_runtime_subtype(value.type, int_rprimitive):
zero = self.add(LoadInt(0))
zero = self.add(LoadInt(0, rtype=value.type))
value = self.binary_op(value, zero, '!=', value.line)
elif is_same_type(value.type, list_rprimitive):
length = self.builtin_len(value, value.line)
Expand All @@ -855,7 +855,7 @@ def add_bool_branch(self, value: Value, true: BasicBlock, false: BasicBlock) ->
value_type = optional_value_type(value.type)
if value_type is not None:
is_none = self.translate_is_op(value, self.none_object(), 'is not', value.line)
branch = Branch(is_none, true, false, Branch.BOOL_EXPR)
branch = Branch(is_none, true, false, Branch.BOOL)
self.add(branch)
always_truthy = False
if isinstance(value_type, RInstance):
Expand All @@ -873,42 +873,55 @@ def add_bool_branch(self, value: Value, true: BasicBlock, false: BasicBlock) ->
remaining = self.unbox_or_cast(value, value_type, value.line)
self.add_bool_branch(remaining, true, false)
return
elif not is_same_type(value.type, bool_rprimitive):
elif not is_bool_rprimitive(value.type) and not is_bit_rprimitive(value.type):
value = self.call_c(bool_op, [value], value.line)
self.add(Branch(value, true, false, Branch.BOOL_EXPR))
self.add(Branch(value, true, false, Branch.BOOL))

def call_c(self,
desc: CFunctionDescription,
args: List[Value],
line: int,
result_type: Optional[RType] = None) -> Value:
# handle void function via singleton RVoid instance
"""Call function using C/native calling convention (not a Python callable)."""
# Handle void function via singleton RVoid instance
coerced = []
# coerce fixed number arguments
# Coerce fixed number arguments
for i in range(min(len(args), len(desc.arg_types))):
formal_type = desc.arg_types[i]
arg = args[i]
arg = self.coerce(arg, formal_type, line)
coerced.append(arg)
# reorder args if necessary
# Reorder args if necessary
if desc.ordering is not None:
assert desc.var_arg_type is None
coerced = [coerced[i] for i in desc.ordering]
# coerce any var_arg
# Coerce any var_arg
var_arg_idx = -1
if desc.var_arg_type is not None:
var_arg_idx = len(desc.arg_types)
for i in range(len(desc.arg_types), len(args)):
arg = args[i]
arg = self.coerce(arg, desc.var_arg_type, line)
coerced.append(arg)
# add extra integer constant if any
# Add extra integer constant if any
for item in desc.extra_int_constants:
val, typ = item
extra_int_constant = self.add(LoadInt(val, line, rtype=typ))
coerced.append(extra_int_constant)
error_kind = desc.error_kind
if error_kind == ERR_NEG_INT:
# Handled with an explicit comparison
error_kind = ERR_NEVER
target = self.add(CallC(desc.c_function_name, coerced, desc.return_type, desc.steals,
desc.is_borrowed, desc.error_kind, line, var_arg_idx))
desc.is_borrowed, error_kind, line, var_arg_idx))
if desc.error_kind == ERR_NEG_INT:
comp = ComparisonOp(target,
self.add(LoadInt(0, line, desc.return_type)),
ComparisonOp.SGE,
line)
comp.error_kind = ERR_FALSE
self.add(comp)

if desc.truncated_type is None:
result = target
else:
Expand Down
Loading