donut/donut-eval.py

91 lines
3.6 KiB
Python

#!/usr/bin/env python
# coding: utf-8
from transformers import DonutProcessor, VisionEncoderDecoderModel, VisionEncoderDecoderConfig
from datasets import load_dataset
import re
import json
import torch
from tqdm.auto import tqdm
import numpy as np
import pandas as pd
from donut import JSONParseEvaluator
import argparse
from sconf import Config
def main(config):
if config.use_enc_dec_config:
config_vision = VisionEncoderDecoderConfig.from_pretrained(config.pretrained_model_path)
config_vision.encoder.image_size = config.image_size # (height, width)
config_vision.decoder.max_length = config.max_dec_length
processor = DonutProcessor.from_pretrained(config.pretrained_processor_path)
model = VisionEncoderDecoderModel.from_pretrained(config.pretrained_model_path, config=config_vision if config.use_enc_dec_config else None)
processor.image_processor.size = config.image_size[::-1] # should be (width, height)
processor.image_processor.do_align_long_axis = False
dataset = load_dataset(config.validation_dataset_path, split=config.validation_dataset_split)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.eval()
model.to(device)
output_list = []
accs = []
for idx, sample in tqdm(enumerate(dataset), total=len(dataset)):
# prepare encoder inputs
pixel_values = processor(sample['image'].convert("RGB"), return_tensors="pt").pixel_values
pixel_values = pixel_values.to(device)
# prepare decoder inputs
task_prompt = "<s_cord-v2>"
decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
decoder_input_ids = decoder_input_ids.to(device)
# autoregressively generate sequence
outputs = model.generate(
pixel_values,
decoder_input_ids=decoder_input_ids,
max_length=model.decoder.config.max_position_embeddings,
early_stopping=True,
pad_token_id=processor.tokenizer.pad_token_id,
eos_token_id=processor.tokenizer.eos_token_id,
use_cache=True,
num_beams=1,
bad_words_ids=[[processor.tokenizer.unk_token_id]],
return_dict_in_generate=True,
)
# turn into JSON
seq = processor.batch_decode(outputs.sequences)[0]
seq = seq.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
seq = re.sub(r"<.*?>", "", seq, count=1).strip() # remove first task start token
seq = processor.token2json(seq)
if config.has_metadata:
ground_truth = json.loads(sample["ground_truth"])
ground_truth = ground_truth["gt_parse"]
evaluator = JSONParseEvaluator()
score = evaluator.cal_acc(seq, ground_truth)
accs.append(score)
if config.print_output:
print(seq)
output_list.append(seq)
if config.output_file_dir:
df = pd.DataFrame(map(lambda x: x.get('text_sequence', ''), output_list))
df.to_csv(f'{config.output_file_dir}/{config.test_name}-out.tsv', sep='\t', header=False, index=False)
if config.has_metadata:
scores = {"accuracies": accs, "mean_accuracy": np.mean(accs)}
print(scores, f"length : {len(accs)}")
print("Mean accuracy:", np.mean(accs))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, required=True)
args, left_argv = parser.parse_known_args()
config = Config(args.config)
config.argv_update(left_argv)
main(config)