fix evaluation

This commit is contained in:
Karol Cyganik 2024-05-14 12:28:20 +02:00
parent 0b4c46e239
commit 793b7ca37b

View File

@ -28,6 +28,8 @@ false_negatives = 0
acts_recognized = defaultdict(int) acts_recognized = defaultdict(int)
acts_not_recognized = defaultdict(int) acts_not_recognized = defaultdict(int)
false_negatives = 0
false_positives = 0
for df in data_files: for df in data_files:
if len(df.columns) == 3: if len(df.columns) == 3:
@ -40,48 +42,33 @@ for df in data_files:
user_speeches = user_speech_rows["message"] user_speeches = user_speech_rows["message"]
entries_count = len(user_speeches) entries_count = len(user_speeches)
found_rules = user_speeches.apply(lambda x: grammar.find_matching_rules(decode_prompt(x)))
parsed = user_speeches.apply( parsed = user_speeches.apply(lambda x: bool(grammar.find_matching_rules(decode_prompt(x))))
lambda x: bool(grammar.find_matching_rules(decode_prompt(x))))
true_count = parsed.sum() true_count = parsed.sum()
false_count = len(parsed) - true_count false_count = len(parsed) - true_count
recognized += true_count recognized += true_count
unrecognized += false_count unrecognized += false_count
for line, correct in zip(df.iterrows(), parsed): for line, rules in zip(df.iterrows(), found_rules):
acts_recognized[line[1]['act'].split('(')[0]] += int(correct) act = line[1]['act'].split('(')[0]
acts_not_recognized[line[1]['act'].split('(')[0]] += int(not(correct)) if len(rules) > 0:
recognized_act = rules[0].name
if recognized_act in act:
true_positives += 1
else:
false_positives += 1
acts_not_recognized[act] += 1
else:
false_negatives += 1
acts_not_recognized[act] += 1
accuracy = recognized / (recognized + unrecognized)
precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) != 0 else 0
recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) != 0 else 0
print(f"Recognized user utterances: {recognized}") print(f"Recognized user utterances: {recognized}")
print(f"Unrecognized user utterances: {unrecognized}") print(f"Unrecognized user utterances: {unrecognized}")
print(f"Accuracy: {recognized/(recognized+unrecognized)}") print(f"Accuracy: {accuracy}")
print(f"Precision: {precision}")
print(f"Recall: {recall}")
precision_per_class = {}
recall_per_class = {}
for act in acts_recognized.keys():
true_positives = acts_recognized[act]
false_negatives = acts_not_recognized[act]
false_positives = recognized - true_positives
precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) != 0 else 0
recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) != 0 else 0
precision_per_class[act] = precision
recall_per_class[act] = recall
average_precision = sum(precision_per_class.values()) / len(precision_per_class)
average_recall = sum(recall_per_class.values()) / len(recall_per_class)
print("\nPrecision per class:")
for act, precision in precision_per_class.items():
print(f"{act}: {precision}")
print("\nRecall per class:")
for act, recall in recall_per_class.items():
print(f"{act}: {recall}")
print(f"\nAverage Precision: {average_precision}")
print(f"Average Recall: {average_recall}")