389 lines
10 KiB
Python
389 lines
10 KiB
Python
|
from __future__ import annotations
|
||
|
|
||
|
import textwrap
|
||
|
from typing import (
|
||
|
TYPE_CHECKING,
|
||
|
cast,
|
||
|
)
|
||
|
|
||
|
import numpy as np
|
||
|
|
||
|
from pandas._libs import (
|
||
|
NaT,
|
||
|
lib,
|
||
|
)
|
||
|
from pandas.errors import InvalidIndexError
|
||
|
|
||
|
from pandas.core.dtypes.cast import find_common_type
|
||
|
|
||
|
from pandas.core.algorithms import safe_sort
|
||
|
from pandas.core.indexes.base import (
|
||
|
Index,
|
||
|
_new_Index,
|
||
|
ensure_index,
|
||
|
ensure_index_from_sequences,
|
||
|
get_unanimous_names,
|
||
|
)
|
||
|
from pandas.core.indexes.category import CategoricalIndex
|
||
|
from pandas.core.indexes.datetimes import DatetimeIndex
|
||
|
from pandas.core.indexes.interval import IntervalIndex
|
||
|
from pandas.core.indexes.multi import MultiIndex
|
||
|
from pandas.core.indexes.period import PeriodIndex
|
||
|
from pandas.core.indexes.range import RangeIndex
|
||
|
from pandas.core.indexes.timedeltas import TimedeltaIndex
|
||
|
|
||
|
if TYPE_CHECKING:
|
||
|
from pandas._typing import Axis
|
||
|
_sort_msg = textwrap.dedent(
|
||
|
"""\
|
||
|
Sorting because non-concatenation axis is not aligned. A future version
|
||
|
of pandas will change to not sort by default.
|
||
|
|
||
|
To accept the future behavior, pass 'sort=False'.
|
||
|
|
||
|
To retain the current behavior and silence the warning, pass 'sort=True'.
|
||
|
"""
|
||
|
)
|
||
|
|
||
|
|
||
|
__all__ = [
|
||
|
"Index",
|
||
|
"MultiIndex",
|
||
|
"CategoricalIndex",
|
||
|
"IntervalIndex",
|
||
|
"RangeIndex",
|
||
|
"InvalidIndexError",
|
||
|
"TimedeltaIndex",
|
||
|
"PeriodIndex",
|
||
|
"DatetimeIndex",
|
||
|
"_new_Index",
|
||
|
"NaT",
|
||
|
"ensure_index",
|
||
|
"ensure_index_from_sequences",
|
||
|
"get_objs_combined_axis",
|
||
|
"union_indexes",
|
||
|
"get_unanimous_names",
|
||
|
"all_indexes_same",
|
||
|
"default_index",
|
||
|
"safe_sort_index",
|
||
|
]
|
||
|
|
||
|
|
||
|
def get_objs_combined_axis(
|
||
|
objs,
|
||
|
intersect: bool = False,
|
||
|
axis: Axis = 0,
|
||
|
sort: bool = True,
|
||
|
copy: bool = False,
|
||
|
) -> Index:
|
||
|
"""
|
||
|
Extract combined index: return intersection or union (depending on the
|
||
|
value of "intersect") of indexes on given axis, or None if all objects
|
||
|
lack indexes (e.g. they are numpy arrays).
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
objs : list
|
||
|
Series or DataFrame objects, may be mix of the two.
|
||
|
intersect : bool, default False
|
||
|
If True, calculate the intersection between indexes. Otherwise,
|
||
|
calculate the union.
|
||
|
axis : {0 or 'index', 1 or 'outer'}, default 0
|
||
|
The axis to extract indexes from.
|
||
|
sort : bool, default True
|
||
|
Whether the result index should come out sorted or not.
|
||
|
copy : bool, default False
|
||
|
If True, return a copy of the combined index.
|
||
|
|
||
|
Returns
|
||
|
-------
|
||
|
Index
|
||
|
"""
|
||
|
obs_idxes = [obj._get_axis(axis) for obj in objs]
|
||
|
return _get_combined_index(obs_idxes, intersect=intersect, sort=sort, copy=copy)
|
||
|
|
||
|
|
||
|
def _get_distinct_objs(objs: list[Index]) -> list[Index]:
|
||
|
"""
|
||
|
Return a list with distinct elements of "objs" (different ids).
|
||
|
Preserves order.
|
||
|
"""
|
||
|
ids: set[int] = set()
|
||
|
res = []
|
||
|
for obj in objs:
|
||
|
if id(obj) not in ids:
|
||
|
ids.add(id(obj))
|
||
|
res.append(obj)
|
||
|
return res
|
||
|
|
||
|
|
||
|
def _get_combined_index(
|
||
|
indexes: list[Index],
|
||
|
intersect: bool = False,
|
||
|
sort: bool = False,
|
||
|
copy: bool = False,
|
||
|
) -> Index:
|
||
|
"""
|
||
|
Return the union or intersection of indexes.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
indexes : list of Index or list objects
|
||
|
When intersect=True, do not accept list of lists.
|
||
|
intersect : bool, default False
|
||
|
If True, calculate the intersection between indexes. Otherwise,
|
||
|
calculate the union.
|
||
|
sort : bool, default False
|
||
|
Whether the result index should come out sorted or not.
|
||
|
copy : bool, default False
|
||
|
If True, return a copy of the combined index.
|
||
|
|
||
|
Returns
|
||
|
-------
|
||
|
Index
|
||
|
"""
|
||
|
# TODO: handle index names!
|
||
|
indexes = _get_distinct_objs(indexes)
|
||
|
if len(indexes) == 0:
|
||
|
index = Index([])
|
||
|
elif len(indexes) == 1:
|
||
|
index = indexes[0]
|
||
|
elif intersect:
|
||
|
index = indexes[0]
|
||
|
for other in indexes[1:]:
|
||
|
index = index.intersection(other)
|
||
|
else:
|
||
|
index = union_indexes(indexes, sort=False)
|
||
|
index = ensure_index(index)
|
||
|
|
||
|
if sort:
|
||
|
index = safe_sort_index(index)
|
||
|
# GH 29879
|
||
|
if copy:
|
||
|
index = index.copy()
|
||
|
|
||
|
return index
|
||
|
|
||
|
|
||
|
def safe_sort_index(index: Index) -> Index:
|
||
|
"""
|
||
|
Returns the sorted index
|
||
|
|
||
|
We keep the dtypes and the name attributes.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
index : an Index
|
||
|
|
||
|
Returns
|
||
|
-------
|
||
|
Index
|
||
|
"""
|
||
|
if index.is_monotonic_increasing:
|
||
|
return index
|
||
|
|
||
|
try:
|
||
|
array_sorted = safe_sort(index)
|
||
|
except TypeError:
|
||
|
pass
|
||
|
else:
|
||
|
if isinstance(array_sorted, Index):
|
||
|
return array_sorted
|
||
|
|
||
|
array_sorted = cast(np.ndarray, array_sorted)
|
||
|
if isinstance(index, MultiIndex):
|
||
|
index = MultiIndex.from_tuples(array_sorted, names=index.names)
|
||
|
else:
|
||
|
index = Index(array_sorted, name=index.name, dtype=index.dtype)
|
||
|
|
||
|
return index
|
||
|
|
||
|
|
||
|
def union_indexes(indexes, sort: bool | None = True) -> Index:
|
||
|
"""
|
||
|
Return the union of indexes.
|
||
|
|
||
|
The behavior of sort and names is not consistent.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
indexes : list of Index or list objects
|
||
|
sort : bool, default True
|
||
|
Whether the result index should come out sorted or not.
|
||
|
|
||
|
Returns
|
||
|
-------
|
||
|
Index
|
||
|
"""
|
||
|
if len(indexes) == 0:
|
||
|
raise AssertionError("Must have at least 1 Index to union")
|
||
|
if len(indexes) == 1:
|
||
|
result = indexes[0]
|
||
|
if isinstance(result, list):
|
||
|
if not sort:
|
||
|
result = Index(result)
|
||
|
else:
|
||
|
result = Index(sorted(result))
|
||
|
return result
|
||
|
|
||
|
indexes, kind = _sanitize_and_check(indexes)
|
||
|
|
||
|
def _unique_indices(inds, dtype) -> Index:
|
||
|
"""
|
||
|
Concatenate indices and remove duplicates.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
inds : list of Index or list objects
|
||
|
dtype : dtype to set for the resulting Index
|
||
|
|
||
|
Returns
|
||
|
-------
|
||
|
Index
|
||
|
"""
|
||
|
if all(isinstance(ind, Index) for ind in inds):
|
||
|
inds = [ind.astype(dtype, copy=False) for ind in inds]
|
||
|
result = inds[0].unique()
|
||
|
other = inds[1].append(inds[2:])
|
||
|
diff = other[result.get_indexer_for(other) == -1]
|
||
|
if len(diff):
|
||
|
result = result.append(diff.unique())
|
||
|
if sort:
|
||
|
result = result.sort_values()
|
||
|
return result
|
||
|
|
||
|
def conv(i):
|
||
|
if isinstance(i, Index):
|
||
|
i = i.tolist()
|
||
|
return i
|
||
|
|
||
|
return Index(
|
||
|
lib.fast_unique_multiple_list([conv(i) for i in inds], sort=sort),
|
||
|
dtype=dtype,
|
||
|
)
|
||
|
|
||
|
def _find_common_index_dtype(inds):
|
||
|
"""
|
||
|
Finds a common type for the indexes to pass through to resulting index.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
inds: list of Index or list objects
|
||
|
|
||
|
Returns
|
||
|
-------
|
||
|
The common type or None if no indexes were given
|
||
|
"""
|
||
|
dtypes = [idx.dtype for idx in indexes if isinstance(idx, Index)]
|
||
|
if dtypes:
|
||
|
dtype = find_common_type(dtypes)
|
||
|
else:
|
||
|
dtype = None
|
||
|
|
||
|
return dtype
|
||
|
|
||
|
if kind == "special":
|
||
|
result = indexes[0]
|
||
|
|
||
|
dtis = [x for x in indexes if isinstance(x, DatetimeIndex)]
|
||
|
dti_tzs = [x for x in dtis if x.tz is not None]
|
||
|
if len(dti_tzs) not in [0, len(dtis)]:
|
||
|
# TODO: this behavior is not tested (so may not be desired),
|
||
|
# but is kept in order to keep behavior the same when
|
||
|
# deprecating union_many
|
||
|
# test_frame_from_dict_with_mixed_indexes
|
||
|
raise TypeError("Cannot join tz-naive with tz-aware DatetimeIndex")
|
||
|
|
||
|
if len(dtis) == len(indexes):
|
||
|
sort = True
|
||
|
result = indexes[0]
|
||
|
|
||
|
elif len(dtis) > 1:
|
||
|
# If we have mixed timezones, our casting behavior may depend on
|
||
|
# the order of indexes, which we don't want.
|
||
|
sort = False
|
||
|
|
||
|
# TODO: what about Categorical[dt64]?
|
||
|
# test_frame_from_dict_with_mixed_indexes
|
||
|
indexes = [x.astype(object, copy=False) for x in indexes]
|
||
|
result = indexes[0]
|
||
|
|
||
|
for other in indexes[1:]:
|
||
|
result = result.union(other, sort=None if sort else False)
|
||
|
return result
|
||
|
|
||
|
elif kind == "array":
|
||
|
dtype = _find_common_index_dtype(indexes)
|
||
|
index = indexes[0]
|
||
|
if not all(index.equals(other) for other in indexes[1:]):
|
||
|
index = _unique_indices(indexes, dtype)
|
||
|
|
||
|
name = get_unanimous_names(*indexes)[0]
|
||
|
if name != index.name:
|
||
|
index = index.rename(name)
|
||
|
return index
|
||
|
else: # kind='list'
|
||
|
dtype = _find_common_index_dtype(indexes)
|
||
|
return _unique_indices(indexes, dtype)
|
||
|
|
||
|
|
||
|
def _sanitize_and_check(indexes):
|
||
|
"""
|
||
|
Verify the type of indexes and convert lists to Index.
|
||
|
|
||
|
Cases:
|
||
|
|
||
|
- [list, list, ...]: Return ([list, list, ...], 'list')
|
||
|
- [list, Index, ...]: Return _sanitize_and_check([Index, Index, ...])
|
||
|
Lists are sorted and converted to Index.
|
||
|
- [Index, Index, ...]: Return ([Index, Index, ...], TYPE)
|
||
|
TYPE = 'special' if at least one special type, 'array' otherwise.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
indexes : list of Index or list objects
|
||
|
|
||
|
Returns
|
||
|
-------
|
||
|
sanitized_indexes : list of Index or list objects
|
||
|
type : {'list', 'array', 'special'}
|
||
|
"""
|
||
|
kinds = list({type(index) for index in indexes})
|
||
|
|
||
|
if list in kinds:
|
||
|
if len(kinds) > 1:
|
||
|
indexes = [
|
||
|
Index(list(x)) if not isinstance(x, Index) else x for x in indexes
|
||
|
]
|
||
|
kinds.remove(list)
|
||
|
else:
|
||
|
return indexes, "list"
|
||
|
|
||
|
if len(kinds) > 1 or Index not in kinds:
|
||
|
return indexes, "special"
|
||
|
else:
|
||
|
return indexes, "array"
|
||
|
|
||
|
|
||
|
def all_indexes_same(indexes) -> bool:
|
||
|
"""
|
||
|
Determine if all indexes contain the same elements.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
indexes : iterable of Index objects
|
||
|
|
||
|
Returns
|
||
|
-------
|
||
|
bool
|
||
|
True if all indexes contain the same elements, False otherwise.
|
||
|
"""
|
||
|
itr = iter(indexes)
|
||
|
first = next(itr)
|
||
|
return all(first.equals(index) for index in itr)
|
||
|
|
||
|
|
||
|
def default_index(n: int) -> RangeIndex:
|
||
|
rng = range(n)
|
||
|
return RangeIndex._simple_new(rng, name=None)
|