Skip to content

Add NDPointIndex (KDTree) #10478

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Jul 7, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
rename TreeIndex -> NDPointIndex
  • Loading branch information
benbovy committed Jul 2, 2025
commit 74affdcbd8359b7e1b33d35fb58656623d3922f9
2 changes: 1 addition & 1 deletion doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1577,7 +1577,7 @@ Custom Indexes
CFTimeIndex
indexes.RangeIndex
indexes.CoordinateTransformIndex
indexes.TreeIndex
indexes.NDPointIndex

Creating custom indexes
-----------------------
Expand Down
2 changes: 1 addition & 1 deletion doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ New Features
~~~~~~~~~~~~
- Expose :py:class:`~xarray.indexes.RangeIndex`, and :py:class:`~xarray.indexes.CoordinateTransformIndex` as public api
under the ``xarray.indexes`` namespace. By `Deepak Cherian <https://github.com/dcherian>`_.
- New :py:class:`xarray.indexes.TreeIndex`, which by default uses :py:class:`scipy.spatial.KDTree` under the hood for
- New :py:class:`xarray.indexes.NDPointIndex`, which by default uses :py:class:`scipy.spatial.KDTree` under the hood for
the selection of irregular, n-dimensional data (:pull:`10478`).
By `Benoit Bovy <https://github.com/benbovy>`_.

Expand Down
4 changes: 2 additions & 2 deletions xarray/indexes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@
PandasIndex,
PandasMultiIndex,
)
from xarray.indexes.nd_point_index import NDPointIndex
from xarray.indexes.range_index import RangeIndex
from xarray.indexes.tree_index import TreeIndex

__all__ = [
"CoordinateTransform",
"CoordinateTransformIndex",
"Index",
"NDPointIndex",
"PandasIndex",
"PandasMultiIndex",
"RangeIndex",
"TreeIndex",
]
24 changes: 13 additions & 11 deletions xarray/indexes/tree_index.py → xarray/indexes/nd_point_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
class TreeAdapter(abc.ABC):
"""Lightweight adapter abstract class for plugging in 3rd-party structures
like :py:class:`scipy.spatial.KDTree` or :py:class:`sklearn.neighbors.KDTree`
into :py:class:`~xarray.indexes.TreeIndex`.
into :py:class:`~xarray.indexes.NDPointIndex`.

"""

Expand Down Expand Up @@ -69,7 +69,7 @@ def equals(self, other: Self) -> bool:


class ScipyKDTreeAdapter(TreeAdapter):
""":py:class:`scipy.spatial.KDTree` adapter for :py:class:`~xarray.indexes.TreeIndex`."""
""":py:class:`scipy.spatial.KDTree` adapter for :py:class:`~xarray.indexes.NDPointIndex`."""

_kdtree: KDTree

Expand Down Expand Up @@ -97,7 +97,7 @@ def get_points(coords: Iterable[Variable | Any]) -> np.ndarray:
T_TreeAdapter = TypeVar("T_TreeAdapter", bound=TreeAdapter)


class TreeIndex(Index, Generic[T_TreeAdapter]):
class NDPointIndex(Index, Generic[T_TreeAdapter]):
"""Xarray index for irregular, n-dimensional data.

This index may be associated with a set of coordinate variables representing
Expand All @@ -111,7 +111,7 @@ class TreeIndex(Index, Generic[T_TreeAdapter]):
By default, this index relies on :py:class:`scipy.spatial.KDTree` for fast
lookup.

Do not use :py:meth:`~xarray.indexes.TreeIndex.__init__` directly. Instead
Do not use :py:meth:`~xarray.indexes.NDPointIndex.__init__` directly. Instead
use :py:meth:`~xarray.Dataset.set_xindex` or
:py:meth:`~xarray.DataArray.set_xindex` to create and set the index from
existing coordinates (see the example below).
Expand All @@ -134,9 +134,9 @@ class TreeIndex(Index, Generic[T_TreeAdapter]):
Data variables:
*empty*

Create a TreeIndex from the "xx" and "yy" coordinate variables:
Create a NDPointIndex from the "xx" and "yy" coordinate variables:

>>> ds = ds.set_xindex(("xx", "yy"), xr.indexes.TreeIndex)
>>> ds = ds.set_xindex(("xx", "yy"), xr.indexes.NDPointIndex)
>>> ds
<xarray.Dataset> Size: 64B
Dimensions: (y: 2, x: 2)
Expand All @@ -147,7 +147,7 @@ class TreeIndex(Index, Generic[T_TreeAdapter]):
Data variables:
*empty*
Indexes:
┌ xx TreeIndex
┌ xx NDPointIndex
└ yy

Point-wise (nearest-neighbor) data selection using Xarray's advanced
Expand Down Expand Up @@ -268,7 +268,7 @@ def create_variables(
if variables is not None:
for var in variables.values():
# might need to update variable dimensions from the index object
# returned from TreeIndex.rename()
# returned from NDPointIndex.rename()
if var.dims != self._dims:
var.dims = self._dims
return dict(**variables)
Expand All @@ -278,7 +278,7 @@ def create_variables(
def equals(
self, other: Index, *, exclude: frozenset[Hashable] | None = None
) -> bool:
if not isinstance(other, TreeIndex):
if not isinstance(other, NDPointIndex):
return False
if type(self._tree_obj) is not type(other._tree_obj):
return False
Expand Down Expand Up @@ -311,7 +311,9 @@ def sel(
self, labels: dict[Any, Any], method=None, tolerance=None
) -> IndexSelResult:
if method != "nearest":
raise ValueError("TreeIndex only supports selection with method='nearest'")
raise ValueError(
"NDPointIndex only supports selection with method='nearest'"
)

missing_labels = set(self._coord_names) - set(labels)
if missing_labels:
Expand All @@ -330,7 +332,7 @@ def sel(
xr_labels[name] = Variable(self._dims, lbl)
else:
raise ValueError(
"invalid label value. TreeIndex only supports advanced (point-wise) indexing "
"invalid label value. NDPointIndex only supports advanced (point-wise) indexing "
"with the following label value kinds:\n"
"- xarray.DataArray or xarray.Variable objects\n"
"- scalar values\n"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,26 @@
import pytest

import xarray as xr
from xarray.indexes import TreeIndex
from xarray.indexes import NDPointIndex
from xarray.tests import assert_identical

pytest.importorskip("scipy")


def test_tree_index_init() -> None:
from xarray.indexes.tree_index import ScipyKDTreeAdapter
from xarray.indexes.nd_point_index import ScipyKDTreeAdapter

xx, yy = np.meshgrid([1.0, 2.0], [3.0, 4.0])
ds = xr.Dataset(coords={"xx": (("y", "x"), xx), "yy": (("y", "x"), yy)})

ds_indexed1 = ds.set_xindex(("xx", "yy"), TreeIndex)
ds_indexed1 = ds.set_xindex(("xx", "yy"), NDPointIndex)
assert "xx" in ds_indexed1.xindexes
assert "yy" in ds_indexed1.xindexes
assert isinstance(ds_indexed1.xindexes["xx"], TreeIndex)
assert isinstance(ds_indexed1.xindexes["xx"], NDPointIndex)
assert ds_indexed1.xindexes["xx"] is ds_indexed1.xindexes["yy"]

ds_indexed2 = ds.set_xindex(
("xx", "yy"), TreeIndex, tree_adapter_cls=ScipyKDTreeAdapter
("xx", "yy"), NDPointIndex, tree_adapter_cls=ScipyKDTreeAdapter
)
assert ds_indexed1.xindexes["xx"].equals(ds_indexed2.xindexes["yy"])

Expand All @@ -31,18 +31,18 @@ def test_tree_index_init_errors() -> None:
ds = xr.Dataset(coords={"xx": (("y", "x"), xx), "yy": (("y", "x"), yy)})

with pytest.raises(ValueError, match="number of variables"):
ds.set_xindex("xx", TreeIndex)
ds.set_xindex("xx", NDPointIndex)

ds2 = ds.assign_coords(yy=(("u", "v"), [[3.0, 3.0], [4.0, 4.0]]))

with pytest.raises(ValueError, match="same dimensions"):
ds2.set_xindex(("xx", "yy"), TreeIndex)
ds2.set_xindex(("xx", "yy"), NDPointIndex)


def test_tree_index_sel() -> None:
xx, yy = np.meshgrid([1.0, 2.0], [3.0, 4.0])
ds = xr.Dataset(coords={"xx": (("y", "x"), xx), "yy": (("y", "x"), yy)}).set_xindex(
("xx", "yy"), TreeIndex
("xx", "yy"), NDPointIndex
)

# 1-dimensional labels
Expand Down Expand Up @@ -101,7 +101,7 @@ def test_tree_index_sel() -> None:
def test_tree_index_sel_errors() -> None:
xx, yy = np.meshgrid([1.0, 2.0], [3.0, 4.0])
ds = xr.Dataset(coords={"xx": (("y", "x"), xx), "yy": (("y", "x"), yy)}).set_xindex(
("xx", "yy"), TreeIndex
("xx", "yy"), NDPointIndex
)

with pytest.raises(ValueError, match="method='nearest'"):
Expand Down Expand Up @@ -133,17 +133,17 @@ def test_tree_index_equals() -> None:
xx1, yy1 = np.meshgrid([1.0, 2.0], [3.0, 4.0])
ds1 = xr.Dataset(
coords={"xx": (("y", "x"), xx1), "yy": (("y", "x"), yy1)}
).set_xindex(("xx", "yy"), TreeIndex)
).set_xindex(("xx", "yy"), NDPointIndex)

xx2, yy2 = np.meshgrid([1.0, 2.0], [3.0, 4.0])
ds2 = xr.Dataset(
coords={"xx": (("y", "x"), xx2), "yy": (("y", "x"), yy2)}
).set_xindex(("xx", "yy"), TreeIndex)
).set_xindex(("xx", "yy"), NDPointIndex)

xx3, yy3 = np.meshgrid([10.0, 20.0], [30.0, 40.0])
ds3 = xr.Dataset(
coords={"xx": (("y", "x"), xx3), "yy": (("y", "x"), yy3)}
).set_xindex(("xx", "yy"), TreeIndex)
).set_xindex(("xx", "yy"), NDPointIndex)

assert ds1.xindexes["xx"].equals(ds2.xindexes["xx"])
assert not ds1.xindexes["xx"].equals(ds3.xindexes["xx"])
Expand All @@ -152,12 +152,12 @@ def test_tree_index_equals() -> None:
def test_tree_index_rename() -> None:
xx, yy = np.meshgrid([1.0, 2.0], [3.0, 4.0])
ds = xr.Dataset(coords={"xx": (("y", "x"), xx), "yy": (("y", "x"), yy)}).set_xindex(
("xx", "yy"), TreeIndex
("xx", "yy"), NDPointIndex
)

ds_renamed = ds.rename_dims(y="u").rename_vars(yy="uu")
assert "uu" in ds_renamed.xindexes
assert isinstance(ds_renamed.xindexes["uu"], TreeIndex)
assert isinstance(ds_renamed.xindexes["uu"], NDPointIndex)
assert ds_renamed.xindexes["xx"] is ds_renamed.xindexes["uu"]

# check via sel() that uses coord names and dims under the hood
Expand Down
Loading
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy