From 57155abfbfd311c17acfd4ced66604e018ce6ebf Mon Sep 17 00:00:00 2001 From: Justus Magin Date: Sun, 2 Mar 2025 19:03:39 +0100 Subject: [PATCH 01/12] explicitly cast the dtype of `condition` to `bool` --- xarray/core/duck_array_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index faec5ded04e..7ea283b87c2 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -373,7 +373,7 @@ def sum_where(data, axis=None, dtype=None, where=None): def where(condition, x, y): """Three argument where() with better dtype promotion rules.""" xp = get_array_namespace(condition, x, y) - return xp.where(condition, *as_shared_dtype([x, y], xp=xp)) + return xp.where(xp.astype(condition, xp.bool), *as_shared_dtype([x, y], xp=xp)) def where_method(data, cond, other=dtypes.NA): From 84a2c84de488b23a0ff24a43ce9671012e0aed9b Mon Sep 17 00:00:00 2001 From: Justus Magin Date: Tue, 4 Mar 2025 21:37:47 +0100 Subject: [PATCH 02/12] cast `condition` to bool in every case for `where` --- xarray/core/duck_array_ops.py | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 7ea283b87c2..6ddf664faf9 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -224,14 +224,17 @@ def empty_like(a, **kwargs): return xp.empty_like(a, **kwargs) -def astype(data, dtype, **kwargs): - if hasattr(data, "__array_namespace__"): +def astype(data, dtype, *, xp=None, **kwargs): + if not hasattr(data, "__array_namespace__") and xp is None: + return data.astype(dtype, **kwargs) + + if xp is None: xp = get_array_namespace(data) - if xp == np: - # numpy currently doesn't have a astype: - return data.astype(dtype, **kwargs) - return xp.astype(data, dtype, **kwargs) - return data.astype(dtype, **kwargs) + + if xp == np: + # numpy currently doesn't have a astype: + return data.astype(dtype, **kwargs) + return xp.astype(data, dtype, **kwargs) def asarray(data, xp=np, dtype=None): @@ -373,7 +376,13 @@ def sum_where(data, axis=None, dtype=None, where=None): def where(condition, x, y): """Three argument where() with better dtype promotion rules.""" xp = get_array_namespace(condition, x, y) - return xp.where(xp.astype(condition, xp.bool), *as_shared_dtype([x, y], xp=xp)) + + if not is_duck_array(condition): + condition = asarray(condition, dtype=xp.bool, xp=xp) + else: + condition = astype(condition, dtype=xp.bool, xp=xp) + + return xp.where(condition, *as_shared_dtype([x, y], xp=xp)) def where_method(data, cond, other=dtypes.NA): From d4ad871bda075debd14a83af8b1c41dccb55d10f Mon Sep 17 00:00:00 2001 From: Justus Magin Date: Tue, 4 Mar 2025 21:38:18 +0100 Subject: [PATCH 03/12] don't pass a `DataArray` to `where` --- xarray/core/groupby.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index b28ba390a9f..6a3ce156ce6 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -524,7 +524,7 @@ def factorize(self) -> EncodedGroups: # Restore these after the raveling broadcasted_masks = broadcast(*masks) mask = functools.reduce(np.logical_or, broadcasted_masks) # type: ignore[arg-type] - _flatcodes = where(mask, -1, _flatcodes) + _flatcodes = where(mask.data, -1, _flatcodes) full_index = pd.MultiIndex.from_product( (grouper.full_index.values for grouper in groupers), From 56b254df13f7573e1a9d3fdfcac97a6c26108e5f Mon Sep 17 00:00:00 2001 From: Justus Magin Date: Tue, 4 Mar 2025 21:51:05 +0100 Subject: [PATCH 04/12] use strings to specify the dtype for backwards compat --- xarray/core/duck_array_ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 6ddf664faf9..6cff20b5010 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -378,9 +378,9 @@ def where(condition, x, y): xp = get_array_namespace(condition, x, y) if not is_duck_array(condition): - condition = asarray(condition, dtype=xp.bool, xp=xp) + condition = asarray(condition, dtype="bool", xp=xp) else: - condition = astype(condition, dtype=xp.bool, xp=xp) + condition = astype(condition, dtype="bool", xp=xp) return xp.where(condition, *as_shared_dtype([x, y], xp=xp)) From ea3d4f783ce198fd42363b2c8262a2e237cd8eab Mon Sep 17 00:00:00 2001 From: Justus Magin Date: Tue, 4 Mar 2025 22:52:13 +0100 Subject: [PATCH 05/12] revert the strings and instead ignore the warning --- pyproject.toml | 3 +++ xarray/core/duck_array_ops.py | 4 ++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 817fda6c328..9029ca6b482 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -358,6 +358,9 @@ filterwarnings = [ "default:the `pandas.MultiIndex` object:FutureWarning:xarray.tests.test_variable", "default:Using a non-tuple sequence for multidimensional indexing is deprecated:FutureWarning", "default:Duplicate dimension names present:UserWarning:xarray.namedarray.core", + # TODO: numpy.bool was deprecated in older versions of numpy, but is in the Array API + # TODO: remove once we can drop numpy<2 + "ignore:In the future `np.bool` will be defined as the corresponding NumpPy scalar:FutureWarning", # TODO: this is raised for vlen-utf8, consolidated metadata, U1 dtype "ignore:is currently not part .* the Zarr version 3 specification.", # TODO: remove once we know how to deal with a changed signature in protocols diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 6cff20b5010..6ddf664faf9 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -378,9 +378,9 @@ def where(condition, x, y): xp = get_array_namespace(condition, x, y) if not is_duck_array(condition): - condition = asarray(condition, dtype="bool", xp=xp) + condition = asarray(condition, dtype=xp.bool, xp=xp) else: - condition = astype(condition, dtype="bool", xp=xp) + condition = astype(condition, dtype=xp.bool, xp=xp) return xp.where(condition, *as_shared_dtype([x, y], xp=xp)) From 8fbfa06344d25bf014d106782b24b677abe87445 Mon Sep 17 00:00:00 2001 From: Justus Magin Date: Tue, 4 Mar 2025 23:15:37 +0100 Subject: [PATCH 06/12] typo --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 9029ca6b482..a7b0b37a759 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -360,7 +360,7 @@ filterwarnings = [ "default:Duplicate dimension names present:UserWarning:xarray.namedarray.core", # TODO: numpy.bool was deprecated in older versions of numpy, but is in the Array API # TODO: remove once we can drop numpy<2 - "ignore:In the future `np.bool` will be defined as the corresponding NumpPy scalar:FutureWarning", + "ignore:In the future `np.bool` will be defined as the corresponding NumPy scalar:FutureWarning", # TODO: this is raised for vlen-utf8, consolidated metadata, U1 dtype "ignore:is currently not part .* the Zarr version 3 specification.", # TODO: remove once we know how to deal with a changed signature in protocols From 00b9720eda7b351f4e596dc4d4e0e87a27ddae3d Mon Sep 17 00:00:00 2001 From: Justus Magin Date: Tue, 4 Mar 2025 23:15:50 +0100 Subject: [PATCH 07/12] restrict to just numpy --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index a7b0b37a759..1afcc8fef14 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -360,7 +360,7 @@ filterwarnings = [ "default:Duplicate dimension names present:UserWarning:xarray.namedarray.core", # TODO: numpy.bool was deprecated in older versions of numpy, but is in the Array API # TODO: remove once we can drop numpy<2 - "ignore:In the future `np.bool` will be defined as the corresponding NumPy scalar:FutureWarning", + "ignore:In the future `np.bool` will be defined as the corresponding NumPy scalar:FutureWarning:numpy.*", # TODO: this is raised for vlen-utf8, consolidated metadata, U1 dtype "ignore:is currently not part .* the Zarr version 3 specification.", # TODO: remove once we know how to deal with a changed signature in protocols From 58ecc32f68f9423c9db7bb2184b0e5843792a663 Mon Sep 17 00:00:00 2001 From: Justus Magin Date: Tue, 4 Mar 2025 23:23:53 +0100 Subject: [PATCH 08/12] unrestrict --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 1afcc8fef14..a7b0b37a759 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -360,7 +360,7 @@ filterwarnings = [ "default:Duplicate dimension names present:UserWarning:xarray.namedarray.core", # TODO: numpy.bool was deprecated in older versions of numpy, but is in the Array API # TODO: remove once we can drop numpy<2 - "ignore:In the future `np.bool` will be defined as the corresponding NumPy scalar:FutureWarning:numpy.*", + "ignore:In the future `np.bool` will be defined as the corresponding NumPy scalar:FutureWarning", # TODO: this is raised for vlen-utf8, consolidated metadata, U1 dtype "ignore:is currently not part .* the Zarr version 3 specification.", # TODO: remove once we know how to deal with a changed signature in protocols From afce80b39ea8013e08a707c889c314362a3e9ed1 Mon Sep 17 00:00:00 2001 From: Justus Magin Date: Wed, 12 Mar 2025 18:12:42 +0100 Subject: [PATCH 09/12] fall back to `xp.bool_` if `xp.bool` doesn't exist --- xarray/core/duck_array_ops.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 6ddf664faf9..30531dafc9c 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -377,10 +377,11 @@ def where(condition, x, y): """Three argument where() with better dtype promotion rules.""" xp = get_array_namespace(condition, x, y) + dtype = xp.bool if hasattr(xp, "bool") else xp.bool_ if not is_duck_array(condition): - condition = asarray(condition, dtype=xp.bool, xp=xp) + condition = asarray(condition, dtype=dtype, xp=xp) else: - condition = astype(condition, dtype=xp.bool, xp=xp) + condition = astype(condition, dtype=dtype, xp=xp) return xp.where(condition, *as_shared_dtype([x, y], xp=xp)) From 8a3e0d26b0b1e8b61cea384d1328838972c9df37 Mon Sep 17 00:00:00 2001 From: Justus Magin Date: Wed, 12 Mar 2025 18:18:35 +0100 Subject: [PATCH 10/12] unskip the `where` test --- xarray/tests/test_array_api.py | 1 - 1 file changed, 1 deletion(-) diff --git a/xarray/tests/test_array_api.py b/xarray/tests/test_array_api.py index c273260d7dd..022d2e3750e 100644 --- a/xarray/tests/test_array_api.py +++ b/xarray/tests/test_array_api.py @@ -139,7 +139,6 @@ def test_unstack(arrays: tuple[xr.DataArray, xr.DataArray]) -> None: assert_equal(actual, expected) -@pytest.mark.skip def test_where() -> None: np_arr = xr.DataArray(np.array([1, 0]), dims="x") xp_arr = xr.DataArray(xp.asarray([1, 0]), dims="x") From 3e48223733fd84828d7cdd5782fb9543286d4cb9 Mon Sep 17 00:00:00 2001 From: Justus Magin Date: Wed, 12 Mar 2025 18:20:07 +0100 Subject: [PATCH 11/12] reverse to avoid warnings --- xarray/core/duck_array_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 30531dafc9c..262c023059a 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -377,7 +377,7 @@ def where(condition, x, y): """Three argument where() with better dtype promotion rules.""" xp = get_array_namespace(condition, x, y) - dtype = xp.bool if hasattr(xp, "bool") else xp.bool_ + dtype = xp.bool_ if hasattr(xp, "bool_") else xp.bool if not is_duck_array(condition): condition = asarray(condition, dtype=dtype, xp=xp) else: From 420d1c013fcc1b8ed714f8311a8372d98c30ff8e Mon Sep 17 00:00:00 2001 From: Justus Magin Date: Wed, 12 Mar 2025 18:51:02 +0100 Subject: [PATCH 12/12] remove the outdated ignore --- pyproject.toml | 3 --- 1 file changed, 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f5d6b004ae3..85c9183b30e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -344,9 +344,6 @@ filterwarnings = [ "default:the `pandas.MultiIndex` object:FutureWarning:xarray.tests.test_variable", "default:Using a non-tuple sequence for multidimensional indexing is deprecated:FutureWarning", "default:Duplicate dimension names present:UserWarning:xarray.namedarray.core", - # TODO: numpy.bool was deprecated in older versions of numpy, but is in the Array API - # TODO: remove once we can drop numpy<2 - "ignore:In the future `np.bool` will be defined as the corresponding NumPy scalar:FutureWarning", # TODO: this is raised for vlen-utf8, consolidated metadata, U1 dtype "ignore:is currently not part .* the Zarr version 3 specification.", # TODO: remove once we know how to deal with a changed signature in protocols pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy