add tracking uri
Some checks failed
s434804-training/pipeline/head There was a failure building this commit

This commit is contained in:
Dawid 2021-05-23 17:38:53 +02:00
parent c727544281
commit 88950624a5
2 changed files with 13 additions and 3 deletions

View File

@ -18,6 +18,7 @@ pipeline {
sh 'chmod +x tensor.py' sh 'chmod +x tensor.py'
sh 'python3 tensor.py' sh 'python3 tensor.py'
sh 'rm -rf country_vaccination' sh 'rm -rf country_vaccination'
sh "export MLFLOW_TRACKING_URI=http://172.17.0.1:5000"
sh 'chmod +x mlflow_model.py' sh 'chmod +x mlflow_model.py'
sh 'python3 mlflow_model.py' sh 'python3 mlflow_model.py'
} }

View File

@ -4,7 +4,7 @@ import mlflow as mlf
from tensorflow import keras from tensorflow import keras
from sklearn.model_selection import train_test_split from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error as rmse from sklearn.metrics import mean_squared_error as rmse
from urllib.parse import urlparse
def create_model(test_size, epochs, batch_size): def create_model(test_size, epochs, batch_size):
df = pd.read_csv('country_vaccinations.csv').dropna() df = pd.read_csv('country_vaccinations.csv').dropna()
@ -36,9 +36,12 @@ def create_model(test_size, epochs, batch_size):
return model, rmse_result, signature, input_example return model, rmse_result, signature, input_example
if __name__ == "__main__": if __name__ == "__main__":
mlf.set_tracking_uri("http://172.17.0.1:5000")
test_size = float(sys.argv[1]) if len(sys.argv) > 1 else 0.2 test_size = float(sys.argv[1]) if len(sys.argv) > 1 else 0.2
epochs = int(sys.argv[2]) if len(sys.argv) > 1 else 100 epochs = int(sys.argv[2]) if len(sys.argv) > 1 else 100
batch_size = int(sys.argv[3]) if len(sys.argv) > 1 else 32 batch_size = int(sys.argv[3]) if len(sys.argv) > 1 else 32
with mlf.start_run(): with mlf.start_run():
mlf.log_param("Test size", test_size) mlf.log_param("Test size", test_size)
mlf.log_param("Epochs", epochs) mlf.log_param("Epochs", epochs)
@ -49,7 +52,13 @@ if __name__ == "__main__":
epochs=epochs, epochs=epochs,
batch_size=batch_size, batch_size=batch_size,
) )
mlf.log_metric("RMSE", rmse_result) mlf.log_metric("RMSE", rmse_result)
# mlf.keras.log_model(model, "country_vaccination") # mlf.keras.log_model(model, "country_vaccination")
mlf.keras.save_model(model, "country_vaccination", input_example=input_example, signature=signature) mlf.set_experiment("s434804")
tracking_url_type_store = urlparse(mlf.get_tracking_uri()).scheme
if tracking_url_type_store != "file":
mlf.keras.log_model(model, "country_vaccinations", registered_model_name="s434804", signature=signature,
input_example=input_example)
else:
mlf.keras.log_model(model, "vaccines_model", signature=signature, input_example=input_example)
mlf.keras.save_model(model, "country_vaccinations", signature=signature, input_example=input_example)