-
Notifications
You must be signed in to change notification settings - Fork 108
Add time_major handling for bidirectional lstms #498
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
1. Adds time major handling for bidirectional for all rnns. 2. Reverse order of squeeze and transpose in lstm conversion when return seq_true: This makes the onnx model slightly more efficent when the model has a stack of lstms that are batch_major. By doing squeeze first and transpose later, we get two consucutive transposes that get eliminated by the optimizer hereby making model more efficient. Changes to be committed: modified: ../keras2onnx/ke2onnx/gru.py modified: ../keras2onnx/ke2onnx/lstm.py modified: ../keras2onnx/ke2onnx/simplernn.py modified: test_layers.py
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1833,6 +1833,30 @@ def test_bidirectional_with_bias(runner, rnn_class): | |
onnx_model = keras2onnx.convert_keras(model, model.name) | ||
assert runner(onnx_model.graph.name, onnx_model, x, expected) | ||
|
||
@pytest.mark.skipif((is_tensorflow_older_than('2.3.0') or (not is_tf_keras)), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for your contribution. Does this mean this test does not make effects right now (before tensorflow 2.3 release)? Then do you test this locally? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I tested it locally by installing the tf-nightly which had tf-2.3. Not all tests work with tf-2.3 so I just ran this test and checked that it works. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You mentioned there is a bug in tf.keras bidirectional lstm. Then if we use this new conversion, would it trigger that bug? I don't see there is special handling in conversion code. -- If we merge this PR, does it break some tf.keras model with that potential issue? Another thing is that this PR amazingly simplifies the original code for dynamic shape case. Could you elaborate on why the original dynamic shape case not needed? Thanks. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The bug is only triggered when running inference on tf.keras.bdirectional with time_major=True with tf.keras predict and inccorect results are produced by tensorflow. See the issue for more details. The keras2onnx converter produces teh correct results. We don't need any special handling for the bug. I think the previous stuff was complicated because we were trying to extract out the dynamic seq_len through a number of ops to use it for reshaping. In the new code, we still do reshape but we use 0 in reshape, so that we keep the batch_dim and seq_dim same in the reshaped tensor without knowing their dynamic value. All we need is to just merge the direction dim and feature dim into 1. |
||
reason=("keras LSTM does not have time_major attribute. There was a bug in tf.keras bidirectional lstm with time_major true which will be fixed in tf-2.3, See - https://github.com/tensorflow/tensorflow/issues/39635")) | ||
@pytest.mark.parametrize("rnn_class", RNN_CLASSES) | ||
def test_bidirectional_time_major_true(runner, rnn_class): | ||
feature_dim = 1 | ||
seq_len = 3 | ||
x = np.ones((1, seq_len, feature_dim), dtype=np.float32) | ||
|
||
for ret_seq in [True, False]: | ||
for merge_mode in ['concat', None]: | ||
K.clear_session() | ||
input = keras.Input(shape=(seq_len, feature_dim)) | ||
# Transpose input to be time major | ||
input_transposed = tf.transpose(input, perm=[1,0,2]) | ||
output = Bidirectional(rnn_class(1, return_sequences=ret_seq, | ||
time_major=True), | ||
name='bi', merge_mode=merge_mode)(input_transposed) | ||
if ret_seq and merge_mode == 'concat': | ||
output = tf.transpose(output, perm=[1,0,2]) | ||
model = keras.Model(inputs=input, outputs=output) | ||
|
||
expected = model.predict(x) | ||
onnx_model = keras2onnx.convert_keras(model, model.name) | ||
assert runner(onnx_model.graph.name, onnx_model, x, expected) | ||
|
||
@pytest.mark.parametrize("rnn_class", RNN_CLASSES) | ||
def test_bidirectional_with_initial_states(runner, rnn_class): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the case
merge_concat=False
, we haveapply_reshape
to handle static/dynamic shapes. Now seems that the new code does not need that any more. It onlyapply_reshape
formerge_concat=True
. Why we have difference here?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When merge_concat is false or None, we need to produce two tensors as output - one for forward part and one for backward. The output of bidirectional lstm is of shape [seq_len, 2, batch_size, input_size].
In case of time_major=true
In case of time_major=false