614 lines
24 KiB
Python
614 lines
24 KiB
Python
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Base test class for checkpointing datasets."""
|
|
|
|
import os
|
|
|
|
import numpy as np
|
|
from tensorflow.python.checkpoint import checkpoint as tracking_util
|
|
from tensorflow.python.checkpoint import checkpoint_management
|
|
from tensorflow.python.data.experimental.ops import iterator_ops as contrib_iterator_ops
|
|
from tensorflow.python.data.ops import dataset_ops
|
|
from tensorflow.python.data.ops import options as options_lib
|
|
from tensorflow.python.eager import context
|
|
from tensorflow.python.framework import combinations
|
|
from tensorflow.python.framework import dtypes
|
|
from tensorflow.python.framework import errors
|
|
from tensorflow.python.framework import ops
|
|
from tensorflow.python.framework import sparse_tensor
|
|
from tensorflow.python.ops import lookup_ops
|
|
from tensorflow.python.ops import variables
|
|
from tensorflow.python.ops.ragged import ragged_tensor_value
|
|
from tensorflow.python.platform import gfile
|
|
from tensorflow.python.platform import test
|
|
from tensorflow.python.training import saver as saver_lib
|
|
from tensorflow.python.util import nest
|
|
|
|
|
|
def remove_variants(get_next_op):
|
|
# TODO(b/72408568): Remove this once session.run can get variant tensors.
|
|
"""Remove variants from a nest structure, so sess.run will execute."""
|
|
|
|
def _remove_variant(x):
|
|
if isinstance(x, ops.Tensor) and x.dtype == dtypes.variant:
|
|
return ()
|
|
else:
|
|
return x
|
|
|
|
return nest.map_structure(_remove_variant, get_next_op)
|
|
|
|
|
|
def default_test_combinations():
|
|
"""Returns the default test combinations for testing checkpointing."""
|
|
|
|
def disable_optimizations(ds_fn):
|
|
options = options_lib.Options()
|
|
options.experimental_optimization.apply_default_optimizations = False
|
|
|
|
def ds_fn_no_opt():
|
|
return ds_fn().with_options(options)
|
|
|
|
return ds_fn_no_opt
|
|
|
|
def verify_unused_iterator(obj, ds_fn, num_outputs, sparse_tensors=False):
|
|
obj.verify_unused_iterator(
|
|
ds_fn=disable_optimizations(ds_fn=ds_fn),
|
|
num_outputs=num_outputs,
|
|
sparse_tensors=sparse_tensors)
|
|
|
|
verify_unused_iterator_combination = combinations.combine(
|
|
verify_fn=combinations.NamedObject("verify_unused_iterator",
|
|
verify_unused_iterator))
|
|
|
|
def verify_fully_used_iterator(obj, ds_fn, num_outputs, sparse_tensors=False):
|
|
obj.verify_fully_used_iterator(
|
|
ds_fn=disable_optimizations(ds_fn=ds_fn),
|
|
num_outputs=num_outputs,
|
|
sparse_tensors=sparse_tensors)
|
|
|
|
verify_fully_used_iterator_combination = combinations.combine(
|
|
verify_fn=combinations.NamedObject("verify_fully_used_iterator",
|
|
verify_fully_used_iterator))
|
|
|
|
def verify_exhausted_iterator(obj, ds_fn, num_outputs, sparse_tensors=False):
|
|
obj.verify_exhausted_iterator(
|
|
ds_fn=disable_optimizations(ds_fn=ds_fn),
|
|
num_outputs=num_outputs,
|
|
sparse_tensors=sparse_tensors)
|
|
|
|
verify_exhausted_iterator_combination = combinations.combine(
|
|
verify_fn=combinations.NamedObject("verify_exhausted_iterator",
|
|
verify_exhausted_iterator))
|
|
|
|
def verify_multiple_breaks(obj, ds_fn, num_outputs, sparse_tensors=False):
|
|
obj.verify_multiple_breaks(
|
|
ds_fn=disable_optimizations(ds_fn=ds_fn),
|
|
num_outputs=num_outputs,
|
|
sparse_tensors=sparse_tensors)
|
|
|
|
verify_multiple_breaks_combination = combinations.combine(
|
|
verify_fn=combinations.NamedObject("verify_multiple_breaks",
|
|
verify_multiple_breaks))
|
|
|
|
def verify_reset_restored_iterator(obj,
|
|
ds_fn,
|
|
num_outputs,
|
|
sparse_tensors=False):
|
|
obj.verify_reset_restored_iterator(
|
|
ds_fn=disable_optimizations(ds_fn=ds_fn),
|
|
num_outputs=num_outputs,
|
|
sparse_tensors=sparse_tensors)
|
|
|
|
verify_reset_restored_iterator_combination = combinations.combine(
|
|
verify_fn=combinations.NamedObject("verify_reset_restored_iterator",
|
|
verify_reset_restored_iterator))
|
|
|
|
return (verify_unused_iterator_combination +
|
|
verify_fully_used_iterator_combination +
|
|
verify_exhausted_iterator_combination +
|
|
verify_multiple_breaks_combination +
|
|
verify_reset_restored_iterator_combination)
|
|
|
|
|
|
# TODO(b/72657739): Remove sparse_tensor argument, which is to test the
|
|
# (deprecated) saveable `SparseTensorSliceDataset`, once the API
|
|
# `from_sparse_tensor_slices()` and related tests are deleted.
|
|
class CheckpointTestBase(test.TestCase):
|
|
"""Base test class for checkpointing datasets."""
|
|
|
|
def tearDown(self):
|
|
self._delete_ckpt()
|
|
super(CheckpointTestBase, self).tearDown()
|
|
|
|
def verify_unused_iterator(self,
|
|
ds_fn,
|
|
num_outputs,
|
|
sparse_tensors=False,
|
|
verify_exhausted=True):
|
|
"""Verifies that saving and restoring an unused iterator works.
|
|
|
|
Args:
|
|
ds_fn: 0-argument function that returns a Dataset.
|
|
num_outputs: Total number of outputs expected from this Dataset.
|
|
sparse_tensors: Whether dataset is built from SparseTensor(s).
|
|
verify_exhausted: Whether to verify that the iterator has been exhausted
|
|
after producing `num_outputs` elements.
|
|
|
|
Raises:
|
|
AssertionError if any test fails.
|
|
"""
|
|
self.verify_run_with_breaks(
|
|
ds_fn, [0],
|
|
num_outputs,
|
|
sparse_tensors=sparse_tensors,
|
|
verify_exhausted=verify_exhausted)
|
|
|
|
def verify_fully_used_iterator(self,
|
|
ds_fn,
|
|
num_outputs,
|
|
sparse_tensors=False):
|
|
"""Verifies that saving and restoring a fully used iterator works.
|
|
|
|
Note that this only checks saving and restoring an iterator from which
|
|
`num_outputs` items have been produced but does not check for an
|
|
exhausted iterator, i.e., one from which an OutOfRange error has been
|
|
returned.
|
|
|
|
Args:
|
|
ds_fn: 0-argument function that returns a Dataset.
|
|
num_outputs: Total number of outputs expected from this Dataset.
|
|
sparse_tensors: Whether dataset is built from SparseTensor(s).
|
|
|
|
Raises:
|
|
AssertionError if test fails.
|
|
"""
|
|
self.verify_run_with_breaks(
|
|
ds_fn, [num_outputs], num_outputs, sparse_tensors=sparse_tensors)
|
|
|
|
def verify_exhausted_iterator(self, ds_fn, num_outputs, sparse_tensors=False):
|
|
"""Verifies that saving and restoring an exhausted iterator works.
|
|
|
|
An exhausted iterator is one which has returned an OutOfRange error.
|
|
|
|
Args:
|
|
ds_fn: 0-argument function that returns a Dataset.
|
|
num_outputs: Total number of outputs expected from this Dataset.
|
|
sparse_tensors: Whether dataset is built from SparseTensor(s).
|
|
|
|
Raises:
|
|
AssertionError if any test fails.
|
|
"""
|
|
self.gen_outputs(
|
|
ds_fn, [],
|
|
num_outputs,
|
|
verify_exhausted=True,
|
|
sparse_tensors=sparse_tensors)
|
|
actual = self.gen_outputs(
|
|
ds_fn, [],
|
|
0,
|
|
ckpt_saved=True,
|
|
verify_exhausted=True,
|
|
sparse_tensors=sparse_tensors)
|
|
self.assertLen(actual, 0)
|
|
|
|
def verify_multiple_breaks(self,
|
|
ds_fn,
|
|
num_outputs,
|
|
num_breaks=10,
|
|
sparse_tensors=False,
|
|
verify_exhausted=True):
|
|
"""Attempts to save/restore at multiple break points.
|
|
|
|
Args:
|
|
ds_fn: 0-argument function that returns a Dataset.
|
|
num_outputs: Total number of outputs expected from this Dataset.
|
|
num_breaks: The number of break points. These are uniformly spread in [0,
|
|
num_outputs] both inclusive.
|
|
sparse_tensors: Whether dataset is built from SparseTensor(s).
|
|
verify_exhausted: Whether to verify that the iterator has been exhausted
|
|
after producing `num_outputs` elements.
|
|
|
|
Raises:
|
|
AssertionError if any test fails.
|
|
"""
|
|
self.verify_run_with_breaks(
|
|
ds_fn,
|
|
self.gen_break_points(num_outputs, num_breaks),
|
|
num_outputs,
|
|
sparse_tensors=sparse_tensors,
|
|
verify_exhausted=verify_exhausted)
|
|
|
|
def verify_reset_restored_iterator(self,
|
|
ds_fn,
|
|
num_outputs,
|
|
break_point=None,
|
|
sparse_tensors=False,
|
|
verify_exhausted=True):
|
|
"""Attempts to re-initialize a restored iterator.
|
|
|
|
This is useful when restoring a training checkpoint during validation.
|
|
|
|
Args:
|
|
ds_fn: 0-argument function that returns a Dataset.
|
|
num_outputs: Total number of outputs expected from this Dataset.
|
|
break_point: Break point. Optional. Defaults to num_outputs/2.
|
|
sparse_tensors: Whether dataset is built from SparseTensor(s).
|
|
verify_exhausted: Whether to verify that the iterator has been exhausted
|
|
after producing `num_outputs` elements.
|
|
|
|
Raises:
|
|
AssertionError if any test fails.
|
|
"""
|
|
if context.executing_eagerly():
|
|
self.skipTest("Eager mode iteration do not support re-initialization.")
|
|
|
|
break_point = num_outputs // 2 if not break_point else break_point
|
|
|
|
# Collect ground truth containing all outputs.
|
|
expected = self.gen_outputs(
|
|
ds_fn, [],
|
|
num_outputs,
|
|
sparse_tensors=sparse_tensors,
|
|
verify_exhausted=verify_exhausted)
|
|
|
|
# Skip some items and save checkpoint.
|
|
self.gen_outputs(
|
|
ds_fn, [],
|
|
break_point,
|
|
sparse_tensors=sparse_tensors,
|
|
verify_exhausted=False)
|
|
|
|
actual = []
|
|
# Restore from checkpoint and then run init_op.
|
|
with ops.Graph().as_default() as g:
|
|
saver = self._import_meta_graph()
|
|
init_op, get_next_op = self._get_iterator_ops_from_collection(
|
|
ds_fn, sparse_tensors=sparse_tensors)
|
|
get_next_op = remove_variants(get_next_op)
|
|
with self.session(graph=g) as sess:
|
|
self._initialize(init_op, sess)
|
|
self._restore(saver, sess)
|
|
self._initialize(init_op, sess)
|
|
for _ in range(num_outputs):
|
|
actual.append(sess.run(get_next_op))
|
|
if verify_exhausted:
|
|
with self.assertRaises(errors.OutOfRangeError):
|
|
sess.run(get_next_op)
|
|
self.match(expected, actual)
|
|
|
|
def verify_error_on_save(self,
|
|
ds_fn,
|
|
num_outputs,
|
|
error,
|
|
break_point=None,
|
|
sparse_tensors=False):
|
|
"""Attempts to save a non-saveable iterator.
|
|
|
|
Args:
|
|
ds_fn: 0-argument function that returns a Dataset.
|
|
num_outputs: Total number of outputs expected from this Dataset.
|
|
error: Declared error when trying to save iterator.
|
|
break_point: Break point. Optional. Defaults to num_outputs/2.
|
|
sparse_tensors: Whether dataset is built from SparseTensor(s).
|
|
|
|
Raises:
|
|
AssertionError if any test fails.
|
|
"""
|
|
break_point = num_outputs // 2 if not break_point else break_point
|
|
if context.executing_eagerly():
|
|
iterator = iter(ds_fn())
|
|
ckpt = tracking_util.Checkpoint(iterator=iterator)
|
|
for _ in range(break_point):
|
|
next(iterator)
|
|
with self.assertRaises(error):
|
|
ckpt.save(self._ckpt_path())
|
|
else:
|
|
with ops.Graph().as_default() as g:
|
|
init_op, get_next_op, saver = self._build_graph(
|
|
ds_fn, sparse_tensors=sparse_tensors)
|
|
get_next_op = remove_variants(get_next_op)
|
|
with self.session(graph=g) as sess:
|
|
self._initialize(init_op, sess)
|
|
for _ in range(break_point):
|
|
sess.run(get_next_op)
|
|
with self.assertRaises(error):
|
|
self._save(sess, saver)
|
|
|
|
def verify_run_with_breaks(self,
|
|
ds_fn,
|
|
break_points,
|
|
num_outputs,
|
|
sparse_tensors=False,
|
|
verify_exhausted=True):
|
|
"""Verifies that ds_fn() produces the same outputs with and without breaks.
|
|
|
|
1. Builds a Dataset using `ds_fn` and produces `num_outputs` items from it
|
|
*without* stopping at break points.
|
|
2. Builds a Dataset using `ds_fn` and produces `num_outputs` items from it
|
|
with stopping at break points.
|
|
|
|
Deep matches outputs from 1 and 2.
|
|
|
|
Args:
|
|
ds_fn: 0-argument function that returns a Dataset.
|
|
break_points: A list of integers. For each `break_point` in
|
|
`break_points`, we produce outputs till `break_point` number of items
|
|
have been produced and then checkpoint the state. The current graph and
|
|
session are destroyed and a new graph and session are used to produce
|
|
outputs till next checkpoint or till `num_outputs` elements have been
|
|
produced. `break_point` must be <= `num_outputs`.
|
|
num_outputs: Total number of outputs expected from this Dataset.
|
|
sparse_tensors: Whether dataset is built from SparseTensor(s).
|
|
verify_exhausted: Whether to verify that the iterator has been exhausted
|
|
after producing `num_outputs` elements.
|
|
|
|
Raises:
|
|
AssertionError if any test fails.
|
|
"""
|
|
expected = self.gen_outputs(
|
|
ds_fn, [],
|
|
num_outputs,
|
|
sparse_tensors=sparse_tensors,
|
|
verify_exhausted=verify_exhausted)
|
|
|
|
actual = self.gen_outputs(
|
|
ds_fn,
|
|
break_points,
|
|
num_outputs,
|
|
sparse_tensors=sparse_tensors,
|
|
verify_exhausted=verify_exhausted)
|
|
|
|
self.match(expected, actual)
|
|
|
|
def gen_outputs(self,
|
|
ds_fn,
|
|
break_points,
|
|
num_outputs,
|
|
ckpt_saved=False,
|
|
sparse_tensors=False,
|
|
verify_exhausted=True,
|
|
save_checkpoint_at_end=True):
|
|
"""Generates elements from input dataset while stopping at break points.
|
|
|
|
Produces `num_outputs` outputs and saves the state of the iterator in the
|
|
Saver checkpoint.
|
|
|
|
Args:
|
|
ds_fn: 0-argument function that returns the dataset.
|
|
break_points: A list of integers. For each `break_point` in
|
|
`break_points`, we produce outputs till `break_point` number of items
|
|
have been produced and then checkpoint the state. The current graph and
|
|
session are destroyed and a new graph and session are used to produce
|
|
outputs till next checkpoint or till `num_outputs` elements have been
|
|
produced. `break_point` must be <= `num_outputs`.
|
|
num_outputs: The total number of outputs to produce from the iterator.
|
|
ckpt_saved: Whether a checkpoint already exists.
|
|
sparse_tensors: Whether dataset is built from SparseTensor(s).
|
|
verify_exhausted: Whether to verify that the iterator has been exhausted
|
|
after producing `num_outputs` elements.
|
|
save_checkpoint_at_end: Whether to save a checkpoint after producing all
|
|
outputs. If False, checkpoints are saved each break point but not at the
|
|
end. Note that checkpoints overwrite each other so there is always only
|
|
a single checkpoint available. Defaults to True.
|
|
|
|
Returns:
|
|
A list of `num_outputs` items.
|
|
"""
|
|
outputs = []
|
|
|
|
if context.executing_eagerly():
|
|
for i in range(len(break_points) + 1):
|
|
iterator = iter(ds_fn())
|
|
ckpt = tracking_util.Checkpoint(iterator=iterator)
|
|
if ckpt_saved:
|
|
ckpt_path = self._latest_ckpt()
|
|
ckpt.restore(ckpt_path)
|
|
start = break_points[i - 1] if i > 0 else 0
|
|
end = break_points[i] if i < len(break_points) else num_outputs
|
|
num_iters = end - start
|
|
for _ in range(num_iters):
|
|
outputs.append(self.evaluate(next(iterator)))
|
|
if i == len(break_points) and verify_exhausted:
|
|
with self.assertRaises(StopIteration):
|
|
next(iterator)
|
|
if save_checkpoint_at_end or i < len(break_points):
|
|
ckpt_path = ckpt.save(self._ckpt_path())
|
|
ckpt_saved = True
|
|
else:
|
|
def get_ops():
|
|
if ckpt_saved:
|
|
saver = self._import_meta_graph()
|
|
init_op, get_next_op = self._get_iterator_ops_from_collection(
|
|
ds_fn, sparse_tensors=sparse_tensors)
|
|
else:
|
|
init_op, get_next_op, saver = self._build_graph(
|
|
ds_fn, sparse_tensors=sparse_tensors)
|
|
return init_op, get_next_op, saver
|
|
|
|
for i in range(len(break_points) + 1):
|
|
with ops.Graph().as_default() as g:
|
|
init_op, get_next_op, saver = get_ops()
|
|
get_next_op = remove_variants(get_next_op)
|
|
with self.session(graph=g) as sess:
|
|
if ckpt_saved:
|
|
self._initialize(init_op, sess)
|
|
self._restore(saver, sess)
|
|
else:
|
|
self._initialize(init_op, sess)
|
|
start = break_points[i - 1] if i > 0 else 0
|
|
end = break_points[i] if i < len(break_points) else num_outputs
|
|
num_iters = end - start
|
|
for _ in range(num_iters):
|
|
outputs.append(sess.run(get_next_op))
|
|
if i == len(break_points) and verify_exhausted:
|
|
with self.assertRaises(errors.OutOfRangeError):
|
|
sess.run(get_next_op)
|
|
if save_checkpoint_at_end or i < len(break_points):
|
|
self._save(sess, saver)
|
|
ckpt_saved = True
|
|
|
|
return outputs
|
|
|
|
def match(self, expected, actual):
|
|
"""Matches nested structures.
|
|
|
|
Recursively matches shape and values of `expected` and `actual`.
|
|
Handles scalars, numpy arrays and other python sequence containers
|
|
e.g. list, dict, as well as SparseTensorValue and RaggedTensorValue.
|
|
|
|
Args:
|
|
expected: Nested structure 1.
|
|
actual: Nested structure 2.
|
|
|
|
Raises:
|
|
AssertionError if matching fails.
|
|
"""
|
|
if isinstance(expected, np.ndarray):
|
|
expected = expected.tolist()
|
|
if isinstance(actual, np.ndarray):
|
|
actual = actual.tolist()
|
|
self.assertEqual(type(expected), type(actual))
|
|
|
|
if nest.is_nested(expected):
|
|
self.assertEqual(len(expected), len(actual))
|
|
if isinstance(expected, dict):
|
|
for key1, key2 in zip(sorted(expected), sorted(actual)):
|
|
self.assertEqual(key1, key2)
|
|
self.match(expected[key1], actual[key2])
|
|
else:
|
|
for item1, item2 in zip(expected, actual):
|
|
self.match(item1, item2)
|
|
elif isinstance(expected, sparse_tensor.SparseTensorValue):
|
|
self.match((expected.indices, expected.values, expected.dense_shape),
|
|
(actual.indices, actual.values, actual.dense_shape))
|
|
elif isinstance(expected, ragged_tensor_value.RaggedTensorValue):
|
|
self.match((expected.values, expected.row_splits),
|
|
(actual.values, actual.row_splits))
|
|
else:
|
|
self.assertEqual(expected, actual)
|
|
|
|
def does_not_match(self, expected, actual):
|
|
with self.assertRaises(AssertionError):
|
|
self.match(expected, actual)
|
|
|
|
def gen_break_points(self, num_outputs, num_samples=10):
|
|
"""Generates `num_samples` unique break points in [0, num_outputs]."""
|
|
return np.unique(np.linspace(0, num_outputs, num_samples, dtype=int))
|
|
|
|
def _build_graph(self, ds_fn, sparse_tensors=False):
|
|
dataset = ds_fn()
|
|
iterator = dataset_ops.make_initializable_iterator(dataset)
|
|
external_state_policy = dataset.options().experimental_external_state_policy
|
|
saveable = contrib_iterator_ops.make_saveable_from_iterator(
|
|
iterator, external_state_policy=external_state_policy)
|
|
ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
|
|
init_op = iterator.initializer
|
|
if sparse_tensors:
|
|
get_next = sparse_tensor.SparseTensor(*iterator.get_next())
|
|
else:
|
|
get_next = iterator.get_next()
|
|
self._add_iterator_ops_to_collection(init_op, get_next, ds_fn,
|
|
sparse_tensors)
|
|
saver = saver_lib.Saver(allow_empty=True)
|
|
return init_op, get_next, saver
|
|
|
|
def _add_iterator_ops_to_collection(self,
|
|
init_op,
|
|
get_next,
|
|
ds_fn,
|
|
sparse_tensors=False):
|
|
ops.add_to_collection("iterator_ops", init_op)
|
|
# `get_next` may be a tuple e.g. in TensorSliceDataset. Since Collections
|
|
# do not support tuples we flatten the tensors and restore the shape in
|
|
# `_get_iterator_ops_from_collection`.
|
|
if sparse_tensors: # specific for deprecated `from_sparse_tensor_slices`.
|
|
ops.add_to_collection("iterator_ops", get_next.indices)
|
|
ops.add_to_collection("iterator_ops", get_next.values)
|
|
ops.add_to_collection("iterator_ops", get_next.dense_shape)
|
|
return
|
|
|
|
get_next_list = nest.flatten(get_next)
|
|
for i, output_class in enumerate(
|
|
nest.flatten(self._get_output_classes(ds_fn))):
|
|
if output_class is sparse_tensor.SparseTensor:
|
|
ops.add_to_collection("iterator_ops", get_next_list[i].indices)
|
|
ops.add_to_collection("iterator_ops", get_next_list[i].values)
|
|
ops.add_to_collection("iterator_ops", get_next_list[i].dense_shape)
|
|
else:
|
|
ops.add_to_collection("iterator_ops", get_next_list[i])
|
|
|
|
def _get_iterator_ops_from_collection(self, ds_fn, sparse_tensors=False):
|
|
all_ops = ops.get_collection("iterator_ops")
|
|
if sparse_tensors: # specific for deprecated `from_sparse_tensor_slices`.
|
|
init_op, indices, values, dense_shape = all_ops
|
|
return init_op, sparse_tensor.SparseTensor(indices, values, dense_shape)
|
|
get_next_list = []
|
|
i = 1
|
|
for output_class in nest.flatten(self._get_output_classes(ds_fn)):
|
|
if output_class is sparse_tensor.SparseTensor:
|
|
indices, values, dense_shape = all_ops[i:i + 3]
|
|
i += 3
|
|
get_next_list.append(
|
|
sparse_tensor.SparseTensor(indices, values, dense_shape))
|
|
else:
|
|
get_next_list.append(all_ops[i])
|
|
i += 1
|
|
return all_ops[0], nest.pack_sequence_as(
|
|
self._get_output_types(ds_fn), get_next_list)
|
|
|
|
def _get_output_types(self, ds_fn):
|
|
assert not context.executing_eagerly()
|
|
with ops.Graph().as_default():
|
|
return dataset_ops.get_legacy_output_types(ds_fn())
|
|
|
|
def _get_output_shapes(self, ds_fn):
|
|
assert not context.executing_eagerly()
|
|
with ops.Graph().as_default():
|
|
return dataset_ops.get_legacy_output_shapes(ds_fn())
|
|
|
|
def _get_output_classes(self, ds_fn):
|
|
assert not context.executing_eagerly()
|
|
with ops.Graph().as_default():
|
|
return dataset_ops.get_legacy_output_classes(ds_fn())
|
|
|
|
def _ckpt_path(self):
|
|
return os.path.join(self.get_temp_dir(), "iterator")
|
|
|
|
def _latest_ckpt(self):
|
|
return checkpoint_management.latest_checkpoint(self.get_temp_dir())
|
|
|
|
def _save(self, sess, saver):
|
|
saver.save(sess, self._ckpt_path())
|
|
|
|
def _restore(self, saver, sess):
|
|
sess.run(lookup_ops.tables_initializer())
|
|
saver.restore(sess, self._latest_ckpt())
|
|
|
|
def _initialize(self, init_op, sess):
|
|
sess.run(variables.global_variables_initializer())
|
|
sess.run(lookup_ops.tables_initializer())
|
|
sess.run(init_op)
|
|
|
|
def _import_meta_graph(self):
|
|
meta_file_path = self._ckpt_path() + ".meta"
|
|
return saver_lib.import_meta_graph(meta_file_path)
|
|
|
|
def _delete_ckpt(self):
|
|
# Remove all checkpoint files.
|
|
prefix = self._ckpt_path()
|
|
pattern = prefix + "*"
|
|
files = gfile.Glob(pattern)
|
|
map(gfile.Remove, files)
|