ium_444409/train_model.py
2022-05-06 22:09:36 +02:00

70 lines
1.7 KiB
Python

import argparse
import numpy as np
import pandas as pd
import torch
from sacred.observers import FileStorageObserver
from torch import nn
from torch.utils.data import DataLoader, Dataset
from sacred import Experiment
from model import PlantsDataset, MLP, train, test
default_batch_size = 64
default_epochs = 5
device = "cuda" if torch.cuda.is_available() else "cpu"
def main(batch_size, epochs, _run):
print(f"Using {device} device")
plant_test = PlantsDataset('data/Plant_1_Generation_Data.csv.test')
plant_train = PlantsDataset('data/Plant_1_Generation_Data.csv.train')
train_dataloader = DataLoader(plant_train, batch_size=batch_size)
test_dataloader = DataLoader(plant_test, batch_size=batch_size)
for i, (data, labels) in enumerate(train_dataloader):
print(data.shape, labels.shape)
print(data, labels)
break
model = MLP()
print(model)
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
for t in range(epochs):
print(f"Epoch {t + 1}\n-------------------------------")
train(train_dataloader, model, loss_fn, optimizer)
last_loss = test(test_dataloader, model, loss_fn)
_run.log_scalar('training.loss', last_loss, t)
print("Done!")
torch.save(model.state_dict(), './model_out')
print("Model saved in ./model_out file.")
def setup_experiment():
ex = Experiment('Predict power output for a given time')
ex.observers.append(FileStorageObserver('sacred_runs'))
return ex
ex = setup_experiment()
@ex.config
def experiment_config():
batch_size = 64
epochs = 5
@ex.automain
def run(batch_size, epochs, _run):
main(batch_size, epochs, _run)
ex.add_artifact('model_out')