Skip to content

Commit

Permalink
[SYCLCompat] Optimize/(fix?) permute_sub_group_by_xor if `logical_sub…
Browse files Browse the repository at this point in the history
…_group_size == 32` (#16646)

`syclcompat::permute_sub_group_by_xor` was reported to flakily fail on
L0. Closer inspection revealed that the implementation of
`permute_sub_group_by_xor` is incorrect for cases where
`logical_sub_group_size != 32`, which is one of the test cases. This
implies that the test itself is wrong.

In this PR we first optimize the part of the implementation that is
valid assuming that Intel spirv builtins are correct (which is also the
only case realistically a user will program): case
`logical_sub_group_size == 32`, in order to:
- Ensure the only useful case is working via the correct optimized
route.
- Check that this improvement doesn't break the suspicious test.

A follow on PR can fix the other cases where `logical_sub_group_size !=
32`: this is better to do later, since
- the only use case I know of for this is to implement non-uniform group
algorithms that we already have implemented (e.g. see
#9671) and any user is advised to use
such algorithms instead of reimplementing them themselves.
- This must I think require a complete reworking of the test and would
otherwise delay the more important change here.

---------

Signed-off-by: JackAKirk <jack.kirk@codeplay.com>
JackAKirk authored Jan 20, 2025
1 parent e257292 commit 291eeee
Showing 2 changed files with 18 additions and 4 deletions.
3 changes: 3 additions & 0 deletions sycl/include/syclcompat/util.hpp
Original file line number Diff line number Diff line change
@@ -410,6 +410,9 @@ T shift_sub_group_right(sycl::sub_group g, T x, unsigned int delta,
template <typename T>
T permute_sub_group_by_xor(sycl::sub_group g, T x, unsigned int mask,
int logical_sub_group_size = 32) {
if (logical_sub_group_size == 32) {
return permute_group_by_xor(g, x, mask);
}
unsigned int id = g.get_local_linear_id();
unsigned int start_index =
id / logical_sub_group_size * logical_sub_group_size;
19 changes: 15 additions & 4 deletions sycl/test-e2e/syclcompat/util/util_permute_sub_group_by_xor.cpp
Original file line number Diff line number Diff line change
@@ -86,11 +86,9 @@ void test_permute_sub_group_by_xor() {
syclcompat::device_ext &dev_ct1 = syclcompat::get_current_device();
sycl::queue *q_ct1 = dev_ct1.default_queue();
bool Result = true;
int *dev_data = nullptr;
unsigned int *dev_data_u = nullptr;
sycl::range<3> GridSize(1, 1, 1);
sycl::range<3> BlockSize(1, 1, 1);
dev_data = sycl::malloc_device<int>(DATA_NUM, *q_ct1);
dev_data_u = sycl::malloc_device<unsigned int>(DATA_NUM, *q_ct1);

GridSize = sycl::range<3>(1, 1, 2);
@@ -120,6 +118,19 @@ void test_permute_sub_group_by_xor() {
q_ct1->memcpy(host_dev_data_u, dev_data_u, DATA_NUM * sizeof(unsigned int))
.wait();
verify_data<unsigned int>(host_dev_data_u, expect1, DATA_NUM);
sycl::free(dev_data_u, *q_ct1);
}

void test_permute_sub_group_by_xor_extra_arg() {
std::cout << __PRETTY_FUNCTION__ << std::endl;

syclcompat::device_ext &dev_ct1 = syclcompat::get_current_device();
sycl::queue *q_ct1 = dev_ct1.default_queue();
bool Result = true;
unsigned int *dev_data_u = nullptr;
sycl::range<3> GridSize(1, 1, 1);
sycl::range<3> BlockSize(1, 1, 1);
dev_data_u = sycl::malloc_device<unsigned int>(DATA_NUM, *q_ct1);

GridSize = sycl::range<3>(1, 1, 2);
BlockSize = sycl::range<3>(1, 2, 32);
@@ -133,6 +144,7 @@ void test_permute_sub_group_by_xor() {
91, 90, 93, 92, 95, 94, 97, 96, 99, 98, 101, 100, 103, 102, 105,
104, 107, 106, 109, 108, 111, 110, 113, 112, 115, 114, 117, 116, 119, 118,
121, 120, 123, 122, 125, 124, 127, 126};
unsigned int host_dev_data_u[DATA_NUM];
init_data<unsigned int>(host_dev_data_u, DATA_NUM);

q_ct1->memcpy(dev_data_u, host_dev_data_u, DATA_NUM * sizeof(unsigned int))
@@ -147,13 +159,12 @@ void test_permute_sub_group_by_xor() {
q_ct1->memcpy(host_dev_data_u, dev_data_u, DATA_NUM * sizeof(unsigned int))
.wait();
verify_data<unsigned int>(host_dev_data_u, expect2, DATA_NUM);

sycl::free(dev_data, *q_ct1);
sycl::free(dev_data_u, *q_ct1);
}

int main() {
test_permute_sub_group_by_xor();
test_permute_sub_group_by_xor_extra_arg();

return 0;
}

0 comments on commit 291eeee

Please sign in to comment.