303 lines
9.3 KiB
Python
303 lines
9.3 KiB
Python
# mypy: disable-error-code="type-arg"
|
|
from typing import (
|
|
Any,
|
|
Callable,
|
|
Dict,
|
|
List,
|
|
NamedTuple,
|
|
Optional,
|
|
overload,
|
|
Tuple,
|
|
Type,
|
|
TypeVar,
|
|
Union,
|
|
)
|
|
|
|
from _typeshed import Incomplete
|
|
from typing_extensions import Never, TypeAlias
|
|
|
|
import torch
|
|
from torch._classes import classes as classes
|
|
from torch._jit_internal import _qualified_name as _qualified_name
|
|
from torch.jit._builtins import _register_builtin as _register_builtin
|
|
from torch.jit._fuser import (
|
|
_graph_for as _graph_for,
|
|
_script_method_graph_for as _script_method_graph_for,
|
|
)
|
|
from torch.jit._monkeytype_config import (
|
|
JitTypeTraceConfig as JitTypeTraceConfig,
|
|
JitTypeTraceStore as JitTypeTraceStore,
|
|
monkeytype_trace as monkeytype_trace,
|
|
)
|
|
from torch.jit._recursive import (
|
|
_compile_and_register_class as _compile_and_register_class,
|
|
infer_methods_to_compile as infer_methods_to_compile,
|
|
ScriptMethodStub as ScriptMethodStub,
|
|
wrap_cpp_module as wrap_cpp_module,
|
|
)
|
|
from torch.jit._state import (
|
|
_enabled as _enabled,
|
|
_set_jit_function_cache as _set_jit_function_cache,
|
|
_set_jit_overload_cache as _set_jit_overload_cache,
|
|
_try_get_jit_cached_function as _try_get_jit_cached_function,
|
|
_try_get_jit_cached_overloads as _try_get_jit_cached_overloads,
|
|
)
|
|
from torch.jit.frontend import (
|
|
get_default_args as get_default_args,
|
|
get_jit_class_def as get_jit_class_def,
|
|
get_jit_def as get_jit_def,
|
|
)
|
|
from torch.nn import Module as Module
|
|
from torch.overrides import (
|
|
has_torch_function as has_torch_function,
|
|
has_torch_function_unary as has_torch_function_unary,
|
|
has_torch_function_variadic as has_torch_function_variadic,
|
|
)
|
|
from torch.package import (
|
|
PackageExporter as PackageExporter,
|
|
PackageImporter as PackageImporter,
|
|
)
|
|
from torch.utils import set_module as set_module
|
|
|
|
from ._serialization import validate_map_location as validate_map_location
|
|
|
|
ScriptFunction = torch._C.ScriptFunction
|
|
|
|
type_trace_db: JitTypeTraceStore
|
|
|
|
# Defined in torch/csrc/jit/python/script_init.cpp
|
|
ResolutionCallback: TypeAlias = Callable[[str], Callable[..., Any]]
|
|
_ClassVar = TypeVar("_ClassVar", bound=type)
|
|
|
|
def _reduce(cls) -> None: ...
|
|
|
|
class Attribute(NamedTuple):
|
|
value: Incomplete
|
|
type: Incomplete
|
|
|
|
def _get_type_trace_db(): ...
|
|
def _get_function_from_type(cls, name): ...
|
|
def _is_new_style_class(cls): ...
|
|
|
|
class OrderedDictWrapper:
|
|
_c: Incomplete
|
|
def __init__(self, _c) -> None: ...
|
|
def keys(self): ...
|
|
def values(self): ...
|
|
def __len__(self) -> int: ...
|
|
def __delitem__(self, k) -> None: ...
|
|
def items(self): ...
|
|
def __setitem__(self, k, v) -> None: ...
|
|
def __contains__(self, k) -> bool: ...
|
|
def __getitem__(self, k): ...
|
|
|
|
class OrderedModuleDict(OrderedDictWrapper):
|
|
_python_modules: Incomplete
|
|
def __init__(self, module, python_dict) -> None: ...
|
|
def items(self): ...
|
|
def __contains__(self, k) -> bool: ...
|
|
def __setitem__(self, k, v) -> None: ...
|
|
def __getitem__(self, k): ...
|
|
|
|
class ScriptMeta(type):
|
|
def __init__(cls, name, bases, attrs) -> None: ...
|
|
|
|
class _CachedForward:
|
|
def __get__(self, obj, cls): ...
|
|
|
|
class ScriptWarning(Warning): ...
|
|
|
|
def script_method(fn): ...
|
|
|
|
class ConstMap:
|
|
const_mapping: Incomplete
|
|
def __init__(self, const_mapping) -> None: ...
|
|
def __getattr__(self, attr): ...
|
|
|
|
def unpackage_script_module(
|
|
importer: PackageImporter, script_module_id: str
|
|
) -> torch.nn.Module: ...
|
|
|
|
_magic_methods: Incomplete
|
|
|
|
class RecursiveScriptClass:
|
|
_c: Incomplete
|
|
_props: Incomplete
|
|
def __init__(self, cpp_class) -> None: ...
|
|
def __getattr__(self, attr): ...
|
|
def __setattr__(self, attr, value): ...
|
|
def forward_magic_method(self, method_name, *args, **kwargs): ...
|
|
def __getstate__(self) -> None: ...
|
|
def __iadd__(self, other): ...
|
|
|
|
def method_template(self, *args, **kwargs): ...
|
|
|
|
class ScriptModule(Module, metaclass=ScriptMeta):
|
|
__jit_unused_properties__: Incomplete
|
|
def __init__(self) -> None: ...
|
|
forward: Callable[..., Any]
|
|
def __getattr__(self, attr): ...
|
|
def __setattr__(self, attr, value): ...
|
|
def define(self, src): ...
|
|
def _replicate_for_data_parallel(self): ...
|
|
def __reduce_package__(self, exporter: PackageExporter): ...
|
|
# add __jit_unused_properties__
|
|
@property
|
|
def code(self) -> str: ...
|
|
@property
|
|
def code_with_constants(self) -> Tuple[str, ConstMap]: ...
|
|
@property
|
|
def graph(self) -> torch.Graph: ...
|
|
@property
|
|
def inlined_graph(self) -> torch.Graph: ...
|
|
@property
|
|
def original_name(self) -> str: ...
|
|
|
|
class RecursiveScriptModule(ScriptModule):
|
|
_disable_script_meta: bool
|
|
_c: Incomplete
|
|
def __init__(self, cpp_module) -> None: ...
|
|
@staticmethod
|
|
def _construct(cpp_module, init_fn): ...
|
|
@staticmethod
|
|
def _finalize_scriptmodule(script_module) -> None: ...
|
|
_concrete_type: Incomplete
|
|
_modules: Incomplete
|
|
_parameters: Incomplete
|
|
_buffers: Incomplete
|
|
__dict__: Incomplete
|
|
def _reconstruct(self, cpp_module) -> None: ...
|
|
def save(self, f, **kwargs): ...
|
|
def _save_for_lite_interpreter(self, *args, **kwargs): ...
|
|
def _save_to_buffer_for_lite_interpreter(self, *args, **kwargs): ...
|
|
def save_to_buffer(self, *args, **kwargs): ...
|
|
def get_debug_state(self, *args, **kwargs): ...
|
|
def extra_repr(self): ...
|
|
def graph_for(self, *args, **kwargs): ...
|
|
def define(self, src) -> None: ...
|
|
def __getattr__(self, attr): ...
|
|
def __setattr__(self, attr, value): ...
|
|
def __copy__(self): ...
|
|
def __deepcopy__(self, memo): ...
|
|
def forward_magic_method(self, method_name, *args, **kwargs): ...
|
|
def __iter__(self): ...
|
|
def __getitem__(self, idx): ...
|
|
def __len__(self) -> int: ...
|
|
def __contains__(self, key) -> bool: ...
|
|
def __dir__(self): ...
|
|
def __bool__(self) -> bool: ...
|
|
def _replicate_for_data_parallel(self): ...
|
|
|
|
def _get_methods(cls): ...
|
|
|
|
_compiled_methods_allowlist: Incomplete
|
|
|
|
def _make_fail(name): ...
|
|
def call_prepare_scriptable_func_impl(obj, memo): ...
|
|
def call_prepare_scriptable_func(obj): ...
|
|
def create_script_dict(obj): ...
|
|
def create_script_list(obj, type_hint: Incomplete | None = ...): ...
|
|
@overload
|
|
def script(
|
|
obj: Type[Module],
|
|
optimize: Optional[bool] = None,
|
|
_frames_up: int = 0,
|
|
_rcb: Optional[ResolutionCallback] = None,
|
|
example_inputs: Union[List[Tuple], Dict[Callable, List[Tuple]], None] = None,
|
|
) -> Never: ...
|
|
@overload
|
|
def script( # type: ignore[misc]
|
|
obj: Dict,
|
|
optimize: Optional[bool] = None,
|
|
_frames_up: int = 0,
|
|
_rcb: Optional[ResolutionCallback] = None,
|
|
example_inputs: Union[List[Tuple], Dict[Callable, List[Tuple]], None] = None,
|
|
) -> torch.ScriptDict: ...
|
|
@overload
|
|
def script( # type: ignore[misc]
|
|
obj: List,
|
|
optimize: Optional[bool] = None,
|
|
_frames_up: int = 0,
|
|
_rcb: Optional[ResolutionCallback] = None,
|
|
example_inputs: Union[List[Tuple], Dict[Callable, List[Tuple]], None] = None,
|
|
) -> torch.ScriptList: ...
|
|
@overload
|
|
def script( # type: ignore[misc]
|
|
obj: Module,
|
|
optimize: Optional[bool] = None,
|
|
_frames_up: int = 0,
|
|
_rcb: Optional[ResolutionCallback] = None,
|
|
example_inputs: Union[List[Tuple], Dict[Callable, List[Tuple]], None] = None,
|
|
) -> RecursiveScriptModule: ...
|
|
@overload
|
|
def script( # type: ignore[misc]
|
|
obj: _ClassVar,
|
|
optimize: Optional[bool] = None,
|
|
_frames_up: int = 0,
|
|
_rcb: Optional[ResolutionCallback] = None,
|
|
example_inputs: Union[List[Tuple], Dict[Callable, List[Tuple]], None] = None,
|
|
) -> _ClassVar: ...
|
|
@overload
|
|
def script( # type: ignore[misc]
|
|
obj: Callable,
|
|
optimize: Optional[bool] = None,
|
|
_frames_up: int = 0,
|
|
_rcb: Optional[ResolutionCallback] = None,
|
|
example_inputs: Union[List[Tuple], Dict[Callable, List[Tuple]], None] = None,
|
|
) -> ScriptFunction: ...
|
|
@overload
|
|
def script(
|
|
obj: Any,
|
|
optimize: Optional[bool] = None,
|
|
_frames_up: int = 0,
|
|
_rcb: Optional[ResolutionCallback] = None,
|
|
example_inputs: Union[List[Tuple], Dict[Callable, List[Tuple]], None] = None,
|
|
) -> RecursiveScriptClass: ...
|
|
@overload
|
|
def script(
|
|
obj,
|
|
optimize: Incomplete | None = ...,
|
|
_frames_up: int = ...,
|
|
_rcb: Incomplete | None = ...,
|
|
example_inputs: Union[List[Tuple], Dict[Callable, List[Tuple]], None] = ...,
|
|
): ...
|
|
def _check_overload_defaults(impl_defaults, overload_defaults, loc) -> None: ...
|
|
def _compile_function_with_overload(overload_fn, qual_name, impl_fn): ...
|
|
def _get_overloads(obj): ...
|
|
def _check_directly_compile_overloaded(obj) -> None: ...
|
|
def interface(obj): ...
|
|
def _recursive_compile_class(obj, loc): ...
|
|
|
|
CompilationUnit: Incomplete
|
|
|
|
def pad(s: str, padding: int, offset: int = ..., char: str = ...): ...
|
|
|
|
class _ScriptProfileColumn:
|
|
header: Incomplete
|
|
alignment: Incomplete
|
|
offset: Incomplete
|
|
rows: Incomplete
|
|
def __init__(
|
|
self, header: str, alignment: int = ..., offset: int = ...
|
|
) -> None: ...
|
|
def add_row(self, lineno: int, value: Any): ...
|
|
def materialize(self): ...
|
|
|
|
class _ScriptProfileTable:
|
|
cols: Incomplete
|
|
source_range: Incomplete
|
|
def __init__(
|
|
self, cols: List[_ScriptProfileColumn], source_range: List[int]
|
|
) -> None: ...
|
|
def dump_string(self): ...
|
|
|
|
class _ScriptProfile:
|
|
profile: Incomplete
|
|
def __init__(self) -> None: ...
|
|
def enable(self) -> None: ...
|
|
def disable(self) -> None: ...
|
|
def dump_string(self) -> str: ...
|
|
def dump(self) -> None: ...
|
|
|
|
def _unwrap_optional(x): ...
|