1545 lines
41 KiB
Plaintext
1545 lines
41 KiB
Plaintext
|
{
|
||
|
"cells": [
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 1,
|
||
|
"id": "8b023ab4",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"train_file ='train/in.tsv.xz'\n",
|
||
|
"test_file = 'dev-0/in.tsv.xz'\n",
|
||
|
"out_file = 'dev-0/out.tsv'"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 4,
|
||
|
"id": "39b223cf",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"from itertools import islice\n",
|
||
|
"import regex as re\n",
|
||
|
"import sys\n",
|
||
|
"from torchtext.vocab import build_vocab_from_iterator\n",
|
||
|
"import lzma\n",
|
||
|
"import pickle\n",
|
||
|
"import re\n",
|
||
|
"import torch\n",
|
||
|
"from torch import nn\n",
|
||
|
"from torch.utils.data import IterableDataset\n",
|
||
|
"import itertools\n",
|
||
|
"from torch.utils.data import DataLoader\n",
|
||
|
"import yaml"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 27,
|
||
|
"id": "a0b0b73e",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"epochs = 3\n",
|
||
|
"embed_size = 200\n",
|
||
|
"device = 'cuda'\n",
|
||
|
"vocab_size = 30000\n",
|
||
|
"batch_s = 1600\n",
|
||
|
"learning_rate = 0.01\n",
|
||
|
"k = 20 #top k words\n",
|
||
|
"wildcard_minweight = 0.01"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 26,
|
||
|
"id": "2ac3a353",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"params = {\n",
|
||
|
"'epochs': 3,\n",
|
||
|
"'embed_size': 100,\n",
|
||
|
"'device': 'cuda',\n",
|
||
|
"'vocab_size': 30000,\n",
|
||
|
"'batch_size': 3200,\n",
|
||
|
"'learning_rate': 0.0001,\n",
|
||
|
"'k': 15, #top k words\n",
|
||
|
"'wildcard_minweight': 0.01\n",
|
||
|
"}"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 14,
|
||
|
"id": "9668da9f",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stderr",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"/tmp/ipykernel_37433/1141171476.py:1: YAMLLoadWarning: calling yaml.load() without Loader=... is deprecated, as the default Loader is unsafe. Please read https://msg.pyyaml.org/load for full details.\n",
|
||
|
" params = yaml.load(open('config/params.yaml'))\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"params = yaml.load(open('config/params.yaml'))\n",
|
||
|
"#then, entire code should go about those params with params[epochs] etc"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 6,
|
||
|
"id": "01a6cf33",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"{'epochs': 3,\n",
|
||
|
" 'embed_size': 100,\n",
|
||
|
" 'device': 'cuda',\n",
|
||
|
" 'vocab_size': 30000,\n",
|
||
|
" 'batch_size': 3200,\n",
|
||
|
" 'learning_rate': 0.0001,\n",
|
||
|
" 'k': 15,\n",
|
||
|
" 'wildcard_minweight': 0.01}"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 6,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"params"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 12,
|
||
|
"id": "7526e30c",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"def get_words_from_line(line):\n",
|
||
|
" line = line.rstrip()\n",
|
||
|
" yield '<s>'\n",
|
||
|
" line = preprocess(line)\n",
|
||
|
" for t in line.split(' '):\n",
|
||
|
" yield t\n",
|
||
|
" yield '</s>'\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"def get_word_lines_from_file(file_name):\n",
|
||
|
" n = 0\n",
|
||
|
" with lzma.open(file_name, 'r') as fh:\n",
|
||
|
" for line in fh:\n",
|
||
|
" n+=1\n",
|
||
|
" if n%1000==0:\n",
|
||
|
" print(n)\n",
|
||
|
" yield get_words_from_line(line.decode('utf-8'))\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 13,
|
||
|
"id": "01cde371",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"def look_ahead_iterator(gen):\n",
|
||
|
" prev2 = None\n",
|
||
|
" prev1 = None\n",
|
||
|
" for item in gen:\n",
|
||
|
" if prev2 is not None and prev1 is not None:\n",
|
||
|
" yield (prev2, prev1, item)\n",
|
||
|
" prev2 = prev1\n",
|
||
|
" prev1 = item\n",
|
||
|
"\n",
|
||
|
"class Trigrams(IterableDataset):\n",
|
||
|
" def __init__(self, text_file, vocabulary_size):\n",
|
||
|
" self.vocab = build_vocab_from_iterator(\n",
|
||
|
" get_word_lines_from_file(text_file),\n",
|
||
|
" max_tokens = vocabulary_size,\n",
|
||
|
" specials = ['<unk>'])\n",
|
||
|
" self.vocab.set_default_index(self.vocab['<unk>'])\n",
|
||
|
" self.vocabulary_size = vocabulary_size\n",
|
||
|
" self.text_file = text_file\n",
|
||
|
"\n",
|
||
|
" def __iter__(self):\n",
|
||
|
" return look_ahead_iterator(\n",
|
||
|
" (self.vocab[t] for t in itertools.chain.from_iterable(get_word_lines_from_file(self.text_file))))\n",
|
||
|
" "
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 14,
|
||
|
"id": "198b1dd3",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"1000\n",
|
||
|
"2000\n",
|
||
|
"3000\n",
|
||
|
"4000\n",
|
||
|
"5000\n",
|
||
|
"6000\n",
|
||
|
"7000\n",
|
||
|
"8000\n",
|
||
|
"9000\n",
|
||
|
"10000\n",
|
||
|
"11000\n",
|
||
|
"12000\n",
|
||
|
"13000\n",
|
||
|
"14000\n",
|
||
|
"15000\n",
|
||
|
"16000\n",
|
||
|
"17000\n",
|
||
|
"18000\n",
|
||
|
"19000\n",
|
||
|
"20000\n",
|
||
|
"21000\n",
|
||
|
"22000\n",
|
||
|
"23000\n",
|
||
|
"24000\n",
|
||
|
"25000\n",
|
||
|
"26000\n",
|
||
|
"27000\n",
|
||
|
"28000\n",
|
||
|
"29000\n",
|
||
|
"30000\n",
|
||
|
"31000\n",
|
||
|
"32000\n",
|
||
|
"33000\n",
|
||
|
"34000\n",
|
||
|
"35000\n",
|
||
|
"36000\n",
|
||
|
"37000\n",
|
||
|
"38000\n",
|
||
|
"39000\n",
|
||
|
"40000\n",
|
||
|
"41000\n",
|
||
|
"42000\n",
|
||
|
"43000\n",
|
||
|
"44000\n",
|
||
|
"45000\n",
|
||
|
"46000\n",
|
||
|
"47000\n",
|
||
|
"48000\n",
|
||
|
"49000\n",
|
||
|
"50000\n",
|
||
|
"51000\n",
|
||
|
"52000\n",
|
||
|
"53000\n",
|
||
|
"54000\n",
|
||
|
"55000\n",
|
||
|
"56000\n",
|
||
|
"57000\n",
|
||
|
"58000\n",
|
||
|
"59000\n",
|
||
|
"60000\n",
|
||
|
"61000\n",
|
||
|
"62000\n",
|
||
|
"63000\n",
|
||
|
"64000\n",
|
||
|
"65000\n",
|
||
|
"66000\n",
|
||
|
"67000\n",
|
||
|
"68000\n",
|
||
|
"69000\n",
|
||
|
"70000\n",
|
||
|
"71000\n",
|
||
|
"72000\n",
|
||
|
"73000\n",
|
||
|
"74000\n",
|
||
|
"75000\n",
|
||
|
"76000\n",
|
||
|
"77000\n",
|
||
|
"78000\n",
|
||
|
"79000\n",
|
||
|
"80000\n",
|
||
|
"81000\n",
|
||
|
"82000\n",
|
||
|
"83000\n",
|
||
|
"84000\n",
|
||
|
"85000\n",
|
||
|
"86000\n",
|
||
|
"87000\n",
|
||
|
"88000\n",
|
||
|
"89000\n",
|
||
|
"90000\n",
|
||
|
"91000\n",
|
||
|
"92000\n",
|
||
|
"93000\n",
|
||
|
"94000\n",
|
||
|
"95000\n",
|
||
|
"96000\n",
|
||
|
"97000\n",
|
||
|
"98000\n",
|
||
|
"99000\n",
|
||
|
"100000\n",
|
||
|
"101000\n",
|
||
|
"102000\n",
|
||
|
"103000\n",
|
||
|
"104000\n",
|
||
|
"105000\n",
|
||
|
"106000\n",
|
||
|
"107000\n",
|
||
|
"108000\n",
|
||
|
"109000\n",
|
||
|
"110000\n",
|
||
|
"111000\n",
|
||
|
"112000\n",
|
||
|
"113000\n",
|
||
|
"114000\n",
|
||
|
"115000\n",
|
||
|
"116000\n",
|
||
|
"117000\n",
|
||
|
"118000\n",
|
||
|
"119000\n",
|
||
|
"120000\n",
|
||
|
"121000\n",
|
||
|
"122000\n",
|
||
|
"123000\n",
|
||
|
"124000\n",
|
||
|
"125000\n",
|
||
|
"126000\n",
|
||
|
"127000\n",
|
||
|
"128000\n",
|
||
|
"129000\n",
|
||
|
"130000\n",
|
||
|
"131000\n",
|
||
|
"132000\n",
|
||
|
"133000\n",
|
||
|
"134000\n",
|
||
|
"135000\n",
|
||
|
"136000\n",
|
||
|
"137000\n",
|
||
|
"138000\n",
|
||
|
"139000\n",
|
||
|
"140000\n",
|
||
|
"141000\n",
|
||
|
"142000\n",
|
||
|
"143000\n",
|
||
|
"144000\n",
|
||
|
"145000\n",
|
||
|
"146000\n",
|
||
|
"147000\n",
|
||
|
"148000\n",
|
||
|
"149000\n",
|
||
|
"150000\n",
|
||
|
"151000\n",
|
||
|
"152000\n",
|
||
|
"153000\n",
|
||
|
"154000\n",
|
||
|
"155000\n",
|
||
|
"156000\n",
|
||
|
"157000\n",
|
||
|
"158000\n",
|
||
|
"159000\n",
|
||
|
"160000\n",
|
||
|
"161000\n",
|
||
|
"162000\n",
|
||
|
"163000\n",
|
||
|
"164000\n",
|
||
|
"165000\n",
|
||
|
"166000\n",
|
||
|
"167000\n",
|
||
|
"168000\n",
|
||
|
"169000\n",
|
||
|
"170000\n",
|
||
|
"171000\n",
|
||
|
"172000\n",
|
||
|
"173000\n",
|
||
|
"174000\n",
|
||
|
"175000\n",
|
||
|
"176000\n",
|
||
|
"177000\n",
|
||
|
"178000\n",
|
||
|
"179000\n",
|
||
|
"180000\n",
|
||
|
"181000\n",
|
||
|
"182000\n",
|
||
|
"183000\n",
|
||
|
"184000\n",
|
||
|
"185000\n",
|
||
|
"186000\n",
|
||
|
"187000\n",
|
||
|
"188000\n",
|
||
|
"189000\n",
|
||
|
"190000\n",
|
||
|
"191000\n",
|
||
|
"192000\n",
|
||
|
"193000\n",
|
||
|
"194000\n",
|
||
|
"195000\n",
|
||
|
"196000\n",
|
||
|
"197000\n",
|
||
|
"198000\n",
|
||
|
"199000\n",
|
||
|
"200000\n",
|
||
|
"201000\n",
|
||
|
"202000\n",
|
||
|
"203000\n",
|
||
|
"204000\n",
|
||
|
"205000\n",
|
||
|
"206000\n",
|
||
|
"207000\n",
|
||
|
"208000\n",
|
||
|
"209000\n",
|
||
|
"210000\n",
|
||
|
"211000\n",
|
||
|
"212000\n",
|
||
|
"213000\n",
|
||
|
"214000\n",
|
||
|
"215000\n",
|
||
|
"216000\n",
|
||
|
"217000\n",
|
||
|
"218000\n",
|
||
|
"219000\n",
|
||
|
"220000\n",
|
||
|
"221000\n",
|
||
|
"222000\n",
|
||
|
"223000\n",
|
||
|
"224000\n",
|
||
|
"225000\n",
|
||
|
"226000\n",
|
||
|
"227000\n",
|
||
|
"228000\n",
|
||
|
"229000\n",
|
||
|
"230000\n",
|
||
|
"231000\n",
|
||
|
"232000\n",
|
||
|
"233000\n",
|
||
|
"234000\n",
|
||
|
"235000\n",
|
||
|
"236000\n",
|
||
|
"237000\n",
|
||
|
"238000\n",
|
||
|
"239000\n",
|
||
|
"240000\n",
|
||
|
"241000\n",
|
||
|
"242000\n",
|
||
|
"243000\n",
|
||
|
"244000\n",
|
||
|
"245000\n",
|
||
|
"246000\n",
|
||
|
"247000\n",
|
||
|
"248000\n",
|
||
|
"249000\n",
|
||
|
"250000\n",
|
||
|
"251000\n",
|
||
|
"252000\n",
|
||
|
"253000\n",
|
||
|
"254000\n",
|
||
|
"255000\n",
|
||
|
"256000\n",
|
||
|
"257000\n",
|
||
|
"258000\n",
|
||
|
"259000\n",
|
||
|
"260000\n",
|
||
|
"261000\n",
|
||
|
"262000\n",
|
||
|
"263000\n",
|
||
|
"264000\n",
|
||
|
"265000\n",
|
||
|
"266000\n",
|
||
|
"267000\n",
|
||
|
"268000\n",
|
||
|
"269000\n",
|
||
|
"270000\n",
|
||
|
"271000\n",
|
||
|
"272000\n",
|
||
|
"273000\n",
|
||
|
"274000\n",
|
||
|
"275000\n",
|
||
|
"276000\n",
|
||
|
"277000\n",
|
||
|
"278000\n",
|
||
|
"279000\n",
|
||
|
"280000\n",
|
||
|
"281000\n",
|
||
|
"282000\n",
|
||
|
"283000\n",
|
||
|
"284000\n",
|
||
|
"285000\n",
|
||
|
"286000\n",
|
||
|
"287000\n",
|
||
|
"288000\n",
|
||
|
"289000\n",
|
||
|
"290000\n",
|
||
|
"291000\n",
|
||
|
"292000\n",
|
||
|
"293000\n",
|
||
|
"294000\n",
|
||
|
"295000\n",
|
||
|
"296000\n",
|
||
|
"297000\n",
|
||
|
"298000\n",
|
||
|
"299000\n",
|
||
|
"300000\n",
|
||
|
"301000\n",
|
||
|
"302000\n",
|
||
|
"303000\n",
|
||
|
"304000\n",
|
||
|
"305000\n",
|
||
|
"306000\n",
|
||
|
"307000\n",
|
||
|
"308000\n",
|
||
|
"309000\n",
|
||
|
"310000\n",
|
||
|
"311000\n",
|
||
|
"312000\n",
|
||
|
"313000\n",
|
||
|
"314000\n",
|
||
|
"315000\n",
|
||
|
"316000\n",
|
||
|
"317000\n",
|
||
|
"318000\n",
|
||
|
"319000\n",
|
||
|
"320000\n",
|
||
|
"321000\n",
|
||
|
"322000\n",
|
||
|
"323000\n",
|
||
|
"324000\n",
|
||
|
"325000\n",
|
||
|
"326000\n",
|
||
|
"327000\n",
|
||
|
"328000\n",
|
||
|
"329000\n",
|
||
|
"330000\n",
|
||
|
"331000\n",
|
||
|
"332000\n",
|
||
|
"333000\n",
|
||
|
"334000\n",
|
||
|
"335000\n",
|
||
|
"336000\n",
|
||
|
"337000\n",
|
||
|
"338000\n",
|
||
|
"339000\n",
|
||
|
"340000\n",
|
||
|
"341000\n",
|
||
|
"342000\n",
|
||
|
"343000\n",
|
||
|
"344000\n",
|
||
|
"345000\n",
|
||
|
"346000\n",
|
||
|
"347000\n",
|
||
|
"348000\n",
|
||
|
"349000\n",
|
||
|
"350000\n",
|
||
|
"351000\n",
|
||
|
"352000\n",
|
||
|
"353000\n",
|
||
|
"354000\n",
|
||
|
"355000\n",
|
||
|
"356000\n",
|
||
|
"357000\n",
|
||
|
"358000\n",
|
||
|
"359000\n",
|
||
|
"360000\n",
|
||
|
"361000\n",
|
||
|
"362000\n",
|
||
|
"363000\n",
|
||
|
"364000\n",
|
||
|
"365000\n",
|
||
|
"366000\n",
|
||
|
"367000\n",
|
||
|
"368000\n",
|
||
|
"369000\n",
|
||
|
"370000\n",
|
||
|
"371000\n",
|
||
|
"372000\n",
|
||
|
"373000\n",
|
||
|
"374000\n",
|
||
|
"375000\n",
|
||
|
"376000\n",
|
||
|
"377000\n",
|
||
|
"378000\n",
|
||
|
"379000\n",
|
||
|
"380000\n",
|
||
|
"381000\n",
|
||
|
"382000\n",
|
||
|
"383000\n",
|
||
|
"384000\n",
|
||
|
"385000\n",
|
||
|
"386000\n",
|
||
|
"387000\n",
|
||
|
"388000\n",
|
||
|
"389000\n",
|
||
|
"390000\n",
|
||
|
"391000\n",
|
||
|
"392000\n",
|
||
|
"393000\n",
|
||
|
"394000\n",
|
||
|
"395000\n",
|
||
|
"396000\n",
|
||
|
"397000\n",
|
||
|
"398000\n",
|
||
|
"399000\n",
|
||
|
"400000\n",
|
||
|
"401000\n",
|
||
|
"402000\n",
|
||
|
"403000\n",
|
||
|
"404000\n",
|
||
|
"405000\n",
|
||
|
"406000\n",
|
||
|
"407000\n",
|
||
|
"408000\n",
|
||
|
"409000\n",
|
||
|
"410000\n",
|
||
|
"411000\n",
|
||
|
"412000\n",
|
||
|
"413000\n",
|
||
|
"414000\n",
|
||
|
"415000\n",
|
||
|
"416000\n",
|
||
|
"417000\n",
|
||
|
"418000\n",
|
||
|
"419000\n",
|
||
|
"420000\n",
|
||
|
"421000\n",
|
||
|
"422000\n",
|
||
|
"423000\n",
|
||
|
"424000\n",
|
||
|
"425000\n",
|
||
|
"426000\n",
|
||
|
"427000\n",
|
||
|
"428000\n",
|
||
|
"429000\n",
|
||
|
"430000\n",
|
||
|
"431000\n",
|
||
|
"432000\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"vocab = build_vocab_from_iterator(\n",
|
||
|
" get_word_lines_from_file(train_file),\n",
|
||
|
" max_tokens = params['vocab_size'],\n",
|
||
|
" specials = ['<unk>'])"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 15,
|
||
|
"id": "6136fbb9",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"with open('filename.pickle', 'wb') as handle:\n",
|
||
|
" pickle.dump(vocab, handle, protocol=pickle.HIGHEST_PROTOCOL)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 23,
|
||
|
"id": "30a5b26b",
|
||
|
"metadata": {
|
||
|
"scrolled": true
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"1000\n",
|
||
|
"2000\n",
|
||
|
"3000\n",
|
||
|
"4000\n",
|
||
|
"5000\n",
|
||
|
"6000\n",
|
||
|
"7000\n",
|
||
|
"8000\n",
|
||
|
"9000\n",
|
||
|
"10000\n",
|
||
|
"11000\n",
|
||
|
"12000\n",
|
||
|
"13000\n",
|
||
|
"14000\n",
|
||
|
"15000\n",
|
||
|
"16000\n",
|
||
|
"17000\n",
|
||
|
"18000\n",
|
||
|
"19000\n",
|
||
|
"20000\n",
|
||
|
"21000\n",
|
||
|
"22000\n",
|
||
|
"23000\n",
|
||
|
"24000\n",
|
||
|
"25000\n",
|
||
|
"26000\n",
|
||
|
"27000\n",
|
||
|
"28000\n",
|
||
|
"29000\n",
|
||
|
"30000\n",
|
||
|
"31000\n",
|
||
|
"32000\n",
|
||
|
"33000\n",
|
||
|
"34000\n",
|
||
|
"35000\n",
|
||
|
"36000\n",
|
||
|
"37000\n",
|
||
|
"38000\n",
|
||
|
"39000\n",
|
||
|
"40000\n",
|
||
|
"41000\n",
|
||
|
"42000\n",
|
||
|
"43000\n",
|
||
|
"44000\n",
|
||
|
"45000\n",
|
||
|
"46000\n",
|
||
|
"47000\n",
|
||
|
"48000\n",
|
||
|
"49000\n",
|
||
|
"50000\n",
|
||
|
"51000\n",
|
||
|
"52000\n",
|
||
|
"53000\n",
|
||
|
"54000\n",
|
||
|
"55000\n",
|
||
|
"56000\n",
|
||
|
"57000\n",
|
||
|
"58000\n",
|
||
|
"59000\n",
|
||
|
"60000\n",
|
||
|
"61000\n",
|
||
|
"62000\n",
|
||
|
"63000\n",
|
||
|
"64000\n",
|
||
|
"65000\n",
|
||
|
"66000\n",
|
||
|
"67000\n",
|
||
|
"68000\n",
|
||
|
"69000\n",
|
||
|
"70000\n",
|
||
|
"71000\n",
|
||
|
"72000\n",
|
||
|
"73000\n",
|
||
|
"74000\n",
|
||
|
"75000\n",
|
||
|
"76000\n",
|
||
|
"77000\n",
|
||
|
"78000\n",
|
||
|
"79000\n",
|
||
|
"80000\n",
|
||
|
"81000\n",
|
||
|
"82000\n",
|
||
|
"83000\n",
|
||
|
"84000\n",
|
||
|
"85000\n",
|
||
|
"86000\n",
|
||
|
"87000\n",
|
||
|
"88000\n",
|
||
|
"89000\n",
|
||
|
"90000\n",
|
||
|
"91000\n",
|
||
|
"92000\n",
|
||
|
"93000\n",
|
||
|
"94000\n",
|
||
|
"95000\n",
|
||
|
"96000\n",
|
||
|
"97000\n",
|
||
|
"98000\n",
|
||
|
"99000\n",
|
||
|
"100000\n",
|
||
|
"101000\n",
|
||
|
"102000\n",
|
||
|
"103000\n",
|
||
|
"104000\n",
|
||
|
"105000\n",
|
||
|
"106000\n",
|
||
|
"107000\n",
|
||
|
"108000\n",
|
||
|
"109000\n",
|
||
|
"110000\n",
|
||
|
"111000\n",
|
||
|
"112000\n",
|
||
|
"113000\n",
|
||
|
"114000\n",
|
||
|
"115000\n",
|
||
|
"116000\n",
|
||
|
"117000\n",
|
||
|
"118000\n",
|
||
|
"119000\n",
|
||
|
"120000\n",
|
||
|
"121000\n",
|
||
|
"122000\n",
|
||
|
"123000\n",
|
||
|
"124000\n",
|
||
|
"125000\n",
|
||
|
"126000\n",
|
||
|
"127000\n",
|
||
|
"128000\n",
|
||
|
"129000\n",
|
||
|
"130000\n",
|
||
|
"131000\n",
|
||
|
"132000\n",
|
||
|
"133000\n",
|
||
|
"134000\n",
|
||
|
"135000\n",
|
||
|
"136000\n",
|
||
|
"137000\n",
|
||
|
"138000\n",
|
||
|
"139000\n",
|
||
|
"140000\n",
|
||
|
"141000\n",
|
||
|
"142000\n",
|
||
|
"143000\n",
|
||
|
"144000\n",
|
||
|
"145000\n",
|
||
|
"146000\n",
|
||
|
"147000\n",
|
||
|
"148000\n",
|
||
|
"149000\n",
|
||
|
"150000\n",
|
||
|
"151000\n",
|
||
|
"152000\n",
|
||
|
"153000\n",
|
||
|
"154000\n",
|
||
|
"155000\n",
|
||
|
"156000\n",
|
||
|
"157000\n",
|
||
|
"158000\n",
|
||
|
"159000\n",
|
||
|
"160000\n",
|
||
|
"161000\n",
|
||
|
"162000\n",
|
||
|
"163000\n",
|
||
|
"164000\n",
|
||
|
"165000\n",
|
||
|
"166000\n",
|
||
|
"167000\n",
|
||
|
"168000\n",
|
||
|
"169000\n",
|
||
|
"170000\n",
|
||
|
"171000\n",
|
||
|
"172000\n",
|
||
|
"173000\n",
|
||
|
"174000\n",
|
||
|
"175000\n",
|
||
|
"176000\n",
|
||
|
"177000\n",
|
||
|
"178000\n",
|
||
|
"179000\n",
|
||
|
"180000\n",
|
||
|
"181000\n",
|
||
|
"182000\n",
|
||
|
"183000\n",
|
||
|
"184000\n",
|
||
|
"185000\n",
|
||
|
"186000\n",
|
||
|
"187000\n",
|
||
|
"188000\n",
|
||
|
"189000\n",
|
||
|
"190000\n",
|
||
|
"191000\n",
|
||
|
"192000\n",
|
||
|
"193000\n",
|
||
|
"194000\n",
|
||
|
"195000\n",
|
||
|
"196000\n",
|
||
|
"197000\n",
|
||
|
"198000\n",
|
||
|
"199000\n",
|
||
|
"200000\n",
|
||
|
"201000\n",
|
||
|
"202000\n",
|
||
|
"203000\n",
|
||
|
"204000\n",
|
||
|
"205000\n",
|
||
|
"206000\n",
|
||
|
"207000\n",
|
||
|
"208000\n",
|
||
|
"209000\n",
|
||
|
"210000\n",
|
||
|
"211000\n",
|
||
|
"212000\n",
|
||
|
"213000\n",
|
||
|
"214000\n",
|
||
|
"215000\n",
|
||
|
"216000\n",
|
||
|
"217000\n",
|
||
|
"218000\n",
|
||
|
"219000\n",
|
||
|
"220000\n",
|
||
|
"221000\n",
|
||
|
"222000\n",
|
||
|
"223000\n",
|
||
|
"224000\n",
|
||
|
"225000\n",
|
||
|
"226000\n",
|
||
|
"227000\n",
|
||
|
"228000\n",
|
||
|
"229000\n",
|
||
|
"230000\n",
|
||
|
"231000\n",
|
||
|
"232000\n",
|
||
|
"233000\n",
|
||
|
"234000\n",
|
||
|
"235000\n",
|
||
|
"236000\n",
|
||
|
"237000\n",
|
||
|
"238000\n",
|
||
|
"239000\n",
|
||
|
"240000\n",
|
||
|
"241000\n",
|
||
|
"242000\n",
|
||
|
"243000\n",
|
||
|
"244000\n",
|
||
|
"245000\n",
|
||
|
"246000\n",
|
||
|
"247000\n",
|
||
|
"248000\n",
|
||
|
"249000\n",
|
||
|
"250000\n",
|
||
|
"251000\n",
|
||
|
"252000\n",
|
||
|
"253000\n",
|
||
|
"254000\n",
|
||
|
"255000\n",
|
||
|
"256000\n",
|
||
|
"257000\n",
|
||
|
"258000\n",
|
||
|
"259000\n",
|
||
|
"260000\n",
|
||
|
"261000\n",
|
||
|
"262000\n",
|
||
|
"263000\n",
|
||
|
"264000\n",
|
||
|
"265000\n",
|
||
|
"266000\n",
|
||
|
"267000\n",
|
||
|
"268000\n",
|
||
|
"269000\n",
|
||
|
"270000\n",
|
||
|
"271000\n",
|
||
|
"272000\n",
|
||
|
"273000\n",
|
||
|
"274000\n",
|
||
|
"275000\n",
|
||
|
"276000\n",
|
||
|
"277000\n",
|
||
|
"278000\n",
|
||
|
"279000\n",
|
||
|
"280000\n",
|
||
|
"281000\n",
|
||
|
"282000\n",
|
||
|
"283000\n",
|
||
|
"284000\n",
|
||
|
"285000\n",
|
||
|
"286000\n",
|
||
|
"287000\n",
|
||
|
"288000\n",
|
||
|
"289000\n",
|
||
|
"290000\n",
|
||
|
"291000\n",
|
||
|
"292000\n",
|
||
|
"293000\n",
|
||
|
"294000\n",
|
||
|
"295000\n",
|
||
|
"296000\n",
|
||
|
"297000\n",
|
||
|
"298000\n",
|
||
|
"299000\n",
|
||
|
"300000\n",
|
||
|
"301000\n",
|
||
|
"302000\n",
|
||
|
"303000\n",
|
||
|
"304000\n",
|
||
|
"305000\n",
|
||
|
"306000\n",
|
||
|
"307000\n",
|
||
|
"308000\n",
|
||
|
"309000\n",
|
||
|
"310000\n",
|
||
|
"311000\n",
|
||
|
"312000\n",
|
||
|
"313000\n",
|
||
|
"314000\n",
|
||
|
"315000\n",
|
||
|
"316000\n",
|
||
|
"317000\n",
|
||
|
"318000\n",
|
||
|
"319000\n",
|
||
|
"320000\n",
|
||
|
"321000\n",
|
||
|
"322000\n",
|
||
|
"323000\n",
|
||
|
"324000\n",
|
||
|
"325000\n",
|
||
|
"326000\n",
|
||
|
"327000\n",
|
||
|
"328000\n",
|
||
|
"329000\n",
|
||
|
"330000\n",
|
||
|
"331000\n",
|
||
|
"332000\n",
|
||
|
"333000\n",
|
||
|
"334000\n",
|
||
|
"335000\n",
|
||
|
"336000\n",
|
||
|
"337000\n",
|
||
|
"338000\n",
|
||
|
"339000\n",
|
||
|
"340000\n",
|
||
|
"341000\n",
|
||
|
"342000\n",
|
||
|
"343000\n",
|
||
|
"344000\n",
|
||
|
"345000\n",
|
||
|
"346000\n",
|
||
|
"347000\n",
|
||
|
"348000\n",
|
||
|
"349000\n",
|
||
|
"350000\n",
|
||
|
"351000\n",
|
||
|
"352000\n",
|
||
|
"353000\n",
|
||
|
"354000\n",
|
||
|
"355000\n",
|
||
|
"356000\n",
|
||
|
"357000\n",
|
||
|
"358000\n",
|
||
|
"359000\n",
|
||
|
"360000\n",
|
||
|
"361000\n",
|
||
|
"362000\n",
|
||
|
"363000\n",
|
||
|
"364000\n",
|
||
|
"365000\n",
|
||
|
"366000\n",
|
||
|
"367000\n",
|
||
|
"368000\n",
|
||
|
"369000\n",
|
||
|
"370000\n",
|
||
|
"371000\n",
|
||
|
"372000\n",
|
||
|
"373000\n",
|
||
|
"374000\n",
|
||
|
"375000\n",
|
||
|
"376000\n",
|
||
|
"377000\n",
|
||
|
"378000\n",
|
||
|
"379000\n",
|
||
|
"380000\n",
|
||
|
"381000\n",
|
||
|
"382000\n",
|
||
|
"383000\n",
|
||
|
"384000\n",
|
||
|
"385000\n",
|
||
|
"386000\n",
|
||
|
"387000\n",
|
||
|
"388000\n",
|
||
|
"389000\n",
|
||
|
"390000\n",
|
||
|
"391000\n",
|
||
|
"392000\n",
|
||
|
"393000\n",
|
||
|
"394000\n",
|
||
|
"395000\n",
|
||
|
"396000\n",
|
||
|
"397000\n",
|
||
|
"398000\n",
|
||
|
"399000\n",
|
||
|
"400000\n",
|
||
|
"401000\n",
|
||
|
"402000\n",
|
||
|
"403000\n",
|
||
|
"404000\n",
|
||
|
"405000\n",
|
||
|
"406000\n",
|
||
|
"407000\n",
|
||
|
"408000\n",
|
||
|
"409000\n",
|
||
|
"410000\n",
|
||
|
"411000\n",
|
||
|
"412000\n",
|
||
|
"413000\n",
|
||
|
"414000\n",
|
||
|
"415000\n",
|
||
|
"416000\n",
|
||
|
"417000\n",
|
||
|
"418000\n",
|
||
|
"419000\n",
|
||
|
"420000\n",
|
||
|
"421000\n",
|
||
|
"422000\n",
|
||
|
"423000\n",
|
||
|
"424000\n",
|
||
|
"425000\n",
|
||
|
"426000\n",
|
||
|
"427000\n",
|
||
|
"428000\n",
|
||
|
"429000\n",
|
||
|
"430000\n",
|
||
|
"431000\n",
|
||
|
"432000\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"with open('filename.pickle','rb') as handle:\n",
|
||
|
" vocab = pickle.load(handle)\n",
|
||
|
" \n",
|
||
|
"train_dataset = Trigrams(train_file, params['vocab_size'])"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 21,
|
||
|
"id": "eaa681b4",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"data = DataLoader(train_dataset, batch_size=params['batch_size']) #load data "
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 16,
|
||
|
"id": "3aea0574",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"class SimpleTrigramNeuralLanguageModel(nn.Module):\n",
|
||
|
" def __init__(self, vocabulary_size, embedding_size):\n",
|
||
|
" super(SimpleTrigramNeuralLanguageModel, self).__init__()\n",
|
||
|
" self.embeddings = nn.Embedding(vocabulary_size, embedding_size)\n",
|
||
|
" self.linear = nn.Linear(2*embedding_size, vocabulary_size)\n",
|
||
|
" self.linear_matrix_2 = nn.Linear(embedding_size*2, embedding_size*2)\n",
|
||
|
" self.relu = nn.ReLU()\n",
|
||
|
" self.softmax = nn.Softmax()\n",
|
||
|
" \n",
|
||
|
" #for each word in vocabulary theres a separate embedding vector, consisting of embedding_size entries\n",
|
||
|
" #self.linear is linear layer consisting of concatenated embeddings of left, and right context words\n",
|
||
|
" #self.linear_matrix_2 is linear layer \n",
|
||
|
" \n",
|
||
|
" def forward(self, x): #x is list of prev and following embeddings\n",
|
||
|
" emb_left = self.embeddings(x[0])\n",
|
||
|
" emb_right = self.embeddings(x[1])\n",
|
||
|
" #create two embeddings vectors, for word before and after, respectively\n",
|
||
|
" \n",
|
||
|
" first_layer_size_2 = self.linear_matrix_2(torch.cat((emb_left, emb_right), dim=1))\n",
|
||
|
" first_relu = self.relu(first_layer_size_2)\n",
|
||
|
" concated = self.linear(first_relu)\n",
|
||
|
" out = self.softmax(concated)\n",
|
||
|
" return out"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 24,
|
||
|
"id": "e4757295",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"0"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 24,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"import gc\n",
|
||
|
"torch.cuda.empty_cache()\n",
|
||
|
"gc.collect()"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 17,
|
||
|
"id": "0a41831e",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"device = 'cuda'\n",
|
||
|
"model = SimpleTrigramNeuralLanguageModel(params['vocab_size'], params['embed_size']).to(device)\n",
|
||
|
"optimizer = torch.optim.Adam(model.parameters(), lr=params['learning_rate'])\n",
|
||
|
"criterion = torch.nn.NLLLoss()"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 26,
|
||
|
"id": "281b9010",
|
||
|
"metadata": {
|
||
|
"scrolled": true
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"epoch: = 0\n",
|
||
|
"0 tensor(5.3414, device='cuda:0', grad_fn=<NllLossBackward0>)\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"name": "stderr",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"/tmp/ipykernel_37433/606935597.py:22: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n",
|
||
|
" out = self.softmax(concated)\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"1000\n",
|
||
|
"100 tensor(5.4870, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
||
|
"200 tensor(5.3542, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
||
|
"2000\n",
|
||
|
"300 tensor(5.3792, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
||
|
"3000\n",
|
||
|
"400 tensor(5.5982, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
||
|
"4000\n",
|
||
|
"500 tensor(5.4045, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
||
|
"5000\n",
|
||
|
"600 tensor(5.5620, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
||
|
"6000\n",
|
||
|
"700 tensor(5.5428, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
||
|
"7000\n",
|
||
|
"800 tensor(5.3684, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
||
|
"8000\n",
|
||
|
"900 tensor(5.4198, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
||
|
"9000\n",
|
||
|
"1000 tensor(5.4100, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
||
|
"10000\n",
|
||
|
"1100 tensor(5.4554, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
||
|
"11000\n",
|
||
|
"1200 tensor(5.5284, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
||
|
"12000\n",
|
||
|
"1300 tensor(5.5495, device='cuda:0', grad_fn=<NllLossBackward0>)\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"name": "stderr",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"/home/gedin/.local/lib/python3.10/site-packages/torch/autograd/__init__.py:200: UserWarning: Error detected in LogBackward0. Traceback of forward call that caused the error:\n",
|
||
|
" File \"/usr/lib/python3.10/runpy.py\", line 196, in _run_module_as_main\n",
|
||
|
" return _run_code(code, main_globals, None,\n",
|
||
|
" File \"/usr/lib/python3.10/runpy.py\", line 86, in _run_code\n",
|
||
|
" exec(code, run_globals)\n",
|
||
|
" File \"/home/gedin/.local/lib/python3.10/site-packages/ipykernel_launcher.py\", line 17, in <module>\n",
|
||
|
" app.launch_new_instance()\n",
|
||
|
" File \"/home/gedin/.local/lib/python3.10/site-packages/traitlets/config/application.py\", line 1043, in launch_instance\n",
|
||
|
" app.start()\n",
|
||
|
" File \"/home/gedin/.local/lib/python3.10/site-packages/ipykernel/kernelapp.py\", line 725, in start\n",
|
||
|
" self.io_loop.start()\n",
|
||
|
" File \"/home/gedin/.local/lib/python3.10/site-packages/tornado/platform/asyncio.py\", line 195, in start\n",
|
||
|
" self.asyncio_loop.run_forever()\n",
|
||
|
" File \"/usr/lib/python3.10/asyncio/base_events.py\", line 600, in run_forever\n",
|
||
|
" self._run_once()\n",
|
||
|
" File \"/usr/lib/python3.10/asyncio/base_events.py\", line 1896, in _run_once\n",
|
||
|
" handle._run()\n",
|
||
|
" File \"/usr/lib/python3.10/asyncio/events.py\", line 80, in _run\n",
|
||
|
" self._context.run(self._callback, *self._args)\n",
|
||
|
" File \"/home/gedin/.local/lib/python3.10/site-packages/ipykernel/kernelbase.py\", line 513, in dispatch_queue\n",
|
||
|
" await self.process_one()\n",
|
||
|
" File \"/home/gedin/.local/lib/python3.10/site-packages/ipykernel/kernelbase.py\", line 502, in process_one\n",
|
||
|
" await dispatch(*args)\n",
|
||
|
" File \"/home/gedin/.local/lib/python3.10/site-packages/ipykernel/kernelbase.py\", line 409, in dispatch_shell\n",
|
||
|
" await result\n",
|
||
|
" File \"/home/gedin/.local/lib/python3.10/site-packages/ipykernel/kernelbase.py\", line 729, in execute_request\n",
|
||
|
" reply_content = await reply_content\n",
|
||
|
" File \"/home/gedin/.local/lib/python3.10/site-packages/ipykernel/ipkernel.py\", line 422, in do_execute\n",
|
||
|
" res = shell.run_cell(\n",
|
||
|
" File \"/home/gedin/.local/lib/python3.10/site-packages/ipykernel/zmqshell.py\", line 540, in run_cell\n",
|
||
|
" return super().run_cell(*args, **kwargs)\n",
|
||
|
" File \"/home/gedin/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py\", line 3009, in run_cell\n",
|
||
|
" result = self._run_cell(\n",
|
||
|
" File \"/home/gedin/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py\", line 3064, in _run_cell\n",
|
||
|
" result = runner(coro)\n",
|
||
|
" File \"/home/gedin/.local/lib/python3.10/site-packages/IPython/core/async_helpers.py\", line 129, in _pseudo_sync_runner\n",
|
||
|
" coro.send(None)\n",
|
||
|
" File \"/home/gedin/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py\", line 3269, in run_cell_async\n",
|
||
|
" has_raised = await self.run_ast_nodes(code_ast.body, cell_name,\n",
|
||
|
" File \"/home/gedin/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py\", line 3448, in run_ast_nodes\n",
|
||
|
" if await self.run_code(code, result, async_=asy):\n",
|
||
|
" File \"/home/gedin/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py\", line 3508, in run_code\n",
|
||
|
" exec(code_obj, self.user_global_ns, self.user_ns)\n",
|
||
|
" File \"/tmp/ipykernel_37433/1707264841.py\", line 13, in <module>\n",
|
||
|
" loss = criterion(torch.log(ypredicted), x) #x is to_predict\n",
|
||
|
" (Triggered internally at ../torch/csrc/autograd/python_anomaly_mode.cpp:114.)\n",
|
||
|
" Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"ename": "RuntimeError",
|
||
|
"evalue": "Function 'LogBackward0' returned nan values in its 0th output.",
|
||
|
"output_type": "error",
|
||
|
"traceback": [
|
||
|
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||
|
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
|
||
|
"Cell \u001b[0;32mIn[26], line 19\u001b[0m\n\u001b[1;32m 16\u001b[0m step \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m 17\u001b[0m \u001b[38;5;66;03m# if step % 10000 == 0:\u001b[39;00m\n\u001b[1;32m 18\u001b[0m \u001b[38;5;66;03m# torch.save(model.state_dict(), f'model-tri-2following-{step}.bin')\u001b[39;00m\n\u001b[0;32m---> 19\u001b[0m \u001b[43mloss\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 20\u001b[0m optimizer\u001b[38;5;241m.\u001b[39mstep()\n\u001b[1;32m 21\u001b[0m \u001b[38;5;66;03m# torch.save(model.state_dict(), f'model-tri-2following-{i}.bin') \u001b[39;00m\n\u001b[1;32m 22\u001b[0m \u001b[38;5;66;03m# torch.save(model.state_dict(), f'model-tri-2following-final.bin')\u001b[39;00m\n",
|
||
|
"File \u001b[0;32m~/.local/lib/python3.10/site-packages/torch/_tensor.py:487\u001b[0m, in \u001b[0;36mTensor.backward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m 477\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m has_torch_function_unary(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m 478\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m handle_torch_function(\n\u001b[1;32m 479\u001b[0m Tensor\u001b[38;5;241m.\u001b[39mbackward,\n\u001b[1;32m 480\u001b[0m (\u001b[38;5;28mself\u001b[39m,),\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 485\u001b[0m inputs\u001b[38;5;241m=\u001b[39minputs,\n\u001b[1;32m 486\u001b[0m )\n\u001b[0;32m--> 487\u001b[0m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mautograd\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 488\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgradient\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs\u001b[49m\n\u001b[1;32m 489\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
|
||
|
"File \u001b[0;32m~/.local/lib/python3.10/site-packages/torch/autograd/__init__.py:200\u001b[0m, in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m 195\u001b[0m retain_graph \u001b[38;5;241m=\u001b[39m create_graph\n\u001b[1;32m 197\u001b[0m \u001b[38;5;66;03m# The reason we repeat same the comment below is that\u001b[39;00m\n\u001b[1;32m 198\u001b[0m \u001b[38;5;66;03m# some Python versions print out the first line of a multi-line function\u001b[39;00m\n\u001b[1;32m 199\u001b[0m \u001b[38;5;66;03m# calls in the traceback and some print out the last line\u001b[39;00m\n\u001b[0;32m--> 200\u001b[0m \u001b[43mVariable\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_execution_engine\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun_backward\u001b[49m\u001b[43m(\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# Calls into the C++ engine to run the backward pass\u001b[39;49;00m\n\u001b[1;32m 201\u001b[0m \u001b[43m \u001b[49m\u001b[43mtensors\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgrad_tensors_\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 202\u001b[0m \u001b[43m \u001b[49m\u001b[43mallow_unreachable\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maccumulate_grad\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n",
|
||
|
"\u001b[0;31mRuntimeError\u001b[0m: Function 'LogBackward0' returned nan values in its 0th output."
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"torch.autograd.set_detect_anomaly(True)\n",
|
||
|
"model.load_state_dict(torch.load(f'model-tri-2following-40000.bin'))\n",
|
||
|
"for i in range(params['epochs']):\n",
|
||
|
" print('epoch: =', i)\n",
|
||
|
" model.train()\n",
|
||
|
" step = 0\n",
|
||
|
" for x, y, z in data: # word, following, 2nd_following words\n",
|
||
|
" x = x.to(device)\n",
|
||
|
" y = y.to(device)\n",
|
||
|
" z = z.to(device)\n",
|
||
|
" optimizer.zero_grad()\n",
|
||
|
" ypredicted = model([y, z]) #following, 2nd_following word\n",
|
||
|
" loss = criterion(torch.log(ypredicted), x) #x is to_predict\n",
|
||
|
" if step % 100 == 0:\n",
|
||
|
" print(step, loss)\n",
|
||
|
" step += 1\n",
|
||
|
"# if step % 10000 == 0:\n",
|
||
|
"# torch.save(model.state_dict(), f'model-tri-2following-{step}.bin')\n",
|
||
|
" loss.backward()\n",
|
||
|
" optimizer.step()\n",
|
||
|
"# torch.save(model.state_dict(), f'model-tri-2following-{i}.bin') \n",
|
||
|
"# torch.save(model.state_dict(), f'model-tri-2following-final.bin')"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 27,
|
||
|
"id": "54b018d8",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"torch.save(model.state_dict(), f'model-tri-2following-final.bin')"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 30,
|
||
|
"id": "7dd5e6f8",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"def get_first_word(text):\n",
|
||
|
" \"\"\"Return the first word of a string.\"\"\"\n",
|
||
|
" word = \"\"\n",
|
||
|
" for i in range(len(text)-1):\n",
|
||
|
"# if text[i] in [' ', ',', '.']\n",
|
||
|
" if text[i] == ' ':\n",
|
||
|
" return word.rstrip()\n",
|
||
|
" else:\n",
|
||
|
" word += text[i]\n",
|
||
|
" return word.rstrip()\n",
|
||
|
"\n",
|
||
|
"def get_values_from_model(context: list, model, vocab, k=10):\n",
|
||
|
" words = [vocab.forward([word]) for word in context]\n",
|
||
|
" ixs = torch.tensor(words).to(device)\n",
|
||
|
" out = model(ixs)\n",
|
||
|
" top = torch.topk(out[0], k)\n",
|
||
|
" top_indices = top.indices.tolist()\n",
|
||
|
" top_probs = top.values.tolist()\n",
|
||
|
" top_words = vocab.lookup_tokens(top_indices)\n",
|
||
|
" return list(zip(top_words, top_probs))\n",
|
||
|
"\n",
|
||
|
"def summarize_probs_unk(dic, const_wildcard=True):\n",
|
||
|
" ''' \n",
|
||
|
" dic: dictionary of probabilities returned by model \n",
|
||
|
" returns: tab of probabilities, with <unk> specificly as last element\n",
|
||
|
" '''\n",
|
||
|
" if const_wildcard or '<unk>' not in dic.keys(): \n",
|
||
|
" if '<unk>' in dic.keys():\n",
|
||
|
" del dic['<unk>']\n",
|
||
|
" probsum = sum(float(val) for key, val in dic.items())\n",
|
||
|
" for key in dic:\n",
|
||
|
" dic[key] = dic[key]/probsum*(1-wildcard_minweight) ###leave some space for wildcard\n",
|
||
|
" tab = [(key, val) for key, val in dic.items()]\n",
|
||
|
" tab.append(('<unk>', wildcard_minweight))\n",
|
||
|
" else:\n",
|
||
|
" probsum = sum(float(val) for key, val in dic.items())\n",
|
||
|
" for key in dic:\n",
|
||
|
" dic[key] = dic[key]/probsum*(1-wildcard_minweight) ###leave some space for wildcard\n",
|
||
|
" wildcard_value = dic['<unk>']\n",
|
||
|
" del dic['<unk>']\n",
|
||
|
" tab = [(key, val) for key, val in dic.items()]\n",
|
||
|
" tab.append(('<unk>', wildcard_value))\n",
|
||
|
" \n",
|
||
|
" return tab\n",
|
||
|
"\n",
|
||
|
"def gonito_format(dic, const_wildcard = True):\n",
|
||
|
" tab = summarize_probs_unk(dic, const_wildcard)\n",
|
||
|
" result = ''\n",
|
||
|
" for element in tab[:-1]:\n",
|
||
|
" result+=str(element[0])+':'+str(element[1])+'\\t'\n",
|
||
|
" result+=':'+ str(tab[-1][1]) + '\\n'\n",
|
||
|
" return result"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 11,
|
||
|
"id": "2b7513f3",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"###preprocessing\n",
|
||
|
"def preprocess(line):\n",
|
||
|
" line = get_rid_of_header(line)\n",
|
||
|
" line = replace_endline(line)\n",
|
||
|
" return line\n",
|
||
|
"\n",
|
||
|
"def get_rid_of_header(line):\n",
|
||
|
" line = line.split('\\t')[6:]\n",
|
||
|
" return \" \".join(line)\n",
|
||
|
" \n",
|
||
|
"def replace_endline(line):\n",
|
||
|
" line = line.replace(\"\\\\n\", \" \")\n",
|
||
|
" return line"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 39,
|
||
|
"id": "4b0e66e2",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stderr",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"/tmp/ipykernel_37433/606935597.py:22: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n",
|
||
|
" out = self.softmax(concated)\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"[('<unk>', 0, 0.12663832306861877),\n",
|
||
|
" ('one', 43, 0.02672259509563446),\n",
|
||
|
" ('part', 146, 0.015497211366891861),\n",
|
||
|
" ('out', 63, 0.012386629357933998),\n",
|
||
|
" ('some', 76, 0.008164796978235245),\n",
|
||
|
" ('members', 426, 0.00799479242414236),\n",
|
||
|
" ('side', 238, 0.007780702318996191),\n",
|
||
|
" ('portion', 634, 0.005733700469136238),\n",
|
||
|
" ('office', 282, 0.0053163678385317326),\n",
|
||
|
" ('member', 712, 0.005126394797116518)]"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 39,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"\n",
|
||
|
"ixs = torch.tensor([vocab.forward(['of']), vocab.forward(['the'])]).to(device)\n",
|
||
|
"\n",
|
||
|
"out = model(ixs)\n",
|
||
|
"top = torch.topk(out[0], 10)\n",
|
||
|
"top_indices = top.indices.tolist()\n",
|
||
|
"top_probs = top.values.tolist()\n",
|
||
|
"top_words = vocab.lookup_tokens(top_indices)\n",
|
||
|
"list(zip(top_words, top_indices, top_probs))"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 18,
|
||
|
"id": "a92abbf2",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"<All keys matched successfully>"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 18,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"model.load_state_dict(torch.load(f'model-tri-2following-final.bin'))"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 31,
|
||
|
"id": "fc7cf293",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stderr",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"/tmp/ipykernel_4654/606935597.py:22: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n",
|
||
|
" out = self.softmax(concated)\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"with lzma.open(test_file, 'rt') as file:\n",
|
||
|
" predict_words = []\n",
|
||
|
" results = []\n",
|
||
|
" for line in file:\n",
|
||
|
" line = replace_endline(line) #get only relevant\n",
|
||
|
" line = line.split('\\t')[6:]\n",
|
||
|
" context = line[1].rstrip().split(\" \")[:2]\n",
|
||
|
" predict_words.append(context) #get_first_word(split[1cd \n",
|
||
|
" vocab = train_dataset.vocab\n",
|
||
|
" for context_words in predict_words:\n",
|
||
|
" results.append(dict(get_values_from_model(context_words, model, vocab, k=10)))\n",
|
||
|
" \n",
|
||
|
" with open(out_file, 'w') as outfile:\n",
|
||
|
" for elem in results: \n",
|
||
|
" outfile.write(gonito_format(elem, const_wildcard=False))\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"id": "1c31c8ba",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": []
|
||
|
}
|
||
|
],
|
||
|
"metadata": {
|
||
|
"kernelspec": {
|
||
|
"display_name": "Python 3 (ipykernel)",
|
||
|
"language": "python",
|
||
|
"name": "python3"
|
||
|
},
|
||
|
"language_info": {
|
||
|
"codemirror_mode": {
|
||
|
"name": "ipython",
|
||
|
"version": 3
|
||
|
},
|
||
|
"file_extension": ".py",
|
||
|
"mimetype": "text/x-python",
|
||
|
"name": "python",
|
||
|
"nbconvert_exporter": "python",
|
||
|
"pygments_lexer": "ipython3",
|
||
|
"version": "3.10.6"
|
||
|
}
|
||
|
},
|
||
|
"nbformat": 4,
|
||
|
"nbformat_minor": 5
|
||
|
}
|