Fix training args

This commit is contained in:
nlitkowski 2021-06-22 15:55:04 +02:00
parent 415cad97e2
commit 16771b5293

16
main.py
View File

@ -12,8 +12,8 @@ EXP_FILE_NAME = "expected.tsv"
FILE_SEP = "\t" FILE_SEP = "\t"
IN_HEADER_FILE_NAME = "in-header.tsv" IN_HEADER_FILE_NAME = "in-header.tsv"
OUT_HEADER_FILE_NAME = "out-header.tsv" OUT_HEADER_FILE_NAME = "out-header.tsv"
PT_MODEL_NAME = "bert-base-cased" # PT_MODEL_NAME = "bert-base-cased"
# PT_MODEL_NAME = "roberta-base" PT_MODEL_NAME = "roberta-base"
class CustomDataset(torch.utils.data.Dataset): class CustomDataset(torch.utils.data.Dataset):
@ -60,11 +60,13 @@ def main(dirnames):
trainer = Trainer( trainer = Trainer(
model=model, model=model,
args=TrainingArguments('./res'), args=TrainingArguments(
train_dataset=dataset, output_dir='./res',
num_train_epochs=5, num_train_epochs=5,
per_device_train_batch_size=16, per_device_train_batch_size=16,
per_device_eval_batch_size=16, per_device_eval_batch_size=16
),
train_dataset=dataset
) )
trainer.train() trainer.train()