Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX] RNN scripting (#57564) #58691

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
[ONNX] RNN scripting (#57564)
Note the first commit in this PR has its own pull request here since it seemed self-contained:
#57082

* [ONNX] simplify batch_first logic in RNN tests

* [ONNX] support GRU with packed input in scripting mode

This required two changes:
* Add as_tensor to symbolic_opset9.py
* Change torch::jit::pushPackingPastRnn to recognize and properly
  replace another use of the batch_sizes output of prim::PackPadded.
  Previously the code assumed that the first use was as input to the
  RNN operator. However in some cases, it is also used to compute
  max_batch_size. For example in this code:
  https://github.com/pytorch/pytorch/blob/febff45/torch/nn/modules/rnn.py#L815-L815

With these changes the GRU tests now pass in scripting mode for opset
version >= 11.

Co-authored-by: Gary Miguel <garymiguel@microsoft.com>

[ghstack-poisoned]
  • Loading branch information
garymm authored and BowenBao committed May 20, 2021
commit 6fecb390ac30fb883b9d91932415d9f729328fd3
82 changes: 39 additions & 43 deletions test/onnx/test_pytorch_onnx_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -6953,25 +6953,19 @@ def __init__(self, layers, nonlinearity, bidirect, dropout, batch_first):
def forward(self, input: PackedSequence):
return self.inner_model(input)

batch_first = True if packed_sequence == 2 else False
batch_first = packed_sequence == 2

if initial_state:
model = ElmanWithStateModel(layers=layers, bidirect=bidirectional, nonlinearity=nonlinearity,
dropout=dropout, batch_first=batch_first)

if packed_sequence == 1:
model = RnnModelWithPackedSequenceWithState(model, False)
if packed_sequence == 2:
model = RnnModelWithPackedSequenceWithState(model, True)
if packed_sequence:
model = RnnModelWithPackedSequenceWithState(model, batch_first)
else:
model = ElmanWithStateModel(layers=layers, bidirect=bidirectional,
nonlinearity=nonlinearity, dropout=dropout,
batch_first=batch_first)

if packed_sequence == 1:
model = RnnModelWithPackedSequenceWithoutState(model, False)
if packed_sequence == 2:
model = RnnModelWithPackedSequenceWithoutState(model, True)
if packed_sequence:
model = RnnModelWithPackedSequenceWithoutState(model, batch_first)

def make_input(batch_size):
seq_lengths = np.random.randint(1, RNN_SEQUENCE_LENGTH + 1, size=batch_size)
Expand Down Expand Up @@ -7002,24 +6996,18 @@ def make_input(batch_size):

def _lstm_test(self, layers, bidirectional, initial_state,
packed_sequence, dropout):
batch_first = True if packed_sequence == 2 else False
batch_first = packed_sequence == 2

if packed_sequence == 0:
model = LstmFlatteningResultWithoutSeqLength(RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, layers,
bidirectional, dropout, batch_first)
else:
if packed_sequence:
model = LstmFlatteningResultWithSeqLength(RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, layers,
bidirectional, dropout, batch_first)
if initial_state:
if packed_sequence == 1:
model = RnnModelWithPackedSequenceWithState(model, False)
if packed_sequence == 2:
model = RnnModelWithPackedSequenceWithState(model, True)
model = RnnModelWithPackedSequenceWithState(model, batch_first)
else:
if packed_sequence == 1:
model = RnnModelWithPackedSequenceWithoutState(model, False)
if packed_sequence == 2:
model = RnnModelWithPackedSequenceWithoutState(model, True)
model = RnnModelWithPackedSequenceWithoutState(model, batch_first)
else:
model = LstmFlatteningResultWithoutSeqLength(RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, layers,
bidirectional, dropout, batch_first)

def make_input(batch_size):
seq_lengths = np.random.randint(1, RNN_SEQUENCE_LENGTH + 1, size=batch_size)
Expand Down Expand Up @@ -7097,30 +7085,24 @@ def __init__(self, layers, bidirect, dropout, batch_first):
def forward(self, input, hx):
return self.inner_model(input, hx)

batch_first = True if packed_sequence == 2 else False
batch_first = packed_sequence == 2

if packed_sequence == 0:
if initial_state:
model = GRUNoSeqLengthWithStateModel(layers=layers, bidirect=bidirectional,
dropout=dropout, batch_first=batch_first)
else:
model = GRUNoSeqLengthWithoutStateModel(layers=layers, bidirect=bidirectional,
dropout=dropout, batch_first=batch_first)
else:
if packed_sequence:
if initial_state:
model = GRUWithStateModel(layers=layers, bidirect=bidirectional, dropout=dropout,
batch_first=batch_first)
if packed_sequence == 1:
model = RnnModelWithPackedSequenceWithState(model, False)
if packed_sequence == 2:
model = RnnModelWithPackedSequenceWithState(model, True)
model = RnnModelWithPackedSequenceWithState(model, batch_first)
else:
model = GRUWithoutStateModel(layers=layers, bidirect=bidirectional, dropout=dropout,
batch_first=batch_first)
if packed_sequence == 1:
model = RnnModelWithPackedSequenceWithoutState(model, False)
if packed_sequence == 2:
model = RnnModelWithPackedSequenceWithoutState(model, True)
model = RnnModelWithPackedSequenceWithoutState(model, batch_first)
else:
if initial_state:
model = GRUNoSeqLengthWithStateModel(layers=layers, bidirect=bidirectional,
dropout=dropout, batch_first=batch_first)
else:
model = GRUNoSeqLengthWithoutStateModel(layers=layers, bidirect=bidirectional,
dropout=dropout, batch_first=batch_first)

def make_input(batch_size):
seq_lengths = np.random.randint(1, RNN_SEQUENCE_LENGTH + 1, size=batch_size)
Expand Down Expand Up @@ -8757,7 +8739,7 @@ def forward(self, x):
dynamic_axes={'input_1': [0, 1]})

def make_test(name, base, layer, bidirectional, initial_state,
variable_length, dropout,
variable_length, dropout, script_test_min_opset_version,
**extra_kwargs):
test_name = str('_'.join([
'test', name, layer[1],
Expand All @@ -8776,6 +8758,7 @@ def make_test(name, base, layer, bidirectional, initial_state,
@disableScriptTest()
@skipIfUnsupportedMinOpsetVersion(9)
def f(self):
self.is_script_test_enabled = self.opset_version >= script_test_min_opset_version
self._dispatch_rnn_test(
base,
layers=layer[0],
Expand Down Expand Up @@ -8825,8 +8808,21 @@ def setup_rnn_tests():
('lstm', 'lstm', {}),
('gru', 'gru', {})
):
# Need Add between list of tensors
script_test_min_opset_version = 11

if ( # compiling in script mode fails with errors like:
# torch.jit.frontend.UnsupportedNodeError: annotated assignments
# without assigned value aren't supported
# https://msdata.visualstudio.com/Vienna/_workitems/edit/1160723
base == 'elman' or
# compiling in script mode fails with errors like:
# RuntimeError: Arguments for call are not valid.
# https://msdata.visualstudio.com/Vienna/_workitems/edit/1160723
base == 'lstm'):
script_test_min_opset_version = float("inf")
make_test(name, base, layer, bidirectional, initial_state,
variable_length, dropout,
variable_length, dropout, script_test_min_opset_version,
**extra_kwargs)
test_count += 1

Expand Down
52 changes: 47 additions & 5 deletions torch/csrc/jit/passes/onnx/peephole.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -258,8 +258,9 @@ void pushPackingPastRnn(Block* b) {
continue;
}

if (rnn->owningBlock() != n->owningBlock())
if (rnn->owningBlock() != n->owningBlock()) {
continue;
}

// Packing only has an effect on a network when its outputs are actually
// used, so we can remove it here.
Expand All @@ -286,9 +287,47 @@ void pushPackingPastRnn(Block* b) {
// remove PackPadded from in front of the RNN
n->outputs().at(0)->replaceAllUsesWith(n->inputs().at(0));

// note there can be multiple uses of the length blob. If we are
// translating a multi-level RNN it will be an input to each level.
n->outputs().at(1)->replaceFirstUseWith(n->inputs().at(1));
Value* batch_sizes = n->outputs().at(1);
while (batch_sizes->uses().size()) {
Use use_0 = batch_sizes->uses().at(0);
Node* user = use_0.user;
// Make calculation of max_batch_size not depend on batch_sizes.
// This looks for a pattern generated by code such as
// https://github.com/pytorch/pytorch/blob/febff45/torch/nn/modules/rnn.py#L815-L815.
//
// Replace onnx::Gather[axis=0](batch_sizes, 0)
// with onnx::Gather[axis=0](onnx::Shape(rnn_input), 1)
if (use_0.offset == 0 && user->kind() == onnx::Gather &&
user->i(attr::axis) == 0 &&
user->inputs().at(1)->node()->kind() == onnx::Constant &&
user->inputs().at(1)->node()->hasAttribute(attr::value)) {
const at::Tensor& const_val_t =
user->inputs().at(1)->node()->t(attr::value);
if (const_val_t.item().toInt() != 0) {
// We'll likely produce an invalid graph if this happens.
break;
}
Value* rnn_input = rnn->inputs().at(0);
Node* shape = b->owningGraph()->create(onnx::Shape);
shape->insertAfter(rnn_input->node());
shape->addInput(rnn_input);
batch_sizes->replaceFirstUseWith(shape->output());
user->inputs().at(1)->node()->t_(
attr::value, at::native::ones_like(const_val_t));
}
// Make RNN not depend on batch_sizes.
else if (user == rnn) {
batch_sizes->replaceFirstUseWith(n->inputs().at(1));
} else {
// If there are other uses that are not:
// * PadPacked (which will be removed in removeNopPacking),
// * Dead code (which will be removed in dead code elimination),
// then we likely have produced an invalid graph, since there will be a
// use of the output of PackPadded, but the PackPadded (and that output)
// will be removed.
break;
}
}

// and insert new PackPadded after the RNN
Node* newPackPadded = b->owningGraph()->create(prim::PackPadded, 2);
Expand All @@ -298,7 +337,7 @@ void pushPackingPastRnn(Block* b) {
next->outputs().at(0)->replaceAllUsesWith(newPackPadded->outputs().at(0));
n->outputs().at(1)->replaceAllUsesWith(newPackPadded->outputs().at(1));

// setup the new PackPadded's inputs
// set up the new PackPadded's inputs
newPackPadded->addInput(next->outputs().at(0));
newPackPadded->addInput(n->inputs().at(1));

Expand Down Expand Up @@ -328,6 +367,9 @@ void pushPackingPastRnn(Block* b) {
}
}

// Despite the name, this actually removes the PadPacked node and leaves
// the PackPadded node. The PackPadded should become dead code which will
// be eliminated later.
void removeNopPacking(Block* graph) {
for (auto it = graph->nodes().begin(); it != graph->nodes().end(); ++it) {
auto* n = *it;
Expand Down
2 changes: 2 additions & 0 deletions torch/onnx/symbolic_opset9.py
Original file line number Diff line number Diff line change
Expand Up @@ -1756,6 +1756,8 @@ def tensor(g, data, dtype=None, device=None, requires_grad=False):
dtype = sym_help.scalar_type_to_onnx.index(sym_help.cast_pytorch_to_onnx[dtype])
return g.op("Cast", data, to_i=sym_help.scalar_type_to_onnx[dtype])

def as_tensor(g, data, dtype=None, device=None):
return tensor(g, data, dtype, device)

@parse_args('v', 'i', 'v', 'v', 'v')
def zeros(g, sizes, dtype, layout, device, pin_memory=False):
Expand Down