diff --git a/predict.py b/predict.py index f8cdb35..af22b09 100644 --- a/predict.py +++ b/predict.py @@ -3,14 +3,13 @@ import mlflow import numpy as np import sys import tarfile -import os + file = tarfile.open('mlruns.tar.gz') file.extractall('./ml') input = str((sys.argv[1:])[0]) PATH = "ml/mlruns/1/f65f936936024133a2c03e1e486ba9cf/artifacts/model/" -print(os.listdir(PATH)) -model = mlflow.pytorch.load_model(f"{PATH}/MLmodel") +model =mlflow.pyfunc.load_model(f"{PATH}/MLmodel") with open(f'[PATH]/{input}', 'r') as file: json_data = json.load(file)