Intelegentny_Pszczelarz/.venv/Lib/site-packages/jax/_src/debugger/web_debugger.py
2023-06-19 00:49:18 +02:00

96 lines
3.2 KiB
Python

# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import os
import weakref
from typing import Any, Dict, List, Optional, Tuple
from jax._src.debugger import cli_debugger
from jax._src.debugger import core as debugger_core
web_pdb_version: Optional[Tuple[int, ...]] = None
try:
import web_pdb # pytype: disable=import-error
web_pdb_version = tuple(map(int, web_pdb.__version__.split(".")))
WEB_PDB_ENABLED = True
except:
WEB_PDB_ENABLED = False
_web_consoles: Dict[Tuple[str, int], web_pdb.WebConsole] = {}
class WebDebugger(cli_debugger.CliDebugger):
"""A web-based debugger."""
prompt = '(jdb) '
use_rawinput: bool = False
def __init__(self, frames: List[debugger_core.DebuggerFrame], thread_id,
completekey: str = "tab", host: str = "", port: int = 5555):
if (host, port) not in _web_consoles:
_web_consoles[host, port] = web_pdb.WebConsole(host, port, self)
# Clobber the debugger in the web console
_web_console = _web_consoles[host, port]
_web_console._debugger = weakref.proxy(self)
super().__init__(frames, thread_id, stdin=_web_console, stdout=_web_console,
completekey=completekey)
def get_current_frame_data(self):
# Constructs the info needed for the web console to display info
current_frame = self.current_frame()
filename = current_frame.filename
lines = current_frame.source
current_line = None
if current_frame.offset is not None:
current_line = current_frame.offset + 1
if web_pdb_version and web_pdb_version < (1, 4, 4):
return {
'filename': filename,
'listing': '\n'.join(lines),
'curr_line': current_line,
'total_lines': len(lines),
'breaklist': [],
}
return {
'dirname': os.path.dirname(os.path.abspath(filename)) + os.path.sep,
'filename': os.path.basename(filename),
'file_listing': '\n'.join(lines),
'current_line': current_line,
'breakpoints': [],
'globals': self.get_globals(),
'locals': self.get_locals(),
}
def get_globals(self):
current_frame = self.current_frame()
globals = "\n".join([f"{key} = {value}" for key, value in
sorted(current_frame.globals.items())])
return globals
def get_locals(self):
current_frame = self.current_frame()
locals = "\n".join([f"{key} = {value}" for key, value in
sorted(current_frame.locals.items())])
return locals
def run(self):
return self.cmdloop()
def run_debugger(frames: List[debugger_core.DebuggerFrame],
thread_id: Optional[int], **kwargs: Any):
WebDebugger(frames, thread_id, **kwargs).run()
if WEB_PDB_ENABLED:
debugger_core.register_debugger("web", run_debugger, -2)