From 16771b5293034500113da9c01cdd2c68d91b6884 Mon Sep 17 00:00:00 2001 From: nlitkowski Date: Tue, 22 Jun 2021 15:55:04 +0200 Subject: [PATCH] Fix training args --- main.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/main.py b/main.py index a65ff80..a68ee09 100644 --- a/main.py +++ b/main.py @@ -12,8 +12,8 @@ EXP_FILE_NAME = "expected.tsv" FILE_SEP = "\t" IN_HEADER_FILE_NAME = "in-header.tsv" OUT_HEADER_FILE_NAME = "out-header.tsv" -PT_MODEL_NAME = "bert-base-cased" -# PT_MODEL_NAME = "roberta-base" +# PT_MODEL_NAME = "bert-base-cased" +PT_MODEL_NAME = "roberta-base" class CustomDataset(torch.utils.data.Dataset): @@ -60,11 +60,13 @@ def main(dirnames): trainer = Trainer( model=model, - args=TrainingArguments('./res'), - train_dataset=dataset, - num_train_epochs=5, - per_device_train_batch_size=16, - per_device_eval_batch_size=16, + args=TrainingArguments( + output_dir='./res', + num_train_epochs=5, + per_device_train_batch_size=16, + per_device_eval_batch_size=16 + ), + train_dataset=dataset ) trainer.train()