diff --git a/keras2onnx/ke2onnx/gru.py b/keras2onnx/ke2onnx/gru.py index 0a090776..b9642e77 100644 --- a/keras2onnx/ke2onnx/gru.py +++ b/keras2onnx/ke2onnx/gru.py @@ -126,8 +126,13 @@ def convert_keras_gru(scope, operator, container, bidirectional=False): output_seq = op.return_sequences reset_after = op.reset_after + time_major = simplernn.is_time_major(op, bidirectional) + # Inputs - gru_x = _name('X') + gru_x = operator.inputs[0].full_name + if not time_major: + gru_x = _name('X') + apply_transpose(scope, operator.inputs[0].full_name, gru_x, container, perm=[1, 0, 2]) tensor_w, tensor_r, tensor_b = build_parameters(scope, operator, container, bidirectional) sequence_lengths = simplernn.build_sequence_lengths(scope, operator, container) initial_h = simplernn.build_initial_states(scope, operator, container, bidirectional) @@ -147,10 +152,6 @@ def convert_keras_gru(scope, operator, container, bidirectional=False): # Outputs output_names = [_name('Y'), _name('Y_h')] - # Transpose input values - input_name = operator.inputs[0].full_name - apply_transpose(scope, input_name, gru_x, container, perm=[1, 0, 2]) - oopb = OnnxOperatorBuilder(container, scope) oopb.apply_op_with_output('apply_gru', input_names, diff --git a/keras2onnx/ke2onnx/lstm.py b/keras2onnx/ke2onnx/lstm.py index ce9ad735..b42bd1d4 100644 --- a/keras2onnx/ke2onnx/lstm.py +++ b/keras2onnx/ke2onnx/lstm.py @@ -197,22 +197,20 @@ def build_output(scope, operator, container, output_names, bidirectional=False): output_name = operator.outputs[0].full_name - time_major = op.time_major if hasattr(op, "time_major") else False + time_major = simplernn.is_time_major(op, bidirectional) # Create output-adjusting operators if output_seq: - lstm_out = lstm_y + # Squeeze the num_direction dim as we know its size is 1 for + # lstm(forward/reverse). + lstm_out = output_name if time_major else _name('y_squeezed') + apply_squeeze(scope, lstm_y, lstm_out, container, axes=[1]) if not time_major: # Onnx LSTM produces time major output. Add a transpose operator to # make it batch_major, if the keras op was not time_major. - # This transforms [ S, 1, B, I] -> [ B, 1, S, I ] where B is + # This transforms [ S, B, I] -> [ B, S, I ] where B is # batch_size and S is seq_len. - perm = [2, 1, 0, 3] - lstm_out = _name('y_transposed') - apply_transpose(scope, lstm_y, lstm_out, container, perm=perm) - - # Squeeze the num_direction dim as we know its size is 1 for - # lstm(forward/reverse). - apply_squeeze(scope, lstm_out, output_name, container, axes=[1]) + perm = [1, 0, 2] + apply_transpose(scope, lstm_out, output_name, container, perm=perm) else: apply_squeeze(scope, lstm_h, output_name, container, axes=[0]) @@ -272,10 +270,10 @@ def convert_keras_lstm(scope, operator, container, bidirectional=False): if bidirectional: output_seq = op.forward_layer.return_sequences - time_major = op.forward_layer.time_major if hasattr(op.forward_layer, "time_major") else False else: output_seq = op.return_sequences - time_major = op.time_major if hasattr(op, "time_major") else False + + time_major = simplernn.is_time_major(op, bidirectional) # Inputs lstm_x = operator.inputs[0].full_name diff --git a/keras2onnx/ke2onnx/simplernn.py b/keras2onnx/ke2onnx/simplernn.py index 80c7f7f3..8433319b 100644 --- a/keras2onnx/ke2onnx/simplernn.py +++ b/keras2onnx/ke2onnx/simplernn.py @@ -9,7 +9,6 @@ from ..common.onnx_ops import ( apply_cast, apply_concat, - apply_identity, apply_reshape, apply_slice, apply_split, @@ -260,6 +259,7 @@ def build_output(scope, operator, container, output_names, bidirectional=False): apply_slice(scope, input_shape_tensor, seq_dim, container, [1], [2], axes=[0]) if bidirectional: + time_major = is_time_major(op, bidirectional) forward_layer = op.forward_layer hidden_size = forward_layer.units @@ -274,94 +274,31 @@ def build_output(scope, operator, container, output_names, bidirectional=False): merge_concat = True if output_seq: - # The output shape of runtime is 3-D while ONNX says 4-D, so we do a Reshape to fix it. - if is_static_shape: - rnn_y_fixed = _name('Y_fixed') - apply_reshape(scope, rnn_y, rnn_y_fixed, container, - desired_shape=[seq_length, 2, -1, hidden_size]) + lstm_out = _name('y_transposed') + if not time_major: + # Transpose ONNX RNN Y with shape (T, D, N, C') into (N, T, D, C') + apply_transpose(scope, rnn_y, lstm_out, container, perm=[2, 0, 1, 3]) else: - shape_tensor = oopb.add_node('Concat', - [seq_dim, - ('_a', oopb.int64, np.array([2], dtype='int64')), - ('_b', oopb.int64, np.array([-1], dtype='int64')), - ('_c', oopb.int64, np.array([hidden_size], dtype='int64')) - ], - input_name + '_output_seq_shape', axis=0) - rnn_y_fixed = oopb.add_node('Reshape', - [rnn_y, - shape_tensor - ], - input_name + '_output_seq_shape_1') - + # Transpose RNN Y with shape (T, D, N, C) into (T, N, D, C) + apply_transpose(scope, rnn_y, lstm_out, container, perm=[0, 2, 1, 3]) if merge_concat: - # In this case, only one Keras output with shape (N, T, 2 * C') should be produced - - # Transpose ONNX RNN Y with shape (T, D, N, C') into (T, N, D, C') - transposed_y = _name('Y_transposed') - apply_transpose(scope, rnn_y_fixed, transposed_y, container, perm=[2, 0, 1, 3]) - - # Change shape (T, N, D, C') to (N, T, D * C') to meet Keras spec - if is_static_shape: - apply_reshape(scope, transposed_y, operator.outputs[0].full_name, container, - desired_shape=[-1, seq_length, 2 * hidden_size]) - else: - attrs = {'axis': 0} - shape_tensor_2 = oopb.add_node('Concat', - [('_a', oopb.int64, np.array([-1], dtype='int64')), - seq_dim, - ('_b', oopb.int64, np.array([2 * hidden_size], dtype='int64')) - ], - input_name + '_output_seq_shape_2', **attrs) - shape_tensor_output = oopb.add_node('Reshape', - [transposed_y, - shape_tensor_2 - ], - input_name + '_output_merge_concat') - apply_identity(scope, shape_tensor_output, operator.outputs[0].full_name, container) + # In this case, only one Keras output with shape (N, T, 2 * C') should be produced. + # ( T, N, 2*C ) if it was time major. + apply_reshape(scope, lstm_out, operator.outputs[0].full_name, container, + desired_shape=[0, 0, 2 * hidden_size]) else: # If merge_mode=None, two tensors should be generated. The first/second tensor is the output of # forward/backward pass. - # Transpose ONNX RNN Y with shape (T, D, N, C') into (T, N, D, C') - transposed_y = _name('Y_transposed') - apply_transpose(scope, rnn_y_fixed, transposed_y, container, perm=[2, 0, 1, 3]) - # Split the transposed Y with shape (T, N, D, C') into (T, N, 1, C') and (T, N, 1, C') forward_y = _name('Y_forward') backward_y = _name('Y_backward') axis_direction = 2 - apply_split(scope, transposed_y, [forward_y, backward_y], container, axis=axis_direction) + apply_split(scope, lstm_out, [forward_y, backward_y], container, axis=axis_direction) # Change (T, N, 1, C') into (T, N, C') to meet Keras spec - forward_y_1 = _name('Y_forward_1') - backward_y_1 = _name('Y_backward_1') - apply_squeeze(scope, forward_y, forward_y_1, container, axes=[axis_direction]) - apply_squeeze(scope, backward_y, backward_y_1, container, axes=[axis_direction]) - - if is_static_shape: - apply_reshape(scope, forward_y_1, operator.outputs[0].full_name, container, - desired_shape=[-1, seq_length, hidden_size]) - apply_reshape(scope, backward_y_1, operator.outputs[1].full_name, container, - desired_shape=[-1, seq_length, hidden_size]) - else: - shape_tensor_3 = oopb.add_node('Concat', - [('_a', oopb.int64, np.array([-1], dtype='int64')), - seq_dim, - ('_b', oopb.int64, np.array([hidden_size], dtype='int64')) - ], - input_name + '_output_seq_shape_3', **attrs) - shape_tensor_output_0 = oopb.add_node('Reshape', - [forward_y_1, - shape_tensor_3 - ], - input_name + '_shape_tensor_output_0') - shape_tensor_output_1 = oopb.add_node('Reshape', - [backward_y_1, - shape_tensor_3 - ], - input_name + '_shape_tensor_output_1') - apply_identity(scope, shape_tensor_output_0, operator.outputs[0].full_name, container) - apply_identity(scope, shape_tensor_output_1, operator.outputs[1].full_name, container) + apply_squeeze(scope, forward_y, operator.outputs[0].full_name, container, axes=[axis_direction]) + apply_squeeze(scope, backward_y, operator.outputs[1].full_name, container, axes=[axis_direction]) else: perm = [1, 0, 2] if merge_concat: @@ -458,6 +395,12 @@ def build_output_states(scope, operator, container, output_names, bidirectional= output_h = operator.outputs[1].full_name apply_squeeze(scope, rnn_h, output_h, container) +def is_time_major(op, bidirectional): + if bidirectional: + time_major = op.forward_layer.time_major if hasattr(op.forward_layer, "time_major") else False + else: + time_major = op.time_major if hasattr(op, "time_major") else False + return time_major def convert_keras_simple_rnn(scope, operator, container, bidirectional=False): op = operator.raw_operator @@ -468,9 +411,13 @@ def convert_keras_simple_rnn(scope, operator, container, bidirectional=False): output_seq = op.forward_layer.return_sequences else: output_seq = op.return_sequences + time_major = is_time_major(op, bidirectional) # Inputs - rnn_x = _name('X') + rnn_x = operator.inputs[0].full_name + if not time_major: + rnn_x = _name('X') + apply_transpose(scope, operator.inputs[0].full_name, rnn_x, container, perm=[1, 0, 2]) tensor_w, tensor_r, tensor_b = build_parameters(scope, operator, container, bidirectional) sequence_lengths = build_sequence_lengths(scope, operator, container) initial_h = build_initial_states(scope, operator, container, bidirectional) @@ -490,10 +437,6 @@ def convert_keras_simple_rnn(scope, operator, container, bidirectional=False): # Outputs output_names = [_name('Y'), _name('Y_h')] - # Transpose input values - input_name = operator.inputs[0].full_name - apply_transpose(scope, input_name, rnn_x, container, perm=[1, 0, 2]) - oopb = OnnxOperatorBuilder(container, scope) oopb.apply_op_with_output('apply_rnn', input_names, diff --git a/tests/test_layers.py b/tests/test_layers.py index a9b589a4..cb861c92 100644 --- a/tests/test_layers.py +++ b/tests/test_layers.py @@ -1854,6 +1854,30 @@ def test_bidirectional_with_bias(runner, rnn_class): onnx_model = keras2onnx.convert_keras(model, model.name) assert runner(onnx_model.graph.name, onnx_model, x, expected) +@pytest.mark.skipif((is_tensorflow_older_than('2.3.0') or (not is_tf_keras)), + reason=("keras LSTM does not have time_major attribute. There was a bug in tf.keras bidirectional lstm with time_major true which will be fixed in tf-2.3, See - https://github.com/tensorflow/tensorflow/issues/39635")) +@pytest.mark.parametrize("rnn_class", RNN_CLASSES) +def test_bidirectional_time_major_true(runner, rnn_class): + feature_dim = 1 + seq_len = 3 + x = np.ones((1, seq_len, feature_dim), dtype=np.float32) + + for ret_seq in [True, False]: + for merge_mode in ['concat', None]: + K.clear_session() + input = keras.Input(shape=(seq_len, feature_dim)) + # Transpose input to be time major + input_transposed = tf.transpose(input, perm=[1,0,2]) + output = Bidirectional(rnn_class(1, return_sequences=ret_seq, + time_major=True), + name='bi', merge_mode=merge_mode)(input_transposed) + if ret_seq and merge_mode == 'concat': + output = tf.transpose(output, perm=[1,0,2]) + model = keras.Model(inputs=input, outputs=output) + + expected = model.predict(x) + onnx_model = keras2onnx.convert_keras(model, model.name) + assert runner(onnx_model.graph.name, onnx_model, x, expected) @pytest.mark.parametrize("rnn_class", RNN_CLASSES) def test_bidirectional_with_initial_states(runner, rnn_class):