623 lines
22 KiB
Python
623 lines
22 KiB
Python
r"""
|
|
This module introduces CUDA Sanitizer, a tool for detecting synchronization errors between kernels ran on different streams.
|
|
|
|
It stores information on accesses to tensors to determine if they are synchronized
|
|
or not. When enabled in a python program and a possible data race is detected, a
|
|
detailed warning will be printed and the program will exit.
|
|
|
|
It can be enabled either by importing this module and calling
|
|
:func:`enable_cuda_sanitizer()` or by exporting the ``TORCH_CUDA_SANITIZER``
|
|
environment variable.
|
|
"""
|
|
|
|
import enum
|
|
import functools
|
|
import inspect
|
|
import io
|
|
import logging
|
|
import sys
|
|
import textwrap
|
|
import traceback
|
|
from dataclasses import dataclass, field
|
|
from typing import Any, Dict, Iterator, List, Optional, Set, Tuple, TypeVar
|
|
|
|
import torch
|
|
import torch.utils._cuda_trace as cuda_trace
|
|
from torch.utils import _pytree as pytree
|
|
from torch.utils._python_dispatch import TorchDispatchMode
|
|
|
|
|
|
DEFAULT_STREAM_ID = 0
|
|
|
|
TK = TypeVar("TK")
|
|
TVa = TypeVar("TVa")
|
|
TVb = TypeVar("TVb")
|
|
|
|
DataPtr = int
|
|
StreamId = int
|
|
EventId = int
|
|
SeqNum = int
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class AccessType(enum.Enum):
|
|
READ = enum.auto()
|
|
WRITE = enum.auto()
|
|
|
|
def __str__(self):
|
|
return "reading from" if self is AccessType.READ else "writing to"
|
|
|
|
|
|
@dataclass
|
|
class Access:
|
|
r"""Stores information about a single access to a tensor by a kernel.
|
|
|
|
Args:
|
|
type: either AccessType.READ or AccessType.Write.
|
|
seq_num: the sequential number of the kernel performing the access.
|
|
stream: the stream id of the stream executing the kernel.
|
|
operator: the schema of the launched kernel, which lists the
|
|
arguments and return type.
|
|
aliases: the arguments in the schema this access corresponds to.
|
|
is_output: Whether the tensor was an output of the kernel.
|
|
stack_trace: the stack summary object captured during access.
|
|
"""
|
|
|
|
type: AccessType
|
|
seq_num: SeqNum
|
|
stream: StreamId
|
|
operator: str
|
|
aliases: List[str]
|
|
is_output: bool
|
|
stack_trace: traceback.StackSummary
|
|
|
|
|
|
class SynchronizationError(Exception):
|
|
"""Base class for errors detected by CUDA Sanitizer."""
|
|
|
|
pass
|
|
|
|
|
|
class UnsynchronizedAccessError(SynchronizationError):
|
|
"""Stores information about two unsynchronized accesses to one data pointer."""
|
|
|
|
def __init__(
|
|
self,
|
|
data_ptr: DataPtr,
|
|
allocation_stack_trace: Optional[traceback.StackSummary],
|
|
current_access: Access,
|
|
previous_access: Access,
|
|
):
|
|
self.data_ptr = data_ptr
|
|
self.allocation_stack_trace = allocation_stack_trace
|
|
self.current_access = current_access
|
|
self.previous_access = previous_access
|
|
|
|
def __str__(self):
|
|
def format_access(access: Access):
|
|
message.write(f"{access.operator}\n{access.type}")
|
|
if access.aliases:
|
|
message.write(" argument(s) " + ", ".join(access.aliases))
|
|
if access.is_output:
|
|
message.write(", and to")
|
|
if access.is_output:
|
|
message.write(" the output")
|
|
message.write(
|
|
f"\nWith stack trace:\n{''.join(access.stack_trace.format())}\n"
|
|
)
|
|
|
|
with io.StringIO() as message:
|
|
message.write(
|
|
textwrap.dedent(
|
|
f"""\
|
|
============================
|
|
CSAN detected a possible data race on tensor with data pointer {self.data_ptr}
|
|
Access by stream {self.current_access.stream} during kernel:
|
|
"""
|
|
)
|
|
)
|
|
format_access(self.current_access)
|
|
|
|
message.write(
|
|
f"Previous access by stream {self.previous_access.stream} during kernel:\n"
|
|
)
|
|
format_access(self.previous_access)
|
|
|
|
if self.allocation_stack_trace:
|
|
message.write(
|
|
"Tensor was allocated with stack trace:\n"
|
|
f"{''.join(self.allocation_stack_trace.format())}"
|
|
)
|
|
else:
|
|
message.write("Trace for tensor allocation not found.")
|
|
return message.getvalue()
|
|
|
|
|
|
class CUDASanitizerErrors(Exception):
|
|
"""Wrapper class for errors reported by CUDA Sanitizer."""
|
|
|
|
def __init__(self, errors: List[SynchronizationError]):
|
|
self.errors = errors
|
|
|
|
def __str__(self):
|
|
return f"detected {len(self.errors)} errors"
|
|
|
|
|
|
@dataclass
|
|
class TensorInfo:
|
|
r"""Stores information about a single tensor and recent accesses to it.
|
|
|
|
Args:
|
|
allocation_stack_trace: the stack summary object captured during tensor
|
|
allocation. Can be ``None`` if the allocation wasn't caught by CSAN.
|
|
reads: list of read accesses to the tensor that were performed since
|
|
the last write.
|
|
write: the last write access to the tensor.
|
|
"""
|
|
|
|
allocation_stack_trace: Optional[traceback.StackSummary]
|
|
reads: List[Access] = field(default_factory=list)
|
|
write: Optional[Access] = None
|
|
|
|
|
|
class _TensorsAccessed:
|
|
def __init__(self):
|
|
self.accesses: Dict[DataPtr, TensorInfo] = {}
|
|
|
|
def ensure_tensor_exists(self, data_ptr: DataPtr) -> None:
|
|
if data_ptr not in self.accesses:
|
|
logger.info(
|
|
"Found tensor with pointer: %s, but no matching tensor "
|
|
"allocation in the trace. Backfilling the trace now. "
|
|
"Perhaps the sanitizer was enabled after some torch operations?",
|
|
data_ptr,
|
|
)
|
|
self.create_tensor(data_ptr, None)
|
|
|
|
def ensure_tensor_does_not_exist(self, data_ptr: DataPtr) -> None:
|
|
if data_ptr in self.accesses:
|
|
logger.info(
|
|
"Found duplicate tensor allocation in the trace for tensor with "
|
|
"pointer: %s. Assuming the trace for tensor deallocation "
|
|
"wasn't caught and backfilling it now. "
|
|
"Perhaps the sanitizer was enabled after some torch operations?",
|
|
data_ptr,
|
|
)
|
|
self.delete_tensor(data_ptr)
|
|
|
|
def create_tensor(
|
|
self, data_ptr: DataPtr, stack_trace: Optional[traceback.StackSummary]
|
|
) -> None:
|
|
self.accesses[data_ptr] = TensorInfo(stack_trace)
|
|
|
|
def delete_tensor(self, data_ptr: DataPtr) -> None:
|
|
del self.accesses[data_ptr]
|
|
|
|
def were_there_reads_since_last_write(self, data_ptr: DataPtr) -> bool:
|
|
return True if self.accesses[data_ptr].reads else False
|
|
|
|
def get_allocation_stack_trace(
|
|
self, data_ptr: DataPtr
|
|
) -> Optional[traceback.StackSummary]:
|
|
return self.accesses[data_ptr].allocation_stack_trace
|
|
|
|
def get_write(self, data_ptr: DataPtr) -> Optional[Access]:
|
|
return self.accesses[data_ptr].write
|
|
|
|
def get_reads(self, data_ptr: DataPtr) -> List[Access]:
|
|
return self.accesses[data_ptr].reads
|
|
|
|
def add_read(self, data_ptr: DataPtr, access: Access) -> None:
|
|
self.accesses[data_ptr].reads.append(access)
|
|
|
|
def set_write(self, data_ptr: DataPtr, access: Access) -> None:
|
|
self.accesses[data_ptr].write = access
|
|
self.accesses[data_ptr].reads = []
|
|
|
|
|
|
class StreamSynchronizations:
|
|
def __init__(self):
|
|
self.current_sync_states: Dict[StreamId, Dict[StreamId, SeqNum]] = {}
|
|
self.recorded_sync_states: Dict[EventId, Dict[StreamId, SeqNum]] = {}
|
|
self.host_sync_state: Dict[StreamId, SeqNum] = {}
|
|
self.create_stream(DEFAULT_STREAM_ID)
|
|
|
|
def _ensure_stream_exists(self, stream: StreamId) -> None:
|
|
if stream not in self.current_sync_states:
|
|
logger.info(
|
|
"Found Stream with id: %s, but no matching stream "
|
|
"creation in the trace. Backfilling the trace now. "
|
|
"Perhaps the sanitizer was enabled after some torch operations?",
|
|
stream,
|
|
)
|
|
self.create_stream(stream)
|
|
|
|
def _ensure_event_exists(self, event: EventId) -> None:
|
|
if event not in self.recorded_sync_states:
|
|
logger.info(
|
|
"Found Event with id: %s, but no matching event "
|
|
"creation in the trace. Backfilling the trace now. "
|
|
"Perhaps the sanitizer was enabled after some torch operations?",
|
|
event,
|
|
)
|
|
self.create_event(event)
|
|
|
|
def _ensure_event_does_not_exist(self, event: EventId) -> None:
|
|
if event in self.recorded_sync_states:
|
|
logger.info(
|
|
"Found duplicate event creation in the trace for event with "
|
|
"id: %s. Assuming the trace for event deletion wasn't caught "
|
|
"and backfilling it now. "
|
|
"Perhaps the sanitizer was enabled after some torch operations?",
|
|
event,
|
|
)
|
|
self.delete_event(event)
|
|
|
|
def create_stream(self, stream: StreamId) -> None:
|
|
if stream in self.current_sync_states:
|
|
logger.info(
|
|
"Found duplicate Stream creation in the trace for Stream with "
|
|
"id: %s. PyTorch Streams are only created once, so this "
|
|
"trace entry is ignored.",
|
|
stream,
|
|
)
|
|
else:
|
|
self.host_sync_state[stream] = 0
|
|
self.current_sync_states[stream] = self.host_sync_state.copy()
|
|
|
|
def create_event(self, event: EventId) -> None:
|
|
self._ensure_event_does_not_exist(event)
|
|
self.recorded_sync_states[event] = {}
|
|
|
|
def delete_event(self, event: EventId) -> None:
|
|
self._ensure_event_exists(event)
|
|
del self.recorded_sync_states[event]
|
|
|
|
def update_seq_num(self, stream: StreamId, seq_num: SeqNum) -> None:
|
|
self._ensure_stream_exists(stream)
|
|
self.current_sync_states[stream][stream] = seq_num
|
|
|
|
def record_state(self, event: EventId, stream: StreamId) -> None:
|
|
self._ensure_event_exists(event)
|
|
self._ensure_stream_exists(stream)
|
|
self.recorded_sync_states[event] = self.current_sync_states[stream].copy()
|
|
|
|
def _state_wait_for_other(
|
|
self, state: Dict[StreamId, SeqNum], other: Dict[StreamId, SeqNum]
|
|
) -> None:
|
|
for stream, seq_num in other.items():
|
|
state[stream] = max(state.get(stream, -1), seq_num)
|
|
|
|
def stream_wait_for_event(self, stream: StreamId, event: EventId) -> None:
|
|
self._ensure_stream_exists(stream)
|
|
self._ensure_event_exists(event)
|
|
self._state_wait_for_other(
|
|
self.current_sync_states[stream], self.recorded_sync_states[event]
|
|
)
|
|
|
|
def all_streams_wait_for_event(self, event: EventId) -> None:
|
|
self._ensure_event_exists(event)
|
|
for stream in self.current_sync_states.keys():
|
|
self.stream_wait_for_event(stream, event)
|
|
|
|
self._state_wait_for_other(
|
|
self.host_sync_state, self.recorded_sync_states[event]
|
|
)
|
|
|
|
def all_streams_wait_for_stream(self, stream: StreamId) -> None:
|
|
self._ensure_stream_exists(stream)
|
|
for state in self.current_sync_states.values():
|
|
self._state_wait_for_other(state, self.current_sync_states[stream])
|
|
|
|
self._state_wait_for_other(
|
|
self.host_sync_state, self.current_sync_states[stream]
|
|
)
|
|
|
|
def sync_all_streams(self) -> None:
|
|
for stream, state in self.current_sync_states.items():
|
|
self.host_sync_state[stream] = state[stream]
|
|
|
|
for state in self.current_sync_states.values():
|
|
self._state_wait_for_other(state, self.host_sync_state)
|
|
|
|
def is_ordered_after(
|
|
self, current_stream: StreamId, seq_num: SeqNum, other_stream: StreamId
|
|
) -> bool:
|
|
self._ensure_stream_exists(current_stream)
|
|
self._ensure_stream_exists(other_stream)
|
|
return seq_num <= self.current_sync_states[current_stream].get(other_stream, -1)
|
|
|
|
|
|
class EventHandler:
|
|
"""Analyzes CSAN trace for synchronization errors.
|
|
|
|
Stores information on each stream's synchronizations with other streams as well
|
|
as tensor accesses to determine whether a given kernel launch might cause a
|
|
data race.
|
|
"""
|
|
|
|
def __init__(self):
|
|
self.tensors_accessed = _TensorsAccessed()
|
|
self.syncs = StreamSynchronizations()
|
|
self.seq_num: SeqNum = 0
|
|
|
|
def _handle_kernel_launch(
|
|
self,
|
|
stream: StreamId,
|
|
read_only: Set[DataPtr],
|
|
read_write: Set[DataPtr],
|
|
outputs: Set[DataPtr],
|
|
operator: str,
|
|
tensor_aliases: Dict[int, List[str]],
|
|
) -> List[SynchronizationError]:
|
|
def check_conflict(
|
|
data_ptr: DataPtr, current_access: Access, previous_access: Optional[Access]
|
|
) -> None:
|
|
if previous_access is None:
|
|
return
|
|
if not self.syncs.is_ordered_after(
|
|
current_access.stream, previous_access.seq_num, previous_access.stream
|
|
):
|
|
error_list.append(
|
|
UnsynchronizedAccessError(
|
|
data_ptr,
|
|
self.tensors_accessed.get_allocation_stack_trace(data_ptr),
|
|
current_access,
|
|
previous_access,
|
|
)
|
|
)
|
|
|
|
error_list: List[SynchronizationError] = []
|
|
self.seq_num += 1
|
|
self.syncs.update_seq_num(stream, self.seq_num)
|
|
stack_trace = traceback.StackSummary.extract(
|
|
traceback.walk_stack(inspect.currentframe()), lookup_lines=False
|
|
)
|
|
# The stack trace generated in this way is in the inverse order, so it must be
|
|
# reversed.
|
|
stack_trace.reverse()
|
|
|
|
for data_ptr in read_only:
|
|
self.tensors_accessed.ensure_tensor_exists(data_ptr)
|
|
current_access = Access(
|
|
AccessType.READ,
|
|
self.seq_num,
|
|
stream,
|
|
operator,
|
|
tensor_aliases[data_ptr],
|
|
data_ptr in outputs,
|
|
stack_trace,
|
|
)
|
|
check_conflict(
|
|
data_ptr, current_access, self.tensors_accessed.get_write(data_ptr)
|
|
)
|
|
self.tensors_accessed.add_read(data_ptr, current_access)
|
|
|
|
for data_ptr in read_write:
|
|
self.tensors_accessed.ensure_tensor_exists(data_ptr)
|
|
current_access = Access(
|
|
AccessType.WRITE,
|
|
self.seq_num,
|
|
stream,
|
|
operator,
|
|
tensor_aliases[data_ptr],
|
|
data_ptr in outputs,
|
|
stack_trace,
|
|
)
|
|
if self.tensors_accessed.were_there_reads_since_last_write(data_ptr):
|
|
for previous_access in self.tensors_accessed.get_reads(data_ptr):
|
|
check_conflict(data_ptr, current_access, previous_access)
|
|
else:
|
|
check_conflict(
|
|
data_ptr, current_access, self.tensors_accessed.get_write(data_ptr)
|
|
)
|
|
self.tensors_accessed.set_write(data_ptr, current_access)
|
|
|
|
return error_list
|
|
|
|
def _handle_event_creation(self, event: EventId) -> None:
|
|
self.syncs.create_event(event)
|
|
|
|
def _handle_event_deletion(self, event: EventId) -> None:
|
|
self.syncs.delete_event(event)
|
|
|
|
def _handle_event_record(self, event: EventId, stream: StreamId) -> None:
|
|
self.syncs.record_state(event, stream)
|
|
|
|
def _handle_event_wait(self, event: EventId, stream: StreamId) -> None:
|
|
self.syncs.stream_wait_for_event(stream, event)
|
|
|
|
def _handle_memory_allocation(self, data_ptr: DataPtr) -> None:
|
|
self.tensors_accessed.ensure_tensor_does_not_exist(data_ptr)
|
|
stack_trace = traceback.StackSummary.extract(
|
|
traceback.walk_stack(inspect.currentframe()), lookup_lines=False
|
|
)
|
|
# The stack trace generated in this way is in the inverse order, so it must be
|
|
# reversed.
|
|
stack_trace.reverse()
|
|
self.tensors_accessed.create_tensor(
|
|
data_ptr,
|
|
stack_trace,
|
|
)
|
|
|
|
def _handle_memory_deallocation(self, data_ptr: DataPtr) -> None:
|
|
self.tensors_accessed.ensure_tensor_exists(data_ptr)
|
|
self.tensors_accessed.delete_tensor(data_ptr)
|
|
|
|
def _handle_stream_creation(self, stream: StreamId) -> None:
|
|
self.syncs.create_stream(stream)
|
|
|
|
def _handle_device_synchronization(self) -> None:
|
|
self.syncs.sync_all_streams()
|
|
|
|
def _handle_stream_synchronization(self, stream: StreamId) -> None:
|
|
self.syncs.all_streams_wait_for_stream(stream)
|
|
|
|
def _handle_event_synchronization(self, event: EventId) -> None:
|
|
self.syncs.all_streams_wait_for_event(event)
|
|
|
|
|
|
def zip_by_key(a: Dict[TK, TVa], b: Dict[TK, TVb]) -> Iterator[Tuple[TK, TVa, TVb]]:
|
|
for arg, value in a.items():
|
|
if arg in b:
|
|
yield arg, value, b[arg]
|
|
|
|
|
|
def zip_arguments(
|
|
schema: torch.FunctionSchema, args: Tuple[Any, ...], kwargs: Dict[str, Any]
|
|
) -> Iterator[Tuple[torch.Argument, Any]]:
|
|
schema_args = schema.arguments[: len(args)]
|
|
schema_kwargs = {arg.name: arg for arg in schema.arguments[len(args) :]}
|
|
|
|
yield from zip(schema_args, args)
|
|
|
|
for _, argument, value in zip_by_key(schema_kwargs, kwargs):
|
|
yield (argument, value)
|
|
|
|
|
|
class ArgumentHandler:
|
|
def __init__(self):
|
|
self.dataptrs_read: Set[DataPtr] = set()
|
|
self.dataptrs_written: Set[DataPtr] = set()
|
|
self.tensor_aliases: Dict[DataPtr, List[str]] = dict()
|
|
self.outputs: Set[DataPtr] = set()
|
|
|
|
def _handle_argument(
|
|
self,
|
|
value: Any,
|
|
is_write: bool,
|
|
name: Optional[str] = None,
|
|
is_output: bool = False,
|
|
) -> None:
|
|
if isinstance(value, torch.Tensor) and value.is_cuda:
|
|
data_ptr = value.data_ptr()
|
|
if is_write:
|
|
self.dataptrs_written.add(data_ptr)
|
|
else:
|
|
self.dataptrs_read.add(data_ptr)
|
|
|
|
self.tensor_aliases.setdefault(data_ptr, [])
|
|
if name is not None:
|
|
self.tensor_aliases[data_ptr].append(name)
|
|
if is_output:
|
|
self.outputs.add(data_ptr)
|
|
|
|
def parse_inputs(
|
|
self,
|
|
schema: torch.FunctionSchema,
|
|
args: Tuple[Any, ...],
|
|
kwargs: Dict[str, Any],
|
|
) -> None:
|
|
for argument, value in zip_arguments(schema, args, kwargs):
|
|
is_write = argument.alias_info is not None and argument.alias_info.is_write
|
|
pytree.tree_map_(
|
|
functools.partial(
|
|
self._handle_argument, is_write=is_write, name=argument.name
|
|
),
|
|
value,
|
|
)
|
|
|
|
def parse_outputs(self, outputs: Any) -> None:
|
|
pytree.tree_map_(
|
|
functools.partial(self._handle_argument, is_write=True, is_output=True),
|
|
outputs,
|
|
)
|
|
|
|
|
|
class CUDASanitizerDispatchMode(TorchDispatchMode):
|
|
def __init__(self):
|
|
self.event_handler = EventHandler()
|
|
torch._C._activate_cuda_trace()
|
|
cuda_trace.register_callback_for_cuda_event_creation(
|
|
self.event_handler._handle_event_creation
|
|
)
|
|
cuda_trace.register_callback_for_cuda_event_deletion(
|
|
self.event_handler._handle_event_deletion
|
|
)
|
|
cuda_trace.register_callback_for_cuda_event_record(
|
|
self.event_handler._handle_event_record
|
|
)
|
|
cuda_trace.register_callback_for_cuda_event_wait(
|
|
self.event_handler._handle_event_wait
|
|
)
|
|
cuda_trace.register_callback_for_cuda_memory_allocation(
|
|
self.event_handler._handle_memory_allocation
|
|
)
|
|
cuda_trace.register_callback_for_cuda_memory_deallocation(
|
|
self.event_handler._handle_memory_deallocation
|
|
)
|
|
cuda_trace.register_callback_for_cuda_stream_creation(
|
|
self.event_handler._handle_stream_creation
|
|
)
|
|
cuda_trace.register_callback_for_cuda_device_synchronization(
|
|
self.event_handler._handle_device_synchronization
|
|
)
|
|
cuda_trace.register_callback_for_cuda_stream_synchronization(
|
|
self.event_handler._handle_stream_synchronization
|
|
)
|
|
cuda_trace.register_callback_for_cuda_event_synchronization(
|
|
self.event_handler._handle_event_synchronization
|
|
)
|
|
|
|
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
|
if kwargs is None:
|
|
kwargs = {}
|
|
|
|
argument_handler = ArgumentHandler()
|
|
argument_handler.parse_inputs(func._schema, args, kwargs)
|
|
|
|
outputs = func(*args, **kwargs)
|
|
|
|
argument_handler.parse_outputs(outputs)
|
|
errors = self.event_handler._handle_kernel_launch(
|
|
torch.cuda.current_stream().cuda_stream,
|
|
argument_handler.dataptrs_read - argument_handler.dataptrs_written,
|
|
argument_handler.dataptrs_written,
|
|
argument_handler.outputs,
|
|
func._schema,
|
|
argument_handler.tensor_aliases,
|
|
)
|
|
if errors:
|
|
for error in errors:
|
|
print(error, file=sys.stderr)
|
|
raise CUDASanitizerErrors(errors)
|
|
|
|
return outputs
|
|
|
|
|
|
class CUDASanitizer:
|
|
"""Manages the lifetime of a CUDASanitizer dispatch mode object.
|
|
|
|
The CUDASanitizer class wraps the entering/exiting functions of the dispatch mode
|
|
context manager in the enable function/destructor, respectively. This is to
|
|
explicitly set the lifetime of the dispatch mode object to that of the application.
|
|
This approach was deemed more elegant than using the atexit module.
|
|
"""
|
|
|
|
def __init__(self):
|
|
self.dispatch = CUDASanitizerDispatchMode()
|
|
self.enabled = False
|
|
|
|
def enable(self):
|
|
self.dispatch.__enter__()
|
|
self.enabled = True
|
|
|
|
def __del__(self):
|
|
if self.enabled:
|
|
self.dispatch.__exit__(None, None, None)
|
|
|
|
|
|
def enable_cuda_sanitizer():
|
|
"""Enable CUDA Sanitizer.
|
|
|
|
The sanitizer will begin to analyze low-level CUDA calls invoked by torch functions
|
|
for synchronization errors. All data races found will be printed to the standard
|
|
error output along with stack traces of suspected causes. For best results, the
|
|
sanitizer should be enabled at the very beginning of the program.
|
|
"""
|
|
cuda_sanitizer.enable()
|
|
|
|
|
|
cuda_sanitizer = CUDASanitizer()
|