ium_426206/generate_MLmodel.py

43 lines
1.6 KiB
Python
Raw Normal View History

2021-05-22 23:16:29 +02:00
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, TensorDataset, DataLoader
import mlflow
import mlflow.pytorch
from urllib.parse import urlparse
from mlflow.models.signature import infer_signature
class LayerLinearRegression(nn.Module):
def __init__(self):
super().__init__()
# Instead of our custom parameters, we use a Linear layer with single input and single output
self.linear = nn.Linear(1, 1)
2021-05-22 23:37:10 +02:00
2021-05-22 23:16:29 +02:00
def forward(self, x):
# Now it only takes a call to the layer to make predictions
return self.linear(x)
2021-05-22 23:37:10 +02:00
checkpoint = torch.load('model.pt')
2021-05-22 23:16:29 +02:00
2021-05-22 23:37:10 +02:00
model = LayerLinearRegression()
#optimizer = optim.SGD(model.parameters(), lr=checkpoint['loss'])
2021-05-22 23:16:29 +02:00
2021-05-22 23:37:10 +02:00
model.load_state_dict(checkpoint['model_state_dict'])
2021-05-22 23:16:29 +02:00
2021-05-22 23:37:10 +02:00
train_dataset = torch.load('train_dataset.pt')
x_train = np.array(train_dataset)[:,0] #(Sales Sum row)
input_example = np.reshape(x_train, (-1,1))
2021-05-22 23:16:29 +02:00
2021-05-22 23:37:10 +02:00
with torch.no_grad():
model.eval()
siganture = infer_signature(x_train, model(torch.tensor(np.reshape(x_train, (-1,1))).float()).numpy())
mlflow.set_tracking_uri("http://172.17.0.1:5000")
2021-05-22 23:16:29 +02:00
tracking_url_type_store = urlparse(mlflow.get_tracking_uri()).scheme
2021-05-22 23:37:10 +02:00
# print(tracking_url_type_store)
2021-05-22 23:16:29 +02:00
# Model registry does not work with file store
if tracking_url_type_store != "file":
2021-05-22 23:37:10 +02:00
mlflow.sklearn.log_model(model, "model", registered_model_name="s426206", signature=siganture, input_example=input_example)
2021-05-22 23:16:29 +02:00
else:
mlflow.sklearn.log_model(model, "model", signature=siganture, input_example=input_example)