Perplexity: 422

This commit is contained in:
Marcin Czerniak 2024-05-22 01:51:51 +02:00
parent a2d183f2e3
commit 0c45624062
2 changed files with 10672 additions and 10576 deletions

File diff suppressed because it is too large Load Diff

View File

@ -16,33 +16,43 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 42, "execution_count": 1,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\Marcin\\.conda\\envs\\p311-cu121\\Lib\\site-packages\\torchtext\\vocab\\__init__.py:4: UserWarning: \n",
"/!\\ IMPORTANT WARNING ABOUT TORCHTEXT STATUS /!\\ \n",
"Torchtext is deprecated and the last released version will be 0.18 (this one). You can silence this warning by calling the following at the beginnign of your scripts: `import torchtext; torchtext.disable_torchtext_deprecation_warning()`\n",
" warnings.warn(torchtext._TORCHTEXT_DEPRECATION_MSG)\n",
"c:\\Users\\Marcin\\.conda\\envs\\p311-cu121\\Lib\\site-packages\\torchtext\\utils.py:4: UserWarning: \n",
"/!\\ IMPORTANT WARNING ABOUT TORCHTEXT STATUS /!\\ \n",
"Torchtext is deprecated and the last released version will be 0.18 (this one). You can silence this warning by calling the following at the beginnign of your scripts: `import torchtext; torchtext.disable_torchtext_deprecation_warning()`\n",
" warnings.warn(torchtext._TORCHTEXT_DEPRECATION_MSG)\n"
]
}
],
"source": [ "source": [
"from tqdm import tqdm\n", "from tqdm import tqdm\n",
"import re\n", "import re\n",
"import nltk\n",
"import os\n", "import os\n",
"import csv\n", "import csv\n",
"import pandas as pd\n", "import pandas as pd\n",
"import torch\n", "import torch\n",
"import torch.nn as nn\n", "import torch.nn as nn\n",
"import torch.optim as optim\n", "import torch.optim as optim\n",
"import sys\n",
"import numpy as np\n", "import numpy as np\n",
"from torch.utils.data import DataLoader, TensorDataset\n", "from torch.utils.data import DataLoader, TensorDataset\n",
"from bidict import bidict\n",
"import math\n",
"from sklearn.utils import shuffle\n", "from sklearn.utils import shuffle\n",
"from collections import Counter\n",
"import random\n", "import random\n",
"from torchtext.vocab import build_vocab_from_iterator" "from torchtext.vocab import build_vocab_from_iterator"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 43, "execution_count": 2,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -59,7 +69,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 144, "execution_count": 3,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -75,7 +85,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 102, "execution_count": 4,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
@ -101,7 +111,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 46, "execution_count": 5,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -121,7 +131,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 103, "execution_count": 6,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
@ -135,7 +145,7 @@
"name": "stderr", "name": "stderr",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"100%|██████████| 433/433 [01:34<00:00, 4.60it/s]\n" "100%|██████████| 433/433 [01:34<00:00, 4.57it/s]\n"
] ]
} }
], ],
@ -148,7 +158,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 104, "execution_count": 7,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -157,16 +167,16 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 105, "execution_count": 8,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"['<unk>', 'the', 'of', 'me.']" "['<unk>', 'the', 'of', 'houses']"
] ]
}, },
"execution_count": 105, "execution_count": 8,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
@ -177,14 +187,21 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 106, "execution_count": 9,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"name": "stderr", "name": "stderr",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"100%|██████████| 433/433 [01:06<00:00, 6.50it/s]\n" " 0%| | 0/433 [00:00<?, ?it/s]"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 433/433 [01:08<00:00, 6.31it/s]\n"
] ]
} }
], ],
@ -219,15 +236,15 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 107, "execution_count": 10,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"name": "stderr", "name": "stderr",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"100%|██████████| 432022/432022 [00:02<00:00, 168428.47it/s]\n", "100%|██████████| 432022/432022 [00:02<00:00, 172153.99it/s]\n",
"100%|██████████| 432022/432022 [00:01<00:00, 332294.03it/s]\n" "100%|██████████| 432022/432022 [00:01<00:00, 316818.82it/s]\n"
] ]
} }
], ],
@ -250,7 +267,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 108, "execution_count": 11,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -266,7 +283,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 110, "execution_count": 12,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -278,7 +295,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 111, "execution_count": 13,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -288,7 +305,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 112, "execution_count": 14,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
@ -307,15 +324,15 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 132, "execution_count": 15,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"['silver', 'republicans', 'silver', 'den']\n", "['a', 'charming', 'woman', 'would']\n",
"['and']\n" "['young']\n"
] ]
} }
], ],
@ -334,7 +351,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 121, "execution_count": 16,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -343,7 +360,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 122, "execution_count": 17,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -359,15 +376,15 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 141, "execution_count": 18,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"class TrigramNN(nn.Module):\n", "class NGramNN(nn.Module):\n",
" def __init__(self, vocab_size, embedding_dim, hidden_dim, output_size):\n", " def __init__(self, vocab_size, embedding_dim):\n",
" super(TrigramNN, self).__init__()\n", " super(NGramNN, self).__init__()\n",
" self.embedding = nn.Embedding(vocab_size, embedding_dim)\n", " self.embedding = nn.Embedding(vocab_size, embedding_dim)\n",
" self.linear = nn.Linear(embedding_dim * (left_tokens + right_tokens), output_size)\n", " self.linear = nn.Linear(embedding_dim * (left_tokens + right_tokens), vocab_size)\n",
" \n", " \n",
" def forward(self, inputs):\n", " def forward(self, inputs):\n",
" out = self.embedding(inputs)\n", " out = self.embedding(inputs)\n",
@ -386,11 +403,20 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 145, "execution_count": 20,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\Marcin\\.conda\\envs\\p311-cu121\\Lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"source": [ "source": [
"model = TrigramNN(vocab_size, embedding_dim, hidden_dim, output_size)\n", "model = NGramNN(vocab_size, embedding_dim)\n",
"criterion = nn.CrossEntropyLoss()\n", "criterion = nn.CrossEntropyLoss()\n",
"optimizer = optim.Adam(model.parameters(), lr=learning_rate)" "optimizer = optim.Adam(model.parameters(), lr=learning_rate)"
] ]
@ -404,7 +430,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 146, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
@ -418,70 +444,140 @@
"name": "stderr", "name": "stderr",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"100%|██████████| 1480/1480 [00:32<00:00, 44.97it/s]\n" "100%|██████████| 1480/1480 [00:32<00:00, 45.24it/s]\n"
] ]
}, },
{ {
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"Epoch 1, Loss: 7.505966403999844\n" "Epoch 1, Loss: 7.488175111847955\n"
] ]
}, },
{ {
"name": "stderr", "name": "stderr",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"100%|██████████| 1480/1480 [00:32<00:00, 45.57it/s]\n" "100%|██████████| 1480/1480 [00:32<00:00, 45.62it/s]\n"
] ]
}, },
{ {
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"Epoch 2, Loss: 5.1014555966531905\n" "Epoch 2, Loss: 5.083534079629022\n"
] ]
}, },
{ {
"name": "stderr", "name": "stderr",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"100%|██████████| 1480/1480 [00:32<00:00, 45.42it/s]\n" "100%|██████████| 1480/1480 [00:32<00:00, 44.91it/s]\n"
] ]
}, },
{ {
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"Epoch 3, Loss: 3.835972652886365\n" "Epoch 3, Loss: 3.8214319522316393\n"
] ]
}, },
{ {
"name": "stderr", "name": "stderr",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"100%|██████████| 1480/1480 [00:32<00:00, 45.60it/s]\n" "100%|██████████| 1480/1480 [00:33<00:00, 44.65it/s]\n"
] ]
}, },
{ {
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"Epoch 4, Loss: 3.1567180975063427\n" "Epoch 4, Loss: 3.1464366490776476\n"
] ]
}, },
{ {
"name": "stderr", "name": "stderr",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"100%|██████████| 1480/1480 [00:32<00:00, 44.90it/s]" "100%|██████████| 1480/1480 [00:32<00:00, 45.94it/s]\n"
] ]
}, },
{ {
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"Epoch 5, Loss: 2.749172909517546\n" "Epoch 5, Loss: 2.743303858589482\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 1480/1480 [00:32<00:00, 46.07it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 6, Loss: 2.456264268949225\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 1480/1480 [00:32<00:00, 45.59it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 7, Loss: 2.2358319317972337\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 1480/1480 [00:32<00:00, 46.15it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 8, Loss: 2.0536118873067806\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 1480/1480 [00:32<00:00, 45.90it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 9, Loss: 1.89841981795994\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 1480/1480 [00:32<00:00, 44.89it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 10, Loss: 1.7637179977990485\n"
] ]
}, },
{ {
@ -521,7 +617,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 157, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -555,15 +651,15 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 158, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"a:0.07 the:0.06 The:0.06 A:0.06 my:0.05 ,:0.05 to:0.05 and:0.05 -:0.05 an:0.05 :0.45\n", "the:0.07 his:0.06 a:0.06 their:0.06 John:0.06 tho:0.05 he:0.05 its:0.05 and:0.05 my:0.05 :0.44\n",
"of:0.07 on:0.06 and:0.06 be:0.06 for:0.05 in:0.05 school,:0.05 ol:0.05 it:0.05 the:0.05 :0.45\n" "to:0.06 a:0.06 the:0.06 and:0.06 in:0.05 when:0.05 of:0.05 up:0.05 such:0.05 for:0.05 :0.46\n"
] ]
} }
], ],
@ -581,7 +677,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 159, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -594,14 +690,14 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 160, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"name": "stderr", "name": "stderr",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"100%|██████████| 10519/10519 [00:51<00:00, 203.44it/s]\n" "100%|██████████| 10519/10519 [00:50<00:00, 206.40it/s]\n"
] ]
} }
], ],