Intelegentny_Pszczelarz/.venv/Lib/site-packages/tensorflow/python/data/ops/save_op.py

123 lines
4.8 KiB
Python
Raw Normal View History

2023-06-19 00:49:18 +02:00
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implementation of SaveDataset in Python."""
import os
from tensorflow.python.checkpoint import checkpoint as checkpoint_lib
from tensorflow.python.checkpoint import checkpoint_management
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import structured_function
from tensorflow.python.data.util import structure
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
from tensorflow.python.platform import gfile
from tensorflow.python.util import lazy_loader
# TODO(b/238903802): Use TypeSpec serialization methods directly.
nested_structure_coder = lazy_loader.LazyLoader(
"nested_structure_coder", globals(),
"tensorflow.python.saved_model.nested_structure_coder")
def _save(input_dataset,
path,
compression=None,
shard_func=None,
checkpoint_args=None):
"""Implements the save function and checkpoint functionality."""
if context.executing_eagerly() and checkpoint_args:
save_dataset = _SaveDataset(input_dataset, path, shard_func, compression)
save_iterator = iter(save_dataset)
if "checkpoint" in checkpoint_args:
raise ValueError(
"'Invalid `checkpoint_args`. `checkpoint_args` are not allowed "
"to include 'checkpoint'."
)
checkpoint = checkpoint_lib.Checkpoint(iterator=save_iterator)
checkpoint_args["checkpoint"] = checkpoint
manager = checkpoint_management.CheckpointManager(**checkpoint_args)
checkpoint.restore(manager.latest_checkpoint)
for _ in enumerate(save_iterator):
if "step_counter" in checkpoint_args:
checkpoint_args["step_counter"].assign_add(delta=1)
manager.save(check_interval=True)
else:
dataset, shard_func, use_shard_func, path = set_save_dataset_attributes(
input_dataset, shard_func, path)
ged_ops.save_dataset(
dataset._variant_tensor, # pylint: disable=protected-access
path=path,
shard_func_other_args=shard_func.captured_inputs,
compression=compression,
shard_func=shard_func,
use_shard_func=use_shard_func)
class _SaveDataset(dataset_ops.UnaryDataset):
""""A dataset that loads previously saved dataset."""
def __init__(self, dataset, path, shard_func, compression):
self._element_spec = dataset.element_spec
self._shard_func = shard_func
dataset, shard_func, use_shard_func, path = set_save_dataset_attributes(
dataset, shard_func, path)
variant_tensor = ged_ops.save_dataset_v2(
dataset._variant_tensor, # pylint: disable=protected-access
path=path,
shard_func_other_args=shard_func.captured_inputs,
shard_func=shard_func,
use_shard_func=use_shard_func,
compression=compression,
output_types=structure.get_flat_tensor_types(dataset.element_spec),
output_shapes=structure.get_flat_tensor_shapes(dataset.element_spec),
)
super().__init__(dataset, variant_tensor)
def _functions(self):
return [self._shard_func]
@property
def element_spec(self):
return self._element_spec
def set_save_dataset_attributes(dataset, shard_func, path):
"""Sets parameters for SaveDatasetOp and SaveDatasetV2Op."""
if shard_func is None:
use_shard_func = False
shard_func = lambda *x: None # a dummy function that will not be used
else:
use_shard_func = True
wrapped_func = structured_function.StructuredFunctionWrapper(
shard_func,
"save()",
input_structure=dataset.element_spec,
add_to_graph=False)
encoded = nested_structure_coder.encode_structure(dataset.element_spec)
gfile.MakeDirs(path)
with gfile.GFile(os.path.join(path, dataset_ops.DATASET_SPEC_FILENAME),
"wb") as f:
f.write(encoded.SerializeToString())
path = ops.convert_to_tensor(path, dtype=dtypes.string, name="path")
shard_func = wrapped_func.function
shard_func.add_to_graph(ops.get_default_graph())
# pylint: disable=protected-access
dataset._apply_debug_options()
return dataset, shard_func, use_shard_func, path