225 lines
8.1 KiB
Python
225 lines
8.1 KiB
Python
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Run doctests for tensorflow."""
|
|
|
|
import doctest
|
|
import re
|
|
import textwrap
|
|
|
|
import numpy as np
|
|
|
|
|
|
class _FloatExtractor(object):
|
|
"""Class for extracting floats from a string.
|
|
|
|
For example:
|
|
|
|
>>> text_parts, floats = _FloatExtractor()("Text 1.0 Text")
|
|
>>> text_parts
|
|
["Text ", " Text"]
|
|
>>> floats
|
|
np.array([1.0])
|
|
"""
|
|
|
|
# Note: non-capturing groups "(?" are not returned in matched groups, or by
|
|
# re.split.
|
|
_FLOAT_RE = re.compile(
|
|
r"""
|
|
( # Captures the float value.
|
|
(?:
|
|
[-+]| # Start with a sign is okay anywhere.
|
|
(?: # Otherwise:
|
|
^| # Start after the start of string
|
|
(?<=[^\w.]) # Not after a word char, or a .
|
|
)
|
|
)
|
|
(?: # Digits and exponent - something like:
|
|
{digits_dot_maybe_digits}{exponent}?| # "1.0" "1." "1.0e3", "1.e3"
|
|
{dot_digits}{exponent}?| # ".1" ".1e3"
|
|
{digits}{exponent}| # "1e3"
|
|
{digits}(?=j) # "300j"
|
|
)
|
|
)
|
|
j? # Optional j for cplx numbers, not captured.
|
|
(?= # Only accept the match if
|
|
$| # * At the end of the string, or
|
|
[^\w.] # * Next char is not a word char or "."
|
|
)
|
|
""".format(
|
|
# Digits, a "." and optional more digits: "1.1".
|
|
digits_dot_maybe_digits=r'(?:[0-9]+\.(?:[0-9]*))',
|
|
# A "." with trailing digits ".23"
|
|
dot_digits=r'(?:\.[0-9]+)',
|
|
# digits: "12"
|
|
digits=r'(?:[0-9]+)',
|
|
# The exponent: An "e" or "E", optional sign, and at least one digit.
|
|
# "e-123", "E+12", "e12"
|
|
exponent=r'(?:[eE][-+]?[0-9]+)'),
|
|
re.VERBOSE)
|
|
|
|
def __call__(self, string):
|
|
"""Extracts floats from a string.
|
|
|
|
>>> text_parts, floats = _FloatExtractor()("Text 1.0 Text")
|
|
>>> text_parts
|
|
["Text ", " Text"]
|
|
>>> floats
|
|
np.array([1.0])
|
|
|
|
Args:
|
|
string: the string to extract floats from.
|
|
|
|
Returns:
|
|
A (string, array) pair, where `string` has each float replaced by "..."
|
|
and `array` is a `float32` `numpy.array` containing the extracted floats.
|
|
"""
|
|
texts = []
|
|
floats = []
|
|
for i, part in enumerate(self._FLOAT_RE.split(string)):
|
|
if i % 2 == 0:
|
|
texts.append(part)
|
|
else:
|
|
floats.append(float(part))
|
|
|
|
return texts, np.array(floats)
|
|
|
|
|
|
class TfDoctestOutputChecker(doctest.OutputChecker, object):
|
|
"""Customizes how `want` and `got` are compared, see `check_output`."""
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super(TfDoctestOutputChecker, self).__init__(*args, **kwargs)
|
|
self.extract_floats = _FloatExtractor()
|
|
self.text_good = None
|
|
self.float_size_good = None
|
|
|
|
_ADDRESS_RE = re.compile(r'\bat 0x[0-9a-f]*?>')
|
|
# TODO(yashkatariya): Add other tensor's string substitutions too.
|
|
# tf.RaggedTensor doesn't need one.
|
|
_NUMPY_OUTPUT_RE = re.compile(r'<tf.Tensor.*?numpy=(.*?)>', re.DOTALL)
|
|
|
|
def _allclose(self, want, got, rtol=1e-3, atol=1e-3):
|
|
return np.allclose(want, got, rtol=rtol, atol=atol)
|
|
|
|
def _tf_tensor_numpy_output(self, string):
|
|
modified_string = self._NUMPY_OUTPUT_RE.sub(r'\1', string)
|
|
return modified_string, modified_string != string
|
|
|
|
MESSAGE = textwrap.dedent("""\n
|
|
#############################################################
|
|
Check the documentation (https://www.tensorflow.org/community/contribute/docs_ref) on how to
|
|
write testable docstrings.
|
|
#############################################################""")
|
|
|
|
def check_output(self, want, got, optionflags):
|
|
"""Compares the docstring output to the output gotten by running the code.
|
|
|
|
Python addresses in the output are replaced with wildcards.
|
|
|
|
Float values in the output compared as using `np.allclose`:
|
|
|
|
* Float values are extracted from the text and replaced with wildcards.
|
|
* The wildcard text is compared to the actual output.
|
|
* The float values are compared using `np.allclose`.
|
|
|
|
The method returns `True` if both the text comparison and the numeric
|
|
comparison are successful.
|
|
|
|
The numeric comparison will fail if either:
|
|
|
|
* The wrong number of floats are found.
|
|
* The float values are not within tolerence.
|
|
|
|
Args:
|
|
want: The output in the docstring.
|
|
got: The output generated after running the snippet.
|
|
optionflags: Flags passed to the doctest.
|
|
|
|
Returns:
|
|
A bool, indicating if the check was successful or not.
|
|
"""
|
|
|
|
# If the docstring's output is empty and there is some output generated
|
|
# after running the snippet, return True. This is because if the user
|
|
# doesn't want to display output, respect that over what the doctest wants.
|
|
if got and not want:
|
|
return True
|
|
|
|
if want is None:
|
|
want = ''
|
|
|
|
if want == got:
|
|
return True
|
|
|
|
# Replace python's addresses with ellipsis (`...`) since it can change on
|
|
# each execution.
|
|
want = self._ADDRESS_RE.sub('at ...>', want)
|
|
|
|
# Replace tf.Tensor strings with only their numpy field values.
|
|
want, want_changed = self._tf_tensor_numpy_output(want)
|
|
if want_changed:
|
|
got, _ = self._tf_tensor_numpy_output(got)
|
|
|
|
# Separate out the floats, and replace `want` with the wild-card version
|
|
# "result=7.0" => "result=..."
|
|
want_text_parts, self.want_floats = self.extract_floats(want)
|
|
# numpy sometimes pads floats in arrays with spaces
|
|
# got: [1.2345, 2.3456, 3.0 ] want: [1.2345, 2.3456, 3.0001]
|
|
# And "normalize whitespace" only works when there's at least one space,
|
|
# so strip them and let the wildcard handle it.
|
|
want_text_parts = [part.strip(' ') for part in want_text_parts]
|
|
want_text_wild = '...'.join(want_text_parts)
|
|
if '....' in want_text_wild:
|
|
# If a float comes just after a period you'll end up four dots and the
|
|
# first three count as the ellipsis. Replace it with three dots.
|
|
want_text_wild = re.sub(r'\.\.\.\.+', '...', want_text_wild)
|
|
|
|
# Find the floats in the string returned by the test
|
|
_, self.got_floats = self.extract_floats(got)
|
|
|
|
self.text_good = super(TfDoctestOutputChecker, self).check_output(
|
|
want=want_text_wild, got=got, optionflags=optionflags)
|
|
if not self.text_good:
|
|
return False
|
|
|
|
if self.want_floats.size == 0:
|
|
# If there are no floats in the "want" string, ignore all the floats in
|
|
# the result. "np.array([ ... ])" matches "np.array([ 1.0, 2.0 ])"
|
|
return True
|
|
|
|
self.float_size_good = (self.want_floats.size == self.got_floats.size)
|
|
|
|
if self.float_size_good:
|
|
return self._allclose(self.want_floats, self.got_floats)
|
|
else:
|
|
return False
|
|
|
|
def output_difference(self, example, got, optionflags):
|
|
got = [got]
|
|
|
|
# If the some of the float output is hidden with `...`, `float_size_good`
|
|
# will be False. This is because the floats extracted from the string is
|
|
# converted into a 1-D numpy array. Hence hidding floats is not allowed
|
|
# anymore.
|
|
if self.text_good:
|
|
if not self.float_size_good:
|
|
got.append("\n\nCAUTION: tf_doctest doesn't work if *some* of the "
|
|
"*float output* is hidden with a \"...\".")
|
|
|
|
got.append(self.MESSAGE)
|
|
got = '\n'.join(got)
|
|
return (super(TfDoctestOutputChecker,
|
|
self).output_difference(example, got, optionflags))
|