446 lines
12 KiB
Plaintext
446 lines
12 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from transformers import DonutProcessor, VisionEncoderDecoderModel\n",
|
|
"from datasets import load_dataset\n",
|
|
"import re\n",
|
|
"import json\n",
|
|
"import torch\n",
|
|
"from tqdm.auto import tqdm\n",
|
|
"import numpy as np\n",
|
|
"\n",
|
|
"from donut import JSONParseEvaluator"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "24d467a0b2fa49a99506d690dba9e411",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
"Downloading: 0%| | 0.00/421 [00:00<?, ?B/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "3b69047a501b4e18b831d9e28c90f45e",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
"Downloading: 0%| | 0.00/544 [00:00<?, ?B/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "69430fbf988c44e39d02b05144c356df",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
"Downloading: 0%| | 0.00/1.30M [00:00<?, ?B/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "875aa0ed813647d08d7f1e003b92b4e5",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
"Downloading: 0%| | 0.00/4.01M [00:00<?, ?B/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "cc5997bbdf6f4adfbea60fdfe18c3503",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
"Downloading: 0%| | 0.00/95.0 [00:00<?, ?B/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "634540e8020e45129fe422976a7113e1",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
"Downloading: 0%| | 0.00/355 [00:00<?, ?B/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "6fdce54654504308a7f1ba669823b996",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
"Downloading: 0%| | 0.00/5.03k [00:00<?, ?B/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "552c171e151045e2a9dc6f860428add0",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
"Downloading: 0%| | 0.00/809M [00:00<?, ?B/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"\n",
|
|
"processor = DonutProcessor.from_pretrained(\"Zombely/plwiki-test\")\n",
|
|
"model = VisionEncoderDecoderModel.from_pretrained(\"Zombely/plwiki-test\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "deba176990a74dbd80fb9347758b6e33",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
"Downloading readme: 0%| | 0.00/527 [00:00<?, ?B/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Using custom data configuration Zombely--pl-text-images-f3f66e614f4d9a7a\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Downloading and preparing dataset None/None to /home/pc/.cache/huggingface/datasets/Zombely___parquet/Zombely--pl-text-images-f3f66e614f4d9a7a/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec...\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "9e238b0b3f1e4176a00657852986df9d",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
"Downloading data files: 0%| | 0/3 [00:00<?, ?it/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "1f57c7efeb0f4e1e8bf7cc39e1d78a84",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
"Downloading data: 0%| | 0.00/1.44M [00:00<?, ?B/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "68418bad8c1745faa7ecd802594f7bf6",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
"Downloading data: 0%| | 0.00/9.47M [00:00<?, ?B/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "8fad89bd442646968ac667f427efc8e5",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
"Downloading data: 0%| | 0.00/885k [00:00<?, ?B/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
" "
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "1ff350c50c674a4d81f033c98b273d6f",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
"Extracting data files #0: 0%| | 0/1 [00:00<?, ?obj/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "724a8fb02f144816a61e9e562e86952b",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
"Extracting data files #2: 0%| | 0/1 [00:00<?, ?obj/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "99bb9649389544a6a2d9e3ede599f9d0",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
"Extracting data files #1: 0%| | 0/1 [00:00<?, ?obj/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "93fd550b1e0b4881b22d1c37f6e201c2",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
"Generating test split: 0%| | 0/13 [00:00<?, ? examples/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "7a4a5650eda941a5b86d8c3d367b6589",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
"Generating train split: 0%| | 0/101 [00:00<?, ? examples/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "b1b5c26457ed47dd87a80d9fb6769d75",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
"Generating validation split: 0%| | 0/11 [00:00<?, ? examples/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Dataset parquet downloaded and prepared to /home/pc/.cache/huggingface/datasets/Zombely___parquet/Zombely--pl-text-images-f3f66e614f4d9a7a/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec. Subsequent calls will reuse this data.\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"dataset = load_dataset(\"Zombely/pl-text-images\", split=\"validation\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "5cecbeda924942488b63727ca3885d60",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
" 0%| | 0/11 [00:00<?, ?it/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"{'accuracies': [0, 0.9791666666666666, 0.9156626506024097, 1.0, 0.9836552748885586, 1.0, 0.7335359675785207, 0.9512987012987013, 0.396732788798133, 0.9908675799086758, 0.9452954048140044], 'mean_accuracy': 0.8087468213232427} length : 11\n",
|
|
"Mean accuracy: 0.8087468213232427\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"\n",
|
|
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
|
|
"\n",
|
|
"model.eval()\n",
|
|
"model.to(device)\n",
|
|
"\n",
|
|
"output_list = []\n",
|
|
"accs = []\n",
|
|
"\n",
|
|
"\n",
|
|
"for idx, sample in tqdm(enumerate(dataset), total=len(dataset)):\n",
|
|
" # prepare encoder inputs\n",
|
|
" pixel_values = processor(sample[\"image\"].convert(\"RGB\"), return_tensors=\"pt\").pixel_values\n",
|
|
" pixel_values = pixel_values.to(device)\n",
|
|
" # prepare decoder inputs\n",
|
|
" task_prompt = \"<s_cord-v2>\"\n",
|
|
" decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors=\"pt\").input_ids\n",
|
|
" decoder_input_ids = decoder_input_ids.to(device)\n",
|
|
" \n",
|
|
" # autoregressively generate sequence\n",
|
|
" outputs = model.generate(\n",
|
|
" pixel_values,\n",
|
|
" decoder_input_ids=decoder_input_ids,\n",
|
|
" max_length=model.decoder.config.max_position_embeddings,\n",
|
|
" early_stopping=True,\n",
|
|
" pad_token_id=processor.tokenizer.pad_token_id,\n",
|
|
" eos_token_id=processor.tokenizer.eos_token_id,\n",
|
|
" use_cache=True,\n",
|
|
" num_beams=1,\n",
|
|
" bad_words_ids=[[processor.tokenizer.unk_token_id]],\n",
|
|
" return_dict_in_generate=True,\n",
|
|
" )\n",
|
|
"\n",
|
|
" # turn into JSON\n",
|
|
" seq = processor.batch_decode(outputs.sequences)[0]\n",
|
|
" seq = seq.replace(processor.tokenizer.eos_token, \"\").replace(processor.tokenizer.pad_token, \"\")\n",
|
|
" seq = re.sub(r\"<.*?>\", \"\", seq, count=1).strip() # remove first task start token\n",
|
|
" seq = processor.token2json(seq)\n",
|
|
"\n",
|
|
" ground_truth = json.loads(sample[\"ground_truth\"])\n",
|
|
" ground_truth = ground_truth[\"gt_parse\"]\n",
|
|
" evaluator = JSONParseEvaluator()\n",
|
|
" score = evaluator.cal_acc(seq, ground_truth)\n",
|
|
"\n",
|
|
" accs.append(score)\n",
|
|
" output_list.append(seq)\n",
|
|
"\n",
|
|
"scores = {\"accuracies\": accs, \"mean_accuracy\": np.mean(accs)}\n",
|
|
"print(scores, f\"length : {len(accs)}\")\n",
|
|
"print(\"Mean accuracy:\", np.mean(accs))"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3.7.15 ('donut')",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.7.15"
|
|
},
|
|
"orig_nbformat": 4,
|
|
"vscode": {
|
|
"interpreter": {
|
|
"hash": "11ee3e278e787ae04c18a69549ce58331d512f29053c6ca32ae16833b7cef834"
|
|
}
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|