This commit is contained in:
zgolebiewska 2024-05-26 14:23:37 +02:00
parent 567074ec4c
commit 48175b11be

View File

@ -1,12 +1,23 @@
from sacred import Experiment
from sacred.observers import MongoObserver, FileStorageObserver
import tensorflow as tf
import pandas as pd
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.model_selection import train_test_split
import json
import mlflow
mlflow.set_tracking_uri("http://localhost:5000") # Ustawienie adresu MLflow Tracking Server
ex = Experiment("s464906_experiment")
mongo_url = "mongodb://admin:IUM_2021@tzietkiewicz.vm.wmi.amu.edu.pl:27017"
ex.observers.append(MongoObserver(url=mongo_url, db_name='sacred'))
ex.observers.append(FileStorageObserver('logs'))
@ex.config
def cfg():
epochs = 100
@ex.automain
def train_model(epochs):
df = pd.read_csv('OrangeQualityData.csv')
encoder = LabelEncoder()
@ -33,19 +44,21 @@ model = tf.keras.Sequential([
model.compile(optimizer='sgd', loss='mse')
with mlflow.start_run():
mlflow.log_param("optimizer", 'sgd')
mlflow.log_param("loss_function", 'mse')
mlflow.log_param("epochs", 100)
history = model.fit(X_train_scaled, y_train, epochs=epochs, verbose=0, validation_data=(X_test_scaled, y_test))
history = model.fit(X_train_scaled, y_train, epochs=100, verbose=0, validation_data=(X_test_scaled, y_test))
ex.log_scalar("epochs", epochs)
for key, value in history.history.items():
mlflow.log_metric(key, value[-1]) # Logujemy ostatnią wartość metryki
ex.add_artifact(__file__)
model.save('orange_quality_model_tf.h5')
ex.add_artifact('orange_quality_model_tf.h5')
for key, value in history.history.items():
ex.log_scalar(key, value[-1])
predictions = model.predict(X_test_scaled)
with open('predictions_tf.json', 'w') as f:
json.dump(predictions.tolist(), f, indent=4)
ex.add_artifact('predictions_tf.json')
return 'Training completed successfully'