GOATS/evaluate.py
2024-05-29 19:26:14 +02:00

54 lines
1.4 KiB
Python

import os
import pandas as pd
import re
from NaturalLanguageAnalyzer import NaturalLanguageAnalyzer
data_directory = 'data'
file_list = os.listdir(data_directory)
dfs = []
for file_name in file_list:
file_path = os.path.join(data_directory, file_name)
df = pd.read_csv(file_path, sep='\t', encoding='utf-8')
df_user = df[df['role'] == 'user'].drop('role', axis=1)
dfs.append(df_user)
combined_df = pd.concat(dfs, ignore_index=True)
change_act_format = {
"thankyou": "thank",
"bye": "thank",
"hello": "greet",
"inform": "inform",
"request": "request",
"reqmore": "request"
}
correct = 0
incorrect = 0
for text, ground_act in zip(combined_df["value"].values, combined_df["act"].values):
nla = NaturalLanguageAnalyzer()
nla_output = nla.predict(text)
predicted_act = set([i[0] for i in nla_output])
pattern = re.compile(r'([^(&]+)(?=\()')
matches = re.findall(pattern, ground_act)
ground_act_processed = set()
for match in matches:
if match in change_act_format:
ground_act_processed.add(change_act_format[match])
for i in ground_act_processed:
if i in predicted_act:
correct += 1
else:
incorrect += 1
print("Predicted:", predicted_act)
print("Ground truth:", ground_act_processed)
print()
accuracy = correct/(correct+incorrect)
print("Accuracy: ", accuracy)