ium_z487175/DL-prediction.py

81 lines
2.4 KiB
Python

import pickle
import os
import pandas as pd
import numpy as np
import argparse
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, mean_squared_error
from sklearn.preprocessing import label_binarize
workspace_path = os.getenv('WORKSPACE')
pickle_path = '/app/model_with_data.pickle'
if os.path.exists(pickle_path):
with open(pickle_path, 'rb') as file:
loaded_data = pickle.load(file)
else:
## else if aby było mozna uruchomić lokalnie
with open('model_with_data.pickle', 'rb') as file:
loaded_data = pickle.load(file)
# Wczytanie modelu
model = loaded_data[0]
#Wczytanie danych testowych
X_test_scaled = loaded_data[3]
y_test_encoded = loaded_data[4]
# Predykcja
y_pred = model.predict(X_test_scaled)
y_pred_classes = np.argmax(y_pred, axis=1)
# Zapisanie wyników predykcji do pliku
np.savetxt('results_prediction.csv', y_pred, delimiter=',')
# Konwersja y_test_encoded na wektor klas
y_test_classes = np.argmax(y_test_encoded, axis=1)
# Wyliczenie zbiorczych metryk i zapis do pliku - zadanie 2 - LAB 06
accuracy = accuracy_score(y_test_classes, y_pred_classes)
precision = precision_score(y_test_classes, y_pred_classes, average='micro')
recall = recall_score(y_test_classes, y_pred_classes, average='micro')
f1 = f1_score(y_test_classes, y_pred_classes, average='micro')
rmse = np.sqrt(mean_squared_error(y_test_classes, y_pred_classes))
print("Metrics results")
print("Accuracy:", accuracy)
print("Micro-avg Precision:", precision)
print("Micro-avg Recall:", recall)
print("F1 Score:", f1)
print("RMSE:", rmse)
# Zapisanie metryk do pliku CSV
parser = argparse.ArgumentParser()
parser.add_argument('--build-number', type=int, help='Build number')
args = parser.parse_args()
build_number = args.build_number
print(f"Numer builda: {build_number}")
metrics_file = os.path.join('metrics.csv')
metrics_data = {
'Build Number': [build_number],
'Accuracy': [accuracy],
'Micro-avg Precision': [precision],
'Micro-avg Recall': [recall],
'F1 Score': [f1],
'RMSE': [rmse]
}
df = pd.DataFrame(metrics_data)
if os.path.exists(metrics_file):
# Odczytanie istniejącego pliku CSV
existing_df = pd.read_csv(metrics_file)
df = pd.concat([existing_df, df], ignore_index=True)
# Zapisanie metryk do pliku CSV
df.to_csv(metrics_file, index=False)
# Wyświetlenie informacji o zapisanym pliku
print("Metryki zostały zapisane do pliku:", metrics_file)