62 lines
1.5 KiB
Python
62 lines
1.5 KiB
Python
from contextlib import contextmanager
|
|
|
|
try:
|
|
from torch._C import _nvtx
|
|
except ImportError:
|
|
class _NVTXStub(object):
|
|
@staticmethod
|
|
def _fail(*args, **kwargs):
|
|
raise RuntimeError("NVTX functions not installed. Are you sure you have a CUDA build?")
|
|
|
|
rangePushA = _fail
|
|
rangePop = _fail
|
|
markA = _fail
|
|
|
|
_nvtx = _NVTXStub() # type: ignore[assignment]
|
|
|
|
__all__ = ['range_push', 'range_pop', 'mark', 'range']
|
|
|
|
|
|
def range_push(msg):
|
|
"""
|
|
Pushes a range onto a stack of nested range span. Returns zero-based
|
|
depth of the range that is started.
|
|
|
|
Args:
|
|
msg (string): ASCII message to associate with range
|
|
"""
|
|
return _nvtx.rangePushA(msg)
|
|
|
|
|
|
def range_pop():
|
|
"""
|
|
Pops a range off of a stack of nested range spans. Returns the
|
|
zero-based depth of the range that is ended.
|
|
"""
|
|
return _nvtx.rangePop()
|
|
|
|
|
|
def mark(msg):
|
|
"""
|
|
Describe an instantaneous event that occurred at some point.
|
|
|
|
Args:
|
|
msg (string): ASCII message to associate with the event.
|
|
"""
|
|
return _nvtx.markA(msg)
|
|
|
|
|
|
@contextmanager
|
|
def range(msg, *args, **kwargs):
|
|
"""
|
|
Context manager / decorator that pushes an NVTX range at the beginning
|
|
of its scope, and pops it at the end. If extra arguments are given,
|
|
they are passed as arguments to msg.format().
|
|
|
|
Args:
|
|
msg (string): message to associate with the range
|
|
"""
|
|
range_push(msg.format(*args, **kwargs))
|
|
yield
|
|
range_pop()
|