Skip to content

Commit cd6065e

Browse files
TomNicholasdcherianIllviljan
authored
Rely on NEP-18 to dispatch to dask in duck_array_ops (#5571)
* basic test for the mean * minimum to get mean working * don't even need to call dask specifically * remove reference to dask when dispatching to modules * fixed special case of pandas vs dask isnull * removed _dask_or_eager_func completely * noqa * pre-commit * what's new * linting * properly import dask for test * fix iris conversion error by rolling back treatment of np.ma.masked_invalid * linting * Update xarray/core/duck_array_ops.py Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com> * Update xarray/core/duck_array_ops.py Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com> * Update xarray/core/duck_array_ops.py Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com> Co-authored-by: Deepak Cherian <dcherian@users.noreply.github.com> Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com>
1 parent 4340909 commit cd6065e

File tree

5 files changed

+91
-87
lines changed

5 files changed

+91
-87
lines changed

doc/whats-new.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,10 @@ Internal Changes
206206
pandas-specific implementation into ``PandasIndex.query()`` and
207207
``PandasMultiIndex.query()`` (:pull:`5322`).
208208
By `Benoit Bovy <https://github.com/benbovy>`_.
209+
- Refactor `xarray.core.duck_array_ops` to no longer special-case dispatching to
210+
dask versions of functions when acting on dask arrays, instead relying numpy
211+
and dask's adherence to NEP-18 to dispatch automatically. (:pull:`5571`)
212+
By `Tom Nicholas <https://github.com/TomNicholas>`_.
209213

210214
.. _whats-new.0.18.2:
211215

xarray/core/duck_array_ops.py

Lines changed: 56 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,15 @@
1111

1212
import numpy as np
1313
import pandas as pd
14+
from numpy import all as array_all # noqa
15+
from numpy import any as array_any # noqa
16+
from numpy import zeros_like # noqa
17+
from numpy import around, broadcast_to # noqa
18+
from numpy import concatenate as _concatenate
19+
from numpy import einsum, isclose, isin, isnan, isnat, pad # noqa
20+
from numpy import stack as _stack
21+
from numpy import take, tensordot, transpose, unravel_index # noqa
22+
from numpy import where as _where
1423

1524
from . import dask_array_compat, dask_array_ops, dtypes, npcompat, nputils
1625
from .nputils import nanfirst, nanlast
@@ -34,31 +43,15 @@ def _dask_or_eager_func(
3443
name,
3544
eager_module=np,
3645
dask_module=dask_array,
37-
list_of_args=False,
38-
array_args=slice(1),
39-
requires_dask=None,
4046
):
4147
"""Create a function that dispatches to dask for dask array inputs."""
42-
if dask_module is not None:
43-
44-
def f(*args, **kwargs):
45-
if list_of_args:
46-
dispatch_args = args[0]
47-
else:
48-
dispatch_args = args[array_args]
49-
if any(is_duck_dask_array(a) for a in dispatch_args):
50-
try:
51-
wrapped = getattr(dask_module, name)
52-
except AttributeError as e:
53-
raise AttributeError(f"{e}: requires dask >={requires_dask}")
54-
else:
55-
wrapped = getattr(eager_module, name)
56-
return wrapped(*args, **kwargs)
5748

58-
else:
59-
60-
def f(*args, **kwargs):
61-
return getattr(eager_module, name)(*args, **kwargs)
49+
def f(*args, **kwargs):
50+
if any(is_duck_dask_array(a) for a in args):
51+
wrapped = getattr(dask_module, name)
52+
else:
53+
wrapped = getattr(eager_module, name)
54+
return wrapped(*args, **kwargs)
6255

6356
return f
6457

@@ -72,16 +65,40 @@ def fail_on_dask_array_input(values, msg=None, func_name=None):
7265
raise NotImplementedError(msg % func_name)
7366

7467

75-
around = _dask_or_eager_func("around")
76-
isclose = _dask_or_eager_func("isclose")
77-
68+
# Requires special-casing because pandas won't automatically dispatch to dask.isnull via NEP-18
69+
pandas_isnull = _dask_or_eager_func("isnull", eager_module=pd, dask_module=dask_array)
7870

79-
isnat = np.isnat
80-
isnan = _dask_or_eager_func("isnan")
81-
zeros_like = _dask_or_eager_func("zeros_like")
82-
83-
84-
pandas_isnull = _dask_or_eager_func("isnull", eager_module=pd)
71+
# np.around has failing doctests, overwrite it so they pass:
72+
# https://github.com/numpy/numpy/issues/19759
73+
around.__doc__ = str.replace(
74+
around.__doc__ or "",
75+
"array([0., 2.])",
76+
"array([0., 2.])",
77+
)
78+
around.__doc__ = str.replace(
79+
around.__doc__ or "",
80+
"array([0., 2.])",
81+
"array([0., 2.])",
82+
)
83+
around.__doc__ = str.replace(
84+
around.__doc__ or "",
85+
"array([0.4, 1.6])",
86+
"array([0.4, 1.6])",
87+
)
88+
around.__doc__ = str.replace(
89+
around.__doc__ or "",
90+
"array([0., 2., 2., 4., 4.])",
91+
"array([0., 2., 2., 4., 4.])",
92+
)
93+
around.__doc__ = str.replace(
94+
around.__doc__ or "",
95+
(
96+
' .. [2] "How Futile are Mindless Assessments of\n'
97+
' Roundoff in Floating-Point Computation?", William Kahan,\n'
98+
" https://people.eecs.berkeley.edu/~wkahan/Mindless.pdf\n"
99+
),
100+
"",
101+
)
85102

86103

87104
def isnull(data):
@@ -114,21 +131,10 @@ def notnull(data):
114131
return ~isnull(data)
115132

116133

117-
transpose = _dask_or_eager_func("transpose")
118-
_where = _dask_or_eager_func("where", array_args=slice(3))
119-
isin = _dask_or_eager_func("isin", array_args=slice(2))
120-
take = _dask_or_eager_func("take")
121-
broadcast_to = _dask_or_eager_func("broadcast_to")
122-
pad = _dask_or_eager_func("pad", dask_module=dask_array_compat)
123-
124-
_concatenate = _dask_or_eager_func("concatenate", list_of_args=True)
125-
_stack = _dask_or_eager_func("stack", list_of_args=True)
126-
127-
array_all = _dask_or_eager_func("all")
128-
array_any = _dask_or_eager_func("any")
129-
130-
tensordot = _dask_or_eager_func("tensordot", array_args=slice(2))
131-
einsum = _dask_or_eager_func("einsum", array_args=slice(1, None))
134+
# TODO replace with simply np.ma.masked_invalid once numpy/numpy#16022 is fixed
135+
masked_invalid = _dask_or_eager_func(
136+
"masked_invalid", eager_module=np.ma, dask_module=getattr(dask_array, "ma", None)
137+
)
132138

133139

134140
def gradient(x, coord, axis, edge_order):
@@ -166,11 +172,6 @@ def cumulative_trapezoid(y, x, axis):
166172
return cumsum(integrand, axis=axis, skipna=False)
167173

168174

169-
masked_invalid = _dask_or_eager_func(
170-
"masked_invalid", eager_module=np.ma, dask_module=getattr(dask_array, "ma", None)
171-
)
172-
173-
174175
def astype(data, dtype, **kwargs):
175176
if (
176177
isinstance(data, sparse_array_type)
@@ -317,9 +318,7 @@ def _ignore_warnings_if(condition):
317318
yield
318319

319320

320-
def _create_nan_agg_method(
321-
name, dask_module=dask_array, coerce_strings=False, invariant_0d=False
322-
):
321+
def _create_nan_agg_method(name, coerce_strings=False, invariant_0d=False):
323322
from . import nanops
324323

325324
def f(values, axis=None, skipna=None, **kwargs):
@@ -344,7 +343,8 @@ def f(values, axis=None, skipna=None, **kwargs):
344343
else:
345344
if name in ["sum", "prod"]:
346345
kwargs.pop("min_count", None)
347-
func = _dask_or_eager_func(name, dask_module=dask_module)
346+
347+
func = getattr(np, name)
348348

349349
try:
350350
with warnings.catch_warnings():
@@ -378,9 +378,7 @@ def f(values, axis=None, skipna=None, **kwargs):
378378
std.numeric_only = True
379379
var = _create_nan_agg_method("var")
380380
var.numeric_only = True
381-
median = _create_nan_agg_method(
382-
"median", dask_module=dask_array_compat, invariant_0d=True
383-
)
381+
median = _create_nan_agg_method("median", invariant_0d=True)
384382
median.numeric_only = True
385383
prod = _create_nan_agg_method("prod", invariant_0d=True)
386384
prod.numeric_only = True
@@ -389,7 +387,6 @@ def f(values, axis=None, skipna=None, **kwargs):
389387
cumprod_1d.numeric_only = True
390388
cumsum_1d = _create_nan_agg_method("cumsum", invariant_0d=True)
391389
cumsum_1d.numeric_only = True
392-
unravel_index = _dask_or_eager_func("unravel_index")
393390

394391

395392
_mean = _create_nan_agg_method("mean", invariant_0d=True)

xarray/core/nanops.py

Lines changed: 11 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,7 @@
33
import numpy as np
44

55
from . import dtypes, nputils, utils
6-
from .duck_array_ops import (
7-
_dask_or_eager_func,
8-
count,
9-
fillna,
10-
isnull,
11-
where,
12-
where_method,
13-
)
6+
from .duck_array_ops import count, fillna, isnull, where, where_method
147
from .pycompat import dask_array_type
158

169
try:
@@ -53,7 +46,7 @@ def _nan_argminmax_object(func, fill_value, value, axis=None, **kwargs):
5346
"""
5447
valid_count = count(value, axis=axis)
5548
value = fillna(value, fill_value)
56-
data = _dask_or_eager_func(func)(value, axis=axis, **kwargs)
49+
data = getattr(np, func)(value, axis=axis, **kwargs)
5750

5851
# TODO This will evaluate dask arrays and might be costly.
5952
if (valid_count == 0).any():
@@ -111,7 +104,7 @@ def nanargmax(a, axis=None):
111104

112105
def nansum(a, axis=None, dtype=None, out=None, min_count=None):
113106
a, mask = _replace_nan(a, 0)
114-
result = _dask_or_eager_func("sum")(a, axis=axis, dtype=dtype)
107+
result = np.sum(a, axis=axis, dtype=dtype)
115108
if min_count is not None:
116109
return _maybe_null_out(result, axis, mask, min_count)
117110
else:
@@ -120,7 +113,7 @@ def nansum(a, axis=None, dtype=None, out=None, min_count=None):
120113

121114
def _nanmean_ddof_object(ddof, value, axis=None, dtype=None, **kwargs):
122115
"""In house nanmean. ddof argument will be used in _nanvar method"""
123-
from .duck_array_ops import _dask_or_eager_func, count, fillna, where_method
116+
from .duck_array_ops import count, fillna, where_method
124117

125118
valid_count = count(value, axis=axis)
126119
value = fillna(value, 0)
@@ -129,7 +122,7 @@ def _nanmean_ddof_object(ddof, value, axis=None, dtype=None, **kwargs):
129122
if dtype is None and value.dtype.kind == "O":
130123
dtype = value.dtype if value.dtype.kind in ["cf"] else float
131124

132-
data = _dask_or_eager_func("sum")(value, axis=axis, dtype=dtype, **kwargs)
125+
data = np.sum(value, axis=axis, dtype=dtype, **kwargs)
133126
data = data / (valid_count - ddof)
134127
return where_method(data, valid_count != 0)
135128

@@ -155,7 +148,7 @@ def nanmedian(a, axis=None, out=None):
155148
# possibly blow memory
156149
if axis is not None and len(np.atleast_1d(axis)) == a.ndim:
157150
axis = None
158-
return _dask_or_eager_func("nanmedian", eager_module=nputils)(a, axis=axis)
151+
return nputils.nanmedian(a, axis=axis)
159152

160153

161154
def _nanvar_object(value, axis=None, ddof=0, keepdims=False, **kwargs):
@@ -170,33 +163,25 @@ def nanvar(a, axis=None, dtype=None, out=None, ddof=0):
170163
if a.dtype.kind == "O":
171164
return _nanvar_object(a, axis=axis, dtype=dtype, ddof=ddof)
172165

173-
return _dask_or_eager_func("nanvar", eager_module=nputils)(
174-
a, axis=axis, dtype=dtype, ddof=ddof
175-
)
166+
return nputils.nanvar(a, axis=axis, dtype=dtype, ddof=ddof)
176167

177168

178169
def nanstd(a, axis=None, dtype=None, out=None, ddof=0):
179-
return _dask_or_eager_func("nanstd", eager_module=nputils)(
180-
a, axis=axis, dtype=dtype, ddof=ddof
181-
)
170+
return nputils.nanstd(a, axis=axis, dtype=dtype, ddof=ddof)
182171

183172

184173
def nanprod(a, axis=None, dtype=None, out=None, min_count=None):
185174
a, mask = _replace_nan(a, 1)
186-
result = _dask_or_eager_func("nanprod")(a, axis=axis, dtype=dtype, out=out)
175+
result = nputils.nanprod(a, axis=axis, dtype=dtype, out=out)
187176
if min_count is not None:
188177
return _maybe_null_out(result, axis, mask, min_count)
189178
else:
190179
return result
191180

192181

193182
def nancumsum(a, axis=None, dtype=None, out=None):
194-
return _dask_or_eager_func("nancumsum", eager_module=nputils)(
195-
a, axis=axis, dtype=dtype
196-
)
183+
return nputils.nancumsum(a, axis=axis, dtype=dtype)
197184

198185

199186
def nancumprod(a, axis=None, dtype=None, out=None):
200-
return _dask_or_eager_func("nancumprod", eager_module=nputils)(
201-
a, axis=axis, dtype=dtype
202-
)
187+
return nputils.nancumprod(a, axis=axis, dtype=dtype)

xarray/tests/test_units.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
assert_duckarray_allclose,
1414
assert_equal,
1515
assert_identical,
16+
requires_dask,
1617
requires_matplotlib,
1718
)
1819
from .test_plot import PlotTestCase
@@ -5579,6 +5580,24 @@ def test_merge(self, variant, unit, error, dtype):
55795580
assert_equal(expected, actual)
55805581

55815582

5583+
@requires_dask
5584+
class TestPintWrappingDask:
5585+
def test_duck_array_ops(self):
5586+
import dask.array
5587+
5588+
d = dask.array.array([1, 2, 3])
5589+
q = pint.Quantity(d, units="m")
5590+
da = xr.DataArray(q, dims="x")
5591+
5592+
actual = da.mean().compute()
5593+
actual.name = None
5594+
expected = xr.DataArray(pint.Quantity(np.array(2.0), units="m"))
5595+
5596+
assert_units_equal(expected, actual)
5597+
# Don't use isinstance b/c we don't want to allow subclasses through
5598+
assert type(expected.data) == type(actual.data) # noqa
5599+
5600+
55825601
@requires_matplotlib
55835602
class TestPlots(PlotTestCase):
55845603
def test_units_in_line_plot_labels(self):

xarray/ufuncs.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020

2121
from .core.dataarray import DataArray as _DataArray
2222
from .core.dataset import Dataset as _Dataset
23-
from .core.duck_array_ops import _dask_or_eager_func
2423
from .core.groupby import GroupBy as _GroupBy
2524
from .core.pycompat import dask_array_type as _dask_array_type
2625
from .core.variable import Variable as _Variable
@@ -71,7 +70,7 @@ def __call__(self, *args, **kwargs):
7170
new_args = tuple(reversed(args))
7271

7372
if res is _UNDEFINED:
74-
f = _dask_or_eager_func(self._name, array_args=slice(len(args)))
73+
f = getattr(_np, self._name)
7574
res = f(*new_args, **kwargs)
7675
if res is NotImplemented:
7776
raise TypeError(

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy