Intelegentny_Pszczelarz/.venv/Lib/site-packages/jax/jaxpr_util.py
2023-06-19 00:49:18 +02:00

209 lines
6.4 KiB
Python

# Copyright 2020 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for the Jaxpr IR."""
import collections
import gzip
import itertools
import json
import types
from typing import Any, Callable, DefaultDict, Dict, List, Optional, Tuple
from jax._src import core
from jax._src import util
from jax._src import source_info_util
from jax._src.lib import xla_client
map, unsafe_map = util.safe_map, map
zip, unsafe_zip = util.safe_zip, zip
def all_eqns(jaxpr: core.Jaxpr):
for eqn in jaxpr.eqns:
yield (jaxpr, eqn)
for subjaxpr in core.subjaxprs(jaxpr):
yield from all_eqns(subjaxpr)
def collect_eqns(jaxpr: core.Jaxpr, key: Callable):
d = collections.defaultdict(list)
for _, eqn in all_eqns(jaxpr):
d[key(eqn)].append(eqn)
return dict(d)
def histogram(jaxpr: core.Jaxpr, key: Callable,
key_fmt: Callable = lambda x: x):
d = collect_eqns(jaxpr, key)
return {key_fmt(k): len(v) for k, v in d.items()}
def primitives(jaxpr: core.Jaxpr):
return histogram(jaxpr, lambda eqn: eqn.primitive.name)
def primitives_by_source(jaxpr: core.Jaxpr):
def key(eqn):
src = source_info_util.summarize(eqn.source_info)
return (eqn.primitive.name, src)
return histogram(jaxpr, key, ' @ '.join)
def primitives_by_shape(jaxpr: core.Jaxpr):
def shape_fmt(var):
return '*' if isinstance(var, core.DropVar) else var.aval.str_short()
def key(eqn):
return (eqn.primitive.name, ' '.join(map(shape_fmt, eqn.outvars)))
return histogram(jaxpr, key, ' :: '.join)
def source_locations(jaxpr: core.Jaxpr):
def key(eqn):
return source_info_util.summarize(eqn.source_info)
return histogram(jaxpr, key)
MaybeEqn = Optional[core.JaxprEqn]
def var_defs_and_refs(jaxpr: core.Jaxpr):
defs: Dict[core.Var, MaybeEqn] = {}
refs: Dict[core.Var, List[MaybeEqn]] = {}
def read(a: core.Atom, eqn: MaybeEqn):
if not isinstance(a, core.Literal):
assert a in defs, a
assert a in refs, a
refs[a].append(eqn)
def write(v: core.Var, eqn: MaybeEqn):
assert v not in defs, v
assert v not in refs, v
if not isinstance(v, core.DropVar):
defs[v] = eqn
refs[v] = []
for v in jaxpr.constvars:
write(v, None)
for v in jaxpr.invars:
write(v, None)
for eqn in jaxpr.eqns:
for a in eqn.invars:
read(a, eqn)
for v in eqn.outvars:
write(v, eqn)
for a in jaxpr.outvars:
read(a, None)
res = [(v, defs[v], refs[v]) for v in defs]
subs = map(var_defs_and_refs, core.subjaxprs(jaxpr))
return [(jaxpr, res), *subs] if subs else (jaxpr, res)
def vars_by_fanout(jaxpr: core.Jaxpr):
def fmt_key(var, eqn):
if eqn is None:
return f'{var} <- invar'
else:
src = source_info_util.summarize(eqn.source_info)
return f'{var} <- {eqn.primitive.name} @ {src}'
def hist(jaxpr, reads):
return {fmt_key(var, var_def): len(var_refs)
for var, var_def, var_refs in reads}
return [(j, hist(j, reads)) for j, reads in var_defs_and_refs(jaxpr)] # pytype: disable=bad-unpacking
def print_histogram(histogram: Dict[Any, int]):
count_width = max(len(str(v)) for v in histogram.values())
count_fmt = '{:>' + str(count_width) + 'd}'
pairs = [(v, k) for k, v in histogram.items()]
for count, name in reversed(sorted(pairs)):
print(count_fmt.format(count), name)
def _pprof_profile(
profile: Dict[Tuple[Optional[xla_client.Traceback], core.Primitive], int]
) -> bytes:
"""Converts a profile into a compressed pprof protocol buffer.
The input profile is a map from (traceback, primitive) pairs to counts.
"""
s: DefaultDict[str, int]
func: DefaultDict[types.CodeType, int]
loc: DefaultDict[Tuple[types.CodeType, int], int]
s = collections.defaultdict(itertools.count(1).__next__)
func = collections.defaultdict(itertools.count(1).__next__)
loc = collections.defaultdict(itertools.count(1).__next__)
s[""] = 0
primitive_key = s["primitive"]
samples = []
for (tb, primitive), count in profile.items():
if tb is None:
frames = []
else:
raw_frames = zip(*tb.raw_frames())
frames = [loc[(code, lasti)] for code, lasti in raw_frames
if source_info_util.is_user_filename(code.co_filename)] # type: ignore
samples.append({
"location_id": frames,
"value": [count],
"label": [{
"key": primitive_key,
"str": s[primitive.name]
}]
})
locations = [
{"id": loc_id,
"line": [{"function_id": func[code],
"line": xla_client.Traceback.code_addr2line(code, lasti)}]}
for (code, lasti), loc_id in loc.items()
]
functions = [
{"id": func_id,
"name": s[code.co_name],
"system_name": s[code.co_name],
"filename": s[code.co_filename],
"start_line": code.co_firstlineno}
for code, func_id in func.items()
]
sample_type = [{"type": s["equations"], "unit": s["count"]}]
# This is the JSON encoding of a pprof profile protocol buffer. See:
# https://github.com/google/pprof/blob/master/proto/profile.proto for a
# description of the format.
json_profile = json.dumps({
"string_table": list(s.keys()),
"location": locations,
"function": functions,
"sample_type": sample_type,
"sample": samples,
})
return gzip.compress(xla_client._xla.json_to_pprof_profile(json_profile))
def pprof_equation_profile(jaxpr: core.Jaxpr) -> bytes:
"""Generates a pprof profile that maps jaxpr equations to Python stack traces.
By visualizing the profile using pprof, one can identify Python code that is
responsible for yielding large numbers of jaxpr equations.
Args:
jaxpr: a Jaxpr.
Returns:
A gzip-compressed pprof Profile protocol buffer, suitable for passing to
pprof tool for visualization.
"""
d: DefaultDict[Tuple[Optional[xla_client.Traceback], core.Primitive], int]
d = collections.defaultdict(lambda: 0)
for _, eqn in all_eqns(jaxpr):
d[(eqn.source_info.traceback, eqn.primitive)] += 1
return _pprof_profile(d)