diff --git a/donut-train.py b/donut-train.py index 6590e1d..ff4c3de 100644 --- a/donut-train.py +++ b/donut-train.py @@ -1,8 +1,6 @@ #!/usr/bin/env python # coding: utf-8 -# In[19]: - from transformers import VisionEncoderDecoderConfig, DonutProcessor, VisionEncoderDecoderModel from datasets import load_dataset @@ -15,23 +13,20 @@ import re from nltk import edit_distance import numpy as np from pytorch_lightning.loggers import WandbLogger -from pytorch_lightning.callbacks import Callback +from pytorch_lightning.callbacks import Callback, ModelCheckpoint import pytorch_lightning as pl import os from huggingface_hub import login +from pytorch_lightning.plugins import CheckpointIO -# In[8]: - DATASET_PATH = "Zombely/pl-text-images-5000-whole" PRETRAINED_MODEL_PATH = "Zombely/plwiki-proto-fine-tuned" START_MODEL_PATH = "Zombely/plwiki-proto-fine-tuned" OUTPUT_MODEL_PATH = "Zombely/plwiki-proto-fine-tuned-v2" LOGGING_PATH = "plwiki-proto-ft-second-iter" - - -# In[ ]: +CHECKPOINT_PATH = "./checkpoint" train_config = { @@ -43,7 +38,7 @@ train_config = { "lr":3e-5, "train_batch_sizes": [8], "val_batch_sizes": [1], - # "seed":2022, + "seed":2022, "num_nodes": 1, "warmup_steps": 300, # 800/8*30/10, 10% "result_path": "./result", @@ -51,15 +46,10 @@ train_config = { } -# In[9]: - dataset = load_dataset(DATASET_PATH) -# In[10]: - - max_length = 768 image_size = [1920, 2560] config = VisionEncoderDecoderConfig.from_pretrained(PRETRAINED_MODEL_PATH) @@ -67,18 +57,26 @@ config.encoder.image_size = image_size # (height, width) config.decoder.max_length = max_length -# In[11]: - processor = DonutProcessor.from_pretrained(START_MODEL_PATH) model = VisionEncoderDecoderModel.from_pretrained(PRETRAINED_MODEL_PATH, config=config) - -# In[12]: - - added_tokens = [] +class CustomCheckpointIO(CheckpointIO): + def save_checkpoint(self, checkpoint, path, storage_options=None): + del checkpoint["state_dict"] + torch.save(checkpoint, path) + + def load_checkpoint(self, path, storage_options=None): + checkpoint = torch.load(path + "artifacts.ckpt") + state_dict = torch.load(path + "pytorch_model.bin") + checkpoint["state_dict"] = {"model." + key: value for key, value in state_dict.items()} + return checkpoint + + def remove_checkpoint(self, path) -> None: + return super().remove_checkpoint(path) + class DonutDataset(Dataset): """ DonutDataset which is saved in huggingface datasets format. (see details in https://huggingface.co/docs/datasets) @@ -217,9 +215,6 @@ class DonutDataset(Dataset): return pixel_values, labels, target_sequence -# In[13]: - - processor.image_processor.size = image_size[::-1] # should be (width, height) processor.image_processor.do_align_long_axis = False @@ -234,23 +229,15 @@ val_dataset = DonutDataset(DATASET_PATH, max_length=max_length, ) -# In[14]: model.config.pad_token_id = processor.tokenizer.pad_token_id model.config.decoder_start_token_id = processor.tokenizer.convert_tokens_to_ids([''])[0] - -# In[15]: - - train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=4) val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=4) -# In[16]: - - class DonutModelPLModule(pl.LightningModule): def __init__(self, config, processor, model): super().__init__() @@ -336,9 +323,6 @@ class DonutModelPLModule(pl.LightningModule): return val_dataloader -# In[17]: - - class PushToHubCallback(Callback): def on_train_epoch_end(self, trainer, pl_module): print(f"Pushing model to the hub, epoch {trainer.current_epoch}") @@ -354,8 +338,6 @@ class PushToHubCallback(Callback): commit_message=f"Training done") -# In[18]: - login(os.environ.get("HUG_TOKKEN", "")) @@ -363,13 +345,22 @@ login(os.environ.get("HUG_TOKKEN", "")) # ### Wandb.ai link: https://wandb.ai/michalkozlowski936/Donut?workspace=user-michalkozlowski936 # ### Hugging_face link https://huggingface.co/Zombely -# In[22]: - - model_module = DonutModelPLModule(train_config, processor, model) wandb_logger = WandbLogger(project="Donut", name=LOGGING_PATH) +checkpoint_callback = ModelCheckpoint( + monitor="val_metric", + dirpath=CHECKPOINT_PATH, + filename="artifacts", + save_top_k=1, + save_last=False, + mode="min", + ) + +custom_ckpt = CustomCheckpointIO() + + trainer = pl.Trainer( accelerator="gpu", # change to gpu devices=1, @@ -378,9 +369,10 @@ trainer = pl.Trainer( check_val_every_n_epoch=train_config.get("check_val_every_n_epoch"), gradient_clip_val=train_config.get("gradient_clip_val"), precision=16, # we'll use mixed precision + plugins=custom_ckpt, num_sanity_val_steps=0, logger=wandb_logger, - callbacks=[PushToHubCallback()], + callbacks=[PushToHubCallback(), checkpoint_callback], ) trainer.fit(model_module)