ium_444409/train_model.py

84 lines
2.7 KiB
Python
Raw Normal View History

2022-04-24 22:20:14 +02:00
import torch
2022-05-09 11:20:41 +02:00
import mlflow
import argparse
2022-05-09 18:23:12 +02:00
import numpy as np
2022-04-24 22:20:14 +02:00
from torch import nn
2022-05-09 11:20:41 +02:00
from torch.utils.data import DataLoader
from urllib.parse import urlparse
2022-04-24 22:20:14 +02:00
from model import PlantsDataset, MLP, train, test
2022-05-05 22:11:32 +02:00
default_batch_size = 64
default_epochs = 5
2022-05-05 22:33:34 +02:00
device = "cuda" if torch.cuda.is_available() else "cpu"
2022-05-09 18:30:08 +02:00
mlflow.set_tracking_uri("http://172.17.0.1:5000")
2022-05-09 11:20:41 +02:00
mlflow.set_experiment("s444409")
def setup_args():
args_parser = argparse.ArgumentParser(prefix_chars='-')
args_parser.add_argument('-b', '--batchSize', type=int, default=default_batch_size)
args_parser.add_argument('-e', '--epochs', type=int, default=default_epochs)
return args_parser.parse_args()
if __name__ == "__main__":
args = setup_args()
batch_size = args.batchSize
epochs = args.epochs
2022-05-05 22:33:34 +02:00
print(f"Using {device} device")
plant_test = PlantsDataset('data/Plant_1_Generation_Data.csv.test')
plant_train = PlantsDataset('data/Plant_1_Generation_Data.csv.train')
2022-04-24 22:20:14 +02:00
2022-05-09 18:23:12 +02:00
input_example = np.array([plant_test.x_train.numpy()[0]])
2022-05-05 22:33:34 +02:00
train_dataloader = DataLoader(plant_train, batch_size=batch_size)
test_dataloader = DataLoader(plant_test, batch_size=batch_size)
2022-04-24 22:20:14 +02:00
2022-05-05 22:33:34 +02:00
for i, (data, labels) in enumerate(train_dataloader):
print(data.shape, labels.shape)
print(data, labels)
break
2022-04-24 22:20:14 +02:00
2022-05-05 22:33:34 +02:00
model = MLP()
print(model)
2022-04-24 22:20:14 +02:00
2022-05-05 22:33:34 +02:00
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
2022-05-09 11:20:41 +02:00
2022-05-05 22:33:34 +02:00
print("Done!")
2022-04-24 22:20:14 +02:00
2022-05-05 22:33:34 +02:00
torch.save(model.state_dict(), './model_out')
print("Model saved in ./model_out file.")
2022-05-09 11:20:41 +02:00
with mlflow.start_run() as run:
print("MLflow run experiment_id: {0}".format(run.info.experiment_id))
print("MLflow run artifact_uri: {0}".format(run.info.artifact_uri))
mlflow.log_param("batch_size", batch_size)
mlflow.log_param("epochs", epochs)
for t in range(epochs):
print(f"Epoch {t + 1}\n-------------------------------")
train(train_dataloader, model, loss_fn, optimizer)
last_loss = test(test_dataloader, model, loss_fn)
mlflow.log_metric("rmse", last_loss)
with torch.no_grad():
preds = model(plant_test.x_train)
signature = mlflow.models.signature.infer_signature(plant_test.x_train.numpy(), preds.numpy())
tracking_url_type_store = urlparse(mlflow.get_tracking_uri()).scheme
if tracking_url_type_store != "file":
mlflow.pytorch.log_model(
model,
2022-05-09 18:23:12 +02:00
"s444409",
registered_model_name="s444409",
signature=signature,
input_example=input_example
2022-05-09 11:20:41 +02:00
)
else:
2022-05-09 18:23:12 +02:00
mlflow.pytorch.log_model(model, "s444409", signature=signature, input_example=input_example)