mlflow fix2

This commit is contained in:
Witold Woch 2023-05-14 23:17:15 +02:00
parent aa20f3caa4
commit 15bee98414

View File

@ -3,13 +3,14 @@ import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from sklearn.preprocessing import StandardScaler
import pandas as pd
import numpy as np
import os
import mlflow
import mlflow.sklearn
import mlflow.pytorch
from mlflow.models.signature import infer_signature
from urllib.parse import urlparse
import logging
logging.basicConfig(level=logging.WARN)
@ -84,7 +85,7 @@ with mlflow.start_run() as run:
for i in range(epochs):
i = i + 1
y_pred = model(X_train)
y_pred = model.forward(X_train)
loss = loss_function(y_pred, y_train)
final_losses.append(loss)
@ -101,7 +102,15 @@ with mlflow.start_run() as run:
loss.backward()
optimizer.step()
torch.save(model,"classificationn_model.pt")
# Infer model signature to log it
signature = infer_signature(X_train.numpy(), model(X_train).detach().numpy())
input_example = {"input": X_train[0].numpy().tolist()}
# Log model
tracking_url_type_store = urlparse(mlflow.get_tracking_uri()).scheme
if tracking_url_type_store != "file":
mlflow.pytorch.log_model(model, "model", signature=signature, input_example=input_example, registered_model_name="ClassificationModel")
else:
mlflow.pytorch.log_model(model, "model", signature=signature, input_example=input_example)