Fix training args
This commit is contained in:
parent
415cad97e2
commit
16771b5293
16
main.py
16
main.py
@ -12,8 +12,8 @@ EXP_FILE_NAME = "expected.tsv"
|
|||||||
FILE_SEP = "\t"
|
FILE_SEP = "\t"
|
||||||
IN_HEADER_FILE_NAME = "in-header.tsv"
|
IN_HEADER_FILE_NAME = "in-header.tsv"
|
||||||
OUT_HEADER_FILE_NAME = "out-header.tsv"
|
OUT_HEADER_FILE_NAME = "out-header.tsv"
|
||||||
PT_MODEL_NAME = "bert-base-cased"
|
# PT_MODEL_NAME = "bert-base-cased"
|
||||||
# PT_MODEL_NAME = "roberta-base"
|
PT_MODEL_NAME = "roberta-base"
|
||||||
|
|
||||||
|
|
||||||
class CustomDataset(torch.utils.data.Dataset):
|
class CustomDataset(torch.utils.data.Dataset):
|
||||||
@ -60,11 +60,13 @@ def main(dirnames):
|
|||||||
|
|
||||||
trainer = Trainer(
|
trainer = Trainer(
|
||||||
model=model,
|
model=model,
|
||||||
args=TrainingArguments('./res'),
|
args=TrainingArguments(
|
||||||
train_dataset=dataset,
|
output_dir='./res',
|
||||||
num_train_epochs=5,
|
num_train_epochs=5,
|
||||||
per_device_train_batch_size=16,
|
per_device_train_batch_size=16,
|
||||||
per_device_eval_batch_size=16,
|
per_device_eval_batch_size=16
|
||||||
|
),
|
||||||
|
train_dataset=dataset
|
||||||
)
|
)
|
||||||
|
|
||||||
trainer.train()
|
trainer.train()
|
||||||
|
Loading…
Reference in New Issue
Block a user