Intelegentny_Pszczelarz/.venv/Lib/site-packages/tensorflow/python/data/experimental/ops/writers.py
2023-06-19 00:49:18 +02:00

126 lines
4.8 KiB
Python

# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Python wrappers for tf.data writers."""
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import convert
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_spec
from tensorflow.python.ops import gen_experimental_dataset_ops
from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
@tf_export("data.experimental.TFRecordWriter")
@deprecation.deprecated(
None, "To write TFRecords to disk, use `tf.io.TFRecordWriter`. To save "
"and load the contents of a dataset, use `tf.data.experimental.save` "
"and `tf.data.experimental.load`")
class TFRecordWriter:
"""Writes a dataset to a TFRecord file.
The elements of the dataset must be scalar strings. To serialize dataset
elements as strings, you can use the `tf.io.serialize_tensor` function.
```python
dataset = tf.data.Dataset.range(3)
dataset = dataset.map(tf.io.serialize_tensor)
writer = tf.data.experimental.TFRecordWriter("/path/to/file.tfrecord")
writer.write(dataset)
```
To read back the elements, use `TFRecordDataset`.
```python
dataset = tf.data.TFRecordDataset("/path/to/file.tfrecord")
dataset = dataset.map(lambda x: tf.io.parse_tensor(x, tf.int64))
```
To shard a `dataset` across multiple TFRecord files:
```python
dataset = ... # dataset to be written
def reduce_func(key, dataset):
filename = tf.strings.join([PATH_PREFIX, tf.strings.as_string(key)])
writer = tf.data.experimental.TFRecordWriter(filename)
writer.write(dataset.map(lambda _, x: x))
return tf.data.Dataset.from_tensors(filename)
dataset = dataset.enumerate()
dataset = dataset.apply(tf.data.experimental.group_by_window(
lambda i, _: i % NUM_SHARDS, reduce_func, tf.int64.max
))
# Iterate through the dataset to trigger data writing.
for _ in dataset:
pass
```
"""
def __init__(self, filename, compression_type=None):
"""Initializes a `TFRecordWriter`.
Args:
filename: a string path indicating where to write the TFRecord data.
compression_type: (Optional.) a string indicating what type of compression
to use when writing the file. See `tf.io.TFRecordCompressionType` for
what types of compression are available. Defaults to `None`.
"""
self._filename = ops.convert_to_tensor(
filename, dtypes.string, name="filename")
self._compression_type = convert.optional_param_to_tensor(
"compression_type",
compression_type,
argument_default="",
argument_dtype=dtypes.string)
def write(self, dataset):
"""Writes a dataset to a TFRecord file.
An operation that writes the content of the specified dataset to the file
specified in the constructor.
If the file exists, it will be overwritten.
Args:
dataset: a `tf.data.Dataset` whose elements are to be written to a file
Returns:
In graph mode, this returns an operation which when executed performs the
write. In eager mode, the write is performed by the method itself and
there is no return value.
Raises
TypeError: if `dataset` is not a `tf.data.Dataset`.
TypeError: if the elements produced by the dataset are not scalar strings.
"""
if not isinstance(dataset, dataset_ops.DatasetV2):
raise TypeError(
f"Invalid `dataset.` Expected a `tf.data.Dataset` object but got "
f"{type(dataset)}."
)
if not dataset_ops.get_structure(dataset).is_compatible_with(
tensor_spec.TensorSpec([], dtypes.string)):
raise TypeError(
f"Invalid `dataset`. Expected a`dataset` that produces scalar "
f"`tf.string` elements, but got a dataset which produces elements "
f"with shapes {dataset_ops.get_legacy_output_shapes(dataset)} and "
f"types {dataset_ops.get_legacy_output_types(dataset)}.")
# pylint: disable=protected-access
dataset = dataset._apply_debug_options()
return gen_experimental_dataset_ops.dataset_to_tf_record(
dataset._variant_tensor, self._filename, self._compression_type)