From bdd06f63109ddf85ead6f836f36f85bf4fb1bc9f Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Xavier=20Dupr=C3=A9?=
Date: Thu, 3 Dec 2020 00:38:18 +0100
Subject: [PATCH] Fix PR #5550 reverted in #5911 (performance improvment for
operator Transpose) (#5916)
* Improves implementation of transpose operator
* Fix issue mentioned in #5911
* adding unit test for function DoTransposeImpl
---
.../core/providers/cpu/tensor/transpose.cc | 252 +++++++++++-------
.../core/providers/cpu/tensor/transpose.h | 10 +
.../providers/cpu/tensor/transpose_test.cc | 162 ++++++++++-
3 files changed, 322 insertions(+), 102 deletions(-)
diff --git a/onnxruntime/core/providers/cpu/tensor/transpose.cc b/onnxruntime/core/providers/cpu/tensor/transpose.cc
index 7503ab946b8ed..06777cd48aacf 100644
--- a/onnxruntime/core/providers/cpu/tensor/transpose.cc
+++ b/onnxruntime/core/providers/cpu/tensor/transpose.cc
@@ -15,23 +15,85 @@ namespace onnxruntime {
etc.
*/
-// ComputeOffset: compute offset into a tensor. This is essentially the dot-product of
-// index and stride, restricted to the specified number of axes.
-static inline size_t ComputeOffset(const std::vector& index, const std::vector& stride, int64_t num_axes) {
- size_t offset = 0;
- for (int64_t j = 0; j < num_axes; ++j) {
- offset += index[j] * stride[j];
+struct MultiIndex {
+ size_t n_axes;
+ std::vector index;
+ std::vector upper_bound;
+ std::vector stride;
+
+ /* There is one MultiIndex instance per axis in the tensor.
+ * The array keeps track of the position of a pointer walking through the data.
+ * Any function using it creates an array of MultiIndex
+ * then calls function IncrementIndexAndComputeOffsetSetup
+ * to initialize the array. This constructor does not initialize
+ * anything because it would be overwritten by function
+ * IncrementIndexAndComputeOffsetSetup. This one calls method Init.
+ * Function IncrementIndexAndComputeOffset is called to increment
+ * the array of MultiIndex to move to the next data in the tensor.
+ */
+ MultiIndex() : index(), upper_bound(), stride() { n_axes = 0; }
+
+ void Init(size_t num_axes) {
+ index.resize(num_axes);
+ upper_bound.resize(num_axes);
+ stride.resize(num_axes);
+ n_axes = num_axes;
}
- return offset;
+
+ void InitAxis(size_t n_axis, size_t i, size_t n, int64_t s) {
+ index[n_axis] = i;
+ upper_bound[n_axis] = n;
+ stride[n_axis] = s;
+ }
+};
+
+/* This function initializes an array of MultiIndex of size num_axes (one instance per axis).
+* target_dims is the shape of the transposed tensor, stride is linked to the tensor to
+* be transposed, if source_dims is the shape, stride[i] = source_dims[i+1] * source_dims[i+2] * ... * 1.
+* element_size is the size of the tensor element (sizeof(float), sizeof(double)).
+*/
+static void IncrementIndexAndComputeOffsetSetup(MultiIndex& mindex, size_t num_axes, const std::vector& target_dims,
+ const std::vector& stride, size_t element_size) {
+ mindex.Init(num_axes);
+ size_t naxes = 0;
+ for (size_t i = 0; i < num_axes; ++i) {
+ if (target_dims[i] == 1)
+ continue;
+ mindex.InitAxis(naxes, 0, static_cast(target_dims[i]), stride[i] * element_size);
+ ++naxes;
+ }
+ ORT_ENFORCE(naxes > 0, "Method IncrementIndexAndComputeOffset assumes this value is strictly positive.");
+ mindex.n_axes = naxes;
}
-// IncrementIndex: Increment an index into a tensor (in lexicographic ordering), wrapping
-// around the specified upper_bound.
-static inline void IncrementIndex(std::vector& index, const std::vector& upper_bound, int64_t num_axes) {
- for (int64_t k = num_axes - 1; k >= 0; --k) {
- index[k]++;
- if (index[k] < upper_bound[k]) break;
- index[k] = 0;
+/* This function increments an array of MultiIndex initialized by function IncrementIndexAndComputeOffsetSetup.
+* It increments the last dimension, checks if it stays within boundary. If it stays in, it returns,
+* otherwise, it reset the dimension to zero and increments the previous one.
+* While doing that, every modification brought to the array of indices is applied on the
+* pointer local_source. It avoids computing again local_source from the source tensor.
+* At every time, the following condition is verified:
+* local_source = source + (sum_i mindex[i].index * mindex[i].stride
+*/
+template
+static inline void IncrementIndexAndComputeOffset(MultiIndex& mindex, const T*& local_source) {
+ // Increment the last dimension.
+ int pos = static_cast(mindex.n_axes) - 1;
+ local_source += mindex.stride[pos];
+ // Checks it stays within boundaries.
+ if (++mindex.index[pos] < mindex.upper_bound[pos])
+ return;
+ // If not, loops on other indices.
+ // The first test is outside the loop to be faster.
+ // As it is the most common case.
+ local_source -= mindex.stride[pos] * mindex.index[pos];
+ mindex.index[pos] = 0;
+ --pos;
+ for (; pos >= 0; --pos) {
+ local_source += mindex.stride[pos];
+ if (++mindex.index[pos] < mindex.upper_bound[pos])
+ break;
+ local_source -= mindex.stride[pos] * mindex.index[pos];
+ mindex.index[pos] = 0;
}
}
@@ -55,17 +117,14 @@ static void DoTransposeImpl(int64_t num_axes, const std::vector& target
size_t num_blocks, size_t num_elts_in_block, const std::vector& stride,
const uint8_t* source, uint8_t* target, size_t element_size) {
size_t blocksize = num_elts_in_block * element_size;
- // index used to iterate over target iteration-space
- std::vector target_index(num_axes, 0);
- for (size_t i = 0; i < num_blocks; ++i) {
- // convert target_index into an offset in source data
- size_t source_offset = ComputeOffset(target_index, stride, num_axes);
-
- // copy
- memcpy(target, source + source_offset * element_size, blocksize);
+ MultiIndex mindex;
+ IncrementIndexAndComputeOffsetSetup(mindex, num_axes, target_dims, stride, element_size);
- // increment target_index:
- IncrementIndex(target_index, target_dims, num_axes);
+ const uint8_t* local_source = source;
+ for (size_t i = 0; i < num_blocks; ++i) {
+ ORT_ENFORCE((local_source >= source) && (local_source < source + num_blocks * blocksize));
+ memcpy(target, local_source, blocksize);
+ IncrementIndexAndComputeOffset(mindex, local_source);
target += blocksize;
}
}
@@ -73,17 +132,15 @@ static void DoTransposeImpl(int64_t num_axes, const std::vector& target
static void DoTransposeImpl(int64_t num_axes, const std::vector& target_dims,
size_t num_blocks, size_t num_elts_in_block, const std::vector& stride,
const std::string* source, std::string* target) {
- // index used to iterate over target iteration-space
- std::vector target_index(num_axes, 0);
- for (size_t i = 0; i < num_blocks; ++i) {
- // convert target_index into an offset in source data
- size_t source_offset = ComputeOffset(target_index, stride, num_axes);
-
- // copy
- DoTransposeSingleBlock(num_elts_in_block, source + source_offset, target);
+ ORT_ENFORCE(num_axes > 0, "Transpose not implemented for empty tensors.");
+ MultiIndex mindex;
+ IncrementIndexAndComputeOffsetSetup(mindex, num_axes, target_dims, stride, 1);
- // increment target_index:
- IncrementIndex(target_index, target_dims, num_axes);
+ const std::string* local_source = source;
+ for (size_t i = 0; i < num_blocks; ++i) {
+ ORT_ENFORCE((local_source >= source) && (local_source < source + num_blocks * num_elts_in_block));
+ DoTransposeSingleBlock(num_elts_in_block, local_source, target);
+ IncrementIndexAndComputeOffset(mindex, local_source);
target += num_elts_in_block;
}
}
@@ -93,67 +150,40 @@ inline void CopyPrim(uint8_t* target, const uint8_t* source) {
*reinterpret_cast(target) = *reinterpret_cast(source);
}
+// The function does not check num_axes > 0 but this is expected.
+template
+static void TypedDoTransposeEltWise(int64_t num_axes, const std::vector& target_dims, size_t num_blocks,
+ const std::vector& stride, const uint8_t* source, uint8_t* target) {
+ MultiIndex mindex;
+ IncrementIndexAndComputeOffsetSetup(mindex, num_axes, target_dims, stride, sizeof(T));
+
+ const uint8_t* local_source = source;
+ uint8_t* target_end = target + sizeof(T) * num_blocks;
+ for (; target != target_end; target += sizeof(T)) {
+ ORT_ENFORCE((local_source >= source) && (local_source < source + sizeof(T) * num_blocks));
+ CopyPrim(target, local_source);
+ IncrementIndexAndComputeOffset(mindex, local_source);
+ }
+}
+
// DoTransposeEltWise: specialization of DoTranspose for the num_elts_in_block=1 case.
// copies source tensor to target, transposing elements.
// The stride vector indicates the transposition.
-static void DoTransposeEltWise(int64_t num_axes, const std::vector& target_dims, size_t num_blocks,
- const std::vector& stride, const uint8_t* source, uint8_t* target,
- size_t element_size) {
- // index used to iterate over target iteration-space
- std::vector target_index(num_axes, 0);
-
+void DoTransposeEltWise(int64_t num_axes, const std::vector& target_dims, size_t num_blocks,
+ const std::vector& stride, const uint8_t* source, uint8_t* target,
+ size_t element_size) {
switch (element_size) {
case sizeof(uint64_t):
- for (size_t i = 0; i < num_blocks; ++i) {
- // convert target_index into an offset in source data
- size_t source_offset = ComputeOffset(target_index, stride, num_axes);
-
- // copy
- CopyPrim(target, source + (source_offset * element_size));
-
- // increment target_index:
- IncrementIndex(target_index, target_dims, num_axes);
- target += element_size;
- }
+ TypedDoTransposeEltWise(num_axes, target_dims, num_blocks, stride, source, target);
break;
case sizeof(uint32_t):
- for (size_t i = 0; i < num_blocks; ++i) {
- // convert target_index into an offset in source data
- size_t source_offset = ComputeOffset(target_index, stride, num_axes);
-
- // copy
- CopyPrim(target, source + (source_offset * element_size));
-
- // increment target_index:
- IncrementIndex(target_index, target_dims, num_axes);
- target += element_size;
- }
+ TypedDoTransposeEltWise(num_axes, target_dims, num_blocks, stride, source, target);
break;
case sizeof(uint16_t):
- for (size_t i = 0; i < num_blocks; ++i) {
- // convert target_index into an offset in source data
- size_t source_offset = ComputeOffset(target_index, stride, num_axes);
-
- // copy
- CopyPrim(target, source + (source_offset * element_size));
-
- // increment target_index:
- IncrementIndex(target_index, target_dims, num_axes);
- target += element_size;
- }
+ TypedDoTransposeEltWise(num_axes, target_dims, num_blocks, stride, source, target);
break;
case sizeof(uint8_t):
- for (size_t i = 0; i < num_blocks; ++i) {
- // convert target_index into an offset in source data
- size_t source_offset = ComputeOffset(target_index, stride, num_axes);
-
- // copy
- *target = *(source + (source_offset * element_size));
-
- // increment target_index:
- IncrementIndex(target_index, target_dims, num_axes);
- target += element_size;
- }
+ TypedDoTransposeEltWise(num_axes, target_dims, num_blocks, stride, source, target);
break;
default:
assert(false);
@@ -162,17 +192,16 @@ static void DoTransposeEltWise(int64_t num_axes, const std::vector& tar
static void DoTransposeEltWise(int64_t num_axes, const std::vector& target_dims, size_t num_blocks,
const std::vector& stride, const std::string* source, std::string* target) {
+ ORT_ENFORCE(num_axes > 0, "Transpose not implemented for empty tensors.");
+ MultiIndex mindex;
+ IncrementIndexAndComputeOffsetSetup(mindex, num_axes, target_dims, stride, 1);
+
// index used to iterate over target iteration-space
- std::vector target_index(num_axes, 0);
+ const std::string* local_source = source;
for (size_t i = 0; i < num_blocks; ++i) {
- // convert target_index into an offset in source data
- size_t source_offset = ComputeOffset(target_index, stride, num_axes);
-
- // copy
- *target = *(source + source_offset);
-
- // increment target_index:
- IncrementIndex(target_index, target_dims, num_axes);
+ ORT_ENFORCE((local_source >= source) && (local_source < source + num_blocks));
+ *target = *local_source;
+ IncrementIndexAndComputeOffset(mindex, local_source);
target++;
}
}
@@ -274,13 +303,15 @@ template
static void SimpleTransposeSingleAxisOutwards(const T* input_data, T* output_data,
int64_t num_loops, int64_t num_writers,
int64_t writes_per_loop, int64_t writes_per_writer_per_loop) {
+ const T* end;
for (int64_t l = 0; l < num_loops; ++l) {
T* output_for_first_writer = output_data;
for (auto wwpl = 0; wwpl < writes_per_writer_per_loop; ++wwpl) {
T* output_for_current_writer = output_for_first_writer;
- for (int64_t w = 0; w < num_writers; ++w) {
+ end = input_data + num_writers;
+ for (; input_data != end;) {
*output_for_current_writer = *input_data++;
// skip to output position for next writer
@@ -379,13 +410,15 @@ template
static void SimpleTransposeSingleAxisInwards(const T* input_data, T* output_data,
int64_t num_loops, int64_t num_readers,
int64_t reads_per_loop, int64_t reads_per_reader_per_loop) {
+ T* end;
for (int64_t l = 0; l < num_loops; ++l) {
const T* input_for_first_reader = input_data;
for (auto rrpl = 0; rrpl < reads_per_reader_per_loop; ++rrpl) {
const T* input_for_current_reader = input_for_first_reader;
- for (int64_t r = 0; r < num_readers; ++r) {
+ end = output_data + num_readers;
+ for (; output_data != end;) {
*output_data++ = *input_for_current_reader;
// skip to input position for next reader
input_for_current_reader += reads_per_reader_per_loop;
@@ -560,6 +593,20 @@ static bool IsMovingSingleAxis(const std::vector& permutations, size_t&
return single_axis_moved;
}
+bool IsTransposeReshape(const std::vector& perm, const std::vector& input_dims) {
+ // As long as the dims with values > 1 stay in the same order, it's a reshape.
+ // Example: Shape=(1,1,1024,4096) -> perm=(2,0,3,1).
+ size_t last_permuted_axis = 0;
+ for (size_t i = 0; i < perm.size(); ++i) {
+ if (input_dims[perm[i]] == 1)
+ continue;
+ if (perm[i] < last_permuted_axis)
+ return false;
+ last_permuted_axis = perm[i];
+ }
+ return true;
+}
+
//`input_shape_override` overrides the shape of `input` for compute purposes.
Status TransposeBase::DoTranspose(const std::vector& permutations, const Tensor& input, Tensor& output,
const TensorShape* input_shape_override) {
@@ -572,6 +619,14 @@ Status TransposeBase::DoTranspose(const std::vector& permutations, const
status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Mismatched data types between input and output Tensors. ",
input_type, " != ", output_type);
} else {
+ TensorShape shape = input_shape_override ? *input_shape_override : input.Shape();
+ if (IsTransposeReshape(permutations, shape.GetDims())) {
+ // As long as the dims with values > 1 stay in the same order, it's a reshape.
+ // Example: Shape=(1,1,1024,4096) -> perm=(2,0,3,1).
+ CopyCpuTensor(&input, &output);
+ return Status::OK();
+ }
+
size_t from = 0, to = 0;
bool moving_single_axis = IsMovingSingleAxis(permutations, from, to);
@@ -607,6 +662,13 @@ Status Transpose::Compute(OpKernelContext* ctx) const {
if (output_shape.Size() == 0)
return Status::OK();
+ if (IsTransposeReshape(*p_perm, input_dims)) {
+ // As long as the dims with values > 1 stay in the same order, it's a reshape.
+ // Example: Shape=(1,1,1024,4096) -> perm=(2,0,3,1).
+ CopyCpuTensor(&X, &Y);
+ return Status::OK();
+ }
+
size_t from = 0, to = 0;
bool moving_single_axis = IsMovingSingleAxis(*p_perm, from, to);
diff --git a/onnxruntime/core/providers/cpu/tensor/transpose.h b/onnxruntime/core/providers/cpu/tensor/transpose.h
index 3cb56f6ddcf63..341975d475e7d 100644
--- a/onnxruntime/core/providers/cpu/tensor/transpose.h
+++ b/onnxruntime/core/providers/cpu/tensor/transpose.h
@@ -10,6 +10,16 @@
namespace onnxruntime {
+/** Tells if the transpose is equivalent to a reshape:
+ empty dimensions can change place, not empty dimensions must be in
+ the same order in the permuted tenosr.
+*/
+bool IsTransposeReshape(const std::vector& perm, const std::vector& input_dims);
+
+void DoTransposeEltWise(int64_t num_axes, const std::vector& target_dims, size_t num_blocks,
+ const std::vector& stride, const uint8_t* source, uint8_t* target,
+ size_t element_size);
+
class TransposeBase {
public:
/**
diff --git a/onnxruntime/test/providers/cpu/tensor/transpose_test.cc b/onnxruntime/test/providers/cpu/tensor/transpose_test.cc
index a40c805ca8fdc..6317f068c71c6 100644
--- a/onnxruntime/test/providers/cpu/tensor/transpose_test.cc
+++ b/onnxruntime/test/providers/cpu/tensor/transpose_test.cc
@@ -4,10 +4,23 @@
#include "gtest/gtest.h"
#include "test/providers/provider_test_utils.h"
#include "test/providers/compare_provider_test_utils.h"
+#include "core/providers/cpu/tensor/transpose.h"
namespace onnxruntime {
namespace test {
+TEST(TransposeOpTest, IsTransposeReshapeTest) {
+ std::vector input_dims{1, 2, 3, 4, 1};
+ std::vector perm{0, 1, 2, 3, 4};
+ ASSERT_TRUE(IsTransposeReshape(perm, input_dims));
+ perm = std::vector{1, 2, 3, 0, 4};
+ ASSERT_TRUE(IsTransposeReshape(perm, input_dims));
+ perm = std::vector{4, 1, 0, 2, 3};
+ ASSERT_TRUE(IsTransposeReshape(perm, input_dims));
+ perm = std::vector{4, 1, 0, 3, 2};
+ ASSERT_FALSE(IsTransposeReshape(perm, input_dims));
+}
+
// Some of the tests can't run on TensorrtExecutionProvider because of errors.
// Those tests will fallback to other EPs.
@@ -124,11 +137,11 @@ TEST(TransposeOpTest, TwoDim_int16) {
2, 5,
3, 6};
- #if defined(OPENVINO_CONFIG_MYRIAD) || defined(OPENVINO_CONFIG_VAD_M)
- TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals, true, false);
- #else
- TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals);
- #endif
+#if defined(OPENVINO_CONFIG_MYRIAD) || defined(OPENVINO_CONFIG_VAD_M)
+ TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals, true, false);
+#else
+ TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals);
+#endif
}
TEST(TransposeOpTest, TwoDim_mlfloat16) {
@@ -246,6 +259,39 @@ TEST(TransposeOpTest, ThreeDimSuffix) {
TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals, false); //TensorRT: illegal error
}
+TEST(TransposeOpTest, TransposeReshape) {
+ std::vector input_shape({1, 4, 2, 1, 3});
+ std::vector input_vals = {
+ 1.0f, 2.0f, 3.0f,
+ 4.0f, 5.0f, 6.0f,
+
+ 1.1f, 2.1f, 3.1f,
+ 4.1f, 5.1f, 6.1f,
+
+ 1.2f, 2.2f, 3.2f,
+ 4.2f, 5.2f, 6.2f,
+
+ 1.3f, 2.3f, 3.3f,
+ 4.3f, 5.3f, 6.3f};
+
+ std::vector perm = {1, 3, 2, 4, 0};
+ std::vector expected_shape({4, 1, 2, 3, 1});
+ auto expected_vals = {
+ 1.0f, 2.0f, 3.0f,
+ 4.0f, 5.0f, 6.0f,
+
+ 1.1f, 2.1f, 3.1f,
+ 4.1f, 5.1f, 6.1f,
+
+ 1.2f, 2.2f, 3.2f,
+ 4.2f, 5.2f, 6.2f,
+
+ 1.3f, 2.3f, 3.3f,
+ 4.3f, 5.3f, 6.3f};
+
+ TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals, false); //TensorRT: illegal error
+}
+
TEST(TransposeOpTest, ThreeDimStr) {
std::vector input_shape({4, 2, 3});
std::vector input_vals = {
@@ -419,10 +465,112 @@ TEST(TransposeOpTest, SingleAxisMovingInwardsBlockCopy) {
TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals, false);
}
+TEST(TransposeOpTest, NDim) {
+ std::vector input_shape({2, 2, 2, 2});
+ std::vector input_vals = {1.0f, 2.0f, 3.0f, 4.0f,
+ 5.0f, 6.0f, 7.0f, 8.0f,
+ 9.0f, 10.0f, 11.0f, 12.0f,
+ 13.0f, 14.0f, 15.0f, 16.0f};
+
+ std::vector perm = {1, 0, 2, 3};
+ auto expected_vals = {1.0f, 2.0f, 3.0f, 4.0f,
+ 9.0f, 10.0f, 11.0f, 12.0f,
+ 5.0f, 6.0f, 7.0f, 8.0f,
+ 13.0f, 14.0f, 15.0f, 16.0f};
+ TransposeTest(input_shape, input_vals, &perm, input_shape, expected_vals);
+
+ perm = {1, 0, 3, 2};
+ auto expected_vals2 = {1.0f, 3.0f, 2.0f, 4.0f,
+ 9.0f, 11.0f, 10.0f, 12.0f,
+ 5.0f, 7.0f, 6.0f, 8.0f,
+ 13.0f, 15.0f, 14.0f, 16.0f};
+ TransposeTest(input_shape, input_vals, &perm, input_shape, expected_vals2);
+}
+
+TEST(TransposeOpTest, DoTransposeImpl) {
+ std::vector input_shape({5, 2, 1, 3});
+ std::vector input_vals(30);
+ for (auto it = input_vals.begin(); it != input_vals.end(); ++it) {
+ *it = static_cast(std::distance(input_vals.begin(), it));
+ }
+ std::vector perm = {2, 1, 0, 3};
+ std::vector expected_shape({1, 2, 5, 3});
+ auto expected_vals = {0.0f, 1.0f, 2.0f, 6.0f, 7.0f, 8.0f,
+ 12.0f, 13.0f, 14.0f, 18.0f, 19.0f, 20.0f,
+ 24.0f, 25.0f, 26.0f, 3.0f, 4.0f, 5.0f,
+ 9.0f, 10.0f, 11.0f, 15.0f, 16.0f, 17.0f,
+ 21.0f, 22.0f, 23.0f, 27.0f, 28.0f, 29.0f};
+ TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals);
+}
+
+TEST(TransposeOpTest, DoTransposeImplString) {
+ std::vector input_shape({5, 2, 1, 3});
+ std::vector input_vals(30);
+ for (auto it = input_vals.begin(); it != input_vals.end(); ++it) {
+ *it = std::string("n") + std::to_string(static_cast(std::distance(input_vals.begin(), it)));
+ }
+ std::vector perm = {2, 1, 0, 3};
+ std::vector expected_shape({1, 2, 5, 3});
+ std::initializer_list expected_vals = {"n0", "n1", "n2", "n6", "n7", "n8",
+ "n12", "n13", "n14", "n18", "n19", "n20",
+ "n24", "n25", "n26", "n3", "n4", "n5",
+ "n9", "n10", "n11", "n15", "n16", "n17",
+ "n21", "n22", "n23", "n27", "n28", "n29"};
+ TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals);
+}
+
+TEST(TransposeOpTest, DoTransposeEltWise) {
+ // Configuration where DoTransposeEltWise is called.
+ std::vector input_shape({2, 2, 2, 2});
+ std::vector input_vals = {1.0f, 2.0f, 3.0f, 4.0f,
+ 5.0f, 6.0f, 7.0f, 8.0f,
+ 9.0f, 10.0f, 11.0f, 12.0f,
+ 13.0f, 14.0f, 15.0f, 16.0f};
+
+ std::vector perm = {1, 0, 3, 2};
+ auto expected_vals2 = {1.0f, 3.0f, 2.0f, 4.0f,
+ 9.0f, 11.0f, 10.0f, 12.0f,
+ 5.0f, 7.0f, 6.0f, 8.0f,
+ 13.0f, 15.0f, 14.0f, 16.0f};
+ TransposeTest(input_shape, input_vals, &perm, input_shape, expected_vals2);
+
+ // Specific test which tests that function DoTransposeEltWise does not
+ // copy values outside the target buffer.
+ TensorShape tensor_shape(input_shape);
+ std::vector stride(input_shape.size());
+ for (size_t i = 0; i < input_shape.size(); i++) {
+ size_t inpdim = perm[i];
+ if (inpdim + 1 < input_shape.size())
+ stride[i] = tensor_shape.SizeFromDimension(inpdim + 1);
+ else
+ stride[i] = 1;
+ }
+
+ std::vector input_vals_end = {1.0f, 2.0f, 3.0f, 4.0f,
+ 5.0f, 6.0f, 7.0f, 8.0f,
+ 9.0f, 10.0f, 11.0f, 12.0f,
+ 13.0f, 14.0f, 15.0f, 16.0f,
+ -1.0f, -1.0f};
+ std::vector target(input_vals_end.size(), 17.0f);
+
+ std::vector expected_vals3 = {1.0f, 3.0f, 2.0f, 4.0f,
+ 9.0f, 11.0f, 10.0f, 12.0f,
+ 5.0f, 7.0f, 6.0f, 8.0f,
+ 13.0f, 15.0f, 14.0f, 16.0f,
+ 17.0f, 17.0f};
+
+ DoTransposeEltWise(input_shape.size(), input_shape, 16,
+ stride, (uint8_t*)input_vals_end.data(), (uint8_t*)target.data(),
+ sizeof(float));
+ for (size_t i = 0; i < input_vals_end.size(); ++i) {
+ ASSERT_TRUE(target[i] == expected_vals3[i]);
+ }
+}
+
#if USE_CUDA
- constexpr const char* kGpuExecutionProvider = kCudaExecutionProvider;
+constexpr const char* kGpuExecutionProvider = kCudaExecutionProvider;
#elif USE_ROCM
- constexpr const char* kGpuExecutionProvider = kRocmExecutionProvider;
+constexpr const char* kGpuExecutionProvider = kRocmExecutionProvider;
#endif
#if defined(USE_CUDA) || defined(USE_ROCM)