From 1aabb8f98c42ff262473ff1362e166b3fb2c9f25 Mon Sep 17 00:00:00 2001 From: BowenBao Date: Thu, 27 May 2021 12:03:59 -0700 Subject: [PATCH] [ONNX] handle aten::_set_item on Dict in convertInplaceOpsAndTrackAlias (#58317) (#58696) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/58696 It seems the JIT produces an output for aten::_set_item on lists but not on dicts. Previously the code would crash because it assumed it was operating on a list. The different behavior can be seen with the following test: ```python class DictModule(torch.nn.Module): def forward(self, x_in: torch.Tensor) -> typing.Dict[str, torch.Tensor]: x_out = {} x_out["test_key_out"] = x_in return x_out x_in = torch.tensor(1) dms = torch.jit.script(DictModule()) torch.onnx.export(dms, (x_in,), "/dev/null", example_outputs=(dms(x_in),)) ``` Before this change: `RuntimeError: outputs_.size() == 1INTERNAL ASSERT FAILED at "../torch/csrc/jit/ir/ir.h":452, please report a bug to PyTorch.` After this change: `RuntimeError: Exporting the operator prim_DictConstruct to ONNX opset version 9 is not supported. Please feel free to request support or submit a pull request on PyTorch GitHub.` This is a more useful error message. Test Plan: Imported from OSS Reviewed By: driazati Differential Revision: D28714804 Pulled By: SplitInfinity fbshipit-source-id: 1e5dc5fb44d1e3f971a22a79b5cf009d7590bf84 Co-authored-by: Gary Miguel --- test/jit/test_onnx_export.py | 17 +++++++++++++++++ .../passes/onnx/remove_inplace_ops_for_onnx.cpp | 11 +++++++++-- 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/test/jit/test_onnx_export.py b/test/jit/test_onnx_export.py index 9012fd36242f9..fdc30c22063a2 100644 --- a/test/jit/test_onnx_export.py +++ b/test/jit/test_onnx_export.py @@ -1,6 +1,7 @@ import io import os import sys +import typing import torch import torch.nn as nn @@ -381,3 +382,19 @@ def forward(self, x): f = io.BytesIO() torch.onnx.export_to_pretty_string( DynamicSliceExportMod(), (input,), f, example_outputs=example_outs, opset_version=10) + + def test_export_dict(self): + class DictModule(torch.nn.Module): + def forward(self, x_in: torch.Tensor) -> typing.Dict[str, torch.Tensor]: + return {"test_key_out": x_in} + + x_in = torch.tensor(1) + mod = DictModule() + mod.train(False) + + f = io.BytesIO() + torch.onnx.export_to_pretty_string(mod, (x_in,), f) + + with self.assertRaisesRegex(RuntimeError, r"DictConstruct.+is not supported."): + torch.onnx.export_to_pretty_string( + torch.jit.script(mod), (x_in,), f, example_outputs=(mod(x_in),)) diff --git a/torch/csrc/jit/passes/onnx/remove_inplace_ops_for_onnx.cpp b/torch/csrc/jit/passes/onnx/remove_inplace_ops_for_onnx.cpp index be34c73d134a2..d9c275e25ba86 100644 --- a/torch/csrc/jit/passes/onnx/remove_inplace_ops_for_onnx.cpp +++ b/torch/csrc/jit/passes/onnx/remove_inplace_ops_for_onnx.cpp @@ -270,8 +270,15 @@ static std::pair PrepareListAppendAndInsertForONNX(Node* n) { return std::make_pair(n->input(0), n->output()); } -static std::pair PrepareListSetItemForONNX(Node* n) { +static std::pair PrepareSetItemForONNX(Node* n) { TORCH_INTERNAL_ASSERT(n->kind() == aten::_set_item); + // It seems the JIT does not always produce an output for _set_item. + // In particular it seems to for list but not for dict. + // So we add one if needed. + if (n->outputs().size() == 0) { + n->addOutput(); + n->output()->setType(n->inputs().at(0)->type()); + } return std::make_pair(n->input(0), n->output()); } @@ -807,7 +814,7 @@ void InplaceConverter::convertInplaceOpsAndTrackAlias(Block* block) { } else if (nkind == aten::Delete) { std::tie(orig_data, new_out) = PrepareListDeleteForONNX(n); } else if (nkind == aten::_set_item) { - std::tie(orig_data, new_out) = PrepareListSetItemForONNX(n); + std::tie(orig_data, new_out) = PrepareSetItemForONNX(n); } else { // Not inplace op. continue;