674 lines
26 KiB
Python
674 lines
26 KiB
Python
|
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
# =============================================================================
|
||
|
|
||
|
"""Contains the base Layer class, from which all layers inherit."""
|
||
|
from __future__ import absolute_import
|
||
|
from __future__ import division
|
||
|
from __future__ import print_function
|
||
|
|
||
|
import copy
|
||
|
import warnings
|
||
|
|
||
|
import tensorflow.compat.v2 as tf
|
||
|
|
||
|
from keras import backend
|
||
|
from keras.engine import base_layer_utils
|
||
|
from keras.engine import base_layer_v1 as base_layer
|
||
|
from keras.legacy_tf_layers import variable_scope_shim
|
||
|
from keras.mixed_precision import policy
|
||
|
from keras.utils import tf_contextlib
|
||
|
|
||
|
# isort: off
|
||
|
from tensorflow.python.ops import variable_scope as vs
|
||
|
from tensorflow.python.util.tf_export import keras_export
|
||
|
from tensorflow.python.util.tf_export import tf_export
|
||
|
|
||
|
_KERAS_STYLE_SCOPE = False
|
||
|
|
||
|
|
||
|
@keras_export(
|
||
|
v1=["keras.__internal__.legacy.layers.experimental.keras_style_scope"]
|
||
|
)
|
||
|
@tf_export(v1=["layers.experimental.keras_style_scope"])
|
||
|
@tf_contextlib.contextmanager
|
||
|
def keras_style_scope():
|
||
|
"""Use Keras-style variable management.
|
||
|
|
||
|
All tf.layers and tf RNN cells created in this scope use Keras-style
|
||
|
variable management. Creating such layers with a scope= argument is
|
||
|
disallowed, and reuse=True is disallowed.
|
||
|
|
||
|
The purpose of this scope is to allow users of existing layers to
|
||
|
slowly transition to a Keras layers API without breaking existing
|
||
|
functionality.
|
||
|
|
||
|
One example of this is when using TensorFlow's RNN classes with Keras
|
||
|
Models or Networks. Because Keras models do not properly set variable
|
||
|
scopes, users of RNNs may either accidentally share scopes between two
|
||
|
different models, or get errors about variables that already exist.
|
||
|
|
||
|
Example:
|
||
|
|
||
|
```python
|
||
|
class RNNModel(tf.keras.Model):
|
||
|
|
||
|
def __init__(self, name):
|
||
|
super(RNNModel, self).__init__(name=name)
|
||
|
self.rnn = tf.compat.v1.nn.rnn_cell.MultiRNNCell(
|
||
|
[tf.compat.v1.nn.rnn_cell.LSTMCell(64) for _ in range(2)])
|
||
|
|
||
|
def call(self, input, state):
|
||
|
return self.rnn(input, state)
|
||
|
|
||
|
model_1 = RNNModel("model_1")
|
||
|
model_2 = RNNModel("model_2")
|
||
|
|
||
|
# OK
|
||
|
output_1, next_state_1 = model_1(input, state)
|
||
|
# Raises an error about trying to create an already existing variable.
|
||
|
output_2, next_state_2 = model_2(input, state)
|
||
|
```
|
||
|
|
||
|
The solution is to wrap the model construction and execution in a
|
||
|
keras-style scope:
|
||
|
|
||
|
```python
|
||
|
with keras_style_scope():
|
||
|
model_1 = RNNModel("model_1")
|
||
|
model_2 = RNNModel("model_2")
|
||
|
|
||
|
# model_1 and model_2 are guaranteed to create their own variables.
|
||
|
output_1, next_state_1 = model_1(input, state)
|
||
|
output_2, next_state_2 = model_2(input, state)
|
||
|
|
||
|
assert len(model_1.weights) > 0
|
||
|
assert len(model_2.weights) > 0
|
||
|
assert(model_1.weights != model_2.weights)
|
||
|
```
|
||
|
|
||
|
Yields:
|
||
|
A keras layer style scope.
|
||
|
"""
|
||
|
global _KERAS_STYLE_SCOPE
|
||
|
stack = _KERAS_STYLE_SCOPE
|
||
|
_KERAS_STYLE_SCOPE = True
|
||
|
try:
|
||
|
yield
|
||
|
finally:
|
||
|
_KERAS_STYLE_SCOPE = stack
|
||
|
|
||
|
|
||
|
@keras_export(
|
||
|
v1=["keras.__internal__.legacy.layers.experimental.set_keras_style"]
|
||
|
)
|
||
|
@tf_export(v1=["layers.experimental.set_keras_style"])
|
||
|
def set_keras_style():
|
||
|
"""Use Keras-style variable management.
|
||
|
|
||
|
All tf.layers and tf RNN cells created after keras style ha been enabled
|
||
|
use Keras-style variable management. Creating such layers with a
|
||
|
scope= argument is disallowed, and reuse=True is disallowed.
|
||
|
|
||
|
The purpose of this function is to allow users of existing layers to
|
||
|
slowly transition to Keras layers API without breaking existing
|
||
|
functionality.
|
||
|
|
||
|
For more details, see the documentation for `keras_style_scope`.
|
||
|
|
||
|
Note, once keras style has been set, it is set globally for the entire
|
||
|
program and cannot be unset.
|
||
|
|
||
|
Example:
|
||
|
|
||
|
```python
|
||
|
set_keras_style()
|
||
|
|
||
|
model_1 = RNNModel(name="model_1")
|
||
|
model_2 = RNNModel(name="model_2")
|
||
|
|
||
|
# model_1 and model_2 are guaranteed to create their own variables.
|
||
|
output_1, next_state_1 = model_1(input, state)
|
||
|
output_2, next_state_2 = model_2(input, state)
|
||
|
|
||
|
assert len(model_1.weights) > 0
|
||
|
assert len(model_2.weights) > 0
|
||
|
assert(model_1.weights != model_2.weights)
|
||
|
```
|
||
|
"""
|
||
|
global _KERAS_STYLE_SCOPE
|
||
|
_KERAS_STYLE_SCOPE = True
|
||
|
|
||
|
|
||
|
def _is_in_keras_style_scope():
|
||
|
global _KERAS_STYLE_SCOPE
|
||
|
return _KERAS_STYLE_SCOPE
|
||
|
|
||
|
|
||
|
@keras_export(v1=["keras.__internal__.legacy.layers.Layer"])
|
||
|
@tf_export(v1=["layers.Layer"])
|
||
|
class Layer(base_layer.Layer):
|
||
|
"""Base layer class.
|
||
|
|
||
|
It is considered legacy, and we recommend the use of `tf.keras.layers.Layer`
|
||
|
instead.
|
||
|
|
||
|
Args:
|
||
|
trainable: Boolean, whether the layer's variables should be trainable.
|
||
|
name: String name of the layer.
|
||
|
dtype: Default dtype of the layer's weights (default of `None` means use
|
||
|
the type of the first input).
|
||
|
|
||
|
Read-only properties:
|
||
|
name: The name of the layer (string).
|
||
|
dtype: Default dtype of the layer's weights (default of `None` means use
|
||
|
the type of the first input).
|
||
|
trainable_variables: List of trainable variables.
|
||
|
non_trainable_variables: List of non-trainable variables.
|
||
|
variables: List of all variables of this layer, trainable and
|
||
|
non-trainable.
|
||
|
updates: List of update ops of this layer.
|
||
|
losses: List of losses added by this layer.
|
||
|
trainable_weights: List of variables to be included in backprop.
|
||
|
non_trainable_weights: List of variables that should not be
|
||
|
included in backprop.
|
||
|
weights: The concatenation of the lists trainable_weights and
|
||
|
non_trainable_weights (in this order).
|
||
|
|
||
|
Mutable properties:
|
||
|
trainable: Whether the layer should be trained (boolean).
|
||
|
input_spec: Optional (list of) `InputSpec` object(s) specifying the
|
||
|
constraints on inputs that can be accepted by the layer.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, trainable=True, name=None, dtype=None, **kwargs):
|
||
|
# For backwards compatibility, legacy layers do not use
|
||
|
# `ResourceVariable` by default.
|
||
|
self._use_resource_variables = False
|
||
|
scope = kwargs.pop("_scope", None)
|
||
|
self._reuse = kwargs.pop("_reuse", None)
|
||
|
|
||
|
# Avoid an incorrect lint error
|
||
|
self._trainable_weights = []
|
||
|
self.built = False
|
||
|
|
||
|
if dtype is None:
|
||
|
# Indicates to infer dtype from inputs. When the V2 dtype behavior
|
||
|
# is enabled, Keras layers default their dtype to floatx instead, so
|
||
|
# we pass an "_infer" policy to keep the old V1 behavior.
|
||
|
dtype = policy.Policy("_infer")
|
||
|
|
||
|
if "autocast" not in kwargs:
|
||
|
kwargs["autocast"] = False
|
||
|
|
||
|
# Mark that legacy layers should not be instrumented as Keras usage
|
||
|
self._disable_keras_instrumentation = True
|
||
|
|
||
|
super().__init__(trainable=trainable, name=name, dtype=dtype, **kwargs)
|
||
|
|
||
|
if _is_in_keras_style_scope():
|
||
|
if scope is not None:
|
||
|
raise ValueError(
|
||
|
"scope argument not allowed when keras style layers are "
|
||
|
"enabled, but saw: {}".format(scope)
|
||
|
)
|
||
|
if self._reuse is not None:
|
||
|
raise ValueError(
|
||
|
"reuse argument not allowed when keras style layers are "
|
||
|
"enabled, but saw: {}".format(self._reuse)
|
||
|
)
|
||
|
self._keras_style = True
|
||
|
else:
|
||
|
self._keras_style = False
|
||
|
|
||
|
self._call_has_scope_arg = "scope" in self._call_spec.arg_names
|
||
|
if scope:
|
||
|
with tf.compat.v1.variable_scope(scope) as captured_scope:
|
||
|
self._scope = captured_scope
|
||
|
else:
|
||
|
self._scope = None
|
||
|
self._current_scope = None
|
||
|
|
||
|
def apply(self, *args, **kwargs):
|
||
|
return self(*args, **kwargs)
|
||
|
|
||
|
# We no longer track graph in tf.layers layers. This property is only kept
|
||
|
# to maintain API backward compatibility.
|
||
|
@property
|
||
|
def graph(self):
|
||
|
warnings.warn(
|
||
|
"`Layer.graph` is deprecated and "
|
||
|
"will be removed in a future version. "
|
||
|
"Please stop using this property because tf.layers layers no "
|
||
|
"longer track their graph.",
|
||
|
stacklevel=2,
|
||
|
)
|
||
|
if tf.executing_eagerly():
|
||
|
raise RuntimeError(
|
||
|
"Layer.graph not supported when executing eagerly."
|
||
|
)
|
||
|
return None
|
||
|
|
||
|
def _init_set_name(self, name):
|
||
|
# Determine layer name (non-unique).
|
||
|
if isinstance(name, tf.compat.v1.VariableScope):
|
||
|
base_name = name.name
|
||
|
self._name, _ = self._make_unique_name()
|
||
|
else:
|
||
|
base_name = name
|
||
|
self._name = name
|
||
|
if not name:
|
||
|
self._name, base_name = self._make_unique_name()
|
||
|
self._base_name = base_name
|
||
|
|
||
|
def _make_unique_name(
|
||
|
self,
|
||
|
name_uid_map=None,
|
||
|
avoid_names=None,
|
||
|
namespace="",
|
||
|
zero_based=False,
|
||
|
):
|
||
|
base_name = base_layer.to_snake_case(self.__class__.__name__)
|
||
|
name = backend.unique_object_name(
|
||
|
base_name,
|
||
|
name_uid_map=name_uid_map,
|
||
|
avoid_names=avoid_names,
|
||
|
namespace=namespace,
|
||
|
zero_based=zero_based,
|
||
|
)
|
||
|
return (name, base_name)
|
||
|
|
||
|
@property
|
||
|
def scope_name(self):
|
||
|
if not self._scope:
|
||
|
raise ValueError(
|
||
|
'No name available for layer scope because the layer "'
|
||
|
+ self._name
|
||
|
+ '" has not been used yet. The scope name '
|
||
|
+ " is determined the first time the layer instance is "
|
||
|
+ "called. You must therefore call the layer before "
|
||
|
+ "querying `scope_name`."
|
||
|
)
|
||
|
return self._scope.name
|
||
|
|
||
|
def add_loss(self, losses, inputs=None):
|
||
|
previous_losses_length = len(self._losses)
|
||
|
previous_callable_losses_length = len(self._callable_losses)
|
||
|
super().add_loss(losses, inputs=inputs)
|
||
|
if not tf.executing_eagerly():
|
||
|
# TODO(fchollet): deprecate collection below.
|
||
|
new_losses = self._losses[previous_losses_length:]
|
||
|
new_callable_losses = self._callable_losses[
|
||
|
previous_callable_losses_length:
|
||
|
]
|
||
|
for regularizer in new_callable_losses:
|
||
|
loss_tensor = regularizer()
|
||
|
if loss_tensor is not None:
|
||
|
new_losses.append(loss_tensor)
|
||
|
_add_elements_to_collection(
|
||
|
new_losses, tf.compat.v1.GraphKeys.REGULARIZATION_LOSSES
|
||
|
)
|
||
|
|
||
|
def _name_scope(self):
|
||
|
"""Determines op naming for the Layer."""
|
||
|
if self._keras_style:
|
||
|
return super()._name_scope()
|
||
|
return self._current_scope.original_name_scope
|
||
|
|
||
|
def _set_scope(self, scope=None):
|
||
|
if self._scope is None:
|
||
|
# If constructed with _scope=None, lazy setting of scope.
|
||
|
if self._reuse:
|
||
|
with tf.compat.v1.variable_scope(
|
||
|
scope if scope is not None else self._base_name
|
||
|
) as captured_scope:
|
||
|
self._scope = captured_scope
|
||
|
else:
|
||
|
with tf.compat.v1.variable_scope(
|
||
|
scope, default_name=self._base_name
|
||
|
) as captured_scope:
|
||
|
self._scope = captured_scope
|
||
|
|
||
|
def add_weight(
|
||
|
self,
|
||
|
name,
|
||
|
shape,
|
||
|
dtype=None,
|
||
|
initializer=None,
|
||
|
regularizer=None,
|
||
|
trainable=None,
|
||
|
constraint=None,
|
||
|
use_resource=None,
|
||
|
synchronization=tf.VariableSynchronization.AUTO,
|
||
|
aggregation=tf.compat.v1.VariableAggregation.NONE,
|
||
|
partitioner=None,
|
||
|
**kwargs
|
||
|
):
|
||
|
"""Adds a new variable to the layer, or gets an existing one; returns it
|
||
|
|
||
|
Args:
|
||
|
name: variable name.
|
||
|
shape: variable shape.
|
||
|
dtype: The type of the variable. Defaults to `self.dtype` or
|
||
|
`float32`.
|
||
|
initializer: initializer instance (callable).
|
||
|
regularizer: regularizer instance (callable).
|
||
|
trainable: whether the variable should be part of the layer's
|
||
|
"trainable_variables" (e.g. variables, biases)
|
||
|
or "non_trainable_variables" (e.g. BatchNorm mean, stddev).
|
||
|
Note, if the current variable scope is marked as non-trainable
|
||
|
then this parameter is ignored and any added variables are also
|
||
|
marked as non-trainable. `trainable` defaults to `True` unless
|
||
|
`synchronization` is set to `ON_READ`.
|
||
|
constraint: constraint instance (callable).
|
||
|
use_resource: Whether to use `ResourceVariable`.
|
||
|
synchronization: Indicates when a distributed a variable will be
|
||
|
aggregated. Accepted values are constants defined in the class
|
||
|
`tf.VariableSynchronization`. By default the synchronization is set
|
||
|
to `AUTO` and the current `DistributionStrategy` chooses when to
|
||
|
synchronize. If `synchronization` is set to `ON_READ`, `trainable`
|
||
|
must not be set to `True`.
|
||
|
aggregation: Indicates how a distributed variable will be aggregated.
|
||
|
Accepted values are constants defined in the class
|
||
|
`tf.VariableAggregation`.
|
||
|
partitioner: (optional) partitioner instance (callable). If
|
||
|
provided, when the requested variable is created it will be split
|
||
|
into multiple partitions according to `partitioner`. In this case,
|
||
|
an instance of `PartitionedVariable` is returned. Available
|
||
|
partitioners include `tf.compat.v1.fixed_size_partitioner` and
|
||
|
`tf.compat.v1.variable_axis_size_partitioner`. For more details,
|
||
|
see the documentation of `tf.compat.v1.get_variable` and the
|
||
|
"Variable Partitioners and Sharding" section of the API guide.
|
||
|
**kwargs: Additional keyword arguments.
|
||
|
|
||
|
Returns:
|
||
|
The created variable. Usually either a `Variable` or
|
||
|
`ResourceVariable` instance. If `partitioner` is not `None`, a
|
||
|
`PartitionedVariable` instance is returned.
|
||
|
|
||
|
Raises:
|
||
|
RuntimeError: If called with partitioned variable regularization and
|
||
|
eager execution is enabled.
|
||
|
ValueError: When trainable has been set to True with synchronization
|
||
|
set as `ON_READ`.
|
||
|
"""
|
||
|
for kwarg in kwargs:
|
||
|
if kwarg != "experimental_autocast":
|
||
|
raise TypeError("Unknown keyword argument:", kwarg)
|
||
|
if self._keras_style:
|
||
|
return super().add_weight(
|
||
|
name=name,
|
||
|
shape=shape,
|
||
|
dtype=dtype,
|
||
|
initializer=initializer,
|
||
|
regularizer=regularizer,
|
||
|
trainable=trainable and self.trainable,
|
||
|
constraint=constraint,
|
||
|
use_resource=use_resource,
|
||
|
synchronization=tf.VariableSynchronization.AUTO,
|
||
|
aggregation=tf.compat.v1.VariableAggregation.NONE,
|
||
|
partitioner=partitioner,
|
||
|
**kwargs
|
||
|
)
|
||
|
|
||
|
if synchronization == tf.VariableSynchronization.ON_READ:
|
||
|
if trainable:
|
||
|
raise ValueError(
|
||
|
"Synchronization value can be set to "
|
||
|
"VariableSynchronization.ON_READ only for non-trainable "
|
||
|
"variables. You have specified trainable=True and "
|
||
|
"synchronization=VariableSynchronization.ON_READ."
|
||
|
)
|
||
|
else:
|
||
|
# Set trainable to be false when variable is to be synced on
|
||
|
# read.
|
||
|
trainable = False
|
||
|
elif trainable is None:
|
||
|
trainable = True
|
||
|
|
||
|
def _should_add_regularizer(variable, existing_variable_set):
|
||
|
if base_layer_utils.is_split_variable(variable):
|
||
|
for var in variable:
|
||
|
if var in existing_variable_set:
|
||
|
return False
|
||
|
return True
|
||
|
else:
|
||
|
return variable not in existing_variable_set
|
||
|
|
||
|
init_graph = None
|
||
|
if not tf.executing_eagerly():
|
||
|
default_graph = tf.compat.v1.get_default_graph()
|
||
|
if default_graph.building_function:
|
||
|
with tf.init_scope():
|
||
|
# Retrieve the variables from the graph into which variables
|
||
|
# will be lifted; if initialization ops will be lifted into
|
||
|
# the eager context, then there is nothing to retrieve,
|
||
|
# since variable collections are not supported when eager
|
||
|
# execution is enabled.
|
||
|
if not tf.executing_eagerly():
|
||
|
init_graph = tf.compat.v1.get_default_graph()
|
||
|
existing_variables = set(
|
||
|
tf.compat.v1.global_variables()
|
||
|
)
|
||
|
else:
|
||
|
# Initialization ops will not be lifted out of the default
|
||
|
# graph.
|
||
|
init_graph = default_graph
|
||
|
existing_variables = set(tf.compat.v1.global_variables())
|
||
|
|
||
|
if dtype is None:
|
||
|
dtype = self.dtype or tf.float32
|
||
|
|
||
|
self._set_scope(None)
|
||
|
reuse = self.built or self._reuse
|
||
|
prev_len_trainable = len(self._trainable_weights)
|
||
|
with tf.compat.v1.variable_scope(
|
||
|
self._scope, reuse=reuse, auxiliary_name_scope=False
|
||
|
) as scope:
|
||
|
self._current_scope = scope
|
||
|
with backend.name_scope(self._name_scope()):
|
||
|
use_resource = (
|
||
|
use_resource
|
||
|
or self._use_resource_variables
|
||
|
or scope.use_resource
|
||
|
)
|
||
|
if initializer is None:
|
||
|
initializer = scope.initializer
|
||
|
variable = super().add_weight(
|
||
|
name,
|
||
|
shape,
|
||
|
dtype=tf.as_dtype(dtype),
|
||
|
initializer=initializer,
|
||
|
trainable=trainable and self.trainable,
|
||
|
constraint=constraint,
|
||
|
partitioner=partitioner,
|
||
|
use_resource=use_resource,
|
||
|
synchronization=synchronization,
|
||
|
aggregation=aggregation,
|
||
|
getter=tf.compat.v1.get_variable,
|
||
|
**kwargs
|
||
|
)
|
||
|
|
||
|
if regularizer:
|
||
|
if (
|
||
|
tf.compat.v1.executing_eagerly_outside_functions()
|
||
|
or _should_add_regularizer(variable, existing_variables)
|
||
|
):
|
||
|
self._handle_weight_regularization(
|
||
|
name, variable, regularizer
|
||
|
)
|
||
|
var_store = vs._get_default_variable_store()
|
||
|
# When the shim to get variable scope working in TF2 is
|
||
|
# used, We need to explicitly make the shim track the
|
||
|
# regularization losses as the collections will not be
|
||
|
# accessible.
|
||
|
if hasattr(var_store, "add_regularizer"):
|
||
|
var_store.add_regularizer(variable, regularizer)
|
||
|
|
||
|
if init_graph is not None:
|
||
|
# Handle edge case where a custom getter has overridden
|
||
|
# `trainable`. There is one known occurrence of this, in
|
||
|
# unit test testBasicRNNCellNotTrainable in
|
||
|
# contrib.rnn.python.kernel_tests.core_rnn_cell_test
|
||
|
with init_graph.as_default():
|
||
|
trainable_variables = tf.compat.v1.trainable_variables()
|
||
|
if (
|
||
|
trainable
|
||
|
and self.trainable
|
||
|
and variable not in trainable_variables
|
||
|
):
|
||
|
# A custom getter / variable scope overrode the
|
||
|
# trainable flag.
|
||
|
extra_trainable_vars = self._trainable_weights[
|
||
|
prev_len_trainable:
|
||
|
]
|
||
|
self._trainable_weights = self._trainable_weights[
|
||
|
:prev_len_trainable
|
||
|
]
|
||
|
self._non_trainable_weights += extra_trainable_vars
|
||
|
return variable
|
||
|
|
||
|
def __call__(self, inputs, *args, **kwargs):
|
||
|
"""Wraps `call`, applying pre- and post-processing steps.
|
||
|
|
||
|
Args:
|
||
|
inputs: input tensor(s).
|
||
|
*args: additional positional arguments to be passed to `self.call`.
|
||
|
**kwargs: additional keyword arguments to be passed to `self.call`.
|
||
|
**Note**: kwarg `scope` is reserved for use by the layer.
|
||
|
|
||
|
Returns:
|
||
|
Output tensor(s).
|
||
|
|
||
|
Note:
|
||
|
- If the layer's `call` method takes a `scope` keyword argument, this
|
||
|
argument will be automatically set to the current variable scope.
|
||
|
- If the layer's `call` method takes a `mask` argument (as some Keras
|
||
|
layers do), its default value will be set to the mask generated
|
||
|
for `inputs` by the previous layer (if `input` did come from
|
||
|
a layer that generated a corresponding mask, i.e. if it came from
|
||
|
a Keras layer with masking support.
|
||
|
|
||
|
Raises:
|
||
|
ValueError: if the layer's `call` method returns None (an invalid
|
||
|
value).
|
||
|
"""
|
||
|
scope = kwargs.pop("scope", None)
|
||
|
|
||
|
if self._keras_style:
|
||
|
if scope is not None:
|
||
|
raise ValueError(
|
||
|
"scope argument not allowed when keras style layers are "
|
||
|
"enabled, but saw: {}".format(scope)
|
||
|
)
|
||
|
return super().__call__(inputs, *args, **kwargs)
|
||
|
|
||
|
self._set_scope(scope)
|
||
|
|
||
|
if self.built:
|
||
|
try:
|
||
|
# Some classes which inherit from Layer do not use its
|
||
|
# constructor, so rather than initializing to None we check for
|
||
|
# an AttributeError.
|
||
|
scope_context_manager = self._always_reuse_variable_scope
|
||
|
except AttributeError:
|
||
|
scope_context_manager = None
|
||
|
|
||
|
if scope_context_manager is None:
|
||
|
# From this point we will always set reuse=True, so create a
|
||
|
# "final" variable scope with this setting. We avoid re-creating
|
||
|
# variable scopes after this point as an optimization.
|
||
|
scope_context_manager = tf.compat.v1.variable_scope(
|
||
|
self._scope, reuse=True, auxiliary_name_scope=False
|
||
|
)
|
||
|
|
||
|
# Do not cache variable scopes if Eager mode is enabled. If
|
||
|
# Eager mode is enabled then we don't want to reuse scopes
|
||
|
# because the cached scope might be from a FuncGraph or Eager
|
||
|
# scope we are no longer in.
|
||
|
if not tf.compat.v1.executing_eagerly_outside_functions():
|
||
|
self._always_reuse_variable_scope = scope_context_manager
|
||
|
else:
|
||
|
scope_context_manager = tf.compat.v1.variable_scope(
|
||
|
self._scope, reuse=self._reuse, auxiliary_name_scope=False
|
||
|
)
|
||
|
|
||
|
with scope_context_manager as scope:
|
||
|
self._current_scope = scope
|
||
|
|
||
|
try:
|
||
|
call_has_scope_arg = self._call_has_scope_arg
|
||
|
except AttributeError:
|
||
|
self._call_spec.arg_names = variable_scope_shim.fn_args(
|
||
|
self.call
|
||
|
)
|
||
|
self._call_has_scope_arg = "scope" in self._call_spec.arg_names
|
||
|
call_has_scope_arg = self._call_has_scope_arg
|
||
|
if call_has_scope_arg:
|
||
|
kwargs["scope"] = scope
|
||
|
|
||
|
# Actually call layer
|
||
|
outputs = super().__call__(inputs, *args, **kwargs)
|
||
|
|
||
|
if not tf.executing_eagerly():
|
||
|
# Update global default collections.
|
||
|
_add_elements_to_collection(
|
||
|
self.updates, tf.compat.v1.GraphKeys.UPDATE_OPS
|
||
|
)
|
||
|
return outputs
|
||
|
|
||
|
def __deepcopy__(self, memo):
|
||
|
no_copy = set(["_graph", "_thread_local", "_metrics_lock"])
|
||
|
shallow_copy = set(["_scope", "_always_reuse_variable_scope"])
|
||
|
cls = self.__class__
|
||
|
result = cls.__new__(cls)
|
||
|
memo[id(self)] = result
|
||
|
for k, v in self.__dict__.items():
|
||
|
if k in no_copy:
|
||
|
setattr(result, k, v)
|
||
|
elif k in shallow_copy:
|
||
|
setattr(result, k, copy.copy(v))
|
||
|
elif base_layer.is_tensor_or_tensor_list(v):
|
||
|
setattr(result, k, v)
|
||
|
else:
|
||
|
setattr(result, k, copy.deepcopy(v, memo))
|
||
|
return result
|
||
|
|
||
|
def __setattr__(self, value, name):
|
||
|
# By-pass the automatic dependency tracking performed by the parent
|
||
|
# Layer.
|
||
|
super(tf.__internal__.tracking.Trackable, self).__setattr__(value, name)
|
||
|
|
||
|
@property
|
||
|
def _is_legacy_layer(self):
|
||
|
"""Used by keras to check compatibility. This should not be
|
||
|
overridden."""
|
||
|
return True
|
||
|
|
||
|
|
||
|
def _add_elements_to_collection(elements, collection_list):
|
||
|
if tf.executing_eagerly():
|
||
|
raise RuntimeError(
|
||
|
"Using collections from Layers not supported in Eager "
|
||
|
"mode. Tried to add %s to %s" % (elements, collection_list)
|
||
|
)
|
||
|
elements = tf.nest.flatten(elements)
|
||
|
collection_list = tf.nest.flatten(collection_list)
|
||
|
for name in collection_list:
|
||
|
collection = tf.compat.v1.get_collection_ref(name)
|
||
|
collection_set = {id(e) for e in collection}
|
||
|
for element in elements:
|
||
|
if id(element) not in collection_set:
|
||
|
collection.append(element)
|