challenging-america-word-ga.../main.ipynb

1312 lines
133 KiB
Plaintext
Raw Permalink Normal View History

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "f3452caf-df58-4394-b0d6-46459cb47045",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"S:\\WENV_TORCHTEXT\\Lib\\site-packages\\torchtext\\vocab\\__init__.py:4: UserWarning: \n",
"/!\\ IMPORTANT WARNING ABOUT TORCHTEXT STATUS /!\\ \n",
"Torchtext is deprecated and the last released version will be 0.18 (this one). You can silence this warning by calling the following at the beginnign of your scripts: `import torchtext; torchtext.disable_torchtext_deprecation_warning()`\n",
" warnings.warn(torchtext._TORCHTEXT_DEPRECATION_MSG)\n",
"S:\\WENV_TORCHTEXT\\Lib\\site-packages\\torchtext\\utils.py:4: UserWarning: \n",
"/!\\ IMPORTANT WARNING ABOUT TORCHTEXT STATUS /!\\ \n",
"Torchtext is deprecated and the last released version will be 0.18 (this one). You can silence this warning by calling the following at the beginnign of your scripts: `import torchtext; torchtext.disable_torchtext_deprecation_warning()`\n",
" warnings.warn(torchtext._TORCHTEXT_DEPRECATION_MSG)\n"
]
}
],
"source": [
"from torch.utils.data import IterableDataset, DataLoader\n",
"from torchtext.vocab import build_vocab_from_iterator\n",
"\n",
"import regex as re\n",
"import itertools\n",
"from itertools import islice\n",
"\n",
"from torch import nn\n",
"import torch\n",
"\n",
"from tqdm.notebook import tqdm\n",
2024-05-23 02:41:59 +02:00
"device = 'cuda'"
]
},
{
"cell_type": "code",
"execution_count": 2,
2024-05-23 02:41:59 +02:00
"id": "5ee9ad24-a5d2-47e1-a5c6-88981dc22b99",
"metadata": {},
"outputs": [],
"source": [
"def get_words_from_line(line):\n",
" line = line.rstrip()\n",
" yield '<s>'\n",
" for m in re.finditer(r'[\\p{L}0-9\\*]+|\\p{P}+', line):\n",
" yield m.group(0).lower()\n",
" yield '</s>'\n",
"\n",
"def get_word_lines_from_file(file_name):\n",
" with open(file_name, 'r', encoding='utf8') as fh:\n",
" for line in fh:\n",
" yield get_words_from_line(line)\n",
"\n",
"def look_ahead_iterator(gen):\n",
" prev2, prev1, next1, next2 = None, None, None, None\n",
" for item in gen:\n",
" if prev2 is not None and prev1 is not None and next1 is not None and next2 is not None:\n",
" yield (prev2, prev1, next2, item, next1)\n",
2024-05-23 02:41:59 +02:00
" prev2, prev1, next1, next2 = prev1, next1, next2, item"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "93279277-0765-4f85-9666-095fc7808c81",
"metadata": {},
"outputs": [],
"source": [
"class FiveGrams(IterableDataset):\n",
" def __init__(self, text_file, vocabulary_size):\n",
" self.vocab = build_vocab_from_iterator(\n",
" get_word_lines_from_file(text_file),\n",
" max_tokens=vocabulary_size,\n",
" specials=['<unk>']\n",
" )\n",
" self.vocab.set_default_index(self.vocab['<unk>'])\n",
" self.vocabulary_size = vocabulary_size\n",
" self.text_file = text_file\n",
"\n",
" def __iter__(self):\n",
" return look_ahead_iterator(\n",
" (self.vocab[t] for t in itertools.chain.from_iterable(get_word_lines_from_file(self.text_file)))\n",
2024-05-23 02:41:59 +02:00
" )"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "6eb5fbd9-bc0f-499d-85f4-3998a4a3f56e",
"metadata": {},
"outputs": [],
"source": [
"class SimpleFiveGramNeuralLanguageModel(nn.Module):\n",
" def __init__(self, vocabulary_size, embedding_size):\n",
" super(SimpleFiveGramNeuralLanguageModel, self).__init__()\n",
" self.embedding = nn.Embedding(vocabulary_size, embedding_size)\n",
" self.linear1 = nn.Linear(embedding_size * 4, embedding_size)\n",
" self.linear2 = nn.Linear(embedding_size, vocabulary_size)\n",
" self.softmax = nn.Softmax(dim=1)\n",
" self.embedding_size = embedding_size\n",
"\n",
" def forward(self, x):\n",
" embeds = self.embedding(x).view(x.size(0), -1)\n",
" out = self.linear1(embeds)\n",
" out = self.linear2(out)\n",
2024-05-23 02:41:59 +02:00
" return self.softmax(out)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "d0dc7c69-3f27-4f00-9b91-5f3a403df074",
"metadata": {},
2024-05-23 02:41:59 +02:00
"outputs": [],
"source": [
2024-05-23 02:41:59 +02:00
"def train(embed_size,vocab_size,num_epochs,batch_size,train_file_path):\n",
" train_dataset = FiveGrams(train_file_path, vocab_size)\n",
" model = SimpleFiveGramNeuralLanguageModel(vocab_size, embed_size).to(device)\n",
" \n",
" data = DataLoader(train_dataset, batch_size=batch_size)\n",
" optimizer = torch.optim.Adam(model.parameters())\n",
" criterion = torch.nn.CrossEntropyLoss()\n",
" \n",
" model.train()\n",
" step = 0\n",
2024-05-23 02:41:59 +02:00
" for _ in range(num_epochs):\n",
" for x1, x2, x3, x4, y in tqdm(data, desc=\"Train loop\"):\n",
" y = y.to(device)\n",
" x = torch.cat((x1.unsqueeze(1), x2.unsqueeze(1), x3.unsqueeze(1), x4.unsqueeze(1)), dim=1).to(device)\n",
" optimizer.zero_grad()\n",
" ypredicted = model(x)\n",
" \n",
" loss = criterion(torch.log(ypredicted), y)\n",
" if step % 5000 == 0:\n",
" print(step, loss)\n",
" step += 1\n",
" loss.backward()\n",
" optimizer.step()\n",
" step = 0\n",
" break\n",
" model.eval()\n",
"\n",
" return model, train_dataset.vocab"
]
},
{
"cell_type": "code",
2024-05-23 02:41:59 +02:00
"execution_count": 6,
"id": "9a1b2240-d2ed-4c56-8443-12113e66b514",
"metadata": {},
"outputs": [],
"source": [
2024-05-23 02:41:59 +02:00
"def get_gap_candidates(words, model, vocab, n=20):\n",
" ixs = vocab(words)\n",
" ixs = torch.tensor(ixs).unsqueeze(0).to(device)\n",
"\n",
" out = model(ixs)\n",
" top = torch.topk(out[0], n)\n",
" top_indices = top.indices.tolist()\n",
" top_probs = top.values.tolist()\n",
" top_words = vocab.lookup_tokens(top_indices)\n",
" return list(zip(top_words, top_probs))\n",
"\n",
"def clean(text):\n",
" text = text.replace('-\\\\n', '').replace('\\\\n', ' ').replace('\\\\t', ' ')\n",
" text = re.sub(r'\\n', ' ', text)\n",
" text = re.sub(r'(?<=\\w)[,-](?=\\w)', '', text)\n",
" text = re.sub(r'\\s+', ' ', text)\n",
" text = re.sub(r'\\p{P}', '', text)\n",
" text = text.strip()\n",
" return text\n",
" \n",
2024-05-23 02:41:59 +02:00
"def predictor(prefix, suffix, model, vocab):\n",
" prefix = clean(prefix)\n",
" suffix = clean(suffix)\n",
" words = prefix.split(' ')[-2:] + suffix.split(' ')[:2]\n",
2024-05-23 02:41:59 +02:00
" candidates = get_gap_candidates(words, model, vocab)\n",
"\n",
" probs_sum = 0\n",
" output = ''\n",
" for word, prob in candidates:\n",
" if word == \"<unk>\":\n",
" continue\n",
" probs_sum += prob\n",
" output += f\"{word}:{prob} \"\n",
" output += f\":{1-probs_sum}\"\n",
"\n",
" return output"
]
},
{
"cell_type": "code",
2024-05-23 02:41:59 +02:00
"execution_count": 7,
"id": "40af2781-3807-43e8-b6dd-3b70066e50c1",
"metadata": {},
2024-05-23 02:41:59 +02:00
"outputs": [],
"source": [
"def generate_result(input_path,model, vocab, output_path='out.tsv'):\n",
" lines = []\n",
" with open(input_path, encoding='utf-8') as f:\n",
" for line in f:\n",
" columns = line.split('\\t')\n",
" prefix = columns[6]\n",
" suffix = columns[7]\n",
" lines.append((prefix, suffix))\n",
"\n",
" with open(output_path, 'w', encoding='utf-8') as output_file:\n",
" for prefix, suffix in tqdm(lines):\n",
" result = predictor(prefix, suffix, model, vocab)\n",
" output_file.write(result + '\\n')"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "d6b7234f-1f40-468f-8c69-2875bb1ec947",
"metadata": {},
"outputs": [],
"source": [
"import subprocess\n",
"\n",
"def evaluate():\n",
" cmd = 'wsl bash -c \"cd /mnt/d/UAM/MODELOWANIE/5GRAM && ./geval -t dev-0\"'\n",
" result = subprocess.run(cmd, shell=True, capture_output=True, text=True)\n",
" return float(result.stdout)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "4c716463-27fe-4c2b-b859-ac9c8aff1942",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
2024-05-23 02:41:59 +02:00
"model_id": "1eac733e07974322bfd47dcff96aa8d4",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Train loop: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"0 tensor(9.2551, device='cuda:0', grad_fn=<NllLossBackward0>)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "3df53998bc334a29bd355578738897d3",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/10519 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
2024-05-23 02:41:59 +02:00
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'embed_size': 100, 'vocab_size': 10000, 'num_epochs': 1, 'batch_size': 8192, 'train_file_path': 'train/nano.txt', 'perplexity': 335.54}\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "df35ef1138644a00a86f999a6cb8a0cb",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Train loop: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"0 tensor(9.2881, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"5000 tensor(4.6065, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"10000 tensor(4.4173, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"15000 tensor(4.3352, device='cuda:0', grad_fn=<NllLossBackward0>)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "1e3d1e4344f44285832accfc83ab2233",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/10519 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'embed_size': 100, 'vocab_size': 10000, 'num_epochs': 1, 'batch_size': 8192, 'train_file_path': 'train/train.txt', 'perplexity': 199.0}\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c6d6813fbb2c4fcba084fffc5c0105b6",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Train loop: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"0 tensor(9.9662, device='cuda:0', grad_fn=<NllLossBackward0>)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "3546c4fd825a4057b05c6105b25f6712",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/10519 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'embed_size': 100, 'vocab_size': 20000, 'num_epochs': 1, 'batch_size': 8192, 'train_file_path': 'train/nano.txt', 'perplexity': 342.38}\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "bd6d71e568a848c19bb8ce762d5b53ef",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Train loop: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"0 tensor(9.9619, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"5000 tensor(4.8952, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"10000 tensor(4.7382, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"15000 tensor(4.6068, device='cuda:0', grad_fn=<NllLossBackward0>)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "cd0c53ddfb3148609218b56c72ce7777",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/10519 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'embed_size': 100, 'vocab_size': 20000, 'num_epochs': 1, 'batch_size': 8192, 'train_file_path': 'train/train.txt', 'perplexity': 190.06}\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "5261144cdc7e4c69aa28d88fd9fad7fe",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Train loop: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"0 tensor(10.3571, device='cuda:0', grad_fn=<NllLossBackward0>)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "42b54814d7f741ddb38d8e00d6db2126",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/10519 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'embed_size': 100, 'vocab_size': 30000, 'num_epochs': 1, 'batch_size': 8192, 'train_file_path': 'train/nano.txt', 'perplexity': 344.0}\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "bc582df4aaa741e2a46ac64a282f36fc",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Train loop: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"0 tensor(10.3726, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"5000 tensor(5.0450, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"10000 tensor(4.8688, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"15000 tensor(4.7152, device='cuda:0', grad_fn=<NllLossBackward0>)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "38a391c3471442dc8620524f0d5c5fc8",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/10519 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'embed_size': 100, 'vocab_size': 30000, 'num_epochs': 1, 'batch_size': 8192, 'train_file_path': 'train/train.txt', 'perplexity': 188.45}\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "5d3d20e040bb4c609143cc335143b6ed",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Train loop: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"0 tensor(9.2857, device='cuda:0', grad_fn=<NllLossBackward0>)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "65a44fd20fce4ec49528763ebb351f4b",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/10519 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'embed_size': 200, 'vocab_size': 10000, 'num_epochs': 1, 'batch_size': 8192, 'train_file_path': 'train/nano.txt', 'perplexity': 288.34}\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "ba3783d3ce5c47ce8c02443ce7601288",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Train loop: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"0 tensor(9.2722, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"5000 tensor(4.4829, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"10000 tensor(4.2794, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"15000 tensor(4.2239, device='cuda:0', grad_fn=<NllLossBackward0>)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "9b2f000b5c7a4ab08fb48c78b45c3a54",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/10519 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'embed_size': 200, 'vocab_size': 10000, 'num_epochs': 1, 'batch_size': 8192, 'train_file_path': 'train/train.txt', 'perplexity': 187.39}\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "1eff87ac9b594efb8406f3fc5e77a8ce",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Train loop: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"0 tensor(9.9422, device='cuda:0', grad_fn=<NllLossBackward0>)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "6883012f36d44ab0842b14281ef5eb65",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/10519 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'embed_size': 200, 'vocab_size': 20000, 'num_epochs': 1, 'batch_size': 8192, 'train_file_path': 'train/nano.txt', 'perplexity': 294.44}\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "1006d625eea94078a56bb393800b2e7f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Train loop: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"0 tensor(9.9541, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"5000 tensor(4.7515, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"10000 tensor(4.5669, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"15000 tensor(4.4938, device='cuda:0', grad_fn=<NllLossBackward0>)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "53d6628208ac4997a290cdde469ccfb1",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/10519 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'embed_size': 200, 'vocab_size': 20000, 'num_epochs': 1, 'batch_size': 8192, 'train_file_path': 'train/train.txt', 'perplexity': 178.27}\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "b9ed3ef499f44fd2942c0d9e2d827442",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Train loop: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"0 tensor(10.3758, device='cuda:0', grad_fn=<NllLossBackward0>)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "ccb8ed60361c4a509afb628f52ce609c",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/10519 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'embed_size': 200, 'vocab_size': 30000, 'num_epochs': 1, 'batch_size': 8192, 'train_file_path': 'train/nano.txt', 'perplexity': 297.89}\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "62c15cbb03c14cc7aee9f3245b3f536e",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Train loop: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"0 tensor(10.3769, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"5000 tensor(4.8590, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"10000 tensor(4.7090, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"15000 tensor(4.5810, device='cuda:0', grad_fn=<NllLossBackward0>)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "85f4db52171945fa9fb2890431a276b2",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/10519 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'embed_size': 200, 'vocab_size': 30000, 'num_epochs': 1, 'batch_size': 8192, 'train_file_path': 'train/train.txt', 'perplexity': 177.66}\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a87614dd8b3447dea6447d7d73a2677b",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Train loop: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"0 tensor(9.2580, device='cuda:0', grad_fn=<NllLossBackward0>)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "32802641dc974b6cb3404a12c51b611b",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/10519 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'embed_size': 300, 'vocab_size': 10000, 'num_epochs': 1, 'batch_size': 8192, 'train_file_path': 'train/nano.txt', 'perplexity': 271.01}\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "f4ebf3dde5274a90ac74dcc17f066b10",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Train loop: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"0 tensor(9.2549, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"5000 tensor(4.4134, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"10000 tensor(4.2280, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"15000 tensor(4.1653, device='cuda:0', grad_fn=<NllLossBackward0>)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a97a141657af410b989bbd8ab8710955",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/10519 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'embed_size': 300, 'vocab_size': 10000, 'num_epochs': 1, 'batch_size': 8192, 'train_file_path': 'train/train.txt', 'perplexity': 181.91}\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "391718c3103d4b9ebfb571b7234e946f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Train loop: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"0 tensor(9.9555, device='cuda:0', grad_fn=<NllLossBackward0>)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "3aa3a60daf6145389312b031ca55e11d",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/10519 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'embed_size': 300, 'vocab_size': 20000, 'num_epochs': 1, 'batch_size': 8192, 'train_file_path': 'train/nano.txt', 'perplexity': 275.53}\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "939dfbbf01d5425da9f885060ae44451",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Train loop: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"0 tensor(9.9647, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"5000 tensor(4.6888, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"10000 tensor(4.5068, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"15000 tensor(4.4465, device='cuda:0', grad_fn=<NllLossBackward0>)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "56e1065d5c01461e8cdfc38b7db2f3ed",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/10519 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'embed_size': 300, 'vocab_size': 20000, 'num_epochs': 1, 'batch_size': 8192, 'train_file_path': 'train/train.txt', 'perplexity': 174.42}\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "d026a38b16b74410b238e9c6b7c997a6",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Train loop: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"0 tensor(10.3563, device='cuda:0', grad_fn=<NllLossBackward0>)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "58d964531ad14b69ba72b7d8d538cf2d",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/10519 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'embed_size': 300, 'vocab_size': 30000, 'num_epochs': 1, 'batch_size': 8192, 'train_file_path': 'train/nano.txt', 'perplexity': 277.32}\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "0bddf3d40c054e92805c100cad8aaccd",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Train loop: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"0 tensor(10.3580, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"5000 tensor(4.8159, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"10000 tensor(4.6442, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"15000 tensor(4.5524, device='cuda:0', grad_fn=<NllLossBackward0>)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "aedcfa56f6c345dd976b37535d0de2b4",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/10519 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'embed_size': 300, 'vocab_size': 30000, 'num_epochs': 1, 'batch_size': 8192, 'train_file_path': 'train/train.txt', 'perplexity': 173.38}\n"
]
}
],
"source": [
"embed_sizes = [100,200,300]\n",
"vocab_sizes = [10_000, 20_000, 30_000]\n",
"num_epochss = [1]\n",
"batch_sizes = [8192]\n",
"train_file_paths = ['train/nano.txt', 'train/train.txt']\n",
"\n",
"results = []\n",
"\n",
"for embed_size in embed_sizes:\n",
" for vocab_size in vocab_sizes:\n",
" for num_epochs in num_epochss:\n",
" for batch_size in batch_sizes:\n",
" for train_file_path in train_file_paths:\n",
" model, vocab = train(embed_size,vocab_size,num_epochs,batch_size,train_file_path)\n",
" generate_result('dev-0/in.tsv', model, vocab, output_path='dev-0/out.tsv')\n",
" result = evaluate()\n",
"\n",
" config = {\"embed_size\": embed_size, \"vocab_size\": vocab_size, \"num_epochs\": num_epochs, \"batch_size\": batch_size, \"train_file_path\": train_file_path, \"perplexity\": result }\n",
" print(config)\n",
" results.append( config )"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "5d835463-01ed-4b44-a652-1ea469542d89",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[{'embed_size': 100,\n",
" 'vocab_size': 10000,\n",
" 'num_epochs': 1,\n",
" 'batch_size': 8192,\n",
" 'train_file_path': 'train/nano.txt',\n",
" 'perplexity': 335.54},\n",
" {'embed_size': 100,\n",
" 'vocab_size': 10000,\n",
" 'num_epochs': 1,\n",
" 'batch_size': 8192,\n",
" 'train_file_path': 'train/train.txt',\n",
" 'perplexity': 199.0},\n",
" {'embed_size': 100,\n",
" 'vocab_size': 20000,\n",
" 'num_epochs': 1,\n",
" 'batch_size': 8192,\n",
" 'train_file_path': 'train/nano.txt',\n",
" 'perplexity': 342.38},\n",
" {'embed_size': 100,\n",
" 'vocab_size': 20000,\n",
" 'num_epochs': 1,\n",
" 'batch_size': 8192,\n",
" 'train_file_path': 'train/train.txt',\n",
" 'perplexity': 190.06},\n",
" {'embed_size': 100,\n",
" 'vocab_size': 30000,\n",
" 'num_epochs': 1,\n",
" 'batch_size': 8192,\n",
" 'train_file_path': 'train/nano.txt',\n",
" 'perplexity': 344.0},\n",
" {'embed_size': 100,\n",
" 'vocab_size': 30000,\n",
" 'num_epochs': 1,\n",
" 'batch_size': 8192,\n",
" 'train_file_path': 'train/train.txt',\n",
" 'perplexity': 188.45},\n",
" {'embed_size': 200,\n",
" 'vocab_size': 10000,\n",
" 'num_epochs': 1,\n",
" 'batch_size': 8192,\n",
" 'train_file_path': 'train/nano.txt',\n",
" 'perplexity': 288.34},\n",
" {'embed_size': 200,\n",
" 'vocab_size': 10000,\n",
" 'num_epochs': 1,\n",
" 'batch_size': 8192,\n",
" 'train_file_path': 'train/train.txt',\n",
" 'perplexity': 187.39},\n",
" {'embed_size': 200,\n",
" 'vocab_size': 20000,\n",
" 'num_epochs': 1,\n",
" 'batch_size': 8192,\n",
" 'train_file_path': 'train/nano.txt',\n",
" 'perplexity': 294.44},\n",
" {'embed_size': 200,\n",
" 'vocab_size': 20000,\n",
" 'num_epochs': 1,\n",
" 'batch_size': 8192,\n",
" 'train_file_path': 'train/train.txt',\n",
" 'perplexity': 178.27},\n",
" {'embed_size': 200,\n",
" 'vocab_size': 30000,\n",
" 'num_epochs': 1,\n",
" 'batch_size': 8192,\n",
" 'train_file_path': 'train/nano.txt',\n",
" 'perplexity': 297.89},\n",
" {'embed_size': 200,\n",
" 'vocab_size': 30000,\n",
" 'num_epochs': 1,\n",
" 'batch_size': 8192,\n",
" 'train_file_path': 'train/train.txt',\n",
" 'perplexity': 177.66},\n",
" {'embed_size': 300,\n",
" 'vocab_size': 10000,\n",
" 'num_epochs': 1,\n",
" 'batch_size': 8192,\n",
" 'train_file_path': 'train/nano.txt',\n",
" 'perplexity': 271.01},\n",
" {'embed_size': 300,\n",
" 'vocab_size': 10000,\n",
" 'num_epochs': 1,\n",
" 'batch_size': 8192,\n",
" 'train_file_path': 'train/train.txt',\n",
" 'perplexity': 181.91},\n",
" {'embed_size': 300,\n",
" 'vocab_size': 20000,\n",
" 'num_epochs': 1,\n",
" 'batch_size': 8192,\n",
" 'train_file_path': 'train/nano.txt',\n",
" 'perplexity': 275.53},\n",
" {'embed_size': 300,\n",
" 'vocab_size': 20000,\n",
" 'num_epochs': 1,\n",
" 'batch_size': 8192,\n",
" 'train_file_path': 'train/train.txt',\n",
" 'perplexity': 174.42},\n",
" {'embed_size': 300,\n",
" 'vocab_size': 30000,\n",
" 'num_epochs': 1,\n",
" 'batch_size': 8192,\n",
" 'train_file_path': 'train/nano.txt',\n",
" 'perplexity': 277.32},\n",
" {'embed_size': 300,\n",
" 'vocab_size': 30000,\n",
" 'num_epochs': 1,\n",
" 'batch_size': 8192,\n",
" 'train_file_path': 'train/train.txt',\n",
" 'perplexity': 173.38}]"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"results"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "b5058255-4478-427a-84f1-fe1f57fc3828",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAy0AAAIjCAYAAAAObfTCAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy80BEi2AAAACXBIWXMAAA9hAAAPYQGoP6dpAACP/UlEQVR4nOzdeVhUZf8G8HvYhn1TFonVJRUXNFLCUlEJIV/3N9NcsNwDzSgzet3NMLW0zK1SsMwlck3Tct+XNHFJwyXcAVdAUNnm+f3hjxMjwzIwMNv9uS6umnPOPPOc42Hm3Hyf54xMCCFARERERESko0y03QEiIiIiIqKyMLQQEREREZFOY2ghIiIiIiKdxtBCREREREQ6jaGFiIiIiIh0GkMLERERERHpNIYWIiIiIiLSaQwtRERERESk0xhaiIiIiIhIpzG0kFp8fX3xn//8p9pf58qVK5DJZEhISNBYmyEhIQgJCdFYe1S2wYMHw9bWttpfZ8+ePZDJZNizZ0+1v1ZVqXMOhoSEoGnTptXboRpW3b+DNXkuZGdnY+jQoXB3d4dMJsPYsWOr/TXVlZCQAJlMhuPHj2uszcGDB8PX11dj7VU3mUyGKVOmVOq5vr6+GDx4sEb7Q0SVx9BiAIo+mEr7OXLkiLa7WK2uXLmCt956C/Xq1YOlpSXc3d3Rrl07TJ48Wdtdqzbr1q2DTCbDd999V+o227dvh0wmw1dffVWDPdMchUKB77//HkFBQXB2doadnR2ef/55DBo0yGDO6Vu3bmHKlClISkrSeNu+vr5K7wOurq5o27Yt1q9fr/HX0mUrV67EvHnzNN7up59+ioSEBIwaNQo//PADBg4cqPHXMAa//vprpUMFle/Ro0eYMmWKXvxRh6g8ZtruAGnOtGnT4OfnV2J5/fr1tdCbmnHp0iW0atUKVlZWePvtt+Hr64vU1FT8+eef+OyzzzB16lRp299//12LPdWsLl26wMHBAStXrsTQoUNVbrNy5UqYmpqib9++Ndw7zRgzZgwWLFiA7t27o3///jAzM0NycjK2bt2KunXr4qWXXgIAtGvXDo8fP4aFhYWWe1y+Z8/BW7duYerUqfD19UWLFi00/notWrTA+++/L73WkiVL0KtXLyxatAgjR47U+Otpm6pzYeXKlTh79qzGKyG7du3CSy+9ZNB/HKkJv/76KxYsWFBtweXx48cwM6vcpU5ycjJMTPT7b7uPHj2SPgc50oD0HUOLAYmIiMCLL76o7W7UqLlz5yI7OxtJSUnw8fFRWnf79m2lx/pwUVtRcrkc//3vfxEfH49bt27Bw8NDaf2TJ0+wfv16vPrqq3B1ddVSLysvPT0dCxcuxLBhw/DNN98orZs3bx7u3LkjPTYxMYGlpWVNd7FSavocfO655zBgwADp8aBBg1C/fn3MnTu3yqHlyZMnsLCw0KmLupo8F27fvg1/f3+NtVdQUACFQmFQ71OaVpljVJXzQS6XV/q5RKR5uvNpQ9WuaJ7InDlzsGDBAtStWxfW1tYICwvD9evXIYTA9OnT4enpCSsrK3Tv3h33799X2dbvv/+OFi1awNLSEv7+/li3bl2JbTIyMjB27Fh4eXlBLpejfv36+Oyzz6BQKEpsN3jwYDg4OMDR0RGRkZHIyMio0D5dvnwZnp6eJQILgBIX68+Op392+Ezxn+Kl9Js3b+Ltt9+Gm5sb5HI5mjRpgmXLlpXbt6ZNm6JDhw4llisUCjz33HP473//Ky1bvXo1AgMDYWdnB3t7ezRr1gxffvllme0PGDAACoUCq1evLrFuy5YtyMzMRP/+/QE8/bCfPn066tWrB7lcDl9fX3z88cfIzc0t8dytW7eiffv2Ul9atWqFlStXSuv379+P119/Hd7e3pDL5fDy8sJ7772Hx48fq+znP//8g86dO8PGxgYeHh6YNm0ahBBl7ltKSgqEEHj55ZdLrCsa6lTk2XkMZQ2XfPYvjStWrEBgYCCsrKzg7OyMvn374vr162X27fTp05DJZNi0aZO07MSJE5DJZHjhhReUto2IiEBQUJD0uPg5uGfPHrRq1QoA8NZbb0l9fHYe17lz59ChQwdYW1vjueeew6xZs8rsX1nc3d3RuHFjpKSkSMsqcn4XHePVq1djwoQJeO6552BtbY2srCzpeO/btw8jRoxArVq1YG9vj0GDBuHBgwfl9ik3NxeTJ09G/fr1pfPpww8/VDo3IyMjYWlpifPnzys9t3PnznBycsKtW7eU+ll0LoSEhGDLli24evWqdHx9fX2RnZ0NGxsbvPvuuyX6c+PGDZiamiIuLk5lf4teIyUlBVu2bJHavXLlCoCnYWbIkCFwc3ODpaUlAgICsHz5cqU2ir8Xz5s3T/q9PHfunMrX7NWrV4lzq2vXriXOw6NHj0Imk2Hr1q0ljnFMTAxcXFxgY2ODnj17KgX/IgsXLkSTJk0gl8vh4eGBqKioCr0XKxQKzJs3D02aNIGlpSXc3NwwYsSIcv/9Bw8ejAULFgCA0u9peccoLy8PkyZNQmBgIBwcHGBjY4O2bdti9+7dJV7j2TktU6ZMgUwmw6VLlzB48GA4OjrCwcEBb731Fh49eqT03GfntBSd6wcPHiz3eCoUCkyZMgUeHh6wtrZGhw4dcO7cuQrPk6nIZ0J5n7FXrlyBi4sLAGDq1KnS8eVwPNJXrLQYkMzMTNy9e1dpmUwmQ61atZSW/fjjj8jLy8Po0aNx//59zJo1C3369EHHjh2xZ88ejB8/HpcuXcL8+fPxwQcflLiAuXjxIt544w2MHDkSkZGRiI+Px+uvv45t27bh1VdfBfC0JN2+fXvcvHkTI0aMgLe3Nw4dOoTY2FikpqZKY8yFEOjevTsOHDiAkSNHonHjxli/fj0iIyMrtM8+Pj7YsWMHdu3ahY4dO6p1vObNm4fs7GylZXPnzkVSUpJ0zNLT0/HSSy9BJpMhOjoaLi4u2Lp1K4YMGYKsrKwyh5y88cYbmDJlCtLS0uDu7i4tP3DgAG7duiUN29q+fTv69euHTp064bPPPgMAnD9/HgcPHlR5UVWkXbt28PT0xMqVKxETE6O0buXKlbC2tkaPHj0AAEOHDsXy5cvx3//+F++//z6OHj2KuLg4nD9/XmmOQ0JCAt5++200adIEsbGxcHR0xMmTJ7Ft2za8+eabAIDExEQ8evQIo0aNQq1atXDs2DHMnz8fN27cQGJiolI/CgsLER4ejpdeegmzZs3Ctm3bMHnyZBQUFGDatGml7ltRCE1MTMTrr78Oa2vrUrdVdVx++OEHpWVXr17FhAkTlMLOjBkzMHHiRPTp0wdDhw7FnTt3MH/+fLRr1w4nT56Eo6OjyvabNm0KR0dH7Nu3D926dQPwNMiZmJjg1KlTyMrKgr29PRQKBQ4dOoThw4erbKdx48aYNm0aJk2ahOHDh6Nt27YAgDZt2kjbPHjwAOHh4ejVqxf69OmDn3/+GePHj0ezZs0QERFR4WNSJD8/H9evX6/0+T19+nRYWFjggw8+QG5urtJfvKOjo+Ho6IgpU6YgOTkZixYtwtWrV6WLfFUUCgW6deuGAwcOYPjw4WjcuDHOnDmDuXPn4sKFC9iwYQMA4Msvv8SuXbsQGRmJw4cPw9TUFEuWLMHvv/+OH374oUSlscj//vc/ZGZm4saNG5g7dy4AwNbWFra2tujZsyfWrFmDL774AqamptJzVq1aBSGEFPif1bhxY/zwww9477334OnpKQ2/c3FxwePHjxESEoJLly4hOjoafn5+SExMxODBg5GRkVHi9zk+Ph5PnjzB8OHDIZfL4ezsrPI127Zti40bN0rnlhACBw8ehImJCfbv31/iPHw27I8ePRpOTk6YPHkyrly5gnnz5iE6Ohpr1qyRtpkyZQqmTp2K0NBQjBo1Svo3/OOPP3Dw4EGYm5ur7BsAjBgxAgkJCXjrrbcwZswYpKSk4Ouvv8bJkyfLfO6IESNw69YtbN++vcTvbFnHKCsrC9999x369euHYcOG4eHDh1i6dCk6d+6MY8eOVWioZZ8+feDn54e4uDj8+eef+O677+Dq6iq9B5elIsczNjYWs2b
"text/plain": [
"<Figure size 1000x600 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"from scipy.interpolate import griddata\n",
"\n",
"# Sample data\n",
"data = results\n",
"\n",
"# Extracting data\n",
"vocab_size = [item['vocab_size'] for item in data if 'nano' not in item['train_file_path'] ]\n",
"embed_size = [item['embed_size'] for item in data if 'nano' not in item['train_file_path'] ]\n",
"perplexity = [item['perplexity'] for item in data if 'nano' not in item['train_file_path'] ]\n",
"\n",
"# Plotting\n",
"grid_x, grid_y = np.meshgrid(np.linspace(min(vocab_size), max(vocab_size), 100),\n",
" np.linspace(min(embed_size), max(embed_size), 100))\n",
"grid_z = griddata((vocab_size, embed_size), perplexity, (grid_x, grid_y), method='cubic')\n",
"\n",
"# Plotting\n",
"plt.figure(figsize=(10, 6))\n",
"contour = plt.contourf(grid_x, grid_y, grid_z, cmap='viridis')\n",
"plt.colorbar(contour, label='Perplexity')\n",
"plt.scatter(vocab_size, embed_size, c='red') # Optional: plot actual data points\n",
"plt.xlabel('Vocab Size')\n",
"plt.ylabel('Embed Size')\n",
"plt.title('Embed Size vs Vocab Size with Perplexity for whole training set')\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "fe388a52-9bd3-4ee3-9cf1-838c9ff22c55",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAy0AAAIjCAYAAAAObfTCAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy80BEi2AAAACXBIWXMAAA9hAAAPYQGoP6dpAACKGklEQVR4nOzdeVwU9eM/8NdyLPeCKIfEIYqpKKiRB5pGaoiSR1qeCZp3YKlpimleJaZ9UitFK0U7yNJES0PDCy88E8WL1PBKDpUAQQFh378//DFfR85FYBd4PR+Pfei+573vfc/s7DCvnZn3KIQQAkRERERERDpKT9sdICIiIiIiKg1DCxERERER6TSGFiIiIiIi0mkMLUREREREpNMYWoiIiIiISKcxtBARERERkU5jaCEiIiIiIp3G0EJERERERDqNoYWIiIiIiHQaQws9k0aNGuG1116r8ve5du0aFAoF1q9fX2lt+vj4wMfHp9Lao9KNHDkS5ubmVf4++/fvh0KhwP79+6v8vZ6VJuugj48PWrVqVbUdqmZV/R2sznUhKysLY8aMgb29PRQKBSZPnlzl70lFNWrUCCNHjqzQa/k3gUi3MbTUQuvXr4dCoSjxcfToUW13sUpdu3YNo0aNQpMmTWBsbAx7e3t07doVc+fO1XbXqsyWLVugUCjw7bffllgnOjoaCoUCX3zxRTX2rPKo1Wp899136NChA6ytrWFhYYHnn38eAQEBtWadvn37NubNm4e4uLhKb7tRo0ay7YCtrS26dOmCyMjISn8vXRYREYHly5dXeruLFi3C+vXrMXHiRHz//fcYMWJEpb9HbXDkyBHMmzcP6enp2u5KrbVo0SJs3bpV290gqnQG2u4AVZ0FCxbA1dW1SLmbm5sWelM9rly5gnbt2sHExARvv/02GjVqhKSkJPz111/49NNPMX/+fKnun3/+qcWeVi5/f39YWloiIiICY8aMKbZOREQE9PX1MWTIkGruXeV49913sXLlSvTr1w/Dhw+HgYEBEhISEBUVhcaNG6Njx44AgK5du+Lhw4dQKpVa7nHZnl4Hb9++jfnz56NRo0Zo06ZNpb9fmzZt8P7770vvtWbNGgwYMABhYWGYMGFCpb+fthW3LkRERODcuXOVfiRk79696NixY63+caQyHDlyBPPnz8fIkSNhZWVV6e0nJCRAT69iv8fWlr8JixYtwhtvvIH+/ftruytElYqhpRbr1asXXnzxRW13o1otW7YMWVlZiIuLg4uLi2xaamqq7HlN2KktLyMjI7zxxhsIDw/H7du34eDgIJuek5ODyMhIvPrqq7C1tdVSLysuJSUFq1atwtixY/H111/Lpi1fvhx37tyRnuvp6cHY2Li6u1gh1b0OPvfcc3jrrbek5wEBAXBzc8OyZcueObTk5ORAqVRWeIexKlTnupCamgp3d/dKay8/Px9qtbpWbac0pVarkZeXp9FnaGRkVOH3q8vLmqgm0J2/LlTtCq8T+eyzz7By5Uo0btwYpqam8PX1xc2bNyGEwMKFC+Ho6AgTExP069cPaWlpxbb1559/ok2bNjA2Noa7uzu2bNlSpE56ejomT54MJycnGBkZwc3NDZ9++inUanWReiNHjoSlpSWsrKwQGBhY7lMJrl69CkdHxyKBBUCRnfWnz19++vSZJx9PnhP/77//4u2334adnR2MjIzQsmVLrFu3rsy+tWrVCq+88kqRcrVajeeeew5vvPGGVLZx40Z4eXnBwsICKpUKHh4eWLFiRantv/XWW1Cr1di4cWORaTt27EBGRgaGDx8O4PEO0cKFC9GkSRMYGRmhUaNGmDVrFnJzc4u8NioqCi+//LLUl3bt2iEiIkKafvDgQbz55ptwdnaGkZERnJycMGXKFDx8+LDYfv7zzz/o2bMnzMzM4ODggAULFkAIUeq8JSYmQgiBzp07F5lWeKpToaevYyjtdMmnz1//4Ycf4OXlBRMTE1hbW2PIkCG4efNmqX07e/YsFAoFfvvtN6ns1KlTUCgUeOGFF2R1e/XqhQ4dOkjPn1wH9+/fj3bt2gEARo0aJfXx6eu4Lly4gFdeeQWmpqZ47rnnsGTJklL7Vxp7e3u0aNECiYmJUll51u/CZbxx40bMnj0bzz33HExNTZGZmSkt7wMHDmD8+PGoX78+VCoVAgIC8N9//5XZp9zcXMydOxdubm7S+vTBBx/I1s3AwEAYGxvj4sWLstf27NkT9erVw+3bt2X9LFwXfHx8sGPHDly/fl1avo0aNUJWVhbMzMzw3nvvFenPrVu3oK+vj9DQ0GL7W/geiYmJ2LFjh9TutWvXADwOM6NHj4adnR2MjY3RunVrbNiwQdbGk9vi5cuXS9/LCxculLicFAoFgoODsXXrVrRq1Ur6rHbu3Cmrd/36dbzzzjto1qwZTExMUL9+fbz55ptS/woVfm6HDx/G1KlTYWNjAzMzM7z++uuyHwUKrVq1Ci1btoSRkREcHBwQFBRU5nZ63rx5mD59OgDA1dW1yLIqnKcff/xRartwfj777DN06tQJ9evXh4mJCby8vLB58+Yi7/H0NS2azNfTfxMKP9tffvkFn3zyCRwdHWFsbIzu3bvjypUrRd678O+oiYkJ2rdvj4MHD5b7Opno6Gi89NJLsLKygrm5OZo1a4ZZs2bJ6pTnu6FQKJCdnY0NGzZIy7ei1/gQ6RoeaanFMjIycPfuXVmZQqFA/fr1ZWU//vgj8vLyMGnSJKSlpWHJkiUYNGgQunXrhv3792PGjBm4cuUKvvzyS0ybNq3IDszly5cxePBgTJgwAYGBgQgPD8ebb76JnTt34tVXXwUAPHjwAC+//DL+/fdfjB8/Hs7Ozjhy5AhCQkKQlJQknWMuhEC/fv1w6NAhTJgwAS1atEBkZCQCAwPLNc8uLi7YvXs39u7di27dumm0vJYvX46srCxZ2bJlyxAXFycts5SUFHTs2FH642pjY4OoqCiMHj0amZmZpZ5yMnjwYMybNw/Jycmwt7eXyg8dOoTbt29Lp21FR0dj6NCh6N69Oz799FMAwMWLF3H48OFid6oKde3aFY6OjoiIiMDUqVNl0yIiImBqaiqdLjBmzBhs2LABb7zxBt5//30cO3YMoaGhuHjxouwah/Xr1+Ptt99Gy5YtERISAisrK5w+fRo7d+7EsGHDAACbNm3CgwcPMHHiRNSvXx/Hjx/Hl19+iVu3bmHTpk2yfhQUFMDPzw8dO3bEkiVLsHPnTsydOxf5+flYsGBBifNWGEI3bdqEN998E6ampiXWLW65fP/997Ky69evY/bs2bKw88knn2DOnDkYNGgQxowZgzt37uDLL79E165dcfr06RJPZWnVqhWsrKxw4MAB9O3bF8DjIKenp4czZ84gMzMTKpUKarUaR44cwbhx44ptp0WLFliwYAE++ugjjBs3Dl26dAEAdOrUSarz33//wc/PDwMGDMCgQYOwefNmzJgxAx4eHujVq1e5l0mhR48e4ebNmxVevxcuXAilUolp06YhNzdX9kt1cHAwrKysMG/ePCQkJCAsLAzXr1+XdgSLo1ar0bdvXxw6dAjjxo1DixYtEB8fj2XLluHvv/+WztNfsWIF9u7di8DAQMTGxkJfXx9r1qzBn3/+ie+//77IkcZCH374ITIyMnDr1i0sW7YMAGBubg5zc3O8/vrr+Pnnn/H5559DX19fes1PP/0EIYQU+J/WokULfP/995gyZQocHR2l0+9sbGzw8OFD+Pj44MqVKwgODoarqys2bdqEkSNHIj09vcj3OTw8HDk5ORg3bhyMjIxgbW1dwif32KFDh7Blyxa88847sLCwwBdffIGBAwfixo0b0md64sQJHDlyBEOGDIGjoyOuXbuGsLAw+Pj44MKFC0W+S5MmTUK9evUwd+5cXLt2DcuXL0dwcDB+/vlnqc68efMwf/589OjRAxMnTpQ+3xMnTuDw4cMwNDQstr8DBgzA33//jZ9++gnLli1DgwYNpGVVaO/evfjll18QHByMBg0aoFGjRgAef+Z9+/bF8OHDkZeXh40bN+LNN9/E9u3b4e/vX+pyKu98lWTx4sX
"text/plain": [
"<Figure size 1000x600 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Extracting data\n",
"vocab_size = [item['vocab_size'] for item in data if 'nano' in item['train_file_path'] ]\n",
"embed_size = [item['embed_size'] for item in data if 'nano' in item['train_file_path'] ]\n",
"perplexity = [item['perplexity'] for item in data if 'nano' in item['train_file_path'] ]\n",
"\n",
"# Plotting\n",
"grid_x, grid_y = np.meshgrid(np.linspace(min(vocab_size), max(vocab_size), 100),\n",
" np.linspace(min(embed_size), max(embed_size), 100))\n",
"grid_z = griddata((vocab_size, embed_size), perplexity, (grid_x, grid_y), method='cubic')\n",
"\n",
"# Plotting\n",
"plt.figure(figsize=(10, 6))\n",
"contour = plt.contourf(grid_x, grid_y, grid_z, cmap='viridis')\n",
"plt.colorbar(contour, label='Perplexity')\n",
"plt.scatter(vocab_size, embed_size, c='red') # Optional: plot actual data points\n",
"plt.xlabel('Vocab Size')\n",
"plt.ylabel('Embed Size')\n",
"plt.title('Embed Size vs Vocab Size with Perplexity for nano training set')\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "a310f1f5-0b2f-4994-b36a-e2ff1a7e6b70",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'embed_size': 300,\n",
" 'vocab_size': 30000,\n",
" 'num_epochs': 1,\n",
" 'batch_size': 8192,\n",
" 'train_file_path': 'train/train.txt',\n",
" 'perplexity': 173.38,\n",
" 'logPerplexity': 5.155485717440494}"
]
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from math import log\n",
"\n",
2024-05-23 02:41:59 +02:00
"best_model_parameters = min(results, key=lambda x: x['perplexity'])\n",
"best_model_parameters['logPerplexity'] = log(best_model_parameters['perplexity'])\n",
"best_model_parameters"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}