mlflow fix2
This commit is contained in:
parent
aa20f3caa4
commit
15bee98414
31
train.py
31
train.py
@ -3,13 +3,14 @@ import torch.nn as nn
|
|||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from sklearn.model_selection import train_test_split
|
from sklearn.model_selection import train_test_split
|
||||||
from sklearn.metrics import accuracy_score
|
|
||||||
from sklearn.preprocessing import StandardScaler
|
from sklearn.preprocessing import StandardScaler
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import os
|
import os
|
||||||
import mlflow
|
import mlflow
|
||||||
import mlflow.sklearn
|
import mlflow.pytorch
|
||||||
|
from mlflow.models.signature import infer_signature
|
||||||
|
from urllib.parse import urlparse
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
logging.basicConfig(level=logging.WARN)
|
logging.basicConfig(level=logging.WARN)
|
||||||
@ -27,7 +28,7 @@ bike = data.loc[:, ['Customer_Age', 'Customer_Gender', 'Country','State', 'Produ
|
|||||||
bikes = pd.get_dummies(bike, columns=['Country', 'State', 'Product_Category', 'Sub_Category', 'Customer_Gender'])
|
bikes = pd.get_dummies(bike, columns=['Country', 'State', 'Product_Category', 'Sub_Category', 'Customer_Gender'])
|
||||||
X = bikes.drop('Profit_Category', axis=1).values
|
X = bikes.drop('Profit_Category', axis=1).values
|
||||||
y = bikes['Profit_Category'].values
|
y = bikes['Profit_Category'].values
|
||||||
X_train, X_test, y_train, y_test=train_test_split(X,y,test_size=0.2,random_state=0)
|
X_train, X_test, y_train, y_test = train_test_split(X,y,test_size=0.2,random_state=0)
|
||||||
scaler = StandardScaler()
|
scaler = StandardScaler()
|
||||||
X = scaler.fit_transform(X)
|
X = scaler.fit_transform(X)
|
||||||
#### Tworzenie tensorów
|
#### Tworzenie tensorów
|
||||||
@ -35,10 +36,10 @@ X_train = X_train.astype(np.float32)
|
|||||||
X_test = X_test.astype(np.float32)
|
X_test = X_test.astype(np.float32)
|
||||||
y_train = y_train.astype(np.float32)
|
y_train = y_train.astype(np.float32)
|
||||||
y_test = y_test.astype(np.float32)
|
y_test = y_test.astype(np.float32)
|
||||||
X_train=torch.FloatTensor(X_train)
|
X_train = torch.FloatTensor(X_train)
|
||||||
X_test=torch.FloatTensor(X_test)
|
X_test = torch.FloatTensor(X_test)
|
||||||
y_train=torch.LongTensor(y_train)
|
y_train = torch.LongTensor(y_train)
|
||||||
y_test=torch.LongTensor(y_test)
|
y_test = torch.LongTensor(y_test)
|
||||||
|
|
||||||
#### Model
|
#### Model
|
||||||
|
|
||||||
@ -56,7 +57,7 @@ class ANN_Model(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
torch.manual_seed(20)
|
torch.manual_seed(20)
|
||||||
model=ANN_Model()
|
model = ANN_Model()
|
||||||
|
|
||||||
def calculate_accuracy(model, X, y):
|
def calculate_accuracy(model, X, y):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@ -84,7 +85,7 @@ with mlflow.start_run() as run:
|
|||||||
|
|
||||||
for i in range(epochs):
|
for i in range(epochs):
|
||||||
i = i + 1
|
i = i + 1
|
||||||
y_pred = model(X_train)
|
y_pred = model.forward(X_train)
|
||||||
loss = loss_function(y_pred, y_train)
|
loss = loss_function(y_pred, y_train)
|
||||||
final_losses.append(loss)
|
final_losses.append(loss)
|
||||||
|
|
||||||
@ -101,7 +102,15 @@ with mlflow.start_run() as run:
|
|||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
torch.save(model,"classificationn_model.pt")
|
# Infer model signature to log it
|
||||||
|
signature = infer_signature(X_train.numpy(), model(X_train).detach().numpy())
|
||||||
|
input_example = {"input": X_train[0].numpy().tolist()}
|
||||||
|
|
||||||
|
# Log model
|
||||||
|
tracking_url_type_store = urlparse(mlflow.get_tracking_uri()).scheme
|
||||||
|
if tracking_url_type_store != "file":
|
||||||
|
mlflow.pytorch.log_model(model, "model", signature=signature, input_example=input_example, registered_model_name="ClassificationModel")
|
||||||
|
else:
|
||||||
|
mlflow.pytorch.log_model(model, "model", signature=signature, input_example=input_example)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user