From da419d66591d08dedbf797ba65028c43ca02175c Mon Sep 17 00:00:00 2001 From: Karol Cyganik Date: Tue, 14 May 2024 12:35:54 +0200 Subject: [PATCH] fix accuracy --- evaluate.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/evaluate.py b/evaluate.py index 699b894..e5e2835 100644 --- a/evaluate.py +++ b/evaluate.py @@ -20,8 +20,6 @@ for filename in os.listdir("data"): if os.path.isfile(f): data_files.append(pd.read_csv(f, sep='\t', header=None)) -recognized = 0 -unrecognized = 0 true_positives = 0 false_positives = 0 false_negatives = 0 @@ -43,11 +41,6 @@ for df in data_files: entries_count = len(user_speeches) found_rules = user_speeches.apply(lambda x: grammar.find_matching_rules(decode_prompt(x))) - parsed = user_speeches.apply(lambda x: bool(grammar.find_matching_rules(decode_prompt(x)))) - true_count = parsed.sum() - false_count = len(parsed) - true_count - recognized += true_count - unrecognized += false_count for line, rules in zip(df.iterrows(), found_rules): act = line[1]['act'].split('(')[0] @@ -62,12 +55,10 @@ for df in data_files: false_negatives += 1 acts_not_recognized[act] += 1 -accuracy = recognized / (recognized + unrecognized) +accuracy = (true_positives + false_positives) / ((true_positives + false_positives) + sum([x for x in acts_not_recognized.values()])) precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) != 0 else 0 recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) != 0 else 0 -print(f"Recognized user utterances: {recognized}") -print(f"Unrecognized user utterances: {unrecognized}") print(f"Accuracy: {accuracy}") print(f"Precision: {precision}") print(f"Recall: {recall}")