From 4a595dff198edfc6163fd9bb6d2b3c095320ac2b Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Tue, 24 Sep 2024 16:22:12 +0200 Subject: [PATCH 01/16] Add coordinate transform classes from prototype --- xarray/core/coordinate_transform.py | 74 ++++++++++++++ xarray/core/indexes.py | 111 +++++++++++++++++++++ xarray/core/indexing.py | 145 ++++++++++++++++++++++++++++ 3 files changed, 330 insertions(+) create mode 100644 xarray/core/coordinate_transform.py diff --git a/xarray/core/coordinate_transform.py b/xarray/core/coordinate_transform.py new file mode 100644 index 00000000000..1d4db3e9b7e --- /dev/null +++ b/xarray/core/coordinate_transform.py @@ -0,0 +1,74 @@ +from typing import Any, Iterable, Hashable, Mapping + +import numpy as np + + +class CoordinateTransform: + """Abstract coordinate transform with dimension & coordinate names.""" + + coord_names: tuple[Hashable, ...] + dims: tuple[str, ...] + dim_size: dict[str, int] + dtype: Any + + def __init__( + self, + coord_names: Iterable[Hashable], + dim_size: Mapping[str, int], + dtype: Any = np.dtype(np.float64), + ): + self.coord_names = tuple(coord_names) + self.dims = tuple(dim_size) + self.dim_size = dict(dim_size) + self.dtype = dtype + + def forward(self, dim_positions: dict[str, Any]) -> dict[Hashable, Any]: + """Perform grid -> world coordinate transformation. + + Parameters + ---------- + dim_positions : dict + Grid location(s) along each dimension (axis). + + Returns + ------- + coord_labels : dict + World coordinate labels. + + """ + # TODO: cache the results in order to avoid re-computing + # all labels when accessing the values of each coordinate one at a time + raise NotImplementedError + + def reverse(self, coord_labels: dict[Hashable, Any]) -> dict[str, Any]: + """Perform world -> grid coordinate reverse transformation. + + Parameters + ---------- + labels : dict + World coordinate labels. + + Returns + ------- + dim_positions : dict + Grid relative location(s) along each dimension (axis). + + """ + raise NotImplementedError + + def equals(self, other: "CoordinateTransform") -> bool: + """Check equality with another CoordinateTransform of the same kind.""" + raise NotImplementedError + + def generate_coords(self, dims: tuple[str] | None = None) -> dict[Hashable, Any]: + """Returns all "world" coordinate labels.""" + if dims is None: + dims = self.dims + + positions = np.meshgrid( + *[np.arange(self.dim_size[d]) for d in dims], + indexing="ij", + ) + dim_positions = {dim: positions[i] for i, dim in enumerate(dims)} + + return self.forward(dim_positions) diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 5abc2129e3e..8d90c955bfe 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -10,6 +10,7 @@ import pandas as pd from xarray.core import formatting, nputils, utils +from xarray.core.coordinate_transform import CoordinateTransform from xarray.core.indexing import ( IndexSelResult, PandasIndexingAdapter, @@ -1372,6 +1373,116 @@ def rename(self, name_dict, dims_dict): ) +class CoordinateTransformIndex(Index): + """Xarray index abstract class for transformation between "pixel" + and "world" coordinates. + + """ + + transform: CoordinateTransform + + def __init__( + self, + transform: CoordinateTransform, + ): + self.transform = transform + + def create_variables( + self, variables: Mapping[Any, Variable] | None = None + ) -> IndexVars: + new_variables = {} + + for name in self.transform.coord_names: + # copy attributes, if any + attrs: Mapping[Hashable, Any] | None + + if variables is not None and name in variables: + var = variables[name] + attrs = var.attrs + else: + attrs = None + + data = CoordinateTransformIndexingAdapter(self.transform, name) + new_variables[name] = Variable(self.transform.dims, data, attrs=attrs) + + return new_variables + + def create_coordinates(self) -> Coordinates: + # TODO: move this in xarray.Index base class? + variables = self.create_variables() + indexes = {name: self for name in variables} + return xr.Coordinates(coords=variables, indexes=indexes) + + def isel( + self, indexers: Mapping[Any, int | slice | np.ndarray | Variable] + ) -> Self | None: + # TODO: support returning a new index (e.g., possible to re-calculate the + # the transform or calculate another transform on a reduced dimension space) + return None + + def sel( + self, labels: dict[Any, Any], method=None, tolerance=None + ) -> IndexSelResult: + if method != "nearest": + raise ValueError( + "CoordinateTransformIndex only supports selection with method='nearest'" + ) + + labels_set = set(labels) + coord_names_set = set(self.transform.coord_names) + + missing_labels = coord_names_set - labels_set + if missing_labels: + raise ValueError( + f"missing labels for coordinate(s): {','.join(missing_labels)}." + ) + + label0_obj = next(iter(labels.values())) + dim_size0 = getattr(label0_obj, "sizes", None) + + is_xr_obj = [ + isinstance(label, (xr.DataArray, xr.Variable)) for label in labels.values() + ] + if not all(is_xr_obj): + raise TypeError( + "CoordinateTransformIndex only supports advanced (point-wise) indexing " + "with either xarray.DataArray or xarray.Variable objects." + ) + dim_size = [getattr(label, "sizes", None) for label in labels.values()] + if any([ds != dim_size0 for ds in dim_size]): + raise ValueError( + "CoordinateTransformIndex only supports advanced (point-wise) indexing " + "with xarray.DataArray or xarray.Variable objects of macthing dimensions." + ) + + coord_labels = { + name: labels[name].values for name in self.transform.coord_names + } + dim_positions = self.transform.reverse(coord_labels) + + results = {} + for dim, pos in dim_positions.items(): + if isinstance(label0_obj, Variable): + xr_pos = Variable(label.dims, idx) + else: + # dataarray + xr_pos = DataArray(idx, dims=label.dims) + results[dim] = idx + + return IndexSelResult(results) + + def equals(self, other: Self) -> bool: + return self.transform.equals(other.transform) + + def rename( + self, + name_dict: Mapping[Any, Hashable], + dims_dict: Mapping[Any, Hashable], + ) -> Self: + # TODO: maybe update self.transform coord_names, dim_size and dims attributes + return self + + def create_default_index_implicit( dim_variable: Variable, all_variables: Mapping | Iterable[Hashable] | None = None, diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index 67912908a2b..35fd2597b85 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -15,6 +15,7 @@ import pandas as pd from xarray.core import duck_array_ops +from xarray.core.coordinate_transform import CoordinateTransform from xarray.core.nputils import NumpyVIndexAdapter from xarray.core.options import OPTIONS from xarray.core.types import T_Xarray @@ -1303,6 +1304,42 @@ def _decompose_outer_indexer( return (BasicIndexer(tuple(backend_indexer)), OuterIndexer(tuple(np_indexer))) +def _posify_indices(indices: np.typing.ArrayLike, size: int) -> np.ndarray: + """Convert negative indices by their equivalent positive indices. + + Note: the resulting indices may still be out of bounds (< 0 or >= size). + + """ + return np.where(indices < 0, size + indices, indices) + + +def _check_bounds(indices, size): + """Check if the given indices are all within the array boundaries.""" + if np.any((indices < 0) | (indices >= size)): + raise IndexError("out of bounds index") + + +def _arrayize_outer_indexer(indexer: OuterIndexer, shape) -> OuterIndexer: + """Return a similar oindex with after replacing slices by arrays and + negative indices by their corresponding positive indices. + + Also check if array indices are within bounds. + + """ + new_key = [] + + for axis, value in enumerate(indexer.tuple): + size = shape[axis] + if isinstance(value, slice): + value = _expand_slice(value, size) + else: + value = _posify_indices(value, size) + _check_bounds(value, size) + new_key.append(value) + + return OuterIndexer(tuple(new_key)) + + def _arrayize_vectorized_indexer( indexer: VectorizedIndexer, shape: _Shape ) -> VectorizedIndexer: @@ -1921,3 +1958,111 @@ def copy(self, deep: bool = True) -> Self: # see PandasIndexingAdapter.copy array = self.array.copy(deep=True) if deep else self.array return type(self)(array, self._dtype, self.level) + + +class CoordinateTransformIndexingAdapter(ExplicitlyIndexedNDArrayMixin): + """Wrap a CoordinateTransform to support explicit indexing and + lazy coordinate labels. + + """ + + _transform: CoordinateTransform + _coord_name: Hashable + _dims: tuple[str, ...] + + def __init__( + self, + transform: CoordinateTransform, + coord_name: Hashable, + dims: tuple[str] | None = None, + ): + self._transform = transform + self._coord_name = coord_name + self._dims = dims or transform.dims + + @property + def dtype(self) -> np.dtype: + return self._transform.dtype + + @property + def shape(self): + return tuple(self._transform.dim_size.values()) + + def get_duck_array(self) -> np.ndarray: + all_coords = self._transform.generate_coords(dims=self._dims) + return np.asarray(all_coords[self._coord_name]) + + def _oindex_get(self, indexer: OuterIndexer): + expanded_indexer_ = OuterIndexer(expanded_indexer(indexer.tuple, self.ndim)) + array_indexer = _arrayize_outer_indexer(expanded_indexer_, self.shape) + + positions = np.meshgrid(*array_indexer.tuple, indexing="ij") + dim_positions = { + dim: pos for dim, pos in zip(self._dims, positions, strict=False) + } + + result = self._transform.forward(dim_positions) + return np.asarray(result[self._coord_name]).squeeze() + + def _oindex_set(self, indexer: OuterIndexer, value: Any) -> None: + raise TypeError( + "setting values is not supported on coordinate transform arrays." + ) + + def _vindex_get(self, indexer: VectorizedIndexer): + expanded_indexer_ = VectorizedIndexer( + expanded_indexer(indexer.tuple, self.ndim) + ) + array_indexer = _arrayize_vectorized_indexer(expanded_indexer_, self.shape) + + dim_positions = {} + for i, (dim, pos) in enumerate( + zip(self._dims, array_indexer.tuple, strict=False) + ): + pos = _posify_indices(pos, self.shape[i]) + _check_bounds(pos, self.shape[i]) + dim_positions[dim] = pos + + result = self._transform.forward(dim_positions) + return np.asarray(result[self._coord_name]) + + def _vindex_set(self, indexer: VectorizedIndexer, value: Any) -> None: + raise TypeError( + "setting values is not supported on coordinate transform arrays." + ) + + def __getitem__(self, indexer: ExplicitIndexer): + # TODO: make it lazy (i.e., re-calculate and re-wrap the transform) when possible? + self._check_and_raise_if_non_basic_indexer(indexer) + + # also works with basic indexing + return self._oindex_get(indexer) + + def __setitem__(self, indexer: ExplicitIndexer, value: Any) -> None: + raise TypeError( + "setting values is not supported on coordinate transform arrays." + ) + + def transpose(self, order): + new_dims = tuple([self._dims[i] for i in order]) + return type(self)(self._transform, self._coord_name, new_dims) + + def __repr__(self: Any) -> str: + return f"{type(self).__name__}(transform={self._transform!r})" + + def _get_array_subset(self) -> np.ndarray: + threshold = max(100, OPTIONS["display_values_threshold"] + 2) + if self.size > threshold: + pos = threshold // 2 + indices = np.concatenate([np.arange(0, pos), np.arange(-pos, 0)]) + subset = self.vindex[VectorizedIndexer((indices,) * self.ndim)] + else: + subset = self + + return np.asarray(subset) + + def _repr_inline_(self, max_width: int) -> str: + """Good to see some labels even for a lazy coordinate.""" + from xarray.core.formatting import format_array_flat + + return format_array_flat(self._get_array_subset(), max_width) From 0b545cf61cf192cd1037e2e1d312921a8ab5843c Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Tue, 24 Sep 2024 20:26:15 +0200 Subject: [PATCH 02/16] lint, public API and docstrings --- xarray/__init__.py | 2 ++ xarray/core/coordinate_transform.py | 10 ++++++--- xarray/core/indexes.py | 32 ++++++++++++++++++++--------- xarray/core/indexing.py | 7 ++++--- xarray/indexes/__init__.py | 9 ++++++-- 5 files changed, 42 insertions(+), 18 deletions(-) diff --git a/xarray/__init__.py b/xarray/__init__.py index e3b7ec469e9..b49ab1848b7 100644 --- a/xarray/__init__.py +++ b/xarray/__init__.py @@ -30,6 +30,7 @@ where, ) from xarray.core.concat import concat +from xarray.core.coordinate_transform import CoordinateTransform from xarray.core.coordinates import Coordinates from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset @@ -109,6 +110,7 @@ "CFTimeIndex", "Context", "Coordinates", + "CoordinateTransform", "DataArray", "Dataset", "DataTree", diff --git a/xarray/core/coordinate_transform.py b/xarray/core/coordinate_transform.py index 1d4db3e9b7e..40043da46bc 100644 --- a/xarray/core/coordinate_transform.py +++ b/xarray/core/coordinate_transform.py @@ -1,4 +1,5 @@ -from typing import Any, Iterable, Hashable, Mapping +from collections.abc import Hashable, Iterable, Mapping +from typing import Any import numpy as np @@ -15,11 +16,14 @@ def __init__( self, coord_names: Iterable[Hashable], dim_size: Mapping[str, int], - dtype: Any = np.dtype(np.float64), + dtype: Any = None, ): self.coord_names = tuple(coord_names) self.dims = tuple(dim_size) self.dim_size = dict(dim_size) + + if dtype is None: + dtype = np.dtype(np.float64) self.dtype = dtype def forward(self, dim_positions: dict[str, Any]) -> dict[Hashable, Any]: @@ -61,7 +65,7 @@ def equals(self, other: "CoordinateTransform") -> bool: raise NotImplementedError def generate_coords(self, dims: tuple[str] | None = None) -> dict[Hashable, Any]: - """Returns all "world" coordinate labels.""" + """Compute all coordinate labels at once.""" if dims is None: dims = self.dims diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 8d90c955bfe..e154b727fc5 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -12,6 +12,7 @@ from xarray.core import formatting, nputils, utils from xarray.core.coordinate_transform import CoordinateTransform from xarray.core.indexing import ( + CoordinateTransformIndexingAdapter, IndexSelResult, PandasIndexingAdapter, PandasMultiIndexingAdapter, @@ -25,6 +26,7 @@ ) if TYPE_CHECKING: + from xarray.core.coordinate import Coordinates from xarray.core.types import ErrorOptions, JoinOptions, Self from xarray.core.variable import Variable @@ -1374,8 +1376,13 @@ def rename(self, name_dict, dims_dict): class CoordinateTransformIndex(Index): - """Xarray index abstract class for transformation between "pixel" - and "world" coordinates. + """Helper class for creating Xarray indexes based on coordinate transforms. + + - wraps a :py:class:`CoordinateTransform` instance + - takes care of creating the index (lazy) coordinates + - supports point-wise label-based selection + - supports exact alignment only, by comparing indexes based on their transform + (not on their explicit coordinate labels) """ @@ -1409,9 +1416,11 @@ def create_variables( def create_coordinates(self) -> Coordinates: # TODO: move this in xarray.Index base class? + from xarray.core.coordinates import Coordinates + variables = self.create_variables() indexes = {name: self for name in variables} - return xr.Coordinates(coords=variables, indexes=indexes) + return Coordinates(coords=variables, indexes=indexes) def isel( self, indexers: Mapping[Any, int | slice | np.ndarray | Variable] @@ -1423,6 +1432,9 @@ def isel( def sel( self, labels: dict[Any, Any], method=None, tolerance=None ) -> IndexSelResult: + from xarray.core.dataarray import DataArray + from xarray.core.variable import Variable + if method != "nearest": raise ValueError( "CoordinateTransformIndex only supports selection with method='nearest'" @@ -1433,15 +1445,14 @@ def sel( missing_labels = coord_names_set - labels_set if missing_labels: - raise ValueError( - f"missing labels for coordinate(s): {','.join(missing_labels)}." - ) + missing_labels_str = ",".join([f"{name}" for name in missing_labels]) + raise ValueError(f"missing labels for coordinate(s): {missing_labels_str}.") label0_obj = next(iter(labels.values())) dim_size0 = getattr(label0_obj, "sizes", None) is_xr_obj = [ - isinstance(label, (xr.DataArray, xr.Variable)) for label in labels.values() + isinstance(label, DataArray | Variable) for label in labels.values() ] if not all(is_xr_obj): raise TypeError( @@ -1461,13 +1472,14 @@ def sel( dim_positions = self.transform.reverse(coord_labels) results = {} + dims0 = tuple(dim_size0) for dim, pos in dim_positions.items(): if isinstance(label0_obj, Variable): - xr_pos = Variable(label.dims, idx) + xr_pos = Variable(dims0, pos) else: # dataarray - xr_pos = DataArray(idx, dims=label.dims) - results[dim] = idx + xr_pos = DataArray(pos, dims=dims0) + results[dim] = xr_pos return IndexSelResult(results) diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index 35fd2597b85..047e16c240d 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -1961,8 +1961,9 @@ def copy(self, deep: bool = True) -> Self: class CoordinateTransformIndexingAdapter(ExplicitlyIndexedNDArrayMixin): - """Wrap a CoordinateTransform to support explicit indexing and - lazy coordinate labels. + """Wrap a CoordinateTransform as a lazy coordinate array. + + Supports explicit indexing (both outer and vectorized). """ @@ -2036,7 +2037,7 @@ def __getitem__(self, indexer: ExplicitIndexer): self._check_and_raise_if_non_basic_indexer(indexer) # also works with basic indexing - return self._oindex_get(indexer) + return self._oindex_get(OuterIndexer(indexer.tuple)) def __setitem__(self, indexer: ExplicitIndexer, value: Any) -> None: raise TypeError( diff --git a/xarray/indexes/__init__.py b/xarray/indexes/__init__.py index b1bf7a1af11..e2857b8602b 100644 --- a/xarray/indexes/__init__.py +++ b/xarray/indexes/__init__.py @@ -3,6 +3,11 @@ """ -from xarray.core.indexes import Index, PandasIndex, PandasMultiIndex +from xarray.core.indexes import ( + CoordinateTransformIndex, + Index, + PandasIndex, + PandasMultiIndex, +) -__all__ = ["Index", "PandasIndex", "PandasMultiIndex"] +__all__ = ["CoordinateTransformIndex", "Index", "PandasIndex", "PandasMultiIndex"] From 8af6614086f8ca181ec070859fca1e019663c837 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Tue, 24 Sep 2024 20:30:52 +0200 Subject: [PATCH 03/16] missing import --- xarray/core/indexes.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index e154b727fc5..b56d1faf295 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -1397,6 +1397,8 @@ def __init__( def create_variables( self, variables: Mapping[Any, Variable] | None = None ) -> IndexVars: + from xarray.core.variable import Variable + new_variables = {} for name in self.transform.coord_names: From e9a11ef6df072c4f61eea7ea7be00e12d7cee5da Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Tue, 24 Sep 2024 20:48:25 +0200 Subject: [PATCH 04/16] sel: convert inverse transform results to ints --- xarray/core/indexes.py | 1 + 1 file changed, 1 insertion(+) diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index b56d1faf295..ab725f86833 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -1476,6 +1476,7 @@ def sel( results = {} dims0 = tuple(dim_size0) for dim, pos in dim_positions.items(): + pos = np.round(pos).astype("int") if isinstance(label0_obj, Variable): xr_pos = Variable(dims0, pos) else: From 0b3fd9ee751f64b9695609a601ae31b336c1e0a0 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Tue, 24 Sep 2024 22:13:40 +0200 Subject: [PATCH 05/16] sel: add todo note about rounding decimal pos --- xarray/core/indexes.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index ab725f86833..987039e1f87 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -1476,6 +1476,9 @@ def sel( results = {} dims0 = tuple(dim_size0) for dim, pos in dim_positions.items(): + # TODO: rounding the decimal positions is not always the behavior we expect + # (there are different ways to represent implicit intervals) + # we should probably make this customizable. pos = np.round(pos).astype("int") if isinstance(label0_obj, Variable): xr_pos = Variable(dims0, pos) From acf1c478c68fcadcc6bfbdd4414bc97b8667383f Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Thu, 26 Sep 2024 10:37:04 +0200 Subject: [PATCH 06/16] rename create_coordinates -> create_coords More consistent with the rest of Xarray API where `coords` is used everywhere. --- xarray/core/indexes.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 987039e1f87..4f2bba20844 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -1417,6 +1417,11 @@ def create_variables( return new_variables def create_coordinates(self) -> Coordinates: + # TODO: remove this alias before merging https://github.com/pydata/xarray/pull/9543! + # (we keep it there so it doesn't break the code of those who are experimenting with this) + return self.create_coords() + + def create_coords(self) -> Coordinates: # TODO: move this in xarray.Index base class? from xarray.core.coordinates import Coordinates From e101585e9fb30a3b73d6d37a7bc0be1607f991b1 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Thu, 26 Sep 2024 10:46:13 +0200 Subject: [PATCH 07/16] add a Coordinates.from_transform convenient method --- xarray/core/coordinates.py | 26 +++++++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/xarray/core/coordinates.py b/xarray/core/coordinates.py index a6dec863aec..af622aaca8b 100644 --- a/xarray/core/coordinates.py +++ b/xarray/core/coordinates.py @@ -14,7 +14,9 @@ from xarray.core import formatting from xarray.core.alignment import Aligner +from xarray.core.coordinate_transform import CoordinateTransform from xarray.core.indexes import ( + CoordinateTransformIndex, Index, Indexes, PandasIndex, @@ -356,7 +358,7 @@ def _construct_direct( def from_pandas_multiindex(cls, midx: pd.MultiIndex, dim: Hashable) -> Self: """Wrap a pandas multi-index as Xarray coordinates (dimension + levels). - The returned coordinates can be directly assigned to a + The returned coordinate variables can be directly assigned to a :py:class:`~xarray.Dataset` or :py:class:`~xarray.DataArray` via the ``coords`` argument of their constructor. @@ -380,6 +382,28 @@ def from_pandas_multiindex(cls, midx: pd.MultiIndex, dim: Hashable) -> Self: return cls(coords=variables, indexes=indexes) + @classmethod + def from_transform(cls, transform: CoordinateTransform) -> Self: + """Wrap a coordinate transform as Xarray (lazy) coordinates. + + The returned coordinate variables can be directly assigned to a + :py:class:`~xarray.Dataset` or :py:class:`~xarray.DataArray` via the + ``coords`` argument of their constructor. + + Parameters + ---------- + transform : :py:class:`CoordinateTransform` + Xarray coordinate transform object. + + Returns + ------- + coords : Coordinates + A collection of Xarray indexed coordinates created from the transform. + + """ + index = CoordinateTransformIndex(transform) + return index.create_coords() + @property def _names(self) -> set[Hashable]: return self._data._coord_names From 09667c5da4e2de2f1db6896e3acce0205e3608e3 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Wed, 2 Oct 2024 14:50:55 +0200 Subject: [PATCH 08/16] fix repr (extract subset values of any n-d array) --- xarray/core/indexing.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index 047e16c240d..04677bb8d60 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -2055,8 +2055,12 @@ def _get_array_subset(self) -> np.ndarray: threshold = max(100, OPTIONS["display_values_threshold"] + 2) if self.size > threshold: pos = threshold // 2 - indices = np.concatenate([np.arange(0, pos), np.arange(-pos, 0)]) - subset = self.vindex[VectorizedIndexer((indices,) * self.ndim)] + flat_indices = np.concatenate( + [np.arange(0, pos), np.arange(self.size - pos, self.size)] + ) + subset = self.vindex[ + VectorizedIndexer(np.unravel_index(flat_indices, self.shape)) + ] else: subset = self From 4c7ce28884c0dd9af4caabb5297036a6d5644a9a Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Wed, 12 Feb 2025 14:55:19 -0700 Subject: [PATCH 09/16] Apply suggestions from code review Co-authored-by: Max Jones <14077947+maxrjones@users.noreply.github.com> --- xarray/core/indexes.py | 2 +- xarray/core/indexing.py | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 53049f9e5a1..fb04b829737 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -1472,7 +1472,7 @@ def sel( "with either xarray.DataArray or xarray.Variable objects." ) dim_size = [getattr(label, "sizes", None) for label in labels.values()] - if any([ds != dim_size0 for ds in dim_size]): + if any(ds != dim_size0 for ds in dim_size): raise ValueError( "CoordinateTransformIndex only supports advanced (point-wise) indexing " "with xarray.DataArray or xarray.Variable objects of macthing dimensions." diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index 35d4fc52e8c..3b337612239 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -2058,9 +2058,7 @@ def _oindex_get(self, indexer: OuterIndexer): array_indexer = _arrayize_outer_indexer(expanded_indexer_, self.shape) positions = np.meshgrid(*array_indexer.tuple, indexing="ij") - dim_positions = { - dim: pos for dim, pos in zip(self._dims, positions, strict=False) - } + dim_positions = dict(zip(self._dims, positions, strict=False)) result = self._transform.forward(dim_positions) return np.asarray(result[self._coord_name]).squeeze() From 5cfb1afa9bdb45817e0527cde82764b12586f6d8 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Thu, 13 Feb 2025 09:11:31 +0100 Subject: [PATCH 10/16] remove specific create coordinates methods In favor of the more generic `Coordinates.from_xindex()`. --- xarray/core/coordinates.py | 24 ------------------------ xarray/core/indexes.py | 14 -------------- 2 files changed, 38 deletions(-) diff --git a/xarray/core/coordinates.py b/xarray/core/coordinates.py index 1bca543ca20..47773ddfbb6 100644 --- a/xarray/core/coordinates.py +++ b/xarray/core/coordinates.py @@ -14,9 +14,7 @@ from xarray.core import formatting from xarray.core.alignment import Aligner -from xarray.core.coordinate_transform import CoordinateTransform from xarray.core.indexes import ( - CoordinateTransformIndex, Index, Indexes, PandasIndex, @@ -418,28 +416,6 @@ def from_pandas_multiindex(cls, midx: pd.MultiIndex, dim: Hashable) -> Self: return cls(coords=variables, indexes=indexes) - @classmethod - def from_transform(cls, transform: CoordinateTransform) -> Self: - """Wrap a coordinate transform as Xarray (lazy) coordinates. - - The returned coordinate variables can be directly assigned to a - :py:class:`~xarray.Dataset` or :py:class:`~xarray.DataArray` via the - ``coords`` argument of their constructor. - - Parameters - ---------- - transform : :py:class:`CoordinateTransform` - Xarray coordinate transform object. - - Returns - ------- - coords : Coordinates - A collection of Xarray indexed coordinates created from the transform. - - """ - index = CoordinateTransformIndex(transform) - return index.create_coords() - @property def _names(self) -> set[Hashable]: return self._data._coord_names diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index fb04b829737..833ec8bb926 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -26,7 +26,6 @@ ) if TYPE_CHECKING: - from xarray.core.coordinate import Coordinates from xarray.core.types import ErrorOptions, JoinOptions, Self from xarray.core.variable import Variable @@ -1421,19 +1420,6 @@ def create_variables( return new_variables - def create_coordinates(self) -> Coordinates: - # TODO: remove this alias before merging https://github.com/pydata/xarray/pull/9543! - # (we keep it there so it doesn't break the code of those who are experimenting with this) - return self.create_coords() - - def create_coords(self) -> Coordinates: - # TODO: move this in xarray.Index base class? - from xarray.core.coordinates import Coordinates - - variables = self.create_variables() - indexes = {name: self for name in variables} - return Coordinates(coords=variables, indexes=indexes) - def isel( self, indexers: Mapping[Any, int | slice | np.ndarray | Variable] ) -> Self | None: From 632c71b103d4e659216f37d3adf0f5b8e8aad091 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Thu, 13 Feb 2025 09:54:09 +0100 Subject: [PATCH 11/16] fix more typing issues --- xarray/core/coordinate_transform.py | 4 +++- xarray/core/indexes.py | 11 +++++------ xarray/core/indexing.py | 6 +++--- 3 files changed, 11 insertions(+), 10 deletions(-) diff --git a/xarray/core/coordinate_transform.py b/xarray/core/coordinate_transform.py index 40043da46bc..52e7f13ca14 100644 --- a/xarray/core/coordinate_transform.py +++ b/xarray/core/coordinate_transform.py @@ -64,7 +64,9 @@ def equals(self, other: "CoordinateTransform") -> bool: """Check equality with another CoordinateTransform of the same kind.""" raise NotImplementedError - def generate_coords(self, dims: tuple[str] | None = None) -> dict[Hashable, Any]: + def generate_coords( + self, dims: tuple[str, ...] | None = None + ) -> dict[Hashable, Any]: """Compute all coordinate labels at once.""" if dims is None: dims = self.dims diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 833ec8bb926..240b4f178ec 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -1447,7 +1447,7 @@ def sel( raise ValueError(f"missing labels for coordinate(s): {missing_labels_str}.") label0_obj = next(iter(labels.values())) - dim_size0 = getattr(label0_obj, "sizes", None) + dim_size0 = getattr(label0_obj, "sizes", {}) is_xr_obj = [ isinstance(label, DataArray | Variable) for label in labels.values() @@ -1457,7 +1457,7 @@ def sel( "CoordinateTransformIndex only supports advanced (point-wise) indexing " "with either xarray.DataArray or xarray.Variable objects." ) - dim_size = [getattr(label, "sizes", None) for label in labels.values()] + dim_size = [getattr(label, "sizes", {}) for label in labels.values()] if any(ds != dim_size0 for ds in dim_size): raise ValueError( "CoordinateTransformIndex only supports advanced (point-wise) indexing " @@ -1469,7 +1469,7 @@ def sel( } dim_positions = self.transform.reverse(coord_labels) - results = {} + results: dict[str, Variable | DataArray] = {} dims0 = tuple(dim_size0) for dim, pos in dim_positions.items(): # TODO: rounding the decimal positions is not always the behavior we expect @@ -1477,11 +1477,10 @@ def sel( # we should probably make this customizable. pos = np.round(pos).astype("int") if isinstance(label0_obj, Variable): - xr_pos = Variable(dims0, pos) + results[dim] = Variable(dims0, pos) else: # dataarray - xr_pos = DataArray(pos, dims=dims0) - results[dim] = xr_pos + results[dim] = DataArray(pos, dims=dims0) return IndexSelResult(results) diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index 3b337612239..f379a932019 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -1302,7 +1302,7 @@ def _decompose_outer_indexer( return (BasicIndexer(tuple(backend_indexer)), OuterIndexer(tuple(np_indexer))) -def _posify_indices(indices: np.typing.ArrayLike, size: int) -> np.ndarray: +def _posify_indices(indices: Any, size: int) -> np.ndarray: """Convert negative indices by their equivalent positive indices. Note: the resulting indices may still be out of bounds (< 0 or >= size). @@ -1311,7 +1311,7 @@ def _posify_indices(indices: np.typing.ArrayLike, size: int) -> np.ndarray: return np.where(indices < 0, size + indices, indices) -def _check_bounds(indices, size): +def _check_bounds(indices: Any, size: int): """Check if the given indices are all within the array boundaries.""" if np.any((indices < 0) | (indices >= size)): raise IndexError("out of bounds index") @@ -2046,7 +2046,7 @@ def dtype(self) -> np.dtype: return self._transform.dtype @property - def shape(self): + def shape(self) -> tuple[int, ...]: return tuple(self._transform.dim_size.values()) def get_duck_array(self) -> np.ndarray: From ae8b318c7a136c35a784e96fd0b63225011b9bf0 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Thu, 13 Feb 2025 09:55:31 +0100 Subject: [PATCH 12/16] remove public imports: not ready yet for public use --- xarray/__init__.py | 2 -- xarray/indexes/__init__.py | 3 +-- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/xarray/__init__.py b/xarray/__init__.py index 05cfecb2b8b..8af936ed27a 100644 --- a/xarray/__init__.py +++ b/xarray/__init__.py @@ -30,7 +30,6 @@ where, ) from xarray.core.concat import concat -from xarray.core.coordinate_transform import CoordinateTransform from xarray.core.coordinates import Coordinates from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset @@ -118,7 +117,6 @@ "CFTimeIndex", "Context", "Coordinates", - "CoordinateTransform", "DataArray", "DataTree", "Dataset", diff --git a/xarray/indexes/__init__.py b/xarray/indexes/__init__.py index e2857b8602b..9073cbc2ed4 100644 --- a/xarray/indexes/__init__.py +++ b/xarray/indexes/__init__.py @@ -4,10 +4,9 @@ """ from xarray.core.indexes import ( - CoordinateTransformIndex, Index, PandasIndex, PandasMultiIndex, ) -__all__ = ["CoordinateTransformIndex", "Index", "PandasIndex", "PandasMultiIndex"] +__all__ = ["Index", "PandasIndex", "PandasMultiIndex"] From 1c425e30fda0541d5571250b87c9c0b3b3736dac Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Thu, 13 Feb 2025 10:43:38 +0100 Subject: [PATCH 13/16] add experimental notice in docstrings --- xarray/core/coordinate_transform.py | 6 +++++- xarray/core/indexes.py | 2 ++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/xarray/core/coordinate_transform.py b/xarray/core/coordinate_transform.py index 52e7f13ca14..d9e09cea173 100644 --- a/xarray/core/coordinate_transform.py +++ b/xarray/core/coordinate_transform.py @@ -5,7 +5,11 @@ class CoordinateTransform: - """Abstract coordinate transform with dimension & coordinate names.""" + """Abstract coordinate transform with dimension & coordinate names. + + EXPERIMENTAL (not ready for public use yet). + + """ coord_names: tuple[Hashable, ...] dims: tuple[str, ...] diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 240b4f178ec..43e231e84d4 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -1382,6 +1382,8 @@ def rename(self, name_dict, dims_dict): class CoordinateTransformIndex(Index): """Helper class for creating Xarray indexes based on coordinate transforms. + EXPERIMENTAL (not ready for public use yet). + - wraps a :py:class:`CoordinateTransform` instance - takes care of creating the index (lazy) coordinates - supports point-wise label-based selection From 952faa78fb96b7a486b503465edda732342517e7 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Thu, 13 Feb 2025 11:36:08 +0100 Subject: [PATCH 14/16] add coordinate transform tests --- xarray/tests/test_coordinate_transform.py | 218 ++++++++++++++++++++++ 1 file changed, 218 insertions(+) create mode 100644 xarray/tests/test_coordinate_transform.py diff --git a/xarray/tests/test_coordinate_transform.py b/xarray/tests/test_coordinate_transform.py new file mode 100644 index 00000000000..26746657dbc --- /dev/null +++ b/xarray/tests/test_coordinate_transform.py @@ -0,0 +1,218 @@ +from collections.abc import Hashable +from typing import Any + +import numpy as np +import pytest + +import xarray as xr +from xarray.core.coordinate_transform import CoordinateTransform +from xarray.core.indexes import CoordinateTransformIndex +from xarray.tests import assert_equal + + +class SimpleCoordinateTransform(CoordinateTransform): + """Simple uniform scale transform in a 2D space (x/y coordinates).""" + + def __init__(self, shape: tuple[int, int], scale: float, dtype: Any = None): + super().__init__(("x", "y"), {"x": shape[1], "y": shape[0]}, dtype=dtype) + + self.scale = scale + + # array dimensions in reverse order (y = rows, x = cols) + self.xy_dims = tuple(self.dims) + self.dims = (self.dims[1], self.dims[0]) + + def forward(self, dim_positions: dict[str, Any]) -> dict[Hashable, Any]: + assert set(dim_positions) == set(self.dims) + return {dim: dim_positions[dim] * self.scale for dim in self.xy_dims} + + def reverse(self, coord_labels: dict[Hashable, Any]) -> dict[str, Any]: + return {dim: coord_labels[dim] / self.scale for dim in self.xy_dims} + + def equals(self, other: "CoordinateTransform") -> bool: + if not isinstance(other, SimpleCoordinateTransform): + return False + return self.scale == other.scale + + def __repr__(self) -> str: + return f"Scale({self.scale})" + + +def test_abstract_coordinate_transform() -> None: + tr = CoordinateTransform(["x"], {"x": 5}) + + with pytest.raises(NotImplementedError): + tr.forward({"x": [1, 2]}) + + with pytest.raises(NotImplementedError): + tr.reverse({"x": [3.0, 4.0]}) + + with pytest.raises(NotImplementedError): + tr.equals(CoordinateTransform(["x"], {"x": 5})) + + +def test_coordinate_transform_init() -> None: + tr = SimpleCoordinateTransform((4, 4), 2.0) + + assert tr.coord_names == ("x", "y") + # array dimensions in reverse order (y = rows, x = cols) + assert tr.dims == ("y", "x") + assert tr.dim_size == {"x": 4, "y": 4} + assert tr.dtype == np.dtype(np.float64) + + tr2 = SimpleCoordinateTransform((4, 4), 2.0, dtype=np.int64) + assert tr2.dtype == np.dtype(np.int64) + + +@pytest.mark.parametrize("dims", [None, ("y", "x")]) +def test_coordinate_transform_generate_coords(dims) -> None: + tr = SimpleCoordinateTransform((2, 2), 2.0) + + actual = tr.generate_coords(dims) + expected = {"x": [[0.0, 2.0], [0.0, 2.0]], "y": [[0.0, 0.0], [2.0, 2.0]]} + assert set(actual) == set(expected) + np.testing.assert_array_equal(actual["x"], expected["x"]) + np.testing.assert_array_equal(actual["y"], expected["y"]) + + +def create_coords(scale: float, shape: tuple[int, int]) -> xr.Coordinates: + """Create x/y Xarray coordinate variables from a simple coordinate transform.""" + tr = SimpleCoordinateTransform(shape, scale) + index = CoordinateTransformIndex(tr) + return xr.Coordinates.from_xindex(index) + + +def test_coordinate_transform_variable() -> None: + coords = create_coords(scale=2.0, shape=(2, 2)) + + assert coords["x"].dtype == np.dtype(np.float64) + assert coords["y"].dtype == np.dtype(np.float64) + assert coords["x"].shape == (2, 2) + assert coords["y"].shape == (2, 2) + + np.testing.assert_array_equal(np.array(coords["x"]), [[0.0, 2.0], [0.0, 2.0]]) + np.testing.assert_array_equal(np.array(coords["y"]), [[0.0, 0.0], [2.0, 2.0]]) + + def assert_repr(var: xr.Variable): + assert ( + repr(var._data) + == "CoordinateTransformIndexingAdapter(transform=Scale(2.0))" + ) + + assert_repr(coords["x"].variable) + assert_repr(coords["y"].variable) + + +def test_coordinate_transform_variable_repr_inline() -> None: + var = create_coords(scale=2.0, shape=(2, 2))["x"].variable + + actual = var._data._repr_inline_(70) # type: ignore[union-attr] + assert actual == "0.0 2.0 0.0 2.0" + + # truncated inline repr + var2 = create_coords(scale=2.0, shape=(10, 10))["x"].variable + + actual2 = var2._data._repr_inline_(70) # type: ignore[union-attr] + assert ( + actual2 == "0.0 2.0 4.0 6.0 8.0 10.0 12.0 ... 6.0 8.0 10.0 12.0 14.0 16.0 18.0" + ) + + +def test_coordinate_transform_variable_basic_outer_indexing() -> None: + var = create_coords(scale=2.0, shape=(4, 4))["x"].variable + + assert var[0, 0] == 0.0 + assert var[0, 1] == 2.0 + assert var[0, -1] == 6.0 + np.testing.assert_array_equal(var[:, 0:2], [[0.0, 2.0]] * 4) + + with pytest.raises(IndexError, match="out of bounds index"): + var[5] + + with pytest.raises(IndexError, match="out of bounds index"): + var[-5] + + +def test_coordinate_transform_variable_vectorized_indexing() -> None: + var = create_coords(scale=2.0, shape=(4, 4))["x"].variable + + actual = var[{"x": xr.Variable("z", [0]), "y": xr.Variable("z", [0])}] + expected = xr.Variable("z", [0.0]) + assert_equal(actual, expected) + + with pytest.raises(IndexError, match="out of bounds index"): + var[{"x": xr.Variable("z", [5]), "y": xr.Variable("z", [5])}] + + +def test_coordinate_transform_setitem_error() -> None: + var = create_coords(scale=2.0, shape=(4, 4))["x"].variable + + # basic indexing + with pytest.raises(TypeError, match="setting values is not supported"): + var[0, 0] = 1.0 + + # outer indexing + with pytest.raises(TypeError, match="setting values is not supported"): + var[[0, 2], 0] = [1.0, 2.0] + + # vectorized indexing + with pytest.raises(TypeError, match="setting values is not supported"): + var[{"x": xr.Variable("z", [0]), "y": xr.Variable("z", [0])}] = 1.0 + + +def test_coordinate_transform_transpose() -> None: + coords = create_coords(scale=2.0, shape=(2, 2)) + + actual = coords["x"].transpose().values + expected = [[0.0, 0.0], [2.0, 2.0]] + np.testing.assert_array_equal(actual, expected) + + +def test_coordinate_transform_equals() -> None: + ds1 = create_coords(scale=2.0, shape=(2, 2)).to_dataset() + ds2 = create_coords(scale=2.0, shape=(2, 2)).to_dataset() + ds3 = create_coords(scale=4.0, shape=(2, 2)).to_dataset() + + # cannot use `assert_equal()` test utility function here yet + # (indexes invariant check are still based on IndexVariable, which + # doesn't work with coordinate transform index coordinate variables) + assert ds1.equals(ds2) + assert not ds1.equals(ds3) + + +def test_coordinate_transform_sel() -> None: + ds = create_coords(scale=2.0, shape=(4, 4)).to_dataset() + + data = [ + [0.0, 1.0, 2.0, 3.0], + [4.0, 5.0, 6.0, 7.0], + [8.0, 9.0, 10.0, 11.0], + [12.0, 13.0, 14.0, 15.0], + ] + ds["data"] = (("y", "x"), data) + + actual = ds.sel( + x=xr.Variable("z", [0.5, 5.5]), y=xr.Variable("z", [0.0, 0.5]), method="nearest" + ) + expected = ds.isel(x=xr.Variable("z", [0, 3]), y=xr.Variable("z", [0, 0])) + + # cannot use `assert_equal()` test utility function here yet + # (indexes invariant check are still based on IndexVariable, which + # doesn't work with coordinate transform index coordinate variables) + assert actual.equals(expected) + + with pytest.raises(ValueError, match=".*only supports selection.*nearest"): + ds.sel(x=xr.Variable("z", [0.5, 5.5]), y=xr.Variable("z", [0.0, 0.5])) + + with pytest.raises(ValueError, match="missing labels for coordinate.*y"): + ds.sel(x=[0.5, 5.5], method="nearest") + + with pytest.raises(TypeError, match=".*only supports advanced.*indexing"): + ds.sel(x=[0.5, 5.5], y=[0.0, 0.5], method="nearest") + + with pytest.raises(ValueError, match=".*only supports advanced.*indexing"): + ds.sel( + x=xr.Variable("z", [0.5, 5.5]), + y=xr.Variable("z", [0.0, 0.5, 1.5]), + method="nearest", + ) From 03fdc90404a195dba8a39af5c549d93b0d2363cb Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Thu, 13 Feb 2025 14:17:14 +0100 Subject: [PATCH 15/16] typing fixes --- xarray/core/indexing.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index f379a932019..521abcdfddd 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -2035,7 +2035,7 @@ def __init__( self, transform: CoordinateTransform, coord_name: Hashable, - dims: tuple[str] | None = None, + dims: tuple[str, ...] | None = None, ): self._transform = transform self._coord_name = coord_name @@ -2102,7 +2102,7 @@ def __setitem__(self, indexer: ExplicitIndexer, value: Any) -> None: "setting values is not supported on coordinate transform arrays." ) - def transpose(self, order): + def transpose(self, order: Iterable[int]) -> Self: new_dims = tuple([self._dims[i] for i in order]) return type(self)(self._transform, self._coord_name, new_dims) From 406b03b7067604438295ffdf34c8a608ace69669 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Thu, 13 Feb 2025 14:18:51 +0100 Subject: [PATCH 16/16] update what's new --- doc/whats-new.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 43edc5ee33e..d9d4998d983 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -28,6 +28,8 @@ New Features By `Kai Mühlbauer `_. - support python 3.13 (no free-threading) (:issue:`9664`, :pull:`9681`) By `Justus Magin `_. +- Added experimental support for coordinate transforms (not ready for public use yet!) (:pull:`9543`) + By `Benoit Bovy `_. Breaking changes ~~~~~~~~~~~~~~~~ pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy