Skip to content

Commit

Permalink
Update on "[ONNX] Fix Gather replacement in RNN peephole"
Browse files Browse the repository at this point in the history
From PR: #58691, Replacing the second input of `Gather` 0 to 1 affects other innocent Nodes. In Issue #91526 onnx::range starts from 0, the 0 is changed by this mechanism, as it's shared by onnx::Gather. This PR intends to create a whole independent Constant 0 for replacement. NOTE: The PR passes all existing RNN tests locally in case CI doesn't include RNN test.

TODO: test

[ghstack-poisoned]
  • Loading branch information
titaiwangms committed Jan 31, 2023
2 parents bad05cf + e905b6c commit 712bc35
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 1 deletion.
52 changes: 52 additions & 0 deletions test/onnx/test_pytorch_onnx_no_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -919,6 +919,58 @@ def forward(self, x, seq_lens):
f = io.BytesIO()
torch.onnx.export(m, (x, seq_lens), f, verbose=False)

def test_pushpackingpastrnn_in_peephole_create_own_gather_input(self):
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

num_layers = 3
T, B, C = 11, 5, 7
mask_start_point = 0
class LSTMTraceWrapper(torch.nn.Module):
def __init__(self):
super(LSTMTraceWrapper, self).__init__()

self.rnn = torch.nn.LSTM(
input_size=C, hidden_size=C, num_layers=num_layers
)

def forward(self, x, seq_lens):
mask = torch.arange(mask_start_point, x.shape[1])
seq_lens = seq_lens[mask]
x = pack_padded_sequence(x, seq_lens)
# Calculate sizes and prepare views to our zero buffer to pass as hx
max_batch_size = x.batch_sizes[0]
hx = torch.randn(num_layers, max_batch_size, C)
cx = torch.randn(num_layers, max_batch_size, C)
x, _ = self.rnn(x, (hx, cx))
x, _ = pad_packed_sequence(x)
return x

x = torch.ones(T, B, C)
# length 5 because of B
seq_lens = torch.from_numpy(np.array([11, 3, 2, 2, 1], dtype=np.int32))
m = LSTMTraceWrapper()

f = io.BytesIO()
torch.onnx.export(m, (x, seq_lens), f, verbose=True, input_names=["input", "seq_len"], dynamic_axes={"input":{1:"B"}})
onnx_proto = onnx.load_model_from_string(f.getvalue())
# the first argument in onnx::Range should be constant node with value 0
const_node = []
constant_input_name = None
for n in onnx_proto.graph.node:
if n.op_type == "Constant":
const_node.append(n)
elif n.op_type == "Range":
constant_input_name = n.input[0]
self.assertNotEqual(constant_input_name, None)
self.assertNotEqual(len(const_node), 0)

value = None
for n in const_node:
if n.output[0] == constant_input_name:
value = np.frombuffer(n.attribute[0].t.raw_data, dtype=np.int64)
self.assertEqual(value, 0)


def test_trace_fork_wait_inline_onnx(self):
def fork_body(x):
return torch.neg(x), torch.neg(x)
Expand Down
2 changes: 1 addition & 1 deletion test/onnx/test_pytorch_onnx_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -9420,7 +9420,7 @@ def forward(self, input: rnn_utils.PackedSequence):
)
)
else:
model = ElmanWithStateModel(
model = ElmanWithoutStateModel(
layers=layers,
bidirect=bidirectional,
nonlinearity=nonlinearity,
Expand Down

0 comments on commit 712bc35

Please sign in to comment.