donut/donut-eval.ipynb

446 lines
12 KiB
Plaintext
Raw Normal View History

2022-12-11 10:43:08 +01:00
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from transformers import DonutProcessor, VisionEncoderDecoderModel\n",
"from datasets import load_dataset\n",
"import re\n",
"import json\n",
"import torch\n",
"from tqdm.auto import tqdm\n",
"import numpy as np\n",
"\n",
"from donut import JSONParseEvaluator"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "24d467a0b2fa49a99506d690dba9e411",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading: 0%| | 0.00/421 [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "3b69047a501b4e18b831d9e28c90f45e",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading: 0%| | 0.00/544 [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "69430fbf988c44e39d02b05144c356df",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading: 0%| | 0.00/1.30M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "875aa0ed813647d08d7f1e003b92b4e5",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading: 0%| | 0.00/4.01M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "cc5997bbdf6f4adfbea60fdfe18c3503",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading: 0%| | 0.00/95.0 [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "634540e8020e45129fe422976a7113e1",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading: 0%| | 0.00/355 [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "6fdce54654504308a7f1ba669823b996",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading: 0%| | 0.00/5.03k [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "552c171e151045e2a9dc6f860428add0",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading: 0%| | 0.00/809M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"\n",
"processor = DonutProcessor.from_pretrained(\"Zombely/plwiki-test\")\n",
"model = VisionEncoderDecoderModel.from_pretrained(\"Zombely/plwiki-test\")"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "deba176990a74dbd80fb9347758b6e33",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading readme: 0%| | 0.00/527 [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Using custom data configuration Zombely--pl-text-images-f3f66e614f4d9a7a\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Downloading and preparing dataset None/None to /home/pc/.cache/huggingface/datasets/Zombely___parquet/Zombely--pl-text-images-f3f66e614f4d9a7a/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec...\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "9e238b0b3f1e4176a00657852986df9d",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading data files: 0%| | 0/3 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "1f57c7efeb0f4e1e8bf7cc39e1d78a84",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading data: 0%| | 0.00/1.44M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "68418bad8c1745faa7ecd802594f7bf6",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading data: 0%| | 0.00/9.47M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "8fad89bd442646968ac667f427efc8e5",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading data: 0%| | 0.00/885k [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" "
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "1ff350c50c674a4d81f033c98b273d6f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Extracting data files #0: 0%| | 0/1 [00:00<?, ?obj/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "724a8fb02f144816a61e9e562e86952b",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Extracting data files #2: 0%| | 0/1 [00:00<?, ?obj/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "99bb9649389544a6a2d9e3ede599f9d0",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Extracting data files #1: 0%| | 0/1 [00:00<?, ?obj/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "93fd550b1e0b4881b22d1c37f6e201c2",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Generating test split: 0%| | 0/13 [00:00<?, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "7a4a5650eda941a5b86d8c3d367b6589",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Generating train split: 0%| | 0/101 [00:00<?, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "b1b5c26457ed47dd87a80d9fb6769d75",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Generating validation split: 0%| | 0/11 [00:00<?, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Dataset parquet downloaded and prepared to /home/pc/.cache/huggingface/datasets/Zombely___parquet/Zombely--pl-text-images-f3f66e614f4d9a7a/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec. Subsequent calls will reuse this data.\n"
]
}
],
"source": [
"dataset = load_dataset(\"Zombely/pl-text-images\", split=\"validation\")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "5cecbeda924942488b63727ca3885d60",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/11 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'accuracies': [0, 0.9791666666666666, 0.9156626506024097, 1.0, 0.9836552748885586, 1.0, 0.7335359675785207, 0.9512987012987013, 0.396732788798133, 0.9908675799086758, 0.9452954048140044], 'mean_accuracy': 0.8087468213232427} length : 11\n",
"Mean accuracy: 0.8087468213232427\n"
]
}
],
"source": [
"\n",
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
"\n",
"model.eval()\n",
"model.to(device)\n",
"\n",
"output_list = []\n",
"accs = []\n",
"\n",
"\n",
"for idx, sample in tqdm(enumerate(dataset), total=len(dataset)):\n",
" # prepare encoder inputs\n",
" pixel_values = processor(sample[\"image\"].convert(\"RGB\"), return_tensors=\"pt\").pixel_values\n",
" pixel_values = pixel_values.to(device)\n",
" # prepare decoder inputs\n",
" task_prompt = \"<s_cord-v2>\"\n",
" decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors=\"pt\").input_ids\n",
" decoder_input_ids = decoder_input_ids.to(device)\n",
" \n",
" # autoregressively generate sequence\n",
" outputs = model.generate(\n",
" pixel_values,\n",
" decoder_input_ids=decoder_input_ids,\n",
" max_length=model.decoder.config.max_position_embeddings,\n",
" early_stopping=True,\n",
" pad_token_id=processor.tokenizer.pad_token_id,\n",
" eos_token_id=processor.tokenizer.eos_token_id,\n",
" use_cache=True,\n",
" num_beams=1,\n",
" bad_words_ids=[[processor.tokenizer.unk_token_id]],\n",
" return_dict_in_generate=True,\n",
" )\n",
"\n",
" # turn into JSON\n",
" seq = processor.batch_decode(outputs.sequences)[0]\n",
" seq = seq.replace(processor.tokenizer.eos_token, \"\").replace(processor.tokenizer.pad_token, \"\")\n",
" seq = re.sub(r\"<.*?>\", \"\", seq, count=1).strip() # remove first task start token\n",
" seq = processor.token2json(seq)\n",
"\n",
" ground_truth = json.loads(sample[\"ground_truth\"])\n",
" ground_truth = ground_truth[\"gt_parse\"]\n",
" evaluator = JSONParseEvaluator()\n",
" score = evaluator.cal_acc(seq, ground_truth)\n",
"\n",
" accs.append(score)\n",
" output_list.append(seq)\n",
"\n",
"scores = {\"accuracies\": accs, \"mean_accuracy\": np.mean(accs)}\n",
"print(scores, f\"length : {len(accs)}\")\n",
"print(\"Mean accuracy:\", np.mean(accs))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.7.15 ('donut')",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.15"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "11ee3e278e787ae04c18a69549ce58331d512f29053c6ca32ae16833b7cef834"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}