donut/donut-eval.py
Michał Kozłowski bdb7f5ef7e additional config
2022-12-17 10:29:25 +01:00

86 lines
3.4 KiB
Python

#!/usr/bin/env python
# coding: utf-8
from transformers import DonutProcessor, VisionEncoderDecoderModel, VisionEncoderDecoderConfig
from datasets import load_dataset
import re
import json
import torch
from tqdm.auto import tqdm
import numpy as np
import pandas as pd
from donut import JSONParseEvaluator
import argparse
from sconf import Config
def main(config):
config_vision = VisionEncoderDecoderConfig.from_pretrained(config.pretrained_model_path)
config_vision.encoder.image_size = [1920, 2560] # (height, width)
config_vision.decoder.max_length = 768
processor = DonutProcessor.from_pretrained(config.pretrained_processor_path)
model = VisionEncoderDecoderModel.from_pretrained(config.pretrained_model_path, config=config_vision)
dataset = load_dataset(config.validation_dataset_path, split=config.validation_dataset_split)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.eval()
model.to(device)
output_list = []
accs = []
for idx, sample in tqdm(enumerate(dataset), total=len(dataset)):
# prepare encoder inputs
pixel_values = processor(sample['image'].convert("RGB"), return_tensors="pt").pixel_values
pixel_values = pixel_values.to(device)
# prepare decoder inputs
task_prompt = "<s_cord-v2>"
decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
decoder_input_ids = decoder_input_ids.to(device)
# autoregressively generate sequence
outputs = model.generate(
pixel_values,
decoder_input_ids=decoder_input_ids,
max_length=model.decoder.config.max_position_embeddings,
early_stopping=True,
pad_token_id=processor.tokenizer.pad_token_id,
eos_token_id=processor.tokenizer.eos_token_id,
use_cache=True,
num_beams=1,
bad_words_ids=[[processor.tokenizer.unk_token_id]],
return_dict_in_generate=True,
)
# turn into JSON
seq = processor.batch_decode(outputs.sequences)[0]
seq = seq.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
seq = re.sub(r"<.*?>", "", seq, count=1).strip() # remove first task start token
seq = processor.token2json(seq)
if config.has_metadata:
ground_truth = json.loads(sample["ground_truth"])
ground_truth = ground_truth["gt_parse"]
evaluator = JSONParseEvaluator()
score = evaluator.cal_acc(seq, ground_truth)
accs.append(score)
if config.print_output:
print(seq)
output_list.append(seq)
if config.output_file_dir:
df = pd.DataFrame(map(lambda x: x.get('text_sequence', ''), output_list))
df.to_csv(f'{config.output_file_dir}/{config.test_name}-out.tsv', sep='\t', header=False, index=False)
if config.has_metadata:
scores = {"accuracies": accs, "mean_accuracy": np.mean(accs)}
print(scores, f"length : {len(accs)}")
print("Mean accuracy:", np.mean(accs))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, required=True)
args, left_argv = parser.parse_known_args()
config = Config(args.config)
config.argv_update(left_argv)
main(config)