Fix training args
This commit is contained in:
parent
415cad97e2
commit
16771b5293
12
main.py
12
main.py
@ -12,8 +12,8 @@ EXP_FILE_NAME = "expected.tsv"
|
||||
FILE_SEP = "\t"
|
||||
IN_HEADER_FILE_NAME = "in-header.tsv"
|
||||
OUT_HEADER_FILE_NAME = "out-header.tsv"
|
||||
PT_MODEL_NAME = "bert-base-cased"
|
||||
# PT_MODEL_NAME = "roberta-base"
|
||||
# PT_MODEL_NAME = "bert-base-cased"
|
||||
PT_MODEL_NAME = "roberta-base"
|
||||
|
||||
|
||||
class CustomDataset(torch.utils.data.Dataset):
|
||||
@ -60,11 +60,13 @@ def main(dirnames):
|
||||
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=TrainingArguments('./res'),
|
||||
train_dataset=dataset,
|
||||
args=TrainingArguments(
|
||||
output_dir='./res',
|
||||
num_train_epochs=5,
|
||||
per_device_train_batch_size=16,
|
||||
per_device_eval_batch_size=16,
|
||||
per_device_eval_batch_size=16
|
||||
),
|
||||
train_dataset=dataset
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
|
Loading…
Reference in New Issue
Block a user