22 lines
1022 B
Python
22 lines
1022 B
Python
from pytorch_lightning.callbacks import Callback
|
|
|
|
|
|
class PushToHubCallback(Callback):
|
|
def __init__(self, output_model_path) -> None:
|
|
super().__init__()
|
|
self.output_model_path = output_model_path
|
|
|
|
|
|
def on_train_epoch_end(self, trainer, pl_module):
|
|
print(f"Pushing model to the hub, epoch {trainer.current_epoch}")
|
|
pl_module.model.push_to_hub(self.output_model_path,
|
|
commit_message=f"Training in progress, epoch {trainer.current_epoch}")
|
|
# pl_module.processor.push_to_hub(self.output_model_path, commit_message=f"Training in progress, epoch {trainer.current_epoch}")
|
|
|
|
def on_train_end(self, trainer, pl_module):
|
|
print(f"Pushing model to the hub after training")
|
|
pl_module.processor.push_to_hub(self.output_model_path,
|
|
commit_message=f"Training done")
|
|
pl_module.model.push_to_hub(self.output_model_path,
|
|
commit_message=f"Training done")
|