From 49991a7da6fb807b3cb55f605648abf5d5b52990 Mon Sep 17 00:00:00 2001 From: Maciej Czajka Date: Thu, 27 Apr 2023 21:39:28 +0200 Subject: [PATCH] add notebook --- simple_neural_network.ipynb | 4874 +++++++++++++++++++++++++++++++++++ 1 file changed, 4874 insertions(+) create mode 100644 simple_neural_network.ipynb diff --git a/simple_neural_network.ipynb b/simple_neural_network.ipynb new file mode 100644 index 0000000..ef10ea2 --- /dev/null +++ b/simple_neural_network.ipynb @@ -0,0 +1,4874 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "source": [ + "## IMPORTS" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 1, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "True\n", + "True\n" + ] + } + ], + "source": [ + "import regex as re\n", + "import sys\n", + "from torchtext.vocab import build_vocab_from_iterator\n", + "import lzma\n", + "from torch.utils.data import IterableDataset\n", + "import itertools\n", + "from torch import nn\n", + "import torch\n", + "import pickle\n", + "from torch.utils.data import DataLoader\n", + "\n", + "print(torch.backends.mps.is_available())\n", + "print(torch.backends.mps.is_built())" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "markdown", + "source": [ + "## FUNCTIONS" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 2, + "outputs": [], + "source": [ + "def get_words_from_line(line):\n", + " line = line.rstrip()\n", + " yield ''\n", + " for t in line.split(' '):\n", + " yield t\n", + " yield ''\n", + "\n", + "def get_word_lines_from_file(file_name):\n", + " n = 0\n", + " with lzma.open(file_name, 'r') as fh:\n", + " for line in fh:\n", + " n += 1\n", + " if n % 1000 == 0:\n", + " print(n ,file=sys.stderr)\n", + " yield get_words_from_line(line.decode('utf-8'))\n", + "\n", + "def look_ahead_iterator(gen):\n", + " prev = None\n", + " for item in gen:\n", + " if prev is not None:\n", + " yield (prev, item)\n", + " prev = item\n", + "\n", + "def clean(text):\n", + " text = str(text).lower().replace('-\\\\n', '').replace('\\\\n', ' ').replace('-', '').replace('\\'s', ' is').replace('\\'re', ' are').replace('\\'m', ' am').replace('\\'ve', ' have').replace('\\'ll', ' will')\n", + " text = re.sub(r'\\p{P}', '', text)\n", + " return text\n", + "\n", + "def predict(word, model, vocab):\n", + " try:\n", + " ixs = torch.tensor(vocab.forward([word])).to(device)\n", + " except:\n", + " ixs = torch.tensor(vocab.forward([''])).to(device)\n", + " word = ''\n", + " out = model(ixs)\n", + " top = torch.topk(out[0], 300)\n", + " top_indices = top.indices.tolist()\n", + " top_probs = top.values.tolist()\n", + " top_words = vocab.lookup_tokens(top_indices)\n", + " prob_list = list(zip(top_words, top_probs))\n", + " for index, element in enumerate(prob_list):\n", + " unk = None\n", + " if '' in element:\n", + " unk = prob_list.pop(index)\n", + " prob_list.append(('', unk[1]))\n", + " break\n", + " if unk is None:\n", + " prob_list[-1] = ('', prob_list[-1][1])\n", + " return ' '.join([f'{x[0]}:{x[1]}' for x in prob_list])\n", + "\n", + "def predicition_for_file(model, vocab, folder, file):\n", + " print('=' * 10, f' do prediction for {folder}/{file} ', '=' * 10)\n", + " with lzma.open(f'{folder}/in.tsv.xz', mode='rt', encoding='utf-8') as f:\n", + " with open(f'{folder}/out.tsv', 'w', encoding='utf-8') as fid:\n", + "\n", + " for line in f:\n", + " separated = line.split('\\t')\n", + " before = clean(separated[6]).split()[-1]\n", + " new_line = predict(before, model, vocab)\n", + " fid.write(new_line + '\\n')" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "markdown", + "source": [ + "## CLASSES" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 3, + "outputs": [], + "source": [ + "class Bigrams(IterableDataset):\n", + " def __init__(self, text_file, vocabulary_size):\n", + " self.vocab = build_vocab_from_iterator(\n", + " get_word_lines_from_file(text_file),\n", + " max_tokens = vocabulary_size,\n", + " specials = [''])\n", + " self.vocab.set_default_index(self.vocab[''])\n", + " self.vocabulary_size = vocabulary_size\n", + " self.text_file = text_file\n", + "\n", + " def __iter__(self):\n", + " return look_ahead_iterator(\n", + " (self.vocab[t] for t in itertools.chain.from_iterable(get_word_lines_from_file(self.text_file))))\n", + "\n", + "class SimpleBigramNeuralLanguageModel(nn.Module):\n", + " def __init__(self, vocabulary_size, embedding_size):\n", + " super(SimpleBigramNeuralLanguageModel, self).__init__()\n", + " self.model = nn.Sequential(\n", + " nn.Embedding(vocabulary_size, embedding_size),\n", + " nn.Linear(embedding_size, vocabulary_size),\n", + " nn.Softmax()\n", + " )\n", + "\n", + " def forward(self, x):\n", + " return self.model(x)" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "markdown", + "source": [ + "## PARAMETERS" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 4, + "outputs": [], + "source": [ + "vocab_size = 30000\n", + "embed_size = 1000\n", + "batch_size = 5000\n", + "device = 'mps'\n", + "path_to_training_file = './train/in.tsv.xz'\n", + "path_to_model_file = 'model_neural_network.bin'\n", + "folder_dev_0, file_dev_0 = 'dev-0', 'in.tsv.xz'\n", + "folder_test_a, file_test_a = 'test-A', 'in.tsv.xz'\n", + "path_to_vocabulary_file = 'vocabulary_neural_network.pickle'" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "markdown", + "source": [ + "## VOCAB" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 5, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "1000\n", + "2000\n", + "3000\n", + "4000\n", + "5000\n", + "6000\n", + "7000\n", + "8000\n", + "9000\n", + "10000\n", + "11000\n", + "12000\n", + "13000\n", + "14000\n", + "15000\n", + "16000\n", + "17000\n", + "18000\n", + "19000\n", + "20000\n", + "21000\n", + "22000\n", + "23000\n", + "24000\n", + "25000\n", + "26000\n", + "27000\n", + "28000\n", + "29000\n", + "30000\n", + "31000\n", + "32000\n", + "33000\n", + "34000\n", + "35000\n", + "36000\n", + "37000\n", + "38000\n", + "39000\n", + "40000\n", + "41000\n", + "42000\n", + "43000\n", + "44000\n", + "45000\n", + "46000\n", + "47000\n", + "48000\n", + "49000\n", + "50000\n", + "51000\n", + "52000\n", + "53000\n", + "54000\n", + "55000\n", + "56000\n", + "57000\n", + "58000\n", + "59000\n", + "60000\n", + "61000\n", + "62000\n", + "63000\n", + "64000\n", + "65000\n", + "66000\n", + "67000\n", + "68000\n", + "69000\n", + "70000\n", + "71000\n", + "72000\n", + "73000\n", + "74000\n", + "75000\n", + "76000\n", + "77000\n", + "78000\n", + "79000\n", + "80000\n", + "81000\n", + "82000\n", + "83000\n", + "84000\n", + "85000\n", + "86000\n", + "87000\n", + "88000\n", + "89000\n", + "90000\n", + "91000\n", + "92000\n", + "93000\n", + "94000\n", + "95000\n", + "96000\n", + "97000\n", + "98000\n", + "99000\n", + "100000\n", + "101000\n", + "102000\n", + "103000\n", + "104000\n", + "105000\n", + "106000\n", + "107000\n", + "108000\n", + "109000\n", + "110000\n", + "111000\n", + "112000\n", + "113000\n", + "114000\n", + "115000\n", + "116000\n", + "117000\n", + "118000\n", + "119000\n", + "120000\n", + "121000\n", + "122000\n", + "123000\n", + "124000\n", + "125000\n", + "126000\n", + "127000\n", + "128000\n", + "129000\n", + "130000\n", + "131000\n", + "132000\n", + "133000\n", + "134000\n", + "135000\n", + "136000\n", + "137000\n", + "138000\n", + "139000\n", + "140000\n", + "141000\n", + "142000\n", + "143000\n", + "144000\n", + "145000\n", + "146000\n", + "147000\n", + "148000\n", + "149000\n", + "150000\n", + "151000\n", + "152000\n", + "153000\n", + "154000\n", + "155000\n", + "156000\n", + "157000\n", + "158000\n", + "159000\n", + "160000\n", + "161000\n", + "162000\n", + "163000\n", + "164000\n", + "165000\n", + "166000\n", + "167000\n", + "168000\n", + "169000\n", + "170000\n", + "171000\n", + "172000\n", + "173000\n", + "174000\n", + "175000\n", + "176000\n", + "177000\n", + "178000\n", + "179000\n", + "180000\n", + "181000\n", + "182000\n", + "183000\n", + "184000\n", + "185000\n", + "186000\n", + "187000\n", + "188000\n", + "189000\n", + "190000\n", + "191000\n", + "192000\n", + "193000\n", + "194000\n", + "195000\n", + "196000\n", + "197000\n", + "198000\n", + "199000\n", + "200000\n", + "201000\n", + "202000\n", + "203000\n", + "204000\n", + "205000\n", + "206000\n", + "207000\n", + "208000\n", + "209000\n", + "210000\n", + "211000\n", + "212000\n", + "213000\n", + "214000\n", + "215000\n", + "216000\n", + "217000\n", + "218000\n", + "219000\n", + "220000\n", + "221000\n", + "222000\n", + "223000\n", + "224000\n", + "225000\n", + "226000\n", + "227000\n", + "228000\n", + "229000\n", + "230000\n", + "231000\n", + "232000\n", + "233000\n", + "234000\n", + "235000\n", + "236000\n", + "237000\n", + "238000\n", + "239000\n", + "240000\n", + "241000\n", + "242000\n", + "243000\n", + "244000\n", + "245000\n", + "246000\n", + "247000\n", + "248000\n", + "249000\n", + "250000\n", + "251000\n", + "252000\n", + "253000\n", + "254000\n", + "255000\n", + "256000\n", + "257000\n", + "258000\n", + "259000\n", + "260000\n", + "261000\n", + "262000\n", + "263000\n", + "264000\n", + "265000\n", + "266000\n", + "267000\n", + "268000\n", + "269000\n", + "270000\n", + "271000\n", + "272000\n", + "273000\n", + "274000\n", + "275000\n", + "276000\n", + "277000\n", + "278000\n", + "279000\n", + "280000\n", + "281000\n", + "282000\n", + "283000\n", + "284000\n", + "285000\n", + "286000\n", + "287000\n", + "288000\n", + "289000\n", + "290000\n", + "291000\n", + "292000\n", + "293000\n", + "294000\n", + "295000\n", + "296000\n", + "297000\n", + "298000\n", + "299000\n", + "300000\n", + "301000\n", + "302000\n", + "303000\n", + "304000\n", + "305000\n", + "306000\n", + "307000\n", + "308000\n", + "309000\n", + "310000\n", + "311000\n", + "312000\n", + "313000\n", + "314000\n", + "315000\n", + "316000\n", + "317000\n", + "318000\n", + "319000\n", + "320000\n", + "321000\n", + "322000\n", + "323000\n", + "324000\n", + "325000\n", + "326000\n", + "327000\n", + "328000\n", + "329000\n", + "330000\n", + "331000\n", + "332000\n", + "333000\n", + "334000\n", + "335000\n", + "336000\n", + "337000\n", + "338000\n", + "339000\n", + "340000\n", + "341000\n", + "342000\n", + "343000\n", + "344000\n", + "345000\n", + "346000\n", + "347000\n", + "348000\n", + "349000\n", + "350000\n", + "351000\n", + "352000\n", + "353000\n", + "354000\n", + "355000\n", + "356000\n", + "357000\n", + "358000\n", + "359000\n", + "360000\n", + "361000\n", + "362000\n", + "363000\n", + "364000\n", + "365000\n", + "366000\n", + "367000\n", + "368000\n", + "369000\n", + "370000\n", + "371000\n", + "372000\n", + "373000\n", + "374000\n", + "375000\n", + "376000\n", + "377000\n", + "378000\n", + "379000\n", + "380000\n", + "381000\n", + "382000\n", + "383000\n", + "384000\n", + "385000\n", + "386000\n", + "387000\n", + "388000\n", + "389000\n", + "390000\n", + "391000\n", + "392000\n", + "393000\n", + "394000\n", + "395000\n", + "396000\n", + "397000\n", + "398000\n", + "399000\n", + "400000\n", + "401000\n", + "402000\n", + "403000\n", + "404000\n", + "405000\n", + "406000\n", + "407000\n", + "408000\n", + "409000\n", + "410000\n", + "411000\n", + "412000\n", + "413000\n", + "414000\n", + "415000\n", + "416000\n", + "417000\n", + "418000\n", + "419000\n", + "420000\n", + "421000\n", + "422000\n", + "423000\n", + "424000\n", + "425000\n", + "426000\n", + "427000\n", + "428000\n", + "429000\n", + "430000\n", + "431000\n", + "432000\n" + ] + } + ], + "source": [ + "vocab = build_vocab_from_iterator(\n", + " get_word_lines_from_file(path_to_training_file),\n", + " max_tokens = vocab_size,\n", + " specials = [''])\n", + "\n", + "with open(path_to_vocabulary_file, 'wb') as handle:\n", + " pickle.dump(vocab, handle, protocol=pickle.HIGHEST_PROTOCOL)" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "markdown", + "source": [ + "## TRAIN MODEL" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 6, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "1000\n", + "2000\n", + "3000\n", + "4000\n", + "5000\n", + "6000\n", + "7000\n", + "8000\n", + "9000\n", + "10000\n", + "11000\n", + "12000\n", + "13000\n", + "14000\n", + "15000\n", + "16000\n", + "17000\n", + "18000\n", + "19000\n", + "20000\n", + "21000\n", + "22000\n", + "23000\n", + "24000\n", + "25000\n", + "26000\n", + "27000\n", + "28000\n", + "29000\n", + "30000\n", + "31000\n", + "32000\n", + "33000\n", + "34000\n", + "35000\n", + "36000\n", + "37000\n", + "38000\n", + "39000\n", + "40000\n", + "41000\n", + "42000\n", + "43000\n", + "44000\n", + "45000\n", + "46000\n", + "47000\n", + "48000\n", + "49000\n", + "50000\n", + "51000\n", + "52000\n", + "53000\n", + "54000\n", + "55000\n", + "56000\n", + "57000\n", + "58000\n", + "59000\n", + "60000\n", + "61000\n", + "62000\n", + "63000\n", + "64000\n", + "65000\n", + "66000\n", + "67000\n", + "68000\n", + "69000\n", + "70000\n", + "71000\n", + "72000\n", + "73000\n", + "74000\n", + "75000\n", + "76000\n", + "77000\n", + "78000\n", + "79000\n", + "80000\n", + "81000\n", + "82000\n", + "83000\n", + "84000\n", + "85000\n", + "86000\n", + "87000\n", + "88000\n", + "89000\n", + "90000\n", + "91000\n", + "92000\n", + "93000\n", + "94000\n", + "95000\n", + "96000\n", + "97000\n", + "98000\n", + "99000\n", + "100000\n", + "101000\n", + "102000\n", + "103000\n", + "104000\n", + "105000\n", + "106000\n", + "107000\n", + "108000\n", + "109000\n", + "110000\n", + "111000\n", + "112000\n", + "113000\n", + "114000\n", + "115000\n", + "116000\n", + "117000\n", + "118000\n", + "119000\n", + "120000\n", + "121000\n", + "122000\n", + "123000\n", + "124000\n", + "125000\n", + "126000\n", + "127000\n", + "128000\n", + "129000\n", + "130000\n", + "131000\n", + "132000\n", + "133000\n", + "134000\n", + "135000\n", + "136000\n", + "137000\n", + "138000\n", + "139000\n", + "140000\n", + "141000\n", + "142000\n", + "143000\n", + "144000\n", + "145000\n", + "146000\n", + "147000\n", + "148000\n", + "149000\n", + "150000\n", + "151000\n", + "152000\n", + "153000\n", + "154000\n", + "155000\n", + "156000\n", + "157000\n", + "158000\n", + "159000\n", + "160000\n", + "161000\n", + "162000\n", + "163000\n", + "164000\n", + "165000\n", + "166000\n", + "167000\n", + "168000\n", + "169000\n", + "170000\n", + "171000\n", + "172000\n", + "173000\n", + "174000\n", + "175000\n", + "176000\n", + "177000\n", + "178000\n", + "179000\n", + "180000\n", + "181000\n", + "182000\n", + "183000\n", + "184000\n", + "185000\n", + "186000\n", + "187000\n", + "188000\n", + "189000\n", + "190000\n", + "191000\n", + "192000\n", + "193000\n", + "194000\n", + "195000\n", + "196000\n", + "197000\n", + "198000\n", + "199000\n", + "200000\n", + "201000\n", + "202000\n", + "203000\n", + "204000\n", + "205000\n", + "206000\n", + "207000\n", + "208000\n", + "209000\n", + "210000\n", + "211000\n", + "212000\n", + "213000\n", + "214000\n", + "215000\n", + "216000\n", + "217000\n", + "218000\n", + "219000\n", + "220000\n", + "221000\n", + "222000\n", + "223000\n", + "224000\n", + "225000\n", + "226000\n", + "227000\n", + "228000\n", + "229000\n", + "230000\n", + "231000\n", + "232000\n", + "233000\n", + "234000\n", + "235000\n", + "236000\n", + "237000\n", + "238000\n", + "239000\n", + "240000\n", + "241000\n", + "242000\n", + "243000\n", + "244000\n", + "245000\n", + "246000\n", + "247000\n", + "248000\n", + "249000\n", + "250000\n", + "251000\n", + "252000\n", + "253000\n", + "254000\n", + "255000\n", + "256000\n", + "257000\n", + "258000\n", + "259000\n", + "260000\n", + "261000\n", + "262000\n", + "263000\n", + "264000\n", + "265000\n", + "266000\n", + "267000\n", + "268000\n", + "269000\n", + "270000\n", + "271000\n", + "272000\n", + "273000\n", + "274000\n", + "275000\n", + "276000\n", + "277000\n", + "278000\n", + "279000\n", + "280000\n", + "281000\n", + "282000\n", + "283000\n", + "284000\n", + "285000\n", + "286000\n", + "287000\n", + "288000\n", + "289000\n", + "290000\n", + "291000\n", + "292000\n", + "293000\n", + "294000\n", + "295000\n", + "296000\n", + "297000\n", + "298000\n", + "299000\n", + "300000\n", + "301000\n", + "302000\n", + "303000\n", + "304000\n", + "305000\n", + "306000\n", + "307000\n", + "308000\n", + "309000\n", + "310000\n", + "311000\n", + "312000\n", + "313000\n", + "314000\n", + "315000\n", + "316000\n", + "317000\n", + "318000\n", + "319000\n", + "320000\n", + "321000\n", + "322000\n", + "323000\n", + "324000\n", + "325000\n", + "326000\n", + "327000\n", + "328000\n", + "329000\n", + "330000\n", + "331000\n", + "332000\n", + "333000\n", + "334000\n", + "335000\n", + "336000\n", + "337000\n", + "338000\n", + "339000\n", + "340000\n", + "341000\n", + "342000\n", + "343000\n", + "344000\n", + "345000\n", + "346000\n", + "347000\n", + "348000\n", + "349000\n", + "350000\n", + "351000\n", + "352000\n", + "353000\n", + "354000\n", + "355000\n", + "356000\n", + "357000\n", + "358000\n", + "359000\n", + "360000\n", + "361000\n", + "362000\n", + "363000\n", + "364000\n", + "365000\n", + "366000\n", + "367000\n", + "368000\n", + "369000\n", + "370000\n", + "371000\n", + "372000\n", + "373000\n", + "374000\n", + "375000\n", + "376000\n", + "377000\n", + "378000\n", + "379000\n", + "380000\n", + "381000\n", + "382000\n", + "383000\n", + "384000\n", + "385000\n", + "386000\n", + "387000\n", + "388000\n", + "389000\n", + "390000\n", + "391000\n", + "392000\n", + "393000\n", + "394000\n", + "395000\n", + "396000\n", + "397000\n", + "398000\n", + "399000\n", + "400000\n", + "401000\n", + "402000\n", + "403000\n", + "404000\n", + "405000\n", + "406000\n", + "407000\n", + "408000\n", + "409000\n", + "410000\n", + "411000\n", + "412000\n", + "413000\n", + "414000\n", + "415000\n", + "416000\n", + "417000\n", + "418000\n", + "419000\n", + "420000\n", + "421000\n", + "422000\n", + "423000\n", + "424000\n", + "425000\n", + "426000\n", + "427000\n", + "428000\n", + "429000\n", + "430000\n", + "431000\n", + "432000\n", + "/Users/maciej/miniconda3/envs/mj/lib/python3.11/site-packages/torch/nn/modules/container.py:217: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n", + " input = module(input)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0 tensor(10.5058, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "1000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "100 tensor(7.3365, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2000\n", + "3000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "200 tensor(6.6523, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "4000\n", + "5000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "300 tensor(6.1860, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "6000\n", + "7000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "400 tensor(6.0387, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "8000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "500 tensor(5.8481, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "9000\n", + "10000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "600 tensor(5.6081, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "11000\n", + "12000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "700 tensor(5.5820, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "13000\n", + "14000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "800 tensor(5.5111, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "15000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "900 tensor(5.4927, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "16000\n", + "17000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1000 tensor(5.5190, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "18000\n", + "19000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1100 tensor(5.5600, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "20000\n", + "21000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1200 tensor(5.6395, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "22000\n", + "23000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1300 tensor(5.4455, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "24000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1400 tensor(5.5564, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "25000\n", + "26000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1500 tensor(5.4919, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "27000\n", + "28000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1600 tensor(5.2355, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "29000\n", + "30000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1700 tensor(5.4107, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "31000\n", + "32000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1800 tensor(5.5119, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "33000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1900 tensor(5.3500, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "34000\n", + "35000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2000 tensor(5.3722, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "36000\n", + "37000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2100 tensor(5.2736, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "38000\n", + "39000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2200 tensor(5.3808, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "40000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2300 tensor(5.5186, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "41000\n", + "42000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2400 tensor(5.2746, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "43000\n", + "44000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2500 tensor(5.3340, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "45000\n", + "46000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2600 tensor(5.4654, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "47000\n", + "48000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2700 tensor(5.4318, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "49000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2800 tensor(5.3528, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "50000\n", + "51000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2900 tensor(5.1630, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "52000\n", + "53000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "3000 tensor(5.4531, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "54000\n", + "55000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "3100 tensor(5.4153, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "56000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "3200 tensor(5.3299, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "57000\n", + "58000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "3300 tensor(5.3637, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "59000\n", + "60000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "3400 tensor(5.3405, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "61000\n", + "62000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "3500 tensor(5.3668, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "63000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "3600 tensor(5.4104, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "64000\n", + "65000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "3700 tensor(5.2142, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "66000\n", + "67000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "3800 tensor(5.5528, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "68000\n", + "69000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "3900 tensor(5.1879, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "70000\n", + "71000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "4000 tensor(5.2014, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "72000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "4100 tensor(5.4020, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "73000\n", + "74000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "4200 tensor(5.2686, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "75000\n", + "76000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "4300 tensor(5.3070, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "77000\n", + "78000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "4400 tensor(5.1891, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "79000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "4500 tensor(5.3085, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "80000\n", + "81000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "4600 tensor(5.3568, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "82000\n", + "83000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "4700 tensor(5.2280, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "84000\n", + "85000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "4800 tensor(5.2878, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "86000\n", + "87000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "4900 tensor(5.1588, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "88000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "5000 tensor(5.1523, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "89000\n", + "90000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "5100 tensor(5.2101, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "91000\n", + "92000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "5200 tensor(5.2949, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "93000\n", + "94000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "5300 tensor(5.3186, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "95000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "5400 tensor(5.2580, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "96000\n", + "97000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "5500 tensor(5.3632, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "98000\n", + "99000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "5600 tensor(5.3885, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100000\n", + "101000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "5700 tensor(5.2640, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "102000\n", + "103000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "5800 tensor(5.4444, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "104000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "5900 tensor(5.1981, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "105000\n", + "106000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "6000 tensor(5.2765, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "107000\n", + "108000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "6100 tensor(5.3015, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "109000\n", + "110000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "6200 tensor(5.1958, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "111000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "6300 tensor(5.1862, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "112000\n", + "113000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "6400 tensor(5.4609, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "114000\n", + "115000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "6500 tensor(5.2700, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "116000\n", + "117000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "6600 tensor(5.3814, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "118000\n", + "119000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "6700 tensor(5.2443, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "120000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "6800 tensor(5.2292, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "121000\n", + "122000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "6900 tensor(5.2252, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "123000\n", + "124000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "7000 tensor(5.3240, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "125000\n", + "126000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "7100 tensor(5.3584, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "127000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "7200 tensor(5.2038, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "128000\n", + "129000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "7300 tensor(5.3306, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "130000\n", + "131000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "7400 tensor(5.3824, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "132000\n", + "133000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "7500 tensor(5.1708, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "134000\n", + "135000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "7600 tensor(5.3388, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "136000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "7700 tensor(5.2014, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "137000\n", + "138000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "7800 tensor(5.3407, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "139000\n", + "140000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "7900 tensor(5.3078, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "141000\n", + "142000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "8000 tensor(5.0961, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "143000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "8100 tensor(5.1313, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "144000\n", + "145000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "8200 tensor(5.2008, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "146000\n", + "147000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "8300 tensor(5.1277, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "148000\n", + "149000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "8400 tensor(5.3875, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "150000\n", + "151000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "8500 tensor(5.3107, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "152000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "8600 tensor(5.3640, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "153000\n", + "154000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "8700 tensor(5.1869, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "155000\n", + "156000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "8800 tensor(5.0180, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "157000\n", + "158000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "8900 tensor(5.1767, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "159000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "9000 tensor(5.3253, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "160000\n", + "161000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "9100 tensor(5.1971, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "162000\n", + "163000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "9200 tensor(5.2071, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "164000\n", + "165000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "9300 tensor(5.1244, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "166000\n", + "167000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "9400 tensor(5.2198, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "168000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "9500 tensor(5.3042, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "169000\n", + "170000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "9600 tensor(5.3171, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "171000\n", + "172000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "9700 tensor(5.1956, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "173000\n", + "174000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "9800 tensor(5.1559, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "175000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "9900 tensor(5.1519, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "176000\n", + "177000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "10000 tensor(5.3396, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "178000\n", + "179000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "10100 tensor(5.2106, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "180000\n", + "181000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "10200 tensor(5.3356, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "182000\n", + "183000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "10300 tensor(5.2105, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "184000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "10400 tensor(5.0844, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "185000\n", + "186000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "10500 tensor(5.3788, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "187000\n", + "188000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "10600 tensor(5.1145, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "189000\n", + "190000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "10700 tensor(5.2610, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "191000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "10800 tensor(5.2560, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "192000\n", + "193000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "10900 tensor(5.2565, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "194000\n", + "195000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "11000 tensor(5.2770, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "196000\n", + "197000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "11100 tensor(5.1193, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "198000\n", + "199000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "11200 tensor(5.1823, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "200000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "11300 tensor(5.3099, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "201000\n", + "202000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "11400 tensor(5.2330, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "203000\n", + "204000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "11500 tensor(5.1722, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "205000\n", + "206000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "11600 tensor(5.2136, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "207000\n", + "208000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "11700 tensor(5.3126, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "209000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "11800 tensor(5.1057, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "210000\n", + "211000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "11900 tensor(5.2419, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "212000\n", + "213000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "12000 tensor(5.2434, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "214000\n", + "215000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "12100 tensor(5.1692, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "216000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "12200 tensor(5.2075, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "217000\n", + "218000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "12300 tensor(5.1290, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "219000\n", + "220000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "12400 tensor(5.2380, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "221000\n", + "222000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "12500 tensor(5.2779, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "223000\n", + "224000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "12600 tensor(5.3369, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "225000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "12700 tensor(5.2351, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "226000\n", + "227000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "12800 tensor(5.2434, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "228000\n", + "229000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "12900 tensor(5.1963, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "230000\n", + "231000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "13000 tensor(5.1363, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "232000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "13100 tensor(5.1915, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "233000\n", + "234000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "13200 tensor(5.1264, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "235000\n", + "236000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "13300 tensor(5.1468, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "237000\n", + "238000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "13400 tensor(5.3026, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "239000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "13500 tensor(5.2925, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "240000\n", + "241000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "13600 tensor(5.1511, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "242000\n", + "243000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "13700 tensor(5.4282, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "244000\n", + "245000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "13800 tensor(5.2730, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "246000\n", + "247000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "13900 tensor(5.2097, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "248000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "14000 tensor(5.2728, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "249000\n", + "250000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "14100 tensor(5.2134, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "251000\n", + "252000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "14200 tensor(5.1931, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "253000\n", + "254000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "14300 tensor(5.2459, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "255000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "14400 tensor(5.1297, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "256000\n", + "257000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "14500 tensor(5.0971, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "258000\n", + "259000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "14600 tensor(5.2238, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "260000\n", + "261000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "14700 tensor(5.2328, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "262000\n", + "263000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "14800 tensor(5.1782, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "264000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "14900 tensor(5.3230, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "265000\n", + "266000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "15000 tensor(5.1504, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "267000\n", + "268000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "15100 tensor(5.1998, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "269000\n", + "270000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "15200 tensor(5.2138, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "271000\n", + "272000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "15300 tensor(5.4110, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "273000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "15400 tensor(5.1748, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "274000\n", + "275000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "15500 tensor(5.2118, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "276000\n", + "277000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "15600 tensor(5.2297, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "278000\n", + "279000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "15700 tensor(5.2977, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "280000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "15800 tensor(5.2175, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "281000\n", + "282000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "15900 tensor(5.0613, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "283000\n", + "284000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "16000 tensor(5.0862, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "285000\n", + "286000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "16100 tensor(5.1910, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "287000\n", + "288000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "16200 tensor(5.0195, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "289000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "16300 tensor(5.1381, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "290000\n", + "291000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "16400 tensor(5.2135, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "292000\n", + "293000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "16500 tensor(5.2058, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "294000\n", + "295000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "16600 tensor(5.2372, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "296000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "16700 tensor(5.1753, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "297000\n", + "298000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "16800 tensor(5.0765, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "299000\n", + "300000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "16900 tensor(5.3361, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "301000\n", + "302000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "17000 tensor(5.2745, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "303000\n", + "304000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "17100 tensor(5.2249, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "305000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "17200 tensor(5.1877, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "306000\n", + "307000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "17300 tensor(5.0891, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "308000\n", + "309000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "17400 tensor(5.4181, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "310000\n", + "311000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "17500 tensor(5.1299, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "312000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "17600 tensor(5.1636, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "313000\n", + "314000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "17700 tensor(5.2179, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "315000\n", + "316000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "17800 tensor(5.2689, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "317000\n", + "318000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "17900 tensor(5.2410, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "319000\n", + "320000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "18000 tensor(5.2342, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "321000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "18100 tensor(5.2234, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "322000\n", + "323000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "18200 tensor(5.0779, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "324000\n", + "325000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "18300 tensor(5.2378, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "326000\n", + "327000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "18400 tensor(5.1710, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "328000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "18500 tensor(5.1134, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "329000\n", + "330000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "18600 tensor(5.2679, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "331000\n", + "332000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "18700 tensor(5.2590, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "333000\n", + "334000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "18800 tensor(5.1842, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "335000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "18900 tensor(5.1379, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "336000\n", + "337000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "19000 tensor(5.1416, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "338000\n", + "339000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "19100 tensor(5.1602, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "340000\n", + "341000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "19200 tensor(5.2670, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "342000\n", + "343000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "19300 tensor(5.1622, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "344000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "19400 tensor(5.1805, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "345000\n", + "346000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "19500 tensor(5.1820, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "347000\n", + "348000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "19600 tensor(5.2506, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "349000\n", + "350000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "19700 tensor(5.1566, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "351000\n", + "352000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "19800 tensor(5.1121, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "353000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "19900 tensor(5.1227, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "354000\n", + "355000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "20000 tensor(5.2132, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "356000\n", + "357000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "20100 tensor(5.2681, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "358000\n", + "359000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "20200 tensor(5.2689, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "360000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "20300 tensor(5.1758, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "361000\n", + "362000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "20400 tensor(5.1275, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "363000\n", + "364000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "20500 tensor(5.1803, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "365000\n", + "366000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "20600 tensor(5.1202, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "367000\n", + "368000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "20700 tensor(5.2343, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "369000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "20800 tensor(5.2035, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "370000\n", + "371000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "20900 tensor(5.2992, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "372000\n", + "373000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "21000 tensor(5.1540, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "374000\n", + "375000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "21100 tensor(5.2739, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "376000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "21200 tensor(5.2949, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "377000\n", + "378000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "21300 tensor(5.2138, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "379000\n", + "380000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "21400 tensor(5.2773, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "381000\n", + "382000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "21500 tensor(5.2345, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "383000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "21600 tensor(5.2528, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "384000\n", + "385000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "21700 tensor(5.1824, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "386000\n", + "387000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "21800 tensor(5.1943, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "388000\n", + "389000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "21900 tensor(5.0359, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "390000\n", + "391000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "22000 tensor(5.1506, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "392000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "22100 tensor(5.1253, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "393000\n", + "394000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "22200 tensor(5.0982, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "395000\n", + "396000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "22300 tensor(5.1554, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "397000\n", + "398000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "22400 tensor(5.1673, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "399000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "22500 tensor(5.1957, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "400000\n", + "401000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "22600 tensor(5.1328, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "402000\n", + "403000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "22700 tensor(5.2231, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "404000\n", + "405000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "22800 tensor(5.1370, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "406000\n", + "407000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "22900 tensor(5.2334, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "408000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "23000 tensor(5.1372, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "409000\n", + "410000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "23100 tensor(5.1193, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "411000\n", + "412000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "23200 tensor(5.2649, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "413000\n", + "414000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "23300 tensor(5.1514, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "415000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "23400 tensor(5.2532, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "416000\n", + "417000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "23500 tensor(5.3751, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "418000\n", + "419000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "23600 tensor(5.0766, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "420000\n", + "421000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "23700 tensor(5.0915, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "422000\n", + "423000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "23800 tensor(5.3195, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "424000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "23900 tensor(5.2758, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "425000\n", + "426000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "24000 tensor(5.0487, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "427000\n", + "428000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "24100 tensor(5.1555, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "429000\n", + "430000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "24200 tensor(5.2140, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "431000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "24300 tensor(5.2729, device='mps:0', grad_fn=)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "432000\n" + ] + } + ], + "source": [ + "train_dataset = Bigrams(path_to_training_file, vocab_size)\n", + "model = SimpleBigramNeuralLanguageModel(vocab_size, embed_size).to(device)\n", + "data = DataLoader(train_dataset, batch_size=batch_size)\n", + "optimizer = torch.optim.Adam(model.parameters())\n", + "criterion = torch.nn.NLLLoss()\n", + "\n", + "model.train()\n", + "step = 0\n", + "for x, y in data:\n", + " x = x.to(device)\n", + " y = y.to(device)\n", + " optimizer.zero_grad()\n", + " ypredicted = model(x)\n", + " loss = criterion(torch.log(ypredicted), y)\n", + " if step % 100 == 0:\n", + " print(step, loss)\n", + " step += 1\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + "torch.save(model.state_dict(), path_to_model_file)" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "markdown", + "source": [ + "## LOAD MODEL AND VOCAB" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 7, + "outputs": [ + { + "data": { + "text/plain": "SimpleBigramNeuralLanguageModel(\n (model): Sequential(\n (0): Embedding(30000, 1000)\n (1): Linear(in_features=1000, out_features=30000, bias=True)\n (2): Softmax(dim=None)\n )\n)" + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "with open(path_to_vocabulary_file, 'rb') as handle:\n", + " vocab = pickle.load(handle)\n", + "model = SimpleBigramNeuralLanguageModel(vocab_size, embed_size).to(device)\n", + "model.load_state_dict(torch.load(path_to_model_file))\n", + "model.eval()" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "markdown", + "source": [ + "## CREATE OUTPUTS FILES" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "markdown", + "source": [ + "### DEV-0" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 8, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "========== do prediction for dev-0/in.tsv.xz ==========\n" + ] + } + ], + "source": [ + "predicition_for_file(model, vocab, folder_dev_0, file_dev_0)" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "markdown", + "source": [ + "### TEST-A" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 9, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "========== do prediction for test-A/in.tsv.xz ==========\n" + ] + } + ], + "source": [ + "predicition_for_file(model, vocab, folder_test_a, file_test_a)" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [], + "metadata": { + "collapsed": false + } + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +}