Skip to content

Flexible coordinate transform #9543

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Feb 14, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix more typing issues
  • Loading branch information
benbovy committed Feb 13, 2025
commit 632c71b103d4e659216f37d3adf0f5b8e8aad091
4 changes: 3 additions & 1 deletion xarray/core/coordinate_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,9 @@ def equals(self, other: "CoordinateTransform") -> bool:
"""Check equality with another CoordinateTransform of the same kind."""
raise NotImplementedError

def generate_coords(self, dims: tuple[str] | None = None) -> dict[Hashable, Any]:
def generate_coords(
self, dims: tuple[str, ...] | None = None
) -> dict[Hashable, Any]:
"""Compute all coordinate labels at once."""
if dims is None:
dims = self.dims
Expand Down
11 changes: 5 additions & 6 deletions xarray/core/indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1447,7 +1447,7 @@ def sel(
raise ValueError(f"missing labels for coordinate(s): {missing_labels_str}.")

label0_obj = next(iter(labels.values()))
dim_size0 = getattr(label0_obj, "sizes", None)
dim_size0 = getattr(label0_obj, "sizes", {})

is_xr_obj = [
isinstance(label, DataArray | Variable) for label in labels.values()
Expand All @@ -1457,7 +1457,7 @@ def sel(
"CoordinateTransformIndex only supports advanced (point-wise) indexing "
"with either xarray.DataArray or xarray.Variable objects."
)
dim_size = [getattr(label, "sizes", None) for label in labels.values()]
dim_size = [getattr(label, "sizes", {}) for label in labels.values()]
if any(ds != dim_size0 for ds in dim_size):
raise ValueError(
"CoordinateTransformIndex only supports advanced (point-wise) indexing "
Expand All @@ -1469,19 +1469,18 @@ def sel(
}
dim_positions = self.transform.reverse(coord_labels)

results = {}
results: dict[str, Variable | DataArray] = {}
dims0 = tuple(dim_size0)
for dim, pos in dim_positions.items():
# TODO: rounding the decimal positions is not always the behavior we expect
# (there are different ways to represent implicit intervals)
# we should probably make this customizable.
pos = np.round(pos).astype("int")
if isinstance(label0_obj, Variable):
xr_pos = Variable(dims0, pos)
results[dim] = Variable(dims0, pos)
else:
# dataarray
xr_pos = DataArray(pos, dims=dims0)
results[dim] = xr_pos
results[dim] = DataArray(pos, dims=dims0)

return IndexSelResult(results)

Expand Down
6 changes: 3 additions & 3 deletions xarray/core/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1302,7 +1302,7 @@ def _decompose_outer_indexer(
return (BasicIndexer(tuple(backend_indexer)), OuterIndexer(tuple(np_indexer)))


def _posify_indices(indices: np.typing.ArrayLike, size: int) -> np.ndarray:
def _posify_indices(indices: Any, size: int) -> np.ndarray:
"""Convert negative indices by their equivalent positive indices.

Note: the resulting indices may still be out of bounds (< 0 or >= size).
Expand All @@ -1311,7 +1311,7 @@ def _posify_indices(indices: np.typing.ArrayLike, size: int) -> np.ndarray:
return np.where(indices < 0, size + indices, indices)


def _check_bounds(indices, size):
def _check_bounds(indices: Any, size: int):
"""Check if the given indices are all within the array boundaries."""
if np.any((indices < 0) | (indices >= size)):
raise IndexError("out of bounds index")
Expand Down Expand Up @@ -2046,7 +2046,7 @@ def dtype(self) -> np.dtype:
return self._transform.dtype

@property
def shape(self):
def shape(self) -> tuple[int, ...]:
return tuple(self._transform.dim_size.values())

def get_duck_array(self) -> np.ndarray:
Expand Down
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy