donut/utils/checkpoint.py

17 lines
644 B
Python

from pytorch_lightning.plugins import CheckpointIO
import torch
class CustomCheckpointIO(CheckpointIO):
def save_checkpoint(self, checkpoint, path, storage_options=None):
del checkpoint["state_dict"]
torch.save(checkpoint, path)
def load_checkpoint(self, path, storage_options=None):
checkpoint = torch.load(path + "artifacts.ckpt")
state_dict = torch.load(path + "pytorch_model.bin")
checkpoint["state_dict"] = {"model." + key: value for key, value in state_dict.items()}
return checkpoint
def remove_checkpoint(self, path) -> None:
return super().remove_checkpoint(path)