mlflow fix2
This commit is contained in:
parent
aa20f3caa4
commit
15bee98414
19
train.py
19
train.py
@ -3,13 +3,14 @@ import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import torch.nn.functional as F
|
||||
from sklearn.model_selection import train_test_split
|
||||
from sklearn.metrics import accuracy_score
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import os
|
||||
import mlflow
|
||||
import mlflow.sklearn
|
||||
import mlflow.pytorch
|
||||
from mlflow.models.signature import infer_signature
|
||||
from urllib.parse import urlparse
|
||||
import logging
|
||||
|
||||
logging.basicConfig(level=logging.WARN)
|
||||
@ -84,7 +85,7 @@ with mlflow.start_run() as run:
|
||||
|
||||
for i in range(epochs):
|
||||
i = i + 1
|
||||
y_pred = model(X_train)
|
||||
y_pred = model.forward(X_train)
|
||||
loss = loss_function(y_pred, y_train)
|
||||
final_losses.append(loss)
|
||||
|
||||
@ -101,7 +102,15 @@ with mlflow.start_run() as run:
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
torch.save(model,"classificationn_model.pt")
|
||||
|
||||
# Infer model signature to log it
|
||||
signature = infer_signature(X_train.numpy(), model(X_train).detach().numpy())
|
||||
input_example = {"input": X_train[0].numpy().tolist()}
|
||||
|
||||
# Log model
|
||||
tracking_url_type_store = urlparse(mlflow.get_tracking_uri()).scheme
|
||||
if tracking_url_type_store != "file":
|
||||
mlflow.pytorch.log_model(model, "model", signature=signature, input_example=input_example, registered_model_name="ClassificationModel")
|
||||
else:
|
||||
mlflow.pytorch.log_model(model, "model", signature=signature, input_example=input_example)
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user