{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "corrected-wholesale", "metadata": {}, "outputs": [], "source": [ "!kaggle datasets download -d yasserh/breast-cancer-dataset" ] }, { "cell_type": "code", "execution_count": null, "id": "ranging-police", "metadata": {}, "outputs": [], "source": [ "!unzip -o breast-cancer-dataset.zip" ] }, { "cell_type": "code", "execution_count": 109, "id": "ideal-spouse", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import torch\n", "from torch import nn\n", "from torch.autograd import Variable\n", "from sklearn.datasets import load_iris\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.metrics import accuracy_score\n", "from sklearn.preprocessing import LabelEncoder\n", "from tensorflow.keras.utils import to_categorical\n", "import torch.nn.functional as F\n", "import pandas as pd" ] }, { "cell_type": "code", "execution_count": 110, "id": "major-compromise", "metadata": {}, "outputs": [], "source": [ "class Model(nn.Module):\n", " def __init__(self, input_dim):\n", " super(Model, self).__init__()\n", " self.layer1 = nn.Linear(input_dim,50)\n", " self.layer2 = nn.Linear(50, 20)\n", " self.layer3 = nn.Linear(20, 3)\n", " \n", " def forward(self, x):\n", " x = F.relu(self.layer1(x))\n", " x = F.relu(self.layer2(x))\n", " x = F.softmax(self.layer3(x)) # To check with the loss function\n", " return x" ] }, { "cell_type": "code", "execution_count": 111, "id": "czech-regular", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
diagnosisradius_meantexture_meanperimeter_meanarea_meansmoothness_meancompactness_meanconcavity_meanconcave points_meansymmetry_mean...radius_worsttexture_worstperimeter_worstarea_worstsmoothness_worstcompactness_worstconcavity_worstconcave points_worstsymmetry_worstfractal_dimension_worst
id
842517M20.5717.77132.901326.00.084740.078640.086900.070170.1812...24.9923.41158.801956.00.12380.18660.24160.18600.27500.08902
84300903M19.6921.25130.001203.00.109600.159900.197400.127900.2069...23.5725.53152.501709.00.14440.42450.45040.24300.36130.08758
84348301M11.4220.3877.58386.10.142500.283900.241400.105200.2597...14.9126.5098.87567.70.20980.86630.68690.25750.66380.17300
84358402M20.2914.34135.101297.00.100300.132800.198000.104300.1809...22.5416.67152.201575.00.13740.20500.40000.16250.23640.07678
843786M12.4515.7082.57477.10.127800.170000.157800.080890.2087...15.4723.75103.40741.60.17910.52490.53550.17410.39850.12440
844359M18.2519.98119.601040.00.094630.109000.112700.074000.1794...22.8827.66153.201606.00.14420.25760.37840.19320.30630.08368
84458202M13.7120.8390.20577.90.118900.164500.093660.059850.2196...17.0628.14110.60897.00.16540.36820.26780.15560.31960.11510
844981M13.0021.8287.50519.80.127300.193200.185900.093530.2350...15.4930.73106.20739.30.17030.54010.53900.20600.43780.10720
84501001M12.4624.0483.97475.90.118600.239600.227300.085430.2030...15.0940.6897.65711.40.18531.05801.10500.22100.43660.20750
\n", "

9 rows × 31 columns

\n", "
" ], "text/plain": [ " diagnosis radius_mean texture_mean perimeter_mean area_mean \\\n", "id \n", "842517 M 20.57 17.77 132.90 1326.0 \n", "84300903 M 19.69 21.25 130.00 1203.0 \n", "84348301 M 11.42 20.38 77.58 386.1 \n", "84358402 M 20.29 14.34 135.10 1297.0 \n", "843786 M 12.45 15.70 82.57 477.1 \n", "844359 M 18.25 19.98 119.60 1040.0 \n", "84458202 M 13.71 20.83 90.20 577.9 \n", "844981 M 13.00 21.82 87.50 519.8 \n", "84501001 M 12.46 24.04 83.97 475.9 \n", "\n", " smoothness_mean compactness_mean concavity_mean \\\n", "id \n", "842517 0.08474 0.07864 0.08690 \n", "84300903 0.10960 0.15990 0.19740 \n", "84348301 0.14250 0.28390 0.24140 \n", "84358402 0.10030 0.13280 0.19800 \n", "843786 0.12780 0.17000 0.15780 \n", "844359 0.09463 0.10900 0.11270 \n", "84458202 0.11890 0.16450 0.09366 \n", "844981 0.12730 0.19320 0.18590 \n", "84501001 0.11860 0.23960 0.22730 \n", "\n", " concave points_mean symmetry_mean ... \\\n", "id ... \n", "842517 0.07017 0.1812 ... \n", "84300903 0.12790 0.2069 ... \n", "84348301 0.10520 0.2597 ... \n", "84358402 0.10430 0.1809 ... \n", "843786 0.08089 0.2087 ... \n", "844359 0.07400 0.1794 ... \n", "84458202 0.05985 0.2196 ... \n", "844981 0.09353 0.2350 ... \n", "84501001 0.08543 0.2030 ... \n", "\n", " radius_worst texture_worst perimeter_worst area_worst \\\n", "id \n", "842517 24.99 23.41 158.80 1956.0 \n", "84300903 23.57 25.53 152.50 1709.0 \n", "84348301 14.91 26.50 98.87 567.7 \n", "84358402 22.54 16.67 152.20 1575.0 \n", "843786 15.47 23.75 103.40 741.6 \n", "844359 22.88 27.66 153.20 1606.0 \n", "84458202 17.06 28.14 110.60 897.0 \n", "844981 15.49 30.73 106.20 739.3 \n", "84501001 15.09 40.68 97.65 711.4 \n", "\n", " smoothness_worst compactness_worst concavity_worst \\\n", "id \n", "842517 0.1238 0.1866 0.2416 \n", "84300903 0.1444 0.4245 0.4504 \n", "84348301 0.2098 0.8663 0.6869 \n", "84358402 0.1374 0.2050 0.4000 \n", "843786 0.1791 0.5249 0.5355 \n", "844359 0.1442 0.2576 0.3784 \n", "84458202 0.1654 0.3682 0.2678 \n", "844981 0.1703 0.5401 0.5390 \n", "84501001 0.1853 1.0580 1.1050 \n", "\n", " concave points_worst symmetry_worst fractal_dimension_worst \n", "id \n", "842517 0.1860 0.2750 0.08902 \n", "84300903 0.2430 0.3613 0.08758 \n", "84348301 0.2575 0.6638 0.17300 \n", "84358402 0.1625 0.2364 0.07678 \n", "843786 0.1741 0.3985 0.12440 \n", "844359 0.1932 0.3063 0.08368 \n", "84458202 0.1556 0.3196 0.11510 \n", "844981 0.2060 0.4378 0.10720 \n", "84501001 0.2210 0.4366 0.20750 \n", "\n", "[9 rows x 31 columns]" ] }, "execution_count": 111, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data = pd.read_csv('breast-cancer.csv', index_col=0)\n", "data[1:10]" ] }, { "cell_type": "code", "execution_count": 112, "id": "outdoor-element", "metadata": {}, "outputs": [], "source": [ "lb = LabelEncoder()\n", "data['diagnosis'] = lb.fit_transform(data['diagnosis'])\n", "features = data.iloc[:, 1:32].values\n", "labels = np.array(data['diagnosis'])" ] }, { "cell_type": "code", "execution_count": 113, "id": "buried-community", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[2.057e+01, 1.777e+01, 1.329e+02, 1.326e+03, 8.474e-02, 7.864e-02,\n", " 8.690e-02, 7.017e-02, 1.812e-01, 5.667e-02, 5.435e-01, 7.339e-01,\n", " 3.398e+00, 7.408e+01, 5.225e-03, 1.308e-02, 1.860e-02, 1.340e-02,\n", " 1.389e-02, 3.532e-03, 2.499e+01, 2.341e+01, 1.588e+02, 1.956e+03,\n", " 1.238e-01, 1.866e-01, 2.416e-01, 1.860e-01, 2.750e-01, 8.902e-02],\n", " [1.969e+01, 2.125e+01, 1.300e+02, 1.203e+03, 1.096e-01, 1.599e-01,\n", " 1.974e-01, 1.279e-01, 2.069e-01, 5.999e-02, 7.456e-01, 7.869e-01,\n", " 4.585e+00, 9.403e+01, 6.150e-03, 4.006e-02, 3.832e-02, 2.058e-02,\n", " 2.250e-02, 4.571e-03, 2.357e+01, 2.553e+01, 1.525e+02, 1.709e+03,\n", " 1.444e-01, 4.245e-01, 4.504e-01, 2.430e-01, 3.613e-01, 8.758e-02],\n", " [1.142e+01, 2.038e+01, 7.758e+01, 3.861e+02, 1.425e-01, 2.839e-01,\n", " 2.414e-01, 1.052e-01, 2.597e-01, 9.744e-02, 4.956e-01, 1.156e+00,\n", " 3.445e+00, 2.723e+01, 9.110e-03, 7.458e-02, 5.661e-02, 1.867e-02,\n", " 5.963e-02, 9.208e-03, 1.491e+01, 2.650e+01, 9.887e+01, 5.677e+02,\n", " 2.098e-01, 8.663e-01, 6.869e-01, 2.575e-01, 6.638e-01, 1.730e-01],\n", " [2.029e+01, 1.434e+01, 1.351e+02, 1.297e+03, 1.003e-01, 1.328e-01,\n", " 1.980e-01, 1.043e-01, 1.809e-01, 5.883e-02, 7.572e-01, 7.813e-01,\n", " 5.438e+00, 9.444e+01, 1.149e-02, 2.461e-02, 5.688e-02, 1.885e-02,\n", " 1.756e-02, 5.115e-03, 2.254e+01, 1.667e+01, 1.522e+02, 1.575e+03,\n", " 1.374e-01, 2.050e-01, 4.000e-01, 1.625e-01, 2.364e-01, 7.678e-02],\n", " [1.245e+01, 1.570e+01, 8.257e+01, 4.771e+02, 1.278e-01, 1.700e-01,\n", " 1.578e-01, 8.089e-02, 2.087e-01, 7.613e-02, 3.345e-01, 8.902e-01,\n", " 2.217e+00, 2.719e+01, 7.510e-03, 3.345e-02, 3.672e-02, 1.137e-02,\n", " 2.165e-02, 5.082e-03, 1.547e+01, 2.375e+01, 1.034e+02, 7.416e+02,\n", " 1.791e-01, 5.249e-01, 5.355e-01, 1.741e-01, 3.985e-01, 1.244e-01],\n", " [1.825e+01, 1.998e+01, 1.196e+02, 1.040e+03, 9.463e-02, 1.090e-01,\n", " 1.127e-01, 7.400e-02, 1.794e-01, 5.742e-02, 4.467e-01, 7.732e-01,\n", " 3.180e+00, 5.391e+01, 4.314e-03, 1.382e-02, 2.254e-02, 1.039e-02,\n", " 1.369e-02, 2.179e-03, 2.288e+01, 2.766e+01, 1.532e+02, 1.606e+03,\n", " 1.442e-01, 2.576e-01, 3.784e-01, 1.932e-01, 3.063e-01, 8.368e-02],\n", " [1.371e+01, 2.083e+01, 9.020e+01, 5.779e+02, 1.189e-01, 1.645e-01,\n", " 9.366e-02, 5.985e-02, 2.196e-01, 7.451e-02, 5.835e-01, 1.377e+00,\n", " 3.856e+00, 5.096e+01, 8.805e-03, 3.029e-02, 2.488e-02, 1.448e-02,\n", " 1.486e-02, 5.412e-03, 1.706e+01, 2.814e+01, 1.106e+02, 8.970e+02,\n", " 1.654e-01, 3.682e-01, 2.678e-01, 1.556e-01, 3.196e-01, 1.151e-01],\n", " [1.300e+01, 2.182e+01, 8.750e+01, 5.198e+02, 1.273e-01, 1.932e-01,\n", " 1.859e-01, 9.353e-02, 2.350e-01, 7.389e-02, 3.063e-01, 1.002e+00,\n", " 2.406e+00, 2.432e+01, 5.731e-03, 3.502e-02, 3.553e-02, 1.226e-02,\n", " 2.143e-02, 3.749e-03, 1.549e+01, 3.073e+01, 1.062e+02, 7.393e+02,\n", " 1.703e-01, 5.401e-01, 5.390e-01, 2.060e-01, 4.378e-01, 1.072e-01],\n", " [1.246e+01, 2.404e+01, 8.397e+01, 4.759e+02, 1.186e-01, 2.396e-01,\n", " 2.273e-01, 8.543e-02, 2.030e-01, 8.243e-02, 2.976e-01, 1.599e+00,\n", " 2.039e+00, 2.394e+01, 7.149e-03, 7.217e-02, 7.743e-02, 1.432e-02,\n", " 1.789e-02, 1.008e-02, 1.509e+01, 4.068e+01, 9.765e+01, 7.114e+02,\n", " 1.853e-01, 1.058e+00, 1.105e+00, 2.210e-01, 4.366e-01, 2.075e-01]])" ] }, "execution_count": 113, "metadata": {}, "output_type": "execute_result" } ], "source": [ "features[1:10]" ] }, { "cell_type": "code", "execution_count": 114, "id": "incredible-quantum", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([1, 1, 1, 1, 1, 1, 1, 1, 1])" ] }, "execution_count": 114, "metadata": {}, "output_type": "execute_result" } ], "source": [ "labels[1:10]" ] }, { "cell_type": "code", "execution_count": 115, "id": "brazilian-butler", "metadata": {}, "outputs": [], "source": [ "features_train, features_test, labels_train, labels_test = train_test_split(features, labels, random_state=42, shuffle=True)" ] }, { "cell_type": "code", "execution_count": 116, "id": "exotic-method", "metadata": {}, "outputs": [], "source": [ "# Training\n", "model = Model(features_train.shape[1])\n", "optimizer = torch.optim.Adam(model.parameters(), lr=0.01)\n", "loss_fn = nn.CrossEntropyLoss()\n", "epochs = 100\n", "\n", "def print_(loss):\n", " print (\"The loss calculated: \", loss)" ] }, { "cell_type": "code", "execution_count": 117, "id": "sharp-month", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch # 1\n", "The loss calculated: 0.922476053237915\n", "Epoch # 2\n", "The loss calculated: 0.9223369359970093\n", "Epoch # 3\n", "The loss calculated: 0.9223369359970093\n", "Epoch # 4\n", "The loss calculated: 0.9223369359970093\n", "Epoch # 5\n", "The loss calculated: 0.9223369359970093\n", "Epoch # 6\n", "The loss calculated: 0.9223369359970093\n", "Epoch # 7\n", "The loss calculated: 0.9223369359970093\n", "Epoch # 8\n", "The loss calculated: 0.9223369359970093\n", "Epoch # 9\n", "The loss calculated: 0.9223369359970093\n", "Epoch # 10\n", "The loss calculated: 0.9223369359970093\n", "Epoch # 11\n", "The loss calculated: 0.9223369359970093\n", "Epoch # 12\n", "The loss calculated: 0.9223369359970093\n", "Epoch # 13\n", "The loss calculated: 0.9223369359970093\n", "Epoch # 14\n", "The loss calculated: 0.9223369359970093\n", "Epoch # 15\n", "The loss calculated: 0.9223369359970093\n", "Epoch # 16\n", "The loss calculated: 0.9223369359970093\n", "Epoch # 17\n", "The loss calculated: 0.9223369359970093\n", "Epoch # 18\n", "The loss calculated: 0.9223369359970093\n", "Epoch # 19\n", "The loss calculated: 0.9223369359970093\n", "Epoch # 20\n", "The loss calculated: 0.9223369359970093\n", "Epoch # 21\n", "The loss calculated: 0.9223369359970093\n", "Epoch # 22\n", "The loss calculated: 0.9223369359970093\n", "Epoch # 23\n", "The loss calculated: 0.9223369359970093\n", "Epoch # 24\n", "The loss calculated: 0.9223369359970093\n", "Epoch # 25\n", "The loss calculated: 0.9223369359970093\n", "Epoch # 26\n", "The loss calculated: 0.9223369359970093\n", "Epoch # 27\n", "The loss calculated: 0.9223369359970093\n", "Epoch # 28\n", "The loss calculated: 0.9223369359970093\n", "Epoch # 29\n", "The loss calculated: 0.9223369359970093\n", "Epoch # 30\n", "The loss calculated: 0.9223369359970093\n", "Epoch # 31\n", "The loss calculated: 0.9223369359970093\n", "Epoch # 32\n", "The loss calculated: 0.9223369359970093\n", "Epoch # 33\n", "The loss calculated: 0.9223369359970093\n", "Epoch # 34\n", "The loss calculated: 0.9223369359970093\n", "Epoch # 35\n", "The loss calculated: 0.9223369359970093\n", "Epoch # 36\n", "The loss calculated: 0.9223369359970093\n", "Epoch # 37\n", "The loss calculated: 0.9223369359970093\n", "Epoch # 38\n", "The loss calculated: 0.9223369359970093\n", "Epoch # 39\n", "The loss calculated: 0.9223369359970093\n", "Epoch # 40\n", "The loss calculated: 0.9223369359970093\n", "Epoch # 41\n", "The loss calculated: 0.9223369359970093\n", "Epoch # 42\n", "The loss calculated: 0.9223369359970093\n", "Epoch # 43\n", "The loss calculated: 0.9223369359970093\n", "Epoch # 44\n", "The loss calculated: 0.9223369359970093\n", "Epoch # 45\n", "The loss calculated: 0.9223369359970093\n", "Epoch # 46\n", "The loss calculated: 0.9223369359970093\n", "Epoch # 47\n", "The loss calculated: 0.9223369359970093\n", "Epoch # 48\n", "The loss calculated: 0.9223369359970093\n", "Epoch # 49\n", "The loss calculated: 0.9223369359970093\n", "Epoch # 50\n", "The loss calculated: 0.9223369359970093\n", "Epoch # 51\n", "The loss calculated: 0.9223369359970093\n", "Epoch # 52\n", "The loss calculated: 0.9223369359970093\n", "Epoch # 53\n", "The loss calculated: 0.9223369359970093\n", "Epoch # 54\n", "The loss calculated: 0.9223369359970093\n", "Epoch # 55\n", "The loss calculated: 0.9223369359970093\n", "Epoch # 56\n", "The loss calculated: 0.9223369359970093\n", "Epoch # 57\n", "The loss calculated: 0.9223369359970093\n", "Epoch # 58\n", "The loss calculated: 0.9223369359970093\n", "Epoch # 59\n", "The loss calculated: 0.9223369359970093\n", "Epoch # 60\n", "The loss calculated: 0.9223369359970093\n", "Epoch # 61\n", "The loss calculated: 0.9223369359970093\n", "Epoch # 62\n", "The loss calculated: 0.9223369359970093\n", "Epoch # 63\n", "The loss calculated: 0.9223369359970093\n", "Epoch # 64\n", "The loss calculated: 0.9223369359970093\n", "Epoch # 65\n", "The loss calculated: 0.9223369359970093\n", "Epoch # 66\n", "The loss calculated: 0.9223369359970093\n", "Epoch # 67\n", "The loss calculated: 0.9223369359970093\n", "Epoch # 68\n", "The loss calculated: 0.9223369359970093\n", "Epoch # 69\n", "The loss calculated: 0.9223369359970093\n", "Epoch # 70\n", "The loss calculated: 0.9223369359970093\n", "Epoch # 71\n", "The loss calculated: 0.9223369359970093\n", "Epoch # 72\n", "The loss calculated: 0.9223369359970093\n", "Epoch # 73\n", "The loss calculated: 0.9223369359970093\n", "Epoch # 74\n", "The loss calculated: 0.9223369359970093\n", "Epoch # 75\n", "The loss calculated: 0.9223369359970093\n", "Epoch # 76\n", "The loss calculated: 0.9223369359970093\n", "Epoch # 77\n", "The loss calculated: 0.9223369359970093\n", "Epoch # 78\n", "The loss calculated: 0.9223369359970093\n", "Epoch # 79\n", "The loss calculated: 0.9223369359970093\n", "Epoch # 80\n", "The loss calculated: 0.9223369359970093\n", "Epoch # 81\n", "The loss calculated: 0.9223369359970093\n", "Epoch # 82\n", "The loss calculated: 0.9223369359970093\n", "Epoch # 83\n", "The loss calculated: 0.9223369359970093\n", "Epoch # 84\n", "The loss calculated: 0.9223369359970093\n", "Epoch # 85\n", "The loss calculated: 0.9223369359970093\n", "Epoch # 86\n", "The loss calculated: 0.9223369359970093\n", "Epoch # 87\n", "The loss calculated: 0.9223369359970093\n", "Epoch # 88\n", "The loss calculated: 0.9223369359970093\n", "Epoch # 89\n", "The loss calculated: 0.9223369359970093\n", "Epoch # 90\n", "The loss calculated: 0.9223369359970093\n", "Epoch # 91\n", "The loss calculated: 0.9223369359970093\n", "Epoch # 92\n", "The loss calculated: 0.9223369359970093\n", "Epoch # 93\n", "The loss calculated: 0.9223369359970093\n", "Epoch # 94\n", "The loss calculated: 0.9223369359970093\n", "Epoch # 95\n", "The loss calculated: 0.9223369359970093\n", "Epoch # 96\n", "The loss calculated: 0.9223369359970093\n", "Epoch # 97\n", "The loss calculated: 0.9223369359970093\n", "Epoch # 98\n", "The loss calculated: 0.9223369359970093\n", "Epoch # 99\n", "The loss calculated: 0.9223369359970093\n", "Epoch # 100\n", "The loss calculated: 0.9223369359970093\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:11: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n", " # This is added back by InteractiveShellApp.init_path()\n" ] } ], "source": [ "# Not using dataloader\n", "x_train, y_train = Variable(torch.from_numpy(features_train)).float(), Variable(torch.from_numpy(labels_train)).long()\n", "for epoch in range(1, epochs+1):\n", " print (\"Epoch #\",epoch)\n", " y_pred = model(x_train)\n", " loss = loss_fn(y_pred, y_train)\n", " print_(loss.item())\n", " \n", " # Zero gradients\n", " optimizer.zero_grad()\n", " loss.backward() # Gradients\n", " optimizer.step() # Update" ] }, { "cell_type": "code", "execution_count": 118, "id": "mechanical-humidity", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:11: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n", " # This is added back by InteractiveShellApp.init_path()\n" ] } ], "source": [ "# Prediction\n", "x_test = Variable(torch.from_numpy(features_test)).float()\n", "pred = model(x_test)" ] }, { "cell_type": "code", "execution_count": 119, "id": "based-charleston", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[1., 0., 0.],\n", " [1., 0., 0.],\n", " [1., 0., 0.],\n", " [1., 0., 0.],\n", " [1., 0., 0.],\n", " [1., 0., 0.],\n", " [1., 0., 0.],\n", " [1., 0., 0.],\n", " [1., 0., 0.]], dtype=float32)" ] }, "execution_count": 119, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pred = pred.detach().numpy()\n", "pred[1:10]" ] }, { "cell_type": "code", "execution_count": 120, "id": "dried-accessory", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The accuracy is 0.6223776223776224\n" ] } ], "source": [ "print (\"The accuracy is\", accuracy_score(labels_test, np.argmax(pred, axis=1)))" ] }, { "cell_type": "code", "execution_count": 121, "id": "effective-characterization", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0" ] }, "execution_count": 121, "metadata": {}, "output_type": "execute_result" } ], "source": [ "labels_test[0]" ] }, { "cell_type": "code", "execution_count": 122, "id": "oriented-determination", "metadata": {}, "outputs": [], "source": [ "torch.save(model, \"travel_insurance-pytorch.pkl\")" ] }, { "cell_type": "code", "execution_count": 123, "id": "infectious-wagon", "metadata": {}, "outputs": [], "source": [ "saved_model = torch.load(\"travel_insurance-pytorch.pkl\")" ] }, { "cell_type": "code", "execution_count": 124, "id": "built-contributor", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:11: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n", " # This is added back by InteractiveShellApp.init_path()\n" ] }, { "data": { "text/plain": [ "0" ] }, "execution_count": 124, "metadata": {}, "output_type": "execute_result" } ], "source": [ "np.argmax(saved_model(x_test[0]).detach().numpy(), axis=0)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.3" } }, "nbformat": 4, "nbformat_minor": 5 }