{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "KYySXV60UbL4",
        "outputId": "bb7d4752-ccc2-48cf-f5d4-540ffbcc1243"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "<torch._C.Generator at 0x7f3f2c178990>"
            ]
          },
          "execution_count": 2,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "import torch.optim as optim\n",
        "\n",
        "torch.manual_seed(1)\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Wle6wRL0UbL9"
      },
      "source": [
        "### Data loading"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "id": "63t_aV_IUbL-"
      },
      "outputs": [],
      "source": [
        "import pickle\n",
        "import lzma\n",
        "import regex as re\n",
        "\n",
        "\n",
        "def load_pickle(filename):\n",
        "    with open(filename, \"rb\") as f:\n",
        "        return pickle.load(f)\n",
        "\n",
        "\n",
        "def save_pickle(d):\n",
        "    with open(\"vocabulary.pkl\", \"wb\") as f:\n",
        "        pickle.dump(d, f, protocol=pickle.HIGHEST_PROTOCOL)\n",
        "\n",
        "\n",
        "def clean_document(document: str) -> str:\n",
        "    document = document.lower().replace(\"’\", \"'\")\n",
        "    document = re.sub(r\"'s|[\\-­]\\\\n\", \"\", document)\n",
        "    document = re.sub(\n",
        "        r\"(\\\\+n|[{}\\[\\]”&:•¦()*0-9;\\\"«»$\\-><^,®¬¿?¡!#+. \\t\\n])+\", \" \", document\n",
        "    )\n",
        "    for to_find, substitute in zip(\n",
        "        [\"i'm\", \"won't\", \"n't\", \"'ll\"], [\"i am\", \"will not\", \" not\", \" will\"]\n",
        "    ):\n",
        "        document = document.replace(to_find, substitute)\n",
        "    return document\n",
        "\n",
        "\n",
        "def get_words_from_line(line, clean_text=True):\n",
        "    if clean_text:\n",
        "        line = clean_document(line)  # .rstrip()\n",
        "    else:\n",
        "        line = line.strip()\n",
        "    yield \"<s>\"\n",
        "    for m in re.finditer(r\"[\\p{L}0-9\\*]+|\\p{P}+\", line):\n",
        "        yield m.group(0).lower()\n",
        "    yield \"</s>\"\n",
        "\n",
        "\n",
        "def get_word_lines_from_file(file_name, clean_text=True, only_text=False):\n",
        "    with lzma.open(file_name, \"r\") as fh:\n",
        "        for i, line in enumerate(fh):\n",
        "            if only_text:\n",
        "                line = \"\\t\".join(line.decode(\"utf-8\").split(\"\\t\")[:-2])\n",
        "            else:\n",
        "                line = line.decode(\"utf-8\")\n",
        "            if i % 10000 == 0:\n",
        "                print(i)\n",
        "            yield get_words_from_line(line, clean_text)\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "N7HGIM40UbL-"
      },
      "source": [
        "### Dataclasses"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "50ezMyyNUbL_",
        "outputId": "88af632f-fa88-43ea-ea7e-a709a2428c11"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "0\n",
            "10000\n",
            "20000\n",
            "30000\n",
            "40000\n",
            "50000\n",
            "60000\n",
            "70000\n",
            "80000\n",
            "90000\n",
            "100000\n",
            "110000\n",
            "120000\n",
            "130000\n",
            "140000\n",
            "150000\n",
            "160000\n",
            "170000\n",
            "180000\n",
            "190000\n",
            "200000\n",
            "210000\n",
            "220000\n",
            "230000\n",
            "240000\n",
            "250000\n",
            "260000\n",
            "270000\n",
            "280000\n",
            "290000\n",
            "300000\n",
            "310000\n",
            "320000\n",
            "330000\n",
            "340000\n",
            "350000\n",
            "360000\n",
            "370000\n",
            "380000\n",
            "390000\n",
            "400000\n",
            "410000\n",
            "420000\n",
            "430000\n"
          ]
        }
      ],
      "source": [
        "from torch.utils.data import IterableDataset\n",
        "from torchtext.vocab import build_vocab_from_iterator\n",
        "import itertools\n",
        "\n",
        "\n",
        "VOCAB_SIZE = 20000\n",
        "\n",
        "\n",
        "def look_ahead_iterator(gen):\n",
        "    prev = None\n",
        "    for item in gen:\n",
        "        if prev is not None:\n",
        "            yield (prev, item)\n",
        "        prev = item\n",
        "\n",
        "\n",
        "class Bigrams(IterableDataset):\n",
        "    def __init__(\n",
        "        self, text_file, vocabulary_size, vocab=None, only_text=False, clean_text=True\n",
        "    ):\n",
        "        self.vocab = (\n",
        "            build_vocab_from_iterator(\n",
        "                get_word_lines_from_file(text_file, clean_text, only_text),\n",
        "                max_tokens=vocabulary_size,\n",
        "                specials=[\"<unk>\"],\n",
        "            )\n",
        "            if vocab is None\n",
        "            else vocab\n",
        "        )\n",
        "        self.vocab.set_default_index(self.vocab[\"<unk>\"])\n",
        "        self.vocabulary_size = vocabulary_size\n",
        "        self.text_file = text_file\n",
        "        self.clean_text = clean_text\n",
        "        self.only_text = only_text\n",
        "\n",
        "    def __iter__(self):\n",
        "        return look_ahead_iterator(\n",
        "            (\n",
        "                self.vocab[t]\n",
        "                for t in itertools.chain.from_iterable(\n",
        "                    get_word_lines_from_file(\n",
        "                        self.text_file, self.clean_text, self.only_text\n",
        "                    )\n",
        "                )\n",
        "            )\n",
        "        )\n",
        "\n",
        "\n",
        "vocab = None  # torch.load('./vocab.pth')\n",
        "\n",
        "train_dataset = Bigrams(\"/content/train/in.tsv.xz\", VOCAB_SIZE, vocab, clean_text=False)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {
        "id": "SCZ87kVxUbL_"
      },
      "outputs": [],
      "source": [
        "# torch.save(train_dataset.vocab, \"vocab.pth\")\n",
        "# torch.save(train_dataset.vocab, \"vocab_only_text.pth\")\n",
        "# torch.save(train_dataset.vocab, \"vocab_only_text_clean.pth\")\n",
        "torch.save(train_dataset.vocab, \"vocab_2.pth\")\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bv_Adw_lUbMA"
      },
      "source": [
        "### Model definition"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {
        "id": "oaoLz3jPUbMB"
      },
      "outputs": [],
      "source": [
        "class SimpleBigramNeuralLanguageModel(nn.Module):\n",
        "    def __init__(self, vocabulary_size, embedding_size):\n",
        "        super(SimpleBigramNeuralLanguageModel, self).__init__()\n",
        "        self.model = nn.Sequential(\n",
        "            nn.Embedding(vocabulary_size, embedding_size),\n",
        "            nn.Linear(embedding_size, vocabulary_size),\n",
        "            nn.Softmax(),\n",
        "        )\n",
        "\n",
        "    def forward(self, x):\n",
        "        return self.model(x)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "metadata": {
        "id": "YBDf3nEvUbMC"
      },
      "outputs": [],
      "source": [
        "EMBED_SIZE = 100\n",
        "\n",
        "model = SimpleBigramNeuralLanguageModel(VOCAB_SIZE, EMBED_SIZE)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 9,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000
        },
        "id": "CRe3STJUUbMC",
        "outputId": "ab4b234a-49c7-4878-f8e4-e21072efa9ee"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "0\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/container.py:217: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n",
            "  input = module(input)\n"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "0 tensor(10.0674, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
            "100 tensor(8.4352, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
            "200 tensor(7.6662, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
            "300 tensor(7.0716, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
            "400 tensor(6.6710, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
            "500 tensor(6.4540, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
            "600 tensor(5.9974, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
            "700 tensor(5.7973, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
            "800 tensor(5.8026, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
            "10000\n",
            "900 tensor(5.7118, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
            "1000 tensor(5.7471, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
            "1100 tensor(5.6865, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
            "1200 tensor(5.4205, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
            "1300 tensor(5.4954, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
            "1400 tensor(5.5415, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
            "1500 tensor(5.3322, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
            "1600 tensor(5.4665, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
            "1700 tensor(5.4710, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
            "20000\n",
            "1800 tensor(5.3953, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
            "1900 tensor(5.4881, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
            "2000 tensor(5.4915, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
            "2100 tensor(5.3621, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
            "2200 tensor(5.2872, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
            "2300 tensor(5.2590, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
            "2400 tensor(5.3661, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
            "2500 tensor(5.3305, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
            "30000\n",
            "2600 tensor(5.3789, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
            "2700 tensor(5.3548, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
            "2800 tensor(5.4579, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
            "2900 tensor(5.2660, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
            "3000 tensor(5.3253, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
            "3100 tensor(5.4020, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
            "3200 tensor(5.2962, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
            "3300 tensor(5.2570, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
            "3400 tensor(5.2317, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
            "40000\n",
            "3500 tensor(5.2410, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
            "3600 tensor(5.2404, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
            "3700 tensor(5.1738, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
            "3800 tensor(5.2654, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
            "3900 tensor(5.2595, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
            "4000 tensor(5.2850, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
            "4100 tensor(5.2995, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
            "4200 tensor(5.2581, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
            "4300 tensor(5.3323, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
            "50000\n",
            "4400 tensor(5.2498, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
            "4500 tensor(5.2674, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
            "4600 tensor(5.3033, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
            "4700 tensor(5.2066, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
            "4800 tensor(5.2302, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
            "4900 tensor(5.2617, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
            "5000 tensor(5.2306, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
            "5100 tensor(5.2781, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
            "60000\n",
            "5200 tensor(5.1833, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
            "5300 tensor(5.2166, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
            "5400 tensor(5.0845, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
            "5500 tensor(5.2272, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
            "5600 tensor(5.3175, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
            "5700 tensor(5.2425, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
            "5800 tensor(5.2449, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
            "5900 tensor(5.3225, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
            "6000 tensor(5.2786, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
            "70000\n",
            "6100 tensor(5.1489, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
            "6200 tensor(5.1793, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
            "6300 tensor(5.2194, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
            "6400 tensor(5.1708, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
            "6500 tensor(5.1394, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
            "6600 tensor(5.1280, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
            "6700 tensor(5.0869, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
            "6800 tensor(5.3255, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
            "6900 tensor(5.3426, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
            "80000\n",
            "7000 tensor(5.1176, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
            "7100 tensor(5.1991, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
            "7200 tensor(5.1227, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
            "7300 tensor(5.1744, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
            "7400 tensor(5.2222, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
            "7500 tensor(5.2110, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
            "7600 tensor(5.1553, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
            "7700 tensor(5.3283, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
            "90000\n",
            "7800 tensor(5.2544, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
            "7900 tensor(5.1871, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
            "8000 tensor(5.2215, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
            "8100 tensor(5.1744, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
            "8200 tensor(5.1087, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
            "8300 tensor(5.1639, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
            "8400 tensor(5.1604, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
            "8500 tensor(5.1612, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
            "8600 tensor(5.2307, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
            "100000\n",
            "8700 tensor(5.1648, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
            "8800 tensor(5.1066, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
            "8900 tensor(5.2405, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
            "9000 tensor(5.2184, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
            "9100 tensor(5.2677, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
            "9200 tensor(5.0773, device='cuda:0', grad_fn=<NllLossBackward0>)\n"
          ]
        },
        {
          "ename": "KeyboardInterrupt",
          "evalue": "ignored",
          "output_type": "error",
          "traceback": [
            "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
            "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
            "\u001b[0;32m<ipython-input-9-c690ed9ba7ad>\u001b[0m in \u001b[0;36m<cell line: 11>\u001b[0;34m()\u001b[0m\n\u001b[1;32m     18\u001b[0m         \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mloss\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     19\u001b[0m     \u001b[0mstep\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 20\u001b[0;31m     \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     21\u001b[0m     \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     22\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/_tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m    485\u001b[0m                 \u001b[0minputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    486\u001b[0m             )\n\u001b[0;32m--> 487\u001b[0;31m         torch.autograd.backward(\n\u001b[0m\u001b[1;32m    488\u001b[0m             \u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    489\u001b[0m         )\n",
            "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/autograd/__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m    198\u001b[0m     \u001b[0;31m# some Python versions print out the first line of a multi-line function\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    199\u001b[0m     \u001b[0;31m# calls in the traceback and some print out the last line\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 200\u001b[0;31m     Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass\n\u001b[0m\u001b[1;32m    201\u001b[0m         \u001b[0mtensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad_tensors_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    202\u001b[0m         allow_unreachable=True, accumulate_grad=True)  # Calls into the C++ engine to run the backward pass\n",
            "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
          ]
        }
      ],
      "source": [
        "from torch.utils.data import DataLoader\n",
        "\n",
        "device = \"cuda\"\n",
        "model = SimpleBigramNeuralLanguageModel(VOCAB_SIZE, EMBED_SIZE).to(device)\n",
        "data = DataLoader(train_dataset, batch_size=5000)\n",
        "optimizer = torch.optim.Adam(model.parameters())\n",
        "criterion = torch.nn.NLLLoss()\n",
        "\n",
        "model.train()\n",
        "step = 0\n",
        "for x, y in data:\n",
        "    x = x.to(device)\n",
        "    y = y.to(device)\n",
        "    optimizer.zero_grad()\n",
        "    ypredicted = model(x)\n",
        "    loss = criterion(torch.log(ypredicted), y)\n",
        "    if step % 100 == 0:\n",
        "        print(step, loss)\n",
        "    step += 1\n",
        "    loss.backward()\n",
        "    optimizer.step()\n",
        "\n",
        "torch.save(model.state_dict(), \"model_2.bin\")\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 10,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "kS9NHTGeom3y",
        "outputId": "ce83640e-d2fe-41e6-cd5d-38a5b0410323"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "[('the', 2, 0.15899169445037842),\n",
              " ('\\\\', 1, 0.10546761751174927),\n",
              " ('he', 28, 0.06849857419729233),\n",
              " ('it', 15, 0.05329886078834534),\n",
              " ('i', 26, 0.0421920120716095),\n",
              " ('they', 50, 0.03895237296819687),\n",
              " ('a', 8, 0.03352600708603859),\n",
              " ('<unk>', 0, 0.031062396243214607),\n",
              " ('we', 61, 0.02323235757648945),\n",
              " ('she', 104, 0.02003088779747486)]"
            ]
          },
          "execution_count": 10,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "ixs = torch.tensor(train_dataset.vocab.forward([\"when\"])).to(device)\n",
        "out = model(ixs)\n",
        "top = torch.topk(out[0], 10)\n",
        "top_indices = top.indices.tolist()\n",
        "top_probs = top.values.tolist()\n",
        "top_words = train_dataset.vocab.lookup_tokens(top_indices)\n",
        "list(zip(top_words, top_indices, top_probs))\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 13,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "UTYHTCJvC_Nm",
        "outputId": "576288fa-fdad-4c21-d924-f613eaf33063"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "<All keys matched successfully>"
            ]
          },
          "execution_count": 13,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "device = \"cuda\"\n",
        "model = SimpleBigramNeuralLanguageModel(VOCAB_SIZE, EMBED_SIZE).to(device)\n",
        "model.load_state_dict(torch.load(\"model1.bin\"))\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 14,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "8WexjGIAxaE4",
        "outputId": "52252b81-3b98-42d3-b137-472af00dbb26"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Training on /content/dev-0/in.tsv.xz\n",
            "\rProgress: 0.01%\rProgress: 0.02%\rProgress: 0.03%\rProgress: 0.04%\rProgress: 0.05%\rProgress: 0.06%\rProgress: 0.07%\rProgress: 0.08%\rProgress: 0.09%\rProgress: 0.10%\rProgress: 0.10%\rProgress: 0.11%\rProgress: 0.12%\rProgress: 0.13%\rProgress: 0.14%\rProgress: 0.15%\rProgress: 0.16%\rProgress: 0.17%\rProgress: 0.18%\rProgress: 0.19%\rProgress: 0.20%\rProgress: 0.21%\rProgress: 0.22%\rProgress: 0.23%\rProgress: 0.24%\rProgress: 0.25%\rProgress: 0.26%\rProgress: 0.27%\rProgress: 0.28%\rProgress: 0.29%\rProgress: 0.29%\rProgress: 0.30%\rProgress: 0.31%\rProgress: 0.32%\rProgress: 0.33%\rProgress: 0.34%\rProgress: 0.35%\rProgress: 0.36%\rProgress: 0.37%\rProgress: 0.38%\rProgress: 0.39%\rProgress: 0.40%\rProgress: 0.41%\rProgress: 0.42%\rProgress: 0.43%\rProgress: 0.44%\rProgress: 0.45%\rProgress: 0.46%\rProgress: 0.47%\rProgress: 0.48%\rProgress: 0.48%\rProgress: 0.49%\rProgress: 0.50%\rProgress: 0.51%\rProgress: 0.52%\rProgress: 0.53%\rProgress: 0.54%\rProgress: 0.55%\rProgress: 0.56%\rProgress: 0.57%\rProgress: 0.58%\rProgress: 0.59%\rProgress: 0.60%\rProgress: 0.61%\rProgress: 0.62%\rProgress: 0.63%\rProgress: 0.64%\rProgress: 0.65%\rProgress: 0.66%\rProgress: 0.67%\rProgress: 0.67%\rProgress: 0.68%\rProgress: 0.69%\rProgress: 0.70%\rProgress: 0.71%\rProgress: 0.72%\rProgress: 0.73%\rProgress: 0.74%\rProgress: 0.75%\rProgress: 0.76%\rProgress: 0.77%\rProgress: 0.78%\rProgress: 0.79%\rProgress: 0.80%\rProgress: 0.81%\rProgress: 0.82%\rProgress: 0.83%\rProgress: 0.84%\rProgress: 0.85%\rProgress: 0.86%\rProgress: 0.87%\rProgress: 0.87%\rProgress: 0.88%\rProgress: 0.89%\rProgress: 0.90%\rProgress: 0.91%\rProgress: 0.92%\rProgress: 0.93%\rProgress: 0.94%\rProgress: 0.95%\rProgress: 0.96%\rProgress: 0.97%\rProgress: 0.98%\rProgress: 0.99%\rProgress: 1.00%\rProgress: 1.01%\rProgress: 1.02%\rProgress: 1.03%\rProgress: 1.04%\rProgress: 1.05%\rProgress: 1.06%\rProgress: 1.06%\rProgress: 1.07%\rProgress: 1.08%\rProgress: 1.09%\rProgress: 1.10%\rProgress: 1.11%\rProgress: 1.12%\rProgress: 1.13%\rProgress: 1.14%\rProgress: 1.15%\rProgress: 1.16%\rProgress: 1.17%\rProgress: 1.18%\rProgress: 1.19%\rProgress: 1.20%\rProgress: 1.21%\rProgress: 1.22%\rProgress: 1.23%\rProgress: 1.24%\rProgress: 1.25%\rProgress: 1.25%\rProgress: 1.26%\rProgress: 1.27%\rProgress: 1.28%\rProgress: 1.29%\rProgress: 1.30%\rProgress: 1.31%\rProgress: 1.32%\rProgress: 1.33%\rProgress: 1.34%\rProgress: 1.35%\rProgress: 1.36%\rProgress: 1.37%\rProgress: 1.38%\rProgress: 1.39%\rProgress: 1.40%\rProgress: 1.41%\rProgress: 1.42%\rProgress: 1.43%\rProgress: 1.44%\rProgress: 1.45%\rProgress: 1.45%\rProgress: 1.46%\rProgress: 1.47%\rProgress: 1.48%\rProgress: 1.49%\rProgress: 1.50%\rProgress: 1.51%\rProgress: 1.52%\rProgress: 1.53%\rProgress: 1.54%\rProgress: 1.55%\rProgress: 1.56%\rProgress: 1.57%\rProgress: 1.58%\rProgress: 1.59%\rProgress: 1.60%\rProgress: 1.61%\rProgress: 1.62%"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/container.py:217: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n",
            "  input = module(input)\n"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Progress: 100.00%\n",
            "Training on /content/test-A/in.tsv.xz\n",
            "Progress: 100.00%\n"
          ]
        }
      ],
      "source": [
        "def predict_word(ixs, model, top_k=5):\n",
        "    out = model(ixs)\n",
        "    top = torch.topk(out[0], 10)\n",
        "    top_indices = top.indices.tolist()\n",
        "    top_probs = top.values.tolist()\n",
        "    top_words = train_dataset.vocab.lookup_tokens(top_indices)\n",
        "    return list(zip(top_words, top_indices, top_probs))\n",
        "\n",
        "\n",
        "def get_one_word(text, context=\"left\"):\n",
        "    # print(\"Getting word from:\", text)\n",
        "    if context == \"left\":\n",
        "        context = -1\n",
        "    else:\n",
        "        context = 0\n",
        "    return text.rstrip().split(\" \")[context]\n",
        "\n",
        "\n",
        "def inference_on_file(filename, model, lines_no=1):\n",
        "    results_path = \"/\".join(filename.split(\"/\")[:-1]) + \"/out.tsv\"\n",
        "    with lzma.open(filename, \"r\") as fp, open(results_path, \"w\") as out_file:\n",
        "        print(\"Training on\", filename)\n",
        "        for i, line in enumerate(fp):\n",
        "            # left, right = [ get_one_word(text_part, context)\n",
        "            # for context, text_part in zip(line.split('\\t')[:-2], ('left', 'right'))]\n",
        "            line = line.decode(\"utf-8\")\n",
        "            # print(line)\n",
        "            left = get_one_word(line.split(\"\\t\")[-2])\n",
        "            # print(\"Current word:\", left)\n",
        "            tensor = torch.tensor(train_dataset.vocab.forward([left])).to(device)\n",
        "            results = predict_word(tensor, model, 9)\n",
        "            prob_sum = sum([word[2] for word in results])\n",
        "            result_line = (\n",
        "                \" \".join([f\"{word[0]}:{word[2]}\" for word in results])\n",
        "                + f\" :{prob_sum}\\n\"\n",
        "            )\n",
        "            # print(result_line)\n",
        "            out_file.write(result_line)\n",
        "            print(f\"\\rProgress: {(((i+1) / lines_no) * 100):.2f}%\", end=\"\")\n",
        "        print()\n",
        "\n",
        "\n",
        "model.eval()\n",
        "\n",
        "for filepath, lines_no in zip(\n",
        "    (\"/content/dev-0/in.tsv.xz\", \"/content/test-A/in.tsv.xz\"), (10519.0, 7414.0)\n",
        "):\n",
        "    inference_on_file(filepath, model, lines_no)\n"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "provenance": []
    },
    "gpuClass": "standard",
    "kernelspec": {
      "display_name": "mj_venv",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.10.6"
    },
    "orig_nbformat": 4
  },
  "nbformat": 4,
  "nbformat_minor": 0
}