donut/donut-train.ipynb
2022-12-11 10:43:08 +01:00

80 KiB
Raw Blame History

from transformers import VisionEncoderDecoderConfig, DonutProcessor, VisionEncoderDecoderModel
from datasets import load_dataset
import json
import random
from typing import Any, List, Tuple
import torch
from torch.utils.data import Dataset, DataLoader
import re
from nltk import edit_distance
import numpy as np
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import Callback
import pytorch_lightning as pl
import os
from huggingface_hub import login
DATASET_PATH = "Zombely/pl-text-images"
PRETRAINED_MODEL_PATH = "nielsr/donut-proto"
OUTPUT_MODEL_PATH = "Zombely/plwiki-test-proto"
LOGGING_PATH = "plwiki-test-run-proto"
train_config = {
    "max_epochs":5,
    "val_check_interval":0.2, # how many times we want to validate during an epoch
    "check_val_every_n_epoch":1,
    "gradient_clip_val":1.0,
    "num_training_samples_per_epoch": 800,
    "lr":3e-5,
    "train_batch_sizes": [8],
    "val_batch_sizes": [1],
    # "seed":2022,
    "num_nodes": 1,
    "warmup_steps": 300, # 800/8*30/10, 10%
    "result_path": "./result",
    "verbose": True,
}
dataset = load_dataset(DATASET_PATH)
Using custom data configuration Zombely--pl-text-images-f3f66e614f4d9a7a
Found cached dataset parquet (/home/pc/.cache/huggingface/datasets/Zombely___parquet/Zombely--pl-text-images-f3f66e614f4d9a7a/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)
  0%|          | 0/3 [00:00<?, ?it/s]
max_length = 768
image_size = [1280, 960]
config = VisionEncoderDecoderConfig.from_pretrained(PRETRAINED_MODEL_PATH)
config.encoder.image_size = image_size # (height, width)
config.decoder.max_length = max_length
processor = DonutProcessor.from_pretrained(PRETRAINED_MODEL_PATH)
model = VisionEncoderDecoderModel.from_pretrained(PRETRAINED_MODEL_PATH, config=config)
Could not find image processor class in the image processor config or the model config. Loading based on pattern matching with the model's feature extractor configuration.
Some weights of the model checkpoint at nielsr/donut-proto were not used when initializing VisionEncoderDecoderModel: ['encoder.encoder.layers.2.blocks.13.attn_mask', 'encoder.encoder.layers.2.blocks.17.attn_mask', 'encoder.encoder.layers.2.blocks.1.attn_mask', 'encoder.encoder.layers.2.blocks.9.attn_mask', 'encoder.encoder.layers.2.blocks.7.attn_mask', 'encoder.encoder.layers.3.blocks.1.attn_mask', 'encoder.encoder.layers.2.blocks.11.attn_mask', 'encoder.encoder.layers.2.blocks.5.attn_mask', 'encoder.encoder.layers.2.blocks.15.attn_mask', 'encoder.encoder.layers.0.blocks.1.attn_mask', 'encoder.encoder.layers.1.blocks.1.attn_mask', 'encoder.encoder.layers.2.blocks.3.attn_mask']
- This IS expected if you are initializing VisionEncoderDecoderModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing VisionEncoderDecoderModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of VisionEncoderDecoderModel were not initialized from the model checkpoint at nielsr/donut-proto and are newly initialized: ['encoder.layernorm.weight', 'encoder.layernorm.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
added_tokens = []

class DonutDataset(Dataset):
    """
    DonutDataset which is saved in huggingface datasets format. (see details in https://huggingface.co/docs/datasets)
    Each row, consists of image path(png/jpg/jpeg) and gt data (json/jsonl/txt),
    and it will be converted into input_tensor(vectorized image) and input_ids(tokenized string).
    Args:
        dataset_name_or_path: name of dataset (available at huggingface.co/datasets) or the path containing image files and metadata.jsonl
        max_length: the max number of tokens for the target sequences
        split: whether to load "train", "validation" or "test" split
        ignore_id: ignore_index for torch.nn.CrossEntropyLoss
        task_start_token: the special token to be fed to the decoder to conduct the target task
        prompt_end_token: the special token at the end of the sequences
        sort_json_key: whether or not to sort the JSON keys
    """

    def __init__(
        self,
        dataset_name_or_path: str,
        max_length: int,
        split: str = "train",
        ignore_id: int = -100,
        task_start_token: str = "<s>",
        prompt_end_token: str = None,
        sort_json_key: bool = True,
    ):
        super().__init__()

        self.max_length = max_length
        self.split = split
        self.ignore_id = ignore_id
        self.task_start_token = task_start_token
        self.prompt_end_token = prompt_end_token if prompt_end_token else task_start_token
        self.sort_json_key = sort_json_key

        self.dataset = load_dataset(dataset_name_or_path, split=self.split)
        self.dataset_length = len(self.dataset)

        self.gt_token_sequences = []
        for sample in self.dataset:
            ground_truth = json.loads(sample["ground_truth"])
            if "gt_parses" in ground_truth:  # when multiple ground truths are available, e.g., docvqa
                assert isinstance(ground_truth["gt_parses"], list)
                gt_jsons = ground_truth["gt_parses"]
            else:
                assert "gt_parse" in ground_truth and isinstance(ground_truth["gt_parse"], dict)
                gt_jsons = [ground_truth["gt_parse"]]

            self.gt_token_sequences.append(
                [
                    self.json2token(
                        gt_json,
                        update_special_tokens_for_json_key=self.split == "train",
                        sort_json_key=self.sort_json_key,
                    )
                    + processor.tokenizer.eos_token
                    for gt_json in gt_jsons  # load json from list of json
                ]
            )

        self.add_tokens([self.task_start_token, self.prompt_end_token])
        self.prompt_end_token_id = processor.tokenizer.convert_tokens_to_ids(self.prompt_end_token)

    def json2token(self, obj: Any, update_special_tokens_for_json_key: bool = True, sort_json_key: bool = True):
        """
        Convert an ordered JSON object into a token sequence
        """
        if type(obj) == dict:
            if len(obj) == 1 and "text_sequence" in obj:
                return obj["text_sequence"]
            else:
                output = ""
                if sort_json_key:
                    keys = sorted(obj.keys(), reverse=True)
                else:
                    keys = obj.keys()
                for k in keys:
                    if update_special_tokens_for_json_key:
                        self.add_tokens([fr"<s_{k}>", fr"</s_{k}>"])
                    output += (
                        fr"<s_{k}>"
                        + self.json2token(obj[k], update_special_tokens_for_json_key, sort_json_key)
                        + fr"</s_{k}>"
                    )
                return output
        elif type(obj) == list:
            return r"<sep/>".join(
                [self.json2token(item, update_special_tokens_for_json_key, sort_json_key) for item in obj]
            )
        else:
            obj = str(obj)
            if f"<{obj}/>" in added_tokens:
                obj = f"<{obj}/>"  # for categorical special tokens
            return obj
    
    def add_tokens(self, list_of_tokens: List[str]):
        """
        Add special tokens to tokenizer and resize the token embeddings of the decoder
        """
        newly_added_num = processor.tokenizer.add_tokens(list_of_tokens)
        if newly_added_num > 0:
            model.decoder.resize_token_embeddings(len(processor.tokenizer))
            added_tokens.extend(list_of_tokens)
    
    def __len__(self) -> int:
        return self.dataset_length

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Load image from image_path of given dataset_path and convert into input_tensor and labels
        Convert gt data into input_ids (tokenized string)
        Returns:
            input_tensor : preprocessed image
            input_ids : tokenized gt_data
            labels : masked labels (model doesn't need to predict prompt and pad token)
        """
        sample = self.dataset[idx]

        # inputs
        pixel_values = processor(sample["image"], random_padding=self.split == "train", return_tensors="pt").pixel_values
        pixel_values = pixel_values.squeeze()

        # targets
        target_sequence = random.choice(self.gt_token_sequences[idx])  # can be more than one, e.g., DocVQA Task 1
        input_ids = processor.tokenizer(
            target_sequence,
            add_special_tokens=False,
            max_length=self.max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt",
        )["input_ids"].squeeze(0)

        labels = input_ids.clone()
        labels[labels == processor.tokenizer.pad_token_id] = self.ignore_id  # model doesn't need to predict pad token
        # labels[: torch.nonzero(labels == self.prompt_end_token_id).sum() + 1] = self.ignore_id  # model doesn't need to predict prompt (for VQA)
        return pixel_values, labels, target_sequence
processor.image_processor.size = image_size[::-1] # should be (width, height)
processor.image_processor.do_align_long_axis = False

train_dataset = DonutDataset(DATASET_PATH, max_length=max_length,
                             split="train", task_start_token="<s_cord-v2>", prompt_end_token="<s_cord-v2>",
                             sort_json_key=False, # cord dataset is preprocessed, so no need for this
                             )

val_dataset = DonutDataset(DATASET_PATH, max_length=max_length,
                             split="validation", task_start_token="<s_cord-v2>", prompt_end_token="<s_cord-v2>",
                             sort_json_key=False, # cord dataset is preprocessed, so no need for this
                             )
Using custom data configuration Zombely--pl-text-images-f3f66e614f4d9a7a
Found cached dataset parquet (/home/pc/.cache/huggingface/datasets/Zombely___parquet/Zombely--pl-text-images-f3f66e614f4d9a7a/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)
Using custom data configuration Zombely--pl-text-images-f3f66e614f4d9a7a
Found cached dataset parquet (/home/pc/.cache/huggingface/datasets/Zombely___parquet/Zombely--pl-text-images-f3f66e614f4d9a7a/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)
model.config.pad_token_id = processor.tokenizer.pad_token_id
model.config.decoder_start_token_id = processor.tokenizer.convert_tokens_to_ids(['<s_cord-v2>'])[0]
train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=4)
val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=4)
class DonutModelPLModule(pl.LightningModule):
    def __init__(self, config, processor, model):
        super().__init__()
        self.config = config
        self.processor = processor
        self.model = model

    def training_step(self, batch, batch_idx):
        pixel_values, labels, _ = batch
        
        outputs = self.model(pixel_values, labels=labels)
        loss = outputs.loss
        self.log_dict({"train_loss": loss}, sync_dist=True)
        return loss

    def validation_step(self, batch, batch_idx, dataset_idx=0):
        pixel_values, labels, answers = batch
        batch_size = pixel_values.shape[0]
        # we feed the prompt to the model
        decoder_input_ids = torch.full((batch_size, 1), self.model.config.decoder_start_token_id, device=self.device)
        
        outputs = self.model.generate(pixel_values,
                                   decoder_input_ids=decoder_input_ids,
                                   max_length=max_length,
                                   early_stopping=True,
                                   pad_token_id=self.processor.tokenizer.pad_token_id,
                                   eos_token_id=self.processor.tokenizer.eos_token_id,
                                   use_cache=True,
                                   num_beams=1,
                                   bad_words_ids=[[self.processor.tokenizer.unk_token_id]],
                                   return_dict_in_generate=True,)
    
        predictions = []
        for seq in self.processor.tokenizer.batch_decode(outputs.sequences):
            seq = seq.replace(self.processor.tokenizer.eos_token, "").replace(self.processor.tokenizer.pad_token, "")
            seq = re.sub(r"<.*?>", "", seq, count=1).strip()  # remove first task start token
            predictions.append(seq)

        scores = list()
        for pred, answer in zip(predictions, answers):
            pred = re.sub(r"(?:(?<=>) | (?=</s_))", "", pred)
            # NOT NEEDED ANYMORE
            # answer = re.sub(r"<.*?>", "", answer, count=1)
            answer = answer.replace(self.processor.tokenizer.eos_token, "")
            scores.append(edit_distance(pred, answer) / max(len(pred), len(answer)))

            if self.config.get("verbose", False) and len(scores) == 1:
                print(f"Prediction: {pred}")
                print(f"    Answer: {answer}")
                print(f" Normed ED: {scores[0]}")

        return scores

    def validation_epoch_end(self, validation_step_outputs):
        # I set this to 1 manually
        # (previously set to len(self.config.dataset_name_or_paths))
        num_of_loaders = 1
        if num_of_loaders == 1:
            validation_step_outputs = [validation_step_outputs]
        assert len(validation_step_outputs) == num_of_loaders
        cnt = [0] * num_of_loaders
        total_metric = [0] * num_of_loaders
        val_metric = [0] * num_of_loaders
        for i, results in enumerate(validation_step_outputs):
            for scores in results:
                cnt[i] += len(scores)
                total_metric[i] += np.sum(scores)
            val_metric[i] = total_metric[i] / cnt[i]
            val_metric_name = f"val_metric_{i}th_dataset"
            self.log_dict({val_metric_name: val_metric[i]}, sync_dist=True)
        self.log_dict({"val_metric": np.sum(total_metric) / np.sum(cnt)}, sync_dist=True)

    def configure_optimizers(self):
        # TODO add scheduler
        optimizer = torch.optim.Adam(self.parameters(), lr=self.config.get("lr"))
    
        return optimizer

    def train_dataloader(self):
        return train_dataloader

    def val_dataloader(self):
        return val_dataloader
class PushToHubCallback(Callback):
    def on_train_epoch_end(self, trainer, pl_module):
        print(f"Pushing model to the hub, epoch {trainer.current_epoch}")
        pl_module.model.push_to_hub(OUTPUT_MODEL_PATH,
                                    commit_message=f"Training in progress, epoch {trainer.current_epoch}")

    def on_train_end(self, trainer, pl_module):
        print(f"Pushing model to the hub after training")
        pl_module.processor.push_to_hub(OUTPUT_MODEL_PATH,
                                    commit_message=f"Training done")
        pl_module.model.push_to_hub(OUTPUT_MODEL_PATH,
                                    commit_message=f"Training done")
login(os.environ.get("HUG_TOKEN"))
Token will not been saved to git credential helper. Pass `add_to_git_credential=True` if you want to set the git credential as well.
Token is valid.
Your token has been saved to /home/pc/.huggingface/token
Login successful

model_module = DonutModelPLModule(train_config, processor, model)

wandb_logger = WandbLogger(project="Donut", name=LOGGING_PATH)

trainer = pl.Trainer(
        accelerator="cpu", # change to gpu
        devices=1,
        max_epochs=train_config.get("max_epochs"),
        val_check_interval=train_config.get("val_check_interval"),
        check_val_every_n_epoch=train_config.get("check_val_every_n_epoch"),
        gradient_clip_val=train_config.get("gradient_clip_val"),
        precision=16, # we'll use mixed precision
        num_sanity_val_steps=0,
        logger=wandb_logger,
        callbacks=[PushToHubCallback()],
)

trainer.fit(model_module)
Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name  | Type                      | Params
----------------------------------------------------
0 | model | VisionEncoderDecoderModel | 213 M 
----------------------------------------------------
213 M     Trainable params
0         Non-trainable params
213 M     Total params
854.597   Total estimated model params size (MB)
Training: 0it [00:00, ?it/s]
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
/tmp/ipykernel_294/2569065759.py in <module>
     27 )
     28 
---> 29 trainer.fit(model_module)

~/anaconda3/envs/donut/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    581         self.strategy._lightning_module = model
    582         call._call_and_handle_interrupt(
--> 583             self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
    584         )
    585 

~/anaconda3/envs/donut/lib/python3.7/site-packages/pytorch_lightning/trainer/call.py in _call_and_handle_interrupt(trainer, trainer_fn, *args, **kwargs)
     36             return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
     37         else:
---> 38             return trainer_fn(*args, **kwargs)
     39 
     40     except _TunerExitException:

~/anaconda3/envs/donut/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in _fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    622             model_connected=self.lightning_module is not None,
    623         )
--> 624         self._run(model, ckpt_path=self.ckpt_path)
    625 
    626         assert self.state.stopped

~/anaconda3/envs/donut/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in _run(self, model, ckpt_path)
   1059         self._checkpoint_connector.resume_end()
   1060 
-> 1061         results = self._run_stage()
   1062 
   1063         log.detail(f"{self.__class__.__name__}: trainer tearing down")

~/anaconda3/envs/donut/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in _run_stage(self)
   1138         if self.predicting:
   1139             return self._run_predict()
-> 1140         self._run_train()
   1141 
   1142     def _pre_training_routine(self) -> None:

~/anaconda3/envs/donut/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in _run_train(self)
   1161 
   1162         with torch.autograd.set_detect_anomaly(self._detect_anomaly):
-> 1163             self.fit_loop.run()
   1164 
   1165     def _run_evaluate(self) -> _EVALUATE_OUTPUT:

~/anaconda3/envs/donut/lib/python3.7/site-packages/pytorch_lightning/loops/loop.py in run(self, *args, **kwargs)
    197             try:
    198                 self.on_advance_start(*args, **kwargs)
--> 199                 self.advance(*args, **kwargs)
    200                 self.on_advance_end()
    201                 self._restarting = False

~/anaconda3/envs/donut/lib/python3.7/site-packages/pytorch_lightning/loops/fit_loop.py in advance(self)
    265         self._data_fetcher.setup(dataloader, batch_to_device=batch_to_device)
    266         with self.trainer.profiler.profile("run_training_epoch"):
--> 267             self._outputs = self.epoch_loop.run(self._data_fetcher)
    268 
    269     def on_advance_end(self) -> None:

~/anaconda3/envs/donut/lib/python3.7/site-packages/pytorch_lightning/loops/loop.py in run(self, *args, **kwargs)
    197             try:
    198                 self.on_advance_start(*args, **kwargs)
--> 199                 self.advance(*args, **kwargs)
    200                 self.on_advance_end()
    201                 self._restarting = False

~/anaconda3/envs/donut/lib/python3.7/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py in advance(self, data_fetcher)
    212 
    213             with self.trainer.profiler.profile("run_training_batch"):
--> 214                 batch_output = self.batch_loop.run(kwargs)
    215 
    216         self.batch_progress.increment_processed()

~/anaconda3/envs/donut/lib/python3.7/site-packages/pytorch_lightning/loops/loop.py in run(self, *args, **kwargs)
    197             try:
    198                 self.on_advance_start(*args, **kwargs)
--> 199                 self.advance(*args, **kwargs)
    200                 self.on_advance_end()
    201                 self._restarting = False

~/anaconda3/envs/donut/lib/python3.7/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py in advance(self, kwargs)
     86                 self.trainer.optimizers, self.trainer.optimizer_frequencies, kwargs.get("batch_idx", 0)
     87             )
---> 88             outputs = self.optimizer_loop.run(optimizers, kwargs)
     89         else:
     90             outputs = self.manual_loop.run(kwargs)

~/anaconda3/envs/donut/lib/python3.7/site-packages/pytorch_lightning/loops/loop.py in run(self, *args, **kwargs)
    197             try:
    198                 self.on_advance_start(*args, **kwargs)
--> 199                 self.advance(*args, **kwargs)
    200                 self.on_advance_end()
    201                 self._restarting = False

~/anaconda3/envs/donut/lib/python3.7/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py in advance(self, optimizers, kwargs)
    198         kwargs = self._build_kwargs(kwargs, self.optimizer_idx, self._hiddens)
    199 
--> 200         result = self._run_optimization(kwargs, self._optimizers[self.optim_progress.optimizer_position])
    201         if result.loss is not None:
    202             # automatic optimization assumes a loss needs to be returned for extras to be considered as the batch

~/anaconda3/envs/donut/lib/python3.7/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py in _run_optimization(self, kwargs, optimizer)
    245         else:
    246             # the `batch_idx` is optional with inter-batch parallelism
--> 247             self._optimizer_step(optimizer, opt_idx, kwargs.get("batch_idx", 0), closure)
    248 
    249         result = closure.consume_result()

~/anaconda3/envs/donut/lib/python3.7/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py in _optimizer_step(self, optimizer, opt_idx, batch_idx, train_step_and_backward_closure)
    364             on_tpu=isinstance(self.trainer.accelerator, TPUAccelerator),
    365             using_native_amp=(self.trainer.amp_backend == AMPType.NATIVE),
--> 366             using_lbfgs=is_lbfgs,
    367         )
    368 

~/anaconda3/envs/donut/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in _call_lightning_module_hook(self, hook_name, pl_module, *args, **kwargs)
   1303 
   1304         with self.profiler.profile(f"[LightningModule]{pl_module.__class__.__name__}.{hook_name}"):
-> 1305             output = fn(*args, **kwargs)
   1306 
   1307         # restore current_fx when nested context

~/anaconda3/envs/donut/lib/python3.7/site-packages/pytorch_lightning/core/module.py in optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, on_tpu, using_native_amp, using_lbfgs)
   1659 
   1660         """
-> 1661         optimizer.step(closure=optimizer_closure)
   1662 
   1663     def optimizer_zero_grad(self, epoch: int, batch_idx: int, optimizer: Optimizer, optimizer_idx: int) -> None:

~/anaconda3/envs/donut/lib/python3.7/site-packages/pytorch_lightning/core/optimizer.py in step(self, closure, **kwargs)
    167 
    168         assert self._strategy is not None
--> 169         step_output = self._strategy.optimizer_step(self._optimizer, self._optimizer_idx, closure, **kwargs)
    170 
    171         self._on_after_step()

~/anaconda3/envs/donut/lib/python3.7/site-packages/pytorch_lightning/strategies/strategy.py in optimizer_step(self, optimizer, opt_idx, closure, model, **kwargs)
    233         assert isinstance(model, pl.LightningModule)
    234         return self.precision_plugin.optimizer_step(
--> 235             optimizer, model=model, optimizer_idx=opt_idx, closure=closure, **kwargs
    236         )
    237 

~/anaconda3/envs/donut/lib/python3.7/site-packages/pytorch_lightning/plugins/precision/native_amp.py in optimizer_step(self, optimizer, model, optimizer_idx, closure, **kwargs)
     77             # skip scaler logic, as bfloat16 does not require scaler
     78             return super().optimizer_step(
---> 79                 optimizer, model=model, optimizer_idx=optimizer_idx, closure=closure, **kwargs
     80             )
     81         if isinstance(optimizer, LBFGS):

~/anaconda3/envs/donut/lib/python3.7/site-packages/pytorch_lightning/plugins/precision/precision_plugin.py in optimizer_step(self, optimizer, model, optimizer_idx, closure, **kwargs)
    119         """Hook to run the optimizer step."""
    120         closure = partial(self._wrap_closure, model, optimizer, optimizer_idx, closure)
--> 121         return optimizer.step(closure=closure, **kwargs)
    122 
    123     def _track_grad_norm(self, trainer: "pl.Trainer") -> None:

~/anaconda3/envs/donut/lib/python3.7/site-packages/torch/optim/optimizer.py in wrapper(*args, **kwargs)
    138                 profile_name = "Optimizer.step#{}.step".format(obj.__class__.__name__)
    139                 with torch.autograd.profiler.record_function(profile_name):
--> 140                     out = func(*args, **kwargs)
    141                     obj._optimizer_step_code()
    142                     return out

~/anaconda3/envs/donut/lib/python3.7/site-packages/torch/optim/optimizer.py in _use_grad(self, *args, **kwargs)
     21         try:
     22             torch.set_grad_enabled(self.defaults['differentiable'])
---> 23             ret = func(self, *args, **kwargs)
     24         finally:
     25             torch.set_grad_enabled(prev_grad)

~/anaconda3/envs/donut/lib/python3.7/site-packages/torch/optim/adam.py in step(self, closure, grad_scaler)
    181         if closure is not None:
    182             with torch.enable_grad():
--> 183                 loss = closure()
    184 
    185         for group in self.param_groups:

~/anaconda3/envs/donut/lib/python3.7/site-packages/pytorch_lightning/plugins/precision/precision_plugin.py in _wrap_closure(self, model, optimizer, optimizer_idx, closure)
    105         consistent with the ``PrecisionPlugin`` subclasses that cannot pass ``optimizer.step(closure)`` directly.
    106         """
--> 107         closure_result = closure()
    108         self._after_closure(model, optimizer, optimizer_idx)
    109         return closure_result

~/anaconda3/envs/donut/lib/python3.7/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py in __call__(self, *args, **kwargs)
    145 
    146     def __call__(self, *args: Any, **kwargs: Any) -> Optional[Tensor]:
--> 147         self._result = self.closure(*args, **kwargs)
    148         return self._result.loss
    149 

~/anaconda3/envs/donut/lib/python3.7/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py in closure(self, *args, **kwargs)
    131 
    132     def closure(self, *args: Any, **kwargs: Any) -> ClosureResult:
--> 133         step_output = self._step_fn()
    134 
    135         if step_output.closure_loss is None:

~/anaconda3/envs/donut/lib/python3.7/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py in _training_step(self, kwargs)
    404         """
    405         # manually capture logged metrics
--> 406         training_step_output = self.trainer._call_strategy_hook("training_step", *kwargs.values())
    407         self.trainer.strategy.post_training_step()
    408 

~/anaconda3/envs/donut/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in _call_strategy_hook(self, hook_name, *args, **kwargs)
   1441 
   1442         with self.profiler.profile(f"[Strategy]{self.strategy.__class__.__name__}.{hook_name}"):
-> 1443             output = fn(*args, **kwargs)
   1444 
   1445         # restore current_fx when nested context

~/anaconda3/envs/donut/lib/python3.7/site-packages/pytorch_lightning/strategies/strategy.py in training_step(self, *args, **kwargs)
    376         with self.precision_plugin.train_step_context():
    377             assert isinstance(self.model, TrainingStep)
--> 378             return self.model.training_step(*args, **kwargs)
    379 
    380     def post_training_step(self) -> None:

/tmp/ipykernel_294/1279761003.py in training_step(self, batch, batch_idx)
      9         pixel_values, labels, _ = batch
     10 
---> 11         outputs = self.model(pixel_values, labels=labels)
     12         loss = outputs.loss
     13         self.log_dict({"train_loss": loss}, sync_dist=True)

~/anaconda3/envs/donut/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1188         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1189                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1190             return forward_call(*input, **kwargs)
   1191         # Do not call functions when jit is used
   1192         full_backward_hooks, non_full_backward_hooks = [], []

~/anaconda3/envs/donut/lib/python3.7/site-packages/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py in forward(self, pixel_values, decoder_input_ids, decoder_attention_mask, encoder_outputs, past_key_values, decoder_inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, **kwargs)
    584                 output_hidden_states=output_hidden_states,
    585                 return_dict=return_dict,
--> 586                 **kwargs_encoder,
    587             )
    588         elif isinstance(encoder_outputs, tuple):

~/anaconda3/envs/donut/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1188         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1189                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1190             return forward_call(*input, **kwargs)
   1191         # Do not call functions when jit is used
   1192         full_backward_hooks, non_full_backward_hooks = [], []

~/anaconda3/envs/donut/lib/python3.7/site-packages/transformers/models/swin/modeling_swin.py in forward(self, pixel_values, bool_masked_pos, head_mask, output_attentions, output_hidden_states, return_dict)
    973         head_mask = self.get_head_mask(head_mask, len(self.config.depths))
    974 
--> 975         embedding_output, input_dimensions = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
    976 
    977         encoder_outputs = self.encoder(

~/anaconda3/envs/donut/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1188         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1189                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1190             return forward_call(*input, **kwargs)
   1191         # Do not call functions when jit is used
   1192         full_backward_hooks, non_full_backward_hooks = [], []

~/anaconda3/envs/donut/lib/python3.7/site-packages/transformers/models/swin/modeling_swin.py in forward(self, pixel_values, bool_masked_pos)
    251     ) -> Tuple[torch.Tensor]:
    252         embeddings, output_dimensions = self.patch_embeddings(pixel_values)
--> 253         embeddings = self.norm(embeddings)
    254         batch_size, seq_len, _ = embeddings.size()
    255 

~/anaconda3/envs/donut/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1188         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1189                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1190             return forward_call(*input, **kwargs)
   1191         # Do not call functions when jit is used
   1192         full_backward_hooks, non_full_backward_hooks = [], []

~/anaconda3/envs/donut/lib/python3.7/site-packages/torch/nn/modules/normalization.py in forward(self, input)
    189     def forward(self, input: Tensor) -> Tensor:
    190         return F.layer_norm(
--> 191             input, self.normalized_shape, self.weight, self.bias, self.eps)
    192 
    193     def extra_repr(self) -> str:

~/anaconda3/envs/donut/lib/python3.7/site-packages/torch/nn/functional.py in layer_norm(input, normalized_shape, weight, bias, eps)
   2513             layer_norm, (input, weight, bias), input, normalized_shape, weight=weight, bias=bias, eps=eps
   2514         )
-> 2515     return torch.layer_norm(input, normalized_shape, weight, bias, eps, torch.backends.cudnn.enabled)
   2516 
   2517 

RuntimeError: expected scalar type BFloat16 but found Float