diff --git a/main.py b/main.py index a65ff80..a68ee09 100644 --- a/main.py +++ b/main.py @@ -12,8 +12,8 @@ EXP_FILE_NAME = "expected.tsv" FILE_SEP = "\t" IN_HEADER_FILE_NAME = "in-header.tsv" OUT_HEADER_FILE_NAME = "out-header.tsv" -PT_MODEL_NAME = "bert-base-cased" -# PT_MODEL_NAME = "roberta-base" +# PT_MODEL_NAME = "bert-base-cased" +PT_MODEL_NAME = "roberta-base" class CustomDataset(torch.utils.data.Dataset): @@ -60,11 +60,13 @@ def main(dirnames): trainer = Trainer( model=model, - args=TrainingArguments('./res'), - train_dataset=dataset, - num_train_epochs=5, - per_device_train_batch_size=16, - per_device_eval_batch_size=16, + args=TrainingArguments( + output_dir='./res', + num_train_epochs=5, + per_device_train_batch_size=16, + per_device_eval_batch_size=16 + ), + train_dataset=dataset ) trainer.train()