Skip to content

Commit e28f171

Browse files
kmuehlbauerspencerkclarkdcherian
authored
fix mean for datetime-like using the respective time resolution unit (#9977)
* fix mean for datetime-like by using the respective dtype time resolution unit, adapting tests * fix mypy * add PR to existing entry for non-nanosecond datetimes * Update xarray/core/duck_array_ops.py Co-authored-by: Spencer Clark <spencerkclark@gmail.com> * cast to "int64" in calculation of datime-like mean * Apply suggestions from code review * Apply suggestions from code review Co-authored-by: Spencer Clark <spencerkclark@gmail.com> --------- Co-authored-by: Spencer Clark <spencerkclark@gmail.com> Co-authored-by: Deepak Cherian <dcherian@users.noreply.github.com>
1 parent e432479 commit e28f171

File tree

3 files changed

+51
-38
lines changed

3 files changed

+51
-38
lines changed

doc/whats-new.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ eventually be deprecated.
5050

5151
New Features
5252
~~~~~~~~~~~~
53-
- Relax nanosecond datetime restriction in CF time decoding (:issue:`7493`, :pull:`9618`).
53+
- Relax nanosecond datetime restriction in CF time decoding (:issue:`7493`, :pull:`9618`, :pull:`9977`).
5454
By `Kai Mühlbauer <https://github.com/kmuehlbauer>`_ and `Spencer Clark <https://github.com/spencerkclark>`_.
5555
- Enable the ``compute=False`` option in :py:meth:`DataTree.to_zarr`. (:pull:`9958`).
5656
By `Sam Levang <https://github.com/slevang>`_.

xarray/core/duck_array_ops.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -662,16 +662,10 @@ def _to_pytimedelta(array, unit="us"):
662662

663663

664664
def np_timedelta64_to_float(array, datetime_unit):
665-
"""Convert numpy.timedelta64 to float.
666-
667-
Notes
668-
-----
669-
The array is first converted to microseconds, which is less likely to
670-
cause overflow errors.
671-
"""
672-
array = array.astype("timedelta64[ns]").astype(np.float64)
673-
conversion_factor = np.timedelta64(1, "ns") / np.timedelta64(1, datetime_unit)
674-
return conversion_factor * array
665+
"""Convert numpy.timedelta64 to float, possibly at a loss of resolution."""
666+
unit, _ = np.datetime_data(array.dtype)
667+
conversion_factor = np.timedelta64(1, unit) / np.timedelta64(1, datetime_unit)
668+
return conversion_factor * array.astype(np.float64)
675669

676670

677671
def pd_timedelta_to_float(value, datetime_unit):
@@ -715,12 +709,15 @@ def mean(array, axis=None, skipna=None, **kwargs):
715709
if dtypes.is_datetime_like(array.dtype):
716710
offset = _datetime_nanmin(array)
717711

718-
# xarray always uses np.datetime64[ns] for np.datetime64 data
719-
dtype = "timedelta64[ns]"
712+
# From version 2025.01.2 xarray uses np.datetime64[unit], where unit
713+
# is one of "s", "ms", "us", "ns".
714+
# To not have to worry about the resolution, we just convert the output
715+
# to "timedelta64" (without unit) and let the dtype of offset take precedence.
716+
# This is fully backwards compatible with datetime64[ns].
720717
return (
721718
_mean(
722719
datetime_to_numeric(array, offset), axis=axis, skipna=skipna, **kwargs
723-
).astype(dtype)
720+
).astype("timedelta64")
724721
+ offset
725722
)
726723
elif _contains_cftime_datetimes(array):

xarray/tests/test_duck_array_ops.py

Lines changed: 40 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from numpy import array, nan
1010

1111
from xarray import DataArray, Dataset, cftime_range, concat
12+
from xarray.coding.times import _NS_PER_TIME_DELTA
1213
from xarray.core import dtypes, duck_array_ops
1314
from xarray.core.duck_array_ops import (
1415
array_notnull_equiv,
@@ -28,6 +29,7 @@
2829
where,
2930
)
3031
from xarray.core.extension_array import PandasExtensionArray
32+
from xarray.core.types import NPDatetimeUnitOptions, PDDatetimeUnitOptions
3133
from xarray.namedarray.pycompat import array_type
3234
from xarray.testing import assert_allclose, assert_equal, assert_identical
3335
from xarray.tests import (
@@ -411,10 +413,11 @@ def assert_dask_array(da, dask):
411413
@arm_xfail
412414
@pytest.mark.filterwarnings("ignore:All-NaN .* encountered:RuntimeWarning")
413415
@pytest.mark.parametrize("dask", [False, True] if has_dask else [False])
414-
def test_datetime_mean(dask: bool) -> None:
416+
def test_datetime_mean(dask: bool, time_unit: PDDatetimeUnitOptions) -> None:
415417
# Note: only testing numpy, as dask is broken upstream
418+
dtype = f"M8[{time_unit}]"
416419
da = DataArray(
417-
np.array(["2010-01-01", "NaT", "2010-01-03", "NaT", "NaT"], dtype="M8[ns]"),
420+
np.array(["2010-01-01", "NaT", "2010-01-03", "NaT", "NaT"], dtype=dtype),
418421
dims=["time"],
419422
)
420423
if dask:
@@ -846,11 +849,11 @@ def test_multiple_dims(dtype, dask, skipna, func):
846849

847850

848851
@pytest.mark.parametrize("dask", [True, False])
849-
def test_datetime_to_numeric_datetime64(dask):
852+
def test_datetime_to_numeric_datetime64(dask, time_unit: PDDatetimeUnitOptions):
850853
if dask and not has_dask:
851854
pytest.skip("requires dask")
852855

853-
times = pd.date_range("2000", periods=5, freq="7D").values
856+
times = pd.date_range("2000", periods=5, freq="7D").as_unit(time_unit).values
854857
if dask:
855858
import dask.array
856859

@@ -874,8 +877,8 @@ def test_datetime_to_numeric_datetime64(dask):
874877
result = duck_array_ops.datetime_to_numeric(
875878
times, datetime_unit="h", dtype=dtype
876879
)
877-
expected = 24 * np.arange(0, 35, 7).astype(dtype)
878-
np.testing.assert_array_equal(result, expected)
880+
expected2 = 24 * np.arange(0, 35, 7).astype(dtype)
881+
np.testing.assert_array_equal(result, expected2)
879882

880883

881884
@requires_cftime
@@ -923,15 +926,18 @@ def test_datetime_to_numeric_cftime(dask):
923926

924927

925928
@requires_cftime
926-
def test_datetime_to_numeric_potential_overflow():
929+
def test_datetime_to_numeric_potential_overflow(time_unit: PDDatetimeUnitOptions):
927930
import cftime
928931

929-
times = pd.date_range("2000", periods=5, freq="7D").values.astype("datetime64[us]")
932+
if time_unit == "ns":
933+
pytest.skip("out-of-bounds datetime64 overflow")
934+
dtype = f"M8[{time_unit}]"
935+
times = pd.date_range("2000", periods=5, freq="7D").values.astype(dtype)
930936
cftimes = cftime_range(
931937
"2000", periods=5, freq="7D", calendar="proleptic_gregorian"
932938
).values
933939

934-
offset = np.datetime64("0001-01-01")
940+
offset = np.datetime64("0001-01-01", time_unit)
935941
cfoffset = cftime.DatetimeProlepticGregorian(1, 1, 1)
936942

937943
result = duck_array_ops.datetime_to_numeric(
@@ -957,35 +963,45 @@ def test_py_timedelta_to_float():
957963
assert py_timedelta_to_float(dt.timedelta(days=1e6), "D") == 1e6
958964

959965

960-
@pytest.mark.parametrize(
961-
"td, expected",
962-
([np.timedelta64(1, "D"), 86400 * 1e9], [np.timedelta64(1, "ns"), 1.0]),
963-
)
964-
def test_np_timedelta64_to_float(td, expected):
965-
out = np_timedelta64_to_float(td, datetime_unit="ns")
966+
@pytest.mark.parametrize("np_dt_unit", ["D", "h", "m", "s", "ms", "us", "ns"])
967+
def test_np_timedelta64_to_float(
968+
np_dt_unit: NPDatetimeUnitOptions, time_unit: PDDatetimeUnitOptions
969+
):
970+
# tests any combination of source np.timedelta64 (NPDatetimeUnitOptions) with
971+
# np_timedelta_to_float with dedicated target unit (PDDatetimeUnitOptions)
972+
td = np.timedelta64(1, np_dt_unit)
973+
expected = _NS_PER_TIME_DELTA[np_dt_unit] / _NS_PER_TIME_DELTA[time_unit]
974+
975+
out = np_timedelta64_to_float(td, datetime_unit=time_unit)
966976
np.testing.assert_allclose(out, expected)
967977
assert isinstance(out, float)
968978

969-
out = np_timedelta64_to_float(np.atleast_1d(td), datetime_unit="ns")
979+
out = np_timedelta64_to_float(np.atleast_1d(td), datetime_unit=time_unit)
970980
np.testing.assert_allclose(out, expected)
971981

972982

973-
@pytest.mark.parametrize(
974-
"td, expected", ([pd.Timedelta(1, "D"), 86400 * 1e9], [pd.Timedelta(1, "ns"), 1.0])
975-
)
976-
def test_pd_timedelta_to_float(td, expected):
977-
out = pd_timedelta_to_float(td, datetime_unit="ns")
983+
@pytest.mark.parametrize("np_dt_unit", ["D", "h", "m", "s", "ms", "us", "ns"])
984+
def test_pd_timedelta_to_float(
985+
np_dt_unit: NPDatetimeUnitOptions, time_unit: PDDatetimeUnitOptions
986+
):
987+
# tests any combination of source pd.Timedelta (NPDatetimeUnitOptions) with
988+
# np_timedelta_to_float with dedicated target unit (PDDatetimeUnitOptions)
989+
td = pd.Timedelta(1, np_dt_unit)
990+
expected = _NS_PER_TIME_DELTA[np_dt_unit] / _NS_PER_TIME_DELTA[time_unit]
991+
992+
out = pd_timedelta_to_float(td, datetime_unit=time_unit)
978993
np.testing.assert_allclose(out, expected)
979994
assert isinstance(out, float)
980995

981996

982997
@pytest.mark.parametrize(
983998
"td", [dt.timedelta(days=1), np.timedelta64(1, "D"), pd.Timedelta(1, "D"), "1 day"]
984999
)
985-
def test_timedelta_to_numeric(td):
1000+
def test_timedelta_to_numeric(td, time_unit: PDDatetimeUnitOptions):
9861001
# Scalar input
987-
out = timedelta_to_numeric(td, "ns")
988-
np.testing.assert_allclose(out, 86400 * 1e9)
1002+
out = timedelta_to_numeric(td, time_unit)
1003+
expected = _NS_PER_TIME_DELTA["D"] / _NS_PER_TIME_DELTA[time_unit]
1004+
np.testing.assert_allclose(out, expected)
9891005
assert isinstance(out, float)
9901006

9911007

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy