ium_444409/train_model.py
Marcin Kostrzewski e93d644059
All checks were successful
s444409-evaluation/pipeline/head This commit looks good
s444409-training/pipeline/head This commit looks good
Log to url
2022-05-11 20:04:13 +02:00

88 lines
2.9 KiB
Python

import argparse
from urllib.parse import urlparse
import mlflow
import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader
from model import PlantsDataset, MLP, train, test
default_batch_size = 64
default_epochs = 5
device = "cuda" if torch.cuda.is_available() else "cpu"
mlflow.set_tracking_uri("http://172.17.0.1:5000")
mlflow.set_experiment("s444409")
def setup_args():
args_parser = argparse.ArgumentParser(prefix_chars='-')
args_parser.add_argument('-b', '--batchSize', type=int, default=default_batch_size)
args_parser.add_argument('-e', '--epochs', type=int, default=default_epochs)
return args_parser.parse_args()
if __name__ == "__main__":
args = setup_args()
batch_size = args.batchSize
epochs = args.epochs
print(f"Using {device} device")
plant_test = PlantsDataset('data/Plant_1_Generation_Data.csv.test')
plant_train = PlantsDataset('data/Plant_1_Generation_Data.csv.train')
input_example = np.array([plant_test.x_train.numpy()[0]])
train_dataloader = DataLoader(plant_train, batch_size=batch_size)
test_dataloader = DataLoader(plant_test, batch_size=batch_size)
for i, (data, labels) in enumerate(train_dataloader):
print(data.shape, labels.shape)
print(data, labels)
break
model = MLP()
print(model)
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
print("Done!")
torch.save(model.state_dict(), './model_out')
print("Model saved in ./model_out file.")
with mlflow.start_run() as run:
print("MLflow run experiment_id: {0}".format(run.info.experiment_id))
print("MLflow run artifact_uri: {0}".format(run.info.artifact_uri))
mlflow.log_param("batch_size", batch_size)
mlflow.log_param("epochs", epochs)
for t in range(epochs):
print(f"Epoch {t + 1}\n-------------------------------")
train(train_dataloader, model, loss_fn, optimizer)
last_loss = test(test_dataloader, model, loss_fn)
mlflow.log_metric("rmse", last_loss)
with torch.no_grad():
preds = model(plant_test.x_train)
signature = mlflow.models.signature.infer_signature(plant_test.x_train.numpy(), preds.numpy())
tracking_url_type_store = urlparse(mlflow.get_tracking_uri()).scheme
if tracking_url_type_store != "file":
mlflow.log_artifact('model.py')
mlflow.pytorch.log_model(
model,
"s444409",
registered_model_name="s444409",
signature=signature,
input_example=input_example,
code_paths=['model.py']
)
else:
mlflow.pytorch.log_model(model, "s444409", signature=signature, input_example=input_example,
code_paths=['model.py'])