diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index faec5ded04e..262c023059a 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -224,14 +224,17 @@ def empty_like(a, **kwargs): return xp.empty_like(a, **kwargs) -def astype(data, dtype, **kwargs): - if hasattr(data, "__array_namespace__"): +def astype(data, dtype, *, xp=None, **kwargs): + if not hasattr(data, "__array_namespace__") and xp is None: + return data.astype(dtype, **kwargs) + + if xp is None: xp = get_array_namespace(data) - if xp == np: - # numpy currently doesn't have a astype: - return data.astype(dtype, **kwargs) - return xp.astype(data, dtype, **kwargs) - return data.astype(dtype, **kwargs) + + if xp == np: + # numpy currently doesn't have a astype: + return data.astype(dtype, **kwargs) + return xp.astype(data, dtype, **kwargs) def asarray(data, xp=np, dtype=None): @@ -373,6 +376,13 @@ def sum_where(data, axis=None, dtype=None, where=None): def where(condition, x, y): """Three argument where() with better dtype promotion rules.""" xp = get_array_namespace(condition, x, y) + + dtype = xp.bool_ if hasattr(xp, "bool_") else xp.bool + if not is_duck_array(condition): + condition = asarray(condition, dtype=dtype, xp=xp) + else: + condition = astype(condition, dtype=dtype, xp=xp) + return xp.where(condition, *as_shared_dtype([x, y], xp=xp)) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index b28ba390a9f..6a3ce156ce6 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -524,7 +524,7 @@ def factorize(self) -> EncodedGroups: # Restore these after the raveling broadcasted_masks = broadcast(*masks) mask = functools.reduce(np.logical_or, broadcasted_masks) # type: ignore[arg-type] - _flatcodes = where(mask, -1, _flatcodes) + _flatcodes = where(mask.data, -1, _flatcodes) full_index = pd.MultiIndex.from_product( (grouper.full_index.values for grouper in groupers), diff --git a/xarray/tests/test_array_api.py b/xarray/tests/test_array_api.py index c273260d7dd..022d2e3750e 100644 --- a/xarray/tests/test_array_api.py +++ b/xarray/tests/test_array_api.py @@ -139,7 +139,6 @@ def test_unstack(arrays: tuple[xr.DataArray, xr.DataArray]) -> None: assert_equal(actual, expected) -@pytest.mark.skip def test_where() -> None: np_arr = xr.DataArray(np.array([1, 0]), dims="x") xp_arr = xr.DataArray(xp.asarray([1, 0]), dims="x")
Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.
Alternative Proxies: