263 lines
7.3 KiB
Python
263 lines
7.3 KiB
Python
"""
|
|
Implementation of nlargest and nsmallest.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import (
|
|
TYPE_CHECKING,
|
|
Hashable,
|
|
Sequence,
|
|
cast,
|
|
final,
|
|
)
|
|
|
|
import numpy as np
|
|
|
|
from pandas._libs import algos as libalgos
|
|
from pandas._typing import (
|
|
DtypeObj,
|
|
IndexLabel,
|
|
)
|
|
|
|
from pandas.core.dtypes.common import (
|
|
is_bool_dtype,
|
|
is_complex_dtype,
|
|
is_integer_dtype,
|
|
is_list_like,
|
|
is_numeric_dtype,
|
|
needs_i8_conversion,
|
|
)
|
|
from pandas.core.dtypes.dtypes import BaseMaskedDtype
|
|
|
|
if TYPE_CHECKING:
|
|
from pandas import (
|
|
DataFrame,
|
|
Series,
|
|
)
|
|
|
|
|
|
class SelectN:
|
|
def __init__(self, obj, n: int, keep: str) -> None:
|
|
self.obj = obj
|
|
self.n = n
|
|
self.keep = keep
|
|
|
|
if self.keep not in ("first", "last", "all"):
|
|
raise ValueError('keep must be either "first", "last" or "all"')
|
|
|
|
def compute(self, method: str) -> DataFrame | Series:
|
|
raise NotImplementedError
|
|
|
|
@final
|
|
def nlargest(self):
|
|
return self.compute("nlargest")
|
|
|
|
@final
|
|
def nsmallest(self):
|
|
return self.compute("nsmallest")
|
|
|
|
@final
|
|
@staticmethod
|
|
def is_valid_dtype_n_method(dtype: DtypeObj) -> bool:
|
|
"""
|
|
Helper function to determine if dtype is valid for
|
|
nsmallest/nlargest methods
|
|
"""
|
|
if is_numeric_dtype(dtype):
|
|
return not is_complex_dtype(dtype)
|
|
return needs_i8_conversion(dtype)
|
|
|
|
|
|
class SelectNSeries(SelectN):
|
|
"""
|
|
Implement n largest/smallest for Series
|
|
|
|
Parameters
|
|
----------
|
|
obj : Series
|
|
n : int
|
|
keep : {'first', 'last'}, default 'first'
|
|
|
|
Returns
|
|
-------
|
|
nordered : Series
|
|
"""
|
|
|
|
def compute(self, method: str) -> Series:
|
|
from pandas.core.reshape.concat import concat
|
|
|
|
n = self.n
|
|
dtype = self.obj.dtype
|
|
if not self.is_valid_dtype_n_method(dtype):
|
|
raise TypeError(f"Cannot use method '{method}' with dtype {dtype}")
|
|
|
|
if n <= 0:
|
|
return self.obj[[]]
|
|
|
|
dropped = self.obj.dropna()
|
|
nan_index = self.obj.drop(dropped.index)
|
|
|
|
# slow method
|
|
if n >= len(self.obj):
|
|
ascending = method == "nsmallest"
|
|
return self.obj.sort_values(ascending=ascending).head(n)
|
|
|
|
# fast method
|
|
new_dtype = dropped.dtype
|
|
|
|
# Similar to algorithms._ensure_data
|
|
arr = dropped._values
|
|
if needs_i8_conversion(arr.dtype):
|
|
arr = arr.view("i8")
|
|
elif isinstance(arr.dtype, BaseMaskedDtype):
|
|
arr = arr._data
|
|
else:
|
|
arr = np.asarray(arr)
|
|
if arr.dtype.kind == "b":
|
|
arr = arr.view(np.uint8)
|
|
|
|
if method == "nlargest":
|
|
arr = -arr
|
|
if is_integer_dtype(new_dtype):
|
|
# GH 21426: ensure reverse ordering at boundaries
|
|
arr -= 1
|
|
|
|
elif is_bool_dtype(new_dtype):
|
|
# GH 26154: ensure False is smaller than True
|
|
arr = 1 - (-arr)
|
|
|
|
if self.keep == "last":
|
|
arr = arr[::-1]
|
|
|
|
nbase = n
|
|
narr = len(arr)
|
|
n = min(n, narr)
|
|
|
|
# arr passed into kth_smallest must be contiguous. We copy
|
|
# here because kth_smallest will modify its input
|
|
kth_val = libalgos.kth_smallest(arr.copy(order="C"), n - 1)
|
|
(ns,) = np.nonzero(arr <= kth_val)
|
|
inds = ns[arr[ns].argsort(kind="mergesort")]
|
|
|
|
if self.keep != "all":
|
|
inds = inds[:n]
|
|
findex = nbase
|
|
else:
|
|
if len(inds) < nbase <= len(nan_index) + len(inds):
|
|
findex = len(nan_index) + len(inds)
|
|
else:
|
|
findex = len(inds)
|
|
|
|
if self.keep == "last":
|
|
# reverse indices
|
|
inds = narr - 1 - inds
|
|
|
|
return concat([dropped.iloc[inds], nan_index]).iloc[:findex]
|
|
|
|
|
|
class SelectNFrame(SelectN):
|
|
"""
|
|
Implement n largest/smallest for DataFrame
|
|
|
|
Parameters
|
|
----------
|
|
obj : DataFrame
|
|
n : int
|
|
keep : {'first', 'last'}, default 'first'
|
|
columns : list or str
|
|
|
|
Returns
|
|
-------
|
|
nordered : DataFrame
|
|
"""
|
|
|
|
def __init__(self, obj: DataFrame, n: int, keep: str, columns: IndexLabel) -> None:
|
|
super().__init__(obj, n, keep)
|
|
if not is_list_like(columns) or isinstance(columns, tuple):
|
|
columns = [columns]
|
|
|
|
columns = cast(Sequence[Hashable], columns)
|
|
columns = list(columns)
|
|
self.columns = columns
|
|
|
|
def compute(self, method: str) -> DataFrame:
|
|
from pandas.core.api import Index
|
|
|
|
n = self.n
|
|
frame = self.obj
|
|
columns = self.columns
|
|
|
|
for column in columns:
|
|
dtype = frame[column].dtype
|
|
if not self.is_valid_dtype_n_method(dtype):
|
|
raise TypeError(
|
|
f"Column {repr(column)} has dtype {dtype}, "
|
|
f"cannot use method {repr(method)} with this dtype"
|
|
)
|
|
|
|
def get_indexer(current_indexer, other_indexer):
|
|
"""
|
|
Helper function to concat `current_indexer` and `other_indexer`
|
|
depending on `method`
|
|
"""
|
|
if method == "nsmallest":
|
|
return current_indexer.append(other_indexer)
|
|
else:
|
|
return other_indexer.append(current_indexer)
|
|
|
|
# Below we save and reset the index in case index contains duplicates
|
|
original_index = frame.index
|
|
cur_frame = frame = frame.reset_index(drop=True)
|
|
cur_n = n
|
|
indexer = Index([], dtype=np.int64)
|
|
|
|
for i, column in enumerate(columns):
|
|
# For each column we apply method to cur_frame[column].
|
|
# If it's the last column or if we have the number of
|
|
# results desired we are done.
|
|
# Otherwise there are duplicates of the largest/smallest
|
|
# value and we need to look at the rest of the columns
|
|
# to determine which of the rows with the largest/smallest
|
|
# value in the column to keep.
|
|
series = cur_frame[column]
|
|
is_last_column = len(columns) - 1 == i
|
|
values = getattr(series, method)(
|
|
cur_n, keep=self.keep if is_last_column else "all"
|
|
)
|
|
|
|
if is_last_column or len(values) <= cur_n:
|
|
indexer = get_indexer(indexer, values.index)
|
|
break
|
|
|
|
# Now find all values which are equal to
|
|
# the (nsmallest: largest)/(nlargest: smallest)
|
|
# from our series.
|
|
border_value = values == values[values.index[-1]]
|
|
|
|
# Some of these values are among the top-n
|
|
# some aren't.
|
|
unsafe_values = values[border_value]
|
|
|
|
# These values are definitely among the top-n
|
|
safe_values = values[~border_value]
|
|
indexer = get_indexer(indexer, safe_values.index)
|
|
|
|
# Go on and separate the unsafe_values on the remaining
|
|
# columns.
|
|
cur_frame = cur_frame.loc[unsafe_values.index]
|
|
cur_n = n - len(indexer)
|
|
|
|
frame = frame.take(indexer)
|
|
|
|
# Restore the index on frame
|
|
frame.index = original_index.take(indexer)
|
|
|
|
# If there is only one column, the frame is already sorted.
|
|
if len(columns) == 1:
|
|
return frame
|
|
|
|
ascending = method == "nsmallest"
|
|
|
|
return frame.sort_values(columns, ascending=ascending, kind="mergesort")
|