-
Notifications
You must be signed in to change notification settings - Fork 108
Add time_major handling for bidirectional lstms #498
Conversation
1. Adds time major handling for bidirectional for all rnns. 2. Reverse order of squeeze and transpose in lstm conversion when return seq_true: This makes the onnx model slightly more efficent when the model has a stack of lstms that are batch_major. By doing squeeze first and transpose later, we get two consucutive transposes that get eliminated by the optimizer hereby making model more efficient. Changes to be committed: modified: ../keras2onnx/ke2onnx/gru.py modified: ../keras2onnx/ke2onnx/lstm.py modified: ../keras2onnx/ke2onnx/simplernn.py modified: test_layers.py
@@ -1833,6 +1833,30 @@ def test_bidirectional_with_bias(runner, rnn_class): | |||
onnx_model = keras2onnx.convert_keras(model, model.name) | |||
assert runner(onnx_model.graph.name, onnx_model, x, expected) | |||
|
|||
@pytest.mark.skipif((is_tensorflow_older_than('2.3.0') or (not is_tf_keras)), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your contribution. Does this mean this test does not make effects right now (before tensorflow 2.3 release)? Then do you test this locally?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I tested it locally by installing the tf-nightly which had tf-2.3. Not all tests work with tf-2.3 so I just ran this test and checked that it works.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You mentioned there is a bug in tf.keras bidirectional lstm. Then if we use this new conversion, would it trigger that bug? I don't see there is special handling in conversion code. -- If we merge this PR, does it break some tf.keras model with that potential issue?
Another thing is that this PR amazingly simplifies the original code for dynamic shape case. Could you elaborate on why the original dynamic shape case not needed? Thanks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The bug is only triggered when running inference on tf.keras.bdirectional with time_major=True with tf.keras predict and inccorect results are produced by tensorflow. See the issue for more details. The keras2onnx converter produces teh correct results. We don't need any special handling for the bug.
I think the previous stuff was complicated because we were trying to extract out the dynamic seq_len through a number of ops to use it for reshaping. In the new code, we still do reshape but we use 0 in reshape, so that we keep the batch_dim and seq_dim same in the reshaped tensor without knowing their dynamic value. All we need is to just merge the direction dim and feature dim into 1.
# In this case, only one Keras output with shape (N, T, 2 * C') should be produced. | ||
# ( T, N, 2*C ) if it was time major. | ||
apply_reshape(scope, lstm_out, operator.outputs[0].full_name, container, | ||
desired_shape=[0, 0, 2 * hidden_size]) | ||
else: | ||
# If merge_mode=None, two tensors should be generated. The first/second tensor is the output of |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the case merge_concat=False
, we have apply_reshape
to handle static/dynamic shapes. Now seems that the new code does not need that any more. It only apply_reshape
for merge_concat=True
. Why we have difference here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When merge_concat is false or None, we need to produce two tensors as output - one for forward part and one for backward. The output of bidirectional lstm is of shape [seq_len, 2, batch_size, input_size].
In case of time_major=true
- We first apply split two split it across the num_direction axis: [seq_len, 1, batch_size, input_size] and [ seq_len, 1, batch_size, input_size]
- Then we squeeze in axis 1, to produce two tensors of shape [seq_len, batch_size, input_size] and [ seq_len, batch_size, input_size]
In case of time_major=false
- We do the transpose to make it batch major: [batch_size, 2, seq_len, input_size]
- We apply split two split it across the num_direction axis: [batch_size, 1, seq_len, input_size] and [batch_size, 1, seq_len, input_size]
- Then we squeeze in axis 1, to produce two tensors of shape [batch_size, seq_len, input_size] and [batch_size, seq_len, input_size]
Description
Changes to be committed