3RNN/Lib/site-packages/tensorflow/python/framework/test_combinations.py
2024-05-26 19:49:15 +02:00

460 lines
16 KiB
Python

# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Facilities for creating multiple test combinations.
Here is a simple example for testing various optimizers in Eager and Graph:
class AdditionExample(test.TestCase, parameterized.TestCase):
@combinations.generate(
combinations.combine(mode=["graph", "eager"],
optimizer=[AdamOptimizer(),
GradientDescentOptimizer()]))
def testOptimizer(self, optimizer):
... f(optimizer)...
This will run `testOptimizer` 4 times with the specified optimizers: 2 in
Eager and 2 in Graph mode.
The test is going to accept the same parameters as the ones used in `combine()`.
The parameters need to match by name between the `combine()` call and the test
signature. It is necessary to accept all parameters. See `OptionalParameter`
for a way to implement optional parameters.
`combine()` function is available for creating a cross product of various
options. `times()` function exists for creating a product of N `combine()`-ed
results.
The execution of generated tests can be customized in a number of ways:
- The test can be skipped if it is not running in the correct environment.
- The arguments that are passed to the test can be additionally transformed.
- The test can be run with specific Python context managers.
These behaviors can be customized by providing instances of `TestCombination` to
`generate()`.
"""
from collections import OrderedDict
import contextlib
import re
import types
import unittest
from absl.testing import parameterized
from tensorflow.python.util import tf_inspect
from tensorflow.python.util.tf_export import tf_export
@tf_export("__internal__.test.combinations.TestCombination", v1=[])
class TestCombination:
"""Customize the behavior of `generate()` and the tests that it executes.
Here is sequence of steps for executing a test combination:
1. The test combination is evaluated for whether it should be executed in
the given environment by calling `should_execute_combination`.
2. If the test combination is going to be executed, then the arguments for
all combined parameters are validated. Some arguments can be handled in
a special way. This is achieved by implementing that logic in
`ParameterModifier` instances that returned from `parameter_modifiers`.
3. Before executing the test, `context_managers` are installed
around it.
"""
def should_execute_combination(self, kwargs):
"""Indicates whether the combination of test arguments should be executed.
If the environment doesn't satisfy the dependencies of the test
combination, then it can be skipped.
Args:
kwargs: Arguments that are passed to the test combination.
Returns:
A tuple boolean and an optional string. The boolean False indicates
that the test should be skipped. The string would indicate a textual
description of the reason. If the test is going to be executed, then
this method returns `None` instead of the string.
"""
del kwargs
return (True, None)
def parameter_modifiers(self):
"""Returns `ParameterModifier` instances that customize the arguments."""
return []
def context_managers(self, kwargs):
"""Return context managers for running the test combination.
The test combination will run under all context managers that all
`TestCombination` instances return.
Args:
kwargs: Arguments and their values that are passed to the test
combination.
Returns:
A list of instantiated context managers.
"""
del kwargs
return []
@tf_export("__internal__.test.combinations.ParameterModifier", v1=[])
class ParameterModifier:
"""Customizes the behavior of a particular parameter.
Users should override `modified_arguments()` to modify the parameter they
want, eg: change the value of certain parameter or filter it from the params
passed to the test case.
See the sample usage below, it will change any negative parameters to zero
before it gets passed to test case.
```
class NonNegativeParameterModifier(ParameterModifier):
def modified_arguments(self, kwargs, requested_parameters):
updates = {}
for name, value in kwargs.items():
if value < 0:
updates[name] = 0
return updates
```
"""
DO_NOT_PASS_TO_THE_TEST = object()
def __init__(self, parameter_name=None):
"""Construct a parameter modifier that may be specific to a parameter.
Args:
parameter_name: A `ParameterModifier` instance may operate on a class of
parameters or on a parameter with a particular name. Only
`ParameterModifier` instances that are of a unique type or were
initialized with a unique `parameter_name` will be executed.
See `__eq__` and `__hash__`.
"""
self._parameter_name = parameter_name
def modified_arguments(self, kwargs, requested_parameters):
"""Replace user-provided arguments before they are passed to a test.
This makes it possible to adjust user-provided arguments before passing
them to the test method.
Args:
kwargs: The combined arguments for the test.
requested_parameters: The set of parameters that are defined in the
signature of the test method.
Returns:
A dictionary with updates to `kwargs`. Keys with values set to
`ParameterModifier.DO_NOT_PASS_TO_THE_TEST` are going to be deleted and
not passed to the test.
"""
del kwargs, requested_parameters
return {}
def __eq__(self, other):
"""Compare `ParameterModifier` by type and `parameter_name`."""
if self is other:
return True
elif type(self) is type(other):
return self._parameter_name == other._parameter_name
else:
return False
def __ne__(self, other):
return not self.__eq__(other)
def __hash__(self):
"""Compare `ParameterModifier` by type or `parameter_name`."""
if self._parameter_name:
return hash(self._parameter_name)
else:
return id(self.__class__)
@tf_export("__internal__.test.combinations.OptionalParameter", v1=[])
class OptionalParameter(ParameterModifier):
"""A parameter that is optional in `combine()` and in the test signature.
`OptionalParameter` is usually used with `TestCombination` in the
`parameter_modifiers()`. It allows `TestCombination` to skip certain
parameters when passing them to `combine()`, since the `TestCombination` might
consume the param and create some context based on the value it gets.
See the sample usage below:
```
class EagerGraphCombination(TestCombination):
def context_managers(self, kwargs):
mode = kwargs.pop("mode", None)
if mode is None:
return []
elif mode == "eager":
return [context.eager_mode()]
elif mode == "graph":
return [ops.Graph().as_default(), context.graph_mode()]
else:
raise ValueError(
"'mode' has to be either 'eager' or 'graph', got {}".format(mode))
def parameter_modifiers(self):
return [test_combinations.OptionalParameter("mode")]
```
When the test case is generated, the param "mode" will not be passed to the
test method, since it is consumed by the `EagerGraphCombination`.
"""
def modified_arguments(self, kwargs, requested_parameters):
if self._parameter_name in requested_parameters:
return {}
else:
return {self._parameter_name: ParameterModifier.DO_NOT_PASS_TO_THE_TEST}
def generate(combinations, test_combinations=()):
"""A decorator for generating combinations of a test method or a test class.
Parameters of the test method must match by name to get the corresponding
value of the combination. Tests must accept all parameters that are passed
other than the ones that are `OptionalParameter`.
Args:
combinations: a list of dictionaries created using combine() and times().
test_combinations: a tuple of `TestCombination` instances that customize
the execution of generated tests.
Returns:
a decorator that will cause the test method or the test class to be run
under the specified conditions.
Raises:
ValueError: if any parameters were not accepted by the test method
"""
def decorator(test_method_or_class):
"""The decorator to be returned."""
# Generate good test names that can be used with --test_filter.
named_combinations = []
for combination in combinations:
# We use OrderedDicts in `combine()` and `times()` to ensure stable
# order of keys in each dictionary.
assert isinstance(combination, OrderedDict)
name = "".join([
"_{}_{}".format("".join(filter(str.isalnum, key)),
"".join(filter(str.isalnum, _get_name(value, i))))
for i, (key, value) in enumerate(combination.items())
])
named_combinations.append(
OrderedDict(
list(combination.items()) +
[("testcase_name", "_test{}".format(name))]))
if isinstance(test_method_or_class, type):
class_object = test_method_or_class
class_object._test_method_ids = test_method_ids = {}
for name, test_method in class_object.__dict__.copy().items():
if (name.startswith(unittest.TestLoader.testMethodPrefix) and
isinstance(test_method, types.FunctionType)):
delattr(class_object, name)
methods = {}
parameterized._update_class_dict_for_param_test_case(
class_object.__name__, methods, test_method_ids, name,
parameterized._ParameterizedTestIter(
_augment_with_special_arguments(
test_method, test_combinations=test_combinations),
named_combinations, parameterized._NAMED, name))
for method_name, method in methods.items():
setattr(class_object, method_name, method)
return class_object
else:
test_method = _augment_with_special_arguments(
test_method_or_class, test_combinations=test_combinations)
return parameterized.named_parameters(*named_combinations)(test_method)
return decorator
def _augment_with_special_arguments(test_method, test_combinations):
def decorated(self, **kwargs):
"""A wrapped test method that can treat some arguments in a special way."""
original_kwargs = kwargs.copy()
# Skip combinations that are going to be executed in a different testing
# environment.
reasons_to_skip = []
for combination in test_combinations:
should_execute, reason = combination.should_execute_combination(
original_kwargs.copy())
if not should_execute:
reasons_to_skip.append(" - " + reason)
if reasons_to_skip:
self.skipTest("\n".join(reasons_to_skip))
customized_parameters = []
for combination in test_combinations:
customized_parameters.extend(combination.parameter_modifiers())
customized_parameters = set(customized_parameters)
# The function for running the test under the total set of
# `context_managers`:
def execute_test_method():
requested_parameters = tf_inspect.getfullargspec(test_method).args
for customized_parameter in customized_parameters:
for argument, value in customized_parameter.modified_arguments(
original_kwargs.copy(), requested_parameters).items():
if value is ParameterModifier.DO_NOT_PASS_TO_THE_TEST:
kwargs.pop(argument, None)
else:
kwargs[argument] = value
omitted_arguments = set(requested_parameters).difference(
set(list(kwargs.keys()) + ["self"]))
if omitted_arguments:
raise ValueError("The test requires parameters whose arguments "
"were not passed: {} .".format(omitted_arguments))
missing_arguments = set(list(kwargs.keys()) + ["self"]).difference(
set(requested_parameters))
if missing_arguments:
raise ValueError("The test does not take parameters that were passed "
": {} .".format(missing_arguments))
kwargs_to_pass = {}
for parameter in requested_parameters:
if parameter == "self":
kwargs_to_pass[parameter] = self
else:
kwargs_to_pass[parameter] = kwargs[parameter]
test_method(**kwargs_to_pass)
# Install `context_managers` before running the test:
context_managers = []
for combination in test_combinations:
for manager in combination.context_managers(
original_kwargs.copy()):
context_managers.append(manager)
if hasattr(contextlib, "nested"): # Python 2
# TODO(isaprykin): Switch to ExitStack when contextlib2 is available.
with contextlib.nested(*context_managers):
execute_test_method()
else: # Python 3
with contextlib.ExitStack() as context_stack:
for manager in context_managers:
context_stack.enter_context(manager)
execute_test_method()
return decorated
@tf_export("__internal__.test.combinations.combine", v1=[])
def combine(**kwargs):
"""Generate combinations based on its keyword arguments.
Two sets of returned combinations can be concatenated using +. Their product
can be computed using `times()`.
Args:
**kwargs: keyword arguments of form `option=[possibilities, ...]`
or `option=the_only_possibility`.
Returns:
a list of dictionaries for each combination. Keys in the dictionaries are
the keyword argument names. Each key has one value - one of the
corresponding keyword argument values.
"""
if not kwargs:
return [OrderedDict()]
sort_by_key = lambda k: k[0]
kwargs = OrderedDict(sorted(kwargs.items(), key=sort_by_key))
first = list(kwargs.items())[0]
rest = dict(list(kwargs.items())[1:])
rest_combined = combine(**rest)
key = first[0]
values = first[1]
if not isinstance(values, list):
values = [values]
return [
OrderedDict(sorted(list(combined.items()) + [(key, v)], key=sort_by_key))
for v in values
for combined in rest_combined
]
@tf_export("__internal__.test.combinations.times", v1=[])
def times(*combined):
"""Generate a product of N sets of combinations.
times(combine(a=[1,2]), combine(b=[3,4])) == combine(a=[1,2], b=[3,4])
Args:
*combined: N lists of dictionaries that specify combinations.
Returns:
a list of dictionaries for each combination.
Raises:
ValueError: if some of the inputs have overlapping keys.
"""
assert combined
if len(combined) == 1:
return combined[0]
first = combined[0]
rest_combined = times(*combined[1:])
combined_results = []
for a in first:
for b in rest_combined:
if set(a.keys()).intersection(set(b.keys())):
raise ValueError("Keys need to not overlap: {} vs {}".format(
a.keys(), b.keys()))
combined_results.append(OrderedDict(list(a.items()) + list(b.items())))
return combined_results
@tf_export("__internal__.test.combinations.NamedObject", v1=[])
class NamedObject:
"""A class that translates an object into a good test name."""
def __init__(self, name, obj):
self._name = name
self._obj = obj
def __getattr__(self, name):
return getattr(self._obj, name)
def __call__(self, *args, **kwargs):
return self._obj(*args, **kwargs)
def __iter__(self):
return self._obj.__iter__()
def __repr__(self):
return self._name
def _get_name(value, index):
return re.sub("0[xX][0-9a-fA-F]+", str(index), str(value))