Skip to content

Commit

Permalink
[mypyc] Rewrite CPyStr_Build using a simplification of _PyUnicode_Joi…
Browse files Browse the repository at this point in the history
…nArray (python#10762)

This makes specialized `format()` calls faster.

Closes mypyc/mypyc#876.
  • Loading branch information
97littleleaf11 authored Jul 6, 2021
1 parent 6eafc5e commit 178df79
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 12 deletions.
7 changes: 4 additions & 3 deletions mypyc/irbuild/specialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
)
from mypyc.ir.rtypes import (
RType, RTuple, str_rprimitive, list_rprimitive, dict_rprimitive, set_rprimitive,
bool_rprimitive, is_dict_rprimitive, c_int_rprimitive, is_str_rprimitive
bool_rprimitive, is_dict_rprimitive, c_int_rprimitive, is_str_rprimitive,
c_pyssize_t_rprimitive
)
from mypyc.primitives.dict_ops import (
dict_keys_op, dict_values_op, dict_items_op, dict_setdefault_spec_init_op
Expand Down Expand Up @@ -378,7 +379,7 @@ def translate_str_format(

# The first parameter is the total size of the following PyObject* merged from
# two lists alternatively.
result_list: List[Value] = [Integer(0, c_int_rprimitive)]
result_list: List[Value] = [Integer(0, c_pyssize_t_rprimitive)]
for a, b in zip(literals, variables):
if a:
result_list.append(builder.load_str(a))
Expand All @@ -393,7 +394,7 @@ def translate_str_format(
if not variables and len(result_list) == 2:
return result_list[1]

result_list[0] = Integer(len(result_list) - 1, c_int_rprimitive)
result_list[0] = Integer(len(result_list) - 1, c_pyssize_t_rprimitive)
return builder.call_c(str_build_op, result_list, expr.line)
return None

Expand Down
2 changes: 1 addition & 1 deletion mypyc/lib-rt/CPy.h
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ static inline char CPyDict_CheckSize(PyObject *dict, CPyTagged size) {
// Str operations


PyObject *CPyStr_Build(int len, ...);
PyObject *CPyStr_Build(Py_ssize_t len, ...);
PyObject *CPyStr_GetItem(PyObject *str, CPyTagged index);
PyObject *CPyStr_Split(PyObject *str, PyObject *sep, CPyTagged max_split);
PyObject *CPyStr_Replace(PyObject *str, PyObject *old_substr, PyObject *new_substr, CPyTagged max_replace);
Expand Down
86 changes: 79 additions & 7 deletions mypyc/lib-rt/str_ops.c
Original file line number Diff line number Diff line change
Expand Up @@ -43,18 +43,90 @@ PyObject *CPyStr_GetItem(PyObject *str, CPyTagged index) {
}
}

PyObject *CPyStr_Build(int len, ...) {
int i;
// A simplification of _PyUnicode_JoinArray() from CPython 3.9.6
PyObject *CPyStr_Build(Py_ssize_t len, ...) {
Py_ssize_t i;
va_list args;

// Calculate the total amount of space and check
// whether all components have the same kind.
Py_ssize_t sz = 0;
Py_UCS4 maxchar = 0;
int use_memcpy = 1; // Use memcpy by default
PyObject *last_obj = NULL;

va_start(args, len);
for (i = 0; i < len; i++) {
PyObject *item = va_arg(args, PyObject *);
if (!PyUnicode_Check(item)) {
PyErr_Format(PyExc_TypeError,
"sequence item %zd: expected str instance,"
" %.80s found",
i, Py_TYPE(item)->tp_name);
return NULL;
}
if (PyUnicode_READY(item) == -1)
return NULL;

PyObject *res = PyUnicode_FromObject(va_arg(args, PyObject *));
for (i = 1; i < len; i++) {
PyObject *str = va_arg(args, PyObject *);
PyUnicode_Append(&res, str);
}
size_t add_sz = PyUnicode_GET_LENGTH(item);
Py_UCS4 item_maxchar = PyUnicode_MAX_CHAR_VALUE(item);
maxchar = Py_MAX(maxchar, item_maxchar);

// Using size_t to avoid overflow during arithmetic calculation
if (add_sz > (size_t)(PY_SSIZE_T_MAX - sz)) {
PyErr_SetString(PyExc_OverflowError,
"join() result is too long for a Python string");
return NULL;
}
sz += add_sz;

// If these strings have different kind, we would call
// _PyUnicode_FastCopyCharacters() in the following part.
if (use_memcpy && last_obj != NULL) {
if (PyUnicode_KIND(last_obj) != PyUnicode_KIND(item))
use_memcpy = 0;
}
last_obj = item;
}
va_end(args);

// Construct the string
PyObject *res = PyUnicode_New(sz, maxchar);
if (res == NULL)
return NULL;

if (use_memcpy) {
unsigned char *res_data = PyUnicode_1BYTE_DATA(res);
unsigned int kind = PyUnicode_KIND(res);

va_start(args, len);
for (i = 0; i < len; ++i) {
PyObject *item = va_arg(args, PyObject *);
Py_ssize_t itemlen = PyUnicode_GET_LENGTH(item);
if (itemlen != 0) {
memcpy(res_data, PyUnicode_DATA(item), kind * itemlen);
res_data += kind * itemlen;
}
}
va_end(args);
assert(res_data == PyUnicode_1BYTE_DATA(res) + kind * PyUnicode_GET_LENGTH(res));
} else {
Py_ssize_t res_offset = 0;

va_start(args, len);
for (i = 0; i < len; ++i) {
PyObject *item = va_arg(args, PyObject *);
Py_ssize_t itemlen = PyUnicode_GET_LENGTH(item);
if (itemlen != 0) {
_PyUnicode_FastCopyCharacters(res, res_offset, item, 0, itemlen);
res_offset += itemlen;
}
}
va_end(args);
assert(res_offset == PyUnicode_GET_LENGTH(res));
}

assert(_PyUnicode_CheckConsistency(res, 1));
return res;
}

Expand Down
2 changes: 1 addition & 1 deletion mypyc/primitives/str_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
)

str_build_op = custom_op(
arg_types=[c_int_rprimitive],
arg_types=[c_pyssize_t_rprimitive],
return_type=str_rprimitive,
c_function_name='CPyStr_Build',
error_kind=ERR_MAGIC,
Expand Down
14 changes: 14 additions & 0 deletions mypyc/test-data/run-strings.test
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,10 @@ def test_fstring_python_doc() -> None:
from typing import Tuple

def test_format_method_basics() -> None:
x = str()
assert 'x{}'.format(x) == 'x'
assert 'ā{}'.format(x) == 'ā'
assert '😀{}'.format(x) == '😀'
assert ''.format() == ''
assert 'abc'.format() == 'abc'
assert '{}{}'.format(1, 2) == '12'
Expand Down Expand Up @@ -342,6 +346,16 @@ def test_format_method_args() -> None:
assert format_kwargs(x=10, y=2, z=1) == 'c10d2'
assert format_kwargs_self(x=10, y=2, z=1) == "{'x': 10, 'y': 2, 'z': 1}"

def test_format_method_different_kind() -> None:
s1 = "Literal['😀']"
assert 'Revealed type is {}'.format(s1) == "Revealed type is Literal['😀']"
s2 = "Revealed type is"
assert "{} Literal['😀']".format(s2) == "Revealed type is Literal['😀']"
s3 = "测试:"
assert "{}{} {}".format(s3, s2, s1) == "测试:Revealed type is Literal['😀']"
assert "Test: {}{}".format(s3, s1) == "Test: 测试:Literal['😀']"
assert "Test: {}{}".format(s3, s2) == "Test: 测试:Revealed type is"

class Point:
def __init__(self, x, y):
self.x, self.y = x, y
Expand Down

0 comments on commit 178df79

Please sign in to comment.