diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 30a1c588c61..cac5cb1ea8b 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -204,6 +204,15 @@ Bug fixes (:pull:`10352`). By `Spencer Clark `_. - Avoid unsafe casts from float to unsigned int in CFMaskCoder (:issue:`9815`, :pull:`9964`). By ` Elliott Sales de Andrade `_. +- Fix attribute overwriting bug when decoding encoded + :py:class:`numpy.timedelta64` values from disk with a dtype attribute + (:issue:`10468`, :pull:`10469`). By `Spencer Clark + `_. +- Fix default ``"_FillValue"`` dtype coercion bug when encoding + :py:class:`numpy.timedelta64` values to an on-disk format that only supports + 32-bit integers (:issue:`10466`, :pull:`10469`). By `Spencer Clark + `_. + Performance ~~~~~~~~~~~ diff --git a/xarray/coding/times.py b/xarray/coding/times.py index d1cc36558fa..d6567ba4c61 100644 --- a/xarray/coding/times.py +++ b/xarray/coding/times.py @@ -1410,6 +1410,43 @@ def has_timedelta64_encoding_dtype(attrs_or_encoding: dict) -> bool: return isinstance(dtype, str) and dtype.startswith("timedelta64") +def resolve_time_unit_from_attrs_dtype( + attrs_dtype: str, name: T_Name +) -> PDDatetimeUnitOptions: + dtype = np.dtype(attrs_dtype) + resolution, _ = np.datetime_data(dtype) + resolution = cast(NPDatetimeUnitOptions, resolution) + if np.timedelta64(1, resolution) > np.timedelta64(1, "s"): + time_unit = cast(PDDatetimeUnitOptions, "s") + message = ( + f"Following pandas, xarray only supports decoding to timedelta64 " + f"values with a resolution of 's', 'ms', 'us', or 'ns'. Encoded " + f"values for variable {name!r} have a resolution of " + f"{resolution!r}. Attempting to decode to a resolution of 's'. " + f"Note, depending on the encoded values, this may lead to an " + f"OverflowError. Additionally, data will not be identically round " + f"tripped; xarray will choose an encoding dtype of " + f"'timedelta64[s]' when re-encoding." + ) + emit_user_level_warning(message) + elif np.timedelta64(1, resolution) < np.timedelta64(1, "ns"): + time_unit = cast(PDDatetimeUnitOptions, "ns") + message = ( + f"Following pandas, xarray only supports decoding to timedelta64 " + f"values with a resolution of 's', 'ms', 'us', or 'ns'. Encoded " + f"values for variable {name!r} have a resolution of " + f"{resolution!r}. Attempting to decode to a resolution of 'ns'. " + f"Note, depending on the encoded values, this may lead to loss of " + f"precision. Additionally, data will not be identically round " + f"tripped; xarray will choose an encoding dtype of " + f"'timedelta64[ns]' when re-encoding." + ) + emit_user_level_warning(message) + else: + time_unit = cast(PDDatetimeUnitOptions, resolution) + return time_unit + + class CFTimedeltaCoder(VariableCoder): """Coder for CF Timedelta coding. @@ -1430,7 +1467,7 @@ class CFTimedeltaCoder(VariableCoder): def __init__( self, - time_unit: PDDatetimeUnitOptions = "ns", + time_unit: PDDatetimeUnitOptions | None = None, decode_via_units: bool = True, decode_via_dtype: bool = True, ) -> None: @@ -1442,45 +1479,18 @@ def __init__( def encode(self, variable: Variable, name: T_Name = None) -> Variable: if np.issubdtype(variable.data.dtype, np.timedelta64): dims, data, attrs, encoding = unpack_for_encoding(variable) - has_timedelta_dtype = has_timedelta64_encoding_dtype(encoding) - if ("units" in encoding or "dtype" in encoding) and not has_timedelta_dtype: - dtype = encoding.get("dtype", None) - units = encoding.pop("units", None) + dtype = encoding.get("dtype", None) + units = encoding.pop("units", None) - # in the case of packed data we need to encode into - # float first, the correct dtype will be established - # via CFScaleOffsetCoder/CFMaskCoder - if "add_offset" in encoding or "scale_factor" in encoding: - dtype = data.dtype if data.dtype.kind == "f" else "float64" + # in the case of packed data we need to encode into + # float first, the correct dtype will be established + # via CFScaleOffsetCoder/CFMaskCoder + if "add_offset" in encoding or "scale_factor" in encoding: + dtype = data.dtype if data.dtype.kind == "f" else "float64" - else: - resolution, _ = np.datetime_data(variable.dtype) - dtype = np.int64 - attrs_dtype = f"timedelta64[{resolution}]" - units = _numpy_dtype_to_netcdf_timeunit(variable.dtype) - safe_setitem(attrs, "dtype", attrs_dtype, name=name) - # Remove dtype encoding if it exists to prevent it from - # interfering downstream in NonStringCoder. - encoding.pop("dtype", None) - - if any( - k in encoding for k in _INVALID_LITERAL_TIMEDELTA64_ENCODING_KEYS - ): - raise ValueError( - f"Specifying 'add_offset' or 'scale_factor' is not " - f"supported when encoding the timedelta64 values of " - f"variable {name!r} with xarray's new default " - f"timedelta64 encoding approach. To encode {name!r} " - f"with xarray's previous timedelta64 encoding " - f"approach, which supports the 'add_offset' and " - f"'scale_factor' parameters, additionally set " - f"encoding['units'] to a unit of time, e.g. " - f"'seconds'. To proceed with encoding of {name!r} " - f"via xarray's new approach, remove any encoding " - f"entries for 'add_offset' or 'scale_factor'." - ) - if "_FillValue" not in encoding and "missing_value" not in encoding: - encoding["_FillValue"] = np.iinfo(np.int64).min + resolution, _ = np.datetime_data(variable.dtype) + attrs_dtype = f"timedelta64[{resolution}]" + safe_setitem(attrs, "dtype", attrs_dtype, name=name) data, units = encode_cf_timedelta(data, units, dtype) safe_setitem(attrs, "units", units, name=name) @@ -1499,54 +1509,13 @@ def decode(self, variable: Variable, name: T_Name = None) -> Variable: ): dims, data, attrs, encoding = unpack_for_decoding(variable) units = pop_to(attrs, encoding, "units") - if is_dtype_decodable and self.decode_via_dtype: - if any( - k in encoding for k in _INVALID_LITERAL_TIMEDELTA64_ENCODING_KEYS - ): - raise ValueError( - f"Decoding timedelta64 values via dtype is not " - f"supported when 'add_offset', or 'scale_factor' are " - f"present in encoding. Check the encoding parameters " - f"of variable {name!r}." - ) - dtype = pop_to(attrs, encoding, "dtype", name=name) - dtype = np.dtype(dtype) - resolution, _ = np.datetime_data(dtype) - resolution = cast(NPDatetimeUnitOptions, resolution) - if np.timedelta64(1, resolution) > np.timedelta64(1, "s"): - time_unit = cast(PDDatetimeUnitOptions, "s") - dtype = np.dtype("timedelta64[s]") - message = ( - f"Following pandas, xarray only supports decoding to " - f"timedelta64 values with a resolution of 's', 'ms', " - f"'us', or 'ns'. Encoded values for variable {name!r} " - f"have a resolution of {resolution!r}. Attempting to " - f"decode to a resolution of 's'. Note, depending on " - f"the encoded values, this may lead to an " - f"OverflowError. Additionally, data will not be " - f"identically round tripped; xarray will choose an " - f"encoding dtype of 'timedelta64[s]' when re-encoding." - ) - emit_user_level_warning(message) - elif np.timedelta64(1, resolution) < np.timedelta64(1, "ns"): - time_unit = cast(PDDatetimeUnitOptions, "ns") - dtype = np.dtype("timedelta64[ns]") - message = ( - f"Following pandas, xarray only supports decoding to " - f"timedelta64 values with a resolution of 's', 'ms', " - f"'us', or 'ns'. Encoded values for variable {name!r} " - f"have a resolution of {resolution!r}. Attempting to " - f"decode to a resolution of 'ns'. Note, depending on " - f"the encoded values, this may lead to loss of " - f"precision. Additionally, data will not be " - f"identically round tripped; xarray will choose an " - f"encoding dtype of 'timedelta64[ns]' " - f"when re-encoding." - ) - emit_user_level_warning(message) + if is_dtype_decodable: + attrs_dtype = attrs.pop("dtype") + if self.time_unit is None: + time_unit = resolve_time_unit_from_attrs_dtype(attrs_dtype, name) else: - time_unit = cast(PDDatetimeUnitOptions, resolution) - elif self.decode_via_units: + time_unit = self.time_unit + else: if self._emit_decode_timedelta_future_warning: emit_user_level_warning( "In a future version, xarray will not decode " @@ -1564,8 +1533,19 @@ def decode(self, variable: Variable, name: T_Name = None) -> Variable: "'CFTimedeltaCoder' instance.", FutureWarning, ) - dtype = np.dtype(f"timedelta64[{self.time_unit}]") - time_unit = self.time_unit + if self.time_unit is None: + time_unit = cast(PDDatetimeUnitOptions, "ns") + else: + time_unit = self.time_unit + + # Handle edge case that decode_via_dtype=False and + # decode_via_units=True, and timedeltas were encoded with a + # dtype attribute. We need to remove the dtype attribute + # to prevent an error during round tripping. + if has_timedelta_dtype: + attrs.pop("dtype") + + dtype = np.dtype(f"timedelta64[{time_unit}]") transform = partial(decode_cf_timedelta, units=units, time_unit=time_unit) data = lazy_elemwise_func(data, transform, dtype=dtype) return Variable(dims, data, attrs, encoding, fastpath=True) diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index f42d2c2c17f..2709e834e68 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -56,6 +56,7 @@ from xarray.conventions import encode_dataset_coordinates from xarray.core import indexing from xarray.core.options import set_options +from xarray.core.types import PDDatetimeUnitOptions from xarray.core.utils import module_available from xarray.namedarray.pycompat import array_type from xarray.tests import ( @@ -642,6 +643,16 @@ def test_roundtrip_timedelta_data(self) -> None: ) as actual: assert_identical(expected, actual) + def test_roundtrip_timedelta_data_via_dtype( + self, time_unit: PDDatetimeUnitOptions + ) -> None: + time_deltas = pd.to_timedelta(["1h", "2h", "NaT"]).as_unit(time_unit) # type: ignore[arg-type, unused-ignore] + expected = Dataset( + {"td": ("td", time_deltas), "td0": time_deltas[0].to_numpy()} + ) + with self.roundtrip(expected) as actual: + assert_identical(expected, actual) + def test_roundtrip_float64_data(self) -> None: expected = Dataset({"x": ("y", np.array([1.0, 2.0, np.pi], dtype="float64"))}) with self.roundtrip(expected) as actual: diff --git a/xarray/tests/test_coding_times.py b/xarray/tests/test_coding_times.py index 65caab1c709..af29716fec0 100644 --- a/xarray/tests/test_coding_times.py +++ b/xarray/tests/test_coding_times.py @@ -20,7 +20,6 @@ ) from xarray.coders import CFDatetimeCoder, CFTimedeltaCoder from xarray.coding.times import ( - _INVALID_LITERAL_TIMEDELTA64_ENCODING_KEYS, _encode_datetime_with_cftime, _netcdf_to_numpy_timeunit, _numpy_to_netcdf_timeunit, @@ -1824,8 +1823,9 @@ def test_encode_cf_timedelta_small_dtype_missing_value(use_dask) -> None: assert_equal(variable, decoded) -_DECODE_TIMEDELTA_TESTS = { +_DECODE_TIMEDELTA_VIA_UNITS_TESTS = { "default": (True, None, np.dtype("timedelta64[ns]"), True), + "decode_timedelta=True": (True, True, np.dtype("timedelta64[ns]"), False), "decode_timedelta=False": (True, False, np.dtype("int64"), False), "inherit-time_unit-from-decode_times": ( CFDatetimeCoder(time_unit="s"), @@ -1856,16 +1856,16 @@ def test_encode_cf_timedelta_small_dtype_missing_value(use_dask) -> None: @pytest.mark.parametrize( ("decode_times", "decode_timedelta", "expected_dtype", "warns"), - list(_DECODE_TIMEDELTA_TESTS.values()), - ids=list(_DECODE_TIMEDELTA_TESTS.keys()), + list(_DECODE_TIMEDELTA_VIA_UNITS_TESTS.values()), + ids=list(_DECODE_TIMEDELTA_VIA_UNITS_TESTS.keys()), ) -def test_decode_timedelta( +def test_decode_timedelta_via_units( decode_times, decode_timedelta, expected_dtype, warns ) -> None: timedeltas = pd.timedelta_range(0, freq="D", periods=3) - encoding = {"units": "days"} - var = Variable(["time"], timedeltas, encoding=encoding) - encoded = conventions.encode_cf_variable(var) + attrs = {"units": "days"} + var = Variable(["time"], timedeltas, encoding=attrs) + encoded = Variable(["time"], np.array([0, 1, 2]), attrs=attrs) if warns: with pytest.warns(FutureWarning, match="decode_timedelta"): decoded = conventions.decode_cf_variable( @@ -1885,6 +1885,57 @@ def test_decode_timedelta( assert decoded.dtype == expected_dtype +_DECODE_TIMEDELTA_VIA_DTYPE_TESTS = { + "default": (True, None, np.dtype("timedelta64[ns]")), + "decode_timedelta=False": (True, False, np.dtype("int64")), + "decode_timedelta=True": (True, True, np.dtype("timedelta64[ns]")), + "inherit-time_unit-from-decode_times": ( + CFDatetimeCoder(time_unit="s"), + None, + np.dtype("timedelta64[s]"), + ), + "set-time_unit-via-CFTimedeltaCoder-decode_times=True": ( + True, + CFTimedeltaCoder(time_unit="s"), + np.dtype("timedelta64[s]"), + ), + "set-time_unit-via-CFTimedeltaCoder-decode_times=False": ( + False, + CFTimedeltaCoder(time_unit="s"), + np.dtype("timedelta64[s]"), + ), + "override-time_unit-from-decode_times": ( + CFDatetimeCoder(time_unit="ns"), + CFTimedeltaCoder(time_unit="s"), + np.dtype("timedelta64[s]"), + ), +} + + +@pytest.mark.parametrize( + ("decode_times", "decode_timedelta", "expected_dtype"), + list(_DECODE_TIMEDELTA_VIA_DTYPE_TESTS.values()), + ids=list(_DECODE_TIMEDELTA_VIA_DTYPE_TESTS.keys()), +) +def test_decode_timedelta_via_dtype( + decode_times, decode_timedelta, expected_dtype +) -> None: + timedeltas = pd.timedelta_range(0, freq="D", periods=3) + encoding = {"units": "days"} + var = Variable(["time"], timedeltas, encoding=encoding) + encoded = conventions.encode_cf_variable(var) + assert encoded.attrs["dtype"] == "timedelta64[ns]" + assert encoded.attrs["units"] == encoding["units"] + decoded = conventions.decode_cf_variable( + "foo", encoded, decode_times=decode_times, decode_timedelta=decode_timedelta + ) + if decode_timedelta is False: + assert_equal(encoded, decoded) + else: + assert_equal(var, decoded) + assert decoded.dtype == expected_dtype + + def test_lazy_decode_timedelta_unexpected_dtype() -> None: attrs = {"units": "seconds"} encoded = Variable(["time"], [0, 0.5, 1], attrs=attrs) @@ -1940,7 +1991,12 @@ def test_duck_array_decode_times(calendar) -> None: def test_decode_timedelta_mask_and_scale( decode_timedelta: bool, mask_and_scale: bool ) -> None: - attrs = {"units": "nanoseconds", "_FillValue": np.int16(-1), "add_offset": 100000.0} + attrs = { + "dtype": "timedelta64[ns]", + "units": "nanoseconds", + "_FillValue": np.int16(-1), + "add_offset": 100000.0, + } encoded = Variable(["time"], np.array([0, -1, 1], "int16"), attrs=attrs) decoded = conventions.decode_cf_variable( "foo", encoded, mask_and_scale=mask_and_scale, decode_timedelta=decode_timedelta @@ -1958,19 +2014,17 @@ def test_decode_floating_point_timedelta_no_serialization_warning() -> None: decoded.load() -def test_literal_timedelta64_coding(time_unit: PDDatetimeUnitOptions) -> None: +def test_timedelta64_coding_via_dtype(time_unit: PDDatetimeUnitOptions) -> None: timedeltas = np.array([0, 1, "NaT"], dtype=f"timedelta64[{time_unit}]") variable = Variable(["time"], timedeltas) - expected_dtype = f"timedelta64[{time_unit}]" expected_units = _numpy_to_netcdf_timeunit(time_unit) encoded = conventions.encode_cf_variable(variable) - assert encoded.attrs["dtype"] == expected_dtype + assert encoded.attrs["dtype"] == f"timedelta64[{time_unit}]" assert encoded.attrs["units"] == expected_units - assert encoded.attrs["_FillValue"] == np.iinfo(np.int64).min decoded = conventions.decode_cf_variable("timedeltas", encoded) - assert decoded.encoding["dtype"] == expected_dtype + assert decoded.encoding["dtype"] == np.dtype("int64") assert decoded.encoding["units"] == expected_units assert_identical(decoded, variable) @@ -1981,7 +2035,7 @@ def test_literal_timedelta64_coding(time_unit: PDDatetimeUnitOptions) -> None: assert reencoded.dtype == encoded.dtype -def test_literal_timedelta_coding_non_pandas_coarse_resolution_warning() -> None: +def test_timedelta_coding_via_dtype_non_pandas_coarse_resolution_warning() -> None: attrs = {"dtype": "timedelta64[D]", "units": "days"} encoded = Variable(["time"], [0, 1, 2], attrs=attrs) with pytest.warns(UserWarning, match="xarray only supports"): @@ -1994,7 +2048,7 @@ def test_literal_timedelta_coding_non_pandas_coarse_resolution_warning() -> None @pytest.mark.xfail(reason="xarray does not recognize picoseconds as time-like") -def test_literal_timedelta_coding_non_pandas_fine_resolution_warning() -> None: +def test_timedelta_coding_via_dtype_non_pandas_fine_resolution_warning() -> None: attrs = {"dtype": "timedelta64[ps]", "units": "picoseconds"} encoded = Variable(["time"], [0, 1000, 2000], attrs=attrs) with pytest.warns(UserWarning, match="xarray only supports"): @@ -2006,17 +2060,16 @@ def test_literal_timedelta_coding_non_pandas_fine_resolution_warning() -> None: assert decoded.dtype == np.dtype("timedelta64[ns]") -@pytest.mark.parametrize("attribute", ["dtype", "units"]) -def test_literal_timedelta_decode_invalid_encoding(attribute) -> None: +def test_timedelta_decode_via_dtype_invalid_encoding() -> None: attrs = {"dtype": "timedelta64[s]", "units": "seconds"} - encoding = {attribute: "foo"} + encoding = {"units": "foo"} encoded = Variable(["time"], [0, 1, 2], attrs=attrs, encoding=encoding) with pytest.raises(ValueError, match="failed to prevent"): conventions.decode_cf_variable("timedeltas", encoded) @pytest.mark.parametrize("attribute", ["dtype", "units"]) -def test_literal_timedelta_encode_invalid_attribute(attribute) -> None: +def test_timedelta_encode_via_dtype_invalid_attribute(attribute) -> None: timedeltas = pd.timedelta_range(0, freq="D", periods=3) attrs = {attribute: "foo"} variable = Variable(["time"], timedeltas, attrs=attrs) @@ -2024,23 +2077,6 @@ def test_literal_timedelta_encode_invalid_attribute(attribute) -> None: conventions.encode_cf_variable(variable) -@pytest.mark.parametrize("invalid_key", _INVALID_LITERAL_TIMEDELTA64_ENCODING_KEYS) -def test_literal_timedelta_encoding_invalid_key_error(invalid_key) -> None: - encoding = {invalid_key: 1.0} - timedeltas = pd.timedelta_range(0, freq="D", periods=3) - variable = Variable(["time"], timedeltas, encoding=encoding) - with pytest.raises(ValueError, match=invalid_key): - conventions.encode_cf_variable(variable) - - -@pytest.mark.parametrize("invalid_key", _INVALID_LITERAL_TIMEDELTA64_ENCODING_KEYS) -def test_literal_timedelta_decoding_invalid_key_error(invalid_key) -> None: - attrs = {invalid_key: 1.0, "dtype": "timedelta64[s]", "units": "seconds"} - variable = Variable(["time"], [0, 1, 2], attrs=attrs) - with pytest.raises(ValueError, match=invalid_key): - conventions.decode_cf_variable("foo", variable) - - @pytest.mark.parametrize( ("decode_via_units", "decode_via_dtype", "attrs", "expect_timedelta64"), [ @@ -2058,12 +2094,6 @@ def test_literal_timedelta_decoding_invalid_key_error(invalid_key) -> None: def test_timedelta_decoding_options( decode_via_units, decode_via_dtype, attrs, expect_timedelta64 ) -> None: - # Note with literal timedelta encoding, we always add a _FillValue, even - # if one is not present in the original encoding parameters, which is why - # we ensure one is defined here when "dtype" is present in attrs. - if "dtype" in attrs: - attrs["_FillValue"] = np.iinfo(np.int64).min - array = np.array([0, 1, 2], dtype=np.dtype("int64")) encoded = Variable(["time"], array, attrs=attrs) @@ -2083,7 +2113,11 @@ def test_timedelta_decoding_options( # Confirm we exactly roundtrip. reencoded = conventions.encode_cf_variable(decoded) - assert_identical(reencoded, encoded) + + expected = encoded.copy() + if "dtype" not in attrs and decode_via_units: + expected.attrs["dtype"] = "timedelta64[s]" + assert_identical(reencoded, expected) def test_timedelta_encoding_explicit_non_timedelta64_dtype() -> None: @@ -2093,20 +2127,21 @@ def test_timedelta_encoding_explicit_non_timedelta64_dtype() -> None: encoded = conventions.encode_cf_variable(variable) assert encoded.attrs["units"] == "days" + assert encoded.attrs["dtype"] == "timedelta64[ns]" assert encoded.dtype == np.dtype("int32") - with pytest.warns(FutureWarning, match="timedelta"): - decoded = conventions.decode_cf_variable("foo", encoded) + decoded = conventions.decode_cf_variable("foo", encoded) assert_identical(decoded, variable) reencoded = conventions.encode_cf_variable(decoded) assert_identical(reencoded, encoded) assert encoded.attrs["units"] == "days" + assert encoded.attrs["dtype"] == "timedelta64[ns]" assert encoded.dtype == np.dtype("int32") @pytest.mark.parametrize("mask_attribute", ["_FillValue", "missing_value"]) -def test_literal_timedelta64_coding_with_mask( +def test_timedelta64_coding_via_dtype_with_mask( time_unit: PDDatetimeUnitOptions, mask_attribute: str ) -> None: timedeltas = np.array([0, 1, "NaT"], dtype=f"timedelta64[{time_unit}]") @@ -2122,7 +2157,7 @@ def test_literal_timedelta64_coding_with_mask( assert encoded[-1] == mask decoded = conventions.decode_cf_variable("timedeltas", encoded) - assert decoded.encoding["dtype"] == expected_dtype + assert decoded.encoding["dtype"] == np.dtype("int64") assert decoded.encoding["units"] == expected_units assert decoded.encoding[mask_attribute] == mask assert np.isnat(decoded[-1]) @@ -2144,7 +2179,7 @@ def test_roundtrip_0size_timedelta(time_unit: PDDatetimeUnitOptions) -> None: assert encoded.dtype == encoding["dtype"] assert encoded.attrs["units"] == encoding["units"] decoded = conventions.decode_cf_variable("foo", encoded, decode_timedelta=True) - assert decoded.dtype == np.dtype("=m8[ns]") + assert decoded.dtype == np.dtype(f"=m8[{time_unit}]") with assert_no_warnings(): decoded.load() assert decoded.dtype == np.dtype("=m8[s]") pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy