From c36ba1d48998878ac287c966805d222f79ee830d Mon Sep 17 00:00:00 2001 From: nlitkowski Date: Tue, 22 Jun 2021 16:41:34 +0200 Subject: [PATCH] Add GPU --- main.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/main.py b/main.py index a68ee09..53367f4 100644 --- a/main.py +++ b/main.py @@ -57,19 +57,16 @@ def main(dirnames): [t for t in train_set_features[in_cols].agg(' '.join, axis=1)], truncation=True, padding=True) dataset = CustomDataset( train_set_enc, [int(t) for t in train_set_labels[out_cols[0]]]) - + device = torch.device("cuda") + model.to(device) trainer = Trainer( model=model, - args=TrainingArguments( - output_dir='./res', - num_train_epochs=5, - per_device_train_batch_size=16, - per_device_eval_batch_size=16 - ), + args=TrainingArguments("model"), train_dataset=dataset ) - + print("Starting training...") trainer.train() + print("Predictions...") for i in range(len(in_sets)): p = os.path.join(dirnames[i], IN_FILE_NAME)