Skip to content

Commit 3c07cce

Browse files
committed
add kwargs to map_over_datasets (similar to apply_ufunc), add test.
1 parent c252152 commit 3c07cce

File tree

2 files changed

+25
-1
lines changed

2 files changed

+25
-1
lines changed

xarray/core/datatree_mapping.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import sys
44
from collections.abc import Callable, Mapping
5+
from functools import partial
56
from typing import TYPE_CHECKING, Any, cast, overload
67

78
from xarray.core.dataset import Dataset
@@ -31,7 +32,9 @@ def map_over_datasets(
3132

3233

3334
def map_over_datasets(
34-
func: Callable[..., Dataset | None | tuple[Dataset | None, ...]], *args: Any
35+
func: Callable[..., Dataset | None | tuple[Dataset | None, ...]],
36+
*args: Any,
37+
kwargs: Mapping | None = None,
3538
) -> DataTree | tuple[DataTree, ...]:
3639
"""
3740
Applies a function to every dataset in one or more DataTree objects with
@@ -68,6 +71,8 @@ def map_over_datasets(
6871
*args : tuple, optional
6972
Positional arguments passed on to `func`. Any DataTree arguments will be
7073
converted to Dataset objects via `.dataset`.
74+
kwargs : dict, optional
75+
Optional keyword arguments passed directly on to call ``func``.
7176
7277
Returns
7378
-------
@@ -85,6 +90,12 @@ def map_over_datasets(
8590

8691
from xarray.core.datatree import DataTree
8792

93+
if kwargs is None:
94+
kwargs = {}
95+
96+
if kwargs:
97+
func = partial(func, **kwargs)
98+
8899
# Walk all trees simultaneously, applying func to all nodes that lie in same position in different trees
89100
# We don't know which arguments are DataTrees so we zip all arguments together as iterables
90101
# Store tuples of results in a dict because we don't yet know how many trees we need to rebuild to return

xarray/tests/test_datatree_mapping.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,19 @@ def test_single_tree_arg_plus_arg(self, create_test_datatree):
5151
result_tree = map_over_datasets(lambda x, y: x * y, 10.0, dt)
5252
assert_equal(result_tree, expected)
5353

54+
def test_single_tree_arg_plus_kwarg(self, create_test_datatree):
55+
dt = create_test_datatree()
56+
expected = create_test_datatree(modify=lambda ds: (10.0 * ds))
57+
58+
def multiply_by_kwarg(ds, **kwargs):
59+
ds = ds * kwargs.pop("multiplier")
60+
return ds
61+
62+
result_tree = map_over_datasets(
63+
multiply_by_kwarg, dt, kwargs=dict(multiplier=10.0)
64+
)
65+
assert_equal(result_tree, expected)
66+
5467
def test_multiple_tree_args(self, create_test_datatree):
5568
dt1 = create_test_datatree()
5669
dt2 = create_test_datatree()

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy