ium_444409/train_model.py

69 lines
1.6 KiB
Python
Raw Normal View History

2022-05-06 20:20:22 +02:00
import argparse
2022-04-24 22:20:14 +02:00
import numpy as np
import pandas as pd
import torch
2022-05-06 21:37:04 +02:00
from sacred.observers import FileStorageObserver
2022-04-24 22:20:14 +02:00
from torch import nn
from torch.utils.data import DataLoader, Dataset
2022-05-06 20:43:53 +02:00
from sacred import Experiment
2022-04-24 22:20:14 +02:00
from model import PlantsDataset, MLP, train, test
2022-05-05 22:11:32 +02:00
default_batch_size = 64
default_epochs = 5
2022-05-05 22:33:34 +02:00
device = "cuda" if torch.cuda.is_available() else "cpu"
2022-05-06 20:43:53 +02:00
def main(batch_size, epochs):
2022-05-05 22:33:34 +02:00
print(f"Using {device} device")
plant_test = PlantsDataset('data/Plant_1_Generation_Data.csv.test')
plant_train = PlantsDataset('data/Plant_1_Generation_Data.csv.train')
2022-04-24 22:20:14 +02:00
2022-05-05 22:33:34 +02:00
train_dataloader = DataLoader(plant_train, batch_size=batch_size)
test_dataloader = DataLoader(plant_test, batch_size=batch_size)
2022-04-24 22:20:14 +02:00
2022-05-05 22:33:34 +02:00
for i, (data, labels) in enumerate(train_dataloader):
print(data.shape, labels.shape)
print(data, labels)
break
2022-04-24 22:20:14 +02:00
2022-05-05 22:33:34 +02:00
model = MLP()
print(model)
2022-04-24 22:20:14 +02:00
2022-05-05 22:33:34 +02:00
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
for t in range(epochs):
print(f"Epoch {t + 1}\n-------------------------------")
train(train_dataloader, model, loss_fn, optimizer)
test(test_dataloader, model, loss_fn)
print("Done!")
2022-04-24 22:20:14 +02:00
2022-05-05 22:33:34 +02:00
torch.save(model.state_dict(), './model_out')
print("Model saved in ./model_out file.")
2022-04-24 22:23:53 +02:00
2022-05-06 20:20:22 +02:00
2022-05-06 21:37:04 +02:00
def setup_experiment():
ex = Experiment('Predict power output for a given time')
ex.observers.append(FileStorageObserver('sacred_runs'))
return ex
ex = setup_experiment()
2022-05-06 20:43:53 +02:00
@ex.config
def experiment_config():
batch_size = 64
epochs = 5
@ex.automain
def run(batch_size, epochs):
main(batch_size, epochs)
2022-05-06 21:37:04 +02:00
ex.add_artifact('model_out')