ium_464863/IUM_13.ipynb

1602 lines
470 KiB
Plaintext
Raw Permalink Normal View History

2024-06-10 12:20:15 +02:00
{
"cells": [
{
"cell_type": "markdown",
"source": [
"## Experiments - neural networks in breast cancer classification problem"
],
"metadata": {
"collapsed": false
},
"id": "e1e08c454a98dd01"
},
{
"cell_type": "code",
"execution_count": 1,
"outputs": [],
"source": [
"# Data manipulation\n",
"import numpy as np\n",
"import pandas as pd\n",
"\n",
"# Data visualization\n",
"import matplotlib.pyplot as plt\n",
"import seaborn as sns\n",
"sns.set_style('whitegrid')\n",
"\n",
"# Data preprocessing\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.preprocessing import StandardScaler\n",
"\n",
"# Metrics\n",
"from sklearn.metrics import confusion_matrix, classification_report\n",
"from sklearn.metrics import accuracy_score\n",
"\n",
"# Deep learning\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.optim as optim"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-08T15:50:23.912931300Z",
"start_time": "2024-06-08T15:50:19.472582100Z"
}
},
"id": "c0c219cc1bbd4c7a"
},
{
"cell_type": "markdown",
"source": [
"#### Methods for visualizing confusion matrix and classification report"
],
"metadata": {
"collapsed": false
},
"id": "6064474e7a56f80e"
},
{
"cell_type": "code",
"execution_count": 3,
"outputs": [],
"source": [
"# Plot confusion matrix\n",
"def plot_confusion_matrix(cm, classes, title='Confusion matrix', cmap='Blues', figsize=(10, 6), axis=None):\n",
" \"\"\"\n",
" Plot the confusion matrix.\n",
" \"\"\"\n",
" if axis is None:\n",
" fig, ax = plt.subplots(figsize=figsize)\n",
" else:\n",
" ax = axis\n",
" \n",
" sns.heatmap(cm, annot=True, fmt='d', xticklabels=classes, yticklabels=classes, cmap=cmap, ax=ax)\n",
" \n",
" ax.set_title(title)\n",
" ax.set_xlabel('Predicted label')\n",
" ax.set_ylabel('True label')\n",
" \n",
" if axis is None:\n",
" plt.show() \n",
" \n",
"# Plot classification report\n",
"def plot_classification_report(report, title='Classification report', axis=None):\n",
" \"\"\"\n",
" Plot the classification report.\n",
" \"\"\"\n",
" if axis is None:\n",
" fig, ax = plt.subplots(figsize=(10, 6))\n",
" else:\n",
" ax = axis\n",
" \n",
" sns.heatmap(pd.DataFrame(report).iloc[:-1, :].T, annot=True, cmap='Blues', ax=ax)\n",
" \n",
" ax.set_title('Classification report')\n",
" ax.set_xlabel('Metrics')\n",
" ax.set_ylabel('Classes')\n",
" \n",
" if axis is None:\n",
" plt.show()"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-08T15:51:26.166904900Z",
"start_time": "2024-06-08T15:51:26.145794400Z"
}
},
"id": "689b41e45e990a1b"
},
{
"cell_type": "markdown",
"source": [
"#### Load data"
],
"metadata": {
"collapsed": false
},
"id": "c7ad4d251442c34c"
},
{
"cell_type": "code",
"execution_count": 8,
"outputs": [],
"source": [
"# Load data\n",
"data = pd.read_csv('datasets/data.csv')\n",
"\n",
"# Delete unnecessary columns\n",
"data.drop(['id'], axis=1, inplace=True)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-08T15:52:50.237396100Z",
"start_time": "2024-06-08T15:52:50.201061700Z"
}
},
"id": "54411dcad03637c2"
},
{
"cell_type": "code",
"execution_count": 76,
"outputs": [
{
"data": {
"text/plain": " diagnosis radius_mean texture_mean perimeter_mean area_mean \\\n0 1.0 0.521037 0.022658 0.545989 0.363733 \n1 1.0 0.643144 0.272574 0.615783 0.501591 \n2 1.0 0.601496 0.390260 0.595743 0.449417 \n3 1.0 0.210090 0.360839 0.233501 0.102906 \n4 1.0 0.629893 0.156578 0.630986 0.489290 \n.. ... ... ... ... ... \n564 1.0 0.690000 0.428813 0.678668 0.566490 \n565 1.0 0.622320 0.626987 0.604036 0.474019 \n566 1.0 0.455251 0.621238 0.445788 0.303118 \n567 1.0 0.644564 0.663510 0.665538 0.475716 \n568 0.0 0.036869 0.501522 0.028540 0.015907 \n\n smoothness_mean compactness_mean concavity_mean concave points_mean \\\n0 0.593753 0.792037 0.703140 0.731113 \n1 0.289880 0.181768 0.203608 0.348757 \n2 0.514309 0.431017 0.462512 0.635686 \n3 0.811321 0.811361 0.565604 0.522863 \n4 0.430351 0.347893 0.463918 0.518390 \n.. ... ... ... ... \n564 0.526948 0.296055 0.571462 0.690358 \n565 0.407782 0.257714 0.337395 0.486630 \n566 0.288165 0.254340 0.216753 0.263519 \n567 0.588336 0.790197 0.823336 0.755467 \n568 0.000000 0.074351 0.000000 0.000000 \n\n symmetry_mean ... radius_worst texture_worst perimeter_worst \\\n0 0.686364 ... 0.620776 0.141525 0.668310 \n1 0.379798 ... 0.606901 0.303571 0.539818 \n2 0.509596 ... 0.556386 0.360075 0.508442 \n3 0.776263 ... 0.248310 0.385928 0.241347 \n4 0.378283 ... 0.519744 0.123934 0.506948 \n.. ... ... ... ... ... \n564 0.336364 ... 0.623266 0.383262 0.576174 \n565 0.349495 ... 0.560655 0.699094 0.520892 \n566 0.267677 ... 0.393099 0.589019 0.379949 \n567 0.675253 ... 0.633582 0.730277 0.668310 \n568 0.266162 ... 0.054287 0.489072 0.043578 \n\n area_worst smoothness_worst compactness_worst concavity_worst \\\n0 0.450698 0.601136 0.619292 0.568610 \n1 0.435214 0.347553 0.154563 0.192971 \n2 0.374508 0.483590 0.385375 0.359744 \n3 0.094008 0.915472 0.814012 0.548642 \n4 0.341575 0.437364 0.172415 0.319489 \n.. ... ... ... ... \n564 0.452664 0.461137 0.178527 0.328035 \n565 0.379915 0.300007 0.159997 0.256789 \n566 0.230731 0.282177 0.273705 0.271805 \n567 0.402035 0.619626 0.815758 0.749760 \n568 0.020497 0.124084 0.036043 0.000000 \n\n concave points_worst symmetry_worst fractal_dimension_worst \n0 0.912027 0.598462 0.418864 \n1 0.639175 0.233590 0.222878 \n2 0.835052 0.403706 0.213433 \n3 0.884880 1.000000 0.773711 \n4 0.558419 0.157500 0.142595 \n.. ...
"text/html": "<div>\n<style scoped>\n .dataframe tbody tr th:only-of-type {\n vertical-align: middle;\n }\n\n .dataframe tbody tr th {\n vertical-align: top;\n }\n\n .dataframe thead th {\n text-align: right;\n }\n</style>\n<table border=\"1\" class=\"dataframe\">\n <thead>\n <tr style=\"text-align: right;\">\n <th></th>\n <th>diagnosis</th>\n <th>radius_mean</th>\n <th>texture_mean</th>\n <th>perimeter_mean</th>\n <th>area_mean</th>\n <th>smoothness_mean</th>\n <th>compactness_mean</th>\n <th>concavity_mean</th>\n <th>concave points_mean</th>\n <th>symmetry_mean</th>\n <th>...</th>\n <th>radius_worst</th>\n <th>texture_worst</th>\n <th>perimeter_worst</th>\n <th>area_worst</th>\n <th>smoothness_worst</th>\n <th>compactness_worst</th>\n <th>concavity_worst</th>\n <th>concave points_worst</th>\n <th>symmetry_worst</th>\n <th>fractal_dimension_worst</th>\n </tr>\n </thead>\n <tbody>\n <tr>\n <th>0</th>\n <td>1.0</td>\n <td>0.521037</td>\n <td>0.022658</td>\n <td>0.545989</td>\n <td>0.363733</td>\n <td>0.593753</td>\n <td>0.792037</td>\n <td>0.703140</td>\n <td>0.731113</td>\n <td>0.686364</td>\n <td>...</td>\n <td>0.620776</td>\n <td>0.141525</td>\n <td>0.668310</td>\n <td>0.450698</td>\n <td>0.601136</td>\n <td>0.619292</td>\n <td>0.568610</td>\n <td>0.912027</td>\n <td>0.598462</td>\n <td>0.418864</td>\n </tr>\n <tr>\n <th>1</th>\n <td>1.0</td>\n <td>0.643144</td>\n <td>0.272574</td>\n <td>0.615783</td>\n <td>0.501591</td>\n <td>0.289880</td>\n <td>0.181768</td>\n <td>0.203608</td>\n <td>0.348757</td>\n <td>0.379798</td>\n <td>...</td>\n <td>0.606901</td>\n <td>0.303571</td>\n <td>0.539818</td>\n <td>0.435214</td>\n <td>0.347553</td>\n <td>0.154563</td>\n <td>0.192971</td>\n <td>0.639175</td>\n <td>0.233590</td>\n <td>0.222878</td>\n </tr>\n <tr>\n <th>2</th>\n <td>1.0</td>\n <td>0.601496</td>\n <td>0.390260</td>\n <td>0.595743</td>\n <td>0.449417</td>\n <td>0.514309</td>\n <td>0.431017</td>\n <td>0.462512</td>\n <td>0.635686</td>\n <td>0.509596</td>\n <td>...</td>\n <td>0.556386</td>\n <td>0.360075</td>\n <td>0.508442</td>\n <td>0.374508</td>\n <td>0.483590</td>\n <td>0.385375</td>\n <td>0.359744</td>\n <td>0.835052</td>\n <td>0.403706</td>\n <td>0.213433</td>\n </tr>\n <tr>\n <th>3</th>\n <td>1.0</td>\n <td>0.210090</td>\n <td>0.360839</td>\n <td>0.233501</td>\n <td>0.102906</td>\n <td>0.811321</td>\n <td>0.811361</td>\n <td>0.565604</td>\n <td>0.522863</td>\n <td>0.776263</td>\n <td>...</td>\n <td>0.248310</td>\n <td>0.385928</td>\n <td>0.241347</td>\n <td>0.094008</td>\n <td>0.915472</td>\n <td>0.814012</td>\n <td>0.548642</td>\n <td>0.884880</td>\n <td>1.000000</td>\n <td>0.773711</td>\n </tr>\n <tr>\n <th>4</th>\n <td>1.0</td>\n <td>0.629893</td>\n <td>0.156578</td>\n <td>0.630986</td>\n <td>0.489290</td>\n <td>0.430351</td>\n <td>0.347893</td>\n <td>0.463918</td>\n <td>0.518390</td>\n <td>0.378283</td>\n <td>...</td>\n <td>0.519744</td>\n <td>0.123934</td>\n <td>0.506948</td>\n <td>0.341575</td>\n <td>0.437364</td>\n <td>0.172415</td>\n <td>0.319489</td>\n <td>0.558419</td>\n <td>0.157500</td>\n <td>0.142595</td>\n </tr>\n <tr>\n <th>...</th>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n <td>...</t
},
"execution_count": 76,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-08T17:45:09.614420700Z",
"start_time": "2024-06-08T17:45:09.479821700Z"
}
},
"id": "4294a8ce6b3bf0d5"
},
{
"cell_type": "markdown",
"source": [
"#### Data preprocessing"
],
"metadata": {
"collapsed": false
},
"id": "74e70a93ab91ece2"
},
{
"cell_type": "code",
"execution_count": 20,
"outputs": [],
"source": [
"# Split data into training and testing sets\n",
"X = data.iloc[:, 1:]\n",
"y = data.iloc[:, 0]\n",
"\n",
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)\n",
"\n",
"# Standardize the data\n",
"scaler = StandardScaler()\n",
"X_train = scaler.fit_transform(X_train)\n",
"X_test = scaler.transform(X_test)\n",
"\n",
"# Convert data to PyTorch tensors\n",
"X_train = torch.FloatTensor(X_train)\n",
"X_test = torch.FloatTensor(X_test)\n",
"y_train = torch.FloatTensor(y_train.values).view(-1, 1)\n",
"y_test = torch.FloatTensor(y_test.values).view(-1, 1)\n",
"\n",
"# Transfer data to GPU if available\n",
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
"\n",
"X_train = X_train.to(device)\n",
"X_test = X_test.to(device)\n",
"y_train = y_train.to(device)\n",
"y_test = y_test.to(device)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-08T16:03:06.981298Z",
"start_time": "2024-06-08T16:03:06.957759100Z"
}
},
"id": "1fcdf94e46907575"
},
{
"cell_type": "markdown",
"source": [
"#### Neural network architectures"
],
"metadata": {
"collapsed": false
},
"id": "1a150bd5c3a959"
},
{
"cell_type": "code",
"execution_count": 28,
"outputs": [],
"source": [
"# V1\n",
"# Three fully connected layers with ReLU activation function\n",
"# Output layer with Sigmoid activation function\n",
"class NeuralNetworkV1(nn.Module):\n",
" def __init__(self, input_size, hidden_size):\n",
" super(NeuralNetworkV1, self).__init__()\n",
" \n",
" self.fc1 = nn.Linear(input_size, hidden_size)\n",
" self.fc2 = nn.Linear(hidden_size, hidden_size // 2)\n",
" self.fc3 = nn.Linear(hidden_size // 2, 1)\n",
"\n",
" self.relu = nn.ReLU()\n",
" self.sigmoid = nn.Sigmoid()\n",
" \n",
" def forward(self, x):\n",
" out = self.fc1(x)\n",
" out = self.relu(out)\n",
" out = self.fc2(out)\n",
" out = self.relu(out)\n",
" out = self.fc3(out)\n",
" out = self.sigmoid(out)\n",
" return out"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-08T16:04:14.970423300Z",
"start_time": "2024-06-08T16:04:14.953618100Z"
}
},
"id": "9fba79643e9e76f3"
},
{
"cell_type": "code",
"execution_count": 29,
"outputs": [],
"source": [
"# V2\n",
"# Four fully connected layers with ReLU activation function and dropout layers\n",
"# Output layer with Sigmoid activation function\n",
"class NeuralNetworkV2(nn.Module):\n",
" def __init__(self, input_size, hidden_size):\n",
" super(NeuralNetworkV2, self).__init__()\n",
" \n",
" self.fc1 = nn.Linear(input_size, hidden_size)\n",
" self.dropout1 = nn.Dropout(0.5)\n",
" self.fc2 = nn.Linear(hidden_size, hidden_size // 2)\n",
" self.dropout2 = nn.Dropout(0.5)\n",
" self.fc3 = nn.Linear(hidden_size // 2, hidden_size // 4)\n",
" self.fc4 = nn.Linear(hidden_size // 4, 1)\n",
"\n",
" self.relu = nn.ReLU()\n",
" self.sigmoid = nn.Sigmoid()\n",
" \n",
" def forward(self, x):\n",
" out = self.fc1(x)\n",
" out = self.relu(out)\n",
" out = self.dropout1(out)\n",
" out = self.fc2(out)\n",
" out = self.relu(out)\n",
" out = self.dropout2(out)\n",
" out = self.fc3(out)\n",
" out = self.relu(out)\n",
" out = self.fc4(out)\n",
" out = self.sigmoid(out)\n",
" return out"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-08T16:04:15.743609500Z",
"start_time": "2024-06-08T16:04:15.731835400Z"
}
},
"id": "d9d6297f85fb83fc"
},
{
"cell_type": "code",
"execution_count": 30,
"outputs": [],
"source": [
"# V3\n",
"# Four fully connected layers with Leaky ReLU activation function\n",
"# Output layer with Sigmoid activation function\n",
"class NeuralNetworkV3(nn.Module):\n",
" def __init__(self, input_size, hidden_size):\n",
" super(NeuralNetworkV3, self).__init__()\n",
" \n",
" self.fc1 = nn.Linear(input_size, hidden_size)\n",
" self.fc2 = nn.Linear(hidden_size, hidden_size // 2)\n",
" self.fc3 = nn.Linear(hidden_size // 2, hidden_size // 4)\n",
" self.fc4 = nn.Linear(hidden_size // 4, 1)\n",
"\n",
" self.leaky_relu = nn.LeakyReLU(0.1)\n",
" self.sigmoid = nn.Sigmoid()\n",
" \n",
" def forward(self, x):\n",
" out = self.fc1(x)\n",
" out = self.leaky_relu(out)\n",
" out = self.fc2(out)\n",
" out = self.leaky_relu(out)\n",
" out = self.fc3(out)\n",
" out = self.leaky_relu(out)\n",
" out = self.fc4(out)\n",
" out = self.sigmoid(out)\n",
" return out"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-08T16:04:15.989218400Z",
"start_time": "2024-06-08T16:04:15.971488Z"
}
},
"id": "e9d8ad2014841d93"
},
{
"cell_type": "code",
"execution_count": 59,
"outputs": [],
"source": [
"# V4\n",
"# Two convolutional layers with ReLU activation function and max pooling layers\n",
"# Two fully connected layers with ReLU activation function\n",
"# Output layer with Sigmoid activation function\n",
"class NeuralNetworkV4(nn.Module):\n",
" def __init__(self, input_size, hidden_size):\n",
" super(NeuralNetworkV4, self).__init__()\n",
" \n",
" self.conv1 = nn.Conv1d(1, 16, kernel_size=3, stride=1, padding=1)\n",
" self.conv2 = nn.Conv1d(16, 32, kernel_size=3, stride=1, padding=1)\n",
" self.pool = nn.MaxPool1d(kernel_size=2, stride=2, padding=0)\n",
" \n",
" # Calculate the output size after the conv and pooling layers\n",
" conv_output_size = input_size // 2 # After two pooling layers with stride 2\n",
" conv_output_size = conv_output_size // 2 # After the second pooling layer\n",
" \n",
" self.fc1 = nn.Linear(32 * conv_output_size, hidden_size)\n",
" self.fc2 = nn.Linear(hidden_size, 1)\n",
"\n",
" self.relu = nn.ReLU()\n",
" self.sigmoid = nn.Sigmoid()\n",
" \n",
" def forward(self, x):\n",
" x = x.unsqueeze(1) # Add channel dimension\n",
" out = self.conv1(x)\n",
" out = self.relu(out)\n",
" out = self.pool(out)\n",
" out = self.conv2(out)\n",
" out = self.relu(out)\n",
" out = self.pool(out)\n",
" out = out.view(out.size(0), -1) # Flatten the tensor\n",
" out = self.fc1(out)\n",
" out = self.relu(out)\n",
" out = self.fc2(out)\n",
" out = self.sigmoid(out)\n",
" return out"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-08T16:25:44.294141500Z",
"start_time": "2024-06-08T16:25:44.266927300Z"
}
},
"id": "19adc0031ce51f33"
},
{
"cell_type": "code",
"execution_count": 60,
"outputs": [],
"source": [
"# V5\n",
"# LSTM layer with ReLU activation function\n",
"# Two fully connected layers with ReLU activation function\n",
"# Output layer with Sigmoid activation function\n",
"class NeuralNetworkV5(nn.Module):\n",
" def __init__(self, input_size, hidden_size):\n",
" super(NeuralNetworkV5, self).__init__()\n",
" \n",
" self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)\n",
" self.fc1 = nn.Linear(hidden_size, hidden_size // 2)\n",
" self.fc2 = nn.Linear(hidden_size // 2, 1)\n",
"\n",
" self.relu = nn.ReLU()\n",
" self.sigmoid = nn.Sigmoid()\n",
" \n",
" def forward(self, x):\n",
" out, _ = self.lstm(x)\n",
" out = out[:, -1, :] # Take the last output of the LSTM\n",
" out = self.fc1(out)\n",
" out = self.relu(out)\n",
" out = self.fc2(out)\n",
" out = self.sigmoid(out)\n",
" return out"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-08T16:25:45.194410100Z",
"start_time": "2024-06-08T16:25:45.173984800Z"
}
},
"id": "3b404d9e7019d04"
},
{
"cell_type": "markdown",
"source": [
"#### Training and evaluation"
],
"metadata": {
"collapsed": false
},
"id": "7966001ff35b88d7"
},
{
"cell_type": "code",
"execution_count": 61,
"outputs": [],
"source": [
"# Training function\n",
"def train(model, X_train, y_train, criterion, optimizer, epochs=100):\n",
" \"\"\"\n",
" Train the neural network.\n",
" \"\"\"\n",
" for epoch in range(epochs):\n",
" optimizer.zero_grad()\n",
" y_pred = model(X_train)\n",
" loss = criterion(y_pred, y_train)\n",
" loss.backward()\n",
" optimizer.step()\n",
" \n",
" if (epoch + 1) % 10 == 0:\n",
" print(f'Epoch {epoch + 1}/{epochs}, Loss: {loss.item()}')"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-08T16:25:46.534167500Z",
"start_time": "2024-06-08T16:25:46.524741Z"
}
},
"id": "ebab09eac524c418"
},
{
"cell_type": "code",
"execution_count": 62,
"outputs": [],
"source": [
"# Evaluation function\n",
"def evaluate(model, X_test, y_test):\n",
" \"\"\"\n",
" Evaluate the neural network.\n",
" \"\"\"\n",
" with torch.no_grad():\n",
" y_pred = model(X_test)\n",
" y_pred = (y_pred > 0.5).float()\n",
" cm = confusion_matrix(y_test.cpu(), y_pred.cpu())\n",
" cr = classification_report(y_test.cpu(), y_pred.cpu(), target_names=['Benign', 'Malignant'], output_dict=True)\n",
" acc = accuracy_score(y_test.cpu(), y_pred.cpu())\n",
" \n",
" return cm, cr, acc"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-08T16:25:46.825080Z",
"start_time": "2024-06-08T16:25:46.805541600Z"
}
},
"id": "b89d7e317f1c8b63"
},
{
"cell_type": "code",
"execution_count": 63,
"outputs": [],
"source": [
"# Neural network parameters\n",
"input_size = X_train.shape[1]\n",
"hidden_size = 128\n",
"learning_rate = 0.001\n",
"weight_decay = 0.0001\n",
"epochs = 1000"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-08T16:25:47.991207300Z",
"start_time": "2024-06-08T16:25:47.974823800Z"
}
},
"id": "d967f53803f97353"
},
{
"cell_type": "markdown",
"source": [
"#### Neural network V1"
],
"metadata": {
"collapsed": false
},
"id": "b8df2e2d1ce6e786"
},
{
"cell_type": "code",
"execution_count": 45,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 10/1000, Loss: 0.4779872000217438\n",
"Epoch 20/1000, Loss: 0.27251550555229187\n",
"Epoch 30/1000, Loss: 0.15164388716220856\n",
"Epoch 40/1000, Loss: 0.10145729035139084\n",
"Epoch 50/1000, Loss: 0.07789547741413116\n",
"Epoch 60/1000, Loss: 0.06484830379486084\n",
"Epoch 70/1000, Loss: 0.05551866441965103\n",
"Epoch 80/1000, Loss: 0.04840008169412613\n",
"Epoch 90/1000, Loss: 0.042578838765621185\n",
"Epoch 100/1000, Loss: 0.03765585273504257\n",
"Epoch 110/1000, Loss: 0.03336893767118454\n",
"Epoch 120/1000, Loss: 0.02957029454410076\n",
"Epoch 130/1000, Loss: 0.02613999880850315\n",
"Epoch 140/1000, Loss: 0.023030269891023636\n",
"Epoch 150/1000, Loss: 0.02021847851574421\n",
"Epoch 160/1000, Loss: 0.01771070621907711\n",
"Epoch 170/1000, Loss: 0.015495861880481243\n",
"Epoch 180/1000, Loss: 0.013572530820965767\n",
"Epoch 190/1000, Loss: 0.011898759752511978\n",
"Epoch 200/1000, Loss: 0.010441687889397144\n",
"Epoch 210/1000, Loss: 0.009163236245512962\n",
"Epoch 220/1000, Loss: 0.00809294544160366\n",
"Epoch 230/1000, Loss: 0.007174020167440176\n",
"Epoch 240/1000, Loss: 0.006399359088391066\n",
"Epoch 250/1000, Loss: 0.005741355009377003\n",
"Epoch 260/1000, Loss: 0.005172951612621546\n",
"Epoch 270/1000, Loss: 0.004688368644565344\n",
"Epoch 280/1000, Loss: 0.004283885937184095\n",
"Epoch 290/1000, Loss: 0.003929890692234039\n",
"Epoch 300/1000, Loss: 0.0036243184003978968\n",
"Epoch 310/1000, Loss: 0.0033553754910826683\n",
"Epoch 320/1000, Loss: 0.0031130558345466852\n",
"Epoch 330/1000, Loss: 0.0028987200930714607\n",
"Epoch 340/1000, Loss: 0.0027084406465291977\n",
"Epoch 350/1000, Loss: 0.0025372877717018127\n",
"Epoch 360/1000, Loss: 0.002381594618782401\n",
"Epoch 370/1000, Loss: 0.002238793997094035\n",
"Epoch 380/1000, Loss: 0.002110145753249526\n",
"Epoch 390/1000, Loss: 0.001990949036553502\n",
"Epoch 400/1000, Loss: 0.0018805447034537792\n",
"Epoch 410/1000, Loss: 0.0017787017859518528\n",
"Epoch 420/1000, Loss: 0.0016847399529069662\n",
"Epoch 430/1000, Loss: 0.001597930327989161\n",
"Epoch 440/1000, Loss: 0.001518208417110145\n",
"Epoch 450/1000, Loss: 0.0014447758439928293\n",
"Epoch 460/1000, Loss: 0.0013784606708213687\n",
"Epoch 470/1000, Loss: 0.0013172627659514546\n",
"Epoch 480/1000, Loss: 0.0012608648976311088\n",
"Epoch 490/1000, Loss: 0.001208683243021369\n",
"Epoch 500/1000, Loss: 0.0011611600639298558\n",
"Epoch 510/1000, Loss: 0.0011176610132679343\n",
"Epoch 520/1000, Loss: 0.0010775947012007236\n",
"Epoch 530/1000, Loss: 0.0010414356365799904\n",
"Epoch 540/1000, Loss: 0.0010077828774228692\n",
"Epoch 550/1000, Loss: 0.000977047486230731\n",
"Epoch 560/1000, Loss: 0.0009483486064709723\n",
"Epoch 570/1000, Loss: 0.0009219619678333402\n",
"Epoch 580/1000, Loss: 0.000897533493116498\n",
"Epoch 590/1000, Loss: 0.0008748812833800912\n",
"Epoch 600/1000, Loss: 0.0008537117973901331\n",
"Epoch 610/1000, Loss: 0.0008338657789863646\n",
"Epoch 620/1000, Loss: 0.0008152445661835372\n",
"Epoch 630/1000, Loss: 0.000797846878413111\n",
"Epoch 640/1000, Loss: 0.0007811780087649822\n",
"Epoch 650/1000, Loss: 0.0007652725907973945\n",
"Epoch 660/1000, Loss: 0.0007502862135879695\n",
"Epoch 670/1000, Loss: 0.0007362824399024248\n",
"Epoch 680/1000, Loss: 0.0007233091746456921\n",
"Epoch 690/1000, Loss: 0.0007109907455742359\n",
"Epoch 700/1000, Loss: 0.0006993163260631263\n",
"Epoch 710/1000, Loss: 0.0006883330643177032\n",
"Epoch 720/1000, Loss: 0.000678009819239378\n",
"Epoch 730/1000, Loss: 0.0006681602098979056\n",
"Epoch 740/1000, Loss: 0.0006588594405911863\n",
"Epoch 750/1000, Loss: 0.0006500912713818252\n",
"Epoch 760/1000, Loss: 0.0006416388787329197\n",
"Epoch 770/1000, Loss: 0.0006337724043987691\n",
"Epoch 780/1000, Loss: 0.0006261061644181609\n",
"Epoch 790/1000, Loss: 0.0006186887621879578\n",
"Epoch 800/1000, Loss: 0.0006119576282799244\n",
"Epoch 810/1000, Loss: 0.0006054439581930637\n",
"Epoch 820/1000, Loss: 0.0005992540973238647\n",
"Epoch 830/1000, Loss: 0.0005932800122536719\n",
"Epoch 840/1000, Loss: 0.000587515824008733\n",
"Epoch 850/1000, Loss: 0.0005819853395223618\n",
"Epoch 860/1000, Loss: 0.0005764576490037143\n",
"Epoch 870/1000, Loss: 0.0005712246056646109\n",
"Epoch 880/1000, Loss: 0.0005661725299432874\n",
"Epoch 890/1000, Loss: 0.0005614019464701414\n",
"Epoch 900/1000, Loss: 0.0005567952175624669\n",
"Epoch 910/1000, Loss: 0.0005523671861737967\n",
"Epoch 920/1000, Loss: 0.000548191019333899\n",
"Epoch 930/1000, Loss: 0.0005440027453005314\n",
"Epoch 940/1000, Loss: 0.0005400135414674878\n",
"Epoch 950/1000, Loss: 0.0005360668292269111\n",
"Epoch 960/1000, Loss: 0.0005323308287188411\n",
"Epoch 970/1000, Loss: 0.0005282927886582911\n",
"Epoch 980/1000, Loss: 0.0005240424652583897\n",
"Epoch 990/1000, Loss: 0.0005197781720198691\n",
"Epoch 1000/1000, Loss: 0.0005156174884177744\n"
]
}
],
"source": [
"# Neural network V1\n",
"model_v1 = NeuralNetworkV1(input_size, hidden_size).to(device)\n",
"criterion_v1 = nn.BCELoss()\n",
"optimizer_v1 = optim.Adam(model_v1.parameters(), lr=learning_rate, weight_decay=weight_decay)\n",
"\n",
"# Train the model\n",
"train(model_v1, X_train, y_train, criterion_v1, optimizer_v1, epochs)\n",
"\n",
"# Evaluate the model\n",
"cm_v1, cr_v1, acc_v1 = evaluate(model_v1, X_test, y_test)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-08T16:07:07.255395800Z",
"start_time": "2024-06-08T16:07:04.326790700Z"
}
},
"id": "7f48dfd17faaa0b3"
},
{
"cell_type": "code",
"execution_count": 46,
"outputs": [
{
"data": {
"text/plain": "<Figure size 1000x600 with 2 Axes>",
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAwIAAAIhCAYAAAD98w2UAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/H5lhTAAAACXBIWXMAAA9hAAAPYQGoP6dpAABVu0lEQVR4nO3deVxUZf//8feADCCgJu5m5JJLqIi4pFkZ6u2WuaW3LRhqouXSXZqmZmqKuLQqprhVLrnlkpWVe2qu4Z7L7ZZppDfkFoKMwvz+8Od8ZwSTMWAmzuvp4zxu5jpnrvM5wE3nM5/ruo7JarVaBQAAAMBQPFwdAAAAAIC8RyIAAAAAGBCJAAAAAGBAJAIAAACAAZEIAAAAAAZEIgAAAAAYEIkAAAAAYEAkAgAAAIABkQgAwB3wvMV/Nn5+APDXSAQAN3DgwAG98cYbaty4sWrWrKmmTZtq+PDhOnPmTK6d89NPP9Wjjz6qmjVr6uOPP86RPnfs2KEqVapox44dOdKfK3388ceaNWvWXY8LDw/Xm2++mQcR3TxX7dq1lZCQkOX+KlWqaPLkyXkSS3ZFREQoIiLijvtv/c688sorWe5ftmyZqlSporNnzzp13uz+/HLDm2++qfDw8GwfP3XqVFWpUkX79++/4zGjR49WaGiokpOTM+3r379/nv0OAshfSAQAF5s/f766dOmiP/74QwMGDNCMGTMUFRWlnTt36plnntGRI0dy/JzJyckaP368atasqVmzZql9+/Y50m9wcLAWLVqk4ODgHOnPlT766COlpqbe9bjY2Ng73sTmhqtXr+qtt97Ks/PllXXr1mnlypU51l92f37uoH379vLw8NBXX32V5X6LxaKvv/5aLVq0kL+/v609IyND0dHR+v777/MqVAD5DIkA4ELx8fGKjo7Wc889p9mzZ6tNmzaqX7++OnfurAULFsjb21tDhw7N8fNevnxZGRkZatq0qerWravSpUvnSL/+/v6qVauWw81Kfvfwww/rgQceyLPzFSpUSD/++KMWL16cZ+fMC4UKFVJ0dLSSkpJcHUqeK1WqlBo1aqRVq1YpPT090/4ffvhBly5d0jPPPGNrO3LkiLp27arFixfLx8cnL8MFkI+QCAAuNGvWLAUEBOj111/PtK9o0aJ688031aRJE6WkpEiS0tPTNX/+fLVp00Y1a9ZU48aN9e677yotLc32vjfffFORkZFaunSpmjdvrurVq6tt27batGmTpJtDLW4NWxg6dKiqVKkiKeshLrcPy7h27ZpGjhypxx9/XNWrV1eLFi0chl9kNTTowIED6tGjh+rXr6/atWurd+/eOnbsWKb3bNu2Td27d1dISIgeffRRTZw4McubIvvYatSooZ9++kkdO3ZUjRo11Lx5c61fv14nT57Uiy++qJCQEDVr1kzffPONw3t37dqlHj16qG7duqpevbrCw8M1efJkZWRkSJLtexIbG2v7evLkyWrWrJliY2NVr149NWrUSJcvX3b4vsXExKhKlSravn17pu/hihUr7ngtzggPD1e9evU0fvx4/f777395bEZGhqZPn65mzZqpevXqat68uebOnZupv7v93O907deuXdN7772nf/3rX6pevbpq166tbt266fDhw05f12uvvaaUlBSNHDnyrscmJCTo9ddfV7169RQSEqIXX3xRhw4dsu2//ec3Z84cVa1aVRcvXrQdM2XKFNvv3S1r165V1apVdf78eUnZ/91duHChnnzySdWuXVs//vhjpngPHTqkOnXqqGfPnrJYLFleU8eOHZWUlOQQzy3Lly9XhQoVFBYWZmsbPHiw0tPTtWjRIgUGBt71ewYAWSERAFzEarVqy5YtatCggXx9fbM8plWrVurTp48KFiwoSXr77bcVExOjpk2baurUqXr++ec1b948vfLKKw4TIw8ePKhZs2apf//+mjJlijw9PdWvXz9dvnxZjRs3VmxsrCTp5Zdf1qJFi7Id89ixY7Vp0yYNHjxYs2bNUpMmTTRhwgQtXbo0y+O3b9+uZ5991vbeMWPG6Pfff1eXLl104sQJh2MHDhyosLAwTZs2TU899ZRmzpypJUuW/GU8N27c0IABA9SlSxdNnTpVvr6+GjhwoHr37q3GjRtr2rRpKlGihAYPHqxz585JuvlJamRkpIoUKaIPPvhAU6dOVZ06dRQbG6tvv/1Wkmzfk2eeecbh+5OQkKAffvhBH3zwgYYMGaLChQs7xPPaa6/pwQcf1IgRI2SxWJSQkKDo6Gi1bNlS7dq1y/b3+a+YTCaNHTtWGRkZdx0iNHLkSE2aNElPP/20pk2bphYtWmjs2LGaMmWK0+fN6toHDRqkpUuXKioqSrNnz9aQIUN07NgxDRgwwOmJuhUrVlS/fv20Zs0aff3113c87sKFC+rSpYt+/vlnDR8+XO+9954yMjL0/PPP236nbv/5NW7cWFar1SFBu/X1rl27bG2bNm3Sww8/rJIlSzr1uxsbG6vBgwfr7bffVmhoqMO+EydOqEePHgoJCdGUKVNkNpuzvK7w8HDdd999mYYHXbhwQZs2bXKoBkjShAkTtGDBAlWtWvWO3ysAuJsCrg4AMKqLFy8qLS1N999/f7aOP378uL744gsNGDBAUVFRkqRHH31UJUqU0KBBg7Rp0yY98cQTkqQ///xTy5Ytsw1ZKViwoF544QVt375dzZs3V7Vq1SRJDzzwgGrVqpXtmHfu3KlHH31UrVu3liTVr19fBQsWvOMnku+9956CgoI0ffp0eXp6SpIaNWqkZs2aadKkSfroo49sx3bq1El9+vSRJDVo0EBr167Vxo0b1aVLlzvGk5GRod69e6tTp06SpCtXrui1117Tiy++qG7dukmSAgIC1LFjRx08eFClSpXSkSNH1LBhQ02cOFEeHh627+P69eu1Y8cOtW7d2vY9KVWqlMP358aNGxo8eLDq1KmTZTw+Pj4aN26cnnvuOU2fPl27d++Wv7+/Ro0adbdvrVPKlSun119/XWPGjNGSJUts12/v1KlTWrx4sV5//XXb70ujRo1kMpkUFxen5557Tvfdd1+2z3n7tVssFtt8hVatWkmS6tWrp+TkZI0bN05JSUkqXry4U9fVo0cPrVmzRqNHj9YjjzyiYsWKZTrms88+06VLl7RgwQKVLVtWkvT444+rVatW+uijjzRp0qQsf37ly5fXtm3b1LJlS6WmpmrPnj0KDg52SAQ2b96sDh06SHLud/e5555TixYtMsV65swZRUZGqmrVqvr444/vmARIktls1tNPP60vvvhCo0aNsg33uVXNuj2RvFX1AIC/g4oA4CK3bi7+aviLvZ07d0qS7Sb8ltatW8vT09NhOE7RokUdxq2XKlVKkv725Mn69etr8eLF6tmzp+bNm6czZ86oT58+aty4caZjU1JSdODAAbVs2dJ2rdLNseBPPvmk7Xpuuf2T1FKlStmGRP0V+/fdSkhCQkJsbUWKFJF0M0mQbt5QzZgxQ9evX9eRI0f0/fffa9KkSUpPT9f169fver5bSdRfxRMZGakpU6Zo69atGjduXKbKgb309HTduHHDtmX39+GFF15Q3bp1NW7cOFu1w9727dtltVoVHh7u0H94eLjS0tIUHx+frfPYs792s9msWbNmqVWrVjp//ry2b9+uhQsXasOGDZJ0xyEwf8XT01MxMTFKSUm5Y/K0bds2VatWTSVLlrRdk4eHhx5//HFt3br1jn03btzYtj8+Pl5eXl7q2rWr9u3bJ4vFouPHjyshIUGNGzd2+nc3q9+Jq1evKjIyUomJiRo1apS8vb3vev0dO3bU1atXtX79elvb8uXL1bhxY4b/AMgVJAKAixQuXFh+fn53XApSunkzffnyZUmy/e/tn7IWKFBA9913n/78809b2+1DjUwmkyTZxsDfq2HDhuk///mPzp49q9GjR6tp06bq0qVLlisb/fnnn7JarVl+qlusWDGHeCVlmvDo4eGRreElWU1MvtNQK+nmPIdhw4YpLCxM7dq108SJE/Xbb7+pQIEC2Tqfn5/fXY9p3769MjIyVKxYMYekJCuRkZEKDg62bZG
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Plot confusion matrix\n",
"plot_confusion_matrix(cm_v1, ['Benign', 'Malignant'], title='Confusion matrix - Neural Network V1')"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-08T16:07:09.111692Z",
"start_time": "2024-06-08T16:07:08.856357500Z"
}
},
"id": "dd60663756e0d8c0"
},
{
"cell_type": "code",
"execution_count": 47,
"outputs": [
{
"data": {
"text/plain": "<Figure size 1000x600 with 2 Axes>",
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAw0AAAIhCAYAAAAM+FYZAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/H5lhTAAAACXBIWXMAAA9hAAAPYQGoP6dpAACHk0lEQVR4nOzdd3hU1dbH8e+kTBJSIAQIVToEAoQQIHRBQOkglov6gggqKCBKU+BKkyYqKqACKioWpFsQEOlSBAxdEjoYWgik9zbvH8iYcRBJrslJyO9zn3kus8+eM+tENpk1a+99TBaLxYKIiIiIiMjfcDA6ABERERERKdiUNIiIiIiIyG0paRARERERkdtS0iAiIiIiIrelpEFERERERG5LSYOIiIiIiNyWkgYREREREbktJQ0iIiIiInJbShpERAoB3Yfz7+lnIyKS95Q0iEiBcuTIEUaPHk3btm1p0KABHTp04NVXXyU8PNymX+3atZk7d26+xjZ37lxq165tfZ6QkMDgwYMJCAigSZMmnDt3jtq1a7Nq1ap/9X03bdrEyy+/bH2+Z88eateuzZ49e/7V9yls4uLiGDNmDL/++qvRoYiI3PWcjA5AROSmL7/8kunTpxMcHMzIkSMpU6YM58+f5+OPP2bDhg189tln+Pn5GRbfI488QuvWra3Pv/nmG7Zs2cKECROoWbMm5cuXZ+nSpdxzzz3/6vt++umnNs/9/f1ZunQpNWrU+Fffp7AJDQ3l22+/5aGHHjI6FBGRu56SBhEpEEJCQpg2bRpPPPEE48ePt7YHBwfToUMHevXqxbhx4/71b/FzomzZspQtW9b6PCYmBoDHH38ck8kEQMOGDfM8Dg8Pj3x5HxERkZs0PUlECoSPP/4YT09PRowYYXesZMmSvPLKK7Rv356kpKRbvj4sLIyhQ4fSrFkz/P39ad26NVOnTiUlJcXaZ+fOnTz66KMEBgbSpEkTnnvuOU6fPm09/vvvvzN48GCCg4MJCAjgP//5D9u2bbMezz49qW/fvtbpUX5+frzyyitcuHDBbnrSmTNnGDp0KE2bNqVJkyYMGjTI5j0vXLjAmDFjaNWqFf7+/jRv3pwxY8YQHR1tfZ+9e/eyd+9e65SkW01POnLkCAMHDiQ4OJhGjRoxePBgTp48aT1+8zW7d+9mwIABBAQE0LJlS9544w0yMzP/9r/LqlWrqFu3LsuXL6dly5Y0bdqUU6dOAbBx40Z69+5N/fr1admyJVOnTrX57zN37lzuu+8+tmzZQqdOnQgICODRRx+1m1Z19epVxo4dy7333kuDBg14+OGH2bRpk02f2rVrM2/ePHr37k2DBg2YN28e/fr1A6Bfv3707dv3b69BRET+d0oaRMRwFouFHTt20Lx5c9zc3G7Zp0uXLgwZMoRixYrZHbt69SpPPPEEycnJzJw5kw8//JCuXbvy+eefs3jxYgDCw8N5/vnnqVevHh988AHTpk3j7NmzPPvss2RlZZGVlcWgQYNITk5m1qxZvP/++5QoUYLnnnuO8+fP273nxIkTefjhhwFYunQpzz//vF2fiIgI/vOf/3Du3DkmTZrEG2+8wbVr13jyySeJiYkhOTmZfv36cfr0aSZOnMjHH39Mv379+OGHH3j77bet71O3bl3q1q3L0qVL8ff3t3ufX375hcceewyA6dOnM3XqVC5fvkyfPn1sEhSAUaNGERQUxPz58+nWrRsfffQRy5cvv91/HjIzM1m0aBHTpk1j7NixVK9ene+//54hQ4ZQrVo13nvvPYYOHcp3333H888/b7MwOSoqipdffpnHH3+cd999F1dXVwYOHEhoaCgA165d4+GHH+bXX3/lpZdeYu7cuVSoUIEhQ4bw3Xff2cQxf/58unfvzpw5c+jQoQMTJkwAYMKECUycOPG21yAiIv8bTU8SEcNFR0eTmppKxYoVc/X6EydOUKdOHd599108PDwAaNGiBTt37mTPnj08++yzHD58mJSUFAYNGoSvry9wY7rRpk2bSEpKIjk5mTNnzvD8889z7733Ali/0U5LS7N7zxo1alinKt2cKnThwgWbPp9++ilpaWl88sknlC5dGrhRlXjsscc4dOgQZcqUoWzZsrz++utUqlQJgGbNmnHo0CH27t1rfZ+b1/R3U5LeeustKleuzMKFC3F0dASgVatWdOzYkTlz5vDuu+9a+z7yyCMMGTIEgObNm7Nx40a2bt1Knz59bvszHjx4MG3btgVuJHlvvvkmrVu35s0337T2qVKlCv3792fbtm3WvsnJyUyaNIlevXpZr69Dhw4sXLiQt99+m08++YSoqCh+/PFHKlSoAMC9995L//79mTVrFt26dcPB4cb3W40bN+app56yvl9sbKz1Z1TU13eIiOQ1JQ0iYribH3RvN03mdlq1akWrVq1IT0/n1KlTnD9/nhMnThAVFUWJEiUACAgIwMXFhYcffphOnTrRpk0bgoODadCgAQDu7u7UqFGDV199lR07dtCqVSvatGnD2LFjc31dISEhNGzY0JowwI1EZcuWLdbnX331FVlZWZw7d47z589z6tQpzpw5Q0ZGxh29R1JSEkeOHGHo0KHWnyOAl5cX7dq1s5leBRAYGGjzvGzZsn875Su7OnXqWP985swZrly5wqBBg2zibNKkCR4eHuzcudOaNDg5OdGtWzdrH1dXV9q0acP27dsB2Lt3L4GBgdaE4aYePXowduxYzpw5Y00IsscgIiL5S0mDiBiuePHiuLu7c+nSpb/tk5SURHp6OsWLF7c7lpWVxezZs/nyyy9JSkqiXLlyNGjQABcXF2ufihUr8sUXX7Bw4UJWrFjB4sWL8fLy4vHHH+fFF1/EZDKxaNEiPvjgA3766Se++eYbnJ2d6dChA5MnT77l+/6TmJiYf6yefPLJJ8yfP5+YmBhKlSpFvXr1cHNzIz4+/o7eIz4+HovFQqlSpeyOlSpVyu48rq6uNs8dHBzu6D4H2aeF3VwAPnnyZCZPnmzX9+rVqzYxODnZ/qrx8fGxniM2NtZaZflr7HBjW9VbxSAiIvlLSYOIFAitWrViz549pKam2nzYv2nZsmW8/vrrrFixwm5e/8KFC/n000+ZPHky999/P56engDWNQc3ZZ9uFBISwtKlS5k/fz5+fn507twZX19fJk2axMSJEwkLC2P9+vV8+OGHeHt752rOvKenJ1FRUXbtu3fvpmLFihw8eJCZM2cyevRoevfuTcmSJQEYPnw4R44cueP3MJlMXLt2ze5YZGSktdLyb/Ly8gJgzJgxNG3a1O549gTrZnKQ3bVr1/Dx8bH2jYyMtOtzs83b2/vfCFlERP5HWggtIgXCgAEDiImJ4Z133rE7FhkZyaJFi6hRo8YtFwKHhIRQo0YNHnroIWvCEBERwYkTJ8jKygJurC9o164daWlpmM1mmjdvzmuvvQbApUuXOHDgAC1atODw4cOYTCbq1KnDSy+9RK1atW5bAbmdxo0bc+jQIZvE4fr16zz99NNs27aNkJAQvLy8ePrpp60JQ2JiIiEhIda4Aeuc/lspVqwY9erVY926dTbTu+Lj49m6dStBQUG5iv12qlWrho+PDxcuXKB+/frWh6+vL2+99RbHjh2z9k1JSeHnn3+2eb59+3aaN28O3JjSdODAAS5evGjzHt999x2lS5emcuXKfxtH9ulYIiKSt1RpEJECoWHDhgwfPpx33nmH06dP06tXL7y9vTl58iQff/wxqampt0wo4EYF4f3332fhwoU0bNiQ8+fPs2DBAtLS0khOTgZuLMB98803GTJkCP/3f/+Ho6MjX3/9NWazmXbt2lGhQgVcXV0ZM2YMw4YNo1SpUuzatYvQ0FDr1p451b9/f7755huefvppBg0ahLOzMx988AFly5ale/fubNq0iSVLljBz5kzatWvH1atX+fjjj7l27ZrNt/VeXl4cOHCA3bt3U7duXbv3GTlyJAMHDuTZZ5/l8ccfJz09nYULF5KWlmZd9PxvcnR05KWXXmLChAk4OjrSrl074uLieP/994mIiLBL7MaOHcuLL76Ij48PH3/8MUlJSTz33HMAPPXUU3z
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Plot classification report\n",
"plot_classification_report(cr_v1)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-08T16:07:10.698060100Z",
"start_time": "2024-06-08T16:07:10.328701100Z"
}
},
"id": "9a5ab9310c986f72"
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [],
"metadata": {
"collapsed": false
},
"id": "c3b0d585910aef3e"
},
{
"cell_type": "markdown",
"source": [
"#### Neural network V2"
],
"metadata": {
"collapsed": false
},
"id": "19a30f696980b3ee"
},
{
"cell_type": "code",
"execution_count": 48,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 10/1000, Loss: 0.6100791096687317\n",
"Epoch 20/1000, Loss: 0.40138334035873413\n",
"Epoch 30/1000, Loss: 0.2066049426794052\n",
"Epoch 40/1000, Loss: 0.1198694109916687\n",
"Epoch 50/1000, Loss: 0.10492949932813644\n",
"Epoch 60/1000, Loss: 0.08525355905294418\n",
"Epoch 70/1000, Loss: 0.07265784591436386\n",
"Epoch 80/1000, Loss: 0.07437089085578918\n",
"Epoch 90/1000, Loss: 0.04634793847799301\n",
"Epoch 100/1000, Loss: 0.04539191350340843\n",
"Epoch 110/1000, Loss: 0.03791217878460884\n",
"Epoch 120/1000, Loss: 0.056155040860176086\n",
"Epoch 130/1000, Loss: 0.02974613755941391\n",
"Epoch 140/1000, Loss: 0.028519876301288605\n",
"Epoch 150/1000, Loss: 0.02841273508965969\n",
"Epoch 160/1000, Loss: 0.02827402390539646\n",
"Epoch 170/1000, Loss: 0.03137960284948349\n",
"Epoch 180/1000, Loss: 0.021595297381281853\n",
"Epoch 190/1000, Loss: 0.03367958217859268\n",
"Epoch 200/1000, Loss: 0.03138892352581024\n",
"Epoch 210/1000, Loss: 0.024735331535339355\n",
"Epoch 220/1000, Loss: 0.013547890819609165\n",
"Epoch 230/1000, Loss: 0.016778510063886642\n",
"Epoch 240/1000, Loss: 0.013662113808095455\n",
"Epoch 250/1000, Loss: 0.014643474481999874\n",
"Epoch 260/1000, Loss: 0.04232195019721985\n",
"Epoch 270/1000, Loss: 0.011198709718883038\n",
"Epoch 280/1000, Loss: 0.014255641028285027\n",
"Epoch 290/1000, Loss: 0.017376599833369255\n",
"Epoch 300/1000, Loss: 0.006715432740747929\n",
"Epoch 310/1000, Loss: 0.015104355290532112\n",
"Epoch 320/1000, Loss: 0.005779958330094814\n",
"Epoch 330/1000, Loss: 0.006878014653921127\n",
"Epoch 340/1000, Loss: 0.010289205238223076\n",
"Epoch 350/1000, Loss: 0.008154270239174366\n",
"Epoch 360/1000, Loss: 0.0052977693267166615\n",
"Epoch 370/1000, Loss: 0.0059393011033535\n",
"Epoch 380/1000, Loss: 0.003750022267922759\n",
"Epoch 390/1000, Loss: 0.006243106909096241\n",
"Epoch 400/1000, Loss: 0.0048174685798585415\n",
"Epoch 410/1000, Loss: 0.008404634892940521\n",
"Epoch 420/1000, Loss: 0.005285304039716721\n",
"Epoch 430/1000, Loss: 0.003210554365068674\n",
"Epoch 440/1000, Loss: 0.0030219131149351597\n",
"Epoch 450/1000, Loss: 0.003663143841549754\n",
"Epoch 460/1000, Loss: 0.004113232716917992\n",
"Epoch 470/1000, Loss: 0.011188282631337643\n",
"Epoch 480/1000, Loss: 0.008383953012526035\n",
"Epoch 490/1000, Loss: 0.005484223831444979\n",
"Epoch 500/1000, Loss: 0.001833457383327186\n",
"Epoch 510/1000, Loss: 0.0026361148338764906\n",
"Epoch 520/1000, Loss: 0.0018964618211612105\n",
"Epoch 530/1000, Loss: 0.005411419551819563\n",
"Epoch 540/1000, Loss: 0.005162812303751707\n",
"Epoch 550/1000, Loss: 0.004074939526617527\n",
"Epoch 560/1000, Loss: 0.001993684796616435\n",
"Epoch 570/1000, Loss: 0.002496592467650771\n",
"Epoch 580/1000, Loss: 0.012827489525079727\n",
"Epoch 590/1000, Loss: 0.0010587115539237857\n",
"Epoch 600/1000, Loss: 0.0020602247677743435\n",
"Epoch 610/1000, Loss: 0.0010980992810800672\n",
"Epoch 620/1000, Loss: 0.0023741163313388824\n",
"Epoch 630/1000, Loss: 0.00123070168774575\n",
"Epoch 640/1000, Loss: 0.011475415900349617\n",
"Epoch 650/1000, Loss: 0.00989847257733345\n",
"Epoch 660/1000, Loss: 0.0012280159862712026\n",
"Epoch 670/1000, Loss: 0.0017485406715422869\n",
"Epoch 680/1000, Loss: 0.0012420162092894316\n",
"Epoch 690/1000, Loss: 0.0004315624828450382\n",
"Epoch 700/1000, Loss: 0.0007627215818502009\n",
"Epoch 710/1000, Loss: 0.00213691801764071\n",
"Epoch 720/1000, Loss: 0.0021272513549774885\n",
"Epoch 730/1000, Loss: 0.0009174205479212105\n",
"Epoch 740/1000, Loss: 0.0015678246272727847\n",
"Epoch 750/1000, Loss: 0.0018770662136375904\n",
"Epoch 760/1000, Loss: 0.00043499123421497643\n",
"Epoch 770/1000, Loss: 0.001615240820683539\n",
"Epoch 780/1000, Loss: 0.0023441719822585583\n",
"Epoch 790/1000, Loss: 0.0004717250994872302\n",
"Epoch 800/1000, Loss: 0.0006681744707748294\n",
"Epoch 810/1000, Loss: 0.0018840441480278969\n",
"Epoch 820/1000, Loss: 0.0016851990949362516\n",
"Epoch 830/1000, Loss: 0.006307181902229786\n",
"Epoch 840/1000, Loss: 0.00034771396894939244\n",
"Epoch 850/1000, Loss: 0.0006758614326827228\n",
"Epoch 860/1000, Loss: 0.0003922640171367675\n",
"Epoch 870/1000, Loss: 0.0011983055155724287\n",
"Epoch 880/1000, Loss: 0.0005702194175682962\n",
"Epoch 890/1000, Loss: 0.0027277739718556404\n",
"Epoch 900/1000, Loss: 0.000657720142044127\n",
"Epoch 910/1000, Loss: 0.0007201721309684217\n",
"Epoch 920/1000, Loss: 0.002012366196140647\n",
"Epoch 930/1000, Loss: 0.0005625162739306688\n",
"Epoch 940/1000, Loss: 0.0008252564002759755\n",
"Epoch 950/1000, Loss: 0.0021678118500858545\n",
"Epoch 960/1000, Loss: 0.002342545660212636\n",
"Epoch 970/1000, Loss: 0.0015022775623947382\n",
"Epoch 980/1000, Loss: 0.0007447443204000592\n",
"Epoch 990/1000, Loss: 0.00044577810331247747\n",
"Epoch 1000/1000, Loss: 0.0006745096179656684\n"
]
}
],
"source": [
"# Neural network V2\n",
"model_v2 = NeuralNetworkV2(input_size, hidden_size).to(device)\n",
"criterion_v2 = nn.BCELoss()\n",
"optimizer_v2 = optim.Adam(model_v2.parameters(), lr=learning_rate, weight_decay=weight_decay)\n",
"\n",
"# Train the model\n",
"train(model_v2, X_train, y_train, criterion_v2, optimizer_v2, epochs)\n",
"\n",
"# Evaluate the model\n",
"cm_v2, cr_v2, acc_v2 = evaluate(model_v2, X_test, y_test)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-08T16:07:41.007361Z",
"start_time": "2024-06-08T16:07:37.052988100Z"
}
},
"id": "cb37345ed4ad8443"
},
{
"cell_type": "code",
"execution_count": 49,
"outputs": [
{
"data": {
"text/plain": "<Figure size 1000x600 with 2 Axes>",
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAwIAAAIhCAYAAAD98w2UAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/H5lhTAAAACXBIWXMAAA9hAAAPYQGoP6dpAABVFElEQVR4nO3deVxUZf//8feAbIoruWu4Je6oKKhZGWpumXu3WRhqLuXSXZqmZmqKaGaZYoqm3ZXllktWWu5puYYbrrdbhpHdkEuhwCDM7w9/zndGMBkDZuS8nj3O42auc811Pge48Xzmc13nmCwWi0UAAAAADMXN2QEAAAAAyHskAgAAAIABkQgAAAAABkQiAAAAABgQiQAAAABgQCQCAAAAgAGRCAAAAAAGRCIAAAAAGBCJAADcAc9bvL/x8wOAv0ciALiA2NhYvfbaa2rRooXq1aunVq1aady4cYqLi8u1Y/7nP//Rww8/rHr16umDDz7IkTH37NmjgIAA7dmzJ0fGc6YPPvhACxcuvGu/0NBQvf7663kQ0c1jNWzYUPHx8VnuDwgI0OzZs/MkluwKCwtTWFjYHfff+p156aWXsty/atUqBQQE6MKFCw4dN7s/v9zw+uuvKzQ0NNv9586dq4CAAB0+fPiOfSZNmqQGDRooKSlJknT06FH1799fTZo0UUhIiPr27aujR4/+49gBGAuJAOBkn332mXr27Kk//vhDw4cP14IFCzRgwADt3btX3bt314kTJ3L8mElJSZo2bZrq1aunhQsXqkuXLjkybu3atbVs2TLVrl07R8Zzpvfff1/Jycl37RcVFXXHi9jccO3aNb3xxht5dry8snnzZq1duzbHxsvuz88VdOnSRW5ubvrqq6+y3G82m/X111+rbdu28vX11fnz5/Xcc88pJSVFERERioyMlNlsVq9evXT27Nk8jh7A/YxEAHCimJgYRUREqFevXlq0aJE6duyokJAQPf3001qyZIm8vLw0ZsyYHD/u1atXlZGRoVatWqlx48YqW7Zsjozr6+ur+vXry9fXN0fGux/UqlVLDz74YJ4dr0iRIvrxxx+1fPnyPDtmXihSpIgiIiKUmJjo7FDyXJkyZdS8eXOtW7dO6enpmfZ///33unLlirp37y5J+vTTT+Xj46Po6Gi1bNlSoaGhmj9/vnx8fLR48eK8Dh/AfYxEAHCihQsXqnDhwnr11Vcz7StRooRef/11tWzZUtevX5ckpaen67PPPlPHjh1Vr149tWjRQu+8845SU1Ot73v99dcVHh6ulStXqk2bNqpTp446deqk7du3S7o51eLWtIUxY8YoICBAUtZTXG6flpGSkqIJEybo0UcfVZ06ddS2bVu76RdZTQ2KjY1Vv379FBISooYNG2rQoEE6depUpvfs2rVLffv2VWBgoB5++GFNnz49y4si29jq1q2rn376Sd26dVPdunXVpk0bbdmyRWfPntXzzz+vwMBAtW7dWt98843de/ft26d+/fqpcePGqlOnjkJDQzV79mxlZGRIkvV7EhUVZf169uzZat26taKiohQcHKzmzZvr6tWrdt+3yMhIBQQEaPfu3Zm+h2vWrLnjuTgiNDRUwcHBmjZtmn777be/7ZuRkaH58+erdevWqlOnjtq0aaNPP/0003h3+7nf6dxTUlI0Y8YMPfHEE6pTp44aNmyoPn366Pjx4w6f1yuvvKLr169rwoQJd+0bHx+vV199VcHBwQoMDNTzzz+vY8eOWfff/vP75JNPVKNGDV2+fNnaZ86cOdbfu1s2bdqkGjVq6Pfff5eU/d/dpUuX6vHHH1fDhg31448/Zor32LFjatSokfr37y+z2ZzlOXXr1k2JiYl28dyyevVqValSRUFBQZKkKlWqqG/fvipYsKC1T8GCBVWmTBn98ssvd/3+AcAtJAKAk1gsFv3www9q2rSpfHx8suzTvn17DR482PoP/ptvvqnIyEi1atVKc+fO1bPPPqvFixfrpZdeslsYeeTIES1cuFDDhg3TnDlz5O7urqFDh+rq1atq0aKFoqKiJEkvvviili1blu2Yp0yZou3bt2vUqFFauHChWrZsqbffflsrV67Msv/u3bv1zDPPWN87efJk/fbbb+rZs6fOnDlj13fEiBEKCgrSvHnz9OSTT+rDDz/UihUr/jaeGzduaPjw4erZs6fmzp0rHx8fjRgxQoMGDVKLFi00b948lSpVSqNGjdLFixclSSdOnFB4eLiKFSum9957T3PnzlWjRo0UFRWl9evXS5L1e9K9e3e77098fLy+//57vffeexo9erSKFi1qF88rr7yiSpUqafz48TKbzYqPj1dERITatWunzp07Z/v7/HdMJpOmTJmijIyMu04RmjBhgmbNmqWnnnpK8+bNU9u2bTVlyhTNmTPH4eNmde4jR47UypUrNWDAAC1atEijR4/WqVOnNHz4cIcX6latWlVDhw7Vxo0b9fXXX9+x36VLl9SzZ08dPXpU48aN04wZM5SRkaFnn33W+jt1+8+vRYsWslgsdgnara/37dtnbdu+fbtq1aql0qVLO/S7GxUVpVGjRunNN99UgwYN7PadOXNG/fr1U2BgoObMmSNPT88szys0NFTFixfPND3o0qVL2r59u7UaIEm9evXSCy+8YNfv/PnzOnXqlB566KE7fu8A4HYFnB0AYFSXL19WamqqKlSokK3+p0+f1hdffKHhw4drwIABkqSHH35YpUqV0siRI7V9+3Y99thjkqS//vpLq1atsk5ZKViwoJ577jnt3r1bbdq0Uc2aNSVJDz74oOrXr5/tmPfu3auHH35YHTp0kCSFhISoYMGC8vPzy7L/jBkz5O/vr/nz58vd3V2S1Lx5c7Vu3VqzZs3S+++/b+3bo0cPDR48WJLUtGlTbdq0Sdu2bVPPnj3vGE9GRoYGDRqkHj16SJL+/PNPvfLKK3r++efVp08fSVLhwoXVrVs3HTlyRGXKlNGJEyfUrFkzTZ8+XW5ubtbv45YtW7Rnzx516NDB+j0pU6aM3ffnxo0bGjVqlBo1apRlPN7e3po6dap69eql+fPna//+/fL19dXEiRPv9q11SMWKFfXqq69q8uTJWrFihfX8bZ07d07Lly/Xq6++av19ad68uUwmk6Kjo9WrVy8VL14828e8/dzNZrN1vUL79u0lScHBwUpKStLUqVOVmJiokiVLOnRe/fr108aNGzVp0iQ1adJEDzzwQKY+H3/8sa5cuaIlS5aofPnykqRHH31U7du31/vvv69Zs2Zl+fOrXLmydu3apXbt2ik5OVkHDhxQ7dq17RKBHTt2qGvXrpIc+93t1auX2rZtmynWuLg4hYeHq0aNGvrggw/umARIkqenp5566il98cUXmjhxory9vSXJWs36u0QyJSVFo0aNkqenp5577rk79gOA21ERAJzk1sXF301/sbV3715Jsl6E39KhQwe5u7vbTccpUaKE3bz1MmXKSNI/XjwZEhKi5cuXq3///lq8eLHi4uI0ePBgtWjRIlPf69evKzY2Vu3atbOeq3RzLvjjjz9uPZ9bbv8ktUyZMtYpUX/H9n23EpLAwEBrW7FixSTdTBKkmxdUCxYsUFpamk6cOKHvvvtOs2bNUnp6utLS0u56vFtJ1N/FEx4erjlz5mjnzp2aOnVqpsqBrfT0dN24ccO6Zff34bnnnlPjxo01depUa7XD1u7du2WxWBQaGmo3fmhoqFJTUxUTE5Ot49iyPXdPT08tXLhQ7du31++//67du3dr6dKl2rp1qyTdcQrM33F3d1dkZKSuX79+x+Rp165dqlmzpkqXLm09Jzc3Nz366KPauXPnHcdu0aKFdX9MTIw8PDzUu3dvHTp0SGazWadPn1Z8fLxatGjh8O9uVr8T165dU3h4uBISEjRx4kR5eXnd9fy7deuma9euacuWLda21atXq0WLFndMtpOSkjRw4EDFxsZq+vTp1uQIALKDRABwkqJFi6pQoUJ3vBWkdPNi+urVq5Jk/d/bP2UtUKCAihcvrr/++svadvtUI5PJJEnWOfD3auzYsfr3v/+tCxcuaNKkSWrVqpV69uy
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Plot confusion matrix\n",
"plot_confusion_matrix(cm_v2, ['Benign', 'Malignant'], title='Confusion matrix - Neural Network V2')"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-08T16:07:41.268092300Z",
"start_time": "2024-06-08T16:07:41.007361Z"
}
},
"id": "d288d528576840f0"
},
{
"cell_type": "code",
"execution_count": 50,
"outputs": [
{
"data": {
"text/plain": "<Figure size 1000x600 with 2 Axes>",
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAxUAAAIhCAYAAAA4gZcFAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/H5lhTAAAACXBIWXMAAA9hAAAPYQGoP6dpAACTbElEQVR4nOzdd3gUVdvH8e+mbBJSSCWEgPQSQouU0BUBpTf1kUcfKeJLERAFEQEVULrYADWIYEWkiQrSkSJFQCAUAYFQDC0EUkhIL+8fyJoliGwwmRB+H6+5LvfM2Zl7AkP23vucM6bs7OxsRERERERE8sjO6ABEREREROTupqRCRERERETuiJIKERERERG5I0oqRERERETkjiipEBERERGRO6KkQkRERERE7oiSChERERERuSNKKkRERERE5I4oqRARuQvoOaV/Tz8bERHjKakQkULlwIEDDB8+nAcffJBatWrRqlUrXnvtNSIjI636Va1alRkzZhRobDNmzKBq1aqW14mJifTv35/atWtTv359Tp06RdWqVfn222//1fOuX7+eESNGWF7v2LGDqlWrsmPHjn/1PHebK1eu8PLLL/Prr78aHYqIyD3PwegARESumzdvHhMnTiQ0NJRhw4ZRokQJTp8+zZw5c1izZg2ff/451apVMyy+xx9/nGbNmllef/fdd2zYsIHXX3+dypUrU6pUKRYsWMB99933r573s88+s3odHBzMggULqFSp0r96nrvN4cOH+f7773n00UeNDkVE5J6npEJECoXdu3czYcIEnnrqKUaPHm1pDw0NpVWrVnTp0oVRo0b961UAW5QsWZKSJUtaXsfFxQHw5JNPYjKZAKhTp06+x+Hm5lYg5xEREbldGv4kIoXCnDlzcHd3Z+jQobn2eXt788orr9CyZUuSkpJu+v4jR44waNAgGjZsSHBwMM2aNWP8+PGkpKRY+mzdupX//Oc/hISEUL9+fQYMGEBERIRl/x9//EH//v0JDQ2ldu3aPPHEE2zatMmyP+fwp6efftoy/KpatWq88sornDlzJtfwpxMnTjBo0CAaNGhA/fr16devn9U5z5w5w8svv0zTpk0JDg6mUaNGvPzyy8TGxlrOs3PnTnbu3GkZ8nSz4U8HDhygT58+hIaGcv/999O/f3+OHTtm2X/9Pdu3b+eZZ56hdu3aNGnShLfeeovMzMy//XP59ttvqV69OosWLaJJkyY0aNCA48ePA7Bu3Tq6detGzZo1adKkCePHj7f685kxYwYPPfQQGzZsoE2bNtSuXZv//Oc/uYZtXbx4kZEjR/LAAw9Qq1YtHnvsMdavX2/Vp2rVqsycOZNu3bpRq1YtZs6cSY8ePQDo0aMHTz/99N9eg4iI5D8lFSJiuOzsbLZs2UKjRo1wcXG5aZ927doxcOBAihUrlmvfxYsXeeqpp0hOTmby5MnMnj2b9u3b8+WXX/LFF18AEBkZyXPPPUeNGjX46KOPmDBhAidPnqRv375kZWWRlZVFv379SE5OZurUqXz44Yd4enoyYMAATp8+neucY8aM4bHHHgNgwYIFPPfcc7n6REVF8cQTT3Dq1CnGjh3LW2+9xaVLl+jZsydxcXEkJyfTo0cPIiIiGDNmDHPmzKFHjx78+OOPvPvuu5bzVK9enerVq7NgwQKCg4NzneeXX37hv//9LwATJ05k/PjxnD9/nu7du1slMAAvvfQSdevWJSwsjA4dOvDJJ5+waNGiW/3xkJmZydy5c5kwYQIjR46kYsWKLFu2jIEDB1KhQgU++OADBg0axA8//MBzzz1nNXE6JiaGESNG8OSTT/L+++/j7OxMnz59OHz4MACXLl3iscce49dff+XFF19kxowZBAYGMnDgQH744QerOMLCwujYsSPTp0+nVatWvP766wC8/vrrjBkz5pbXICIi+UvDn0TEcLGxsaSmplK6dOk8vf/o0aMEBQXx/vvv4+bmBkDjxo3ZunUrO3bsoG/fvuzfv5+UlBT69euHv78/cG040/r160lKSiI5OZkTJ07w3HPP8cADDwBYvhFPS0vLdc5KlSpZhkJdH4p05swZqz6fffYZaWlpfPrpp/j5+QHXqhr//e9/2bdvHyVKlKBkyZJMmTKFMmXKANCwYUP27dvHzp07Lee5fk1/N+Tp7bffpmzZsnz88cfY29sD0LRpU1q3bs306dN5//33LX0ff/xxBg4cCECjRo1Yt24dGzdupHv37rf8Gffv358HH3wQuJYETps2jWbNmjFt2jRLn3LlytGrVy82bdpk6ZucnMzYsWPp0qWL5fpatWrFxx9/zLvvvsunn35KTEwMq1evJjAwEIAHHniAXr16MXXqVDp06ICd3bXvv+rVq0fv3r0t54uPj7f8jO71+SUiIkZTUiEihrv+QfhWw3BupWnTpjRt2pT09HSOHz/O6dOnOXr0KDExMXh6egJQu3ZtnJyceOyxx2jTpg3NmzcnNDSUWrVqAeDq6kqlSpV47bXX2LJlC02bNqV58+aMHDkyz9e1e/du6tSpY0ko4Fois2HDBsvrr7/+mqysLE6dOsXp06c5fvw4J06cICMj47bOkZSUxIEDBxg0aJDl5wjg4eFBixYtrIZvAYSEhFi9Llmy5N8OKcspKCjI8v8nTpzgwoUL9OvXzyrO+vXr4+bmxtatWy1JhYODAx06dLD0cXZ2pnnz5mzevBmAnTt3EhISYkkoruvUqRMjR47kxIkTloQhZwwiIlK4KKkQEcMVL14cV1dXzp0797d9kpKSSE9Pp3jx4rn2ZWVl8c477zBv3jySkpIICAigVq1aODk5WfqULl2ar776io8//pjFixfzxRdf4OHhwZNPPskLL7yAyWRi7ty5fPTRR6xdu5bvvvsOR0dHWrVqxbhx42563n8SFxf3j9WXTz/9lLCwMOLi4vD19aVGjRq4uLiQkJBwW+dISEggOzsbX1/fXPt8fX1zHcfZ2dnqtZ2d3W095yHnsLPrE9THjRvHuHHjcvW9ePGiVQwODta/anx8fCzHiI+Pt1Rpbowdri0be7MYRESkcFFSISKFQtOmTdmxYwepqalWycB1CxcuZMqUKSxevDjXvIKPP/6Yzz77jHHjxvHwww/j7u4OYJnzcF3O4Uy7d+9mwYIFhIWFUa1aNdq2bYu/vz9jx45lzJgxHDlyhFWrVjF79my8vLzyNGbf3d2dmJiYXO3bt2+ndOnShIeHM3nyZIYPH063bt3w9vYGYMiQIRw4cOC2z2Eymbh06VKufdHR0ZZKzb/Jw8MDgJdffpkGDRrk2p8zAbuePOR06dIlfHx8LH2jo6Nz9bne5uXl9W+ELCIi+UwTtUWkUHjmmWeIi4vjvffey7UvOjqauXPnUqlSpZtOVN69ezeVKlXi0UcftSQUUVFRHD16lKysLODa/IYWLVqQlpaG2WymUaNGvPnmmwCcO3eOvXv30rhxY/bv34/JZCIoKIgXX3yRKlWq3LKCciv16tVj3759VonF5cuXefbZZ9m0aRO7d+/Gw8ODZ5991pJQXL16ld27d1viBixzCm6mWLFi1KhRg5UrV1oNH0tISGDjxo3UrVs3T7HfSoUKFfDx8eHMmTPUrFnTsvn7+/P2229z6NAhS9+UlBR+/vlnq9ebN2+mUaNGwLUhU3v37uXs2bNW5/jhhx/w8/OjbNmyfxtHzuFeIiJiLFUqRKRQqFOnDkOGDOG9994jIiKCLl264OXlxbFjx5gzZw6pqak3TTjgWgXiww8/5OOPP6ZOnTqcPn2aWbNmkZaWRnJyMnBtgvC0adMYOHAg//vf/7C3t+ebb77BbDbTokULAgMDcXZ25uWXX2bw4MH4+vqybds2Dh8+bFm61Fa9evXiu+++49lnn6Vfv344Ojry0UcfUbJkSTp27Mj69euZP38+kydPpkWLFly8eJE5c+Zw6dIlq2/7PTw82Lt3L9u3b6d69eq5zjNs2DD69OlD3759efLJJ0lPT+fjjz8mLS3NMin732Rvb8+LL77I66+/jr29PS1atODKlSt8+OGHREVF5Ur8Ro4cyQsvvICPjw9z5swhKSmJAQMGANC7d29
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Plot classification report\n",
"plot_classification_report(cr_v2)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-08T16:07:43.343425300Z",
"start_time": "2024-06-08T16:07:42.949867200Z"
}
},
"id": "967251e118adbb60"
},
{
"cell_type": "markdown",
"source": [
"#### Neural network V3"
],
"metadata": {
"collapsed": false
},
"id": "9900eac2f18370e5"
},
{
"cell_type": "code",
"execution_count": 51,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 10/1000, Loss: 0.5781828761100769\n",
"Epoch 20/1000, Loss: 0.35954442620277405\n",
"Epoch 30/1000, Loss: 0.16433778405189514\n",
"Epoch 40/1000, Loss: 0.09435044229030609\n",
"Epoch 50/1000, Loss: 0.06897494941949844\n",
"Epoch 60/1000, Loss: 0.05602168291807175\n",
"Epoch 70/1000, Loss: 0.046810463070869446\n",
"Epoch 80/1000, Loss: 0.03961234167218208\n",
"Epoch 90/1000, Loss: 0.033818356692790985\n",
"Epoch 100/1000, Loss: 0.02865622565150261\n",
"Epoch 110/1000, Loss: 0.023877572268247604\n",
"Epoch 120/1000, Loss: 0.019604215398430824\n",
"Epoch 130/1000, Loss: 0.01610736735165119\n",
"Epoch 140/1000, Loss: 0.01334200520068407\n",
"Epoch 150/1000, Loss: 0.011029877699911594\n",
"Epoch 160/1000, Loss: 0.009049472399055958\n",
"Epoch 170/1000, Loss: 0.007442420814186335\n",
"Epoch 180/1000, Loss: 0.0061035449616611\n",
"Epoch 190/1000, Loss: 0.004539611749351025\n",
"Epoch 200/1000, Loss: 0.0030651914421468973\n",
"Epoch 210/1000, Loss: 0.0022714021615684032\n",
"Epoch 220/1000, Loss: 0.001738564227707684\n",
"Epoch 230/1000, Loss: 0.0013824186753481627\n",
"Epoch 240/1000, Loss: 0.0011372618610039353\n",
"Epoch 250/1000, Loss: 0.0009611304849386215\n",
"Epoch 260/1000, Loss: 0.0008317003957927227\n",
"Epoch 270/1000, Loss: 0.0007334426045417786\n",
"Epoch 280/1000, Loss: 0.0006564902723766863\n",
"Epoch 290/1000, Loss: 0.0005941776908002794\n",
"Epoch 300/1000, Loss: 0.0005428834119811654\n",
"Epoch 310/1000, Loss: 0.0005009892047382891\n",
"Epoch 320/1000, Loss: 0.00046646010014228523\n",
"Epoch 330/1000, Loss: 0.00043750536860898137\n",
"Epoch 340/1000, Loss: 0.0004131169698666781\n",
"Epoch 350/1000, Loss: 0.0003923749318346381\n",
"Epoch 360/1000, Loss: 0.00037453131517395377\n",
"Epoch 370/1000, Loss: 0.0003590169653762132\n",
"Epoch 380/1000, Loss: 0.0003454721299931407\n",
"Epoch 390/1000, Loss: 0.00033352847094647586\n",
"Epoch 400/1000, Loss: 0.0003230379370506853\n",
"Epoch 410/1000, Loss: 0.00031372407102026045\n",
"Epoch 420/1000, Loss: 0.00030547339702025056\n",
"Epoch 430/1000, Loss: 0.0002980876306537539\n",
"Epoch 440/1000, Loss: 0.0002914170909207314\n",
"Epoch 450/1000, Loss: 0.00028551131254062057\n",
"Epoch 460/1000, Loss: 0.000280180451227352\n",
"Epoch 470/1000, Loss: 0.00027529962244443595\n",
"Epoch 480/1000, Loss: 0.00027088262140750885\n",
"Epoch 490/1000, Loss: 0.00026690459344536066\n",
"Epoch 500/1000, Loss: 0.0002632274990901351\n",
"Epoch 510/1000, Loss: 0.00025982430088333786\n",
"Epoch 520/1000, Loss: 0.000256663013715297\n",
"Epoch 530/1000, Loss: 0.00025376983103342354\n",
"Epoch 540/1000, Loss: 0.0002510569174773991\n",
"Epoch 550/1000, Loss: 0.0002485134173184633\n",
"Epoch 560/1000, Loss: 0.0002461685216985643\n",
"Epoch 570/1000, Loss: 0.00024397078959736973\n",
"Epoch 580/1000, Loss: 0.00024182727793231606\n",
"Epoch 590/1000, Loss: 0.0002398582291789353\n",
"Epoch 600/1000, Loss: 0.00023796758614480495\n",
"Epoch 610/1000, Loss: 0.00023617268016096205\n",
"Epoch 620/1000, Loss: 0.00023440059158019722\n",
"Epoch 630/1000, Loss: 0.00023271011014003307\n",
"Epoch 640/1000, Loss: 0.00023108486493583769\n",
"Epoch 650/1000, Loss: 0.00022952101426199079\n",
"Epoch 660/1000, Loss: 0.00022804134641774\n",
"Epoch 670/1000, Loss: 0.00022659693786408752\n",
"Epoch 680/1000, Loss: 0.00022519213962368667\n",
"Epoch 690/1000, Loss: 0.0002238261658931151\n",
"Epoch 700/1000, Loss: 0.0002225149655714631\n",
"Epoch 710/1000, Loss: 0.000221233261981979\n",
"Epoch 720/1000, Loss: 0.00022013885609339923\n",
"Epoch 730/1000, Loss: 0.0002189427614212036\n",
"Epoch 740/1000, Loss: 0.00021769681188743562\n",
"Epoch 750/1000, Loss: 0.00021648130496032536\n",
"Epoch 760/1000, Loss: 0.00021531064703594893\n",
"Epoch 770/1000, Loss: 0.00021418495452962816\n",
"Epoch 780/1000, Loss: 0.0002131147193722427\n",
"Epoch 790/1000, Loss: 0.00021202574134804308\n",
"Epoch 800/1000, Loss: 0.0002110681962221861\n",
"Epoch 810/1000, Loss: 0.00021014301455579698\n",
"Epoch 820/1000, Loss: 0.0002092804352287203\n",
"Epoch 830/1000, Loss: 0.00020846931147389114\n",
"Epoch 840/1000, Loss: 0.00020768567628692836\n",
"Epoch 850/1000, Loss: 0.00020691761164925992\n",
"Epoch 860/1000, Loss: 0.00020612987282220274\n",
"Epoch 870/1000, Loss: 0.00020533577480819076\n",
"Epoch 880/1000, Loss: 0.00020457926439121366\n",
"Epoch 890/1000, Loss: 0.00020382019283715636\n",
"Epoch 900/1000, Loss: 0.00020311155822128057\n",
"Epoch 910/1000, Loss: 0.00020234761177562177\n",
"Epoch 920/1000, Loss: 0.00020164126181043684\n",
"Epoch 930/1000, Loss: 0.00020095478976145387\n",
"Epoch 940/1000, Loss: 0.0002003153640544042\n",
"Epoch 950/1000, Loss: 0.00019972550217062235\n",
"Epoch 960/1000, Loss: 0.00019920482009183615\n",
"Epoch 970/1000, Loss: 0.0001986794959520921\n",
"Epoch 980/1000, Loss: 0.00019817189604509622\n",
"Epoch 990/1000, Loss: 0.00019766329205594957\n",
"Epoch 1000/1000, Loss: 0.00019715774396900088\n"
]
}
],
"source": [
"# Neural network V3\n",
"model_v3 = NeuralNetworkV3(input_size, hidden_size).to(device)\n",
"criterion_v3 = nn.BCELoss()\n",
"optimizer_v3 = optim.Adam(model_v3.parameters(), lr=learning_rate, weight_decay=weight_decay)\n",
"\n",
"# Train the model\n",
"train(model_v3, X_train, y_train, criterion_v3, optimizer_v3, epochs)\n",
"\n",
"# Evaluate the model\n",
"cm_v3, cr_v3, acc_v3 = evaluate(model_v3, X_test, y_test)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-08T16:08:02.306988400Z",
"start_time": "2024-06-08T16:07:58.364980300Z"
}
},
"id": "1647870c14f1d6eb"
},
{
"cell_type": "code",
"execution_count": 53,
"outputs": [
{
"data": {
"text/plain": "<Figure size 1000x600 with 2 Axes>",
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAwIAAAIhCAYAAAD98w2UAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/H5lhTAAAACXBIWXMAAA9hAAAPYQGoP6dpAABU+klEQVR4nO3deXxMZ///8fckspAElRJLNbVUaBAEQbXViNqqdrcuIajQWnoXpaiiRKjqQpTYehetrZZqS2svam1sUfS2VWnKLbW0ITIk8/vDz3xnJCqjSWbkvJ4e53FnrnPNdT4nyZ2ez3yu6xyTxWKxCAAAAIChuDk7AAAAAAB5j0QAAAAAMCASAQAAAMCASAQAAAAAAyIRAAAAAAyIRAAAAAAwIBIBAAAAwIBIBAAAAAADIhEAgL/BMxfvb/z8AODOSAQAF5GYmKg33nhDjRo1UvXq1RUREaERI0bo9OnTuXbM//znP3r88cdVvXp1ffzxxzky5s6dOxUUFKSdO3fmyHjO9PHHH2v27Nl37RceHq4333wzDyK6eaxatWopKSkpy/1BQUGaMmVKnsSSXZGRkYqMjLzj/lu/M6+++mqW+5ctW6agoCCdOXPGoeNm9+eXG958802Fh4dnu/+0adMUFBSkAwcO3LHPmDFjVLNmTaWkpEiS9u7dq8jISNWsWVMNGzbU2LFjrfsAIDtIBAAX8Nlnn6lz5876448/NHDgQM2cOVPR0dHatWuXOnTooCNHjuT4MVNSUjRhwgRVr15ds2fPVtu2bXNk3ODgYC1atEjBwcE5Mp4zffTRR0pNTb1rv7i4uDtexOaGK1eu6K233sqz4+WV9evXa+XKlTk2XnZ/fq6gbdu2cnNz01dffZXlfrPZrK+//lrNmjWTr6+vjhw5oqioKPn4+GjKlCkaMGCAVq9erddeey2PIwdwPyMRAJwsISFBMTExeuGFFzRnzhy1atVKYWFh6tSpkxYsWCAvLy8NGzYsx497+fJlZWRkKCIiQnXq1FGpUqVyZFxfX1/VqFFDvr6+OTLe/eCxxx7Tww8/nGfHK1y4sH744QctXrw4z46ZFwoXLqyYmBglJyc7O5Q8V7JkSTVs2FCrVq1Senp6pv3ff/+9Ll26pA4dOkiSPv30UxUpUkSTJ09Ww4YN1a5dOw0cOFBbt27ViRMn8jp8APcpEgHAyWbPni0/Pz8NGDAg075ixYrpzTffVOPGjXX16lVJUnp6uj777DO1atVK1atXV6NGjfTee+8pLS3N+r4333xTUVFRWrp0qZo2baqqVauqdevW2rx5s6SbUy1uTVsYNmyYgoKCJGU9xeX2aRnXrl3TqFGj9OSTT6pq1apq1qyZ3fSLrKYGJSYmqkePHgoLC1OtWrXUu3dvHT16NNN7tm/fru7duyskJESPP/64Jk6cmOVFkW1s1apV048//qj27durWrVqatq0qTZs2KATJ06oa9euCgkJUZMmTfTNN9/YvXf37t3q0aOH6tSpo6pVqyo8PFxTpkxRRkaGJFm/J3Fxcdavp0yZoiZNmiguLk5169ZVw4YNdfnyZbvvW2xsrIKCgrRjx45M38MVK1bc8VwcER4errp162rChAn6/fff/7ZvRkaGZsyYoSZNmqhq1apq2rSp5s2bl2m8u/3c73Tu165d06RJk/TMM8+oatWqqlWrlrp166bDhw87fF6vv/66rl69qlGjRt21b1JSkgYMGKC6desqJCREXbt21aFDh6z7b//5zZ07V5UrV9bFixetfaZOnWr9vbtl3bp1qly5ss6dOycp+7+7Cxcu1NNPP61atWrphx9+yBTvoUOHVLt2bfXs2VNmsznLc2rfvr2Sk5Pt4rll+fLlKl++vEJDQyVJ//73vzVjxgx5enpa+3h4eEjSHccHgNuRCABOZLFYtHXrVtWvX18FCxbMsk+LFi3Up08fFSpUSJL09ttvKzY2VhEREZo2bZpefPFFzZ8/X6+++qrdwsiDBw9q9uzZ6t+/v6ZOnSp3d3f169dPly9fVqNGjRQXFydJeuWVV7Ro0aJsxzxu3Dht3rxZQ4YM0ezZs9W4cWO9++67Wrp0aZb9d+zYoeeff9763rFjx+r3339X586ddfz4cbu+gwYNUmhoqKZPn65nn31Ws2bN0pIlS/42nhs3bmjgwIHq3Lmzpk2bpoIFC2rQoEHq3bu3GjVqpOnTp6tEiRIaMmSIzp49K0nWaRVFixbVBx98oGnTpql27dqKi4vT6tWrJcn6PenQoYPd9ycpKUnff/+9PvjgAw0dOlRFihSxi+f111/XI488opEjR8psNispKUkxMTFq3ry52rRpk+3v898xmUwaN26cMjIy7jpFaNSoUZo8ebKee+45TZ8+Xc2aNdO4ceM0depUh4+b1bkPHjxYS5cuVXR0tObMmaOhQ4fq6NGjGjhwoMMLdStUqKB+/fpp7dq1+vrrr+/Y78KFC+rcubN++uknjRgxQpMmTVJGRoZefPFF6+/U7T+/Ro0ayWKx2CVot77evXu3tW3z5s167LHHFBAQ4NDvblxcnIYMGaK3335bNWvWtNt3/Phx9ejRQyEhIZo6dardxbut8PBwPfDAA5mmB124cEGbN2+2VgMkKSAgQJUrV5YkXb16Vdu2bdMHH3ygWrVqWdsB4G4KODsAwMguXryotLQ0PfTQQ9nqf+zYMX3xxRcaOHCgoqOjJUmPP/64SpQoocGDB2vz5s166qmnJEl//fWXli1bZp2yUqhQIb300kvasWOHmjZtqipVqkiSHn74YdWoUSPbMe/atUuPP/64WrZsKUkKCwtToUKF5O/vn2X/SZMmKTAwUDNmzJC7u7skqWHDhmrSpIkmT56sjz76yNq3Y8eO6tOnjySpfv36WrdunTZt2qTOnTvfMZ6MjAz17t1bHTt2lCT9+eefev3119W1a1d169ZNkuTn56f27dvr4MGDKlmypI4cOaIGDRpo4sSJcnNzs34fN2zYoJ07d6ply5bW70nJkiXtvj83btzQkCFDVLt27Szj8fb21vjx4/XCCy9oxowZ2rNnj3x9fTV69Oi7fWsdUrZsWQ0YMEBjx47VkiVLrOdv6+TJk1q8eLEGDBhg/X1p2LChTCaT4uPj9cILL+iBBx7I9jFvP3ez2Wxdr9CiRQtJUt26dZWSkqLx48crOTlZxYsXd+i8evToobVr12rMmDGqV6+eHnzwwUx9Pv30U126dEkLFixQmTJlJElPPvmkWrRooY8++kiTJ0/O8udXrlw5bd++Xc2bN1dqaqr27t2r4OBgu0Rgy5YtateunSTHfndfeOEFNWvWLFOsp0+fVlRUlCpXrqyPP/74jkmAJHl6euq5557TF198odGjR8vb21uSrNWsrBJJi8WievXqKS0tTUWLFtWIESPuOD4A3I6KAOBEty4u/m76i61du3ZJkvUi/JaWLVvK3d3dbjpOsWLF7OatlyxZUpL+8eLJsLAwLV68WD179tT8+fN1+vRp9enTR40aNcrU9+rVq0pMTFTz5s2t5yrdnAv+9NNPW8/nlts/SS1ZsqR1StTfsX3frYQkJCTE2la0aFFJN5ME6eYF1cyZM3X9+nUdOXJE3333nSZPnqz09HRdv379rse7lUT9XTxRUVGaOnWqtm3bpvHjx2eqHNhKT0/XjRs3rFt2fx9eeukl1alTR+PHj7dWO2zt2LFDFotF4eHhduOHh4crLS1NCQkJ2TqOLdtz9/T01OzZs9WiRQudO3dOO3bs0MKFC7Vx40ZJ9zZFxd3dXbGxsbp69eodk6ft27erSpUqCggIsJ6Tm5ubnnzySW3btu2OYzdq1Mi6PyEhQR4eHurSpYv2798vs9msY8eOKSkpSY0aNXL4dzer34krV64oKipK58+f1+jRo+Xl5XXX82/fvr2uXLmiDRs2WNuWL1+uRo0aZZls37hxQ9OmTdO0adNUrlw5vfjii7lycwEA+ROJAOBERYoUkY+Pzx1vBSndvJi+fPmyJFn/9/ZPWQsUKKAHHnhAf/31l7Xt9qlGJpNJkqxz4O/V8OHD9e9//1tnzpzRmDFjFBERoc6dO2d58fHXX3/
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Plot confusion matrix\n",
"plot_confusion_matrix(cm_v3, ['Benign', 'Malignant'], title='Confusion matrix - Neural Network V3')"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-08T16:08:05.178318100Z",
"start_time": "2024-06-08T16:08:04.939888900Z"
}
},
"id": "352c5a8e9037cbf9"
},
{
"cell_type": "code",
"execution_count": 54,
"outputs": [
{
"data": {
"text/plain": "<Figure size 1000x600 with 2 Axes>",
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAxUAAAIhCAYAAAA4gZcFAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/H5lhTAAAACXBIWXMAAA9hAAAPYQGoP6dpAACTbElEQVR4nOzdd3gUVdvH8e+mbBJSSCWEgPQSQouU0BUBpTf1kUcfKeJLERAFEQEVULrYADWIYEWkiQrSkSJFQCAUAYFQDC0EUkhIL+8fyJoliGwwmRB+H6+5LvfM2Zl7AkP23vucM6bs7OxsRERERERE8sjO6ABEREREROTupqRCRERERETuiJIKERERERG5I0oqRERERETkjiipEBERERGRO6KkQkRERERE7oiSChERERERuSNKKkRERERE5I4oqRARuQvoOaV/Tz8bERHjKakQkULlwIEDDB8+nAcffJBatWrRqlUrXnvtNSIjI636Va1alRkzZhRobDNmzKBq1aqW14mJifTv35/atWtTv359Tp06RdWqVfn222//1fOuX7+eESNGWF7v2LGDqlWrsmPHjn/1PHebK1eu8PLLL/Prr78aHYqIyD3PwegARESumzdvHhMnTiQ0NJRhw4ZRokQJTp8+zZw5c1izZg2ff/451apVMyy+xx9/nGbNmllef/fdd2zYsIHXX3+dypUrU6pUKRYsWMB99933r573s88+s3odHBzMggULqFSp0r96nrvN4cOH+f7773n00UeNDkVE5J6npEJECoXdu3czYcIEnnrqKUaPHm1pDw0NpVWrVnTp0oVRo0b961UAW5QsWZKSJUtaXsfFxQHw5JNPYjKZAKhTp06+x+Hm5lYg5xEREbldGv4kIoXCnDlzcHd3Z+jQobn2eXt788orr9CyZUuSkpJu+v4jR44waNAgGjZsSHBwMM2aNWP8+PGkpKRY+mzdupX//Oc/hISEUL9+fQYMGEBERIRl/x9//EH//v0JDQ2ldu3aPPHEE2zatMmyP+fwp6efftoy/KpatWq88sornDlzJtfwpxMnTjBo0CAaNGhA/fr16devn9U5z5w5w8svv0zTpk0JDg6mUaNGvPzyy8TGxlrOs3PnTnbu3GkZ8nSz4U8HDhygT58+hIaGcv/999O/f3+OHTtm2X/9Pdu3b+eZZ56hdu3aNGnShLfeeovMzMy//XP59ttvqV69OosWLaJJkyY0aNCA48ePA7Bu3Tq6detGzZo1adKkCePHj7f685kxYwYPPfQQGzZsoE2bNtSuXZv//Oc/uYZtXbx4kZEjR/LAAw9Qq1YtHnvsMdavX2/Vp2rVqsycOZNu3bpRq1YtZs6cSY8ePQDo0aMHTz/99N9eg4iI5D8lFSJiuOzsbLZs2UKjRo1wcXG5aZ927doxcOBAihUrlmvfxYsXeeqpp0hOTmby5MnMnj2b9u3b8+WXX/LFF18AEBkZyXPPPUeNGjX46KOPmDBhAidPnqRv375kZWWRlZVFv379SE5OZurUqXz44Yd4enoyYMAATp8+neucY8aM4bHHHgNgwYIFPPfcc7n6REVF8cQTT3Dq1CnGjh3LW2+9xaVLl+jZsydxcXEkJyfTo0cPIiIiGDNmDHPmzKFHjx78+OOPvPvuu5bzVK9enerVq7NgwQKCg4NzneeXX37hv//9LwATJ05k/PjxnD9/nu7du1slMAAvvfQSdevWJSwsjA4dOvDJJ5+waNGiW/3xkJmZydy5c5kwYQIjR46kYsWKLFu2jIEDB1KhQgU++OADBg0axA8//MBzzz1nNXE6JiaGESNG8OSTT/L+++/j7OxMnz59OHz4MACXLl3iscce49dff+XFF19kxowZBAYGMnDgQH744QerOMLCwujYsSPTp0+nVatWvP766wC8/vrrjBkz5pbXICIi+UvDn0TEcLGxsaSmplK6dOk8vf/o0aMEBQXx/vvv4+bmBkDjxo3ZunUrO3bsoG/fvuzfv5+UlBT69euHv78/cG040/r160lKSiI5OZkTJ07w3HPP8cADDwBYvhFPS0vLdc5KlSpZhkJdH4p05swZqz6fffYZaWlpfPrpp/j5+QHXqhr//e9/2bdvHyVKlKBkyZJMmTKFMmXKANCwYUP27dvHzp07Lee5fk1/N+Tp7bffpmzZsnz88cfY29sD0LRpU1q3bs306dN5//33LX0ff/xxBg4cCECjRo1Yt24dGzdupHv37rf8Gffv358HH3wQuJYETps2jWbNmjFt2jRLn3LlytGrVy82bdpk6ZucnMzYsWPp0qWL5fpatWrFxx9/zLvvvsunn35KTEwMq1evJjAwEIAHHniAXr16MXXqVDp06ICd3bXvv+rVq0fv3r0t54uPj7f8jO71+SUiIkZTUiEihrv+QfhWw3BupWnTpjRt2pT09HSOHz/O6dOnOXr0KDExMXh6egJQu3ZtnJyceOyxx2jTpg3NmzcnNDSUWrVqAeDq6kqlSpV47bXX2LJlC02bNqV58+aMHDkyz9e1e/du6tSpY0ko4Fois2HDBsvrr7/+mqysLE6dOsXp06c5fvw4J06cICMj47bOkZSUxIEDBxg0aJDl5wjg4eFBixYtrIZvAYSEhFi9Llmy5N8OKcspKCjI8v8nTpzgwoUL9OvXzyrO+vXr4+bmxtatWy1JhYODAx06dLD0cXZ2pnnz5mzevBmAnTt3EhISYkkoruvUqRMjR47kxIkTloQhZwwiIlK4KKkQEcMVL14cV1dXzp0797d9kpKSSE9Pp3jx4rn2ZWVl8c477zBv3jySkpIICAigVq1aODk5WfqULl2ar776io8//pjFixfzxRdf4OHhwZNPPskLL7yAyWRi7ty5fPTRR6xdu5bvvvsOR0dHWrVqxbhx42563n8SFxf3j9WXTz/9lLCwMOLi4vD19aVGjRq4uLiQkJBwW+dISEggOzsbX1/fXPt8fX1zHcfZ2dnqtZ2d3W095yHnsLPrE9THjRvHuHHjcvW9ePGiVQwODta/anx8fCzHiI+Pt1Rpbowdri0be7MYRESkcFFSISKFQtOmTdmxYwepqalWycB1CxcuZMqUKSxevDjXvIKPP/6Yzz77jHHjxvHwww/j7u4OYJnzcF3O4Uy7d+9mwYIFhIWFUa1aNdq2bYu/vz9jx45lzJgxHDlyhFWrVjF79my8vLzyNGbf3d2dmJiYXO3bt2+ndOnShIeHM3nyZIYPH063bt3w9vYGYMiQIRw4cOC2z2Eymbh06VKufdHR0ZZKzb/Jw8MDgJdffpkGDRrk2p8zAbuePOR06dIlfHx8LH2jo6Nz9bne5uXl9W+ELCIi+UwTtUWkUHjmmWeIi4vjvffey7UvOjqauXPnUqlSpZtOVN69ezeVKlXi0UcftSQUUVFRHD16lKysLODa/IYWLVqQlpaG2WymUaNGvPnmmwCcO3eOvXv30rhxY/bv34/JZCIoKIgXX3yRKlWq3LKCciv16tVj3759VonF5cuXefbZZ9m0aRO7d+/Gw8ODZ5991pJQXL16ld27d1viBixzCm6mWLFi1KhRg5UrV1oNH0tISGDjxo3UrVs3T7HfSoUKFfDx8eHMmTPUrFnTsvn7+/P2229z6NAhS9+UlBR+/vlnq9ebN2+mUaNGwLUhU3v37uXs2bNW5/jhhx/w8/OjbNmyfxtHzuFeIiJiLFUqRKRQqFOnDkOGDOG9994jIiKCLl264OXlxbFjx5gzZw6pqak3TTjgWgXiww8/5OOPP6ZOnTqcPn2aWbNmkZaWRnJyMnBtgvC0adMYOHAg//vf/7C3t+ebb77BbDbTokULAgMDcXZ25uWXX2bw4MH4+vqybds2Dh8+bFm61Fa9evXiu+++49lnn6Vfv344Ojry0UcfUbJkSTp27Mj69euZP38+kydPpkWLFly8eJE5c+Zw6dIlq2/7PTw82Lt3L9u3b6d69eq5zjNs2DD69OlD3759efLJJ0lPT+fjjz8mLS3NMin732Rvb8+LL77I66+/jr29PS1atODKlSt8+OGHREVF5Ur8Ro4cyQsvvICPjw9z5swhKSmJAQMGANC7d29
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Plot classification report\n",
"plot_classification_report(cr_v3)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-08T16:08:10.217195900Z",
"start_time": "2024-06-08T16:08:09.908834700Z"
}
},
"id": "e7870deb507d0eab"
},
{
"cell_type": "markdown",
"source": [
"#### Neural network V4"
],
"metadata": {
"collapsed": false
},
"id": "d32e69740aecec91"
},
{
"cell_type": "code",
"execution_count": 64,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 10/1000, Loss: 0.4186047613620758\n",
"Epoch 20/1000, Loss: 0.20972707867622375\n",
"Epoch 30/1000, Loss: 0.14643631875514984\n",
"Epoch 40/1000, Loss: 0.1106308326125145\n",
"Epoch 50/1000, Loss: 0.09039844572544098\n",
"Epoch 60/1000, Loss: 0.07973706722259521\n",
"Epoch 70/1000, Loss: 0.07261230796575546\n",
"Epoch 80/1000, Loss: 0.06730187684297562\n",
"Epoch 90/1000, Loss: 0.06239919736981392\n",
"Epoch 100/1000, Loss: 0.05727291852235794\n",
"Epoch 110/1000, Loss: 0.05158916860818863\n",
"Epoch 120/1000, Loss: 0.04558803513646126\n",
"Epoch 130/1000, Loss: 0.03927084803581238\n",
"Epoch 140/1000, Loss: 0.03285034000873566\n",
"Epoch 150/1000, Loss: 0.026732975617051125\n",
"Epoch 160/1000, Loss: 0.021154126152396202\n",
"Epoch 170/1000, Loss: 0.016432005912065506\n",
"Epoch 180/1000, Loss: 0.012630755081772804\n",
"Epoch 190/1000, Loss: 0.009745683521032333\n",
"Epoch 200/1000, Loss: 0.007611896842718124\n",
"Epoch 210/1000, Loss: 0.0060510095208883286\n",
"Epoch 220/1000, Loss: 0.004891632590442896\n",
"Epoch 230/1000, Loss: 0.004030665848404169\n",
"Epoch 240/1000, Loss: 0.003379524452611804\n",
"Epoch 250/1000, Loss: 0.002879904117435217\n",
"Epoch 260/1000, Loss: 0.002489611506462097\n",
"Epoch 270/1000, Loss: 0.002184198470786214\n",
"Epoch 280/1000, Loss: 0.0019437369192019105\n",
"Epoch 290/1000, Loss: 0.0017541276756674051\n",
"Epoch 300/1000, Loss: 0.0016026693629100919\n",
"Epoch 310/1000, Loss: 0.0014796099858358502\n",
"Epoch 320/1000, Loss: 0.0013786517083644867\n",
"Epoch 330/1000, Loss: 0.0012941397726535797\n",
"Epoch 340/1000, Loss: 0.0012196070747449994\n",
"Epoch 350/1000, Loss: 0.0011414645705372095\n",
"Epoch 360/1000, Loss: 0.001061757793650031\n",
"Epoch 370/1000, Loss: 0.0009867347544059157\n",
"Epoch 380/1000, Loss: 0.0009273262694478035\n",
"Epoch 390/1000, Loss: 0.0008805892430245876\n",
"Epoch 400/1000, Loss: 0.0008431114838458598\n",
"Epoch 410/1000, Loss: 0.0008109930204227567\n",
"Epoch 420/1000, Loss: 0.0007825929205864668\n",
"Epoch 430/1000, Loss: 0.0007562871905975044\n",
"Epoch 440/1000, Loss: 0.0007321131997741759\n",
"Epoch 450/1000, Loss: 0.0007092245505191386\n",
"Epoch 460/1000, Loss: 0.0006881365552544594\n",
"Epoch 470/1000, Loss: 0.0006688928115181625\n",
"Epoch 480/1000, Loss: 0.0006500301533378661\n",
"Epoch 490/1000, Loss: 0.0006336761871352792\n",
"Epoch 500/1000, Loss: 0.0006172554567456245\n",
"Epoch 510/1000, Loss: 0.00060228758957237\n",
"Epoch 520/1000, Loss: 0.0005888384766876698\n",
"Epoch 530/1000, Loss: 0.0005763943772763014\n",
"Epoch 540/1000, Loss: 0.0005643281037919223\n",
"Epoch 550/1000, Loss: 0.0005541218561120331\n",
"Epoch 560/1000, Loss: 0.0005427465075626969\n",
"Epoch 570/1000, Loss: 0.0005325390957295895\n",
"Epoch 580/1000, Loss: 0.0005224572960287333\n",
"Epoch 590/1000, Loss: 0.0005137791740708053\n",
"Epoch 600/1000, Loss: 0.0005050148465670645\n",
"Epoch 610/1000, Loss: 0.0004971114685758948\n",
"Epoch 620/1000, Loss: 0.0004892157157883048\n",
"Epoch 630/1000, Loss: 0.00048245518701151013\n",
"Epoch 640/1000, Loss: 0.0004754861583933234\n",
"Epoch 650/1000, Loss: 0.0004669078625738621\n",
"Epoch 660/1000, Loss: 0.0004601963155437261\n",
"Epoch 670/1000, Loss: 0.00045247498201206326\n",
"Epoch 680/1000, Loss: 0.00044596134102903306\n",
"Epoch 690/1000, Loss: 0.0004407193046063185\n",
"Epoch 700/1000, Loss: 0.000434816291090101\n",
"Epoch 710/1000, Loss: 0.00042790695442818105\n",
"Epoch 720/1000, Loss: 0.0004230188496876508\n",
"Epoch 730/1000, Loss: 0.00041854308801703155\n",
"Epoch 740/1000, Loss: 0.0004127133288420737\n",
"Epoch 750/1000, Loss: 0.00040834766696207225\n",
"Epoch 760/1000, Loss: 0.0004034344747196883\n",
"Epoch 770/1000, Loss: 0.00039839293458499014\n",
"Epoch 780/1000, Loss: 0.0003941936884075403\n",
"Epoch 790/1000, Loss: 0.00039072500658221543\n",
"Epoch 800/1000, Loss: 0.0003858723503071815\n",
"Epoch 810/1000, Loss: 0.0003812974609900266\n",
"Epoch 820/1000, Loss: 0.00037719475221820176\n",
"Epoch 830/1000, Loss: 0.0003737532824743539\n",
"Epoch 840/1000, Loss: 0.00036979030119255185\n",
"Epoch 850/1000, Loss: 0.00036644612555392087\n",
"Epoch 860/1000, Loss: 0.0003625146928243339\n",
"Epoch 870/1000, Loss: 0.000359336263500154\n",
"Epoch 880/1000, Loss: 0.00035612270585261285\n",
"Epoch 890/1000, Loss: 0.00035236674011684954\n",
"Epoch 900/1000, Loss: 0.0003493966069072485\n",
"Epoch 910/1000, Loss: 0.0003470622468739748\n",
"Epoch 920/1000, Loss: 0.00034316032542847097\n",
"Epoch 930/1000, Loss: 0.00034029941889457405\n",
"Epoch 940/1000, Loss: 0.00033748531132005155\n",
"Epoch 950/1000, Loss: 0.0003346360463183373\n",
"Epoch 960/1000, Loss: 0.00033203166094608605\n",
"Epoch 970/1000, Loss: 0.0003288303851149976\n",
"Epoch 980/1000, Loss: 0.0003261156671214849\n",
"Epoch 990/1000, Loss: 0.000323760905303061\n",
"Epoch 1000/1000, Loss: 0.0003211983712390065\n"
]
}
],
"source": [
"# Neural network V4\n",
"model_v4 = NeuralNetworkV4(input_size, hidden_size).to(device)\n",
"criterion_v4 = nn.BCELoss()\n",
"optimizer_v4 = optim.Adam(model_v4.parameters(), lr=learning_rate, weight_decay=weight_decay)\n",
"\n",
"# Train the model\n",
"train(model_v4, X_train, y_train, criterion_v4, optimizer_v4, epochs)\n",
"\n",
"# Evaluate the model\n",
"cm_v4, cr_v4, acc_v4 = evaluate(model_v4, X_test, y_test)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-08T16:26:02.339239600Z",
"start_time": "2024-06-08T16:25:56.963561200Z"
}
},
"id": "b85c2f139ab4eeca"
},
{
"cell_type": "code",
"execution_count": 65,
"outputs": [
{
"data": {
"text/plain": "<Figure size 1000x600 with 2 Axes>",
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAwIAAAIhCAYAAAD98w2UAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/H5lhTAAAACXBIWXMAAA9hAAAPYQGoP6dpAABWNElEQVR4nO3deVxUZf//8feADKDghrsZueQSKiIuaVaGertlbultFoZaarl0l6aJmSui2aqY4tamuS9ZWbmnpqnhnsvtlmmkN+SKIIMwvz/8Od8ZwWQMmInzevY4j5jrnLnO5wDR+cznuq5jslqtVgEAAAAwFA9XBwAAAAAg75EIAAAAAAZEIgAAAAAYEIkAAAAAYEAkAgAAAIABkQgAAAAABkQiAAAAABgQiQAAAABgQCQCAHAHPG/xn42fHwD8NRIBwA0cOHBAr7/+upo2baratWurefPmGjlypM6cOZNr5/zkk0/0yCOPqHbt2vroo49ypM8dO3aoWrVq2rFjR47050offfSR5syZc9fjwsLC9MYbb+RBRDfPVbduXcXHx2e5v1q1apo6dWqexJJd4eHhCg8Pv+P+W78zL7/8cpb7ly9frmrVquns2bNOnTe7P7/c8MYbbygsLCzbx0+fPl3VqlXT/v3773jMuHHjFBISoqSkpEz71q1bl2/+uwOQt0gEABebP3++unXrpj///FODBw/WrFmz1KdPH+3cuVNPP/20jhw5kuPnTEpK0qRJk1S7dm3NmTNHHTt2zJF+g4KCtGjRIgUFBeVIf6704YcfKiUl5a7HxcTE3PEmNjdcu3ZNb775Zp6dL6+sX79eq1atyrH+svvzcwcdO3aUh4eHvvrqqyz3WywWff3112rVqpX8/Pwc9l28eFGjRo3KizAB5EMkAoALxcXFKSoqSt27d9fcuXPVrl07NWzYUF27dtWCBQvk7e2tyMjIHD/v5cuXlZGRoebNm6t+/foqW7ZsjvTr5+enOnXqZLpZyc8eeugh3X///Xl2vsKFC+vHH3/U4sWL8+yceaFw4cKKiopSYmKiq0PJc2XKlFGTJk20evVqpaenZ9r/ww8/6NKlS3r66acz7RszZowKFCiQF2ECyIdIBAAXmjNnjvz9/fXaa69l2le8eHG98cYbatasmZKTkyVJ6enpmj9/vtq1a6fatWuradOmeuedd5Sammp73xtvvKGIiAgtW7ZMLVu2VM2aNdW+fXtt3rxZ0s2hFreGLURGRqpatWqSsh7icvuwjOvXr2v06NF67LHHVLNmTbVq1cph+EVWQ4MOHDig3r17q2HDhqpbt6769eunY8eOZXrP9u3b1atXLwUHB+uRRx7R5MmTs7wpso+tVq1a+vnnn9W5c2fVqlVLLVu21IYNG3Ty5Ek9//zzCg4OVosWLfTNN984vHfXrl3q3bu36tevr5o1ayosLExTp05VRkaGJNm+JzExMbavp06dqhYtWigmJkYNGjRQkyZNdPnyZYfvW3R0tKpVq6affvop0/dw5cqVd7wWZ4SFhalBgwaaNGmS/vjjj788NiMjQzNnzlSLFi1Us2ZNtWzZUp9//nmm/u72c7/TtV+/fl3vvvuu/vWvf6lmzZqqW7euevbsqcOHDzt9Xa+++qqSk5M1evToux4bHx+v1157TQ0aNFBwcLCef/55HTp0yLb/9p/fZ599purVq+vixYu2Y6ZNm2b7vbtl3bp1ql69us6fPy8p+7+7Cxcu1BNPPKG6devqxx9/zBTvoUOHVK9ePb344ouyWCxZXlPnzp2VmJjoEM8tK1asUKVKlRQaGurQvnr1am3btk2vv/76Xb9nAJAVEgHARaxWq7Zu3apGjRrJ19c3y2PatGmj/v37q2DBgpKkt956S9HR0WrevLmmT5+uZ599VvPmzdPLL7/sMDHy4MGDmjNnjgYNGqRp06bJ09NTAwcO1OXLl9W0aVPFxMRIkl566SUtWrQo2zFPmDBBmzdv1rBhwzRnzhw1a9ZMb7/9tpYtW5bl8T/99JOeeeYZ23vHjx+vP/74Q926ddOJEyccjh0yZIhCQ0M1Y8YMPfnkk5o9e7aWLFnyl/HcuHFDgwcPVrdu3TR9+nT5+vpqyJAh6tevn5o2baoZM2aoVKlSGjZsmM6dOydJOnLkiCIiIlS0aFG9//77mj59uurVq6eYmBh9++23kmT7njz99NMO35/4+Hj98MMPev/99zV8+HAVKVLEIZ5XX31VDzzwgEaNGiWLxaL4+HhFRUWpdevW6tChQ7a/z3/FZDJpwoQJysjIuOsQodGjR2vKlCl66qmnNGPGDLVq1UoTJkzQtGnTnD5vVtc+dOhQLVu2TH369NHcuXM1fPhwHTt2TIMHD3Z6om7lypU1cOBArV27Vl9//fUdj7tw4YK6deumX375RSNHjtS7776rjIwMPfvss7bfqdt/fk2bNpXVanVI0G59vWvXLlvb5s2b9dBDD6l06dJO/e7GxMRo2LBheuuttxQSEuKw78SJE+rdu7eCg4M1bdo0mc3mLK8rLCxMxYoVyzQ86MKFC9q8eXOmakBiYqLGjBmjyMhIlSxZ8o7fLwD4K9QTARe5ePGiUlNTdd9992Xr+OPHj2vp0qUaPHiw+vTpI0l65JFHVKpUKQ0dOlSbN2/W448/Lkm6evWqli9fbhuyUrBgQT333HP66aef1LJlS9WoUUOSdP/996tOnTrZjnnnzp165JFH1LZtW0lSw4YNVbBgQQUEBGR5/LvvvqvAwEDNnDlTnp6ekqQmTZqoRYsWmjJlij788EPbsV26dFH//v0lSY0aNdK6deu0adMmdevW7Y7xZGRkqF+/furSpYsk6cqVK3r11Vf1/PPPq2fPnpIkf39/de7cWQcPHlSZMmV05MgRNW7cWJMnT5aHh4ft+7hhwwbt2LFDbdu2tX1PypQp4/D9uXHjhoYNG6Z69eplGY+Pj48mTpyo7t27a+bMmdq9e7f8/Pw0ZsyYu31rnVKhQgW99tprGj9+vJYsWWK7fnunTp3S4sWL9dprr9l+X5o0aSKTyaTY2Fh1795dxYoVy/Y5b792i8Vim6/Qpk0bSVKDBg2UlJSkiRMnKjEx0ekb1N69e2vt2rUaN26cHn74YZUoUSLTMZ9++qkuXbqkBQsWqHz58pKkxx57TG3atNGHH36oKVOmZPnzq1ixorZv367WrVsrJSVFe/bsUVBQkEMisGXLFnXq1EmSc7+73bt3V6tWrTLFeubMGUVERKh69er66KOP7pgESJLZbNZTTz2lpUuXasyYMfLx8ZEkWzXr9kRy5MiRCgkJUYcOHZgkDOCeUREAXOTWzcVfDX+xt3PnTkmy3YTf0rZtW3l6ejrcDBQvXtxh3HqZMmUk6W9PnmzYsKEWL16sF198UfPmzdOZM2fUv39/NW3aNNOxycnJOnDggFq3bm27VunmWPAnnnjCdj233P5JapkyZWxDov6K/ftuJSTBwcG2tqJFi0q6mSRIN2+oZs2apbS0NB05ckTff/+9pkyZovT0dKWlpd31fLeSqL+KJyIiQtOmTdO2bds0ceLETJUDe+np6bpx44Zty+7vw3PPPaf69etr4sSJtmqHvZ9++klWq1VhYWEO/YeFhSk1NVVxcXHZOo89+2s3m82aM2eO2rRpo/Pnz+unn37SwoULtXHjRkm64xCYv+Lp6ano6GglJyffMXnavn27atSoodKlS9uuycPDQ4899pi2bdt2x76bNm1q2x8XFycvLy/16NFD+/btk8Vi0fHjxxUfH6+mTZs6/bub1e/EtWvXFBERoYSEBI0ZM0be3t53vf7OnTvr2rVr2rBhg61txYoVatq0qUOyvWLFCsXFxWns2LF37RMA/gqJAOAiRYoUUaFChe64FKR082b68uXLkmT79+2fshYoUEDFihXT1atXbW23DzUymUySZBsDf69GjBih//znPzp79qzGjRun5s2bq1u3blmubHT16lVZrdYsP9UtUaKEQ7ySbJ+A3uLh4ZGt4SVZTUy+01Ar6eY8hxEjRig0NFQdOnTQ5MmT9fvvv6tAgQLZOl+hQoXuekzHjh2VkZGhEiVKOCQ
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Plot confusion matrix\n",
"plot_confusion_matrix(cm_v4, ['Benign', 'Malignant'], title='Confusion matrix - Neural Network V4')"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-08T16:26:02.503093100Z",
"start_time": "2024-06-08T16:26:02.331956800Z"
}
},
"id": "73954f54243ecb26"
},
{
"cell_type": "code",
"execution_count": 66,
"outputs": [
{
"data": {
"text/plain": "<Figure size 1000x600 with 2 Axes>",
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAw0AAAIhCAYAAAAM+FYZAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/H5lhTAAAACXBIWXMAAA9hAAAPYQGoP6dpAACHk0lEQVR4nOzdd3hU1dbH8e+kTBJSIAQIVToEAoQQIHRBQOkglov6gggqKCBKU+BKkyYqKqACKioWpFsQEOlSBAxdEjoYWgik9zbvH8iYcRBJrslJyO9zn3kus8+eM+tENpk1a+99TBaLxYKIiIiIiMjfcDA6ABERERERKdiUNIiIiIiIyG0paRARERERkdtS0iAiIiIiIrelpEFERERERG5LSYOIiIiIiNyWkgYREREREbktJQ0iIiIiInJbShpERAoB3Yfz7+lnIyKS95Q0iEiBcuTIEUaPHk3btm1p0KABHTp04NVXXyU8PNymX+3atZk7d26+xjZ37lxq165tfZ6QkMDgwYMJCAigSZMmnDt3jtq1a7Nq1ap/9X03bdrEyy+/bH2+Z88eateuzZ49e/7V9yls4uLiGDNmDL/++qvRoYiI3PWcjA5AROSmL7/8kunTpxMcHMzIkSMpU6YM58+f5+OPP2bDhg189tln+Pn5GRbfI488QuvWra3Pv/nmG7Zs2cKECROoWbMm5cuXZ+nSpdxzzz3/6vt++umnNs/9/f1ZunQpNWrU+Fffp7AJDQ3l22+/5aGHHjI6FBGRu56SBhEpEEJCQpg2bRpPPPEE48ePt7YHBwfToUMHevXqxbhx4/71b/FzomzZspQtW9b6PCYmBoDHH38ck8kEQMOGDfM8Dg8Pj3x5HxERkZs0PUlECoSPP/4YT09PRowYYXesZMmSvPLKK7Rv356kpKRbvj4sLIyhQ4fSrFkz/P39ad26NVOnTiUlJcXaZ+fOnTz66KMEBgbSpEkTnnvuOU6fPm09/vvvvzN48GCCg4MJCAjgP//5D9u2bbMezz49qW/fvtbpUX5+frzyyitcuHDBbnrSmTNnGDp0KE2bNqVJkyYMGjTI5j0vXLjAmDFjaNWqFf7+/jRv3pwxY8YQHR1tfZ+9e/eyd+9e65SkW01POnLkCAMHDiQ4OJhGjRoxePBgTp48aT1+8zW7d+9mwIABBAQE0LJlS9544w0yMzP/9r/LqlWrqFu3LsuXL6dly5Y0bdqUU6dOAbBx40Z69+5N/fr1admyJVOnTrX57zN37lzuu+8+tmzZQqdOnQgICODRRx+1m1Z19epVxo4dy7333kuDBg14+OGH2bRpk02f2rVrM2/ePHr37k2DBg2YN28e/fr1A6Bfv3707dv3b69BRET+d0oaRMRwFouFHTt20Lx5c9zc3G7Zp0uXLgwZMoRixYrZHbt69SpPPPEEycnJzJw5kw8//JCuXbvy+eefs3jxYgDCw8N5/vnnqVevHh988AHTpk3j7NmzPPvss2RlZZGVlcWgQYNITk5m1qxZvP/++5QoUYLnnnuO8+fP273nxIkTefjhhwFYunQpzz//vF2fiIgI/vOf/3Du3DkmTZrEG2+8wbVr13jyySeJiYkhOTmZfv36cfr0aSZOnMjHH39Mv379+OGHH3j77bet71O3bl3q1q3L0qVL8ff3t3ufX375hcceewyA6dOnM3XqVC5fvkyfPn1sEhSAUaNGERQUxPz58+nWrRsfffQRy5cvv91/HjIzM1m0aBHTpk1j7NixVK9ene+//54hQ4ZQrVo13nvvPYYOHcp3333H888/b7MwOSoqipdffpnHH3+cd999F1dXVwYOHEhoaCgA165d4+GHH+bXX3/lpZdeYu7cuVSoUIEhQ4bw3Xff2cQxf/58unfvzpw5c+jQoQMTJkwAYMKECUycOPG21yAiIv8bTU8SEcNFR0eTmppKxYoVc/X6EydOUKdOHd599108PDwAaNGiBTt37mTPnj08++yzHD58mJSUFAYNGoSvry9wY7rRpk2bSEpKIjk5mTNnzvD8889z7733Ali/0U5LS7N7zxo1alinKt2cKnThwgWbPp9++ilpaWl88sknlC5dGrhRlXjsscc4dOgQZcqUoWzZsrz++utUqlQJgGbNmnHo0CH27t1rfZ+b1/R3U5LeeustKleuzMKFC3F0dASgVatWdOzYkTlz5vDuu+9a+z7yyCMMGTIEgObNm7Nx40a2bt1Knz59bvszHjx4MG3btgVuJHlvvvkmrVu35s0337T2qVKlCv3792fbtm3WvsnJyUyaNIlevXpZr69Dhw4sXLiQt99+m08++YSoqCh+/PFHKlSoAMC9995L//79mTVrFt26dcPB4cb3W40bN+app56yvl9sbKz1Z1TU13eIiOQ1JQ0iYribH3RvN03mdlq1akWrVq1IT0/n1KlTnD9/nhMnThAVFUWJEiUACAgIwMXFhYcffphOnTrRpk0bgoODadCgAQDu7u7UqFGDV199lR07dtCqVSvatGnD2LFjc31dISEhNGzY0JowwI1EZcuWLdbnX331FVlZWZw7d47z589z6tQpzpw5Q0ZGxh29R1JSEkeOHGHo0KHWnyOAl5cX7dq1s5leBRAYGGjzvGzZsn875Su7OnXqWP985swZrly5wqBBg2zibNKkCR4eHuzcudOaNDg5OdGtWzdrH1dXV9q0acP27dsB2Lt3L4GBgdaE4aYePXowduxYzpw5Y00IsscgIiL5S0mDiBiuePHiuLu7c+nSpb/tk5SURHp6OsWLF7c7lpWVxezZs/nyyy9JSkqiXLlyNGjQABcXF2ufihUr8sUXX7Bw4UJWrFjB4sWL8fLy4vHHH+fFF1/EZDKxaNEiPvjgA3766Se++eYbnJ2d6dChA5MnT77l+/6TmJiYf6yefPLJJ8yfP5+YmBhKlSpFvXr1cHNzIz4+/o7eIz4+HovFQqlSpeyOlSpVyu48rq6uNs8dHBzu6D4H2aeF3VwAPnnyZCZPnmzX9+rVqzYxODnZ/qrx8fGxniM2NtZaZflr7HBjW9VbxSAiIvlLSYOIFAitWrViz549pKam2nzYv2nZsmW8/vrrrFixwm5e/8KFC/n000+ZPHky999/P56engDWNQc3ZZ9uFBISwtKlS5k/fz5+fn507twZX19fJk2axMSJEwkLC2P9+vV8+OGHeHt752rOvKenJ1FRUXbtu3fvpmLFihw8eJCZM2cyevRoevfuTcmSJQEYPnw4R44cueP3MJlMXLt2ze5YZGSktdLyb/Ly8gJgzJgxNG3a1O549gTrZnKQ3bVr1/Dx8bH2jYyMtOtzs83b2/vfCFlERP5HWggtIgXCgAEDiImJ4Z133rE7FhkZyaJFi6hRo8YtFwKHhIRQo0YNHnroIWvCEBERwYkTJ8jKygJurC9o164daWlpmM1mmjdvzmuvvQbApUuXOHDgAC1atODw4cOYTCbq1KnDSy+9RK1atW5bAbmdxo0bc+jQIZvE4fr16zz99NNs27aNkJAQvLy8ePrpp60JQ2JiIiEhIda4Aeuc/lspVqwY9erVY926dTbTu+Lj49m6dStBQUG5iv12qlWrho+PDxcuXKB+/frWh6+vL2+99RbHjh2z9k1JSeHnn3+2eb59+3aaN28O3JjSdODAAS5evGjzHt999x2lS5emcuXKfxtH9ulYIiKSt1RpEJECoWHDhgwfPpx33nmH06dP06tXL7y9vTl58iQff/wxqampt0wo4EYF4f3332fhwoU0bNiQ8+fPs2DBAtLS0khOTgZuLMB98803GTJkCP/3f/+Ho6MjX3/9NWazmXbt2lGhQgVcXV0ZM2YMw4YNo1SpUuzatYvQ0FDr1p451b9/f7755huefvppBg0ahLOzMx988AFly5ale/fubNq0iSVLljBz5kzatWvH1atX+fjjj7l27ZrNt/VeXl4cOHCA3bt3U7duXbv3GTlyJAMHDuTZZ5/l8ccfJz09nYULF5KWlmZd9PxvcnR05KWXXmLChAk4OjrSrl074uLieP/994mIiLBL7MaOHcuLL76Ij48PH3/8MUlJSTz33HMAPPXUU3z
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Plot classification report\n",
"plot_classification_report(cr_v4)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-08T16:26:02.703936200Z",
"start_time": "2024-06-08T16:26:02.503093100Z"
}
},
"id": "7edf4cb0b06ec6"
},
{
"cell_type": "markdown",
"source": [
"#### Neural network V5"
],
"metadata": {
"collapsed": false
},
"id": "dff009e888af5074"
},
{
"cell_type": "code",
"execution_count": 67,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 10/1000, Loss: 0.6472257971763611\n",
"Epoch 20/1000, Loss: 0.5392731428146362\n",
"Epoch 30/1000, Loss: 0.36916524171829224\n",
"Epoch 40/1000, Loss: 0.21396635472774506\n",
"Epoch 50/1000, Loss: 0.13106301426887512\n",
"Epoch 60/1000, Loss: 0.09312008321285248\n",
"Epoch 70/1000, Loss: 0.07527194172143936\n",
"Epoch 80/1000, Loss: 0.06560871005058289\n",
"Epoch 90/1000, Loss: 0.05922595039010048\n",
"Epoch 100/1000, Loss: 0.054361775517463684\n",
"Epoch 110/1000, Loss: 0.050068166106939316\n",
"Epoch 120/1000, Loss: 0.04590252414345741\n",
"Epoch 130/1000, Loss: 0.042265865951776505\n",
"Epoch 140/1000, Loss: 0.03903531655669212\n",
"Epoch 150/1000, Loss: 0.03611545264720917\n",
"Epoch 160/1000, Loss: 0.03336109220981598\n",
"Epoch 170/1000, Loss: 0.030720891430974007\n",
"Epoch 180/1000, Loss: 0.028099115937948227\n",
"Epoch 190/1000, Loss: 0.02529078908264637\n",
"Epoch 200/1000, Loss: 0.022276364266872406\n",
"Epoch 210/1000, Loss: 0.01945570483803749\n",
"Epoch 220/1000, Loss: 0.016959620639681816\n",
"Epoch 230/1000, Loss: 0.014633278362452984\n",
"Epoch 240/1000, Loss: 0.012433870695531368\n",
"Epoch 250/1000, Loss: 0.01063645537942648\n",
"Epoch 260/1000, Loss: 0.009215538389980793\n",
"Epoch 270/1000, Loss: 0.008075335994362831\n",
"Epoch 280/1000, Loss: 0.007146317046135664\n",
"Epoch 290/1000, Loss: 0.006382191088050604\n",
"Epoch 300/1000, Loss: 0.005759004037827253\n",
"Epoch 310/1000, Loss: 0.005265630315989256\n",
"Epoch 320/1000, Loss: 0.004859541077166796\n",
"Epoch 330/1000, Loss: 0.0045250700786709785\n",
"Epoch 340/1000, Loss: 0.00425321189686656\n",
"Epoch 350/1000, Loss: 0.0040284693241119385\n",
"Epoch 360/1000, Loss: 0.0038447484839707613\n",
"Epoch 370/1000, Loss: 0.0036921321880072355\n",
"Epoch 380/1000, Loss: 0.003561016172170639\n",
"Epoch 390/1000, Loss: 0.003450624644756317\n",
"Epoch 400/1000, Loss: 0.003364400239661336\n",
"Epoch 410/1000, Loss: 0.0032908162102103233\n",
"Epoch 420/1000, Loss: 0.003222778672352433\n",
"Epoch 430/1000, Loss: 0.003165710251778364\n",
"Epoch 440/1000, Loss: 0.0031129783019423485\n",
"Epoch 450/1000, Loss: 0.0030633530113846064\n",
"Epoch 460/1000, Loss: 0.0029811717104166746\n",
"Epoch 470/1000, Loss: 0.0028529497794806957\n",
"Epoch 480/1000, Loss: 0.0026854067109525204\n",
"Epoch 490/1000, Loss: 0.0025055729784071445\n",
"Epoch 500/1000, Loss: 0.002334937220439315\n",
"Epoch 510/1000, Loss: 0.002182029653340578\n",
"Epoch 520/1000, Loss: 0.0020450232550501823\n",
"Epoch 530/1000, Loss: 0.0019291980424895883\n",
"Epoch 540/1000, Loss: 0.0018309173174202442\n",
"Epoch 550/1000, Loss: 0.0017435038462281227\n",
"Epoch 560/1000, Loss: 0.0016699342522770166\n",
"Epoch 570/1000, Loss: 0.0016051153652369976\n",
"Epoch 580/1000, Loss: 0.0015472021186724305\n",
"Epoch 590/1000, Loss: 0.0014925338327884674\n",
"Epoch 600/1000, Loss: 0.0014427873538807034\n",
"Epoch 610/1000, Loss: 0.0013984435936436057\n",
"Epoch 620/1000, Loss: 0.0013581964885815978\n",
"Epoch 630/1000, Loss: 0.00132227991707623\n",
"Epoch 640/1000, Loss: 0.0012905021430924535\n",
"Epoch 650/1000, Loss: 0.0012572426348924637\n",
"Epoch 660/1000, Loss: 0.001225507934577763\n",
"Epoch 670/1000, Loss: 0.0012008383637294173\n",
"Epoch 680/1000, Loss: 0.001180665334686637\n",
"Epoch 690/1000, Loss: 0.001158036757260561\n",
"Epoch 700/1000, Loss: 0.0011395757319405675\n",
"Epoch 710/1000, Loss: 0.0011222346220165491\n",
"Epoch 720/1000, Loss: 0.0011062605772167444\n",
"Epoch 730/1000, Loss: 0.0010908363619819283\n",
"Epoch 740/1000, Loss: 0.0010754970135167241\n",
"Epoch 750/1000, Loss: 0.0010610275203362107\n",
"Epoch 760/1000, Loss: 0.0010471524437889457\n",
"Epoch 770/1000, Loss: 0.0010340394219383597\n",
"Epoch 780/1000, Loss: 0.0010215329239144921\n",
"Epoch 790/1000, Loss: 0.0010098188649863005\n",
"Epoch 800/1000, Loss: 0.0009981722105294466\n",
"Epoch 810/1000, Loss: 0.000987424748018384\n",
"Epoch 820/1000, Loss: 0.0009770964970812201\n",
"Epoch 830/1000, Loss: 0.0009671879815869033\n",
"Epoch 840/1000, Loss: 0.000957623531576246\n",
"Epoch 850/1000, Loss: 0.00094812415773049\n",
"Epoch 860/1000, Loss: 0.0009392902138642967\n",
"Epoch 870/1000, Loss: 0.0009308967855758965\n",
"Epoch 880/1000, Loss: 0.0009222882217727602\n",
"Epoch 890/1000, Loss: 0.0009139023022726178\n",
"Epoch 900/1000, Loss: 0.0009058943251147866\n",
"Epoch 910/1000, Loss: 0.0008985912427306175\n",
"Epoch 920/1000, Loss: 0.0008904503774829209\n",
"Epoch 930/1000, Loss: 0.0008828166173771024\n",
"Epoch 940/1000, Loss: 0.0008756622555665672\n",
"Epoch 950/1000, Loss: 0.000868525356054306\n",
"Epoch 960/1000, Loss: 0.0008616363047622144\n",
"Epoch 970/1000, Loss: 0.0008551458013243973\n",
"Epoch 980/1000, Loss: 0.0008486591395922005\n",
"Epoch 990/1000, Loss: 0.0008415335905738175\n",
"Epoch 1000/1000, Loss: 0.0008350368589162827\n"
]
}
],
"source": [
"# Neural network V5\n",
"model_v5 = NeuralNetworkV5(input_size, hidden_size).to(device)\n",
"criterion_v5 = nn.BCELoss()\n",
"optimizer_v5 = optim.Adam(model_v5.parameters(), lr=learning_rate, weight_decay=weight_decay)\n",
"\n",
"# Train the model\n",
"train(model_v5, X_train.unsqueeze(1), y_train, criterion_v5, optimizer_v5, epochs)\n",
"\n",
"# Evaluate the model\n",
"cm_v5, cr_v5, acc_v5 = evaluate(model_v5, X_test.unsqueeze(1), y_test)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-08T16:26:18.741105800Z",
"start_time": "2024-06-08T16:26:14.401969800Z"
}
},
"id": "ea7a761090fe566a"
},
{
"cell_type": "code",
"execution_count": 68,
"outputs": [
{
"data": {
"text/plain": "<Figure size 1000x600 with 2 Axes>",
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAwIAAAIhCAYAAAD98w2UAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/H5lhTAAAACXBIWXMAAA9hAAAPYQGoP6dpAABUhUlEQVR4nO3deVyUVf//8feIgCiuuOUSuWsuiChoaRlqbpm7t1kYaqLldqfmmrkimlmmuOBWmeZuZtnqlpo7mksuiZrpTRrkFoKgML8//DnfGdFkjGFGrtfzflyPh5zrzLk+FxD3fOZzzrlMZrPZLAAAAACGksPZAQAAAADIeiQCAAAAgAGRCAAAAAAGRCIAAAAAGBCJAAAAAGBAJAIAAACAAZEIAAAAAAZEIgAAAAAYEIkAAPwDnrn4aOPnBwD3RyIAuIjDhw/rrbfeUsOGDVWjRg01btxYo0aN0rlz5xx2zY8//lhPP/20atSooVmzZmXKmLt371alSpW0e/fuTBnPmWbNmqUFCxY8sF9wcLCGDRuWBRHdvlatWrUUGxt7z/OVKlXSjBkzsiSWjAoJCVFISMh9z9/5nXnjjTfueX7NmjWqVKmSzp8/b9d1M/rzc4Rhw4YpODg4w/1nz56tSpUq6dChQ/ftM378ePn7+yshIUGSNHjwYFWqVCnd8e233/7r+AEYA4kA4AKWLFmizp0766+//tKgQYM0b948hYWFac+ePerQoYOOHz+e6ddMSEjQ5MmTVaNGDS1YsEBt27bNlHGrVq2q5cuXq2rVqpkynjN9+OGHSkpKemC/yMjI+76JdYTr16/r7bffzrLrZZWNGzdq3bp1mTZeRn9+rqBt27bKkSOHvvzyy3ueT0lJ0VdffaVmzZrJ29tbknT8+HG98MILWr58uc1Rt27drAwdwCOMRABwsujoaIWHh6tLly5auHChWrVqpaCgIHXq1ElLly6Vp6enRowYkenXvXr1qtLS0tS4cWPVqVNHjz32WKaM6+3trZo1a1rerBjBk08+qccffzzLrpcvXz799NNPWrFiRZZdMyvky5dP4eHhio+Pd3YoWa548eKqX7++vv76a6WmpqY7/+OPP+rKlSvq0KGDJCk5OVlnzpxR3bp1VbNmTZujQIECWRw9gEcViQDgZAsWLFDevHk1cODAdOcKFSqkYcOGqVGjRkpMTJQkpaamasmSJWrVqpVq1Kihhg0b6r333lNycrLldcOGDVNoaKhWr16tpk2bqlq1amrdurW2bt0q6fZUizvTFkaMGKFKlSpJuvcUl7unZdy4cUNjxozRM888o2rVqqlZs2Y20y/uNTXo8OHD6tGjh4KCglSrVi317t1bJ0+eTPeanTt3qnv37vLz89PTTz+tKVOm3PNNkXVs1atX1759+9S+fXtVr15dTZs21aZNm3T69Gm9+uqr8vPzU5MmTbR+/Xqb1+7du1c9evRQnTp1VK1aNQUHB2vGjBlKS0uTJMv3JDIy0vLvGTNmqEmTJoqMjFRgYKDq16+vq1ev2nzfIiIiVKlSJe3atSvd93Dt2rX3vRd7BAcHKzAwUJMnT9Yff/zxj33T0tI0d+5cNWnSRNWqVVPTpk316aefphvvQT/3+937jRs3NHXqVD3//POqVq2aatWqpW7duunYsWN239ebb76pxMREjRkz5oF9Y2NjNXDgQAUGBsrPz0+vvvqqjh49ajl/989v0aJFqly5si5fvmzpM3PmTMvv3R0bNmxQ5cqVdfHiRUkZ/91dtmyZnnvuOdWqVUs//fRTuniPHj2q2rVrq2fPnkpJSbnnPbVv317x8fE28dzx+eefq2zZsgoICJAk/frrr7p165aqVKnywO8VANwPiQDgRGazWdu3b1e9evXk5eV1zz4tWrRQnz59lDt3bknSO++8o4iICDVu3FizZ8/Wyy+/rMWLF+uNN96wWRh55MgRLViwQP3799fMmTPl5uamfv366erVq2rYsKEiIyMlSa+//rqWL1+e4ZgnTpyorVu3aujQoVqwYIEaNWqkd999V6tXr75n/127dumll16yvHbChAn6448/1LlzZ506dcqm7+DBgxUQEKA5c+bohRde0Pz587Vy5cp/jOfWrVsaNGiQOnfurNmzZ8vLy0uDBw9W79691bBhQ82ZM0dFixbV0KFDdeHCBUm3p1SEhoaqQIEC+uCDDzR79mzVrl1bkZGR+uabbyTJ8j3p0KGDzfcnNjZWP/74oz744AMNHz5c+fPnt4nnzTff1BNPPKHRo0crJSVFsbGxCg8PV/PmzdWmTZsMf5//iclk0sSJE5WWlvbAKUJjxozR9OnT9eKLL2rOnDlq1qyZJk6cqJkzZ9p93Xvd+5AhQ7R69WqFhYVp4cKFGj58uE6ePKlBgwbZvVC3XLly6tevn3744Qd99dVX9+136dIlde7cWb/88otGjRqlqVOnKi0tTS+//LLld+run1/Dhg1lNpttErQ7/967d6+lbevWrXryySdVrFgxu353IyMjNXToUL3zzjvy9/e3OXfq1Cn16NFDfn5+mjlzpjw8PO55X8HBwSpYsGC66UGXLl3S1q1bLdUASZbpgitXrlT9+vVVrVo1denSRQcPHrzv9w0A7pbT2QEARnb58mUlJyerVKlSGeofExOjVatWadCgQQoLC5MkPf300ypatKiGDBmirVu36tlnn5Uk/f3331qzZo1lykru3Ln1yiuvaNeuXWratKnlk8THH39cNWvWzHDMe/bs0dNPP62WLVtKkoKCgpQ7d275+Pjcs//UqVPl6+uruXPnys3NTZJUv359NWnSRNOnT9eHH35o6duxY0f16dNHklSvXj1t2LBBW7ZsUefOne8bT1pamnr37q2OHTtKkq5du6Y333xTr776qrp16yZJyps3r9q3b68jR46oePHiOn78uJ566ilNmTJFOXLksHwfN23apN27d6tly5aW70nx4sVtvj+3bt3S0KFDVbt27XvGkytXLk2aNEldunTR3LlztX//fnl7e2vs2LEP+tbapXTp0ho4cKAmTJiglStXWu7f2pkzZ7RixQoNHDjQ8vtSv359mUwmRUVFqUuXLipYsGCGr3n3vaekpFjWK7Ro0UKSFBgYqISEBE2aNEnx8fEqUqSIXffVo0cP/fDDDxo/frzq1q2rwoULp+vzySef6MqVK1q6dKlKliwpSXrmmWfUokULffjhh5o+ffo9f35lypTRzp071bx5cyUlJenAgQOqWrWqTSKwbds2tWvXTpJ9v7tdunRRs2bN0sV67tw5hYaGqnLlypo1a9Z9kwBJ8vDw0IsvvqhVq1Zp7NixypUrlyRZqlnWieSdiktSUpKmTp2qK1euaO7cueratauWL1+uypUr/+P3GQAkKgKAU915c/FP01+s7dmzR5Isb8LvaNmypdzc3Gym4xQqVMhm3nrx4sUl6V8vngwKCtKKFSvUs2dPLV68WOfOnVOfPn3UsGHDdH0TExN1+PBhNW/e3HKv0u254M8995zlfu64+5PU4sWLW6ZE/RPr191JSPz8/Cxtd+ZMX7t2TdLtN1Tz5s3TzZs3dfz4cX333XeaPn26UlNTdfPmzQde70HTMfz9/RUaGqqZM2dqx44dmjRpUrrKgbXU1FTdunXLcmT09+GVV15RnTp1NGnSJEu1w9quXbtkNpsVHBxsM35wcLCSk5MVHR2doetYs753Dw8PLViwQC1atNDFixe1a9cuLVu2TJs3b5ak+06B+Sdubm6KiIhQYmLifZOnnTt3qkqVKipWrJjlnnLkyKFnnnlGO3bsuO/YDRs2tJyPjo6Wu7u7unbtqoMHDyolJUUxMTGKjY1Vw4YN7f7dvdfvxPXr1xUaGqq4uDiNHTtWnp6eD7z/9u3b6/r169q0aZOl7fPPP1fDhg1tku1XXnlF8+fP17vvvqugoCA1bdpUH330kby8vDRnzpwHXgcAJCoCgFPlz59fefLkue9WkNLtN9M3b95U/vz5dfXqVUlK9ylrzpw5VbBgQf3999+WtrunGplMJkmyzIF/WCNHjlTx4sW1bt06jR8/3rKl4ZgxY9J9Cvn333/LbDb
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Plot confusion matrix\n",
"plot_confusion_matrix(cm_v5, ['Benign', 'Malignant'], title='Confusion matrix - Neural Network V5')"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-08T16:26:18.901207300Z",
"start_time": "2024-06-08T16:26:18.741105800Z"
}
},
"id": "c50f0edf9c024c66"
},
{
"cell_type": "code",
"execution_count": 69,
"outputs": [
{
"data": {
"text/plain": "<Figure size 1000x600 with 2 Axes>",
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAx0AAAIhCAYAAAArVtfxAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/H5lhTAAAACXBIWXMAAA9hAAAPYQGoP6dpAACHmUlEQVR4nOzdeXxN1/7/8dfJHIkYgpha8xBBRBBDDC1atCpV7XXba+41q5aiKKLGolpDaypt9ap5aCmq5poVMRRtxNCgCJIQmZPz+8PX+eU02kqak30q72cf+1Fn7XXW/uzIlvPJZ629TWaz2YyIiIiIiIiNOBgdgIiIiIiIPN6UdIiIiIiIiE0p6RAREREREZtS0iEiIiIiIjalpENERERERGxKSYeIiIiIiNiUkg4REREREbEpJR0iIiIiImJTSjpERP4B9BzXP6avjYiI/VPSISJ25eTJkwwZMoRmzZpRs2ZNWrRowahRo4iMjLTqV6VKFWbNmpWrsc2aNYsqVapYXsfFxdG7d2/8/f2pW7cuFy9epEqVKqxZsyZHj7tt2zaGDRtmeX3w4EGqVKnCwYMHc/Q4/zR37txh6NCh/Pjjj0aHIiIif8HJ6ABERB5YsmQJEydOJCgoiMGDB1OsWDEuXbrEwoUL2bJlC1988QVVq1Y1LL6XX36Zxo0bW16vW7eOHTt2MHr0aCpVqkTJkiVZvnw5Tz75ZI4e9/PPP7d67efnx/Lly6lYsWKOHuef5syZM3z99de89NJLRociIiJ/QUmHiNiFI0eOMGHCBF577TVGjhxpaQ8KCqJFixaEhIQwYsSIHK8iZEXx4sUpXry45XVMTAwAr776KiaTCYBatWrZPA5PT89cOY6IiEhO0fQqEbELCxcuJH/+/AwaNCjTvsKFC/POO+/QvHlz4uPjH/r+s2fP0r9/f+rXr4+fnx+NGzdm/PjxJCYmWvrs3buXV155hYCAAOrWrUufPn2IiIiw7P/111/p3bs3QUFB+Pv7869//Ytdu3ZZ9mecXtWpUyfL9K6qVavyzjvvcPny5UzTq86fP0///v2pV68edevWpVevXlbHvHz5MkOHDiU4OBg/Pz8aNGjA0KFDiY6Othzn0KFDHDp0yDKl6mHTq06ePEmPHj0ICgqidu3a9O7dm/DwcMv+B+/Zv38/3bt3x9/fn0aNGjF16lTS0tL+8O9lzZo1VKtWjZUrV9KoUSPq1avHuXPnANi6dSvt27enRo0aNGrUiPHjx1v9/cyaNYunn36aHTt20KpVK/z9/XnllVcyTQu7ceMGw4cPp2nTptSsWZMOHTqwbds2qz5VqlRh9uzZtG/fnpo1azJ79mw6d+4MQOfOnenUqdMfnoOIiBhPSYeIGM5sNrNnzx4aNGiAu7v7Q/u0adOGfv36kS9fvkz7bty4wWuvvUZCQgKTJ09mwYIFPPfcc3z55ZcsXrwYgMjISPr27Uv16tWZM2cOEyZM4MKFC/Ts2ZP09HTS09Pp1asXCQkJTJkyhU8++YSCBQvSp08fLl26lOmYY8aMoUOHDgAsX76cvn37Zupz/fp1/vWvf3Hx4kVCQ0OZOnUqN2/epEuXLsTExJCQkEDnzp2JiIhgzJgxLFy4kM6dO/Ptt9/y4YcfWo5TrVo1qlWrxvLly/Hz88t0nAMHDvDvf/8bgIkTJzJ+/Hh+++03OnbsaJXgALz99tsEBgYyd+5cnn/+eT799FNWrlz5Z389pKWlsWjRIiZMmMDw4cOpUKEC69evp1+/fpQvX56PP/6Y/v37880339C3b1+rhd23b99m2LBhvPrqq8yYMQM3Nzd69OjBmTNnALh58yYdOnTgxx9/5K233mLWrFmUKlWKfv368c0331jFMXfuXNq2bcvMmTNp0aIFo0ePBmD06NGMGTPmT89BRESMpelVImK46OhokpKSKF26dLbe/8svv+Dr68uMGTPw9PQEoGHDhuzdu5eDBw/Ss2dPTpw4QWJiIr169cLHxwe4P11q27ZtxMfHk5CQwPnz5+nbty9NmzYFsPxGPTk5OdMxK1asaJlq9WCq0+XLl636fP755yQnJ/PZZ59RtGhR4H5V5N///jfHjx+nWLFiFC9enPfff58nnngCgPr163P8+HEOHTpkOc6Dc/qjKVUffPABZcqUYf78+Tg6OgIQHBxMy5YtmTlzJjNmzLD0ffnll+nXrx8ADRo0YOvWrezcuZOOHTv+6de4d+/eNGvWDLifJE6bNo3GjRszbdo0S5+yZcvStWtXdu3aZembkJBAaGgoISEhlvNr0aIF8+fP58MPP+Szzz7j9u3bfPfdd5QqVQqApk2b0rVrV6ZMmcLzzz+Pg8P934/VqVOHbt26WY4XGxtr+Rrl9fUtIiL2TkmHiBjuwQflP5vm82eCg4MJDg4mJSWFc+fOcenSJX755Rdu375NwYIFAfD398fV1ZUOHTrQqlUrmjRpQlBQEDVr1gTAw8ODihUrMmrUKPbs2UNwcDBNmjRh+PDh2T6vI0eOUKtWLUvCAfcTnR07dlhef/XVV6Snp3Px4kUuXbrEuXPnOH/+PKmpqY90jPj4eE6ePEn//v0tX0cALy8vnnrqKavpYQABAQFWr4sXL/6HU9Yy8vX1tfz5/PnzXLt2jV69elnFWbduXTw9Pdm7d68l6XBycuL555+39HFzc6NJkybs3r0bgEOHDhEQEGBJOB544YUXGD58OOfPn7ckFBljEBGRfxYlHSJiuAIFCuDh4cHVq1f/sE98fDwpKSkUKFAg07709HSmT5/OkiVLiI+Pp0SJEtSsWRNXV1dLn9KlS/O///2P+fPns2rVKhYvXoyXlxevvvoqb775JiaTiUWLFjFnzhy+//571q1bh7OzMy1atGDs2LEPPe5fiYmJ+cvqzWeffcbcuXOJiYmhSJEiVK9eHXd3d+7evftIx7h79y5ms5kiRYpk2lekSJFM47i5uVm9dnBweKTnXGSc1vZgAf3YsWMZO3Zspr43btywisHJyfpHjbe3t2WM2NhYS5Xn97HD/dviPiwGERH5Z1HSISJ2ITg4mIMHD5KUlGSVLDywYsUK3n//fVatWpVpXcP8+fP5/PPPGTt2LM888wz58+cHsKy5eCDjdKkjR46wfPly5s6dS9WqVWndujU+Pj6EhoYyZswYzp49y+bNm1mwYAGFChXK1pqB/Pnzc/v27Uzt+/fvp3Tp0oSFhTF58mSGDBlC+/btKVy4MAADBw7k5MmTj3wMk8nEzZs3M+2LioqyVHpykpeXFwBDhw6lXr16mfZnTNAeJBcZ3bx5E29vb0vfqKioTH0etBUqVCgnQhYREYNpIbmI2IXu3bsTExPDRx99lGlfVFQUixYtomLFig9dSH3kyBEqVqzISy+9ZEk4rl+/zi+//EJ6ejpwf33FU089RXJyMi4uLjRo0IBx48YBcPXqVY4dO0bDhg05ceIEJpMJX19f3nrrLSpXrvynFZg/U6dOHY4fP26VeNy6dYvXX3+dXbt2ceTIEby8vHj99dctCce9e/c4cuSIJW7AsqbhYfLly0f16tXZtGmT1fS0u3fvsnPnTgIDA7MV+58pX7483t7eXL58mRo1alg2Hx8fPvjgA06fPm3pm5iYyA8//GD1evfu3TRo0AC4PyXr2LFjXLlyxeoY33zzDUWLFqVMmTJ/GEfG6WQiImLfVOkQEbtQq1YtBg4cyEcffURERAQhISEUKlSI8PBwFi5cSFJS0kMTErhfwfjkk0+YP38+tWrV4tKlS8ybN4/k5GQSEhKA+wuYp02bRr9+/fjPf/6Do6Mjy5Ytw8XFhaeeeopSpUrh5ubG0KFDGTBgAEWKFGHfvn2cOXPGcmvWrOratSvr1q3j9ddfp1evXjg7OzNnzhyKFy9O27Zt2bZtG0uXLmXy5Mk89dRT3Lhxg4ULF3Lz5k2raoGXlxfHjh1j//79VKtWLdNxBg8eTI8ePejZsyevvvoqKSkpzJ8/n+TkZMui8Zzk6OjIW2+9xejRo3F0dOSpp57izp07fPLJJ1y/fj1TYjh8+HDefPNNvL29WbhwIfHx8fTp0weAbt268c0339C1a1f
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Plot classification report\n",
"plot_classification_report(cr_v5)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-08T16:26:19.522283200Z",
"start_time": "2024-06-08T16:26:19.336932800Z"
}
},
"id": "a5685db72a4db790"
},
{
"cell_type": "code",
"execution_count": 90,
"outputs": [
{
"data": {
"text/plain": "<Figure size 1000x600 with 1 Axes>",
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA0sAAAIhCAYAAACfXCH+AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/H5lhTAAAACXBIWXMAAA9hAAAPYQGoP6dpAABLdElEQVR4nO3deVyU9f7//yeLCIoaomJh4haICIqQW5iFu5UiWcfsuKVppZalpYgLZi6YlR/luFRSah4tE3E3M7VOlprmmuFRXI5KGSqUyibD/P7w53ybuDRQmAF53G83bjfnfb2vuV7X+ILhybWMg9lsNgsAAAAAYMXR3gUAAAAAQElEWAIAAAAAA4QlAAAAADBAWAIAAAAAA4QlAAAAADBAWAIAAAAAA4QlAAAAADBAWAIAAAAAA4QlAAAAADBAWAKAu8TIkSPl5+en+Ph4e5cCG9i1a5f8/Py0a9cue5cCAHctB7PZbLZ3EQCAO3P58mWFhYWpdu3aysnJ0aZNm+Tg4GDvslCMrly5ouPHj6tBgwZyd3e3dzkAcFfiyBIA3AXWrVsnSYqOjtapU6e0c+dOO1eE4ubu7q6mTZsSlACgGBGWAOAusHLlSrVq1UotW7aUj4+Pli9fnm9OYmKievTooSZNmuiRRx7RO++8o5ycHMvy/fv367nnnlOzZs3UsmVLvfbaazp//rwkKSEhQX5+fjp79qzVc4aHh2vMmDGWx35+foqLi1NkZKSCgoIUFxcnSfrhhx80cOBAPfjgg2rcuLHCw8M1Z84c5eXlWda9cuWKJk+erDZt2qhp06Z68skntX37dklSbGysgoKCdPnyZavtz507VyEhIcrMzDR8Xcxmsz7++GN16dJFQUFB6tChgxYuXKg/n1SxY8cO9e7dWyEhIWrRooVGjhypX375xbI8ISFBgYGB2rNnj5588kkFBgaqU6dO2rp1q06cOKF+/fqpSZMm6tChg9avX2+1np+fnw4cOKAePXooKChITzzxhDZt2mRV49mzZ/XGG28oLCxMAQEBatWqld544w2lpaVZvc5Tp05Vv379FBQUpOjo6Hyn4WVlZSkmJkYPP/ywGjdurM6dO2vhwoVW2/rtt98UFRWltm3bKigoSD179tRXX31lNcfPz09Lly5VdHS0mjdvruDgYL3yyiu6cOGC4WsMAHczwhIAlHLHjh3ToUOHFBERIUmKiIjQV199ZfXL7dKlSzV69GgFBAQoLi5OgwcP1pIlS/TWW29Jko4cOaJ//vOfys7O1owZMzRp0iQdPnxYAwcOVG5ubqHqmT9/vp544gnNnj1bnTp1UlJSkvr376977rlH7733nubNm6fQ0FDFxcVp48aNkiSTyaTnnntOa9eu1ZAhQzR37lzVq1dPQ4cO1Z49e9SzZ09lZ2fnCxqrV69W165d5ebmZljLjBkzNGPGDIWHh2v+/Pnq2bOnZs6cqffff1/S9QD53HPP6d5779W7776rqKgo7du3T//4xz908eJFy/Pk5uZq5MiR6tWrl+bNmyc3NzeNGjVKL7zwgh555BHNnz9fNWrU0OjRo/Xrr79a1TBkyBC1a9dOcXFxqlu3rkaMGKGvv/5akpSZmam+ffsqOTlZEydO1MKFC9W3b1+tX79e7733ntXzLF26VIGBgZo7d6569uyZb1+nTp2qb775RqNHj9bChQvVrl07zZgxQytXrpQkXbhwQT179tSePXv06quvas6cOfL29tbQoUO1Zs0aq+d67733lJeXp3fffVdvvPGGtm3bpqlTp/7t/z0A3HXMAIBSbdq0aebmzZubs7OzzWaz2ZySkmJu2LChed68eWaz2Ww2mUzmVq1amV966SWr9T788ENzjx49zDk5Oebhw4ebH3roIXNWVpZl+Y8//mh+9NFHzUeOHDGvXLnS7Ovraz5z5ozVczz66KPm0aNHWx77+vqa+/XrZzVn1apV5kGDBplNJpNlzGQymUNCQszjx483m81m89atW82+vr7mL7/80mrOP/7xD/OcOXPMZrPZ/I9//MP87LPPWpbv3bvX7Ovra/7xxx8NX5fff//d3KhRI/OUKVOsxidPnmweOHCg2WQymR966CHzc889Z7X89OnT5oCAAHNsbKzZbDZb9v3f//63Zc769evNvr6+5lmzZlnGDh06ZLUPN9aLi4uzzMnLyzN3797d/NRTT5nNZrP5yJEj5meeecb8v//9z6qGIUOGmDt16mR5/Oijj5rbt29vNWfnzp1mX19f886dO81ms9ncqVMn87hx46zmxMXFmbdt22Y2m83mGTNmmAMCAsxnz561mtOvXz/zQw89ZPn/8fX1NT/zzDNWc8aMGWNu2rSpGQDKGmd7hzUAwO27du2a1qxZo/bt2ysrK0tZWVmqWLGiQkJC9Nlnn2nw4ME6efKkLl68qA4dOlitO3DgQA0cOFCStHfvXrVt21bly5e3LA8ODtbWrVslST///HOBa/L397d6HBERoYiICGVnZ+vkyZM6ffq0fv75Z5lMJl27ds2y/XLlyik8PNyynqOjo9XphE8++aTGjx+vc+fOydvbW6tWrVLdunUVHBxsWMf+/fuVm5urjh07Wo2PGzdOkpScnKzU1FSNHDnSannt2rUVHBys3bt3W43/eTuenp6SpCZNmljG7rnnHknSH3/8YbVejx49LP92cHBQhw4dNGfOHGVlZcnf31///ve/lZeXp1OnTun06dM6fvy4Tpw4ke+I3l9f179q0aKFli9frl9//VVt27ZV27ZtNXToUMvy3bt3Kzg4WN7e3lbrdevWTVFRUTpx4oQaNGggSWratKnVnJo1a970VEcAuJsRlgCgFNu+fbsuXryozz//XJ9//nm+5f/5z38sNwC48Qu+kfT09FsuL4wKFSpYPc7KytLkyZO1evVq5ebmqlatWgoODpazs7Pl2qH09HTdc889cnS8+dnhXbt21dSpU7V69WoNHDhQGzdu1ODBg2+5T5JUtWrVWy6vVq1avmXVqlXTkSNHrMaMbqRws9P//qxGjRpWjz09PWU2m/XHH3/I1dVVH330kebPn6/09HRVq1ZNjRs3lpubW77rs/76uv5VdHS0atasqTVr1mjy5MmaPHmygoODFRMTo4YNG+r333/X/fffb7ivknXI++t+OTo6Wl3nBQBlBWEJAEqxlStX6v7779eUKVOsxs1ms4YNG6bly5frtddekyRdunTJak5aWpqOHDmi4OBgVapUKd9ySfr666/l7+9vuQ35n2/IIElXr1792xqnTJmiL774QrNmzVLr1q0tv/S3atXKMqdSpUpKT0+X2Wy2uuX5kSNHZDabFRAQoIoVK6pz587auHGjfH19lZGRoe7du990u5UrV7bsd7169SzjKSkp+t///icPDw9JMrxxQWpqqmX5nboRgm64cOGCnJycdM8992jt2rWaPn26Xn/9dUVGRlqC3SuvvKJDhw4VajsuLi568cUX9eKLLyolJUXbtm3T3LlzNXLkSK1fv15VqlRRampqvvVujBXV/gLA3YQbPABAKZWamqr//Oc/euyxx9SiRQurr5YtW6pz5876+uuvVblyZXl4eGjbtm1W669evVqDBw/WtWvXFBoaqh07dljdHe/IkSMaPHiwfvrpJ8tRlT/fvCA5OdlydOZW9u7dqxYtWqh9+/aWoHT48GFdunTJEr5CQ0N17do1ffPNN5b1zGazoqKitGDBAstYz5499d///leLFi1S69at5eXlddPtBgUFqVy5cvn2Oz4+Xq+99poeeOABVa9e3XLb9RvOnDmj/fv3q1mzZn+7bwWxZcsWq33avHmzQkJC5OLior1796py5coaNGiQJShdvXpVe/fuzRdMbyUrK0udOnWyfCDxfffdp2effVaPPfaYUlJSJEkPPvig9u3bp3Pnzlmtu2bNGlWvXl0+Pj53uqsAcNfhyBIAlFKJiYnKzc3VY489Zrg8IiJCK1as0Geffabhw4frzTfflKenp8LDw3Xy5EnNnj1bzz77rKpUqaKXXnpJ//jHPzRkyBD17dtXWVlZmjVrloKCgvTQQw8pKytLrq6umj59ul555RV
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Accuracy comparison\n",
"accuracies = [acc_v1, acc_v2, acc_v3, acc_v4, acc_v5]\n",
"models = ['Neural Network V1', 'Neural Network V2', 'Neural Network V3', 'Neural Network V4', 'Neural Network V5']\n",
"\n",
"fig, ax = plt.subplots(figsize=(10, 6))\n",
"sns.barplot(x=models, y=accuracies, hue=models, ax=ax)\n",
"\n",
"# Add labels\n",
"for i, v in enumerate(accuracies):\n",
" ax.text(i, v + 0.01, str(round(v, 2)), ha='center', va='bottom')\n",
"\n",
"ax.set_title('Accuracy comparison')\n",
"ax.set_xlabel('Model')\n",
"ax.set_ylabel('Accuracy')\n",
"\n",
"plt.show()"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-08T18:14:36.782522200Z",
"start_time": "2024-06-08T18:14:36.623693300Z"
}
},
"id": "822f2e6732f1a75d"
},
{
"cell_type": "code",
"execution_count": 89,
"outputs": [
{
"data": {
"text/plain": "0.9824561403508771"
},
"execution_count": 89,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"acc_v4"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-08T18:14:27.106543Z",
"start_time": "2024-06-08T18:14:27.067155600Z"
}
},
"id": "e85c80a8f36fd6a7"
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [],
"metadata": {
"collapsed": false
},
"id": "aa65d038e747f33"
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}