Intelegentny_Pszczelarz/.venv/Lib/site-packages/tensorflow/python/data/experimental/ops/distributed_save_op.py

73 lines
2.9 KiB
Python
Raw Normal View History

2023-06-19 00:49:18 +02:00
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Distributed saving of a dataset to disk."""
from tensorflow.core.protobuf import snapshot_pb2
from tensorflow.python.eager import context
from tensorflow.python.ops import gen_experimental_dataset_ops
from tensorflow.python.util import lazy_loader
# TODO(b/176933539): Use regular import.
# TODO(b/238903802): Use TypeSpec serialization methods directly.
nested_structure_coder = lazy_loader.LazyLoader(
"nested_structure_coder", globals(),
"tensorflow.python.saved_model.nested_structure_coder")
# TODO(b/250921378): Add example to docstring and export to TF API.
def distributed_save(dataset,
directory,
dispatcher_address,
compression="AUTO"):
"""Initiates the process of distributedly saving a dataset to disk.
Args:
dataset: The `tf.data.Dataset` to save.
directory: A string indicating the directory to which to save `dataset`.
dispatcher_address: A string indicating the address of the dispatcher for
the tf.data service instance used to save `dataset`.
compression: (Optional.) A string indicating whether and how to compress the
`dataset` materialization. If `"AUTO"`, the tf.data runtime decides which
algorithm to use. If `"GZIP"` or `"SNAPPY"`, that specific algorithm is
used. If `None`, the `dataset` materialization is not compressed.
Returns:
`None`.
Raises:
ValueError: If not in eager mode.
"""
if not context.executing_eagerly():
return RuntimeError("must be in eager mode")
if not isinstance(dispatcher_address, str):
raise ValueError("`dispatcher_address` must be a string, but is a "
f"{type(dispatcher_address)} ({dispatcher_address}")
if not dispatcher_address:
raise ValueError("`dispatcher_address` must not be empty")
metadata = snapshot_pb2.DistributedSnapshotMetadata(
element_spec=nested_structure_coder.encode_structure(
dataset.element_spec).SerializeToString(),
compression=compression,
)
gen_experimental_dataset_ops.distributed_save(
dataset._variant_tensor, # pylint: disable=protected-access
directory=directory,
address=dispatcher_address,
metadata=metadata.SerializeToString(),
)