import re
import os
import pandas as pd
import numpy as np
from flair.models import SequenceTagger
from conllu import parse_incr
from flair.data import Corpus
from src.utils.nlu_utils import conllu2flair, nolabel2o, predict_multiple

# Frame model evaluation
frame_model = SequenceTagger.load('frame-model-prod/best-model.pt')
with open('data/test_dialog_46.conllu', encoding='utf-8') as trainfile:
    testset = list(parse_incr(trainfile, fields=['id', 'form', 'frame', 'slot'], field_parsers={}))

corpus = Corpus(test=conllu2flair(testset, "frame"))
result = frame_model.evaluate(corpus.test, mini_batch_size=1, gold_label_type="frame")
print(result.detailed_results)

# Slot model evaluation
slot_model = SequenceTagger.load('slot-model-prod/best-model.pt')

with open('data/test_dialog_46.conllu', encoding='utf-8') as trainfile:
    testset = list(parse_incr(trainfile, fields=['id', 'form', 'frame', 'slot'], field_parsers={'slot': nolabel2o}))

corpus = Corpus(test=conllu2flair(testset, "slot"))
result = slot_model.evaluate(corpus.test, mini_batch_size=8, gold_label_type="slot")
print(result.detailed_results)

# Custom evaluation
def __parse_acts(acts):
    acts_split = acts.split('&')
    remove_slot_regex = "[\(\[].*?[\)\]]"
    return set(re.sub(remove_slot_regex, "", act) for act in acts_split)

def __parse_predictions(predictions):
    return set(prediction.split('/')[0] for prediction in predictions)

total_acts = 0
act_correct_predictions = 0
slot_correct_predictions = 0

for file_name in os.listdir('data'):
    if file_name.split('.')[-1] != 'tsv':
        continue

    df = pd.read_csv(f'data/{file_name}', sep='\t', names=['kto', 'treść', 'akt'])
    df = df[df.kto == 'user']
    all_data = np.array(df)

    for row in all_data:
        sentence = row[1]
        acts = __parse_acts(row[2])

        predictions_raw = predict_multiple(frame_model, sentence.split(), 'frame')
        predictions = __parse_predictions(predictions_raw)

        for act in acts:
            total_acts += 1
            if act in predictions:
                act_correct_predictions += 1

print(f"Accuracy - predicting acts: {(act_correct_predictions / total_acts)*100} ({act_correct_predictions}/{total_acts})")