3RNN/Lib/site-packages/pandas/core/indexes/api.py
2024-05-26 19:49:15 +02:00

389 lines
10 KiB
Python

from __future__ import annotations
import textwrap
from typing import (
TYPE_CHECKING,
cast,
)
import numpy as np
from pandas._libs import (
NaT,
lib,
)
from pandas.errors import InvalidIndexError
from pandas.core.dtypes.cast import find_common_type
from pandas.core.algorithms import safe_sort
from pandas.core.indexes.base import (
Index,
_new_Index,
ensure_index,
ensure_index_from_sequences,
get_unanimous_names,
)
from pandas.core.indexes.category import CategoricalIndex
from pandas.core.indexes.datetimes import DatetimeIndex
from pandas.core.indexes.interval import IntervalIndex
from pandas.core.indexes.multi import MultiIndex
from pandas.core.indexes.period import PeriodIndex
from pandas.core.indexes.range import RangeIndex
from pandas.core.indexes.timedeltas import TimedeltaIndex
if TYPE_CHECKING:
from pandas._typing import Axis
_sort_msg = textwrap.dedent(
"""\
Sorting because non-concatenation axis is not aligned. A future version
of pandas will change to not sort by default.
To accept the future behavior, pass 'sort=False'.
To retain the current behavior and silence the warning, pass 'sort=True'.
"""
)
__all__ = [
"Index",
"MultiIndex",
"CategoricalIndex",
"IntervalIndex",
"RangeIndex",
"InvalidIndexError",
"TimedeltaIndex",
"PeriodIndex",
"DatetimeIndex",
"_new_Index",
"NaT",
"ensure_index",
"ensure_index_from_sequences",
"get_objs_combined_axis",
"union_indexes",
"get_unanimous_names",
"all_indexes_same",
"default_index",
"safe_sort_index",
]
def get_objs_combined_axis(
objs,
intersect: bool = False,
axis: Axis = 0,
sort: bool = True,
copy: bool = False,
) -> Index:
"""
Extract combined index: return intersection or union (depending on the
value of "intersect") of indexes on given axis, or None if all objects
lack indexes (e.g. they are numpy arrays).
Parameters
----------
objs : list
Series or DataFrame objects, may be mix of the two.
intersect : bool, default False
If True, calculate the intersection between indexes. Otherwise,
calculate the union.
axis : {0 or 'index', 1 or 'outer'}, default 0
The axis to extract indexes from.
sort : bool, default True
Whether the result index should come out sorted or not.
copy : bool, default False
If True, return a copy of the combined index.
Returns
-------
Index
"""
obs_idxes = [obj._get_axis(axis) for obj in objs]
return _get_combined_index(obs_idxes, intersect=intersect, sort=sort, copy=copy)
def _get_distinct_objs(objs: list[Index]) -> list[Index]:
"""
Return a list with distinct elements of "objs" (different ids).
Preserves order.
"""
ids: set[int] = set()
res = []
for obj in objs:
if id(obj) not in ids:
ids.add(id(obj))
res.append(obj)
return res
def _get_combined_index(
indexes: list[Index],
intersect: bool = False,
sort: bool = False,
copy: bool = False,
) -> Index:
"""
Return the union or intersection of indexes.
Parameters
----------
indexes : list of Index or list objects
When intersect=True, do not accept list of lists.
intersect : bool, default False
If True, calculate the intersection between indexes. Otherwise,
calculate the union.
sort : bool, default False
Whether the result index should come out sorted or not.
copy : bool, default False
If True, return a copy of the combined index.
Returns
-------
Index
"""
# TODO: handle index names!
indexes = _get_distinct_objs(indexes)
if len(indexes) == 0:
index = Index([])
elif len(indexes) == 1:
index = indexes[0]
elif intersect:
index = indexes[0]
for other in indexes[1:]:
index = index.intersection(other)
else:
index = union_indexes(indexes, sort=False)
index = ensure_index(index)
if sort:
index = safe_sort_index(index)
# GH 29879
if copy:
index = index.copy()
return index
def safe_sort_index(index: Index) -> Index:
"""
Returns the sorted index
We keep the dtypes and the name attributes.
Parameters
----------
index : an Index
Returns
-------
Index
"""
if index.is_monotonic_increasing:
return index
try:
array_sorted = safe_sort(index)
except TypeError:
pass
else:
if isinstance(array_sorted, Index):
return array_sorted
array_sorted = cast(np.ndarray, array_sorted)
if isinstance(index, MultiIndex):
index = MultiIndex.from_tuples(array_sorted, names=index.names)
else:
index = Index(array_sorted, name=index.name, dtype=index.dtype)
return index
def union_indexes(indexes, sort: bool | None = True) -> Index:
"""
Return the union of indexes.
The behavior of sort and names is not consistent.
Parameters
----------
indexes : list of Index or list objects
sort : bool, default True
Whether the result index should come out sorted or not.
Returns
-------
Index
"""
if len(indexes) == 0:
raise AssertionError("Must have at least 1 Index to union")
if len(indexes) == 1:
result = indexes[0]
if isinstance(result, list):
if not sort:
result = Index(result)
else:
result = Index(sorted(result))
return result
indexes, kind = _sanitize_and_check(indexes)
def _unique_indices(inds, dtype) -> Index:
"""
Concatenate indices and remove duplicates.
Parameters
----------
inds : list of Index or list objects
dtype : dtype to set for the resulting Index
Returns
-------
Index
"""
if all(isinstance(ind, Index) for ind in inds):
inds = [ind.astype(dtype, copy=False) for ind in inds]
result = inds[0].unique()
other = inds[1].append(inds[2:])
diff = other[result.get_indexer_for(other) == -1]
if len(diff):
result = result.append(diff.unique())
if sort:
result = result.sort_values()
return result
def conv(i):
if isinstance(i, Index):
i = i.tolist()
return i
return Index(
lib.fast_unique_multiple_list([conv(i) for i in inds], sort=sort),
dtype=dtype,
)
def _find_common_index_dtype(inds):
"""
Finds a common type for the indexes to pass through to resulting index.
Parameters
----------
inds: list of Index or list objects
Returns
-------
The common type or None if no indexes were given
"""
dtypes = [idx.dtype for idx in indexes if isinstance(idx, Index)]
if dtypes:
dtype = find_common_type(dtypes)
else:
dtype = None
return dtype
if kind == "special":
result = indexes[0]
dtis = [x for x in indexes if isinstance(x, DatetimeIndex)]
dti_tzs = [x for x in dtis if x.tz is not None]
if len(dti_tzs) not in [0, len(dtis)]:
# TODO: this behavior is not tested (so may not be desired),
# but is kept in order to keep behavior the same when
# deprecating union_many
# test_frame_from_dict_with_mixed_indexes
raise TypeError("Cannot join tz-naive with tz-aware DatetimeIndex")
if len(dtis) == len(indexes):
sort = True
result = indexes[0]
elif len(dtis) > 1:
# If we have mixed timezones, our casting behavior may depend on
# the order of indexes, which we don't want.
sort = False
# TODO: what about Categorical[dt64]?
# test_frame_from_dict_with_mixed_indexes
indexes = [x.astype(object, copy=False) for x in indexes]
result = indexes[0]
for other in indexes[1:]:
result = result.union(other, sort=None if sort else False)
return result
elif kind == "array":
dtype = _find_common_index_dtype(indexes)
index = indexes[0]
if not all(index.equals(other) for other in indexes[1:]):
index = _unique_indices(indexes, dtype)
name = get_unanimous_names(*indexes)[0]
if name != index.name:
index = index.rename(name)
return index
else: # kind='list'
dtype = _find_common_index_dtype(indexes)
return _unique_indices(indexes, dtype)
def _sanitize_and_check(indexes):
"""
Verify the type of indexes and convert lists to Index.
Cases:
- [list, list, ...]: Return ([list, list, ...], 'list')
- [list, Index, ...]: Return _sanitize_and_check([Index, Index, ...])
Lists are sorted and converted to Index.
- [Index, Index, ...]: Return ([Index, Index, ...], TYPE)
TYPE = 'special' if at least one special type, 'array' otherwise.
Parameters
----------
indexes : list of Index or list objects
Returns
-------
sanitized_indexes : list of Index or list objects
type : {'list', 'array', 'special'}
"""
kinds = list({type(index) for index in indexes})
if list in kinds:
if len(kinds) > 1:
indexes = [
Index(list(x)) if not isinstance(x, Index) else x for x in indexes
]
kinds.remove(list)
else:
return indexes, "list"
if len(kinds) > 1 or Index not in kinds:
return indexes, "special"
else:
return indexes, "array"
def all_indexes_same(indexes) -> bool:
"""
Determine if all indexes contain the same elements.
Parameters
----------
indexes : iterable of Index objects
Returns
-------
bool
True if all indexes contain the same elements, False otherwise.
"""
itr = iter(indexes)
first = next(itr)
return all(first.equals(index) for index in itr)
def default_index(n: int) -> RangeIndex:
rng = range(n)
return RangeIndex._simple_new(rng, name=None)