From 6fe1bc46f59d51d886ef93bf0ab72321fcfe42cc Mon Sep 17 00:00:00 2001 From: Mark Harfouche Date: Wed, 30 Apr 2025 12:07:53 -0400 Subject: [PATCH 1/4] Add a test to showcase the differences in the set_dims behavior xref: https://github.com/pydata/xarray/issues/9462 --- xarray/tests/test_variable.py | 60 +++++++++++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) diff --git a/xarray/tests/test_variable.py b/xarray/tests/test_variable.py index 619dc1561ef..d1afa10c5fa 100644 --- a/xarray/tests/test_variable.py +++ b/xarray/tests/test_variable.py @@ -1653,6 +1653,66 @@ def test_set_dims_object_dtype(self): expected = Variable(["x"], exp_values) assert_identical(actual, expected) + def test_set_dims_without_broadcast(self): + class ArrayWithoutBroadcastTo(NDArrayMixin, indexing.ExplicitlyIndexed): + def __init__(self, array): + self.array = array + + # Broadcasting with __getitem__ is "easier" to implement + # especially for dims of 1 + def __getitem__(self, key): + return self.array[key] + + def __array_function__(self, *args, **kwargs): + raise NotImplementedError( + "Not we don't want to use broadcast_to here " + "https://github.com/pydata/xarray/issues/9462" + ) + + arr = ArrayWithoutBroadcastTo(np.zeros((3, 4))) + # We should be able to add a new axis without broadcasting + assert arr[np.newaxis, :, :].shape == (1, 3, 4) + with pytest.raises(NotImplementedError): + np.broadcast_to(arr, (1, 3, 4)) + + v = Variable(["x", "y"], arr) + v_expanded = v.set_dims(["z", "x", "y"]) + assert v_expanded.dims == ("z", "x", "y") + assert v_expanded.shape == (1, 3, 4) + + # Explicitly asking for a shape of 1 triggers a different + # codepath in set_dims + # https://github.com/pydata/xarray/issues/9462 + v_expanded = v.set_dims(["z", "x", "y"], shape=(1, 3, 4)) + assert v_expanded.dims == ("z", "x", "y") + assert v_expanded.shape == (1, 3, 4) + + v_expanded = v.set_dims(["x", "z", "y"], shape=(3, 1, 4)) + assert v_expanded.dims == ("x", "z", "y") + assert v_expanded.shape == (3, 1, 4) + + v_expanded = v.set_dims(["x", "y", "z"], shape=(3, 4, 1)) + assert v_expanded.dims == ("x", "y", "z") + assert v_expanded.shape == (3, 4, 1) + + v_expanded = v.set_dims({"z": 1, "x": 3, "y": 4}) + assert v_expanded.dims == ("z", "x", "y") + assert v_expanded.shape == (1, 3, 4) + + v_expanded = v.set_dims({"x": 3, "z": 1, "y": 4}) + assert v_expanded.dims == ("x", "z", "y") + assert v_expanded.shape == (3, 1, 4) + + v_expanded = v.set_dims({"x": 3, "y": 4, "z": 1}) + assert v_expanded.dims == ("x", "y", "z") + assert v_expanded.shape == (3, 4, 1) + + with pytest.raises(NotImplementedError): + v.set_dims({"z": 2, "x": 3, "y": 4}) + + with pytest.raises(NotImplementedError): + v.set_dims(["z", "x", "y"], shape=(2, 3, 4)) + def test_stack(self): v = Variable(["x", "y"], [[0, 1], [2, 3]], {"foo": "bar"}) actual = v.stack(z=("x", "y")) From fb67fc273a1fe6ec0ae90882daefbc7307311621 Mon Sep 17 00:00:00 2001 From: Mark Harfouche Date: Wed, 30 Apr 2025 12:08:18 -0400 Subject: [PATCH 2/4] Expand on the implementation of set_dims to make the trivial case easier --- xarray/core/variable.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/xarray/core/variable.py b/xarray/core/variable.py index b8b33997780..7ca22cd98e5 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -1355,7 +1355,7 @@ def set_dims(self, dim, shape=None): dim = [dim] if shape is None and is_dict_like(dim): - shape = dim.values() + shape = tuple(dim.values()) missing_dims = set(self.dims) - set(dim) if missing_dims: @@ -1371,13 +1371,18 @@ def set_dims(self, dim, shape=None): # don't use broadcast_to unless necessary so the result remains # writeable if possible expanded_data = self.data - elif shape is not None: + elif shape is None or all( + s == 1 for s, e in zip(shape, dim, strict=True) if e not in self_dims + ): + # "Trivial" broadcasting, i.e. simply inserting a new dimension + # This is typically easier for duck arrays to implement + # than the full "broadcast_to" semantics + indexer = (None,) * (len(expanded_dims) - self.ndim) + (...,) + expanded_data = self.data[indexer] + else: # elif shape is not None: dims_map = dict(zip(dim, shape, strict=True)) tmp_shape = tuple(dims_map[d] for d in expanded_dims) expanded_data = duck_array_ops.broadcast_to(self._data, tmp_shape) - else: - indexer = (None,) * (len(expanded_dims) - self.ndim) + (...,) - expanded_data = self.data[indexer] expanded_var = Variable( expanded_dims, expanded_data, self._attrs, self._encoding, fastpath=True From db27ee019e386121529b273ab8d8c81b2d4be1b9 Mon Sep 17 00:00:00 2001 From: Mark Harfouche Date: Wed, 30 Apr 2025 12:08:22 -0400 Subject: [PATCH 3/4] Add release note for https://github.com/pydata/xarray/pull/10277 --- doc/whats-new.rst | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 76fb5d42aa9..f85378d7176 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -97,6 +97,11 @@ Internal Changes ~~~~~~~~~~~~~~~~ - Avoid stacking when grouping by a chunked array. This can be a large performance improvement. By `Deepak Cherian `_. +- The implementation of ``Variable.set_dims`` has changed to use array indexing syntax + instead of ``np.broadcast_to`` to perform dimension expansions where + all new dimensions have a size of 1. This should improve compatibility with + duck arrays that do not support broadcasting (:issue:`9462`, :pull:`10277`). + By `Mark Harfouche `_. .. _whats-new.2025.03.1: From a4ee4e54f97bd4ae5645abfe8bacbfae4edff40c Mon Sep 17 00:00:00 2001 From: Mark Harfouche Date: Thu, 1 May 2025 10:04:39 -0400 Subject: [PATCH 4/4] Update test_variable.py --- xarray/tests/test_variable.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/xarray/tests/test_variable.py b/xarray/tests/test_variable.py index d1afa10c5fa..1e7c32dec1e 100644 --- a/xarray/tests/test_variable.py +++ b/xarray/tests/test_variable.py @@ -1680,6 +1680,14 @@ def __array_function__(self, *args, **kwargs): assert v_expanded.dims == ("z", "x", "y") assert v_expanded.shape == (1, 3, 4) + v_expanded = v.set_dims(["x", "z", "y"]) + assert v_expanded.dims == ("x", "z", "y") + assert v_expanded.shape == (3, 1, 4) + + v_expanded = v.set_dims(["x", "y", "z"]) + assert v_expanded.dims == ("x", "y", "z") + assert v_expanded.shape == (3, 4, 1) + # Explicitly asking for a shape of 1 triggers a different # codepath in set_dims # https://github.com/pydata/xarray/issues/9462 pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy