diff --git a/gonito.yml b/gonito.yml new file mode 100644 index 0000000..caa409b --- /dev/null +++ b/gonito.yml @@ -0,0 +1,13 @@ +description: Lab 7 - Ex. 1 +tags: + - bigram + - neural-network +params: + epochs: 1 + batch-size: 5000 + learning-rate: 0.001 + embed_size: 100 + vocab_size: 20000 +links: + - title: "Git WMI" + url: "https://git.wmi.amu.edu.pl/s478841/challenging-america-word-gap-prediction" \ No newline at end of file diff --git a/neural-networks/pytorch_n_gram.ipynb b/neural-networks/pytorch_n_gram.ipynb new file mode 100644 index 0000000..f76b43b --- /dev/null +++ b/neural-networks/pytorch_n_gram.ipynb @@ -0,0 +1,645 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "KYySXV60UbL4", + "outputId": "bb7d4752-ccc2-48cf-f5d4-540ffbcc1243" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "import torch.optim as optim\n", + "\n", + "torch.manual_seed(1)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Wle6wRL0UbL9" + }, + "source": [ + "### Data loading" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "id": "63t_aV_IUbL-" + }, + "outputs": [], + "source": [ + "import pickle\n", + "import lzma\n", + "import regex as re\n", + "\n", + "\n", + "def load_pickle(filename):\n", + " with open(filename, \"rb\") as f:\n", + " return pickle.load(f)\n", + "\n", + "\n", + "def save_pickle(d):\n", + " with open(\"vocabulary.pkl\", \"wb\") as f:\n", + " pickle.dump(d, f, protocol=pickle.HIGHEST_PROTOCOL)\n", + "\n", + "\n", + "def clean_document(document: str) -> str:\n", + " document = document.lower().replace(\"’\", \"'\")\n", + " document = re.sub(r\"'s|[\\-­]\\\\n\", \"\", document)\n", + " document = re.sub(\n", + " r\"(\\\\+n|[{}\\[\\]”&:•¦()*0-9;\\\"«»$\\-><^,®¬¿?¡!#+. \\t\\n])+\", \" \", document\n", + " )\n", + " for to_find, substitute in zip(\n", + " [\"i'm\", \"won't\", \"n't\", \"'ll\"], [\"i am\", \"will not\", \" not\", \" will\"]\n", + " ):\n", + " document = document.replace(to_find, substitute)\n", + " return document\n", + "\n", + "\n", + "def get_words_from_line(line, clean_text=True):\n", + " if clean_text:\n", + " line = clean_document(line) # .rstrip()\n", + " else:\n", + " line = line.strip()\n", + " yield \"\"\n", + " for m in re.finditer(r\"[\\p{L}0-9\\*]+|\\p{P}+\", line):\n", + " yield m.group(0).lower()\n", + " yield \"\"\n", + "\n", + "\n", + "def get_word_lines_from_file(file_name, clean_text=True, only_text=False):\n", + " with lzma.open(file_name, \"r\") as fh:\n", + " for i, line in enumerate(fh):\n", + " if only_text:\n", + " line = \"\\t\".join(line.decode(\"utf-8\").split(\"\\t\")[:-2])\n", + " else:\n", + " line = line.decode(\"utf-8\")\n", + " if i % 10000 == 0:\n", + " print(i)\n", + " yield get_words_from_line(line, clean_text)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "N7HGIM40UbL-" + }, + "source": [ + "### Dataclasses" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "50ezMyyNUbL_", + "outputId": "88af632f-fa88-43ea-ea7e-a709a2428c11" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0\n", + "10000\n", + "20000\n", + "30000\n", + "40000\n", + "50000\n", + "60000\n", + "70000\n", + "80000\n", + "90000\n", + "100000\n", + "110000\n", + "120000\n", + "130000\n", + "140000\n", + "150000\n", + "160000\n", + "170000\n", + "180000\n", + "190000\n", + "200000\n", + "210000\n", + "220000\n", + "230000\n", + "240000\n", + "250000\n", + "260000\n", + "270000\n", + "280000\n", + "290000\n", + "300000\n", + "310000\n", + "320000\n", + "330000\n", + "340000\n", + "350000\n", + "360000\n", + "370000\n", + "380000\n", + "390000\n", + "400000\n", + "410000\n", + "420000\n", + "430000\n" + ] + } + ], + "source": [ + "from torch.utils.data import IterableDataset\n", + "from torchtext.vocab import build_vocab_from_iterator\n", + "import itertools\n", + "\n", + "\n", + "VOCAB_SIZE = 20000\n", + "\n", + "\n", + "def look_ahead_iterator(gen):\n", + " prev = None\n", + " for item in gen:\n", + " if prev is not None:\n", + " yield (prev, item)\n", + " prev = item\n", + "\n", + "\n", + "class Bigrams(IterableDataset):\n", + " def __init__(\n", + " self, text_file, vocabulary_size, vocab=None, only_text=False, clean_text=True\n", + " ):\n", + " self.vocab = (\n", + " build_vocab_from_iterator(\n", + " get_word_lines_from_file(text_file, clean_text, only_text),\n", + " max_tokens=vocabulary_size,\n", + " specials=[\"\"],\n", + " )\n", + " if vocab is None\n", + " else vocab\n", + " )\n", + " self.vocab.set_default_index(self.vocab[\"\"])\n", + " self.vocabulary_size = vocabulary_size\n", + " self.text_file = text_file\n", + " self.clean_text = clean_text\n", + " self.only_text = only_text\n", + "\n", + " def __iter__(self):\n", + " return look_ahead_iterator(\n", + " (\n", + " self.vocab[t]\n", + " for t in itertools.chain.from_iterable(\n", + " get_word_lines_from_file(\n", + " self.text_file, self.clean_text, self.only_text\n", + " )\n", + " )\n", + " )\n", + " )\n", + "\n", + "\n", + "vocab = None # torch.load('./vocab.pth')\n", + "\n", + "train_dataset = Bigrams(\"/content/train/in.tsv.xz\", VOCAB_SIZE, vocab, clean_text=False)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "id": "SCZ87kVxUbL_" + }, + "outputs": [], + "source": [ + "# torch.save(train_dataset.vocab, \"vocab.pth\")\n", + "# torch.save(train_dataset.vocab, \"vocab_only_text.pth\")\n", + "# torch.save(train_dataset.vocab, \"vocab_only_text_clean.pth\")\n", + "torch.save(train_dataset.vocab, \"vocab_2.pth\")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "bv_Adw_lUbMA" + }, + "source": [ + "### Model definition" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "id": "oaoLz3jPUbMB" + }, + "outputs": [], + "source": [ + "class SimpleBigramNeuralLanguageModel(nn.Module):\n", + " def __init__(self, vocabulary_size, embedding_size):\n", + " super(SimpleBigramNeuralLanguageModel, self).__init__()\n", + " self.model = nn.Sequential(\n", + " nn.Embedding(vocabulary_size, embedding_size),\n", + " nn.Linear(embedding_size, vocabulary_size),\n", + " nn.Softmax(),\n", + " )\n", + "\n", + " def forward(self, x):\n", + " return self.model(x)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "id": "YBDf3nEvUbMC" + }, + "outputs": [], + "source": [ + "EMBED_SIZE = 100\n", + "\n", + "model = SimpleBigramNeuralLanguageModel(VOCAB_SIZE, EMBED_SIZE)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 1000 + }, + "id": "CRe3STJUUbMC", + "outputId": "ab4b234a-49c7-4878-f8e4-e21072efa9ee" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/container.py:217: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n", + " input = module(input)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0 tensor(10.0674, device='cuda:0', grad_fn=)\n", + "100 tensor(8.4352, device='cuda:0', grad_fn=)\n", + "200 tensor(7.6662, device='cuda:0', grad_fn=)\n", + "300 tensor(7.0716, device='cuda:0', grad_fn=)\n", + "400 tensor(6.6710, device='cuda:0', grad_fn=)\n", + "500 tensor(6.4540, device='cuda:0', grad_fn=)\n", + "600 tensor(5.9974, device='cuda:0', grad_fn=)\n", + "700 tensor(5.7973, device='cuda:0', grad_fn=)\n", + "800 tensor(5.8026, device='cuda:0', grad_fn=)\n", + "10000\n", + "900 tensor(5.7118, device='cuda:0', grad_fn=)\n", + "1000 tensor(5.7471, device='cuda:0', grad_fn=)\n", + "1100 tensor(5.6865, device='cuda:0', grad_fn=)\n", + "1200 tensor(5.4205, device='cuda:0', grad_fn=)\n", + "1300 tensor(5.4954, device='cuda:0', grad_fn=)\n", + "1400 tensor(5.5415, device='cuda:0', grad_fn=)\n", + "1500 tensor(5.3322, device='cuda:0', grad_fn=)\n", + "1600 tensor(5.4665, device='cuda:0', grad_fn=)\n", + "1700 tensor(5.4710, device='cuda:0', grad_fn=)\n", + "20000\n", + "1800 tensor(5.3953, device='cuda:0', grad_fn=)\n", + "1900 tensor(5.4881, device='cuda:0', grad_fn=)\n", + "2000 tensor(5.4915, device='cuda:0', grad_fn=)\n", + "2100 tensor(5.3621, device='cuda:0', grad_fn=)\n", + "2200 tensor(5.2872, device='cuda:0', grad_fn=)\n", + "2300 tensor(5.2590, device='cuda:0', grad_fn=)\n", + "2400 tensor(5.3661, device='cuda:0', grad_fn=)\n", + "2500 tensor(5.3305, device='cuda:0', grad_fn=)\n", + "30000\n", + "2600 tensor(5.3789, device='cuda:0', grad_fn=)\n", + "2700 tensor(5.3548, device='cuda:0', grad_fn=)\n", + "2800 tensor(5.4579, device='cuda:0', grad_fn=)\n", + "2900 tensor(5.2660, device='cuda:0', grad_fn=)\n", + "3000 tensor(5.3253, device='cuda:0', grad_fn=)\n", + "3100 tensor(5.4020, device='cuda:0', grad_fn=)\n", + "3200 tensor(5.2962, device='cuda:0', grad_fn=)\n", + "3300 tensor(5.2570, device='cuda:0', grad_fn=)\n", + "3400 tensor(5.2317, device='cuda:0', grad_fn=)\n", + "40000\n", + "3500 tensor(5.2410, device='cuda:0', grad_fn=)\n", + "3600 tensor(5.2404, device='cuda:0', grad_fn=)\n", + "3700 tensor(5.1738, device='cuda:0', grad_fn=)\n", + "3800 tensor(5.2654, device='cuda:0', grad_fn=)\n", + "3900 tensor(5.2595, device='cuda:0', grad_fn=)\n", + "4000 tensor(5.2850, device='cuda:0', grad_fn=)\n", + "4100 tensor(5.2995, device='cuda:0', grad_fn=)\n", + "4200 tensor(5.2581, device='cuda:0', grad_fn=)\n", + "4300 tensor(5.3323, device='cuda:0', grad_fn=)\n", + "50000\n", + "4400 tensor(5.2498, device='cuda:0', grad_fn=)\n", + "4500 tensor(5.2674, device='cuda:0', grad_fn=)\n", + "4600 tensor(5.3033, device='cuda:0', grad_fn=)\n", + "4700 tensor(5.2066, device='cuda:0', grad_fn=)\n", + "4800 tensor(5.2302, device='cuda:0', grad_fn=)\n", + "4900 tensor(5.2617, device='cuda:0', grad_fn=)\n", + "5000 tensor(5.2306, device='cuda:0', grad_fn=)\n", + "5100 tensor(5.2781, device='cuda:0', grad_fn=)\n", + "60000\n", + "5200 tensor(5.1833, device='cuda:0', grad_fn=)\n", + "5300 tensor(5.2166, device='cuda:0', grad_fn=)\n", + "5400 tensor(5.0845, device='cuda:0', grad_fn=)\n", + "5500 tensor(5.2272, device='cuda:0', grad_fn=)\n", + "5600 tensor(5.3175, device='cuda:0', grad_fn=)\n", + "5700 tensor(5.2425, device='cuda:0', grad_fn=)\n", + "5800 tensor(5.2449, device='cuda:0', grad_fn=)\n", + "5900 tensor(5.3225, device='cuda:0', grad_fn=)\n", + "6000 tensor(5.2786, device='cuda:0', grad_fn=)\n", + "70000\n", + "6100 tensor(5.1489, device='cuda:0', grad_fn=)\n", + "6200 tensor(5.1793, device='cuda:0', grad_fn=)\n", + "6300 tensor(5.2194, device='cuda:0', grad_fn=)\n", + "6400 tensor(5.1708, device='cuda:0', grad_fn=)\n", + "6500 tensor(5.1394, device='cuda:0', grad_fn=)\n", + "6600 tensor(5.1280, device='cuda:0', grad_fn=)\n", + "6700 tensor(5.0869, device='cuda:0', grad_fn=)\n", + "6800 tensor(5.3255, device='cuda:0', grad_fn=)\n", + "6900 tensor(5.3426, device='cuda:0', grad_fn=)\n", + "80000\n", + "7000 tensor(5.1176, device='cuda:0', grad_fn=)\n", + "7100 tensor(5.1991, device='cuda:0', grad_fn=)\n", + "7200 tensor(5.1227, device='cuda:0', grad_fn=)\n", + "7300 tensor(5.1744, device='cuda:0', grad_fn=)\n", + "7400 tensor(5.2222, device='cuda:0', grad_fn=)\n", + "7500 tensor(5.2110, device='cuda:0', grad_fn=)\n", + "7600 tensor(5.1553, device='cuda:0', grad_fn=)\n", + "7700 tensor(5.3283, device='cuda:0', grad_fn=)\n", + "90000\n", + "7800 tensor(5.2544, device='cuda:0', grad_fn=)\n", + "7900 tensor(5.1871, device='cuda:0', grad_fn=)\n", + "8000 tensor(5.2215, device='cuda:0', grad_fn=)\n", + "8100 tensor(5.1744, device='cuda:0', grad_fn=)\n", + "8200 tensor(5.1087, device='cuda:0', grad_fn=)\n", + "8300 tensor(5.1639, device='cuda:0', grad_fn=)\n", + "8400 tensor(5.1604, device='cuda:0', grad_fn=)\n", + "8500 tensor(5.1612, device='cuda:0', grad_fn=)\n", + "8600 tensor(5.2307, device='cuda:0', grad_fn=)\n", + "100000\n", + "8700 tensor(5.1648, device='cuda:0', grad_fn=)\n", + "8800 tensor(5.1066, device='cuda:0', grad_fn=)\n", + "8900 tensor(5.2405, device='cuda:0', grad_fn=)\n", + "9000 tensor(5.2184, device='cuda:0', grad_fn=)\n", + "9100 tensor(5.2677, device='cuda:0', grad_fn=)\n", + "9200 tensor(5.0773, device='cuda:0', grad_fn=)\n" + ] + }, + { + "ename": "KeyboardInterrupt", + "evalue": "ignored", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 18\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mloss\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 19\u001b[0m \u001b[0mstep\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 20\u001b[0;31m \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 21\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 22\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/_tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m 485\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 486\u001b[0m )\n\u001b[0;32m--> 487\u001b[0;31m torch.autograd.backward(\n\u001b[0m\u001b[1;32m 488\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 489\u001b[0m )\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/autograd/__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m 198\u001b[0m \u001b[0;31m# some Python versions print out the first line of a multi-line function\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 199\u001b[0m \u001b[0;31m# calls in the traceback and some print out the last line\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 200\u001b[0;31m Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass\n\u001b[0m\u001b[1;32m 201\u001b[0m \u001b[0mtensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad_tensors_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 202\u001b[0m allow_unreachable=True, accumulate_grad=True) # Calls into the C++ engine to run the backward pass\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m: " + ] + } + ], + "source": [ + "from torch.utils.data import DataLoader\n", + "\n", + "device = \"cuda\"\n", + "model = SimpleBigramNeuralLanguageModel(VOCAB_SIZE, EMBED_SIZE).to(device)\n", + "data = DataLoader(train_dataset, batch_size=5000)\n", + "optimizer = torch.optim.Adam(model.parameters())\n", + "criterion = torch.nn.NLLLoss()\n", + "\n", + "model.train()\n", + "step = 0\n", + "for x, y in data:\n", + " x = x.to(device)\n", + " y = y.to(device)\n", + " optimizer.zero_grad()\n", + " ypredicted = model(x)\n", + " loss = criterion(torch.log(ypredicted), y)\n", + " if step % 100 == 0:\n", + " print(step, loss)\n", + " step += 1\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + "torch.save(model.state_dict(), \"model_2.bin\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "kS9NHTGeom3y", + "outputId": "ce83640e-d2fe-41e6-cd5d-38a5b0410323" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "[('the', 2, 0.15899169445037842),\n", + " ('\\\\', 1, 0.10546761751174927),\n", + " ('he', 28, 0.06849857419729233),\n", + " ('it', 15, 0.05329886078834534),\n", + " ('i', 26, 0.0421920120716095),\n", + " ('they', 50, 0.03895237296819687),\n", + " ('a', 8, 0.03352600708603859),\n", + " ('', 0, 0.031062396243214607),\n", + " ('we', 61, 0.02323235757648945),\n", + " ('she', 104, 0.02003088779747486)]" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ixs = torch.tensor(train_dataset.vocab.forward([\"when\"])).to(device)\n", + "out = model(ixs)\n", + "top = torch.topk(out[0], 10)\n", + "top_indices = top.indices.tolist()\n", + "top_probs = top.values.tolist()\n", + "top_words = train_dataset.vocab.lookup_tokens(top_indices)\n", + "list(zip(top_words, top_indices, top_probs))\n" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "UTYHTCJvC_Nm", + "outputId": "576288fa-fdad-4c21-d924-f613eaf33063" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "device = \"cuda\"\n", + "model = SimpleBigramNeuralLanguageModel(VOCAB_SIZE, EMBED_SIZE).to(device)\n", + "model.load_state_dict(torch.load(\"model1.bin\"))\n" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "8WexjGIAxaE4", + "outputId": "52252b81-3b98-42d3-b137-472af00dbb26" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Training on /content/dev-0/in.tsv.xz\n", + "\rProgress: 0.01%\rProgress: 0.02%\rProgress: 0.03%\rProgress: 0.04%\rProgress: 0.05%\rProgress: 0.06%\rProgress: 0.07%\rProgress: 0.08%\rProgress: 0.09%\rProgress: 0.10%\rProgress: 0.10%\rProgress: 0.11%\rProgress: 0.12%\rProgress: 0.13%\rProgress: 0.14%\rProgress: 0.15%\rProgress: 0.16%\rProgress: 0.17%\rProgress: 0.18%\rProgress: 0.19%\rProgress: 0.20%\rProgress: 0.21%\rProgress: 0.22%\rProgress: 0.23%\rProgress: 0.24%\rProgress: 0.25%\rProgress: 0.26%\rProgress: 0.27%\rProgress: 0.28%\rProgress: 0.29%\rProgress: 0.29%\rProgress: 0.30%\rProgress: 0.31%\rProgress: 0.32%\rProgress: 0.33%\rProgress: 0.34%\rProgress: 0.35%\rProgress: 0.36%\rProgress: 0.37%\rProgress: 0.38%\rProgress: 0.39%\rProgress: 0.40%\rProgress: 0.41%\rProgress: 0.42%\rProgress: 0.43%\rProgress: 0.44%\rProgress: 0.45%\rProgress: 0.46%\rProgress: 0.47%\rProgress: 0.48%\rProgress: 0.48%\rProgress: 0.49%\rProgress: 0.50%\rProgress: 0.51%\rProgress: 0.52%\rProgress: 0.53%\rProgress: 0.54%\rProgress: 0.55%\rProgress: 0.56%\rProgress: 0.57%\rProgress: 0.58%\rProgress: 0.59%\rProgress: 0.60%\rProgress: 0.61%\rProgress: 0.62%\rProgress: 0.63%\rProgress: 0.64%\rProgress: 0.65%\rProgress: 0.66%\rProgress: 0.67%\rProgress: 0.67%\rProgress: 0.68%\rProgress: 0.69%\rProgress: 0.70%\rProgress: 0.71%\rProgress: 0.72%\rProgress: 0.73%\rProgress: 0.74%\rProgress: 0.75%\rProgress: 0.76%\rProgress: 0.77%\rProgress: 0.78%\rProgress: 0.79%\rProgress: 0.80%\rProgress: 0.81%\rProgress: 0.82%\rProgress: 0.83%\rProgress: 0.84%\rProgress: 0.85%\rProgress: 0.86%\rProgress: 0.87%\rProgress: 0.87%\rProgress: 0.88%\rProgress: 0.89%\rProgress: 0.90%\rProgress: 0.91%\rProgress: 0.92%\rProgress: 0.93%\rProgress: 0.94%\rProgress: 0.95%\rProgress: 0.96%\rProgress: 0.97%\rProgress: 0.98%\rProgress: 0.99%\rProgress: 1.00%\rProgress: 1.01%\rProgress: 1.02%\rProgress: 1.03%\rProgress: 1.04%\rProgress: 1.05%\rProgress: 1.06%\rProgress: 1.06%\rProgress: 1.07%\rProgress: 1.08%\rProgress: 1.09%\rProgress: 1.10%\rProgress: 1.11%\rProgress: 1.12%\rProgress: 1.13%\rProgress: 1.14%\rProgress: 1.15%\rProgress: 1.16%\rProgress: 1.17%\rProgress: 1.18%\rProgress: 1.19%\rProgress: 1.20%\rProgress: 1.21%\rProgress: 1.22%\rProgress: 1.23%\rProgress: 1.24%\rProgress: 1.25%\rProgress: 1.25%\rProgress: 1.26%\rProgress: 1.27%\rProgress: 1.28%\rProgress: 1.29%\rProgress: 1.30%\rProgress: 1.31%\rProgress: 1.32%\rProgress: 1.33%\rProgress: 1.34%\rProgress: 1.35%\rProgress: 1.36%\rProgress: 1.37%\rProgress: 1.38%\rProgress: 1.39%\rProgress: 1.40%\rProgress: 1.41%\rProgress: 1.42%\rProgress: 1.43%\rProgress: 1.44%\rProgress: 1.45%\rProgress: 1.45%\rProgress: 1.46%\rProgress: 1.47%\rProgress: 1.48%\rProgress: 1.49%\rProgress: 1.50%\rProgress: 1.51%\rProgress: 1.52%\rProgress: 1.53%\rProgress: 1.54%\rProgress: 1.55%\rProgress: 1.56%\rProgress: 1.57%\rProgress: 1.58%\rProgress: 1.59%\rProgress: 1.60%\rProgress: 1.61%\rProgress: 1.62%" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/container.py:217: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n", + " input = module(input)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Progress: 100.00%\n", + "Training on /content/test-A/in.tsv.xz\n", + "Progress: 100.00%\n" + ] + } + ], + "source": [ + "def predict_word(ixs, model, top_k=5):\n", + " out = model(ixs)\n", + " top = torch.topk(out[0], 10)\n", + " top_indices = top.indices.tolist()\n", + " top_probs = top.values.tolist()\n", + " top_words = train_dataset.vocab.lookup_tokens(top_indices)\n", + " return list(zip(top_words, top_indices, top_probs))\n", + "\n", + "\n", + "def get_one_word(text, context=\"left\"):\n", + " # print(\"Getting word from:\", text)\n", + " if context == \"left\":\n", + " context = -1\n", + " else:\n", + " context = 0\n", + " return text.rstrip().split(\" \")[context]\n", + "\n", + "\n", + "def inference_on_file(filename, model, lines_no=1):\n", + " results_path = \"/\".join(filename.split(\"/\")[:-1]) + \"/out.tsv\"\n", + " with lzma.open(filename, \"r\") as fp, open(results_path, \"w\") as out_file:\n", + " print(\"Training on\", filename)\n", + " for i, line in enumerate(fp):\n", + " # left, right = [ get_one_word(text_part, context)\n", + " # for context, text_part in zip(line.split('\\t')[:-2], ('left', 'right'))]\n", + " line = line.decode(\"utf-8\")\n", + " # print(line)\n", + " left = get_one_word(line.split(\"\\t\")[-2])\n", + " # print(\"Current word:\", left)\n", + " tensor = torch.tensor(train_dataset.vocab.forward([left])).to(device)\n", + " results = predict_word(tensor, model, 9)\n", + " prob_sum = sum([word[2] for word in results])\n", + " result_line = (\n", + " \" \".join([f\"{word[0]}:{word[2]}\" for word in results])\n", + " + f\" :{prob_sum}\\n\"\n", + " )\n", + " # print(result_line)\n", + " out_file.write(result_line)\n", + " print(f\"\\rProgress: {(((i+1) / lines_no) * 100):.2f}%\", end=\"\")\n", + " print()\n", + "\n", + "\n", + "model.eval()\n", + "\n", + "for filepath, lines_no in zip(\n", + " (\"/content/dev-0/in.tsv.xz\", \"/content/test-A/in.tsv.xz\"), (10519.0, 7414.0)\n", + "):\n", + " inference_on_file(filepath, model, lines_no)\n" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "provenance": [] + }, + "gpuClass": "standard", + "kernelspec": { + "display_name": "mj_venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.6" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 0 +}