133 lines
5.4 KiB
C
133 lines
5.4 KiB
C
|
#pragma once
|
||
|
|
||
|
#include <c10/core/impl/LocalDispatchKeySet.h>
|
||
|
#include <c10/macros/Export.h>
|
||
|
#include <c10/macros/Macros.h>
|
||
|
|
||
|
// NOTE [Tracing Mode Switches]
|
||
|
//
|
||
|
// Historically, tracing function was controlled by two switches:
|
||
|
//
|
||
|
// - `AutoDispatchBelowADInplaceOrView` guard
|
||
|
//
|
||
|
// Tracing function used to be script-generated inside `VariableType_*.cpp`
|
||
|
// kernels, sharing the same `Autograd` dispatch key with autograd function.
|
||
|
// Therefore, before tracing function was moved out of VariableType,
|
||
|
// `AutoDispatchBelowADInplaceOrView` guard can also disable tracing as a
|
||
|
// side effect of disabling `Autograd` dispatching.
|
||
|
//
|
||
|
// - `setTracingState()` API in `torch/csrc/jit/frontend/tracer.h`
|
||
|
//
|
||
|
// It stores tracing data in a `TracingState` object in TLS. If the
|
||
|
// `TracingState` object in TLS is `null`, then tracing is paused.
|
||
|
//
|
||
|
// The `TracingState` object is created in `tracer::trace()` - the main
|
||
|
// entrance of tracing function. It's temporarily set to `null` inside
|
||
|
// generated VariableType (now TraceType) to bypass tracing for intermediate
|
||
|
// ops (ops being called by other ops). After the intermediate op call
|
||
|
// finishes it's set back to the original `TracingState` object.
|
||
|
//
|
||
|
// The `TracingState` obect in TLS can also be read/written via its Python
|
||
|
// binding in `python_tracer.cpp`, and `get/setTracingState()` C++ APIs,
|
||
|
// which are also exposed as `TORCH_API`.
|
||
|
//
|
||
|
// Two new switches were introduced since tracing function was moved out of
|
||
|
// VariableType:
|
||
|
//
|
||
|
// - `tracer::impl::set_dispatch_enabled()` API
|
||
|
//
|
||
|
// Unlike the special `Autograd` dispatch key which is included in dispatch
|
||
|
// key set by default, `Tracer` dispatch key is off by default. The
|
||
|
// dispatching switch can be toggled via this new API.
|
||
|
//
|
||
|
// - `tracer::impl::NoTracerDispatchMode` guard
|
||
|
//
|
||
|
// It's used to cover the old semantics of `AutoDispatchBelowADInplaceOrView`
|
||
|
// after tracing was moved out of VariableType.
|
||
|
//
|
||
|
// Before tracing function was moved out of VariableType, tracing was enabled
|
||
|
// when the following conditions are satisfied:
|
||
|
//
|
||
|
// 1) `TracingState` object in TLS != null;
|
||
|
// - Either inside the execution scope of `tracer::trace()`, or
|
||
|
// - Eagerly called `setTracingState()` with non-null object.
|
||
|
// 2) Not inside `AutoDispatchBelowADInplaceOrView` scope;
|
||
|
//
|
||
|
// After:
|
||
|
//
|
||
|
// 1) `TracingState` object in TLS != null;
|
||
|
// 2) Has called `tracer::impl::set_dispatch_enabled(true)`;
|
||
|
// 3) Not inside `tracer::impl::NonDispatchGuard` scope;
|
||
|
//
|
||
|
// [TODOs]
|
||
|
//
|
||
|
// - `setTracingState()` v.s. `tracer::impl::set_dispatch_enabled()`
|
||
|
//
|
||
|
// Currently `set_dispatch_enabled()` is set/unset inside `setTracingState()`
|
||
|
// to keep the semantics exactly the same as before - it's confusing to keep
|
||
|
// both switches, though. We should consider simplifying/limiting the exposed
|
||
|
// `setTracingState()` Python/C++ APIs (and other APIs calling it) so that
|
||
|
// these two can be unified.
|
||
|
//
|
||
|
// - `AutoDispatchBelowADInplaceOrView` v.s.
|
||
|
// `tracer::impl::NoTracerDispatchMode`
|
||
|
//
|
||
|
// We don't need to always set both guards together to keep semantics
|
||
|
// unchanged. For the follow use cases of `AutoDispatchBelowADInplaceOrView`
|
||
|
// we don't need set the new tracer guard:
|
||
|
//
|
||
|
// * Script-generated VariableType kernels. The guard is not necessary as
|
||
|
// tracing is already disabled explicitly by `setTracingState(null)` in
|
||
|
// generated TraceType kernels - we could keep it as is or use the new guard
|
||
|
// instead.
|
||
|
//
|
||
|
// * Custom ops. Will be handled by fallback kernel for `Tracer`.
|
||
|
//
|
||
|
// * Functions that are not likely to be called in tracing context (no python
|
||
|
// binding / not an operator), e.g.: all mobile forward() wrappers, test
|
||
|
// binaries, and etc.
|
||
|
//
|
||
|
// * Where new threads are spawned, e.g.: ATen/native/ConvolutionMM2d.cpp.
|
||
|
// It's not necessary as tracing is off by default.
|
||
|
//
|
||
|
// For the rest of cases we might need have both:
|
||
|
//
|
||
|
// * Functions that might be reachable from eager mode python (especially
|
||
|
// factory methods), e.g.:
|
||
|
// `internal_new_from_data()` in `torch/csrc/utils/tensor_new.cpp`.
|
||
|
// Without the new guard it will add `aten::empty` to the traced graph.
|
||
|
//
|
||
|
// * Some manually maintained functions, e.g.:
|
||
|
// `torch/csrc/autograd/VariableTypeManual.cpp`.
|
||
|
// Set the new guard if it's not obvious whether `setTracingState(null)`
|
||
|
// has been called before it reaches the `AutoDispatchBelowADInplaceOrView`
|
||
|
// guard.
|
||
|
//
|
||
|
// We might need tweak the usage of the new guard to optimize/fix things.
|
||
|
// It should only affect the correctness of tracing function, because the
|
||
|
// guard is essentially no-op when the master `setTracingState()` switch is
|
||
|
// off.
|
||
|
|
||
|
// TODO: move this from `at::` to `jit::torch::` after
|
||
|
// `aten/src/ATen/cpp_custom_type_hack.h` is removed.
|
||
|
|
||
|
namespace at::tracer::impl {
|
||
|
|
||
|
static inline bool is_dispatch_enabled() {
|
||
|
return c10::impl::tls_is_dispatch_key_included(at::DispatchKey::Tracer) &&
|
||
|
!c10::impl::tls_is_dispatch_key_excluded(at::DispatchKey::Tracer);
|
||
|
}
|
||
|
|
||
|
static inline void set_dispatch_enabled(bool enabled) {
|
||
|
TORCH_INTERNAL_ASSERT(
|
||
|
!c10::impl::tls_is_dispatch_key_excluded(at::DispatchKey::Tracer),
|
||
|
"Cannot enable tracing within the scope of NoTracerDispatchMode!");
|
||
|
c10::impl::tls_set_dispatch_key_included(at::DispatchKey::Tracer, enabled);
|
||
|
}
|
||
|
|
||
|
struct NoTracerDispatchMode {
|
||
|
c10::impl::ExcludeDispatchKeyGuard guard_{at::DispatchKey::Tracer};
|
||
|
};
|
||
|
|
||
|
} // namespace at::tracer::impl
|