From 7160846c81d9ebd298a4c1608d67b8f4bd76fee7 Mon Sep 17 00:00:00 2001 From: James Reed Date: Sat, 14 Jul 2018 00:04:40 -0700 Subject: [PATCH] Only view() rhs of index_put if we need to (#9424) Summary: During tracing (and export) we are now introducing an unnecessary hard-coded view on the RHS of indexed assignments such as `tensor[idxs] = rhs`. This caused a regression in the PyTorch translate models because these expressions appear with variable sizes in the RHS. This change makes it so we only call view if we indeed need to strip leading 1-dimensions Pull Request resolved: https://github.com/pytorch/pytorch/pull/9424 Reviewed By: colesbury Differential Revision: D8838881 Pulled By: jamesr66a fbshipit-source-id: 399e5daa7d021f4f59f6f92b9fae581f92bfc538 --- ...tScript.test_index_put_trace_with_view.expect | 8 ++++++++ ...ript.test_index_put_trace_without_view.expect | 7 +++++++ test/test_jit.py | 16 ++++++++++++++++ torch/csrc/autograd/python_variable_indexing.cpp | 7 ++++++- 4 files changed, 37 insertions(+), 1 deletion(-) create mode 100644 test/expect/TestScript.test_index_put_trace_with_view.expect create mode 100644 test/expect/TestScript.test_index_put_trace_without_view.expect diff --git a/test/expect/TestScript.test_index_put_trace_with_view.expect b/test/expect/TestScript.test_index_put_trace_with_view.expect new file mode 100644 index 0000000000000..24ff0fe32c451 --- /dev/null +++ b/test/expect/TestScript.test_index_put_trace_with_view.expect @@ -0,0 +1,8 @@ +graph(%0 : Double(100) + %1 : Long(4) + %2 : Double(1, 1, 1, 4)) { + %3 : Double(4) = aten::view[size=[4]](%2) + %4 : Long(4) = aten::_cast_Long[non_blocking=0](%1) + %11 : Double(100) = aten::index_put(%0, %4, %3) + return (%11); +} diff --git a/test/expect/TestScript.test_index_put_trace_without_view.expect b/test/expect/TestScript.test_index_put_trace_without_view.expect new file mode 100644 index 0000000000000..f483213b48146 --- /dev/null +++ b/test/expect/TestScript.test_index_put_trace_without_view.expect @@ -0,0 +1,7 @@ +graph(%0 : Double(100) + %1 : Long(4) + %2 : Double(4)) { + %3 : Long(4) = aten::_cast_Long[non_blocking=0](%1) + %10 : Double(100) = aten::index_put(%0, %3, %2) + return (%10); +} diff --git a/test/test_jit.py b/test/test_jit.py index f187c944d1043..0663b41b67e08 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -4212,6 +4212,22 @@ def forward(self, x): def some_func(x): return sm(x) + def test_index_put_trace_with_view(self): + @torch.jit.trace(torch.rand(100), torch.tensor([1, 2, 3, 4]), torch.rand(1, 1, 1, 4)) + def test_index_put(target, indices, rhs): + target[indices] = rhs + return target + + self.assertExpected(str(test_index_put.graph)) + + def test_index_put_trace_without_view(self): + @torch.jit.trace(torch.rand(100), torch.tensor([1, 2, 3, 4]), torch.rand(4)) + def test_index_put(target, indices, rhs): + target[indices] = rhs + return target + + self.assertExpected(str(test_index_put.graph)) + class TestEndToEndHybridFrontendModels(JitTestCase): diff --git a/torch/csrc/autograd/python_variable_indexing.cpp b/torch/csrc/autograd/python_variable_indexing.cpp index 8c2023721bf2c..cd8329cad0143 100644 --- a/torch/csrc/autograd/python_variable_indexing.cpp +++ b/torch/csrc/autograd/python_variable_indexing.cpp @@ -371,7 +371,12 @@ int THPVariable_setitem(PyObject* self, PyObject* index, PyObject* py_value) { } IntList slicedValueSizes = slicePrefix1sSize(value.sizes()); - auto valuesSliced = value.view(slicedValueSizes); + torch::autograd::Variable valuesSliced; + if (!value.sizes().equals(slicedValueSizes)) { + valuesSliced = value.view(slicedValueSizes); + } else { + valuesSliced = value; + } dispatch_index_put_(sliced, variableIndices, valuesSliced); return 0; END_HANDLE_TH_ERRORS_RET(-1)