#!/usr/bin/env python # coding: utf-8 from transformers import DonutProcessor, VisionEncoderDecoderModel, VisionEncoderDecoderConfig from datasets import load_dataset import re import json import torch from tqdm.auto import tqdm import numpy as np import pandas as pd from donut import JSONParseEvaluator import argparse from sconf import Config def main(config): if config.use_enc_dec_config: config_vision = VisionEncoderDecoderConfig.from_pretrained(config.pretrained_model_path) config_vision.encoder.image_size = config.image_size # (height, width) config_vision.decoder.max_length = config.max_dec_length processor = DonutProcessor.from_pretrained(config.pretrained_processor_path) model = VisionEncoderDecoderModel.from_pretrained(config.pretrained_model_path) processor.image_processor.size = config.image_size[::-1] # should be (width, height) processor.image_processor.do_align_long_axis = False dataset = load_dataset(config.validation_dataset_path, split=config.validation_dataset_split) device = "cuda" if torch.cuda.is_available() else "cpu" model.eval() model.to(device) output_list = [] accs = [] for idx, sample in tqdm(enumerate(dataset), total=len(dataset)): # prepare encoder inputs pixel_values = processor(sample['image'].convert("RGB"), return_tensors="pt").pixel_values pixel_values = pixel_values.to(device) # prepare decoder inputs task_prompt = "" decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids decoder_input_ids = decoder_input_ids.to(device) # autoregressively generate sequence outputs = model.generate( pixel_values, decoder_input_ids=decoder_input_ids, max_length=model.decoder.config.max_position_embeddings, early_stopping=True, pad_token_id=processor.tokenizer.pad_token_id, eos_token_id=processor.tokenizer.eos_token_id, use_cache=True, num_beams=1, bad_words_ids=[[processor.tokenizer.unk_token_id]], return_dict_in_generate=True, ) # turn into JSON seq = processor.batch_decode(outputs.sequences)[0] seq = seq.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "") seq = re.sub(r"<.*?>", "", seq, count=1).strip() # remove first task start token seq = processor.token2json(seq) if config.has_metadata: ground_truth = json.loads(sample["ground_truth"]) ground_truth = ground_truth["gt_parse"] evaluator = JSONParseEvaluator() score = evaluator.cal_acc(seq, ground_truth) accs.append(score) if config.print_output: print(seq) output_list.append(seq) if config.output_file_dir: df = pd.DataFrame(map(lambda x: x.get('text_sequence', ''), output_list)) df.to_csv(f'{config.output_file_dir}/{config.test_name}-out.tsv', sep='\t', header=False, index=False) if config.has_metadata: scores = {"accuracies": accs, "mean_accuracy": np.mean(accs)} print(scores, f"length : {len(accs)}") print("Mean accuracy:", np.mean(accs)) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, required=True) args, left_argv = parser.parse_known_args() config = Config(args.config) config.argv_update(left_argv) main(config)