Skip to content

Commit

Permalink
[Perf] Optimize Tile CPU and CUDA kernels for a corner case (microsof…
Browse files Browse the repository at this point in the history
  • Loading branch information
hariharans29 authored Jan 21, 2021
1 parent d9e4795 commit 8574854
Show file tree
Hide file tree
Showing 8 changed files with 179 additions and 19 deletions.
63 changes: 58 additions & 5 deletions onnxruntime/core/providers/cpu/tensor/tile.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,31 @@ Status TileCoreForFixedSizeTypes(const Tensor& input_tensor, Tensor& output_tens
return Status::OK();
}

namespace TileOp {
// Find the first non-1 repeat and check the input shape to the left of that dimension,
// if the dim values are 1, then the tiling logic is essentially copying the input buffer
// multiple times. The number of times can be computed as the product of the repeat values.
bool IsTileMemcpy(const TensorShape& input_shape,
const int64_t* repeats,
size_t rank,
/*out*/ size_t& num_of_copies) {
for (int64_t i = static_cast<int64_t>(rank) - 1; i >= 0; --i) {
if (repeats[i] != 1) {
if (input_shape.SizeToDimension(i) == 1) {
num_of_copies = 1;
for (int64_t j = 0; j <= i; ++j) {
num_of_copies *= repeats[j];
}
return true;
} else {
break;
}
}
}
return false;
}
} // namespace TileOp

Status Tile::Compute(OpKernelContext* ctx) const {
const auto* tensor_pointer = ctx->Input<Tensor>(0);
if (tensor_pointer == nullptr) return Status(common::ONNXRUNTIME, common::FAIL, "Input count of Tile OP mismatch, the first one is empty");
Expand All @@ -116,19 +141,47 @@ Status Tile::Compute(OpKernelContext* ctx) const {
return Status(ONNXRUNTIME, INVALID_ARGUMENT, "'repeat' input tensor must have the same length as the 'input' tensor");

// Calculate the shape of the output tensor
auto* repeats = repeats_tensor.template Data<int64_t>();
const auto* repeats = repeats_tensor.template Data<int64_t>();
std::vector<int64_t> output_dims = input_shape.GetDims();
for (size_t axis = 0; axis < input_rank; axis++) {
output_dims[axis] *= repeats[axis];
}

TensorShape outputShape(output_dims);
auto& output_tensor = *ctx->Output(0, outputShape);
TensorShape output_shape(output_dims);
auto& output_tensor = *ctx->Output(0, output_shape);

// Repeat tensor input can have 0 as a valid value
// check if the computed outputshape size is 0 and
// check if the computed output_shape size is 0 and
// return an empty tensor if so.
if (outputShape.Size() == 0) {
if (output_shape.Size() == 0) {
return Status::OK();
}

// Repeat tensor has all 1s in it
if (output_shape == input_shape) {
// TODO: Handle string copies when the kernel eventually supports string type.
// For now, it shouldn't throw in the enforce as the kernel doesn't claim string support
ORT_ENFORCE(!input_tensor.IsDataType<std::string>(), "Tile doesn't support string type yet");
memcpy(output_tensor.MutableDataRaw(), input_tensor.DataRaw(), input_tensor.SizeInBytes());
return Status::OK();
}

size_t num_of_copies = 1;
if (TileOp::IsTileMemcpy(input_shape, repeats, input_rank, num_of_copies)) {
// TODO: Handle string copies when the kernel eventually supports string type.
// For now, it shouldn't throw in the enforce as the kernel doesn't claim string support
ORT_ENFORCE(!input_tensor.IsDataType<std::string>(), "Tile doesn't support string type yet");

int8_t* output_data_casted = reinterpret_cast<int8_t*>(output_tensor.MutableDataRaw());
const void* input_data_raw = input_tensor.DataRaw();
size_t tensor_size_in_bytes = input_tensor.SizeInBytes();

// TODO: Add multi-threading logic if num_of_copies is large enough
for (size_t i = 0; i < num_of_copies; ++i) {
memcpy(static_cast<void*>(output_data_casted), input_data_raw, tensor_size_in_bytes);
output_data_casted += tensor_size_in_bytes;
}

return Status::OK();
}

Expand Down
17 changes: 15 additions & 2 deletions onnxruntime/core/providers/cpu/tensor/tile.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,21 @@

namespace onnxruntime {

struct Tile final : OpKernel {
Tile(const OpKernelInfo& info) : OpKernel(info) {
namespace TileOp {
// Function to determine if the tiling operation is just multiple copies
// of the input data buffer
// E.g.: input_shape: [1, 1, 256 * 50]
// repeats: [1, 200, 1]
// output shape: [1, 200, 256 * 50]

bool IsTileMemcpy(const TensorShape& input_shape,
const int64_t* repeats,
size_t rank,
/*out*/ size_t& num_of_copies);
} // namespace TileOp

struct Tile : OpKernel {
explicit Tile(const OpKernelInfo& info) : OpKernel(info) {
}

Status Compute(OpKernelContext* context) const override;
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/cuda/tensor/scatter_nd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ Status ScatterND::ComputeInternal(OpKernelContext* context) const {

if (input_data != output_data) {
// TODO: Run benchmarks to determine if a dedicated kernel doing data copy will be faster than invoking cudaMemcpy ?
cudaMemcpyAsync(output_data, input_data, element_size * input_shape.Size(), cudaMemcpyDeviceToDevice);
cudaMemcpyAsync(output_data, input_data, input_tensor->SizeInBytes(), cudaMemcpyDeviceToDevice);
}

// Bail out early
Expand Down
59 changes: 52 additions & 7 deletions onnxruntime/core/providers/cuda/tensor/tile.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "core/providers/cuda/tensor/tile.h"
#include "core/providers/cpu/tensor/utils.h"
#include "tile_impl.h"

using namespace onnxruntime::common;
namespace onnxruntime {
namespace cuda {
Expand Down Expand Up @@ -51,22 +52,66 @@ Status Tile::ComputeInternal(OpKernelContext* ctx) const {

// Calculate the shape of the output tensor
auto* repeats = repeats_tensor.template Data<int64_t>();
const auto& input_shape = input_tensor.Shape().GetDims();
std::vector<int64_t> output_dims(input_shape);
const auto& input_shape = input_tensor.Shape();
const auto& input_dims = input_shape.GetDims();
std::vector<int64_t> output_dims(input_dims);
for (auto axis = 0; axis < rank; axis++)
output_dims[axis] *= repeats[axis];
TensorShape outputShape(output_dims);
auto& output_tensor = *ctx->Output(0, outputShape);
TensorShape output_shape(output_dims);
auto& output_tensor = *ctx->Output(0, output_shape);

void* output_data = output_tensor.MutableDataRaw();
const void* input_data = input_tensor.DataRaw();

TensorPitches input_pitches(input_shape);
// Repeat tensor input can have 0 as a valid value
// check if the computed output_shape size is 0 and
// return an empty tensor if so.
if (output_shape.Size() == 0) {
return Status::OK();
}

// Repeat tensor has all 1s in it
if (output_shape == input_shape) {
cudaMemcpyAsync(output_tensor.MutableDataRaw(), input_tensor.DataRaw(), input_tensor.SizeInBytes(), cudaMemcpyDeviceToDevice);
return Status::OK();
}

size_t num_of_copies = 1;
if (TileOp::IsTileMemcpy(input_shape, repeats, rank, num_of_copies)) {
if (input_tensor.IsDataType<float>() ||
input_tensor.IsDataType<int32_t>()) {
TileMemcpyImpl(
reinterpret_cast<const typename ToCudaType<float>::MappedType*>(input_data),
input_shape.Size(),
reinterpret_cast<typename ToCudaType<float>::MappedType*>(output_data),
output_shape.Size());
} else if (input_tensor.IsDataType<double>() ||
input_tensor.IsDataType<int64_t>()) {
TileMemcpyImpl(
reinterpret_cast<const typename ToCudaType<double>::MappedType*>(input_data),
input_shape.Size(),
reinterpret_cast<typename ToCudaType<double>::MappedType*>(output_data),
output_shape.Size());
} else if (input_tensor.IsDataType<MLFloat16>()) {
TileMemcpyImpl(
reinterpret_cast<const typename ToCudaType<MLFloat16>::MappedType*>(input_data),
input_shape.Size(),
reinterpret_cast<typename ToCudaType<MLFloat16>::MappedType*>(output_data),
output_shape.Size());
} else {
// Won't hit this as the kernel doesn't claim support for any type that will trigger this
ORT_THROW("Tile doesn't have an implementation yet for the type: ", input_tensor.DataType());
}

return Status::OK();
}

TensorPitches input_pitches(input_dims);
TArray<int64_t> input_strides(input_pitches);

TArray<fast_divmod> fdm_input_shape(rank);
for (int32_t i = 0; i < input_shape.size(); ++i) {
fdm_input_shape[i] = fast_divmod(gsl::narrow_cast<int>(input_shape[i]));
for (int32_t i = 0; i < input_dims.size(); ++i) {
fdm_input_shape[i] = fast_divmod(gsl::narrow_cast<int>(input_dims[i]));
}

TArray<fast_divmod> fdm_output_strides(rank);
Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/core/providers/cuda/tensor/tile.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@

#include "core/common/common.h"
#include "core/providers/cuda/cuda_kernel.h"
#include "core/providers/cpu/tensor/tile.h"

namespace onnxruntime {
namespace cuda {

struct Tile final : CudaKernel {
Tile(const OpKernelInfo& info) : CudaKernel(info) {
explicit Tile(const OpKernelInfo& info) : CudaKernel(info) {
}

Status ComputeInternal(OpKernelContext* context) const override;
Expand Down
27 changes: 25 additions & 2 deletions onnxruntime/core/providers/cuda/tensor/tile_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,31 @@ void TileImpl(
fdm_output_strides, output_data, (CUDA_LONG)N);
}

#define SPECIALIZED_IMPL(T) \
template void TileImpl<T>(const size_t shape_rank, const TArray<fast_divmod>& fdm_input_shape, const TArray<int64_t>& input_stride, const T* input_data, const TArray<fast_divmod>& fdm_output_strides, T* output_data, const size_t N);
template <typename T>
__global__ void _TileMemcpyKernel(
const T* input_data,
const size_t num_input_elements,
T* output_data,
const size_t N) {
CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N);
auto input_index = id % num_input_elements;
output_data[id] = input_data[input_index];
}

template <typename T>
void TileMemcpyImpl(
const T* input_data,
const size_t num_input_elements,
T* output_data,
const size_t num_output_elements) {
int blocksPerGrid = (int)(ceil(static_cast<float>(num_output_elements) / GridDim::maxThreadsPerBlock));
_TileMemcpyKernel<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(
input_data, num_input_elements, output_data, (CUDA_LONG)num_output_elements);
}

#define SPECIALIZED_IMPL(T) \
template void TileImpl<T>(const size_t shape_rank, const TArray<fast_divmod>& fdm_input_shape, const TArray<int64_t>& input_stride, const T* input_data, const TArray<fast_divmod>& fdm_output_strides, T* output_data, const size_t N); \
template void TileMemcpyImpl<T>(const T* input_data, const size_t num_input_elements, T* output_data, const size_t num_output_elements);

SPECIALIZED_IMPL(float)
SPECIALIZED_IMPL(double)
Expand Down
7 changes: 7 additions & 0 deletions onnxruntime/core/providers/cuda/tensor/tile_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,12 @@ void TileImpl(
T* output_data,
const size_t N);

template <typename T>
void TileMemcpyImpl(
const T* input_data,
const size_t num_input_elements,
T* output_data,
const size_t num_output_elements);

} // namespace cuda
} // namespace onnxruntime
20 changes: 19 additions & 1 deletion onnxruntime/test/providers/cpu/tensor/tile_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ void RunTest(std::initializer_list<T> input,
test.AddInput<int64_t>("repeats", repeat_dims, repeat);
test.AddOutput<T>("output", output_dims, output);
if (std::is_same<T, int8_t>::value)
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); //TensorRT reports error: Assertion Error in makePaddedScale: 0 (regionRanges != nullptr)
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); //TensorRT reports error: Assertion Error in makePaddedScale: 0 (regionRanges != nullptr)
else
test.Run();
}
Expand All @@ -43,6 +43,15 @@ void RunTestWrapper() {

// Tile3D
RunTest<T>({111, 112, 113, 122, 123, 124}, {2, 1, 3}, {1, 2, 1}, {3}, {111, 112, 113, 111, 112, 113, 122, 123, 124, 122, 123, 124}, {2, 2, 3});

// Tile1DWithOneRepeats
RunTest<T>({111, 112, 113, 122, 123, 124}, {2, 1, 3}, {1, 1, 1}, {3}, {111, 112, 113, 122, 123, 124}, {2, 1, 3});

// TileWhichIsBasicallyCopiesOfInputBuffer - 1
RunTest<T>({111, 112, 113}, {1, 1, 3}, {2, 2, 1}, {3}, {111, 112, 113, 111, 112, 113, 111, 112, 113, 111, 112, 113}, {2, 2, 3});

// TileWhichIsBasicallyCopiesOfInputBuffer - 2
RunTest<T>({111, 112, 113}, {1, 1, 3}, {3, 1, 1}, {3}, {111, 112, 113, 111, 112, 113, 111, 112, 113}, {3, 1, 3});
}

template <>
Expand All @@ -64,6 +73,15 @@ void RunTestWrapper<bool>() {

// Tile3D
RunTest<bool>({true, false, true, false, true, false}, {2, 1, 3}, {1, 2, 1}, {3}, {true, false, true, true, false, true, false, true, false, false, true, false}, {2, 2, 3});

// Tile1DWithOneRepeats
RunTest<bool>({true, false, true, false, true, true}, {2, 1, 3}, {1, 1, 1}, {3}, {true, false, true, false, true, true}, {2, 1, 3});

// TileWhichIsBasicallyCopiesOfInputBuffer - 1
RunTest<bool>({true, false, true}, {1, 1, 3}, {2, 2, 1}, {3}, {true, false, true, true, false, true, true, false, true, true, false, true}, {2, 2, 3});

// TileWhichIsBasicallyCopiesOfInputBuffer - 2
RunTest<bool>({true, false, true}, {1, 1, 3}, {3, 1, 1}, {3}, {true, false, true, true, false, true, true, false, true}, {3, 1, 3});
}

TEST(TensorOpTest, TileFloatType) {
Expand Down

0 comments on commit 8574854

Please sign in to comment.