"""JIT-related state. This module stores various pieces of Python-global state relating to the JIT. This is not intended to be imported directly; please the exposed functionalities in `torch.jit`. """ import os import weakref from typing import Any, Dict, Type import torch class EnabledProxy: """Stores whether the JIT is enabled or not. This is just a wrapper for a bool, so that we get reference semantics """ def __init__(self): self.enabled = self.parse_env( "PYTORCH_JIT", True, "> Using PyTorch JIT", "> PyTorch JIT DISABLED" ) def parse_env(self, name, default, true_message, false_message): value = os.environ.get(name) if value is None: return default if value.lower() in {"1", "true", "yes"}: return True elif value.lower() in {"0", "false", "no"}: return False if value == "1v": print(true_message) return True elif value == "0v": print(false_message) return False raise ValueError(f"Unknown setting of {name}. Try using 0 or 1.") def __bool__(self): return self.enabled _enabled = EnabledProxy() def disable(): _enabled.enabled = False def enable(): _enabled.enabled = True # The Python CompilationUnit. All functions and modules defined in Python will # live in here. It's defined in Python because doing in cpp creates static # destruction order issues. _python_cu = torch._C.CompilationUnit() # python class => ScriptClass mapping _script_classes: Dict[Type[Any], Type[Any]] = {} _name_to_pyclass: Dict[str, Type[Any]] = {} def _add_script_class(python_class, script_class): _script_classes[python_class] = script_class _name_to_pyclass[script_class.qualified_name()] = python_class def _get_script_class(python_class): override = getattr(python_class, "_jit_override_qualname", None) if override is not None: python_class = _get_python_class(override) return _script_classes.get(python_class, None) def _get_python_class(qualified_name): return _name_to_pyclass.get(qualified_name, None) def _clear_class_state(): _script_classes.clear() _name_to_pyclass.clear() # Caching: we currently cache compilation of free functions and overloaded functions. # To cache free functions we hold a weak ref to the function object and # map to the compiled fn's qualified name. # To cache overloaded functions we hold a weak ref to the function obj and # map to all of its overloaded compiled fns. # In the future we could consider caching more types of objects so that # aliasing is preserved across separate compilations of the same object. _jit_caching_layer: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary() _jit_function_overload_caching: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary() def _try_get_jit_cached_overloads(key): qual_names = _jit_function_overload_caching.get(key, None) if qual_names: return [_python_cu.find_function(qual_name) for qual_name in qual_names] else: return None def _set_jit_overload_cache(key, compiled_fns): _jit_function_overload_caching[key] = [fn.qualified_name for fn in compiled_fns] def _try_get_jit_cached_function(key): if getattr(key, "__disable_jit_function_caching__", False) is True: return None qual_name = _jit_caching_layer.get(key, None) if qual_name: return _python_cu.find_function(qual_name) else: return None def _set_jit_function_cache(key, value): # only free functions currently supported assert isinstance(value, torch.jit.ScriptFunction) _jit_caching_layer[key] = value.qualified_name