import argparse import numpy as np import pandas as pd import torch from sacred.observers import FileStorageObserver, MongoObserver from torch import nn from torch.utils.data import DataLoader, Dataset from sacred import Experiment from model import PlantsDataset, MLP, train, test default_batch_size = 64 default_epochs = 5 device = "cuda" if torch.cuda.is_available() else "cpu" def main(batch_size, epochs, _run): print(f"Using {device} device") plant_test = PlantsDataset('data/Plant_1_Generation_Data.csv.test') plant_train = PlantsDataset('data/Plant_1_Generation_Data.csv.train') train_dataloader = DataLoader(plant_train, batch_size=batch_size) test_dataloader = DataLoader(plant_test, batch_size=batch_size) for i, (data, labels) in enumerate(train_dataloader): print(data.shape, labels.shape) print(data, labels) break model = MLP() print(model) loss_fn = nn.MSELoss() optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) for t in range(epochs): print(f"Epoch {t + 1}\n-------------------------------") train(train_dataloader, model, loss_fn, optimizer) last_loss = test(test_dataloader, model, loss_fn) _run.log_scalar('training.loss', last_loss, t) print("Done!") torch.save(model.state_dict(), './model_out') print("Model saved in ./model_out file.") def setup_experiment(): ex = Experiment('Predict power output for a given time') ex.observers.append(FileStorageObserver('sacred_runs')) ex.observers.append(MongoObserver(url='mongodb://admin:IUM_2021@172.17.0.1:27017', db_name='sacred')) return ex ex = setup_experiment() @ex.config def experiment_config(): batch_size = 64 epochs = 5 @ex.automain def run(batch_size, epochs, _run): main(batch_size, epochs, _run) ex.add_artifact('model_out')