config and params for donut-eval

This commit is contained in:
Michał Kozłowski 2022-12-16 13:53:03 +01:00
parent d98383197f
commit 8ccd1aabb6
2 changed files with 68 additions and 68 deletions

7
config.yaml Normal file
View File

@ -0,0 +1,7 @@
pretrained_processor_path: "Zombely/plwiki-proto-fine-tuned-v2"
pretrained_model_path: "Zombely/plwiki-proto-fine-tuned-v2"
validation_dataset_path: "Zombely/diachronia-ocr"
validation_dataset_split: "train"
has_metadata: False
print_output: True
output_file_dir: "../../gonito-outs"

View File

@ -1,9 +1,6 @@
#!/usr/bin/env python #!/usr/bin/env python
# coding: utf-8 # coding: utf-8
# In[1]:
from transformers import DonutProcessor, VisionEncoderDecoderModel from transformers import DonutProcessor, VisionEncoderDecoderModel
from datasets import load_dataset from datasets import load_dataset
import re import re
@ -13,34 +10,20 @@ from tqdm.auto import tqdm
import numpy as np import numpy as np
import pandas as pd import pandas as pd
from donut import JSONParseEvaluator from donut import JSONParseEvaluator
import argparse
from sconf import Config
def main(config):
processor = DonutProcessor.from_pretrained(config.pretrained_processor_path)
model = VisionEncoderDecoderModel.from_pretrained(config.pretrained_model_path)
dataset = load_dataset(config.validation_dataset_path, split=config.validation_dataset_split)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.eval()
model.to(device)
output_list = []
accs = []
# In[2]: for idx, sample in tqdm(enumerate(dataset), total=len(dataset)):
processor = DonutProcessor.from_pretrained("Zombely/plwiki-proto-fine-tuned-v2")
model = VisionEncoderDecoderModel.from_pretrained("Zombely/plwiki-proto-fine-tuned-v2")
# In[3]:
dataset = load_dataset("Zombely/diachronia-ocr", split='train')
# In[4]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model.eval()
model.to(device)
output_list = []
accs = []
has_metadata = bool(dataset[0].get('ground_truth'))
for idx, sample in tqdm(enumerate(dataset), total=len(dataset)):
# prepare encoder inputs # prepare encoder inputs
pixel_values = processor(sample['image'].convert("RGB"), return_tensors="pt").pixel_values pixel_values = processor(sample['image'].convert("RGB"), return_tensors="pt").pixel_values
pixel_values = pixel_values.to(device) pixel_values = pixel_values.to(device)
@ -68,20 +51,30 @@ for idx, sample in tqdm(enumerate(dataset), total=len(dataset)):
seq = seq.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "") seq = seq.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
seq = re.sub(r"<.*?>", "", seq, count=1).strip() # remove first task start token seq = re.sub(r"<.*?>", "", seq, count=1).strip() # remove first task start token
seq = processor.token2json(seq) seq = processor.token2json(seq)
if has_metadata: if config.has_metadata:
ground_truth = json.loads(sample["ground_truth"]) ground_truth = json.loads(sample["ground_truth"])
ground_truth = ground_truth["gt_parse"] ground_truth = ground_truth["gt_parse"]
evaluator = JSONParseEvaluator() evaluator = JSONParseEvaluator()
score = evaluator.cal_acc(seq, ground_truth) score = evaluator.cal_acc(seq, ground_truth)
accs.append(score) accs.append(score)
if config.print_output:
print(seq) print(seq)
output_list.append(seq) output_list.append(seq)
df = pd.DataFrame(map(lambda x: x.get('text_sequence', ''), output_list)) if config.output_file_dir:
df.to_csv('out.tsv', sep='\t', header=False, index=False) df = pd.DataFrame(map(lambda x: x.get('text_sequence', ''), output_list))
df.to_csv(f'{config.output_file_dir}/{config.pretrained_processor_path}-out.tsv', sep='\t', header=False, index=False)
if has_metadata: if config.has_metadata:
scores = {"accuracies": accs, "mean_accuracy": np.mean(accs)} scores = {"accuracies": accs, "mean_accuracy": np.mean(accs)}
print(scores, f"length : {len(accs)}") print(scores, f"length : {len(accs)}")
print("Mean accuracy:", np.mean(accs)) print("Mean accuracy:", np.mean(accs))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, required=True)
args, left_argv = parser.parse_known_args()
config = Config(args.config)
config.argv_update(left_argv)
main(config)