{ "metadata": { "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.5-final" }, "orig_nbformat": 2, "kernelspec": { "name": "python385jvsc74a57bd02cef13873963874fd5439bd04a135498d1dd9725d9d90f40de0b76178a8e03b1", "display_name": "Python 3.8.5 64-bit (conda)" } }, "nbformat": 4, "nbformat_minor": 2, "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "\n", "import torch\n", "import torch.nn.functional as F\n", "from torch import nn\n", "from torch.autograd import Variable\n", "import torchvision.transforms as transforms\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.preprocessing import MinMaxScaler\n", "from sklearn.metrics import accuracy_score\n", "import numpy as np\n", "import pandas as pd\n", "\n", "\n", "\n", "class LogisticRegressionModel(nn.Module):\n", " def __init__(self, input_dim, output_dim):\n", " super(LogisticRegressionModel, self).__init__()\n", " self.linear = nn.Linear(input_dim, output_dim)\n", " self.sigmoid = nn.Sigmoid()\n", " def forward(self, x):\n", " out = self.linear(x)\n", " return self.sigmoid(out)\n", "\n", "\n", "data_train = pd.read_csv(\"data_train.csv\")\n", "data_test = pd.read_csv(\"data_test.csv\")\n", "data_val = pd.read_csv(\"data_val.csv\")\n", "FEATURES = [ 'age','hypertension','heart_disease','ever_married', 'avg_glucose_level', 'bmi']\n" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "x_train = data_train[FEATURES].astype(np.float32)\n", "y_train = data_train['stroke'].astype(np.float32)\n", "\n", "x_test = data_test[FEATURES].astype(np.float32)\n", "y_test = data_test['stroke'].astype(np.float32)\n", "\n", "\n", "\n", "fTrain = torch.from_numpy(x_train.values)\n", "tTrain = torch.from_numpy(y_train.values.reshape(2945,1))\n", "\n", "fTest= torch.from_numpy(x_test.values)\n", "tTest = torch.from_numpy(y_test.values)\n" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "\n", "batch_size = 95\n", "n_iters = 1000\n", "num_epochs = int(n_iters / (len(x_train) / batch_size))" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "input_dim = 6\n", "output_dim = 1\n", "\n", "model = LogisticRegressionModel(input_dim, output_dim)\n" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "learning_rate = 0.001\n", "\n", "criterion = torch.nn.BCELoss(reduction='mean') \n", "optimizer = torch.optim.SGD(model.parameters(), lr = learning_rate)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "torch.Size([1, 6])\ntorch.Size([1])\n" ] } ], "source": [ "print(list(model.parameters())[0].size())\n", "print(list(model.parameters())[1].size())" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Epoch # 0\n", "4.4554009437561035\n", "Epoch # 1\n", "2.887434244155884\n", "Epoch # 2\n", "1.4808591604232788\n", "Epoch # 3\n", "0.6207292079925537\n", "Epoch # 4\n", "0.4031478762626648\n", "Epoch # 5\n", "0.34721270203590393\n", "Epoch # 6\n", "0.32333147525787354\n", "Epoch # 7\n", "0.3105970621109009\n", "Epoch # 8\n", "0.30295372009277344\n", "Epoch # 9\n", "0.2980167269706726\n", "Epoch # 10\n", "0.29466450214385986\n", "Epoch # 11\n", "0.29230451583862305\n", "Epoch # 12\n", "0.29059702157974243\n", "Epoch # 13\n", "0.2893349230289459\n", "Epoch # 14\n", "0.2883857190608978\n", "Epoch # 15\n", "0.2876618504524231\n", "Epoch # 16\n", "0.2871031165122986\n", "Epoch # 17\n", "0.28666743636131287\n", "Epoch # 18\n", "0.28632479906082153\n", "Epoch # 19\n", "0.2860531508922577\n", "Epoch # 20\n", "0.28583624958992004\n", "Epoch # 21\n", "0.2856619954109192\n", "Epoch # 22\n", "0.285521000623703\n", "Epoch # 23\n", "0.2854064106941223\n", "Epoch # 24\n", "0.2853126525878906\n", "Epoch # 25\n", "0.2852354049682617\n", "Epoch # 26\n", "0.2851715385913849\n", "Epoch # 27\n", "0.28511837124824524\n", "Epoch # 28\n", "0.2850736975669861\n", "Epoch # 29\n", "0.2850360572338104\n", "Epoch # 30\n", "0.28500401973724365\n", "Epoch # 31\n", "0.2849765419960022\n", "X:\\Anaconda2020\\lib\\site-packages\\torch\\autograd\\__init__.py:145: UserWarning: CUDA initialization: The NVIDIA driver on your system is too old (found version 10010). Please update your GPU driver by downloading and installing a new version from the URL: http://www.nvidia.com/Download/index.aspx Alternatively, go to: https://pytorch.org to install a PyTorch version that has been compiled with your version of the CUDA driver. (Triggered internally at ..\\c10\\cuda\\CUDAFunctions.cpp:109.)\n", " Variable._execution_engine.run_backward(\n" ] } ], "source": [ "for epoch in range(num_epochs):\n", " print (\"Epoch #\",epoch)\n", " model.train()\n", " optimizer.zero_grad()\n", " # Forward pass\n", " y_pred = model(fTrain)\n", " # Compute Loss\n", " loss = criterion(y_pred, tTrain)\n", " print(loss.item())\n", " # Backward pass\n", " loss.backward()\n", " optimizer.step()" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "tags": [] }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "predicted Y value: tensor([[0.0468],\n [0.0325],\n [0.2577],\n [0.2059],\n [0.1090],\n [0.0229],\n [0.2290],\n [0.0689],\n [0.2476],\n [0.0453],\n [0.0150],\n [0.4080],\n [0.0424],\n [0.0981],\n [0.0221],\n [0.1546],\n [0.1400],\n [0.1768],\n [0.1684],\n [0.0229],\n [0.1836],\n [0.1200],\n [0.0137],\n [0.2316],\n [0.0185],\n [0.0179],\n [0.0108],\n [0.0175],\n [0.0471],\n [0.4576],\n [0.0210],\n [0.0103],\n [0.0616],\n [0.1850],\n [0.4114],\n [0.4264],\n [0.0405],\n [0.0788],\n [0.2405],\n [0.0340],\n [0.4345],\n [0.1758],\n [0.0385],\n [0.0749],\n [0.4349],\n [0.0357],\n [0.0295],\n [0.3939],\n [0.1147],\n [0.3812],\n [0.0659],\n [0.0675],\n [0.0263],\n [0.1398],\n [0.0959],\n [0.0406],\n [0.0531],\n [0.0500],\n [0.4259],\n [0.1086],\n [0.0611],\n [0.0855],\n [0.0473],\n [0.2826],\n [0.1734],\n [0.0560],\n [0.0466],\n [0.0290],\n [0.1903],\n [0.4515],\n [0.0118],\n [0.2158],\n [0.1293],\n [0.2488],\n [0.0424],\n [0.1809],\n [0.0122],\n [0.0796],\n [0.0901],\n [0.0879],\n [0.0457],\n [0.0091],\n [0.0196],\n [0.3310],\n [0.0978],\n [0.0843],\n [0.0684],\n [0.0340],\n [0.0583],\n [0.0670],\n [0.0133],\n [0.1165],\n [0.0145],\n [0.1581],\n [0.1677],\n [0.0353],\n [0.0745],\n [0.0108],\n [0.0492],\n [0.0611],\n [0.2977],\n [0.2820],\n [0.0219],\n [0.0580],\n [0.0122],\n [0.0726],\n [0.3315],\n [0.0201],\n [0.2460],\n [0.0110],\n [0.0322],\n [0.0180],\n [0.0135],\n [0.3176],\n [0.1390],\n [0.0678],\n [0.1596],\n [0.0128],\n [0.0900],\n [0.0117],\n [0.0224],\n [0.0357],\n [0.0103],\n [0.1728],\n [0.0135],\n [0.0992],\n [0.4371],\n [0.4525],\n [0.0278],\n [0.0617],\n [0.2499],\n [0.0129],\n [0.0424],\n [0.0292],\n [0.3903],\n [0.0108],\n [0.0404],\n [0.0344],\n [0.4109],\n [0.3936],\n [0.0603],\n [0.4396],\n [0.1155],\n [0.3594],\n [0.0305],\n [0.0307],\n [0.0226],\n [0.1284],\n [0.0474],\n [0.0959],\n [0.0135],\n [0.0289],\n [0.3705],\n [0.1538],\n [0.4535],\n [0.0355],\n [0.0169],\n [0.1648],\n [0.4217],\n [0.0951],\n [0.0767],\n [0.0475],\n [0.0452],\n [0.1625],\n [0.0896],\n [0.0114],\n [0.0423],\n [0.3971],\n [0.0173],\n [0.0250],\n [0.3579],\n [0.0131],\n [0.0201],\n [0.0149],\n [0.2615],\n [0.1773],\n [0.1204],\n [0.3556],\n [0.2390],\n [0.0098],\n [0.0190],\n [0.3040],\n [0.0115],\n [0.2033],\n [0.1327],\n [0.0180],\n [0.0610],\n [0.2927],\n [0.1182],\n [0.0115],\n [0.4474],\n [0.3513],\n [0.0451],\n [0.4089],\n [0.0375],\n [0.0127],\n [0.0630],\n [0.0428],\n [0.2085],\n [0.0529],\n [0.3436],\n [0.0678],\n [0.0717],\n [0.0799],\n [0.0967],\n [0.1246],\n [0.1086],\n [0.0387],\n [0.1742],\n [0.1582],\n [0.1374],\n [0.4205],\n [0.0534],\n [0.3051],\n [0.1204],\n [0.0423],\n [0.0324],\n [0.0141],\n [0.0312],\n [0.0261],\n [0.2619],\n [0.0767],\n [0.1742],\n [0.0311],\n [0.1763],\n [0.0326],\n [0.0529],\n [0.3928],\n [0.1209],\n [0.0724],\n [0.3551],\n [0.2067],\n [0.0288],\n [0.0782],\n [0.0661],\n [0.0469],\n [0.1089],\n [0.0985],\n [0.1032],\n [0.1083],\n [0.0546],\n [0.0983],\n [0.1302],\n [0.0153],\n [0.2179],\n [0.1196],\n [0.0275],\n [0.4366],\n [0.0340],\n [0.0286],\n [0.1193],\n [0.0729],\n [0.0553],\n [0.0159],\n [0.0140],\n [0.2195],\n [0.3792],\n [0.3966],\n [0.0424],\n [0.0872],\n [0.0687],\n [0.1941],\n [0.0179],\n [0.0380],\n [0.2445],\n [0.1905],\n [0.1518],\n [0.2370],\n [0.0706],\n [0.1668],\n [0.1265],\n [0.2363],\n [0.0354],\n [0.0263],\n [0.0653],\n [0.0097],\n [0.0152],\n [0.0495],\n [0.2952],\n [0.3581],\n [0.0388],\n [0.0365],\n [0.2808],\n [0.0189],\n [0.0133],\n [0.0692],\n [0.0256],\n [0.0500],\n [0.1452],\n [0.3315],\n [0.4509],\n [0.2079],\n [0.0140],\n [0.2505],\n [0.1044],\n [0.0121],\n [0.0087],\n [0.0453],\n [0.0173],\n [0.0127],\n [0.0176],\n [0.3826],\n [0.0843],\n [0.0885],\n [0.0133],\n [0.0190],\n [0.0278],\n [0.0612],\n [0.2574],\n [0.0404],\n [0.0735],\n [0.0527],\n [0.0300],\n [0.1384],\n [0.2189],\n [0.0301],\n [0.2303],\n [0.0425],\n [0.0294],\n [0.0103],\n [0.2857],\n [0.0620],\n [0.0938],\n [0.2283],\n [0.0237],\n [0.1175],\n [0.1809],\n [0.0725],\n [0.0107],\n [0.0629],\n [0.4485],\n [0.0233],\n [0.0165],\n [0.2240],\n [0.1200],\n [0.0665],\n [0.1153],\n [0.2957],\n [0.0112],\n [0.0328],\n [0.2823],\n [0.4248],\n [0.0308],\n [0.0216],\n [0.1206],\n [0.2086],\n [0.0115],\n [0.0399],\n [0.0246],\n [0.3960],\n [0.1782],\n [0.0591],\n [0.0092],\n [0.0600],\n [0.2938],\n [0.1144],\n [0.0136],\n [0.2075],\n [0.0426],\n [0.0998],\n [0.0407],\n [0.2944],\n [0.2721],\n [0.0734],\n [0.2927],\n [0.0482],\n [0.2740],\n [0.0363],\n [0.4624],\n [0.0558],\n [0.1669],\n [0.0243],\n [0.0109],\n [0.1209],\n [0.0617],\n [0.0634],\n [0.0183],\n [0.0319],\n [0.1135],\n [0.0121],\n [0.0314],\n [0.0137],\n [0.0195],\n [0.0094],\n [0.3304],\n [0.0694],\n [0.0144],\n [0.0278],\n [0.1393],\n [0.3971],\n [0.0939],\n [0.0489],\n [0.0763],\n [0.0394],\n [0.2953],\n [0.0581],\n [0.0404],\n [0.0489],\n [0.0429],\n [0.1940],\n [0.0098],\n [0.0535],\n [0.2953],\n [0.2188],\n [0.0115],\n [0.1468],\n [0.0210],\n [0.2410],\n [0.0685],\n [0.1935],\n [0.1258],\n [0.0146],\n [0.0279],\n [0.0240],\n [0.3981],\n [0.2131],\n [0.0267],\n [0.0184],\n [0.2806],\n [0.2224],\n [0.2687],\n [0.0207],\n [0.2931],\n [0.0707],\n [0.0408],\n [0.0836],\n [0.0799],\n [0.1043],\n [0.0235],\n [0.1093],\n [0.0915],\n [0.0186],\n [0.0885],\n [0.0143],\n [0.0430],\n [0.4653],\n [0.1440],\n [0.0343],\n [0.1683],\n [0.0222],\n [0.1450],\n [0.0497],\n [0.2624],\n [0.0158],\n [0.0157],\n [0.3439],\n [0.1724],\n [0.1858],\n [0.4211],\n [0.0741],\n [0.0708],\n [0.0437],\n [0.0117],\n [0.1091],\n [0.0450],\n [0.1210],\n [0.0864],\n [0.1131],\n [0.0796],\n [0.3588],\n [0.1135],\n [0.0211],\n [0.0152],\n [0.0109],\n [0.1337],\n [0.0341],\n [0.1293],\n [0.0809],\n [0.1133],\n [0.0163],\n [0.0598],\n [0.1512],\n [0.0480],\n [0.1759],\n [0.1126],\n [0.1127],\n [0.0263],\n [0.0215],\n [0.4364],\n [0.0164],\n [0.0447],\n [0.0979],\n [0.0678],\n [0.0624],\n [0.0317],\n [0.0102],\n [0.0188],\n [0.0698],\n [0.1521],\n [0.0642],\n [0.3642],\n [0.4675],\n [0.0807],\n [0.0191],\n [0.0761],\n [0.0902],\n [0.2640],\n [0.0658],\n [0.0115],\n [0.0166],\n [0.0222],\n [0.0138],\n [0.0588],\n [0.0254],\n [0.3206],\n [0.0467],\n [0.0861],\n [0.0161],\n [0.0726],\n [0.2488],\n [0.0441],\n [0.0508],\n [0.2247],\n [0.2302],\n [0.1532],\n [0.4200],\n [0.0451],\n [0.3172],\n [0.3833],\n [0.1273],\n [0.0189],\n [0.0763],\n [0.2182],\n [0.0644],\n [0.2589],\n [0.1022],\n [0.0138],\n [0.0616],\n [0.2501],\n [0.0143],\n [0.0120],\n [0.3911],\n [0.2098],\n [0.1860],\n [0.0638],\n [0.0235],\n [0.0264],\n [0.1376],\n [0.3966],\n [0.2488],\n [0.1522],\n [0.0390],\n [0.3700],\n [0.1037],\n [0.0729],\n [0.1019],\n [0.0281],\n [0.0292],\n [0.0099],\n [0.2580],\n [0.0105],\n [0.1386],\n [0.0141],\n [0.0188],\n [0.0625],\n [0.2696],\n [0.0582],\n [0.0218],\n [0.1327],\n [0.0290],\n [0.3961],\n [0.0202],\n [0.0209],\n [0.3187],\n [0.1900],\n [0.0237],\n [0.3660],\n [0.1311],\n [0.0300],\n [0.0211],\n [0.4100],\n [0.0311],\n [0.1034],\n [0.4346],\n [0.1150],\n [0.0588],\n [0.1075],\n [0.0989],\n [0.2195],\n [0.4500],\n [0.0934],\n [0.0930],\n [0.1336],\n [0.1932],\n [0.1717],\n [0.0731],\n [0.1601],\n [0.2492],\n [0.0096],\n [0.0759],\n [0.1010],\n [0.0592],\n [0.4519],\n [0.2835],\n [0.0693],\n [0.1462],\n [0.0280],\n [0.4231],\n [0.0400],\n [0.1261],\n [0.0129],\n [0.1344],\n [0.0724],\n [0.0362],\n [0.0444],\n [0.0724],\n [0.0266],\n [0.0624],\n [0.0094],\n [0.0557],\n [0.1328],\n [0.1478],\n [0.1098],\n [0.0486],\n [0.1091],\n [0.1119],\n [0.1213],\n [0.2821],\n [0.4471],\n [0.0485],\n [0.0278],\n [0.2685],\n [0.0907],\n [0.0253],\n [0.0618],\n [0.0361],\n [0.2087],\n [0.1609],\n [0.2896],\n [0.2296],\n [0.1015],\n [0.2034],\n [0.0103],\n [0.0483],\n [0.0843],\n [0.2846],\n [0.0126],\n [0.0839],\n [0.0097],\n [0.3067],\n [0.0319],\n [0.1666],\n [0.0134],\n [0.3563],\n [0.0354],\n [0.2181],\n [0.0399],\n [0.1233],\n [0.0332],\n [0.1251],\n [0.3705],\n [0.0720],\n [0.0260],\n [0.2191],\n [0.0998],\n [0.2293],\n [0.1474],\n [0.0092],\n [0.0092],\n [0.0433],\n [0.1093],\n [0.0165],\n [0.3294],\n [0.0136],\n [0.0735],\n [0.0381],\n [0.1373],\n [0.1616],\n [0.0496],\n [0.1992],\n [0.0342],\n [0.2832],\n [0.0306],\n [0.0188],\n [0.3583],\n [0.1543],\n [0.0188],\n [0.0111],\n [0.0964],\n [0.0963],\n [0.2209],\n [0.1034],\n [0.1088],\n [0.0695],\n [0.0308],\n [0.0280],\n [0.0712],\n [0.0474],\n [0.1890],\n [0.3057],\n [0.0896],\n [0.2190],\n [0.1548],\n [0.4623],\n [0.4395],\n [0.3971],\n [0.0328],\n [0.0132],\n [0.4267],\n [0.0234],\n [0.0202],\n [0.0141],\n [0.1102],\n [0.0159],\n [0.0842],\n [0.0629],\n [0.1334],\n [0.1256],\n [0.2835],\n [0.3958],\n [0.1798],\n [0.0108],\n [0.0144],\n [0.0220],\n [0.0982],\n [0.1031],\n [0.0454],\n [0.0292],\n [0.1306],\n [0.4508],\n [0.0465],\n [0.3683],\n [0.0549],\n [0.0102],\n [0.0222],\n [0.0503],\n [0.0245],\n [0.1466],\n [0.1999],\n [0.0237],\n [0.1147],\n [0.4205],\n [0.0654],\n [0.1033],\n [0.1182],\n [0.0737],\n [0.0561],\n [0.2595],\n [0.0116],\n [0.1922],\n [0.4246],\n [0.3039],\n [0.1907],\n [0.4135],\n [0.3967],\n [0.2716],\n [0.2395],\n [0.2179],\n [0.2798],\n [0.0185],\n [0.0197],\n [0.0482],\n [0.0086],\n [0.0088],\n [0.1054],\n [0.1191],\n [0.0319],\n [0.1223],\n [0.0358],\n [0.2026],\n [0.0206],\n [0.0784],\n [0.1204],\n [0.3416],\n [0.3174],\n [0.0210],\n [0.0305],\n [0.0098],\n [0.3320],\n [0.0258],\n [0.3058],\n [0.3179],\n [0.1205],\n [0.0276],\n [0.0308],\n [0.0135],\n [0.3972],\n [0.2718],\n [0.4641],\n [0.0432],\n [0.0744],\n [0.3558],\n [0.2100],\n [0.1327],\n [0.1907],\n [0.0276],\n [0.1947],\n [0.0844],\n [0.2688],\n [0.0658],\n [0.0929],\n [0.1471],\n [0.1222],\n [0.0212],\n [0.0804],\n [0.4366],\n [0.2472],\n [0.0325],\n [0.0178],\n [0.3040],\n [0.0868],\n [0.0683],\n [0.0470],\n [0.2027],\n [0.0262],\n [0.1257],\n [0.1146],\n [0.0586],\n [0.1017],\n [0.4349],\n [0.0286],\n [0.0723],\n [0.3070],\n [0.0135],\n [0.0380],\n [0.0447],\n [0.0161],\n [0.0729],\n [0.0360],\n [0.1328],\n [0.0126],\n [0.0531],\n [0.1831],\n [0.4434],\n [0.0198],\n [0.0878],\n [0.0382],\n [0.0387],\n [0.0438],\n [0.0720],\n [0.0311],\n [0.1378],\n [0.1178],\n [0.1888],\n [0.1199],\n [0.2023],\n [0.1153],\n [0.0523],\n [0.0420],\n [0.0200],\n [0.1671],\n [0.1086],\n [0.0503],\n [0.0441],\n [0.0242],\n [0.1676],\n [0.0120],\n [0.4497],\n [0.0396],\n [0.0922],\n [0.2931],\n [0.0194],\n [0.0267],\n [0.1875],\n [0.1045],\n [0.0839],\n [0.1251],\n [0.0294],\n [0.1090],\n [0.0136],\n [0.0851],\n [0.0360],\n [0.0158],\n [0.0944],\n [0.0110],\n [0.0114],\n [0.0586],\n [0.4468],\n [0.0760],\n [0.0501],\n [0.2267],\n [0.0528],\n [0.0367],\n [0.0803],\n [0.1456],\n [0.2818],\n [0.0266],\n [0.1995],\n [0.3691],\n [0.2341],\n [0.2593],\n [0.0636],\n [0.1788],\n [0.0479],\n [0.0509],\n [0.1104],\n [0.0918],\n [0.1508],\n [0.3680],\n [0.3948],\n [0.1899],\n [0.2569],\n [0.0363],\n [0.0262],\n [0.0936],\n [0.0550],\n [0.1027],\n [0.1444],\n [0.0330],\n [0.0097],\n [0.0761],\n [0.2207],\n [0.0326],\n [0.2501],\n [0.0394],\n [0.0760],\n [0.0381],\n [0.0115],\n [0.2717],\n [0.0423],\n [0.0731],\n [0.1560],\n [0.0826],\n [0.0092],\n [0.0219],\n [0.0751],\n [0.1322],\n [0.2677],\n [0.1361],\n [0.4089],\n [0.0925],\n [0.0266],\n [0.1068],\n [0.3935],\n [0.0987],\n [0.0115],\n [0.3348],\n [0.0551],\n [0.0817],\n [0.0489],\n [0.1392],\n [0.0596],\n [0.0844],\n [0.2388],\n [0.0960],\n [0.0721],\n [0.1400],\n [0.4667],\n [0.2374],\n [0.0349],\n [0.0857],\n [0.1599],\n [0.1922],\n [0.0281],\n [0.0183],\n [0.4507],\n [0.0167],\n [0.0283],\n [0.0402],\n [0.2076],\n [0.1693],\n [0.1446],\n [0.3547],\n [0.0943],\n [0.3730],\n [0.1823],\n [0.0426],\n [0.0149],\n [0.0327],\n [0.3715],\n [0.0474],\n [0.1343],\n [0.1915],\n [0.0690],\n [0.0092],\n [0.1643],\n [0.2189],\n [0.3149],\n [0.2171],\n [0.2178],\n [0.0097],\n [0.3628],\n [0.0163],\n [0.0684],\n [0.1145],\n [0.4074],\n [0.0514],\n [0.3587],\n [0.0905],\n [0.0159],\n [0.1992],\n [0.0109]])\n" ] } ], "source": [ "\n", "y_pred = model(fTest)\n", "print(\"predicted Y value: \", y_pred.data)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "The accuracy is 0.9480651731160896\n" ] } ], "source": [ "print (\"The accuracy is\", accuracy_score(tTest, np.argmax(y_pred.detach().numpy(), axis=1)))" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "torch.save(model, 'stroke.pkl')" ] } ] }