{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "corrected-wholesale", "metadata": {}, "outputs": [], "source": [ "!kaggle datasets download -d yasserh/breast-cancer-dataset" ] }, { "cell_type": "code", "execution_count": null, "id": "ranging-police", "metadata": {}, "outputs": [], "source": [ "!unzip -o breast-cancer-dataset.zip" ] }, { "cell_type": "code", "execution_count": 109, "id": "ideal-spouse", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import torch\n", "from torch import nn\n", "from torch.autograd import Variable\n", "from sklearn.datasets import load_iris\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.metrics import accuracy_score\n", "from sklearn.preprocessing import LabelEncoder\n", "from tensorflow.keras.utils import to_categorical\n", "import torch.nn.functional as F\n", "import pandas as pd" ] }, { "cell_type": "code", "execution_count": 110, "id": "major-compromise", "metadata": {}, "outputs": [], "source": [ "class Model(nn.Module):\n", " def __init__(self, input_dim):\n", " super(Model, self).__init__()\n", " self.layer1 = nn.Linear(input_dim,50)\n", " self.layer2 = nn.Linear(50, 20)\n", " self.layer3 = nn.Linear(20, 3)\n", " \n", " def forward(self, x):\n", " x = F.relu(self.layer1(x))\n", " x = F.relu(self.layer2(x))\n", " x = F.softmax(self.layer3(x)) # To check with the loss function\n", " return x" ] }, { "cell_type": "code", "execution_count": 111, "id": "czech-regular", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | diagnosis | \n", "radius_mean | \n", "texture_mean | \n", "perimeter_mean | \n", "area_mean | \n", "smoothness_mean | \n", "compactness_mean | \n", "concavity_mean | \n", "concave points_mean | \n", "symmetry_mean | \n", "... | \n", "radius_worst | \n", "texture_worst | \n", "perimeter_worst | \n", "area_worst | \n", "smoothness_worst | \n", "compactness_worst | \n", "concavity_worst | \n", "concave points_worst | \n", "symmetry_worst | \n", "fractal_dimension_worst | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
id | \n", "\n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " |
842517 | \n", "M | \n", "20.57 | \n", "17.77 | \n", "132.90 | \n", "1326.0 | \n", "0.08474 | \n", "0.07864 | \n", "0.08690 | \n", "0.07017 | \n", "0.1812 | \n", "... | \n", "24.99 | \n", "23.41 | \n", "158.80 | \n", "1956.0 | \n", "0.1238 | \n", "0.1866 | \n", "0.2416 | \n", "0.1860 | \n", "0.2750 | \n", "0.08902 | \n", "
84300903 | \n", "M | \n", "19.69 | \n", "21.25 | \n", "130.00 | \n", "1203.0 | \n", "0.10960 | \n", "0.15990 | \n", "0.19740 | \n", "0.12790 | \n", "0.2069 | \n", "... | \n", "23.57 | \n", "25.53 | \n", "152.50 | \n", "1709.0 | \n", "0.1444 | \n", "0.4245 | \n", "0.4504 | \n", "0.2430 | \n", "0.3613 | \n", "0.08758 | \n", "
84348301 | \n", "M | \n", "11.42 | \n", "20.38 | \n", "77.58 | \n", "386.1 | \n", "0.14250 | \n", "0.28390 | \n", "0.24140 | \n", "0.10520 | \n", "0.2597 | \n", "... | \n", "14.91 | \n", "26.50 | \n", "98.87 | \n", "567.7 | \n", "0.2098 | \n", "0.8663 | \n", "0.6869 | \n", "0.2575 | \n", "0.6638 | \n", "0.17300 | \n", "
84358402 | \n", "M | \n", "20.29 | \n", "14.34 | \n", "135.10 | \n", "1297.0 | \n", "0.10030 | \n", "0.13280 | \n", "0.19800 | \n", "0.10430 | \n", "0.1809 | \n", "... | \n", "22.54 | \n", "16.67 | \n", "152.20 | \n", "1575.0 | \n", "0.1374 | \n", "0.2050 | \n", "0.4000 | \n", "0.1625 | \n", "0.2364 | \n", "0.07678 | \n", "
843786 | \n", "M | \n", "12.45 | \n", "15.70 | \n", "82.57 | \n", "477.1 | \n", "0.12780 | \n", "0.17000 | \n", "0.15780 | \n", "0.08089 | \n", "0.2087 | \n", "... | \n", "15.47 | \n", "23.75 | \n", "103.40 | \n", "741.6 | \n", "0.1791 | \n", "0.5249 | \n", "0.5355 | \n", "0.1741 | \n", "0.3985 | \n", "0.12440 | \n", "
844359 | \n", "M | \n", "18.25 | \n", "19.98 | \n", "119.60 | \n", "1040.0 | \n", "0.09463 | \n", "0.10900 | \n", "0.11270 | \n", "0.07400 | \n", "0.1794 | \n", "... | \n", "22.88 | \n", "27.66 | \n", "153.20 | \n", "1606.0 | \n", "0.1442 | \n", "0.2576 | \n", "0.3784 | \n", "0.1932 | \n", "0.3063 | \n", "0.08368 | \n", "
84458202 | \n", "M | \n", "13.71 | \n", "20.83 | \n", "90.20 | \n", "577.9 | \n", "0.11890 | \n", "0.16450 | \n", "0.09366 | \n", "0.05985 | \n", "0.2196 | \n", "... | \n", "17.06 | \n", "28.14 | \n", "110.60 | \n", "897.0 | \n", "0.1654 | \n", "0.3682 | \n", "0.2678 | \n", "0.1556 | \n", "0.3196 | \n", "0.11510 | \n", "
844981 | \n", "M | \n", "13.00 | \n", "21.82 | \n", "87.50 | \n", "519.8 | \n", "0.12730 | \n", "0.19320 | \n", "0.18590 | \n", "0.09353 | \n", "0.2350 | \n", "... | \n", "15.49 | \n", "30.73 | \n", "106.20 | \n", "739.3 | \n", "0.1703 | \n", "0.5401 | \n", "0.5390 | \n", "0.2060 | \n", "0.4378 | \n", "0.10720 | \n", "
84501001 | \n", "M | \n", "12.46 | \n", "24.04 | \n", "83.97 | \n", "475.9 | \n", "0.11860 | \n", "0.23960 | \n", "0.22730 | \n", "0.08543 | \n", "0.2030 | \n", "... | \n", "15.09 | \n", "40.68 | \n", "97.65 | \n", "711.4 | \n", "0.1853 | \n", "1.0580 | \n", "1.1050 | \n", "0.2210 | \n", "0.4366 | \n", "0.20750 | \n", "
9 rows × 31 columns
\n", "