3RNN/Lib/site-packages/optree/utils.py
2024-05-26 19:49:15 +02:00

116 lines
3.4 KiB
Python

# Copyright 2022-2024 MetaOPT Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utility functions for OpTree."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Callable, Iterable, Sequence, overload
if TYPE_CHECKING:
from optree.typing import S, T, U
def total_order_sorted(
iterable: Iterable[T],
*,
key: Callable[[T], Any] | None = None,
reverse: bool = False,
) -> list[T]:
"""Sort an iterable in a total order.
This is useful for sorting objects that are not comparable, e.g., dictionaries with different
types of keys.
"""
sequence = list(iterable)
try:
# Sort directly if possible
return sorted(sequence, key=key, reverse=reverse) # type: ignore[type-var,arg-type]
except TypeError:
if key is None:
def key_fn(x: T) -> tuple[str, Any]:
return (f'{x.__class__.__module__}.{x.__class__.__qualname__}', x)
else:
def key_fn(x: T) -> tuple[str, Any]:
y = key(x)
return (f'{y.__class__.__module__}.{y.__class__.__qualname__}', y)
try:
# Add `{obj.__class__.__module__}.{obj.__class__.__qualname__}` to the key order to make
# it sortable between different types (e.g., `int` vs. `str`)
return sorted(sequence, key=key_fn, reverse=reverse)
except TypeError: # cannot sort the keys (e.g., user-defined types)
return sequence # fallback to original order
@overload
def safe_zip(
__iter1: Iterable[T],
) -> zip[tuple[T]]: # pragma: no cover
...
@overload
def safe_zip(
__iter1: Iterable[T],
__iter2: Iterable[S],
) -> zip[tuple[T, S]]: # pragma: no cover
...
@overload
def safe_zip(
__iter1: Iterable[T],
__iter2: Iterable[S],
__iter3: Iterable[U],
) -> zip[tuple[T, S, U]]: # pragma: no cover
...
@overload
def safe_zip(
__iter1: Iterable[Any],
__iter2: Iterable[Any],
__iter3: Iterable[Any],
__iter4: Iterable[Any],
*__iters: Iterable[Any],
) -> zip[tuple[Any, ...]]: # pragma: no cover
...
def safe_zip(*args):
"""Strict zip that requires all arguments to be the same length."""
seqs = [arg if isinstance(arg, Sequence) else list(arg) for arg in args]
if len(set(map(len, seqs))) > 1:
raise ValueError(f'length mismatch: {list(map(len, seqs))}')
return zip(*seqs)
def unzip2(xys: Iterable[tuple[T, S]]) -> tuple[tuple[T, ...], tuple[S, ...]]:
"""Unzip sequence of length-2 tuples into two tuples."""
# Note: we deliberately don't use zip(*xys) because it is lazily evaluated,
# is too permissive about inputs, and does not guarantee a length-2 output.
# For example, for empty dict: tuple(zip(*{}.items())) -> ()
xs = []
ys = []
for x, y in xys:
xs.append(x)
ys.append(y)
return tuple(xs), tuple(ys)