{ "cells": [ { "cell_type": "markdown", "metadata": { "collapsed": false, "id": "8Iy6jV8cXBuT" }, "source": [ "## Imports" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": true, "id": "vLUNBqCuXBuV", "pycharm": { "is_executing": true } }, "outputs": [], "source": [ "import itertools\n", "import lzma\n", "\n", "import regex as re\n", "import torch\n", "from torch import nn\n", "from torch.utils.data import IterableDataset, DataLoader\n", "from torchtext.vocab import build_vocab_from_iterator\n", "from google.colab import drive" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false, "id": "y8M2LxjXXBuY" }, "source": [ "## Definitions" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false, "id": "wMM1C4pKXBuY" }, "source": [ "### Functions" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "id": "VYFHWbTlXBuZ" }, "outputs": [], "source": [ "def clean_text(line: str):\n", " # Preprocessing\n", " separated = line.split('\\t')\n", " prefix = separated[6].replace(r'\\n', ' ').replace('\\\\n', ' ').replace(' ', ' ').replace('.', '').replace(',', '').replace('?', '').replace('!', '').replace('(', '').replace(')', '').replace(';', '').replace(':', '').replace('\"', '').replace(\"'\", '').replace('-', ' ').replace(' ', ' ')\n", " suffix = separated[7].replace(r'\\n', ' ').replace('\\\\n', ' ').replace(' ', ' ').replace('.', '').replace(',', '').replace('?', '').replace('!', '').replace('(', '').replace(')', '').replace(';', '').replace(':', '').replace('\"', '').replace(\"'\", '').replace('-', ' ').replace(' ', ' ')\n", " return prefix + ' ' + suffix" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "id": "qycsWH4gXBua" }, "outputs": [], "source": [ "def get_words_from_line(line):\n", " line = clean_text(line)\n", " for word in line.split():\n", " yield word" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "id": "S3JF1_zWXBua" }, "outputs": [], "source": [ "def get_word_lines_from_file(file_name):\n", " with lzma.open(file_name, mode='rt', encoding='utf-8') as fid:\n", " for line in fid:\n", " yield get_words_from_line(line)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "id": "-20wlI9hXBub" }, "outputs": [], "source": [ "def look_ahead_iterator(gen):\n", " prev = None\n", " for item in gen:\n", " if prev is not None:\n", " yield (prev, item)\n", " prev = item" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "id": "jL5ZrQGMXBub" }, "outputs": [], "source": [ "def prediction(word: str) -> str:\n", " ixs = torch.tensor(vocab.forward([word])).to(device)\n", " out = model(ixs)\n", " top = torch.topk(out[0], 5)\n", " top_indices = top.indices.tolist()\n", " top_probs = top.values.tolist()\n", " top_words = vocab.lookup_tokens(top_indices)\n", " zipped = list(zip(top_words, top_probs))\n", " for index, element in enumerate(zipped):\n", " unk = None\n", " if '' in element:\n", " unk = zipped.pop(index)\n", " zipped.append(('', unk[1]))\n", " break\n", " if unk is None:\n", " zipped[-1] = ('', zipped[-1][1])\n", " return ' '.join([f'{x[0]}:{x[1]}' for x in zipped])" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "id": "KByjDByYXBuc" }, "outputs": [], "source": [ "def save_outs(folder_name):\n", " print(f'Creating outputs in {folder_name}')\n", " with lzma.open(f'/content/drive/MyDrive/Colab Notebooks/{folder_name}/in.tsv.xz', mode='rt', encoding='utf-8') as fid:\n", " with open(f'/content/drive/MyDrive/Colab Notebooks/{folder_name}/out.tsv', 'w', encoding='utf-8', newline='\\n') as f:\n", " for line in fid:\n", " separated = line.split('\\t')\n", " prefix = separated[6].replace(r'\\n', ' ').split()[-1]\n", " output_line = prediction(prefix)\n", " f.write(output_line + '\\n')" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false, "id": "dHW2X57NXBud" }, "source": [ "### Classes" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class Bigrams(IterableDataset):\n", " def __init__(self, text_file, vocabulary_size):\n", " self.vocab = build_vocab_from_iterator(\n", " get_word_lines_from_file(text_file),\n", " max_tokens=vocabulary_size,\n", " specials=[''])\n", " self.vocab.set_default_index(self.vocab[''])\n", " self.vocabulary_size = vocabulary_size\n", " self.text_file = text_file\n", "\n", " def __iter__(self):\n", " return look_ahead_iterator(\n", " (self.vocab[t] for t in itertools.chain.from_iterable(get_word_lines_from_file(self.text_file))))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "XQD2jLnOXBue" }, "outputs": [], "source": [ "class SimpleBigramNeuralLanguageModel(nn.Module):\n", " def __init__(self, vocabulary_size, embedding_size):\n", " super(SimpleBigramNeuralLanguageModel, self).__init__()\n", " self.model = nn.Sequential(\n", " nn.Embedding(vocabulary_size, embedding_size),\n", " nn.Linear(embedding_size, vocabulary_size),\n", " nn.Softmax()\n", " )\n", "\n", " def forward(self, x):\n", " return self.model(x)" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false, "id": "Mvodzlq6XBuf" }, "source": [ "## Training" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false, "id": "zUDc1k5cXBuf" }, "source": [ "### Params" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "id": "ndnatbe3XBug" }, "outputs": [], "source": [ "vocab_size = 10000\n", "embed_size = 100\n", "batch_size = 2000\n", "device = 'cuda'\n", "path_to_train = '/content/drive/MyDrive/Colab Notebooks/train/in.tsv.xz'\n", "path_to_model = 'modelneural_bigram.bin'" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false, "id": "7wF-1JG-XBug" }, "source": [ "### Colab" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "Sf4dvmOPXBuh", "outputId": "3ac75e94-6acd-4906-e9c0-5a5bbe099566" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Mounted at /content/drive\n", "/content/drive/MyDrive\n" ] } ], "source": [ "drive.mount('/content/drive')\n", "%cd /content/drive/MyDrive/" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false, "id": "aeSaf6vvXBuh" }, "source": [ "### Run" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "id": "dzWDCLo0XBuh" }, "outputs": [], "source": [ "vocab = build_vocab_from_iterator(\n", " get_word_lines_from_file(path_to_train),\n", " max_tokens=vocab_size,\n", " specials=['']\n", ")\n", "\n", "vocab.set_default_index(vocab[''])" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "id": "FRo29Q3bXBui" }, "outputs": [], "source": [ "train_dataset = Bigrams(path_to_train, vocab_size)" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "mYxBeXjwXBui", "outputId": "ebd5218f-6a5b-49ec-a2da-e478d63fe50d" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/container.py:217: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n", " input = module(input)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "0 tensor(9.4517, device='cuda:0', grad_fn=)\n", "100 tensor(7.9341, device='cuda:0', grad_fn=)\n", "200 tensor(7.1452, device='cuda:0', grad_fn=)\n", "300 tensor(6.7956, device='cuda:0', grad_fn=)\n", "400 tensor(6.4127, device='cuda:0', grad_fn=)\n", "500 tensor(6.3407, device='cuda:0', grad_fn=)\n", "600 tensor(6.2125, device='cuda:0', grad_fn=)\n", "700 tensor(5.7817, device='cuda:0', grad_fn=)\n", "800 tensor(5.7309, device='cuda:0', grad_fn=)\n", "900 tensor(5.7419, device='cuda:0', grad_fn=)\n", "1000 tensor(5.7372, device='cuda:0', grad_fn=)\n", "1100 tensor(5.2804, device='cuda:0', grad_fn=)\n", "1200 tensor(5.4610, device='cuda:0', grad_fn=)\n", "1300 tensor(5.6610, device='cuda:0', grad_fn=)\n", "1400 tensor(5.3070, device='cuda:0', grad_fn=)\n", "1500 tensor(4.9666, device='cuda:0', grad_fn=)\n", "1600 tensor(5.2102, device='cuda:0', grad_fn=)\n", "1700 tensor(5.4919, device='cuda:0', grad_fn=)\n", "1800 tensor(5.1968, device='cuda:0', grad_fn=)\n", "1900 tensor(5.3336, device='cuda:0', grad_fn=)\n", "2000 tensor(5.2387, device='cuda:0', grad_fn=)\n", "2100 tensor(5.2247, device='cuda:0', grad_fn=)\n", "2200 tensor(5.2544, device='cuda:0', grad_fn=)\n", "2300 tensor(5.3343, device='cuda:0', grad_fn=)\n", "2400 tensor(5.3077, device='cuda:0', grad_fn=)\n", "2500 tensor(5.1209, device='cuda:0', grad_fn=)\n", "2600 tensor(5.3806, device='cuda:0', grad_fn=)\n", "2700 tensor(5.2865, device='cuda:0', grad_fn=)\n", "2800 tensor(5.2625, device='cuda:0', grad_fn=)\n", "2900 tensor(5.2476, device='cuda:0', grad_fn=)\n", "3000 tensor(5.2663, device='cuda:0', grad_fn=)\n", "3100 tensor(5.0200, device='cuda:0', grad_fn=)\n", "3200 tensor(5.2324, device='cuda:0', grad_fn=)\n", "3300 tensor(5.1963, device='cuda:0', grad_fn=)\n", "3400 tensor(5.1108, device='cuda:0', grad_fn=)\n", "3500 tensor(5.1499, device='cuda:0', grad_fn=)\n", "3600 tensor(5.3241, device='cuda:0', grad_fn=)\n", "3700 tensor(5.1977, device='cuda:0', grad_fn=)\n", "3800 tensor(5.1466, device='cuda:0', grad_fn=)\n", "3900 tensor(5.2557, device='cuda:0', grad_fn=)\n", "4000 tensor(5.0468, device='cuda:0', grad_fn=)\n", "4100 tensor(5.1882, device='cuda:0', grad_fn=)\n", "4200 tensor(5.0748, device='cuda:0', grad_fn=)\n", "4300 tensor(4.9577, device='cuda:0', grad_fn=)\n", "4400 tensor(4.8100, device='cuda:0', grad_fn=)\n", "4500 tensor(5.0355, device='cuda:0', grad_fn=)\n", "4600 tensor(5.1247, device='cuda:0', grad_fn=)\n", "4700 tensor(5.0516, device='cuda:0', grad_fn=)\n", "4800 tensor(4.9036, device='cuda:0', grad_fn=)\n", "4900 tensor(5.0096, device='cuda:0', grad_fn=)\n", "5000 tensor(5.2085, device='cuda:0', grad_fn=)\n", "5100 tensor(5.0944, device='cuda:0', grad_fn=)\n", "5200 tensor(5.1592, device='cuda:0', grad_fn=)\n", "5300 tensor(5.2019, device='cuda:0', grad_fn=)\n", "5400 tensor(5.2048, device='cuda:0', grad_fn=)\n", "5500 tensor(5.0499, device='cuda:0', grad_fn=)\n", "5600 tensor(5.0369, device='cuda:0', grad_fn=)\n", "5700 tensor(5.2581, device='cuda:0', grad_fn=)\n", "5800 tensor(5.0312, device='cuda:0', grad_fn=)\n", "5900 tensor(5.0513, device='cuda:0', grad_fn=)\n", "6000 tensor(5.2384, device='cuda:0', grad_fn=)\n", "6100 tensor(5.0257, device='cuda:0', grad_fn=)\n", "6200 tensor(5.1156, device='cuda:0', grad_fn=)\n", "6300 tensor(4.9953, device='cuda:0', grad_fn=)\n", "6400 tensor(5.2028, device='cuda:0', grad_fn=)\n", "6500 tensor(4.8426, device='cuda:0', grad_fn=)\n", "6600 tensor(5.0661, device='cuda:0', grad_fn=)\n", "6700 tensor(5.0976, device='cuda:0', grad_fn=)\n", "6800 tensor(4.9180, device='cuda:0', grad_fn=)\n", "6900 tensor(4.9928, device='cuda:0', grad_fn=)\n", "7000 tensor(5.1889, device='cuda:0', grad_fn=)\n", "7100 tensor(4.9612, device='cuda:0', grad_fn=)\n", "7200 tensor(5.1408, device='cuda:0', grad_fn=)\n", "7300 tensor(5.0562, device='cuda:0', grad_fn=)\n", "7400 tensor(4.8779, device='cuda:0', grad_fn=)\n", "7500 tensor(5.0490, device='cuda:0', grad_fn=)\n", "7600 tensor(5.0678, device='cuda:0', grad_fn=)\n", "7700 tensor(4.9938, device='cuda:0', grad_fn=)\n", "7800 tensor(5.0301, device='cuda:0', grad_fn=)\n", "7900 tensor(5.2542, device='cuda:0', grad_fn=)\n", "8000 tensor(4.8772, device='cuda:0', grad_fn=)\n", "8100 tensor(5.0953, device='cuda:0', grad_fn=)\n", "8200 tensor(5.0217, device='cuda:0', grad_fn=)\n", "8300 tensor(5.0107, device='cuda:0', grad_fn=)\n", "8400 tensor(5.0733, device='cuda:0', grad_fn=)\n", "8500 tensor(4.5262, device='cuda:0', grad_fn=)\n", "8600 tensor(5.0271, device='cuda:0', grad_fn=)\n", "8700 tensor(4.6307, device='cuda:0', grad_fn=)\n", "8800 tensor(4.9917, device='cuda:0', grad_fn=)\n", "8900 tensor(5.1940, device='cuda:0', grad_fn=)\n", "9000 tensor(5.0302, device='cuda:0', grad_fn=)\n", "9100 tensor(5.0956, device='cuda:0', grad_fn=)\n", "9200 tensor(5.0438, device='cuda:0', grad_fn=)\n", "9300 tensor(5.0134, device='cuda:0', grad_fn=)\n", "9400 tensor(5.2201, device='cuda:0', grad_fn=)\n", "9500 tensor(4.8876, device='cuda:0', grad_fn=)\n", "9600 tensor(5.1474, device='cuda:0', grad_fn=)\n", "9700 tensor(5.0169, device='cuda:0', grad_fn=)\n", "9800 tensor(5.0743, device='cuda:0', grad_fn=)\n", "9900 tensor(4.9008, device='cuda:0', grad_fn=)\n", "10000 tensor(5.1381, device='cuda:0', grad_fn=)\n", "10100 tensor(5.0524, device='cuda:0', grad_fn=)\n", "10200 tensor(5.0369, device='cuda:0', grad_fn=)\n", "10300 tensor(5.0595, device='cuda:0', grad_fn=)\n", "10400 tensor(5.0138, device='cuda:0', grad_fn=)\n", "10500 tensor(5.0164, device='cuda:0', grad_fn=)\n", "10600 tensor(4.9153, device='cuda:0', grad_fn=)\n", "10700 tensor(4.9971, device='cuda:0', grad_fn=)\n", "10800 tensor(5.0200, device='cuda:0', grad_fn=)\n", "10900 tensor(4.9631, device='cuda:0', grad_fn=)\n", "11000 tensor(4.9385, device='cuda:0', grad_fn=)\n", "11100 tensor(4.9851, device='cuda:0', grad_fn=)\n", "11200 tensor(5.0681, device='cuda:0', grad_fn=)\n", "11300 tensor(5.1261, device='cuda:0', grad_fn=)\n", "11400 tensor(5.0098, device='cuda:0', grad_fn=)\n", "11500 tensor(5.1261, device='cuda:0', grad_fn=)\n", "11600 tensor(5.1213, device='cuda:0', grad_fn=)\n", "11700 tensor(5.0265, device='cuda:0', grad_fn=)\n", "11800 tensor(4.7047, device='cuda:0', grad_fn=)\n", "11900 tensor(5.1954, device='cuda:0', grad_fn=)\n", "12000 tensor(5.0850, device='cuda:0', grad_fn=)\n", "12100 tensor(4.9762, device='cuda:0', grad_fn=)\n", "12200 tensor(5.0162, device='cuda:0', grad_fn=)\n", "12300 tensor(4.9834, device='cuda:0', grad_fn=)\n", "12400 tensor(4.8953, device='cuda:0', grad_fn=)\n", "12500 tensor(5.0389, device='cuda:0', grad_fn=)\n", "12600 tensor(4.9266, device='cuda:0', grad_fn=)\n", "12700 tensor(5.0132, device='cuda:0', grad_fn=)\n", "12800 tensor(5.1777, device='cuda:0', grad_fn=)\n", "12900 tensor(4.8290, device='cuda:0', grad_fn=)\n", "13000 tensor(5.0639, device='cuda:0', grad_fn=)\n", "13100 tensor(5.0565, device='cuda:0', grad_fn=)\n", "13200 tensor(5.0222, device='cuda:0', grad_fn=)\n", "13300 tensor(5.2150, device='cuda:0', grad_fn=)\n", "13400 tensor(4.9393, device='cuda:0', grad_fn=)\n", "13500 tensor(5.0270, device='cuda:0', grad_fn=)\n", "13600 tensor(4.9520, device='cuda:0', grad_fn=)\n", "13700 tensor(4.9845, device='cuda:0', grad_fn=)\n", "13800 tensor(4.8543, device='cuda:0', grad_fn=)\n", "13900 tensor(4.8892, device='cuda:0', grad_fn=)\n", "14000 tensor(4.9802, device='cuda:0', grad_fn=)\n", "14100 tensor(4.9833, device='cuda:0', grad_fn=)\n", "14200 tensor(4.9348, device='cuda:0', grad_fn=)\n", "14300 tensor(4.9561, device='cuda:0', grad_fn=)\n", "14400 tensor(5.0198, device='cuda:0', grad_fn=)\n", "14500 tensor(4.9878, device='cuda:0', grad_fn=)\n", "14600 tensor(4.7517, device='cuda:0', grad_fn=)\n", "14700 tensor(4.9452, device='cuda:0', grad_fn=)\n", "14800 tensor(4.8229, device='cuda:0', grad_fn=)\n", "14900 tensor(5.1425, device='cuda:0', grad_fn=)\n", "15000 tensor(4.9122, device='cuda:0', grad_fn=)\n", "15100 tensor(4.8217, device='cuda:0', grad_fn=)\n", "15200 tensor(4.8604, device='cuda:0', grad_fn=)\n", "15300 tensor(5.1151, device='cuda:0', grad_fn=)\n", "15400 tensor(4.9545, device='cuda:0', grad_fn=)\n", "15500 tensor(5.0922, device='cuda:0', grad_fn=)\n", "15600 tensor(4.7891, device='cuda:0', grad_fn=)\n", "15700 tensor(4.6318, device='cuda:0', grad_fn=)\n", "15800 tensor(4.9540, device='cuda:0', grad_fn=)\n", "15900 tensor(4.7681, device='cuda:0', grad_fn=)\n", "16000 tensor(4.9602, device='cuda:0', grad_fn=)\n", "16100 tensor(4.9705, device='cuda:0', grad_fn=)\n", "16200 tensor(4.8296, device='cuda:0', grad_fn=)\n", "16300 tensor(5.0188, device='cuda:0', grad_fn=)\n", "16400 tensor(5.1062, device='cuda:0', grad_fn=)\n", "16500 tensor(5.2549, device='cuda:0', grad_fn=)\n", "16600 tensor(5.1164, device='cuda:0', grad_fn=)\n", "16700 tensor(4.9399, device='cuda:0', grad_fn=)\n", "16800 tensor(5.1161, device='cuda:0', grad_fn=)\n", "16900 tensor(4.9115, device='cuda:0', grad_fn=)\n", "17000 tensor(4.7572, device='cuda:0', grad_fn=)\n", "17100 tensor(4.9667, device='cuda:0', grad_fn=)\n", "17200 tensor(4.7463, device='cuda:0', grad_fn=)\n", "17300 tensor(4.9038, device='cuda:0', grad_fn=)\n", "17400 tensor(4.9859, device='cuda:0', grad_fn=)\n", "17500 tensor(5.0652, device='cuda:0', grad_fn=)\n", "17600 tensor(4.6641, device='cuda:0', grad_fn=)\n", "17700 tensor(4.9265, device='cuda:0', grad_fn=)\n", "17800 tensor(5.0095, device='cuda:0', grad_fn=)\n", "17900 tensor(5.1090, device='cuda:0', grad_fn=)\n", "18000 tensor(4.9015, device='cuda:0', grad_fn=)\n", "18100 tensor(4.9997, device='cuda:0', grad_fn=)\n", "18200 tensor(4.8359, device='cuda:0', grad_fn=)\n", "18300 tensor(4.7353, device='cuda:0', grad_fn=)\n", "18400 tensor(4.9657, device='cuda:0', grad_fn=)\n", "18500 tensor(4.9856, device='cuda:0', grad_fn=)\n", "18600 tensor(5.0571, device='cuda:0', grad_fn=)\n", "18700 tensor(4.8566, device='cuda:0', grad_fn=)\n", "18800 tensor(4.9819, device='cuda:0', grad_fn=)\n", "18900 tensor(4.9809, device='cuda:0', grad_fn=)\n", "19000 tensor(5.0202, device='cuda:0', grad_fn=)\n", "19100 tensor(5.1329, device='cuda:0', grad_fn=)\n", "19200 tensor(5.0460, device='cuda:0', grad_fn=)\n", "19300 tensor(4.9174, device='cuda:0', grad_fn=)\n", "19400 tensor(5.1266, device='cuda:0', grad_fn=)\n", "19500 tensor(4.8903, device='cuda:0', grad_fn=)\n", "19600 tensor(5.0548, device='cuda:0', grad_fn=)\n", "19700 tensor(4.9530, device='cuda:0', grad_fn=)\n", "19800 tensor(4.9296, device='cuda:0', grad_fn=)\n", "19900 tensor(4.9925, device='cuda:0', grad_fn=)\n", "20000 tensor(4.9181, device='cuda:0', grad_fn=)\n", "20100 tensor(4.9487, device='cuda:0', grad_fn=)\n", "20200 tensor(5.0580, device='cuda:0', grad_fn=)\n", "20300 tensor(5.1110, device='cuda:0', grad_fn=)\n", "20400 tensor(4.8053, device='cuda:0', grad_fn=)\n", "20500 tensor(4.7658, device='cuda:0', grad_fn=)\n", "20600 tensor(4.7387, device='cuda:0', grad_fn=)\n", "20700 tensor(4.9779, device='cuda:0', grad_fn=)\n", "20800 tensor(4.8901, device='cuda:0', grad_fn=)\n", "20900 tensor(4.9092, device='cuda:0', grad_fn=)\n", "21000 tensor(5.2856, device='cuda:0', grad_fn=)\n", "21100 tensor(4.9803, device='cuda:0', grad_fn=)\n", "21200 tensor(4.6889, device='cuda:0', grad_fn=)\n", "21300 tensor(4.8434, device='cuda:0', grad_fn=)\n", "21400 tensor(4.7451, device='cuda:0', grad_fn=)\n", "21500 tensor(4.9406, device='cuda:0', grad_fn=)\n", "21600 tensor(4.8431, device='cuda:0', grad_fn=)\n", "21700 tensor(4.9932, device='cuda:0', grad_fn=)\n", "21800 tensor(4.6696, device='cuda:0', grad_fn=)\n", "21900 tensor(4.8091, device='cuda:0', grad_fn=)\n", "22000 tensor(4.7533, device='cuda:0', grad_fn=)\n", "22100 tensor(4.6842, device='cuda:0', grad_fn=)\n", "22200 tensor(4.8844, device='cuda:0', grad_fn=)\n", "22300 tensor(5.1038, device='cuda:0', grad_fn=)\n", "22400 tensor(4.9929, device='cuda:0', grad_fn=)\n", "22500 tensor(5.0109, device='cuda:0', grad_fn=)\n", "22600 tensor(4.8278, device='cuda:0', grad_fn=)\n", "22700 tensor(4.8597, device='cuda:0', grad_fn=)\n", "22800 tensor(5.0256, device='cuda:0', grad_fn=)\n", "22900 tensor(4.4663, device='cuda:0', grad_fn=)\n", "23000 tensor(4.6069, device='cuda:0', grad_fn=)\n", "23100 tensor(5.0816, device='cuda:0', grad_fn=)\n", "23200 tensor(4.9038, device='cuda:0', grad_fn=)\n", "23300 tensor(4.9284, device='cuda:0', grad_fn=)\n", "23400 tensor(5.0439, device='cuda:0', grad_fn=)\n", "23500 tensor(4.9640, device='cuda:0', grad_fn=)\n", "23600 tensor(5.0096, device='cuda:0', grad_fn=)\n", "23700 tensor(4.9700, device='cuda:0', grad_fn=)\n", "23800 tensor(4.9461, device='cuda:0', grad_fn=)\n", "23900 tensor(4.8171, device='cuda:0', grad_fn=)\n", "24000 tensor(4.9529, device='cuda:0', grad_fn=)\n", "24100 tensor(4.8525, device='cuda:0', grad_fn=)\n", "24200 tensor(5.0488, device='cuda:0', grad_fn=)\n", "24300 tensor(4.9206, device='cuda:0', grad_fn=)\n", "24400 tensor(5.0900, device='cuda:0', grad_fn=)\n", "24500 tensor(4.9484, device='cuda:0', grad_fn=)\n", "24600 tensor(4.8962, device='cuda:0', grad_fn=)\n", "24700 tensor(4.8884, device='cuda:0', grad_fn=)\n", "24800 tensor(5.1541, device='cuda:0', grad_fn=)\n", "24900 tensor(4.9803, device='cuda:0', grad_fn=)\n", "25000 tensor(4.4473, device='cuda:0', grad_fn=)\n", "25100 tensor(4.7330, device='cuda:0', grad_fn=)\n", "25200 tensor(5.0709, device='cuda:0', grad_fn=)\n", "25300 tensor(4.7139, device='cuda:0', grad_fn=)\n", "25400 tensor(4.8961, device='cuda:0', grad_fn=)\n", "25500 tensor(4.9459, device='cuda:0', grad_fn=)\n", "25600 tensor(4.8840, device='cuda:0', grad_fn=)\n", "25700 tensor(4.7792, device='cuda:0', grad_fn=)\n", "25800 tensor(4.9212, device='cuda:0', grad_fn=)\n", "25900 tensor(4.7168, device='cuda:0', grad_fn=)\n", "26000 tensor(4.7903, device='cuda:0', grad_fn=)\n", "26100 tensor(4.9544, device='cuda:0', grad_fn=)\n", "26200 tensor(4.8421, device='cuda:0', grad_fn=)\n", "26300 tensor(4.8085, device='cuda:0', grad_fn=)\n", "26400 tensor(4.7129, device='cuda:0', grad_fn=)\n", "26500 tensor(5.0808, device='cuda:0', grad_fn=)\n", "26600 tensor(4.8222, device='cuda:0', grad_fn=)\n", "26700 tensor(4.7982, device='cuda:0', grad_fn=)\n", "26800 tensor(4.8482, device='cuda:0', grad_fn=)\n", "26900 tensor(5.0815, device='cuda:0', grad_fn=)\n", "27000 tensor(4.9754, device='cuda:0', grad_fn=)\n", "27100 tensor(5.0156, device='cuda:0', grad_fn=)\n", "27200 tensor(4.7985, device='cuda:0', grad_fn=)\n", "27300 tensor(4.6372, device='cuda:0', grad_fn=)\n", "27400 tensor(4.5098, device='cuda:0', grad_fn=)\n", "27500 tensor(5.0427, device='cuda:0', grad_fn=)\n", "27600 tensor(4.9139, device='cuda:0', grad_fn=)\n", "27700 tensor(4.8924, device='cuda:0', grad_fn=)\n", "27800 tensor(4.9972, device='cuda:0', grad_fn=)\n", "27900 tensor(5.0452, device='cuda:0', grad_fn=)\n", "28000 tensor(4.5323, device='cuda:0', grad_fn=)\n", "28100 tensor(4.8945, device='cuda:0', grad_fn=)\n", "28200 tensor(4.8096, device='cuda:0', grad_fn=)\n", "28300 tensor(5.1238, device='cuda:0', grad_fn=)\n", "28400 tensor(4.9879, device='cuda:0', grad_fn=)\n", "28500 tensor(4.9505, device='cuda:0', grad_fn=)\n", "28600 tensor(4.7750, device='cuda:0', grad_fn=)\n", "28700 tensor(5.0738, device='cuda:0', grad_fn=)\n", "28800 tensor(4.9318, device='cuda:0', grad_fn=)\n", "28900 tensor(5.0403, device='cuda:0', grad_fn=)\n", "29000 tensor(4.9072, device='cuda:0', grad_fn=)\n", "29100 tensor(4.9822, device='cuda:0', grad_fn=)\n", "29200 tensor(4.8701, device='cuda:0', grad_fn=)\n", "29300 tensor(4.8883, device='cuda:0', grad_fn=)\n", "29400 tensor(4.8906, device='cuda:0', grad_fn=)\n", "29500 tensor(5.0658, device='cuda:0', grad_fn=)\n", "29600 tensor(4.7604, device='cuda:0', grad_fn=)\n", "29700 tensor(5.0792, device='cuda:0', grad_fn=)\n", "29800 tensor(4.9074, device='cuda:0', grad_fn=)\n", "29900 tensor(4.8845, device='cuda:0', grad_fn=)\n", "30000 tensor(5.1969, device='cuda:0', grad_fn=)\n", "30100 tensor(4.9648, device='cuda:0', grad_fn=)\n", "30200 tensor(4.9086, device='cuda:0', grad_fn=)\n", "30300 tensor(4.9708, device='cuda:0', grad_fn=)\n", "30400 tensor(4.9155, device='cuda:0', grad_fn=)\n", "30500 tensor(4.9404, device='cuda:0', grad_fn=)\n", "30600 tensor(5.0224, device='cuda:0', grad_fn=)\n", "30700 tensor(5.0298, device='cuda:0', grad_fn=)\n", "30800 tensor(4.9557, device='cuda:0', grad_fn=)\n", "30900 tensor(4.9653, device='cuda:0', grad_fn=)\n", "31000 tensor(4.8938, device='cuda:0', grad_fn=)\n", "31100 tensor(4.6689, device='cuda:0', grad_fn=)\n", "31200 tensor(4.9757, device='cuda:0', grad_fn=)\n", "31300 tensor(4.8805, device='cuda:0', grad_fn=)\n", "31400 tensor(4.9969, device='cuda:0', grad_fn=)\n", "31500 tensor(4.8262, device='cuda:0', grad_fn=)\n", "31600 tensor(4.5519, device='cuda:0', grad_fn=)\n", "31700 tensor(4.9185, device='cuda:0', grad_fn=)\n", "31800 tensor(4.9190, device='cuda:0', grad_fn=)\n", "31900 tensor(4.8702, device='cuda:0', grad_fn=)\n", "32000 tensor(4.9346, device='cuda:0', grad_fn=)\n", "32100 tensor(4.8963, device='cuda:0', grad_fn=)\n", "32200 tensor(4.9017, device='cuda:0', grad_fn=)\n", "32300 tensor(4.9595, device='cuda:0', grad_fn=)\n", "32400 tensor(4.8125, device='cuda:0', grad_fn=)\n", "32500 tensor(4.9593, device='cuda:0', grad_fn=)\n", "32600 tensor(5.0663, device='cuda:0', grad_fn=)\n", "32700 tensor(4.9644, device='cuda:0', grad_fn=)\n", "32800 tensor(4.8500, device='cuda:0', grad_fn=)\n", "32900 tensor(5.0070, device='cuda:0', grad_fn=)\n", "33000 tensor(4.8131, device='cuda:0', grad_fn=)\n", "33100 tensor(5.0183, device='cuda:0', grad_fn=)\n", "33200 tensor(4.8692, device='cuda:0', grad_fn=)\n", "33300 tensor(4.9145, device='cuda:0', grad_fn=)\n", "33400 tensor(5.0221, device='cuda:0', grad_fn=)\n", "33500 tensor(4.9636, device='cuda:0', grad_fn=)\n", "33600 tensor(4.8758, device='cuda:0', grad_fn=)\n", "33700 tensor(4.8713, device='cuda:0', grad_fn=)\n", "33800 tensor(4.7325, device='cuda:0', grad_fn=)\n", "33900 tensor(4.9829, device='cuda:0', grad_fn=)\n", "34000 tensor(4.7823, device='cuda:0', grad_fn=)\n", "34100 tensor(4.9773, device='cuda:0', grad_fn=)\n", "34200 tensor(4.9638, device='cuda:0', grad_fn=)\n", "34300 tensor(5.0311, device='cuda:0', grad_fn=)\n", "34400 tensor(4.9491, device='cuda:0', grad_fn=)\n", "34500 tensor(4.9527, device='cuda:0', grad_fn=)\n", "34600 tensor(4.7559, device='cuda:0', grad_fn=)\n", "34700 tensor(4.9602, device='cuda:0', grad_fn=)\n", "34800 tensor(5.0363, device='cuda:0', grad_fn=)\n", "34900 tensor(4.9509, device='cuda:0', grad_fn=)\n", "35000 tensor(4.8740, device='cuda:0', grad_fn=)\n", "35100 tensor(4.8790, device='cuda:0', grad_fn=)\n", "35200 tensor(4.7886, device='cuda:0', grad_fn=)\n", "35300 tensor(4.9939, device='cuda:0', grad_fn=)\n", "35400 tensor(4.8046, device='cuda:0', grad_fn=)\n", "35500 tensor(5.0125, device='cuda:0', grad_fn=)\n", "35600 tensor(4.8254, device='cuda:0', grad_fn=)\n", "35700 tensor(4.5858, device='cuda:0', grad_fn=)\n", "35800 tensor(5.0067, device='cuda:0', grad_fn=)\n", "35900 tensor(5.0505, device='cuda:0', grad_fn=)\n", "36000 tensor(4.9909, device='cuda:0', grad_fn=)\n", "36100 tensor(4.8610, device='cuda:0', grad_fn=)\n", "36200 tensor(4.9135, device='cuda:0', grad_fn=)\n", "36300 tensor(5.0409, device='cuda:0', grad_fn=)\n", "36400 tensor(4.8932, device='cuda:0', grad_fn=)\n", "36500 tensor(4.8384, device='cuda:0', grad_fn=)\n", "36600 tensor(4.8262, device='cuda:0', grad_fn=)\n", "36700 tensor(4.8363, device='cuda:0', grad_fn=)\n", "36800 tensor(4.9260, device='cuda:0', grad_fn=)\n", "36900 tensor(4.7176, device='cuda:0', grad_fn=)\n", "37000 tensor(4.8836, device='cuda:0', grad_fn=)\n", "37100 tensor(4.7659, device='cuda:0', grad_fn=)\n", "37200 tensor(5.0418, device='cuda:0', grad_fn=)\n", "37300 tensor(4.7165, device='cuda:0', grad_fn=)\n", "37400 tensor(4.7707, device='cuda:0', grad_fn=)\n", "37500 tensor(4.9404, device='cuda:0', grad_fn=)\n", "37600 tensor(4.7666, device='cuda:0', grad_fn=)\n", "37700 tensor(5.0086, device='cuda:0', grad_fn=)\n", "37800 tensor(4.8929, device='cuda:0', grad_fn=)\n", "37900 tensor(5.0537, device='cuda:0', grad_fn=)\n", "38000 tensor(4.8494, device='cuda:0', grad_fn=)\n", "38100 tensor(5.1193, device='cuda:0', grad_fn=)\n", "38200 tensor(4.9035, device='cuda:0', grad_fn=)\n", "38300 tensor(4.7574, device='cuda:0', grad_fn=)\n", "38400 tensor(4.9181, device='cuda:0', grad_fn=)\n", "38500 tensor(5.0186, device='cuda:0', grad_fn=)\n", "38600 tensor(5.0224, device='cuda:0', grad_fn=)\n", "38700 tensor(4.6032, device='cuda:0', grad_fn=)\n", "38800 tensor(5.1368, device='cuda:0', grad_fn=)\n", "38900 tensor(4.9394, device='cuda:0', grad_fn=)\n", "39000 tensor(4.7891, device='cuda:0', grad_fn=)\n", "39100 tensor(4.9718, device='cuda:0', grad_fn=)\n", "39200 tensor(4.9599, device='cuda:0', grad_fn=)\n", "39300 tensor(4.8518, device='cuda:0', grad_fn=)\n", "39400 tensor(4.7832, device='cuda:0', grad_fn=)\n", "39500 tensor(4.9827, device='cuda:0', grad_fn=)\n", "39600 tensor(5.0733, device='cuda:0', grad_fn=)\n", "39700 tensor(4.8859, device='cuda:0', grad_fn=)\n", "39800 tensor(4.9722, device='cuda:0', grad_fn=)\n", "39900 tensor(5.0568, device='cuda:0', grad_fn=)\n", "40000 tensor(4.8251, device='cuda:0', grad_fn=)\n", "40100 tensor(4.8720, device='cuda:0', grad_fn=)\n", "40200 tensor(5.3066, device='cuda:0', grad_fn=)\n", "40300 tensor(4.9435, device='cuda:0', grad_fn=)\n", "40400 tensor(4.9634, device='cuda:0', grad_fn=)\n", "40500 tensor(4.8406, device='cuda:0', grad_fn=)\n", "40600 tensor(4.8050, device='cuda:0', grad_fn=)\n", "40700 tensor(4.6578, device='cuda:0', grad_fn=)\n", "40800 tensor(4.8490, device='cuda:0', grad_fn=)\n", "40900 tensor(5.1542, device='cuda:0', grad_fn=)\n", "41000 tensor(4.8509, device='cuda:0', grad_fn=)\n", "41100 tensor(4.8082, device='cuda:0', grad_fn=)\n", "41200 tensor(4.8444, device='cuda:0', grad_fn=)\n", "41300 tensor(5.1602, device='cuda:0', grad_fn=)\n", "41400 tensor(4.7235, device='cuda:0', grad_fn=)\n", "41500 tensor(5.0334, device='cuda:0', grad_fn=)\n", "41600 tensor(5.0500, device='cuda:0', grad_fn=)\n", "41700 tensor(5.0378, device='cuda:0', grad_fn=)\n", "41800 tensor(4.7989, device='cuda:0', grad_fn=)\n", "41900 tensor(4.9342, device='cuda:0', grad_fn=)\n", "42000 tensor(4.9981, device='cuda:0', grad_fn=)\n", "42100 tensor(4.6723, device='cuda:0', grad_fn=)\n", "42200 tensor(4.9382, device='cuda:0', grad_fn=)\n", "42300 tensor(4.9237, device='cuda:0', grad_fn=)\n", "42400 tensor(4.9302, device='cuda:0', grad_fn=)\n", "42500 tensor(4.8494, device='cuda:0', grad_fn=)\n", "42600 tensor(4.9942, device='cuda:0', grad_fn=)\n", "42700 tensor(4.9581, device='cuda:0', grad_fn=)\n", "42800 tensor(4.8044, device='cuda:0', grad_fn=)\n", "42900 tensor(5.0890, device='cuda:0', grad_fn=)\n", "43000 tensor(4.9422, device='cuda:0', grad_fn=)\n", "43100 tensor(5.0014, device='cuda:0', grad_fn=)\n", "43200 tensor(4.9001, device='cuda:0', grad_fn=)\n", "43300 tensor(4.9133, device='cuda:0', grad_fn=)\n", "43400 tensor(4.8836, device='cuda:0', grad_fn=)\n", "43500 tensor(4.8232, device='cuda:0', grad_fn=)\n", "43600 tensor(4.8052, device='cuda:0', grad_fn=)\n", "43700 tensor(5.0304, device='cuda:0', grad_fn=)\n", "43800 tensor(5.0834, device='cuda:0', grad_fn=)\n", "43900 tensor(4.8242, device='cuda:0', grad_fn=)\n", "44000 tensor(4.8126, device='cuda:0', grad_fn=)\n", "44100 tensor(4.7836, device='cuda:0', grad_fn=)\n", "44200 tensor(5.0763, device='cuda:0', grad_fn=)\n", "44300 tensor(5.0682, device='cuda:0', grad_fn=)\n", "44400 tensor(4.8869, device='cuda:0', grad_fn=)\n", "44500 tensor(4.8527, device='cuda:0', grad_fn=)\n", "44600 tensor(4.8439, device='cuda:0', grad_fn=)\n", "44700 tensor(4.9127, device='cuda:0', grad_fn=)\n", "44800 tensor(4.9628, device='cuda:0', grad_fn=)\n", "44900 tensor(5.0566, device='cuda:0', grad_fn=)\n", "45000 tensor(5.0596, device='cuda:0', grad_fn=)\n", "45100 tensor(5.1187, device='cuda:0', grad_fn=)\n", "45200 tensor(5.0824, device='cuda:0', grad_fn=)\n", "45300 tensor(4.8433, device='cuda:0', grad_fn=)\n", "45400 tensor(4.7299, device='cuda:0', grad_fn=)\n", "45500 tensor(5.1722, device='cuda:0', grad_fn=)\n", "45600 tensor(4.7867, device='cuda:0', grad_fn=)\n", "45700 tensor(4.9631, device='cuda:0', grad_fn=)\n", "45800 tensor(4.6216, device='cuda:0', grad_fn=)\n", "45900 tensor(4.9601, device='cuda:0', grad_fn=)\n", "46000 tensor(4.9055, device='cuda:0', grad_fn=)\n", "46100 tensor(5.0517, device='cuda:0', grad_fn=)\n", "46200 tensor(5.0099, device='cuda:0', grad_fn=)\n", "46300 tensor(4.8178, device='cuda:0', grad_fn=)\n", "46400 tensor(4.9317, device='cuda:0', grad_fn=)\n", "46500 tensor(4.8770, device='cuda:0', grad_fn=)\n", "46600 tensor(4.9668, device='cuda:0', grad_fn=)\n", "46700 tensor(5.1287, device='cuda:0', grad_fn=)\n", "46800 tensor(4.9050, device='cuda:0', grad_fn=)\n", "46900 tensor(4.9622, device='cuda:0', grad_fn=)\n", "47000 tensor(4.6818, device='cuda:0', grad_fn=)\n", "47100 tensor(4.8780, device='cuda:0', grad_fn=)\n", "47200 tensor(4.9493, device='cuda:0', grad_fn=)\n", "47300 tensor(4.7958, device='cuda:0', grad_fn=)\n", "47400 tensor(4.5415, device='cuda:0', grad_fn=)\n", "47500 tensor(5.0651, device='cuda:0', grad_fn=)\n", "47600 tensor(4.9692, device='cuda:0', grad_fn=)\n", "47700 tensor(4.8536, device='cuda:0', grad_fn=)\n", "47800 tensor(4.7306, device='cuda:0', grad_fn=)\n", "47900 tensor(5.1795, device='cuda:0', grad_fn=)\n", "48000 tensor(4.9196, device='cuda:0', grad_fn=)\n", "48100 tensor(5.1446, device='cuda:0', grad_fn=)\n", "48200 tensor(4.9810, device='cuda:0', grad_fn=)\n", "48300 tensor(4.9688, device='cuda:0', grad_fn=)\n", "48400 tensor(5.0246, device='cuda:0', grad_fn=)\n", "48500 tensor(4.7523, device='cuda:0', grad_fn=)\n", "48600 tensor(4.7716, device='cuda:0', grad_fn=)\n", "48700 tensor(4.8938, device='cuda:0', grad_fn=)\n", "48800 tensor(4.9324, device='cuda:0', grad_fn=)\n", "48900 tensor(4.9811, device='cuda:0', grad_fn=)\n", "49000 tensor(4.8818, device='cuda:0', grad_fn=)\n", "49100 tensor(4.9871, device='cuda:0', grad_fn=)\n", "49200 tensor(4.8498, device='cuda:0', grad_fn=)\n", "49300 tensor(4.8027, device='cuda:0', grad_fn=)\n", "49400 tensor(5.0199, device='cuda:0', grad_fn=)\n", "49500 tensor(4.9790, device='cuda:0', grad_fn=)\n", "49600 tensor(5.0995, device='cuda:0', grad_fn=)\n", "49700 tensor(4.8989, device='cuda:0', grad_fn=)\n", "49800 tensor(4.8903, device='cuda:0', grad_fn=)\n", "49900 tensor(4.6744, device='cuda:0', grad_fn=)\n", "50000 tensor(4.9403, device='cuda:0', grad_fn=)\n", "50100 tensor(4.7815, device='cuda:0', grad_fn=)\n", "50200 tensor(4.8617, device='cuda:0', grad_fn=)\n", "50300 tensor(4.4559, device='cuda:0', grad_fn=)\n", "50400 tensor(5.0322, device='cuda:0', grad_fn=)\n", "50500 tensor(4.6867, device='cuda:0', grad_fn=)\n", "50600 tensor(4.9644, device='cuda:0', grad_fn=)\n", "50700 tensor(5.0631, device='cuda:0', grad_fn=)\n", "50800 tensor(4.7992, device='cuda:0', grad_fn=)\n", "50900 tensor(4.9346, device='cuda:0', grad_fn=)\n", "51000 tensor(4.6487, device='cuda:0', grad_fn=)\n", "51100 tensor(4.8758, device='cuda:0', grad_fn=)\n", "51200 tensor(5.0734, device='cuda:0', grad_fn=)\n", "51300 tensor(4.8078, device='cuda:0', grad_fn=)\n", "51400 tensor(4.7628, device='cuda:0', grad_fn=)\n", "51500 tensor(4.8508, device='cuda:0', grad_fn=)\n", "51600 tensor(4.8231, device='cuda:0', grad_fn=)\n", "51700 tensor(5.0122, device='cuda:0', grad_fn=)\n", "51800 tensor(4.8941, device='cuda:0', grad_fn=)\n", "51900 tensor(5.0284, device='cuda:0', grad_fn=)\n", "52000 tensor(4.9158, device='cuda:0', grad_fn=)\n", "52100 tensor(4.8752, device='cuda:0', grad_fn=)\n", "52200 tensor(4.7020, device='cuda:0', grad_fn=)\n", "52300 tensor(4.6001, device='cuda:0', grad_fn=)\n", "52400 tensor(4.7898, device='cuda:0', grad_fn=)\n", "52500 tensor(4.8255, device='cuda:0', grad_fn=)\n", "52600 tensor(4.7331, device='cuda:0', grad_fn=)\n", "52700 tensor(4.8546, device='cuda:0', grad_fn=)\n", "52800 tensor(4.9418, device='cuda:0', grad_fn=)\n", "52900 tensor(4.7536, device='cuda:0', grad_fn=)\n", "53000 tensor(4.9609, device='cuda:0', grad_fn=)\n", "53100 tensor(5.0644, device='cuda:0', grad_fn=)\n", "53200 tensor(4.8919, device='cuda:0', grad_fn=)\n", "53300 tensor(4.7840, device='cuda:0', grad_fn=)\n", "53400 tensor(4.8539, device='cuda:0', grad_fn=)\n", "53500 tensor(4.8023, device='cuda:0', grad_fn=)\n", "53600 tensor(4.9810, device='cuda:0', grad_fn=)\n", "53700 tensor(4.9946, device='cuda:0', grad_fn=)\n", "53800 tensor(4.3504, device='cuda:0', grad_fn=)\n", "53900 tensor(4.8656, device='cuda:0', grad_fn=)\n", "54000 tensor(5.0103, device='cuda:0', grad_fn=)\n", "54100 tensor(4.8503, device='cuda:0', grad_fn=)\n", "54200 tensor(4.9970, device='cuda:0', grad_fn=)\n", "54300 tensor(4.5719, device='cuda:0', grad_fn=)\n", "54400 tensor(4.7891, device='cuda:0', grad_fn=)\n", "54500 tensor(4.8968, device='cuda:0', grad_fn=)\n", "54600 tensor(5.0036, device='cuda:0', grad_fn=)\n", "54700 tensor(4.9487, device='cuda:0', grad_fn=)\n", "54800 tensor(4.8477, device='cuda:0', grad_fn=)\n", "54900 tensor(4.9253, device='cuda:0', grad_fn=)\n", "55000 tensor(4.9079, device='cuda:0', grad_fn=)\n", "55100 tensor(4.9499, device='cuda:0', grad_fn=)\n", "55200 tensor(5.0510, device='cuda:0', grad_fn=)\n", "55300 tensor(4.9320, device='cuda:0', grad_fn=)\n", "55400 tensor(4.5737, device='cuda:0', grad_fn=)\n", "55500 tensor(4.7703, device='cuda:0', grad_fn=)\n", "55600 tensor(5.0166, device='cuda:0', grad_fn=)\n", "55700 tensor(4.9049, device='cuda:0', grad_fn=)\n", "55800 tensor(4.7355, device='cuda:0', grad_fn=)\n", "55900 tensor(4.5776, device='cuda:0', grad_fn=)\n", "56000 tensor(4.9919, device='cuda:0', grad_fn=)\n", "56100 tensor(4.8629, device='cuda:0', grad_fn=)\n", "56200 tensor(5.0123, device='cuda:0', grad_fn=)\n", "56300 tensor(4.3110, device='cuda:0', grad_fn=)\n", "56400 tensor(4.8950, device='cuda:0', grad_fn=)\n", "56500 tensor(4.8415, device='cuda:0', grad_fn=)\n", "56600 tensor(4.7285, device='cuda:0', grad_fn=)\n", "56700 tensor(4.8401, device='cuda:0', grad_fn=)\n", "56800 tensor(4.7972, device='cuda:0', grad_fn=)\n", "56900 tensor(4.7398, device='cuda:0', grad_fn=)\n", "57000 tensor(5.1683, device='cuda:0', grad_fn=)\n", "57100 tensor(4.9399, device='cuda:0', grad_fn=)\n", "57200 tensor(4.9609, device='cuda:0', grad_fn=)\n", "57300 tensor(4.9818, device='cuda:0', grad_fn=)\n", "57400 tensor(4.9719, device='cuda:0', grad_fn=)\n", "57500 tensor(4.8724, device='cuda:0', grad_fn=)\n", "57600 tensor(4.9824, device='cuda:0', grad_fn=)\n", "57700 tensor(5.0357, device='cuda:0', grad_fn=)\n", "57800 tensor(5.0542, device='cuda:0', grad_fn=)\n", "57900 tensor(4.8753, device='cuda:0', grad_fn=)\n", "58000 tensor(4.7773, device='cuda:0', grad_fn=)\n", "58100 tensor(4.7864, device='cuda:0', grad_fn=)\n", "58200 tensor(4.8033, device='cuda:0', grad_fn=)\n", "58300 tensor(4.9997, device='cuda:0', grad_fn=)\n", "58400 tensor(4.9701, device='cuda:0', grad_fn=)\n", "58500 tensor(4.8920, device='cuda:0', grad_fn=)\n", "58600 tensor(4.9408, device='cuda:0', grad_fn=)\n", "58700 tensor(5.1013, device='cuda:0', grad_fn=)\n", "58800 tensor(4.8176, device='cuda:0', grad_fn=)\n", "58900 tensor(4.7466, device='cuda:0', grad_fn=)\n", "59000 tensor(4.9146, device='cuda:0', grad_fn=)\n", "59100 tensor(4.8151, device='cuda:0', grad_fn=)\n", "59200 tensor(4.9928, device='cuda:0', grad_fn=)\n", "59300 tensor(5.0274, device='cuda:0', grad_fn=)\n", "59400 tensor(4.7727, device='cuda:0', grad_fn=)\n", "59500 tensor(5.0648, device='cuda:0', grad_fn=)\n", "59600 tensor(4.9982, device='cuda:0', grad_fn=)\n", "59700 tensor(4.8934, device='cuda:0', grad_fn=)\n", "59800 tensor(4.8285, device='cuda:0', grad_fn=)\n", "59900 tensor(4.8039, device='cuda:0', grad_fn=)\n", "60000 tensor(4.9090, device='cuda:0', grad_fn=)\n", "60100 tensor(4.6927, device='cuda:0', grad_fn=)\n", "60200 tensor(4.8922, device='cuda:0', grad_fn=)\n", "60300 tensor(4.8804, device='cuda:0', grad_fn=)\n", "60400 tensor(4.9676, device='cuda:0', grad_fn=)\n", "60500 tensor(4.7234, device='cuda:0', grad_fn=)\n", "60600 tensor(4.9174, device='cuda:0', grad_fn=)\n", "60700 tensor(4.9062, device='cuda:0', grad_fn=)\n", "60800 tensor(5.0811, device='cuda:0', grad_fn=)\n", "60900 tensor(5.1713, device='cuda:0', grad_fn=)\n", "61000 tensor(4.9471, device='cuda:0', grad_fn=)\n", "61100 tensor(4.8106, device='cuda:0', grad_fn=)\n", "61200 tensor(4.8666, device='cuda:0', grad_fn=)\n", "61300 tensor(4.8624, device='cuda:0', grad_fn=)\n", "61400 tensor(4.5771, device='cuda:0', grad_fn=)\n", "61500 tensor(4.8186, device='cuda:0', grad_fn=)\n", "61600 tensor(4.7787, device='cuda:0', grad_fn=)\n", "61700 tensor(4.9245, device='cuda:0', grad_fn=)\n", "61800 tensor(5.0268, device='cuda:0', grad_fn=)\n", "61900 tensor(5.2582, device='cuda:0', grad_fn=)\n", "62000 tensor(4.8309, device='cuda:0', grad_fn=)\n", "62100 tensor(4.9982, device='cuda:0', grad_fn=)\n", "62200 tensor(4.8859, device='cuda:0', grad_fn=)\n", "62300 tensor(4.5051, device='cuda:0', grad_fn=)\n", "62400 tensor(4.6767, device='cuda:0', grad_fn=)\n", "62500 tensor(4.7197, device='cuda:0', grad_fn=)\n", "62600 tensor(4.6625, device='cuda:0', grad_fn=)\n", "62700 tensor(4.6548, device='cuda:0', grad_fn=)\n", "62800 tensor(4.7307, device='cuda:0', grad_fn=)\n", "62900 tensor(4.9550, device='cuda:0', grad_fn=)\n", "63000 tensor(4.5528, device='cuda:0', grad_fn=)\n", "63100 tensor(4.8676, device='cuda:0', grad_fn=)\n", "63200 tensor(4.9302, device='cuda:0', grad_fn=)\n", "63300 tensor(4.8878, device='cuda:0', grad_fn=)\n", "63400 tensor(4.9172, device='cuda:0', grad_fn=)\n", "63500 tensor(4.7881, device='cuda:0', grad_fn=)\n", "63600 tensor(4.8712, device='cuda:0', grad_fn=)\n", "63700 tensor(4.9398, device='cuda:0', grad_fn=)\n", "63800 tensor(4.9999, device='cuda:0', grad_fn=)\n", "63900 tensor(4.8581, device='cuda:0', grad_fn=)\n", "64000 tensor(4.6726, device='cuda:0', grad_fn=)\n", "64100 tensor(5.0308, device='cuda:0', grad_fn=)\n", "64200 tensor(4.7130, device='cuda:0', grad_fn=)\n", "64300 tensor(4.9586, device='cuda:0', grad_fn=)\n", "64400 tensor(4.9456, device='cuda:0', grad_fn=)\n", "64500 tensor(4.8030, device='cuda:0', grad_fn=)\n", "64600 tensor(4.9885, device='cuda:0', grad_fn=)\n", "64700 tensor(4.9439, device='cuda:0', grad_fn=)\n", "64800 tensor(4.6348, device='cuda:0', grad_fn=)\n", "64900 tensor(4.8772, device='cuda:0', grad_fn=)\n", "65000 tensor(4.9567, device='cuda:0', grad_fn=)\n", "65100 tensor(4.9036, device='cuda:0', grad_fn=)\n", "65200 tensor(4.7526, device='cuda:0', grad_fn=)\n", "65300 tensor(4.9206, device='cuda:0', grad_fn=)\n", "65400 tensor(4.8406, device='cuda:0', grad_fn=)\n", "65500 tensor(4.5461, device='cuda:0', grad_fn=)\n", "65600 tensor(4.9647, device='cuda:0', grad_fn=)\n", "65700 tensor(4.9128, device='cuda:0', grad_fn=)\n", "65800 tensor(4.8554, device='cuda:0', grad_fn=)\n", "65900 tensor(4.8749, device='cuda:0', grad_fn=)\n", "66000 tensor(5.1345, device='cuda:0', grad_fn=)\n", "66100 tensor(4.6254, device='cuda:0', grad_fn=)\n", "66200 tensor(4.9932, device='cuda:0', grad_fn=)\n", "66300 tensor(4.5778, device='cuda:0', grad_fn=)\n", "66400 tensor(4.7925, device='cuda:0', grad_fn=)\n", "66500 tensor(4.9761, device='cuda:0', grad_fn=)\n", "66600 tensor(4.9166, device='cuda:0', grad_fn=)\n", "66700 tensor(4.8186, device='cuda:0', grad_fn=)\n", "66800 tensor(4.9063, device='cuda:0', grad_fn=)\n", "66900 tensor(4.9770, device='cuda:0', grad_fn=)\n", "67000 tensor(4.8087, device='cuda:0', grad_fn=)\n", "67100 tensor(4.7366, device='cuda:0', grad_fn=)\n", "67200 tensor(5.0656, device='cuda:0', grad_fn=)\n", "67300 tensor(4.9718, device='cuda:0', grad_fn=)\n", "67400 tensor(4.8172, device='cuda:0', grad_fn=)\n", "67500 tensor(4.9368, device='cuda:0', grad_fn=)\n", "67600 tensor(4.9278, device='cuda:0', grad_fn=)\n", "67700 tensor(4.8133, device='cuda:0', grad_fn=)\n", "67800 tensor(4.9486, device='cuda:0', grad_fn=)\n", "67900 tensor(4.8521, device='cuda:0', grad_fn=)\n", "68000 tensor(4.9510, device='cuda:0', grad_fn=)\n", "68100 tensor(4.8939, device='cuda:0', grad_fn=)\n", "68200 tensor(4.8088, device='cuda:0', grad_fn=)\n", "68300 tensor(4.9821, device='cuda:0', grad_fn=)\n", "68400 tensor(5.1750, device='cuda:0', grad_fn=)\n", "68500 tensor(4.6476, device='cuda:0', grad_fn=)\n", "68600 tensor(4.8567, device='cuda:0', grad_fn=)\n", "68700 tensor(4.8663, device='cuda:0', grad_fn=)\n", "68800 tensor(5.0268, device='cuda:0', grad_fn=)\n", "68900 tensor(4.8717, device='cuda:0', grad_fn=)\n", "69000 tensor(4.9166, device='cuda:0', grad_fn=)\n", "69100 tensor(4.9094, device='cuda:0', grad_fn=)\n", "69200 tensor(4.7433, device='cuda:0', grad_fn=)\n", "69300 tensor(4.5366, device='cuda:0', grad_fn=)\n", "69400 tensor(5.0260, device='cuda:0', grad_fn=)\n", "69500 tensor(4.7304, device='cuda:0', grad_fn=)\n" ] } ], "source": [ "model = SimpleBigramNeuralLanguageModel(vocab_size, embed_size).to(device)\n", "data = DataLoader(train_dataset, batch_size=batch_size)\n", "optimizer = torch.optim.Adam(model.parameters())\n", "criterion = torch.nn.NLLLoss()\n", "\n", "model.train()\n", "step = 0\n", "for x, y in data:\n", " x = x.to(device)\n", " y = y.to(device)\n", " optimizer.zero_grad()\n", " ypredicted = model(x)\n", " loss = criterion(torch.log(ypredicted), y)\n", " if step % 100 == 0:\n", " print(step, loss)\n", " step += 1\n", " loss.backward()\n", " optimizer.step()" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "wfLtxqN6gFCw", "outputId": "1be9876e-eb88-4ed0-a40e-3546aa6c5ad4" }, "outputs": [ { "data": { "text/plain": [ "True" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import torch\n", "torch.cuda.is_available()" ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "id": "bp60AtU0XBuj" }, "outputs": [], "source": [ "torch.save(model.state_dict(), path_to_model)" ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "BwN-Q2sFXBuj", "outputId": "a444be6d-bfb3-4235-c48c-41ba6cbfeec1" }, "outputs": [ { "data": { "text/plain": [ "SimpleBigramNeuralLanguageModel(\n", " (model): Sequential(\n", " (0): Embedding(10000, 100)\n", " (1): Linear(in_features=100, out_features=10000, bias=True)\n", " (2): Softmax(dim=None)\n", " )\n", ")" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model = SimpleBigramNeuralLanguageModel(vocab_size, embed_size).to(device)\n", "model.load_state_dict(torch.load(path_to_model))\n", "model.eval()" ] }, { "cell_type": "code", "execution_count": 29, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "QVBhjgB1XBuk", "outputId": "ee63bb8b-57c8-40fb-94fe-cd00e0fa82b8" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Creating outputs in dev-0\n" ] } ], "source": [ "save_outs('dev-0')" ] }, { "cell_type": "code", "execution_count": 30, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "5BglgEAxXBuk", "outputId": "4fda63a1-94d8-4daa-dbd7-d6a640e57f40" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Creating outputs in test-A\n" ] } ], "source": [ "save_outs('test-A')" ] } ], "metadata": { "accelerator": "GPU", "colab": { "provenance": [] }, "gpuClass": "standard", "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.6" } }, "nbformat": 4, "nbformat_minor": 0 }