import re import os import pandas as pd import numpy as np from flair.models import SequenceTagger from conllu import parse_incr from flair.data import Corpus from src.utils.nlu_utils import conllu2flair, nolabel2o, predict_multiple # Frame model evaluation frame_model = SequenceTagger.load('frame-model-prod/best-model.pt') with open('data/test_dialog_46.conllu', encoding='utf-8') as trainfile: testset = list(parse_incr(trainfile, fields=['id', 'form', 'frame', 'slot'], field_parsers={})) corpus = Corpus(test=conllu2flair(testset, "frame")) result = frame_model.evaluate(corpus.test, mini_batch_size=1, gold_label_type="frame") print(result.detailed_results) # Slot model evaluation slot_model = SequenceTagger.load('slot-model-prod/best-model.pt') with open('data/test_dialog_46.conllu', encoding='utf-8') as trainfile: testset = list(parse_incr(trainfile, fields=['id', 'form', 'frame', 'slot'], field_parsers={'slot': nolabel2o})) corpus = Corpus(test=conllu2flair(testset, "slot")) result = slot_model.evaluate(corpus.test, mini_batch_size=8, gold_label_type="slot") print(result.detailed_results) # Custom evaluation def __parse_acts(acts): acts_split = acts.split('&') remove_slot_regex = "[\(\[].*?[\)\]]" return set(re.sub(remove_slot_regex, "", act) for act in acts_split) def __parse_predictions(predictions): return set(prediction.split('/')[0] for prediction in predictions) total_acts = 0 act_correct_predictions = 0 slot_correct_predictions = 0 for file_name in os.listdir('data'): if file_name.split('.')[-1] != 'tsv': continue df = pd.read_csv(f'data/{file_name}', sep='\t', names=['kto', 'treść', 'akt']) df = df[df.kto == 'user'] all_data = np.array(df) for row in all_data: sentence = row[1] acts = __parse_acts(row[2]) predictions_raw = predict_multiple(frame_model, sentence.split(), 'frame') predictions = __parse_predictions(predictions_raw) for act in acts: total_acts += 1 if act in predictions: act_correct_predictions += 1 print(f"Accuracy - predicting acts: {(act_correct_predictions / total_acts)*100} ({act_correct_predictions}/{total_acts})")