diff --git a/simple_neural_network.ipynb b/simple_neural_network.ipynb
new file mode 100644
index 0000000..ef10ea2
--- /dev/null
+++ b/simple_neural_network.ipynb
@@ -0,0 +1,4874 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## IMPORTS"
+ ],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "True\n",
+ "True\n"
+ ]
+ }
+ ],
+ "source": [
+ "import regex as re\n",
+ "import sys\n",
+ "from torchtext.vocab import build_vocab_from_iterator\n",
+ "import lzma\n",
+ "from torch.utils.data import IterableDataset\n",
+ "import itertools\n",
+ "from torch import nn\n",
+ "import torch\n",
+ "import pickle\n",
+ "from torch.utils.data import DataLoader\n",
+ "\n",
+ "print(torch.backends.mps.is_available())\n",
+ "print(torch.backends.mps.is_built())"
+ ],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## FUNCTIONS"
+ ],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "outputs": [],
+ "source": [
+ "def get_words_from_line(line):\n",
+ " line = line.rstrip()\n",
+ " yield ''\n",
+ " for t in line.split(' '):\n",
+ " yield t\n",
+ " yield ''\n",
+ "\n",
+ "def get_word_lines_from_file(file_name):\n",
+ " n = 0\n",
+ " with lzma.open(file_name, 'r') as fh:\n",
+ " for line in fh:\n",
+ " n += 1\n",
+ " if n % 1000 == 0:\n",
+ " print(n ,file=sys.stderr)\n",
+ " yield get_words_from_line(line.decode('utf-8'))\n",
+ "\n",
+ "def look_ahead_iterator(gen):\n",
+ " prev = None\n",
+ " for item in gen:\n",
+ " if prev is not None:\n",
+ " yield (prev, item)\n",
+ " prev = item\n",
+ "\n",
+ "def clean(text):\n",
+ " text = str(text).lower().replace('-\\\\n', '').replace('\\\\n', ' ').replace('-', '').replace('\\'s', ' is').replace('\\'re', ' are').replace('\\'m', ' am').replace('\\'ve', ' have').replace('\\'ll', ' will')\n",
+ " text = re.sub(r'\\p{P}', '', text)\n",
+ " return text\n",
+ "\n",
+ "def predict(word, model, vocab):\n",
+ " try:\n",
+ " ixs = torch.tensor(vocab.forward([word])).to(device)\n",
+ " except:\n",
+ " ixs = torch.tensor(vocab.forward([''])).to(device)\n",
+ " word = ''\n",
+ " out = model(ixs)\n",
+ " top = torch.topk(out[0], 300)\n",
+ " top_indices = top.indices.tolist()\n",
+ " top_probs = top.values.tolist()\n",
+ " top_words = vocab.lookup_tokens(top_indices)\n",
+ " prob_list = list(zip(top_words, top_probs))\n",
+ " for index, element in enumerate(prob_list):\n",
+ " unk = None\n",
+ " if '' in element:\n",
+ " unk = prob_list.pop(index)\n",
+ " prob_list.append(('', unk[1]))\n",
+ " break\n",
+ " if unk is None:\n",
+ " prob_list[-1] = ('', prob_list[-1][1])\n",
+ " return ' '.join([f'{x[0]}:{x[1]}' for x in prob_list])\n",
+ "\n",
+ "def predicition_for_file(model, vocab, folder, file):\n",
+ " print('=' * 10, f' do prediction for {folder}/{file} ', '=' * 10)\n",
+ " with lzma.open(f'{folder}/in.tsv.xz', mode='rt', encoding='utf-8') as f:\n",
+ " with open(f'{folder}/out.tsv', 'w', encoding='utf-8') as fid:\n",
+ "\n",
+ " for line in f:\n",
+ " separated = line.split('\\t')\n",
+ " before = clean(separated[6]).split()[-1]\n",
+ " new_line = predict(before, model, vocab)\n",
+ " fid.write(new_line + '\\n')"
+ ],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## CLASSES"
+ ],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "outputs": [],
+ "source": [
+ "class Bigrams(IterableDataset):\n",
+ " def __init__(self, text_file, vocabulary_size):\n",
+ " self.vocab = build_vocab_from_iterator(\n",
+ " get_word_lines_from_file(text_file),\n",
+ " max_tokens = vocabulary_size,\n",
+ " specials = [''])\n",
+ " self.vocab.set_default_index(self.vocab[''])\n",
+ " self.vocabulary_size = vocabulary_size\n",
+ " self.text_file = text_file\n",
+ "\n",
+ " def __iter__(self):\n",
+ " return look_ahead_iterator(\n",
+ " (self.vocab[t] for t in itertools.chain.from_iterable(get_word_lines_from_file(self.text_file))))\n",
+ "\n",
+ "class SimpleBigramNeuralLanguageModel(nn.Module):\n",
+ " def __init__(self, vocabulary_size, embedding_size):\n",
+ " super(SimpleBigramNeuralLanguageModel, self).__init__()\n",
+ " self.model = nn.Sequential(\n",
+ " nn.Embedding(vocabulary_size, embedding_size),\n",
+ " nn.Linear(embedding_size, vocabulary_size),\n",
+ " nn.Softmax()\n",
+ " )\n",
+ "\n",
+ " def forward(self, x):\n",
+ " return self.model(x)"
+ ],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## PARAMETERS"
+ ],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "outputs": [],
+ "source": [
+ "vocab_size = 30000\n",
+ "embed_size = 1000\n",
+ "batch_size = 5000\n",
+ "device = 'mps'\n",
+ "path_to_training_file = './train/in.tsv.xz'\n",
+ "path_to_model_file = 'model_neural_network.bin'\n",
+ "folder_dev_0, file_dev_0 = 'dev-0', 'in.tsv.xz'\n",
+ "folder_test_a, file_test_a = 'test-A', 'in.tsv.xz'\n",
+ "path_to_vocabulary_file = 'vocabulary_neural_network.pickle'"
+ ],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## VOCAB"
+ ],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "1000\n",
+ "2000\n",
+ "3000\n",
+ "4000\n",
+ "5000\n",
+ "6000\n",
+ "7000\n",
+ "8000\n",
+ "9000\n",
+ "10000\n",
+ "11000\n",
+ "12000\n",
+ "13000\n",
+ "14000\n",
+ "15000\n",
+ "16000\n",
+ "17000\n",
+ "18000\n",
+ "19000\n",
+ "20000\n",
+ "21000\n",
+ "22000\n",
+ "23000\n",
+ "24000\n",
+ "25000\n",
+ "26000\n",
+ "27000\n",
+ "28000\n",
+ "29000\n",
+ "30000\n",
+ "31000\n",
+ "32000\n",
+ "33000\n",
+ "34000\n",
+ "35000\n",
+ "36000\n",
+ "37000\n",
+ "38000\n",
+ "39000\n",
+ "40000\n",
+ "41000\n",
+ "42000\n",
+ "43000\n",
+ "44000\n",
+ "45000\n",
+ "46000\n",
+ "47000\n",
+ "48000\n",
+ "49000\n",
+ "50000\n",
+ "51000\n",
+ "52000\n",
+ "53000\n",
+ "54000\n",
+ "55000\n",
+ "56000\n",
+ "57000\n",
+ "58000\n",
+ "59000\n",
+ "60000\n",
+ "61000\n",
+ "62000\n",
+ "63000\n",
+ "64000\n",
+ "65000\n",
+ "66000\n",
+ "67000\n",
+ "68000\n",
+ "69000\n",
+ "70000\n",
+ "71000\n",
+ "72000\n",
+ "73000\n",
+ "74000\n",
+ "75000\n",
+ "76000\n",
+ "77000\n",
+ "78000\n",
+ "79000\n",
+ "80000\n",
+ "81000\n",
+ "82000\n",
+ "83000\n",
+ "84000\n",
+ "85000\n",
+ "86000\n",
+ "87000\n",
+ "88000\n",
+ "89000\n",
+ "90000\n",
+ "91000\n",
+ "92000\n",
+ "93000\n",
+ "94000\n",
+ "95000\n",
+ "96000\n",
+ "97000\n",
+ "98000\n",
+ "99000\n",
+ "100000\n",
+ "101000\n",
+ "102000\n",
+ "103000\n",
+ "104000\n",
+ "105000\n",
+ "106000\n",
+ "107000\n",
+ "108000\n",
+ "109000\n",
+ "110000\n",
+ "111000\n",
+ "112000\n",
+ "113000\n",
+ "114000\n",
+ "115000\n",
+ "116000\n",
+ "117000\n",
+ "118000\n",
+ "119000\n",
+ "120000\n",
+ "121000\n",
+ "122000\n",
+ "123000\n",
+ "124000\n",
+ "125000\n",
+ "126000\n",
+ "127000\n",
+ "128000\n",
+ "129000\n",
+ "130000\n",
+ "131000\n",
+ "132000\n",
+ "133000\n",
+ "134000\n",
+ "135000\n",
+ "136000\n",
+ "137000\n",
+ "138000\n",
+ "139000\n",
+ "140000\n",
+ "141000\n",
+ "142000\n",
+ "143000\n",
+ "144000\n",
+ "145000\n",
+ "146000\n",
+ "147000\n",
+ "148000\n",
+ "149000\n",
+ "150000\n",
+ "151000\n",
+ "152000\n",
+ "153000\n",
+ "154000\n",
+ "155000\n",
+ "156000\n",
+ "157000\n",
+ "158000\n",
+ "159000\n",
+ "160000\n",
+ "161000\n",
+ "162000\n",
+ "163000\n",
+ "164000\n",
+ "165000\n",
+ "166000\n",
+ "167000\n",
+ "168000\n",
+ "169000\n",
+ "170000\n",
+ "171000\n",
+ "172000\n",
+ "173000\n",
+ "174000\n",
+ "175000\n",
+ "176000\n",
+ "177000\n",
+ "178000\n",
+ "179000\n",
+ "180000\n",
+ "181000\n",
+ "182000\n",
+ "183000\n",
+ "184000\n",
+ "185000\n",
+ "186000\n",
+ "187000\n",
+ "188000\n",
+ "189000\n",
+ "190000\n",
+ "191000\n",
+ "192000\n",
+ "193000\n",
+ "194000\n",
+ "195000\n",
+ "196000\n",
+ "197000\n",
+ "198000\n",
+ "199000\n",
+ "200000\n",
+ "201000\n",
+ "202000\n",
+ "203000\n",
+ "204000\n",
+ "205000\n",
+ "206000\n",
+ "207000\n",
+ "208000\n",
+ "209000\n",
+ "210000\n",
+ "211000\n",
+ "212000\n",
+ "213000\n",
+ "214000\n",
+ "215000\n",
+ "216000\n",
+ "217000\n",
+ "218000\n",
+ "219000\n",
+ "220000\n",
+ "221000\n",
+ "222000\n",
+ "223000\n",
+ "224000\n",
+ "225000\n",
+ "226000\n",
+ "227000\n",
+ "228000\n",
+ "229000\n",
+ "230000\n",
+ "231000\n",
+ "232000\n",
+ "233000\n",
+ "234000\n",
+ "235000\n",
+ "236000\n",
+ "237000\n",
+ "238000\n",
+ "239000\n",
+ "240000\n",
+ "241000\n",
+ "242000\n",
+ "243000\n",
+ "244000\n",
+ "245000\n",
+ "246000\n",
+ "247000\n",
+ "248000\n",
+ "249000\n",
+ "250000\n",
+ "251000\n",
+ "252000\n",
+ "253000\n",
+ "254000\n",
+ "255000\n",
+ "256000\n",
+ "257000\n",
+ "258000\n",
+ "259000\n",
+ "260000\n",
+ "261000\n",
+ "262000\n",
+ "263000\n",
+ "264000\n",
+ "265000\n",
+ "266000\n",
+ "267000\n",
+ "268000\n",
+ "269000\n",
+ "270000\n",
+ "271000\n",
+ "272000\n",
+ "273000\n",
+ "274000\n",
+ "275000\n",
+ "276000\n",
+ "277000\n",
+ "278000\n",
+ "279000\n",
+ "280000\n",
+ "281000\n",
+ "282000\n",
+ "283000\n",
+ "284000\n",
+ "285000\n",
+ "286000\n",
+ "287000\n",
+ "288000\n",
+ "289000\n",
+ "290000\n",
+ "291000\n",
+ "292000\n",
+ "293000\n",
+ "294000\n",
+ "295000\n",
+ "296000\n",
+ "297000\n",
+ "298000\n",
+ "299000\n",
+ "300000\n",
+ "301000\n",
+ "302000\n",
+ "303000\n",
+ "304000\n",
+ "305000\n",
+ "306000\n",
+ "307000\n",
+ "308000\n",
+ "309000\n",
+ "310000\n",
+ "311000\n",
+ "312000\n",
+ "313000\n",
+ "314000\n",
+ "315000\n",
+ "316000\n",
+ "317000\n",
+ "318000\n",
+ "319000\n",
+ "320000\n",
+ "321000\n",
+ "322000\n",
+ "323000\n",
+ "324000\n",
+ "325000\n",
+ "326000\n",
+ "327000\n",
+ "328000\n",
+ "329000\n",
+ "330000\n",
+ "331000\n",
+ "332000\n",
+ "333000\n",
+ "334000\n",
+ "335000\n",
+ "336000\n",
+ "337000\n",
+ "338000\n",
+ "339000\n",
+ "340000\n",
+ "341000\n",
+ "342000\n",
+ "343000\n",
+ "344000\n",
+ "345000\n",
+ "346000\n",
+ "347000\n",
+ "348000\n",
+ "349000\n",
+ "350000\n",
+ "351000\n",
+ "352000\n",
+ "353000\n",
+ "354000\n",
+ "355000\n",
+ "356000\n",
+ "357000\n",
+ "358000\n",
+ "359000\n",
+ "360000\n",
+ "361000\n",
+ "362000\n",
+ "363000\n",
+ "364000\n",
+ "365000\n",
+ "366000\n",
+ "367000\n",
+ "368000\n",
+ "369000\n",
+ "370000\n",
+ "371000\n",
+ "372000\n",
+ "373000\n",
+ "374000\n",
+ "375000\n",
+ "376000\n",
+ "377000\n",
+ "378000\n",
+ "379000\n",
+ "380000\n",
+ "381000\n",
+ "382000\n",
+ "383000\n",
+ "384000\n",
+ "385000\n",
+ "386000\n",
+ "387000\n",
+ "388000\n",
+ "389000\n",
+ "390000\n",
+ "391000\n",
+ "392000\n",
+ "393000\n",
+ "394000\n",
+ "395000\n",
+ "396000\n",
+ "397000\n",
+ "398000\n",
+ "399000\n",
+ "400000\n",
+ "401000\n",
+ "402000\n",
+ "403000\n",
+ "404000\n",
+ "405000\n",
+ "406000\n",
+ "407000\n",
+ "408000\n",
+ "409000\n",
+ "410000\n",
+ "411000\n",
+ "412000\n",
+ "413000\n",
+ "414000\n",
+ "415000\n",
+ "416000\n",
+ "417000\n",
+ "418000\n",
+ "419000\n",
+ "420000\n",
+ "421000\n",
+ "422000\n",
+ "423000\n",
+ "424000\n",
+ "425000\n",
+ "426000\n",
+ "427000\n",
+ "428000\n",
+ "429000\n",
+ "430000\n",
+ "431000\n",
+ "432000\n"
+ ]
+ }
+ ],
+ "source": [
+ "vocab = build_vocab_from_iterator(\n",
+ " get_word_lines_from_file(path_to_training_file),\n",
+ " max_tokens = vocab_size,\n",
+ " specials = [''])\n",
+ "\n",
+ "with open(path_to_vocabulary_file, 'wb') as handle:\n",
+ " pickle.dump(vocab, handle, protocol=pickle.HIGHEST_PROTOCOL)"
+ ],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## TRAIN MODEL"
+ ],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "1000\n",
+ "2000\n",
+ "3000\n",
+ "4000\n",
+ "5000\n",
+ "6000\n",
+ "7000\n",
+ "8000\n",
+ "9000\n",
+ "10000\n",
+ "11000\n",
+ "12000\n",
+ "13000\n",
+ "14000\n",
+ "15000\n",
+ "16000\n",
+ "17000\n",
+ "18000\n",
+ "19000\n",
+ "20000\n",
+ "21000\n",
+ "22000\n",
+ "23000\n",
+ "24000\n",
+ "25000\n",
+ "26000\n",
+ "27000\n",
+ "28000\n",
+ "29000\n",
+ "30000\n",
+ "31000\n",
+ "32000\n",
+ "33000\n",
+ "34000\n",
+ "35000\n",
+ "36000\n",
+ "37000\n",
+ "38000\n",
+ "39000\n",
+ "40000\n",
+ "41000\n",
+ "42000\n",
+ "43000\n",
+ "44000\n",
+ "45000\n",
+ "46000\n",
+ "47000\n",
+ "48000\n",
+ "49000\n",
+ "50000\n",
+ "51000\n",
+ "52000\n",
+ "53000\n",
+ "54000\n",
+ "55000\n",
+ "56000\n",
+ "57000\n",
+ "58000\n",
+ "59000\n",
+ "60000\n",
+ "61000\n",
+ "62000\n",
+ "63000\n",
+ "64000\n",
+ "65000\n",
+ "66000\n",
+ "67000\n",
+ "68000\n",
+ "69000\n",
+ "70000\n",
+ "71000\n",
+ "72000\n",
+ "73000\n",
+ "74000\n",
+ "75000\n",
+ "76000\n",
+ "77000\n",
+ "78000\n",
+ "79000\n",
+ "80000\n",
+ "81000\n",
+ "82000\n",
+ "83000\n",
+ "84000\n",
+ "85000\n",
+ "86000\n",
+ "87000\n",
+ "88000\n",
+ "89000\n",
+ "90000\n",
+ "91000\n",
+ "92000\n",
+ "93000\n",
+ "94000\n",
+ "95000\n",
+ "96000\n",
+ "97000\n",
+ "98000\n",
+ "99000\n",
+ "100000\n",
+ "101000\n",
+ "102000\n",
+ "103000\n",
+ "104000\n",
+ "105000\n",
+ "106000\n",
+ "107000\n",
+ "108000\n",
+ "109000\n",
+ "110000\n",
+ "111000\n",
+ "112000\n",
+ "113000\n",
+ "114000\n",
+ "115000\n",
+ "116000\n",
+ "117000\n",
+ "118000\n",
+ "119000\n",
+ "120000\n",
+ "121000\n",
+ "122000\n",
+ "123000\n",
+ "124000\n",
+ "125000\n",
+ "126000\n",
+ "127000\n",
+ "128000\n",
+ "129000\n",
+ "130000\n",
+ "131000\n",
+ "132000\n",
+ "133000\n",
+ "134000\n",
+ "135000\n",
+ "136000\n",
+ "137000\n",
+ "138000\n",
+ "139000\n",
+ "140000\n",
+ "141000\n",
+ "142000\n",
+ "143000\n",
+ "144000\n",
+ "145000\n",
+ "146000\n",
+ "147000\n",
+ "148000\n",
+ "149000\n",
+ "150000\n",
+ "151000\n",
+ "152000\n",
+ "153000\n",
+ "154000\n",
+ "155000\n",
+ "156000\n",
+ "157000\n",
+ "158000\n",
+ "159000\n",
+ "160000\n",
+ "161000\n",
+ "162000\n",
+ "163000\n",
+ "164000\n",
+ "165000\n",
+ "166000\n",
+ "167000\n",
+ "168000\n",
+ "169000\n",
+ "170000\n",
+ "171000\n",
+ "172000\n",
+ "173000\n",
+ "174000\n",
+ "175000\n",
+ "176000\n",
+ "177000\n",
+ "178000\n",
+ "179000\n",
+ "180000\n",
+ "181000\n",
+ "182000\n",
+ "183000\n",
+ "184000\n",
+ "185000\n",
+ "186000\n",
+ "187000\n",
+ "188000\n",
+ "189000\n",
+ "190000\n",
+ "191000\n",
+ "192000\n",
+ "193000\n",
+ "194000\n",
+ "195000\n",
+ "196000\n",
+ "197000\n",
+ "198000\n",
+ "199000\n",
+ "200000\n",
+ "201000\n",
+ "202000\n",
+ "203000\n",
+ "204000\n",
+ "205000\n",
+ "206000\n",
+ "207000\n",
+ "208000\n",
+ "209000\n",
+ "210000\n",
+ "211000\n",
+ "212000\n",
+ "213000\n",
+ "214000\n",
+ "215000\n",
+ "216000\n",
+ "217000\n",
+ "218000\n",
+ "219000\n",
+ "220000\n",
+ "221000\n",
+ "222000\n",
+ "223000\n",
+ "224000\n",
+ "225000\n",
+ "226000\n",
+ "227000\n",
+ "228000\n",
+ "229000\n",
+ "230000\n",
+ "231000\n",
+ "232000\n",
+ "233000\n",
+ "234000\n",
+ "235000\n",
+ "236000\n",
+ "237000\n",
+ "238000\n",
+ "239000\n",
+ "240000\n",
+ "241000\n",
+ "242000\n",
+ "243000\n",
+ "244000\n",
+ "245000\n",
+ "246000\n",
+ "247000\n",
+ "248000\n",
+ "249000\n",
+ "250000\n",
+ "251000\n",
+ "252000\n",
+ "253000\n",
+ "254000\n",
+ "255000\n",
+ "256000\n",
+ "257000\n",
+ "258000\n",
+ "259000\n",
+ "260000\n",
+ "261000\n",
+ "262000\n",
+ "263000\n",
+ "264000\n",
+ "265000\n",
+ "266000\n",
+ "267000\n",
+ "268000\n",
+ "269000\n",
+ "270000\n",
+ "271000\n",
+ "272000\n",
+ "273000\n",
+ "274000\n",
+ "275000\n",
+ "276000\n",
+ "277000\n",
+ "278000\n",
+ "279000\n",
+ "280000\n",
+ "281000\n",
+ "282000\n",
+ "283000\n",
+ "284000\n",
+ "285000\n",
+ "286000\n",
+ "287000\n",
+ "288000\n",
+ "289000\n",
+ "290000\n",
+ "291000\n",
+ "292000\n",
+ "293000\n",
+ "294000\n",
+ "295000\n",
+ "296000\n",
+ "297000\n",
+ "298000\n",
+ "299000\n",
+ "300000\n",
+ "301000\n",
+ "302000\n",
+ "303000\n",
+ "304000\n",
+ "305000\n",
+ "306000\n",
+ "307000\n",
+ "308000\n",
+ "309000\n",
+ "310000\n",
+ "311000\n",
+ "312000\n",
+ "313000\n",
+ "314000\n",
+ "315000\n",
+ "316000\n",
+ "317000\n",
+ "318000\n",
+ "319000\n",
+ "320000\n",
+ "321000\n",
+ "322000\n",
+ "323000\n",
+ "324000\n",
+ "325000\n",
+ "326000\n",
+ "327000\n",
+ "328000\n",
+ "329000\n",
+ "330000\n",
+ "331000\n",
+ "332000\n",
+ "333000\n",
+ "334000\n",
+ "335000\n",
+ "336000\n",
+ "337000\n",
+ "338000\n",
+ "339000\n",
+ "340000\n",
+ "341000\n",
+ "342000\n",
+ "343000\n",
+ "344000\n",
+ "345000\n",
+ "346000\n",
+ "347000\n",
+ "348000\n",
+ "349000\n",
+ "350000\n",
+ "351000\n",
+ "352000\n",
+ "353000\n",
+ "354000\n",
+ "355000\n",
+ "356000\n",
+ "357000\n",
+ "358000\n",
+ "359000\n",
+ "360000\n",
+ "361000\n",
+ "362000\n",
+ "363000\n",
+ "364000\n",
+ "365000\n",
+ "366000\n",
+ "367000\n",
+ "368000\n",
+ "369000\n",
+ "370000\n",
+ "371000\n",
+ "372000\n",
+ "373000\n",
+ "374000\n",
+ "375000\n",
+ "376000\n",
+ "377000\n",
+ "378000\n",
+ "379000\n",
+ "380000\n",
+ "381000\n",
+ "382000\n",
+ "383000\n",
+ "384000\n",
+ "385000\n",
+ "386000\n",
+ "387000\n",
+ "388000\n",
+ "389000\n",
+ "390000\n",
+ "391000\n",
+ "392000\n",
+ "393000\n",
+ "394000\n",
+ "395000\n",
+ "396000\n",
+ "397000\n",
+ "398000\n",
+ "399000\n",
+ "400000\n",
+ "401000\n",
+ "402000\n",
+ "403000\n",
+ "404000\n",
+ "405000\n",
+ "406000\n",
+ "407000\n",
+ "408000\n",
+ "409000\n",
+ "410000\n",
+ "411000\n",
+ "412000\n",
+ "413000\n",
+ "414000\n",
+ "415000\n",
+ "416000\n",
+ "417000\n",
+ "418000\n",
+ "419000\n",
+ "420000\n",
+ "421000\n",
+ "422000\n",
+ "423000\n",
+ "424000\n",
+ "425000\n",
+ "426000\n",
+ "427000\n",
+ "428000\n",
+ "429000\n",
+ "430000\n",
+ "431000\n",
+ "432000\n",
+ "/Users/maciej/miniconda3/envs/mj/lib/python3.11/site-packages/torch/nn/modules/container.py:217: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n",
+ " input = module(input)\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "0 tensor(10.5058, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "1000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "100 tensor(7.3365, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "2000\n",
+ "3000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "200 tensor(6.6523, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "4000\n",
+ "5000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "300 tensor(6.1860, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "6000\n",
+ "7000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "400 tensor(6.0387, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "8000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "500 tensor(5.8481, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "9000\n",
+ "10000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "600 tensor(5.6081, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "11000\n",
+ "12000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "700 tensor(5.5820, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "13000\n",
+ "14000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "800 tensor(5.5111, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "15000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "900 tensor(5.4927, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "16000\n",
+ "17000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "1000 tensor(5.5190, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "18000\n",
+ "19000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "1100 tensor(5.5600, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "20000\n",
+ "21000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "1200 tensor(5.6395, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "22000\n",
+ "23000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "1300 tensor(5.4455, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "24000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "1400 tensor(5.5564, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "25000\n",
+ "26000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "1500 tensor(5.4919, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "27000\n",
+ "28000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "1600 tensor(5.2355, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "29000\n",
+ "30000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "1700 tensor(5.4107, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "31000\n",
+ "32000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "1800 tensor(5.5119, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "33000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "1900 tensor(5.3500, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "34000\n",
+ "35000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "2000 tensor(5.3722, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "36000\n",
+ "37000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "2100 tensor(5.2736, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "38000\n",
+ "39000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "2200 tensor(5.3808, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "40000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "2300 tensor(5.5186, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "41000\n",
+ "42000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "2400 tensor(5.2746, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "43000\n",
+ "44000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "2500 tensor(5.3340, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "45000\n",
+ "46000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "2600 tensor(5.4654, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "47000\n",
+ "48000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "2700 tensor(5.4318, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "49000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "2800 tensor(5.3528, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "50000\n",
+ "51000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "2900 tensor(5.1630, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "52000\n",
+ "53000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "3000 tensor(5.4531, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "54000\n",
+ "55000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "3100 tensor(5.4153, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "56000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "3200 tensor(5.3299, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "57000\n",
+ "58000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "3300 tensor(5.3637, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "59000\n",
+ "60000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "3400 tensor(5.3405, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "61000\n",
+ "62000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "3500 tensor(5.3668, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "63000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "3600 tensor(5.4104, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "64000\n",
+ "65000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "3700 tensor(5.2142, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "66000\n",
+ "67000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "3800 tensor(5.5528, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "68000\n",
+ "69000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "3900 tensor(5.1879, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "70000\n",
+ "71000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "4000 tensor(5.2014, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "72000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "4100 tensor(5.4020, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "73000\n",
+ "74000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "4200 tensor(5.2686, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "75000\n",
+ "76000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "4300 tensor(5.3070, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "77000\n",
+ "78000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "4400 tensor(5.1891, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "79000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "4500 tensor(5.3085, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "80000\n",
+ "81000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "4600 tensor(5.3568, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "82000\n",
+ "83000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "4700 tensor(5.2280, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "84000\n",
+ "85000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "4800 tensor(5.2878, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "86000\n",
+ "87000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "4900 tensor(5.1588, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "88000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "5000 tensor(5.1523, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "89000\n",
+ "90000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "5100 tensor(5.2101, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "91000\n",
+ "92000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "5200 tensor(5.2949, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "93000\n",
+ "94000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "5300 tensor(5.3186, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "95000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "5400 tensor(5.2580, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "96000\n",
+ "97000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "5500 tensor(5.3632, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "98000\n",
+ "99000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "5600 tensor(5.3885, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "100000\n",
+ "101000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "5700 tensor(5.2640, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "102000\n",
+ "103000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "5800 tensor(5.4444, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "104000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "5900 tensor(5.1981, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "105000\n",
+ "106000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "6000 tensor(5.2765, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "107000\n",
+ "108000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "6100 tensor(5.3015, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "109000\n",
+ "110000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "6200 tensor(5.1958, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "111000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "6300 tensor(5.1862, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "112000\n",
+ "113000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "6400 tensor(5.4609, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "114000\n",
+ "115000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "6500 tensor(5.2700, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "116000\n",
+ "117000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "6600 tensor(5.3814, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "118000\n",
+ "119000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "6700 tensor(5.2443, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "120000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "6800 tensor(5.2292, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "121000\n",
+ "122000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "6900 tensor(5.2252, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "123000\n",
+ "124000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "7000 tensor(5.3240, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "125000\n",
+ "126000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "7100 tensor(5.3584, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "127000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "7200 tensor(5.2038, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "128000\n",
+ "129000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "7300 tensor(5.3306, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "130000\n",
+ "131000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "7400 tensor(5.3824, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "132000\n",
+ "133000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "7500 tensor(5.1708, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "134000\n",
+ "135000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "7600 tensor(5.3388, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "136000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "7700 tensor(5.2014, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "137000\n",
+ "138000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "7800 tensor(5.3407, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "139000\n",
+ "140000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "7900 tensor(5.3078, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "141000\n",
+ "142000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "8000 tensor(5.0961, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "143000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "8100 tensor(5.1313, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "144000\n",
+ "145000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "8200 tensor(5.2008, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "146000\n",
+ "147000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "8300 tensor(5.1277, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "148000\n",
+ "149000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "8400 tensor(5.3875, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "150000\n",
+ "151000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "8500 tensor(5.3107, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "152000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "8600 tensor(5.3640, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "153000\n",
+ "154000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "8700 tensor(5.1869, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "155000\n",
+ "156000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "8800 tensor(5.0180, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "157000\n",
+ "158000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "8900 tensor(5.1767, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "159000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "9000 tensor(5.3253, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "160000\n",
+ "161000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "9100 tensor(5.1971, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "162000\n",
+ "163000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "9200 tensor(5.2071, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "164000\n",
+ "165000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "9300 tensor(5.1244, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "166000\n",
+ "167000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "9400 tensor(5.2198, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "168000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "9500 tensor(5.3042, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "169000\n",
+ "170000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "9600 tensor(5.3171, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "171000\n",
+ "172000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "9700 tensor(5.1956, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "173000\n",
+ "174000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "9800 tensor(5.1559, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "175000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "9900 tensor(5.1519, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "176000\n",
+ "177000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "10000 tensor(5.3396, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "178000\n",
+ "179000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "10100 tensor(5.2106, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "180000\n",
+ "181000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "10200 tensor(5.3356, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "182000\n",
+ "183000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "10300 tensor(5.2105, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "184000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "10400 tensor(5.0844, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "185000\n",
+ "186000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "10500 tensor(5.3788, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "187000\n",
+ "188000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "10600 tensor(5.1145, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "189000\n",
+ "190000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "10700 tensor(5.2610, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "191000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "10800 tensor(5.2560, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "192000\n",
+ "193000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "10900 tensor(5.2565, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "194000\n",
+ "195000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "11000 tensor(5.2770, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "196000\n",
+ "197000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "11100 tensor(5.1193, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "198000\n",
+ "199000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "11200 tensor(5.1823, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "200000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "11300 tensor(5.3099, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "201000\n",
+ "202000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "11400 tensor(5.2330, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "203000\n",
+ "204000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "11500 tensor(5.1722, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "205000\n",
+ "206000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "11600 tensor(5.2136, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "207000\n",
+ "208000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "11700 tensor(5.3126, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "209000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "11800 tensor(5.1057, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "210000\n",
+ "211000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "11900 tensor(5.2419, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "212000\n",
+ "213000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "12000 tensor(5.2434, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "214000\n",
+ "215000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "12100 tensor(5.1692, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "216000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "12200 tensor(5.2075, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "217000\n",
+ "218000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "12300 tensor(5.1290, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "219000\n",
+ "220000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "12400 tensor(5.2380, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "221000\n",
+ "222000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "12500 tensor(5.2779, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "223000\n",
+ "224000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "12600 tensor(5.3369, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "225000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "12700 tensor(5.2351, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "226000\n",
+ "227000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "12800 tensor(5.2434, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "228000\n",
+ "229000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "12900 tensor(5.1963, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "230000\n",
+ "231000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "13000 tensor(5.1363, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "232000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "13100 tensor(5.1915, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "233000\n",
+ "234000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "13200 tensor(5.1264, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "235000\n",
+ "236000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "13300 tensor(5.1468, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "237000\n",
+ "238000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "13400 tensor(5.3026, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "239000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "13500 tensor(5.2925, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "240000\n",
+ "241000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "13600 tensor(5.1511, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "242000\n",
+ "243000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "13700 tensor(5.4282, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "244000\n",
+ "245000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "13800 tensor(5.2730, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "246000\n",
+ "247000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "13900 tensor(5.2097, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "248000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "14000 tensor(5.2728, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "249000\n",
+ "250000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "14100 tensor(5.2134, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "251000\n",
+ "252000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "14200 tensor(5.1931, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "253000\n",
+ "254000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "14300 tensor(5.2459, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "255000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "14400 tensor(5.1297, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "256000\n",
+ "257000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "14500 tensor(5.0971, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "258000\n",
+ "259000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "14600 tensor(5.2238, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "260000\n",
+ "261000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "14700 tensor(5.2328, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "262000\n",
+ "263000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "14800 tensor(5.1782, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "264000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "14900 tensor(5.3230, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "265000\n",
+ "266000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "15000 tensor(5.1504, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "267000\n",
+ "268000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "15100 tensor(5.1998, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "269000\n",
+ "270000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "15200 tensor(5.2138, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "271000\n",
+ "272000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "15300 tensor(5.4110, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "273000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "15400 tensor(5.1748, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "274000\n",
+ "275000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "15500 tensor(5.2118, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "276000\n",
+ "277000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "15600 tensor(5.2297, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "278000\n",
+ "279000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "15700 tensor(5.2977, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "280000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "15800 tensor(5.2175, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "281000\n",
+ "282000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "15900 tensor(5.0613, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "283000\n",
+ "284000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "16000 tensor(5.0862, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "285000\n",
+ "286000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "16100 tensor(5.1910, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "287000\n",
+ "288000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "16200 tensor(5.0195, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "289000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "16300 tensor(5.1381, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "290000\n",
+ "291000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "16400 tensor(5.2135, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "292000\n",
+ "293000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "16500 tensor(5.2058, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "294000\n",
+ "295000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "16600 tensor(5.2372, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "296000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "16700 tensor(5.1753, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "297000\n",
+ "298000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "16800 tensor(5.0765, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "299000\n",
+ "300000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "16900 tensor(5.3361, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "301000\n",
+ "302000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "17000 tensor(5.2745, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "303000\n",
+ "304000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "17100 tensor(5.2249, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "305000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "17200 tensor(5.1877, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "306000\n",
+ "307000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "17300 tensor(5.0891, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "308000\n",
+ "309000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "17400 tensor(5.4181, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "310000\n",
+ "311000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "17500 tensor(5.1299, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "312000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "17600 tensor(5.1636, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "313000\n",
+ "314000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "17700 tensor(5.2179, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "315000\n",
+ "316000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "17800 tensor(5.2689, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "317000\n",
+ "318000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "17900 tensor(5.2410, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "319000\n",
+ "320000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "18000 tensor(5.2342, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "321000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "18100 tensor(5.2234, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "322000\n",
+ "323000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "18200 tensor(5.0779, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "324000\n",
+ "325000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "18300 tensor(5.2378, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "326000\n",
+ "327000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "18400 tensor(5.1710, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "328000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "18500 tensor(5.1134, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "329000\n",
+ "330000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "18600 tensor(5.2679, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "331000\n",
+ "332000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "18700 tensor(5.2590, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "333000\n",
+ "334000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "18800 tensor(5.1842, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "335000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "18900 tensor(5.1379, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "336000\n",
+ "337000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "19000 tensor(5.1416, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "338000\n",
+ "339000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "19100 tensor(5.1602, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "340000\n",
+ "341000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "19200 tensor(5.2670, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "342000\n",
+ "343000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "19300 tensor(5.1622, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "344000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "19400 tensor(5.1805, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "345000\n",
+ "346000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "19500 tensor(5.1820, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "347000\n",
+ "348000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "19600 tensor(5.2506, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "349000\n",
+ "350000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "19700 tensor(5.1566, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "351000\n",
+ "352000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "19800 tensor(5.1121, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "353000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "19900 tensor(5.1227, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "354000\n",
+ "355000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "20000 tensor(5.2132, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "356000\n",
+ "357000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "20100 tensor(5.2681, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "358000\n",
+ "359000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "20200 tensor(5.2689, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "360000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "20300 tensor(5.1758, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "361000\n",
+ "362000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "20400 tensor(5.1275, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "363000\n",
+ "364000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "20500 tensor(5.1803, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "365000\n",
+ "366000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "20600 tensor(5.1202, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "367000\n",
+ "368000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "20700 tensor(5.2343, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "369000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "20800 tensor(5.2035, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "370000\n",
+ "371000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "20900 tensor(5.2992, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "372000\n",
+ "373000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "21000 tensor(5.1540, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "374000\n",
+ "375000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "21100 tensor(5.2739, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "376000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "21200 tensor(5.2949, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "377000\n",
+ "378000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "21300 tensor(5.2138, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "379000\n",
+ "380000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "21400 tensor(5.2773, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "381000\n",
+ "382000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "21500 tensor(5.2345, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "383000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "21600 tensor(5.2528, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "384000\n",
+ "385000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "21700 tensor(5.1824, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "386000\n",
+ "387000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "21800 tensor(5.1943, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "388000\n",
+ "389000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "21900 tensor(5.0359, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "390000\n",
+ "391000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "22000 tensor(5.1506, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "392000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "22100 tensor(5.1253, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "393000\n",
+ "394000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "22200 tensor(5.0982, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "395000\n",
+ "396000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "22300 tensor(5.1554, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "397000\n",
+ "398000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "22400 tensor(5.1673, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "399000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "22500 tensor(5.1957, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "400000\n",
+ "401000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "22600 tensor(5.1328, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "402000\n",
+ "403000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "22700 tensor(5.2231, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "404000\n",
+ "405000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "22800 tensor(5.1370, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "406000\n",
+ "407000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "22900 tensor(5.2334, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "408000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "23000 tensor(5.1372, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "409000\n",
+ "410000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "23100 tensor(5.1193, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "411000\n",
+ "412000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "23200 tensor(5.2649, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "413000\n",
+ "414000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "23300 tensor(5.1514, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "415000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "23400 tensor(5.2532, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "416000\n",
+ "417000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "23500 tensor(5.3751, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "418000\n",
+ "419000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "23600 tensor(5.0766, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "420000\n",
+ "421000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "23700 tensor(5.0915, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "422000\n",
+ "423000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "23800 tensor(5.3195, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "424000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "23900 tensor(5.2758, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "425000\n",
+ "426000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "24000 tensor(5.0487, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "427000\n",
+ "428000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "24100 tensor(5.1555, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "429000\n",
+ "430000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "24200 tensor(5.2140, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "431000\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "24300 tensor(5.2729, device='mps:0', grad_fn=)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "432000\n"
+ ]
+ }
+ ],
+ "source": [
+ "train_dataset = Bigrams(path_to_training_file, vocab_size)\n",
+ "model = SimpleBigramNeuralLanguageModel(vocab_size, embed_size).to(device)\n",
+ "data = DataLoader(train_dataset, batch_size=batch_size)\n",
+ "optimizer = torch.optim.Adam(model.parameters())\n",
+ "criterion = torch.nn.NLLLoss()\n",
+ "\n",
+ "model.train()\n",
+ "step = 0\n",
+ "for x, y in data:\n",
+ " x = x.to(device)\n",
+ " y = y.to(device)\n",
+ " optimizer.zero_grad()\n",
+ " ypredicted = model(x)\n",
+ " loss = criterion(torch.log(ypredicted), y)\n",
+ " if step % 100 == 0:\n",
+ " print(step, loss)\n",
+ " step += 1\n",
+ " loss.backward()\n",
+ " optimizer.step()\n",
+ "\n",
+ "torch.save(model.state_dict(), path_to_model_file)"
+ ],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## LOAD MODEL AND VOCAB"
+ ],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "outputs": [
+ {
+ "data": {
+ "text/plain": "SimpleBigramNeuralLanguageModel(\n (model): Sequential(\n (0): Embedding(30000, 1000)\n (1): Linear(in_features=1000, out_features=30000, bias=True)\n (2): Softmax(dim=None)\n )\n)"
+ },
+ "execution_count": 7,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "with open(path_to_vocabulary_file, 'rb') as handle:\n",
+ " vocab = pickle.load(handle)\n",
+ "model = SimpleBigramNeuralLanguageModel(vocab_size, embed_size).to(device)\n",
+ "model.load_state_dict(torch.load(path_to_model_file))\n",
+ "model.eval()"
+ ],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## CREATE OUTPUTS FILES"
+ ],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "### DEV-0"
+ ],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "========== do prediction for dev-0/in.tsv.xz ==========\n"
+ ]
+ }
+ ],
+ "source": [
+ "predicition_for_file(model, vocab, folder_dev_0, file_dev_0)"
+ ],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "### TEST-A"
+ ],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "========== do prediction for test-A/in.tsv.xz ==========\n"
+ ]
+ }
+ ],
+ "source": [
+ "predicition_for_file(model, vocab, folder_test_a, file_test_a)"
+ ],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "outputs": [],
+ "source": [],
+ "metadata": {
+ "collapsed": false
+ }
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 2
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython2",
+ "version": "2.7.6"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}