72 lines
1.9 KiB
Python
72 lines
1.9 KiB
Python
import argparse
|
|
|
|
import numpy as np
|
|
import pandas as pd
|
|
import torch
|
|
from sacred.observers import FileStorageObserver, MongoObserver
|
|
from torch import nn
|
|
from torch.utils.data import DataLoader, Dataset
|
|
from sacred import Experiment
|
|
|
|
from model import PlantsDataset, MLP, train, test
|
|
|
|
default_batch_size = 64
|
|
default_epochs = 5
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
|
|
def main(batch_size, epochs, _run):
|
|
print(f"Using {device} device")
|
|
|
|
plant_test = PlantsDataset('data/Plant_1_Generation_Data.csv.test')
|
|
plant_train = PlantsDataset('data/Plant_1_Generation_Data.csv.train')
|
|
|
|
train_dataloader = DataLoader(plant_train, batch_size=batch_size)
|
|
test_dataloader = DataLoader(plant_test, batch_size=batch_size)
|
|
|
|
for i, (data, labels) in enumerate(train_dataloader):
|
|
print(data.shape, labels.shape)
|
|
print(data, labels)
|
|
break
|
|
|
|
model = MLP()
|
|
print(model)
|
|
|
|
loss_fn = nn.MSELoss()
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
|
|
for t in range(epochs):
|
|
print(f"Epoch {t + 1}\n-------------------------------")
|
|
train(train_dataloader, model, loss_fn, optimizer)
|
|
last_loss = test(test_dataloader, model, loss_fn)
|
|
_run.log_scalar('training.loss', last_loss, t)
|
|
print("Done!")
|
|
|
|
torch.save(model.state_dict(), './model_out')
|
|
print("Model saved in ./model_out file.")
|
|
|
|
|
|
def setup_experiment():
|
|
ex = Experiment('Predict power output for a given time')
|
|
ex.observers.append(FileStorageObserver('sacred_runs'))
|
|
# ex.observers.append(MongoObserver(url='mongodb://mongo_user:mongo_password_IUM_2022@localhost:27017',
|
|
# db_name='sacred'))
|
|
return ex
|
|
|
|
|
|
ex = setup_experiment()
|
|
|
|
|
|
@ex.config
|
|
def experiment_config():
|
|
batch_size = 64
|
|
epochs = 5
|
|
|
|
|
|
@ex.automain
|
|
def run(batch_size, epochs, _run):
|
|
main(batch_size, epochs, _run)
|
|
|
|
|
|
ex.add_artifact('model_out')
|