Skip to content

Commit

Permalink
[data] Fix arrow dataset sort on empty blocks (#19707)
Browse files Browse the repository at this point in the history
  • Loading branch information
jjyao authored Oct 26, 2021
1 parent 3e81506 commit 47744d2
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 4 deletions.
24 changes: 20 additions & 4 deletions python/ray/data/impl/arrow_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,10 +270,14 @@ def zip(self, other: "Block[T]") -> "Block[T]":
def builder() -> ArrowBlockBuilder[T]:
return ArrowBlockBuilder()

def sample(self, n_samples: int, key: SortKeyT) -> List[T]:
def sample(self, n_samples: int, key: SortKeyT) -> "pyarrow.Table":
if key is None or callable(key):
raise NotImplementedError(
"Arrow sort key must be a column name, was: {}".format(key))
if self._table.num_rows == 0:
# If the pyarrow table is empty we may not have schema
# so calling table.select() will raise an error.
return pyarrow.Table.from_pydict({})
k = min(n_samples, self._table.num_rows)
indices = random.sample(range(self._table.num_rows), k)
return self._table.select([k[0] for k in key]).take(indices)
Expand All @@ -284,6 +288,14 @@ def sort_and_partition(self, boundaries: List[T], key: SortKeyT,
raise NotImplementedError(
"sorting by multiple columns is not supported yet")

if self._table.num_rows == 0:
# If the pyarrow table is empty we may not have schema
# so calling sort_indices() will raise an error.
return [
pyarrow.Table.from_pydict({})
for _ in range(len(boundaries) + 1)
]

import pyarrow.compute as pac

indices = pac.sort_indices(self._table, sort_keys=key)
Expand Down Expand Up @@ -330,9 +342,13 @@ def sort_and_partition(self, boundaries: List[T], key: SortKeyT,
def merge_sorted_blocks(
blocks: List[Block[T]], key: SortKeyT,
_descending: bool) -> Tuple[Block[T], BlockMetadata]:
ret = pyarrow.concat_tables(blocks, promote=True)
indices = pyarrow.compute.sort_indices(ret, sort_keys=key)
ret = ret.take(indices)
blocks = [b for b in blocks if b.num_rows > 0]
if len(blocks) == 0:
ret = pyarrow.Table.from_pydict({})
else:
ret = pyarrow.concat_tables(blocks, promote=True)
indices = pyarrow.compute.sort_indices(ret, sort_keys=key)
ret = ret.take(indices)
return ret, ArrowBlockAccessor(ret).get_metadata(None)


Expand Down
1 change: 1 addition & 0 deletions python/ray/data/impl/sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def sample_boundaries(blocks: BlockList[T], key: SortKeyT,
sample_bar.close()

samples = ray.get(sample_results)
samples = [s for s in samples if len(s) > 0]
sample_items = np.concatenate(samples)
sample_items.sort()
ret = [
Expand Down
25 changes: 25 additions & 0 deletions python/ray/data/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2956,6 +2956,31 @@ def assert_sorted(sorted_ds, expected_rows):
ds.sort(key=[("b", "descending")]), zip(reversed(a), reversed(b)))


def test_sort_arrow_with_empty_blocks(ray_start_regular):
assert BlockAccessor.for_block(pa.Table.from_pydict({})).sample(
10, "A").num_rows == 0

partitions = BlockAccessor.for_block(pa.Table.from_pydict(
{})).sort_and_partition(
[1, 5, 10], "A", descending=False)
assert len(partitions) == 4
for partition in partitions:
assert partition.num_rows == 0

assert BlockAccessor.for_block(pa.Table.from_pydict(
{})).merge_sorted_blocks([pa.Table.from_pydict({})], "A",
False)[0].num_rows == 0

ds = ray.data.from_items(
[{
"A": (x % 3),
"B": x
} for x in range(3)], parallelism=3)
ds = ds.filter(lambda r: r["A"] == 0)
assert [row.as_pydict() for row in ds.sort("A").iter_rows()] == \
[{"A": 0, "B": 0}]


def test_dataset_retry_exceptions(ray_start_regular, local_path):
@ray.remote
class Counter:
Expand Down

0 comments on commit 47744d2

Please sign in to comment.