98 lines
4.1 KiB
Python
98 lines
4.1 KiB
Python
|
import torch
|
||
|
import pytorch_lightning as pl
|
||
|
from nltk import edit_distance
|
||
|
import re
|
||
|
import numpy as np
|
||
|
|
||
|
|
||
|
class DonutModelPLModuleStream(pl.LightningModule):
|
||
|
def __init__(self, config, processor, model, max_length, train_dataloader, val_dataloader):
|
||
|
super().__init__()
|
||
|
self.config = config
|
||
|
self.processor = processor
|
||
|
self.model = model
|
||
|
self.max_length = max_length
|
||
|
self._train_dataloader = train_dataloader
|
||
|
self._val_dataloader = val_dataloader
|
||
|
|
||
|
def training_step(self, batch, batch_idx):
|
||
|
# pixel_values, labels, _ = batch
|
||
|
pixel_values = batch['pixel_values']
|
||
|
labels = batch['labels']
|
||
|
outputs = self.model(pixel_values, labels=labels)
|
||
|
loss = outputs.loss
|
||
|
self.log_dict({"train_loss": loss}, sync_dist=True)
|
||
|
return loss
|
||
|
|
||
|
def validation_step(self, batch, batch_idx, dataset_idx=0):
|
||
|
# pixel_values, labels, answers = batch
|
||
|
|
||
|
pixel_values = batch['pixel_values']
|
||
|
labels = batch['labels']
|
||
|
answers = batch['target_sequence']
|
||
|
batch_size = pixel_values.shape[0]
|
||
|
# we feed the prompt to the model
|
||
|
decoder_input_ids = torch.full((batch_size, 1), self.model.config.decoder_start_token_id, device=self.device)
|
||
|
|
||
|
outputs = self.model.generate(pixel_values,
|
||
|
decoder_input_ids=decoder_input_ids,
|
||
|
max_length=self.max_length,
|
||
|
early_stopping=True,
|
||
|
pad_token_id=self.processor.tokenizer.pad_token_id,
|
||
|
eos_token_id=self.processor.tokenizer.eos_token_id,
|
||
|
use_cache=True,
|
||
|
num_beams=1,
|
||
|
bad_words_ids=[[self.processor.tokenizer.unk_token_id]],
|
||
|
return_dict_in_generate=True,)
|
||
|
|
||
|
predictions = []
|
||
|
for seq in self.processor.tokenizer.batch_decode(outputs.sequences):
|
||
|
seq = seq.replace(self.processor.tokenizer.eos_token, "").replace(self.processor.tokenizer.pad_token, "")
|
||
|
seq = re.sub(r"<.*?>", "", seq, count=1).strip() # remove first task start token
|
||
|
predictions.append(seq)
|
||
|
|
||
|
scores = list()
|
||
|
for pred, answer in zip(predictions, answers):
|
||
|
pred = re.sub(r"(?:(?<=>) | (?=</s_))", "", pred)
|
||
|
# NOT NEEDED ANYMORE
|
||
|
# answer = re.sub(r"<.*?>", "", answer, count=1)
|
||
|
answer = answer.replace(self.processor.tokenizer.eos_token, "")
|
||
|
scores.append(edit_distance(pred, answer) / max(len(pred), len(answer)))
|
||
|
|
||
|
if self.config.get("verbose", False) and len(scores) == 1:
|
||
|
print(f"Prediction: {pred}")
|
||
|
print(f" Answer: {answer}")
|
||
|
print(f" Normed ED: {scores[0]}")
|
||
|
|
||
|
return scores
|
||
|
|
||
|
def validation_epoch_end(self, validation_step_outputs):
|
||
|
# I set this to 1 manually
|
||
|
# (previously set to len(self.config.dataset_name_or_paths))
|
||
|
num_of_loaders = 1
|
||
|
if num_of_loaders == 1:
|
||
|
validation_step_outputs = [validation_step_outputs]
|
||
|
assert len(validation_step_outputs) == num_of_loaders
|
||
|
cnt = [0] * num_of_loaders
|
||
|
total_metric = [0] * num_of_loaders
|
||
|
val_metric = [0] * num_of_loaders
|
||
|
for i, results in enumerate(validation_step_outputs):
|
||
|
for scores in results:
|
||
|
cnt[i] += len(scores)
|
||
|
total_metric[i] += np.sum(scores)
|
||
|
val_metric[i] = total_metric[i] / cnt[i]
|
||
|
val_metric_name = f"val_metric_{i}th_dataset"
|
||
|
self.log_dict({val_metric_name: val_metric[i]}, sync_dist=True)
|
||
|
self.log_dict({"val_metric": np.sum(total_metric) / np.sum(cnt)}, sync_dist=True)
|
||
|
|
||
|
def configure_optimizers(self):
|
||
|
# TODO add scheduler
|
||
|
optimizer = torch.optim.Adam(self.parameters(), lr=self.config.get("lr"))
|
||
|
|
||
|
return optimizer
|
||
|
|
||
|
def train_dataloader(self):
|
||
|
return self._train_dataloader
|
||
|
|
||
|
def val_dataloader(self):
|
||
|
return self._val_dataloader
|