32 lines
805 B
Python
32 lines
805 B
Python
|
import torch.fx as fx
|
||
|
|
||
|
def set_trace(gm: fx.GraphModule) -> fx.GraphModule:
|
||
|
"""
|
||
|
Sets a breakpoint in `gm`'s generated python code. It drops into pdb when
|
||
|
`gm` gets run.
|
||
|
|
||
|
Args:
|
||
|
gm: graph module to insert breakpoint. It is then recompiled for it to
|
||
|
take effect.
|
||
|
|
||
|
Returns:
|
||
|
the `gm` with breakpoint inserted.
|
||
|
"""
|
||
|
def insert_pdb(body):
|
||
|
return ["import pdb; pdb.set_trace()\n", *body]
|
||
|
|
||
|
with gm.graph.on_generate_code(
|
||
|
make_transformer=lambda cur_transform: (
|
||
|
# new code transformer to register
|
||
|
lambda body: (
|
||
|
insert_pdb(
|
||
|
cur_transform(body) if cur_transform
|
||
|
else body
|
||
|
)
|
||
|
)
|
||
|
)
|
||
|
):
|
||
|
gm.recompile()
|
||
|
|
||
|
return gm
|