diff --git a/main.py b/main.py index a68ee09..53367f4 100644 --- a/main.py +++ b/main.py @@ -57,19 +57,16 @@ def main(dirnames): [t for t in train_set_features[in_cols].agg(' '.join, axis=1)], truncation=True, padding=True) dataset = CustomDataset( train_set_enc, [int(t) for t in train_set_labels[out_cols[0]]]) - + device = torch.device("cuda") + model.to(device) trainer = Trainer( model=model, - args=TrainingArguments( - output_dir='./res', - num_train_epochs=5, - per_device_train_batch_size=16, - per_device_eval_batch_size=16 - ), + args=TrainingArguments("model"), train_dataset=dataset ) - + print("Starting training...") trainer.train() + print("Predictions...") for i in range(len(in_sets)): p = os.path.join(dirnames[i], IN_FILE_NAME)