Inzynierka/Lib/site-packages/pandas/_libs/intervaltree.pxi.in

435 lines
15 KiB
Cython
Raw Normal View History

2023-06-02 12:51:02 +02:00
"""
Template for intervaltree
WARNING: DO NOT edit .pxi FILE directly, .pxi is generated from .pxi.in
"""
from pandas._libs.algos import is_monotonic
ctypedef fused int_scalar_t:
int64_t
float64_t
ctypedef fused uint_scalar_t:
uint64_t
float64_t
ctypedef fused scalar_t:
int_scalar_t
uint_scalar_t
# ----------------------------------------------------------------------
# IntervalTree
# ----------------------------------------------------------------------
cdef class IntervalTree(IntervalMixin):
"""A centered interval tree
Based off the algorithm described on Wikipedia:
https://en.wikipedia.org/wiki/Interval_tree
we are emulating the IndexEngine interface
"""
cdef readonly:
ndarray left, right
IntervalNode root
object dtype
str closed
object _is_overlapping, _left_sorter, _right_sorter
Py_ssize_t _na_count
def __init__(self, left, right, closed='right', leaf_size=100):
"""
Parameters
----------
left, right : np.ndarray[ndim=1]
Left and right bounds for each interval. Assumed to contain no
NaNs.
closed : {'left', 'right', 'both', 'neither'}, optional
Whether the intervals are closed on the left-side, right-side, both
or neither. Defaults to 'right'.
leaf_size : int, optional
Parameter that controls when the tree switches from creating nodes
to brute-force search. Tune this parameter to optimize query
performance.
"""
if closed not in ['left', 'right', 'both', 'neither']:
raise ValueError("invalid option for 'closed': %s" % closed)
left = np.asarray(left)
right = np.asarray(right)
self.dtype = np.result_type(left, right)
self.left = np.asarray(left, dtype=self.dtype)
self.right = np.asarray(right, dtype=self.dtype)
indices = np.arange(len(left), dtype='int64')
self.closed = closed
# GH 23352: ensure no nan in nodes
mask = ~np.isnan(self.left)
self._na_count = len(mask) - mask.sum()
self.left = self.left[mask]
self.right = self.right[mask]
indices = indices[mask]
node_cls = NODE_CLASSES[str(self.dtype), closed]
self.root = node_cls(self.left, self.right, indices, leaf_size)
@property
def left_sorter(self) -> np.ndarray:
"""How to sort the left labels; this is used for binary search
"""
if self._left_sorter is None:
values = [self.right, self.left]
self._left_sorter = np.lexsort(values)
return self._left_sorter
@property
def right_sorter(self) -> np.ndarray:
"""How to sort the right labels
"""
if self._right_sorter is None:
self._right_sorter = np.argsort(self.right)
return self._right_sorter
@property
def is_overlapping(self) -> bool:
"""
Determine if the IntervalTree contains overlapping intervals.
Cached as self._is_overlapping.
"""
if self._is_overlapping is not None:
return self._is_overlapping
# <= when both sides closed since endpoints can overlap
op = le if self.closed == 'both' else lt
# overlap if start of current interval < end of previous interval
# (current and previous in terms of sorted order by left/start side)
current = self.left[self.left_sorter[1:]]
previous = self.right[self.left_sorter[:-1]]
self._is_overlapping = bool(op(current, previous).any())
return self._is_overlapping
@property
def is_monotonic_increasing(self) -> bool:
"""
Return True if the IntervalTree is monotonic increasing (only equal or
increasing values), else False
"""
if self._na_count > 0:
return False
sort_order = self.left_sorter
return is_monotonic(sort_order, False)[0]
def get_indexer(self, scalar_t[:] target) -> np.ndarray:
"""Return the positions corresponding to unique intervals that overlap
with the given array of scalar targets.
"""
# TODO: write get_indexer_intervals
cdef:
Py_ssize_t old_len
Py_ssize_t i
Int64Vector result
result = Int64Vector()
old_len = 0
for i in range(len(target)):
try:
self.root.query(result, target[i])
except OverflowError:
# overflow -> no match, which is already handled below
pass
if result.data.n == old_len:
result.append(-1)
elif result.data.n > old_len + 1:
raise KeyError(
'indexer does not intersect a unique set of intervals')
old_len = result.data.n
return result.to_array().astype('intp')
def get_indexer_non_unique(self, scalar_t[:] target):
"""Return the positions corresponding to intervals that overlap with
the given array of scalar targets. Non-unique positions are repeated.
"""
cdef:
Py_ssize_t old_len
Py_ssize_t i
Int64Vector result, missing
result = Int64Vector()
missing = Int64Vector()
old_len = 0
for i in range(len(target)):
try:
self.root.query(result, target[i])
except OverflowError:
# overflow -> no match, which is already handled below
pass
if result.data.n == old_len:
result.append(-1)
missing.append(i)
old_len = result.data.n
return (result.to_array().astype('intp'),
missing.to_array().astype('intp'))
def __repr__(self) -> str:
return ('<IntervalTree[{dtype},{closed}]: '
'{n_elements} elements>'.format(
dtype=self.dtype, closed=self.closed,
n_elements=self.root.n_elements))
# compat with IndexEngine interface
def clear_mapping(self) -> None:
pass
cdef take(ndarray source, ndarray indices):
"""Take the given positions from a 1D ndarray
"""
return PyArray_Take(source, indices, 0)
cdef sort_values_and_indices(all_values, all_indices, subset):
indices = take(all_indices, subset)
values = take(all_values, subset)
sorter = PyArray_ArgSort(values, 0, NPY_QUICKSORT)
sorted_values = take(values, sorter)
sorted_indices = take(indices, sorter)
return sorted_values, sorted_indices
# ----------------------------------------------------------------------
# Nodes
# ----------------------------------------------------------------------
@cython.internal
cdef class IntervalNode:
cdef readonly:
int64_t n_elements, n_center, leaf_size
bint is_leaf_node
def __repr__(self) -> str:
if self.is_leaf_node:
return (
f"<{type(self).__name__}: {self.n_elements} elements (terminal)>"
)
else:
n_left = self.left_node.n_elements
n_right = self.right_node.n_elements
n_center = self.n_elements - n_left - n_right
return (
f"<{type(self).__name__}: "
f"pivot {self.pivot}, {self.n_elements} elements "
f"({n_left} left, {n_right} right, {n_center} overlapping)>"
)
def counts(self):
"""
Inspect counts on this node
useful for debugging purposes
"""
if self.is_leaf_node:
return self.n_elements
else:
m = len(self.center_left_values)
l = self.left_node.counts()
r = self.right_node.counts()
return (m, (l, r))
# we need specialized nodes and leaves to optimize for different dtype and
# closed values
{{py:
nodes = []
for dtype in ['float64', 'int64', 'uint64']:
for closed, cmp_left, cmp_right in [
('left', '<=', '<'),
('right', '<', '<='),
('both', '<=', '<='),
('neither', '<', '<')]:
cmp_left_converse = '<' if cmp_left == '<=' else '<='
cmp_right_converse = '<' if cmp_right == '<=' else '<='
if dtype.startswith('int'):
fused_prefix = 'int_'
elif dtype.startswith('uint'):
fused_prefix = 'uint_'
elif dtype.startswith('float'):
fused_prefix = ''
nodes.append((dtype, dtype.title(),
closed, closed.title(),
cmp_left,
cmp_right,
cmp_left_converse,
cmp_right_converse,
fused_prefix))
}}
NODE_CLASSES = {}
{{for dtype, dtype_title, closed, closed_title, cmp_left, cmp_right,
cmp_left_converse, cmp_right_converse, fused_prefix in nodes}}
@cython.internal
cdef class {{dtype_title}}Closed{{closed_title}}IntervalNode(IntervalNode):
"""Non-terminal node for an IntervalTree
Categorizes intervals by those that fall to the left, those that fall to
the right, and those that overlap with the pivot.
"""
cdef readonly:
{{dtype_title}}Closed{{closed_title}}IntervalNode left_node, right_node
{{dtype}}_t[:] center_left_values, center_right_values, left, right
int64_t[:] center_left_indices, center_right_indices, indices
{{dtype}}_t min_left, max_right
{{dtype}}_t pivot
def __init__(self,
ndarray[{{dtype}}_t, ndim=1] left,
ndarray[{{dtype}}_t, ndim=1] right,
ndarray[int64_t, ndim=1] indices,
int64_t leaf_size):
self.n_elements = len(left)
self.leaf_size = leaf_size
# min_left and min_right are used to speed-up query by skipping
# query on sub-nodes. If this node has size 0, query is cheap,
# so these values don't matter.
if left.size > 0:
self.min_left = left.min()
self.max_right = right.max()
else:
self.min_left = 0
self.max_right = 0
if self.n_elements <= leaf_size:
# make this a terminal (leaf) node
self.is_leaf_node = True
self.left = left
self.right = right
self.indices = indices
self.n_center = 0
else:
# calculate a pivot so we can create child nodes
self.is_leaf_node = False
self.pivot = np.median(left / 2 + right / 2)
if np.isinf(self.pivot):
self.pivot = cython.cast({{dtype}}_t, 0)
if self.pivot > np.max(right):
self.pivot = np.max(left)
if self.pivot < np.min(left):
self.pivot = np.min(right)
left_set, right_set, center_set = self.classify_intervals(
left, right)
self.left_node = self.new_child_node(left, right,
indices, left_set)
self.right_node = self.new_child_node(left, right,
indices, right_set)
self.center_left_values, self.center_left_indices = \
sort_values_and_indices(left, indices, center_set)
self.center_right_values, self.center_right_indices = \
sort_values_and_indices(right, indices, center_set)
self.n_center = len(self.center_left_indices)
@cython.wraparound(False)
@cython.boundscheck(False)
cdef classify_intervals(self, {{dtype}}_t[:] left, {{dtype}}_t[:] right):
"""Classify the given intervals based upon whether they fall to the
left, right, or overlap with this node's pivot.
"""
cdef:
Int64Vector left_ind, right_ind, overlapping_ind
Py_ssize_t i
left_ind = Int64Vector()
right_ind = Int64Vector()
overlapping_ind = Int64Vector()
for i in range(self.n_elements):
if right[i] {{cmp_right_converse}} self.pivot:
left_ind.append(i)
elif self.pivot {{cmp_left_converse}} left[i]:
right_ind.append(i)
else:
overlapping_ind.append(i)
return (left_ind.to_array(),
right_ind.to_array(),
overlapping_ind.to_array())
cdef new_child_node(self,
ndarray[{{dtype}}_t, ndim=1] left,
ndarray[{{dtype}}_t, ndim=1] right,
ndarray[int64_t, ndim=1] indices,
ndarray[int64_t, ndim=1] subset):
"""Create a new child node.
"""
left = take(left, subset)
right = take(right, subset)
indices = take(indices, subset)
return {{dtype_title}}Closed{{closed_title}}IntervalNode(
left, right, indices, self.leaf_size)
@cython.wraparound(False)
@cython.boundscheck(False)
@cython.initializedcheck(False)
cpdef query(self, Int64Vector result, {{fused_prefix}}scalar_t point):
"""Recursively query this node and its sub-nodes for intervals that
overlap with the query point.
"""
cdef:
int64_t[:] indices
{{dtype}}_t[:] values
Py_ssize_t i
if self.is_leaf_node:
# Once we get down to a certain size, it doesn't make sense to
# continue the binary tree structure. Instead, we use linear
# search.
for i in range(self.n_elements):
if self.left[i] {{cmp_left}} point {{cmp_right}} self.right[i]:
result.append(self.indices[i])
else:
# There are child nodes. Based on comparing our query to the pivot,
# look at the center values, then go to the relevant child.
if point < self.pivot:
values = self.center_left_values
indices = self.center_left_indices
for i in range(self.n_center):
if not values[i] {{cmp_left}} point:
break
result.append(indices[i])
if point {{cmp_right}} self.left_node.max_right:
self.left_node.query(result, point)
elif point > self.pivot:
values = self.center_right_values
indices = self.center_right_indices
for i in range(self.n_center - 1, -1, -1):
if not point {{cmp_right}} values[i]:
break
result.append(indices[i])
if self.right_node.min_left {{cmp_left}} point:
self.right_node.query(result, point)
else:
result.extend(self.center_left_indices)
NODE_CLASSES['{{dtype}}',
'{{closed}}'] = {{dtype_title}}Closed{{closed_title}}IntervalNode
{{endfor}}