614 lines
21 KiB
Python
614 lines
21 KiB
Python
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Contains AutoCastVariable, a variable which automatically casts itself."""
|
|
|
|
import threading
|
|
|
|
import tensorflow.compat.v2 as tf
|
|
|
|
from keras.distribute import distributed_training_utils
|
|
|
|
# _autocast_dtype.dtype is the dtype AutoCastVariables should be cast to, or
|
|
# None if AutoCastVariables should not be cast.
|
|
_autocast_dtype = threading.local()
|
|
|
|
|
|
def numpy_text(tensor, is_repr=False):
|
|
"""Human readable representation of a tensor's numpy value."""
|
|
if tensor.dtype.is_numpy_compatible:
|
|
|
|
text = repr(tensor._numpy()) if is_repr else str(tensor._numpy())
|
|
|
|
else:
|
|
text = "<unprintable>"
|
|
if "\n" in text:
|
|
text = "\n" + text
|
|
return text
|
|
|
|
|
|
class AutoCastVariableSpec(tf.types.experimental.TraceType):
|
|
"""TraceType for AutoCastVariableSpec for tracing with tf.function.
|
|
|
|
This class implements the Type for AutoCastVariable used in tracing.
|
|
"""
|
|
|
|
def __init__(self, value):
|
|
self._value = value
|
|
|
|
def is_subtype_of(self, other) -> bool:
|
|
"""If the other spec is the same as `self`, return True."""
|
|
return self == other
|
|
|
|
def most_specific_common_supertype(self, others):
|
|
"""`self` is the common supertype if all input types match it."""
|
|
return self if all(self == other for other in others) else None
|
|
|
|
def placeholder_value(self, placeholder_context=None):
|
|
"""Use the AutoCastVariable value itself as a placeholder."""
|
|
return self._value
|
|
|
|
def _to_tensors(self, value):
|
|
return []
|
|
|
|
def __hash__(self) -> int:
|
|
return hash(id(self._value))
|
|
|
|
def __eq__(self, other) -> bool:
|
|
return self is other
|
|
|
|
|
|
class AutoCastVariable(tf.Variable, tf.__internal__.types.Tensor):
|
|
"""Variable that casts itself to a different dtype in applicable contexts.
|
|
|
|
This class wraps a floating-point `tf.Variable`. It emulates the variable
|
|
interface and delegates to the wrapped variable, but it additionally will
|
|
cast the wrapped variable under an `enable_auto_cast_variables(dtype)`
|
|
context manager.
|
|
|
|
For example:
|
|
|
|
>>> v = tf.Variable(1.0, dtype=tf.float32)
|
|
>>> v = AutoCastVariable(v)
|
|
>>> tf.identity(v).dtype
|
|
tf.float32
|
|
>>> with enable_auto_cast_variables(tf.float16):
|
|
... tf.identity(v).dtype
|
|
tf.float16
|
|
|
|
The purpose of this class is to allow Keras layers to create variables in
|
|
float32, and automatically cast them to float16 or bfloat16 when the layer
|
|
is called.
|
|
"""
|
|
|
|
def __init__(self, variable):
|
|
"""Creates an AutoCastVariable instance.
|
|
|
|
Args:
|
|
variable: A floating-point resource variable to wrap.
|
|
|
|
Raises:
|
|
ValueError: If `variable` is not a floating-point resource variable
|
|
"""
|
|
if not isinstance(variable, tf.Variable):
|
|
raise ValueError(
|
|
"variable must be of type tf.ResourceVariable, but got: %s"
|
|
% variable
|
|
)
|
|
if not variable.dtype.is_floating:
|
|
raise ValueError(
|
|
"variable must be a floating point variable but has type: %s"
|
|
% variable.dtype.name
|
|
)
|
|
self._variable = variable
|
|
# 'delegate' means AutoCastVariable.op return self._variable.op, which
|
|
# will raise an AttributeError in Eager (as intended). If set to any
|
|
# other value, AutoCastVariable.op returns that value instead, which is
|
|
# used to set the op attribute in AutoCastVariable.assign().
|
|
self._op = "delegate"
|
|
|
|
def _should_cast(self):
|
|
"""Returns True if this variable should be casted when accessed."""
|
|
autocast_dtype = getattr(_autocast_dtype, "dtype", None)
|
|
return autocast_dtype is not None and self.dtype != autocast_dtype
|
|
|
|
@property
|
|
def dtype(self):
|
|
"""The dtype of the underlying variable, before any casts are done."""
|
|
return self._variable.dtype
|
|
|
|
@property
|
|
def true_dtype(self):
|
|
"""Deprecated alias of `dtype`."""
|
|
return self._variable.dtype
|
|
|
|
@property
|
|
def _cast_dtype(self):
|
|
dtype = getattr(_autocast_dtype, "dtype", None)
|
|
return dtype or self._variable.dtype
|
|
|
|
def value(self):
|
|
val = self._variable.value()
|
|
if not self._should_cast():
|
|
return val
|
|
return tf.cast(val, self._cast_dtype)
|
|
|
|
def read_value(self):
|
|
val = self._variable.read_value()
|
|
return tf.cast(val, self._cast_dtype)
|
|
|
|
def sparse_read(self, indices, name=None):
|
|
"""Reads the value of this variable sparsely, using `gather`."""
|
|
val = self._variable.sparse_read(indices, name=name)
|
|
return tf.cast(val, self._cast_dtype)
|
|
|
|
def gather_nd(self, indices, name=None):
|
|
"""Gather slices of the variable into a Tensor."""
|
|
val = self._variable.gather_nd(indices, name=name)
|
|
return tf.cast(val, self._cast_dtype)
|
|
|
|
def __getattr__(self, name):
|
|
return getattr(self._variable, name)
|
|
|
|
def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
|
|
"""Converts this variable to a tensor."""
|
|
if as_ref:
|
|
# This ValueError should not occur in practice since it is
|
|
# impossible to pass as_ref=True using public APIs.
|
|
raise ValueError(
|
|
"Cannot convert AutoCastVariable to a tensor if "
|
|
"as_ref=True is passed to convert_to_tensor"
|
|
)
|
|
if not self._should_cast():
|
|
return tf.convert_to_tensor(self._variable, dtype=dtype, name=name)
|
|
if dtype is not None and not dtype.is_compatible_with(self._cast_dtype):
|
|
raise ValueError(
|
|
"Incompatible type conversion requested to type {!r} for "
|
|
"AutoCastVariable which is casted to type {!r}".format(
|
|
dtype.name, self._cast_dtype.name
|
|
)
|
|
)
|
|
val = tf.convert_to_tensor(
|
|
self._variable, dtype=self._variable.dtype, name=name
|
|
)
|
|
return tf.cast(val, self._cast_dtype)
|
|
|
|
def _should_act_as_resource_variable(self):
|
|
"""Pass resource_variable_ops.is_resource_variable check."""
|
|
pass
|
|
|
|
def __repr__(self):
|
|
if tf.executing_eagerly() and not self._in_graph_mode:
|
|
repr_str = (
|
|
"<AutoCastVariable '{v.name}' shape={v.shape} "
|
|
"dtype={v.dtype.name} dtype_to_cast_to={v._cast_dtype.name}, "
|
|
"numpy={np_repr}>"
|
|
)
|
|
return repr_str.format(
|
|
v=self, np_repr=numpy_text(self.read_value(), is_repr=True)
|
|
)
|
|
else:
|
|
repr_str = (
|
|
"<AutoCastVariable '{v.name}' shape={v.shape} "
|
|
"dtype={v.dtype.name} dtype_to_cast_to={v._cast_dtype.name}>"
|
|
)
|
|
return repr_str.format(v=self)
|
|
|
|
# Method delegations: We delegate the following methods to self._variable.
|
|
# Each of these methods simply calls the same method on self._variable. The
|
|
# base Variable raises NotImplementedError for most of these, so we must
|
|
# override them.
|
|
#
|
|
# We do not define the following methods from Variable for the following
|
|
# reasons:
|
|
# * 'count_up_to': This method only applies to int variables, which cannot
|
|
# be wrapped with an AutoCastVariable.
|
|
# * 'ref': Instead we inherit the definition from Variable.
|
|
# If we defined and delegated to Variable, the ref of an
|
|
# AutoCastVariable would be the same as the ref of the underlying
|
|
# variable, which would be strange as they are different Python objects.
|
|
|
|
def set_shape(self, shape):
|
|
return self._variable.set_shape(self, shape)
|
|
|
|
@property
|
|
def trainable(self):
|
|
return self._variable.trainable
|
|
|
|
@property
|
|
def synchronization(self):
|
|
return self._variable.synchronization
|
|
|
|
@property
|
|
def aggregation(self):
|
|
return self._variable.aggregation
|
|
|
|
def eval(self, session=None):
|
|
return self._variable.eval(session)
|
|
|
|
def initialized_value(self):
|
|
return self._variable.initialized_value()
|
|
|
|
@property
|
|
def initial_value(self):
|
|
return self._variable.initial_value
|
|
|
|
@property
|
|
def constraint(self):
|
|
return self._variable.constraint
|
|
|
|
def _apply_assign_update(
|
|
self, update_fn, value, use_locking=None, name=None, read_value=True
|
|
):
|
|
# TODO(b/146181571): This logic can be simplified once
|
|
# DistributedVariable.assign returns a DistributedVariable. Currently
|
|
# for MirroredStrategy, it returns a Mirrored value.
|
|
if tf.compat.v1.executing_eagerly_outside_functions():
|
|
assign_op = update_fn(value, use_locking, name, False)
|
|
if read_value:
|
|
# We create a new AutoCastVariable with the same underlying
|
|
# tf.Variable. The new AutoCastVariable is identical except the
|
|
# 'op' attribute is defined. This matches the behavior of
|
|
# tf.Variable.assign.
|
|
var = create_autocast_variable(self._variable)
|
|
var._op = assign_op
|
|
return var
|
|
return assign_op
|
|
|
|
# Fallback to wrapping the returned variable in graph mode if possible
|
|
assign_var = update_fn(value, use_locking, name, read_value)
|
|
if read_value and tf.__internal__.ops.is_resource_variable(assign_var):
|
|
return create_autocast_variable(assign_var)
|
|
return assign_var
|
|
|
|
def _apply_update(self, update_fn, *args, **kwargs):
|
|
update_var = update_fn(*args, **kwargs)
|
|
if tf.compat.v1.executing_eagerly_outside_functions():
|
|
return self
|
|
|
|
# Fallback to wrapping the returned variable in graph mode if possible
|
|
if tf.__internal__.ops.is_resource_variable(update_var):
|
|
return create_autocast_variable(update_var)
|
|
return update_var
|
|
|
|
def assign(self, value, use_locking=None, name=None, read_value=True):
|
|
return self._apply_assign_update(
|
|
self._variable.assign, value, use_locking, name, read_value
|
|
)
|
|
|
|
def assign_add(self, delta, use_locking=None, name=None, read_value=True):
|
|
return self._apply_assign_update(
|
|
self._variable.assign_add, delta, use_locking, name, read_value
|
|
)
|
|
|
|
def assign_sub(self, delta, use_locking=None, name=None, read_value=True):
|
|
return self._apply_assign_update(
|
|
self._variable.assign_sub, delta, use_locking, name, read_value
|
|
)
|
|
|
|
def scatter_sub(self, sparse_delta, use_locking=False, name=None):
|
|
return self._apply_update(
|
|
self._variable.scatter_sub, sparse_delta, use_locking, name
|
|
)
|
|
|
|
def scatter_add(self, sparse_delta, use_locking=False, name=None):
|
|
return self._apply_update(
|
|
self._variable.scatter_add, sparse_delta, use_locking, name
|
|
)
|
|
|
|
def scatter_max(self, sparse_delta, use_locking=False, name=None):
|
|
return self._apply_update(
|
|
self._variable.scatter_max, sparse_delta, use_locking, name
|
|
)
|
|
|
|
def scatter_min(self, sparse_delta, use_locking=False, name=None):
|
|
return self._apply_update(
|
|
self._variable.scatter_min, sparse_delta, use_locking, name
|
|
)
|
|
|
|
def scatter_mul(self, sparse_delta, use_locking=False, name=None):
|
|
return self._apply_update(
|
|
self._variable.scatter_mul, sparse_delta, use_locking, name
|
|
)
|
|
|
|
def scatter_div(self, sparse_delta, use_locking=False, name=None):
|
|
return self._apply_update(
|
|
self._variable.scatter_div, sparse_delta, use_locking, name
|
|
)
|
|
|
|
def scatter_update(self, sparse_delta, use_locking=False, name=None):
|
|
return self._apply_update(
|
|
self._variable.scatter_update, sparse_delta, use_locking, name
|
|
)
|
|
|
|
def batch_scatter_update(self, sparse_delta, use_locking=False, name=None):
|
|
return self._apply_update(
|
|
self._variable.batch_scatter_update, sparse_delta, use_locking, name
|
|
)
|
|
|
|
def scatter_nd_sub(self, indices, updates, name=None):
|
|
return self._apply_update(
|
|
self._variable.scatter_nd_sub, indices, updates, name
|
|
)
|
|
|
|
def scatter_nd_add(self, indices, updates, name=None):
|
|
return self._apply_update(
|
|
self._variable.scatter_nd_add, indices, updates, name
|
|
)
|
|
|
|
def scatter_nd_update(self, indices, updates, name=None):
|
|
return self._apply_update(
|
|
self._variable.scatter_nd_update, indices, updates, name
|
|
)
|
|
|
|
def load(self, value, session=None):
|
|
return self._variable.load(value, session)
|
|
|
|
@property
|
|
def name(self):
|
|
return self._variable.name
|
|
|
|
@property
|
|
def _shared_name(self):
|
|
return self._variable._shared_name
|
|
|
|
@property
|
|
def initializer(self):
|
|
return self._variable.initializer
|
|
|
|
@property
|
|
def device(self):
|
|
return self._variable.device
|
|
|
|
@property
|
|
def op(self):
|
|
if self._op == "delegate":
|
|
return self._variable.op
|
|
return self._op
|
|
|
|
def _as_graph_element(self):
|
|
graph_element = self._variable._as_graph_element()
|
|
if graph_element is None:
|
|
return self._op
|
|
return graph_element
|
|
|
|
@property
|
|
def graph(self):
|
|
return self._variable.graph
|
|
|
|
@property
|
|
def shape(self):
|
|
return self._variable.shape
|
|
|
|
def get_shape(self):
|
|
return self._variable.get_shape()
|
|
|
|
def __tf_tracing_type__(self, context):
|
|
return AutoCastVariableSpec(self)
|
|
|
|
def _gather_saveables_for_checkpoint(self):
|
|
# By delegating this method to the wrapped variable, checkpoints with
|
|
# AutoCastVariables are identical to checkpoints with normal variables.
|
|
# Therefore models checkpointed with AutoCastVariables can be restored
|
|
# on models with normal variables, and vice versa.
|
|
return self._variable._gather_saveables_for_checkpoint()
|
|
|
|
def _export_to_saved_model_graph(
|
|
self, object_map, tensor_map, options, **kwargs
|
|
):
|
|
# By delegating this method to the wrapped variable, SavedModel with
|
|
# AutoCastVariables are identical to SavedModel with normal variables.
|
|
resource_list = self._variable._export_to_saved_model_graph(
|
|
object_map, tensor_map, options, **kwargs
|
|
)
|
|
object_map[self] = object_map[self._variable]
|
|
return resource_list
|
|
|
|
# TODO(reedwm): Maybe encode the fact the variable is an AutoCastVariable in
|
|
# to_proto().
|
|
def to_proto(self, export_scope=None):
|
|
return self._variable.to_proto(export_scope)
|
|
|
|
def from_proto(self, variable_def, import_scope=None):
|
|
return self._variable.from_proto(variable_def, import_scope)
|
|
|
|
# Delegate the private attributes _handle_name and _initializer_op to
|
|
# self._variable. SavedModel sets these attributes when loading a model. For
|
|
# example, it sets _handle_name here:
|
|
# https://github.com/tensorflow/tensorflow/blob/db26bd574fa95b5bdd53c08463dd19407cc0297e/tensorflow/python/keras/saving/saved_model/load.py#L211
|
|
# We need to expose these attributes on AutoCastVariable as well for
|
|
# SavedModel to work properly.
|
|
# TODO(reedwm/kathywu): Find a better way to support SavedModel. Exposing
|
|
# private attributes is hacky and difficult to maintain.
|
|
@property
|
|
def _handle_name(self):
|
|
return self._variable._handle_name
|
|
|
|
@_handle_name.setter
|
|
def _handle_name(self, handle_name):
|
|
self._variable._handle_name = handle_name
|
|
|
|
@property
|
|
def _initializer_op(self):
|
|
return self._variable._initializer_op
|
|
|
|
@_initializer_op.setter
|
|
def _initializer_op(self, initializer_op):
|
|
self._variable._initializer_op = initializer_op
|
|
|
|
# Operator overloads:
|
|
# Note we only overload operators that support floating-point types, as
|
|
# non-float variables cannot be wrapped with an AutoCastVariable.
|
|
# Also note: We call read_value() instead of value(), because value() causes
|
|
# gradients not to work properly when TPUStrategy is used: b/143380936
|
|
|
|
def __add__(self, o):
|
|
return self.read_value() + o
|
|
|
|
def __radd__(self, o):
|
|
return o + self.read_value()
|
|
|
|
def __sub__(self, o):
|
|
return self.read_value() - o
|
|
|
|
def __rsub__(self, o):
|
|
return o - self.read_value()
|
|
|
|
def __mul__(self, o):
|
|
return self.read_value() * o
|
|
|
|
def __rmul__(self, o):
|
|
return o * self.read_value()
|
|
|
|
def __truediv__(self, o):
|
|
return self.read_value() / o
|
|
|
|
def __rtruediv__(self, o):
|
|
return o / self.read_value()
|
|
|
|
def __floordiv__(self, o):
|
|
return self.read_value() // o
|
|
|
|
def __rfloordiv__(self, o):
|
|
return o // self.read_value()
|
|
|
|
def __mod__(self, o):
|
|
return self.read_value() % o
|
|
|
|
def __rmod__(self, o):
|
|
return o % self.read_value()
|
|
|
|
def __lt__(self, o):
|
|
return self.read_value() < o
|
|
|
|
def __le__(self, o):
|
|
return self.read_value() <= o
|
|
|
|
def __gt__(self, o):
|
|
return self.read_value() > o
|
|
|
|
def __ge__(self, o):
|
|
return self.read_value() >= o
|
|
|
|
def __getitem__(self, o):
|
|
return self.read_value()[o]
|
|
|
|
def __pow__(self, o, modulo=None):
|
|
return pow(self.read_value(), o, modulo)
|
|
|
|
def __rpow__(self, o):
|
|
return pow(o, self.read_value())
|
|
|
|
def __neg__(self):
|
|
return -self.read_value()
|
|
|
|
def __abs__(self):
|
|
return abs(self.read_value())
|
|
|
|
def __div__(self, o):
|
|
try:
|
|
return self.read_value().__div__(o)
|
|
except AttributeError:
|
|
# See
|
|
# https://docs.python.org/3/library/constants.html#NotImplemented
|
|
return NotImplemented
|
|
|
|
def __rdiv__(self, o):
|
|
try:
|
|
return self.read_value().__rdiv__(o)
|
|
except AttributeError:
|
|
# See
|
|
# https://docs.python.org/3/library/constants.html#NotImplemented
|
|
return NotImplemented
|
|
|
|
def __matmul__(self, o):
|
|
try:
|
|
return self.read_value().__matmul__(o)
|
|
except AttributeError:
|
|
# See
|
|
# https://docs.python.org/3/library/constants.html#NotImplemented
|
|
return NotImplemented
|
|
|
|
def __rmatmul__(self, o):
|
|
try:
|
|
return self.read_value().__rmatmul__(o)
|
|
except AttributeError:
|
|
# See
|
|
# https://docs.python.org/3/library/constants.html#NotImplemented
|
|
return NotImplemented
|
|
|
|
|
|
tf.register_tensor_conversion_function(
|
|
AutoCastVariable, AutoCastVariable._dense_var_to_tensor
|
|
)
|
|
|
|
|
|
def create_autocast_variable(variable):
|
|
"""Creates an AutoCastVariable that wraps another variable.
|
|
|
|
This typically just returns `AutoCastVariable(variable)`. But, if the
|
|
variable is a DistributedVariable or one of its subclasses, we instead
|
|
dynamically create a class that subclasses from both AutoCastVariable and
|
|
variable.__class__. This is so the returned variable will still pass
|
|
`isinstance(variable, variable.__class__)`, which is required for
|
|
DistributedVariables and its subclasses to work properly.
|
|
|
|
Args:
|
|
variable: A floating-point resource variable to wrap.
|
|
|
|
Returns:
|
|
An AutoCastVariable that wraps the variable.
|
|
"""
|
|
if not distributed_training_utils.is_distributed_variable(variable):
|
|
return AutoCastVariable(variable)
|
|
|
|
class AutoCastDistributedVariable(AutoCastVariable, variable.__class__):
|
|
"""An AutoCastVariable that also subclasses from variable.__class__.
|
|
|
|
variable.__class__ is either a DistributedVariable or an
|
|
AggregatingVariable.
|
|
"""
|
|
|
|
def __repr__(self):
|
|
|
|
return (
|
|
"<AutoCastDistributedVariable dtype={v.dtype.name} "
|
|
"dtype_to_cast_to={v._cast_dtype.name} "
|
|
"inner_variable={v._variable}>"
|
|
).format(v=self)
|
|
|
|
return AutoCastDistributedVariable(variable)
|
|
|
|
|
|
class enable_auto_cast_variables:
|
|
"""Context manager which enables the autocasting of `AutoCastVariable`s.
|
|
|
|
Under this context manager, `AutoCastVariable`s will be cast to `dtype` if
|
|
`dtype` is floating-point. Otherwise, `AutoCastVariable`s will not be cast.
|
|
"""
|
|
|
|
__slots__ = ["_dtype", "_prev_dtype"]
|
|
|
|
def __init__(self, dtype):
|
|
if dtype and not dtype.is_floating:
|
|
dtype = None
|
|
self._dtype = dtype
|
|
|
|
def __enter__(self):
|
|
self._prev_dtype = getattr(_autocast_dtype, "dtype", None)
|
|
_autocast_dtype.dtype = self._dtype
|
|
|
|
def __exit__(self, type_arg, value_arg, traceback_arg):
|
|
_autocast_dtype.dtype = self._prev_dtype
|