Compare commits
No commits in common. "master" and "vm-changes" have entirely different histories.
master
...
vm-changes
6
.gitignore
vendored
6
.gitignore
vendored
@ -2,9 +2,3 @@ env
|
||||
Donut
|
||||
nohup.out
|
||||
wandb
|
||||
__pycache__/
|
||||
checkpoint
|
||||
.vscode
|
||||
donut_env
|
||||
env_donut
|
||||
*.out
|
17
README.md
17
README.md
@ -1,17 +0,0 @@
|
||||
# Donut script for train and eval
|
||||
|
||||
## Instalation/setup:
|
||||
-> `pip install -r requirements.txt`
|
||||
-> `export HUG_TOKEN=YourToken` - token needed for pushing model to huggingface after training
|
||||
|
||||
## Usage:
|
||||
|
||||
### Train
|
||||
-> `python train --config config-train.yaml`
|
||||
|
||||
|
||||
### Evaluation
|
||||
-> `python eval --config config-eval.yaml`
|
||||
|
||||
## Configuration
|
||||
Need to suply config file like `config-eval.yaml` for evaluation or `config-train.yaml` for traning
|
@ -1,11 +1,8 @@
|
||||
pretrained_processor_path: "Zombely/pl-donut-v1.2"
|
||||
pretrained_model_path: "Zombely/pl-donut-v1.2"
|
||||
validation_dataset_path: "Zombely/diachronia-ocr-train"
|
||||
pretrained_processor_path: "Zombely/plwiki-proto-fine-tuned-v2"
|
||||
pretrained_model_path: "Zombely/plwiki-proto-fine-tuned-v2"
|
||||
validation_dataset_path: "Zombely/diachronia-ocr"
|
||||
validation_dataset_split: "train"
|
||||
has_metadata: True
|
||||
has_metadata: False
|
||||
print_output: True
|
||||
output_file_dir: ""
|
||||
test_name: "fine-tuned"
|
||||
image_size: [1280, 960]
|
||||
use_enc_dec_config: False
|
||||
max_dec_length: 768
|
||||
output_file_dir: "../../gonito-outs"
|
||||
test_name: "fine-tuned-test"
|
@ -1,22 +0,0 @@
|
||||
dataset_path: "Zombely/wikisource-red"
|
||||
pretrained_model_path: "Zombely/pl-donut-v1.1"
|
||||
start_model_path: "Zombely/pl-donut-v1.1"
|
||||
output_model_path: "Zombely/pl-donut-v1.2"
|
||||
wandb_test_name: "pl-donut-v1.2"
|
||||
checkpoint_path: "./checkpoint"
|
||||
max_length: 768
|
||||
image_size: [1260, 960]
|
||||
train_config:
|
||||
max_epochs: 1
|
||||
val_check_interval: 1.0
|
||||
check_val_every_n_epoch: 1
|
||||
gradient_clip_val: 1.0
|
||||
num_training_samples_per_epoch: 800
|
||||
lr: 1.0e-4
|
||||
train_batch_sizes: [8]
|
||||
val_batch_sizes: [1]
|
||||
seed: 2023
|
||||
num_nodes: 1
|
||||
warmup_steps: 10
|
||||
result_path: "./result"
|
||||
verbose: True
|
@ -15,15 +15,15 @@ from sconf import Config
|
||||
|
||||
def main(config):
|
||||
|
||||
if config.use_enc_dec_config:
|
||||
config_vision = VisionEncoderDecoderConfig.from_pretrained(config.pretrained_model_path)
|
||||
config_vision.encoder.image_size = config.image_size # (height, width)
|
||||
config_vision.decoder.max_length = config.max_dec_length
|
||||
# image_size = [1920, 2560]
|
||||
# config_vision = VisionEncoderDecoderConfig.from_pretrained(config.pretrained_model_path)
|
||||
# config_vision.encoder.image_size = image_size # (height, width)
|
||||
# config_vision.decoder.max_length = 768
|
||||
|
||||
processor = DonutProcessor.from_pretrained(config.pretrained_processor_path)
|
||||
model = VisionEncoderDecoderModel.from_pretrained(config.pretrained_model_path, config=config_vision if config.use_enc_dec_config else None)
|
||||
model = VisionEncoderDecoderModel.from_pretrained(config.pretrained_model_path)
|
||||
|
||||
processor.image_processor.size = config.image_size[::-1] # should be (width, height)
|
||||
# processor.image_processor.size = image_size[::-1] # should be (width, height)
|
||||
processor.image_processor.do_align_long_axis = False
|
||||
|
||||
dataset = load_dataset(config.validation_dataset_path, split=config.validation_dataset_split)
|
||||
@ -69,11 +69,7 @@ def main(config):
|
||||
|
||||
accs.append(score)
|
||||
if config.print_output:
|
||||
if 'ground_truth' in sample:
|
||||
ground_truth = json.loads(sample["ground_truth"])
|
||||
ground_truth = str(ground_truth["gt_parse"])
|
||||
print("Original: ", ground_truth + "\n")
|
||||
print("Prediction: ", str(seq) + "\n")
|
||||
print(seq)
|
||||
output_list.append(seq)
|
||||
if config.output_file_dir:
|
||||
df = pd.DataFrame(map(lambda x: x.get('text_sequence', ''), output_list))
|
@ -376,3 +376,4 @@ trainer = pl.Trainer(
|
||||
)
|
||||
|
||||
trainer.fit(model_module)
|
||||
|
@ -794,7 +794,7 @@
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "donut",
|
||||
"display_name": "hug_donut",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
@ -808,12 +808,12 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.8 | packaged by conda-forge | (main, Nov 24 2022, 14:07:00) [MSC v.1916 64 bit (AMD64)]"
|
||||
"version": "3.9.15 (main, Nov 4 2022, 16:13:54) \n[GCC 11.2.0]"
|
||||
},
|
||||
"orig_nbformat": 4,
|
||||
"vscode": {
|
||||
"interpreter": {
|
||||
"hash": "5f15394fbb90e53eb87c79cee123a308177758b46ab7bd2ba3c7b07360ea775a"
|
||||
"hash": "8f1c1b41577d000ca6512e75d22d324bbd1d5e060e99f4f49d98cf0adf636690"
|
||||
}
|
||||
}
|
||||
},
|
||||
|
@ -1,494 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from transformers import VisionEncoderDecoderConfig, DonutProcessor, VisionEncoderDecoderModel, DonutImageProcessor, XLMRobertaTokenizerFast, BertConfig, ViTConfig\n",
|
||||
"from datasets import load_dataset, interleave_datasets\n",
|
||||
"import json\n",
|
||||
"import random\n",
|
||||
"from typing import Any, List, Tuple\n",
|
||||
"import torch\n",
|
||||
"from torch.utils.data import Dataset, DataLoader\n",
|
||||
"import re\n",
|
||||
"from nltk import edit_distance\n",
|
||||
"import numpy as np\n",
|
||||
"from pytorch_lightning.loggers import WandbLogger\n",
|
||||
"from pytorch_lightning.callbacks import Callback\n",
|
||||
"import pytorch_lightning as pl\n",
|
||||
"import os\n",
|
||||
"from huggingface_hub import login\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"import pytorch_lightning as pl\n",
|
||||
"from nltk import edit_distance\n",
|
||||
"import re\n",
|
||||
"import numpy as np\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class DonutModelPLModuleStream(pl.LightningModule):\n",
|
||||
" def __init__(self, config, processor, model, max_length, train_dataloader, val_dataloader):\n",
|
||||
" super().__init__()\n",
|
||||
" self.config = config\n",
|
||||
" self.processor = processor\n",
|
||||
" self.model = model\n",
|
||||
" self.max_length = max_length\n",
|
||||
" self._train_dataloader = train_dataloader\n",
|
||||
" self._val_dataloader = val_dataloader\n",
|
||||
"\n",
|
||||
" def training_step(self, batch, batch_idx):\n",
|
||||
" # pixel_values, labels, _ = batch\n",
|
||||
" pixel_values = batch['pixel_values']\n",
|
||||
" labels = batch['labels']\n",
|
||||
" outputs = self.model(pixel_values, labels=labels)\n",
|
||||
" loss = outputs.loss\n",
|
||||
" self.log_dict({\"train_loss\": loss}, sync_dist=True)\n",
|
||||
" return loss\n",
|
||||
"\n",
|
||||
" def validation_step(self, batch, batch_idx, dataset_idx=0):\n",
|
||||
" # pixel_values, labels, answers = batch\n",
|
||||
"\n",
|
||||
" pixel_values = batch['pixel_values']\n",
|
||||
" labels = batch['labels']\n",
|
||||
" answers = batch['target_sequence'][0]\n",
|
||||
" batch_size = pixel_values.shape[0]\n",
|
||||
" # we feed the prompt to the model\n",
|
||||
" decoder_input_ids = torch.full((batch_size, 1), self.model.config.decoder_start_token_id, device=self.device)\n",
|
||||
" \n",
|
||||
" outputs = self.model.generate(pixel_values,\n",
|
||||
" decoder_input_ids=decoder_input_ids,\n",
|
||||
" max_length=self.max_length,\n",
|
||||
" early_stopping=True,\n",
|
||||
" pad_token_id=self.processor.tokenizer.pad_token_id,\n",
|
||||
" eos_token_id=self.processor.tokenizer.eos_token_id,\n",
|
||||
" use_cache=True,\n",
|
||||
" num_beams=1,\n",
|
||||
" bad_words_ids=[[self.processor.tokenizer.unk_token_id]],\n",
|
||||
" return_dict_in_generate=True,)\n",
|
||||
" \n",
|
||||
" predictions = []\n",
|
||||
" for seq in self.processor.tokenizer.batch_decode(outputs.sequences):\n",
|
||||
" seq = seq.replace(self.processor.tokenizer.eos_token, \"\").replace(self.processor.tokenizer.pad_token, \"\")\n",
|
||||
" seq = re.sub(r\"<.*?>\", \"\", seq, count=1).strip() # remove first task start token\n",
|
||||
" predictions.append(seq)\n",
|
||||
"\n",
|
||||
" scores = list()\n",
|
||||
" for pred, answer in zip(predictions, answers):\n",
|
||||
" pred = re.sub(r\"(?:(?<=>) | (?=</s_))\", \"\", pred)\n",
|
||||
" # NOT NEEDED ANYMORE\n",
|
||||
" # answer = re.sub(r\"<.*?>\", \"\", answer, count=1)\n",
|
||||
" answer = answer.replace(self.processor.tokenizer.eos_token, \"\")\n",
|
||||
" scores.append(edit_distance(pred, answer) / max(len(pred), len(answer)))\n",
|
||||
"\n",
|
||||
" if self.config.get(\"verbose\", False) and len(scores) == 1:\n",
|
||||
" print(f\"Prediction: {pred}\")\n",
|
||||
" print(f\" Answer: {answer}\")\n",
|
||||
" print(f\" Normed ED: {scores[0]}\")\n",
|
||||
"\n",
|
||||
" return scores\n",
|
||||
"\n",
|
||||
" def validation_epoch_end(self, validation_step_outputs):\n",
|
||||
" # I set this to 1 manually\n",
|
||||
" # (previously set to len(self.config.dataset_name_or_paths))\n",
|
||||
" num_of_loaders = 1\n",
|
||||
" if num_of_loaders == 1:\n",
|
||||
" validation_step_outputs = [validation_step_outputs]\n",
|
||||
" assert len(validation_step_outputs) == num_of_loaders\n",
|
||||
" cnt = [0] * num_of_loaders\n",
|
||||
" total_metric = [0] * num_of_loaders\n",
|
||||
" val_metric = [0] * num_of_loaders\n",
|
||||
" for i, results in enumerate(validation_step_outputs):\n",
|
||||
" for scores in results:\n",
|
||||
" cnt[i] += len(scores)\n",
|
||||
" total_metric[i] += np.sum(scores)\n",
|
||||
" val_metric[i] = total_metric[i] / cnt[i]\n",
|
||||
" val_metric_name = f\"val_metric_{i}th_dataset\"\n",
|
||||
" self.log_dict({val_metric_name: val_metric[i]}, sync_dist=True)\n",
|
||||
" self.log_dict({\"val_metric\": np.sum(total_metric) / np.sum(cnt)}, sync_dist=True)\n",
|
||||
"\n",
|
||||
" def configure_optimizers(self):\n",
|
||||
" # TODO add scheduler\n",
|
||||
" optimizer = torch.optim.Adam(self.parameters(), lr=self.config.get(\"lr\"))\n",
|
||||
" \n",
|
||||
" return optimizer\n",
|
||||
"\n",
|
||||
" def train_dataloader(self):\n",
|
||||
" return self._train_dataloader\n",
|
||||
"\n",
|
||||
" def val_dataloader(self):\n",
|
||||
" return self._val_dataloader"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# dataset = load_dataset"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"image_processor = DonutImageProcessor(do_resize=True, do_align_long_axis=False, size=[960, 1260])\n",
|
||||
"tokenizer = XLMRobertaTokenizerFast.from_pretrained('xlm-roberta-base')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"config_encoder = ViTConfig(image_size=[1260, 960])\n",
|
||||
"config_decoder = BertConfig()\n",
|
||||
"config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(config_encoder, config_decoder)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"processor = DonutProcessor(image_processor=image_processor, tokenizer=tokenizer)\n",
|
||||
"model = VisionEncoderDecoderModel(config=config)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"added_tokens = []\n",
|
||||
"\n",
|
||||
"### PROCESS FUNC START ###\n",
|
||||
"\n",
|
||||
"def add_tokens(list_of_tokens: List[str]):\n",
|
||||
" \"\"\"\n",
|
||||
" Add special tokens to tokenizer and resize the token embeddings of the decoder\n",
|
||||
" \"\"\"\n",
|
||||
" newly_added_num = processor.tokenizer.add_tokens(list_of_tokens)\n",
|
||||
" if newly_added_num > 0:\n",
|
||||
" model.decoder.resize_token_embeddings(len(processor.tokenizer))\n",
|
||||
" added_tokens.extend(list_of_tokens)\n",
|
||||
"\n",
|
||||
"def json2token(obj: Any, update_special_tokens_for_json_key: bool = True, sort_json_key: bool = True):\n",
|
||||
" \"\"\"\n",
|
||||
" Convert an ordered JSON object into a token sequence\n",
|
||||
" \"\"\"\n",
|
||||
" if type(obj) == dict:\n",
|
||||
" if len(obj) == 1 and \"text_sequence\" in obj:\n",
|
||||
" return obj[\"text_sequence\"]\n",
|
||||
" else:\n",
|
||||
" output = \"\"\n",
|
||||
" if sort_json_key:\n",
|
||||
" keys = sorted(obj.keys(), reverse=True)\n",
|
||||
" else:\n",
|
||||
" keys = obj.keys()\n",
|
||||
" for k in keys:\n",
|
||||
" if update_special_tokens_for_json_key:\n",
|
||||
" add_tokens([fr\"<s_{k}>\", fr\"</s_{k}>\"])\n",
|
||||
" output += (\n",
|
||||
" fr\"<s_{k}>\"\n",
|
||||
" + json2token(obj[k], update_special_tokens_for_json_key, sort_json_key)\n",
|
||||
" + fr\"</s_{k}>\"\n",
|
||||
" )\n",
|
||||
" return output\n",
|
||||
" elif type(obj) == list:\n",
|
||||
" return r\"<sep/>\".join(\n",
|
||||
" [json2token(item, update_special_tokens_for_json_key, sort_json_key) for item in obj]\n",
|
||||
" )\n",
|
||||
" else:\n",
|
||||
" obj = str(obj)\n",
|
||||
" if f\"<{obj}/>\" in added_tokens:\n",
|
||||
" obj = f\"<{obj}/>\" # for categorical special tokens\n",
|
||||
" return obj\n",
|
||||
"\n",
|
||||
"def process(row, split):\n",
|
||||
" task_start_token, prompt_end_token = \"<s_cord-v2>\", \"<s_cord-v2>\"\n",
|
||||
" ground_truth = json.loads(row[\"ground_truth\"])\n",
|
||||
" if \"gt_parses\" in ground_truth: # when multiple ground truths are available, e.g., docvqa\n",
|
||||
" assert isinstance(ground_truth[\"gt_parses\"], list)\n",
|
||||
" gt_jsons = ground_truth[\"gt_parses\"]\n",
|
||||
" else:\n",
|
||||
" assert \"gt_parse\" in ground_truth and isinstance(ground_truth[\"gt_parse\"], dict)\n",
|
||||
" gt_jsons = [ground_truth[\"gt_parse\"]]\n",
|
||||
"\n",
|
||||
" gt_token_sequences = (\n",
|
||||
" [\n",
|
||||
" json2token(\n",
|
||||
" gt_json,\n",
|
||||
" update_special_tokens_for_json_key=split == \"train\",\n",
|
||||
" sort_json_key=False,\n",
|
||||
" )\n",
|
||||
" + processor.tokenizer.eos_token\n",
|
||||
" for gt_json in gt_jsons # load json from list of json\n",
|
||||
" ]\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" add_tokens([task_start_token, prompt_end_token])\n",
|
||||
" prompt_end_token_id = processor.tokenizer.convert_tokens_to_ids(prompt_end_token)\n",
|
||||
"\n",
|
||||
" # change if not 3 channels\n",
|
||||
" if row['image'].mode != \"RGB\":\n",
|
||||
" row['image'] = row['image'].convert(\"RGB\")\n",
|
||||
" # inputs\n",
|
||||
" pixel_values = processor(row[\"image\"], random_padding=split == \"train\", return_tensors=\"pt\").pixel_values\n",
|
||||
" pixel_values = pixel_values.squeeze()\n",
|
||||
"\n",
|
||||
" # targets\n",
|
||||
" input_ids = processor.tokenizer(\n",
|
||||
" gt_token_sequences,\n",
|
||||
" add_special_tokens=False,\n",
|
||||
" max_length=config.max_length,\n",
|
||||
" padding=\"max_length\",\n",
|
||||
" truncation=True,\n",
|
||||
" return_tensors=\"pt\",\n",
|
||||
" )[\"input_ids\"].squeeze(0)\n",
|
||||
"\n",
|
||||
" labels = input_ids.clone()\n",
|
||||
" labels[labels == processor.tokenizer.pad_token_id] = -100 # model doesn't need to predict pad token\n",
|
||||
" return {\"pixel_values\": pixel_values, \"labels\": labels, 'target_sequence': gt_token_sequences }\n",
|
||||
"\n",
|
||||
"def proces_train(row):\n",
|
||||
" return process(row, 'train')\n",
|
||||
"\n",
|
||||
"def proces_val(row):\n",
|
||||
" return process(row, 'validation')\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Using custom data configuration Zombely--wikisource-red-98affb32ced5f2c5\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"dataset = load_dataset('Zombely/wikisource-red', streaming=True)\n",
|
||||
"val_dataset = dataset.pop('validation') \n",
|
||||
"train_dataset = interleave_datasets(list(dataset.values()))\n",
|
||||
"# train_length = sum(split.num_examples for split in dataset[list(dataset.keys())[0]].info.splits.values() if split.name != 'validation')\n",
|
||||
"# val_length = list(val_dataset.info.splits.values())[-1].num_examples\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"train_dataset = train_dataset.map(proces_train, remove_columns = ['image', 'ground_truth'])\n",
|
||||
"val_dataset = val_dataset.map(proces_val, remove_columns = ['image', 'ground_truth'])\n",
|
||||
"\n",
|
||||
"train_dataset = train_dataset.with_format('torch')\n",
|
||||
"val_dataset = val_dataset.with_format('torch')\n",
|
||||
"\n",
|
||||
"# train_dataset = CustomWrapperIterator(train_dataset, total_len=train_length)\n",
|
||||
"# val_dataset = CustomWrapperIterator(val_dataset, total_len=val_length)\n",
|
||||
"\n",
|
||||
"model.config.pad_token_id = processor.tokenizer.pad_token_id\n",
|
||||
"model.config.decoder_start_token_id = processor.tokenizer.convert_tokens_to_ids(['<s_cord-v2>'])[0]\n",
|
||||
"\n",
|
||||
"train_dataloader = DataLoader(train_dataset, batch_size=1, num_workers=0)\n",
|
||||
"val_dataloader = DataLoader(val_dataset, batch_size=1, num_workers=0)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train_config = {\n",
|
||||
" \"max_epochs\": 1,\n",
|
||||
" \"val_check_interval\": 1.0,\n",
|
||||
" \"check_val_every_n_epoch\": 1,\n",
|
||||
" \"gradient_clip_val\": 1.0,\n",
|
||||
" \"num_training_samples_per_epoch\": 800,\n",
|
||||
" \"lr\": 1.0e-4,\n",
|
||||
" \"train_batch_sizes\": [8],\n",
|
||||
" \"val_batch_sizes\": [1],\n",
|
||||
" \"seed\": 2023,\n",
|
||||
" \"num_nodes\": 1,\n",
|
||||
" \"warmup_steps\": 10,\n",
|
||||
" \"result_path\": \"./result\",\n",
|
||||
" \"verbose\": True\n",
|
||||
"}\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model_module = DonutModelPLModuleStream(train_config, processor, model, max_length=config.max_length, train_dataloader=train_dataloader, val_dataloader=val_dataloader)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Using 16bit native Automatic Mixed Precision (AMP)\n",
|
||||
"GPU available: True (cuda), used: True\n",
|
||||
"TPU available: False, using: 0 TPU cores\n",
|
||||
"IPU available: False, using: 0 IPUs\n",
|
||||
"HPU available: False, using: 0 HPUs\n",
|
||||
"`Trainer(val_check_interval=1.0)` was configured so validation will run at the end of the training epoch..\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"\n",
|
||||
"trainer = pl.Trainer(\n",
|
||||
" accelerator=\"gpu\" if torch.cuda.is_available() else 'cpu', # change to gpu\n",
|
||||
" devices=1,\n",
|
||||
" max_epochs=train_config['max_epochs'],\n",
|
||||
" val_check_interval=train_config['val_check_interval'],\n",
|
||||
" check_val_every_n_epoch=train_config['check_val_every_n_epoch'],\n",
|
||||
" gradient_clip_val=train_config['gradient_clip_val'],\n",
|
||||
" precision=16, # we'll use mixed precision\n",
|
||||
" num_sanity_val_steps=0,\n",
|
||||
")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Missing logger folder: /home/wmi/project/donut/notepads/lightning_logs\n",
|
||||
"LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
|
||||
"\n",
|
||||
" | Name | Type | Params\n",
|
||||
"----------------------------------------------------\n",
|
||||
"0 | model | VisionEncoderDecoderModel | 227 M \n",
|
||||
"----------------------------------------------------\n",
|
||||
"227 M Trainable params\n",
|
||||
"0 Non-trainable params\n",
|
||||
"227 M Total params\n",
|
||||
"455.428 Total estimated model params size (MB)\n",
|
||||
"/home/wmi/project/donut/env_donut/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 14 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n",
|
||||
" rank_zero_warn(\n",
|
||||
"/home/wmi/project/donut/env_donut/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning: The dataloader, val_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 14 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n",
|
||||
" rank_zero_warn(\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "15d82804a8fe4aa6b2a02a16ce144496",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Training: 0it [00:00, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"(1024, 1579)\n",
|
||||
"(1024, 1473)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"ename": "OutOfMemoryError",
|
||||
"evalue": "CUDA out of memory. Tried to allocate 368.00 MiB (GPU 0; 23.70 GiB total capacity; 22.07 GiB already allocated; 260.56 MiB free; 22.09 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||||
"\u001b[0;31mOutOfMemoryError\u001b[0m Traceback (most recent call last)",
|
||||
"\u001b[0;32m/tmp/ipykernel_32385/828374167.py\u001b[0m in \u001b[0;36m<cell line: 1>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel_module\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
|
||||
"\u001b[0;32m~/project/donut/env_donut/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)\u001b[0m\n\u001b[1;32m 580\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mTypeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"`Trainer.fit()` requires a `LightningModule`, got: {model.__class__.__qualname__}\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 581\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstrategy\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_lightning_module\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 582\u001b[0;31m call._call_and_handle_interrupt(\n\u001b[0m\u001b[1;32m 583\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_fit_impl\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrain_dataloaders\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mval_dataloaders\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdatamodule\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mckpt_path\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 584\u001b[0m )\n",
|
||||
"\u001b[0;32m~/project/donut/env_donut/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py\u001b[0m in \u001b[0;36m_call_and_handle_interrupt\u001b[0;34m(trainer, trainer_fn, *args, **kwargs)\u001b[0m\n\u001b[1;32m 36\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstrategy\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlauncher\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlaunch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrainer_fn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrainer\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtrainer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 37\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 38\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mtrainer_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 39\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 40\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0m_TunerExitException\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0;32m~/project/donut/env_donut/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py\u001b[0m in \u001b[0;36m_fit_impl\u001b[0;34m(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)\u001b[0m\n\u001b[1;32m 622\u001b[0m \u001b[0mmodel_connected\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlightning_module\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 623\u001b[0m )\n\u001b[0;32m--> 624\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_run\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mckpt_path\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mckpt_path\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 625\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 626\u001b[0m \u001b[0;32massert\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstate\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstopped\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0;32m~/project/donut/env_donut/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py\u001b[0m in \u001b[0;36m_run\u001b[0;34m(self, model, ckpt_path)\u001b[0m\n\u001b[1;32m 1059\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_checkpoint_connector\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mresume_end\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1060\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1061\u001b[0;31m \u001b[0mresults\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_run_stage\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1062\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1063\u001b[0m \u001b[0mlog\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdetail\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"{self.__class__.__name__}: trainer tearing down\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0;32m~/project/donut/env_donut/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py\u001b[0m in \u001b[0;36m_run_stage\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1138\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpredicting\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1139\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_run_predict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1140\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_run_train\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1141\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1142\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_pre_training_routine\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0;32m~/project/donut/env_donut/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py\u001b[0m in \u001b[0;36m_run_train\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1161\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1162\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mautograd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset_detect_anomaly\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_detect_anomaly\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1163\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit_loop\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1164\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1165\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_run_evaluate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0m_EVALUATE_OUTPUT\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0;32m~/project/donut/env_donut/lib/python3.10/site-packages/pytorch_lightning/loops/loop.py\u001b[0m in \u001b[0;36mrun\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 197\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 198\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mon_advance_start\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 199\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madvance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 200\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mon_advance_end\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 201\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_restarting\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0;32m~/project/donut/env_donut/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py\u001b[0m in \u001b[0;36madvance\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 265\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_data_fetcher\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msetup\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdataloader\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_to_device\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mbatch_to_device\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 266\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprofiler\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprofile\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"run_training_epoch\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 267\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_outputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mepoch_loop\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_data_fetcher\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 268\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 269\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mon_advance_end\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0;32m~/project/donut/env_donut/lib/python3.10/site-packages/pytorch_lightning/loops/loop.py\u001b[0m in \u001b[0;36mrun\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 197\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 198\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mon_advance_start\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 199\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madvance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 200\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mon_advance_end\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 201\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_restarting\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0;32m~/project/donut/env_donut/lib/python3.10/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py\u001b[0m in \u001b[0;36madvance\u001b[0;34m(self, data_fetcher)\u001b[0m\n\u001b[1;32m 212\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 213\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprofiler\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprofile\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"run_training_batch\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 214\u001b[0;31m \u001b[0mbatch_output\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbatch_loop\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 215\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 216\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbatch_progress\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mincrement_processed\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0;32m~/project/donut/env_donut/lib/python3.10/site-packages/pytorch_lightning/loops/loop.py\u001b[0m in \u001b[0;36mrun\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 197\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 198\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mon_advance_start\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 199\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madvance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 200\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mon_advance_end\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 201\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_restarting\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0;32m~/project/donut/env_donut/lib/python3.10/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py\u001b[0m in \u001b[0;36madvance\u001b[0;34m(self, kwargs)\u001b[0m\n\u001b[1;32m 86\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moptimizers\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moptimizer_frequencies\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"batch_idx\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 87\u001b[0m )\n\u001b[0;32m---> 88\u001b[0;31m \u001b[0moutputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moptimizer_loop\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moptimizers\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 89\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 90\u001b[0m \u001b[0moutputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmanual_loop\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0;32m~/project/donut/env_donut/lib/python3.10/site-packages/pytorch_lightning/loops/loop.py\u001b[0m in \u001b[0;36mrun\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 197\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 198\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mon_advance_start\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 199\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madvance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 200\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mon_advance_end\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 201\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_restarting\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0;32m~/project/donut/env_donut/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py\u001b[0m in \u001b[0;36madvance\u001b[0;34m(self, optimizers, kwargs)\u001b[0m\n\u001b[1;32m 198\u001b[0m \u001b[0mkwargs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_build_kwargs\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moptimizer_idx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_hiddens\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 199\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 200\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_run_optimization\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_optimizers\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moptim_progress\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moptimizer_position\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 201\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mloss\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 202\u001b[0m \u001b[0;31m# automatic optimization assumes a loss needs to be returned for extras to be considered as the batch\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0;32m~/project/donut/env_donut/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py\u001b[0m in \u001b[0;36m_run_optimization\u001b[0;34m(self, kwargs, optimizer)\u001b[0m\n\u001b[1;32m 245\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 246\u001b[0m \u001b[0;31m# the `batch_idx` is optional with inter-batch parallelism\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 247\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_optimizer_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moptimizer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mopt_idx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"batch_idx\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mclosure\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 248\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 249\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mclosure\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconsume_result\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0;32m~/project/donut/env_donut/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py\u001b[0m in \u001b[0;36m_optimizer_step\u001b[0;34m(self, optimizer, opt_idx, batch_idx, train_step_and_backward_closure)\u001b[0m\n\u001b[1;32m 355\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 356\u001b[0m \u001b[0;31m# model hook\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 357\u001b[0;31m self.trainer._call_lightning_module_hook(\n\u001b[0m\u001b[1;32m 358\u001b[0m \u001b[0;34m\"optimizer_step\"\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 359\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcurrent_epoch\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0;32m~/project/donut/env_donut/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py\u001b[0m in \u001b[0;36m_call_lightning_module_hook\u001b[0;34m(self, hook_name, pl_module, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1303\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1304\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprofiler\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprofile\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"[LightningModule]{pl_module.__class__.__name__}.{hook_name}\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1305\u001b[0;31m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1306\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1307\u001b[0m \u001b[0;31m# restore current_fx when nested context\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0;32m~/project/donut/env_donut/lib/python3.10/site-packages/pytorch_lightning/core/module.py\u001b[0m in \u001b[0;36moptimizer_step\u001b[0;34m(self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, on_tpu, using_native_amp, using_lbfgs)\u001b[0m\n\u001b[1;32m 1659\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1660\u001b[0m \"\"\"\n\u001b[0;32m-> 1661\u001b[0;31m \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mclosure\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0moptimizer_closure\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1662\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1663\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0moptimizer_zero_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mepoch\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_idx\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mOptimizer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptimizer_idx\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0;32m~/project/donut/env_donut/lib/python3.10/site-packages/pytorch_lightning/core/optimizer.py\u001b[0m in \u001b[0;36mstep\u001b[0;34m(self, closure, **kwargs)\u001b[0m\n\u001b[1;32m 167\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 168\u001b[0m \u001b[0;32massert\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_strategy\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 169\u001b[0;31m \u001b[0mstep_output\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_strategy\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moptimizer_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_optimizer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_optimizer_idx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mclosure\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 170\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 171\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_on_after_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0;32m~/project/donut/env_donut/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py\u001b[0m in \u001b[0;36moptimizer_step\u001b[0;34m(self, optimizer, opt_idx, closure, model, **kwargs)\u001b[0m\n\u001b[1;32m 232\u001b[0m \u001b[0;31m# TODO(lite): remove assertion once strategy's optimizer_step typing is fixed\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 233\u001b[0m \u001b[0;32massert\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpl\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mLightningModule\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 234\u001b[0;31m return self.precision_plugin.optimizer_step(\n\u001b[0m\u001b[1;32m 235\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptimizer_idx\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mopt_idx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mclosure\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mclosure\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 236\u001b[0m )\n",
|
||||
"\u001b[0;32m~/project/donut/env_donut/lib/python3.10/site-packages/pytorch_lightning/plugins/precision/native_amp.py\u001b[0m in \u001b[0;36moptimizer_step\u001b[0;34m(self, optimizer, model, optimizer_idx, closure, **kwargs)\u001b[0m\n\u001b[1;32m 83\u001b[0m \u001b[0;34mf\"Native AMP and the LBFGS optimizer are not compatible (optimizer {optimizer_idx}).\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 84\u001b[0m )\n\u001b[0;32m---> 85\u001b[0;31m \u001b[0mclosure_result\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mclosure\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 86\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 87\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0m_optimizer_handles_unscaling\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moptimizer\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0;32m~/project/donut/env_donut/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 145\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 146\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__call__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mAny\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mAny\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mOptional\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 147\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_result\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mclosure\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 148\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_result\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mloss\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 149\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0;32m~/project/donut/env_donut/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py\u001b[0m in \u001b[0;36mclosure\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 140\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 141\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_backward_fn\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mstep_output\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mclosure_loss\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 142\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_backward_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstep_output\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mclosure_loss\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 143\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 144\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mstep_output\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0;32m~/project/donut/env_donut/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py\u001b[0m in \u001b[0;36mbackward_fn\u001b[0;34m(loss)\u001b[0m\n\u001b[1;32m 301\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 302\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mbackward_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mloss\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 303\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call_strategy_hook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"backward\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mloss\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mopt_idx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 304\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 305\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mbackward_fn\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0;32m~/project/donut/env_donut/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py\u001b[0m in \u001b[0;36m_call_strategy_hook\u001b[0;34m(self, hook_name, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1441\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1442\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprofiler\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprofile\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"[Strategy]{self.strategy.__class__.__name__}.{hook_name}\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1443\u001b[0;31m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1444\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1445\u001b[0m \u001b[0;31m# restore current_fx when nested context\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0;32m~/project/donut/env_donut/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(self, closure_loss, optimizer, optimizer_idx, *args, **kwargs)\u001b[0m\n\u001b[1;32m 205\u001b[0m \u001b[0mclosure_loss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprecision_plugin\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpre_backward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mclosure_loss\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlightning_module\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 206\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 207\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprecision_plugin\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mclosure_loss\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlightning_module\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptimizer_idx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 208\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 209\u001b[0m \u001b[0mclosure_loss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprecision_plugin\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpost_backward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mclosure_loss\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlightning_module\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0;32m~/project/donut/env_donut/lib/python3.10/site-packages/pytorch_lightning/plugins/precision/precision_plugin.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(self, tensor, model, optimizer, optimizer_idx, *args, **kwargs)\u001b[0m\n\u001b[1;32m 67\u001b[0m \u001b[0;31m\\\u001b[0m\u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mKeyword\u001b[0m \u001b[0marguments\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mthe\u001b[0m \u001b[0msame\u001b[0m \u001b[0mpurpose\u001b[0m \u001b[0;32mas\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 68\u001b[0m \"\"\"\n\u001b[0;32m---> 69\u001b[0;31m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptimizer_idx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 70\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 71\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mpost_backward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtensor\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodule\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m\"pl.LightningModule\"\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;31m# type: ignore[override]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0;32m~/project/donut/env_donut/lib/python3.10/site-packages/pytorch_lightning/core/module.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(self, loss, optimizer, optimizer_idx, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1404\u001b[0m \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1405\u001b[0m \"\"\"\n\u001b[0;32m-> 1406\u001b[0;31m \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1407\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1408\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mtoggle_optimizer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mUnion\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mOptimizer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mLightningOptimizer\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptimizer_idx\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0;32m~/project/donut/env_donut/lib/python3.10/site-packages/torch/_tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m 486\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 487\u001b[0m )\n\u001b[0;32m--> 488\u001b[0;31m torch.autograd.backward(\n\u001b[0m\u001b[1;32m 489\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 490\u001b[0m )\n",
|
||||
"\u001b[0;32m~/project/donut/env_donut/lib/python3.10/site-packages/torch/autograd/__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m 195\u001b[0m \u001b[0;31m# some Python versions print out the first line of a multi-line function\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 196\u001b[0m \u001b[0;31m# calls in the traceback and some print out the last line\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 197\u001b[0;31m Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass\n\u001b[0m\u001b[1;32m 198\u001b[0m \u001b[0mtensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad_tensors_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 199\u001b[0m allow_unreachable=True, accumulate_grad=True) # Calls into the C++ engine to run the backward pass\n",
|
||||
"\u001b[0;31mOutOfMemoryError\u001b[0m: CUDA out of memory. Tried to allocate 368.00 MiB (GPU 0; 23.70 GiB total capacity; 22.07 GiB already allocated; 260.56 MiB free; 22.09 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"trainer.fit(model_module)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "env_donut",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.6"
|
||||
},
|
||||
"orig_nbformat": 4
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
@ -14,7 +14,7 @@ beautifulsoup4==4.11.1
|
||||
bleach==5.0.1
|
||||
blend-modes==2.1.0
|
||||
cachetools==5.2.0
|
||||
certifi==2022.12.7
|
||||
certifi
|
||||
cffi==1.15.1
|
||||
charset-normalizer==2.1.1
|
||||
click==8.1.3
|
||||
@ -26,7 +26,7 @@ decorator==5.1.1
|
||||
defusedxml==0.7.1
|
||||
dill==0.3.6
|
||||
docker-pycreds==0.4.0
|
||||
donut-python==1.0.9
|
||||
donut-python
|
||||
entrypoints==0.4
|
||||
evaluate==0.3.0
|
||||
fastapi==0.87.0
|
||||
@ -109,9 +109,8 @@ parso==0.8.3
|
||||
pathtools==0.1.2
|
||||
pexpect==4.8.0
|
||||
pickleshare==0.7.5
|
||||
Pillow==9.4.0
|
||||
Pillow==9.3.0
|
||||
pkgutil_resolve_name==1.3.10
|
||||
portalocker==2.7.0
|
||||
prometheus-client==0.15.0
|
||||
promise==2.3
|
||||
prompt-toolkit==3.0.33
|
||||
@ -176,10 +175,9 @@ tifffile==2021.11.2
|
||||
timm==0.6.11
|
||||
tinycss2==1.2.1
|
||||
tokenizers==0.13.2
|
||||
torch==1.13.1
|
||||
torchdata==0.5.1
|
||||
torchmetrics==0.11.4
|
||||
torchvision==0.14.1
|
||||
torch==1.13.0
|
||||
torchmetrics==0.10.3
|
||||
torchvision==0.14.0
|
||||
tornado==6.2
|
||||
tqdm==4.64.1
|
||||
traitlets==5.5.0
|
||||
|
115
train.py
115
train.py
@ -1,115 +0,0 @@
|
||||
from transformers import VisionEncoderDecoderConfig, DonutProcessor, VisionEncoderDecoderModel
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from pytorch_lightning.loggers import WandbLogger
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
import pytorch_lightning as pl
|
||||
import os
|
||||
from huggingface_hub import login
|
||||
import argparse
|
||||
from sconf import Config
|
||||
from utils.checkpoint import CustomCheckpointIO
|
||||
from utils.donut_dataset import DonutDataset
|
||||
from utils.donut_model_pl import DonutModelPLModule
|
||||
from utils.callbacks import PushToHubCallback
|
||||
import warnings
|
||||
from datasets import load_dataset
|
||||
|
||||
|
||||
|
||||
|
||||
def main(config, hug_token):
|
||||
|
||||
config_vision = VisionEncoderDecoderConfig.from_pretrained(
|
||||
config.pretrained_model_path)
|
||||
config_vision.encoder.image_size = config.image_size
|
||||
config_vision.decoder.max_length = config.max_length
|
||||
|
||||
processor = DonutProcessor.from_pretrained(config.start_model_path)
|
||||
model = VisionEncoderDecoderModel.from_pretrained(
|
||||
config.pretrained_model_path, config=config_vision)
|
||||
|
||||
processor.image_processor.size = config.image_size[::-1]
|
||||
processor.image_processor.do_align_long_axis = False
|
||||
|
||||
added_tokens = []
|
||||
|
||||
train_dataset = DonutDataset(
|
||||
config.dataset_path,
|
||||
processor=processor,
|
||||
model=model,
|
||||
max_length=config.max_length,
|
||||
split="train",
|
||||
task_start_token="<s_cord-v2>",
|
||||
prompt_end_token="<s_cord-v2>",
|
||||
added_tokens=added_tokens,
|
||||
sort_json_key=False, # cord dataset is preprocessed, so no need for this
|
||||
)
|
||||
|
||||
val_dataset = DonutDataset(
|
||||
config.dataset_path,
|
||||
processor=processor,
|
||||
model=model,
|
||||
max_length=config.max_length,
|
||||
split="validation",
|
||||
task_start_token="<s_cord-v2>",
|
||||
prompt_end_token="<s_cord-v2>",
|
||||
added_tokens=added_tokens,
|
||||
sort_json_key=False, # cord dataset is preprocessed, so no need for this
|
||||
)
|
||||
|
||||
model.config.pad_token_id = processor.tokenizer.pad_token_id
|
||||
model.config.decoder_start_token_id = processor.tokenizer.convert_tokens_to_ids(['<s_cord-v2>'])[0]
|
||||
|
||||
train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=4)
|
||||
val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=4)
|
||||
|
||||
login(hug_token, True)
|
||||
|
||||
model_module = DonutModelPLModule(config.train_config.toDict(), processor, model, max_length=config.max_length, train_dataloader=train_dataloader, val_dataloader=val_dataloader)
|
||||
|
||||
wandb_logger = WandbLogger(project="Donut", name=config.wandb_test_name)
|
||||
|
||||
checkpoint_callback = ModelCheckpoint(
|
||||
monitor="val_metric",
|
||||
dirpath=config.checkpoint_path,
|
||||
filename="artifacts",
|
||||
save_top_k=1,
|
||||
save_last=False,
|
||||
mode="min",
|
||||
)
|
||||
|
||||
custom_ckpt = CustomCheckpointIO()
|
||||
|
||||
trainer = pl.Trainer(
|
||||
accelerator="gpu" if torch.cuda.is_available() else 'cpu', # change to gpu
|
||||
devices=1,
|
||||
max_epochs=config.train_config.max_epochs,
|
||||
val_check_interval=config.train_config.val_check_interval,
|
||||
check_val_every_n_epoch=config.train_config.check_val_every_n_epoch,
|
||||
gradient_clip_val=config.train_config.gradient_clip_val,
|
||||
precision=16, # we'll use mixed precision
|
||||
plugins=custom_ckpt,
|
||||
num_sanity_val_steps=0,
|
||||
logger=wandb_logger,
|
||||
callbacks=[PushToHubCallback(output_model_path=config.output_model_path), checkpoint_callback],
|
||||
)
|
||||
|
||||
trainer.fit(model_module)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--config", type=str, required=True)
|
||||
args, left_argv = parser.parse_known_args()
|
||||
config = Config(args.config)
|
||||
config.argv_update(left_argv)
|
||||
|
||||
hug_token = os.environ.get("HUG_TOKEN", None)
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
warnings.warn("You don't have cuda available, training might be taking long time or impossible")
|
||||
|
||||
if not hug_token:
|
||||
raise Exception("You need to set up HUG_TOKEN in enviroments to push output model to hub")
|
||||
main(config, hug_token)
|
255
train_stream.py
255
train_stream.py
@ -1,255 +0,0 @@
|
||||
from typing import Any, List
|
||||
from transformers import VisionEncoderDecoderConfig, DonutProcessor, VisionEncoderDecoderModel
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from pytorch_lightning.loggers import WandbLogger
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
import pytorch_lightning as pl
|
||||
import os
|
||||
from huggingface_hub import login
|
||||
import argparse
|
||||
from sconf import Config
|
||||
from utils.checkpoint import CustomCheckpointIO
|
||||
from utils.donut_dataset_stream import DonutDatasetStream
|
||||
from utils.donut_model_pl_stream import DonutModelPLModuleStream
|
||||
from utils.callbacks import PushToHubCallback
|
||||
import warnings
|
||||
from datasets import load_dataset, interleave_datasets
|
||||
from torchdata.datapipes.iter import IterableWrapper
|
||||
import json
|
||||
|
||||
|
||||
class CustomWrapperIterator(IterableWrapper):
|
||||
def __init__(self, iterable, deepcopy=True, total_len=None):
|
||||
super().__init__(iterable, deepcopy)
|
||||
self.total_len = total_len
|
||||
|
||||
def __len__(self):
|
||||
if self.total_len:
|
||||
return self.total_len
|
||||
return super().__len__()
|
||||
|
||||
def main(config, hug_token):
|
||||
|
||||
config_vision = VisionEncoderDecoderConfig.from_pretrained(
|
||||
config.pretrained_model_path)
|
||||
config_vision.encoder.image_size = config.image_size
|
||||
config_vision.decoder.max_length = config.max_length
|
||||
|
||||
processor = DonutProcessor.from_pretrained(config.start_model_path)
|
||||
model = VisionEncoderDecoderModel.from_pretrained(
|
||||
config.pretrained_model_path, config=config_vision)
|
||||
|
||||
processor.image_processor.size = config.image_size[::-1]
|
||||
processor.image_processor.do_align_long_axis = False
|
||||
|
||||
added_tokens = []
|
||||
|
||||
### PROCESS FUNC START ###
|
||||
|
||||
def add_tokens(list_of_tokens: List[str]):
|
||||
"""
|
||||
Add special tokens to tokenizer and resize the token embeddings of the decoder
|
||||
"""
|
||||
newly_added_num = processor.tokenizer.add_tokens(list_of_tokens)
|
||||
if newly_added_num > 0:
|
||||
model.decoder.resize_token_embeddings(len(processor.tokenizer))
|
||||
added_tokens.extend(list_of_tokens)
|
||||
|
||||
def json2token(obj: Any, update_special_tokens_for_json_key: bool = True, sort_json_key: bool = True):
|
||||
"""
|
||||
Convert an ordered JSON object into a token sequence
|
||||
"""
|
||||
if type(obj) == dict:
|
||||
if len(obj) == 1 and "text_sequence" in obj:
|
||||
return obj["text_sequence"]
|
||||
else:
|
||||
output = ""
|
||||
if sort_json_key:
|
||||
keys = sorted(obj.keys(), reverse=True)
|
||||
else:
|
||||
keys = obj.keys()
|
||||
for k in keys:
|
||||
if update_special_tokens_for_json_key:
|
||||
add_tokens([fr"<s_{k}>", fr"</s_{k}>"])
|
||||
output += (
|
||||
fr"<s_{k}>"
|
||||
+ json2token(obj[k], update_special_tokens_for_json_key, sort_json_key)
|
||||
+ fr"</s_{k}>"
|
||||
)
|
||||
return output
|
||||
elif type(obj) == list:
|
||||
return r"<sep/>".join(
|
||||
[json2token(item, update_special_tokens_for_json_key, sort_json_key) for item in obj]
|
||||
)
|
||||
else:
|
||||
obj = str(obj)
|
||||
if f"<{obj}/>" in added_tokens:
|
||||
obj = f"<{obj}/>" # for categorical special tokens
|
||||
return obj
|
||||
|
||||
def process(row, split):
|
||||
task_start_token, prompt_end_token = "<s_cord-v2>", "<s_cord-v2>"
|
||||
ground_truth = json.loads(row["ground_truth"])
|
||||
if "gt_parses" in ground_truth: # when multiple ground truths are available, e.g., docvqa
|
||||
assert isinstance(ground_truth["gt_parses"], list)
|
||||
gt_jsons = ground_truth["gt_parses"]
|
||||
else:
|
||||
assert "gt_parse" in ground_truth and isinstance(ground_truth["gt_parse"], dict)
|
||||
gt_jsons = [ground_truth["gt_parse"]]
|
||||
|
||||
gt_token_sequences = (
|
||||
[
|
||||
json2token(
|
||||
gt_json,
|
||||
update_special_tokens_for_json_key=split == "train",
|
||||
sort_json_key=False,
|
||||
)
|
||||
+ processor.tokenizer.eos_token
|
||||
for gt_json in gt_jsons # load json from list of json
|
||||
]
|
||||
)
|
||||
|
||||
add_tokens([task_start_token, prompt_end_token])
|
||||
prompt_end_token_id = processor.tokenizer.convert_tokens_to_ids(prompt_end_token)
|
||||
|
||||
# change if not 3 channels
|
||||
if row['image'].mode != "RGB":
|
||||
row['image'] = row['image'].convert("RGB")
|
||||
|
||||
# inputs
|
||||
pixel_values = processor(row["image"], random_padding=split == "train", return_tensors="pt").pixel_values
|
||||
pixel_values = pixel_values.squeeze()
|
||||
|
||||
# targets
|
||||
input_ids = processor.tokenizer(
|
||||
gt_token_sequences,
|
||||
add_special_tokens=False,
|
||||
max_length=config.max_length,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)["input_ids"].squeeze(0)
|
||||
|
||||
labels = input_ids.clone()
|
||||
labels[labels == processor.tokenizer.pad_token_id] = -100 # model doesn't need to predict pad token
|
||||
return {"pixel_values": pixel_values, "labels": labels, 'target_sequence': gt_token_sequences }
|
||||
|
||||
def proces_train(row):
|
||||
return process(row, 'train')
|
||||
|
||||
def proces_val(row):
|
||||
return process(row, 'validation')
|
||||
|
||||
|
||||
|
||||
|
||||
### PROCESS FUNC END ###
|
||||
|
||||
# train_dataset_process = DonutDatasetStream(
|
||||
# processor=processor,
|
||||
# model=model,
|
||||
# max_length=config.max_length,
|
||||
# split="train",
|
||||
# task_start_token="<s_cord-v2>",
|
||||
# prompt_end_token="<s_cord-v2>",
|
||||
# added_tokens=added_tokens,
|
||||
# sort_json_key=False, # cord dataset is preprocessed, so no need for this
|
||||
# )
|
||||
|
||||
# val_dataset_process = DonutDatasetStream(
|
||||
# processor=processor,
|
||||
# model=model,
|
||||
# max_length=config.max_length,
|
||||
# split="validation",
|
||||
# task_start_token="<s_cord-v2>",
|
||||
# prompt_end_token="<s_cord-v2>",
|
||||
# added_tokens=added_tokens,
|
||||
# sort_json_key=False, # cord dataset is preprocessed, so no need for this
|
||||
# )
|
||||
|
||||
# dataset_green = load_dataset("Zombely/wikisource-green", streaming=True)
|
||||
# val_dataset = dataset_green.pop('validation')
|
||||
# val_length = list(val_dataset.info.splits.values())[-1].num_examples
|
||||
# dataset_yellow = load_dataset("Zombely/wikisource-yellow", streaming=True)
|
||||
# dataset_red = load_dataset("Zombely/wikisource-red", streaming=True)
|
||||
# train_dataset = interleave_datasets(list(dataset_green.values()) + list(dataset_yellow.values()) + list(dataset_red.values()))
|
||||
# train_length_green = sum(split.num_examples for split in dataset_green[list(dataset_green.keys())[0]].info.splits.values() if split.name != 'validation')
|
||||
# train_length_yellow = sum(split.num_examples for split in dataset_yellow[list(dataset_yellow.keys())[0]].info.splits.values())
|
||||
# train_length_red = sum(split.num_examples for split in dataset_red[list(dataset_red.keys())[0]].info.splits.values())
|
||||
# train_length = train_length_green + train_length_yellow + train_length_red
|
||||
|
||||
dataset = load_dataset(config.dataset_path, streaming=True)
|
||||
val_dataset = dataset.pop('validation')
|
||||
train_dataset = interleave_datasets(list(dataset.values()))
|
||||
train_length = sum(split.num_examples for split in dataset[list(dataset.keys())[0]].info.splits.values() if split.name != 'validation')
|
||||
val_length = list(val_dataset.info.splits.values())[-1].num_examples
|
||||
|
||||
|
||||
train_dataset = train_dataset.map(proces_train, remove_columns = ['image', 'ground_truth'])
|
||||
val_dataset = val_dataset.map(proces_val, remove_columns = ['image', 'ground_truth'])
|
||||
|
||||
# train_dataset = train_dataset.with_format('torch')
|
||||
# val_dataset = val_dataset.with_format('torch')
|
||||
|
||||
train_dataset = CustomWrapperIterator(train_dataset, total_len=train_length)
|
||||
val_dataset = CustomWrapperIterator(val_dataset, total_len=val_length)
|
||||
|
||||
model.config.pad_token_id = processor.tokenizer.pad_token_id
|
||||
model.config.decoder_start_token_id = processor.tokenizer.convert_tokens_to_ids(['<s_cord-v2>'])[0]
|
||||
|
||||
train_dataloader = DataLoader(train_dataset, batch_size=1, num_workers=0)
|
||||
val_dataloader = DataLoader(val_dataset, batch_size=1, num_workers=0)
|
||||
|
||||
login(hug_token, True)
|
||||
|
||||
model_module = DonutModelPLModuleStream(config.train_config.toDict(), processor, model, max_length=config.max_length, train_dataloader=train_dataloader, val_dataloader=val_dataloader)
|
||||
|
||||
wandb_logger = WandbLogger(project="Donut", name=config.wandb_test_name)
|
||||
|
||||
if not os.path.exists(config.checkpoint_path):
|
||||
os.mkdir(config.checkpoint_path)
|
||||
|
||||
checkpoint_callback = ModelCheckpoint(
|
||||
monitor="val_metric",
|
||||
dirpath=config.checkpoint_path,
|
||||
filename="artifacts",
|
||||
save_top_k=1,
|
||||
save_last=False,
|
||||
mode="min",
|
||||
)
|
||||
|
||||
custom_ckpt = CustomCheckpointIO()
|
||||
|
||||
trainer = pl.Trainer(
|
||||
accelerator="gpu" if torch.cuda.is_available() else 'cpu', # change to gpu
|
||||
devices=1,
|
||||
max_epochs=config.train_config.max_epochs,
|
||||
val_check_interval=config.train_config.val_check_interval,
|
||||
check_val_every_n_epoch=config.train_config.check_val_every_n_epoch,
|
||||
gradient_clip_val=config.train_config.gradient_clip_val,
|
||||
precision=16, # we'll use mixed precision
|
||||
plugins=custom_ckpt,
|
||||
num_sanity_val_steps=0,
|
||||
logger=wandb_logger,
|
||||
callbacks=[PushToHubCallback(output_model_path=config.output_model_path), checkpoint_callback],
|
||||
)
|
||||
|
||||
trainer.fit(model_module)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--config", type=str, required=True)
|
||||
args, left_argv = parser.parse_known_args()
|
||||
config = Config(args.config)
|
||||
config.argv_update(left_argv)
|
||||
|
||||
hug_token = os.environ.get("HUG_TOKEN", "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")
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
warnings.warn("You don't have cuda available, training might be taking long time or impossible")
|
||||
|
||||
if not hug_token:
|
||||
raise Exception("You need to set up HUG_TOKEN in enviroments to push output model to hub")
|
||||
main(config, hug_token)
|
@ -1,21 +0,0 @@
|
||||
from pytorch_lightning.callbacks import Callback
|
||||
|
||||
|
||||
class PushToHubCallback(Callback):
|
||||
def __init__(self, output_model_path) -> None:
|
||||
super().__init__()
|
||||
self.output_model_path = output_model_path
|
||||
|
||||
|
||||
def on_train_epoch_end(self, trainer, pl_module):
|
||||
print(f"Pushing model to the hub, epoch {trainer.current_epoch}")
|
||||
pl_module.model.push_to_hub(self.output_model_path,
|
||||
commit_message=f"Training in progress, epoch {trainer.current_epoch}")
|
||||
# pl_module.processor.push_to_hub(self.output_model_path, commit_message=f"Training in progress, epoch {trainer.current_epoch}")
|
||||
|
||||
def on_train_end(self, trainer, pl_module):
|
||||
print(f"Pushing model to the hub after training")
|
||||
pl_module.processor.push_to_hub(self.output_model_path,
|
||||
commit_message=f"Training done")
|
||||
pl_module.model.push_to_hub(self.output_model_path,
|
||||
commit_message=f"Training done")
|
@ -1,17 +0,0 @@
|
||||
from pytorch_lightning.plugins import CheckpointIO
|
||||
import torch
|
||||
|
||||
|
||||
class CustomCheckpointIO(CheckpointIO):
|
||||
def save_checkpoint(self, checkpoint, path, storage_options=None):
|
||||
del checkpoint["state_dict"]
|
||||
torch.save(checkpoint, path)
|
||||
|
||||
def load_checkpoint(self, path, storage_options=None):
|
||||
checkpoint = torch.load(path + "artifacts.ckpt")
|
||||
state_dict = torch.load(path + "pytorch_model.bin")
|
||||
checkpoint["state_dict"] = {"model." + key: value for key, value in state_dict.items()}
|
||||
return checkpoint
|
||||
|
||||
def remove_checkpoint(self, path) -> None:
|
||||
return super().remove_checkpoint(path)
|
@ -1,155 +0,0 @@
|
||||
from datasets import load_dataset
|
||||
from torch.utils.data import Dataset
|
||||
import json
|
||||
from typing import Any, List, Tuple
|
||||
import random
|
||||
import torch
|
||||
from transformers import DonutProcessor, VisionEncoderDecoderModel
|
||||
|
||||
|
||||
class DonutDataset(Dataset):
|
||||
"""
|
||||
DonutDataset which is saved in huggingface datasets format. (see details in https://huggingface.co/docs/datasets)
|
||||
Each row, consists of image path(png/jpg/jpeg) and gt data (json/jsonl/txt),
|
||||
and it will be converted into input_tensor(vectorized image) and input_ids(tokenized string).
|
||||
Args:
|
||||
dataset_name_or_path: name of dataset (available at huggingface.co/datasets) or the path containing image files and metadata.jsonl
|
||||
max_length: the max number of tokens for the target sequences
|
||||
split: whether to load "train", "validation" or "test" split
|
||||
ignore_id: ignore_index for torch.nn.CrossEntropyLoss
|
||||
task_start_token: the special token to be fed to the decoder to conduct the target task
|
||||
prompt_end_token: the special token at the end of the sequences
|
||||
sort_json_key: whether or not to sort the JSON keys
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset_name_or_path: str,
|
||||
max_length: int,
|
||||
processor: DonutProcessor,
|
||||
model: VisionEncoderDecoderModel,
|
||||
split: str = "train",
|
||||
ignore_id: int = -100,
|
||||
task_start_token: str = "<s>",
|
||||
prompt_end_token: str = None,
|
||||
sort_json_key: bool = True,
|
||||
added_tokens: list = []
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.max_length = max_length
|
||||
self.split = split
|
||||
self.processor = processor
|
||||
self.model = model
|
||||
self.ignore_id = ignore_id
|
||||
self.task_start_token = task_start_token
|
||||
self.prompt_end_token = prompt_end_token if prompt_end_token else task_start_token
|
||||
self.sort_json_key = sort_json_key
|
||||
self.added_tokens = added_tokens
|
||||
|
||||
self.dataset = load_dataset(dataset_name_or_path, split=self.split)
|
||||
self.dataset_length = len(self.dataset)
|
||||
|
||||
self.gt_token_sequences = []
|
||||
for sample in self.dataset:
|
||||
ground_truth = json.loads(sample["ground_truth"])
|
||||
if "gt_parses" in ground_truth: # when multiple ground truths are available, e.g., docvqa
|
||||
assert isinstance(ground_truth["gt_parses"], list)
|
||||
gt_jsons = ground_truth["gt_parses"]
|
||||
else:
|
||||
assert "gt_parse" in ground_truth and isinstance(ground_truth["gt_parse"], dict)
|
||||
gt_jsons = [ground_truth["gt_parse"]]
|
||||
|
||||
self.gt_token_sequences.append(
|
||||
[
|
||||
self.json2token(
|
||||
gt_json,
|
||||
update_special_tokens_for_json_key=self.split == "train",
|
||||
sort_json_key=self.sort_json_key,
|
||||
)
|
||||
+ self.processor.tokenizer.eos_token
|
||||
for gt_json in gt_jsons # load json from list of json
|
||||
]
|
||||
)
|
||||
|
||||
self.add_tokens([self.task_start_token, self.prompt_end_token])
|
||||
self.prompt_end_token_id = self.processor.tokenizer.convert_tokens_to_ids(self.prompt_end_token)
|
||||
|
||||
def json2token(self, obj: Any, update_special_tokens_for_json_key: bool = True, sort_json_key: bool = True):
|
||||
"""
|
||||
Convert an ordered JSON object into a token sequence
|
||||
"""
|
||||
if type(obj) == dict:
|
||||
if len(obj) == 1 and "text_sequence" in obj:
|
||||
return obj["text_sequence"]
|
||||
else:
|
||||
output = ""
|
||||
if sort_json_key:
|
||||
keys = sorted(obj.keys(), reverse=True)
|
||||
else:
|
||||
keys = obj.keys()
|
||||
for k in keys:
|
||||
if update_special_tokens_for_json_key:
|
||||
self.add_tokens([fr"<s_{k}>", fr"</s_{k}>"])
|
||||
output += (
|
||||
fr"<s_{k}>"
|
||||
+ self.json2token(obj[k], update_special_tokens_for_json_key, sort_json_key)
|
||||
+ fr"</s_{k}>"
|
||||
)
|
||||
return output
|
||||
elif type(obj) == list:
|
||||
return r"<sep/>".join(
|
||||
[self.json2token(item, update_special_tokens_for_json_key, sort_json_key) for item in obj]
|
||||
)
|
||||
else:
|
||||
obj = str(obj)
|
||||
if f"<{obj}/>" in self.added_tokens:
|
||||
obj = f"<{obj}/>" # for categorical special tokens
|
||||
return obj
|
||||
|
||||
def add_tokens(self, list_of_tokens: List[str]):
|
||||
"""
|
||||
Add special tokens to tokenizer and resize the token embeddings of the decoder
|
||||
"""
|
||||
newly_added_num = self.processor.tokenizer.add_tokens(list_of_tokens)
|
||||
if newly_added_num > 0:
|
||||
self.model.decoder.resize_token_embeddings(len(self.processor.tokenizer))
|
||||
self.added_tokens.extend(list_of_tokens)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.dataset_length
|
||||
|
||||
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Load image from image_path of given dataset_path and convert into input_tensor and labels
|
||||
Convert gt data into input_ids (tokenized string)
|
||||
Returns:
|
||||
input_tensor : preprocessed image
|
||||
input_ids : tokenized gt_data
|
||||
labels : masked labels (model doesn't need to predict prompt and pad token)
|
||||
"""
|
||||
sample = self.dataset[idx]
|
||||
|
||||
# change if not 3 channels
|
||||
if sample['image'].mode != "RGB":
|
||||
sample['image'] = sample['image'].convert("RGB")
|
||||
|
||||
# inputs
|
||||
pixel_values = self.processor(sample["image"], random_padding=self.split == "train", return_tensors="pt").pixel_values
|
||||
pixel_values = pixel_values.squeeze()
|
||||
|
||||
# targets
|
||||
target_sequence = random.choice(self.gt_token_sequences[idx]) # can be more than one, e.g., DocVQA Task 1
|
||||
input_ids = self.processor.tokenizer(
|
||||
target_sequence,
|
||||
add_special_tokens=False,
|
||||
max_length=self.max_length,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)["input_ids"].squeeze(0)
|
||||
|
||||
labels = input_ids.clone()
|
||||
labels[labels == self.processor.tokenizer.pad_token_id] = self.ignore_id # model doesn't need to predict pad token
|
||||
# labels[: torch.nonzero(labels == self.prompt_end_token_id).sum() + 1] = self.ignore_id # model doesn't need to predict prompt (for VQA)
|
||||
return pixel_values, labels, target_sequence
|
@ -1,125 +0,0 @@
|
||||
from datasets import load_dataset
|
||||
from torch.utils.data import Dataset
|
||||
import json
|
||||
from typing import Any, List, Tuple
|
||||
import random
|
||||
import torch
|
||||
from transformers import DonutProcessor, VisionEncoderDecoderModel
|
||||
|
||||
|
||||
class DonutDatasetStream:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
processor: DonutProcessor,
|
||||
model: VisionEncoderDecoderModel,
|
||||
max_length: int,
|
||||
ignore_id: int = -100,
|
||||
split: str = 'train',
|
||||
task_start_token: str = "<s>",
|
||||
prompt_end_token: str = None,
|
||||
sort_json_key: bool = True,
|
||||
added_tokens: list = []
|
||||
):
|
||||
|
||||
self.split = split
|
||||
self.max_length = max_length
|
||||
self.processor = processor
|
||||
self.model = model
|
||||
self.ignore_id = ignore_id
|
||||
self.task_start_token = task_start_token
|
||||
self.prompt_end_token = prompt_end_token if prompt_end_token else task_start_token
|
||||
self.sort_json_key = sort_json_key
|
||||
self.added_tokens = added_tokens
|
||||
|
||||
def process(self, row):
|
||||
|
||||
|
||||
ground_truth = json.loads(row["ground_truth"])
|
||||
if "gt_parses" in ground_truth: # when multiple ground truths are available, e.g., docvqa
|
||||
assert isinstance(ground_truth["gt_parses"], list)
|
||||
gt_jsons = ground_truth["gt_parses"]
|
||||
else:
|
||||
assert "gt_parse" in ground_truth and isinstance(ground_truth["gt_parse"], dict)
|
||||
gt_jsons = [ground_truth["gt_parse"]]
|
||||
|
||||
self.gt_token_sequences = (
|
||||
[
|
||||
self.json2token(
|
||||
gt_json,
|
||||
update_special_tokens_for_json_key=self.split == "train",
|
||||
sort_json_key=self.sort_json_key,
|
||||
)
|
||||
+ self.processor.tokenizer.eos_token
|
||||
for gt_json in gt_jsons # load json from list of json
|
||||
]
|
||||
)
|
||||
|
||||
self.add_tokens([self.task_start_token, self.prompt_end_token])
|
||||
self.prompt_end_token_id = self.processor.tokenizer.convert_tokens_to_ids(self.prompt_end_token)
|
||||
|
||||
|
||||
# change if not 3 channels
|
||||
if row['image'].mode != "RGB":
|
||||
row['image'] = row['image'].convert("RGB")
|
||||
|
||||
# inputs
|
||||
pixel_values = self.processor(row["image"], random_padding=self.split == "train", return_tensors="pt").pixel_values
|
||||
pixel_values = pixel_values.squeeze()
|
||||
|
||||
# targets
|
||||
target_sequence = random.choice(self.gt_token_sequences) # can be more than one, e.g., DocVQA Task 1
|
||||
input_ids = self.processor.tokenizer(
|
||||
target_sequence,
|
||||
add_special_tokens=False,
|
||||
max_length=self.max_length,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)["input_ids"].squeeze(0)
|
||||
|
||||
labels = input_ids.clone()
|
||||
labels[labels == self.processor.tokenizer.pad_token_id] = self.ignore_id # model doesn't need to predict pad token
|
||||
return {"pixel_values": pixel_values, "labels": labels, 'target_sequence': target_sequence }
|
||||
|
||||
def json2token(self, obj: Any, update_special_tokens_for_json_key: bool = True, sort_json_key: bool = True):
|
||||
"""
|
||||
Convert an ordered JSON object into a token sequence
|
||||
"""
|
||||
if type(obj) == dict:
|
||||
if len(obj) == 1 and "text_sequence" in obj:
|
||||
return obj["text_sequence"]
|
||||
else:
|
||||
output = ""
|
||||
if sort_json_key:
|
||||
keys = sorted(obj.keys(), reverse=True)
|
||||
else:
|
||||
keys = obj.keys()
|
||||
for k in keys:
|
||||
if update_special_tokens_for_json_key:
|
||||
self.add_tokens([fr"<s_{k}>", fr"</s_{k}>"])
|
||||
output += (
|
||||
fr"<s_{k}>"
|
||||
+ self.json2token(obj[k], update_special_tokens_for_json_key, sort_json_key)
|
||||
+ fr"</s_{k}>"
|
||||
)
|
||||
return output
|
||||
elif type(obj) == list:
|
||||
return r"<sep/>".join(
|
||||
[self.json2token(item, update_special_tokens_for_json_key, sort_json_key) for item in obj]
|
||||
)
|
||||
else:
|
||||
obj = str(obj)
|
||||
if f"<{obj}/>" in self.added_tokens:
|
||||
obj = f"<{obj}/>" # for categorical special tokens
|
||||
return obj
|
||||
|
||||
def add_tokens(self, list_of_tokens: List[str]):
|
||||
"""
|
||||
Add special tokens to tokenizer and resize the token embeddings of the decoder
|
||||
"""
|
||||
newly_added_num = self.processor.tokenizer.add_tokens(list_of_tokens)
|
||||
if newly_added_num > 0:
|
||||
self.model.decoder.resize_token_embeddings(len(self.processor.tokenizer))
|
||||
self.added_tokens.extend(list_of_tokens)
|
||||
|
@ -1,93 +0,0 @@
|
||||
import torch
|
||||
import pytorch_lightning as pl
|
||||
from nltk import edit_distance
|
||||
import re
|
||||
import numpy as np
|
||||
|
||||
|
||||
class DonutModelPLModule(pl.LightningModule):
|
||||
def __init__(self, config, processor, model, max_length, train_dataloader, val_dataloader):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.processor = processor
|
||||
self.model = model
|
||||
self.max_length = max_length
|
||||
self._train_dataloader = train_dataloader
|
||||
self._val_dataloader = val_dataloader
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
pixel_values, labels, _ = batch
|
||||
|
||||
outputs = self.model(pixel_values, labels=labels)
|
||||
loss = outputs.loss
|
||||
self.log_dict({"train_loss": loss}, sync_dist=True)
|
||||
return loss
|
||||
|
||||
def validation_step(self, batch, batch_idx, dataset_idx=0):
|
||||
pixel_values, labels, answers = batch
|
||||
batch_size = pixel_values.shape[0]
|
||||
# we feed the prompt to the model
|
||||
decoder_input_ids = torch.full((batch_size, 1), self.model.config.decoder_start_token_id, device=self.device)
|
||||
|
||||
outputs = self.model.generate(pixel_values,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
max_length=self.max_length,
|
||||
early_stopping=True,
|
||||
pad_token_id=self.processor.tokenizer.pad_token_id,
|
||||
eos_token_id=self.processor.tokenizer.eos_token_id,
|
||||
use_cache=True,
|
||||
num_beams=1,
|
||||
bad_words_ids=[[self.processor.tokenizer.unk_token_id]],
|
||||
return_dict_in_generate=True,)
|
||||
|
||||
predictions = []
|
||||
for seq in self.processor.tokenizer.batch_decode(outputs.sequences):
|
||||
seq = seq.replace(self.processor.tokenizer.eos_token, "").replace(self.processor.tokenizer.pad_token, "")
|
||||
seq = re.sub(r"<.*?>", "", seq, count=1).strip() # remove first task start token
|
||||
predictions.append(seq)
|
||||
|
||||
scores = list()
|
||||
for pred, answer in zip(predictions, answers):
|
||||
pred = re.sub(r"(?:(?<=>) | (?=</s_))", "", pred)
|
||||
# NOT NEEDED ANYMORE
|
||||
# answer = re.sub(r"<.*?>", "", answer, count=1)
|
||||
answer = answer.replace(self.processor.tokenizer.eos_token, "")
|
||||
scores.append(edit_distance(pred, answer) / max(len(pred), len(answer)))
|
||||
|
||||
if self.config.get("verbose", False) and len(scores) == 1:
|
||||
print(f"Prediction: {pred}")
|
||||
print(f" Answer: {answer}")
|
||||
print(f" Normed ED: {scores[0]}")
|
||||
|
||||
return scores
|
||||
|
||||
def validation_epoch_end(self, validation_step_outputs):
|
||||
# I set this to 1 manually
|
||||
# (previously set to len(self.config.dataset_name_or_paths))
|
||||
num_of_loaders = 1
|
||||
if num_of_loaders == 1:
|
||||
validation_step_outputs = [validation_step_outputs]
|
||||
assert len(validation_step_outputs) == num_of_loaders
|
||||
cnt = [0] * num_of_loaders
|
||||
total_metric = [0] * num_of_loaders
|
||||
val_metric = [0] * num_of_loaders
|
||||
for i, results in enumerate(validation_step_outputs):
|
||||
for scores in results:
|
||||
cnt[i] += len(scores)
|
||||
total_metric[i] += np.sum(scores)
|
||||
val_metric[i] = total_metric[i] / cnt[i]
|
||||
val_metric_name = f"val_metric_{i}th_dataset"
|
||||
self.log_dict({val_metric_name: val_metric[i]}, sync_dist=True)
|
||||
self.log_dict({"val_metric": np.sum(total_metric) / np.sum(cnt)}, sync_dist=True)
|
||||
|
||||
def configure_optimizers(self):
|
||||
# TODO add scheduler
|
||||
optimizer = torch.optim.Adam(self.parameters(), lr=self.config.get("lr"))
|
||||
|
||||
return optimizer
|
||||
|
||||
def train_dataloader(self):
|
||||
return self._train_dataloader
|
||||
|
||||
def val_dataloader(self):
|
||||
return self._val_dataloader
|
@ -1,98 +0,0 @@
|
||||
import torch
|
||||
import pytorch_lightning as pl
|
||||
from nltk import edit_distance
|
||||
import re
|
||||
import numpy as np
|
||||
|
||||
|
||||
class DonutModelPLModuleStream(pl.LightningModule):
|
||||
def __init__(self, config, processor, model, max_length, train_dataloader, val_dataloader):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.processor = processor
|
||||
self.model = model
|
||||
self.max_length = max_length
|
||||
self._train_dataloader = train_dataloader
|
||||
self._val_dataloader = val_dataloader
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
# pixel_values, labels, _ = batch
|
||||
pixel_values = batch['pixel_values']
|
||||
labels = batch['labels']
|
||||
outputs = self.model(pixel_values, labels=labels)
|
||||
loss = outputs.loss
|
||||
self.log_dict({"train_loss": loss}, sync_dist=True)
|
||||
return loss
|
||||
|
||||
def validation_step(self, batch, batch_idx, dataset_idx=0):
|
||||
# pixel_values, labels, answers = batch
|
||||
|
||||
pixel_values = batch['pixel_values']
|
||||
labels = batch['labels']
|
||||
answers = batch['target_sequence'][0]
|
||||
batch_size = pixel_values.shape[0]
|
||||
# we feed the prompt to the model
|
||||
decoder_input_ids = torch.full((batch_size, 1), self.model.config.decoder_start_token_id, device=self.device)
|
||||
|
||||
outputs = self.model.generate(pixel_values,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
max_length=self.max_length,
|
||||
early_stopping=True,
|
||||
pad_token_id=self.processor.tokenizer.pad_token_id,
|
||||
eos_token_id=self.processor.tokenizer.eos_token_id,
|
||||
use_cache=True,
|
||||
num_beams=1,
|
||||
bad_words_ids=[[self.processor.tokenizer.unk_token_id]],
|
||||
return_dict_in_generate=True,)
|
||||
|
||||
predictions = []
|
||||
for seq in self.processor.tokenizer.batch_decode(outputs.sequences):
|
||||
seq = seq.replace(self.processor.tokenizer.eos_token, "").replace(self.processor.tokenizer.pad_token, "")
|
||||
seq = re.sub(r"<.*?>", "", seq, count=1).strip() # remove first task start token
|
||||
predictions.append(seq)
|
||||
|
||||
scores = list()
|
||||
for pred, answer in zip(predictions, answers):
|
||||
pred = re.sub(r"(?:(?<=>) | (?=</s_))", "", pred)
|
||||
# NOT NEEDED ANYMORE
|
||||
# answer = re.sub(r"<.*?>", "", answer, count=1)
|
||||
answer = answer.replace(self.processor.tokenizer.eos_token, "")
|
||||
scores.append(edit_distance(pred, answer) / max(len(pred), len(answer)))
|
||||
|
||||
if self.config.get("verbose", False) and len(scores) == 1:
|
||||
print(f"Prediction: {pred}")
|
||||
print(f" Answer: {answer}")
|
||||
print(f" Normed ED: {scores[0]}")
|
||||
|
||||
return scores
|
||||
|
||||
def validation_epoch_end(self, validation_step_outputs):
|
||||
# I set this to 1 manually
|
||||
# (previously set to len(self.config.dataset_name_or_paths))
|
||||
num_of_loaders = 1
|
||||
if num_of_loaders == 1:
|
||||
validation_step_outputs = [validation_step_outputs]
|
||||
assert len(validation_step_outputs) == num_of_loaders
|
||||
cnt = [0] * num_of_loaders
|
||||
total_metric = [0] * num_of_loaders
|
||||
val_metric = [0] * num_of_loaders
|
||||
for i, results in enumerate(validation_step_outputs):
|
||||
for scores in results:
|
||||
cnt[i] += len(scores)
|
||||
total_metric[i] += np.sum(scores)
|
||||
val_metric[i] = total_metric[i] / cnt[i]
|
||||
val_metric_name = f"val_metric_{i}th_dataset"
|
||||
self.log_dict({val_metric_name: val_metric[i]}, sync_dist=True)
|
||||
self.log_dict({"val_metric": np.sum(total_metric) / np.sum(cnt)}, sync_dist=True)
|
||||
|
||||
def configure_optimizers(self):
|
||||
# TODO add scheduler
|
||||
optimizer = torch.optim.Adam(self.parameters(), lr=self.config.get("lr"))
|
||||
|
||||
return optimizer
|
||||
|
||||
def train_dataloader(self):
|
||||
return self._train_dataloader
|
||||
|
||||
def val_dataloader(self):
|
||||
return self._val_dataloader
|
Loading…
Reference in New Issue
Block a user