Add GPU
This commit is contained in:
parent
0249514499
commit
c36ba1d489
13
main.py
13
main.py
@ -57,19 +57,16 @@ def main(dirnames):
|
||||
[t for t in train_set_features[in_cols].agg(' '.join, axis=1)], truncation=True, padding=True)
|
||||
dataset = CustomDataset(
|
||||
train_set_enc, [int(t) for t in train_set_labels[out_cols[0]]])
|
||||
|
||||
device = torch.device("cuda")
|
||||
model.to(device)
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=TrainingArguments(
|
||||
output_dir='./res',
|
||||
num_train_epochs=5,
|
||||
per_device_train_batch_size=16,
|
||||
per_device_eval_batch_size=16
|
||||
),
|
||||
args=TrainingArguments("model"),
|
||||
train_dataset=dataset
|
||||
)
|
||||
|
||||
print("Starting training...")
|
||||
trainer.train()
|
||||
print("Predictions...")
|
||||
|
||||
for i in range(len(in_sets)):
|
||||
p = os.path.join(dirnames[i], IN_FILE_NAME)
|
||||
|
Loading…
Reference in New Issue
Block a user