This commit is contained in:
zgolebiewska 2024-05-26 14:16:50 +02:00
parent b06db2db4b
commit 9d66acbdad
3 changed files with 40 additions and 5 deletions

14
PMProject.yaml Normal file
View File

@ -0,0 +1,14 @@
name: OrangeQualityModel
conda_env: conda.yaml
entry_points:
train:
parameters:
epochs: {type: int, default: 100}
command: "python model.py"
test:
parameters:
model_path: {type: str, default: "orange_quality_model_tf.h5"}
command: "python test_model.py --model_path {model_path}"

10
conda.yaml Normal file
View File

@ -0,0 +1,10 @@
name: orange_quality_model_env
channels:
- defaults
dependencies:
- python=3.8
- pip
- pandas
- scikit-learn
- tensorflow
- mlflow

View File

@ -3,6 +3,9 @@ import pandas as pd
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.model_selection import train_test_split
import json
import mlflow
mlflow.set_tracking_uri("http://localhost:5000") # Ustawienie adresu MLflow Tracking Server
df = pd.read_csv('OrangeQualityData.csv')
@ -30,8 +33,16 @@ model = tf.keras.Sequential([
model.compile(optimizer='sgd', loss='mse')
with mlflow.start_run():
mlflow.log_param("optimizer", 'sgd')
mlflow.log_param("loss_function", 'mse')
mlflow.log_param("epochs", 100)
history = model.fit(X_train_scaled, y_train, epochs=100, verbose=0, validation_data=(X_test_scaled, y_test))
for key, value in history.history.items():
mlflow.log_metric(key, value[-1]) # Logujemy ostatnią wartość metryki
model.save('orange_quality_model_tf.h5')
predictions = model.predict(X_test_scaled)