Intelegentny_Pszczelarz/.venv/Lib/site-packages/tensorflow/python/data/kernel_tests/checkpoint_test_base.py
2023-06-19 00:49:18 +02:00

614 lines
24 KiB
Python

# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Base test class for checkpointing datasets."""
import os
import numpy as np
from tensorflow.python.checkpoint import checkpoint as tracking_util
from tensorflow.python.checkpoint import checkpoint_management
from tensorflow.python.data.experimental.ops import iterator_ops as contrib_iterator_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import options as options_lib
from tensorflow.python.eager import context
from tensorflow.python.framework import combinations
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import variables
from tensorflow.python.ops.ragged import ragged_tensor_value
from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
from tensorflow.python.training import saver as saver_lib
from tensorflow.python.util import nest
def remove_variants(get_next_op):
# TODO(b/72408568): Remove this once session.run can get variant tensors.
"""Remove variants from a nest structure, so sess.run will execute."""
def _remove_variant(x):
if isinstance(x, ops.Tensor) and x.dtype == dtypes.variant:
return ()
else:
return x
return nest.map_structure(_remove_variant, get_next_op)
def default_test_combinations():
"""Returns the default test combinations for testing checkpointing."""
def disable_optimizations(ds_fn):
options = options_lib.Options()
options.experimental_optimization.apply_default_optimizations = False
def ds_fn_no_opt():
return ds_fn().with_options(options)
return ds_fn_no_opt
def verify_unused_iterator(obj, ds_fn, num_outputs, sparse_tensors=False):
obj.verify_unused_iterator(
ds_fn=disable_optimizations(ds_fn=ds_fn),
num_outputs=num_outputs,
sparse_tensors=sparse_tensors)
verify_unused_iterator_combination = combinations.combine(
verify_fn=combinations.NamedObject("verify_unused_iterator",
verify_unused_iterator))
def verify_fully_used_iterator(obj, ds_fn, num_outputs, sparse_tensors=False):
obj.verify_fully_used_iterator(
ds_fn=disable_optimizations(ds_fn=ds_fn),
num_outputs=num_outputs,
sparse_tensors=sparse_tensors)
verify_fully_used_iterator_combination = combinations.combine(
verify_fn=combinations.NamedObject("verify_fully_used_iterator",
verify_fully_used_iterator))
def verify_exhausted_iterator(obj, ds_fn, num_outputs, sparse_tensors=False):
obj.verify_exhausted_iterator(
ds_fn=disable_optimizations(ds_fn=ds_fn),
num_outputs=num_outputs,
sparse_tensors=sparse_tensors)
verify_exhausted_iterator_combination = combinations.combine(
verify_fn=combinations.NamedObject("verify_exhausted_iterator",
verify_exhausted_iterator))
def verify_multiple_breaks(obj, ds_fn, num_outputs, sparse_tensors=False):
obj.verify_multiple_breaks(
ds_fn=disable_optimizations(ds_fn=ds_fn),
num_outputs=num_outputs,
sparse_tensors=sparse_tensors)
verify_multiple_breaks_combination = combinations.combine(
verify_fn=combinations.NamedObject("verify_multiple_breaks",
verify_multiple_breaks))
def verify_reset_restored_iterator(obj,
ds_fn,
num_outputs,
sparse_tensors=False):
obj.verify_reset_restored_iterator(
ds_fn=disable_optimizations(ds_fn=ds_fn),
num_outputs=num_outputs,
sparse_tensors=sparse_tensors)
verify_reset_restored_iterator_combination = combinations.combine(
verify_fn=combinations.NamedObject("verify_reset_restored_iterator",
verify_reset_restored_iterator))
return (verify_unused_iterator_combination +
verify_fully_used_iterator_combination +
verify_exhausted_iterator_combination +
verify_multiple_breaks_combination +
verify_reset_restored_iterator_combination)
# TODO(b/72657739): Remove sparse_tensor argument, which is to test the
# (deprecated) saveable `SparseTensorSliceDataset`, once the API
# `from_sparse_tensor_slices()` and related tests are deleted.
class CheckpointTestBase(test.TestCase):
"""Base test class for checkpointing datasets."""
def tearDown(self):
self._delete_ckpt()
super(CheckpointTestBase, self).tearDown()
def verify_unused_iterator(self,
ds_fn,
num_outputs,
sparse_tensors=False,
verify_exhausted=True):
"""Verifies that saving and restoring an unused iterator works.
Args:
ds_fn: 0-argument function that returns a Dataset.
num_outputs: Total number of outputs expected from this Dataset.
sparse_tensors: Whether dataset is built from SparseTensor(s).
verify_exhausted: Whether to verify that the iterator has been exhausted
after producing `num_outputs` elements.
Raises:
AssertionError if any test fails.
"""
self.verify_run_with_breaks(
ds_fn, [0],
num_outputs,
sparse_tensors=sparse_tensors,
verify_exhausted=verify_exhausted)
def verify_fully_used_iterator(self,
ds_fn,
num_outputs,
sparse_tensors=False):
"""Verifies that saving and restoring a fully used iterator works.
Note that this only checks saving and restoring an iterator from which
`num_outputs` items have been produced but does not check for an
exhausted iterator, i.e., one from which an OutOfRange error has been
returned.
Args:
ds_fn: 0-argument function that returns a Dataset.
num_outputs: Total number of outputs expected from this Dataset.
sparse_tensors: Whether dataset is built from SparseTensor(s).
Raises:
AssertionError if test fails.
"""
self.verify_run_with_breaks(
ds_fn, [num_outputs], num_outputs, sparse_tensors=sparse_tensors)
def verify_exhausted_iterator(self, ds_fn, num_outputs, sparse_tensors=False):
"""Verifies that saving and restoring an exhausted iterator works.
An exhausted iterator is one which has returned an OutOfRange error.
Args:
ds_fn: 0-argument function that returns a Dataset.
num_outputs: Total number of outputs expected from this Dataset.
sparse_tensors: Whether dataset is built from SparseTensor(s).
Raises:
AssertionError if any test fails.
"""
self.gen_outputs(
ds_fn, [],
num_outputs,
verify_exhausted=True,
sparse_tensors=sparse_tensors)
actual = self.gen_outputs(
ds_fn, [],
0,
ckpt_saved=True,
verify_exhausted=True,
sparse_tensors=sparse_tensors)
self.assertLen(actual, 0)
def verify_multiple_breaks(self,
ds_fn,
num_outputs,
num_breaks=10,
sparse_tensors=False,
verify_exhausted=True):
"""Attempts to save/restore at multiple break points.
Args:
ds_fn: 0-argument function that returns a Dataset.
num_outputs: Total number of outputs expected from this Dataset.
num_breaks: The number of break points. These are uniformly spread in [0,
num_outputs] both inclusive.
sparse_tensors: Whether dataset is built from SparseTensor(s).
verify_exhausted: Whether to verify that the iterator has been exhausted
after producing `num_outputs` elements.
Raises:
AssertionError if any test fails.
"""
self.verify_run_with_breaks(
ds_fn,
self.gen_break_points(num_outputs, num_breaks),
num_outputs,
sparse_tensors=sparse_tensors,
verify_exhausted=verify_exhausted)
def verify_reset_restored_iterator(self,
ds_fn,
num_outputs,
break_point=None,
sparse_tensors=False,
verify_exhausted=True):
"""Attempts to re-initialize a restored iterator.
This is useful when restoring a training checkpoint during validation.
Args:
ds_fn: 0-argument function that returns a Dataset.
num_outputs: Total number of outputs expected from this Dataset.
break_point: Break point. Optional. Defaults to num_outputs/2.
sparse_tensors: Whether dataset is built from SparseTensor(s).
verify_exhausted: Whether to verify that the iterator has been exhausted
after producing `num_outputs` elements.
Raises:
AssertionError if any test fails.
"""
if context.executing_eagerly():
self.skipTest("Eager mode iteration do not support re-initialization.")
break_point = num_outputs // 2 if not break_point else break_point
# Collect ground truth containing all outputs.
expected = self.gen_outputs(
ds_fn, [],
num_outputs,
sparse_tensors=sparse_tensors,
verify_exhausted=verify_exhausted)
# Skip some items and save checkpoint.
self.gen_outputs(
ds_fn, [],
break_point,
sparse_tensors=sparse_tensors,
verify_exhausted=False)
actual = []
# Restore from checkpoint and then run init_op.
with ops.Graph().as_default() as g:
saver = self._import_meta_graph()
init_op, get_next_op = self._get_iterator_ops_from_collection(
ds_fn, sparse_tensors=sparse_tensors)
get_next_op = remove_variants(get_next_op)
with self.session(graph=g) as sess:
self._initialize(init_op, sess)
self._restore(saver, sess)
self._initialize(init_op, sess)
for _ in range(num_outputs):
actual.append(sess.run(get_next_op))
if verify_exhausted:
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next_op)
self.match(expected, actual)
def verify_error_on_save(self,
ds_fn,
num_outputs,
error,
break_point=None,
sparse_tensors=False):
"""Attempts to save a non-saveable iterator.
Args:
ds_fn: 0-argument function that returns a Dataset.
num_outputs: Total number of outputs expected from this Dataset.
error: Declared error when trying to save iterator.
break_point: Break point. Optional. Defaults to num_outputs/2.
sparse_tensors: Whether dataset is built from SparseTensor(s).
Raises:
AssertionError if any test fails.
"""
break_point = num_outputs // 2 if not break_point else break_point
if context.executing_eagerly():
iterator = iter(ds_fn())
ckpt = tracking_util.Checkpoint(iterator=iterator)
for _ in range(break_point):
next(iterator)
with self.assertRaises(error):
ckpt.save(self._ckpt_path())
else:
with ops.Graph().as_default() as g:
init_op, get_next_op, saver = self._build_graph(
ds_fn, sparse_tensors=sparse_tensors)
get_next_op = remove_variants(get_next_op)
with self.session(graph=g) as sess:
self._initialize(init_op, sess)
for _ in range(break_point):
sess.run(get_next_op)
with self.assertRaises(error):
self._save(sess, saver)
def verify_run_with_breaks(self,
ds_fn,
break_points,
num_outputs,
sparse_tensors=False,
verify_exhausted=True):
"""Verifies that ds_fn() produces the same outputs with and without breaks.
1. Builds a Dataset using `ds_fn` and produces `num_outputs` items from it
*without* stopping at break points.
2. Builds a Dataset using `ds_fn` and produces `num_outputs` items from it
with stopping at break points.
Deep matches outputs from 1 and 2.
Args:
ds_fn: 0-argument function that returns a Dataset.
break_points: A list of integers. For each `break_point` in
`break_points`, we produce outputs till `break_point` number of items
have been produced and then checkpoint the state. The current graph and
session are destroyed and a new graph and session are used to produce
outputs till next checkpoint or till `num_outputs` elements have been
produced. `break_point` must be <= `num_outputs`.
num_outputs: Total number of outputs expected from this Dataset.
sparse_tensors: Whether dataset is built from SparseTensor(s).
verify_exhausted: Whether to verify that the iterator has been exhausted
after producing `num_outputs` elements.
Raises:
AssertionError if any test fails.
"""
expected = self.gen_outputs(
ds_fn, [],
num_outputs,
sparse_tensors=sparse_tensors,
verify_exhausted=verify_exhausted)
actual = self.gen_outputs(
ds_fn,
break_points,
num_outputs,
sparse_tensors=sparse_tensors,
verify_exhausted=verify_exhausted)
self.match(expected, actual)
def gen_outputs(self,
ds_fn,
break_points,
num_outputs,
ckpt_saved=False,
sparse_tensors=False,
verify_exhausted=True,
save_checkpoint_at_end=True):
"""Generates elements from input dataset while stopping at break points.
Produces `num_outputs` outputs and saves the state of the iterator in the
Saver checkpoint.
Args:
ds_fn: 0-argument function that returns the dataset.
break_points: A list of integers. For each `break_point` in
`break_points`, we produce outputs till `break_point` number of items
have been produced and then checkpoint the state. The current graph and
session are destroyed and a new graph and session are used to produce
outputs till next checkpoint or till `num_outputs` elements have been
produced. `break_point` must be <= `num_outputs`.
num_outputs: The total number of outputs to produce from the iterator.
ckpt_saved: Whether a checkpoint already exists.
sparse_tensors: Whether dataset is built from SparseTensor(s).
verify_exhausted: Whether to verify that the iterator has been exhausted
after producing `num_outputs` elements.
save_checkpoint_at_end: Whether to save a checkpoint after producing all
outputs. If False, checkpoints are saved each break point but not at the
end. Note that checkpoints overwrite each other so there is always only
a single checkpoint available. Defaults to True.
Returns:
A list of `num_outputs` items.
"""
outputs = []
if context.executing_eagerly():
for i in range(len(break_points) + 1):
iterator = iter(ds_fn())
ckpt = tracking_util.Checkpoint(iterator=iterator)
if ckpt_saved:
ckpt_path = self._latest_ckpt()
ckpt.restore(ckpt_path)
start = break_points[i - 1] if i > 0 else 0
end = break_points[i] if i < len(break_points) else num_outputs
num_iters = end - start
for _ in range(num_iters):
outputs.append(self.evaluate(next(iterator)))
if i == len(break_points) and verify_exhausted:
with self.assertRaises(StopIteration):
next(iterator)
if save_checkpoint_at_end or i < len(break_points):
ckpt_path = ckpt.save(self._ckpt_path())
ckpt_saved = True
else:
def get_ops():
if ckpt_saved:
saver = self._import_meta_graph()
init_op, get_next_op = self._get_iterator_ops_from_collection(
ds_fn, sparse_tensors=sparse_tensors)
else:
init_op, get_next_op, saver = self._build_graph(
ds_fn, sparse_tensors=sparse_tensors)
return init_op, get_next_op, saver
for i in range(len(break_points) + 1):
with ops.Graph().as_default() as g:
init_op, get_next_op, saver = get_ops()
get_next_op = remove_variants(get_next_op)
with self.session(graph=g) as sess:
if ckpt_saved:
self._initialize(init_op, sess)
self._restore(saver, sess)
else:
self._initialize(init_op, sess)
start = break_points[i - 1] if i > 0 else 0
end = break_points[i] if i < len(break_points) else num_outputs
num_iters = end - start
for _ in range(num_iters):
outputs.append(sess.run(get_next_op))
if i == len(break_points) and verify_exhausted:
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next_op)
if save_checkpoint_at_end or i < len(break_points):
self._save(sess, saver)
ckpt_saved = True
return outputs
def match(self, expected, actual):
"""Matches nested structures.
Recursively matches shape and values of `expected` and `actual`.
Handles scalars, numpy arrays and other python sequence containers
e.g. list, dict, as well as SparseTensorValue and RaggedTensorValue.
Args:
expected: Nested structure 1.
actual: Nested structure 2.
Raises:
AssertionError if matching fails.
"""
if isinstance(expected, np.ndarray):
expected = expected.tolist()
if isinstance(actual, np.ndarray):
actual = actual.tolist()
self.assertEqual(type(expected), type(actual))
if nest.is_nested(expected):
self.assertEqual(len(expected), len(actual))
if isinstance(expected, dict):
for key1, key2 in zip(sorted(expected), sorted(actual)):
self.assertEqual(key1, key2)
self.match(expected[key1], actual[key2])
else:
for item1, item2 in zip(expected, actual):
self.match(item1, item2)
elif isinstance(expected, sparse_tensor.SparseTensorValue):
self.match((expected.indices, expected.values, expected.dense_shape),
(actual.indices, actual.values, actual.dense_shape))
elif isinstance(expected, ragged_tensor_value.RaggedTensorValue):
self.match((expected.values, expected.row_splits),
(actual.values, actual.row_splits))
else:
self.assertEqual(expected, actual)
def does_not_match(self, expected, actual):
with self.assertRaises(AssertionError):
self.match(expected, actual)
def gen_break_points(self, num_outputs, num_samples=10):
"""Generates `num_samples` unique break points in [0, num_outputs]."""
return np.unique(np.linspace(0, num_outputs, num_samples, dtype=int))
def _build_graph(self, ds_fn, sparse_tensors=False):
dataset = ds_fn()
iterator = dataset_ops.make_initializable_iterator(dataset)
external_state_policy = dataset.options().experimental_external_state_policy
saveable = contrib_iterator_ops.make_saveable_from_iterator(
iterator, external_state_policy=external_state_policy)
ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
init_op = iterator.initializer
if sparse_tensors:
get_next = sparse_tensor.SparseTensor(*iterator.get_next())
else:
get_next = iterator.get_next()
self._add_iterator_ops_to_collection(init_op, get_next, ds_fn,
sparse_tensors)
saver = saver_lib.Saver(allow_empty=True)
return init_op, get_next, saver
def _add_iterator_ops_to_collection(self,
init_op,
get_next,
ds_fn,
sparse_tensors=False):
ops.add_to_collection("iterator_ops", init_op)
# `get_next` may be a tuple e.g. in TensorSliceDataset. Since Collections
# do not support tuples we flatten the tensors and restore the shape in
# `_get_iterator_ops_from_collection`.
if sparse_tensors: # specific for deprecated `from_sparse_tensor_slices`.
ops.add_to_collection("iterator_ops", get_next.indices)
ops.add_to_collection("iterator_ops", get_next.values)
ops.add_to_collection("iterator_ops", get_next.dense_shape)
return
get_next_list = nest.flatten(get_next)
for i, output_class in enumerate(
nest.flatten(self._get_output_classes(ds_fn))):
if output_class is sparse_tensor.SparseTensor:
ops.add_to_collection("iterator_ops", get_next_list[i].indices)
ops.add_to_collection("iterator_ops", get_next_list[i].values)
ops.add_to_collection("iterator_ops", get_next_list[i].dense_shape)
else:
ops.add_to_collection("iterator_ops", get_next_list[i])
def _get_iterator_ops_from_collection(self, ds_fn, sparse_tensors=False):
all_ops = ops.get_collection("iterator_ops")
if sparse_tensors: # specific for deprecated `from_sparse_tensor_slices`.
init_op, indices, values, dense_shape = all_ops
return init_op, sparse_tensor.SparseTensor(indices, values, dense_shape)
get_next_list = []
i = 1
for output_class in nest.flatten(self._get_output_classes(ds_fn)):
if output_class is sparse_tensor.SparseTensor:
indices, values, dense_shape = all_ops[i:i + 3]
i += 3
get_next_list.append(
sparse_tensor.SparseTensor(indices, values, dense_shape))
else:
get_next_list.append(all_ops[i])
i += 1
return all_ops[0], nest.pack_sequence_as(
self._get_output_types(ds_fn), get_next_list)
def _get_output_types(self, ds_fn):
assert not context.executing_eagerly()
with ops.Graph().as_default():
return dataset_ops.get_legacy_output_types(ds_fn())
def _get_output_shapes(self, ds_fn):
assert not context.executing_eagerly()
with ops.Graph().as_default():
return dataset_ops.get_legacy_output_shapes(ds_fn())
def _get_output_classes(self, ds_fn):
assert not context.executing_eagerly()
with ops.Graph().as_default():
return dataset_ops.get_legacy_output_classes(ds_fn())
def _ckpt_path(self):
return os.path.join(self.get_temp_dir(), "iterator")
def _latest_ckpt(self):
return checkpoint_management.latest_checkpoint(self.get_temp_dir())
def _save(self, sess, saver):
saver.save(sess, self._ckpt_path())
def _restore(self, saver, sess):
sess.run(lookup_ops.tables_initializer())
saver.restore(sess, self._latest_ckpt())
def _initialize(self, init_op, sess):
sess.run(variables.global_variables_initializer())
sess.run(lookup_ops.tables_initializer())
sess.run(init_op)
def _import_meta_graph(self):
meta_file_path = self._ckpt_path() + ".meta"
return saver_lib.import_meta_graph(meta_file_path)
def _delete_ckpt(self):
# Remove all checkpoint files.
prefix = self._ckpt_path()
pattern = prefix + "*"
files = gfile.Glob(pattern)
map(gfile.Remove, files)