753 lines
30 KiB
Python
753 lines
30 KiB
Python
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# =============================================================================
|
|
# pylint: disable=g-classes-have-attributes
|
|
"""Contains a shim to allow using TF1 get_variable code in TF2."""
|
|
import functools
|
|
|
|
from tensorflow.python.eager import context
|
|
from tensorflow.python.framework import dtypes
|
|
from tensorflow.python.framework import ops
|
|
from tensorflow.python.framework import tensor_shape
|
|
from tensorflow.python.keras.engine import base_layer
|
|
from tensorflow.python.keras.utils import tf_contextlib
|
|
from tensorflow.python.keras.utils import tf_inspect
|
|
from tensorflow.python.module import module
|
|
from tensorflow.python.ops import init_ops
|
|
from tensorflow.python.ops import variable_scope as vs
|
|
from tensorflow.python.ops import variables
|
|
from tensorflow.python.platform import tf_logging as logging
|
|
from tensorflow.python.util import tf_decorator
|
|
|
|
|
|
def as_shape(shape):
|
|
"""Converts the given object to a TensorShape."""
|
|
if isinstance(shape, tensor_shape.TensorShape):
|
|
return shape
|
|
else:
|
|
return tensor_shape.TensorShape(shape)
|
|
|
|
|
|
def _is_callable_object(obj):
|
|
return hasattr(obj, "__call__") and tf_inspect.ismethod(obj.__call__)
|
|
|
|
|
|
def _has_kwargs(fn):
|
|
"""Returns whether the passed callable has **kwargs in its signature.
|
|
|
|
Args:
|
|
fn: Function, or function-like object (e.g., result of `functools.partial`).
|
|
|
|
Returns:
|
|
`bool`: if `fn` has **kwargs in its signature.
|
|
|
|
Raises:
|
|
`TypeError`: If fn is not a Function, or function-like object.
|
|
"""
|
|
if isinstance(fn, functools.partial):
|
|
fn = fn.func
|
|
elif _is_callable_object(fn):
|
|
fn = fn.__call__
|
|
elif not callable(fn):
|
|
raise TypeError(
|
|
"fn should be a function-like object, but is of type {}.".format(
|
|
type(fn)))
|
|
return tf_inspect.getfullargspec(fn).varkw is not None
|
|
|
|
|
|
def fn_args(fn):
|
|
"""Get argument names for function-like object.
|
|
|
|
Args:
|
|
fn: Function, or function-like object (e.g., result of `functools.partial`).
|
|
|
|
Returns:
|
|
`tuple` of string argument names.
|
|
|
|
Raises:
|
|
ValueError: if partial function has positionally bound arguments
|
|
"""
|
|
if isinstance(fn, functools.partial):
|
|
args = fn_args(fn.func)
|
|
args = [a for a in args[len(fn.args):] if a not in (fn.keywords or [])]
|
|
else:
|
|
if hasattr(fn, "__call__") and tf_inspect.ismethod(fn.__call__):
|
|
fn = fn.__call__
|
|
args = tf_inspect.getfullargspec(fn).args
|
|
if _is_bound_method(fn) and args:
|
|
# If it's a bound method, it may or may not have a self/cls first
|
|
# argument; for example, self could be captured in *args.
|
|
# If it does have a positional argument, it is self/cls.
|
|
args.pop(0)
|
|
return tuple(args)
|
|
|
|
|
|
def _is_bound_method(fn):
|
|
_, fn = tf_decorator.unwrap(fn)
|
|
return tf_inspect.ismethod(fn) and (fn.__self__ is not None)
|
|
|
|
|
|
def validate_synchronization_aggregation_trainable(
|
|
synchronization, aggregation, trainable, name):
|
|
"""Given user-provided variable properties, sets defaults and validates."""
|
|
if aggregation is None:
|
|
aggregation = variables.VariableAggregation.NONE
|
|
else:
|
|
if not isinstance(aggregation,
|
|
(variables.VariableAggregation,
|
|
variables.VariableAggregationV2)):
|
|
try:
|
|
aggregation = variables.VariableAggregationV2(aggregation)
|
|
except ValueError:
|
|
raise ValueError(
|
|
"Invalid variable aggregation mode: {} for variable: {}".format(
|
|
aggregation, name))
|
|
if synchronization is None:
|
|
synchronization = variables.VariableSynchronization.AUTO
|
|
else:
|
|
try:
|
|
synchronization = variables.VariableSynchronization(synchronization)
|
|
except ValueError:
|
|
raise ValueError(
|
|
"Invalid variable synchronization mode: {} for variable: {}".format(
|
|
synchronization, name))
|
|
if trainable is None:
|
|
trainable = synchronization != variables.VariableSynchronization.ON_READ
|
|
return synchronization, aggregation, trainable
|
|
|
|
|
|
class _EagerVariableStore(object):
|
|
"""TF2-compatible VariableStore that avoids collections & tracks regularizers.
|
|
|
|
New variable names and new variables can be created; all stored
|
|
variables are initialized with the initializer passed to __init__.
|
|
|
|
All variables get created in `tf.init_scope.` to avoid a bad
|
|
interaction between `tf.function` `FuncGraph` internals, Keras
|
|
Functional Models, and TPUStrategy variable initialization.
|
|
|
|
Attributes:
|
|
vars: a dictionary with string names (same as passed in GetVar) as keys and
|
|
the corresponding TensorFlow Variables as values.
|
|
"""
|
|
|
|
__slots__ = ["_vars", "_regularizers", "_store_eager_variables"]
|
|
|
|
def __init__(self):
|
|
"""Create a variable store."""
|
|
self._vars = {} # A dictionary of the stored TensorFlow variables.
|
|
self._regularizers = {} # A dict mapping var names to their regularizers.
|
|
self._store_eager_variables = True
|
|
|
|
def get_variable(
|
|
self,
|
|
name,
|
|
shape=None,
|
|
dtype=dtypes.float32,
|
|
initializer=None,
|
|
regularizer=None,
|
|
reuse=None,
|
|
trainable=None,
|
|
collections=None,
|
|
caching_device=None,
|
|
partitioner=None,
|
|
validate_shape=True,
|
|
use_resource=None,
|
|
custom_getter=None,
|
|
constraint=None,
|
|
synchronization=vs.VariableSynchronization.AUTO,
|
|
aggregation=vs.VariableAggregation.NONE):
|
|
"""Gets an existing variable with these parameters or create a new one.
|
|
|
|
If a variable with the given name is already stored, we return the stored
|
|
variable. Otherwise, we create a new one.
|
|
|
|
Set `reuse` to `True` when you only want to reuse existing Variables.
|
|
Set `reuse` to `False` when you only want to create new Variables.
|
|
Set `reuse` to None (the default) or tf.compat.v1.AUTO_REUSE when you want
|
|
variables to be created if they don't exist or returned if they do.
|
|
|
|
If initializer is `None` (the default), the default initializer passed in
|
|
the constructor is used. If that one is `None` too, we use a new
|
|
`glorot_uniform_initializer`. If initializer is a Tensor, we use
|
|
it as a value and derive the shape from the initializer.
|
|
|
|
If a partitioner is provided, a `PartitionedVariable` is returned.
|
|
Accessing this object as a `Tensor` returns the shards concatenated along
|
|
the partition axis.
|
|
|
|
Some useful partitioners are available. See, e.g.,
|
|
`variable_axis_size_partitioner` and `min_max_variable_partitioner`.
|
|
|
|
Args:
|
|
name: The name of the new or existing variable.
|
|
shape: Shape of the new or existing variable.
|
|
dtype: Type of the new or existing variable (defaults to `DT_FLOAT`).
|
|
initializer: Initializer for the variable.
|
|
regularizer: A (Tensor -> Tensor or None) function; the result of applying
|
|
it on a newly created variable will be added to the collection
|
|
GraphKeys.REGULARIZATION_LOSSES and can be used for regularization.
|
|
reuse: a Boolean, None, or tf.AUTO_REUSE. Controls reuse or creation of
|
|
variables. When eager execution is enabled this argument is always
|
|
forced to be False.
|
|
trainable: If `True` also add the variable to the graph collection
|
|
`GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). `trainable`
|
|
defaults to `True`, unless `synchronization` is set to `ON_READ`, in
|
|
which case it defaults to `False`.
|
|
collections: List of graph collections keys to add the `Variable` to.
|
|
Defaults to `[GraphKeys.GLOBAL_VARIABLES]` (see `tf.Variable`).
|
|
caching_device: Optional device string or function describing where the
|
|
Variable should be cached for reading. Defaults to the Variable's
|
|
device. If not `None`, caches on another device. Typical use is to
|
|
cache on the device where the Ops using the `Variable` reside, to
|
|
deduplicate copying through `Switch` and other conditional statements.
|
|
partitioner: Optional callable that accepts a fully defined `TensorShape`
|
|
and dtype of the `Variable` to be created, and returns a list of
|
|
partitions for each axis (currently only one axis can be partitioned).
|
|
validate_shape: If False, allows the variable to be initialized with a
|
|
value of unknown shape. If True, the default, the shape of initial_value
|
|
must be known.
|
|
use_resource: If False, creates a regular Variable. If True, creates
|
|
instead an experimental ResourceVariable which has well-defined
|
|
semantics. Defaults to False (will later change to True). When eager
|
|
execution is enabled this argument is always forced to be true.
|
|
custom_getter: Callable that takes as a first argument the true getter,
|
|
and allows overwriting the internal get_variable method. The signature
|
|
of `custom_getter` should match that of this method,
|
|
but the most future-proof version will allow for changes: `def
|
|
custom_getter(getter, *args, **kwargs)`. Direct access to
|
|
all `get_variable` parameters is also allowed: `def
|
|
custom_getter(getter, name, *args, **kwargs)`. A simple identity
|
|
custom getter that simply creates variables with modified names is:
|
|
```python
|
|
def custom_getter(getter, name, *args, **kwargs): return getter(name +
|
|
'_suffix', *args, **kwargs) ```
|
|
constraint: An optional projection function to be applied to the variable
|
|
after being updated by an `Optimizer` (e.g. used to implement norm
|
|
constraints or value constraints for layer weights). The function must
|
|
take as input the unprojected Tensor representing the value of the
|
|
variable and return the Tensor for the projected value (which must have
|
|
the same shape). Constraints are not safe to use when doing asynchronous
|
|
distributed training.
|
|
synchronization: Indicates when a distributed a variable will be
|
|
aggregated. Accepted values are constants defined in the class
|
|
`tf.VariableSynchronization`. By default the synchronization is set to
|
|
`AUTO` and the current `DistributionStrategy` chooses when to
|
|
synchronize.
|
|
aggregation: Indicates how a distributed variable will be aggregated.
|
|
Accepted values are constants defined in the class
|
|
`tf.VariableAggregation`.
|
|
|
|
Returns:
|
|
The created or existing `Variable` (or `PartitionedVariable`, if a
|
|
partitioner was used).
|
|
|
|
Raises:
|
|
ValueError: when creating a new variable and shape is not declared,
|
|
when reusing a variable and specifying a conflicting shape,
|
|
or when violating reuse during variable creation.
|
|
RuntimeError: when eager execution is enabled and not called from an
|
|
EagerVariableStore.
|
|
"""
|
|
if custom_getter is not None and not callable(custom_getter):
|
|
raise ValueError("Passed a custom_getter which is not callable: %s" %
|
|
custom_getter)
|
|
|
|
with ops.init_scope():
|
|
if context.executing_eagerly():
|
|
# Variable creation and initialization takes place in `init_scope`s;
|
|
# as such, if an `init_scope` lifts us into the eager context, then we
|
|
# need to use `ResourceVariable`s.
|
|
use_resource = True
|
|
|
|
# Note that it's fine to reuse eager variables whose initialization was
|
|
# lifted from a function-building graph into the eager context (that's why
|
|
# the following clause is not wrapped in an `init_scope`); lifted variables
|
|
# are tracked by the graph's `VariableStore`.
|
|
if context.executing_eagerly():
|
|
reuse = vs.AUTO_REUSE
|
|
|
|
# If a *_ref type is passed in an error would be triggered further down the
|
|
# stack. We prevent this using base_dtype to get a non-ref version of the
|
|
# type, before doing anything else. When _ref types are removed in favor of
|
|
# resources, this line can be removed.
|
|
try:
|
|
dtype = dtype.base_dtype
|
|
except AttributeError:
|
|
# .base_dtype not existing means that we will try and use the raw dtype
|
|
# which was passed in - this might be a NumPy type which is valid.
|
|
pass
|
|
|
|
# This is the main logic of get_variable. However, custom_getter
|
|
# may override this logic. So we save it as a callable and pass
|
|
# it to custom_getter.
|
|
# Note: the parameters of _true_getter, and their documentation, match
|
|
# *exactly* item-for-item with the docstring of this method.
|
|
def _true_getter( # pylint: disable=missing-docstring
|
|
name,
|
|
shape=None,
|
|
dtype=dtypes.float32,
|
|
initializer=None,
|
|
regularizer=None,
|
|
reuse=None,
|
|
trainable=None,
|
|
collections=None, # pylint: disable=unused-argument
|
|
caching_device=None,
|
|
partitioner=None,
|
|
validate_shape=True,
|
|
use_resource=None, # pylint: disable=unused-argument
|
|
constraint=None,
|
|
synchronization=vs.VariableSynchronization.AUTO,
|
|
aggregation=vs.VariableAggregation.NONE):
|
|
# Partitioned variable currently unsupported w/ the shim
|
|
if partitioner is not None:
|
|
raise ValueError(
|
|
"`partitioner` arg for `get_variable` is unsupported in TF2."
|
|
"File a bug if you need help. You passed %s" % partitioner)
|
|
|
|
# Single variable case
|
|
if "%s/part_0" % name in self._vars:
|
|
raise ValueError(
|
|
"No partitioner was provided, but a partitioned version of the "
|
|
"variable was found: %s/part_0. Perhaps a variable of the same "
|
|
"name was already created with partitioning?" % name)
|
|
|
|
return self._get_single_variable(
|
|
name=name,
|
|
shape=shape,
|
|
dtype=dtype,
|
|
initializer=initializer,
|
|
regularizer=regularizer,
|
|
reuse=reuse,
|
|
trainable=trainable,
|
|
caching_device=caching_device,
|
|
validate_shape=validate_shape,
|
|
constraint=constraint,
|
|
synchronization=synchronization,
|
|
aggregation=aggregation)
|
|
|
|
synchronization, aggregation, trainable = (
|
|
validate_synchronization_aggregation_trainable(
|
|
synchronization, aggregation, trainable, name))
|
|
|
|
if custom_getter is not None:
|
|
# Handle backwards compatibility with getter arguments that were added
|
|
# to the API after users started writing custom getters.
|
|
custom_getter_kwargs = {
|
|
"getter": _true_getter,
|
|
"name": name,
|
|
"shape": shape,
|
|
"dtype": dtype,
|
|
"initializer": initializer,
|
|
"regularizer": regularizer,
|
|
"reuse": reuse,
|
|
"trainable": trainable,
|
|
"collections": collections,
|
|
"caching_device": caching_device,
|
|
"partitioner": partitioner,
|
|
"validate_shape": validate_shape,
|
|
"use_resource": use_resource,
|
|
"synchronization": synchronization,
|
|
"aggregation": aggregation,
|
|
}
|
|
# `fn_args` and `has_kwargs` can handle functions, `functools.partial`,
|
|
# `lambda`.
|
|
if ("constraint" in fn_args(custom_getter) or
|
|
_has_kwargs(custom_getter)):
|
|
custom_getter_kwargs["constraint"] = constraint
|
|
return custom_getter(**custom_getter_kwargs)
|
|
else:
|
|
return _true_getter(
|
|
name,
|
|
shape=shape,
|
|
dtype=dtype,
|
|
initializer=initializer,
|
|
regularizer=regularizer,
|
|
reuse=reuse,
|
|
trainable=trainable,
|
|
collections=collections,
|
|
caching_device=caching_device,
|
|
partitioner=partitioner,
|
|
validate_shape=validate_shape,
|
|
use_resource=use_resource,
|
|
constraint=constraint,
|
|
synchronization=synchronization,
|
|
aggregation=aggregation)
|
|
|
|
def _get_single_variable(
|
|
self,
|
|
name,
|
|
shape=None,
|
|
dtype=dtypes.float32,
|
|
initializer=None,
|
|
regularizer=None,
|
|
partition_info=None,
|
|
reuse=None,
|
|
trainable=None,
|
|
caching_device=None,
|
|
validate_shape=True,
|
|
constraint=None,
|
|
synchronization=vs.VariableSynchronization.AUTO,
|
|
aggregation=vs.VariableAggregation.NONE):
|
|
"""Get or create a single Variable (e.g.
|
|
|
|
a shard or entire variable).
|
|
|
|
See the documentation of get_variable above (ignore partitioning components)
|
|
for details.
|
|
|
|
Args:
|
|
name: see get_variable.
|
|
shape: see get_variable.
|
|
dtype: see get_variable.
|
|
initializer: see get_variable.
|
|
regularizer: see get_variable.
|
|
partition_info: _PartitionInfo object.
|
|
reuse: see get_variable.
|
|
trainable: see get_variable.
|
|
caching_device: see get_variable.
|
|
validate_shape: see get_variable.
|
|
constraint: see get_variable.
|
|
synchronization: see get_variable.
|
|
aggregation: see get_variable.
|
|
|
|
Returns:
|
|
A Variable. See documentation of get_variable above.
|
|
|
|
Raises:
|
|
ValueError: See documentation of get_variable above.
|
|
"""
|
|
# Set to true if initializer is a constant.
|
|
initializing_from_value = False
|
|
if initializer is not None and not callable(initializer):
|
|
initializing_from_value = True
|
|
if shape is not None and initializing_from_value:
|
|
raise ValueError("If initializer is a constant, do not specify shape.")
|
|
|
|
dtype = dtypes.as_dtype(dtype)
|
|
shape = as_shape(shape)
|
|
|
|
if name in self._vars:
|
|
# Here we handle the case when returning an existing variable.
|
|
if reuse is False: # pylint: disable=g-bool-id-comparison
|
|
err_msg = ("Variable %s already exists, disallowed."
|
|
" Did you mean to set reuse=True or "
|
|
"reuse=tf.AUTO_REUSE in VarScope?" % name)
|
|
# ResourceVariables don't have an op associated with so no traceback
|
|
raise ValueError(err_msg)
|
|
found_var = self._vars[name]
|
|
if not shape.is_compatible_with(found_var.get_shape()):
|
|
raise ValueError("Trying to share variable %s, but specified shape %s"
|
|
" and found shape %s." %
|
|
(name, shape, found_var.get_shape()))
|
|
if not dtype.is_compatible_with(found_var.dtype):
|
|
dtype_str = dtype.name
|
|
found_type_str = found_var.dtype.name
|
|
raise ValueError("Trying to share variable %s, but specified dtype %s"
|
|
" and found dtype %s." %
|
|
(name, dtype_str, found_type_str))
|
|
return found_var
|
|
|
|
# The code below handles only the case of creating a new variable.
|
|
if reuse is True: # pylint: disable=g-bool-id-comparison
|
|
raise ValueError("Variable %s does not exist, or was not created with "
|
|
"tf.get_variable(). Did you mean to set "
|
|
"reuse=tf.AUTO_REUSE in VarScope?" % name)
|
|
|
|
# Create the tensor to initialize the variable with default value.
|
|
if initializer is None:
|
|
initializer, initializing_from_value = self._get_default_initializer(
|
|
name=name, shape=shape, dtype=dtype)
|
|
# Enter an init scope when creating the initializer.
|
|
with ops.init_scope():
|
|
if initializing_from_value:
|
|
init_val = initializer
|
|
variable_dtype = None
|
|
else:
|
|
# Instantiate initializer if provided initializer is a type object.
|
|
if tf_inspect.isclass(initializer):
|
|
initializer = initializer()
|
|
if shape.is_fully_defined():
|
|
if "partition_info" in tf_inspect.getargspec(initializer).args:
|
|
init_val = functools.partial(initializer,
|
|
shape.as_list(),
|
|
dtype=dtype,
|
|
partition_info=partition_info)
|
|
else:
|
|
init_val = functools.partial(initializer,
|
|
shape.as_list(), dtype=dtype)
|
|
variable_dtype = dtype.base_dtype
|
|
else:
|
|
init_val = initializer
|
|
variable_dtype = None
|
|
|
|
# Create the variable (Always eagerly as a workaround for a strange
|
|
# tpu / funcgraph / keras functional model interaction )
|
|
with ops.init_scope():
|
|
v = variables.Variable(
|
|
initial_value=init_val,
|
|
name=name,
|
|
trainable=trainable,
|
|
caching_device=caching_device,
|
|
dtype=variable_dtype,
|
|
validate_shape=validate_shape,
|
|
constraint=constraint,
|
|
synchronization=synchronization,
|
|
aggregation=aggregation)
|
|
|
|
self._vars[name] = v
|
|
logging.vlog(1, "Created variable %s with shape %s and init %s", v.name,
|
|
format(shape), initializer)
|
|
|
|
# Run the regularizer if requested and save the resulting loss.
|
|
if regularizer:
|
|
self.add_regularizer(v, regularizer)
|
|
|
|
return v
|
|
|
|
def add_regularizer(self, var, regularizer):
|
|
self._regularizers[var.name] = functools.partial(regularizer, var)
|
|
|
|
# Initialize variable when no initializer provided
|
|
def _get_default_initializer(self, name, shape=None, dtype=dtypes.float32):
|
|
"""Provide a default initializer and a corresponding value.
|
|
|
|
Args:
|
|
name: see get_variable.
|
|
shape: see get_variable.
|
|
dtype: see get_variable.
|
|
|
|
Returns:
|
|
initializer and initializing_from_value. See get_variable above.
|
|
|
|
Raises:
|
|
ValueError: When giving unsupported dtype.
|
|
"""
|
|
del shape
|
|
# If dtype is DT_FLOAT, provide a uniform unit scaling initializer
|
|
if dtype.is_floating:
|
|
initializer = init_ops.glorot_uniform_initializer()
|
|
initializing_from_value = False
|
|
# If dtype is DT_INT/DT_UINT, provide a default value `zero`
|
|
# If dtype is DT_BOOL, provide a default value `FALSE`
|
|
elif (dtype.is_integer or dtype.is_unsigned or dtype.is_bool or
|
|
dtype == dtypes.string):
|
|
initializer = init_ops.zeros_initializer()
|
|
initializing_from_value = False
|
|
# NOTES:Do we need to support for handling DT_STRING and DT_COMPLEX here?
|
|
else:
|
|
raise ValueError("An initializer for variable %s of %s is required" %
|
|
(name, dtype.base_dtype))
|
|
|
|
return initializer, initializing_from_value
|
|
|
|
|
|
class VariableAndLossTracker(module.Module):
|
|
"""Module that has a scope to capture vars/losses made by `get_variable`."""
|
|
|
|
def __init__(self):
|
|
self._var_store = _EagerVariableStore() # pylint: disable=protected-access
|
|
self._variables = {}
|
|
|
|
def _variable_creator(self, next_creator, **kwargs):
|
|
var = next_creator(**kwargs)
|
|
self._variables[var.name] = var
|
|
|
|
return var
|
|
|
|
@tf_contextlib.contextmanager
|
|
def scope(self):
|
|
with vs.variable_creator_scope(
|
|
self._variable_creator), vs.with_variable_store(self._var_store):
|
|
yield
|
|
|
|
def get_regularization_losses(self):
|
|
# TODO(kaftan): Consider adding a regex scope like the collection access.
|
|
# But, < 40-50 usages of get_regularization_loss(es) with `scope`
|
|
# & possible to do manually?
|
|
losses = {}
|
|
for var_name, regularizer in self._var_store._regularizers.items(): # pylint: disable=protected-access
|
|
losses[var_name] = regularizer()
|
|
return losses
|
|
|
|
|
|
class VariableScopeWrapperLayer(base_layer.Layer):
|
|
"""Wrapper Layer to capture `compat.v1.get_variable` and `compat.v1.layers`.
|
|
|
|
See go/tf2-migration-model-bookkeeping for background.
|
|
|
|
This shim layer allows using large sets of TF1 model-forward-pass code as a
|
|
Keras layer that works in TF2 with TF2 behaviors enabled. To use it,
|
|
override this class and put your TF1 model's forward pass inside your
|
|
implementation for `forward_pass`.
|
|
|
|
Below are some examples, and then more details on the functionality of this
|
|
shhim layer to wrap TF1 model forward passes.
|
|
|
|
Example of capturing tf.compat.v1.layer-based modeling code as a Keras layer:
|
|
|
|
```python
|
|
class WrappedDoubleDenseLayer(variable_scope_shim.VariableScopeWrapperLayer):
|
|
|
|
def __init__(self, units, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self.units = units
|
|
|
|
def forward_pass(self, inputs, training=None):
|
|
out = tf.compat.v1.layers.dense(
|
|
inputs, self.units, name="dense_one",
|
|
kernel_initializer=init_ops.ones_initializer(),
|
|
kernel_regularizer="l2")
|
|
with variable_scope.variable_scope("nested_scope"):
|
|
out = tf.compat.v1.layers.dense(
|
|
out, self.units, name="dense_two",
|
|
kernel_initializer=init_ops.ones_initializer(),
|
|
kernel_regularizer="l2")
|
|
return out
|
|
|
|
# Create a layer that can be used as a standard keras layer
|
|
layer = WrappedDoubleDenseLayer(10)
|
|
|
|
# call the layer on inputs
|
|
layer(...)
|
|
|
|
# Variables created/used within the scope will be tracked by the layer
|
|
layer.weights
|
|
layer.trainable_variables
|
|
|
|
# Regularization losses will be captured in layer.losses after a call,
|
|
# just like any other Keras layer
|
|
reg_losses = layer.losses
|
|
```
|
|
|
|
The solution is to wrap the model construction and execution in a keras-style
|
|
scope:
|
|
|
|
```python
|
|
class WrappedDoubleDenseLayer(variable_scope_shim.VariableScopeWrapperLayer):
|
|
|
|
def __init__(self, units, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self.units = units
|
|
|
|
def forward_pass(self, inputs, training=None):
|
|
out = inputs
|
|
with tf.compat.v1.variable_scope("dense_one"):
|
|
# The weights are created with a `regularizer`,
|
|
# so the layer should track their regularization losses
|
|
kernel = tf.compat.v1.get_variable(
|
|
shape=[out.shape[-1], self.units],
|
|
regularizer=regularizers.L2(),
|
|
initializer=init_ops.ones_initializer(),
|
|
name="kernel")
|
|
bias = tf.compat.v1.get_variable(
|
|
shape=[self.units,],
|
|
initializer=init_ops.zeros_initializer(),
|
|
name="bias")
|
|
out = tf.compat.v1.math.matmul(out, kernel)
|
|
out = tf.compat.v1.nn.bias_add(out, bias)
|
|
with tf.compat.v1.variable_scope("nested_scope"):
|
|
with tf.compat.v1.variable_scope("dense_two"):
|
|
kernel = tf.compat.v1.get_variable(
|
|
shape=[out.shape[-1], self.units],
|
|
regularizer=regularizers.L2(),
|
|
initializer=init_ops.ones_initializer(),
|
|
name="kernel")
|
|
bias = tf.compat.v1.get_variable(
|
|
shape=[self.units,],
|
|
initializer=init_ops.zeros_initializer(),
|
|
name="bias")
|
|
out = tf.compat.v1.math.matmul(out, kernel)
|
|
out = tf.compat.v1.nn.bias_add(out, bias)
|
|
return out
|
|
|
|
# Create a layer that can be used as a standard keras layer
|
|
layer = WrappedDoubleDenseLayer(10)
|
|
|
|
# call the layer on inputs
|
|
layer(...)
|
|
|
|
# Variables created/used within the scope will be tracked by the layer
|
|
layer.weights
|
|
layer.trainable_variables
|
|
|
|
# Regularization losses will be captured in layer.losses after a call,
|
|
# just like any other Keras layer
|
|
reg_losses = layer.losses
|
|
```
|
|
|
|
Regularization losses:
|
|
Any regularizers specified in the `get_variable` calls or `compat.v1.layer`
|
|
creations will get captured by this wrapper layer. Regularization losses
|
|
are accessible in `layer.losses` after a call just like in a standard
|
|
Keras layer, and will be captured by any model that includes this layer.
|
|
|
|
Variable scope / variable reuse:
|
|
variable-scope based reuse in the `forward_pass` will be respected,
|
|
and work like variable-scope based reuse in TF1.
|
|
|
|
Variable Names/Pre-trained checkpoint loading:
|
|
variable naming from get_variable and `compat.v1.layer` layers will match
|
|
the TF1 names, so you should be able to re-use your old name-based
|
|
checkpoints.
|
|
|
|
Training Arg in `forward_pass`:
|
|
Keras will pass a `training` arg to this layer similarly to how it
|
|
passes `training` to other layers in TF2. See more details in the docs
|
|
on `tf.keras.layers.Layer` to understand what will be passed and when.
|
|
Note: tf.compat.v1.layers are usually not called with `training=None`,
|
|
so the training arg to `forward_pass` might not feed through to them
|
|
unless you pass it to their calls explicitly.
|
|
|
|
Call signature of the forward pass:
|
|
The semantics of the forward pass signature roughly match the standard
|
|
Keras layer `call` signature, except that a `training` arg will *always*
|
|
be passed, so your `forward_pass` must accept either.
|
|
|
|
Limitations:
|
|
* TF2 will not prune unused variable updates (or unused outputs). You may
|
|
need to adjust your forward pass code to avoid computations or variable
|
|
updates that you don't intend to use. (E.g. by adding a flag to the
|
|
`forward_pass` call signature and branching on it).
|
|
* Avoid Nesting variable creation in tf.function inside of `forward_pass`
|
|
While the layer may safetely be used from inside a `tf.function`, using
|
|
a function inside of `forward_pass` will break the variable scoping.
|
|
* TBD: Nesting keras layers/models or other `VariableScopeWrapperLayer`s
|
|
directly in `forward_pass` may not work correctly just yet.
|
|
Support for this/instructions for how to do this is sill being worked on.
|
|
|
|
Coming soon: A better guide, testing/verification guide.
|
|
"""
|
|
|
|
def __init__(self, **kwargs):
|
|
super().__init__(**kwargs)
|
|
# Relies on keras layers tracking Modules
|
|
self.tracker = VariableAndLossTracker()
|
|
# May need to inspect func to see if it should pass a `training` arg or not
|
|
|
|
def forward_pass(self, *args, **kwargs):
|
|
raise NotImplementedError
|
|
|
|
def call(self, *args, **kwargs):
|
|
with self.tracker.scope():
|
|
out = self.forward_pass(*args, **kwargs)
|
|
if not self._eager_losses:
|
|
# We have to record regularization losses in the call as if they
|
|
# are activity losses.
|
|
# So, don't double-count regularization losses if the layer is used
|
|
# multiple times in a model
|
|
for loss in self.tracker.get_regularization_losses().values():
|
|
self.add_loss(loss)
|
|
return out
|