{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "f3452caf-df58-4394-b0d6-46459cb47045", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "S:\\WENV_TORCHTEXT\\Lib\\site-packages\\torchtext\\vocab\\__init__.py:4: UserWarning: \n", "/!\\ IMPORTANT WARNING ABOUT TORCHTEXT STATUS /!\\ \n", "Torchtext is deprecated and the last released version will be 0.18 (this one). You can silence this warning by calling the following at the beginnign of your scripts: `import torchtext; torchtext.disable_torchtext_deprecation_warning()`\n", " warnings.warn(torchtext._TORCHTEXT_DEPRECATION_MSG)\n", "S:\\WENV_TORCHTEXT\\Lib\\site-packages\\torchtext\\utils.py:4: UserWarning: \n", "/!\\ IMPORTANT WARNING ABOUT TORCHTEXT STATUS /!\\ \n", "Torchtext is deprecated and the last released version will be 0.18 (this one). You can silence this warning by calling the following at the beginnign of your scripts: `import torchtext; torchtext.disable_torchtext_deprecation_warning()`\n", " warnings.warn(torchtext._TORCHTEXT_DEPRECATION_MSG)\n" ] } ], "source": [ "from torch.utils.data import IterableDataset, DataLoader\n", "from torchtext.vocab import build_vocab_from_iterator\n", "\n", "import regex as re\n", "import itertools\n", "from itertools import islice\n", "\n", "from torch import nn\n", "import torch\n", "\n", "from tqdm.notebook import tqdm\n", "\n", "embed_size = 300\n", "vocab_size = 30_000\n", "num_epochs = 1\n", "device = 'cuda'\n", "batch_size = 8192\n", "train_file_path = 'train/train.txt'" ] }, { "cell_type": "code", "execution_count": 2, "id": "93279277-0765-4f85-9666-095fc7808c81", "metadata": {}, "outputs": [], "source": [ "# Function to extract words from a line of text\n", "def get_words_from_line(line):\n", " line = line.rstrip()\n", " yield ''\n", " for m in re.finditer(r'[\\p{L}0-9\\*]+|\\p{P}+', line):\n", " yield m.group(0).lower()\n", " yield ''\n", "\n", "# Generator to read lines from a file\n", "def get_word_lines_from_file(file_name):\n", " with open(file_name, 'r', encoding='utf8') as fh:\n", " for line in fh:\n", " yield get_words_from_line(line)\n", "\n", "# Function to create 5-grams from a sequence\n", "def look_ahead_iterator(gen):\n", " prev2, prev1, next1, next2 = None, None, None, None\n", " for item in gen:\n", " if prev2 is not None and prev1 is not None and next1 is not None and next2 is not None:\n", " yield (prev2, prev1, next2, item, next1)\n", " prev2, prev1, next1, next2 = prev1, next1, next2, item\n", "\n", "# Dataset class for 5-grams\n", "class FiveGrams(IterableDataset):\n", " def __init__(self, text_file, vocabulary_size):\n", " self.vocab = build_vocab_from_iterator(\n", " get_word_lines_from_file(text_file),\n", " max_tokens=vocabulary_size,\n", " specials=['']\n", " )\n", " self.vocab.set_default_index(self.vocab[''])\n", " self.vocabulary_size = vocabulary_size\n", " self.text_file = text_file\n", "\n", " def __iter__(self):\n", " return look_ahead_iterator(\n", " (self.vocab[t] for t in itertools.chain.from_iterable(get_word_lines_from_file(self.text_file)))\n", " )\n", "\n", "# Instantiate the dataset\n", "train_dataset = FiveGrams(train_file_path, vocab_size)" ] }, { "cell_type": "code", "execution_count": 3, "id": "980103d6-05a3-4b9a-a539-b59815f6a45d", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "['', 'came', 'the', 'last', 'fiom']\n", "['came', 'fiom', 'last', 'place', 'the']\n" ] } ], "source": [ "i = 0\n", "for x in train_dataset:\n", " print(train_dataset.vocab.lookup_tokens(x))\n", " if i >= 1:\n", " break\n", " i += 1" ] }, { "cell_type": "code", "execution_count": 4, "id": "6eb5fbd9-bc0f-499d-85f4-3998a4a3f56e", "metadata": {}, "outputs": [], "source": [ "class SimpleFiveGramNeuralLanguageModel(nn.Module):\n", " def __init__(self, vocabulary_size, embedding_size):\n", " super(SimpleFiveGramNeuralLanguageModel, self).__init__()\n", " self.embedding = nn.Embedding(vocabulary_size, embedding_size)\n", " self.linear1 = nn.Linear(embedding_size * 4, embedding_size)\n", " self.linear2 = nn.Linear(embedding_size, vocabulary_size)\n", " self.softmax = nn.Softmax(dim=1)\n", " self.embedding_size = embedding_size\n", "\n", " def forward(self, x):\n", " embeds = self.embedding(x).view(x.size(0), -1)\n", " out = self.linear1(embeds)\n", " out = self.linear2(out)\n", " return self.softmax(out)\n", "\n", "model = SimpleFiveGramNeuralLanguageModel(vocab_size, embed_size).to(device)" ] }, { "cell_type": "code", "execution_count": 5, "id": "d0dc7c69-3f27-4f00-9b91-5f3a403df074", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "3064f4f089604c8c8d0d6a6a826876bc", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Train loop: 0it [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "0 tensor(10.3575, device='cuda:0', grad_fn=)\n", "5000 tensor(4.8030, device='cuda:0', grad_fn=)\n", "10000 tensor(4.6310, device='cuda:0', grad_fn=)\n", "15000 tensor(4.5446, device='cuda:0', grad_fn=)\n" ] }, { "data": { "text/plain": [ "SimpleFiveGramNeuralLanguageModel(\n", " (embedding): Embedding(30000, 300)\n", " (linear1): Linear(in_features=1200, out_features=300, bias=True)\n", " (linear2): Linear(in_features=300, out_features=30000, bias=True)\n", " (softmax): Softmax(dim=1)\n", ")" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data = DataLoader(train_dataset, batch_size=batch_size)\n", "optimizer = torch.optim.Adam(model.parameters())\n", "criterion = torch.nn.CrossEntropyLoss()\n", "\n", "model.train()\n", "step = 0\n", "for _ in range(num_epochs):\n", " for x1, x2, x3, x4, y in tqdm(data, desc=\"Train loop\"):\n", " y = y.to(device)\n", " x = torch.cat((x1.unsqueeze(1), x2.unsqueeze(1), x3.unsqueeze(1), x4.unsqueeze(1)), dim=1).to(device)\n", " optimizer.zero_grad()\n", " ypredicted = model(x)\n", " \n", " loss = criterion(torch.log(ypredicted), y)\n", " if step % 5000 == 0:\n", " print(step, loss)\n", " step += 1\n", " loss.backward()\n", " optimizer.step()\n", " step = 0\n", "model.eval()" ] }, { "cell_type": "code", "execution_count": 8, "id": "9a1b2240-d2ed-4c56-8443-12113e66b514", "metadata": {}, "outputs": [], "source": [ "def get_gap_candidates(words, n=20, vocab=train_dataset.vocab):\n", " ixs = vocab(words)\n", " ixs = torch.tensor(ixs).unsqueeze(0).to(device)\n", "\n", " out = model(ixs)\n", " top = torch.topk(out[0], n)\n", " top_indices = top.indices.tolist()\n", " top_probs = top.values.tolist()\n", " top_words = vocab.lookup_tokens(top_indices)\n", " return list(zip(top_words, top_probs))\n", "\n", "def clean(text):\n", " text = text.replace('-\\\\n', '').replace('\\\\n', ' ').replace('\\\\t', ' ')\n", " text = re.sub(r'\\n', ' ', text)\n", " text = re.sub(r'(?<=\\w)[,-](?=\\w)', '', text)\n", " text = re.sub(r'\\s+', ' ', text)\n", " text = re.sub(r'\\p{P}', '', text)\n", " text = text.strip()\n", " return text\n", " \n", "def predictor(prefix, suffix):\n", " prefix = clean(prefix)\n", " suffix = clean(suffix)\n", " words = prefix.split(' ')[-2:] + suffix.split(' ')[:2]\n", " candidates = get_gap_candidates(words)\n", "\n", " probs_sum = 0\n", " output = ''\n", " for word, prob in candidates:\n", " if word == \"\":\n", " continue\n", " probs_sum += prob\n", " output += f\"{word}:{prob} \"\n", " output += f\":{1-probs_sum}\"\n", "\n", " return output" ] }, { "cell_type": "code", "execution_count": 9, "id": "40af2781-3807-43e8-b6dd-3b70066e50c1", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "743a03c2e3064f9485d196e8eafe80e9", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/10519 [00:00