Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX] Improve lower tuples and handle control flow #57650

Merged
merged 8 commits into from
May 18, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix for comments
  • Loading branch information
neginraoof committed May 18, 2021
commit 2882cd2b99d0b97fe112aa83256d86a4cd130034
5 changes: 3 additions & 2 deletions test/onnx/test_pytorch_onnx_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -6793,11 +6793,12 @@ def forward(self, input1, input2):
for j in range(2):
e, f = a
a = (d, f)
f = c[1]
if f.size(0) != input1.size(-1):
g = a[1]
g = b[1]
b = (g, f)
else:
k = a[0:]
k = c[0:]
b = (f, k[0])
m, n = b
c = (n, m)
Expand Down
13 changes: 5 additions & 8 deletions torch/csrc/jit/passes/lower_tuples.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ std::unordered_set<Symbol> supported_ops = {
aten::__getitem__};

// Flatten block inputs and insert a tuple construct in the block
static void flattenTupleInBlockParam(Node* n, size_t index) {
static void flattenTupleInLoopParams(Node* n, size_t index) {
auto input = n->inputs()[index];
neginraoof marked this conversation as resolved.
Show resolved Hide resolved
TupleTypePtr tt = input->type()->cast<TupleType>();
neginraoof marked this conversation as resolved.
Show resolved Hide resolved

Expand All @@ -40,9 +40,9 @@ static void flattenTupleInBlockParam(Node* n, size_t index) {
auto new_construct_node =
block->prependNode(block->owningGraph()->create(prim::TupleConstruct));
for (size_t j = 0; j < tt->elements().size(); ++j) {
auto new_block_in = block->addInput();
auto new_block_in = block->insertInput(index + j );
new_construct_node->addInput(new_block_in);
block_node->addInput(input->node()->inputs().at(j));
block_node->insertInput(index + j + 1, input->node()->inputs().at(j));
}
new_construct_node->output()->setType(block->inputs().at(index - 1)->type());
block->inputs().at(index - 1)->replaceAllUsesWith(
Expand Down Expand Up @@ -110,10 +110,6 @@ void removeTupleNodes(Node* n, bool must_remove_tuples) {
n->outputs()[i]->replaceAllUsesWith(construct_node->inputs().at(i));
}
} else if (n->kind() == prim::TupleIndex) {
if ((construct_node->kind() == prim::If) ||
(construct_node->kind() == prim::Loop)) {
return;
}
auto idx = n->inputs().at(1);
auto maybe_int = constant_as<int64_t>(idx);
if (!maybe_int) {
Expand Down Expand Up @@ -189,7 +185,8 @@ static void flattenInputs(Node* n, Node* insert_point) {
if (supported_ops.count(n->kind()) > 0) {
if ((n->kind() == prim::Loop)) {
neginraoof marked this conversation as resolved.
Show resolved Hide resolved
if (input->node()->kind() == prim::TupleConstruct) {
flattenTupleInBlockParam(n, i);
// This function supports all node types with blocks that take tuple inputs.
flattenTupleInLoopParams(n, i);
}
} else if ((n->kind() == prim::Return)) {
neginraoof marked this conversation as resolved.
Show resolved Hide resolved
if (input->node()->kind() == prim::TupleConstruct) {
neginraoof marked this conversation as resolved.
Show resolved Hide resolved
Expand Down