diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 43edc5ee33e..d9d4998d983 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -28,6 +28,8 @@ New Features By `Kai Mühlbauer `_. - support python 3.13 (no free-threading) (:issue:`9664`, :pull:`9681`) By `Justus Magin `_. +- Added experimental support for coordinate transforms (not ready for public use yet!) (:pull:`9543`) + By `Benoit Bovy `_. Breaking changes ~~~~~~~~~~~~~~~~ diff --git a/xarray/core/coordinate_transform.py b/xarray/core/coordinate_transform.py new file mode 100644 index 00000000000..d9e09cea173 --- /dev/null +++ b/xarray/core/coordinate_transform.py @@ -0,0 +1,84 @@ +from collections.abc import Hashable, Iterable, Mapping +from typing import Any + +import numpy as np + + +class CoordinateTransform: + """Abstract coordinate transform with dimension & coordinate names. + + EXPERIMENTAL (not ready for public use yet). + + """ + + coord_names: tuple[Hashable, ...] + dims: tuple[str, ...] + dim_size: dict[str, int] + dtype: Any + + def __init__( + self, + coord_names: Iterable[Hashable], + dim_size: Mapping[str, int], + dtype: Any = None, + ): + self.coord_names = tuple(coord_names) + self.dims = tuple(dim_size) + self.dim_size = dict(dim_size) + + if dtype is None: + dtype = np.dtype(np.float64) + self.dtype = dtype + + def forward(self, dim_positions: dict[str, Any]) -> dict[Hashable, Any]: + """Perform grid -> world coordinate transformation. + + Parameters + ---------- + dim_positions : dict + Grid location(s) along each dimension (axis). + + Returns + ------- + coord_labels : dict + World coordinate labels. + + """ + # TODO: cache the results in order to avoid re-computing + # all labels when accessing the values of each coordinate one at a time + raise NotImplementedError + + def reverse(self, coord_labels: dict[Hashable, Any]) -> dict[str, Any]: + """Perform world -> grid coordinate reverse transformation. + + Parameters + ---------- + labels : dict + World coordinate labels. + + Returns + ------- + dim_positions : dict + Grid relative location(s) along each dimension (axis). + + """ + raise NotImplementedError + + def equals(self, other: "CoordinateTransform") -> bool: + """Check equality with another CoordinateTransform of the same kind.""" + raise NotImplementedError + + def generate_coords( + self, dims: tuple[str, ...] | None = None + ) -> dict[Hashable, Any]: + """Compute all coordinate labels at once.""" + if dims is None: + dims = self.dims + + positions = np.meshgrid( + *[np.arange(self.dim_size[d]) for d in dims], + indexing="ij", + ) + dim_positions = {dim: positions[i] for i, dim in enumerate(dims)} + + return self.forward(dim_positions) diff --git a/xarray/core/coordinates.py b/xarray/core/coordinates.py index a9ceeb08b96..47773ddfbb6 100644 --- a/xarray/core/coordinates.py +++ b/xarray/core/coordinates.py @@ -392,7 +392,7 @@ def from_xindex(cls, index: Index) -> Self: def from_pandas_multiindex(cls, midx: pd.MultiIndex, dim: Hashable) -> Self: """Wrap a pandas multi-index as Xarray coordinates (dimension + levels). - The returned coordinates can be directly assigned to a + The returned coordinate variables can be directly assigned to a :py:class:`~xarray.Dataset` or :py:class:`~xarray.DataArray` via the ``coords`` argument of their constructor. diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index fbaef9729e3..43e231e84d4 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -10,7 +10,9 @@ import pandas as pd from xarray.core import formatting, nputils, utils +from xarray.core.coordinate_transform import CoordinateTransform from xarray.core.indexing import ( + CoordinateTransformIndexingAdapter, IndexSelResult, PandasIndexingAdapter, PandasMultiIndexingAdapter, @@ -1377,6 +1379,125 @@ def rename(self, name_dict, dims_dict): ) +class CoordinateTransformIndex(Index): + """Helper class for creating Xarray indexes based on coordinate transforms. + + EXPERIMENTAL (not ready for public use yet). + + - wraps a :py:class:`CoordinateTransform` instance + - takes care of creating the index (lazy) coordinates + - supports point-wise label-based selection + - supports exact alignment only, by comparing indexes based on their transform + (not on their explicit coordinate labels) + + """ + + transform: CoordinateTransform + + def __init__( + self, + transform: CoordinateTransform, + ): + self.transform = transform + + def create_variables( + self, variables: Mapping[Any, Variable] | None = None + ) -> IndexVars: + from xarray.core.variable import Variable + + new_variables = {} + + for name in self.transform.coord_names: + # copy attributes, if any + attrs: Mapping[Hashable, Any] | None + + if variables is not None and name in variables: + var = variables[name] + attrs = var.attrs + else: + attrs = None + + data = CoordinateTransformIndexingAdapter(self.transform, name) + new_variables[name] = Variable(self.transform.dims, data, attrs=attrs) + + return new_variables + + def isel( + self, indexers: Mapping[Any, int | slice | np.ndarray | Variable] + ) -> Self | None: + # TODO: support returning a new index (e.g., possible to re-calculate the + # the transform or calculate another transform on a reduced dimension space) + return None + + def sel( + self, labels: dict[Any, Any], method=None, tolerance=None + ) -> IndexSelResult: + from xarray.core.dataarray import DataArray + from xarray.core.variable import Variable + + if method != "nearest": + raise ValueError( + "CoordinateTransformIndex only supports selection with method='nearest'" + ) + + labels_set = set(labels) + coord_names_set = set(self.transform.coord_names) + + missing_labels = coord_names_set - labels_set + if missing_labels: + missing_labels_str = ",".join([f"{name}" for name in missing_labels]) + raise ValueError(f"missing labels for coordinate(s): {missing_labels_str}.") + + label0_obj = next(iter(labels.values())) + dim_size0 = getattr(label0_obj, "sizes", {}) + + is_xr_obj = [ + isinstance(label, DataArray | Variable) for label in labels.values() + ] + if not all(is_xr_obj): + raise TypeError( + "CoordinateTransformIndex only supports advanced (point-wise) indexing " + "with either xarray.DataArray or xarray.Variable objects." + ) + dim_size = [getattr(label, "sizes", {}) for label in labels.values()] + if any(ds != dim_size0 for ds in dim_size): + raise ValueError( + "CoordinateTransformIndex only supports advanced (point-wise) indexing " + "with xarray.DataArray or xarray.Variable objects of macthing dimensions." + ) + + coord_labels = { + name: labels[name].values for name in self.transform.coord_names + } + dim_positions = self.transform.reverse(coord_labels) + + results: dict[str, Variable | DataArray] = {} + dims0 = tuple(dim_size0) + for dim, pos in dim_positions.items(): + # TODO: rounding the decimal positions is not always the behavior we expect + # (there are different ways to represent implicit intervals) + # we should probably make this customizable. + pos = np.round(pos).astype("int") + if isinstance(label0_obj, Variable): + results[dim] = Variable(dims0, pos) + else: + # dataarray + results[dim] = DataArray(pos, dims=dims0) + + return IndexSelResult(results) + + def equals(self, other: Self) -> bool: + return self.transform.equals(other.transform) + + def rename( + self, + name_dict: Mapping[Any, Hashable], + dims_dict: Mapping[Any, Hashable], + ) -> Self: + # TODO: maybe update self.transform coord_names, dim_size and dims attributes + return self + + def create_default_index_implicit( dim_variable: Variable, all_variables: Mapping | Iterable[Hashable] | None = None, diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index cf9d3885f08..521abcdfddd 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -17,6 +17,7 @@ from packaging.version import Version from xarray.core import duck_array_ops +from xarray.core.coordinate_transform import CoordinateTransform from xarray.core.nputils import NumpyVIndexAdapter from xarray.core.options import OPTIONS from xarray.core.types import T_Xarray @@ -1301,6 +1302,42 @@ def _decompose_outer_indexer( return (BasicIndexer(tuple(backend_indexer)), OuterIndexer(tuple(np_indexer))) +def _posify_indices(indices: Any, size: int) -> np.ndarray: + """Convert negative indices by their equivalent positive indices. + + Note: the resulting indices may still be out of bounds (< 0 or >= size). + + """ + return np.where(indices < 0, size + indices, indices) + + +def _check_bounds(indices: Any, size: int): + """Check if the given indices are all within the array boundaries.""" + if np.any((indices < 0) | (indices >= size)): + raise IndexError("out of bounds index") + + +def _arrayize_outer_indexer(indexer: OuterIndexer, shape) -> OuterIndexer: + """Return a similar oindex with after replacing slices by arrays and + negative indices by their corresponding positive indices. + + Also check if array indices are within bounds. + + """ + new_key = [] + + for axis, value in enumerate(indexer.tuple): + size = shape[axis] + if isinstance(value, slice): + value = _expand_slice(value, size) + else: + value = _posify_indices(value, size) + _check_bounds(value, size) + new_key.append(value) + + return OuterIndexer(tuple(new_key)) + + def _arrayize_vectorized_indexer( indexer: VectorizedIndexer, shape: _Shape ) -> VectorizedIndexer: @@ -1981,3 +2018,114 @@ def copy(self, deep: bool = True) -> Self: # see PandasIndexingAdapter.copy array = self.array.copy(deep=True) if deep else self.array return type(self)(array, self._dtype, self.level) + + +class CoordinateTransformIndexingAdapter(ExplicitlyIndexedNDArrayMixin): + """Wrap a CoordinateTransform as a lazy coordinate array. + + Supports explicit indexing (both outer and vectorized). + + """ + + _transform: CoordinateTransform + _coord_name: Hashable + _dims: tuple[str, ...] + + def __init__( + self, + transform: CoordinateTransform, + coord_name: Hashable, + dims: tuple[str, ...] | None = None, + ): + self._transform = transform + self._coord_name = coord_name + self._dims = dims or transform.dims + + @property + def dtype(self) -> np.dtype: + return self._transform.dtype + + @property + def shape(self) -> tuple[int, ...]: + return tuple(self._transform.dim_size.values()) + + def get_duck_array(self) -> np.ndarray: + all_coords = self._transform.generate_coords(dims=self._dims) + return np.asarray(all_coords[self._coord_name]) + + def _oindex_get(self, indexer: OuterIndexer): + expanded_indexer_ = OuterIndexer(expanded_indexer(indexer.tuple, self.ndim)) + array_indexer = _arrayize_outer_indexer(expanded_indexer_, self.shape) + + positions = np.meshgrid(*array_indexer.tuple, indexing="ij") + dim_positions = dict(zip(self._dims, positions, strict=False)) + + result = self._transform.forward(dim_positions) + return np.asarray(result[self._coord_name]).squeeze() + + def _oindex_set(self, indexer: OuterIndexer, value: Any) -> None: + raise TypeError( + "setting values is not supported on coordinate transform arrays." + ) + + def _vindex_get(self, indexer: VectorizedIndexer): + expanded_indexer_ = VectorizedIndexer( + expanded_indexer(indexer.tuple, self.ndim) + ) + array_indexer = _arrayize_vectorized_indexer(expanded_indexer_, self.shape) + + dim_positions = {} + for i, (dim, pos) in enumerate( + zip(self._dims, array_indexer.tuple, strict=False) + ): + pos = _posify_indices(pos, self.shape[i]) + _check_bounds(pos, self.shape[i]) + dim_positions[dim] = pos + + result = self._transform.forward(dim_positions) + return np.asarray(result[self._coord_name]) + + def _vindex_set(self, indexer: VectorizedIndexer, value: Any) -> None: + raise TypeError( + "setting values is not supported on coordinate transform arrays." + ) + + def __getitem__(self, indexer: ExplicitIndexer): + # TODO: make it lazy (i.e., re-calculate and re-wrap the transform) when possible? + self._check_and_raise_if_non_basic_indexer(indexer) + + # also works with basic indexing + return self._oindex_get(OuterIndexer(indexer.tuple)) + + def __setitem__(self, indexer: ExplicitIndexer, value: Any) -> None: + raise TypeError( + "setting values is not supported on coordinate transform arrays." + ) + + def transpose(self, order: Iterable[int]) -> Self: + new_dims = tuple([self._dims[i] for i in order]) + return type(self)(self._transform, self._coord_name, new_dims) + + def __repr__(self: Any) -> str: + return f"{type(self).__name__}(transform={self._transform!r})" + + def _get_array_subset(self) -> np.ndarray: + threshold = max(100, OPTIONS["display_values_threshold"] + 2) + if self.size > threshold: + pos = threshold // 2 + flat_indices = np.concatenate( + [np.arange(0, pos), np.arange(self.size - pos, self.size)] + ) + subset = self.vindex[ + VectorizedIndexer(np.unravel_index(flat_indices, self.shape)) + ] + else: + subset = self + + return np.asarray(subset) + + def _repr_inline_(self, max_width: int) -> str: + """Good to see some labels even for a lazy coordinate.""" + from xarray.core.formatting import format_array_flat + + return format_array_flat(self._get_array_subset(), max_width) diff --git a/xarray/indexes/__init__.py b/xarray/indexes/__init__.py index b1bf7a1af11..9073cbc2ed4 100644 --- a/xarray/indexes/__init__.py +++ b/xarray/indexes/__init__.py @@ -3,6 +3,10 @@ """ -from xarray.core.indexes import Index, PandasIndex, PandasMultiIndex +from xarray.core.indexes import ( + Index, + PandasIndex, + PandasMultiIndex, +) __all__ = ["Index", "PandasIndex", "PandasMultiIndex"] diff --git a/xarray/tests/test_coordinate_transform.py b/xarray/tests/test_coordinate_transform.py new file mode 100644 index 00000000000..26746657dbc --- /dev/null +++ b/xarray/tests/test_coordinate_transform.py @@ -0,0 +1,218 @@ +from collections.abc import Hashable +from typing import Any + +import numpy as np +import pytest + +import xarray as xr +from xarray.core.coordinate_transform import CoordinateTransform +from xarray.core.indexes import CoordinateTransformIndex +from xarray.tests import assert_equal + + +class SimpleCoordinateTransform(CoordinateTransform): + """Simple uniform scale transform in a 2D space (x/y coordinates).""" + + def __init__(self, shape: tuple[int, int], scale: float, dtype: Any = None): + super().__init__(("x", "y"), {"x": shape[1], "y": shape[0]}, dtype=dtype) + + self.scale = scale + + # array dimensions in reverse order (y = rows, x = cols) + self.xy_dims = tuple(self.dims) + self.dims = (self.dims[1], self.dims[0]) + + def forward(self, dim_positions: dict[str, Any]) -> dict[Hashable, Any]: + assert set(dim_positions) == set(self.dims) + return {dim: dim_positions[dim] * self.scale for dim in self.xy_dims} + + def reverse(self, coord_labels: dict[Hashable, Any]) -> dict[str, Any]: + return {dim: coord_labels[dim] / self.scale for dim in self.xy_dims} + + def equals(self, other: "CoordinateTransform") -> bool: + if not isinstance(other, SimpleCoordinateTransform): + return False + return self.scale == other.scale + + def __repr__(self) -> str: + return f"Scale({self.scale})" + + +def test_abstract_coordinate_transform() -> None: + tr = CoordinateTransform(["x"], {"x": 5}) + + with pytest.raises(NotImplementedError): + tr.forward({"x": [1, 2]}) + + with pytest.raises(NotImplementedError): + tr.reverse({"x": [3.0, 4.0]}) + + with pytest.raises(NotImplementedError): + tr.equals(CoordinateTransform(["x"], {"x": 5})) + + +def test_coordinate_transform_init() -> None: + tr = SimpleCoordinateTransform((4, 4), 2.0) + + assert tr.coord_names == ("x", "y") + # array dimensions in reverse order (y = rows, x = cols) + assert tr.dims == ("y", "x") + assert tr.dim_size == {"x": 4, "y": 4} + assert tr.dtype == np.dtype(np.float64) + + tr2 = SimpleCoordinateTransform((4, 4), 2.0, dtype=np.int64) + assert tr2.dtype == np.dtype(np.int64) + + +@pytest.mark.parametrize("dims", [None, ("y", "x")]) +def test_coordinate_transform_generate_coords(dims) -> None: + tr = SimpleCoordinateTransform((2, 2), 2.0) + + actual = tr.generate_coords(dims) + expected = {"x": [[0.0, 2.0], [0.0, 2.0]], "y": [[0.0, 0.0], [2.0, 2.0]]} + assert set(actual) == set(expected) + np.testing.assert_array_equal(actual["x"], expected["x"]) + np.testing.assert_array_equal(actual["y"], expected["y"]) + + +def create_coords(scale: float, shape: tuple[int, int]) -> xr.Coordinates: + """Create x/y Xarray coordinate variables from a simple coordinate transform.""" + tr = SimpleCoordinateTransform(shape, scale) + index = CoordinateTransformIndex(tr) + return xr.Coordinates.from_xindex(index) + + +def test_coordinate_transform_variable() -> None: + coords = create_coords(scale=2.0, shape=(2, 2)) + + assert coords["x"].dtype == np.dtype(np.float64) + assert coords["y"].dtype == np.dtype(np.float64) + assert coords["x"].shape == (2, 2) + assert coords["y"].shape == (2, 2) + + np.testing.assert_array_equal(np.array(coords["x"]), [[0.0, 2.0], [0.0, 2.0]]) + np.testing.assert_array_equal(np.array(coords["y"]), [[0.0, 0.0], [2.0, 2.0]]) + + def assert_repr(var: xr.Variable): + assert ( + repr(var._data) + == "CoordinateTransformIndexingAdapter(transform=Scale(2.0))" + ) + + assert_repr(coords["x"].variable) + assert_repr(coords["y"].variable) + + +def test_coordinate_transform_variable_repr_inline() -> None: + var = create_coords(scale=2.0, shape=(2, 2))["x"].variable + + actual = var._data._repr_inline_(70) # type: ignore[union-attr] + assert actual == "0.0 2.0 0.0 2.0" + + # truncated inline repr + var2 = create_coords(scale=2.0, shape=(10, 10))["x"].variable + + actual2 = var2._data._repr_inline_(70) # type: ignore[union-attr] + assert ( + actual2 == "0.0 2.0 4.0 6.0 8.0 10.0 12.0 ... 6.0 8.0 10.0 12.0 14.0 16.0 18.0" + ) + + +def test_coordinate_transform_variable_basic_outer_indexing() -> None: + var = create_coords(scale=2.0, shape=(4, 4))["x"].variable + + assert var[0, 0] == 0.0 + assert var[0, 1] == 2.0 + assert var[0, -1] == 6.0 + np.testing.assert_array_equal(var[:, 0:2], [[0.0, 2.0]] * 4) + + with pytest.raises(IndexError, match="out of bounds index"): + var[5] + + with pytest.raises(IndexError, match="out of bounds index"): + var[-5] + + +def test_coordinate_transform_variable_vectorized_indexing() -> None: + var = create_coords(scale=2.0, shape=(4, 4))["x"].variable + + actual = var[{"x": xr.Variable("z", [0]), "y": xr.Variable("z", [0])}] + expected = xr.Variable("z", [0.0]) + assert_equal(actual, expected) + + with pytest.raises(IndexError, match="out of bounds index"): + var[{"x": xr.Variable("z", [5]), "y": xr.Variable("z", [5])}] + + +def test_coordinate_transform_setitem_error() -> None: + var = create_coords(scale=2.0, shape=(4, 4))["x"].variable + + # basic indexing + with pytest.raises(TypeError, match="setting values is not supported"): + var[0, 0] = 1.0 + + # outer indexing + with pytest.raises(TypeError, match="setting values is not supported"): + var[[0, 2], 0] = [1.0, 2.0] + + # vectorized indexing + with pytest.raises(TypeError, match="setting values is not supported"): + var[{"x": xr.Variable("z", [0]), "y": xr.Variable("z", [0])}] = 1.0 + + +def test_coordinate_transform_transpose() -> None: + coords = create_coords(scale=2.0, shape=(2, 2)) + + actual = coords["x"].transpose().values + expected = [[0.0, 0.0], [2.0, 2.0]] + np.testing.assert_array_equal(actual, expected) + + +def test_coordinate_transform_equals() -> None: + ds1 = create_coords(scale=2.0, shape=(2, 2)).to_dataset() + ds2 = create_coords(scale=2.0, shape=(2, 2)).to_dataset() + ds3 = create_coords(scale=4.0, shape=(2, 2)).to_dataset() + + # cannot use `assert_equal()` test utility function here yet + # (indexes invariant check are still based on IndexVariable, which + # doesn't work with coordinate transform index coordinate variables) + assert ds1.equals(ds2) + assert not ds1.equals(ds3) + + +def test_coordinate_transform_sel() -> None: + ds = create_coords(scale=2.0, shape=(4, 4)).to_dataset() + + data = [ + [0.0, 1.0, 2.0, 3.0], + [4.0, 5.0, 6.0, 7.0], + [8.0, 9.0, 10.0, 11.0], + [12.0, 13.0, 14.0, 15.0], + ] + ds["data"] = (("y", "x"), data) + + actual = ds.sel( + x=xr.Variable("z", [0.5, 5.5]), y=xr.Variable("z", [0.0, 0.5]), method="nearest" + ) + expected = ds.isel(x=xr.Variable("z", [0, 3]), y=xr.Variable("z", [0, 0])) + + # cannot use `assert_equal()` test utility function here yet + # (indexes invariant check are still based on IndexVariable, which + # doesn't work with coordinate transform index coordinate variables) + assert actual.equals(expected) + + with pytest.raises(ValueError, match=".*only supports selection.*nearest"): + ds.sel(x=xr.Variable("z", [0.5, 5.5]), y=xr.Variable("z", [0.0, 0.5])) + + with pytest.raises(ValueError, match="missing labels for coordinate.*y"): + ds.sel(x=[0.5, 5.5], method="nearest") + + with pytest.raises(TypeError, match=".*only supports advanced.*indexing"): + ds.sel(x=[0.5, 5.5], y=[0.0, 0.5], method="nearest") + + with pytest.raises(ValueError, match=".*only supports advanced.*indexing"): + ds.sel( + x=xr.Variable("z", [0.5, 5.5]), + y=xr.Variable("z", [0.0, 0.5, 1.5]), + method="nearest", + ) pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy