order: combining different xarray variables followed by a reduction orders very inefficiently #11641
Description
Lets look at the following example:
import xarray as xr
import dask.array as da
size = 50
ds = xr.Dataset(
dict(
u=(
["time", "j", "i"],
da.random.random((size, 20, 20), chunks=(10, -1, -1)),
),
v=(
["time", "j", "i"],
da.random.random((size, 20, 20), chunks=(10, -1, -1)),
),
w=(
["time", "j", "i"],
da.random.random((size, 20, 20), chunks=(10, -1, -1)),
),
)
)
ds["uv"] = ds.u * ds.v
ds["vw"] = ds.v * ds.w
ds = ds.fillna(199)
We are combining u and v and then v and w. Not having a reduction after that step generally works fine:
The individual chunks in one array are independent of all other chunks, so we can process chunk by chunk for all data arrays.
Adding a reduction after these cross dependencies makes things go sideways:
Add:
ds = ds.count()
The ordering algorithm eagerly processes a complete tree reduction for the first variable uv
before touching anything from vw
. This means that the data array v
is loaded completely into memory when the first tree reduction is finished before we are tackling the vw
and thus we can't release any chunk from v
.
I am not sure what a good solution here would look like. Ideally, the ordering algorithm would know that the v
chunks are a lot larger than the reduced chunks of the uv
combination and thus prefer processing v
before starting with a new chunk of uv
.
Alternatively, we could load v
twice, i.e. drop the v chunks after they are added to uv
.
This is the pattern that kills https://github.com/coiled/benchmarks/blob/main/tests/geospatial/workloads/atmospheric_circulation.py
task graph:
from dask.base import collections_to_dsk
dsk = collections_to_dsk([ds.uv.data, ds.vw.data], optimize_graph=True)
cc @fjetter