forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Only view() rhs of index_put if we need to (pytorch#9424)
Summary: During tracing (and export) we are now introducing an unnecessary hard-coded view on the RHS of indexed assignments such as `tensor[idxs] = rhs`. This caused a regression in the PyTorch translate models because these expressions appear with variable sizes in the RHS. This change makes it so we only call view if we indeed need to strip leading 1-dimensions Pull Request resolved: pytorch#9424 Reviewed By: colesbury Differential Revision: D8838881 Pulled By: jamesr66a fbshipit-source-id: 399e5daa7d021f4f59f6f92b9fae581f92bfc538
- Loading branch information
1 parent
5ac8a80
commit 7160846
Showing
4 changed files
with
37 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
graph(%0 : Double(100) | ||
%1 : Long(4) | ||
%2 : Double(1, 1, 1, 4)) { | ||
%3 : Double(4) = aten::view[size=[4]](%2) | ||
%4 : Long(4) = aten::_cast_Long[non_blocking=0](%1) | ||
%11 : Double(100) = aten::index_put(%0, %4, %3) | ||
return (%11); | ||
} |
7 changes: 7 additions & 0 deletions
7
test/expect/TestScript.test_index_put_trace_without_view.expect
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
graph(%0 : Double(100) | ||
%1 : Long(4) | ||
%2 : Double(4)) { | ||
%3 : Long(4) = aten::_cast_Long[non_blocking=0](%1) | ||
%10 : Double(100) = aten::index_put(%0, %3, %2) | ||
return (%10); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters