ium_444409/train_model.py
Marcin Kostrzewski 858e9ec215
All checks were successful
s444409-evaluation/pipeline/head This commit looks good
s444409-training/pipeline/head This commit looks good
Archive MLFlow artifacts
2022-05-09 18:23:12 +02:00

84 lines
2.7 KiB
Python

import torch
import mlflow
import argparse
import numpy as np
from torch import nn
from torch.utils.data import DataLoader
from urllib.parse import urlparse
from model import PlantsDataset, MLP, train, test
default_batch_size = 64
default_epochs = 5
device = "cuda" if torch.cuda.is_available() else "cpu"
#mlflow.set_tracking_uri("http://172.17.0.1:5000")
mlflow.set_experiment("s444409")
def setup_args():
args_parser = argparse.ArgumentParser(prefix_chars='-')
args_parser.add_argument('-b', '--batchSize', type=int, default=default_batch_size)
args_parser.add_argument('-e', '--epochs', type=int, default=default_epochs)
return args_parser.parse_args()
if __name__ == "__main__":
args = setup_args()
batch_size = args.batchSize
epochs = args.epochs
print(f"Using {device} device")
plant_test = PlantsDataset('data/Plant_1_Generation_Data.csv.test')
plant_train = PlantsDataset('data/Plant_1_Generation_Data.csv.train')
input_example = np.array([plant_test.x_train.numpy()[0]])
train_dataloader = DataLoader(plant_train, batch_size=batch_size)
test_dataloader = DataLoader(plant_test, batch_size=batch_size)
for i, (data, labels) in enumerate(train_dataloader):
print(data.shape, labels.shape)
print(data, labels)
break
model = MLP()
print(model)
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
print("Done!")
torch.save(model.state_dict(), './model_out')
print("Model saved in ./model_out file.")
with mlflow.start_run() as run:
print("MLflow run experiment_id: {0}".format(run.info.experiment_id))
print("MLflow run artifact_uri: {0}".format(run.info.artifact_uri))
mlflow.log_param("batch_size", batch_size)
mlflow.log_param("epochs", epochs)
for t in range(epochs):
print(f"Epoch {t + 1}\n-------------------------------")
train(train_dataloader, model, loss_fn, optimizer)
last_loss = test(test_dataloader, model, loss_fn)
mlflow.log_metric("rmse", last_loss)
with torch.no_grad():
preds = model(plant_test.x_train)
signature = mlflow.models.signature.infer_signature(plant_test.x_train.numpy(), preds.numpy())
tracking_url_type_store = urlparse(mlflow.get_tracking_uri()).scheme
if tracking_url_type_store != "file":
mlflow.pytorch.log_model(
model,
"s444409",
registered_model_name="s444409",
signature=signature,
input_example=input_example
)
else:
mlflow.pytorch.log_model(model, "s444409", signature=signature, input_example=input_example)