{ "cells": [ { "cell_type": "markdown", "source": [ "## Experiments - neural networks in breast cancer classification problem" ], "metadata": { "collapsed": false }, "id": "e1e08c454a98dd01" }, { "cell_type": "code", "execution_count": 1, "outputs": [], "source": [ "# Data manipulation\n", "import numpy as np\n", "import pandas as pd\n", "\n", "# Data visualization\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "sns.set_style('whitegrid')\n", "\n", "# Data preprocessing\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.preprocessing import StandardScaler\n", "\n", "# Metrics\n", "from sklearn.metrics import confusion_matrix, classification_report\n", "from sklearn.metrics import accuracy_score\n", "\n", "# Deep learning\n", "import torch\n", "import torch.nn as nn\n", "import torch.optim as optim" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-06-08T15:50:23.912931300Z", "start_time": "2024-06-08T15:50:19.472582100Z" } }, "id": "c0c219cc1bbd4c7a" }, { "cell_type": "markdown", "source": [ "#### Methods for visualizing confusion matrix and classification report" ], "metadata": { "collapsed": false }, "id": "6064474e7a56f80e" }, { "cell_type": "code", "execution_count": 3, "outputs": [], "source": [ "# Plot confusion matrix\n", "def plot_confusion_matrix(cm, classes, title='Confusion matrix', cmap='Blues', figsize=(10, 6), axis=None):\n", " \"\"\"\n", " Plot the confusion matrix.\n", " \"\"\"\n", " if axis is None:\n", " fig, ax = plt.subplots(figsize=figsize)\n", " else:\n", " ax = axis\n", " \n", " sns.heatmap(cm, annot=True, fmt='d', xticklabels=classes, yticklabels=classes, cmap=cmap, ax=ax)\n", " \n", " ax.set_title(title)\n", " ax.set_xlabel('Predicted label')\n", " ax.set_ylabel('True label')\n", " \n", " if axis is None:\n", " plt.show() \n", " \n", "# Plot classification report\n", "def plot_classification_report(report, title='Classification report', axis=None):\n", " \"\"\"\n", " Plot the classification report.\n", " \"\"\"\n", " if axis is None:\n", " fig, ax = plt.subplots(figsize=(10, 6))\n", " else:\n", " ax = axis\n", " \n", " sns.heatmap(pd.DataFrame(report).iloc[:-1, :].T, annot=True, cmap='Blues', ax=ax)\n", " \n", " ax.set_title('Classification report')\n", " ax.set_xlabel('Metrics')\n", " ax.set_ylabel('Classes')\n", " \n", " if axis is None:\n", " plt.show()" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-06-08T15:51:26.166904900Z", "start_time": "2024-06-08T15:51:26.145794400Z" } }, "id": "689b41e45e990a1b" }, { "cell_type": "markdown", "source": [ "#### Load data" ], "metadata": { "collapsed": false }, "id": "c7ad4d251442c34c" }, { "cell_type": "code", "execution_count": 8, "outputs": [], "source": [ "# Load data\n", "data = pd.read_csv('datasets/data.csv')\n", "\n", "# Delete unnecessary columns\n", "data.drop(['id'], axis=1, inplace=True)" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-06-08T15:52:50.237396100Z", "start_time": "2024-06-08T15:52:50.201061700Z" } }, "id": "54411dcad03637c2" }, { "cell_type": "code", "execution_count": 76, "outputs": [ { "data": { "text/plain": " diagnosis radius_mean texture_mean perimeter_mean area_mean \\\n0 1.0 0.521037 0.022658 0.545989 0.363733 \n1 1.0 0.643144 0.272574 0.615783 0.501591 \n2 1.0 0.601496 0.390260 0.595743 0.449417 \n3 1.0 0.210090 0.360839 0.233501 0.102906 \n4 1.0 0.629893 0.156578 0.630986 0.489290 \n.. ... ... ... ... ... \n564 1.0 0.690000 0.428813 0.678668 0.566490 \n565 1.0 0.622320 0.626987 0.604036 0.474019 \n566 1.0 0.455251 0.621238 0.445788 0.303118 \n567 1.0 0.644564 0.663510 0.665538 0.475716 \n568 0.0 0.036869 0.501522 0.028540 0.015907 \n\n smoothness_mean compactness_mean concavity_mean concave points_mean \\\n0 0.593753 0.792037 0.703140 0.731113 \n1 0.289880 0.181768 0.203608 0.348757 \n2 0.514309 0.431017 0.462512 0.635686 \n3 0.811321 0.811361 0.565604 0.522863 \n4 0.430351 0.347893 0.463918 0.518390 \n.. ... ... ... ... \n564 0.526948 0.296055 0.571462 0.690358 \n565 0.407782 0.257714 0.337395 0.486630 \n566 0.288165 0.254340 0.216753 0.263519 \n567 0.588336 0.790197 0.823336 0.755467 \n568 0.000000 0.074351 0.000000 0.000000 \n\n symmetry_mean ... radius_worst texture_worst perimeter_worst \\\n0 0.686364 ... 0.620776 0.141525 0.668310 \n1 0.379798 ... 0.606901 0.303571 0.539818 \n2 0.509596 ... 0.556386 0.360075 0.508442 \n3 0.776263 ... 0.248310 0.385928 0.241347 \n4 0.378283 ... 0.519744 0.123934 0.506948 \n.. ... ... ... ... ... \n564 0.336364 ... 0.623266 0.383262 0.576174 \n565 0.349495 ... 0.560655 0.699094 0.520892 \n566 0.267677 ... 0.393099 0.589019 0.379949 \n567 0.675253 ... 0.633582 0.730277 0.668310 \n568 0.266162 ... 0.054287 0.489072 0.043578 \n\n area_worst smoothness_worst compactness_worst concavity_worst \\\n0 0.450698 0.601136 0.619292 0.568610 \n1 0.435214 0.347553 0.154563 0.192971 \n2 0.374508 0.483590 0.385375 0.359744 \n3 0.094008 0.915472 0.814012 0.548642 \n4 0.341575 0.437364 0.172415 0.319489 \n.. ... ... ... ... \n564 0.452664 0.461137 0.178527 0.328035 \n565 0.379915 0.300007 0.159997 0.256789 \n566 0.230731 0.282177 0.273705 0.271805 \n567 0.402035 0.619626 0.815758 0.749760 \n568 0.020497 0.124084 0.036043 0.000000 \n\n concave points_worst symmetry_worst fractal_dimension_worst \n0 0.912027 0.598462 0.418864 \n1 0.639175 0.233590 0.222878 \n2 0.835052 0.403706 0.213433 \n3 0.884880 1.000000 0.773711 \n4 0.558419 0.157500 0.142595 \n.. ... ... ... \n564 0.761512 0.097575 0.105667 \n565 0.559450 0.198502 0.074315 \n566 0.487285 0.128721 0.151909 \n567 0.910653 0.497142 0.452315 \n568 0.000000 0.257441 0.100682 \n\n[569 rows x 31 columns]", "text/html": "
\n | diagnosis | \nradius_mean | \ntexture_mean | \nperimeter_mean | \narea_mean | \nsmoothness_mean | \ncompactness_mean | \nconcavity_mean | \nconcave points_mean | \nsymmetry_mean | \n... | \nradius_worst | \ntexture_worst | \nperimeter_worst | \narea_worst | \nsmoothness_worst | \ncompactness_worst | \nconcavity_worst | \nconcave points_worst | \nsymmetry_worst | \nfractal_dimension_worst | \n
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n1.0 | \n0.521037 | \n0.022658 | \n0.545989 | \n0.363733 | \n0.593753 | \n0.792037 | \n0.703140 | \n0.731113 | \n0.686364 | \n... | \n0.620776 | \n0.141525 | \n0.668310 | \n0.450698 | \n0.601136 | \n0.619292 | \n0.568610 | \n0.912027 | \n0.598462 | \n0.418864 | \n
1 | \n1.0 | \n0.643144 | \n0.272574 | \n0.615783 | \n0.501591 | \n0.289880 | \n0.181768 | \n0.203608 | \n0.348757 | \n0.379798 | \n... | \n0.606901 | \n0.303571 | \n0.539818 | \n0.435214 | \n0.347553 | \n0.154563 | \n0.192971 | \n0.639175 | \n0.233590 | \n0.222878 | \n
2 | \n1.0 | \n0.601496 | \n0.390260 | \n0.595743 | \n0.449417 | \n0.514309 | \n0.431017 | \n0.462512 | \n0.635686 | \n0.509596 | \n... | \n0.556386 | \n0.360075 | \n0.508442 | \n0.374508 | \n0.483590 | \n0.385375 | \n0.359744 | \n0.835052 | \n0.403706 | \n0.213433 | \n
3 | \n1.0 | \n0.210090 | \n0.360839 | \n0.233501 | \n0.102906 | \n0.811321 | \n0.811361 | \n0.565604 | \n0.522863 | \n0.776263 | \n... | \n0.248310 | \n0.385928 | \n0.241347 | \n0.094008 | \n0.915472 | \n0.814012 | \n0.548642 | \n0.884880 | \n1.000000 | \n0.773711 | \n
4 | \n1.0 | \n0.629893 | \n0.156578 | \n0.630986 | \n0.489290 | \n0.430351 | \n0.347893 | \n0.463918 | \n0.518390 | \n0.378283 | \n... | \n0.519744 | \n0.123934 | \n0.506948 | \n0.341575 | \n0.437364 | \n0.172415 | \n0.319489 | \n0.558419 | \n0.157500 | \n0.142595 | \n
... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n
564 | \n1.0 | \n0.690000 | \n0.428813 | \n0.678668 | \n0.566490 | \n0.526948 | \n0.296055 | \n0.571462 | \n0.690358 | \n0.336364 | \n... | \n0.623266 | \n0.383262 | \n0.576174 | \n0.452664 | \n0.461137 | \n0.178527 | \n0.328035 | \n0.761512 | \n0.097575 | \n0.105667 | \n
565 | \n1.0 | \n0.622320 | \n0.626987 | \n0.604036 | \n0.474019 | \n0.407782 | \n0.257714 | \n0.337395 | \n0.486630 | \n0.349495 | \n... | \n0.560655 | \n0.699094 | \n0.520892 | \n0.379915 | \n0.300007 | \n0.159997 | \n0.256789 | \n0.559450 | \n0.198502 | \n0.074315 | \n
566 | \n1.0 | \n0.455251 | \n0.621238 | \n0.445788 | \n0.303118 | \n0.288165 | \n0.254340 | \n0.216753 | \n0.263519 | \n0.267677 | \n... | \n0.393099 | \n0.589019 | \n0.379949 | \n0.230731 | \n0.282177 | \n0.273705 | \n0.271805 | \n0.487285 | \n0.128721 | \n0.151909 | \n
567 | \n1.0 | \n0.644564 | \n0.663510 | \n0.665538 | \n0.475716 | \n0.588336 | \n0.790197 | \n0.823336 | \n0.755467 | \n0.675253 | \n... | \n0.633582 | \n0.730277 | \n0.668310 | \n0.402035 | \n0.619626 | \n0.815758 | \n0.749760 | \n0.910653 | \n0.497142 | \n0.452315 | \n
568 | \n0.0 | \n0.036869 | \n0.501522 | \n0.028540 | \n0.015907 | \n0.000000 | \n0.074351 | \n0.000000 | \n0.000000 | \n0.266162 | \n... | \n0.054287 | \n0.489072 | \n0.043578 | \n0.020497 | \n0.124084 | \n0.036043 | \n0.000000 | \n0.000000 | \n0.257441 | \n0.100682 | \n
569 rows × 31 columns
\n