From 7e3a6a4370e2fcb4a4b13ebe86ac08a92a7ed3cd Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Fri, 28 Jun 2024 11:00:08 -0400 Subject: [PATCH 01/33] Add SeasonGrouper, SeasonResampler These two groupers allow defining custom seasons, and dropping incomplete seasons from the output. Both cases are treated by adjusting the factorization -- conversion from group labels to integer codes -- appropriately. --- doc/api.rst | 2 + properties/test_properties.py | 31 +++++ xarray/core/toolzcompat.py | 56 +++++++++ xarray/groupers.py | 220 ++++++++++++++++++++++++++++++++++ xarray/tests/test_groupby.py | 20 ++++ 5 files changed, 329 insertions(+) create mode 100644 xarray/core/toolzcompat.py diff --git a/doc/api.rst b/doc/api.rst index 63427447d53..ddba586fc83 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -1139,6 +1139,8 @@ Grouper Objects groupers.BinGrouper groupers.UniqueGrouper groupers.TimeResampler + groupers.SeasonGrouper + groupers.SeasonResampler Rolling objects diff --git a/properties/test_properties.py b/properties/test_properties.py index fc0a1955539..859f9d4e500 100644 --- a/properties/test_properties.py +++ b/properties/test_properties.py @@ -2,10 +2,12 @@ pytest.importorskip("hypothesis") +import hypothesis.strategies as st from hypothesis import given import xarray as xr import xarray.testing.strategies as xrst +from xarray.groupers import season_to_month_tuple @given(attrs=xrst.simple_attrs) @@ -15,3 +17,32 @@ def test_assert_identical(attrs): ds = xr.Dataset(attrs=attrs) xr.testing.assert_identical(ds, ds.copy(deep=True)) + + +@given( + roll=st.integers(min_value=0, max_value=12), + breaks=st.lists( + st.integers(min_value=0, max_value=11), min_size=1, max_size=12, unique=True + ), +) +def test_property_season_month_tuple(roll, breaks): + chars = list("JFMAMJJASOND") + months = tuple(range(1, 13)) + + rolled_chars = chars[roll:] + chars[:roll] + rolled_months = months[roll:] + months[:roll] + breaks = sorted(breaks) + if breaks[0] != 0: + breaks = [0] + breaks + if breaks[-1] != 12: + breaks = breaks + [12] + seasons = tuple( + "".join(rolled_chars[start:stop]) + for start, stop in zip(breaks[:-1], breaks[1:], strict=False) + ) + actual = season_to_month_tuple(seasons) + expected = tuple( + rolled_months[start:stop] + for start, stop in zip(breaks[:-1], breaks[1:], strict=False) + ) + assert expected == actual diff --git a/xarray/core/toolzcompat.py b/xarray/core/toolzcompat.py new file mode 100644 index 00000000000..4632419a845 --- /dev/null +++ b/xarray/core/toolzcompat.py @@ -0,0 +1,56 @@ +# This file contains functions copied from the toolz library in accordance +# with its license. The original copyright notice is duplicated below. + +# Copyright (c) 2013 Matthew Rocklin + +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# a. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# b. Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# c. Neither the name of toolz nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. + + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE FOR +# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY +# OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH +# DAMAGE. + + +def sliding_window(n, seq): + """A sequence of overlapping subsequences + + >>> list(sliding_window(2, [1, 2, 3, 4])) + [(1, 2), (2, 3), (3, 4)] + + This function creates a sliding window suitable for transformations like + sliding means / smoothing + + >>> mean = lambda seq: float(sum(seq)) / len(seq) + >>> list(map(mean, sliding_window(2, [1, 2, 3, 4]))) + [1.5, 2.5, 3.5] + """ + import collections + import itertools + + return zip( + *( + collections.deque(itertools.islice(it, i), 0) or it + for i, it in enumerate(itertools.tee(seq, n)) + ), + strict=False, + ) diff --git a/xarray/groupers.py b/xarray/groupers.py index 89b189e582e..fcbadeea2c8 100644 --- a/xarray/groupers.py +++ b/xarray/groupers.py @@ -7,7 +7,9 @@ from __future__ import annotations import datetime +import itertools from abc import ABC, abstractmethod +from collections.abc import Mapping, Sequence from dataclasses import dataclass, field from itertools import pairwise from typing import TYPE_CHECKING, Any, Literal, cast @@ -20,11 +22,15 @@ from xarray.core import duck_array_ops from xarray.core.computation import apply_ufunc from xarray.core.coordinates import Coordinates, _coordinates_from_variable +from xarray.core.coordinates import Coordinates +from xarray.core.common import _contains_datetime_like_objects +from xarray.core.coordinates import Coordinates from xarray.core.dataarray import DataArray from xarray.core.duck_array_ops import isnull from xarray.core.groupby import T_Group, _DummyGroup from xarray.core.indexes import safe_cast_to_index from xarray.core.resample_cftime import CFTimeGrouper +from xarray.core.toolzcompat import sliding_window from xarray.core.types import ( Bins, DatetimeLike, @@ -553,3 +559,217 @@ def unique_value_groups( if isinstance(values, pd.MultiIndex): values.names = ar.names return values, inverse + + +def season_to_month_tuple(seasons: Sequence[str]) -> tuple[tuple[int, ...], ...]: + initials = "JFMAMJJASOND" + starts = dict( + ("".join(s), i + 1) + for s, i in zip(sliding_window(2, initials + "J"), range(12), strict=False) + ) + result: list[tuple[int, ...]] = [] + for i, season in enumerate(seasons): + if len(season) == 1: + if i < len(seasons) - 1: + suffix = seasons[i + 1][0] + else: + suffix = seasons[0][0] + else: + suffix = season[1] + + start = starts[season[0] + suffix] + + month_append = [] + for i in range(len(season[1:])): + elem = start + i + 1 + month_append.append(elem - 12 * (elem > 12)) + result.append((start,) + tuple(month_append)) + return tuple(result) + + +@dataclass +class SeasonGrouper(Grouper): + """Allows grouping using a custom definition of seasons. + + Parameters + ---------- + seasons: sequence of str + List of strings representing seasons. E.g. ``"JF"`` or ``"JJA"`` etc. + + Examples + -------- + >>> SeasonGrouper(["JF", "MAM", "JJAS", "OND"]) + >>> SeasonGrouper(["DJFM", "AM", "JJA", "SON"]) + """ + + seasons: Sequence[str] + season_inds: Sequence[Sequence[int]] = field(init=False, repr=False) + # drop_incomplete: bool = field(default=True) # TODO + + def __post_init__(self) -> None: + self.season_inds = season_to_month_tuple(self.seasons) + + def factorize(self, group: T_Group) -> EncodedGroups: + if TYPE_CHECKING: + assert not isinstance(group, _DummyGroup) + if not _contains_datetime_like_objects(group.variable): + raise ValueError( + "SeasonGrouper can only be used to group by datetime-like arrays." + ) + + seasons = self.seasons + season_inds = self.season_inds + + months = group.dt.month + codes_ = np.full(group.shape, -1) + group_indices: list[list[int]] = [[]] * len(seasons) + + index = np.arange(group.size) + for idx, season_tuple in enumerate(season_inds): + mask = months.isin(season_tuple) + codes_[mask] = idx + group_indices[idx] = index[mask] + + if np.all(codes_ == -1): + raise ValueError( + "Failed to group data. Are you grouping by a variable that is all NaN?" + ) + codes = group.copy(data=codes_, deep=False).rename("season") + unique_coord = Variable("season", seasons, attrs=group.attrs) + full_index = pd.Index(seasons) + return EncodedGroups( + codes=codes, + group_indices=tuple(group_indices), + unique_coord=unique_coord, + full_index=full_index, + ) + + +@dataclass +class SeasonResampler(Resampler): + """Allows grouping using a custom definition of seasons. + + Parameters + ---------- + seasons: Sequence[str] + An ordered list of seasons. + drop_incomplete: bool + Whether to drop seasons that are not completely included in the data. + For example, if a time series starts in Jan-2001, and seasons includes `"DJF"` + then observations from Jan-2001, and Feb-2001 are ignored in the grouping + since Dec-2000 isn't present. + + Examples + -------- + >>> SeasonResampler(["JF", "MAM", "JJAS", "OND"]) + >>> SeasonResampler(["DJFM", "AM", "JJA", "SON"]) + """ + + seasons: Sequence[str] + drop_incomplete: bool = field(default=True) + season_inds: Sequence[Sequence[int]] = field(init=False, repr=False) + season_tuples: Mapping[str, Sequence[int]] = field(init=False, repr=False) + + def __post_init__(self): + self.season_inds = season_to_month_tuple(self.seasons) + self.season_tuples = dict(zip(self.seasons, self.season_inds, strict=False)) + + def factorize(self, group): + if group.ndim != 1: + raise ValueError( + "SeasonResampler can only be used to resample by 1D arrays." + ) + if not _contains_datetime_like_objects(group.variable): + raise ValueError( + "SeasonResampler can only be used to group by datetime-like arrays." + ) + + seasons = self.seasons + season_inds = self.season_inds + season_tuples = self.season_tuples + + nstr = max(len(s) for s in seasons) + year = group.dt.year.astype(int) + month = group.dt.month.astype(int) + season_label = np.full(group.shape, "", dtype=f"U{nstr}") + + # offset years for seasons with December and January + for season_str, season_ind in zip(seasons, season_inds, strict=False): + season_label[month.isin(season_ind)] = season_str + if "DJ" in season_str: + after_dec = season_ind[season_str.index("D") + 1 :] + year[month.isin(after_dec)] -= 1 + + frame = pd.DataFrame( + data={"index": np.arange(group.size), "month": month}, + index=pd.MultiIndex.from_arrays( + [year.data, season_label], names=["year", "season"] + ), + ) + + series = frame["index"] + g = series.groupby(["year", "season"], sort=False) + first_items = g.first() + counts = g.count() + + # these are the seasons that are present + unique_coord = pd.DatetimeIndex( + [ + pd.Timestamp(year=year, month=season_tuples[season][0], day=1) + for year, season in first_items.index + ] + ) + + sbins = first_items.values.astype(int) + group_indices = [ + slice(i, j) for i, j in zip(sbins[:-1], sbins[1:], strict=False) + ] + group_indices += [slice(sbins[-1], None)] + + # Make sure the first and last timestamps + # are for the correct months,if not we have incomplete seasons + unique_codes = np.arange(len(unique_coord)) + if self.drop_incomplete: + for idx, slicer in zip([0, -1], (slice(1, None), slice(-1)), strict=False): + stamp_year, stamp_season = frame.index[idx] + code = seasons.index(stamp_season) + stamp_month = season_inds[code][idx] + if stamp_month != month[idx].item(): + # we have an incomplete season! + group_indices = group_indices[slicer] + unique_coord = unique_coord[slicer] + if idx == 0: + unique_codes -= 1 + unique_codes[idx] = -1 + + # all years and seasons + complete_index = pd.DatetimeIndex( + # This sorted call is a hack. It's hard to figure out how + # to start the iteration + sorted( + [ + pd.Timestamp(f"{y}-{m}-01") + for y, m in itertools.product( + range(year[0].item(), year[-1].item() + 1), + [s[0] for s in season_inds], + ) + ] + ) + ) + # only keep that included in data + range_ = complete_index.get_indexer(unique_coord[[0, -1]]) + full_index = complete_index[slice(range_[0], range_[-1] + 1)] + # check that there are no "missing" seasons in the middle + # print(full_index, unique_coord) + if not full_index.equals(unique_coord): + raise ValueError("Are there seasons missing in the middle of the dataset?") + + codes = group.copy(data=np.repeat(unique_codes, counts), deep=False) + unique_coord_var = Variable(group.name, unique_coord, group.attrs) + + return EncodedGroups( + codes=codes, + group_indices=group_indices, + unique_coord=unique_coord_var, + full_index=full_index, + ) diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index 3fc7fcac132..ff33b073ddb 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -22,6 +22,7 @@ Grouper, TimeResampler, UniqueGrouper, + season_to_month_tuple, ) from xarray.namedarray.pycompat import is_chunked_array from xarray.tests import ( @@ -3147,3 +3148,22 @@ def test_groupby_dask_eager_load_warnings(): # 2. grouped-reduce on unique coords is identical to array # 3. group_over == groupby-reduce along other dimensions # 4. result is equivalent for transposed input +def test_season_to_month_tuple(): + assert season_to_month_tuple(["JF", "MAM", "JJAS", "OND"]) == ( + (1, 2), + (3, 4, 5), + (6, 7, 8, 9), + (10, 11, 12), + ) + assert season_to_month_tuple(["DJFM", "AM", "JJAS", "ON"]) == ( + (12, 1, 2, 3), + (4, 5), + (6, 7, 8, 9), + (10, 11), + ) + + +# Possible property tests +# 1. lambda x: x +# 2. grouped-reduce on unique coords is identical to array +# 3. group_over == groupby-reduce along other dimensions From 879b49604aa3451f767c31e7ca8a5acaf02306bf Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Fri, 20 Sep 2024 17:46:48 -0600 Subject: [PATCH 02/33] Allow sliding seasons --- xarray/core/groupby.py | 13 ++-- xarray/groupers.py | 111 +++++++++++++++++++++++++++-------- xarray/tests/test_groupby.py | 4 +- 3 files changed, 95 insertions(+), 33 deletions(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 7a32cd7b1db..f13cda08e30 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -253,6 +253,8 @@ def _ensure_1d( from xarray.core.dataarray import DataArray if isinstance(group, DataArray): + for dim in set(group.dims) - set(obj.dims): + obj = obj.expand_dims(dim) # try to stack the dims of the group into a single dim orig_dims = group.dims stacked_dim = "stacked_" + "_".join(map(str, orig_dims)) @@ -750,7 +752,7 @@ def __repr__(self) -> str: for grouper in self.groupers: coord = grouper.unique_coord labels = ", ".join(format_array_flat(coord, 30).split()) - text += f"\n {grouper.name!r}: {coord.size}/{grouper.full_index.size} groups present with labels {labels}" + text += f"\n {grouper.name!r}: {type(grouper.grouper).__name__}({grouper.group.name!r}), {coord.size} groups with labels {labels}" return text + ">" def _iter_grouped(self) -> Iterator[T_Xarray]: @@ -974,7 +976,7 @@ def _flox_reduce( parsed_dim_list = list() # preserve order for dim_ in itertools.chain( - *(grouper.group.dims for grouper in self.groupers) + *(grouper.codes.dims for grouper in self.groupers) ): if dim_ not in parsed_dim_list: parsed_dim_list.append(dim_) @@ -988,7 +990,7 @@ def _flox_reduce( # Better to control it here than in flox. for grouper in self.groupers: if any( - d not in grouper.group.dims and d not in obj.dims for d in parsed_dim + d not in grouper.codes.dims and d not in obj.dims for d in parsed_dim ): raise ValueError(f"cannot reduce over dimensions {dim}.") @@ -1232,9 +1234,6 @@ def quantile( "Sample quantiles in statistical packages," The American Statistician, 50(4), pp. 361-365, 1996 """ - if dim is None: - dim = (self._group_dim,) - # Dataset.quantile does this, do it for flox to ensure same output. q = np.asarray(q, dtype=np.float64) @@ -1253,7 +1252,7 @@ def quantile( self._obj.__class__.quantile, shortcut=False, q=q, - dim=dim, + dim=dim or self._group_dim, method=method, keep_attrs=keep_attrs, skipna=skipna, diff --git a/xarray/groupers.py b/xarray/groupers.py index fcbadeea2c8..0a101880077 100644 --- a/xarray/groupers.py +++ b/xarray/groupers.py @@ -9,9 +9,11 @@ import datetime import itertools from abc import ABC, abstractmethod +from collections import defaultdict from collections.abc import Mapping, Sequence from dataclasses import dataclass, field from itertools import pairwise +from itertools import chain from typing import TYPE_CHECKING, Any, Literal, cast import numpy as np @@ -75,9 +77,9 @@ class EncodedGroups: codes: DataArray full_index: pd.Index - group_indices: GroupIndices - unique_coord: Variable | _DummyGroup - coords: Coordinates + group_indices: GroupIndices = field(init=False, repr=False) + unique_coord: Variable | _DummyGroup = field(init=False, repr=False) + coords: Coordinates = field(init=False, repr=False) def __init__( self, @@ -587,6 +589,55 @@ def season_to_month_tuple(seasons: Sequence[str]) -> tuple[tuple[int, ...], ...] return tuple(result) +def inds_to_string(asints: tuple[tuple[int, ...], ...]) -> tuple[str, ...]: + inits = "JFMAMJJASOND" + return tuple("".join([inits[i_ - 1] for i_ in t]) for t in asints) + + +@dataclass +class SeasonsGroup: + seasons: tuple[str, ...] + inds: tuple[tuple[int, ...], ...] + codes: Sequence[int] + + +def find_independent_seasons(seasons: Sequence[str]) -> Sequence[SeasonsGroup]: + """ + Iterates though a list of seasons e.g. ["DJF", "FMA", ...], + and splits that into multiple sequences of non-overlapping seasons. + """ + sinds = season_to_month_tuple(seasons) + grouped = defaultdict(list) + codes = defaultdict(list) + seen: set[tuple[int, ...]] = set() + idx = 0 + # This is quadratic, but the length of seasons is at most 12 + for i, current in enumerate(sinds): + # Start with a group + if current not in seen: + grouped[idx].append(current) + codes[idx].append(i) + seen.add(current) + + # Loop through remaining groups, and look for overlaps + for j, second in enumerate(sinds[i:]): + if not (set(chain(*grouped[idx])) & set(second)): + if second not in seen: + grouped[idx].append(second) + codes[idx].append(j + i) + seen.add(second) + if len(seen) == len(seasons): + break + # found all non-overlapping groups for this row, increment and start over + idx += 1 + + grouped_ints = tuple(tuple(idx) for idx in grouped.values() if idx) + return [ + SeasonsGroup(seasons=inds_to_string(inds), inds=inds, codes=codes) + for inds, codes in zip(grouped_ints, codes.values(), strict=False) + ] + + @dataclass class SeasonGrouper(Grouper): """Allows grouping using a custom definition of seasons. @@ -599,16 +650,20 @@ class SeasonGrouper(Grouper): Examples -------- >>> SeasonGrouper(["JF", "MAM", "JJAS", "OND"]) - >>> SeasonGrouper(["DJFM", "AM", "JJA", "SON"]) + SeasonGrouper(seasons=['JF', 'MAM', 'JJAS', 'OND']) + + The ordering is preserved + >>> SeasonGrouper(["MAM", "JJAS", "OND", "JF"]) + SeasonGrouper(seasons=['MAM', 'JJAS', 'OND', 'JF']) + + Overlapping seasons are allowed + >>> SeasonGrouper(["DJFM", "MAMJ", "JJAS", "SOND"]) + SeasonGrouper(seasons=['DJFM', 'MAMJ', 'JJAS', 'SOND']) """ seasons: Sequence[str] - season_inds: Sequence[Sequence[int]] = field(init=False, repr=False) # drop_incomplete: bool = field(default=True) # TODO - def __post_init__(self) -> None: - self.season_inds = season_to_month_tuple(self.seasons) - def factorize(self, group: T_Group) -> EncodedGroups: if TYPE_CHECKING: assert not isinstance(group, _DummyGroup) @@ -616,27 +671,32 @@ def factorize(self, group: T_Group) -> EncodedGroups: raise ValueError( "SeasonGrouper can only be used to group by datetime-like arrays." ) - - seasons = self.seasons - season_inds = self.season_inds - - months = group.dt.month - codes_ = np.full(group.shape, -1) - group_indices: list[list[int]] = [[]] * len(seasons) - - index = np.arange(group.size) - for idx, season_tuple in enumerate(season_inds): - mask = months.isin(season_tuple) - codes_[mask] = idx - group_indices[idx] = index[mask] + months = group.dt.month.data + seasons_groups = find_independent_seasons(self.seasons) + codes_ = np.full((len(seasons_groups),) + group.shape, -1, dtype=np.int8) + group_indices: list[list[int]] = [[]] * len(self.seasons) + for axis_index, seasgroup in enumerate(seasons_groups): + for season_tuple, code in zip( + seasgroup.inds, seasgroup.codes, strict=False + ): + mask = np.isin(months, season_tuple) + codes_[axis_index, mask] = code + (indices,) = mask.nonzero() + group_indices[code] = indices.tolist() if np.all(codes_ == -1): raise ValueError( "Failed to group data. Are you grouping by a variable that is all NaN?" ) - codes = group.copy(data=codes_, deep=False).rename("season") - unique_coord = Variable("season", seasons, attrs=group.attrs) - full_index = pd.Index(seasons) + needs_dummy_dim = len(seasons_groups) > 1 + codes = DataArray( + dims=(("__season_dim__",) if needs_dummy_dim else tuple()) + group.dims, + data=codes_ if needs_dummy_dim else codes_.squeeze(), + attrs=group.attrs, + name="season", + ) + unique_coord = Variable("season", self.seasons, attrs=group.attrs) + full_index = pd.Index(self.seasons) return EncodedGroups( codes=codes, group_indices=tuple(group_indices), @@ -662,7 +722,10 @@ class SeasonResampler(Resampler): Examples -------- >>> SeasonResampler(["JF", "MAM", "JJAS", "OND"]) + SeasonResampler(seasons=['JF', 'MAM', 'JJAS', 'OND'], drop_incomplete=True) + >>> SeasonResampler(["DJFM", "AM", "JJA", "SON"]) + SeasonResampler(seasons=['DJFM', 'AM', 'JJA', 'SON'], drop_incomplete=True) """ seasons: Sequence[str] diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index ff33b073ddb..b8b8aa45bcb 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -611,7 +611,7 @@ def test_groupby_repr(obj, dim) -> None: N = len(np.unique(obj[dim])) expected = f"<{obj.__class__.__name__}GroupBy" expected += f", grouped over 1 grouper(s), {N} groups in total:" - expected += f"\n {dim!r}: {N}/{N} groups present with labels " + expected += f"\n {dim!r}: UniqueGrouper({dim!r}), {N} groups with labels " if dim == "x": expected += "1, 2, 3, 4, 5>" elif dim == "y": @@ -628,7 +628,7 @@ def test_groupby_repr_datetime(obj) -> None: actual = repr(obj.groupby("t.month")) expected = f"<{obj.__class__.__name__}GroupBy" expected += ", grouped over 1 grouper(s), 12 groups in total:\n" - expected += " 'month': 12/12 groups present with labels " + expected += " 'month': UniqueGrouper('month'), 12 groups with labels " expected += "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12>" assert actual == expected From 8268c468537bb04d490209aac7589e04606311d7 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Sat, 21 Sep 2024 20:59:27 -0600 Subject: [PATCH 03/33] cftime support --- xarray/groupers.py | 26 ++++++++++++++++++++------ 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/xarray/groupers.py b/xarray/groupers.py index 0a101880077..39e8b2db3fe 100644 --- a/xarray/groupers.py +++ b/xarray/groupers.py @@ -21,14 +21,21 @@ from numpy.typing import ArrayLike from xarray.coding.cftime_offsets import BaseCFTimeOffset, _new_to_legacy_freq +from xarray.coding.cftimeindex import CFTimeIndex from xarray.core import duck_array_ops from xarray.core.computation import apply_ufunc from xarray.core.coordinates import Coordinates, _coordinates_from_variable from xarray.core.coordinates import Coordinates from xarray.core.common import _contains_datetime_like_objects +from xarray.core.common import _contains_datetime_like_objects +from xarray.core.common import ( + _contains_cftime_datetimes, + _contains_datetime_like_objects, +) from xarray.core.coordinates import Coordinates from xarray.core.dataarray import DataArray from xarray.core.duck_array_ops import isnull +from xarray.core.formatting import first_n_items from xarray.core.groupby import T_Group, _DummyGroup from xarray.core.indexes import safe_cast_to_index from xarray.core.resample_cftime import CFTimeGrouper @@ -567,7 +574,7 @@ def season_to_month_tuple(seasons: Sequence[str]) -> tuple[tuple[int, ...], ...] initials = "JFMAMJJASOND" starts = dict( ("".join(s), i + 1) - for s, i in zip(sliding_window(2, initials + "J"), range(12), strict=False) + for s, i in zip(sliding_window(2, initials + "J"), range(12), strict=True) ) result: list[tuple[int, ...]] = [] for i, season in enumerate(seasons): @@ -757,7 +764,7 @@ def factorize(self, group): season_label = np.full(group.shape, "", dtype=f"U{nstr}") # offset years for seasons with December and January - for season_str, season_ind in zip(seasons, season_inds, strict=False): + for season_str, season_ind in zip(seasons, season_inds, strict=True): season_label[month.isin(season_ind)] = season_str if "DJ" in season_str: after_dec = season_ind[season_str.index("D") + 1 :] @@ -775,10 +782,17 @@ def factorize(self, group): first_items = g.first() counts = g.count() + if _contains_cftime_datetimes(group.data): + index_class = CFTimeIndex + datetime_class = type(first_n_items(group.data, 1).item()) + else: + index_class = pd.DatetimeIndex + datetime_class = datetime.datetime + # these are the seasons that are present - unique_coord = pd.DatetimeIndex( + unique_coord = index_class( [ - pd.Timestamp(year=year, month=season_tuples[season][0], day=1) + datetime_class(year=year, month=season_tuples[season][0], day=1) for year, season in first_items.index ] ) @@ -806,12 +820,12 @@ def factorize(self, group): unique_codes[idx] = -1 # all years and seasons - complete_index = pd.DatetimeIndex( + complete_index = index_class( # This sorted call is a hack. It's hard to figure out how # to start the iteration sorted( [ - pd.Timestamp(f"{y}-{m}-01") + datetime_class(year=y, month=m, day=1) for y, m in itertools.product( range(year[0].item(), year[-1].item() + 1), [s[0] for s in season_inds], From 31cc519dda1aef97219889bc374c3a9838ce7c6a Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Sat, 21 Sep 2024 21:22:23 -0600 Subject: [PATCH 04/33] Add skeleton tests --- xarray/tests/test_groupby.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index b8b8aa45bcb..54cd8075328 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -12,7 +12,7 @@ from packaging.version import Version import xarray as xr -from xarray import DataArray, Dataset, Variable +from xarray import DataArray, Dataset, Variable, cftime_range from xarray.core.alignment import broadcast from xarray.core.groupby import _consolidate_slices from xarray.core.types import InterpOptions, ResampleCompatible @@ -20,6 +20,7 @@ BinGrouper, EncodedGroups, Grouper, + SeasonResampler, TimeResampler, UniqueGrouper, season_to_month_tuple, @@ -3163,6 +3164,24 @@ def test_season_to_month_tuple(): ) +def test_season_resampler(): + time = cftime_range("2001-01-01", "2002-12-30", freq="D", calendar="360_day") + da = DataArray(np.ones(time.size), dims="time", coords={"time": time}) + + # through resample + da.resample(time=SeasonResampler(["DJF", "MAM", "JJA", "SON"])).sum() + + # through groupby + da.groupby(time=SeasonResampler(["DJF", "MAM", "JJA", "SON"])).sum() + + # skip september + da.groupby(time=SeasonResampler(["DJF", "MAM", "JJA", "ON"])).sum() + + # overlapping + with pytest.raises(ValueError): + da.groupby(time=SeasonResampler(["DJFM", "MAMJ", "JJAS", "SOND"])).sum() + + # Possible property tests # 1. lambda x: x # 2. grouped-reduce on unique coords is identical to array From 96ae241479f561d325345c53c3d0c0126e24c688 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Sat, 21 Sep 2024 21:22:33 -0600 Subject: [PATCH 05/33] Support "subsampled" seasons --- xarray/groupers.py | 33 ++++++++++++++++++++++++++------- xarray/tests/test_groupby.py | 3 +++ 2 files changed, 29 insertions(+), 7 deletions(-) diff --git a/xarray/groupers.py b/xarray/groupers.py index 39e8b2db3fe..66740baf7d2 100644 --- a/xarray/groupers.py +++ b/xarray/groupers.py @@ -7,7 +7,9 @@ from __future__ import annotations import datetime +import functools import itertools +import operator from abc import ABC, abstractmethod from collections import defaultdict from collections.abc import Mapping, Sequence @@ -742,7 +744,12 @@ class SeasonResampler(Resampler): def __post_init__(self): self.season_inds = season_to_month_tuple(self.seasons) - self.season_tuples = dict(zip(self.seasons, self.season_inds, strict=False)) + all_inds = functools.reduce(operator.add, self.season_inds) + if len(all_inds) > len(set(all_inds)): + raise ValueError( + f"Overlapping seasons are not allowed. Received {self.seasons!r}" + ) + self.season_tuples = dict(zip(self.seasons, self.season_inds, strict=True)) def factorize(self, group): if group.ndim != 1: @@ -768,12 +775,22 @@ def factorize(self, group): season_label[month.isin(season_ind)] = season_str if "DJ" in season_str: after_dec = season_ind[season_str.index("D") + 1 :] + # important this is assuming non-overlapping seasons year[month.isin(after_dec)] -= 1 + # Allow users to skip one or more months? + # present_seasons is a mask that is True for months that are requestsed in the output + present_seasons = season_label != "" + if present_seasons.all(): + present_seasons = slice(None) frame = pd.DataFrame( - data={"index": np.arange(group.size), "month": month}, + data={ + "index": np.arange(group[present_seasons].size), + "month": month[present_seasons], + }, index=pd.MultiIndex.from_arrays( - [year.data, season_label], names=["year", "season"] + [year.data[present_seasons], season_label[present_seasons]], + names=["year", "season"], ), ) @@ -799,7 +816,7 @@ def factorize(self, group): sbins = first_items.values.astype(int) group_indices = [ - slice(i, j) for i, j in zip(sbins[:-1], sbins[1:], strict=False) + slice(i, j) for i, j in zip(sbins[:-1], sbins[1:], strict=True) ] group_indices += [slice(sbins[-1], None)] @@ -807,11 +824,11 @@ def factorize(self, group): # are for the correct months,if not we have incomplete seasons unique_codes = np.arange(len(unique_coord)) if self.drop_incomplete: - for idx, slicer in zip([0, -1], (slice(1, None), slice(-1)), strict=False): + for idx, slicer in zip([0, -1], (slice(1, None), slice(-1)), strict=True): stamp_year, stamp_season = frame.index[idx] code = seasons.index(stamp_season) stamp_month = season_inds[code][idx] - if stamp_month != month[idx].item(): + if stamp_month != month[present_seasons][idx].item(): # we have an incomplete season! group_indices = group_indices[slicer] unique_coord = unique_coord[slicer] @@ -841,7 +858,9 @@ def factorize(self, group): if not full_index.equals(unique_coord): raise ValueError("Are there seasons missing in the middle of the dataset?") - codes = group.copy(data=np.repeat(unique_codes, counts), deep=False) + final_codes = np.full(group.data.size, -1) + final_codes[present_seasons] = np.repeat(unique_codes, counts) + codes = group.copy(data=final_codes, deep=False) unique_coord_var = Variable(group.name, unique_coord, group.attrs) return EncodedGroups( diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index 54cd8075328..79ad7aaa875 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -3177,6 +3177,9 @@ def test_season_resampler(): # skip september da.groupby(time=SeasonResampler(["DJF", "MAM", "JJA", "ON"])).sum() + # "subsampling" + da.groupby(time=SeasonResampler(["JJAS"])).sum() + # overlapping with pytest.raises(ValueError): da.groupby(time=SeasonResampler(["DJFM", "MAMJ", "JJAS", "SOND"])).sum() From 77dc5e0d36cf49ab9d9b14a92b71c439075ffc7d Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Sat, 21 Sep 2024 21:47:15 -0600 Subject: [PATCH 06/33] small edits --- properties/test_properties.py | 8 +- xarray/groupers.py | 123 ++++++++++++++++------------ xarray/tests/test_groupby.py | 147 ++++++++++++++++++++++++++-------- 3 files changed, 190 insertions(+), 88 deletions(-) diff --git a/properties/test_properties.py b/properties/test_properties.py index 859f9d4e500..24de8049f58 100644 --- a/properties/test_properties.py +++ b/properties/test_properties.py @@ -1,3 +1,5 @@ +import itertools + import pytest pytest.importorskip("hypothesis") @@ -37,12 +39,10 @@ def test_property_season_month_tuple(roll, breaks): if breaks[-1] != 12: breaks = breaks + [12] seasons = tuple( - "".join(rolled_chars[start:stop]) - for start, stop in zip(breaks[:-1], breaks[1:], strict=False) + "".join(rolled_chars[start:stop]) for start, stop in itertools.pairwise(breaks) ) actual = season_to_month_tuple(seasons) expected = tuple( - rolled_months[start:stop] - for start, stop in zip(breaks[:-1], breaks[1:], strict=False) + rolled_months[start:stop] for start, stop in itertools.pairwise(breaks) ) assert expected == actual diff --git a/xarray/groupers.py b/xarray/groupers.py index 66740baf7d2..44a2ad6729f 100644 --- a/xarray/groupers.py +++ b/xarray/groupers.py @@ -14,8 +14,7 @@ from collections import defaultdict from collections.abc import Mapping, Sequence from dataclasses import dataclass, field -from itertools import pairwise -from itertools import chain +from itertools import chain, pairwise from typing import TYPE_CHECKING, Any, Literal, cast import numpy as np @@ -25,16 +24,12 @@ from xarray.coding.cftime_offsets import BaseCFTimeOffset, _new_to_legacy_freq from xarray.coding.cftimeindex import CFTimeIndex from xarray.core import duck_array_ops -from xarray.core.computation import apply_ufunc -from xarray.core.coordinates import Coordinates, _coordinates_from_variable -from xarray.core.coordinates import Coordinates -from xarray.core.common import _contains_datetime_like_objects -from xarray.core.common import _contains_datetime_like_objects from xarray.core.common import ( _contains_cftime_datetimes, _contains_datetime_like_objects, ) -from xarray.core.coordinates import Coordinates +from xarray.core.computation import apply_ufunc +from xarray.core.coordinates import Coordinates, _coordinates_from_variable from xarray.core.dataarray import DataArray from xarray.core.duck_array_ops import isnull from xarray.core.formatting import first_n_items @@ -751,14 +746,16 @@ def __post_init__(self): ) self.season_tuples = dict(zip(self.seasons, self.season_inds, strict=True)) - def factorize(self, group): + def factorize(self, group: T_Group) -> EncodedGroups: if group.ndim != 1: raise ValueError( "SeasonResampler can only be used to resample by 1D arrays." ) - if not _contains_datetime_like_objects(group.variable): + if not isinstance(group, DataArray) or not _contains_datetime_like_objects( + group.variable + ): raise ValueError( - "SeasonResampler can only be used to group by datetime-like arrays." + "SeasonResampler can only be used to group by datetime-like DataArrays." ) seasons = self.seasons @@ -775,13 +772,14 @@ def factorize(self, group): season_label[month.isin(season_ind)] = season_str if "DJ" in season_str: after_dec = season_ind[season_str.index("D") + 1 :] - # important this is assuming non-overlapping seasons + # important: this is assuming non-overlapping seasons year[month.isin(after_dec)] -= 1 # Allow users to skip one or more months? - # present_seasons is a mask that is True for months that are requestsed in the output + # present_seasons is a mask that is True for months that are requested in the output present_seasons = season_label != "" if present_seasons.all(): + # avoid copies if we can. present_seasons = slice(None) frame = pd.DataFrame( data={ @@ -794,10 +792,13 @@ def factorize(self, group): ), ) - series = frame["index"] - g = series.groupby(["year", "season"], sort=False) - first_items = g.first() - counts = g.count() + agged = ( + frame["index"] + .groupby(["year", "season"], sort=False) + .agg(["first", "count"]) + ) + first_items = agged["first"] + counts = agged["count"] if _contains_cftime_datetimes(group.data): index_class = CFTimeIndex @@ -814,32 +815,18 @@ def factorize(self, group): ] ) - sbins = first_items.values.astype(int) - group_indices = [ - slice(i, j) for i, j in zip(sbins[:-1], sbins[1:], strict=True) - ] - group_indices += [slice(sbins[-1], None)] - - # Make sure the first and last timestamps - # are for the correct months,if not we have incomplete seasons - unique_codes = np.arange(len(unique_coord)) - if self.drop_incomplete: - for idx, slicer in zip([0, -1], (slice(1, None), slice(-1)), strict=True): - stamp_year, stamp_season = frame.index[idx] - code = seasons.index(stamp_season) - stamp_month = season_inds[code][idx] - if stamp_month != month[present_seasons][idx].item(): - # we have an incomplete season! - group_indices = group_indices[slicer] - unique_coord = unique_coord[slicer] - if idx == 0: - unique_codes -= 1 - unique_codes[idx] = -1 - - # all years and seasons + # sbins = first_items.values.astype(int) + # group_indices = [ + # slice(i, j) for i, j in zip(sbins[:-1], sbins[1:], strict=True) + # ] + # group_indices += [slice(sbins[-1], None)] + + # This sorted call is a hack. It's hard to figure out how + # to start the iteration for arbitrary season ordering + # for example "DJF" as first entry or last entry + # So we construct the largest possible index and slice it to the + # range present in the data. complete_index = index_class( - # This sorted call is a hack. It's hard to figure out how - # to start the iteration sorted( [ datetime_class(year=y, month=m, day=1) @@ -850,22 +837,56 @@ def factorize(self, group): ] ) ) - # only keep that included in data - range_ = complete_index.get_indexer(unique_coord[[0, -1]]) - full_index = complete_index[slice(range_[0], range_[-1] + 1)] + + # all years and seasons + def get_label(year, season): + month = season_tuples[season][0] + return f"{year}-{month}-01" + + unique_codes = np.arange(len(unique_coord)) + first_valid_season = season_label[0] + last_valid_season = season_label[-1] + first_year, last_year = year.data[[0, -1]] + if self.drop_incomplete: + if month.data[0] != season_tuples[first_valid_season][0]: + if "DJ" in first_valid_season: + first_year += 1 + first_valid_season = seasons[ + (seasons.index(first_valid_season) + 1) % len(seasons) + ] + # group_indices = group_indices[slice(1, None)] + unique_codes -= 1 + + if month.data[-1] != season_tuples[last_valid_season][-1]: + last_valid_season = seasons[seasons.index(last_valid_season) - 1] + if "DJ" in last_valid_season: + last_year -= 1 + # group_indices = group_indices[slice(-1)] + unique_codes[-1] = -1 + + first_label = get_label(first_year, first_valid_season) + last_label = get_label(last_year, last_valid_season) + + slicer = complete_index.slice_indexer(first_label, last_label) + full_index = complete_index[slicer] + # TODO: group must be sorted + # codes = np.searchsorted(edges, group.data, side="left") + # codes -= 1 + # codes[~present_seasons | group.data >= edges[-1]] = -1 + # codes[isnull(group.data)] = -1 + # import ipdb; ipdb.set_trace() # check that there are no "missing" seasons in the middle - # print(full_index, unique_coord) - if not full_index.equals(unique_coord): - raise ValueError("Are there seasons missing in the middle of the dataset?") + # if not full_index.equals(unique_coord): + # raise ValueError("Are there seasons missing in the middle of the dataset?") final_codes = np.full(group.data.size, -1) final_codes[present_seasons] = np.repeat(unique_codes, counts) codes = group.copy(data=final_codes, deep=False) - unique_coord_var = Variable(group.name, unique_coord, group.attrs) + # unique_coord_var = Variable(group.name, unique_coord, group.attrs) return EncodedGroups( codes=codes, - group_indices=group_indices, - unique_coord=unique_coord_var, + # group_indices=group_indices, + # unique_coord=unique_coord_var, full_index=full_index, ) diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index 79ad7aaa875..ea1b9c4e6b7 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -12,7 +12,7 @@ from packaging.version import Version import xarray as xr -from xarray import DataArray, Dataset, Variable, cftime_range +from xarray import DataArray, Dataset, Variable, cftime_range, date_range from xarray.core.alignment import broadcast from xarray.core.groupby import _consolidate_slices from xarray.core.types import InterpOptions, ResampleCompatible @@ -20,6 +20,7 @@ BinGrouper, EncodedGroups, Grouper, + SeasonGrouper, SeasonResampler, TimeResampler, UniqueGrouper, @@ -44,6 +45,7 @@ requires_pandas_ge_2_2, requires_scipy, ) +from xarray.tests.test_coding_times import _ALL_CALENDARS @pytest.fixture @@ -3144,48 +3146,127 @@ def test_groupby_dask_eager_load_warnings(): ds.groupby_bins("x", bins=[1, 2, 3], eagerly_compute_group=False) -# TODO: Possible property tests to add to this module -# 1. lambda x: x -# 2. grouped-reduce on unique coords is identical to array -# 3. group_over == groupby-reduce along other dimensions -# 4. result is equivalent for transposed input -def test_season_to_month_tuple(): - assert season_to_month_tuple(["JF", "MAM", "JJAS", "OND"]) == ( - (1, 2), - (3, 4, 5), - (6, 7, 8, 9), - (10, 11, 12), - ) - assert season_to_month_tuple(["DJFM", "AM", "JJAS", "ON"]) == ( - (12, 1, 2, 3), - (4, 5), - (6, 7, 8, 9), - (10, 11), +class TestSeasonGrouperAndResampler: + def test_season_to_month_tuple(self): + assert season_to_month_tuple(["JF", "MAM", "JJAS", "OND"]) == ( + (1, 2), + (3, 4, 5), + (6, 7, 8, 9), + (10, 11, 12), + ) + assert season_to_month_tuple(["DJFM", "AM", "JJAS", "ON"]) == ( + (12, 1, 2, 3), + (4, 5), + (6, 7, 8, 9), + (10, 11), + ) + + @pytest.mark.parametrize("calendar", _ALL_CALENDARS) + def test_season_grouper_simple(self, calendar) -> None: + time = cftime_range("2001-01-01", "2002-12-30", freq="D", calendar=calendar) + da = DataArray(np.ones(time.size), dims="time", coords={"time": time}) + expected = da.groupby("time.season").mean() + # note season order matches expected + actual = da.groupby( + time=SeasonGrouper( + ["DJF", "JJA", "MAM", "SON"], # drop_incomplete=False + ) + ).mean() + assert_identical(expected, actual) + + # TODO: drop_incomplete + @requires_cftime + @pytest.mark.parametrize("drop_incomplete", [True, False]) + @pytest.mark.parametrize( + "seasons", + [ + pytest.param(["DJF", "MAM", "JJA", "SON"], id="standard"), + pytest.param(["MAM", "JJA", "SON", "DJF"], id="standard-diff-order"), + pytest.param(["JFM", "AMJ", "JAS", "OND"], id="december-same-year"), + pytest.param(["DJF", "MAM", "JJA", "ON"], id="skip-september"), + pytest.param(["JJAS"], id="jjas-only"), + pytest.param(["MAM", "JJA", "SON", "DJF"], id="different-order"), + pytest.param(["JJA", "MAM", "SON", "DJF"], id="out-of-order"), + ], ) + def test_season_resampler(self, seasons: list[str], drop_incomplete: bool) -> None: + calendar = "standard" + time = date_range("2001-01-01", "2002-12-30", freq="D", calendar=calendar) + da = DataArray(np.ones(time.size), dims="time", coords={"time": time}) + counts = da.resample(time="ME").count() + + seasons_as_ints = season_to_month_tuple(seasons) + month = counts.time.dt.month.data + year = counts.time.dt.year.data + for season, as_ints in zip(seasons, seasons_as_ints, strict=True): + if "DJ" in season: + for imonth in as_ints[season.index("D") + 1 :]: + year[month == imonth] -= 1 + counts["time"] = ( + "time", + [pd.Timestamp(f"{y}-{m}-01") for y, m in zip(year, month, strict=True)], + ) + counts = counts.convert_calendar(calendar, "time", align_on="date") + + expected_vals = [] + expected_time = [] + for year in [2001, 2002]: + for season, as_ints in zip(seasons, seasons_as_ints, strict=True): + out_year = year + if "DJ" in season: + out_year = year - 1 + available = [ + counts.sel(time=f"{out_year}-{month:02d}").data for month in as_ints + ] + if any(len(a) == 0 for a in available) and drop_incomplete: + continue + output_label = pd.Timestamp(f"{out_year}-{as_ints[0]:02d}-01") + expected_time.append(output_label) + # use concatenate to handle empty array when dec value does not exist + expected_vals.append(np.concatenate(available).sum()) + expected = xr.DataArray( + expected_vals, dims="time", coords={"time": expected_time} + ).convert_calendar(calendar, align_on="date") + rs = SeasonResampler(seasons, drop_incomplete=drop_incomplete) + # through resample + actual = da.resample(time=rs).sum() + assert_identical(actual, expected) -def test_season_resampler(): - time = cftime_range("2001-01-01", "2002-12-30", freq="D", calendar="360_day") - da = DataArray(np.ones(time.size), dims="time", coords={"time": time}) + def test_season_resampler_errors(self): + time = cftime_range("2001-01-01", "2002-12-30", freq="D", calendar="360_day") + da = DataArray(np.ones(time.size), dims="time", coords={"time": time}) - # through resample - da.resample(time=SeasonResampler(["DJF", "MAM", "JJA", "SON"])).sum() + # non-datetime array + with pytest.raises(ValueError): + DataArray(np.ones(5), dims="time").groupby(time=SeasonResampler(["DJF"])) - # through groupby - da.groupby(time=SeasonResampler(["DJF", "MAM", "JJA", "SON"])).sum() + # ndim > 1 array + with pytest.raises(ValueError): + DataArray( + np.ones((5, 5)), dims=("t", "x"), coords={"x": np.arange(5)} + ).groupby(x=SeasonResampler(["DJF"])) - # skip september - da.groupby(time=SeasonResampler(["DJF", "MAM", "JJA", "ON"])).sum() + # overlapping seasons + with pytest.raises(ValueError): + da.groupby(time=SeasonResampler(["DJFM", "MAMJ", "JJAS", "SOND"])).sum() - # "subsampling" - da.groupby(time=SeasonResampler(["JJAS"])).sum() + @requires_cftime + def test_season_resampler_groupby_identical(self): + time = date_range("2001-01-01", "2002-12-30", freq="D") + da = DataArray(np.ones(time.size), dims="time", coords={"time": time}) - # overlapping - with pytest.raises(ValueError): - da.groupby(time=SeasonResampler(["DJFM", "MAMJ", "JJAS", "SOND"])).sum() + # through resample + resampler = SeasonResampler(["DJF", "MAM", "JJA", "SON"]) + rs = da.resample(time=resampler).sum() + # through groupby + gb = da.groupby(time=resampler).sum() + assert_identical(rs, gb) -# Possible property tests + +# TODO: Possible property tests to add to this module # 1. lambda x: x # 2. grouped-reduce on unique coords is identical to array # 3. group_over == groupby-reduce along other dimensions +# 4. result is equivalent for transposed input From d68b1e4c21ec9e622ab5a33271e114a86122ce11 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Tue, 12 Nov 2024 13:43:35 -0700 Subject: [PATCH 07/33] Add reset --- xarray/groupers.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/xarray/groupers.py b/xarray/groupers.py index 44a2ad6729f..6d4036d196a 100644 --- a/xarray/groupers.py +++ b/xarray/groupers.py @@ -708,6 +708,9 @@ def factorize(self, group: T_Group) -> EncodedGroups: full_index=full_index, ) + def reset(self) -> Self: + return type(self)(self.seasons) + @dataclass class SeasonResampler(Resampler): @@ -733,7 +736,7 @@ class SeasonResampler(Resampler): """ seasons: Sequence[str] - drop_incomplete: bool = field(default=True) + drop_incomplete: bool = field(default=True, kw_only=True) season_inds: Sequence[Sequence[int]] = field(init=False, repr=False) season_tuples: Mapping[str, Sequence[int]] = field(init=False, repr=False) @@ -890,3 +893,6 @@ def get_label(year, season): # unique_coord=unique_coord_var, full_index=full_index, ) + + def reset(self) -> Self: + return type(self)(seasons=self.seasons, drop_incomplete=self.drop_incomplete) From 1b7a9fcdba89fe88c195f65c12f62806673c3d79 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Wed, 13 Nov 2024 20:35:36 -0700 Subject: [PATCH 08/33] Fix tests --- xarray/groupers.py | 13 ++++++++----- xarray/tests/test_groupby.py | 16 ++++++++++++---- 2 files changed, 20 insertions(+), 9 deletions(-) diff --git a/xarray/groupers.py b/xarray/groupers.py index 6d4036d196a..523a150e6b1 100644 --- a/xarray/groupers.py +++ b/xarray/groupers.py @@ -15,7 +15,7 @@ from collections.abc import Mapping, Sequence from dataclasses import dataclass, field from itertools import chain, pairwise -from typing import TYPE_CHECKING, Any, Literal, cast +from typing import TYPE_CHECKING, Any, Literal, Self, cast import numpy as np import pandas as pd @@ -847,11 +847,11 @@ def get_label(year, season): return f"{year}-{month}-01" unique_codes = np.arange(len(unique_coord)) - first_valid_season = season_label[0] - last_valid_season = season_label[-1] + valid_season_mask = season_label != "" + first_valid_season, last_valid_season = season_label[valid_season_mask][[0, -1]] first_year, last_year = year.data[[0, -1]] if self.drop_incomplete: - if month.data[0] != season_tuples[first_valid_season][0]: + if month.data[valid_season_mask][0] != season_tuples[first_valid_season][0]: if "DJ" in first_valid_season: first_year += 1 first_valid_season = seasons[ @@ -860,7 +860,10 @@ def get_label(year, season): # group_indices = group_indices[slice(1, None)] unique_codes -= 1 - if month.data[-1] != season_tuples[last_valid_season][-1]: + if ( + month.data[valid_season_mask][-1] + != season_tuples[last_valid_season][-1] + ): last_valid_season = seasons[seasons.index(last_valid_season) - 1] if "DJ" in last_valid_season: last_year -= 1 diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index ea1b9c4e6b7..730b7a2a2d8 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -3210,11 +3210,14 @@ def test_season_resampler(self, seasons: list[str], drop_incomplete: bool) -> No expected_vals = [] expected_time = [] - for year in [2001, 2002]: + for year in [2001, 2002, 2003]: for season, as_ints in zip(seasons, seasons_as_ints, strict=True): out_year = year if "DJ" in season: out_year = year - 1 + if out_year == 2003: + # this is a dummy year added to make sure we cover 2002-DJF + continue available = [ counts.sel(time=f"{out_year}-{month:02d}").data for month in as_ints ] @@ -3225,9 +3228,14 @@ def test_season_resampler(self, seasons: list[str], drop_incomplete: bool) -> No # use concatenate to handle empty array when dec value does not exist expected_vals.append(np.concatenate(available).sum()) - expected = xr.DataArray( - expected_vals, dims="time", coords={"time": expected_time} - ).convert_calendar(calendar, align_on="date") + expected = ( + # we construct expected in the standard calendar + xr.DataArray(expected_vals, dims="time", coords={"time": expected_time}) + # and then convert to the expected calendar, + .convert_calendar(calendar, align_on="date") + # and finally sort since DJF will be out-of-order + .sortby("time") + ) rs = SeasonResampler(seasons, drop_incomplete=drop_incomplete) # through resample actual = da.resample(time=rs).sum() From be5f9337fede955ab8cbc9fdaeb077e3cb643549 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Thu, 14 Nov 2024 10:42:03 -0700 Subject: [PATCH 09/33] Raise if seasons are not sorted for resampling --- xarray/groupers.py | 32 ++++++++++++++++++++++++++++++++ xarray/tests/test_groupby.py | 9 ++++++++- 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/xarray/groupers.py b/xarray/groupers.py index 523a150e6b1..3cf67db19b3 100644 --- a/xarray/groupers.py +++ b/xarray/groupers.py @@ -598,6 +598,32 @@ def inds_to_string(asints: tuple[tuple[int, ...], ...]) -> tuple[str, ...]: return tuple("".join([inits[i_ - 1] for i_ in t]) for t in asints) +def is_sorted_periodic(lst): + n = len(lst) + + # Find the wraparound point where the list decreases + wrap_point = -1 + for i in range(1, n): + if lst[i] < lst[i - 1]: + wrap_point = i + break + + # If no wraparound point is found, the list is already sorted + if wrap_point == -1: + return True + + # Check if both parts around the wrap point are sorted + for i in range(1, wrap_point): + if lst[i] < lst[i - 1]: + return False + for i in range(wrap_point + 1, n): + if lst[i] < lst[i - 1]: + return False + + # Check wraparound condition + return lst[-1] <= lst[0] + + @dataclass class SeasonsGroup: seasons: tuple[str, ...] @@ -749,6 +775,12 @@ def __post_init__(self): ) self.season_tuples = dict(zip(self.seasons, self.season_inds, strict=True)) + if not is_sorted_periodic(list(itertools.chain(*self.season_inds))): + raise ValueError( + "Resampling is only supported with sorted seasons. " + f"Provided seasons {self.seasons!r} are not sorted." + ) + def factorize(self, group: T_Group) -> EncodedGroups: if group.ndim != 1: raise ValueError( diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index 730b7a2a2d8..6f0e0503851 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -3174,6 +3174,14 @@ def test_season_grouper_simple(self, calendar) -> None: ).mean() assert_identical(expected, actual) + @pytest.mark.parametrize("seasons", [["JJA", "MAM", "SON", "DJF"]]) + def test_season_resampling_raises_unsorted_seasons(self, seasons): + calendar = "standard" + time = date_range("2001-01-01", "2002-12-30", freq="D", calendar=calendar) + da = DataArray(np.ones(time.size), dims="time", coords={"time": time}) + with pytest.raises(ValueError, match="sort"): + da.resample(time=SeasonResampler(seasons)) + # TODO: drop_incomplete @requires_cftime @pytest.mark.parametrize("drop_incomplete", [True, False]) @@ -3186,7 +3194,6 @@ def test_season_grouper_simple(self, calendar) -> None: pytest.param(["DJF", "MAM", "JJA", "ON"], id="skip-september"), pytest.param(["JJAS"], id="jjas-only"), pytest.param(["MAM", "JJA", "SON", "DJF"], id="different-order"), - pytest.param(["JJA", "MAM", "SON", "DJF"], id="out-of-order"), ], ) def test_season_resampler(self, seasons: list[str], drop_incomplete: bool) -> None: From bd21b48d8c5a239e96ffea3e4faa2accbd259bf9 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Thu, 14 Nov 2024 12:21:19 -0700 Subject: [PATCH 10/33] fix Self import --- xarray/groupers.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/xarray/groupers.py b/xarray/groupers.py index 3cf67db19b3..3c657884634 100644 --- a/xarray/groupers.py +++ b/xarray/groupers.py @@ -15,7 +15,7 @@ from collections.abc import Mapping, Sequence from dataclasses import dataclass, field from itertools import chain, pairwise -from typing import TYPE_CHECKING, Any, Literal, Self, cast +from typing import TYPE_CHECKING, Any, Literal, cast import numpy as np import pandas as pd @@ -42,6 +42,7 @@ DatetimeLike, GroupIndices, ResampleCompatible, + Self, SideOptions, ) from xarray.core.variable import Variable From 09640b78c1c43eeda117361a920c743732bb5970 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Thu, 14 Nov 2024 12:31:10 -0700 Subject: [PATCH 11/33] Redo calendar fixtures --- xarray/tests/__init__.py | 47 +++++++++++++++++++++++++++++-- xarray/tests/test_accessor_dt.py | 25 ++++++---------- xarray/tests/test_cftimeindex.py | 10 +++---- xarray/tests/test_coding_times.py | 40 ++++++-------------------- 4 files changed, 66 insertions(+), 56 deletions(-) diff --git a/xarray/tests/__init__.py b/xarray/tests/__init__.py index 5ed334e61dd..349389a6448 100644 --- a/xarray/tests/__init__.py +++ b/xarray/tests/__init__.py @@ -16,6 +16,7 @@ import xarray.testing from xarray import Dataset +from xarray.coding.times import _STANDARD_CALENDARS as _STANDARD_CALENDARS_UNSORTED from xarray.core.duck_array_ops import allclose_or_equiv # noqa: F401 from xarray.core.extension_array import PandasExtensionArray from xarray.core.options import set_options @@ -353,12 +354,52 @@ def create_test_data( _CFTIME_CALENDARS = [ + pytest.param( + cal, marks=pytest.mark.skipif(not has_cftime, reason="requires cftime") + ) + for cal in sorted( + [ + "365_day", + "360_day", + "julian", + "all_leap", + "366_day", + "gregorian", + "proleptic_gregorian", + "standard", + ] + ) +] + +_STANDARD_CALENDAR_NAMES = sorted(_STANDARD_CALENDARS_UNSORTED) +_NON_STANDARD_CALENDAR_NAMES = { + "noleap", "365_day", "360_day", "julian", "all_leap", "366_day", - "gregorian", - "proleptic_gregorian", - "standard", +} +_NON_STANDARD_CALENDARS = [ + pytest.param( + cal, marks=pytest.mark.skipif(not has_cftime, reason="requires cftime") + ) + for cal in sorted(_NON_STANDARD_CALENDAR_NAMES) ] +_STANDARD_CALENDARS = [pytest.param(_) for _ in _STANDARD_CALENDAR_NAMES] +_ALL_CALENDARS = sorted(_STANDARD_CALENDARS + _NON_STANDARD_CALENDARS) + + +def _all_cftime_date_types(): + import cftime + + return { + "noleap": cftime.DatetimeNoLeap, + "365_day": cftime.DatetimeNoLeap, + "360_day": cftime.Datetime360Day, + "julian": cftime.DatetimeJulian, + "all_leap": cftime.DatetimeAllLeap, + "366_day": cftime.DatetimeAllLeap, + "gregorian": cftime.DatetimeGregorian, + "proleptic_gregorian": cftime.DatetimeProlepticGregorian, + } diff --git a/xarray/tests/test_accessor_dt.py b/xarray/tests/test_accessor_dt.py index 64309966103..0d51a292be1 100644 --- a/xarray/tests/test_accessor_dt.py +++ b/xarray/tests/test_accessor_dt.py @@ -6,6 +6,8 @@ import xarray as xr from xarray.tests import ( + _CFTIME_CALENDARS, + _all_cftime_date_types, assert_allclose, assert_array_equal, assert_chunks_equal, @@ -390,15 +392,6 @@ def test_dask_accessor_method(self, method, parameters) -> None: assert_equal(actual.compute(), expected.compute()) -_CFTIME_CALENDARS = [ - "365_day", - "360_day", - "julian", - "all_leap", - "366_day", - "gregorian", - "proleptic_gregorian", -] _NT = 100 @@ -407,6 +400,13 @@ def calendar(request): return request.param +@pytest.fixture() +def cftime_date_type(calendar): + if calendar == "standard": + calendar = "proleptic_gregorian" + return _all_cftime_date_types()[calendar] + + @pytest.fixture() def times(calendar): import cftime @@ -571,13 +571,6 @@ def test_dask_field_access(times_3d, data, field) -> None: assert_equal(result.compute(), expected) -@pytest.fixture() -def cftime_date_type(calendar): - from xarray.tests.test_coding_times import _all_cftime_date_types - - return _all_cftime_date_types()[calendar] - - @requires_cftime def test_seasons(cftime_date_type) -> None: dates = xr.DataArray( diff --git a/xarray/tests/test_cftimeindex.py b/xarray/tests/test_cftimeindex.py index 116487e2bcf..03ea5e544ed 100644 --- a/xarray/tests/test_cftimeindex.py +++ b/xarray/tests/test_cftimeindex.py @@ -18,16 +18,14 @@ parse_iso8601_like, ) from xarray.tests import ( + _ALL_CALENDARS, + _NON_STANDARD_CALENDAR_NAMES, + _all_cftime_date_types, assert_array_equal, assert_identical, has_cftime, requires_cftime, ) -from xarray.tests.test_coding_times import ( - _ALL_CALENDARS, - _NON_STANDARD_CALENDARS, - _all_cftime_date_types, -) # cftime 1.5.2 renames "gregorian" to "standard" standard_or_gregorian = "" @@ -1161,7 +1159,7 @@ def test_to_datetimeindex(calendar, unsafe): index = xr.cftime_range("2000", periods=5, calendar=calendar) expected = pd.date_range("2000", periods=5) - if calendar in _NON_STANDARD_CALENDARS and not unsafe: + if calendar in _NON_STANDARD_CALENDAR_NAMES and not unsafe: with pytest.warns(RuntimeWarning, match="non-standard"): result = index.to_datetimeindex() else: diff --git a/xarray/tests/test_coding_times.py b/xarray/tests/test_coding_times.py index 9a51ca40d07..90c5724782e 100644 --- a/xarray/tests/test_coding_times.py +++ b/xarray/tests/test_coding_times.py @@ -19,7 +19,6 @@ date_range, decode_cf, ) -from xarray.coding.times import _STANDARD_CALENDARS as _STANDARD_CALENDARS_UNSORTED from xarray.coding.times import ( CFDatetimeCoder, _encode_datetime_with_cftime, @@ -42,7 +41,12 @@ from xarray.core.utils import is_duck_dask_array from xarray.testing import assert_equal, assert_identical from xarray.tests import ( + _ALL_CALENDARS, + _NON_STANDARD_CALENDARS, + _STANDARD_CALENDAR_NAMES, + _STANDARD_CALENDARS, FirstElementAccessibleArray, + _all_cftime_date_types, arm_xfail, assert_array_equal, assert_duckarray_allclose, @@ -53,17 +57,6 @@ requires_dask, ) -_NON_STANDARD_CALENDARS_SET = { - "noleap", - "365_day", - "360_day", - "julian", - "all_leap", - "366_day", -} -_STANDARD_CALENDARS = sorted(_STANDARD_CALENDARS_UNSORTED) -_ALL_CALENDARS = sorted(_NON_STANDARD_CALENDARS_SET.union(_STANDARD_CALENDARS)) -_NON_STANDARD_CALENDARS = sorted(_NON_STANDARD_CALENDARS_SET) _CF_DATETIME_NUM_DATES_UNITS = [ (np.arange(10), "days since 2000-01-01"), (np.arange(10).astype("float64"), "days since 2000-01-01"), @@ -99,26 +92,11 @@ _CF_DATETIME_TESTS = [ num_dates_units + (calendar,) for num_dates_units, calendar in product( - _CF_DATETIME_NUM_DATES_UNITS, _STANDARD_CALENDARS + _CF_DATETIME_NUM_DATES_UNITS, _STANDARD_CALENDAR_NAMES ) ] -def _all_cftime_date_types(): - import cftime - - return { - "noleap": cftime.DatetimeNoLeap, - "365_day": cftime.DatetimeNoLeap, - "360_day": cftime.Datetime360Day, - "julian": cftime.DatetimeJulian, - "all_leap": cftime.DatetimeAllLeap, - "366_day": cftime.DatetimeAllLeap, - "gregorian": cftime.DatetimeGregorian, - "proleptic_gregorian": cftime.DatetimeProlepticGregorian, - } - - @requires_cftime @pytest.mark.filterwarnings("ignore:Ambiguous reference date string") @pytest.mark.filterwarnings("ignore:Times can't be serialized faithfully") @@ -666,13 +644,13 @@ def test_decode_cf(calendar) -> None: ds[v].attrs["units"] = "days since 2001-01-01" ds[v].attrs["calendar"] = calendar - if not has_cftime and calendar not in _STANDARD_CALENDARS: + if not has_cftime and calendar not in _STANDARD_CALENDAR_NAMES: with pytest.raises(ValueError): ds = decode_cf(ds) else: ds = decode_cf(ds) - if calendar not in _STANDARD_CALENDARS: + if calendar not in _STANDARD_CALENDAR_NAMES: assert ds.test.dtype == np.dtype("O") else: assert ds.test.dtype == np.dtype("M8[ns]") @@ -1006,7 +984,7 @@ def test_decode_ambiguous_time_warns(calendar) -> None: # we don't decode non-standard calendards with # pandas so expect no warning will be emitted - is_standard_calendar = calendar in _STANDARD_CALENDARS + is_standard_calendar = calendar in _STANDARD_CALENDAR_NAMES dates = [1, 2, 3] units = "days since 1-1-1" From 8773faf5176cc53819bb2a267cd6f80728db7d14 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Thu, 14 Nov 2024 12:31:37 -0700 Subject: [PATCH 12/33] fix test --- xarray/tests/test_groupby.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index 6f0e0503851..a8fb905100a 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -28,6 +28,7 @@ ) from xarray.namedarray.pycompat import is_chunked_array from xarray.tests import ( + _ALL_CALENDARS, InaccessibleArray, assert_allclose, assert_equal, @@ -45,7 +46,6 @@ requires_pandas_ge_2_2, requires_scipy, ) -from xarray.tests.test_coding_times import _ALL_CALENDARS @pytest.fixture From 879af5950049ef791388e2ec7d672e7dbe107387 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Thu, 14 Nov 2024 20:58:56 -0700 Subject: [PATCH 13/33] cftime tests --- xarray/groupers.py | 4 ++-- xarray/tests/test_groupby.py | 26 +++++++++++++++++++++----- 2 files changed, 23 insertions(+), 7 deletions(-) diff --git a/xarray/groupers.py b/xarray/groupers.py index 3c657884634..3a1ea1aebea 100644 --- a/xarray/groupers.py +++ b/xarray/groupers.py @@ -876,8 +876,8 @@ def factorize(self, group: T_Group) -> EncodedGroups: # all years and seasons def get_label(year, season): - month = season_tuples[season][0] - return f"{year}-{month}-01" + month, *_ = season_tuples[season] + return f"{year}-{month:02d}-01" unique_codes = np.arange(len(unique_coord)) valid_season_mask = season_label != "" diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index a8fb905100a..ca27b3643d6 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -3182,13 +3182,21 @@ def test_season_resampling_raises_unsorted_seasons(self, seasons): with pytest.raises(ValueError, match="sort"): da.resample(time=SeasonResampler(seasons)) - # TODO: drop_incomplete - @requires_cftime + @pytest.mark.parametrize( + "use_cftime", + [ + pytest.param( + True, marks=pytest.mark.skipif(not has_cftime, reason="no cftime") + ), + False, + ], + ) @pytest.mark.parametrize("drop_incomplete", [True, False]) @pytest.mark.parametrize( "seasons", [ pytest.param(["DJF", "MAM", "JJA", "SON"], id="standard"), + pytest.param(["NDJ", "FMA", "MJJ", "ASO"], id="nov-first"), pytest.param(["MAM", "JJA", "SON", "DJF"], id="standard-diff-order"), pytest.param(["JFM", "AMJ", "JAS", "OND"], id="december-same-year"), pytest.param(["DJF", "MAM", "JJA", "ON"], id="skip-september"), @@ -3196,9 +3204,17 @@ def test_season_resampling_raises_unsorted_seasons(self, seasons): pytest.param(["MAM", "JJA", "SON", "DJF"], id="different-order"), ], ) - def test_season_resampler(self, seasons: list[str], drop_incomplete: bool) -> None: + def test_season_resampler( + self, seasons: list[str], drop_incomplete: bool, use_cftime: bool + ) -> None: calendar = "standard" - time = date_range("2001-01-01", "2002-12-30", freq="D", calendar=calendar) + time = date_range( + "2001-01-01", + "2002-12-30", + freq="D", + calendar=calendar, + use_cftime=use_cftime, + ) da = DataArray(np.ones(time.size), dims="time", coords={"time": time}) counts = da.resample(time="ME").count() @@ -3239,7 +3255,7 @@ def test_season_resampler(self, seasons: list[str], drop_incomplete: bool) -> No # we construct expected in the standard calendar xr.DataArray(expected_vals, dims="time", coords={"time": expected_time}) # and then convert to the expected calendar, - .convert_calendar(calendar, align_on="date") + .convert_calendar(calendar, align_on="date", use_cftime=use_cftime) # and finally sort since DJF will be out-of-order .sortby("time") ) From 2ca67daa536dcb0ca758e28e19126d1594871648 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Fri, 15 Nov 2024 21:06:55 -0700 Subject: [PATCH 14/33] Fix doctest --- xarray/core/dataarray.py | 6 +++--- xarray/core/dataset.py | 6 +++--- xarray/core/groupby.py | 5 ++++- xarray/tests/test_groupby.py | 4 ++-- 4 files changed, 12 insertions(+), 9 deletions(-) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 52ce2463d51..7320daa11a5 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -6850,7 +6850,7 @@ def groupby( >>> da.groupby("letters") + 'letters': UniqueGrouper('letters'), 2/2 groups with labels 'a', 'b'> Execute a reduction @@ -6866,8 +6866,8 @@ def groupby( >>> da.groupby(["letters", "x"]) + 'letters': UniqueGrouper('letters'), 2/2 groups with labels 'a', 'b' + 'x': UniqueGrouper('x'), 4/4 groups with labels 10, 20, 30, 40> Use Grouper objects to express more complicated GroupBy operations diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index a09a857e331..be342fcddbb 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -10444,7 +10444,7 @@ def groupby( >>> ds.groupby("letters") + 'letters': UniqueGrouper('letters'), 2/2 groups with labels 'a', 'b'> Execute a reduction @@ -10461,8 +10461,8 @@ def groupby( >>> ds.groupby(["letters", "x"]) + 'letters': UniqueGrouper('letters'), 2/2 groups with labels 'a', 'b' + 'x': UniqueGrouper('x'), 4/4 groups with labels 10, 20, 30, 40> Use Grouper objects to express more complicated GroupBy operations diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index f13cda08e30..74f954bc308 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -752,7 +752,10 @@ def __repr__(self) -> str: for grouper in self.groupers: coord = grouper.unique_coord labels = ", ".join(format_array_flat(coord, 30).split()) - text += f"\n {grouper.name!r}: {type(grouper.grouper).__name__}({grouper.group.name!r}), {coord.size} groups with labels {labels}" + text += ( + f"\n {grouper.name!r}: {type(grouper.grouper).__name__}({grouper.group.name!r}), " + f"{coord.size}/{grouper.full_index.size} groups with labels {labels}" + ) return text + ">" def _iter_grouped(self) -> Iterator[T_Xarray]: diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index ca27b3643d6..e9ec82e25d4 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -614,7 +614,7 @@ def test_groupby_repr(obj, dim) -> None: N = len(np.unique(obj[dim])) expected = f"<{obj.__class__.__name__}GroupBy" expected += f", grouped over 1 grouper(s), {N} groups in total:" - expected += f"\n {dim!r}: UniqueGrouper({dim!r}), {N} groups with labels " + expected += f"\n {dim!r}: UniqueGrouper({dim!r}), {N}/{N} groups with labels " if dim == "x": expected += "1, 2, 3, 4, 5>" elif dim == "y": @@ -631,7 +631,7 @@ def test_groupby_repr_datetime(obj) -> None: actual = repr(obj.groupby("t.month")) expected = f"<{obj.__class__.__name__}GroupBy" expected += ", grouped over 1 grouper(s), 12 groups in total:\n" - expected += " 'month': UniqueGrouper('month'), 12 groups with labels " + expected += " 'month': UniqueGrouper('month'), 12/12 groups with labels " expected += "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12>" assert actual == expected From f5191e5084d49d96cac67d89ae028241f6af19cd Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Fri, 15 Nov 2024 21:08:57 -0700 Subject: [PATCH 15/33] typing --- xarray/groupers.py | 1 + 1 file changed, 1 insertion(+) diff --git a/xarray/groupers.py b/xarray/groupers.py index 3a1ea1aebea..60c7575ab42 100644 --- a/xarray/groupers.py +++ b/xarray/groupers.py @@ -836,6 +836,7 @@ def factorize(self, group: T_Group) -> EncodedGroups: first_items = agged["first"] counts = agged["count"] + index_class: CFTimeIndex | pd.DatetimeIndex if _contains_cftime_datetimes(group.data): index_class = CFTimeIndex datetime_class = type(first_n_items(group.data, 1).item()) From 2512d53231538b6ecbae5bae47fd70be4517b474 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Fri, 15 Nov 2024 21:12:47 -0700 Subject: [PATCH 16/33] fix test --- xarray/tests/test_groupby.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index e9ec82e25d4..5e1a4401d0d 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -3163,7 +3163,7 @@ def test_season_to_month_tuple(self): @pytest.mark.parametrize("calendar", _ALL_CALENDARS) def test_season_grouper_simple(self, calendar) -> None: - time = cftime_range("2001-01-01", "2002-12-30", freq="D", calendar=calendar) + time = date_range("2001-01-01", "2002-12-30", freq="D", calendar=calendar) da = DataArray(np.ones(time.size), dims="time", coords={"time": time}) expected = da.groupby("time.season").mean() # note season order matches expected @@ -3229,7 +3229,8 @@ def test_season_resampler( "time", [pd.Timestamp(f"{y}-{m}-01") for y, m in zip(year, month, strict=True)], ) - counts = counts.convert_calendar(calendar, "time", align_on="date") + if has_cftime: + counts = counts.convert_calendar(calendar, "time", align_on="date") expected_vals = [] expected_time = [] @@ -3254,16 +3255,21 @@ def test_season_resampler( expected = ( # we construct expected in the standard calendar xr.DataArray(expected_vals, dims="time", coords={"time": expected_time}) - # and then convert to the expected calendar, - .convert_calendar(calendar, align_on="date", use_cftime=use_cftime) - # and finally sort since DJF will be out-of-order - .sortby("time") ) + if has_cftime: + # and then convert to the expected calendar, + expected = expected.convert_calendar( + calendar, align_on="date", use_cftime=use_cftime + ) + # and finally sort since DJF will be out-of-order + expected = expected.sortby("time") + rs = SeasonResampler(seasons, drop_incomplete=drop_incomplete) # through resample actual = da.resample(time=rs).sum() assert_identical(actual, expected) + @requires_cftime def test_season_resampler_errors(self): time = cftime_range("2001-01-01", "2002-12-30", freq="D", calendar="360_day") da = DataArray(np.ones(time.size), dims="time", coords={"time": time}) From b38553251fb1615c7ddee8699797223f52f89fd7 Mon Sep 17 00:00:00 2001 From: Tom Vo Date: Wed, 20 Nov 2024 09:48:05 -0800 Subject: [PATCH 17/33] Add tests for SeasonGrouper API (PR #9524) (#40) * Add tests for SeasonalGrouper API * Add more tests --- xarray/tests/test_groupby.py | 228 +++++++++++++++++++++++++++++++++++ 1 file changed, 228 insertions(+) diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index 5e1a4401d0d..1b1e9a00bc1 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -3161,6 +3161,234 @@ def test_season_to_month_tuple(self): (10, 11), ) + def test_season_grouper_raises_error_if_months_are_not_valid_or_not_continuous( + self, + ): + calendar = "standard" + time = date_range("2001-01-01", "2002-12-30", freq="D", calendar=calendar) + da = DataArray(np.ones(time.size), dims="time", coords={"time": time}) + + with pytest.raises(KeyError, match="IN"): + da.groupby(time=SeasonGrouper(["INVALID_SEASON"])) + + with pytest.raises(KeyError, match="MD"): + da.groupby(time=SeasonGrouper(["MDF"])) + + @pytest.mark.parametrize("calendar", _ALL_CALENDARS) + def test_season_grouper_with_months_spanning_calendar_year_using_same_year( + self, calendar + ): + time = cftime_range("2001-01-01", "2002-12-30", freq="MS", calendar=calendar) + data = np.array( + [ + 1.0, + 1.25, + 1.5, + 1.75, + 2.0, + 1.1, + 1.35, + 1.6, + 1.85, + 1.2, + 1.45, + 1.7, + 1.95, + 1.05, + 1.3, + 1.55, + 1.8, + 1.15, + 1.4, + 1.65, + 1.9, + 1.25, + 1.5, + 1.75, + ] + ) + da = DataArray(data, dims="time", coords={"time": time}) + da["year"] = da.time.dt.year + + actual = da.groupby( + year=UniqueGrouper(), time=SeasonGrouper(["NDJFM", "AMJ"]) + ).mean() + + # Expected if the same year "ND" is used for seasonal grouping + expected = xr.DataArray( + data=np.array([[1.38, 1.616667], [1.51, 1.5]]), + dims=["year", "season"], + coords={ + "year": [ + 2001, + 2002, + ], + "season": ["NDJFM", "AMJ"], + }, + ) + + assert_allclose(expected, actual) + + @pytest.mark.parametrize("calendar", _ALL_CALENDARS) + def test_season_grouper_with_partial_years(self, calendar): + time = cftime_range("2001-01-01", "2002-06-30", freq="MS", calendar=calendar) + data = np.array( + [ + 1.0, + 1.25, + 1.5, + 1.75, + 2.0, + 1.1, + 1.35, + 1.6, + 1.85, + 1.2, + 1.45, + 1.7, + 1.95, + 1.05, + 1.3, + 1.55, + 1.8, + 1.15, + ] + ) + da = DataArray(data, dims="time", coords={"time": time}) + da["year"] = da.time.dt.year + + actual = da.groupby( + year=UniqueGrouper(), time=SeasonGrouper(["NDJFM", "AMJ"]) + ).mean() + + # Expected if partial years are handled correctly + expected = xr.DataArray( + data=np.array([[1.38, 1.616667], [1.43333333, 1.5]]), + dims=["year", "season"], + coords={ + "year": [ + 2001, + 2002, + ], + "season": ["NDJFM", "AMJ"], + }, + ) + + assert_allclose(expected, actual) + + @pytest.mark.parametrize("calendar", _ALL_CALENDARS) + def test_season_grouper_with_single_month_seasons(self, calendar): + time = cftime_range("2001-01-01", "2002-12-30", freq="MS", calendar=calendar) + data = np.array( + [ + 1.0, + 1.25, + 1.5, + 1.75, + 2.0, + 1.1, + 1.35, + 1.6, + 1.85, + 1.2, + 1.45, + 1.7, + 1.95, + 1.05, + 1.3, + 1.55, + 1.8, + 1.15, + 1.4, + 1.65, + 1.9, + 1.25, + 1.5, + 1.75, + ] + ) + da = DataArray(data, dims="time", coords={"time": time}) + da["year"] = da.time.dt.year + + actual = da.groupby( + year=UniqueGrouper(), + time=SeasonGrouper( + ["J", "F", "M", "A", "M", "J", "J", "A", "S", "O", "N", "D"] + ), + ).mean() + + # Expected if single month seasons are handled correctly + expected = xr.DataArray( + data=np.array( + [ + [1.0, 1.25, 1.5, 1.75, 2.0, 1.1, 1.35, 1.6, 1.85, 1.2, 1.45, 1.7], + [1.95, 1.05, 1.3, 1.55, 1.8, 1.15, 1.4, 1.65, 1.9, 1.25, 1.5, 1.75], + ] + ), + dims=["year", "season"], + coords={ + "year": [ + 2001, + 2002, + ], + "season": ["J", "F", "M", "A", "M", "J", "J", "A", "S", "O", "N", "D"], + }, + ) + + assert_allclose(expected, actual) + + @pytest.mark.xfail + @pytest.mark.parametrize("calendar", _ALL_CALENDARS) + def test_season_grouper_with_months_spanning_calendar_year_using_previous_year( + self, calendar + ): + # NOTE: This feature is not implemented yet. Maybe it can be a + # parameter to the `SeasonGrouper` API (e.g. `use_previous_year=True`). + time = cftime_range("2001-01-01", "2002-12-30", freq="MS", calendar=calendar) + data = np.array( + [ + 1.0, + 1.25, + 1.5, + 1.75, + 2.0, + 1.1, + 1.35, + 1.6, + 1.85, + 1.2, + 1.45, + 1.7, + 1.95, + 1.05, + 1.3, + 1.55, + 1.8, + 1.15, + 1.4, + 1.65, + 1.9, + 1.25, + 1.5, + 1.75, + ] + ) + da = DataArray(data, dims="time", coords={"time": time}) + da["year"] = da.time.dt.year + + actual = da.groupby( + year=UniqueGrouper(), time=SeasonGrouper(["NDJFM", "AMJ"]) + ).mean() + + # Expected if the previous "ND" is used for seasonal grouping + expected = xr.DataArray( + data=np.array([[1.25, 1.616667], [1.49, 1.5], [1.625]]), + dims=["year", "season"], + coords={"year": [2001, 2002, 2003], "season": ["NDJFM", "AMJ"]}, + ) + + assert_allclose(expected, actual) + @pytest.mark.parametrize("calendar", _ALL_CALENDARS) def test_season_grouper_simple(self, calendar) -> None: time = date_range("2001-01-01", "2002-12-30", freq="D", calendar=calendar) From a21952acd2e6076564633127bc81f05c7842d4db Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Thu, 21 Nov 2024 08:58:39 -0700 Subject: [PATCH 18/33] try fixing test --- xarray/tests/test_groupby.py | 67 +++++++----------------------------- 1 file changed, 13 insertions(+), 54 deletions(-) diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index 1b1e9a00bc1..b454755e398 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -3179,34 +3179,15 @@ def test_season_grouper_with_months_spanning_calendar_year_using_same_year( self, calendar ): time = cftime_range("2001-01-01", "2002-12-30", freq="MS", calendar=calendar) + # fmt: off data = np.array( [ - 1.0, - 1.25, - 1.5, - 1.75, - 2.0, - 1.1, - 1.35, - 1.6, - 1.85, - 1.2, - 1.45, - 1.7, - 1.95, - 1.05, - 1.3, - 1.55, - 1.8, - 1.15, - 1.4, - 1.65, - 1.9, - 1.25, - 1.5, - 1.75, + 1.0, 1.25, 1.5, 1.75, 2.0, 1.1, 1.35, 1.6, 1.85, 1.2, 1.45, 1.7, + 1.95, 1.05, 1.3, 1.55, 1.8, 1.15, 1.4, 1.65, 1.9, 1.25, 1.5, 1.75, ] + ) + # fmt: on da = DataArray(data, dims="time", coords={"time": time}) da["year"] = da.time.dt.year @@ -3337,7 +3318,6 @@ def test_season_grouper_with_single_month_seasons(self, calendar): assert_allclose(expected, actual) - @pytest.mark.xfail @pytest.mark.parametrize("calendar", _ALL_CALENDARS) def test_season_grouper_with_months_spanning_calendar_year_using_previous_year( self, calendar @@ -3345,46 +3325,25 @@ def test_season_grouper_with_months_spanning_calendar_year_using_previous_year( # NOTE: This feature is not implemented yet. Maybe it can be a # parameter to the `SeasonGrouper` API (e.g. `use_previous_year=True`). time = cftime_range("2001-01-01", "2002-12-30", freq="MS", calendar=calendar) + # fmt: off data = np.array( [ - 1.0, - 1.25, - 1.5, - 1.75, - 2.0, - 1.1, - 1.35, - 1.6, - 1.85, - 1.2, - 1.45, - 1.7, - 1.95, - 1.05, - 1.3, - 1.55, - 1.8, - 1.15, - 1.4, - 1.65, - 1.9, - 1.25, - 1.5, - 1.75, + 1.0, 1.25, 1.5, 1.75, 2.0, 1.1, 1.35, 1.6, 1.85, 1.2, 1.45, 1.7, + 1.95, 1.05, 1.3, 1.55, 1.8, 1.15, 1.4, 1.65, 1.9, 1.25, 1.5, 1.75, ] ) + # fmt: on da = DataArray(data, dims="time", coords={"time": time}) da["year"] = da.time.dt.year - actual = da.groupby( - year=UniqueGrouper(), time=SeasonGrouper(["NDJFM", "AMJ"]) - ).mean() + gb = da.groupby(year=UniqueGrouper(), time=SeasonGrouper(["NDJFM", "AMJ"])) + actual = gb.mean() # Expected if the previous "ND" is used for seasonal grouping expected = xr.DataArray( - data=np.array([[1.25, 1.616667], [1.49, 1.5], [1.625]]), + data=np.array([[1.25, 1.616667], [1.49, 1.5], [1.625, np.nan]]), dims=["year", "season"], - coords={"year": [2001, 2002, 2003], "season": ["NDJFM", "AMJ"]}, + coords={"year": [2000, 2001, 2002], "season": ["NDJFM", "AMJ"]}, ) assert_allclose(expected, actual) From bc867514216f816a821e5708ba73b0a749c2cdec Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Wed, 8 Jan 2025 10:37:39 -0700 Subject: [PATCH 19/33] lint --- xarray/tests/test_groupby.py | 33 +++++---------------------------- 1 file changed, 5 insertions(+), 28 deletions(-) diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index 9715ffe9ace..6bdb1f9f26e 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -3376,34 +3376,14 @@ def test_season_grouper_with_partial_years(self, calendar): @pytest.mark.parametrize("calendar", _ALL_CALENDARS) def test_season_grouper_with_single_month_seasons(self, calendar): time = cftime_range("2001-01-01", "2002-12-30", freq="MS", calendar=calendar) + # fmt: off data = np.array( [ - 1.0, - 1.25, - 1.5, - 1.75, - 2.0, - 1.1, - 1.35, - 1.6, - 1.85, - 1.2, - 1.45, - 1.7, - 1.95, - 1.05, - 1.3, - 1.55, - 1.8, - 1.15, - 1.4, - 1.65, - 1.9, - 1.25, - 1.5, - 1.75, + 1.0, 1.25, 1.5, 1.75, 2.0, 1.1, 1.35, 1.6, 1.85, 1.2, 1.45, 1.7, + 1.95, 1.05, 1.3, 1.55, 1.8, 1.15, 1.4, 1.65, 1.9, 1.25, 1.5, 1.75, ] ) + # fmt: on da = DataArray(data, dims="time", coords={"time": time}) da["year"] = da.time.dt.year @@ -3424,10 +3404,7 @@ def test_season_grouper_with_single_month_seasons(self, calendar): ), dims=["year", "season"], coords={ - "year": [ - 2001, - 2002, - ], + "year": [2001, 2002], "season": ["J", "F", "M", "A", "M", "J", "J", "A", "S", "O", "N", "D"], }, ) From 64c99c500a549646afbeff07025b71f3377809a0 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Tue, 18 Mar 2025 21:04:48 -0600 Subject: [PATCH 20/33] format --- xarray/tests/test_groupby.py | 34 ++++++++++------------------------ 1 file changed, 10 insertions(+), 24 deletions(-) diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index af45e01ef47..01cc9278710 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -13,7 +13,7 @@ from packaging.version import Version import xarray as xr -from xarray import DataArray, Dataset, Variable, cftime_range, date_range +from xarray import DataArray, Dataset, Variable, date_range from xarray.core.groupby import _consolidate_slices from xarray.core.types import InterpOptions, ResampleCompatible from xarray.groupers import ( @@ -3294,7 +3294,7 @@ def test_season_grouper_raises_error_if_months_are_not_valid_or_not_continuous( def test_season_grouper_with_months_spanning_calendar_year_using_same_year( self, calendar ): - time = cftime_range("2001-01-01", "2002-12-30", freq="MS", calendar=calendar) + time = date_range("2001-01-01", "2002-12-30", freq="MS", calendar=calendar) # fmt: off data = np.array( [ @@ -3328,29 +3328,15 @@ def test_season_grouper_with_months_spanning_calendar_year_using_same_year( @pytest.mark.parametrize("calendar", _ALL_CALENDARS) def test_season_grouper_with_partial_years(self, calendar): - time = cftime_range("2001-01-01", "2002-06-30", freq="MS", calendar=calendar) + time = date_range("2001-01-01", "2002-06-30", freq="MS", calendar=calendar) + # fmt: off data = np.array( [ - 1.0, - 1.25, - 1.5, - 1.75, - 2.0, - 1.1, - 1.35, - 1.6, - 1.85, - 1.2, - 1.45, - 1.7, - 1.95, - 1.05, - 1.3, - 1.55, - 1.8, - 1.15, + 1.0, 1.25, 1.5, 1.75, 2.0, 1.1, 1.35, 1.6, 1.85, 1.2, 1.45, 1.7, + 1.95, 1.05, 1.3, 1.55, 1.8, 1.15, ] ) + # fmt: on da = DataArray(data, dims="time", coords={"time": time}) da["year"] = da.time.dt.year @@ -3375,7 +3361,7 @@ def test_season_grouper_with_partial_years(self, calendar): @pytest.mark.parametrize("calendar", _ALL_CALENDARS) def test_season_grouper_with_single_month_seasons(self, calendar): - time = cftime_range("2001-01-01", "2002-12-30", freq="MS", calendar=calendar) + time = date_range("2001-01-01", "2002-12-30", freq="MS", calendar=calendar) # fmt: off data = np.array( [ @@ -3417,7 +3403,7 @@ def test_season_grouper_with_months_spanning_calendar_year_using_previous_year( ): # NOTE: This feature is not implemented yet. Maybe it can be a # parameter to the `SeasonGrouper` API (e.g. `use_previous_year=True`). - time = cftime_range("2001-01-01", "2002-12-30", freq="MS", calendar=calendar) + time = date_range("2001-01-01", "2002-12-30", freq="MS", calendar=calendar) # fmt: off data = np.array( [ @@ -3551,7 +3537,7 @@ def test_season_resampler( @requires_cftime def test_season_resampler_errors(self): - time = cftime_range("2001-01-01", "2002-12-30", freq="D", calendar="360_day") + time = date_range("2001-01-01", "2002-12-30", freq="D", calendar="360_day") da = DataArray(np.ones(time.size), dims="time", coords={"time": time}) # non-datetime array From 594f2850e71b13e7d023e733ba3c3daf595fdb0d Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Tue, 18 Mar 2025 21:20:57 -0600 Subject: [PATCH 21/33] fix test --- xarray/tests/test_groupby.py | 45 ++++++++++++++++++------------------ 1 file changed, 23 insertions(+), 22 deletions(-) diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index 01cc9278710..2615522eb1f 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -3315,13 +3315,7 @@ def test_season_grouper_with_months_spanning_calendar_year_using_same_year( expected = xr.DataArray( data=np.array([[1.38, 1.616667], [1.51, 1.5]]), dims=["year", "season"], - coords={ - "year": [ - 2001, - 2002, - ], - "season": ["NDJFM", "AMJ"], - }, + coords={"year": [2001, 2002], "season": ["NDJFM", "AMJ"]}, ) assert_allclose(expected, actual) @@ -3348,13 +3342,7 @@ def test_season_grouper_with_partial_years(self, calendar): expected = xr.DataArray( data=np.array([[1.38, 1.616667], [1.43333333, 1.5]]), dims=["year", "season"], - coords={ - "year": [ - 2001, - 2002, - ], - "season": ["NDJFM", "AMJ"], - }, + coords={"year": [2001, 2002], "season": ["NDJFM", "AMJ"]}, ) assert_allclose(expected, actual) @@ -3401,8 +3389,6 @@ def test_season_grouper_with_single_month_seasons(self, calendar): def test_season_grouper_with_months_spanning_calendar_year_using_previous_year( self, calendar ): - # NOTE: This feature is not implemented yet. Maybe it can be a - # parameter to the `SeasonGrouper` API (e.g. `use_previous_year=True`). time = date_range("2001-01-01", "2002-12-30", freq="MS", calendar=calendar) # fmt: off data = np.array( @@ -3413,18 +3399,33 @@ def test_season_grouper_with_months_spanning_calendar_year_using_previous_year( ) # fmt: on da = DataArray(data, dims="time", coords={"time": time}) - da["year"] = da.time.dt.year - gb = da.groupby(year=UniqueGrouper(), time=SeasonGrouper(["NDJFM", "AMJ"])) + gb = da.resample(time=SeasonResampler(["NDJFM", "AMJ"], drop_incomplete=False)) actual = gb.mean() + new_time = ( + xr.DataArray( + dims="time", + data=pd.DatetimeIndex( + [ + "2000-11-01", + "2001-04-01", + "2001-11-01", + "2002-04-01", + "2002-11-01", + ] + ), + ) + .convert_calendar(calendar=calendar, align_on="date") + .time.variable + ) + # Expected if the previous "ND" is used for seasonal grouping expected = xr.DataArray( - data=np.array([[1.25, 1.616667], [1.49, 1.5], [1.625, np.nan]]), - dims=["year", "season"], - coords={"year": [2000, 2001, 2002], "season": ["NDJFM", "AMJ"]}, + data=np.array([1.25, 1.616667, 1.49, 1.5, 1.625]), + dims="time", + coords={"time": new_time}, ) - assert_allclose(expected, actual) @pytest.mark.parametrize("calendar", _ALL_CALENDARS) From 1313ab998db4e3ed2447062dd7049fef5a0f780c Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Tue, 18 Mar 2025 21:26:50 -0600 Subject: [PATCH 22/33] cleanup --- xarray/groupers.py | 26 ++------------------------ 1 file changed, 2 insertions(+), 24 deletions(-) diff --git a/xarray/groupers.py b/xarray/groupers.py index c9ae924ea9b..c42b1d223d1 100644 --- a/xarray/groupers.py +++ b/xarray/groupers.py @@ -705,6 +705,7 @@ class SeasonGrouper(Grouper): ---------- seasons: sequence of str List of strings representing seasons. E.g. ``"JF"`` or ``"JJA"`` etc. + Overlapping seasons are allowed (e.g. ``["DJFM", "MAMJ", "JJAS", "SOND"]``) Examples -------- @@ -880,12 +881,6 @@ def factorize(self, group: T_Group) -> EncodedGroups: ] ) - # sbins = first_items.values.astype(int) - # group_indices = [ - # slice(i, j) for i, j in zip(sbins[:-1], sbins[1:], strict=True) - # ] - # group_indices += [slice(sbins[-1], None)] - # This sorted call is a hack. It's hard to figure out how # to start the iteration for arbitrary season ordering # for example "DJF" as first entry or last entry @@ -919,7 +914,6 @@ def get_label(year, season): first_valid_season = seasons[ (seasons.index(first_valid_season) + 1) % len(seasons) ] - # group_indices = group_indices[slice(1, None)] unique_codes -= 1 if ( @@ -929,7 +923,6 @@ def get_label(year, season): last_valid_season = seasons[seasons.index(last_valid_season) - 1] if "DJ" in last_valid_season: last_year -= 1 - # group_indices = group_indices[slice(-1)] unique_codes[-1] = -1 first_label = get_label(first_year, first_valid_season) @@ -937,27 +930,12 @@ def get_label(year, season): slicer = complete_index.slice_indexer(first_label, last_label) full_index = complete_index[slicer] - # TODO: group must be sorted - # codes = np.searchsorted(edges, group.data, side="left") - # codes -= 1 - # codes[~present_seasons | group.data >= edges[-1]] = -1 - # codes[isnull(group.data)] = -1 - # import ipdb; ipdb.set_trace() - # check that there are no "missing" seasons in the middle - # if not full_index.equals(unique_coord): - # raise ValueError("Are there seasons missing in the middle of the dataset?") final_codes = np.full(group.data.size, -1) final_codes[present_seasons] = np.repeat(unique_codes, counts) codes = group.copy(data=final_codes, deep=False) - # unique_coord_var = Variable(group.name, unique_coord, group.attrs) - return EncodedGroups( - codes=codes, - # group_indices=group_indices, - # unique_coord=unique_coord_var, - full_index=full_index, - ) + return EncodedGroups(codes=codes, full_index=full_index) def reset(self) -> Self: return type(self)(seasons=self.seasons, drop_incomplete=self.drop_incomplete) From 32d9ed05c9a084b203835fa9e3d04aa5581a9dfc Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Tue, 18 Mar 2025 21:28:07 -0600 Subject: [PATCH 23/33] more cleanup --- xarray/{core => compat}/toolzcompat.py | 0 xarray/groupers.py | 4 ++-- 2 files changed, 2 insertions(+), 2 deletions(-) rename xarray/{core => compat}/toolzcompat.py (100%) diff --git a/xarray/core/toolzcompat.py b/xarray/compat/toolzcompat.py similarity index 100% rename from xarray/core/toolzcompat.py rename to xarray/compat/toolzcompat.py diff --git a/xarray/groupers.py b/xarray/groupers.py index c42b1d223d1..26a811d84a8 100644 --- a/xarray/groupers.py +++ b/xarray/groupers.py @@ -23,6 +23,7 @@ from xarray.coding.cftime_offsets import BaseCFTimeOffset, _new_to_legacy_freq from xarray.coding.cftimeindex import CFTimeIndex +from xarray.compat.toolzcompat import sliding_window from xarray.computation.computation import apply_ufunc from xarray.core.common import ( _contains_cftime_datetimes, @@ -35,7 +36,6 @@ from xarray.core.groupby import T_Group, _DummyGroup from xarray.core.indexes import safe_cast_to_index from xarray.core.resample_cftime import CFTimeGrouper -from xarray.core.toolzcompat import sliding_window from xarray.core.types import ( Bins, DatetimeLike, @@ -865,7 +865,7 @@ def factorize(self, group: T_Group) -> EncodedGroups: first_items = agged["first"] counts = agged["count"] - index_class: CFTimeIndex | pd.DatetimeIndex + index_class: type[CFTimeIndex] | type[pd.DatetimeIndex] if _contains_cftime_datetimes(group.data): index_class = CFTimeIndex datetime_class = type(first_n_items(group.data, 1).item()) From b068e9427d697805c75e99dc4efe7bd5e9bba4db Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Tue, 18 Mar 2025 21:42:41 -0600 Subject: [PATCH 24/33] fix --- xarray/tests/test_coding_times.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/tests/test_coding_times.py b/xarray/tests/test_coding_times.py index fe7781cf671..e4541bad7e6 100644 --- a/xarray/tests/test_coding_times.py +++ b/xarray/tests/test_coding_times.py @@ -1932,7 +1932,7 @@ def test_duck_array_decode_times(calendar) -> None: decoded = conventions.decode_cf_variable( "foo", var, decode_times=CFDatetimeCoder(use_cftime=None) ) - if calendar not in _STANDARD_CALENDARS: + if calendar not in _STANDARD_CALENDAR_NAMES: assert decoded.dtype == np.dtype("O") else: assert decoded.dtype == np.dtype("=M8[ns]") From 862cf2ada4bc063ca0c9280c3c2555311fde700c Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Wed, 19 Mar 2025 20:08:24 -0600 Subject: [PATCH 25/33] Fix automatic inference of unique_coord --- xarray/groupers.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/xarray/groupers.py b/xarray/groupers.py index c2beadc916a..5eedb92c007 100644 --- a/xarray/groupers.py +++ b/xarray/groupers.py @@ -118,7 +118,10 @@ def __init__( self.group_indices = group_indices if unique_coord is None: - unique_values = full_index[np.unique(codes)] + unique_codes = np.sort(pd.unique(codes.data)) + # Skip the -1 sentinel + unique_codes = unique_codes[unique_codes >= 0] + unique_values = full_index[unique_codes] self.unique_coord = Variable( dims=codes.name, data=unique_values, attrs=codes.attrs ) From f3f7d52f12e5917feb259b8df4e869157bcdf474 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Wed, 19 Mar 2025 20:10:16 -0600 Subject: [PATCH 26/33] Squashed commit of the following: commit 583a3d2732ff014d4f20d56f3f9238f7fc9faefc Author: Deepak Cherian Date: Wed Mar 19 12:55:54 2025 -0600 fix mypy commit 699c3b82c3d0674796097846929dac5918883b38 Author: Deepak Cherian Date: Wed Mar 19 09:30:38 2025 -0600 Preserve label ordering for multi-variable GroupBy --- xarray/core/groupby.py | 27 +++++++++++++++++---- xarray/groupers.py | 2 +- xarray/tests/test_groupby.py | 47 +++++++++++++++++++++++++++++------- 3 files changed, 61 insertions(+), 15 deletions(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index ab4227daadf..1c01e6fe7b7 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -536,6 +536,11 @@ def factorize(self) -> EncodedGroups: list(grouper.full_index.values for grouper in groupers), names=tuple(grouper.name for grouper in groupers), ) + if not full_index.is_unique: + raise ValueError( + "The output index for the GroupBy is non-unique. " + "This is a bug in the Grouper provided." + ) # This will be unused when grouping by dask arrays, so skip.. if not is_chunked_array(_flatcodes): # Constructing an index from the product is wrong when there are missing groups @@ -947,17 +952,29 @@ def _binary_op(self, other, f, reflexive=False): def _restore_dim_order(self, stacked): raise NotImplementedError - def _maybe_restore_empty_groups(self, combined): - """Our index contained empty groups (e.g., from a resampling or binning). If we + def _maybe_reindex(self, combined): + """Reindexing is needed in two cases: + 1. Our index contained empty groups (e.g., from a resampling or binning). If we reduced on that dimension, we want to restore the full index. + + 2. We use a MultiIndex for multi-variable GroupBy. + The MultiIndex stores each level's labels in sorted order + which are then assigned on unstacking. So we need to restore + the correct order here. """ has_missing_groups = ( self.encoded.unique_coord.size != self.encoded.full_index.size ) indexers = {} for grouper in self.groupers: - if has_missing_groups and grouper.name in combined._indexes: + index = combined._indexes.get(grouper.name, None) + if has_missing_groups and index is not None: indexers[grouper.name] = grouper.full_index + elif len(self.groupers) > 1: + if not isinstance( + grouper.full_index, pd.RangeIndex + ) and not index.index.equals(grouper.full_index): + indexers[grouper.name] = grouper.full_index if indexers: combined = combined.reindex(**indexers) return combined @@ -1597,7 +1614,7 @@ def _combine(self, applied, shortcut=False): if dim not in applied_example.dims: combined = combined.assign_coords(self.encoded.coords) combined = self._maybe_unstack(combined) - combined = self._maybe_restore_empty_groups(combined) + combined = self._maybe_reindex(combined) return combined def reduce( @@ -1753,7 +1770,7 @@ def _combine(self, applied): if dim not in applied_example.dims: combined = combined.assign_coords(self.encoded.coords) combined = self._maybe_unstack(combined) - combined = self._maybe_restore_empty_groups(combined) + combined = self._maybe_reindex(combined) return combined def reduce( diff --git a/xarray/groupers.py b/xarray/groupers.py index 5eedb92c007..7e32d67fe6e 100644 --- a/xarray/groupers.py +++ b/xarray/groupers.py @@ -536,7 +536,7 @@ def first_items(self) -> tuple[pd.Series, np.ndarray]: counts = grouped.count() # This way we generate codes for the final output index: full_index. # So for _flox_reduce we avoid one reindex and copy by avoiding - # _maybe_restore_empty_groups + # _maybe_reindex codes = np.repeat(np.arange(len(first_items)), counts) return first_items, codes diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index 0d4a3a6774a..4d53449e4be 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -158,7 +158,7 @@ def test_multi_index_groupby_sum() -> None: @requires_pandas_ge_2_2 -def test_multi_index_propagation(): +def test_multi_index_propagation() -> None: # regression test for GH9648 times = pd.date_range("2023-01-01", periods=4) locations = ["A", "B"] @@ -2295,7 +2295,7 @@ def test_resample_origin(self) -> None: times = pd.date_range("2000-01-01T02:03:01", freq="6h", periods=10) array = DataArray(np.arange(10), [("time", times)]) - origin = "start" + origin: Literal["start"] = "start" actual = array.resample(time="24h", origin=origin).mean() expected = DataArray(array.to_series().resample("24h", origin=origin).mean()) assert_identical(expected, actual) @@ -2700,7 +2700,7 @@ def test_default_flox_method() -> None: @requires_cftime @pytest.mark.filterwarnings("ignore") -def test_cftime_resample_gh_9108(): +def test_cftime_resample_gh_9108() -> None: import cftime ds = Dataset( @@ -3050,7 +3050,7 @@ def test_gappy_resample_reductions(reduction): assert_identical(expected, actual) -def test_groupby_transpose(): +def test_groupby_transpose() -> None: # GH5361 data = xr.DataArray( np.random.randn(4, 2), @@ -3110,7 +3110,7 @@ def test_lazy_grouping(grouper, expect_index): @requires_dask -def test_lazy_grouping_errors(): +def test_lazy_grouping_errors() -> None: import dask.array data = DataArray( @@ -3136,7 +3136,7 @@ def test_lazy_grouping_errors(): @requires_dask -def test_lazy_int_bins_error(): +def test_lazy_int_bins_error() -> None: import dask.array with pytest.raises(ValueError, match="Bin edges must be provided"): @@ -3144,7 +3144,7 @@ def test_lazy_int_bins_error(): _ = BinGrouper(bins=4).factorize(DataArray(dask.array.arange(3))) -def test_time_grouping_seasons_specified(): +def test_time_grouping_seasons_specified() -> None: time = xr.date_range("2001-01-01", "2002-01-01", freq="D") ds = xr.Dataset({"foo": np.arange(time.size)}, coords={"time": ("time", time)}) labels = ["DJF", "MAM", "JJA", "SON"] @@ -3153,7 +3153,36 @@ def test_time_grouping_seasons_specified(): assert_identical(actual, expected.reindex(season=labels)) -def test_groupby_multiple_bin_grouper_missing_groups(): +def test_multiple_grouper_unsorted_order() -> None: + time = xr.date_range("2001-01-01", "2003-01-01", freq="MS") + ds = xr.Dataset({"foo": np.arange(time.size)}, coords={"time": ("time", time)}) + labels = ["DJF", "MAM", "JJA", "SON"] + actual = ds.groupby( + { + "time.season": UniqueGrouper(labels=labels), + "time.year": UniqueGrouper(labels=[2002, 2001]), + } + ).sum() + expected = ( + ds.groupby({"time.season": UniqueGrouper(), "time.year": UniqueGrouper()}) + .sum() + .reindex(season=labels, year=[2002, 2001]) + ) + assert_identical(actual, expected.reindex(season=labels)) + + b = xr.DataArray( + np.random.default_rng(0).random((2, 3, 4)), + coords={"x": [0, 1], "y": [0, 1, 2]}, + dims=["x", "y", "z"], + ) + actual2 = b.groupby( + x=UniqueGrouper(labels=[1, 0]), y=UniqueGrouper(labels=[2, 0, 1]) + ).sum() + expected2 = b.reindex(x=[1, 0], y=[2, 0, 1]).transpose("z", ...) + assert_identical(actual2, expected2) + + +def test_groupby_multiple_bin_grouper_missing_groups() -> None: from numpy import nan ds = xr.Dataset( @@ -3230,7 +3259,7 @@ def test_shuffle_by(chunks, expected_chunks): @requires_dask -def test_groupby_dask_eager_load_warnings(): +def test_groupby_dask_eager_load_warnings() -> None: ds = xr.Dataset( {"foo": (("z"), np.arange(12))}, coords={"x": ("z", np.arange(12)), "y": ("z", np.arange(12))}, From 85d9217772e838836607ac39761995afede21a9e Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Wed, 19 Mar 2025 21:24:50 -0600 Subject: [PATCH 27/33] cleanup --- properties/test_properties.py | 19 ++++++++++-- xarray/groupers.py | 28 +++++++++++++++--- xarray/tests/test_groupby.py | 55 +++++++++++++++++------------------ 3 files changed, 67 insertions(+), 35 deletions(-) diff --git a/properties/test_properties.py b/properties/test_properties.py index 24de8049f58..2ae91a15801 100644 --- a/properties/test_properties.py +++ b/properties/test_properties.py @@ -5,11 +5,11 @@ pytest.importorskip("hypothesis") import hypothesis.strategies as st -from hypothesis import given +from hypothesis import given, note import xarray as xr import xarray.testing.strategies as xrst -from xarray.groupers import season_to_month_tuple +from xarray.groupers import find_independent_seasons, season_to_month_tuple @given(attrs=xrst.simple_attrs) @@ -46,3 +46,18 @@ def test_property_season_month_tuple(roll, breaks): rolled_months[start:stop] for start, stop in itertools.pairwise(breaks) ) assert expected == actual + + +@given(data=st.data(), nmonths=st.integers(min_value=1, max_value=11)) +def test_property_find_independent_seasons(data, nmonths): + chars = "JFMAMJJASOND" + # if stride > nmonths, then we can't infer season order + stride = data.draw(st.integers(min_value=1, max_value=nmonths)) + chars = chars + chars[:nmonths] + seasons = [list(chars[i : i + nmonths]) for i in range(0, 12, stride)] + note(seasons) + groups = find_independent_seasons(seasons) + for group in groups: + inds = tuple(itertools.chain(*group.inds)) + assert len(inds) == len(set(inds)) + assert len(group.codes) == len(set(group.codes)) diff --git a/xarray/groupers.py b/xarray/groupers.py index 7e32d67fe6e..06b0131650c 100644 --- a/xarray/groupers.py +++ b/xarray/groupers.py @@ -604,6 +604,14 @@ def unique_value_groups( def season_to_month_tuple(seasons: Sequence[str]) -> tuple[tuple[int, ...], ...]: + """ + >>> season_to_month_tuple(["DJF", "MAM", "JJA", "SON"]) + ((12, 1, 2), (3, 4, 5), (6, 7, 8), (9, 10, 11)) + >>> season_to_month_tuple(["DJFM", "MAMJ", "JJAS", "SOND"]) + ((12, 1, 2, 3), (3, 4, 5, 6), (6, 7, 8, 9), (9, 10, 11, 12)) + >>> season_to_month_tuple(["DJFM", "SOND"]) + ((12, 1, 2, 3), (9, 10, 11, 12)) + """ initials = "JFMAMJJASOND" starts = dict( ("".join(s), i + 1) @@ -629,7 +637,7 @@ def season_to_month_tuple(seasons: Sequence[str]) -> tuple[tuple[int, ...], ...] return tuple(result) -def inds_to_string(asints: tuple[tuple[int, ...], ...]) -> tuple[str, ...]: +def inds_to_season_string(asints: tuple[tuple[int, ...], ...]) -> tuple[str, ...]: inits = "JFMAMJJASOND" return tuple("".join([inits[i_ - 1] for i_ in t]) for t in asints) @@ -660,10 +668,13 @@ def is_sorted_periodic(lst): return lst[-1] <= lst[0] -@dataclass +@dataclass(kw_only=True, frozen=True) class SeasonsGroup: seasons: tuple[str, ...] + # tuple[integer months] corresponding to each season inds: tuple[tuple[int, ...], ...] + # integer code for each season, this is not simply range(len(seasons)) + # when the seasons have overlaps codes: Sequence[int] @@ -671,13 +682,22 @@ def find_independent_seasons(seasons: Sequence[str]) -> Sequence[SeasonsGroup]: """ Iterates though a list of seasons e.g. ["DJF", "FMA", ...], and splits that into multiple sequences of non-overlapping seasons. + + >>> find_independent_seasons( + ... ["DJF", "FMA", "AMJ", "JJA", "ASO", "OND"] + ... ) # doctest: +NORMALIZE_WHITESPACE + [SeasonsGroup(seasons=('DJF', 'AMJ', 'ASO'), inds=((12, 1, 2), (4, 5, 6), (8, 9, 10)), codes=[0, 2, 4]), + SeasonsGroup(seasons=('FMA', 'JJA', 'OND'), inds=((2, 3, 4), (6, 7, 8), (10, 11, 12)), codes=[1, 3, 5])] + + >>> find_independent_seasons(["DJF", "MAM", "JJA", "SON"]) + [SeasonsGroup(seasons=('DJF', 'MAM', 'JJA', 'SON'), inds=((12, 1, 2), (3, 4, 5), (6, 7, 8), (9, 10, 11)), codes=[0, 1, 2, 3])] """ season_inds = season_to_month_tuple(seasons) grouped = defaultdict(list) codes = defaultdict(list) seen: set[tuple[int, ...]] = set() idx = 0 - # This is quadratic, but the length of seasons is at most 12 + # This is quadratic, but the number of seasons is at most 12 for i, current in enumerate(season_inds): # Start with a group if current not in seen: @@ -699,7 +719,7 @@ def find_independent_seasons(seasons: Sequence[str]) -> Sequence[SeasonsGroup]: grouped_ints = tuple(tuple(idx) for idx in grouped.values() if idx) return [ - SeasonsGroup(seasons=inds_to_string(inds), inds=inds, codes=codes) + SeasonsGroup(seasons=inds_to_season_string(inds), inds=inds, codes=codes) for inds, codes in zip(grouped_ints, codes.values(), strict=False) ] diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index 4d53449e4be..bcc2a55c5b1 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -3378,7 +3378,7 @@ def test_season_grouper_with_partial_years(self, calendar): assert_allclose(expected, actual) - @pytest.mark.parametrize("calendar", _ALL_CALENDARS) + @pytest.mark.parametrize("calendar", ["standard"]) def test_season_grouper_with_single_month_seasons(self, calendar): time = date_range("2001-01-01", "2002-12-30", freq="MS", calendar=calendar) # fmt: off @@ -3392,29 +3392,32 @@ def test_season_grouper_with_single_month_seasons(self, calendar): da = DataArray(data, dims="time", coords={"time": time}) da["year"] = da.time.dt.year - actual = da.groupby( - year=UniqueGrouper(), - time=SeasonGrouper( - ["J", "F", "M", "A", "M", "J", "J", "A", "S", "O", "N", "D"] - ), - ).mean() + # TODO: Consider supporting this if needed + # It does not work without flox, because the group labels are not unique, + # and so the stack/unstack approach does not work. + with pytest.raises(ValueError): + da.groupby( + year=UniqueGrouper(), + time=SeasonGrouper( + ["J", "F", "M", "A", "M", "J", "J", "A", "S", "O", "N", "D"] + ), + ).mean() # Expected if single month seasons are handled correctly - expected = xr.DataArray( - data=np.array( - [ - [1.0, 1.25, 1.5, 1.75, 2.0, 1.1, 1.35, 1.6, 1.85, 1.2, 1.45, 1.7], - [1.95, 1.05, 1.3, 1.55, 1.8, 1.15, 1.4, 1.65, 1.9, 1.25, 1.5, 1.75], - ] - ), - dims=["year", "season"], - coords={ - "year": [2001, 2002], - "season": ["J", "F", "M", "A", "M", "J", "J", "A", "S", "O", "N", "D"], - }, - ) - - assert_allclose(expected, actual) + # expected = xr.DataArray( + # data=np.array( + # [ + # [1.0, 1.25, 1.5, 1.75, 2.0, 1.1, 1.35, 1.6, 1.85, 1.2, 1.45, 1.7], + # [1.95, 1.05, 1.3, 1.55, 1.8, 1.15, 1.4, 1.65, 1.9, 1.25, 1.5, 1.75], + # ] + # ), + # dims=["year", "season"], + # coords={ + # "year": [2001, 2002], + # "season": ["J", "F", "M", "A", "M", "J", "J", "A", "S", "O", "N", "D"], + # }, + # ) + # assert_allclose(expected, actual) @pytest.mark.parametrize("calendar", _ALL_CALENDARS) def test_season_grouper_with_months_spanning_calendar_year_using_previous_year( @@ -3481,13 +3484,7 @@ def test_season_resampling_raises_unsorted_seasons(self, seasons): da.resample(time=SeasonResampler(seasons)) @pytest.mark.parametrize( - "use_cftime", - [ - pytest.param( - True, marks=pytest.mark.skipif(not has_cftime, reason="no cftime") - ), - False, - ], + "use_cftime", [pytest.param(True, marks=requires_cftime), False] ) @pytest.mark.parametrize("drop_incomplete", [True, False]) @pytest.mark.parametrize( From de26f38b2a51b4642c3d74127560bcb14ef32ca8 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Wed, 19 Mar 2025 22:36:07 -0600 Subject: [PATCH 28/33] Fix --- xarray/tests/__init__.py | 5 ++++- xarray/tests/test_groupby.py | 31 ++++++++++++++++--------------- 2 files changed, 20 insertions(+), 16 deletions(-) diff --git a/xarray/tests/__init__.py b/xarray/tests/__init__.py index 31024d72e60..37a7509b820 100644 --- a/xarray/tests/__init__.py +++ b/xarray/tests/__init__.py @@ -384,7 +384,10 @@ def create_test_data( pytest.param(cal, marks=requires_cftime) for cal in sorted(_NON_STANDARD_CALENDAR_NAMES) ] -_STANDARD_CALENDARS = [pytest.param(cal) for cal in _STANDARD_CALENDAR_NAMES] +_STANDARD_CALENDARS = [ + pytest.param(cal, marks=requires_cftime if cal != "standard" else ()) + for cal in _STANDARD_CALENDAR_NAMES +] _ALL_CALENDARS = sorted(_STANDARD_CALENDARS + _NON_STANDARD_CALENDARS) _CFTIME_CALENDARS = [ pytest.param(*p.values, marks=requires_cftime) for p in _ALL_CALENDARS diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index bcc2a55c5b1..46400259662 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -3437,22 +3437,23 @@ def test_season_grouper_with_months_spanning_calendar_year_using_previous_year( gb = da.resample(time=SeasonResampler(["NDJFM", "AMJ"], drop_incomplete=False)) actual = gb.mean() - new_time = ( - xr.DataArray( - dims="time", - data=pd.DatetimeIndex( - [ - "2000-11-01", - "2001-04-01", - "2001-11-01", - "2002-04-01", - "2002-11-01", - ] - ), - ) - .convert_calendar(calendar=calendar, align_on="date") - .time.variable + new_time_da = xr.DataArray( + dims="time", + data=pd.DatetimeIndex( + [ + "2000-11-01", + "2001-04-01", + "2001-11-01", + "2002-04-01", + "2002-11-01", + ] + ), ) + if calendar != "standard": + new_time_da = new_time_da.convert_calendar( + calendar=calendar, align_on="date" + ) + new_time = new_time_da.time.variable # Expected if the previous "ND" is used for seasonal grouping expected = xr.DataArray( From fc7297a736924e63d102153bec0270b09d071c39 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Wed, 19 Mar 2025 23:09:06 -0600 Subject: [PATCH 29/33] fix docstring --- xarray/groupers.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/xarray/groupers.py b/xarray/groupers.py index 06b0131650c..94eb863ac51 100644 --- a/xarray/groupers.py +++ b/xarray/groupers.py @@ -740,10 +740,12 @@ class SeasonGrouper(Grouper): SeasonGrouper(seasons=['JF', 'MAM', 'JJAS', 'OND']) The ordering is preserved + >>> SeasonGrouper(["MAM", "JJAS", "OND", "JF"]) SeasonGrouper(seasons=['MAM', 'JJAS', 'OND', 'JF']) Overlapping seasons are allowed + >>> SeasonGrouper(["DJFM", "MAMJ", "JJAS", "SOND"]) SeasonGrouper(seasons=['DJFM', 'MAMJ', 'JJAS', 'SOND']) """ From 861da6ca6376828d310f8d249b8a63038cfcdc0a Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Tue, 25 Mar 2025 19:30:19 -0600 Subject: [PATCH 30/33] cleanup --- xarray/tests/test_groupby.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index 46400259662..0be97b3e439 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -3437,18 +3437,16 @@ def test_season_grouper_with_months_spanning_calendar_year_using_previous_year( gb = da.resample(time=SeasonResampler(["NDJFM", "AMJ"], drop_incomplete=False)) actual = gb.mean() + # fmt: off new_time_da = xr.DataArray( dims="time", data=pd.DatetimeIndex( [ - "2000-11-01", - "2001-04-01", - "2001-11-01", - "2002-04-01", - "2002-11-01", + "2000-11-01", "2001-04-01", "2001-11-01", "2002-04-01", "2002-11-01" ] ), ) + # fmt: on if calendar != "standard": new_time_da = new_time_da.convert_calendar( calendar=calendar, align_on="date" @@ -3497,7 +3495,6 @@ def test_season_resampling_raises_unsorted_seasons(self, seasons): pytest.param(["JFM", "AMJ", "JAS", "OND"], id="december-same-year"), pytest.param(["DJF", "MAM", "JJA", "ON"], id="skip-september"), pytest.param(["JJAS"], id="jjas-only"), - pytest.param(["MAM", "JJA", "SON", "DJF"], id="different-order"), ], ) def test_season_resampler( From 7406458f848368c1540f22fcf8be3ff1d3d2e746 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Tue, 25 Mar 2025 20:03:26 -0600 Subject: [PATCH 31/33] Avoid silly sphinx complete rebuilds --- doc/conf.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/doc/conf.py b/doc/conf.py index d4328dbf1b0..43afc1253e5 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -187,6 +187,8 @@ "pd.NaT": "~pandas.NaT", } +autodoc_type_aliases = napoleon_type_aliases # Keep both in sync + # mermaid config mermaid_version = "10.9.1" From 6297c1c35f640a562c8cf6a56dc4a3744ba8387c Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Tue, 25 Mar 2025 19:34:43 -0600 Subject: [PATCH 32/33] Add docs --- doc/user-guide/groupby.rst | 8 ++ doc/user-guide/time-series.rst | 132 +++++++++++++++++++++++++-------- xarray/groupers.py | 1 + 3 files changed, 110 insertions(+), 31 deletions(-) diff --git a/doc/user-guide/groupby.rst b/doc/user-guide/groupby.rst index 7cb4e883347..673e23d75ac 100644 --- a/doc/user-guide/groupby.rst +++ b/doc/user-guide/groupby.rst @@ -332,6 +332,14 @@ Different groupers can be combined to construct sophisticated GroupBy operations ds.groupby(x=BinGrouper(bins=[5, 15, 25]), letters=UniqueGrouper()).sum() +Time Grouping and Resampling +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. seealso:: + + See :ref:`resampling`. + + Shuffling ~~~~~~~~~ diff --git a/doc/user-guide/time-series.rst b/doc/user-guide/time-series.rst index d131ae74b9f..cb3e94e3645 100644 --- a/doc/user-guide/time-series.rst +++ b/doc/user-guide/time-series.rst @@ -1,3 +1,5 @@ +.. currentmodule:: xarray + .. _time-series: ================ @@ -21,12 +23,12 @@ core functionality. Creating datetime64 data ------------------------ -Xarray uses the numpy dtypes ``datetime64[unit]`` and ``timedelta64[unit]`` -(where unit is one of ``"s"``, ``"ms"``, ``"us"`` and ``"ns"``) to represent datetime +Xarray uses the numpy dtypes :py:class:`numpy.datetime64` and :py:class:`numpy.timedelta64` +with specified units (one of ``"s"``, ``"ms"``, ``"us"`` and ``"ns"``) to represent datetime data, which offer vectorized operations with numpy and smooth integration with pandas. -To convert to or create regular arrays of ``datetime64`` data, we recommend -using :py:func:`pandas.to_datetime` and :py:func:`pandas.date_range`: +To convert to or create regular arrays of :py:class:`numpy.datetime64` data, we recommend +using :py:func:`pandas.to_datetime`, :py:class:`pandas.DatetimeIndex`, or :py:func:`xarray.date_range`: .. ipython:: python @@ -34,13 +36,6 @@ using :py:func:`pandas.to_datetime` and :py:func:`pandas.date_range`: pd.DatetimeIndex( ["2000-01-01 00:00:00", "2000-02-02 00:00:00"], dtype="datetime64[s]" ) - pd.date_range("2000-01-01", periods=365) - pd.date_range("2000-01-01", periods=365, unit="s") - -It is also possible to use corresponding :py:func:`xarray.date_range`: - -.. ipython:: python - xr.date_range("2000-01-01", periods=365) xr.date_range("2000-01-01", periods=365, unit="s") @@ -81,7 +76,7 @@ attribute like ``'days since 2000-01-01'``). You can manual decode arrays in this form by passing a dataset to -:py:func:`~xarray.decode_cf`: +:py:func:`decode_cf`: .. ipython:: python @@ -93,8 +88,8 @@ You can manual decode arrays in this form by passing a dataset to coder = xr.coders.CFDatetimeCoder(time_unit="s") xr.decode_cf(ds, decode_times=coder) -From xarray 2025.01.2 the resolution of the dates can be one of ``"s"``, ``"ms"``, ``"us"`` or ``"ns"``. One limitation of using ``datetime64[ns]`` is that it limits the native representation of dates to those that fall between the years 1678 and 2262, which gets increased significantly with lower resolutions. When a store contains dates outside of these bounds (or dates < `1582-10-15`_ with a Gregorian, also known as standard, calendar), dates will be returned as arrays of :py:class:`cftime.datetime` objects and a :py:class:`~xarray.CFTimeIndex` will be used for indexing. -:py:class:`~xarray.CFTimeIndex` enables most of the indexing functionality of a :py:class:`pandas.DatetimeIndex`. +From xarray 2025.01.2 the resolution of the dates can be one of ``"s"``, ``"ms"``, ``"us"`` or ``"ns"``. One limitation of using ``datetime64[ns]`` is that it limits the native representation of dates to those that fall between the years 1678 and 2262, which gets increased significantly with lower resolutions. When a store contains dates outside of these bounds (or dates < `1582-10-15`_ with a Gregorian, also known as standard, calendar), dates will be returned as arrays of :py:class:`cftime.datetime` objects and a :py:class:`CFTimeIndex` will be used for indexing. +:py:class:`CFTimeIndex` enables most of the indexing functionality of a :py:class:`pandas.DatetimeIndex`. See :ref:`CFTimeIndex` for more information. Datetime indexing @@ -205,35 +200,37 @@ You can also search for multiple months (in this case January through March), us Resampling and grouped operations --------------------------------- -Datetime components couple particularly well with grouped operations (see -:ref:`groupby`) for analyzing features that repeat over time. Here's how to -calculate the mean by time of day: + +.. seealso:: + + For more generic documentation on grouping, see :ref:`groupby`. + + +Datetime components couple particularly well with grouped operations for analyzing features that repeat over time. +Here's how to calculate the mean by time of day: .. ipython:: python - :okwarning: ds.groupby("time.hour").mean() For upsampling or downsampling temporal resolutions, xarray offers a -:py:meth:`~xarray.Dataset.resample` method building on the core functionality +:py:meth:`Dataset.resample` method building on the core functionality offered by the pandas method of the same name. Resample uses essentially the -same api as ``resample`` `in pandas`_. +same api as :py:meth:`pandas.DataFrame.resample` `in pandas`_. .. _in pandas: https://pandas.pydata.org/pandas-docs/stable/timeseries.html#up-and-downsampling For example, we can downsample our dataset from hourly to 6-hourly: .. ipython:: python - :okwarning: ds.resample(time="6h") -This will create a specialized ``Resample`` object which saves information -necessary for resampling. All of the reduction methods which work with -``Resample`` objects can also be used for resampling: +This will create a specialized :py:class:`~xarray.core.resample.DatasetResample` or :py:class:`~xarray.core.resample.DataArrayResample` +object which saves information necessary for resampling. All of the reduction methods which work with +:py:class:`Dataset` or :py:class:`DataArray` objects can also be used for resampling: .. ipython:: python - :okwarning: ds.resample(time="6h").mean() @@ -252,7 +249,7 @@ by specifying the ``dim`` keyword argument ds.resample(time="6h").mean(dim=["time", "latitude", "longitude"]) For upsampling, xarray provides six methods: ``asfreq``, ``ffill``, ``bfill``, ``pad``, -``nearest`` and ``interpolate``. ``interpolate`` extends ``scipy.interpolate.interp1d`` +``nearest`` and ``interpolate``. ``interpolate`` extends :py:func:`scipy.interpolate.interp1d` and supports all of its schemes. All of these resampling operations work on both Dataset and DataArray objects with an arbitrary number of dimensions. @@ -266,9 +263,7 @@ Data that has indices outside of the given ``tolerance`` are set to ``NaN``. It is often desirable to center the time values after a resampling operation. That can be accomplished by updating the resampled dataset time coordinate values -using time offset arithmetic via the `pandas.tseries.frequencies.to_offset`_ function. - -.. _pandas.tseries.frequencies.to_offset: https://pandas.pydata.org/docs/reference/api/pandas.tseries.frequencies.to_offset.html +using time offset arithmetic via the :py:func:`pandas.tseries.frequencies.to_offset` function. .. ipython:: python @@ -277,5 +272,80 @@ using time offset arithmetic via the `pandas.tseries.frequencies.to_offset`_ fun resampled_ds["time"] = resampled_ds.get_index("time") + offset resampled_ds -For more examples of using grouped operations on a time dimension, see -:doc:`../examples/weather-data`. + +.. seealso:: + + For more examples of using grouped operations on a time dimension, see :doc:`../examples/weather-data`. + + +Handling Seasons +~~~~~~~~~~~~~~~~ + +Two extremely common time series operations are to group by seasons, and resample to a seasonal frequency. +Xarray has historically supported some simple versions of these computations. +For example, ``.groupby("time.season")`` (where the seasons are DJF, MAM, JJA, SON) +and resampling to a seasonal frequency using Pandas syntax: ``.resample(time="QS-DEC")``. + +Quite commonly one wants more flexibility in defining seasons. For these use-cases, Xarray provides +:py:class:`groupers.SeasonGrouper` and :py:class:`groupers.SeasonResampler`. + + +.. currentmodule:: xarray.groupers + +.. ipython:: python + + from xarray.groupers import SeasonGrouper + + ds.groupby(time=SeasonGrouper(["DJF", "MAM", "JJA", "SON"])).mean() + + +Note how the seasons are in the specified order, unlike ``.groupby("time.season")`` where the +seasons are sorted alphabetically. + +.. ipython:: python + + ds.groupby("time.season").mean() + + +:py:class:`SeasonGrouper` supports overlapping seasons: + +.. ipython:: python + + ds.groupby(time=SeasonGrouper(["DJFM", "MAMJ", "JJAS", "SOND"])).mean() + + +Skipping months is allowed: + +.. ipython:: python + + ds.groupby(time=SeasonGrouper(["JJAS"])).mean() + + +Use :py:class:`SeasonResampler` to specify custom seasons. + +.. ipython:: python + + from xarray.groupers import SeasonResampler + + ds.resample(time=SeasonResampler(["DJF", "MAM", "JJA", "SON"])).mean() + + +:py:class:`SeasonResampler` is smart enough to correctly handle years for seasons that +span the end of the year (e.g. DJF). By default :py:class:`SeasonResampler` will skip any +season that is incomplete (e.g. the first DJF season for a time series that starts in Jan). +Pass the ``drop_incomplete=False`` kwarg to :py:class:`SeasonResampler` to disable this behaviour. + +.. ipython:: python + + from xarray.groupers import SeasonResampler + + ds.resample( + time=SeasonResampler(["DJF", "MAM", "JJA", "SON"], drop_incomplete=False) + ).mean() + + +Seasons need not be of the same length: + +.. ipython:: python + + ds.resample(time=SeasonResampler(["JF", "MAM", "JJAS", "OND"])).mean() diff --git a/xarray/groupers.py b/xarray/groupers.py index 408bf29d096..0551b02ae91 100644 --- a/xarray/groupers.py +++ b/xarray/groupers.py @@ -643,6 +643,7 @@ def inds_to_season_string(asints: tuple[tuple[int, ...], ...]) -> tuple[str, ... def is_sorted_periodic(lst): + """Used to verify that seasons provided to SeasonResampler are in order.""" n = len(lst) # Find the wraparound point where the list decreases From 3ee3fde558ee716b35356484eb703c5e0d299305 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 6 May 2025 20:47:39 +0000 Subject: [PATCH 33/33] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/groupers.py | 1 - 1 file changed, 1 deletion(-) diff --git a/xarray/groupers.py b/xarray/groupers.py index 1382ef218f2..203acbe37c5 100644 --- a/xarray/groupers.py +++ b/xarray/groupers.py @@ -30,7 +30,6 @@ _contains_datetime_like_objects, ) from xarray.core.coordinates import Coordinates, coordinates_from_variable -from xarray.core.coordinates import Coordinates, coordinates_from_variable from xarray.core.dataarray import DataArray from xarray.core.duck_array_ops import array_all, isnull from xarray.core.formatting import first_n_items pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy