added checkpoints
This commit is contained in:
parent
e3e0fa495d
commit
322a495d9c
@ -1,8 +1,6 @@
|
|||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
# coding: utf-8
|
# coding: utf-8
|
||||||
|
|
||||||
# In[19]:
|
|
||||||
|
|
||||||
|
|
||||||
from transformers import VisionEncoderDecoderConfig, DonutProcessor, VisionEncoderDecoderModel
|
from transformers import VisionEncoderDecoderConfig, DonutProcessor, VisionEncoderDecoderModel
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
@ -15,23 +13,20 @@ import re
|
|||||||
from nltk import edit_distance
|
from nltk import edit_distance
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pytorch_lightning.loggers import WandbLogger
|
from pytorch_lightning.loggers import WandbLogger
|
||||||
from pytorch_lightning.callbacks import Callback
|
from pytorch_lightning.callbacks import Callback, ModelCheckpoint
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import os
|
import os
|
||||||
from huggingface_hub import login
|
from huggingface_hub import login
|
||||||
|
from pytorch_lightning.plugins import CheckpointIO
|
||||||
|
|
||||||
|
|
||||||
# In[8]:
|
|
||||||
|
|
||||||
|
|
||||||
DATASET_PATH = "Zombely/pl-text-images-5000-whole"
|
DATASET_PATH = "Zombely/pl-text-images-5000-whole"
|
||||||
PRETRAINED_MODEL_PATH = "Zombely/plwiki-proto-fine-tuned"
|
PRETRAINED_MODEL_PATH = "Zombely/plwiki-proto-fine-tuned"
|
||||||
START_MODEL_PATH = "Zombely/plwiki-proto-fine-tuned"
|
START_MODEL_PATH = "Zombely/plwiki-proto-fine-tuned"
|
||||||
OUTPUT_MODEL_PATH = "Zombely/plwiki-proto-fine-tuned-v2"
|
OUTPUT_MODEL_PATH = "Zombely/plwiki-proto-fine-tuned-v2"
|
||||||
LOGGING_PATH = "plwiki-proto-ft-second-iter"
|
LOGGING_PATH = "plwiki-proto-ft-second-iter"
|
||||||
|
CHECKPOINT_PATH = "./checkpoint"
|
||||||
|
|
||||||
# In[ ]:
|
|
||||||
|
|
||||||
|
|
||||||
train_config = {
|
train_config = {
|
||||||
@ -43,7 +38,7 @@ train_config = {
|
|||||||
"lr":3e-5,
|
"lr":3e-5,
|
||||||
"train_batch_sizes": [8],
|
"train_batch_sizes": [8],
|
||||||
"val_batch_sizes": [1],
|
"val_batch_sizes": [1],
|
||||||
# "seed":2022,
|
"seed":2022,
|
||||||
"num_nodes": 1,
|
"num_nodes": 1,
|
||||||
"warmup_steps": 300, # 800/8*30/10, 10%
|
"warmup_steps": 300, # 800/8*30/10, 10%
|
||||||
"result_path": "./result",
|
"result_path": "./result",
|
||||||
@ -51,15 +46,10 @@ train_config = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
# In[9]:
|
|
||||||
|
|
||||||
|
|
||||||
dataset = load_dataset(DATASET_PATH)
|
dataset = load_dataset(DATASET_PATH)
|
||||||
|
|
||||||
|
|
||||||
# In[10]:
|
|
||||||
|
|
||||||
|
|
||||||
max_length = 768
|
max_length = 768
|
||||||
image_size = [1920, 2560]
|
image_size = [1920, 2560]
|
||||||
config = VisionEncoderDecoderConfig.from_pretrained(PRETRAINED_MODEL_PATH)
|
config = VisionEncoderDecoderConfig.from_pretrained(PRETRAINED_MODEL_PATH)
|
||||||
@ -67,18 +57,26 @@ config.encoder.image_size = image_size # (height, width)
|
|||||||
config.decoder.max_length = max_length
|
config.decoder.max_length = max_length
|
||||||
|
|
||||||
|
|
||||||
# In[11]:
|
|
||||||
|
|
||||||
|
|
||||||
processor = DonutProcessor.from_pretrained(START_MODEL_PATH)
|
processor = DonutProcessor.from_pretrained(START_MODEL_PATH)
|
||||||
model = VisionEncoderDecoderModel.from_pretrained(PRETRAINED_MODEL_PATH, config=config)
|
model = VisionEncoderDecoderModel.from_pretrained(PRETRAINED_MODEL_PATH, config=config)
|
||||||
|
|
||||||
|
|
||||||
# In[12]:
|
|
||||||
|
|
||||||
|
|
||||||
added_tokens = []
|
added_tokens = []
|
||||||
|
|
||||||
|
class CustomCheckpointIO(CheckpointIO):
|
||||||
|
def save_checkpoint(self, checkpoint, path, storage_options=None):
|
||||||
|
del checkpoint["state_dict"]
|
||||||
|
torch.save(checkpoint, path)
|
||||||
|
|
||||||
|
def load_checkpoint(self, path, storage_options=None):
|
||||||
|
checkpoint = torch.load(path + "artifacts.ckpt")
|
||||||
|
state_dict = torch.load(path + "pytorch_model.bin")
|
||||||
|
checkpoint["state_dict"] = {"model." + key: value for key, value in state_dict.items()}
|
||||||
|
return checkpoint
|
||||||
|
|
||||||
|
def remove_checkpoint(self, path) -> None:
|
||||||
|
return super().remove_checkpoint(path)
|
||||||
|
|
||||||
class DonutDataset(Dataset):
|
class DonutDataset(Dataset):
|
||||||
"""
|
"""
|
||||||
DonutDataset which is saved in huggingface datasets format. (see details in https://huggingface.co/docs/datasets)
|
DonutDataset which is saved in huggingface datasets format. (see details in https://huggingface.co/docs/datasets)
|
||||||
@ -217,9 +215,6 @@ class DonutDataset(Dataset):
|
|||||||
return pixel_values, labels, target_sequence
|
return pixel_values, labels, target_sequence
|
||||||
|
|
||||||
|
|
||||||
# In[13]:
|
|
||||||
|
|
||||||
|
|
||||||
processor.image_processor.size = image_size[::-1] # should be (width, height)
|
processor.image_processor.size = image_size[::-1] # should be (width, height)
|
||||||
processor.image_processor.do_align_long_axis = False
|
processor.image_processor.do_align_long_axis = False
|
||||||
|
|
||||||
@ -234,23 +229,15 @@ val_dataset = DonutDataset(DATASET_PATH, max_length=max_length,
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# In[14]:
|
|
||||||
|
|
||||||
|
|
||||||
model.config.pad_token_id = processor.tokenizer.pad_token_id
|
model.config.pad_token_id = processor.tokenizer.pad_token_id
|
||||||
model.config.decoder_start_token_id = processor.tokenizer.convert_tokens_to_ids(['<s_cord-v2>'])[0]
|
model.config.decoder_start_token_id = processor.tokenizer.convert_tokens_to_ids(['<s_cord-v2>'])[0]
|
||||||
|
|
||||||
|
|
||||||
# In[15]:
|
|
||||||
|
|
||||||
|
|
||||||
train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=4)
|
train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=4)
|
||||||
val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=4)
|
val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=4)
|
||||||
|
|
||||||
|
|
||||||
# In[16]:
|
|
||||||
|
|
||||||
|
|
||||||
class DonutModelPLModule(pl.LightningModule):
|
class DonutModelPLModule(pl.LightningModule):
|
||||||
def __init__(self, config, processor, model):
|
def __init__(self, config, processor, model):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -336,9 +323,6 @@ class DonutModelPLModule(pl.LightningModule):
|
|||||||
return val_dataloader
|
return val_dataloader
|
||||||
|
|
||||||
|
|
||||||
# In[17]:
|
|
||||||
|
|
||||||
|
|
||||||
class PushToHubCallback(Callback):
|
class PushToHubCallback(Callback):
|
||||||
def on_train_epoch_end(self, trainer, pl_module):
|
def on_train_epoch_end(self, trainer, pl_module):
|
||||||
print(f"Pushing model to the hub, epoch {trainer.current_epoch}")
|
print(f"Pushing model to the hub, epoch {trainer.current_epoch}")
|
||||||
@ -354,8 +338,6 @@ class PushToHubCallback(Callback):
|
|||||||
commit_message=f"Training done")
|
commit_message=f"Training done")
|
||||||
|
|
||||||
|
|
||||||
# In[18]:
|
|
||||||
|
|
||||||
|
|
||||||
login(os.environ.get("HUG_TOKKEN", ""))
|
login(os.environ.get("HUG_TOKKEN", ""))
|
||||||
|
|
||||||
@ -363,13 +345,22 @@ login(os.environ.get("HUG_TOKKEN", ""))
|
|||||||
# ### Wandb.ai link: https://wandb.ai/michalkozlowski936/Donut?workspace=user-michalkozlowski936
|
# ### Wandb.ai link: https://wandb.ai/michalkozlowski936/Donut?workspace=user-michalkozlowski936
|
||||||
# ### Hugging_face link https://huggingface.co/Zombely
|
# ### Hugging_face link https://huggingface.co/Zombely
|
||||||
|
|
||||||
# In[22]:
|
|
||||||
|
|
||||||
|
|
||||||
model_module = DonutModelPLModule(train_config, processor, model)
|
model_module = DonutModelPLModule(train_config, processor, model)
|
||||||
|
|
||||||
wandb_logger = WandbLogger(project="Donut", name=LOGGING_PATH)
|
wandb_logger = WandbLogger(project="Donut", name=LOGGING_PATH)
|
||||||
|
|
||||||
|
checkpoint_callback = ModelCheckpoint(
|
||||||
|
monitor="val_metric",
|
||||||
|
dirpath=CHECKPOINT_PATH,
|
||||||
|
filename="artifacts",
|
||||||
|
save_top_k=1,
|
||||||
|
save_last=False,
|
||||||
|
mode="min",
|
||||||
|
)
|
||||||
|
|
||||||
|
custom_ckpt = CustomCheckpointIO()
|
||||||
|
|
||||||
|
|
||||||
trainer = pl.Trainer(
|
trainer = pl.Trainer(
|
||||||
accelerator="gpu", # change to gpu
|
accelerator="gpu", # change to gpu
|
||||||
devices=1,
|
devices=1,
|
||||||
@ -378,9 +369,10 @@ trainer = pl.Trainer(
|
|||||||
check_val_every_n_epoch=train_config.get("check_val_every_n_epoch"),
|
check_val_every_n_epoch=train_config.get("check_val_every_n_epoch"),
|
||||||
gradient_clip_val=train_config.get("gradient_clip_val"),
|
gradient_clip_val=train_config.get("gradient_clip_val"),
|
||||||
precision=16, # we'll use mixed precision
|
precision=16, # we'll use mixed precision
|
||||||
|
plugins=custom_ckpt,
|
||||||
num_sanity_val_steps=0,
|
num_sanity_val_steps=0,
|
||||||
logger=wandb_logger,
|
logger=wandb_logger,
|
||||||
callbacks=[PushToHubCallback()],
|
callbacks=[PushToHubCallback(), checkpoint_callback],
|
||||||
)
|
)
|
||||||
|
|
||||||
trainer.fit(model_module)
|
trainer.fit(model_module)
|
||||||
|
Loading…
Reference in New Issue
Block a user