import torch import mlflow import argparse from torch import nn from torch.utils.data import DataLoader from urllib.parse import urlparse from model import PlantsDataset, MLP, train, test default_batch_size = 64 default_epochs = 5 device = "cuda" if torch.cuda.is_available() else "cpu" mlflow.set_tracking_uri("http://172.17.0.1:5000") mlflow.set_experiment("s444409") def setup_args(): args_parser = argparse.ArgumentParser(prefix_chars='-') args_parser.add_argument('-b', '--batchSize', type=int, default=default_batch_size) args_parser.add_argument('-e', '--epochs', type=int, default=default_epochs) return args_parser.parse_args() if __name__ == "__main__": args = setup_args() batch_size = args.batchSize epochs = args.epochs print(f"Using {device} device") plant_test = PlantsDataset('data/Plant_1_Generation_Data.csv.test') plant_train = PlantsDataset('data/Plant_1_Generation_Data.csv.train') train_dataloader = DataLoader(plant_train, batch_size=batch_size) test_dataloader = DataLoader(plant_test, batch_size=batch_size) for i, (data, labels) in enumerate(train_dataloader): print(data.shape, labels.shape) print(data, labels) break model = MLP() print(model) loss_fn = nn.MSELoss() optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) print("Done!") torch.save(model.state_dict(), './model_out') print("Model saved in ./model_out file.") with mlflow.start_run() as run: print("MLflow run experiment_id: {0}".format(run.info.experiment_id)) print("MLflow run artifact_uri: {0}".format(run.info.artifact_uri)) mlflow.log_param("batch_size", batch_size) mlflow.log_param("epochs", epochs) for t in range(epochs): print(f"Epoch {t + 1}\n-------------------------------") train(train_dataloader, model, loss_fn, optimizer) last_loss = test(test_dataloader, model, loss_fn) mlflow.log_metric("rmse", last_loss) with torch.no_grad(): preds = model(plant_test.x_train) signature = mlflow.models.signature.infer_signature(plant_test.x_train.numpy(), preds.numpy()) tracking_url_type_store = urlparse(mlflow.get_tracking_uri()).scheme if tracking_url_type_store != "file": mlflow.pytorch.log_model( model, "s444409-power-plant-model", registered_model_name="s444409PowerPlant", signature=signature ) else: mlflow.pytorch.log_model(model, "s444409-power-plant-model", signature=signature)