DeBERTa classifier

This commit is contained in:
MrPoldi 2022-06-19 18:32:53 +02:00
parent f5fa1779c9
commit a8125bba9d
2 changed files with 53 additions and 0 deletions

34
simple_transformers.py Normal file
View File

@ -0,0 +1,34 @@
from simpletransformers.classification import ClassificationModel, ClassificationArgs
import pandas as pd
import logging
import torch
logging.basicConfig(level=logging.INFO)
transformer_logger = logging.getLogger("transformers")
transformer_logger.setLevel(logging.WARNING)
train_df = pd.read_csv("train/train.tsv", sep="\t")
print(train_df)
dev_df = pd.read_csv("dev-0/dev.tsv", sep="\t")
print(dev_df)
args = {
'train_batch_size': 32,
'learning_rate': 2e-5,
'evaluate_during_training': True,
'save_steps': 1000,
'evaluate_during_training_steps': 1000,
'evaluate_during_training_verbose': True,
'overwrite_output_dir': True,
'save_eval_checkpoints': True,
'use_early_stopping': True,
'early_stopping_patience': 5,
'num_train_epochs': 3
}
model = ClassificationModel("deberta", "microsoft/deberta-base", use_cuda=True, args=args)
model.train_model(train_df, eval_df=dev_df)

View File

@ -0,0 +1,19 @@
from simpletransformers.classification import ClassificationModel
import pandas as pd
model = ClassificationModel("deberta", "outputs/best_model")
dev_df = pd.read_csv("dev-0/dev.tsv", sep="\t")
result, model_outputs, wrong_predictions = model.eval_model(dev_df)
print(result)
tp = result["tp"]
fp = result["fp"]
tn = result["tn"]
fn = result["fn"]
print(f"Accuracy: {(tp+tn)/(tp+fp+tn+fn)}")
precision = tp/(tp+fp)
print(f"Precision: {precision}")
recall = tp/(tp+fn)
print(f"Recall: {recall}")
print(f"F1-score: {2*precision*recall/(precision+recall)}")