304 lines
11 KiB
Python
304 lines
11 KiB
Python
|
import itertools
|
||
|
from warnings import warn
|
||
|
|
||
|
import torch
|
||
|
import torch.cuda
|
||
|
|
||
|
from torch.autograd import (
|
||
|
_disable_profiler_legacy,
|
||
|
_enable_profiler_legacy,
|
||
|
DeviceType,
|
||
|
ProfilerConfig,
|
||
|
ProfilerState,
|
||
|
)
|
||
|
from torch.autograd.profiler_util import (
|
||
|
_filter_name,
|
||
|
_filter_stack_entry,
|
||
|
_rewrite_name,
|
||
|
EventList,
|
||
|
FunctionEvent,
|
||
|
MEMORY_EVENT_NAME,
|
||
|
)
|
||
|
|
||
|
__all__ = ["profile"]
|
||
|
|
||
|
|
||
|
class profile:
|
||
|
"""DEPRECATED: use torch.profiler instead."""
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
enabled=True,
|
||
|
*,
|
||
|
use_cuda=False,
|
||
|
record_shapes=False,
|
||
|
with_flops=False,
|
||
|
profile_memory=False,
|
||
|
with_stack=False,
|
||
|
with_modules=False,
|
||
|
):
|
||
|
self.enabled: bool = enabled
|
||
|
if not self.enabled:
|
||
|
return
|
||
|
self.use_cuda = use_cuda
|
||
|
self.function_events = None
|
||
|
self.entered = False
|
||
|
self.record_shapes = record_shapes
|
||
|
self.with_flops = with_flops
|
||
|
self.record_shapes |= self.with_flops
|
||
|
self.profile_memory = profile_memory
|
||
|
self.with_stack = with_stack
|
||
|
self.with_modules = with_modules
|
||
|
|
||
|
if self.use_cuda and not torch.cuda.is_available():
|
||
|
warn("CUDA is not available, disabling CUDA profiling")
|
||
|
self.use_cuda = False
|
||
|
|
||
|
if self.use_cuda:
|
||
|
self.profiler_kind = ProfilerState.CUDA
|
||
|
else:
|
||
|
self.profiler_kind = ProfilerState.CPU
|
||
|
|
||
|
def config(self):
|
||
|
return ProfilerConfig(
|
||
|
self.profiler_kind,
|
||
|
self.record_shapes,
|
||
|
self.profile_memory,
|
||
|
self.with_stack,
|
||
|
self.with_flops,
|
||
|
self.with_modules,
|
||
|
# avoid exposing _ExperimentalConfig this in legacy public API
|
||
|
torch._C._profiler._ExperimentalConfig(),
|
||
|
)
|
||
|
|
||
|
def __enter__(self):
|
||
|
if not self.enabled:
|
||
|
return
|
||
|
if self.entered:
|
||
|
raise RuntimeError("Profiler context manager is not reentrant")
|
||
|
self.entered = True
|
||
|
self._start_trace()
|
||
|
return self
|
||
|
|
||
|
def _start_trace(self):
|
||
|
_enable_profiler_legacy(self.config())
|
||
|
|
||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||
|
if not self.enabled:
|
||
|
return
|
||
|
if self.use_cuda:
|
||
|
torch.cuda.synchronize()
|
||
|
|
||
|
records = _disable_profiler_legacy()
|
||
|
parsed_results = _parse_legacy_records(records)
|
||
|
self.function_events = EventList(
|
||
|
parsed_results,
|
||
|
use_cuda=self.use_cuda,
|
||
|
profile_memory=self.profile_memory,
|
||
|
with_flops=self.with_flops,
|
||
|
)
|
||
|
self.function_events._build_tree()
|
||
|
return False
|
||
|
|
||
|
def __repr__(self):
|
||
|
if self.function_events is None:
|
||
|
return "<unfinished profiler_legacy.profile>"
|
||
|
return repr(self.function_events)
|
||
|
|
||
|
def __str__(self):
|
||
|
if self.function_events is None:
|
||
|
return "<unfinished profile.profiler_legacy.profile>"
|
||
|
return str(self.function_events)
|
||
|
|
||
|
def _check_finish(self):
|
||
|
if self.function_events is None:
|
||
|
raise RuntimeError("Profiler didn't finish running")
|
||
|
|
||
|
def table(
|
||
|
self,
|
||
|
sort_by=None,
|
||
|
row_limit=100,
|
||
|
max_src_column_width=75,
|
||
|
max_name_column_width=55,
|
||
|
max_shapes_column_width=80,
|
||
|
header=None,
|
||
|
top_level_events_only=False,
|
||
|
):
|
||
|
self._check_finish()
|
||
|
assert self.function_events is not None
|
||
|
return self.function_events.table(
|
||
|
sort_by=sort_by,
|
||
|
row_limit=row_limit,
|
||
|
max_src_column_width=max_src_column_width,
|
||
|
max_name_column_width=max_name_column_width,
|
||
|
max_shapes_column_width=max_shapes_column_width,
|
||
|
header=header,
|
||
|
top_level_events_only=top_level_events_only,
|
||
|
)
|
||
|
|
||
|
table.__doc__ = EventList.table.__doc__
|
||
|
|
||
|
def export_chrome_trace(self, path):
|
||
|
self._check_finish()
|
||
|
assert self.function_events is not None
|
||
|
return self.function_events.export_chrome_trace(path)
|
||
|
|
||
|
export_chrome_trace.__doc__ = EventList.export_chrome_trace.__doc__
|
||
|
|
||
|
def export_stacks(self, path: str, metric: str = "self_cpu_time_total"):
|
||
|
self._check_finish()
|
||
|
assert self.function_events is not None, "Expected profiling results"
|
||
|
assert self.with_stack, "export_stacks() requires with_stack=True"
|
||
|
return self.function_events.export_stacks(path, metric)
|
||
|
|
||
|
def key_averages(self, group_by_input_shape=False, group_by_stack_n=0):
|
||
|
self._check_finish()
|
||
|
assert self.function_events is not None, "Expected profiling results"
|
||
|
return self.function_events.key_averages(group_by_input_shape, group_by_stack_n)
|
||
|
|
||
|
key_averages.__doc__ = EventList.key_averages.__doc__
|
||
|
|
||
|
def total_average(self):
|
||
|
self._check_finish()
|
||
|
assert self.function_events is not None, "Expected profiling results"
|
||
|
return self.function_events.total_average()
|
||
|
|
||
|
total_average.__doc__ = EventList.total_average.__doc__
|
||
|
|
||
|
@property
|
||
|
def self_cpu_time_total(self):
|
||
|
"""Return CPU time as the sum of self times across all events."""
|
||
|
self._check_finish()
|
||
|
assert self.function_events is not None
|
||
|
return self.function_events.self_cpu_time_total
|
||
|
|
||
|
|
||
|
def _parse_legacy_records(thread_records):
|
||
|
def _get_record_key(record):
|
||
|
"""Return a tuple for correlating start and end records in `_parse_legacy_records`."""
|
||
|
return (record.handle(), record.node_id())
|
||
|
|
||
|
next_id = 0
|
||
|
start_record = None
|
||
|
functions = []
|
||
|
record_stack = []
|
||
|
|
||
|
# '__start_profile' is not guaranteed to be first, so we must find it here
|
||
|
for record in itertools.chain.from_iterable(thread_records):
|
||
|
name = record.name()
|
||
|
if start_record is None and name == "__start_profile":
|
||
|
start_record = record
|
||
|
|
||
|
assert start_record is not None and not start_record.is_remote()
|
||
|
|
||
|
for thread_record_list in thread_records:
|
||
|
# accumulated memory allocations per handle
|
||
|
cpu_memory_allocs = {}
|
||
|
cuda_memory_allocs = {}
|
||
|
# ranges per handle
|
||
|
range_starts = {}
|
||
|
|
||
|
filtered_handles = set()
|
||
|
prev_record = None
|
||
|
for record in thread_record_list:
|
||
|
record_key = _get_record_key(record)
|
||
|
if _filter_name(record.name()) or record_key in filtered_handles:
|
||
|
filtered_handles.add(record_key)
|
||
|
continue
|
||
|
|
||
|
if record.kind() == "push":
|
||
|
# workaround to reduce double logging from operator
|
||
|
# wrappers and redispatch
|
||
|
if prev_record is not None:
|
||
|
duplicate = (
|
||
|
prev_record.name() == record.name()
|
||
|
and prev_record.kind() == record.kind()
|
||
|
and prev_record.node_id() == record.node_id()
|
||
|
)
|
||
|
if duplicate:
|
||
|
filtered_handles.add(record_key)
|
||
|
continue
|
||
|
|
||
|
range_starts[record_key] = record
|
||
|
cpu_memory_allocs[record_key] = 0
|
||
|
cuda_memory_allocs[record_key] = 0
|
||
|
elif record.kind() == "pop":
|
||
|
assert (
|
||
|
record_key in range_starts
|
||
|
), f"""Expected record with key {record_key} to exist in range_starts.
|
||
|
This means that the pop event did not have a corresponding push."""
|
||
|
|
||
|
start = range_starts[record_key]
|
||
|
|
||
|
cpu_memory_usage = cpu_memory_allocs[record_key]
|
||
|
cuda_memory_usage = cuda_memory_allocs[record_key]
|
||
|
is_async = start.is_async() or (start.thread_id() != record.thread_id())
|
||
|
is_remote_event = record.is_remote()
|
||
|
start_flops = start.flops()
|
||
|
|
||
|
fe = FunctionEvent(
|
||
|
id=record.handle(),
|
||
|
node_id=record.node_id(),
|
||
|
name=_rewrite_name(name=start.name(), with_wildcard=True),
|
||
|
trace_name=_rewrite_name(name=start.name(), with_wildcard=False),
|
||
|
thread=start.thread_id(),
|
||
|
start_us=start_record.cpu_elapsed_us(start),
|
||
|
end_us=start_record.cpu_elapsed_us(record),
|
||
|
fwd_thread=start.fwd_thread_id(),
|
||
|
input_shapes=start.shapes(),
|
||
|
stack=[
|
||
|
entry for entry in start.stack() if _filter_stack_entry(entry)
|
||
|
],
|
||
|
scope=start.scope(),
|
||
|
cpu_memory_usage=cpu_memory_usage,
|
||
|
cuda_memory_usage=cuda_memory_usage,
|
||
|
is_async=is_async,
|
||
|
is_remote=is_remote_event,
|
||
|
sequence_nr=start.sequence_nr(),
|
||
|
device_type=DeviceType.CPU,
|
||
|
is_legacy=True,
|
||
|
flops=start_flops,
|
||
|
)
|
||
|
# note: async events have only cpu total time
|
||
|
if not is_async and start.has_cuda():
|
||
|
duration = start.cuda_elapsed_us(record)
|
||
|
if duration > 0:
|
||
|
fe.append_kernel(start.name(), start.device(), duration)
|
||
|
functions.append(fe)
|
||
|
del range_starts[record_key]
|
||
|
del cpu_memory_allocs[record_key]
|
||
|
del cuda_memory_allocs[record_key]
|
||
|
elif record.kind() == "memory_alloc":
|
||
|
num_open_handles_cpu = len(cpu_memory_allocs)
|
||
|
num_open_handles_cuda = len(cuda_memory_allocs)
|
||
|
assert num_open_handles_cpu == num_open_handles_cuda
|
||
|
for handle in cpu_memory_allocs.keys():
|
||
|
cpu_memory_allocs[handle] += record.cpu_memory_usage()
|
||
|
for handle in cuda_memory_allocs.keys():
|
||
|
cuda_memory_allocs[handle] += record.cuda_memory_usage()
|
||
|
if num_open_handles_cpu == 0:
|
||
|
# output event as a top-level memory event
|
||
|
fe = FunctionEvent(
|
||
|
id=0,
|
||
|
name=MEMORY_EVENT_NAME,
|
||
|
trace_name=None,
|
||
|
thread=0,
|
||
|
start_us=0,
|
||
|
end_us=0,
|
||
|
stack=[],
|
||
|
cpu_memory_usage=record.cpu_memory_usage(),
|
||
|
cuda_memory_usage=record.cuda_memory_usage(),
|
||
|
is_legacy=True,
|
||
|
)
|
||
|
functions.append(fe)
|
||
|
prev_record = record
|
||
|
|
||
|
# Sort functions by start time then by end time ascending.
|
||
|
# This ensures that--in the case of nested events which
|
||
|
# have the same start time (which may happen due to the
|
||
|
# granularity of the given clock tick)--we always show
|
||
|
# the outermost nested call first. This adds stability
|
||
|
# in how FunctionEvents appear
|
||
|
functions.sort(key=lambda evt: [evt.time_range.start, -evt.time_range.end])
|
||
|
return functions
|