Skip to content

Commit 97fb90b

Browse files
authored
(fix): disallow NumpyExtensionArray (#10334)
1 parent 60bc816 commit 97fb90b

File tree

5 files changed

+66
-10
lines changed

5 files changed

+66
-10
lines changed

properties/test_pandas_roundtrip.py

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import hypothesis.extra.pandas as pdst # isort:skip
1616
import hypothesis.strategies as st # isort:skip
1717
from hypothesis import given # isort:skip
18+
from xarray.tests import has_pyarrow
1819

1920
numeric_dtypes = st.one_of(
2021
npst.unsigned_integer_dtypes(endianness="="),
@@ -134,10 +135,39 @@ def test_roundtrip_pandas_dataframe_datetime(df) -> None:
134135
xr.testing.assert_identical(dataset, roundtripped.to_xarray())
135136

136137

137-
def test_roundtrip_1d_pandas_extension_array() -> None:
138-
df = pd.DataFrame({"cat": pd.Categorical(["a", "b", "c"])})
139-
arr = xr.Dataset.from_dataframe(df)["cat"]
138+
@pytest.mark.parametrize(
139+
"extension_array",
140+
[
141+
pd.Categorical(["a", "b", "c"]),
142+
pd.array(["a", "b", "c"], dtype="string"),
143+
pd.arrays.IntervalArray(
144+
[pd.Interval(0, 1), pd.Interval(1, 5), pd.Interval(2, 6)]
145+
),
146+
pd.arrays.TimedeltaArray._from_sequence(pd.TimedeltaIndex(["1h", "2h", "3h"])),
147+
pd.arrays.DatetimeArray._from_sequence(
148+
pd.DatetimeIndex(["2023-01-01", "2023-01-02", "2023-01-03"], freq="D")
149+
),
150+
np.array([1, 2, 3], dtype="int64"),
151+
]
152+
+ ([pd.array([1, 2, 3], dtype="int64[pyarrow]")] if has_pyarrow else []),
153+
ids=["cat", "string", "interval", "timedelta", "datetime", "numpy"]
154+
+ (["pyarrow"] if has_pyarrow else []),
155+
)
156+
@pytest.mark.parametrize("is_index", [True, False])
157+
def test_roundtrip_1d_pandas_extension_array(extension_array, is_index) -> None:
158+
df = pd.DataFrame({"arr": extension_array})
159+
if is_index:
160+
df = df.set_index("arr")
161+
arr = xr.Dataset.from_dataframe(df)["arr"]
140162
roundtripped = arr.to_pandas()
141-
assert (df["cat"] == roundtripped).all()
142-
assert df["cat"].dtype == roundtripped.dtype
143-
xr.testing.assert_identical(arr, roundtripped.to_xarray())
163+
df_arr_to_test = df.index if is_index else df["arr"]
164+
assert (df_arr_to_test == roundtripped).all()
165+
# `NumpyExtensionArray` types are not roundtripped, including `StringArray` which subtypes.
166+
if isinstance(extension_array, pd.arrays.NumpyExtensionArray): # type: ignore[attr-defined]
167+
assert isinstance(arr.data, np.ndarray)
168+
else:
169+
assert (
170+
df_arr_to_test.dtype
171+
== (roundtripped.index if is_index else roundtripped).dtype
172+
)
173+
xr.testing.assert_identical(arr, roundtripped.to_xarray())

xarray/core/dataset.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@
9999
parse_dims_as_set,
100100
)
101101
from xarray.core.variable import (
102+
UNSUPPORTED_EXTENSION_ARRAY_TYPES,
102103
IndexVariable,
103104
Variable,
104105
as_variable,
@@ -7281,7 +7282,7 @@ def from_dataframe(cls, dataframe: pd.DataFrame, sparse: bool = False) -> Self:
72817282
extension_arrays = []
72827283
for k, v in dataframe.items():
72837284
if not is_extension_array_dtype(v) or isinstance(
7284-
v.array, pd.arrays.DatetimeArray | pd.arrays.TimedeltaArray
7285+
v.array, UNSUPPORTED_EXTENSION_ARRAY_TYPES
72857286
):
72867287
arrays.append((k, np.asarray(v)))
72877288
else:

xarray/core/extension_array.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,13 @@ class PandasExtensionArray(Generic[T_ExtensionArray], NDArrayMixin):
9393
def __post_init__(self):
9494
if not isinstance(self.array, pd.api.extensions.ExtensionArray):
9595
raise TypeError(f"{self.array} is not an pandas ExtensionArray.")
96+
# This does not use the UNSUPPORTED_EXTENSION_ARRAY_TYPES whitelist because
97+
# we do support extension arrays from datetime, for example, that need
98+
# duck array support internally via this class.
99+
if isinstance(self.array, pd.arrays.NumpyExtensionArray):
100+
raise TypeError(
101+
"`NumpyExtensionArray` should be converted to a numpy array in `xarray` internally."
102+
)
96103

97104
def __array_function__(self, func, types, args, kwargs):
98105
def replace_duck_with_extension_array(args) -> list:

xarray/core/indexing.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1802,8 +1802,12 @@ def __array__(
18021802

18031803
def get_duck_array(self) -> np.ndarray | PandasExtensionArray:
18041804
# We return an PandasExtensionArray wrapper type that satisfies
1805-
# duck array protocols. This is what's needed for tests to pass.
1806-
if pd.api.types.is_extension_array_dtype(self.array):
1805+
# duck array protocols.
1806+
# `NumpyExtensionArray` is excluded
1807+
if pd.api.types.is_extension_array_dtype(self.array) and not isinstance(
1808+
self.array.array,
1809+
pd.arrays.NumpyExtensionArray, # type: ignore[attr-defined]
1810+
):
18071811
from xarray.core.extension_array import PandasExtensionArray
18081812

18091813
return PandasExtensionArray(self.array.array)

xarray/core/variable.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,11 @@
6363
)
6464
# https://github.com/python/mypy/issues/224
6565
BASIC_INDEXING_TYPES = integer_types + (slice,)
66+
UNSUPPORTED_EXTENSION_ARRAY_TYPES = (
67+
pd.arrays.DatetimeArray,
68+
pd.arrays.TimedeltaArray,
69+
pd.arrays.NumpyExtensionArray, # type: ignore[attr-defined]
70+
)
6671

6772
if TYPE_CHECKING:
6873
from xarray.core.types import (
@@ -190,6 +195,8 @@ def _maybe_wrap_data(data):
190195
"""
191196
if isinstance(data, pd.Index):
192197
return PandasIndexingAdapter(data)
198+
if isinstance(data, UNSUPPORTED_EXTENSION_ARRAY_TYPES):
199+
return data.to_numpy()
193200
if isinstance(data, pd.api.extensions.ExtensionArray):
194201
return PandasExtensionArray(data)
195202
return data
@@ -251,7 +258,14 @@ def convert_non_numpy_type(data):
251258

252259
# we don't want nested self-described arrays
253260
if isinstance(data, pd.Series | pd.DataFrame):
254-
pandas_data = data.values
261+
if (
262+
isinstance(data, pd.Series)
263+
and pd.api.types.is_extension_array_dtype(data)
264+
and not isinstance(data.array, UNSUPPORTED_EXTENSION_ARRAY_TYPES)
265+
):
266+
pandas_data = data.array
267+
else:
268+
pandas_data = data.values # type: ignore[assignment]
255269
if isinstance(pandas_data, NON_NUMPY_SUPPORTED_ARRAY_TYPES):
256270
return convert_non_numpy_type(pandas_data)
257271
else:

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy