Intelegentny_Pszczelarz/.venv/Lib/site-packages/keras/legacy_tf_layers/variable_scope_shim.py

1086 lines
44 KiB
Python
Raw Normal View History

2023-06-19 00:49:18 +02:00
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Contains a shim to allow using TF1 get_variable code in TF2."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import contextlib
import functools
import tensorflow.compat.v2 as tf
from keras.engine import base_layer
from keras.utils import layer_utils
from keras.utils import tf_inspect
# isort: off
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util.tf_export import keras_export
def as_shape(shape):
"""Converts the given object to a TensorShape."""
if isinstance(shape, tf.TensorShape):
return shape
else:
return tf.TensorShape(shape)
def _is_callable_object(obj):
return hasattr(obj, "__call__") and tf_inspect.ismethod(obj.__call__)
def _has_kwargs(fn):
"""Returns whether the passed callable has **kwargs in its signature.
Args:
fn: Function, or function-like object (e.g., result of
`functools.partial`).
Returns:
`bool`: if `fn` has **kwargs in its signature.
Raises:
`TypeError`: If fn is not a Function, or function-like object.
"""
if isinstance(fn, functools.partial):
fn = fn.func
elif _is_callable_object(fn):
fn = fn.__call__
elif not callable(fn):
raise TypeError(
f"fn should be a function-like object, but is of type {type(fn)}."
)
return tf_inspect.getfullargspec(fn).varkw is not None
def fn_args(fn):
"""Get argument names for function-like object.
Args:
fn: Function, or function-like object (e.g., result of
`functools.partial`).
Returns:
`tuple` of string argument names.
Raises:
ValueError: if partial function has positionally bound arguments
"""
if isinstance(fn, functools.partial):
args = fn_args(fn.func)
args = [a for a in args[len(fn.args) :] if a not in (fn.keywords or [])]
else:
if hasattr(fn, "__call__") and tf_inspect.ismethod(fn.__call__):
fn = fn.__call__
args = tf_inspect.getfullargspec(fn).args
if _is_bound_method(fn) and args:
# If it's a bound method, it may or may not have a self/cls first
# argument; for example, self could be captured in *args.
# If it does have a positional argument, it is self/cls.
args.pop(0)
return tuple(args)
def _is_bound_method(fn):
_, fn = tf.__internal__.decorator.unwrap(fn)
return tf_inspect.ismethod(fn) and (fn.__self__ is not None)
def validate_synchronization_aggregation_trainable(
synchronization, aggregation, trainable, name
):
"""Given user-provided variable properties, sets defaults and validates."""
if aggregation is None:
aggregation = tf.compat.v1.VariableAggregation.NONE
else:
if not isinstance(
aggregation,
(tf.compat.v1.VariableAggregation, tf.VariableAggregation),
):
try:
aggregation = tf.VariableAggregation(aggregation)
except ValueError:
raise ValueError(
"Invalid variable aggregation mode: {} "
"for variable: {}".format(aggregation, name)
)
if synchronization is None:
synchronization = tf.VariableSynchronization.AUTO
else:
try:
synchronization = tf.VariableSynchronization(synchronization)
except ValueError:
raise ValueError(
"Invalid variable synchronization mode: {} "
"for variable: {}".format(synchronization, name)
)
if trainable is None:
trainable = synchronization != tf.VariableSynchronization.ON_READ
return synchronization, aggregation, trainable
class _EagerVariableStore(tf.Module):
"""TF2-safe VariableStore that avoids collections & tracks regularizers.
New variable names and new variables can be created; all stored
variables are initialized with the initializer passed to __init__.
All variables get created in `tf.init_scope.` to avoid a bad
interaction between `tf.function` `FuncGraph` internals, Keras
Functional Models, and TPUStrategy variable initialization.
Also, it always acts as if reuse is set to either "TRUE" or
tf.compat.v1.AUTO_REUSE
Attributes:
vars: a dictionary with string names (same as passed in GetVar) as keys
and the corresponding TensorFlow Variables as values.
regularizers: a dictionary with string names as keys and the corresponding
callables that return losses as values.
layers: a dictionary with string names as keys and the corresponding
nested keras layers as values.
"""
def __init__(self):
"""Create a variable store."""
self._vars = {} # A dictionary of the stored TensorFlow variables.
self._regularizers = (
{}
) # A dict mapping var names to their regularizers.
self._layers = {} # A dictionary of stored keras layers.
self._store_eager_variables = True
@contextlib.contextmanager
def scope(self):
with vs.with_variable_store(self):
yield
def get_variable(
self,
name,
shape=None,
dtype=tf.float32,
initializer=None,
regularizer=None,
reuse=None,
trainable=None,
collections=None,
caching_device=None,
partitioner=None,
validate_shape=True,
use_resource=None,
custom_getter=None,
constraint=None,
synchronization=tf.VariableSynchronization.AUTO,
aggregation=tf.compat.v1.VariableAggregation.NONE,
):
"""Gets an existing variable with these parameters or create a new one.
If a variable with the given name is already stored, we return the
stored variable. Otherwise, we create a new one.
Set `reuse` to `True` when you only want to reuse existing Variables.
Set `reuse` to None (the default) or tf.compat.v1.AUTO_REUSE when you
want variables to be created if they don't exist or returned if they do.
In this shim, `reuse` of `False` will be treated as auto-reuse.
If initializer is `None` (the default), the default initializer passed
in the constructor is used. If that one is `None` too, we use a new
`glorot_uniform_initializer`. If initializer is a Tensor, we use it as a
value and derive the shape from the initializer.
If a partitioner is provided, a `PartitionedVariable` is returned.
Accessing this object as a `Tensor` returns the shards concatenated
along the partition axis.
Some useful partitioners are available. See, e.g.,
`variable_axis_size_partitioner` and `min_max_variable_partitioner`.
Args:
name: The name of the new or existing variable.
shape: Shape of the new or existing variable.
dtype: Type of the new or existing variable (defaults to `DT_FLOAT`).
initializer: Initializer for the variable.
regularizer: A (Tensor -> Tensor or None) function; the result of
applying it on a newly created variable will be added to the
collection GraphKeys.REGULARIZATION_LOSSES and can be used for
regularization.
reuse: a Boolean, None, or tf.AUTO_REUSE. Controls reuse or creation
of variables. When eager execution is enabled this argument is
always forced to be False.
trainable: If `True` also add the variable to the graph collection
`GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). `trainable`
defaults to `True`, unless `synchronization` is set to `ON_READ`, in
which case it defaults to `False`.
collections: List of graph collections keys to add the `Variable` to.
Defaults to `[GraphKeys.GLOBAL_VARIABLES]` (see `tf.Variable`).
caching_device: Optional device string or function describing where
the Variable should be cached for reading. Defaults to the
Variable's device. If not `None`, caches on another device.
Typical use is to cache on the device where the Ops using the
`Variable` reside, to deduplicate copying through `Switch` and other
conditional statements.
partitioner: Optional callable that accepts a fully defined
`TensorShape` and dtype of the `Variable` to be created, and returns
a list of partitions for each axis (currently only one axis can be
partitioned).
validate_shape: If False, allows the variable to be initialized with a
value of unknown shape. If True, the default, the shape of
initial_value must be known.
use_resource: If False, creates a regular Variable. If True, creates
instead an experimental ResourceVariable which has well-defined
semantics. Defaults to False (will later change to True). When eager
execution is enabled this argument is always forced to be true.
custom_getter: Callable that takes as a first argument the true
getter, and allows overwriting the internal get_variable method. The
signature of `custom_getter` should match that of this method, but
the most future-proof version will allow for changes:
`def custom_getter(getter, *args, **kwargs)`.
Direct access to all `get_variable` parameters is also allowed:
`def custom_getter(getter, name, *args, **kwargs)`.
A simple identity custom getter that simply creates variables with
modified names is:
```python
def custom_getter(getter, name, *args, **kwargs):
return getter(name + '_suffix', *args, **kwargs)
```
constraint: An optional projection function to be applied to the
variable after being updated by an `Optimizer` (e.g. used to
implement norm constraints or value constraints for layer weights).
The function must take as input the unprojected Tensor representing
the value of the variable and return the Tensor for the projected
value (which must have the same shape). Constraints are not safe to
use when doing asynchronous distributed training.
synchronization: Indicates when a distributed a variable will be
aggregated. Accepted values are constants defined in the class
`tf.VariableSynchronization`. By default the synchronization is set
to `AUTO` and the current `DistributionStrategy` chooses when to
synchronize.
aggregation: Indicates how a distributed variable will be aggregated.
Accepted values are constants defined in the class
`tf.VariableAggregation`.
Returns:
The created or existing `Variable` (or `PartitionedVariable`, if a
partitioner was used).
Raises:
ValueError: when creating a new variable and shape is not declared,
when reusing a variable and specifying a conflicting shape,
or when violating reuse during variable creation.
RuntimeError: when eager execution is enabled and not called from an
EagerVariableStore.
"""
if custom_getter is not None and not callable(custom_getter):
raise ValueError(
f"Passed a custom_getter which is not callable: {custom_getter}"
)
with tf.init_scope():
if tf.executing_eagerly():
# Variable creation and initialization takes place in
# `init_scope`s; as such, if an `init_scope` lifts us into the
# eager context, then we need to use `ResourceVariable`s.
use_resource = True
# Note that it's fine to reuse eager variables whose initialization was
# lifted from a function-building graph into the eager context (that's
# why the following clause is not wrapped in an `init_scope`); lifted
# variables are tracked by the graph's `VariableStore`.
if not reuse:
reuse = tf.compat.v1.AUTO_REUSE
# If a *_ref type is passed in an error would be triggered further down
# the stack. We prevent this using base_dtype to get a non-ref version
# of the type, before doing anything else. When _ref types are removed
# in favor of resources, this line can be removed.
try:
dtype = dtype.base_dtype
except AttributeError:
# .base_dtype not existing means that we will try and use the raw
# dtype which was passed in - this might be a NumPy type which is
# valid.
pass
# This is the main logic of get_variable. However, custom_getter
# may override this logic. So we save it as a callable and pass
# it to custom_getter.
# Note: the parameters of _true_getter, and their documentation, match
# *exactly* item-for-item with the docstring of this method.
def _true_getter(
name,
shape=None,
dtype=tf.float32,
initializer=None,
regularizer=None,
reuse=None,
trainable=None,
collections=None,
caching_device=None,
partitioner=None,
validate_shape=True,
use_resource=None,
constraint=None,
synchronization=tf.VariableSynchronization.AUTO,
aggregation=tf.compat.v1.VariableAggregation.NONE,
):
# Partitioned variable currently unsupported w/ the shim
if partitioner is not None:
raise ValueError(
"`partitioner` arg for `get_variable` is unsupported in "
"TF2. File a bug if you need help. "
"You passed %s" % partitioner
)
# Single variable case
if f"{name}/part_0" in self._vars:
raise ValueError(
"No partitioner was provided, but a partitioned version of "
"the variable was found: %s/part_0. Perhaps a variable of "
"the same name was already created with "
"partitioning?" % name
)
return self._get_single_variable(
name=name,
shape=shape,
dtype=dtype,
initializer=initializer,
regularizer=regularizer,
reuse=reuse,
trainable=trainable,
caching_device=caching_device,
validate_shape=validate_shape,
constraint=constraint,
synchronization=synchronization,
aggregation=aggregation,
)
(
synchronization,
aggregation,
trainable,
) = validate_synchronization_aggregation_trainable(
synchronization, aggregation, trainable, name
)
if custom_getter is not None:
# Handle backwards compatibility with getter arguments that were
# added to the API after users started writing custom getters.
custom_getter_kwargs = {
"getter": _true_getter,
"name": name,
"shape": shape,
"dtype": dtype,
"initializer": initializer,
"regularizer": regularizer,
"reuse": reuse,
"trainable": trainable,
"collections": collections,
"caching_device": caching_device,
"partitioner": partitioner,
"validate_shape": validate_shape,
"use_resource": use_resource,
"synchronization": synchronization,
"aggregation": aggregation,
}
# `fn_args` and `has_kwargs` can handle functions,
# `functools.partial`, `lambda`.
if "constraint" in fn_args(custom_getter) or _has_kwargs(
custom_getter
):
custom_getter_kwargs["constraint"] = constraint
return custom_getter(**custom_getter_kwargs)
else:
return _true_getter(
name,
shape=shape,
dtype=dtype,
initializer=initializer,
regularizer=regularizer,
reuse=reuse,
trainable=trainable,
collections=collections,
caching_device=caching_device,
partitioner=partitioner,
validate_shape=validate_shape,
use_resource=use_resource,
constraint=constraint,
synchronization=synchronization,
aggregation=aggregation,
)
def _get_single_variable(
self,
name,
shape=None,
dtype=tf.float32,
initializer=None,
regularizer=None,
partition_info=None,
reuse=None,
trainable=None,
caching_device=None,
validate_shape=True,
constraint=None,
synchronization=tf.VariableSynchronization.AUTO,
aggregation=tf.compat.v1.VariableAggregation.NONE,
):
"""Get or create a single Variable (e.g. a shard or entire variable).
See the documentation of get_variable above (ignore partitioning
components) for details.
Args:
name: see get_variable.
shape: see get_variable.
dtype: see get_variable.
initializer: see get_variable.
regularizer: see get_variable.
partition_info: _PartitionInfo object.
reuse: see get_variable.
trainable: see get_variable.
caching_device: see get_variable.
validate_shape: see get_variable.
constraint: see get_variable.
synchronization: see get_variable.
aggregation: see get_variable.
Returns:
A Variable. See documentation of get_variable above.
Raises:
ValueError: See documentation of get_variable above.
"""
# Set to true if initializer is a constant.
initializing_from_value = False
if initializer is not None and not callable(initializer):
initializing_from_value = True
if shape is not None and initializing_from_value:
raise ValueError(
"If initializer is a constant, do not specify shape."
)
dtype = tf.as_dtype(dtype)
shape = as_shape(shape)
if name in self._vars:
# Here we handle the case when returning an existing variable.
found_var = self._vars[name]
if not shape.is_compatible_with(found_var.get_shape()):
raise ValueError(
"Trying to share variable %s, but specified shape %s"
" and found shape %s."
% (name, shape, found_var.get_shape())
)
if not dtype.is_compatible_with(found_var.dtype):
dtype_str = dtype.name
found_type_str = found_var.dtype.name
raise ValueError(
"Trying to share variable %s, but specified dtype %s"
" and found dtype %s." % (name, dtype_str, found_type_str)
)
return found_var
# The code below handles only the case of creating a new variable.
if reuse is True:
raise ValueError(
"Variable %s does not exist, or was not created with "
"tf.get_variable(). Did you mean to set "
"reuse=tf.AUTO_REUSE in VarScope?" % name
)
# Create the tensor to initialize the variable with default value.
if initializer is None:
(
initializer,
initializing_from_value,
) = self._get_default_initializer(
name=name, shape=shape, dtype=dtype
)
# Enter an init scope when creating the initializer.
with tf.init_scope():
if initializing_from_value:
init_val = initializer
variable_dtype = None
else:
# Instantiate initializer if provided initializer is a type
# object.
if tf_inspect.isclass(initializer):
initializer = initializer()
if shape.is_fully_defined():
if (
"partition_info"
in tf_inspect.getargspec(initializer).args
):
init_val = functools.partial(
initializer,
shape.as_list(),
dtype=dtype,
partition_info=partition_info,
)
else:
init_val = functools.partial(
initializer, shape.as_list(), dtype=dtype
)
variable_dtype = dtype.base_dtype
else:
init_val = initializer
variable_dtype = None
# Create the variable (Always eagerly as a workaround for a strange
# tpu / funcgraph / keras functional model interaction )
with tf.init_scope():
v = tf.Variable(
initial_value=init_val,
name=name,
trainable=trainable,
caching_device=caching_device,
dtype=variable_dtype,
validate_shape=validate_shape,
constraint=constraint,
synchronization=synchronization,
aggregation=aggregation,
)
self._vars[name] = v
logging.vlog(
1,
"Created variable %s with shape %s and init %s",
v.name,
format(shape),
initializer,
)
# Run the regularizer if requested and save the resulting loss.
if regularizer:
self.add_regularizer(v, regularizer)
return v
def get_or_create_layer(self, name, create_layer_method):
if name not in self._layers:
layer = create_layer_method()
self._layers[name] = layer
if isinstance(layer, base_layer.Layer):
self._regularizers[name] = lambda: tf.math.reduce_sum(
layer.losses
)
return self._layers[name]
def add_regularizer(self, var, regularizer):
self._regularizers[var.name] = functools.partial(regularizer, var)
# Initialize variable when no initializer provided
def _get_default_initializer(self, name, shape=None, dtype=tf.float32):
"""Provide a default initializer and a corresponding value.
Args:
name: see get_variable.
shape: see get_variable.
dtype: see get_variable.
Returns:
initializer and initializing_from_value. See get_variable above.
Raises:
ValueError: When giving unsupported dtype.
"""
del shape
# If dtype is DT_FLOAT, provide a uniform unit scaling initializer
if dtype.is_floating:
initializer = tf.compat.v1.glorot_uniform_initializer()
initializing_from_value = False
# If dtype is DT_INT/DT_UINT, provide a default value `zero`
# If dtype is DT_BOOL, provide a default value `FALSE`
elif (
dtype.is_integer
or dtype.is_unsigned
or dtype.is_bool
or dtype == tf.string
):
initializer = tf.compat.v1.zeros_initializer()
initializing_from_value = False
# NOTES:Do we need to support for handling DT_STRING and DT_COMPLEX
# here?
else:
raise ValueError(
"An initializer for variable %s of %s is required"
% (name, dtype.base_dtype)
)
return initializer, initializing_from_value
@keras_export(v1=["keras.utils.track_tf1_style_variables"])
def track_tf1_style_variables(method):
"""Wrap layer & module methods in this decorator to capture tf1-style
weights.
Decorating a `tf.keras.Layer`'s or `tf.Module`'s methods with this
decorator will cause the layer/module to track weights created/used
via `tf.compat.v1.get_variable` (and by extension `tf.compat.v1.layers`)
inside the decorated method.
In addition to tracking the weights themselves under the standard
`layer.variable`/`module.variable`/etc. properties, if the method belongs
to a `tf.keras.Layer` then any regularization losses specified via the
`get_variable` or `tf.compat.v1.layers` regularizer arguments will get
tracked by the layer under the standard `layer.losses` property.
This tracking enables using large classes of TF1-style model-forward-pass
code inside of Keras layers or `tf.Modules` in TF2 with TF2 behaviors
enabled.
Example of capturing tf.compat.v1.layer-based modeling code as a Keras
layer:
```python
class WrappedDoubleDenseLayer(tf.keras.layers.Layer):
def __init__(self, units, *args, **kwargs):
super().__init__(*args, **kwargs)
self.units = units
@tf.compat.v1.keras.utils.track_tf1_style_variables
def call(self, inputs):
with tf.compat.v1.variable_scope("double_dense_layer"):
out = tf.compat.v1.layers.dense(
inputs, self.units, name="dense_one",
kernel_initializer=tf.compat.v1.random_normal_initializer,
kernel_regularizer="l2")
out = tf.compat.v1.layers.dense(
out, self.units, name="dense_two",
kernel_initializer=tf.compat.v1.random_normal_initializer(),
kernel_regularizer="l2")
return out
# Create a layer that can be used as a standard keras layer
layer = WrappedDoubleDenseLayer(10)
# call the layer on inputs
layer(...)
# Variables created/used within the scope will be tracked by the layer
layer.weights
layer.trainable_variables
# Regularization losses will be captured in layer.losses after a call,
# just like any other Keras layer
reg_losses = layer.losses
```
Example of capturing tf.compat.v1.get_variable-based modeling code as
a Keras layer:
```python
class WrappedDoubleDenseLayer(tf.keras.layers.Layer):
def __init__(self, units, *args, **kwargs):
super().__init__(*args, **kwargs)
self.units = units
@tf.compat.v1.keras.utils.track_tf1_style_variables
def call(self, inputs):
out = inputs
with tf.compat.v1.variable_scope("double_dense_layer"):
with tf.compat.v1.variable_scope("dense_one"):
# The weights are created with a `regularizer`,
# so the layer should track their regularization losses
kernel = tf.compat.v1.get_variable(
shape=[out.shape[-1], self.units],
regularizer=regularizers.L2(),
initializer=init_ops.ones_initializer(),
name="kernel")
bias = tf.compat.v1.get_variable(
shape=[self.units,],
initializer=init_ops.zeros_initializer(),
name="bias")
out = tf.compat.v1.math.matmul(out, kernel)
out = tf.compat.v1.nn.bias_add(out, bias)
with tf.compat.v1.variable_scope("dense_two"):
kernel = tf.compat.v1.get_variable(
shape=[out.shape[-1], self.units],
regularizer=regularizers.L2(),
initializer=init_ops.ones_initializer(),
name="kernel")
bias = tf.compat.v1.get_variable(
shape=[self.units,],
initializer=init_ops.zeros_initializer(),
name="bias")
out = tf.compat.v1.math.matmul(out, kernel)
out = tf.compat.v1.nn.bias_add(out, bias)
return out
# Create a layer that can be used as a standard keras layer
layer = WrappedDoubleDenseLayer(10)
# call the layer on inputs
layer(...)
# Variables created/used within the scope will be tracked by the layer
layer.weights
layer.trainable_variables
# Regularization losses will be captured in layer.losses after a call,
# just like any other Keras layer
reg_losses = layer.losses
```
Regularization losses:
Any regularizers specified in the `get_variable` calls or
`compat.v1.layer` creations will get captured if they occur in your
decorated method and the method belongs to a
`tf.keras.Layer`/`tf.keras.Module`. Regularization losses
are accessible in `layer.losses` after a call just like in a standard
Keras layer, and will be captured by any model that includes this layer.
Regularization losses attached to Keras layers/models set as attributes
of your layer will also get captured in the standard Keras regularization
loss tracking.
(While Modules have no `losses` property, no-arg callables to compute
the regularization losses may be tracked as dict values in a private
`module._tf1_style_var_store._regularizers` property, but only for
`tf.compat.v1.layers` and `get_variable` weights and not for any other
nested Keras layers/tf.Modules)
Variable scope / variable reuse:
variable-scope based reuse in your decorated method will be respected,
and work like variable-scope based reuse in TF1.
Variable Names/Pre-trained checkpoint loading:
Variable naming from get_variable and `compat.v1.layer` layers will match
the TF1 names, so you should be able to re-use your old name-based
checkpoints. Variable naming for Keras layers/models or for variables
created by `tf.Variable` may change when going to eager execution.
Training Arg if you decorate `layer.call`:
Keras will pass a `training` arg to this layer if `call` contains
a `training` arg or a `**kwargs` varargs in its call signature,
similarly to how keras passes `training` to other layers in TF2 that have
similar signatures in their `call` implementations.
See more details in the docs
on `tf.keras.layers.Layer` to understand what will be passed and when.
Note: tf.compat.v1.layers are usually not called with `training=None`,
so the training arg to `forward_pass` might not feed through to them
unless you pass it to their calls explicitly.
Caveats:
* TF2 will not prune unused variable updates (or unused outputs). You may
need to adjust your forward pass code to avoid computations or variable
updates that you don't intend to use.
* Avoid Nesting variable creation in tf.function inside of
methods decorated with `track_tf1_style_variables`
While the method may safely be used from inside a `tf.function`, using
a function inside of a decorated method may break the variable scoping.
* This decorator only adds implicit tracking for legacy tf1-style
get_variable / compat.v1.layers usage.
If you would like to use nested Keras layers/models
inside the decorated method, you need to
assign them as attributes of your layer so that Keras/Module's standard
object-oriented weights (and loss tracking for layers) will kick in.
See the intro to modules, layers, and models
[guide](https://www.tensorflow.org/guide/intro_to_modules) for more
info. As a backup, the `compat.v1.keras.utils.get_or_create_layer`
method will ease tracking nested keras model weights and losses for
existing TF1 code, but new code should use explicit tracking.
Args:
method: The method to decorate. This should belong to a custom tf.Module,
tf.keras.layers.Layer, or tf.keras.Model.
Returns:
The decorated method.
"""
def _method_wrapper(self, *args, **kwargs):
var_store = getattr(self, "_tf1_style_var_store", None)
if not var_store:
if not isinstance(self, tf.Module):
# Raise an error if you incorrectly decorate a method
# that is not a method of a Module, Layer, or Model:
raise ValueError(
"`@tf.compat.v1.keras.utils.track_tf1_layers_and_variables`"
" must be applied to a method of a subclassed `tf.Module`, "
"`tf.keras.layers.Layer`, or `tf.keras.Model` and which "
"takes `self` as the first argument. But, the first "
"argument passed to the decorated method was {}, which "
"does not extend Module, Layer, or Model.".format(self)
)
var_store = _EagerVariableStore()
self._tf1_style_var_store = var_store
existing_regularized_variables = set(var_store._regularizers.keys())
with var_store.scope():
out = method(self, *args, **kwargs)
# If this is a layer method, add the regularization losses
# to the layer for any newly-created regularized variables
if isinstance(self, base_layer.Layer):
for (
var_name,
regularizer,
) in var_store._regularizers.items():
if var_name not in existing_regularized_variables:
self.add_loss(regularizer)
return out
return tf.__internal__.decorator.make_decorator(
target=method, decorator_func=_method_wrapper
)
class VariableScopeLayer(base_layer.Layer):
"""Wrapper Layer to capture `compat.v1.get_variable` and `compat.v1.layers`.
This shim layer allows using large sets of TF1 model-forward-pass code as a
Keras layer that works in TF2 with TF2 behaviors enabled. It will capture
both weights and regularization losses of your forward-pass code. To use it,
override this class and put your TF1 model's forward pass inside your
implementation for `forward_pass`. (Unlike standard custom Keras layers,
do not override `call`.)
Below are some examples, and then more details on the functionality of this
shim layer to wrap TF1 model forward passes.
Example of capturing tf.compat.v1.layer-based modeling code as a Keras
layer:
```python
class WrappedDoubleDenseLayer(variable_scope_shim.VariableScopeLayer):
def __init__(self, units, *args, **kwargs):
super().__init__(*args, **kwargs)
self.units = units
def forward_pass(self, inputs):
with variable_scope.variable_scope("double_dense_layer"):
out = tf.compat.v1.layers.dense(
inputs, self.units, name="dense_one",
kernel_initializer=tf.compat.v1.random_normal_initializer,
kernel_regularizer="l2")
out = tf.compat.v1.layers.dense(
out, self.units, name="dense_two",
kernel_initializer=tf.compat.v1.random_normal_initializer(),
kernel_regularizer="l2")
return out
# Create a layer that can be used as a standard keras layer
layer = WrappedDoubleDenseLayer(10)
# call the layer on inputs
layer(...)
# Variables created/used within the scope will be tracked by the layer
layer.weights
layer.trainable_variables
# Regularization losses will be captured in layer.losses after a call,
# just like any other Keras layer
reg_losses = layer.losses
```
Example of capturing tf.compat.v1.get_variable-based modeling code as
a Keras layer:
```python
class WrappedDoubleDenseLayer(variable_scope_shim.VariableScopeLayer):
def __init__(self, units, *args, **kwargs):
super().__init__(*args, **kwargs)
self.units = units
def forward_pass(self, inputs):
out = inputs
with tf.compat.v1.variable_scope("double_dense_layer"):
with tf.compat.v1.variable_scope("dense_one"):
# The weights are created with a `regularizer`,
# so the layer should track their regularization losses
kernel = tf.compat.v1.get_variable(
shape=[out.shape[-1], self.units],
regularizer=regularizers.L2(),
initializer=init_ops.ones_initializer(),
name="kernel")
bias = tf.compat.v1.get_variable(
shape=[self.units,],
initializer=init_ops.zeros_initializer(),
name="bias")
out = tf.compat.v1.math.matmul(out, kernel)
out = tf.compat.v1.nn.bias_add(out, bias)
with tf.compat.v1.variable_scope("dense_two"):
kernel = tf.compat.v1.get_variable(
shape=[out.shape[-1], self.units],
regularizer=regularizers.L2(),
initializer=init_ops.ones_initializer(),
name="kernel")
bias = tf.compat.v1.get_variable(
shape=[self.units,],
initializer=init_ops.zeros_initializer(),
name="bias")
out = tf.compat.v1.math.matmul(out, kernel)
out = tf.compat.v1.nn.bias_add(out, bias)
return out
# Create a layer that can be used as a standard keras layer
layer = WrappedDoubleDenseLayer(10)
# call the layer on inputs
layer(...)
# Variables created/used within the scope will be tracked by the layer
layer.weights
layer.trainable_variables
# Regularization losses will be captured in layer.losses after a call,
# just like any other Keras layer
reg_losses = layer.losses
```
Regularization losses:
Any regularizers specified in the `get_variable` calls or
`compat.v1.layer` creations will get captured by this wrapper layer.
Regularization losses are accessible in `layer.losses` after a call just
like in a standard Keras layer, and will be captured by any model that
includes this layer. Regularization losses attached to Keras
layers/models set as attributes of your layer will also get captured in
the standard Keras regularization loss tracking.
Variable scope / variable reuse:
variable-scope based reuse in the `forward_pass` will be respected,
and work like variable-scope based reuse in TF1.
Variable Names/Pre-trained checkpoint loading:
Variable naming from get_variable and `compat.v1.layer` layers will match
the TF1 names, so you should be able to re-use your old name-based
checkpoints. Variable naming for Keras layers/models or for variables
created by `tf.Variable` may change when going to eager execution.
Training Arg in `forward_pass`:
Keras will pass a `training` arg to this layer if `forward_pass` contains
a `training` arg or a `**kwargs` varargs in its call signature,
similarly to how keras passes `training` to other layers in TF2 that have
similar signatures in their `call` implementations.
See more details in the docs
on `tf.keras.layers.Layer` to understand what will be passed and when.
Note: tf.compat.v1.layers are usually not called with `training=None`,
so the training arg to `forward_pass` might not feed through to them
unless you pass it to their calls explicitly.
Call signature of the forward pass:
The semantics of the forward pass signature match the standard
Keras layer `call` signature, including how Keras decides when
to pass in a `training` arg., and the semantics applied to
the first positional arg in the call signature.
Caveats:
* TF2 will not prune unused variable updates (or unused outputs). You may
need to adjust your forward pass code to avoid computations or variable
updates that you don't intend to use. (E.g. by adding a flag to the
`forward_pass` call signature and branching on it).
* Avoid Nesting variable creation in tf.function inside of `forward_pass`
While the layer may safely be used from inside a `tf.function`, using
a function inside of `forward_pass` will break the variable scoping.
* If you would like to nest Keras layers/models or other
`VariableScopeLayer`s directly in `forward_pass`, you need to
assign them as attributes of your layer so that Keras's standard
object-oriented weights and loss tracking will kick in.
See the intro to modules, layers, and models
[guide](https://www.tensorflow.org/guide/intro_to_modules) for more info
"""
@property
@layer_utils.cached_per_instance
def _call_full_argspec(self):
# Argspec inspection is expensive and the call spec is used often, so it
# makes sense to cache the result.
return tf_inspect.getfullargspec(self.forward_pass)
def forward_pass(self, *args, **kwargs):
"""Implement this method. It should include your model forward pass."""
raise NotImplementedError
@track_tf1_style_variables
def call(self, *args, **kwargs):
return self.forward_pass(*args, **kwargs)
@keras_export(v1=["keras.utils.get_or_create_layer"])
def get_or_create_layer(name, create_layer_method):
"""Use this method to track nested keras models in a shim-decorated method.
This method can be used within a `tf.keras.Layer`'s methods decorated by
the`track_tf1_style_variables` shim, to additionally track inner keras Model
objects created within the same method. The inner model's variables and
losses will be accessible via the outer model's `variables` and `losses`
attributes.
This enables tracking of inner keras models using TF2 behaviors, with
minimal changes to existing TF1-style code.
Example:
```python
class NestedLayer(tf.keras.layers.Layer):
def __init__(self, units, *args, **kwargs):
super().__init__(*args, **kwargs)
self.units = units
def build_model(self):
inp = tf.keras.Input(shape=(5, 5))
dense_layer = tf.keras.layers.Dense(
10, name="dense", kernel_regularizer="l2",
kernel_initializer=tf.compat.v1.ones_initializer())
model = tf.keras.Model(inputs=inp, outputs=dense_layer(inp))
return model
@tf.compat.v1.keras.utils.track_tf1_style_variables
def call(self, inputs):
model = tf.compat.v1.keras.utils.get_or_create_layer(
"dense_model", self.build_model)
return model(inputs)
```
The inner model creation should be confined to its own zero-arg function,
which should be passed into this method. In TF1, this method will
immediately create and return the desired model, without any tracking.
Args:
name: A name to give the nested layer to track.
create_layer_method: a Callable that takes no args and returns the nested
layer.
Returns:
The created layer.
"""
store = vs._get_default_variable_store()
if not isinstance(store, _EagerVariableStore):
if not tf.compat.v1.executing_eagerly_outside_functions():
# tf1 case; just create and return layer
return create_layer_method()
else:
raise ValueError(
"Tried to call get_or_create_layer in eager mode from a method "
"notdecorated with "
"@tf.compat.v1.keras.utils.track_tf1_style_variables."
)
vs_name = tf.compat.v1.get_variable_scope().name
name = f"{vs_name}/{name}"
return store.get_or_create_layer(name, create_layer_method)