25 KiB
25 KiB
import torch
import torch.nn.functional as F
from torch import nn
from torch.autograd import Variable
import torchvision.transforms as transforms
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import accuracy_score
import numpy as np
import pandas as pd
class LogisticRegressionModel(nn.Module):
def __init__(self, input_dim, output_dim):
super(LogisticRegressionModel, self).__init__()
self.linear = nn.Linear(input_dim, output_dim)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
out = self.linear(x)
return self.sigmoid(out)
data_train = pd.read_csv("data_train.csv")
data_test = pd.read_csv("data_test.csv")
data_val = pd.read_csv("data_val.csv")
FEATURES = [ 'age','hypertension','heart_disease','ever_married', 'avg_glucose_level', 'bmi']
x_train = data_train[FEATURES].astype(np.float32)
y_train = data_train['stroke'].astype(np.float32)
x_test = data_test[FEATURES].astype(np.float32)
y_test = data_test['stroke'].astype(np.float32)
fTrain = torch.from_numpy(x_train.values)
tTrain = torch.from_numpy(y_train.values.reshape(2945,1))
fTest= torch.from_numpy(x_test.values)
tTest = torch.from_numpy(y_test.values)
batch_size = 95
n_iters = 1000
num_epochs = int(n_iters / (len(x_train) / batch_size))
input_dim = 6
output_dim = 1
model = LogisticRegressionModel(input_dim, output_dim)
learning_rate = 0.001
criterion = torch.nn.BCELoss(reduction='mean')
optimizer = torch.optim.SGD(model.parameters(), lr = learning_rate)
print(list(model.parameters())[0].size())
print(list(model.parameters())[1].size())
torch.Size([1, 6]) torch.Size([1])
for epoch in range(num_epochs):
print ("Epoch #",epoch)
model.train()
optimizer.zero_grad()
# Forward pass
y_pred = model(fTrain)
# Compute Loss
loss = criterion(y_pred, tTrain)
print(loss.item())
# Backward pass
loss.backward()
optimizer.step()
Epoch # 0 4.4554009437561035 Epoch # 1 2.887434244155884 Epoch # 2 1.4808591604232788 Epoch # 3 0.6207292079925537 Epoch # 4 0.4031478762626648 Epoch # 5 0.34721270203590393 Epoch # 6 0.32333147525787354 Epoch # 7 0.3105970621109009 Epoch # 8 0.30295372009277344 Epoch # 9 0.2980167269706726 Epoch # 10 0.29466450214385986 Epoch # 11 0.29230451583862305 Epoch # 12 0.29059702157974243 Epoch # 13 0.2893349230289459 Epoch # 14 0.2883857190608978 Epoch # 15 0.2876618504524231 Epoch # 16 0.2871031165122986 Epoch # 17 0.28666743636131287 Epoch # 18 0.28632479906082153 Epoch # 19 0.2860531508922577 Epoch # 20 0.28583624958992004 Epoch # 21 0.2856619954109192 Epoch # 22 0.285521000623703 Epoch # 23 0.2854064106941223 Epoch # 24 0.2853126525878906 Epoch # 25 0.2852354049682617 Epoch # 26 0.2851715385913849 Epoch # 27 0.28511837124824524 Epoch # 28 0.2850736975669861 Epoch # 29 0.2850360572338104 Epoch # 30 0.28500401973724365 Epoch # 31 0.2849765419960022 X:\Anaconda2020\lib\site-packages\torch\autograd\__init__.py:145: UserWarning: CUDA initialization: The NVIDIA driver on your system is too old (found version 10010). Please update your GPU driver by downloading and installing a new version from the URL: http://www.nvidia.com/Download/index.aspx Alternatively, go to: https://pytorch.org to install a PyTorch version that has been compiled with your version of the CUDA driver. (Triggered internally at ..\c10\cuda\CUDAFunctions.cpp:109.) Variable._execution_engine.run_backward(
y_pred = model(fTest)
print("predicted Y value: ", y_pred.data)
predicted Y value: tensor([[0.0468], [0.0325], [0.2577], [0.2059], [0.1090], [0.0229], [0.2290], [0.0689], [0.2476], [0.0453], [0.0150], [0.4080], [0.0424], [0.0981], [0.0221], [0.1546], [0.1400], [0.1768], [0.1684], [0.0229], [0.1836], [0.1200], [0.0137], [0.2316], [0.0185], [0.0179], [0.0108], [0.0175], [0.0471], [0.4576], [0.0210], [0.0103], [0.0616], [0.1850], [0.4114], [0.4264], [0.0405], [0.0788], [0.2405], [0.0340], [0.4345], [0.1758], [0.0385], [0.0749], [0.4349], [0.0357], [0.0295], [0.3939], [0.1147], [0.3812], [0.0659], [0.0675], [0.0263], [0.1398], [0.0959], [0.0406], [0.0531], [0.0500], [0.4259], [0.1086], [0.0611], [0.0855], [0.0473], [0.2826], [0.1734], [0.0560], [0.0466], [0.0290], [0.1903], [0.4515], [0.0118], [0.2158], [0.1293], [0.2488], [0.0424], [0.1809], [0.0122], [0.0796], [0.0901], [0.0879], [0.0457], [0.0091], [0.0196], [0.3310], [0.0978], [0.0843], [0.0684], [0.0340], [0.0583], [0.0670], [0.0133], [0.1165], [0.0145], [0.1581], [0.1677], [0.0353], [0.0745], [0.0108], [0.0492], [0.0611], [0.2977], [0.2820], [0.0219], [0.0580], [0.0122], [0.0726], [0.3315], [0.0201], [0.2460], [0.0110], [0.0322], [0.0180], [0.0135], [0.3176], [0.1390], [0.0678], [0.1596], [0.0128], [0.0900], [0.0117], [0.0224], [0.0357], [0.0103], [0.1728], [0.0135], [0.0992], [0.4371], [0.4525], [0.0278], [0.0617], [0.2499], [0.0129], [0.0424], [0.0292], [0.3903], [0.0108], [0.0404], [0.0344], [0.4109], [0.3936], [0.0603], [0.4396], [0.1155], [0.3594], [0.0305], [0.0307], [0.0226], [0.1284], [0.0474], [0.0959], [0.0135], [0.0289], [0.3705], [0.1538], [0.4535], [0.0355], [0.0169], [0.1648], [0.4217], [0.0951], [0.0767], [0.0475], [0.0452], [0.1625], [0.0896], [0.0114], [0.0423], [0.3971], [0.0173], [0.0250], [0.3579], [0.0131], [0.0201], [0.0149], [0.2615], [0.1773], [0.1204], [0.3556], [0.2390], [0.0098], [0.0190], [0.3040], [0.0115], [0.2033], [0.1327], [0.0180], [0.0610], [0.2927], [0.1182], [0.0115], [0.4474], [0.3513], [0.0451], [0.4089], [0.0375], [0.0127], [0.0630], [0.0428], [0.2085], [0.0529], [0.3436], [0.0678], [0.0717], [0.0799], [0.0967], [0.1246], [0.1086], [0.0387], [0.1742], [0.1582], [0.1374], [0.4205], [0.0534], [0.3051], [0.1204], [0.0423], [0.0324], [0.0141], [0.0312], [0.0261], [0.2619], [0.0767], [0.1742], [0.0311], [0.1763], [0.0326], [0.0529], [0.3928], [0.1209], [0.0724], [0.3551], [0.2067], [0.0288], [0.0782], [0.0661], [0.0469], [0.1089], [0.0985], [0.1032], [0.1083], [0.0546], [0.0983], [0.1302], [0.0153], [0.2179], [0.1196], [0.0275], [0.4366], [0.0340], [0.0286], [0.1193], [0.0729], [0.0553], [0.0159], [0.0140], [0.2195], [0.3792], [0.3966], [0.0424], [0.0872], [0.0687], [0.1941], [0.0179], [0.0380], [0.2445], [0.1905], [0.1518], [0.2370], [0.0706], [0.1668], [0.1265], [0.2363], [0.0354], [0.0263], [0.0653], [0.0097], [0.0152], [0.0495], [0.2952], [0.3581], [0.0388], [0.0365], [0.2808], [0.0189], [0.0133], [0.0692], [0.0256], [0.0500], [0.1452], [0.3315], [0.4509], [0.2079], [0.0140], [0.2505], [0.1044], [0.0121], [0.0087], [0.0453], [0.0173], [0.0127], [0.0176], [0.3826], [0.0843], [0.0885], [0.0133], [0.0190], [0.0278], [0.0612], [0.2574], [0.0404], [0.0735], [0.0527], [0.0300], [0.1384], [0.2189], [0.0301], [0.2303], [0.0425], [0.0294], [0.0103], [0.2857], [0.0620], [0.0938], [0.2283], [0.0237], [0.1175], [0.1809], [0.0725], [0.0107], [0.0629], [0.4485], [0.0233], [0.0165], [0.2240], [0.1200], [0.0665], [0.1153], [0.2957], [0.0112], [0.0328], [0.2823], [0.4248], [0.0308], [0.0216], [0.1206], [0.2086], [0.0115], [0.0399], [0.0246], [0.3960], [0.1782], [0.0591], [0.0092], [0.0600], [0.2938], [0.1144], [0.0136], [0.2075], [0.0426], [0.0998], [0.0407], [0.2944], [0.2721], [0.0734], [0.2927], [0.0482], [0.2740], [0.0363], [0.4624], [0.0558], [0.1669], [0.0243], [0.0109], [0.1209], [0.0617], [0.0634], [0.0183], [0.0319], [0.1135], [0.0121], [0.0314], [0.0137], [0.0195], [0.0094], [0.3304], [0.0694], [0.0144], [0.0278], [0.1393], [0.3971], [0.0939], [0.0489], [0.0763], [0.0394], [0.2953], [0.0581], [0.0404], [0.0489], [0.0429], [0.1940], [0.0098], [0.0535], [0.2953], [0.2188], [0.0115], [0.1468], [0.0210], [0.2410], [0.0685], [0.1935], [0.1258], [0.0146], [0.0279], [0.0240], [0.3981], [0.2131], [0.0267], [0.0184], [0.2806], [0.2224], [0.2687], [0.0207], [0.2931], [0.0707], [0.0408], [0.0836], [0.0799], [0.1043], [0.0235], [0.1093], [0.0915], [0.0186], [0.0885], [0.0143], [0.0430], [0.4653], [0.1440], [0.0343], [0.1683], [0.0222], [0.1450], [0.0497], [0.2624], [0.0158], [0.0157], [0.3439], [0.1724], [0.1858], [0.4211], [0.0741], [0.0708], [0.0437], [0.0117], [0.1091], [0.0450], [0.1210], [0.0864], [0.1131], [0.0796], [0.3588], [0.1135], [0.0211], [0.0152], [0.0109], [0.1337], [0.0341], [0.1293], [0.0809], [0.1133], [0.0163], [0.0598], [0.1512], [0.0480], [0.1759], [0.1126], [0.1127], [0.0263], [0.0215], [0.4364], [0.0164], [0.0447], [0.0979], [0.0678], [0.0624], [0.0317], [0.0102], [0.0188], [0.0698], [0.1521], [0.0642], [0.3642], [0.4675], [0.0807], [0.0191], [0.0761], [0.0902], [0.2640], [0.0658], [0.0115], [0.0166], [0.0222], [0.0138], [0.0588], [0.0254], [0.3206], [0.0467], [0.0861], [0.0161], [0.0726], [0.2488], [0.0441], [0.0508], [0.2247], [0.2302], [0.1532], [0.4200], [0.0451], [0.3172], [0.3833], [0.1273], [0.0189], [0.0763], [0.2182], [0.0644], [0.2589], [0.1022], [0.0138], [0.0616], [0.2501], [0.0143], [0.0120], [0.3911], [0.2098], [0.1860], [0.0638], [0.0235], [0.0264], [0.1376], [0.3966], [0.2488], [0.1522], [0.0390], [0.3700], [0.1037], [0.0729], [0.1019], [0.0281], [0.0292], [0.0099], [0.2580], [0.0105], [0.1386], [0.0141], [0.0188], [0.0625], [0.2696], [0.0582], [0.0218], [0.1327], [0.0290], [0.3961], [0.0202], [0.0209], [0.3187], [0.1900], [0.0237], [0.3660], [0.1311], [0.0300], [0.0211], [0.4100], [0.0311], [0.1034], [0.4346], [0.1150], [0.0588], [0.1075], [0.0989], [0.2195], [0.4500], [0.0934], [0.0930], [0.1336], [0.1932], [0.1717], [0.0731], [0.1601], [0.2492], [0.0096], [0.0759], [0.1010], [0.0592], [0.4519], [0.2835], [0.0693], [0.1462], [0.0280], [0.4231], [0.0400], [0.1261], [0.0129], [0.1344], [0.0724], [0.0362], [0.0444], [0.0724], [0.0266], [0.0624], [0.0094], [0.0557], [0.1328], [0.1478], [0.1098], [0.0486], [0.1091], [0.1119], [0.1213], [0.2821], [0.4471], [0.0485], [0.0278], [0.2685], [0.0907], [0.0253], [0.0618], [0.0361], [0.2087], [0.1609], [0.2896], [0.2296], [0.1015], [0.2034], [0.0103], [0.0483], [0.0843], [0.2846], [0.0126], [0.0839], [0.0097], [0.3067], [0.0319], [0.1666], [0.0134], [0.3563], [0.0354], [0.2181], [0.0399], [0.1233], [0.0332], [0.1251], [0.3705], [0.0720], [0.0260], [0.2191], [0.0998], [0.2293], [0.1474], [0.0092], [0.0092], [0.0433], [0.1093], [0.0165], [0.3294], [0.0136], [0.0735], [0.0381], [0.1373], [0.1616], [0.0496], [0.1992], [0.0342], [0.2832], [0.0306], [0.0188], [0.3583], [0.1543], [0.0188], [0.0111], [0.0964], [0.0963], [0.2209], [0.1034], [0.1088], [0.0695], [0.0308], [0.0280], [0.0712], [0.0474], [0.1890], [0.3057], [0.0896], [0.2190], [0.1548], [0.4623], [0.4395], [0.3971], [0.0328], [0.0132], [0.4267], [0.0234], [0.0202], [0.0141], [0.1102], [0.0159], [0.0842], [0.0629], [0.1334], [0.1256], [0.2835], [0.3958], [0.1798], [0.0108], [0.0144], [0.0220], [0.0982], [0.1031], [0.0454], [0.0292], [0.1306], [0.4508], [0.0465], [0.3683], [0.0549], [0.0102], [0.0222], [0.0503], [0.0245], [0.1466], [0.1999], [0.0237], [0.1147], [0.4205], [0.0654], [0.1033], [0.1182], [0.0737], [0.0561], [0.2595], [0.0116], [0.1922], [0.4246], [0.3039], [0.1907], [0.4135], [0.3967], [0.2716], [0.2395], [0.2179], [0.2798], [0.0185], [0.0197], [0.0482], [0.0086], [0.0088], [0.1054], [0.1191], [0.0319], [0.1223], [0.0358], [0.2026], [0.0206], [0.0784], [0.1204], [0.3416], [0.3174], [0.0210], [0.0305], [0.0098], [0.3320], [0.0258], [0.3058], [0.3179], [0.1205], [0.0276], [0.0308], [0.0135], [0.3972], [0.2718], [0.4641], [0.0432], [0.0744], [0.3558], [0.2100], [0.1327], [0.1907], [0.0276], [0.1947], [0.0844], [0.2688], [0.0658], [0.0929], [0.1471], [0.1222], [0.0212], [0.0804], [0.4366], [0.2472], [0.0325], [0.0178], [0.3040], [0.0868], [0.0683], [0.0470], [0.2027], [0.0262], [0.1257], [0.1146], [0.0586], [0.1017], [0.4349], [0.0286], [0.0723], [0.3070], [0.0135], [0.0380], [0.0447], [0.0161], [0.0729], [0.0360], [0.1328], [0.0126], [0.0531], [0.1831], [0.4434], [0.0198], [0.0878], [0.0382], [0.0387], [0.0438], [0.0720], [0.0311], [0.1378], [0.1178], [0.1888], [0.1199], [0.2023], [0.1153], [0.0523], [0.0420], [0.0200], [0.1671], [0.1086], [0.0503], [0.0441], [0.0242], [0.1676], [0.0120], [0.4497], [0.0396], [0.0922], [0.2931], [0.0194], [0.0267], [0.1875], [0.1045], [0.0839], [0.1251], [0.0294], [0.1090], [0.0136], [0.0851], [0.0360], [0.0158], [0.0944], [0.0110], [0.0114], [0.0586], [0.4468], [0.0760], [0.0501], [0.2267], [0.0528], [0.0367], [0.0803], [0.1456], [0.2818], [0.0266], [0.1995], [0.3691], [0.2341], [0.2593], [0.0636], [0.1788], [0.0479], [0.0509], [0.1104], [0.0918], [0.1508], [0.3680], [0.3948], [0.1899], [0.2569], [0.0363], [0.0262], [0.0936], [0.0550], [0.1027], [0.1444], [0.0330], [0.0097], [0.0761], [0.2207], [0.0326], [0.2501], [0.0394], [0.0760], [0.0381], [0.0115], [0.2717], [0.0423], [0.0731], [0.1560], [0.0826], [0.0092], [0.0219], [0.0751], [0.1322], [0.2677], [0.1361], [0.4089], [0.0925], [0.0266], [0.1068], [0.3935], [0.0987], [0.0115], [0.3348], [0.0551], [0.0817], [0.0489], [0.1392], [0.0596], [0.0844], [0.2388], [0.0960], [0.0721], [0.1400], [0.4667], [0.2374], [0.0349], [0.0857], [0.1599], [0.1922], [0.0281], [0.0183], [0.4507], [0.0167], [0.0283], [0.0402], [0.2076], [0.1693], [0.1446], [0.3547], [0.0943], [0.3730], [0.1823], [0.0426], [0.0149], [0.0327], [0.3715], [0.0474], [0.1343], [0.1915], [0.0690], [0.0092], [0.1643], [0.2189], [0.3149], [0.2171], [0.2178], [0.0097], [0.3628], [0.0163], [0.0684], [0.1145], [0.4074], [0.0514], [0.3587], [0.0905], [0.0159], [0.1992], [0.0109]])
print ("The accuracy is", accuracy_score(tTest, np.argmax(y_pred.detach().numpy(), axis=1)))
The accuracy is 0.9480651731160896
torch.save(model, 'stroke.pkl')