ium_434766/lab5.ipynb
s434766 3dfc1beec8
Some checks reported errors
s434766-training/pipeline/head Something is wrong with the build of this commit
jobs
2021-05-07 21:30:35 +02:00

23 KiB


import torch
import torch.nn.functional as F
from torch import nn
from torch.autograd import Variable
import torchvision.transforms as transforms
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import accuracy_score
import numpy as np
import pandas as pd



class LogisticRegressionModel(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(LogisticRegressionModel, self).__init__()
        self.linear = nn.Linear(input_dim, output_dim)
        self.sigmoid = nn.Sigmoid()
    def forward(self, x):
        out = self.linear(x)
        return self.sigmoid(out)

np.set_printoptions(suppress=False)
data_train = pd.read_csv("data_train.csv")
data_test = pd.read_csv("data_test.csv")
data_val = pd.read_csv("data_val.csv")
FEATURES =  [ 'age','hypertension','heart_disease','ever_married', 'avg_glucose_level', 'bmi']
x_train = data_train[FEATURES].astype(np.float32)
y_train = data_train['stroke'].astype(np.float32)

x_test = data_test[FEATURES].astype(np.float32)
y_test = data_test['stroke'].astype(np.float32)



fTrain = torch.from_numpy(x_train.values)
tTrain = torch.from_numpy(y_train.values.reshape(2945,1))

fTest= torch.from_numpy(x_test.values)
tTest = torch.from_numpy(y_test.values)

batch_size = 150
n_iters = 1000
num_epochs = 10
input_dim = 6
output_dim = 1

model = LogisticRegressionModel(input_dim, output_dim)
learning_rate = 0.001

criterion = torch.nn.BCELoss(reduction='mean') 
optimizer = torch.optim.SGD(model.parameters(), lr = learning_rate)
print(list(model.parameters())[0].size())
print(list(model.parameters())[1].size())
torch.Size([1, 6])
torch.Size([1])
for epoch in range(num_epochs):
    print ("Epoch #",epoch)
    model.train()
    optimizer.zero_grad()
    # Forward pass
    y_pred = model(fTrain)
    # Compute Loss
    loss = criterion(y_pred, tTrain)
    print(loss.item())
    # Backward pass
    loss.backward()
    optimizer.step()
Epoch # 0
0.34391772747039795
Epoch # 1
0.3400452435016632
Epoch # 2
0.33628249168395996
Epoch # 3
0.3326331079006195
Epoch # 4
0.3291005790233612
Epoch # 5
0.32568827271461487
Epoch # 6
0.32239940762519836
Epoch # 7
0.3192369043827057
Epoch # 8
0.3162035048007965
Epoch # 9
0.31330153346061707

y_pred = model(fTest)
print("predicted Y value: ", y_pred.data)
predicted Y value:  tensor([[0.0089],
        [0.0051],
        [0.1535],
        [0.1008],
        [0.0365],
        [0.0014],
        [0.1275],
        [0.0172],
        [0.1439],
        [0.0088],
        [0.0013],
        [0.3466],
        [0.0078],
        [0.0303],
        [0.0024],
        [0.0607],
        [0.0556],
        [0.0826],
        [0.0765],
        [0.0027],
        [0.0869],
        [0.0424],
        [0.0013],
        [0.1338],
        [0.0017],
        [0.0020],
        [0.0009],
        [0.0014],
        [0.0090],
        [0.4073],
        [0.0026],
        [0.0009],
        [0.0141],
        [0.0897],
        [0.3593],
        [0.3849],
        [0.0073],
        [0.0204],
        [0.1406],
        [0.0053],
        [0.3840],
        [0.0802],
        [0.0068],
        [0.0190],
        [0.3849],
        [0.0034],
        [0.0045],
        [0.3272],
        [0.0397],
        [0.3087],
        [0.0162],
        [0.0159],
        [0.0033],
        [0.0559],
        [0.0238],
        [0.0073],
        [0.0113],
        [0.0102],
        [0.3827],
        [0.0359],
        [0.0138],
        [0.0248],
        [0.0080],
        [0.1858],
        [0.0766],
        [0.0123],
        [0.0077],
        [0.0042],
        [0.0908],
        [0.4172],
        [0.0010],
        [0.1105],
        [0.0463],
        [0.1457],
        [0.0078],
        [0.0821],
        [0.0011],
        [0.0210],
        [0.0273],
        [0.0248],
        [0.0082],
        [0.0007],
        [0.0022],
        [0.2436],
        [0.0297],
        [0.0235],
        [0.0168],
        [0.0053],
        [0.0128],
        [0.0156],
        [0.0009],
        [0.0375],
        [0.0008],
        [0.0645],
        [0.0750],
        [0.0055],
        [0.0185],
        [0.0008],
        [0.0082],
        [0.0138],
        [0.2082],
        [0.1823],
        [0.0027],
        [0.0124],
        [0.0010],
        [0.0187],
        [0.2454],
        [0.0019],
        [0.1413],
        [0.0010],
        [0.0050],
        [0.0020],
        [0.0011],
        [0.2266],
        [0.0545],
        [0.0164],
        [0.0678],
        [0.0012],
        [0.0271],
        [0.0009],
        [0.0029],
        [0.0058],
        [0.0009],
        [0.0762],
        [0.0013],
        [0.0276],
        [0.3940],
        [0.4213],
        [0.0041],
        [0.0144],
        [0.1491],
        [0.0011],
        [0.0077],
        [0.0032],
        [0.3155],
        [0.0009],
        [0.0072],
        [0.0056],
        [0.3580],
        [0.3235],
        [0.0130],
        [0.4032],
        [0.0405],
        [0.2882],
        [0.0045],
        [0.0041],
        [0.0026],
        [0.0354],
        [0.0094],
        [0.0278],
        [0.0011],
        [0.0036],
        [0.2996],
        [0.0652],
        [0.4247],
        [0.0048],
        [0.0016],
        [0.0703],
        [0.3676],
        [0.0231],
        [0.0206],
        [0.0093],
        [0.0087],
        [0.0649],
        [0.0207],
        [0.0010],
        [0.0076],
        [0.3366],
        [0.0015],
        [0.0034],
        [0.2819],
        [0.0007],
        [0.0024],
        [0.0015],
        [0.1623],
        [0.0838],
        [0.0431],
        [0.2744],
        [0.1369],
        [0.0007],
        [0.0022],
        [0.2049],
        [0.0010],
        [0.1057],
        [0.0503],
        [0.0021],
        [0.0136],
        [0.1939],
        [0.0401],
        [0.0010],
        [0.4003],
        [0.2621],
        [0.0087],
        [0.3507],
        [0.0061],
        [0.0012],
        [0.0103],
        [0.0080],
        [0.1068],
        [0.0098],
        [0.2625],
        [0.0162],
        [0.0178],
        [0.0215],
        [0.0283],
        [0.0444],
        [0.0356],
        [0.0037],
        [0.0779],
        [0.0652],
        [0.0521],
        [0.3626],
        [0.0116],
        [0.2099],
        [0.0431],
        [0.0076],
        [0.0050],
        [0.0010],
        [0.0049],
        [0.0036],
        [0.1638],
        [0.0205],
        [0.0775],
        [0.0026],
        [0.0826],
        [0.0052],
        [0.0111],
        [0.3208],
        [0.0435],
        [0.0140],
        [0.2729],
        [0.1035],
        [0.0041],
        [0.0202],
        [0.0164],
        [0.0089],
        [0.0359],
        [0.0306],
        [0.0325],
        [0.0354],
        [0.0096],
        [0.0304],
        [0.0473],
        [0.0016],
        [0.1147],
        [0.0368],
        [0.0034],
        [0.3918],
        [0.0054],
        [0.0040],
        [0.0415],
        [0.0190],
        [0.0116],
        [0.0014],
        [0.0010],
        [0.1187],
        [0.3011],
        [0.3357],
        [0.0077],
        [0.0243],
        [0.0128],
        [0.0978],
        [0.0020],
        [0.0064],
        [0.1374],
        [0.0916],
        [0.0623],
        [0.1318],
        [0.0144],
        [0.0737],
        [0.0468],
        [0.1312],
        [0.0056],
        [0.0017],
        [0.0157],
        [0.0008],
        [0.0016],
        [0.0098],
        [0.2015],
        [0.2831],
        [0.0054],
        [0.0062],
        [0.1790],
        [0.0022],
        [0.0012],
        [0.0174],
        [0.0033],
        [0.0100],
        [0.0580],
        [0.2458],
        [0.4155],
        [0.1052],
        [0.0007],
        [0.1498],
        [0.0344],
        [0.0011],
        [0.0006],
        [0.0088],
        [0.0018],
        [0.0011],
        [0.0014],
        [0.3143],
        [0.0236],
        [0.0252],
        [0.0010],
        [0.0022],
        [0.0041],
        [0.0105],
        [0.1520],
        [0.0063],
        [0.0197],
        [0.0082],
        [0.0038],
        [0.0533],
        [0.1163],
        [0.0032],
        [0.1304],
        [0.0058],
        [0.0045],
        [0.0008],
        [0.1758],
        [0.0078],
        [0.0282],
        [0.1256],
        [0.0021],
        [0.0388],
        [0.0826],
        [0.0185],
        [0.0009],
        [0.0152],
        [0.4050],
        [0.0030],
        [0.0017],
        [0.1281],
        [0.0424],
        [0.0090],
        [0.0404],
        [0.2027],
        [0.0005],
        [0.0053],
        [0.1841],
        [0.3782],
        [0.0047],
        [0.0025],
        [0.0432],
        [0.1070],
        [0.0010],
        [0.0069],
        [0.0033],
        [0.3327],
        [0.0852],
        [0.0135],
        [0.0007],
        [0.0109],
        [0.1961],
        [0.0393],
        [0.0013],
        [0.1051],
        [0.0078],
        [0.0324],
        [0.0065],
        [0.1976],
        [0.1738],
        [0.0195],
        [0.1941],
        [0.0088],
        [0.1793],
        [0.0061],
        [0.4261],
        [0.0122],
        [0.0733],
        [0.0026],
        [0.0009],
        [0.0436],
        [0.0141],
        [0.0124],
        [0.0020],
        [0.0048],
        [0.0381],
        [0.0009],
        [0.0044],
        [0.0014],
        [0.0010],
        [0.0008],
        [0.2416],
        [0.0177],
        [0.0015],
        [0.0041],
        [0.0554],
        [0.3380],
        [0.0284],
        [0.0093],
        [0.0200],
        [0.0057],
        [0.2006],
        [0.0126],
        [0.0063],
        [0.0072],
        [0.0071],
        [0.0981],
        [0.0008],
        [0.0116],
        [0.1803],
        [0.1175],
        [0.0010],
        [0.0610],
        [0.0026],
        [0.1416],
        [0.0168],
        [0.0973],
        [0.0457],
        [0.0012],
        [0.0016],
        [0.0015],
        [0.3407],
        [0.1155],
        [0.0035],
        [0.0010],
        [0.1799],
        [0.1241],
        [0.1657],
        [0.0018],
        [0.1940],
        [0.0171],
        [0.0075],
        [0.0230],
        [0.0113],
        [0.0342],
        [0.0031],
        [0.0369],
        [0.0262],
        [0.0018],
        [0.0224],
        [0.0014],
        [0.0081],
        [0.4363],
        [0.0562],
        [0.0056],
        [0.0754],
        [0.0028],
        [0.0582],
        [0.0100],
        [0.1473],
        [0.0016],
        [0.0013],
        [0.2620],
        [0.0646],
        [0.0910],
        [0.3646],
        [0.0181],
        [0.0146],
        [0.0076],
        [0.0009],
        [0.0363],
        [0.0086],
        [0.0336],
        [0.0206],
        [0.0374],
        [0.0212],
        [0.2859],
        [0.0384],
        [0.0026],
        [0.0016],
        [0.0009],
        [0.0519],
        [0.0054],
        [0.0392],
        [0.0225],
        [0.0382],
        [0.0016],
        [0.0075],
        [0.0619],
        [0.0040],
        [0.0810],
        [0.0372],
        [0.0376],
        [0.0037],
        [0.0028],
        [0.3910],
        [0.0017],
        [0.0084],
        [0.0298],
        [0.0160],
        [0.0112],
        [0.0024],
        [0.0008],
        [0.0019],
        [0.0181],
        [0.0624],
        [0.0127],
        [0.2768],
        [0.4461],
        [0.0224],
        [0.0015],
        [0.0200],
        [0.0274],
        [0.1512],
        [0.0161],
        [0.0005],
        [0.0009],
        [0.0028],
        [0.0012],
        [0.0132],
        [0.0028],
        [0.2346],
        [0.0088],
        [0.0227],
        [0.0015],
        [0.0187],
        [0.1476],
        [0.0069],
        [0.0096],
        [0.1176],
        [0.1304],
        [0.0643],
        [0.3622],
        [0.0078],
        [0.2245],
        [0.3165],
        [0.0477],
        [0.0022],
        [0.0203],
        [0.1160],
        [0.0148],
        [0.1549],
        [0.0316],
        [0.0011],
        [0.0141],
        [0.1498],
        [0.0011],
        [0.0011],
        [0.3146],
        [0.1083],
        [0.0908],
        [0.0128],
        [0.0031],
        [0.0038],
        [0.0525],
        [0.3349],
        [0.1460],
        [0.0631],
        [0.0054],
        [0.2970],
        [0.0336],
        [0.0166],
        [0.0312],
        [0.0037],
        [0.0038],
        [0.0008],
        [0.1529],
        [0.0008],
        [0.0544],
        [0.0013],
        [0.0021],
        [0.0079],
        [0.1678],
        [0.0127],
        [0.0026],
        [0.0506],
        [0.0042],
        [0.3334],
        [0.0022],
        [0.0021],
        [0.2290],
        [0.0904],
        [0.0032],
        [0.2832],
        [0.0484],
        [0.0022],
        [0.0026],
        [0.3545],
        [0.0048],
        [0.0331],
        [0.3835],
        [0.0400],
        [0.0132],
        [0.0342],
        [0.0311],
        [0.1180],
        [0.4113],
        [0.0278],
        [0.0277],
        [0.0400],
        [0.0968],
        [0.0736],
        [0.0176],
        [0.0696],
        [0.1483],
        [0.0007],
        [0.0197],
        [0.0305],
        [0.0125],
        [0.4180],
        [0.1870],
        [0.0174],
        [0.0603],
        [0.0042],
        [0.3729],
        [0.0069],
        [0.0466],
        [0.0006],
        [0.0530],
        [0.0183],
        [0.0061],
        [0.0061],
        [0.0185],
        [0.0021],
        [0.0099],
        [0.0007],
        [0.0121],
        [0.0504],
        [0.0619],
        [0.0373],
        [0.0078],
        [0.0366],
        [0.0357],
        [0.0443],
        [0.1825],
        [0.3996],
        [0.0054],
        [0.0040],
        [0.1632],
        [0.0247],
        [0.0019],
        [0.0141],
        [0.0060],
        [0.1078],
        [0.0698],
        [0.1844],
        [0.1288],
        [0.0308],
        [0.1057],
        [0.0009],
        [0.0100],
        [0.0234],
        [0.1900],
        [0.0010],
        [0.0231],
        [0.0007],
        [0.2147],
        [0.0048],
        [0.0733],
        [0.0012],
        [0.2775],
        [0.0033],
        [0.1134],
        [0.0070],
        [0.0424],
        [0.0055],
        [0.0451],
        [0.2994],
        [0.0181],
        [0.0035],
        [0.1174],
        [0.0287],
        [0.1288],
        [0.0613],
        [0.0007],
        [0.0007],
        [0.0083],
        [0.0365],
        [0.0017],
        [0.2397],
        [0.0007],
        [0.0196],
        [0.0065],
        [0.0519],
        [0.0722],
        [0.0098],
        [0.0986],
        [0.0025],
        [0.1867],
        [0.0046],
        [0.0015],
        [0.2836],
        [0.0665],
        [0.0021],
        [0.0005],
        [0.0285],
        [0.0281],
        [0.1217],
        [0.0330],
        [0.0361],
        [0.0177],
        [0.0046],
        [0.0041],
        [0.0175],
        [0.0092],
        [0.0877],
        [0.2109],
        [0.0267],
        [0.1180],
        [0.0670],
        [0.4255],
        [0.4027],
        [0.3366],
        [0.0052],
        [0.0009],
        [0.3857],
        [0.0031],
        [0.0013],
        [0.0011],
        [0.0379],
        [0.0017],
        [0.0236],
        [0.0153],
        [0.0518],
        [0.0456],
        [0.1873],
        [0.3320],
        [0.0881],
        [0.0009],
        [0.0015],
        [0.0027],
        [0.0304],
        [0.0332],
        [0.0036],
        [0.0043],
        [0.0409],
        [0.4144],
        [0.0087],
        [0.2914],
        [0.0116],
        [0.0008],
        [0.0028],
        [0.0103],
        [0.0032],
        [0.0606],
        [0.1001],
        [0.0028],
        [0.0397],
        [0.3626],
        [0.0158],
        [0.0331],
        [0.0401],
        [0.0182],
        [0.0125],
        [0.1572],
        [0.0010],
        [0.0937],
        [0.3776],
        [0.2067],
        [0.0914],
        [0.3666],
        [0.3347],
        [0.1722],
        [0.1384],
        [0.1152],
        [0.1753],
        [0.0009],
        [0.0023],
        [0.0088],
        [0.0006],
        [0.0006],
        [0.0313],
        [0.0413],
        [0.0022],
        [0.0413],
        [0.0059],
        [0.1054],
        [0.0024],
        [0.0173],
        [0.0426],
        [0.2546],
        [0.2232],
        [0.0026],
        [0.0045],
        [0.0008],
        [0.2471],
        [0.0035],
        [0.2115],
        [0.2262],
        [0.0433],
        [0.0039],
        [0.0047],
        [0.0013],
        [0.3367],
        [0.1733],
        [0.4315],
        [0.0075],
        [0.0182],
        [0.2763],
        [0.1106],
        [0.0510],
        [0.0907],
        [0.0039],
        [0.0987],
        [0.0239],
        [0.1644],
        [0.0162],
        [0.0271],
        [0.0608],
        [0.0351],
        [0.0014],
        [0.0221],
        [0.3912],
        [0.1440],
        [0.0051],
        [0.0015],
        [0.2073],
        [0.0235],
        [0.0168],
        [0.0090],
        [0.1056],
        [0.0036],
        [0.0452],
        [0.0396],
        [0.0070],
        [0.0314],
        [0.3853],
        [0.0039],
        [0.0184],
        [0.2150],
        [0.0013],
        [0.0055],
        [0.0084],
        [0.0015],
        [0.0190],
        [0.0060],
        [0.0507],
        [0.0007],
        [0.0113],
        [0.0851],
        [0.3850],
        [0.0012],
        [0.0245],
        [0.0066],
        [0.0069],
        [0.0040],
        [0.0183],
        [0.0043],
        [0.0529],
        [0.0398],
        [0.0887],
        [0.0426],
        [0.0927],
        [0.0404],
        [0.0049],
        [0.0056],
        [0.0024],
        [0.0739],
        [0.0359],
        [0.0104],
        [0.0081],
        [0.0023],
        [0.0747],
        [0.0011],
        [0.4100],
        [0.0068],
        [0.0268],
        [0.1948],
        [0.0013],
        [0.0018],
        [0.0867],
        [0.0346],
        [0.0231],
        [0.0455],
        [0.0044],
        [0.0365],
        [0.0013],
        [0.0245],
        [0.0059],
        [0.0014],
        [0.0257],
        [0.0010],
        [0.0010],
        [0.0131],
        [0.3990],
        [0.0200],
        [0.0078],
        [0.1217],
        [0.0110],
        [0.0064],
        [0.0219],
        [0.0580],
        [0.1824],
        [0.0020],
        [0.0986],
        [0.2947],
        [0.1087],
        [0.1563],
        [0.0122],
        [0.0784],
        [0.0097],
        [0.0097],
        [0.0347],
        [0.0263],
        [0.0613],
        [0.2907],
        [0.3285],
        [0.0896],
        [0.1502],
        [0.0061],
        [0.0037],
        [0.0282],
        [0.0116],
        [0.0323],
        [0.0576],
        [0.0054],
        [0.0008],
        [0.0200],
        [0.1212],
        [0.0052],
        [0.1496],
        [0.0065],
        [0.0199],
        [0.0065],
        [0.0010],
        [0.1536],
        [0.0076],
        [0.0191],
        [0.0691],
        [0.0187],
        [0.0007],
        [0.0027],
        [0.0162],
        [0.0440],
        [0.1624],
        [0.0502],
        [0.3496],
        [0.0270],
        [0.0034],
        [0.0337],
        [0.3247],
        [0.0274],
        [0.0010],
        [0.2565],
        [0.0099],
        [0.0126],
        [0.0092],
        [0.0546],
        [0.0139],
        [0.0238],
        [0.1364],
        [0.0246],
        [0.0183],
        [0.0558],
        [0.4423],
        [0.1326],
        [0.0059],
        [0.0229],
        [0.0692],
        [0.0944],
        [0.0022],
        [0.0017],
        [0.4130],
        [0.0013],
        [0.0037],
        [0.0071],
        [0.1049],
        [0.0774],
        [0.0569],
        [0.2711],
        [0.0290],
        [0.3081],
        [0.0848],
        [0.0078],
        [0.0015],
        [0.0046],
        [0.3030],
        [0.0093],
        [0.0481],
        [0.0931],
        [0.0174],
        [0.0007],
        [0.0695],
        [0.1172],
        [0.2178],
        [0.1137],
        [0.1141],
        [0.0008],
        [0.2754],
        [0.0008],
        [0.0167],
        [0.0398],
        [0.3444],
        [0.0089],
        [0.2858],
        [0.0251],
        [0.0016],
        [0.0993],
        [0.0009]])
print ("The accuracy is", accuracy_score(tTest, np.argmax(y_pred.detach().numpy(), axis=1)))
The accuracy is 0.9480651731160896
torch.save(model, 'stroke.pkl')