Skip to content

Commit

Permalink
[ONNX] RNN scripting (pytorch#57564) (pytorch#58691)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#58691

Note the first commit in this PR has its own pull request here since it seemed self-contained:
pytorch#57082

* [ONNX] simplify batch_first logic in RNN tests

* [ONNX] support GRU with packed input in scripting mode

This required two changes:
* Add as_tensor to symbolic_opset9.py
* Change torch::jit::pushPackingPastRnn to recognize and properly
  replace another use of the batch_sizes output of prim::PackPadded.
  Previously the code assumed that the first use was as input to the
  RNN operator. However in some cases, it is also used to compute
  max_batch_size. For example in this code:
  https://github.com/pytorch/pytorch/blob/febff45/torch/nn/modules/rnn.py#L815-L815

With these changes the GRU tests now pass in scripting mode for opset
version >= 11.

Test Plan: Imported from OSS

Reviewed By: driazati

Differential Revision: D28714805

Pulled By: SplitInfinity

fbshipit-source-id: f19647a04533d9ec76399a8793b3f712ea0337d2

Co-authored-by: Gary Miguel <garymiguel@microsoft.com>
  • Loading branch information
2 people authored and deniskokarev committed Jun 9, 2021
1 parent f037b00 commit bb83f69
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 48 deletions.
82 changes: 39 additions & 43 deletions test/onnx/test_pytorch_onnx_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -7066,25 +7066,19 @@ def __init__(self, layers, nonlinearity, bidirect, dropout, batch_first):
def forward(self, input: PackedSequence):
return self.inner_model(input)

batch_first = True if packed_sequence == 2 else False
batch_first = packed_sequence == 2

if initial_state:
model = ElmanWithStateModel(layers=layers, bidirect=bidirectional, nonlinearity=nonlinearity,
dropout=dropout, batch_first=batch_first)

if packed_sequence == 1:
model = RnnModelWithPackedSequenceWithState(model, False)
if packed_sequence == 2:
model = RnnModelWithPackedSequenceWithState(model, True)
if packed_sequence:
model = RnnModelWithPackedSequenceWithState(model, batch_first)
else:
model = ElmanWithStateModel(layers=layers, bidirect=bidirectional,
nonlinearity=nonlinearity, dropout=dropout,
batch_first=batch_first)

if packed_sequence == 1:
model = RnnModelWithPackedSequenceWithoutState(model, False)
if packed_sequence == 2:
model = RnnModelWithPackedSequenceWithoutState(model, True)
if packed_sequence:
model = RnnModelWithPackedSequenceWithoutState(model, batch_first)

def make_input(batch_size):
seq_lengths = np.random.randint(1, RNN_SEQUENCE_LENGTH + 1, size=batch_size)
Expand Down Expand Up @@ -7115,24 +7109,18 @@ def make_input(batch_size):

def _lstm_test(self, layers, bidirectional, initial_state,
packed_sequence, dropout):
batch_first = True if packed_sequence == 2 else False
batch_first = packed_sequence == 2

if packed_sequence == 0:
model = LstmFlatteningResultWithoutSeqLength(RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, layers,
bidirectional, dropout, batch_first)
else:
if packed_sequence:
model = LstmFlatteningResultWithSeqLength(RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, layers,
bidirectional, dropout, batch_first)
if initial_state:
if packed_sequence == 1:
model = RnnModelWithPackedSequenceWithState(model, False)
if packed_sequence == 2:
model = RnnModelWithPackedSequenceWithState(model, True)
model = RnnModelWithPackedSequenceWithState(model, batch_first)
else:
if packed_sequence == 1:
model = RnnModelWithPackedSequenceWithoutState(model, False)
if packed_sequence == 2:
model = RnnModelWithPackedSequenceWithoutState(model, True)
model = RnnModelWithPackedSequenceWithoutState(model, batch_first)
else:
model = LstmFlatteningResultWithoutSeqLength(RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, layers,
bidirectional, dropout, batch_first)

def make_input(batch_size):
seq_lengths = np.random.randint(1, RNN_SEQUENCE_LENGTH + 1, size=batch_size)
Expand Down Expand Up @@ -7210,30 +7198,24 @@ def __init__(self, layers, bidirect, dropout, batch_first):
def forward(self, input, hx):
return self.inner_model(input, hx)

batch_first = True if packed_sequence == 2 else False
batch_first = packed_sequence == 2

if packed_sequence == 0:
if initial_state:
model = GRUNoSeqLengthWithStateModel(layers=layers, bidirect=bidirectional,
dropout=dropout, batch_first=batch_first)
else:
model = GRUNoSeqLengthWithoutStateModel(layers=layers, bidirect=bidirectional,
dropout=dropout, batch_first=batch_first)
else:
if packed_sequence:
if initial_state:
model = GRUWithStateModel(layers=layers, bidirect=bidirectional, dropout=dropout,
batch_first=batch_first)
if packed_sequence == 1:
model = RnnModelWithPackedSequenceWithState(model, False)
if packed_sequence == 2:
model = RnnModelWithPackedSequenceWithState(model, True)
model = RnnModelWithPackedSequenceWithState(model, batch_first)
else:
model = GRUWithoutStateModel(layers=layers, bidirect=bidirectional, dropout=dropout,
batch_first=batch_first)
if packed_sequence == 1:
model = RnnModelWithPackedSequenceWithoutState(model, False)
if packed_sequence == 2:
model = RnnModelWithPackedSequenceWithoutState(model, True)
model = RnnModelWithPackedSequenceWithoutState(model, batch_first)
else:
if initial_state:
model = GRUNoSeqLengthWithStateModel(layers=layers, bidirect=bidirectional,
dropout=dropout, batch_first=batch_first)
else:
model = GRUNoSeqLengthWithoutStateModel(layers=layers, bidirect=bidirectional,
dropout=dropout, batch_first=batch_first)

def make_input(batch_size):
seq_lengths = np.random.randint(1, RNN_SEQUENCE_LENGTH + 1, size=batch_size)
Expand Down Expand Up @@ -8899,7 +8881,7 @@ def forward(self, x):
dynamic_axes={'input_1': [0, 1]})

def make_test(name, base, layer, bidirectional, initial_state,
variable_length, dropout,
variable_length, dropout, script_test_min_opset_version,
**extra_kwargs):
test_name = str('_'.join([
'test', name, layer[1],
Expand All @@ -8918,6 +8900,7 @@ def make_test(name, base, layer, bidirectional, initial_state,
@disableScriptTest()
@skipIfUnsupportedMinOpsetVersion(9)
def f(self):
self.is_script_test_enabled = self.opset_version >= script_test_min_opset_version
self._dispatch_rnn_test(
base,
layers=layer[0],
Expand Down Expand Up @@ -8967,8 +8950,21 @@ def setup_rnn_tests():
('lstm', 'lstm', {}),
('gru', 'gru', {})
):
# Need Add between list of tensors
script_test_min_opset_version = 11

if ( # compiling in script mode fails with errors like:
# torch.jit.frontend.UnsupportedNodeError: annotated assignments
# without assigned value aren't supported
# https://msdata.visualstudio.com/Vienna/_workitems/edit/1160723
base == 'elman' or
# compiling in script mode fails with errors like:
# RuntimeError: Arguments for call are not valid.
# https://msdata.visualstudio.com/Vienna/_workitems/edit/1160723
base == 'lstm'):
script_test_min_opset_version = float("inf")
make_test(name, base, layer, bidirectional, initial_state,
variable_length, dropout,
variable_length, dropout, script_test_min_opset_version,
**extra_kwargs)
test_count += 1

Expand Down
52 changes: 47 additions & 5 deletions torch/csrc/jit/passes/onnx/peephole.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -258,8 +258,9 @@ void pushPackingPastRnn(Block* b) {
continue;
}

if (rnn->owningBlock() != n->owningBlock())
if (rnn->owningBlock() != n->owningBlock()) {
continue;
}

// Packing only has an effect on a network when its outputs are actually
// used, so we can remove it here.
Expand All @@ -286,9 +287,47 @@ void pushPackingPastRnn(Block* b) {
// remove PackPadded from in front of the RNN
n->outputs().at(0)->replaceAllUsesWith(n->inputs().at(0));

// note there can be multiple uses of the length blob. If we are
// translating a multi-level RNN it will be an input to each level.
n->outputs().at(1)->replaceFirstUseWith(n->inputs().at(1));
Value* batch_sizes = n->outputs().at(1);
while (batch_sizes->uses().size()) {
Use use_0 = batch_sizes->uses().at(0);
Node* user = use_0.user;
// Make calculation of max_batch_size not depend on batch_sizes.
// This looks for a pattern generated by code such as
// https://github.com/pytorch/pytorch/blob/febff45/torch/nn/modules/rnn.py#L815-L815.
//
// Replace onnx::Gather[axis=0](batch_sizes, 0)
// with onnx::Gather[axis=0](onnx::Shape(rnn_input), 1)
if (use_0.offset == 0 && user->kind() == onnx::Gather &&
user->i(attr::axis) == 0 &&
user->inputs().at(1)->node()->kind() == onnx::Constant &&
user->inputs().at(1)->node()->hasAttribute(attr::value)) {
const at::Tensor& const_val_t =
user->inputs().at(1)->node()->t(attr::value);
if (const_val_t.item().toInt() != 0) {
// We'll likely produce an invalid graph if this happens.
break;
}
Value* rnn_input = rnn->inputs().at(0);
Node* shape = b->owningGraph()->create(onnx::Shape);
shape->insertAfter(rnn_input->node());
shape->addInput(rnn_input);
batch_sizes->replaceFirstUseWith(shape->output());
user->inputs().at(1)->node()->t_(
attr::value, at::native::ones_like(const_val_t));
}
// Make RNN not depend on batch_sizes.
else if (user == rnn) {
batch_sizes->replaceFirstUseWith(n->inputs().at(1));
} else {
// If there are other uses that are not:
// * PadPacked (which will be removed in removeNopPacking),
// * Dead code (which will be removed in dead code elimination),
// then we likely have produced an invalid graph, since there will be a
// use of the output of PackPadded, but the PackPadded (and that output)
// will be removed.
break;
}
}

// and insert new PackPadded after the RNN
Node* newPackPadded = b->owningGraph()->create(prim::PackPadded, 2);
Expand All @@ -298,7 +337,7 @@ void pushPackingPastRnn(Block* b) {
next->outputs().at(0)->replaceAllUsesWith(newPackPadded->outputs().at(0));
n->outputs().at(1)->replaceAllUsesWith(newPackPadded->outputs().at(1));

// setup the new PackPadded's inputs
// set up the new PackPadded's inputs
newPackPadded->addInput(next->outputs().at(0));
newPackPadded->addInput(n->inputs().at(1));

Expand Down Expand Up @@ -328,6 +367,9 @@ void pushPackingPastRnn(Block* b) {
}
}

// Despite the name, this actually removes the PadPacked node and leaves
// the PackPadded node. The PackPadded should become dead code which will
// be eliminated later.
void removeNopPacking(Block* graph) {
for (auto it = graph->nodes().begin(); it != graph->nodes().end(); ++it) {
auto* n = *it;
Expand Down
2 changes: 2 additions & 0 deletions torch/onnx/symbolic_opset9.py
Original file line number Diff line number Diff line change
Expand Up @@ -1758,6 +1758,8 @@ def tensor(g, data, dtype=None, device=None, requires_grad=False):
dtype = sym_help.scalar_type_to_onnx.index(sym_help.cast_pytorch_to_onnx[dtype])
return g.op("Cast", data, to_i=sym_help.scalar_type_to_onnx[dtype])

def as_tensor(g, data, dtype=None, device=None):
return tensor(g, data, dtype, device)

@parse_args('v', 'i', 'v', 'v', 'v')
def zeros(g, sizes, dtype, layout, device, pin_memory=False):
Expand Down

0 comments on commit bb83f69

Please sign in to comment.