DeBERTa classifier
This commit is contained in:
parent
f5fa1779c9
commit
a8125bba9d
34
simple_transformers.py
Normal file
34
simple_transformers.py
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
from simpletransformers.classification import ClassificationModel, ClassificationArgs
|
||||||
|
import pandas as pd
|
||||||
|
import logging
|
||||||
|
import torch
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
transformer_logger = logging.getLogger("transformers")
|
||||||
|
transformer_logger.setLevel(logging.WARNING)
|
||||||
|
|
||||||
|
|
||||||
|
train_df = pd.read_csv("train/train.tsv", sep="\t")
|
||||||
|
print(train_df)
|
||||||
|
|
||||||
|
dev_df = pd.read_csv("dev-0/dev.tsv", sep="\t")
|
||||||
|
print(dev_df)
|
||||||
|
|
||||||
|
|
||||||
|
args = {
|
||||||
|
'train_batch_size': 32,
|
||||||
|
'learning_rate': 2e-5,
|
||||||
|
'evaluate_during_training': True,
|
||||||
|
'save_steps': 1000,
|
||||||
|
'evaluate_during_training_steps': 1000,
|
||||||
|
'evaluate_during_training_verbose': True,
|
||||||
|
'overwrite_output_dir': True,
|
||||||
|
'save_eval_checkpoints': True,
|
||||||
|
'use_early_stopping': True,
|
||||||
|
'early_stopping_patience': 5,
|
||||||
|
'num_train_epochs': 3
|
||||||
|
}
|
||||||
|
|
||||||
|
model = ClassificationModel("deberta", "microsoft/deberta-base", use_cuda=True, args=args)
|
||||||
|
|
||||||
|
model.train_model(train_df, eval_df=dev_df)
|
19
simple_transformers_eval.py
Normal file
19
simple_transformers_eval.py
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
from simpletransformers.classification import ClassificationModel
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
model = ClassificationModel("deberta", "outputs/best_model")
|
||||||
|
|
||||||
|
dev_df = pd.read_csv("dev-0/dev.tsv", sep="\t")
|
||||||
|
|
||||||
|
result, model_outputs, wrong_predictions = model.eval_model(dev_df)
|
||||||
|
print(result)
|
||||||
|
tp = result["tp"]
|
||||||
|
fp = result["fp"]
|
||||||
|
tn = result["tn"]
|
||||||
|
fn = result["fn"]
|
||||||
|
print(f"Accuracy: {(tp+tn)/(tp+fp+tn+fn)}")
|
||||||
|
precision = tp/(tp+fp)
|
||||||
|
print(f"Precision: {precision}")
|
||||||
|
recall = tp/(tp+fn)
|
||||||
|
print(f"Recall: {recall}")
|
||||||
|
print(f"F1-score: {2*precision*recall/(precision+recall)}")
|
Loading…
Reference in New Issue
Block a user