Intelegentny_Pszczelarz/.venv/Lib/site-packages/jax/experimental/jax2tf/tests/models_test_main.py
2023-06-19 00:49:18 +02:00

261 lines
8.7 KiB
Python

# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Tests for converting Flax models using jax2tf.
Note: we do not use actual absl tests for this since we do not want this to run
as part of our CI. Instead, we mainly use these tests to generate the markdown
file.
Run this as follows:
```
# Test all converters on all suites and regenerate the table.
python3 models_test_main.py
# Test only CNN and Resnet50 with the TFjs converter and don't write Markdown.
python3 models_test_main.py \
--examples=flax/cnn,flax/resnet50 \
--converters=jax2tf_noxla_tfjs \
--write_markdown=False
# Execute all CNN tests (including the tests for shape polymorphism):
python3 models_test_main.py --example_prefix=flax/cnn
# Execute only a specific CNN test for shape polymorphism (NOTE: quotes and
# commas are stripped from the example names).
python3 models_test_main.py --examples="flax/cnn_[(b ...)]"
```
"""
from absl import app
from absl import flags
import functools
import numpy as np
import os
import re
import jax
from jax._src import dtypes
import jax.numpy as jnp
import datetime
from typing import Dict, List, Sequence, Tuple
import tensorflow as tf
from jax.experimental.jax2tf.shape_poly import InconclusiveDimensionOperation
from jax.experimental.jax2tf.tests.model_harness import ALL_HARNESSES
from jax.experimental.jax2tf.tests.converters import ALL_CONVERTERS
flags.DEFINE_list("converters", [x.name for x in ALL_CONVERTERS],
"Which converters to test.")
flags.DEFINE_list("examples", [],
("List of examples to test, e.g.: 'flax/mnist,flax/seq2seq'. "
"If empty, will test all examples."))
flags.DEFINE_string("example_prefix", "",
("Prefix for filtering tests. For instance 'flax/mnist' "
"will test all examples starting with 'flax/mnist' "
"(including all polymorphic tests)."))
flags.DEFINE_bool(
"write_markdown", True,
"If true, write results as Markdown. Otherwise, only output to stdout.")
flags.DEFINE_bool(
"fail_on_error", False,
("If true, exit with an error when a conversion fails. Useful for "
"debugging because it will show the entire stack trace."))
FLAGS = flags.FLAGS
def _write_markdown(results: Dict[str, List[Tuple[str, str,]]]) -> None:
"""Writes all results to Markdown file."""
table_lines = []
converters = FLAGS.converters
table_lines.append("| Example | " + " ".join([f"{c} |" for c in converters]))
table_lines.append("|" + (" --- |" * (len(converters) + 1)))
error_lines = []
def _header2anchor(header):
header = header.lower().replace(" ", "-")
header = "#" + re.sub(r"[^a-zA-Z0-9_-]+", "", header)
return header
for example_name, converter_results in results.items():
# Make sure the converters are in the right order.
assert [c[0] for c in converter_results] == converters
line = f"| `{example_name}` |"
example_error_lines = []
for converter_name, error_msg in converter_results:
if not error_msg:
line += " YES |"
else:
error_header = (f"Example: `{example_name}` | "
f"Converter: `{converter_name}`")
example_error_lines.append("### " + error_header)
example_error_lines.append(f"```\n{error_msg}\n```")
example_error_lines.append("[Back to top](#summary-table)\n")
line += f" [NO]({_header2anchor(error_header)}) |"
table_lines.append(line)
if example_error_lines:
error_lines.append(f"## `{example_name}`")
error_lines.extend(example_error_lines)
g3doc_path = "../g3doc"
output_path = os.path.join(g3doc_path, "convert_models_results.md")
template_path = output_path + ".template"
if (workdir := "BUILD_WORKING_DIRECTORY") in os.environ:
os.chdir(os.path.dirname(os.environ[workdir]))
with tf.io.gfile.GFile(template_path, "r") as f_in, \
tf.io.gfile.GFile(output_path, "w") as f_out:
template = "".join(f_in.readlines())
template = template.replace("{{generation_date}}", str(datetime.date.today()))
template = template.replace("{{table}}", "\n".join(table_lines))
template = template.replace("{{errors}}", "\n".join(error_lines))
f_out.write(template)
print("Written converter results to", output_path)
def _crop_convert_error(msg: str) -> str:
msg = msg.replace("\\n", "\n").replace("\\t", "\t")
tflite_msg = "Some ops are not supported by the native TFLite runtime"
if tflite_msg in msg:
# Filters out only the missing ops.
missing_ops = msg.split("\nDetails:\n")[1].split("\n\n")[0]
return tflite_msg + "\n" + missing_ops
tfjs_msg = "Some ops in the model are custom ops"
if tfjs_msg in msg:
return msg[msg.index(tfjs_msg):]
return msg
def _get_random_data(x: jax.Array) -> np.ndarray:
dtype = dtypes.canonicalize_dtype(x.dtype)
if np.issubdtype(dtype, np.integer):
return np.random.randint(0, 100, size=x.shape, dtype=dtype)
elif np.issubdtype(dtype, np.floating):
return np.array(np.random.uniform(size=x.shape), dtype=dtype)
elif dtype == bool:
return np.random.choice(a=[False, True], size=x.shape)
else:
raise ValueError(f"Unsupported dtype for numerical comparison: {dtype}")
def test_converters():
"""Tests all converters and write output."""
def _exit_if_empty(x, name):
if not x:
print(f"WARNING: No {name} found to test! Exiting")
exit()
def _maybe_reraise(e):
if FLAGS.fail_on_error:
raise e
def _format(e):
# InvalidArgumentError output is shown differently.
msg = str(e) if repr(e).strip() == "InvalidArgumentError()" else repr(e)
msg = msg.replace("\n\n", "\n") # Output newlines correctly.
return msg
converters = list(
filter(lambda x: x.name in FLAGS.converters, ALL_CONVERTERS))
_exit_if_empty(converters, "converters")
harnesses_to_test = {
name: fn for name, fn in ALL_HARNESSES.items()
if (not FLAGS.examples or name in FLAGS.examples) and
(not FLAGS.example_prefix or name.startswith(FLAGS.example_prefix))
}
_exit_if_empty(harnesses_to_test, "harness")
results = {}
for name, harness_fn in harnesses_to_test.items():
harness = harness_fn() # This will create the variables.
np_assert_allclose = functools.partial(
np.testing.assert_allclose, rtol=harness.rtol)
converter_results = []
for converter in converters:
print(f"===== Testing example {harness.name}, "
f"Converter {converter.name}")
error_msg = ""
try:
apply_tf = converter.convert_fn(harness)
print("=== Conversion OK!")
except Exception as e: # pylint: disable=broad-except
error_msg = "Conversion error\n" + _crop_convert_error(repr(e))
print(f"=== {error_msg}")
_maybe_reraise(e)
else:
if not converter.compare_numerics:
print("=== Skipping numerical comparison.")
else:
xs = [jax.tree_util.tree_map(_get_random_data, x) for x in harness.inputs]
jax_result = harness.apply_with_vars(*xs)
try:
tf_result = apply_tf(*xs)
jax.tree_map(np_assert_allclose, jax_result, tf_result)
print("=== Numerical comparison OK!")
except AssertionError as e:
error_msg = "Numerical comparison error:\n" + _format(e)
print("===", error_msg)
_maybe_reraise(e)
# Catch remaining but expected errors (we still want to fail on
# unexpected errors).
except (
# TFLite runtime error
RuntimeError,
# TF error
tf.errors.InvalidArgumentError,
# jax2tf poly shape errors.
InconclusiveDimensionOperation,
TypeError,
IndexError,
ValueError) as e:
error_msg = _format(e)
print("=== ", error_msg)
_maybe_reraise(e)
converter_results.append((converter.name, error_msg))
results[harness.name] = converter_results
if FLAGS.write_markdown:
_write_markdown(results)
else:
print("=== NOT writing results to Markdown.")
def main(argv: Sequence[str]) -> None:
if len(argv) > 1:
raise app.UsageError("Too many command-line arguments.")
test_converters()
if __name__ == "__main__":
app.run(main)