diff --git a/xarray/core/dask_array_ops.py b/xarray/core/dask_array_ops.py index 8bf9c68b727..2dca38538e1 100644 --- a/xarray/core/dask_array_ops.py +++ b/xarray/core/dask_array_ops.py @@ -1,6 +1,7 @@ from __future__ import annotations import math +from functools import partial from xarray.core import dtypes, nputils @@ -75,6 +76,47 @@ def least_squares(lhs, rhs, rcond=None, skipna=False): return coeffs, residuals +def _fill_with_last_one(a, b): + import numpy as np + + # cumreduction apply the push func over all the blocks first so, + # the only missing part is filling the missing values using the + # last data of the previous chunk + return np.where(np.isnan(b), a, b) + + +def _dtype_push(a, axis, dtype=None): + from xarray.core.duck_array_ops import _push + + # Not sure why the blelloch algorithm force to receive a dtype + return _push(a, axis=axis) + + +def _reset_cumsum(a, axis, dtype=None): + import numpy as np + + cumsum = np.cumsum(a, axis=axis) + reset_points = np.maximum.accumulate(np.where(a == 0, cumsum, 0), axis=axis) + return cumsum - reset_points + + +def _last_reset_cumsum(a, axis, keepdims=None): + import numpy as np + + # Take the last cumulative sum taking into account the reset + # This is useful for blelloch method + return np.take(_reset_cumsum(a, axis=axis), axis=axis, indices=[-1]) + + +def _combine_reset_cumsum(a, b, axis): + import numpy as np + + # It is going to sum the previous result until the first + # non nan value + bitmask = np.cumprod(b != 0, axis=axis) + return np.where(bitmask, b + a, b) + + def push(array, n, axis, method="blelloch"): """ Dask-aware bottleneck.push @@ -91,16 +133,6 @@ def push(array, n, axis, method="blelloch"): # TODO: Replace all this function # once https://github.com/pydata/xarray/issues/9229 being implemented - def _fill_with_last_one(a, b): - # cumreduction apply the push func over all the blocks first so, - # the only missing part is filling the missing values using the - # last data of the previous chunk - return np.where(np.isnan(b), a, b) - - def _dtype_push(a, axis, dtype=None): - # Not sure why the blelloch algorithm force to receive a dtype - return _push(a, axis=axis) - pushed_array = da.reductions.cumreduction( func=_dtype_push, binop=_fill_with_last_one, @@ -113,26 +145,9 @@ def _dtype_push(a, axis, dtype=None): ) if n is not None and 0 < n < array.shape[axis] - 1: - - def _reset_cumsum(a, axis, dtype=None): - cumsum = np.cumsum(a, axis=axis) - reset_points = np.maximum.accumulate(np.where(a == 0, cumsum, 0), axis=axis) - return cumsum - reset_points - - def _last_reset_cumsum(a, axis, keepdims=None): - # Take the last cumulative sum taking into account the reset - # This is useful for blelloch method - return np.take(_reset_cumsum(a, axis=axis), axis=axis, indices=[-1]) - - def _combine_reset_cumsum(a, b): - # It is going to sum the previous result until the first - # non nan value - bitmask = np.cumprod(b != 0, axis=axis) - return np.where(bitmask, b + a, b) - valid_positions = da.reductions.cumreduction( func=_reset_cumsum, - binop=_combine_reset_cumsum, + binop=partial(_combine_reset_cumsum, axis=axis), ident=0, x=da.isnan(array, dtype=int), axis=axis,
Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.
Alternative Proxies: