ium_434766/lab5.ipynb
2021-04-17 13:35:20 +02:00

25 KiB


import torch
import torch.nn.functional as F
from torch import nn
from torch.autograd import Variable
import torchvision.transforms as transforms
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import accuracy_score
import numpy as np
import pandas as pd



class LogisticRegressionModel(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(LogisticRegressionModel, self).__init__()
        self.linear = nn.Linear(input_dim, output_dim)
        self.sigmoid = nn.Sigmoid()
    def forward(self, x):
        out = self.linear(x)
        return self.sigmoid(out)


data_train = pd.read_csv("data_train.csv")
data_test = pd.read_csv("data_test.csv")
data_val = pd.read_csv("data_val.csv")
FEATURES =  [ 'age','hypertension','heart_disease','ever_married', 'avg_glucose_level', 'bmi']
x_train = data_train[FEATURES].astype(np.float32)
y_train = data_train['stroke'].astype(np.float32)

x_test = data_test[FEATURES].astype(np.float32)
y_test = data_test['stroke'].astype(np.float32)



fTrain = torch.from_numpy(x_train.values)
tTrain = torch.from_numpy(y_train.values.reshape(2945,1))

fTest= torch.from_numpy(x_test.values)
tTest = torch.from_numpy(y_test.values)

batch_size = 95
n_iters = 1000
num_epochs = int(n_iters / (len(x_train) / batch_size))
input_dim = 6
output_dim = 1

model = LogisticRegressionModel(input_dim, output_dim)
learning_rate = 0.001

criterion = torch.nn.BCELoss(reduction='mean') 
optimizer = torch.optim.SGD(model.parameters(), lr = learning_rate)
print(list(model.parameters())[0].size())
print(list(model.parameters())[1].size())
torch.Size([1, 6])
torch.Size([1])
for epoch in range(num_epochs):
    print ("Epoch #",epoch)
    model.train()
    optimizer.zero_grad()
    # Forward pass
    y_pred = model(fTrain)
    # Compute Loss
    loss = criterion(y_pred, tTrain)
    print(loss.item())
    # Backward pass
    loss.backward()
    optimizer.step()
Epoch # 0
4.4554009437561035
Epoch # 1
2.887434244155884
Epoch # 2
1.4808591604232788
Epoch # 3
0.6207292079925537
Epoch # 4
0.4031478762626648
Epoch # 5
0.34721270203590393
Epoch # 6
0.32333147525787354
Epoch # 7
0.3105970621109009
Epoch # 8
0.30295372009277344
Epoch # 9
0.2980167269706726
Epoch # 10
0.29466450214385986
Epoch # 11
0.29230451583862305
Epoch # 12
0.29059702157974243
Epoch # 13
0.2893349230289459
Epoch # 14
0.2883857190608978
Epoch # 15
0.2876618504524231
Epoch # 16
0.2871031165122986
Epoch # 17
0.28666743636131287
Epoch # 18
0.28632479906082153
Epoch # 19
0.2860531508922577
Epoch # 20
0.28583624958992004
Epoch # 21
0.2856619954109192
Epoch # 22
0.285521000623703
Epoch # 23
0.2854064106941223
Epoch # 24
0.2853126525878906
Epoch # 25
0.2852354049682617
Epoch # 26
0.2851715385913849
Epoch # 27
0.28511837124824524
Epoch # 28
0.2850736975669861
Epoch # 29
0.2850360572338104
Epoch # 30
0.28500401973724365
Epoch # 31
0.2849765419960022
X:\Anaconda2020\lib\site-packages\torch\autograd\__init__.py:145: UserWarning: CUDA initialization: The NVIDIA driver on your system is too old (found version 10010). Please update your GPU driver by downloading and installing a new version from the URL: http://www.nvidia.com/Download/index.aspx Alternatively, go to: https://pytorch.org to install a PyTorch version that has been compiled with your version of the CUDA driver. (Triggered internally at  ..\c10\cuda\CUDAFunctions.cpp:109.)
  Variable._execution_engine.run_backward(

y_pred = model(fTest)
print("predicted Y value: ", y_pred.data)
predicted Y value:  tensor([[0.0468],
        [0.0325],
        [0.2577],
        [0.2059],
        [0.1090],
        [0.0229],
        [0.2290],
        [0.0689],
        [0.2476],
        [0.0453],
        [0.0150],
        [0.4080],
        [0.0424],
        [0.0981],
        [0.0221],
        [0.1546],
        [0.1400],
        [0.1768],
        [0.1684],
        [0.0229],
        [0.1836],
        [0.1200],
        [0.0137],
        [0.2316],
        [0.0185],
        [0.0179],
        [0.0108],
        [0.0175],
        [0.0471],
        [0.4576],
        [0.0210],
        [0.0103],
        [0.0616],
        [0.1850],
        [0.4114],
        [0.4264],
        [0.0405],
        [0.0788],
        [0.2405],
        [0.0340],
        [0.4345],
        [0.1758],
        [0.0385],
        [0.0749],
        [0.4349],
        [0.0357],
        [0.0295],
        [0.3939],
        [0.1147],
        [0.3812],
        [0.0659],
        [0.0675],
        [0.0263],
        [0.1398],
        [0.0959],
        [0.0406],
        [0.0531],
        [0.0500],
        [0.4259],
        [0.1086],
        [0.0611],
        [0.0855],
        [0.0473],
        [0.2826],
        [0.1734],
        [0.0560],
        [0.0466],
        [0.0290],
        [0.1903],
        [0.4515],
        [0.0118],
        [0.2158],
        [0.1293],
        [0.2488],
        [0.0424],
        [0.1809],
        [0.0122],
        [0.0796],
        [0.0901],
        [0.0879],
        [0.0457],
        [0.0091],
        [0.0196],
        [0.3310],
        [0.0978],
        [0.0843],
        [0.0684],
        [0.0340],
        [0.0583],
        [0.0670],
        [0.0133],
        [0.1165],
        [0.0145],
        [0.1581],
        [0.1677],
        [0.0353],
        [0.0745],
        [0.0108],
        [0.0492],
        [0.0611],
        [0.2977],
        [0.2820],
        [0.0219],
        [0.0580],
        [0.0122],
        [0.0726],
        [0.3315],
        [0.0201],
        [0.2460],
        [0.0110],
        [0.0322],
        [0.0180],
        [0.0135],
        [0.3176],
        [0.1390],
        [0.0678],
        [0.1596],
        [0.0128],
        [0.0900],
        [0.0117],
        [0.0224],
        [0.0357],
        [0.0103],
        [0.1728],
        [0.0135],
        [0.0992],
        [0.4371],
        [0.4525],
        [0.0278],
        [0.0617],
        [0.2499],
        [0.0129],
        [0.0424],
        [0.0292],
        [0.3903],
        [0.0108],
        [0.0404],
        [0.0344],
        [0.4109],
        [0.3936],
        [0.0603],
        [0.4396],
        [0.1155],
        [0.3594],
        [0.0305],
        [0.0307],
        [0.0226],
        [0.1284],
        [0.0474],
        [0.0959],
        [0.0135],
        [0.0289],
        [0.3705],
        [0.1538],
        [0.4535],
        [0.0355],
        [0.0169],
        [0.1648],
        [0.4217],
        [0.0951],
        [0.0767],
        [0.0475],
        [0.0452],
        [0.1625],
        [0.0896],
        [0.0114],
        [0.0423],
        [0.3971],
        [0.0173],
        [0.0250],
        [0.3579],
        [0.0131],
        [0.0201],
        [0.0149],
        [0.2615],
        [0.1773],
        [0.1204],
        [0.3556],
        [0.2390],
        [0.0098],
        [0.0190],
        [0.3040],
        [0.0115],
        [0.2033],
        [0.1327],
        [0.0180],
        [0.0610],
        [0.2927],
        [0.1182],
        [0.0115],
        [0.4474],
        [0.3513],
        [0.0451],
        [0.4089],
        [0.0375],
        [0.0127],
        [0.0630],
        [0.0428],
        [0.2085],
        [0.0529],
        [0.3436],
        [0.0678],
        [0.0717],
        [0.0799],
        [0.0967],
        [0.1246],
        [0.1086],
        [0.0387],
        [0.1742],
        [0.1582],
        [0.1374],
        [0.4205],
        [0.0534],
        [0.3051],
        [0.1204],
        [0.0423],
        [0.0324],
        [0.0141],
        [0.0312],
        [0.0261],
        [0.2619],
        [0.0767],
        [0.1742],
        [0.0311],
        [0.1763],
        [0.0326],
        [0.0529],
        [0.3928],
        [0.1209],
        [0.0724],
        [0.3551],
        [0.2067],
        [0.0288],
        [0.0782],
        [0.0661],
        [0.0469],
        [0.1089],
        [0.0985],
        [0.1032],
        [0.1083],
        [0.0546],
        [0.0983],
        [0.1302],
        [0.0153],
        [0.2179],
        [0.1196],
        [0.0275],
        [0.4366],
        [0.0340],
        [0.0286],
        [0.1193],
        [0.0729],
        [0.0553],
        [0.0159],
        [0.0140],
        [0.2195],
        [0.3792],
        [0.3966],
        [0.0424],
        [0.0872],
        [0.0687],
        [0.1941],
        [0.0179],
        [0.0380],
        [0.2445],
        [0.1905],
        [0.1518],
        [0.2370],
        [0.0706],
        [0.1668],
        [0.1265],
        [0.2363],
        [0.0354],
        [0.0263],
        [0.0653],
        [0.0097],
        [0.0152],
        [0.0495],
        [0.2952],
        [0.3581],
        [0.0388],
        [0.0365],
        [0.2808],
        [0.0189],
        [0.0133],
        [0.0692],
        [0.0256],
        [0.0500],
        [0.1452],
        [0.3315],
        [0.4509],
        [0.2079],
        [0.0140],
        [0.2505],
        [0.1044],
        [0.0121],
        [0.0087],
        [0.0453],
        [0.0173],
        [0.0127],
        [0.0176],
        [0.3826],
        [0.0843],
        [0.0885],
        [0.0133],
        [0.0190],
        [0.0278],
        [0.0612],
        [0.2574],
        [0.0404],
        [0.0735],
        [0.0527],
        [0.0300],
        [0.1384],
        [0.2189],
        [0.0301],
        [0.2303],
        [0.0425],
        [0.0294],
        [0.0103],
        [0.2857],
        [0.0620],
        [0.0938],
        [0.2283],
        [0.0237],
        [0.1175],
        [0.1809],
        [0.0725],
        [0.0107],
        [0.0629],
        [0.4485],
        [0.0233],
        [0.0165],
        [0.2240],
        [0.1200],
        [0.0665],
        [0.1153],
        [0.2957],
        [0.0112],
        [0.0328],
        [0.2823],
        [0.4248],
        [0.0308],
        [0.0216],
        [0.1206],
        [0.2086],
        [0.0115],
        [0.0399],
        [0.0246],
        [0.3960],
        [0.1782],
        [0.0591],
        [0.0092],
        [0.0600],
        [0.2938],
        [0.1144],
        [0.0136],
        [0.2075],
        [0.0426],
        [0.0998],
        [0.0407],
        [0.2944],
        [0.2721],
        [0.0734],
        [0.2927],
        [0.0482],
        [0.2740],
        [0.0363],
        [0.4624],
        [0.0558],
        [0.1669],
        [0.0243],
        [0.0109],
        [0.1209],
        [0.0617],
        [0.0634],
        [0.0183],
        [0.0319],
        [0.1135],
        [0.0121],
        [0.0314],
        [0.0137],
        [0.0195],
        [0.0094],
        [0.3304],
        [0.0694],
        [0.0144],
        [0.0278],
        [0.1393],
        [0.3971],
        [0.0939],
        [0.0489],
        [0.0763],
        [0.0394],
        [0.2953],
        [0.0581],
        [0.0404],
        [0.0489],
        [0.0429],
        [0.1940],
        [0.0098],
        [0.0535],
        [0.2953],
        [0.2188],
        [0.0115],
        [0.1468],
        [0.0210],
        [0.2410],
        [0.0685],
        [0.1935],
        [0.1258],
        [0.0146],
        [0.0279],
        [0.0240],
        [0.3981],
        [0.2131],
        [0.0267],
        [0.0184],
        [0.2806],
        [0.2224],
        [0.2687],
        [0.0207],
        [0.2931],
        [0.0707],
        [0.0408],
        [0.0836],
        [0.0799],
        [0.1043],
        [0.0235],
        [0.1093],
        [0.0915],
        [0.0186],
        [0.0885],
        [0.0143],
        [0.0430],
        [0.4653],
        [0.1440],
        [0.0343],
        [0.1683],
        [0.0222],
        [0.1450],
        [0.0497],
        [0.2624],
        [0.0158],
        [0.0157],
        [0.3439],
        [0.1724],
        [0.1858],
        [0.4211],
        [0.0741],
        [0.0708],
        [0.0437],
        [0.0117],
        [0.1091],
        [0.0450],
        [0.1210],
        [0.0864],
        [0.1131],
        [0.0796],
        [0.3588],
        [0.1135],
        [0.0211],
        [0.0152],
        [0.0109],
        [0.1337],
        [0.0341],
        [0.1293],
        [0.0809],
        [0.1133],
        [0.0163],
        [0.0598],
        [0.1512],
        [0.0480],
        [0.1759],
        [0.1126],
        [0.1127],
        [0.0263],
        [0.0215],
        [0.4364],
        [0.0164],
        [0.0447],
        [0.0979],
        [0.0678],
        [0.0624],
        [0.0317],
        [0.0102],
        [0.0188],
        [0.0698],
        [0.1521],
        [0.0642],
        [0.3642],
        [0.4675],
        [0.0807],
        [0.0191],
        [0.0761],
        [0.0902],
        [0.2640],
        [0.0658],
        [0.0115],
        [0.0166],
        [0.0222],
        [0.0138],
        [0.0588],
        [0.0254],
        [0.3206],
        [0.0467],
        [0.0861],
        [0.0161],
        [0.0726],
        [0.2488],
        [0.0441],
        [0.0508],
        [0.2247],
        [0.2302],
        [0.1532],
        [0.4200],
        [0.0451],
        [0.3172],
        [0.3833],
        [0.1273],
        [0.0189],
        [0.0763],
        [0.2182],
        [0.0644],
        [0.2589],
        [0.1022],
        [0.0138],
        [0.0616],
        [0.2501],
        [0.0143],
        [0.0120],
        [0.3911],
        [0.2098],
        [0.1860],
        [0.0638],
        [0.0235],
        [0.0264],
        [0.1376],
        [0.3966],
        [0.2488],
        [0.1522],
        [0.0390],
        [0.3700],
        [0.1037],
        [0.0729],
        [0.1019],
        [0.0281],
        [0.0292],
        [0.0099],
        [0.2580],
        [0.0105],
        [0.1386],
        [0.0141],
        [0.0188],
        [0.0625],
        [0.2696],
        [0.0582],
        [0.0218],
        [0.1327],
        [0.0290],
        [0.3961],
        [0.0202],
        [0.0209],
        [0.3187],
        [0.1900],
        [0.0237],
        [0.3660],
        [0.1311],
        [0.0300],
        [0.0211],
        [0.4100],
        [0.0311],
        [0.1034],
        [0.4346],
        [0.1150],
        [0.0588],
        [0.1075],
        [0.0989],
        [0.2195],
        [0.4500],
        [0.0934],
        [0.0930],
        [0.1336],
        [0.1932],
        [0.1717],
        [0.0731],
        [0.1601],
        [0.2492],
        [0.0096],
        [0.0759],
        [0.1010],
        [0.0592],
        [0.4519],
        [0.2835],
        [0.0693],
        [0.1462],
        [0.0280],
        [0.4231],
        [0.0400],
        [0.1261],
        [0.0129],
        [0.1344],
        [0.0724],
        [0.0362],
        [0.0444],
        [0.0724],
        [0.0266],
        [0.0624],
        [0.0094],
        [0.0557],
        [0.1328],
        [0.1478],
        [0.1098],
        [0.0486],
        [0.1091],
        [0.1119],
        [0.1213],
        [0.2821],
        [0.4471],
        [0.0485],
        [0.0278],
        [0.2685],
        [0.0907],
        [0.0253],
        [0.0618],
        [0.0361],
        [0.2087],
        [0.1609],
        [0.2896],
        [0.2296],
        [0.1015],
        [0.2034],
        [0.0103],
        [0.0483],
        [0.0843],
        [0.2846],
        [0.0126],
        [0.0839],
        [0.0097],
        [0.3067],
        [0.0319],
        [0.1666],
        [0.0134],
        [0.3563],
        [0.0354],
        [0.2181],
        [0.0399],
        [0.1233],
        [0.0332],
        [0.1251],
        [0.3705],
        [0.0720],
        [0.0260],
        [0.2191],
        [0.0998],
        [0.2293],
        [0.1474],
        [0.0092],
        [0.0092],
        [0.0433],
        [0.1093],
        [0.0165],
        [0.3294],
        [0.0136],
        [0.0735],
        [0.0381],
        [0.1373],
        [0.1616],
        [0.0496],
        [0.1992],
        [0.0342],
        [0.2832],
        [0.0306],
        [0.0188],
        [0.3583],
        [0.1543],
        [0.0188],
        [0.0111],
        [0.0964],
        [0.0963],
        [0.2209],
        [0.1034],
        [0.1088],
        [0.0695],
        [0.0308],
        [0.0280],
        [0.0712],
        [0.0474],
        [0.1890],
        [0.3057],
        [0.0896],
        [0.2190],
        [0.1548],
        [0.4623],
        [0.4395],
        [0.3971],
        [0.0328],
        [0.0132],
        [0.4267],
        [0.0234],
        [0.0202],
        [0.0141],
        [0.1102],
        [0.0159],
        [0.0842],
        [0.0629],
        [0.1334],
        [0.1256],
        [0.2835],
        [0.3958],
        [0.1798],
        [0.0108],
        [0.0144],
        [0.0220],
        [0.0982],
        [0.1031],
        [0.0454],
        [0.0292],
        [0.1306],
        [0.4508],
        [0.0465],
        [0.3683],
        [0.0549],
        [0.0102],
        [0.0222],
        [0.0503],
        [0.0245],
        [0.1466],
        [0.1999],
        [0.0237],
        [0.1147],
        [0.4205],
        [0.0654],
        [0.1033],
        [0.1182],
        [0.0737],
        [0.0561],
        [0.2595],
        [0.0116],
        [0.1922],
        [0.4246],
        [0.3039],
        [0.1907],
        [0.4135],
        [0.3967],
        [0.2716],
        [0.2395],
        [0.2179],
        [0.2798],
        [0.0185],
        [0.0197],
        [0.0482],
        [0.0086],
        [0.0088],
        [0.1054],
        [0.1191],
        [0.0319],
        [0.1223],
        [0.0358],
        [0.2026],
        [0.0206],
        [0.0784],
        [0.1204],
        [0.3416],
        [0.3174],
        [0.0210],
        [0.0305],
        [0.0098],
        [0.3320],
        [0.0258],
        [0.3058],
        [0.3179],
        [0.1205],
        [0.0276],
        [0.0308],
        [0.0135],
        [0.3972],
        [0.2718],
        [0.4641],
        [0.0432],
        [0.0744],
        [0.3558],
        [0.2100],
        [0.1327],
        [0.1907],
        [0.0276],
        [0.1947],
        [0.0844],
        [0.2688],
        [0.0658],
        [0.0929],
        [0.1471],
        [0.1222],
        [0.0212],
        [0.0804],
        [0.4366],
        [0.2472],
        [0.0325],
        [0.0178],
        [0.3040],
        [0.0868],
        [0.0683],
        [0.0470],
        [0.2027],
        [0.0262],
        [0.1257],
        [0.1146],
        [0.0586],
        [0.1017],
        [0.4349],
        [0.0286],
        [0.0723],
        [0.3070],
        [0.0135],
        [0.0380],
        [0.0447],
        [0.0161],
        [0.0729],
        [0.0360],
        [0.1328],
        [0.0126],
        [0.0531],
        [0.1831],
        [0.4434],
        [0.0198],
        [0.0878],
        [0.0382],
        [0.0387],
        [0.0438],
        [0.0720],
        [0.0311],
        [0.1378],
        [0.1178],
        [0.1888],
        [0.1199],
        [0.2023],
        [0.1153],
        [0.0523],
        [0.0420],
        [0.0200],
        [0.1671],
        [0.1086],
        [0.0503],
        [0.0441],
        [0.0242],
        [0.1676],
        [0.0120],
        [0.4497],
        [0.0396],
        [0.0922],
        [0.2931],
        [0.0194],
        [0.0267],
        [0.1875],
        [0.1045],
        [0.0839],
        [0.1251],
        [0.0294],
        [0.1090],
        [0.0136],
        [0.0851],
        [0.0360],
        [0.0158],
        [0.0944],
        [0.0110],
        [0.0114],
        [0.0586],
        [0.4468],
        [0.0760],
        [0.0501],
        [0.2267],
        [0.0528],
        [0.0367],
        [0.0803],
        [0.1456],
        [0.2818],
        [0.0266],
        [0.1995],
        [0.3691],
        [0.2341],
        [0.2593],
        [0.0636],
        [0.1788],
        [0.0479],
        [0.0509],
        [0.1104],
        [0.0918],
        [0.1508],
        [0.3680],
        [0.3948],
        [0.1899],
        [0.2569],
        [0.0363],
        [0.0262],
        [0.0936],
        [0.0550],
        [0.1027],
        [0.1444],
        [0.0330],
        [0.0097],
        [0.0761],
        [0.2207],
        [0.0326],
        [0.2501],
        [0.0394],
        [0.0760],
        [0.0381],
        [0.0115],
        [0.2717],
        [0.0423],
        [0.0731],
        [0.1560],
        [0.0826],
        [0.0092],
        [0.0219],
        [0.0751],
        [0.1322],
        [0.2677],
        [0.1361],
        [0.4089],
        [0.0925],
        [0.0266],
        [0.1068],
        [0.3935],
        [0.0987],
        [0.0115],
        [0.3348],
        [0.0551],
        [0.0817],
        [0.0489],
        [0.1392],
        [0.0596],
        [0.0844],
        [0.2388],
        [0.0960],
        [0.0721],
        [0.1400],
        [0.4667],
        [0.2374],
        [0.0349],
        [0.0857],
        [0.1599],
        [0.1922],
        [0.0281],
        [0.0183],
        [0.4507],
        [0.0167],
        [0.0283],
        [0.0402],
        [0.2076],
        [0.1693],
        [0.1446],
        [0.3547],
        [0.0943],
        [0.3730],
        [0.1823],
        [0.0426],
        [0.0149],
        [0.0327],
        [0.3715],
        [0.0474],
        [0.1343],
        [0.1915],
        [0.0690],
        [0.0092],
        [0.1643],
        [0.2189],
        [0.3149],
        [0.2171],
        [0.2178],
        [0.0097],
        [0.3628],
        [0.0163],
        [0.0684],
        [0.1145],
        [0.4074],
        [0.0514],
        [0.3587],
        [0.0905],
        [0.0159],
        [0.1992],
        [0.0109]])
print ("The accuracy is", accuracy_score(tTest, np.argmax(y_pred.detach().numpy(), axis=1)))
The accuracy is 0.9480651731160896
torch.save(model, 'stroke.pkl')