From 771269fc4c0499fa56c939d3eca238f6df0a0cf5 Mon Sep 17 00:00:00 2001 From: Kamila Date: Sun, 15 May 2022 12:53:35 +0200 Subject: [PATCH] mlflow attempt task 2 prediction --- predict.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/predict.py b/predict.py index d80f183..f8cdb35 100644 --- a/predict.py +++ b/predict.py @@ -3,16 +3,17 @@ import mlflow import numpy as np import sys import tarfile - +import os file = tarfile.open('mlruns.tar.gz') file.extractall('./ml') input = str((sys.argv[1:])[0]) PATH = "ml/mlruns/1/f65f936936024133a2c03e1e486ba9cf/artifacts/model/" +print(os.listdir(PATH)) model = mlflow.pytorch.load_model(f"{PATH}/MLmodel") with open(f'[PATH]/{input}', 'r') as file: json_data = json.load(file) - + print(f"Input: {json_data['inputs'][0]}") print(f"Prediction: {model.predict(np.array([json_data['inputs'][0]], dtype=np.float32))}") \ No newline at end of file