ium_444409/train_model.py
emkarcinos 1a44512224
All checks were successful
s444409-training/pipeline/head This commit looks good
s444409-evaluation/pipeline/head This commit looks good
Enable Mongo
2022-05-09 10:03:04 +02:00

72 lines
1.9 KiB
Python

import argparse
import numpy as np
import pandas as pd
import torch
from sacred.observers import FileStorageObserver, MongoObserver
from torch import nn
from torch.utils.data import DataLoader, Dataset
from sacred import Experiment
from model import PlantsDataset, MLP, train, test
default_batch_size = 64
default_epochs = 5
device = "cuda" if torch.cuda.is_available() else "cpu"
def main(batch_size, epochs, _run):
print(f"Using {device} device")
plant_test = PlantsDataset('data/Plant_1_Generation_Data.csv.test')
plant_train = PlantsDataset('data/Plant_1_Generation_Data.csv.train')
train_dataloader = DataLoader(plant_train, batch_size=batch_size)
test_dataloader = DataLoader(plant_test, batch_size=batch_size)
for i, (data, labels) in enumerate(train_dataloader):
print(data.shape, labels.shape)
print(data, labels)
break
model = MLP()
print(model)
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
for t in range(epochs):
print(f"Epoch {t + 1}\n-------------------------------")
train(train_dataloader, model, loss_fn, optimizer)
last_loss = test(test_dataloader, model, loss_fn)
_run.log_scalar('training.loss', last_loss, t)
print("Done!")
torch.save(model.state_dict(), './model_out')
print("Model saved in ./model_out file.")
def setup_experiment():
ex = Experiment('Predict power output for a given time')
ex.observers.append(FileStorageObserver('sacred_runs'))
ex.observers.append(MongoObserver(url='mongodb://admin:IUM_2021@172.17.0.1:27017',
db_name='sacred'))
return ex
ex = setup_experiment()
@ex.config
def experiment_config():
batch_size = 64
epochs = 5
@ex.automain
def run(batch_size, epochs, _run):
main(batch_size, epochs, _run)
ex.add_artifact('model_out')