"
],
"text/plain": [
" PassengerId Survived Pclass \\\n",
"0 1 0 3 \n",
"1 2 1 1 \n",
"2 3 1 3 \n",
"3 4 1 1 \n",
"4 5 0 3 \n",
"\n",
" Name Sex Age SibSp \\\n",
"0 Braund, Mr. Owen Harris male 0.271174 1 \n",
"1 Cumings, Mrs. John Bradley (Florence Briggs Th... female 0.472229 1 \n",
"2 Heikkinen, Miss. Laina female 0.321438 0 \n",
"3 Futrelle, Mrs. Jacques Heath (Lily May Peel) female 0.434531 1 \n",
"4 Allen, Mr. William Henry male 0.434531 0 \n",
"\n",
" Parch Ticket Fare Cabin Embarked \n",
"0 0 A/5 21171 0.014151 NaN S \n",
"1 0 PC 17599 0.139136 C85 C \n",
"2 0 STON/O2. 3101282 0.015469 NaN S \n",
"3 0 113803 0.103644 C123 S \n",
"4 0 373450 0.015713 NaN S "
]
},
"execution_count": 230,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df.head(5)"
]
},
{
"cell_type": "code",
"execution_count": 52,
"id": "e6ffda37",
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"df = pd.read_csv(\"train.csv\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9f7c33a0",
"metadata": {},
"outputs": [],
"source": [
"# e19191c5.uam.onmicrosoft.com@emea.teams.ms"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "54dd7eaa",
"metadata": {},
"source": [
"## lab 5 ml"
]
},
{
"cell_type": "code",
"execution_count": 231,
"id": "ec55ac92",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Index(['PassengerId', 'Survived', 'Pclass', 'Name', 'Sex', 'Age', 'SibSp',\n",
" 'Parch', 'Ticket', 'Fare', 'Cabin', 'Embarked'],\n",
" dtype='object')\n"
]
}
],
"source": [
"#data\n",
"cols = df.columns\n",
"print(cols)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "40225042",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 232,
"id": "11850862",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import torch\n",
"from torch import nn\n",
"from torch.autograd import Variable\n",
"from sklearn.datasets import load_iris\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.metrics import accuracy_score\n",
"from keras.utils import to_categorical\n",
"import torch.nn.functional as F"
]
},
{
"cell_type": "code",
"execution_count": 259,
"id": "cfecc11c",
"metadata": {},
"outputs": [],
"source": [
"class Model(nn.Module):\n",
" def __init__(self, input_dim):\n",
" super(Model, self).__init__()\n",
" self.layer1 = nn.Linear(input_dim, 50)\n",
" self.layer2 = nn.Linear(50, 20)\n",
" self.layer3 = nn.Linear(20, 2)\n",
" \n",
" def forward(self, x):\n",
" x = F.relu(self.layer1(x))\n",
" x = F.relu(self.layer2(x))\n",
" x = F.softmax(self.layer3(x))\n",
" \n",
" return x\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 235,
"id": "0af12074",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/tmp/ipykernel_7802/1323642195.py:6: SettingWithCopyWarning: \n",
"A value is trying to be set on a copy of a slice from a DataFrame\n",
"\n",
"See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
" X['Sex'].replace(['female', 'male'], [0,1], inplace=True)\n"
]
},
{
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
"
\n",
"
\n",
"
Pclass
\n",
"
Sex
\n",
"
Age
\n",
"
SibSp
\n",
"
Fare
\n",
"
\n",
" \n",
" \n",
"
\n",
"
1
\n",
"
1
\n",
"
0
\n",
"
0.472229
\n",
"
1
\n",
"
0.139136
\n",
"
\n",
"
\n",
"
3
\n",
"
1
\n",
"
0
\n",
"
0.434531
\n",
"
1
\n",
"
0.103644
\n",
"
\n",
"
\n",
"
6
\n",
"
1
\n",
"
1
\n",
"
0.673285
\n",
"
0
\n",
"
0.101229
\n",
"
\n",
"
\n",
"
10
\n",
"
3
\n",
"
0
\n",
"
0.044986
\n",
"
1
\n",
"
0.032596
\n",
"
\n",
"
\n",
"
11
\n",
"
1
\n",
"
0
\n",
"
0.723549
\n",
"
0
\n",
"
0.051822
\n",
"
\n",
"
\n",
"
...
\n",
"
...
\n",
"
...
\n",
"
...
\n",
"
...
\n",
"
...
\n",
"
\n",
"
\n",
"
871
\n",
"
1
\n",
"
0
\n",
"
0.585323
\n",
"
1
\n",
"
0.102579
\n",
"
\n",
"
\n",
"
872
\n",
"
1
\n",
"
1
\n",
"
0.409399
\n",
"
0
\n",
"
0.009759
\n",
"
\n",
"
\n",
"
879
\n",
"
1
\n",
"
0
\n",
"
0.698417
\n",
"
0
\n",
"
0.162314
\n",
"
\n",
"
\n",
"
887
\n",
"
1
\n",
"
0
\n",
"
0.233476
\n",
"
0
\n",
"
0.058556
\n",
"
\n",
"
\n",
"
889
\n",
"
1
\n",
"
1
\n",
"
0.321438
\n",
"
0
\n",
"
0.058556
\n",
"
\n",
" \n",
"
\n",
"
183 rows × 5 columns
\n",
"
"
],
"text/plain": [
" Pclass Sex Age SibSp Fare\n",
"1 1 0 0.472229 1 0.139136\n",
"3 1 0 0.434531 1 0.103644\n",
"6 1 1 0.673285 0 0.101229\n",
"10 3 0 0.044986 1 0.032596\n",
"11 1 0 0.723549 0 0.051822\n",
".. ... ... ... ... ...\n",
"871 1 0 0.585323 1 0.102579\n",
"872 1 1 0.409399 0 0.009759\n",
"879 1 0 0.698417 0 0.162314\n",
"887 1 0 0.233476 0 0.058556\n",
"889 1 1 0.321438 0 0.058556\n",
"\n",
"[183 rows x 5 columns]"
]
},
"execution_count": 235,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df = df.dropna()\n",
"X = df[['Pclass', 'Sex', 'Age','SibSp', 'Fare']]\n",
"Y = df[['Survived']]\n",
"\n",
"# X.loc[:,'Age'] = X.loc[:,'Age'].fillna(X['Age'].mean())\n",
"X['Sex'].replace(['female', 'male'], [0,1], inplace=True)\n",
"\n",
"X"
]
},
{
"cell_type": "code",
"execution_count": 236,
"id": "591bfb44",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[1 1 0 1 1 1 1 0 1 0 0 1 0 1 0 0 1 0 0 0 1 0 1 0 0 0 1 0 0 0 1 1 1 1 0 1 1\n",
" 1 1 1 0 1 0 0 1 0 0 1 1 0 1 1 0 0 1 1 1 1 1 1 1 1 1 1 1 0 0 0 1 0 1 1 1 1\n",
" 1 1 1 0 1 1 1 1 1 1 0 1 0 1 1 0 1 0 1 0 1 1 1 0 0 1 0 1 0 1 0 1 1 1 0 1 1\n",
" 0 1 1 1 1 1 1 1 0 1 1 1 1 1 1 1 0 1 1 1 1 1 1 0 0 0 1 1 1 1 0 0 1 1 1 1 1\n",
" 0 1 1 1 1 1 0 1 0 0 1 1 1 1 0 1 1 0 0 1 1 0 1 1 1 1 1 1 1 0 1 0 1 1 1]\n"
]
}
],
"source": [
"from sklearn.preprocessing import LabelEncoder\n",
"Y = np.ravel(Y)\n",
"encoder = LabelEncoder()\n",
"encoder.fit(Y)\n",
"Y = encoder.transform(Y)\n",
"print(Y)"
]
},
{
"cell_type": "code",
"execution_count": 237,
"id": "8a7cac39",
"metadata": {},
"outputs": [],
"source": [
"X_train, X_test, Y_train, Y_test = train_test_split(X,Y, random_state=42, shuffle=True)"
]
},
{
"cell_type": "code",
"execution_count": 260,
"id": "93454e63",
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"Xt = torch.tensor(X_train.values, dtype = torch.float32)\n",
"Yt = torch.tensor(Y_train, dtype=torch.long)\n",
"# .reshape(-1,1)\n",
"# Yt = Y_train"
]
},
{
"cell_type": "code",
"execution_count": 261,
"id": "3aac198b",
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([137])"
]
},
"execution_count": 261,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"Yt.shape"
]
},
{
"cell_type": "code",
"execution_count": 262,
"id": "27591bf8",
"metadata": {},
"outputs": [],
"source": [
"model = Model(Xt.shape[1])\n",
"optimizer = torch.optim.Adam(model.parameters(), lr=0.01)\n",
"loss_fn = nn.CrossEntropyLoss()\n",
"epochs = 500\n",
"\n",
"def print_(loss):\n",
" print (\"The loss calculated: \", loss)\n"
]
},
{
"cell_type": "code",
"execution_count": 263,
"id": "9d700f25",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch # 1\n",
"The loss calculated: 0.6927047371864319\n",
"Epoch # 2\n",
"The loss calculated: 0.6760580539703369\n",
"Epoch # 3\n",
"The loss calculated: 0.6577760577201843\n",
"Epoch # 4\n",
"The loss calculated: 0.6410418152809143\n",
"Epoch # 5\n",
"The loss calculated: 0.6274042725563049\n",
"Epoch # 6\n",
"The loss calculated: 0.6176177263259888\n",
"Epoch # 7\n",
"The loss calculated: 0.6114543676376343\n",
"Epoch # 8\n",
"The loss calculated: 0.6079199314117432\n",
"Epoch # 9\n",
"The loss calculated: 0.6057404279708862\n",
"Epoch # 10\n",
"The loss calculated: 0.6039658188819885\n",
"Epoch # 11\n",
"The loss calculated: 0.6018784046173096\n",
"Epoch # 12\n",
"The loss calculated: 0.5988859534263611\n",
"Epoch # 13\n",
"The loss calculated: 0.5944192409515381\n",
"Epoch # 14\n",
"The loss calculated: 0.58795166015625\n",
"Epoch # 15\n",
"The loss calculated: 0.5793240666389465\n",
"Epoch # 16\n",
"The loss calculated: 0.569113552570343\n",
"Epoch # 17\n",
"The loss calculated: 0.5591343641281128\n",
"Epoch # 18\n",
"The loss calculated: 0.5525994300842285\n",
"Epoch # 19\n",
"The loss calculated: 0.549091637134552\n",
"Epoch # 20\n",
"The loss calculated: 0.5478854775428772\n",
"Epoch # 21\n",
"The loss calculated: 0.5459576845169067\n",
"Epoch # 22\n",
"The loss calculated: 0.5430701971054077\n",
"Epoch # 23\n",
"The loss calculated: 0.5398197174072266\n",
"Epoch # 24\n",
"The loss calculated: 0.5366366505622864\n",
"Epoch # 25\n",
"The loss calculated: 0.5338087677955627\n",
"Epoch # 26\n",
"The loss calculated: 0.5315443873405457\n",
"Epoch # 27\n",
"The loss calculated: 0.5298702716827393\n",
"Epoch # 28\n",
"The loss calculated: 0.5285016894340515\n",
"Epoch # 29\n",
"The loss calculated: 0.5272928476333618\n",
"Epoch # 30\n",
"The loss calculated: 0.5261989235877991\n",
"Epoch # 31\n",
"The loss calculated: 0.5251137018203735\n",
"Epoch # 32\n",
"The loss calculated: 0.5238412618637085\n",
"Epoch # 33\n",
"The loss calculated: 0.5226505398750305\n",
"Epoch # 34\n",
"The loss calculated: 0.5215187072753906\n",
"Epoch # 35\n",
"The loss calculated: 0.5204036235809326\n",
"Epoch # 36\n",
"The loss calculated: 0.5194926857948303\n",
"Epoch # 37\n",
"The loss calculated: 0.5188320875167847\n",
"Epoch # 38\n",
"The loss calculated: 0.5182497501373291\n",
"Epoch # 39\n",
"The loss calculated: 0.5176616907119751\n",
"Epoch # 40\n",
"The loss calculated: 0.5170402526855469\n",
"Epoch # 41\n",
"The loss calculated: 0.5162948369979858\n",
"Epoch # 42\n",
"The loss calculated: 0.5155003070831299\n",
"Epoch # 43\n",
"The loss calculated: 0.51481693983078\n",
"Epoch # 44\n",
"The loss calculated: 0.5142836570739746\n",
"Epoch # 45\n",
"The loss calculated: 0.5137770771980286\n",
"Epoch # 46\n",
"The loss calculated: 0.5132609009742737\n",
"Epoch # 47\n",
"The loss calculated: 0.5126983523368835\n",
"Epoch # 48\n",
"The loss calculated: 0.5120936036109924\n",
"Epoch # 49\n",
"The loss calculated: 0.5116094350814819\n",
"Epoch # 50\n",
"The loss calculated: 0.5111839175224304\n",
"Epoch # 51\n",
"The loss calculated: 0.5106979608535767\n",
"Epoch # 52\n",
"The loss calculated: 0.5101208686828613\n",
"Epoch # 53\n",
"The loss calculated: 0.5095392465591431\n",
"Epoch # 54\n",
"The loss calculated: 0.5090041756629944\n",
"Epoch # 55\n",
"The loss calculated: 0.5083613395690918\n",
"Epoch # 56\n",
"The loss calculated: 0.5075969099998474\n",
"Epoch # 57\n",
"The loss calculated: 0.5067813992500305\n",
"Epoch # 58\n",
"The loss calculated: 0.5060149431228638\n",
"Epoch # 59\n",
"The loss calculated: 0.5052304863929749\n",
"Epoch # 60\n",
"The loss calculated: 0.5044183135032654\n",
"Epoch # 61\n",
"The loss calculated: 0.5035461187362671\n",
"Epoch # 62\n",
"The loss calculated: 0.5025045871734619\n",
"Epoch # 63\n",
"The loss calculated: 0.5014879107475281\n",
"Epoch # 64\n",
"The loss calculated: 0.5006436705589294\n",
"Epoch # 65\n",
"The loss calculated: 0.499641090631485\n",
"Epoch # 66\n",
"The loss calculated: 0.4986647367477417\n",
"Epoch # 67\n",
"The loss calculated: 0.497800350189209\n",
"Epoch # 68\n",
"The loss calculated: 0.49712076783180237\n",
"Epoch # 69\n",
"The loss calculated: 0.49643078446388245\n",
"Epoch # 70\n",
"The loss calculated: 0.4957447350025177\n",
"Epoch # 71\n",
"The loss calculated: 0.4950644075870514\n",
"Epoch # 72\n",
"The loss calculated: 0.4944438636302948\n",
"Epoch # 73\n",
"The loss calculated: 0.4937107563018799\n",
"Epoch # 74\n",
"The loss calculated: 0.49320393800735474\n",
"Epoch # 75\n",
"The loss calculated: 0.49250030517578125\n",
"Epoch # 76\n",
"The loss calculated: 0.49141865968704224\n",
"Epoch # 77\n",
"The loss calculated: 0.49071067571640015\n",
"Epoch # 78\n",
"The loss calculated: 0.4899919629096985\n",
"Epoch # 79\n",
"The loss calculated: 0.48904943466186523\n",
"Epoch # 80\n",
"The loss calculated: 0.4885300099849701\n",
"Epoch # 81\n",
"The loss calculated: 0.48774540424346924\n",
"Epoch # 82\n",
"The loss calculated: 0.48720788955688477\n",
"Epoch # 83\n",
"The loss calculated: 0.4868374466896057\n",
"Epoch # 84\n",
"The loss calculated: 0.48623406887054443\n",
"Epoch # 85\n",
"The loss calculated: 0.48583683371543884\n",
"Epoch # 86\n",
"The loss calculated: 0.48502254486083984\n",
"Epoch # 87\n",
"The loss calculated: 0.4844677746295929\n",
"Epoch # 88\n",
"The loss calculated: 0.48361340165138245\n",
"Epoch # 89\n",
"The loss calculated: 0.4827542304992676\n",
"Epoch # 90\n",
"The loss calculated: 0.4817808270454407\n",
"Epoch # 91\n",
"The loss calculated: 0.4809269607067108\n",
"Epoch # 92\n",
"The loss calculated: 0.4804893136024475\n",
"Epoch # 93\n",
"The loss calculated: 0.48043856024742126\n",
"Epoch # 94\n",
"The loss calculated: 0.4801830053329468\n",
"Epoch # 95\n",
"The loss calculated: 0.479977011680603\n",
"Epoch # 96\n",
"The loss calculated: 0.47945544123649597\n",
"Epoch # 97\n",
"The loss calculated: 0.47897064685821533\n",
"Epoch # 98\n",
"The loss calculated: 0.4786403775215149\n",
"Epoch # 99\n",
"The loss calculated: 0.47828078269958496\n",
"Epoch # 100\n",
"The loss calculated: 0.47804537415504456\n",
"Epoch # 101\n",
"The loss calculated: 0.4777425527572632\n",
"Epoch # 102\n",
"The loss calculated: 0.4773750603199005\n",
"Epoch # 103\n",
"The loss calculated: 0.4768853187561035\n",
"Epoch # 104\n",
"The loss calculated: 0.4766947627067566\n",
"Epoch # 105\n",
"The loss calculated: 0.47633618116378784\n",
"Epoch # 106\n",
"The loss calculated: 0.47610870003700256\n",
"Epoch # 107\n",
"The loss calculated: 0.47584590315818787\n",
"Epoch # 108\n",
"The loss calculated: 0.47565311193466187\n",
"Epoch # 109\n",
"The loss calculated: 0.475361168384552\n",
"Epoch # 110\n",
"The loss calculated: 0.475079208612442\n",
"Epoch # 111\n",
"The loss calculated: 0.47482433915138245\n",
"Epoch # 112\n",
"The loss calculated: 0.47465214133262634\n",
"Epoch # 113\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/tmp/ipykernel_7802/3372075492.py:11: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n",
" x = F.softmax(self.layer3(x))\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"The loss calculated: 0.4745003283023834\n",
"Epoch # 114\n",
"The loss calculated: 0.47428470849990845\n",
"Epoch # 115\n",
"The loss calculated: 0.47402113676071167\n",
"Epoch # 116\n",
"The loss calculated: 0.4738253355026245\n",
"Epoch # 117\n",
"The loss calculated: 0.47366538643836975\n",
"Epoch # 118\n",
"The loss calculated: 0.47345176339149475\n",
"Epoch # 119\n",
"The loss calculated: 0.47328999638557434\n",
"Epoch # 120\n",
"The loss calculated: 0.47304701805114746\n",
"Epoch # 121\n",
"The loss calculated: 0.47283679246902466\n",
"Epoch # 122\n",
"The loss calculated: 0.47269734740257263\n",
"Epoch # 123\n",
"The loss calculated: 0.47256502509117126\n",
"Epoch # 124\n",
"The loss calculated: 0.4723707437515259\n",
"Epoch # 125\n",
"The loss calculated: 0.4721546471118927\n",
"Epoch # 126\n",
"The loss calculated: 0.4719236493110657\n",
"Epoch # 127\n",
"The loss calculated: 0.4718014895915985\n",
"Epoch # 128\n",
"The loss calculated: 0.4715701937675476\n",
"Epoch # 129\n",
"The loss calculated: 0.47162505984306335\n",
"Epoch # 130\n",
"The loss calculated: 0.47140219807624817\n",
"Epoch # 131\n",
"The loss calculated: 0.47120794653892517\n",
"Epoch # 132\n",
"The loss calculated: 0.47121524810791016\n",
"Epoch # 133\n",
"The loss calculated: 0.4708421230316162\n",
"Epoch # 134\n",
"The loss calculated: 0.47080597281455994\n",
"Epoch # 135\n",
"The loss calculated: 0.470735102891922\n",
"Epoch # 136\n",
"The loss calculated: 0.47046154737472534\n",
"Epoch # 137\n",
"The loss calculated: 0.4704940617084503\n",
"Epoch # 138\n",
"The loss calculated: 0.4704982340335846\n",
"Epoch # 139\n",
"The loss calculated: 0.470112144947052\n",
"Epoch # 140\n",
"The loss calculated: 0.4701041877269745\n",
"Epoch # 141\n",
"The loss calculated: 0.47008904814720154\n",
"Epoch # 142\n",
"The loss calculated: 0.4698803722858429\n",
"Epoch # 143\n",
"The loss calculated: 0.46982747316360474\n",
"Epoch # 144\n",
"The loss calculated: 0.469696044921875\n",
"Epoch # 145\n",
"The loss calculated: 0.46962815523147583\n",
"Epoch # 146\n",
"The loss calculated: 0.469440758228302\n",
"Epoch # 147\n",
"The loss calculated: 0.46939632296562195\n",
"Epoch # 148\n",
"The loss calculated: 0.4695526957511902\n",
"Epoch # 149\n",
"The loss calculated: 0.4697006046772003\n",
"Epoch # 150\n",
"The loss calculated: 0.4692654609680176\n",
"Epoch # 151\n",
"The loss calculated: 0.4700072407722473\n",
"Epoch # 152\n",
"The loss calculated: 0.4690340757369995\n",
"Epoch # 153\n",
"The loss calculated: 0.47001826763153076\n",
"Epoch # 154\n",
"The loss calculated: 0.46880584955215454\n",
"Epoch # 155\n",
"The loss calculated: 0.46919724345207214\n",
"Epoch # 156\n",
"The loss calculated: 0.4687418043613434\n",
"Epoch # 157\n",
"The loss calculated: 0.4687948226928711\n",
"Epoch # 158\n",
"The loss calculated: 0.46873044967651367\n",
"Epoch # 159\n",
"The loss calculated: 0.46848490834236145\n",
"Epoch # 160\n",
"The loss calculated: 0.4686104953289032\n",
"Epoch # 161\n",
"The loss calculated: 0.4683172404766083\n",
"Epoch # 162\n",
"The loss calculated: 0.46831050515174866\n",
"Epoch # 163\n",
"The loss calculated: 0.46828699111938477\n",
"Epoch # 164\n",
"The loss calculated: 0.46824583411216736\n",
"Epoch # 165\n",
"The loss calculated: 0.468075156211853\n",
"Epoch # 166\n",
"The loss calculated: 0.46814292669296265\n",
"Epoch # 167\n",
"The loss calculated: 0.46796467900276184\n",
"Epoch # 168\n",
"The loss calculated: 0.46802079677581787\n",
"Epoch # 169\n",
"The loss calculated: 0.46778491139411926\n",
"Epoch # 170\n",
"The loss calculated: 0.4679405093193054\n",
"Epoch # 171\n",
"The loss calculated: 0.46800506114959717\n",
"Epoch # 172\n",
"The loss calculated: 0.467818945646286\n",
"Epoch # 173\n",
"The loss calculated: 0.4678487181663513\n",
"Epoch # 174\n",
"The loss calculated: 0.46776196360588074\n",
"Epoch # 175\n",
"The loss calculated: 0.46756404638290405\n",
"Epoch # 176\n",
"The loss calculated: 0.4682294726371765\n",
"Epoch # 177\n",
"The loss calculated: 0.46777990460395813\n",
"Epoch # 178\n",
"The loss calculated: 0.4677632451057434\n",
"Epoch # 179\n",
"The loss calculated: 0.46777427196502686\n",
"Epoch # 180\n",
"The loss calculated: 0.46746954321861267\n",
"Epoch # 181\n",
"The loss calculated: 0.4676474630832672\n",
"Epoch # 182\n",
"The loss calculated: 0.46711796522140503\n",
"Epoch # 183\n",
"The loss calculated: 0.4677950441837311\n",
"Epoch # 184\n",
"The loss calculated: 0.46725085377693176\n",
"Epoch # 185\n",
"The loss calculated: 0.4676659107208252\n",
"Epoch # 186\n",
"The loss calculated: 0.4672679901123047\n",
"Epoch # 187\n",
"The loss calculated: 0.46727195382118225\n",
"Epoch # 188\n",
"The loss calculated: 0.466960608959198\n",
"Epoch # 189\n",
"The loss calculated: 0.46708735823631287\n",
"Epoch # 190\n",
"The loss calculated: 0.4671291708946228\n",
"Epoch # 191\n",
"The loss calculated: 0.46684736013412476\n",
"Epoch # 192\n",
"The loss calculated: 0.4667331576347351\n",
"Epoch # 193\n",
"The loss calculated: 0.46685370802879333\n",
"Epoch # 194\n",
"The loss calculated: 0.4668591618537903\n",
"Epoch # 195\n",
"The loss calculated: 0.46671974658966064\n",
"Epoch # 196\n",
"The loss calculated: 0.46653658151626587\n",
"Epoch # 197\n",
"The loss calculated: 0.46659478545188904\n",
"Epoch # 198\n",
"The loss calculated: 0.4665440022945404\n",
"Epoch # 199\n",
"The loss calculated: 0.4664462208747864\n",
"Epoch # 200\n",
"The loss calculated: 0.466394305229187\n",
"Epoch # 201\n",
"The loss calculated: 0.4665300250053406\n",
"Epoch # 202\n",
"The loss calculated: 0.4664006531238556\n",
"Epoch # 203\n",
"The loss calculated: 0.46651187539100647\n",
"Epoch # 204\n",
"The loss calculated: 0.4662490487098694\n",
"Epoch # 205\n",
"The loss calculated: 0.46683457493782043\n",
"Epoch # 206\n",
"The loss calculated: 0.46636930108070374\n",
"Epoch # 207\n",
"The loss calculated: 0.4663969576358795\n",
"Epoch # 208\n",
"The loss calculated: 0.46641668677330017\n",
"Epoch # 209\n",
"The loss calculated: 0.46628400683403015\n",
"Epoch # 210\n",
"The loss calculated: 0.4664050042629242\n",
"Epoch # 211\n",
"The loss calculated: 0.4661887586116791\n",
"Epoch # 212\n",
"The loss calculated: 0.4660308063030243\n",
"Epoch # 213\n",
"The loss calculated: 0.4661027491092682\n",
"Epoch # 214\n",
"The loss calculated: 0.4660954177379608\n",
"Epoch # 215\n",
"The loss calculated: 0.4658938944339752\n",
"Epoch # 216\n",
"The loss calculated: 0.4660359025001526\n",
"Epoch # 217\n",
"The loss calculated: 0.46567121148109436\n",
"Epoch # 218\n",
"The loss calculated: 0.4657202959060669\n",
"Epoch # 219\n",
"The loss calculated: 0.4657045900821686\n",
"Epoch # 220\n",
"The loss calculated: 0.4655347168445587\n",
"Epoch # 221\n",
"The loss calculated: 0.4654804468154907\n",
"Epoch # 222\n",
"The loss calculated: 0.4656883180141449\n",
"Epoch # 223\n",
"The loss calculated: 0.46542859077453613\n",
"Epoch # 224\n",
"The loss calculated: 0.46529003977775574\n",
"Epoch # 225\n",
"The loss calculated: 0.46543607115745544\n",
"Epoch # 226\n",
"The loss calculated: 0.46531468629837036\n",
"Epoch # 227\n",
"The loss calculated: 0.4653342068195343\n",
"Epoch # 228\n",
"The loss calculated: 0.46527451276779175\n",
"Epoch # 229\n",
"The loss calculated: 0.4652668535709381\n",
"Epoch # 230\n",
"The loss calculated: 0.46513044834136963\n",
"Epoch # 231\n",
"The loss calculated: 0.4650672972202301\n",
"Epoch # 232\n",
"The loss calculated: 0.46511510014533997\n",
"Epoch # 233\n",
"The loss calculated: 0.4647628366947174\n",
"Epoch # 234\n",
"The loss calculated: 0.4647744596004486\n",
"Epoch # 235\n",
"The loss calculated: 0.4648566246032715\n",
"Epoch # 236\n",
"The loss calculated: 0.4646404981613159\n",
"Epoch # 237\n",
"The loss calculated: 0.4645318388938904\n",
"Epoch # 238\n",
"The loss calculated: 0.46459120512008667\n",
"Epoch # 239\n",
"The loss calculated: 0.46454647183418274\n",
"Epoch # 240\n",
"The loss calculated: 0.46439239382743835\n",
"Epoch # 241\n",
"The loss calculated: 0.464549720287323\n",
"Epoch # 242\n",
"The loss calculated: 0.4642981290817261\n",
"Epoch # 243\n",
"The loss calculated: 0.4640815258026123\n",
"Epoch # 244\n",
"The loss calculated: 0.4640815258026123\n",
"Epoch # 245\n",
"The loss calculated: 0.4638811945915222\n",
"Epoch # 246\n",
"The loss calculated: 0.46409285068511963\n",
"Epoch # 247\n",
"The loss calculated: 0.46399882435798645\n",
"Epoch # 248\n",
"The loss calculated: 0.4639054536819458\n",
"Epoch # 249\n",
"The loss calculated: 0.46384960412979126\n",
"Epoch # 250\n",
"The loss calculated: 0.46365633606910706\n",
"Epoch # 251\n",
"The loss calculated: 0.4635387361049652\n",
"Epoch # 252\n",
"The loss calculated: 0.46366339921951294\n",
"Epoch # 253\n",
"The loss calculated: 0.4635831415653229\n",
"Epoch # 254\n",
"The loss calculated: 0.46347707509994507\n",
"Epoch # 255\n",
"The loss calculated: 0.4633452892303467\n",
"Epoch # 256\n",
"The loss calculated: 0.4634377658367157\n",
"Epoch # 257\n",
"The loss calculated: 0.46325498819351196\n",
"Epoch # 258\n",
"The loss calculated: 0.46343502402305603\n",
"Epoch # 259\n",
"The loss calculated: 0.46319177746772766\n",
"Epoch # 260\n",
"The loss calculated: 0.4631631076335907\n",
"Epoch # 261\n",
"The loss calculated: 0.4630383253097534\n",
"Epoch # 262\n",
"The loss calculated: 0.4629758596420288\n",
"Epoch # 263\n",
"The loss calculated: 0.46284860372543335\n",
"Epoch # 264\n",
"The loss calculated: 0.46269962191581726\n",
"Epoch # 265\n",
"The loss calculated: 0.4628857374191284\n",
"Epoch # 266\n",
"The loss calculated: 0.4627268314361572\n",
"Epoch # 267\n",
"The loss calculated: 0.46238410472869873\n",
"Epoch # 268\n",
"The loss calculated: 0.4622679352760315\n",
"Epoch # 269\n",
"The loss calculated: 0.46253955364227295\n",
"Epoch # 270\n",
"The loss calculated: 0.46243607997894287\n",
"Epoch # 271\n",
"The loss calculated: 0.4622651934623718\n",
"Epoch # 272\n",
"The loss calculated: 0.4621260166168213\n",
"Epoch # 273\n",
"The loss calculated: 0.4619852304458618\n",
"Epoch # 274\n",
"The loss calculated: 0.4621600806713104\n",
"Epoch # 275\n",
"The loss calculated: 0.46188268065452576\n",
"Epoch # 276\n",
"The loss calculated: 0.4619770050048828\n",
"Epoch # 277\n",
"The loss calculated: 0.4617985486984253\n",
"Epoch # 278\n",
"The loss calculated: 0.46143385767936707\n",
"Epoch # 279\n",
"The loss calculated: 0.4618164002895355\n",
"Epoch # 280\n",
"The loss calculated: 0.461500883102417\n",
"Epoch # 281\n",
"The loss calculated: 0.4614565372467041\n",
"Epoch # 282\n",
"The loss calculated: 0.4613018035888672\n",
"Epoch # 283\n",
"The loss calculated: 0.4612286388874054\n",
"Epoch # 284\n",
"The loss calculated: 0.4610031545162201\n",
"Epoch # 285\n",
"The loss calculated: 0.4609623849391937\n",
"Epoch # 286\n",
"The loss calculated: 0.4608198404312134\n",
"Epoch # 287\n",
"The loss calculated: 0.46074378490448\n",
"Epoch # 288\n",
"The loss calculated: 0.46068280935287476\n",
"Epoch # 289\n",
"The loss calculated: 0.46061643958091736\n",
"Epoch # 290\n",
"The loss calculated: 0.4604104459285736\n",
"Epoch # 291\n",
"The loss calculated: 0.4607124626636505\n",
"Epoch # 292\n",
"The loss calculated: 0.4607458710670471\n",
"Epoch # 293\n",
"The loss calculated: 0.4601185619831085\n",
"Epoch # 294\n",
"The loss calculated: 0.460267573595047\n",
"Epoch # 295\n",
"The loss calculated: 0.4605766832828522\n",
"Epoch # 296\n",
"The loss calculated: 0.46028855443000793\n",
"Epoch # 297\n",
"The loss calculated: 0.4599803388118744\n",
"Epoch # 298\n",
"The loss calculated: 0.4600617587566376\n",
"Epoch # 299\n",
"The loss calculated: 0.46000462770462036\n",
"Epoch # 300\n",
"The loss calculated: 0.4595383405685425\n",
"Epoch # 301\n",
"The loss calculated: 0.4598424732685089\n",
"Epoch # 302\n",
"The loss calculated: 0.4597552418708801\n",
"Epoch # 303\n",
"The loss calculated: 0.45939505100250244\n",
"Epoch # 304\n",
"The loss calculated: 0.459394633769989\n",
"Epoch # 305\n",
"The loss calculated: 0.4592142403125763\n",
"Epoch # 306\n",
"The loss calculated: 0.4591156244277954\n",
"Epoch # 307\n",
"The loss calculated: 0.4590142071247101\n",
"Epoch # 308\n",
"The loss calculated: 0.45902881026268005\n",
"Epoch # 309\n",
"The loss calculated: 0.4590888023376465\n",
"Epoch # 310\n",
"The loss calculated: 0.45860469341278076\n",
"Epoch # 311\n",
"The loss calculated: 0.45852038264274597\n",
"Epoch # 312\n",
"The loss calculated: 0.4585433900356293\n",
"Epoch # 313\n",
"The loss calculated: 0.4586207866668701\n",
"Epoch # 314\n",
"The loss calculated: 0.45869746804237366\n",
"Epoch # 315\n",
"The loss calculated: 0.4585130214691162\n",
"Epoch # 316\n",
"The loss calculated: 0.45780810713768005\n",
"Epoch # 317\n",
"The loss calculated: 0.4584527313709259\n",
"Epoch # 318\n",
"The loss calculated: 0.4584985375404358\n",
"Epoch # 319\n",
"The loss calculated: 0.4577976167201996\n",
"Epoch # 320\n",
"The loss calculated: 0.4578183591365814\n",
"Epoch # 321\n",
"The loss calculated: 0.45760011672973633\n",
"Epoch # 322\n",
"The loss calculated: 0.4573518931865692\n",
"Epoch # 323\n",
"The loss calculated: 0.45755714178085327\n",
"Epoch # 324\n",
"The loss calculated: 0.4574785828590393\n",
"Epoch # 325\n",
"The loss calculated: 0.4572897255420685\n",
"Epoch # 326\n",
"The loss calculated: 0.45682093501091003\n",
"Epoch # 327\n",
"The loss calculated: 0.4571937322616577\n",
"Epoch # 328\n",
"The loss calculated: 0.45755869150161743\n",
"Epoch # 329\n",
"The loss calculated: 0.45663607120513916\n",
"Epoch # 330\n",
"The loss calculated: 0.4570084810256958\n",
"Epoch # 331\n",
"The loss calculated: 0.45761099457740784\n",
"Epoch # 332\n",
"The loss calculated: 0.456558495759964\n",
"Epoch # 333\n",
"The loss calculated: 0.45620036125183105\n",
"Epoch # 334\n",
"The loss calculated: 0.4563443958759308\n",
"Epoch # 335\n",
"The loss calculated: 0.45647644996643066\n",
"Epoch # 336\n",
"The loss calculated: 0.45592716336250305\n",
"Epoch # 337\n",
"The loss calculated: 0.455634742975235\n",
"Epoch # 338\n",
"The loss calculated: 0.4558946192264557\n",
"Epoch # 339\n",
"The loss calculated: 0.45598289370536804\n",
"Epoch # 340\n",
"The loss calculated: 0.4554951786994934\n",
"Epoch # 341\n",
"The loss calculated: 0.4554195702075958\n",
"Epoch # 342\n",
"The loss calculated: 0.4554871618747711\n",
"Epoch # 343\n",
"The loss calculated: 0.4549509584903717\n",
"Epoch # 344\n",
"The loss calculated: 0.4548693597316742\n",
"Epoch # 345\n",
"The loss calculated: 0.4558226466178894\n",
"Epoch # 346\n",
"The loss calculated: 0.45509448647499084\n",
"Epoch # 347\n",
"The loss calculated: 0.45454123616218567\n",
"Epoch # 348\n",
"The loss calculated: 0.4553173780441284\n",
"Epoch # 349\n",
"The loss calculated: 0.4548755884170532\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch # 350\n",
"The loss calculated: 0.45442134141921997\n",
"Epoch # 351\n",
"The loss calculated: 0.4545627236366272\n",
"Epoch # 352\n",
"The loss calculated: 0.4543512463569641\n",
"Epoch # 353\n",
"The loss calculated: 0.4541962146759033\n",
"Epoch # 354\n",
"The loss calculated: 0.4540751874446869\n",
"Epoch # 355\n",
"The loss calculated: 0.45386749505996704\n",
"Epoch # 356\n",
"The loss calculated: 0.4536762833595276\n",
"Epoch # 357\n",
"The loss calculated: 0.4532167911529541\n",
"Epoch # 358\n",
"The loss calculated: 0.4538520872592926\n",
"Epoch # 359\n",
"The loss calculated: 0.45413821935653687\n",
"Epoch # 360\n",
"The loss calculated: 0.45311087369918823\n",
"Epoch # 361\n",
"The loss calculated: 0.45335227251052856\n",
"Epoch # 362\n",
"The loss calculated: 0.45350611209869385\n",
"Epoch # 363\n",
"The loss calculated: 0.45265665650367737\n",
"Epoch # 364\n",
"The loss calculated: 0.4524100124835968\n",
"Epoch # 365\n",
"The loss calculated: 0.4523312449455261\n",
"Epoch # 366\n",
"The loss calculated: 0.4522554874420166\n",
"Epoch # 367\n",
"The loss calculated: 0.4523703455924988\n",
"Epoch # 368\n",
"The loss calculated: 0.4521876573562622\n",
"Epoch # 369\n",
"The loss calculated: 0.4517895579338074\n",
"Epoch # 370\n",
"The loss calculated: 0.4517730474472046\n",
"Epoch # 371\n",
"The loss calculated: 0.4515615999698639\n",
"Epoch # 372\n",
"The loss calculated: 0.45157772302627563\n",
"Epoch # 373\n",
"The loss calculated: 0.4515098035335541\n",
"Epoch # 374\n",
"The loss calculated: 0.45118868350982666\n",
"Epoch # 375\n",
"The loss calculated: 0.45117509365081787\n",
"Epoch # 376\n",
"The loss calculated: 0.45118534564971924\n",
"Epoch # 377\n",
"The loss calculated: 0.45082926750183105\n",
"Epoch # 378\n",
"The loss calculated: 0.4507909119129181\n",
"Epoch # 379\n",
"The loss calculated: 0.45116591453552246\n",
"Epoch # 380\n",
"The loss calculated: 0.45066720247268677\n",
"Epoch # 381\n",
"The loss calculated: 0.45026636123657227\n",
"Epoch # 382\n",
"The loss calculated: 0.4510788321495056\n",
"Epoch # 383\n",
"The loss calculated: 0.4512375593185425\n",
"Epoch # 384\n",
"The loss calculated: 0.450232595205307\n",
"Epoch # 385\n",
"The loss calculated: 0.44986671209335327\n",
"Epoch # 386\n",
"The loss calculated: 0.4502098262310028\n",
"Epoch # 387\n",
"The loss calculated: 0.4510081112384796\n",
"Epoch # 388\n",
"The loss calculated: 0.4499610960483551\n",
"Epoch # 389\n",
"The loss calculated: 0.44945529103279114\n",
"Epoch # 390\n",
"The loss calculated: 0.45030856132507324\n",
"Epoch # 391\n",
"The loss calculated: 0.4493928849697113\n",
"Epoch # 392\n",
"The loss calculated: 0.4490446448326111\n",
"Epoch # 393\n",
"The loss calculated: 0.4496527910232544\n",
"Epoch # 394\n",
"The loss calculated: 0.44922882318496704\n",
"Epoch # 395\n",
"The loss calculated: 0.4484827220439911\n",
"Epoch # 396\n",
"The loss calculated: 0.44952288269996643\n",
"Epoch # 397\n",
"The loss calculated: 0.4490470588207245\n",
"Epoch # 398\n",
"The loss calculated: 0.44837456941604614\n",
"Epoch # 399\n",
"The loss calculated: 0.44843804836273193\n",
"Epoch # 400\n",
"The loss calculated: 0.44825857877731323\n",
"Epoch # 401\n",
"The loss calculated: 0.4478710889816284\n",
"Epoch # 402\n",
"The loss calculated: 0.4478342533111572\n",
"Epoch # 403\n",
"The loss calculated: 0.44727033376693726\n",
"Epoch # 404\n",
"The loss calculated: 0.4474068582057953\n",
"Epoch # 405\n",
"The loss calculated: 0.4473791718482971\n",
"Epoch # 406\n",
"The loss calculated: 0.4471847414970398\n",
"Epoch # 407\n",
"The loss calculated: 0.44691354036331177\n",
"Epoch # 408\n",
"The loss calculated: 0.44677817821502686\n",
"Epoch # 409\n",
"The loss calculated: 0.4468446969985962\n",
"Epoch # 410\n",
"The loss calculated: 0.4465027153491974\n",
"Epoch # 411\n",
"The loss calculated: 0.44606125354766846\n",
"Epoch # 412\n",
"The loss calculated: 0.44594869017601013\n",
"Epoch # 413\n",
"The loss calculated: 0.4456939101219177\n",
"Epoch # 414\n",
"The loss calculated: 0.445888489484787\n",
"Epoch # 415\n",
"The loss calculated: 0.4455548822879791\n",
"Epoch # 416\n",
"The loss calculated: 0.44548290967941284\n",
"Epoch # 417\n",
"The loss calculated: 0.44544851779937744\n",
"Epoch # 418\n",
"The loss calculated: 0.44522538781166077\n",
"Epoch # 419\n",
"The loss calculated: 0.44501474499702454\n",
"Epoch # 420\n",
"The loss calculated: 0.4449530839920044\n",
"Epoch # 421\n",
"The loss calculated: 0.4445208013057709\n",
"Epoch # 422\n",
"The loss calculated: 0.4444122314453125\n",
"Epoch # 423\n",
"The loss calculated: 0.44473087787628174\n",
"Epoch # 424\n",
"The loss calculated: 0.4442698359489441\n",
"Epoch # 425\n",
"The loss calculated: 0.44399431347846985\n",
"Epoch # 426\n",
"The loss calculated: 0.4437970817089081\n",
"Epoch # 427\n",
"The loss calculated: 0.44364386796951294\n",
"Epoch # 428\n",
"The loss calculated: 0.4437081217765808\n",
"Epoch # 429\n",
"The loss calculated: 0.4436897039413452\n",
"Epoch # 430\n",
"The loss calculated: 0.44336003065109253\n",
"Epoch # 431\n",
"The loss calculated: 0.4430985748767853\n",
"Epoch # 432\n",
"The loss calculated: 0.44310933351516724\n",
"Epoch # 433\n",
"The loss calculated: 0.4428543746471405\n",
"Epoch # 434\n",
"The loss calculated: 0.44258877635002136\n",
"Epoch # 435\n",
"The loss calculated: 0.4427826404571533\n",
"Epoch # 436\n",
"The loss calculated: 0.44258812069892883\n",
"Epoch # 437\n",
"The loss calculated: 0.442533403635025\n",
"Epoch # 438\n",
"The loss calculated: 0.44270434975624084\n",
"Epoch # 439\n",
"The loss calculated: 0.4427698850631714\n",
"Epoch # 440\n",
"The loss calculated: 0.44257086515426636\n",
"Epoch # 441\n",
"The loss calculated: 0.4425719976425171\n",
"Epoch # 442\n",
"The loss calculated: 0.4420627951622009\n",
"Epoch # 443\n",
"The loss calculated: 0.4421764612197876\n",
"Epoch # 444\n",
"The loss calculated: 0.44193679094314575\n",
"Epoch # 445\n",
"The loss calculated: 0.44186508655548096\n",
"Epoch # 446\n",
"The loss calculated: 0.44136378169059753\n",
"Epoch # 447\n",
"The loss calculated: 0.44126731157302856\n",
"Epoch # 448\n",
"The loss calculated: 0.44119781255722046\n",
"Epoch # 449\n",
"The loss calculated: 0.4413573145866394\n",
"Epoch # 450\n",
"The loss calculated: 0.4411191940307617\n",
"Epoch # 451\n",
"The loss calculated: 0.4407786428928375\n",
"Epoch # 452\n",
"The loss calculated: 0.4407300055027008\n",
"Epoch # 453\n",
"The loss calculated: 0.4404629170894623\n",
"Epoch # 454\n",
"The loss calculated: 0.44039714336395264\n",
"Epoch # 455\n",
"The loss calculated: 0.44031772017478943\n",
"Epoch # 456\n",
"The loss calculated: 0.44058850407600403\n",
"Epoch # 457\n",
"The loss calculated: 0.44026416540145874\n",
"Epoch # 458\n",
"The loss calculated: 0.4401347041130066\n",
"Epoch # 459\n",
"The loss calculated: 0.44020867347717285\n",
"Epoch # 460\n",
"The loss calculated: 0.43979671597480774\n",
"Epoch # 461\n",
"The loss calculated: 0.44035604596138\n",
"Epoch # 462\n",
"The loss calculated: 0.4401366412639618\n",
"Epoch # 463\n",
"The loss calculated: 0.4404027760028839\n",
"Epoch # 464\n",
"The loss calculated: 0.439935564994812\n",
"Epoch # 465\n",
"The loss calculated: 0.4399685561656952\n",
"Epoch # 466\n",
"The loss calculated: 0.4409003257751465\n",
"Epoch # 467\n",
"The loss calculated: 0.43949607014656067\n",
"Epoch # 468\n",
"The loss calculated: 0.4398217797279358\n",
"Epoch # 469\n",
"The loss calculated: 0.43998679518699646\n",
"Epoch # 470\n",
"The loss calculated: 0.4403824508190155\n",
"Epoch # 471\n",
"The loss calculated: 0.43901607394218445\n",
"Epoch # 472\n",
"The loss calculated: 0.44028377532958984\n",
"Epoch # 473\n",
"The loss calculated: 0.4426659643650055\n",
"Epoch # 474\n",
"The loss calculated: 0.44038379192352295\n",
"Epoch # 475\n",
"The loss calculated: 0.4395928978919983\n",
"Epoch # 476\n",
"The loss calculated: 0.44086745381355286\n",
"Epoch # 477\n",
"The loss calculated: 0.43867841362953186\n",
"Epoch # 478\n",
"The loss calculated: 0.4390256404876709\n",
"Epoch # 479\n",
"The loss calculated: 0.4390667676925659\n",
"Epoch # 480\n",
"The loss calculated: 0.4384021759033203\n",
"Epoch # 481\n",
"The loss calculated: 0.4385366439819336\n",
"Epoch # 482\n",
"The loss calculated: 0.4384676516056061\n",
"Epoch # 483\n",
"The loss calculated: 0.4386775493621826\n",
"Epoch # 484\n",
"The loss calculated: 0.43819159269332886\n",
"Epoch # 485\n",
"The loss calculated: 0.4379732608795166\n",
"Epoch # 486\n",
"The loss calculated: 0.4379722476005554\n",
"Epoch # 487\n",
"The loss calculated: 0.4376266896724701\n",
"Epoch # 488\n",
"The loss calculated: 0.4373808205127716\n",
"Epoch # 489\n",
"The loss calculated: 0.43826723098754883\n",
"Epoch # 490\n",
"The loss calculated: 0.4379383623600006\n",
"Epoch # 491\n",
"The loss calculated: 0.4372965395450592\n",
"Epoch # 492\n",
"The loss calculated: 0.4375162422657013\n",
"Epoch # 493\n",
"The loss calculated: 0.43795913457870483\n",
"Epoch # 494\n",
"The loss calculated: 0.43740007281303406\n",
"Epoch # 495\n",
"The loss calculated: 0.43741703033447266\n",
"Epoch # 496\n",
"The loss calculated: 0.4373546838760376\n",
"Epoch # 497\n",
"The loss calculated: 0.4368191957473755\n",
"Epoch # 498\n",
"The loss calculated: 0.4367024898529053\n",
"Epoch # 499\n",
"The loss calculated: 0.43679192662239075\n",
"Epoch # 500\n",
"The loss calculated: 0.436893105506897\n"
]
}
],
"source": [
"from torch.utils.data import DataLoader\n",
"\n",
"for epoch in range(1, epochs+1):\n",
" print(\"Epoch #\", epoch)\n",
" y_pred = model(Xt)\n",
"# print(y_pred)\n",
" loss = loss_fn(y_pred, Yt)\n",
" print_(loss.item())\n",
" \n",
" optimizer.zero_grad()\n",
" loss.backward()\n",
" optimizer.step()"
]
},
{
"cell_type": "code",
"execution_count": 264,
"id": "45d76c95",
"metadata": {},
"outputs": [],
"source": [
"x_test = torch.tensor(X_test.values, dtype=torch.float32)"
]
},
{
"cell_type": "code",
"execution_count": 271,
"id": "5e98206b",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/tmp/ipykernel_7802/3372075492.py:11: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n",
" x = F.softmax(self.layer3(x))\n"
]
}
],
"source": [
"pred = model(x_test)"
]
},
{
"cell_type": "code",
"execution_count": 272,
"id": "35d64340",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[1.3141002e-01, 8.6859006e-01],\n",
" [3.0172759e-16, 1.0000000e+00],\n",
" [5.9731257e-21, 1.0000000e+00],\n",
" [8.7287611e-01, 1.2712391e-01],\n",
" [3.3298880e-01, 6.6701120e-01],\n",
" [9.9992323e-01, 7.6730175e-05],\n",
" [6.9742590e-01, 3.0257410e-01],\n",
" [1.8122771e-10, 1.0000000e+00],\n",
" [8.1137923e-18, 1.0000000e+00],\n",
" [9.9391985e-01, 6.0801902e-03],\n",
" [9.9800962e-01, 1.9904438e-03],\n",
" [1.4347603e-12, 1.0000000e+00],\n",
" [8.8945550e-01, 1.1054446e-01],\n",
" [5.3068206e-19, 1.0000000e+00],\n",
" [4.4245785e-01, 5.5754209e-01],\n",
" [3.9323148e-01, 6.0676849e-01],\n",
" [5.0538932e-23, 1.0000000e+00],\n",
" [6.8482041e-01, 3.1517953e-01],\n",
" [9.9650586e-01, 3.4941665e-03],\n",
" [3.6827392e-24, 1.0000000e+00],\n",
" [3.4629088e-12, 1.0000000e+00],\n",
" [2.4781654e-11, 1.0000000e+00],\n",
" [8.4075117e-01, 1.5924890e-01],\n",
" [9.9999881e-01, 1.2382451e-06],\n",
" [9.9950111e-01, 4.9885432e-04],\n",
" [1.1888127e-14, 1.0000000e+00],\n",
" [1.5869159e-14, 1.0000000e+00],\n",
" [9.4683814e-01, 5.3161871e-02],\n",
" [7.3645154e-08, 9.9999988e-01],\n",
" [1.2287432e-11, 1.0000000e+00],\n",
" [5.7253930e-15, 1.0000000e+00],\n",
" [7.9019060e-08, 9.9999988e-01],\n",
" [5.5769521e-01, 4.4230482e-01],\n",
" [1.8103112e-14, 1.0000000e+00],\n",
" [9.9812454e-01, 1.8754901e-03],\n",
" [2.5346470e-05, 9.9997461e-01],\n",
" [1.6169167e-17, 1.0000000e+00],\n",
" [9.3050295e-01, 6.9496997e-02],\n",
" [6.1799776e-02, 9.3820024e-01],\n",
" [9.7120519e-06, 9.9999034e-01],\n",
" [9.9844283e-01, 1.5571705e-03],\n",
" [8.0438519e-01, 1.9561480e-01],\n",
" [2.0653886e-16, 1.0000000e+00],\n",
" [7.0155847e-01, 2.9844159e-01],\n",
" [9.9505252e-01, 4.9475045e-03],\n",
" [9.3824464e-01, 6.1755374e-02]], dtype=float32)"
]
},
"execution_count": 272,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pred = pred.detach().numpy()\n",
"pred"
]
},
{
"cell_type": "code",
"execution_count": 269,
"id": "5c18f80f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The accuracy is 0.7391304347826086\n"
]
}
],
"source": [
"print (\"The accuracy is\", accuracy_score(Y_test, np.argmax(pred, axis=1)))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a4638b1d",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}