3RNN/Lib/site-packages/sklearn/neighbors/_quad_tree.pyx

610 lines
23 KiB
Cython
Raw Permalink Normal View History

2024-05-26 19:49:15 +02:00
# Author: Thomas Moreau <thomas.moreau.2010@gmail.com>
# Author: Olivier Grisel <olivier.grisel@ensta.fr>
from cpython cimport Py_INCREF, PyObject, PyTypeObject
from libc.math cimport fabsf
from libc.stdlib cimport free
from libc.string cimport memcpy
from libc.stdio cimport printf
from libc.stdint cimport SIZE_MAX
from ..tree._utils cimport safe_realloc
import numpy as np
cimport numpy as cnp
cnp.import_array()
cdef extern from "numpy/arrayobject.h":
object PyArray_NewFromDescr(PyTypeObject* subtype, cnp.dtype descr,
int nd, cnp.npy_intp* dims,
cnp.npy_intp* strides,
void* data, int flags, object obj)
int PyArray_SetBaseObject(cnp.ndarray arr, PyObject* obj)
# Build the corresponding numpy dtype for Cell.
# This works by casting `dummy` to an array of Cell of length 1, which numpy
# can construct a `dtype`-object for. See https://stackoverflow.com/q/62448946
# for a more detailed explanation.
cdef Cell dummy
CELL_DTYPE = np.asarray(<Cell[:1]>(&dummy)).dtype
assert CELL_DTYPE.itemsize == sizeof(Cell)
cdef class _QuadTree:
"""Array-based representation of a QuadTree.
This class is currently working for indexing 2D data (regular QuadTree) and
for indexing 3D data (OcTree). It is planned to split the 2 implementations
using `Cython.Tempita` to save some memory for QuadTree.
Note that this code is currently internally used only by the Barnes-Hut
method in `sklearn.manifold.TSNE`. It is planned to be refactored and
generalized in the future to be compatible with nearest neighbors API of
`sklearn.neighbors` with 2D and 3D data.
"""
def __cinit__(self, int n_dimensions, int verbose):
"""Constructor."""
# Parameters of the tree
self.n_dimensions = n_dimensions
self.verbose = verbose
self.n_cells_per_cell = <int> (2 ** self.n_dimensions)
# Inner structures
self.max_depth = 0
self.cell_count = 0
self.capacity = 0
self.n_points = 0
self.cells = NULL
def __dealloc__(self):
"""Destructor."""
# Free all inner structures
free(self.cells)
@property
def cumulative_size(self):
cdef Cell[:] cell_mem_view = self._get_cell_ndarray()
return cell_mem_view.base['cumulative_size'][:self.cell_count]
@property
def leafs(self):
cdef Cell[:] cell_mem_view = self._get_cell_ndarray()
return cell_mem_view.base['is_leaf'][:self.cell_count]
def build_tree(self, X):
"""Build a tree from an array of points X."""
cdef:
int i
float32_t[3] pt
float32_t[3] min_bounds, max_bounds
# validate X and prepare for query
# X = check_array(X, dtype=float32_t, order='C')
n_samples = X.shape[0]
capacity = 100
self._resize(capacity)
m = np.min(X, axis=0)
M = np.max(X, axis=0)
# Scale the maximum to get all points strictly in the tree bounding box
# The 3 bounds are for positive, negative and small values
M = np.maximum(M * (1. + 1e-3 * np.sign(M)), M + 1e-3)
for i in range(self.n_dimensions):
min_bounds[i] = m[i]
max_bounds[i] = M[i]
if self.verbose > 10:
printf("[QuadTree] bounding box axis %i : [%f, %f]\n",
i, min_bounds[i], max_bounds[i])
# Create the initial node with boundaries from the dataset
self._init_root(min_bounds, max_bounds)
for i in range(n_samples):
for j in range(self.n_dimensions):
pt[j] = X[i, j]
self.insert_point(pt, i)
# Shrink the cells array to reduce memory usage
self._resize(capacity=self.cell_count)
cdef int insert_point(self, float32_t[3] point, intp_t point_index,
intp_t cell_id=0) except -1 nogil:
"""Insert a point in the QuadTree."""
cdef int ax
cdef intp_t selected_child
cdef Cell* cell = &self.cells[cell_id]
cdef intp_t n_point = cell.cumulative_size
if self.verbose > 10:
printf("[QuadTree] Inserting depth %li\n", cell.depth)
# Assert that the point is in the right range
if DEBUGFLAG:
self._check_point_in_cell(point, cell)
# If the cell is an empty leaf, insert the point in it
if cell.cumulative_size == 0:
cell.cumulative_size = 1
self.n_points += 1
for i in range(self.n_dimensions):
cell.barycenter[i] = point[i]
cell.point_index = point_index
if self.verbose > 10:
printf("[QuadTree] inserted point %li in cell %li\n",
point_index, cell_id)
return cell_id
# If the cell is not a leaf, update cell internals and
# recurse in selected child
if not cell.is_leaf:
for ax in range(self.n_dimensions):
# barycenter update using a weighted mean
cell.barycenter[ax] = (
n_point * cell.barycenter[ax] + point[ax]) / (n_point + 1)
# Increase the size of the subtree starting from this cell
cell.cumulative_size += 1
# Insert child in the correct subtree
selected_child = self._select_child(point, cell)
if self.verbose > 49:
printf("[QuadTree] selected child %li\n", selected_child)
if selected_child == -1:
self.n_points += 1
return self._insert_point_in_new_child(point, cell, point_index)
return self.insert_point(point, point_index, selected_child)
# Finally, if the cell is a leaf with a point already inserted,
# split the cell in n_cells_per_cell if the point is not a duplicate.
# If it is a duplicate, increase the size of the leaf and return.
if self._is_duplicate(point, cell.barycenter):
if self.verbose > 10:
printf("[QuadTree] found a duplicate!\n")
cell.cumulative_size += 1
self.n_points += 1
return cell_id
# In a leaf, the barycenter correspond to the only point included
# in it.
self._insert_point_in_new_child(cell.barycenter, cell, cell.point_index,
cell.cumulative_size)
return self.insert_point(point, point_index, cell_id)
# XXX: This operation is not Thread safe
cdef intp_t _insert_point_in_new_child(
self, float32_t[3] point, Cell* cell, intp_t point_index, intp_t size=1
) noexcept nogil:
"""Create a child of cell which will contain point."""
# Local variable definition
cdef:
intp_t cell_id, cell_child_id, parent_id
float32_t[3] save_point
float32_t width
Cell* child
int i
# If the maximal capacity of the Tree have been reached, double the capacity
# We need to save the current cell id and the current point to retrieve them
# in case the reallocation
if self.cell_count + 1 > self.capacity:
parent_id = cell.cell_id
for i in range(self.n_dimensions):
save_point[i] = point[i]
self._resize(SIZE_MAX)
cell = &self.cells[parent_id]
point = save_point
# Get an empty cell and initialize it
cell_id = self.cell_count
self.cell_count += 1
child = &self.cells[cell_id]
self._init_cell(child, cell.cell_id, cell.depth + 1)
child.cell_id = cell_id
# Set the cell as an inner cell of the Tree
cell.is_leaf = False
cell.point_index = -1
# Set the correct boundary for the cell, store the point in the cell
# and compute its index in the children array.
cell_child_id = 0
for i in range(self.n_dimensions):
cell_child_id *= 2
if point[i] >= cell.center[i]:
cell_child_id += 1
child.min_bounds[i] = cell.center[i]
child.max_bounds[i] = cell.max_bounds[i]
else:
child.min_bounds[i] = cell.min_bounds[i]
child.max_bounds[i] = cell.center[i]
child.center[i] = (child.min_bounds[i] + child.max_bounds[i]) / 2.
width = child.max_bounds[i] - child.min_bounds[i]
child.barycenter[i] = point[i]
child.squared_max_width = max(child.squared_max_width, width*width)
# Store the point info and the size to account for duplicated points
child.point_index = point_index
child.cumulative_size = size
# Store the child cell in the correct place in children
cell.children[cell_child_id] = child.cell_id
if DEBUGFLAG:
# Assert that the point is in the right range
self._check_point_in_cell(point, child)
if self.verbose > 10:
printf("[QuadTree] inserted point %li in new child %li\n",
point_index, cell_id)
return cell_id
cdef bint _is_duplicate(self, float32_t[3] point1, float32_t[3] point2) noexcept nogil:
"""Check if the two given points are equals."""
cdef int i
cdef bint res = True
for i in range(self.n_dimensions):
# Use EPSILON to avoid numerical error that would overgrow the tree
res &= fabsf(point1[i] - point2[i]) <= EPSILON
return res
cdef intp_t _select_child(self, float32_t[3] point, Cell* cell) noexcept nogil:
"""Select the child of cell which contains the given query point."""
cdef:
int i
intp_t selected_child = 0
for i in range(self.n_dimensions):
# Select the correct child cell to insert the point by comparing
# it to the borders of the cells using precomputed center.
selected_child *= 2
if point[i] >= cell.center[i]:
selected_child += 1
return cell.children[selected_child]
cdef void _init_cell(self, Cell* cell, intp_t parent, intp_t depth) noexcept nogil:
"""Initialize a cell structure with some constants."""
cell.parent = parent
cell.is_leaf = True
cell.depth = depth
cell.squared_max_width = 0
cell.cumulative_size = 0
for i in range(self.n_cells_per_cell):
cell.children[i] = SIZE_MAX
cdef void _init_root(self, float32_t[3] min_bounds, float32_t[3] max_bounds
) noexcept nogil:
"""Initialize the root node with the given space boundaries"""
cdef:
int i
float32_t width
Cell* root = &self.cells[0]
self._init_cell(root, -1, 0)
for i in range(self.n_dimensions):
root.min_bounds[i] = min_bounds[i]
root.max_bounds[i] = max_bounds[i]
root.center[i] = (max_bounds[i] + min_bounds[i]) / 2.
width = max_bounds[i] - min_bounds[i]
root.squared_max_width = max(root.squared_max_width, width*width)
root.cell_id = 0
self.cell_count += 1
cdef int _check_point_in_cell(self, float32_t[3] point, Cell* cell
) except -1 nogil:
"""Check that the given point is in the cell boundaries."""
if self.verbose >= 50:
if self.n_dimensions == 3:
printf("[QuadTree] Checking point (%f, %f, %f) in cell %li "
"([%f/%f, %f/%f, %f/%f], size %li)\n",
point[0], point[1], point[2], cell.cell_id,
cell.min_bounds[0], cell.max_bounds[0], cell.min_bounds[1],
cell.max_bounds[1], cell.min_bounds[2], cell.max_bounds[2],
cell.cumulative_size)
else:
printf("[QuadTree] Checking point (%f, %f) in cell %li "
"([%f/%f, %f/%f], size %li)\n",
point[0], point[1], cell.cell_id, cell.min_bounds[0],
cell.max_bounds[0], cell.min_bounds[1],
cell.max_bounds[1], cell.cumulative_size)
for i in range(self.n_dimensions):
if (cell.min_bounds[i] > point[i] or
cell.max_bounds[i] <= point[i]):
with gil:
msg = "[QuadTree] InsertionError: point out of cell "
msg += "boundary.\nAxis %li: cell [%f, %f]; point %f\n"
msg %= i, cell.min_bounds[i], cell.max_bounds[i], point[i]
raise ValueError(msg)
def _check_coherence(self):
"""Check the coherence of the cells of the tree.
Check that the info stored in each cell is compatible with the info
stored in descendent and sibling cells. Raise a ValueError if this
fails.
"""
for cell in self.cells[:self.cell_count]:
# Check that the barycenter of inserted point is within the cell
# boundaries
self._check_point_in_cell(cell.barycenter, &cell)
if not cell.is_leaf:
# Compute the number of point in children and compare with
# its cummulative_size.
n_points = 0
for idx in range(self.n_cells_per_cell):
child_id = cell.children[idx]
if child_id != -1:
child = self.cells[child_id]
n_points += child.cumulative_size
assert child.cell_id == child_id, (
"Cell id not correctly initialized.")
if n_points != cell.cumulative_size:
raise ValueError(
"Cell {} is incoherent. Size={} but found {} points "
"in children. ({})"
.format(cell.cell_id, cell.cumulative_size,
n_points, cell.children))
# Make sure that the number of point in the tree correspond to the
# cumulative size in root cell.
if self.n_points != self.cells[0].cumulative_size:
raise ValueError(
"QuadTree is incoherent. Size={} but found {} points "
"in children."
.format(self.n_points, self.cells[0].cumulative_size))
cdef long summarize(self, float32_t[3] point, float32_t* results,
float squared_theta=.5, intp_t cell_id=0, long idx=0
) noexcept nogil:
"""Summarize the tree compared to a query point.
Input arguments
---------------
point : array (n_dimensions)
query point to construct the summary.
cell_id : integer, optional (default: 0)
current cell of the tree summarized. This should be set to 0 for
external calls.
idx : integer, optional (default: 0)
current index in the result array. This should be set to 0 for
external calls
squared_theta: float, optional (default: .5)
threshold to decide whether the node is sufficiently far
from the query point to be a good summary. The formula is such that
the node is a summary if
node_width^2 / dist_node_point^2 < squared_theta.
Note that the argument should be passed as theta^2 to avoid
computing square roots of the distances.
Output arguments
----------------
results : array (n_samples * (n_dimensions+2))
result will contain a summary of the tree information compared to
the query point:
- results[idx:idx+n_dimensions] contains the coordinate-wise
difference between the query point and the summary cell idx.
This is useful in t-SNE to compute the negative forces.
- result[idx+n_dimensions+1] contains the squared euclidean
distance to the summary cell idx.
- result[idx+n_dimensions+2] contains the number of point of the
tree contained in the summary cell idx.
Return
------
idx : integer
number of elements in the results array.
"""
cdef:
int i, idx_d = idx + self.n_dimensions
bint duplicate = True
Cell* cell = &self.cells[cell_id]
results[idx_d] = 0.
for i in range(self.n_dimensions):
results[idx + i] = point[i] - cell.barycenter[i]
results[idx_d] += results[idx + i] * results[idx + i]
duplicate &= fabsf(results[idx + i]) <= EPSILON
# Do not compute self interactions
if duplicate and cell.is_leaf:
return idx
# Check whether we can use this node as a summary
# It's a summary node if the angular size as measured from the point
# is relatively small (w.r.t. theta) or if it is a leaf node.
# If it can be summarized, we use the cell center of mass
# Otherwise, we go a higher level of resolution and into the leaves.
if cell.is_leaf or (
(cell.squared_max_width / results[idx_d]) < squared_theta):
results[idx_d + 1] = <float32_t> cell.cumulative_size
return idx + self.n_dimensions + 2
else:
# Recursively compute the summary in nodes
for c in range(self.n_cells_per_cell):
child_id = cell.children[c]
if child_id != -1:
idx = self.summarize(point, results, squared_theta,
child_id, idx)
return idx
def get_cell(self, point):
"""return the id of the cell containing the query point or raise
ValueError if the point is not in the tree
"""
cdef float32_t[3] query_pt
cdef int i
assert len(point) == self.n_dimensions, (
"Query point should be a point in dimension {}."
.format(self.n_dimensions))
for i in range(self.n_dimensions):
query_pt[i] = point[i]
return self._get_cell(query_pt, 0)
cdef int _get_cell(self, float32_t[3] point, intp_t cell_id=0
) except -1 nogil:
"""guts of get_cell.
Return the id of the cell containing the query point or raise ValueError
if the point is not in the tree"""
cdef:
intp_t selected_child
Cell* cell = &self.cells[cell_id]
if cell.is_leaf:
if self._is_duplicate(cell.barycenter, point):
if self.verbose > 99:
printf("[QuadTree] Found point in cell: %li\n",
cell.cell_id)
return cell_id
with gil:
raise ValueError("Query point not in the Tree.")
selected_child = self._select_child(point, cell)
if selected_child > 0:
if self.verbose > 99:
printf("[QuadTree] Selected_child: %li\n", selected_child)
return self._get_cell(point, selected_child)
with gil:
raise ValueError("Query point not in the Tree.")
# Pickling primitives
def __reduce__(self):
"""Reduce re-implementation, for pickling."""
return (_QuadTree, (self.n_dimensions, self.verbose), self.__getstate__())
def __getstate__(self):
"""Getstate re-implementation, for pickling."""
d = {}
# capacity is inferred during the __setstate__ using nodes
d["max_depth"] = self.max_depth
d["cell_count"] = self.cell_count
d["capacity"] = self.capacity
d["n_points"] = self.n_points
d["cells"] = self._get_cell_ndarray().base
return d
def __setstate__(self, d):
"""Setstate re-implementation, for unpickling."""
self.max_depth = d["max_depth"]
self.cell_count = d["cell_count"]
self.capacity = d["capacity"]
self.n_points = d["n_points"]
if 'cells' not in d:
raise ValueError('You have loaded Tree version which '
'cannot be imported')
cell_ndarray = d['cells']
if (cell_ndarray.ndim != 1 or
cell_ndarray.dtype != CELL_DTYPE or
not cell_ndarray.flags.c_contiguous):
raise ValueError('Did not recognise loaded array layout')
self.capacity = cell_ndarray.shape[0]
if self._resize_c(self.capacity) != 0:
raise MemoryError("resizing tree to %d" % self.capacity)
cdef Cell[:] cell_mem_view = cell_ndarray
memcpy(
pto=self.cells,
pfrom=&cell_mem_view[0],
size=self.capacity * sizeof(Cell),
)
# Array manipulation methods, to convert it to numpy or to resize
# self.cells array
cdef Cell[:] _get_cell_ndarray(self):
"""Wraps nodes as a NumPy struct array.
The array keeps a reference to this Tree, which manages the underlying
memory. Individual fields are publicly accessible as properties of the
Tree.
"""
cdef cnp.npy_intp shape[1]
shape[0] = <cnp.npy_intp> self.cell_count
cdef cnp.npy_intp strides[1]
strides[0] = sizeof(Cell)
cdef Cell[:] arr
Py_INCREF(CELL_DTYPE)
arr = PyArray_NewFromDescr(
subtype=<PyTypeObject *> np.ndarray,
descr=CELL_DTYPE,
nd=1,
dims=shape,
strides=strides,
data=<void*> self.cells,
flags=cnp.NPY_ARRAY_DEFAULT,
obj=None,
)
Py_INCREF(self)
if PyArray_SetBaseObject(arr.base, <PyObject*> self) < 0:
raise ValueError("Can't initialize array!")
return arr
cdef int _resize(self, intp_t capacity) except -1 nogil:
"""Resize all inner arrays to `capacity`, if `capacity` == -1, then
double the size of the inner arrays.
Returns -1 in case of failure to allocate memory (and raise MemoryError)
or 0 otherwise.
"""
if self._resize_c(capacity) != 0:
# Acquire gil only if we need to raise
with gil:
raise MemoryError()
cdef int _resize_c(self, intp_t capacity=SIZE_MAX) except -1 nogil:
"""Guts of _resize
Returns -1 in case of failure to allocate memory (and raise MemoryError)
or 0 otherwise.
"""
if capacity == self.capacity and self.cells != NULL:
return 0
if <size_t> capacity == SIZE_MAX:
if self.capacity == 0:
capacity = 9 # default initial value to min
else:
capacity = 2 * self.capacity
safe_realloc(&self.cells, capacity)
# if capacity smaller than cell_count, adjust the counter
if capacity < self.cell_count:
self.cell_count = capacity
self.capacity = capacity
return 0
def _py_summarize(self, float32_t[:] query_pt, float32_t[:, :] X, float angle):
# Used for testing summarize
cdef:
float32_t[:] summary
int n_samples
n_samples = X.shape[0]
summary = np.empty(4 * n_samples, dtype=np.float32)
idx = self.summarize(&query_pt[0], &summary[0], angle * angle)
return idx, summary