#pragma once #include #include #include #include #include #include #include #include #include #include #include namespace at { // Thread local state contains values that are preserved across // thread boundaries (e.g. at::launch/JIT fork, autograd). // Note at::parallel_for doesn't preserve TLS across thread boundaries. class TORCH_API ThreadLocalState { public: // Saves the thread local variables' values and // returns them as a ThreadLocalState ThreadLocalState(); // set_grad_mode - force the value of the grad mode TLS in // the current state object. This is used for example in the // autograd engine. void set_grad_mode(bool enabled); // set_multithreading_enabled - force the value of the multithreadinmaximum // threads TLS in // the current state object. This is used for example in the // autograd engine. void set_multithreading_enabled(bool enabled); // Sets thread local variables in the current thread, // according to the thread boundary specified static void setThreadLocalState(const ThreadLocalState& state); private: c10::impl::LocalDispatchKeySet dispatch_key_; // ThreadLocalDebugInfo does not change after being created // with DebugInfoGuard std::shared_ptr debug_info_; // RecordFunction TLS RecordFunctionTLS rf_tls_; // TLS for out-of-tree functorch // See NOTE [functorch TLS in pytorch/pytorch] for why this needs to be a // pointer (spoiler alert: it's due to the indirection) // This needs to be a shared_ptr instead of a unique_ptr because // ThreadLocalState is copy-able and does indeed get copied. Maybe we can // consider adding an explicit copy constructor for ThreadLocalState in the // future but I didn't want to add one just for this. std::shared_ptr functorch_tls_; // TLS for AutogradModes AutogradState autograd_tls_; // TLS for enable_torch_dispatch_mode c10::impl::TorchDispatchModeTLS torch_dispatch_mode_state_; // TLS for enable_python_dispatcher c10::impl::PyInterpreter* python_dispatcher_state_; // TLS for __torch_function__ (mode and disable_torch_function) at::impl::PythonTorchFunctionTLS python_torch_function_state_; // TLS for saved tensors default hooks at::impl::SavedTensorDefaultHooksTLS saved_tensors_default_hooks_state_; bool functionalization_reapply_views_state_; // TLS for arbitrary python objects that is registered via hooks at::impl::ThreadLocalPythonObjects saved_objects_; friend class ThreadLocalStateGuard; }; // Guard to set and reset the thread local state class TORCH_API ThreadLocalStateGuard { public: explicit ThreadLocalStateGuard(const ThreadLocalState& state) : prev_state_(ThreadLocalState()) { // set the given state across the thread boundary ThreadLocalState::setThreadLocalState(state); } ~ThreadLocalStateGuard() { // restore previously set variables ThreadLocalState::setThreadLocalState(prev_state_); } private: // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) const ThreadLocalState prev_state_; }; template auto wrapPropagateTLSState(T callback) { return [tls_state = ThreadLocalState(), callback = std::move(callback)](auto&&... args) { ThreadLocalStateGuard g(tls_state); // Propagate value returned by callback(). return callback(std::forward(args)...); }; } } // namespace at