mlflow fix2

This commit is contained in:
Witold Woch 2023-05-14 23:17:15 +02:00
parent aa20f3caa4
commit 15bee98414

View File

@ -3,13 +3,14 @@ import torch.nn as nn
import torch.optim as optim import torch.optim as optim
import torch.nn.functional as F import torch.nn.functional as F
from sklearn.model_selection import train_test_split from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from sklearn.preprocessing import StandardScaler from sklearn.preprocessing import StandardScaler
import pandas as pd import pandas as pd
import numpy as np import numpy as np
import os import os
import mlflow import mlflow
import mlflow.sklearn import mlflow.pytorch
from mlflow.models.signature import infer_signature
from urllib.parse import urlparse
import logging import logging
logging.basicConfig(level=logging.WARN) logging.basicConfig(level=logging.WARN)
@ -84,7 +85,7 @@ with mlflow.start_run() as run:
for i in range(epochs): for i in range(epochs):
i = i + 1 i = i + 1
y_pred = model(X_train) y_pred = model.forward(X_train)
loss = loss_function(y_pred, y_train) loss = loss_function(y_pred, y_train)
final_losses.append(loss) final_losses.append(loss)
@ -101,7 +102,15 @@ with mlflow.start_run() as run:
loss.backward() loss.backward()
optimizer.step() optimizer.step()
torch.save(model,"classificationn_model.pt") # Infer model signature to log it
signature = infer_signature(X_train.numpy(), model(X_train).detach().numpy())
input_example = {"input": X_train[0].numpy().tolist()}
# Log model
tracking_url_type_store = urlparse(mlflow.get_tracking_uri()).scheme
if tracking_url_type_store != "file":
mlflow.pytorch.log_model(model, "model", signature=signature, input_example=input_example, registered_model_name="ClassificationModel")
else:
mlflow.pytorch.log_model(model, "model", signature=signature, input_example=input_example)