35 lines
926 B
Python
35 lines
926 B
Python
from simpletransformers.classification import ClassificationModel, ClassificationArgs
|
|
import pandas as pd
|
|
import logging
|
|
import torch
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
transformer_logger = logging.getLogger("transformers")
|
|
transformer_logger.setLevel(logging.WARNING)
|
|
|
|
|
|
train_df = pd.read_csv("train/train.tsv", sep="\t")
|
|
print(train_df)
|
|
|
|
dev_df = pd.read_csv("dev-0/dev.tsv", sep="\t")
|
|
print(dev_df)
|
|
|
|
|
|
args = {
|
|
'train_batch_size': 32,
|
|
'learning_rate': 2e-5,
|
|
'evaluate_during_training': True,
|
|
'save_steps': 1000,
|
|
'evaluate_during_training_steps': 1000,
|
|
'evaluate_during_training_verbose': True,
|
|
'overwrite_output_dir': True,
|
|
'save_eval_checkpoints': True,
|
|
'use_early_stopping': True,
|
|
'early_stopping_patience': 5,
|
|
'num_train_epochs': 3
|
|
}
|
|
|
|
model = ClassificationModel("deberta", "microsoft/deberta-base", use_cuda=True, args=args)
|
|
|
|
model.train_model(train_df, eval_df=dev_df)
|