Skip to content

Commit

Permalink
[ONNX] Update special post process for SequenceInsert after SequenceE…
Browse files Browse the repository at this point in the history
…mpty (#56965)

`ONNX::SequenceEmpty` requires dtype to be provided, and is default to float. We updates previous dtype of created `ONNX::SequenceEmpty` node when dtype is later discovered to be other than float, through downstream `ONNX::SequenceInsert` node. This PR improves the algorithm to cover nested loop case.

Co-authored-by: BowenBao <bowbao@microsoft.com>
  • Loading branch information
BowenBao authored May 18, 2021
1 parent b98b622 commit 1d8059b
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 8 deletions.
18 changes: 18 additions & 0 deletions test/onnx/test_pytorch_onnx_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -4797,6 +4797,24 @@ def forward(self, x):
x = torch.randn(4, 4, 3, 4)
self.run_test(model, (x, ))

@skipIfUnsupportedMinOpsetVersion(13)
def test_list_append_nested_mixed_dtype(self):
class ListModel(torch.nn.Module):
def forward(self, x, y):
res = []
for i in range(x.size(0)):
for j in range(x.size(1)):
if i == j:
res.append(x == y)
else:
res.append(x != y)
return res

model = torch.jit.script(ListModel())
x = torch.randn(4, 4, 3, 4)
y = torch.randn(3, 4)
self.run_test(model, (x, y))

@skipIfUnsupportedMinOpsetVersion(11)
def test_list_pop(self):
class ListModel(torch.nn.Module):
Expand Down
86 changes: 78 additions & 8 deletions torch/csrc/jit/passes/onnx/shape_type_inference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1170,18 +1170,88 @@ void SpecialPostProcess(Node* n) {
// If the list to insert is empty, we set the elem type by
// looking at the tensor being inserted.
auto list_node = n->input(0)->node();
// NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
auto t_node = n->input(1)->node();
if (!list_node || list_node->kind() != prim::ListConstruct ||
list_node->inputs().size() != 0) {
break;
}
auto seq_node = n->input(0)->node();
auto t_type = n->input(1)->type()->cast<TensorType>();

auto update_sequence_empty_dtype = [](Node* n, TensorTypePtr t_type) {
TORCH_INTERNAL_ASSERT(n && n->kind() == ::c10::onnx::SequenceEmpty);
TORCH_INTERNAL_ASSERT(t_type && t_type->scalarType().has_value());
auto scalar_type = t_type->scalarType().value();
auto onnx_type = ATenTypeToOnnxType(scalar_type);
n->i_(attr::dtype, onnx_type);
n->output()->setType(ListType::create(t_type));
};

auto find_sequence_empty = [](Value* input,
TensorTypePtr t_type) -> Node* {
auto find_sequence_empty_impl =
[](Value* input,
TensorTypePtr t_type,
auto& find_sequence_empty_ref) -> Node* {
auto input_node = input->node();
TORCH_INTERNAL_ASSERT(input_node);

// 1. Input is from SequenceEmpty.
if (input_node->kind() == ::c10::onnx::SequenceEmpty) {
return input_node;
}

if (TensorTypePtr t_type = n->input(1)->type()->cast<TensorType>()) {
if (t_type->scalarType()) {
// 2. Input is subblock input of a Loop node, which takes outer block
// SequenceEmpty as input.
if (input_node->kind() == prim::Param) {
auto loop_n = input_node->owningBlock()->owningNode();
if (nullptr == loop_n || loop_n->kind() != ::c10::onnx::Loop) {
return nullptr;
}

auto it = std::find(
input_node->outputs().begin(),
input_node->outputs().end(),
input);
auto idx = std::distance(input_node->outputs().begin(), it);

auto outer_block_node = loop_n->input(idx)->node();
if (outer_block_node &&
outer_block_node->kind() == ::c10::onnx::SequenceEmpty) {
// Found SequenceEmpty
input->setType(ListType::create(t_type));
return outer_block_node;
} else {
// Outer block node still not SequenceEmpty, call recursively in
// case of nested loop.
auto found_n = find_sequence_empty_ref(
loop_n->input(idx), t_type, find_sequence_empty_ref);
if (found_n) {
input->setType(ListType::create(t_type));
}
return found_n;
}
}

// Could not find source SequenceEmpty node.
return nullptr;
};
return find_sequence_empty_impl(
input, t_type, find_sequence_empty_impl);
};

if (seq_node && t_type && t_type->scalarType()) {
if (seq_node->kind() == prim::ListConstruct &&
seq_node->inputs().size() != 0) {
// When prim::ListConstruct is not yet converted to
// onnx::SequenceEmpty
n->output()->setType(ListType::create(t_type));
} else if (seq_node->kind() == ::c10::onnx::SequenceEmpty) {
update_sequence_empty_dtype(seq_node, t_type);
} else if (seq_node->kind() == prim::Param) {
// Try to find original onnx::SequenceEmpty node in outer block.
auto seq_empty_n = find_sequence_empty(n->input(0), t_type);
if (seq_empty_n) {
update_sequence_empty_dtype(seq_empty_n, t_type);
}
}
}

break;
}
case ::c10::onnx::Cast: {
Expand Down

0 comments on commit 1d8059b

Please sign in to comment.