19 lines
558 B
Python
19 lines
558 B
Python
from simpletransformers.classification import ClassificationModel
|
|
import pandas as pd
|
|
|
|
model = ClassificationModel("deberta", "outputs/best_model")
|
|
|
|
dev_df = pd.read_csv("dev-0/dev.tsv", sep="\t")
|
|
|
|
result, model_outputs, wrong_predictions = model.eval_model(dev_df)
|
|
print(result)
|
|
tp = result["tp"]
|
|
fp = result["fp"]
|
|
tn = result["tn"]
|
|
fn = result["fn"]
|
|
print(f"Accuracy: {(tp+tn)/(tp+fp+tn+fn)}")
|
|
precision = tp/(tp+fp)
|
|
print(f"Precision: {precision}")
|
|
recall = tp/(tp+fn)
|
|
print(f"Recall: {recall}")
|
|
print(f"F1-score: {2*precision*recall/(precision+recall)}") |