sport-text-classification-ball/w2v.ipynb
Paweł Łączkowski 4ef009f252 initial commit
2024-05-08 20:27:57 +02:00

2424 lines
108 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"id": "4a1063341b779728",
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
"source": [
"### Importing libraries"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "initial_id",
"metadata": {
"ExecuteTime": {
"end_time": "2024-05-08T17:53:05.720813400Z",
"start_time": "2024-05-08T17:53:04.048848Z"
},
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
"outputs": [],
"source": [
"# Data manipulation\n",
"import numpy as np\n",
"import pandas as pd\n",
"\n",
"# Word2vec\n",
"from gensim.models import KeyedVectors\n",
"\n",
"# NLP\n",
"import spacy\n",
"\n",
"# Neural network\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.optim as optim\n",
"\n",
"# Metrics\n",
"from sklearn.metrics import accuracy_score"
]
},
{
"cell_type": "markdown",
"id": "b4e671142b23b776",
"metadata": {
"ExecuteTime": {
"end_time": "2024-05-08T14:07:57.691373500Z",
"start_time": "2024-05-08T14:07:57.676097400Z"
},
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
"source": [
"## Load word2vec model (100 dimensions)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "1014cb8afc56b42",
"metadata": {
"ExecuteTime": {
"end_time": "2024-05-08T12:30:47.955107800Z",
"start_time": "2024-05-08T12:30:36.048633200Z"
},
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
"outputs": [],
"source": [
"word2vec = KeyedVectors.load('word2vec/word2vec_100_3_polish.bin')"
]
},
{
"cell_type": "markdown",
"id": "ae6f5b70a6d3d61e",
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
"source": [
"## Load spacy model"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "d950bf5e3fc1181e",
"metadata": {
"ExecuteTime": {
"end_time": "2024-05-08T12:30:48.823353900Z",
"start_time": "2024-05-08T12:30:47.956105400Z"
},
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
"outputs": [],
"source": [
"nlp = spacy.load('pl_core_news_sm')"
]
},
{
"cell_type": "markdown",
"id": "da4582862ff1273f",
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
"source": [
"## Neural network model"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "bfcb49cbdfb816a4",
"metadata": {
"ExecuteTime": {
"end_time": "2024-05-08T18:02:01.126467Z",
"start_time": "2024-05-08T18:02:01.070554500Z"
},
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
"outputs": [],
"source": [
"class NeuralNetwork(nn.Module):\n",
" def __init__(self, input_size, hidden_size):\n",
" super(NeuralNetwork, self).__init__()\n",
" \n",
" self.fc1 = nn.Linear(input_size, hidden_size)\n",
" self.fc2 = nn.Linear(hidden_size, hidden_size // 2)\n",
" self.fc3 = nn.Linear(hidden_size // 2, hidden_size // 4)\n",
" self.fc4 = nn.Linear(hidden_size // 4, hidden_size // 8)\n",
" self.fc5 = nn.Linear(hidden_size // 8, 1)\n",
" \n",
" self.relu = nn.ReLU()\n",
" self.sigmoid = nn.Sigmoid()\n",
" \n",
" def forward(self, x):\n",
" x = self.fc1(x)\n",
" x = self.relu(x)\n",
" x = self.fc2(x)\n",
" x = self.relu(x)\n",
" x = self.fc3(x)\n",
" x = self.relu(x)\n",
" x = self.fc4(x)\n",
" x = self.relu(x)\n",
" x = self.fc5(x)\n",
" x = self.sigmoid(x)\n",
" return x"
]
},
{
"cell_type": "markdown",
"id": "be7d539b55824f71",
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
"source": [
"## Load and preprocess data"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "d05f2731752949a7",
"metadata": {
"ExecuteTime": {
"end_time": "2024-05-08T17:56:24.056763100Z",
"start_time": "2024-05-08T17:56:23.356652200Z"
},
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
"outputs": [],
"source": [
"# Load data\n",
"df_train = pd.read_csv('train/train.tsv', delimiter='\\t', header=None)\n",
"df_test = pd.read_csv('test-A/in.tsv', delimiter='\\t', header=None)\n",
"df_dev = pd.read_csv('dev-0/in.tsv', delimiter='\\t', header=None)\n",
"df_dev_expected = pd.read_csv('dev-0/expected.tsv', delimiter='\\t', header=None)\n",
"\n",
"# Drop invalid columns\n",
"df_train.drop(columns=2, inplace=True)\n",
"df_test.drop(columns=1, inplace=True)\n",
"\n",
"# Rename columns\n",
"df_train.columns = ['label', 'sentence']\n",
"df_test.columns = ['sentence']\n",
"df_dev.columns = ['sentence']\n",
"df_dev_expected.columns = ['label']\n",
"\n",
"# Convert sentences to lowercase\n",
"df_train['sentence'] = df_train['sentence'].apply(lambda x: x.lower())\n",
"df_test['sentence'] = df_test['sentence'].apply(lambda x: x.lower())\n",
"df_dev['sentence'] = df_dev['sentence'].apply(lambda x: x.lower())"
]
},
{
"cell_type": "markdown",
"id": "a751ad8a-23e5-4261-b873-fb3328e94de0",
"metadata": {},
"source": [
"## Sentence representation with word2vec"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "2c574d88e4f47051",
"metadata": {
"ExecuteTime": {
"end_time": "2024-05-08T12:30:48.876712200Z",
"start_time": "2024-05-08T12:30:48.855785900Z"
},
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
"outputs": [],
"source": [
"def get_sentence_representation(sentence):\n",
" doc = nlp(sentence.lower())\n",
" return np.sum([word2vec[token.text] for token in doc if token.text in word2vec], axis=0)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7642fb9f9bb374ae",
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
"outputs": [],
"source": [
"# Train data\n",
"X_train = np.array([get_sentence_representation(sentence) for sentence in df_train['sentence']])\n",
"y_train = df_train['label'].values\n",
"\n",
"# Dev data\n",
"X_dev = np.array([get_sentence_representation(sentence) for sentence in df_dev['sentence']])\n",
"y_dev = df_dev_expected['label'].values\n",
"\n",
"# Test data\n",
"X_test = np.array([get_sentence_representation(sentence) for sentence in df_test['sentence']])"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "81b46b4f880fb569",
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
"outputs": [],
"source": [
"# Load previously saved data (word2vec representation of sentences)\n",
"X_train = np.load('X_train.npy')\n",
"y_train = np.load('y_train.npy')\n",
"\n",
"X_dev = np.load('X_dev.npy')\n",
"y_dev = np.load('y_dev.npy')\n",
"\n",
"X_test = np.load('X_test.npy')"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "3609cf343b38a8af",
"metadata": {
"ExecuteTime": {
"end_time": "2024-05-08T13:58:32.182241800Z",
"start_time": "2024-05-08T13:57:10.645340700Z"
},
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 0 - loss: 0.9116247892379761\n",
"Epoch 0 - accuracy: 0.6362802641232576\n",
"Epoch 1 - loss: 0.7140407562255859\n",
"Epoch 1 - accuracy: 0.6362802641232576\n",
"Epoch 2 - loss: 0.618240237236023\n",
"Epoch 2 - accuracy: 0.6366471019809244\n",
"Epoch 3 - loss: 0.6239327788352966\n",
"Epoch 3 - accuracy: 0.6977256052824652\n",
"Epoch 4 - loss: 0.6335155367851257\n",
"Epoch 4 - accuracy: 0.7316581071166545\n",
"Epoch 5 - loss: 0.6156240701675415\n",
"Epoch 5 - accuracy: 0.7279897285399853\n",
"Epoch 6 - loss: 0.5953847169876099\n",
"Epoch 6 - accuracy: 0.694424064563463\n",
"Epoch 7 - loss: 0.5810463428497314\n",
"Epoch 7 - accuracy: 0.6766324284666178\n",
"Epoch 8 - loss: 0.5640420317649841\n",
"Epoch 8 - accuracy: 0.6856199559794571\n",
"Epoch 9 - loss: 0.5385629534721375\n",
"Epoch 9 - accuracy: 0.733125458547322\n",
"Epoch 10 - loss: 0.5121918320655823\n",
"Epoch 10 - accuracy: 0.7870506236243581\n",
"Epoch 11 - loss: 0.49032482504844666\n",
"Epoch 11 - accuracy: 0.80997798972854\n",
"Epoch 12 - loss: 0.45911359786987305\n",
"Epoch 12 - accuracy: 0.8105282465150404\n",
"Epoch 13 - loss: 0.43574512004852295\n",
"Epoch 13 - accuracy: 0.8046588407923698\n",
"Epoch 14 - loss: 0.4038088619709015\n",
"Epoch 14 - accuracy: 0.8297872340425532\n",
"Epoch 15 - loss: 0.38172227144241333\n",
"Epoch 15 - accuracy: 0.8406089508437271\n",
"Epoch 16 - loss: 0.34902533888816833\n",
"Epoch 16 - accuracy: 0.8547322083639032\n",
"Epoch 17 - loss: 0.32240182161331177\n",
"Epoch 17 - accuracy: 0.8679383712399119\n",
"Epoch 18 - loss: 0.311143159866333\n",
"Epoch 18 - accuracy: 0.8706896551724138\n",
"Epoch 19 - loss: 0.3087221682071686\n",
"Epoch 19 - accuracy: 0.8650036683785767\n",
"Epoch 20 - loss: 0.2847115099430084\n",
"Epoch 20 - accuracy: 0.8798606016140865\n",
"Epoch 21 - loss: 0.2667429447174072\n",
"Epoch 21 - accuracy: 0.8906823184152605\n",
"Epoch 22 - loss: 0.27080655097961426\n",
"Epoch 22 - accuracy: 0.8857300073367571\n",
"Epoch 23 - loss: 0.2571971118450165\n",
"Epoch 23 - accuracy: 0.8890315480557593\n",
"Epoch 24 - loss: 0.2406086027622223\n",
"Epoch 24 - accuracy: 0.9002201027146002\n",
"Epoch 25 - loss: 0.24326132237911224\n",
"Epoch 25 - accuracy: 0.8972853998532648\n",
"Epoch 26 - loss: 0.23878683149814606\n",
"Epoch 26 - accuracy: 0.898936170212766\n",
"Epoch 27 - loss: 0.2246064394712448\n",
"Epoch 27 - accuracy: 0.9060895084372708\n",
"Epoch 28 - loss: 0.2251722663640976\n",
"Epoch 28 - accuracy: 0.9055392516507703\n",
"Epoch 29 - loss: 0.22923922538757324\n",
"Epoch 29 - accuracy: 0.9022377109317682\n",
"Epoch 30 - loss: 0.2179824411869049\n",
"Epoch 30 - accuracy: 0.9092076302274394\n",
"Epoch 31 - loss: 0.21235334873199463\n",
"Epoch 31 - accuracy: 0.9106749816581071\n",
"Epoch 32 - loss: 0.21798621118068695\n",
"Epoch 32 - accuracy: 0.9048055759354365\n",
"Epoch 33 - loss: 0.2159065306186676\n",
"Epoch 33 - accuracy: 0.913059427732942\n",
"Epoch 34 - loss: 0.20762306451797485\n",
"Epoch 34 - accuracy: 0.9117754952311079\n",
"Epoch 35 - loss: 0.20568040013313293\n",
"Epoch 35 - accuracy: 0.9115920763022743\n",
"Epoch 36 - loss: 0.20884327590465546\n",
"Epoch 36 - accuracy: 0.9161775495231108\n",
"Epoch 37 - loss: 0.2060328722000122\n",
"Epoch 37 - accuracy: 0.9104915627292737\n",
"Epoch 38 - loss: 0.19855882227420807\n",
"Epoch 38 - accuracy: 0.9163609684519443\n",
"Epoch 39 - loss: 0.19888578355312347\n",
"Epoch 39 - accuracy: 0.9174614820249449\n",
"Epoch 40 - loss: 0.20154142379760742\n",
"Epoch 40 - accuracy: 0.9148936170212766\n",
"Epoch 41 - loss: 0.19735418260097504\n",
"Epoch 41 - accuracy: 0.9167278063096111\n",
"Epoch 42 - loss: 0.19287440180778503\n",
"Epoch 42 - accuracy: 0.9181951577402788\n",
"Epoch 43 - loss: 0.19381892681121826\n",
"Epoch 43 - accuracy: 0.9181951577402788\n",
"Epoch 44 - loss: 0.19489999115467072\n",
"Epoch 44 - accuracy: 0.9183785766691123\n",
"Epoch 45 - loss: 0.191724494099617\n",
"Epoch 45 - accuracy: 0.9187454145267792\n",
"Epoch 46 - loss: 0.187753826379776\n",
"Epoch 46 - accuracy: 0.9196625091709465\n",
"Epoch 47 - loss: 0.18749786913394928\n",
"Epoch 47 - accuracy: 0.9196625091709465\n",
"Epoch 48 - loss: 0.18893089890480042\n",
"Epoch 48 - accuracy: 0.9194790902421129\n",
"Epoch 49 - loss: 0.18769621849060059\n",
"Epoch 49 - accuracy: 0.921496698459281\n",
"Epoch 50 - loss: 0.18460743129253387\n",
"Epoch 50 - accuracy: 0.9224137931034483\n",
"Epoch 51 - loss: 0.18182355165481567\n",
"Epoch 51 - accuracy: 0.9222303741746148\n",
"Epoch 52 - loss: 0.18156687915325165\n",
"Epoch 52 - accuracy: 0.923881144534116\n",
"Epoch 53 - loss: 0.183009535074234\n",
"Epoch 53 - accuracy: 0.9231474688187821\n",
"Epoch 54 - loss: 0.1834733784198761\n",
"Epoch 54 - accuracy: 0.9236977256052825\n",
"Epoch 55 - loss: 0.18309442698955536\n",
"Epoch 55 - accuracy: 0.9231474688187821\n",
"Epoch 56 - loss: 0.18024547398090363\n",
"Epoch 56 - accuracy: 0.9257153338224505\n",
"Epoch 57 - loss: 0.17757220566272736\n",
"Epoch 57 - accuracy: 0.9253484959647835\n",
"Epoch 58 - loss: 0.17630420625209808\n",
"Epoch 58 - accuracy: 0.9257153338224505\n",
"Epoch 59 - loss: 0.17675919830799103\n",
"Epoch 59 - accuracy: 0.9269992663242846\n",
"Epoch 60 - loss: 0.17855069041252136\n",
"Epoch 60 - accuracy: 0.9253484959647835\n",
"Epoch 61 - loss: 0.17923885583877563\n",
"Epoch 61 - accuracy: 0.9264490095377843\n",
"Epoch 62 - loss: 0.17824675142765045\n",
"Epoch 62 - accuracy: 0.9249816581071166\n",
"Epoch 63 - loss: 0.17506806552410126\n",
"Epoch 63 - accuracy: 0.927916360968452\n",
"Epoch 64 - loss: 0.17224562168121338\n",
"Epoch 64 - accuracy: 0.9292002934702861\n",
"Epoch 65 - loss: 0.17135733366012573\n",
"Epoch 65 - accuracy: 0.9290168745414527\n",
"Epoch 66 - loss: 0.1721159666776657\n",
"Epoch 66 - accuracy: 0.9303008070432869\n",
"Epoch 67 - loss: 0.17357103526592255\n",
"Epoch 67 - accuracy: 0.927549523110785\n",
"Epoch 68 - loss: 0.17438305914402008\n",
"Epoch 68 - accuracy: 0.9286500366837858\n",
"Epoch 69 - loss: 0.17328017950057983\n",
"Epoch 69 - accuracy: 0.9271826852531181\n",
"Epoch 70 - loss: 0.1704489141702652\n",
"Epoch 70 - accuracy: 0.9301173881144534\n",
"Epoch 71 - loss: 0.16770099103450775\n",
"Epoch 71 - accuracy: 0.9325018341892883\n",
"Epoch 72 - loss: 0.1664922684431076\n",
"Epoch 72 - accuracy: 0.9321349963316214\n",
"Epoch 73 - loss: 0.16692861914634705\n",
"Epoch 73 - accuracy: 0.9308510638297872\n",
"Epoch 74 - loss: 0.1685953289270401\n",
"Epoch 74 - accuracy: 0.9292002934702861\n",
"Epoch 75 - loss: 0.17085592448711395\n",
"Epoch 75 - accuracy: 0.9282831988261189\n",
"Epoch 76 - loss: 0.1718451827764511\n",
"Epoch 76 - accuracy: 0.9282831988261189\n",
"Epoch 77 - loss: 0.17076171934604645\n",
"Epoch 77 - accuracy: 0.9277329420396185\n",
"Epoch 78 - loss: 0.16577975451946259\n",
"Epoch 78 - accuracy: 0.9308510638297872\n",
"Epoch 79 - loss: 0.16221792995929718\n",
"Epoch 79 - accuracy: 0.9352531181217901\n",
"Epoch 80 - loss: 0.16273713111877441\n",
"Epoch 80 - accuracy: 0.9330520909757887\n",
"Epoch 81 - loss: 0.16511954367160797\n",
"Epoch 81 - accuracy: 0.9312179016874541\n",
"Epoch 82 - loss: 0.165837362408638\n",
"Epoch 82 - accuracy: 0.9301173881144534\n",
"Epoch 83 - loss: 0.16341626644134521\n",
"Epoch 83 - accuracy: 0.9317681584739546\n",
"Epoch 84 - loss: 0.16035114228725433\n",
"Epoch 84 - accuracy: 0.9336023477622891\n",
"Epoch 85 - loss: 0.1588079333305359\n",
"Epoch 85 - accuracy: 0.9356199559794571\n",
"Epoch 86 - loss: 0.15911996364593506\n",
"Epoch 86 - accuracy: 0.9337857666911226\n",
"Epoch 87 - loss: 0.16078558564186096\n",
"Epoch 87 - accuracy: 0.9321349963316214\n",
"Epoch 88 - loss: 0.16469542682170868\n",
"Epoch 88 - accuracy: 0.9315847395451211\n",
"Epoch 89 - loss: 0.1707155555486679\n",
"Epoch 89 - accuracy: 0.9284666177549523\n",
"Epoch 90 - loss: 0.17066214978694916\n",
"Epoch 90 - accuracy: 0.927916360968452\n",
"Epoch 91 - loss: 0.16468870639801025\n",
"Epoch 91 - accuracy: 0.9317681584739546\n",
"Epoch 92 - loss: 0.15593402087688446\n",
"Epoch 92 - accuracy: 0.9361702127659575\n",
"Epoch 93 - loss: 0.15852655470371246\n",
"Epoch 93 - accuracy: 0.9332355099046221\n",
"Epoch 94 - loss: 0.1638629287481308\n",
"Epoch 94 - accuracy: 0.9317681584739546\n",
"Epoch 95 - loss: 0.15902242064476013\n",
"Epoch 95 - accuracy: 0.9341526045487895\n",
"Epoch 96 - loss: 0.15367668867111206\n",
"Epoch 96 - accuracy: 0.9372707263389581\n",
"Epoch 97 - loss: 0.1557561308145523\n",
"Epoch 97 - accuracy: 0.9356199559794571\n",
"Epoch 98 - loss: 0.15850110352039337\n",
"Epoch 98 - accuracy: 0.9343360234776229\n",
"Epoch 99 - loss: 0.1556081622838974\n",
"Epoch 99 - accuracy: 0.9356199559794571\n",
"Epoch 100 - loss: 0.15204134583473206\n",
"Epoch 100 - accuracy: 0.9369038884812912\n",
"Epoch 101 - loss: 0.1527155041694641\n",
"Epoch 101 - accuracy: 0.935986793837124\n",
"Epoch 102 - loss: 0.15533211827278137\n",
"Epoch 102 - accuracy: 0.9352531181217901\n",
"Epoch 103 - loss: 0.1561749130487442\n",
"Epoch 103 - accuracy: 0.9345194424064563\n",
"Epoch 104 - loss: 0.15393942594528198\n",
"Epoch 104 - accuracy: 0.9369038884812912\n",
"Epoch 105 - loss: 0.15091024339199066\n",
"Epoch 105 - accuracy: 0.9356199559794571\n",
"Epoch 106 - loss: 0.14931201934814453\n",
"Epoch 106 - accuracy: 0.9391049156272927\n",
"Epoch 107 - loss: 0.14995944499969482\n",
"Epoch 107 - accuracy: 0.9374541452677916\n",
"Epoch 108 - loss: 0.15163299441337585\n",
"Epoch 108 - accuracy: 0.9358033749082906\n",
"Epoch 109 - loss: 0.15240135788917542\n",
"Epoch 109 - accuracy: 0.9365370506236244\n",
"Epoch 110 - loss: 0.1520850509405136\n",
"Epoch 110 - accuracy: 0.9347028613352898\n",
"Epoch 111 - loss: 0.1503436118364334\n",
"Epoch 111 - accuracy: 0.9365370506236244\n",
"Epoch 112 - loss: 0.14841999113559723\n",
"Epoch 112 - accuracy: 0.9365370506236244\n",
"Epoch 113 - loss: 0.14677779376506805\n",
"Epoch 113 - accuracy: 0.9396551724137931\n",
"Epoch 114 - loss: 0.14633171260356903\n",
"Epoch 114 - accuracy: 0.94002201027146\n",
"Epoch 115 - loss: 0.14686140418052673\n",
"Epoch 115 - accuracy: 0.9378209831254586\n",
"Epoch 116 - loss: 0.14767663180828094\n",
"Epoch 116 - accuracy: 0.9369038884812912\n",
"Epoch 117 - loss: 0.14996299147605896\n",
"Epoch 117 - accuracy: 0.9352531181217901\n",
"Epoch 118 - loss: 0.15362921357154846\n",
"Epoch 118 - accuracy: 0.9345194424064563\n",
"Epoch 119 - loss: 0.15794260799884796\n",
"Epoch 119 - accuracy: 0.9334189288334556\n",
"Epoch 120 - loss: 0.16259002685546875\n",
"Epoch 120 - accuracy: 0.9306676449009538\n",
"Epoch 121 - loss: 0.1544409841299057\n",
"Epoch 121 - accuracy: 0.9336023477622891\n",
"Epoch 122 - loss: 0.14526315033435822\n",
"Epoch 122 - accuracy: 0.9383712399119589\n",
"Epoch 123 - loss: 0.1445532888174057\n",
"Epoch 123 - accuracy: 0.9385546588407924\n",
"Epoch 124 - loss: 0.15006223320960999\n",
"Epoch 124 - accuracy: 0.9358033749082906\n",
"Epoch 125 - loss: 0.1499934196472168\n",
"Epoch 125 - accuracy: 0.9370873074101247\n",
"Epoch 126 - loss: 0.1433618664741516\n",
"Epoch 126 - accuracy: 0.9396551724137931\n",
"Epoch 127 - loss: 0.14303633570671082\n",
"Epoch 127 - accuracy: 0.94002201027146\n",
"Epoch 128 - loss: 0.14751005172729492\n",
"Epoch 128 - accuracy: 0.9370873074101247\n",
"Epoch 129 - loss: 0.1479104906320572\n",
"Epoch 129 - accuracy: 0.9361702127659575\n",
"Epoch 130 - loss: 0.14386673271656036\n",
"Epoch 130 - accuracy: 0.9378209831254586\n",
"Epoch 131 - loss: 0.1406869888305664\n",
"Epoch 131 - accuracy: 0.9422230374174615\n",
"Epoch 132 - loss: 0.14180149137973785\n",
"Epoch 132 - accuracy: 0.94002201027146\n",
"Epoch 133 - loss: 0.1451810747385025\n",
"Epoch 133 - accuracy: 0.9378209831254586\n",
"Epoch 134 - loss: 0.14710348844528198\n",
"Epoch 134 - accuracy: 0.9374541452677916\n",
"Epoch 135 - loss: 0.14716187119483948\n",
"Epoch 135 - accuracy: 0.9370873074101247\n",
"Epoch 136 - loss: 0.14351975917816162\n",
"Epoch 136 - accuracy: 0.938004402054292\n",
"Epoch 137 - loss: 0.1399369090795517\n",
"Epoch 137 - accuracy: 0.9411225238444607\n",
"Epoch 138 - loss: 0.13867513835430145\n",
"Epoch 138 - accuracy: 0.9422230374174615\n",
"Epoch 139 - loss: 0.14000000059604645\n",
"Epoch 139 - accuracy: 0.9411225238444607\n",
"Epoch 140 - loss: 0.14197778701782227\n",
"Epoch 140 - accuracy: 0.9391049156272927\n",
"Epoch 141 - loss: 0.1430107057094574\n",
"Epoch 141 - accuracy: 0.9378209831254586\n",
"Epoch 142 - loss: 0.14269666373729706\n",
"Epoch 142 - accuracy: 0.9378209831254586\n",
"Epoch 143 - loss: 0.1405152529478073\n",
"Epoch 143 - accuracy: 0.9413059427732942\n",
"Epoch 144 - loss: 0.13812802731990814\n",
"Epoch 144 - accuracy: 0.9422230374174615\n",
"Epoch 145 - loss: 0.13681475818157196\n",
"Epoch 145 - accuracy: 0.9431401320616287\n",
"Epoch 146 - loss: 0.13699214160442352\n",
"Epoch 146 - accuracy: 0.9431401320616287\n",
"Epoch 147 - loss: 0.13806407153606415\n",
"Epoch 147 - accuracy: 0.9422230374174615\n",
"Epoch 148 - loss: 0.14000557363033295\n",
"Epoch 148 - accuracy: 0.9398385913426266\n",
"Epoch 149 - loss: 0.14341016113758087\n",
"Epoch 149 - accuracy: 0.9374541452677916\n",
"Epoch 150 - loss: 0.14785771071910858\n",
"Epoch 150 - accuracy: 0.9374541452677916\n",
"Epoch 151 - loss: 0.15585456788539886\n",
"Epoch 151 - accuracy: 0.9323184152604549\n",
"Epoch 152 - loss: 0.1521560400724411\n",
"Epoch 152 - accuracy: 0.9356199559794571\n",
"Epoch 153 - loss: 0.14386101067066193\n",
"Epoch 153 - accuracy: 0.9378209831254586\n",
"Epoch 154 - loss: 0.1349945366382599\n",
"Epoch 154 - accuracy: 0.9433235509904622\n",
"Epoch 155 - loss: 0.14102381467819214\n",
"Epoch 155 - accuracy: 0.9405722670579604\n",
"Epoch 156 - loss: 0.14648419618606567\n",
"Epoch 156 - accuracy: 0.935986793837124\n",
"Epoch 157 - loss: 0.13682815432548523\n",
"Epoch 157 - accuracy: 0.9418561995597946\n",
"Epoch 158 - loss: 0.13537228107452393\n",
"Epoch 158 - accuracy: 0.9424064563462949\n",
"Epoch 159 - loss: 0.14196345210075378\n",
"Epoch 159 - accuracy: 0.9389214966984593\n",
"Epoch 160 - loss: 0.13914623856544495\n",
"Epoch 160 - accuracy: 0.9407556859867938\n",
"Epoch 161 - loss: 0.13375478982925415\n",
"Epoch 161 - accuracy: 0.9440572267057961\n",
"Epoch 162 - loss: 0.13424597680568695\n",
"Epoch 162 - accuracy: 0.9425898752751284\n",
"Epoch 163 - loss: 0.1385672688484192\n",
"Epoch 163 - accuracy: 0.9411225238444607\n",
"Epoch 164 - loss: 0.14130806922912598\n",
"Epoch 164 - accuracy: 0.9394717534849596\n",
"Epoch 165 - loss: 0.1381911337375641\n",
"Epoch 165 - accuracy: 0.9411225238444607\n",
"Epoch 166 - loss: 0.13435527682304382\n",
"Epoch 166 - accuracy: 0.9424064563462949\n",
"Epoch 167 - loss: 0.13232551515102386\n",
"Epoch 167 - accuracy: 0.9435069699192957\n",
"Epoch 168 - loss: 0.13362722098827362\n",
"Epoch 168 - accuracy: 0.9431401320616287\n",
"Epoch 169 - loss: 0.136326402425766\n",
"Epoch 169 - accuracy: 0.9429567131327953\n",
"Epoch 170 - loss: 0.13702066242694855\n",
"Epoch 170 - accuracy: 0.9422230374174615\n",
"Epoch 171 - loss: 0.13637429475784302\n",
"Epoch 171 - accuracy: 0.9425898752751284\n",
"Epoch 172 - loss: 0.13346955180168152\n",
"Epoch 172 - accuracy: 0.9424064563462949\n",
"Epoch 173 - loss: 0.1314283013343811\n",
"Epoch 173 - accuracy: 0.9438738077769626\n",
"Epoch 174 - loss: 0.1313088834285736\n",
"Epoch 174 - accuracy: 0.9440572267057961\n",
"Epoch 175 - loss: 0.13260416686534882\n",
"Epoch 175 - accuracy: 0.9429567131327953\n",
"Epoch 176 - loss: 0.13419018685817719\n",
"Epoch 176 - accuracy: 0.9431401320616287\n",
"Epoch 177 - loss: 0.1347162425518036\n",
"Epoch 177 - accuracy: 0.9429567131327953\n",
"Epoch 178 - loss: 0.13532517850399017\n",
"Epoch 178 - accuracy: 0.9433235509904622\n",
"Epoch 179 - loss: 0.1342029869556427\n",
"Epoch 179 - accuracy: 0.9438738077769626\n",
"Epoch 180 - loss: 0.1333332359790802\n",
"Epoch 180 - accuracy: 0.9436903888481292\n",
"Epoch 181 - loss: 0.13152720034122467\n",
"Epoch 181 - accuracy: 0.9436903888481292\n",
"Epoch 182 - loss: 0.1302594542503357\n",
"Epoch 182 - accuracy: 0.9446074834922964\n",
"Epoch 183 - loss: 0.1294606775045395\n",
"Epoch 183 - accuracy: 0.9438738077769626\n",
"Epoch 184 - loss: 0.1292601078748703\n",
"Epoch 184 - accuracy: 0.9446074834922964\n",
"Epoch 185 - loss: 0.1295052468776703\n",
"Epoch 185 - accuracy: 0.9447909024211298\n",
"Epoch 186 - loss: 0.13040119409561157\n",
"Epoch 186 - accuracy: 0.9433235509904622\n",
"Epoch 187 - loss: 0.13275794684886932\n",
"Epoch 187 - accuracy: 0.9433235509904622\n",
"Epoch 188 - loss: 0.1388574093580246\n",
"Epoch 188 - accuracy: 0.9418561995597946\n",
"Epoch 189 - loss: 0.15836314857006073\n",
"Epoch 189 - accuracy: 0.9328686720469552\n",
"Epoch 190 - loss: 0.17806877195835114\n",
"Epoch 190 - accuracy: 0.9268158473954512\n",
"Epoch 191 - loss: 0.20557238161563873\n",
"Epoch 191 - accuracy: 0.919112252384446\n",
"Epoch 192 - loss: 0.13659578561782837\n",
"Epoch 192 - accuracy: 0.9422230374174615\n",
"Epoch 193 - loss: 0.1556226760149002\n",
"Epoch 193 - accuracy: 0.9330520909757887\n",
"Epoch 194 - loss: 0.172604039311409\n",
"Epoch 194 - accuracy: 0.9277329420396185\n",
"Epoch 195 - loss: 0.13436183333396912\n",
"Epoch 195 - accuracy: 0.9422230374174615\n",
"Epoch 196 - loss: 0.17623932659626007\n",
"Epoch 196 - accuracy: 0.9244314013206163\n",
"Epoch 197 - loss: 0.13648562133312225\n",
"Epoch 197 - accuracy: 0.9409391049156273\n",
"Epoch 198 - loss: 0.1569497287273407\n",
"Epoch 198 - accuracy: 0.9323184152604549\n",
"Epoch 199 - loss: 0.13791748881340027\n",
"Epoch 199 - accuracy: 0.9411225238444607\n",
"Epoch 200 - loss: 0.14613977074623108\n",
"Epoch 200 - accuracy: 0.9385546588407924\n",
"Epoch 201 - loss: 0.14238546788692474\n",
"Epoch 201 - accuracy: 0.9383712399119589\n",
"Epoch 202 - loss: 0.13876666128635406\n",
"Epoch 202 - accuracy: 0.9387380777696258\n",
"Epoch 203 - loss: 0.14433924853801727\n",
"Epoch 203 - accuracy: 0.9394717534849596\n",
"Epoch 204 - loss: 0.13128788769245148\n",
"Epoch 204 - accuracy: 0.9446074834922964\n",
"Epoch 205 - loss: 0.1427002102136612\n",
"Epoch 205 - accuracy: 0.9369038884812912\n",
"Epoch 206 - loss: 0.12868227064609528\n",
"Epoch 206 - accuracy: 0.9447909024211298\n",
"Epoch 207 - loss: 0.14365875720977783\n",
"Epoch 207 - accuracy: 0.94002201027146\n",
"Epoch 208 - loss: 0.13067907094955444\n",
"Epoch 208 - accuracy: 0.9433235509904622\n",
"Epoch 209 - loss: 0.1349562257528305\n",
"Epoch 209 - accuracy: 0.94002201027146\n",
"Epoch 210 - loss: 0.13604091107845306\n",
"Epoch 210 - accuracy: 0.9436903888481292\n",
"Epoch 211 - loss: 0.1275433897972107\n",
"Epoch 211 - accuracy: 0.9451577402787967\n",
"Epoch 212 - loss: 0.13538497686386108\n",
"Epoch 212 - accuracy: 0.9414893617021277\n",
"Epoch 213 - loss: 0.1287802904844284\n",
"Epoch 213 - accuracy: 0.9438738077769626\n",
"Epoch 214 - loss: 0.1299477219581604\n",
"Epoch 214 - accuracy: 0.9447909024211298\n",
"Epoch 215 - loss: 0.13284087181091309\n",
"Epoch 215 - accuracy: 0.9427732942039618\n",
"Epoch 216 - loss: 0.12653198838233948\n",
"Epoch 216 - accuracy: 0.9453411592076302\n",
"Epoch 217 - loss: 0.13136284053325653\n",
"Epoch 217 - accuracy: 0.9440572267057961\n",
"Epoch 218 - loss: 0.13037842512130737\n",
"Epoch 218 - accuracy: 0.9440572267057961\n",
"Epoch 219 - loss: 0.1261424571275711\n",
"Epoch 219 - accuracy: 0.9451577402787967\n",
"Epoch 220 - loss: 0.13126927614212036\n",
"Epoch 220 - accuracy: 0.9438738077769626\n",
"Epoch 221 - loss: 0.12874549627304077\n",
"Epoch 221 - accuracy: 0.9451577402787967\n",
"Epoch 222 - loss: 0.12589335441589355\n",
"Epoch 222 - accuracy: 0.9451577402787967\n",
"Epoch 223 - loss: 0.13035798072814941\n",
"Epoch 223 - accuracy: 0.944424064563463\n",
"Epoch 224 - loss: 0.12775404751300812\n",
"Epoch 224 - accuracy: 0.9455245781364637\n",
"Epoch 225 - loss: 0.12526558339595795\n",
"Epoch 225 - accuracy: 0.9460748349229641\n",
"Epoch 226 - loss: 0.1290864199399948\n",
"Epoch 226 - accuracy: 0.944424064563463\n",
"Epoch 227 - loss: 0.1273193061351776\n",
"Epoch 227 - accuracy: 0.9458914159941306\n",
"Epoch 228 - loss: 0.12475359439849854\n",
"Epoch 228 - accuracy: 0.9460748349229641\n",
"Epoch 229 - loss: 0.12801365554332733\n",
"Epoch 229 - accuracy: 0.9453411592076302\n",
"Epoch 230 - loss: 0.12760193645954132\n",
"Epoch 230 - accuracy: 0.9451577402787967\n",
"Epoch 231 - loss: 0.12460287660360336\n",
"Epoch 231 - accuracy: 0.9455245781364637\n",
"Epoch 232 - loss: 0.126450777053833\n",
"Epoch 232 - accuracy: 0.9440572267057961\n",
"Epoch 233 - loss: 0.1277255415916443\n",
"Epoch 233 - accuracy: 0.9458914159941306\n",
"Epoch 234 - loss: 0.1250246912240982\n",
"Epoch 234 - accuracy: 0.9453411592076302\n",
"Epoch 235 - loss: 0.124518021941185\n",
"Epoch 235 - accuracy: 0.9457079970652972\n",
"Epoch 236 - loss: 0.12645651400089264\n",
"Epoch 236 - accuracy: 0.9442406456346295\n",
"Epoch 237 - loss: 0.12585106492042542\n",
"Epoch 237 - accuracy: 0.9447909024211298\n",
"Epoch 238 - loss: 0.12404952198266983\n",
"Epoch 238 - accuracy: 0.9462582538517975\n",
"Epoch 239 - loss: 0.12434504926204681\n",
"Epoch 239 - accuracy: 0.9460748349229641\n",
"Epoch 240 - loss: 0.12558409571647644\n",
"Epoch 240 - accuracy: 0.9451577402787967\n",
"Epoch 241 - loss: 0.12516562640666962\n",
"Epoch 241 - accuracy: 0.9460748349229641\n",
"Epoch 242 - loss: 0.1238495483994484\n",
"Epoch 242 - accuracy: 0.9462582538517975\n",
"Epoch 243 - loss: 0.12349341064691544\n",
"Epoch 243 - accuracy: 0.9462582538517975\n",
"Epoch 244 - loss: 0.124192014336586\n",
"Epoch 244 - accuracy: 0.946441672780631\n",
"Epoch 245 - loss: 0.12476740032434464\n",
"Epoch 245 - accuracy: 0.9458914159941306\n",
"Epoch 246 - loss: 0.12425507605075836\n",
"Epoch 246 - accuracy: 0.9460748349229641\n",
"Epoch 247 - loss: 0.12344306707382202\n",
"Epoch 247 - accuracy: 0.946441672780631\n",
"Epoch 248 - loss: 0.12292434275150299\n",
"Epoch 248 - accuracy: 0.9458914159941306\n",
"Epoch 249 - loss: 0.12305624783039093\n",
"Epoch 249 - accuracy: 0.9455245781364637\n",
"Epoch 250 - loss: 0.12364161759614944\n",
"Epoch 250 - accuracy: 0.9473587674247982\n",
"Epoch 251 - loss: 0.12392168492078781\n",
"Epoch 251 - accuracy: 0.9460748349229641\n",
"Epoch 252 - loss: 0.12429433315992355\n",
"Epoch 252 - accuracy: 0.9475421863536317\n",
"Epoch 253 - loss: 0.12410783767700195\n",
"Epoch 253 - accuracy: 0.9460748349229641\n",
"Epoch 254 - loss: 0.12418146431446075\n",
"Epoch 254 - accuracy: 0.9471753484959647\n",
"Epoch 255 - loss: 0.12378914654254913\n",
"Epoch 255 - accuracy: 0.9460748349229641\n",
"Epoch 256 - loss: 0.12378736585378647\n",
"Epoch 256 - accuracy: 0.9475421863536317\n",
"Epoch 257 - loss: 0.12342856079339981\n",
"Epoch 257 - accuracy: 0.9462582538517975\n",
"Epoch 258 - loss: 0.12355145812034607\n",
"Epoch 258 - accuracy: 0.9477256052824652\n",
"Epoch 259 - loss: 0.12343250960111618\n",
"Epoch 259 - accuracy: 0.9462582538517975\n",
"Epoch 260 - loss: 0.1240469440817833\n",
"Epoch 260 - accuracy: 0.9468085106382979\n",
"Epoch 261 - loss: 0.12469983100891113\n",
"Epoch 261 - accuracy: 0.9471753484959647\n",
"Epoch 262 - loss: 0.12708620727062225\n",
"Epoch 262 - accuracy: 0.9442406456346295\n",
"Epoch 263 - loss: 0.13097906112670898\n",
"Epoch 263 - accuracy: 0.944424064563463\n",
"Epoch 264 - loss: 0.13735362887382507\n",
"Epoch 264 - accuracy: 0.9425898752751284\n",
"Epoch 265 - loss: 0.14637614786624908\n",
"Epoch 265 - accuracy: 0.9389214966984593\n",
"Epoch 266 - loss: 0.13928058743476868\n",
"Epoch 266 - accuracy: 0.9418561995597946\n",
"Epoch 267 - loss: 0.1263727992773056\n",
"Epoch 267 - accuracy: 0.9457079970652972\n",
"Epoch 268 - loss: 0.12248501926660538\n",
"Epoch 268 - accuracy: 0.946441672780631\n",
"Epoch 269 - loss: 0.13156425952911377\n",
"Epoch 269 - accuracy: 0.9436903888481292\n",
"Epoch 270 - loss: 0.13035646080970764\n",
"Epoch 270 - accuracy: 0.9438738077769626\n",
"Epoch 271 - loss: 0.12129738926887512\n",
"Epoch 271 - accuracy: 0.9482758620689655\n",
"Epoch 272 - loss: 0.12735018134117126\n",
"Epoch 272 - accuracy: 0.9453411592076302\n",
"Epoch 273 - loss: 0.13254329562187195\n",
"Epoch 273 - accuracy: 0.944424064563463\n",
"Epoch 274 - loss: 0.12437407672405243\n",
"Epoch 274 - accuracy: 0.9469919295671313\n",
"Epoch 275 - loss: 0.12198557704687119\n",
"Epoch 275 - accuracy: 0.9479090242112986\n",
"Epoch 276 - loss: 0.1282014101743698\n",
"Epoch 276 - accuracy: 0.9449743213499633\n",
"Epoch 277 - loss: 0.13050922751426697\n",
"Epoch 277 - accuracy: 0.9436903888481292\n",
"Epoch 278 - loss: 0.12561355531215668\n",
"Epoch 278 - accuracy: 0.9460748349229641\n",
"Epoch 279 - loss: 0.12143143266439438\n",
"Epoch 279 - accuracy: 0.9482758620689655\n",
"Epoch 280 - loss: 0.12208808958530426\n",
"Epoch 280 - accuracy: 0.9488261188554659\n",
"Epoch 281 - loss: 0.12578213214874268\n",
"Epoch 281 - accuracy: 0.9468085106382979\n",
"Epoch 282 - loss: 0.12954838573932648\n",
"Epoch 282 - accuracy: 0.9442406456346295\n",
"Epoch 283 - loss: 0.12968361377716064\n",
"Epoch 283 - accuracy: 0.9458914159941306\n",
"Epoch 284 - loss: 0.125925675034523\n",
"Epoch 284 - accuracy: 0.9453411592076302\n",
"Epoch 285 - loss: 0.1212596595287323\n",
"Epoch 285 - accuracy: 0.9466250917094644\n",
"Epoch 286 - loss: 0.12094567716121674\n",
"Epoch 286 - accuracy: 0.9477256052824652\n",
"Epoch 287 - loss: 0.12425583600997925\n",
"Epoch 287 - accuracy: 0.9466250917094644\n",
"Epoch 288 - loss: 0.12605202198028564\n",
"Epoch 288 - accuracy: 0.946441672780631\n",
"Epoch 289 - loss: 0.1254625767469406\n",
"Epoch 289 - accuracy: 0.9453411592076302\n",
"Epoch 290 - loss: 0.12259259074926376\n",
"Epoch 290 - accuracy: 0.9455245781364637\n",
"Epoch 291 - loss: 0.12090922147035599\n",
"Epoch 291 - accuracy: 0.9482758620689655\n",
"Epoch 292 - loss: 0.12081866711378098\n",
"Epoch 292 - accuracy: 0.9482758620689655\n",
"Epoch 293 - loss: 0.12197820842266083\n",
"Epoch 293 - accuracy: 0.9451577402787967\n",
"Epoch 294 - loss: 0.12467528134584427\n",
"Epoch 294 - accuracy: 0.9460748349229641\n",
"Epoch 295 - loss: 0.12644042074680328\n",
"Epoch 295 - accuracy: 0.9473587674247982\n",
"Epoch 296 - loss: 0.12967751920223236\n",
"Epoch 296 - accuracy: 0.9431401320616287\n",
"Epoch 297 - loss: 0.13314160704612732\n",
"Epoch 297 - accuracy: 0.9447909024211298\n",
"Epoch 298 - loss: 0.13466300070285797\n",
"Epoch 298 - accuracy: 0.9429567131327953\n",
"Epoch 299 - loss: 0.13335949182510376\n",
"Epoch 299 - accuracy: 0.9438738077769626\n",
"Epoch 300 - loss: 0.12456800788640976\n",
"Epoch 300 - accuracy: 0.946441672780631\n",
"Epoch 301 - loss: 0.12012847512960434\n",
"Epoch 301 - accuracy: 0.9493763756419662\n",
"Epoch 302 - loss: 0.12398750334978104\n",
"Epoch 302 - accuracy: 0.9460748349229641\n",
"Epoch 303 - loss: 0.12710987031459808\n",
"Epoch 303 - accuracy: 0.944424064563463\n",
"Epoch 304 - loss: 0.12448311597108841\n",
"Epoch 304 - accuracy: 0.9468085106382979\n",
"Epoch 305 - loss: 0.12061852961778641\n",
"Epoch 305 - accuracy: 0.948459280997799\n",
"Epoch 306 - loss: 0.12071719765663147\n",
"Epoch 306 - accuracy: 0.9482758620689655\n",
"Epoch 307 - loss: 0.12375523895025253\n",
"Epoch 307 - accuracy: 0.9462582538517975\n",
"Epoch 308 - loss: 0.12774750590324402\n",
"Epoch 308 - accuracy: 0.9442406456346295\n",
"Epoch 309 - loss: 0.12890563905239105\n",
"Epoch 309 - accuracy: 0.9460748349229641\n",
"Epoch 310 - loss: 0.12809687852859497\n",
"Epoch 310 - accuracy: 0.9438738077769626\n",
"Epoch 311 - loss: 0.12481006234884262\n",
"Epoch 311 - accuracy: 0.9475421863536317\n",
"Epoch 312 - loss: 0.12165655940771103\n",
"Epoch 312 - accuracy: 0.9473587674247982\n",
"Epoch 313 - loss: 0.11955662071704865\n",
"Epoch 313 - accuracy: 0.9491929567131328\n",
"Epoch 314 - loss: 0.12011021375656128\n",
"Epoch 314 - accuracy: 0.9490095377842993\n",
"Epoch 315 - loss: 0.1227300688624382\n",
"Epoch 315 - accuracy: 0.9473587674247982\n",
"Epoch 316 - loss: 0.1252894401550293\n",
"Epoch 316 - accuracy: 0.9479090242112986\n",
"Epoch 317 - loss: 0.12868621945381165\n",
"Epoch 317 - accuracy: 0.9438738077769626\n",
"Epoch 318 - loss: 0.13128605484962463\n",
"Epoch 318 - accuracy: 0.9455245781364637\n",
"Epoch 319 - loss: 0.13100145757198334\n",
"Epoch 319 - accuracy: 0.9440572267057961\n",
"Epoch 320 - loss: 0.12706202268600464\n",
"Epoch 320 - accuracy: 0.9460748349229641\n",
"Epoch 321 - loss: 0.12146887183189392\n",
"Epoch 321 - accuracy: 0.9480924431401321\n",
"Epoch 322 - loss: 0.11979363113641739\n",
"Epoch 322 - accuracy: 0.948459280997799\n",
"Epoch 323 - loss: 0.12291908264160156\n",
"Epoch 323 - accuracy: 0.9466250917094644\n",
"Epoch 324 - loss: 0.1266350895166397\n",
"Epoch 324 - accuracy: 0.9446074834922964\n",
"Epoch 325 - loss: 0.12731537222862244\n",
"Epoch 325 - accuracy: 0.9468085106382979\n",
"Epoch 326 - loss: 0.12540048360824585\n",
"Epoch 326 - accuracy: 0.9451577402787967\n",
"Epoch 327 - loss: 0.12146314978599548\n",
"Epoch 327 - accuracy: 0.9473587674247982\n",
"Epoch 328 - loss: 0.11976297944784164\n",
"Epoch 328 - accuracy: 0.9490095377842993\n",
"Epoch 329 - loss: 0.12041870504617691\n",
"Epoch 329 - accuracy: 0.9482758620689655\n",
"Epoch 330 - loss: 0.12252318114042282\n",
"Epoch 330 - accuracy: 0.9468085106382979\n",
"Epoch 331 - loss: 0.1266428381204605\n",
"Epoch 331 - accuracy: 0.9442406456346295\n",
"Epoch 332 - loss: 0.1323375105857849\n",
"Epoch 332 - accuracy: 0.9447909024211298\n",
"Epoch 333 - loss: 0.14109158515930176\n",
"Epoch 333 - accuracy: 0.9407556859867938\n",
"Epoch 334 - loss: 0.15532760322093964\n",
"Epoch 334 - accuracy: 0.9350696991929567\n",
"Epoch 335 - loss: 0.14378340542316437\n",
"Epoch 335 - accuracy: 0.9392883345561261\n",
"Epoch 336 - loss: 0.12407119572162628\n",
"Epoch 336 - accuracy: 0.9479090242112986\n",
"Epoch 337 - loss: 0.12199515104293823\n",
"Epoch 337 - accuracy: 0.9482758620689655\n",
"Epoch 338 - loss: 0.1333143711090088\n",
"Epoch 338 - accuracy: 0.9446074834922964\n",
"Epoch 339 - loss: 0.1249939352273941\n",
"Epoch 339 - accuracy: 0.9469919295671313\n",
"Epoch 340 - loss: 0.11966632306575775\n",
"Epoch 340 - accuracy: 0.948459280997799\n",
"Epoch 341 - loss: 0.13108626008033752\n",
"Epoch 341 - accuracy: 0.9449743213499633\n",
"Epoch 342 - loss: 0.1272350698709488\n",
"Epoch 342 - accuracy: 0.9458914159941306\n",
"Epoch 343 - loss: 0.11856188625097275\n",
"Epoch 343 - accuracy: 0.9479090242112986\n",
"Epoch 344 - loss: 0.12906348705291748\n",
"Epoch 344 - accuracy: 0.9447909024211298\n",
"Epoch 345 - loss: 0.1326865404844284\n",
"Epoch 345 - accuracy: 0.9425898752751284\n",
"Epoch 346 - loss: 0.12189553678035736\n",
"Epoch 346 - accuracy: 0.9479090242112986\n",
"Epoch 347 - loss: 0.12244117259979248\n",
"Epoch 347 - accuracy: 0.9466250917094644\n",
"Epoch 348 - loss: 0.13159716129302979\n",
"Epoch 348 - accuracy: 0.9449743213499633\n",
"Epoch 349 - loss: 0.13086934387683868\n",
"Epoch 349 - accuracy: 0.9440572267057961\n",
"Epoch 350 - loss: 0.12248321622610092\n",
"Epoch 350 - accuracy: 0.948459280997799\n",
"Epoch 351 - loss: 0.11868999153375626\n",
"Epoch 351 - accuracy: 0.9499266324284666\n",
"Epoch 352 - loss: 0.1236417293548584\n",
"Epoch 352 - accuracy: 0.9451577402787967\n",
"Epoch 353 - loss: 0.127396360039711\n",
"Epoch 353 - accuracy: 0.9469919295671313\n",
"Epoch 354 - loss: 0.12435289472341537\n",
"Epoch 354 - accuracy: 0.9451577402787967\n",
"Epoch 355 - loss: 0.11912697553634644\n",
"Epoch 355 - accuracy: 0.9499266324284666\n",
"Epoch 356 - loss: 0.11984141916036606\n",
"Epoch 356 - accuracy: 0.9490095377842993\n",
"Epoch 357 - loss: 0.12398801743984222\n",
"Epoch 357 - accuracy: 0.9462582538517975\n",
"Epoch 358 - loss: 0.12404925376176834\n",
"Epoch 358 - accuracy: 0.9479090242112986\n",
"Epoch 359 - loss: 0.12187787145376205\n",
"Epoch 359 - accuracy: 0.9477256052824652\n",
"Epoch 360 - loss: 0.11939382553100586\n",
"Epoch 360 - accuracy: 0.950476889214967\n",
"Epoch 361 - loss: 0.12002267688512802\n",
"Epoch 361 - accuracy: 0.9491929567131328\n",
"Epoch 362 - loss: 0.12263120710849762\n",
"Epoch 362 - accuracy: 0.9471753484959647\n",
"Epoch 363 - loss: 0.12382364273071289\n",
"Epoch 363 - accuracy: 0.948459280997799\n",
"Epoch 364 - loss: 0.12410726398229599\n",
"Epoch 364 - accuracy: 0.9462582538517975\n",
"Epoch 365 - loss: 0.12248091399669647\n",
"Epoch 365 - accuracy: 0.9480924431401321\n",
"Epoch 366 - loss: 0.12116307765245438\n",
"Epoch 366 - accuracy: 0.9482758620689655\n",
"Epoch 367 - loss: 0.11971309036016464\n",
"Epoch 367 - accuracy: 0.950476889214967\n",
"Epoch 368 - loss: 0.11947835236787796\n",
"Epoch 368 - accuracy: 0.950476889214967\n",
"Epoch 369 - loss: 0.11992073804140091\n",
"Epoch 369 - accuracy: 0.9502934702861335\n",
"Epoch 370 - loss: 0.12049725651741028\n",
"Epoch 370 - accuracy: 0.9491929567131328\n",
"Epoch 371 - loss: 0.12239193171262741\n",
"Epoch 371 - accuracy: 0.9477256052824652\n",
"Epoch 372 - loss: 0.12434761226177216\n",
"Epoch 372 - accuracy: 0.9479090242112986\n",
"Epoch 373 - loss: 0.1290699988603592\n",
"Epoch 373 - accuracy: 0.9431401320616287\n",
"Epoch 374 - loss: 0.1359148472547531\n",
"Epoch 374 - accuracy: 0.9433235509904622\n",
"Epoch 375 - loss: 0.1423913836479187\n",
"Epoch 375 - accuracy: 0.9392883345561261\n",
"Epoch 376 - loss: 0.14881592988967896\n",
"Epoch 376 - accuracy: 0.9391049156272927\n",
"Epoch 377 - loss: 0.13204756379127502\n",
"Epoch 377 - accuracy: 0.9446074834922964\n",
"Epoch 378 - loss: 0.11842061579227448\n",
"Epoch 378 - accuracy: 0.9508437270726339\n",
"Epoch 379 - loss: 0.12411734461784363\n",
"Epoch 379 - accuracy: 0.9477256052824652\n",
"Epoch 380 - loss: 0.12863674759864807\n",
"Epoch 380 - accuracy: 0.9442406456346295\n",
"Epoch 381 - loss: 0.12139783054590225\n",
"Epoch 381 - accuracy: 0.9490095377842993\n",
"Epoch 382 - loss: 0.11913666129112244\n",
"Epoch 382 - accuracy: 0.9493763756419662\n",
"Epoch 383 - loss: 0.12667888402938843\n",
"Epoch 383 - accuracy: 0.9457079970652972\n",
"Epoch 384 - loss: 0.1274142861366272\n",
"Epoch 384 - accuracy: 0.9468085106382979\n",
"Epoch 385 - loss: 0.12151997536420822\n",
"Epoch 385 - accuracy: 0.9482758620689655\n",
"Epoch 386 - loss: 0.11988431960344315\n",
"Epoch 386 - accuracy: 0.9497432134996332\n",
"Epoch 387 - loss: 0.12450485676527023\n",
"Epoch 387 - accuracy: 0.9477256052824652\n",
"Epoch 388 - loss: 0.1301063597202301\n",
"Epoch 388 - accuracy: 0.9438738077769626\n",
"Epoch 389 - loss: 0.13058441877365112\n",
"Epoch 389 - accuracy: 0.9453411592076302\n",
"Epoch 390 - loss: 0.1274021863937378\n",
"Epoch 390 - accuracy: 0.9460748349229641\n",
"Epoch 391 - loss: 0.12132485955953598\n",
"Epoch 391 - accuracy: 0.9497432134996332\n",
"Epoch 392 - loss: 0.11876960098743439\n",
"Epoch 392 - accuracy: 0.950476889214967\n",
"Epoch 393 - loss: 0.12090307474136353\n",
"Epoch 393 - accuracy: 0.9482758620689655\n",
"Epoch 394 - loss: 0.12387574464082718\n",
"Epoch 394 - accuracy: 0.9482758620689655\n",
"Epoch 395 - loss: 0.12585753202438354\n",
"Epoch 395 - accuracy: 0.9436903888481292\n",
"Epoch 396 - loss: 0.12338998168706894\n",
"Epoch 396 - accuracy: 0.9480924431401321\n",
"Epoch 397 - loss: 0.12095915526151657\n",
"Epoch 397 - accuracy: 0.9482758620689655\n",
"Epoch 398 - loss: 0.11895456165075302\n",
"Epoch 398 - accuracy: 0.9512105649303008\n",
"Epoch 399 - loss: 0.11982899159193039\n",
"Epoch 399 - accuracy: 0.9510271460014673\n",
"Epoch 400 - loss: 0.1227630227804184\n",
"Epoch 400 - accuracy: 0.9475421863536317\n",
"Epoch 401 - loss: 0.1245056763291359\n",
"Epoch 401 - accuracy: 0.9477256052824652\n",
"Epoch 402 - loss: 0.1266520470380783\n",
"Epoch 402 - accuracy: 0.9451577402787967\n",
"Epoch 403 - loss: 0.12723323702812195\n",
"Epoch 403 - accuracy: 0.9469919295671313\n",
"Epoch 404 - loss: 0.12804493308067322\n",
"Epoch 404 - accuracy: 0.9447909024211298\n",
"Epoch 405 - loss: 0.12689587473869324\n",
"Epoch 405 - accuracy: 0.9468085106382979\n",
"Epoch 406 - loss: 0.1248333603143692\n",
"Epoch 406 - accuracy: 0.946441672780631\n",
"Epoch 407 - loss: 0.12143062800168991\n",
"Epoch 407 - accuracy: 0.9490095377842993\n",
"Epoch 408 - loss: 0.11979366093873978\n",
"Epoch 408 - accuracy: 0.9510271460014673\n",
"Epoch 409 - loss: 0.11989187449216843\n",
"Epoch 409 - accuracy: 0.9506603081438004\n",
"Epoch 410 - loss: 0.12118031084537506\n",
"Epoch 410 - accuracy: 0.9493763756419662\n",
"Epoch 411 - loss: 0.12442609667778015\n",
"Epoch 411 - accuracy: 0.9468085106382979\n",
"Epoch 412 - loss: 0.12685492634773254\n",
"Epoch 412 - accuracy: 0.9473587674247982\n",
"Epoch 413 - loss: 0.13157705962657928\n",
"Epoch 413 - accuracy: 0.9440572267057961\n",
"Epoch 414 - loss: 0.13618524372577667\n",
"Epoch 414 - accuracy: 0.9429567131327953\n",
"Epoch 415 - loss: 0.13877074420452118\n",
"Epoch 415 - accuracy: 0.9422230374174615\n",
"Epoch 416 - loss: 0.13766561448574066\n",
"Epoch 416 - accuracy: 0.942039618488628\n",
"Epoch 417 - loss: 0.12581269443035126\n",
"Epoch 417 - accuracy: 0.9466250917094644\n",
"Epoch 418 - loss: 0.11829851567745209\n",
"Epoch 418 - accuracy: 0.9502934702861335\n",
"Epoch 419 - loss: 0.12286898493766785\n",
"Epoch 419 - accuracy: 0.948459280997799\n",
"Epoch 420 - loss: 0.12826034426689148\n",
"Epoch 420 - accuracy: 0.9442406456346295\n",
"Epoch 421 - loss: 0.12521657347679138\n",
"Epoch 421 - accuracy: 0.9469919295671313\n",
"Epoch 422 - loss: 0.12012328952550888\n",
"Epoch 422 - accuracy: 0.9488261188554659\n",
"Epoch 423 - loss: 0.12050732970237732\n",
"Epoch 423 - accuracy: 0.9501100513573001\n",
"Epoch 424 - loss: 0.12443410605192184\n",
"Epoch 424 - accuracy: 0.9475421863536317\n",
"Epoch 425 - loss: 0.1285167634487152\n",
"Epoch 425 - accuracy: 0.9457079970652972\n",
"Epoch 426 - loss: 0.12714825570583344\n",
"Epoch 426 - accuracy: 0.9468085106382979\n",
"Epoch 427 - loss: 0.12536460161209106\n",
"Epoch 427 - accuracy: 0.946441672780631\n",
"Epoch 428 - loss: 0.12127231061458588\n",
"Epoch 428 - accuracy: 0.9502934702861335\n",
"Epoch 429 - loss: 0.11975359916687012\n",
"Epoch 429 - accuracy: 0.9497432134996332\n",
"Epoch 430 - loss: 0.11963523924350739\n",
"Epoch 430 - accuracy: 0.9506603081438004\n",
"Epoch 431 - loss: 0.12077296525239944\n",
"Epoch 431 - accuracy: 0.950476889214967\n",
"Epoch 432 - loss: 0.12409121543169022\n",
"Epoch 432 - accuracy: 0.9473587674247982\n",
"Epoch 433 - loss: 0.12621550261974335\n",
"Epoch 433 - accuracy: 0.9477256052824652\n",
"Epoch 434 - loss: 0.1316927820444107\n",
"Epoch 434 - accuracy: 0.944424064563463\n",
"Epoch 435 - loss: 0.13731232285499573\n",
"Epoch 435 - accuracy: 0.9427732942039618\n",
"Epoch 436 - loss: 0.14055246114730835\n",
"Epoch 436 - accuracy: 0.9411225238444607\n",
"Epoch 437 - loss: 0.13798841834068298\n",
"Epoch 437 - accuracy: 0.942039618488628\n",
"Epoch 438 - loss: 0.12439721822738647\n",
"Epoch 438 - accuracy: 0.9471753484959647\n",
"Epoch 439 - loss: 0.11913619935512543\n",
"Epoch 439 - accuracy: 0.9508437270726339\n",
"Epoch 440 - loss: 0.12641268968582153\n",
"Epoch 440 - accuracy: 0.9473587674247982\n",
"Epoch 441 - loss: 0.12938538193702698\n",
"Epoch 441 - accuracy: 0.9451577402787967\n",
"Epoch 442 - loss: 0.12298136949539185\n",
"Epoch 442 - accuracy: 0.9490095377842993\n",
"Epoch 443 - loss: 0.11922991275787354\n",
"Epoch 443 - accuracy: 0.9506603081438004\n",
"Epoch 444 - loss: 0.12371832132339478\n",
"Epoch 444 - accuracy: 0.9477256052824652\n",
"Epoch 445 - loss: 0.12833170592784882\n",
"Epoch 445 - accuracy: 0.9466250917094644\n",
"Epoch 446 - loss: 0.13009405136108398\n",
"Epoch 446 - accuracy: 0.9451577402787967\n",
"Epoch 447 - loss: 0.12564228475093842\n",
"Epoch 447 - accuracy: 0.9482758620689655\n",
"Epoch 448 - loss: 0.12178538739681244\n",
"Epoch 448 - accuracy: 0.9499266324284666\n",
"Epoch 449 - loss: 0.1192527562379837\n",
"Epoch 449 - accuracy: 0.9502934702861335\n",
"Epoch 450 - loss: 0.12046285718679428\n",
"Epoch 450 - accuracy: 0.950476889214967\n",
"Epoch 451 - loss: 0.12494070082902908\n",
"Epoch 451 - accuracy: 0.9479090242112986\n",
"Epoch 452 - loss: 0.12780679762363434\n",
"Epoch 452 - accuracy: 0.9480924431401321\n",
"Epoch 453 - loss: 0.13254830241203308\n",
"Epoch 453 - accuracy: 0.9446074834922964\n",
"Epoch 454 - loss: 0.13431498408317566\n",
"Epoch 454 - accuracy: 0.9442406456346295\n",
"Epoch 455 - loss: 0.13251695036888123\n",
"Epoch 455 - accuracy: 0.944424064563463\n",
"Epoch 456 - loss: 0.1252819150686264\n",
"Epoch 456 - accuracy: 0.9471753484959647\n",
"Epoch 457 - loss: 0.1196482852101326\n",
"Epoch 457 - accuracy: 0.9510271460014673\n",
"Epoch 458 - loss: 0.12122555822134018\n",
"Epoch 458 - accuracy: 0.9495597945707998\n",
"Epoch 459 - loss: 0.12617750465869904\n",
"Epoch 459 - accuracy: 0.9471753484959647\n",
"Epoch 460 - loss: 0.1289263814687729\n",
"Epoch 460 - accuracy: 0.9453411592076302\n",
"Epoch 461 - loss: 0.12516246736049652\n",
"Epoch 461 - accuracy: 0.9473587674247982\n",
"Epoch 462 - loss: 0.12158451974391937\n",
"Epoch 462 - accuracy: 0.9501100513573001\n",
"Epoch 463 - loss: 0.1201198622584343\n",
"Epoch 463 - accuracy: 0.9510271460014673\n",
"Epoch 464 - loss: 0.1223447248339653\n",
"Epoch 464 - accuracy: 0.9501100513573001\n",
"Epoch 465 - loss: 0.12768422067165375\n",
"Epoch 465 - accuracy: 0.9468085106382979\n",
"Epoch 466 - loss: 0.13118062913417816\n",
"Epoch 466 - accuracy: 0.946441672780631\n",
"Epoch 467 - loss: 0.13691848516464233\n",
"Epoch 467 - accuracy: 0.944424064563463\n",
"Epoch 468 - loss: 0.1410398930311203\n",
"Epoch 468 - accuracy: 0.9414893617021277\n",
"Epoch 469 - loss: 0.1366601139307022\n",
"Epoch 469 - accuracy: 0.9436903888481292\n",
"Epoch 470 - loss: 0.12491223961114883\n",
"Epoch 470 - accuracy: 0.948459280997799\n",
"Epoch 471 - loss: 0.11888165026903152\n",
"Epoch 471 - accuracy: 0.9519442406456347\n",
"Epoch 472 - loss: 0.12632633745670319\n",
"Epoch 472 - accuracy: 0.9458914159941306\n",
"Epoch 473 - loss: 0.13229218125343323\n",
"Epoch 473 - accuracy: 0.9457079970652972\n",
"Epoch 474 - loss: 0.1283317506313324\n",
"Epoch 474 - accuracy: 0.9458914159941306\n",
"Epoch 475 - loss: 0.11960683017969131\n",
"Epoch 475 - accuracy: 0.9506603081438004\n",
"Epoch 476 - loss: 0.12003237754106522\n",
"Epoch 476 - accuracy: 0.950476889214967\n",
"Epoch 477 - loss: 0.1274898201227188\n",
"Epoch 477 - accuracy: 0.9471753484959647\n",
"Epoch 478 - loss: 0.12885287404060364\n",
"Epoch 478 - accuracy: 0.9475421863536317\n",
"Epoch 479 - loss: 0.12669768929481506\n",
"Epoch 479 - accuracy: 0.9480924431401321\n",
"Epoch 480 - loss: 0.1212262511253357\n",
"Epoch 480 - accuracy: 0.9510271460014673\n",
"Epoch 481 - loss: 0.1199721023440361\n",
"Epoch 481 - accuracy: 0.9515774027879678\n",
"Epoch 482 - loss: 0.12229137867689133\n",
"Epoch 482 - accuracy: 0.9502934702861335\n",
"Epoch 483 - loss: 0.1249641627073288\n",
"Epoch 483 - accuracy: 0.9479090242112986\n",
"Epoch 484 - loss: 0.12990516424179077\n",
"Epoch 484 - accuracy: 0.9471753484959647\n",
"Epoch 485 - loss: 0.13138818740844727\n",
"Epoch 485 - accuracy: 0.946441672780631\n",
"Epoch 486 - loss: 0.13237611949443817\n",
"Epoch 486 - accuracy: 0.9455245781364637\n",
"Epoch 487 - loss: 0.12806807458400726\n",
"Epoch 487 - accuracy: 0.9473587674247982\n",
"Epoch 488 - loss: 0.12325626611709595\n",
"Epoch 488 - accuracy: 0.9495597945707998\n",
"Epoch 489 - loss: 0.11965951323509216\n",
"Epoch 489 - accuracy: 0.9519442406456347\n",
"Epoch 490 - loss: 0.12142913788557053\n",
"Epoch 490 - accuracy: 0.9506603081438004\n",
"Epoch 491 - loss: 0.1262865960597992\n",
"Epoch 491 - accuracy: 0.9473587674247982\n",
"Epoch 492 - loss: 0.12745560705661774\n",
"Epoch 492 - accuracy: 0.9468085106382979\n",
"Epoch 493 - loss: 0.12800352275371552\n",
"Epoch 493 - accuracy: 0.948459280997799\n",
"Epoch 494 - loss: 0.12423637509346008\n",
"Epoch 494 - accuracy: 0.9482758620689655\n",
"Epoch 495 - loss: 0.12243915349245071\n",
"Epoch 495 - accuracy: 0.9501100513573001\n",
"Epoch 496 - loss: 0.12085887044668198\n",
"Epoch 496 - accuracy: 0.9515774027879678\n",
"Epoch 497 - loss: 0.12142859399318695\n",
"Epoch 497 - accuracy: 0.9508437270726339\n",
"Epoch 498 - loss: 0.12384403496980667\n",
"Epoch 498 - accuracy: 0.9495597945707998\n",
"Epoch 499 - loss: 0.12585672736167908\n",
"Epoch 499 - accuracy: 0.9475421863536317\n",
"0.9475421863536317\n",
"Best epoch: 471\n",
"Max accuracy: 0.9519442406456347\n"
]
}
],
"source": [
"model = NeuralNetwork(100, 256)\n",
"\n",
"criterion = nn.BCELoss()\n",
"optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=0.001)\n",
"\n",
"X_train_tensor = torch.from_numpy(X_train).float()\n",
"y_train_tensor = torch.from_numpy(y_train).float().view(-1, 1)\n",
"\n",
"X_dev_tensor = torch.from_numpy(X_dev).float()\n",
"y_dev_tensor = torch.from_numpy(y_dev).float().view(-1, 1)\n",
"\n",
"model.train()\n",
"\n",
"best_epoch = 0\n",
"max_accuracy = 0.0\n",
"\n",
"for epoch in range(500):\n",
" \n",
" optimizer.zero_grad()\n",
" y_pred = model(X_train_tensor)\n",
" loss = criterion(y_pred, y_train_tensor)\n",
" loss.backward()\n",
" optimizer.step()\n",
" \n",
" # dev loss\n",
" with torch.no_grad():\n",
" y_pred = model(X_dev_tensor)\n",
" loss = criterion(y_pred, y_dev_tensor)\n",
" accuracy = accuracy_score(y_dev_tensor, np.where(y_pred > 0.5, 1, 0))\n",
"\n",
" if max_accuracy < accuracy:\n",
" best_epoch = epoch\n",
" max_accuracy = accuracy\n",
" \n",
" print(f\"Epoch {epoch} - loss: {loss}\")\n",
" print(f\"Epoch {epoch} - accuracy: {accuracy}\")\n",
" \n",
"model.eval()\n",
"\n",
"with torch.no_grad():\n",
" y_pred = model(X_dev_tensor)\n",
" y_pred = np.where(y_pred > 0.5, 1, 0)\n",
" accuracy = accuracy_score(y_dev_tensor, y_pred)\n",
" print(accuracy)\n",
"\n",
"print(f\"Best epoch: {best_epoch}\")\n",
"print(f\"Max accuracy: {max_accuracy}\")"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "00f4dbb4-92d5-4c8c-b55b-9f6046e5894e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 0 - loss: 0.6600890755653381\n",
"Epoch 0 - accuracy: 0.636463683052091\n",
"Epoch 1 - loss: 0.6251927614212036\n",
"Epoch 1 - accuracy: 0.7168011738811445\n",
"Epoch 2 - loss: 0.6044067740440369\n",
"Epoch 2 - accuracy: 0.7454145267791636\n",
"Epoch 3 - loss: 0.5796983242034912\n",
"Epoch 3 - accuracy: 0.6925898752751284\n",
"Epoch 4 - loss: 0.5561812520027161\n",
"Epoch 4 - accuracy: 0.7035950110051358\n",
"Epoch 5 - loss: 0.5200029015541077\n",
"Epoch 5 - accuracy: 0.8017241379310345\n",
"Epoch 6 - loss: 0.4880651533603668\n",
"Epoch 6 - accuracy: 0.8242846661775495\n",
"Epoch 7 - loss: 0.4609140455722809\n",
"Epoch 7 - accuracy: 0.7848495964783566\n",
"Epoch 8 - loss: 0.4216255843639374\n",
"Epoch 8 - accuracy: 0.8431768158473955\n",
"Epoch 9 - loss: 0.39770105481147766\n",
"Epoch 9 - accuracy: 0.8464783565663977\n",
"Epoch 10 - loss: 0.38234901428222656\n",
"Epoch 10 - accuracy: 0.8275862068965517\n",
"Epoch 11 - loss: 0.3460651636123657\n",
"Epoch 11 - accuracy: 0.8617021276595744\n",
"Epoch 12 - loss: 0.3165454864501953\n",
"Epoch 12 - accuracy: 0.8769258987527513\n",
"Epoch 13 - loss: 0.31047651171684265\n",
"Epoch 13 - accuracy: 0.871606749816581\n",
"Epoch 14 - loss: 0.31299567222595215\n",
"Epoch 14 - accuracy: 0.8670212765957447\n",
"Epoch 15 - loss: 0.28543993830680847\n",
"Epoch 15 - accuracy: 0.8826118855465884\n",
"Epoch 16 - loss: 0.2618979215621948\n",
"Epoch 16 - accuracy: 0.8971019809244314\n",
"Epoch 17 - loss: 0.2782493829727173\n",
"Epoch 17 - accuracy: 0.8794937637564196\n",
"Epoch 18 - loss: 0.26039671897888184\n",
"Epoch 18 - accuracy: 0.8947175348495965\n",
"Epoch 19 - loss: 0.2395632565021515\n",
"Epoch 19 - accuracy: 0.9042553191489362\n",
"Epoch 20 - loss: 0.26059266924858093\n",
"Epoch 20 - accuracy: 0.8882978723404256\n",
"Epoch 21 - loss: 0.23744292557239532\n",
"Epoch 21 - accuracy: 0.9038884812912693\n",
"Epoch 22 - loss: 0.23012538254261017\n",
"Epoch 22 - accuracy: 0.9053558327219369\n",
"Epoch 23 - loss: 0.24461759626865387\n",
"Epoch 23 - accuracy: 0.8932501834189288\n",
"Epoch 24 - loss: 0.21968719363212585\n",
"Epoch 24 - accuracy: 0.9101247248716068\n",
"Epoch 25 - loss: 0.22737234830856323\n",
"Epoch 25 - accuracy: 0.9084739545121057\n",
"Epoch 26 - loss: 0.22357280552387238\n",
"Epoch 26 - accuracy: 0.9035216434336023\n",
"Epoch 27 - loss: 0.2126060277223587\n",
"Epoch 27 - accuracy: 0.9088407923697726\n",
"Epoch 28 - loss: 0.22030684351921082\n",
"Epoch 28 - accuracy: 0.9097578870139399\n",
"Epoch 29 - loss: 0.20767514407634735\n",
"Epoch 29 - accuracy: 0.9126925898752751\n",
"Epoch 30 - loss: 0.21199911832809448\n",
"Epoch 30 - accuracy: 0.9103081438004402\n",
"Epoch 31 - loss: 0.21017472445964813\n",
"Epoch 31 - accuracy: 0.911041819515774\n",
"Epoch 32 - loss: 0.20167312026023865\n",
"Epoch 32 - accuracy: 0.9180117388114454\n",
"Epoch 33 - loss: 0.20891046524047852\n",
"Epoch 33 - accuracy: 0.913059427732942\n",
"Epoch 34 - loss: 0.20060847699642181\n",
"Epoch 34 - accuracy: 0.9165443873807777\n",
"Epoch 35 - loss: 0.19901244342327118\n",
"Epoch 35 - accuracy: 0.917094644167278\n",
"Epoch 36 - loss: 0.20174746215343475\n",
"Epoch 36 - accuracy: 0.9145267791636097\n",
"Epoch 37 - loss: 0.19371414184570312\n",
"Epoch 37 - accuracy: 0.9200293470286134\n",
"Epoch 38 - loss: 0.1961180418729782\n",
"Epoch 38 - accuracy: 0.9185619955979457\n",
"Epoch 39 - loss: 0.19458524882793427\n",
"Epoch 39 - accuracy: 0.9194790902421129\n",
"Epoch 40 - loss: 0.18955352902412415\n",
"Epoch 40 - accuracy: 0.9213132795304475\n",
"Epoch 41 - loss: 0.19220395386219025\n",
"Epoch 41 - accuracy: 0.9203961848862803\n",
"Epoch 42 - loss: 0.18957814574241638\n",
"Epoch 42 - accuracy: 0.9211298606016141\n",
"Epoch 43 - loss: 0.18620409071445465\n",
"Epoch 43 - accuracy: 0.9227806309611152\n",
"Epoch 44 - loss: 0.18816402554512024\n",
"Epoch 44 - accuracy: 0.921496698459281\n",
"Epoch 45 - loss: 0.18725506961345673\n",
"Epoch 45 - accuracy: 0.9231474688187821\n",
"Epoch 46 - loss: 0.18374845385551453\n",
"Epoch 46 - accuracy: 0.9249816581071166\n",
"Epoch 47 - loss: 0.18296627700328827\n",
"Epoch 47 - accuracy: 0.9242479823917829\n",
"Epoch 48 - loss: 0.18405590951442719\n",
"Epoch 48 - accuracy: 0.9225972120322817\n",
"Epoch 49 - loss: 0.18362443149089813\n",
"Epoch 49 - accuracy: 0.923881144534116\n",
"Epoch 50 - loss: 0.17992615699768066\n",
"Epoch 50 - accuracy: 0.9246148202494497\n",
"Epoch 51 - loss: 0.17868903279304504\n",
"Epoch 51 - accuracy: 0.9257153338224505\n",
"Epoch 52 - loss: 0.17998185753822327\n",
"Epoch 52 - accuracy: 0.9268158473954512\n",
"Epoch 53 - loss: 0.180391326546669\n",
"Epoch 53 - accuracy: 0.9240645634629494\n",
"Epoch 54 - loss: 0.1815360188484192\n",
"Epoch 54 - accuracy: 0.9249816581071166\n",
"Epoch 55 - loss: 0.1782231479883194\n",
"Epoch 55 - accuracy: 0.9249816581071166\n",
"Epoch 56 - loss: 0.1754382997751236\n",
"Epoch 56 - accuracy: 0.9292002934702861\n",
"Epoch 57 - loss: 0.173308864235878\n",
"Epoch 57 - accuracy: 0.9304842259721203\n",
"Epoch 58 - loss: 0.1733696460723877\n",
"Epoch 58 - accuracy: 0.9286500366837858\n",
"Epoch 59 - loss: 0.17508764564990997\n",
"Epoch 59 - accuracy: 0.9288334556126192\n",
"Epoch 60 - loss: 0.17530100047588348\n",
"Epoch 60 - accuracy: 0.9271826852531181\n",
"Epoch 61 - loss: 0.1759965568780899\n",
"Epoch 61 - accuracy: 0.9280997798972854\n",
"Epoch 62 - loss: 0.1731904298067093\n",
"Epoch 62 - accuracy: 0.9284666177549523\n",
"Epoch 63 - loss: 0.17060300707817078\n",
"Epoch 63 - accuracy: 0.9304842259721203\n",
"Epoch 64 - loss: 0.16826492547988892\n",
"Epoch 64 - accuracy: 0.9323184152604549\n",
"Epoch 65 - loss: 0.16778990626335144\n",
"Epoch 65 - accuracy: 0.931951577402788\n",
"Epoch 66 - loss: 0.16876132786273956\n",
"Epoch 66 - accuracy: 0.9306676449009538\n",
"Epoch 67 - loss: 0.17013192176818848\n",
"Epoch 67 - accuracy: 0.9310344827586207\n",
"Epoch 68 - loss: 0.17322558164596558\n",
"Epoch 68 - accuracy: 0.9290168745414527\n",
"Epoch 69 - loss: 0.1745549589395523\n",
"Epoch 69 - accuracy: 0.9277329420396185\n",
"Epoch 70 - loss: 0.17566947638988495\n",
"Epoch 70 - accuracy: 0.927549523110785\n",
"Epoch 71 - loss: 0.16949035227298737\n",
"Epoch 71 - accuracy: 0.9304842259721203\n",
"Epoch 72 - loss: 0.16411995887756348\n",
"Epoch 72 - accuracy: 0.9328686720469552\n",
"Epoch 73 - loss: 0.1636730283498764\n",
"Epoch 73 - accuracy: 0.9336023477622891\n",
"Epoch 74 - loss: 0.16683687269687653\n",
"Epoch 74 - accuracy: 0.931951577402788\n",
"Epoch 75 - loss: 0.168714702129364\n",
"Epoch 75 - accuracy: 0.9315847395451211\n",
"Epoch 76 - loss: 0.16501154005527496\n",
"Epoch 76 - accuracy: 0.9315847395451211\n",
"Epoch 77 - loss: 0.16135883331298828\n",
"Epoch 77 - accuracy: 0.933969185619956\n",
"Epoch 78 - loss: 0.16116301715373993\n",
"Epoch 78 - accuracy: 0.9341526045487895\n",
"Epoch 79 - loss: 0.16375577449798584\n",
"Epoch 79 - accuracy: 0.9315847395451211\n",
"Epoch 80 - loss: 0.16777744889259338\n",
"Epoch 80 - accuracy: 0.9315847395451211\n",
"Epoch 81 - loss: 0.1683170050382614\n",
"Epoch 81 - accuracy: 0.9297505502567865\n",
"Epoch 82 - loss: 0.16626843810081482\n",
"Epoch 82 - accuracy: 0.931951577402788\n",
"Epoch 83 - loss: 0.16042152047157288\n",
"Epoch 83 - accuracy: 0.9326852531181218\n",
"Epoch 84 - loss: 0.1577889770269394\n",
"Epoch 84 - accuracy: 0.9352531181217901\n",
"Epoch 85 - loss: 0.15964850783348083\n",
"Epoch 85 - accuracy: 0.9345194424064563\n",
"Epoch 86 - loss: 0.1621726006269455\n",
"Epoch 86 - accuracy: 0.9315847395451211\n",
"Epoch 87 - loss: 0.1624235063791275\n",
"Epoch 87 - accuracy: 0.9336023477622891\n",
"Epoch 88 - loss: 0.15911513566970825\n",
"Epoch 88 - accuracy: 0.9337857666911226\n",
"Epoch 89 - loss: 0.15607108175754547\n",
"Epoch 89 - accuracy: 0.9378209831254586\n",
"Epoch 90 - loss: 0.15539272129535675\n",
"Epoch 90 - accuracy: 0.9367204695524578\n",
"Epoch 91 - loss: 0.15680216252803802\n",
"Epoch 91 - accuracy: 0.9354365370506236\n",
"Epoch 92 - loss: 0.15915803611278534\n",
"Epoch 92 - accuracy: 0.9337857666911226\n",
"Epoch 93 - loss: 0.161188542842865\n",
"Epoch 93 - accuracy: 0.9315847395451211\n",
"Epoch 94 - loss: 0.1637449860572815\n",
"Epoch 94 - accuracy: 0.9326852531181218\n",
"Epoch 95 - loss: 0.16255930066108704\n",
"Epoch 95 - accuracy: 0.9315847395451211\n",
"Epoch 96 - loss: 0.15967024862766266\n",
"Epoch 96 - accuracy: 0.9337857666911226\n",
"Epoch 97 - loss: 0.1543407291173935\n",
"Epoch 97 - accuracy: 0.9356199559794571\n",
"Epoch 98 - loss: 0.15191201865673065\n",
"Epoch 98 - accuracy: 0.9369038884812912\n",
"Epoch 99 - loss: 0.15335297584533691\n",
"Epoch 99 - accuracy: 0.9378209831254586\n",
"Epoch 100 - loss: 0.1558246612548828\n",
"Epoch 100 - accuracy: 0.9341526045487895\n",
"Epoch 101 - loss: 0.15698470175266266\n",
"Epoch 101 - accuracy: 0.9348862802641232\n",
"Epoch 102 - loss: 0.15483127534389496\n",
"Epoch 102 - accuracy: 0.9354365370506236\n",
"Epoch 103 - loss: 0.15184034407138824\n",
"Epoch 103 - accuracy: 0.9370873074101247\n",
"Epoch 104 - loss: 0.14983399212360382\n",
"Epoch 104 - accuracy: 0.9383712399119589\n",
"Epoch 105 - loss: 0.1496090143918991\n",
"Epoch 105 - accuracy: 0.9383712399119589\n",
"Epoch 106 - loss: 0.1509258896112442\n",
"Epoch 106 - accuracy: 0.9383712399119589\n",
"Epoch 107 - loss: 0.15413691103458405\n",
"Epoch 107 - accuracy: 0.9361702127659575\n",
"Epoch 108 - loss: 0.16182859241962433\n",
"Epoch 108 - accuracy: 0.9321349963316214\n",
"Epoch 109 - loss: 0.17049157619476318\n",
"Epoch 109 - accuracy: 0.9293837123991195\n",
"Epoch 110 - loss: 0.18406210839748383\n",
"Epoch 110 - accuracy: 0.921496698459281\n",
"Epoch 111 - loss: 0.1651865839958191\n",
"Epoch 111 - accuracy: 0.9330520909757887\n",
"Epoch 112 - loss: 0.14784397184848785\n",
"Epoch 112 - accuracy: 0.9403888481291269\n",
"Epoch 113 - loss: 0.15461425483226776\n",
"Epoch 113 - accuracy: 0.9354365370506236\n",
"Epoch 114 - loss: 0.1594894975423813\n",
"Epoch 114 - accuracy: 0.9336023477622891\n",
"Epoch 115 - loss: 0.14919613301753998\n",
"Epoch 115 - accuracy: 0.9389214966984593\n",
"Epoch 116 - loss: 0.14946608245372772\n",
"Epoch 116 - accuracy: 0.9385546588407924\n",
"Epoch 117 - loss: 0.15539346635341644\n",
"Epoch 117 - accuracy: 0.9352531181217901\n",
"Epoch 118 - loss: 0.14897421002388\n",
"Epoch 118 - accuracy: 0.9385546588407924\n",
"Epoch 119 - loss: 0.14688338339328766\n",
"Epoch 119 - accuracy: 0.9396551724137931\n",
"Epoch 120 - loss: 0.15270763635635376\n",
"Epoch 120 - accuracy: 0.9370873074101247\n",
"Epoch 121 - loss: 0.1505187451839447\n",
"Epoch 121 - accuracy: 0.9374541452677916\n",
"Epoch 122 - loss: 0.14485368132591248\n",
"Epoch 122 - accuracy: 0.9407556859867938\n",
"Epoch 123 - loss: 0.14813244342803955\n",
"Epoch 123 - accuracy: 0.9389214966984593\n",
"Epoch 124 - loss: 0.15244553983211517\n",
"Epoch 124 - accuracy: 0.935986793837124\n",
"Epoch 125 - loss: 0.14820507168769836\n",
"Epoch 125 - accuracy: 0.9385546588407924\n",
"Epoch 126 - loss: 0.1435617357492447\n",
"Epoch 126 - accuracy: 0.9409391049156273\n",
"Epoch 127 - loss: 0.14422369003295898\n",
"Epoch 127 - accuracy: 0.9405722670579604\n",
"Epoch 128 - loss: 0.14768363535404205\n",
"Epoch 128 - accuracy: 0.9385546588407924\n",
"Epoch 129 - loss: 0.1491793990135193\n",
"Epoch 129 - accuracy: 0.9365370506236244\n",
"Epoch 130 - loss: 0.14625929296016693\n",
"Epoch 130 - accuracy: 0.94002201027146\n",
"Epoch 131 - loss: 0.14273685216903687\n",
"Epoch 131 - accuracy: 0.9414893617021277\n",
"Epoch 132 - loss: 0.14178012311458588\n",
"Epoch 132 - accuracy: 0.9409391049156273\n",
"Epoch 133 - loss: 0.1434500366449356\n",
"Epoch 133 - accuracy: 0.9414893617021277\n",
"Epoch 134 - loss: 0.14543531835079193\n",
"Epoch 134 - accuracy: 0.9387380777696258\n",
"Epoch 135 - loss: 0.14560234546661377\n",
"Epoch 135 - accuracy: 0.9424064563462949\n",
"Epoch 136 - loss: 0.1443878710269928\n",
"Epoch 136 - accuracy: 0.9392883345561261\n",
"Epoch 137 - loss: 0.14207443594932556\n",
"Epoch 137 - accuracy: 0.9422230374174615\n",
"Epoch 138 - loss: 0.14036691188812256\n",
"Epoch 138 - accuracy: 0.9422230374174615\n",
"Epoch 139 - loss: 0.1399471014738083\n",
"Epoch 139 - accuracy: 0.9422230374174615\n",
"Epoch 140 - loss: 0.14061304926872253\n",
"Epoch 140 - accuracy: 0.9425898752751284\n",
"Epoch 141 - loss: 0.14191102981567383\n",
"Epoch 141 - accuracy: 0.9392883345561261\n",
"Epoch 142 - loss: 0.1439162641763687\n",
"Epoch 142 - accuracy: 0.9422230374174615\n",
"Epoch 143 - loss: 0.14785940945148468\n",
"Epoch 143 - accuracy: 0.9370873074101247\n",
"Epoch 144 - loss: 0.15084269642829895\n",
"Epoch 144 - accuracy: 0.9385546588407924\n",
"Epoch 145 - loss: 0.15596060454845428\n",
"Epoch 145 - accuracy: 0.9336023477622891\n",
"Epoch 146 - loss: 0.1506507396697998\n",
"Epoch 146 - accuracy: 0.9376375641966251\n",
"Epoch 147 - loss: 0.1430141031742096\n",
"Epoch 147 - accuracy: 0.9378209831254586\n",
"Epoch 148 - loss: 0.1378999948501587\n",
"Epoch 148 - accuracy: 0.9429567131327953\n",
"Epoch 149 - loss: 0.14189480245113373\n",
"Epoch 149 - accuracy: 0.9429567131327953\n",
"Epoch 150 - loss: 0.1461869180202484\n",
"Epoch 150 - accuracy: 0.938004402054292\n",
"Epoch 151 - loss: 0.1414678692817688\n",
"Epoch 151 - accuracy: 0.9429567131327953\n",
"Epoch 152 - loss: 0.13736973702907562\n",
"Epoch 152 - accuracy: 0.9446074834922964\n",
"Epoch 153 - loss: 0.13923421502113342\n",
"Epoch 153 - accuracy: 0.9405722670579604\n",
"Epoch 154 - loss: 0.14242184162139893\n",
"Epoch 154 - accuracy: 0.9427732942039618\n",
"Epoch 155 - loss: 0.14257796108722687\n",
"Epoch 155 - accuracy: 0.9381878209831255\n",
"Epoch 156 - loss: 0.138951376080513\n",
"Epoch 156 - accuracy: 0.9451577402787967\n",
"Epoch 157 - loss: 0.1361865997314453\n",
"Epoch 157 - accuracy: 0.9451577402787967\n",
"Epoch 158 - loss: 0.13608776032924652\n",
"Epoch 158 - accuracy: 0.9446074834922964\n",
"Epoch 159 - loss: 0.1380961388349533\n",
"Epoch 159 - accuracy: 0.9458914159941306\n",
"Epoch 160 - loss: 0.14209063351154327\n",
"Epoch 160 - accuracy: 0.9392883345561261\n",
"Epoch 161 - loss: 0.1477213203907013\n",
"Epoch 161 - accuracy: 0.9403888481291269\n",
"Epoch 162 - loss: 0.16057507693767548\n",
"Epoch 162 - accuracy: 0.931951577402788\n",
"Epoch 163 - loss: 0.16030831634998322\n",
"Epoch 163 - accuracy: 0.9326852531181218\n",
"Epoch 164 - loss: 0.15553373098373413\n",
"Epoch 164 - accuracy: 0.933969185619956\n",
"Epoch 165 - loss: 0.13658468425273895\n",
"Epoch 165 - accuracy: 0.9473587674247982\n",
"Epoch 166 - loss: 0.14066576957702637\n",
"Epoch 166 - accuracy: 0.9446074834922964\n",
"Epoch 167 - loss: 0.15113432705402374\n",
"Epoch 167 - accuracy: 0.9358033749082906\n",
"Epoch 168 - loss: 0.13788758218288422\n",
"Epoch 168 - accuracy: 0.9451577402787967\n",
"Epoch 169 - loss: 0.13813276588916779\n",
"Epoch 169 - accuracy: 0.9458914159941306\n",
"Epoch 170 - loss: 0.14637058973312378\n",
"Epoch 170 - accuracy: 0.9370873074101247\n",
"Epoch 171 - loss: 0.13659228384494781\n",
"Epoch 171 - accuracy: 0.9466250917094644\n",
"Epoch 172 - loss: 0.13671350479125977\n",
"Epoch 172 - accuracy: 0.9460748349229641\n",
"Epoch 173 - loss: 0.1432092934846878\n",
"Epoch 173 - accuracy: 0.9376375641966251\n",
"Epoch 174 - loss: 0.13640692830085754\n",
"Epoch 174 - accuracy: 0.9473587674247982\n",
"Epoch 175 - loss: 0.13370223343372345\n",
"Epoch 175 - accuracy: 0.9479090242112986\n",
"Epoch 176 - loss: 0.13890092074871063\n",
"Epoch 176 - accuracy: 0.9394717534849596\n",
"Epoch 177 - loss: 0.1390204131603241\n",
"Epoch 177 - accuracy: 0.9446074834922964\n",
"Epoch 178 - loss: 0.13492031395435333\n",
"Epoch 178 - accuracy: 0.9435069699192957\n",
"Epoch 179 - loss: 0.13199980556964874\n",
"Epoch 179 - accuracy: 0.9469919295671313\n",
"Epoch 180 - loss: 0.1336917281150818\n",
"Epoch 180 - accuracy: 0.9475421863536317\n",
"Epoch 181 - loss: 0.13739150762557983\n",
"Epoch 181 - accuracy: 0.9414893617021277\n",
"Epoch 182 - loss: 0.1385558694601059\n",
"Epoch 182 - accuracy: 0.9449743213499633\n",
"Epoch 183 - loss: 0.13827867805957794\n",
"Epoch 183 - accuracy: 0.9409391049156273\n",
"Epoch 184 - loss: 0.1347222477197647\n",
"Epoch 184 - accuracy: 0.9473587674247982\n",
"Epoch 185 - loss: 0.13180093467235565\n",
"Epoch 185 - accuracy: 0.9453411592076302\n",
"Epoch 186 - loss: 0.1310291290283203\n",
"Epoch 186 - accuracy: 0.9471753484959647\n",
"Epoch 187 - loss: 0.13239546120166779\n",
"Epoch 187 - accuracy: 0.9490095377842993\n",
"Epoch 188 - loss: 0.1344117820262909\n",
"Epoch 188 - accuracy: 0.9438738077769626\n",
"Epoch 189 - loss: 0.13522079586982727\n",
"Epoch 189 - accuracy: 0.9468085106382979\n",
"Epoch 190 - loss: 0.1355038285255432\n",
"Epoch 190 - accuracy: 0.9431401320616287\n",
"Epoch 191 - loss: 0.13365335762500763\n",
"Epoch 191 - accuracy: 0.9469919295671313\n",
"Epoch 192 - loss: 0.13177931308746338\n",
"Epoch 192 - accuracy: 0.9451577402787967\n",
"Epoch 193 - loss: 0.13017825782299042\n",
"Epoch 193 - accuracy: 0.9488261188554659\n",
"Epoch 194 - loss: 0.1296418458223343\n",
"Epoch 194 - accuracy: 0.9475421863536317\n",
"Epoch 195 - loss: 0.13013602793216705\n",
"Epoch 195 - accuracy: 0.946441672780631\n",
"Epoch 196 - loss: 0.13140098750591278\n",
"Epoch 196 - accuracy: 0.948459280997799\n",
"Epoch 197 - loss: 0.13433723151683807\n",
"Epoch 197 - accuracy: 0.9433235509904622\n",
"Epoch 198 - loss: 0.13919727504253387\n",
"Epoch 198 - accuracy: 0.9440572267057961\n",
"Epoch 199 - loss: 0.15434369444847107\n",
"Epoch 199 - accuracy: 0.9343360234776229\n",
"Epoch 200 - loss: 0.1615518182516098\n",
"Epoch 200 - accuracy: 0.931951577402788\n",
"Epoch 201 - loss: 0.17248667776584625\n",
"Epoch 201 - accuracy: 0.9257153338224505\n",
"Epoch 202 - loss: 0.13753581047058105\n",
"Epoch 202 - accuracy: 0.9451577402787967\n",
"Epoch 203 - loss: 0.13390885293483734\n",
"Epoch 203 - accuracy: 0.9473587674247982\n",
"Epoch 204 - loss: 0.15163034200668335\n",
"Epoch 204 - accuracy: 0.9350696991929567\n",
"Epoch 205 - loss: 0.1322300136089325\n",
"Epoch 205 - accuracy: 0.9493763756419662\n",
"Epoch 206 - loss: 0.1383107304573059\n",
"Epoch 206 - accuracy: 0.9436903888481292\n",
"Epoch 207 - loss: 0.14488989114761353\n",
"Epoch 207 - accuracy: 0.9376375641966251\n",
"Epoch 208 - loss: 0.12969642877578735\n",
"Epoch 208 - accuracy: 0.9475421863536317\n",
"Epoch 209 - loss: 0.1437886655330658\n",
"Epoch 209 - accuracy: 0.9418561995597946\n",
"Epoch 210 - loss: 0.13960619270801544\n",
"Epoch 210 - accuracy: 0.9398385913426266\n",
"Epoch 211 - loss: 0.1303216964006424\n",
"Epoch 211 - accuracy: 0.9451577402787967\n",
"Epoch 212 - loss: 0.1448250114917755\n",
"Epoch 212 - accuracy: 0.9409391049156273\n",
"Epoch 213 - loss: 0.1382424384355545\n",
"Epoch 213 - accuracy: 0.9409391049156273\n",
"Epoch 214 - loss: 0.12859448790550232\n",
"Epoch 214 - accuracy: 0.9466250917094644\n",
"Epoch 215 - loss: 0.14336782693862915\n",
"Epoch 215 - accuracy: 0.9403888481291269\n",
"Epoch 216 - loss: 0.1435861885547638\n",
"Epoch 216 - accuracy: 0.9394717534849596\n",
"Epoch 217 - loss: 0.12729865312576294\n",
"Epoch 217 - accuracy: 0.9471753484959647\n",
"Epoch 218 - loss: 0.14043588936328888\n",
"Epoch 218 - accuracy: 0.9422230374174615\n",
"Epoch 219 - loss: 0.14857016503810883\n",
"Epoch 219 - accuracy: 0.9370873074101247\n",
"Epoch 220 - loss: 0.12810075283050537\n",
"Epoch 220 - accuracy: 0.9495597945707998\n",
"Epoch 221 - loss: 0.13681010901927948\n",
"Epoch 221 - accuracy: 0.9440572267057961\n",
"Epoch 222 - loss: 0.1480584442615509\n",
"Epoch 222 - accuracy: 0.9376375641966251\n",
"Epoch 223 - loss: 0.12819060683250427\n",
"Epoch 223 - accuracy: 0.9499266324284666\n",
"Epoch 224 - loss: 0.13606873154640198\n",
"Epoch 224 - accuracy: 0.9442406456346295\n",
"Epoch 225 - loss: 0.14506947994232178\n",
"Epoch 225 - accuracy: 0.9385546588407924\n",
"Epoch 226 - loss: 0.12688535451889038\n",
"Epoch 226 - accuracy: 0.9501100513573001\n",
"Epoch 227 - loss: 0.13793440163135529\n",
"Epoch 227 - accuracy: 0.9436903888481292\n",
"Epoch 228 - loss: 0.14175213873386383\n",
"Epoch 228 - accuracy: 0.9405722670579604\n",
"Epoch 229 - loss: 0.12606659531593323\n",
"Epoch 229 - accuracy: 0.9471753484959647\n",
"Epoch 230 - loss: 0.14068996906280518\n",
"Epoch 230 - accuracy: 0.9422230374174615\n",
"Epoch 231 - loss: 0.13893744349479675\n",
"Epoch 231 - accuracy: 0.9407556859867938\n",
"Epoch 232 - loss: 0.12637808918952942\n",
"Epoch 232 - accuracy: 0.9471753484959647\n",
"Epoch 233 - loss: 0.1423971951007843\n",
"Epoch 233 - accuracy: 0.9411225238444607\n",
"Epoch 234 - loss: 0.13614536821842194\n",
"Epoch 234 - accuracy: 0.9422230374174615\n",
"Epoch 235 - loss: 0.12686654925346375\n",
"Epoch 235 - accuracy: 0.9466250917094644\n",
"Epoch 236 - loss: 0.14213667809963226\n",
"Epoch 236 - accuracy: 0.9402054292002935\n",
"Epoch 237 - loss: 0.1342005431652069\n",
"Epoch 237 - accuracy: 0.9429567131327953\n",
"Epoch 238 - loss: 0.12690769135951996\n",
"Epoch 238 - accuracy: 0.9460748349229641\n",
"Epoch 239 - loss: 0.1406722366809845\n",
"Epoch 239 - accuracy: 0.9402054292002935\n",
"Epoch 240 - loss: 0.1323944330215454\n",
"Epoch 240 - accuracy: 0.9427732942039618\n",
"Epoch 241 - loss: 0.12687557935714722\n",
"Epoch 241 - accuracy: 0.9458914159941306\n",
"Epoch 242 - loss: 0.13867004215717316\n",
"Epoch 242 - accuracy: 0.9429567131327953\n",
"Epoch 243 - loss: 0.13049884140491486\n",
"Epoch 243 - accuracy: 0.9453411592076302\n",
"Epoch 244 - loss: 0.12657293677330017\n",
"Epoch 244 - accuracy: 0.9458914159941306\n",
"Epoch 245 - loss: 0.13628236949443817\n",
"Epoch 245 - accuracy: 0.9440572267057961\n",
"Epoch 246 - loss: 0.12921735644340515\n",
"Epoch 246 - accuracy: 0.9457079970652972\n",
"Epoch 247 - loss: 0.125750333070755\n",
"Epoch 247 - accuracy: 0.9468085106382979\n",
"Epoch 248 - loss: 0.13418523967266083\n",
"Epoch 248 - accuracy: 0.9447909024211298\n",
"Epoch 249 - loss: 0.128926083445549\n",
"Epoch 249 - accuracy: 0.9455245781364637\n",
"Epoch 250 - loss: 0.12465361505746841\n",
"Epoch 250 - accuracy: 0.9469919295671313\n",
"Epoch 251 - loss: 0.13247524201869965\n",
"Epoch 251 - accuracy: 0.9457079970652972\n",
"Epoch 252 - loss: 0.1298513263463974\n",
"Epoch 252 - accuracy: 0.9453411592076302\n",
"Epoch 253 - loss: 0.12355927377939224\n",
"Epoch 253 - accuracy: 0.9479090242112986\n",
"Epoch 254 - loss: 0.13069619238376617\n",
"Epoch 254 - accuracy: 0.9469919295671313\n",
"Epoch 255 - loss: 0.13157017529010773\n",
"Epoch 255 - accuracy: 0.9440572267057961\n",
"Epoch 256 - loss: 0.12315743416547775\n",
"Epoch 256 - accuracy: 0.9495597945707998\n",
"Epoch 257 - loss: 0.12831072509288788\n",
"Epoch 257 - accuracy: 0.948459280997799\n",
"Epoch 258 - loss: 0.13248488306999207\n",
"Epoch 258 - accuracy: 0.9442406456346295\n",
"Epoch 259 - loss: 0.12374118715524673\n",
"Epoch 259 - accuracy: 0.9506603081438004\n",
"Epoch 260 - loss: 0.12544500827789307\n",
"Epoch 260 - accuracy: 0.9499266324284666\n",
"Epoch 261 - loss: 0.13119831681251526\n",
"Epoch 261 - accuracy: 0.944424064563463\n",
"Epoch 262 - loss: 0.12471841275691986\n",
"Epoch 262 - accuracy: 0.950476889214967\n",
"Epoch 263 - loss: 0.12288667261600494\n",
"Epoch 263 - accuracy: 0.9501100513573001\n",
"Epoch 264 - loss: 0.12759484350681305\n",
"Epoch 264 - accuracy: 0.9466250917094644\n",
"Epoch 265 - loss: 0.125613272190094\n",
"Epoch 265 - accuracy: 0.9491929567131328\n",
"Epoch 266 - loss: 0.12219160795211792\n",
"Epoch 266 - accuracy: 0.9486426999266324\n",
"Epoch 267 - loss: 0.12327056378126144\n",
"Epoch 267 - accuracy: 0.9477256052824652\n",
"Epoch 268 - loss: 0.12530098855495453\n",
"Epoch 268 - accuracy: 0.9491929567131328\n",
"Epoch 269 - loss: 0.1247311532497406\n",
"Epoch 269 - accuracy: 0.9471753484959647\n",
"Epoch 270 - loss: 0.1219773069024086\n",
"Epoch 270 - accuracy: 0.9501100513573001\n",
"Epoch 271 - loss: 0.12205666303634644\n",
"Epoch 271 - accuracy: 0.9499266324284666\n",
"Epoch 272 - loss: 0.12405984103679657\n",
"Epoch 272 - accuracy: 0.9475421863536317\n",
"Epoch 273 - loss: 0.12427092343568802\n",
"Epoch 273 - accuracy: 0.9499266324284666\n",
"Epoch 274 - loss: 0.123012013733387\n",
"Epoch 274 - accuracy: 0.9479090242112986\n",
"Epoch 275 - loss: 0.12140244245529175\n",
"Epoch 275 - accuracy: 0.9502934702861335\n",
"Epoch 276 - loss: 0.12125080823898315\n",
"Epoch 276 - accuracy: 0.9506603081438004\n",
"Epoch 277 - loss: 0.12223225086927414\n",
"Epoch 277 - accuracy: 0.9488261188554659\n",
"Epoch 278 - loss: 0.12314152717590332\n",
"Epoch 278 - accuracy: 0.9499266324284666\n",
"Epoch 279 - loss: 0.12346971035003662\n",
"Epoch 279 - accuracy: 0.9482758620689655\n",
"Epoch 280 - loss: 0.12271338701248169\n",
"Epoch 280 - accuracy: 0.9497432134996332\n",
"Epoch 281 - loss: 0.12176275253295898\n",
"Epoch 281 - accuracy: 0.9495597945707998\n",
"Epoch 282 - loss: 0.12081919610500336\n",
"Epoch 282 - accuracy: 0.9508437270726339\n",
"Epoch 283 - loss: 0.12036316096782684\n",
"Epoch 283 - accuracy: 0.9501100513573001\n",
"Epoch 284 - loss: 0.12042579054832458\n",
"Epoch 284 - accuracy: 0.9499266324284666\n",
"Epoch 285 - loss: 0.1209179013967514\n",
"Epoch 285 - accuracy: 0.9506603081438004\n",
"Epoch 286 - loss: 0.12179320305585861\n",
"Epoch 286 - accuracy: 0.9499266324284666\n",
"Epoch 287 - loss: 0.12336383759975433\n",
"Epoch 287 - accuracy: 0.9499266324284666\n",
"Epoch 288 - loss: 0.12664794921875\n",
"Epoch 288 - accuracy: 0.9460748349229641\n",
"Epoch 289 - loss: 0.12985500693321228\n",
"Epoch 289 - accuracy: 0.9495597945707998\n",
"Epoch 290 - loss: 0.13782480359077454\n",
"Epoch 290 - accuracy: 0.9424064563462949\n",
"Epoch 291 - loss: 0.13601623475551605\n",
"Epoch 291 - accuracy: 0.9429567131327953\n",
"Epoch 292 - loss: 0.13354693353176117\n",
"Epoch 292 - accuracy: 0.9424064563462949\n",
"Epoch 293 - loss: 0.12155157327651978\n",
"Epoch 293 - accuracy: 0.9517608217168012\n",
"Epoch 294 - loss: 0.1212334856390953\n",
"Epoch 294 - accuracy: 0.9517608217168012\n",
"Epoch 295 - loss: 0.12846817076206207\n",
"Epoch 295 - accuracy: 0.9460748349229641\n",
"Epoch 296 - loss: 0.12484529614448547\n",
"Epoch 296 - accuracy: 0.9512105649303008\n",
"Epoch 297 - loss: 0.11943589150905609\n",
"Epoch 297 - accuracy: 0.9501100513573001\n",
"Epoch 298 - loss: 0.12043111026287079\n",
"Epoch 298 - accuracy: 0.9493763756419662\n",
"Epoch 299 - loss: 0.12429048120975494\n",
"Epoch 299 - accuracy: 0.9510271460014673\n",
"Epoch 300 - loss: 0.12505868077278137\n",
"Epoch 300 - accuracy: 0.9458914159941306\n",
"Epoch 301 - loss: 0.12101427465677261\n",
"Epoch 301 - accuracy: 0.9502934702861335\n",
"Epoch 302 - loss: 0.11874640733003616\n",
"Epoch 302 - accuracy: 0.9513939838591343\n",
"Epoch 303 - loss: 0.11972319334745407\n",
"Epoch 303 - accuracy: 0.9513939838591343\n",
"Epoch 304 - loss: 0.12259635329246521\n",
"Epoch 304 - accuracy: 0.9508437270726339\n",
"Epoch 305 - loss: 0.1272718459367752\n",
"Epoch 305 - accuracy: 0.9453411592076302\n",
"Epoch 306 - loss: 0.1310117095708847\n",
"Epoch 306 - accuracy: 0.9482758620689655\n",
"Epoch 307 - loss: 0.13956663012504578\n",
"Epoch 307 - accuracy: 0.9414893617021277\n",
"Epoch 308 - loss: 0.13723066449165344\n",
"Epoch 308 - accuracy: 0.9425898752751284\n",
"Epoch 309 - loss: 0.13382814824581146\n",
"Epoch 309 - accuracy: 0.9436903888481292\n",
"Epoch 310 - loss: 0.12054771184921265\n",
"Epoch 310 - accuracy: 0.9513939838591343\n",
"Epoch 311 - loss: 0.12119341641664505\n",
"Epoch 311 - accuracy: 0.9515774027879678\n",
"Epoch 312 - loss: 0.12905658781528473\n",
"Epoch 312 - accuracy: 0.9449743213499633\n",
"Epoch 313 - loss: 0.1224476620554924\n",
"Epoch 313 - accuracy: 0.9517608217168012\n",
"Epoch 314 - loss: 0.11810442805290222\n",
"Epoch 314 - accuracy: 0.9508437270726339\n",
"Epoch 315 - loss: 0.12269062548875809\n",
"Epoch 315 - accuracy: 0.9473587674247982\n",
"Epoch 316 - loss: 0.12358292192220688\n",
"Epoch 316 - accuracy: 0.9515774027879678\n",
"Epoch 317 - loss: 0.11988306790590286\n",
"Epoch 317 - accuracy: 0.9502934702861335\n",
"Epoch 318 - loss: 0.11729229986667633\n",
"Epoch 318 - accuracy: 0.9512105649303008\n",
"Epoch 319 - loss: 0.11976367235183716\n",
"Epoch 319 - accuracy: 0.9501100513573001\n",
"Epoch 320 - loss: 0.12398409843444824\n",
"Epoch 320 - accuracy: 0.9466250917094644\n",
"Epoch 321 - loss: 0.12489615380764008\n",
"Epoch 321 - accuracy: 0.950476889214967\n",
"Epoch 322 - loss: 0.12479665130376816\n",
"Epoch 322 - accuracy: 0.9469919295671313\n",
"Epoch 323 - loss: 0.12163161486387253\n",
"Epoch 323 - accuracy: 0.9510271460014673\n",
"Epoch 324 - loss: 0.11900309473276138\n",
"Epoch 324 - accuracy: 0.9510271460014673\n",
"Epoch 325 - loss: 0.11766552925109863\n",
"Epoch 325 - accuracy: 0.9521276595744681\n",
"Epoch 326 - loss: 0.11795462667942047\n",
"Epoch 326 - accuracy: 0.9521276595744681\n",
"Epoch 327 - loss: 0.11942996829748154\n",
"Epoch 327 - accuracy: 0.9506603081438004\n",
"Epoch 328 - loss: 0.12166863679885864\n",
"Epoch 328 - accuracy: 0.9513939838591343\n",
"Epoch 329 - loss: 0.12576095759868622\n",
"Epoch 329 - accuracy: 0.9469919295671313\n",
"Epoch 330 - loss: 0.12968149781227112\n",
"Epoch 330 - accuracy: 0.9479090242112986\n",
"Epoch 331 - loss: 0.1390303373336792\n",
"Epoch 331 - accuracy: 0.9422230374174615\n",
"Epoch 332 - loss: 0.13620175421237946\n",
"Epoch 332 - accuracy: 0.9435069699192957\n",
"Epoch 333 - loss: 0.13187989592552185\n",
"Epoch 333 - accuracy: 0.9438738077769626\n",
"Epoch 334 - loss: 0.1182360053062439\n",
"Epoch 334 - accuracy: 0.9521276595744681\n",
"Epoch 335 - loss: 0.121561199426651\n",
"Epoch 335 - accuracy: 0.9515774027879678\n",
"Epoch 336 - loss: 0.12973010540008545\n",
"Epoch 336 - accuracy: 0.9455245781364637\n",
"Epoch 337 - loss: 0.11987362802028656\n",
"Epoch 337 - accuracy: 0.9521276595744681\n",
"Epoch 338 - loss: 0.11774610728025436\n",
"Epoch 338 - accuracy: 0.9532281731474688\n",
"Epoch 339 - loss: 0.12471359968185425\n",
"Epoch 339 - accuracy: 0.9473587674247982\n",
"Epoch 340 - loss: 0.12189881503582001\n",
"Epoch 340 - accuracy: 0.9513939838591343\n",
"Epoch 341 - loss: 0.11661942303180695\n",
"Epoch 341 - accuracy: 0.9515774027879678\n",
"Epoch 342 - loss: 0.11699726432561874\n",
"Epoch 342 - accuracy: 0.9513939838591343\n",
"Epoch 343 - loss: 0.12145952135324478\n",
"Epoch 343 - accuracy: 0.9513939838591343\n",
"Epoch 344 - loss: 0.12609325349330902\n",
"Epoch 344 - accuracy: 0.9469919295671313\n",
"Epoch 345 - loss: 0.12498581409454346\n",
"Epoch 345 - accuracy: 0.9499266324284666\n",
"Epoch 346 - loss: 0.1231597438454628\n",
"Epoch 346 - accuracy: 0.9480924431401321\n",
"Epoch 347 - loss: 0.11873132735490799\n",
"Epoch 347 - accuracy: 0.9510271460014673\n",
"Epoch 348 - loss: 0.11646643280982971\n",
"Epoch 348 - accuracy: 0.9528613352898019\n",
"Epoch 349 - loss: 0.11734868586063385\n",
"Epoch 349 - accuracy: 0.9512105649303008\n",
"Epoch 350 - loss: 0.11994626373052597\n",
"Epoch 350 - accuracy: 0.9528613352898019\n",
"Epoch 351 - loss: 0.12329673767089844\n",
"Epoch 351 - accuracy: 0.9475421863536317\n",
"Epoch 352 - loss: 0.12438608705997467\n",
"Epoch 352 - accuracy: 0.950476889214967\n",
"Epoch 353 - loss: 0.12616336345672607\n",
"Epoch 353 - accuracy: 0.9473587674247982\n",
"Epoch 354 - loss: 0.12313596159219742\n",
"Epoch 354 - accuracy: 0.950476889214967\n",
"Epoch 355 - loss: 0.12048415839672089\n",
"Epoch 355 - accuracy: 0.9491929567131328\n",
"Epoch 356 - loss: 0.11673296988010406\n",
"Epoch 356 - accuracy: 0.9521276595744681\n",
"Epoch 357 - loss: 0.11539475619792938\n",
"Epoch 357 - accuracy: 0.9515774027879678\n",
"Epoch 358 - loss: 0.11649461835622787\n",
"Epoch 358 - accuracy: 0.952494497432135\n",
"Epoch 359 - loss: 0.11848600953817368\n",
"Epoch 359 - accuracy: 0.9517608217168012\n",
"Epoch 360 - loss: 0.12056515365839005\n",
"Epoch 360 - accuracy: 0.9499266324284666\n",
"Epoch 361 - loss: 0.12084736675024033\n",
"Epoch 361 - accuracy: 0.9508437270726339\n",
"Epoch 362 - loss: 0.12160326540470123\n",
"Epoch 362 - accuracy: 0.9486426999266324\n",
"Epoch 363 - loss: 0.12055764347314835\n",
"Epoch 363 - accuracy: 0.9510271460014673\n",
"Epoch 364 - loss: 0.12027326971292496\n",
"Epoch 364 - accuracy: 0.9490095377842993\n",
"Epoch 365 - loss: 0.1186472624540329\n",
"Epoch 365 - accuracy: 0.9519442406456347\n",
"Epoch 366 - loss: 0.11754312366247177\n",
"Epoch 366 - accuracy: 0.9517608217168012\n",
"Epoch 367 - loss: 0.1162467822432518\n",
"Epoch 367 - accuracy: 0.9517608217168012\n",
"Epoch 368 - loss: 0.11557508260011673\n",
"Epoch 368 - accuracy: 0.9515774027879678\n",
"Epoch 369 - loss: 0.11531295627355576\n",
"Epoch 369 - accuracy: 0.9510271460014673\n",
"Epoch 370 - loss: 0.11541815102100372\n",
"Epoch 370 - accuracy: 0.9515774027879678\n",
"Epoch 371 - loss: 0.11582695692777634\n",
"Epoch 371 - accuracy: 0.9523110785033015\n",
"Epoch 372 - loss: 0.1167873665690422\n",
"Epoch 372 - accuracy: 0.9517608217168012\n",
"Epoch 373 - loss: 0.11975935101509094\n",
"Epoch 373 - accuracy: 0.9512105649303008\n",
"Epoch 374 - loss: 0.12819969654083252\n",
"Epoch 374 - accuracy: 0.9486426999266324\n",
"Epoch 375 - loss: 0.16193169355392456\n",
"Epoch 375 - accuracy: 0.9369038884812912\n",
"Epoch 376 - loss: 0.20821644365787506\n",
"Epoch 376 - accuracy: 0.9154438738077769\n",
"Epoch 377 - loss: 0.325961709022522\n",
"Epoch 377 - accuracy: 0.8919662509170947\n",
"Epoch 378 - loss: 0.1364455372095108\n",
"Epoch 378 - accuracy: 0.9440572267057961\n",
"Epoch 379 - loss: 0.21241942048072815\n",
"Epoch 379 - accuracy: 0.907006603081438\n",
"Epoch 380 - loss: 0.1626487672328949\n",
"Epoch 380 - accuracy: 0.9269992663242846\n",
"Epoch 381 - loss: 0.20652639865875244\n",
"Epoch 381 - accuracy: 0.9092076302274394\n",
"Epoch 382 - loss: 0.13588370382785797\n",
"Epoch 382 - accuracy: 0.9455245781364637\n",
"Epoch 383 - loss: 0.20422253012657166\n",
"Epoch 383 - accuracy: 0.909024211298606\n",
"Epoch 384 - loss: 0.13599459826946259\n",
"Epoch 384 - accuracy: 0.9495597945707998\n",
"Epoch 385 - loss: 0.17693686485290527\n",
"Epoch 385 - accuracy: 0.921496698459281\n",
"Epoch 386 - loss: 0.1579907238483429\n",
"Epoch 386 - accuracy: 0.9314013206162876\n",
"Epoch 387 - loss: 0.14672908186912537\n",
"Epoch 387 - accuracy: 0.942039618488628\n",
"Epoch 388 - loss: 0.16746172308921814\n",
"Epoch 388 - accuracy: 0.927549523110785\n",
"Epoch 389 - loss: 0.1338583081960678\n",
"Epoch 389 - accuracy: 0.948459280997799\n",
"Epoch 390 - loss: 0.15159721672534943\n",
"Epoch 390 - accuracy: 0.9341526045487895\n",
"Epoch 391 - loss: 0.13681034743785858\n",
"Epoch 391 - accuracy: 0.9436903888481292\n",
"Epoch 392 - loss: 0.14098820090293884\n",
"Epoch 392 - accuracy: 0.9440572267057961\n",
"Epoch 393 - loss: 0.1385105550289154\n",
"Epoch 393 - accuracy: 0.9451577402787967\n",
"Epoch 394 - loss: 0.13083040714263916\n",
"Epoch 394 - accuracy: 0.9446074834922964\n",
"Epoch 395 - loss: 0.13558636605739594\n",
"Epoch 395 - accuracy: 0.9425898752751284\n",
"Epoch 396 - loss: 0.12347723543643951\n",
"Epoch 396 - accuracy: 0.9512105649303008\n",
"Epoch 397 - loss: 0.13910643756389618\n",
"Epoch 397 - accuracy: 0.9436903888481292\n",
"Epoch 398 - loss: 0.12173766642808914\n",
"Epoch 398 - accuracy: 0.9479090242112986\n",
"Epoch 399 - loss: 0.13300205767154694\n",
"Epoch 399 - accuracy: 0.944424064563463\n",
"Epoch 400 - loss: 0.1191490963101387\n",
"Epoch 400 - accuracy: 0.9499266324284666\n",
"Epoch 401 - loss: 0.13313132524490356\n",
"Epoch 401 - accuracy: 0.9480924431401321\n",
"Epoch 402 - loss: 0.11875271052122116\n",
"Epoch 402 - accuracy: 0.9521276595744681\n",
"Epoch 403 - loss: 0.1281844675540924\n",
"Epoch 403 - accuracy: 0.9455245781364637\n",
"Epoch 404 - loss: 0.11837610602378845\n",
"Epoch 404 - accuracy: 0.9534115920763023\n",
"Epoch 405 - loss: 0.12806600332260132\n",
"Epoch 405 - accuracy: 0.9501100513573001\n",
"Epoch 406 - loss: 0.11814288049936295\n",
"Epoch 406 - accuracy: 0.9519442406456347\n",
"Epoch 407 - loss: 0.12468785047531128\n",
"Epoch 407 - accuracy: 0.9475421863536317\n",
"Epoch 408 - loss: 0.11862160265445709\n",
"Epoch 408 - accuracy: 0.9526779163609684\n",
"Epoch 409 - loss: 0.12392288446426392\n",
"Epoch 409 - accuracy: 0.952494497432135\n",
"Epoch 410 - loss: 0.11869090050458908\n",
"Epoch 410 - accuracy: 0.9497432134996332\n",
"Epoch 411 - loss: 0.12116280943155289\n",
"Epoch 411 - accuracy: 0.9501100513573001\n",
"Epoch 412 - loss: 0.1193477138876915\n",
"Epoch 412 - accuracy: 0.9528613352898019\n",
"Epoch 413 - loss: 0.11886186897754669\n",
"Epoch 413 - accuracy: 0.9541452677916361\n",
"Epoch 414 - loss: 0.11949761211872101\n",
"Epoch 414 - accuracy: 0.9506603081438004\n",
"Epoch 415 - loss: 0.11689911782741547\n",
"Epoch 415 - accuracy: 0.9519442406456347\n",
"Epoch 416 - loss: 0.1198965311050415\n",
"Epoch 416 - accuracy: 0.9530447542186353\n",
"Epoch 417 - loss: 0.11652497947216034\n",
"Epoch 417 - accuracy: 0.9532281731474688\n",
"Epoch 418 - loss: 0.1186857745051384\n",
"Epoch 418 - accuracy: 0.9506603081438004\n",
"Epoch 419 - loss: 0.11728774011135101\n",
"Epoch 419 - accuracy: 0.9537784299339692\n",
"Epoch 420 - loss: 0.11748381704092026\n",
"Epoch 420 - accuracy: 0.9537784299339692\n",
"Epoch 421 - loss: 0.11850572377443314\n",
"Epoch 421 - accuracy: 0.9515774027879678\n",
"Epoch 422 - loss: 0.11655177175998688\n",
"Epoch 422 - accuracy: 0.9537784299339692\n",
"Epoch 423 - loss: 0.11829046159982681\n",
"Epoch 423 - accuracy: 0.9521276595744681\n",
"Epoch 424 - loss: 0.11714666336774826\n",
"Epoch 424 - accuracy: 0.9528613352898019\n",
"Epoch 425 - loss: 0.11655426770448685\n",
"Epoch 425 - accuracy: 0.9530447542186353\n",
"Epoch 426 - loss: 0.11783893406391144\n",
"Epoch 426 - accuracy: 0.9515774027879678\n",
"Epoch 427 - loss: 0.11623657494783401\n",
"Epoch 427 - accuracy: 0.9528613352898019\n",
"Epoch 428 - loss: 0.11643465608358383\n",
"Epoch 428 - accuracy: 0.9526779163609684\n",
"Epoch 429 - loss: 0.11700281500816345\n",
"Epoch 429 - accuracy: 0.9528613352898019\n",
"Epoch 430 - loss: 0.11578302085399628\n",
"Epoch 430 - accuracy: 0.9530447542186353\n",
"Epoch 431 - loss: 0.11609523743391037\n",
"Epoch 431 - accuracy: 0.9517608217168012\n",
"Epoch 432 - loss: 0.11639384180307388\n",
"Epoch 432 - accuracy: 0.952494497432135\n",
"Epoch 433 - loss: 0.11552444845438004\n",
"Epoch 433 - accuracy: 0.9530447542186353\n",
"Epoch 434 - loss: 0.11577931046485901\n",
"Epoch 434 - accuracy: 0.9521276595744681\n",
"Epoch 435 - loss: 0.11588909476995468\n",
"Epoch 435 - accuracy: 0.9528613352898019\n",
"Epoch 436 - loss: 0.11526598036289215\n",
"Epoch 436 - accuracy: 0.9534115920763023\n",
"Epoch 437 - loss: 0.11537007242441177\n",
"Epoch 437 - accuracy: 0.9526779163609684\n",
"Epoch 438 - loss: 0.11550384759902954\n",
"Epoch 438 - accuracy: 0.9532281731474688\n",
"Epoch 439 - loss: 0.11503107100725174\n",
"Epoch 439 - accuracy: 0.9530447542186353\n",
"Epoch 440 - loss: 0.11492138355970383\n",
"Epoch 440 - accuracy: 0.9535950110051358\n",
"Epoch 441 - loss: 0.11515788733959198\n",
"Epoch 441 - accuracy: 0.9532281731474688\n",
"Epoch 442 - loss: 0.11495614796876907\n",
"Epoch 442 - accuracy: 0.9534115920763023\n",
"Epoch 443 - loss: 0.11473667621612549\n",
"Epoch 443 - accuracy: 0.9535950110051358\n",
"Epoch 444 - loss: 0.11485560238361359\n",
"Epoch 444 - accuracy: 0.9539618488628027\n",
"Epoch 445 - loss: 0.11495267599821091\n",
"Epoch 445 - accuracy: 0.9539618488628027\n",
"Epoch 446 - loss: 0.11481142789125443\n",
"Epoch 446 - accuracy: 0.9537784299339692\n",
"Epoch 447 - loss: 0.11461487412452698\n",
"Epoch 447 - accuracy: 0.9543286867204696\n",
"Epoch 448 - loss: 0.11467333883047104\n",
"Epoch 448 - accuracy: 0.9535950110051358\n",
"Epoch 449 - loss: 0.11486823111772537\n",
"Epoch 449 - accuracy: 0.9535950110051358\n",
"Epoch 450 - loss: 0.1148831769824028\n",
"Epoch 450 - accuracy: 0.9535950110051358\n",
"Epoch 451 - loss: 0.11478026211261749\n",
"Epoch 451 - accuracy: 0.9541452677916361\n",
"Epoch 452 - loss: 0.11460951715707779\n",
"Epoch 452 - accuracy: 0.9534115920763023\n",
"Epoch 453 - loss: 0.11460217833518982\n",
"Epoch 453 - accuracy: 0.9534115920763023\n",
"Epoch 454 - loss: 0.11475560069084167\n",
"Epoch 454 - accuracy: 0.9534115920763023\n",
"Epoch 455 - loss: 0.11476694047451019\n",
"Epoch 455 - accuracy: 0.9539618488628027\n",
"Epoch 456 - loss: 0.11479894816875458\n",
"Epoch 456 - accuracy: 0.9530447542186353\n",
"Epoch 457 - loss: 0.11459555476903915\n",
"Epoch 457 - accuracy: 0.9541452677916361\n",
"Epoch 458 - loss: 0.11456222832202911\n",
"Epoch 458 - accuracy: 0.9535950110051358\n",
"Epoch 459 - loss: 0.11451135575771332\n",
"Epoch 459 - accuracy: 0.9543286867204696\n",
"Epoch 460 - loss: 0.11464311182498932\n",
"Epoch 460 - accuracy: 0.9530447542186353\n",
"Epoch 461 - loss: 0.11459161341190338\n",
"Epoch 461 - accuracy: 0.9541452677916361\n",
"Epoch 462 - loss: 0.11457392573356628\n",
"Epoch 462 - accuracy: 0.9530447542186353\n",
"Epoch 463 - loss: 0.1144469603896141\n",
"Epoch 463 - accuracy: 0.9537784299339692\n",
"Epoch 464 - loss: 0.11453792452812195\n",
"Epoch 464 - accuracy: 0.9528613352898019\n",
"Epoch 465 - loss: 0.11465635150671005\n",
"Epoch 465 - accuracy: 0.9530447542186353\n",
"Epoch 466 - loss: 0.11518797278404236\n",
"Epoch 466 - accuracy: 0.9526779163609684\n",
"Epoch 467 - loss: 0.11605854332447052\n",
"Epoch 467 - accuracy: 0.9537784299339692\n",
"Epoch 468 - loss: 0.11784745752811432\n",
"Epoch 468 - accuracy: 0.9523110785033015\n",
"Epoch 469 - loss: 0.12037881463766098\n",
"Epoch 469 - accuracy: 0.9510271460014673\n",
"Epoch 470 - loss: 0.12108893692493439\n",
"Epoch 470 - accuracy: 0.9508437270726339\n",
"Epoch 471 - loss: 0.12257847934961319\n",
"Epoch 471 - accuracy: 0.9512105649303008\n",
"0.9512105649303008\n"
]
}
],
"source": [
"best_model = NeuralNetwork(100, 256)\n",
"\n",
"criterion = nn.BCELoss()\n",
"optimizer = optim.Adam(best_model.parameters(), lr=0.001, weight_decay=0.001)\n",
"\n",
"X_train_tensor = torch.from_numpy(X_train).float()\n",
"y_train_tensor = torch.from_numpy(y_train).float().view(-1, 1)\n",
"\n",
"X_dev_tensor = torch.from_numpy(X_dev).float()\n",
"y_dev_tensor = torch.from_numpy(y_dev).float().view(-1, 1)\n",
"\n",
"X_test_tensor = torch.from_numpy(X_test).float()\n",
"\n",
"best_model.train()\n",
"\n",
"best_epoch = 471\n",
"\n",
"for epoch in range(best_epoch + 1):\n",
" \n",
" optimizer.zero_grad()\n",
" y_pred = best_model(X_train_tensor)\n",
" loss = criterion(y_pred, y_train_tensor)\n",
" loss.backward()\n",
" optimizer.step()\n",
" \n",
" # dev loss\n",
" with torch.no_grad():\n",
" y_pred = best_model(X_dev_tensor)\n",
" loss = criterion(y_pred, y_dev_tensor)\n",
" accuracy = accuracy_score(y_dev_tensor, np.where(y_pred > 0.5, 1, 0))\n",
" \n",
" print(f\"Epoch {epoch} - loss: {loss}\")\n",
" print(f\"Epoch {epoch} - accuracy: {accuracy}\")\n",
" \n",
"model.eval()"
]
},
{
"cell_type": "code",
"execution_count": 30,
"id": "dee40e5a-7dd5-46c3-8504-8399e8356cfb",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Dev accuracy: 0.9512105649303008\n"
]
}
],
"source": [
"with torch.no_grad():\n",
" y_pred_dev = best_model(X_dev_tensor)\n",
" y_pred_dev = np.where(y_pred_dev > 0.5, 1, 0)\n",
" df_dev_out = pd.DataFrame(y_pred_dev)\n",
" df_dev_out.to_csv('dev-0/out.tsv', sep = '\\t', index = None)\n",
" accuracy_dev = accuracy_score(y_dev_tensor, y_pred_dev)\n",
" print(f\"Dev accuracy: {accuracy_dev}\")\n",
"\n",
" y_pred_test = best_model(X_test_tensor)\n",
" y_pred_test = np.where(y_pred_test > 0.5, 1, 0)\n",
" df_test_out = pd.DataFrame(y_pred_test)\n",
" df_test_out.to_csv('test-A/out.tsv', sep = '\\t', index = None)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.4"
}
},
"nbformat": 4,
"nbformat_minor": 5
}