donut/notepads/donut-from-zero-train.ipynb

59 KiB
Raw Permalink Blame History

from transformers import VisionEncoderDecoderConfig, DonutProcessor, VisionEncoderDecoderModel, DonutImageProcessor, XLMRobertaTokenizerFast, BertConfig, ViTConfig
from datasets import load_dataset, interleave_datasets
import json
import random
from typing import Any, List, Tuple
import torch
from torch.utils.data import Dataset, DataLoader
import re
from nltk import edit_distance
import numpy as np
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import Callback
import pytorch_lightning as pl
import os
from huggingface_hub import login
import torch
import pytorch_lightning as pl
from nltk import edit_distance
import re
import numpy as np


class DonutModelPLModuleStream(pl.LightningModule):
    def __init__(self, config, processor, model, max_length, train_dataloader, val_dataloader):
        super().__init__()
        self.config = config
        self.processor = processor
        self.model = model
        self.max_length = max_length
        self._train_dataloader = train_dataloader
        self._val_dataloader = val_dataloader

    def training_step(self, batch, batch_idx):
        # pixel_values, labels, _ = batch
        pixel_values = batch['pixel_values']
        labels = batch['labels']
        outputs = self.model(pixel_values, labels=labels)
        loss = outputs.loss
        self.log_dict({"train_loss": loss}, sync_dist=True)
        return loss

    def validation_step(self, batch, batch_idx, dataset_idx=0):
        # pixel_values, labels, answers = batch

        pixel_values = batch['pixel_values']
        labels = batch['labels']
        answers = batch['target_sequence'][0]
        batch_size = pixel_values.shape[0]
        # we feed the prompt to the model
        decoder_input_ids = torch.full((batch_size, 1), self.model.config.decoder_start_token_id, device=self.device)
        
        outputs = self.model.generate(pixel_values,
                                   decoder_input_ids=decoder_input_ids,
                                   max_length=self.max_length,
                                   early_stopping=True,
                                   pad_token_id=self.processor.tokenizer.pad_token_id,
                                   eos_token_id=self.processor.tokenizer.eos_token_id,
                                   use_cache=True,
                                   num_beams=1,
                                   bad_words_ids=[[self.processor.tokenizer.unk_token_id]],
                                   return_dict_in_generate=True,)
    
        predictions = []
        for seq in self.processor.tokenizer.batch_decode(outputs.sequences):
            seq = seq.replace(self.processor.tokenizer.eos_token, "").replace(self.processor.tokenizer.pad_token, "")
            seq = re.sub(r"<.*?>", "", seq, count=1).strip()  # remove first task start token
            predictions.append(seq)

        scores = list()
        for pred, answer in zip(predictions, answers):
            pred = re.sub(r"(?:(?<=>) | (?=</s_))", "", pred)
            # NOT NEEDED ANYMORE
            # answer = re.sub(r"<.*?>", "", answer, count=1)
            answer = answer.replace(self.processor.tokenizer.eos_token, "")
            scores.append(edit_distance(pred, answer) / max(len(pred), len(answer)))

            if self.config.get("verbose", False) and len(scores) == 1:
                print(f"Prediction: {pred}")
                print(f"    Answer: {answer}")
                print(f" Normed ED: {scores[0]}")

        return scores

    def validation_epoch_end(self, validation_step_outputs):
        # I set this to 1 manually
        # (previously set to len(self.config.dataset_name_or_paths))
        num_of_loaders = 1
        if num_of_loaders == 1:
            validation_step_outputs = [validation_step_outputs]
        assert len(validation_step_outputs) == num_of_loaders
        cnt = [0] * num_of_loaders
        total_metric = [0] * num_of_loaders
        val_metric = [0] * num_of_loaders
        for i, results in enumerate(validation_step_outputs):
            for scores in results:
                cnt[i] += len(scores)
                total_metric[i] += np.sum(scores)
            val_metric[i] = total_metric[i] / cnt[i]
            val_metric_name = f"val_metric_{i}th_dataset"
            self.log_dict({val_metric_name: val_metric[i]}, sync_dist=True)
        self.log_dict({"val_metric": np.sum(total_metric) / np.sum(cnt)}, sync_dist=True)

    def configure_optimizers(self):
        # TODO add scheduler
        optimizer = torch.optim.Adam(self.parameters(), lr=self.config.get("lr"))
    
        return optimizer

    def train_dataloader(self):
        return self._train_dataloader

    def val_dataloader(self):
        return self._val_dataloader
# dataset = load_dataset
image_processor = DonutImageProcessor(do_resize=True, do_align_long_axis=False, size=[960, 1260])
tokenizer = XLMRobertaTokenizerFast.from_pretrained('xlm-roberta-base')
config_encoder = ViTConfig(image_size=[1260, 960])
config_decoder = BertConfig()
config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(config_encoder, config_decoder)
processor = DonutProcessor(image_processor=image_processor, tokenizer=tokenizer)
model = VisionEncoderDecoderModel(config=config)
added_tokens = []

### PROCESS FUNC START ###

def add_tokens(list_of_tokens: List[str]):
    """
    Add special tokens to tokenizer and resize the token embeddings of the decoder
    """
    newly_added_num = processor.tokenizer.add_tokens(list_of_tokens)
    if newly_added_num > 0:
        model.decoder.resize_token_embeddings(len(processor.tokenizer))
        added_tokens.extend(list_of_tokens)

def json2token(obj: Any, update_special_tokens_for_json_key: bool = True, sort_json_key: bool = True):
    """
    Convert an ordered JSON object into a token sequence
    """
    if type(obj) == dict:
        if len(obj) == 1 and "text_sequence" in obj:
            return obj["text_sequence"]
        else:
            output = ""
            if sort_json_key:
                keys = sorted(obj.keys(), reverse=True)
            else:
                keys = obj.keys()
            for k in keys:
                if update_special_tokens_for_json_key:
                    add_tokens([fr"<s_{k}>", fr"</s_{k}>"])
                output += (
                    fr"<s_{k}>"
                    + json2token(obj[k], update_special_tokens_for_json_key, sort_json_key)
                    + fr"</s_{k}>"
                )
            return output
    elif type(obj) == list:
        return r"<sep/>".join(
            [json2token(item, update_special_tokens_for_json_key, sort_json_key) for item in obj]
        )
    else:
        obj = str(obj)
        if f"<{obj}/>" in added_tokens:
            obj = f"<{obj}/>"  # for categorical special tokens
        return obj

def process(row, split):
    task_start_token, prompt_end_token = "<s_cord-v2>", "<s_cord-v2>"
    ground_truth = json.loads(row["ground_truth"])
    if "gt_parses" in ground_truth:  # when multiple ground truths are available, e.g., docvqa
        assert isinstance(ground_truth["gt_parses"], list)
        gt_jsons = ground_truth["gt_parses"]
    else:
        assert "gt_parse" in ground_truth and isinstance(ground_truth["gt_parse"], dict)
        gt_jsons = [ground_truth["gt_parse"]]

    gt_token_sequences = (
        [
            json2token(
                gt_json,
                update_special_tokens_for_json_key=split == "train",
                sort_json_key=False,
            )
            + processor.tokenizer.eos_token
            for gt_json in gt_jsons  # load json from list of json
        ]
    )

    add_tokens([task_start_token, prompt_end_token])
    prompt_end_token_id = processor.tokenizer.convert_tokens_to_ids(prompt_end_token)

    # change if not 3 channels
    if row['image'].mode != "RGB":
        row['image'] = row['image'].convert("RGB")
    # inputs
    pixel_values = processor(row["image"], random_padding=split == "train", return_tensors="pt").pixel_values
    pixel_values = pixel_values.squeeze()

    # targets
    input_ids = processor.tokenizer(
        gt_token_sequences,
        add_special_tokens=False,
        max_length=config.max_length,
        padding="max_length",
        truncation=True,
        return_tensors="pt",
    )["input_ids"].squeeze(0)

    labels = input_ids.clone()
    labels[labels == processor.tokenizer.pad_token_id] = -100  # model doesn't need to predict pad token
    return {"pixel_values": pixel_values, "labels": labels, 'target_sequence': gt_token_sequences }

def proces_train(row):
    return process(row, 'train')

def proces_val(row):
    return process(row, 'validation')
dataset = load_dataset('Zombely/wikisource-red', streaming=True)
val_dataset = dataset.pop('validation') 
train_dataset = interleave_datasets(list(dataset.values()))
# train_length = sum(split.num_examples for split in dataset[list(dataset.keys())[0]].info.splits.values() if split.name != 'validation')
# val_length = list(val_dataset.info.splits.values())[-1].num_examples


train_dataset = train_dataset.map(proces_train, remove_columns = ['image', 'ground_truth'])
val_dataset = val_dataset.map(proces_val, remove_columns = ['image', 'ground_truth'])

train_dataset = train_dataset.with_format('torch')
val_dataset = val_dataset.with_format('torch')

# train_dataset = CustomWrapperIterator(train_dataset, total_len=train_length)
# val_dataset = CustomWrapperIterator(val_dataset, total_len=val_length)

model.config.pad_token_id = processor.tokenizer.pad_token_id
model.config.decoder_start_token_id = processor.tokenizer.convert_tokens_to_ids(['<s_cord-v2>'])[0]

train_dataloader = DataLoader(train_dataset, batch_size=1, num_workers=0)
val_dataloader = DataLoader(val_dataset, batch_size=1, num_workers=0)
Using custom data configuration Zombely--wikisource-red-98affb32ced5f2c5
train_config = {
    "max_epochs": 1,
    "val_check_interval": 1.0,
    "check_val_every_n_epoch": 1,
    "gradient_clip_val": 1.0,
    "num_training_samples_per_epoch": 800,
    "lr": 1.0e-4,
    "train_batch_sizes": [8],
    "val_batch_sizes": [1],
    "seed": 2023,
    "num_nodes": 1,
    "warmup_steps": 10,
    "result_path": "./result",
    "verbose": True
}
model_module = DonutModelPLModuleStream(train_config, processor, model, max_length=config.max_length, train_dataloader=train_dataloader, val_dataloader=val_dataloader)

trainer = pl.Trainer(
    accelerator="gpu" if torch.cuda.is_available() else 'cpu', # change to gpu
    devices=1,
    max_epochs=train_config['max_epochs'],
    val_check_interval=train_config['val_check_interval'],
    check_val_every_n_epoch=train_config['check_val_every_n_epoch'],
    gradient_clip_val=train_config['gradient_clip_val'],
    precision=16, # we'll use mixed precision
    num_sanity_val_steps=0,
)
Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(val_check_interval=1.0)` was configured so validation will run at the end of the training epoch..
trainer.fit(model_module)
Missing logger folder: /home/wmi/project/donut/notepads/lightning_logs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type                      | Params
----------------------------------------------------
0 | model | VisionEncoderDecoderModel | 227 M 
----------------------------------------------------
227 M     Trainable params
0         Non-trainable params
227 M     Total params
455.428   Total estimated model params size (MB)
/home/wmi/project/donut/env_donut/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 14 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
/home/wmi/project/donut/env_donut/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning: The dataloader, val_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 14 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
Training: 0it [00:00, ?it/s]
(1024, 1579)
(1024, 1473)
---------------------------------------------------------------------------
OutOfMemoryError                          Traceback (most recent call last)
/tmp/ipykernel_32385/828374167.py in <cell line: 1>()
----> 1 trainer.fit(model_module)

~/project/donut/env_donut/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py in fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    580             raise TypeError(f"`Trainer.fit()` requires a `LightningModule`, got: {model.__class__.__qualname__}")
    581         self.strategy._lightning_module = model
--> 582         call._call_and_handle_interrupt(
    583             self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
    584         )

~/project/donut/env_donut/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py in _call_and_handle_interrupt(trainer, trainer_fn, *args, **kwargs)
     36             return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
     37         else:
---> 38             return trainer_fn(*args, **kwargs)
     39 
     40     except _TunerExitException:

~/project/donut/env_donut/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py in _fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    622             model_connected=self.lightning_module is not None,
    623         )
--> 624         self._run(model, ckpt_path=self.ckpt_path)
    625 
    626         assert self.state.stopped

~/project/donut/env_donut/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py in _run(self, model, ckpt_path)
   1059         self._checkpoint_connector.resume_end()
   1060 
-> 1061         results = self._run_stage()
   1062 
   1063         log.detail(f"{self.__class__.__name__}: trainer tearing down")

~/project/donut/env_donut/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py in _run_stage(self)
   1138         if self.predicting:
   1139             return self._run_predict()
-> 1140         self._run_train()
   1141 
   1142     def _pre_training_routine(self) -> None:

~/project/donut/env_donut/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py in _run_train(self)
   1161 
   1162         with torch.autograd.set_detect_anomaly(self._detect_anomaly):
-> 1163             self.fit_loop.run()
   1164 
   1165     def _run_evaluate(self) -> _EVALUATE_OUTPUT:

~/project/donut/env_donut/lib/python3.10/site-packages/pytorch_lightning/loops/loop.py in run(self, *args, **kwargs)
    197             try:
    198                 self.on_advance_start(*args, **kwargs)
--> 199                 self.advance(*args, **kwargs)
    200                 self.on_advance_end()
    201                 self._restarting = False

~/project/donut/env_donut/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py in advance(self)
    265         self._data_fetcher.setup(dataloader, batch_to_device=batch_to_device)
    266         with self.trainer.profiler.profile("run_training_epoch"):
--> 267             self._outputs = self.epoch_loop.run(self._data_fetcher)
    268 
    269     def on_advance_end(self) -> None:

~/project/donut/env_donut/lib/python3.10/site-packages/pytorch_lightning/loops/loop.py in run(self, *args, **kwargs)
    197             try:
    198                 self.on_advance_start(*args, **kwargs)
--> 199                 self.advance(*args, **kwargs)
    200                 self.on_advance_end()
    201                 self._restarting = False

~/project/donut/env_donut/lib/python3.10/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py in advance(self, data_fetcher)
    212 
    213             with self.trainer.profiler.profile("run_training_batch"):
--> 214                 batch_output = self.batch_loop.run(kwargs)
    215 
    216         self.batch_progress.increment_processed()

~/project/donut/env_donut/lib/python3.10/site-packages/pytorch_lightning/loops/loop.py in run(self, *args, **kwargs)
    197             try:
    198                 self.on_advance_start(*args, **kwargs)
--> 199                 self.advance(*args, **kwargs)
    200                 self.on_advance_end()
    201                 self._restarting = False

~/project/donut/env_donut/lib/python3.10/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py in advance(self, kwargs)
     86                 self.trainer.optimizers, self.trainer.optimizer_frequencies, kwargs.get("batch_idx", 0)
     87             )
---> 88             outputs = self.optimizer_loop.run(optimizers, kwargs)
     89         else:
     90             outputs = self.manual_loop.run(kwargs)

~/project/donut/env_donut/lib/python3.10/site-packages/pytorch_lightning/loops/loop.py in run(self, *args, **kwargs)
    197             try:
    198                 self.on_advance_start(*args, **kwargs)
--> 199                 self.advance(*args, **kwargs)
    200                 self.on_advance_end()
    201                 self._restarting = False

~/project/donut/env_donut/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py in advance(self, optimizers, kwargs)
    198         kwargs = self._build_kwargs(kwargs, self.optimizer_idx, self._hiddens)
    199 
--> 200         result = self._run_optimization(kwargs, self._optimizers[self.optim_progress.optimizer_position])
    201         if result.loss is not None:
    202             # automatic optimization assumes a loss needs to be returned for extras to be considered as the batch

~/project/donut/env_donut/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py in _run_optimization(self, kwargs, optimizer)
    245         else:
    246             # the `batch_idx` is optional with inter-batch parallelism
--> 247             self._optimizer_step(optimizer, opt_idx, kwargs.get("batch_idx", 0), closure)
    248 
    249         result = closure.consume_result()

~/project/donut/env_donut/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py in _optimizer_step(self, optimizer, opt_idx, batch_idx, train_step_and_backward_closure)
    355 
    356         # model hook
--> 357         self.trainer._call_lightning_module_hook(
    358             "optimizer_step",
    359             self.trainer.current_epoch,

~/project/donut/env_donut/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py in _call_lightning_module_hook(self, hook_name, pl_module, *args, **kwargs)
   1303 
   1304         with self.profiler.profile(f"[LightningModule]{pl_module.__class__.__name__}.{hook_name}"):
-> 1305             output = fn(*args, **kwargs)
   1306 
   1307         # restore current_fx when nested context

~/project/donut/env_donut/lib/python3.10/site-packages/pytorch_lightning/core/module.py in optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, on_tpu, using_native_amp, using_lbfgs)
   1659 
   1660         """
-> 1661         optimizer.step(closure=optimizer_closure)
   1662 
   1663     def optimizer_zero_grad(self, epoch: int, batch_idx: int, optimizer: Optimizer, optimizer_idx: int) -> None:

~/project/donut/env_donut/lib/python3.10/site-packages/pytorch_lightning/core/optimizer.py in step(self, closure, **kwargs)
    167 
    168         assert self._strategy is not None
--> 169         step_output = self._strategy.optimizer_step(self._optimizer, self._optimizer_idx, closure, **kwargs)
    170 
    171         self._on_after_step()

~/project/donut/env_donut/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py in optimizer_step(self, optimizer, opt_idx, closure, model, **kwargs)
    232         # TODO(lite): remove assertion once strategy's optimizer_step typing is fixed
    233         assert isinstance(model, pl.LightningModule)
--> 234         return self.precision_plugin.optimizer_step(
    235             optimizer, model=model, optimizer_idx=opt_idx, closure=closure, **kwargs
    236         )

~/project/donut/env_donut/lib/python3.10/site-packages/pytorch_lightning/plugins/precision/native_amp.py in optimizer_step(self, optimizer, model, optimizer_idx, closure, **kwargs)
     83                 f"Native AMP and the LBFGS optimizer are not compatible (optimizer {optimizer_idx})."
     84             )
---> 85         closure_result = closure()
     86 
     87         if not _optimizer_handles_unscaling(optimizer):

~/project/donut/env_donut/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py in __call__(self, *args, **kwargs)
    145 
    146     def __call__(self, *args: Any, **kwargs: Any) -> Optional[Tensor]:
--> 147         self._result = self.closure(*args, **kwargs)
    148         return self._result.loss
    149 

~/project/donut/env_donut/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py in closure(self, *args, **kwargs)
    140 
    141         if self._backward_fn is not None and step_output.closure_loss is not None:
--> 142             self._backward_fn(step_output.closure_loss)
    143 
    144         return step_output

~/project/donut/env_donut/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py in backward_fn(loss)
    301 
    302         def backward_fn(loss: Tensor) -> None:
--> 303             self.trainer._call_strategy_hook("backward", loss, optimizer, opt_idx)
    304 
    305         return backward_fn

~/project/donut/env_donut/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py in _call_strategy_hook(self, hook_name, *args, **kwargs)
   1441 
   1442         with self.profiler.profile(f"[Strategy]{self.strategy.__class__.__name__}.{hook_name}"):
-> 1443             output = fn(*args, **kwargs)
   1444 
   1445         # restore current_fx when nested context

~/project/donut/env_donut/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py in backward(self, closure_loss, optimizer, optimizer_idx, *args, **kwargs)
    205         closure_loss = self.precision_plugin.pre_backward(closure_loss, self.lightning_module)
    206 
--> 207         self.precision_plugin.backward(closure_loss, self.lightning_module, optimizer, optimizer_idx, *args, **kwargs)
    208 
    209         closure_loss = self.precision_plugin.post_backward(closure_loss, self.lightning_module)

~/project/donut/env_donut/lib/python3.10/site-packages/pytorch_lightning/plugins/precision/precision_plugin.py in backward(self, tensor, model, optimizer, optimizer_idx, *args, **kwargs)
     67             \**kwargs: Keyword arguments for the same purpose as ``*args``.
     68         """
---> 69         model.backward(tensor, optimizer, optimizer_idx, *args, **kwargs)
     70 
     71     def post_backward(self, tensor: Tensor, module: "pl.LightningModule") -> Tensor:  # type: ignore[override]

~/project/donut/env_donut/lib/python3.10/site-packages/pytorch_lightning/core/module.py in backward(self, loss, optimizer, optimizer_idx, *args, **kwargs)
   1404                 loss.backward()
   1405         """
-> 1406         loss.backward(*args, **kwargs)
   1407 
   1408     def toggle_optimizer(self, optimizer: Union[Optimizer, LightningOptimizer], optimizer_idx: int) -> None:

~/project/donut/env_donut/lib/python3.10/site-packages/torch/_tensor.py in backward(self, gradient, retain_graph, create_graph, inputs)
    486                 inputs=inputs,
    487             )
--> 488         torch.autograd.backward(
    489             self, gradient, retain_graph, create_graph, inputs=inputs
    490         )

~/project/donut/env_donut/lib/python3.10/site-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
    195     # some Python versions print out the first line of a multi-line function
    196     # calls in the traceback and some print out the last line
--> 197     Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
    198         tensors, grad_tensors_, retain_graph, create_graph, inputs,
    199         allow_unreachable=True, accumulate_grad=True)  # Calls into the C++ engine to run the backward pass

OutOfMemoryError: CUDA out of memory. Tried to allocate 368.00 MiB (GPU 0; 23.70 GiB total capacity; 22.07 GiB already allocated; 260.56 MiB free; 22.09 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF