From 753535ce0890892e8aea3fd033cee01bd6f11081 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Tue, 13 May 2025 17:25:11 +0200 Subject: [PATCH 1/6] (fix): pandas extension array repr for int64[pyarrow] --- xarray/core/extension_array.py | 15 +++++++++++++-- xarray/core/formatting.py | 7 ++++++- xarray/tests/__init__.py | 12 ++++++++++++ xarray/tests/test_concat.py | 18 +++++++++++------- xarray/tests/test_dataset.py | 15 ++++++++++----- 5 files changed, 52 insertions(+), 15 deletions(-) diff --git a/xarray/core/extension_array.py b/xarray/core/extension_array.py index 096a427e425..b56af41f220 100644 --- a/xarray/core/extension_array.py +++ b/xarray/core/extension_array.py @@ -65,6 +65,17 @@ def __extension_duck_array__where( return cast(T_ExtensionArray, pd.Series(x).where(condition, pd.Series(y)).array) +@implements(np.reshape) +def __extension_duck_array__reshape( + arr: T_ExtensionArray, shape: tuple +) -> T_ExtensionArray: + if (shape[0] == len(arr) and len(shape) == 1) or shape == (-1,): + return arr + raise NotImplementedError( + f"Cannot reshape 1d-only pandas extension array to: {shape}" + ) + + @dataclass(frozen=True) class PandasExtensionArray(Generic[T_ExtensionArray], NDArrayMixin): """NEP-18 compliant wrapper for pandas extension arrays. @@ -100,10 +111,10 @@ def replace_duck_with_extension_array(args) -> list: args = tuple(replace_duck_with_extension_array(args)) if func not in HANDLED_EXTENSION_ARRAY_FUNCTIONS: - return func(*args, **kwargs) + raise KeyError("Function not registered for pandas extension arrays.") res = HANDLED_EXTENSION_ARRAY_FUNCTIONS[func](*args, **kwargs) if is_extension_array_dtype(res): - return type(self)[type(res)](res) + return PandasExtensionArray(res) return res def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): diff --git a/xarray/core/formatting.py b/xarray/core/formatting.py index 7aa333ffb2e..86fb147d382 100644 --- a/xarray/core/formatting.py +++ b/xarray/core/formatting.py @@ -19,6 +19,7 @@ from xarray.core.datatree_render import RenderDataTree from xarray.core.duck_array_ops import array_all, array_any, array_equiv, astype, ravel +from xarray.core.extension_array import PandasExtensionArray from xarray.core.indexing import MemoryCachedArray from xarray.core.options import OPTIONS, _get_boolean_with_default from xarray.core.treenode import group_subtrees @@ -176,6 +177,8 @@ def format_timedelta(t, timedelta_format=None): def format_item(x, timedelta_format=None, quote_strings=True): """Returns a succinct summary of an object as a string""" + if isinstance(x, PandasExtensionArray): + return f"{x.array[0]}" if isinstance(x, np.datetime64 | datetime): return format_timestamp(x) if isinstance(x, np.timedelta64 | timedelta): @@ -194,7 +197,9 @@ def format_items(x): """Returns a succinct summaries of all items in a sequence as strings""" x = to_duck_array(x) timedelta_format = "datetime" - if np.issubdtype(x.dtype, np.timedelta64): + if not isinstance(x, PandasExtensionArray) and np.issubdtype( + x.dtype, np.timedelta64 + ): x = astype(x, dtype="timedelta64[ns]") day_part = x[~pd.isnull(x)].astype("timedelta64[D]").astype("timedelta64[ns]") time_needed = x[~pd.isnull(x)] != day_part diff --git a/xarray/tests/__init__.py b/xarray/tests/__init__.py index b33192393f7..5b05a3fb108 100644 --- a/xarray/tests/__init__.py +++ b/xarray/tests/__init__.py @@ -363,6 +363,18 @@ def create_test_data( ) ), ) + obj["var5"] = ( + "dim1", + pd.arrays.IntervalArray([pd.Interval(0, 1)] * dim_sizes[0]), + ) + if has_pyarrow: + obj["var6"] = ( + "dim1", + pd.array( + rs.integers(1, 10, size=dim_sizes[0]).tolist(), + dtype="int64[pyarrow]", + ), + ) if dim_sizes == _DEFAULT_TEST_DIM_SIZES: numbers_values = np.array([0, 1, 2, 0, 0, 1, 1, 2, 2, 3], dtype="int64") else: diff --git a/xarray/tests/test_concat.py b/xarray/tests/test_concat.py index 49c6490d819..39ecfcf35c0 100644 --- a/xarray/tests/test_concat.py +++ b/xarray/tests/test_concat.py @@ -21,6 +21,7 @@ assert_equal, assert_identical, requires_dask, + requires_pyarrow, ) from xarray.tests.test_dataset import create_test_data @@ -154,19 +155,22 @@ def test_concat_missing_var() -> None: assert_identical(actual, expected) -def test_concat_categorical() -> None: +@pytest.mark.parametrize( + "var", ["var4", "var5", pytest.param("var6", marks=requires_pyarrow)] +) +def test_concat_extension_array(var) -> None: data1 = create_test_data(use_extension_array=True) data2 = create_test_data(use_extension_array=True) concatenated = concat([data1, data2], dim="dim1") - assert ( - concatenated["var4"] - == type(data2["var4"].variable.data)._concat_same_type( + assert pd.Series( + concatenated[var] + == type(data2[var].variable.data)._concat_same_type( [ - data1["var4"].variable.data, - data2["var4"].variable.data, + data1[var].variable.data, + data2[var].variable.data, ] ) - ).all() + ).all() # need to wrap in series because pyarrow bool does not support `all` def test_concat_missing_multiple_consecutive_var() -> None: diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index bacad96a213..c17f772474c 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -58,6 +58,7 @@ create_test_data, has_cftime, has_dask, + has_pyarrow, raise_if_dask_computes, requires_bottleneck, requires_cftime, @@ -297,12 +298,16 @@ def test_repr(self) -> None: var1 (dim1, dim2) float64 576B -0.9891 -0.3678 1.288 ... -0.2116 0.364 var2 (dim1, dim2) float64 576B 0.953 1.52 1.704 ... 0.1347 -0.6423 var3 (dim3, dim1) float64 640B 0.4107 0.9941 0.1665 ... 0.716 1.555 - var4 (dim1) category 32B 'b' 'c' 'b' 'a' 'c' 'a' 'c' 'a' + var4 (dim1) category 32B b c b a c a c a + var5 (dim1) interval[int64, right] 128B (0, 1] (0, 1] ... (0, 1] (0, 1]{} Attributes: - foo: bar""".format( - data["dim3"].dtype, - "ns", - ) + foo: bar""" + ).format( + data["dim3"].dtype, + "ns", + "\n var6 (dim1) int64[pyarrow] 64B 5 9 7 2 6 2 8 1" + if has_pyarrow + else "", ) actual = "\n".join(x.rstrip() for x in repr(data).split("\n")) From a23de74a125cf767cbbd9b38c27deb2a7bb6cd20 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Tue, 13 May 2025 17:38:06 +0200 Subject: [PATCH 2/6] (fix): remove problematic intervalarray repr --- xarray/tests/__init__.py | 6 +----- xarray/tests/test_concat.py | 4 +--- xarray/tests/test_dataset.py | 10 +++++----- 3 files changed, 7 insertions(+), 13 deletions(-) diff --git a/xarray/tests/__init__.py b/xarray/tests/__init__.py index 5b05a3fb108..fe76df75fa0 100644 --- a/xarray/tests/__init__.py +++ b/xarray/tests/__init__.py @@ -363,12 +363,8 @@ def create_test_data( ) ), ) - obj["var5"] = ( - "dim1", - pd.arrays.IntervalArray([pd.Interval(0, 1)] * dim_sizes[0]), - ) if has_pyarrow: - obj["var6"] = ( + obj["var5"] = ( "dim1", pd.array( rs.integers(1, 10, size=dim_sizes[0]).tolist(), diff --git a/xarray/tests/test_concat.py b/xarray/tests/test_concat.py index 39ecfcf35c0..ed5aac4fe99 100644 --- a/xarray/tests/test_concat.py +++ b/xarray/tests/test_concat.py @@ -155,9 +155,7 @@ def test_concat_missing_var() -> None: assert_identical(actual, expected) -@pytest.mark.parametrize( - "var", ["var4", "var5", pytest.param("var6", marks=requires_pyarrow)] -) +@pytest.mark.parametrize("var", ["var4", pytest.param("var5", marks=requires_pyarrow)]) def test_concat_extension_array(var) -> None: data1 = create_test_data(use_extension_array=True) data2 = create_test_data(use_extension_array=True) diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index c17f772474c..22d46d5a4cb 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -298,14 +298,13 @@ def test_repr(self) -> None: var1 (dim1, dim2) float64 576B -0.9891 -0.3678 1.288 ... -0.2116 0.364 var2 (dim1, dim2) float64 576B 0.953 1.52 1.704 ... 0.1347 -0.6423 var3 (dim3, dim1) float64 640B 0.4107 0.9941 0.1665 ... 0.716 1.555 - var4 (dim1) category 32B b c b a c a c a - var5 (dim1) interval[int64, right] 128B (0, 1] (0, 1] ... (0, 1] (0, 1]{} + var4 (dim1) category 32B b c b a c a c a{} Attributes: foo: bar""" ).format( data["dim3"].dtype, "ns", - "\n var6 (dim1) int64[pyarrow] 64B 5 9 7 2 6 2 8 1" + "\n var5 (dim1) int64[pyarrow] 64B 5 9 7 2 6 2 8 1" if has_pyarrow else "", ) @@ -5801,7 +5800,7 @@ def test_reduce_cumsum_test_dims(self, reduct, expected, func) -> None: def test_reduce_non_numeric(self) -> None: data1 = create_test_data(seed=44, use_extension_array=True) data2 = create_test_data(seed=44) - add_vars = {"var5": ["dim1", "dim2"], "var6": ["dim1"]} + add_vars = {"var6": ["dim1", "dim2"], "var7": ["dim1"]} for v, dims in sorted(add_vars.items()): size = tuple(data1.sizes[d] for d in dims) data = np.random.randint(0, 100, size=size).astype(np.str_) @@ -5811,10 +5810,11 @@ def test_reduce_non_numeric(self) -> None: "var4" not in data1.mean() and "var5" not in data1.mean() and "var6" not in data1.mean() + and "var7" not in data1.mean() ) assert_equal(data1.mean(), data2.mean()) assert_equal(data1.mean(dim="dim1"), data2.mean(dim="dim1")) - assert "var5" not in data1.mean(dim="dim2") and "var6" in data1.mean(dim="dim2") + assert "var6" not in data1.mean(dim="dim2") and "var7" in data1.mean(dim="dim2") @pytest.mark.filterwarnings( "ignore:Once the behaviour of DataArray:DeprecationWarning" From ad2aa493bb269f07e3a54f521a990ed6f94b4647 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Tue, 13 May 2025 17:38:43 +0200 Subject: [PATCH 3/6] (fix): new categorical repr --- xarray/tests/test_dataarray.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index e7acdcdd4f3..3640bf4df14 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -3637,7 +3637,7 @@ def test_series_categorical_index(self) -> None: s = pd.Series(np.arange(5), index=pd.CategoricalIndex(list("aabbc"))) arr = DataArray(s) - assert "'a'" in repr(arr) # should not error + assert "a a b b" in repr(arr) # should not error @pytest.mark.parametrize("use_dask", [True, False]) @pytest.mark.parametrize("data", ["list", "array", True]) From f4d26546a29f034eed7e8d1e374ac6b65f570d2c Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Wed, 14 May 2025 10:27:59 +0200 Subject: [PATCH 4/6] (fix): use f-string for repr test --- xarray/tests/test_dataset.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 22d46d5a4cb..7f152f8c506 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -284,29 +284,28 @@ def test_repr(self) -> None: data = create_test_data(seed=123, use_extension_array=True) data.attrs["foo"] = "bar" # need to insert str dtype at runtime to handle different endianness + var5 = ( + "\n var5 (dim1) int64[pyarrow] 64B 5 9 7 2 6 2 8 1" + if has_pyarrow + else "" + ) expected = dedent( - """\ + f"""\ Size: 2kB Dimensions: (dim2: 9, dim3: 10, time: 20, dim1: 8) Coordinates: * dim2 (dim2) float64 72B 0.0 0.5 1.0 1.5 2.0 2.5 3.0 3.5 4.0 - * dim3 (dim3) {} 40B 'a' 'b' 'c' 'd' 'e' 'f' 'g' 'h' 'i' 'j' - * time (time) datetime64[{}] 160B 2000-01-01 2000-01-02 ... 2000-01-20 + * dim3 (dim3) {data["dim3"].dtype} 40B 'a' 'b' 'c' 'd' 'e' 'f' 'g' 'h' 'i' 'j' + * time (time) datetime64[ns] 160B 2000-01-01 2000-01-02 ... 2000-01-20 numbers (dim3) int64 80B 0 1 2 0 0 1 1 2 2 3 Dimensions without coordinates: dim1 Data variables: var1 (dim1, dim2) float64 576B -0.9891 -0.3678 1.288 ... -0.2116 0.364 var2 (dim1, dim2) float64 576B 0.953 1.52 1.704 ... 0.1347 -0.6423 var3 (dim3, dim1) float64 640B 0.4107 0.9941 0.1665 ... 0.716 1.555 - var4 (dim1) category 32B b c b a c a c a{} + var4 (dim1) category 32B b c b a c a c a{var5} Attributes: foo: bar""" - ).format( - data["dim3"].dtype, - "ns", - "\n var5 (dim1) int64[pyarrow] 64B 5 9 7 2 6 2 8 1" - if has_pyarrow - else "", ) actual = "\n".join(x.rstrip() for x in repr(data).split("\n")) From f39cbee535d1590e2837c3afe24ed88507e0126a Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Wed, 14 May 2025 10:28:20 +0200 Subject: [PATCH 5/6] (fix): comment in test --- xarray/tests/test_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 7f152f8c506..f38b2798eb3 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -5804,7 +5804,7 @@ def test_reduce_non_numeric(self) -> None: size = tuple(data1.sizes[d] for d in dims) data = np.random.randint(0, 100, size=size).astype(np.str_) data1[v] = (dims, data, {"foo": "variable"}) - # var4 is extension array categorical and should be dropped + # var4 and var5 are extension arrays and should be dropped assert ( "var4" not in data1.mean() and "var5" not in data1.mean() From b8dc0f87503b24b12adfc0747429c7c4119e16b4 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 29 May 2025 12:19:07 +0200 Subject: [PATCH 6/6] (chore): add comment --- xarray/core/formatting.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/xarray/core/formatting.py b/xarray/core/formatting.py index 86fb147d382..90a2e5d0d92 100644 --- a/xarray/core/formatting.py +++ b/xarray/core/formatting.py @@ -178,7 +178,10 @@ def format_timedelta(t, timedelta_format=None): def format_item(x, timedelta_format=None, quote_strings=True): """Returns a succinct summary of an object as a string""" if isinstance(x, PandasExtensionArray): - return f"{x.array[0]}" + # We want to bypass PandasExtensionArray's repr here + # because its __repr__ is PandasExtensionArray(array=[...]) + # and this function is only for single elements. + return str(x.array[0]) if isinstance(x, np.datetime64 | datetime): return format_timestamp(x) if isinstance(x, np.timedelta64 | timedelta): pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy