2023-12-14 16:39:46 +01:00
|
|
|
from fastapi import FastAPI
|
|
|
|
from pydantic import BaseModel
|
|
|
|
from sklearn.metrics import f1_score
|
|
|
|
|
|
|
|
import pandas as pd
|
|
|
|
import distance
|
|
|
|
import io
|
|
|
|
|
|
|
|
|
|
|
|
ERROR_RESPONSE = {
|
|
|
|
"status": 400
|
|
|
|
}
|
|
|
|
app = FastAPI()
|
|
|
|
|
|
|
|
|
|
|
|
class Data(BaseModel):
|
2024-01-12 10:20:16 +01:00
|
|
|
metric: str
|
|
|
|
file_expected: str = ""
|
|
|
|
file_out: str = ""
|
|
|
|
file_in: str = ""
|
2023-12-14 16:39:46 +01:00
|
|
|
|
|
|
|
|
|
|
|
def preprocess_data(out, expected):
|
|
|
|
out = out.split("\n")
|
|
|
|
expected = expected.split("\n")[:-1]
|
|
|
|
|
|
|
|
out = out[:len(expected)]
|
|
|
|
|
|
|
|
return out, expected
|
|
|
|
|
|
|
|
|
|
|
|
def get_levenshtein_score(trues, preds):
|
|
|
|
def normalize_answer(s):
|
|
|
|
return s.lower()
|
|
|
|
|
|
|
|
levenstein_scores = []
|
|
|
|
for true, pred in [(true, pred) for (true, pred) in zip(trues, preds) if true != ""]:
|
|
|
|
if pred == "":
|
|
|
|
levenstein_score = 0
|
|
|
|
else:
|
|
|
|
levenstein_score = 1 - distance.nlevenshtein(normalize_answer(true), normalize_answer(pred))
|
|
|
|
levenstein_scores.append(levenstein_score)
|
|
|
|
|
|
|
|
avg_levenstein_score = sum(levenstein_scores) / len(levenstein_scores) * 100
|
|
|
|
return avg_levenstein_score
|
|
|
|
|
|
|
|
|
|
|
|
def get_answerability_f1(trues, preds):
|
|
|
|
def get_answerability(answers):
|
|
|
|
return [1 if answer == "" else 0 for answer in answers]
|
|
|
|
|
|
|
|
true_answerability = get_answerability(trues)
|
|
|
|
predicted_answerability = get_answerability(preds)
|
|
|
|
answerability_f1 = f1_score(true_answerability, predicted_answerability, zero_division=0.0) * 100
|
|
|
|
return answerability_f1
|
|
|
|
|
|
|
|
|
2024-01-12 10:20:16 +01:00
|
|
|
def get_final_score(trues, preds):
|
2023-12-14 16:39:46 +01:00
|
|
|
scores = {}
|
|
|
|
scores["Levenshtein"] = get_levenshtein_score(trues, preds)
|
|
|
|
scores["AnswerabilityF1"] = get_answerability_f1(trues, preds)
|
2024-01-12 10:20:16 +01:00
|
|
|
return round((scores["Levenshtein"] + scores["AnswerabilityF1"]) / 2, 2)
|
2023-12-14 16:39:46 +01:00
|
|
|
|
|
|
|
|
|
|
|
def get_emotion_recognition_scores(df_in, df_expected, df_predition):
|
|
|
|
text_annotation = df_in['text'].apply(lambda x: x == '#' * len(x))
|
|
|
|
|
|
|
|
df_expected_text = df_expected[text_annotation]
|
|
|
|
df_expected_sentence = df_expected[~text_annotation]
|
|
|
|
|
|
|
|
df_prediction_text = df_predition[text_annotation]
|
|
|
|
df_prediction_sentence = df_predition[~text_annotation]
|
|
|
|
|
|
|
|
f1_text_score = f1_score(
|
2024-01-02 12:22:45 +01:00
|
|
|
df_prediction_text,
|
|
|
|
df_expected_text,
|
2023-12-14 16:39:46 +01:00
|
|
|
average='macro',
|
2024-01-02 12:22:45 +01:00
|
|
|
zero_division=0.0
|
2023-12-14 16:39:46 +01:00
|
|
|
)
|
|
|
|
f1_text_score = f1_text_score * 100
|
|
|
|
|
|
|
|
f1_sentence_score = f1_score(
|
2024-01-02 12:22:45 +01:00
|
|
|
df_expected_sentence,
|
|
|
|
df_prediction_sentence,
|
2023-12-14 16:39:46 +01:00
|
|
|
average='macro',
|
2024-01-02 12:22:45 +01:00
|
|
|
zero_division=0.0
|
2023-12-14 16:39:46 +01:00
|
|
|
)
|
|
|
|
f1_sentence_score = f1_sentence_score * 100
|
|
|
|
|
|
|
|
final_score = (f1_text_score + f1_sentence_score) / 2
|
|
|
|
|
|
|
|
return {
|
|
|
|
"SentenceF1": round(f1_sentence_score, 2),
|
|
|
|
"TextF1": round(f1_text_score, 2),
|
|
|
|
"FinalF1": round(final_score, 2)
|
|
|
|
}
|
|
|
|
|
|
|
|
@app.get("/")
|
|
|
|
async def root(data: Data):
|
2024-01-12 10:20:16 +01:00
|
|
|
metric = data.metric
|
|
|
|
file_expected = data.file_expected
|
|
|
|
file_out = data.file_out
|
|
|
|
file_in = data.file_in
|
|
|
|
|
|
|
|
try:
|
|
|
|
if metric == "Final":
|
|
|
|
if len(file_out) > 0:
|
|
|
|
file_out, file_expected = preprocess_data(file_out, file_expected)
|
|
|
|
if len(file_out) != len(file_expected):
|
|
|
|
return ERROR_RESPONSE
|
|
|
|
|
|
|
|
return {
|
|
|
|
"status": 200,
|
|
|
|
"result": get_final_score(file_expected, file_out)
|
|
|
|
}
|
|
|
|
elif metric == "Levenshtein":
|
|
|
|
if len(file_out) > 0:
|
|
|
|
file_out, file_expected = preprocess_data(file_out, file_expected)
|
|
|
|
if len(file_out) != len(file_expected):
|
|
|
|
return ERROR_RESPONSE
|
|
|
|
|
|
|
|
return {
|
|
|
|
"status": 200,
|
|
|
|
"result": round(get_levenshtein_score(file_expected, file_out), 2)
|
|
|
|
}
|
|
|
|
elif metric == "AnswerabilityF1":
|
|
|
|
if len(file_out) > 0:
|
|
|
|
file_out, file_expected = preprocess_data(file_out, file_expected)
|
|
|
|
if len(file_out) != len(file_expected):
|
|
|
|
return ERROR_RESPONSE
|
|
|
|
|
|
|
|
return {
|
|
|
|
"status": 200,
|
|
|
|
"result": round(get_answerability_f1(file_expected, file_out), 2)
|
|
|
|
}
|
|
|
|
elif metric == "FinalF1":
|
|
|
|
if len(file_out) > 0:
|
|
|
|
df_in = pd.read_table(io.StringIO(file_in))
|
|
|
|
df_expected = pd.read_table(io.StringIO(file_expected))
|
|
|
|
df_predition = pd.read_table(io.StringIO(file_out))
|
|
|
|
|
|
|
|
results = get_emotion_recognition_scores(df_in, df_expected, df_predition)
|
|
|
|
|
|
|
|
return {
|
|
|
|
"status": 200,
|
|
|
|
"result": results["FinalF1"]
|
|
|
|
}
|
|
|
|
elif metric == "SentenceF1":
|
|
|
|
if len(file_out) > 0:
|
|
|
|
df_in = pd.read_table(io.StringIO(file_in))
|
|
|
|
df_expected = pd.read_table(io.StringIO(file_expected))
|
|
|
|
df_predition = pd.read_table(io.StringIO(file_out))
|
|
|
|
|
|
|
|
results = get_emotion_recognition_scores(df_in, df_expected, df_predition)
|
|
|
|
|
|
|
|
return {
|
|
|
|
"status": 200,
|
|
|
|
"result": results["SentenceF1"]
|
|
|
|
}
|
|
|
|
elif metric == "TextF1":
|
|
|
|
if len(file_out) > 0:
|
|
|
|
df_in = pd.read_table(io.StringIO(file_in))
|
|
|
|
df_expected = pd.read_table(io.StringIO(file_expected))
|
|
|
|
df_predition = pd.read_table(io.StringIO(file_out))
|
|
|
|
|
|
|
|
results = get_emotion_recognition_scores(df_in, df_expected, df_predition)
|
|
|
|
|
|
|
|
return {
|
|
|
|
"status": 200,
|
|
|
|
"result": results["TextF1"]
|
|
|
|
}
|
|
|
|
except:
|
|
|
|
return ERROR_RESPONSE
|
2023-12-14 16:39:46 +01:00
|
|
|
|
2024-01-12 10:20:16 +01:00
|
|
|
return ERROR_RESPONSE
|