#pragma once #include namespace at::impl { // VmapMode contains a thread local count of how many nested vmaps // we are currently inside. That number is known as the `vmap level`. // VmapMode is used in the implementation of the Python `torch.vmap` API. // // NOTE: this is NOT the c++ api for torch.vmap. That doesn't exist yet. struct TORCH_API VmapMode { // Returns the vmap level, aka the count of how many nested vmaps we're in. static int64_t current_vmap_level(); // Increment the count of nested vmaps. If this causes the vmap level to be // greater than 0, then it enables DispatchKey::VmapMode on all tensors. static int64_t increment_nesting(); // Decrements the count of nested vmaps. If this causes the vmap level to be // equal to 0, then it disables DispatchKey::VmapMode on all tensors. static int64_t decrement_nesting(); }; } // namespace at::impl