393 lines
13 KiB
Python
393 lines
13 KiB
Python
|
"""Indexing mixin for sparse array/matrix classes.
|
||
|
"""
|
||
|
from __future__ import annotations
|
||
|
|
||
|
from typing import TYPE_CHECKING
|
||
|
|
||
|
import numpy as np
|
||
|
from ._sputils import isintlike
|
||
|
|
||
|
if TYPE_CHECKING:
|
||
|
import numpy.typing as npt
|
||
|
|
||
|
INT_TYPES = (int, np.integer)
|
||
|
|
||
|
|
||
|
def _broadcast_arrays(a, b):
|
||
|
"""
|
||
|
Same as np.broadcast_arrays(a, b) but old writeability rules.
|
||
|
|
||
|
NumPy >= 1.17.0 transitions broadcast_arrays to return
|
||
|
read-only arrays. Set writeability explicitly to avoid warnings.
|
||
|
Retain the old writeability rules, as our Cython code assumes
|
||
|
the old behavior.
|
||
|
"""
|
||
|
x, y = np.broadcast_arrays(a, b)
|
||
|
x.flags.writeable = a.flags.writeable
|
||
|
y.flags.writeable = b.flags.writeable
|
||
|
return x, y
|
||
|
|
||
|
|
||
|
class IndexMixin:
|
||
|
"""
|
||
|
This class provides common dispatching and validation logic for indexing.
|
||
|
"""
|
||
|
def _raise_on_1d_array_slice(self):
|
||
|
"""We do not currently support 1D sparse arrays.
|
||
|
|
||
|
This function is called each time that a 1D array would
|
||
|
result, raising an error instead.
|
||
|
|
||
|
Once 1D sparse arrays are implemented, it should be removed.
|
||
|
"""
|
||
|
from scipy.sparse import sparray
|
||
|
|
||
|
if isinstance(self, sparray):
|
||
|
raise NotImplementedError(
|
||
|
'We have not yet implemented 1D sparse slices; '
|
||
|
'please index using explicit indices, e.g. `x[:, [0]]`'
|
||
|
)
|
||
|
|
||
|
def __getitem__(self, key):
|
||
|
row, col = self._validate_indices(key)
|
||
|
|
||
|
# Dispatch to specialized methods.
|
||
|
if isinstance(row, INT_TYPES):
|
||
|
if isinstance(col, INT_TYPES):
|
||
|
return self._get_intXint(row, col)
|
||
|
elif isinstance(col, slice):
|
||
|
self._raise_on_1d_array_slice()
|
||
|
return self._get_intXslice(row, col)
|
||
|
elif col.ndim == 1:
|
||
|
self._raise_on_1d_array_slice()
|
||
|
return self._get_intXarray(row, col)
|
||
|
elif col.ndim == 2:
|
||
|
return self._get_intXarray(row, col)
|
||
|
raise IndexError('index results in >2 dimensions')
|
||
|
elif isinstance(row, slice):
|
||
|
if isinstance(col, INT_TYPES):
|
||
|
self._raise_on_1d_array_slice()
|
||
|
return self._get_sliceXint(row, col)
|
||
|
elif isinstance(col, slice):
|
||
|
if row == slice(None) and row == col:
|
||
|
return self.copy()
|
||
|
return self._get_sliceXslice(row, col)
|
||
|
elif col.ndim == 1:
|
||
|
return self._get_sliceXarray(row, col)
|
||
|
raise IndexError('index results in >2 dimensions')
|
||
|
elif row.ndim == 1:
|
||
|
if isinstance(col, INT_TYPES):
|
||
|
self._raise_on_1d_array_slice()
|
||
|
return self._get_arrayXint(row, col)
|
||
|
elif isinstance(col, slice):
|
||
|
return self._get_arrayXslice(row, col)
|
||
|
else: # row.ndim == 2
|
||
|
if isinstance(col, INT_TYPES):
|
||
|
return self._get_arrayXint(row, col)
|
||
|
elif isinstance(col, slice):
|
||
|
raise IndexError('index results in >2 dimensions')
|
||
|
elif row.shape[1] == 1 and (col.ndim == 1 or col.shape[0] == 1):
|
||
|
# special case for outer indexing
|
||
|
return self._get_columnXarray(row[:,0], col.ravel())
|
||
|
|
||
|
# The only remaining case is inner (fancy) indexing
|
||
|
row, col = _broadcast_arrays(row, col)
|
||
|
if row.shape != col.shape:
|
||
|
raise IndexError('number of row and column indices differ')
|
||
|
if row.size == 0:
|
||
|
return self.__class__(np.atleast_2d(row).shape, dtype=self.dtype)
|
||
|
return self._get_arrayXarray(row, col)
|
||
|
|
||
|
def __setitem__(self, key, x):
|
||
|
row, col = self._validate_indices(key)
|
||
|
|
||
|
if isinstance(row, INT_TYPES) and isinstance(col, INT_TYPES):
|
||
|
x = np.asarray(x, dtype=self.dtype)
|
||
|
if x.size != 1:
|
||
|
raise ValueError('Trying to assign a sequence to an item')
|
||
|
self._set_intXint(row, col, x.flat[0])
|
||
|
return
|
||
|
|
||
|
if isinstance(row, slice):
|
||
|
row = np.arange(*row.indices(self.shape[0]))[:, None]
|
||
|
else:
|
||
|
row = np.atleast_1d(row)
|
||
|
|
||
|
if isinstance(col, slice):
|
||
|
col = np.arange(*col.indices(self.shape[1]))[None, :]
|
||
|
if row.ndim == 1:
|
||
|
row = row[:, None]
|
||
|
else:
|
||
|
col = np.atleast_1d(col)
|
||
|
|
||
|
i, j = _broadcast_arrays(row, col)
|
||
|
if i.shape != j.shape:
|
||
|
raise IndexError('number of row and column indices differ')
|
||
|
|
||
|
from ._base import issparse
|
||
|
if issparse(x):
|
||
|
if i.ndim == 1:
|
||
|
# Inner indexing, so treat them like row vectors.
|
||
|
i = i[None]
|
||
|
j = j[None]
|
||
|
broadcast_row = x.shape[0] == 1 and i.shape[0] != 1
|
||
|
broadcast_col = x.shape[1] == 1 and i.shape[1] != 1
|
||
|
if not ((broadcast_row or x.shape[0] == i.shape[0]) and
|
||
|
(broadcast_col or x.shape[1] == i.shape[1])):
|
||
|
raise ValueError('shape mismatch in assignment')
|
||
|
if x.shape[0] == 0 or x.shape[1] == 0:
|
||
|
return
|
||
|
x = x.tocoo(copy=True)
|
||
|
x.sum_duplicates()
|
||
|
self._set_arrayXarray_sparse(i, j, x)
|
||
|
else:
|
||
|
# Make x and i into the same shape
|
||
|
x = np.asarray(x, dtype=self.dtype)
|
||
|
if x.squeeze().shape != i.squeeze().shape:
|
||
|
x = np.broadcast_to(x, i.shape)
|
||
|
if x.size == 0:
|
||
|
return
|
||
|
x = x.reshape(i.shape)
|
||
|
self._set_arrayXarray(i, j, x)
|
||
|
|
||
|
def _validate_indices(self, key):
|
||
|
# First, check if indexing with single boolean matrix.
|
||
|
from ._base import _spbase
|
||
|
if (isinstance(key, (_spbase, np.ndarray)) and
|
||
|
key.ndim == 2 and key.dtype.kind == 'b'):
|
||
|
if key.shape != self.shape:
|
||
|
raise IndexError('boolean index shape does not match array shape')
|
||
|
row, col = key.nonzero()
|
||
|
else:
|
||
|
row, col = _unpack_index(key)
|
||
|
M, N = self.shape
|
||
|
|
||
|
def _validate_bool_idx(
|
||
|
idx: npt.NDArray[np.bool_],
|
||
|
axis_size: int,
|
||
|
axis_name: str
|
||
|
) -> npt.NDArray[np.int_]:
|
||
|
if len(idx) != axis_size:
|
||
|
raise IndexError(
|
||
|
f"boolean {axis_name} index has incorrect length: {len(idx)} "
|
||
|
f"instead of {axis_size}"
|
||
|
)
|
||
|
return _boolean_index_to_array(idx)
|
||
|
|
||
|
if isintlike(row):
|
||
|
row = int(row)
|
||
|
if row < -M or row >= M:
|
||
|
raise IndexError('row index (%d) out of range' % row)
|
||
|
if row < 0:
|
||
|
row += M
|
||
|
elif (bool_row := _compatible_boolean_index(row)) is not None:
|
||
|
row = _validate_bool_idx(bool_row, M, "row")
|
||
|
elif not isinstance(row, slice):
|
||
|
row = self._asindices(row, M)
|
||
|
|
||
|
if isintlike(col):
|
||
|
col = int(col)
|
||
|
if col < -N or col >= N:
|
||
|
raise IndexError('column index (%d) out of range' % col)
|
||
|
if col < 0:
|
||
|
col += N
|
||
|
elif (bool_col := _compatible_boolean_index(col)) is not None:
|
||
|
col = _validate_bool_idx(bool_col, N, "column")
|
||
|
elif not isinstance(col, slice):
|
||
|
col = self._asindices(col, N)
|
||
|
|
||
|
return row, col
|
||
|
|
||
|
def _asindices(self, idx, length):
|
||
|
"""Convert `idx` to a valid index for an axis with a given length.
|
||
|
|
||
|
Subclasses that need special validation can override this method.
|
||
|
"""
|
||
|
try:
|
||
|
x = np.asarray(idx)
|
||
|
except (ValueError, TypeError, MemoryError) as e:
|
||
|
raise IndexError('invalid index') from e
|
||
|
|
||
|
if x.ndim not in (1, 2):
|
||
|
raise IndexError('Index dimension must be 1 or 2')
|
||
|
|
||
|
if x.size == 0:
|
||
|
return x
|
||
|
|
||
|
# Check bounds
|
||
|
max_indx = x.max()
|
||
|
if max_indx >= length:
|
||
|
raise IndexError('index (%d) out of range' % max_indx)
|
||
|
|
||
|
min_indx = x.min()
|
||
|
if min_indx < 0:
|
||
|
if min_indx < -length:
|
||
|
raise IndexError('index (%d) out of range' % min_indx)
|
||
|
if x is idx or not x.flags.owndata:
|
||
|
x = x.copy()
|
||
|
x[x < 0] += length
|
||
|
return x
|
||
|
|
||
|
def _getrow(self, i):
|
||
|
"""Return a copy of row i of the matrix, as a (1 x n) row vector.
|
||
|
"""
|
||
|
M, N = self.shape
|
||
|
i = int(i)
|
||
|
if i < -M or i >= M:
|
||
|
raise IndexError('index (%d) out of range' % i)
|
||
|
if i < 0:
|
||
|
i += M
|
||
|
return self._get_intXslice(i, slice(None))
|
||
|
|
||
|
def _getcol(self, i):
|
||
|
"""Return a copy of column i of the matrix, as a (m x 1) column vector.
|
||
|
"""
|
||
|
M, N = self.shape
|
||
|
i = int(i)
|
||
|
if i < -N or i >= N:
|
||
|
raise IndexError('index (%d) out of range' % i)
|
||
|
if i < 0:
|
||
|
i += N
|
||
|
return self._get_sliceXint(slice(None), i)
|
||
|
|
||
|
def _get_intXint(self, row, col):
|
||
|
raise NotImplementedError()
|
||
|
|
||
|
def _get_intXarray(self, row, col):
|
||
|
raise NotImplementedError()
|
||
|
|
||
|
def _get_intXslice(self, row, col):
|
||
|
raise NotImplementedError()
|
||
|
|
||
|
def _get_sliceXint(self, row, col):
|
||
|
raise NotImplementedError()
|
||
|
|
||
|
def _get_sliceXslice(self, row, col):
|
||
|
raise NotImplementedError()
|
||
|
|
||
|
def _get_sliceXarray(self, row, col):
|
||
|
raise NotImplementedError()
|
||
|
|
||
|
def _get_arrayXint(self, row, col):
|
||
|
raise NotImplementedError()
|
||
|
|
||
|
def _get_arrayXslice(self, row, col):
|
||
|
raise NotImplementedError()
|
||
|
|
||
|
def _get_columnXarray(self, row, col):
|
||
|
raise NotImplementedError()
|
||
|
|
||
|
def _get_arrayXarray(self, row, col):
|
||
|
raise NotImplementedError()
|
||
|
|
||
|
def _set_intXint(self, row, col, x):
|
||
|
raise NotImplementedError()
|
||
|
|
||
|
def _set_arrayXarray(self, row, col, x):
|
||
|
raise NotImplementedError()
|
||
|
|
||
|
def _set_arrayXarray_sparse(self, row, col, x):
|
||
|
# Fall back to densifying x
|
||
|
x = np.asarray(x.toarray(), dtype=self.dtype)
|
||
|
x, _ = _broadcast_arrays(x, row)
|
||
|
self._set_arrayXarray(row, col, x)
|
||
|
|
||
|
|
||
|
def _unpack_index(index) -> tuple[
|
||
|
int | slice | npt.NDArray[np.bool_ | np.int_],
|
||
|
int | slice | npt.NDArray[np.bool_ | np.int_]
|
||
|
]:
|
||
|
""" Parse index. Always return a tuple of the form (row, col).
|
||
|
Valid type for row/col is integer, slice, array of bool, or array of integers.
|
||
|
"""
|
||
|
# Parse any ellipses.
|
||
|
index = _check_ellipsis(index)
|
||
|
|
||
|
# Next, parse the tuple or object
|
||
|
if isinstance(index, tuple):
|
||
|
if len(index) == 2:
|
||
|
row, col = index
|
||
|
elif len(index) == 1:
|
||
|
row, col = index[0], slice(None)
|
||
|
else:
|
||
|
raise IndexError('invalid number of indices')
|
||
|
else:
|
||
|
idx = _compatible_boolean_index(index)
|
||
|
if idx is None:
|
||
|
row, col = index, slice(None)
|
||
|
elif idx.ndim < 2:
|
||
|
return idx, slice(None)
|
||
|
elif idx.ndim == 2:
|
||
|
return idx.nonzero()
|
||
|
# Next, check for validity and transform the index as needed.
|
||
|
from ._base import issparse
|
||
|
if issparse(row) or issparse(col):
|
||
|
# Supporting sparse boolean indexing with both row and col does
|
||
|
# not work because spmatrix.ndim is always 2.
|
||
|
raise IndexError(
|
||
|
'Indexing with sparse matrices is not supported '
|
||
|
'except boolean indexing where matrix and index '
|
||
|
'are equal shapes.')
|
||
|
return row, col
|
||
|
|
||
|
|
||
|
def _check_ellipsis(index):
|
||
|
"""Process indices with Ellipsis. Returns modified index."""
|
||
|
if index is Ellipsis:
|
||
|
return (slice(None), slice(None))
|
||
|
|
||
|
if not isinstance(index, tuple):
|
||
|
return index
|
||
|
|
||
|
# Find any Ellipsis objects.
|
||
|
ellipsis_indices = [i for i, v in enumerate(index) if v is Ellipsis]
|
||
|
if not ellipsis_indices:
|
||
|
return index
|
||
|
if len(ellipsis_indices) > 1:
|
||
|
raise IndexError("an index can only have a single ellipsis ('...')")
|
||
|
|
||
|
# Replace the Ellipsis object with 0, 1, or 2 null-slices as needed.
|
||
|
i, = ellipsis_indices
|
||
|
num_slices = max(0, 3 - len(index))
|
||
|
return index[:i] + (slice(None),) * num_slices + index[i + 1:]
|
||
|
|
||
|
|
||
|
def _maybe_bool_ndarray(idx):
|
||
|
"""Returns a compatible array if elements are boolean.
|
||
|
"""
|
||
|
idx = np.asanyarray(idx)
|
||
|
if idx.dtype.kind == 'b':
|
||
|
return idx
|
||
|
return None
|
||
|
|
||
|
|
||
|
def _first_element_bool(idx, max_dim=2):
|
||
|
"""Returns True if first element of the incompatible
|
||
|
array type is boolean.
|
||
|
"""
|
||
|
if max_dim < 1:
|
||
|
return None
|
||
|
try:
|
||
|
first = next(iter(idx), None)
|
||
|
except TypeError:
|
||
|
return None
|
||
|
if isinstance(first, bool):
|
||
|
return True
|
||
|
return _first_element_bool(first, max_dim-1)
|
||
|
|
||
|
|
||
|
def _compatible_boolean_index(idx):
|
||
|
"""Returns a boolean index array that can be converted to
|
||
|
integer array. Returns None if no such array exists.
|
||
|
"""
|
||
|
# Presence of attribute `ndim` indicates a compatible array type.
|
||
|
if hasattr(idx, 'ndim') or _first_element_bool(idx):
|
||
|
return _maybe_bool_ndarray(idx)
|
||
|
return None
|
||
|
|
||
|
|
||
|
def _boolean_index_to_array(idx):
|
||
|
if idx.ndim > 1:
|
||
|
raise IndexError('invalid index shape')
|
||
|
return np.where(idx)[0]
|