2
2
3
3
import enum
4
4
import functools
5
+ import math
5
6
import operator
6
7
from collections import Counter , defaultdict
7
8
from collections .abc import Callable , Hashable , Iterable , Mapping
@@ -472,12 +473,12 @@ def __init__(self, key: tuple[slice | np.ndarray[Any, np.dtype[np.generic]], ...
472
473
for k in key :
473
474
if isinstance (k , slice ):
474
475
k = as_integer_slice (k )
475
- elif is_duck_dask_array (k ):
476
- raise ValueError (
477
- "Vectorized indexing with Dask arrays is not supported. "
478
- "Please pass a numpy array by calling ``.compute``. "
479
- "See https://github.com/dask/dask/issues/8958."
480
- )
476
+ # elif is_duck_dask_array(k):
477
+ # raise ValueError(
478
+ # "Vectorized indexing with Dask arrays is not supported. "
479
+ # "Please pass a numpy array by calling ``.compute``. "
480
+ # "See https://github.com/dask/dask/issues/8958."
481
+ # )
481
482
elif is_duck_array (k ):
482
483
if not np .issubdtype (k .dtype , np .integer ):
483
484
raise TypeError (
@@ -1607,6 +1608,18 @@ def transpose(self, order):
1607
1608
return xp .permute_dims (self .array , order )
1608
1609
1609
1610
1611
+ def _apply_vectorized_indexer_dask_wrapper (indices , coord ):
1612
+ from xarray .core .indexing import (
1613
+ VectorizedIndexer ,
1614
+ apply_indexer ,
1615
+ as_indexable ,
1616
+ )
1617
+
1618
+ return apply_indexer (
1619
+ as_indexable (coord ), VectorizedIndexer ((indices .squeeze (axis = - 1 ),))
1620
+ )
1621
+
1622
+
1610
1623
class DaskIndexingAdapter (ExplicitlyIndexedNDArrayMixin ):
1611
1624
"""Wrap a dask array to support explicit indexing."""
1612
1625
@@ -1630,7 +1643,29 @@ def _oindex_get(self, indexer: OuterIndexer):
1630
1643
return value
1631
1644
1632
1645
def _vindex_get (self , indexer : VectorizedIndexer ):
1633
- return self .array .vindex [indexer .tuple ]
1646
+ try :
1647
+ return self .array .vindex [indexer .tuple ]
1648
+ except IndexError as e :
1649
+ # TODO: upstream to dask
1650
+ has_dask = any (is_duck_dask_array (i ) for i in indexer .tuple )
1651
+ if not has_dask or (has_dask and len (indexer .tuple ) > 1 ):
1652
+ raise e
1653
+ if math .prod (self .array .numblocks ) > 1 or self .array .ndim > 1 :
1654
+ raise e
1655
+ (idxr ,) = indexer .tuple
1656
+ if idxr .ndim == 0 :
1657
+ return self .array [idxr .data ]
1658
+ else :
1659
+ import dask .array
1660
+
1661
+ return dask .array .map_blocks (
1662
+ _apply_vectorized_indexer_dask_wrapper ,
1663
+ idxr [..., np .newaxis ],
1664
+ self .array ,
1665
+ chunks = idxr .chunks ,
1666
+ drop_axis = - 1 ,
1667
+ dtype = self .array .dtype ,
1668
+ )
1634
1669
1635
1670
def __getitem__ (self , indexer : ExplicitIndexer ):
1636
1671
self ._check_and_raise_if_non_basic_indexer (indexer )
0 commit comments