120 lines
4.4 KiB
Python
120 lines
4.4 KiB
Python
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Autograph specifc overrides for tf.data.ops."""
|
|
import functools
|
|
|
|
import numpy as np
|
|
|
|
from tensorflow.python.autograph.operators import control_flow
|
|
from tensorflow.python.autograph.operators import py_builtins
|
|
from tensorflow.python.data.ops import iterator_ops
|
|
from tensorflow.python.framework import ops
|
|
from tensorflow.python.framework import tensor_spec
|
|
from tensorflow.python.ops import control_flow_ops
|
|
from tensorflow.python.util import nest
|
|
|
|
|
|
# TODO(mdan): These checks should be easier. Fix the nest API.
|
|
def _verify_spec_compatible(input_name, spec_name, input_, spec):
|
|
"""Verifies that a symbol has a type compatible vith a given spec.
|
|
|
|
Here, compatibility is viewed in the general TensorFlow sense: that the dtypes
|
|
are the same after implicit conversion, if both are tensors.
|
|
|
|
This verifier ensures consistent treatment of types across AutoGraph.
|
|
|
|
Args:
|
|
input_name: A name to use for `input_` in error messages.
|
|
spec_name: A name to use for `spec` in error messages.
|
|
input_: Any, value to verify.
|
|
spec: TypeSpec that `input_` must be compatible with.
|
|
|
|
Raises:
|
|
ValueError if the two types have been determined not to be compatible.
|
|
"""
|
|
assert isinstance(spec, tensor_spec.TensorSpec)
|
|
if input is None:
|
|
# TODO(mdan): raise from None when switching to Py3.
|
|
raise ValueError("{} cannot be None".format(input_name))
|
|
|
|
# TODO(mdan): Use TensorCompatible when ready.
|
|
if isinstance(input_, (bool, int, float, str, np.ndarray)):
|
|
input_ = ops.convert_to_tensor_v2(input_)
|
|
|
|
input_dtype = getattr(input_, "dtype", None)
|
|
|
|
if input_dtype != spec.dtype:
|
|
input_dtype_str = "no dtype" if input_dtype is None else str(input_dtype)
|
|
|
|
raise TypeError(
|
|
"{} must have the same dtype as {}. Expected {}, got {}".format(
|
|
input_name, spec_name, spec.dtype, input_dtype_str
|
|
)
|
|
)
|
|
|
|
|
|
def _verify_structure_compatible(input_name, spec_name, input_, spec):
|
|
"""Verifies that possibly-structured symbol has types compatible vith another.
|
|
|
|
See _verify_spec_compatible for a more concrete meaning of "compatible".
|
|
Unspec _verify_spec_compatible, which handles singular Tensor-spec objects,
|
|
verify_structures_compatible can process structures recognized by tf.nest.
|
|
|
|
Args:
|
|
input_name: A name to use for `input_` in error messages.
|
|
spec_name: A name to use for `spec` in error messages.
|
|
input_: Any, value to verify. May, but doesn't need to, be a structure.
|
|
spec: Any, value that `input_` must be compatible with. May, but doesn't
|
|
need to, be a structure.
|
|
|
|
Raises:
|
|
ValueError if the two types have been determined not to be compatible.
|
|
"""
|
|
try:
|
|
nest.assert_same_structure(input_, spec, expand_composites=True)
|
|
except (ValueError, TypeError) as e:
|
|
raise TypeError(
|
|
"{} must have the same element structure as {}.\n\n{}".format(
|
|
input_name, spec_name, str(e)
|
|
)
|
|
) from e
|
|
|
|
nest.map_structure(
|
|
functools.partial(_verify_spec_compatible, input_name, spec_name), input_,
|
|
spec)
|
|
|
|
|
|
def _next_tf_iterator(iterator, default=py_builtins.UNSPECIFIED):
|
|
if default is py_builtins.UNSPECIFIED:
|
|
# Without a default, fall back to the "normal" behavior which raises
|
|
# a runtime exception.
|
|
return next(iterator)
|
|
opt_iterate = iterator.get_next_as_optional()
|
|
_verify_structure_compatible(
|
|
"the default argument", "the iterate", default, iterator.element_spec
|
|
)
|
|
return control_flow_ops.cond(
|
|
opt_iterate.has_value(), opt_iterate.get_value, lambda: default
|
|
)
|
|
|
|
|
|
def register_overrides():
|
|
py_builtins.next_registry.register(
|
|
iterator_ops.OwnedIterator, _next_tf_iterator
|
|
)
|
|
control_flow.for_loop_registry.register(
|
|
iterator_ops.OwnedIterator, control_flow._tf_iterator_for_stmt # pylint: disable=protected-access
|
|
)
|