3RNN/Lib/site-packages/tensorflow/python/distribute/sharded_variable.py
2024-05-26 19:49:15 +02:00

1041 lines
37 KiB
Python

# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""ShardedVariable class."""
import copy
import math
from typing import Sequence
import weakref
import numpy as np
from tensorflow.python.framework import composite_tensor
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import indexed_slices as indexed_slices_lib
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor as tensor_lib
from tensorflow.python.framework import tensor_conversion_registry
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import type_spec
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import partitioned_variables
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.saved_model import save_context
from tensorflow.python.trackable import base as trackable
from tensorflow.python.training.saving import saveable_object_util
from tensorflow.python.util import dispatch
from tensorflow.python.util.tf_export import tf_export
@tf_export('distribute.experimental.partitioners.Partitioner', v1=[])
class Partitioner(object):
"""Partitioner base class: all partitiners inherit from this class.
Partitioners should implement a `__call__` method with the following
signature:
```python
def __call__(self, shape, dtype, axis=0):
# Partitions the given `shape` and returns the partition results.
# See docstring of `__call__` method for the format of partition results.
```
"""
def __call__(self, shape, dtype, axis=0):
"""Partitions the given `shape` and returns the partition results.
Examples of a partitioner that allocates a fixed number of shards:
```python
partitioner = FixedShardsPartitioner(num_shards=2)
partitions = partitioner(tf.TensorShape([10, 3], tf.float32), axis=0)
print(partitions) # [2, 0]
```
Args:
shape: a `tf.TensorShape`, the shape to partition.
dtype: a `tf.dtypes.Dtype` indicating the type of the partition value.
axis: The axis to partition along. Default: outermost axis.
Returns:
A list of integers representing the number of partitions on each axis,
where i-th value correponds to i-th axis.
"""
raise NotImplementedError
@tf_export('distribute.experimental.partitioners.FixedShardsPartitioner', v1=[])
class FixedShardsPartitioner(Partitioner):
"""Partitioner that allocates a fixed number of shards.
Examples:
>>> # standalone usage:
>>> partitioner = FixedShardsPartitioner(num_shards=2)
>>> partitions = partitioner(tf.TensorShape([10, 3]), tf.float32)
>>> [2, 1]
>>>
>>> # use in ParameterServerStrategy
>>> # strategy = tf.distribute.experimental.ParameterServerStrategy(
>>> # cluster_resolver=cluster_resolver, variable_partitioner=partitioner)
"""
def __init__(self, num_shards):
"""Creates a new `FixedShardsPartitioner`.
Args:
num_shards: `int`, number of shards to partition.
"""
self._num_shards = num_shards
def __call__(self, shape, dtype, axis=0):
del dtype
result = [1] * len(shape)
result[axis] = min(self._num_shards, shape.dims[axis].value)
return result
@tf_export('distribute.experimental.partitioners.MinSizePartitioner', v1=[])
class MinSizePartitioner(Partitioner):
"""Partitioner that allocates a minimum size per shard.
This partitioner ensures each shard has at least `min_shard_bytes`, and tries
to allocate as many shards as possible, i.e., keeping shard size as small as
possible. The maximum number of such shards (upper bound) is given by
`max_shards`.
Examples:
>>> partitioner = MinSizePartitioner(min_shard_bytes=4, max_shards=2)
>>> partitions = partitioner(tf.TensorShape([6, 1]), tf.float32)
>>> [2, 1]
>>> partitioner = MinSizePartitioner(min_shard_bytes=4, max_shards=10)
>>> partitions = partitioner(tf.TensorShape([6, 1]), tf.float32)
>>> [6, 1]
>>>
>>> # use in ParameterServerStrategy
>>> # strategy = tf.distribute.experimental.ParameterServerStrategy(
>>> # cluster_resolver=cluster_resolver, variable_partitioner=partitioner)
"""
def __init__(
self, min_shard_bytes=256 << 10, max_shards=1, bytes_per_string=16
):
"""Creates a new `MinSizePartitioner`.
Args:
min_shard_bytes: Minimum bytes of each shard. Defaults to 256K.
max_shards: Upper bound on the number of shards. Defaults to 1.
bytes_per_string: If the partition value is of type string, this provides
an estimate of how large each string is.
"""
if min_shard_bytes < 1:
raise ValueError(
'Argument `min_shard_bytes` must be positive. '
f'Received: {min_shard_bytes}'
)
if max_shards < 1:
raise ValueError(
f'Argument `max_shards` must be positive. Received: {max_shards}'
)
if bytes_per_string < 1:
raise ValueError(
'Argument `bytes_per_string` must be positive. '
f'Received: {bytes_per_string}'
)
self._min_shard_bytes = min_shard_bytes
self._max_shards = max_shards
self._bytes_per_string = bytes_per_string
def __call__(self, shape, dtype, axis=0):
return partitioned_variables.min_max_variable_partitioner(
max_partitions=self._max_shards,
axis=axis,
min_slice_size=self._min_shard_bytes,
bytes_per_string_element=self._bytes_per_string,
)(shape, dtype)
@tf_export('distribute.experimental.partitioners.MaxSizePartitioner', v1=[])
class MaxSizePartitioner(Partitioner):
"""Partitioner that keeps shards below `max_shard_bytes`.
This partitioner ensures each shard has at most `max_shard_bytes`, and tries
to allocate as few shards as possible, i.e., keeping shard size as large
as possible.
If the partitioner hits the `max_shards` limit, then each shard may end up
larger than `max_shard_bytes`. By default `max_shards` equals `None` and no
limit on the number of shards is enforced.
Examples:
>>> partitioner = MaxSizePartitioner(max_shard_bytes=4)
>>> partitions = partitioner(tf.TensorShape([6, 1]), tf.float32)
>>> [6, 1]
>>> partitioner = MaxSizePartitioner(max_shard_bytes=4, max_shards=2)
>>> partitions = partitioner(tf.TensorShape([6, 1]), tf.float32)
>>> [2, 1]
>>> partitioner = MaxSizePartitioner(max_shard_bytes=1024)
>>> partitions = partitioner(tf.TensorShape([6, 1]), tf.float32)
>>> [1, 1]
>>>
>>> # use in ParameterServerStrategy
>>> # strategy = tf.distribute.experimental.ParameterServerStrategy(
>>> # cluster_resolver=cluster_resolver, variable_partitioner=partitioner)
"""
def __init__(self, max_shard_bytes, max_shards=None, bytes_per_string=16):
"""Creates a new `MaxSizePartitioner`.
Args:
max_shard_bytes: The maximum size any given shard is allowed to be.
max_shards: The maximum number of shards in `int` created taking
precedence over `max_shard_bytes`.
bytes_per_string: If the partition value is of type string, this provides
an estimate of how large each string is.
"""
if max_shard_bytes < 1:
raise ValueError(
'Argument `max_shard_bytes` must be positive. '
f'Received {max_shard_bytes}'
)
if max_shards and max_shards < 1:
raise ValueError(
f'Argument `max_shards` must be positive. Received {max_shards}'
)
if bytes_per_string < 1:
raise ValueError(
'Argument `bytes_per_string` must be positive. '
f'Received: {bytes_per_string}'
)
self._max_shard_bytes = max_shard_bytes
self._max_shards = max_shards
self._bytes_per_string = bytes_per_string
def __call__(self, shape, dtype, axis=0):
return partitioned_variables.variable_axis_size_partitioner(
max_shard_bytes=self._max_shard_bytes,
max_shards=self._max_shards,
bytes_per_string_element=self._bytes_per_string,
axis=axis,
)(shape, dtype)
class ShardedVariableSpec(type_spec.TypeSpec):
"""Type specification for a `ShardedVariable`."""
__slots__ = ['_variable_specs']
value_type = property(lambda self: ShardedVariable)
def __init__(self, *variable_specs):
self._variable_specs = tuple(variable_specs)
def _serialize(self):
return self._variable_specs
@property
def _component_specs(self):
return self._variable_specs
def _to_components(self, value):
return tuple(value.variables)
def _from_components(self, variables):
return ShardedVariable(variables)
def _cast(self, value, _):
return value
class ShardedVariableMixin(trackable.Trackable):
"""Mixin for ShardedVariable."""
# TODO(b/170877138): Remove this mixin once fixed. This mixin is required
# since TPUEmbeddingVariable can't be a CompositeTensor.
def __init__(self, variables, name='ShardedVariable'):
"""Treats `variables` as shards of a larger Variable.
Example:
```
variables = [
tf.Variable(..., shape=(10, 100), dtype=tf.float32),
tf.Variable(..., shape=(15, 100), dtype=tf.float32),
tf.Variable(..., shape=(5, 100), dtype=tf.float32)
]
sharded_variable = ShardedVariableMixin(variables)
assert sharded_variable.shape.as_list() == [30, 100]
```
Args:
variables: A list of `ResourceVariable`s that comprise this sharded
variable. Variables should not be shared between different
`ShardedVariableMixin` objects.
name: String. Name of this container. Defaults to "ShardedVariable".
"""
super(ShardedVariableMixin, self).__init__()
self._variables = variables
self._name = name
if (
not isinstance(variables, Sequence)
or not variables
or any(not isinstance(v, variables_lib.Variable) for v in variables)
):
raise TypeError(
'Argument `variables` should be a non-empty list of '
f'`variables.Variable`s. Received {variables}'
)
var_dtypes = {v.dtype for v in variables}
if len(var_dtypes) > 1:
raise ValueError(
'All elements in argument `variables` must have the same dtype. '
f'Received dtypes: {[v.dtype for v in variables]}'
)
first_var = variables[0]
self._dtype = first_var.dtype
# All variables must have the same shape for axes > 0.
higher_dim_shapes = {tuple(v.shape.as_list()[1:]) for v in variables}
if len(higher_dim_shapes) > 1:
raise ValueError(
'All elements in argument `variables` must have the same shapes '
'except for the first axis. '
f'Received shapes: {[v.shape for v in variables]}'
)
first_dim = sum(int(v.shape.as_list()[0]) for v in variables)
self._shape = tensor_shape.TensorShape(
[first_dim] + first_var.shape.as_list()[1:]
)
for v in variables:
v._sharded_container = weakref.ref(self)
self._var_offsets = [
[0 for _ in range(len(first_var.shape))] for _ in range(len(variables))
]
for i in range(1, len(variables)):
# Always partition on the first axis. Offsets on other axes are 0.
self._var_offsets[i][0] += (
self._var_offsets[i - 1][0] + variables[i - 1].shape.as_list()[0]
)
save_slice_info = [v._get_save_slice_info() for v in variables] # pylint: disable=protected-access
if any(slice_info is not None for slice_info in save_slice_info):
raise ValueError(
'`SaveSliceInfo` should not be set for all elements in argument '
'`variables`. `ShardedVariable` will infer `SaveSliceInfo` according '
'to the order of the elements `variables`. '
f'Received save slice info {save_slice_info}'
)
# We create an uninitialized saving_variable with the full shape, which can
# be later captured in signatures so that the signatures can treat this
# ShardedVariable as one single variable.
self._saving_variable = resource_variable_ops.UninitializedVariable(
shape=self._shape,
dtype=self._dtype,
name=self._name,
trainable=self._variables[0].trainable,
synchronization=variables_lib.VariableSynchronization.NONE,
aggregation=variables_lib.VariableAggregation.NONE,
)
def __iter__(self):
"""Return an iterable for accessing the underlying sharded variables."""
return iter(self._variables)
def __getitem__(self, slice_spec):
"""Extracts the specified region as a Tensor from the sharded variable.
The API contract is identical to `Tensor.__getitem__`. Assignment to the
sliced range is not yet supported.
Args:
slice_spec: The arguments to __getitem__, specifying the global slicing of
the sharded variable.
Returns:
The appropriate slice of tensor based on `slice_spec`.
Raises:
IndexError: If a slice index is out of bound.
TypeError: If `spec_spec` contains Tensor.
"""
# TODO(b/177482728): Support tensor input.
# TODO(b/177482728): Support slice assign, similar to variable slice assign.
if (
isinstance(slice_spec, bool)
or (
isinstance(slice_spec, tensor_lib.Tensor)
and slice_spec.dtype == dtypes.bool
)
or (isinstance(slice_spec, np.ndarray) and slice_spec.dtype == bool)
):
tensor = _var_to_tensor(self)
return array_ops.boolean_mask(tensor=tensor, mask=slice_spec)
if not isinstance(slice_spec, (list, tuple)):
slice_spec = (slice_spec,)
s = slice_spec[0]
if isinstance(s, slice):
first_dim_slice_specs = self._decompose_slice_spec(s)
values = []
for i, var in enumerate(self._variables):
if first_dim_slice_specs[i] is not None:
all_dim_slice_spec = (first_dim_slice_specs[i],) + slice_spec[1:]
values.append(var[all_dim_slice_spec])
if s.step is not None and s.step < 0:
values.reverse()
if not values:
return constant_op.constant(
[], dtype=self._dtype, shape=((0,) + self._shape[1:])
)
return array_ops.concat(values, axis=0)
elif s is Ellipsis:
return array_ops.concat(
[var[slice_spec] for var in self._variables], axis=0
)
elif s is array_ops.newaxis:
return array_ops.concat(
[var[slice_spec[1:]] for var in self._variables], axis=0
)[array_ops.newaxis]
else:
if isinstance(s, tensor_lib.Tensor):
raise TypeError(
'ShardedVariable: using Tensor for indexing is not allowed.'
)
if s < 0:
s += self._shape[0]
if s < 0 or s >= self._shape[0]:
raise IndexError(
f'ShardedVariable: slice index {s} of dimension 0 out of bounds.'
)
for i in range(len(self._variables)):
if i == len(self._variables) - 1 or (
s >= self._var_offsets[i][0] and s < self._var_offsets[i + 1][0]
):
return self._variables[i][
(s - self._var_offsets[i][0],) + slice_spec[1:]
]
def _decompose_slice_spec(self, slice_spec):
"""Decompose a global slice_spec into a list of per-variable slice_spec.
`ShardedVariable` only supports first dimension partitioning, thus
`slice_spec` must be for first dimension.
Args:
slice_spec: A python `slice` object that specifies the global slicing.
Returns:
A list of python `slice` objects or None specifying the local slicing for
each component variable. None means no slicing.
For example, given component variables:
v0 = [0, 1, 2]
v1 = [3, 4, 5]
v2 = [6, 7, 8, 9]
If `slice_spec` is slice(start=None, stop=None, step=None), we will have:
v0[returned[0]] = [0, 1, 2]
v1[returned[1]] = [3, 4, 5]
v2[returned[2]] = [6, 7, 8, 9]
If `slice_spec` is slice(start=2, stop=8, step=3), we will have:
v0[returned[0]] = [2]
v1[returned[1]] = [5]
returned[2] == None
If `slice_spec` is slice(start=9, stop=3, step=-2), we will have:
returned[0] == None
v1[returned[1]] = [5]
v2[returned[2]] = [9, 7]
"""
if (
isinstance(slice_spec.start, tensor_lib.Tensor)
or isinstance(slice_spec.stop, tensor_lib.Tensor)
or isinstance(slice_spec.step, tensor_lib.Tensor)
):
raise TypeError(
'ShardedVariable: using Tensor in slice_spec is not allowed. Please '
'file a feature request with the TensorFlow team.'
)
result = []
# Normalize start, end and stop.
slice_step = slice_spec.step if slice_spec.step is not None else 1
if slice_step == 0:
raise ValueError('slice step cannot be zero')
slice_start = slice_spec.start
if slice_start is None:
slice_start = 0 if slice_step > 0 else self._shape[0] - 1
elif slice_start < 0:
slice_start += self._shape[0]
slice_end = slice_spec.stop
if slice_end is None:
# After the normalization, we no longer interpret negative index, thus
# "-1" conceptually refers to the element before the first one, which
# doesn't exist. This is to ease the decomposition code.
slice_end = self._shape[0] if slice_step > 0 else -1
elif slice_end < 0:
slice_end += self._shape[0]
# To find the local slice_spec of each component variable, we start from
# the start of the global slice, and iterate through each variable.
# When iterating on a variable, we move the cursor (`cur`) to the first
# index that falls into the variable's range, which becomes the start of
# the variable's local slice_spec. The end of the local_spec is determined
# by using whatever is smaller between global slice end and variable range
# end.
cur = slice_start
if slice_step > 0:
for i in range(len(self._var_offsets)):
var_start = self._var_offsets[i][0]
var_end = (
self._var_offsets[i + 1][0]
if i < len(self._var_offsets) - 1
else self._shape[0]
)
if cur < var_start:
cur += slice_step * int(math.ceil((var_start - cur) / slice_step))
if cur >= var_end or cur >= slice_end:
result.append(None)
else:
start = cur - var_start
end = min(slice_end, var_end) - var_start
result.append(slice(start, end, slice_step))
else: # slice_step < 0
for i in range(len(self._var_offsets) - 1, -1, -1):
var_start = self._var_offsets[i][0]
var_end = (
self._var_offsets[i + 1][0]
if i < len(self._var_offsets) - 1
else self._shape[0]
)
if cur >= var_end:
cur += slice_step * int(math.ceil((var_end - cur - 1) / slice_step))
if cur < var_start or cur <= slice_end:
result.append(None)
else:
start = cur - var_start
if slice_end >= var_start:
end = slice_end - var_start
else:
end = None # no explicit end: slice until hitting the boundary.
result.append(slice(start, end, slice_step))
result.reverse()
return result
@property
def _type_spec(self):
return ShardedVariableSpec(
*(
resource_variable_ops.VariableSpec(v.shape, v.dtype)
for v in self._variables
)
)
@property
def variables(self):
"""The list of `Variable`s that make up the shards of this object."""
if save_context.in_save_context():
return [self._saving_variable]
return self._variables
@property
def name(self):
"""The name of this object. Used for checkpointing."""
return self._name
@property
def dtype(self):
"""The dtype of all `Variable`s in this object."""
return self._dtype
@property
def shape(self):
"""The overall shape, combining all shards along axis `0`."""
return self._shape
def assign(self, value, use_locking=None, name=None, read_value=True):
for i, v in enumerate(self._variables):
v.assign(array_ops.slice(value, self._var_offsets[i], v.shape.as_list()))
return self
def assign_add(self, delta, use_locking=False, name=None, read_value=True):
for i, v in enumerate(self._variables):
v.assign_add(
array_ops.slice(delta, self._var_offsets[i], v.shape.as_list())
)
return self
def assign_sub(self, delta, use_locking=False, name=None, read_value=True):
for i, v in enumerate(self._variables):
v.assign_sub(
array_ops.slice(delta, self._var_offsets[i], v.shape.as_list())
)
return self
def _decompose_indices(self, indices):
"""Decompose a global 1D indices into a list of per-variable indices."""
if indices.shape.rank != 1:
raise ValueError(
'ShardedVariable: indices must be 1D Tensor for sparse operations. '
f'Received shape: {indices.shape}'
)
base = self._shape[0] // len(self._variables)
extra = self._shape[0] % len(self._variables)
# Assert that sharding conforms to "div" sharding
expect_first_dim = [base] * len(self._variables)
for i in range(extra):
expect_first_dim[i] = expect_first_dim[i] + 1
actual_first_dim = [v.shape.as_list()[0] for v in self._variables]
if expect_first_dim != actual_first_dim:
raise NotImplementedError(
'scater_xxx ops are not supported in ShardedVariale that does not '
'conform to "div" sharding'
)
# For index that falls into the partition that has extra 1, assignment is
# `index // (base + 1)` (no less than `(indices - extra) // base`)
# For index that falls into the partition that doesn't has extra 1,
# assignment is `(indices - extra) // base` (no less than
# `indices // (base + 1)`)
#
# Example:
# base = 10, extra = 2, partitions: [0, 11), [11, 22), [22, 32)
# index = 10 -> partition_assigment = 0
# index = 22 -> partition_assiment = 2
partition_assignments = math_ops.maximum(
indices // (base + 1), (indices - extra) // base
)
local_indices = array_ops.where(
partition_assignments < extra,
indices % (base + 1),
(indices - extra) % base,
)
# For whatever reason `dynamic_partition` only supports int32
partition_assignments = math_ops.cast(partition_assignments, dtypes.int32)
per_var_indices = data_flow_ops.dynamic_partition(
local_indices, partition_assignments, len(self._variables)
)
return per_var_indices, partition_assignments
def _decompose_indexed_slices(self, indexed_slices):
"""Decompose a global `IndexedSlices` into a list of per-variable ones."""
per_var_indices, partition_assignments = self._decompose_indices(
indexed_slices.indices
)
per_var_values = data_flow_ops.dynamic_partition(
indexed_slices.values, partition_assignments, len(self._variables)
)
return [
indexed_slices_lib.IndexedSlices(
values=per_var_values[i], indices=per_var_indices[i]
)
for i in range(len(self._variables))
]
# ==================== scatter ops implementations ======================== #
def scatter_add(self, sparse_delta, use_locking=False, name=None):
"""Implements tf.Variable.scatter_add."""
per_var_sparse_delta = self._decompose_indexed_slices(sparse_delta)
for i, v in enumerate(self._variables):
new_name = None
if name is not None:
new_name = '{}/part_{}'.format(name, i)
v.scatter_add(per_var_sparse_delta[i], name=new_name)
return self
def scatter_div(self, sparse_delta, use_locking=False, name=None):
"""Implements tf.Variable.scatter_div."""
per_var_sparse_delta = self._decompose_indexed_slices(sparse_delta)
for i, v in enumerate(self._variables):
new_name = None
if name is not None:
new_name = '{}/part_{}'.format(name, i)
v.scatter_div(per_var_sparse_delta[i], name=new_name)
return self
def scatter_max(self, sparse_delta, use_locking=False, name=None):
"""Implements tf.Variable.scatter_max."""
per_var_sparse_delta = self._decompose_indexed_slices(sparse_delta)
for i, v in enumerate(self._variables):
new_name = None
if name is not None:
new_name = '{}/part_{}'.format(name, i)
v.scatter_max(per_var_sparse_delta[i], name=new_name)
return self
def scatter_min(self, sparse_delta, use_locking=False, name=None):
"""Implements tf.Variable.scatter_min."""
per_var_sparse_delta = self._decompose_indexed_slices(sparse_delta)
for i, v in enumerate(self._variables):
new_name = None
if name is not None:
new_name = '{}/part_{}'.format(name, i)
v.scatter_min(per_var_sparse_delta[i], name=new_name)
return self
def scatter_mul(self, sparse_delta, use_locking=False, name=None):
"""Implements tf.Variable.scatter_mul."""
per_var_sparse_delta = self._decompose_indexed_slices(sparse_delta)
for i, v in enumerate(self._variables):
new_name = None
if name is not None:
new_name = '{}/part_{}'.format(name, i)
v.scatter_mul(per_var_sparse_delta[i], name=new_name)
return self
def scatter_sub(self, sparse_delta, use_locking=False, name=None):
"""Implements tf.Variable.scatter_sub."""
per_var_sparse_delta = self._decompose_indexed_slices(sparse_delta)
for i, v in enumerate(self._variables):
new_name = None
if name is not None:
new_name = '{}/part_{}'.format(name, i)
v.scatter_sub(per_var_sparse_delta[i], name=new_name)
return self
def scatter_update(self, sparse_delta, use_locking=False, name=None):
"""Implements tf.Variable.scatter_update."""
per_var_sparse_delta = self._decompose_indexed_slices(sparse_delta)
for i, v in enumerate(self._variables):
new_name = None
if name is not None:
new_name = '{}/part_{}'.format(name, i)
v.scatter_update(per_var_sparse_delta[i], name=new_name)
return self
def batch_scatter_update(self, sparse_delta, use_locking=False, name=None):
"""Implements tf.Variable.batch_scatter_update."""
per_var_sparse_delta = self._decompose_indexed_slices(sparse_delta)
for i, v in enumerate(self._variables):
new_name = None
if name is not None:
new_name = '{}/part_{}'.format(name, i)
v.batch_scatter_update(per_var_sparse_delta[i], name=new_name)
return self
# ================== scatter ops implementations END ====================== #
def sparse_read(self, indices, name=None):
"""Implements tf.Variable.sparse_read."""
per_var_indices, _ = self._decompose_indices(indices)
result = []
for i, v in enumerate(self._variables):
new_name = None
if name is not None:
new_name = '{}/part_{}'.format(name, i)
result.append(v.sparse_read(per_var_indices[i], name=new_name))
return array_ops.concat(result, axis=0)
def _gather_saveables_for_checkpoint(self):
"""Return a `Saveable` for each shard. See `Trackable`."""
def _saveable_factory(name=self.name):
"""Creates `SaveableObject`s for this `ShardedVariable`."""
saveables = []
dims = len(self._variables[0].shape)
var_offset = [0 for _ in range(dims)]
for v in self._variables:
save_slice_info = variables_lib.Variable.SaveSliceInfo(
full_name=self.name,
full_shape=self.shape.as_list(),
var_offset=copy.copy(var_offset),
var_shape=v.shape.as_list(),
)
saveables.append(
saveable_object_util.ResourceVariableSaveable(
v, save_slice_info.spec, name
)
)
var_offset[0] += int(v.shape[0])
return saveables
return {trackable.VARIABLE_VALUE_KEY: _saveable_factory}
def _copy_trackable_to_cpu(self, object_map):
"""For implementing `Trackable` async checkpointing."""
# This is not implemented in the ShardedVariableMixin class because multiple
# classes inherit from it. If your class contains values that should be
# copied to CPU for async checkpointing, please implement this in the class
# definition.
def _export_to_saved_model_graph(
self, object_map, tensor_map, options, **kwargs
):
"""For implementing `Trackable` SavedModel export."""
resource_list = []
for v in self._variables + [self._saving_variable]:
resource_list.extend(
v._export_to_saved_model_graph( # pylint:disable=protected-access
object_map, tensor_map, options, **kwargs
)
)
object_map[self] = ShardedVariable(
[object_map[self._saving_variable]], name=self.name
)
return resource_list
@property
def _unique_id(self):
# String-replace to ensure uniqueness for checkpoint tracking
return self.variables[0]._unique_id.replace('part_0', 'sharded') # pylint: disable=protected-access
@property
def _distribute_strategy(self):
return self.variables[0]._distribute_strategy # pylint: disable=protected-access
@property
def _shared_name(self):
return self._name
@property
def is_sharded_variable(self):
return True
def numpy(self):
"""Copies the values in this ShardedVariable to a NumPy array.
First converts to a single Tensor using the registered conversion function,
which concatenates the shards, then uses Tensor.numpy() to convert to
a NumPy array.
Returns:
A NumPy array of the same shape and dtype.
"""
return _var_to_tensor(self).numpy()
@tf_export('__internal__.distribute.ShardedVariable', v1=[])
class ShardedVariable(ShardedVariableMixin, composite_tensor.CompositeTensor):
"""A container for `Variables` that should be treated as shards.
Variables that are too large to fit on a single device (e.g., large
embeddings)
may need to be sharded over multiple devices. This class maintains a list of
smaller variables that can be independently stored on separate devices (eg,
multiple parameter servers), and saves and restores those variables as if they
were a single larger variable.
Objects of this class can be saved with a given number of shards and then
restored from a checkpoint into a different number of shards.
Objects of this class can be saved to SavedModel format using
`tf.saved_model.save`. The SavedModel can be used by programs like TF serving
APIs. It is not yet supported to load the SavedModel with
`tf.saved_model.load`.
Since `ShardedVariable` can be saved and then restored to different number of
shards depending on the restore environments, for example, TF serving APIs
would restore to one shard for serving efficiency, when using
`ShardedVariable` in a tf.function, one should generally not assume it has the
same number of shards across save and load.
Sharding is only supported along the first dimension.
>>> class Model(tf.Module):
... def __init__(self):
... self.sharded_variable = ShardedVariable([
... tf.Variable([3.0], dtype=tf.float32),
... tf.Variable([2.0], dtype=tf.float32)
... ])
...
... @tf.function(input_signature=[tf.TensorSpec([], dtype=tf.int32)])
... def fn(self, x):
... return tf.nn.embedding_lookup(self.sharded_variable.variables, x)
...
... @tf.function(input_signature=[tf.TensorSpec([], dtype=tf.int32)])
... def serve_fn(self, x):
... return tf.nn.embedding_lookup(self.sharded_variable.variables, x)
>>>
>>> model = Model()
>>> model.fn(1).numpy()
2.0
>>> tf.saved_model.save(model, export_dir='/tmp/saved_model',
... signatures=model.serve_fn)
"""
@property
def _type_spec(self):
return ShardedVariableSpec(
*(resource_variable_ops.VariableSpec(v.shape, v.dtype)
for v in self._variables))
@classmethod
def _overload_all_operators(cls):
"""Register overloads for all operators."""
for operator in tensor_lib.Tensor.OVERLOADABLE_OPERATORS:
if operator == '__getitem__':
continue
cls._overload_operator(operator)
@classmethod
def _overload_operator(cls, operator):
"""Delegate an operator overload to `tensor_lib.Tensor`."""
tensor_operator = getattr(tensor_lib.Tensor, operator)
def _operator(v, *args, **kwargs):
return tensor_operator(_var_to_tensor(v), *args, **kwargs)
setattr(cls, operator, _operator)
def __tf_experimental_restore_capture__(
self, concrete_function, internal_capture
):
# Avoid restoring captures for functions that use ShardedVariable - the
# layer will be recreated during Keras model loading
# TODO(jmullenbach): support loading models with ShardedVariables using
# tf.saved_model.load
return None
def _should_act_as_resource_variable(self):
"""Pass resource_variable_ops.is_resource_variable check."""
return True
def _write_object_proto(self, proto, options):
resource_variable_ops.write_object_proto_for_resource_variable(
self._saving_variable, proto, options, enforce_naming=False
)
def _copy_trackable_to_cpu(self, object_map):
"""For implementing `Trackable` async checkpointing."""
if self in object_map:
# If populated already, simply loop through sub-variables to copy values.
for v in self._variables:
v._copy_trackable_to_cpu(object_map) # pylint: disable=protected-access
else:
# If not populated, populate first, then copy.
copied_vars = []
for v in self._variables:
# This step will both instantiate `v`'s CPU copy and copy its value.
v._copy_trackable_to_cpu(object_map) # pylint: disable=protected-access
copied_vars.append(object_map[v])
new_var = ShardedVariable(copied_vars, name=self.name)
object_map[self] = new_var
def _var_to_tensor(var, dtype=None, name=None, as_ref=False):
"""Converts a `ShardedVariable` to a `Tensor`."""
del name
if dtype is not None and not dtype.is_compatible_with(var.dtype):
raise ValueError(
'Incompatible type conversion requested to type {!r} for variable '
'of type {!r}'.format(dtype.name, var.dtype.name)
)
if as_ref:
raise NotImplementedError(
"ShardedVariable doesn't support being used as a reference."
)
# We use op dispatch mechanism to override embedding_lookup ops when called
# with ShardedVariable. This requires embedding_lookup ops to raise TypeError
# when called with ShardedVariable. However since ShardedVariable can be
# converted to a tensor via concat, embedding_lookup ops would silently
# do the convertion and never raise a TypeError. To be able to properly
# raise a TypeError, namescope is used to detect if this method is called
# within a embedding_lookup op.
# NOTE: This doesn't work in eager mode since op namescope is always cleared
# in eager. This also breaks if user sets the name of embedding_lookup op
# with something that doesn't contain str "embedding_lookup".
#
# TODO(chenkai): Find a more robust way to do this, which should not rely
# on namescope.
if 'embedding_lookup' in ops.get_name_scope():
raise TypeError(
'Converting ShardedVariable to tensor in embedding lookup'
' ops is disallowed.'
)
return array_ops.concat(var.variables, axis=0)
# Register a conversion function which reads the value of the variable,
# allowing instances of the class to be used as tensors.
tensor_conversion_registry.register_tensor_conversion_function(
ShardedVariable, _var_to_tensor
)
ShardedVariable._overload_all_operators() # pylint: disable=protected-access
# Override the behavior of embedding_lookup(sharded_variable, ...)
@dispatch.dispatch_for_types(embedding_ops.embedding_lookup, ShardedVariable)
def embedding_lookup(
params,
ids,
partition_strategy='mod',
name=None,
validate_indices=True,
max_norm=None,
):
if isinstance(params, list):
params = params[0]
return embedding_ops.embedding_lookup(
params.variables,
ids,
partition_strategy,
name,
validate_indices,
max_norm,
)
# Separately override safe_embedding_lookup_sparse, to avoid conversion of
# ShardedVariable to tensor.
@dispatch.dispatch_for_api(embedding_ops.safe_embedding_lookup_sparse)
def safe_embedding_lookup_sparse(
embedding_weights: ShardedVariable,
sparse_ids,
sparse_weights=None,
combiner='mean',
default_id=None,
name=None,
partition_strategy='div',
max_norm=None,
allow_fast_lookup=False,
):
"""Pass the individual shard variables as a list."""
return embedding_ops.safe_embedding_lookup_sparse(
embedding_weights.variables,
sparse_ids,
sparse_weights=sparse_weights,
combiner=combiner,
default_id=default_id,
name=name,
partition_strategy=partition_strategy,
max_norm=max_norm,
allow_fast_lookup=allow_fast_lookup,
)