Skip to content

Commit 4174aa1

Browse files
authored
Preserve label ordering for multi-variable GroupBy (#10151)
* Preserve label ordering for multi-variable GroupBy * fix mypy
1 parent fd7c765 commit 4174aa1

File tree

3 files changed

+61
-15
lines changed

3 files changed

+61
-15
lines changed

xarray/core/groupby.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -534,6 +534,11 @@ def factorize(self) -> EncodedGroups:
534534
list(grouper.full_index.values for grouper in groupers),
535535
names=tuple(grouper.name for grouper in groupers),
536536
)
537+
if not full_index.is_unique:
538+
raise ValueError(
539+
"The output index for the GroupBy is non-unique. "
540+
"This is a bug in the Grouper provided."
541+
)
537542
# This will be unused when grouping by dask arrays, so skip..
538543
if not is_chunked_array(_flatcodes):
539544
# Constructing an index from the product is wrong when there are missing groups
@@ -942,17 +947,29 @@ def _binary_op(self, other, f, reflexive=False):
942947
def _restore_dim_order(self, stacked):
943948
raise NotImplementedError
944949

945-
def _maybe_restore_empty_groups(self, combined):
946-
"""Our index contained empty groups (e.g., from a resampling or binning). If we
950+
def _maybe_reindex(self, combined):
951+
"""Reindexing is needed in two cases:
952+
1. Our index contained empty groups (e.g., from a resampling or binning). If we
947953
reduced on that dimension, we want to restore the full index.
954+
955+
2. We use a MultiIndex for multi-variable GroupBy.
956+
The MultiIndex stores each level's labels in sorted order
957+
which are then assigned on unstacking. So we need to restore
958+
the correct order here.
948959
"""
949960
has_missing_groups = (
950961
self.encoded.unique_coord.size != self.encoded.full_index.size
951962
)
952963
indexers = {}
953964
for grouper in self.groupers:
954-
if has_missing_groups and grouper.name in combined._indexes:
965+
index = combined._indexes.get(grouper.name, None)
966+
if has_missing_groups and index is not None:
955967
indexers[grouper.name] = grouper.full_index
968+
elif len(self.groupers) > 1:
969+
if not isinstance(
970+
grouper.full_index, pd.RangeIndex
971+
) and not index.index.equals(grouper.full_index):
972+
indexers[grouper.name] = grouper.full_index
956973
if indexers:
957974
combined = combined.reindex(**indexers)
958975
return combined
@@ -1595,7 +1612,7 @@ def _combine(self, applied, shortcut=False):
15951612
if dim not in applied_example.dims:
15961613
combined = combined.assign_coords(self.encoded.coords)
15971614
combined = self._maybe_unstack(combined)
1598-
combined = self._maybe_restore_empty_groups(combined)
1615+
combined = self._maybe_reindex(combined)
15991616
return combined
16001617

16011618
def reduce(
@@ -1751,7 +1768,7 @@ def _combine(self, applied):
17511768
if dim not in applied_example.dims:
17521769
combined = combined.assign_coords(self.encoded.coords)
17531770
combined = self._maybe_unstack(combined)
1754-
combined = self._maybe_restore_empty_groups(combined)
1771+
combined = self._maybe_reindex(combined)
17551772
return combined
17561773

17571774
def reduce(

xarray/groupers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -521,7 +521,7 @@ def first_items(self) -> tuple[pd.Series, np.ndarray]:
521521
counts = grouped.count()
522522
# This way we generate codes for the final output index: full_index.
523523
# So for _flox_reduce we avoid one reindex and copy by avoiding
524-
# _maybe_restore_empty_groups
524+
# _maybe_reindex
525525
codes = np.repeat(np.arange(len(first_items)), counts)
526526
return first_items, codes
527527

xarray/tests/test_groupby.py

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def test_multi_index_groupby_sum() -> None:
154154

155155

156156
@requires_pandas_ge_2_2
157-
def test_multi_index_propagation():
157+
def test_multi_index_propagation() -> None:
158158
# regression test for GH9648
159159
times = pd.date_range("2023-01-01", periods=4)
160160
locations = ["A", "B"]
@@ -2291,7 +2291,7 @@ def test_resample_origin(self) -> None:
22912291
times = pd.date_range("2000-01-01T02:03:01", freq="6h", periods=10)
22922292
array = DataArray(np.arange(10), [("time", times)])
22932293

2294-
origin = "start"
2294+
origin: Literal["start"] = "start"
22952295
actual = array.resample(time="24h", origin=origin).mean()
22962296
expected = DataArray(array.to_series().resample("24h", origin=origin).mean())
22972297
assert_identical(expected, actual)
@@ -2696,7 +2696,7 @@ def test_default_flox_method() -> None:
26962696

26972697
@requires_cftime
26982698
@pytest.mark.filterwarnings("ignore")
2699-
def test_cftime_resample_gh_9108():
2699+
def test_cftime_resample_gh_9108() -> None:
27002700
import cftime
27012701

27022702
ds = Dataset(
@@ -3046,7 +3046,7 @@ def test_gappy_resample_reductions(reduction):
30463046
assert_identical(expected, actual)
30473047

30483048

3049-
def test_groupby_transpose():
3049+
def test_groupby_transpose() -> None:
30503050
# GH5361
30513051
data = xr.DataArray(
30523052
np.random.randn(4, 2),
@@ -3106,7 +3106,7 @@ def test_lazy_grouping(grouper, expect_index):
31063106

31073107

31083108
@requires_dask
3109-
def test_lazy_grouping_errors():
3109+
def test_lazy_grouping_errors() -> None:
31103110
import dask.array
31113111

31123112
data = DataArray(
@@ -3132,15 +3132,15 @@ def test_lazy_grouping_errors():
31323132

31333133

31343134
@requires_dask
3135-
def test_lazy_int_bins_error():
3135+
def test_lazy_int_bins_error() -> None:
31363136
import dask.array
31373137

31383138
with pytest.raises(ValueError, match="Bin edges must be provided"):
31393139
with raise_if_dask_computes():
31403140
_ = BinGrouper(bins=4).factorize(DataArray(dask.array.arange(3)))
31413141

31423142

3143-
def test_time_grouping_seasons_specified():
3143+
def test_time_grouping_seasons_specified() -> None:
31443144
time = xr.date_range("2001-01-01", "2002-01-01", freq="D")
31453145
ds = xr.Dataset({"foo": np.arange(time.size)}, coords={"time": ("time", time)})
31463146
labels = ["DJF", "MAM", "JJA", "SON"]
@@ -3149,7 +3149,36 @@ def test_time_grouping_seasons_specified():
31493149
assert_identical(actual, expected.reindex(season=labels))
31503150

31513151

3152-
def test_groupby_multiple_bin_grouper_missing_groups():
3152+
def test_multiple_grouper_unsorted_order() -> None:
3153+
time = xr.date_range("2001-01-01", "2003-01-01", freq="MS")
3154+
ds = xr.Dataset({"foo": np.arange(time.size)}, coords={"time": ("time", time)})
3155+
labels = ["DJF", "MAM", "JJA", "SON"]
3156+
actual = ds.groupby(
3157+
{
3158+
"time.season": UniqueGrouper(labels=labels),
3159+
"time.year": UniqueGrouper(labels=[2002, 2001]),
3160+
}
3161+
).sum()
3162+
expected = (
3163+
ds.groupby({"time.season": UniqueGrouper(), "time.year": UniqueGrouper()})
3164+
.sum()
3165+
.reindex(season=labels, year=[2002, 2001])
3166+
)
3167+
assert_identical(actual, expected.reindex(season=labels))
3168+
3169+
b = xr.DataArray(
3170+
np.random.default_rng(0).random((2, 3, 4)),
3171+
coords={"x": [0, 1], "y": [0, 1, 2]},
3172+
dims=["x", "y", "z"],
3173+
)
3174+
actual2 = b.groupby(
3175+
x=UniqueGrouper(labels=[1, 0]), y=UniqueGrouper(labels=[2, 0, 1])
3176+
).sum()
3177+
expected2 = b.reindex(x=[1, 0], y=[2, 0, 1]).transpose("z", ...)
3178+
assert_identical(actual2, expected2)
3179+
3180+
3181+
def test_groupby_multiple_bin_grouper_missing_groups() -> None:
31533182
from numpy import nan
31543183

31553184
ds = xr.Dataset(
@@ -3226,7 +3255,7 @@ def test_shuffle_by(chunks, expected_chunks):
32263255

32273256

32283257
@requires_dask
3229-
def test_groupby_dask_eager_load_warnings():
3258+
def test_groupby_dask_eager_load_warnings() -> None:
32303259
ds = xr.Dataset(
32313260
{"foo": (("z"), np.arange(12))},
32323261
coords={"x": ("z", np.arange(12)), "y": ("z", np.arange(12))},

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy