This commit is contained in:
Kacper 2023-06-06 22:05:40 +02:00
parent 4278d2f8ea
commit abb72efa82
5 changed files with 18060 additions and 18110 deletions

View File

@ -1,4 +1,3 @@
# Użyte elementy z wykładu/ćwiczeń: # Rozszerzenia:
- pełny lewy kontekst skompresowany do jednego tensora obok dwustronnego kontekstu trigramowego - uwzględniony prawy kontekst
- warstwy layer norm - dwukierunkowy LSTM
- warstwy dropout

File diff suppressed because it is too large Load Diff

View File

@ -1,15 +1,14 @@
description: neural network with trigram left-right context plus full left context tensor description: lstm with trigram left-right context of 3 ngrams
tags: tags:
- neural-network - neural-network
- left-context - left-context
- right-context - right-context
- trigrams - lstm
params: params:
vocab_size: 20000 vocab_size: 20000
embed_size: 150 embed_size: 150
batch_size: 4096 batch_size: 4096
hidden_size: 1024 hidden_size: 1024
learning_rate: 0.001 learning_rate: 0.0001
epochs: 10
param-files: param-files:
- "*.yaml" - "*.yaml"

View File

@ -3,146 +3,99 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"from torchtext.vocab import build_vocab_from_iterator\n", "from torchtext.vocab import build_vocab_from_iterator\n",
"import pickle\n", "import pickle\n",
"from torch.utils.data import IterableDataset\n", "from torch.utils.data import IterableDataset\n",
"from itertools import chain\n",
"from torch import nn\n", "from torch import nn\n",
"import torch.nn.functional as F\n",
"import torch\n", "import torch\n",
"import lzma\n", "import lzma\n",
"from torch.utils.data import DataLoader\n", "from torch.utils.data import DataLoader\n",
"import shutil\n",
"torch.manual_seed(1)" "torch.manual_seed(1)"
] ],
"metadata": {
"collapsed": false,
"pycharm": {
"is_executing": true
}
}
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 60,
"outputs": [], "outputs": [],
"source": [ "source": [
"def simple_preprocess(line):\n", "def simple_preprocess(line):\n",
" return line.replace(r'\\n', ' ')\n", " return line.replace(r'\\n', ' ')\n",
"\n", "\n",
"def get_max_left_context_len(file_name):\n",
" print('Getting max left context length...')\n",
" max_len = 0\n",
" with lzma.open(file_name, 'r') as fh:\n",
" for line in fh:\n",
" line = line.decode('utf-8')\n",
" line = line.strip()\n",
" line = line.split('\\t')[-2]\n",
" line = simple_preprocess(line)\n",
" curr_len = len(line.split())\n",
" if curr_len > max_len:\n",
" max_len = curr_len\n",
" print(f'max_len={max_len}')\n",
" return max_len\n",
"\n",
"def get_words_from_line(line):\n", "def get_words_from_line(line):\n",
" for t in line:\n", " line = line.strip()\n",
" line = simple_preprocess(line)\n",
" yield '<s>'\n",
" for t in line.split():\n",
" yield t\n", " yield t\n",
" yield '</s>'\n",
"\n", "\n",
"def get_word_lines_from_file(file_name, max_left_context_len, return_gen, n_size=-1):\n", "def get_word_lines_from_file(file_name, n_size=-1):\n",
" with lzma.open(file_name, 'r') as fh:\n", " with lzma.open(file_name, 'r') as fh:\n",
" n = 0\n", " n = 0\n",
" for line in fh:\n", " for line in fh:\n",
" n += 1\n", " n += 1\n",
" line = line.decode('utf-8')\n", " yield get_words_from_line(line.decode('utf-8'))\n",
" line = line.strip()\n",
" padding = '<pad> ' * (max_left_context_len - 1) # <s>\n",
" left_context = padding + '<s> ' + simple_preprocess(line.split('\\t')[-2])\n",
" right_context = simple_preprocess(line.split('\\t')[-1]) + ' </s> <pad> <pad>'\n",
" line = left_context + ' ' + right_context\n",
" line = line.split()\n",
" if return_gen:\n",
" yield get_words_from_line(line)\n",
" else:\n",
" yield line\n",
" if n == n_size:\n", " if n == n_size:\n",
" break\n", " break\n",
"\n", "\n",
"def look_ahead_iterator(gen, vocab, max_left_context_len):\n", "def look_ahead_iterator(gen):\n",
" ngram = []\n",
" for item in gen:\n", " for item in gen:\n",
" start_pos = item.index('<s>') + 1\n", " if len(ngram) < 7:\n",
" item = [vocab[t] for t in item]\n", " ngram.append(item)\n",
" for i in range(start_pos, len(item) - 4):\n", " if len(ngram) == 7:\n",
" yield [item[:i-3][-max_left_context_len+3:], item[i-3:i], item[i], item[i+1:i+4]]\n", " yield ngram\n",
" else:\n",
" ngram = ngram[1:]\n",
" ngram.append(item)\n",
" yield ngram\n",
"\n", "\n",
"def build_vocab(file, vocab_size, max_left_context_len):\n", "def build_vocab(file, vocab_size):\n",
" try:\n", " try:\n",
" with open(f'vocab_{vocab_size}_padded.pickle', 'rb') as handle:\n", " with open(f'trigram_nn_vocab_{vocab_size}.pickle', 'rb') as handle:\n",
" print('Loading vocab...')\n",
" vocab = pickle.load(handle)\n", " vocab = pickle.load(handle)\n",
" except:\n", " except:\n",
" print('Building vocab...')\n",
" vocab = build_vocab_from_iterator(\n", " vocab = build_vocab_from_iterator(\n",
" get_word_lines_from_file(file, max_left_context_len, return_gen=True),\n", " get_word_lines_from_file(file),\n",
" max_tokens = vocab_size,\n", " max_tokens = vocab_size,\n",
" specials = ['<unk>'])\n", " specials = ['<unk>'])\n",
" with open(f'vocab_{vocab_size}_padded.pickle', 'wb') as handle:\n", " with open(f'trigram_nn_vocab_{vocab_size}.pickle', 'wb') as handle:\n",
" pickle.dump(vocab, handle, protocol=pickle.HIGHEST_PROTOCOL)\n", " pickle.dump(vocab, handle, protocol=pickle.HIGHEST_PROTOCOL)\n",
" return vocab\n", " return vocab\n",
"\n", "\n",
"class Ngrams(IterableDataset):\n", "class Trigrams(IterableDataset):\n",
" def __init__(self, text_file, max_left_context_len):\n", " def __init__(self, text_file):\n",
" self.vocab = vocab\n", " self.vocab = vocab\n",
" self.vocab.set_default_index(self.vocab['<unk>'])\n", " self.vocab.set_default_index(self.vocab['<unk>'])\n",
" self.text_file = text_file\n", " self.text_file = text_file\n",
" self.max_left_context_len = max_left_context_len\n",
"\n", "\n",
" def __iter__(self):\n", " def __iter__(self):\n",
" return look_ahead_iterator(get_word_lines_from_file(self.text_file, max_left_context_len, return_gen=False), self.vocab, self.max_left_context_len)\n", " return look_ahead_iterator(\n",
" (self.vocab[t] for t in chain.from_iterable(get_word_lines_from_file(self.text_file))))\n",
"\n", "\n",
"# Dropout, norm layers adjusted on a case-by-case basis. Also gradual hidden layer size reduction vs. no reduction\n", "class LSTMLanguageModel(nn.Module):\n",
"class NeuralLanguageModel(nn.Module):\n",
" def __init__(self, vocab_size, embed_size, hidden_size):\n", " def __init__(self, vocab_size, embed_size, hidden_size):\n",
" super(NeuralLanguageModel, self).__init__()\n", " super(LSTMLanguageModel, self).__init__()\n",
" self.embeddings = nn.Embedding(vocab_size, embed_size)\n", " self.embeddings = nn.Embedding(vocab_size, embed_size)\n",
" self.hidden_1 = nn.Linear(7*embed_size, hidden_size)\n", " self.lstm_layer = nn.LSTM(6*embed_size, hidden_size, bidirectional=True)\n",
" self.hidden_2 = nn.Linear(hidden_size, int(hidden_size/2))\n", " self.output_layer = nn.Linear(2*hidden_size, vocab_size)\n",
" self.hidden_3 = nn.Linear(int(hidden_size/2), int(hidden_size/4))\n",
" self.output = nn.Linear(int(hidden_size/4), vocab_size)\n",
"\n",
" self.softmax = nn.Softmax(dim=1)\n", " self.softmax = nn.Softmax(dim=1)\n",
" self.norm_input = nn.LayerNorm(7*embed_size)\n",
" self.norm_1 = nn.LayerNorm(int(hidden_size))\n",
" self.norm_2 = nn.LayerNorm(int(hidden_size/2))\n",
" self.norm_3 = nn.LayerNorm(int(hidden_size/4))\n",
" self.activation = nn.LeakyReLU()\n",
" self.dropout = nn.Dropout(0.1)\n",
"\n", "\n",
" def forward(self, x):\n", " def forward(self, x):\n",
" x_whole_left, x_left_trigram, x_right_trigram = x\n", " embeds = [self.embeddings(gram) for gram in x]\n",
" x_whole_left_embed = [self.embeddings(t) for t in x_whole_left]\n", " concat_embed = torch.concat(embeds, dim=1)\n",
" x_whole_left_embed_len = len(x_whole_left_embed)\n", " z = F.relu(self.lstm_layer(concat_embed)[0])\n",
" x_whole_left_embed = torch.stack(x_whole_left_embed)\n", " y = self.softmax(self.output_layer(z))\n",
" x_whole_left_embed = torch.sum(x_whole_left_embed, dim=0) / x_whole_left_embed_len\n",
" #x_whole_left_embed = torch.sum(x_whole_left_embed, dim=0)\n",
" x_left_trigram_embed = torch.concat([self.embeddings(t) for t in x_left_trigram], dim=1)\n",
" x_right_trigram_embed = torch.concat([self.embeddings(t) for t in x_right_trigram], dim=1)\n",
" concat_embed = torch.concat((x_whole_left_embed, x_left_trigram_embed, x_right_trigram_embed), dim=1)\n",
" if torch.isnan(concat_embed).any():\n",
" print('NaN!')\n",
" raise Exception(\"Error\")\n",
" concat_embed = self.norm_input(concat_embed)\n",
" z = self.hidden_1(concat_embed)\n",
" z = self.norm_1(z)\n",
" z = self.activation(z)\n",
" #z = self.dropout(z)\n",
" z = self.hidden_2(z)\n",
" z = self.norm_2(z)\n",
" z = self.activation(z)\n",
" #z = self.dropout(z)\n",
" z = self.hidden_3(z)\n",
" z = self.norm_3(z)\n",
" z = self.activation(z)\n",
" #z = self.dropout(z)\n",
" z = self.output(z)\n",
" y = self.softmax(z)\n",
" return y" " return y"
], ],
"metadata": { "metadata": {
@ -158,69 +111,56 @@
"max_steps = -1\n", "max_steps = -1\n",
"vocab_size = 20000\n", "vocab_size = 20000\n",
"embed_size = 150\n", "embed_size = 150\n",
"batch_size = 4096\n", "batch_size = 1024\n",
"hidden_size = 1024\n", "hidden_size = 512\n",
"learning_rate = 0.001 # < 0.1\n", "learning_rate = 0.0001\n",
"epochs = 1\n", "vocab = build_vocab('challenging-america-word-gap-prediction/train/in.tsv.xz', vocab_size)\n",
"#max_left_context_len = get_max_left_context_len('challenging-america-word-gap-prediction/train/in.tsv.xz')\n", "train_dataset = Trigrams('challenging-america-word-gap-prediction/train/in.tsv.xz')\n",
"max_left_context_len = 291\n",
"torch.manual_seed(1)\n",
"vocab = build_vocab('challenging-america-word-gap-prediction/train/in.tsv.xz', vocab_size, max_left_context_len)\n",
"train_dataset = Ngrams('challenging-america-word-gap-prediction/train/in.tsv.xz', max_left_context_len)\n",
"if torch.cuda.is_available():\n", "if torch.cuda.is_available():\n",
" device = 'cuda'\n", " device = 'cuda'\n",
"else:\n", "else:\n",
" raise Exception()\n", " raise Exception()\n",
"model = NeuralLanguageModel(vocab_size, embed_size, hidden_size).to(device)\n", "model = LSTMLanguageModel(vocab_size, embed_size, hidden_size).to(device)\n",
"#model.load_state_dict(torch.load(model_name))\n",
"data = DataLoader(train_dataset, batch_size=batch_size)\n", "data = DataLoader(train_dataset, batch_size=batch_size)\n",
"optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)\n", "optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)\n",
"criterion = torch.nn.NLLLoss()\n", "criterion = torch.nn.NLLLoss()\n",
"\n", "\n",
"with torch.autograd.set_detect_anomaly(True):\n",
"model.train()\n", "model.train()\n",
" epoch = 0\n",
" for i in range(epochs):\n",
"step = 0\n", "step = 0\n",
" epoch += 1\n", "for ngram in data:\n",
" print(f'--------epoch {epoch}--------')\n", " x = [gram.to(device) for gram in ngram[:3]+ngram[4:]]\n",
" for x_whole_left, x_left_trigram, y, x_right_trigram in data:\n", " y = ngram[3].to(device)\n",
" x = [t.to(device) for t in x_whole_left], [t.to(device) for t in x_left_trigram], [t.to(device) for t in x_right_trigram]\n",
" y = y.to(device)\n",
" optimizer.zero_grad()\n", " optimizer.zero_grad()\n",
" y_pred = model(x)\n", " ypredicted = model(x)\n",
" loss = criterion(torch.log(y_pred), y)\n", " loss = criterion(torch.log(ypredicted), y)\n",
" if step % 1000 == 0:\n", " if step % 100 == 0:\n",
" print(f'steps: {step}, loss: {loss.item()}')\n", " print(f'steps: {step}, loss: {loss.item()}')\n",
" if step % 1000 == 0:\n",
" if step != 0:\n", " if step != 0:\n",
" name = f'loss-{loss.item()}_model_steps-{step}_epoch-{epoch}_vocab-{vocab_size}_embed-{embed_size}_batch-{batch_size}_hidden-{hidden_size}_lr-{learning_rate}.bin'\n", " torch.save(model.state_dict(), f'{loss}_steps-{step}_vocab-{vocab_size}_embed-{embed_size}_batch-{batch_size}_hidden-{hidden_size}_lr-{learning_rate}.bin')\n",
" torch.save(model.state_dict(), 'models/' + name)\n",
" loss.backward()\n", " loss.backward()\n",
" torch.nn.utils.clip_grad_value_(model.parameters(), clip_value=1.0)\n",
" optimizer.step()\n", " optimizer.step()\n",
" if step == max_steps:\n", " if step == max_steps:\n",
" break\n", " break\n",
" step += 1" " step += 1"
], ],
"metadata": { "metadata": {
"collapsed": false "collapsed": false,
"pycharm": {
"is_executing": true
}
} }
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 69,
"outputs": [], "outputs": [],
"source": [ "source": [
"step += 1\n",
"vocab_size = 20000\n", "vocab_size = 20000\n",
"embed_size = 150\n", "embed_size = 150\n",
"batch_size = 4096\n", "hidden_size = 512\n",
"hidden_size = 1024\n", "vocab = build_vocab('challenging-america-word-gap-prediction/train/in.tsv.xz', vocab_size)\n",
"max_left_context_len = 291\n", "vocab.set_default_index(vocab['<unk>'])"
"vocab = build_vocab('challenging-america-word-gap-prediction/train/in.tsv.xz', vocab_size, max_left_context_len)\n",
"vocab.set_default_index(vocab['<unk>'])\n",
"model_name = 'models/' + 'best_model_mod_arch.bin'\n",
"topk = 10"
], ],
"metadata": { "metadata": {
"collapsed": false "collapsed": false
@ -228,30 +168,26 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 74,
"outputs": [], "outputs": [],
"source": [ "source": [
"for model_name in ['best.bin']:\n",
" topk = 100\n",
" preds = []\n", " preds = []\n",
" device = 'cuda'\n", " device = 'cuda'\n",
"model = NeuralLanguageModel(vocab_size, embed_size, hidden_size).to(device)\n", " model = LSTMLanguageModel(vocab_size, embed_size, hidden_size).to(device)\n",
" model.load_state_dict(torch.load(model_name))\n", " model.load_state_dict(torch.load(model_name))\n",
" model.eval()\n", " model.eval()\n",
"j = 0\n",
" for path in ['challenging-america-word-gap-prediction/dev-0', 'challenging-america-word-gap-prediction/test-A']:\n", " for path in ['challenging-america-word-gap-prediction/dev-0', 'challenging-america-word-gap-prediction/test-A']:\n",
" with lzma.open(f'{path}/in.tsv.xz', 'r') as fh, open(f'{path}/out.tsv', 'w', encoding='utf-8') as f_out:\n", " with lzma.open(f'{path}/in.tsv.xz', 'r') as fh, open(f'{path}/out.tsv', 'w', encoding='utf-8') as f_out:\n",
" for line in fh:\n", " for line in fh:\n",
" j += 1\n", " left_context = simple_preprocess(line.decode('utf-8').split('\\t')[-2].strip()).split()[-3:]\n",
" left_context = simple_preprocess(line.decode('utf-8')).split('\\t')[-2].strip()\n", " right_context = simple_preprocess(line.decode('utf-8').split('\\t')[-1].strip()).split()[:3]\n",
" right_context = simple_preprocess(line.decode('utf-8')).split('\\t')[-1].strip()\n", " full_context = left_context + right_context\n",
" padding = '<pad> ' * (max_left_context_len - 1) # <s>\n", " x = [torch.tensor(vocab.forward([word])).to(device) for word in full_context]\n",
" left_context = padding + '<s> ' + left_context\n",
" right_context = right_context + ' </s> <pad> <pad>'\n",
" x_left_trigram, x_right_trigram = left_context.split()[-3:], right_context.split()[:3]\n",
" x = [torch.tensor(vocab.forward([w])).to(device) for w in left_context], [torch.tensor(vocab.forward([w])).to(device) for w in x_left_trigram], [torch.tensor(vocab.forward([w])).to(device) for w in x_right_trigram]\n",
" out = model(x)\n", " out = model(x)\n",
" top = torch.topk(out[0], topk)\n", " top = torch.topk(out[0], topk)\n",
" top_indices = top.indices.tolist()\n", " top_indices = top.indices.tolist()\n",
" print(j, ' '.join(x_left_trigram), '[[[', vocab.lookup_token(top_indices[0]) if vocab.lookup_token(top_indices[0]) != '<unk>' else vocab.lookup_token(top_indices[1]), ']]]', ' '.join(x_right_trigram))\n",
" top_probs = top.values.tolist()\n", " top_probs = top.values.tolist()\n",
" top_words = vocab.lookup_tokens(top_indices)\n", " top_words = vocab.lookup_tokens(top_indices)\n",
" top_zipped = zip(top_words, top_probs)\n", " top_zipped = zip(top_words, top_probs)\n",
@ -268,6 +204,22 @@
"metadata": { "metadata": {
"collapsed": false "collapsed": false
} }
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"%cd challenging-america-word-gap-prediction/\n",
"!./geval --test-name dev-0\n",
"%cd ../"
],
"metadata": {
"collapsed": false,
"pycharm": {
"is_executing": true
}
}
} }
], ],
"metadata": { "metadata": {

File diff suppressed because it is too large Load Diff