chat-restaruacja/evaluate.py

63 lines
2.2 KiB
Python
Raw Normal View History

2024-05-10 01:25:13 +02:00
import re
import os
import pandas as pd
import numpy as np
from nlu_utils import predict_multiple
from flair.models import SequenceTagger
from conllu import parse_incr
from flair.data import Corpus
from nlu_utils import conllu2flair, nolabel2o
# Frame model evaluation
frame_model = SequenceTagger.load('frame-model-prod/best-model.pt')
with open('data/test_dialog_46.conllu', encoding='utf-8') as trainfile:
testset = list(parse_incr(trainfile, fields=['id', 'form', 'frame', 'slot'], field_parsers={}))
corpus = Corpus(test=conllu2flair(testset, "frame"))
result = frame_model.evaluate(corpus.test, mini_batch_size=1, gold_label_type="frame")
print(result.detailed_results)
# Slot model evaluation
slot_model = SequenceTagger.load('slot-model-prod/best-model.pt')
with open('data/test_dialog_46.conllu', encoding='utf-8') as trainfile:
testset = list(parse_incr(trainfile, fields=['id', 'form', 'frame', 'slot'], field_parsers={'slot': nolabel2o}))
corpus = Corpus(test=conllu2flair(testset, "slot"))
result = slot_model.evaluate(corpus.test, mini_batch_size=8, gold_label_type="slot")
print(result.detailed_results)
# Custom evaluation
def __parse_acts(acts):
acts_split = acts.split('&')
remove_slot_regex = "[\(\[].*?[\)\]]"
return set(re.sub(remove_slot_regex, "", act) for act in acts_split)
def __parse_predictions(predictions):
return set(prediction.split('/')[0] for prediction in predictions)
total_acts = 0
act_correct_predictions = 0
slot_correct_predictions = 0
for file_name in os.listdir('data'):
if file_name.split('.')[-1] != 'tsv':
continue
df = pd.read_csv(f'data/{file_name}', sep='\t', names=['kto', 'treść', 'akt'])
df = df[df.kto == 'user']
all_data = np.array(df)
for row in all_data:
sentence = row[1]
acts = __parse_acts(row[2])
predictions_raw = predict_multiple(frame_model, sentence.split(), 'frame')
predictions = __parse_predictions(predictions_raw)
for act in acts:
total_acts += 1
if act in predictions:
act_correct_predictions += 1
print(f"Accuracy - predicting acts: {(act_correct_predictions / total_acts)*100} ({act_correct_predictions}/{total_acts})")