solution
This commit is contained in:
parent
4278d2f8ea
commit
abb72efa82
@ -1,4 +1,3 @@
|
||||
# Użyte elementy z wykładu/ćwiczeń:
|
||||
- pełny lewy kontekst skompresowany do jednego tensora obok dwustronnego kontekstu trigramowego
|
||||
- warstwy layer norm
|
||||
- warstwy dropout
|
||||
# Rozszerzenia:
|
||||
- uwzględniony prawy kontekst
|
||||
- dwukierunkowy LSTM
|
||||
|
21038
dev-0/out.tsv
21038
dev-0/out.tsv
File diff suppressed because it is too large
Load Diff
@ -1,15 +1,14 @@
|
||||
description: neural network with trigram left-right context plus full left context tensor
|
||||
description: lstm with trigram left-right context of 3 ngrams
|
||||
tags:
|
||||
- neural-network
|
||||
- left-context
|
||||
- right-context
|
||||
- trigrams
|
||||
- lstm
|
||||
params:
|
||||
vocab_size: 20000
|
||||
embed_size: 150
|
||||
batch_size: 4096
|
||||
hidden_size: 1024
|
||||
learning_rate: 0.001
|
||||
epochs: 10
|
||||
learning_rate: 0.0001
|
||||
param-files:
|
||||
- "*.yaml"
|
||||
|
226
solution.ipynb
226
solution.ipynb
@ -3,146 +3,99 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from torchtext.vocab import build_vocab_from_iterator\n",
|
||||
"import pickle\n",
|
||||
"from torch.utils.data import IterableDataset\n",
|
||||
"from itertools import chain\n",
|
||||
"from torch import nn\n",
|
||||
"import torch.nn.functional as F\n",
|
||||
"import torch\n",
|
||||
"import lzma\n",
|
||||
"from torch.utils.data import DataLoader\n",
|
||||
"import shutil\n",
|
||||
"torch.manual_seed(1)"
|
||||
]
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"is_executing": true
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 60,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def simple_preprocess(line):\n",
|
||||
" return line.replace(r'\\n', ' ')\n",
|
||||
"\n",
|
||||
"def get_max_left_context_len(file_name):\n",
|
||||
" print('Getting max left context length...')\n",
|
||||
" max_len = 0\n",
|
||||
" with lzma.open(file_name, 'r') as fh:\n",
|
||||
" for line in fh:\n",
|
||||
" line = line.decode('utf-8')\n",
|
||||
" line = line.strip()\n",
|
||||
" line = line.split('\\t')[-2]\n",
|
||||
" line = simple_preprocess(line)\n",
|
||||
" curr_len = len(line.split())\n",
|
||||
" if curr_len > max_len:\n",
|
||||
" max_len = curr_len\n",
|
||||
" print(f'max_len={max_len}')\n",
|
||||
" return max_len\n",
|
||||
"\n",
|
||||
"def get_words_from_line(line):\n",
|
||||
" for t in line:\n",
|
||||
" line = line.strip()\n",
|
||||
" line = simple_preprocess(line)\n",
|
||||
" yield '<s>'\n",
|
||||
" for t in line.split():\n",
|
||||
" yield t\n",
|
||||
" yield '</s>'\n",
|
||||
"\n",
|
||||
"def get_word_lines_from_file(file_name, max_left_context_len, return_gen, n_size=-1):\n",
|
||||
"def get_word_lines_from_file(file_name, n_size=-1):\n",
|
||||
" with lzma.open(file_name, 'r') as fh:\n",
|
||||
" n = 0\n",
|
||||
" for line in fh:\n",
|
||||
" n += 1\n",
|
||||
" line = line.decode('utf-8')\n",
|
||||
" line = line.strip()\n",
|
||||
" padding = '<pad> ' * (max_left_context_len - 1) # <s>\n",
|
||||
" left_context = padding + '<s> ' + simple_preprocess(line.split('\\t')[-2])\n",
|
||||
" right_context = simple_preprocess(line.split('\\t')[-1]) + ' </s> <pad> <pad>'\n",
|
||||
" line = left_context + ' ' + right_context\n",
|
||||
" line = line.split()\n",
|
||||
" if return_gen:\n",
|
||||
" yield get_words_from_line(line)\n",
|
||||
" else:\n",
|
||||
" yield line\n",
|
||||
" yield get_words_from_line(line.decode('utf-8'))\n",
|
||||
" if n == n_size:\n",
|
||||
" break\n",
|
||||
"\n",
|
||||
"def look_ahead_iterator(gen, vocab, max_left_context_len):\n",
|
||||
"def look_ahead_iterator(gen):\n",
|
||||
" ngram = []\n",
|
||||
" for item in gen:\n",
|
||||
" start_pos = item.index('<s>') + 1\n",
|
||||
" item = [vocab[t] for t in item]\n",
|
||||
" for i in range(start_pos, len(item) - 4):\n",
|
||||
" yield [item[:i-3][-max_left_context_len+3:], item[i-3:i], item[i], item[i+1:i+4]]\n",
|
||||
" if len(ngram) < 7:\n",
|
||||
" ngram.append(item)\n",
|
||||
" if len(ngram) == 7:\n",
|
||||
" yield ngram\n",
|
||||
" else:\n",
|
||||
" ngram = ngram[1:]\n",
|
||||
" ngram.append(item)\n",
|
||||
" yield ngram\n",
|
||||
"\n",
|
||||
"def build_vocab(file, vocab_size, max_left_context_len):\n",
|
||||
"def build_vocab(file, vocab_size):\n",
|
||||
" try:\n",
|
||||
" with open(f'vocab_{vocab_size}_padded.pickle', 'rb') as handle:\n",
|
||||
" print('Loading vocab...')\n",
|
||||
" with open(f'trigram_nn_vocab_{vocab_size}.pickle', 'rb') as handle:\n",
|
||||
" vocab = pickle.load(handle)\n",
|
||||
" except:\n",
|
||||
" print('Building vocab...')\n",
|
||||
" vocab = build_vocab_from_iterator(\n",
|
||||
" get_word_lines_from_file(file, max_left_context_len, return_gen=True),\n",
|
||||
" get_word_lines_from_file(file),\n",
|
||||
" max_tokens = vocab_size,\n",
|
||||
" specials = ['<unk>'])\n",
|
||||
" with open(f'vocab_{vocab_size}_padded.pickle', 'wb') as handle:\n",
|
||||
" with open(f'trigram_nn_vocab_{vocab_size}.pickle', 'wb') as handle:\n",
|
||||
" pickle.dump(vocab, handle, protocol=pickle.HIGHEST_PROTOCOL)\n",
|
||||
" return vocab\n",
|
||||
"\n",
|
||||
"class Ngrams(IterableDataset):\n",
|
||||
" def __init__(self, text_file, max_left_context_len):\n",
|
||||
"class Trigrams(IterableDataset):\n",
|
||||
" def __init__(self, text_file):\n",
|
||||
" self.vocab = vocab\n",
|
||||
" self.vocab.set_default_index(self.vocab['<unk>'])\n",
|
||||
" self.text_file = text_file\n",
|
||||
" self.max_left_context_len = max_left_context_len\n",
|
||||
"\n",
|
||||
" def __iter__(self):\n",
|
||||
" return look_ahead_iterator(get_word_lines_from_file(self.text_file, max_left_context_len, return_gen=False), self.vocab, self.max_left_context_len)\n",
|
||||
" return look_ahead_iterator(\n",
|
||||
" (self.vocab[t] for t in chain.from_iterable(get_word_lines_from_file(self.text_file))))\n",
|
||||
"\n",
|
||||
"# Dropout, norm layers adjusted on a case-by-case basis. Also gradual hidden layer size reduction vs. no reduction\n",
|
||||
"class NeuralLanguageModel(nn.Module):\n",
|
||||
"class LSTMLanguageModel(nn.Module):\n",
|
||||
" def __init__(self, vocab_size, embed_size, hidden_size):\n",
|
||||
" super(NeuralLanguageModel, self).__init__()\n",
|
||||
" super(LSTMLanguageModel, self).__init__()\n",
|
||||
" self.embeddings = nn.Embedding(vocab_size, embed_size)\n",
|
||||
" self.hidden_1 = nn.Linear(7*embed_size, hidden_size)\n",
|
||||
" self.hidden_2 = nn.Linear(hidden_size, int(hidden_size/2))\n",
|
||||
" self.hidden_3 = nn.Linear(int(hidden_size/2), int(hidden_size/4))\n",
|
||||
" self.output = nn.Linear(int(hidden_size/4), vocab_size)\n",
|
||||
"\n",
|
||||
" self.lstm_layer = nn.LSTM(6*embed_size, hidden_size, bidirectional=True)\n",
|
||||
" self.output_layer = nn.Linear(2*hidden_size, vocab_size)\n",
|
||||
" self.softmax = nn.Softmax(dim=1)\n",
|
||||
" self.norm_input = nn.LayerNorm(7*embed_size)\n",
|
||||
" self.norm_1 = nn.LayerNorm(int(hidden_size))\n",
|
||||
" self.norm_2 = nn.LayerNorm(int(hidden_size/2))\n",
|
||||
" self.norm_3 = nn.LayerNorm(int(hidden_size/4))\n",
|
||||
" self.activation = nn.LeakyReLU()\n",
|
||||
" self.dropout = nn.Dropout(0.1)\n",
|
||||
"\n",
|
||||
" def forward(self, x):\n",
|
||||
" x_whole_left, x_left_trigram, x_right_trigram = x\n",
|
||||
" x_whole_left_embed = [self.embeddings(t) for t in x_whole_left]\n",
|
||||
" x_whole_left_embed_len = len(x_whole_left_embed)\n",
|
||||
" x_whole_left_embed = torch.stack(x_whole_left_embed)\n",
|
||||
" x_whole_left_embed = torch.sum(x_whole_left_embed, dim=0) / x_whole_left_embed_len\n",
|
||||
" #x_whole_left_embed = torch.sum(x_whole_left_embed, dim=0)\n",
|
||||
" x_left_trigram_embed = torch.concat([self.embeddings(t) for t in x_left_trigram], dim=1)\n",
|
||||
" x_right_trigram_embed = torch.concat([self.embeddings(t) for t in x_right_trigram], dim=1)\n",
|
||||
" concat_embed = torch.concat((x_whole_left_embed, x_left_trigram_embed, x_right_trigram_embed), dim=1)\n",
|
||||
" if torch.isnan(concat_embed).any():\n",
|
||||
" print('NaN!')\n",
|
||||
" raise Exception(\"Error\")\n",
|
||||
" concat_embed = self.norm_input(concat_embed)\n",
|
||||
" z = self.hidden_1(concat_embed)\n",
|
||||
" z = self.norm_1(z)\n",
|
||||
" z = self.activation(z)\n",
|
||||
" #z = self.dropout(z)\n",
|
||||
" z = self.hidden_2(z)\n",
|
||||
" z = self.norm_2(z)\n",
|
||||
" z = self.activation(z)\n",
|
||||
" #z = self.dropout(z)\n",
|
||||
" z = self.hidden_3(z)\n",
|
||||
" z = self.norm_3(z)\n",
|
||||
" z = self.activation(z)\n",
|
||||
" #z = self.dropout(z)\n",
|
||||
" z = self.output(z)\n",
|
||||
" y = self.softmax(z)\n",
|
||||
" embeds = [self.embeddings(gram) for gram in x]\n",
|
||||
" concat_embed = torch.concat(embeds, dim=1)\n",
|
||||
" z = F.relu(self.lstm_layer(concat_embed)[0])\n",
|
||||
" y = self.softmax(self.output_layer(z))\n",
|
||||
" return y"
|
||||
],
|
||||
"metadata": {
|
||||
@ -158,69 +111,56 @@
|
||||
"max_steps = -1\n",
|
||||
"vocab_size = 20000\n",
|
||||
"embed_size = 150\n",
|
||||
"batch_size = 4096\n",
|
||||
"hidden_size = 1024\n",
|
||||
"learning_rate = 0.001 # < 0.1\n",
|
||||
"epochs = 1\n",
|
||||
"#max_left_context_len = get_max_left_context_len('challenging-america-word-gap-prediction/train/in.tsv.xz')\n",
|
||||
"max_left_context_len = 291\n",
|
||||
"torch.manual_seed(1)\n",
|
||||
"vocab = build_vocab('challenging-america-word-gap-prediction/train/in.tsv.xz', vocab_size, max_left_context_len)\n",
|
||||
"train_dataset = Ngrams('challenging-america-word-gap-prediction/train/in.tsv.xz', max_left_context_len)\n",
|
||||
"batch_size = 1024\n",
|
||||
"hidden_size = 512\n",
|
||||
"learning_rate = 0.0001\n",
|
||||
"vocab = build_vocab('challenging-america-word-gap-prediction/train/in.tsv.xz', vocab_size)\n",
|
||||
"train_dataset = Trigrams('challenging-america-word-gap-prediction/train/in.tsv.xz')\n",
|
||||
"if torch.cuda.is_available():\n",
|
||||
" device = 'cuda'\n",
|
||||
"else:\n",
|
||||
" raise Exception()\n",
|
||||
"model = NeuralLanguageModel(vocab_size, embed_size, hidden_size).to(device)\n",
|
||||
"#model.load_state_dict(torch.load(model_name))\n",
|
||||
"model = LSTMLanguageModel(vocab_size, embed_size, hidden_size).to(device)\n",
|
||||
"data = DataLoader(train_dataset, batch_size=batch_size)\n",
|
||||
"optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)\n",
|
||||
"criterion = torch.nn.NLLLoss()\n",
|
||||
"\n",
|
||||
"with torch.autograd.set_detect_anomaly(True):\n",
|
||||
"model.train()\n",
|
||||
" epoch = 0\n",
|
||||
" for i in range(epochs):\n",
|
||||
"step = 0\n",
|
||||
" epoch += 1\n",
|
||||
" print(f'--------epoch {epoch}--------')\n",
|
||||
" for x_whole_left, x_left_trigram, y, x_right_trigram in data:\n",
|
||||
" x = [t.to(device) for t in x_whole_left], [t.to(device) for t in x_left_trigram], [t.to(device) for t in x_right_trigram]\n",
|
||||
" y = y.to(device)\n",
|
||||
"for ngram in data:\n",
|
||||
" x = [gram.to(device) for gram in ngram[:3]+ngram[4:]]\n",
|
||||
" y = ngram[3].to(device)\n",
|
||||
" optimizer.zero_grad()\n",
|
||||
" y_pred = model(x)\n",
|
||||
" loss = criterion(torch.log(y_pred), y)\n",
|
||||
" if step % 1000 == 0:\n",
|
||||
" ypredicted = model(x)\n",
|
||||
" loss = criterion(torch.log(ypredicted), y)\n",
|
||||
" if step % 100 == 0:\n",
|
||||
" print(f'steps: {step}, loss: {loss.item()}')\n",
|
||||
" if step % 1000 == 0:\n",
|
||||
" if step != 0:\n",
|
||||
" name = f'loss-{loss.item()}_model_steps-{step}_epoch-{epoch}_vocab-{vocab_size}_embed-{embed_size}_batch-{batch_size}_hidden-{hidden_size}_lr-{learning_rate}.bin'\n",
|
||||
" torch.save(model.state_dict(), 'models/' + name)\n",
|
||||
" torch.save(model.state_dict(), f'{loss}_steps-{step}_vocab-{vocab_size}_embed-{embed_size}_batch-{batch_size}_hidden-{hidden_size}_lr-{learning_rate}.bin')\n",
|
||||
" loss.backward()\n",
|
||||
" torch.nn.utils.clip_grad_value_(model.parameters(), clip_value=1.0)\n",
|
||||
" optimizer.step()\n",
|
||||
" if step == max_steps:\n",
|
||||
" break\n",
|
||||
" step += 1"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"is_executing": true
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 69,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"step += 1\n",
|
||||
"vocab_size = 20000\n",
|
||||
"embed_size = 150\n",
|
||||
"batch_size = 4096\n",
|
||||
"hidden_size = 1024\n",
|
||||
"max_left_context_len = 291\n",
|
||||
"vocab = build_vocab('challenging-america-word-gap-prediction/train/in.tsv.xz', vocab_size, max_left_context_len)\n",
|
||||
"vocab.set_default_index(vocab['<unk>'])\n",
|
||||
"model_name = 'models/' + 'best_model_mod_arch.bin'\n",
|
||||
"topk = 10"
|
||||
"hidden_size = 512\n",
|
||||
"vocab = build_vocab('challenging-america-word-gap-prediction/train/in.tsv.xz', vocab_size)\n",
|
||||
"vocab.set_default_index(vocab['<unk>'])"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
@ -228,30 +168,26 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 74,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"for model_name in ['best.bin']:\n",
|
||||
" topk = 100\n",
|
||||
" preds = []\n",
|
||||
" device = 'cuda'\n",
|
||||
"model = NeuralLanguageModel(vocab_size, embed_size, hidden_size).to(device)\n",
|
||||
" model = LSTMLanguageModel(vocab_size, embed_size, hidden_size).to(device)\n",
|
||||
" model.load_state_dict(torch.load(model_name))\n",
|
||||
" model.eval()\n",
|
||||
"j = 0\n",
|
||||
" for path in ['challenging-america-word-gap-prediction/dev-0', 'challenging-america-word-gap-prediction/test-A']:\n",
|
||||
" with lzma.open(f'{path}/in.tsv.xz', 'r') as fh, open(f'{path}/out.tsv', 'w', encoding='utf-8') as f_out:\n",
|
||||
" for line in fh:\n",
|
||||
" j += 1\n",
|
||||
" left_context = simple_preprocess(line.decode('utf-8')).split('\\t')[-2].strip()\n",
|
||||
" right_context = simple_preprocess(line.decode('utf-8')).split('\\t')[-1].strip()\n",
|
||||
" padding = '<pad> ' * (max_left_context_len - 1) # <s>\n",
|
||||
" left_context = padding + '<s> ' + left_context\n",
|
||||
" right_context = right_context + ' </s> <pad> <pad>'\n",
|
||||
" x_left_trigram, x_right_trigram = left_context.split()[-3:], right_context.split()[:3]\n",
|
||||
" x = [torch.tensor(vocab.forward([w])).to(device) for w in left_context], [torch.tensor(vocab.forward([w])).to(device) for w in x_left_trigram], [torch.tensor(vocab.forward([w])).to(device) for w in x_right_trigram]\n",
|
||||
" left_context = simple_preprocess(line.decode('utf-8').split('\\t')[-2].strip()).split()[-3:]\n",
|
||||
" right_context = simple_preprocess(line.decode('utf-8').split('\\t')[-1].strip()).split()[:3]\n",
|
||||
" full_context = left_context + right_context\n",
|
||||
" x = [torch.tensor(vocab.forward([word])).to(device) for word in full_context]\n",
|
||||
" out = model(x)\n",
|
||||
" top = torch.topk(out[0], topk)\n",
|
||||
" top_indices = top.indices.tolist()\n",
|
||||
" print(j, ' '.join(x_left_trigram), '[[[', vocab.lookup_token(top_indices[0]) if vocab.lookup_token(top_indices[0]) != '<unk>' else vocab.lookup_token(top_indices[1]), ']]]', ' '.join(x_right_trigram))\n",
|
||||
" top_probs = top.values.tolist()\n",
|
||||
" top_words = vocab.lookup_tokens(top_indices)\n",
|
||||
" top_zipped = zip(top_words, top_probs)\n",
|
||||
@ -268,6 +204,22 @@
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%cd challenging-america-word-gap-prediction/\n",
|
||||
"!./geval --test-name dev-0\n",
|
||||
"%cd ../"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"is_executing": true
|
||||
}
|
||||
}
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
14828
test-A/out.tsv
14828
test-A/out.tsv
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user