This commit is contained in:
nlitkowski 2021-06-22 16:41:34 +02:00
parent 0249514499
commit c36ba1d489

13
main.py
View File

@ -57,19 +57,16 @@ def main(dirnames):
[t for t in train_set_features[in_cols].agg(' '.join, axis=1)], truncation=True, padding=True)
dataset = CustomDataset(
train_set_enc, [int(t) for t in train_set_labels[out_cols[0]]])
device = torch.device("cuda")
model.to(device)
trainer = Trainer(
model=model,
args=TrainingArguments(
output_dir='./res',
num_train_epochs=5,
per_device_train_batch_size=16,
per_device_eval_batch_size=16
),
args=TrainingArguments("model"),
train_dataset=dataset
)
print("Starting training...")
trainer.train()
print("Predictions...")
for i in range(len(in_sets)):
p = os.path.join(dirnames[i], IN_FILE_NAME)