Skip to content

Commit

Permalink
[ONNX] handle aten::_set_item on Dict in convertInplaceOpsAndTrackAli…
Browse files Browse the repository at this point in the history
…as (#58317)

It seems the JIT produces an output for aten::_set_item on lists but
not on dicts. Previously the code would crash because it assumed it
was operating on a list.

The different behavior can be seen with the following test:

```python
class DictModule(torch.nn.Module):
    def forward(self, x_in: torch.Tensor) -> typing.Dict[str, torch.Tensor]:
        x_out = {}
        x_out["test_key_out"] = x_in
        return x_out

x_in = torch.tensor(1)
dms = torch.jit.script(DictModule())
torch.onnx.export(dms, (x_in,), "/dev/null", example_outputs=(dms(x_in),))
```

Before this change:
`RuntimeError: outputs_.size() == 1INTERNAL ASSERT FAILED at "../torch/csrc/jit/ir/ir.h":452, please report a bug to PyTorch.`

After this change:
`RuntimeError: Exporting the operator prim_DictConstruct to ONNX opset version 9 is not supported. Please feel free to request support or submit a pull request on PyTorch GitHub.`

This is a more useful error message.

Co-authored-by: Gary Miguel <garymiguel@microsoft.com>

[ghstack-poisoned]
  • Loading branch information
garymm authored and BowenBao committed May 20, 2021
1 parent eca525e commit 2dfb4c1
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 2 deletions.
17 changes: 17 additions & 0 deletions test/jit/test_onnx_export.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import io
import os
import sys
import typing

import torch
import torch.nn as nn
Expand Down Expand Up @@ -381,3 +382,19 @@ def forward(self, x):
f = io.BytesIO()
torch.onnx.export_to_pretty_string(
DynamicSliceExportMod(), (input,), f, example_outputs=example_outs, opset_version=10)

def test_export_dict(self):
class DictModule(torch.nn.Module):
def forward(self, x_in: torch.Tensor) -> typing.Dict[str, torch.Tensor]:
return {"test_key_out": x_in}

x_in = torch.tensor(1)
mod = DictModule()
mod.train(False)

f = io.BytesIO()
torch.onnx.export_to_pretty_string(mod, (x_in,), f)

with self.assertRaisesRegex(RuntimeError, r"DictConstruct.+is not supported."):
torch.onnx.export_to_pretty_string(
torch.jit.script(mod), (x_in,), f, example_outputs=(mod(x_in),))
11 changes: 9 additions & 2 deletions torch/csrc/jit/passes/onnx/remove_inplace_ops_for_onnx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -270,8 +270,15 @@ static std::pair<Value*, Value*> PrepareListAppendAndInsertForONNX(Node* n) {
return std::make_pair(n->input(0), n->output());
}

static std::pair<Value*, Value*> PrepareListSetItemForONNX(Node* n) {
static std::pair<Value*, Value*> PrepareSetItemForONNX(Node* n) {
TORCH_INTERNAL_ASSERT(n->kind() == aten::_set_item);
// It seems the JIT does not always produce an output for _set_item.
// In particular it seems to for list but not for dict.
// So we add one if needed.
if (n->outputs().size() == 0) {
n->addOutput();
n->output()->setType(n->inputs().at(0)->type());
}
return std::make_pair(n->input(0), n->output());
}

Expand Down Expand Up @@ -807,7 +814,7 @@ void InplaceConverter::convertInplaceOpsAndTrackAlias(Block* block) {
} else if (nkind == aten::Delete) {
std::tie(orig_data, new_out) = PrepareListDeleteForONNX(n);
} else if (nkind == aten::_set_item) {
std::tie(orig_data, new_out) = PrepareListSetItemForONNX(n);
std::tie(orig_data, new_out) = PrepareSetItemForONNX(n);
} else {
// Not inplace op.
continue;
Expand Down

0 comments on commit 2dfb4c1

Please sign in to comment.