Skip to content

Commit

Permalink
Update AMD transpose to match CUDA transpose.
Browse files Browse the repository at this point in the history
  • Loading branch information
jessebenson committed Dec 9, 2020
1 parent abdbb5f commit cc47cfc
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 393 deletions.
34 changes: 18 additions & 16 deletions onnxruntime/core/providers/cuda/tensor/transpose.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,23 @@
namespace onnxruntime {
namespace cuda {

ONNX_OPERATOR_VERSIONED_KERNEL_EX(Transpose,
kOnnxDomain,
1, 12,
kCudaExecutionProvider,
KernelDefBuilder()
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()),
Transpose);

ONNX_OPERATOR_KERNEL_EX(Transpose,
kOnnxDomain,
13,
kCudaExecutionProvider,
KernelDefBuilder()
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()),
Transpose);
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
Transpose,
kOnnxDomain,
1, 12,
kCudaExecutionProvider,
KernelDefBuilder()
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()),
Transpose);

ONNX_OPERATOR_KERNEL_EX(
Transpose,
kOnnxDomain,
13,
kCudaExecutionProvider,
KernelDefBuilder()
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()),
Transpose);

// special case acceleration using cublas matrix transpose
static std::tuple<int, int> TryTransposeWithCublas(const std::vector<size_t>& perm, const TensorShape& input_shape) {
Expand Down Expand Up @@ -162,7 +164,7 @@ Status Transpose::DoTranspose(const cudaDeviceProp& prop,
if (CanDoTranspose3D(new_rank, new_input_dims, new_permutations)) {
return Transpose3DImpl(element_size, input_shape, tmp_input_strides,
input.DataRaw(), output.MutableDataRaw(), output.Shape().Size());
} else if (CanDoTranspose4D(prop, element_size, new_rank, new_input_dims, new_permutations)) {
} else if (CanDoTranspose4D(prop, element_size, new_rank, new_input_dims, new_permutations)) {
TArray<int64_t> tmp_output_strides(new_rank);
for (auto i = 0; i < new_rank; i++) {
tmp_output_strides[i] = new_output_strides[new_permutations[i]];
Expand Down
121 changes: 54 additions & 67 deletions onnxruntime/core/providers/rocm/tensor/transpose.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,23 @@
namespace onnxruntime {
namespace rocm {

ONNX_OPERATOR_VERSIONED_KERNEL_EX(Transpose,
kOnnxDomain,
1, 12,
kRocmExecutionProvider,
KernelDefBuilder()
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()),
Transpose);

ONNX_OPERATOR_KERNEL_EX(Transpose,
kOnnxDomain,
13,
kRocmExecutionProvider,
KernelDefBuilder()
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()),
Transpose);
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
Transpose,
kOnnxDomain,
1, 12,
kRocmExecutionProvider,
KernelDefBuilder()
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()),
Transpose);

ONNX_OPERATOR_KERNEL_EX(
Transpose,
kOnnxDomain,
13,
kRocmExecutionProvider,
KernelDefBuilder()
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()),
Transpose);

// special case acceleration using rocblas matrix transpose
static std::tuple<int, int> TryTransposeWithRocblas(const std::vector<size_t>& perm, const TensorShape& input_shape) {
Expand Down Expand Up @@ -52,14 +54,14 @@ static std::tuple<int, int> TryTransposeWithRocblas(const std::vector<size_t>& p
}

template <typename T>
Status TransposeWithRocblas(rocblas_handle handle, const Tensor& input, Tensor& output, int M, int N) {
Status TransposeWithRocblas(rocblas_handle rocblas_handle, const Tensor& input, Tensor& output, int M, int N) {
typedef typename ToHipType<T>::MappedType HipT;
HipT one = ToHipType<T>::FromFloat(1.0f);
HipT zero = ToHipType<T>::FromFloat(0.0f);
const HipT* input_data = reinterpret_cast<const HipT*>(input.Data<T>());
HipT* output_data = reinterpret_cast<HipT*>(output.MutableData<T>());
ROCBLAS_RETURN_IF_ERROR(
rocblasTransposeHelper(handle,
rocblasTransposeHelper(rocblas_handle,
rocblas_operation_transpose, rocblas_operation_transpose, M, N,
&one,
input_data,
Expand All @@ -72,8 +74,15 @@ Status TransposeWithRocblas(rocblas_handle handle, const Tensor& input, Tensor&
return Status::OK();
}

Status Transpose::DoTranspose(const Transpose& kernel,
Status Transpose::DoTranspose(const Transpose& transpose_kernel,
const std::vector<size_t>& permutations, const Tensor& input, Tensor& output) {
return Transpose::DoTranspose(transpose_kernel.GetDeviceProp(), transpose_kernel.RocblasHandle(), permutations, input, output);
}

Status Transpose::DoTranspose(const hipDeviceProp_t& prop,
const rocblas_handle rocblas_handle,
const std::vector<size_t>& permutations, const Tensor& input, Tensor& output,
const TensorShape* input_shape_override) {
// special case when there is a dim value of 0 in the shape.
if (output.Shape().Size() == 0)
return Status::OK();
Expand All @@ -82,26 +91,23 @@ Status Transpose::DoTranspose(const Transpose& kernel,
if (element_type == utils::GetONNXTensorElementDataType<float>() ||
element_type == utils::GetONNXTensorElementDataType<double>() ||
element_type == utils::GetONNXTensorElementDataType<MLFloat16>()) {
auto mn = TryTransposeWithRocblas(permutations, input.Shape());
auto mn = TryTransposeWithRocblas(permutations, input_shape_override ? *input_shape_override : input.Shape());
int M = std::get<0>(mn);
int N = std::get<1>(mn);
if (M != 0 && N != 0) {
if (element_type == utils::GetONNXTensorElementDataType<float>()) {
return TransposeWithRocblas<float>(kernel.RocblasHandle(), input, output, M, N);
return TransposeWithRocblas<float>(rocblas_handle, input, output, M, N);
} else if (element_type == utils::GetONNXTensorElementDataType<double>()) {
return TransposeWithRocblas<double>(kernel.RocblasHandle(), input, output, M, N);
return TransposeWithRocblas<double>(rocblas_handle, input, output, M, N);
} else {
return TransposeWithRocblas<MLFloat16>(kernel.RocblasHandle(), input, output, M, N);
return TransposeWithRocblas<MLFloat16>(rocblas_handle, input, output, M, N);
}
}
}

const std::vector<int64_t>& input_dims = input.Shape().GetDims();
const std::vector<int64_t>& input_dims = input_shape_override ? input_shape_override->GetDims() : input.Shape().GetDims();
const std::vector<int64_t>& output_dims = output.Shape().GetDims();

auto rank = static_cast<int32_t>(input_dims.size());
TensorPitches original_input_strides(input_dims);
TensorPitches original_output_strides(output_dims);

// flatten the adjacent dimensions which are contiguous
// for example: permutations[0, 2, 3, 1] -> [0, 2, 1], permutations[0, 3, 1, 2] -> [0, 2, 1]
Expand Down Expand Up @@ -150,55 +156,36 @@ Status Transpose::DoTranspose(const Transpose& kernel,
TensorPitches new_input_strides(new_input_dims);
TensorPitches new_output_strides(new_output_dims);

// TArray<int64_t> input_strides(rank);
// for (auto i = 0; i < rank; i++) {
// input_strides[i] = original_input_strides[permutations[i]];
// }

// TArray<fast_divmod> output_strides(rank);
// for (auto i = 0; i < rank; i++) {
// output_strides[i] = fast_divmod(gsl::narrow_cast<int>(original_output_strides[i]));
// }
// Optimize the permutation of 3D/4D tensor
TArray<int64_t> input_shape(new_input_dims);
TArray<int64_t> tmp_input_strides(new_input_strides);

size_t element_size = input.DataType()->Size();
std::vector<int64_t> input_shape(new_rank);
std::vector<int64_t> tmp_input_strides(new_rank);
std::vector<int64_t> tmp_output_strides(new_rank);
for (auto i = 0; i < new_rank; i++) {
input_shape[i] = new_input_dims[i];
tmp_input_strides[i] = new_input_strides[i];
tmp_output_strides[i] = new_output_strides[new_permutations[i]];
}

if (CanDoTranspose3D(new_rank, new_input_dims, new_permutations)) {
return Transpose3DImpl(kernel, element_size, input_shape, tmp_input_strides,
return Transpose3DImpl(element_size, input_shape, tmp_input_strides,
input.DataRaw(), output.MutableDataRaw(), output.Shape().Size());

} else if (CanDoTranspose4D(kernel.GetDeviceProp(), element_size, new_rank, new_input_dims, new_permutations)) {
return Transpose4DImpl(kernel, element_size, input_shape, tmp_input_strides, input.DataRaw(),
} else if (CanDoTranspose4D(prop, element_size, new_rank, new_input_dims, new_permutations)) {
TArray<int64_t> tmp_output_strides(new_rank);
for (auto i = 0; i < new_rank; i++) {
tmp_output_strides[i] = new_output_strides[new_permutations[i]];
}
return Transpose4DImpl(element_size, input_shape, tmp_input_strides, input.DataRaw(),
tmp_output_strides, output.MutableDataRaw(), output.Shape().Size());
}
}

RocmAsyncBuffer<int64_t> input_strides(&kernel, new_rank);
// General cases
TArray<int64_t> input_strides(new_rank);
for (auto i = 0; i < new_rank; i++) {
input_strides.CpuPtr()[i] = new_input_strides[new_permutations[i]];
input_strides[i] = new_input_strides[new_permutations[i]];
}

RocmAsyncBuffer<fast_divmod> output_strides(&kernel, new_rank);
ORT_ENFORCE(CalculateFdmStrides(output_strides.CpuSpan(), new_output_dims));

// TODO: use output shape in reverse order for uint24 math
// for (auto i = 0; i < rank; i++) {
// output_strides.CpuPtr()[i] = output_dims[rank - 1 - i];
// if (output_dims[rank-1-i] > 0x7FFFFF) {
// printf("shape size is: %lx\n", output_dims[rank-1-i]);
// }
// }
ORT_RETURN_IF_ERROR(input_strides.CopyToGpu());
ORT_RETURN_IF_ERROR(output_strides.CopyToGpu());
TArray<fast_divmod> output_strides(new_rank);
for (auto i = 0; i < new_rank; i++) {
output_strides[i] = fast_divmod(gsl::narrow_cast<int>(new_output_strides[i]));
}

auto status = TransposeImpl(element_size, new_rank, input_strides.GpuPtr(), input.DataRaw(),
output_strides.GpuPtr(), output.MutableDataRaw(), output.Shape().Size());
auto status = TransposeImpl(element_size, new_rank, input_strides, input.DataRaw(),
output_strides, output.MutableDataRaw(), output.Shape().Size());

return status;
}
Expand All @@ -221,8 +208,8 @@ Status Transpose::ComputeInternal(OpKernelContext* ctx) const {
TensorShape output_shape{output_dims};
Tensor* Y = ctx->Output(0, output_shape);

return DoTranspose(*this, *p_perm, X, *Y);
return DoTranspose(this->GetDeviceProp(), this->RocblasHandle(), *p_perm, X, *Y);
}

} // namespace rocm
} // namespace onnxruntime
} // namespace onnxruntime
11 changes: 9 additions & 2 deletions onnxruntime/core/providers/rocm/tensor/transpose.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,27 @@
#include "gsl/gsl"
#include "core/common/common.h"
#include "core/framework/op_kernel.h"
#include "core/providers/cpu/tensor/transpose.h"
#include "core/providers/rocm/rocm_common.h"
#include "core/providers/cpu/tensor/transpose.h"

namespace onnxruntime {
namespace rocm {

class Transpose final : public RocmKernel, public TransposeBase {
public:
Transpose(const OpKernelInfo& info) : RocmKernel(info), TransposeBase(info) {}
Transpose(const OpKernelInfo& info) : RocmKernel(info), TransposeBase(info) {
}

Status ComputeInternal(OpKernelContext* context) const override;

static Status DoTranspose(const Transpose& transpose_kernel,
const std::vector<size_t>& permutations, const Tensor& input, Tensor& output);

// `input_shape_override` (if provided) overrides the shape of `input` for compute purposes
static Status DoTranspose(const hipDeviceProp_t& prop,
const rocblas_handle rocblas_handle,
const std::vector<size_t>& permutations,
const Tensor& input, Tensor& output, const TensorShape* input_shape_override = nullptr);
};

} // namespace rocm
Expand Down
Loading

0 comments on commit cc47cfc

Please sign in to comment.