2044 lines
72 KiB
Python
2044 lines
72 KiB
Python
|
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
# ==============================================================================
|
||
|
"""Variable class."""
|
||
|
|
||
|
import abc
|
||
|
import enum
|
||
|
import functools
|
||
|
import itertools
|
||
|
import os
|
||
|
|
||
|
from tensorflow.core.framework import variable_pb2
|
||
|
from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import
|
||
|
from tensorflow.python.eager import context
|
||
|
from tensorflow.python.framework import dtypes
|
||
|
from tensorflow.python.framework import ops
|
||
|
from tensorflow.python.framework import tensor as tensor_lib
|
||
|
from tensorflow.python.framework import tensor_conversion_registry
|
||
|
from tensorflow.python.framework import tensor_shape
|
||
|
from tensorflow.python.ops import array_ops
|
||
|
from tensorflow.python.ops import array_ops_stack
|
||
|
from tensorflow.python.ops import control_flow_ops
|
||
|
from tensorflow.python.ops import gen_math_ops
|
||
|
from tensorflow.python.ops import math_ops
|
||
|
from tensorflow.python.ops import state_ops
|
||
|
from tensorflow.python.ops import tensor_getitem_override
|
||
|
from tensorflow.python.trackable import base as trackable
|
||
|
from tensorflow.python.util import object_identity
|
||
|
from tensorflow.python.util import tf_should_use
|
||
|
from tensorflow.python.util import traceback_utils
|
||
|
from tensorflow.python.util.deprecation import deprecated
|
||
|
from tensorflow.python.util.deprecation import deprecated_args
|
||
|
from tensorflow.python.util.tf_export import tf_export
|
||
|
|
||
|
|
||
|
def default_variable_creator_v2(next_creator=None, **kwds):
|
||
|
from tensorflow.python.ops import resource_variable_ops # pylint: disable=g-import-not-at-top
|
||
|
|
||
|
return resource_variable_ops.default_variable_creator_v2(
|
||
|
next_creator=next_creator, **kwds)
|
||
|
|
||
|
|
||
|
def _make_getter(captured_getter, captured_previous):
|
||
|
"""To avoid capturing loop variables."""
|
||
|
|
||
|
def getter(**kwargs):
|
||
|
return captured_getter(captured_previous, **kwargs)
|
||
|
|
||
|
return getter
|
||
|
|
||
|
|
||
|
@tf_export("VariableSynchronization")
|
||
|
class VariableSynchronization(enum.Enum):
|
||
|
"""Indicates when a distributed variable will be synced.
|
||
|
|
||
|
* `AUTO`: Indicates that the synchronization will be determined by the current
|
||
|
`DistributionStrategy` (eg. With `MirroredStrategy` this would be
|
||
|
`ON_WRITE`).
|
||
|
* `NONE`: Indicates that there will only be one copy of the variable, so
|
||
|
there is no need to sync.
|
||
|
* `ON_WRITE`: Indicates that the variable will be updated across devices
|
||
|
every time it is written.
|
||
|
* `ON_READ`: Indicates that the variable will be aggregated across devices
|
||
|
when it is read (eg. when checkpointing or when evaluating an op that uses
|
||
|
the variable).
|
||
|
|
||
|
Example:
|
||
|
>>> temp_grad=[tf.Variable([0.], trainable=False,
|
||
|
... synchronization=tf.VariableSynchronization.ON_READ,
|
||
|
... aggregation=tf.VariableAggregation.MEAN
|
||
|
... )]
|
||
|
"""
|
||
|
AUTO = 0
|
||
|
NONE = 1
|
||
|
ON_WRITE = 2
|
||
|
ON_READ = 3
|
||
|
|
||
|
|
||
|
# LINT.IfChange
|
||
|
@tf_export("VariableAggregation", v1=[])
|
||
|
class VariableAggregationV2(enum.Enum):
|
||
|
"""Indicates how a distributed variable will be aggregated.
|
||
|
|
||
|
`tf.distribute.Strategy` distributes a model by making multiple copies
|
||
|
(called "replicas") acting on different elements of the input batch in a
|
||
|
data parallel model. When performing some variable-update operation,
|
||
|
for example `var.assign_add(x)`, in a model, we need to resolve how to combine
|
||
|
the different values for `x` computed in the different replicas.
|
||
|
|
||
|
* `NONE`: This is the default, giving an error if you use a
|
||
|
variable-update operation with multiple replicas.
|
||
|
* `SUM`: Add the updates across replicas.
|
||
|
* `MEAN`: Take the arithmetic mean ("average") of the updates across replicas.
|
||
|
* `ONLY_FIRST_REPLICA`: This is for when every replica is performing the same
|
||
|
update, but we only want to perform the update once. Used, e.g., for the
|
||
|
global step counter.
|
||
|
|
||
|
For example:
|
||
|
|
||
|
>>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
|
||
|
>>> with strategy.scope():
|
||
|
... v = tf.Variable(5.0, aggregation=tf.VariableAggregation.MEAN)
|
||
|
>>> @tf.function
|
||
|
... def update_fn():
|
||
|
... return v.assign_add(1.0)
|
||
|
>>> strategy.run(update_fn)
|
||
|
PerReplica:{
|
||
|
0: <tf.Tensor: shape=(), dtype=float32, numpy=6.0>,
|
||
|
1: <tf.Tensor: shape=(), dtype=float32, numpy=6.0>
|
||
|
}
|
||
|
|
||
|
"""
|
||
|
NONE = 0
|
||
|
SUM = 1
|
||
|
MEAN = 2
|
||
|
ONLY_FIRST_REPLICA = 3
|
||
|
|
||
|
def __hash__(self):
|
||
|
return hash(self.value)
|
||
|
|
||
|
def __eq__(self, other):
|
||
|
if self is other:
|
||
|
return True
|
||
|
elif isinstance(other, VariableAggregation):
|
||
|
return int(self.value) == int(other.value)
|
||
|
else:
|
||
|
return False
|
||
|
|
||
|
|
||
|
@tf_export(v1=["VariableAggregation"])
|
||
|
class VariableAggregation(enum.Enum):
|
||
|
NONE = 0
|
||
|
SUM = 1
|
||
|
MEAN = 2
|
||
|
ONLY_FIRST_REPLICA = 3
|
||
|
ONLY_FIRST_TOWER = 3 # DEPRECATED
|
||
|
|
||
|
def __hash__(self):
|
||
|
return hash(self.value)
|
||
|
|
||
|
|
||
|
# LINT.ThenChange(//tensorflow/core/framework/variable.proto)
|
||
|
#
|
||
|
# Note that we are currently relying on the integer values of the Python enums
|
||
|
# matching the integer values of the proto enums.
|
||
|
|
||
|
VariableAggregation.__doc__ = (
|
||
|
VariableAggregationV2.__doc__ +
|
||
|
"* `ONLY_FIRST_TOWER`: Deprecated alias for `ONLY_FIRST_REPLICA`.\n ")
|
||
|
|
||
|
|
||
|
def validate_synchronization_aggregation_trainable(synchronization, aggregation,
|
||
|
trainable, name):
|
||
|
"""Given user-provided variable properties, sets defaults and validates."""
|
||
|
if aggregation is None:
|
||
|
aggregation = VariableAggregation.NONE
|
||
|
else:
|
||
|
if not isinstance(aggregation,
|
||
|
(VariableAggregation, VariableAggregationV2)):
|
||
|
try:
|
||
|
aggregation = VariableAggregationV2(aggregation)
|
||
|
except ValueError:
|
||
|
raise ValueError(
|
||
|
"Invalid variable aggregation mode: {} for variable: {}".format(
|
||
|
aggregation, name))
|
||
|
if synchronization is None:
|
||
|
synchronization = VariableSynchronization.AUTO
|
||
|
else:
|
||
|
try:
|
||
|
synchronization = VariableSynchronization(synchronization)
|
||
|
except ValueError:
|
||
|
raise ValueError(
|
||
|
"Invalid variable synchronization mode: {} for variable: {}".format(
|
||
|
synchronization, name))
|
||
|
if trainable is None:
|
||
|
trainable = synchronization != VariableSynchronization.ON_READ
|
||
|
return synchronization, aggregation, trainable
|
||
|
|
||
|
|
||
|
class VariableMetaclass(abc.ABCMeta):
|
||
|
"""Metaclass to allow construction of tf.Variable to be overridden."""
|
||
|
|
||
|
@traceback_utils.filter_traceback
|
||
|
def __call__(cls, *args, **kwargs):
|
||
|
if hasattr(cls, "_variable_call") and callable(cls._variable_call):
|
||
|
variable_call = cls._variable_call(*args, **kwargs)
|
||
|
if variable_call is not None:
|
||
|
return variable_call
|
||
|
return super(VariableMetaclass, cls).__call__(*args, **kwargs)
|
||
|
|
||
|
|
||
|
@tf_export("Variable", v1=[])
|
||
|
# TODO(mdan): This should subclass core.Tensor, and not all its subclasses?
|
||
|
class Variable(trackable.Trackable, metaclass=VariableMetaclass):
|
||
|
"""See the [variable guide](https://tensorflow.org/guide/variable).
|
||
|
|
||
|
A variable maintains shared, persistent state manipulated by a program.
|
||
|
|
||
|
The `Variable()` constructor requires an initial value for the variable, which
|
||
|
can be a `Tensor` of any type and shape. This initial value defines the type
|
||
|
and shape of the variable. After construction, the type and shape of the
|
||
|
variable are fixed. The value can be changed using one of the assign methods.
|
||
|
|
||
|
>>> v = tf.Variable(1.)
|
||
|
>>> v.assign(2.)
|
||
|
<tf.Variable ... shape=() dtype=float32, numpy=2.0>
|
||
|
>>> v.assign_add(0.5)
|
||
|
<tf.Variable ... shape=() dtype=float32, numpy=2.5>
|
||
|
|
||
|
The `shape` argument to `Variable`'s constructor allows you to construct a
|
||
|
variable with a less defined shape than its `initial_value`:
|
||
|
|
||
|
>>> v = tf.Variable(1., shape=tf.TensorShape(None))
|
||
|
>>> v.assign([[1.]])
|
||
|
<tf.Variable ... shape=<unknown> dtype=float32, numpy=array([[1.]], ...)>
|
||
|
|
||
|
Just like any `Tensor`, variables created with `Variable()` can be used as
|
||
|
inputs to operations. Additionally, all the operators overloaded for the
|
||
|
`Tensor` class are carried over to variables.
|
||
|
|
||
|
>>> w = tf.Variable([[1.], [2.]])
|
||
|
>>> x = tf.constant([[3., 4.]])
|
||
|
>>> tf.matmul(w, x)
|
||
|
<tf.Tensor:... shape=(2, 2), ... numpy=
|
||
|
array([[3., 4.],
|
||
|
[6., 8.]], dtype=float32)>
|
||
|
>>> tf.sigmoid(w + x)
|
||
|
<tf.Tensor:... shape=(2, 2), ...>
|
||
|
|
||
|
When building a machine learning model it is often convenient to distinguish
|
||
|
between variables holding trainable model parameters and other variables such
|
||
|
as a `step` variable used to count training steps. To make this easier, the
|
||
|
variable constructor supports a `trainable=<bool>`
|
||
|
parameter. `tf.GradientTape` watches trainable variables by default:
|
||
|
|
||
|
>>> with tf.GradientTape(persistent=True) as tape:
|
||
|
... trainable = tf.Variable(1.)
|
||
|
... non_trainable = tf.Variable(2., trainable=False)
|
||
|
... x1 = trainable * 2.
|
||
|
... x2 = non_trainable * 3.
|
||
|
>>> tape.gradient(x1, trainable)
|
||
|
<tf.Tensor:... shape=(), dtype=float32, numpy=2.0>
|
||
|
>>> assert tape.gradient(x2, non_trainable) is None # Unwatched
|
||
|
|
||
|
Variables are automatically tracked when assigned to attributes of types
|
||
|
inheriting from `tf.Module`.
|
||
|
|
||
|
>>> m = tf.Module()
|
||
|
>>> m.v = tf.Variable([1.])
|
||
|
>>> m.trainable_variables
|
||
|
(<tf.Variable ... shape=(1,) ... numpy=array([1.], dtype=float32)>,)
|
||
|
|
||
|
This tracking then allows saving variable values to
|
||
|
[training checkpoints](https://www.tensorflow.org/guide/checkpoint), or to
|
||
|
[SavedModels](https://www.tensorflow.org/guide/saved_model) which include
|
||
|
serialized TensorFlow graphs.
|
||
|
|
||
|
Variables are often captured and manipulated by `tf.function`s. This works the
|
||
|
same way the un-decorated function would have:
|
||
|
|
||
|
>>> v = tf.Variable(0.)
|
||
|
>>> read_and_decrement = tf.function(lambda: v.assign_sub(0.1))
|
||
|
>>> read_and_decrement()
|
||
|
<tf.Tensor: shape=(), dtype=float32, numpy=-0.1>
|
||
|
>>> read_and_decrement()
|
||
|
<tf.Tensor: shape=(), dtype=float32, numpy=-0.2>
|
||
|
|
||
|
Variables created inside a `tf.function` must be owned outside the function
|
||
|
and be created only once:
|
||
|
|
||
|
>>> class M(tf.Module):
|
||
|
... @tf.function
|
||
|
... def __call__(self, x):
|
||
|
... if not hasattr(self, "v"): # Or set self.v to None in __init__
|
||
|
... self.v = tf.Variable(x)
|
||
|
... return self.v * x
|
||
|
>>> m = M()
|
||
|
>>> m(2.)
|
||
|
<tf.Tensor: shape=(), dtype=float32, numpy=4.0>
|
||
|
>>> m(3.)
|
||
|
<tf.Tensor: shape=(), dtype=float32, numpy=6.0>
|
||
|
>>> m.v
|
||
|
<tf.Variable ... shape=() dtype=float32, numpy=2.0>
|
||
|
|
||
|
See the `tf.function` documentation for details.
|
||
|
"""
|
||
|
|
||
|
@deprecated_args(
|
||
|
None, "A variable's value can be manually cached by calling "
|
||
|
"tf.Variable.read_value() under a tf.device scope. The caching_device "
|
||
|
"argument does not work properly.", "caching_device")
|
||
|
def __init__(self,
|
||
|
initial_value=None,
|
||
|
trainable=None,
|
||
|
validate_shape=True,
|
||
|
caching_device=None,
|
||
|
name=None,
|
||
|
variable_def=None,
|
||
|
dtype=None,
|
||
|
import_scope=None,
|
||
|
constraint=None,
|
||
|
synchronization=VariableSynchronization.AUTO,
|
||
|
aggregation=VariableAggregation.NONE,
|
||
|
shape=None,
|
||
|
experimental_enable_variable_lifting=True,
|
||
|
):
|
||
|
"""Creates a new variable with value `initial_value`.
|
||
|
|
||
|
Args:
|
||
|
initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
|
||
|
which is the initial value for the Variable. The initial value must have
|
||
|
a shape specified unless `validate_shape` is set to False. Can also be a
|
||
|
callable with no argument that returns the initial value when called. In
|
||
|
that case, `dtype` must be specified. (Note that initializer functions
|
||
|
from init_ops.py must first be bound to a shape before being used here.)
|
||
|
trainable: If `True`, GradientTapes automatically watch uses of this
|
||
|
variable. Defaults to `True`, unless `synchronization` is set to
|
||
|
`ON_READ`, in which case it defaults to `False`.
|
||
|
validate_shape: If `False`, allows the variable to be initialized with a
|
||
|
value of unknown shape. If `True`, the default, the shape of
|
||
|
`initial_value` must be known.
|
||
|
caching_device: Note: This argument is only valid when using a v1-style
|
||
|
`Session`. Optional device string describing where the Variable should
|
||
|
be cached for reading. Defaults to the Variable's device. If not `None`,
|
||
|
caches on another device. Typical use is to cache on the device where
|
||
|
the Ops using the Variable reside, to deduplicate copying through
|
||
|
`Switch` and other conditional statements.
|
||
|
name: Optional name for the variable. Defaults to `'Variable'` and gets
|
||
|
uniquified automatically.
|
||
|
variable_def: `VariableDef` protocol buffer. If not `None`, recreates the
|
||
|
Variable object with its contents, referencing the variable's nodes in
|
||
|
the graph, which must already exist. The graph is not changed.
|
||
|
`variable_def` and the other arguments are mutually exclusive.
|
||
|
dtype: If set, initial_value will be converted to the given type. If
|
||
|
`None`, either the datatype will be kept (if `initial_value` is a
|
||
|
Tensor), or `convert_to_tensor` will decide.
|
||
|
import_scope: Optional `string`. Name scope to add to the `Variable.` Only
|
||
|
used when initializing from protocol buffer.
|
||
|
constraint: An optional projection function to be applied to the variable
|
||
|
after being updated by an `Optimizer` (e.g. used to implement norm
|
||
|
constraints or value constraints for layer weights). The function must
|
||
|
take as input the unprojected Tensor representing the value of the
|
||
|
variable and return the Tensor for the projected value (which must have
|
||
|
the same shape). Constraints are not safe to use when doing asynchronous
|
||
|
distributed training.
|
||
|
synchronization: Indicates when a distributed variable will be
|
||
|
aggregated. Accepted values are constants defined in the class
|
||
|
`tf.VariableSynchronization`. By default the synchronization is set to
|
||
|
`AUTO` and the current `DistributionStrategy` chooses when to
|
||
|
synchronize.
|
||
|
aggregation: Indicates how a distributed variable will be aggregated.
|
||
|
Accepted values are constants defined in the class
|
||
|
`tf.VariableAggregation`.
|
||
|
shape: (optional) The shape of this variable. If None, the shape of
|
||
|
`initial_value` will be used. When setting this argument to
|
||
|
`tf.TensorShape(None)` (representing an unspecified shape), the variable
|
||
|
can be assigned with values of different shapes.
|
||
|
experimental_enable_variable_lifting: Whether to lift the variable out if
|
||
|
it's in a `tf.function`. Default is `True`. When this argument
|
||
|
is `True`, variable creation will follow the behavior and
|
||
|
restrictions described
|
||
|
[here](https://www.tensorflow.org/guide/function#creating_tfvariables).
|
||
|
If this argument is `False`, that description doesn't apply,
|
||
|
and you can freely create and use the variable in the
|
||
|
`tf.function`, as if it's a "mutable `tf.Tensor`". You can't
|
||
|
return the variable though.
|
||
|
|
||
|
Raises:
|
||
|
ValueError: If both `variable_def` and initial_value are specified.
|
||
|
ValueError: If the initial value is not specified, or does not have a
|
||
|
shape and `validate_shape` is `True`.
|
||
|
"""
|
||
|
raise NotImplementedError
|
||
|
|
||
|
def __repr__(self):
|
||
|
raise NotImplementedError
|
||
|
|
||
|
def value(self):
|
||
|
"""Returns the last snapshot of this variable.
|
||
|
|
||
|
You usually do not need to call this method as all ops that need the value
|
||
|
of the variable call it automatically through a `convert_to_tensor()` call.
|
||
|
|
||
|
Returns a `Tensor` which holds the value of the variable. You can not
|
||
|
assign a new value to this tensor as it is not a reference to the variable.
|
||
|
|
||
|
To avoid copies, if the consumer of the returned value is on the same device
|
||
|
as the variable, this actually returns the live value of the variable, not
|
||
|
a copy. Updates to the variable are seen by the consumer. If the consumer
|
||
|
is on a different device it will get a copy of the variable.
|
||
|
|
||
|
Returns:
|
||
|
A `Tensor` containing the value of the variable.
|
||
|
"""
|
||
|
raise NotImplementedError
|
||
|
|
||
|
def read_value(self):
|
||
|
"""Returns the value of this variable, read in the current context.
|
||
|
|
||
|
Can be different from value() if it's on another device, with control
|
||
|
dependencies, etc.
|
||
|
|
||
|
Returns:
|
||
|
A `Tensor` containing the value of the variable.
|
||
|
"""
|
||
|
raise NotImplementedError
|
||
|
|
||
|
def set_shape(self, shape):
|
||
|
"""Overrides the shape for this variable.
|
||
|
|
||
|
Args:
|
||
|
shape: the `TensorShape` representing the overridden shape.
|
||
|
"""
|
||
|
raise NotImplementedError
|
||
|
|
||
|
@property
|
||
|
def trainable(self):
|
||
|
raise NotImplementedError
|
||
|
|
||
|
@property
|
||
|
def synchronization(self):
|
||
|
raise NotImplementedError
|
||
|
|
||
|
@property
|
||
|
def aggregation(self):
|
||
|
raise NotImplementedError
|
||
|
|
||
|
def eval(self, session=None):
|
||
|
"""In a session, computes and returns the value of this variable.
|
||
|
|
||
|
This is not a graph construction method, it does not add ops to the graph.
|
||
|
|
||
|
This convenience method requires a session where the graph
|
||
|
containing this variable has been launched. If no session is
|
||
|
passed, the default session is used. See `tf.compat.v1.Session` for more
|
||
|
information on launching a graph and on sessions.
|
||
|
|
||
|
```python
|
||
|
v = tf.Variable([1, 2])
|
||
|
init = tf.compat.v1.global_variables_initializer()
|
||
|
|
||
|
with tf.compat.v1.Session() as sess:
|
||
|
sess.run(init)
|
||
|
# Usage passing the session explicitly.
|
||
|
print(v.eval(sess))
|
||
|
# Usage with the default session. The 'with' block
|
||
|
# above makes 'sess' the default session.
|
||
|
print(v.eval())
|
||
|
```
|
||
|
|
||
|
Args:
|
||
|
session: The session to use to evaluate this variable. If none, the
|
||
|
default session is used.
|
||
|
|
||
|
Returns:
|
||
|
A numpy `ndarray` with a copy of the value of this variable.
|
||
|
"""
|
||
|
raise NotImplementedError
|
||
|
|
||
|
@deprecated(
|
||
|
None, "Use Variable.read_value. Variables in 2.X are initialized "
|
||
|
"automatically both in eager and graph (inside tf.defun) contexts.")
|
||
|
def initialized_value(self):
|
||
|
"""Returns the value of the initialized variable.
|
||
|
|
||
|
You should use this instead of the variable itself to initialize another
|
||
|
variable with a value that depends on the value of this variable.
|
||
|
|
||
|
```python
|
||
|
# Initialize 'v' with a random tensor.
|
||
|
v = tf.Variable(tf.random.truncated_normal([10, 40]))
|
||
|
# Use `initialized_value` to guarantee that `v` has been
|
||
|
# initialized before its value is used to initialize `w`.
|
||
|
# The random values are picked only once.
|
||
|
w = tf.Variable(v.initialized_value() * 2.0)
|
||
|
```
|
||
|
|
||
|
Returns:
|
||
|
A `Tensor` holding the value of this variable after its initializer
|
||
|
has run.
|
||
|
"""
|
||
|
raise NotImplementedError
|
||
|
|
||
|
@property
|
||
|
def initial_value(self):
|
||
|
"""Returns the Tensor used as the initial value for the variable.
|
||
|
|
||
|
Note that this is different from `initialized_value()` which runs
|
||
|
the op that initializes the variable before returning its value.
|
||
|
This method returns the tensor that is used by the op that initializes
|
||
|
the variable.
|
||
|
|
||
|
Returns:
|
||
|
A `Tensor`.
|
||
|
"""
|
||
|
raise NotImplementedError
|
||
|
|
||
|
@property
|
||
|
def constraint(self):
|
||
|
"""Returns the constraint function associated with this variable.
|
||
|
|
||
|
Returns:
|
||
|
The constraint function that was passed to the variable constructor.
|
||
|
Can be `None` if no constraint was passed.
|
||
|
"""
|
||
|
raise NotImplementedError
|
||
|
|
||
|
def assign(self, value, use_locking=False, name=None, read_value=True):
|
||
|
"""Assigns a new value to the variable.
|
||
|
|
||
|
This is essentially a shortcut for `assign(self, value)`.
|
||
|
|
||
|
Args:
|
||
|
value: A `Tensor`. The new value for this variable.
|
||
|
use_locking: If `True`, use locking during the assignment.
|
||
|
name: The name of the operation to be created
|
||
|
read_value: if True, will return something which evaluates to the new
|
||
|
value of the variable; if False will return the assign op.
|
||
|
|
||
|
Returns:
|
||
|
The updated variable. If `read_value` is false, instead returns None in
|
||
|
Eager mode and the assign op in graph mode.
|
||
|
"""
|
||
|
raise NotImplementedError
|
||
|
|
||
|
def assign_add(self, delta, use_locking=False, name=None, read_value=True):
|
||
|
"""Adds a value to this variable.
|
||
|
|
||
|
This is essentially a shortcut for `assign_add(self, delta)`.
|
||
|
|
||
|
Args:
|
||
|
delta: A `Tensor`. The value to add to this variable.
|
||
|
use_locking: If `True`, use locking during the operation.
|
||
|
name: The name of the operation to be created
|
||
|
read_value: if True, will return something which evaluates to the new
|
||
|
value of the variable; if False will return the assign op.
|
||
|
|
||
|
Returns:
|
||
|
The updated variable. If `read_value` is false, instead returns None in
|
||
|
Eager mode and the assign op in graph mode.
|
||
|
"""
|
||
|
raise NotImplementedError
|
||
|
|
||
|
def assign_sub(self, delta, use_locking=False, name=None, read_value=True):
|
||
|
"""Subtracts a value from this variable.
|
||
|
|
||
|
This is essentially a shortcut for `assign_sub(self, delta)`.
|
||
|
|
||
|
Args:
|
||
|
delta: A `Tensor`. The value to subtract from this variable.
|
||
|
use_locking: If `True`, use locking during the operation.
|
||
|
name: The name of the operation to be created
|
||
|
read_value: if True, will return something which evaluates to the new
|
||
|
value of the variable; if False will return the assign op.
|
||
|
|
||
|
Returns:
|
||
|
The updated variable. If `read_value` is false, instead returns None in
|
||
|
Eager mode and the assign op in graph mode.
|
||
|
"""
|
||
|
raise NotImplementedError
|
||
|
|
||
|
def scatter_sub(self, sparse_delta, use_locking=False, name=None):
|
||
|
"""Subtracts `tf.IndexedSlices` from this variable.
|
||
|
|
||
|
Args:
|
||
|
sparse_delta: `tf.IndexedSlices` to be subtracted from this variable.
|
||
|
use_locking: If `True`, use locking during the operation.
|
||
|
name: the name of the operation.
|
||
|
|
||
|
Returns:
|
||
|
The updated variable.
|
||
|
|
||
|
Raises:
|
||
|
TypeError: if `sparse_delta` is not an `IndexedSlices`.
|
||
|
"""
|
||
|
raise NotImplementedError
|
||
|
|
||
|
def scatter_add(self, sparse_delta, use_locking=False, name=None):
|
||
|
"""Adds `tf.IndexedSlices` to this variable.
|
||
|
|
||
|
Args:
|
||
|
sparse_delta: `tf.IndexedSlices` to be added to this variable.
|
||
|
use_locking: If `True`, use locking during the operation.
|
||
|
name: the name of the operation.
|
||
|
|
||
|
Returns:
|
||
|
The updated variable.
|
||
|
|
||
|
Raises:
|
||
|
TypeError: if `sparse_delta` is not an `IndexedSlices`.
|
||
|
"""
|
||
|
raise NotImplementedError
|
||
|
|
||
|
def scatter_max(self, sparse_delta, use_locking=False, name=None):
|
||
|
"""Updates this variable with the max of `tf.IndexedSlices` and itself.
|
||
|
|
||
|
Args:
|
||
|
sparse_delta: `tf.IndexedSlices` to use as an argument of max with this
|
||
|
variable.
|
||
|
use_locking: If `True`, use locking during the operation.
|
||
|
name: the name of the operation.
|
||
|
|
||
|
Returns:
|
||
|
The updated variable.
|
||
|
|
||
|
Raises:
|
||
|
TypeError: if `sparse_delta` is not an `IndexedSlices`.
|
||
|
"""
|
||
|
raise NotImplementedError
|
||
|
|
||
|
def scatter_min(self, sparse_delta, use_locking=False, name=None):
|
||
|
"""Updates this variable with the min of `tf.IndexedSlices` and itself.
|
||
|
|
||
|
Args:
|
||
|
sparse_delta: `tf.IndexedSlices` to use as an argument of min with this
|
||
|
variable.
|
||
|
use_locking: If `True`, use locking during the operation.
|
||
|
name: the name of the operation.
|
||
|
|
||
|
Returns:
|
||
|
The updated variable.
|
||
|
|
||
|
Raises:
|
||
|
TypeError: if `sparse_delta` is not an `IndexedSlices`.
|
||
|
"""
|
||
|
raise NotImplementedError
|
||
|
|
||
|
def scatter_mul(self, sparse_delta, use_locking=False, name=None):
|
||
|
"""Multiply this variable by `tf.IndexedSlices`.
|
||
|
|
||
|
Args:
|
||
|
sparse_delta: `tf.IndexedSlices` to multiply this variable by.
|
||
|
use_locking: If `True`, use locking during the operation.
|
||
|
name: the name of the operation.
|
||
|
|
||
|
Returns:
|
||
|
The updated variable.
|
||
|
|
||
|
Raises:
|
||
|
TypeError: if `sparse_delta` is not an `IndexedSlices`.
|
||
|
"""
|
||
|
raise NotImplementedError
|
||
|
|
||
|
def scatter_div(self, sparse_delta, use_locking=False, name=None):
|
||
|
"""Divide this variable by `tf.IndexedSlices`.
|
||
|
|
||
|
Args:
|
||
|
sparse_delta: `tf.IndexedSlices` to divide this variable by.
|
||
|
use_locking: If `True`, use locking during the operation.
|
||
|
name: the name of the operation.
|
||
|
|
||
|
Returns:
|
||
|
The updated variable.
|
||
|
|
||
|
Raises:
|
||
|
TypeError: if `sparse_delta` is not an `IndexedSlices`.
|
||
|
"""
|
||
|
raise NotImplementedError
|
||
|
|
||
|
def scatter_update(self, sparse_delta, use_locking=False, name=None):
|
||
|
"""Assigns `tf.IndexedSlices` to this variable.
|
||
|
|
||
|
Args:
|
||
|
sparse_delta: `tf.IndexedSlices` to be assigned to this variable.
|
||
|
use_locking: If `True`, use locking during the operation.
|
||
|
name: the name of the operation.
|
||
|
|
||
|
Returns:
|
||
|
The updated variable.
|
||
|
|
||
|
Raises:
|
||
|
TypeError: if `sparse_delta` is not an `IndexedSlices`.
|
||
|
"""
|
||
|
raise NotImplementedError
|
||
|
|
||
|
def batch_scatter_update(self, sparse_delta, use_locking=False, name=None):
|
||
|
"""Assigns `tf.IndexedSlices` to this variable batch-wise.
|
||
|
|
||
|
Analogous to `batch_gather`. This assumes that this variable and the
|
||
|
sparse_delta IndexedSlices have a series of leading dimensions that are the
|
||
|
same for all of them, and the updates are performed on the last dimension of
|
||
|
indices. In other words, the dimensions should be the following:
|
||
|
|
||
|
`num_prefix_dims = sparse_delta.indices.ndims - 1`
|
||
|
`batch_dim = num_prefix_dims + 1`
|
||
|
`sparse_delta.updates.shape = sparse_delta.indices.shape + var.shape[
|
||
|
batch_dim:]`
|
||
|
|
||
|
where
|
||
|
|
||
|
`sparse_delta.updates.shape[:num_prefix_dims]`
|
||
|
`== sparse_delta.indices.shape[:num_prefix_dims]`
|
||
|
`== var.shape[:num_prefix_dims]`
|
||
|
|
||
|
And the operation performed can be expressed as:
|
||
|
|
||
|
`var[i_1, ..., i_n,
|
||
|
sparse_delta.indices[i_1, ..., i_n, j]] = sparse_delta.updates[
|
||
|
i_1, ..., i_n, j]`
|
||
|
|
||
|
When sparse_delta.indices is a 1D tensor, this operation is equivalent to
|
||
|
`scatter_update`.
|
||
|
|
||
|
To avoid this operation one can looping over the first `ndims` of the
|
||
|
variable and using `scatter_update` on the subtensors that result of slicing
|
||
|
the first dimension. This is a valid option for `ndims = 1`, but less
|
||
|
efficient than this implementation.
|
||
|
|
||
|
Args:
|
||
|
sparse_delta: `tf.IndexedSlices` to be assigned to this variable.
|
||
|
use_locking: If `True`, use locking during the operation.
|
||
|
name: the name of the operation.
|
||
|
|
||
|
Returns:
|
||
|
The updated variable.
|
||
|
|
||
|
Raises:
|
||
|
TypeError: if `sparse_delta` is not an `IndexedSlices`.
|
||
|
"""
|
||
|
raise NotImplementedError
|
||
|
|
||
|
def scatter_nd_sub(self, indices, updates, name=None):
|
||
|
"""Applies sparse subtraction to individual values or slices in a Variable.
|
||
|
|
||
|
Assuming the variable has rank `P` and `indices` is a `Tensor` of rank `Q`.
|
||
|
|
||
|
`indices` must be integer tensor, containing indices into self.
|
||
|
It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
|
||
|
|
||
|
The innermost dimension of `indices` (with length `K`) corresponds to
|
||
|
indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
|
||
|
dimension of self.
|
||
|
|
||
|
`updates` is `Tensor` of rank `Q-1+P-K` with shape:
|
||
|
|
||
|
```
|
||
|
[d_0, ..., d_{Q-2}, self.shape[K], ..., self.shape[P-1]].
|
||
|
```
|
||
|
|
||
|
For example, say we want to add 4 scattered elements to a rank-1 tensor to
|
||
|
8 elements. In Python, that update would look like this:
|
||
|
|
||
|
```python
|
||
|
v = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
|
||
|
indices = tf.constant([[4], [3], [1] ,[7]])
|
||
|
updates = tf.constant([9, 10, 11, 12])
|
||
|
v.scatter_nd_sub(indices, updates)
|
||
|
print(v)
|
||
|
```
|
||
|
|
||
|
After the update `v` would look like this:
|
||
|
|
||
|
[1, -9, 3, -6, -4, 6, 7, -4]
|
||
|
|
||
|
See `tf.scatter_nd` for more details about how to make updates to
|
||
|
slices.
|
||
|
|
||
|
Args:
|
||
|
indices: The indices to be used in the operation.
|
||
|
updates: The values to be used in the operation.
|
||
|
name: the name of the operation.
|
||
|
|
||
|
Returns:
|
||
|
The updated variable.
|
||
|
"""
|
||
|
raise NotImplementedError
|
||
|
|
||
|
def scatter_nd_add(self, indices, updates, name=None):
|
||
|
"""Applies sparse addition to individual values or slices in a Variable.
|
||
|
|
||
|
The Variable has rank `P` and `indices` is a `Tensor` of rank `Q`.
|
||
|
|
||
|
`indices` must be integer tensor, containing indices into self.
|
||
|
It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
|
||
|
|
||
|
The innermost dimension of `indices` (with length `K`) corresponds to
|
||
|
indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
|
||
|
dimension of self.
|
||
|
|
||
|
`updates` is `Tensor` of rank `Q-1+P-K` with shape:
|
||
|
|
||
|
```
|
||
|
[d_0, ..., d_{Q-2}, self.shape[K], ..., self.shape[P-1]].
|
||
|
```
|
||
|
|
||
|
For example, say we want to add 4 scattered elements to a rank-1 tensor to
|
||
|
8 elements. In Python, that update would look like this:
|
||
|
|
||
|
```python
|
||
|
v = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
|
||
|
indices = tf.constant([[4], [3], [1] ,[7]])
|
||
|
updates = tf.constant([9, 10, 11, 12])
|
||
|
v.scatter_nd_add(indices, updates)
|
||
|
print(v)
|
||
|
```
|
||
|
|
||
|
The resulting update to v would look like this:
|
||
|
|
||
|
[1, 13, 3, 14, 14, 6, 7, 20]
|
||
|
|
||
|
See `tf.scatter_nd` for more details about how to make updates to
|
||
|
slices.
|
||
|
|
||
|
Args:
|
||
|
indices: The indices to be used in the operation.
|
||
|
updates: The values to be used in the operation.
|
||
|
name: the name of the operation.
|
||
|
|
||
|
Returns:
|
||
|
The updated variable.
|
||
|
"""
|
||
|
raise NotImplementedError
|
||
|
|
||
|
def scatter_nd_update(self, indices, updates, name=None):
|
||
|
"""Applies sparse assignment to individual values or slices in a Variable.
|
||
|
|
||
|
The Variable has rank `P` and `indices` is a `Tensor` of rank `Q`.
|
||
|
|
||
|
`indices` must be integer tensor, containing indices into self.
|
||
|
It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
|
||
|
|
||
|
The innermost dimension of `indices` (with length `K`) corresponds to
|
||
|
indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
|
||
|
dimension of self.
|
||
|
|
||
|
`updates` is `Tensor` of rank `Q-1+P-K` with shape:
|
||
|
|
||
|
```
|
||
|
[d_0, ..., d_{Q-2}, self.shape[K], ..., self.shape[P-1]].
|
||
|
```
|
||
|
|
||
|
For example, say we want to add 4 scattered elements to a rank-1 tensor to
|
||
|
8 elements. In Python, that update would look like this:
|
||
|
|
||
|
```python
|
||
|
v = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
|
||
|
indices = tf.constant([[4], [3], [1] ,[7]])
|
||
|
updates = tf.constant([9, 10, 11, 12])
|
||
|
v.scatter_nd_update(indices, updates)
|
||
|
print(v)
|
||
|
```
|
||
|
|
||
|
The resulting update to v would look like this:
|
||
|
|
||
|
[1, 11, 3, 10, 9, 6, 7, 12]
|
||
|
|
||
|
See `tf.scatter_nd` for more details about how to make updates to
|
||
|
slices.
|
||
|
|
||
|
Args:
|
||
|
indices: The indices to be used in the operation.
|
||
|
updates: The values to be used in the operation.
|
||
|
name: the name of the operation.
|
||
|
|
||
|
Returns:
|
||
|
The updated variable.
|
||
|
"""
|
||
|
raise NotImplementedError
|
||
|
|
||
|
def sparse_read(self, indices, name=None):
|
||
|
r"""Gather slices from params axis axis according to indices.
|
||
|
|
||
|
This function supports a subset of tf.gather, see tf.gather for details on
|
||
|
usage.
|
||
|
|
||
|
Args:
|
||
|
indices: The index `Tensor`. Must be one of the following types: `int32`,
|
||
|
`int64`. Must be in range `[0, params.shape[axis])`.
|
||
|
name: A name for the operation (optional).
|
||
|
|
||
|
Returns:
|
||
|
A `Tensor`. Has the same type as `params`.
|
||
|
"""
|
||
|
raise AttributeError
|
||
|
|
||
|
def gather_nd(self, indices, name=None):
|
||
|
r"""Gather slices from `params` into a Tensor with shape specified by `indices`.
|
||
|
|
||
|
See tf.gather_nd for details.
|
||
|
|
||
|
Args:
|
||
|
indices: A `Tensor`. Must be one of the following types: `int32`, `int64`.
|
||
|
Index tensor.
|
||
|
name: A name for the operation (optional).
|
||
|
|
||
|
Returns:
|
||
|
A `Tensor`. Has the same type as `params`.
|
||
|
"""
|
||
|
raise AttributeError
|
||
|
|
||
|
@deprecated(None, "Prefer Dataset.range instead.")
|
||
|
def count_up_to(self, limit):
|
||
|
"""Increments this variable until it reaches `limit`.
|
||
|
|
||
|
When that Op is run it tries to increment the variable by `1`. If
|
||
|
incrementing the variable would bring it above `limit` then the Op raises
|
||
|
the exception `OutOfRangeError`.
|
||
|
|
||
|
If no error is raised, the Op outputs the value of the variable before
|
||
|
the increment.
|
||
|
|
||
|
This is essentially a shortcut for `count_up_to(self, limit)`.
|
||
|
|
||
|
Args:
|
||
|
limit: value at which incrementing the variable raises an error.
|
||
|
|
||
|
Returns:
|
||
|
A `Tensor` that will hold the variable value before the increment. If no
|
||
|
other Op modifies this variable, the values produced will all be
|
||
|
distinct.
|
||
|
"""
|
||
|
raise NotImplementedError
|
||
|
|
||
|
@deprecated(None,
|
||
|
"Prefer Variable.assign which has equivalent behavior in 2.X.")
|
||
|
def load(self, value, session=None):
|
||
|
"""Load new value into this variable.
|
||
|
|
||
|
Writes new value to variable's memory. Doesn't add ops to the graph.
|
||
|
|
||
|
This convenience method requires a session where the graph
|
||
|
containing this variable has been launched. If no session is
|
||
|
passed, the default session is used. See `tf.compat.v1.Session` for more
|
||
|
information on launching a graph and on sessions.
|
||
|
|
||
|
```python
|
||
|
v = tf.Variable([1, 2])
|
||
|
init = tf.compat.v1.global_variables_initializer()
|
||
|
|
||
|
with tf.compat.v1.Session() as sess:
|
||
|
sess.run(init)
|
||
|
# Usage passing the session explicitly.
|
||
|
v.load([2, 3], sess)
|
||
|
print(v.eval(sess)) # prints [2 3]
|
||
|
# Usage with the default session. The 'with' block
|
||
|
# above makes 'sess' the default session.
|
||
|
v.load([3, 4], sess)
|
||
|
print(v.eval()) # prints [3 4]
|
||
|
```
|
||
|
|
||
|
Args:
|
||
|
value: New variable value
|
||
|
session: The session to use to evaluate this variable. If none, the
|
||
|
default session is used.
|
||
|
|
||
|
Raises:
|
||
|
ValueError: Session is not passed and no default session
|
||
|
"""
|
||
|
if context.executing_eagerly():
|
||
|
self.assign(value)
|
||
|
else:
|
||
|
session = session or ops.get_default_session()
|
||
|
if session is None:
|
||
|
raise ValueError(
|
||
|
"Either session argument should be provided or default session "
|
||
|
"should be established")
|
||
|
session.run(self.initializer, {self.initializer.inputs[1]: value})
|
||
|
|
||
|
# Conversion to tensor.
|
||
|
@staticmethod
|
||
|
def _TensorConversionFunction(v, dtype=None, name=None, as_ref=False): # pylint: disable=invalid-name
|
||
|
"""Utility function for converting a Variable to a Tensor."""
|
||
|
_ = name
|
||
|
if dtype and not dtype.is_compatible_with(v.dtype):
|
||
|
raise ValueError(
|
||
|
f"Incompatible type conversion requested to type '{dtype.name}' for "
|
||
|
f"variable of type '{v.dtype.name}' (Variable: {v}).")
|
||
|
if as_ref:
|
||
|
return v._ref() # pylint: disable=protected-access
|
||
|
else:
|
||
|
return v.value()
|
||
|
|
||
|
@classmethod
|
||
|
def _OverloadAllOperators(cls): # pylint: disable=invalid-name
|
||
|
"""Register overloads for all operators."""
|
||
|
for operator in tensor_lib.Tensor.OVERLOADABLE_OPERATORS:
|
||
|
cls._OverloadOperator(operator)
|
||
|
# For slicing, bind getitem differently than a tensor (use _slice_helper_var
|
||
|
# instead)
|
||
|
# pylint: disable=protected-access
|
||
|
setattr(cls, "__getitem__", tensor_getitem_override._slice_helper_var)
|
||
|
|
||
|
@classmethod
|
||
|
def _OverloadOperator(cls, operator): # pylint: disable=invalid-name
|
||
|
"""Defer an operator overload to `tensor_lib.Tensor`.
|
||
|
|
||
|
We pull the operator out of tensor_lib.Tensor dynamically to avoid ordering
|
||
|
issues.
|
||
|
|
||
|
Args:
|
||
|
operator: string. The operator name.
|
||
|
"""
|
||
|
# We can't use the overload mechanism on __eq__ & __ne__ since __eq__ is
|
||
|
# called when adding a variable to sets. As a result we call a.value() which
|
||
|
# causes infinite recursion when operating within a GradientTape
|
||
|
# TODO(gjn): Consider removing this
|
||
|
if operator == "__eq__" or operator == "__ne__":
|
||
|
return
|
||
|
|
||
|
tensor_oper = getattr(tensor_lib.Tensor, operator)
|
||
|
|
||
|
def _run_op(a, *args, **kwargs):
|
||
|
# pylint: disable=protected-access
|
||
|
return tensor_oper(a.value(), *args, **kwargs)
|
||
|
|
||
|
functools.update_wrapper(_run_op, tensor_oper)
|
||
|
setattr(cls, operator, _run_op)
|
||
|
|
||
|
def __hash__(self):
|
||
|
if (
|
||
|
tensor_lib.Tensor._USE_EQUALITY
|
||
|
and ops.executing_eagerly_outside_functions()
|
||
|
): # pylint: disable=protected-access
|
||
|
raise TypeError(
|
||
|
"Variable is unhashable. "
|
||
|
f"Instead, use variable.ref() as the key. (Variable: {self})"
|
||
|
)
|
||
|
else:
|
||
|
return id(self)
|
||
|
|
||
|
# TODO(gjn): duplicate of math_ops.tensor_equals, consider removing
|
||
|
def __eq__(self, other):
|
||
|
"""Compares two variables element-wise for equality."""
|
||
|
if (
|
||
|
tensor_lib.Tensor._USE_EQUALITY
|
||
|
and ops.executing_eagerly_outside_functions()
|
||
|
): # pylint: disable=protected-access
|
||
|
return gen_math_ops.equal(self, other, incompatible_shape_error=False)
|
||
|
else:
|
||
|
# In legacy graph mode, tensor equality is object equality
|
||
|
return self is other
|
||
|
|
||
|
# TODO(gjn): duplicate of math_ops.tensor_not_equals, consider removing
|
||
|
def __ne__(self, other):
|
||
|
"""Compares two variables element-wise for equality."""
|
||
|
if (
|
||
|
tensor_lib.Tensor._USE_EQUALITY
|
||
|
and ops.executing_eagerly_outside_functions()
|
||
|
): # pylint: disable=protected-access
|
||
|
return gen_math_ops.not_equal(self, other, incompatible_shape_error=False)
|
||
|
else:
|
||
|
# In legacy graph mode, tensor equality is object equality
|
||
|
return self is not other
|
||
|
|
||
|
def __iter__(self):
|
||
|
"""When executing eagerly, iterates over the value of the variable."""
|
||
|
return iter(self.read_value())
|
||
|
|
||
|
# NOTE(mrry): This enables the Variable's overloaded "right" binary
|
||
|
# operators to run when the left operand is an ndarray, because it
|
||
|
# accords the Variable class higher priority than an ndarray, or a
|
||
|
# numpy matrix.
|
||
|
# TODO(mrry): Convert this to using numpy's __numpy_ufunc__
|
||
|
# mechanism, which allows more control over how Variables interact
|
||
|
# with ndarrays.
|
||
|
__array_priority__ = 100
|
||
|
|
||
|
@property
|
||
|
def name(self):
|
||
|
"""The name of this variable."""
|
||
|
raise NotImplementedError
|
||
|
|
||
|
@property
|
||
|
def _shared_name(self):
|
||
|
"""The shared name of the variable.
|
||
|
|
||
|
Unlike name(), shared_name doesn't have ":0" suffix. It is user-specified
|
||
|
name with name scope prefix.
|
||
|
|
||
|
Returns:
|
||
|
variable name.
|
||
|
"""
|
||
|
return self.name[:self.name.index(":")]
|
||
|
|
||
|
@property
|
||
|
def initializer(self):
|
||
|
"""The initializer operation for this variable."""
|
||
|
raise NotImplementedError
|
||
|
|
||
|
@property
|
||
|
def device(self):
|
||
|
"""The device of this variable."""
|
||
|
raise NotImplementedError
|
||
|
|
||
|
@property
|
||
|
def dtype(self):
|
||
|
"""The `DType` of this variable."""
|
||
|
raise NotImplementedError
|
||
|
|
||
|
@property
|
||
|
def op(self):
|
||
|
"""The `Operation` of this variable."""
|
||
|
raise NotImplementedError
|
||
|
|
||
|
@property
|
||
|
def graph(self):
|
||
|
"""The `Graph` of this variable."""
|
||
|
raise NotImplementedError
|
||
|
|
||
|
@property
|
||
|
def shape(self):
|
||
|
"""The `TensorShape` of this variable.
|
||
|
|
||
|
Returns:
|
||
|
A `TensorShape`.
|
||
|
"""
|
||
|
raise NotImplementedError
|
||
|
|
||
|
def get_shape(self) -> tensor_shape.TensorShape:
|
||
|
"""Alias of `Variable.shape`."""
|
||
|
return self.shape
|
||
|
|
||
|
def _gather_saveables_for_checkpoint(self):
|
||
|
"""For implementing `Trackable`. This object is saveable on its own."""
|
||
|
return {trackable.VARIABLE_VALUE_KEY: self}
|
||
|
|
||
|
def to_proto(self, export_scope=None):
|
||
|
"""Converts a `Variable` to a `VariableDef` protocol buffer.
|
||
|
|
||
|
Args:
|
||
|
export_scope: Optional `string`. Name scope to remove.
|
||
|
|
||
|
Returns:
|
||
|
A `VariableDef` protocol buffer, or `None` if the `Variable` is not
|
||
|
in the specified name scope.
|
||
|
"""
|
||
|
raise NotImplementedError
|
||
|
|
||
|
@staticmethod
|
||
|
def from_proto(variable_def, import_scope=None):
|
||
|
"""Returns a `Variable` object created from `variable_def`."""
|
||
|
raise NotImplementedError
|
||
|
|
||
|
def _set_save_slice_info(self, save_slice_info):
|
||
|
"""Sets the slice info for this `Variable`.
|
||
|
|
||
|
Args:
|
||
|
save_slice_info: A `Variable.SaveSliceInfo` object.
|
||
|
"""
|
||
|
self._save_slice_info = save_slice_info
|
||
|
|
||
|
def _get_save_slice_info(self):
|
||
|
return self._save_slice_info
|
||
|
|
||
|
@deprecated(None, "Use ref() instead.")
|
||
|
def experimental_ref(self):
|
||
|
return self.ref()
|
||
|
|
||
|
def ref(self):
|
||
|
# tf.Tensor also has the same ref() API. If you update the
|
||
|
# documentation here, please update tf.Tensor.ref() as well.
|
||
|
"""Returns a hashable reference object to this Variable.
|
||
|
|
||
|
The primary use case for this API is to put variables in a set/dictionary.
|
||
|
We can't put variables in a set/dictionary as `variable.__hash__()` is no
|
||
|
longer available starting Tensorflow 2.0.
|
||
|
|
||
|
The following will raise an exception starting 2.0
|
||
|
|
||
|
>>> x = tf.Variable(5)
|
||
|
>>> y = tf.Variable(10)
|
||
|
>>> z = tf.Variable(10)
|
||
|
>>> variable_set = {x, y, z}
|
||
|
Traceback (most recent call last):
|
||
|
...
|
||
|
TypeError: Variable is unhashable. Instead, use tensor.ref() as the key.
|
||
|
>>> variable_dict = {x: 'five', y: 'ten'}
|
||
|
Traceback (most recent call last):
|
||
|
...
|
||
|
TypeError: Variable is unhashable. Instead, use tensor.ref() as the key.
|
||
|
|
||
|
Instead, we can use `variable.ref()`.
|
||
|
|
||
|
>>> variable_set = {x.ref(), y.ref(), z.ref()}
|
||
|
>>> x.ref() in variable_set
|
||
|
True
|
||
|
>>> variable_dict = {x.ref(): 'five', y.ref(): 'ten', z.ref(): 'ten'}
|
||
|
>>> variable_dict[y.ref()]
|
||
|
'ten'
|
||
|
|
||
|
Also, the reference object provides `.deref()` function that returns the
|
||
|
original Variable.
|
||
|
|
||
|
>>> x = tf.Variable(5)
|
||
|
>>> x.ref().deref()
|
||
|
<tf.Variable 'Variable:0' shape=() dtype=int32, numpy=5>
|
||
|
"""
|
||
|
return object_identity.Reference(self)
|
||
|
|
||
|
@classmethod
|
||
|
def _variable_call(
|
||
|
cls,
|
||
|
initial_value=None,
|
||
|
trainable=None,
|
||
|
validate_shape=True,
|
||
|
caching_device=None,
|
||
|
name=None,
|
||
|
variable_def=None,
|
||
|
dtype=None,
|
||
|
import_scope=None,
|
||
|
constraint=None,
|
||
|
synchronization=VariableSynchronization.AUTO,
|
||
|
aggregation=VariableAggregation.NONE,
|
||
|
shape=None,
|
||
|
experimental_enable_variable_lifting=None,
|
||
|
**kwargs,
|
||
|
):
|
||
|
"""Variable class getter. Useful to force the signature."""
|
||
|
if cls is not Variable:
|
||
|
return None
|
||
|
previous_getter = lambda **kws: default_variable_creator_v2(None, **kws)
|
||
|
for _, getter in ops.get_default_graph()._variable_creator_stack: # pylint: disable=protected-access
|
||
|
previous_getter = _make_getter(getter, previous_getter)
|
||
|
|
||
|
# Reset `aggregation` that is explicitly set as `None` to the enum NONE.
|
||
|
if aggregation is None:
|
||
|
aggregation = VariableAggregation.NONE
|
||
|
return previous_getter(
|
||
|
initial_value=initial_value,
|
||
|
trainable=trainable,
|
||
|
validate_shape=validate_shape,
|
||
|
caching_device=caching_device,
|
||
|
name=name,
|
||
|
variable_def=variable_def,
|
||
|
dtype=dtype,
|
||
|
import_scope=import_scope,
|
||
|
constraint=constraint,
|
||
|
synchronization=synchronization,
|
||
|
aggregation=aggregation,
|
||
|
shape=shape,
|
||
|
experimental_enable_variable_lifting=experimental_enable_variable_lifting,
|
||
|
**kwargs
|
||
|
)
|
||
|
|
||
|
class SaveSliceInfo:
|
||
|
"""Information on how to save this Variable as a slice.
|
||
|
|
||
|
Provides internal support for saving variables as slices of a larger
|
||
|
variable. This API is not public and is subject to change.
|
||
|
|
||
|
Available properties:
|
||
|
|
||
|
* full_name
|
||
|
* full_shape
|
||
|
* var_offset
|
||
|
* var_shape
|
||
|
"""
|
||
|
|
||
|
def __init__(self,
|
||
|
full_name=None,
|
||
|
full_shape=None,
|
||
|
var_offset=None,
|
||
|
var_shape=None,
|
||
|
save_slice_info_def=None,
|
||
|
import_scope=None):
|
||
|
"""Create a `SaveSliceInfo`.
|
||
|
|
||
|
Args:
|
||
|
full_name: Name of the full variable of which this `Variable` is a
|
||
|
slice.
|
||
|
full_shape: Shape of the full variable, as a list of int.
|
||
|
var_offset: Offset of this `Variable` into the full variable, as a list
|
||
|
of int.
|
||
|
var_shape: Shape of this `Variable`, as a list of int.
|
||
|
save_slice_info_def: `SaveSliceInfoDef` protocol buffer. If not `None`,
|
||
|
recreates the SaveSliceInfo object its contents. `save_slice_info_def`
|
||
|
and other arguments are mutually exclusive.
|
||
|
import_scope: Optional `string`. Name scope to add. Only used when
|
||
|
initializing from protocol buffer.
|
||
|
"""
|
||
|
if save_slice_info_def:
|
||
|
assert isinstance(save_slice_info_def, variable_pb2.SaveSliceInfoDef)
|
||
|
self.full_name = ops.prepend_name_scope(
|
||
|
save_slice_info_def.full_name, import_scope=import_scope)
|
||
|
self.full_shape = list(save_slice_info_def.full_shape)
|
||
|
self.var_offset = list(save_slice_info_def.var_offset)
|
||
|
self.var_shape = list(save_slice_info_def.var_shape)
|
||
|
else:
|
||
|
self.full_name = full_name
|
||
|
self.full_shape = full_shape
|
||
|
self.var_offset = var_offset
|
||
|
self.var_shape = var_shape
|
||
|
|
||
|
@property
|
||
|
def spec(self):
|
||
|
"""Computes the spec string used for saving."""
|
||
|
full_shape_str = " ".join("%d" % d for d in self.full_shape) + " "
|
||
|
sl_spec = ":".join(
|
||
|
"%d,%d" % (o, s) for o, s in zip(self.var_offset, self.var_shape))
|
||
|
return full_shape_str + sl_spec
|
||
|
|
||
|
def to_proto(self, export_scope=None):
|
||
|
"""Returns a SaveSliceInfoDef() proto.
|
||
|
|
||
|
Args:
|
||
|
export_scope: Optional `string`. Name scope to remove.
|
||
|
|
||
|
Returns:
|
||
|
A `SaveSliceInfoDef` protocol buffer, or None if the `Variable` is not
|
||
|
in the specified name scope.
|
||
|
"""
|
||
|
if (export_scope is None or self.full_name.startswith(export_scope)):
|
||
|
save_slice_info_def = variable_pb2.SaveSliceInfoDef()
|
||
|
save_slice_info_def.full_name = ops.strip_name_scope(
|
||
|
self.full_name, export_scope)
|
||
|
for i in self.full_shape:
|
||
|
save_slice_info_def.full_shape.append(i)
|
||
|
for i in self.var_offset:
|
||
|
save_slice_info_def.var_offset.append(i)
|
||
|
for i in self.var_shape:
|
||
|
save_slice_info_def.var_shape.append(i)
|
||
|
return save_slice_info_def
|
||
|
else:
|
||
|
return None
|
||
|
|
||
|
|
||
|
Variable._OverloadAllOperators() # pylint: disable=protected-access
|
||
|
|
||
|
|
||
|
def _try_guard_against_uninitialized_dependencies(name, initial_value):
|
||
|
"""Attempt to guard against dependencies on uninitialized variables.
|
||
|
|
||
|
Replace references to variables in `initial_value` with references to the
|
||
|
variable's initialized values. The initialized values are essentially
|
||
|
conditional TensorFlow graphs that return a variable's value if it is
|
||
|
initialized or its `initial_value` if it hasn't been initialized. This
|
||
|
replacement is done on a best effort basis:
|
||
|
|
||
|
- If the `initial_value` graph contains cycles, we don't do any
|
||
|
replacements for that graph.
|
||
|
- If the variables that `initial_value` depends on are not present in the
|
||
|
`GLOBAL_VARIABLES` or `LOCAL_VARIABLES` we don't replace them.
|
||
|
|
||
|
In these cases, it is up to the caller to ensure that the `initial_value`
|
||
|
graph uses initialized variables or that they guard access to variables
|
||
|
using their `initialized_value` method.
|
||
|
|
||
|
Args:
|
||
|
name: Variable name.
|
||
|
initial_value: `Tensor`. The initial value.
|
||
|
|
||
|
Returns:
|
||
|
A `Tensor` suitable to initialize a variable.
|
||
|
Raises:
|
||
|
TypeError: If `initial_value` is not a `Tensor`.
|
||
|
"""
|
||
|
if not isinstance(initial_value, tensor_lib.Tensor):
|
||
|
raise TypeError("initial_value needs to be a Tensor: %s" % initial_value)
|
||
|
|
||
|
# Don't modify initial_value if it contains any cyclic dependencies.
|
||
|
if _has_cycle(initial_value.op, state={}):
|
||
|
return initial_value
|
||
|
return _safe_initial_value_from_tensor(name, initial_value, op_cache={})
|
||
|
|
||
|
|
||
|
_UNKNOWN, _STARTED, _FINISHED = range(3)
|
||
|
|
||
|
|
||
|
def _has_cycle(op, state):
|
||
|
"""Detect cycles in the dependencies of `initial_value`."""
|
||
|
op_state = state.get(op.name, _UNKNOWN)
|
||
|
if op_state == _STARTED:
|
||
|
return True
|
||
|
elif op_state == _FINISHED:
|
||
|
return False
|
||
|
|
||
|
state[op.name] = _STARTED
|
||
|
for i in itertools.chain((i.op for i in op.inputs), op.control_inputs):
|
||
|
if _has_cycle(i, state):
|
||
|
return True
|
||
|
state[op.name] = _FINISHED
|
||
|
return False
|
||
|
|
||
|
|
||
|
def _safe_initial_value_from_tensor(name, tensor, op_cache):
|
||
|
"""Replace dependencies on variables with their initialized values.
|
||
|
|
||
|
Args:
|
||
|
name: Variable name.
|
||
|
tensor: A `Tensor`. The tensor to replace.
|
||
|
op_cache: A dict mapping operation names to `Operation`s. Used to memoize
|
||
|
the results so as to avoid creating redundant operations.
|
||
|
|
||
|
Returns:
|
||
|
A `Tensor` compatible with `tensor`. Any inputs that lead to variable
|
||
|
values will be replaced with a corresponding graph that uses the
|
||
|
variable's initialized values. This is done on a best-effort basis. If no
|
||
|
modifications need to be made then `tensor` will be returned unchanged.
|
||
|
"""
|
||
|
op = tensor.op
|
||
|
new_op = op_cache.get(op.name)
|
||
|
if new_op is None:
|
||
|
new_op = _safe_initial_value_from_op(name, op, op_cache)
|
||
|
op_cache[op.name] = new_op
|
||
|
return new_op.outputs[tensor.value_index]
|
||
|
|
||
|
|
||
|
def _safe_initial_value_from_op(name, op, op_cache):
|
||
|
"""Replace dependencies on variables with their initialized values.
|
||
|
|
||
|
Args:
|
||
|
name: Variable name.
|
||
|
op: An `Operation`. The operation to replace.
|
||
|
op_cache: A dict mapping operation names to `Operation`s. Used to memoize
|
||
|
the results so as to avoid creating redundant operations.
|
||
|
|
||
|
Returns:
|
||
|
An `Operation` compatible with `op`. Any inputs that lead to variable
|
||
|
values will be replaced with a corresponding graph that uses the
|
||
|
variable's initialized values. This is done on a best-effort basis. If no
|
||
|
modifications need to be made then `op` will be returned unchanged.
|
||
|
"""
|
||
|
op_type = op.node_def.op
|
||
|
if op_type in ("IsVariableInitialized", "VarIsInitializedOp",
|
||
|
"ReadVariableOp", "If"):
|
||
|
return op
|
||
|
|
||
|
# Attempt to find the initialized_value of any variable reference / handles.
|
||
|
# TODO(b/70206927): Fix handling of ResourceVariables.
|
||
|
if op_type in ("Variable", "VariableV2", "VarHandleOp"):
|
||
|
initialized_value = _find_initialized_value_for_variable(op)
|
||
|
return op if initialized_value is None else initialized_value.op
|
||
|
|
||
|
# Recursively build initializer expressions for inputs.
|
||
|
modified = False
|
||
|
new_op_inputs = []
|
||
|
for op_input in op.inputs:
|
||
|
new_op_input = _safe_initial_value_from_tensor(name, op_input, op_cache)
|
||
|
new_op_inputs.append(new_op_input)
|
||
|
modified = modified or (new_op_input != op_input)
|
||
|
|
||
|
# If at least one input was modified, replace the op.
|
||
|
if modified:
|
||
|
new_op_type = op_type
|
||
|
if new_op_type == "RefSwitch":
|
||
|
new_op_type = "Switch"
|
||
|
new_op_name = op.node_def.name + "_" + name
|
||
|
new_op_name = new_op_name.replace(":", "_")
|
||
|
return op.graph.create_op(
|
||
|
new_op_type,
|
||
|
new_op_inputs,
|
||
|
op._output_types, # pylint: disable=protected-access
|
||
|
name=new_op_name,
|
||
|
attrs=op.node_def.attr)
|
||
|
|
||
|
return op
|
||
|
|
||
|
|
||
|
def _find_initialized_value_for_variable(variable_op):
|
||
|
"""Find the initialized value for a variable op.
|
||
|
|
||
|
To do so, lookup the variable op in the variables collection.
|
||
|
|
||
|
Args:
|
||
|
variable_op: A variable `Operation`.
|
||
|
|
||
|
Returns:
|
||
|
A `Tensor` representing the initialized value for the variable or `None`
|
||
|
if the initialized value could not be found.
|
||
|
"""
|
||
|
try:
|
||
|
var_names = [variable_op.node_def.name, variable_op.node_def.name + ":0"]
|
||
|
for collection_name in (ops.GraphKeys.GLOBAL_VARIABLES,
|
||
|
ops.GraphKeys.LOCAL_VARIABLES):
|
||
|
for var in variable_op.graph.get_collection(collection_name):
|
||
|
if var.name in var_names:
|
||
|
return var.initialized_value()
|
||
|
except AttributeError:
|
||
|
# Return None when an incomplete user-defined variable type was put in
|
||
|
# the collection.
|
||
|
return None
|
||
|
return None
|
||
|
|
||
|
|
||
|
class PartitionedVariable:
|
||
|
"""A container for partitioned `Variable` objects.
|
||
|
|
||
|
@compatibility(eager) `tf.PartitionedVariable` is not compatible with
|
||
|
eager execution. Use `tf.Variable` instead which is compatible
|
||
|
with both eager execution and graph construction. See [the
|
||
|
TensorFlow Eager Execution
|
||
|
guide](https://www.tensorflow.org/guide/eager#variables_and_optimizers)
|
||
|
for details on how variables work in eager execution.
|
||
|
@end_compatibility
|
||
|
"""
|
||
|
|
||
|
def __init__(self, name, shape, dtype, variable_list, partitions):
|
||
|
"""Creates a new partitioned variable wrapper.
|
||
|
|
||
|
Variables passed via the variable_list must contain a save_slice_info
|
||
|
field. Concatenation and iteration is in lexicographic order according
|
||
|
to the var_offset property of the save_slice_info.
|
||
|
|
||
|
Args:
|
||
|
name: String. Overall name of the variables.
|
||
|
shape: List of integers. Overall shape of the variables.
|
||
|
dtype: Type of the variables.
|
||
|
variable_list: List of `Variable` that comprise this partitioned variable.
|
||
|
partitions: List of integers. Number of partitions for each dimension.
|
||
|
|
||
|
Raises:
|
||
|
TypeError: If `variable_list` is not a list of `Variable` objects, or
|
||
|
`partitions` is not a list.
|
||
|
ValueError: If `variable_list` is empty, or the `Variable` shape
|
||
|
information does not match `shape`, or `partitions` has invalid values.
|
||
|
"""
|
||
|
if not isinstance(variable_list, (list, tuple)):
|
||
|
raise TypeError("variable_list is not a list or tuple: %s" %
|
||
|
variable_list)
|
||
|
if not isinstance(partitions, (list, tuple)):
|
||
|
raise TypeError("partitions is not a list or tuple: %s" % partitions)
|
||
|
if not all(p >= 1 for p in partitions):
|
||
|
raise ValueError("partition values must be positive: %s" % partitions)
|
||
|
if not variable_list:
|
||
|
raise ValueError("variable_list may not be empty")
|
||
|
# pylint: disable=protected-access
|
||
|
for v in variable_list:
|
||
|
# Sort the variable_list lexicographically according to var offset value.
|
||
|
if not all(v._get_save_slice_info() is not None for v in variable_list):
|
||
|
raise ValueError(
|
||
|
"All variables must have a save_slice_info available: %s" %
|
||
|
[v.name for v in variable_list])
|
||
|
if len(shape) != len(partitions):
|
||
|
raise ValueError("len(shape) != len(partitions): %s vs. %s" %
|
||
|
(shape, partitions))
|
||
|
if v._get_save_slice_info().full_shape != shape:
|
||
|
raise ValueError("All variables' full shapes must match shape: %s; "
|
||
|
"but full shapes were: %s" %
|
||
|
(shape, str([v._get_save_slice_info().full_shape])))
|
||
|
self._variable_list = sorted(
|
||
|
variable_list, key=lambda v: v._get_save_slice_info().var_offset)
|
||
|
# pylint: enable=protected-access
|
||
|
|
||
|
self._name = name
|
||
|
self._shape = shape
|
||
|
self._dtype = dtype
|
||
|
self._partitions = partitions
|
||
|
self._as_tensor = None
|
||
|
|
||
|
def __iter__(self):
|
||
|
"""Return an iterable for accessing the underlying partition Variables."""
|
||
|
return iter(self._variable_list)
|
||
|
|
||
|
def __len__(self):
|
||
|
num_partition_axes = len(self._partition_axes())
|
||
|
if num_partition_axes > 1:
|
||
|
raise ValueError("Cannot get a length for %d > 1 partition axes" %
|
||
|
num_partition_axes)
|
||
|
return len(self._variable_list)
|
||
|
|
||
|
def _partition_axes(self):
|
||
|
if all(p == 1 for p in self._partitions):
|
||
|
return [0]
|
||
|
else:
|
||
|
return [i for i, p in enumerate(self._partitions) if p > 1]
|
||
|
|
||
|
def _concat(self):
|
||
|
"""Returns the overall concatenated value as a `Tensor`.
|
||
|
|
||
|
This is different from using the partitioned variable directly as a tensor
|
||
|
(through tensor conversion and `as_tensor`) in that it creates a new set of
|
||
|
operations that keeps the control dependencies from its scope.
|
||
|
|
||
|
Returns:
|
||
|
`Tensor` containing the concatenated value.
|
||
|
"""
|
||
|
if len(self._variable_list) == 1:
|
||
|
with ops.name_scope(None):
|
||
|
return array_ops.identity(self._variable_list[0], name=self._name)
|
||
|
|
||
|
partition_axes = self._partition_axes()
|
||
|
|
||
|
if len(partition_axes) > 1:
|
||
|
raise NotImplementedError(
|
||
|
"Cannot concatenate along more than one dimension: %s. "
|
||
|
"Multi-axis partition concat is not supported" % str(partition_axes))
|
||
|
partition_ix = partition_axes[0]
|
||
|
|
||
|
with ops.name_scope(self._name + "/ConcatPartitions/"):
|
||
|
concatenated = array_ops.concat(self._variable_list, partition_ix)
|
||
|
|
||
|
with ops.name_scope(None):
|
||
|
return array_ops.identity(concatenated, name=self._name)
|
||
|
|
||
|
def as_tensor(self):
|
||
|
"""Returns the overall concatenated value as a `Tensor`.
|
||
|
|
||
|
The returned tensor will not inherit the control dependencies from the scope
|
||
|
where the value is used, which is similar to getting the value of
|
||
|
`Variable`.
|
||
|
|
||
|
Returns:
|
||
|
`Tensor` containing the concatenated value.
|
||
|
"""
|
||
|
with ops.control_dependencies(None):
|
||
|
return self._concat()
|
||
|
|
||
|
@staticmethod
|
||
|
def _TensorConversionFunction(v, dtype=None, name=None, as_ref=False):
|
||
|
# pylint: disable=invalid-name
|
||
|
_ = name
|
||
|
if dtype is not None and not dtype.is_compatible_with(v.dtype):
|
||
|
raise ValueError(
|
||
|
"Incompatible type conversion requested to type '%s' for variable "
|
||
|
"of type '%s'" % (dtype.name, v.dtype.name))
|
||
|
if as_ref:
|
||
|
raise NotImplementedError(
|
||
|
"PartitionedVariable doesn't support being used as a reference.")
|
||
|
else:
|
||
|
return v.as_tensor()
|
||
|
|
||
|
@property
|
||
|
def name(self):
|
||
|
return self._name
|
||
|
|
||
|
@property
|
||
|
def dtype(self):
|
||
|
return self._dtype
|
||
|
|
||
|
@property
|
||
|
def shape(self):
|
||
|
return self.get_shape()
|
||
|
|
||
|
@property
|
||
|
def _distribute_strategy(self):
|
||
|
"""The `tf.distribute.Strategy` that this variable was created under."""
|
||
|
# NOTE(yuefengz): Today, no partitioned variables in a distribute strategy.
|
||
|
return None
|
||
|
|
||
|
def get_shape(self) -> tensor_shape.TensorShape:
|
||
|
return self._shape
|
||
|
|
||
|
def _get_variable_list(self):
|
||
|
return self._variable_list
|
||
|
|
||
|
def _get_partitions(self):
|
||
|
return self._partitions
|
||
|
|
||
|
def _apply_assign_fn(self, assign_fn, value):
|
||
|
partition_axes = self._partition_axes()
|
||
|
if len(partition_axes) > 1:
|
||
|
raise NotImplementedError(
|
||
|
"Cannot do assign action along more than one dimension: %s. "
|
||
|
"Multi-axis partition assign action is not supported " %
|
||
|
str(partition_axes))
|
||
|
if isinstance(value, list):
|
||
|
assert len(value) == len(self._variable_list)
|
||
|
value_list = value
|
||
|
elif isinstance(value, PartitionedVariable):
|
||
|
value_list = list(value)
|
||
|
else:
|
||
|
partition_ix = partition_axes[0]
|
||
|
size_splits_list = [
|
||
|
tensor_shape.dimension_value(var.shape[partition_ix])
|
||
|
for var in self._variable_list
|
||
|
]
|
||
|
value_list = array_ops.split(value, size_splits_list, axis=partition_ix)
|
||
|
|
||
|
op_list = [
|
||
|
assign_fn(var, value_list[idx])
|
||
|
for idx, var in enumerate(self._variable_list)
|
||
|
]
|
||
|
return op_list
|
||
|
|
||
|
def assign(self, value, use_locking=False, name=None, read_value=True):
|
||
|
assign_fn = lambda var, r_value: var.assign(
|
||
|
r_value, use_locking=use_locking, name=name, read_value=read_value)
|
||
|
assign_list = self._apply_assign_fn(assign_fn, value)
|
||
|
if read_value:
|
||
|
return assign_list
|
||
|
return [assign.op for assign in assign_list]
|
||
|
|
||
|
def assign_add(self, value, use_locking=False, name=None, read_value=True):
|
||
|
assign_fn = lambda var, r_value: var.assign_add(
|
||
|
r_value, use_locking=use_locking, name=name, read_value=read_value)
|
||
|
assign_list = self._apply_assign_fn(assign_fn, value)
|
||
|
if read_value:
|
||
|
return assign_list
|
||
|
return [assign.op for assign in assign_list]
|
||
|
|
||
|
def assign_sub(self, value, use_locking=False, name=None, read_value=True):
|
||
|
assign_fn = lambda var, r_value: var.assign_sub(
|
||
|
r_value, use_locking=use_locking, name=name, read_value=read_value)
|
||
|
assign_list = self._apply_assign_fn(assign_fn, value)
|
||
|
if read_value:
|
||
|
return assign_list
|
||
|
return [assign.op for assign in assign_list]
|
||
|
|
||
|
|
||
|
@tf_export(v1=["global_variables"])
|
||
|
def global_variables(scope=None):
|
||
|
"""Returns global variables.
|
||
|
|
||
|
Global variables are variables that are shared across machines in a
|
||
|
distributed environment. The `Variable()` constructor or `get_variable()`
|
||
|
automatically adds new variables to the graph collection
|
||
|
`GraphKeys.GLOBAL_VARIABLES`.
|
||
|
This convenience function returns the contents of that collection.
|
||
|
|
||
|
An alternative to global variables are local variables. See
|
||
|
`tf.compat.v1.local_variables`
|
||
|
|
||
|
@compatibility(TF2)
|
||
|
Not compatible with eager execution and `tf.function`. In particular, Graph
|
||
|
collections are deprecated in TF2. Instead please create a
|
||
|
[tf.Module](https://www.tensorflow.org/guide/intro_to_modules)
|
||
|
container for all your model state, including variables.
|
||
|
You can then list all the variables in your `tf.Module` through the
|
||
|
`variables` attribute.
|
||
|
@end_compatibility
|
||
|
|
||
|
Args:
|
||
|
scope: (Optional.) A string. If supplied, the resulting list is filtered to
|
||
|
include only items whose `name` attribute matches `scope` using
|
||
|
`re.match`. Items without a `name` attribute are never returned if a scope
|
||
|
is supplied. The choice of `re.match` means that a `scope` without special
|
||
|
tokens filters by prefix.
|
||
|
|
||
|
Returns:
|
||
|
A list of `Variable` objects.
|
||
|
"""
|
||
|
return ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope)
|
||
|
|
||
|
|
||
|
@tf_export(v1=["all_variables"])
|
||
|
@deprecated("2017-03-02", "Please use tf.global_variables instead.")
|
||
|
def all_variables():
|
||
|
"""Use `tf.compat.v1.global_variables` instead."""
|
||
|
return global_variables()
|
||
|
|
||
|
|
||
|
def _all_saveable_objects(scope=None):
|
||
|
"""Returns all variables and `SaveableObject`s that must be checkpointed.
|
||
|
|
||
|
Args:
|
||
|
scope: (Optional.) A string. If supplied, the resulting list is filtered to
|
||
|
include only items whose `name` attribute matches `scope` using
|
||
|
`re.match`. Items without a `name` attribute are never returned if a scope
|
||
|
is supplied. The choice of `re.match` means that a `scope` without special
|
||
|
tokens filters by prefix.
|
||
|
|
||
|
Returns:
|
||
|
A list of `Variable` and `SaveableObject` to be checkpointed
|
||
|
"""
|
||
|
# TODO(andreasst): make this function public once things are settled.
|
||
|
return (ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope) +
|
||
|
ops.get_collection(ops.GraphKeys.SAVEABLE_OBJECTS, scope))
|
||
|
|
||
|
|
||
|
@tf_export(v1=["local_variables"])
|
||
|
def local_variables(scope=None):
|
||
|
"""Returns local variables.
|
||
|
|
||
|
Local variables - per process variables, usually not saved/restored to
|
||
|
checkpoint and used for temporary or intermediate values.
|
||
|
For example, they can be used as counters for metrics computation or
|
||
|
number of epochs this machine has read data.
|
||
|
The `tf.contrib.framework.local_variable()` function automatically adds the
|
||
|
new variable to `GraphKeys.LOCAL_VARIABLES`.
|
||
|
This convenience function returns the contents of that collection.
|
||
|
|
||
|
An alternative to local variables are global variables. See
|
||
|
`tf.compat.v1.global_variables`
|
||
|
|
||
|
Args:
|
||
|
scope: (Optional.) A string. If supplied, the resulting list is filtered to
|
||
|
include only items whose `name` attribute matches `scope` using
|
||
|
`re.match`. Items without a `name` attribute are never returned if a scope
|
||
|
is supplied. The choice of `re.match` means that a `scope` without special
|
||
|
tokens filters by prefix.
|
||
|
|
||
|
Returns:
|
||
|
A list of local `Variable` objects.
|
||
|
"""
|
||
|
return ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES, scope)
|
||
|
|
||
|
|
||
|
@tf_export(v1=["model_variables"])
|
||
|
def model_variables(scope=None):
|
||
|
"""Returns all variables in the MODEL_VARIABLES collection.
|
||
|
|
||
|
Args:
|
||
|
scope: (Optional.) A string. If supplied, the resulting list is filtered to
|
||
|
include only items whose `name` attribute matches `scope` using
|
||
|
`re.match`. Items without a `name` attribute are never returned if a scope
|
||
|
is supplied. The choice of `re.match` means that a `scope` without special
|
||
|
tokens filters by prefix.
|
||
|
|
||
|
Returns:
|
||
|
A list of local Variable objects.
|
||
|
"""
|
||
|
return ops.get_collection(ops.GraphKeys.MODEL_VARIABLES, scope)
|
||
|
|
||
|
|
||
|
@tf_export(v1=["trainable_variables"])
|
||
|
def trainable_variables(scope=None):
|
||
|
"""Returns all variables created with `trainable=True`.
|
||
|
|
||
|
When passed `trainable=True`, the `Variable()` constructor automatically
|
||
|
adds new variables to the graph collection
|
||
|
`GraphKeys.TRAINABLE_VARIABLES`. This convenience function returns the
|
||
|
contents of that collection.
|
||
|
|
||
|
@compatibility(TF2)
|
||
|
Not compatible with eager execution and `tf.function`. In particular, Graph
|
||
|
collections are deprecated in TF2. Instead please create a `tf.Module`
|
||
|
container for all your model state, including variables.
|
||
|
You can then list all the trainable variables in your `tf.Module` through the
|
||
|
`trainable_variables` attribute.
|
||
|
@end_compatibility
|
||
|
|
||
|
Args:
|
||
|
scope: (Optional.) A string. If supplied, the resulting list is filtered to
|
||
|
include only items whose `name` attribute matches `scope` using
|
||
|
`re.match`. Items without a `name` attribute are never returned if a scope
|
||
|
is supplied. The choice of `re.match` means that a `scope` without special
|
||
|
tokens filters by prefix.
|
||
|
|
||
|
Returns:
|
||
|
A list of Variable objects.
|
||
|
"""
|
||
|
return ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES, scope)
|
||
|
|
||
|
|
||
|
@tf_export(v1=["moving_average_variables"])
|
||
|
def moving_average_variables(scope=None):
|
||
|
"""Returns all variables that maintain their moving averages.
|
||
|
|
||
|
If an `ExponentialMovingAverage` object is created and the `apply()`
|
||
|
method is called on a list of variables, these variables will
|
||
|
be added to the `GraphKeys.MOVING_AVERAGE_VARIABLES` collection.
|
||
|
This convenience function returns the contents of that collection.
|
||
|
|
||
|
Args:
|
||
|
scope: (Optional.) A string. If supplied, the resulting list is filtered to
|
||
|
include only items whose `name` attribute matches `scope` using
|
||
|
`re.match`. Items without a `name` attribute are never returned if a scope
|
||
|
is supplied. The choice of `re.match` means that a `scope` without special
|
||
|
tokens filters by prefix.
|
||
|
|
||
|
Returns:
|
||
|
A list of Variable objects.
|
||
|
"""
|
||
|
return ops.get_collection(ops.GraphKeys.MOVING_AVERAGE_VARIABLES, scope)
|
||
|
|
||
|
|
||
|
@tf_export(v1=["initializers.variables", "variables_initializer"])
|
||
|
def variables_initializer(var_list, name="init"):
|
||
|
"""Returns an Op that initializes a list of variables.
|
||
|
|
||
|
After you launch the graph in a session, you can run the returned Op to
|
||
|
initialize all the variables in `var_list`. This Op runs all the
|
||
|
initializers of the variables in `var_list` in parallel.
|
||
|
|
||
|
Calling `initialize_variables()` is equivalent to passing the list of
|
||
|
initializers to `Group()`.
|
||
|
|
||
|
If `var_list` is empty, however, the function still returns an Op that can
|
||
|
be run. That Op just has no effect.
|
||
|
|
||
|
@compatibility(TF2)
|
||
|
In TF2, variables are initialized immediately when they are created. There is
|
||
|
no longer a need to run variable initializers before using them.
|
||
|
@end_compatibility
|
||
|
|
||
|
Args:
|
||
|
var_list: List of `Variable` objects to initialize.
|
||
|
name: Optional name for the returned operation.
|
||
|
|
||
|
Returns:
|
||
|
An Op that run the initializers of all the specified variables.
|
||
|
"""
|
||
|
if var_list and not context.executing_eagerly():
|
||
|
return control_flow_ops.group(*[v.initializer for v in var_list], name=name)
|
||
|
return control_flow_ops.no_op(name=name)
|
||
|
|
||
|
|
||
|
@tf_export(v1=["initialize_variables"])
|
||
|
@tf_should_use.should_use_result
|
||
|
@deprecated("2017-03-02", "Use `tf.variables_initializer` instead.")
|
||
|
def initialize_variables(var_list, name="init"):
|
||
|
"""See `tf.compat.v1.variables_initializer`."""
|
||
|
return variables_initializer(var_list, name=name)
|
||
|
|
||
|
|
||
|
@tf_export(v1=["initializers.global_variables", "global_variables_initializer"])
|
||
|
def global_variables_initializer():
|
||
|
"""Returns an Op that initializes global variables.
|
||
|
|
||
|
This is just a shortcut for `variables_initializer(global_variables())`
|
||
|
|
||
|
@compatibility(TF2)
|
||
|
In TF2, variables are initialized immediately when they are created. There is
|
||
|
no longer a need to run variable initializers before using them.
|
||
|
@end_compatibility
|
||
|
|
||
|
Returns:
|
||
|
An Op that initializes global variables in the graph.
|
||
|
"""
|
||
|
if context.executing_eagerly():
|
||
|
return control_flow_ops.no_op(name="global_variables_initializer")
|
||
|
return variables_initializer(global_variables())
|
||
|
|
||
|
|
||
|
@tf_export(v1=["initialize_all_variables"])
|
||
|
@tf_should_use.should_use_result
|
||
|
@deprecated("2017-03-02", "Use `tf.global_variables_initializer` instead.")
|
||
|
def initialize_all_variables():
|
||
|
"""See `tf.compat.v1.global_variables_initializer`."""
|
||
|
return global_variables_initializer()
|
||
|
|
||
|
|
||
|
@tf_export(v1=["initializers.local_variables", "local_variables_initializer"])
|
||
|
def local_variables_initializer():
|
||
|
"""Returns an Op that initializes all local variables.
|
||
|
|
||
|
This is just a shortcut for `variables_initializer(local_variables())`
|
||
|
|
||
|
@compatibility(TF2)
|
||
|
In TF2, variables are initialized immediately when they are created. There is
|
||
|
no longer a need to run variable initializers before using them.
|
||
|
@end_compatibility
|
||
|
|
||
|
Returns:
|
||
|
An Op that initializes all local variables in the graph.
|
||
|
"""
|
||
|
if context.executing_eagerly():
|
||
|
return control_flow_ops.no_op(name="local_variables_initializer")
|
||
|
return variables_initializer(local_variables())
|
||
|
|
||
|
|
||
|
@tf_export(v1=["initialize_local_variables"])
|
||
|
@tf_should_use.should_use_result
|
||
|
@deprecated("2017-03-02", "Use `tf.local_variables_initializer` instead.")
|
||
|
def initialize_local_variables():
|
||
|
"""See `tf.compat.v1.local_variables_initializer`."""
|
||
|
return local_variables_initializer()
|
||
|
|
||
|
|
||
|
@tf_export(v1=["assert_variables_initialized"])
|
||
|
@tf_should_use.should_use_result
|
||
|
def assert_variables_initialized(var_list=None):
|
||
|
"""Returns an Op to check if variables are initialized.
|
||
|
|
||
|
NOTE: This function is obsolete and will be removed in 6 months. Please
|
||
|
change your implementation to use `report_uninitialized_variables()`.
|
||
|
|
||
|
When run, the returned Op will raise the exception `FailedPreconditionError`
|
||
|
if any of the variables has not yet been initialized.
|
||
|
|
||
|
Note: This function is implemented by trying to fetch the values of the
|
||
|
variables. If one of the variables is not initialized a message may be
|
||
|
logged by the C++ runtime. This is expected.
|
||
|
|
||
|
Args:
|
||
|
var_list: List of `Variable` objects to check. Defaults to the value of
|
||
|
`global_variables().`
|
||
|
|
||
|
Returns:
|
||
|
An Op, or None if there are no variables.
|
||
|
"""
|
||
|
if var_list is None:
|
||
|
var_list = global_variables() + local_variables()
|
||
|
# Backwards compatibility for old-style variables. TODO(touts): remove.
|
||
|
if not var_list:
|
||
|
var_list = []
|
||
|
for op in ops.get_default_graph().get_operations():
|
||
|
if op.type in ["Variable", "VariableV2", "AutoReloadVariable"]:
|
||
|
var_list.append(op.outputs[0])
|
||
|
if not var_list:
|
||
|
return None
|
||
|
else:
|
||
|
ranks = []
|
||
|
for var in var_list:
|
||
|
with ops.colocate_with(var.op):
|
||
|
ranks.append(array_ops.rank_internal(var, optimize=False))
|
||
|
if len(ranks) == 1:
|
||
|
return ranks[0]
|
||
|
else:
|
||
|
return array_ops_stack.stack(ranks)
|
||
|
|
||
|
|
||
|
@tf_export(v1=["report_uninitialized_variables"])
|
||
|
@tf_should_use.should_use_result
|
||
|
def report_uninitialized_variables(var_list=None,
|
||
|
name="report_uninitialized_variables"):
|
||
|
"""Adds ops to list the names of uninitialized variables.
|
||
|
|
||
|
When run, it returns a 1-D tensor containing the names of uninitialized
|
||
|
variables if there are any, or an empty array if there are none.
|
||
|
|
||
|
Args:
|
||
|
var_list: List of `Variable` objects to check. Defaults to the value of
|
||
|
`global_variables() + local_variables()`
|
||
|
name: Optional name of the `Operation`.
|
||
|
|
||
|
Returns:
|
||
|
A 1-D tensor containing names of the uninitialized variables, or an empty
|
||
|
1-D tensor if there are no variables or no uninitialized variables.
|
||
|
"""
|
||
|
if var_list is None:
|
||
|
var_list = global_variables() + local_variables()
|
||
|
# Backwards compatibility for old-style variables. TODO(touts): remove.
|
||
|
if not var_list:
|
||
|
var_list = []
|
||
|
for op in ops.get_default_graph().get_operations():
|
||
|
if op.type in ["Variable", "VariableV2", "AutoReloadVariable"]:
|
||
|
var_list.append(op.outputs[0])
|
||
|
with ops.name_scope(name):
|
||
|
# Run all operations on CPU
|
||
|
if var_list:
|
||
|
init_vars = [state_ops.is_variable_initialized(v) for v in var_list]
|
||
|
local_device = os.environ.get(
|
||
|
"TF_DEVICE_FOR_UNINITIALIZED_VARIABLE_REPORTING", "/cpu:0")
|
||
|
with ops.device(local_device):
|
||
|
if not var_list:
|
||
|
# Return an empty tensor so we only need to check for returned tensor
|
||
|
# size being 0 as an indication of model ready.
|
||
|
return array_ops.constant([], dtype=dtypes.string)
|
||
|
else:
|
||
|
# Get a 1-D boolean tensor listing whether each variable is initialized.
|
||
|
variables_mask = math_ops.logical_not(array_ops_stack.stack(init_vars))
|
||
|
# Get a 1-D string tensor containing all the variable names.
|
||
|
variable_names_tensor = array_ops.constant(
|
||
|
[s.op.name for s in var_list])
|
||
|
# Return a 1-D tensor containing all the names of
|
||
|
# uninitialized variables.
|
||
|
return array_ops.boolean_mask(variable_names_tensor, variables_mask)
|
||
|
|
||
|
|
||
|
tensor_conversion_registry.register_tensor_conversion_function(
|
||
|
PartitionedVariable, PartitionedVariable._TensorConversionFunction) # pylint: disable=protected-access
|