211 lines
7.3 KiB
Python
211 lines
7.3 KiB
Python
# Copyright 2022 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
from __future__ import annotations
|
|
|
|
import dataclasses
|
|
import inspect
|
|
import threading
|
|
from typing import Any, Dict, Hashable, List, Optional, Protocol, Tuple
|
|
|
|
import numpy as np
|
|
|
|
import jax
|
|
from jax import tree_util
|
|
from jax._src import core
|
|
from jax._src import debugging
|
|
from jax._src import traceback_util
|
|
from jax._src import util
|
|
|
|
|
|
@tree_util.register_pytree_node_class
|
|
class _DictWrapper:
|
|
keys: list[Hashable]
|
|
values: list[Any]
|
|
|
|
def __init__(self, keys, values):
|
|
self._keys = keys
|
|
self._values = values
|
|
|
|
def to_dict(self):
|
|
return dict(zip(self._keys, self._values))
|
|
|
|
def tree_flatten(self):
|
|
return self._values, self._keys
|
|
|
|
@classmethod
|
|
def tree_unflatten(cls, keys, values):
|
|
return _DictWrapper(keys, values)
|
|
|
|
|
|
class _CantFlatten:
|
|
__repr__ = lambda _: "<cant_flatten>"
|
|
cant_flatten = _CantFlatten()
|
|
|
|
def _safe_flatten_dict(dct: dict[Any, Any]
|
|
) -> tuple[list[Any], tree_util.PyTreeDef]:
|
|
# We avoid comparison between keys by just using the original order
|
|
keys, values = [], []
|
|
for key, value in dct.items():
|
|
try:
|
|
tree_util.tree_leaves(value)
|
|
except:
|
|
# If flattening fails, we substitute a sentinel object.
|
|
value = cant_flatten
|
|
keys.append(key)
|
|
values.append(value)
|
|
return tree_util.tree_flatten(_DictWrapper(keys, values))
|
|
|
|
|
|
@tree_util.register_pytree_node_class
|
|
@dataclasses.dataclass(frozen=True)
|
|
class DebuggerFrame:
|
|
"""Encapsulates Python frame information."""
|
|
filename: str
|
|
locals: Dict[str, Any]
|
|
globals: Dict[str, Any]
|
|
code_context: str
|
|
source: List[str]
|
|
lineno: int
|
|
offset: Optional[int]
|
|
|
|
def tree_flatten(self):
|
|
flat_locals, locals_tree = _safe_flatten_dict(self.locals)
|
|
flat_globals, globals_tree = _safe_flatten_dict(self.globals)
|
|
flat_vars = flat_locals + flat_globals
|
|
is_valid = [
|
|
isinstance(l, (core.Tracer, jax.Array, np.ndarray))
|
|
for l in flat_vars
|
|
]
|
|
invalid_vars, valid_vars = util.partition_list(is_valid, flat_vars)
|
|
return valid_vars, (is_valid, invalid_vars, locals_tree, globals_tree,
|
|
len(flat_locals), self.filename, self.code_context,
|
|
self.source, self.lineno, self.offset)
|
|
|
|
@classmethod
|
|
def tree_unflatten(cls, info, valid_vars):
|
|
(is_valid, invalid_vars, locals_tree, globals_tree, num_locals, filename,
|
|
code_context, source, lineno, offset) = info
|
|
flat_vars = util.merge_lists(is_valid, invalid_vars, valid_vars)
|
|
flat_locals, flat_globals = util.split_list(flat_vars, [num_locals])
|
|
locals_ = tree_util.tree_unflatten(locals_tree, flat_locals).to_dict()
|
|
globals_ = tree_util.tree_unflatten(globals_tree, flat_globals).to_dict()
|
|
return DebuggerFrame(filename, locals_, globals_, code_context, source,
|
|
lineno, offset)
|
|
|
|
@classmethod
|
|
def from_frameinfo(cls, frame_info) -> DebuggerFrame:
|
|
try:
|
|
_, start = inspect.getsourcelines(frame_info.frame)
|
|
source = inspect.getsource(frame_info.frame).split("\n")
|
|
# Line numbers begin at 1 but offsets begin at 0. `inspect.getsource` will
|
|
# return a partial view of the file and a `start` indicating the line
|
|
# number that the source code starts at. However, it's possible that
|
|
# `start` is 0, indicating that we are at the beginning of the file. In
|
|
# this case, `offset` is just the `lineno - 1`. If `start` is nonzero,
|
|
# then we subtract it off from the `lineno` and don't need to subtract 1
|
|
# since both start and lineno are 1-indexed.
|
|
offset = frame_info.lineno - max(start, 1)
|
|
except OSError:
|
|
source = []
|
|
offset = None
|
|
return DebuggerFrame(
|
|
filename=frame_info.filename,
|
|
locals=frame_info.frame.f_locals,
|
|
globals={},
|
|
code_context=frame_info.code_context,
|
|
source=source,
|
|
lineno=frame_info.lineno,
|
|
offset=offset)
|
|
|
|
|
|
class Debugger(Protocol):
|
|
|
|
def __call__(self, frames: List[DebuggerFrame], thread_id: Optional[int],
|
|
**kwargs: Any) -> None:
|
|
...
|
|
_debugger_registry: Dict[str, Tuple[int, Debugger]] = {}
|
|
|
|
|
|
def get_debugger(backend: Optional[str] = None) -> Debugger:
|
|
if backend is not None and backend in _debugger_registry:
|
|
return _debugger_registry[backend][1]
|
|
debuggers = sorted(_debugger_registry.values(), key=lambda x: -x[0])
|
|
if not debuggers:
|
|
raise ValueError("No debuggers registered!")
|
|
return debuggers[0][1]
|
|
|
|
|
|
def register_debugger(name: str, debugger: Debugger, priority: int) -> None:
|
|
if name in _debugger_registry:
|
|
raise ValueError(f"Debugger with name \"{name}\" already registered.")
|
|
_debugger_registry[name] = (priority, debugger)
|
|
|
|
|
|
debug_lock = threading.Lock()
|
|
|
|
|
|
def breakpoint(*, backend: Optional[str] = None, filter_frames: bool = True,
|
|
num_frames: Optional[int] = None, ordered: bool = False,
|
|
**kwargs): # pylint: disable=redefined-builtin
|
|
"""Enters a breakpoint at a point in a program.
|
|
|
|
Args:
|
|
backend: The debugger backend to use. By default, picks the highest priority
|
|
debugger and in the absence of other registered debuggers, falls back to
|
|
the CLI debugger.
|
|
filter_frames: Whether or not to filter out JAX-internal stack frames from
|
|
the traceback. Since some libraries, like Flax, also make user of JAX's
|
|
stack frame filtering system, this option can also affect whether stack
|
|
frames from libraries are filtered.
|
|
num_frames: The number of frames above the current stack frame to make
|
|
available for inspection in the interactive debugger.
|
|
ordered: A keyword only argument used to indicate whether or not the
|
|
staged out computation will enforce ordering of this ``debug_print``
|
|
with respect to other ordered ``debug_print`` calls.
|
|
|
|
Returns:
|
|
None.
|
|
"""
|
|
frame_infos = inspect.stack()
|
|
# Throw out first frame corresponding to this function
|
|
frame_infos = frame_infos[1:]
|
|
if num_frames is not None:
|
|
frame_infos = frame_infos[:num_frames]
|
|
# Filter out internal frames
|
|
if filter_frames:
|
|
frames = [
|
|
DebuggerFrame.from_frameinfo(frame_info)
|
|
for frame_info in frame_infos
|
|
if traceback_util.include_frame(frame_info.frame)
|
|
]
|
|
else:
|
|
frames = [
|
|
DebuggerFrame.from_frameinfo(frame_info)
|
|
for frame_info in frame_infos
|
|
]
|
|
flat_args, frames_tree = tree_util.tree_flatten(frames)
|
|
|
|
def _breakpoint_callback(*flat_args):
|
|
frames = tree_util.tree_unflatten(frames_tree, flat_args)
|
|
thread_id = None
|
|
if threading.current_thread() is not threading.main_thread():
|
|
thread_id = threading.get_ident()
|
|
debugger = get_debugger(backend=backend)
|
|
# Lock here because this could be called from multiple threads at the same
|
|
# time.
|
|
with debug_lock:
|
|
debugger(frames, thread_id, **kwargs)
|
|
|
|
debugging.debug_callback(_breakpoint_callback, *flat_args, ordered=ordered)
|