from pytorch_lightning.plugins import CheckpointIO import torch class CustomCheckpointIO(CheckpointIO): def save_checkpoint(self, checkpoint, path, storage_options=None): del checkpoint["state_dict"] torch.save(checkpoint, path) def load_checkpoint(self, path, storage_options=None): checkpoint = torch.load(path + "artifacts.ckpt") state_dict = torch.load(path + "pytorch_model.bin") checkpoint["state_dict"] = {"model." + key: value for key, value in state_dict.items()} return checkpoint def remove_checkpoint(self, path) -> None: return super().remove_checkpoint(path)