Intelegentny_Pszczelarz/.venv/Lib/site-packages/tensorflow/python/keras/distribute/worker_training_state.py

144 lines
5.9 KiB
Python
Raw Normal View History

2023-06-19 00:49:18 +02:00
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Training state management."""
import os
from tensorflow.python.checkpoint import checkpoint as trackable_util
from tensorflow.python.checkpoint import checkpoint_management
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.keras import backend
from tensorflow.python.keras.distribute import distributed_file_utils
from tensorflow.python.keras.utils import mode_keys
from tensorflow.python.lib.io import file_io
from tensorflow.python.ops import variables
# Constant for `tf.keras.Model` attribute to store the epoch at which the most
# recently saved checkpoint was saved.
CKPT_SAVED_EPOCH = '_ckpt_saved_epoch'
CKPT_SAVED_EPOCH_UNUSED_VALUE = -1
class WorkerTrainingState(object):
"""Training state management class.
This class provides apis for backing up and restoring the training state.
This allows model and epoch information to be saved periodically and restore
for fault-tolerance, also known as preemption-recovery purpose.
"""
def __init__(self, model, checkpoint_dir):
self._model = model
# The epoch at which the checkpoint is saved. Used for fault-tolerance.
# GPU device only has int64 dtype registered VarHandleOp.
self._ckpt_saved_epoch = variables.Variable(
initial_value=constant_op.constant(
CKPT_SAVED_EPOCH_UNUSED_VALUE, dtype=dtypes.int64),
name='ckpt_saved_epoch')
# Variable initialization.
backend.set_value(self._ckpt_saved_epoch, CKPT_SAVED_EPOCH_UNUSED_VALUE)
# _ckpt_saved_epoch gets tracked and is included in the checkpoint file
# when backing up.
checkpoint = trackable_util.Checkpoint(
model=self._model, ckpt_saved_epoch=self._ckpt_saved_epoch)
# If this is single-worker training, checkpoint_dir are the same for
# write_checkpoint_manager and read_checkpoint_manager.
#
# If this is multi-worker training, and this worker should not
# save checkpoint, we replace the write_checkpoint_manager's checkpoint_dir
# with a temp filepath, so it writes to a file that will be removed at the
# end of back_up() call. This is necessary because the SyncOnReadVariable
# needs to be synced across all the workers in order to be read, and all
# workers need to perform `save()`.
# But all workers should restore from the same checkpoint_dir as passed in
# read_checkpoint_manager.
self.read_checkpoint_manager = checkpoint_management.CheckpointManager(
checkpoint,
directory=os.path.join(checkpoint_dir, 'chief'),
max_to_keep=1)
write_checkpoint_dir = distributed_file_utils.write_dirpath(
checkpoint_dir, self._model.distribute_strategy)
if self._model.distribute_strategy.extended.should_checkpoint:
self.write_checkpoint_manager = self.read_checkpoint_manager
else:
self.write_checkpoint_manager = checkpoint_management.CheckpointManager(
checkpoint, directory=write_checkpoint_dir, max_to_keep=1)
def back_up(self, epoch):
"""Back up the current state of training into a checkpoint file.
Args:
epoch: The current epoch information to be saved.
"""
backend.set_value(self._ckpt_saved_epoch, epoch)
# Save the model plus CKPT_SAVED_EPOCH variable.
if self.write_checkpoint_manager.save():
distributed_file_utils.remove_temp_dirpath(
self.write_checkpoint_manager.directory,
self._model.distribute_strategy)
def restore(self):
"""Restore the training state from the backed up checkpoint file.
Returns:
True if the training state is successfully restored. False if the training
state doesn't need to be restored, or error occurred so it can't.
"""
self.read_checkpoint_manager.restore_or_initialize()
def delete_backup(self):
"""Delete the backup directories.
Delete the backup directories which should not exist after `fit()`
successfully finishes.
"""
if self.write_checkpoint_manager is self.read_checkpoint_manager:
try:
file_io.delete_recursively_v2(self.write_checkpoint_manager.directory)
except errors.NotFoundError:
pass
def maybe_load_initial_epoch_from_ckpt(self, initial_epoch, mode):
"""Maybe load initial epoch from ckpt considering possible worker recovery.
When `_ckpt_saved_epoch` attribute exists and is not
`CKPT_SAVED_EPOCH_UNUSED_VALUE`, this is under multi-worker training setting
and indicates the worker is recovering from previous failure. In this case,
infer `initial_epoch` from `self._ckpt_saved_epoch` to continue previous
unfinished training from certain epoch.
Args:
initial_epoch: The original initial_epoch user passes in in `fit()`.
mode: The mode for running `model.fit()`.
Returns:
If the training is recovering from previous failure under multi-worker
training setting, return the epoch the training is supposed to continue
at. Otherwise, return the `initial_epoch` the user passes in.
"""
epoch = backend.eval(self._ckpt_saved_epoch)
if mode == mode_keys.ModeKeys.TRAIN and epoch >= 0:
# The most recently saved epoch is one epoch prior to the epoch it
# failed at, so return the value of 'self._ckpt_saved_epoch' plus one.
return epoch + 1
return initial_epoch