add solution etc

This commit is contained in:
Kacper 2023-05-26 11:24:47 +02:00
parent 5a5265fd3d
commit 4278d2f8ea
3 changed files with 313 additions and 1 deletions

View File

@ -1 +1,4 @@
# WIP # Użyte elementy z wykładu/ćwiczeń:
- pełny lewy kontekst skompresowany do jednego tensora obok dwustronnego kontekstu trigramowego
- warstwy layer norm
- warstwy dropout

15
gonito.yaml Normal file
View File

@ -0,0 +1,15 @@
description: neural network with trigram left-right context plus full left context tensor
tags:
- neural-network
- left-context
- right-context
- trigrams
params:
vocab_size: 20000
embed_size: 150
batch_size: 4096
hidden_size: 1024
learning_rate: 0.001
epochs: 10
param-files:
- "*.yaml"

294
solution.ipynb Normal file
View File

@ -0,0 +1,294 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"from torchtext.vocab import build_vocab_from_iterator\n",
"import pickle\n",
"from torch.utils.data import IterableDataset\n",
"from torch import nn\n",
"import torch\n",
"import lzma\n",
"from torch.utils.data import DataLoader\n",
"import shutil\n",
"torch.manual_seed(1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"def simple_preprocess(line):\n",
" return line.replace(r'\\n', ' ')\n",
"\n",
"def get_max_left_context_len(file_name):\n",
" print('Getting max left context length...')\n",
" max_len = 0\n",
" with lzma.open(file_name, 'r') as fh:\n",
" for line in fh:\n",
" line = line.decode('utf-8')\n",
" line = line.strip()\n",
" line = line.split('\\t')[-2]\n",
" line = simple_preprocess(line)\n",
" curr_len = len(line.split())\n",
" if curr_len > max_len:\n",
" max_len = curr_len\n",
" print(f'max_len={max_len}')\n",
" return max_len\n",
"\n",
"def get_words_from_line(line):\n",
" for t in line:\n",
" yield t\n",
"\n",
"def get_word_lines_from_file(file_name, max_left_context_len, return_gen, n_size=-1):\n",
" with lzma.open(file_name, 'r') as fh:\n",
" n = 0\n",
" for line in fh:\n",
" n += 1\n",
" line = line.decode('utf-8')\n",
" line = line.strip()\n",
" padding = '<pad> ' * (max_left_context_len - 1) # <s>\n",
" left_context = padding + '<s> ' + simple_preprocess(line.split('\\t')[-2])\n",
" right_context = simple_preprocess(line.split('\\t')[-1]) + ' </s> <pad> <pad>'\n",
" line = left_context + ' ' + right_context\n",
" line = line.split()\n",
" if return_gen:\n",
" yield get_words_from_line(line)\n",
" else:\n",
" yield line\n",
" if n == n_size:\n",
" break\n",
"\n",
"def look_ahead_iterator(gen, vocab, max_left_context_len):\n",
" for item in gen:\n",
" start_pos = item.index('<s>') + 1\n",
" item = [vocab[t] for t in item]\n",
" for i in range(start_pos, len(item) - 4):\n",
" yield [item[:i-3][-max_left_context_len+3:], item[i-3:i], item[i], item[i+1:i+4]]\n",
"\n",
"def build_vocab(file, vocab_size, max_left_context_len):\n",
" try:\n",
" with open(f'vocab_{vocab_size}_padded.pickle', 'rb') as handle:\n",
" print('Loading vocab...')\n",
" vocab = pickle.load(handle)\n",
" except:\n",
" print('Building vocab...')\n",
" vocab = build_vocab_from_iterator(\n",
" get_word_lines_from_file(file, max_left_context_len, return_gen=True),\n",
" max_tokens = vocab_size,\n",
" specials = ['<unk>'])\n",
" with open(f'vocab_{vocab_size}_padded.pickle', 'wb') as handle:\n",
" pickle.dump(vocab, handle, protocol=pickle.HIGHEST_PROTOCOL)\n",
" return vocab\n",
"\n",
"class Ngrams(IterableDataset):\n",
" def __init__(self, text_file, max_left_context_len):\n",
" self.vocab = vocab\n",
" self.vocab.set_default_index(self.vocab['<unk>'])\n",
" self.text_file = text_file\n",
" self.max_left_context_len = max_left_context_len\n",
"\n",
" def __iter__(self):\n",
" return look_ahead_iterator(get_word_lines_from_file(self.text_file, max_left_context_len, return_gen=False), self.vocab, self.max_left_context_len)\n",
"\n",
"# Dropout, norm layers adjusted on a case-by-case basis. Also gradual hidden layer size reduction vs. no reduction\n",
"class NeuralLanguageModel(nn.Module):\n",
" def __init__(self, vocab_size, embed_size, hidden_size):\n",
" super(NeuralLanguageModel, self).__init__()\n",
" self.embeddings = nn.Embedding(vocab_size, embed_size)\n",
" self.hidden_1 = nn.Linear(7*embed_size, hidden_size)\n",
" self.hidden_2 = nn.Linear(hidden_size, int(hidden_size/2))\n",
" self.hidden_3 = nn.Linear(int(hidden_size/2), int(hidden_size/4))\n",
" self.output = nn.Linear(int(hidden_size/4), vocab_size)\n",
"\n",
" self.softmax = nn.Softmax(dim=1)\n",
" self.norm_input = nn.LayerNorm(7*embed_size)\n",
" self.norm_1 = nn.LayerNorm(int(hidden_size))\n",
" self.norm_2 = nn.LayerNorm(int(hidden_size/2))\n",
" self.norm_3 = nn.LayerNorm(int(hidden_size/4))\n",
" self.activation = nn.LeakyReLU()\n",
" self.dropout = nn.Dropout(0.1)\n",
"\n",
" def forward(self, x):\n",
" x_whole_left, x_left_trigram, x_right_trigram = x\n",
" x_whole_left_embed = [self.embeddings(t) for t in x_whole_left]\n",
" x_whole_left_embed_len = len(x_whole_left_embed)\n",
" x_whole_left_embed = torch.stack(x_whole_left_embed)\n",
" x_whole_left_embed = torch.sum(x_whole_left_embed, dim=0) / x_whole_left_embed_len\n",
" #x_whole_left_embed = torch.sum(x_whole_left_embed, dim=0)\n",
" x_left_trigram_embed = torch.concat([self.embeddings(t) for t in x_left_trigram], dim=1)\n",
" x_right_trigram_embed = torch.concat([self.embeddings(t) for t in x_right_trigram], dim=1)\n",
" concat_embed = torch.concat((x_whole_left_embed, x_left_trigram_embed, x_right_trigram_embed), dim=1)\n",
" if torch.isnan(concat_embed).any():\n",
" print('NaN!')\n",
" raise Exception(\"Error\")\n",
" concat_embed = self.norm_input(concat_embed)\n",
" z = self.hidden_1(concat_embed)\n",
" z = self.norm_1(z)\n",
" z = self.activation(z)\n",
" #z = self.dropout(z)\n",
" z = self.hidden_2(z)\n",
" z = self.norm_2(z)\n",
" z = self.activation(z)\n",
" #z = self.dropout(z)\n",
" z = self.hidden_3(z)\n",
" z = self.norm_3(z)\n",
" z = self.activation(z)\n",
" #z = self.dropout(z)\n",
" z = self.output(z)\n",
" y = self.softmax(z)\n",
" return y"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"# Sample parameters\n",
"max_steps = -1\n",
"vocab_size = 20000\n",
"embed_size = 150\n",
"batch_size = 4096\n",
"hidden_size = 1024\n",
"learning_rate = 0.001 # < 0.1\n",
"epochs = 1\n",
"#max_left_context_len = get_max_left_context_len('challenging-america-word-gap-prediction/train/in.tsv.xz')\n",
"max_left_context_len = 291\n",
"torch.manual_seed(1)\n",
"vocab = build_vocab('challenging-america-word-gap-prediction/train/in.tsv.xz', vocab_size, max_left_context_len)\n",
"train_dataset = Ngrams('challenging-america-word-gap-prediction/train/in.tsv.xz', max_left_context_len)\n",
"if torch.cuda.is_available():\n",
" device = 'cuda'\n",
"else:\n",
" raise Exception()\n",
"model = NeuralLanguageModel(vocab_size, embed_size, hidden_size).to(device)\n",
"#model.load_state_dict(torch.load(model_name))\n",
"data = DataLoader(train_dataset, batch_size=batch_size)\n",
"optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)\n",
"criterion = torch.nn.NLLLoss()\n",
"\n",
"with torch.autograd.set_detect_anomaly(True):\n",
" model.train()\n",
" epoch = 0\n",
" for i in range(epochs):\n",
" step = 0\n",
" epoch += 1\n",
" print(f'--------epoch {epoch}--------')\n",
" for x_whole_left, x_left_trigram, y, x_right_trigram in data:\n",
" x = [t.to(device) for t in x_whole_left], [t.to(device) for t in x_left_trigram], [t.to(device) for t in x_right_trigram]\n",
" y = y.to(device)\n",
" optimizer.zero_grad()\n",
" y_pred = model(x)\n",
" loss = criterion(torch.log(y_pred), y)\n",
" if step % 1000 == 0:\n",
" print(f'steps: {step}, loss: {loss.item()}')\n",
" if step != 0:\n",
" name = f'loss-{loss.item()}_model_steps-{step}_epoch-{epoch}_vocab-{vocab_size}_embed-{embed_size}_batch-{batch_size}_hidden-{hidden_size}_lr-{learning_rate}.bin'\n",
" torch.save(model.state_dict(), 'models/' + name)\n",
" loss.backward()\n",
" torch.nn.utils.clip_grad_value_(model.parameters(), clip_value=1.0)\n",
" optimizer.step()\n",
" if step == max_steps:\n",
" break\n",
" step += 1"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"step += 1\n",
"vocab_size = 20000\n",
"embed_size = 150\n",
"batch_size = 4096\n",
"hidden_size = 1024\n",
"max_left_context_len = 291\n",
"vocab = build_vocab('challenging-america-word-gap-prediction/train/in.tsv.xz', vocab_size, max_left_context_len)\n",
"vocab.set_default_index(vocab['<unk>'])\n",
"model_name = 'models/' + 'best_model_mod_arch.bin'\n",
"topk = 10"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"preds = []\n",
"device = 'cuda'\n",
"model = NeuralLanguageModel(vocab_size, embed_size, hidden_size).to(device)\n",
"model.load_state_dict(torch.load(model_name))\n",
"model.eval()\n",
"j = 0\n",
"for path in ['challenging-america-word-gap-prediction/dev-0', 'challenging-america-word-gap-prediction/test-A']:\n",
" with lzma.open(f'{path}/in.tsv.xz', 'r') as fh, open(f'{path}/out.tsv', 'w', encoding='utf-8') as f_out:\n",
" for line in fh:\n",
" j += 1\n",
" left_context = simple_preprocess(line.decode('utf-8')).split('\\t')[-2].strip()\n",
" right_context = simple_preprocess(line.decode('utf-8')).split('\\t')[-1].strip()\n",
" padding = '<pad> ' * (max_left_context_len - 1) # <s>\n",
" left_context = padding + '<s> ' + left_context\n",
" right_context = right_context + ' </s> <pad> <pad>'\n",
" x_left_trigram, x_right_trigram = left_context.split()[-3:], right_context.split()[:3]\n",
" x = [torch.tensor(vocab.forward([w])).to(device) for w in left_context], [torch.tensor(vocab.forward([w])).to(device) for w in x_left_trigram], [torch.tensor(vocab.forward([w])).to(device) for w in x_right_trigram]\n",
" out = model(x)\n",
" top = torch.topk(out[0], topk)\n",
" top_indices = top.indices.tolist()\n",
" print(j, ' '.join(x_left_trigram), '[[[', vocab.lookup_token(top_indices[0]) if vocab.lookup_token(top_indices[0]) != '<unk>' else vocab.lookup_token(top_indices[1]), ']]]', ' '.join(x_right_trigram))\n",
" top_probs = top.values.tolist()\n",
" top_words = vocab.lookup_tokens(top_indices)\n",
" top_zipped = zip(top_words, top_probs)\n",
" pred = ''\n",
" total_prob = 0\n",
" for word, prob in top_zipped:\n",
" if word != '<unk>':\n",
" pred += f'{word}:{prob} '\n",
" total_prob += prob\n",
" unk_prob = 1 - total_prob\n",
" pred += f':{unk_prob}'\n",
" f_out.write(pred + '\\n')"
],
"metadata": {
"collapsed": false
}
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 0
}