diff --git a/dane.ipynb b/dane.ipynb index 379b36b..d65a001 100644 --- a/dane.ipynb +++ b/dane.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "id": "3473477b", "metadata": {}, "outputs": [ @@ -10,9 +10,9 @@ "name": "stdout", "output_type": "stream", "text": [ - "Downloading titanic.zip to /home/gedin/Studia/InzUczeniaMaszynowego/zadania\n", - "100%|███████████████████████████████████████| 34.1k/34.1k [00:00<00:00, 212kB/s]\n", - "100%|███████████████████████████████████████| 34.1k/34.1k [00:00<00:00, 212kB/s]\n" + "/home/gedin/.local/lib/python3.10/site-packages/requests/__init__.py:102: RequestsDependencyWarning: urllib3 (1.26.13) or chardet (5.1.0)/charset_normalizer (2.0.12) doesn't match a supported version!\n", + " warnings.warn(\"urllib3 ({}) or chardet ({})/charset_normalizer ({}) doesn't match a supported \"\n", + "titanic.zip: Skipping, found more recently modified local copy (use --force to force download)\n" ] } ], @@ -42,6 +42,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "b6adfde9", "metadata": {}, @@ -51,7 +52,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 2, "id": "a9d9a8ee", "metadata": {}, "outputs": [ @@ -71,7 +72,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 4, "id": "bf08fe16", "metadata": {}, "outputs": [], @@ -81,7 +82,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 228, "id": "fc59f320", "metadata": {}, "outputs": [], @@ -91,7 +92,7 @@ }, { "cell_type": "code", - "execution_count": 28, + "execution_count": 47, "id": "aa5ea30b", "metadata": {}, "outputs": [ @@ -233,7 +234,7 @@ "4 0 373450 8.0500 NaN S " ] }, - "execution_count": 28, + "execution_count": 47, "metadata": {}, "output_type": "execute_result" } @@ -244,7 +245,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 48, "id": "32d4140c", "metadata": {}, "outputs": [ @@ -385,7 +386,7 @@ "max 6.000000 512.329200 " ] }, - "execution_count": 10, + "execution_count": 48, "metadata": {}, "output_type": "execute_result" } @@ -396,7 +397,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 49, "id": "920ea21b", "metadata": {}, "outputs": [ @@ -407,7 +408,7 @@ " ]], dtype=object)" ] }, - "execution_count": 26, + "execution_count": 49, "metadata": {}, "output_type": "execute_result" }, @@ -428,9 +429,11 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": 50, "id": "be20c939", - "metadata": {}, + "metadata": { + "scrolled": true + }, "outputs": [ { "data": { @@ -438,7 +441,7 @@ "" ] }, - "execution_count": 27, + "execution_count": 50, "metadata": {}, "output_type": "execute_result" }, @@ -461,153 +464,12 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 51, "id": "8286046e", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
PassengerIdSurvivedPclassNameSexAgeSibSpParchTicketFareCabinEmbarked
0103Braund, Mr. Owen Harrismale22.010A/5 211717.2500NaNS
1211Cumings, Mrs. John Bradley (Florence Briggs Th...female38.010PC 1759971.2833C85C
2313Heikkinen, Miss. Lainafemale26.000STON/O2. 31012827.9250NaNS
3411Futrelle, Mrs. Jacques Heath (Lily May Peel)female35.01011380353.1000C123S
4503Allen, Mr. William Henrymale35.0003734508.0500NaNS
\n", - "
" - ], - "text/plain": [ - " PassengerId Survived Pclass \\\n", - "0 1 0 3 \n", - "1 2 1 1 \n", - "2 3 1 3 \n", - "3 4 1 1 \n", - "4 5 0 3 \n", - "\n", - " Name Sex Age SibSp \\\n", - "0 Braund, Mr. Owen Harris male 22.0 1 \n", - "1 Cumings, Mrs. John Bradley (Florence Briggs Th... female 38.0 1 \n", - "2 Heikkinen, Miss. Laina female 26.0 0 \n", - "3 Futrelle, Mrs. Jacques Heath (Lily May Peel) female 35.0 1 \n", - "4 Allen, Mr. William Henry male 35.0 0 \n", - "\n", - " Parch Ticket Fare Cabin Embarked \n", - "0 0 A/5 21171 7.2500 NaN S \n", - "1 0 PC 17599 71.2833 C85 C \n", - "2 0 STON/O2. 3101282 7.9250 NaN S \n", - "3 0 113803 53.1000 C123 S \n", - "4 0 373450 8.0500 NaN S " - ] - }, - "execution_count": 15, - "metadata": {}, - "output_type": "execute_result" - } - ], + "metadata": { + "scrolled": true + }, + "outputs": [], "source": [ "# df.dropna()\n", "#df.fillna()" @@ -615,7 +477,7 @@ }, { "cell_type": "code", - "execution_count": 40, + "execution_count": 229, "id": "1ed8c693", "metadata": {}, "outputs": [], @@ -627,7 +489,7 @@ }, { "cell_type": "code", - "execution_count": 42, + "execution_count": 230, "id": "d5a0fa72", "metadata": {}, "outputs": [ @@ -769,7 +631,7 @@ "4 0 373450 0.015713 NaN S " ] }, - "execution_count": 42, + "execution_count": 230, "metadata": {}, "output_type": "execute_result" } @@ -780,10 +642,1536 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 52, "id": "e6ffda37", "metadata": {}, "outputs": [], + "source": [ + "import pandas as pd\n", + "df = pd.read_csv(\"train.csv\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9f7c33a0", + "metadata": {}, + "outputs": [], + "source": [ + "# e19191c5.uam.onmicrosoft.com@emea.teams.ms" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "54dd7eaa", + "metadata": {}, + "source": [ + "## lab 5 ml" + ] + }, + { + "cell_type": "code", + "execution_count": 231, + "id": "ec55ac92", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Index(['PassengerId', 'Survived', 'Pclass', 'Name', 'Sex', 'Age', 'SibSp',\n", + " 'Parch', 'Ticket', 'Fare', 'Cabin', 'Embarked'],\n", + " dtype='object')\n" + ] + } + ], + "source": [ + "#data\n", + "cols = df.columns\n", + "print(cols)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "40225042", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 232, + "id": "11850862", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import torch\n", + "from torch import nn\n", + "from torch.autograd import Variable\n", + "from sklearn.datasets import load_iris\n", + "from sklearn.model_selection import train_test_split\n", + "from sklearn.metrics import accuracy_score\n", + "from keras.utils import to_categorical\n", + "import torch.nn.functional as F" + ] + }, + { + "cell_type": "code", + "execution_count": 259, + "id": "cfecc11c", + "metadata": {}, + "outputs": [], + "source": [ + "class Model(nn.Module):\n", + " def __init__(self, input_dim):\n", + " super(Model, self).__init__()\n", + " self.layer1 = nn.Linear(input_dim, 50)\n", + " self.layer2 = nn.Linear(50, 20)\n", + " self.layer3 = nn.Linear(20, 2)\n", + " \n", + " def forward(self, x):\n", + " x = F.relu(self.layer1(x))\n", + " x = F.relu(self.layer2(x))\n", + " x = F.softmax(self.layer3(x))\n", + " \n", + " return x\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": 235, + "id": "0af12074", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_7802/1323642195.py:6: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " X['Sex'].replace(['female', 'male'], [0,1], inplace=True)\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
PclassSexAgeSibSpFare
1100.47222910.139136
3100.43453110.103644
6110.67328500.101229
10300.04498610.032596
11100.72354900.051822
..................
871100.58532310.102579
872110.40939900.009759
879100.69841700.162314
887100.23347600.058556
889110.32143800.058556
\n", + "

183 rows × 5 columns

\n", + "
" + ], + "text/plain": [ + " Pclass Sex Age SibSp Fare\n", + "1 1 0 0.472229 1 0.139136\n", + "3 1 0 0.434531 1 0.103644\n", + "6 1 1 0.673285 0 0.101229\n", + "10 3 0 0.044986 1 0.032596\n", + "11 1 0 0.723549 0 0.051822\n", + ".. ... ... ... ... ...\n", + "871 1 0 0.585323 1 0.102579\n", + "872 1 1 0.409399 0 0.009759\n", + "879 1 0 0.698417 0 0.162314\n", + "887 1 0 0.233476 0 0.058556\n", + "889 1 1 0.321438 0 0.058556\n", + "\n", + "[183 rows x 5 columns]" + ] + }, + "execution_count": 235, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df = df.dropna()\n", + "X = df[['Pclass', 'Sex', 'Age','SibSp', 'Fare']]\n", + "Y = df[['Survived']]\n", + "\n", + "# X.loc[:,'Age'] = X.loc[:,'Age'].fillna(X['Age'].mean())\n", + "X['Sex'].replace(['female', 'male'], [0,1], inplace=True)\n", + "\n", + "X" + ] + }, + { + "cell_type": "code", + "execution_count": 236, + "id": "591bfb44", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[1 1 0 1 1 1 1 0 1 0 0 1 0 1 0 0 1 0 0 0 1 0 1 0 0 0 1 0 0 0 1 1 1 1 0 1 1\n", + " 1 1 1 0 1 0 0 1 0 0 1 1 0 1 1 0 0 1 1 1 1 1 1 1 1 1 1 1 0 0 0 1 0 1 1 1 1\n", + " 1 1 1 0 1 1 1 1 1 1 0 1 0 1 1 0 1 0 1 0 1 1 1 0 0 1 0 1 0 1 0 1 1 1 0 1 1\n", + " 0 1 1 1 1 1 1 1 0 1 1 1 1 1 1 1 0 1 1 1 1 1 1 0 0 0 1 1 1 1 0 0 1 1 1 1 1\n", + " 0 1 1 1 1 1 0 1 0 0 1 1 1 1 0 1 1 0 0 1 1 0 1 1 1 1 1 1 1 0 1 0 1 1 1]\n" + ] + } + ], + "source": [ + "from sklearn.preprocessing import LabelEncoder\n", + "Y = np.ravel(Y)\n", + "encoder = LabelEncoder()\n", + "encoder.fit(Y)\n", + "Y = encoder.transform(Y)\n", + "print(Y)" + ] + }, + { + "cell_type": "code", + "execution_count": 237, + "id": "8a7cac39", + "metadata": {}, + "outputs": [], + "source": [ + "X_train, X_test, Y_train, Y_test = train_test_split(X,Y, random_state=42, shuffle=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 260, + "id": "93454e63", + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "Xt = torch.tensor(X_train.values, dtype = torch.float32)\n", + "Yt = torch.tensor(Y_train, dtype=torch.long)\n", + "# .reshape(-1,1)\n", + "# Yt = Y_train" + ] + }, + { + "cell_type": "code", + "execution_count": 261, + "id": "3aac198b", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([137])" + ] + }, + "execution_count": 261, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "Yt.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 262, + "id": "27591bf8", + "metadata": {}, + "outputs": [], + "source": [ + "model = Model(Xt.shape[1])\n", + "optimizer = torch.optim.Adam(model.parameters(), lr=0.01)\n", + "loss_fn = nn.CrossEntropyLoss()\n", + "epochs = 500\n", + "\n", + "def print_(loss):\n", + " print (\"The loss calculated: \", loss)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 263, + "id": "9d700f25", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch # 1\n", + "The loss calculated: 0.6927047371864319\n", + "Epoch # 2\n", + "The loss calculated: 0.6760580539703369\n", + "Epoch # 3\n", + "The loss calculated: 0.6577760577201843\n", + "Epoch # 4\n", + "The loss calculated: 0.6410418152809143\n", + "Epoch # 5\n", + "The loss calculated: 0.6274042725563049\n", + "Epoch # 6\n", + "The loss calculated: 0.6176177263259888\n", + "Epoch # 7\n", + "The loss calculated: 0.6114543676376343\n", + "Epoch # 8\n", + "The loss calculated: 0.6079199314117432\n", + "Epoch # 9\n", + "The loss calculated: 0.6057404279708862\n", + "Epoch # 10\n", + "The loss calculated: 0.6039658188819885\n", + "Epoch # 11\n", + "The loss calculated: 0.6018784046173096\n", + "Epoch # 12\n", + "The loss calculated: 0.5988859534263611\n", + "Epoch # 13\n", + "The loss calculated: 0.5944192409515381\n", + "Epoch # 14\n", + "The loss calculated: 0.58795166015625\n", + "Epoch # 15\n", + "The loss calculated: 0.5793240666389465\n", + "Epoch # 16\n", + "The loss calculated: 0.569113552570343\n", + "Epoch # 17\n", + "The loss calculated: 0.5591343641281128\n", + "Epoch # 18\n", + "The loss calculated: 0.5525994300842285\n", + "Epoch # 19\n", + "The loss calculated: 0.549091637134552\n", + "Epoch # 20\n", + "The loss calculated: 0.5478854775428772\n", + "Epoch # 21\n", + "The loss calculated: 0.5459576845169067\n", + "Epoch # 22\n", + "The loss calculated: 0.5430701971054077\n", + "Epoch # 23\n", + "The loss calculated: 0.5398197174072266\n", + "Epoch # 24\n", + "The loss calculated: 0.5366366505622864\n", + "Epoch # 25\n", + "The loss calculated: 0.5338087677955627\n", + "Epoch # 26\n", + "The loss calculated: 0.5315443873405457\n", + "Epoch # 27\n", + "The loss calculated: 0.5298702716827393\n", + "Epoch # 28\n", + "The loss calculated: 0.5285016894340515\n", + "Epoch # 29\n", + "The loss calculated: 0.5272928476333618\n", + "Epoch # 30\n", + "The loss calculated: 0.5261989235877991\n", + "Epoch # 31\n", + "The loss calculated: 0.5251137018203735\n", + "Epoch # 32\n", + "The loss calculated: 0.5238412618637085\n", + "Epoch # 33\n", + "The loss calculated: 0.5226505398750305\n", + "Epoch # 34\n", + "The loss calculated: 0.5215187072753906\n", + "Epoch # 35\n", + "The loss calculated: 0.5204036235809326\n", + "Epoch # 36\n", + "The loss calculated: 0.5194926857948303\n", + "Epoch # 37\n", + "The loss calculated: 0.5188320875167847\n", + "Epoch # 38\n", + "The loss calculated: 0.5182497501373291\n", + "Epoch # 39\n", + "The loss calculated: 0.5176616907119751\n", + "Epoch # 40\n", + "The loss calculated: 0.5170402526855469\n", + "Epoch # 41\n", + "The loss calculated: 0.5162948369979858\n", + "Epoch # 42\n", + "The loss calculated: 0.5155003070831299\n", + "Epoch # 43\n", + "The loss calculated: 0.51481693983078\n", + "Epoch # 44\n", + "The loss calculated: 0.5142836570739746\n", + "Epoch # 45\n", + "The loss calculated: 0.5137770771980286\n", + "Epoch # 46\n", + "The loss calculated: 0.5132609009742737\n", + "Epoch # 47\n", + "The loss calculated: 0.5126983523368835\n", + "Epoch # 48\n", + "The loss calculated: 0.5120936036109924\n", + "Epoch # 49\n", + "The loss calculated: 0.5116094350814819\n", + "Epoch # 50\n", + "The loss calculated: 0.5111839175224304\n", + "Epoch # 51\n", + "The loss calculated: 0.5106979608535767\n", + "Epoch # 52\n", + "The loss calculated: 0.5101208686828613\n", + "Epoch # 53\n", + "The loss calculated: 0.5095392465591431\n", + "Epoch # 54\n", + "The loss calculated: 0.5090041756629944\n", + "Epoch # 55\n", + "The loss calculated: 0.5083613395690918\n", + "Epoch # 56\n", + "The loss calculated: 0.5075969099998474\n", + "Epoch # 57\n", + "The loss calculated: 0.5067813992500305\n", + "Epoch # 58\n", + "The loss calculated: 0.5060149431228638\n", + "Epoch # 59\n", + "The loss calculated: 0.5052304863929749\n", + "Epoch # 60\n", + "The loss calculated: 0.5044183135032654\n", + "Epoch # 61\n", + "The loss calculated: 0.5035461187362671\n", + "Epoch # 62\n", + "The loss calculated: 0.5025045871734619\n", + "Epoch # 63\n", + "The loss calculated: 0.5014879107475281\n", + "Epoch # 64\n", + "The loss calculated: 0.5006436705589294\n", + "Epoch # 65\n", + "The loss calculated: 0.499641090631485\n", + "Epoch # 66\n", + "The loss calculated: 0.4986647367477417\n", + "Epoch # 67\n", + "The loss calculated: 0.497800350189209\n", + "Epoch # 68\n", + "The loss calculated: 0.49712076783180237\n", + "Epoch # 69\n", + "The loss calculated: 0.49643078446388245\n", + "Epoch # 70\n", + "The loss calculated: 0.4957447350025177\n", + "Epoch # 71\n", + "The loss calculated: 0.4950644075870514\n", + "Epoch # 72\n", + "The loss calculated: 0.4944438636302948\n", + "Epoch # 73\n", + "The loss calculated: 0.4937107563018799\n", + "Epoch # 74\n", + "The loss calculated: 0.49320393800735474\n", + "Epoch # 75\n", + "The loss calculated: 0.49250030517578125\n", + "Epoch # 76\n", + "The loss calculated: 0.49141865968704224\n", + "Epoch # 77\n", + "The loss calculated: 0.49071067571640015\n", + "Epoch # 78\n", + "The loss calculated: 0.4899919629096985\n", + "Epoch # 79\n", + "The loss calculated: 0.48904943466186523\n", + "Epoch # 80\n", + "The loss calculated: 0.4885300099849701\n", + "Epoch # 81\n", + "The loss calculated: 0.48774540424346924\n", + "Epoch # 82\n", + "The loss calculated: 0.48720788955688477\n", + "Epoch # 83\n", + "The loss calculated: 0.4868374466896057\n", + "Epoch # 84\n", + "The loss calculated: 0.48623406887054443\n", + "Epoch # 85\n", + "The loss calculated: 0.48583683371543884\n", + "Epoch # 86\n", + "The loss calculated: 0.48502254486083984\n", + "Epoch # 87\n", + "The loss calculated: 0.4844677746295929\n", + "Epoch # 88\n", + "The loss calculated: 0.48361340165138245\n", + "Epoch # 89\n", + "The loss calculated: 0.4827542304992676\n", + "Epoch # 90\n", + "The loss calculated: 0.4817808270454407\n", + "Epoch # 91\n", + "The loss calculated: 0.4809269607067108\n", + "Epoch # 92\n", + "The loss calculated: 0.4804893136024475\n", + "Epoch # 93\n", + "The loss calculated: 0.48043856024742126\n", + "Epoch # 94\n", + "The loss calculated: 0.4801830053329468\n", + "Epoch # 95\n", + "The loss calculated: 0.479977011680603\n", + "Epoch # 96\n", + "The loss calculated: 0.47945544123649597\n", + "Epoch # 97\n", + "The loss calculated: 0.47897064685821533\n", + "Epoch # 98\n", + "The loss calculated: 0.4786403775215149\n", + "Epoch # 99\n", + "The loss calculated: 0.47828078269958496\n", + "Epoch # 100\n", + "The loss calculated: 0.47804537415504456\n", + "Epoch # 101\n", + "The loss calculated: 0.4777425527572632\n", + "Epoch # 102\n", + "The loss calculated: 0.4773750603199005\n", + "Epoch # 103\n", + "The loss calculated: 0.4768853187561035\n", + "Epoch # 104\n", + "The loss calculated: 0.4766947627067566\n", + "Epoch # 105\n", + "The loss calculated: 0.47633618116378784\n", + "Epoch # 106\n", + "The loss calculated: 0.47610870003700256\n", + "Epoch # 107\n", + "The loss calculated: 0.47584590315818787\n", + "Epoch # 108\n", + "The loss calculated: 0.47565311193466187\n", + "Epoch # 109\n", + "The loss calculated: 0.475361168384552\n", + "Epoch # 110\n", + "The loss calculated: 0.475079208612442\n", + "Epoch # 111\n", + "The loss calculated: 0.47482433915138245\n", + "Epoch # 112\n", + "The loss calculated: 0.47465214133262634\n", + "Epoch # 113\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_7802/3372075492.py:11: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n", + " x = F.softmax(self.layer3(x))\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The loss calculated: 0.4745003283023834\n", + "Epoch # 114\n", + "The loss calculated: 0.47428470849990845\n", + "Epoch # 115\n", + "The loss calculated: 0.47402113676071167\n", + "Epoch # 116\n", + "The loss calculated: 0.4738253355026245\n", + "Epoch # 117\n", + "The loss calculated: 0.47366538643836975\n", + "Epoch # 118\n", + "The loss calculated: 0.47345176339149475\n", + "Epoch # 119\n", + "The loss calculated: 0.47328999638557434\n", + "Epoch # 120\n", + "The loss calculated: 0.47304701805114746\n", + "Epoch # 121\n", + "The loss calculated: 0.47283679246902466\n", + "Epoch # 122\n", + "The loss calculated: 0.47269734740257263\n", + "Epoch # 123\n", + "The loss calculated: 0.47256502509117126\n", + "Epoch # 124\n", + "The loss calculated: 0.4723707437515259\n", + "Epoch # 125\n", + "The loss calculated: 0.4721546471118927\n", + "Epoch # 126\n", + "The loss calculated: 0.4719236493110657\n", + "Epoch # 127\n", + "The loss calculated: 0.4718014895915985\n", + "Epoch # 128\n", + "The loss calculated: 0.4715701937675476\n", + "Epoch # 129\n", + "The loss calculated: 0.47162505984306335\n", + "Epoch # 130\n", + "The loss calculated: 0.47140219807624817\n", + "Epoch # 131\n", + "The loss calculated: 0.47120794653892517\n", + "Epoch # 132\n", + "The loss calculated: 0.47121524810791016\n", + "Epoch # 133\n", + "The loss calculated: 0.4708421230316162\n", + "Epoch # 134\n", + "The loss calculated: 0.47080597281455994\n", + "Epoch # 135\n", + "The loss calculated: 0.470735102891922\n", + "Epoch # 136\n", + "The loss calculated: 0.47046154737472534\n", + "Epoch # 137\n", + "The loss calculated: 0.4704940617084503\n", + "Epoch # 138\n", + "The loss calculated: 0.4704982340335846\n", + "Epoch # 139\n", + "The loss calculated: 0.470112144947052\n", + "Epoch # 140\n", + "The loss calculated: 0.4701041877269745\n", + "Epoch # 141\n", + "The loss calculated: 0.47008904814720154\n", + "Epoch # 142\n", + "The loss calculated: 0.4698803722858429\n", + "Epoch # 143\n", + "The loss calculated: 0.46982747316360474\n", + "Epoch # 144\n", + "The loss calculated: 0.469696044921875\n", + "Epoch # 145\n", + "The loss calculated: 0.46962815523147583\n", + "Epoch # 146\n", + "The loss calculated: 0.469440758228302\n", + "Epoch # 147\n", + "The loss calculated: 0.46939632296562195\n", + "Epoch # 148\n", + "The loss calculated: 0.4695526957511902\n", + "Epoch # 149\n", + "The loss calculated: 0.4697006046772003\n", + "Epoch # 150\n", + "The loss calculated: 0.4692654609680176\n", + "Epoch # 151\n", + "The loss calculated: 0.4700072407722473\n", + "Epoch # 152\n", + "The loss calculated: 0.4690340757369995\n", + "Epoch # 153\n", + "The loss calculated: 0.47001826763153076\n", + "Epoch # 154\n", + "The loss calculated: 0.46880584955215454\n", + "Epoch # 155\n", + "The loss calculated: 0.46919724345207214\n", + "Epoch # 156\n", + "The loss calculated: 0.4687418043613434\n", + "Epoch # 157\n", + "The loss calculated: 0.4687948226928711\n", + "Epoch # 158\n", + "The loss calculated: 0.46873044967651367\n", + "Epoch # 159\n", + "The loss calculated: 0.46848490834236145\n", + "Epoch # 160\n", + "The loss calculated: 0.4686104953289032\n", + "Epoch # 161\n", + "The loss calculated: 0.4683172404766083\n", + "Epoch # 162\n", + "The loss calculated: 0.46831050515174866\n", + "Epoch # 163\n", + "The loss calculated: 0.46828699111938477\n", + "Epoch # 164\n", + "The loss calculated: 0.46824583411216736\n", + "Epoch # 165\n", + "The loss calculated: 0.468075156211853\n", + "Epoch # 166\n", + "The loss calculated: 0.46814292669296265\n", + "Epoch # 167\n", + "The loss calculated: 0.46796467900276184\n", + "Epoch # 168\n", + "The loss calculated: 0.46802079677581787\n", + "Epoch # 169\n", + "The loss calculated: 0.46778491139411926\n", + "Epoch # 170\n", + "The loss calculated: 0.4679405093193054\n", + "Epoch # 171\n", + "The loss calculated: 0.46800506114959717\n", + "Epoch # 172\n", + "The loss calculated: 0.467818945646286\n", + "Epoch # 173\n", + "The loss calculated: 0.4678487181663513\n", + "Epoch # 174\n", + "The loss calculated: 0.46776196360588074\n", + "Epoch # 175\n", + "The loss calculated: 0.46756404638290405\n", + "Epoch # 176\n", + "The loss calculated: 0.4682294726371765\n", + "Epoch # 177\n", + "The loss calculated: 0.46777990460395813\n", + "Epoch # 178\n", + "The loss calculated: 0.4677632451057434\n", + "Epoch # 179\n", + "The loss calculated: 0.46777427196502686\n", + "Epoch # 180\n", + "The loss calculated: 0.46746954321861267\n", + "Epoch # 181\n", + "The loss calculated: 0.4676474630832672\n", + "Epoch # 182\n", + "The loss calculated: 0.46711796522140503\n", + "Epoch # 183\n", + "The loss calculated: 0.4677950441837311\n", + "Epoch # 184\n", + "The loss calculated: 0.46725085377693176\n", + "Epoch # 185\n", + "The loss calculated: 0.4676659107208252\n", + "Epoch # 186\n", + "The loss calculated: 0.4672679901123047\n", + "Epoch # 187\n", + "The loss calculated: 0.46727195382118225\n", + "Epoch # 188\n", + "The loss calculated: 0.466960608959198\n", + "Epoch # 189\n", + "The loss calculated: 0.46708735823631287\n", + "Epoch # 190\n", + "The loss calculated: 0.4671291708946228\n", + "Epoch # 191\n", + "The loss calculated: 0.46684736013412476\n", + "Epoch # 192\n", + "The loss calculated: 0.4667331576347351\n", + "Epoch # 193\n", + "The loss calculated: 0.46685370802879333\n", + "Epoch # 194\n", + "The loss calculated: 0.4668591618537903\n", + "Epoch # 195\n", + "The loss calculated: 0.46671974658966064\n", + "Epoch # 196\n", + "The loss calculated: 0.46653658151626587\n", + "Epoch # 197\n", + "The loss calculated: 0.46659478545188904\n", + "Epoch # 198\n", + "The loss calculated: 0.4665440022945404\n", + "Epoch # 199\n", + "The loss calculated: 0.4664462208747864\n", + "Epoch # 200\n", + "The loss calculated: 0.466394305229187\n", + "Epoch # 201\n", + "The loss calculated: 0.4665300250053406\n", + "Epoch # 202\n", + "The loss calculated: 0.4664006531238556\n", + "Epoch # 203\n", + "The loss calculated: 0.46651187539100647\n", + "Epoch # 204\n", + "The loss calculated: 0.4662490487098694\n", + "Epoch # 205\n", + "The loss calculated: 0.46683457493782043\n", + "Epoch # 206\n", + "The loss calculated: 0.46636930108070374\n", + "Epoch # 207\n", + "The loss calculated: 0.4663969576358795\n", + "Epoch # 208\n", + "The loss calculated: 0.46641668677330017\n", + "Epoch # 209\n", + "The loss calculated: 0.46628400683403015\n", + "Epoch # 210\n", + "The loss calculated: 0.4664050042629242\n", + "Epoch # 211\n", + "The loss calculated: 0.4661887586116791\n", + "Epoch # 212\n", + "The loss calculated: 0.4660308063030243\n", + "Epoch # 213\n", + "The loss calculated: 0.4661027491092682\n", + "Epoch # 214\n", + "The loss calculated: 0.4660954177379608\n", + "Epoch # 215\n", + "The loss calculated: 0.4658938944339752\n", + "Epoch # 216\n", + "The loss calculated: 0.4660359025001526\n", + "Epoch # 217\n", + "The loss calculated: 0.46567121148109436\n", + "Epoch # 218\n", + "The loss calculated: 0.4657202959060669\n", + "Epoch # 219\n", + "The loss calculated: 0.4657045900821686\n", + "Epoch # 220\n", + "The loss calculated: 0.4655347168445587\n", + "Epoch # 221\n", + "The loss calculated: 0.4654804468154907\n", + "Epoch # 222\n", + "The loss calculated: 0.4656883180141449\n", + "Epoch # 223\n", + "The loss calculated: 0.46542859077453613\n", + "Epoch # 224\n", + "The loss calculated: 0.46529003977775574\n", + "Epoch # 225\n", + "The loss calculated: 0.46543607115745544\n", + "Epoch # 226\n", + "The loss calculated: 0.46531468629837036\n", + "Epoch # 227\n", + "The loss calculated: 0.4653342068195343\n", + "Epoch # 228\n", + "The loss calculated: 0.46527451276779175\n", + "Epoch # 229\n", + "The loss calculated: 0.4652668535709381\n", + "Epoch # 230\n", + "The loss calculated: 0.46513044834136963\n", + "Epoch # 231\n", + "The loss calculated: 0.4650672972202301\n", + "Epoch # 232\n", + "The loss calculated: 0.46511510014533997\n", + "Epoch # 233\n", + "The loss calculated: 0.4647628366947174\n", + "Epoch # 234\n", + "The loss calculated: 0.4647744596004486\n", + "Epoch # 235\n", + "The loss calculated: 0.4648566246032715\n", + "Epoch # 236\n", + "The loss calculated: 0.4646404981613159\n", + "Epoch # 237\n", + "The loss calculated: 0.4645318388938904\n", + "Epoch # 238\n", + "The loss calculated: 0.46459120512008667\n", + "Epoch # 239\n", + "The loss calculated: 0.46454647183418274\n", + "Epoch # 240\n", + "The loss calculated: 0.46439239382743835\n", + "Epoch # 241\n", + "The loss calculated: 0.464549720287323\n", + "Epoch # 242\n", + "The loss calculated: 0.4642981290817261\n", + "Epoch # 243\n", + "The loss calculated: 0.4640815258026123\n", + "Epoch # 244\n", + "The loss calculated: 0.4640815258026123\n", + "Epoch # 245\n", + "The loss calculated: 0.4638811945915222\n", + "Epoch # 246\n", + "The loss calculated: 0.46409285068511963\n", + "Epoch # 247\n", + "The loss calculated: 0.46399882435798645\n", + "Epoch # 248\n", + "The loss calculated: 0.4639054536819458\n", + "Epoch # 249\n", + "The loss calculated: 0.46384960412979126\n", + "Epoch # 250\n", + "The loss calculated: 0.46365633606910706\n", + "Epoch # 251\n", + "The loss calculated: 0.4635387361049652\n", + "Epoch # 252\n", + "The loss calculated: 0.46366339921951294\n", + "Epoch # 253\n", + "The loss calculated: 0.4635831415653229\n", + "Epoch # 254\n", + "The loss calculated: 0.46347707509994507\n", + "Epoch # 255\n", + "The loss calculated: 0.4633452892303467\n", + "Epoch # 256\n", + "The loss calculated: 0.4634377658367157\n", + "Epoch # 257\n", + "The loss calculated: 0.46325498819351196\n", + "Epoch # 258\n", + "The loss calculated: 0.46343502402305603\n", + "Epoch # 259\n", + "The loss calculated: 0.46319177746772766\n", + "Epoch # 260\n", + "The loss calculated: 0.4631631076335907\n", + "Epoch # 261\n", + "The loss calculated: 0.4630383253097534\n", + "Epoch # 262\n", + "The loss calculated: 0.4629758596420288\n", + "Epoch # 263\n", + "The loss calculated: 0.46284860372543335\n", + "Epoch # 264\n", + "The loss calculated: 0.46269962191581726\n", + "Epoch # 265\n", + "The loss calculated: 0.4628857374191284\n", + "Epoch # 266\n", + "The loss calculated: 0.4627268314361572\n", + "Epoch # 267\n", + "The loss calculated: 0.46238410472869873\n", + "Epoch # 268\n", + "The loss calculated: 0.4622679352760315\n", + "Epoch # 269\n", + "The loss calculated: 0.46253955364227295\n", + "Epoch # 270\n", + "The loss calculated: 0.46243607997894287\n", + "Epoch # 271\n", + "The loss calculated: 0.4622651934623718\n", + "Epoch # 272\n", + "The loss calculated: 0.4621260166168213\n", + "Epoch # 273\n", + "The loss calculated: 0.4619852304458618\n", + "Epoch # 274\n", + "The loss calculated: 0.4621600806713104\n", + "Epoch # 275\n", + "The loss calculated: 0.46188268065452576\n", + "Epoch # 276\n", + "The loss calculated: 0.4619770050048828\n", + "Epoch # 277\n", + "The loss calculated: 0.4617985486984253\n", + "Epoch # 278\n", + "The loss calculated: 0.46143385767936707\n", + "Epoch # 279\n", + "The loss calculated: 0.4618164002895355\n", + "Epoch # 280\n", + "The loss calculated: 0.461500883102417\n", + "Epoch # 281\n", + "The loss calculated: 0.4614565372467041\n", + "Epoch # 282\n", + "The loss calculated: 0.4613018035888672\n", + "Epoch # 283\n", + "The loss calculated: 0.4612286388874054\n", + "Epoch # 284\n", + "The loss calculated: 0.4610031545162201\n", + "Epoch # 285\n", + "The loss calculated: 0.4609623849391937\n", + "Epoch # 286\n", + "The loss calculated: 0.4608198404312134\n", + "Epoch # 287\n", + "The loss calculated: 0.46074378490448\n", + "Epoch # 288\n", + "The loss calculated: 0.46068280935287476\n", + "Epoch # 289\n", + "The loss calculated: 0.46061643958091736\n", + "Epoch # 290\n", + "The loss calculated: 0.4604104459285736\n", + "Epoch # 291\n", + "The loss calculated: 0.4607124626636505\n", + "Epoch # 292\n", + "The loss calculated: 0.4607458710670471\n", + "Epoch # 293\n", + "The loss calculated: 0.4601185619831085\n", + "Epoch # 294\n", + "The loss calculated: 0.460267573595047\n", + "Epoch # 295\n", + "The loss calculated: 0.4605766832828522\n", + "Epoch # 296\n", + "The loss calculated: 0.46028855443000793\n", + "Epoch # 297\n", + "The loss calculated: 0.4599803388118744\n", + "Epoch # 298\n", + "The loss calculated: 0.4600617587566376\n", + "Epoch # 299\n", + "The loss calculated: 0.46000462770462036\n", + "Epoch # 300\n", + "The loss calculated: 0.4595383405685425\n", + "Epoch # 301\n", + "The loss calculated: 0.4598424732685089\n", + "Epoch # 302\n", + "The loss calculated: 0.4597552418708801\n", + "Epoch # 303\n", + "The loss calculated: 0.45939505100250244\n", + "Epoch # 304\n", + "The loss calculated: 0.459394633769989\n", + "Epoch # 305\n", + "The loss calculated: 0.4592142403125763\n", + "Epoch # 306\n", + "The loss calculated: 0.4591156244277954\n", + "Epoch # 307\n", + "The loss calculated: 0.4590142071247101\n", + "Epoch # 308\n", + "The loss calculated: 0.45902881026268005\n", + "Epoch # 309\n", + "The loss calculated: 0.4590888023376465\n", + "Epoch # 310\n", + "The loss calculated: 0.45860469341278076\n", + "Epoch # 311\n", + "The loss calculated: 0.45852038264274597\n", + "Epoch # 312\n", + "The loss calculated: 0.4585433900356293\n", + "Epoch # 313\n", + "The loss calculated: 0.4586207866668701\n", + "Epoch # 314\n", + "The loss calculated: 0.45869746804237366\n", + "Epoch # 315\n", + "The loss calculated: 0.4585130214691162\n", + "Epoch # 316\n", + "The loss calculated: 0.45780810713768005\n", + "Epoch # 317\n", + "The loss calculated: 0.4584527313709259\n", + "Epoch # 318\n", + "The loss calculated: 0.4584985375404358\n", + "Epoch # 319\n", + "The loss calculated: 0.4577976167201996\n", + "Epoch # 320\n", + "The loss calculated: 0.4578183591365814\n", + "Epoch # 321\n", + "The loss calculated: 0.45760011672973633\n", + "Epoch # 322\n", + "The loss calculated: 0.4573518931865692\n", + "Epoch # 323\n", + "The loss calculated: 0.45755714178085327\n", + "Epoch # 324\n", + "The loss calculated: 0.4574785828590393\n", + "Epoch # 325\n", + "The loss calculated: 0.4572897255420685\n", + "Epoch # 326\n", + "The loss calculated: 0.45682093501091003\n", + "Epoch # 327\n", + "The loss calculated: 0.4571937322616577\n", + "Epoch # 328\n", + "The loss calculated: 0.45755869150161743\n", + "Epoch # 329\n", + "The loss calculated: 0.45663607120513916\n", + "Epoch # 330\n", + "The loss calculated: 0.4570084810256958\n", + "Epoch # 331\n", + "The loss calculated: 0.45761099457740784\n", + "Epoch # 332\n", + "The loss calculated: 0.456558495759964\n", + "Epoch # 333\n", + "The loss calculated: 0.45620036125183105\n", + "Epoch # 334\n", + "The loss calculated: 0.4563443958759308\n", + "Epoch # 335\n", + "The loss calculated: 0.45647644996643066\n", + "Epoch # 336\n", + "The loss calculated: 0.45592716336250305\n", + "Epoch # 337\n", + "The loss calculated: 0.455634742975235\n", + "Epoch # 338\n", + "The loss calculated: 0.4558946192264557\n", + "Epoch # 339\n", + "The loss calculated: 0.45598289370536804\n", + "Epoch # 340\n", + "The loss calculated: 0.4554951786994934\n", + "Epoch # 341\n", + "The loss calculated: 0.4554195702075958\n", + "Epoch # 342\n", + "The loss calculated: 0.4554871618747711\n", + "Epoch # 343\n", + "The loss calculated: 0.4549509584903717\n", + "Epoch # 344\n", + "The loss calculated: 0.4548693597316742\n", + "Epoch # 345\n", + "The loss calculated: 0.4558226466178894\n", + "Epoch # 346\n", + "The loss calculated: 0.45509448647499084\n", + "Epoch # 347\n", + "The loss calculated: 0.45454123616218567\n", + "Epoch # 348\n", + "The loss calculated: 0.4553173780441284\n", + "Epoch # 349\n", + "The loss calculated: 0.4548755884170532\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch # 350\n", + "The loss calculated: 0.45442134141921997\n", + "Epoch # 351\n", + "The loss calculated: 0.4545627236366272\n", + "Epoch # 352\n", + "The loss calculated: 0.4543512463569641\n", + "Epoch # 353\n", + "The loss calculated: 0.4541962146759033\n", + "Epoch # 354\n", + "The loss calculated: 0.4540751874446869\n", + "Epoch # 355\n", + "The loss calculated: 0.45386749505996704\n", + "Epoch # 356\n", + "The loss calculated: 0.4536762833595276\n", + "Epoch # 357\n", + "The loss calculated: 0.4532167911529541\n", + "Epoch # 358\n", + "The loss calculated: 0.4538520872592926\n", + "Epoch # 359\n", + "The loss calculated: 0.45413821935653687\n", + "Epoch # 360\n", + "The loss calculated: 0.45311087369918823\n", + "Epoch # 361\n", + "The loss calculated: 0.45335227251052856\n", + "Epoch # 362\n", + "The loss calculated: 0.45350611209869385\n", + "Epoch # 363\n", + "The loss calculated: 0.45265665650367737\n", + "Epoch # 364\n", + "The loss calculated: 0.4524100124835968\n", + "Epoch # 365\n", + "The loss calculated: 0.4523312449455261\n", + "Epoch # 366\n", + "The loss calculated: 0.4522554874420166\n", + "Epoch # 367\n", + "The loss calculated: 0.4523703455924988\n", + "Epoch # 368\n", + "The loss calculated: 0.4521876573562622\n", + "Epoch # 369\n", + "The loss calculated: 0.4517895579338074\n", + "Epoch # 370\n", + "The loss calculated: 0.4517730474472046\n", + "Epoch # 371\n", + "The loss calculated: 0.4515615999698639\n", + "Epoch # 372\n", + "The loss calculated: 0.45157772302627563\n", + "Epoch # 373\n", + "The loss calculated: 0.4515098035335541\n", + "Epoch # 374\n", + "The loss calculated: 0.45118868350982666\n", + "Epoch # 375\n", + "The loss calculated: 0.45117509365081787\n", + "Epoch # 376\n", + "The loss calculated: 0.45118534564971924\n", + "Epoch # 377\n", + "The loss calculated: 0.45082926750183105\n", + "Epoch # 378\n", + "The loss calculated: 0.4507909119129181\n", + "Epoch # 379\n", + "The loss calculated: 0.45116591453552246\n", + "Epoch # 380\n", + "The loss calculated: 0.45066720247268677\n", + "Epoch # 381\n", + "The loss calculated: 0.45026636123657227\n", + "Epoch # 382\n", + "The loss calculated: 0.4510788321495056\n", + "Epoch # 383\n", + "The loss calculated: 0.4512375593185425\n", + "Epoch # 384\n", + "The loss calculated: 0.450232595205307\n", + "Epoch # 385\n", + "The loss calculated: 0.44986671209335327\n", + "Epoch # 386\n", + "The loss calculated: 0.4502098262310028\n", + "Epoch # 387\n", + "The loss calculated: 0.4510081112384796\n", + "Epoch # 388\n", + "The loss calculated: 0.4499610960483551\n", + "Epoch # 389\n", + "The loss calculated: 0.44945529103279114\n", + "Epoch # 390\n", + "The loss calculated: 0.45030856132507324\n", + "Epoch # 391\n", + "The loss calculated: 0.4493928849697113\n", + "Epoch # 392\n", + "The loss calculated: 0.4490446448326111\n", + "Epoch # 393\n", + "The loss calculated: 0.4496527910232544\n", + "Epoch # 394\n", + "The loss calculated: 0.44922882318496704\n", + "Epoch # 395\n", + "The loss calculated: 0.4484827220439911\n", + "Epoch # 396\n", + "The loss calculated: 0.44952288269996643\n", + "Epoch # 397\n", + "The loss calculated: 0.4490470588207245\n", + "Epoch # 398\n", + "The loss calculated: 0.44837456941604614\n", + "Epoch # 399\n", + "The loss calculated: 0.44843804836273193\n", + "Epoch # 400\n", + "The loss calculated: 0.44825857877731323\n", + "Epoch # 401\n", + "The loss calculated: 0.4478710889816284\n", + "Epoch # 402\n", + "The loss calculated: 0.4478342533111572\n", + "Epoch # 403\n", + "The loss calculated: 0.44727033376693726\n", + "Epoch # 404\n", + "The loss calculated: 0.4474068582057953\n", + "Epoch # 405\n", + "The loss calculated: 0.4473791718482971\n", + "Epoch # 406\n", + "The loss calculated: 0.4471847414970398\n", + "Epoch # 407\n", + "The loss calculated: 0.44691354036331177\n", + "Epoch # 408\n", + "The loss calculated: 0.44677817821502686\n", + "Epoch # 409\n", + "The loss calculated: 0.4468446969985962\n", + "Epoch # 410\n", + "The loss calculated: 0.4465027153491974\n", + "Epoch # 411\n", + "The loss calculated: 0.44606125354766846\n", + "Epoch # 412\n", + "The loss calculated: 0.44594869017601013\n", + "Epoch # 413\n", + "The loss calculated: 0.4456939101219177\n", + "Epoch # 414\n", + "The loss calculated: 0.445888489484787\n", + "Epoch # 415\n", + "The loss calculated: 0.4455548822879791\n", + "Epoch # 416\n", + "The loss calculated: 0.44548290967941284\n", + "Epoch # 417\n", + "The loss calculated: 0.44544851779937744\n", + "Epoch # 418\n", + "The loss calculated: 0.44522538781166077\n", + "Epoch # 419\n", + "The loss calculated: 0.44501474499702454\n", + "Epoch # 420\n", + "The loss calculated: 0.4449530839920044\n", + "Epoch # 421\n", + "The loss calculated: 0.4445208013057709\n", + "Epoch # 422\n", + "The loss calculated: 0.4444122314453125\n", + "Epoch # 423\n", + "The loss calculated: 0.44473087787628174\n", + "Epoch # 424\n", + "The loss calculated: 0.4442698359489441\n", + "Epoch # 425\n", + "The loss calculated: 0.44399431347846985\n", + "Epoch # 426\n", + "The loss calculated: 0.4437970817089081\n", + "Epoch # 427\n", + "The loss calculated: 0.44364386796951294\n", + "Epoch # 428\n", + "The loss calculated: 0.4437081217765808\n", + "Epoch # 429\n", + "The loss calculated: 0.4436897039413452\n", + "Epoch # 430\n", + "The loss calculated: 0.44336003065109253\n", + "Epoch # 431\n", + "The loss calculated: 0.4430985748767853\n", + "Epoch # 432\n", + "The loss calculated: 0.44310933351516724\n", + "Epoch # 433\n", + "The loss calculated: 0.4428543746471405\n", + "Epoch # 434\n", + "The loss calculated: 0.44258877635002136\n", + "Epoch # 435\n", + "The loss calculated: 0.4427826404571533\n", + "Epoch # 436\n", + "The loss calculated: 0.44258812069892883\n", + "Epoch # 437\n", + "The loss calculated: 0.442533403635025\n", + "Epoch # 438\n", + "The loss calculated: 0.44270434975624084\n", + "Epoch # 439\n", + "The loss calculated: 0.4427698850631714\n", + "Epoch # 440\n", + "The loss calculated: 0.44257086515426636\n", + "Epoch # 441\n", + "The loss calculated: 0.4425719976425171\n", + "Epoch # 442\n", + "The loss calculated: 0.4420627951622009\n", + "Epoch # 443\n", + "The loss calculated: 0.4421764612197876\n", + "Epoch # 444\n", + "The loss calculated: 0.44193679094314575\n", + "Epoch # 445\n", + "The loss calculated: 0.44186508655548096\n", + "Epoch # 446\n", + "The loss calculated: 0.44136378169059753\n", + "Epoch # 447\n", + "The loss calculated: 0.44126731157302856\n", + "Epoch # 448\n", + "The loss calculated: 0.44119781255722046\n", + "Epoch # 449\n", + "The loss calculated: 0.4413573145866394\n", + "Epoch # 450\n", + "The loss calculated: 0.4411191940307617\n", + "Epoch # 451\n", + "The loss calculated: 0.4407786428928375\n", + "Epoch # 452\n", + "The loss calculated: 0.4407300055027008\n", + "Epoch # 453\n", + "The loss calculated: 0.4404629170894623\n", + "Epoch # 454\n", + "The loss calculated: 0.44039714336395264\n", + "Epoch # 455\n", + "The loss calculated: 0.44031772017478943\n", + "Epoch # 456\n", + "The loss calculated: 0.44058850407600403\n", + "Epoch # 457\n", + "The loss calculated: 0.44026416540145874\n", + "Epoch # 458\n", + "The loss calculated: 0.4401347041130066\n", + "Epoch # 459\n", + "The loss calculated: 0.44020867347717285\n", + "Epoch # 460\n", + "The loss calculated: 0.43979671597480774\n", + "Epoch # 461\n", + "The loss calculated: 0.44035604596138\n", + "Epoch # 462\n", + "The loss calculated: 0.4401366412639618\n", + "Epoch # 463\n", + "The loss calculated: 0.4404027760028839\n", + "Epoch # 464\n", + "The loss calculated: 0.439935564994812\n", + "Epoch # 465\n", + "The loss calculated: 0.4399685561656952\n", + "Epoch # 466\n", + "The loss calculated: 0.4409003257751465\n", + "Epoch # 467\n", + "The loss calculated: 0.43949607014656067\n", + "Epoch # 468\n", + "The loss calculated: 0.4398217797279358\n", + "Epoch # 469\n", + "The loss calculated: 0.43998679518699646\n", + "Epoch # 470\n", + "The loss calculated: 0.4403824508190155\n", + "Epoch # 471\n", + "The loss calculated: 0.43901607394218445\n", + "Epoch # 472\n", + "The loss calculated: 0.44028377532958984\n", + "Epoch # 473\n", + "The loss calculated: 0.4426659643650055\n", + "Epoch # 474\n", + "The loss calculated: 0.44038379192352295\n", + "Epoch # 475\n", + "The loss calculated: 0.4395928978919983\n", + "Epoch # 476\n", + "The loss calculated: 0.44086745381355286\n", + "Epoch # 477\n", + "The loss calculated: 0.43867841362953186\n", + "Epoch # 478\n", + "The loss calculated: 0.4390256404876709\n", + "Epoch # 479\n", + "The loss calculated: 0.4390667676925659\n", + "Epoch # 480\n", + "The loss calculated: 0.4384021759033203\n", + "Epoch # 481\n", + "The loss calculated: 0.4385366439819336\n", + "Epoch # 482\n", + "The loss calculated: 0.4384676516056061\n", + "Epoch # 483\n", + "The loss calculated: 0.4386775493621826\n", + "Epoch # 484\n", + "The loss calculated: 0.43819159269332886\n", + "Epoch # 485\n", + "The loss calculated: 0.4379732608795166\n", + "Epoch # 486\n", + "The loss calculated: 0.4379722476005554\n", + "Epoch # 487\n", + "The loss calculated: 0.4376266896724701\n", + "Epoch # 488\n", + "The loss calculated: 0.4373808205127716\n", + "Epoch # 489\n", + "The loss calculated: 0.43826723098754883\n", + "Epoch # 490\n", + "The loss calculated: 0.4379383623600006\n", + "Epoch # 491\n", + "The loss calculated: 0.4372965395450592\n", + "Epoch # 492\n", + "The loss calculated: 0.4375162422657013\n", + "Epoch # 493\n", + "The loss calculated: 0.43795913457870483\n", + "Epoch # 494\n", + "The loss calculated: 0.43740007281303406\n", + "Epoch # 495\n", + "The loss calculated: 0.43741703033447266\n", + "Epoch # 496\n", + "The loss calculated: 0.4373546838760376\n", + "Epoch # 497\n", + "The loss calculated: 0.4368191957473755\n", + "Epoch # 498\n", + "The loss calculated: 0.4367024898529053\n", + "Epoch # 499\n", + "The loss calculated: 0.43679192662239075\n", + "Epoch # 500\n", + "The loss calculated: 0.436893105506897\n" + ] + } + ], + "source": [ + "from torch.utils.data import DataLoader\n", + "\n", + "for epoch in range(1, epochs+1):\n", + " print(\"Epoch #\", epoch)\n", + " y_pred = model(Xt)\n", + "# print(y_pred)\n", + " loss = loss_fn(y_pred, Yt)\n", + " print_(loss.item())\n", + " \n", + " optimizer.zero_grad()\n", + " loss.backward()\n", + " optimizer.step()" + ] + }, + { + "cell_type": "code", + "execution_count": 264, + "id": "45d76c95", + "metadata": {}, + "outputs": [], + "source": [ + "x_test = torch.tensor(X_test.values, dtype=torch.float32)" + ] + }, + { + "cell_type": "code", + "execution_count": 271, + "id": "5e98206b", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_7802/3372075492.py:11: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n", + " x = F.softmax(self.layer3(x))\n" + ] + } + ], + "source": [ + "pred = model(x_test)" + ] + }, + { + "cell_type": "code", + "execution_count": 272, + "id": "35d64340", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[1.3141002e-01, 8.6859006e-01],\n", + " [3.0172759e-16, 1.0000000e+00],\n", + " [5.9731257e-21, 1.0000000e+00],\n", + " [8.7287611e-01, 1.2712391e-01],\n", + " [3.3298880e-01, 6.6701120e-01],\n", + " [9.9992323e-01, 7.6730175e-05],\n", + " [6.9742590e-01, 3.0257410e-01],\n", + " [1.8122771e-10, 1.0000000e+00],\n", + " [8.1137923e-18, 1.0000000e+00],\n", + " [9.9391985e-01, 6.0801902e-03],\n", + " [9.9800962e-01, 1.9904438e-03],\n", + " [1.4347603e-12, 1.0000000e+00],\n", + " [8.8945550e-01, 1.1054446e-01],\n", + " [5.3068206e-19, 1.0000000e+00],\n", + " [4.4245785e-01, 5.5754209e-01],\n", + " [3.9323148e-01, 6.0676849e-01],\n", + " [5.0538932e-23, 1.0000000e+00],\n", + " [6.8482041e-01, 3.1517953e-01],\n", + " [9.9650586e-01, 3.4941665e-03],\n", + " [3.6827392e-24, 1.0000000e+00],\n", + " [3.4629088e-12, 1.0000000e+00],\n", + " [2.4781654e-11, 1.0000000e+00],\n", + " [8.4075117e-01, 1.5924890e-01],\n", + " [9.9999881e-01, 1.2382451e-06],\n", + " [9.9950111e-01, 4.9885432e-04],\n", + " [1.1888127e-14, 1.0000000e+00],\n", + " [1.5869159e-14, 1.0000000e+00],\n", + " [9.4683814e-01, 5.3161871e-02],\n", + " [7.3645154e-08, 9.9999988e-01],\n", + " [1.2287432e-11, 1.0000000e+00],\n", + " [5.7253930e-15, 1.0000000e+00],\n", + " [7.9019060e-08, 9.9999988e-01],\n", + " [5.5769521e-01, 4.4230482e-01],\n", + " [1.8103112e-14, 1.0000000e+00],\n", + " [9.9812454e-01, 1.8754901e-03],\n", + " [2.5346470e-05, 9.9997461e-01],\n", + " [1.6169167e-17, 1.0000000e+00],\n", + " [9.3050295e-01, 6.9496997e-02],\n", + " [6.1799776e-02, 9.3820024e-01],\n", + " [9.7120519e-06, 9.9999034e-01],\n", + " [9.9844283e-01, 1.5571705e-03],\n", + " [8.0438519e-01, 1.9561480e-01],\n", + " [2.0653886e-16, 1.0000000e+00],\n", + " [7.0155847e-01, 2.9844159e-01],\n", + " [9.9505252e-01, 4.9475045e-03],\n", + " [9.3824464e-01, 6.1755374e-02]], dtype=float32)" + ] + }, + "execution_count": 272, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pred = pred.detach().numpy()\n", + "pred" + ] + }, + { + "cell_type": "code", + "execution_count": 269, + "id": "5c18f80f", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The accuracy is 0.7391304347826086\n" + ] + } + ], + "source": [ + "print (\"The accuracy is\", accuracy_score(Y_test, np.argmax(pred, axis=1)))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a4638b1d", + "metadata": {}, + "outputs": [], "source": [] } ], diff --git a/dockerfile b/dockerfile deleted file mode 100644 index 4e565dd..0000000 --- a/dockerfile +++ /dev/null @@ -1,23 +0,0 @@ -FROM ubuntu:latest - -RUN apt-get update --fix-missing -RUN apt install python3-pip -y -RUN apt install unzip -y -RUN apt install git -y - -RUN pip install --user kaggle -RUN pip install --user pandas - - -RUN ln -s ~/.local/bin/kaggle /usr/bin/kaggle -ENV PATH="$PATH:~/.local/bin/kaggle" -ENV KAGGLE_USERNAME="wiktorbombola" -ENV KAGGLE_KEY="" - -# RUN echo "alias kaggle='~/.local/bin/kaggle'" >> ~/.bashrc - -COPY ./script.sh ./ -COPY ./learning.py ./ - -# CMD ./script.sh 300 -# CMD ./learning.py \ No newline at end of file diff --git a/dockerfiles/docker-titanic b/dockerfiles/docker-titanic deleted file mode 100644 index 4e565dd..0000000 --- a/dockerfiles/docker-titanic +++ /dev/null @@ -1,23 +0,0 @@ -FROM ubuntu:latest - -RUN apt-get update --fix-missing -RUN apt install python3-pip -y -RUN apt install unzip -y -RUN apt install git -y - -RUN pip install --user kaggle -RUN pip install --user pandas - - -RUN ln -s ~/.local/bin/kaggle /usr/bin/kaggle -ENV PATH="$PATH:~/.local/bin/kaggle" -ENV KAGGLE_USERNAME="wiktorbombola" -ENV KAGGLE_KEY="" - -# RUN echo "alias kaggle='~/.local/bin/kaggle'" >> ~/.bashrc - -COPY ./script.sh ./ -COPY ./learning.py ./ - -# CMD ./script.sh 300 -# CMD ./learning.py \ No newline at end of file diff --git a/dockerfiles/docker-train b/dockerfiles/docker-train deleted file mode 100644 index e860450..0000000 --- a/dockerfiles/docker-train +++ /dev/null @@ -1,19 +0,0 @@ -FROM ubuntu:latest - -RUN apt-get update --fix-missing -RUN apt install python3-pip -y -RUN apt install unzip -y -RUN apt install git -y - -RUN pip install --user pandas -RUN pip install --user torch -RUN pip install --user keras -RUN pip install --user tensorflow -RUN pip install --user scikit-learn - -# RUN echo "alias kaggle='~/.local/bin/kaggle'" >> ~/.bashrc - -COPY ./learning.py ./ - -# CMD ./script.sh 300 -# CMD ./learning.pyRUN pip install --user numpy diff --git a/train/dockerfile b/train/dockerfile deleted file mode 100644 index a87600e..0000000 --- a/train/dockerfile +++ /dev/null @@ -1,19 +0,0 @@ -FROM ubuntu:latest - -RUN apt-get update --fix-missing -RUN apt install python3-pip -y -RUN apt install unzip -y -RUN apt install git -y - -RUN pip install --user pandas -RUN pip install --user torch -RUN pip install --user keras -RUN pip install --user tensorflow -RUN pip install --user scikit-learn - -# RUN echo "alias kaggle='~/.local/bin/kaggle'" >> ~/.bashrc - -COPY ./../learning.py ./ - -# CMD ./script.sh 300 -# CMD ./learning.pyRUN pip install --user numpy