Skip to content

(fix): allow upcasting of nans in as_shared_dtype for extension arrays #10292

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
May 8, 2025
Merged
Next Next commit
(fix): allow upcasting of nans in as_shared_dtype for extension arrays
  • Loading branch information
ilan-gold committed May 7, 2025
commit a9ef54b185a899c74e6bbcc5d14b871dea0b7b1b
18 changes: 13 additions & 5 deletions xarray/core/duck_array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from xarray.compat import dask_array_compat, dask_array_ops
from xarray.compat.array_api_compat import get_array_namespace
from xarray.core import dtypes, nputils
from xarray.core.extension_array import PandasExtensionArray
from xarray.core.options import OPTIONS
from xarray.core.utils import is_duck_array, is_duck_dask_array, module_available
from xarray.namedarray.parallelcompat import get_chunked_array_type
Expand Down Expand Up @@ -256,13 +257,20 @@ def as_shared_dtype(scalars_or_arrays, xp=None):
extension_array_types = [
x.dtype for x in scalars_or_arrays if is_extension_array_dtype(x)
]
if len(extension_array_types) == len(scalars_or_arrays) and all(
non_nans = [x for x in scalars_or_arrays if not (x is pd.NA or x is np.nan)]
if len(extension_array_types) == len(non_nans) and all(
isinstance(x, type(extension_array_types[0])) for x in extension_array_types
):
return scalars_or_arrays
return [
x
if not (x is pd.NA or x is np.nan)
else PandasExtensionArray(
type(non_nans[0].array)._from_sequence([x], dtype=non_nans[0].dtype)
)
for x in scalars_or_arrays
]
raise ValueError(
"Cannot cast arrays to shared type, found"
f" array types {[x.dtype for x in scalars_or_arrays]}"
f"Cannot cast values to shared type, found values: {scalars_or_arrays}"
)

# Avoid calling array_type("cupy") repeatidely in the any check
Expand Down Expand Up @@ -383,7 +391,7 @@ def where(condition, x, y):
condition = asarray(condition, dtype=dtype, xp=xp)
else:
condition = astype(condition, dtype=dtype, xp=xp)

print(as_shared_dtype([x, y], xp=xp))
return xp.where(condition, *as_shared_dtype([x, y], xp=xp))


Expand Down
17 changes: 17 additions & 0 deletions xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1826,6 +1826,23 @@ def test_categorical_reindex(self) -> None:
actual = ds.reindex(cat=["foo"])["cat"].values
assert (actual == np.array(["foo"])).all()

@pytest.mark.parametrize("fill_value", [np.nan, pd.NA])
def test_extensionarray_negative_reindex(self, fill_value) -> None:
cat = pd.Categorical(
["foo", "bar", "baz"],
categories=["foo", "bar", "baz", "qux", "quux", "corge"],
)
ds = xr.Dataset(
{"cat": ("index", cat)},
coords={"index": ("index", np.arange(3))},
)
reindexed_cat = (
ds.reindex(index=[-1, 1, 1], fill_value=fill_value)["cat"]
.to_pandas()
.values
)
assert reindexed_cat.equals(pd.array([pd.NA, "bar", "bar"], dtype=cat.dtype))

def test_extension_array_reindex_same(self) -> None:
series = pd.Series([1, 2, pd.NA, 3], dtype=pd.Int32Dtype())
test = xr.Dataset({"test": series})
Expand Down
Loading
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy