{ "cells": [ { "cell_type": "code", "execution_count": 2, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "KYySXV60UbL4", "outputId": "bb7d4752-ccc2-48cf-f5d4-540ffbcc1243" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "import torch.optim as optim\n", "\n", "torch.manual_seed(1)\n" ] }, { "cell_type": "markdown", "metadata": { "id": "Wle6wRL0UbL9" }, "source": [ "### Data loading" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "id": "63t_aV_IUbL-" }, "outputs": [], "source": [ "import pickle\n", "import lzma\n", "import regex as re\n", "\n", "\n", "def load_pickle(filename):\n", " with open(filename, \"rb\") as f:\n", " return pickle.load(f)\n", "\n", "\n", "def save_pickle(d):\n", " with open(\"vocabulary.pkl\", \"wb\") as f:\n", " pickle.dump(d, f, protocol=pickle.HIGHEST_PROTOCOL)\n", "\n", "\n", "def clean_document(document: str) -> str:\n", " document = document.lower().replace(\"’\", \"'\")\n", " document = re.sub(r\"'s|[\\-­]\\\\n\", \"\", document)\n", " document = re.sub(\n", " r\"(\\\\+n|[{}\\[\\]”&:•¦()*0-9;\\\"«»$\\-><^,®¬¿?¡!#+. \\t\\n])+\", \" \", document\n", " )\n", " for to_find, substitute in zip(\n", " [\"i'm\", \"won't\", \"n't\", \"'ll\"], [\"i am\", \"will not\", \" not\", \" will\"]\n", " ):\n", " document = document.replace(to_find, substitute)\n", " return document\n", "\n", "\n", "def get_words_from_line(line, clean_text=True):\n", " if clean_text:\n", " line = clean_document(line) # .rstrip()\n", " else:\n", " line = line.strip()\n", " yield \"\"\n", " for m in re.finditer(r\"[\\p{L}0-9\\*]+|\\p{P}+\", line):\n", " yield m.group(0).lower()\n", " yield \"\"\n", "\n", "\n", "def get_word_lines_from_file(file_name, clean_text=True, only_text=False):\n", " with lzma.open(file_name, \"r\") as fh:\n", " for i, line in enumerate(fh):\n", " if only_text:\n", " line = \"\\t\".join(line.decode(\"utf-8\").split(\"\\t\")[:-2])\n", " else:\n", " line = line.decode(\"utf-8\")\n", " if i % 10000 == 0:\n", " print(i)\n", " yield get_words_from_line(line, clean_text)\n" ] }, { "cell_type": "markdown", "metadata": { "id": "N7HGIM40UbL-" }, "source": [ "### Dataclasses" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "50ezMyyNUbL_", "outputId": "88af632f-fa88-43ea-ea7e-a709a2428c11" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0\n", "10000\n", "20000\n", "30000\n", "40000\n", "50000\n", "60000\n", "70000\n", "80000\n", "90000\n", "100000\n", "110000\n", "120000\n", "130000\n", "140000\n", "150000\n", "160000\n", "170000\n", "180000\n", "190000\n", "200000\n", "210000\n", "220000\n", "230000\n", "240000\n", "250000\n", "260000\n", "270000\n", "280000\n", "290000\n", "300000\n", "310000\n", "320000\n", "330000\n", "340000\n", "350000\n", "360000\n", "370000\n", "380000\n", "390000\n", "400000\n", "410000\n", "420000\n", "430000\n" ] } ], "source": [ "from torch.utils.data import IterableDataset\n", "from torchtext.vocab import build_vocab_from_iterator\n", "import itertools\n", "\n", "\n", "VOCAB_SIZE = 20000\n", "\n", "\n", "def look_ahead_iterator(gen):\n", " prev = None\n", " for item in gen:\n", " if prev is not None:\n", " yield (prev, item)\n", " prev = item\n", "\n", "\n", "class Bigrams(IterableDataset):\n", " def __init__(\n", " self, text_file, vocabulary_size, vocab=None, only_text=False, clean_text=True\n", " ):\n", " self.vocab = (\n", " build_vocab_from_iterator(\n", " get_word_lines_from_file(text_file, clean_text, only_text),\n", " max_tokens=vocabulary_size,\n", " specials=[\"\"],\n", " )\n", " if vocab is None\n", " else vocab\n", " )\n", " self.vocab.set_default_index(self.vocab[\"\"])\n", " self.vocabulary_size = vocabulary_size\n", " self.text_file = text_file\n", " self.clean_text = clean_text\n", " self.only_text = only_text\n", "\n", " def __iter__(self):\n", " return look_ahead_iterator(\n", " (\n", " self.vocab[t]\n", " for t in itertools.chain.from_iterable(\n", " get_word_lines_from_file(\n", " self.text_file, self.clean_text, self.only_text\n", " )\n", " )\n", " )\n", " )\n", "\n", "\n", "vocab = None # torch.load('./vocab.pth')\n", "\n", "train_dataset = Bigrams(\"/content/train/in.tsv.xz\", VOCAB_SIZE, vocab, clean_text=False)\n" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "id": "SCZ87kVxUbL_" }, "outputs": [], "source": [ "# torch.save(train_dataset.vocab, \"vocab.pth\")\n", "# torch.save(train_dataset.vocab, \"vocab_only_text.pth\")\n", "# torch.save(train_dataset.vocab, \"vocab_only_text_clean.pth\")\n", "torch.save(train_dataset.vocab, \"vocab_2.pth\")\n" ] }, { "cell_type": "markdown", "metadata": { "id": "bv_Adw_lUbMA" }, "source": [ "### Model definition" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "id": "oaoLz3jPUbMB" }, "outputs": [], "source": [ "class SimpleBigramNeuralLanguageModel(nn.Module):\n", " def __init__(self, vocabulary_size, embedding_size):\n", " super(SimpleBigramNeuralLanguageModel, self).__init__()\n", " self.model = nn.Sequential(\n", " nn.Embedding(vocabulary_size, embedding_size),\n", " nn.Linear(embedding_size, vocabulary_size),\n", " nn.Softmax(),\n", " )\n", "\n", " def forward(self, x):\n", " return self.model(x)\n" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "id": "YBDf3nEvUbMC" }, "outputs": [], "source": [ "EMBED_SIZE = 100\n", "\n", "model = SimpleBigramNeuralLanguageModel(VOCAB_SIZE, EMBED_SIZE)\n" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 1000 }, "id": "CRe3STJUUbMC", "outputId": "ab4b234a-49c7-4878-f8e4-e21072efa9ee" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/container.py:217: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n", " input = module(input)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "0 tensor(10.0674, device='cuda:0', grad_fn=)\n", "100 tensor(8.4352, device='cuda:0', grad_fn=)\n", "200 tensor(7.6662, device='cuda:0', grad_fn=)\n", "300 tensor(7.0716, device='cuda:0', grad_fn=)\n", "400 tensor(6.6710, device='cuda:0', grad_fn=)\n", "500 tensor(6.4540, device='cuda:0', grad_fn=)\n", "600 tensor(5.9974, device='cuda:0', grad_fn=)\n", "700 tensor(5.7973, device='cuda:0', grad_fn=)\n", "800 tensor(5.8026, device='cuda:0', grad_fn=)\n", "10000\n", "900 tensor(5.7118, device='cuda:0', grad_fn=)\n", "1000 tensor(5.7471, device='cuda:0', grad_fn=)\n", "1100 tensor(5.6865, device='cuda:0', grad_fn=)\n", "1200 tensor(5.4205, device='cuda:0', grad_fn=)\n", "1300 tensor(5.4954, device='cuda:0', grad_fn=)\n", "1400 tensor(5.5415, device='cuda:0', grad_fn=)\n", "1500 tensor(5.3322, device='cuda:0', grad_fn=)\n", "1600 tensor(5.4665, device='cuda:0', grad_fn=)\n", "1700 tensor(5.4710, device='cuda:0', grad_fn=)\n", "20000\n", "1800 tensor(5.3953, device='cuda:0', grad_fn=)\n", "1900 tensor(5.4881, device='cuda:0', grad_fn=)\n", "2000 tensor(5.4915, device='cuda:0', grad_fn=)\n", "2100 tensor(5.3621, device='cuda:0', grad_fn=)\n", "2200 tensor(5.2872, device='cuda:0', grad_fn=)\n", "2300 tensor(5.2590, device='cuda:0', grad_fn=)\n", "2400 tensor(5.3661, device='cuda:0', grad_fn=)\n", "2500 tensor(5.3305, device='cuda:0', grad_fn=)\n", "30000\n", "2600 tensor(5.3789, device='cuda:0', grad_fn=)\n", "2700 tensor(5.3548, device='cuda:0', grad_fn=)\n", "2800 tensor(5.4579, device='cuda:0', grad_fn=)\n", "2900 tensor(5.2660, device='cuda:0', grad_fn=)\n", "3000 tensor(5.3253, device='cuda:0', grad_fn=)\n", "3100 tensor(5.4020, device='cuda:0', grad_fn=)\n", "3200 tensor(5.2962, device='cuda:0', grad_fn=)\n", "3300 tensor(5.2570, device='cuda:0', grad_fn=)\n", "3400 tensor(5.2317, device='cuda:0', grad_fn=)\n", "40000\n", "3500 tensor(5.2410, device='cuda:0', grad_fn=)\n", "3600 tensor(5.2404, device='cuda:0', grad_fn=)\n", "3700 tensor(5.1738, device='cuda:0', grad_fn=)\n", "3800 tensor(5.2654, device='cuda:0', grad_fn=)\n", "3900 tensor(5.2595, device='cuda:0', grad_fn=)\n", "4000 tensor(5.2850, device='cuda:0', grad_fn=)\n", "4100 tensor(5.2995, device='cuda:0', grad_fn=)\n", "4200 tensor(5.2581, device='cuda:0', grad_fn=)\n", "4300 tensor(5.3323, device='cuda:0', grad_fn=)\n", "50000\n", "4400 tensor(5.2498, device='cuda:0', grad_fn=)\n", "4500 tensor(5.2674, device='cuda:0', grad_fn=)\n", "4600 tensor(5.3033, device='cuda:0', grad_fn=)\n", "4700 tensor(5.2066, device='cuda:0', grad_fn=)\n", "4800 tensor(5.2302, device='cuda:0', grad_fn=)\n", "4900 tensor(5.2617, device='cuda:0', grad_fn=)\n", "5000 tensor(5.2306, device='cuda:0', grad_fn=)\n", "5100 tensor(5.2781, device='cuda:0', grad_fn=)\n", "60000\n", "5200 tensor(5.1833, device='cuda:0', grad_fn=)\n", "5300 tensor(5.2166, device='cuda:0', grad_fn=)\n", "5400 tensor(5.0845, device='cuda:0', grad_fn=)\n", "5500 tensor(5.2272, device='cuda:0', grad_fn=)\n", "5600 tensor(5.3175, device='cuda:0', grad_fn=)\n", "5700 tensor(5.2425, device='cuda:0', grad_fn=)\n", "5800 tensor(5.2449, device='cuda:0', grad_fn=)\n", "5900 tensor(5.3225, device='cuda:0', grad_fn=)\n", "6000 tensor(5.2786, device='cuda:0', grad_fn=)\n", "70000\n", "6100 tensor(5.1489, device='cuda:0', grad_fn=)\n", "6200 tensor(5.1793, device='cuda:0', grad_fn=)\n", "6300 tensor(5.2194, device='cuda:0', grad_fn=)\n", "6400 tensor(5.1708, device='cuda:0', grad_fn=)\n", "6500 tensor(5.1394, device='cuda:0', grad_fn=)\n", "6600 tensor(5.1280, device='cuda:0', grad_fn=)\n", "6700 tensor(5.0869, device='cuda:0', grad_fn=)\n", "6800 tensor(5.3255, device='cuda:0', grad_fn=)\n", "6900 tensor(5.3426, device='cuda:0', grad_fn=)\n", "80000\n", "7000 tensor(5.1176, device='cuda:0', grad_fn=)\n", "7100 tensor(5.1991, device='cuda:0', grad_fn=)\n", "7200 tensor(5.1227, device='cuda:0', grad_fn=)\n", "7300 tensor(5.1744, device='cuda:0', grad_fn=)\n", "7400 tensor(5.2222, device='cuda:0', grad_fn=)\n", "7500 tensor(5.2110, device='cuda:0', grad_fn=)\n", "7600 tensor(5.1553, device='cuda:0', grad_fn=)\n", "7700 tensor(5.3283, device='cuda:0', grad_fn=)\n", "90000\n", "7800 tensor(5.2544, device='cuda:0', grad_fn=)\n", "7900 tensor(5.1871, device='cuda:0', grad_fn=)\n", "8000 tensor(5.2215, device='cuda:0', grad_fn=)\n", "8100 tensor(5.1744, device='cuda:0', grad_fn=)\n", "8200 tensor(5.1087, device='cuda:0', grad_fn=)\n", "8300 tensor(5.1639, device='cuda:0', grad_fn=)\n", "8400 tensor(5.1604, device='cuda:0', grad_fn=)\n", "8500 tensor(5.1612, device='cuda:0', grad_fn=)\n", "8600 tensor(5.2307, device='cuda:0', grad_fn=)\n", "100000\n", "8700 tensor(5.1648, device='cuda:0', grad_fn=)\n", "8800 tensor(5.1066, device='cuda:0', grad_fn=)\n", "8900 tensor(5.2405, device='cuda:0', grad_fn=)\n", "9000 tensor(5.2184, device='cuda:0', grad_fn=)\n", "9100 tensor(5.2677, device='cuda:0', grad_fn=)\n", "9200 tensor(5.0773, device='cuda:0', grad_fn=)\n" ] }, { "ename": "KeyboardInterrupt", "evalue": "ignored", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 18\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mloss\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 19\u001b[0m \u001b[0mstep\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 20\u001b[0;31m \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 21\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 22\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/_tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m 485\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 486\u001b[0m )\n\u001b[0;32m--> 487\u001b[0;31m torch.autograd.backward(\n\u001b[0m\u001b[1;32m 488\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 489\u001b[0m )\n", "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/autograd/__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m 198\u001b[0m \u001b[0;31m# some Python versions print out the first line of a multi-line function\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 199\u001b[0m \u001b[0;31m# calls in the traceback and some print out the last line\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 200\u001b[0;31m Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass\n\u001b[0m\u001b[1;32m 201\u001b[0m \u001b[0mtensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad_tensors_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 202\u001b[0m allow_unreachable=True, accumulate_grad=True) # Calls into the C++ engine to run the backward pass\n", "\u001b[0;31mKeyboardInterrupt\u001b[0m: " ] } ], "source": [ "from torch.utils.data import DataLoader\n", "\n", "device = \"cuda\"\n", "model = SimpleBigramNeuralLanguageModel(VOCAB_SIZE, EMBED_SIZE).to(device)\n", "data = DataLoader(train_dataset, batch_size=5000)\n", "optimizer = torch.optim.Adam(model.parameters())\n", "criterion = torch.nn.NLLLoss()\n", "\n", "model.train()\n", "step = 0\n", "for x, y in data:\n", " x = x.to(device)\n", " y = y.to(device)\n", " optimizer.zero_grad()\n", " ypredicted = model(x)\n", " loss = criterion(torch.log(ypredicted), y)\n", " if step % 100 == 0:\n", " print(step, loss)\n", " step += 1\n", " loss.backward()\n", " optimizer.step()\n", "\n", "torch.save(model.state_dict(), \"model_2.bin\")\n" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "kS9NHTGeom3y", "outputId": "ce83640e-d2fe-41e6-cd5d-38a5b0410323" }, "outputs": [ { "data": { "text/plain": [ "[('the', 2, 0.15899169445037842),\n", " ('\\\\', 1, 0.10546761751174927),\n", " ('he', 28, 0.06849857419729233),\n", " ('it', 15, 0.05329886078834534),\n", " ('i', 26, 0.0421920120716095),\n", " ('they', 50, 0.03895237296819687),\n", " ('a', 8, 0.03352600708603859),\n", " ('', 0, 0.031062396243214607),\n", " ('we', 61, 0.02323235757648945),\n", " ('she', 104, 0.02003088779747486)]" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ixs = torch.tensor(train_dataset.vocab.forward([\"when\"])).to(device)\n", "out = model(ixs)\n", "top = torch.topk(out[0], 10)\n", "top_indices = top.indices.tolist()\n", "top_probs = top.values.tolist()\n", "top_words = train_dataset.vocab.lookup_tokens(top_indices)\n", "list(zip(top_words, top_indices, top_probs))\n" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "UTYHTCJvC_Nm", "outputId": "576288fa-fdad-4c21-d924-f613eaf33063" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "device = \"cuda\"\n", "model = SimpleBigramNeuralLanguageModel(VOCAB_SIZE, EMBED_SIZE).to(device)\n", "model.load_state_dict(torch.load(\"model1.bin\"))\n" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "8WexjGIAxaE4", "outputId": "52252b81-3b98-42d3-b137-472af00dbb26" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Training on /content/dev-0/in.tsv.xz\n", "\rProgress: 0.01%\rProgress: 0.02%\rProgress: 0.03%\rProgress: 0.04%\rProgress: 0.05%\rProgress: 0.06%\rProgress: 0.07%\rProgress: 0.08%\rProgress: 0.09%\rProgress: 0.10%\rProgress: 0.10%\rProgress: 0.11%\rProgress: 0.12%\rProgress: 0.13%\rProgress: 0.14%\rProgress: 0.15%\rProgress: 0.16%\rProgress: 0.17%\rProgress: 0.18%\rProgress: 0.19%\rProgress: 0.20%\rProgress: 0.21%\rProgress: 0.22%\rProgress: 0.23%\rProgress: 0.24%\rProgress: 0.25%\rProgress: 0.26%\rProgress: 0.27%\rProgress: 0.28%\rProgress: 0.29%\rProgress: 0.29%\rProgress: 0.30%\rProgress: 0.31%\rProgress: 0.32%\rProgress: 0.33%\rProgress: 0.34%\rProgress: 0.35%\rProgress: 0.36%\rProgress: 0.37%\rProgress: 0.38%\rProgress: 0.39%\rProgress: 0.40%\rProgress: 0.41%\rProgress: 0.42%\rProgress: 0.43%\rProgress: 0.44%\rProgress: 0.45%\rProgress: 0.46%\rProgress: 0.47%\rProgress: 0.48%\rProgress: 0.48%\rProgress: 0.49%\rProgress: 0.50%\rProgress: 0.51%\rProgress: 0.52%\rProgress: 0.53%\rProgress: 0.54%\rProgress: 0.55%\rProgress: 0.56%\rProgress: 0.57%\rProgress: 0.58%\rProgress: 0.59%\rProgress: 0.60%\rProgress: 0.61%\rProgress: 0.62%\rProgress: 0.63%\rProgress: 0.64%\rProgress: 0.65%\rProgress: 0.66%\rProgress: 0.67%\rProgress: 0.67%\rProgress: 0.68%\rProgress: 0.69%\rProgress: 0.70%\rProgress: 0.71%\rProgress: 0.72%\rProgress: 0.73%\rProgress: 0.74%\rProgress: 0.75%\rProgress: 0.76%\rProgress: 0.77%\rProgress: 0.78%\rProgress: 0.79%\rProgress: 0.80%\rProgress: 0.81%\rProgress: 0.82%\rProgress: 0.83%\rProgress: 0.84%\rProgress: 0.85%\rProgress: 0.86%\rProgress: 0.87%\rProgress: 0.87%\rProgress: 0.88%\rProgress: 0.89%\rProgress: 0.90%\rProgress: 0.91%\rProgress: 0.92%\rProgress: 0.93%\rProgress: 0.94%\rProgress: 0.95%\rProgress: 0.96%\rProgress: 0.97%\rProgress: 0.98%\rProgress: 0.99%\rProgress: 1.00%\rProgress: 1.01%\rProgress: 1.02%\rProgress: 1.03%\rProgress: 1.04%\rProgress: 1.05%\rProgress: 1.06%\rProgress: 1.06%\rProgress: 1.07%\rProgress: 1.08%\rProgress: 1.09%\rProgress: 1.10%\rProgress: 1.11%\rProgress: 1.12%\rProgress: 1.13%\rProgress: 1.14%\rProgress: 1.15%\rProgress: 1.16%\rProgress: 1.17%\rProgress: 1.18%\rProgress: 1.19%\rProgress: 1.20%\rProgress: 1.21%\rProgress: 1.22%\rProgress: 1.23%\rProgress: 1.24%\rProgress: 1.25%\rProgress: 1.25%\rProgress: 1.26%\rProgress: 1.27%\rProgress: 1.28%\rProgress: 1.29%\rProgress: 1.30%\rProgress: 1.31%\rProgress: 1.32%\rProgress: 1.33%\rProgress: 1.34%\rProgress: 1.35%\rProgress: 1.36%\rProgress: 1.37%\rProgress: 1.38%\rProgress: 1.39%\rProgress: 1.40%\rProgress: 1.41%\rProgress: 1.42%\rProgress: 1.43%\rProgress: 1.44%\rProgress: 1.45%\rProgress: 1.45%\rProgress: 1.46%\rProgress: 1.47%\rProgress: 1.48%\rProgress: 1.49%\rProgress: 1.50%\rProgress: 1.51%\rProgress: 1.52%\rProgress: 1.53%\rProgress: 1.54%\rProgress: 1.55%\rProgress: 1.56%\rProgress: 1.57%\rProgress: 1.58%\rProgress: 1.59%\rProgress: 1.60%\rProgress: 1.61%\rProgress: 1.62%" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/container.py:217: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n", " input = module(input)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Progress: 100.00%\n", "Training on /content/test-A/in.tsv.xz\n", "Progress: 100.00%\n" ] } ], "source": [ "def predict_word(ixs, model, top_k=5):\n", " out = model(ixs)\n", " top = torch.topk(out[0], 10)\n", " top_indices = top.indices.tolist()\n", " top_probs = top.values.tolist()\n", " top_words = train_dataset.vocab.lookup_tokens(top_indices)\n", " return list(zip(top_words, top_indices, top_probs))\n", "\n", "\n", "def get_one_word(text, context=\"left\"):\n", " # print(\"Getting word from:\", text)\n", " if context == \"left\":\n", " context = -1\n", " else:\n", " context = 0\n", " return text.rstrip().split(\" \")[context]\n", "\n", "\n", "def inference_on_file(filename, model, lines_no=1):\n", " results_path = \"/\".join(filename.split(\"/\")[:-1]) + \"/out.tsv\"\n", " with lzma.open(filename, \"r\") as fp, open(results_path, \"w\") as out_file:\n", " print(\"Training on\", filename)\n", " for i, line in enumerate(fp):\n", " # left, right = [ get_one_word(text_part, context)\n", " # for context, text_part in zip(line.split('\\t')[:-2], ('left', 'right'))]\n", " line = line.decode(\"utf-8\")\n", " # print(line)\n", " left = get_one_word(line.split(\"\\t\")[-2])\n", " # print(\"Current word:\", left)\n", " tensor = torch.tensor(train_dataset.vocab.forward([left])).to(device)\n", " results = predict_word(tensor, model, 9)\n", " prob_sum = sum([word[2] for word in results])\n", " result_line = (\n", " \" \".join([f\"{word[0]}:{word[2]}\" for word in results])\n", " + f\" :{prob_sum}\\n\"\n", " )\n", " # print(result_line)\n", " out_file.write(result_line)\n", " print(f\"\\rProgress: {(((i+1) / lines_no) * 100):.2f}%\", end=\"\")\n", " print()\n", "\n", "\n", "model.eval()\n", "\n", "for filepath, lines_no in zip(\n", " (\"/content/dev-0/in.tsv.xz\", \"/content/test-A/in.tsv.xz\"), (10519.0, 7414.0)\n", "):\n", " inference_on_file(filepath, model, lines_no)\n" ] } ], "metadata": { "accelerator": "GPU", "colab": { "provenance": [] }, "gpuClass": "standard", "kernelspec": { "display_name": "mj_venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 0 }