# Copyright 2023 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Checkpoint policies that determine how tensors are split into shards.""" import math from typing import MutableSequence, Sequence from absl import logging from tensorflow.python.checkpoint.sharding import sharding_util from tensorflow.python.eager import context from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import string_ops from tensorflow.python.ops import variables from tensorflow.python.trackable import base from tensorflow.python.util import tf_export @tf_export.tf_export("train.experimental.ShardByTaskPolicy") class ShardByTaskPolicy(sharding_util.ShardingCallback): """Policy that splits tensors into shards based on their device spec task.""" @property def description(self) -> str: return "Split tensors into shards based on their device spec task." def __call__( self, shardable_tensors: Sequence[sharding_util.ShardableTensor] ) -> Sequence[sharding_util.TensorSliceDict]: """Callback to split tensors into shards based on their device spec task. Args: shardable_tensors: A list of ShardableTensors. Returns: List of shard dicts containing tensors. [ {checkpoint key: {slice_spec: tensor} } ] """ tensors_by_task = {} for shardable_tensor in shardable_tensors: tensor = shardable_tensor.tensor checkpoint_key = shardable_tensor.checkpoint_key slice_spec = shardable_tensor.slice_spec (tensors_by_task .setdefault(checkpoint_key, {})[slice_spec]) = tensor return [tensors_by_task] _PartitionAxisAndSize = tuple[int, int] _OffsetAndShape = tuple[Sequence[int], Sequence[int]] @tf_export.tf_export("train.experimental.MaxShardSizePolicy") class MaxShardSizePolicy(sharding_util.ShardingCallback): """Policy that splits tensors into shards with a max shard size. Shards may exceed the max shard size if they contain 1. a single scalar/string tensor that could not be sliced and exceeds the max shard size or 2. the checkpoint object graph, whose size cannot be calculated when saving. """ def __init__(self, max_shard_size: int): self.max_shard_size = max_shard_size @property def description(self) -> str: return "Split tensors into shards with a max shard size." def _get_next_partition( self, shard_size_remaining: int, shape: tensor_shape.TensorShape, dtype_size: int, num_elems: int ) -> _PartitionAxisAndSize: """Gets tensor partition with size closest to shard_size_remaining. Args: shard_size_remaining: Size in bytes of the space remaining in the shard. shape: Shape of the working tensor to partition in the remaining shard space. dtype_size: Size in bytes of the dtype of the working tensor. num_elems: Number of elements in the working tensor. Returns: A tuple containing the axis of the next partition and that partition size. """ if shape.rank is None or shape.rank == 0: return 0, math.inf # Find axis with minimum partitions. (aka axis with maximum partition size) # (max partition size is as close as possible to the shard_size_remaining) bytes_per_slice = num_elems // shape.dims[0].value * dtype_size slices_per_shard = max( 1, math.floor(shard_size_remaining / bytes_per_slice)) min_parts = math.ceil(shape.dims[0].value / slices_per_shard) min_axis = 0 for axis in range(1, shape.rank): bytes_per_slice = num_elems // shape.dims[axis].value * dtype_size slices_per_shard = max( 1, math.floor(shard_size_remaining / bytes_per_slice)) axis_parts = math.ceil(shape.dims[axis].value / slices_per_shard) partition_size = num_elems * dtype_size / axis_parts if (axis_parts < min_parts and partition_size < shard_size_remaining): min_axis, min_parts = axis, int(axis_parts) return min_axis, math.ceil(int(shape[min_axis]) / min_parts) def _add_partition( self, root_shardable_tensor: sharding_util.ShardableTensor, dtype_size: int, working_tensor_offset: Sequence[int], part_axis_and_size: _PartitionAxisAndSize, shard_size_remaining: int, max_shard_size: int, tensors_by_shard: MutableSequence[sharding_util.TensorSliceDict], large_scalars: MutableSequence[sharding_util.TensorSliceDict], ) -> tuple[tensor_lib.Tensor, _OffsetAndShape]: """Adds the tensor partition to the shard, if possible. Args: root_shardable_tensor: The full tensor being partitioned. dtype_size: Size in bytes of the dtype of the working tensor. working_tensor_offset: The offset of the working tensor in the full tensor. part_axis_and_size: A tuple containing the axis of the partition and that partition size. shard_size_remaining: Size in bytes of the space remaining in the shard. max_shard_size: Max size in bytes allowed for a checkpoint shard. tensors_by_shard: List of shard dicts containing tensors. [ {checkpoint key: {slice_spec: tensor} } ] large_scalars: List of shard dicts containing scalars too large to fit in the max_shard_size. [ {checkpoint key: {slice_spec: tensor} } ] Returns: A tuple containing the size of the slice that was added to the shard and the offset & shape of the remaining portion of the tensor. """ root_tensor = root_shardable_tensor.tensor root_tensor_shape = root_shardable_tensor.shape checkpoint_key = root_shardable_tensor.checkpoint_key if root_tensor_shape.rank is None or root_tensor_shape.rank == 0: return None, (None, None) min_axis, part_size = part_axis_and_size # Add what we can to the current shard. slice_offset = working_tensor_offset slice_shape = [root_tensor_shape[i] - slice_offset[i] for i in range(root_tensor_shape.rank)] slice_shape[min_axis] = part_size slice_size_in_bytes = int(math.prod(slice_shape)) * dtype_size with ops.device(root_shardable_tensor.device): tensor_slice = array_ops.slice( root_tensor, begin=slice_offset, size=slice_shape) slice_spec = variables.Variable.SaveSliceInfo( full_name=checkpoint_key, full_shape=root_tensor_shape, var_offset=slice_offset, var_shape=slice_shape).spec.strip() remaining_size = shard_size_remaining if slice_size_in_bytes > max_shard_size: logging.warning("Slice %s of tensor %s is a scalar of size %s bytes and " "cannot be partitioned into a shard of max shard size %s " "bytes. It will be added as an individual shard that " "exceeds the max shard size.", slice_spec, checkpoint_key, slice_size_in_bytes, max_shard_size) large_scalars.append({checkpoint_key: {slice_spec: tensor_slice}}) elif slice_size_in_bytes > shard_size_remaining: # Smallest partition can't fit in the remaining shard space. Start fresh # with a new shard. return None, (None, None) else: if not tensors_by_shard or shard_size_remaining < 1: tensors_by_shard.append({}) remaining_size = max_shard_size (tensors_by_shard[-1] .setdefault(checkpoint_key, {})[slice_spec]) = tensor_slice remaining_size -= slice_size_in_bytes # Get remaining portion of tensor to add to the next shard(s). slice_offset[min_axis] += part_size slice_shape = [root_tensor_shape[i] - slice_offset[i] for i in range(root_tensor_shape.rank)] return (remaining_size, (slice_offset, slice_shape)) def __call__( self, shardable_tensors: Sequence[sharding_util.ShardableTensor] ) -> Sequence[sharding_util.TensorSliceDict]: """Callback to split tensors into shards with a max shard size. Args: shardable_tensors: A list of ShardableTensors. Returns: List of shard dicts containing tensors. [ {checkpoint key: {slice_spec: tensor} } ] """ tensors_by_shard = [] large_scalars = [] shard_size_remaining = self.max_shard_size for shardable_tensor in shardable_tensors: root_tensor = shardable_tensor.tensor root_shape = shardable_tensor.shape dtype = shardable_tensor.dtype checkpoint_key = shardable_tensor.checkpoint_key dtype_size = dtypes.as_dtype(dtype).size total_size = root_shape.num_elements() * dtype_size # in bytes # Calculate string tensor sizes. if checkpoint_key == base.OBJECT_GRAPH_PROTO_KEY: # In graph mode, the object graph is populated using feed_additions when # the session is run. So, we can't calculate the size here. Fortunately, # the serialized object graph string will never be that big, so we just # place it in the current shard without worrying about its size. total_size = dtype_size = 0 elif dtype == dtypes.string: if not context.executing_eagerly(): with ops.device(shardable_tensor.device): root_tensor = ops.get_default_session().run(root_tensor) if root_shape.rank is None or root_shape.rank == 0: sizes = [string_ops.string_length(root_tensor, unit="BYTE")] else: sizes = [string_ops.string_length(elem, unit="BYTE") for elem in root_tensor] if context.executing_eagerly(): sizes = [size.numpy() for size in sizes] else: with ops.device(shardable_tensor.device): sizes = ops.get_default_session().run(sizes) total_size = sum(sizes) dtype_size = max(sizes) if (total_size > self.max_shard_size and (root_shape.rank is None or root_shape.rank == 0)): logging.warning("Tensor %s is a scalar of size %s bytes and cannot be " "partitioned into a shard of max shard size %s bytes. " "It will be added as an individual shard that exceeds " "the max shard size.", checkpoint_key, total_size, self.max_shard_size) large_scalars.append( {checkpoint_key: {shardable_tensor.slice_spec: root_tensor}}) continue # Partition tensor and add partitions to shards. working_tensor = root_tensor working_tensor_var_offset = [0] * root_shape.rank working_tensor_shape = root_shape working_tensor_size = total_size while working_tensor_size > shard_size_remaining: part_axis_and_size = self._get_next_partition( shard_size_remaining=shard_size_remaining, shape=working_tensor_shape, dtype_size=dtype_size, num_elems=working_tensor_shape.num_elements()) (remaining_size, (remaining_offset, remaining_shape)) = self._add_partition( root_shardable_tensor=shardable_tensor, dtype_size=dtype_size, working_tensor_offset=working_tensor_var_offset, part_axis_and_size=part_axis_and_size, shard_size_remaining=shard_size_remaining, max_shard_size=self.max_shard_size, tensors_by_shard=tensors_by_shard, large_scalars=large_scalars) if remaining_size is None: # Tensor partition couldn't fit in remaining shard space. Try again # with the next full shard. tensors_by_shard.append({}) shard_size_remaining = self.max_shard_size else: working_tensor = array_ops.slice( root_tensor, begin=remaining_offset, size=remaining_shape) working_tensor_var_offset = remaining_offset working_tensor_shape = working_tensor.shape working_tensor_size = int(math.prod(remaining_shape)) * dtype_size shard_size_remaining = remaining_size if working_tensor_shape.num_elements() > 0: remaining_tensor_slice_spec = variables.Variable.SaveSliceInfo( full_name=checkpoint_key, full_shape=root_shape, var_offset=working_tensor_var_offset, var_shape=working_tensor_shape).spec.strip() if not tensors_by_shard: tensors_by_shard.append({}) (tensors_by_shard[-1] .setdefault(checkpoint_key, {}) [remaining_tensor_slice_spec]) = working_tensor shard_size_remaining -= working_tensor_size return tensors_by_shard + large_scalars