Mlflow predict from s434704

This commit is contained in:
Jan Nowak 2021-05-23 19:47:55 +02:00
parent 44840aecf2
commit 39bec6d769
4 changed files with 15 additions and 17 deletions

2
.gitignore vendored
View File

@ -9,3 +9,5 @@ mlruns
my_model my_model
1/ 1/
mydb.sqlite mydb.sqlite
movies_on_streaming_platforms_model.zip
movies_on_streaming_platforms_model

View File

@ -15,6 +15,7 @@ RUN pip3 install torch==1.8.1+cpu torchvision==0.9.1+cpu torchaudio==0.8.1 -f ht
RUN pip3 install sacred RUN pip3 install sacred
RUN pip3 install pymongo RUN pip3 install pymongo
RUN pip3 install mlflow RUN pip3 install mlflow
RUN pip3 install tensorflow==2.5.0rc1
# Stwórzmy w kontenerze (jeśli nie istnieje) katalog /app i przejdźmy do niego (wszystkie kolejne polecenia RUN, CMD, ENTRYPOINT, COPY i ADD będą w nim wykonywane) # Stwórzmy w kontenerze (jeśli nie istnieje) katalog /app i przejdźmy do niego (wszystkie kolejne polecenia RUN, CMD, ENTRYPOINT, COPY i ADD będą w nim wykonywane)
WORKDIR /app WORKDIR /app

View File

@ -1,17 +1,14 @@
import mlflow import mlflow
import mlflow.pytorch import mlflow.keras
import sys import sys
import json import json
import numpy as np
import torch
input = sys.argv[1] input = sys.argv[1]
model = mlflow.pytorch.load_model("my_model") model = mlflow.keras.load_model("movies_on_streaming_platforms_model")
with open('my_model/'+input) as json_file: with open('movies_on_streaming_platforms_model/'+input) as json_file:
data = json.load(json_file) data = json.load(json_file)
#print(np.array(data['inputs'])) #print(data)
print(model(torch.tensor(np.array(data['inputs'])).float())) print(model.predict(data['inputs']))

View File

@ -1,16 +1,14 @@
import mlflow
import mlflow.pytorch
from mlflow.tracking import MlflowClient from mlflow.tracking import MlflowClient
import numpy as np import mlflow
import torch import mlflow.keras
import json import json
#mlflow.set_tracking_uri("http://127.0.0.1:5000") #mlflow.set_tracking_uri("http://127.0.0.1:5000")
mlflow.set_tracking_uri("http://172.17.0.1:5000") mlflow.set_tracking_uri("http://172.17.0.1:5000")
client = MlflowClient() client = MlflowClient()
version = 0 version = 0
model_name = "s426206" model_name = "s434704"
for mv in client.search_model_versions("name='s426206'"): for mv in client.search_model_versions(f"name='{model_name}'"):
if int(mv.version) > version: if int(mv.version) > version:
version = int(mv.version) version = int(mv.version)
@ -18,7 +16,7 @@ model = mlflow.pytorch.load_model(
model_uri=f"models:/{model_name}/{version}" model_uri=f"models:/{model_name}/{version}"
) )
with open('my_model/input_example.json') as json_file: with open('movies_on_streaming_platforms_model/input_example.json') as json_file:
data = json.load(json_file) data = json.load(json_file)
#print(np.array(data['inputs'])) #print(np.array(data['inputs']))
print(model(torch.tensor(np.array(data['inputs'])).float())) print(model.predict(data['inputs']))