Intelegentny_Pszczelarz/.venv/Lib/site-packages/tensorflow/python/ops/parallel_for/test_util.py
2023-06-19 00:49:18 +02:00

64 lines
2.5 KiB
Python

# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test utility."""
import numpy as np
from tensorflow.python.ops import variables
from tensorflow.python.ops.parallel_for import control_flow_ops as pfor_control_flow_ops
from tensorflow.python.platform import test
from tensorflow.python.util import nest
class PForTestCase(test.TestCase):
"""Base class for test cases."""
def _run_targets(self, targets1, targets2=None, run_init=True):
targets1 = nest.flatten(targets1)
targets2 = ([] if targets2 is None else nest.flatten(targets2))
assert len(targets1) == len(targets2) or not targets2
if run_init:
init = variables.global_variables_initializer()
self.evaluate(init)
return self.evaluate(targets1 + targets2)
# TODO(agarwal): Allow tests to pass down tolerances.
def run_and_assert_equal(self, targets1, targets2, rtol=1e-4, atol=1e-5):
outputs = self._run_targets(targets1, targets2)
outputs = nest.flatten(outputs) # flatten SparseTensorValues
n = len(outputs) // 2
for i in range(n):
if outputs[i + n].dtype != np.object_:
self.assertAllClose(outputs[i + n], outputs[i], rtol=rtol, atol=atol)
else:
self.assertAllEqual(outputs[i + n], outputs[i])
def _test_loop_fn(self,
loop_fn,
iters,
parallel_iterations=None,
fallback_to_while_loop=False,
rtol=1e-4,
atol=1e-5):
t1 = pfor_control_flow_ops.pfor(
loop_fn,
iters=iters,
fallback_to_while_loop=fallback_to_while_loop,
parallel_iterations=parallel_iterations)
loop_fn_dtypes = nest.map_structure(lambda x: x.dtype, t1)
t2 = pfor_control_flow_ops.for_loop(loop_fn, loop_fn_dtypes, iters=iters,
parallel_iterations=parallel_iterations)
self.run_and_assert_equal(t1, t2, rtol=rtol, atol=atol)