Skip to content

Commit a4ba2bc

Browse files
committed
Migrate to DaskIndexingAdapter
1 parent e2547f1 commit a4ba2bc

File tree

2 files changed

+48
-39
lines changed

2 files changed

+48
-39
lines changed

xarray/core/computation.py

Lines changed: 6 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -2126,18 +2126,6 @@ def to_floatable(x: DataArray) -> DataArray:
21262126
return to_floatable(data)
21272127

21282128

2129-
def _apply_vectorized_indexer(indices, coord):
2130-
from xarray.core.indexing import (
2131-
VectorizedIndexer,
2132-
apply_indexer,
2133-
as_indexable,
2134-
)
2135-
2136-
return apply_indexer(
2137-
as_indexable(coord), VectorizedIndexer((indices.squeeze(axis=-1),))
2138-
)
2139-
2140-
21412129
def _calc_idxminmax(
21422130
*,
21432131
array,
@@ -2182,28 +2170,14 @@ def _calc_idxminmax(
21822170
indx = func(array, dim=dim, axis=None, keep_attrs=keep_attrs, skipna=skipna)
21832171

21842172
# Handle chunked arrays (e.g. dask).
2173+
coord = array[dim]._variable.to_base_variable()
21852174
if is_chunked_array(array.data):
21862175
chunkmanager = get_chunked_array_type(array.data)
2187-
chunked_coord = chunkmanager.from_array(array[dim].data, chunks=((-1,),))
2188-
2189-
if indx.ndim == 0:
2190-
out = chunked_coord[indx.data]
2191-
else:
2192-
out = chunkmanager.map_blocks(
2193-
_apply_vectorized_indexer,
2194-
indx.data[..., np.newaxis],
2195-
chunked_coord,
2196-
chunks=indx.data.chunks,
2197-
drop_axis=-1,
2198-
dtype=chunked_coord.dtype,
2199-
)
2200-
res = indx.copy(data=out)
2201-
# we need to attach back the dim name
2202-
res.name = dim
2203-
else:
2204-
res = array[dim][(indx,)]
2205-
# The dim is gone but we need to remove the corresponding coordinate.
2206-
del res.coords[dim]
2176+
coord_array = chunkmanager.from_array(
2177+
array[dim].data, chunks=((array.sizes[dim],),)
2178+
)
2179+
coord = coord.copy(data=coord_array)
2180+
res = indx._replace(coord[(indx.variable,)]).rename(dim)
22072181

22082182
if skipna or (skipna is None and array.dtype.kind in na_dtypes):
22092183
# Put the NaN values back in after removing them

xarray/core/indexing.py

Lines changed: 42 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import enum
44
import functools
5+
import math
56
import operator
67
from collections import Counter, defaultdict
78
from collections.abc import Callable, Hashable, Iterable, Mapping
@@ -472,12 +473,12 @@ def __init__(self, key: tuple[slice | np.ndarray[Any, np.dtype[np.generic]], ...
472473
for k in key:
473474
if isinstance(k, slice):
474475
k = as_integer_slice(k)
475-
elif is_duck_dask_array(k):
476-
raise ValueError(
477-
"Vectorized indexing with Dask arrays is not supported. "
478-
"Please pass a numpy array by calling ``.compute``. "
479-
"See https://github.com/dask/dask/issues/8958."
480-
)
476+
# elif is_duck_dask_array(k):
477+
# raise ValueError(
478+
# "Vectorized indexing with Dask arrays is not supported. "
479+
# "Please pass a numpy array by calling ``.compute``. "
480+
# "See https://github.com/dask/dask/issues/8958."
481+
# )
481482
elif is_duck_array(k):
482483
if not np.issubdtype(k.dtype, np.integer):
483484
raise TypeError(
@@ -1607,6 +1608,18 @@ def transpose(self, order):
16071608
return xp.permute_dims(self.array, order)
16081609

16091610

1611+
def _apply_vectorized_indexer_dask_wrapper(indices, coord):
1612+
from xarray.core.indexing import (
1613+
VectorizedIndexer,
1614+
apply_indexer,
1615+
as_indexable,
1616+
)
1617+
1618+
return apply_indexer(
1619+
as_indexable(coord), VectorizedIndexer((indices.squeeze(axis=-1),))
1620+
)
1621+
1622+
16101623
class DaskIndexingAdapter(ExplicitlyIndexedNDArrayMixin):
16111624
"""Wrap a dask array to support explicit indexing."""
16121625

@@ -1630,7 +1643,29 @@ def _oindex_get(self, indexer: OuterIndexer):
16301643
return value
16311644

16321645
def _vindex_get(self, indexer: VectorizedIndexer):
1633-
return self.array.vindex[indexer.tuple]
1646+
try:
1647+
return self.array.vindex[indexer.tuple]
1648+
except IndexError as e:
1649+
# TODO: upstream to dask
1650+
has_dask = any(is_duck_dask_array(i) for i in indexer.tuple)
1651+
if not has_dask or (has_dask and len(indexer.tuple) > 1):
1652+
raise e
1653+
if math.prod(self.array.numblocks) > 1 or self.array.ndim > 1:
1654+
raise e
1655+
(idxr,) = indexer.tuple
1656+
if idxr.ndim == 0:
1657+
return self.array[idxr.data]
1658+
else:
1659+
import dask.array
1660+
1661+
return dask.array.map_blocks(
1662+
_apply_vectorized_indexer_dask_wrapper,
1663+
idxr[..., np.newaxis],
1664+
self.array,
1665+
chunks=idxr.chunks,
1666+
drop_axis=-1,
1667+
dtype=self.array.dtype,
1668+
)
16341669

16351670
def __getitem__(self, indexer: ExplicitIndexer):
16361671
self._check_and_raise_if_non_basic_indexer(indexer)

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy