diff --git a/Zad_7.ipynb b/Zad_7.ipynb
new file mode 100644
index 0000000..04dcc57
--- /dev/null
+++ b/Zad_7.ipynb
@@ -0,0 +1,4995 @@
+{
+ "nbformat": 4,
+ "nbformat_minor": 0,
+ "metadata": {
+ "colab": {
+ "provenance": [],
+ "toc_visible": true
+ },
+ "kernelspec": {
+ "name": "python3",
+ "display_name": "Python 3"
+ },
+ "language_info": {
+ "name": "python"
+ },
+ "accelerator": "GPU",
+ "gpuClass": "standard"
+ },
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## connect to google drive (working on colab)"
+ ],
+ "metadata": {
+ "id": "G0ujnpy2tuBE"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "from google.colab import drive\n",
+ "drive.mount('/content/drive')"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "Lwuh_S5pWY1j",
+ "outputId": "27838dab-7be0-4447-883a-95559887c7c8"
+ },
+ "execution_count": 1,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Mounted at /content/drive\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "QNiUKMiqWLd0"
+ },
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "!mkdir moj7\n"
+ ],
+ "metadata": {
+ "id": "vlnrhRaEWNJF"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "%cd drive"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "539oWG3pXOAX",
+ "outputId": "a9ae634d-d4a2-47dd-97d2-10c245c7c5d2"
+ },
+ "execution_count": null,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "/content/drive\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "%cd MyDrive"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "wgmiKs4BiAiT",
+ "outputId": "fbbd0bc7-76bd-47bf-e38b-e051239e5ba7"
+ },
+ "execution_count": null,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "/content/drive/MyDrive\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "%cd moj7"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "jNdkji_hiAlt",
+ "outputId": "962b875a-8d3f-433d-8d7b-dcd664ee1674"
+ },
+ "execution_count": null,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "/content/drive/MyDrive/moj7\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [],
+ "metadata": {
+ "id": "P249ENeSiAqn"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "!pwd"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "IOHV3Iz4WNLc",
+ "outputId": "f56a6ab7-73e6-4b03-824b-b5e749c8a82e"
+ },
+ "execution_count": null,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "/content/drive/MyDrive/moj7\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [],
+ "metadata": {
+ "id": "2pLkeHY5Z9oT"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Preprocess"
+ ],
+ "metadata": {
+ "id": "D7jhQfbttn9D"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "import re"
+ ],
+ "metadata": {
+ "id": "_IPWOt2BZ_-q"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "train_file ='train/in.tsv.xz'\n",
+ "test_file = 'test-A/in.tsv.xz'\n",
+ "out_file = 'test-A/out.tsv'\n",
+ "\n",
+ "def preprocess(line):\n",
+ " line = replace_endline(line)\n",
+ " line = get_rid_of_header(line)\n",
+ " return line\n",
+ "\n",
+ "def get_rid_of_header(line):\n",
+ " line = line.split('\\t')[6:]\n",
+ " return \"\".join(line)\n",
+ " \n",
+ "def replace_endline(line):\n",
+ " line = re.sub(\"\\\\n|\\\\+\", \" \", line)\n",
+ " return line"
+ ],
+ "metadata": {
+ "id": "qDnZdPblWNNr"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "from itertools import islice\n",
+ "import regex as re\n",
+ "import sys\n",
+ "from torchtext.vocab import build_vocab_from_iterator\n",
+ "import lzma\n",
+ "import pickle\n",
+ "\n",
+ "\n",
+ "\n",
+ "def get_words_from_line(line):\n",
+ " line = line.rstrip()\n",
+ " yield ''\n",
+ " line = preprocess(line)\n",
+ " for t in line.split(' '):\n",
+ " yield t\n",
+ " yield ''\n",
+ "\n",
+ "\n",
+ "def get_word_lines_from_file(file_name):\n",
+ " n = 0\n",
+ " with lzma.open(file_name, 'r') as fh:\n",
+ " for line in fh:\n",
+ " n+=1\n",
+ " if n%1000==0:\n",
+ " print(n)\n",
+ " yield get_words_from_line(line.decode('utf-8'))\n",
+ "#vocab_size = 20000\n",
+ "vocab_size = 20000\n",
+ "\n",
+ "vocab = build_vocab_from_iterator(\n",
+ " get_word_lines_from_file(train_file),\n",
+ " max_tokens = vocab_size,\n",
+ " specials = [''])\n",
+ "\n"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "RVN0lKVZfwMe",
+ "outputId": "305b03e4-f626-4560-a371-41bc5a0ea9c7"
+ },
+ "execution_count": null,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "1000\n",
+ "2000\n",
+ "3000\n",
+ "4000\n",
+ "5000\n",
+ "6000\n",
+ "7000\n",
+ "8000\n",
+ "9000\n",
+ "10000\n",
+ "11000\n",
+ "12000\n",
+ "13000\n",
+ "14000\n",
+ "15000\n",
+ "16000\n",
+ "17000\n",
+ "18000\n",
+ "19000\n",
+ "20000\n",
+ "21000\n",
+ "22000\n",
+ "23000\n",
+ "24000\n",
+ "25000\n",
+ "26000\n",
+ "27000\n",
+ "28000\n",
+ "29000\n",
+ "30000\n",
+ "31000\n",
+ "32000\n",
+ "33000\n",
+ "34000\n",
+ "35000\n",
+ "36000\n",
+ "37000\n",
+ "38000\n",
+ "39000\n",
+ "40000\n",
+ "41000\n",
+ "42000\n",
+ "43000\n",
+ "44000\n",
+ "45000\n",
+ "46000\n",
+ "47000\n",
+ "48000\n",
+ "49000\n",
+ "50000\n",
+ "51000\n",
+ "52000\n",
+ "53000\n",
+ "54000\n",
+ "55000\n",
+ "56000\n",
+ "57000\n",
+ "58000\n",
+ "59000\n",
+ "60000\n",
+ "61000\n",
+ "62000\n",
+ "63000\n",
+ "64000\n",
+ "65000\n",
+ "66000\n",
+ "67000\n",
+ "68000\n",
+ "69000\n",
+ "70000\n",
+ "71000\n",
+ "72000\n",
+ "73000\n",
+ "74000\n",
+ "75000\n",
+ "76000\n",
+ "77000\n",
+ "78000\n",
+ "79000\n",
+ "80000\n",
+ "81000\n",
+ "82000\n",
+ "83000\n",
+ "84000\n",
+ "85000\n",
+ "86000\n",
+ "87000\n",
+ "88000\n",
+ "89000\n",
+ "90000\n",
+ "91000\n",
+ "92000\n",
+ "93000\n",
+ "94000\n",
+ "95000\n",
+ "96000\n",
+ "97000\n",
+ "98000\n",
+ "99000\n",
+ "100000\n",
+ "101000\n",
+ "102000\n",
+ "103000\n",
+ "104000\n",
+ "105000\n",
+ "106000\n",
+ "107000\n",
+ "108000\n",
+ "109000\n",
+ "110000\n",
+ "111000\n",
+ "112000\n",
+ "113000\n",
+ "114000\n",
+ "115000\n",
+ "116000\n",
+ "117000\n",
+ "118000\n",
+ "119000\n",
+ "120000\n",
+ "121000\n",
+ "122000\n",
+ "123000\n",
+ "124000\n",
+ "125000\n",
+ "126000\n",
+ "127000\n",
+ "128000\n",
+ "129000\n",
+ "130000\n",
+ "131000\n",
+ "132000\n",
+ "133000\n",
+ "134000\n",
+ "135000\n",
+ "136000\n",
+ "137000\n",
+ "138000\n",
+ "139000\n",
+ "140000\n",
+ "141000\n",
+ "142000\n",
+ "143000\n",
+ "144000\n",
+ "145000\n",
+ "146000\n",
+ "147000\n",
+ "148000\n",
+ "149000\n",
+ "150000\n",
+ "151000\n",
+ "152000\n",
+ "153000\n",
+ "154000\n",
+ "155000\n",
+ "156000\n",
+ "157000\n",
+ "158000\n",
+ "159000\n",
+ "160000\n",
+ "161000\n",
+ "162000\n",
+ "163000\n",
+ "164000\n",
+ "165000\n",
+ "166000\n",
+ "167000\n",
+ "168000\n",
+ "169000\n",
+ "170000\n",
+ "171000\n",
+ "172000\n",
+ "173000\n",
+ "174000\n",
+ "175000\n",
+ "176000\n",
+ "177000\n",
+ "178000\n",
+ "179000\n",
+ "180000\n",
+ "181000\n",
+ "182000\n",
+ "183000\n",
+ "184000\n",
+ "185000\n",
+ "186000\n",
+ "187000\n",
+ "188000\n",
+ "189000\n",
+ "190000\n",
+ "191000\n",
+ "192000\n",
+ "193000\n",
+ "194000\n",
+ "195000\n",
+ "196000\n",
+ "197000\n",
+ "198000\n",
+ "199000\n",
+ "200000\n",
+ "201000\n",
+ "202000\n",
+ "203000\n",
+ "204000\n",
+ "205000\n",
+ "206000\n",
+ "207000\n",
+ "208000\n",
+ "209000\n",
+ "210000\n",
+ "211000\n",
+ "212000\n",
+ "213000\n",
+ "214000\n",
+ "215000\n",
+ "216000\n",
+ "217000\n",
+ "218000\n",
+ "219000\n",
+ "220000\n",
+ "221000\n",
+ "222000\n",
+ "223000\n",
+ "224000\n",
+ "225000\n",
+ "226000\n",
+ "227000\n",
+ "228000\n",
+ "229000\n",
+ "230000\n",
+ "231000\n",
+ "232000\n",
+ "233000\n",
+ "234000\n",
+ "235000\n",
+ "236000\n",
+ "237000\n",
+ "238000\n",
+ "239000\n",
+ "240000\n",
+ "241000\n",
+ "242000\n",
+ "243000\n",
+ "244000\n",
+ "245000\n",
+ "246000\n",
+ "247000\n",
+ "248000\n",
+ "249000\n",
+ "250000\n",
+ "251000\n",
+ "252000\n",
+ "253000\n",
+ "254000\n",
+ "255000\n",
+ "256000\n",
+ "257000\n",
+ "258000\n",
+ "259000\n",
+ "260000\n",
+ "261000\n",
+ "262000\n",
+ "263000\n",
+ "264000\n",
+ "265000\n",
+ "266000\n",
+ "267000\n",
+ "268000\n",
+ "269000\n",
+ "270000\n",
+ "271000\n",
+ "272000\n",
+ "273000\n",
+ "274000\n",
+ "275000\n",
+ "276000\n",
+ "277000\n",
+ "278000\n",
+ "279000\n",
+ "280000\n",
+ "281000\n",
+ "282000\n",
+ "283000\n",
+ "284000\n",
+ "285000\n",
+ "286000\n",
+ "287000\n",
+ "288000\n",
+ "289000\n",
+ "290000\n",
+ "291000\n",
+ "292000\n",
+ "293000\n",
+ "294000\n",
+ "295000\n",
+ "296000\n",
+ "297000\n",
+ "298000\n",
+ "299000\n",
+ "300000\n",
+ "301000\n",
+ "302000\n",
+ "303000\n",
+ "304000\n",
+ "305000\n",
+ "306000\n",
+ "307000\n",
+ "308000\n",
+ "309000\n",
+ "310000\n",
+ "311000\n",
+ "312000\n",
+ "313000\n",
+ "314000\n",
+ "315000\n",
+ "316000\n",
+ "317000\n",
+ "318000\n",
+ "319000\n",
+ "320000\n",
+ "321000\n",
+ "322000\n",
+ "323000\n",
+ "324000\n",
+ "325000\n",
+ "326000\n",
+ "327000\n",
+ "328000\n",
+ "329000\n",
+ "330000\n",
+ "331000\n",
+ "332000\n",
+ "333000\n",
+ "334000\n",
+ "335000\n",
+ "336000\n",
+ "337000\n",
+ "338000\n",
+ "339000\n",
+ "340000\n",
+ "341000\n",
+ "342000\n",
+ "343000\n",
+ "344000\n",
+ "345000\n",
+ "346000\n",
+ "347000\n",
+ "348000\n",
+ "349000\n",
+ "350000\n",
+ "351000\n",
+ "352000\n",
+ "353000\n",
+ "354000\n",
+ "355000\n",
+ "356000\n",
+ "357000\n",
+ "358000\n",
+ "359000\n",
+ "360000\n",
+ "361000\n",
+ "362000\n",
+ "363000\n",
+ "364000\n",
+ "365000\n",
+ "366000\n",
+ "367000\n",
+ "368000\n",
+ "369000\n",
+ "370000\n",
+ "371000\n",
+ "372000\n",
+ "373000\n",
+ "374000\n",
+ "375000\n",
+ "376000\n",
+ "377000\n",
+ "378000\n",
+ "379000\n",
+ "380000\n",
+ "381000\n",
+ "382000\n",
+ "383000\n",
+ "384000\n",
+ "385000\n",
+ "386000\n",
+ "387000\n",
+ "388000\n",
+ "389000\n",
+ "390000\n",
+ "391000\n",
+ "392000\n",
+ "393000\n",
+ "394000\n",
+ "395000\n",
+ "396000\n",
+ "397000\n",
+ "398000\n",
+ "399000\n",
+ "400000\n",
+ "401000\n",
+ "402000\n",
+ "403000\n",
+ "404000\n",
+ "405000\n",
+ "406000\n",
+ "407000\n",
+ "408000\n",
+ "409000\n",
+ "410000\n",
+ "411000\n",
+ "412000\n",
+ "413000\n",
+ "414000\n",
+ "415000\n",
+ "416000\n",
+ "417000\n",
+ "418000\n",
+ "419000\n",
+ "420000\n",
+ "421000\n",
+ "422000\n",
+ "423000\n",
+ "424000\n",
+ "425000\n",
+ "426000\n",
+ "427000\n",
+ "428000\n",
+ "429000\n",
+ "430000\n",
+ "431000\n",
+ "432000\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "vocab['no']"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "b9vMOTlZxISl",
+ "outputId": "a3a71b17-3fb2-4794-ae5d-a43d9af49b69"
+ },
+ "execution_count": null,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "50"
+ ]
+ },
+ "metadata": {},
+ "execution_count": 19
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "with open('filename.pickle', 'wb') as handle:\n",
+ " pickle.dump(vocab, handle, protocol=pickle.HIGHEST_PROTOCOL)"
+ ],
+ "metadata": {
+ "id": "6R9l6tuPxB_B"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Create NN"
+ ],
+ "metadata": {
+ "id": "Be25rS6Uvl4V"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "from torch import nn\n",
+ "import torch\n",
+ "import pickle\n",
+ "# embed_size = 150\n",
+ "embed_size = 150\n",
+ "\n",
+ "class Bigram(nn.Module):\n",
+ " def __init__(self, vocabulary_size, embedding_size):\n",
+ " super(Bigram, self).__init__()\n",
+ " self.model = nn.Sequential(\n",
+ " nn.Embedding(vocabulary_size, embedding_size),\n",
+ " nn.Linear(embedding_size, vocabulary_size),\n",
+ " nn.Softmax()\n",
+ " )\n",
+ " def forward(self, x):\n",
+ " return self.model(x)\n",
+ "\n",
+ "model = Bigram(vocab_size, embed_size)\n",
+ "\n",
+ "vocab.set_default_index(vocab[''])\n",
+ "res = torch.tensor(vocab.forward(['order']))\n",
+ "print(res)\n"
+ ],
+ "metadata": {
+ "id": "dGTOmcHwWNSi",
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "outputId": "88c30492-7a9a-4b96-9119-ecdf5865bb51"
+ },
+ "execution_count": null,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "tensor([215])\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [],
+ "metadata": {
+ "id": "mWZ_jw-hxNXk"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "from torch.utils.data import IterableDataset\n",
+ "import itertools\n",
+ "\n",
+ "def look_ahead_iterator(gen):\n",
+ " prev = None\n",
+ " for item in gen:\n",
+ " if prev is not None:\n",
+ " yield (prev, item)\n",
+ " prev = item\n",
+ "\n",
+ "class Bigrams(IterableDataset):\n",
+ " def __init__(self, text_file, vocabulary_size):\n",
+ " self.vocab = build_vocab_from_iterator(\n",
+ " get_word_lines_from_file(text_file),\n",
+ " max_tokens = vocabulary_size,\n",
+ " specials = [''])\n",
+ " self.vocab.set_default_index(self.vocab[''])\n",
+ " self.vocabulary_size = vocabulary_size\n",
+ " self.text_file = text_file\n",
+ "\n",
+ " def __iter__(self):\n",
+ " return look_ahead_iterator(\n",
+ " (self.vocab[t] for t in itertools.chain.from_iterable(get_word_lines_from_file(self.text_file))))\n",
+ "\n",
+ "\n",
+ "train_dataset = Bigrams(train_file, vocab_size)"
+ ],
+ "metadata": {
+ "id": "5CSigeomWNVT",
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "outputId": "4fb3b1ff-f91b-4799-fc5f-17bae10b94ec"
+ },
+ "execution_count": null,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "1000\n",
+ "2000\n",
+ "3000\n",
+ "4000\n",
+ "5000\n",
+ "6000\n",
+ "7000\n",
+ "8000\n",
+ "9000\n",
+ "10000\n",
+ "11000\n",
+ "12000\n",
+ "13000\n",
+ "14000\n",
+ "15000\n",
+ "16000\n",
+ "17000\n",
+ "18000\n",
+ "19000\n",
+ "20000\n",
+ "21000\n",
+ "22000\n",
+ "23000\n",
+ "24000\n",
+ "25000\n",
+ "26000\n",
+ "27000\n",
+ "28000\n",
+ "29000\n",
+ "30000\n",
+ "31000\n",
+ "32000\n",
+ "33000\n",
+ "34000\n",
+ "35000\n",
+ "36000\n",
+ "37000\n",
+ "38000\n",
+ "39000\n",
+ "40000\n",
+ "41000\n",
+ "42000\n",
+ "43000\n",
+ "44000\n",
+ "45000\n",
+ "46000\n",
+ "47000\n",
+ "48000\n",
+ "49000\n",
+ "50000\n",
+ "51000\n",
+ "52000\n",
+ "53000\n",
+ "54000\n",
+ "55000\n",
+ "56000\n",
+ "57000\n",
+ "58000\n",
+ "59000\n",
+ "60000\n",
+ "61000\n",
+ "62000\n",
+ "63000\n",
+ "64000\n",
+ "65000\n",
+ "66000\n",
+ "67000\n",
+ "68000\n",
+ "69000\n",
+ "70000\n",
+ "71000\n",
+ "72000\n",
+ "73000\n",
+ "74000\n",
+ "75000\n",
+ "76000\n",
+ "77000\n",
+ "78000\n",
+ "79000\n",
+ "80000\n",
+ "81000\n",
+ "82000\n",
+ "83000\n",
+ "84000\n",
+ "85000\n",
+ "86000\n",
+ "87000\n",
+ "88000\n",
+ "89000\n",
+ "90000\n",
+ "91000\n",
+ "92000\n",
+ "93000\n",
+ "94000\n",
+ "95000\n",
+ "96000\n",
+ "97000\n",
+ "98000\n",
+ "99000\n",
+ "100000\n",
+ "101000\n",
+ "102000\n",
+ "103000\n",
+ "104000\n",
+ "105000\n",
+ "106000\n",
+ "107000\n",
+ "108000\n",
+ "109000\n",
+ "110000\n",
+ "111000\n",
+ "112000\n",
+ "113000\n",
+ "114000\n",
+ "115000\n",
+ "116000\n",
+ "117000\n",
+ "118000\n",
+ "119000\n",
+ "120000\n",
+ "121000\n",
+ "122000\n",
+ "123000\n",
+ "124000\n",
+ "125000\n",
+ "126000\n",
+ "127000\n",
+ "128000\n",
+ "129000\n",
+ "130000\n",
+ "131000\n",
+ "132000\n",
+ "133000\n",
+ "134000\n",
+ "135000\n",
+ "136000\n",
+ "137000\n",
+ "138000\n",
+ "139000\n",
+ "140000\n",
+ "141000\n",
+ "142000\n",
+ "143000\n",
+ "144000\n",
+ "145000\n",
+ "146000\n",
+ "147000\n",
+ "148000\n",
+ "149000\n",
+ "150000\n",
+ "151000\n",
+ "152000\n",
+ "153000\n",
+ "154000\n",
+ "155000\n",
+ "156000\n",
+ "157000\n",
+ "158000\n",
+ "159000\n",
+ "160000\n",
+ "161000\n",
+ "162000\n",
+ "163000\n",
+ "164000\n",
+ "165000\n",
+ "166000\n",
+ "167000\n",
+ "168000\n",
+ "169000\n",
+ "170000\n",
+ "171000\n",
+ "172000\n",
+ "173000\n",
+ "174000\n",
+ "175000\n",
+ "176000\n",
+ "177000\n",
+ "178000\n",
+ "179000\n",
+ "180000\n",
+ "181000\n",
+ "182000\n",
+ "183000\n",
+ "184000\n",
+ "185000\n",
+ "186000\n",
+ "187000\n",
+ "188000\n",
+ "189000\n",
+ "190000\n",
+ "191000\n",
+ "192000\n",
+ "193000\n",
+ "194000\n",
+ "195000\n",
+ "196000\n",
+ "197000\n",
+ "198000\n",
+ "199000\n",
+ "200000\n",
+ "201000\n",
+ "202000\n",
+ "203000\n",
+ "204000\n",
+ "205000\n",
+ "206000\n",
+ "207000\n",
+ "208000\n",
+ "209000\n",
+ "210000\n",
+ "211000\n",
+ "212000\n",
+ "213000\n",
+ "214000\n",
+ "215000\n",
+ "216000\n",
+ "217000\n",
+ "218000\n",
+ "219000\n",
+ "220000\n",
+ "221000\n",
+ "222000\n",
+ "223000\n",
+ "224000\n",
+ "225000\n",
+ "226000\n",
+ "227000\n",
+ "228000\n",
+ "229000\n",
+ "230000\n",
+ "231000\n",
+ "232000\n",
+ "233000\n",
+ "234000\n",
+ "235000\n",
+ "236000\n",
+ "237000\n",
+ "238000\n",
+ "239000\n",
+ "240000\n",
+ "241000\n",
+ "242000\n",
+ "243000\n",
+ "244000\n",
+ "245000\n",
+ "246000\n",
+ "247000\n",
+ "248000\n",
+ "249000\n",
+ "250000\n",
+ "251000\n",
+ "252000\n",
+ "253000\n",
+ "254000\n",
+ "255000\n",
+ "256000\n",
+ "257000\n",
+ "258000\n",
+ "259000\n",
+ "260000\n",
+ "261000\n",
+ "262000\n",
+ "263000\n",
+ "264000\n",
+ "265000\n",
+ "266000\n",
+ "267000\n",
+ "268000\n",
+ "269000\n",
+ "270000\n",
+ "271000\n",
+ "272000\n",
+ "273000\n",
+ "274000\n",
+ "275000\n",
+ "276000\n",
+ "277000\n",
+ "278000\n",
+ "279000\n",
+ "280000\n",
+ "281000\n",
+ "282000\n",
+ "283000\n",
+ "284000\n",
+ "285000\n",
+ "286000\n",
+ "287000\n",
+ "288000\n",
+ "289000\n",
+ "290000\n",
+ "291000\n",
+ "292000\n",
+ "293000\n",
+ "294000\n",
+ "295000\n",
+ "296000\n",
+ "297000\n",
+ "298000\n",
+ "299000\n",
+ "300000\n",
+ "301000\n",
+ "302000\n",
+ "303000\n",
+ "304000\n",
+ "305000\n",
+ "306000\n",
+ "307000\n",
+ "308000\n",
+ "309000\n",
+ "310000\n",
+ "311000\n",
+ "312000\n",
+ "313000\n",
+ "314000\n",
+ "315000\n",
+ "316000\n",
+ "317000\n",
+ "318000\n",
+ "319000\n",
+ "320000\n",
+ "321000\n",
+ "322000\n",
+ "323000\n",
+ "324000\n",
+ "325000\n",
+ "326000\n",
+ "327000\n",
+ "328000\n",
+ "329000\n",
+ "330000\n",
+ "331000\n",
+ "332000\n",
+ "333000\n",
+ "334000\n",
+ "335000\n",
+ "336000\n",
+ "337000\n",
+ "338000\n",
+ "339000\n",
+ "340000\n",
+ "341000\n",
+ "342000\n",
+ "343000\n",
+ "344000\n",
+ "345000\n",
+ "346000\n",
+ "347000\n",
+ "348000\n",
+ "349000\n",
+ "350000\n",
+ "351000\n",
+ "352000\n",
+ "353000\n",
+ "354000\n",
+ "355000\n",
+ "356000\n",
+ "357000\n",
+ "358000\n",
+ "359000\n",
+ "360000\n",
+ "361000\n",
+ "362000\n",
+ "363000\n",
+ "364000\n",
+ "365000\n",
+ "366000\n",
+ "367000\n",
+ "368000\n",
+ "369000\n",
+ "370000\n",
+ "371000\n",
+ "372000\n",
+ "373000\n",
+ "374000\n",
+ "375000\n",
+ "376000\n",
+ "377000\n",
+ "378000\n",
+ "379000\n",
+ "380000\n",
+ "381000\n",
+ "382000\n",
+ "383000\n",
+ "384000\n",
+ "385000\n",
+ "386000\n",
+ "387000\n",
+ "388000\n",
+ "389000\n",
+ "390000\n",
+ "391000\n",
+ "392000\n",
+ "393000\n",
+ "394000\n",
+ "395000\n",
+ "396000\n",
+ "397000\n",
+ "398000\n",
+ "399000\n",
+ "400000\n",
+ "401000\n",
+ "402000\n",
+ "403000\n",
+ "404000\n",
+ "405000\n",
+ "406000\n",
+ "407000\n",
+ "408000\n",
+ "409000\n",
+ "410000\n",
+ "411000\n",
+ "412000\n",
+ "413000\n",
+ "414000\n",
+ "415000\n",
+ "416000\n",
+ "417000\n",
+ "418000\n",
+ "419000\n",
+ "420000\n",
+ "421000\n",
+ "422000\n",
+ "423000\n",
+ "424000\n",
+ "425000\n",
+ "426000\n",
+ "427000\n",
+ "428000\n",
+ "429000\n",
+ "430000\n",
+ "431000\n",
+ "432000\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "from torch.utils.data import DataLoader\n",
+ "\n",
+ "next(iter(DataLoader(train_dataset, batch_size=5)))"
+ ],
+ "metadata": {
+ "id": "oYAZ772rWNX7",
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "outputId": "da3a811a-36cf-4d34-82d8-b1ce2447c232"
+ },
+ "execution_count": null,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "[tensor([ 23, 191, 5791, 1, 112]),\n",
+ " tensor([ 191, 5791, 1, 112, 159])]"
+ ]
+ },
+ "metadata": {},
+ "execution_count": 23
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Train"
+ ],
+ "metadata": {
+ "id": "1H_dI372vrNh"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
+ ],
+ "metadata": {
+ "id": "N2u4Qmadgtdn"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "\n",
+ "model = Bigram(vocab_size, embed_size).to(device)\n",
+ "data = DataLoader(train_dataset, batch_size=1000)\n",
+ "optimizer = torch.optim.Adam(model.parameters())\n",
+ "criterion = torch.nn.NLLLoss()\n",
+ "## epochs=2\n",
+ "for i in range(2):\n",
+ " print('epoch: =', i)\n",
+ " model.train()\n",
+ " step = 0\n",
+ " for x, y in data:\n",
+ " x = x.to(device)\n",
+ " y = y.to(device)\n",
+ " optimizer.zero_grad()\n",
+ " ypredicted = model(x)\n",
+ " loss = criterion(torch.log(ypredicted), y)\n",
+ " if step % 100 == 0:\n",
+ " print(step, loss)\n",
+ " step += 1\n",
+ " loss.backward()\n",
+ " optimizer.step()\n",
+ " torch.save(model.state_dict(), 'model.bin') \n"
+ ],
+ "metadata": {
+ "id": "OGk2tjbvWNag",
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "outputId": "868503a1-2849-40ff-e703-886fba094927"
+ },
+ "execution_count": null,
+ "outputs": [
+ {
+ "metadata": {
+ "tags": null
+ },
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "epoch: = 0\n"
+ ]
+ },
+ {
+ "metadata": {
+ "tags": null
+ },
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/container.py:217: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n",
+ " input = module(input)\n"
+ ]
+ },
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "0 tensor(10.3037, device='cuda:0', grad_fn=)\n",
+ "100 tensor(8.7506, device='cuda:0', grad_fn=)\n",
+ "200 tensor(7.8141, device='cuda:0', grad_fn=)\n",
+ "1000\n",
+ "300 tensor(7.4218, device='cuda:0', grad_fn=)\n",
+ "400 tensor(7.1627, device='cuda:0', grad_fn=)\n",
+ "500 tensor(6.7964, device='cuda:0', grad_fn=)\n",
+ "2000\n",
+ "600 tensor(6.4704, device='cuda:0', grad_fn=)\n",
+ "700 tensor(6.3798, device='cuda:0', grad_fn=)\n",
+ "800 tensor(6.2849, device='cuda:0', grad_fn=)\n",
+ "3000\n",
+ "900 tensor(6.3975, device='cuda:0', grad_fn=)\n",
+ "1000 tensor(6.0096, device='cuda:0', grad_fn=)\n",
+ "1100 tensor(5.7434, device='cuda:0', grad_fn=)\n",
+ "4000\n",
+ "1200 tensor(5.9602, device='cuda:0', grad_fn=)\n",
+ "1300 tensor(6.1623, device='cuda:0', grad_fn=)\n",
+ "1400 tensor(6.1647, device='cuda:0', grad_fn=)\n",
+ "5000\n",
+ "1500 tensor(6.1010, device='cuda:0', grad_fn=)\n",
+ "1600 tensor(6.0634, device='cuda:0', grad_fn=)\n",
+ "6000\n",
+ "1700 tensor(5.9149, device='cuda:0', grad_fn=)\n",
+ "1800 tensor(5.7918, device='cuda:0', grad_fn=)\n",
+ "1900 tensor(5.6739, device='cuda:0', grad_fn=)\n",
+ "7000\n",
+ "2000 tensor(5.5298, device='cuda:0', grad_fn=)\n",
+ "2100 tensor(5.8011, device='cuda:0', grad_fn=)\n",
+ "2200 tensor(5.4338, device='cuda:0', grad_fn=)\n",
+ "8000\n",
+ "2300 tensor(5.7522, device='cuda:0', grad_fn=)\n",
+ "2400 tensor(5.0313, device='cuda:0', grad_fn=)\n",
+ "2500 tensor(5.7116, device='cuda:0', grad_fn=)\n",
+ "9000\n",
+ "2600 tensor(5.2706, device='cuda:0', grad_fn=)\n",
+ "2700 tensor(5.6324, device='cuda:0', grad_fn=)\n",
+ "2800 tensor(5.0710, device='cuda:0', grad_fn=)\n",
+ "10000\n",
+ "2900 tensor(5.5921, device='cuda:0', grad_fn=)\n",
+ "3000 tensor(5.4808, device='cuda:0', grad_fn=)\n",
+ "11000\n",
+ "3100 tensor(5.3611, device='cuda:0', grad_fn=)\n",
+ "3200 tensor(5.6228, device='cuda:0', grad_fn=)\n",
+ "3300 tensor(5.4286, device='cuda:0', grad_fn=)\n",
+ "12000\n",
+ "3400 tensor(5.3550, device='cuda:0', grad_fn=)\n",
+ "3500 tensor(5.4032, device='cuda:0', grad_fn=)\n",
+ "3600 tensor(5.1070, device='cuda:0', grad_fn=)\n",
+ "13000\n",
+ "3700 tensor(5.4506, device='cuda:0', grad_fn=)\n",
+ "3800 tensor(5.4622, device='cuda:0', grad_fn=)\n",
+ "3900 tensor(5.4984, device='cuda:0', grad_fn=)\n",
+ "14000\n",
+ "4000 tensor(5.1740, device='cuda:0', grad_fn=)\n",
+ "4100 tensor(5.6064, device='cuda:0', grad_fn=)\n",
+ "4200 tensor(5.0705, device='cuda:0', grad_fn=)\n",
+ "15000\n",
+ "4300 tensor(5.5181, device='cuda:0', grad_fn=)\n",
+ "4400 tensor(5.2919, device='cuda:0', grad_fn=)\n",
+ "16000\n",
+ "4500 tensor(5.5021, device='cuda:0', grad_fn=)\n",
+ "4600 tensor(5.5308, device='cuda:0', grad_fn=)\n",
+ "4700 tensor(5.4699, device='cuda:0', grad_fn=)\n",
+ "17000\n",
+ "4800 tensor(5.2686, device='cuda:0', grad_fn=)\n",
+ "4900 tensor(5.4776, device='cuda:0', grad_fn=)\n",
+ "5000 tensor(5.5061, device='cuda:0', grad_fn=)\n",
+ "18000\n",
+ "5100 tensor(5.3180, device='cuda:0', grad_fn=)\n",
+ "5200 tensor(5.5524, device='cuda:0', grad_fn=)\n",
+ "5300 tensor(5.3481, device='cuda:0', grad_fn=)\n",
+ "19000\n",
+ "5400 tensor(5.2153, device='cuda:0', grad_fn=)\n",
+ "5500 tensor(5.4478, device='cuda:0', grad_fn=)\n",
+ "20000\n",
+ "5600 tensor(5.3441, device='cuda:0', grad_fn=)\n",
+ "5700 tensor(5.3958, device='cuda:0', grad_fn=)\n",
+ "5800 tensor(5.8945, device='cuda:0', grad_fn=)\n",
+ "21000\n",
+ "5900 tensor(5.5684, device='cuda:0', grad_fn=)\n",
+ "6000 tensor(5.5715, device='cuda:0', grad_fn=)\n",
+ "6100 tensor(5.2367, device='cuda:0', grad_fn=)\n",
+ "22000\n",
+ "6200 tensor(5.6976, device='cuda:0', grad_fn=)\n",
+ "6300 tensor(5.5367, device='cuda:0', grad_fn=)\n",
+ "6400 tensor(5.3024, device='cuda:0', grad_fn=)\n",
+ "23000\n",
+ "6500 tensor(5.3010, device='cuda:0', grad_fn=)\n",
+ "6600 tensor(6.0962, device='cuda:0', grad_fn=)\n",
+ "6700 tensor(5.0961, device='cuda:0', grad_fn=)\n",
+ "24000\n",
+ "6800 tensor(5.1091, device='cuda:0', grad_fn=)\n",
+ "6900 tensor(5.4123, device='cuda:0', grad_fn=)\n",
+ "25000\n",
+ "7000 tensor(5.3128, device='cuda:0', grad_fn=)\n",
+ "7100 tensor(5.3416, device='cuda:0', grad_fn=)\n",
+ "7200 tensor(5.4973, device='cuda:0', grad_fn=)\n",
+ "26000\n",
+ "7300 tensor(5.4418, device='cuda:0', grad_fn=)\n",
+ "7400 tensor(5.2171, device='cuda:0', grad_fn=)\n",
+ "7500 tensor(5.6509, device='cuda:0', grad_fn=)\n",
+ "27000\n",
+ "7600 tensor(5.0550, device='cuda:0', grad_fn=)\n",
+ "7700 tensor(5.4937, device='cuda:0', grad_fn=)\n",
+ "7800 tensor(5.9218, device='cuda:0', grad_fn=)\n",
+ "28000\n",
+ "7900 tensor(5.2853, device='cuda:0', grad_fn=)\n",
+ "8000 tensor(5.3146, device='cuda:0', grad_fn=)\n",
+ "8100 tensor(4.8552, device='cuda:0', grad_fn=)\n",
+ "29000\n",
+ "8200 tensor(5.3389, device='cuda:0', grad_fn=)\n",
+ "8300 tensor(5.2421, device='cuda:0', grad_fn=)\n",
+ "30000\n",
+ "8400 tensor(5.2460, device='cuda:0', grad_fn=)\n",
+ "8500 tensor(5.0331, device='cuda:0', grad_fn=)\n",
+ "8600 tensor(5.0050, device='cuda:0', grad_fn=)\n",
+ "31000\n",
+ "8700 tensor(5.3844, device='cuda:0', grad_fn=)\n",
+ "8800 tensor(5.4491, device='cuda:0', grad_fn=)\n",
+ "8900 tensor(5.6790, device='cuda:0', grad_fn=)\n",
+ "32000\n",
+ "9000 tensor(5.1118, device='cuda:0', grad_fn=)\n",
+ "9100 tensor(5.3567, device='cuda:0', grad_fn=)\n",
+ "9200 tensor(5.4141, device='cuda:0', grad_fn=)\n",
+ "33000\n",
+ "9300 tensor(5.3085, device='cuda:0', grad_fn=)\n",
+ "9400 tensor(5.2808, device='cuda:0', grad_fn=)\n",
+ "34000\n",
+ "9500 tensor(5.0931, device='cuda:0', grad_fn=)\n",
+ "9600 tensor(5.1090, device='cuda:0', grad_fn=)\n",
+ "9700 tensor(5.2519, device='cuda:0', grad_fn=)\n",
+ "35000\n",
+ "9800 tensor(5.3852, device='cuda:0', grad_fn=)\n",
+ "9900 tensor(5.0943, device='cuda:0', grad_fn=)\n",
+ "10000 tensor(5.4690, device='cuda:0', grad_fn=)\n",
+ "36000\n",
+ "10100 tensor(5.4348, device='cuda:0', grad_fn=)\n",
+ "10200 tensor(5.3262, device='cuda:0', grad_fn=)\n",
+ "10300 tensor(5.4878, device='cuda:0', grad_fn=)\n",
+ "37000\n",
+ "10400 tensor(5.2384, device='cuda:0', grad_fn=)\n",
+ "10500 tensor(5.2151, device='cuda:0', grad_fn=)\n",
+ "10600 tensor(4.8722, device='cuda:0', grad_fn=)\n",
+ "38000\n",
+ "10700 tensor(5.4325, device='cuda:0', grad_fn=)\n",
+ "10800 tensor(4.8699, device='cuda:0', grad_fn=)\n",
+ "39000\n",
+ "10900 tensor(5.3448, device='cuda:0', grad_fn=)\n",
+ "11000 tensor(5.1358, device='cuda:0', grad_fn=)\n",
+ "11100 tensor(5.0432, device='cuda:0', grad_fn=)\n",
+ "40000\n",
+ "11200 tensor(5.4062, device='cuda:0', grad_fn=)\n",
+ "11300 tensor(5.4040, device='cuda:0', grad_fn=)\n",
+ "11400 tensor(5.5312, device='cuda:0', grad_fn=)\n",
+ "41000\n",
+ "11500 tensor(5.4374, device='cuda:0', grad_fn=)\n",
+ "11600 tensor(5.0998, device='cuda:0', grad_fn=)\n",
+ "11700 tensor(5.4217, device='cuda:0', grad_fn=)\n",
+ "42000\n",
+ "11800 tensor(5.5747, device='cuda:0', grad_fn=)\n",
+ "11900 tensor(5.0467, device='cuda:0', grad_fn=)\n",
+ "12000 tensor(5.4270, device='cuda:0', grad_fn=)\n",
+ "43000\n",
+ "12100 tensor(5.2043, device='cuda:0', grad_fn=)\n",
+ "12200 tensor(5.2369, device='cuda:0', grad_fn=)\n",
+ "44000\n",
+ "12300 tensor(5.4465, device='cuda:0', grad_fn=)\n",
+ "12400 tensor(4.9839, device='cuda:0', grad_fn=)\n",
+ "12500 tensor(5.3214, device='cuda:0', grad_fn=)\n",
+ "45000\n",
+ "12600 tensor(5.1928, device='cuda:0', grad_fn=)\n",
+ "12700 tensor(4.9646, device='cuda:0', grad_fn=)\n",
+ "12800 tensor(5.3325, device='cuda:0', grad_fn=)\n",
+ "46000\n",
+ "12900 tensor(5.4429, device='cuda:0', grad_fn=)\n",
+ "13000 tensor(5.0652, device='cuda:0', grad_fn=)\n",
+ "13100 tensor(5.3126, device='cuda:0', grad_fn=)\n",
+ "47000\n",
+ "13200 tensor(5.4124, device='cuda:0', grad_fn=)\n",
+ "13300 tensor(5.5385, device='cuda:0', grad_fn=)\n",
+ "13400 tensor(5.0986, device='cuda:0', grad_fn=)\n",
+ "48000\n",
+ "13500 tensor(5.2693, device='cuda:0', grad_fn=)\n",
+ "13600 tensor(5.2136, device='cuda:0', grad_fn=)\n",
+ "49000\n",
+ "13700 tensor(5.5169, device='cuda:0', grad_fn=)\n",
+ "13800 tensor(5.1840, device='cuda:0', grad_fn=)\n",
+ "13900 tensor(5.2700, device='cuda:0', grad_fn=)\n",
+ "50000\n",
+ "14000 tensor(5.2077, device='cuda:0', grad_fn=)\n",
+ "14100 tensor(5.3791, device='cuda:0', grad_fn=)\n",
+ "14200 tensor(5.4008, device='cuda:0', grad_fn=)\n",
+ "51000\n",
+ "14300 tensor(5.3506, device='cuda:0', grad_fn=)\n",
+ "14400 tensor(4.7662, device='cuda:0', grad_fn=)\n",
+ "14500 tensor(4.9474, device='cuda:0', grad_fn=)\n",
+ "52000\n",
+ "14600 tensor(5.0245, device='cuda:0', grad_fn=)\n",
+ "14700 tensor(5.3977, device='cuda:0', grad_fn=)\n",
+ "14800 tensor(4.9653, device='cuda:0', grad_fn=)\n",
+ "53000\n",
+ "14900 tensor(4.8947, device='cuda:0', grad_fn=)\n",
+ "15000 tensor(5.3548, device='cuda:0', grad_fn=)\n",
+ "54000\n",
+ "15100 tensor(4.7244, device='cuda:0', grad_fn=)\n",
+ "15200 tensor(4.9752, device='cuda:0', grad_fn=)\n",
+ "15300 tensor(5.3929, device='cuda:0', grad_fn=)\n",
+ "55000\n",
+ "15400 tensor(5.3096, device='cuda:0', grad_fn=)\n",
+ "15500 tensor(5.1247, device='cuda:0', grad_fn=)\n",
+ "15600 tensor(5.2753, device='cuda:0', grad_fn=)\n",
+ "56000\n",
+ "15700 tensor(5.2373, device='cuda:0', grad_fn=)\n",
+ "15800 tensor(4.9997, device='cuda:0', grad_fn=)\n",
+ "15900 tensor(5.1718, device='cuda:0', grad_fn=)\n",
+ "57000\n",
+ "16000 tensor(5.5952, device='cuda:0', grad_fn=)\n",
+ "16100 tensor(5.3699, device='cuda:0', grad_fn=)\n",
+ "16200 tensor(5.0923, device='cuda:0', grad_fn=)\n",
+ "58000\n",
+ "16300 tensor(4.9985, device='cuda:0', grad_fn=)\n",
+ "16400 tensor(5.3076, device='cuda:0', grad_fn=)\n",
+ "59000\n",
+ "16500 tensor(5.1994, device='cuda:0', grad_fn=)\n",
+ "16600 tensor(5.3672, device='cuda:0', grad_fn=)\n",
+ "16700 tensor(5.2054, device='cuda:0', grad_fn=)\n",
+ "60000\n",
+ "16800 tensor(5.3379, device='cuda:0', grad_fn=)\n",
+ "16900 tensor(5.2785, device='cuda:0', grad_fn=)\n",
+ "17000 tensor(5.2590, device='cuda:0', grad_fn=)\n",
+ "61000\n",
+ "17100 tensor(5.3564, device='cuda:0', grad_fn=)\n",
+ "17200 tensor(5.3598, device='cuda:0', grad_fn=)\n",
+ "17300 tensor(4.7786, device='cuda:0', grad_fn=)\n",
+ "62000\n",
+ "17400 tensor(5.2639, device='cuda:0', grad_fn=)\n",
+ "17500 tensor(5.2037, device='cuda:0', grad_fn=)\n",
+ "17600 tensor(5.1158, device='cuda:0', grad_fn=)\n",
+ "63000\n",
+ "17700 tensor(4.9831, device='cuda:0', grad_fn=)\n",
+ "17800 tensor(4.8950, device='cuda:0', grad_fn=)\n",
+ "64000\n",
+ "17900 tensor(5.0928, device='cuda:0', grad_fn=)\n",
+ "18000 tensor(5.3423, device='cuda:0', grad_fn=)\n",
+ "18100 tensor(5.1760, device='cuda:0', grad_fn=)\n",
+ "65000\n",
+ "18200 tensor(5.2021, device='cuda:0', grad_fn=)\n",
+ "18300 tensor(5.1306, device='cuda:0', grad_fn=)\n",
+ "18400 tensor(5.1199, device='cuda:0', grad_fn=)\n",
+ "66000\n",
+ "18500 tensor(5.2082, device='cuda:0', grad_fn=)\n",
+ "18600 tensor(5.3290, device='cuda:0', grad_fn=)\n",
+ "18700 tensor(5.2257, device='cuda:0', grad_fn=)\n",
+ "67000\n",
+ "18800 tensor(4.9107, device='cuda:0', grad_fn=)\n",
+ "18900 tensor(5.3400, device='cuda:0', grad_fn=)\n",
+ "68000\n",
+ "19000 tensor(5.1366, device='cuda:0', grad_fn=)\n",
+ "19100 tensor(5.1199, device='cuda:0', grad_fn=)\n",
+ "19200 tensor(5.2202, device='cuda:0', grad_fn=)\n",
+ "69000\n",
+ "19300 tensor(5.2236, device='cuda:0', grad_fn=)\n",
+ "19400 tensor(5.2953, device='cuda:0', grad_fn=)\n",
+ "19500 tensor(5.1308, device='cuda:0', grad_fn=)\n",
+ "70000\n",
+ "19600 tensor(5.3578, device='cuda:0', grad_fn=)\n",
+ "19700 tensor(5.1600, device='cuda:0', grad_fn=)\n",
+ "19800 tensor(4.6220, device='cuda:0', grad_fn=)\n",
+ "71000\n",
+ "19900 tensor(5.3731, device='cuda:0', grad_fn=)\n",
+ "20000 tensor(4.9936, device='cuda:0', grad_fn=)\n",
+ "20100 tensor(5.0817, device='cuda:0', grad_fn=)\n",
+ "72000\n",
+ "20200 tensor(5.1613, device='cuda:0', grad_fn=)\n",
+ "20300 tensor(5.3877, device='cuda:0', grad_fn=)\n",
+ "73000\n",
+ "20400 tensor(5.4114, device='cuda:0', grad_fn=)\n",
+ "20500 tensor(5.2609, device='cuda:0', grad_fn=)\n",
+ "20600 tensor(5.1378, device='cuda:0', grad_fn=)\n",
+ "74000\n",
+ "20700 tensor(5.0799, device='cuda:0', grad_fn=)\n",
+ "20800 tensor(5.3615, device='cuda:0', grad_fn=)\n",
+ "20900 tensor(5.3365, device='cuda:0', grad_fn=)\n",
+ "75000\n",
+ "21000 tensor(4.9244, device='cuda:0', grad_fn=)\n",
+ "21100 tensor(5.5084, device='cuda:0', grad_fn=)\n",
+ "21200 tensor(4.8769, device='cuda:0', grad_fn=)\n",
+ "76000\n",
+ "21300 tensor(5.3414, device='cuda:0', grad_fn=)\n",
+ "21400 tensor(5.0648, device='cuda:0', grad_fn=)\n",
+ "21500 tensor(5.0594, device='cuda:0', grad_fn=)\n",
+ "77000\n",
+ "21600 tensor(5.2537, device='cuda:0', grad_fn=)\n",
+ "21700 tensor(5.1834, device='cuda:0', grad_fn=)\n",
+ "21800 tensor(4.8151, device='cuda:0', grad_fn=)\n",
+ "78000\n",
+ "21900 tensor(5.3335, device='cuda:0', grad_fn=)\n",
+ "22000 tensor(4.9580, device='cuda:0', grad_fn=)\n",
+ "79000\n",
+ "22100 tensor(5.2262, device='cuda:0', grad_fn=)\n",
+ "22200 tensor(5.1946, device='cuda:0', grad_fn=)\n",
+ "22300 tensor(5.2404, device='cuda:0', grad_fn=)\n",
+ "80000\n",
+ "22400 tensor(4.9491, device='cuda:0', grad_fn=)\n",
+ "22500 tensor(4.6901, device='cuda:0', grad_fn=)\n",
+ "22600 tensor(5.1937, device='cuda:0', grad_fn=)\n",
+ "81000\n",
+ "22700 tensor(4.9937, device='cuda:0', grad_fn=)\n",
+ "22800 tensor(5.1401, device='cuda:0', grad_fn=)\n",
+ "22900 tensor(5.0599, device='cuda:0', grad_fn=)\n",
+ "82000\n",
+ "23000 tensor(5.4315, device='cuda:0', grad_fn=)\n",
+ "23100 tensor(5.1854, device='cuda:0', grad_fn=)\n",
+ "83000\n",
+ "23200 tensor(5.1033, device='cuda:0', grad_fn=)\n",
+ "23300 tensor(5.2352, device='cuda:0', grad_fn=)\n",
+ "23400 tensor(5.2004, device='cuda:0', grad_fn=)\n",
+ "84000\n",
+ "23500 tensor(5.0866, device='cuda:0', grad_fn=)\n",
+ "23600 tensor(5.2372, device='cuda:0', grad_fn=)\n",
+ "23700 tensor(5.4711, device='cuda:0', grad_fn=)\n",
+ "85000\n",
+ "23800 tensor(5.4030, device='cuda:0', grad_fn=)\n",
+ "23900 tensor(5.3589, device='cuda:0', grad_fn=)\n",
+ "24000 tensor(5.1646, device='cuda:0', grad_fn=)\n",
+ "86000\n",
+ "24100 tensor(5.4865, device='cuda:0', grad_fn=)\n",
+ "24200 tensor(5.3663, device='cuda:0', grad_fn=)\n",
+ "24300 tensor(5.1760, device='cuda:0', grad_fn=)\n",
+ "87000\n",
+ "24400 tensor(5.2950, device='cuda:0', grad_fn=)\n",
+ "24500 tensor(5.0376, device='cuda:0', grad_fn=)\n",
+ "88000\n",
+ "24600 tensor(5.1229, device='cuda:0', grad_fn=)\n",
+ "24700 tensor(5.3261, device='cuda:0', grad_fn=)\n",
+ "24800 tensor(5.3953, device='cuda:0', grad_fn=)\n",
+ "89000\n",
+ "24900 tensor(5.2734, device='cuda:0', grad_fn=)\n",
+ "25000 tensor(5.5544, device='cuda:0', grad_fn=