From fd6b339c3f51e83d5f9deb12e837f822cd51a2f7 Mon Sep 17 00:00:00 2001 From: Sam Levang Date: Mon, 18 Nov 2024 11:03:25 -0500 Subject: [PATCH 01/24] lots more duck array compat, plus tests --- xarray/core/array_api_compat.py | 28 ++ xarray/core/common.py | 5 +- xarray/core/computation.py | 8 +- xarray/core/dataset.py | 12 +- xarray/core/duck_array_ops.py | 122 +++++--- xarray/core/nanops.py | 2 +- xarray/core/nputils.py | 6 + xarray/core/rolling.py | 3 + xarray/core/variable.py | 39 ++- xarray/tests/test_duck_array_wrapping.py | 371 +++++++++++++++++++++++ xarray/tests/test_strategies.py | 19 +- xarray/tests/test_variable.py | 39 +-- 12 files changed, 563 insertions(+), 91 deletions(-) create mode 100644 xarray/tests/test_duck_array_wrapping.py diff --git a/xarray/core/array_api_compat.py b/xarray/core/array_api_compat.py index da072de5b69..28d671cc349 100644 --- a/xarray/core/array_api_compat.py +++ b/xarray/core/array_api_compat.py @@ -1,5 +1,7 @@ import numpy as np +from xarray.namedarray.pycompat import array_type + def is_weak_scalar_type(t): return isinstance(t, bool | int | float | complex | str | bytes) @@ -42,3 +44,29 @@ def result_type(*arrays_and_dtypes, xp) -> np.dtype: return xp.result_type(*arrays_and_dtypes) else: return _future_array_api_result_type(*arrays_and_dtypes, xp=xp) + + +def get_array_namespace(*values): + def _get_single_namespace(x): + if hasattr(x, "__array_namespace__"): + return x.__array_namespace__() + elif isinstance(x, array_type("cupy")): + # special case cupy for now + import cupy as cp + + return cp + else: + return np + + namespaces = {_get_single_namespace(t) for t in values} + non_numpy = namespaces - {np} + + if len(non_numpy) > 1: + names = [module.__name__ for module in non_numpy] + raise TypeError(f"Mixed array types {names} are not supported.") + elif non_numpy: + [xp] = non_numpy + else: + xp = np + + return xp diff --git a/xarray/core/common.py b/xarray/core/common.py index 6f788f408d0..8aaa153c1a8 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -496,7 +496,7 @@ def clip( keep_attrs = _get_keep_attrs(default=True) return apply_ufunc( - np.clip, self, min, max, keep_attrs=keep_attrs, dask="allowed" + duck_array_ops.clip, self, min, max, keep_attrs=keep_attrs, dask="allowed" ) def get_index(self, key: Hashable) -> pd.Index: @@ -1760,7 +1760,8 @@ def _full_like_variable( **from_array_kwargs, ) else: - data = np.full_like(other.data, fill_value, dtype=dtype) + xp = duck_array_ops.get_array_namespace(other.data) + data = xp.full_like(other.data, fill_value, dtype=dtype) return Variable(dims=other.dims, data=data, attrs=other.attrs) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index b15ed7f3f34..0bfe21642f7 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -34,7 +34,7 @@ from xarray.core.utils import is_dict_like, is_scalar, parse_dims_as_set, result_name from xarray.core.variable import Variable from xarray.namedarray.parallelcompat import get_chunked_array_type -from xarray.namedarray.pycompat import is_chunked_array +from xarray.namedarray.pycompat import is_chunked_array, to_numpy from xarray.util.deprecation_helpers import deprecate_dims if TYPE_CHECKING: @@ -1702,7 +1702,7 @@ def cross( ) c = apply_ufunc( - np.cross, + duck_array_ops.cross, a, b, input_core_dims=[[dim], [dim]], @@ -2174,9 +2174,13 @@ def _calc_idxminmax( # we need to attach back the dim name res.name = dim else: + indx.data = to_numpy(indx.data) res = array[dim][(indx,)] # The dim is gone but we need to remove the corresponding coordinate. del res.coords[dim] + # Cast to array namespace + xp = duck_array_ops.get_array_namespace(array.data) + res.data = xp.asarray(res.data) if skipna or (skipna is None and array.dtype.kind in na_dtypes): # Put the NaN values back in after removing them diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index cc34a8cc04b..6e5a8e163b8 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -127,7 +127,7 @@ calculate_dimensions, ) from xarray.namedarray.parallelcompat import get_chunked_array_type, guess_chunkmanager -from xarray.namedarray.pycompat import array_type, is_chunked_array +from xarray.namedarray.pycompat import array_type, is_chunked_array, to_numpy from xarray.plot.accessor import DatasetPlotAccessor from xarray.util.deprecation_helpers import _deprecate_positional_args, deprecate_dims @@ -6564,7 +6564,7 @@ def dropna( array = self._variables[k] if dim in array.dims: dims = [d for d in array.dims if d != dim] - count += np.asarray(array.count(dims)) + count += to_numpy(array.count(dims).data) size += math.prod([self.sizes[d] for d in dims]) if thresh is not None: @@ -8678,16 +8678,20 @@ def _integrate_one(self, coord, datetime_unit=None, cumulative=False): coord_names.add(k) else: if k in self.data_vars and dim in v.dims: + # cast coord data to duck array if needed + coord_data = duck_array_ops.get_array_namespace(v.data).asarray( + coord_var.data + ) if _contains_datetime_like_objects(v): v = datetime_to_numeric(v, datetime_unit=datetime_unit) if cumulative: integ = duck_array_ops.cumulative_trapezoid( - v.data, coord_var.data, axis=v.get_axis_num(dim) + v.data, coord_data, axis=v.get_axis_num(dim) ) v_dims = v.dims else: integ = duck_array_ops.trapz( - v.data, coord_var.data, axis=v.get_axis_num(dim) + v.data, coord_data, axis=v.get_axis_num(dim) ) v_dims = list(v.dims) v_dims.remove(dim) diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 77e62e4c71e..d67f8d17207 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -18,23 +18,17 @@ import pandas as pd from numpy import all as array_all # noqa: F401 from numpy import any as array_any # noqa: F401 -from numpy import concatenate as _concatenate from numpy import ( # noqa: F401 - full_like, - gradient, isclose, - isin, isnat, take, - tensordot, - transpose, unravel_index, ) -from numpy.lib.stride_tricks import sliding_window_view # noqa: F401 from packaging.version import Version from pandas.api.types import is_extension_array_dtype from xarray.core import dask_array_ops, dtypes, nputils +from xarray.core.array_api_compat import get_array_namespace from xarray.core.options import OPTIONS from xarray.core.utils import is_duck_array, is_duck_dask_array, module_available from xarray.namedarray import pycompat @@ -55,28 +49,6 @@ dask_available = module_available("dask") -def get_array_namespace(*values): - def _get_array_namespace(x): - if hasattr(x, "__array_namespace__"): - return x.__array_namespace__() - else: - return np - - namespaces = {_get_array_namespace(t) for t in values} - non_numpy = namespaces - {np} - - if len(non_numpy) > 1: - raise TypeError( - "cannot deal with more than one type supporting the array API at the same time" - ) - elif non_numpy: - [xp] = non_numpy - else: - xp = np - - return xp - - def einsum(*args, **kwargs): from xarray.core.options import OPTIONS @@ -85,7 +57,23 @@ def einsum(*args, **kwargs): return opt_einsum.contract(*args, **kwargs) else: - return np.einsum(*args, **kwargs) + xp = get_array_namespace(*args) + return xp.einsum(*args, **kwargs) + + +def tensordot(*args, **kwargs): + xp = get_array_namespace(*args) + return xp.tensordot(*args, **kwargs) + + +def cross(*args, **kwargs): + xp = get_array_namespace(*args) + return xp.cross(*args, **kwargs) + + +def gradient(f, *varargs, axis=None, edge_order=1): + xp = get_array_namespace(f) + return xp.gradient(f, *varargs, axis=axis, edge_order=edge_order) def _dask_or_eager_func( @@ -153,7 +141,7 @@ def isnull(data): ) ): # these types cannot represent missing values - return full_like(data, dtype=bool, fill_value=False) + return full_like(data, dtype=xp.bool, fill_value=False) else: # at this point, array should have dtype=object if isinstance(data, np.ndarray) or is_extension_array_dtype(data): @@ -200,11 +188,23 @@ def cumulative_trapezoid(y, x, axis): # Pad so that 'axis' has same length in result as it did in y pads = [(1, 0) if i == axis else (0, 0) for i in range(y.ndim)] - integrand = np.pad(integrand, pads, mode="constant", constant_values=0.0) + + xp = get_array_namespace(y, x) + integrand = xp.pad(integrand, pads, mode="constant", constant_values=0.0) return cumsum(integrand, axis=axis, skipna=False) +def full_like(a, fill_value, **kwargs): + xp = get_array_namespace(a) + return xp.full_like(a, fill_value, **kwargs) + + +def empty_like(a, **kwargs): + xp = get_array_namespace(a) + return xp.empty_like(a, **kwargs) + + def astype(data, dtype, **kwargs): if hasattr(data, "__array_namespace__"): xp = get_array_namespace(data) @@ -335,7 +335,8 @@ def array_notnull_equiv(arr1, arr2): def count(data, axis=None): """Count the number of non-NA in this array along the given axis or axes""" - return np.sum(np.logical_not(isnull(data)), axis=axis) + xp = get_array_namespace(data) + return xp.sum(xp.logical_not(isnull(data)), axis=axis) def sum_where(data, axis=None, dtype=None, where=None): @@ -350,7 +351,7 @@ def sum_where(data, axis=None, dtype=None, where=None): def where(condition, x, y): """Three argument where() with better dtype promotion rules.""" - xp = get_array_namespace(condition) + xp = get_array_namespace(condition, x, y) return xp.where(condition, *as_shared_dtype([x, y], xp=xp)) @@ -367,15 +368,25 @@ def fillna(data, other): return where(notnull(data), data, other) +def logical_not(data): + xp = get_array_namespace(data) + return xp.logical_not(data) + + +def clip(data, min=None, max=None): + xp = get_array_namespace(data) + return xp.clip(data, min, max) + + def concatenate(arrays, axis=0): """concatenate() with better dtype promotion rules.""" - # TODO: remove the additional check once `numpy` adds `concat` to its array namespace - if hasattr(arrays[0], "__array_namespace__") and not isinstance( - arrays[0], np.ndarray - ): - xp = get_array_namespace(arrays[0]) + # TODO: `concat` is the xp compliant name, but fallback to concatenate for + # older numpy and for cupy + xp = get_array_namespace(*arrays) + if hasattr(xp, "concat"): return xp.concat(as_shared_dtype(arrays, xp=xp), axis=axis) - return _concatenate(as_shared_dtype(arrays), axis=axis) + else: + return xp.concatenate(as_shared_dtype(arrays, xp=xp), axis=axis) def stack(arrays, axis=0): @@ -393,6 +404,32 @@ def ravel(array): return reshape(array, (-1,)) +def transpose(array, axes=None): + xp = get_array_namespace(array) + return xp.transpose(array, axes) + + +def moveaxis(array, source, destination): + xp = get_array_namespace(array) + return xp.moveaxis(array, source, destination) + + +def pad(array, pad_width, **kwargs): + xp = get_array_namespace(array) + return xp.pad(array, pad_width, **kwargs) + + +def sliding_window_view(array, window_shape, axis=None): + # TODO: some array libraries don't support this, implement an alternative? + xp = get_array_namespace(array) + return xp.lib.stride_tricks.sliding_window_view(array, window_shape, axis=axis) + + +def quantile(array, q, axis=None, **kwargs): + xp = get_array_namespace(array) + return xp.quantile(array, q, axis=axis, **kwargs) + + @contextlib.contextmanager def _ignore_warnings_if(condition): if condition: @@ -734,6 +771,11 @@ def last(values, axis, skipna=None): return take(values, -1, axis=axis) +def isin(element, test_elements, **kwargs): + xp = get_array_namespace(element, test_elements) + return xp.isin(element, test_elements, **kwargs) + + def least_squares(lhs, rhs, rcond=None, skipna=False): """Return the coefficients and residuals of a least-squares fit.""" if is_duck_dask_array(rhs): diff --git a/xarray/core/nanops.py b/xarray/core/nanops.py index 7fbb63068c0..4894cf02be2 100644 --- a/xarray/core/nanops.py +++ b/xarray/core/nanops.py @@ -128,7 +128,7 @@ def nanmean(a, axis=None, dtype=None, out=None): "ignore", r"Mean of empty slice", category=RuntimeWarning ) - return np.nanmean(a, axis=axis, dtype=dtype) + return nputils.nanmean(a, axis=axis, dtype=dtype) def nanmedian(a, axis=None, out=None): diff --git a/xarray/core/nputils.py b/xarray/core/nputils.py index cd20dbccd87..24d6b1dda72 100644 --- a/xarray/core/nputils.py +++ b/xarray/core/nputils.py @@ -7,6 +7,7 @@ import pandas as pd from packaging.version import Version +from xarray.core.array_api_compat import get_array_namespace from xarray.core.utils import is_duck_array, module_available from xarray.namedarray import pycompat @@ -179,6 +180,11 @@ def f(values, axis=None, **kwargs): dtype = kwargs.get("dtype") bn_func = getattr(bn, name, None) + xp = get_array_namespace(values) + if xp is not np: + func = getattr(xp, name, None) + if func is not None: + return func(values, axis=axis, **kwargs) if ( module_available("numbagg") and pycompat.mod_version("numbagg") >= Version("0.5.0") diff --git a/xarray/core/rolling.py b/xarray/core/rolling.py index 781550207ff..dfeab5e409e 100644 --- a/xarray/core/rolling.py +++ b/xarray/core/rolling.py @@ -639,6 +639,7 @@ def _array_reduce( ) del kwargs["dim"] + xp = duck_array_ops.get_array_namespace(self.obj.data) if ( OPTIONS["use_numbagg"] and module_available("numbagg") @@ -654,6 +655,7 @@ def _array_reduce( # TODO: we could also allow this, probably as part of a refactoring of this # module, so we can use the machinery in `self.reduce`. and self.ndim == 1 + and xp is np ): import numbagg @@ -676,6 +678,7 @@ def _array_reduce( or module_available("dask", "2024.11.0") ) and self.ndim == 1 + and xp is np ): return self._bottleneck_reduce( bottleneck_move_func, keep_attrs=keep_attrs, **kwargs diff --git a/xarray/core/variable.py b/xarray/core/variable.py index a6ea44b1ee5..f4db3fa6b1d 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -822,7 +822,7 @@ def __getitem__(self, key) -> Self: data = indexing.apply_indexer(indexable, indexer) if new_order: - data = np.moveaxis(data, range(len(new_order)), new_order) + data = duck_array_ops.moveaxis(data, range(len(new_order)), new_order) return self._finalize_indexing_result(dims, data) def _finalize_indexing_result(self, dims, data) -> Self: @@ -860,12 +860,17 @@ def _getitem_with_mask(self, key, fill_value=dtypes.NA): # we need to invert the mask in order to pass data first. This helps # pint to choose the correct unit # TODO: revert after https://github.com/hgrecco/pint/issues/1019 is fixed - data = duck_array_ops.where(np.logical_not(mask), data, fill_value) + # cast mask to any duck array type + if not is_duck_dask_array(mask): + mask = duck_array_ops.get_array_namespace(data).asarray(mask) + data = duck_array_ops.where( + duck_array_ops.logical_not(mask), data, fill_value + ) else: # array cannot be indexed along dimensions of size 0, so just # build the mask directly instead. mask = indexing.create_mask(indexer, self.shape) - data = np.broadcast_to(fill_value, getattr(mask, "shape", ())) + data = duck_array_ops.broadcast_to(fill_value, getattr(mask, "shape", ())) if new_order: data = duck_array_ops.moveaxis(data, range(len(new_order)), new_order) @@ -896,7 +901,7 @@ def __setitem__(self, key, value): if new_order: value = duck_array_ops.asarray(value) value = value[(len(dims) - value.ndim) * (np.newaxis,) + (Ellipsis,)] - value = np.moveaxis(value, new_order, range(len(new_order))) + value = duck_array_ops.moveaxis(value, new_order, range(len(new_order))) indexable = as_indexable(self._data) indexing.set_with_indexer(indexable, index_tuple, value) @@ -1098,7 +1103,7 @@ def _shift_one_dim(self, dim, count, fill_value=dtypes.NA): dim_pad = (width, 0) if count >= 0 else (0, width) pads = [(0, 0) if d != dim else dim_pad for d in self.dims] - data = np.pad( + data = duck_array_ops.pad( duck_array_ops.astype(trimmed_data, dtype), pads, mode="constant", @@ -1244,7 +1249,7 @@ def pad( if reflect_type is not None: pad_option_kwargs["reflect_type"] = reflect_type - array = np.pad( + array = duck_array_ops.pad( duck_array_ops.astype(self.data, dtype, copy=False), pad_width_by_index, mode=mode, @@ -1533,14 +1538,16 @@ def _unstack_once( if is_missing_values: dtype, fill_value = dtypes.maybe_promote(self.dtype) - create_template = partial(np.full_like, fill_value=fill_value) + create_template = partial( + duck_array_ops.full_like, fill_value=fill_value + ) else: dtype = self.dtype fill_value = dtypes.get_fill_value(dtype) - create_template = np.empty_like + create_template = duck_array_ops.empty_like else: dtype = self.dtype - create_template = partial(np.full_like, fill_value=fill_value) + create_template = partial(duck_array_ops.full_like, fill_value=fill_value) if sparse: # unstacking a dense multitindexed array to a sparse array @@ -1630,7 +1637,8 @@ def clip(self, min=None, max=None): """ from xarray.core.computation import apply_ufunc - return apply_ufunc(np.clip, self, min, max, dask="allowed") + xp = duck_array_ops.get_array_namespace(self.data) + return apply_ufunc(xp.clip, self, min, max, dask="parallelized") def reduce( # type: ignore[override] self, @@ -1923,13 +1931,15 @@ def quantile( if skipna or (skipna is None and self.dtype.kind in "cfO"): _quantile_func = nputils.nanquantile else: - _quantile_func = np.quantile + _quantile_func = duck_array_ops.quantile if keep_attrs is None: keep_attrs = _get_keep_attrs(default=False) + xp = duck_array_ops.get_array_namespace(self.data) + scalar = utils.is_scalar(q) - q = np.atleast_1d(np.asarray(q, dtype=np.float64)) + q = xp.atleast_1d(xp.asarray(q, dtype=float)) if dim is None: dim = self.dims @@ -1939,9 +1949,10 @@ def quantile( def _wrapper(npa, **kwargs): # move quantile axis to end. required for apply_ufunc - return np.moveaxis(_quantile_func(npa, **kwargs), 0, -1) + return xp.moveaxis(_quantile_func(npa, **kwargs), 0, -1) - axis = np.arange(-1, -1 * len(dim) - 1, -1) + # jax requires hashable + axis = tuple(range(-1, -1 * len(dim) - 1, -1)) kwargs = {"q": q, "axis": axis, "method": method} diff --git a/xarray/tests/test_duck_array_wrapping.py b/xarray/tests/test_duck_array_wrapping.py new file mode 100644 index 00000000000..1fc8ece98dd --- /dev/null +++ b/xarray/tests/test_duck_array_wrapping.py @@ -0,0 +1,371 @@ +import numpy as np +import pandas as pd +import pytest + +import xarray as xr + +# TODO: how to test these in CI? +jnp = pytest.importorskip("jax.numpy") +cp = pytest.importorskip("cupy") + + +def get_test_dataarray(xp): + return xr.DataArray( + xp.asarray([[1, 2, 3, np.nan, 5]]), + dims=["y", "x"], + coords={"y": [1], "x": np.arange(5)}, + name="foo", + ) + + +@pytest.mark.parametrize("xp", [cp, jnp]) +class TestTopLevelMethods: + @pytest.fixture(autouse=True) + def setUp(self, xp): + self.xp = xp + self.Array = xp.ndarray + self.x1 = get_test_dataarray(xp) + self.x2 = get_test_dataarray(xp).assign_coords(x=np.arange(2, 7)) + + def test_apply_ufunc(self): + func = lambda x: x + 1 + result = xr.apply_ufunc(func, self.x1) + assert isinstance(result.data, self.Array) + + def test_align(self): + result = xr.align(self.x1, self.x2) + assert isinstance(result[0].data, self.Array) + assert isinstance(result[1].data, self.Array) + + def test_broadcast(self): + result = xr.broadcast(self.x1, self.x2) + assert isinstance(result[0].data, self.Array) + assert isinstance(result[1].data, self.Array) + + def test_concat(self): + result = xr.concat([self.x1, self.x2], dim="x") + assert isinstance(result.data, self.Array) + + def test_merge(self): + result = xr.merge([self.x1, self.x2], compat="override") + assert isinstance(result.foo.data, self.Array) + + def test_where(self): + x1, x2 = xr.align(self.x1, self.x2, join="inner") + result = xr.where(x1 > 2, x1, x2) + assert isinstance(result.data, self.Array) + + def test_full_like(self): + result = xr.full_like(self.x1, 0) + assert isinstance(result.data, self.Array) + + def test_cov(self): + result = xr.cov(self.x1, self.x2) + assert isinstance(result.data, self.Array) + + def test_corr(self): + result = xr.corr(self.x1, self.x2) + assert isinstance(result.data, self.Array) + + def test_cross(self): + x1, x2 = xr.align(self.x1.squeeze(), self.x2.squeeze(), join="inner") + result = xr.cross(x1, x2, dim="x") + assert isinstance(result.data, self.Array) + + def test_dot(self): + result = xr.dot(self.x1, self.x2) + assert isinstance(result.data, self.Array) + + def test_map_blocks(self): + result = xr.map_blocks(lambda x: x + 1, self.x1) + assert isinstance(result.data, self.Array) + + +@pytest.mark.parametrize("xp", [cp, jnp]) +class TestDataArrayMethods: + @pytest.fixture(autouse=True) + def setUp(self, xp): + self.xp = xp + self.Array = xp.ndarray + self.x = get_test_dataarray(xp) + + def test_loc(self): + result = self.x.loc[{"x": slice(1, 3)}] + assert isinstance(result.data, self.Array) + + def test_isel(self): + result = self.x.isel(x=slice(1, 3)) + assert isinstance(result.data, self.Array) + + def test_sel(self): + result = self.x.sel(x=slice(1, 3)) + assert isinstance(result.data, self.Array) + + def test_squeeze(self): + result = self.x.squeeze("y") + assert isinstance(result.data, self.Array) + + @pytest.mark.xfail(reason="interp is not namespace aware") + def test_interp(self): + result = self.x.interp(x=2.5) + assert isinstance(result.data, self.Array) + + def test_isnull(self): + result = self.x.isnull() + assert isinstance(result.data, self.Array) + + def test_notnull(self): + result = self.x.notnull() + assert isinstance(result.data, self.Array) + + def test_count(self): + result = self.x.count() + assert isinstance(result.data, self.Array) + + def test_dropna(self): + result = self.x.dropna(dim="x") + assert isinstance(result.data, self.Array) + + def test_fillna(self): + result = self.x.fillna(0) + assert isinstance(result.data, self.Array) + + @pytest.mark.xfail(reason="ffill is not namespace aware") + def test_ffill(self): + result = self.x.ffill() + assert isinstance(result.data, self.Array) + + @pytest.mark.xfail(reason="bfill is not namespace aware") + def test_bfill(self): + result = self.x.bfill() + assert isinstance(result.data, self.Array) + + @pytest.mark.xfail(reason="interpolate_na is not namespace aware") + def test_interpolate_na(self): + result = self.x.interpolate_na() + assert isinstance(result.data, self.Array) + + def test_where(self): + result = self.x.where(self.x > 2) + assert isinstance(result.data, self.Array) + + def test_isin(self): + result = self.x.isin(self.xp.asarray([1])) + assert isinstance(result.data, self.Array) + + def test_groupby(self): + result = self.x.groupby("x").mean() + assert isinstance(result.data, self.Array) + + def test_rolling(self): + if self.xp is jnp: + pytest.xfail("no sliding_window_view in jax") + result = self.x.rolling(x=3).mean() + assert isinstance(result.data, self.Array) + + @pytest.mark.xfail(reason="rolling_exp is not namespace aware") + def test_rolling_exp(self): + result = self.x.rolling_exp(x=3).mean() + assert isinstance(result.data, self.Array) + + def test_weighted(self): + result = self.x.weighted(self.x.fillna(0)).mean() + assert isinstance(result.data, self.Array) + + def test_coarsen(self): + result = self.x.coarsen(x=2, boundary="pad").mean() + assert isinstance(result.data, self.Array) + + def test_resample(self): + time_coord = pd.date_range("2000-01-01", periods=5) + result = self.x.assign_coords(x=time_coord).resample(x="D").mean() + assert isinstance(result.data, self.Array) + + def test_diff(self): + result = self.x.diff("x") + assert isinstance(result.data, self.Array) + + def test_dot(self): + result = self.x.dot(self.x) + assert isinstance(result.data, self.Array) + + @pytest.mark.parametrize("skipna", [True, False]) + def test_quantile(self, skipna): + if self.xp is cp and skipna: + pytest.xfail("no nanquantile in cupy") + result = self.x.quantile(0.5, skipna=skipna) + assert isinstance(result.data, self.Array) + + def test_differentiate(self): + if self.xp is jnp: + pytest.xfail("edge_order kwarg") + result = self.x.differentiate("x") + assert isinstance(result.data, self.Array) + + def test_integrate(self): + result = self.x.integrate("x") + assert isinstance(result.data, self.Array) + + @pytest.mark.xfail(reason="polyfit is not namespace aware") + def test_polyfit(self): + result = self.x.polyfit("x", 1) + assert isinstance(result.polyfit_coefficients.data, self.Array) + + def test_map_blocks(self): + result = self.x.map_blocks(lambda x: x + 1) + assert isinstance(result.data, self.Array) + + def test_all(self): + result = self.x.all(dim="x") + assert isinstance(result.data, self.Array) + + def test_any(self): + result = self.x.any(dim="x") + assert isinstance(result.data, self.Array) + + @pytest.mark.parametrize("skipna", [True, False]) + def test_argmax(self, skipna): + result = self.x.argmax(dim="x", skipna=skipna) + assert isinstance(result.data, self.Array) + + @pytest.mark.parametrize("skipna", [True, False]) + def test_argmin(self, skipna): + result = self.x.argmin(dim="x", skipna=skipna) + assert isinstance(result.data, self.Array) + + @pytest.mark.parametrize("skipna", [True, False]) + def test_idxmax(self, skipna): + result = self.x.idxmax(dim="x", skipna=skipna) + assert isinstance(result.data, self.Array) + + @pytest.mark.parametrize("skipna", [True, False]) + def test_idxmin(self, skipna): + result = self.x.idxmin(dim="x", skipna=skipna) + assert isinstance(result.data, self.Array) + + @pytest.mark.parametrize("skipna", [True, False]) + def test_max(self, skipna): + result = self.x.max(dim="x", skipna=skipna) + assert isinstance(result.data, self.Array) + + @pytest.mark.parametrize("skipna", [True, False]) + def test_min(self, skipna): + result = self.x.min(dim="x", skipna=skipna) + assert isinstance(result.data, self.Array) + + @pytest.mark.parametrize("skipna", [True, False]) + def test_mean(self, skipna): + result = self.x.mean(dim="x", skipna=skipna) + assert isinstance(result.data, self.Array) + + @pytest.mark.parametrize("skipna", [True, False]) + def test_median(self, skipna): + result = self.x.median(dim="x", skipna=skipna) + assert isinstance(result.data, self.Array) + + @pytest.mark.parametrize("skipna", [True, False]) + def test_prod(self, skipna): + result = self.x.prod(dim="x", skipna=skipna) + assert isinstance(result.data, self.Array) + + @pytest.mark.parametrize("skipna", [True, False]) + def test_sum(self, skipna): + result = self.x.sum(dim="x", skipna=skipna) + assert isinstance(result.data, self.Array) + + @pytest.mark.parametrize("skipna", [True, False]) + def test_std(self, skipna): + if self.xp is cp and not skipna: + pytest.xfail("ddof/correction kwarg mismatch") + result = self.x.std(dim="x", skipna=skipna) + assert isinstance(result.data, self.Array) + + @pytest.mark.parametrize("skipna", [True, False]) + def test_var(self, skipna): + if self.xp is cp and not skipna: + pytest.xfail("ddof/correction kwarg mismatch") + result = self.x.var(dim="x", skipna=skipna) + assert isinstance(result.data, self.Array) + + @pytest.mark.parametrize("skipna", [True, False]) + def test_cumsum(self, skipna): + result = self.x.cumsum(dim="x", skipna=skipna) + assert isinstance(result.data, self.Array) + + @pytest.mark.parametrize("skipna", [True, False]) + def test_cumprod(self, skipna): + result = self.x.cumprod(dim="x", skipna=skipna) + assert isinstance(result.data, self.Array) + + def test_argsort(self): + result = self.x.argsort() + assert isinstance(result.data, self.Array) + + def test_clip(self): + result = self.x.clip(min=2.0, max=4.0) + assert isinstance(result.data, self.Array) + + def test_conj(self): + result = self.x.conj() + assert isinstance(result.data, self.Array) + + def test_conjugate(self): + result = self.x.conjugate() + assert isinstance(result.data, self.Array) + + def test_imag(self): + result = self.x.imag + assert isinstance(result.data, self.Array) + + def test_searchsorted(self): + result = self.x.squeeze().searchsorted(self.xp.asarray(3)) + assert isinstance(result, self.Array) + + def test_round(self): + result = self.x.round() + assert isinstance(result.data, self.Array) + + def test_real(self): + result = self.x.real + assert isinstance(result.data, self.Array) + + def test_T(self): + result = self.x.T + assert isinstance(result.data, self.Array) + + @pytest.mark.xfail(reason="rank is not namespace aware") + def test_rank(self): + result = self.x.rank() + assert isinstance(result.data, self.Array) + + def test_transpose(self): + result = self.x.transpose() + assert isinstance(result.data, self.Array) + + def test_stack(self): + result = self.x.stack(z=("x", "y")) + assert isinstance(result.data, self.Array) + + def test_unstack(self): + result = self.x.stack(z=("x", "y")).unstack("z") + assert isinstance(result.data, self.Array) + + def test_shift(self): + result = self.x.shift(x=1) + assert isinstance(result.data, self.Array) + + def test_roll(self): + result = self.x.roll(x=1) + assert isinstance(result.data, self.Array) + + def test_pad(self): + result = self.x.pad(x=1) + assert isinstance(result.data, self.Array) + + def test_sortby(self): + result = self.x.sortby("x") + assert isinstance(result.data, self.Array) + + def test_broadcast_like(self): + result = self.x.broadcast_like(self.x) + assert isinstance(result.data, self.Array) diff --git a/xarray/tests/test_strategies.py b/xarray/tests/test_strategies.py index 798f5f732d1..48819333ca2 100644 --- a/xarray/tests/test_strategies.py +++ b/xarray/tests/test_strategies.py @@ -13,6 +13,7 @@ from hypothesis import given from hypothesis.extra.array_api import make_strategies_namespace +from xarray.core.options import set_options from xarray.core.variable import Variable from xarray.testing.strategies import ( attrs, @@ -267,14 +268,14 @@ def test_mean(self, data, var): Test that given a Variable of at least one dimension, the mean of the Variable is always equal to the mean of the underlying array. """ + with set_options(use_numbagg=False): + # specify arbitrary reduction along at least one dimension + reduction_dims = data.draw(unique_subset_of(var.dims, min_size=1)) - # specify arbitrary reduction along at least one dimension - reduction_dims = data.draw(unique_subset_of(var.dims, min_size=1)) + # create expected result (using nanmean because arrays with Nans will be generated) + reduction_axes = tuple(var.get_axis_num(dim) for dim in reduction_dims) + expected = np.nanmean(var.data, axis=reduction_axes) - # create expected result (using nanmean because arrays with Nans will be generated) - reduction_axes = tuple(var.get_axis_num(dim) for dim in reduction_dims) - expected = np.nanmean(var.data, axis=reduction_axes) - - # assert property is always satisfied - result = var.mean(dim=reduction_dims).data - npt.assert_equal(expected, result) + # assert property is always satisfied + result = var.mean(dim=reduction_dims).data + npt.assert_equal(expected, result) diff --git a/xarray/tests/test_variable.py b/xarray/tests/test_variable.py index 9c6f50037d3..1461489e731 100644 --- a/xarray/tests/test_variable.py +++ b/xarray/tests/test_variable.py @@ -1978,26 +1978,27 @@ def test_reduce_funcs(self): def test_reduce_keepdims(self): v = Variable(["x", "y"], self.d) - assert_identical( - v.mean(keepdims=True), Variable(v.dims, np.mean(self.d, keepdims=True)) - ) - assert_identical( - v.mean(dim="x", keepdims=True), - Variable(v.dims, np.mean(self.d, axis=0, keepdims=True)), - ) - assert_identical( - v.mean(dim="y", keepdims=True), - Variable(v.dims, np.mean(self.d, axis=1, keepdims=True)), - ) - assert_identical( - v.mean(dim=["y", "x"], keepdims=True), - Variable(v.dims, np.mean(self.d, axis=(1, 0), keepdims=True)), - ) + with set_options(use_numbagg=False): + assert_identical( + v.mean(keepdims=True), Variable(v.dims, np.mean(self.d, keepdims=True)) + ) + assert_identical( + v.mean(dim="x", keepdims=True), + Variable(v.dims, np.mean(self.d, axis=0, keepdims=True)), + ) + assert_identical( + v.mean(dim="y", keepdims=True), + Variable(v.dims, np.mean(self.d, axis=1, keepdims=True)), + ) + assert_identical( + v.mean(dim=["y", "x"], keepdims=True), + Variable(v.dims, np.mean(self.d, axis=(1, 0), keepdims=True)), + ) - v = Variable([], 1.0) - assert_identical( - v.mean(keepdims=True), Variable([], np.mean(v.data, keepdims=True)) - ) + v = Variable([], 1.0) + assert_identical( + v.mean(keepdims=True), Variable([], np.mean(v.data, keepdims=True)) + ) @requires_dask def test_reduce_keepdims_dask(self): From f7866ce78bd71e604a8e05d14cc8080dd7dd9ec4 Mon Sep 17 00:00:00 2001 From: Sam Levang Date: Mon, 18 Nov 2024 15:15:27 -0500 Subject: [PATCH 02/24] merge sliding_window_view --- xarray/core/duck_array_ops.py | 27 +++++++++++++-------------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index cca07a44f52..f994fec7ae8 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -122,21 +122,20 @@ def fail_on_dask_array_input(values, msg=None, func_name=None): "masked_invalid", eager_module=np.ma, dask_module="dask.array.ma" ) -# sliding_window_view will not dispatch arbitrary kwargs (automatic_rechunk), -# so we need to hand-code this. -sliding_window_view = _dask_or_eager_func( - "sliding_window_view", - eager_module=np.lib.stride_tricks, - dask_module=dask_array_compat, - dask_only_kwargs=("automatic_rechunk",), - numpy_only_kwargs=("subok", "writeable"), -) - -# def sliding_window_view(array, window_shape, axis=None): -# # TODO: some array libraries don't support this, implement an alternative? -# xp = get_array_namespace(array) -# return xp.lib.stride_tricks.sliding_window_view(array, window_shape, axis=axis) +def sliding_window_view(array, window_shape, axis=None, **kwargs): + # TODO: some libraries (e.g. jax) don't have this, implement an alternative? + xp = get_array_namespace(array) + # sliding_window_view will not dispatch arbitrary kwargs (automatic_rechunk), + # so we need to hand-code this. + func = _dask_or_eager_func( + "sliding_window_view", + eager_module=xp.lib.stride_tricks, + dask_module=dask_array_compat, + dask_only_kwargs=("automatic_rechunk",), + numpy_only_kwargs=("subok", "writeable"), + ) + return func(array, window_shape, axis=axis, **kwargs) def round(array): From 90037fe8e883cef727f8a4893541f94a46c583cc Mon Sep 17 00:00:00 2001 From: Sam Levang Date: Mon, 18 Nov 2024 15:25:57 -0500 Subject: [PATCH 03/24] namespaces constant --- xarray/tests/test_duck_array_wrapping.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/xarray/tests/test_duck_array_wrapping.py b/xarray/tests/test_duck_array_wrapping.py index 1fc8ece98dd..c58c62bf84b 100644 --- a/xarray/tests/test_duck_array_wrapping.py +++ b/xarray/tests/test_duck_array_wrapping.py @@ -8,6 +8,8 @@ jnp = pytest.importorskip("jax.numpy") cp = pytest.importorskip("cupy") +NAMESPACES = [cp, jnp] + def get_test_dataarray(xp): return xr.DataArray( @@ -18,7 +20,7 @@ def get_test_dataarray(xp): ) -@pytest.mark.parametrize("xp", [cp, jnp]) +@pytest.mark.parametrize("xp", NAMESPACES) class TestTopLevelMethods: @pytest.fixture(autouse=True) def setUp(self, xp): @@ -81,7 +83,7 @@ def test_map_blocks(self): assert isinstance(result.data, self.Array) -@pytest.mark.parametrize("xp", [cp, jnp]) +@pytest.mark.parametrize("xp", NAMESPACES) class TestDataArrayMethods: @pytest.fixture(autouse=True) def setUp(self, xp): From 5ba1a2f81280007b1e6ac1089aa9934e76c5c83e Mon Sep 17 00:00:00 2001 From: Sam Levang Date: Mon, 18 Nov 2024 15:32:03 -0500 Subject: [PATCH 04/24] revert dask allowed --- xarray/core/variable.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 1597e4bbe66..dd67b290cf2 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -1638,7 +1638,7 @@ def clip(self, min=None, max=None): from xarray.core.computation import apply_ufunc xp = duck_array_ops.get_array_namespace(self.data) - return apply_ufunc(xp.clip, self, min, max, dask="parallelized") + return apply_ufunc(xp.clip, self, min, max, dask="allowed") def reduce( # type: ignore[override] self, From 6225ae3a70d785047259decda130513d1e5df03a Mon Sep 17 00:00:00 2001 From: Sam Levang Date: Tue, 19 Nov 2024 15:20:12 -0500 Subject: [PATCH 05/24] fix up some tests --- xarray/core/dataset.py | 9 ++++++--- xarray/core/duck_array_ops.py | 3 ++- xarray/core/nputils.py | 3 +++ xarray/core/variable.py | 6 +++--- 4 files changed, 14 insertions(+), 7 deletions(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 6e5a8e163b8..b7ecacf98b3 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -8679,9 +8679,12 @@ def _integrate_one(self, coord, datetime_unit=None, cumulative=False): else: if k in self.data_vars and dim in v.dims: # cast coord data to duck array if needed - coord_data = duck_array_ops.get_array_namespace(v.data).asarray( - coord_var.data - ) + if isinstance(v.data, array_type("cupy")): + coord_data = duck_array_ops.get_array_namespace(v.data).asarray( + coord_var.data + ) + else: + coord_data = coord_var.data if _contains_datetime_like_objects(v): v = datetime_to_numeric(v, datetime_unit=datetime_unit) if cumulative: diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index f994fec7ae8..746004c630d 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -168,7 +168,8 @@ def isnull(data): ) ): # these types cannot represent missing values - return full_like(data, dtype=xp.bool, fill_value=False) + dtype = xp.bool_ if hasattr(xp, "bool_") else xp.bool + return full_like(data, dtype=dtype, fill_value=False) else: # at this point, array should have dtype=object if isinstance(data, np.ndarray) or is_extension_array_dtype(data): diff --git a/xarray/core/nputils.py b/xarray/core/nputils.py index 24d6b1dda72..b5f399debab 100644 --- a/xarray/core/nputils.py +++ b/xarray/core/nputils.py @@ -236,6 +236,9 @@ def f(values, axis=None, **kwargs): # bottleneck does not take care dtype, min_count kwargs.pop("dtype", None) result = bn_func(values, axis=axis, **kwargs) + # bottleneck returns python scalars for reduction over all axes + if isinstance(result, float): + result = np.float64(result) else: result = getattr(npmodule, name)(values, axis=axis, **kwargs) diff --git a/xarray/core/variable.py b/xarray/core/variable.py index dd67b290cf2..dd6feb4f07d 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -1936,10 +1936,8 @@ def quantile( if keep_attrs is None: keep_attrs = _get_keep_attrs(default=False) - xp = duck_array_ops.get_array_namespace(self.data) - scalar = utils.is_scalar(q) - q = xp.atleast_1d(xp.asarray(q, dtype=float)) + q = np.atleast_1d(np.asarray(q, dtype=np.float64)) if dim is None: dim = self.dims @@ -1947,6 +1945,8 @@ def quantile( if utils.is_scalar(dim): dim = [dim] + xp = duck_array_ops.get_array_namespace(self.data) + def _wrapper(npa, **kwargs): # move quantile axis to end. required for apply_ufunc return xp.moveaxis(_quantile_func(npa, **kwargs), 0, -1) From e2911c2810a628a5b1a7becec103c2c8528d1c37 Mon Sep 17 00:00:00 2001 From: Sam Levang Date: Tue, 19 Nov 2024 16:16:19 -0500 Subject: [PATCH 06/24] backwards compat sparse mask --- xarray/core/variable.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/variable.py b/xarray/core/variable.py index dd6feb4f07d..833b15d7993 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -861,7 +861,7 @@ def _getitem_with_mask(self, key, fill_value=dtypes.NA): # pint to choose the correct unit # TODO: revert after https://github.com/hgrecco/pint/issues/1019 is fixed # cast mask to any duck array type - if not is_duck_dask_array(mask): + if type(mask) is not type(data): mask = duck_array_ops.get_array_namespace(data).asarray(mask) data = duck_array_ops.where( duck_array_ops.logical_not(mask), data, fill_value From 2ac37f9769236225d6d6692e2e2ce60414d5e9d0 Mon Sep 17 00:00:00 2001 From: Sam Levang Date: Wed, 20 Nov 2024 22:13:22 -0500 Subject: [PATCH 07/24] add as_array methods --- xarray/core/dataarray.py | 22 ++++++++++++++++++++++ xarray/core/dataset.py | 26 ++++++++++++++++++++++++++ xarray/namedarray/core.py | 4 ++++ xarray/tests/test_dataarray.py | 13 +++++++++++++ xarray/tests/test_dataset.py | 15 +++++++++++++++ 5 files changed, 80 insertions(+) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 52ce2463d51..2b19863e35e 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -842,6 +842,28 @@ def as_numpy(self) -> Self: coords = {k: v.as_numpy() for k, v in self._coords.items()} return self._replace(self.variable.as_numpy(), coords, indexes=self._indexes) + def as_array(self, asarray: Callable[[ArrayLike, ...], Any], **kwargs) -> Self: + """ + Coerces wrapped data into a specific array type. + + `asarray` should output an object that supports the Array API Standard. + This method does not convert index coordinates, which can't generally be + represented as arbitrary array types. + + Parameters + ---------- + asarray : Callable + Function that converts an array-like object to the desired array type. + For example, `cupy.asarray`, `jax.numpy.asarray`, or `sparse.COO.from_numpy`. + **kwargs : dict + Additional keyword arguments passed to the `asarray` function. + + Returns + ------- + DataArray + """ + return self._replace(self.variable.as_array(asarray, **kwargs)) + @property def _in_memory(self) -> bool: return self.variable._in_memory diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index b7ecacf98b3..700d733a543 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -1434,6 +1434,32 @@ def as_numpy(self) -> Self: numpy_variables = {k: v.as_numpy() for k, v in self.variables.items()} return self._replace(variables=numpy_variables) + def as_array(self, asarray: Callable[[ArrayLike, ...], Any], **kwargs) -> Self: + """ + Converts wrapped data into a specific array type. + + `asarray` should output an object that supports the Array API Standard. + This method does not convert index coordinates, which can't generally be + represented as arbitrary array types. + + Parameters + ---------- + asarray : Callable + Function that converts an array-like object to the desired array type. + For example, `cupy.asarray`, `jax.numpy.asarray`, or `sparse.COO.from_numpy`. + **kwargs : dict + Additional keyword arguments passed to the `asarray` function. + + Returns + ------- + Dataset + """ + array_variables = { + k: v.as_array(asarray, **kwargs) if k not in self._indexes else v + for k, v in self.variables.items() + } + return self._replace(variables=array_variables) + def _copy_listed(self, names: Iterable[Hashable]) -> Self: """Create a new Dataset with the listed variables from this dataset and the all relevant coordinates. Skips all validation. diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index 98d96c73e91..8ae17ebce13 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -860,6 +860,10 @@ def as_numpy(self) -> Self: """Coerces wrapped data into a numpy array, returning a Variable.""" return self._replace(data=self.to_numpy()) + def as_array(self, asarray: Callable[[ArrayLike, ...], Any], **kwargs) -> Self: + """Coerces wrapped data into a specific array type, returning a Variable.""" + return self._replace(data=asarray(self._data, **kwargs)) + def reduce( self, func: Callable[..., Any], diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index b5ecc9517d9..aa6cb5e1721 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -39,6 +39,7 @@ from xarray.core.utils import is_scalar from xarray.testing import _assert_internal_invariants from xarray.tests import ( + DuckArrayWrapper, InaccessibleArray, ReturnItem, assert_allclose, @@ -7165,6 +7166,18 @@ def test_from_pint_wrapping_dask(self) -> None: np.testing.assert_equal(da.to_numpy(), arr) +def test_as_array() -> None: + da = xr.DataArray([1, 2, 3], dims=["x"], coords={"x": [4, 5, 6]}) + + def as_duck_array(arr): + return DuckArrayWrapper(arr) + + result = da.as_array(as_duck_array) + + assert isinstance(result.data, DuckArrayWrapper) + assert isinstance(result.x.data, np.ndarray) + + class TestStackEllipsis: # https://github.com/pydata/xarray/issues/6051 def test_result_as_expected(self) -> None: diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index be82655515d..bbfc2df3fd7 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -7639,6 +7639,21 @@ def test_from_pint_wrapping_dask(self) -> None: assert_identical(result, expected) +def test_as_array() -> None: + ds = xr.Dataset( + {"a": ("x", [1, 2, 3])}, coords={"lat": ("x", [4, 5, 6]), "x": [7, 8, 9]} + ) + + def as_duck_array(arr): + return DuckArrayWrapper(arr) + + result = ds.as_array(as_duck_array) + + assert isinstance(result.a.data, DuckArrayWrapper) + assert isinstance(result.lat.data, DuckArrayWrapper) + assert isinstance(result.x.data, np.ndarray) + + def test_string_keys_typing() -> None: """Tests that string keys to `variables` are permitted by mypy""" From 1cc344ba46164fd2294c2d6fdaf8bc1e7c273afd Mon Sep 17 00:00:00 2001 From: Sam Levang Date: Wed, 20 Nov 2024 22:42:54 -0500 Subject: [PATCH 08/24] to_like_array helper --- xarray/core/array_api_compat.py | 6 ++++++ xarray/core/computation.py | 5 ++--- xarray/core/dataset.py | 9 ++------- xarray/core/variable.py | 5 ++--- 4 files changed, 12 insertions(+), 13 deletions(-) diff --git a/xarray/core/array_api_compat.py b/xarray/core/array_api_compat.py index 28d671cc349..1845d6eddcc 100644 --- a/xarray/core/array_api_compat.py +++ b/xarray/core/array_api_compat.py @@ -70,3 +70,9 @@ def _get_single_namespace(x): xp = np return xp + + +def to_like_array(array, like): + # Mostly for cupy compatibility, because cupy binary ops require all cupy arrays + xp = get_array_namespace(like) + return xp.asarray(array) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 0bfe21642f7..0945f4638f6 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -24,6 +24,7 @@ from xarray.core import dtypes, duck_array_ops, utils from xarray.core.alignment import align, deep_align +from xarray.core.array_api_compat import to_like_array from xarray.core.common import zeros_like from xarray.core.duck_array_ops import datetime_to_numeric from xarray.core.formatting import limit_lines @@ -2178,9 +2179,7 @@ def _calc_idxminmax( res = array[dim][(indx,)] # The dim is gone but we need to remove the corresponding coordinate. del res.coords[dim] - # Cast to array namespace - xp = duck_array_ops.get_array_namespace(array.data) - res.data = xp.asarray(res.data) + res.data = to_like_array(res.data, array.data) if skipna or (skipna is None and array.dtype.kind in na_dtypes): # Put the NaN values back in after removing them diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 700d733a543..eded1d89d05 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -55,6 +55,7 @@ align, ) from xarray.core.arithmetic import DatasetArithmetic +from xarray.core.array_api_compat import to_like_array from xarray.core.common import ( DataWithCoords, _contains_datetime_like_objects, @@ -8704,13 +8705,7 @@ def _integrate_one(self, coord, datetime_unit=None, cumulative=False): coord_names.add(k) else: if k in self.data_vars and dim in v.dims: - # cast coord data to duck array if needed - if isinstance(v.data, array_type("cupy")): - coord_data = duck_array_ops.get_array_namespace(v.data).asarray( - coord_var.data - ) - else: - coord_data = coord_var.data + coord_data = to_like_array(coord_var.data, like=v.data) if _contains_datetime_like_objects(v): v = datetime_to_numeric(v, datetime_unit=datetime_unit) if cumulative: diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 833b15d7993..a472e809876 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -19,6 +19,7 @@ import xarray as xr # only for Dataset and DataArray from xarray.core import common, dtypes, duck_array_ops, indexing, nputils, ops, utils from xarray.core.arithmetic import VariableArithmetic +from xarray.core.array_api_compat import to_like_array from xarray.core.common import AbstractArray from xarray.core.extension_array import PandasExtensionArray from xarray.core.indexing import ( @@ -860,9 +861,7 @@ def _getitem_with_mask(self, key, fill_value=dtypes.NA): # we need to invert the mask in order to pass data first. This helps # pint to choose the correct unit # TODO: revert after https://github.com/hgrecco/pint/issues/1019 is fixed - # cast mask to any duck array type - if type(mask) is not type(data): - mask = duck_array_ops.get_array_namespace(data).asarray(mask) + mask = to_like_array(mask, data) data = duck_array_ops.where( duck_array_ops.logical_not(mask), data, fill_value ) From 372439ce144f97fdf1e2fdee1e0d5c5f05ae919e Mon Sep 17 00:00:00 2001 From: Sam Levang Date: Wed, 20 Nov 2024 22:58:45 -0500 Subject: [PATCH 09/24] only cast non-numpy --- xarray/core/array_api_compat.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/xarray/core/array_api_compat.py b/xarray/core/array_api_compat.py index 1845d6eddcc..e7424325de8 100644 --- a/xarray/core/array_api_compat.py +++ b/xarray/core/array_api_compat.py @@ -75,4 +75,7 @@ def _get_single_namespace(x): def to_like_array(array, like): # Mostly for cupy compatibility, because cupy binary ops require all cupy arrays xp = get_array_namespace(like) - return xp.asarray(array) + if xp is not np: + return xp.asarray(array) + # avoid casting things like pint quantities to numpy arrays + return array From 0eef2cbe2d0cae5fd80c8d3e510e54e1d2978df3 Mon Sep 17 00:00:00 2001 From: Sam Levang Date: Wed, 20 Nov 2024 23:27:43 -0500 Subject: [PATCH 10/24] better idxminmax approach --- xarray/core/computation.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 0945f4638f6..6e233425e95 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -35,7 +35,7 @@ from xarray.core.utils import is_dict_like, is_scalar, parse_dims_as_set, result_name from xarray.core.variable import Variable from xarray.namedarray.parallelcompat import get_chunked_array_type -from xarray.namedarray.pycompat import is_chunked_array, to_numpy +from xarray.namedarray.pycompat import is_chunked_array from xarray.util.deprecation_helpers import deprecate_dims if TYPE_CHECKING: @@ -2171,15 +2171,14 @@ def _calc_idxminmax( chunks = dict(zip(array.dims, array.chunks, strict=True)) dask_coord = chunkmanager.from_array(array[dim].data, chunks=chunks[dim]) data = dask_coord[duck_array_ops.ravel(indx.data)] - res = indx.copy(data=duck_array_ops.reshape(data, indx.shape)) - # we need to attach back the dim name - res.name = dim else: - indx.data = to_numpy(indx.data) - res = array[dim][(indx,)] - # The dim is gone but we need to remove the corresponding coordinate. - del res.coords[dim] - res.data = to_like_array(res.data, array.data) + arr_coord = to_like_array(array[dim].data, array.data) + data = arr_coord[duck_array_ops.ravel(indx.data)] + + # rebuild like the argmin/max output, and rename as the dim name + data = duck_array_ops.reshape(data, indx.shape) + res = indx.copy(data=data) + res.name = dim if skipna or (skipna is None and array.dtype.kind in na_dtypes): # Put the NaN values back in after removing them From 6739504fc7a7f2f4646ff7c06aa7b2653840c00d Mon Sep 17 00:00:00 2001 From: Sam Levang Date: Wed, 20 Nov 2024 23:58:38 -0500 Subject: [PATCH 11/24] fix mypy --- xarray/core/dataarray.py | 2 +- xarray/core/dataset.py | 2 +- xarray/namedarray/core.py | 6 +++++- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 2f8a6ce620b..ff9880bf2da 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -842,7 +842,7 @@ def as_numpy(self) -> Self: coords = {k: v.as_numpy() for k, v in self._coords.items()} return self._replace(self.variable.as_numpy(), coords, indexes=self._indexes) - def as_array(self, asarray: Callable[[ArrayLike, ...], Any], **kwargs) -> Self: + def as_array(self, asarray: Callable, **kwargs) -> Self: """ Coerces wrapped data into a specific array type. diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index eded1d89d05..3711045f8c9 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -1435,7 +1435,7 @@ def as_numpy(self) -> Self: numpy_variables = {k: v.as_numpy() for k, v in self.variables.items()} return self._replace(variables=numpy_variables) - def as_array(self, asarray: Callable[[ArrayLike, ...], Any], **kwargs) -> Self: + def as_array(self, asarray: Callable, **kwargs) -> Self: """ Converts wrapped data into a specific array type. diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index 8ae17ebce13..e80a15fdc3f 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -860,7 +860,11 @@ def as_numpy(self) -> Self: """Coerces wrapped data into a numpy array, returning a Variable.""" return self._replace(data=self.to_numpy()) - def as_array(self, asarray: Callable[[ArrayLike, ...], Any], **kwargs) -> Self: + def as_array( + self, + asarray: Callable[[duckarray[Any, _DType_co]], duckarray[Any, _DType_co]], + **kwargs: Any, + ) -> Self: """Coerces wrapped data into a specific array type, returning a Variable.""" return self._replace(data=asarray(self._data, **kwargs)) From 9e6d6f8155a467f1c941de8496255bcca2b4ddbf Mon Sep 17 00:00:00 2001 From: Sam Levang Date: Thu, 21 Nov 2024 09:27:02 -0500 Subject: [PATCH 12/24] naming, add is_array_type --- xarray/core/dataarray.py | 22 +++++++++++++++++++--- xarray/core/dataset.py | 26 +++++++++++++++++++++++--- xarray/namedarray/core.py | 31 +++++++++++++++++++++++++++++-- xarray/tests/test_dataarray.py | 7 +++++-- xarray/tests/test_dataset.py | 7 +++++-- 5 files changed, 81 insertions(+), 12 deletions(-) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index ff9880bf2da..7796904d897 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -842,7 +842,7 @@ def as_numpy(self) -> Self: coords = {k: v.as_numpy() for k, v in self._coords.items()} return self._replace(self.variable.as_numpy(), coords, indexes=self._indexes) - def as_array(self, asarray: Callable, **kwargs) -> Self: + def as_array_type(self, asarray: Callable, **kwargs) -> Self: """ Coerces wrapped data into a specific array type. @@ -854,7 +854,8 @@ def as_array(self, asarray: Callable, **kwargs) -> Self: ---------- asarray : Callable Function that converts an array-like object to the desired array type. - For example, `cupy.asarray`, `jax.numpy.asarray`, or `sparse.COO.from_numpy`. + For example, `cupy.asarray`, `jax.numpy.asarray`, `sparse.COO.from_numpy`, + or any `from_dlpack` method. **kwargs : dict Additional keyword arguments passed to the `asarray` function. @@ -862,7 +863,22 @@ def as_array(self, asarray: Callable, **kwargs) -> Self: ------- DataArray """ - return self._replace(self.variable.as_array(asarray, **kwargs)) + return self._replace(self.variable.as_array_type(asarray, **kwargs)) + + def is_array_type(self, array_type: type) -> bool: + """ + Check if the wrapped data is of a specific array type. + + Parameters + ---------- + array_type : type + The array type to check for. + + Returns + ------- + bool + """ + return self.variable.is_array_type(array_type) @property def _in_memory(self) -> bool: diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 3711045f8c9..32ea1b98308 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -1435,7 +1435,7 @@ def as_numpy(self) -> Self: numpy_variables = {k: v.as_numpy() for k, v in self.variables.items()} return self._replace(variables=numpy_variables) - def as_array(self, asarray: Callable, **kwargs) -> Self: + def as_array_type(self, asarray: Callable, **kwargs) -> Self: """ Converts wrapped data into a specific array type. @@ -1447,7 +1447,8 @@ def as_array(self, asarray: Callable, **kwargs) -> Self: ---------- asarray : Callable Function that converts an array-like object to the desired array type. - For example, `cupy.asarray`, `jax.numpy.asarray`, or `sparse.COO.from_numpy`. + For example, `cupy.asarray`, `jax.numpy.asarray`, `sparse.COO.from_numpy`, + or any `from_dlpack` method. **kwargs : dict Additional keyword arguments passed to the `asarray` function. @@ -1456,11 +1457,30 @@ def as_array(self, asarray: Callable, **kwargs) -> Self: Dataset """ array_variables = { - k: v.as_array(asarray, **kwargs) if k not in self._indexes else v + k: v.as_array_type(asarray, **kwargs) if k not in self._indexes else v for k, v in self.variables.items() } return self._replace(variables=array_variables) + def is_array_type(self, array_type: type) -> bool: + """ + Check if all data variables and non-index coordinates are of a specific array type. + + Parameters + ---------- + array_type : type + The array type to check for. + + Returns + ------- + bool + """ + return all( + v.is_array_type(array_type) + for k, v in self.variables.items() + if k not in self._indexes + ) + def _copy_listed(self, names: Iterable[Hashable]) -> Self: """Create a new Dataset with the listed variables from this dataset and the all relevant coordinates. Skips all validation. diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index e80a15fdc3f..2558ecec9c7 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -860,14 +860,41 @@ def as_numpy(self) -> Self: """Coerces wrapped data into a numpy array, returning a Variable.""" return self._replace(data=self.to_numpy()) - def as_array( + def as_array_type( self, asarray: Callable[[duckarray[Any, _DType_co]], duckarray[Any, _DType_co]], **kwargs: Any, ) -> Self: - """Coerces wrapped data into a specific array type, returning a Variable.""" + """Converts wrapped data into a specific array type. + + Parameters + ---------- + asarray : callable + Function that converts the data into a specific array type. + **kwargs : dict + Additional keyword arguments passed on to `asarray`. + + Returns + ------- + array : NamedArray + Array with the same data, but converted into a specific array type + """ return self._replace(data=asarray(self._data, **kwargs)) + def is_array_type(self, array_type: type) -> bool: + """Check if the data is an instance of a specific array type. + + Parameters + ---------- + array_type : type + Array type to check against. + + Returns + ------- + is_array_type : bool + """ + return isinstance(self._data, array_type) + def reduce( self, func: Callable[..., Any], diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 6e1efe85185..8bc63f3bf4b 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -7166,16 +7166,19 @@ def test_from_pint_wrapping_dask(self) -> None: np.testing.assert_equal(da.to_numpy(), arr) -def test_as_array() -> None: +def test_as_array_type_is_array_type() -> None: da = xr.DataArray([1, 2, 3], dims=["x"], coords={"x": [4, 5, 6]}) + assert da.is_array_type(np.ndarray) + def as_duck_array(arr): return DuckArrayWrapper(arr) - result = da.as_array(as_duck_array) + result = da.as_array_type(as_duck_array) assert isinstance(result.data, DuckArrayWrapper) assert isinstance(result.x.data, np.ndarray) + assert result.is_array_type(DuckArrayWrapper) class TestStackEllipsis: diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 13917e28225..edca2a02c93 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -7639,19 +7639,22 @@ def test_from_pint_wrapping_dask(self) -> None: assert_identical(result, expected) -def test_as_array() -> None: +def test_as_array_type_is_array_type() -> None: ds = xr.Dataset( {"a": ("x", [1, 2, 3])}, coords={"lat": ("x", [4, 5, 6]), "x": [7, 8, 9]} ) + # lat is a PandasIndex here + assert ds.drop_vars("lat").is_array_type(np.ndarray) def as_duck_array(arr): return DuckArrayWrapper(arr) - result = ds.as_array(as_duck_array) + result = ds.as_array_type(as_duck_array) assert isinstance(result.a.data, DuckArrayWrapper) assert isinstance(result.lat.data, DuckArrayWrapper) assert isinstance(result.x.data, np.ndarray) + assert result.is_array_type(DuckArrayWrapper) def test_string_keys_typing() -> None: From e72101155ed172a217b90d5fc36c95a495545049 Mon Sep 17 00:00:00 2001 From: Sam Levang Date: Thu, 21 Nov 2024 09:31:45 -0500 Subject: [PATCH 13/24] add public doc and whats new --- doc/api.rst | 4 ++++ doc/whats-new.rst | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/doc/api.rst b/doc/api.rst index 0c30ddc4c20..e5517eaf07e 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -117,6 +117,8 @@ Dataset contents Dataset.convert_calendar Dataset.interp_calendar Dataset.get_index + Dataset.as_array_type + Dataset.is_array_type Comparisons ----------- @@ -315,6 +317,8 @@ DataArray contents DataArray.get_index DataArray.astype DataArray.item + DataArray.as_array_type + DataArray.is_array_type Indexing -------- diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 3a04467d483..8084cc17780 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -64,6 +64,10 @@ New Features underlying array's backend. Provides better support for certain wrapped array types like ``jax.numpy.ndarray``. (:issue:`7848`, :pull:`9776`). By `Sam Levang `_. +- Make more xarray methods fully compatible with duck array types, and introduce new + ``as_array_type`` and ``is_array_type`` methods for converting wrapped data to other + duck array types. (:issue:`7848`, :pull:`9798`). + By `Sam Levang `_. Breaking changes ~~~~~~~~~~~~~~~~ From 1fe41316b2063db6b131828c0fcd26e1d3926abc Mon Sep 17 00:00:00 2001 From: Sam Levang Date: Thu, 21 Nov 2024 10:19:28 -0500 Subject: [PATCH 14/24] update comments --- xarray/core/array_api_compat.py | 3 ++- xarray/core/dataarray.py | 2 +- xarray/tests/test_duck_array_wrapping.py | 27 ++++++++++++------------ 3 files changed, 16 insertions(+), 16 deletions(-) diff --git a/xarray/core/array_api_compat.py b/xarray/core/array_api_compat.py index e7424325de8..e1e5d5c5bdc 100644 --- a/xarray/core/array_api_compat.py +++ b/xarray/core/array_api_compat.py @@ -51,7 +51,8 @@ def _get_single_namespace(x): if hasattr(x, "__array_namespace__"): return x.__array_namespace__() elif isinstance(x, array_type("cupy")): - # special case cupy for now + # cupy is fully compliant from xarray's perspective, but will not expose + # __array_namespace__ until at least v14. Special case it for now import cupy as cp return cp diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 7796904d897..ead13663cb8 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -844,7 +844,7 @@ def as_numpy(self) -> Self: def as_array_type(self, asarray: Callable, **kwargs) -> Self: """ - Coerces wrapped data into a specific array type. + Converts wrapped data into a specific array type. `asarray` should output an object that supports the Array API Standard. This method does not convert index coordinates, which can't generally be diff --git a/xarray/tests/test_duck_array_wrapping.py b/xarray/tests/test_duck_array_wrapping.py index c58c62bf84b..ffaa1440e1a 100644 --- a/xarray/tests/test_duck_array_wrapping.py +++ b/xarray/tests/test_duck_array_wrapping.py @@ -107,8 +107,9 @@ def test_squeeze(self): result = self.x.squeeze("y") assert isinstance(result.data, self.Array) - @pytest.mark.xfail(reason="interp is not namespace aware") + @pytest.mark.xfail(reason="interp uses numpy and scipy") def test_interp(self): + # TODO: some cases could be made to work result = self.x.interp(x=2.5) assert isinstance(result.data, self.Array) @@ -132,17 +133,17 @@ def test_fillna(self): result = self.x.fillna(0) assert isinstance(result.data, self.Array) - @pytest.mark.xfail(reason="ffill is not namespace aware") + @pytest.mark.xfail(reason="ffill uses bottleneck or numbagg") def test_ffill(self): result = self.x.ffill() assert isinstance(result.data, self.Array) - @pytest.mark.xfail(reason="bfill is not namespace aware") + @pytest.mark.xfail(reason="bfill uses bottleneck or numbagg") def test_bfill(self): result = self.x.bfill() assert isinstance(result.data, self.Array) - @pytest.mark.xfail(reason="interpolate_na is not namespace aware") + @pytest.mark.xfail(reason="interpolate_na uses numpy and scipy") def test_interpolate_na(self): result = self.x.interpolate_na() assert isinstance(result.data, self.Array) @@ -165,7 +166,7 @@ def test_rolling(self): result = self.x.rolling(x=3).mean() assert isinstance(result.data, self.Array) - @pytest.mark.xfail(reason="rolling_exp is not namespace aware") + @pytest.mark.xfail(reason="rolling_exp uses numbagg") def test_rolling_exp(self): result = self.x.rolling_exp(x=3).mean() assert isinstance(result.data, self.Array) @@ -199,17 +200,18 @@ def test_quantile(self, skipna): assert isinstance(result.data, self.Array) def test_differentiate(self): - if self.xp is jnp: - pytest.xfail("edge_order kwarg") - result = self.x.differentiate("x") + # edge_order is not implemented in jax, and only supports passing None + edge_order = None if self.xp is jnp else 1 + result = self.x.differentiate("x", edge_order=edge_order) assert isinstance(result.data, self.Array) def test_integrate(self): result = self.x.integrate("x") assert isinstance(result.data, self.Array) - @pytest.mark.xfail(reason="polyfit is not namespace aware") + @pytest.mark.xfail(reason="polyfit uses numpy linalg") def test_polyfit(self): + # TODO: this could work, there are just a lot of different linalg calls result = self.x.polyfit("x", 1) assert isinstance(result.polyfit_coefficients.data, self.Array) @@ -277,15 +279,11 @@ def test_sum(self, skipna): @pytest.mark.parametrize("skipna", [True, False]) def test_std(self, skipna): - if self.xp is cp and not skipna: - pytest.xfail("ddof/correction kwarg mismatch") result = self.x.std(dim="x", skipna=skipna) assert isinstance(result.data, self.Array) @pytest.mark.parametrize("skipna", [True, False]) def test_var(self, skipna): - if self.xp is cp and not skipna: - pytest.xfail("ddof/correction kwarg mismatch") result = self.x.var(dim="x", skipna=skipna) assert isinstance(result.data, self.Array) @@ -335,8 +333,9 @@ def test_T(self): result = self.x.T assert isinstance(result.data, self.Array) - @pytest.mark.xfail(reason="rank is not namespace aware") + @pytest.mark.xfail(reason="rank uses bottleneck") def test_rank(self): + # TODO: scipy has rankdata, as does jax, so this can work result = self.x.rank() assert isinstance(result.data, self.Array) From 205c1995703d0f259cfce907337558c6256c0a43 Mon Sep 17 00:00:00 2001 From: Sam Levang Date: Thu, 21 Nov 2024 10:57:26 -0500 Subject: [PATCH 15/24] add support for chunked arrays in as_array_type --- xarray/core/dataarray.py | 2 ++ xarray/core/dataset.py | 2 ++ xarray/namedarray/core.py | 14 +++++++++++--- xarray/tests/test_dataarray.py | 19 +++++++++++++++---- xarray/tests/test_dataset.py | 24 ++++++++++++++++++++---- 5 files changed, 50 insertions(+), 11 deletions(-) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index ead13663cb8..021e9d85474 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -846,6 +846,8 @@ def as_array_type(self, asarray: Callable, **kwargs) -> Self: """ Converts wrapped data into a specific array type. + If the data is a chunked array, the conversion is applied to each block. + `asarray` should output an object that supports the Array API Standard. This method does not convert index coordinates, which can't generally be represented as arbitrary array types. diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 32ea1b98308..038d503e682 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -1439,6 +1439,8 @@ def as_array_type(self, asarray: Callable, **kwargs) -> Self: """ Converts wrapped data into a specific array type. + If the data is a chunked array, the conversion is applied to each block. + `asarray` should output an object that supports the Array API Standard. This method does not convert index coordinates, which can't generally be represented as arbitrary array types. diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index 2558ecec9c7..ab4b3bc1820 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -40,8 +40,8 @@ _SupportsImag, _SupportsReal, ) -from xarray.namedarray.parallelcompat import guess_chunkmanager -from xarray.namedarray.pycompat import to_numpy +from xarray.namedarray.parallelcompat import get_chunked_array_type, guess_chunkmanager +from xarray.namedarray.pycompat import is_chunked_array, to_numpy from xarray.namedarray.utils import ( either_dict_or_kwargs, infix_dims, @@ -867,6 +867,8 @@ def as_array_type( ) -> Self: """Converts wrapped data into a specific array type. + If the data is a chunked array, the conversion is applied to each block. + Parameters ---------- asarray : callable @@ -879,7 +881,13 @@ def as_array_type( array : NamedArray Array with the same data, but converted into a specific array type """ - return self._replace(data=asarray(self._data, **kwargs)) + if is_chunked_array(self._data): + chunkmanager = get_chunked_array_type(self._data) + new_data = chunkmanager.map_blocks(asarray, self._data, **kwargs) + else: + new_data = asarray(self._data, **kwargs) + + return self._replace(data=new_data) def is_array_type(self, array_type: type) -> bool: """Check if the data is an instance of a specific array type. diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 8bc63f3bf4b..b4af9d37e35 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -7171,16 +7171,27 @@ def test_as_array_type_is_array_type() -> None: assert da.is_array_type(np.ndarray) - def as_duck_array(arr): - return DuckArrayWrapper(arr) - - result = da.as_array_type(as_duck_array) + result = da.as_array_type(lambda x: DuckArrayWrapper(x)) assert isinstance(result.data, DuckArrayWrapper) assert isinstance(result.x.data, np.ndarray) assert result.is_array_type(DuckArrayWrapper) +@requires_dask +def test_as_array_type_dask() -> None: + import dask.array + + da = xr.DataArray([1, 2, 3], dims=["x"], coords={"x": [4, 5, 6]}).chunk() + + result = da.as_array_type(lambda x: DuckArrayWrapper(x)) + + assert isinstance(result.data, dask.array.Array) + assert isinstance(result.data._meta, DuckArrayWrapper) + assert isinstance(result.x.data, np.ndarray) + assert result.is_array_type(dask.array.Array) + + class TestStackEllipsis: # https://github.com/pydata/xarray/issues/6051 def test_result_as_expected(self) -> None: diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index edca2a02c93..b8dbcabf3ce 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -7646,10 +7646,7 @@ def test_as_array_type_is_array_type() -> None: # lat is a PandasIndex here assert ds.drop_vars("lat").is_array_type(np.ndarray) - def as_duck_array(arr): - return DuckArrayWrapper(arr) - - result = ds.as_array_type(as_duck_array) + result = ds.as_array_type(lambda x: DuckArrayWrapper(x)) assert isinstance(result.a.data, DuckArrayWrapper) assert isinstance(result.lat.data, DuckArrayWrapper) @@ -7657,6 +7654,25 @@ def as_duck_array(arr): assert result.is_array_type(DuckArrayWrapper) +@requires_dask +def test_as_array_type_dask() -> None: + import dask.array + + ds = xr.Dataset( + {"a": ("x", [1, 2, 3])}, coords={"lat": ("x", [4, 5, 6]), "x": [7, 8, 9]} + ).chunk() + + assert ds.is_array_type(dask.array.Array) + + result = ds.as_array_type(lambda x: DuckArrayWrapper(x)) + + assert isinstance(result.a.data, dask.array.Array) + assert isinstance(result.a.data._meta, DuckArrayWrapper) + assert isinstance(result.lat.data, dask.array.Array) + assert isinstance(result.lat.data._meta, DuckArrayWrapper) + assert isinstance(result.x.data, np.ndarray) + + def test_string_keys_typing() -> None: """Tests that string keys to `variables` are permitted by mypy""" From c8d4e5ec713358f05a0def3789b38f778e346ad5 Mon Sep 17 00:00:00 2001 From: Sam Levang Date: Fri, 22 Nov 2024 14:19:25 -0500 Subject: [PATCH 16/24] revert array_type methods --- doc/api.rst | 4 --- doc/whats-new.rst | 5 ++-- xarray/core/dataarray.py | 40 ---------------------------- xarray/core/dataset.py | 48 ---------------------------------- xarray/namedarray/core.py | 47 ++------------------------------- xarray/tests/test_dataarray.py | 27 ------------------- xarray/tests/test_dataset.py | 34 ------------------------ 7 files changed, 4 insertions(+), 201 deletions(-) diff --git a/doc/api.rst b/doc/api.rst index 7a596fdaa2d..85ef46ca6ba 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -117,8 +117,6 @@ Dataset contents Dataset.convert_calendar Dataset.interp_calendar Dataset.get_index - Dataset.as_array_type - Dataset.is_array_type Comparisons ----------- @@ -317,8 +315,6 @@ DataArray contents DataArray.get_index DataArray.astype DataArray.item - DataArray.as_array_type - DataArray.is_array_type Indexing -------- diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 8084cc17780..3801075a310 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -64,9 +64,8 @@ New Features underlying array's backend. Provides better support for certain wrapped array types like ``jax.numpy.ndarray``. (:issue:`7848`, :pull:`9776`). By `Sam Levang `_. -- Make more xarray methods fully compatible with duck array types, and introduce new - ``as_array_type`` and ``is_array_type`` methods for converting wrapped data to other - duck array types. (:issue:`7848`, :pull:`9798`). +- Better support wrapping additional array types (e.g. ``cupy`` or ``jax``) by calling generalized + duck array operations throughout more xarray methods. (:issue:`7848`, :pull:`9798`). By `Sam Levang `_. Breaking changes diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index e45aaac5836..eae11c0c491 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -844,46 +844,6 @@ def as_numpy(self) -> Self: coords = {k: v.as_numpy() for k, v in self._coords.items()} return self._replace(self.variable.as_numpy(), coords, indexes=self._indexes) - def as_array_type(self, asarray: Callable, **kwargs) -> Self: - """ - Converts wrapped data into a specific array type. - - If the data is a chunked array, the conversion is applied to each block. - - `asarray` should output an object that supports the Array API Standard. - This method does not convert index coordinates, which can't generally be - represented as arbitrary array types. - - Parameters - ---------- - asarray : Callable - Function that converts an array-like object to the desired array type. - For example, `cupy.asarray`, `jax.numpy.asarray`, `sparse.COO.from_numpy`, - or any `from_dlpack` method. - **kwargs : dict - Additional keyword arguments passed to the `asarray` function. - - Returns - ------- - DataArray - """ - return self._replace(self.variable.as_array_type(asarray, **kwargs)) - - def is_array_type(self, array_type: type) -> bool: - """ - Check if the wrapped data is of a specific array type. - - Parameters - ---------- - array_type : type - The array type to check for. - - Returns - ------- - bool - """ - return self.variable.is_array_type(array_type) - @property def _in_memory(self) -> bool: return self.variable._in_memory diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index dd2df1c77c1..b305e4b51de 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -1437,54 +1437,6 @@ def as_numpy(self) -> Self: numpy_variables = {k: v.as_numpy() for k, v in self.variables.items()} return self._replace(variables=numpy_variables) - def as_array_type(self, asarray: Callable, **kwargs) -> Self: - """ - Converts wrapped data into a specific array type. - - If the data is a chunked array, the conversion is applied to each block. - - `asarray` should output an object that supports the Array API Standard. - This method does not convert index coordinates, which can't generally be - represented as arbitrary array types. - - Parameters - ---------- - asarray : Callable - Function that converts an array-like object to the desired array type. - For example, `cupy.asarray`, `jax.numpy.asarray`, `sparse.COO.from_numpy`, - or any `from_dlpack` method. - **kwargs : dict - Additional keyword arguments passed to the `asarray` function. - - Returns - ------- - Dataset - """ - array_variables = { - k: v.as_array_type(asarray, **kwargs) if k not in self._indexes else v - for k, v in self.variables.items() - } - return self._replace(variables=array_variables) - - def is_array_type(self, array_type: type) -> bool: - """ - Check if all data variables and non-index coordinates are of a specific array type. - - Parameters - ---------- - array_type : type - The array type to check for. - - Returns - ------- - bool - """ - return all( - v.is_array_type(array_type) - for k, v in self.variables.items() - if k not in self._indexes - ) - def _copy_listed(self, names: Iterable[Hashable]) -> Self: """Create a new Dataset with the listed variables from this dataset and the all relevant coordinates. Skips all validation. diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index ab4b3bc1820..98d96c73e91 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -40,8 +40,8 @@ _SupportsImag, _SupportsReal, ) -from xarray.namedarray.parallelcompat import get_chunked_array_type, guess_chunkmanager -from xarray.namedarray.pycompat import is_chunked_array, to_numpy +from xarray.namedarray.parallelcompat import guess_chunkmanager +from xarray.namedarray.pycompat import to_numpy from xarray.namedarray.utils import ( either_dict_or_kwargs, infix_dims, @@ -860,49 +860,6 @@ def as_numpy(self) -> Self: """Coerces wrapped data into a numpy array, returning a Variable.""" return self._replace(data=self.to_numpy()) - def as_array_type( - self, - asarray: Callable[[duckarray[Any, _DType_co]], duckarray[Any, _DType_co]], - **kwargs: Any, - ) -> Self: - """Converts wrapped data into a specific array type. - - If the data is a chunked array, the conversion is applied to each block. - - Parameters - ---------- - asarray : callable - Function that converts the data into a specific array type. - **kwargs : dict - Additional keyword arguments passed on to `asarray`. - - Returns - ------- - array : NamedArray - Array with the same data, but converted into a specific array type - """ - if is_chunked_array(self._data): - chunkmanager = get_chunked_array_type(self._data) - new_data = chunkmanager.map_blocks(asarray, self._data, **kwargs) - else: - new_data = asarray(self._data, **kwargs) - - return self._replace(data=new_data) - - def is_array_type(self, array_type: type) -> bool: - """Check if the data is an instance of a specific array type. - - Parameters - ---------- - array_type : type - Array type to check against. - - Returns - ------- - is_array_type : bool - """ - return isinstance(self._data, array_type) - def reduce( self, func: Callable[..., Any], diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index b4af9d37e35..c8b438948de 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -39,7 +39,6 @@ from xarray.core.utils import is_scalar from xarray.testing import _assert_internal_invariants from xarray.tests import ( - DuckArrayWrapper, InaccessibleArray, ReturnItem, assert_allclose, @@ -7166,32 +7165,6 @@ def test_from_pint_wrapping_dask(self) -> None: np.testing.assert_equal(da.to_numpy(), arr) -def test_as_array_type_is_array_type() -> None: - da = xr.DataArray([1, 2, 3], dims=["x"], coords={"x": [4, 5, 6]}) - - assert da.is_array_type(np.ndarray) - - result = da.as_array_type(lambda x: DuckArrayWrapper(x)) - - assert isinstance(result.data, DuckArrayWrapper) - assert isinstance(result.x.data, np.ndarray) - assert result.is_array_type(DuckArrayWrapper) - - -@requires_dask -def test_as_array_type_dask() -> None: - import dask.array - - da = xr.DataArray([1, 2, 3], dims=["x"], coords={"x": [4, 5, 6]}).chunk() - - result = da.as_array_type(lambda x: DuckArrayWrapper(x)) - - assert isinstance(result.data, dask.array.Array) - assert isinstance(result.data._meta, DuckArrayWrapper) - assert isinstance(result.x.data, np.ndarray) - assert result.is_array_type(dask.array.Array) - - class TestStackEllipsis: # https://github.com/pydata/xarray/issues/6051 def test_result_as_expected(self) -> None: diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index b8dbcabf3ce..67d38aac0fe 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -7639,40 +7639,6 @@ def test_from_pint_wrapping_dask(self) -> None: assert_identical(result, expected) -def test_as_array_type_is_array_type() -> None: - ds = xr.Dataset( - {"a": ("x", [1, 2, 3])}, coords={"lat": ("x", [4, 5, 6]), "x": [7, 8, 9]} - ) - # lat is a PandasIndex here - assert ds.drop_vars("lat").is_array_type(np.ndarray) - - result = ds.as_array_type(lambda x: DuckArrayWrapper(x)) - - assert isinstance(result.a.data, DuckArrayWrapper) - assert isinstance(result.lat.data, DuckArrayWrapper) - assert isinstance(result.x.data, np.ndarray) - assert result.is_array_type(DuckArrayWrapper) - - -@requires_dask -def test_as_array_type_dask() -> None: - import dask.array - - ds = xr.Dataset( - {"a": ("x", [1, 2, 3])}, coords={"lat": ("x", [4, 5, 6]), "x": [7, 8, 9]} - ).chunk() - - assert ds.is_array_type(dask.array.Array) - - result = ds.as_array_type(lambda x: DuckArrayWrapper(x)) - - assert isinstance(result.a.data, dask.array.Array) - assert isinstance(result.a.data._meta, DuckArrayWrapper) - assert isinstance(result.lat.data, dask.array.Array) - assert isinstance(result.lat.data._meta, DuckArrayWrapper) - assert isinstance(result.x.data, np.ndarray) - - def test_string_keys_typing() -> None: """Tests that string keys to `variables` are permitted by mypy""" From f306768fe78d4751e6f264ff992dff09e20453a8 Mon Sep 17 00:00:00 2001 From: Sam Levang Date: Fri, 22 Nov 2024 14:20:21 -0500 Subject: [PATCH 17/24] fix up whats new --- doc/whats-new.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index ccaac9e7263..e1fb12269ed 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -64,11 +64,11 @@ New Features underlying array's backend. Provides better support for certain wrapped array types like ``jax.numpy.ndarray``. (:issue:`7848`, :pull:`9776`). By `Sam Levang `_. +- Speed up loading of large zarr stores using dask arrays. (:issue:`8902`) + By `Deepak Cherian `_. - Better support wrapping additional array types (e.g. ``cupy`` or ``jax``) by calling generalized duck array operations throughout more xarray methods. (:issue:`7848`, :pull:`9798`). By `Sam Levang `_. -- Speed up loading of large zarr stores using dask arrays. (:issue:`8902`) - By `Deepak Cherian `_. Breaking changes ~~~~~~~~~~~~~~~~ From 18ebdcdb29bda39395d254be4f7cb3c3f88b6e16 Mon Sep 17 00:00:00 2001 From: Sam Levang Date: Fri, 22 Nov 2024 17:22:06 -0500 Subject: [PATCH 18/24] comment about bool_ --- xarray/core/duck_array_ops.py | 1 + 1 file changed, 1 insertion(+) diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 59f3da2c8f7..7e7333fd8ea 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -166,6 +166,7 @@ def isnull(data): ) ): # these types cannot represent missing values + # bool_ is for backwards compat with numpy<2, and cupy dtype = xp.bool_ if hasattr(xp, "bool_") else xp.bool return full_like(data, dtype=dtype, fill_value=False) else: From 121af9e5b1a12d2759a2f846dd7207afa3100bcb Mon Sep 17 00:00:00 2001 From: Sam Levang Date: Sat, 23 Nov 2024 11:09:25 -0500 Subject: [PATCH 19/24] add jax to complete ci envs --- ci/requirements/environment-3.13.yml | 2 ++ ci/requirements/environment-windows-3.13.yml | 2 ++ ci/requirements/environment-windows.yml | 2 ++ ci/requirements/environment.yml | 2 ++ 4 files changed, 8 insertions(+) diff --git a/ci/requirements/environment-3.13.yml b/ci/requirements/environment-3.13.yml index dbb446f4454..937cb013711 100644 --- a/ci/requirements/environment-3.13.yml +++ b/ci/requirements/environment-3.13.yml @@ -47,3 +47,5 @@ dependencies: - toolz - typing_extensions - zarr + - pip: + - jax # no way to get cpu-only jaxlib from conda if gpu is present diff --git a/ci/requirements/environment-windows-3.13.yml b/ci/requirements/environment-windows-3.13.yml index 448e3f70c0c..0d32fd13a96 100644 --- a/ci/requirements/environment-windows-3.13.yml +++ b/ci/requirements/environment-windows-3.13.yml @@ -42,3 +42,5 @@ dependencies: - toolz - typing_extensions - zarr + - pip: + - jax # no way to get cpu-only jaxlib from conda if gpu is present diff --git a/ci/requirements/environment-windows.yml b/ci/requirements/environment-windows.yml index 3b2e6dc62e6..a9a53d0c1b1 100644 --- a/ci/requirements/environment-windows.yml +++ b/ci/requirements/environment-windows.yml @@ -42,3 +42,5 @@ dependencies: - toolz - typing_extensions - zarr + - pip: + - jax # no way to get cpu-only jaxlib from conda if gpu is present diff --git a/ci/requirements/environment.yml b/ci/requirements/environment.yml index 43938880592..364ae03666f 100644 --- a/ci/requirements/environment.yml +++ b/ci/requirements/environment.yml @@ -49,3 +49,5 @@ dependencies: - toolz - typing_extensions - zarr + - pip: + - jax # no way to get cpu-only jaxlib from conda if gpu is present From 472ae7e7e1fc499adc3598511e96475e8d7ab045 Mon Sep 17 00:00:00 2001 From: Sam Levang Date: Sat, 23 Nov 2024 11:10:17 -0500 Subject: [PATCH 20/24] add pint and sparse to tests --- xarray/core/common.py | 3 +- xarray/tests/test_duck_array_wrapping.py | 154 +++++++++++++++++------ 2 files changed, 118 insertions(+), 39 deletions(-) diff --git a/xarray/core/common.py b/xarray/core/common.py index 8aaa153c1a8..32135996d3c 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -1760,8 +1760,7 @@ def _full_like_variable( **from_array_kwargs, ) else: - xp = duck_array_ops.get_array_namespace(other.data) - data = xp.full_like(other.data, fill_value, dtype=dtype) + data = duck_array_ops.full_like(other.data, fill_value, dtype=dtype) return Variable(dims=other.dims, data=data, attrs=other.attrs) diff --git a/xarray/tests/test_duck_array_wrapping.py b/xarray/tests/test_duck_array_wrapping.py index ffaa1440e1a..05c0ab68bea 100644 --- a/xarray/tests/test_duck_array_wrapping.py +++ b/xarray/tests/test_duck_array_wrapping.py @@ -4,30 +4,107 @@ import xarray as xr -# TODO: how to test these in CI? -jnp = pytest.importorskip("jax.numpy") -cp = pytest.importorskip("cupy") - -NAMESPACES = [cp, jnp] - - -def get_test_dataarray(xp): - return xr.DataArray( - xp.asarray([[1, 2, 3, np.nan, 5]]), - dims=["y", "x"], - coords={"y": [1], "x": np.arange(5)}, - name="foo", - ) - - -@pytest.mark.parametrize("xp", NAMESPACES) -class TestTopLevelMethods: +# Don't run cupy in CI because it requires a GPU +NAMESPACE_ARRAYS = { + "jax.numpy": { + "array": "ndarray", + "constructor": "asarray", + "xfails": { + "rolling": "no sliding_window_view", + "rolling_mean": "no sliding_window_view", + }, + }, + "cupy": { + "array": "ndarray", + "constructor": "asarray", + "xfails": {"quantile": "no nanquantile"}, + }, + "pint": { + "array": "Quantity", + "constructor": "Quantity", + "xfails": { + "all": "returns a bool", + "any": "returns a bool", + "argmax": "returns an int", + "argmin": "returns an int", + "argsort": "returns an int", + "count": "returns an int", + "dot": "no tensordot", + "full_like": "should work, see: https://github.com/hgrecco/pint/pull/1669", + "idxmax": "returns the coordinate", + "idxmin": "returns the coordinate", + "isin": "returns a bool", + "isnull": "returns a bool", + "notnull": "returns a bool", + "rolling_mean": "no dispatch for numbagg/bottleneck", + "searchsorted": "returns an int", + "weighted": "no tensordot", + }, + }, + "sparse": { + "array": "COO", + "constructor": "COO", + "xfails": { + "cov": "dense output", + "corr": "no nanstd", + "cross": "no cross", + "count": "dense output", + "isin": "no isin", + "rolling": "no sliding_window_view", + "rolling_mean": "no sliding_window_view", + "weighted": "fill_value error", + "coarsen": "pad constant_values must be fill_value", + "quantile": "no non skipping version", + "differentiate": "no gradient", + "argmax": "no nan skipping version", + "argmin": "no nan skipping version", + "idxmax": "no nan skipping version", + "idxmin": "no nan skipping version", + "median": "no nan skipping version", + "std": "no nan skipping version", + "var": "no nan skipping version", + "cumsum": "no cumsum", + "cumprod": "no cumprod", + "argsort": "no argsort", + "conjugate": "no conjugate", + "searchsorted": "no searchsorted", + "shift": "pad constant_values must be fill_value", + "pad": "pad constant_values must be fill_value", + }, + }, +} + + +class _BaseTest: + def setup_for_test(self, request, namespace): + self.namespace = namespace + self.xp = pytest.importorskip(namespace) + self.Array = getattr(self.xp, NAMESPACE_ARRAYS[namespace]["array"]) + self.constructor = getattr(self.xp, NAMESPACE_ARRAYS[namespace]["constructor"]) + xarray_method = request.node.name.split("test_")[1].split("[")[0] + if xarray_method in NAMESPACE_ARRAYS[namespace]["xfails"]: + reason = NAMESPACE_ARRAYS[namespace]["xfails"][xarray_method] + pytest.xfail(f"xfail for {self.namespace}: {reason}") + + def get_test_dataarray(self): + data = np.asarray([[1, 2, 3, np.nan, 5]]) + x = np.arange(5) + data = self.constructor(data) + return xr.DataArray( + data, + dims=["y", "x"], + coords={"y": [1], "x": x}, + name="foo", + ) + + +@pytest.mark.parametrize("namespace", NAMESPACE_ARRAYS) +class TestTopLevelMethods(_BaseTest): @pytest.fixture(autouse=True) - def setUp(self, xp): - self.xp = xp - self.Array = xp.ndarray - self.x1 = get_test_dataarray(xp) - self.x2 = get_test_dataarray(xp).assign_coords(x=np.arange(2, 7)) + def setUp(self, request, namespace): + self.setup_for_test(request, namespace) + self.x1 = self.get_test_dataarray() + self.x2 = self.get_test_dataarray().assign_coords(x=np.arange(2, 7)) def test_apply_ufunc(self): func = lambda x: x + 1 @@ -83,13 +160,12 @@ def test_map_blocks(self): assert isinstance(result.data, self.Array) -@pytest.mark.parametrize("xp", NAMESPACES) -class TestDataArrayMethods: +@pytest.mark.parametrize("namespace", NAMESPACE_ARRAYS) +class TestDataArrayMethods(_BaseTest): @pytest.fixture(autouse=True) - def setUp(self, xp): - self.xp = xp - self.Array = xp.ndarray - self.x = get_test_dataarray(xp) + def setUp(self, request, namespace): + self.setup_for_test(request, namespace) + self.x = self.get_test_dataarray() def test_loc(self): result = self.x.loc[{"x": slice(1, 3)}] @@ -153,7 +229,8 @@ def test_where(self): assert isinstance(result.data, self.Array) def test_isin(self): - result = self.x.isin(self.xp.asarray([1])) + test_elements = self.constructor(np.asarray([1])) + result = self.x.isin(test_elements) assert isinstance(result.data, self.Array) def test_groupby(self): @@ -161,9 +238,13 @@ def test_groupby(self): assert isinstance(result.data, self.Array) def test_rolling(self): - if self.xp is jnp: - pytest.xfail("no sliding_window_view in jax") - result = self.x.rolling(x=3).mean() + result = self.x.rolling(x=3) + elem = next(iter(result))[1] + assert isinstance(elem.data, self.Array) + + @pytest.mark.parametrize("skipna", [True, False]) + def test_rolling_mean(self, skipna): + result = self.x.rolling(x=3).mean(skipna=skipna) assert isinstance(result.data, self.Array) @pytest.mark.xfail(reason="rolling_exp uses numbagg") @@ -194,14 +275,12 @@ def test_dot(self): @pytest.mark.parametrize("skipna", [True, False]) def test_quantile(self, skipna): - if self.xp is cp and skipna: - pytest.xfail("no nanquantile in cupy") result = self.x.quantile(0.5, skipna=skipna) assert isinstance(result.data, self.Array) def test_differentiate(self): # edge_order is not implemented in jax, and only supports passing None - edge_order = None if self.xp is jnp else 1 + edge_order = None if self.namespace == "jax.numpy" else 1 result = self.x.differentiate("x", edge_order=edge_order) assert isinstance(result.data, self.Array) @@ -318,7 +397,8 @@ def test_imag(self): assert isinstance(result.data, self.Array) def test_searchsorted(self): - result = self.x.squeeze().searchsorted(self.xp.asarray(3)) + v = self.constructor(np.asarray([3])) + result = self.x.squeeze().searchsorted(v) assert isinstance(result, self.Array) def test_round(self): From 5aa4a392b314544750ce0395395492b02dbddae6 Mon Sep 17 00:00:00 2001 From: Sam Levang Date: Sat, 23 Nov 2024 11:19:25 -0500 Subject: [PATCH 21/24] remove from windows --- ci/requirements/environment-windows-3.13.yml | 2 -- ci/requirements/environment-windows.yml | 2 -- 2 files changed, 4 deletions(-) diff --git a/ci/requirements/environment-windows-3.13.yml b/ci/requirements/environment-windows-3.13.yml index 0d32fd13a96..448e3f70c0c 100644 --- a/ci/requirements/environment-windows-3.13.yml +++ b/ci/requirements/environment-windows-3.13.yml @@ -42,5 +42,3 @@ dependencies: - toolz - typing_extensions - zarr - - pip: - - jax # no way to get cpu-only jaxlib from conda if gpu is present diff --git a/ci/requirements/environment-windows.yml b/ci/requirements/environment-windows.yml index a9a53d0c1b1..3b2e6dc62e6 100644 --- a/ci/requirements/environment-windows.yml +++ b/ci/requirements/environment-windows.yml @@ -42,5 +42,3 @@ dependencies: - toolz - typing_extensions - zarr - - pip: - - jax # no way to get cpu-only jaxlib from conda if gpu is present From 390df6f7715b46d557cb64fd32a21e6567e64e21 Mon Sep 17 00:00:00 2001 From: Sam Levang Date: Sat, 23 Nov 2024 12:40:31 -0500 Subject: [PATCH 22/24] mypy, xfail one more sparse --- xarray/tests/test_duck_array_wrapping.py | 31 ++++++++++++++++-------- 1 file changed, 21 insertions(+), 10 deletions(-) diff --git a/xarray/tests/test_duck_array_wrapping.py b/xarray/tests/test_duck_array_wrapping.py index 05c0ab68bea..63413aba1a3 100644 --- a/xarray/tests/test_duck_array_wrapping.py +++ b/xarray/tests/test_duck_array_wrapping.py @@ -7,21 +7,27 @@ # Don't run cupy in CI because it requires a GPU NAMESPACE_ARRAYS = { "jax.numpy": { - "array": "ndarray", - "constructor": "asarray", + "attrs": { + "array": "ndarray", + "constructor": "asarray", + }, "xfails": { "rolling": "no sliding_window_view", "rolling_mean": "no sliding_window_view", }, }, "cupy": { - "array": "ndarray", - "constructor": "asarray", + "attrs": { + "array": "ndarray", + "constructor": "asarray", + }, "xfails": {"quantile": "no nanquantile"}, }, "pint": { - "array": "Quantity", - "constructor": "Quantity", + "attrs": { + "array": "Quantity", + "constructor": "Quantity", + }, "xfails": { "all": "returns a bool", "any": "returns a bool", @@ -42,13 +48,16 @@ }, }, "sparse": { - "array": "COO", - "constructor": "COO", + "attrs": { + "array": "COO", + "constructor": "COO", + }, "xfails": { "cov": "dense output", "corr": "no nanstd", "cross": "no cross", "count": "dense output", + "dot": "fails on some platforms/versions", "isin": "no isin", "rolling": "no sliding_window_view", "rolling_mean": "no sliding_window_view", @@ -79,8 +88,10 @@ class _BaseTest: def setup_for_test(self, request, namespace): self.namespace = namespace self.xp = pytest.importorskip(namespace) - self.Array = getattr(self.xp, NAMESPACE_ARRAYS[namespace]["array"]) - self.constructor = getattr(self.xp, NAMESPACE_ARRAYS[namespace]["constructor"]) + self.Array = getattr(self.xp, NAMESPACE_ARRAYS[namespace]["attrs"]["array"]) + self.constructor = getattr( + self.xp, NAMESPACE_ARRAYS[namespace]["attrs"]["constructor"] + ) xarray_method = request.node.name.split("test_")[1].split("[")[0] if xarray_method in NAMESPACE_ARRAYS[namespace]["xfails"]: reason = NAMESPACE_ARRAYS[namespace]["xfails"][xarray_method] From f6074d2fa3b9c2d3900cda25c4cb322f2c698bd1 Mon Sep 17 00:00:00 2001 From: Sam Levang Date: Mon, 25 Nov 2024 10:01:59 -0500 Subject: [PATCH 23/24] add dask and a few other methods --- xarray/tests/test_duck_array_wrapping.py | 73 +++++++++++++++++++----- 1 file changed, 60 insertions(+), 13 deletions(-) diff --git a/xarray/tests/test_duck_array_wrapping.py b/xarray/tests/test_duck_array_wrapping.py index 63413aba1a3..59928dce370 100644 --- a/xarray/tests/test_duck_array_wrapping.py +++ b/xarray/tests/test_duck_array_wrapping.py @@ -6,22 +6,35 @@ # Don't run cupy in CI because it requires a GPU NAMESPACE_ARRAYS = { - "jax.numpy": { + "cupy": { "attrs": { "array": "ndarray", "constructor": "asarray", }, + "xfails": {"quantile": "no nanquantile"}, + }, + "dask.array": { + "attrs": { + "array": "Array", + "constructor": "from_array", + }, "xfails": { - "rolling": "no sliding_window_view", - "rolling_mean": "no sliding_window_view", + "argsort": "no argsort", + "conjugate": "conj but no conjugate", + "searchsorted": "dask.array.searchsorted but no Array.searchsorted", }, }, - "cupy": { + "jax.numpy": { "attrs": { "array": "ndarray", "constructor": "asarray", }, - "xfails": {"quantile": "no nanquantile"}, + "xfails": { + "rolling_construct": "no sliding_window_view", + "rolling_reduce": "no sliding_window_view", + "cumulative_construct": "no sliding_window_view", + "cumulative_reduce": "no sliding_window_view", + }, }, "pint": { "attrs": { @@ -42,7 +55,8 @@ "isin": "returns a bool", "isnull": "returns a bool", "notnull": "returns a bool", - "rolling_mean": "no dispatch for numbagg/bottleneck", + "rolling_reduce": "no dispatch for numbagg/bottleneck", + "cumulative_reduce": "no dispatch for numbagg/bottleneck", "searchsorted": "returns an int", "weighted": "no tensordot", }, @@ -59,8 +73,12 @@ "count": "dense output", "dot": "fails on some platforms/versions", "isin": "no isin", - "rolling": "no sliding_window_view", - "rolling_mean": "no sliding_window_view", + "rolling_construct": "no sliding_window_view", + "rolling_reduce": "no sliding_window_view", + "cumulative_construct": "no sliding_window_view", + "cumulative_reduce": "no sliding_window_view", + "coarsen_construct": "pad constant_values must be fill_value", + "coarsen_reduce": "pad constant_values must be fill_value", "weighted": "fill_value error", "coarsen": "pad constant_values must be fill_value", "quantile": "no non skipping version", @@ -119,7 +137,7 @@ def setUp(self, request, namespace): def test_apply_ufunc(self): func = lambda x: x + 1 - result = xr.apply_ufunc(func, self.x1) + result = xr.apply_ufunc(func, self.x1, dask="parallelized") assert isinstance(result.data, self.Array) def test_align(self): @@ -248,26 +266,51 @@ def test_groupby(self): result = self.x.groupby("x").mean() assert isinstance(result.data, self.Array) - def test_rolling(self): + def test_groupby_bins(self): + result = self.x.groupby_bins("x", bins=[0, 2, 4, 6]).mean() + assert isinstance(result.data, self.Array) + + def test_rolling_iter(self): result = self.x.rolling(x=3) elem = next(iter(result))[1] assert isinstance(elem.data, self.Array) + def test_rolling_construct(self): + result = self.x.rolling(x=3).construct(x="window") + assert isinstance(result.data, self.Array) + @pytest.mark.parametrize("skipna", [True, False]) - def test_rolling_mean(self, skipna): + def test_rolling_reduce(self, skipna): result = self.x.rolling(x=3).mean(skipna=skipna) assert isinstance(result.data, self.Array) @pytest.mark.xfail(reason="rolling_exp uses numbagg") - def test_rolling_exp(self): + def test_rolling_exp_reduce(self): result = self.x.rolling_exp(x=3).mean() assert isinstance(result.data, self.Array) + def test_cumulative_iter(self): + result = self.x.cumulative("x") + elem = next(iter(result))[1] + assert isinstance(elem.data, self.Array) + + def test_cumulative_construct(self): + result = self.x.cumulative("x").construct(x="window") + assert isinstance(result.data, self.Array) + + def test_cumulative_reduce(self): + result = self.x.cumulative("x").sum() + assert isinstance(result.data, self.Array) + def test_weighted(self): result = self.x.weighted(self.x.fillna(0)).mean() assert isinstance(result.data, self.Array) - def test_coarsen(self): + def test_coarsen_construct(self): + result = self.x.coarsen(x=2, boundary="pad").construct(x=["a", "b"]) + assert isinstance(result.data, self.Array) + + def test_coarsen_reduce(self): result = self.x.coarsen(x=2, boundary="pad").mean() assert isinstance(result.data, self.Array) @@ -391,6 +434,10 @@ def test_argsort(self): result = self.x.argsort() assert isinstance(result.data, self.Array) + def test_astype(self): + result = self.x.astype(int) + assert isinstance(result.data, self.Array) + def test_clip(self): result = self.x.clip(min=2.0, max=4.0) assert isinstance(result.data, self.Array) From bfd6aebbb0cb2f91f86a71a82e7593ae2b9365e3 Mon Sep 17 00:00:00 2001 From: Sam Levang Date: Mon, 25 Nov 2024 10:03:45 -0500 Subject: [PATCH 24/24] move whats new --- doc/whats-new.rst | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index ab1cba8d9a6..906fd0a25b2 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -21,6 +21,9 @@ v.2024.11.1 (unreleased) New Features ~~~~~~~~~~~~ +- Better support wrapping additional array types (e.g. ``cupy`` or ``jax``) by calling generalized + duck array operations throughout more xarray methods. (:issue:`7848`, :pull:`9798`). + By `Sam Levang `_. Breaking changes @@ -85,9 +88,6 @@ New Features By `Sam Levang `_. - Speed up loading of large zarr stores using dask arrays. (:issue:`8902`) By `Deepak Cherian `_. -- Better support wrapping additional array types (e.g. ``cupy`` or ``jax``) by calling generalized - duck array operations throughout more xarray methods. (:issue:`7848`, :pull:`9798`). - By `Sam Levang `_. Breaking Changes ~~~~~~~~~~~~~~~~ pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy