Skip to content

(fix): allow upcasting of nans in as_shared_dtype for extension arrays #10292

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
May 8, 2025
Merged
16 changes: 12 additions & 4 deletions xarray/core/duck_array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from xarray.compat import dask_array_compat, dask_array_ops
from xarray.compat.array_api_compat import get_array_namespace
from xarray.core import dtypes, nputils
from xarray.core.extension_array import PandasExtensionArray
from xarray.core.options import OPTIONS
from xarray.core.utils import is_duck_array, is_duck_dask_array, module_available
from xarray.namedarray.parallelcompat import get_chunked_array_type
Expand Down Expand Up @@ -234,7 +235,7 @@

if xp == np:
# numpy currently doesn't have a astype:
return data.astype(dtype, **kwargs)

Check warning on line 238 in xarray/core/duck_array_ops.py

View workflow job for this annotation

GitHub Actions / macos-latest py3.10

invalid value encountered in cast

Check warning on line 238 in xarray/core/duck_array_ops.py

View workflow job for this annotation

GitHub Actions / macos-latest py3.10

invalid value encountered in cast

Check warning on line 238 in xarray/core/duck_array_ops.py

View workflow job for this annotation

GitHub Actions / ubuntu-latest py3.10

invalid value encountered in cast

Check warning on line 238 in xarray/core/duck_array_ops.py

View workflow job for this annotation

GitHub Actions / ubuntu-latest py3.10

invalid value encountered in cast

Check warning on line 238 in xarray/core/duck_array_ops.py

View workflow job for this annotation

GitHub Actions / windows-latest py3.10

invalid value encountered in cast

Check warning on line 238 in xarray/core/duck_array_ops.py

View workflow job for this annotation

GitHub Actions / windows-latest py3.10

invalid value encountered in cast
return xp.astype(data, dtype, **kwargs)


Expand All @@ -256,13 +257,20 @@
extension_array_types = [
x.dtype for x in scalars_or_arrays if is_extension_array_dtype(x)
]
if len(extension_array_types) == len(scalars_or_arrays) and all(
non_nans = [x for x in scalars_or_arrays if not (x is pd.NA or x is np.nan)]
if len(extension_array_types) == len(non_nans) and all(
isinstance(x, type(extension_array_types[0])) for x in extension_array_types
):
return scalars_or_arrays
return [
x
if not (x is pd.NA or x is np.nan)
else PandasExtensionArray(
type(non_nans[0].array)._from_sequence([x], dtype=non_nans[0].dtype)
)
for x in scalars_or_arrays
]
raise ValueError(
"Cannot cast arrays to shared type, found"
f" array types {[x.dtype for x in scalars_or_arrays]}"
f"Cannot cast values to shared type, found values: {scalars_or_arrays}"
)

# Avoid calling array_type("cupy") repeatidely in the any check
Expand Down
22 changes: 21 additions & 1 deletion xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from copy import copy, deepcopy
from io import StringIO
from textwrap import dedent
from typing import Any, Literal
from typing import Any, Literal, cast

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -1826,6 +1826,26 @@ def test_categorical_reindex(self) -> None:
actual = ds.reindex(cat=["foo"])["cat"].values
assert (actual == np.array(["foo"])).all()

@pytest.mark.parametrize("fill_value", [np.nan, pd.NA])
def test_extensionarray_negative_reindex(self, fill_value) -> None:
cat = pd.Categorical(
["foo", "bar", "baz"],
categories=["foo", "bar", "baz", "qux", "quux", "corge"],
)
ds = xr.Dataset(
{"cat": ("index", cat)},
coords={"index": ("index", np.arange(3))},
)
reindexed_cat = cast(
pd.api.extensions.ExtensionArray,
(
ds.reindex(index=[-1, 1, 1], fill_value=fill_value)["cat"]
.to_pandas()
.values
),
)
assert reindexed_cat.equals(pd.array([pd.NA, "bar", "bar"], dtype=cat.dtype)) # type: ignore[attr-defined]

def test_extension_array_reindex_same(self) -> None:
series = pd.Series([1, 2, pd.NA, 3], dtype=pd.Int32Dtype())
test = xr.Dataset({"test": series})
Expand Down
Loading
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy