84 lines
2.7 KiB
Python
84 lines
2.7 KiB
Python
import torch
|
|
import mlflow
|
|
import argparse
|
|
import numpy as np
|
|
from torch import nn
|
|
from torch.utils.data import DataLoader
|
|
from urllib.parse import urlparse
|
|
|
|
from model import PlantsDataset, MLP, train, test
|
|
|
|
default_batch_size = 64
|
|
default_epochs = 5
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
#mlflow.set_tracking_uri("http://172.17.0.1:5000")
|
|
mlflow.set_experiment("s444409")
|
|
|
|
|
|
def setup_args():
|
|
args_parser = argparse.ArgumentParser(prefix_chars='-')
|
|
args_parser.add_argument('-b', '--batchSize', type=int, default=default_batch_size)
|
|
args_parser.add_argument('-e', '--epochs', type=int, default=default_epochs)
|
|
|
|
return args_parser.parse_args()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
args = setup_args()
|
|
batch_size = args.batchSize
|
|
epochs = args.epochs
|
|
|
|
print(f"Using {device} device")
|
|
|
|
plant_test = PlantsDataset('data/Plant_1_Generation_Data.csv.test')
|
|
plant_train = PlantsDataset('data/Plant_1_Generation_Data.csv.train')
|
|
|
|
input_example = np.array([plant_test.x_train.numpy()[0]])
|
|
|
|
train_dataloader = DataLoader(plant_train, batch_size=batch_size)
|
|
test_dataloader = DataLoader(plant_test, batch_size=batch_size)
|
|
|
|
for i, (data, labels) in enumerate(train_dataloader):
|
|
print(data.shape, labels.shape)
|
|
print(data, labels)
|
|
break
|
|
|
|
model = MLP()
|
|
print(model)
|
|
|
|
loss_fn = nn.MSELoss()
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
|
|
|
|
print("Done!")
|
|
|
|
torch.save(model.state_dict(), './model_out')
|
|
print("Model saved in ./model_out file.")
|
|
with mlflow.start_run() as run:
|
|
print("MLflow run experiment_id: {0}".format(run.info.experiment_id))
|
|
print("MLflow run artifact_uri: {0}".format(run.info.artifact_uri))
|
|
|
|
mlflow.log_param("batch_size", batch_size)
|
|
mlflow.log_param("epochs", epochs)
|
|
for t in range(epochs):
|
|
print(f"Epoch {t + 1}\n-------------------------------")
|
|
train(train_dataloader, model, loss_fn, optimizer)
|
|
last_loss = test(test_dataloader, model, loss_fn)
|
|
mlflow.log_metric("rmse", last_loss)
|
|
|
|
with torch.no_grad():
|
|
preds = model(plant_test.x_train)
|
|
signature = mlflow.models.signature.infer_signature(plant_test.x_train.numpy(), preds.numpy())
|
|
tracking_url_type_store = urlparse(mlflow.get_tracking_uri()).scheme
|
|
if tracking_url_type_store != "file":
|
|
mlflow.pytorch.log_model(
|
|
model,
|
|
"s444409",
|
|
registered_model_name="s444409",
|
|
signature=signature,
|
|
input_example=input_example
|
|
)
|
|
else:
|
|
mlflow.pytorch.log_model(model, "s444409", signature=signature, input_example=input_example)
|