144 lines
5.9 KiB
Python
144 lines
5.9 KiB
Python
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Training state management."""
|
|
|
|
import os
|
|
from tensorflow.python.checkpoint import checkpoint as trackable_util
|
|
from tensorflow.python.checkpoint import checkpoint_management
|
|
from tensorflow.python.framework import constant_op
|
|
from tensorflow.python.framework import dtypes
|
|
from tensorflow.python.framework import errors
|
|
from tensorflow.python.keras import backend
|
|
from tensorflow.python.keras.distribute import distributed_file_utils
|
|
from tensorflow.python.keras.utils import mode_keys
|
|
from tensorflow.python.lib.io import file_io
|
|
from tensorflow.python.ops import variables
|
|
|
|
# Constant for `tf.keras.Model` attribute to store the epoch at which the most
|
|
# recently saved checkpoint was saved.
|
|
CKPT_SAVED_EPOCH = '_ckpt_saved_epoch'
|
|
|
|
CKPT_SAVED_EPOCH_UNUSED_VALUE = -1
|
|
|
|
|
|
class WorkerTrainingState(object):
|
|
"""Training state management class.
|
|
|
|
This class provides apis for backing up and restoring the training state.
|
|
This allows model and epoch information to be saved periodically and restore
|
|
for fault-tolerance, also known as preemption-recovery purpose.
|
|
"""
|
|
|
|
def __init__(self, model, checkpoint_dir):
|
|
self._model = model
|
|
|
|
# The epoch at which the checkpoint is saved. Used for fault-tolerance.
|
|
# GPU device only has int64 dtype registered VarHandleOp.
|
|
self._ckpt_saved_epoch = variables.Variable(
|
|
initial_value=constant_op.constant(
|
|
CKPT_SAVED_EPOCH_UNUSED_VALUE, dtype=dtypes.int64),
|
|
name='ckpt_saved_epoch')
|
|
|
|
# Variable initialization.
|
|
backend.set_value(self._ckpt_saved_epoch, CKPT_SAVED_EPOCH_UNUSED_VALUE)
|
|
|
|
# _ckpt_saved_epoch gets tracked and is included in the checkpoint file
|
|
# when backing up.
|
|
checkpoint = trackable_util.Checkpoint(
|
|
model=self._model, ckpt_saved_epoch=self._ckpt_saved_epoch)
|
|
|
|
# If this is single-worker training, checkpoint_dir are the same for
|
|
# write_checkpoint_manager and read_checkpoint_manager.
|
|
#
|
|
# If this is multi-worker training, and this worker should not
|
|
# save checkpoint, we replace the write_checkpoint_manager's checkpoint_dir
|
|
# with a temp filepath, so it writes to a file that will be removed at the
|
|
# end of back_up() call. This is necessary because the SyncOnReadVariable
|
|
# needs to be synced across all the workers in order to be read, and all
|
|
# workers need to perform `save()`.
|
|
# But all workers should restore from the same checkpoint_dir as passed in
|
|
# read_checkpoint_manager.
|
|
self.read_checkpoint_manager = checkpoint_management.CheckpointManager(
|
|
checkpoint,
|
|
directory=os.path.join(checkpoint_dir, 'chief'),
|
|
max_to_keep=1)
|
|
write_checkpoint_dir = distributed_file_utils.write_dirpath(
|
|
checkpoint_dir, self._model.distribute_strategy)
|
|
if self._model.distribute_strategy.extended.should_checkpoint:
|
|
self.write_checkpoint_manager = self.read_checkpoint_manager
|
|
else:
|
|
self.write_checkpoint_manager = checkpoint_management.CheckpointManager(
|
|
checkpoint, directory=write_checkpoint_dir, max_to_keep=1)
|
|
|
|
def back_up(self, epoch):
|
|
"""Back up the current state of training into a checkpoint file.
|
|
|
|
Args:
|
|
epoch: The current epoch information to be saved.
|
|
"""
|
|
backend.set_value(self._ckpt_saved_epoch, epoch)
|
|
# Save the model plus CKPT_SAVED_EPOCH variable.
|
|
if self.write_checkpoint_manager.save():
|
|
distributed_file_utils.remove_temp_dirpath(
|
|
self.write_checkpoint_manager.directory,
|
|
self._model.distribute_strategy)
|
|
|
|
def restore(self):
|
|
"""Restore the training state from the backed up checkpoint file.
|
|
|
|
Returns:
|
|
True if the training state is successfully restored. False if the training
|
|
state doesn't need to be restored, or error occurred so it can't.
|
|
"""
|
|
self.read_checkpoint_manager.restore_or_initialize()
|
|
|
|
def delete_backup(self):
|
|
"""Delete the backup directories.
|
|
|
|
Delete the backup directories which should not exist after `fit()`
|
|
successfully finishes.
|
|
"""
|
|
if self.write_checkpoint_manager is self.read_checkpoint_manager:
|
|
try:
|
|
file_io.delete_recursively_v2(self.write_checkpoint_manager.directory)
|
|
except errors.NotFoundError:
|
|
pass
|
|
|
|
def maybe_load_initial_epoch_from_ckpt(self, initial_epoch, mode):
|
|
"""Maybe load initial epoch from ckpt considering possible worker recovery.
|
|
|
|
When `_ckpt_saved_epoch` attribute exists and is not
|
|
`CKPT_SAVED_EPOCH_UNUSED_VALUE`, this is under multi-worker training setting
|
|
and indicates the worker is recovering from previous failure. In this case,
|
|
infer `initial_epoch` from `self._ckpt_saved_epoch` to continue previous
|
|
unfinished training from certain epoch.
|
|
|
|
Args:
|
|
initial_epoch: The original initial_epoch user passes in in `fit()`.
|
|
mode: The mode for running `model.fit()`.
|
|
|
|
Returns:
|
|
If the training is recovering from previous failure under multi-worker
|
|
training setting, return the epoch the training is supposed to continue
|
|
at. Otherwise, return the `initial_epoch` the user passes in.
|
|
"""
|
|
|
|
epoch = backend.eval(self._ckpt_saved_epoch)
|
|
if mode == mode_keys.ModeKeys.TRAIN and epoch >= 0:
|
|
# The most recently saved epoch is one epoch prior to the epoch it
|
|
# failed at, so return the value of 'self._ckpt_saved_epoch' plus one.
|
|
return epoch + 1
|
|
return initial_epoch
|