UGP/save_on_end_epoch.py

26 lines
830 B
Python

from pathlib import Path
from typing import Any
from transformers import TrainerCallback, TrainerControl, TrainerState, TrainingArguments
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
class SaveOnEndEpochTrainerCallback(TrainerCallback):
def on_epoch_end(
self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs: Any
) -> None:
training_steps = state.global_step
# Do not save if was not trained
if training_steps <= 0:
return
save_path = Path(args.output_dir) / f"{PREFIX_CHECKPOINT_DIR}-{training_steps}"
# Skip if checkpoint exists - no need to save
if save_path.exists():
return
control.should_log = True
control.should_evaluate = True
control.should_save = True