Fix training args

This commit is contained in:
nlitkowski 2021-06-22 15:55:04 +02:00
parent 415cad97e2
commit 16771b5293

16
main.py
View File

@ -12,8 +12,8 @@ EXP_FILE_NAME = "expected.tsv"
FILE_SEP = "\t"
IN_HEADER_FILE_NAME = "in-header.tsv"
OUT_HEADER_FILE_NAME = "out-header.tsv"
PT_MODEL_NAME = "bert-base-cased"
# PT_MODEL_NAME = "roberta-base"
# PT_MODEL_NAME = "bert-base-cased"
PT_MODEL_NAME = "roberta-base"
class CustomDataset(torch.utils.data.Dataset):
@ -60,11 +60,13 @@ def main(dirnames):
trainer = Trainer(
model=model,
args=TrainingArguments('./res'),
train_dataset=dataset,
num_train_epochs=5,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
args=TrainingArguments(
output_dir='./res',
num_train_epochs=5,
per_device_train_batch_size=16,
per_device_eval_batch_size=16
),
train_dataset=dataset
)
trainer.train()