donut/notepads/donut-eval.ipynb
2023-01-04 09:50:56 +01:00

12 KiB

from transformers import DonutProcessor, VisionEncoderDecoderModel
from datasets import load_dataset
import re
import json
import torch
from tqdm.auto import tqdm
import numpy as np

from donut import JSONParseEvaluator

processor = DonutProcessor.from_pretrained("Zombely/plwiki-test")
model = VisionEncoderDecoderModel.from_pretrained("Zombely/plwiki-test")
Downloading:   0%|          | 0.00/421 [00:00<?, ?B/s]
Downloading:   0%|          | 0.00/544 [00:00<?, ?B/s]
Downloading:   0%|          | 0.00/1.30M [00:00<?, ?B/s]
Downloading:   0%|          | 0.00/4.01M [00:00<?, ?B/s]
Downloading:   0%|          | 0.00/95.0 [00:00<?, ?B/s]
Downloading:   0%|          | 0.00/355 [00:00<?, ?B/s]
Downloading:   0%|          | 0.00/5.03k [00:00<?, ?B/s]
Downloading:   0%|          | 0.00/809M [00:00<?, ?B/s]
dataset = load_dataset("Zombely/pl-text-images", split="validation")
Downloading readme:   0%|          | 0.00/527 [00:00<?, ?B/s]
Using custom data configuration Zombely--pl-text-images-f3f66e614f4d9a7a
Downloading and preparing dataset None/None to /home/pc/.cache/huggingface/datasets/Zombely___parquet/Zombely--pl-text-images-f3f66e614f4d9a7a/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec...
Downloading data files:   0%|          | 0/3 [00:00<?, ?it/s]
Downloading data:   0%|          | 0.00/1.44M [00:00<?, ?B/s]
Downloading data:   0%|          | 0.00/9.47M [00:00<?, ?B/s]
Downloading data:   0%|          | 0.00/885k [00:00<?, ?B/s]
   
Extracting data files #0:   0%|          | 0/1 [00:00<?, ?obj/s]
Extracting data files #2:   0%|          | 0/1 [00:00<?, ?obj/s]
Extracting data files #1:   0%|          | 0/1 [00:00<?, ?obj/s]
Generating test split:   0%|          | 0/13 [00:00<?, ? examples/s]
Generating train split:   0%|          | 0/101 [00:00<?, ? examples/s]
Generating validation split:   0%|          | 0/11 [00:00<?, ? examples/s]
Dataset parquet downloaded and prepared to /home/pc/.cache/huggingface/datasets/Zombely___parquet/Zombely--pl-text-images-f3f66e614f4d9a7a/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec. Subsequent calls will reuse this data.

device = "cuda" if torch.cuda.is_available() else "cpu"

model.eval()
model.to(device)

output_list = []
accs = []


for idx, sample in tqdm(enumerate(dataset), total=len(dataset)):
    # prepare encoder inputs
    pixel_values = processor(sample["image"].convert("RGB"), return_tensors="pt").pixel_values
    pixel_values = pixel_values.to(device)
    # prepare decoder inputs
    task_prompt = "<s_cord-v2>"
    decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
    decoder_input_ids = decoder_input_ids.to(device)
    
    # autoregressively generate sequence
    outputs = model.generate(
            pixel_values,
            decoder_input_ids=decoder_input_ids,
            max_length=model.decoder.config.max_position_embeddings,
            early_stopping=True,
            pad_token_id=processor.tokenizer.pad_token_id,
            eos_token_id=processor.tokenizer.eos_token_id,
            use_cache=True,
            num_beams=1,
            bad_words_ids=[[processor.tokenizer.unk_token_id]],
            return_dict_in_generate=True,
        )

    # turn into JSON
    seq = processor.batch_decode(outputs.sequences)[0]
    seq = seq.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
    seq = re.sub(r"<.*?>", "", seq, count=1).strip()  # remove first task start token
    seq = processor.token2json(seq)

    ground_truth = json.loads(sample["ground_truth"])
    ground_truth = ground_truth["gt_parse"]
    evaluator = JSONParseEvaluator()
    score = evaluator.cal_acc(seq, ground_truth)

    accs.append(score)
    output_list.append(seq)

scores = {"accuracies": accs, "mean_accuracy": np.mean(accs)}
print(scores, f"length : {len(accs)}")
print("Mean accuracy:", np.mean(accs))
  0%|          | 0/11 [00:00<?, ?it/s]
{'accuracies': [0, 0.9791666666666666, 0.9156626506024097, 1.0, 0.9836552748885586, 1.0, 0.7335359675785207, 0.9512987012987013, 0.396732788798133, 0.9908675799086758, 0.9452954048140044], 'mean_accuracy': 0.8087468213232427} length : 11
Mean accuracy: 0.8087468213232427