Intelegentny_Pszczelarz/.venv/Lib/site-packages/tensorflow/python/data/ops/structured_function.py

306 lines
12 KiB
Python
Raw Normal View History

2023-06-19 00:49:18 +02:00
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for managing tf.data user-defined functions."""
import warnings
from tensorflow.python.data.ops import debug_mode
from tensorflow.python.data.util import nest
from tensorflow.python.data.util import structure
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.eager import function as eager_function
from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.ops import script_ops
from tensorflow.python.util import function_utils
from tensorflow.python.util import lazy_loader
from tensorflow.python.util import variable_utils
autograph = lazy_loader.LazyLoader(
"autograph", globals(),
"tensorflow.python.autograph.impl.api")
# TODO(mdan): Create a public API for this.
autograph_ctx = lazy_loader.LazyLoader(
"autograph_ctx", globals(),
"tensorflow.python.autograph.core.ag_ctx")
def _should_pack(arg):
"""Determines whether the caller needs to pack the argument in a tuple.
If user-defined function returns a list of tensors, `nest.flatten()` and
`ops.convert_to_tensor()` and would conspire to attempt to stack those tensors
into a single tensor because the tf.data version of `nest.flatten()` does
not recurse into lists. Since it is more likely that the list arose from
returning the result of an operation (such as `tf.numpy_function()`) that
returns a list of not-necessarily-stackable tensors, we treat the returned
value as a `tuple` instead. A user wishing to pack the return value into a
single tensor can use an explicit `tf.stack()` before returning.
Args:
arg: argument to check
Returns:
Indication of whether the caller needs to pack the argument in a tuple.
"""
return isinstance(arg, list)
def _should_unpack(arg):
"""Determines whether the caller needs to unpack the argument from a tuple.
Args:
arg: argument to check
Returns:
Indication of whether the caller needs to unpack the argument from a tuple.
"""
return type(arg) is tuple # pylint: disable=unidiomatic-typecheck
class StructuredFunctionWrapper():
"""A function wrapper that supports structured arguments and return values."""
def __init__(self,
func,
transformation_name,
dataset=None,
input_classes=None,
input_shapes=None,
input_types=None,
input_structure=None,
add_to_graph=True,
use_legacy_function=False,
defun_kwargs=None):
"""Creates a new `StructuredFunctionWrapper` for the given function.
Args:
func: A function from a (nested) structure to another (nested) structure.
transformation_name: Human-readable name of the transformation in which
this function is being instantiated, for error messages.
dataset: (Optional.) A `tf.data.Dataset`. If given, the structure of this
dataset will be assumed as the structure for `func` arguments; otherwise
`input_classes`, `input_shapes`, and `input_types` must be defined.
input_classes: (Optional.) A (nested) structure of `type`. If given, this
argument defines the Python types for `func` arguments.
input_shapes: (Optional.) A (nested) structure of `tf.TensorShape`. If
given, this argument defines the shapes and structure for `func`
arguments.
input_types: (Optional.) A (nested) structure of `tf.DType`. If given,
this argument defines the element types and structure for `func`
arguments.
input_structure: (Optional.) A `Structure` object. If given, this argument
defines the element types and structure for `func` arguments.
add_to_graph: (Optional.) If `True`, the function will be added to the
default graph, if it exists.
use_legacy_function: (Optional.) A boolean that determines whether the
function be created using `tensorflow.python.eager.function.defun`
(default behavior) or `tensorflow.python.framework.function.Defun`
(legacy behavior).
defun_kwargs: (Optional.) A dictionary mapping string argument names to
values. If supplied, will be passed to `function` as keyword arguments.
Raises:
ValueError: If an invalid combination of `dataset`, `input_classes`,
`input_shapes`, and `input_types` is passed.
"""
# pylint: disable=protected-access
if input_structure is None:
if dataset is None:
if input_classes is None or input_shapes is None or input_types is None:
raise ValueError("Either `dataset`, `input_structure` or all of "
"`input_classes`, `input_shapes`, and `input_types` "
"must be specified.")
self._input_structure = structure.convert_legacy_structure(
input_types, input_shapes, input_classes)
else:
if not (input_classes is None and input_shapes is None and
input_types is None):
raise ValueError("Either `dataset`, `input_structure` or all of "
"`input_classes`, `input_shapes`, and `input_types` "
"must be specified.")
self._input_structure = dataset.element_spec
else:
if not (dataset is None and input_classes is None and
input_shapes is None and input_types is None):
raise ValueError("Either `dataset`, `input_structure`, or all of "
"`input_classes`, `input_shapes`, and `input_types` "
"must be specified.")
self._input_structure = input_structure
self._func = func
if defun_kwargs is None:
defun_kwargs = {}
readable_transformation_name = transformation_name.replace(
".", "_")[:-2] if len(transformation_name) > 2 else ""
func_name = "_".join(
[readable_transformation_name,
function_utils.get_func_name(func)])
# Sanitize function name to remove symbols that interfere with graph
# construction.
for symbol in ["<", ">", "\\", "'", " "]:
func_name = func_name.replace(symbol, "")
ag_ctx = autograph_ctx.control_status_ctx()
def wrapper_helper(*args):
"""Wrapper for passing nested structures to and from tf.data functions."""
nested_args = structure.from_compatible_tensor_list(
self._input_structure, args)
if not _should_unpack(nested_args):
nested_args = (nested_args,)
ret = autograph.tf_convert(self._func, ag_ctx)(*nested_args)
ret = variable_utils.convert_variables_to_tensors(ret)
if _should_pack(ret):
ret = tuple(ret)
try:
self._output_structure = structure.type_spec_from_value(ret)
except (ValueError, TypeError) as e:
raise TypeError(f"Unsupported return value from function passed to "
f"{transformation_name}: {ret}.") from e
return ret
def trace_legacy_function(defun_kwargs):
@function.Defun(*structure.get_flat_tensor_types(self._input_structure),
**defun_kwargs)
def wrapped_fn(*args):
ret = wrapper_helper(*args)
return structure.to_tensor_list(self._output_structure, ret)
return lambda: wrapped_fn
def trace_py_function(defun_kwargs):
# First we trace the function to infer the output structure.
@eager_function.defun_with_attributes(
input_signature=structure.get_flat_tensor_specs(
self._input_structure),
autograph=False,
attributes=defun_kwargs)
def unused(*args): # pylint: disable=missing-docstring,unused-variable
ret = wrapper_helper(*args)
ret = structure.to_tensor_list(self._output_structure, ret)
return [ops.convert_to_tensor(t) for t in ret]
_ = unused.get_concrete_function()
def py_function_wrapper(*args):
nested_args = structure.from_compatible_tensor_list(
self._input_structure, args)
if not _should_unpack(nested_args):
nested_args = (nested_args,)
ret = self._func(*nested_args)
if _should_pack(ret):
ret = tuple(ret)
ret = structure.to_tensor_list(self._output_structure, ret)
return [ops.convert_to_tensor(t) for t in ret]
# Next we trace the function wrapped in `eager_py_func` to force eager
# execution.
@eager_function.defun_with_attributes(
input_signature=structure.get_flat_tensor_specs(
self._input_structure),
autograph=False,
attributes=defun_kwargs)
def wrapped_fn(*args): # pylint: disable=missing-docstring
return script_ops.eager_py_func(
py_function_wrapper, args,
structure.get_flat_tensor_types(self._output_structure))
return wrapped_fn.get_concrete_function
def trace_tf_function(defun_kwargs):
# Note: wrapper_helper will apply autograph based on context.
@eager_function.defun_with_attributes(
input_signature=structure.get_flat_tensor_specs(
self._input_structure),
autograph=False,
attributes=defun_kwargs)
def wrapped_fn(*args): # pylint: disable=missing-docstring
ret = wrapper_helper(*args)
ret = structure.to_tensor_list(self._output_structure, ret)
return [ops.convert_to_tensor(t) for t in ret]
return wrapped_fn.get_concrete_function
if use_legacy_function:
defun_kwargs.update({"func_name": func_name + "_" + str(ops.uid())})
fn_factory = trace_legacy_function(defun_kwargs)
else:
defun_kwargs.update({"func_name": func_name})
defun_kwargs.update({"_tf_data_function": True})
if debug_mode.DEBUG_MODE:
fn_factory = trace_py_function(defun_kwargs)
else:
if def_function.functions_run_eagerly():
warnings.warn(
"Even though the `tf.config.experimental_run_functions_eagerly` "
"option is set, this option does not apply to tf.data functions. "
"To force eager execution of tf.data functions, please use "
"`tf.data.experimental.enable_debug_mode()`.")
fn_factory = trace_tf_function(defun_kwargs)
self._function = fn_factory()
# There is no graph to add in eager mode.
add_to_graph &= not context.executing_eagerly()
# There are some lifetime issues when a legacy function is not added to a
# out-living graph. It's already deprecated so de-prioritizing the fix.
add_to_graph |= use_legacy_function
if add_to_graph:
self._function.add_to_graph(ops.get_default_graph())
if not use_legacy_function:
outer_graph_seed = ops.get_default_graph().seed
if outer_graph_seed and self._function.graph.seed == outer_graph_seed:
if self._function.graph._seed_used:
warnings.warn(
"Seed %s from outer graph might be getting used by function %s, "
"if the random op has not been provided any seed. Explicitly set "
"the seed in the function if this is not the intended behavior." %
(outer_graph_seed, func_name),
stacklevel=4)
@property
def output_structure(self):
return self._output_structure
@property
def output_classes(self):
return nest.map_structure(
lambda component_spec: component_spec._to_legacy_output_classes(), # pylint: disable=protected-access
self._output_structure)
@property
def output_shapes(self):
return nest.map_structure(
lambda component_spec: component_spec._to_legacy_output_shapes(), # pylint: disable=protected-access
self._output_structure)
@property
def output_types(self):
return nest.map_structure(
lambda component_spec: component_spec._to_legacy_output_types(), # pylint: disable=protected-access
self._output_structure)
@property
def function(self):
return self._function