JARVIS/evaluate.py

68 lines
2.1 KiB
Python
Raw Normal View History

import os
import pandas as pd
import jsgf
2024-05-11 19:41:13 +02:00
from unidecode import unidecode
import string
2024-05-14 11:59:29 +02:00
from collections import defaultdict
2024-05-11 19:41:13 +02:00
def decode_prompt(prompt):
prompt_decoded = unidecode(prompt)
translator = str.maketrans('', '', string.punctuation)
prompt_decoded = prompt_decoded.translate(translator)
return prompt_decoded
grammar = jsgf.parse_grammar_file('book.jsgf')
data_files = []
for filename in os.listdir("data"):
f = os.path.join("data", filename)
if os.path.isfile(f):
data_files.append(pd.read_csv(f, sep='\t', header=None))
2024-05-14 11:59:29 +02:00
true_positives = 0
false_positives = 0
false_negatives = 0
acts_recognized = defaultdict(int)
acts_not_recognized = defaultdict(int)
2024-05-14 12:28:20 +02:00
false_negatives = 0
false_positives = 0
for df in data_files:
if len(df.columns) == 3:
df.columns = ["agent", "message", "act"]
elif len(df.columns) == 2:
df.columns = ["agent", "message"]
else:
continue
2024-05-26 13:40:21 +02:00
user_speech_rows = df[df['agent'] == "user"]
user_speeches = user_speech_rows["message"]
entries_count = len(user_speeches)
2024-05-11 19:41:13 +02:00
2024-05-14 12:28:20 +02:00
found_rules = user_speeches.apply(lambda x: grammar.find_matching_rules(decode_prompt(x)))
2024-05-26 13:40:21 +02:00
for line, rules in zip(user_speech_rows.iterrows(), found_rules):
2024-05-14 12:28:20 +02:00
act = line[1]['act'].split('(')[0]
if len(rules) > 0:
recognized_act = rules[0].name
if recognized_act in act:
true_positives += 1
2024-05-26 13:40:21 +02:00
acts_recognized[act] += 1
else:
2024-05-14 12:28:20 +02:00
false_positives += 1
acts_not_recognized[act] += 1
else:
false_negatives += 1
acts_not_recognized[act] += 1
2024-05-14 12:35:54 +02:00
accuracy = (true_positives + false_positives) / ((true_positives + false_positives) + sum([x for x in acts_not_recognized.values()]))
2024-05-14 12:28:20 +02:00
precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) != 0 else 0
recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) != 0 else 0
2024-05-14 12:28:20 +02:00
print(f"Accuracy: {accuracy}")
print(f"Precision: {precision}")
print(f"Recall: {recall}")
2024-05-14 11:59:29 +02:00