solution
This commit is contained in:
parent
4278d2f8ea
commit
abb72efa82
@ -1,4 +1,3 @@
|
|||||||
# Użyte elementy z wykładu/ćwiczeń:
|
# Rozszerzenia:
|
||||||
- pełny lewy kontekst skompresowany do jednego tensora obok dwustronnego kontekstu trigramowego
|
- uwzględniony prawy kontekst
|
||||||
- warstwy layer norm
|
- dwukierunkowy LSTM
|
||||||
- warstwy dropout
|
|
||||||
|
21038
dev-0/out.tsv
21038
dev-0/out.tsv
File diff suppressed because it is too large
Load Diff
@ -1,15 +1,14 @@
|
|||||||
description: neural network with trigram left-right context plus full left context tensor
|
description: lstm with trigram left-right context of 3 ngrams
|
||||||
tags:
|
tags:
|
||||||
- neural-network
|
- neural-network
|
||||||
- left-context
|
- left-context
|
||||||
- right-context
|
- right-context
|
||||||
- trigrams
|
- lstm
|
||||||
params:
|
params:
|
||||||
vocab_size: 20000
|
vocab_size: 20000
|
||||||
embed_size: 150
|
embed_size: 150
|
||||||
batch_size: 4096
|
batch_size: 4096
|
||||||
hidden_size: 1024
|
hidden_size: 1024
|
||||||
learning_rate: 0.001
|
learning_rate: 0.0001
|
||||||
epochs: 10
|
|
||||||
param-files:
|
param-files:
|
||||||
- "*.yaml"
|
- "*.yaml"
|
||||||
|
290
solution.ipynb
290
solution.ipynb
@ -3,146 +3,99 @@
|
|||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"metadata": {
|
|
||||||
"collapsed": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"from torchtext.vocab import build_vocab_from_iterator\n",
|
"from torchtext.vocab import build_vocab_from_iterator\n",
|
||||||
"import pickle\n",
|
"import pickle\n",
|
||||||
"from torch.utils.data import IterableDataset\n",
|
"from torch.utils.data import IterableDataset\n",
|
||||||
|
"from itertools import chain\n",
|
||||||
"from torch import nn\n",
|
"from torch import nn\n",
|
||||||
|
"import torch.nn.functional as F\n",
|
||||||
"import torch\n",
|
"import torch\n",
|
||||||
"import lzma\n",
|
"import lzma\n",
|
||||||
"from torch.utils.data import DataLoader\n",
|
"from torch.utils.data import DataLoader\n",
|
||||||
"import shutil\n",
|
|
||||||
"torch.manual_seed(1)"
|
"torch.manual_seed(1)"
|
||||||
]
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"is_executing": true
|
||||||
|
}
|
||||||
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": 60,
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"def simple_preprocess(line):\n",
|
"def simple_preprocess(line):\n",
|
||||||
" return line.replace(r'\\n', ' ')\n",
|
" return line.replace(r'\\n', ' ')\n",
|
||||||
"\n",
|
"\n",
|
||||||
"def get_max_left_context_len(file_name):\n",
|
|
||||||
" print('Getting max left context length...')\n",
|
|
||||||
" max_len = 0\n",
|
|
||||||
" with lzma.open(file_name, 'r') as fh:\n",
|
|
||||||
" for line in fh:\n",
|
|
||||||
" line = line.decode('utf-8')\n",
|
|
||||||
" line = line.strip()\n",
|
|
||||||
" line = line.split('\\t')[-2]\n",
|
|
||||||
" line = simple_preprocess(line)\n",
|
|
||||||
" curr_len = len(line.split())\n",
|
|
||||||
" if curr_len > max_len:\n",
|
|
||||||
" max_len = curr_len\n",
|
|
||||||
" print(f'max_len={max_len}')\n",
|
|
||||||
" return max_len\n",
|
|
||||||
"\n",
|
|
||||||
"def get_words_from_line(line):\n",
|
"def get_words_from_line(line):\n",
|
||||||
" for t in line:\n",
|
" line = line.strip()\n",
|
||||||
|
" line = simple_preprocess(line)\n",
|
||||||
|
" yield '<s>'\n",
|
||||||
|
" for t in line.split():\n",
|
||||||
" yield t\n",
|
" yield t\n",
|
||||||
|
" yield '</s>'\n",
|
||||||
"\n",
|
"\n",
|
||||||
"def get_word_lines_from_file(file_name, max_left_context_len, return_gen, n_size=-1):\n",
|
"def get_word_lines_from_file(file_name, n_size=-1):\n",
|
||||||
" with lzma.open(file_name, 'r') as fh:\n",
|
" with lzma.open(file_name, 'r') as fh:\n",
|
||||||
" n = 0\n",
|
" n = 0\n",
|
||||||
" for line in fh:\n",
|
" for line in fh:\n",
|
||||||
" n += 1\n",
|
" n += 1\n",
|
||||||
" line = line.decode('utf-8')\n",
|
" yield get_words_from_line(line.decode('utf-8'))\n",
|
||||||
" line = line.strip()\n",
|
|
||||||
" padding = '<pad> ' * (max_left_context_len - 1) # <s>\n",
|
|
||||||
" left_context = padding + '<s> ' + simple_preprocess(line.split('\\t')[-2])\n",
|
|
||||||
" right_context = simple_preprocess(line.split('\\t')[-1]) + ' </s> <pad> <pad>'\n",
|
|
||||||
" line = left_context + ' ' + right_context\n",
|
|
||||||
" line = line.split()\n",
|
|
||||||
" if return_gen:\n",
|
|
||||||
" yield get_words_from_line(line)\n",
|
|
||||||
" else:\n",
|
|
||||||
" yield line\n",
|
|
||||||
" if n == n_size:\n",
|
" if n == n_size:\n",
|
||||||
" break\n",
|
" break\n",
|
||||||
"\n",
|
"\n",
|
||||||
"def look_ahead_iterator(gen, vocab, max_left_context_len):\n",
|
"def look_ahead_iterator(gen):\n",
|
||||||
|
" ngram = []\n",
|
||||||
" for item in gen:\n",
|
" for item in gen:\n",
|
||||||
" start_pos = item.index('<s>') + 1\n",
|
" if len(ngram) < 7:\n",
|
||||||
" item = [vocab[t] for t in item]\n",
|
" ngram.append(item)\n",
|
||||||
" for i in range(start_pos, len(item) - 4):\n",
|
" if len(ngram) == 7:\n",
|
||||||
" yield [item[:i-3][-max_left_context_len+3:], item[i-3:i], item[i], item[i+1:i+4]]\n",
|
" yield ngram\n",
|
||||||
|
" else:\n",
|
||||||
|
" ngram = ngram[1:]\n",
|
||||||
|
" ngram.append(item)\n",
|
||||||
|
" yield ngram\n",
|
||||||
"\n",
|
"\n",
|
||||||
"def build_vocab(file, vocab_size, max_left_context_len):\n",
|
"def build_vocab(file, vocab_size):\n",
|
||||||
" try:\n",
|
" try:\n",
|
||||||
" with open(f'vocab_{vocab_size}_padded.pickle', 'rb') as handle:\n",
|
" with open(f'trigram_nn_vocab_{vocab_size}.pickle', 'rb') as handle:\n",
|
||||||
" print('Loading vocab...')\n",
|
|
||||||
" vocab = pickle.load(handle)\n",
|
" vocab = pickle.load(handle)\n",
|
||||||
" except:\n",
|
" except:\n",
|
||||||
" print('Building vocab...')\n",
|
|
||||||
" vocab = build_vocab_from_iterator(\n",
|
" vocab = build_vocab_from_iterator(\n",
|
||||||
" get_word_lines_from_file(file, max_left_context_len, return_gen=True),\n",
|
" get_word_lines_from_file(file),\n",
|
||||||
" max_tokens = vocab_size,\n",
|
" max_tokens = vocab_size,\n",
|
||||||
" specials = ['<unk>'])\n",
|
" specials = ['<unk>'])\n",
|
||||||
" with open(f'vocab_{vocab_size}_padded.pickle', 'wb') as handle:\n",
|
" with open(f'trigram_nn_vocab_{vocab_size}.pickle', 'wb') as handle:\n",
|
||||||
" pickle.dump(vocab, handle, protocol=pickle.HIGHEST_PROTOCOL)\n",
|
" pickle.dump(vocab, handle, protocol=pickle.HIGHEST_PROTOCOL)\n",
|
||||||
" return vocab\n",
|
" return vocab\n",
|
||||||
"\n",
|
"\n",
|
||||||
"class Ngrams(IterableDataset):\n",
|
"class Trigrams(IterableDataset):\n",
|
||||||
" def __init__(self, text_file, max_left_context_len):\n",
|
" def __init__(self, text_file):\n",
|
||||||
" self.vocab = vocab\n",
|
" self.vocab = vocab\n",
|
||||||
" self.vocab.set_default_index(self.vocab['<unk>'])\n",
|
" self.vocab.set_default_index(self.vocab['<unk>'])\n",
|
||||||
" self.text_file = text_file\n",
|
" self.text_file = text_file\n",
|
||||||
" self.max_left_context_len = max_left_context_len\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
" def __iter__(self):\n",
|
" def __iter__(self):\n",
|
||||||
" return look_ahead_iterator(get_word_lines_from_file(self.text_file, max_left_context_len, return_gen=False), self.vocab, self.max_left_context_len)\n",
|
" return look_ahead_iterator(\n",
|
||||||
|
" (self.vocab[t] for t in chain.from_iterable(get_word_lines_from_file(self.text_file))))\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Dropout, norm layers adjusted on a case-by-case basis. Also gradual hidden layer size reduction vs. no reduction\n",
|
"class LSTMLanguageModel(nn.Module):\n",
|
||||||
"class NeuralLanguageModel(nn.Module):\n",
|
|
||||||
" def __init__(self, vocab_size, embed_size, hidden_size):\n",
|
" def __init__(self, vocab_size, embed_size, hidden_size):\n",
|
||||||
" super(NeuralLanguageModel, self).__init__()\n",
|
" super(LSTMLanguageModel, self).__init__()\n",
|
||||||
" self.embeddings = nn.Embedding(vocab_size, embed_size)\n",
|
" self.embeddings = nn.Embedding(vocab_size, embed_size)\n",
|
||||||
" self.hidden_1 = nn.Linear(7*embed_size, hidden_size)\n",
|
" self.lstm_layer = nn.LSTM(6*embed_size, hidden_size, bidirectional=True)\n",
|
||||||
" self.hidden_2 = nn.Linear(hidden_size, int(hidden_size/2))\n",
|
" self.output_layer = nn.Linear(2*hidden_size, vocab_size)\n",
|
||||||
" self.hidden_3 = nn.Linear(int(hidden_size/2), int(hidden_size/4))\n",
|
|
||||||
" self.output = nn.Linear(int(hidden_size/4), vocab_size)\n",
|
|
||||||
"\n",
|
|
||||||
" self.softmax = nn.Softmax(dim=1)\n",
|
" self.softmax = nn.Softmax(dim=1)\n",
|
||||||
" self.norm_input = nn.LayerNorm(7*embed_size)\n",
|
|
||||||
" self.norm_1 = nn.LayerNorm(int(hidden_size))\n",
|
|
||||||
" self.norm_2 = nn.LayerNorm(int(hidden_size/2))\n",
|
|
||||||
" self.norm_3 = nn.LayerNorm(int(hidden_size/4))\n",
|
|
||||||
" self.activation = nn.LeakyReLU()\n",
|
|
||||||
" self.dropout = nn.Dropout(0.1)\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
" def forward(self, x):\n",
|
" def forward(self, x):\n",
|
||||||
" x_whole_left, x_left_trigram, x_right_trigram = x\n",
|
" embeds = [self.embeddings(gram) for gram in x]\n",
|
||||||
" x_whole_left_embed = [self.embeddings(t) for t in x_whole_left]\n",
|
" concat_embed = torch.concat(embeds, dim=1)\n",
|
||||||
" x_whole_left_embed_len = len(x_whole_left_embed)\n",
|
" z = F.relu(self.lstm_layer(concat_embed)[0])\n",
|
||||||
" x_whole_left_embed = torch.stack(x_whole_left_embed)\n",
|
" y = self.softmax(self.output_layer(z))\n",
|
||||||
" x_whole_left_embed = torch.sum(x_whole_left_embed, dim=0) / x_whole_left_embed_len\n",
|
|
||||||
" #x_whole_left_embed = torch.sum(x_whole_left_embed, dim=0)\n",
|
|
||||||
" x_left_trigram_embed = torch.concat([self.embeddings(t) for t in x_left_trigram], dim=1)\n",
|
|
||||||
" x_right_trigram_embed = torch.concat([self.embeddings(t) for t in x_right_trigram], dim=1)\n",
|
|
||||||
" concat_embed = torch.concat((x_whole_left_embed, x_left_trigram_embed, x_right_trigram_embed), dim=1)\n",
|
|
||||||
" if torch.isnan(concat_embed).any():\n",
|
|
||||||
" print('NaN!')\n",
|
|
||||||
" raise Exception(\"Error\")\n",
|
|
||||||
" concat_embed = self.norm_input(concat_embed)\n",
|
|
||||||
" z = self.hidden_1(concat_embed)\n",
|
|
||||||
" z = self.norm_1(z)\n",
|
|
||||||
" z = self.activation(z)\n",
|
|
||||||
" #z = self.dropout(z)\n",
|
|
||||||
" z = self.hidden_2(z)\n",
|
|
||||||
" z = self.norm_2(z)\n",
|
|
||||||
" z = self.activation(z)\n",
|
|
||||||
" #z = self.dropout(z)\n",
|
|
||||||
" z = self.hidden_3(z)\n",
|
|
||||||
" z = self.norm_3(z)\n",
|
|
||||||
" z = self.activation(z)\n",
|
|
||||||
" #z = self.dropout(z)\n",
|
|
||||||
" z = self.output(z)\n",
|
|
||||||
" y = self.softmax(z)\n",
|
|
||||||
" return y"
|
" return y"
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
@ -158,69 +111,95 @@
|
|||||||
"max_steps = -1\n",
|
"max_steps = -1\n",
|
||||||
"vocab_size = 20000\n",
|
"vocab_size = 20000\n",
|
||||||
"embed_size = 150\n",
|
"embed_size = 150\n",
|
||||||
"batch_size = 4096\n",
|
"batch_size = 1024\n",
|
||||||
"hidden_size = 1024\n",
|
"hidden_size = 512\n",
|
||||||
"learning_rate = 0.001 # < 0.1\n",
|
"learning_rate = 0.0001\n",
|
||||||
"epochs = 1\n",
|
"vocab = build_vocab('challenging-america-word-gap-prediction/train/in.tsv.xz', vocab_size)\n",
|
||||||
"#max_left_context_len = get_max_left_context_len('challenging-america-word-gap-prediction/train/in.tsv.xz')\n",
|
"train_dataset = Trigrams('challenging-america-word-gap-prediction/train/in.tsv.xz')\n",
|
||||||
"max_left_context_len = 291\n",
|
|
||||||
"torch.manual_seed(1)\n",
|
|
||||||
"vocab = build_vocab('challenging-america-word-gap-prediction/train/in.tsv.xz', vocab_size, max_left_context_len)\n",
|
|
||||||
"train_dataset = Ngrams('challenging-america-word-gap-prediction/train/in.tsv.xz', max_left_context_len)\n",
|
|
||||||
"if torch.cuda.is_available():\n",
|
"if torch.cuda.is_available():\n",
|
||||||
" device = 'cuda'\n",
|
" device = 'cuda'\n",
|
||||||
"else:\n",
|
"else:\n",
|
||||||
" raise Exception()\n",
|
" raise Exception()\n",
|
||||||
"model = NeuralLanguageModel(vocab_size, embed_size, hidden_size).to(device)\n",
|
"model = LSTMLanguageModel(vocab_size, embed_size, hidden_size).to(device)\n",
|
||||||
"#model.load_state_dict(torch.load(model_name))\n",
|
|
||||||
"data = DataLoader(train_dataset, batch_size=batch_size)\n",
|
"data = DataLoader(train_dataset, batch_size=batch_size)\n",
|
||||||
"optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)\n",
|
"optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)\n",
|
||||||
"criterion = torch.nn.NLLLoss()\n",
|
"criterion = torch.nn.NLLLoss()\n",
|
||||||
"\n",
|
"\n",
|
||||||
"with torch.autograd.set_detect_anomaly(True):\n",
|
"model.train()\n",
|
||||||
" model.train()\n",
|
"step = 0\n",
|
||||||
" epoch = 0\n",
|
"for ngram in data:\n",
|
||||||
" for i in range(epochs):\n",
|
" x = [gram.to(device) for gram in ngram[:3]+ngram[4:]]\n",
|
||||||
" step = 0\n",
|
" y = ngram[3].to(device)\n",
|
||||||
" epoch += 1\n",
|
" optimizer.zero_grad()\n",
|
||||||
" print(f'--------epoch {epoch}--------')\n",
|
" ypredicted = model(x)\n",
|
||||||
" for x_whole_left, x_left_trigram, y, x_right_trigram in data:\n",
|
" loss = criterion(torch.log(ypredicted), y)\n",
|
||||||
" x = [t.to(device) for t in x_whole_left], [t.to(device) for t in x_left_trigram], [t.to(device) for t in x_right_trigram]\n",
|
" if step % 100 == 0:\n",
|
||||||
" y = y.to(device)\n",
|
" print(f'steps: {step}, loss: {loss.item()}')\n",
|
||||||
" optimizer.zero_grad()\n",
|
" if step % 1000 == 0:\n",
|
||||||
" y_pred = model(x)\n",
|
" if step != 0:\n",
|
||||||
" loss = criterion(torch.log(y_pred), y)\n",
|
" torch.save(model.state_dict(), f'{loss}_steps-{step}_vocab-{vocab_size}_embed-{embed_size}_batch-{batch_size}_hidden-{hidden_size}_lr-{learning_rate}.bin')\n",
|
||||||
" if step % 1000 == 0:\n",
|
" loss.backward()\n",
|
||||||
" print(f'steps: {step}, loss: {loss.item()}')\n",
|
" optimizer.step()\n",
|
||||||
" if step != 0:\n",
|
" if step == max_steps:\n",
|
||||||
" name = f'loss-{loss.item()}_model_steps-{step}_epoch-{epoch}_vocab-{vocab_size}_embed-{embed_size}_batch-{batch_size}_hidden-{hidden_size}_lr-{learning_rate}.bin'\n",
|
" break\n",
|
||||||
" torch.save(model.state_dict(), 'models/' + name)\n",
|
" step += 1"
|
||||||
" loss.backward()\n",
|
|
||||||
" torch.nn.utils.clip_grad_value_(model.parameters(), clip_value=1.0)\n",
|
|
||||||
" optimizer.step()\n",
|
|
||||||
" if step == max_steps:\n",
|
|
||||||
" break\n",
|
|
||||||
" step += 1"
|
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"collapsed": false
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"is_executing": true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": 69,
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"step += 1\n",
|
|
||||||
"vocab_size = 20000\n",
|
"vocab_size = 20000\n",
|
||||||
"embed_size = 150\n",
|
"embed_size = 150\n",
|
||||||
"batch_size = 4096\n",
|
"hidden_size = 512\n",
|
||||||
"hidden_size = 1024\n",
|
"vocab = build_vocab('challenging-america-word-gap-prediction/train/in.tsv.xz', vocab_size)\n",
|
||||||
"max_left_context_len = 291\n",
|
"vocab.set_default_index(vocab['<unk>'])"
|
||||||
"vocab = build_vocab('challenging-america-word-gap-prediction/train/in.tsv.xz', vocab_size, max_left_context_len)\n",
|
],
|
||||||
"vocab.set_default_index(vocab['<unk>'])\n",
|
"metadata": {
|
||||||
"model_name = 'models/' + 'best_model_mod_arch.bin'\n",
|
"collapsed": false
|
||||||
"topk = 10"
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 74,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"for model_name in ['best.bin']:\n",
|
||||||
|
" topk = 100\n",
|
||||||
|
" preds = []\n",
|
||||||
|
" device = 'cuda'\n",
|
||||||
|
" model = LSTMLanguageModel(vocab_size, embed_size, hidden_size).to(device)\n",
|
||||||
|
" model.load_state_dict(torch.load(model_name))\n",
|
||||||
|
" model.eval()\n",
|
||||||
|
" for path in ['challenging-america-word-gap-prediction/dev-0', 'challenging-america-word-gap-prediction/test-A']:\n",
|
||||||
|
" with lzma.open(f'{path}/in.tsv.xz', 'r') as fh, open(f'{path}/out.tsv', 'w', encoding='utf-8') as f_out:\n",
|
||||||
|
" for line in fh:\n",
|
||||||
|
" left_context = simple_preprocess(line.decode('utf-8').split('\\t')[-2].strip()).split()[-3:]\n",
|
||||||
|
" right_context = simple_preprocess(line.decode('utf-8').split('\\t')[-1].strip()).split()[:3]\n",
|
||||||
|
" full_context = left_context + right_context\n",
|
||||||
|
" x = [torch.tensor(vocab.forward([word])).to(device) for word in full_context]\n",
|
||||||
|
" out = model(x)\n",
|
||||||
|
" top = torch.topk(out[0], topk)\n",
|
||||||
|
" top_indices = top.indices.tolist()\n",
|
||||||
|
" top_probs = top.values.tolist()\n",
|
||||||
|
" top_words = vocab.lookup_tokens(top_indices)\n",
|
||||||
|
" top_zipped = zip(top_words, top_probs)\n",
|
||||||
|
" pred = ''\n",
|
||||||
|
" total_prob = 0\n",
|
||||||
|
" for word, prob in top_zipped:\n",
|
||||||
|
" if word != '<unk>':\n",
|
||||||
|
" pred += f'{word}:{prob} '\n",
|
||||||
|
" total_prob += prob\n",
|
||||||
|
" unk_prob = 1 - total_prob\n",
|
||||||
|
" pred += f':{unk_prob}'\n",
|
||||||
|
" f_out.write(pred + '\\n')"
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"collapsed": false
|
"collapsed": false
|
||||||
@ -231,42 +210,15 @@
|
|||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"preds = []\n",
|
"%cd challenging-america-word-gap-prediction/\n",
|
||||||
"device = 'cuda'\n",
|
"!./geval --test-name dev-0\n",
|
||||||
"model = NeuralLanguageModel(vocab_size, embed_size, hidden_size).to(device)\n",
|
"%cd ../"
|
||||||
"model.load_state_dict(torch.load(model_name))\n",
|
|
||||||
"model.eval()\n",
|
|
||||||
"j = 0\n",
|
|
||||||
"for path in ['challenging-america-word-gap-prediction/dev-0', 'challenging-america-word-gap-prediction/test-A']:\n",
|
|
||||||
" with lzma.open(f'{path}/in.tsv.xz', 'r') as fh, open(f'{path}/out.tsv', 'w', encoding='utf-8') as f_out:\n",
|
|
||||||
" for line in fh:\n",
|
|
||||||
" j += 1\n",
|
|
||||||
" left_context = simple_preprocess(line.decode('utf-8')).split('\\t')[-2].strip()\n",
|
|
||||||
" right_context = simple_preprocess(line.decode('utf-8')).split('\\t')[-1].strip()\n",
|
|
||||||
" padding = '<pad> ' * (max_left_context_len - 1) # <s>\n",
|
|
||||||
" left_context = padding + '<s> ' + left_context\n",
|
|
||||||
" right_context = right_context + ' </s> <pad> <pad>'\n",
|
|
||||||
" x_left_trigram, x_right_trigram = left_context.split()[-3:], right_context.split()[:3]\n",
|
|
||||||
" x = [torch.tensor(vocab.forward([w])).to(device) for w in left_context], [torch.tensor(vocab.forward([w])).to(device) for w in x_left_trigram], [torch.tensor(vocab.forward([w])).to(device) for w in x_right_trigram]\n",
|
|
||||||
" out = model(x)\n",
|
|
||||||
" top = torch.topk(out[0], topk)\n",
|
|
||||||
" top_indices = top.indices.tolist()\n",
|
|
||||||
" print(j, ' '.join(x_left_trigram), '[[[', vocab.lookup_token(top_indices[0]) if vocab.lookup_token(top_indices[0]) != '<unk>' else vocab.lookup_token(top_indices[1]), ']]]', ' '.join(x_right_trigram))\n",
|
|
||||||
" top_probs = top.values.tolist()\n",
|
|
||||||
" top_words = vocab.lookup_tokens(top_indices)\n",
|
|
||||||
" top_zipped = zip(top_words, top_probs)\n",
|
|
||||||
" pred = ''\n",
|
|
||||||
" total_prob = 0\n",
|
|
||||||
" for word, prob in top_zipped:\n",
|
|
||||||
" if word != '<unk>':\n",
|
|
||||||
" pred += f'{word}:{prob} '\n",
|
|
||||||
" total_prob += prob\n",
|
|
||||||
" unk_prob = 1 - total_prob\n",
|
|
||||||
" pred += f':{unk_prob}'\n",
|
|
||||||
" f_out.write(pred + '\\n')"
|
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"collapsed": false
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"is_executing": true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
14828
test-A/out.tsv
14828
test-A/out.tsv
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user