{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "f3452caf-df58-4394-b0d6-46459cb47045", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "S:\\WENV_TORCHTEXT\\Lib\\site-packages\\torchtext\\vocab\\__init__.py:4: UserWarning: \n", "/!\\ IMPORTANT WARNING ABOUT TORCHTEXT STATUS /!\\ \n", "Torchtext is deprecated and the last released version will be 0.18 (this one). You can silence this warning by calling the following at the beginnign of your scripts: `import torchtext; torchtext.disable_torchtext_deprecation_warning()`\n", " warnings.warn(torchtext._TORCHTEXT_DEPRECATION_MSG)\n", "S:\\WENV_TORCHTEXT\\Lib\\site-packages\\torchtext\\utils.py:4: UserWarning: \n", "/!\\ IMPORTANT WARNING ABOUT TORCHTEXT STATUS /!\\ \n", "Torchtext is deprecated and the last released version will be 0.18 (this one). You can silence this warning by calling the following at the beginnign of your scripts: `import torchtext; torchtext.disable_torchtext_deprecation_warning()`\n", " warnings.warn(torchtext._TORCHTEXT_DEPRECATION_MSG)\n" ] } ], "source": [ "from torch.utils.data import IterableDataset, DataLoader\n", "from torchtext.vocab import build_vocab_from_iterator\n", "\n", "import regex as re\n", "import itertools\n", "from itertools import islice\n", "\n", "from torch import nn\n", "import torch\n", "\n", "from tqdm.notebook import tqdm\n", "device = 'cuda'" ] }, { "cell_type": "code", "execution_count": 2, "id": "5ee9ad24-a5d2-47e1-a5c6-88981dc22b99", "metadata": {}, "outputs": [], "source": [ "def get_words_from_line(line):\n", " line = line.rstrip()\n", " yield ''\n", " for m in re.finditer(r'[\\p{L}0-9\\*]+|\\p{P}+', line):\n", " yield m.group(0).lower()\n", " yield ''\n", "\n", "def get_word_lines_from_file(file_name):\n", " with open(file_name, 'r', encoding='utf8') as fh:\n", " for line in fh:\n", " yield get_words_from_line(line)\n", "\n", "def look_ahead_iterator(gen):\n", " prev2, prev1, next1, next2 = None, None, None, None\n", " for item in gen:\n", " if prev2 is not None and prev1 is not None and next1 is not None and next2 is not None:\n", " yield (prev2, prev1, next2, item, next1)\n", " prev2, prev1, next1, next2 = prev1, next1, next2, item" ] }, { "cell_type": "code", "execution_count": 3, "id": "93279277-0765-4f85-9666-095fc7808c81", "metadata": {}, "outputs": [], "source": [ "class FiveGrams(IterableDataset):\n", " def __init__(self, text_file, vocabulary_size):\n", " self.vocab = build_vocab_from_iterator(\n", " get_word_lines_from_file(text_file),\n", " max_tokens=vocabulary_size,\n", " specials=['']\n", " )\n", " self.vocab.set_default_index(self.vocab[''])\n", " self.vocabulary_size = vocabulary_size\n", " self.text_file = text_file\n", "\n", " def __iter__(self):\n", " return look_ahead_iterator(\n", " (self.vocab[t] for t in itertools.chain.from_iterable(get_word_lines_from_file(self.text_file)))\n", " )" ] }, { "cell_type": "code", "execution_count": 4, "id": "6eb5fbd9-bc0f-499d-85f4-3998a4a3f56e", "metadata": {}, "outputs": [], "source": [ "class SimpleFiveGramNeuralLanguageModel(nn.Module):\n", " def __init__(self, vocabulary_size, embedding_size):\n", " super(SimpleFiveGramNeuralLanguageModel, self).__init__()\n", " self.embedding = nn.Embedding(vocabulary_size, embedding_size)\n", " self.linear1 = nn.Linear(embedding_size * 4, embedding_size)\n", " self.linear2 = nn.Linear(embedding_size, vocabulary_size)\n", " self.softmax = nn.Softmax(dim=1)\n", " self.embedding_size = embedding_size\n", "\n", " def forward(self, x):\n", " embeds = self.embedding(x).view(x.size(0), -1)\n", " out = self.linear1(embeds)\n", " out = self.linear2(out)\n", " return self.softmax(out)" ] }, { "cell_type": "code", "execution_count": 5, "id": "d0dc7c69-3f27-4f00-9b91-5f3a403df074", "metadata": {}, "outputs": [], "source": [ "def train(embed_size,vocab_size,num_epochs,batch_size,train_file_path):\n", " train_dataset = FiveGrams(train_file_path, vocab_size)\n", " model = SimpleFiveGramNeuralLanguageModel(vocab_size, embed_size).to(device)\n", " \n", " data = DataLoader(train_dataset, batch_size=batch_size)\n", " optimizer = torch.optim.Adam(model.parameters())\n", " criterion = torch.nn.CrossEntropyLoss()\n", " \n", " model.train()\n", " step = 0\n", " for _ in range(num_epochs):\n", " for x1, x2, x3, x4, y in tqdm(data, desc=\"Train loop\"):\n", " y = y.to(device)\n", " x = torch.cat((x1.unsqueeze(1), x2.unsqueeze(1), x3.unsqueeze(1), x4.unsqueeze(1)), dim=1).to(device)\n", " optimizer.zero_grad()\n", " ypredicted = model(x)\n", " \n", " loss = criterion(torch.log(ypredicted), y)\n", " if step % 5000 == 0:\n", " print(step, loss)\n", " step += 1\n", " loss.backward()\n", " optimizer.step()\n", " step = 0\n", " break\n", " model.eval()\n", "\n", " return model, train_dataset.vocab" ] }, { "cell_type": "code", "execution_count": 6, "id": "9a1b2240-d2ed-4c56-8443-12113e66b514", "metadata": {}, "outputs": [], "source": [ "def get_gap_candidates(words, model, vocab, n=20):\n", " ixs = vocab(words)\n", " ixs = torch.tensor(ixs).unsqueeze(0).to(device)\n", "\n", " out = model(ixs)\n", " top = torch.topk(out[0], n)\n", " top_indices = top.indices.tolist()\n", " top_probs = top.values.tolist()\n", " top_words = vocab.lookup_tokens(top_indices)\n", " return list(zip(top_words, top_probs))\n", "\n", "def clean(text):\n", " text = text.replace('-\\\\n', '').replace('\\\\n', ' ').replace('\\\\t', ' ')\n", " text = re.sub(r'\\n', ' ', text)\n", " text = re.sub(r'(?<=\\w)[,-](?=\\w)', '', text)\n", " text = re.sub(r'\\s+', ' ', text)\n", " text = re.sub(r'\\p{P}', '', text)\n", " text = text.strip()\n", " return text\n", " \n", "def predictor(prefix, suffix, model, vocab):\n", " prefix = clean(prefix)\n", " suffix = clean(suffix)\n", " words = prefix.split(' ')[-2:] + suffix.split(' ')[:2]\n", " candidates = get_gap_candidates(words, model, vocab)\n", "\n", " probs_sum = 0\n", " output = ''\n", " for word, prob in candidates:\n", " if word == \"\":\n", " continue\n", " probs_sum += prob\n", " output += f\"{word}:{prob} \"\n", " output += f\":{1-probs_sum}\"\n", "\n", " return output" ] }, { "cell_type": "code", "execution_count": 7, "id": "40af2781-3807-43e8-b6dd-3b70066e50c1", "metadata": {}, "outputs": [], "source": [ "def generate_result(input_path,model, vocab, output_path='out.tsv'):\n", " lines = []\n", " with open(input_path, encoding='utf-8') as f:\n", " for line in f:\n", " columns = line.split('\\t')\n", " prefix = columns[6]\n", " suffix = columns[7]\n", " lines.append((prefix, suffix))\n", "\n", " with open(output_path, 'w', encoding='utf-8') as output_file:\n", " for prefix, suffix in tqdm(lines):\n", " result = predictor(prefix, suffix, model, vocab)\n", " output_file.write(result + '\\n')" ] }, { "cell_type": "code", "execution_count": 8, "id": "d6b7234f-1f40-468f-8c69-2875bb1ec947", "metadata": {}, "outputs": [], "source": [ "import subprocess\n", "\n", "def evaluate():\n", " cmd = 'wsl bash -c \"cd /mnt/d/UAM/MODELOWANIE/5GRAM && ./geval -t dev-0\"'\n", " result = subprocess.run(cmd, shell=True, capture_output=True, text=True)\n", " return float(result.stdout)" ] }, { "cell_type": "code", "execution_count": 9, "id": "4c716463-27fe-4c2b-b859-ac9c8aff1942", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "1eac733e07974322bfd47dcff96aa8d4", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Train loop: 0it [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "0 tensor(9.2551, device='cuda:0', grad_fn=)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "3df53998bc334a29bd355578738897d3", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/10519 [00:00)\n", "5000 tensor(4.6065, device='cuda:0', grad_fn=)\n", "10000 tensor(4.4173, device='cuda:0', grad_fn=)\n", "15000 tensor(4.3352, device='cuda:0', grad_fn=)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "1e3d1e4344f44285832accfc83ab2233", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/10519 [00:00)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "3546c4fd825a4057b05c6105b25f6712", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/10519 [00:00)\n", "5000 tensor(4.8952, device='cuda:0', grad_fn=)\n", "10000 tensor(4.7382, device='cuda:0', grad_fn=)\n", "15000 tensor(4.6068, device='cuda:0', grad_fn=)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "cd0c53ddfb3148609218b56c72ce7777", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/10519 [00:00)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "42b54814d7f741ddb38d8e00d6db2126", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/10519 [00:00)\n", "5000 tensor(5.0450, device='cuda:0', grad_fn=)\n", "10000 tensor(4.8688, device='cuda:0', grad_fn=)\n", "15000 tensor(4.7152, device='cuda:0', grad_fn=)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "38a391c3471442dc8620524f0d5c5fc8", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/10519 [00:00)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "65a44fd20fce4ec49528763ebb351f4b", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/10519 [00:00)\n", "5000 tensor(4.4829, device='cuda:0', grad_fn=)\n", "10000 tensor(4.2794, device='cuda:0', grad_fn=)\n", "15000 tensor(4.2239, device='cuda:0', grad_fn=)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "9b2f000b5c7a4ab08fb48c78b45c3a54", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/10519 [00:00)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "6883012f36d44ab0842b14281ef5eb65", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/10519 [00:00)\n", "5000 tensor(4.7515, device='cuda:0', grad_fn=)\n", "10000 tensor(4.5669, device='cuda:0', grad_fn=)\n", "15000 tensor(4.4938, device='cuda:0', grad_fn=)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "53d6628208ac4997a290cdde469ccfb1", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/10519 [00:00)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "ccb8ed60361c4a509afb628f52ce609c", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/10519 [00:00)\n", "5000 tensor(4.8590, device='cuda:0', grad_fn=)\n", "10000 tensor(4.7090, device='cuda:0', grad_fn=)\n", "15000 tensor(4.5810, device='cuda:0', grad_fn=)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "85f4db52171945fa9fb2890431a276b2", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/10519 [00:00)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "32802641dc974b6cb3404a12c51b611b", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/10519 [00:00)\n", "5000 tensor(4.4134, device='cuda:0', grad_fn=)\n", "10000 tensor(4.2280, device='cuda:0', grad_fn=)\n", "15000 tensor(4.1653, device='cuda:0', grad_fn=)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "a97a141657af410b989bbd8ab8710955", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/10519 [00:00)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "3aa3a60daf6145389312b031ca55e11d", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/10519 [00:00)\n", "5000 tensor(4.6888, device='cuda:0', grad_fn=)\n", "10000 tensor(4.5068, device='cuda:0', grad_fn=)\n", "15000 tensor(4.4465, device='cuda:0', grad_fn=)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "56e1065d5c01461e8cdfc38b7db2f3ed", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/10519 [00:00)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "58d964531ad14b69ba72b7d8d538cf2d", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/10519 [00:00)\n", "5000 tensor(4.8159, device='cuda:0', grad_fn=)\n", "10000 tensor(4.6442, device='cuda:0', grad_fn=)\n", "15000 tensor(4.5524, device='cuda:0', grad_fn=)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "aedcfa56f6c345dd976b37535d0de2b4", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/10519 [00:00" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import matplotlib.pyplot as plt\n", "import numpy as np\n", "from scipy.interpolate import griddata\n", "\n", "# Sample data\n", "data = results\n", "\n", "# Extracting data\n", "vocab_size = [item['vocab_size'] for item in data if 'nano' not in item['train_file_path'] ]\n", "embed_size = [item['embed_size'] for item in data if 'nano' not in item['train_file_path'] ]\n", "perplexity = [item['perplexity'] for item in data if 'nano' not in item['train_file_path'] ]\n", "\n", "# Plotting\n", "grid_x, grid_y = np.meshgrid(np.linspace(min(vocab_size), max(vocab_size), 100),\n", " np.linspace(min(embed_size), max(embed_size), 100))\n", "grid_z = griddata((vocab_size, embed_size), perplexity, (grid_x, grid_y), method='cubic')\n", "\n", "# Plotting\n", "plt.figure(figsize=(10, 6))\n", "contour = plt.contourf(grid_x, grid_y, grid_z, cmap='viridis')\n", "plt.colorbar(contour, label='Perplexity')\n", "plt.scatter(vocab_size, embed_size, c='red') # Optional: plot actual data points\n", "plt.xlabel('Vocab Size')\n", "plt.ylabel('Embed Size')\n", "plt.title('Embed Size vs Vocab Size with Perplexity for whole training set')\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 22, "id": "fe388a52-9bd3-4ee3-9cf1-838c9ff22c55", "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Extracting data\n", "vocab_size = [item['vocab_size'] for item in data if 'nano' in item['train_file_path'] ]\n", "embed_size = [item['embed_size'] for item in data if 'nano' in item['train_file_path'] ]\n", "perplexity = [item['perplexity'] for item in data if 'nano' in item['train_file_path'] ]\n", "\n", "# Plotting\n", "grid_x, grid_y = np.meshgrid(np.linspace(min(vocab_size), max(vocab_size), 100),\n", " np.linspace(min(embed_size), max(embed_size), 100))\n", "grid_z = griddata((vocab_size, embed_size), perplexity, (grid_x, grid_y), method='cubic')\n", "\n", "# Plotting\n", "plt.figure(figsize=(10, 6))\n", "contour = plt.contourf(grid_x, grid_y, grid_z, cmap='viridis')\n", "plt.colorbar(contour, label='Perplexity')\n", "plt.scatter(vocab_size, embed_size, c='red') # Optional: plot actual data points\n", "plt.xlabel('Vocab Size')\n", "plt.ylabel('Embed Size')\n", "plt.title('Embed Size vs Vocab Size with Perplexity for nano training set')\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 26, "id": "a310f1f5-0b2f-4994-b36a-e2ff1a7e6b70", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'embed_size': 300,\n", " 'vocab_size': 30000,\n", " 'num_epochs': 1,\n", " 'batch_size': 8192,\n", " 'train_file_path': 'train/train.txt',\n", " 'perplexity': 173.38,\n", " 'logPerplexity': 5.155485717440494}" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from math import log\n", "\n", "best_model_parameters = min(results, key=lambda x: x['perplexity'])\n", "best_model_parameters['logPerplexity'] = log(best_model_parameters['perplexity'])\n", "best_model_parameters" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.3" } }, "nbformat": 4, "nbformat_minor": 5 }