utils for donut

This commit is contained in:
s444415 2022-12-11 10:43:08 +01:00
commit a2b9a287dd
4 changed files with 1251 additions and 0 deletions

0
README.md Normal file
View File

445
donut-eval.ipynb Normal file
View File

@ -0,0 +1,445 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from transformers import DonutProcessor, VisionEncoderDecoderModel\n",
"from datasets import load_dataset\n",
"import re\n",
"import json\n",
"import torch\n",
"from tqdm.auto import tqdm\n",
"import numpy as np\n",
"\n",
"from donut import JSONParseEvaluator"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "24d467a0b2fa49a99506d690dba9e411",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading: 0%| | 0.00/421 [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "3b69047a501b4e18b831d9e28c90f45e",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading: 0%| | 0.00/544 [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "69430fbf988c44e39d02b05144c356df",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading: 0%| | 0.00/1.30M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "875aa0ed813647d08d7f1e003b92b4e5",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading: 0%| | 0.00/4.01M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "cc5997bbdf6f4adfbea60fdfe18c3503",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading: 0%| | 0.00/95.0 [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "634540e8020e45129fe422976a7113e1",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading: 0%| | 0.00/355 [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "6fdce54654504308a7f1ba669823b996",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading: 0%| | 0.00/5.03k [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "552c171e151045e2a9dc6f860428add0",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading: 0%| | 0.00/809M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"\n",
"processor = DonutProcessor.from_pretrained(\"Zombely/plwiki-test\")\n",
"model = VisionEncoderDecoderModel.from_pretrained(\"Zombely/plwiki-test\")"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "deba176990a74dbd80fb9347758b6e33",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading readme: 0%| | 0.00/527 [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Using custom data configuration Zombely--pl-text-images-f3f66e614f4d9a7a\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Downloading and preparing dataset None/None to /home/pc/.cache/huggingface/datasets/Zombely___parquet/Zombely--pl-text-images-f3f66e614f4d9a7a/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec...\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "9e238b0b3f1e4176a00657852986df9d",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading data files: 0%| | 0/3 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "1f57c7efeb0f4e1e8bf7cc39e1d78a84",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading data: 0%| | 0.00/1.44M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "68418bad8c1745faa7ecd802594f7bf6",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading data: 0%| | 0.00/9.47M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "8fad89bd442646968ac667f427efc8e5",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading data: 0%| | 0.00/885k [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" "
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "1ff350c50c674a4d81f033c98b273d6f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Extracting data files #0: 0%| | 0/1 [00:00<?, ?obj/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "724a8fb02f144816a61e9e562e86952b",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Extracting data files #2: 0%| | 0/1 [00:00<?, ?obj/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "99bb9649389544a6a2d9e3ede599f9d0",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Extracting data files #1: 0%| | 0/1 [00:00<?, ?obj/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "93fd550b1e0b4881b22d1c37f6e201c2",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Generating test split: 0%| | 0/13 [00:00<?, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "7a4a5650eda941a5b86d8c3d367b6589",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Generating train split: 0%| | 0/101 [00:00<?, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "b1b5c26457ed47dd87a80d9fb6769d75",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Generating validation split: 0%| | 0/11 [00:00<?, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Dataset parquet downloaded and prepared to /home/pc/.cache/huggingface/datasets/Zombely___parquet/Zombely--pl-text-images-f3f66e614f4d9a7a/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec. Subsequent calls will reuse this data.\n"
]
}
],
"source": [
"dataset = load_dataset(\"Zombely/pl-text-images\", split=\"validation\")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "5cecbeda924942488b63727ca3885d60",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/11 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'accuracies': [0, 0.9791666666666666, 0.9156626506024097, 1.0, 0.9836552748885586, 1.0, 0.7335359675785207, 0.9512987012987013, 0.396732788798133, 0.9908675799086758, 0.9452954048140044], 'mean_accuracy': 0.8087468213232427} length : 11\n",
"Mean accuracy: 0.8087468213232427\n"
]
}
],
"source": [
"\n",
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
"\n",
"model.eval()\n",
"model.to(device)\n",
"\n",
"output_list = []\n",
"accs = []\n",
"\n",
"\n",
"for idx, sample in tqdm(enumerate(dataset), total=len(dataset)):\n",
" # prepare encoder inputs\n",
" pixel_values = processor(sample[\"image\"].convert(\"RGB\"), return_tensors=\"pt\").pixel_values\n",
" pixel_values = pixel_values.to(device)\n",
" # prepare decoder inputs\n",
" task_prompt = \"<s_cord-v2>\"\n",
" decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors=\"pt\").input_ids\n",
" decoder_input_ids = decoder_input_ids.to(device)\n",
" \n",
" # autoregressively generate sequence\n",
" outputs = model.generate(\n",
" pixel_values,\n",
" decoder_input_ids=decoder_input_ids,\n",
" max_length=model.decoder.config.max_position_embeddings,\n",
" early_stopping=True,\n",
" pad_token_id=processor.tokenizer.pad_token_id,\n",
" eos_token_id=processor.tokenizer.eos_token_id,\n",
" use_cache=True,\n",
" num_beams=1,\n",
" bad_words_ids=[[processor.tokenizer.unk_token_id]],\n",
" return_dict_in_generate=True,\n",
" )\n",
"\n",
" # turn into JSON\n",
" seq = processor.batch_decode(outputs.sequences)[0]\n",
" seq = seq.replace(processor.tokenizer.eos_token, \"\").replace(processor.tokenizer.pad_token, \"\")\n",
" seq = re.sub(r\"<.*?>\", \"\", seq, count=1).strip() # remove first task start token\n",
" seq = processor.token2json(seq)\n",
"\n",
" ground_truth = json.loads(sample[\"ground_truth\"])\n",
" ground_truth = ground_truth[\"gt_parse\"]\n",
" evaluator = JSONParseEvaluator()\n",
" score = evaluator.cal_acc(seq, ground_truth)\n",
"\n",
" accs.append(score)\n",
" output_list.append(seq)\n",
"\n",
"scores = {\"accuracies\": accs, \"mean_accuracy\": np.mean(accs)}\n",
"print(scores, f\"length : {len(accs)}\")\n",
"print(\"Mean accuracy:\", np.mean(accs))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.7.15 ('donut')",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.15"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "11ee3e278e787ae04c18a69549ce58331d512f29053c6ca32ae16833b7cef834"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}

607
donut-train.ipynb Normal file
View File

@ -0,0 +1,607 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"from transformers import VisionEncoderDecoderConfig, DonutProcessor, VisionEncoderDecoderModel\n",
"from datasets import load_dataset\n",
"import json\n",
"import random\n",
"from typing import Any, List, Tuple\n",
"import torch\n",
"from torch.utils.data import Dataset, DataLoader\n",
"import re\n",
"from nltk import edit_distance\n",
"import numpy as np\n",
"from pytorch_lightning.loggers import WandbLogger\n",
"from pytorch_lightning.callbacks import Callback\n",
"import pytorch_lightning as pl\n",
"import os\n",
"from huggingface_hub import login\n"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"DATASET_PATH = \"Zombely/pl-text-images\"\n",
"PRETRAINED_MODEL_PATH = \"nielsr/donut-proto\"\n",
"OUTPUT_MODEL_PATH = \"Zombely/plwiki-test-proto\"\n",
"LOGGING_PATH = \"plwiki-test-run-proto\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"train_config = {\n",
" \"max_epochs\":5,\n",
" \"val_check_interval\":0.2, # how many times we want to validate during an epoch\n",
" \"check_val_every_n_epoch\":1,\n",
" \"gradient_clip_val\":1.0,\n",
" \"num_training_samples_per_epoch\": 800,\n",
" \"lr\":3e-5,\n",
" \"train_batch_sizes\": [8],\n",
" \"val_batch_sizes\": [1],\n",
" # \"seed\":2022,\n",
" \"num_nodes\": 1,\n",
" \"warmup_steps\": 300, # 800/8*30/10, 10%\n",
" \"result_path\": \"./result\",\n",
" \"verbose\": True,\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Using custom data configuration Zombely--pl-text-images-f3f66e614f4d9a7a\n",
"Found cached dataset parquet (/home/pc/.cache/huggingface/datasets/Zombely___parquet/Zombely--pl-text-images-f3f66e614f4d9a7a/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "6e15064cef2d4d4c842ec97c3467b1d9",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/3 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"dataset = load_dataset(DATASET_PATH)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"max_length = 768\n",
"image_size = [1280, 960]\n",
"config = VisionEncoderDecoderConfig.from_pretrained(PRETRAINED_MODEL_PATH)\n",
"config.encoder.image_size = image_size # (height, width)\n",
"config.decoder.max_length = max_length"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Could not find image processor class in the image processor config or the model config. Loading based on pattern matching with the model's feature extractor configuration.\n",
"Some weights of the model checkpoint at nielsr/donut-proto were not used when initializing VisionEncoderDecoderModel: ['encoder.encoder.layers.2.blocks.13.attn_mask', 'encoder.encoder.layers.2.blocks.17.attn_mask', 'encoder.encoder.layers.2.blocks.1.attn_mask', 'encoder.encoder.layers.2.blocks.9.attn_mask', 'encoder.encoder.layers.2.blocks.7.attn_mask', 'encoder.encoder.layers.3.blocks.1.attn_mask', 'encoder.encoder.layers.2.blocks.11.attn_mask', 'encoder.encoder.layers.2.blocks.5.attn_mask', 'encoder.encoder.layers.2.blocks.15.attn_mask', 'encoder.encoder.layers.0.blocks.1.attn_mask', 'encoder.encoder.layers.1.blocks.1.attn_mask', 'encoder.encoder.layers.2.blocks.3.attn_mask']\n",
"- This IS expected if you are initializing VisionEncoderDecoderModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
"- This IS NOT expected if you are initializing VisionEncoderDecoderModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
"Some weights of VisionEncoderDecoderModel were not initialized from the model checkpoint at nielsr/donut-proto and are newly initialized: ['encoder.layernorm.weight', 'encoder.layernorm.bias']\n",
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
]
}
],
"source": [
"processor = DonutProcessor.from_pretrained(PRETRAINED_MODEL_PATH)\n",
"model = VisionEncoderDecoderModel.from_pretrained(PRETRAINED_MODEL_PATH, config=config)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"added_tokens = []\n",
"\n",
"class DonutDataset(Dataset):\n",
" \"\"\"\n",
" DonutDataset which is saved in huggingface datasets format. (see details in https://huggingface.co/docs/datasets)\n",
" Each row, consists of image path(png/jpg/jpeg) and gt data (json/jsonl/txt),\n",
" and it will be converted into input_tensor(vectorized image) and input_ids(tokenized string).\n",
" Args:\n",
" dataset_name_or_path: name of dataset (available at huggingface.co/datasets) or the path containing image files and metadata.jsonl\n",
" max_length: the max number of tokens for the target sequences\n",
" split: whether to load \"train\", \"validation\" or \"test\" split\n",
" ignore_id: ignore_index for torch.nn.CrossEntropyLoss\n",
" task_start_token: the special token to be fed to the decoder to conduct the target task\n",
" prompt_end_token: the special token at the end of the sequences\n",
" sort_json_key: whether or not to sort the JSON keys\n",
" \"\"\"\n",
"\n",
" def __init__(\n",
" self,\n",
" dataset_name_or_path: str,\n",
" max_length: int,\n",
" split: str = \"train\",\n",
" ignore_id: int = -100,\n",
" task_start_token: str = \"<s>\",\n",
" prompt_end_token: str = None,\n",
" sort_json_key: bool = True,\n",
" ):\n",
" super().__init__()\n",
"\n",
" self.max_length = max_length\n",
" self.split = split\n",
" self.ignore_id = ignore_id\n",
" self.task_start_token = task_start_token\n",
" self.prompt_end_token = prompt_end_token if prompt_end_token else task_start_token\n",
" self.sort_json_key = sort_json_key\n",
"\n",
" self.dataset = load_dataset(dataset_name_or_path, split=self.split)\n",
" self.dataset_length = len(self.dataset)\n",
"\n",
" self.gt_token_sequences = []\n",
" for sample in self.dataset:\n",
" ground_truth = json.loads(sample[\"ground_truth\"])\n",
" if \"gt_parses\" in ground_truth: # when multiple ground truths are available, e.g., docvqa\n",
" assert isinstance(ground_truth[\"gt_parses\"], list)\n",
" gt_jsons = ground_truth[\"gt_parses\"]\n",
" else:\n",
" assert \"gt_parse\" in ground_truth and isinstance(ground_truth[\"gt_parse\"], dict)\n",
" gt_jsons = [ground_truth[\"gt_parse\"]]\n",
"\n",
" self.gt_token_sequences.append(\n",
" [\n",
" self.json2token(\n",
" gt_json,\n",
" update_special_tokens_for_json_key=self.split == \"train\",\n",
" sort_json_key=self.sort_json_key,\n",
" )\n",
" + processor.tokenizer.eos_token\n",
" for gt_json in gt_jsons # load json from list of json\n",
" ]\n",
" )\n",
"\n",
" self.add_tokens([self.task_start_token, self.prompt_end_token])\n",
" self.prompt_end_token_id = processor.tokenizer.convert_tokens_to_ids(self.prompt_end_token)\n",
"\n",
" def json2token(self, obj: Any, update_special_tokens_for_json_key: bool = True, sort_json_key: bool = True):\n",
" \"\"\"\n",
" Convert an ordered JSON object into a token sequence\n",
" \"\"\"\n",
" if type(obj) == dict:\n",
" if len(obj) == 1 and \"text_sequence\" in obj:\n",
" return obj[\"text_sequence\"]\n",
" else:\n",
" output = \"\"\n",
" if sort_json_key:\n",
" keys = sorted(obj.keys(), reverse=True)\n",
" else:\n",
" keys = obj.keys()\n",
" for k in keys:\n",
" if update_special_tokens_for_json_key:\n",
" self.add_tokens([fr\"<s_{k}>\", fr\"</s_{k}>\"])\n",
" output += (\n",
" fr\"<s_{k}>\"\n",
" + self.json2token(obj[k], update_special_tokens_for_json_key, sort_json_key)\n",
" + fr\"</s_{k}>\"\n",
" )\n",
" return output\n",
" elif type(obj) == list:\n",
" return r\"<sep/>\".join(\n",
" [self.json2token(item, update_special_tokens_for_json_key, sort_json_key) for item in obj]\n",
" )\n",
" else:\n",
" obj = str(obj)\n",
" if f\"<{obj}/>\" in added_tokens:\n",
" obj = f\"<{obj}/>\" # for categorical special tokens\n",
" return obj\n",
" \n",
" def add_tokens(self, list_of_tokens: List[str]):\n",
" \"\"\"\n",
" Add special tokens to tokenizer and resize the token embeddings of the decoder\n",
" \"\"\"\n",
" newly_added_num = processor.tokenizer.add_tokens(list_of_tokens)\n",
" if newly_added_num > 0:\n",
" model.decoder.resize_token_embeddings(len(processor.tokenizer))\n",
" added_tokens.extend(list_of_tokens)\n",
" \n",
" def __len__(self) -> int:\n",
" return self.dataset_length\n",
"\n",
" def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n",
" \"\"\"\n",
" Load image from image_path of given dataset_path and convert into input_tensor and labels\n",
" Convert gt data into input_ids (tokenized string)\n",
" Returns:\n",
" input_tensor : preprocessed image\n",
" input_ids : tokenized gt_data\n",
" labels : masked labels (model doesn't need to predict prompt and pad token)\n",
" \"\"\"\n",
" sample = self.dataset[idx]\n",
"\n",
" # inputs\n",
" pixel_values = processor(sample[\"image\"], random_padding=self.split == \"train\", return_tensors=\"pt\").pixel_values\n",
" pixel_values = pixel_values.squeeze()\n",
"\n",
" # targets\n",
" target_sequence = random.choice(self.gt_token_sequences[idx]) # can be more than one, e.g., DocVQA Task 1\n",
" input_ids = processor.tokenizer(\n",
" target_sequence,\n",
" add_special_tokens=False,\n",
" max_length=self.max_length,\n",
" padding=\"max_length\",\n",
" truncation=True,\n",
" return_tensors=\"pt\",\n",
" )[\"input_ids\"].squeeze(0)\n",
"\n",
" labels = input_ids.clone()\n",
" labels[labels == processor.tokenizer.pad_token_id] = self.ignore_id # model doesn't need to predict pad token\n",
" # labels[: torch.nonzero(labels == self.prompt_end_token_id).sum() + 1] = self.ignore_id # model doesn't need to predict prompt (for VQA)\n",
" return pixel_values, labels, target_sequence"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Using custom data configuration Zombely--pl-text-images-f3f66e614f4d9a7a\n",
"Found cached dataset parquet (/home/pc/.cache/huggingface/datasets/Zombely___parquet/Zombely--pl-text-images-f3f66e614f4d9a7a/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)\n",
"Using custom data configuration Zombely--pl-text-images-f3f66e614f4d9a7a\n",
"Found cached dataset parquet (/home/pc/.cache/huggingface/datasets/Zombely___parquet/Zombely--pl-text-images-f3f66e614f4d9a7a/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)\n"
]
}
],
"source": [
"processor.image_processor.size = image_size[::-1] # should be (width, height)\n",
"processor.image_processor.do_align_long_axis = False\n",
"\n",
"train_dataset = DonutDataset(DATASET_PATH, max_length=max_length,\n",
" split=\"train\", task_start_token=\"<s_cord-v2>\", prompt_end_token=\"<s_cord-v2>\",\n",
" sort_json_key=False, # cord dataset is preprocessed, so no need for this\n",
" )\n",
"\n",
"val_dataset = DonutDataset(DATASET_PATH, max_length=max_length,\n",
" split=\"validation\", task_start_token=\"<s_cord-v2>\", prompt_end_token=\"<s_cord-v2>\",\n",
" sort_json_key=False, # cord dataset is preprocessed, so no need for this\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"model.config.pad_token_id = processor.tokenizer.pad_token_id\n",
"model.config.decoder_start_token_id = processor.tokenizer.convert_tokens_to_ids(['<s_cord-v2>'])[0]"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=4)\n",
"val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=4)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"class DonutModelPLModule(pl.LightningModule):\n",
" def __init__(self, config, processor, model):\n",
" super().__init__()\n",
" self.config = config\n",
" self.processor = processor\n",
" self.model = model\n",
"\n",
" def training_step(self, batch, batch_idx):\n",
" pixel_values, labels, _ = batch\n",
" \n",
" outputs = self.model(pixel_values, labels=labels)\n",
" loss = outputs.loss\n",
" self.log_dict({\"train_loss\": loss}, sync_dist=True)\n",
" return loss\n",
"\n",
" def validation_step(self, batch, batch_idx, dataset_idx=0):\n",
" pixel_values, labels, answers = batch\n",
" batch_size = pixel_values.shape[0]\n",
" # we feed the prompt to the model\n",
" decoder_input_ids = torch.full((batch_size, 1), self.model.config.decoder_start_token_id, device=self.device)\n",
" \n",
" outputs = self.model.generate(pixel_values,\n",
" decoder_input_ids=decoder_input_ids,\n",
" max_length=max_length,\n",
" early_stopping=True,\n",
" pad_token_id=self.processor.tokenizer.pad_token_id,\n",
" eos_token_id=self.processor.tokenizer.eos_token_id,\n",
" use_cache=True,\n",
" num_beams=1,\n",
" bad_words_ids=[[self.processor.tokenizer.unk_token_id]],\n",
" return_dict_in_generate=True,)\n",
" \n",
" predictions = []\n",
" for seq in self.processor.tokenizer.batch_decode(outputs.sequences):\n",
" seq = seq.replace(self.processor.tokenizer.eos_token, \"\").replace(self.processor.tokenizer.pad_token, \"\")\n",
" seq = re.sub(r\"<.*?>\", \"\", seq, count=1).strip() # remove first task start token\n",
" predictions.append(seq)\n",
"\n",
" scores = list()\n",
" for pred, answer in zip(predictions, answers):\n",
" pred = re.sub(r\"(?:(?<=>) | (?=</s_))\", \"\", pred)\n",
" # NOT NEEDED ANYMORE\n",
" # answer = re.sub(r\"<.*?>\", \"\", answer, count=1)\n",
" answer = answer.replace(self.processor.tokenizer.eos_token, \"\")\n",
" scores.append(edit_distance(pred, answer) / max(len(pred), len(answer)))\n",
"\n",
" if self.config.get(\"verbose\", False) and len(scores) == 1:\n",
" print(f\"Prediction: {pred}\")\n",
" print(f\" Answer: {answer}\")\n",
" print(f\" Normed ED: {scores[0]}\")\n",
"\n",
" return scores\n",
"\n",
" def validation_epoch_end(self, validation_step_outputs):\n",
" # I set this to 1 manually\n",
" # (previously set to len(self.config.dataset_name_or_paths))\n",
" num_of_loaders = 1\n",
" if num_of_loaders == 1:\n",
" validation_step_outputs = [validation_step_outputs]\n",
" assert len(validation_step_outputs) == num_of_loaders\n",
" cnt = [0] * num_of_loaders\n",
" total_metric = [0] * num_of_loaders\n",
" val_metric = [0] * num_of_loaders\n",
" for i, results in enumerate(validation_step_outputs):\n",
" for scores in results:\n",
" cnt[i] += len(scores)\n",
" total_metric[i] += np.sum(scores)\n",
" val_metric[i] = total_metric[i] / cnt[i]\n",
" val_metric_name = f\"val_metric_{i}th_dataset\"\n",
" self.log_dict({val_metric_name: val_metric[i]}, sync_dist=True)\n",
" self.log_dict({\"val_metric\": np.sum(total_metric) / np.sum(cnt)}, sync_dist=True)\n",
"\n",
" def configure_optimizers(self):\n",
" # TODO add scheduler\n",
" optimizer = torch.optim.Adam(self.parameters(), lr=self.config.get(\"lr\"))\n",
" \n",
" return optimizer\n",
"\n",
" def train_dataloader(self):\n",
" return train_dataloader\n",
"\n",
" def val_dataloader(self):\n",
" return val_dataloader"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"class PushToHubCallback(Callback):\n",
" def on_train_epoch_end(self, trainer, pl_module):\n",
" print(f\"Pushing model to the hub, epoch {trainer.current_epoch}\")\n",
" pl_module.model.push_to_hub(OUTPUT_MODEL_PATH,\n",
" commit_message=f\"Training in progress, epoch {trainer.current_epoch}\")\n",
"\n",
" def on_train_end(self, trainer, pl_module):\n",
" print(f\"Pushing model to the hub after training\")\n",
" pl_module.processor.push_to_hub(OUTPUT_MODEL_PATH,\n",
" commit_message=f\"Training done\")\n",
" pl_module.model.push_to_hub(OUTPUT_MODEL_PATH,\n",
" commit_message=f\"Training done\")"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Token will not been saved to git credential helper. Pass `add_to_git_credential=True` if you want to set the git credential as well.\n",
"Token is valid.\n",
"Your token has been saved to /home/pc/.huggingface/token\n",
"Login successful\n"
]
}
],
"source": [
"login(os.environ.get(\"HUG_TOKEN\"))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Wandb.ai link: https://wandb.ai/michalkozlowski936/Donut?workspace=user-michalkozlowski936\n",
"### Hugging_face link https://huggingface.co/Zombely"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Using bfloat16 Automatic Mixed Precision (AMP)\n",
"GPU available: False, used: False\n",
"TPU available: False, using: 0 TPU cores\n",
"IPU available: False, using: 0 IPUs\n",
"HPU available: False, using: 0 HPUs\n",
"\n",
" | Name | Type | Params\n",
"----------------------------------------------------\n",
"0 | model | VisionEncoderDecoderModel | 213 M \n",
"----------------------------------------------------\n",
"213 M Trainable params\n",
"0 Non-trainable params\n",
"213 M Total params\n",
"854.597 Total estimated model params size (MB)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "3f0be6eee3a54304ba7b8a333536a418",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Training: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"ename": "RuntimeError",
"evalue": "expected scalar type BFloat16 but found Float",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m/tmp/ipykernel_294/2569065759.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 27\u001b[0m )\n\u001b[1;32m 28\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 29\u001b[0;31m \u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel_module\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;32m~/anaconda3/envs/donut/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)\u001b[0m\n\u001b[1;32m 581\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstrategy\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_lightning_module\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 582\u001b[0m call._call_and_handle_interrupt(\n\u001b[0;32m--> 583\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_fit_impl\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrain_dataloaders\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mval_dataloaders\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdatamodule\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mckpt_path\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 584\u001b[0m )\n\u001b[1;32m 585\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/anaconda3/envs/donut/lib/python3.7/site-packages/pytorch_lightning/trainer/call.py\u001b[0m in \u001b[0;36m_call_and_handle_interrupt\u001b[0;34m(trainer, trainer_fn, *args, **kwargs)\u001b[0m\n\u001b[1;32m 36\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstrategy\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlauncher\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlaunch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrainer_fn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrainer\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtrainer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 37\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 38\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mtrainer_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 39\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 40\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0m_TunerExitException\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/anaconda3/envs/donut/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py\u001b[0m in \u001b[0;36m_fit_impl\u001b[0;34m(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)\u001b[0m\n\u001b[1;32m 622\u001b[0m \u001b[0mmodel_connected\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlightning_module\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 623\u001b[0m )\n\u001b[0;32m--> 624\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_run\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mckpt_path\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mckpt_path\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 625\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 626\u001b[0m \u001b[0;32massert\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstate\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstopped\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/anaconda3/envs/donut/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py\u001b[0m in \u001b[0;36m_run\u001b[0;34m(self, model, ckpt_path)\u001b[0m\n\u001b[1;32m 1059\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_checkpoint_connector\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mresume_end\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1060\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1061\u001b[0;31m \u001b[0mresults\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_run_stage\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1062\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1063\u001b[0m \u001b[0mlog\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdetail\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"{self.__class__.__name__}: trainer tearing down\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/anaconda3/envs/donut/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py\u001b[0m in \u001b[0;36m_run_stage\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1138\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpredicting\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1139\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_run_predict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1140\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_run_train\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1141\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1142\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_pre_training_routine\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/anaconda3/envs/donut/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py\u001b[0m in \u001b[0;36m_run_train\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1161\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1162\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mautograd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset_detect_anomaly\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_detect_anomaly\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1163\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit_loop\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1164\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1165\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_run_evaluate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0m_EVALUATE_OUTPUT\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/anaconda3/envs/donut/lib/python3.7/site-packages/pytorch_lightning/loops/loop.py\u001b[0m in \u001b[0;36mrun\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 197\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 198\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mon_advance_start\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 199\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madvance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 200\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mon_advance_end\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 201\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_restarting\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/anaconda3/envs/donut/lib/python3.7/site-packages/pytorch_lightning/loops/fit_loop.py\u001b[0m in \u001b[0;36madvance\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 265\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_data_fetcher\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msetup\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdataloader\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_to_device\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mbatch_to_device\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 266\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprofiler\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprofile\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"run_training_epoch\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 267\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_outputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mepoch_loop\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_data_fetcher\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 268\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 269\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mon_advance_end\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/anaconda3/envs/donut/lib/python3.7/site-packages/pytorch_lightning/loops/loop.py\u001b[0m in \u001b[0;36mrun\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 197\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 198\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mon_advance_start\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 199\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madvance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 200\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mon_advance_end\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 201\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_restarting\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/anaconda3/envs/donut/lib/python3.7/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py\u001b[0m in \u001b[0;36madvance\u001b[0;34m(self, data_fetcher)\u001b[0m\n\u001b[1;32m 212\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 213\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprofiler\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprofile\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"run_training_batch\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 214\u001b[0;31m \u001b[0mbatch_output\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbatch_loop\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 215\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 216\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbatch_progress\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mincrement_processed\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/anaconda3/envs/donut/lib/python3.7/site-packages/pytorch_lightning/loops/loop.py\u001b[0m in \u001b[0;36mrun\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 197\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 198\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mon_advance_start\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 199\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madvance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 200\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mon_advance_end\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 201\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_restarting\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/anaconda3/envs/donut/lib/python3.7/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py\u001b[0m in \u001b[0;36madvance\u001b[0;34m(self, kwargs)\u001b[0m\n\u001b[1;32m 86\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moptimizers\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moptimizer_frequencies\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"batch_idx\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 87\u001b[0m )\n\u001b[0;32m---> 88\u001b[0;31m \u001b[0moutputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moptimizer_loop\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moptimizers\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 89\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 90\u001b[0m \u001b[0moutputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmanual_loop\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/anaconda3/envs/donut/lib/python3.7/site-packages/pytorch_lightning/loops/loop.py\u001b[0m in \u001b[0;36mrun\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 197\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 198\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mon_advance_start\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 199\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madvance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 200\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mon_advance_end\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 201\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_restarting\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/anaconda3/envs/donut/lib/python3.7/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py\u001b[0m in \u001b[0;36madvance\u001b[0;34m(self, optimizers, kwargs)\u001b[0m\n\u001b[1;32m 198\u001b[0m \u001b[0mkwargs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_build_kwargs\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moptimizer_idx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_hiddens\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 199\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 200\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_run_optimization\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_optimizers\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moptim_progress\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moptimizer_position\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 201\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mloss\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 202\u001b[0m \u001b[0;31m# automatic optimization assumes a loss needs to be returned for extras to be considered as the batch\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/anaconda3/envs/donut/lib/python3.7/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py\u001b[0m in \u001b[0;36m_run_optimization\u001b[0;34m(self, kwargs, optimizer)\u001b[0m\n\u001b[1;32m 245\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 246\u001b[0m \u001b[0;31m# the `batch_idx` is optional with inter-batch parallelism\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 247\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_optimizer_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moptimizer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mopt_idx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"batch_idx\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mclosure\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 248\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 249\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mclosure\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconsume_result\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/anaconda3/envs/donut/lib/python3.7/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py\u001b[0m in \u001b[0;36m_optimizer_step\u001b[0;34m(self, optimizer, opt_idx, batch_idx, train_step_and_backward_closure)\u001b[0m\n\u001b[1;32m 364\u001b[0m \u001b[0mon_tpu\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maccelerator\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mTPUAccelerator\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 365\u001b[0m \u001b[0musing_native_amp\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mamp_backend\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0mAMPType\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mNATIVE\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 366\u001b[0;31m \u001b[0musing_lbfgs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mis_lbfgs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 367\u001b[0m )\n\u001b[1;32m 368\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/anaconda3/envs/donut/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py\u001b[0m in \u001b[0;36m_call_lightning_module_hook\u001b[0;34m(self, hook_name, pl_module, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1303\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1304\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprofiler\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprofile\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"[LightningModule]{pl_module.__class__.__name__}.{hook_name}\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1305\u001b[0;31m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1306\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1307\u001b[0m \u001b[0;31m# restore current_fx when nested context\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/anaconda3/envs/donut/lib/python3.7/site-packages/pytorch_lightning/core/module.py\u001b[0m in \u001b[0;36moptimizer_step\u001b[0;34m(self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, on_tpu, using_native_amp, using_lbfgs)\u001b[0m\n\u001b[1;32m 1659\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1660\u001b[0m \"\"\"\n\u001b[0;32m-> 1661\u001b[0;31m \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mclosure\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0moptimizer_closure\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1662\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1663\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0moptimizer_zero_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mepoch\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_idx\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mOptimizer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptimizer_idx\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/anaconda3/envs/donut/lib/python3.7/site-packages/pytorch_lightning/core/optimizer.py\u001b[0m in \u001b[0;36mstep\u001b[0;34m(self, closure, **kwargs)\u001b[0m\n\u001b[1;32m 167\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 168\u001b[0m \u001b[0;32massert\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_strategy\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 169\u001b[0;31m \u001b[0mstep_output\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_strategy\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moptimizer_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_optimizer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_optimizer_idx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mclosure\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 170\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 171\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_on_after_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/anaconda3/envs/donut/lib/python3.7/site-packages/pytorch_lightning/strategies/strategy.py\u001b[0m in \u001b[0;36moptimizer_step\u001b[0;34m(self, optimizer, opt_idx, closure, model, **kwargs)\u001b[0m\n\u001b[1;32m 233\u001b[0m \u001b[0;32massert\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpl\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mLightningModule\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 234\u001b[0m return self.precision_plugin.optimizer_step(\n\u001b[0;32m--> 235\u001b[0;31m \u001b[0moptimizer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptimizer_idx\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mopt_idx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mclosure\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mclosure\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 236\u001b[0m )\n\u001b[1;32m 237\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/anaconda3/envs/donut/lib/python3.7/site-packages/pytorch_lightning/plugins/precision/native_amp.py\u001b[0m in \u001b[0;36moptimizer_step\u001b[0;34m(self, optimizer, model, optimizer_idx, closure, **kwargs)\u001b[0m\n\u001b[1;32m 77\u001b[0m \u001b[0;31m# skip scaler logic, as bfloat16 does not require scaler\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 78\u001b[0m return super().optimizer_step(\n\u001b[0;32m---> 79\u001b[0;31m \u001b[0moptimizer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptimizer_idx\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0moptimizer_idx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mclosure\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mclosure\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 80\u001b[0m )\n\u001b[1;32m 81\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moptimizer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mLBFGS\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/anaconda3/envs/donut/lib/python3.7/site-packages/pytorch_lightning/plugins/precision/precision_plugin.py\u001b[0m in \u001b[0;36moptimizer_step\u001b[0;34m(self, optimizer, model, optimizer_idx, closure, **kwargs)\u001b[0m\n\u001b[1;32m 119\u001b[0m \u001b[0;34m\"\"\"Hook to run the optimizer step.\"\"\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 120\u001b[0m \u001b[0mclosure\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpartial\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_wrap_closure\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptimizer_idx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mclosure\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 121\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mclosure\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mclosure\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 122\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 123\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_track_grad_norm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrainer\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m\"pl.Trainer\"\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/anaconda3/envs/donut/lib/python3.7/site-packages/torch/optim/optimizer.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 138\u001b[0m \u001b[0mprofile_name\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m\"Optimizer.step#{}.step\"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mobj\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__class__\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__name__\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 139\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mautograd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprofiler\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrecord_function\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mprofile_name\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 140\u001b[0;31m \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 141\u001b[0m \u001b[0mobj\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_optimizer_step_code\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 142\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mout\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/anaconda3/envs/donut/lib/python3.7/site-packages/torch/optim/optimizer.py\u001b[0m in \u001b[0;36m_use_grad\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 21\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 22\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset_grad_enabled\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdefaults\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'differentiable'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 23\u001b[0;31m \u001b[0mret\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 24\u001b[0m \u001b[0;32mfinally\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 25\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset_grad_enabled\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mprev_grad\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/anaconda3/envs/donut/lib/python3.7/site-packages/torch/optim/adam.py\u001b[0m in \u001b[0;36mstep\u001b[0;34m(self, closure, grad_scaler)\u001b[0m\n\u001b[1;32m 181\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mclosure\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 182\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0menable_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 183\u001b[0;31m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mclosure\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 184\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 185\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mgroup\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparam_groups\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/anaconda3/envs/donut/lib/python3.7/site-packages/pytorch_lightning/plugins/precision/precision_plugin.py\u001b[0m in \u001b[0;36m_wrap_closure\u001b[0;34m(self, model, optimizer, optimizer_idx, closure)\u001b[0m\n\u001b[1;32m 105\u001b[0m \u001b[0mconsistent\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mthe\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0mPrecisionPlugin\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m \u001b[0msubclasses\u001b[0m \u001b[0mthat\u001b[0m \u001b[0mcannot\u001b[0m \u001b[0;32mpass\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mclosure\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m \u001b[0mdirectly\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 106\u001b[0m \"\"\"\n\u001b[0;32m--> 107\u001b[0;31m \u001b[0mclosure_result\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mclosure\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 108\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_after_closure\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptimizer_idx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 109\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mclosure_result\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/anaconda3/envs/donut/lib/python3.7/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 145\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 146\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__call__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mAny\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mAny\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mOptional\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 147\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_result\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mclosure\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 148\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_result\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mloss\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 149\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/anaconda3/envs/donut/lib/python3.7/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py\u001b[0m in \u001b[0;36mclosure\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 131\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 132\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mclosure\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mAny\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mAny\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mClosureResult\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 133\u001b[0;31m \u001b[0mstep_output\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_step_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 134\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 135\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mstep_output\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mclosure_loss\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/anaconda3/envs/donut/lib/python3.7/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py\u001b[0m in \u001b[0;36m_training_step\u001b[0;34m(self, kwargs)\u001b[0m\n\u001b[1;32m 404\u001b[0m \"\"\"\n\u001b[1;32m 405\u001b[0m \u001b[0;31m# manually capture logged metrics\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 406\u001b[0;31m \u001b[0mtraining_step_output\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call_strategy_hook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"training_step\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 407\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstrategy\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpost_training_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 408\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/anaconda3/envs/donut/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py\u001b[0m in \u001b[0;36m_call_strategy_hook\u001b[0;34m(self, hook_name, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1441\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1442\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprofiler\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprofile\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"[Strategy]{self.strategy.__class__.__name__}.{hook_name}\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1443\u001b[0;31m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1444\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1445\u001b[0m \u001b[0;31m# restore current_fx when nested context\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/anaconda3/envs/donut/lib/python3.7/site-packages/pytorch_lightning/strategies/strategy.py\u001b[0m in \u001b[0;36mtraining_step\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 376\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprecision_plugin\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain_step_context\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 377\u001b[0m \u001b[0;32massert\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mTrainingStep\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 378\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtraining_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 379\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 380\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mpost_training_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/tmp/ipykernel_294/1279761003.py\u001b[0m in \u001b[0;36mtraining_step\u001b[0;34m(self, batch, batch_idx)\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0mpixel_values\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabels\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbatch\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 11\u001b[0;31m \u001b[0moutputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpixel_values\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabels\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mlabels\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 12\u001b[0m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0moutputs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mloss\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 13\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlog_dict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m{\u001b[0m\u001b[0;34m\"train_loss\"\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mloss\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msync_dist\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/anaconda3/envs/donut/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1188\u001b[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m 1189\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1190\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1191\u001b[0m \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1192\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/anaconda3/envs/donut/lib/python3.7/site-packages/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, pixel_values, decoder_input_ids, decoder_attention_mask, encoder_outputs, past_key_values, decoder_inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, **kwargs)\u001b[0m\n\u001b[1;32m 584\u001b[0m \u001b[0moutput_hidden_states\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0moutput_hidden_states\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 585\u001b[0m \u001b[0mreturn_dict\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mreturn_dict\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 586\u001b[0;31m \u001b[0;34m**\u001b[0m\u001b[0mkwargs_encoder\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 587\u001b[0m )\n\u001b[1;32m 588\u001b[0m \u001b[0;32melif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mencoder_outputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtuple\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/anaconda3/envs/donut/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1188\u001b[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m 1189\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1190\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1191\u001b[0m \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1192\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/anaconda3/envs/donut/lib/python3.7/site-packages/transformers/models/swin/modeling_swin.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, pixel_values, bool_masked_pos, head_mask, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m 973\u001b[0m \u001b[0mhead_mask\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_head_mask\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mhead_mask\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconfig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdepths\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 974\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 975\u001b[0;31m \u001b[0membedding_output\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput_dimensions\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0membeddings\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpixel_values\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbool_masked_pos\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mbool_masked_pos\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 976\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 977\u001b[0m encoder_outputs = self.encoder(\n",
"\u001b[0;32m~/anaconda3/envs/donut/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1188\u001b[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m 1189\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1190\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1191\u001b[0m \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1192\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/anaconda3/envs/donut/lib/python3.7/site-packages/transformers/models/swin/modeling_swin.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, pixel_values, bool_masked_pos)\u001b[0m\n\u001b[1;32m 251\u001b[0m ) -> Tuple[torch.Tensor]:\n\u001b[1;32m 252\u001b[0m \u001b[0membeddings\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutput_dimensions\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpatch_embeddings\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpixel_values\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 253\u001b[0;31m \u001b[0membeddings\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnorm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0membeddings\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 254\u001b[0m \u001b[0mbatch_size\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mseq_len\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0membeddings\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 255\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/anaconda3/envs/donut/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1188\u001b[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m 1189\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1190\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1191\u001b[0m \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1192\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/anaconda3/envs/donut/lib/python3.7/site-packages/torch/nn/modules/normalization.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 189\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 190\u001b[0m return F.layer_norm(\n\u001b[0;32m--> 191\u001b[0;31m input, self.normalized_shape, self.weight, self.bias, self.eps)\n\u001b[0m\u001b[1;32m 192\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 193\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mextra_repr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/anaconda3/envs/donut/lib/python3.7/site-packages/torch/nn/functional.py\u001b[0m in \u001b[0;36mlayer_norm\u001b[0;34m(input, normalized_shape, weight, bias, eps)\u001b[0m\n\u001b[1;32m 2513\u001b[0m \u001b[0mlayer_norm\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbias\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnormalized_shape\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mweight\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbias\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mbias\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0meps\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0meps\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2514\u001b[0m )\n\u001b[0;32m-> 2515\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlayer_norm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnormalized_shape\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbias\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0meps\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackends\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcudnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0menabled\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2516\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2517\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mRuntimeError\u001b[0m: expected scalar type BFloat16 but found Float"
]
}
],
"source": [
"\n",
"model_module = DonutModelPLModule(train_config, processor, model)\n",
"\n",
"wandb_logger = WandbLogger(project=\"Donut\", name=LOGGING_PATH)\n",
"\n",
"trainer = pl.Trainer(\n",
" accelerator=\"cpu\", # change to gpu\n",
" devices=1,\n",
" max_epochs=train_config.get(\"max_epochs\"),\n",
" val_check_interval=train_config.get(\"val_check_interval\"),\n",
" check_val_every_n_epoch=train_config.get(\"check_val_every_n_epoch\"),\n",
" gradient_clip_val=train_config.get(\"gradient_clip_val\"),\n",
" precision=16, # we'll use mixed precision\n",
" num_sanity_val_steps=0,\n",
" logger=wandb_logger,\n",
" callbacks=[PushToHubCallback()],\n",
")\n",
"\n",
"trainer.fit(model_module)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.7.15 ('donut')",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.15"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "11ee3e278e787ae04c18a69549ce58331d512f29053c6ca32ae16833b7cef834"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}

199
requirements.txt Normal file
View File

@ -0,0 +1,199 @@
absl-py==1.3.0
aiohttp==3.8.3
aiosignal==1.3.1
anyio==3.6.2
arabic-reshaper==2.1.4
argon2-cffi==21.3.0
argon2-cffi-bindings==21.2.0
async-timeout==4.0.2
asynctest==0.13.0
attrs==22.1.0
backcall==0.2.0
bcrypt==4.0.1
beautifulsoup4==4.11.1
bleach==5.0.1
blend-modes==2.1.0
cachetools==5.2.0
certifi @ file:///croot/certifi_1665076670883/work/certifi
cffi==1.15.1
charset-normalizer==2.1.1
click==8.1.3
cryptography==38.0.3
cycler==0.11.0
datasets==2.7.0
debugpy==1.6.3
decorator==5.1.1
defusedxml==0.7.1
dill==0.3.6
docker-pycreds==0.4.0
donut-python @ file:///home/pc/work/donut
entrypoints==0.4
evaluate==0.3.0
fastapi==0.87.0
fastjsonschema==2.16.2
ffmpy==0.3.0
filelock==3.8.0
fire==0.4.0
fonttools==4.38.0
frozenlist==1.3.3
fsspec==2022.11.0
future==0.18.2
gitdb==4.0.10
GitPython==3.1.29
google-auth==2.14.1
google-auth-oauthlib==0.4.6
gradio==3.10.1
grpcio==1.50.0
h11==0.12.0
httpcore==0.15.0
httpx==0.23.1
huggingface-hub==0.11.0
idna==3.4
imageio==2.22.4
imgaug==0.4.0
importlib-metadata==5.0.0
importlib-resources==5.10.0
ipykernel==6.16.2
ipython==7.34.0
ipython-genutils==0.2.0
ipywidgets==8.0.2
jedi==0.18.2
Jinja2==3.1.2
jiwer==2.5.1
joblib==1.2.0
jsonschema==4.17.1
jupyter==1.0.0
jupyter-console==6.4.4
jupyter-server==1.23.3
jupyter_client==7.4.7
jupyter_core==4.11.2
jupyterlab-pygments==0.2.2
jupyterlab-widgets==3.0.3
kiwisolver==1.4.4
Levenshtein==0.20.2
lightning-utilities==0.3.0
linkify-it-py==1.0.3
Markdown==3.4.1
markdown-it-py==2.1.0
MarkupSafe==2.1.1
matplotlib==3.5.3
matplotlib-inline==0.1.6
mdit-py-plugins==0.3.1
mdurl==0.1.2
mistune==2.0.4
multidict==6.0.2
multiprocess==0.70.14
munch==2.5.0
nbclassic==0.4.8
nbclient==0.7.0
nbconvert==7.2.5
nbformat==5.7.0
nest-asyncio==1.5.6
networkx==2.6.3
nltk==3.7
notebook==6.5.2
notebook_shim==0.2.2
numpy==1.21.6
nvidia-cublas-cu11==11.10.3.66
nvidia-cuda-nvrtc-cu11==11.7.99
nvidia-cuda-runtime-cu11==11.7.99
nvidia-cudnn-cu11==8.5.0.96
oauthlib==3.2.2
opencv-python==4.6.0.66
orjson==3.8.2
packaging==21.3
pandas==1.3.5
pandocfilters==1.5.0
paramiko==2.12.0
parso==0.8.3
pathtools==0.1.2
pexpect==4.8.0
pickleshare==0.7.5
Pillow==9.3.0
pkgutil_resolve_name==1.3.10
prometheus-client==0.15.0
promise==2.3
prompt-toolkit==3.0.33
protobuf==3.20.3
psutil==5.9.4
ptyprocess==0.7.0
pyarrow==10.0.0
pyasn1==0.4.8
pyasn1-modules==0.2.8
pycparser==2.21
pycryptodome==3.15.0
pydantic==1.10.2
pydub==0.25.1
pygame==2.1.2
Pygments==2.13.0
PyNaCl==1.5.0
pyparsing==3.0.9
pyrsistent==0.19.2
python-bidi==0.4.2
python-dateutil==2.8.2
python-multipart==0.0.5
pytorch-lightning==1.8.2
pytweening==1.0.4
pytz==2022.6
PyWavelets==1.3.0
PyYAML==6.0
pyzmq==24.0.1
qtconsole==5.4.0
QtPy==2.3.0
rapidfuzz==2.13.3
regex==2022.10.31
requests==2.28.1
requests-oauthlib==1.3.1
responses==0.18.0
rfc3986==1.5.0
rsa==4.9
ruamel.yaml==0.17.21
ruamel.yaml.clib==0.2.7
scikit-image==0.19.3
scikit-learn==1.0.2
scipy==1.7.3
sconf==0.2.5
Send2Trash==1.8.0
sentencepiece==0.1.97
sentry-sdk==1.11.1
setproctitle==1.3.2
Shapely==1.8.5.post1
shortuuid==1.0.11
six==1.16.0
smmap==5.0.0
sniffio==1.3.0
soupsieve==2.3.2.post1
starlette==0.21.0
synthtiger==1.2.1
tensorboard==2.11.0
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.1
termcolor==2.1.1
terminado==0.17.0
threadpoolctl==3.1.0
tifffile==2021.11.2
timm==0.6.11
tinycss2==1.2.1
tokenizers==0.13.2
torch==1.13.0
torchmetrics==0.10.3
torchvision==0.14.0
tornado==6.2
tqdm==4.64.1
traitlets==5.5.0
transformers @ git+https://github.com/huggingface/transformers.git@699e90437f984d69ad3c9b891dd2e9d0fc2cffe4
typing_extensions==4.4.0
uc-micro-py==1.0.1
urllib3==1.26.12
uvicorn==0.20.0
wandb==0.13.5
wcwidth==0.2.5
webencodings==0.5.1
websocket-client==1.4.2
websockets==10.4
Werkzeug==2.2.2
widgetsnbextension==4.0.3
xxhash==3.1.0
yarl==1.8.1
zipp==3.10.0
zss==1.2.0