s434766
3dfc1beec8
Some checks reported errors
s434766-training/pipeline/head Something is wrong with the build of this commit
211 lines
23 KiB
Plaintext
211 lines
23 KiB
Plaintext
{
|
|
"metadata": {
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.8.5"
|
|
},
|
|
"orig_nbformat": 2,
|
|
"kernelspec": {
|
|
"name": "python385jvsc74a57bd02cef13873963874fd5439bd04a135498d1dd9725d9d90f40de0b76178a8e03b1",
|
|
"display_name": "Python 3.8.5 64-bit (conda)"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2,
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 74,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"\n",
|
|
"import torch\n",
|
|
"import torch.nn.functional as F\n",
|
|
"from torch import nn\n",
|
|
"from torch.autograd import Variable\n",
|
|
"import torchvision.transforms as transforms\n",
|
|
"from sklearn.model_selection import train_test_split\n",
|
|
"from sklearn.preprocessing import MinMaxScaler\n",
|
|
"from sklearn.metrics import accuracy_score\n",
|
|
"import numpy as np\n",
|
|
"import pandas as pd\n",
|
|
"\n",
|
|
"\n",
|
|
"\n",
|
|
"class LogisticRegressionModel(nn.Module):\n",
|
|
" def __init__(self, input_dim, output_dim):\n",
|
|
" super(LogisticRegressionModel, self).__init__()\n",
|
|
" self.linear = nn.Linear(input_dim, output_dim)\n",
|
|
" self.sigmoid = nn.Sigmoid()\n",
|
|
" def forward(self, x):\n",
|
|
" out = self.linear(x)\n",
|
|
" return self.sigmoid(out)\n",
|
|
"\n",
|
|
"np.set_printoptions(suppress=False)\n",
|
|
"data_train = pd.read_csv(\"data_train.csv\")\n",
|
|
"data_test = pd.read_csv(\"data_test.csv\")\n",
|
|
"data_val = pd.read_csv(\"data_val.csv\")\n",
|
|
"FEATURES = [ 'age','hypertension','heart_disease','ever_married', 'avg_glucose_level', 'bmi']\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 75,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"x_train = data_train[FEATURES].astype(np.float32)\n",
|
|
"y_train = data_train['stroke'].astype(np.float32)\n",
|
|
"\n",
|
|
"x_test = data_test[FEATURES].astype(np.float32)\n",
|
|
"y_test = data_test['stroke'].astype(np.float32)\n",
|
|
"\n",
|
|
"\n",
|
|
"\n",
|
|
"fTrain = torch.from_numpy(x_train.values)\n",
|
|
"tTrain = torch.from_numpy(y_train.values.reshape(2945,1))\n",
|
|
"\n",
|
|
"fTest= torch.from_numpy(x_test.values)\n",
|
|
"tTest = torch.from_numpy(y_test.values)\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 76,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"\n",
|
|
"batch_size = 150\n",
|
|
"n_iters = 1000\n",
|
|
"num_epochs = 10"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 77,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"input_dim = 6\n",
|
|
"output_dim = 1\n",
|
|
"\n",
|
|
"model = LogisticRegressionModel(input_dim, output_dim)\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 78,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"learning_rate = 0.001\n",
|
|
"\n",
|
|
"criterion = torch.nn.BCELoss(reduction='mean') \n",
|
|
"optimizer = torch.optim.SGD(model.parameters(), lr = learning_rate)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 79,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"output_type": "stream",
|
|
"name": "stdout",
|
|
"text": [
|
|
"torch.Size([1, 6])\ntorch.Size([1])\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print(list(model.parameters())[0].size())\n",
|
|
"print(list(model.parameters())[1].size())"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 80,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"output_type": "stream",
|
|
"name": "stdout",
|
|
"text": [
|
|
"Epoch # 0\n0.34391772747039795\nEpoch # 1\n0.3400452435016632\nEpoch # 2\n0.33628249168395996\nEpoch # 3\n0.3326331079006195\nEpoch # 4\n0.3291005790233612\nEpoch # 5\n0.32568827271461487\nEpoch # 6\n0.32239940762519836\nEpoch # 7\n0.3192369043827057\nEpoch # 8\n0.3162035048007965\nEpoch # 9\n0.31330153346061707\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"for epoch in range(num_epochs):\n",
|
|
" print (\"Epoch #\",epoch)\n",
|
|
" model.train()\n",
|
|
" optimizer.zero_grad()\n",
|
|
" # Forward pass\n",
|
|
" y_pred = model(fTrain)\n",
|
|
" # Compute Loss\n",
|
|
" loss = criterion(y_pred, tTrain)\n",
|
|
" print(loss.item())\n",
|
|
" # Backward pass\n",
|
|
" loss.backward()\n",
|
|
" optimizer.step()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 81,
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [
|
|
{
|
|
"output_type": "stream",
|
|
"name": "stdout",
|
|
"text": [
|
|
"predicted Y value: tensor([[0.0089],\n [0.0051],\n [0.1535],\n [0.1008],\n [0.0365],\n [0.0014],\n [0.1275],\n [0.0172],\n [0.1439],\n [0.0088],\n [0.0013],\n [0.3466],\n [0.0078],\n [0.0303],\n [0.0024],\n [0.0607],\n [0.0556],\n [0.0826],\n [0.0765],\n [0.0027],\n [0.0869],\n [0.0424],\n [0.0013],\n [0.1338],\n [0.0017],\n [0.0020],\n [0.0009],\n [0.0014],\n [0.0090],\n [0.4073],\n [0.0026],\n [0.0009],\n [0.0141],\n [0.0897],\n [0.3593],\n [0.3849],\n [0.0073],\n [0.0204],\n [0.1406],\n [0.0053],\n [0.3840],\n [0.0802],\n [0.0068],\n [0.0190],\n [0.3849],\n [0.0034],\n [0.0045],\n [0.3272],\n [0.0397],\n [0.3087],\n [0.0162],\n [0.0159],\n [0.0033],\n [0.0559],\n [0.0238],\n [0.0073],\n [0.0113],\n [0.0102],\n [0.3827],\n [0.0359],\n [0.0138],\n [0.0248],\n [0.0080],\n [0.1858],\n [0.0766],\n [0.0123],\n [0.0077],\n [0.0042],\n [0.0908],\n [0.4172],\n [0.0010],\n [0.1105],\n [0.0463],\n [0.1457],\n [0.0078],\n [0.0821],\n [0.0011],\n [0.0210],\n [0.0273],\n [0.0248],\n [0.0082],\n [0.0007],\n [0.0022],\n [0.2436],\n [0.0297],\n [0.0235],\n [0.0168],\n [0.0053],\n [0.0128],\n [0.0156],\n [0.0009],\n [0.0375],\n [0.0008],\n [0.0645],\n [0.0750],\n [0.0055],\n [0.0185],\n [0.0008],\n [0.0082],\n [0.0138],\n [0.2082],\n [0.1823],\n [0.0027],\n [0.0124],\n [0.0010],\n [0.0187],\n [0.2454],\n [0.0019],\n [0.1413],\n [0.0010],\n [0.0050],\n [0.0020],\n [0.0011],\n [0.2266],\n [0.0545],\n [0.0164],\n [0.0678],\n [0.0012],\n [0.0271],\n [0.0009],\n [0.0029],\n [0.0058],\n [0.0009],\n [0.0762],\n [0.0013],\n [0.0276],\n [0.3940],\n [0.4213],\n [0.0041],\n [0.0144],\n [0.1491],\n [0.0011],\n [0.0077],\n [0.0032],\n [0.3155],\n [0.0009],\n [0.0072],\n [0.0056],\n [0.3580],\n [0.3235],\n [0.0130],\n [0.4032],\n [0.0405],\n [0.2882],\n [0.0045],\n [0.0041],\n [0.0026],\n [0.0354],\n [0.0094],\n [0.0278],\n [0.0011],\n [0.0036],\n [0.2996],\n [0.0652],\n [0.4247],\n [0.0048],\n [0.0016],\n [0.0703],\n [0.3676],\n [0.0231],\n [0.0206],\n [0.0093],\n [0.0087],\n [0.0649],\n [0.0207],\n [0.0010],\n [0.0076],\n [0.3366],\n [0.0015],\n [0.0034],\n [0.2819],\n [0.0007],\n [0.0024],\n [0.0015],\n [0.1623],\n [0.0838],\n [0.0431],\n [0.2744],\n [0.1369],\n [0.0007],\n [0.0022],\n [0.2049],\n [0.0010],\n [0.1057],\n [0.0503],\n [0.0021],\n [0.0136],\n [0.1939],\n [0.0401],\n [0.0010],\n [0.4003],\n [0.2621],\n [0.0087],\n [0.3507],\n [0.0061],\n [0.0012],\n [0.0103],\n [0.0080],\n [0.1068],\n [0.0098],\n [0.2625],\n [0.0162],\n [0.0178],\n [0.0215],\n [0.0283],\n [0.0444],\n [0.0356],\n [0.0037],\n [0.0779],\n [0.0652],\n [0.0521],\n [0.3626],\n [0.0116],\n [0.2099],\n [0.0431],\n [0.0076],\n [0.0050],\n [0.0010],\n [0.0049],\n [0.0036],\n [0.1638],\n [0.0205],\n [0.0775],\n [0.0026],\n [0.0826],\n [0.0052],\n [0.0111],\n [0.3208],\n [0.0435],\n [0.0140],\n [0.2729],\n [0.1035],\n [0.0041],\n [0.0202],\n [0.0164],\n [0.0089],\n [0.0359],\n [0.0306],\n [0.0325],\n [0.0354],\n [0.0096],\n [0.0304],\n [0.0473],\n [0.0016],\n [0.1147],\n [0.0368],\n [0.0034],\n [0.3918],\n [0.0054],\n [0.0040],\n [0.0415],\n [0.0190],\n [0.0116],\n [0.0014],\n [0.0010],\n [0.1187],\n [0.3011],\n [0.3357],\n [0.0077],\n [0.0243],\n [0.0128],\n [0.0978],\n [0.0020],\n [0.0064],\n [0.1374],\n [0.0916],\n [0.0623],\n [0.1318],\n [0.0144],\n [0.0737],\n [0.0468],\n [0.1312],\n [0.0056],\n [0.0017],\n [0.0157],\n [0.0008],\n [0.0016],\n [0.0098],\n [0.2015],\n [0.2831],\n [0.0054],\n [0.0062],\n [0.1790],\n [0.0022],\n [0.0012],\n [0.0174],\n [0.0033],\n [0.0100],\n [0.0580],\n [0.2458],\n [0.4155],\n [0.1052],\n [0.0007],\n [0.1498],\n [0.0344],\n [0.0011],\n [0.0006],\n [0.0088],\n [0.0018],\n [0.0011],\n [0.0014],\n [0.3143],\n [0.0236],\n [0.0252],\n [0.0010],\n [0.0022],\n [0.0041],\n [0.0105],\n [0.1520],\n [0.0063],\n [0.0197],\n [0.0082],\n [0.0038],\n [0.0533],\n [0.1163],\n [0.0032],\n [0.1304],\n [0.0058],\n [0.0045],\n [0.0008],\n [0.1758],\n [0.0078],\n [0.0282],\n [0.1256],\n [0.0021],\n [0.0388],\n [0.0826],\n [0.0185],\n [0.0009],\n [0.0152],\n [0.4050],\n [0.0030],\n [0.0017],\n [0.1281],\n [0.0424],\n [0.0090],\n [0.0404],\n [0.2027],\n [0.0005],\n [0.0053],\n [0.1841],\n [0.3782],\n [0.0047],\n [0.0025],\n [0.0432],\n [0.1070],\n [0.0010],\n [0.0069],\n [0.0033],\n [0.3327],\n [0.0852],\n [0.0135],\n [0.0007],\n [0.0109],\n [0.1961],\n [0.0393],\n [0.0013],\n [0.1051],\n [0.0078],\n [0.0324],\n [0.0065],\n [0.1976],\n [0.1738],\n [0.0195],\n [0.1941],\n [0.0088],\n [0.1793],\n [0.0061],\n [0.4261],\n [0.0122],\n [0.0733],\n [0.0026],\n [0.0009],\n [0.0436],\n [0.0141],\n [0.0124],\n [0.0020],\n [0.0048],\n [0.0381],\n [0.0009],\n [0.0044],\n [0.0014],\n [0.0010],\n [0.0008],\n [0.2416],\n [0.0177],\n [0.0015],\n [0.0041],\n [0.0554],\n [0.3380],\n [0.0284],\n [0.0093],\n [0.0200],\n [0.0057],\n [0.2006],\n [0.0126],\n [0.0063],\n [0.0072],\n [0.0071],\n [0.0981],\n [0.0008],\n [0.0116],\n [0.1803],\n [0.1175],\n [0.0010],\n [0.0610],\n [0.0026],\n [0.1416],\n [0.0168],\n [0.0973],\n [0.0457],\n [0.0012],\n [0.0016],\n [0.0015],\n [0.3407],\n [0.1155],\n [0.0035],\n [0.0010],\n [0.1799],\n [0.1241],\n [0.1657],\n [0.0018],\n [0.1940],\n [0.0171],\n [0.0075],\n [0.0230],\n [0.0113],\n [0.0342],\n [0.0031],\n [0.0369],\n [0.0262],\n [0.0018],\n [0.0224],\n [0.0014],\n [0.0081],\n [0.4363],\n [0.0562],\n [0.0056],\n [0.0754],\n [0.0028],\n [0.0582],\n [0.0100],\n [0.1473],\n [0.0016],\n [0.0013],\n [0.2620],\n [0.0646],\n [0.0910],\n [0.3646],\n [0.0181],\n [0.0146],\n [0.0076],\n [0.0009],\n [0.0363],\n [0.0086],\n [0.0336],\n [0.0206],\n [0.0374],\n [0.0212],\n [0.2859],\n [0.0384],\n [0.0026],\n [0.0016],\n [0.0009],\n [0.0519],\n [0.0054],\n [0.0392],\n [0.0225],\n [0.0382],\n [0.0016],\n [0.0075],\n [0.0619],\n [0.0040],\n [0.0810],\n [0.0372],\n [0.0376],\n [0.0037],\n [0.0028],\n [0.3910],\n [0.0017],\n [0.0084],\n [0.0298],\n [0.0160],\n [0.0112],\n [0.0024],\n [0.0008],\n [0.0019],\n [0.0181],\n [0.0624],\n [0.0127],\n [0.2768],\n [0.4461],\n [0.0224],\n [0.0015],\n [0.0200],\n [0.0274],\n [0.1512],\n [0.0161],\n [0.0005],\n [0.0009],\n [0.0028],\n [0.0012],\n [0.0132],\n [0.0028],\n [0.2346],\n [0.0088],\n [0.0227],\n [0.0015],\n [0.0187],\n [0.1476],\n [0.0069],\n [0.0096],\n [0.1176],\n [0.1304],\n [0.0643],\n [0.3622],\n [0.0078],\n [0.2245],\n [0.3165],\n [0.0477],\n [0.0022],\n [0.0203],\n [0.1160],\n [0.0148],\n [0.1549],\n [0.0316],\n [0.0011],\n [0.0141],\n [0.1498],\n [0.0011],\n [0.0011],\n [0.3146],\n [0.1083],\n [0.0908],\n [0.0128],\n [0.0031],\n [0.0038],\n [0.0525],\n [0.3349],\n [0.1460],\n [0.0631],\n [0.0054],\n [0.2970],\n [0.0336],\n [0.0166],\n [0.0312],\n [0.0037],\n [0.0038],\n [0.0008],\n [0.1529],\n [0.0008],\n [0.0544],\n [0.0013],\n [0.0021],\n [0.0079],\n [0.1678],\n [0.0127],\n [0.0026],\n [0.0506],\n [0.0042],\n [0.3334],\n [0.0022],\n [0.0021],\n [0.2290],\n [0.0904],\n [0.0032],\n [0.2832],\n [0.0484],\n [0.0022],\n [0.0026],\n [0.3545],\n [0.0048],\n [0.0331],\n [0.3835],\n [0.0400],\n [0.0132],\n [0.0342],\n [0.0311],\n [0.1180],\n [0.4113],\n [0.0278],\n [0.0277],\n [0.0400],\n [0.0968],\n [0.0736],\n [0.0176],\n [0.0696],\n [0.1483],\n [0.0007],\n [0.0197],\n [0.0305],\n [0.0125],\n [0.4180],\n [0.1870],\n [0.0174],\n [0.0603],\n [0.0042],\n [0.3729],\n [0.0069],\n [0.0466],\n [0.0006],\n [0.0530],\n [0.0183],\n [0.0061],\n [0.0061],\n [0.0185],\n [0.0021],\n [0.0099],\n [0.0007],\n [0.0121],\n [0.0504],\n [0.0619],\n [0.0373],\n [0.0078],\n [0.0366],\n [0.0357],\n [0.0443],\n [0.1825],\n [0.3996],\n [0.0054],\n [0.0040],\n [0.1632],\n [0.0247],\n [0.0019],\n [0.0141],\n [0.0060],\n [0.1078],\n [0.0698],\n [0.1844],\n [0.1288],\n [0.0308],\n [0.1057],\n [0.0009],\n [0.0100],\n [0.0234],\n [0.1900],\n [0.0010],\n [0.0231],\n [0.0007],\n [0.2147],\n [0.0048],\n [0.0733],\n [0.0012],\n [0.2775],\n [0.0033],\n [0.1134],\n [0.0070],\n [0.0424],\n [0.0055],\n [0.0451],\n [0.2994],\n [0.0181],\n [0.0035],\n [0.1174],\n [0.0287],\n [0.1288],\n [0.0613],\n [0.0007],\n [0.0007],\n [0.0083],\n [0.0365],\n [0.0017],\n [0.2397],\n [0.0007],\n [0.0196],\n [0.0065],\n [0.0519],\n [0.0722],\n [0.0098],\n [0.0986],\n [0.0025],\n [0.1867],\n [0.0046],\n [0.0015],\n [0.2836],\n [0.0665],\n [0.0021],\n [0.0005],\n [0.0285],\n [0.0281],\n [0.1217],\n [0.0330],\n [0.0361],\n [0.0177],\n [0.0046],\n [0.0041],\n [0.0175],\n [0.0092],\n [0.0877],\n [0.2109],\n [0.0267],\n [0.1180],\n [0.0670],\n [0.4255],\n [0.4027],\n [0.3366],\n [0.0052],\n [0.0009],\n [0.3857],\n [0.0031],\n [0.0013],\n [0.0011],\n [0.0379],\n [0.0017],\n [0.0236],\n [0.0153],\n [0.0518],\n [0.0456],\n [0.1873],\n [0.3320],\n [0.0881],\n [0.0009],\n [0.0015],\n [0.0027],\n [0.0304],\n [0.0332],\n [0.0036],\n [0.0043],\n [0.0409],\n [0.4144],\n [0.0087],\n [0.2914],\n [0.0116],\n [0.0008],\n [0.0028],\n [0.0103],\n [0.0032],\n [0.0606],\n [0.1001],\n [0.0028],\n [0.0397],\n [0.3626],\n [0.0158],\n [0.0331],\n [0.0401],\n [0.0182],\n [0.0125],\n [0.1572],\n [0.0010],\n [0.0937],\n [0.3776],\n [0.2067],\n [0.0914],\n [0.3666],\n [0.3347],\n [0.1722],\n [0.1384],\n [0.1152],\n [0.1753],\n [0.0009],\n [0.0023],\n [0.0088],\n [0.0006],\n [0.0006],\n [0.0313],\n [0.0413],\n [0.0022],\n [0.0413],\n [0.0059],\n [0.1054],\n [0.0024],\n [0.0173],\n [0.0426],\n [0.2546],\n [0.2232],\n [0.0026],\n [0.0045],\n [0.0008],\n [0.2471],\n [0.0035],\n [0.2115],\n [0.2262],\n [0.0433],\n [0.0039],\n [0.0047],\n [0.0013],\n [0.3367],\n [0.1733],\n [0.4315],\n [0.0075],\n [0.0182],\n [0.2763],\n [0.1106],\n [0.0510],\n [0.0907],\n [0.0039],\n [0.0987],\n [0.0239],\n [0.1644],\n [0.0162],\n [0.0271],\n [0.0608],\n [0.0351],\n [0.0014],\n [0.0221],\n [0.3912],\n [0.1440],\n [0.0051],\n [0.0015],\n [0.2073],\n [0.0235],\n [0.0168],\n [0.0090],\n [0.1056],\n [0.0036],\n [0.0452],\n [0.0396],\n [0.0070],\n [0.0314],\n [0.3853],\n [0.0039],\n [0.0184],\n [0.2150],\n [0.0013],\n [0.0055],\n [0.0084],\n [0.0015],\n [0.0190],\n [0.0060],\n [0.0507],\n [0.0007],\n [0.0113],\n [0.0851],\n [0.3850],\n [0.0012],\n [0.0245],\n [0.0066],\n [0.0069],\n [0.0040],\n [0.0183],\n [0.0043],\n [0.0529],\n [0.0398],\n [0.0887],\n [0.0426],\n [0.0927],\n [0.0404],\n [0.0049],\n [0.0056],\n [0.0024],\n [0.0739],\n [0.0359],\n [0.0104],\n [0.0081],\n [0.0023],\n [0.0747],\n [0.0011],\n [0.4100],\n [0.0068],\n [0.0268],\n [0.1948],\n [0.0013],\n [0.0018],\n [0.0867],\n [0.0346],\n [0.0231],\n [0.0455],\n [0.0044],\n [0.0365],\n [0.0013],\n [0.0245],\n [0.0059],\n [0.0014],\n [0.0257],\n [0.0010],\n [0.0010],\n [0.0131],\n [0.3990],\n [0.0200],\n [0.0078],\n [0.1217],\n [0.0110],\n [0.0064],\n [0.0219],\n [0.0580],\n [0.1824],\n [0.0020],\n [0.0986],\n [0.2947],\n [0.1087],\n [0.1563],\n [0.0122],\n [0.0784],\n [0.0097],\n [0.0097],\n [0.0347],\n [0.0263],\n [0.0613],\n [0.2907],\n [0.3285],\n [0.0896],\n [0.1502],\n [0.0061],\n [0.0037],\n [0.0282],\n [0.0116],\n [0.0323],\n [0.0576],\n [0.0054],\n [0.0008],\n [0.0200],\n [0.1212],\n [0.0052],\n [0.1496],\n [0.0065],\n [0.0199],\n [0.0065],\n [0.0010],\n [0.1536],\n [0.0076],\n [0.0191],\n [0.0691],\n [0.0187],\n [0.0007],\n [0.0027],\n [0.0162],\n [0.0440],\n [0.1624],\n [0.0502],\n [0.3496],\n [0.0270],\n [0.0034],\n [0.0337],\n [0.3247],\n [0.0274],\n [0.0010],\n [0.2565],\n [0.0099],\n [0.0126],\n [0.0092],\n [0.0546],\n [0.0139],\n [0.0238],\n [0.1364],\n [0.0246],\n [0.0183],\n [0.0558],\n [0.4423],\n [0.1326],\n [0.0059],\n [0.0229],\n [0.0692],\n [0.0944],\n [0.0022],\n [0.0017],\n [0.4130],\n [0.0013],\n [0.0037],\n [0.0071],\n [0.1049],\n [0.0774],\n [0.0569],\n [0.2711],\n [0.0290],\n [0.3081],\n [0.0848],\n [0.0078],\n [0.0015],\n [0.0046],\n [0.3030],\n [0.0093],\n [0.0481],\n [0.0931],\n [0.0174],\n [0.0007],\n [0.0695],\n [0.1172],\n [0.2178],\n [0.1137],\n [0.1141],\n [0.0008],\n [0.2754],\n [0.0008],\n [0.0167],\n [0.0398],\n [0.3444],\n [0.0089],\n [0.2858],\n [0.0251],\n [0.0016],\n [0.0993],\n [0.0009]])\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"\n",
|
|
"y_pred = model(fTest)\n",
|
|
"print(\"predicted Y value: \", y_pred.data)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 87,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"output_type": "stream",
|
|
"name": "stdout",
|
|
"text": [
|
|
"The accuracy is 0.9480651731160896\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print (\"The accuracy is\", accuracy_score(tTest, np.argmax(y_pred.detach().numpy(), axis=1)))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 83,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"torch.save(model, 'stroke.pkl')"
|
|
]
|
|
}
|
|
]
|
|
} |