17 lines
644 B
Python
17 lines
644 B
Python
|
from pytorch_lightning.plugins import CheckpointIO
|
||
|
import torch
|
||
|
|
||
|
|
||
|
class CustomCheckpointIO(CheckpointIO):
|
||
|
def save_checkpoint(self, checkpoint, path, storage_options=None):
|
||
|
del checkpoint["state_dict"]
|
||
|
torch.save(checkpoint, path)
|
||
|
|
||
|
def load_checkpoint(self, path, storage_options=None):
|
||
|
checkpoint = torch.load(path + "artifacts.ckpt")
|
||
|
state_dict = torch.load(path + "pytorch_model.bin")
|
||
|
checkpoint["state_dict"] = {"model." + key: value for key, value in state_dict.items()}
|
||
|
return checkpoint
|
||
|
|
||
|
def remove_checkpoint(self, path) -> None:
|
||
|
return super().remove_checkpoint(path)
|