Intelegentny_Pszczelarz/.venv/Lib/site-packages/tensorflow/python/data/ops/padded_batch_op.py
2023-06-19 00:49:18 +02:00

263 lines
10 KiB
Python

# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""The implementation of `tf.data.Dataset.padded_batch`."""
import numpy as np
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
from tensorflow.python.data.util import structure
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import smart_cond
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import gen_dataset_ops
def _padded_batch(input_dataset,
batch_size,
padded_shapes=None,
padding_values=None,
drop_remainder=False,
name=None):
"""See `tf.data.Dataset.padded_batch` for details."""
if padded_shapes is None:
padded_shapes = dataset_ops.get_legacy_output_shapes(input_dataset)
for i, shape in enumerate(nest.flatten(padded_shapes)):
# A `tf.TensorShape` is only false if its *rank* is unknown.
if not shape:
raise ValueError(f"You must provide `padded_shapes` argument because "
f"component {i} has unknown rank.")
return _PaddedBatchDataset(
input_dataset,
batch_size,
padded_shapes,
padding_values,
drop_remainder,
name=name)
def _is_padded_shape_compatible_with(padded_shape, input_component_shape):
"""Returns `True` if `input_component_shape` can be padded to `padded_shape`.
Args:
padded_shape: A `tf.TensorShape`.
input_component_shape: A `tf.TensorShape`.
Returns:
`True` if `input_component_shape` can be padded to `padded_shape`, otherwise
`False`.
"""
if padded_shape.dims is None or input_component_shape.dims is None:
return True
if len(padded_shape.dims) != len(input_component_shape.dims):
return False
for padded_dim, input_dim in zip(padded_shape.dims,
input_component_shape.dims):
if (padded_dim.value is not None and input_dim.value is not None and
padded_dim.value < input_dim.value):
return False
return True
def _padded_shape_to_tensor(padded_shape, input_component_shape):
"""Converts `padded_shape` to a `tf.Tensor` representing that shape.
Args:
padded_shape: A shape-like object, which may be a `tf.TensorShape`, a Python
sequence, or a 1-D `tf.Tensor` of `tf.int64` elements.
input_component_shape: A `tf.TensorShape`, with which `padded_shape` must be
compatible.
Returns:
A 1-D `tf.Tensor` of `tf.int64` elements, representing `padded_shape`.
Raises:
ValueError: If `padded_shape` is not a shape or not compatible with
`input_component_shape`.
TypeError: If `padded_shape` is not convertible to a `tf.int64` tensor.
"""
try:
# Try to convert the `padded_shape` to a `tf.TensorShape`
padded_shape_as_shape = tensor_shape.as_shape(padded_shape)
# We will return the "canonical" tensor representation, which uses
# `-1` in place of `None`.
ret = ops.convert_to_tensor([
dim if dim is not None else -1
for dim in padded_shape_as_shape.as_list()
],
dtype=dtypes.int64)
except (TypeError, ValueError) as e:
# The argument was not trivially convertible to a
# `tf.TensorShape`, so fall back on the conversion to tensor
# machinery.
ret = ops.convert_to_tensor(padded_shape, preferred_dtype=dtypes.int64)
if ret.shape.dims is not None and len(ret.shape.dims) != 1:
raise ValueError(
f"Padded shape {padded_shape} must be a `tf.int64` vector tensor, "
f"but its shape was {ret.shape}.") from e
if ret.dtype != dtypes.int64:
raise TypeError(
f"Padded shape {padded_shape} must be a `tf.int64` vector "
f"tensor, but its element type was {ret.dtype.name}.") from e
padded_shape_as_shape = tensor_util.constant_value_as_shape(ret)
if not _is_padded_shape_compatible_with(padded_shape_as_shape,
input_component_shape):
raise ValueError(f"The padded shape {padded_shape_as_shape} is not "
f"compatible with the shape {input_component_shape} of "
f"the corresponding input component.")
return ret
def _padding_values_or_default(padding_values, input_dataset):
"""Returns padding values with None elements replaced with default values."""
def make_zero(t):
if t.base_dtype == dtypes.string:
return ""
elif t.base_dtype == dtypes.variant:
raise TypeError("Unable to create default padding value for a component "
"of type 'variant'.")
elif t.base_dtype == dtypes.bfloat16:
# Special case `bfloat16` because it is not supported by NumPy.
return constant_op.constant(0, dtype=dtypes.bfloat16)
else:
return np.zeros_like(t.as_numpy_dtype())
def value_or_default(value, default):
return default if value is None else value
default_padding = nest.map_structure(
make_zero, dataset_ops.get_legacy_output_types(input_dataset))
return nest.map_structure_up_to(padding_values, value_or_default,
padding_values, default_padding)
def _padding_value_to_tensor(value, output_type):
"""Converts the padding value to a tensor.
Args:
value: The padding value.
output_type: Its expected dtype.
Returns:
A scalar `Tensor`.
Raises:
ValueError: if the padding value is not a scalar.
TypeError: if the padding value's type does not match `output_type`.
"""
value = ops.convert_to_tensor(value, name="padding_value")
if not value.shape.is_compatible_with(tensor_shape.TensorShape([])):
raise ValueError(f"Invalid `padding_values`. `padding_values` values "
f"should be scalars, but got {value.shape}.")
if value.dtype != output_type:
raise TypeError(f"Invalid `padding_values`. `padding_values` values "
f"type {value.dtype} does not match type {output_type} "
f"of the corresponding input component.")
return value
class _PaddedBatchDataset(dataset_ops.UnaryDataset):
"""A `Dataset` that batches and pads contiguous elements from its input."""
def __init__(self,
input_dataset,
batch_size,
padded_shapes,
padding_values,
drop_remainder,
name=None):
"""See `Dataset.batch()` for details."""
self._input_dataset = input_dataset
def check_types(component_spec):
if not isinstance(component_spec, tensor_spec.TensorSpec):
if isinstance(component_spec, dataset_ops.DatasetSpec):
raise TypeError(
"`padded_batch` is not supported for datasets of datasets")
raise TypeError(f"`padded_batch` is only supported for datasets that "
f"produce tensor elements but type spec of elements in "
f"the input dataset is not a subclass of TensorSpec: "
f"`{component_spec}`.")
nest.map_structure(check_types, input_dataset.element_spec)
self._input_dataset = input_dataset
self._batch_size = ops.convert_to_tensor(
batch_size, dtype=dtypes.int64, name="batch_size")
padding_values = _padding_values_or_default(padding_values, input_dataset)
input_shapes = dataset_ops.get_legacy_output_shapes(input_dataset)
flat_padded_shapes = nest.flatten_up_to(input_shapes, padded_shapes)
flat_padded_shapes_as_tensors = []
for input_component_shape, padded_shape in zip(
nest.flatten(input_shapes), flat_padded_shapes):
flat_padded_shapes_as_tensors.append(
_padded_shape_to_tensor(padded_shape, input_component_shape))
self._padded_shapes = nest.pack_sequence_as(input_shapes,
flat_padded_shapes_as_tensors)
# If padding_values is a single element and input_shapes is a structure,
# "broadcast" padding_values to the same structure as input_shapes.
if nest.is_nested(input_shapes) and not nest.is_nested(padding_values):
padding_values = nest.map_structure(lambda _: padding_values,
input_shapes)
self._padding_values = nest.map_structure_up_to(
input_shapes, _padding_value_to_tensor, padding_values,
dataset_ops.get_legacy_output_types(input_dataset))
self._drop_remainder = ops.convert_to_tensor(
drop_remainder, dtype=dtypes.bool, name="drop_remainder")
def _padded_shape_to_batch_shape(s):
return tensor_shape.TensorShape([
tensor_util.constant_value(self._batch_size)
if smart_cond.smart_constant_value(self._drop_remainder) else None
]).concatenate(tensor_util.constant_value_as_shape(s))
output_shapes = nest.map_structure(_padded_shape_to_batch_shape,
self._padded_shapes)
self._structure = structure.convert_legacy_structure(
dataset_ops.get_legacy_output_types(self._input_dataset), output_shapes,
dataset_ops.get_legacy_output_classes(self._input_dataset))
self._name = name
# pylint: disable=protected-access
variant_tensor = gen_dataset_ops.padded_batch_dataset_v2(
input_dataset._variant_tensor, # pylint: disable=protected-access
batch_size=self._batch_size,
padded_shapes=[
ops.convert_to_tensor(s, dtype=dtypes.int64)
for s in nest.flatten(self._padded_shapes)
],
padding_values=nest.flatten(self._padding_values),
drop_remainder=self._drop_remainder,
output_shapes=structure.get_flat_tensor_shapes(self._structure),
metadata=self._metadata.SerializeToString())
super().__init__(input_dataset, variant_tensor)
@property
def element_spec(self):
return self._structure