forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[ONNX] RNN scripting (pytorch#57564) (pytorch#58691)
Summary: Pull Request resolved: pytorch#58691 Note the first commit in this PR has its own pull request here since it seemed self-contained: pytorch#57082 * [ONNX] simplify batch_first logic in RNN tests * [ONNX] support GRU with packed input in scripting mode This required two changes: * Add as_tensor to symbolic_opset9.py * Change torch::jit::pushPackingPastRnn to recognize and properly replace another use of the batch_sizes output of prim::PackPadded. Previously the code assumed that the first use was as input to the RNN operator. However in some cases, it is also used to compute max_batch_size. For example in this code: https://github.com/pytorch/pytorch/blob/febff45/torch/nn/modules/rnn.py#L815-L815 With these changes the GRU tests now pass in scripting mode for opset version >= 11. Test Plan: Imported from OSS Reviewed By: driazati Differential Revision: D28714805 Pulled By: SplitInfinity fbshipit-source-id: f19647a04533d9ec76399a8793b3f712ea0337d2 Co-authored-by: Gary Miguel <garymiguel@microsoft.com>
- Loading branch information
1 parent
f037b00
commit bb83f69
Showing
3 changed files
with
88 additions
and
48 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters