{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "dfd117bd-5d6f-46e6-979c-092a8065fa0b", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "S:\\WENV_TORCHTEXT\\Lib\\site-packages\\torchtext\\vocab\\__init__.py:4: UserWarning: \n", "/!\\ IMPORTANT WARNING ABOUT TORCHTEXT STATUS /!\\ \n", "Torchtext is deprecated and the last released version will be 0.18 (this one). You can silence this warning by calling the following at the beginnign of your scripts: `import torchtext; torchtext.disable_torchtext_deprecation_warning()`\n", " warnings.warn(torchtext._TORCHTEXT_DEPRECATION_MSG)\n", "S:\\WENV_TORCHTEXT\\Lib\\site-packages\\torchtext\\utils.py:4: UserWarning: \n", "/!\\ IMPORTANT WARNING ABOUT TORCHTEXT STATUS /!\\ \n", "Torchtext is deprecated and the last released version will be 0.18 (this one). You can silence this warning by calling the following at the beginnign of your scripts: `import torchtext; torchtext.disable_torchtext_deprecation_warning()`\n", " warnings.warn(torchtext._TORCHTEXT_DEPRECATION_MSG)\n" ] } ], "source": [ "from torch.utils.data import IterableDataset, DataLoader\n", "from torchtext.vocab import build_vocab_from_iterator\n", "\n", "import regex as re\n", "import sys\n", "import itertools\n", "from itertools import islice\n", "\n", "from torch import nn\n", "import torch\n", "\n", "from tqdm.notebook import tqdm\n", "\n", "embed_size = 300\n", "vocab_size = 30_000\n", "num_epochs = 1\n", "device = 'cuda'\n", "batch_size = 8192\n", "train_file_path = 'train/train.txt'\n", "\n", "with open(train_file_path, 'r', encoding='utf-8') as file:\n", " total = len(file.readlines())" ] }, { "cell_type": "code", "execution_count": 2, "id": "40392665-79bc-4032-a5de-9d189545c9f7", "metadata": {}, "outputs": [], "source": [ "# Function to extract words from a line of text\n", "def get_words_from_line(line):\n", " line = line.rstrip()\n", " yield ''\n", " for m in re.finditer(r'[\\p{L}0-9\\*]+|\\p{P}+', line):\n", " yield m.group(0).lower()\n", " yield ''\n", "\n", "# Generator to read lines from a file\n", "def get_word_lines_from_file(file_name):\n", " limit = total * 2\n", " with open(file_name, 'r', encoding='utf8') as fh:\n", " for line in fh:\n", " limit -= 1\n", " if not limit:\n", " break\n", " yield get_words_from_line(line)\n", "\n", "# Function to create trigrams from a sequence\n", "def look_ahead_iterator(gen):\n", " prev1, prev2 = None, None\n", " for item in gen:\n", " if prev1 is not None and prev2 is not None:\n", " yield (prev2, prev1, item)\n", " prev2 = prev1\n", " prev1 = item\n", "\n", "# Dataset class for trigrams\n", "class Trigrams(IterableDataset):\n", " def __init__(self, text_file, vocabulary_size):\n", " self.vocab = build_vocab_from_iterator(\n", " get_word_lines_from_file(text_file),\n", " max_tokens=vocabulary_size,\n", " specials=['']\n", " )\n", " self.vocab.set_default_index(self.vocab[''])\n", " self.vocabulary_size = vocabulary_size\n", " self.text_file = text_file\n", "\n", " def __iter__(self):\n", " return look_ahead_iterator(\n", " (self.vocab[t] for t in itertools.chain.from_iterable(get_word_lines_from_file(self.text_file)))\n", " )\n", "\n", "# Instantiate the dataset\n", "train_dataset = Trigrams(train_file_path, vocab_size)" ] }, { "cell_type": "code", "execution_count": 3, "id": "0cf7aa68-37aa-48a4-b647-e0e5002ca5c9", "metadata": {}, "outputs": [], "source": [ "class SimpleTrigramNeuralLanguageModel(nn.Module):\n", " def __init__(self, vocabulary_size, embedding_size):\n", " super(SimpleTrigramNeuralLanguageModel, self).__init__()\n", " self.embedding = nn.Embedding(vocabulary_size, embedding_size)\n", " self.linear1 = nn.Linear(embedding_size * 2, embedding_size)\n", " self.linear2 = nn.Linear(embedding_size, vocabulary_size)\n", " self.softmax = nn.Softmax(dim=1)\n", " self.embedding_size = embedding_size\n", "\n", " def forward(self, x):\n", " embeds = self.embedding(x).view(x.size(0), -1)\n", " out = self.linear1(embeds)\n", " out = self.linear2(out)\n", " return self.softmax(out)\n", "\n", "model = SimpleTrigramNeuralLanguageModel(vocab_size, embed_size).to(device)" ] }, { "cell_type": "code", "execution_count": 4, "id": "32ea22db-7259-4549-a9d5-4781d9bc99bc", "metadata": {}, "outputs": [], "source": [ "data = DataLoader(train_dataset, batch_size=batch_size)\n", "optimizer = torch.optim.Adam(model.parameters())\n", "criterion = torch.nn.CrossEntropyLoss()" ] }, { "cell_type": "code", "execution_count": 5, "id": "0858967e-5143-4253-921d-a009dbbdca27", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "c3d8f9d5b178490899934860a55c2508", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Train loop: 0it [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "0 tensor(10.3631, device='cuda:0', grad_fn=)\n", "5000 tensor(5.7081, device='cuda:0', grad_fn=)\n", "10000 tensor(5.5925, device='cuda:0', grad_fn=)\n", "15000 tensor(5.5097, device='cuda:0', grad_fn=)\n" ] }, { "data": { "text/plain": [ "SimpleTrigramNeuralLanguageModel(\n", " (embedding): Embedding(30000, 300)\n", " (linear1): Linear(in_features=600, out_features=300, bias=True)\n", " (linear2): Linear(in_features=300, out_features=30000, bias=True)\n", " (softmax): Softmax(dim=1)\n", ")" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.train()\n", "step = 0\n", "for _ in range(num_epochs):\n", " for x1,x2,y in tqdm(data, desc=\"Train loop\"):\n", " y = y.to(device)\n", " x = torch.cat((x1.unsqueeze(1),x2.unsqueeze(1)), dim=1).to(device)\n", " optimizer.zero_grad()\n", " ypredicted = model(x)\n", " \n", " loss = criterion(torch.log(ypredicted), y)\n", " if step % 5000 == 0:\n", " print(step, loss)\n", " step += 1\n", " loss.backward()\n", " optimizer.step()\n", " step = 0\n", "model.eval()" ] }, { "cell_type": "code", "execution_count": 6, "id": "da4d116c-beec-436d-84d8-577282507226", "metadata": {}, "outputs": [], "source": [ "def get_gap_candidates(words, n=10, vocab=train_dataset.vocab):\n", " ixs = vocab(words)\n", " ixs = torch.tensor(ixs).unsqueeze(0).to(device)\n", "\n", " out = model(ixs)\n", " top = torch.topk(out[0], n)\n", " top_indices = top.indices.tolist()\n", " top_probs = top.values.tolist()\n", " top_words = vocab.lookup_tokens(top_indices)\n", " return list(zip(top_words, top_probs))" ] }, { "cell_type": "code", "execution_count": 7, "id": "0cafd70a-29b3-4a49-b40f-b8ce3143084a", "metadata": {}, "outputs": [], "source": [ "def clean(text):\n", " text = text.replace('-\\\\n', '').replace('\\\\n', ' ').replace('\\\\t', ' ')\n", " text = re.sub(r'\\n', ' ', text)\n", " text = re.sub(r'(?<=\\w)[,-](?=\\w)', '', text)\n", " text = re.sub(r'\\s+', ' ', text)\n", " text = re.sub(r'\\p{P}', '', text)\n", " text = text.strip()\n", " return text\n", " \n", "def predictor(prefix):\n", " words = clean(prefix)\n", " candidates = get_gap_candidates(words.strip().split(' ')[-2:])\n", "\n", " probs_sum = 0\n", " output = ''\n", " for word,prob in candidates:\n", " if word == \"\":\n", " continue\n", " probs_sum += prob\n", " output += f\"{word}:{prob} \"\n", " output += f\":{1-probs_sum}\"\n", "\n", " return output" ] }, { "cell_type": "code", "execution_count": 8, "id": "965ebaf3-4c0b-4462-8ac5-4746ec9489ab", "metadata": {}, "outputs": [], "source": [ "def generate_result(input_path, output_path='out.tsv'):\n", " lines = []\n", " with open(input_path, encoding='utf-8') as f:\n", " for line in f:\n", " columns = line.split('\\t')\n", " prefix = columns[6]\n", " suffix = columns[7]\n", " lines.append(prefix)\n", "\n", " with open(output_path, 'w', encoding='utf-8') as output_file:\n", " for line in lines:\n", " result = predictor(line)\n", " output_file.write(result + '\\n')" ] }, { "cell_type": "code", "execution_count": 9, "id": "80547ba7-9d01-4d2b-9e83-269919513de9", "metadata": {}, "outputs": [], "source": [ "generate_result('dev-0/in.tsv', output_path='dev-0/out.tsv')" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.3" } }, "nbformat": 4, "nbformat_minor": 5 }