diff --git a/ci/requirements/environment-3.13.yml b/ci/requirements/environment-3.13.yml index dbb446f4454..937cb013711 100644 --- a/ci/requirements/environment-3.13.yml +++ b/ci/requirements/environment-3.13.yml @@ -47,3 +47,5 @@ dependencies: - toolz - typing_extensions - zarr + - pip: + - jax # no way to get cpu-only jaxlib from conda if gpu is present diff --git a/ci/requirements/environment.yml b/ci/requirements/environment.yml index 43938880592..364ae03666f 100644 --- a/ci/requirements/environment.yml +++ b/ci/requirements/environment.yml @@ -49,3 +49,5 @@ dependencies: - toolz - typing_extensions - zarr + - pip: + - jax # no way to get cpu-only jaxlib from conda if gpu is present diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 0da34df2c1a..906fd0a25b2 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -21,6 +21,9 @@ v.2024.11.1 (unreleased) New Features ~~~~~~~~~~~~ +- Better support wrapping additional array types (e.g. ``cupy`` or ``jax``) by calling generalized + duck array operations throughout more xarray methods. (:issue:`7848`, :pull:`9798`). + By `Sam Levang `_. Breaking changes diff --git a/xarray/core/array_api_compat.py b/xarray/core/array_api_compat.py index da072de5b69..e1e5d5c5bdc 100644 --- a/xarray/core/array_api_compat.py +++ b/xarray/core/array_api_compat.py @@ -1,5 +1,7 @@ import numpy as np +from xarray.namedarray.pycompat import array_type + def is_weak_scalar_type(t): return isinstance(t, bool | int | float | complex | str | bytes) @@ -42,3 +44,39 @@ def result_type(*arrays_and_dtypes, xp) -> np.dtype: return xp.result_type(*arrays_and_dtypes) else: return _future_array_api_result_type(*arrays_and_dtypes, xp=xp) + + +def get_array_namespace(*values): + def _get_single_namespace(x): + if hasattr(x, "__array_namespace__"): + return x.__array_namespace__() + elif isinstance(x, array_type("cupy")): + # cupy is fully compliant from xarray's perspective, but will not expose + # __array_namespace__ until at least v14. Special case it for now + import cupy as cp + + return cp + else: + return np + + namespaces = {_get_single_namespace(t) for t in values} + non_numpy = namespaces - {np} + + if len(non_numpy) > 1: + names = [module.__name__ for module in non_numpy] + raise TypeError(f"Mixed array types {names} are not supported.") + elif non_numpy: + [xp] = non_numpy + else: + xp = np + + return xp + + +def to_like_array(array, like): + # Mostly for cupy compatibility, because cupy binary ops require all cupy arrays + xp = get_array_namespace(like) + if xp is not np: + return xp.asarray(array) + # avoid casting things like pint quantities to numpy arrays + return array diff --git a/xarray/core/common.py b/xarray/core/common.py index 6f788f408d0..32135996d3c 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -496,7 +496,7 @@ def clip( keep_attrs = _get_keep_attrs(default=True) return apply_ufunc( - np.clip, self, min, max, keep_attrs=keep_attrs, dask="allowed" + duck_array_ops.clip, self, min, max, keep_attrs=keep_attrs, dask="allowed" ) def get_index(self, key: Hashable) -> pd.Index: @@ -1760,7 +1760,7 @@ def _full_like_variable( **from_array_kwargs, ) else: - data = np.full_like(other.data, fill_value, dtype=dtype) + data = duck_array_ops.full_like(other.data, fill_value, dtype=dtype) return Variable(dims=other.dims, data=data, attrs=other.attrs) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index b15ed7f3f34..6e233425e95 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -24,6 +24,7 @@ from xarray.core import dtypes, duck_array_ops, utils from xarray.core.alignment import align, deep_align +from xarray.core.array_api_compat import to_like_array from xarray.core.common import zeros_like from xarray.core.duck_array_ops import datetime_to_numeric from xarray.core.formatting import limit_lines @@ -1702,7 +1703,7 @@ def cross( ) c = apply_ufunc( - np.cross, + duck_array_ops.cross, a, b, input_core_dims=[[dim], [dim]], @@ -2170,13 +2171,14 @@ def _calc_idxminmax( chunks = dict(zip(array.dims, array.chunks, strict=True)) dask_coord = chunkmanager.from_array(array[dim].data, chunks=chunks[dim]) data = dask_coord[duck_array_ops.ravel(indx.data)] - res = indx.copy(data=duck_array_ops.reshape(data, indx.shape)) - # we need to attach back the dim name - res.name = dim else: - res = array[dim][(indx,)] - # The dim is gone but we need to remove the corresponding coordinate. - del res.coords[dim] + arr_coord = to_like_array(array[dim].data, array.data) + data = arr_coord[duck_array_ops.ravel(indx.data)] + + # rebuild like the argmin/max output, and rename as the dim name + data = duck_array_ops.reshape(data, indx.shape) + res = indx.copy(data=data) + res.name = dim if skipna or (skipna is None and array.dtype.kind in na_dtypes): # Put the NaN values back in after removing them diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index e80ce5fa64a..ce8f93a37e5 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -55,6 +55,7 @@ align, ) from xarray.core.arithmetic import DatasetArithmetic +from xarray.core.array_api_compat import to_like_array from xarray.core.common import ( DataWithCoords, _contains_datetime_like_objects, @@ -127,7 +128,7 @@ calculate_dimensions, ) from xarray.namedarray.parallelcompat import get_chunked_array_type, guess_chunkmanager -from xarray.namedarray.pycompat import array_type, is_chunked_array +from xarray.namedarray.pycompat import array_type, is_chunked_array, to_numpy from xarray.plot.accessor import DatasetPlotAccessor from xarray.util.deprecation_helpers import _deprecate_positional_args, deprecate_dims @@ -6622,7 +6623,7 @@ def dropna( array = self._variables[k] if dim in array.dims: dims = [d for d in array.dims if d != dim] - count += np.asarray(array.count(dims)) + count += to_numpy(array.count(dims).data) size += math.prod([self.sizes[d] for d in dims]) if thresh is not None: @@ -8736,16 +8737,17 @@ def _integrate_one(self, coord, datetime_unit=None, cumulative=False): coord_names.add(k) else: if k in self.data_vars and dim in v.dims: + coord_data = to_like_array(coord_var.data, like=v.data) if _contains_datetime_like_objects(v): v = datetime_to_numeric(v, datetime_unit=datetime_unit) if cumulative: integ = duck_array_ops.cumulative_trapezoid( - v.data, coord_var.data, axis=v.get_axis_num(dim) + v.data, coord_data, axis=v.get_axis_num(dim) ) v_dims = v.dims else: integ = duck_array_ops.trapz( - v.data, coord_var.data, axis=v.get_axis_num(dim) + v.data, coord_data, axis=v.get_axis_num(dim) ) v_dims = list(v.dims) v_dims.remove(dim) diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 0b915166279..7e7333fd8ea 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -18,21 +18,16 @@ import pandas as pd from numpy import all as array_all # noqa: F401 from numpy import any as array_any # noqa: F401 -from numpy import concatenate as _concatenate from numpy import ( # noqa: F401 - full_like, - gradient, isclose, - isin, isnat, take, - tensordot, - transpose, unravel_index, ) from pandas.api.types import is_extension_array_dtype from xarray.core import dask_array_compat, dask_array_ops, dtypes, nputils +from xarray.core.array_api_compat import get_array_namespace from xarray.core.options import OPTIONS from xarray.core.utils import is_duck_array, is_duck_dask_array, module_available from xarray.namedarray.parallelcompat import get_chunked_array_type @@ -52,28 +47,6 @@ dask_available = module_available("dask") -def get_array_namespace(*values): - def _get_array_namespace(x): - if hasattr(x, "__array_namespace__"): - return x.__array_namespace__() - else: - return np - - namespaces = {_get_array_namespace(t) for t in values} - non_numpy = namespaces - {np} - - if len(non_numpy) > 1: - raise TypeError( - "cannot deal with more than one type supporting the array API at the same time" - ) - elif non_numpy: - [xp] = non_numpy - else: - xp = np - - return xp - - def einsum(*args, **kwargs): from xarray.core.options import OPTIONS @@ -82,7 +55,23 @@ def einsum(*args, **kwargs): return opt_einsum.contract(*args, **kwargs) else: - return np.einsum(*args, **kwargs) + xp = get_array_namespace(*args) + return xp.einsum(*args, **kwargs) + + +def tensordot(*args, **kwargs): + xp = get_array_namespace(*args) + return xp.tensordot(*args, **kwargs) + + +def cross(*args, **kwargs): + xp = get_array_namespace(*args) + return xp.cross(*args, **kwargs) + + +def gradient(f, *varargs, axis=None, edge_order=1): + xp = get_array_namespace(f) + return xp.gradient(f, *varargs, axis=axis, edge_order=edge_order) def _dask_or_eager_func( @@ -131,15 +120,20 @@ def fail_on_dask_array_input(values, msg=None, func_name=None): "masked_invalid", eager_module=np.ma, dask_module="dask.array.ma" ) -# sliding_window_view will not dispatch arbitrary kwargs (automatic_rechunk), -# so we need to hand-code this. -sliding_window_view = _dask_or_eager_func( - "sliding_window_view", - eager_module=np.lib.stride_tricks, - dask_module=dask_array_compat, - dask_only_kwargs=("automatic_rechunk",), - numpy_only_kwargs=("subok", "writeable"), -) + +def sliding_window_view(array, window_shape, axis=None, **kwargs): + # TODO: some libraries (e.g. jax) don't have this, implement an alternative? + xp = get_array_namespace(array) + # sliding_window_view will not dispatch arbitrary kwargs (automatic_rechunk), + # so we need to hand-code this. + func = _dask_or_eager_func( + "sliding_window_view", + eager_module=xp.lib.stride_tricks, + dask_module=dask_array_compat, + dask_only_kwargs=("automatic_rechunk",), + numpy_only_kwargs=("subok", "writeable"), + ) + return func(array, window_shape, axis=axis, **kwargs) def round(array): @@ -172,7 +166,9 @@ def isnull(data): ) ): # these types cannot represent missing values - return full_like(data, dtype=bool, fill_value=False) + # bool_ is for backwards compat with numpy<2, and cupy + dtype = xp.bool_ if hasattr(xp, "bool_") else xp.bool + return full_like(data, dtype=dtype, fill_value=False) else: # at this point, array should have dtype=object if isinstance(data, np.ndarray) or is_extension_array_dtype(data): @@ -213,11 +209,23 @@ def cumulative_trapezoid(y, x, axis): # Pad so that 'axis' has same length in result as it did in y pads = [(1, 0) if i == axis else (0, 0) for i in range(y.ndim)] - integrand = np.pad(integrand, pads, mode="constant", constant_values=0.0) + + xp = get_array_namespace(y, x) + integrand = xp.pad(integrand, pads, mode="constant", constant_values=0.0) return cumsum(integrand, axis=axis, skipna=False) +def full_like(a, fill_value, **kwargs): + xp = get_array_namespace(a) + return xp.full_like(a, fill_value, **kwargs) + + +def empty_like(a, **kwargs): + xp = get_array_namespace(a) + return xp.empty_like(a, **kwargs) + + def astype(data, dtype, **kwargs): if hasattr(data, "__array_namespace__"): xp = get_array_namespace(data) @@ -348,7 +356,8 @@ def array_notnull_equiv(arr1, arr2): def count(data, axis=None): """Count the number of non-NA in this array along the given axis or axes""" - return np.sum(np.logical_not(isnull(data)), axis=axis) + xp = get_array_namespace(data) + return xp.sum(xp.logical_not(isnull(data)), axis=axis) def sum_where(data, axis=None, dtype=None, where=None): @@ -363,7 +372,7 @@ def sum_where(data, axis=None, dtype=None, where=None): def where(condition, x, y): """Three argument where() with better dtype promotion rules.""" - xp = get_array_namespace(condition) + xp = get_array_namespace(condition, x, y) return xp.where(condition, *as_shared_dtype([x, y], xp=xp)) @@ -380,15 +389,25 @@ def fillna(data, other): return where(notnull(data), data, other) +def logical_not(data): + xp = get_array_namespace(data) + return xp.logical_not(data) + + +def clip(data, min=None, max=None): + xp = get_array_namespace(data) + return xp.clip(data, min, max) + + def concatenate(arrays, axis=0): """concatenate() with better dtype promotion rules.""" - # TODO: remove the additional check once `numpy` adds `concat` to its array namespace - if hasattr(arrays[0], "__array_namespace__") and not isinstance( - arrays[0], np.ndarray - ): - xp = get_array_namespace(arrays[0]) + # TODO: `concat` is the xp compliant name, but fallback to concatenate for + # older numpy and for cupy + xp = get_array_namespace(*arrays) + if hasattr(xp, "concat"): return xp.concat(as_shared_dtype(arrays, xp=xp), axis=axis) - return _concatenate(as_shared_dtype(arrays), axis=axis) + else: + return xp.concatenate(as_shared_dtype(arrays, xp=xp), axis=axis) def stack(arrays, axis=0): @@ -406,6 +425,26 @@ def ravel(array): return reshape(array, (-1,)) +def transpose(array, axes=None): + xp = get_array_namespace(array) + return xp.transpose(array, axes) + + +def moveaxis(array, source, destination): + xp = get_array_namespace(array) + return xp.moveaxis(array, source, destination) + + +def pad(array, pad_width, **kwargs): + xp = get_array_namespace(array) + return xp.pad(array, pad_width, **kwargs) + + +def quantile(array, q, axis=None, **kwargs): + xp = get_array_namespace(array) + return xp.quantile(array, q, axis=axis, **kwargs) + + @contextlib.contextmanager def _ignore_warnings_if(condition): if condition: @@ -747,6 +786,11 @@ def last(values, axis, skipna=None): return take(values, -1, axis=axis) +def isin(element, test_elements, **kwargs): + xp = get_array_namespace(element, test_elements) + return xp.isin(element, test_elements, **kwargs) + + def least_squares(lhs, rhs, rcond=None, skipna=False): """Return the coefficients and residuals of a least-squares fit.""" if is_duck_dask_array(rhs): diff --git a/xarray/core/nanops.py b/xarray/core/nanops.py index 7fbb63068c0..4894cf02be2 100644 --- a/xarray/core/nanops.py +++ b/xarray/core/nanops.py @@ -128,7 +128,7 @@ def nanmean(a, axis=None, dtype=None, out=None): "ignore", r"Mean of empty slice", category=RuntimeWarning ) - return np.nanmean(a, axis=axis, dtype=dtype) + return nputils.nanmean(a, axis=axis, dtype=dtype) def nanmedian(a, axis=None, out=None): diff --git a/xarray/core/nputils.py b/xarray/core/nputils.py index bf5dfa1bc32..3211ab296e6 100644 --- a/xarray/core/nputils.py +++ b/xarray/core/nputils.py @@ -7,6 +7,7 @@ import pandas as pd from packaging.version import Version +from xarray.core.array_api_compat import get_array_namespace from xarray.core.utils import is_duck_array, module_available from xarray.namedarray import pycompat @@ -179,6 +180,11 @@ def f(values, axis=None, **kwargs): dtype = kwargs.get("dtype") bn_func = getattr(bn, name, None) + xp = get_array_namespace(values) + if xp is not np: + func = getattr(xp, name, None) + if func is not None: + return func(values, axis=axis, **kwargs) if ( module_available("numbagg") and OPTIONS["use_numbagg"] @@ -229,6 +235,9 @@ def f(values, axis=None, **kwargs): # bottleneck does not take care dtype, min_count kwargs.pop("dtype", None) result = bn_func(values, axis=axis, **kwargs) + # bottleneck returns python scalars for reduction over all axes + if isinstance(result, float): + result = np.float64(result) else: result = getattr(npmodule, name)(values, axis=axis, **kwargs) diff --git a/xarray/core/rolling.py b/xarray/core/rolling.py index cb16c3723ca..fde87841d32 100644 --- a/xarray/core/rolling.py +++ b/xarray/core/rolling.py @@ -708,6 +708,7 @@ def _array_reduce( ) del kwargs["dim"] + xp = duck_array_ops.get_array_namespace(self.obj.data) if ( OPTIONS["use_numbagg"] and module_available("numbagg") @@ -722,6 +723,7 @@ def _array_reduce( # TODO: we could also allow this, probably as part of a refactoring of this # module, so we can use the machinery in `self.reduce`. and self.ndim == 1 + and xp is np ): import numbagg @@ -744,6 +746,7 @@ def _array_reduce( or module_available("dask", "2024.11.0") ) and self.ndim == 1 + and xp is np ): return self._bottleneck_reduce( bottleneck_move_func, keep_attrs=keep_attrs, **kwargs diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 9f660d0878a..07113d66b5b 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -19,6 +19,7 @@ import xarray as xr # only for Dataset and DataArray from xarray.core import common, dtypes, duck_array_ops, indexing, nputils, ops, utils from xarray.core.arithmetic import VariableArithmetic +from xarray.core.array_api_compat import to_like_array from xarray.core.common import AbstractArray from xarray.core.extension_array import PandasExtensionArray from xarray.core.indexing import ( @@ -828,7 +829,7 @@ def __getitem__(self, key) -> Self: data = indexing.apply_indexer(indexable, indexer) if new_order: - data = np.moveaxis(data, range(len(new_order)), new_order) + data = duck_array_ops.moveaxis(data, range(len(new_order)), new_order) return self._finalize_indexing_result(dims, data) def _finalize_indexing_result(self, dims, data) -> Self: @@ -866,12 +867,15 @@ def _getitem_with_mask(self, key, fill_value=dtypes.NA): # we need to invert the mask in order to pass data first. This helps # pint to choose the correct unit # TODO: revert after https://github.com/hgrecco/pint/issues/1019 is fixed - data = duck_array_ops.where(np.logical_not(mask), data, fill_value) + mask = to_like_array(mask, data) + data = duck_array_ops.where( + duck_array_ops.logical_not(mask), data, fill_value + ) else: # array cannot be indexed along dimensions of size 0, so just # build the mask directly instead. mask = indexing.create_mask(indexer, self.shape) - data = np.broadcast_to(fill_value, getattr(mask, "shape", ())) + data = duck_array_ops.broadcast_to(fill_value, getattr(mask, "shape", ())) if new_order: data = duck_array_ops.moveaxis(data, range(len(new_order)), new_order) @@ -902,7 +906,7 @@ def __setitem__(self, key, value): if new_order: value = duck_array_ops.asarray(value) value = value[(len(dims) - value.ndim) * (np.newaxis,) + (Ellipsis,)] - value = np.moveaxis(value, new_order, range(len(new_order))) + value = duck_array_ops.moveaxis(value, new_order, range(len(new_order))) indexable = as_indexable(self._data) indexing.set_with_indexer(indexable, index_tuple, value) @@ -1122,7 +1126,7 @@ def _shift_one_dim(self, dim, count, fill_value=dtypes.NA): dim_pad = (width, 0) if count >= 0 else (0, width) pads = [(0, 0) if d != dim else dim_pad for d in self.dims] - data = np.pad( + data = duck_array_ops.pad( duck_array_ops.astype(trimmed_data, dtype), pads, mode="constant", @@ -1268,7 +1272,7 @@ def pad( if reflect_type is not None: pad_option_kwargs["reflect_type"] = reflect_type - array = np.pad( + array = duck_array_ops.pad( duck_array_ops.astype(self.data, dtype, copy=False), pad_width_by_index, mode=mode, @@ -1557,14 +1561,16 @@ def _unstack_once( if is_missing_values: dtype, fill_value = dtypes.maybe_promote(self.dtype) - create_template = partial(np.full_like, fill_value=fill_value) + create_template = partial( + duck_array_ops.full_like, fill_value=fill_value + ) else: dtype = self.dtype fill_value = dtypes.get_fill_value(dtype) - create_template = np.empty_like + create_template = duck_array_ops.empty_like else: dtype = self.dtype - create_template = partial(np.full_like, fill_value=fill_value) + create_template = partial(duck_array_ops.full_like, fill_value=fill_value) if sparse: # unstacking a dense multitindexed array to a sparse array @@ -1654,7 +1660,8 @@ def clip(self, min=None, max=None): """ from xarray.core.computation import apply_ufunc - return apply_ufunc(np.clip, self, min, max, dask="allowed") + xp = duck_array_ops.get_array_namespace(self.data) + return apply_ufunc(xp.clip, self, min, max, dask="allowed") def reduce( # type: ignore[override] self, @@ -1947,7 +1954,7 @@ def quantile( if skipna or (skipna is None and self.dtype.kind in "cfO"): _quantile_func = nputils.nanquantile else: - _quantile_func = np.quantile + _quantile_func = duck_array_ops.quantile if keep_attrs is None: keep_attrs = _get_keep_attrs(default=False) @@ -1961,11 +1968,14 @@ def quantile( if utils.is_scalar(dim): dim = [dim] + xp = duck_array_ops.get_array_namespace(self.data) + def _wrapper(npa, **kwargs): # move quantile axis to end. required for apply_ufunc - return np.moveaxis(_quantile_func(npa, **kwargs), 0, -1) + return xp.moveaxis(_quantile_func(npa, **kwargs), 0, -1) - axis = np.arange(-1, -1 * len(dim) - 1, -1) + # jax requires hashable + axis = tuple(range(-1, -1 * len(dim) - 1, -1)) kwargs = {"q": q, "axis": axis, "method": method} diff --git a/xarray/tests/test_duck_array_wrapping.py b/xarray/tests/test_duck_array_wrapping.py new file mode 100644 index 00000000000..59928dce370 --- /dev/null +++ b/xarray/tests/test_duck_array_wrapping.py @@ -0,0 +1,510 @@ +import numpy as np +import pandas as pd +import pytest + +import xarray as xr + +# Don't run cupy in CI because it requires a GPU +NAMESPACE_ARRAYS = { + "cupy": { + "attrs": { + "array": "ndarray", + "constructor": "asarray", + }, + "xfails": {"quantile": "no nanquantile"}, + }, + "dask.array": { + "attrs": { + "array": "Array", + "constructor": "from_array", + }, + "xfails": { + "argsort": "no argsort", + "conjugate": "conj but no conjugate", + "searchsorted": "dask.array.searchsorted but no Array.searchsorted", + }, + }, + "jax.numpy": { + "attrs": { + "array": "ndarray", + "constructor": "asarray", + }, + "xfails": { + "rolling_construct": "no sliding_window_view", + "rolling_reduce": "no sliding_window_view", + "cumulative_construct": "no sliding_window_view", + "cumulative_reduce": "no sliding_window_view", + }, + }, + "pint": { + "attrs": { + "array": "Quantity", + "constructor": "Quantity", + }, + "xfails": { + "all": "returns a bool", + "any": "returns a bool", + "argmax": "returns an int", + "argmin": "returns an int", + "argsort": "returns an int", + "count": "returns an int", + "dot": "no tensordot", + "full_like": "should work, see: https://github.com/hgrecco/pint/pull/1669", + "idxmax": "returns the coordinate", + "idxmin": "returns the coordinate", + "isin": "returns a bool", + "isnull": "returns a bool", + "notnull": "returns a bool", + "rolling_reduce": "no dispatch for numbagg/bottleneck", + "cumulative_reduce": "no dispatch for numbagg/bottleneck", + "searchsorted": "returns an int", + "weighted": "no tensordot", + }, + }, + "sparse": { + "attrs": { + "array": "COO", + "constructor": "COO", + }, + "xfails": { + "cov": "dense output", + "corr": "no nanstd", + "cross": "no cross", + "count": "dense output", + "dot": "fails on some platforms/versions", + "isin": "no isin", + "rolling_construct": "no sliding_window_view", + "rolling_reduce": "no sliding_window_view", + "cumulative_construct": "no sliding_window_view", + "cumulative_reduce": "no sliding_window_view", + "coarsen_construct": "pad constant_values must be fill_value", + "coarsen_reduce": "pad constant_values must be fill_value", + "weighted": "fill_value error", + "coarsen": "pad constant_values must be fill_value", + "quantile": "no non skipping version", + "differentiate": "no gradient", + "argmax": "no nan skipping version", + "argmin": "no nan skipping version", + "idxmax": "no nan skipping version", + "idxmin": "no nan skipping version", + "median": "no nan skipping version", + "std": "no nan skipping version", + "var": "no nan skipping version", + "cumsum": "no cumsum", + "cumprod": "no cumprod", + "argsort": "no argsort", + "conjugate": "no conjugate", + "searchsorted": "no searchsorted", + "shift": "pad constant_values must be fill_value", + "pad": "pad constant_values must be fill_value", + }, + }, +} + + +class _BaseTest: + def setup_for_test(self, request, namespace): + self.namespace = namespace + self.xp = pytest.importorskip(namespace) + self.Array = getattr(self.xp, NAMESPACE_ARRAYS[namespace]["attrs"]["array"]) + self.constructor = getattr( + self.xp, NAMESPACE_ARRAYS[namespace]["attrs"]["constructor"] + ) + xarray_method = request.node.name.split("test_")[1].split("[")[0] + if xarray_method in NAMESPACE_ARRAYS[namespace]["xfails"]: + reason = NAMESPACE_ARRAYS[namespace]["xfails"][xarray_method] + pytest.xfail(f"xfail for {self.namespace}: {reason}") + + def get_test_dataarray(self): + data = np.asarray([[1, 2, 3, np.nan, 5]]) + x = np.arange(5) + data = self.constructor(data) + return xr.DataArray( + data, + dims=["y", "x"], + coords={"y": [1], "x": x}, + name="foo", + ) + + +@pytest.mark.parametrize("namespace", NAMESPACE_ARRAYS) +class TestTopLevelMethods(_BaseTest): + @pytest.fixture(autouse=True) + def setUp(self, request, namespace): + self.setup_for_test(request, namespace) + self.x1 = self.get_test_dataarray() + self.x2 = self.get_test_dataarray().assign_coords(x=np.arange(2, 7)) + + def test_apply_ufunc(self): + func = lambda x: x + 1 + result = xr.apply_ufunc(func, self.x1, dask="parallelized") + assert isinstance(result.data, self.Array) + + def test_align(self): + result = xr.align(self.x1, self.x2) + assert isinstance(result[0].data, self.Array) + assert isinstance(result[1].data, self.Array) + + def test_broadcast(self): + result = xr.broadcast(self.x1, self.x2) + assert isinstance(result[0].data, self.Array) + assert isinstance(result[1].data, self.Array) + + def test_concat(self): + result = xr.concat([self.x1, self.x2], dim="x") + assert isinstance(result.data, self.Array) + + def test_merge(self): + result = xr.merge([self.x1, self.x2], compat="override") + assert isinstance(result.foo.data, self.Array) + + def test_where(self): + x1, x2 = xr.align(self.x1, self.x2, join="inner") + result = xr.where(x1 > 2, x1, x2) + assert isinstance(result.data, self.Array) + + def test_full_like(self): + result = xr.full_like(self.x1, 0) + assert isinstance(result.data, self.Array) + + def test_cov(self): + result = xr.cov(self.x1, self.x2) + assert isinstance(result.data, self.Array) + + def test_corr(self): + result = xr.corr(self.x1, self.x2) + assert isinstance(result.data, self.Array) + + def test_cross(self): + x1, x2 = xr.align(self.x1.squeeze(), self.x2.squeeze(), join="inner") + result = xr.cross(x1, x2, dim="x") + assert isinstance(result.data, self.Array) + + def test_dot(self): + result = xr.dot(self.x1, self.x2) + assert isinstance(result.data, self.Array) + + def test_map_blocks(self): + result = xr.map_blocks(lambda x: x + 1, self.x1) + assert isinstance(result.data, self.Array) + + +@pytest.mark.parametrize("namespace", NAMESPACE_ARRAYS) +class TestDataArrayMethods(_BaseTest): + @pytest.fixture(autouse=True) + def setUp(self, request, namespace): + self.setup_for_test(request, namespace) + self.x = self.get_test_dataarray() + + def test_loc(self): + result = self.x.loc[{"x": slice(1, 3)}] + assert isinstance(result.data, self.Array) + + def test_isel(self): + result = self.x.isel(x=slice(1, 3)) + assert isinstance(result.data, self.Array) + + def test_sel(self): + result = self.x.sel(x=slice(1, 3)) + assert isinstance(result.data, self.Array) + + def test_squeeze(self): + result = self.x.squeeze("y") + assert isinstance(result.data, self.Array) + + @pytest.mark.xfail(reason="interp uses numpy and scipy") + def test_interp(self): + # TODO: some cases could be made to work + result = self.x.interp(x=2.5) + assert isinstance(result.data, self.Array) + + def test_isnull(self): + result = self.x.isnull() + assert isinstance(result.data, self.Array) + + def test_notnull(self): + result = self.x.notnull() + assert isinstance(result.data, self.Array) + + def test_count(self): + result = self.x.count() + assert isinstance(result.data, self.Array) + + def test_dropna(self): + result = self.x.dropna(dim="x") + assert isinstance(result.data, self.Array) + + def test_fillna(self): + result = self.x.fillna(0) + assert isinstance(result.data, self.Array) + + @pytest.mark.xfail(reason="ffill uses bottleneck or numbagg") + def test_ffill(self): + result = self.x.ffill() + assert isinstance(result.data, self.Array) + + @pytest.mark.xfail(reason="bfill uses bottleneck or numbagg") + def test_bfill(self): + result = self.x.bfill() + assert isinstance(result.data, self.Array) + + @pytest.mark.xfail(reason="interpolate_na uses numpy and scipy") + def test_interpolate_na(self): + result = self.x.interpolate_na() + assert isinstance(result.data, self.Array) + + def test_where(self): + result = self.x.where(self.x > 2) + assert isinstance(result.data, self.Array) + + def test_isin(self): + test_elements = self.constructor(np.asarray([1])) + result = self.x.isin(test_elements) + assert isinstance(result.data, self.Array) + + def test_groupby(self): + result = self.x.groupby("x").mean() + assert isinstance(result.data, self.Array) + + def test_groupby_bins(self): + result = self.x.groupby_bins("x", bins=[0, 2, 4, 6]).mean() + assert isinstance(result.data, self.Array) + + def test_rolling_iter(self): + result = self.x.rolling(x=3) + elem = next(iter(result))[1] + assert isinstance(elem.data, self.Array) + + def test_rolling_construct(self): + result = self.x.rolling(x=3).construct(x="window") + assert isinstance(result.data, self.Array) + + @pytest.mark.parametrize("skipna", [True, False]) + def test_rolling_reduce(self, skipna): + result = self.x.rolling(x=3).mean(skipna=skipna) + assert isinstance(result.data, self.Array) + + @pytest.mark.xfail(reason="rolling_exp uses numbagg") + def test_rolling_exp_reduce(self): + result = self.x.rolling_exp(x=3).mean() + assert isinstance(result.data, self.Array) + + def test_cumulative_iter(self): + result = self.x.cumulative("x") + elem = next(iter(result))[1] + assert isinstance(elem.data, self.Array) + + def test_cumulative_construct(self): + result = self.x.cumulative("x").construct(x="window") + assert isinstance(result.data, self.Array) + + def test_cumulative_reduce(self): + result = self.x.cumulative("x").sum() + assert isinstance(result.data, self.Array) + + def test_weighted(self): + result = self.x.weighted(self.x.fillna(0)).mean() + assert isinstance(result.data, self.Array) + + def test_coarsen_construct(self): + result = self.x.coarsen(x=2, boundary="pad").construct(x=["a", "b"]) + assert isinstance(result.data, self.Array) + + def test_coarsen_reduce(self): + result = self.x.coarsen(x=2, boundary="pad").mean() + assert isinstance(result.data, self.Array) + + def test_resample(self): + time_coord = pd.date_range("2000-01-01", periods=5) + result = self.x.assign_coords(x=time_coord).resample(x="D").mean() + assert isinstance(result.data, self.Array) + + def test_diff(self): + result = self.x.diff("x") + assert isinstance(result.data, self.Array) + + def test_dot(self): + result = self.x.dot(self.x) + assert isinstance(result.data, self.Array) + + @pytest.mark.parametrize("skipna", [True, False]) + def test_quantile(self, skipna): + result = self.x.quantile(0.5, skipna=skipna) + assert isinstance(result.data, self.Array) + + def test_differentiate(self): + # edge_order is not implemented in jax, and only supports passing None + edge_order = None if self.namespace == "jax.numpy" else 1 + result = self.x.differentiate("x", edge_order=edge_order) + assert isinstance(result.data, self.Array) + + def test_integrate(self): + result = self.x.integrate("x") + assert isinstance(result.data, self.Array) + + @pytest.mark.xfail(reason="polyfit uses numpy linalg") + def test_polyfit(self): + # TODO: this could work, there are just a lot of different linalg calls + result = self.x.polyfit("x", 1) + assert isinstance(result.polyfit_coefficients.data, self.Array) + + def test_map_blocks(self): + result = self.x.map_blocks(lambda x: x + 1) + assert isinstance(result.data, self.Array) + + def test_all(self): + result = self.x.all(dim="x") + assert isinstance(result.data, self.Array) + + def test_any(self): + result = self.x.any(dim="x") + assert isinstance(result.data, self.Array) + + @pytest.mark.parametrize("skipna", [True, False]) + def test_argmax(self, skipna): + result = self.x.argmax(dim="x", skipna=skipna) + assert isinstance(result.data, self.Array) + + @pytest.mark.parametrize("skipna", [True, False]) + def test_argmin(self, skipna): + result = self.x.argmin(dim="x", skipna=skipna) + assert isinstance(result.data, self.Array) + + @pytest.mark.parametrize("skipna", [True, False]) + def test_idxmax(self, skipna): + result = self.x.idxmax(dim="x", skipna=skipna) + assert isinstance(result.data, self.Array) + + @pytest.mark.parametrize("skipna", [True, False]) + def test_idxmin(self, skipna): + result = self.x.idxmin(dim="x", skipna=skipna) + assert isinstance(result.data, self.Array) + + @pytest.mark.parametrize("skipna", [True, False]) + def test_max(self, skipna): + result = self.x.max(dim="x", skipna=skipna) + assert isinstance(result.data, self.Array) + + @pytest.mark.parametrize("skipna", [True, False]) + def test_min(self, skipna): + result = self.x.min(dim="x", skipna=skipna) + assert isinstance(result.data, self.Array) + + @pytest.mark.parametrize("skipna", [True, False]) + def test_mean(self, skipna): + result = self.x.mean(dim="x", skipna=skipna) + assert isinstance(result.data, self.Array) + + @pytest.mark.parametrize("skipna", [True, False]) + def test_median(self, skipna): + result = self.x.median(dim="x", skipna=skipna) + assert isinstance(result.data, self.Array) + + @pytest.mark.parametrize("skipna", [True, False]) + def test_prod(self, skipna): + result = self.x.prod(dim="x", skipna=skipna) + assert isinstance(result.data, self.Array) + + @pytest.mark.parametrize("skipna", [True, False]) + def test_sum(self, skipna): + result = self.x.sum(dim="x", skipna=skipna) + assert isinstance(result.data, self.Array) + + @pytest.mark.parametrize("skipna", [True, False]) + def test_std(self, skipna): + result = self.x.std(dim="x", skipna=skipna) + assert isinstance(result.data, self.Array) + + @pytest.mark.parametrize("skipna", [True, False]) + def test_var(self, skipna): + result = self.x.var(dim="x", skipna=skipna) + assert isinstance(result.data, self.Array) + + @pytest.mark.parametrize("skipna", [True, False]) + def test_cumsum(self, skipna): + result = self.x.cumsum(dim="x", skipna=skipna) + assert isinstance(result.data, self.Array) + + @pytest.mark.parametrize("skipna", [True, False]) + def test_cumprod(self, skipna): + result = self.x.cumprod(dim="x", skipna=skipna) + assert isinstance(result.data, self.Array) + + def test_argsort(self): + result = self.x.argsort() + assert isinstance(result.data, self.Array) + + def test_astype(self): + result = self.x.astype(int) + assert isinstance(result.data, self.Array) + + def test_clip(self): + result = self.x.clip(min=2.0, max=4.0) + assert isinstance(result.data, self.Array) + + def test_conj(self): + result = self.x.conj() + assert isinstance(result.data, self.Array) + + def test_conjugate(self): + result = self.x.conjugate() + assert isinstance(result.data, self.Array) + + def test_imag(self): + result = self.x.imag + assert isinstance(result.data, self.Array) + + def test_searchsorted(self): + v = self.constructor(np.asarray([3])) + result = self.x.squeeze().searchsorted(v) + assert isinstance(result, self.Array) + + def test_round(self): + result = self.x.round() + assert isinstance(result.data, self.Array) + + def test_real(self): + result = self.x.real + assert isinstance(result.data, self.Array) + + def test_T(self): + result = self.x.T + assert isinstance(result.data, self.Array) + + @pytest.mark.xfail(reason="rank uses bottleneck") + def test_rank(self): + # TODO: scipy has rankdata, as does jax, so this can work + result = self.x.rank() + assert isinstance(result.data, self.Array) + + def test_transpose(self): + result = self.x.transpose() + assert isinstance(result.data, self.Array) + + def test_stack(self): + result = self.x.stack(z=("x", "y")) + assert isinstance(result.data, self.Array) + + def test_unstack(self): + result = self.x.stack(z=("x", "y")).unstack("z") + assert isinstance(result.data, self.Array) + + def test_shift(self): + result = self.x.shift(x=1) + assert isinstance(result.data, self.Array) + + def test_roll(self): + result = self.x.roll(x=1) + assert isinstance(result.data, self.Array) + + def test_pad(self): + result = self.x.pad(x=1) + assert isinstance(result.data, self.Array) + + def test_sortby(self): + result = self.x.sortby("x") + assert isinstance(result.data, self.Array) + + def test_broadcast_like(self): + result = self.x.broadcast_like(self.x) + assert isinstance(result.data, self.Array) diff --git a/xarray/tests/test_strategies.py b/xarray/tests/test_strategies.py index 798f5f732d1..48819333ca2 100644 --- a/xarray/tests/test_strategies.py +++ b/xarray/tests/test_strategies.py @@ -13,6 +13,7 @@ from hypothesis import given from hypothesis.extra.array_api import make_strategies_namespace +from xarray.core.options import set_options from xarray.core.variable import Variable from xarray.testing.strategies import ( attrs, @@ -267,14 +268,14 @@ def test_mean(self, data, var): Test that given a Variable of at least one dimension, the mean of the Variable is always equal to the mean of the underlying array. """ + with set_options(use_numbagg=False): + # specify arbitrary reduction along at least one dimension + reduction_dims = data.draw(unique_subset_of(var.dims, min_size=1)) - # specify arbitrary reduction along at least one dimension - reduction_dims = data.draw(unique_subset_of(var.dims, min_size=1)) + # create expected result (using nanmean because arrays with Nans will be generated) + reduction_axes = tuple(var.get_axis_num(dim) for dim in reduction_dims) + expected = np.nanmean(var.data, axis=reduction_axes) - # create expected result (using nanmean because arrays with Nans will be generated) - reduction_axes = tuple(var.get_axis_num(dim) for dim in reduction_dims) - expected = np.nanmean(var.data, axis=reduction_axes) - - # assert property is always satisfied - result = var.mean(dim=reduction_dims).data - npt.assert_equal(expected, result) + # assert property is always satisfied + result = var.mean(dim=reduction_dims).data + npt.assert_equal(expected, result) diff --git a/xarray/tests/test_variable.py b/xarray/tests/test_variable.py index 9c6f50037d3..1461489e731 100644 --- a/xarray/tests/test_variable.py +++ b/xarray/tests/test_variable.py @@ -1978,26 +1978,27 @@ def test_reduce_funcs(self): def test_reduce_keepdims(self): v = Variable(["x", "y"], self.d) - assert_identical( - v.mean(keepdims=True), Variable(v.dims, np.mean(self.d, keepdims=True)) - ) - assert_identical( - v.mean(dim="x", keepdims=True), - Variable(v.dims, np.mean(self.d, axis=0, keepdims=True)), - ) - assert_identical( - v.mean(dim="y", keepdims=True), - Variable(v.dims, np.mean(self.d, axis=1, keepdims=True)), - ) - assert_identical( - v.mean(dim=["y", "x"], keepdims=True), - Variable(v.dims, np.mean(self.d, axis=(1, 0), keepdims=True)), - ) + with set_options(use_numbagg=False): + assert_identical( + v.mean(keepdims=True), Variable(v.dims, np.mean(self.d, keepdims=True)) + ) + assert_identical( + v.mean(dim="x", keepdims=True), + Variable(v.dims, np.mean(self.d, axis=0, keepdims=True)), + ) + assert_identical( + v.mean(dim="y", keepdims=True), + Variable(v.dims, np.mean(self.d, axis=1, keepdims=True)), + ) + assert_identical( + v.mean(dim=["y", "x"], keepdims=True), + Variable(v.dims, np.mean(self.d, axis=(1, 0), keepdims=True)), + ) - v = Variable([], 1.0) - assert_identical( - v.mean(keepdims=True), Variable([], np.mean(v.data, keepdims=True)) - ) + v = Variable([], 1.0) + assert_identical( + v.mean(keepdims=True), Variable([], np.mean(v.data, keepdims=True)) + ) @requires_dask def test_reduce_keepdims_dask(self): pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy