{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from transformers import DonutProcessor, VisionEncoderDecoderModel\n", "from datasets import load_dataset\n", "import re\n", "import json\n", "import torch\n", "from tqdm.auto import tqdm\n", "import numpy as np\n", "\n", "from donut import JSONParseEvaluator" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "24d467a0b2fa49a99506d690dba9e411", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Downloading: 0%| | 0.00/421 [00:00\"\n", " decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors=\"pt\").input_ids\n", " decoder_input_ids = decoder_input_ids.to(device)\n", " \n", " # autoregressively generate sequence\n", " outputs = model.generate(\n", " pixel_values,\n", " decoder_input_ids=decoder_input_ids,\n", " max_length=model.decoder.config.max_position_embeddings,\n", " early_stopping=True,\n", " pad_token_id=processor.tokenizer.pad_token_id,\n", " eos_token_id=processor.tokenizer.eos_token_id,\n", " use_cache=True,\n", " num_beams=1,\n", " bad_words_ids=[[processor.tokenizer.unk_token_id]],\n", " return_dict_in_generate=True,\n", " )\n", "\n", " # turn into JSON\n", " seq = processor.batch_decode(outputs.sequences)[0]\n", " seq = seq.replace(processor.tokenizer.eos_token, \"\").replace(processor.tokenizer.pad_token, \"\")\n", " seq = re.sub(r\"<.*?>\", \"\", seq, count=1).strip() # remove first task start token\n", " seq = processor.token2json(seq)\n", "\n", " ground_truth = json.loads(sample[\"ground_truth\"])\n", " ground_truth = ground_truth[\"gt_parse\"]\n", " evaluator = JSONParseEvaluator()\n", " score = evaluator.cal_acc(seq, ground_truth)\n", "\n", " accs.append(score)\n", " output_list.append(seq)\n", "\n", "scores = {\"accuracies\": accs, \"mean_accuracy\": np.mean(accs)}\n", "print(scores, f\"length : {len(accs)}\")\n", "print(\"Mean accuracy:\", np.mean(accs))" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3.7.15 ('donut')", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.15" }, "orig_nbformat": 4, "vscode": { "interpreter": { "hash": "11ee3e278e787ae04c18a69549ce58331d512f29053c6ca32ae16833b7cef834" } } }, "nbformat": 4, "nbformat_minor": 2 }