{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "8b023ab4", "metadata": {}, "outputs": [], "source": [ "train_file ='train/in.tsv.xz'\n", "test_file = 'dev-0/in.tsv.xz'\n", "out_file = 'dev-0/out.tsv'" ] }, { "cell_type": "code", "execution_count": 4, "id": "39b223cf", "metadata": {}, "outputs": [], "source": [ "from itertools import islice\n", "import regex as re\n", "import sys\n", "from torchtext.vocab import build_vocab_from_iterator\n", "import lzma\n", "import pickle\n", "import re\n", "import torch\n", "from torch import nn\n", "from torch.utils.data import IterableDataset\n", "import itertools\n", "from torch.utils.data import DataLoader\n", "import yaml" ] }, { "cell_type": "code", "execution_count": 27, "id": "a0b0b73e", "metadata": {}, "outputs": [], "source": [ "epochs = 3\n", "embed_size = 200\n", "device = 'cuda'\n", "vocab_size = 30000\n", "batch_s = 1600\n", "learning_rate = 0.01\n", "k = 20 #top k words\n", "wildcard_minweight = 0.01" ] }, { "cell_type": "code", "execution_count": 26, "id": "2ac3a353", "metadata": {}, "outputs": [], "source": [ "params = {\n", "'epochs': 3,\n", "'embed_size': 100,\n", "'device': 'cuda',\n", "'vocab_size': 30000,\n", "'batch_size': 3200,\n", "'learning_rate': 0.0001,\n", "'k': 15, #top k words\n", "'wildcard_minweight': 0.01\n", "}" ] }, { "cell_type": "code", "execution_count": 14, "id": "9668da9f", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/tmp/ipykernel_37433/1141171476.py:1: YAMLLoadWarning: calling yaml.load() without Loader=... is deprecated, as the default Loader is unsafe. Please read https://msg.pyyaml.org/load for full details.\n", " params = yaml.load(open('config/params.yaml'))\n" ] } ], "source": [ "params = yaml.load(open('config/params.yaml'))\n", "#then, entire code should go about those params with params[epochs] etc" ] }, { "cell_type": "code", "execution_count": 6, "id": "01a6cf33", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'epochs': 3,\n", " 'embed_size': 100,\n", " 'device': 'cuda',\n", " 'vocab_size': 30000,\n", " 'batch_size': 3200,\n", " 'learning_rate': 0.0001,\n", " 'k': 15,\n", " 'wildcard_minweight': 0.01}" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "params" ] }, { "cell_type": "code", "execution_count": 12, "id": "7526e30c", "metadata": {}, "outputs": [], "source": [ "def get_words_from_line(line):\n", " line = line.rstrip()\n", " yield ''\n", " line = preprocess(line)\n", " for t in line.split(' '):\n", " yield t\n", " yield ''\n", "\n", "\n", "def get_word_lines_from_file(file_name):\n", " n = 0\n", " with lzma.open(file_name, 'r') as fh:\n", " for line in fh:\n", " n+=1\n", " if n%1000==0:\n", " print(n)\n", " yield get_words_from_line(line.decode('utf-8'))\n" ] }, { "cell_type": "code", "execution_count": 13, "id": "01cde371", "metadata": {}, "outputs": [], "source": [ "def look_ahead_iterator(gen):\n", " prev2 = None\n", " prev1 = None\n", " for item in gen:\n", " if prev2 is not None and prev1 is not None:\n", " yield (prev2, prev1, item)\n", " prev2 = prev1\n", " prev1 = item\n", "\n", "class Trigrams(IterableDataset):\n", " def __init__(self, text_file, vocabulary_size):\n", " self.vocab = build_vocab_from_iterator(\n", " get_word_lines_from_file(text_file),\n", " max_tokens = vocabulary_size,\n", " specials = [''])\n", " self.vocab.set_default_index(self.vocab[''])\n", " self.vocabulary_size = vocabulary_size\n", " self.text_file = text_file\n", "\n", " def __iter__(self):\n", " return look_ahead_iterator(\n", " (self.vocab[t] for t in itertools.chain.from_iterable(get_word_lines_from_file(self.text_file))))\n", " " ] }, { "cell_type": "code", "execution_count": 14, "id": "198b1dd3", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1000\n", "2000\n", "3000\n", "4000\n", "5000\n", "6000\n", "7000\n", "8000\n", "9000\n", "10000\n", "11000\n", "12000\n", "13000\n", "14000\n", "15000\n", "16000\n", "17000\n", "18000\n", "19000\n", "20000\n", "21000\n", "22000\n", "23000\n", "24000\n", "25000\n", "26000\n", "27000\n", "28000\n", "29000\n", "30000\n", "31000\n", "32000\n", "33000\n", "34000\n", "35000\n", "36000\n", "37000\n", "38000\n", "39000\n", "40000\n", "41000\n", "42000\n", "43000\n", "44000\n", "45000\n", "46000\n", "47000\n", "48000\n", "49000\n", "50000\n", "51000\n", "52000\n", "53000\n", "54000\n", "55000\n", "56000\n", "57000\n", "58000\n", "59000\n", "60000\n", "61000\n", "62000\n", "63000\n", "64000\n", "65000\n", "66000\n", "67000\n", "68000\n", "69000\n", "70000\n", "71000\n", "72000\n", "73000\n", "74000\n", "75000\n", "76000\n", "77000\n", "78000\n", "79000\n", "80000\n", "81000\n", "82000\n", "83000\n", "84000\n", "85000\n", "86000\n", "87000\n", "88000\n", "89000\n", "90000\n", "91000\n", "92000\n", "93000\n", "94000\n", "95000\n", "96000\n", "97000\n", "98000\n", "99000\n", "100000\n", "101000\n", "102000\n", "103000\n", "104000\n", "105000\n", "106000\n", "107000\n", "108000\n", "109000\n", "110000\n", "111000\n", "112000\n", "113000\n", "114000\n", "115000\n", "116000\n", "117000\n", "118000\n", "119000\n", "120000\n", "121000\n", "122000\n", "123000\n", "124000\n", "125000\n", "126000\n", "127000\n", "128000\n", "129000\n", "130000\n", "131000\n", "132000\n", "133000\n", "134000\n", "135000\n", "136000\n", "137000\n", "138000\n", "139000\n", "140000\n", "141000\n", "142000\n", "143000\n", "144000\n", "145000\n", "146000\n", "147000\n", "148000\n", "149000\n", "150000\n", "151000\n", "152000\n", "153000\n", "154000\n", "155000\n", "156000\n", "157000\n", "158000\n", "159000\n", "160000\n", "161000\n", "162000\n", "163000\n", "164000\n", "165000\n", "166000\n", "167000\n", "168000\n", "169000\n", "170000\n", "171000\n", "172000\n", "173000\n", "174000\n", "175000\n", "176000\n", "177000\n", "178000\n", "179000\n", "180000\n", "181000\n", "182000\n", "183000\n", "184000\n", "185000\n", "186000\n", "187000\n", "188000\n", "189000\n", "190000\n", "191000\n", "192000\n", "193000\n", "194000\n", "195000\n", "196000\n", "197000\n", "198000\n", "199000\n", "200000\n", "201000\n", "202000\n", "203000\n", "204000\n", "205000\n", "206000\n", "207000\n", "208000\n", "209000\n", "210000\n", "211000\n", "212000\n", "213000\n", "214000\n", "215000\n", "216000\n", "217000\n", "218000\n", "219000\n", "220000\n", "221000\n", "222000\n", "223000\n", "224000\n", "225000\n", "226000\n", "227000\n", "228000\n", "229000\n", "230000\n", "231000\n", "232000\n", "233000\n", "234000\n", "235000\n", "236000\n", "237000\n", "238000\n", "239000\n", "240000\n", "241000\n", "242000\n", "243000\n", "244000\n", "245000\n", "246000\n", "247000\n", "248000\n", "249000\n", "250000\n", "251000\n", "252000\n", "253000\n", "254000\n", "255000\n", "256000\n", "257000\n", "258000\n", "259000\n", "260000\n", "261000\n", "262000\n", "263000\n", "264000\n", "265000\n", "266000\n", "267000\n", "268000\n", "269000\n", "270000\n", "271000\n", "272000\n", "273000\n", "274000\n", "275000\n", "276000\n", "277000\n", "278000\n", "279000\n", "280000\n", "281000\n", "282000\n", "283000\n", "284000\n", "285000\n", "286000\n", "287000\n", "288000\n", "289000\n", "290000\n", "291000\n", "292000\n", "293000\n", "294000\n", "295000\n", "296000\n", "297000\n", "298000\n", "299000\n", "300000\n", "301000\n", "302000\n", "303000\n", "304000\n", "305000\n", "306000\n", "307000\n", "308000\n", "309000\n", "310000\n", "311000\n", "312000\n", "313000\n", "314000\n", "315000\n", "316000\n", "317000\n", "318000\n", "319000\n", "320000\n", "321000\n", "322000\n", "323000\n", "324000\n", "325000\n", "326000\n", "327000\n", "328000\n", "329000\n", "330000\n", "331000\n", "332000\n", "333000\n", "334000\n", "335000\n", "336000\n", "337000\n", "338000\n", "339000\n", "340000\n", "341000\n", "342000\n", "343000\n", "344000\n", "345000\n", "346000\n", "347000\n", "348000\n", "349000\n", "350000\n", "351000\n", "352000\n", "353000\n", "354000\n", "355000\n", "356000\n", "357000\n", "358000\n", "359000\n", "360000\n", "361000\n", "362000\n", "363000\n", "364000\n", "365000\n", "366000\n", "367000\n", "368000\n", "369000\n", "370000\n", "371000\n", "372000\n", "373000\n", "374000\n", "375000\n", "376000\n", "377000\n", "378000\n", "379000\n", "380000\n", "381000\n", "382000\n", "383000\n", "384000\n", "385000\n", "386000\n", "387000\n", "388000\n", "389000\n", "390000\n", "391000\n", "392000\n", "393000\n", "394000\n", "395000\n", "396000\n", "397000\n", "398000\n", "399000\n", "400000\n", "401000\n", "402000\n", "403000\n", "404000\n", "405000\n", "406000\n", "407000\n", "408000\n", "409000\n", "410000\n", "411000\n", "412000\n", "413000\n", "414000\n", "415000\n", "416000\n", "417000\n", "418000\n", "419000\n", "420000\n", "421000\n", "422000\n", "423000\n", "424000\n", "425000\n", "426000\n", "427000\n", "428000\n", "429000\n", "430000\n", "431000\n", "432000\n" ] } ], "source": [ "vocab = build_vocab_from_iterator(\n", " get_word_lines_from_file(train_file),\n", " max_tokens = params['vocab_size'],\n", " specials = [''])" ] }, { "cell_type": "code", "execution_count": 15, "id": "6136fbb9", "metadata": {}, "outputs": [], "source": [ "with open('filename.pickle', 'wb') as handle:\n", " pickle.dump(vocab, handle, protocol=pickle.HIGHEST_PROTOCOL)" ] }, { "cell_type": "code", "execution_count": 23, "id": "30a5b26b", "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1000\n", "2000\n", "3000\n", "4000\n", "5000\n", "6000\n", "7000\n", "8000\n", "9000\n", "10000\n", "11000\n", "12000\n", "13000\n", "14000\n", "15000\n", "16000\n", "17000\n", "18000\n", "19000\n", "20000\n", "21000\n", "22000\n", "23000\n", "24000\n", "25000\n", "26000\n", "27000\n", "28000\n", "29000\n", "30000\n", "31000\n", "32000\n", "33000\n", "34000\n", "35000\n", "36000\n", "37000\n", "38000\n", "39000\n", "40000\n", "41000\n", "42000\n", "43000\n", "44000\n", "45000\n", "46000\n", "47000\n", "48000\n", "49000\n", "50000\n", "51000\n", "52000\n", "53000\n", "54000\n", "55000\n", "56000\n", "57000\n", "58000\n", "59000\n", "60000\n", "61000\n", "62000\n", "63000\n", "64000\n", "65000\n", "66000\n", "67000\n", "68000\n", "69000\n", "70000\n", "71000\n", "72000\n", "73000\n", "74000\n", "75000\n", "76000\n", "77000\n", "78000\n", "79000\n", "80000\n", "81000\n", "82000\n", "83000\n", "84000\n", "85000\n", "86000\n", "87000\n", "88000\n", "89000\n", "90000\n", "91000\n", "92000\n", "93000\n", "94000\n", "95000\n", "96000\n", "97000\n", "98000\n", "99000\n", "100000\n", "101000\n", "102000\n", "103000\n", "104000\n", "105000\n", "106000\n", "107000\n", "108000\n", "109000\n", "110000\n", "111000\n", "112000\n", "113000\n", "114000\n", "115000\n", "116000\n", "117000\n", "118000\n", "119000\n", "120000\n", "121000\n", "122000\n", "123000\n", "124000\n", "125000\n", "126000\n", "127000\n", "128000\n", "129000\n", "130000\n", "131000\n", "132000\n", "133000\n", "134000\n", "135000\n", "136000\n", "137000\n", "138000\n", "139000\n", "140000\n", "141000\n", "142000\n", "143000\n", "144000\n", "145000\n", "146000\n", "147000\n", "148000\n", "149000\n", "150000\n", "151000\n", "152000\n", "153000\n", "154000\n", "155000\n", "156000\n", "157000\n", "158000\n", "159000\n", "160000\n", "161000\n", "162000\n", "163000\n", "164000\n", "165000\n", "166000\n", "167000\n", "168000\n", "169000\n", "170000\n", "171000\n", "172000\n", "173000\n", "174000\n", "175000\n", "176000\n", "177000\n", "178000\n", "179000\n", "180000\n", "181000\n", "182000\n", "183000\n", "184000\n", "185000\n", "186000\n", "187000\n", "188000\n", "189000\n", "190000\n", "191000\n", "192000\n", "193000\n", "194000\n", "195000\n", "196000\n", "197000\n", "198000\n", "199000\n", "200000\n", "201000\n", "202000\n", "203000\n", "204000\n", "205000\n", "206000\n", "207000\n", "208000\n", "209000\n", "210000\n", "211000\n", "212000\n", "213000\n", "214000\n", "215000\n", "216000\n", "217000\n", "218000\n", "219000\n", "220000\n", "221000\n", "222000\n", "223000\n", "224000\n", "225000\n", "226000\n", "227000\n", "228000\n", "229000\n", "230000\n", "231000\n", "232000\n", "233000\n", "234000\n", "235000\n", "236000\n", "237000\n", "238000\n", "239000\n", "240000\n", "241000\n", "242000\n", "243000\n", "244000\n", "245000\n", "246000\n", "247000\n", "248000\n", "249000\n", "250000\n", "251000\n", "252000\n", "253000\n", "254000\n", "255000\n", "256000\n", "257000\n", "258000\n", "259000\n", "260000\n", "261000\n", "262000\n", "263000\n", "264000\n", "265000\n", "266000\n", "267000\n", "268000\n", "269000\n", "270000\n", "271000\n", "272000\n", "273000\n", "274000\n", "275000\n", "276000\n", "277000\n", "278000\n", "279000\n", "280000\n", "281000\n", "282000\n", "283000\n", "284000\n", "285000\n", "286000\n", "287000\n", "288000\n", "289000\n", "290000\n", "291000\n", "292000\n", "293000\n", "294000\n", "295000\n", "296000\n", "297000\n", "298000\n", "299000\n", "300000\n", "301000\n", "302000\n", "303000\n", "304000\n", "305000\n", "306000\n", "307000\n", "308000\n", "309000\n", "310000\n", "311000\n", "312000\n", "313000\n", "314000\n", "315000\n", "316000\n", "317000\n", "318000\n", "319000\n", "320000\n", "321000\n", "322000\n", "323000\n", "324000\n", "325000\n", "326000\n", "327000\n", "328000\n", "329000\n", "330000\n", "331000\n", "332000\n", "333000\n", "334000\n", "335000\n", "336000\n", "337000\n", "338000\n", "339000\n", "340000\n", "341000\n", "342000\n", "343000\n", "344000\n", "345000\n", "346000\n", "347000\n", "348000\n", "349000\n", "350000\n", "351000\n", "352000\n", "353000\n", "354000\n", "355000\n", "356000\n", "357000\n", "358000\n", "359000\n", "360000\n", "361000\n", "362000\n", "363000\n", "364000\n", "365000\n", "366000\n", "367000\n", "368000\n", "369000\n", "370000\n", "371000\n", "372000\n", "373000\n", "374000\n", "375000\n", "376000\n", "377000\n", "378000\n", "379000\n", "380000\n", "381000\n", "382000\n", "383000\n", "384000\n", "385000\n", "386000\n", "387000\n", "388000\n", "389000\n", "390000\n", "391000\n", "392000\n", "393000\n", "394000\n", "395000\n", "396000\n", "397000\n", "398000\n", "399000\n", "400000\n", "401000\n", "402000\n", "403000\n", "404000\n", "405000\n", "406000\n", "407000\n", "408000\n", "409000\n", "410000\n", "411000\n", "412000\n", "413000\n", "414000\n", "415000\n", "416000\n", "417000\n", "418000\n", "419000\n", "420000\n", "421000\n", "422000\n", "423000\n", "424000\n", "425000\n", "426000\n", "427000\n", "428000\n", "429000\n", "430000\n", "431000\n", "432000\n" ] } ], "source": [ "with open('filename.pickle','rb') as handle:\n", " vocab = pickle.load(handle)\n", " \n", "train_dataset = Trigrams(train_file, params['vocab_size'])" ] }, { "cell_type": "code", "execution_count": 21, "id": "eaa681b4", "metadata": {}, "outputs": [], "source": [ "data = DataLoader(train_dataset, batch_size=params['batch_size']) #load data " ] }, { "cell_type": "code", "execution_count": 16, "id": "3aea0574", "metadata": {}, "outputs": [], "source": [ "class SimpleTrigramNeuralLanguageModel(nn.Module):\n", " def __init__(self, vocabulary_size, embedding_size):\n", " super(SimpleTrigramNeuralLanguageModel, self).__init__()\n", " self.embeddings = nn.Embedding(vocabulary_size, embedding_size)\n", " self.linear = nn.Linear(2*embedding_size, vocabulary_size)\n", " self.linear_matrix_2 = nn.Linear(embedding_size*2, embedding_size*2)\n", " self.relu = nn.ReLU()\n", " self.softmax = nn.Softmax()\n", " \n", " #for each word in vocabulary theres a separate embedding vector, consisting of embedding_size entries\n", " #self.linear is linear layer consisting of concatenated embeddings of left, and right context words\n", " #self.linear_matrix_2 is linear layer \n", " \n", " def forward(self, x): #x is list of prev and following embeddings\n", " emb_left = self.embeddings(x[0])\n", " emb_right = self.embeddings(x[1])\n", " #create two embeddings vectors, for word before and after, respectively\n", " \n", " first_layer_size_2 = self.linear_matrix_2(torch.cat((emb_left, emb_right), dim=1))\n", " first_relu = self.relu(first_layer_size_2)\n", " concated = self.linear(first_relu)\n", " out = self.softmax(concated)\n", " return out" ] }, { "cell_type": "code", "execution_count": 24, "id": "e4757295", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import gc\n", "torch.cuda.empty_cache()\n", "gc.collect()" ] }, { "cell_type": "code", "execution_count": 17, "id": "0a41831e", "metadata": {}, "outputs": [], "source": [ "device = 'cuda'\n", "model = SimpleTrigramNeuralLanguageModel(params['vocab_size'], params['embed_size']).to(device)\n", "optimizer = torch.optim.Adam(model.parameters(), lr=params['learning_rate'])\n", "criterion = torch.nn.NLLLoss()" ] }, { "cell_type": "code", "execution_count": 26, "id": "281b9010", "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "epoch: = 0\n", "0 tensor(5.3414, device='cuda:0', grad_fn=)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/tmp/ipykernel_37433/606935597.py:22: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n", " out = self.softmax(concated)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "1000\n", "100 tensor(5.4870, device='cuda:0', grad_fn=)\n", "200 tensor(5.3542, device='cuda:0', grad_fn=)\n", "2000\n", "300 tensor(5.3792, device='cuda:0', grad_fn=)\n", "3000\n", "400 tensor(5.5982, device='cuda:0', grad_fn=)\n", "4000\n", "500 tensor(5.4045, device='cuda:0', grad_fn=)\n", "5000\n", "600 tensor(5.5620, device='cuda:0', grad_fn=)\n", "6000\n", "700 tensor(5.5428, device='cuda:0', grad_fn=)\n", "7000\n", "800 tensor(5.3684, device='cuda:0', grad_fn=)\n", "8000\n", "900 tensor(5.4198, device='cuda:0', grad_fn=)\n", "9000\n", "1000 tensor(5.4100, device='cuda:0', grad_fn=)\n", "10000\n", "1100 tensor(5.4554, device='cuda:0', grad_fn=)\n", "11000\n", "1200 tensor(5.5284, device='cuda:0', grad_fn=)\n", "12000\n", "1300 tensor(5.5495, device='cuda:0', grad_fn=)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/home/gedin/.local/lib/python3.10/site-packages/torch/autograd/__init__.py:200: UserWarning: Error detected in LogBackward0. Traceback of forward call that caused the error:\n", " File \"/usr/lib/python3.10/runpy.py\", line 196, in _run_module_as_main\n", " return _run_code(code, main_globals, None,\n", " File \"/usr/lib/python3.10/runpy.py\", line 86, in _run_code\n", " exec(code, run_globals)\n", " File \"/home/gedin/.local/lib/python3.10/site-packages/ipykernel_launcher.py\", line 17, in \n", " app.launch_new_instance()\n", " File \"/home/gedin/.local/lib/python3.10/site-packages/traitlets/config/application.py\", line 1043, in launch_instance\n", " app.start()\n", " File \"/home/gedin/.local/lib/python3.10/site-packages/ipykernel/kernelapp.py\", line 725, in start\n", " self.io_loop.start()\n", " File \"/home/gedin/.local/lib/python3.10/site-packages/tornado/platform/asyncio.py\", line 195, in start\n", " self.asyncio_loop.run_forever()\n", " File \"/usr/lib/python3.10/asyncio/base_events.py\", line 600, in run_forever\n", " self._run_once()\n", " File \"/usr/lib/python3.10/asyncio/base_events.py\", line 1896, in _run_once\n", " handle._run()\n", " File \"/usr/lib/python3.10/asyncio/events.py\", line 80, in _run\n", " self._context.run(self._callback, *self._args)\n", " File \"/home/gedin/.local/lib/python3.10/site-packages/ipykernel/kernelbase.py\", line 513, in dispatch_queue\n", " await self.process_one()\n", " File \"/home/gedin/.local/lib/python3.10/site-packages/ipykernel/kernelbase.py\", line 502, in process_one\n", " await dispatch(*args)\n", " File \"/home/gedin/.local/lib/python3.10/site-packages/ipykernel/kernelbase.py\", line 409, in dispatch_shell\n", " await result\n", " File \"/home/gedin/.local/lib/python3.10/site-packages/ipykernel/kernelbase.py\", line 729, in execute_request\n", " reply_content = await reply_content\n", " File \"/home/gedin/.local/lib/python3.10/site-packages/ipykernel/ipkernel.py\", line 422, in do_execute\n", " res = shell.run_cell(\n", " File \"/home/gedin/.local/lib/python3.10/site-packages/ipykernel/zmqshell.py\", line 540, in run_cell\n", " return super().run_cell(*args, **kwargs)\n", " File \"/home/gedin/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py\", line 3009, in run_cell\n", " result = self._run_cell(\n", " File \"/home/gedin/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py\", line 3064, in _run_cell\n", " result = runner(coro)\n", " File \"/home/gedin/.local/lib/python3.10/site-packages/IPython/core/async_helpers.py\", line 129, in _pseudo_sync_runner\n", " coro.send(None)\n", " File \"/home/gedin/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py\", line 3269, in run_cell_async\n", " has_raised = await self.run_ast_nodes(code_ast.body, cell_name,\n", " File \"/home/gedin/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py\", line 3448, in run_ast_nodes\n", " if await self.run_code(code, result, async_=asy):\n", " File \"/home/gedin/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py\", line 3508, in run_code\n", " exec(code_obj, self.user_global_ns, self.user_ns)\n", " File \"/tmp/ipykernel_37433/1707264841.py\", line 13, in \n", " loss = criterion(torch.log(ypredicted), x) #x is to_predict\n", " (Triggered internally at ../torch/csrc/autograd/python_anomaly_mode.cpp:114.)\n", " Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass\n" ] }, { "ename": "RuntimeError", "evalue": "Function 'LogBackward0' returned nan values in its 0th output.", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[26], line 19\u001b[0m\n\u001b[1;32m 16\u001b[0m step \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m 17\u001b[0m \u001b[38;5;66;03m# if step % 10000 == 0:\u001b[39;00m\n\u001b[1;32m 18\u001b[0m \u001b[38;5;66;03m# torch.save(model.state_dict(), f'model-tri-2following-{step}.bin')\u001b[39;00m\n\u001b[0;32m---> 19\u001b[0m \u001b[43mloss\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 20\u001b[0m optimizer\u001b[38;5;241m.\u001b[39mstep()\n\u001b[1;32m 21\u001b[0m \u001b[38;5;66;03m# torch.save(model.state_dict(), f'model-tri-2following-{i}.bin') \u001b[39;00m\n\u001b[1;32m 22\u001b[0m \u001b[38;5;66;03m# torch.save(model.state_dict(), f'model-tri-2following-final.bin')\u001b[39;00m\n", "File \u001b[0;32m~/.local/lib/python3.10/site-packages/torch/_tensor.py:487\u001b[0m, in \u001b[0;36mTensor.backward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m 477\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m has_torch_function_unary(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m 478\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m handle_torch_function(\n\u001b[1;32m 479\u001b[0m Tensor\u001b[38;5;241m.\u001b[39mbackward,\n\u001b[1;32m 480\u001b[0m (\u001b[38;5;28mself\u001b[39m,),\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 485\u001b[0m inputs\u001b[38;5;241m=\u001b[39minputs,\n\u001b[1;32m 486\u001b[0m )\n\u001b[0;32m--> 487\u001b[0m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mautograd\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 488\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgradient\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs\u001b[49m\n\u001b[1;32m 489\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m~/.local/lib/python3.10/site-packages/torch/autograd/__init__.py:200\u001b[0m, in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m 195\u001b[0m retain_graph \u001b[38;5;241m=\u001b[39m create_graph\n\u001b[1;32m 197\u001b[0m \u001b[38;5;66;03m# The reason we repeat same the comment below is that\u001b[39;00m\n\u001b[1;32m 198\u001b[0m \u001b[38;5;66;03m# some Python versions print out the first line of a multi-line function\u001b[39;00m\n\u001b[1;32m 199\u001b[0m \u001b[38;5;66;03m# calls in the traceback and some print out the last line\u001b[39;00m\n\u001b[0;32m--> 200\u001b[0m \u001b[43mVariable\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_execution_engine\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun_backward\u001b[49m\u001b[43m(\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# Calls into the C++ engine to run the backward pass\u001b[39;49;00m\n\u001b[1;32m 201\u001b[0m \u001b[43m \u001b[49m\u001b[43mtensors\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgrad_tensors_\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 202\u001b[0m \u001b[43m \u001b[49m\u001b[43mallow_unreachable\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maccumulate_grad\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n", "\u001b[0;31mRuntimeError\u001b[0m: Function 'LogBackward0' returned nan values in its 0th output." ] } ], "source": [ "torch.autograd.set_detect_anomaly(True)\n", "model.load_state_dict(torch.load(f'model-tri-2following-40000.bin'))\n", "for i in range(params['epochs']):\n", " print('epoch: =', i)\n", " model.train()\n", " step = 0\n", " for x, y, z in data: # word, following, 2nd_following words\n", " x = x.to(device)\n", " y = y.to(device)\n", " z = z.to(device)\n", " optimizer.zero_grad()\n", " ypredicted = model([y, z]) #following, 2nd_following word\n", " loss = criterion(torch.log(ypredicted), x) #x is to_predict\n", " if step % 100 == 0:\n", " print(step, loss)\n", " step += 1\n", "# if step % 10000 == 0:\n", "# torch.save(model.state_dict(), f'model-tri-2following-{step}.bin')\n", " loss.backward()\n", " optimizer.step()\n", "# torch.save(model.state_dict(), f'model-tri-2following-{i}.bin') \n", "# torch.save(model.state_dict(), f'model-tri-2following-final.bin')" ] }, { "cell_type": "code", "execution_count": 27, "id": "54b018d8", "metadata": {}, "outputs": [], "source": [ "torch.save(model.state_dict(), f'model-tri-2following-final.bin')" ] }, { "cell_type": "code", "execution_count": 30, "id": "7dd5e6f8", "metadata": {}, "outputs": [], "source": [ "def get_first_word(text):\n", " \"\"\"Return the first word of a string.\"\"\"\n", " word = \"\"\n", " for i in range(len(text)-1):\n", "# if text[i] in [' ', ',', '.']\n", " if text[i] == ' ':\n", " return word.rstrip()\n", " else:\n", " word += text[i]\n", " return word.rstrip()\n", "\n", "def get_values_from_model(context: list, model, vocab, k=10):\n", " words = [vocab.forward([word]) for word in context]\n", " ixs = torch.tensor(words).to(device)\n", " out = model(ixs)\n", " top = torch.topk(out[0], k)\n", " top_indices = top.indices.tolist()\n", " top_probs = top.values.tolist()\n", " top_words = vocab.lookup_tokens(top_indices)\n", " return list(zip(top_words, top_probs))\n", "\n", "def summarize_probs_unk(dic, const_wildcard=True):\n", " ''' \n", " dic: dictionary of probabilities returned by model \n", " returns: tab of probabilities, with specificly as last element\n", " '''\n", " if const_wildcard or '' not in dic.keys(): \n", " if '' in dic.keys():\n", " del dic['']\n", " probsum = sum(float(val) for key, val in dic.items())\n", " for key in dic:\n", " dic[key] = dic[key]/probsum*(1-wildcard_minweight) ###leave some space for wildcard\n", " tab = [(key, val) for key, val in dic.items()]\n", " tab.append(('', wildcard_minweight))\n", " else:\n", " probsum = sum(float(val) for key, val in dic.items())\n", " for key in dic:\n", " dic[key] = dic[key]/probsum*(1-wildcard_minweight) ###leave some space for wildcard\n", " wildcard_value = dic['']\n", " del dic['']\n", " tab = [(key, val) for key, val in dic.items()]\n", " tab.append(('', wildcard_value))\n", " \n", " return tab\n", "\n", "def gonito_format(dic, const_wildcard = True):\n", " tab = summarize_probs_unk(dic, const_wildcard)\n", " result = ''\n", " for element in tab[:-1]:\n", " result+=str(element[0])+':'+str(element[1])+'\\t'\n", " result+=':'+ str(tab[-1][1]) + '\\n'\n", " return result" ] }, { "cell_type": "code", "execution_count": 11, "id": "2b7513f3", "metadata": {}, "outputs": [], "source": [ "###preprocessing\n", "def preprocess(line):\n", " line = get_rid_of_header(line)\n", " line = replace_endline(line)\n", " return line\n", "\n", "def get_rid_of_header(line):\n", " line = line.split('\\t')[6:]\n", " return \" \".join(line)\n", " \n", "def replace_endline(line):\n", " line = line.replace(\"\\\\n\", \" \")\n", " return line" ] }, { "cell_type": "code", "execution_count": 39, "id": "4b0e66e2", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/tmp/ipykernel_37433/606935597.py:22: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n", " out = self.softmax(concated)\n" ] }, { "data": { "text/plain": [ "[('', 0, 0.12663832306861877),\n", " ('one', 43, 0.02672259509563446),\n", " ('part', 146, 0.015497211366891861),\n", " ('out', 63, 0.012386629357933998),\n", " ('some', 76, 0.008164796978235245),\n", " ('members', 426, 0.00799479242414236),\n", " ('side', 238, 0.007780702318996191),\n", " ('portion', 634, 0.005733700469136238),\n", " ('office', 282, 0.0053163678385317326),\n", " ('member', 712, 0.005126394797116518)]" ] }, "execution_count": 39, "metadata": {}, "output_type": "execute_result" } ], "source": [ "\n", "ixs = torch.tensor([vocab.forward(['of']), vocab.forward(['the'])]).to(device)\n", "\n", "out = model(ixs)\n", "top = torch.topk(out[0], 10)\n", "top_indices = top.indices.tolist()\n", "top_probs = top.values.tolist()\n", "top_words = vocab.lookup_tokens(top_indices)\n", "list(zip(top_words, top_indices, top_probs))" ] }, { "cell_type": "code", "execution_count": 18, "id": "a92abbf2", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.load_state_dict(torch.load(f'model-tri-2following-final.bin'))" ] }, { "cell_type": "code", "execution_count": 31, "id": "fc7cf293", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/tmp/ipykernel_4654/606935597.py:22: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n", " out = self.softmax(concated)\n" ] } ], "source": [ "with lzma.open(test_file, 'rt') as file:\n", " predict_words = []\n", " results = []\n", " for line in file:\n", " line = replace_endline(line) #get only relevant\n", " line = line.split('\\t')[6:]\n", " context = line[1].rstrip().split(\" \")[:2]\n", " predict_words.append(context) #get_first_word(split[1cd \n", " vocab = train_dataset.vocab\n", " for context_words in predict_words:\n", " results.append(dict(get_values_from_model(context_words, model, vocab, k=10)))\n", " \n", " with open(out_file, 'w') as outfile:\n", " for elem in results: \n", " outfile.write(gonito_format(elem, const_wildcard=False))\n" ] }, { "cell_type": "code", "execution_count": null, "id": "1c31c8ba", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 5 }