Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use the shuffle algorithm for take #11267

Merged
merged 26 commits into from
Aug 5, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Implement shuffle
  • Loading branch information
phofl committed Jul 30, 2024
commit dd8d4fbe628e1df0b8ece328718c24c352da677c
119 changes: 119 additions & 0 deletions dask/array/_shuffle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
from __future__ import annotations

from itertools import count, product

import numpy as np
import toolz

from dask.array.chunk import getitem
from dask.array.core import Array, concatenate3
from dask.base import tokenize
from dask.highlevelgraph import HighLevelGraph


def shuffle(
x,
indexer: list[list[int]],
axis,
):
average_chunk_size = int(sum(x.chunks[axis]) / len(x.chunks[axis]) * 1.25)

# Figure out how many groups we can put into one chunk
current_bucket, buckets = [], []
for index in indexer:
if (
len(current_bucket) + len(index) > average_chunk_size
and len(current_bucket) > 0
):
buckets.append(current_bucket)
current_bucket = index.copy()
else:
current_bucket.extend(index)
if len(current_bucket) > average_chunk_size / 1.25:
buckets.append(current_bucket)
current_bucket = []
if len(current_bucket) > 0:
buckets.append(current_bucket)

chunk_borders = np.cumsum(x.chunks[axis])
new_index = list(
product(*(range(len(c)) for i, c in enumerate(x.chunks) if i != axis))
)

intermediates = dict()
merges = dict()
token = tokenize(x, indexer, axis)
split_name = f"shuffle-split-{token}"
merge_name = f"shuffle-merge-{token}"
slices = (slice(None),) * (len(x.chunks) - 1)
split_name_suffixes = count()

old_blocks = np.empty([len(c) for c in x.chunks], dtype="O")
for index in np.ndindex(old_blocks.shape):
old_blocks[index] = (x.name,) + index

for final_chunk, bucket in enumerate(buckets):
arr = np.array(bucket)
sorter = np.argsort(arr)
sorted_array = arr[sorter]
chunk_nrs, borders = np.unique(
np.searchsorted(chunk_borders, sorted_array, side="right"),
return_index=True,
)
borders = borders.tolist()
borders.append(len(bucket))

for nidx in new_index:
keys = []

for i, (c, b_start, b_end) in enumerate(
zip(chunk_nrs, borders[:-1], borders[1:])
):
key = convert_key(nidx, c, axis)
name = (split_name, next(split_name_suffixes))
intermediates[name] = (
getitem,
old_blocks[key],
convert_key(
slices,
sorted_array[b_start:b_end]
- (chunk_borders[c - 1] if c > 0 else 0),
axis,
),
)
keys.append(name)

final_suffix = convert_key(nidx, final_chunk, axis)
if len(keys) > 1:
merges[(merge_name,) + final_suffix] = (
concatenate_arrays,
keys,
sorter,
axis,
)
elif len(keys) == 1:
merges[(merge_name,) + final_suffix] = keys[0]
else:
raise NotImplementedError

layer = toolz.merge(merges, intermediates)
graph = HighLevelGraph.from_collections(merge_name, layer, dependencies=[x])

chunks = []
for i, c in enumerate(x.chunks):
if i == axis:
chunks.append(tuple(map(len, buckets)))
else:
chunks.append(c)

return Array(graph, merge_name, chunks, meta=x)


def concatenate_arrays(arrs, sorter, axis):
return np.take(np.concatenate(arrs, axis=axis), np.argsort(sorter), axis=axis)


def convert_key(key, chunk, axis):
key = list(key)
key.insert(axis, chunk)
return tuple(key)
2 changes: 1 addition & 1 deletion dask/array/chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,7 @@ def getitem(obj, index):

"""
try:
result = obj[index]
result = obj[*index]
except IndexError as e:
raise ValueError(
"Array chunk size or shape is unknown. "
Expand Down
9 changes: 9 additions & 0 deletions dask/array/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2762,6 +2762,15 @@ def rechunk(

return rechunk(self, chunks, threshold, block_size_limit, balance, method)

def shuffle(
self,
indexer: list[list[int]],
axis,
):
from dask.array._shuffle import shuffle

return shuffle(self, indexer, axis)

@property
def real(self):
from dask.array.ufunc import real
Expand Down