75 lines
1.3 KiB
Python
75 lines
1.3 KiB
Python
"""
|
|
This file defines the types for type annotations.
|
|
|
|
These names aren't part of the module namespace, but they are used in the
|
|
annotations in the function signatures. The functions in the module are only
|
|
valid for inputs that match the given type annotations.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
__all__ = [
|
|
"Array",
|
|
"Device",
|
|
"Dtype",
|
|
"SupportsDLPack",
|
|
"SupportsBufferProtocol",
|
|
"PyCapsule",
|
|
]
|
|
|
|
import sys
|
|
from typing import (
|
|
Any,
|
|
Literal,
|
|
Sequence,
|
|
Type,
|
|
Union,
|
|
TYPE_CHECKING,
|
|
TypeVar,
|
|
Protocol,
|
|
)
|
|
|
|
from ._array_object import Array
|
|
from numpy import (
|
|
dtype,
|
|
int8,
|
|
int16,
|
|
int32,
|
|
int64,
|
|
uint8,
|
|
uint16,
|
|
uint32,
|
|
uint64,
|
|
float32,
|
|
float64,
|
|
)
|
|
|
|
_T_co = TypeVar("_T_co", covariant=True)
|
|
|
|
class NestedSequence(Protocol[_T_co]):
|
|
def __getitem__(self, key: int, /) -> _T_co | NestedSequence[_T_co]: ...
|
|
def __len__(self, /) -> int: ...
|
|
|
|
Device = Literal["cpu"]
|
|
if TYPE_CHECKING or sys.version_info >= (3, 9):
|
|
Dtype = dtype[Union[
|
|
int8,
|
|
int16,
|
|
int32,
|
|
int64,
|
|
uint8,
|
|
uint16,
|
|
uint32,
|
|
uint64,
|
|
float32,
|
|
float64,
|
|
]]
|
|
else:
|
|
Dtype = dtype
|
|
|
|
SupportsBufferProtocol = Any
|
|
PyCapsule = Any
|
|
|
|
class SupportsDLPack(Protocol):
|
|
def __dlpack__(self, /, *, stream: None = ...) -> PyCapsule: ...
|