Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX] enable some RNN tests in scripting mode #57082

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
[ONNX] enable some RNN tests in scripting mode
Previously they were all disabled, even though many of them actually
pass today.
  • Loading branch information
garymm committed Apr 28, 2021
commit 455d1ee5a6dda42e3f9048a292ca62e81cb9cfab
22 changes: 13 additions & 9 deletions test/onnx/test_pytorch_onnx_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -8445,7 +8445,7 @@ def forward(self, x, filled_value: int):
self.run_test(FillModule(), (x, filled_value))

def make_test(name, base, layer, bidirectional, initial_state,
variable_length, dropout,
variable_length, dropout, is_script_test_enabled,
**extra_kwargs):
test_name = str('_'.join([
'test', name, layer[1],
Expand All @@ -8455,15 +8455,9 @@ def make_test(name, base, layer, bidirectional, initial_state,

# Cannot export with older opsets because of 'ConstantFill' op
# ConstantFill was a temp op removed at opset 8. This is no longer supported by onnxruntime
# There are still some issues prevent us from enabling script test for these scenarios:
# test_gru_*:
# Operator aten::as_tensor is not supported by exporter yet.
# - https://msdata.visualstudio.com/Vienna/_workitems/edit/1055382
# Operator aten::_pack_padded_sequence is not supported by exporter yet.
# - https://msdata.visualstudio.com/Vienna/_workitems/edit/1055384
@disableScriptTest()
@skipIfUnsupportedMinOpsetVersion(9)
def f(self):
self.is_script_test_enabled = is_script_test_enabled
Copy link
Contributor

@neginraoof neginraoof May 5, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this PR still needed after #57564 ?

self._dispatch_rnn_test(
base,
layers=layer[0],
Expand Down Expand Up @@ -8513,8 +8507,18 @@ def setup_rnn_tests():
('lstm', 'lstm', {}),
('gru', 'gru', {})
):
# There are still some issues prevent us from enabling script test for these scenarios.
is_script_test_enabled = not (
# gru with sequence lengths blocked by:
# Operator aten::as_tensor is not supported by exporter yet.
# - https://msdata.visualstudio.com/Vienna/_workitems/edit/1055382
# Operator aten::_pack_padded_sequence is not supported by exporter yet.
# - https://msdata.visualstudio.com/Vienna/_workitems/edit/1055384
(base == 'gru' and variable_length[0] != 0) or
(base == 'elman') or
(base == 'lstm'))
make_test(name, base, layer, bidirectional, initial_state,
variable_length, dropout,
variable_length, dropout, is_script_test_enabled,
**extra_kwargs)
test_count += 1

Expand Down