This commit is contained in:
Jakub Zaręba 2023-05-10 21:13:53 +02:00
parent d1d20f6968
commit e089fb1453
2 changed files with 39 additions and 56 deletions

View File

@ -1,67 +1,48 @@
from sacred import Experiment
from sacred.observers import MongoObserver, FileStorageObserver
import tensorflow as tf
import pandas as pd
import numpy as np
from sklearn.preprocessing import MinMaxScaler
import matplotlib.pyplot as plt
import os
ex = Experiment('s487187-evaluate', interactive=True)
ex.observers.append(MongoObserver(url='mongodb://admin:IUM_2021@172.17.0.1:27017', db_name='sacred'))
@ex.config
def my_config():
model_path = 'model.h5'
test_data_path = 'data.csv'
metrics_file_path = 'metrics.txt'
plot_path = 'plot.png'
model_path = 'model.h5'
test_data_path = 'data.csv'
metrics_file_path = 'metrics.txt'
plot_path = 'plot.png'
@ex.capture
def evaluate_model(model_path, test_data_path, metrics_file_path, plot_path):
import tensorflow as tf
import pandas as pd
import numpy as np
from sklearn.preprocessing import MinMaxScaler
import matplotlib.pyplot as plt
import os
model = tf.keras.models.load_model(model_path)
model = tf.keras.models.load_model(model_path)
test_data = pd.read_csv(test_data_path, sep=';')
test_data = pd.get_dummies(test_data, columns=['Sex', 'Medal'])
test_data = test_data.drop(columns=['Name', 'Team', 'NOC', 'Games', 'Year', 'Season', 'City', 'Sport', 'Event'])
test_data = pd.read_csv(test_data_path, sep=';')
test_data = pd.get_dummies(test_data, columns=['Sex', 'Medal'])
test_data = test_data.drop(columns=['Name', 'Team', 'NOC', 'Games', 'Year', 'Season', 'City', 'Sport', 'Event'])
scaler = MinMaxScaler()
test_data = pd.DataFrame(scaler.fit_transform(test_data), columns=test_data.columns)
scaler = MinMaxScaler()
test_data = pd.DataFrame(scaler.fit_transform(test_data), columns=test_data.columns)
X_test = test_data.filter(regex='Sex|Age')
y_test = test_data.filter(regex='Medal')
y_test = pd.get_dummies(y_test)
X_test = test_data.filter(regex='Sex|Age')
y_test = test_data.filter(regex='Medal')
y_test = pd.get_dummies(y_test)
X_test = X_test.fillna(0)
y_test = y_test.fillna(0)
X_test = X_test.fillna(0)
y_test = y_test.fillna(0)
y_pred = model.predict(X_test)
y_pred = model.predict(X_test)
top_1_accuracy = tf.keras.metrics.categorical_accuracy(y_test, y_pred)
top_5_accuracy = tf.keras.metrics.top_k_categorical_accuracy(y_test, y_pred, k=5)
top_1_accuracy = tf.keras.metrics.categorical_accuracy(y_test, y_pred)
top_5_accuracy = tf.keras.metrics.top_k_categorical_accuracy(y_test, y_pred, k=5)
if os.path.exists(metrics_file_path):
if os.path.exists(metrics_file_path):
metrics_df = pd.read_csv(metrics_file_path)
else:
else:
metrics_df = pd.DataFrame(columns=['top_1_accuracy', 'top_5_accuracy'])
new_row = pd.DataFrame([{'top_1_accuracy': np.mean(top_1_accuracy.numpy()), 'top_5_accuracy': np.mean(top_5_accuracy.numpy())}])
metrics_df = pd.concat([metrics_df, new_row], ignore_index=True)
metrics_df.to_csv(metrics_file_path, index=False)
new_row = pd.DataFrame([{'top_1_accuracy': np.mean(top_1_accuracy.numpy()), 'top_5_accuracy': np.mean(top_5_accuracy.numpy())}])
metrics_df = pd.concat([metrics_df, new_row], ignore_index=True)
metrics_df.to_csv(metrics_file_path, index=False)
plt.figure(figsize=(10, 6))
plt.plot(metrics_df['top_1_accuracy'], label='Top-1 Accuracy')
plt.plot(metrics_df['top_5_accuracy'], label='Top-5 Accuracy')
plt.legend()
plt.savefig(plot_path)
ex.log_scalar('top_1_accuracy', np.mean(top_1_accuracy.numpy()))
ex.log_scalar('top_5_accuracy', np.mean(top_5_accuracy.numpy()))
ex.add_artifact(model_path)
ex.add_artifact(metrics_file_path)
ex.add_artifact(plot_path)
@ex.automain
def main():
evaluate_model()
plt.figure(figsize=(10, 6))
plt.plot(metrics_df['top_1_accuracy'], label='Top-1 Accuracy')
plt.plot(metrics_df['top_5_accuracy'], label='Top-5 Accuracy')
plt.legend()
plt.savefig(plot_path)

View File

@ -1,7 +1,7 @@
from sacred import Experiment
from sacred.observers import MongoObserver, FileStorageObserver
ex = Experiment('s487187-training')
ex = Experiment('s487187-training', interactive=True)
ex.observers.append(MongoObserver(url='mongodb://admin:IUM_2021@172.17.0.1:27017', db_name='sacred'))
@ex.config
@ -65,3 +65,5 @@ def run_experiment():
accuracy = train_model()
ex.log_scalar('accuracy', accuracy)
ex.add_artifact('model.h5')
ex.run()