Intelegentny_Pszczelarz/.venv/Lib/site-packages/jax/_src/debugger/colab_debugger.py
2023-06-19 00:49:18 +02:00

261 lines
7.7 KiB
Python

# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module for Colab-specific debugger."""
from __future__ import annotations
import html
import inspect
import traceback
from typing import List
import uuid
from jax._src.debugger import colab_lib
from jax._src.debugger import core as debugger_core
from jax._src.debugger import cli_debugger
# pylint: disable=g-import-not-at-top
# pytype: disable=import-error
if colab_lib.IS_COLAB_ENABLED:
from google.colab import output
try:
import pygments
IS_PYGMENTS_ENABLED = True
except ImportError:
IS_PYGMENTS_ENABLED = False
# pytype: enable=import-error
# pylint: enable=g-import-not-at-top
class CodeViewer(colab_lib.DynamicDOMElement):
"""A mutable DOM element that displays code as HTML."""
def __init__(self, code_: str, highlights: List[int], linenostart: int = 1):
self._code = code_
self._highlights = highlights
self._view = colab_lib.dynamic(colab_lib.div())
self._linenostart = linenostart
def render(self):
self.update_code(
self._code, self._highlights, linenostart=self._linenostart)
def clear(self):
self._view.clear()
def append(self, child):
raise NotImplementedError
def update(self, elem):
self._view.update(elem)
def _highlight_code(self, code: str, highlights, linenostart: int):
is_dark_mode = output.eval_js(
'document.documentElement.matches("[theme=dark]");')
code_style = "monokai" if is_dark_mode else "default"
hl_color = "#4e56b7" if is_dark_mode else "#fff7c1"
if IS_PYGMENTS_ENABLED:
lexer = pygments.lexers.get_lexer_by_name("python")
formatter = pygments.formatters.HtmlFormatter(
full=False,
hl_lines=highlights,
linenos=True,
linenostart=linenostart,
style=code_style)
if hl_color:
formatter.style.highlight_color = hl_color
css_ = formatter.get_style_defs()
code = pygments.highlight(code, lexer, formatter)
else:
return "";
return code, css_
def update_code(self, code_, highlights, *, linenostart: int = 1):
"""Updates the code viewer to use new code."""
self._code = code_
self._view.clear()
code_, css_ = self._highlight_code(self._code, highlights, linenostart)
uuid_ = uuid.uuid4()
code_div = colab_lib.div(
colab_lib.css(css_),
code_,
id=f"code-{uuid_}",
style=colab_lib.style({
"max-height": "500px",
"overflow-y": "scroll",
"background-color": "var(--colab-border-color)",
"padding": "5px 5px 5px 5px",
}))
if highlights:
percent_scroll = highlights[0] / len(self._code.split("\n"))
else:
percent_scroll = 0.
self.update(code_div)
# Scroll to where the line is
output.eval_js("""
console.log("{id}")
var elem = document.getElementById("{id}")
var maxScrollPosition = elem.scrollHeight - elem.clientHeight;
elem.scrollTop = maxScrollPosition * {percent_scroll}
""".format(id=f"code-{uuid_}", percent_scroll=percent_scroll))
class FramePreview(colab_lib.DynamicDOMElement):
"""Displays information about a stack frame."""
def __init__(self, frame):
super().__init__()
self._header = colab_lib.dynamic(
colab_lib.div(colab_lib.pre(colab_lib.code(""))))
self._code_view = CodeViewer("", highlights=[])
self.frame = frame
self._file_cache = {}
def clear(self):
self._header.clear()
self._code_view.clear()
def append(self, child):
raise NotImplementedError
def update(self, elem):
raise NotImplementedError
def update_frame(self, frame):
"""Updates the frame viewer to use a new frame."""
self.frame = frame
lineno = self.frame.lineno or None
filename = self.frame.filename.strip()
if inspect.getmodulename(filename):
if filename not in self._file_cache:
try:
with open(filename, "r") as fp:
self._file_cache[filename] = fp.read()
source = self._file_cache[filename]
highlight = lineno
linenostart = 1
except FileNotFoundError:
source = "\n".join(frame.source)
highlight = min(frame.offset + 1, len(frame.source) - 1)
linenostart = lineno - frame.offset
else:
source = "\n".join(frame.source)
highlight = min(frame.offset + 1, len(frame.source) - 1)
linenostart = lineno - frame.offset
self._header.clear()
self._header.update(
colab_lib.div(
colab_lib.pre(colab_lib.code(f"{html.escape(filename)}({lineno})")),
style=colab_lib.style({
"padding": "5px 5px 5px 5px",
"background-color": "var(--colab-highlighted-surface-color)",
})))
self._code_view.update_code(source, [highlight], linenostart=linenostart)
def render(self):
self.update_frame(self.frame)
class DebuggerView(colab_lib.DynamicDOMElement):
"""Main view for the Colab debugger."""
def __init__(self, frame, *, log_color=""):
super().__init__()
self._interaction_log = colab_lib.dynamic(colab_lib.div())
self._frame_preview = FramePreview(frame)
self._header = colab_lib.dynamic(
colab_lib.div(
colab_lib.span("Breakpoint"),
style=colab_lib.style({
"background-color": "var(--colab-secondary-surface-color)",
"color": "var(--colab-primary-text-color)",
"padding": "5px 5px 5px 5px",
"font-weight": "bold",
})))
def render(self):
self._header.render()
self._frame_preview.render()
self._interaction_log.render()
def append(self, child):
raise NotImplementedError
def update(self, elem):
raise NotImplementedError
def clear(self):
self._header.clear()
self._interaction_log.clear()
self._frame_preview.clear()
def update_frame(self, frame):
self._frame_preview.update_frame(frame)
def write(self, text):
self._interaction_log.append(colab_lib.pre(text))
def read(self):
raise NotImplementedError()
def readline(self):
with output.use_tags(["stdin"]):
user_input = input() + "\n"
output.clear(output_tags=["stdin"])
return user_input
def isatty(self):
return True
def flush(self):
pass
class ColabDebugger(cli_debugger.CliDebugger):
"""A JAX debugger for a Colab environment."""
def __init__(self,
frames: List[debugger_core.DebuggerFrame],
thread_id: int):
super().__init__(frames, thread_id)
self._debugger_view = DebuggerView(self.current_frame())
self.stdout = self.stdin = self._debugger_view # type: ignore
def do_up(self, arg):
super().do_up(arg)
self._debugger_view.update_frame(self.current_frame())
return False
def do_down(self, arg):
super().do_down(arg)
self._debugger_view.update_frame(self.current_frame())
return False
def run(self):
self._debugger_view.render()
while True:
if not self.cmdloop():
return
def _run_debugger(frames, thread_id, **kwargs):
try:
ColabDebugger(frames, thread_id, **kwargs).run()
except Exception:
traceback.print_exc()
if colab_lib.IS_COLAB_ENABLED:
debugger_core.register_debugger("colab", _run_debugger, 1)