diff --git a/doc/release/upcoming_changes/28516.c_api.rst b/doc/release/upcoming_changes/28516.c_api.rst new file mode 100644 index 000000000000..ec4cf0aa2d6b --- /dev/null +++ b/doc/release/upcoming_changes/28516.c_api.rst @@ -0,0 +1,17 @@ +New comparison and null handling enums for sorting in dtype API +--------------------------------------------------------------- + +Using the new `NPY_DT_sort_compare` slot, user-defined dtypes can +now specify how to compare elements during sorting operations. +The sort compare function should return a member of the +`NPY_COMPARE_RESULT` enum to indicate the result of the comparison, +including support for unordered comparisons. + +The sorting context of type `PyArrayMethod_SortContext` is passed +to the `NPY_DT_get_sort_function` and `NPY_DT_get_argsort_function` +functions and contains a boolean `descending` flag and a +`nan_position` of type `NPY_SORT_NAN_POSITION`, which can be used to +control the behavior of sorting with respect to NaN values. +Currently, sorts are always ascending and nulls are always sorted last, +but this must be checked in the context passed to the sort function +to allow for future features. \ No newline at end of file diff --git a/doc/release/upcoming_changes/28516.new_feature.rst b/doc/release/upcoming_changes/28516.new_feature.rst new file mode 100644 index 000000000000..4117a9eed988 --- /dev/null +++ b/doc/release/upcoming_changes/28516.new_feature.rst @@ -0,0 +1,11 @@ +New sorting function slots `NPY_DT_get_sort_function`, `NPY_DT_get_argsort_function` for dtype API +--------------------------------------------------------------------------------------------------- + +User-defined dtypes can now provide specific sorting functions for use with NumPy's sort methods. +The new slots `NPY_DT_get_sort_function` and `NPY_DT_get_argsort_function` should be functions that +return function pointers implementing sorting functionality for the dtype, while considering the +sort-kind and order. The old arrfunc slots ``NPY_DT_PyArray_ArrFuncs_sort`` and +``NPY_DT_PyArray_ArrFuncs_argsort`` may be deprecated in the future. + +Additionally, the new `NPY_DT_sort_compare` slot can be used to provide a comparison function for +sorting, which will replace the default comparison function for the dtype in sorting functions. \ No newline at end of file diff --git a/doc/source/reference/c-api/array.rst b/doc/source/reference/c-api/array.rst index 02db78ebb2b1..aedade624f04 100644 --- a/doc/source/reference/c-api/array.rst +++ b/doc/source/reference/c-api/array.rst @@ -1873,6 +1873,36 @@ described below. pointer. Currently this is used for zero-filling and clearing arrays storing embedded references. +.. c:type:: int (PyArray_SortFuncWithContext)( \ + PyArrayMethod_SortContext *context, void *data, \ + npy_intp num, NpyAuxData *auxdata) + + A function to sort a buffer of data. The *data* is a pointer to the + beginning of the contiguous buffer containing *num* elements. A function + of this type is returned by the `get_sort_function` function in the DType + slots, where *context* is passed in containing the descriptor for the + array. Returns 0 on success, -1 on failure. + +.. c:type:: int (PyArray_ArgSortFuncWithContext)( \ + PyArrayMethod_SortContext *context, void *data, \ + npy_intp *tosort, npy_intp num, NpyAuxData *auxdata) + + A function to arg-sort a buffer of data. The *data* is a pointer to the + beginning of the buffer containing *num* elements. The *tosort* is a + pointer to an array of indices that will be filled in with the + indices of the sorted elements. A function of this type is returned by + the `get_argsort_function` function in the DType slots, where + *context* is passed in containing the descriptor for the array. + Returns 0 on success, -1 on failure. + +.. c:type:: NPY_COMPARE_RESULT (PyArray_SortCompareFunc) ( \ + const void *a, const void *b, PyArray_Descr *descr) + + A function to compare two elements of an array for sorting. The *a* and *b* + pointers point to the elements to compare, and *descr* is the descriptor for + the array. Returns a value of type :c:type:`NPY_COMPARE_RESULT` indicating + the result of the comparison, including whether each element is unordered. + API Functions and Typedefs ~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -3521,6 +3551,40 @@ member of ``PyArrayDTypeMeta_Spec`` struct. force newly created arrays to have a newly created descriptor instance, no matter what input descriptor is provided by a user. +.. c:macro:: NPY_DT_get_sort_function + +.. c:type:: int *(PyArrayDTypeMeta_GetSortFunction)(PyArray_Descr *, \ + npy_intp sort_kind, PyArray_SortFuncWithContext **out_sort, \ + NpyAuxData **out_auxdata, NPY_ARRAYMETHOD_FLAGS *out_flags) + + .. versionadded:: 2.4 + + If defined, sets a custom sorting function for the DType for each of + the sort kinds numpy implements. Currently, sorts are always descending + and always use nulls to the end, and this must be checked in the + implementation. Returns 0 on success. + +.. c:macro:: NPY_DT_get_argsort_function + +.. c:type:: int *(PyArrayDTypeMeta_GetArgSortFunction)(PyArray_Descr *, \ + npy_intp sort_kind, PyArray_ArgSortFuncWithContext **out_argsort, \ + NpyAuxData **out_auxdata, NPY_ARRAYMETHOD_FLAGS *out_flags) + + .. versionadded:: 2.4 + + If defined, sets a custom argsorting function for the DType for each of + the sort kinds numpy implements. Currently, sorts are always descending + and always use nulls to the end, and this must be checked in the + implementation. Returns 0 on success. + +.. c:macro:: NPY_DT_sort_compare + + .. versionadded:: 2.4 + + If defined, sets a custom comparison function for the DType for use in + sorting, which will replace `NPY_DT_PyArray_ArrFuncs_compare`. Implements + ``PyArray_CompareFunc``. + PyArray_ArrFuncs slots ^^^^^^^^^^^^^^^^^^^^^^ @@ -3547,6 +3611,8 @@ DType API slots but for now we have exposed the legacy .. c:macro:: NPY_DT_PyArray_ArrFuncs_compare Computes a comparison for `numpy.sort`, implements ``PyArray_CompareFunc``. + If `NPY_DT_sort_compare` is defined, it will be used instead. This slot may + be deprecated in the future. .. c:macro:: NPY_DT_PyArray_ArrFuncs_argmax @@ -3590,13 +3656,17 @@ DType API slots but for now we have exposed the legacy An array of PyArray_SortFunc of length ``NPY_NSORTS``. If set, allows defining custom sorting implementations for each of the sorting - algorithms numpy implements. + algorithms numpy implements. If `NPY_DT_get_sort_function` is + defined, it will be used instead. This slot may be deprecated in the + future. .. c:macro:: NPY_DT_PyArray_ArrFuncs_argsort An array of PyArray_ArgSortFunc of length ``NPY_NSORTS``. If set, allows defining custom argsorting implementations for each of the - sorting algorithms numpy implements. + sorting algorithms numpy implements. If `NPY_DT_get_argsort_function` + is defined, it will be used instead. This slot may be deprecated in + the future. Macros and Static Inline Functions ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -4340,6 +4410,40 @@ Enumerated Types :c:data:`NPY_STABLESORT` are aliased to each other and may refer to one of several stable sorting algorithms depending on the data type. +.. c:enum:: NPY_SORT_NAN_POSITION + + An enum used to indicate the position of NaN values in sorting. + + .. c:enumerator:: NPY_SORT_NAN_TO_START + + Indicates that NaN values should be sorted to the start. + + .. c:enumerator:: NPY_SORT_NAN_TO_END + + Indicates that NaN values should be sorted to the end. + +.. c:enum:: NPY_COMPARE_RESULT + + An enum used to indicate the result of a comparison operation. + The unordered comparisons are used to indicate that the + comparison is not well-defined for one or both of the operands, + such as when comparing NaN values. + + .. c:enumerator:: NPY_LESS + + .. c:enumerator:: NPY_EQUAL + + .. c:enumerator:: NPY_GREATER + + .. c:enumerator:: NPY_UNORDERED_LEFT + + .. c:enumerator:: NPY_UNORDERED_RIGHT + + .. c:enumerator:: NPY_UNORDERED_BOTH + + .. c:enumerator:: NPY_COMPARE_ERROR + + Indicates that an error occurred during the comparison operation. .. c:enum:: NPY_SCALARKIND diff --git a/doc/source/reference/c-api/types-and-structures.rst b/doc/source/reference/c-api/types-and-structures.rst index 3f16b5f4dbc4..1b24fe854943 100644 --- a/doc/source/reference/c-api/types-and-structures.rst +++ b/doc/source/reference/c-api/types-and-structures.rst @@ -792,6 +792,41 @@ PyArrayMethod_Context and PyArrayMethod_Spec An array of slots for the method. Slot IDs must be one of the values below. +.. _arraymethod-sort-context: + +PyArrayMethod_SortContext +------------------------- + +.. c:type:: PyArrayMethod_SortContext + + A struct used to provide context for sorting methods. + + .. code-block:: c + + typedef struct { + PyArray_Descr *descriptor; + PyArray_SortCompareFunc *compare; + npy_bool descending; + NPY_SORT_NAN_POSITION nan_position; + } PyArrayMethod_SortContext + + .. c:member:: PyArray_Descr *descriptor + + The descriptor for the data type being sorted. + + .. c:member:: PyArray_SortCompareFunc *compare + + A pointer to the comparison function used for sorting. This function + can be NULL if the sort is not based on a comparison function. + + .. c:member:: npy_bool descending + + A flag indicating whether the sort is descending. + + .. c:member:: NPY_SORT_NAN_POSITION nan_position + + The position of NaN values in the sort order. + .. _dtypemeta: PyArray_DTypeMeta and PyArrayDTypeMeta_Spec diff --git a/numpy/_core/include/numpy/dtype_api.h b/numpy/_core/include/numpy/dtype_api.h index b37c9fbb6821..d9ecc7321289 100644 --- a/numpy/_core/include/numpy/dtype_api.h +++ b/numpy/_core/include/numpy/dtype_api.h @@ -368,6 +368,12 @@ typedef int (PyArrayMethod_PromoterFunction)(PyObject *ufunc, #define NPY_DT_get_fill_zero_loop 10 #define NPY_DT_finalize_descr 11 +#if NPY_API_VERSION >= NPY_2_4_API_VERSION +#define NPY_DT_get_sort_function 12 +#define NPY_DT_get_argsort_function 13 +#define NPY_DT_sort_compare 14 +#endif + // These PyArray_ArrFunc slots will be deprecated and replaced eventually // getitem and setitem can be defined as a performance optimization; // by default the user dtypes call `legacy_getitem_using_DType` and @@ -477,4 +483,45 @@ typedef PyArray_Descr *(PyArrayDTypeMeta_FinalizeDescriptor)(PyArray_Descr *dtyp typedef int(PyArrayDTypeMeta_SetItem)(PyArray_Descr *, PyObject *, char *); typedef PyObject *(PyArrayDTypeMeta_GetItem)(PyArray_Descr *, char *); +typedef enum { + NPY_LESS = -1, + NPY_EQUAL = 0, + NPY_GREATER = 1, + NPY_UNORDERED_LEFT = 2, + NPY_UNORDERED_RIGHT = 3, + NPY_UNORDERED_BOTH = 4, + NPY_COMPARE_ERROR = 5, +} NPY_COMPARE_RESULT; + +typedef struct PyArrayMethod_SortContext_tag PyArrayMethod_SortContext; + +typedef NPY_COMPARE_RESULT (PyArray_SortCompareFunc)( + const void *a, const void *b, PyArray_Descr *descr); + +typedef enum { + NPY_SORT_NAN_TO_START = 0, + NPY_SORT_NAN_TO_END = 1, +} NPY_SORT_NAN_POSITION; + +struct PyArrayMethod_SortContext_tag { + PyArray_Descr *descriptor; + PyArray_SortCompareFunc *compare; + npy_bool descending; + NPY_SORT_NAN_POSITION nan_position; +}; + +typedef int (PyArray_SortFuncWithContext)(PyArrayMethod_SortContext *, + void *, npy_intp, + NpyAuxData *); +typedef int (PyArray_ArgSortFuncWithContext)(PyArrayMethod_SortContext *, + void *, npy_intp *, npy_intp, + NpyAuxData *); + +typedef int *(PyArrayDTypeMeta_GetSortFunction)(PyArray_Descr *, + npy_intp, PyArray_SortFuncWithContext **, NpyAuxData **, + NPY_ARRAYMETHOD_FLAGS *); +typedef int *(PyArrayDTypeMeta_GetArgSortFunction)(PyArray_Descr *, + npy_intp, PyArray_ArgSortFuncWithContext **, NpyAuxData **, + NPY_ARRAYMETHOD_FLAGS *); + #endif /* NUMPY_CORE_INCLUDE_NUMPY___DTYPE_API_H_ */ diff --git a/numpy/_core/include/numpy/numpyconfig.h b/numpy/_core/include/numpy/numpyconfig.h index 52d7e2b5d7d7..c110baf9d379 100644 --- a/numpy/_core/include/numpy/numpyconfig.h +++ b/numpy/_core/include/numpy/numpyconfig.h @@ -84,6 +84,7 @@ #define NPY_2_1_API_VERSION 0x00000013 #define NPY_2_2_API_VERSION 0x00000013 #define NPY_2_3_API_VERSION 0x00000014 +#define NPY_2_4_API_VERSION 0x00000015 /* diff --git a/numpy/_core/src/common/npy_sort.h.src b/numpy/_core/src/common/npy_sort.h.src index d6e4357225a8..d14f247c00ce 100644 --- a/numpy/_core/src/common/npy_sort.h.src +++ b/numpy/_core/src/common/npy_sort.h.src @@ -5,6 +5,7 @@ #include #include #include +#include #define NPY_ENOMEM 1 #define NPY_ECOMP 2 @@ -97,6 +98,22 @@ NPY_NO_EXPORT int atimsort_@suff@(void *vec, npy_intp *ind, npy_intp cnt, void * ***************************************************************************** */ +NPY_NO_EXPORT int npy_quicksort_with_context(PyArrayMethod_SortContext *context, void *vec, + npy_intp cnt, NpyAuxData *auxdata); +NPY_NO_EXPORT int npy_heapsort_with_context(PyArrayMethod_SortContext *context, void *vec, + npy_intp cnt, NpyAuxData *auxdata); +NPY_NO_EXPORT int npy_mergesort_with_context(PyArrayMethod_SortContext *context, void *vec, + npy_intp cnt, NpyAuxData *auxdata); +NPY_NO_EXPORT int npy_timsort_with_context(PyArrayMethod_SortContext *context, void *vec, + npy_intp cnt, NpyAuxData *auxdata); +NPY_NO_EXPORT int npy_aquicksort_with_context(PyArrayMethod_SortContext *context, void *vec, + npy_intp *ind, npy_intp cnt, NpyAuxData *auxdata); +NPY_NO_EXPORT int npy_aheapsort_with_context(PyArrayMethod_SortContext *context, void *vec, + npy_intp *ind, npy_intp cnt, NpyAuxData *auxdata); +NPY_NO_EXPORT int npy_amergesort_with_context(PyArrayMethod_SortContext *context, void *vec, + npy_intp *ind, npy_intp cnt, NpyAuxData *auxdata); +NPY_NO_EXPORT int npy_atimsort_with_context(PyArrayMethod_SortContext *context, void *vec, + npy_intp *ind, npy_intp cnt, NpyAuxData *auxdata); NPY_NO_EXPORT int npy_quicksort(void *vec, npy_intp cnt, void *arr); NPY_NO_EXPORT int npy_heapsort(void *vec, npy_intp cnt, void *arr); @@ -107,6 +124,15 @@ NPY_NO_EXPORT int npy_aheapsort(void *vec, npy_intp *ind, npy_intp cnt, void *ar NPY_NO_EXPORT int npy_amergesort(void *vec, npy_intp *ind, npy_intp cnt, void *arr); NPY_NO_EXPORT int npy_atimsort(void *vec, npy_intp *ind, npy_intp cnt, void *arr); +NPY_NO_EXPORT int npy_quicksort_impl(void *vec, npy_intp cnt, void *arr, PyArrayMethod_SortContext *context); +NPY_NO_EXPORT int npy_heapsort_impl(void *vec, npy_intp cnt, void *arr, PyArrayMethod_SortContext *context); +NPY_NO_EXPORT int npy_mergesort_impl(void *vec, npy_intp cnt, void *arr, PyArrayMethod_SortContext *context); +NPY_NO_EXPORT int npy_timsort_impl(void *vec, npy_intp cnt, void *arr, PyArrayMethod_SortContext *context); +NPY_NO_EXPORT int npy_aquicksort_impl(void *vec, npy_intp *ind, npy_intp cnt, void *arr, PyArrayMethod_SortContext *context); +NPY_NO_EXPORT int npy_aheapsort_impl(void *vec, npy_intp *ind, npy_intp cnt, void *arr, PyArrayMethod_SortContext *context); +NPY_NO_EXPORT int npy_amergesort_impl(void *vec, npy_intp *ind, npy_intp cnt, void *arr, PyArrayMethod_SortContext *context); +NPY_NO_EXPORT int npy_atimsort_impl(void *vec, npy_intp *ind, npy_intp cnt, void *arr, PyArrayMethod_SortContext *context); + #ifdef __cplusplus } #endif diff --git a/numpy/_core/src/multiarray/dtypemeta.c b/numpy/_core/src/multiarray/dtypemeta.c index 0b1b0fb39192..0f01ad8c8222 100644 --- a/numpy/_core/src/multiarray/dtypemeta.c +++ b/numpy/_core/src/multiarray/dtypemeta.c @@ -194,6 +194,9 @@ dtypemeta_initialize_struct_from_spec( NPY_DT_SLOTS(DType)->getitem = NULL; NPY_DT_SLOTS(DType)->get_clear_loop = NULL; NPY_DT_SLOTS(DType)->get_fill_zero_loop = NULL; + NPY_DT_SLOTS(DType)->get_sort_function = NULL; + NPY_DT_SLOTS(DType)->get_argsort_function = NULL; + NPY_DT_SLOTS(DType)->sort_compare = NULL; NPY_DT_SLOTS(DType)->finalize_descr = NULL; NPY_DT_SLOTS(DType)->f = default_funcs; @@ -1230,6 +1233,15 @@ dtypemeta_wrap_legacy_descriptor( dtype_class->flags |= NPY_DT_NUMERIC; } + if (dt_slots->sort_compare == NULL) { + if (!NPY_DT_is_legacy(dtype_class) && !NPY_DT_is_user_defined(dtype_class)) { + PyErr_SetString(PyExc_RuntimeError, + "DType has no sort_compare function."); + Py_DECREF(dtype_class); + return -1; + } + } + if (_PyArray_MapPyTypeToDType(dtype_class, descr->typeobj, PyTypeNum_ISUSERDEF(dtype_class->type_num)) < 0) { Py_DECREF(dtype_class); diff --git a/numpy/_core/src/multiarray/dtypemeta.h b/numpy/_core/src/multiarray/dtypemeta.h index a8b78e3f7518..ecc87bb92124 100644 --- a/numpy/_core/src/multiarray/dtypemeta.h +++ b/numpy/_core/src/multiarray/dtypemeta.h @@ -67,6 +67,12 @@ typedef struct { * parameters, if any, as the operand dtype. */ PyArrayDTypeMeta_FinalizeDescriptor *finalize_descr; + + /* DType sorting methods. */ + PyArrayDTypeMeta_GetSortFunction *get_sort_function; + PyArrayDTypeMeta_GetArgSortFunction *get_argsort_function; + PyArray_SortCompareFunc *sort_compare; + /* * The casting implementation (ArrayMethod) to convert between two * instances of this DType, stored explicitly for fast access: @@ -89,7 +95,11 @@ typedef struct { // This must be updated if new slots before within_dtype_castingimpl // are added +#if NPY_API_VERSION >= NPY_2_4_API_VERSION +#define NPY_NUM_DTYPE_SLOTS 14 +#else #define NPY_NUM_DTYPE_SLOTS 11 +#endif #define NPY_NUM_DTYPE_PYARRAY_ARRFUNCS_SLOTS 22 #define NPY_DT_MAX_ARRFUNCS_SLOT \ NPY_NUM_DTYPE_PYARRAY_ARRFUNCS_SLOTS + _NPY_DT_ARRFUNCS_OFFSET @@ -291,6 +301,43 @@ PyArray_SETITEM(PyArrayObject *arr, char *itemptr, PyObject *v) Py_XSETREF(descr, _new_); \ } while(0) +static inline int +PyArray_GetSortFunction(PyArray_Descr *descr, + NPY_SORTKIND which, PyArray_SortFuncWithContext **out_sort, + NpyAuxData **out_auxdata, NPY_ARRAYMETHOD_FLAGS *out_flags) +{ + if (NPY_DT_SLOTS(NPY_DTYPE(descr))->get_sort_function == NULL) { + return -1; + } + + if (NPY_DT_SLOTS(NPY_DTYPE(descr))->get_sort_function( + descr, which, out_sort, out_auxdata, out_flags) == NULL) { + return -1; + } + return 0; +} + +static inline int +PyArray_GetArgSortFunction(PyArray_Descr *descr, + NPY_SORTKIND which, PyArray_ArgSortFuncWithContext **out_argsort, + NpyAuxData **out_auxdata, NPY_ARRAYMETHOD_FLAGS *out_flags) +{ + if (NPY_DT_SLOTS(NPY_DTYPE(descr))->get_argsort_function == NULL) { + return -1; + } + + if (NPY_DT_SLOTS(NPY_DTYPE(descr))->get_argsort_function( + descr, which, out_argsort, out_auxdata, out_flags) == NULL) { + return -1; + } + return 0; +} + +static inline PyArray_SortCompareFunc * +PyArray_GetSortCompareFunction(PyArray_Descr *descr) +{ + return NPY_DT_SLOTS(NPY_DTYPE(descr))->sort_compare; +} // Get the pointer to the PyArray_DTypeMeta for the type associated with the typenum. static inline PyArray_DTypeMeta * diff --git a/numpy/_core/src/multiarray/item_selection.c b/numpy/_core/src/multiarray/item_selection.c index d2db10633810..71539ec82bb6 100644 --- a/numpy/_core/src/multiarray/item_selection.c +++ b/numpy/_core/src/multiarray/item_selection.c @@ -1191,7 +1191,8 @@ PyArray_Choose(PyArrayObject *ip, PyObject *op, PyArrayObject *out, * over all but the desired sorting axis. */ static int -_new_sortlike(PyArrayObject *op, int axis, PyArray_SortFunc *sort, +_new_sortlike(PyArrayObject *op, int axis, PyArray_SortFuncWithContext *sort, + PyArray_SortFunc *sort_with_array, NpyAuxData *auxdata, PyArray_PartitionFunc *part, npy_intp const *kth, npy_intp nkth) { npy_intp N = PyArray_DIM(op, axis); @@ -1215,6 +1216,12 @@ _new_sortlike(PyArrayObject *op, int axis, PyArray_SortFunc *sort, NPY_cast_info to_cast_info = {.func = NULL}; NPY_cast_info from_cast_info = {.func = NULL}; + PyArrayMethod_SortContext context = { + .descriptor = descr, + .descending = NPY_FALSE, + .nan_position = NPY_SORT_NAN_TO_END, + }; + NPY_BEGIN_THREADS_DEF; /* Check if there is any sorting to do */ @@ -1293,7 +1300,14 @@ _new_sortlike(PyArrayObject *op, int axis, PyArray_SortFunc *sort, */ if (part == NULL) { - ret = sort(bufptr, N, op); + if (sort != NULL) { + context.compare = PyArray_GetSortCompareFunction(descr); + + ret = sort(&context, bufptr, N, auxdata); + } + else { + ret = sort_with_array(bufptr, N, op); + } if (needs_api && PyErr_Occurred()) { ret = -1; } @@ -1358,8 +1372,9 @@ _new_sortlike(PyArrayObject *op, int axis, PyArray_SortFunc *sort, } static PyObject* -_new_argsortlike(PyArrayObject *op, int axis, PyArray_ArgSortFunc *argsort, - PyArray_ArgPartitionFunc *argpart, +_new_argsortlike(PyArrayObject *op, int axis, PyArray_ArgSortFuncWithContext *argsort, + PyArray_ArgSortFunc *argsort_with_array, + NpyAuxData *auxdata, PyArray_ArgPartitionFunc *argpart, npy_intp const *kth, npy_intp nkth) { npy_intp N = PyArray_DIM(op, axis); @@ -1388,6 +1403,12 @@ _new_argsortlike(PyArrayObject *op, int axis, PyArray_ArgSortFunc *argsort, NPY_ARRAYMETHOD_FLAGS transfer_flags; NPY_cast_info cast_info = {.func = NULL}; + PyArrayMethod_SortContext context = { + .descriptor = descr, + .descending = NPY_FALSE, + .nan_position = NPY_SORT_NAN_TO_END, + }; + NPY_BEGIN_THREADS_DEF; PyObject *mem_handler = PyDataMem_GetHandler(); @@ -1483,8 +1504,15 @@ _new_argsortlike(PyArrayObject *op, int axis, PyArray_ArgSortFunc *argsort, } if (argpart == NULL) { - ret = argsort(valptr, idxptr, N, op); - /* Object comparisons may raise an exception */ + if (argsort != NULL) { + context.compare = PyArray_GetSortCompareFunction(descr); + + ret = argsort(&context, valptr, idxptr, N, auxdata); + } + else { + ret = argsort_with_array(valptr, idxptr, N, op); + } + /* Object comparisons may raise an exception in Python 3 */ if (needs_api && PyErr_Occurred()) { ret = -1; } @@ -1554,7 +1582,12 @@ _new_argsortlike(PyArrayObject *op, int axis, PyArray_ArgSortFunc *argsort, NPY_NO_EXPORT int PyArray_Sort(PyArrayObject *op, int axis, NPY_SORTKIND which) { - PyArray_SortFunc *sort = NULL; + PyArray_SortFuncWithContext *sort = NULL; + PyArray_SortFunc *sort_with_array = NULL; + + NpyAuxData *auxdata = NULL; + NPY_ARRAYMETHOD_FLAGS flags = 0; + int n = PyArray_NDIM(op); if (check_and_adjust_axis(&axis, n) < 0) { @@ -1570,20 +1603,39 @@ PyArray_Sort(PyArrayObject *op, int axis, NPY_SORTKIND which) return -1; } - sort = PyDataType_GetArrFuncs(PyArray_DESCR(op))->sort[which]; + if (PyArray_GetSortFunction(PyArray_DESCR(op), which, &sort, &auxdata, &flags) < 0) { + sort_with_array = PyDataType_GetArrFuncs(PyArray_DESCR(op))->sort[which]; + } + + if (sort_with_array == NULL) { + if (PyArray_GetSortCompareFunction(PyArray_DESCR(op)) != NULL) { + switch (which) { + default: + case NPY_QUICKSORT: + sort = npy_quicksort_with_context; + break; + case NPY_HEAPSORT: + sort = npy_heapsort_with_context; + break; + case NPY_STABLESORT: + sort = npy_timsort_with_context; + break; + } + } + } if (sort == NULL) { - if (PyDataType_GetArrFuncs(PyArray_DESCR(op))->compare) { + if (PyDataType_GetArrFuncs(PyArray_DESCR(op))->compare != NULL) { switch (which) { default: case NPY_QUICKSORT: - sort = npy_quicksort; + sort_with_array = npy_quicksort; break; case NPY_HEAPSORT: - sort = npy_heapsort; + sort_with_array = npy_heapsort; break; case NPY_STABLESORT: - sort = npy_timsort; + sort_with_array = npy_timsort; break; } } @@ -1594,7 +1646,7 @@ PyArray_Sort(PyArrayObject *op, int axis, NPY_SORTKIND which) } } - return _new_sortlike(op, axis, sort, NULL, NULL, 0); + return _new_sortlike(op, axis, sort, sort_with_array, auxdata, NULL, NULL, 0); } @@ -1702,7 +1754,7 @@ PyArray_Partition(PyArrayObject *op, PyArrayObject * ktharray, int axis, return -1; } - ret = _new_sortlike(op, axis, sort, part, + ret = _new_sortlike(op, axis, NULL, sort, NULL, part, PyArray_DATA(kthrvl), PyArray_SIZE(kthrvl)); Py_DECREF(kthrvl); @@ -1718,23 +1770,46 @@ NPY_NO_EXPORT PyObject * PyArray_ArgSort(PyArrayObject *op, int axis, NPY_SORTKIND which) { PyArrayObject *op2; - PyArray_ArgSortFunc *argsort = NULL; + PyArray_ArgSortFuncWithContext *argsort = NULL; + PyArray_ArgSortFunc *argsort_with_array = NULL; PyObject *ret; - argsort = PyDataType_GetArrFuncs(PyArray_DESCR(op))->argsort[which]; + NpyAuxData *auxdata = NULL; + NPY_ARRAYMETHOD_FLAGS flags = 0; + + if (PyArray_GetArgSortFunction(PyArray_DESCR(op), which, &argsort, &auxdata, &flags) < 0) { + argsort_with_array = PyDataType_GetArrFuncs(PyArray_DESCR(op))->argsort[which]; + } + + if (argsort_with_array == NULL) { + if (PyArray_GetSortCompareFunction(PyArray_DESCR(op)) != NULL) { + switch (which) { + default: + case NPY_QUICKSORT: + argsort = npy_aquicksort_with_context; + break; + case NPY_HEAPSORT: + argsort = npy_aheapsort_with_context; + break; + case NPY_STABLESORT: + argsort = npy_atimsort_with_context; + break; + } + } + } if (argsort == NULL) { - if (PyDataType_GetArrFuncs(PyArray_DESCR(op))->compare) { + if (PyDataType_GetArrFuncs(PyArray_DESCR(op))->compare != NULL) { switch (which) { default: case NPY_QUICKSORT: - argsort = npy_aquicksort; + argsort_with_array = npy_aquicksort; break; case NPY_HEAPSORT: - argsort = npy_aheapsort; + argsort_with_array = npy_aheapsort; break; case NPY_STABLESORT: - argsort = npy_atimsort; + argsort_with_array = npy_atimsort; break; } } @@ -1750,7 +1825,7 @@ PyArray_ArgSort(PyArrayObject *op, int axis, NPY_SORTKIND which) return NULL; } - ret = _new_argsortlike(op2, axis, argsort, NULL, NULL, 0); + ret = _new_argsortlike(op2, axis, argsort, argsort_with_array, auxdata, NULL, NULL, 0); Py_DECREF(op2); return ret; @@ -1804,7 +1879,7 @@ PyArray_ArgPartition(PyArrayObject *op, PyArrayObject *ktharray, int axis, return NULL; } - ret = _new_argsortlike(op2, axis, argsort, argpart, + ret = _new_argsortlike(op2, axis, NULL, argsort, NULL, argpart, PyArray_DATA(kthrvl), PyArray_SIZE(kthrvl)); Py_DECREF(kthrvl); diff --git a/numpy/_core/src/multiarray/stringdtype/dtype.c b/numpy/_core/src/multiarray/stringdtype/dtype.c index a06e7a1ed1b6..cc57df4cbd97 100644 --- a/numpy/_core/src/multiarray/stringdtype/dtype.c +++ b/numpy/_core/src/multiarray/stringdtype/dtype.c @@ -9,6 +9,7 @@ #include "numpy/arrayobject.h" #include "numpy/ndarraytypes.h" #include "numpy/npy_math.h" +#include "npy_sort.h" #include "static_string.h" #include "dtypemeta.h" @@ -516,6 +517,133 @@ _compare(void *a, void *b, PyArray_StringDTypeObject *descr_a, return NpyString_cmp(&s_a, &s_b); } +#if NPY_API_VERSION >= NPY_2_4_API_VERSION +static NPY_COMPARE_RESULT +stringdtype_sort_compare(void *a, void *b, PyArray_Descr *descr) { + PyArray_StringDTypeObject *string_descr = (PyArray_StringDTypeObject *)descr; + int dist = _compare(a, b, string_descr, string_descr); + + if (dist < 0) { + return NPY_LESS; + } + else if (dist > 0) { + return NPY_GREATER; + } + else { + return NPY_EQUAL; + } +} + +int +_stringdtype_sort(PyArrayMethod_SortContext *context, void *start, npy_intp num, + NpyAuxData *auxdata, PyArray_SortFuncWithContext *sort) { + PyArray_StringDTypeObject *descr = (PyArray_StringDTypeObject *)context->descriptor; + + NpyString_acquire_allocator(descr); + int result = sort(context, start, num, auxdata); + NpyString_release_allocator(descr->allocator); + + return result; +} + +int +_stringdtype_quicksort(PyArrayMethod_SortContext *context, void *start, npy_intp num, + NpyAuxData *auxdata) { + return _stringdtype_sort(context, start, num, auxdata, + &npy_quicksort_with_context); +} + +int +_stringdtype_heapsort(PyArrayMethod_SortContext *context, void *start, npy_intp num, + NpyAuxData *auxdata) { + return _stringdtype_sort(context, start, num, auxdata, + &npy_heapsort_with_context); +} + +int +_stringdtype_timsort(PyArrayMethod_SortContext *context, void *start, npy_intp num, + NpyAuxData *auxdata) { + return _stringdtype_sort(context, start, num, auxdata, + &npy_timsort_with_context); +} + +int +stringdtype_get_sort_function(PyArray_Descr *descr, + NPY_SORTKIND sort_kind, PyArray_SortFuncWithContext **out_sort, + NpyAuxData **NPY_UNUSED(out_auxdata), NPY_ARRAYMETHOD_FLAGS *out_flags) { + + switch (sort_kind) { + default: + case NPY_QUICKSORT: + *out_sort = &_stringdtype_quicksort; + break; + case NPY_HEAPSORT: + *out_sort = &_stringdtype_heapsort; + break; + case NPY_STABLESORT: + *out_sort = &_stringdtype_timsort; + break; + } + *out_flags = NPY_METH_REQUIRES_PYAPI; + return 0; +} + +int +_stringdtype_argsort(PyArrayMethod_SortContext *context, void *vv, npy_intp *tosort, + npy_intp num, NpyAuxData *auxdata, PyArray_ArgSortFuncWithContext *argsort) { + PyArray_StringDTypeObject *descr = (PyArray_StringDTypeObject *)context->descriptor; + + NpyString_acquire_allocator(descr); + int result = argsort(context, vv, tosort, num, auxdata); + NpyString_release_allocator(descr->allocator); + + return result; +} + +int +_stringdtype_aquicksort(PyArrayMethod_SortContext *context, void *vv, npy_intp *tosort, + npy_intp n, NpyAuxData *auxdata) { + return _stringdtype_argsort(context, vv, tosort, n, auxdata, + &npy_aquicksort_with_context); +} + +int +_stringdtype_aheapsort(PyArrayMethod_SortContext *context, void *vv, npy_intp *tosort, + npy_intp n, NpyAuxData *auxdata) { + return _stringdtype_argsort(context, vv, tosort, n, auxdata, + &npy_aheapsort_with_context); +} + +int +_stringdtype_atimsort(PyArrayMethod_SortContext *context, void *vv, npy_intp *tosort, + npy_intp n, NpyAuxData *auxdata) { + return _stringdtype_argsort(context, vv, tosort, n, auxdata, + &npy_atimsort_with_context); +} + +int +stringdtype_get_argsort_function(PyArray_Descr *descr, + NPY_SORTKIND sort_kind, PyArray_ArgSortFuncWithContext **out_argsort, + NpyAuxData **NPY_UNUSED(out_auxdata), NPY_ARRAYMETHOD_FLAGS *out_flags) { + + switch (sort_kind) { + default: + case NPY_QUICKSORT: + *out_argsort = &_stringdtype_aquicksort; + break; + case NPY_HEAPSORT: + *out_argsort = &_stringdtype_aheapsort; + break; + case NPY_STABLESORT: + *out_argsort = &_stringdtype_atimsort; + break; + } + *out_flags = NPY_METH_REQUIRES_PYAPI; + + return 0; +} +#endif // NPY_API_VERSION >= NPY_2_4_API_VERSION + // PyArray_ArgFunc // The max element is the one with the highest unicode code point. int @@ -656,6 +784,11 @@ static PyType_Slot PyArray_StringDType_Slots[] = { &string_discover_descriptor_from_pyobject}, {NPY_DT_setitem, &stringdtype_setitem}, {NPY_DT_getitem, &stringdtype_getitem}, +#if NPY_API_VERSION >= NPY_2_4_API_VERSION + {NPY_DT_sort_compare, &stringdtype_sort_compare}, + {NPY_DT_get_sort_function, &stringdtype_get_sort_function}, + {NPY_DT_get_argsort_function, &stringdtype_get_argsort_function}, +#endif {NPY_DT_ensure_canonical, &stringdtype_ensure_canonical}, {NPY_DT_PyArray_ArrFuncs_nonzero, &nonzero}, {NPY_DT_PyArray_ArrFuncs_compare, &compare}, diff --git a/numpy/_core/src/npysort/heapsort.cpp b/numpy/_core/src/npysort/heapsort.cpp index 492cd47262d8..75edd4ae3b0a 100644 --- a/numpy/_core/src/npysort/heapsort.cpp +++ b/numpy/_core/src/npysort/heapsort.cpp @@ -49,12 +49,40 @@ ***************************************************************************** */ +NPY_NO_EXPORT int +npy_heapsort_with_context(PyArrayMethod_SortContext *context, void *start, npy_intp num, + NpyAuxData *auxdata) +{ + return npy_heapsort_impl(start, num, NULL, context); +} + +NPY_NO_EXPORT int +npy_aheapsort_with_context(PyArrayMethod_SortContext *context, void *vv, npy_intp *tosort, + npy_intp num, NpyAuxData *auxdata) +{ + return npy_aheapsort_impl(vv, tosort, num, NULL, context); +} + NPY_NO_EXPORT int npy_heapsort(void *start, npy_intp num, void *varr) { - PyArrayObject *arr = (PyArrayObject *)varr; - npy_intp elsize = PyArray_ITEMSIZE(arr); - PyArray_CompareFunc *cmp = PyDataType_GetArrFuncs(PyArray_DESCR(arr))->compare; + return npy_heapsort_impl(start, num, varr, NULL); +} + +NPY_NO_EXPORT int +npy_aheapsort(void *vv, npy_intp *tosort, npy_intp n, void *varr) +{ + return npy_aheapsort_impl(vv, tosort, n, varr, NULL); +} + +NPY_NO_EXPORT int +npy_heapsort_impl(void *start, npy_intp num, void *varr, PyArrayMethod_SortContext *context) +{ + void *arr; + npy_intp elsize; + PyArray_CompareFunc *cmp; + fill_sort_data_from_arr_or_context(varr, context, &arr, &elsize, &cmp); + if (elsize == 0) { return 0; /* no need for sorting elements of no size */ } @@ -111,12 +139,14 @@ npy_heapsort(void *start, npy_intp num, void *varr) } NPY_NO_EXPORT int -npy_aheapsort(void *vv, npy_intp *tosort, npy_intp n, void *varr) +npy_aheapsort_impl(void *vv, npy_intp *tosort, npy_intp n, void *varr, PyArrayMethod_SortContext *context) { char *v = (char *)vv; - PyArrayObject *arr = (PyArrayObject *)varr; - npy_intp elsize = PyArray_ITEMSIZE(arr); - PyArray_CompareFunc *cmp = PyDataType_GetArrFuncs(PyArray_DESCR(arr))->compare; + void *arr; + npy_intp elsize; + PyArray_CompareFunc *cmp; + fill_sort_data_from_arr_or_context(varr, context, &arr, &elsize, &cmp); + npy_intp *a, i, j, l, tmp; /* The array needs to be offset by one for heapsort indexing */ diff --git a/numpy/_core/src/npysort/mergesort.cpp b/numpy/_core/src/npysort/mergesort.cpp index 2fac0ccfafcd..bbda4abd586d 100644 --- a/numpy/_core/src/npysort/mergesort.cpp +++ b/numpy/_core/src/npysort/mergesort.cpp @@ -335,9 +335,35 @@ string_amergesort_(type *v, npy_intp *tosort, npy_intp num, void *varr) ***************************************************************************** */ +NPY_NO_EXPORT int +npy_mergesort_with_context(PyArrayMethod_SortContext *context, void *start, npy_intp num, + NpyAuxData *auxdata) +{ + return npy_mergesort_impl(start, num, NULL, context); +} + +NPY_NO_EXPORT int +npy_amergesort_with_context(PyArrayMethod_SortContext *context, void *vv, npy_intp *tosort, + npy_intp num, NpyAuxData *auxdata) +{ + return npy_amergesort_impl(vv, tosort, num, NULL, context); +} + +NPY_NO_EXPORT int +npy_mergesort(void *start, npy_intp num, void *varr) +{ + return npy_mergesort_impl(start, num, varr, NULL); +} + +NPY_NO_EXPORT int +npy_amergesort(void *vv, npy_intp *tosort, npy_intp num, void *varr) +{ + return npy_amergesort_impl(vv, tosort, num, varr, NULL); +} + static void npy_mergesort0(char *pl, char *pr, char *pw, char *vp, npy_intp elsize, - PyArray_CompareFunc *cmp, PyArrayObject *arr) + PyArray_CompareFunc *cmp, void *arr) { char *pi, *pj, *pk, *pm; @@ -381,11 +407,12 @@ npy_mergesort0(char *pl, char *pr, char *pw, char *vp, npy_intp elsize, } NPY_NO_EXPORT int -npy_mergesort(void *start, npy_intp num, void *varr) +npy_mergesort_impl(void *start, npy_intp num, void *varr, PyArrayMethod_SortContext *context) { - PyArrayObject *arr = (PyArrayObject *)varr; - npy_intp elsize = PyArray_ITEMSIZE(arr); - PyArray_CompareFunc *cmp = PyDataType_GetArrFuncs(PyArray_DESCR(arr))->compare; + void *arr; + npy_intp elsize; + PyArray_CompareFunc *cmp; + fill_sort_data_from_arr_or_context(varr, context, &arr, &elsize, &cmp); char *pl = (char *)start; char *pr = pl + num * elsize; char *pw; @@ -413,7 +440,7 @@ npy_mergesort(void *start, npy_intp num, void *varr) static void npy_amergesort0(npy_intp *pl, npy_intp *pr, char *v, npy_intp *pw, - npy_intp elsize, PyArray_CompareFunc *cmp, PyArrayObject *arr) + npy_intp elsize, PyArray_CompareFunc *cmp, void *arr) { char *vp; npy_intp vi, *pi, *pj, *pk, *pm; @@ -457,11 +484,12 @@ npy_amergesort0(npy_intp *pl, npy_intp *pr, char *v, npy_intp *pw, } NPY_NO_EXPORT int -npy_amergesort(void *v, npy_intp *tosort, npy_intp num, void *varr) +npy_amergesort_impl(void *v, npy_intp *tosort, npy_intp num, void *varr, PyArrayMethod_SortContext *context) { - PyArrayObject *arr = (PyArrayObject *)varr; - npy_intp elsize = PyArray_ITEMSIZE(arr); - PyArray_CompareFunc *cmp = PyDataType_GetArrFuncs(PyArray_DESCR(arr))->compare; + void *arr; + npy_intp elsize; + PyArray_CompareFunc *cmp; + fill_sort_data_from_arr_or_context(varr, context, &arr, &elsize, &cmp); npy_intp *pl, *pr, *pw; /* Items that have zero size don't make sense to sort */ diff --git a/numpy/_core/src/npysort/npysort_common.h b/numpy/_core/src/npysort/npysort_common.h index 0680ae52afe3..16e889e7d310 100644 --- a/numpy/_core/src/npysort/npysort_common.h +++ b/numpy/_core/src/npysort/npysort_common.h @@ -10,6 +10,98 @@ extern "C" { #endif + +/* + ***************************************************************************** + ** NEW SORTFUNC HANDLERS ** + ***************************************************************************** + */ + +static inline int +compare_result_to_int(NPY_COMPARE_RESULT result, NPY_SORT_NAN_POSITION nan_position) +{ + if (result == NPY_LESS) { + return -1; + } + else if (result == NPY_GREATER) { + return 1; + } + else if (result == NPY_EQUAL) { + return 0; + } + else { + if (nan_position == NPY_SORT_NAN_TO_END) { + if (result == NPY_UNORDERED_LEFT) { + return -1; + } + else if (result == NPY_UNORDERED_RIGHT) { + return 1; + } + else if (result == NPY_UNORDERED_BOTH) { + return 0; + } + } + else if (nan_position == NPY_SORT_NAN_TO_START) { + if (result == NPY_UNORDERED_LEFT) { + return 1; + } + else if (result == NPY_UNORDERED_RIGHT) { + return -1; + } + else if (result == NPY_UNORDERED_BOTH) { + return 0; + } + } + } + + /* This should never happen, but just in case */ + PyErr_SetString(PyExc_RuntimeError, "Unexpected comparison result in sort function"); + return NPY_MIN_INT; /* Indicate an error */ +} + +static inline int +compare_from_context(const void *a, const void *b, void *context) +{ + PyArrayMethod_SortContext *sort_context = (PyArrayMethod_SortContext *)context; + PyArray_SortCompareFunc *cmp = sort_context->compare; + + int descending = sort_context->descending; + NPY_SORT_NAN_POSITION nan_position = sort_context->nan_position; + + NPY_COMPARE_RESULT result = cmp(a, b, sort_context->descriptor); + + if (result == NPY_COMPARE_ERROR) { + PyErr_SetString(PyExc_RuntimeError, "Unexpected comparison result in sort function"); + return NPY_MIN_INT; /* Indicate an error */ + } + + int cmp_result = compare_result_to_int(result, nan_position); + + if (descending) { + cmp_result = -cmp_result; + } + + return cmp_result; +} + +static inline void +fill_sort_data_from_arr_or_context(void *array, PyArrayMethod_SortContext *context, + void **out_arr_or_context, npy_intp *elsize, + PyArray_CompareFunc **out_cmp) +{ + if (context != NULL) { + *out_arr_or_context = (void *)context; + *elsize = PyDataType_ELSIZE(context->descriptor); + *out_cmp = &compare_from_context; + } + else { + PyArrayObject *arr = (PyArrayObject *)array; + *out_arr_or_context = (void *)arr; + *elsize = PyArray_ITEMSIZE(arr); + *out_cmp = PyDataType_GetArrFuncs(PyArray_DESCR(arr))->compare; + } +} + /* ***************************************************************************** ** SWAP MACROS ** diff --git a/numpy/_core/src/npysort/quicksort.cpp b/numpy/_core/src/npysort/quicksort.cpp index ddf4fce0c28b..259b36583189 100644 --- a/numpy/_core/src/npysort/quicksort.cpp +++ b/numpy/_core/src/npysort/quicksort.cpp @@ -505,12 +505,40 @@ string_aquicksort_(type *vv, npy_intp *tosort, npy_intp num, void *varr) ***************************************************************************** */ +NPY_NO_EXPORT int +npy_quicksort_with_context(PyArrayMethod_SortContext *context, void *start, npy_intp num, + NpyAuxData *auxdata) +{ + return npy_quicksort_impl(start, num, NULL, context); +} + +NPY_NO_EXPORT int +npy_aquicksort_with_context(PyArrayMethod_SortContext *context, void *vv, npy_intp *tosort, + npy_intp num, NpyAuxData *auxdata) +{ + return npy_aquicksort_impl(vv, tosort, num, NULL, context); +} + NPY_NO_EXPORT int npy_quicksort(void *start, npy_intp num, void *varr) { - PyArrayObject *arr = (PyArrayObject *)varr; - npy_intp elsize = PyArray_ITEMSIZE(arr); - PyArray_CompareFunc *cmp = PyDataType_GetArrFuncs(PyArray_DESCR(arr))->compare; + return npy_quicksort_impl(start, num, varr, NULL); +} + +NPY_NO_EXPORT int +npy_aquicksort(void *vv, npy_intp *tosort, npy_intp num, void *varr) +{ + return npy_aquicksort_impl(vv, tosort, num, varr, NULL); +} + +NPY_NO_EXPORT int +npy_quicksort_impl(void *start, npy_intp num, void *varr, PyArrayMethod_SortContext *context) +{ + void *arr; + npy_intp elsize; + PyArray_CompareFunc *cmp; + fill_sort_data_from_arr_or_context(varr, context, &arr, &elsize, &cmp); + char *vp; char *pl = (char *)start; char *pr = pl + (num - 1) * elsize; @@ -606,16 +634,19 @@ npy_quicksort(void *start, npy_intp num, void *varr) } free(vp); + return 0; } NPY_NO_EXPORT int -npy_aquicksort(void *vv, npy_intp *tosort, npy_intp num, void *varr) +npy_aquicksort_impl(void *vv, npy_intp *tosort, npy_intp num, void *varr, PyArrayMethod_SortContext *context) { char *v = (char *)vv; - PyArrayObject *arr = (PyArrayObject *)varr; - npy_intp elsize = PyArray_ITEMSIZE(arr); - PyArray_CompareFunc *cmp = PyDataType_GetArrFuncs(PyArray_DESCR(arr))->compare; + void *arr; + npy_intp elsize; + PyArray_CompareFunc *cmp; + fill_sort_data_from_arr_or_context(varr, context, &arr, &elsize, &cmp); + char *vp; npy_intp *pl = tosort; npy_intp *pr = tosort + num - 1; diff --git a/numpy/_core/src/npysort/timsort.cpp b/numpy/_core/src/npysort/timsort.cpp index 0f0f5721e7cf..4b1da96e2146 100644 --- a/numpy/_core/src/npysort/timsort.cpp +++ b/numpy/_core/src/npysort/timsort.cpp @@ -1851,6 +1851,32 @@ string_atimsort_(void *start, npy_intp *tosort, npy_intp num, void *varr) ***************************************************************************** */ +NPY_NO_EXPORT int +npy_timsort_with_context(PyArrayMethod_SortContext *context, void *start, npy_intp num, + NpyAuxData *auxdata) +{ + return npy_timsort_impl(start, num, NULL, context); +} + +NPY_NO_EXPORT int +npy_atimsort_with_context(PyArrayMethod_SortContext *context, void *vv, npy_intp *tosort, + npy_intp num, NpyAuxData *auxdata) +{ + return npy_atimsort_impl(vv, tosort, num, NULL, context); +} + +NPY_NO_EXPORT int +npy_timsort(void *start, npy_intp num, void *varr) +{ + return npy_timsort_impl(start, num, varr, NULL); +} + +NPY_NO_EXPORT int +npy_atimsort(void *start, npy_intp *tosort, npy_intp num, void *varr) +{ + return npy_atimsort_impl(start, tosort, num, varr, NULL); +} + typedef struct { char *pw; npy_intp size; @@ -1878,7 +1904,7 @@ resize_buffer_char(buffer_char *buffer, npy_intp new_size) static npy_intp npy_count_run(char *arr, npy_intp l, npy_intp num, npy_intp minrun, char *vp, - size_t len, PyArray_CompareFunc *cmp, PyArrayObject *py_arr) + size_t len, PyArray_CompareFunc *cmp, void *py_arr) { npy_intp sz; char *pl, *pi, *pj, *pr; @@ -1939,7 +1965,7 @@ npy_count_run(char *arr, npy_intp l, npy_intp num, npy_intp minrun, char *vp, static npy_intp npy_gallop_right(const char *arr, const npy_intp size, const char *key, - size_t len, PyArray_CompareFunc *cmp, PyArrayObject *py_arr) + size_t len, PyArray_CompareFunc *cmp, void *py_arr) { npy_intp last_ofs, ofs, m; @@ -1984,7 +2010,7 @@ npy_gallop_right(const char *arr, const npy_intp size, const char *key, static npy_intp npy_gallop_left(const char *arr, const npy_intp size, const char *key, - size_t len, PyArray_CompareFunc *cmp, PyArrayObject *py_arr) + size_t len, PyArray_CompareFunc *cmp, void *py_arr) { npy_intp last_ofs, ofs, l, m, r; @@ -2031,7 +2057,7 @@ npy_gallop_left(const char *arr, const npy_intp size, const char *key, static void npy_merge_left(char *p1, npy_intp l1, char *p2, npy_intp l2, char *p3, - size_t len, PyArray_CompareFunc *cmp, PyArrayObject *py_arr) + size_t len, PyArray_CompareFunc *cmp, void *py_arr) { char *end = p2 + l2 * len; memcpy(p3, p1, sizeof(char) * l1 * len); @@ -2060,7 +2086,7 @@ npy_merge_left(char *p1, npy_intp l1, char *p2, npy_intp l2, char *p3, static void npy_merge_right(char *p1, npy_intp l1, char *p2, npy_intp l2, char *p3, - size_t len, PyArray_CompareFunc *cmp, PyArrayObject *py_arr) + size_t len, PyArray_CompareFunc *cmp, void *py_arr) { npy_intp ofs; char *start = p1 - len; @@ -2095,7 +2121,7 @@ npy_merge_right(char *p1, npy_intp l1, char *p2, npy_intp l2, char *p3, static int npy_merge_at(char *arr, const run *stack, const npy_intp at, buffer_char *buffer, size_t len, PyArray_CompareFunc *cmp, - PyArrayObject *py_arr) + void *py_arr) { int ret; npy_intp s1, l1, s2, l2, k; @@ -2145,7 +2171,7 @@ npy_merge_at(char *arr, const run *stack, const npy_intp at, static int npy_try_collapse(char *arr, run *stack, npy_intp *stack_ptr, buffer_char *buffer, size_t len, PyArray_CompareFunc *cmp, - PyArrayObject *py_arr) + void *py_arr) { int ret; npy_intp A, B, C, top; @@ -2205,7 +2231,7 @@ npy_try_collapse(char *arr, run *stack, npy_intp *stack_ptr, static int npy_force_collapse(char *arr, run *stack, npy_intp *stack_ptr, buffer_char *buffer, size_t len, PyArray_CompareFunc *cmp, - PyArrayObject *py_arr) + void *py_arr) { int ret; npy_intp top = *stack_ptr; @@ -2246,11 +2272,13 @@ npy_force_collapse(char *arr, run *stack, npy_intp *stack_ptr, } NPY_NO_EXPORT int -npy_timsort(void *start, npy_intp num, void *varr) +npy_timsort_impl(void *start, npy_intp num, void *varr, PyArrayMethod_SortContext *context) { - PyArrayObject *arr = reinterpret_cast(varr); - size_t len = PyArray_ITEMSIZE(arr); - PyArray_CompareFunc *cmp = PyDataType_GetArrFuncs(PyArray_DESCR(arr))->compare; + void *arr; + npy_intp len; + PyArray_CompareFunc *cmp; + fill_sort_data_from_arr_or_context(varr, context, &arr, &len, &cmp); + int ret; npy_intp l, n, stack_ptr, minrun; run stack[TIMSORT_STACK_SIZE]; @@ -2313,7 +2341,7 @@ npy_timsort(void *start, npy_intp num, void *varr) static npy_intp npy_acount_run(char *arr, npy_intp *tosort, npy_intp l, npy_intp num, npy_intp minrun, size_t len, PyArray_CompareFunc *cmp, - PyArrayObject *py_arr) + void *py_arr) { npy_intp sz; npy_intp vi; @@ -2379,7 +2407,7 @@ npy_acount_run(char *arr, npy_intp *tosort, npy_intp l, npy_intp num, static npy_intp npy_agallop_left(const char *arr, const npy_intp *tosort, const npy_intp size, const char *key, size_t len, PyArray_CompareFunc *cmp, - PyArrayObject *py_arr) + void *py_arr) { npy_intp last_ofs, ofs, l, m, r; @@ -2428,7 +2456,7 @@ npy_agallop_left(const char *arr, const npy_intp *tosort, const npy_intp size, static npy_intp npy_agallop_right(const char *arr, const npy_intp *tosort, const npy_intp size, const char *key, size_t len, PyArray_CompareFunc *cmp, - PyArrayObject *py_arr) + void *py_arr) { npy_intp last_ofs, ofs, m; @@ -2474,7 +2502,7 @@ npy_agallop_right(const char *arr, const npy_intp *tosort, const npy_intp size, static void npy_amerge_left(char *arr, npy_intp *p1, npy_intp l1, npy_intp *p2, npy_intp l2, npy_intp *p3, size_t len, - PyArray_CompareFunc *cmp, PyArrayObject *py_arr) + PyArray_CompareFunc *cmp, void *py_arr) { npy_intp *end = p2 + l2; memcpy(p3, p1, sizeof(npy_intp) * l1); @@ -2498,7 +2526,7 @@ npy_amerge_left(char *arr, npy_intp *p1, npy_intp l1, npy_intp *p2, static void npy_amerge_right(char *arr, npy_intp *p1, npy_intp l1, npy_intp *p2, npy_intp l2, npy_intp *p3, size_t len, - PyArray_CompareFunc *cmp, PyArrayObject *py_arr) + PyArray_CompareFunc *cmp, void *py_arr) { npy_intp ofs; npy_intp *start = p1 - 1; @@ -2527,7 +2555,7 @@ npy_amerge_right(char *arr, npy_intp *p1, npy_intp l1, npy_intp *p2, static int npy_amerge_at(char *arr, npy_intp *tosort, const run *stack, const npy_intp at, buffer_intp *buffer, size_t len, PyArray_CompareFunc *cmp, - PyArrayObject *py_arr) + void *py_arr) { int ret; npy_intp s1, l1, s2, l2, k; @@ -2577,7 +2605,7 @@ npy_amerge_at(char *arr, npy_intp *tosort, const run *stack, const npy_intp at, static int npy_atry_collapse(char *arr, npy_intp *tosort, run *stack, npy_intp *stack_ptr, buffer_intp *buffer, size_t len, PyArray_CompareFunc *cmp, - PyArrayObject *py_arr) + void *py_arr) { int ret; npy_intp A, B, C, top; @@ -2638,7 +2666,7 @@ npy_atry_collapse(char *arr, npy_intp *tosort, run *stack, npy_intp *stack_ptr, static int npy_aforce_collapse(char *arr, npy_intp *tosort, run *stack, npy_intp *stack_ptr, buffer_intp *buffer, size_t len, - PyArray_CompareFunc *cmp, PyArrayObject *py_arr) + PyArray_CompareFunc *cmp, void *py_arr) { int ret; npy_intp top = *stack_ptr; @@ -2682,11 +2710,13 @@ npy_aforce_collapse(char *arr, npy_intp *tosort, run *stack, } NPY_NO_EXPORT int -npy_atimsort(void *start, npy_intp *tosort, npy_intp num, void *varr) +npy_atimsort_impl(void *start, npy_intp *tosort, npy_intp num, void *varr, PyArrayMethod_SortContext *context) { - PyArrayObject *arr = reinterpret_cast(varr); - size_t len = PyArray_ITEMSIZE(arr); - PyArray_CompareFunc *cmp = PyDataType_GetArrFuncs(PyArray_DESCR(arr))->compare; + void *arr; + npy_intp len; + PyArray_CompareFunc *cmp; + fill_sort_data_from_arr_or_context(varr, context, &arr, &len, &cmp); + int ret; npy_intp l, n, stack_ptr, minrun; run stack[TIMSORT_STACK_SIZE]; pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy