59 KiB
59 KiB
from transformers import VisionEncoderDecoderConfig, DonutProcessor, VisionEncoderDecoderModel, DonutImageProcessor, XLMRobertaTokenizerFast, BertConfig, ViTConfig
from datasets import load_dataset, interleave_datasets
import json
import random
from typing import Any, List, Tuple
import torch
from torch.utils.data import Dataset, DataLoader
import re
from nltk import edit_distance
import numpy as np
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import Callback
import pytorch_lightning as pl
import os
from huggingface_hub import login
import torch
import pytorch_lightning as pl
from nltk import edit_distance
import re
import numpy as np
class DonutModelPLModuleStream(pl.LightningModule):
def __init__(self, config, processor, model, max_length, train_dataloader, val_dataloader):
super().__init__()
self.config = config
self.processor = processor
self.model = model
self.max_length = max_length
self._train_dataloader = train_dataloader
self._val_dataloader = val_dataloader
def training_step(self, batch, batch_idx):
# pixel_values, labels, _ = batch
pixel_values = batch['pixel_values']
labels = batch['labels']
outputs = self.model(pixel_values, labels=labels)
loss = outputs.loss
self.log_dict({"train_loss": loss}, sync_dist=True)
return loss
def validation_step(self, batch, batch_idx, dataset_idx=0):
# pixel_values, labels, answers = batch
pixel_values = batch['pixel_values']
labels = batch['labels']
answers = batch['target_sequence'][0]
batch_size = pixel_values.shape[0]
# we feed the prompt to the model
decoder_input_ids = torch.full((batch_size, 1), self.model.config.decoder_start_token_id, device=self.device)
outputs = self.model.generate(pixel_values,
decoder_input_ids=decoder_input_ids,
max_length=self.max_length,
early_stopping=True,
pad_token_id=self.processor.tokenizer.pad_token_id,
eos_token_id=self.processor.tokenizer.eos_token_id,
use_cache=True,
num_beams=1,
bad_words_ids=[[self.processor.tokenizer.unk_token_id]],
return_dict_in_generate=True,)
predictions = []
for seq in self.processor.tokenizer.batch_decode(outputs.sequences):
seq = seq.replace(self.processor.tokenizer.eos_token, "").replace(self.processor.tokenizer.pad_token, "")
seq = re.sub(r"<.*?>", "", seq, count=1).strip() # remove first task start token
predictions.append(seq)
scores = list()
for pred, answer in zip(predictions, answers):
pred = re.sub(r"(?:(?<=>) | (?=</s_))", "", pred)
# NOT NEEDED ANYMORE
# answer = re.sub(r"<.*?>", "", answer, count=1)
answer = answer.replace(self.processor.tokenizer.eos_token, "")
scores.append(edit_distance(pred, answer) / max(len(pred), len(answer)))
if self.config.get("verbose", False) and len(scores) == 1:
print(f"Prediction: {pred}")
print(f" Answer: {answer}")
print(f" Normed ED: {scores[0]}")
return scores
def validation_epoch_end(self, validation_step_outputs):
# I set this to 1 manually
# (previously set to len(self.config.dataset_name_or_paths))
num_of_loaders = 1
if num_of_loaders == 1:
validation_step_outputs = [validation_step_outputs]
assert len(validation_step_outputs) == num_of_loaders
cnt = [0] * num_of_loaders
total_metric = [0] * num_of_loaders
val_metric = [0] * num_of_loaders
for i, results in enumerate(validation_step_outputs):
for scores in results:
cnt[i] += len(scores)
total_metric[i] += np.sum(scores)
val_metric[i] = total_metric[i] / cnt[i]
val_metric_name = f"val_metric_{i}th_dataset"
self.log_dict({val_metric_name: val_metric[i]}, sync_dist=True)
self.log_dict({"val_metric": np.sum(total_metric) / np.sum(cnt)}, sync_dist=True)
def configure_optimizers(self):
# TODO add scheduler
optimizer = torch.optim.Adam(self.parameters(), lr=self.config.get("lr"))
return optimizer
def train_dataloader(self):
return self._train_dataloader
def val_dataloader(self):
return self._val_dataloader
# dataset = load_dataset
image_processor = DonutImageProcessor(do_resize=True, do_align_long_axis=False, size=[960, 1260])
tokenizer = XLMRobertaTokenizerFast.from_pretrained('xlm-roberta-base')
config_encoder = ViTConfig(image_size=[1260, 960])
config_decoder = BertConfig()
config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(config_encoder, config_decoder)
processor = DonutProcessor(image_processor=image_processor, tokenizer=tokenizer)
model = VisionEncoderDecoderModel(config=config)
added_tokens = []
### PROCESS FUNC START ###
def add_tokens(list_of_tokens: List[str]):
"""
Add special tokens to tokenizer and resize the token embeddings of the decoder
"""
newly_added_num = processor.tokenizer.add_tokens(list_of_tokens)
if newly_added_num > 0:
model.decoder.resize_token_embeddings(len(processor.tokenizer))
added_tokens.extend(list_of_tokens)
def json2token(obj: Any, update_special_tokens_for_json_key: bool = True, sort_json_key: bool = True):
"""
Convert an ordered JSON object into a token sequence
"""
if type(obj) == dict:
if len(obj) == 1 and "text_sequence" in obj:
return obj["text_sequence"]
else:
output = ""
if sort_json_key:
keys = sorted(obj.keys(), reverse=True)
else:
keys = obj.keys()
for k in keys:
if update_special_tokens_for_json_key:
add_tokens([fr"<s_{k}>", fr"</s_{k}>"])
output += (
fr"<s_{k}>"
+ json2token(obj[k], update_special_tokens_for_json_key, sort_json_key)
+ fr"</s_{k}>"
)
return output
elif type(obj) == list:
return r"<sep/>".join(
[json2token(item, update_special_tokens_for_json_key, sort_json_key) for item in obj]
)
else:
obj = str(obj)
if f"<{obj}/>" in added_tokens:
obj = f"<{obj}/>" # for categorical special tokens
return obj
def process(row, split):
task_start_token, prompt_end_token = "<s_cord-v2>", "<s_cord-v2>"
ground_truth = json.loads(row["ground_truth"])
if "gt_parses" in ground_truth: # when multiple ground truths are available, e.g., docvqa
assert isinstance(ground_truth["gt_parses"], list)
gt_jsons = ground_truth["gt_parses"]
else:
assert "gt_parse" in ground_truth and isinstance(ground_truth["gt_parse"], dict)
gt_jsons = [ground_truth["gt_parse"]]
gt_token_sequences = (
[
json2token(
gt_json,
update_special_tokens_for_json_key=split == "train",
sort_json_key=False,
)
+ processor.tokenizer.eos_token
for gt_json in gt_jsons # load json from list of json
]
)
add_tokens([task_start_token, prompt_end_token])
prompt_end_token_id = processor.tokenizer.convert_tokens_to_ids(prompt_end_token)
# change if not 3 channels
if row['image'].mode != "RGB":
row['image'] = row['image'].convert("RGB")
# inputs
pixel_values = processor(row["image"], random_padding=split == "train", return_tensors="pt").pixel_values
pixel_values = pixel_values.squeeze()
# targets
input_ids = processor.tokenizer(
gt_token_sequences,
add_special_tokens=False,
max_length=config.max_length,
padding="max_length",
truncation=True,
return_tensors="pt",
)["input_ids"].squeeze(0)
labels = input_ids.clone()
labels[labels == processor.tokenizer.pad_token_id] = -100 # model doesn't need to predict pad token
return {"pixel_values": pixel_values, "labels": labels, 'target_sequence': gt_token_sequences }
def proces_train(row):
return process(row, 'train')
def proces_val(row):
return process(row, 'validation')
dataset = load_dataset('Zombely/wikisource-red', streaming=True)
val_dataset = dataset.pop('validation')
train_dataset = interleave_datasets(list(dataset.values()))
# train_length = sum(split.num_examples for split in dataset[list(dataset.keys())[0]].info.splits.values() if split.name != 'validation')
# val_length = list(val_dataset.info.splits.values())[-1].num_examples
train_dataset = train_dataset.map(proces_train, remove_columns = ['image', 'ground_truth'])
val_dataset = val_dataset.map(proces_val, remove_columns = ['image', 'ground_truth'])
train_dataset = train_dataset.with_format('torch')
val_dataset = val_dataset.with_format('torch')
# train_dataset = CustomWrapperIterator(train_dataset, total_len=train_length)
# val_dataset = CustomWrapperIterator(val_dataset, total_len=val_length)
model.config.pad_token_id = processor.tokenizer.pad_token_id
model.config.decoder_start_token_id = processor.tokenizer.convert_tokens_to_ids(['<s_cord-v2>'])[0]
train_dataloader = DataLoader(train_dataset, batch_size=1, num_workers=0)
val_dataloader = DataLoader(val_dataset, batch_size=1, num_workers=0)
Using custom data configuration Zombely--wikisource-red-98affb32ced5f2c5
train_config = {
"max_epochs": 1,
"val_check_interval": 1.0,
"check_val_every_n_epoch": 1,
"gradient_clip_val": 1.0,
"num_training_samples_per_epoch": 800,
"lr": 1.0e-4,
"train_batch_sizes": [8],
"val_batch_sizes": [1],
"seed": 2023,
"num_nodes": 1,
"warmup_steps": 10,
"result_path": "./result",
"verbose": True
}
model_module = DonutModelPLModuleStream(train_config, processor, model, max_length=config.max_length, train_dataloader=train_dataloader, val_dataloader=val_dataloader)
trainer = pl.Trainer(
accelerator="gpu" if torch.cuda.is_available() else 'cpu', # change to gpu
devices=1,
max_epochs=train_config['max_epochs'],
val_check_interval=train_config['val_check_interval'],
check_val_every_n_epoch=train_config['check_val_every_n_epoch'],
gradient_clip_val=train_config['gradient_clip_val'],
precision=16, # we'll use mixed precision
num_sanity_val_steps=0,
)
Using 16bit native Automatic Mixed Precision (AMP) GPU available: True (cuda), used: True TPU available: False, using: 0 TPU cores IPU available: False, using: 0 IPUs HPU available: False, using: 0 HPUs `Trainer(val_check_interval=1.0)` was configured so validation will run at the end of the training epoch..
trainer.fit(model_module)
Missing logger folder: /home/wmi/project/donut/notepads/lightning_logs LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0] | Name | Type | Params ---------------------------------------------------- 0 | model | VisionEncoderDecoderModel | 227 M ---------------------------------------------------- 227 M Trainable params 0 Non-trainable params 227 M Total params 455.428 Total estimated model params size (MB) /home/wmi/project/donut/env_donut/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 14 which is the number of cpus on this machine) in the `DataLoader` init to improve performance. rank_zero_warn( /home/wmi/project/donut/env_donut/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning: The dataloader, val_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 14 which is the number of cpus on this machine) in the `DataLoader` init to improve performance. rank_zero_warn(
Training: 0it [00:00, ?it/s]
(1024, 1579) (1024, 1473)
[0;31m---------------------------------------------------------------------------[0m [0;31mOutOfMemoryError[0m Traceback (most recent call last) [0;32m/tmp/ipykernel_32385/828374167.py[0m in [0;36m<cell line: 1>[0;34m()[0m [0;32m----> 1[0;31m [0mtrainer[0m[0;34m.[0m[0mfit[0m[0;34m([0m[0mmodel_module[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m [0m [0;32m~/project/donut/env_donut/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py[0m in [0;36mfit[0;34m(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)[0m [1;32m 580[0m [0;32mraise[0m [0mTypeError[0m[0;34m([0m[0;34mf"`Trainer.fit()` requires a `LightningModule`, got: {model.__class__.__qualname__}"[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m [1;32m 581[0m [0mself[0m[0;34m.[0m[0mstrategy[0m[0;34m.[0m[0m_lightning_module[0m [0;34m=[0m [0mmodel[0m[0;34m[0m[0;34m[0m[0m [0;32m--> 582[0;31m call._call_and_handle_interrupt( [0m[1;32m 583[0m [0mself[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0m_fit_impl[0m[0;34m,[0m [0mmodel[0m[0;34m,[0m [0mtrain_dataloaders[0m[0;34m,[0m [0mval_dataloaders[0m[0;34m,[0m [0mdatamodule[0m[0;34m,[0m [0mckpt_path[0m[0;34m[0m[0;34m[0m[0m [1;32m 584[0m ) [0;32m~/project/donut/env_donut/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py[0m in [0;36m_call_and_handle_interrupt[0;34m(trainer, trainer_fn, *args, **kwargs)[0m [1;32m 36[0m [0;32mreturn[0m [0mtrainer[0m[0;34m.[0m[0mstrategy[0m[0;34m.[0m[0mlauncher[0m[0;34m.[0m[0mlaunch[0m[0;34m([0m[0mtrainer_fn[0m[0;34m,[0m [0;34m*[0m[0margs[0m[0;34m,[0m [0mtrainer[0m[0;34m=[0m[0mtrainer[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m [1;32m 37[0m [0;32melse[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m [0;32m---> 38[0;31m [0;32mreturn[0m [0mtrainer_fn[0m[0;34m([0m[0;34m*[0m[0margs[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m [0m[1;32m 39[0m [0;34m[0m[0m [1;32m 40[0m [0;32mexcept[0m [0m_TunerExitException[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m [0;32m~/project/donut/env_donut/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py[0m in [0;36m_fit_impl[0;34m(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)[0m [1;32m 622[0m [0mmodel_connected[0m[0;34m=[0m[0mself[0m[0;34m.[0m[0mlightning_module[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m [1;32m 623[0m ) [0;32m--> 624[0;31m [0mself[0m[0;34m.[0m[0m_run[0m[0;34m([0m[0mmodel[0m[0;34m,[0m [0mckpt_path[0m[0;34m=[0m[0mself[0m[0;34m.[0m[0mckpt_path[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m [0m[1;32m 625[0m [0;34m[0m[0m [1;32m 626[0m [0;32massert[0m [0mself[0m[0;34m.[0m[0mstate[0m[0;34m.[0m[0mstopped[0m[0;34m[0m[0;34m[0m[0m [0;32m~/project/donut/env_donut/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py[0m in [0;36m_run[0;34m(self, model, ckpt_path)[0m [1;32m 1059[0m [0mself[0m[0;34m.[0m[0m_checkpoint_connector[0m[0;34m.[0m[0mresume_end[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m [1;32m 1060[0m [0;34m[0m[0m [0;32m-> 1061[0;31m [0mresults[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0m_run_stage[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m [0m[1;32m 1062[0m [0;34m[0m[0m [1;32m 1063[0m [0mlog[0m[0;34m.[0m[0mdetail[0m[0;34m([0m[0;34mf"{self.__class__.__name__}: trainer tearing down"[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m [0;32m~/project/donut/env_donut/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py[0m in [0;36m_run_stage[0;34m(self)[0m [1;32m 1138[0m [0;32mif[0m [0mself[0m[0;34m.[0m[0mpredicting[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m [1;32m 1139[0m [0;32mreturn[0m [0mself[0m[0;34m.[0m[0m_run_predict[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m [0;32m-> 1140[0;31m [0mself[0m[0;34m.[0m[0m_run_train[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m [0m[1;32m 1141[0m [0;34m[0m[0m [1;32m 1142[0m [0;32mdef[0m [0m_pre_training_routine[0m[0;34m([0m[0mself[0m[0;34m)[0m [0;34m->[0m [0;32mNone[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m [0;32m~/project/donut/env_donut/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py[0m in [0;36m_run_train[0;34m(self)[0m [1;32m 1161[0m [0;34m[0m[0m [1;32m 1162[0m [0;32mwith[0m [0mtorch[0m[0;34m.[0m[0mautograd[0m[0;34m.[0m[0mset_detect_anomaly[0m[0;34m([0m[0mself[0m[0;34m.[0m[0m_detect_anomaly[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m [0;32m-> 1163[0;31m [0mself[0m[0;34m.[0m[0mfit_loop[0m[0;34m.[0m[0mrun[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m [0m[1;32m 1164[0m [0;34m[0m[0m [1;32m 1165[0m [0;32mdef[0m [0m_run_evaluate[0m[0;34m([0m[0mself[0m[0;34m)[0m [0;34m->[0m [0m_EVALUATE_OUTPUT[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m [0;32m~/project/donut/env_donut/lib/python3.10/site-packages/pytorch_lightning/loops/loop.py[0m in [0;36mrun[0;34m(self, *args, **kwargs)[0m [1;32m 197[0m [0;32mtry[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m [1;32m 198[0m [0mself[0m[0;34m.[0m[0mon_advance_start[0m[0;34m([0m[0;34m*[0m[0margs[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m [0;32m--> 199[0;31m [0mself[0m[0;34m.[0m[0madvance[0m[0;34m([0m[0;34m*[0m[0margs[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m [0m[1;32m 200[0m [0mself[0m[0;34m.[0m[0mon_advance_end[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m [1;32m 201[0m [0mself[0m[0;34m.[0m[0m_restarting[0m [0;34m=[0m [0;32mFalse[0m[0;34m[0m[0;34m[0m[0m [0;32m~/project/donut/env_donut/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py[0m in [0;36madvance[0;34m(self)[0m [1;32m 265[0m [0mself[0m[0;34m.[0m[0m_data_fetcher[0m[0;34m.[0m[0msetup[0m[0;34m([0m[0mdataloader[0m[0;34m,[0m [0mbatch_to_device[0m[0;34m=[0m[0mbatch_to_device[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m [1;32m 266[0m [0;32mwith[0m [0mself[0m[0;34m.[0m[0mtrainer[0m[0;34m.[0m[0mprofiler[0m[0;34m.[0m[0mprofile[0m[0;34m([0m[0;34m"run_training_epoch"[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m [0;32m--> 267[0;31m [0mself[0m[0;34m.[0m[0m_outputs[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mepoch_loop[0m[0;34m.[0m[0mrun[0m[0;34m([0m[0mself[0m[0;34m.[0m[0m_data_fetcher[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m [0m[1;32m 268[0m [0;34m[0m[0m [1;32m 269[0m [0;32mdef[0m [0mon_advance_end[0m[0;34m([0m[0mself[0m[0;34m)[0m [0;34m->[0m [0;32mNone[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m [0;32m~/project/donut/env_donut/lib/python3.10/site-packages/pytorch_lightning/loops/loop.py[0m in [0;36mrun[0;34m(self, *args, **kwargs)[0m [1;32m 197[0m [0;32mtry[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m [1;32m 198[0m [0mself[0m[0;34m.[0m[0mon_advance_start[0m[0;34m([0m[0;34m*[0m[0margs[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m [0;32m--> 199[0;31m [0mself[0m[0;34m.[0m[0madvance[0m[0;34m([0m[0;34m*[0m[0margs[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m [0m[1;32m 200[0m [0mself[0m[0;34m.[0m[0mon_advance_end[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m [1;32m 201[0m [0mself[0m[0;34m.[0m[0m_restarting[0m [0;34m=[0m [0;32mFalse[0m[0;34m[0m[0;34m[0m[0m [0;32m~/project/donut/env_donut/lib/python3.10/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py[0m in [0;36madvance[0;34m(self, data_fetcher)[0m [1;32m 212[0m [0;34m[0m[0m [1;32m 213[0m [0;32mwith[0m [0mself[0m[0;34m.[0m[0mtrainer[0m[0;34m.[0m[0mprofiler[0m[0;34m.[0m[0mprofile[0m[0;34m([0m[0;34m"run_training_batch"[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m [0;32m--> 214[0;31m [0mbatch_output[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mbatch_loop[0m[0;34m.[0m[0mrun[0m[0;34m([0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m [0m[1;32m 215[0m [0;34m[0m[0m [1;32m 216[0m [0mself[0m[0;34m.[0m[0mbatch_progress[0m[0;34m.[0m[0mincrement_processed[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m [0;32m~/project/donut/env_donut/lib/python3.10/site-packages/pytorch_lightning/loops/loop.py[0m in [0;36mrun[0;34m(self, *args, **kwargs)[0m [1;32m 197[0m [0;32mtry[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m [1;32m 198[0m [0mself[0m[0;34m.[0m[0mon_advance_start[0m[0;34m([0m[0;34m*[0m[0margs[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m [0;32m--> 199[0;31m [0mself[0m[0;34m.[0m[0madvance[0m[0;34m([0m[0;34m*[0m[0margs[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m [0m[1;32m 200[0m [0mself[0m[0;34m.[0m[0mon_advance_end[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m [1;32m 201[0m [0mself[0m[0;34m.[0m[0m_restarting[0m [0;34m=[0m [0;32mFalse[0m[0;34m[0m[0;34m[0m[0m [0;32m~/project/donut/env_donut/lib/python3.10/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py[0m in [0;36madvance[0;34m(self, kwargs)[0m [1;32m 86[0m [0mself[0m[0;34m.[0m[0mtrainer[0m[0;34m.[0m[0moptimizers[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0mtrainer[0m[0;34m.[0m[0moptimizer_frequencies[0m[0;34m,[0m [0mkwargs[0m[0;34m.[0m[0mget[0m[0;34m([0m[0;34m"batch_idx"[0m[0;34m,[0m [0;36m0[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m [1;32m 87[0m ) [0;32m---> 88[0;31m [0moutputs[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0moptimizer_loop[0m[0;34m.[0m[0mrun[0m[0;34m([0m[0moptimizers[0m[0;34m,[0m [0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m [0m[1;32m 89[0m [0;32melse[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m [1;32m 90[0m [0moutputs[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mmanual_loop[0m[0;34m.[0m[0mrun[0m[0;34m([0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m [0;32m~/project/donut/env_donut/lib/python3.10/site-packages/pytorch_lightning/loops/loop.py[0m in [0;36mrun[0;34m(self, *args, **kwargs)[0m [1;32m 197[0m [0;32mtry[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m [1;32m 198[0m [0mself[0m[0;34m.[0m[0mon_advance_start[0m[0;34m([0m[0;34m*[0m[0margs[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m [0;32m--> 199[0;31m [0mself[0m[0;34m.[0m[0madvance[0m[0;34m([0m[0;34m*[0m[0margs[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m [0m[1;32m 200[0m [0mself[0m[0;34m.[0m[0mon_advance_end[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m [1;32m 201[0m [0mself[0m[0;34m.[0m[0m_restarting[0m [0;34m=[0m [0;32mFalse[0m[0;34m[0m[0;34m[0m[0m [0;32m~/project/donut/env_donut/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py[0m in [0;36madvance[0;34m(self, optimizers, kwargs)[0m [1;32m 198[0m [0mkwargs[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0m_build_kwargs[0m[0;34m([0m[0mkwargs[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0moptimizer_idx[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0m_hiddens[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m [1;32m 199[0m [0;34m[0m[0m [0;32m--> 200[0;31m [0mresult[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0m_run_optimization[0m[0;34m([0m[0mkwargs[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0m_optimizers[0m[0;34m[[0m[0mself[0m[0;34m.[0m[0moptim_progress[0m[0;34m.[0m[0moptimizer_position[0m[0;34m][0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m [0m[1;32m 201[0m [0;32mif[0m [0mresult[0m[0;34m.[0m[0mloss[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m [1;32m 202[0m [0;31m# automatic optimization assumes a loss needs to be returned for extras to be considered as the batch[0m[0;34m[0m[0;34m[0m[0m [0;32m~/project/donut/env_donut/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py[0m in [0;36m_run_optimization[0;34m(self, kwargs, optimizer)[0m [1;32m 245[0m [0;32melse[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m [1;32m 246[0m [0;31m# the `batch_idx` is optional with inter-batch parallelism[0m[0;34m[0m[0;34m[0m[0m [0;32m--> 247[0;31m [0mself[0m[0;34m.[0m[0m_optimizer_step[0m[0;34m([0m[0moptimizer[0m[0;34m,[0m [0mopt_idx[0m[0;34m,[0m [0mkwargs[0m[0;34m.[0m[0mget[0m[0;34m([0m[0;34m"batch_idx"[0m[0;34m,[0m [0;36m0[0m[0;34m)[0m[0;34m,[0m [0mclosure[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m [0m[1;32m 248[0m [0;34m[0m[0m [1;32m 249[0m [0mresult[0m [0;34m=[0m [0mclosure[0m[0;34m.[0m[0mconsume_result[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m [0;32m~/project/donut/env_donut/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py[0m in [0;36m_optimizer_step[0;34m(self, optimizer, opt_idx, batch_idx, train_step_and_backward_closure)[0m [1;32m 355[0m [0;34m[0m[0m [1;32m 356[0m [0;31m# model hook[0m[0;34m[0m[0;34m[0m[0m [0;32m--> 357[0;31m self.trainer._call_lightning_module_hook( [0m[1;32m 358[0m [0;34m"optimizer_step"[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m [1;32m 359[0m [0mself[0m[0;34m.[0m[0mtrainer[0m[0;34m.[0m[0mcurrent_epoch[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m [0;32m~/project/donut/env_donut/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py[0m in [0;36m_call_lightning_module_hook[0;34m(self, hook_name, pl_module, *args, **kwargs)[0m [1;32m 1303[0m [0;34m[0m[0m [1;32m 1304[0m [0;32mwith[0m [0mself[0m[0;34m.[0m[0mprofiler[0m[0;34m.[0m[0mprofile[0m[0;34m([0m[0;34mf"[LightningModule]{pl_module.__class__.__name__}.{hook_name}"[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m [0;32m-> 1305[0;31m [0moutput[0m [0;34m=[0m [0mfn[0m[0;34m([0m[0;34m*[0m[0margs[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m [0m[1;32m 1306[0m [0;34m[0m[0m [1;32m 1307[0m [0;31m# restore current_fx when nested context[0m[0;34m[0m[0;34m[0m[0m [0;32m~/project/donut/env_donut/lib/python3.10/site-packages/pytorch_lightning/core/module.py[0m in [0;36moptimizer_step[0;34m(self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, on_tpu, using_native_amp, using_lbfgs)[0m [1;32m 1659[0m [0;34m[0m[0m [1;32m 1660[0m """ [0;32m-> 1661[0;31m [0moptimizer[0m[0;34m.[0m[0mstep[0m[0;34m([0m[0mclosure[0m[0;34m=[0m[0moptimizer_closure[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m [0m[1;32m 1662[0m [0;34m[0m[0m [1;32m 1663[0m [0;32mdef[0m [0moptimizer_zero_grad[0m[0;34m([0m[0mself[0m[0;34m,[0m [0mepoch[0m[0;34m:[0m [0mint[0m[0;34m,[0m [0mbatch_idx[0m[0;34m:[0m [0mint[0m[0;34m,[0m [0moptimizer[0m[0;34m:[0m [0mOptimizer[0m[0;34m,[0m [0moptimizer_idx[0m[0;34m:[0m [0mint[0m[0;34m)[0m [0;34m->[0m [0;32mNone[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m [0;32m~/project/donut/env_donut/lib/python3.10/site-packages/pytorch_lightning/core/optimizer.py[0m in [0;36mstep[0;34m(self, closure, **kwargs)[0m [1;32m 167[0m [0;34m[0m[0m [1;32m 168[0m [0;32massert[0m [0mself[0m[0;34m.[0m[0m_strategy[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m[0;34m[0m[0;34m[0m[0m [0;32m--> 169[0;31m [0mstep_output[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0m_strategy[0m[0;34m.[0m[0moptimizer_step[0m[0;34m([0m[0mself[0m[0;34m.[0m[0m_optimizer[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0m_optimizer_idx[0m[0;34m,[0m [0mclosure[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m [0m[1;32m 170[0m [0;34m[0m[0m [1;32m 171[0m [0mself[0m[0;34m.[0m[0m_on_after_step[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m [0;32m~/project/donut/env_donut/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py[0m in [0;36moptimizer_step[0;34m(self, optimizer, opt_idx, closure, model, **kwargs)[0m [1;32m 232[0m [0;31m# TODO(lite): remove assertion once strategy's optimizer_step typing is fixed[0m[0;34m[0m[0;34m[0m[0m [1;32m 233[0m [0;32massert[0m [0misinstance[0m[0;34m([0m[0mmodel[0m[0;34m,[0m [0mpl[0m[0;34m.[0m[0mLightningModule[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m [0;32m--> 234[0;31m return self.precision_plugin.optimizer_step( [0m[1;32m 235[0m [0moptimizer[0m[0;34m,[0m [0mmodel[0m[0;34m=[0m[0mmodel[0m[0;34m,[0m [0moptimizer_idx[0m[0;34m=[0m[0mopt_idx[0m[0;34m,[0m [0mclosure[0m[0;34m=[0m[0mclosure[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m[0m[0;34m[0m[0m [1;32m 236[0m ) [0;32m~/project/donut/env_donut/lib/python3.10/site-packages/pytorch_lightning/plugins/precision/native_amp.py[0m in [0;36moptimizer_step[0;34m(self, optimizer, model, optimizer_idx, closure, **kwargs)[0m [1;32m 83[0m [0;34mf"Native AMP and the LBFGS optimizer are not compatible (optimizer {optimizer_idx})."[0m[0;34m[0m[0;34m[0m[0m [1;32m 84[0m ) [0;32m---> 85[0;31m [0mclosure_result[0m [0;34m=[0m [0mclosure[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m [0m[1;32m 86[0m [0;34m[0m[0m [1;32m 87[0m [0;32mif[0m [0;32mnot[0m [0m_optimizer_handles_unscaling[0m[0;34m([0m[0moptimizer[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m [0;32m~/project/donut/env_donut/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py[0m in [0;36m__call__[0;34m(self, *args, **kwargs)[0m [1;32m 145[0m [0;34m[0m[0m [1;32m 146[0m [0;32mdef[0m [0m__call__[0m[0;34m([0m[0mself[0m[0;34m,[0m [0;34m*[0m[0margs[0m[0;34m:[0m [0mAny[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m:[0m [0mAny[0m[0;34m)[0m [0;34m->[0m [0mOptional[0m[0;34m[[0m[0mTensor[0m[0;34m][0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m [0;32m--> 147[0;31m [0mself[0m[0;34m.[0m[0m_result[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mclosure[0m[0;34m([0m[0;34m*[0m[0margs[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m [0m[1;32m 148[0m [0;32mreturn[0m [0mself[0m[0;34m.[0m[0m_result[0m[0;34m.[0m[0mloss[0m[0;34m[0m[0;34m[0m[0m [1;32m 149[0m [0;34m[0m[0m [0;32m~/project/donut/env_donut/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py[0m in [0;36mclosure[0;34m(self, *args, **kwargs)[0m [1;32m 140[0m [0;34m[0m[0m [1;32m 141[0m [0;32mif[0m [0mself[0m[0;34m.[0m[0m_backward_fn[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m [0;32mand[0m [0mstep_output[0m[0;34m.[0m[0mclosure_loss[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m [0;32m--> 142[0;31m [0mself[0m[0;34m.[0m[0m_backward_fn[0m[0;34m([0m[0mstep_output[0m[0;34m.[0m[0mclosure_loss[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m [0m[1;32m 143[0m [0;34m[0m[0m [1;32m 144[0m [0;32mreturn[0m [0mstep_output[0m[0;34m[0m[0;34m[0m[0m [0;32m~/project/donut/env_donut/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py[0m in [0;36mbackward_fn[0;34m(loss)[0m [1;32m 301[0m [0;34m[0m[0m [1;32m 302[0m [0;32mdef[0m [0mbackward_fn[0m[0;34m([0m[0mloss[0m[0;34m:[0m [0mTensor[0m[0;34m)[0m [0;34m->[0m [0;32mNone[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m [0;32m--> 303[0;31m [0mself[0m[0;34m.[0m[0mtrainer[0m[0;34m.[0m[0m_call_strategy_hook[0m[0;34m([0m[0;34m"backward"[0m[0;34m,[0m [0mloss[0m[0;34m,[0m [0moptimizer[0m[0;34m,[0m [0mopt_idx[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m [0m[1;32m 304[0m [0;34m[0m[0m [1;32m 305[0m [0;32mreturn[0m [0mbackward_fn[0m[0;34m[0m[0;34m[0m[0m [0;32m~/project/donut/env_donut/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py[0m in [0;36m_call_strategy_hook[0;34m(self, hook_name, *args, **kwargs)[0m [1;32m 1441[0m [0;34m[0m[0m [1;32m 1442[0m [0;32mwith[0m [0mself[0m[0;34m.[0m[0mprofiler[0m[0;34m.[0m[0mprofile[0m[0;34m([0m[0;34mf"[Strategy]{self.strategy.__class__.__name__}.{hook_name}"[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m [0;32m-> 1443[0;31m [0moutput[0m [0;34m=[0m [0mfn[0m[0;34m([0m[0;34m*[0m[0margs[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m [0m[1;32m 1444[0m [0;34m[0m[0m [1;32m 1445[0m [0;31m# restore current_fx when nested context[0m[0;34m[0m[0;34m[0m[0m [0;32m~/project/donut/env_donut/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py[0m in [0;36mbackward[0;34m(self, closure_loss, optimizer, optimizer_idx, *args, **kwargs)[0m [1;32m 205[0m [0mclosure_loss[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mprecision_plugin[0m[0;34m.[0m[0mpre_backward[0m[0;34m([0m[0mclosure_loss[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0mlightning_module[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m [1;32m 206[0m [0;34m[0m[0m [0;32m--> 207[0;31m [0mself[0m[0;34m.[0m[0mprecision_plugin[0m[0;34m.[0m[0mbackward[0m[0;34m([0m[0mclosure_loss[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0mlightning_module[0m[0;34m,[0m [0moptimizer[0m[0;34m,[0m [0moptimizer_idx[0m[0;34m,[0m [0;34m*[0m[0margs[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m [0m[1;32m 208[0m [0;34m[0m[0m [1;32m 209[0m [0mclosure_loss[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mprecision_plugin[0m[0;34m.[0m[0mpost_backward[0m[0;34m([0m[0mclosure_loss[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0mlightning_module[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m [0;32m~/project/donut/env_donut/lib/python3.10/site-packages/pytorch_lightning/plugins/precision/precision_plugin.py[0m in [0;36mbackward[0;34m(self, tensor, model, optimizer, optimizer_idx, *args, **kwargs)[0m [1;32m 67[0m [0;31m\[0m[0;34m**[0m[0mkwargs[0m[0;34m:[0m [0mKeyword[0m [0marguments[0m [0;32mfor[0m [0mthe[0m [0msame[0m [0mpurpose[0m [0;32mas[0m[0;31m [0m[0;31m`[0m[0;31m`[0m[0;34m*[0m[0margs[0m[0;31m`[0m[0;31m`[0m[0;34m.[0m[0;34m[0m[0;34m[0m[0m [1;32m 68[0m """ [0;32m---> 69[0;31m [0mmodel[0m[0;34m.[0m[0mbackward[0m[0;34m([0m[0mtensor[0m[0;34m,[0m [0moptimizer[0m[0;34m,[0m [0moptimizer_idx[0m[0;34m,[0m [0;34m*[0m[0margs[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m [0m[1;32m 70[0m [0;34m[0m[0m [1;32m 71[0m [0;32mdef[0m [0mpost_backward[0m[0;34m([0m[0mself[0m[0;34m,[0m [0mtensor[0m[0;34m:[0m [0mTensor[0m[0;34m,[0m [0mmodule[0m[0;34m:[0m [0;34m"pl.LightningModule"[0m[0;34m)[0m [0;34m->[0m [0mTensor[0m[0;34m:[0m [0;31m# type: ignore[override][0m[0;34m[0m[0;34m[0m[0m [0;32m~/project/donut/env_donut/lib/python3.10/site-packages/pytorch_lightning/core/module.py[0m in [0;36mbackward[0;34m(self, loss, optimizer, optimizer_idx, *args, **kwargs)[0m [1;32m 1404[0m [0mloss[0m[0;34m.[0m[0mbackward[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m [1;32m 1405[0m """ [0;32m-> 1406[0;31m [0mloss[0m[0;34m.[0m[0mbackward[0m[0;34m([0m[0;34m*[0m[0margs[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m [0m[1;32m 1407[0m [0;34m[0m[0m [1;32m 1408[0m [0;32mdef[0m [0mtoggle_optimizer[0m[0;34m([0m[0mself[0m[0;34m,[0m [0moptimizer[0m[0;34m:[0m [0mUnion[0m[0;34m[[0m[0mOptimizer[0m[0;34m,[0m [0mLightningOptimizer[0m[0;34m][0m[0;34m,[0m [0moptimizer_idx[0m[0;34m:[0m [0mint[0m[0;34m)[0m [0;34m->[0m [0;32mNone[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m [0;32m~/project/donut/env_donut/lib/python3.10/site-packages/torch/_tensor.py[0m in [0;36mbackward[0;34m(self, gradient, retain_graph, create_graph, inputs)[0m [1;32m 486[0m [0minputs[0m[0;34m=[0m[0minputs[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m [1;32m 487[0m ) [0;32m--> 488[0;31m torch.autograd.backward( [0m[1;32m 489[0m [0mself[0m[0;34m,[0m [0mgradient[0m[0;34m,[0m [0mretain_graph[0m[0;34m,[0m [0mcreate_graph[0m[0;34m,[0m [0minputs[0m[0;34m=[0m[0minputs[0m[0;34m[0m[0;34m[0m[0m [1;32m 490[0m ) [0;32m~/project/donut/env_donut/lib/python3.10/site-packages/torch/autograd/__init__.py[0m in [0;36mbackward[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)[0m [1;32m 195[0m [0;31m# some Python versions print out the first line of a multi-line function[0m[0;34m[0m[0;34m[0m[0m [1;32m 196[0m [0;31m# calls in the traceback and some print out the last line[0m[0;34m[0m[0;34m[0m[0m [0;32m--> 197[0;31m Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass [0m[1;32m 198[0m [0mtensors[0m[0;34m,[0m [0mgrad_tensors_[0m[0;34m,[0m [0mretain_graph[0m[0;34m,[0m [0mcreate_graph[0m[0;34m,[0m [0minputs[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m [1;32m 199[0m allow_unreachable=True, accumulate_grad=True) # Calls into the C++ engine to run the backward pass [0;31mOutOfMemoryError[0m: CUDA out of memory. Tried to allocate 368.00 MiB (GPU 0; 23.70 GiB total capacity; 22.07 GiB already allocated; 260.56 MiB free; 22.09 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF