Skip to content

Commit

Permalink
Only view() rhs of index_put if we need to (pytorch#9424)
Browse files Browse the repository at this point in the history
Summary:
During tracing (and export) we are now introducing an unnecessary hard-coded view on the RHS of indexed assignments such as `tensor[idxs] = rhs`. This caused a regression in the PyTorch translate models because these expressions appear with variable sizes in the RHS. This change makes it so we only call view if we indeed need to strip leading 1-dimensions
Pull Request resolved: pytorch#9424

Reviewed By: colesbury

Differential Revision: D8838881

Pulled By: jamesr66a

fbshipit-source-id: 399e5daa7d021f4f59f6f92b9fae581f92bfc538
  • Loading branch information
James Reed authored and facebook-github-bot committed Jul 14, 2018
1 parent 5ac8a80 commit 7160846
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 1 deletion.
8 changes: 8 additions & 0 deletions test/expect/TestScript.test_index_put_trace_with_view.expect
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
graph(%0 : Double(100)
%1 : Long(4)
%2 : Double(1, 1, 1, 4)) {
%3 : Double(4) = aten::view[size=[4]](%2)
%4 : Long(4) = aten::_cast_Long[non_blocking=0](%1)
%11 : Double(100) = aten::index_put(%0, %4, %3)
return (%11);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
graph(%0 : Double(100)
%1 : Long(4)
%2 : Double(4)) {
%3 : Long(4) = aten::_cast_Long[non_blocking=0](%1)
%10 : Double(100) = aten::index_put(%0, %3, %2)
return (%10);
}
16 changes: 16 additions & 0 deletions test/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -4212,6 +4212,22 @@ def forward(self, x):
def some_func(x):
return sm(x)

def test_index_put_trace_with_view(self):
@torch.jit.trace(torch.rand(100), torch.tensor([1, 2, 3, 4]), torch.rand(1, 1, 1, 4))
def test_index_put(target, indices, rhs):
target[indices] = rhs
return target

self.assertExpected(str(test_index_put.graph))

def test_index_put_trace_without_view(self):
@torch.jit.trace(torch.rand(100), torch.tensor([1, 2, 3, 4]), torch.rand(4))
def test_index_put(target, indices, rhs):
target[indices] = rhs
return target

self.assertExpected(str(test_index_put.graph))


class TestEndToEndHybridFrontendModels(JitTestCase):

Expand Down
7 changes: 6 additions & 1 deletion torch/csrc/autograd/python_variable_indexing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,12 @@ int THPVariable_setitem(PyObject* self, PyObject* index, PyObject* py_value) {
}

IntList slicedValueSizes = slicePrefix1sSize(value.sizes());
auto valuesSliced = value.view(slicedValueSizes);
torch::autograd::Variable valuesSliced;
if (!value.sizes().equals(slicedValueSizes)) {
valuesSliced = value.view(slicedValueSizes);
} else {
valuesSliced = value;
}
dispatch_index_put_(sliced, variableIndices, valuesSliced);
return 0;
END_HANDLE_TH_ERRORS_RET(-1)
Expand Down

0 comments on commit 7160846

Please sign in to comment.