s434766
3dfc1beec8
Some checks reported errors
s434766-training/pipeline/head Something is wrong with the build of this commit
23 KiB
23 KiB
import torch
import torch.nn.functional as F
from torch import nn
from torch.autograd import Variable
import torchvision.transforms as transforms
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import accuracy_score
import numpy as np
import pandas as pd
class LogisticRegressionModel(nn.Module):
def __init__(self, input_dim, output_dim):
super(LogisticRegressionModel, self).__init__()
self.linear = nn.Linear(input_dim, output_dim)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
out = self.linear(x)
return self.sigmoid(out)
np.set_printoptions(suppress=False)
data_train = pd.read_csv("data_train.csv")
data_test = pd.read_csv("data_test.csv")
data_val = pd.read_csv("data_val.csv")
FEATURES = [ 'age','hypertension','heart_disease','ever_married', 'avg_glucose_level', 'bmi']
x_train = data_train[FEATURES].astype(np.float32)
y_train = data_train['stroke'].astype(np.float32)
x_test = data_test[FEATURES].astype(np.float32)
y_test = data_test['stroke'].astype(np.float32)
fTrain = torch.from_numpy(x_train.values)
tTrain = torch.from_numpy(y_train.values.reshape(2945,1))
fTest= torch.from_numpy(x_test.values)
tTest = torch.from_numpy(y_test.values)
batch_size = 150
n_iters = 1000
num_epochs = 10
input_dim = 6
output_dim = 1
model = LogisticRegressionModel(input_dim, output_dim)
learning_rate = 0.001
criterion = torch.nn.BCELoss(reduction='mean')
optimizer = torch.optim.SGD(model.parameters(), lr = learning_rate)
print(list(model.parameters())[0].size())
print(list(model.parameters())[1].size())
torch.Size([1, 6]) torch.Size([1])
for epoch in range(num_epochs):
print ("Epoch #",epoch)
model.train()
optimizer.zero_grad()
# Forward pass
y_pred = model(fTrain)
# Compute Loss
loss = criterion(y_pred, tTrain)
print(loss.item())
# Backward pass
loss.backward()
optimizer.step()
Epoch # 0 0.34391772747039795 Epoch # 1 0.3400452435016632 Epoch # 2 0.33628249168395996 Epoch # 3 0.3326331079006195 Epoch # 4 0.3291005790233612 Epoch # 5 0.32568827271461487 Epoch # 6 0.32239940762519836 Epoch # 7 0.3192369043827057 Epoch # 8 0.3162035048007965 Epoch # 9 0.31330153346061707
y_pred = model(fTest)
print("predicted Y value: ", y_pred.data)
predicted Y value: tensor([[0.0089], [0.0051], [0.1535], [0.1008], [0.0365], [0.0014], [0.1275], [0.0172], [0.1439], [0.0088], [0.0013], [0.3466], [0.0078], [0.0303], [0.0024], [0.0607], [0.0556], [0.0826], [0.0765], [0.0027], [0.0869], [0.0424], [0.0013], [0.1338], [0.0017], [0.0020], [0.0009], [0.0014], [0.0090], [0.4073], [0.0026], [0.0009], [0.0141], [0.0897], [0.3593], [0.3849], [0.0073], [0.0204], [0.1406], [0.0053], [0.3840], [0.0802], [0.0068], [0.0190], [0.3849], [0.0034], [0.0045], [0.3272], [0.0397], [0.3087], [0.0162], [0.0159], [0.0033], [0.0559], [0.0238], [0.0073], [0.0113], [0.0102], [0.3827], [0.0359], [0.0138], [0.0248], [0.0080], [0.1858], [0.0766], [0.0123], [0.0077], [0.0042], [0.0908], [0.4172], [0.0010], [0.1105], [0.0463], [0.1457], [0.0078], [0.0821], [0.0011], [0.0210], [0.0273], [0.0248], [0.0082], [0.0007], [0.0022], [0.2436], [0.0297], [0.0235], [0.0168], [0.0053], [0.0128], [0.0156], [0.0009], [0.0375], [0.0008], [0.0645], [0.0750], [0.0055], [0.0185], [0.0008], [0.0082], [0.0138], [0.2082], [0.1823], [0.0027], [0.0124], [0.0010], [0.0187], [0.2454], [0.0019], [0.1413], [0.0010], [0.0050], [0.0020], [0.0011], [0.2266], [0.0545], [0.0164], [0.0678], [0.0012], [0.0271], [0.0009], [0.0029], [0.0058], [0.0009], [0.0762], [0.0013], [0.0276], [0.3940], [0.4213], [0.0041], [0.0144], [0.1491], [0.0011], [0.0077], [0.0032], [0.3155], [0.0009], [0.0072], [0.0056], [0.3580], [0.3235], [0.0130], [0.4032], [0.0405], [0.2882], [0.0045], [0.0041], [0.0026], [0.0354], [0.0094], [0.0278], [0.0011], [0.0036], [0.2996], [0.0652], [0.4247], [0.0048], [0.0016], [0.0703], [0.3676], [0.0231], [0.0206], [0.0093], [0.0087], [0.0649], [0.0207], [0.0010], [0.0076], [0.3366], [0.0015], [0.0034], [0.2819], [0.0007], [0.0024], [0.0015], [0.1623], [0.0838], [0.0431], [0.2744], [0.1369], [0.0007], [0.0022], [0.2049], [0.0010], [0.1057], [0.0503], [0.0021], [0.0136], [0.1939], [0.0401], [0.0010], [0.4003], [0.2621], [0.0087], [0.3507], [0.0061], [0.0012], [0.0103], [0.0080], [0.1068], [0.0098], [0.2625], [0.0162], [0.0178], [0.0215], [0.0283], [0.0444], [0.0356], [0.0037], [0.0779], [0.0652], [0.0521], [0.3626], [0.0116], [0.2099], [0.0431], [0.0076], [0.0050], [0.0010], [0.0049], [0.0036], [0.1638], [0.0205], [0.0775], [0.0026], [0.0826], [0.0052], [0.0111], [0.3208], [0.0435], [0.0140], [0.2729], [0.1035], [0.0041], [0.0202], [0.0164], [0.0089], [0.0359], [0.0306], [0.0325], [0.0354], [0.0096], [0.0304], [0.0473], [0.0016], [0.1147], [0.0368], [0.0034], [0.3918], [0.0054], [0.0040], [0.0415], [0.0190], [0.0116], [0.0014], [0.0010], [0.1187], [0.3011], [0.3357], [0.0077], [0.0243], [0.0128], [0.0978], [0.0020], [0.0064], [0.1374], [0.0916], [0.0623], [0.1318], [0.0144], [0.0737], [0.0468], [0.1312], [0.0056], [0.0017], [0.0157], [0.0008], [0.0016], [0.0098], [0.2015], [0.2831], [0.0054], [0.0062], [0.1790], [0.0022], [0.0012], [0.0174], [0.0033], [0.0100], [0.0580], [0.2458], [0.4155], [0.1052], [0.0007], [0.1498], [0.0344], [0.0011], [0.0006], [0.0088], [0.0018], [0.0011], [0.0014], [0.3143], [0.0236], [0.0252], [0.0010], [0.0022], [0.0041], [0.0105], [0.1520], [0.0063], [0.0197], [0.0082], [0.0038], [0.0533], [0.1163], [0.0032], [0.1304], [0.0058], [0.0045], [0.0008], [0.1758], [0.0078], [0.0282], [0.1256], [0.0021], [0.0388], [0.0826], [0.0185], [0.0009], [0.0152], [0.4050], [0.0030], [0.0017], [0.1281], [0.0424], [0.0090], [0.0404], [0.2027], [0.0005], [0.0053], [0.1841], [0.3782], [0.0047], [0.0025], [0.0432], [0.1070], [0.0010], [0.0069], [0.0033], [0.3327], [0.0852], [0.0135], [0.0007], [0.0109], [0.1961], [0.0393], [0.0013], [0.1051], [0.0078], [0.0324], [0.0065], [0.1976], [0.1738], [0.0195], [0.1941], [0.0088], [0.1793], [0.0061], [0.4261], [0.0122], [0.0733], [0.0026], [0.0009], [0.0436], [0.0141], [0.0124], [0.0020], [0.0048], [0.0381], [0.0009], [0.0044], [0.0014], [0.0010], [0.0008], [0.2416], [0.0177], [0.0015], [0.0041], [0.0554], [0.3380], [0.0284], [0.0093], [0.0200], [0.0057], [0.2006], [0.0126], [0.0063], [0.0072], [0.0071], [0.0981], [0.0008], [0.0116], [0.1803], [0.1175], [0.0010], [0.0610], [0.0026], [0.1416], [0.0168], [0.0973], [0.0457], [0.0012], [0.0016], [0.0015], [0.3407], [0.1155], [0.0035], [0.0010], [0.1799], [0.1241], [0.1657], [0.0018], [0.1940], [0.0171], [0.0075], [0.0230], [0.0113], [0.0342], [0.0031], [0.0369], [0.0262], [0.0018], [0.0224], [0.0014], [0.0081], [0.4363], [0.0562], [0.0056], [0.0754], [0.0028], [0.0582], [0.0100], [0.1473], [0.0016], [0.0013], [0.2620], [0.0646], [0.0910], [0.3646], [0.0181], [0.0146], [0.0076], [0.0009], [0.0363], [0.0086], [0.0336], [0.0206], [0.0374], [0.0212], [0.2859], [0.0384], [0.0026], [0.0016], [0.0009], [0.0519], [0.0054], [0.0392], [0.0225], [0.0382], [0.0016], [0.0075], [0.0619], [0.0040], [0.0810], [0.0372], [0.0376], [0.0037], [0.0028], [0.3910], [0.0017], [0.0084], [0.0298], [0.0160], [0.0112], [0.0024], [0.0008], [0.0019], [0.0181], [0.0624], [0.0127], [0.2768], [0.4461], [0.0224], [0.0015], [0.0200], [0.0274], [0.1512], [0.0161], [0.0005], [0.0009], [0.0028], [0.0012], [0.0132], [0.0028], [0.2346], [0.0088], [0.0227], [0.0015], [0.0187], [0.1476], [0.0069], [0.0096], [0.1176], [0.1304], [0.0643], [0.3622], [0.0078], [0.2245], [0.3165], [0.0477], [0.0022], [0.0203], [0.1160], [0.0148], [0.1549], [0.0316], [0.0011], [0.0141], [0.1498], [0.0011], [0.0011], [0.3146], [0.1083], [0.0908], [0.0128], [0.0031], [0.0038], [0.0525], [0.3349], [0.1460], [0.0631], [0.0054], [0.2970], [0.0336], [0.0166], [0.0312], [0.0037], [0.0038], [0.0008], [0.1529], [0.0008], [0.0544], [0.0013], [0.0021], [0.0079], [0.1678], [0.0127], [0.0026], [0.0506], [0.0042], [0.3334], [0.0022], [0.0021], [0.2290], [0.0904], [0.0032], [0.2832], [0.0484], [0.0022], [0.0026], [0.3545], [0.0048], [0.0331], [0.3835], [0.0400], [0.0132], [0.0342], [0.0311], [0.1180], [0.4113], [0.0278], [0.0277], [0.0400], [0.0968], [0.0736], [0.0176], [0.0696], [0.1483], [0.0007], [0.0197], [0.0305], [0.0125], [0.4180], [0.1870], [0.0174], [0.0603], [0.0042], [0.3729], [0.0069], [0.0466], [0.0006], [0.0530], [0.0183], [0.0061], [0.0061], [0.0185], [0.0021], [0.0099], [0.0007], [0.0121], [0.0504], [0.0619], [0.0373], [0.0078], [0.0366], [0.0357], [0.0443], [0.1825], [0.3996], [0.0054], [0.0040], [0.1632], [0.0247], [0.0019], [0.0141], [0.0060], [0.1078], [0.0698], [0.1844], [0.1288], [0.0308], [0.1057], [0.0009], [0.0100], [0.0234], [0.1900], [0.0010], [0.0231], [0.0007], [0.2147], [0.0048], [0.0733], [0.0012], [0.2775], [0.0033], [0.1134], [0.0070], [0.0424], [0.0055], [0.0451], [0.2994], [0.0181], [0.0035], [0.1174], [0.0287], [0.1288], [0.0613], [0.0007], [0.0007], [0.0083], [0.0365], [0.0017], [0.2397], [0.0007], [0.0196], [0.0065], [0.0519], [0.0722], [0.0098], [0.0986], [0.0025], [0.1867], [0.0046], [0.0015], [0.2836], [0.0665], [0.0021], [0.0005], [0.0285], [0.0281], [0.1217], [0.0330], [0.0361], [0.0177], [0.0046], [0.0041], [0.0175], [0.0092], [0.0877], [0.2109], [0.0267], [0.1180], [0.0670], [0.4255], [0.4027], [0.3366], [0.0052], [0.0009], [0.3857], [0.0031], [0.0013], [0.0011], [0.0379], [0.0017], [0.0236], [0.0153], [0.0518], [0.0456], [0.1873], [0.3320], [0.0881], [0.0009], [0.0015], [0.0027], [0.0304], [0.0332], [0.0036], [0.0043], [0.0409], [0.4144], [0.0087], [0.2914], [0.0116], [0.0008], [0.0028], [0.0103], [0.0032], [0.0606], [0.1001], [0.0028], [0.0397], [0.3626], [0.0158], [0.0331], [0.0401], [0.0182], [0.0125], [0.1572], [0.0010], [0.0937], [0.3776], [0.2067], [0.0914], [0.3666], [0.3347], [0.1722], [0.1384], [0.1152], [0.1753], [0.0009], [0.0023], [0.0088], [0.0006], [0.0006], [0.0313], [0.0413], [0.0022], [0.0413], [0.0059], [0.1054], [0.0024], [0.0173], [0.0426], [0.2546], [0.2232], [0.0026], [0.0045], [0.0008], [0.2471], [0.0035], [0.2115], [0.2262], [0.0433], [0.0039], [0.0047], [0.0013], [0.3367], [0.1733], [0.4315], [0.0075], [0.0182], [0.2763], [0.1106], [0.0510], [0.0907], [0.0039], [0.0987], [0.0239], [0.1644], [0.0162], [0.0271], [0.0608], [0.0351], [0.0014], [0.0221], [0.3912], [0.1440], [0.0051], [0.0015], [0.2073], [0.0235], [0.0168], [0.0090], [0.1056], [0.0036], [0.0452], [0.0396], [0.0070], [0.0314], [0.3853], [0.0039], [0.0184], [0.2150], [0.0013], [0.0055], [0.0084], [0.0015], [0.0190], [0.0060], [0.0507], [0.0007], [0.0113], [0.0851], [0.3850], [0.0012], [0.0245], [0.0066], [0.0069], [0.0040], [0.0183], [0.0043], [0.0529], [0.0398], [0.0887], [0.0426], [0.0927], [0.0404], [0.0049], [0.0056], [0.0024], [0.0739], [0.0359], [0.0104], [0.0081], [0.0023], [0.0747], [0.0011], [0.4100], [0.0068], [0.0268], [0.1948], [0.0013], [0.0018], [0.0867], [0.0346], [0.0231], [0.0455], [0.0044], [0.0365], [0.0013], [0.0245], [0.0059], [0.0014], [0.0257], [0.0010], [0.0010], [0.0131], [0.3990], [0.0200], [0.0078], [0.1217], [0.0110], [0.0064], [0.0219], [0.0580], [0.1824], [0.0020], [0.0986], [0.2947], [0.1087], [0.1563], [0.0122], [0.0784], [0.0097], [0.0097], [0.0347], [0.0263], [0.0613], [0.2907], [0.3285], [0.0896], [0.1502], [0.0061], [0.0037], [0.0282], [0.0116], [0.0323], [0.0576], [0.0054], [0.0008], [0.0200], [0.1212], [0.0052], [0.1496], [0.0065], [0.0199], [0.0065], [0.0010], [0.1536], [0.0076], [0.0191], [0.0691], [0.0187], [0.0007], [0.0027], [0.0162], [0.0440], [0.1624], [0.0502], [0.3496], [0.0270], [0.0034], [0.0337], [0.3247], [0.0274], [0.0010], [0.2565], [0.0099], [0.0126], [0.0092], [0.0546], [0.0139], [0.0238], [0.1364], [0.0246], [0.0183], [0.0558], [0.4423], [0.1326], [0.0059], [0.0229], [0.0692], [0.0944], [0.0022], [0.0017], [0.4130], [0.0013], [0.0037], [0.0071], [0.1049], [0.0774], [0.0569], [0.2711], [0.0290], [0.3081], [0.0848], [0.0078], [0.0015], [0.0046], [0.3030], [0.0093], [0.0481], [0.0931], [0.0174], [0.0007], [0.0695], [0.1172], [0.2178], [0.1137], [0.1141], [0.0008], [0.2754], [0.0008], [0.0167], [0.0398], [0.3444], [0.0089], [0.2858], [0.0251], [0.0016], [0.0993], [0.0009]])
print ("The accuracy is", accuracy_score(tTest, np.argmax(y_pred.detach().numpy(), axis=1)))
The accuracy is 0.9480651731160896
torch.save(model, 'stroke.pkl')