donut/utils/callbacks.py

22 lines
1022 B
Python

from pytorch_lightning.callbacks import Callback
class PushToHubCallback(Callback):
def __init__(self, output_model_path) -> None:
super().__init__()
self.output_model_path = output_model_path
def on_train_epoch_end(self, trainer, pl_module):
print(f"Pushing model to the hub, epoch {trainer.current_epoch}")
pl_module.model.push_to_hub(self.output_model_path,
commit_message=f"Training in progress, epoch {trainer.current_epoch}")
# pl_module.processor.push_to_hub(self.output_model_path, commit_message=f"Training in progress, epoch {trainer.current_epoch}")
def on_train_end(self, trainer, pl_module):
print(f"Pushing model to the hub after training")
pl_module.processor.push_to_hub(self.output_model_path,
commit_message=f"Training done")
pl_module.model.push_to_hub(self.output_model_path,
commit_message=f"Training done")