paranormal-or-skeptic-ISI-p.../simple_transformers.py

35 lines
926 B
Python

from simpletransformers.classification import ClassificationModel, ClassificationArgs
import pandas as pd
import logging
import torch
logging.basicConfig(level=logging.INFO)
transformer_logger = logging.getLogger("transformers")
transformer_logger.setLevel(logging.WARNING)
train_df = pd.read_csv("train/train.tsv", sep="\t")
print(train_df)
dev_df = pd.read_csv("dev-0/dev.tsv", sep="\t")
print(dev_df)
args = {
'train_batch_size': 32,
'learning_rate': 2e-5,
'evaluate_during_training': True,
'save_steps': 1000,
'evaluate_during_training_steps': 1000,
'evaluate_during_training_verbose': True,
'overwrite_output_dir': True,
'save_eval_checkpoints': True,
'use_early_stopping': True,
'early_stopping_patience': 5,
'num_train_epochs': 3
}
model = ClassificationModel("deberta", "microsoft/deberta-base", use_cuda=True, args=args)
model.train_model(train_df, eval_df=dev_df)