615 lines
24 KiB
Python
615 lines
24 KiB
Python
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# =============================================================================
|
|
# pylint: disable=g-classes-have-attributes
|
|
"""Contains the base Layer class, from which all layers inherit."""
|
|
import copy
|
|
import warnings
|
|
|
|
from tensorflow.python.eager import context
|
|
from tensorflow.python.framework import dtypes
|
|
from tensorflow.python.framework import ops
|
|
from tensorflow.python.keras import backend
|
|
from tensorflow.python.keras.engine import base_layer
|
|
from tensorflow.python.keras.engine import base_layer_utils
|
|
from tensorflow.python.keras.legacy_tf_layers import variable_scope_shim
|
|
from tensorflow.python.keras.mixed_precision import policy
|
|
from tensorflow.python.keras.utils import tf_contextlib
|
|
from tensorflow.python.ops import variable_scope as vs
|
|
from tensorflow.python.ops import variables as tf_variables
|
|
from tensorflow.python.trackable import base as trackable
|
|
from tensorflow.python.util import nest
|
|
from tensorflow.python.util.tf_export import keras_export
|
|
from tensorflow.python.util.tf_export import tf_export
|
|
|
|
# Avoid breaking users who directly import this symbol from this file.
|
|
# TODO(fchollet): remove this.
|
|
InputSpec = base_layer.InputSpec # pylint: disable=invalid-name
|
|
|
|
_KERAS_STYLE_SCOPE = False
|
|
|
|
|
|
@keras_export(
|
|
v1=['keras.__internal__.legacy.layers.experimental.keras_style_scope'])
|
|
@tf_export(v1=['layers.experimental.keras_style_scope'])
|
|
@tf_contextlib.contextmanager
|
|
def keras_style_scope():
|
|
"""Use Keras-style variable management.
|
|
|
|
All tf.layers and tf RNN cells created in this scope use Keras-style
|
|
variable management. Creating such layers with a scope= argument is
|
|
disallowed, and reuse=True is disallowed.
|
|
|
|
The purpose of this scope is to allow users of existing layers to
|
|
slowly transition to a Keras layers API without breaking existing
|
|
functionality.
|
|
|
|
One example of this is when using TensorFlow's RNN classes with Keras
|
|
Models or Networks. Because Keras models do not properly set variable
|
|
scopes, users of RNNs may either accidentally share scopes between two
|
|
different models, or get errors about variables that already exist.
|
|
|
|
Example:
|
|
|
|
```python
|
|
class RNNModel(tf.keras.Model):
|
|
|
|
def __init__(self, name):
|
|
super(RNNModel, self).__init__(name=name)
|
|
self.rnn = tf.compat.v1.nn.rnn_cell.MultiRNNCell(
|
|
[tf.compat.v1.nn.rnn_cell.LSTMCell(64) for _ in range(2)])
|
|
|
|
def call(self, input, state):
|
|
return self.rnn(input, state)
|
|
|
|
model_1 = RNNModel("model_1")
|
|
model_2 = RNNModel("model_2")
|
|
|
|
# OK
|
|
output_1, next_state_1 = model_1(input, state)
|
|
# Raises an error about trying to create an already existing variable.
|
|
output_2, next_state_2 = model_2(input, state)
|
|
```
|
|
|
|
The solution is to wrap the model construction and execution in a keras-style
|
|
scope:
|
|
|
|
```python
|
|
with keras_style_scope():
|
|
model_1 = RNNModel("model_1")
|
|
model_2 = RNNModel("model_2")
|
|
|
|
# model_1 and model_2 are guaranteed to create their own variables.
|
|
output_1, next_state_1 = model_1(input, state)
|
|
output_2, next_state_2 = model_2(input, state)
|
|
|
|
assert len(model_1.weights) > 0
|
|
assert len(model_2.weights) > 0
|
|
assert(model_1.weights != model_2.weights)
|
|
```
|
|
|
|
Yields:
|
|
A keras layer style scope.
|
|
"""
|
|
global _KERAS_STYLE_SCOPE
|
|
stack = _KERAS_STYLE_SCOPE
|
|
_KERAS_STYLE_SCOPE = True
|
|
try:
|
|
yield
|
|
finally:
|
|
_KERAS_STYLE_SCOPE = stack
|
|
|
|
|
|
@keras_export(
|
|
v1=['keras.__internal__.legacy.layers.experimental.set_keras_style'])
|
|
@tf_export(v1=['layers.experimental.set_keras_style'])
|
|
def set_keras_style():
|
|
"""Use Keras-style variable management.
|
|
|
|
All tf.layers and tf RNN cells created after keras style ha been enabled
|
|
use Keras-style variable management. Creating such layers with a
|
|
scope= argument is disallowed, and reuse=True is disallowed.
|
|
|
|
The purpose of this function is to allow users of existing layers to
|
|
slowly transition to Keras layers API without breaking existing
|
|
functionality.
|
|
|
|
For more details, see the documentation for `keras_style_scope`.
|
|
|
|
Note, once keras style has been set, it is set globally for the entire
|
|
program and cannot be unset.
|
|
|
|
Example:
|
|
|
|
```python
|
|
set_keras_style()
|
|
|
|
model_1 = RNNModel(name="model_1")
|
|
model_2 = RNNModel(name="model_2")
|
|
|
|
# model_1 and model_2 are guaranteed to create their own variables.
|
|
output_1, next_state_1 = model_1(input, state)
|
|
output_2, next_state_2 = model_2(input, state)
|
|
|
|
assert len(model_1.weights) > 0
|
|
assert len(model_2.weights) > 0
|
|
assert(model_1.weights != model_2.weights)
|
|
```
|
|
"""
|
|
global _KERAS_STYLE_SCOPE
|
|
_KERAS_STYLE_SCOPE = True
|
|
|
|
|
|
def _is_in_keras_style_scope():
|
|
global _KERAS_STYLE_SCOPE
|
|
return _KERAS_STYLE_SCOPE
|
|
|
|
|
|
@keras_export(v1=['keras.__internal__.legacy.layers.Layer'])
|
|
@tf_export(v1=['layers.Layer'])
|
|
class Layer(base_layer.Layer):
|
|
"""Base layer class.
|
|
|
|
It is considered legacy, and we recommend the use of `tf.keras.layers.Layer`
|
|
instead.
|
|
|
|
Args:
|
|
trainable: Boolean, whether the layer's variables should be trainable.
|
|
name: String name of the layer.
|
|
dtype: Default dtype of the layer's weights (default of `None` means use the
|
|
type of the first input).
|
|
|
|
Read-only properties:
|
|
name: The name of the layer (string).
|
|
dtype: Default dtype of the layer's weights (default of `None` means use the
|
|
type of the first input).
|
|
trainable_variables: List of trainable variables.
|
|
non_trainable_variables: List of non-trainable variables.
|
|
variables: List of all variables of this layer, trainable and
|
|
non-trainable.
|
|
updates: List of update ops of this layer.
|
|
losses: List of losses added by this layer.
|
|
trainable_weights: List of variables to be included in backprop.
|
|
non_trainable_weights: List of variables that should not be
|
|
included in backprop.
|
|
weights: The concatenation of the lists trainable_weights and
|
|
non_trainable_weights (in this order).
|
|
|
|
Mutable properties:
|
|
trainable: Whether the layer should be trained (boolean).
|
|
input_spec: Optional (list of) `InputSpec` object(s) specifying the
|
|
constraints on inputs that can be accepted by the layer.
|
|
"""
|
|
|
|
def __init__(self, trainable=True, name=None, dtype=None,
|
|
**kwargs):
|
|
# For backwards compatibility, legacy layers do not use `ResourceVariable`
|
|
# by default.
|
|
self._use_resource_variables = False
|
|
scope = kwargs.pop('_scope', None)
|
|
self._reuse = kwargs.pop('_reuse', None)
|
|
|
|
# Avoid an incorrect lint error
|
|
self._trainable_weights = []
|
|
self.built = False
|
|
|
|
if dtype is None:
|
|
# Indicates to infer dtype from inputs. When the V2 dtype behavior is
|
|
# enabled, Keras layers default their dtype to floatx instead, so we pass
|
|
# an "_infer" policy to keep the old V1 behavior.
|
|
dtype = policy.Policy('_infer')
|
|
|
|
if 'autocast' not in kwargs:
|
|
kwargs['autocast'] = False
|
|
|
|
# Mark that legacy layers should not be instrumented as Keras usage
|
|
self._disable_keras_instrumentation = True
|
|
|
|
super(Layer, self).__init__(trainable=trainable, name=name, dtype=dtype,
|
|
**kwargs)
|
|
|
|
if _is_in_keras_style_scope():
|
|
if scope is not None:
|
|
raise ValueError(
|
|
'scope argument not allowed when keras style layers are enabled, '
|
|
'but saw: {}'.format(scope))
|
|
if self._reuse is not None:
|
|
raise ValueError(
|
|
'reuse argument not allowed when keras style layers are enabled, '
|
|
'but saw: {}'.format(self._reuse))
|
|
self._keras_style = True
|
|
else:
|
|
self._keras_style = False
|
|
|
|
self._call_has_scope_arg = 'scope' in self._call_fn_args
|
|
if scope:
|
|
with vs.variable_scope(scope) as captured_scope:
|
|
self._scope = captured_scope
|
|
else:
|
|
self._scope = None
|
|
self._current_scope = None
|
|
|
|
# We no longer track graph in tf.layers layers. This property is only kept to
|
|
# maintain API backward compatibility.
|
|
@property
|
|
def graph(self):
|
|
warnings.warn('`Layer.graph` is deprecated and '
|
|
'will be removed in a future version. '
|
|
'Please stop using this property because tf.layers layers no '
|
|
'longer track their graph.')
|
|
if context.executing_eagerly():
|
|
raise RuntimeError('Layer.graph not supported when executing eagerly.')
|
|
return None
|
|
|
|
def _init_set_name(self, name):
|
|
# Determine layer name (non-unique).
|
|
if isinstance(name, vs.VariableScope):
|
|
base_name = name.name
|
|
self._name, _ = self._make_unique_name()
|
|
else:
|
|
base_name = name
|
|
self._name = name
|
|
if not name:
|
|
self._name, base_name = self._make_unique_name()
|
|
self._base_name = base_name
|
|
|
|
def _make_unique_name(self, name_uid_map=None, avoid_names=None,
|
|
namespace='', zero_based=False):
|
|
base_name = base_layer.to_snake_case(self.__class__.__name__)
|
|
name = backend.unique_object_name(
|
|
base_name,
|
|
name_uid_map=name_uid_map,
|
|
avoid_names=avoid_names,
|
|
namespace=namespace,
|
|
zero_based=zero_based)
|
|
return (name, base_name)
|
|
|
|
@property
|
|
def scope_name(self):
|
|
if not self._scope:
|
|
raise ValueError('No name available for layer scope because the layer "' +
|
|
self._name + '" has not been used yet. The scope name ' +
|
|
' is determined the first time the layer instance is ' +
|
|
'called. You must therefore call the layer before ' +
|
|
'querying `scope_name`.')
|
|
return self._scope.name
|
|
|
|
def add_loss(self, losses, inputs=None):
|
|
previous_losses_length = len(self._losses)
|
|
previous_callable_losses_length = len(self._callable_losses)
|
|
super(Layer, self).add_loss(losses, inputs=inputs)
|
|
if not context.executing_eagerly():
|
|
# TODO(fchollet): deprecate collection below.
|
|
new_losses = self._losses[previous_losses_length:]
|
|
new_callable_losses = self._callable_losses[
|
|
previous_callable_losses_length:]
|
|
for regularizer in new_callable_losses:
|
|
loss_tensor = regularizer()
|
|
if loss_tensor is not None:
|
|
new_losses.append(loss_tensor)
|
|
_add_elements_to_collection(
|
|
new_losses,
|
|
ops.GraphKeys.REGULARIZATION_LOSSES)
|
|
|
|
def _name_scope(self): # pylint: disable=method-hidden
|
|
"""Determines op naming for the Layer."""
|
|
if self._keras_style:
|
|
return super(Layer, self)._name_scope()
|
|
return self._current_scope.original_name_scope
|
|
|
|
def _set_scope(self, scope=None):
|
|
if self._scope is None:
|
|
# If constructed with _scope=None, lazy setting of scope.
|
|
if self._reuse:
|
|
with vs.variable_scope(
|
|
scope if scope is not None else self._base_name) as captured_scope:
|
|
self._scope = captured_scope
|
|
else:
|
|
with vs.variable_scope(
|
|
scope, default_name=self._base_name) as captured_scope:
|
|
self._scope = captured_scope
|
|
|
|
def add_weight(self,
|
|
name,
|
|
shape,
|
|
dtype=None,
|
|
initializer=None,
|
|
regularizer=None,
|
|
trainable=None,
|
|
constraint=None,
|
|
use_resource=None,
|
|
synchronization=vs.VariableSynchronization.AUTO,
|
|
aggregation=vs.VariableAggregation.NONE,
|
|
partitioner=None,
|
|
**kwargs):
|
|
"""Adds a new variable to the layer, or gets an existing one; returns it.
|
|
|
|
Args:
|
|
name: variable name.
|
|
shape: variable shape.
|
|
dtype: The type of the variable. Defaults to `self.dtype` or `float32`.
|
|
initializer: initializer instance (callable).
|
|
regularizer: regularizer instance (callable).
|
|
trainable: whether the variable should be part of the layer's
|
|
"trainable_variables" (e.g. variables, biases)
|
|
or "non_trainable_variables" (e.g. BatchNorm mean, stddev).
|
|
Note, if the current variable scope is marked as non-trainable
|
|
then this parameter is ignored and any added variables are also
|
|
marked as non-trainable. `trainable` defaults to `True` unless
|
|
`synchronization` is set to `ON_READ`.
|
|
constraint: constraint instance (callable).
|
|
use_resource: Whether to use `ResourceVariable`.
|
|
synchronization: Indicates when a distributed a variable will be
|
|
aggregated. Accepted values are constants defined in the class
|
|
`tf.VariableSynchronization`. By default the synchronization is set to
|
|
`AUTO` and the current `DistributionStrategy` chooses
|
|
when to synchronize. If `synchronization` is set to `ON_READ`,
|
|
`trainable` must not be set to `True`.
|
|
aggregation: Indicates how a distributed variable will be aggregated.
|
|
Accepted values are constants defined in the class
|
|
`tf.VariableAggregation`.
|
|
partitioner: (optional) partitioner instance (callable). If
|
|
provided, when the requested variable is created it will be split
|
|
into multiple partitions according to `partitioner`. In this case,
|
|
an instance of `PartitionedVariable` is returned. Available
|
|
partitioners include `tf.compat.v1.fixed_size_partitioner` and
|
|
`tf.compat.v1.variable_axis_size_partitioner`. For more details, see
|
|
the documentation of `tf.compat.v1.get_variable` and the "Variable
|
|
Partitioners and Sharding" section of the API guide.
|
|
**kwargs: Additional keyword arguments.
|
|
|
|
Returns:
|
|
The created variable. Usually either a `Variable` or `ResourceVariable`
|
|
instance. If `partitioner` is not `None`, a `PartitionedVariable`
|
|
instance is returned.
|
|
|
|
Raises:
|
|
RuntimeError: If called with partitioned variable regularization and
|
|
eager execution is enabled.
|
|
ValueError: When trainable has been set to True with synchronization
|
|
set as `ON_READ`.
|
|
"""
|
|
for kwarg in kwargs:
|
|
if kwarg != 'experimental_autocast':
|
|
raise TypeError('Unknown keyword argument:', kwarg)
|
|
if self._keras_style:
|
|
return super(Layer, self).add_weight(
|
|
name=name,
|
|
shape=shape,
|
|
dtype=dtype,
|
|
initializer=initializer,
|
|
regularizer=regularizer,
|
|
trainable=trainable and self.trainable,
|
|
constraint=constraint,
|
|
use_resource=use_resource,
|
|
synchronization=vs.VariableSynchronization.AUTO,
|
|
aggregation=vs.VariableAggregation.NONE,
|
|
partitioner=partitioner,
|
|
**kwargs)
|
|
|
|
if synchronization == vs.VariableSynchronization.ON_READ:
|
|
if trainable:
|
|
raise ValueError(
|
|
'Synchronization value can be set to '
|
|
'VariableSynchronization.ON_READ only for non-trainable variables. '
|
|
'You have specified trainable=True and '
|
|
'synchronization=VariableSynchronization.ON_READ.')
|
|
else:
|
|
# Set trainable to be false when variable is to be synced on read.
|
|
trainable = False
|
|
elif trainable is None:
|
|
trainable = True
|
|
|
|
def _should_add_regularizer(variable, existing_variable_set):
|
|
if base_layer_utils.is_split_variable(variable):
|
|
for var in variable:
|
|
if var in existing_variable_set:
|
|
return False
|
|
return True
|
|
else:
|
|
return variable not in existing_variable_set
|
|
|
|
init_graph = None
|
|
if not context.executing_eagerly():
|
|
default_graph = ops.get_default_graph()
|
|
if default_graph.building_function:
|
|
with ops.init_scope():
|
|
# Retrieve the variables from the graph into which variables
|
|
# will be lifted; if initialization ops will be lifted into
|
|
# the eager context, then there is nothing to retrieve, since variable
|
|
# collections are not supported when eager execution is enabled.
|
|
if not context.executing_eagerly():
|
|
init_graph = ops.get_default_graph()
|
|
existing_variables = set(tf_variables.global_variables())
|
|
else:
|
|
# Initialization ops will not be lifted out of the default graph.
|
|
init_graph = default_graph
|
|
existing_variables = set(tf_variables.global_variables())
|
|
|
|
if dtype is None:
|
|
dtype = self.dtype or dtypes.float32
|
|
|
|
self._set_scope(None)
|
|
reuse = self.built or self._reuse
|
|
prev_len_trainable = len(self._trainable_weights)
|
|
with vs.variable_scope(
|
|
self._scope, reuse=reuse, auxiliary_name_scope=False) as scope:
|
|
self._current_scope = scope
|
|
with backend.name_scope(self._name_scope()): # pylint: disable=not-callable
|
|
use_resource = (use_resource or
|
|
self._use_resource_variables or
|
|
scope.use_resource)
|
|
if initializer is None:
|
|
initializer = scope.initializer
|
|
variable = super(Layer, self).add_weight(
|
|
name,
|
|
shape,
|
|
dtype=dtypes.as_dtype(dtype),
|
|
initializer=initializer,
|
|
trainable=trainable and self.trainable,
|
|
constraint=constraint,
|
|
partitioner=partitioner,
|
|
use_resource=use_resource,
|
|
synchronization=synchronization,
|
|
aggregation=aggregation,
|
|
getter=vs.get_variable,
|
|
**kwargs)
|
|
|
|
if regularizer:
|
|
if (ops.executing_eagerly_outside_functions()
|
|
or _should_add_regularizer(variable, existing_variables)):
|
|
self._handle_weight_regularization(name, variable, regularizer)
|
|
var_store = vs._get_default_variable_store() # pylint: disable=protected-access
|
|
# When the shim to get variable scope working in TF2 is used,
|
|
# We need to explicitly make the shim track the regularization
|
|
# losses as the collections will not be accessible.
|
|
if hasattr(var_store, 'add_regularizer'):
|
|
var_store.add_regularizer(variable, regularizer)
|
|
|
|
if init_graph is not None:
|
|
# Handle edge case where a custom getter has overridden `trainable`.
|
|
# There is one known occurrence of this, in unit test
|
|
# testBasicRNNCellNotTrainable in
|
|
# contrib.rnn.python.kernel_tests.core_rnn_cell_test
|
|
with init_graph.as_default():
|
|
trainable_variables = tf_variables.trainable_variables()
|
|
if (trainable and self.trainable and
|
|
variable not in trainable_variables):
|
|
# A custom getter / variable scope overrode the trainable flag.
|
|
extra_trainable_vars = self._trainable_weights[prev_len_trainable:]
|
|
self._trainable_weights = self._trainable_weights[
|
|
:prev_len_trainable]
|
|
self._non_trainable_weights += extra_trainable_vars
|
|
return variable
|
|
|
|
def __call__(self, inputs, *args, **kwargs):
|
|
"""Wraps `call`, applying pre- and post-processing steps.
|
|
|
|
Args:
|
|
inputs: input tensor(s).
|
|
*args: additional positional arguments to be passed to `self.call`.
|
|
**kwargs: additional keyword arguments to be passed to `self.call`.
|
|
**Note**: kwarg `scope` is reserved for use by the layer.
|
|
|
|
Returns:
|
|
Output tensor(s).
|
|
|
|
Note:
|
|
- If the layer's `call` method takes a `scope` keyword argument,
|
|
this argument will be automatically set to the current variable scope.
|
|
- If the layer's `call` method takes a `mask` argument (as some Keras
|
|
layers do), its default value will be set to the mask generated
|
|
for `inputs` by the previous layer (if `input` did come from
|
|
a layer that generated a corresponding mask, i.e. if it came from
|
|
a Keras layer with masking support.
|
|
|
|
Raises:
|
|
ValueError: if the layer's `call` method returns None (an invalid value).
|
|
"""
|
|
scope = kwargs.pop('scope', None)
|
|
|
|
if self._keras_style:
|
|
if scope is not None:
|
|
raise ValueError(
|
|
'scope argument not allowed when keras style layers are enabled, '
|
|
'but saw: {}'.format(scope))
|
|
return super(Layer, self).__call__(inputs, *args, **kwargs)
|
|
|
|
self._set_scope(scope)
|
|
|
|
if self.built:
|
|
try:
|
|
# Some classes which inherit from Layer do not use its constructor, so
|
|
# rather than initializing to None we check for an AttributeError.
|
|
scope_context_manager = self._always_reuse_variable_scope # pylint: disable=access-member-before-definition
|
|
except AttributeError:
|
|
scope_context_manager = None
|
|
|
|
if scope_context_manager is None:
|
|
# From this point we will always set reuse=True, so create a "final"
|
|
# variable scope with this setting. We avoid re-creating variable scopes
|
|
# after this point as an optimization.
|
|
scope_context_manager = vs.variable_scope(
|
|
self._scope, reuse=True, auxiliary_name_scope=False)
|
|
|
|
# Do not cache variable scopes if Eager mode is enabled. If Eager mode
|
|
# is enabled then we don't want to reuse scopes because the cached scope
|
|
# might be from a FuncGraph or Eager scope we are no longer in.
|
|
if not ops.executing_eagerly_outside_functions():
|
|
self._always_reuse_variable_scope = scope_context_manager
|
|
else:
|
|
scope_context_manager = vs.variable_scope(
|
|
self._scope, reuse=self._reuse, auxiliary_name_scope=False)
|
|
|
|
with scope_context_manager as scope:
|
|
self._current_scope = scope
|
|
|
|
try:
|
|
call_has_scope_arg = self._call_has_scope_arg
|
|
except AttributeError:
|
|
self._call_fn_args = variable_scope_shim.fn_args(self.call)
|
|
self._call_has_scope_arg = 'scope' in self._call_fn_args
|
|
call_has_scope_arg = self._call_has_scope_arg
|
|
if call_has_scope_arg:
|
|
kwargs['scope'] = scope
|
|
|
|
# Actually call layer
|
|
outputs = super(Layer, self).__call__(inputs, *args, **kwargs)
|
|
|
|
if not context.executing_eagerly():
|
|
# Update global default collections.
|
|
_add_elements_to_collection(self.updates, ops.GraphKeys.UPDATE_OPS)
|
|
return outputs
|
|
|
|
def __deepcopy__(self, memo):
|
|
no_copy = set(['_graph', '_thread_local', '_metrics_lock'])
|
|
shallow_copy = set(['_scope', '_always_reuse_variable_scope'])
|
|
cls = self.__class__
|
|
result = cls.__new__(cls)
|
|
memo[id(self)] = result
|
|
for k, v in self.__dict__.items():
|
|
if k in no_copy:
|
|
setattr(result, k, v)
|
|
elif k in shallow_copy:
|
|
setattr(result, k, copy.copy(v))
|
|
elif base_layer.is_tensor_or_tensor_list(v):
|
|
setattr(result, k, v)
|
|
else:
|
|
setattr(result, k, copy.deepcopy(v, memo))
|
|
return result
|
|
|
|
def __setattr__(self, value, name):
|
|
# By-pass the automatic dependency tracking performed by the parent Layer.
|
|
super(trackable.Trackable, self).__setattr__(value, name) # pylint: disable=bad-super-call
|
|
|
|
@property
|
|
def _is_legacy_layer(self):
|
|
"""Used by keras to check compatibility. This should not be overridden."""
|
|
return True
|
|
|
|
|
|
def _add_elements_to_collection(elements, collection_list):
|
|
if context.executing_eagerly():
|
|
raise RuntimeError('Using collections from Layers not supported in Eager '
|
|
'mode. Tried to add %s to %s' % (elements,
|
|
collection_list))
|
|
elements = nest.flatten(elements)
|
|
collection_list = nest.flatten(collection_list)
|
|
for name in collection_list:
|
|
collection = ops.get_collection_ref(name)
|
|
collection_set = {id(e) for e in collection}
|
|
for element in elements:
|
|
if id(element) not in collection_set:
|
|
collection.append(element)
|