Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX] Improve lower tuples and handle control flow #57650

Merged
merged 8 commits into from
May 18, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
initial updates for lower tuples
  • Loading branch information
neginraoof committed May 18, 2021
commit 8d8749b378f1c574d997cbaef501b98b1e3e9a7c
40 changes: 40 additions & 0 deletions test/onnx/test_pytorch_onnx_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -6781,6 +6781,46 @@ def forward(self, input):
input = torch.randn(2, 5, 7, dtype=torch.float64)
self.run_test(Celu(), (input,))

def test_lower_tuple(self):
neginraoof marked this conversation as resolved.
Show resolved Hide resolved
class TupleModule(torch.nn.Module):
def forward(self, input1, input2):
# type: (torch.Tensor, torch.Tensor) -> torch.Tensor:
a = (input1, input2)
b = a
c = b
for i in range(5):
d = a[0]
for j in range(2):
e, f = a
a = (d, f)
if f.size(0) != input1.size(-1):
g = a[1]
b = (g, f)
else:
k = a[0:]
b = (f, k[0])
m, n = b
c = (n, m)
p, q = c
return p + q

input1 = torch.randn(2)
input2 = torch.randn(2)
self.run_test(TupleModule(), (input1, input2))

def test_lower_tuple_2(self):
class TupleModule(torch.nn.Module):
def forward(self, input):
# type: (torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
a = (input, input)
for x in range(5):
c, d = a
a = (c, d)
return a

input = torch.randn(2)
self.run_test(TupleModule(), (input,))

@skipIfUnsupportedMinOpsetVersion(9)
def test_where(self):
class Model(torch.nn.Module):
Expand Down
191 changes: 136 additions & 55 deletions torch/csrc/jit/passes/lower_tuples.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,26 +25,87 @@ std::unordered_set<Symbol> supported_ops = {
prim::Return,
prim::PythonOp,
aten::format,
prim::Uninitialized,
aten::__getitem__
};

// Flatten block inputs and insert a tuple construct in the block
static void flattenTupleInBlockParam(Node* n, size_t index) {
auto input = n->inputs()[index];
neginraoof marked this conversation as resolved.
Show resolved Hide resolved
TupleTypePtr tt = input->type()->cast<TupleType>();
neginraoof marked this conversation as resolved.
Show resolved Hide resolved

Block* loop_block = n->blocks().at(0);
Node* loop_block_node = n;

std::vector<Value*> new_node_inputs = {};
auto new_construct_node = loop_block->prependNode(loop_block->owningGraph()->create(prim::TupleConstruct));
for (size_t j = 0; j < tt->elements().size(); ++j) {
auto new_block_in = loop_block->addInput();
new_construct_node->addInput(new_block_in);
loop_block_node->addInput(input->node()->inputs().at(j));
}
new_construct_node->output()->setType(loop_block->inputs().at(index - 1)->type());
loop_block->inputs().at(index - 1)->replaceAllUsesWith(new_construct_node->output());
loop_block->eraseInput(index - 1);
loop_block_node->removeInput(index);

}

// Flatten tuple outputs of the block node and append a TupleConstruct
// node after the block node if there is an outer block.
static void flattenTupleInBlockReturn(Node* n, size_t index) {
auto input = n->inputs()[index];
neginraoof marked this conversation as resolved.
Show resolved Hide resolved
Block* loop_block = n->owningBlock();
Node* loop_block_node = loop_block->owningNode();
Node* new_construct_node = nullptr;
TupleTypePtr tt = input->type()->cast<TupleType>();
for (size_t j = 0; j < tt->elements().size(); ++j) {
loop_block->registerOutput(input->node()->inputs().at(j));
}

loop_block->eraseOutput(index);
if (loop_block_node == nullptr)
return;

if (loop_block_node->kind() == prim::Loop)
index = index - 1; // Loop block has an extra element (iter counter)
auto tuple_outputs = loop_block_node->outputs().at(index);
if (!(tuple_outputs->type()->cast<TupleType>()))
return; // When node has multiple blocks, do not flatten outputs again

new_construct_node = loop_block->owningGraph()->create(prim::TupleConstruct);
new_construct_node->insertAfter(loop_block_node);
for (size_t j = 0; j < tt->elements().size(); ++j) {
auto new_block_out = loop_block_node->addOutput();
new_construct_node->addInput(new_block_out);
}
// Replace the block node with the new TupleConstruct node
new_construct_node->output()->setType(tuple_outputs->type());
tuple_outputs->replaceAllUsesWith(new_construct_node->output());
loop_block_node->eraseOutput(index);
}

void removeTupleNodes(Node* n, bool must_remove_tuples) {
if (n->kind() != prim::TupleUnpack && n->kind() != prim::TupleIndex &&
n->kind() != prim::TupleSlice) {
return;
}
// tuple index has two inputs, tuple and index
auto construct = n->inputs().at(0)->node();
if (construct->kind() != prim::TupleConstruct) {
auto construct_node = n->inputs().at(0)->node();
if (construct_node->kind() != prim::TupleConstruct) {
if (must_remove_tuples) {
AT_ERROR(n->kind().toQualString(), " not matched to tuple construct");
}
return;
}
if (n->kind() == prim::TupleUnpack) {
for (size_t i = 0; i < n->outputs().size(); ++i) {
n->outputs()[i]->replaceAllUsesWith(construct->inputs().at(i));
n->outputs()[i]->replaceAllUsesWith(construct_node->inputs().at(i));
}
} else if (n->kind() == prim::TupleIndex) {
if ((construct_node->kind() == prim::If) || (construct_node->kind() == prim::Loop)) {
return;
}
auto idx = n->inputs().at(1);
auto maybe_int = constant_as<int64_t>(idx);
if (!maybe_int) {
Expand All @@ -54,21 +115,21 @@ void removeTupleNodes(Node* n, bool must_remove_tuples) {
return;
}
auto int_idx = *maybe_int;
auto len = construct->output()->type()->containedTypes().size();
size_t len = construct_node->output()->type()->containedTypes().size();
if (int_idx < 0) {
int_idx += len;
}
// currently, we allow non-constant tuple index if the tuple is of one type.
// so we need to check bounds here
if (int_idx >= 0 && static_cast<size_t>(int_idx) < len) {
n->output()->replaceAllUsesWith(construct->inputs().at(int_idx));
n->output()->replaceAllUsesWith(construct_node->inputs().at(int_idx));
}
} else if (n->kind() == prim::TupleSlice) {
std::vector<Value*> values;
int64_t beg = n->i(attr::beg);
int64_t end = n->i(attr::end);
for (int64_t i = beg; i < end; i += 1) {
values.push_back(construct->inputs().at(i));
values.push_back(construct_node->inputs().at(i));
}
auto graph = n->owningGraph();
auto tuple_out = graph->createTuple(values);
Expand Down Expand Up @@ -108,92 +169,112 @@ static void RemoveTupleConstants(Node* n) {
n->replaceAllUsesWith(tuple_construct);
}

static void VisitNode(Node* n, Node* insert_point) {
auto& graph = *n->owningGraph();

// tuple construction operators will become dead when the unpacks are replaced
if (n->kind() == prim::TupleConstruct) {
return;
}

// note: changing the second argument to false changes this pass from a
// complete lowering pass to one that removes tuples when possible. When
// tuples are first-class in the interpreter, we should still run this pass to
// remove extraneous uses

if (n->kind() == prim::TupleUnpack || n->kind() == prim::TupleIndex ||
n->kind() == prim::TupleSlice) {
removeTupleNodes(n, /*must_remove_tuples*/ true);
return;
}

static void flattenInputs(Node* n, Node* insert_point) {
// flatten the input list op(a, tup, b) --> op(a, t0, t1, b)
for (size_t i = 0; i < n->inputs().size();) {
auto input = n->inputs()[i];
if (TupleTypePtr tt = input->type()->cast<TupleType>()) {
TORCH_CHECK(
supported_ops.count(n->kind()) > 0,
"tuple appears in op that does not forward tuples, ",
"unsupported kind: ",
n->kind().toQualString());
TORCH_CHECK(
input->node()->kind() == prim::TupleConstruct,
"tuple use not matched to tuple construct");
for (size_t j = 0; j < tt->elements().size(); ++j) {
n->insertInput(i + 1 + j, input->node()->inputs().at(j));
(input->node()->kind() == prim::TupleConstruct),
"tuple use not matched to tuple construct. Instead found: ", n->kind().toQualString());
if (supported_ops.count(n->kind()) > 0) {
if ((n->kind() == prim::Loop)) {
neginraoof marked this conversation as resolved.
Show resolved Hide resolved
if (input->node()->kind() == prim::TupleConstruct) {
flattenTupleInBlockParam(n, i);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this function only suitable for prim::Loop? Then we may rename the function name to indicate that...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be used for any blocks with tuple inputs. Renamed it for now.

}
} else if ((n->kind() == prim::Return)) {
neginraoof marked this conversation as resolved.
Show resolved Hide resolved
if (input->node()->kind() == prim::TupleConstruct) {
neginraoof marked this conversation as resolved.
Show resolved Hide resolved
flattenTupleInBlockReturn(n, i);
neginraoof marked this conversation as resolved.
Show resolved Hide resolved
}
} else {
for (size_t j = 0; j < tt->elements().size(); ++j) {
neginraoof marked this conversation as resolved.
Show resolved Hide resolved
n->insertInput(i + 1 + j, input->node()->inputs().at(j));
}
n->removeInput(i);
}
// note: no update to i
// since tuples might be nested we need to recursively scan
// the new flattened inputs
} else {
TORCH_WARN(
"tuple appears in op inputs, but this op does not forward tuples, ",
"unsupported kind: ",
n->kind().toQualString());
++i;
}
n->removeInput(i);
// note: no update to i
// since tuples might be nested we need to recursively scan
// the new flattened inputs
} else {
++i;
}
}
for (auto b : n->blocks()) {
LowerAllTuples(b);
}
}

static void flattenOutputs(Node* n, Node* insert_point) {
// flatten the outputs list
auto& graph = *n->owningGraph();
for (size_t i = 0; i < n->outputs().size();) {
Value* output = n->outputs()[i];
if (!output->hasUses()) {
return;
++i;
continue;
}

// (a, b, tup, c) -> (a, b, t0, t1, c)
// and:
// tup = (t0, t1)
// is placed at the current insertion point
if (TupleTypePtr tt = output->type()->cast<TupleType>()) {
TORCH_CHECK(
supported_ops.count(n->kind()) > 0,
"tuple appears in op that does not forward tuples, ",
if (supported_ops.count(n->kind()) > 0) {
for (size_t j = 0; j < tt->elements().size(); j++) {
n->insertOutput(i + 1 + j)->setType(tt->elements()[j]);
}
auto new_tup =
graph.createTuple(n->outputs().slice(i + 1, tt->elements().size()));
new_tup->insertBefore(insert_point);
insert_point = new_tup;
output->replaceAllUsesWith(new_tup->output());
n->eraseOutput(i);
// note: no update to i to handle nested tuples
} else {
TORCH_WARN(
BowenBao marked this conversation as resolved.
Show resolved Hide resolved
"tuple appears in the op outputs, but this op does not forward tuples, ",
"unsupported kind: ",
n->kind().toQualString());
for (size_t j = 0; j < tt->elements().size(); j++) {
n->insertOutput(i + 1 + j)->setType(tt->elements()[j]);
++i;
}
auto new_tup =
graph.createTuple(n->outputs().slice(i + 1, tt->elements().size()));
new_tup->insertBefore(insert_point);
insert_point = new_tup;
output->replaceAllUsesWith(new_tup->output());
n->eraseOutput(i);
// note: no update to i to handle nested tuples
} else {
++i;
}
}
}

static void VisitNode(Node* n, Node* insert_point) {
// tuple construction operators will become dead when the unpacks are replaced
if (n->kind() == prim::TupleConstruct) {
return;
}
// note: changing the second argument to false changes this pass from a
// complete lowering pass to one that removes tuples when possible. When
// tuples are first-class in the interpreter, we should still run this pass to
// remove extraneous uses
if (n->kind() == prim::TupleUnpack || n->kind() == prim::TupleIndex ||
n->kind() == prim::TupleSlice) {
removeTupleNodes(n, /*must_remove_tuples*/ true);
return;
}
flattenInputs(n, insert_point);
for (auto b : n->blocks()) {
LowerAllTuples(b);
}
flattenOutputs(n, insert_point);
}

static void LowerAllTuples(Block* block) {
// tuples in parameter lists of a block behave exactly the same as
// _outputs_ of normal instructions, since the param_node represents the
// parameters as outputs, we can handle it by simply visiting the node
VisitNode(block->param_node(), *block->nodes().begin());
for (auto it = block->nodes().begin(), end = block->nodes().end();
it != end;) {
it != end;) {
auto n = *it++;
RemoveTupleConstants(n);
VisitNode(n, *it);
Expand Down