diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index d5ffa2a0197b3..f530de9a5e9f9 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -352,7 +352,7 @@ variants: function, method dispatch: SparseCPU, SparseCUDA: add_sparse - SparseCsrCPU: add_sparse_csr + SparseCsrCPU, SparseCsrCUDA: add_sparse_csr MkldnnCPU: mkldnn_add - func: add_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!) @@ -361,7 +361,7 @@ structured_delegate: add.out dispatch: SparseCPU, SparseCUDA: add_sparse_ - SparseCsrCPU: add_sparse_csr_ + SparseCsrCPU, SparseCsrCUDA: add_sparse_csr_ MkldnnCPU: mkldnn_add_ - func: add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) @@ -373,6 +373,7 @@ SparseCPU: add_out_sparse_cpu SparseCUDA: add_out_sparse_cuda SparseCsrCPU: add_out_sparse_csr_cpu + SparseCsrCUDA: add_out_sparse_csr_cuda MkldnnCPU: mkldnn_add_out - func: _add_relu.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor @@ -4581,7 +4582,7 @@ variants: function dispatch: SparseCPU, SparseCUDA: resize_as_sparse_ - SparseCsrCPU: resize_as_sparse_csr_ + SparseCsrCPU, SparseCsrCUDA: resize_as_sparse_csr_ - func: zero_(Tensor(a!) self) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -4866,7 +4867,7 @@ - func: to_dense(Tensor self, ScalarType? dtype=None) -> Tensor variants: method dispatch: - SparseCPU, SparseCUDA, SparseCsrCPU: sparse_to_dense + SparseCPU, SparseCUDA, SparseCsrCPU, SparseCsrCUDA: sparse_to_dense MkldnnCPU: mkldnn_to_dense - func: to_dense_backward(Tensor grad, Tensor input) -> Tensor diff --git a/aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp b/aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp index 750440fc7a84e..f8e8d81ae29b4 100644 --- a/aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp +++ b/aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp @@ -288,12 +288,12 @@ Tensor& add_out_dense_sparse_csr_cpu( auto out_strides0 = resultBuffer.strides()[0]; auto out_strides1 = resultBuffer.strides()[1]; - for (int32_t irow = 0; irow < src_crow_indices.size(0) - 1; + for (index_t irow = 0; irow < src_crow_indices.size(0) - 1; ++irow) { - int32_t start_index = crow_indices_accessor[irow]; - int32_t end_index = crow_indices_accessor[irow + 1]; + index_t start_index = crow_indices_accessor[irow]; + index_t end_index = crow_indices_accessor[irow + 1]; - for (int i = start_index; i < end_index; ++i) { + for (index_t i = start_index; i < end_index; ++i) { auto icol = col_indices_accessor[i]; auto index = resultBuffer.storage_offset() + irow * out_strides0 + icol * out_strides1; diff --git a/aten/src/ATen/native/sparse/cuda/SparseCsrTensorMath.cu b/aten/src/ATen/native/sparse/cuda/SparseCsrTensorMath.cu new file mode 100644 index 0000000000000..08eaeb8a12b2f --- /dev/null +++ b/aten/src/ATen/native/sparse/cuda/SparseCsrTensorMath.cu @@ -0,0 +1,160 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include + +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace at { +namespace native { + +using namespace at::sparse_csr; +// certain utiliy functions are usable from sparse COO. +using namespace at::sparse; + +Tensor& add_out_dense_sparse_csr_cuda( + Tensor& output, + const Tensor& dense, + const SparseCsrTensor& src, + const Scalar& alpha) { + TORCH_INTERNAL_ASSERT(dense.layout() == kStrided); + TORCH_INTERNAL_ASSERT(src.is_sparse_csr()); + TORCH_INTERNAL_ASSERT(dense.is_cuda()); + + TORCH_CHECK( + output.is_contiguous(), + "out argument must be contiguous, but got: ", + output.suggest_memory_format()); + TORCH_CHECK( + output.is_cuda(), + "add: expected 'out' to be CUDA tensor, but got tensor on device: ", + output.device()); + + TORCH_CHECK( + src.is_cuda(), + "add: expected 'other' to be a CUDA tensor, but got tensor on device: ", + src.device()); + + TORCH_CHECK( + dense.sizes().equals(src.sizes()), + "add: expected 'self' and 'other' to have same size, but self has size ", + dense.sizes(), + " while other has size ", + src.sizes(), + " (FYI: op2-sparse addition does not currently support broadcasting)"); + + auto commonDtype = promoteTypes(dense.scalar_type(), src.scalar_type()); + TORCH_CHECK( + canCast(commonDtype, output.scalar_type()), + "Can't convert result type ", + commonDtype, + " to output ", + output.scalar_type(), + " in add operation"); + + Tensor src_values = src.values(); + Tensor src_crow_indices = src.crow_indices(); + Tensor src_col_indices = src.col_indices(); + + resize_output(output, dense.sizes()); + + Tensor resultBuffer = output; + Tensor valuesBuffer = src_values.to(commonDtype); + if (output.scalar_type() != commonDtype) { + resultBuffer = dense.to(commonDtype); + } else if (!is_same_tensor(output, dense)) { + resultBuffer.copy_(dense); + } + AT_DISPATCH_ALL_TYPES( + commonDtype, + "add_out_op2_sparse_csr", + [&valuesBuffer, &resultBuffer, &alpha, &src_crow_indices, &src_col_indices]() { + AT_DISPATCH_INDEX_TYPES( + src_crow_indices.scalar_type(), + "csr_add_out_crow_indices", + [&valuesBuffer, &resultBuffer, &alpha, &src_crow_indices, &src_col_indices]() { + scalar_t* values_accessor = valuesBuffer.data_ptr(); + scalar_t* out_ptr = resultBuffer.data_ptr(); + scalar_t cast_value = alpha.to(); + + index_t* crow_indices_accessor = src_crow_indices.data_ptr(); + index_t* col_indices_accessor = src_col_indices.data_ptr(); + int64_t out_storage_offset = resultBuffer.storage_offset(); + + auto out_strides = resultBuffer.strides(); + int64_t out_strides0 = out_strides[0]; + int64_t out_strides1 = out_strides[1]; + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA()); + auto policy = thrust::cuda::par(allocator).on(stream); + + // Note that this could be wildly imbalanced if the sparsity pattern varies a lot between rows. + thrust::for_each( + policy, + thrust::make_counting_iterator(int64_t(0)), + thrust::make_counting_iterator(int64_t(src_crow_indices.size(0) - 1)), + [values_accessor, + crow_indices_accessor, + col_indices_accessor, + out_ptr, + out_storage_offset, + out_strides0, + cast_value, + out_strides1 + ]__device__(int64_t irow) { + index_t start_index = crow_indices_accessor[irow]; + index_t end_index = crow_indices_accessor[irow + 1]; + + for (index_t i = start_index; i < end_index; ++i) { + auto icol = col_indices_accessor[i]; + auto index = out_storage_offset + irow * out_strides0 + icol * out_strides1; + out_ptr[index] += cast_value * values_accessor[i]; + } + }); + }); + }); + if (output.scalar_type() != commonDtype) { + output.copy_(resultBuffer); + } + return output; +} + +Tensor& add_out_sparse_csr_cuda( + const Tensor& self, + const SparseCsrTensor& other, + const Scalar& alpha, + SparseCsrTensor& out) { + if (self.layout() == kStrided) { + return add_out_dense_sparse_csr_cuda(out, self, other, alpha); + } else { + TORCH_CHECK( + false, + "NotImplementedError: Addition of sparse CSR tensors is not yet implemented.") + } + return out; +} + +} // namespace native +} // namespace at diff --git a/test/test_sparse_csr.py b/test/test_sparse_csr.py index 89c8e90465bfd..3e2577aa1b5b5 100644 --- a/test/test_sparse_csr.py +++ b/test/test_sparse_csr.py @@ -278,7 +278,6 @@ def test_sparse_csr_from_dense(self, device): self.assertEqual(torch.tensor([0, 1, 2] * 3, dtype=torch.int64), sparse.col_indices()) self.assertEqual(torch.tensor([2] * 9), sparse.values()) - @onlyCPU @dtypes(torch.double) def test_dense_convert(self, device, dtype): size = (5, 5) @@ -400,7 +399,35 @@ def test_shape(di, dj, dk, nnz): for k in range(2, 8): test_shape(i, j, k, i * j // 2) - @onlyCPU + @dtypes(torch.float, torch.double) + def test_add(self, device, dtype): + def _test_spadd_shape(nnz, shape): + x = self.genSparseCSRTensor(shape, nnz, dtype=dtype, device=device, index_dtype=torch.int32) + y = torch.randn(*shape, dtype=dtype, device=device) + r = random.random() + + res = torch.add(y, x, alpha=r) + expected = y + r * x.to_dense() + self.assertEqual(res, expected) + + # Non contiguous dense tensor + s = list(shape) + s[0] = shape[-1] + s[-1] = shape[0] + y = torch.randn(*s, dtype=torch.double, device=device) + y.transpose_(0, len(s) - 1) + r = random.random() + + res = torch.add(y, x, alpha=r) + expected = y + r * x.to_dense() + + self.assertEqual(res, expected) + + _test_spadd_shape(10, [100, 100]) + _test_spadd_shape(0, [100, 100]) + _test_spadd_shape(10, [100, 1]) + _test_spadd_shape(10, [1, 100]) + @dtypes(*torch.testing.floating_types()) def test_coo_csr_conversion(self, device, dtype): size = (5, 5)