Skip to content

Commit

Permalink
Update on "[ONNX] handle aten::_set_item on Dict in convertInplaceOps…
Browse files Browse the repository at this point in the history
…AndTrackAlias (#58317)"

It seems the JIT produces an output for aten::_set_item on lists but
not on dicts. Previously the code would crash because it assumed it
was operating on a list.

The different behavior can be seen with the following test:

```python
class DictModule(torch.nn.Module):
    def forward(self, x_in: torch.Tensor) -> typing.Dict[str, torch.Tensor]:
        x_out = {}
        x_out["test_key_out"] = x_in
        return x_out

x_in = torch.tensor(1)
dms = torch.jit.script(DictModule())
torch.onnx.export(dms, (x_in,), "/dev/null", example_outputs=(dms(x_in),))
```

Before this change:
`RuntimeError: outputs_.size() == 1INTERNAL ASSERT FAILED at "../torch/csrc/jit/ir/ir.h":452, please report a bug to PyTorch.`

After this change:
`RuntimeError: Exporting the operator prim_DictConstruct to ONNX opset version 9 is not supported. Please feel free to request support or submit a pull request on PyTorch GitHub.`

This is a more useful error message.

Co-authored-by: Gary Miguel <garymiguel@microsoft.com>

[ghstack-poisoned]
  • Loading branch information
BowenBao committed May 20, 2021
2 parents 2dfb4c1 + f00100d commit ef8bd1b
Show file tree
Hide file tree
Showing 18 changed files with 82 additions and 257 deletions.
2 changes: 1 addition & 1 deletion aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3642,7 +3642,7 @@
device_check: NoCheck
device_guard: False

- func: slice.Tensor(Tensor(a) self, int dim=0, int? start=None, int? end=None, int step=1) -> Tensor(a)
- func: slice.Tensor(Tensor(a) self, int dim=0, int? start=0, int? end=9223372036854775807, int step=1) -> Tensor(a)
variants: function, method
device_check: NoCheck
device_guard: False
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/templates/RegisterSchema.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ TORCH_LIBRARY(aten, m) {
// Implementations located in torch/csrc/jit/runtime/register_prim_ops.cpp
m.def(TORCH_SELECTIVE_SCHEMA("aten::splitlines(str self, bool keepends=False) -> str[]"));
m.def(TORCH_SELECTIVE_SCHEMA(
"aten::slice.str(str string, int? start=None, int? end=None, int step=1) -> str"));
"aten::slice.str(str string, int? start=0, int? end=9223372036854775807, int step=1) -> str"));
m.def(TORCH_SELECTIVE_SCHEMA("aten::isupper(str self) -> bool"));
m.def(TORCH_SELECTIVE_SCHEMA("aten::islower(str self) -> bool"));
m.def(TORCH_SELECTIVE_SCHEMA("aten::capitalize(str self) -> str"));
Expand Down
2 changes: 1 addition & 1 deletion caffe2/python/operator_test/adagrad_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def test_row_wise_sparse_adagrad(self, inputs, lr, epsilon, weight_decay, gc, dc
),
**hu.gcs
)
@settings(deadline=1000)
@settings(deadline=None)
def test_row_wise_sparse_adagrad_empty(self, inputs, lr, epsilon, gc, dc):
param, momentum = inputs
grad = np.empty(shape=(0,) + param.shape[1:], dtype=np.float32)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@
("aten::cumprod_backward", datetime.date(2021, 5, 1)),
("aten::_triangular_solve_helper", datetime.date(9999, 1, 1)),
("aten::_addmv_impl_", datetime.date(2021, 5, 15)),
("aten::slice", datetime.date(2021, 5, 31)),
("aten::slice", datetime.date(2021, 6, 15)),
("aten::adaptive_avg_pool3d_backward", datetime.date(9999, 1, 1)),
("aten::_embedding_bag_dense_backward", datetime.date(9999, 1, 1)),
("aten::_amp_update_scale", datetime.date(2021, 6, 1)),
Expand Down
4 changes: 2 additions & 2 deletions test/cpp/jit/test_interpreter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ TEST(InterpreterTest, IgnorableArgsInSchema) {
auto op_to_specified_args = function.op_to_num_specified_args();
ASSERT_TRUE(op_to_specified_args.size() == 2);
ASSERT_TRUE(op_to_specified_args["aten::slice.Tensor"] == 4);
ASSERT_TRUE(op_to_specified_args["aten::slice.str"] == 4);
ASSERT_TRUE(op_to_specified_args["aten::slice.str"] == 1);
auto graph_vararg = build_mobile_export_analysis_graph_with_vararg();
MobileCode function_vararg(graph_vararg, "");
auto op_to_specified_args_vararg = function_vararg.op_to_num_specified_args();
Expand All @@ -172,7 +172,7 @@ TEST(InterpreterTest, IgnorableArgsInSchema) {
MobileCode function_nested(graph_nested, "");
auto op_to_specified_args_nested = function_nested.op_to_num_specified_args();
ASSERT_TRUE(op_to_specified_args_nested["aten::slice.Tensor"] == 4);
ASSERT_TRUE(op_to_specified_args_nested["aten::slice.str"] == 4);
ASSERT_TRUE(op_to_specified_args_nested["aten::slice.str"] == 1);

auto graph_non_const = build_mobile_export_analysis_graph_non_const();
MobileCode function_non_const(graph_non_const, "");
Expand Down
10 changes: 5 additions & 5 deletions test/cpp/jit/test_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,9 @@ std::shared_ptr<Graph> build_lstm() {

std::shared_ptr<Graph> build_mobile_export_analysis_graph() {
// We use following two schemas for this graph:
// 1. slice.Tensor(Tensor(a) self, int dim=0, int? start=None,
// int? end=None, int step=1) -> Tensor(a)
// 2. slice.str(str string, int? start=None, int? end=None,
// 1. slice.Tensor(Tensor(a) self, int dim=0, int? start=0,
// int? end=9223372036854775807, int step=1) -> Tensor(a)
// 2. slice.str(str string, int? start=0, int? end=9223372036854775807,
// int step=1) -> str
// %3 and %4 use slice.Tensor while %5 use slice.str.
// Since we can see %3 and %4 have the same last argument that is never used
Expand All @@ -114,7 +114,7 @@ std::shared_ptr<Graph> build_mobile_export_analysis_graph() {
%22 : str = prim::Constant[value="value"]()
%3 : Tensor = aten::slice(%0, %1, %20, %2, %1)
%4 : Tensor = aten::slice(%0, %2, %20, %21, %1)
%5 : str = aten::slice(%22, %20, %21, %2)
%5 : str = aten::slice(%22, %20, %21, %1)
return (%3, %4, %5))IR";

auto g = std::make_shared<Graph>();
Expand All @@ -139,7 +139,7 @@ std::shared_ptr<Graph> build_mobile_export_analysis_graph_nested() {
%c : Tensor = prim::If(%23)
block0():
%4 : Tensor = aten::slice(%0, %2, %20, %21, %1)
%5 : str = aten::slice(%22, %20, %21, %2)
%5 : str = aten::slice(%22, %20, %21, %1)
%c.1 : Tensor = aten::slice(%0, %1, %20, %2, %1)
-> (%c.1)
block1():
Expand Down
8 changes: 5 additions & 3 deletions test/jit/test_ignorable_args.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import sys

from torch._C import parse_ir
from torch.testing import FileCheck

Expand All @@ -17,6 +18,7 @@
class TestIgnorableArgs(JitTestCase):
def test_slice_ignorable_args_for_slice(self):
graph_str = """graph():
%15 : int = prim::Constant[value=9223372036854775807]()
%13 : int = prim::Constant[value=0]()
%10 : bool = prim::Constant[value=0]()
%8 : NoneType = prim::Constant()
Expand All @@ -29,8 +31,8 @@ def test_slice_ignorable_args_for_slice(self):
%6 : int[] = prim::ListConstruct(%0, %1, %2, %3, %4, %4)
%7 : int[][] = prim::ListConstruct(%5, %6)
%val.1 : Tensor = aten::tensor(%7, %8, %8, %10)
%16 : Tensor = aten::slice(%val.1, %13, %1, %8, %0)
%20 : Tensor = aten::slice(%16, %0, %8, %0, %0)
%16 : Tensor = aten::slice(%val.1, %13, %1, %15, %0)
%20 : Tensor = aten::slice(%16, %0, %13, %0, %0)
return (%20)"""
graph = parse_ir(graph_str)
function = self.createFunctionFromGraph(graph)
Expand All @@ -41,5 +43,5 @@ def test_slice_ignorable_args_for_slice(self):
# We ignore trailing arguments after start=2 for dim 0
# and after end=1 for dim 1
# because in %16, %15 and %0 are default values for the schema.
FileCheck().check("torch.slice(torch.slice(torch.tensor(_0), 0, 2), 1, None, 1)").run(src)
FileCheck().check("torch.slice(torch.tensor(_0), 0, 2), 1, 0, 1)").run(src)
self.assertEqual(function(), function_copy())
187 changes: 0 additions & 187 deletions test/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -4626,193 +4626,6 @@ def test(backward=False):
test(backward=True)
test(backward=True)

def test_index(self):
def consec(size, start=0):
numel = torch.tensor(size).prod().item()
return torch.arange(numel).view(size)

def consec_list(size):
return list(range(size))

def random_string(size):
letters = string.ascii_lowercase
return "".join(random.choice(letters) for i in range(size))

def check_indexing(indexing, tensor):
template = dedent("""
def func(x):
return x{}
""")

self._check_code(template.format(indexing), "func", [tensor])

def check_dynamic_indexing(indexing, tensor, value1, value2):
value1 = torch.tensor(value1)
value2 = torch.tensor(value2)

template = dedent("""
def func(x, value1, value2):
i = int(value1)
j = int(value2)
return x{}
""")

self._check_code(template.format(indexing), "func", [tensor, value1, value2])

# Torchscript assumes type Tensor by default, so we need this explicit
# declaration.
def check_indexing_list_int(indexing, list):
template = dedent("""
def func(x):
# type: (List[int]) -> Any
return x{}
""")

self._check_code(template.format(indexing), "func", [list])

def check_indexing_str(indexing, str):
template = dedent("""
def func(x):
# type: (str) -> Any
return x{}
""")

self._check_code(template.format(indexing), "func", [str])

# basic slices
check_indexing('[0]', consec((3, 3)))
check_indexing('[1]', consec((3, 3), 10))
check_indexing('[2]', consec((3, 3), 19))
check_indexing('[2]', consec((3,)))
check_indexing('[-1]', consec((3, 3), 19))
check_indexing('[0:2]', consec((3, 3, 3)))
check_indexing('[1:-1]', consec((3, 3, 3)))
check_indexing('[-3:-1]', consec((6, 3)))
check_indexing('[1:]', consec((3, 3)))
check_indexing('[:1]', consec((3, 3)))
check_indexing('[:]', consec((3, 2)))

# multi-dim: indexes
check_indexing('[0, 1]', consec((3, 3)))
check_indexing('[0, 1]', consec((3, 3, 2)))
check_indexing('[1, 0, 2]', consec((3, 3, 3)))
check_indexing('[2, -1]', consec((3, 3)))

# multi-dim: mixed slicing and indexing
check_indexing('[0, 1:2]', consec((3, 3)))
check_indexing('[0, :1]', consec((3, 3, 2)))
check_indexing('[1, 2:]', consec((3, 3, 3)))
check_indexing('[-1, 1:, 0]', consec((3, 3, 3, 3)))
check_indexing('[1:, -1, 0]', consec((3, 3, 3, 3)))
check_indexing('[-1, 2:, 1:2]', consec((3, 3, 3, 3)))
check_indexing('[-1, 1:, 0]', consec((3, 3, 3, 3)))
check_indexing('[-1, :, 0, 2]', consec((3, 3, 3, 3)))

# zero-sized slices
check_indexing('[0:0]', consec((2, 2)))
check_indexing('[0:0, 1]', consec((3, 3)))

# trivial expression usage
check_indexing('[1+1]', consec((3, 3)))
check_indexing('[1:(0 + 2)]', consec((3, 3, 3)))

# None for new dimensions
check_indexing('[None, 0]', consec((3, 3)))
check_indexing('[1, None]', consec((3, 3), 10))
check_indexing('[None, None, 2]', consec((3, 3), 19))
check_indexing('[None, 2, None]', consec((3,)))
check_indexing('[0:2, None]', consec((3, 3, 3)))
check_indexing('[None, 1:-1]', consec((3, 3, 3)))
check_indexing('[None, -3:-1, None]', consec((6, 3)))
check_indexing('[-1, None, 2:, None, 1:2]', consec((3, 3, 3, 3)))
check_indexing('[None, -1, None, 2:, None, 1:2, None]', consec((3, 3, 3, 3)))

# dynamic expression usage
check_dynamic_indexing("[i + j]", consec((3, 3)), 0, 1)
check_dynamic_indexing("[i:j, i]", consec((3, 3, 2)), 0, 2)

# positive striding
check_indexing_list_int('[0]', consec_list(6))
check_indexing_list_int('[1]', consec_list(7))
check_indexing_list_int('[2]', consec_list(8))
check_indexing_list_int('[2]', consec_list(9))
check_indexing_list_int('[-1]', consec_list(10))
check_indexing_list_int('[0:2]', consec_list(11))
check_indexing_list_int('[1:-1]', consec_list(12))
check_indexing_list_int('[-3:-1]', consec_list(13))
check_indexing_list_int('[1:]', consec_list(15))
check_indexing_list_int('[:1]', consec_list(16))
check_indexing_list_int('[:]', consec_list(17))
check_indexing_list_int('[::]', consec_list(0))
check_indexing_list_int('[1000::]', consec_list(0))
check_indexing_list_int('[:1000:]', consec_list(0))

# negative striding
check_indexing_list_int('[::-1]', consec_list(7))
check_indexing_list_int('[:3:-1]', consec_list(7))
check_indexing_list_int('[3::-1]', consec_list(7))
check_indexing_list_int('[1000::-1]', consec_list(7))
check_indexing_list_int('[3:0:-1]', consec_list(7))
check_indexing_list_int('[3:-1000:-1]', consec_list(7))
check_indexing_list_int('[0:0:-1]', consec_list(7))
check_indexing_list_int('[0:-1000:-1]', consec_list(7))

# only step is specified
check_indexing_list_int('[::-1]', consec_list(0))
check_indexing_list_int('[::-1]', consec_list(7))
check_indexing_list_int('[::-2]', consec_list(7))
check_indexing_list_int('[::2]', consec_list(7))
check_indexing_list_int('[::42]', consec_list(7))
check_indexing_list_int('[::-42]', consec_list(7))
check_indexing_list_int('[::42]', consec_list(0))
check_indexing_list_int('[::-42]', consec_list(0))
check_indexing_list_int('[::9223372036854775807]', consec_list(42))
check_indexing_list_int('[::-9223372036854775807]', consec_list(42))
with self.assertRaisesRegex(RuntimeError, "out of bounds"):
check_indexing_list_int('[::-9223372036854775808]', consec_list(42))
with self.assertRaisesRegex(RuntimeError, "should have non-zero step"):
check_indexing_list_int('[::0]', consec_list(42))

# striding strings
check_indexing_str('[0]', random_string(6))
check_indexing_str('[1]', random_string(7))
check_indexing_str('[2]', random_string(8))
check_indexing_str('[2]', random_string(9))
check_indexing_str('[-1]', random_string(10))
check_indexing_str('[0:2]', random_string(11))
check_indexing_str('[1:-1]', random_string(12))
check_indexing_str('[-3:-1]', random_string(13))
check_indexing_str('[1:]', random_string(15))
check_indexing_str('[:1]', random_string(16))
check_indexing_str('[:]', random_string(17))
check_indexing_str('[::]', random_string(0))
check_indexing_str('[1000::]', random_string(0))
check_indexing_str('[:1000:]', random_string(0))

check_indexing_str('[::-1]', random_string(7))
check_indexing_str('[:3:-1]', random_string(7))
check_indexing_str('[3::-1]', random_string(7))
check_indexing_str('[1000::-1]', random_string(7))
check_indexing_str('[3:0:-1]', random_string(7))
check_indexing_str('[3:-1000:-1]', random_string(7))
check_indexing_str('[0:0:-1]', random_string(7))
check_indexing_str('[0:-1000:-1]', random_string(7))

check_indexing_str('[::-1]', random_string(0))
check_indexing_str('[::-1]', random_string(7))
check_indexing_str('[::-2]', random_string(7))
check_indexing_str('[::2]', random_string(7))
check_indexing_str('[::42]', random_string(7))
check_indexing_str('[::-42]', random_string(7))
check_indexing_str('[::42]', random_string(0))
check_indexing_str('[::-42]', random_string(0))
check_indexing_str('[::9223372036854775807]', random_string(42))
check_indexing_str('[::-9223372036854775807]', random_string(42))
with self.assertRaisesRegex(RuntimeError, "out of bounds"):
check_indexing_str('[::-9223372036854775808]', random_string(42))
with self.assertRaisesRegex(RuntimeError, "should have non-zero step"):
check_indexing_str('[::0]', random_string(42))

def test_module_copy_with_attributes(self):
class Vocabulary(torch.jit.ScriptModule):
def __init__(self, vocab_list):
Expand Down
2 changes: 1 addition & 1 deletion tools/autograd/derivatives.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1052,7 +1052,7 @@
- name: sinh(Tensor self) -> Tensor
self: grad * self.cosh().conj()

- name: slice.Tensor(Tensor(a) self, int dim=0, int? start=None, int? end=None, int step=1) -> Tensor(a)
- name: slice.Tensor(Tensor(a) self, int dim=0, int? start=0, int? end=9223372036854775807, int step=1) -> Tensor(a)
self: slice_backward_wrapper(grad, self.sizes(), dim, start, end, step)
result: auto_linear

Expand Down
26 changes: 17 additions & 9 deletions torch/csrc/jit/frontend/ir_emitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3758,7 +3758,7 @@ struct to_ir {
Value* end,
Value* step) {
std::vector<NamedValue> args;
args.reserve(5);
args.reserve(4);
args.emplace_back(loc, "self", sliceable);

// XXX: If list slicing becomes more complicated or stops using
Expand All @@ -3770,10 +3770,11 @@ struct to_ir {
} else {
AT_ASSERT(!sliceable->type()->isSubtypeOf(TensorType::get()));
}

// TODO for now let's deal with TupleType first. Ideally all list, tensor,
// string, and tuple slicing should be same (tugsbayasgalan)
if (sliceable->type()->cast<TupleType>()) {
std::vector<at::optional<NamedValue>> tuple_args;
// since we are only dealing with tuple slicing, we try to keep
// since we are only dealing with tuple slicing for now, we try to keep
// tuple args seperate for now
tuple_args.reserve(3);

Expand All @@ -3787,15 +3788,22 @@ struct to_ir {
return emitTupleSlice(loc, args[0], tuple_args);
}

// handling cases like x[0:2]. x[0:2:] is already handled from python
// TODO this needs to be cleaned for list slicing
// Default value for start is 0.
if (!start) {
start = graph->insertConstant(0, loc);
}
args.emplace_back(loc, "start", start);

if (end) {
args.emplace_back(loc, "end", end);
}

if (!step) {
step = graph->insertConstant(1, loc);
}

args.emplace_back(loc, "start", start);
args.emplace_back(loc, "end", end);
args.emplace_back(loc, "step", step);
return emitBuiltinCall(loc, *graph, aten::slice, args, {});
NamedValue step_nv = NamedValue(loc, "step", step);
return emitBuiltinCall(loc, *graph, aten::slice, args, {step_nv});
}

// Desugars slice indexing: tensor[begin:end] -> tensor.slice(dim, begin, end,
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/passes/shape_analysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -884,7 +884,7 @@ class ShapePropagator {
"aten::trunc(Tensor self) -> Tensor",
"aten::rot90(Tensor self, int k, int[] dims) -> Tensor",
"aten::narrow(Tensor self, int dim, int start, int length) -> Tensor",
"aten::slice(Tensor self, int dim, int? start=None, int? end=None, int step=1) -> Tensor",
"aten::slice(Tensor self, int dim, int? start=0, int? end=9223372036854775807, int step=1) -> Tensor",
"aten::alias(Tensor self) -> Tensor",
},
[](Node* node) -> type_vec_t {
Expand Down
Loading

0 comments on commit ef8bd1b

Please sign in to comment.