In [1]:
from transformers import VisionEncoderDecoderConfig, DonutProcessor, VisionEncoderDecoderModel, DonutImageProcessor, XLMRobertaTokenizerFast, BertConfig, ViTConfig
from datasets import load_dataset, interleave_datasets
import json
import random
from typing import Any, List, Tuple
import torch
from torch.utils.data import Dataset, DataLoader
import re
from nltk import edit_distance
import numpy as np
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import Callback
import pytorch_lightning as pl
import os
from huggingface_hub import login


In [2]:
import torch
import pytorch_lightning as pl
from nltk import edit_distance
import re
import numpy as np


class DonutModelPLModuleStream(pl.LightningModule):
    def __init__(self, config, processor, model, max_length, train_dataloader, val_dataloader):
        super().__init__()
        self.config = config
        self.processor = processor
        self.model = model
        self.max_length = max_length
        self._train_dataloader = train_dataloader
        self._val_dataloader = val_dataloader

    def training_step(self, batch, batch_idx):
        # pixel_values, labels, _ = batch
        pixel_values = batch['pixel_values']
        labels = batch['labels']
        outputs = self.model(pixel_values, labels=labels)
        loss = outputs.loss
        self.log_dict({"train_loss": loss}, sync_dist=True)
        return loss

    def validation_step(self, batch, batch_idx, dataset_idx=0):
        # pixel_values, labels, answers = batch

        pixel_values = batch['pixel_values']
        labels = batch['labels']
        answers = batch['target_sequence'][0]
        batch_size = pixel_values.shape[0]
        # we feed the prompt to the model
        decoder_input_ids = torch.full((batch_size, 1), self.model.config.decoder_start_token_id, device=self.device)
        
        outputs = self.model.generate(pixel_values,
                                   decoder_input_ids=decoder_input_ids,
                                   max_length=self.max_length,
                                   early_stopping=True,
                                   pad_token_id=self.processor.tokenizer.pad_token_id,
                                   eos_token_id=self.processor.tokenizer.eos_token_id,
                                   use_cache=True,
                                   num_beams=1,
                                   bad_words_ids=[[self.processor.tokenizer.unk_token_id]],
                                   return_dict_in_generate=True,)
    
        predictions = []
        for seq in self.processor.tokenizer.batch_decode(outputs.sequences):
            seq = seq.replace(self.processor.tokenizer.eos_token, "").replace(self.processor.tokenizer.pad_token, "")
            seq = re.sub(r"<.*?>", "", seq, count=1).strip()  # remove first task start token
            predictions.append(seq)

        scores = list()
        for pred, answer in zip(predictions, answers):
            pred = re.sub(r"(?:(?<=>) | (?=</s_))", "", pred)
            # NOT NEEDED ANYMORE
            # answer = re.sub(r"<.*?>", "", answer, count=1)
            answer = answer.replace(self.processor.tokenizer.eos_token, "")
            scores.append(edit_distance(pred, answer) / max(len(pred), len(answer)))

            if self.config.get("verbose", False) and len(scores) == 1:
                print(f"Prediction: {pred}")
                print(f"    Answer: {answer}")
                print(f" Normed ED: {scores[0]}")

        return scores

    def validation_epoch_end(self, validation_step_outputs):
        # I set this to 1 manually
        # (previously set to len(self.config.dataset_name_or_paths))
        num_of_loaders = 1
        if num_of_loaders == 1:
            validation_step_outputs = [validation_step_outputs]
        assert len(validation_step_outputs) == num_of_loaders
        cnt = [0] * num_of_loaders
        total_metric = [0] * num_of_loaders
        val_metric = [0] * num_of_loaders
        for i, results in enumerate(validation_step_outputs):
            for scores in results:
                cnt[i] += len(scores)
                total_metric[i] += np.sum(scores)
            val_metric[i] = total_metric[i] / cnt[i]
            val_metric_name = f"val_metric_{i}th_dataset"
            self.log_dict({val_metric_name: val_metric[i]}, sync_dist=True)
        self.log_dict({"val_metric": np.sum(total_metric) / np.sum(cnt)}, sync_dist=True)

    def configure_optimizers(self):
        # TODO add scheduler
        optimizer = torch.optim.Adam(self.parameters(), lr=self.config.get("lr"))
    
        return optimizer

    def train_dataloader(self):
        return self._train_dataloader

    def val_dataloader(self):
        return self._val_dataloader

In [3]:
# dataset = load_dataset

In [4]:
image_processor = DonutImageProcessor(do_resize=True, do_align_long_axis=False, size=[960, 1260])
tokenizer = XLMRobertaTokenizerFast.from_pretrained('xlm-roberta-base')

In [5]:
config_encoder = ViTConfig(image_size=[1260, 960])
config_decoder = BertConfig()
config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(config_encoder, config_decoder)

In [6]:
processor = DonutProcessor(image_processor=image_processor, tokenizer=tokenizer)
model = VisionEncoderDecoderModel(config=config)

In [7]:
added_tokens = []

### PROCESS FUNC START ###

def add_tokens(list_of_tokens: List[str]):
    """
    Add special tokens to tokenizer and resize the token embeddings of the decoder
    """
    newly_added_num = processor.tokenizer.add_tokens(list_of_tokens)
    if newly_added_num > 0:
        model.decoder.resize_token_embeddings(len(processor.tokenizer))
        added_tokens.extend(list_of_tokens)

def json2token(obj: Any, update_special_tokens_for_json_key: bool = True, sort_json_key: bool = True):
    """
    Convert an ordered JSON object into a token sequence
    """
    if type(obj) == dict:
        if len(obj) == 1 and "text_sequence" in obj:
            return obj["text_sequence"]
        else:
            output = ""
            if sort_json_key:
                keys = sorted(obj.keys(), reverse=True)
            else:
                keys = obj.keys()
            for k in keys:
                if update_special_tokens_for_json_key:
                    add_tokens([fr"<s_{k}>", fr"</s_{k}>"])
                output += (
                    fr"<s_{k}>"
                    + json2token(obj[k], update_special_tokens_for_json_key, sort_json_key)
                    + fr"</s_{k}>"
                )
            return output
    elif type(obj) == list:
        return r"<sep/>".join(
            [json2token(item, update_special_tokens_for_json_key, sort_json_key) for item in obj]
        )
    else:
        obj = str(obj)
        if f"<{obj}/>" in added_tokens:
            obj = f"<{obj}/>"  # for categorical special tokens
        return obj

def process(row, split):
    task_start_token, prompt_end_token = "<s_cord-v2>", "<s_cord-v2>"
    ground_truth = json.loads(row["ground_truth"])
    if "gt_parses" in ground_truth:  # when multiple ground truths are available, e.g., docvqa
        assert isinstance(ground_truth["gt_parses"], list)
        gt_jsons = ground_truth["gt_parses"]
    else:
        assert "gt_parse" in ground_truth and isinstance(ground_truth["gt_parse"], dict)
        gt_jsons = [ground_truth["gt_parse"]]

    gt_token_sequences = (
        [
            json2token(
                gt_json,
                update_special_tokens_for_json_key=split == "train",
                sort_json_key=False,
            )
            + processor.tokenizer.eos_token
            for gt_json in gt_jsons  # load json from list of json
        ]
    )

    add_tokens([task_start_token, prompt_end_token])
    prompt_end_token_id = processor.tokenizer.convert_tokens_to_ids(prompt_end_token)

    # change if not 3 channels
    if row['image'].mode != "RGB":
        row['image'] = row['image'].convert("RGB")
    # inputs
    pixel_values = processor(row["image"], random_padding=split == "train", return_tensors="pt").pixel_values
    pixel_values = pixel_values.squeeze()

    # targets
    input_ids = processor.tokenizer(
        gt_token_sequences,
        add_special_tokens=False,
        max_length=config.max_length,
        padding="max_length",
        truncation=True,
        return_tensors="pt",
    )["input_ids"].squeeze(0)

    labels = input_ids.clone()
    labels[labels == processor.tokenizer.pad_token_id] = -100  # model doesn't need to predict pad token
    return {"pixel_values": pixel_values, "labels": labels, 'target_sequence': gt_token_sequences }

def proces_train(row):
    return process(row, 'train')

def proces_val(row):
    return process(row, 'validation')


In [8]:
dataset = load_dataset('Zombely/wikisource-red', streaming=True)
val_dataset = dataset.pop('validation') 
train_dataset = interleave_datasets(list(dataset.values()))
# train_length = sum(split.num_examples for split in dataset[list(dataset.keys())[0]].info.splits.values() if split.name != 'validation')
# val_length = list(val_dataset.info.splits.values())[-1].num_examples


train_dataset = train_dataset.map(proces_train, remove_columns = ['image', 'ground_truth'])
val_dataset = val_dataset.map(proces_val, remove_columns = ['image', 'ground_truth'])

train_dataset = train_dataset.with_format('torch')
val_dataset = val_dataset.with_format('torch')

# train_dataset = CustomWrapperIterator(train_dataset, total_len=train_length)
# val_dataset = CustomWrapperIterator(val_dataset, total_len=val_length)

model.config.pad_token_id = processor.tokenizer.pad_token_id
model.config.decoder_start_token_id = processor.tokenizer.convert_tokens_to_ids(['<s_cord-v2>'])[0]

train_dataloader = DataLoader(train_dataset, batch_size=1, num_workers=0)
val_dataloader = DataLoader(val_dataset, batch_size=1, num_workers=0)


Using custom data configuration Zombely--wikisource-red-98affb32ced5f2c5


In [9]:
train_config = {
    "max_epochs": 1,
    "val_check_interval": 1.0,
    "check_val_every_n_epoch": 1,
    "gradient_clip_val": 1.0,
    "num_training_samples_per_epoch": 800,
    "lr": 1.0e-4,
    "train_batch_sizes": [8],
    "val_batch_sizes": [1],
    "seed": 2023,
    "num_nodes": 1,
    "warmup_steps": 10,
    "result_path": "./result",
    "verbose": True
}


In [10]:
model_module = DonutModelPLModuleStream(train_config, processor, model, max_length=config.max_length, train_dataloader=train_dataloader, val_dataloader=val_dataloader)

In [11]:

trainer = pl.Trainer(
    accelerator="gpu" if torch.cuda.is_available() else 'cpu', # change to gpu
    devices=1,
    max_epochs=train_config['max_epochs'],
    val_check_interval=train_config['val_check_interval'],
    check_val_every_n_epoch=train_config['check_val_every_n_epoch'],
    gradient_clip_val=train_config['gradient_clip_val'],
    precision=16, # we'll use mixed precision
    num_sanity_val_steps=0,
)


Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(val_check_interval=1.0)` was configured so validation will run at the end of the training epoch..


In [12]:
trainer.fit(model_module)

Missing logger folder: /home/wmi/project/donut/notepads/lightning_logs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type                      | Params
----------------------------------------------------
0 | model | VisionEncoderDecoderModel | 227 M 
----------------------------------------------------
227 M     Trainable params
0         Non-trainable params
227 M     Total params
455.428   Total estimated model params size (MB)
  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

(1024, 1579)
(1024, 1473)


OutOfMemoryError: CUDA out of memory. Tried to allocate 368.00 MiB (GPU 0; 23.70 GiB total capacity; 22.07 GiB already allocated; 260.56 MiB free; 22.09 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF