782 lines
29 KiB
Python
782 lines
29 KiB
Python
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Keras SavedModel serialization.
|
|
|
|
TODO (kathywu): Move to layer_serialization.py. Some model-specific logic should
|
|
go to model_serialization.py.
|
|
"""
|
|
|
|
import functools
|
|
import threading
|
|
import weakref
|
|
|
|
import tensorflow.compat.v1.logging as logging
|
|
import tensorflow.compat.v2 as tf
|
|
|
|
from keras import backend
|
|
from keras.engine import base_layer_utils
|
|
from keras.engine import input_spec
|
|
from keras.mixed_precision import autocast_variable
|
|
from keras.saving.legacy import saving_utils
|
|
from keras.saving.legacy.saved_model import constants
|
|
from keras.saving.legacy.saved_model import load as keras_load
|
|
from keras.saving.legacy.saved_model import serialized_attributes
|
|
from keras.saving.legacy.saved_model import utils
|
|
from keras.utils import layer_utils
|
|
from keras.utils import tf_contextlib
|
|
from keras.utils import tf_utils
|
|
from keras.utils import version_utils
|
|
from keras.utils.generic_utils import LazyLoader
|
|
|
|
# To avoid circular dependencies between keras/engine and keras/saving,
|
|
# code in keras/saving must delay imports.
|
|
|
|
# TODO(b/134426265): Switch back to single-quotes to match the rest of the file
|
|
# once the issue with copybara is fixed.
|
|
|
|
base_layer = LazyLoader("base_layer", globals(), "keras.engine.base_layer")
|
|
metrics = LazyLoader("metrics", globals(), "keras.metrics")
|
|
input_layer = LazyLoader("input_layer", globals(), "keras.engine.input_layer")
|
|
training_lib = LazyLoader("training_lib", globals(), "keras.engine.training")
|
|
sequential_lib = LazyLoader(
|
|
"sequential_lib", globals(), "keras.engine.sequential"
|
|
)
|
|
|
|
|
|
def should_skip_serialization(layer):
|
|
"""Skip serializing extra objects and functions if layer inputs aren't
|
|
set."""
|
|
saved_model_input_spec_set = (
|
|
isinstance(layer, training_lib.Model)
|
|
and layer._saved_model_inputs_spec is not None
|
|
)
|
|
if not layer.built and not saved_model_input_spec_set:
|
|
logging.warning(
|
|
"Skipping full serialization of Keras layer {}, because "
|
|
"it is not built.".format(layer)
|
|
)
|
|
return True
|
|
return False
|
|
|
|
|
|
def _filter_shards(variables):
|
|
return [var for var in variables if not hasattr(var, "_sharded_container")]
|
|
|
|
|
|
def wrap_layer_objects(layer, serialization_cache):
|
|
"""Returns extra trackable objects to attach to the serialized layer.
|
|
|
|
Args:
|
|
layer: Keras Layer object.
|
|
serialization_cache: Dictionary shared between all objects during
|
|
serialization.
|
|
|
|
Returns:
|
|
A dictionary containing all checkpointable objects from a
|
|
SerializedAttributes object. See LayerAttributes and ModelAttributes for
|
|
entire list of objects
|
|
"""
|
|
# Wrap all regularization losses as tf.functions.
|
|
# First, generate list of all regularization losses in this layer and
|
|
# sublayers.
|
|
all_losses = layer._callable_losses[:]
|
|
for child_layer in utils.list_all_layers(layer):
|
|
all_losses.extend(child_layer._callable_losses)
|
|
# Next, wrap all loss functions as tf.functions. Use the serialization cache
|
|
# to store already-wrapped functions.
|
|
keras_loss_cache = serialization_cache.setdefault("keras_losses", {})
|
|
wrapped_loss_functions = []
|
|
for loss_fn in all_losses:
|
|
if loss_fn in keras_loss_cache:
|
|
wrapped_loss_functions.append(keras_loss_cache[loss_fn])
|
|
else:
|
|
wrapped_loss = _wrap_unconditional_loss(
|
|
loss_fn, len(keras_loss_cache)
|
|
)
|
|
keras_loss_cache[loss_fn] = wrapped_loss
|
|
wrapped_loss_functions.append(wrapped_loss)
|
|
wrapped_layer_losses = [
|
|
keras_loss_cache[fn] for fn in layer._callable_losses[:]
|
|
]
|
|
|
|
layer_metrics = tf.__internal__.tracking.wrap(
|
|
{m.name: m for m in layer._metrics}
|
|
)
|
|
|
|
# Avoid duplicate creation of shard Variables on loading.
|
|
# `layer.variables` will return the shard Variables rather than the
|
|
# ShardedVariables (b/224541446), but Keras loading will create new
|
|
# ShardedVariables (and thus shard Variables) from Keras metadata if needed.
|
|
# There's no need to also save the shard Variables here, so filter them out.
|
|
variables = _filter_shards(layer.variables)
|
|
trainable_variables = _filter_shards(layer.trainable_variables)
|
|
non_trainable_variables = _filter_shards(layer.non_trainable_variables)
|
|
return dict(
|
|
variables=tf.__internal__.tracking.wrap(variables),
|
|
trainable_variables=tf.__internal__.tracking.wrap(trainable_variables),
|
|
non_trainable_variables=tf.__internal__.tracking.wrap(
|
|
non_trainable_variables
|
|
),
|
|
layers=tf.__internal__.tracking.wrap(utils.list_all_layers(layer)),
|
|
metrics=tf.__internal__.tracking.wrap(layer.metrics),
|
|
regularization_losses=tf.__internal__.tracking.wrap(
|
|
wrapped_loss_functions
|
|
),
|
|
layer_regularization_losses=tf.__internal__.tracking.wrap(
|
|
wrapped_layer_losses
|
|
),
|
|
layer_metrics=layer_metrics,
|
|
)
|
|
|
|
|
|
def wrap_layer_functions(layer, serialization_cache):
|
|
"""Returns dict of wrapped layer call function and losses in tf.functions.
|
|
|
|
Args:
|
|
layer: Keras Layer object.
|
|
serialization_cache: Dictionary shared between all objects during
|
|
serialization.
|
|
|
|
Returns:
|
|
A dictionary containing all keras tf.functions to serialize. See
|
|
LayerAttributes and ModelAttributes for the list of all attributes.
|
|
"""
|
|
# Since Sequential models may be modified in place using model.add() or
|
|
# model.pop(), don't use saved functions.
|
|
if isinstance(layer, keras_load.RevivedLayer) and not isinstance(
|
|
layer, sequential_lib.Sequential
|
|
):
|
|
return {
|
|
fn_name: getattr(layer.keras_api, fn_name, None)
|
|
for fn_name in serialized_attributes.LayerAttributes.all_functions
|
|
}
|
|
|
|
# Reset the losses of the layer and its children. The call function in each
|
|
# child layer is replaced with tf.functions.
|
|
original_fns = _replace_child_layer_functions(layer, serialization_cache)
|
|
original_losses = _reset_layer_losses(layer)
|
|
|
|
# Wrap all the layer call and activity regularizer functions.
|
|
|
|
# Use LayerCallCollection to ensure that all layer call functions (__call__,
|
|
# call with losses) are traced with the same inputs.
|
|
call_collection = LayerCallCollection(layer)
|
|
call_fn_with_losses = call_collection.add_function(
|
|
_wrap_call_and_conditional_losses(layer),
|
|
f"{layer.name}_layer_call_and_return_conditional_losses",
|
|
# If any of this layer's child layers use the training arg, the traced
|
|
# call functions of this layer will have a training keyword argument. If
|
|
# the original layer does not expect the training arg, then it will have
|
|
# to be removed (by setting `match_layer_training_arg`).
|
|
match_layer_training_arg=True,
|
|
)
|
|
call_fn = call_collection.add_function(
|
|
_extract_outputs_from_fn(layer, call_fn_with_losses),
|
|
f"{layer.name}_layer_call_fn",
|
|
# Since `call_fn` wraps call_fn_with_losses and not the original call
|
|
# function, `match_layer_training_arg` should be set to False.
|
|
match_layer_training_arg=False,
|
|
)
|
|
|
|
fns = {
|
|
"call_and_return_conditional_losses": call_fn_with_losses,
|
|
"__call__": call_fn,
|
|
}
|
|
|
|
if layer._activity_regularizer is not None:
|
|
fns["activity_regularizer_fn"] = _wrap_activity_regularizer(layer)
|
|
fns[
|
|
"call_and_return_all_conditional_losses"
|
|
] = call_collection.add_function(
|
|
_append_activity_regularizer_loss(
|
|
layer, call_fn_with_losses, fns["activity_regularizer_fn"]
|
|
),
|
|
f"{layer.name}_layer_call_and_return_all_conditional_losses",
|
|
match_layer_training_arg=False,
|
|
)
|
|
else:
|
|
fns["activity_regularizer_fn"] = None
|
|
fns["call_and_return_all_conditional_losses"] = call_fn_with_losses
|
|
|
|
# Manually trigger traces before restoring the overwritten functions. The
|
|
# functions are traced within the layer call context to ensure that layer
|
|
# functions (e.g. add_loss) behave as though running in graph mode.
|
|
with tracing_scope():
|
|
call_collection.trace_with_input_signature()
|
|
with base_layer_utils.call_context().enter(
|
|
layer, inputs=None, build_graph=True, training=None, saving=True
|
|
):
|
|
for fn in fns.values():
|
|
if fn is not None and not isinstance(fn, LayerCall):
|
|
fn.get_concrete_function()
|
|
|
|
# Restore overwritten functions and losses
|
|
_restore_child_layer_functions(original_fns)
|
|
_restore_layer_losses(original_losses)
|
|
|
|
return fns
|
|
|
|
|
|
def default_save_signature(layer):
|
|
original_losses = _reset_layer_losses(layer)
|
|
fn = saving_utils.trace_model_call(layer)
|
|
_restore_layer_losses(original_losses)
|
|
return fn
|
|
|
|
|
|
def _replace_child_layer_functions(layer, serialization_cache):
|
|
"""Replaces functions in the children layers with wrapped tf.functions.
|
|
|
|
This step allows functions from parent layers to reference the wrapped
|
|
functions from their children layers instead of retracing the ops.
|
|
|
|
This function also resets all losses stored in the layer. These are stored
|
|
in the returned dictionary. Use `_restore_child_layer_functions` to restore
|
|
the original attributes.
|
|
|
|
Args:
|
|
layer: Keras Layer object.
|
|
serialization_cache: Dictionary shared between all objects during
|
|
serialization.
|
|
|
|
Returns:
|
|
Dictionary mapping layer objects -> original functions and losses:
|
|
{ Child layer 1: {
|
|
'losses': Original losses,
|
|
'call': Original call function
|
|
'_activity_regularizer': Original activity regularizer},
|
|
Child layer 2: ...
|
|
}
|
|
"""
|
|
|
|
original_fns = {}
|
|
|
|
def replace_layer_functions(child_layer, serialized_fns):
|
|
"""Replaces layer call and activity regularizer with wrapped
|
|
functions."""
|
|
original_fns[child_layer] = {
|
|
"call": child_layer.call,
|
|
"_activity_regularizer": child_layer._activity_regularizer,
|
|
}
|
|
with utils.no_automatic_dependency_tracking_scope(child_layer):
|
|
try:
|
|
child_layer._activity_regularizer = serialized_fns.get(
|
|
"activity_regularizer_fn"
|
|
)
|
|
except AttributeError:
|
|
# Some layers have an unsettable activity regularizer.
|
|
pass
|
|
child_layer.call = utils.use_wrapped_call(
|
|
child_layer,
|
|
serialized_fns["call_and_return_conditional_losses"],
|
|
child_layer._call_spec,
|
|
default_training_value=False,
|
|
)
|
|
|
|
def replace_metric_functions(child_layer, serialized_fns):
|
|
"""Replaces metric functions with wrapped functions."""
|
|
original_fns[child_layer] = {
|
|
"__call__": child_layer.__call__,
|
|
"result": child_layer.result,
|
|
"update_state": child_layer.update_state,
|
|
}
|
|
with utils.no_automatic_dependency_tracking_scope(child_layer):
|
|
child_layer.__call__ = serialized_fns["__call__"]
|
|
child_layer.result = serialized_fns["result"]
|
|
child_layer.update_state = serialized_fns["update_state"]
|
|
|
|
for child_layer in utils.list_all_layers(layer):
|
|
if isinstance(child_layer, input_layer.InputLayer):
|
|
continue
|
|
|
|
if child_layer not in serialization_cache[constants.KERAS_CACHE_KEY]:
|
|
serialized_functions = child_layer._trackable_saved_model_saver._get_serialized_attributes( # noqa: E501
|
|
serialization_cache
|
|
).functions
|
|
else:
|
|
serialized_functions = serialization_cache[
|
|
constants.KERAS_CACHE_KEY
|
|
][child_layer].functions
|
|
if not serialized_functions:
|
|
# This indicates either:
|
|
# - circular dependency, which means the current layer's functions
|
|
# should be wrapped first.
|
|
# - Child layer's inputs are not defined, so its functions have
|
|
# not been wrapped. In this case, no replacement is necessary so
|
|
# move on to the next child.
|
|
continue
|
|
|
|
if isinstance(child_layer, metrics.Metric):
|
|
replace_metric_functions(child_layer, serialized_functions)
|
|
else:
|
|
replace_layer_functions(child_layer, serialized_functions)
|
|
|
|
return original_fns
|
|
|
|
|
|
def _restore_child_layer_functions(original_fns):
|
|
"""Restores attributes replaced with `_replace_child_layer_functions`."""
|
|
for child_layer, fns in original_fns.items():
|
|
with utils.no_automatic_dependency_tracking_scope(child_layer):
|
|
for fn_name, fn in fns.items():
|
|
try:
|
|
setattr(child_layer, fn_name, fn)
|
|
except AttributeError:
|
|
# In the case of _activity_regularizer, setting the
|
|
# attribute may be disallowed.
|
|
pass
|
|
|
|
|
|
def _reset_layer_losses(parent_layer):
|
|
"""Resets losses of layer and its sublayers, and returns original losses."""
|
|
losses_dict = {}
|
|
for layer in utils.list_all_layers_and_sublayers(parent_layer):
|
|
losses_dict[layer] = {
|
|
"losses": layer._losses[:],
|
|
"eager_losses": layer._eager_losses[:],
|
|
}
|
|
with utils.no_automatic_dependency_tracking_scope(layer):
|
|
layer._losses = []
|
|
layer._eager_losses = []
|
|
return losses_dict
|
|
|
|
|
|
def _restore_layer_losses(losses_dict):
|
|
for layer in losses_dict:
|
|
with utils.no_automatic_dependency_tracking_scope(layer):
|
|
layer._losses = losses_dict[layer]["losses"]
|
|
layer._eager_losses = losses_dict[layer]["eager_losses"]
|
|
|
|
|
|
class LayerTracingContext(threading.local):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.enable_call_tracing = False
|
|
self.trace_queue = []
|
|
|
|
|
|
_thread_local_data = LayerTracingContext()
|
|
|
|
|
|
@tf_contextlib.contextmanager
|
|
def tracing_scope():
|
|
"""Enables tracing scope."""
|
|
# This enables the LayerCallCollection's tracing mechanism to trace all call
|
|
# functions in the collection.
|
|
previous_value = _thread_local_data.enable_call_tracing
|
|
previous_queue = _thread_local_data.trace_queue
|
|
try:
|
|
_thread_local_data.enable_call_tracing = True
|
|
_thread_local_data.trace_queue = []
|
|
yield
|
|
finally:
|
|
# Run traces from the queue.
|
|
while _thread_local_data.trace_queue:
|
|
fn, args, kwargs, training = _thread_local_data.trace_queue.pop(0)
|
|
if training is not None:
|
|
with backend.deprecated_internal_learning_phase_scope(training):
|
|
fn.get_concrete_function(*args, **kwargs)
|
|
else:
|
|
fn.get_concrete_function(*args, **kwargs)
|
|
_thread_local_data.trace_queue = previous_queue
|
|
_thread_local_data.enable_call_tracing = previous_value
|
|
|
|
|
|
def add_trace_to_queue(fn, args, kwargs, training=None):
|
|
if tracing_enabled():
|
|
_thread_local_data.trace_queue.append(
|
|
(fn, args[:], kwargs.copy(), training)
|
|
)
|
|
|
|
|
|
def tracing_enabled():
|
|
"""Whether to add extra traces to the queue."""
|
|
return _thread_local_data.enable_call_tracing
|
|
|
|
|
|
class LayerCallCollection:
|
|
"""Groups wrapped layer call functions.
|
|
|
|
This is used to ensure that all layer call functions are traced with the
|
|
same inputs-
|
|
- call
|
|
- call_and_return_conditional_losses
|
|
- call_and_return_all_conditional_losses
|
|
"""
|
|
|
|
def __init__(self, layer):
|
|
self.layer = layer
|
|
|
|
self.layer_call_method = _get_layer_call_method(layer)
|
|
self._expects_training_arg = utils.layer_uses_training_bool(layer)
|
|
self._call_spec = layer._call_spec
|
|
|
|
# Create new call spec if the layer itself does not accept a training
|
|
# arg, but one of its child layers does. When this layer's call
|
|
# functions are traced, they will be traced with an added `training`
|
|
# keyword argument.
|
|
if not self.layer._expects_training_arg and self._expects_training_arg:
|
|
arg_spec = utils.set_training_arg_spec(
|
|
self._call_spec.full_argspec, False
|
|
)
|
|
self._call_spec = layer_utils.CallFunctionSpec(arg_spec)
|
|
|
|
self._layer_inputs = self._get_layer_inputs(layer)
|
|
self._functions = weakref.WeakValueDictionary()
|
|
|
|
# Get the input argument name from the args.
|
|
if self._call_spec.arg_names:
|
|
self._input_arg_name = self._call_spec.arg_names[0]
|
|
else:
|
|
# Layer could be defined with only varargs, in which case use a
|
|
# default name.
|
|
self._input_arg_name = "inputs"
|
|
|
|
def _get_layer_inputs(self, layer):
|
|
"""Inspects layer object and returns the inferred input signature.
|
|
|
|
Args:
|
|
layer: Layer object.
|
|
|
|
Returns:
|
|
List of possibly nested TensorSpecs of the layer call function inputs
|
|
in the form of `(args, kwargs)`
|
|
"""
|
|
if (
|
|
isinstance(layer.call, tf.__internal__.function.Function)
|
|
and layer.call.input_signature is not None
|
|
):
|
|
return layer.call.input_signature, {}
|
|
elif isinstance(layer, training_lib.Model):
|
|
return saving_utils.model_call_inputs(layer)
|
|
elif (
|
|
layer.input_spec is not None
|
|
and layer._use_input_spec_as_call_signature
|
|
):
|
|
|
|
def to_tensor_spec_or_none(x):
|
|
spec = input_spec.to_tensor_spec(x, layer._compute_dtype)
|
|
# If the shape is too general (e.g. multiple dimensions are
|
|
# allowed), return None so that separate functions can be
|
|
# generated for each inferred input signature.
|
|
# TODO(b/134962016): currently partial signatures are not
|
|
# supported.
|
|
if spec.shape == tf.TensorShape(None):
|
|
return None, None
|
|
return spec
|
|
|
|
input_signature = [
|
|
tf.nest.map_structure(to_tensor_spec_or_none, layer.input_spec)
|
|
]
|
|
|
|
return input_signature, {}
|
|
else:
|
|
return None, None
|
|
|
|
def add_trace(self, *args, **kwargs):
|
|
"""Traces all functions with the same args and kwargs.
|
|
|
|
Args:
|
|
*args: Positional args passed to the original function.
|
|
**kwargs: Keyword args passed to the original function.
|
|
"""
|
|
args = list(args)
|
|
kwargs = kwargs.copy()
|
|
|
|
for fn in self._functions.values():
|
|
# TODO(kathywu): Replace arguments with broader shapes defined in
|
|
# the input signature.
|
|
if self._expects_training_arg:
|
|
|
|
def trace_with_training(value, fn=fn):
|
|
nonlocal args, kwargs
|
|
(args, kwargs,) = self._call_spec.set_arg_value(
|
|
"training", value, args, kwargs, inputs_in_args=True
|
|
)
|
|
add_trace_to_queue(fn, args, kwargs, value)
|
|
|
|
trace_with_training(True)
|
|
trace_with_training(False)
|
|
else:
|
|
add_trace_to_queue(fn, args, kwargs)
|
|
|
|
def training_arg_was_passed(self, args, kwargs):
|
|
return self._call_spec.arg_was_passed(
|
|
"training", args, kwargs, inputs_in_args=True
|
|
)
|
|
|
|
def get_training_arg_value(self, args, kwargs):
|
|
try:
|
|
return self._call_spec.get_arg_value(
|
|
"training", args, kwargs, inputs_in_args=True
|
|
)
|
|
except KeyError: # Training is not in args or kwargs.
|
|
return None
|
|
|
|
def get_input_arg_value(self, args, kwargs):
|
|
return self._call_spec.get_arg_value(
|
|
self._input_arg_name, args, kwargs, inputs_in_args=True
|
|
)
|
|
|
|
def _maybe_wrap_with_training_arg(self, call_fn, match_layer_training_arg):
|
|
"""Wraps call function with added training argument if necessary."""
|
|
if not self.layer._expects_training_arg and self._expects_training_arg:
|
|
# Add training arg to wrapper function.
|
|
def wrap_with_training_arg(*args, **kwargs):
|
|
if match_layer_training_arg:
|
|
# Remove the training value, since the original call_fn does
|
|
# not expect a training arg. Instead, the training value
|
|
# will be propagated using the call context created in
|
|
# LayerCall.
|
|
args = list(args)
|
|
kwargs = kwargs.copy()
|
|
(args, kwargs,) = self._call_spec.set_arg_value(
|
|
"training",
|
|
None,
|
|
args,
|
|
kwargs,
|
|
inputs_in_args=True,
|
|
pop_kwarg_if_none=True,
|
|
)
|
|
return call_fn(*args, **kwargs)
|
|
|
|
return tf.__internal__.decorator.make_decorator(
|
|
target=call_fn,
|
|
decorator_func=wrap_with_training_arg,
|
|
decorator_argspec=self._call_spec.full_argspec,
|
|
)
|
|
|
|
return call_fn
|
|
|
|
def add_function(self, call_fn, name, match_layer_training_arg):
|
|
"""Adds a layer call function to the collection.
|
|
|
|
Args:
|
|
call_fn: a python function
|
|
name: Name of call function
|
|
match_layer_training_arg: If True, removes the `training` from the
|
|
function arguments when calling `call_fn`.
|
|
|
|
Returns:
|
|
LayerCall (tf.function)
|
|
"""
|
|
fn = LayerCall(
|
|
self,
|
|
self._maybe_wrap_with_training_arg(
|
|
call_fn, match_layer_training_arg
|
|
),
|
|
name,
|
|
)
|
|
self._functions[name] = fn.wrapped_call
|
|
return fn
|
|
|
|
def trace_with_input_signature(self):
|
|
"""Trace with the layer/models inferred input signature if possible."""
|
|
if self._layer_inputs[0] is None:
|
|
return
|
|
|
|
args, kwargs = self._layer_inputs
|
|
if self._expects_training_arg:
|
|
args, kwargs = self._call_spec.set_arg_value(
|
|
"training", False, args, kwargs, inputs_in_args=True
|
|
)
|
|
if None not in tf.nest.flatten([args, kwargs]):
|
|
# Manually add traces for layers that have keyword arguments and
|
|
# have a fully defined input signature.
|
|
self.add_trace(*args, **kwargs)
|
|
|
|
|
|
def _filtered_inputs(inputs):
|
|
return list(filter(tf_utils.is_tensor_or_variable, tf.nest.flatten(inputs)))
|
|
|
|
|
|
def layer_call_wrapper(call_collection, method, name):
|
|
"""Ensures layer losses are kept the same, and runs method in call
|
|
context."""
|
|
|
|
# Create wrapper that deals with losses and call context.
|
|
def wrapper(*args, **kwargs):
|
|
"""Calls method within call context."""
|
|
layer = call_collection.layer
|
|
training = None
|
|
inputs = _filtered_inputs([args, kwargs])
|
|
|
|
if (args or kwargs) and call_collection.training_arg_was_passed(
|
|
args, kwargs
|
|
):
|
|
training = call_collection.get_training_arg_value(args, kwargs)
|
|
|
|
original_losses = _reset_layer_losses(layer)
|
|
with base_layer_utils.call_context().enter(
|
|
layer,
|
|
inputs=inputs,
|
|
build_graph=False,
|
|
training=training,
|
|
saving=True,
|
|
):
|
|
with autocast_variable.enable_auto_cast_variables(
|
|
layer._compute_dtype_object
|
|
):
|
|
ret = method(*args, **kwargs)
|
|
_restore_layer_losses(original_losses)
|
|
return ret
|
|
|
|
# Rename to `name`, since tf.function doesn't have a name argument. Without
|
|
# this, all functions returned by this method will be named "call", which
|
|
# would be a nightmare to debug.
|
|
fn = tf.__internal__.decorator.make_decorator(
|
|
target=method, decorator_func=wrapper
|
|
)
|
|
fn.__name__ = name
|
|
return fn
|
|
|
|
|
|
class LayerCall:
|
|
"""Function that triggers traces of other functions in the same
|
|
collection."""
|
|
|
|
def __init__(self, call_collection, call_fn, name):
|
|
"""Initializes a LayerCall object.
|
|
|
|
Args:
|
|
call_collection: a LayerCallCollection, which contains the other layer
|
|
call functions (e.g. call_with_conditional_losses, call). These
|
|
functions should be traced with the same arguments.
|
|
call_fn: A call function.
|
|
name: Name of the call function.
|
|
"""
|
|
self.call_collection = call_collection
|
|
self.wrapped_call = tf.function(
|
|
layer_call_wrapper(call_collection, call_fn, name)
|
|
)
|
|
|
|
def _maybe_trace(self, args, kwargs):
|
|
# Trigger traces of other call functions + extra training-arg traces.
|
|
if tracing_enabled():
|
|
self.call_collection.add_trace(*args, **kwargs)
|
|
|
|
def __call__(self, *args, **kwargs):
|
|
self._maybe_trace(args, kwargs)
|
|
return self.wrapped_call(*args, **kwargs)
|
|
|
|
def get_concrete_function(self, *args, **kwargs):
|
|
self._maybe_trace(args, kwargs)
|
|
return self.wrapped_call.get_concrete_function(*args, **kwargs)
|
|
|
|
|
|
def _wrap_call_and_conditional_losses(layer):
|
|
"""Wraps call function that returns a tuple of (outputs, losses).
|
|
|
|
The losses returned are conditional on the inputs passed to the call
|
|
function. Unconditional losses (e.g. weight regularizeration) are wrapped
|
|
separately.
|
|
|
|
Args:
|
|
layer: a Keras layer object
|
|
|
|
Returns:
|
|
python call function that returns outputs and conditional losses --
|
|
excludes activity regularizer
|
|
"""
|
|
# Create function that generates both outputs and losses
|
|
layer_call = _get_layer_call_method(layer)
|
|
|
|
def call_and_return_conditional_losses(*args, **kwargs):
|
|
"""Returns layer (call_output, conditional losses) tuple."""
|
|
call_output = layer_call(*args, **kwargs)
|
|
if version_utils.is_v1_layer_or_model(layer):
|
|
conditional_losses = layer.get_losses_for(
|
|
_filtered_inputs([args, kwargs])
|
|
)
|
|
else:
|
|
conditional_losses = [
|
|
l for l in layer.losses if not hasattr(l, "_unconditional_loss")
|
|
]
|
|
return call_output, conditional_losses
|
|
|
|
return _create_call_fn_decorator(layer, call_and_return_conditional_losses)
|
|
|
|
|
|
def _extract_outputs_from_fn(layer, call_and_return_conditional_losses):
|
|
"""Returns a function that returns only call function outputs."""
|
|
if isinstance(layer, keras_load.RevivedLayer):
|
|
return layer.keras_api.__call__
|
|
|
|
def call(inputs, *args, **kwargs):
|
|
return call_and_return_conditional_losses(inputs, *args, **kwargs)[0]
|
|
|
|
return _create_call_fn_decorator(layer, call)
|
|
|
|
|
|
def _append_activity_regularizer_loss(
|
|
layer, call_fn_with_losses, activity_regularizer_fn
|
|
):
|
|
"""Appends activity regularizer loss to losses returned by the wrapped
|
|
fn."""
|
|
|
|
def fn(inputs, *args, **kwargs):
|
|
outputs, losses = call_fn_with_losses(inputs, *args, **kwargs)
|
|
losses.append(activity_regularizer_fn(outputs))
|
|
return outputs, losses
|
|
|
|
return _create_call_fn_decorator(layer, fn)
|
|
|
|
|
|
def _create_call_fn_decorator(layer, wrapped_call):
|
|
call_fn = _get_layer_call_method(layer)
|
|
fn, arg_spec = utils.maybe_add_training_arg(
|
|
layer._call_spec,
|
|
wrapped_call,
|
|
layer._expects_training_arg,
|
|
default_training_value=False,
|
|
)
|
|
return tf.__internal__.decorator.make_decorator(
|
|
target=call_fn, decorator_func=fn, decorator_argspec=arg_spec
|
|
)
|
|
|
|
|
|
def _wrap_unconditional_loss(loss_fn, index):
|
|
"""Wraps callable/unconditional loss, returning a serializable function."""
|
|
# Extract original loss function from partial function
|
|
fn = loss_fn.args[0] if isinstance(loss_fn, functools.partial) else loss_fn
|
|
if isinstance(fn, tf.__internal__.function.Function):
|
|
return fn
|
|
else:
|
|
return tf.__internal__.function.Function(
|
|
fn, f"loss_fn_{index}", input_signature=[]
|
|
)
|
|
|
|
|
|
def _wrap_activity_regularizer(layer):
|
|
"""Wraps the activity regularizer."""
|
|
|
|
if isinstance(
|
|
layer._activity_regularizer, tf.__internal__.function.Function
|
|
):
|
|
return layer._activity_regularizer
|
|
return tf.__internal__.function.Function(
|
|
layer._activity_regularizer,
|
|
f"{layer.name}_activity_regularizer",
|
|
input_signature=[
|
|
tf.TensorSpec(None, layer._compute_dtype or backend.floatx())
|
|
],
|
|
)
|
|
|
|
|
|
def _get_layer_call_method(layer):
|
|
if isinstance(layer.call, (tf.__internal__.function.Function)):
|
|
return layer.call.python_function
|
|
return layer.call
|