This commit is contained in:
Jakub Zaręba 2023-05-10 21:13:53 +02:00
parent d1d20f6968
commit e089fb1453
2 changed files with 39 additions and 56 deletions

View File

@ -1,67 +1,48 @@
from sacred import Experiment import tensorflow as tf
from sacred.observers import MongoObserver, FileStorageObserver import pandas as pd
import numpy as np
from sklearn.preprocessing import MinMaxScaler
import matplotlib.pyplot as plt
import os
ex = Experiment('s487187-evaluate', interactive=True)
ex.observers.append(MongoObserver(url='mongodb://admin:IUM_2021@172.17.0.1:27017', db_name='sacred'))
@ex.config model_path = 'model.h5'
def my_config(): test_data_path = 'data.csv'
model_path = 'model.h5' metrics_file_path = 'metrics.txt'
test_data_path = 'data.csv' plot_path = 'plot.png'
metrics_file_path = 'metrics.txt'
plot_path = 'plot.png'
@ex.capture model = tf.keras.models.load_model(model_path)
def evaluate_model(model_path, test_data_path, metrics_file_path, plot_path):
import tensorflow as tf
import pandas as pd
import numpy as np
from sklearn.preprocessing import MinMaxScaler
import matplotlib.pyplot as plt
import os
model = tf.keras.models.load_model(model_path) test_data = pd.read_csv(test_data_path, sep=';')
test_data = pd.get_dummies(test_data, columns=['Sex', 'Medal'])
test_data = test_data.drop(columns=['Name', 'Team', 'NOC', 'Games', 'Year', 'Season', 'City', 'Sport', 'Event'])
test_data = pd.read_csv(test_data_path, sep=';') scaler = MinMaxScaler()
test_data = pd.get_dummies(test_data, columns=['Sex', 'Medal']) test_data = pd.DataFrame(scaler.fit_transform(test_data), columns=test_data.columns)
test_data = test_data.drop(columns=['Name', 'Team', 'NOC', 'Games', 'Year', 'Season', 'City', 'Sport', 'Event'])
scaler = MinMaxScaler() X_test = test_data.filter(regex='Sex|Age')
test_data = pd.DataFrame(scaler.fit_transform(test_data), columns=test_data.columns) y_test = test_data.filter(regex='Medal')
y_test = pd.get_dummies(y_test)
X_test = test_data.filter(regex='Sex|Age') X_test = X_test.fillna(0)
y_test = test_data.filter(regex='Medal') y_test = y_test.fillna(0)
y_test = pd.get_dummies(y_test)
X_test = X_test.fillna(0) y_pred = model.predict(X_test)
y_test = y_test.fillna(0)
y_pred = model.predict(X_test) top_1_accuracy = tf.keras.metrics.categorical_accuracy(y_test, y_pred)
top_5_accuracy = tf.keras.metrics.top_k_categorical_accuracy(y_test, y_pred, k=5)
top_1_accuracy = tf.keras.metrics.categorical_accuracy(y_test, y_pred) if os.path.exists(metrics_file_path):
top_5_accuracy = tf.keras.metrics.top_k_categorical_accuracy(y_test, y_pred, k=5)
if os.path.exists(metrics_file_path):
metrics_df = pd.read_csv(metrics_file_path) metrics_df = pd.read_csv(metrics_file_path)
else: else:
metrics_df = pd.DataFrame(columns=['top_1_accuracy', 'top_5_accuracy']) metrics_df = pd.DataFrame(columns=['top_1_accuracy', 'top_5_accuracy'])
new_row = pd.DataFrame([{'top_1_accuracy': np.mean(top_1_accuracy.numpy()), 'top_5_accuracy': np.mean(top_5_accuracy.numpy())}]) new_row = pd.DataFrame([{'top_1_accuracy': np.mean(top_1_accuracy.numpy()), 'top_5_accuracy': np.mean(top_5_accuracy.numpy())}])
metrics_df = pd.concat([metrics_df, new_row], ignore_index=True) metrics_df = pd.concat([metrics_df, new_row], ignore_index=True)
metrics_df.to_csv(metrics_file_path, index=False) metrics_df.to_csv(metrics_file_path, index=False)
plt.figure(figsize=(10, 6)) plt.figure(figsize=(10, 6))
plt.plot(metrics_df['top_1_accuracy'], label='Top-1 Accuracy') plt.plot(metrics_df['top_1_accuracy'], label='Top-1 Accuracy')
plt.plot(metrics_df['top_5_accuracy'], label='Top-5 Accuracy') plt.plot(metrics_df['top_5_accuracy'], label='Top-5 Accuracy')
plt.legend() plt.legend()
plt.savefig(plot_path) plt.savefig(plot_path)
ex.log_scalar('top_1_accuracy', np.mean(top_1_accuracy.numpy()))
ex.log_scalar('top_5_accuracy', np.mean(top_5_accuracy.numpy()))
ex.add_artifact(model_path)
ex.add_artifact(metrics_file_path)
ex.add_artifact(plot_path)
@ex.automain
def main():
evaluate_model()

View File

@ -1,7 +1,7 @@
from sacred import Experiment from sacred import Experiment
from sacred.observers import MongoObserver, FileStorageObserver from sacred.observers import MongoObserver, FileStorageObserver
ex = Experiment('s487187-training') ex = Experiment('s487187-training', interactive=True)
ex.observers.append(MongoObserver(url='mongodb://admin:IUM_2021@172.17.0.1:27017', db_name='sacred')) ex.observers.append(MongoObserver(url='mongodb://admin:IUM_2021@172.17.0.1:27017', db_name='sacred'))
@ex.config @ex.config
@ -65,3 +65,5 @@ def run_experiment():
accuracy = train_model() accuracy = train_model()
ex.log_scalar('accuracy', accuracy) ex.log_scalar('accuracy', accuracy)
ex.add_artifact('model.h5') ex.add_artifact('model.h5')
ex.run()