94 lines
3.2 KiB
Cython
94 lines
3.2 KiB
Cython
from cython cimport floating
|
|
|
|
cdef inline void dual_swap(
|
|
floating* darr,
|
|
intp_t *iarr,
|
|
intp_t a,
|
|
intp_t b,
|
|
) noexcept nogil:
|
|
"""Swap the values at index a and b of both darr and iarr"""
|
|
cdef floating dtmp = darr[a]
|
|
darr[a] = darr[b]
|
|
darr[b] = dtmp
|
|
|
|
cdef intp_t itmp = iarr[a]
|
|
iarr[a] = iarr[b]
|
|
iarr[b] = itmp
|
|
|
|
|
|
cdef int simultaneous_sort(
|
|
floating* values,
|
|
intp_t* indices,
|
|
intp_t size,
|
|
) noexcept nogil:
|
|
"""
|
|
Perform a recursive quicksort on the values array as to sort them ascendingly.
|
|
This simultaneously performs the swaps on both the values and the indices arrays.
|
|
|
|
The numpy equivalent is:
|
|
|
|
def simultaneous_sort(dist, idx):
|
|
i = np.argsort(dist)
|
|
return dist[i], idx[i]
|
|
|
|
Notes
|
|
-----
|
|
Arrays are manipulated via a pointer to there first element and their size
|
|
as to ease the processing of dynamically allocated buffers.
|
|
"""
|
|
# TODO: In order to support discrete distance metrics, we need to have a
|
|
# simultaneous sort which breaks ties on indices when distances are identical.
|
|
# The best might be using a std::stable_sort and a Comparator which might need
|
|
# an Array of Structures (AoS) instead of the Structure of Arrays (SoA)
|
|
# currently used.
|
|
cdef:
|
|
intp_t pivot_idx, i, store_idx
|
|
floating pivot_val
|
|
|
|
# in the small-array case, do things efficiently
|
|
if size <= 1:
|
|
pass
|
|
elif size == 2:
|
|
if values[0] > values[1]:
|
|
dual_swap(values, indices, 0, 1)
|
|
elif size == 3:
|
|
if values[0] > values[1]:
|
|
dual_swap(values, indices, 0, 1)
|
|
if values[1] > values[2]:
|
|
dual_swap(values, indices, 1, 2)
|
|
if values[0] > values[1]:
|
|
dual_swap(values, indices, 0, 1)
|
|
else:
|
|
# Determine the pivot using the median-of-three rule.
|
|
# The smallest of the three is moved to the beginning of the array,
|
|
# the middle (the pivot value) is moved to the end, and the largest
|
|
# is moved to the pivot index.
|
|
pivot_idx = size // 2
|
|
if values[0] > values[size - 1]:
|
|
dual_swap(values, indices, 0, size - 1)
|
|
if values[size - 1] > values[pivot_idx]:
|
|
dual_swap(values, indices, size - 1, pivot_idx)
|
|
if values[0] > values[size - 1]:
|
|
dual_swap(values, indices, 0, size - 1)
|
|
pivot_val = values[size - 1]
|
|
|
|
# Partition indices about pivot. At the end of this operation,
|
|
# pivot_idx will contain the pivot value, everything to the left
|
|
# will be smaller, and everything to the right will be larger.
|
|
store_idx = 0
|
|
for i in range(size - 1):
|
|
if values[i] < pivot_val:
|
|
dual_swap(values, indices, i, store_idx)
|
|
store_idx += 1
|
|
dual_swap(values, indices, store_idx, size - 1)
|
|
pivot_idx = store_idx
|
|
|
|
# Recursively sort each side of the pivot
|
|
if pivot_idx > 1:
|
|
simultaneous_sort(values, indices, pivot_idx)
|
|
if pivot_idx + 2 < size:
|
|
simultaneous_sort(values + pivot_idx + 1,
|
|
indices + pivot_idx + 1,
|
|
size - pivot_idx - 1)
|
|
return 0
|