Skip to content

Flexible coordinate transform #9543

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Feb 14, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
lint, public API and docstrings
  • Loading branch information
benbovy committed Sep 24, 2024
commit 0b545cf61cf192cd1037e2e1d312921a8ab5843c
2 changes: 2 additions & 0 deletions xarray/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
where,
)
from xarray.core.concat import concat
from xarray.core.coordinate_transform import CoordinateTransform
from xarray.core.coordinates import Coordinates
from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset
Expand Down Expand Up @@ -109,6 +110,7 @@
"CFTimeIndex",
"Context",
"Coordinates",
"CoordinateTransform",
"DataArray",
"Dataset",
"DataTree",
Expand Down
10 changes: 7 additions & 3 deletions xarray/core/coordinate_transform.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Any, Iterable, Hashable, Mapping
from collections.abc import Hashable, Iterable, Mapping
from typing import Any

import numpy as np

Expand All @@ -15,11 +16,14 @@ def __init__(
self,
coord_names: Iterable[Hashable],
dim_size: Mapping[str, int],
dtype: Any = np.dtype(np.float64),
dtype: Any = None,
):
self.coord_names = tuple(coord_names)
self.dims = tuple(dim_size)
self.dim_size = dict(dim_size)

if dtype is None:
dtype = np.dtype(np.float64)
self.dtype = dtype

def forward(self, dim_positions: dict[str, Any]) -> dict[Hashable, Any]:
Expand Down Expand Up @@ -61,7 +65,7 @@ def equals(self, other: "CoordinateTransform") -> bool:
raise NotImplementedError

def generate_coords(self, dims: tuple[str] | None = None) -> dict[Hashable, Any]:
"""Returns all "world" coordinate labels."""
"""Compute all coordinate labels at once."""
if dims is None:
dims = self.dims

Expand Down
32 changes: 22 additions & 10 deletions xarray/core/indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from xarray.core import formatting, nputils, utils
from xarray.core.coordinate_transform import CoordinateTransform
from xarray.core.indexing import (
CoordinateTransformIndexingAdapter,
IndexSelResult,
PandasIndexingAdapter,
PandasMultiIndexingAdapter,
Expand All @@ -25,6 +26,7 @@
)

if TYPE_CHECKING:
from xarray.core.coordinate import Coordinates
from xarray.core.types import ErrorOptions, JoinOptions, Self
from xarray.core.variable import Variable

Expand Down Expand Up @@ -1374,8 +1376,13 @@ def rename(self, name_dict, dims_dict):


class CoordinateTransformIndex(Index):
"""Xarray index abstract class for transformation between "pixel"
and "world" coordinates.
"""Helper class for creating Xarray indexes based on coordinate transforms.

- wraps a :py:class:`CoordinateTransform` instance
- takes care of creating the index (lazy) coordinates
- supports point-wise label-based selection
- supports exact alignment only, by comparing indexes based on their transform
(not on their explicit coordinate labels)

"""

Expand Down Expand Up @@ -1409,9 +1416,11 @@ def create_variables(

def create_coordinates(self) -> Coordinates:
# TODO: move this in xarray.Index base class?
from xarray.core.coordinates import Coordinates

variables = self.create_variables()
indexes = {name: self for name in variables}
return xr.Coordinates(coords=variables, indexes=indexes)
return Coordinates(coords=variables, indexes=indexes)

def isel(
self, indexers: Mapping[Any, int | slice | np.ndarray | Variable]
Expand All @@ -1423,6 +1432,9 @@ def isel(
def sel(
self, labels: dict[Any, Any], method=None, tolerance=None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How hard would it be to support tolerance in some form? This is a common and useful form of error checking.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pretty tricky to support it here I think, probably better to handle it on a per case basis.

For basic transformations I guess it could be possible to calculate a single, uniform tolerance value in decimal array index units and validate the selected elements using those units (cheap). In other cases we would need to compute the forward transformation of the extracted array indices and then validate the selected elements based on distances in physical units (more expensive).

Also, there may be cases where the coordinates of a same transform object don’t have all the same physical units (e.g., both degrees and radians coordinates in an Astropy WCS object). Unless we forbid that in xarray.CoordinateTransform, it doesn’t make much sense to pass a single tolerance value. Passing a dictionary tolerance={coord_name: value} doesn’t look very nice either IMO. A {unit: value} dict looks better but adding explicit support for units here might be opening a can of worms.

) -> IndexSelResult:
from xarray.core.dataarray import DataArray
from xarray.core.variable import Variable

if method != "nearest":
raise ValueError(
"CoordinateTransformIndex only supports selection with method='nearest'"
Expand All @@ -1433,15 +1445,14 @@ def sel(

missing_labels = coord_names_set - labels_set
if missing_labels:
raise ValueError(
f"missing labels for coordinate(s): {','.join(missing_labels)}."
)
missing_labels_str = ",".join([f"{name}" for name in missing_labels])
raise ValueError(f"missing labels for coordinate(s): {missing_labels_str}.")

label0_obj = next(iter(labels.values()))
dim_size0 = getattr(label0_obj, "sizes", None)

is_xr_obj = [
isinstance(label, (xr.DataArray, xr.Variable)) for label in labels.values()
isinstance(label, DataArray | Variable) for label in labels.values()
]
if not all(is_xr_obj):
raise TypeError(
Expand All @@ -1461,13 +1472,14 @@ def sel(
dim_positions = self.transform.reverse(coord_labels)

results = {}
dims0 = tuple(dim_size0)
for dim, pos in dim_positions.items():
if isinstance(label0_obj, Variable):
xr_pos = Variable(label.dims, idx)
xr_pos = Variable(dims0, pos)
else:
# dataarray
xr_pos = DataArray(idx, dims=label.dims)
results[dim] = idx
xr_pos = DataArray(pos, dims=dims0)
results[dim] = xr_pos

return IndexSelResult(results)

Expand Down
7 changes: 4 additions & 3 deletions xarray/core/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1961,8 +1961,9 @@ def copy(self, deep: bool = True) -> Self:


class CoordinateTransformIndexingAdapter(ExplicitlyIndexedNDArrayMixin):
"""Wrap a CoordinateTransform to support explicit indexing and
lazy coordinate labels.
"""Wrap a CoordinateTransform as a lazy coordinate array.

Supports explicit indexing (both outer and vectorized).

"""

Expand Down Expand Up @@ -2036,7 +2037,7 @@ def __getitem__(self, indexer: ExplicitIndexer):
self._check_and_raise_if_non_basic_indexer(indexer)

# also works with basic indexing
return self._oindex_get(indexer)
return self._oindex_get(OuterIndexer(indexer.tuple))

def __setitem__(self, indexer: ExplicitIndexer, value: Any) -> None:
raise TypeError(
Expand Down
9 changes: 7 additions & 2 deletions xarray/indexes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@

"""

from xarray.core.indexes import Index, PandasIndex, PandasMultiIndex
from xarray.core.indexes import (
CoordinateTransformIndex,
Index,
PandasIndex,
PandasMultiIndex,
)

__all__ = ["Index", "PandasIndex", "PandasMultiIndex"]
__all__ = ["CoordinateTransformIndex", "Index", "PandasIndex", "PandasMultiIndex"]
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy