3RNN/Lib/site-packages/tensorflow/tools/docs/tf_doctest_lib.py
2024-05-26 19:49:15 +02:00

225 lines
8.1 KiB
Python

# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Run doctests for tensorflow."""
import doctest
import re
import textwrap
import numpy as np
class _FloatExtractor(object):
"""Class for extracting floats from a string.
For example:
>>> text_parts, floats = _FloatExtractor()("Text 1.0 Text")
>>> text_parts
["Text ", " Text"]
>>> floats
np.array([1.0])
"""
# Note: non-capturing groups "(?" are not returned in matched groups, or by
# re.split.
_FLOAT_RE = re.compile(
r"""
( # Captures the float value.
(?:
[-+]| # Start with a sign is okay anywhere.
(?: # Otherwise:
^| # Start after the start of string
(?<=[^\w.]) # Not after a word char, or a .
)
)
(?: # Digits and exponent - something like:
{digits_dot_maybe_digits}{exponent}?| # "1.0" "1." "1.0e3", "1.e3"
{dot_digits}{exponent}?| # ".1" ".1e3"
{digits}{exponent}| # "1e3"
{digits}(?=j) # "300j"
)
)
j? # Optional j for cplx numbers, not captured.
(?= # Only accept the match if
$| # * At the end of the string, or
[^\w.] # * Next char is not a word char or "."
)
""".format(
# Digits, a "." and optional more digits: "1.1".
digits_dot_maybe_digits=r'(?:[0-9]+\.(?:[0-9]*))',
# A "." with trailing digits ".23"
dot_digits=r'(?:\.[0-9]+)',
# digits: "12"
digits=r'(?:[0-9]+)',
# The exponent: An "e" or "E", optional sign, and at least one digit.
# "e-123", "E+12", "e12"
exponent=r'(?:[eE][-+]?[0-9]+)'),
re.VERBOSE)
def __call__(self, string):
"""Extracts floats from a string.
>>> text_parts, floats = _FloatExtractor()("Text 1.0 Text")
>>> text_parts
["Text ", " Text"]
>>> floats
np.array([1.0])
Args:
string: the string to extract floats from.
Returns:
A (string, array) pair, where `string` has each float replaced by "..."
and `array` is a `float32` `numpy.array` containing the extracted floats.
"""
texts = []
floats = []
for i, part in enumerate(self._FLOAT_RE.split(string)):
if i % 2 == 0:
texts.append(part)
else:
floats.append(float(part))
return texts, np.array(floats)
class TfDoctestOutputChecker(doctest.OutputChecker, object):
"""Customizes how `want` and `got` are compared, see `check_output`."""
def __init__(self, *args, **kwargs):
super(TfDoctestOutputChecker, self).__init__(*args, **kwargs)
self.extract_floats = _FloatExtractor()
self.text_good = None
self.float_size_good = None
_ADDRESS_RE = re.compile(r'\bat 0x[0-9a-f]*?>')
# TODO(yashkatariya): Add other tensor's string substitutions too.
# tf.RaggedTensor doesn't need one.
_NUMPY_OUTPUT_RE = re.compile(r'<tf.Tensor.*?numpy=(.*?)>', re.DOTALL)
def _allclose(self, want, got, rtol=1e-3, atol=1e-3):
return np.allclose(want, got, rtol=rtol, atol=atol)
def _tf_tensor_numpy_output(self, string):
modified_string = self._NUMPY_OUTPUT_RE.sub(r'\1', string)
return modified_string, modified_string != string
MESSAGE = textwrap.dedent("""\n
#############################################################
Check the documentation (https://www.tensorflow.org/community/contribute/docs_ref) on how to
write testable docstrings.
#############################################################""")
def check_output(self, want, got, optionflags):
"""Compares the docstring output to the output gotten by running the code.
Python addresses in the output are replaced with wildcards.
Float values in the output compared as using `np.allclose`:
* Float values are extracted from the text and replaced with wildcards.
* The wildcard text is compared to the actual output.
* The float values are compared using `np.allclose`.
The method returns `True` if both the text comparison and the numeric
comparison are successful.
The numeric comparison will fail if either:
* The wrong number of floats are found.
* The float values are not within tolerence.
Args:
want: The output in the docstring.
got: The output generated after running the snippet.
optionflags: Flags passed to the doctest.
Returns:
A bool, indicating if the check was successful or not.
"""
# If the docstring's output is empty and there is some output generated
# after running the snippet, return True. This is because if the user
# doesn't want to display output, respect that over what the doctest wants.
if got and not want:
return True
if want is None:
want = ''
if want == got:
return True
# Replace python's addresses with ellipsis (`...`) since it can change on
# each execution.
want = self._ADDRESS_RE.sub('at ...>', want)
# Replace tf.Tensor strings with only their numpy field values.
want, want_changed = self._tf_tensor_numpy_output(want)
if want_changed:
got, _ = self._tf_tensor_numpy_output(got)
# Separate out the floats, and replace `want` with the wild-card version
# "result=7.0" => "result=..."
want_text_parts, self.want_floats = self.extract_floats(want)
# numpy sometimes pads floats in arrays with spaces
# got: [1.2345, 2.3456, 3.0 ] want: [1.2345, 2.3456, 3.0001]
# And "normalize whitespace" only works when there's at least one space,
# so strip them and let the wildcard handle it.
want_text_parts = [part.strip(' ') for part in want_text_parts]
want_text_wild = '...'.join(want_text_parts)
if '....' in want_text_wild:
# If a float comes just after a period you'll end up four dots and the
# first three count as the ellipsis. Replace it with three dots.
want_text_wild = re.sub(r'\.\.\.\.+', '...', want_text_wild)
# Find the floats in the string returned by the test
_, self.got_floats = self.extract_floats(got)
self.text_good = super(TfDoctestOutputChecker, self).check_output(
want=want_text_wild, got=got, optionflags=optionflags)
if not self.text_good:
return False
if self.want_floats.size == 0:
# If there are no floats in the "want" string, ignore all the floats in
# the result. "np.array([ ... ])" matches "np.array([ 1.0, 2.0 ])"
return True
self.float_size_good = (self.want_floats.size == self.got_floats.size)
if self.float_size_good:
return self._allclose(self.want_floats, self.got_floats)
else:
return False
def output_difference(self, example, got, optionflags):
got = [got]
# If the some of the float output is hidden with `...`, `float_size_good`
# will be False. This is because the floats extracted from the string is
# converted into a 1-D numpy array. Hence hidding floats is not allowed
# anymore.
if self.text_good:
if not self.float_size_good:
got.append("\n\nCAUTION: tf_doctest doesn't work if *some* of the "
"*float output* is hidden with a \"...\".")
got.append(self.MESSAGE)
got = '\n'.join(got)
return (super(TfDoctestOutputChecker,
self).output_difference(example, got, optionflags))