3RNN/Lib/site-packages/tensorflow/python/training/warm_starting_util.py
2024-05-26 19:49:15 +02:00

562 lines
24 KiB
Python

# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities to warm-start TF.Learn Estimators."""
import collections
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import checkpoint_ops
from tensorflow.python.training import checkpoint_utils
from tensorflow.python.training import saver as saver_lib
from tensorflow.python.training.saving import saveable_object_util
from tensorflow.python.util.tf_export import tf_export
@tf_export(v1=["train.VocabInfo"])
class VocabInfo(
collections.namedtuple("VocabInfo", [
"new_vocab",
"new_vocab_size",
"num_oov_buckets",
"old_vocab",
"old_vocab_size",
"backup_initializer",
"axis",
])):
"""Vocabulary information for warm-starting.
See `tf.estimator.WarmStartSettings` for examples of using
VocabInfo to warm-start.
Args:
new_vocab: [Required] A path to the new vocabulary file (used with the model
to be trained).
new_vocab_size: [Required] An integer indicating how many entries of the new
vocabulary will used in training.
num_oov_buckets: [Required] An integer indicating how many OOV buckets are
associated with the vocabulary.
old_vocab: [Required] A path to the old vocabulary file (used with the
checkpoint to be warm-started from).
old_vocab_size: [Optional] An integer indicating how many entries of the old
vocabulary were used in the creation of the checkpoint. If not provided,
the entire old vocabulary will be used.
backup_initializer: [Optional] A variable initializer used for variables
corresponding to new vocabulary entries and OOV. If not provided, these
entries will be zero-initialized.
axis: [Optional] Denotes what axis the vocabulary corresponds to. The
default, 0, corresponds to the most common use case (embeddings or
linear weights for binary classification / regression). An axis of 1
could be used for warm-starting output layers with class vocabularies.
Returns:
A `VocabInfo` which represents the vocabulary information for warm-starting.
Raises:
ValueError: `axis` is neither 0 or 1.
Example Usage:
```python
embeddings_vocab_info = tf.VocabInfo(
new_vocab='embeddings_vocab',
new_vocab_size=100,
num_oov_buckets=1,
old_vocab='pretrained_embeddings_vocab',
old_vocab_size=10000,
backup_initializer=tf.compat.v1.truncated_normal_initializer(
mean=0.0, stddev=(1 / math.sqrt(embedding_dim))),
axis=0)
softmax_output_layer_kernel_vocab_info = tf.VocabInfo(
new_vocab='class_vocab',
new_vocab_size=5,
num_oov_buckets=0, # No OOV for classes.
old_vocab='old_class_vocab',
old_vocab_size=8,
backup_initializer=tf.compat.v1.glorot_uniform_initializer(),
axis=1)
softmax_output_layer_bias_vocab_info = tf.VocabInfo(
new_vocab='class_vocab',
new_vocab_size=5,
num_oov_buckets=0, # No OOV for classes.
old_vocab='old_class_vocab',
old_vocab_size=8,
backup_initializer=tf.compat.v1.zeros_initializer(),
axis=0)
#Currently, only axis=0 and axis=1 are supported.
```
"""
def __new__(cls,
new_vocab,
new_vocab_size,
num_oov_buckets,
old_vocab,
old_vocab_size=-1,
backup_initializer=None,
axis=0):
if axis != 0 and axis != 1:
raise ValueError("The only supported values for the axis argument are 0 "
"and 1. Provided axis: {}".format(axis))
return super(VocabInfo, cls).__new__(
cls,
new_vocab,
new_vocab_size,
num_oov_buckets,
old_vocab,
old_vocab_size,
backup_initializer,
axis,
)
def _infer_var_name(var):
"""Returns name of the `var`.
Args:
var: A list. The list can contain either of the following:
(i) A single `Variable`
(ii) A single `ResourceVariable`
(iii) Multiple `Variable` objects which must be slices of the same larger
variable.
(iv) A single `PartitionedVariable`
Returns:
Name of the `var`
"""
name_to_var_dict = saveable_object_util.op_list_to_dict(var)
if len(name_to_var_dict) > 1:
raise TypeError("`var` = %s passed as arg violates the constraints. "
"name_to_var_dict = %s" % (var, name_to_var_dict))
return list(name_to_var_dict.keys())[0]
def _get_var_info(var, prev_tensor_name=None):
"""Helper method for standarizing Variable and naming.
Args:
var: Current graph's variable that needs to be warm-started (initialized).
Can be either of the following: (i) `Variable` (ii) `ResourceVariable`
(iii) list of `Variable`: The list must contain slices of the same larger
variable. (iv) `PartitionedVariable`
prev_tensor_name: Name of the tensor to lookup in provided `prev_ckpt`. If
None, we lookup tensor with same name as given `var`.
Returns:
A tuple of the Tensor name and var.
"""
if checkpoint_utils._is_variable(var): # pylint: disable=protected-access
current_var_name = _infer_var_name([var])
elif (isinstance(var, list) and
all(checkpoint_utils._is_variable(v) for v in var)): # pylint: disable=protected-access
current_var_name = _infer_var_name(var)
elif isinstance(var, variables_lib.PartitionedVariable):
current_var_name = _infer_var_name([var])
var = var._get_variable_list() # pylint: disable=protected-access
else:
raise TypeError(
"var MUST be one of the following: a Variable, list of Variable or "
"PartitionedVariable, but is {}".format(type(var)))
if not prev_tensor_name:
# Assume tensor name remains the same.
prev_tensor_name = current_var_name
return prev_tensor_name, var
# pylint: disable=protected-access
# Accesses protected members of tf.Variable to reset the variable's internal
# state.
def _warm_start_var_with_vocab(var,
current_vocab_path,
current_vocab_size,
prev_ckpt,
prev_vocab_path,
previous_vocab_size=-1,
current_oov_buckets=0,
prev_tensor_name=None,
initializer=None,
axis=0):
"""Warm-starts given variable from `prev_tensor_name` tensor in `prev_ckpt`.
Use this method when the `var` is backed by vocabulary. This method stitches
the given `var` such that values corresponding to individual features in the
vocabulary remain consistent irrespective of changing order of the features
between old and new vocabularies.
Args:
var: Current graph's variable that needs to be warm-started (initialized).
Can be either of the following:
(i) `Variable`
(ii) `ResourceVariable`
(iii) list of `Variable`: The list must contain slices of the same larger
variable.
(iv) `PartitionedVariable`
current_vocab_path: Path to the vocab file used for the given `var`.
current_vocab_size: An `int` specifying the number of entries in the current
vocab.
prev_ckpt: A string specifying the directory with checkpoint file(s) or path
to checkpoint. The given checkpoint must have tensor with name
`prev_tensor_name` (if not None) or tensor with name same as given `var`.
prev_vocab_path: Path to the vocab file used for the tensor in `prev_ckpt`.
previous_vocab_size: If provided, will constrain previous vocab to the first
`previous_vocab_size` entries. -1 means use the entire previous vocab.
current_oov_buckets: An `int` specifying the number of out-of-vocabulary
buckets used for given `var`.
prev_tensor_name: Name of the tensor to lookup in provided `prev_ckpt`. If
None, we lookup tensor with same name as given `var`.
initializer: Variable initializer to be used for missing entries. If None,
missing entries will be zero-initialized.
axis: Axis of the variable that the provided vocabulary corresponds to.
Raises:
ValueError: If required args are not provided.
"""
if not (current_vocab_path and current_vocab_size and prev_ckpt and
prev_vocab_path):
raise ValueError("Invalid args: Must provide all of [current_vocab_path, "
"current_vocab_size, prev_ckpt, prev_vocab_path}.")
if checkpoint_utils._is_variable(var):
var = [var]
elif (isinstance(var, list) and
all(checkpoint_utils._is_variable(v) for v in var)):
var = var
elif isinstance(var, variables_lib.PartitionedVariable):
var = var._get_variable_list()
else:
raise TypeError(
"var MUST be one of the following: a Variable, list of Variable or "
"PartitionedVariable, but is {}".format(type(var)))
if not prev_tensor_name:
# Assume tensor name remains the same.
prev_tensor_name = _infer_var_name(var)
total_v_first_axis = sum(v.get_shape().as_list()[0] for v in var)
for v in var:
v_shape = v.get_shape().as_list()
slice_info = v._get_save_slice_info()
partition_info = None
if slice_info:
partition_info = variable_scope._PartitionInfo(
full_shape=slice_info.full_shape, var_offset=slice_info.var_offset)
if axis == 0:
new_row_vocab_size = current_vocab_size
new_col_vocab_size = v_shape[1]
old_row_vocab_size = previous_vocab_size
old_row_vocab_file = prev_vocab_path
new_row_vocab_file = current_vocab_path
old_col_vocab_file = None
new_col_vocab_file = None
num_row_oov_buckets = current_oov_buckets
num_col_oov_buckets = 0
elif axis == 1:
# Note that we must compute this value across all partitions, whereas
# in the axis = 0 case, we can simply use v_shape[1] because we don't
# allow partitioning across axis = 1.
new_row_vocab_size = total_v_first_axis
new_col_vocab_size = current_vocab_size
old_row_vocab_size = -1
old_row_vocab_file = None
new_row_vocab_file = None
old_col_vocab_file = prev_vocab_path
new_col_vocab_file = current_vocab_path
num_row_oov_buckets = 0
num_col_oov_buckets = current_oov_buckets
else:
raise ValueError("The only supported values for the axis argument are 0 "
"and 1. Provided axis: {}".format(axis))
init = checkpoint_ops._load_and_remap_matrix_initializer(
ckpt_path=checkpoint_utils._get_checkpoint_filename(prev_ckpt),
old_tensor_name=prev_tensor_name,
new_row_vocab_size=new_row_vocab_size,
new_col_vocab_size=new_col_vocab_size,
old_row_vocab_size=old_row_vocab_size,
old_row_vocab_file=old_row_vocab_file,
new_row_vocab_file=new_row_vocab_file,
old_col_vocab_file=old_col_vocab_file,
new_col_vocab_file=new_col_vocab_file,
num_row_oov_buckets=num_row_oov_buckets,
num_col_oov_buckets=num_col_oov_buckets,
initializer=initializer)
new_init_val = ops.convert_to_tensor(
init(shape=v_shape, partition_info=partition_info))
v._initializer_op = state_ops.assign(v, new_init_val)
# pylint: enable=protected-access
def _get_grouped_variables(vars_to_warm_start):
"""Collects and groups (possibly partitioned) variables into a dictionary.
The variables can be provided explicitly through vars_to_warm_start, or they
are retrieved from collections (see below).
Args:
vars_to_warm_start: One of the following:
- A regular expression (string) that captures which variables to
warm-start (see tf.compat.v1.get_collection). This expression will
only consider variables in the TRAINABLE_VARIABLES collection.
- A list of strings, each representing a full variable name to warm-start.
These will consider variables in GLOBAL_VARIABLES collection.
- A list of Variables to warm-start.
- `None`, in which case all variables in TRAINABLE_VARIABLES will be used.
Returns:
A dictionary mapping variable names (strings) to lists of Variables.
Raises:
ValueError: If vars_to_warm_start is not a string, `None`, a list of
`Variables`, or a list of strings.
"""
# TODO(b/143899805): Remove unicode checks when deprecating Python2.
if isinstance(vars_to_warm_start, str) or vars_to_warm_start is None:
# Both vars_to_warm_start = '.*' and vars_to_warm_start = None will match
# everything (in TRAINABLE_VARIABLES) here.
logging.info("Warm-starting variables only in TRAINABLE_VARIABLES.")
list_of_vars = ops.get_collection(
ops.GraphKeys.TRAINABLE_VARIABLES, scope=vars_to_warm_start)
elif isinstance(vars_to_warm_start, list):
if all(isinstance(v, str) for v in vars_to_warm_start):
list_of_vars = []
for v in vars_to_warm_start:
list_of_vars += ops.get_collection(
ops.GraphKeys.GLOBAL_VARIABLES, scope=v)
elif all(checkpoint_utils._is_variable(v) for v in vars_to_warm_start): # pylint: disable=protected-access
list_of_vars = vars_to_warm_start
else:
raise ValueError("If `vars_to_warm_start` is a list, it must be all "
"`Variable` or all `str`. Given types are {}".format(
[type(v) for v in vars_to_warm_start]))
else:
raise ValueError("`vars_to_warm_start must be a `list` or `str`. Given "
"type is {}".format(type(vars_to_warm_start)))
# We have to deal with partitioned variables, since get_collection flattens
# out the list.
grouped_variables = {}
for v in list_of_vars:
t = [v] if not isinstance(v, list) else v
var_name = _infer_var_name(t)
grouped_variables.setdefault(var_name, []).append(v)
return grouped_variables
def _get_object_checkpoint_renames(path, variable_names):
"""Returns a dictionary mapping variable names to checkpoint keys.
The warm-starting utility expects variable names to match with the variable
names in the checkpoint. For object-based checkpoints, the variable names
and names in the checkpoint are different. Thus, for object-based checkpoints,
this function is used to obtain the map from variable names to checkpoint
keys.
Args:
path: path to checkpoint directory or file.
variable_names: list of variable names to load from the checkpoint.
Returns:
If the checkpoint is object-based, this function returns a map from variable
names to their corresponding checkpoint keys.
If the checkpoint is name-based, this returns an empty dict.
Raises:
ValueError: If the object-based checkpoint is missing variables.
"""
fname = checkpoint_utils._get_checkpoint_filename(path) # pylint: disable=protected-access
try:
names_to_keys = saver_lib.object_graph_key_mapping(fname)
except errors.NotFoundError:
# If an error is raised from `object_graph_key_mapping`, then the
# checkpoint is name-based. There are no renames, so return an empty dict.
return {}
missing_names = set(variable_names) - set(names_to_keys.keys())
if missing_names:
raise ValueError(
"Attempting to warm-start from an object-based checkpoint, but found "
"that the checkpoint did not contain values for all variables. The "
"following variables were missing: {}"
.format(missing_names))
return {name: names_to_keys[name] for name in variable_names}
@tf_export(v1=["train.warm_start"])
def warm_start(ckpt_to_initialize_from,
vars_to_warm_start=".*",
var_name_to_vocab_info=None,
var_name_to_prev_var_name=None):
"""Warm-starts a model using the given settings.
If you are using a tf.estimator.Estimator, this will automatically be called
during training.
Args:
ckpt_to_initialize_from: [Required] A string specifying the directory with
checkpoint file(s) or path to checkpoint from which to warm-start the
model parameters.
vars_to_warm_start: [Optional] One of the following:
- A regular expression (string) that captures which variables to
warm-start (see tf.compat.v1.get_collection). This expression will only
consider variables in the TRAINABLE_VARIABLES collection -- if you need
to warm-start non_TRAINABLE vars (such as optimizer accumulators or
batch norm statistics), please use the below option.
- A list of strings, each a regex scope provided to
tf.compat.v1.get_collection with GLOBAL_VARIABLES (please see
tf.compat.v1.get_collection). For backwards compatibility reasons,
this is separate from the single-string argument type.
- A list of Variables to warm-start. If you do not have access to the
`Variable` objects at the call site, please use the above option.
- `None`, in which case only TRAINABLE variables specified in
`var_name_to_vocab_info` will be warm-started.
Defaults to `'.*'`, which warm-starts all variables in the
TRAINABLE_VARIABLES collection. Note that this excludes variables such
as accumulators and moving statistics from batch norm.
var_name_to_vocab_info: [Optional] Dict of variable names (strings) to
`tf.estimator.VocabInfo`. The variable names should be "full" variables,
not the names of the partitions. If not explicitly provided, the variable
is assumed to have no (changes to) vocabulary.
var_name_to_prev_var_name: [Optional] Dict of variable names (strings) to
name of the previously-trained variable in `ckpt_to_initialize_from`. If
not explicitly provided, the name of the variable is assumed to be same
between previous checkpoint and current model. Note that this has no
effect on the set of variables that is warm-started, and only controls
name mapping (use `vars_to_warm_start` for controlling what variables to
warm-start).
Raises:
ValueError: If the WarmStartSettings contains prev_var_name or VocabInfo
configuration for variable names that are not used. This is to ensure
a stronger check for variable configuration than relying on users to
examine the logs.
"""
logging.info("Warm-starting from: {}".format(ckpt_to_initialize_from))
grouped_variables = _get_grouped_variables(vars_to_warm_start)
if var_name_to_vocab_info is None:
var_name_to_vocab_info = {}
if not var_name_to_prev_var_name:
# Detect whether the checkpoint is object-based, in which case the
# var_name_to_prev_var_name dictionary should map variable names to
# checkpoint keys. If the user has specified var_name_to_prev_var_name, we
# do not override it.
var_name_to_prev_var_name = _get_object_checkpoint_renames(
ckpt_to_initialize_from, grouped_variables.keys())
warmstarted_count = 0
# Keep track of which var_names in var_name_to_prev_var_name and
# var_name_to_vocab_info have been used. Err on the safer side by throwing an
# exception if any are unused by the end of the loop. It is easy to misname
# a variable during this configuration, in which case without this check, we
# would fail to warm-start silently.
prev_var_name_used = set()
vocab_info_used = set()
# Group the vocabless vars into one call to init_from_checkpoint.
vocabless_vars = {}
for var_name, variable in grouped_variables.items():
prev_var_name = var_name_to_prev_var_name.get(var_name)
if prev_var_name:
prev_var_name_used.add(var_name)
vocab_info = var_name_to_vocab_info.get(var_name)
if vocab_info:
vocab_info_used.add(var_name)
warmstarted_count += 1
logging.debug(
"Warm-starting variable: {}; current_vocab: {} current_vocab_size: {}"
" prev_vocab: {} prev_vocab_size: {} current_oov: {} prev_tensor: {}"
" initializer: {}".format(
var_name, vocab_info.new_vocab, vocab_info.new_vocab_size,
vocab_info.old_vocab, (vocab_info.old_vocab_size if
vocab_info.old_vocab_size > 0 else "All"),
vocab_info.num_oov_buckets, prev_var_name or "Unchanged",
vocab_info.backup_initializer or "zero-initialized"))
_warm_start_var_with_vocab(
variable,
current_vocab_path=vocab_info.new_vocab,
current_vocab_size=vocab_info.new_vocab_size,
prev_ckpt=ckpt_to_initialize_from,
prev_vocab_path=vocab_info.old_vocab,
previous_vocab_size=vocab_info.old_vocab_size,
current_oov_buckets=vocab_info.num_oov_buckets,
prev_tensor_name=prev_var_name,
initializer=vocab_info.backup_initializer,
axis=vocab_info.axis)
else:
# For the special value of vars_to_warm_start = None,
# we only warm-start variables with explicitly specified vocabularies.
if vars_to_warm_start:
warmstarted_count += 1
logging.debug("Warm-starting variable: {}; prev_var_name: {}".format(
var_name, prev_var_name or "Unchanged"))
# Because we use a default empty list in grouped_variables, single
# unpartitioned variables will be lists here, which we rectify in order
# for init_from_checkpoint logic to work correctly.
if len(variable) == 1:
variable = variable[0]
prev_tensor_name, var = _get_var_info(variable, prev_var_name)
if prev_tensor_name in vocabless_vars:
# The API for checkpoint_utils.init_from_checkpoint accepts a mapping
# from checkpoint tensor names to model variable names, so it does not
# support warm-starting two variables from the same tensor. Our work-
# around is to run init_from_checkpoint multiple times, each time we
# encounter a new variable that should be initialized by a previously-
# used tensor.
logging.debug("Requested prev_var_name {} initialize both {} and {}; "
"calling init_from_checkpoint.".format(
prev_tensor_name,
vocabless_vars[prev_tensor_name],
var))
checkpoint_utils.init_from_checkpoint(ckpt_to_initialize_from,
vocabless_vars)
vocabless_vars.clear()
vocabless_vars[prev_tensor_name] = var
if vocabless_vars:
checkpoint_utils.init_from_checkpoint(ckpt_to_initialize_from,
vocabless_vars)
prev_var_name_not_used = set(
var_name_to_prev_var_name.keys()) - prev_var_name_used
vocab_info_not_used = set(var_name_to_vocab_info.keys()) - vocab_info_used
logging.info("Warm-started %d variables.", warmstarted_count)
if prev_var_name_not_used:
raise ValueError(
"You provided the following variables in "
"var_name_to_prev_var_name that were not used: "
"{0}. Perhaps you misspelled them? Here is the list of viable "
"variable names: {1}".format(prev_var_name_not_used,
grouped_variables.keys()))
if vocab_info_not_used:
raise ValueError(
"You provided the following variables in "
"var_name_to_vocab_info that were not used: {0}. "
" Perhaps you misspelled them? Here is the list of viable variable "
"names: {1}".format(vocab_info_not_used, grouped_variables.keys()))