12 KiB
12 KiB
from transformers import DonutProcessor, VisionEncoderDecoderModel
from datasets import load_dataset
import re
import json
import torch
from tqdm.auto import tqdm
import numpy as np
from donut import JSONParseEvaluator
processor = DonutProcessor.from_pretrained("Zombely/plwiki-test")
model = VisionEncoderDecoderModel.from_pretrained("Zombely/plwiki-test")
Downloading: 0%| | 0.00/421 [00:00<?, ?B/s]
Downloading: 0%| | 0.00/544 [00:00<?, ?B/s]
Downloading: 0%| | 0.00/1.30M [00:00<?, ?B/s]
Downloading: 0%| | 0.00/4.01M [00:00<?, ?B/s]
Downloading: 0%| | 0.00/95.0 [00:00<?, ?B/s]
Downloading: 0%| | 0.00/355 [00:00<?, ?B/s]
Downloading: 0%| | 0.00/5.03k [00:00<?, ?B/s]
Downloading: 0%| | 0.00/809M [00:00<?, ?B/s]
dataset = load_dataset("Zombely/pl-text-images", split="validation")
Downloading readme: 0%| | 0.00/527 [00:00<?, ?B/s]
Using custom data configuration Zombely--pl-text-images-f3f66e614f4d9a7a
Downloading and preparing dataset None/None to /home/pc/.cache/huggingface/datasets/Zombely___parquet/Zombely--pl-text-images-f3f66e614f4d9a7a/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec...
Downloading data files: 0%| | 0/3 [00:00<?, ?it/s]
Downloading data: 0%| | 0.00/1.44M [00:00<?, ?B/s]
Downloading data: 0%| | 0.00/9.47M [00:00<?, ?B/s]
Downloading data: 0%| | 0.00/885k [00:00<?, ?B/s]
Extracting data files #0: 0%| | 0/1 [00:00<?, ?obj/s]
Extracting data files #2: 0%| | 0/1 [00:00<?, ?obj/s]
Extracting data files #1: 0%| | 0/1 [00:00<?, ?obj/s]
Generating test split: 0%| | 0/13 [00:00<?, ? examples/s]
Generating train split: 0%| | 0/101 [00:00<?, ? examples/s]
Generating validation split: 0%| | 0/11 [00:00<?, ? examples/s]
Dataset parquet downloaded and prepared to /home/pc/.cache/huggingface/datasets/Zombely___parquet/Zombely--pl-text-images-f3f66e614f4d9a7a/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec. Subsequent calls will reuse this data.
device = "cuda" if torch.cuda.is_available() else "cpu"
model.eval()
model.to(device)
output_list = []
accs = []
for idx, sample in tqdm(enumerate(dataset), total=len(dataset)):
# prepare encoder inputs
pixel_values = processor(sample["image"].convert("RGB"), return_tensors="pt").pixel_values
pixel_values = pixel_values.to(device)
# prepare decoder inputs
task_prompt = "<s_cord-v2>"
decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
decoder_input_ids = decoder_input_ids.to(device)
# autoregressively generate sequence
outputs = model.generate(
pixel_values,
decoder_input_ids=decoder_input_ids,
max_length=model.decoder.config.max_position_embeddings,
early_stopping=True,
pad_token_id=processor.tokenizer.pad_token_id,
eos_token_id=processor.tokenizer.eos_token_id,
use_cache=True,
num_beams=1,
bad_words_ids=[[processor.tokenizer.unk_token_id]],
return_dict_in_generate=True,
)
# turn into JSON
seq = processor.batch_decode(outputs.sequences)[0]
seq = seq.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
seq = re.sub(r"<.*?>", "", seq, count=1).strip() # remove first task start token
seq = processor.token2json(seq)
ground_truth = json.loads(sample["ground_truth"])
ground_truth = ground_truth["gt_parse"]
evaluator = JSONParseEvaluator()
score = evaluator.cal_acc(seq, ground_truth)
accs.append(score)
output_list.append(seq)
scores = {"accuracies": accs, "mean_accuracy": np.mean(accs)}
print(scores, f"length : {len(accs)}")
print("Mean accuracy:", np.mean(accs))
0%| | 0/11 [00:00<?, ?it/s]
{'accuracies': [0, 0.9791666666666666, 0.9156626506024097, 1.0, 0.9836552748885586, 1.0, 0.7335359675785207, 0.9512987012987013, 0.396732788798133, 0.9908675799086758, 0.9452954048140044], 'mean_accuracy': 0.8087468213232427} length : 11 Mean accuracy: 0.8087468213232427