diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 96330a64b68..48dc3b7627a 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -13,6 +13,7 @@ from collections.abc import Callable from functools import partial from importlib import import_module +from typing import Any import numpy as np import pandas as pd @@ -27,6 +28,7 @@ from xarray.compat import dask_array_compat, dask_array_ops from xarray.compat.array_api_compat import get_array_namespace from xarray.core import dtypes, nputils +from xarray.core.extension_array import PandasExtensionArray from xarray.core.options import OPTIONS from xarray.core.utils import is_duck_array, is_duck_dask_array, module_available from xarray.namedarray.parallelcompat import get_chunked_array_type @@ -143,6 +145,21 @@ def round(array): around: Callable = round +def isna(data: Any) -> bool: + """Checks if data is literally np.nan or pd.NA. + + Parameters + ---------- + data + Any python object + + Returns + ------- + Whether or not the data is np.nan or pd.NA + """ + return data is pd.NA or data is np.nan + + def isnull(data): data = asarray(data) @@ -256,13 +273,20 @@ def as_shared_dtype(scalars_or_arrays, xp=None): extension_array_types = [ x.dtype for x in scalars_or_arrays if is_extension_array_dtype(x) ] - if len(extension_array_types) == len(scalars_or_arrays) and all( + non_nans = [x for x in scalars_or_arrays if not isna(x)] + if len(extension_array_types) == len(non_nans) and all( isinstance(x, type(extension_array_types[0])) for x in extension_array_types ): - return scalars_or_arrays + return [ + x + if not isna(x) + else PandasExtensionArray( + type(non_nans[0].array)._from_sequence([x], dtype=non_nans[0].dtype) + ) + for x in scalars_or_arrays + ] raise ValueError( - "Cannot cast arrays to shared type, found" - f" array types {[x.dtype for x in scalars_or_arrays]}" + f"Cannot cast values to shared type, found values: {scalars_or_arrays}" ) # Avoid calling array_type("cupy") repeatidely in the any check diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index af7db7294a8..c8df61142f6 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -8,7 +8,7 @@ from copy import copy, deepcopy from io import StringIO from textwrap import dedent -from typing import Any, Literal +from typing import Any, Literal, cast import numpy as np import pandas as pd @@ -1827,6 +1827,26 @@ def test_categorical_index_reindex(self) -> None: actual = ds.reindex(cat=["foo"])["cat"].values assert (actual == np.array(["foo"])).all() + @pytest.mark.parametrize("fill_value", [np.nan, pd.NA]) + def test_extensionarray_negative_reindex(self, fill_value) -> None: + cat = pd.Categorical( + ["foo", "bar", "baz"], + categories=["foo", "bar", "baz", "qux", "quux", "corge"], + ) + ds = xr.Dataset( + {"cat": ("index", cat)}, + coords={"index": ("index", np.arange(3))}, + ) + reindexed_cat = cast( + pd.api.extensions.ExtensionArray, + ( + ds.reindex(index=[-1, 1, 1], fill_value=fill_value)["cat"] + .to_pandas() + .values + ), + ) + assert reindexed_cat.equals(pd.array([pd.NA, "bar", "bar"], dtype=cat.dtype)) # type: ignore[attr-defined] + def test_extension_array_reindex_same(self) -> None: series = pd.Series([1, 2, pd.NA, 3], dtype=pd.Int32Dtype()) test = xr.Dataset({"test": series})
Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.
Alternative Proxies: