challenging-america-word-ga.../zad8.ipynb

1545 lines
41 KiB
Plaintext
Raw Normal View History

2023-06-28 19:20:16 +02:00
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "8b023ab4",
"metadata": {},
"outputs": [],
"source": [
"train_file ='train/in.tsv.xz'\n",
"test_file = 'dev-0/in.tsv.xz'\n",
"out_file = 'dev-0/out.tsv'"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "39b223cf",
"metadata": {},
"outputs": [],
"source": [
"from itertools import islice\n",
"import regex as re\n",
"import sys\n",
"from torchtext.vocab import build_vocab_from_iterator\n",
"import lzma\n",
"import pickle\n",
"import re\n",
"import torch\n",
"from torch import nn\n",
"from torch.utils.data import IterableDataset\n",
"import itertools\n",
"from torch.utils.data import DataLoader\n",
"import yaml"
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "a0b0b73e",
"metadata": {},
"outputs": [],
"source": [
"epochs = 3\n",
"embed_size = 200\n",
"device = 'cuda'\n",
"vocab_size = 30000\n",
"batch_s = 1600\n",
"learning_rate = 0.01\n",
"k = 20 #top k words\n",
"wildcard_minweight = 0.01"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "2ac3a353",
"metadata": {},
"outputs": [],
"source": [
"params = {\n",
"'epochs': 3,\n",
"'embed_size': 100,\n",
"'device': 'cuda',\n",
"'vocab_size': 30000,\n",
"'batch_size': 3200,\n",
"'learning_rate': 0.0001,\n",
"'k': 15, #top k words\n",
"'wildcard_minweight': 0.01\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "9668da9f",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/tmp/ipykernel_37433/1141171476.py:1: YAMLLoadWarning: calling yaml.load() without Loader=... is deprecated, as the default Loader is unsafe. Please read https://msg.pyyaml.org/load for full details.\n",
" params = yaml.load(open('config/params.yaml'))\n"
]
}
],
"source": [
"params = yaml.load(open('config/params.yaml'))\n",
"#then, entire code should go about those params with params[epochs] etc"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "01a6cf33",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'epochs': 3,\n",
" 'embed_size': 100,\n",
" 'device': 'cuda',\n",
" 'vocab_size': 30000,\n",
" 'batch_size': 3200,\n",
" 'learning_rate': 0.0001,\n",
" 'k': 15,\n",
" 'wildcard_minweight': 0.01}"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"params"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "7526e30c",
"metadata": {},
"outputs": [],
"source": [
"def get_words_from_line(line):\n",
" line = line.rstrip()\n",
" yield '<s>'\n",
" line = preprocess(line)\n",
" for t in line.split(' '):\n",
" yield t\n",
" yield '</s>'\n",
"\n",
"\n",
"def get_word_lines_from_file(file_name):\n",
" n = 0\n",
" with lzma.open(file_name, 'r') as fh:\n",
" for line in fh:\n",
" n+=1\n",
" if n%1000==0:\n",
" print(n)\n",
" yield get_words_from_line(line.decode('utf-8'))\n"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "01cde371",
"metadata": {},
"outputs": [],
"source": [
"def look_ahead_iterator(gen):\n",
" prev2 = None\n",
" prev1 = None\n",
" for item in gen:\n",
" if prev2 is not None and prev1 is not None:\n",
" yield (prev2, prev1, item)\n",
" prev2 = prev1\n",
" prev1 = item\n",
"\n",
"class Trigrams(IterableDataset):\n",
" def __init__(self, text_file, vocabulary_size):\n",
" self.vocab = build_vocab_from_iterator(\n",
" get_word_lines_from_file(text_file),\n",
" max_tokens = vocabulary_size,\n",
" specials = ['<unk>'])\n",
" self.vocab.set_default_index(self.vocab['<unk>'])\n",
" self.vocabulary_size = vocabulary_size\n",
" self.text_file = text_file\n",
"\n",
" def __iter__(self):\n",
" return look_ahead_iterator(\n",
" (self.vocab[t] for t in itertools.chain.from_iterable(get_word_lines_from_file(self.text_file))))\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "198b1dd3",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1000\n",
"2000\n",
"3000\n",
"4000\n",
"5000\n",
"6000\n",
"7000\n",
"8000\n",
"9000\n",
"10000\n",
"11000\n",
"12000\n",
"13000\n",
"14000\n",
"15000\n",
"16000\n",
"17000\n",
"18000\n",
"19000\n",
"20000\n",
"21000\n",
"22000\n",
"23000\n",
"24000\n",
"25000\n",
"26000\n",
"27000\n",
"28000\n",
"29000\n",
"30000\n",
"31000\n",
"32000\n",
"33000\n",
"34000\n",
"35000\n",
"36000\n",
"37000\n",
"38000\n",
"39000\n",
"40000\n",
"41000\n",
"42000\n",
"43000\n",
"44000\n",
"45000\n",
"46000\n",
"47000\n",
"48000\n",
"49000\n",
"50000\n",
"51000\n",
"52000\n",
"53000\n",
"54000\n",
"55000\n",
"56000\n",
"57000\n",
"58000\n",
"59000\n",
"60000\n",
"61000\n",
"62000\n",
"63000\n",
"64000\n",
"65000\n",
"66000\n",
"67000\n",
"68000\n",
"69000\n",
"70000\n",
"71000\n",
"72000\n",
"73000\n",
"74000\n",
"75000\n",
"76000\n",
"77000\n",
"78000\n",
"79000\n",
"80000\n",
"81000\n",
"82000\n",
"83000\n",
"84000\n",
"85000\n",
"86000\n",
"87000\n",
"88000\n",
"89000\n",
"90000\n",
"91000\n",
"92000\n",
"93000\n",
"94000\n",
"95000\n",
"96000\n",
"97000\n",
"98000\n",
"99000\n",
"100000\n",
"101000\n",
"102000\n",
"103000\n",
"104000\n",
"105000\n",
"106000\n",
"107000\n",
"108000\n",
"109000\n",
"110000\n",
"111000\n",
"112000\n",
"113000\n",
"114000\n",
"115000\n",
"116000\n",
"117000\n",
"118000\n",
"119000\n",
"120000\n",
"121000\n",
"122000\n",
"123000\n",
"124000\n",
"125000\n",
"126000\n",
"127000\n",
"128000\n",
"129000\n",
"130000\n",
"131000\n",
"132000\n",
"133000\n",
"134000\n",
"135000\n",
"136000\n",
"137000\n",
"138000\n",
"139000\n",
"140000\n",
"141000\n",
"142000\n",
"143000\n",
"144000\n",
"145000\n",
"146000\n",
"147000\n",
"148000\n",
"149000\n",
"150000\n",
"151000\n",
"152000\n",
"153000\n",
"154000\n",
"155000\n",
"156000\n",
"157000\n",
"158000\n",
"159000\n",
"160000\n",
"161000\n",
"162000\n",
"163000\n",
"164000\n",
"165000\n",
"166000\n",
"167000\n",
"168000\n",
"169000\n",
"170000\n",
"171000\n",
"172000\n",
"173000\n",
"174000\n",
"175000\n",
"176000\n",
"177000\n",
"178000\n",
"179000\n",
"180000\n",
"181000\n",
"182000\n",
"183000\n",
"184000\n",
"185000\n",
"186000\n",
"187000\n",
"188000\n",
"189000\n",
"190000\n",
"191000\n",
"192000\n",
"193000\n",
"194000\n",
"195000\n",
"196000\n",
"197000\n",
"198000\n",
"199000\n",
"200000\n",
"201000\n",
"202000\n",
"203000\n",
"204000\n",
"205000\n",
"206000\n",
"207000\n",
"208000\n",
"209000\n",
"210000\n",
"211000\n",
"212000\n",
"213000\n",
"214000\n",
"215000\n",
"216000\n",
"217000\n",
"218000\n",
"219000\n",
"220000\n",
"221000\n",
"222000\n",
"223000\n",
"224000\n",
"225000\n",
"226000\n",
"227000\n",
"228000\n",
"229000\n",
"230000\n",
"231000\n",
"232000\n",
"233000\n",
"234000\n",
"235000\n",
"236000\n",
"237000\n",
"238000\n",
"239000\n",
"240000\n",
"241000\n",
"242000\n",
"243000\n",
"244000\n",
"245000\n",
"246000\n",
"247000\n",
"248000\n",
"249000\n",
"250000\n",
"251000\n",
"252000\n",
"253000\n",
"254000\n",
"255000\n",
"256000\n",
"257000\n",
"258000\n",
"259000\n",
"260000\n",
"261000\n",
"262000\n",
"263000\n",
"264000\n",
"265000\n",
"266000\n",
"267000\n",
"268000\n",
"269000\n",
"270000\n",
"271000\n",
"272000\n",
"273000\n",
"274000\n",
"275000\n",
"276000\n",
"277000\n",
"278000\n",
"279000\n",
"280000\n",
"281000\n",
"282000\n",
"283000\n",
"284000\n",
"285000\n",
"286000\n",
"287000\n",
"288000\n",
"289000\n",
"290000\n",
"291000\n",
"292000\n",
"293000\n",
"294000\n",
"295000\n",
"296000\n",
"297000\n",
"298000\n",
"299000\n",
"300000\n",
"301000\n",
"302000\n",
"303000\n",
"304000\n",
"305000\n",
"306000\n",
"307000\n",
"308000\n",
"309000\n",
"310000\n",
"311000\n",
"312000\n",
"313000\n",
"314000\n",
"315000\n",
"316000\n",
"317000\n",
"318000\n",
"319000\n",
"320000\n",
"321000\n",
"322000\n",
"323000\n",
"324000\n",
"325000\n",
"326000\n",
"327000\n",
"328000\n",
"329000\n",
"330000\n",
"331000\n",
"332000\n",
"333000\n",
"334000\n",
"335000\n",
"336000\n",
"337000\n",
"338000\n",
"339000\n",
"340000\n",
"341000\n",
"342000\n",
"343000\n",
"344000\n",
"345000\n",
"346000\n",
"347000\n",
"348000\n",
"349000\n",
"350000\n",
"351000\n",
"352000\n",
"353000\n",
"354000\n",
"355000\n",
"356000\n",
"357000\n",
"358000\n",
"359000\n",
"360000\n",
"361000\n",
"362000\n",
"363000\n",
"364000\n",
"365000\n",
"366000\n",
"367000\n",
"368000\n",
"369000\n",
"370000\n",
"371000\n",
"372000\n",
"373000\n",
"374000\n",
"375000\n",
"376000\n",
"377000\n",
"378000\n",
"379000\n",
"380000\n",
"381000\n",
"382000\n",
"383000\n",
"384000\n",
"385000\n",
"386000\n",
"387000\n",
"388000\n",
"389000\n",
"390000\n",
"391000\n",
"392000\n",
"393000\n",
"394000\n",
"395000\n",
"396000\n",
"397000\n",
"398000\n",
"399000\n",
"400000\n",
"401000\n",
"402000\n",
"403000\n",
"404000\n",
"405000\n",
"406000\n",
"407000\n",
"408000\n",
"409000\n",
"410000\n",
"411000\n",
"412000\n",
"413000\n",
"414000\n",
"415000\n",
"416000\n",
"417000\n",
"418000\n",
"419000\n",
"420000\n",
"421000\n",
"422000\n",
"423000\n",
"424000\n",
"425000\n",
"426000\n",
"427000\n",
"428000\n",
"429000\n",
"430000\n",
"431000\n",
"432000\n"
]
}
],
"source": [
"vocab = build_vocab_from_iterator(\n",
" get_word_lines_from_file(train_file),\n",
" max_tokens = params['vocab_size'],\n",
" specials = ['<unk>'])"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "6136fbb9",
"metadata": {},
"outputs": [],
"source": [
"with open('filename.pickle', 'wb') as handle:\n",
" pickle.dump(vocab, handle, protocol=pickle.HIGHEST_PROTOCOL)"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "30a5b26b",
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1000\n",
"2000\n",
"3000\n",
"4000\n",
"5000\n",
"6000\n",
"7000\n",
"8000\n",
"9000\n",
"10000\n",
"11000\n",
"12000\n",
"13000\n",
"14000\n",
"15000\n",
"16000\n",
"17000\n",
"18000\n",
"19000\n",
"20000\n",
"21000\n",
"22000\n",
"23000\n",
"24000\n",
"25000\n",
"26000\n",
"27000\n",
"28000\n",
"29000\n",
"30000\n",
"31000\n",
"32000\n",
"33000\n",
"34000\n",
"35000\n",
"36000\n",
"37000\n",
"38000\n",
"39000\n",
"40000\n",
"41000\n",
"42000\n",
"43000\n",
"44000\n",
"45000\n",
"46000\n",
"47000\n",
"48000\n",
"49000\n",
"50000\n",
"51000\n",
"52000\n",
"53000\n",
"54000\n",
"55000\n",
"56000\n",
"57000\n",
"58000\n",
"59000\n",
"60000\n",
"61000\n",
"62000\n",
"63000\n",
"64000\n",
"65000\n",
"66000\n",
"67000\n",
"68000\n",
"69000\n",
"70000\n",
"71000\n",
"72000\n",
"73000\n",
"74000\n",
"75000\n",
"76000\n",
"77000\n",
"78000\n",
"79000\n",
"80000\n",
"81000\n",
"82000\n",
"83000\n",
"84000\n",
"85000\n",
"86000\n",
"87000\n",
"88000\n",
"89000\n",
"90000\n",
"91000\n",
"92000\n",
"93000\n",
"94000\n",
"95000\n",
"96000\n",
"97000\n",
"98000\n",
"99000\n",
"100000\n",
"101000\n",
"102000\n",
"103000\n",
"104000\n",
"105000\n",
"106000\n",
"107000\n",
"108000\n",
"109000\n",
"110000\n",
"111000\n",
"112000\n",
"113000\n",
"114000\n",
"115000\n",
"116000\n",
"117000\n",
"118000\n",
"119000\n",
"120000\n",
"121000\n",
"122000\n",
"123000\n",
"124000\n",
"125000\n",
"126000\n",
"127000\n",
"128000\n",
"129000\n",
"130000\n",
"131000\n",
"132000\n",
"133000\n",
"134000\n",
"135000\n",
"136000\n",
"137000\n",
"138000\n",
"139000\n",
"140000\n",
"141000\n",
"142000\n",
"143000\n",
"144000\n",
"145000\n",
"146000\n",
"147000\n",
"148000\n",
"149000\n",
"150000\n",
"151000\n",
"152000\n",
"153000\n",
"154000\n",
"155000\n",
"156000\n",
"157000\n",
"158000\n",
"159000\n",
"160000\n",
"161000\n",
"162000\n",
"163000\n",
"164000\n",
"165000\n",
"166000\n",
"167000\n",
"168000\n",
"169000\n",
"170000\n",
"171000\n",
"172000\n",
"173000\n",
"174000\n",
"175000\n",
"176000\n",
"177000\n",
"178000\n",
"179000\n",
"180000\n",
"181000\n",
"182000\n",
"183000\n",
"184000\n",
"185000\n",
"186000\n",
"187000\n",
"188000\n",
"189000\n",
"190000\n",
"191000\n",
"192000\n",
"193000\n",
"194000\n",
"195000\n",
"196000\n",
"197000\n",
"198000\n",
"199000\n",
"200000\n",
"201000\n",
"202000\n",
"203000\n",
"204000\n",
"205000\n",
"206000\n",
"207000\n",
"208000\n",
"209000\n",
"210000\n",
"211000\n",
"212000\n",
"213000\n",
"214000\n",
"215000\n",
"216000\n",
"217000\n",
"218000\n",
"219000\n",
"220000\n",
"221000\n",
"222000\n",
"223000\n",
"224000\n",
"225000\n",
"226000\n",
"227000\n",
"228000\n",
"229000\n",
"230000\n",
"231000\n",
"232000\n",
"233000\n",
"234000\n",
"235000\n",
"236000\n",
"237000\n",
"238000\n",
"239000\n",
"240000\n",
"241000\n",
"242000\n",
"243000\n",
"244000\n",
"245000\n",
"246000\n",
"247000\n",
"248000\n",
"249000\n",
"250000\n",
"251000\n",
"252000\n",
"253000\n",
"254000\n",
"255000\n",
"256000\n",
"257000\n",
"258000\n",
"259000\n",
"260000\n",
"261000\n",
"262000\n",
"263000\n",
"264000\n",
"265000\n",
"266000\n",
"267000\n",
"268000\n",
"269000\n",
"270000\n",
"271000\n",
"272000\n",
"273000\n",
"274000\n",
"275000\n",
"276000\n",
"277000\n",
"278000\n",
"279000\n",
"280000\n",
"281000\n",
"282000\n",
"283000\n",
"284000\n",
"285000\n",
"286000\n",
"287000\n",
"288000\n",
"289000\n",
"290000\n",
"291000\n",
"292000\n",
"293000\n",
"294000\n",
"295000\n",
"296000\n",
"297000\n",
"298000\n",
"299000\n",
"300000\n",
"301000\n",
"302000\n",
"303000\n",
"304000\n",
"305000\n",
"306000\n",
"307000\n",
"308000\n",
"309000\n",
"310000\n",
"311000\n",
"312000\n",
"313000\n",
"314000\n",
"315000\n",
"316000\n",
"317000\n",
"318000\n",
"319000\n",
"320000\n",
"321000\n",
"322000\n",
"323000\n",
"324000\n",
"325000\n",
"326000\n",
"327000\n",
"328000\n",
"329000\n",
"330000\n",
"331000\n",
"332000\n",
"333000\n",
"334000\n",
"335000\n",
"336000\n",
"337000\n",
"338000\n",
"339000\n",
"340000\n",
"341000\n",
"342000\n",
"343000\n",
"344000\n",
"345000\n",
"346000\n",
"347000\n",
"348000\n",
"349000\n",
"350000\n",
"351000\n",
"352000\n",
"353000\n",
"354000\n",
"355000\n",
"356000\n",
"357000\n",
"358000\n",
"359000\n",
"360000\n",
"361000\n",
"362000\n",
"363000\n",
"364000\n",
"365000\n",
"366000\n",
"367000\n",
"368000\n",
"369000\n",
"370000\n",
"371000\n",
"372000\n",
"373000\n",
"374000\n",
"375000\n",
"376000\n",
"377000\n",
"378000\n",
"379000\n",
"380000\n",
"381000\n",
"382000\n",
"383000\n",
"384000\n",
"385000\n",
"386000\n",
"387000\n",
"388000\n",
"389000\n",
"390000\n",
"391000\n",
"392000\n",
"393000\n",
"394000\n",
"395000\n",
"396000\n",
"397000\n",
"398000\n",
"399000\n",
"400000\n",
"401000\n",
"402000\n",
"403000\n",
"404000\n",
"405000\n",
"406000\n",
"407000\n",
"408000\n",
"409000\n",
"410000\n",
"411000\n",
"412000\n",
"413000\n",
"414000\n",
"415000\n",
"416000\n",
"417000\n",
"418000\n",
"419000\n",
"420000\n",
"421000\n",
"422000\n",
"423000\n",
"424000\n",
"425000\n",
"426000\n",
"427000\n",
"428000\n",
"429000\n",
"430000\n",
"431000\n",
"432000\n"
]
}
],
"source": [
"with open('filename.pickle','rb') as handle:\n",
" vocab = pickle.load(handle)\n",
" \n",
"train_dataset = Trigrams(train_file, params['vocab_size'])"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "eaa681b4",
"metadata": {},
"outputs": [],
"source": [
"data = DataLoader(train_dataset, batch_size=params['batch_size']) #load data "
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "3aea0574",
"metadata": {},
"outputs": [],
"source": [
"class SimpleTrigramNeuralLanguageModel(nn.Module):\n",
" def __init__(self, vocabulary_size, embedding_size):\n",
" super(SimpleTrigramNeuralLanguageModel, self).__init__()\n",
" self.embeddings = nn.Embedding(vocabulary_size, embedding_size)\n",
" self.linear = nn.Linear(2*embedding_size, vocabulary_size)\n",
" self.linear_matrix_2 = nn.Linear(embedding_size*2, embedding_size*2)\n",
" self.relu = nn.ReLU()\n",
" self.softmax = nn.Softmax()\n",
" \n",
" #for each word in vocabulary theres a separate embedding vector, consisting of embedding_size entries\n",
" #self.linear is linear layer consisting of concatenated embeddings of left, and right context words\n",
" #self.linear_matrix_2 is linear layer \n",
" \n",
" def forward(self, x): #x is list of prev and following embeddings\n",
" emb_left = self.embeddings(x[0])\n",
" emb_right = self.embeddings(x[1])\n",
" #create two embeddings vectors, for word before and after, respectively\n",
" \n",
" first_layer_size_2 = self.linear_matrix_2(torch.cat((emb_left, emb_right), dim=1))\n",
" first_relu = self.relu(first_layer_size_2)\n",
" concated = self.linear(first_relu)\n",
" out = self.softmax(concated)\n",
" return out"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "e4757295",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0"
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import gc\n",
"torch.cuda.empty_cache()\n",
"gc.collect()"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "0a41831e",
"metadata": {},
"outputs": [],
"source": [
"device = 'cuda'\n",
"model = SimpleTrigramNeuralLanguageModel(params['vocab_size'], params['embed_size']).to(device)\n",
"optimizer = torch.optim.Adam(model.parameters(), lr=params['learning_rate'])\n",
"criterion = torch.nn.NLLLoss()"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "281b9010",
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch: = 0\n",
"0 tensor(5.3414, device='cuda:0', grad_fn=<NllLossBackward0>)\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/tmp/ipykernel_37433/606935597.py:22: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n",
" out = self.softmax(concated)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"1000\n",
"100 tensor(5.4870, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"200 tensor(5.3542, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"2000\n",
"300 tensor(5.3792, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"3000\n",
"400 tensor(5.5982, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"4000\n",
"500 tensor(5.4045, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"5000\n",
"600 tensor(5.5620, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"6000\n",
"700 tensor(5.5428, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"7000\n",
"800 tensor(5.3684, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"8000\n",
"900 tensor(5.4198, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"9000\n",
"1000 tensor(5.4100, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"10000\n",
"1100 tensor(5.4554, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"11000\n",
"1200 tensor(5.5284, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"12000\n",
"1300 tensor(5.5495, device='cuda:0', grad_fn=<NllLossBackward0>)\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/gedin/.local/lib/python3.10/site-packages/torch/autograd/__init__.py:200: UserWarning: Error detected in LogBackward0. Traceback of forward call that caused the error:\n",
" File \"/usr/lib/python3.10/runpy.py\", line 196, in _run_module_as_main\n",
" return _run_code(code, main_globals, None,\n",
" File \"/usr/lib/python3.10/runpy.py\", line 86, in _run_code\n",
" exec(code, run_globals)\n",
" File \"/home/gedin/.local/lib/python3.10/site-packages/ipykernel_launcher.py\", line 17, in <module>\n",
" app.launch_new_instance()\n",
" File \"/home/gedin/.local/lib/python3.10/site-packages/traitlets/config/application.py\", line 1043, in launch_instance\n",
" app.start()\n",
" File \"/home/gedin/.local/lib/python3.10/site-packages/ipykernel/kernelapp.py\", line 725, in start\n",
" self.io_loop.start()\n",
" File \"/home/gedin/.local/lib/python3.10/site-packages/tornado/platform/asyncio.py\", line 195, in start\n",
" self.asyncio_loop.run_forever()\n",
" File \"/usr/lib/python3.10/asyncio/base_events.py\", line 600, in run_forever\n",
" self._run_once()\n",
" File \"/usr/lib/python3.10/asyncio/base_events.py\", line 1896, in _run_once\n",
" handle._run()\n",
" File \"/usr/lib/python3.10/asyncio/events.py\", line 80, in _run\n",
" self._context.run(self._callback, *self._args)\n",
" File \"/home/gedin/.local/lib/python3.10/site-packages/ipykernel/kernelbase.py\", line 513, in dispatch_queue\n",
" await self.process_one()\n",
" File \"/home/gedin/.local/lib/python3.10/site-packages/ipykernel/kernelbase.py\", line 502, in process_one\n",
" await dispatch(*args)\n",
" File \"/home/gedin/.local/lib/python3.10/site-packages/ipykernel/kernelbase.py\", line 409, in dispatch_shell\n",
" await result\n",
" File \"/home/gedin/.local/lib/python3.10/site-packages/ipykernel/kernelbase.py\", line 729, in execute_request\n",
" reply_content = await reply_content\n",
" File \"/home/gedin/.local/lib/python3.10/site-packages/ipykernel/ipkernel.py\", line 422, in do_execute\n",
" res = shell.run_cell(\n",
" File \"/home/gedin/.local/lib/python3.10/site-packages/ipykernel/zmqshell.py\", line 540, in run_cell\n",
" return super().run_cell(*args, **kwargs)\n",
" File \"/home/gedin/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py\", line 3009, in run_cell\n",
" result = self._run_cell(\n",
" File \"/home/gedin/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py\", line 3064, in _run_cell\n",
" result = runner(coro)\n",
" File \"/home/gedin/.local/lib/python3.10/site-packages/IPython/core/async_helpers.py\", line 129, in _pseudo_sync_runner\n",
" coro.send(None)\n",
" File \"/home/gedin/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py\", line 3269, in run_cell_async\n",
" has_raised = await self.run_ast_nodes(code_ast.body, cell_name,\n",
" File \"/home/gedin/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py\", line 3448, in run_ast_nodes\n",
" if await self.run_code(code, result, async_=asy):\n",
" File \"/home/gedin/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py\", line 3508, in run_code\n",
" exec(code_obj, self.user_global_ns, self.user_ns)\n",
" File \"/tmp/ipykernel_37433/1707264841.py\", line 13, in <module>\n",
" loss = criterion(torch.log(ypredicted), x) #x is to_predict\n",
" (Triggered internally at ../torch/csrc/autograd/python_anomaly_mode.cpp:114.)\n",
" Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass\n"
]
},
{
"ename": "RuntimeError",
"evalue": "Function 'LogBackward0' returned nan values in its 0th output.",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[26], line 19\u001b[0m\n\u001b[1;32m 16\u001b[0m step \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m 17\u001b[0m \u001b[38;5;66;03m# if step % 10000 == 0:\u001b[39;00m\n\u001b[1;32m 18\u001b[0m \u001b[38;5;66;03m# torch.save(model.state_dict(), f'model-tri-2following-{step}.bin')\u001b[39;00m\n\u001b[0;32m---> 19\u001b[0m \u001b[43mloss\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 20\u001b[0m optimizer\u001b[38;5;241m.\u001b[39mstep()\n\u001b[1;32m 21\u001b[0m \u001b[38;5;66;03m# torch.save(model.state_dict(), f'model-tri-2following-{i}.bin') \u001b[39;00m\n\u001b[1;32m 22\u001b[0m \u001b[38;5;66;03m# torch.save(model.state_dict(), f'model-tri-2following-final.bin')\u001b[39;00m\n",
"File \u001b[0;32m~/.local/lib/python3.10/site-packages/torch/_tensor.py:487\u001b[0m, in \u001b[0;36mTensor.backward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m 477\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m has_torch_function_unary(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m 478\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m handle_torch_function(\n\u001b[1;32m 479\u001b[0m Tensor\u001b[38;5;241m.\u001b[39mbackward,\n\u001b[1;32m 480\u001b[0m (\u001b[38;5;28mself\u001b[39m,),\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 485\u001b[0m inputs\u001b[38;5;241m=\u001b[39minputs,\n\u001b[1;32m 486\u001b[0m )\n\u001b[0;32m--> 487\u001b[0m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mautograd\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 488\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgradient\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs\u001b[49m\n\u001b[1;32m 489\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/.local/lib/python3.10/site-packages/torch/autograd/__init__.py:200\u001b[0m, in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m 195\u001b[0m retain_graph \u001b[38;5;241m=\u001b[39m create_graph\n\u001b[1;32m 197\u001b[0m \u001b[38;5;66;03m# The reason we repeat same the comment below is that\u001b[39;00m\n\u001b[1;32m 198\u001b[0m \u001b[38;5;66;03m# some Python versions print out the first line of a multi-line function\u001b[39;00m\n\u001b[1;32m 199\u001b[0m \u001b[38;5;66;03m# calls in the traceback and some print out the last line\u001b[39;00m\n\u001b[0;32m--> 200\u001b[0m \u001b[43mVariable\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_execution_engine\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun_backward\u001b[49m\u001b[43m(\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# Calls into the C++ engine to run the backward pass\u001b[39;49;00m\n\u001b[1;32m 201\u001b[0m \u001b[43m \u001b[49m\u001b[43mtensors\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgrad_tensors_\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 202\u001b[0m \u001b[43m \u001b[49m\u001b[43mallow_unreachable\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maccumulate_grad\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n",
"\u001b[0;31mRuntimeError\u001b[0m: Function 'LogBackward0' returned nan values in its 0th output."
]
}
],
"source": [
"torch.autograd.set_detect_anomaly(True)\n",
"model.load_state_dict(torch.load(f'model-tri-2following-40000.bin'))\n",
"for i in range(params['epochs']):\n",
" print('epoch: =', i)\n",
" model.train()\n",
" step = 0\n",
" for x, y, z in data: # word, following, 2nd_following words\n",
" x = x.to(device)\n",
" y = y.to(device)\n",
" z = z.to(device)\n",
" optimizer.zero_grad()\n",
" ypredicted = model([y, z]) #following, 2nd_following word\n",
" loss = criterion(torch.log(ypredicted), x) #x is to_predict\n",
" if step % 100 == 0:\n",
" print(step, loss)\n",
" step += 1\n",
"# if step % 10000 == 0:\n",
"# torch.save(model.state_dict(), f'model-tri-2following-{step}.bin')\n",
" loss.backward()\n",
" optimizer.step()\n",
"# torch.save(model.state_dict(), f'model-tri-2following-{i}.bin') \n",
"# torch.save(model.state_dict(), f'model-tri-2following-final.bin')"
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "54b018d8",
"metadata": {},
"outputs": [],
"source": [
"torch.save(model.state_dict(), f'model-tri-2following-final.bin')"
]
},
{
"cell_type": "code",
"execution_count": 30,
"id": "7dd5e6f8",
"metadata": {},
"outputs": [],
"source": [
"def get_first_word(text):\n",
" \"\"\"Return the first word of a string.\"\"\"\n",
" word = \"\"\n",
" for i in range(len(text)-1):\n",
"# if text[i] in [' ', ',', '.']\n",
" if text[i] == ' ':\n",
" return word.rstrip()\n",
" else:\n",
" word += text[i]\n",
" return word.rstrip()\n",
"\n",
"def get_values_from_model(context: list, model, vocab, k=10):\n",
" words = [vocab.forward([word]) for word in context]\n",
" ixs = torch.tensor(words).to(device)\n",
" out = model(ixs)\n",
" top = torch.topk(out[0], k)\n",
" top_indices = top.indices.tolist()\n",
" top_probs = top.values.tolist()\n",
" top_words = vocab.lookup_tokens(top_indices)\n",
" return list(zip(top_words, top_probs))\n",
"\n",
"def summarize_probs_unk(dic, const_wildcard=True):\n",
" ''' \n",
" dic: dictionary of probabilities returned by model \n",
" returns: tab of probabilities, with <unk> specificly as last element\n",
" '''\n",
" if const_wildcard or '<unk>' not in dic.keys(): \n",
" if '<unk>' in dic.keys():\n",
" del dic['<unk>']\n",
" probsum = sum(float(val) for key, val in dic.items())\n",
" for key in dic:\n",
" dic[key] = dic[key]/probsum*(1-wildcard_minweight) ###leave some space for wildcard\n",
" tab = [(key, val) for key, val in dic.items()]\n",
" tab.append(('<unk>', wildcard_minweight))\n",
" else:\n",
" probsum = sum(float(val) for key, val in dic.items())\n",
" for key in dic:\n",
" dic[key] = dic[key]/probsum*(1-wildcard_minweight) ###leave some space for wildcard\n",
" wildcard_value = dic['<unk>']\n",
" del dic['<unk>']\n",
" tab = [(key, val) for key, val in dic.items()]\n",
" tab.append(('<unk>', wildcard_value))\n",
" \n",
" return tab\n",
"\n",
"def gonito_format(dic, const_wildcard = True):\n",
" tab = summarize_probs_unk(dic, const_wildcard)\n",
" result = ''\n",
" for element in tab[:-1]:\n",
" result+=str(element[0])+':'+str(element[1])+'\\t'\n",
" result+=':'+ str(tab[-1][1]) + '\\n'\n",
" return result"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "2b7513f3",
"metadata": {},
"outputs": [],
"source": [
"###preprocessing\n",
"def preprocess(line):\n",
" line = get_rid_of_header(line)\n",
" line = replace_endline(line)\n",
" return line\n",
"\n",
"def get_rid_of_header(line):\n",
" line = line.split('\\t')[6:]\n",
" return \" \".join(line)\n",
" \n",
"def replace_endline(line):\n",
" line = line.replace(\"\\\\n\", \" \")\n",
" return line"
]
},
{
"cell_type": "code",
"execution_count": 39,
"id": "4b0e66e2",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/tmp/ipykernel_37433/606935597.py:22: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n",
" out = self.softmax(concated)\n"
]
},
{
"data": {
"text/plain": [
"[('<unk>', 0, 0.12663832306861877),\n",
" ('one', 43, 0.02672259509563446),\n",
" ('part', 146, 0.015497211366891861),\n",
" ('out', 63, 0.012386629357933998),\n",
" ('some', 76, 0.008164796978235245),\n",
" ('members', 426, 0.00799479242414236),\n",
" ('side', 238, 0.007780702318996191),\n",
" ('portion', 634, 0.005733700469136238),\n",
" ('office', 282, 0.0053163678385317326),\n",
" ('member', 712, 0.005126394797116518)]"
]
},
"execution_count": 39,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"\n",
"ixs = torch.tensor([vocab.forward(['of']), vocab.forward(['the'])]).to(device)\n",
"\n",
"out = model(ixs)\n",
"top = torch.topk(out[0], 10)\n",
"top_indices = top.indices.tolist()\n",
"top_probs = top.values.tolist()\n",
"top_words = vocab.lookup_tokens(top_indices)\n",
"list(zip(top_words, top_indices, top_probs))"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "a92abbf2",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<All keys matched successfully>"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.load_state_dict(torch.load(f'model-tri-2following-final.bin'))"
]
},
{
"cell_type": "code",
"execution_count": 31,
"id": "fc7cf293",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/tmp/ipykernel_4654/606935597.py:22: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n",
" out = self.softmax(concated)\n"
]
}
],
"source": [
"with lzma.open(test_file, 'rt') as file:\n",
" predict_words = []\n",
" results = []\n",
" for line in file:\n",
" line = replace_endline(line) #get only relevant\n",
" line = line.split('\\t')[6:]\n",
" context = line[1].rstrip().split(\" \")[:2]\n",
" predict_words.append(context) #get_first_word(split[1cd \n",
" vocab = train_dataset.vocab\n",
" for context_words in predict_words:\n",
" results.append(dict(get_values_from_model(context_words, model, vocab, k=10)))\n",
" \n",
" with open(out_file, 'w') as outfile:\n",
" for elem in results: \n",
" outfile.write(gonito_format(elem, const_wildcard=False))\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1c31c8ba",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}