From a4fb1a898ae1443fa3ad88e970f34bd3c715657f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebastian=20Wa=C5=82=C4=99sa?= Date: Sun, 15 May 2022 14:19:02 +0200 Subject: [PATCH] Zaktualizuj 'ml_pytorch_mlflow.py' --- ml_pytorch_mlflow.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/ml_pytorch_mlflow.py b/ml_pytorch_mlflow.py index 56e615d..80e559d 100644 --- a/ml_pytorch_mlflow.py +++ b/ml_pytorch_mlflow.py @@ -107,7 +107,7 @@ def predict_single(input, target, model): return "Target: "+str(target)+" Predicted: "+str(prediction)+"\n" -def prediction(input, target, model): +def prediction(input, model): inputs = input.unsqueeze(0) predictions = model(inputs) predicted = predictions[0].detach() @@ -133,7 +133,7 @@ def my_main(epochs): for i in range(0, len(val_ds), 1): input_, target = val_ds[i] expected.append(float(target)) - predicted.append(float(prediction(input_, target, model))) + predicted.append(float(prediction(input_, model))) MSE = mean_squared_error(expected, predicted) MAE = mean_absolute_error(expected, predicted) @@ -147,8 +147,19 @@ def my_main(epochs): input_, target = val_ds[i] file.write(str(predict_single(input_, target, model))) + input_example = val_ds[0].unsqueeze(0) + signature = mlflow.models.signature.infer_signature(input_, prediction(input_, model)) + tracking_url_type_store = urlparse(mlflow.get_tracking_uri()).scheme + + if tracking_url_type_store != "file": + mlflow.pytorch.log_model(model, "model", registered_model_name="s444356", signature=siganture, + input_example=input_example) + else: + mlflow.pytorch.log_model(model, "model", signature=siganture, input_example=input_example) + mlflow.pytorch.save_model(model, "my_model", signature=siganture, input_example=input_example) + torch.save(model, "Model_xPosition.pkl") - # ex.add_artifact("Model_xPosition.pkl") with mlflow.start_run() as run: - my_main(epochs) \ No newline at end of file + my_main(epochs) + \ No newline at end of file