307 lines
10 KiB
Python
307 lines
10 KiB
Python
|
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
# ==============================================================================
|
||
|
"""Utility functions shared between SavedModel saving/loading implementations."""
|
||
|
|
||
|
import itertools
|
||
|
import threading
|
||
|
import types
|
||
|
|
||
|
from tensorflow.python.eager import context
|
||
|
from tensorflow.python.keras import backend as K
|
||
|
from tensorflow.python.keras.engine import base_layer_utils
|
||
|
from tensorflow.python.keras.utils import control_flow_util
|
||
|
from tensorflow.python.keras.utils import tf_contextlib
|
||
|
from tensorflow.python.keras.utils import tf_inspect
|
||
|
from tensorflow.python.keras.utils.generic_utils import LazyLoader
|
||
|
from tensorflow.python.util import tf_decorator
|
||
|
|
||
|
|
||
|
# pylint:disable=g-inconsistent-quotes
|
||
|
training_lib = LazyLoader(
|
||
|
"training_lib", globals(),
|
||
|
"tensorflow.python.keras.engine.training")
|
||
|
# pylint:enable=g-inconsistent-quotes
|
||
|
|
||
|
|
||
|
def use_wrapped_call(layer, call_fn, default_training_value=None,
|
||
|
return_method=False):
|
||
|
"""Creates fn that adds the losses returned by call_fn & returns the outputs.
|
||
|
|
||
|
Args:
|
||
|
layer: A Keras layer object
|
||
|
call_fn: tf.function that takes layer inputs (and possibly a training arg),
|
||
|
and returns a tuple of (outputs, list of losses).
|
||
|
default_training_value: Default value of the training kwarg. If `None`, the
|
||
|
default is `K.learning_phase()`.
|
||
|
return_method: Whether to return a method bound to the layer.
|
||
|
|
||
|
Returns:
|
||
|
function that calls call_fn and returns the outputs. Losses returned by
|
||
|
call_fn are added to the layer losses.
|
||
|
"""
|
||
|
expects_training_arg = layer_uses_training_bool(layer)
|
||
|
if hasattr(call_fn, 'original_layer_call'): # call_fn is a LayerCall object
|
||
|
original_call = call_fn.original_layer_call
|
||
|
# In Python 3, callable objects are not compatible with inspect.getargspec
|
||
|
call_fn = call_fn.__call__
|
||
|
else:
|
||
|
original_call = call_fn
|
||
|
fn, arg_spec = maybe_add_training_arg(
|
||
|
original_call, call_fn, expects_training_arg, default_training_value)
|
||
|
|
||
|
def return_outputs_and_add_losses(*args, **kwargs):
|
||
|
"""Returns the outputs from the layer call function, and adds the losses."""
|
||
|
if return_method:
|
||
|
args = args[1:]
|
||
|
|
||
|
outputs, losses = fn(*args, **kwargs)
|
||
|
layer.add_loss(losses, inputs=True)
|
||
|
|
||
|
# TODO(kathywu): This is a temporary hack. When a network of layers is
|
||
|
# revived from SavedModel, only the top-level layer will have losses. This
|
||
|
# causes issues in eager mode because the child layers may have graph losses
|
||
|
# (thus model.losses returns a mix of Eager and graph tensors). To fix this,
|
||
|
# whenever eager losses are added to one layer, add eager losses to all
|
||
|
# child layers. This causes `.losses` to only return eager losses.
|
||
|
# pylint: disable=protected-access
|
||
|
if context.executing_eagerly():
|
||
|
for i in layer._flatten_layers():
|
||
|
if i is not layer:
|
||
|
i._eager_losses = [base_layer_utils.REVIVED_LOSS_PLACEHOLDER]
|
||
|
# pylint: enable=protected-access
|
||
|
return outputs
|
||
|
|
||
|
decorated = tf_decorator.make_decorator(
|
||
|
target=call_fn,
|
||
|
decorator_func=return_outputs_and_add_losses,
|
||
|
decorator_argspec=arg_spec)
|
||
|
|
||
|
if return_method:
|
||
|
return types.MethodType(decorated, layer)
|
||
|
else:
|
||
|
return decorated
|
||
|
|
||
|
|
||
|
def layer_uses_training_bool(layer):
|
||
|
"""Returns whether this layer or any of its children uses the training arg."""
|
||
|
if layer._expects_training_arg: # pylint: disable=protected-access
|
||
|
return True
|
||
|
visited = {layer}
|
||
|
to_visit = list_all_layers(layer)
|
||
|
while to_visit:
|
||
|
layer = to_visit.pop()
|
||
|
if layer in visited:
|
||
|
continue
|
||
|
if getattr(layer, '_expects_training_arg', True):
|
||
|
return True
|
||
|
visited.add(layer)
|
||
|
to_visit.extend(list_all_layers(layer))
|
||
|
return False
|
||
|
|
||
|
|
||
|
def list_all_layers(obj):
|
||
|
if isinstance(obj, training_lib.Model):
|
||
|
# Handle special case of Sequential, which doesn't return
|
||
|
# the `Input` layer.
|
||
|
return obj.layers
|
||
|
else:
|
||
|
return list(obj._flatten_layers(include_self=False, recursive=False)) # pylint: disable=protected-access
|
||
|
|
||
|
|
||
|
def list_all_layers_and_sublayers(obj):
|
||
|
s = set([obj])
|
||
|
s.update(itertools.chain.from_iterable(
|
||
|
list_all_layers_and_sublayers(layer) for layer in list_all_layers(obj)))
|
||
|
return s
|
||
|
|
||
|
|
||
|
def maybe_add_training_arg(
|
||
|
original_call, wrapped_call, expects_training_arg, default_training_value):
|
||
|
"""Decorate call and optionally adds training argument.
|
||
|
|
||
|
If a layer expects a training argument, this function ensures that 'training'
|
||
|
is present in the layer args or kwonly args, with the default training value.
|
||
|
|
||
|
Args:
|
||
|
original_call: Original call function.
|
||
|
wrapped_call: Wrapped call function.
|
||
|
expects_training_arg: Whether to include 'training' argument.
|
||
|
default_training_value: Default value of the training kwarg to include in
|
||
|
the arg spec. If `None`, the default is `K.learning_phase()`.
|
||
|
|
||
|
Returns:
|
||
|
Tuple of (
|
||
|
function that calls `wrapped_call` and sets the training arg,
|
||
|
Argspec of returned function or `None` if the argspec is unchanged)
|
||
|
"""
|
||
|
if not expects_training_arg:
|
||
|
return wrapped_call, None
|
||
|
def wrap_with_training_arg(*args, **kwargs):
|
||
|
"""Wrap the `wrapped_call` function, and set training argument."""
|
||
|
training_arg_index = get_training_arg_index(original_call)
|
||
|
training = get_training_arg(training_arg_index, args, kwargs)
|
||
|
if training is None:
|
||
|
training = default_training_value or K.learning_phase()
|
||
|
|
||
|
args = list(args)
|
||
|
kwargs = kwargs.copy()
|
||
|
|
||
|
def replace_training_and_call(training):
|
||
|
set_training_arg(training, training_arg_index, args, kwargs)
|
||
|
return wrapped_call(*args, **kwargs)
|
||
|
|
||
|
return control_flow_util.smart_cond(
|
||
|
training, lambda: replace_training_and_call(True),
|
||
|
lambda: replace_training_and_call(False))
|
||
|
|
||
|
# Create arg spec for decorated function. If 'training' is not defined in the
|
||
|
# args of the original arg spec, then add it to kwonlyargs.
|
||
|
arg_spec = tf_inspect.getfullargspec(original_call)
|
||
|
defaults = list(arg_spec.defaults) if arg_spec.defaults is not None else []
|
||
|
|
||
|
kwonlyargs = arg_spec.kwonlyargs
|
||
|
kwonlydefaults = arg_spec.kwonlydefaults or {}
|
||
|
# Add training arg if it does not exist, or set the default training value.
|
||
|
if 'training' not in arg_spec.args:
|
||
|
kwonlyargs.append('training')
|
||
|
kwonlydefaults['training'] = default_training_value
|
||
|
else:
|
||
|
index = arg_spec.args.index('training')
|
||
|
training_default_index = len(arg_spec.args) - index
|
||
|
if (arg_spec.defaults and
|
||
|
len(arg_spec.defaults) >= training_default_index and
|
||
|
defaults[-training_default_index] is None):
|
||
|
defaults[-training_default_index] = default_training_value
|
||
|
|
||
|
decorator_argspec = tf_inspect.FullArgSpec(
|
||
|
args=arg_spec.args,
|
||
|
varargs=arg_spec.varargs,
|
||
|
varkw=arg_spec.varkw,
|
||
|
defaults=defaults,
|
||
|
kwonlyargs=kwonlyargs,
|
||
|
kwonlydefaults=kwonlydefaults,
|
||
|
annotations=arg_spec.annotations)
|
||
|
return wrap_with_training_arg, decorator_argspec
|
||
|
|
||
|
|
||
|
def get_training_arg_index(call_fn):
|
||
|
"""Returns the index of 'training' in the layer call function arguments.
|
||
|
|
||
|
Args:
|
||
|
call_fn: Call function.
|
||
|
|
||
|
Returns:
|
||
|
- n: index of 'training' in the call function arguments.
|
||
|
- -1: if 'training' is not found in the arguments, but layer.call accepts
|
||
|
variable keyword arguments
|
||
|
- None: if layer doesn't expect a training argument.
|
||
|
"""
|
||
|
argspec = tf_inspect.getfullargspec(call_fn)
|
||
|
if argspec.varargs:
|
||
|
# When there are variable args, training must be a keyword arg.
|
||
|
if 'training' in argspec.kwonlyargs or argspec.varkw:
|
||
|
return -1
|
||
|
return None
|
||
|
else:
|
||
|
# Try to find 'training' in the list of args or kwargs.
|
||
|
arg_list = argspec.args
|
||
|
if tf_inspect.ismethod(call_fn):
|
||
|
arg_list = arg_list[1:]
|
||
|
|
||
|
if 'training' in arg_list:
|
||
|
return arg_list.index('training')
|
||
|
elif 'training' in argspec.kwonlyargs or argspec.varkw:
|
||
|
return -1
|
||
|
return None
|
||
|
|
||
|
|
||
|
def set_training_arg(training, index, args, kwargs):
|
||
|
if index is None or index < 0 or len(args) <= index: # index is invalid
|
||
|
kwargs['training'] = training
|
||
|
else:
|
||
|
args[index] = training
|
||
|
return args, kwargs
|
||
|
|
||
|
|
||
|
def get_training_arg(index, args, kwargs):
|
||
|
if index is None or index < 0 or len(args) <= index: # index is invalid
|
||
|
return kwargs.get('training', None)
|
||
|
else:
|
||
|
return args[index]
|
||
|
|
||
|
|
||
|
def remove_training_arg(index, args, kwargs):
|
||
|
if index is None or index < 0 or len(args) <= index: # index is invalid
|
||
|
kwargs.pop('training', None)
|
||
|
else:
|
||
|
args.pop(index)
|
||
|
|
||
|
|
||
|
class SaveOptionsContext(threading.local):
|
||
|
|
||
|
def __init__(self):
|
||
|
super(SaveOptionsContext, self).__init__()
|
||
|
self.save_traces = True
|
||
|
|
||
|
|
||
|
_save_options_context = SaveOptionsContext()
|
||
|
|
||
|
|
||
|
@tf_contextlib.contextmanager
|
||
|
def keras_option_scope(save_traces):
|
||
|
previous_value = _save_options_context.save_traces
|
||
|
try:
|
||
|
_save_options_context.save_traces = save_traces
|
||
|
yield
|
||
|
finally:
|
||
|
_save_options_context.save_traces = previous_value
|
||
|
|
||
|
|
||
|
def should_save_traces():
|
||
|
"""Whether to trace layer functions-can be disabled in the save_traces arg."""
|
||
|
return _save_options_context.save_traces
|
||
|
|
||
|
|
||
|
@tf_contextlib.contextmanager
|
||
|
def no_automatic_dependency_tracking_scope(obj):
|
||
|
"""A context that disables automatic dependency tracking when assigning attrs.
|
||
|
|
||
|
Objects that inherit from Autotrackable automatically creates dependencies
|
||
|
to trackable objects through attribute assignments, and wraps data structures
|
||
|
(lists or dicts) with trackable classes. This scope may be used to temporarily
|
||
|
disable this behavior. This works similar to the decorator
|
||
|
`no_automatic_dependency_tracking`.
|
||
|
|
||
|
Example usage:
|
||
|
```
|
||
|
model = tf.keras.Model()
|
||
|
model.arr1 = [] # Creates a ListWrapper object
|
||
|
with no_automatic_dependency_tracking_scope(model):
|
||
|
model.arr2 = [] # Creates a regular, untracked python list
|
||
|
```
|
||
|
|
||
|
Args:
|
||
|
obj: A trackable object.
|
||
|
|
||
|
Yields:
|
||
|
a scope in which the object doesn't track dependencies.
|
||
|
"""
|
||
|
previous_value = getattr(obj, '_setattr_tracking', True)
|
||
|
obj._setattr_tracking = False # pylint: disable=protected-access
|
||
|
try:
|
||
|
yield
|
||
|
finally:
|
||
|
obj._setattr_tracking = previous_value # pylint: disable=protected-access
|