ium_444409/train_model.py
Marcin Kostrzewski d0ab5ae997
All checks were successful
s444409-evaluation/pipeline/head This commit looks good
s444409-training/pipeline/head This commit looks good
Move common methods and functions to model.py
2022-05-06 21:51:49 +02:00

69 lines
1.6 KiB
Python

import argparse
import numpy as np
import pandas as pd
import torch
from sacred.observers import FileStorageObserver
from torch import nn
from torch.utils.data import DataLoader, Dataset
from sacred import Experiment
from model import PlantsDataset, MLP, train, test
default_batch_size = 64
default_epochs = 5
device = "cuda" if torch.cuda.is_available() else "cpu"
def main(batch_size, epochs):
print(f"Using {device} device")
plant_test = PlantsDataset('data/Plant_1_Generation_Data.csv.test')
plant_train = PlantsDataset('data/Plant_1_Generation_Data.csv.train')
train_dataloader = DataLoader(plant_train, batch_size=batch_size)
test_dataloader = DataLoader(plant_test, batch_size=batch_size)
for i, (data, labels) in enumerate(train_dataloader):
print(data.shape, labels.shape)
print(data, labels)
break
model = MLP()
print(model)
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
for t in range(epochs):
print(f"Epoch {t + 1}\n-------------------------------")
train(train_dataloader, model, loss_fn, optimizer)
test(test_dataloader, model, loss_fn)
print("Done!")
torch.save(model.state_dict(), './model_out')
print("Model saved in ./model_out file.")
def setup_experiment():
ex = Experiment('Predict power output for a given time')
ex.observers.append(FileStorageObserver('sacred_runs'))
return ex
ex = setup_experiment()
@ex.config
def experiment_config():
batch_size = 64
epochs = 5
@ex.automain
def run(batch_size, epochs):
main(batch_size, epochs)
ex.add_artifact('model_out')