Perplexity: 422
This commit is contained in:
parent
a2d183f2e3
commit
0c45624062
21038
dev-0/out.tsv
21038
dev-0/out.tsv
File diff suppressed because it is too large
Load Diff
@ -16,33 +16,43 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 42,
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"c:\\Users\\Marcin\\.conda\\envs\\p311-cu121\\Lib\\site-packages\\torchtext\\vocab\\__init__.py:4: UserWarning: \n",
|
||||
"/!\\ IMPORTANT WARNING ABOUT TORCHTEXT STATUS /!\\ \n",
|
||||
"Torchtext is deprecated and the last released version will be 0.18 (this one). You can silence this warning by calling the following at the beginnign of your scripts: `import torchtext; torchtext.disable_torchtext_deprecation_warning()`\n",
|
||||
" warnings.warn(torchtext._TORCHTEXT_DEPRECATION_MSG)\n",
|
||||
"c:\\Users\\Marcin\\.conda\\envs\\p311-cu121\\Lib\\site-packages\\torchtext\\utils.py:4: UserWarning: \n",
|
||||
"/!\\ IMPORTANT WARNING ABOUT TORCHTEXT STATUS /!\\ \n",
|
||||
"Torchtext is deprecated and the last released version will be 0.18 (this one). You can silence this warning by calling the following at the beginnign of your scripts: `import torchtext; torchtext.disable_torchtext_deprecation_warning()`\n",
|
||||
" warnings.warn(torchtext._TORCHTEXT_DEPRECATION_MSG)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from tqdm import tqdm\n",
|
||||
"import re\n",
|
||||
"import nltk\n",
|
||||
"import os\n",
|
||||
"import csv\n",
|
||||
"import pandas as pd\n",
|
||||
"import torch\n",
|
||||
"import torch.nn as nn\n",
|
||||
"import torch.optim as optim\n",
|
||||
"import sys\n",
|
||||
"import numpy as np\n",
|
||||
"from torch.utils.data import DataLoader, TensorDataset\n",
|
||||
"from bidict import bidict\n",
|
||||
"import math\n",
|
||||
"from sklearn.utils import shuffle\n",
|
||||
"from collections import Counter\n",
|
||||
"import random\n",
|
||||
"from torchtext.vocab import build_vocab_from_iterator"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 43,
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -59,7 +69,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 144,
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -75,7 +85,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 102,
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
@ -101,7 +111,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 46,
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -121,7 +131,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 103,
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
@ -135,7 +145,7 @@
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"100%|██████████| 433/433 [01:34<00:00, 4.60it/s]\n"
|
||||
"100%|██████████| 433/433 [01:34<00:00, 4.57it/s]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -148,7 +158,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 104,
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -157,16 +167,16 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 105,
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"['<unk>', 'the', 'of', 'me.']"
|
||||
"['<unk>', 'the', 'of', 'houses']"
|
||||
]
|
||||
},
|
||||
"execution_count": 105,
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@ -177,14 +187,21 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 106,
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"100%|██████████| 433/433 [01:06<00:00, 6.50it/s]\n"
|
||||
" 0%| | 0/433 [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"100%|██████████| 433/433 [01:08<00:00, 6.31it/s]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -219,15 +236,15 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 107,
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"100%|██████████| 432022/432022 [00:02<00:00, 168428.47it/s]\n",
|
||||
"100%|██████████| 432022/432022 [00:01<00:00, 332294.03it/s]\n"
|
||||
"100%|██████████| 432022/432022 [00:02<00:00, 172153.99it/s]\n",
|
||||
"100%|██████████| 432022/432022 [00:01<00:00, 316818.82it/s]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -250,7 +267,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 108,
|
||||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -266,7 +283,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 110,
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -278,7 +295,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 111,
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -288,7 +305,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 112,
|
||||
"execution_count": 14,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
@ -307,15 +324,15 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 132,
|
||||
"execution_count": 15,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"['silver', 'republicans', 'silver', 'den']\n",
|
||||
"['and']\n"
|
||||
"['a', 'charming', 'woman', 'would']\n",
|
||||
"['young']\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -334,7 +351,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 121,
|
||||
"execution_count": 16,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -343,7 +360,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 122,
|
||||
"execution_count": 17,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -359,15 +376,15 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 141,
|
||||
"execution_count": 18,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class TrigramNN(nn.Module):\n",
|
||||
" def __init__(self, vocab_size, embedding_dim, hidden_dim, output_size):\n",
|
||||
" super(TrigramNN, self).__init__()\n",
|
||||
"class NGramNN(nn.Module):\n",
|
||||
" def __init__(self, vocab_size, embedding_dim):\n",
|
||||
" super(NGramNN, self).__init__()\n",
|
||||
" self.embedding = nn.Embedding(vocab_size, embedding_dim)\n",
|
||||
" self.linear = nn.Linear(embedding_dim * (left_tokens + right_tokens), output_size)\n",
|
||||
" self.linear = nn.Linear(embedding_dim * (left_tokens + right_tokens), vocab_size)\n",
|
||||
" \n",
|
||||
" def forward(self, inputs):\n",
|
||||
" out = self.embedding(inputs)\n",
|
||||
@ -386,11 +403,20 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 145,
|
||||
"execution_count": 20,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"c:\\Users\\Marcin\\.conda\\envs\\p311-cu121\\Lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
||||
" from .autonotebook import tqdm as notebook_tqdm\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"model = TrigramNN(vocab_size, embedding_dim, hidden_dim, output_size)\n",
|
||||
"model = NGramNN(vocab_size, embedding_dim)\n",
|
||||
"criterion = nn.CrossEntropyLoss()\n",
|
||||
"optimizer = optim.Adam(model.parameters(), lr=learning_rate)"
|
||||
]
|
||||
@ -404,7 +430,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 146,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
@ -418,70 +444,140 @@
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"100%|██████████| 1480/1480 [00:32<00:00, 44.97it/s]\n"
|
||||
"100%|██████████| 1480/1480 [00:32<00:00, 45.24it/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch 1, Loss: 7.505966403999844\n"
|
||||
"Epoch 1, Loss: 7.488175111847955\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"100%|██████████| 1480/1480 [00:32<00:00, 45.57it/s]\n"
|
||||
"100%|██████████| 1480/1480 [00:32<00:00, 45.62it/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch 2, Loss: 5.1014555966531905\n"
|
||||
"Epoch 2, Loss: 5.083534079629022\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"100%|██████████| 1480/1480 [00:32<00:00, 45.42it/s]\n"
|
||||
"100%|██████████| 1480/1480 [00:32<00:00, 44.91it/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch 3, Loss: 3.835972652886365\n"
|
||||
"Epoch 3, Loss: 3.8214319522316393\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"100%|██████████| 1480/1480 [00:32<00:00, 45.60it/s]\n"
|
||||
"100%|██████████| 1480/1480 [00:33<00:00, 44.65it/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch 4, Loss: 3.1567180975063427\n"
|
||||
"Epoch 4, Loss: 3.1464366490776476\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"100%|██████████| 1480/1480 [00:32<00:00, 44.90it/s]"
|
||||
"100%|██████████| 1480/1480 [00:32<00:00, 45.94it/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch 5, Loss: 2.749172909517546\n"
|
||||
"Epoch 5, Loss: 2.743303858589482\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"100%|██████████| 1480/1480 [00:32<00:00, 46.07it/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch 6, Loss: 2.456264268949225\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"100%|██████████| 1480/1480 [00:32<00:00, 45.59it/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch 7, Loss: 2.2358319317972337\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"100%|██████████| 1480/1480 [00:32<00:00, 46.15it/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch 8, Loss: 2.0536118873067806\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"100%|██████████| 1480/1480 [00:32<00:00, 45.90it/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch 9, Loss: 1.89841981795994\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"100%|██████████| 1480/1480 [00:32<00:00, 44.89it/s]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch 10, Loss: 1.7637179977990485\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -521,7 +617,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 157,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -555,15 +651,15 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 158,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"a:0.07 the:0.06 The:0.06 A:0.06 my:0.05 ,:0.05 to:0.05 and:0.05 -:0.05 an:0.05 :0.45\n",
|
||||
"of:0.07 on:0.06 and:0.06 be:0.06 for:0.05 in:0.05 school,:0.05 ol:0.05 it:0.05 the:0.05 :0.45\n"
|
||||
"the:0.07 his:0.06 a:0.06 their:0.06 John:0.06 tho:0.05 he:0.05 its:0.05 and:0.05 my:0.05 :0.44\n",
|
||||
"to:0.06 a:0.06 the:0.06 and:0.06 in:0.05 when:0.05 of:0.05 up:0.05 such:0.05 for:0.05 :0.46\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -581,7 +677,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 159,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -594,14 +690,14 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 160,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"100%|██████████| 10519/10519 [00:51<00:00, 203.44it/s]\n"
|
||||
"100%|██████████| 10519/10519 [00:50<00:00, 206.40it/s]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
Loading…
Reference in New Issue
Block a user