diff --git a/lab8/trainScript.py b/lab8/trainScript.py index 10916c9..29c634c 100644 --- a/lab8/trainScript.py +++ b/lab8/trainScript.py @@ -108,16 +108,16 @@ def train(): mlflow.log_param('learning_rate', learning_rate) mlflow.log_metric('final_loss', min(hist["val_loss"])) - signature = mlflow.models.signature.infer_signature(house_price_features, linear_model.predict(house_price_features)) - - tracking_url_type_store = urlparse(mlflow.get_tracking_uri()).scheme - - sampleInp = [0.0, 0.0, 2.0, 904.129525, 1.000000, 1.000000, 20.098413, 79.107860] - # expected value is 49.7 - if tracking_url_type_store != "file": - mlflow.keras.log_model(linear_model, "linear-model", registered_model_name="HousePriceLinear", signature=signature) - else: - mlflow.keras.log_model(linear_model, "model", signature=signature, input_example=np.array(sampleInp)) + signature = mlflow.models.signature.infer_signature(house_price_features, linear_model.predict(house_price_features)) + + tracking_url_type_store = urlparse(mlflow.get_tracking_uri()).scheme + + sampleInp = [0.0, 0.0, 2.0, 904.129525, 1.000000, 1.000000, 20.098413, 79.107860] + # expected value is 49.7 + if tracking_url_type_store != "file": + mlflow.keras.log_model(linear_model, "linear-model", registered_model_name="HousePriceLinear", signature=signature) + else: + mlflow.keras.log_model(linear_model, "model", signature=signature, input_example=np.array(sampleInp)) if __name__ == '__main__': train() \ No newline at end of file