added checkpoints

This commit is contained in:
s444415 2022-12-14 16:26:35 +01:00
parent e3e0fa495d
commit 322a495d9c

View File

@ -1,8 +1,6 @@
#!/usr/bin/env python #!/usr/bin/env python
# coding: utf-8 # coding: utf-8
# In[19]:
from transformers import VisionEncoderDecoderConfig, DonutProcessor, VisionEncoderDecoderModel from transformers import VisionEncoderDecoderConfig, DonutProcessor, VisionEncoderDecoderModel
from datasets import load_dataset from datasets import load_dataset
@ -15,23 +13,20 @@ import re
from nltk import edit_distance from nltk import edit_distance
import numpy as np import numpy as np
from pytorch_lightning.loggers import WandbLogger from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import Callback from pytorch_lightning.callbacks import Callback, ModelCheckpoint
import pytorch_lightning as pl import pytorch_lightning as pl
import os import os
from huggingface_hub import login from huggingface_hub import login
from pytorch_lightning.plugins import CheckpointIO
# In[8]:
DATASET_PATH = "Zombely/pl-text-images-5000-whole" DATASET_PATH = "Zombely/pl-text-images-5000-whole"
PRETRAINED_MODEL_PATH = "Zombely/plwiki-proto-fine-tuned" PRETRAINED_MODEL_PATH = "Zombely/plwiki-proto-fine-tuned"
START_MODEL_PATH = "Zombely/plwiki-proto-fine-tuned" START_MODEL_PATH = "Zombely/plwiki-proto-fine-tuned"
OUTPUT_MODEL_PATH = "Zombely/plwiki-proto-fine-tuned-v2" OUTPUT_MODEL_PATH = "Zombely/plwiki-proto-fine-tuned-v2"
LOGGING_PATH = "plwiki-proto-ft-second-iter" LOGGING_PATH = "plwiki-proto-ft-second-iter"
CHECKPOINT_PATH = "./checkpoint"
# In[ ]:
train_config = { train_config = {
@ -43,7 +38,7 @@ train_config = {
"lr":3e-5, "lr":3e-5,
"train_batch_sizes": [8], "train_batch_sizes": [8],
"val_batch_sizes": [1], "val_batch_sizes": [1],
# "seed":2022, "seed":2022,
"num_nodes": 1, "num_nodes": 1,
"warmup_steps": 300, # 800/8*30/10, 10% "warmup_steps": 300, # 800/8*30/10, 10%
"result_path": "./result", "result_path": "./result",
@ -51,15 +46,10 @@ train_config = {
} }
# In[9]:
dataset = load_dataset(DATASET_PATH) dataset = load_dataset(DATASET_PATH)
# In[10]:
max_length = 768 max_length = 768
image_size = [1920, 2560] image_size = [1920, 2560]
config = VisionEncoderDecoderConfig.from_pretrained(PRETRAINED_MODEL_PATH) config = VisionEncoderDecoderConfig.from_pretrained(PRETRAINED_MODEL_PATH)
@ -67,18 +57,26 @@ config.encoder.image_size = image_size # (height, width)
config.decoder.max_length = max_length config.decoder.max_length = max_length
# In[11]:
processor = DonutProcessor.from_pretrained(START_MODEL_PATH) processor = DonutProcessor.from_pretrained(START_MODEL_PATH)
model = VisionEncoderDecoderModel.from_pretrained(PRETRAINED_MODEL_PATH, config=config) model = VisionEncoderDecoderModel.from_pretrained(PRETRAINED_MODEL_PATH, config=config)
# In[12]:
added_tokens = [] added_tokens = []
class CustomCheckpointIO(CheckpointIO):
def save_checkpoint(self, checkpoint, path, storage_options=None):
del checkpoint["state_dict"]
torch.save(checkpoint, path)
def load_checkpoint(self, path, storage_options=None):
checkpoint = torch.load(path + "artifacts.ckpt")
state_dict = torch.load(path + "pytorch_model.bin")
checkpoint["state_dict"] = {"model." + key: value for key, value in state_dict.items()}
return checkpoint
def remove_checkpoint(self, path) -> None:
return super().remove_checkpoint(path)
class DonutDataset(Dataset): class DonutDataset(Dataset):
""" """
DonutDataset which is saved in huggingface datasets format. (see details in https://huggingface.co/docs/datasets) DonutDataset which is saved in huggingface datasets format. (see details in https://huggingface.co/docs/datasets)
@ -217,9 +215,6 @@ class DonutDataset(Dataset):
return pixel_values, labels, target_sequence return pixel_values, labels, target_sequence
# In[13]:
processor.image_processor.size = image_size[::-1] # should be (width, height) processor.image_processor.size = image_size[::-1] # should be (width, height)
processor.image_processor.do_align_long_axis = False processor.image_processor.do_align_long_axis = False
@ -234,23 +229,15 @@ val_dataset = DonutDataset(DATASET_PATH, max_length=max_length,
) )
# In[14]:
model.config.pad_token_id = processor.tokenizer.pad_token_id model.config.pad_token_id = processor.tokenizer.pad_token_id
model.config.decoder_start_token_id = processor.tokenizer.convert_tokens_to_ids(['<s_cord-v2>'])[0] model.config.decoder_start_token_id = processor.tokenizer.convert_tokens_to_ids(['<s_cord-v2>'])[0]
# In[15]:
train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=4) train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=4)
val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=4) val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=4)
# In[16]:
class DonutModelPLModule(pl.LightningModule): class DonutModelPLModule(pl.LightningModule):
def __init__(self, config, processor, model): def __init__(self, config, processor, model):
super().__init__() super().__init__()
@ -336,9 +323,6 @@ class DonutModelPLModule(pl.LightningModule):
return val_dataloader return val_dataloader
# In[17]:
class PushToHubCallback(Callback): class PushToHubCallback(Callback):
def on_train_epoch_end(self, trainer, pl_module): def on_train_epoch_end(self, trainer, pl_module):
print(f"Pushing model to the hub, epoch {trainer.current_epoch}") print(f"Pushing model to the hub, epoch {trainer.current_epoch}")
@ -354,8 +338,6 @@ class PushToHubCallback(Callback):
commit_message=f"Training done") commit_message=f"Training done")
# In[18]:
login(os.environ.get("HUG_TOKKEN", "")) login(os.environ.get("HUG_TOKKEN", ""))
@ -363,13 +345,22 @@ login(os.environ.get("HUG_TOKKEN", ""))
# ### Wandb.ai link: https://wandb.ai/michalkozlowski936/Donut?workspace=user-michalkozlowski936 # ### Wandb.ai link: https://wandb.ai/michalkozlowski936/Donut?workspace=user-michalkozlowski936
# ### Hugging_face link https://huggingface.co/Zombely # ### Hugging_face link https://huggingface.co/Zombely
# In[22]:
model_module = DonutModelPLModule(train_config, processor, model) model_module = DonutModelPLModule(train_config, processor, model)
wandb_logger = WandbLogger(project="Donut", name=LOGGING_PATH) wandb_logger = WandbLogger(project="Donut", name=LOGGING_PATH)
checkpoint_callback = ModelCheckpoint(
monitor="val_metric",
dirpath=CHECKPOINT_PATH,
filename="artifacts",
save_top_k=1,
save_last=False,
mode="min",
)
custom_ckpt = CustomCheckpointIO()
trainer = pl.Trainer( trainer = pl.Trainer(
accelerator="gpu", # change to gpu accelerator="gpu", # change to gpu
devices=1, devices=1,
@ -378,9 +369,10 @@ trainer = pl.Trainer(
check_val_every_n_epoch=train_config.get("check_val_every_n_epoch"), check_val_every_n_epoch=train_config.get("check_val_every_n_epoch"),
gradient_clip_val=train_config.get("gradient_clip_val"), gradient_clip_val=train_config.get("gradient_clip_val"),
precision=16, # we'll use mixed precision precision=16, # we'll use mixed precision
plugins=custom_ckpt,
num_sanity_val_steps=0, num_sanity_val_steps=0,
logger=wandb_logger, logger=wandb_logger,
callbacks=[PushToHubCallback()], callbacks=[PushToHubCallback(), checkpoint_callback],
) )
trainer.fit(model_module) trainer.fit(model_module)