Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Matrix multiplication inconsistent for dask and sparse: dask @ sparse works, sparse @ dask fails #9934

Open
brendan-m-murphy opened this issue Jan 9, 2025 · 5 comments

Comments

@brendan-m-murphy
Copy link

brendan-m-murphy commented Jan 9, 2025

What is your issue?

The order of matrix multiplication matters for dask and sparse arrays, but probably shouldn't.

Here is an example:

import numpy as np
from sparse import COO
import xarray as xr

sparse_da = xr.DataArray(COO.from_numpy(np.arange(10)))
dask_da = xr.DataArray(np.arange(100).reshape(10, 10).chunk({"dim_1": 5})

# this works:
dask_da @ sparse_da

# this raises a TypeError for "unsupported types"
sparse_da @ dask_da

# this works as expected
sparse_da @ dask_da.as_numpy()

In the workflow where this is used, the dask array has no chunks along its common dimensions with the sparse array, so it seems like sparse @ chunk should be fine. Also, in this workflow, loading the dask array into memory or making the sparse array dense would use a very large amount of memory.

@brendan-m-murphy brendan-m-murphy added the needs triage Issue that has not been reviewed by xarray team member label Jan 9, 2025
@keewis
Copy link
Collaborator

keewis commented Jan 9, 2025

this is a side-effect of using opt_einsum: while

sparse_da @ dask_da

fails, this:

with xr.set_options(use_opt_einsum=False):
    sparse_da @ dask_da

works (but is slower). You could also avoid this issue by converting sparse_da to a single-chunk dask array using .chunk().

(not sure what to do to fix this in xarray, but I agree that this is annoying)

@brendan-m-murphy
Copy link
Author

Thanks for the suggestions!

This is the ufunc we've been using instead: https://github.com/openghg/openghg_inversions/blob/sparse-xarray-fix/openghg_inversions/array_ops.py#L75. opt_einsum ends up calling sparse.tensordot, so I made this ufunc just call sparse.tensordot immediately. There's a bit of a "gotcha" where sparse.tensordot leaves extra broadcast dimensions around, so they need to be removed.

I might just rewrite it to apply @ in the opposite order, or use your chunking suggestion. (Or remove the function and just inline this code.)

Maybe xarray could check for the case of sparse and dask and either swap the order or convert the sparse array into a single chunk dask array? I guess opt-einsum could probably fix this too, but it seems like they just check if sparse has a tensordot attribute and apply it, so they would also need to insert some ugly logic to check both operands and decide what tensordot to use.

brendan-m-murphy added a commit to openghg/openghg_inversions that referenced this issue Jan 9, 2025
Used chunking suggestion from pydata/xarray#9934
@dcherian dcherian added upstream issue and removed needs triage Issue that has not been reviewed by xarray team member labels Jan 9, 2025
@dcherian
Copy link
Contributor

dcherian commented Jan 9, 2025

Can you open an issue at pydata/sparse instead?

@brendan-m-murphy
Copy link
Author

Can you open an issue at pydata/sparse instead?

Sure, I can see what they say. I think they would need to change tensordot to accept dask arrays (or any duck array?). At least in this case, the non-sparse array just needs to support indexing (I think).

@brendan-m-murphy
Copy link
Author

brendan-m-murphy commented Jan 15, 2025

The conclusion of the issue I opened with sparse is that they cannot support matmul with dask arrays, since they use numba to do the multiplication with ndarrays, so other types of arrays cannot be substituted for numpy in this case.

Possibilities for fixing/improving this in xarray are:

  1. check if one of the operands in xr.dot is a dask array, and make this the first operand
  2. check for case sparse @ dask and call .chunk() on the sparse operand
  3. add a note in the docstring for dot saying that the einsum operation is based on the first operand, and mention this particular case

1 or 3 seem like the best options to me (assuming I've got it right for option 3). I'm happy to put in a PR, if either seems acceptable.

Edit: option 1 probably also something that could be done by opt-einsum

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants