add tracking uri
Some checks failed
s434804-training/pipeline/head There was a failure building this commit
Some checks failed
s434804-training/pipeline/head There was a failure building this commit
This commit is contained in:
parent
c727544281
commit
88950624a5
@ -18,6 +18,7 @@ pipeline {
|
|||||||
sh 'chmod +x tensor.py'
|
sh 'chmod +x tensor.py'
|
||||||
sh 'python3 tensor.py'
|
sh 'python3 tensor.py'
|
||||||
sh 'rm -rf country_vaccination'
|
sh 'rm -rf country_vaccination'
|
||||||
|
sh "export MLFLOW_TRACKING_URI=http://172.17.0.1:5000"
|
||||||
sh 'chmod +x mlflow_model.py'
|
sh 'chmod +x mlflow_model.py'
|
||||||
sh 'python3 mlflow_model.py'
|
sh 'python3 mlflow_model.py'
|
||||||
}
|
}
|
||||||
|
@ -4,7 +4,7 @@ import mlflow as mlf
|
|||||||
from tensorflow import keras
|
from tensorflow import keras
|
||||||
from sklearn.model_selection import train_test_split
|
from sklearn.model_selection import train_test_split
|
||||||
from sklearn.metrics import mean_squared_error as rmse
|
from sklearn.metrics import mean_squared_error as rmse
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
def create_model(test_size, epochs, batch_size):
|
def create_model(test_size, epochs, batch_size):
|
||||||
df = pd.read_csv('country_vaccinations.csv').dropna()
|
df = pd.read_csv('country_vaccinations.csv').dropna()
|
||||||
@ -36,9 +36,12 @@ def create_model(test_size, epochs, batch_size):
|
|||||||
return model, rmse_result, signature, input_example
|
return model, rmse_result, signature, input_example
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
mlf.set_tracking_uri("http://172.17.0.1:5000")
|
||||||
|
|
||||||
test_size = float(sys.argv[1]) if len(sys.argv) > 1 else 0.2
|
test_size = float(sys.argv[1]) if len(sys.argv) > 1 else 0.2
|
||||||
epochs = int(sys.argv[2]) if len(sys.argv) > 1 else 100
|
epochs = int(sys.argv[2]) if len(sys.argv) > 1 else 100
|
||||||
batch_size = int(sys.argv[3]) if len(sys.argv) > 1 else 32
|
batch_size = int(sys.argv[3]) if len(sys.argv) > 1 else 32
|
||||||
|
|
||||||
with mlf.start_run():
|
with mlf.start_run():
|
||||||
mlf.log_param("Test size", test_size)
|
mlf.log_param("Test size", test_size)
|
||||||
mlf.log_param("Epochs", epochs)
|
mlf.log_param("Epochs", epochs)
|
||||||
@ -49,7 +52,13 @@ if __name__ == "__main__":
|
|||||||
epochs=epochs,
|
epochs=epochs,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
mlf.log_metric("RMSE", rmse_result)
|
mlf.log_metric("RMSE", rmse_result)
|
||||||
# mlf.keras.log_model(model, "country_vaccination")
|
# mlf.keras.log_model(model, "country_vaccination")
|
||||||
mlf.keras.save_model(model, "country_vaccination", input_example=input_example, signature=signature)
|
mlf.set_experiment("s434804")
|
||||||
|
tracking_url_type_store = urlparse(mlf.get_tracking_uri()).scheme
|
||||||
|
if tracking_url_type_store != "file":
|
||||||
|
mlf.keras.log_model(model, "country_vaccinations", registered_model_name="s434804", signature=signature,
|
||||||
|
input_example=input_example)
|
||||||
|
else:
|
||||||
|
mlf.keras.log_model(model, "vaccines_model", signature=signature, input_example=input_example)
|
||||||
|
mlf.keras.save_model(model, "country_vaccinations", signature=signature, input_example=input_example)
|
Loading…
Reference in New Issue
Block a user