import torch import pytorch_lightning as pl from nltk import edit_distance import re import numpy as np class DonutModelPLModuleStream(pl.LightningModule): def __init__(self, config, processor, model, max_length, train_dataloader, val_dataloader): super().__init__() self.config = config self.processor = processor self.model = model self.max_length = max_length self._train_dataloader = train_dataloader self._val_dataloader = val_dataloader def training_step(self, batch, batch_idx): # pixel_values, labels, _ = batch pixel_values = batch['pixel_values'] labels = batch['labels'] outputs = self.model(pixel_values, labels=labels) loss = outputs.loss self.log_dict({"train_loss": loss}, sync_dist=True) return loss def validation_step(self, batch, batch_idx, dataset_idx=0): # pixel_values, labels, answers = batch pixel_values = batch['pixel_values'] labels = batch['labels'] answers = batch['target_sequence'] batch_size = pixel_values.shape[0] # we feed the prompt to the model decoder_input_ids = torch.full((batch_size, 1), self.model.config.decoder_start_token_id, device=self.device) outputs = self.model.generate(pixel_values, decoder_input_ids=decoder_input_ids, max_length=self.max_length, early_stopping=True, pad_token_id=self.processor.tokenizer.pad_token_id, eos_token_id=self.processor.tokenizer.eos_token_id, use_cache=True, num_beams=1, bad_words_ids=[[self.processor.tokenizer.unk_token_id]], return_dict_in_generate=True,) predictions = [] for seq in self.processor.tokenizer.batch_decode(outputs.sequences): seq = seq.replace(self.processor.tokenizer.eos_token, "").replace(self.processor.tokenizer.pad_token, "") seq = re.sub(r"<.*?>", "", seq, count=1).strip() # remove first task start token predictions.append(seq) scores = list() for pred, answer in zip(predictions, answers): pred = re.sub(r"(?:(?<=>) | (?=", "", answer, count=1) answer = answer.replace(self.processor.tokenizer.eos_token, "") scores.append(edit_distance(pred, answer) / max(len(pred), len(answer))) if self.config.get("verbose", False) and len(scores) == 1: print(f"Prediction: {pred}") print(f" Answer: {answer}") print(f" Normed ED: {scores[0]}") return scores def validation_epoch_end(self, validation_step_outputs): # I set this to 1 manually # (previously set to len(self.config.dataset_name_or_paths)) num_of_loaders = 1 if num_of_loaders == 1: validation_step_outputs = [validation_step_outputs] assert len(validation_step_outputs) == num_of_loaders cnt = [0] * num_of_loaders total_metric = [0] * num_of_loaders val_metric = [0] * num_of_loaders for i, results in enumerate(validation_step_outputs): for scores in results: cnt[i] += len(scores) total_metric[i] += np.sum(scores) val_metric[i] = total_metric[i] / cnt[i] val_metric_name = f"val_metric_{i}th_dataset" self.log_dict({val_metric_name: val_metric[i]}, sync_dist=True) self.log_dict({"val_metric": np.sum(total_metric) / np.sum(cnt)}, sync_dist=True) def configure_optimizers(self): # TODO add scheduler optimizer = torch.optim.Adam(self.parameters(), lr=self.config.get("lr")) return optimizer def train_dataloader(self): return self._train_dataloader def val_dataloader(self): return self._val_dataloader