Compare commits

...

2 Commits

Author SHA1 Message Date
Karol Cyganik 53628482e7 Merge branch 'main' of https://git.wmi.amu.edu.pl/s464915/JARVIS 2024-05-14 11:59:32 +02:00
Karol Cyganik b4e1f57795 add precision & recall 2024-05-14 11:59:29 +02:00
1 changed files with 41 additions and 0 deletions

View File

@ -3,6 +3,7 @@ import pandas as pd
import jsgf
from unidecode import unidecode
import string
from collections import defaultdict
def decode_prompt(prompt):
@ -21,6 +22,12 @@ for filename in os.listdir("data"):
recognized = 0
unrecognized = 0
true_positives = 0
false_positives = 0
false_negatives = 0
acts_recognized = defaultdict(int)
acts_not_recognized = defaultdict(int)
for df in data_files:
if len(df.columns) == 3:
@ -40,7 +47,41 @@ for df in data_files:
false_count = len(parsed) - true_count
recognized += true_count
unrecognized += false_count
for line, correct in zip(df.iterrows(), parsed):
acts_recognized[line[1]['act'].split('(')[0]] += int(correct)
acts_not_recognized[line[1]['act'].split('(')[0]] += int(not(correct))
print(f"Recognized user utterances: {recognized}")
print(f"Unrecognized user utterances: {unrecognized}")
print(f"Accuracy: {recognized/(recognized+unrecognized)}")
precision_per_class = {}
recall_per_class = {}
for act in acts_recognized.keys():
true_positives = acts_recognized[act]
false_negatives = acts_not_recognized[act]
false_positives = recognized - true_positives
precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) != 0 else 0
recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) != 0 else 0
precision_per_class[act] = precision
recall_per_class[act] = recall
average_precision = sum(precision_per_class.values()) / len(precision_per_class)
average_recall = sum(recall_per_class.values()) / len(recall_per_class)
print("\nPrecision per class:")
for act, precision in precision_per_class.items():
print(f"{act}: {precision}")
print("\nRecall per class:")
for act, recall in recall_per_class.items():
print(f"{act}: {recall}")
print(f"\nAverage Precision: {average_precision}")
print(f"Average Recall: {average_recall}")