26 lines
830 B
Python
26 lines
830 B
Python
from pathlib import Path
|
|
from typing import Any
|
|
|
|
from transformers import TrainerCallback, TrainerControl, TrainerState, TrainingArguments
|
|
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
|
|
|
|
|
class SaveOnEndEpochTrainerCallback(TrainerCallback):
|
|
def on_epoch_end(
|
|
self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs: Any
|
|
) -> None:
|
|
training_steps = state.global_step
|
|
|
|
# Do not save if was not trained
|
|
if training_steps <= 0:
|
|
return
|
|
|
|
save_path = Path(args.output_dir) / f"{PREFIX_CHECKPOINT_DIR}-{training_steps}"
|
|
# Skip if checkpoint exists - no need to save
|
|
if save_path.exists():
|
|
return
|
|
|
|
control.should_log = True
|
|
control.should_evaluate = True
|
|
control.should_save = True
|