diff --git a/xarray/coding/times.py b/xarray/coding/times.py index ad5e8653e2a..7fffa595d94 100644 --- a/xarray/coding/times.py +++ b/xarray/coding/times.py @@ -22,7 +22,7 @@ ) from xarray.core import indexing from xarray.core.common import contains_cftime_datetimes, is_np_datetime_like -from xarray.core.duck_array_ops import asarray, ravel, reshape +from xarray.core.duck_array_ops import array_all, array_any, asarray, ravel, reshape from xarray.core.formatting import first_n_items, format_timestamp, last_item from xarray.core.pdcompat import default_precision_timestamp, timestamp_as_unit from xarray.core.utils import attempt_import, emit_user_level_warning @@ -676,7 +676,7 @@ def _infer_time_units_from_diff(unique_timedeltas) -> str: unit_timedelta = _unit_timedelta_numpy zero_timedelta = np.timedelta64(0, "ns") for time_unit in time_units: - if np.all(unique_timedeltas % unit_timedelta(time_unit) == zero_timedelta): + if array_all(unique_timedeltas % unit_timedelta(time_unit) == zero_timedelta): return time_unit return "seconds" @@ -939,7 +939,7 @@ def encode_datetime(d): def cast_to_int_if_safe(num) -> np.ndarray: int_num = np.asarray(num, dtype=np.int64) - if (num == int_num).all(): + if array_all(num == int_num): num = int_num return num @@ -961,7 +961,7 @@ def _cast_to_dtype_if_safe(num: np.ndarray, dtype: np.dtype) -> np.ndarray: cast_num = np.asarray(num, dtype=dtype) if np.issubdtype(dtype, np.integer): - if not (num == cast_num).all(): + if not array_all(num == cast_num): if np.issubdtype(num.dtype, np.floating): raise ValueError( f"Not possible to cast all encoded times from " @@ -979,7 +979,7 @@ def _cast_to_dtype_if_safe(num: np.ndarray, dtype: np.dtype) -> np.ndarray: "a larger integer dtype." ) else: - if np.isinf(cast_num).any(): + if array_any(np.isinf(cast_num)): raise OverflowError( f"Not possible to cast encoded times from {num.dtype!r} to " f"{dtype!r} without overflow. Consider removing the dtype " diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index c3f1598050a..45fdaee9768 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -16,8 +16,6 @@ import numpy as np import pandas as pd -from numpy import all as array_all # noqa: F401 -from numpy import any as array_any # noqa: F401 from numpy import ( # noqa: F401 isclose, isnat, @@ -319,7 +317,9 @@ def allclose_or_equiv(arr1, arr2, rtol=1e-5, atol=1e-8): if lazy_equiv is None: with warnings.catch_warnings(): warnings.filterwarnings("ignore", r"All-NaN (slice|axis) encountered") - return bool(isclose(arr1, arr2, rtol=rtol, atol=atol, equal_nan=True).all()) + return bool( + array_all(isclose(arr1, arr2, rtol=rtol, atol=atol, equal_nan=True)) + ) else: return lazy_equiv @@ -333,7 +333,7 @@ def array_equiv(arr1, arr2): with warnings.catch_warnings(): warnings.filterwarnings("ignore", "In the future, 'NAT == x'") flag_array = (arr1 == arr2) | (isnull(arr1) & isnull(arr2)) - return bool(flag_array.all()) + return bool(array_all(flag_array)) else: return lazy_equiv @@ -349,7 +349,7 @@ def array_notnull_equiv(arr1, arr2): with warnings.catch_warnings(): warnings.filterwarnings("ignore", "In the future, 'NAT == x'") flag_array = (arr1 == arr2) | isnull(arr1) | isnull(arr2) - return bool(flag_array.all()) + return bool(array_all(flag_array)) else: return lazy_equiv @@ -536,6 +536,16 @@ def f(values, axis=None, skipna=None, **kwargs): cumsum_1d.numeric_only = True +def array_all(array, axis=None, keepdims=False, **kwargs): + xp = get_array_namespace(array) + return xp.all(array, axis=axis, keepdims=keepdims, **kwargs) + + +def array_any(array, axis=None, keepdims=False, **kwargs): + xp = get_array_namespace(array) + return xp.any(array, axis=axis, keepdims=keepdims, **kwargs) + + _mean = _create_nan_agg_method("mean", invariant_0d=True) diff --git a/xarray/core/formatting.py b/xarray/core/formatting.py index ab17fa85381..a6bacccbeef 100644 --- a/xarray/core/formatting.py +++ b/xarray/core/formatting.py @@ -18,7 +18,7 @@ from pandas.errors import OutOfBoundsDatetime from xarray.core.datatree_render import RenderDataTree -from xarray.core.duck_array_ops import array_equiv, astype +from xarray.core.duck_array_ops import array_all, array_any, array_equiv, astype from xarray.core.indexing import MemoryCachedArray from xarray.core.options import OPTIONS, _get_boolean_with_default from xarray.core.treenode import group_subtrees @@ -204,9 +204,9 @@ def format_items(x): day_part = x[~pd.isnull(x)].astype("timedelta64[D]").astype("timedelta64[ns]") time_needed = x[~pd.isnull(x)] != day_part day_needed = day_part != np.timedelta64(0, "ns") - if np.logical_not(day_needed).all(): + if array_all(np.logical_not(day_needed)): timedelta_format = "time" - elif np.logical_not(time_needed).all(): + elif array_all(np.logical_not(time_needed)): timedelta_format = "date" formatted = [format_item(xi, timedelta_format) for xi in x] @@ -232,7 +232,7 @@ def format_array_flat(array, max_width: int): cum_len = np.cumsum([len(s) + 1 for s in relevant_items]) - 1 if (array.size > 2) and ( - (max_possibly_relevant < array.size) or (cum_len > max_width).any() + (max_possibly_relevant < array.size) or array_any(cum_len > max_width) ): padding = " ... " max_len = max(int(np.argmax(cum_len + len(padding) - 1 > max_width)), 2) diff --git a/xarray/core/nanops.py b/xarray/core/nanops.py index 4894cf02be2..17c60b6f663 100644 --- a/xarray/core/nanops.py +++ b/xarray/core/nanops.py @@ -45,7 +45,7 @@ def _nan_argminmax_object(func, fill_value, value, axis=None, **kwargs): data = getattr(np, func)(value, axis=axis, **kwargs) # TODO This will evaluate dask arrays and might be costly. - if (valid_count == 0).any(): + if duck_array_ops.array_any(valid_count == 0): raise ValueError("All-NaN slice encountered") return data diff --git a/xarray/core/weighted.py b/xarray/core/weighted.py index 269cb49a2c1..cd24091b18e 100644 --- a/xarray/core/weighted.py +++ b/xarray/core/weighted.py @@ -171,7 +171,7 @@ def __init__(self, obj: T_Xarray, weights: T_DataArray) -> None: def _weight_check(w): # Ref https://github.com/pydata/xarray/pull/4559/files#r515968670 - if duck_array_ops.isnull(w).any(): + if duck_array_ops.array_any(duck_array_ops.isnull(w)): raise ValueError( "`weights` cannot contain missing values. " "Missing values can be replaced by `weights.fillna(0)`." diff --git a/xarray/groupers.py b/xarray/groupers.py index dac4c4309de..32e5e712196 100644 --- a/xarray/groupers.py +++ b/xarray/groupers.py @@ -17,11 +17,10 @@ from numpy.typing import ArrayLike from xarray.coding.cftime_offsets import BaseCFTimeOffset, _new_to_legacy_freq -from xarray.core import duck_array_ops from xarray.core.computation import apply_ufunc from xarray.core.coordinates import Coordinates, _coordinates_from_variable from xarray.core.dataarray import DataArray -from xarray.core.duck_array_ops import isnull +from xarray.core.duck_array_ops import array_all, isnull from xarray.core.groupby import T_Group, _DummyGroup from xarray.core.indexes import safe_cast_to_index from xarray.core.resample_cftime import CFTimeGrouper @@ -235,7 +234,7 @@ def _factorize_unique(self) -> EncodedGroups: # look through group to find the unique values sort = not isinstance(self.group_as_index, pd.MultiIndex) unique_values, codes_ = unique_value_groups(self.group_as_index, sort=sort) - if (codes_ == -1).all(): + if array_all(codes_ == -1): raise ValueError( "Failed to group data. Are you grouping by a variable that is all NaN?" ) @@ -347,7 +346,7 @@ def reset(self) -> Self: ) def __post_init__(self) -> None: - if duck_array_ops.isnull(self.bins).all(): + if array_all(isnull(self.bins)): raise ValueError("All bin edges are NaN.") def _cut(self, data): @@ -381,7 +380,7 @@ def factorize(self, group: T_Group) -> EncodedGroups: f"Bin edges must be provided when grouping by chunked arrays. Received {self.bins=!r} instead" ) codes = self._factorize_lazy(group) - if not by_is_chunked and (codes == -1).all(): + if not by_is_chunked and array_all(codes == -1): raise ValueError( f"None of the data falls within bins with edges {self.bins!r}" ) @@ -547,7 +546,7 @@ def factorize(self, group: T_Group) -> EncodedGroups: def _factorize_given_labels(data: np.ndarray, labels: np.ndarray) -> np.ndarray: # Copied from flox sorter = np.argsort(labels) - is_sorted = (sorter == np.arange(sorter.size)).all() + is_sorted = array_all(sorter == np.arange(sorter.size)) codes = np.searchsorted(labels, data, sorter=sorter) mask = ~np.isin(data, labels) | isnull(data) | (codes == len(labels)) # codes is the index in to the sorted array. diff --git a/xarray/testing/assertions.py b/xarray/testing/assertions.py index 6d87537a523..8a2dba9261f 100644 --- a/xarray/testing/assertions.py +++ b/xarray/testing/assertions.py @@ -296,7 +296,7 @@ def assert_duckarray_equal(x, y, err_msg="", verbose=True): if (utils.is_duck_array(x) and utils.is_scalar(y)) or ( utils.is_scalar(x) and utils.is_duck_array(y) ): - equiv = (x == y).all() + equiv = duck_array_ops.array_all(x == y) else: equiv = duck_array_ops.array_equiv(x, y) assert equiv, _format_message(x, y, err_msg=err_msg, verbose=verbose)
Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.
Alternative Proxies: