diff --git a/doc/whats-new.rst b/doc/whats-new.rst index a0b9e300a94..20bbdc7ec69 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -29,6 +29,9 @@ New Features - Improved compatibility with OPeNDAP DAP4 data model for backend engine ``pydap``. This includes ``datatree`` support, and removing slashes from dimension names. By `Miguel Jimenez-Urias `_. +- Improved support pandas Extension Arrays. (:issue:`9661`, :pull:`9671`) + By `Ilan Gold `_. + Breaking changes ~~~~~~~~~~~~~~~~ @@ -41,6 +44,12 @@ Breaking changes pydap 3.4 3.5.0 ===================== ========= ======= + +- Reductions with ``groupby_bins`` or those that involve :py:class:`xarray.groupers.BinGrouper` + now return objects indexed by :py:meth:`pandas.IntervalArray` objects, + instead of numpy object arrays containing tuples. This change enables interval-aware indexing of + such Xarray objects. (:pull:`9671`). By `Ilan Gold `_. + Deprecations ~~~~~~~~~~~~ diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 9b80e154b95..f523f971725 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -6917,7 +6917,7 @@ def groupby( [[nan, nan, nan], [ 3., 4., 5.]]]) Coordinates: - * x_bins (x_bins) object 16B (5, 15] (15, 25] + * x_bins (x_bins) interval[int64, right] 32B (5, 15] (15, 25] * letters (letters) object 16B 'a' 'b' Dimensions without coordinates: y diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 9bfad11994e..9d52f2e0776 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -7059,6 +7059,8 @@ def to_pandas(self) -> pd.Series | pd.DataFrame: ) def _to_dataframe(self, ordered_dims: Mapping[Any, int]): + from xarray.core.extension_array import PandasExtensionArray + columns_in_order = [k for k in self.variables if k not in self.dims] non_extension_array_columns = [ k @@ -7070,20 +7072,41 @@ def _to_dataframe(self, ordered_dims: Mapping[Any, int]): for k in columns_in_order if is_extension_array_dtype(self.variables[k].data) ] + extension_array_columns_different_index = [ + k + for k in extension_array_columns + if set(self.variables[k].dims) != set(ordered_dims.keys()) + ] + extension_array_columns_same_index = [ + k + for k in extension_array_columns + if k not in extension_array_columns_different_index + ] data = [ self._variables[k].set_dims(ordered_dims).values.reshape(-1) for k in non_extension_array_columns ] index = self.coords.to_index([*ordered_dims]) broadcasted_df = pd.DataFrame( - dict(zip(non_extension_array_columns, data, strict=True)), index=index + { + **dict(zip(non_extension_array_columns, data, strict=True)), + **{ + c: self.variables[c].data.array + for c in extension_array_columns_same_index + }, + }, + index=index, ) - for extension_array_column in extension_array_columns: + for extension_array_column in extension_array_columns_different_index: extension_array = self.variables[extension_array_column].data.array - index = self[self.variables[extension_array_column].dims[0]].data + index = self[ + self.variables[extension_array_column].dims[0] + ].coords.to_index() extension_array_df = pd.DataFrame( {extension_array_column: extension_array}, - index=self[self.variables[extension_array_column].dims[0]].data, + index=pd.Index(index.array) + if isinstance(index, PandasExtensionArray) + else index, ) extension_array_df.index.name = self.variables[extension_array_column].dims[ 0 @@ -9892,10 +9915,10 @@ def groupby( >>> from xarray.groupers import BinGrouper, UniqueGrouper >>> >>> ds.groupby(x=BinGrouper(bins=[5, 15, 25]), letters=UniqueGrouper()).sum() - Size: 128B + Size: 144B Dimensions: (y: 3, x_bins: 2, letters: 2) Coordinates: - * x_bins (x_bins) object 16B (5, 15] (15, 25] + * x_bins (x_bins) interval[int64, right] 32B (5, 15] (15, 25] * letters (letters) object 16B 'a' 'b' Dimensions without coordinates: y Data variables: diff --git a/xarray/core/extension_array.py b/xarray/core/extension_array.py index 43829b4029f..e8006a4c8c3 100644 --- a/xarray/core/extension_array.py +++ b/xarray/core/extension_array.py @@ -102,7 +102,7 @@ def replace_duck_with_extension_array(args) -> list: return type(self)[type(res)](res) return res - def __array_ufunc__(ufunc, method, *inputs, **kwargs): + def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): return ufunc(*inputs, **kwargs) def __repr__(self): diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 9dc1a26b1f0..bc934132f1c 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -11,6 +11,7 @@ from xarray.core import formatting, nputils, utils from xarray.core.coordinate_transform import CoordinateTransform +from xarray.core.extension_array import PandasExtensionArray from xarray.core.indexing import ( CoordinateTransformIndexingAdapter, IndexSelResult, @@ -444,6 +445,8 @@ def safe_cast_to_index(array: Any) -> pd.Index: from xarray.core.variable import Variable from xarray.namedarray.pycompat import to_numpy + if isinstance(array, PandasExtensionArray): + array = pd.Index(array.array) if isinstance(array, pd.Index): index = array elif isinstance(array, DataArray | Variable): @@ -602,7 +605,11 @@ def __init__( self.dim = dim if coord_dtype is None: - coord_dtype = get_valid_numpy_dtype(index) + if pd.api.types.is_extension_array_dtype(index.dtype): + cast(pd.api.extensions.ExtensionDtype, index.dtype) + coord_dtype = index.dtype + else: + coord_dtype = get_valid_numpy_dtype(index) self.coord_dtype = coord_dtype def _replace(self, index, dim=None, coord_dtype=None): @@ -698,6 +705,8 @@ def concat( if not indexes: coord_dtype = None + elif len(set(idx.coord_dtype for idx in indexes)) == 1: + coord_dtype = indexes[0].coord_dtype else: coord_dtype = np.result_type(*[idx.coord_dtype for idx in indexes]) diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index 2999506d1de..aa56006eff3 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -10,14 +10,16 @@ from dataclasses import dataclass, field from datetime import timedelta from html import escape -from typing import TYPE_CHECKING, Any, overload +from typing import TYPE_CHECKING, Any, cast, overload import numpy as np import pandas as pd +from numpy.typing import DTypeLike from packaging.version import Version from xarray.core import duck_array_ops from xarray.core.coordinate_transform import CoordinateTransform +from xarray.core.extension_array import PandasExtensionArray from xarray.core.nputils import NumpyVIndexAdapter from xarray.core.options import OPTIONS from xarray.core.types import T_Xarray @@ -28,14 +30,13 @@ is_duck_array, is_duck_dask_array, is_scalar, + is_valid_numpy_dtype, to_0d_array, ) from xarray.namedarray.parallelcompat import get_chunked_array_type from xarray.namedarray.pycompat import array_type, integer_types, is_chunked_array if TYPE_CHECKING: - from numpy.typing import DTypeLike - from xarray.core.indexes import Index from xarray.core.types import Self from xarray.core.variable import Variable @@ -1744,27 +1745,43 @@ class PandasIndexingAdapter(ExplicitlyIndexedNDArrayMixin): __slots__ = ("_dtype", "array") array: pd.Index - _dtype: np.dtype + _dtype: np.dtype | pd.api.extensions.ExtensionDtype - def __init__(self, array: pd.Index, dtype: DTypeLike = None): + def __init__( + self, + array: pd.Index, + dtype: DTypeLike | pd.api.extensions.ExtensionDtype | None = None, + ): from xarray.core.indexes import safe_cast_to_index self.array = safe_cast_to_index(array) if dtype is None: - self._dtype = get_valid_numpy_dtype(array) + if pd.api.types.is_extension_array_dtype(array.dtype): + cast(pd.api.extensions.ExtensionDtype, array.dtype) + self._dtype = array.dtype + else: + self._dtype = get_valid_numpy_dtype(array) + elif pd.api.types.is_extension_array_dtype(dtype): + self._dtype = cast(pd.api.extensions.ExtensionDtype, dtype) else: - self._dtype = np.dtype(dtype) + self._dtype = np.dtype(cast(DTypeLike, dtype)) @property - def dtype(self) -> np.dtype: + def dtype(self) -> np.dtype | pd.api.extensions.ExtensionDtype: # type: ignore[override] return self._dtype def __array__( - self, dtype: np.typing.DTypeLike = None, /, *, copy: bool | None = None + self, + dtype: np.typing.DTypeLike | None = None, + /, + *, + copy: bool | None = None, ) -> np.ndarray: - if dtype is None: - dtype = self.dtype + if dtype is None and is_valid_numpy_dtype(self.dtype): + dtype = cast(np.dtype, self.dtype) + else: + dtype = get_valid_numpy_dtype(self.array) array = self.array if isinstance(array, pd.PeriodIndex): with suppress(AttributeError): @@ -1776,14 +1793,18 @@ def __array__( else: return np.asarray(array.values, dtype=dtype) - def get_duck_array(self) -> np.ndarray: + def get_duck_array(self) -> np.ndarray | PandasExtensionArray: + # We return an PandasExtensionArray wrapper type that satisfies + # duck array protocols. This is what's needed for tests to pass. + if pd.api.types.is_extension_array_dtype(self.array): + return PandasExtensionArray(self.array.array) return np.asarray(self) @property def shape(self) -> _Shape: return (len(self.array),) - def _convert_scalar(self, item): + def _convert_scalar(self, item) -> np.ndarray: if item is pd.NaT: # work around the impossibility of casting NaT with asarray # note: it probably would be better in general to return @@ -1799,7 +1820,10 @@ def _convert_scalar(self, item): # numpy fails to convert pd.Timestamp to np.datetime64[ns] item = np.asarray(item.to_datetime64()) elif self.dtype != object: - item = np.asarray(item, dtype=self.dtype) + dtype = self.dtype + if pd.api.types.is_extension_array_dtype(dtype): + dtype = get_valid_numpy_dtype(self.array) + item = np.asarray(item, dtype=cast(np.dtype, dtype)) # as for numpy.ndarray indexing, we always want the result to be # a NumPy array. @@ -1902,6 +1926,12 @@ def copy(self, deep: bool = True) -> Self: array = self.array.copy(deep=True) if deep else self.array return type(self)(array, self._dtype) + @property + def nbytes(self) -> int: + if pd.api.types.is_extension_array_dtype(self.dtype): + return self.array.nbytes + return cast(np.dtype, self.dtype).itemsize * len(self.array) + class PandasMultiIndexingAdapter(PandasIndexingAdapter): """Handles explicit indexing for a pandas.MultiIndex. @@ -1914,23 +1944,27 @@ class PandasMultiIndexingAdapter(PandasIndexingAdapter): __slots__ = ("_dtype", "adapter", "array", "level") array: pd.MultiIndex - _dtype: np.dtype + _dtype: np.dtype | pd.api.extensions.ExtensionDtype level: str | None def __init__( self, array: pd.MultiIndex, - dtype: DTypeLike = None, + dtype: DTypeLike | pd.api.extensions.ExtensionDtype | None = None, level: str | None = None, ): super().__init__(array, dtype) self.level = level def __array__( - self, dtype: np.typing.DTypeLike = None, /, *, copy: bool | None = None + self, + dtype: DTypeLike | None = None, + /, + *, + copy: bool | None = None, ) -> np.ndarray: if dtype is None: - dtype = self.dtype + dtype = cast(np.dtype, self.dtype) if self.level is not None: return np.asarray( self.array.get_level_values(self.level).values, dtype=dtype diff --git a/xarray/core/variable.py b/xarray/core/variable.py index f59680dd7df..6d769842a69 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -13,7 +13,6 @@ import numpy as np import pandas as pd from numpy.typing import ArrayLike -from pandas.api.types import is_extension_array_dtype import xarray as xr # only for Dataset and DataArray from xarray.compat.array_api_compat import to_like_array @@ -60,6 +59,7 @@ indexing.ExplicitlyIndexed, pd.Index, pd.api.extensions.ExtensionArray, + PandasExtensionArray, ) # https://github.com/python/mypy/issues/224 BASIC_INDEXING_TYPES = integer_types + (slice,) @@ -192,7 +192,7 @@ def _maybe_wrap_data(data): if isinstance(data, pd.Index): return PandasIndexingAdapter(data) if isinstance(data, pd.api.extensions.ExtensionArray): - return PandasExtensionArray[type(data)](data) + return PandasExtensionArray(data) return data @@ -2593,11 +2593,6 @@ def chunk( # type: ignore[override] dask.array.from_array """ - if is_extension_array_dtype(self): - raise ValueError( - f"{self} was found to be a Pandas ExtensionArray. Please convert to numpy first." - ) - if from_array_kwargs is None: from_array_kwargs = {} diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index cdf9eab5c8d..f9c1919201f 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -834,6 +834,7 @@ def chunk( if chunkmanager.is_chunked_array(data_old): data_chunked = chunkmanager.rechunk(data_old, chunks) # type: ignore[arg-type] else: + ndata: duckarray[Any, Any] if not isinstance(data_old, ExplicitlyIndexed): ndata = data_old else: diff --git a/xarray/plot/dataarray_plot.py b/xarray/plot/dataarray_plot.py index a4ed7eba1d0..ee49928aa01 100644 --- a/xarray/plot/dataarray_plot.py +++ b/xarray/plot/dataarray_plot.py @@ -524,7 +524,7 @@ def line( assert hueplt is not None ax.legend(handles=primitive, labels=list(hueplt.to_numpy()), title=hue_label) - if np.issubdtype(xplt.dtype, np.datetime64): + if isinstance(xplt.dtype, np.dtype) and np.issubdtype(xplt.dtype, np.datetime64): # type: ignore[redundant-expr] _set_concise_date(ax, axis="x") _update_axes(ax, xincrease, yincrease, xscale, yscale, xticks, yticks, xlim, ylim) diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index c1310bc7e1d..ed8c4178ed0 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -4381,7 +4381,6 @@ def test_setitem_pandas(self) -> None: ds["x"] = np.arange(3) ds_copy = ds.copy() ds_copy["bar"] = ds["bar"].to_pandas() - assert_equal(ds, ds_copy) def test_setitem_auto_align(self) -> None: @@ -4972,6 +4971,16 @@ def test_to_and_from_dataframe(self) -> None: expected = pd.DataFrame([[]], index=idx) assert expected.equals(actual), (expected, actual) + def test_from_dataframe_categorical_dtype_index(self) -> None: + cat = pd.CategoricalIndex(list("abcd")) + df = pd.DataFrame({"f": [0, 1, 2, 3]}, index=cat) + ds = df.to_xarray() + restored = ds.to_dataframe() + df.index.name = ( + "index" # restored gets the name because it has the coord with the name + ) + pd.testing.assert_frame_equal(df, restored) + def test_from_dataframe_categorical_index(self) -> None: cat = pd.CategoricalDtype( categories=["foo", "bar", "baz", "qux", "quux", "corge"] @@ -4996,7 +5005,7 @@ def test_from_dataframe_categorical_index_string_categories(self) -> None: ) ser = pd.Series(1, index=cat) ds = ser.to_xarray() - assert ds.coords.dtypes["index"] == np.dtype("O") + assert ds.coords.dtypes["index"] == ser.index.dtype @requires_sparse def test_from_dataframe_sparse(self) -> None: diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index 1c351f0ee62..52ab8c4d232 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -4,7 +4,7 @@ import operator import warnings from itertools import pairwise -from typing import Literal +from typing import Literal, cast from unittest import mock import numpy as np @@ -1118,7 +1118,8 @@ def test_groupby_math_nD_group() -> None: expected = da.isel(x=slice(30)) - expanded_mean expected["labels"] = expected.labels.broadcast_like(expected.labels2d) expected["num"] = expected.num.broadcast_like(expected.num2d) - expected["num2d_bins"] = (("x", "y"), mean.num2d_bins.data[idxr]) + # mean.num2d_bins.data is a pandas IntervalArray so needs to be put in `numpy` to allow indexing + expected["num2d_bins"] = (("x", "y"), mean.num2d_bins.data.to_numpy()[idxr]) actual = g - mean assert_identical(expected, actual) @@ -1680,13 +1681,9 @@ def test_groupby_bins( df["dim_0_bins"] = pd.cut(array["dim_0"], bins, **cut_kwargs) # type: ignore[call-overload] expected_df = df.groupby("dim_0_bins", observed=True).sum() - # TODO: can't convert df with IntervalIndex to Xarray - expected = ( - expected_df.reset_index(drop=True) - .to_xarray() - .assign_coords(index=np.array(expected_df.index)) - .rename({"index": "dim_0_bins"})["a"] - ) + expected = expected_df.to_xarray().assign_coords( + dim_0_bins=cast(pd.CategoricalIndex, expected_df.index).categories + )["a"] with xr.set_options(use_flox=use_flox): gb = array.groupby_bins("dim_0", bins=bins, **cut_kwargs) diff --git a/xarray/tests/test_pandas_to_xarray.py b/xarray/tests/test_pandas_to_xarray.py index 590749bf548..111866541eb 100644 --- a/xarray/tests/test_pandas_to_xarray.py +++ b/xarray/tests/test_pandas_to_xarray.py @@ -107,14 +107,6 @@ def index_flat(request): return indices_dict[key].copy() -@pytest.fixture -def using_infer_string() -> bool: - """ - Fixture to check if infer string option is enabled. - """ - return pd.options.future.infer_string is True # type: ignore[union-attr] - - class TestDataFrameToXArray: @pytest.fixture def df(self): @@ -131,8 +123,7 @@ def df(self): } ) - @pytest.mark.xfail(reason="needs some work") - def test_to_xarray_index_types(self, index_flat, df, using_infer_string): + def test_to_xarray_index_types(self, index_flat, df): index = index_flat # MultiIndex is tested in test_to_xarray_with_multiindex if len(index) == 0: @@ -154,9 +145,6 @@ def test_to_xarray_index_types(self, index_flat, df, using_infer_string): # datetimes w/tz are preserved # column names are lost expected = df.copy() - expected["f"] = expected["f"].astype( - object if not using_infer_string else "str" - ) expected.columns.name = None tm.assert_frame_equal(result.to_dataframe(), expected) @@ -168,7 +156,7 @@ def test_to_xarray_empty(self, df): assert result.sizes["foo"] == 0 assert isinstance(result, Dataset) - def test_to_xarray_with_multiindex(self, df, using_infer_string): + def test_to_xarray_with_multiindex(self, df): from xarray import Dataset # MultiIndex @@ -183,9 +171,7 @@ def test_to_xarray_with_multiindex(self, df, using_infer_string): result = result.to_dataframe() expected = df.copy() - expected["f"] = expected["f"].astype( - object if not using_infer_string else "str" - ) + expected["f"] = expected["f"].astype(object) expected.columns.name = None tm.assert_frame_equal(result, expected) diff --git a/xarray/tests/test_variable.py b/xarray/tests/test_variable.py index 8569cb093e7..619dc1561ef 100644 --- a/xarray/tests/test_variable.py +++ b/xarray/tests/test_variable.py @@ -333,7 +333,7 @@ def test_pandas_period_index(self): v = self.cls(["x"], pd.period_range(start="2000", periods=20, freq="D")) v = v.load() # for dask-based Variable assert v[0] == pd.Period("2000", freq="D") - assert "Period('2000-01-01', 'D')" in repr(v) + assert "PeriodArray" in repr(v) @pytest.mark.parametrize("dtype", [float, int]) def test_1d_math(self, dtype: np.typing.DTypeLike) -> None: @@ -656,7 +656,7 @@ def test_pandas_categorical_dtype(self): data = pd.Categorical(np.arange(10, dtype="int64")) v = self.cls("x", data) print(v) # should not error - assert v.dtype == "int64" + assert v.dtype == data.dtype def test_pandas_datetime64_with_tz(self): data = pd.date_range( @@ -667,9 +667,12 @@ def test_pandas_datetime64_with_tz(self): ) v = self.cls("x", data) print(v) # should not error - if "America/New_York" in str(data.dtype): - # pandas is new enough that it has datetime64 with timezone dtype - assert v.dtype == "object" + if v.dtype == np.dtype("O"): + import dask.array as da + + assert isinstance(v.data, da.Array) + else: + assert v.dtype == data.dtype def test_multiindex(self): idx = pd.MultiIndex.from_product([list("abc"), [0, 1]]) @@ -1592,14 +1595,6 @@ def test_pandas_categorical_dtype(self): print(v) # should not error assert pd.api.types.is_extension_array_dtype(v.dtype) - def test_pandas_categorical_no_chunk(self): - data = pd.Categorical(np.arange(10, dtype="int64")) - v = self.cls("x", data) - with pytest.raises( - ValueError, match=r".*was found to be a Pandas ExtensionArray.*" - ): - v.chunk((5,)) - def test_squeeze(self): v = Variable(["x", "y"], [[1]]) assert_identical(Variable([], 1), v.squeeze()) @@ -2412,10 +2407,17 @@ def test_multiindex(self): def test_pad(self, mode, xr_arg, np_arg): super().test_pad(mode, xr_arg, np_arg) + @pytest.mark.skip(reason="dask doesn't support extension arrays") + def test_pandas_period_index(self): + super().test_pandas_period_index() + + @pytest.mark.skip(reason="dask doesn't support extension arrays") + def test_pandas_datetime64_with_tz(self): + super().test_pandas_datetime64_with_tz() + + @pytest.mark.skip(reason="dask doesn't support extension arrays") def test_pandas_categorical_dtype(self): - data = pd.Categorical(np.arange(10, dtype="int64")) - with pytest.raises(ValueError, match="was found to be a Pandas ExtensionArray"): - self.cls("x", data) + super().test_pandas_categorical_dtype() @requires_sparse @@ -3021,7 +3023,7 @@ def test_datetime_conversion(values, unit) -> None: # todo: check for redundancy (suggested per review) dims = ["time"] if isinstance(values, np.ndarray | pd.Index | pd.Series) else [] var = Variable(dims, values) - if var.dtype.kind == "M": + if var.dtype.kind == "M" and isinstance(var.dtype, np.dtype): assert var.dtype == np.dtype(f"datetime64[{unit}]") else: # The only case where a non-datetime64 dtype can occur currently is in @@ -3063,8 +3065,12 @@ def test_pandas_two_only_datetime_conversion_warnings( # todo: check for redundancy (suggested per review) var = Variable(["time"], data.astype(dtype)) # type: ignore[arg-type] - if var.dtype.kind == "M": + # we internally convert series to numpy representations to avoid too much nastiness with extension arrays + # when calling data.array e.g., with NumpyExtensionArrays + if isinstance(data, pd.Series): assert var.dtype == np.dtype("datetime64[s]") + elif var.dtype.kind == "M": + assert var.dtype == dtype else: # The only case where a non-datetime64 dtype can occur currently is in # the case that the variable is backed by a timezone-aware pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy