diff --git a/doc/api/datatree.rst b/doc/api/datatree.rst index 8501440b7d7..49abd501407 100644 --- a/doc/api/datatree.rst +++ b/doc/api/datatree.rst @@ -195,6 +195,7 @@ Index into all nodes in the subtree simultaneously. DataTree.isel DataTree.sel + DataTree.subset .. DataTree.drop_sel .. DataTree.drop_isel diff --git a/doc/whats-new.rst b/doc/whats-new.rst index bace038bb17..bd2f6fef525 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -13,6 +13,9 @@ v2025.07.2 (unreleased) New Features ~~~~~~~~~~~~ +- Added :py:meth:`~xarray.DataTree.subset` to index variables on all nodes of a datatree (:pull:`10400`) + By `Mathias Hauser `_. + Breaking changes ~~~~~~~~~~~~~~~~ diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index afef2f20094..4cdca5dcec1 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -10,6 +10,7 @@ Iterable, Iterator, Mapping, + Sequence, ) from html import escape from typing import ( @@ -1014,6 +1015,35 @@ def __delitem__(self, key: str) -> None: else: raise KeyError(key) + def subset( + self, keys: str | Sequence[str], *, errors: ErrorOptions = "raise" + ) -> DataTree: + """Index DataArrays on each node + + Parameters + ---------- + keys : str | Sequence[str] + Name of the data variables to index. + errors : "raise", "ignore" + Whether to raise a key error if a data variable is missing on a node. + + Returns + ------- + out : DataTree + """ + + if isinstance(keys, str): + keys = [keys] + + def getitem(ds): + keys_for_ds = keys + if errors == "ignore": + keys_for_ds = [key for key in keys if key in ds.data_vars] + + return ds[keys_for_ds] + + return map_over_datasets(getitem, self) + @overload def update(self, other: Dataset) -> None: ... diff --git a/xarray/tests/test_datatree.py b/xarray/tests/test_datatree.py index 2bf079a7cbd..836d61f304a 100644 --- a/xarray/tests/test_datatree.py +++ b/xarray/tests/test_datatree.py @@ -333,6 +333,34 @@ def test_getitem_dict_like_selection_access_to_dataset(self) -> None: assert_identical(results[{"temp": 1}], data[{"temp": 1}]) # type: ignore[index] +def test_subset() -> None: + ds1 = xr.Dataset(data_vars={"var1": ("x", [1, 2]), "var2": ("x", [0, 1])}) + ds2 = xr.Dataset(data_vars={"var1": ("x", [1, 2])}) + dt = xr.DataTree.from_dict({"ds1": ds1, "ds2": ds2}) + + dt_var1 = xr.DataTree.from_dict({"ds1": ds1[["var1"]], "ds2": ds2}) + + # errors as map_over_datasets does not skip empty nodes + with pytest.raises(KeyError, match="var1"): + dt.subset("var1") + + # will still error if map_over_datasets will ever skip empty nodes + with pytest.raises(KeyError, match="var2"): + dt.subset("var2") + + result = dt.subset("var1", errors="ignore") + expected = dt_var1 + xr.testing.assert_equal(result, expected) + + result = dt.subset(["var1"], errors="ignore") + expected = dt_var1 + xr.testing.assert_equal(result, expected) + + result = dt.subset(["var1", "var2"], errors="ignore") + expected = dt + xr.testing.assert_equal(result, expected) + + class TestUpdate: def test_update(self) -> None: dt = DataTree() pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy