2024-04-04 09:06:39 +02:00
|
|
|
import torch
|
|
|
|
import os
|
|
|
|
import pandas as pd
|
|
|
|
import numpy as np
|
|
|
|
|
2024-04-19 11:40:54 +02:00
|
|
|
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
|
2024-04-04 09:06:39 +02:00
|
|
|
|
2024-04-19 12:02:09 +02:00
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
import seaborn as sns
|
|
|
|
|
2024-04-19 11:40:54 +02:00
|
|
|
from NeuralNetwork import NeuralNetwork
|
2024-04-04 09:06:39 +02:00
|
|
|
|
|
|
|
# Load model if it exists
|
|
|
|
if os.path.exists('./models/model.pth'):
|
|
|
|
# Create model
|
|
|
|
model = torch.load('./models/model.pth')
|
|
|
|
|
|
|
|
# Load test data
|
|
|
|
test = pd.read_csv('./datasets/test.csv')
|
|
|
|
|
|
|
|
# Split data
|
|
|
|
X_test = test.drop(columns=['id', 'diagnosis']).values
|
|
|
|
y_test = test['diagnosis'].values
|
|
|
|
|
|
|
|
# Convert data to PyTorch tensors
|
|
|
|
X_test = torch.FloatTensor(X_test)
|
|
|
|
y_test = torch.FloatTensor(y_test).view(-1, 1)
|
|
|
|
|
|
|
|
# Predict
|
|
|
|
with torch.no_grad():
|
|
|
|
y_pred = model(X_test)
|
|
|
|
y_pred = np.where(y_pred >= 0.5, 1, 0)
|
|
|
|
|
|
|
|
# Save predictions to CSV
|
|
|
|
pd.DataFrame(y_pred, columns=['Prediction']).to_csv('predictions.csv', index=False)
|
2024-04-19 11:40:54 +02:00
|
|
|
|
|
|
|
# Calculate metrics
|
|
|
|
accuracy = accuracy_score(y_test, y_pred)
|
|
|
|
precision = precision_score(y_test, y_pred)
|
|
|
|
recall = recall_score(y_test, y_pred)
|
|
|
|
f1 = f1_score(y_test, y_pred)
|
|
|
|
|
|
|
|
# Save metrics to CSV (append mode, if file exists, if not, create it)
|
|
|
|
if not os.path.exists('metrics.csv'):
|
|
|
|
pd.DataFrame([[accuracy, precision, recall, f1]], columns=['Accuracy', 'Precision', 'Recall', 'F1']).to_csv('metrics.csv', index=False)
|
|
|
|
else:
|
|
|
|
# without header
|
2024-04-19 12:02:09 +02:00
|
|
|
metrics = pd.read_csv('metrics.csv')
|
|
|
|
metrics = metrics._append({'Accuracy': accuracy, 'Precision': precision, 'Recall': recall, 'F1': f1}, ignore_index=True)
|
|
|
|
metrics.to_csv('metrics.csv', index=False, mode='a', header=False)
|
|
|
|
|
|
|
|
# Plot metrics line chart
|
|
|
|
sns.set(style='whitegrid')
|
|
|
|
plt.figure(figsize=(8, 6))
|
|
|
|
sns.lineplot(data=metrics)
|
|
|
|
plt.title('Metrics history')
|
|
|
|
plt.xlabel('History number')
|
|
|
|
plt.ylabel('Value')
|
|
|
|
plt.legend()
|
|
|
|
plt.savefig('metrics.png')
|
2024-04-04 09:06:39 +02:00
|
|
|
else:
|
|
|
|
raise FileNotFoundError('Model not found')
|