modifcation to train and eval

This commit is contained in:
s444415 2022-12-16 12:32:08 +00:00
parent 67cc4bdf7c
commit d98383197f
2 changed files with 23 additions and 19 deletions

View File

@ -11,21 +11,21 @@ import json
import torch import torch
from tqdm.auto import tqdm from tqdm.auto import tqdm
import numpy as np import numpy as np
import pandas as pd
from donut import JSONParseEvaluator from donut import JSONParseEvaluator
# In[2]: # In[2]:
processor = DonutProcessor.from_pretrained("Zombely/plwiki-proto-fine-tuned") processor = DonutProcessor.from_pretrained("Zombely/plwiki-proto-fine-tuned-v2")
model = VisionEncoderDecoderModel.from_pretrained("Zombely/plwiki-proto-fine-tuned") model = VisionEncoderDecoderModel.from_pretrained("Zombely/plwiki-proto-fine-tuned-v2")
# In[3]: # In[3]:
dataset = load_dataset("Zombely/pl-text-images-5000-whole", split="validation") dataset = load_dataset("Zombely/diachronia-ocr", split='train')
# In[4]: # In[4]:
@ -38,11 +38,11 @@ model.to(device)
output_list = [] output_list = []
accs = [] accs = []
has_metadata = bool(dataset[0].get('ground_truth'))
for idx, sample in tqdm(enumerate(dataset), total=len(dataset)): for idx, sample in tqdm(enumerate(dataset), total=len(dataset)):
# prepare encoder inputs # prepare encoder inputs
pixel_values = processor(sample["image"].convert("RGB"), return_tensors="pt").pixel_values pixel_values = processor(sample['image'].convert("RGB"), return_tensors="pt").pixel_values
pixel_values = pixel_values.to(device) pixel_values = pixel_values.to(device)
# prepare decoder inputs # prepare decoder inputs
task_prompt = "<s_cord-v2>" task_prompt = "<s_cord-v2>"
@ -68,16 +68,20 @@ for idx, sample in tqdm(enumerate(dataset), total=len(dataset)):
seq = seq.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "") seq = seq.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
seq = re.sub(r"<.*?>", "", seq, count=1).strip() # remove first task start token seq = re.sub(r"<.*?>", "", seq, count=1).strip() # remove first task start token
seq = processor.token2json(seq) seq = processor.token2json(seq)
if has_metadata:
ground_truth = json.loads(sample["ground_truth"])
ground_truth = ground_truth["gt_parse"]
evaluator = JSONParseEvaluator()
score = evaluator.cal_acc(seq, ground_truth)
ground_truth = json.loads(sample["ground_truth"]) accs.append(score)
ground_truth = ground_truth["gt_parse"] print(seq)
evaluator = JSONParseEvaluator()
score = evaluator.cal_acc(seq, ground_truth)
accs.append(score)
output_list.append(seq) output_list.append(seq)
df = pd.DataFrame(map(lambda x: x.get('text_sequence', ''), output_list))
df.to_csv('out.tsv', sep='\t', header=False, index=False)
scores = {"accuracies": accs, "mean_accuracy": np.mean(accs)} if has_metadata:
print(scores, f"length : {len(accs)}") scores = {"accuracies": accs, "mean_accuracy": np.mean(accs)}
print("Mean accuracy:", np.mean(accs)) print(scores, f"length : {len(accs)}")
print("Mean accuracy:", np.mean(accs))

View File

@ -22,7 +22,7 @@ from pytorch_lightning.plugins import CheckpointIO
DATASET_PATH = "Zombely/pl-text-images-5000-whole" DATASET_PATH = "Zombely/pl-text-images-5000-whole"
PRETRAINED_MODEL_PATH = "Zombely/plwiki-proto-fine-tuned" PRETRAINED_MODEL_PATH = "Zombely/plwiki-proto-fine-tuned-v2"
START_MODEL_PATH = "Zombely/plwiki-proto-fine-tuned" START_MODEL_PATH = "Zombely/plwiki-proto-fine-tuned"
OUTPUT_MODEL_PATH = "Zombely/plwiki-proto-fine-tuned-v2" OUTPUT_MODEL_PATH = "Zombely/plwiki-proto-fine-tuned-v2"
LOGGING_PATH = "plwiki-proto-ft-second-iter" LOGGING_PATH = "plwiki-proto-ft-second-iter"
@ -30,8 +30,8 @@ CHECKPOINT_PATH = "./checkpoint"
train_config = { train_config = {
"max_epochs":30, "max_epochs":1,
"val_check_interval":0.5, # how many times we want to validate during an epoch "val_check_interval":1.0, # how many times we want to validate during an epoch
"check_val_every_n_epoch":1, "check_val_every_n_epoch":1,
"gradient_clip_val":1.0, "gradient_clip_val":1.0,
"num_training_samples_per_epoch": 800, "num_training_samples_per_epoch": 800,
@ -339,7 +339,7 @@ class PushToHubCallback(Callback):
login(os.environ.get("HUG_TOKKEN", "")) login(os.environ.get("HUG_TOKKEN", None), True)
# ### Wandb.ai link: https://wandb.ai/michalkozlowski936/Donut?workspace=user-michalkozlowski936 # ### Wandb.ai link: https://wandb.ai/michalkozlowski936/Donut?workspace=user-michalkozlowski936