import dataclasses import os from typing import Any, List import torch from .utils import print_once @dataclasses.dataclass class ProfileMetrics: microseconds: float = 0.0 operators: int = 0 fusions: int = 0 graphs: int = 0 def __iadd__(self, other: "ProfileMetrics"): self.microseconds += other.microseconds self.operators += other.operators self.fusions += other.fusions return self def __add__(self, other: "ProfileMetrics"): assert isinstance(other, ProfileMetrics) return ProfileMetrics( self.microseconds + other.microseconds, self.operators + other.operators, self.fusions + other.fusions, ) def __truediv__(self, other): if isinstance(other, int): other = ProfileMetrics(other, other, other) return ProfileMetrics( self.microseconds / max(1, other.microseconds), self.operators / max(1, other.operators), self.fusions / max(1, other.fusions), ) def __str__(self): return f"{self.operators:4.0%} ops {self.microseconds:4.0%} time" def tocsv(self): return [self.operators, self.microseconds] class ProfileResult: def __init__(self, captured, total, unique_graphs): self.captured: ProfileMetrics = captured or ProfileMetrics() self.total: ProfileMetrics = total or ProfileMetrics() self.unique_graphs: int = unique_graphs def __iadd__(self, other: "ProfileResult"): self.captured += other.captured self.total += other.total self.unique_graphs += other.unique_graphs return self def percent(self): return self.captured / self.total def __str__(self): return ( f"{self.unique_graphs:2} graphs {self.captured.graphs:2} graph calls " f"{self.captured.operators:4}/{self.total.operators:4} = " + str(self.percent()) ) def tocsv(self): return [ self.unique_graphs, self.captured.graphs, self.captured.operators, self.total.operators, ] + self.percent().tocsv() def should_print_missing(): return os.environ.get("TORCHDYNAMO_PRINT_MISSING") == "1" def print_missing(stack): if any("/torch/autograd/profiler.py" in x for x in stack): return stack = [ x for x in stack if ("> ".join(stack[-3:])) class Profiler: unique_graphs = 0 def __init__(self): self.prof = torch.profiler.profile( activities=[torch.profiler.ProfilerActivity.CPU], with_stack=should_print_missing(), ) def results(self): captured_regions = 0 captured_ops = 0 captured_microseconds = 0 total_ops = 0 total_microseconds = 0 last_op_end_time = -1 captured_region_end_time = -1 events = sorted(self.prof.events(), key=lambda x: x.time_range.start) for e in events: if e.name == "TORCHDYNAMO": captured_region_end_time = e.time_range.end captured_regions += 1 # ignore `handle = torch.zeros(1)` in record_function.__init__() total_ops -= 1 elif e.time_range.start >= last_op_end_time: last_op_end_time = e.time_range.end if e.time_range.end <= captured_region_end_time: captured_ops += 1 captured_microseconds += e.time_range.elapsed_us() elif should_print_missing(): print_missing(e.stack) total_ops += 1 total_microseconds += e.time_range.elapsed_us() else: pass # ops recursively called from other ops (ignored) unique_graphs = Profiler.unique_graphs Profiler.unique_graphs = 0 # we counted one extra op that is part of the profiler setup code total_ops -= 1 return ProfileResult( captured=ProfileMetrics( microseconds=captured_microseconds, operators=captured_ops, fusions=captured_ops - captured_regions, graphs=captured_regions, ), total=ProfileMetrics( microseconds=total_microseconds, operators=total_ops, fusions=total_ops - 1, ), unique_graphs=unique_graphs, ) def fx_insert_profiling(gm: torch.fx.GraphModule, example_inputs: List[Any]): def _wrapped(*args): with torch.profiler.record_function("TORCHDYNAMO"): return gm.forward(*args) Profiler.unique_graphs += 1 return _wrapped