24 lines
674 B
Python
24 lines
674 B
Python
|
import mlflow
|
||
|
import mlflow.pytorch
|
||
|
from mlflow.tracking import MlflowClient
|
||
|
import numpy as np
|
||
|
import torch
|
||
|
import json
|
||
|
|
||
|
#mlflow.set_tracking_uri("http://127.0.0.1:5000")
|
||
|
mlflow.set_tracking_uri("http://172.17.0.1:5000")
|
||
|
client = MlflowClient()
|
||
|
version = 0
|
||
|
model_name = "s426206"
|
||
|
for mv in client.search_model_versions("name='s426206'"):
|
||
|
if int(mv.version) > version:
|
||
|
version = int(mv.version)
|
||
|
|
||
|
model = mlflow.pytorch.load_model(
|
||
|
model_uri=f"models:/{model_name}/{version}"
|
||
|
)
|
||
|
|
||
|
with open('my_model/input_example.json') as json_file:
|
||
|
data = json.load(json_file)
|
||
|
#print(np.array(data['inputs']))
|
||
|
print(model(torch.tensor(np.array(data['inputs'])).float()))
|