cleanup
This commit is contained in:
parent
fbe87e2ef1
commit
a95d7ddac5
10519
dev-0/in.tsv
10519
dev-0/in.tsv
File diff suppressed because it is too large
Load Diff
BIN
dev-0/in.tsv.xz
Normal file
BIN
dev-0/in.tsv.xz
Normal file
Binary file not shown.
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
13
gonito.yaml
13
gonito.yaml
@ -1,13 +0,0 @@
|
|||||||
description: neural network with trigrams, right context
|
|
||||||
tags:
|
|
||||||
- neural-network
|
|
||||||
- right-context
|
|
||||||
- trigrams
|
|
||||||
params:
|
|
||||||
vocab_size: 20000
|
|
||||||
embed_size: 150
|
|
||||||
batch_size: 512, 1024, 4096
|
|
||||||
hidden_size: 256, 1024
|
|
||||||
learning_rate: 0.0001, 0.001
|
|
||||||
param-files:
|
|
||||||
- "*.yaml"
|
|
283
solution.ipynb
283
solution.ipynb
@ -1,283 +0,0 @@
|
|||||||
{
|
|
||||||
"cells": [
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"collapsed": true,
|
|
||||||
"pycharm": {
|
|
||||||
"is_executing": true
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"from torchtext.vocab import build_vocab_from_iterator\n",
|
|
||||||
"import pickle\n",
|
|
||||||
"from torch.utils.data import IterableDataset\n",
|
|
||||||
"from itertools import chain\n",
|
|
||||||
"from torch import nn\n",
|
|
||||||
"import torch.nn.functional as F\n",
|
|
||||||
"import torch\n",
|
|
||||||
"import lzma\n",
|
|
||||||
"from torch.utils.data import DataLoader\n",
|
|
||||||
"import shutil\n",
|
|
||||||
"torch.manual_seed(1)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 3,
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"def simple_preprocess(line):\n",
|
|
||||||
" return line.replace(r'\\n', ' ')\n",
|
|
||||||
"\n",
|
|
||||||
"def get_words_from_line(line):\n",
|
|
||||||
" line = line.strip()\n",
|
|
||||||
" line = simple_preprocess(line)\n",
|
|
||||||
" yield '<s>'\n",
|
|
||||||
" for t in line.split():\n",
|
|
||||||
" yield t\n",
|
|
||||||
" yield '</s>'\n",
|
|
||||||
"\n",
|
|
||||||
"def get_word_lines_from_file(file_name, n_size=-1):\n",
|
|
||||||
" with lzma.open(file_name, 'r') as fh:\n",
|
|
||||||
" n = 0\n",
|
|
||||||
" for line in fh:\n",
|
|
||||||
" n += 1\n",
|
|
||||||
" yield get_words_from_line(line.decode('utf-8'))\n",
|
|
||||||
" if n == n_size:\n",
|
|
||||||
" break\n",
|
|
||||||
"\n",
|
|
||||||
"def look_ahead_iterator(gen):\n",
|
|
||||||
" ngram = []\n",
|
|
||||||
" for item in gen:\n",
|
|
||||||
" if len(ngram) < 3:\n",
|
|
||||||
" ngram.append(item)\n",
|
|
||||||
" if len(ngram) == 3:\n",
|
|
||||||
" yield ngram[1], ngram[2], ngram[0]\n",
|
|
||||||
" else:\n",
|
|
||||||
" ngram = ngram[1:]\n",
|
|
||||||
" ngram.append(item)\n",
|
|
||||||
" yield ngram[1], ngram[2], ngram[0]\n",
|
|
||||||
"\n",
|
|
||||||
"def build_vocab(file, vocab_size):\n",
|
|
||||||
" try:\n",
|
|
||||||
" with open(f'trigram_nn_vocab_{vocab_size}.pickle', 'rb') as handle:\n",
|
|
||||||
" vocab = pickle.load(handle)\n",
|
|
||||||
" except:\n",
|
|
||||||
" vocab = build_vocab_from_iterator(\n",
|
|
||||||
" get_word_lines_from_file(file),\n",
|
|
||||||
" max_tokens = vocab_size,\n",
|
|
||||||
" specials = ['<unk>'])\n",
|
|
||||||
" with open(f'trigram_nn_vocab_{vocab_size}.pickle', 'wb') as handle:\n",
|
|
||||||
" pickle.dump(vocab, handle, protocol=pickle.HIGHEST_PROTOCOL)\n",
|
|
||||||
" return vocab\n",
|
|
||||||
"\n",
|
|
||||||
"class Trigrams(IterableDataset):\n",
|
|
||||||
" def __init__(self, text_file):\n",
|
|
||||||
" self.vocab = vocab\n",
|
|
||||||
" self.vocab.set_default_index(self.vocab['<unk>'])\n",
|
|
||||||
" self.text_file = text_file\n",
|
|
||||||
"\n",
|
|
||||||
" def __iter__(self):\n",
|
|
||||||
" return look_ahead_iterator(\n",
|
|
||||||
" (self.vocab[t] for t in chain.from_iterable(get_word_lines_from_file(self.text_file))))\n",
|
|
||||||
"\n",
|
|
||||||
"class TrigramNeuralLanguageModel(nn.Module):\n",
|
|
||||||
" def __init__(self, vocab_size, embed_size, hidden_size):\n",
|
|
||||||
" super(TrigramNeuralLanguageModel, self).__init__()\n",
|
|
||||||
" self.embeddings = nn.Embedding(vocab_size, embed_size)\n",
|
|
||||||
" self.hidden_layer = nn.Linear(2*embed_size, hidden_size)\n",
|
|
||||||
" self.output_layer = nn.Linear(hidden_size, vocab_size)\n",
|
|
||||||
" self.softmax = nn.Softmax(dim=1)\n",
|
|
||||||
"\n",
|
|
||||||
" def forward(self, x):\n",
|
|
||||||
" embeds = self.embeddings(x[0]), self.embeddings(x[1])\n",
|
|
||||||
" concat_embed = torch.concat(embeds, dim=1)\n",
|
|
||||||
" z = F.relu(self.hidden_layer(concat_embed))\n",
|
|
||||||
" y = self.softmax(self.output_layer(z))\n",
|
|
||||||
" return y"
|
|
||||||
],
|
|
||||||
"metadata": {
|
|
||||||
"collapsed": false
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"max_steps = -1\n",
|
|
||||||
"vocab_size = 20000\n",
|
|
||||||
"embed_size = 150\n",
|
|
||||||
"batch_size = 1024\n",
|
|
||||||
"hidden_size = 1024\n",
|
|
||||||
"learning_rate = 0.001\n",
|
|
||||||
"vocab = build_vocab('challenging-america-word-gap-prediction/train/in.tsv.xz', vocab_size)\n",
|
|
||||||
"train_dataset = Trigrams('challenging-america-word-gap-prediction/train/in.tsv.xz')\n",
|
|
||||||
"if torch.cuda.is_available():\n",
|
|
||||||
" device = 'cuda'\n",
|
|
||||||
"else:\n",
|
|
||||||
" raise Exception()\n",
|
|
||||||
"model = TrigramNeuralLanguageModel(vocab_size, embed_size, hidden_size).to(device)\n",
|
|
||||||
"data = DataLoader(train_dataset, batch_size=batch_size)\n",
|
|
||||||
"optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)\n",
|
|
||||||
"criterion = torch.nn.NLLLoss()\n",
|
|
||||||
"\n",
|
|
||||||
"model.train()\n",
|
|
||||||
"step = 0\n",
|
|
||||||
"for x1, x2, y in data:\n",
|
|
||||||
" x = x1.to(device), x2.to(device)\n",
|
|
||||||
" y = y.to(device)\n",
|
|
||||||
" optimizer.zero_grad()\n",
|
|
||||||
" ypredicted = model(x)\n",
|
|
||||||
" loss = criterion(torch.log(ypredicted), y)\n",
|
|
||||||
" if step % 1000 == 0:\n",
|
|
||||||
" print(f'steps: {step}, loss: {loss.item()}')\n",
|
|
||||||
" if step != 0:\n",
|
|
||||||
" torch.save(model.state_dict(), f'trigram_nn_model_steps-{step}_vocab-{vocab_size}_embed-{embed_size}_batch-{batch_size}_hidden-{hidden_size}_lr-{learning_rate}.bin')\n",
|
|
||||||
" loss.backward()\n",
|
|
||||||
" optimizer.step()\n",
|
|
||||||
" if step == max_steps:\n",
|
|
||||||
" break\n",
|
|
||||||
" step += 1"
|
|
||||||
],
|
|
||||||
"metadata": {
|
|
||||||
"collapsed": false,
|
|
||||||
"pycharm": {
|
|
||||||
"is_executing": true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 9,
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"vocab_size = 20000\n",
|
|
||||||
"embed_size = 150\n",
|
|
||||||
"vocab = build_vocab('challenging-america-word-gap-prediction/train/in.tsv.xz', vocab_size)\n",
|
|
||||||
"vocab.set_default_index(vocab['<unk>'])"
|
|
||||||
],
|
|
||||||
"metadata": {
|
|
||||||
"collapsed": false
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 11,
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"trigram_nn_model_steps-13000_vocab-20000_embed-150_batch-512_hidden-256_lr-0.0001.bin\n",
|
|
||||||
"512\n",
|
|
||||||
"256\n",
|
|
||||||
"trigram_nn_model_steps-7000_vocab-20000_embed-150_batch-1024_hidden-1024_lr-0.001.bin\n",
|
|
||||||
"1024\n",
|
|
||||||
"1024\n",
|
|
||||||
"trigram_nn_model_steps-6000_vocab-20000_embed-150_batch-4096_hidden-256_lr-0.001.bin\n",
|
|
||||||
"4096\n",
|
|
||||||
"256\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"for model_name in ['trigram_nn_model_steps-13000_vocab-20000_embed-150_batch-512_hidden-256_lr-0.0001.bin', 'trigram_nn_model_steps-7000_vocab-20000_embed-150_batch-1024_hidden-1024_lr-0.001.bin', 'trigram_nn_model_steps-6000_vocab-20000_embed-150_batch-4096_hidden-256_lr-0.001.bin']:\n",
|
|
||||||
" print(model_name)\n",
|
|
||||||
" batch_size = int(model_name.split('_')[-3].split('-')[1])\n",
|
|
||||||
" print(batch_size)\n",
|
|
||||||
" hidden_size = int(model_name.split('_')[-2].split('-')[1])\n",
|
|
||||||
" print(hidden_size)\n",
|
|
||||||
" topk = 10\n",
|
|
||||||
" preds = []\n",
|
|
||||||
" device = 'cuda'\n",
|
|
||||||
" model = TrigramNeuralLanguageModel(vocab_size, embed_size, hidden_size).to(device)\n",
|
|
||||||
" model.load_state_dict(torch.load(model_name))\n",
|
|
||||||
" model.eval()\n",
|
|
||||||
" for path in ['challenging-america-word-gap-prediction/dev-0', 'challenging-america-word-gap-prediction/test-A']:\n",
|
|
||||||
" with lzma.open(f'{path}/in.tsv.xz', 'r') as fh, open(f'{path}/out.tsv', 'w', encoding='utf-8') as f_out:\n",
|
|
||||||
" for line in fh:\n",
|
|
||||||
" right_context = simple_preprocess(line.decode('utf-8').split('\\t')[-1].strip()).split()[:2]\n",
|
|
||||||
" x = torch.tensor(vocab.forward([right_context[0]])).to(device), \\\n",
|
|
||||||
" torch.tensor(vocab.forward([right_context[1]])).to(device)\n",
|
|
||||||
" out = model(x)\n",
|
|
||||||
" top = torch.topk(out[0], topk)\n",
|
|
||||||
" top_indices = top.indices.tolist()\n",
|
|
||||||
" top_probs = top.values.tolist()\n",
|
|
||||||
" top_words = vocab.lookup_tokens(top_indices)\n",
|
|
||||||
" top_zipped = zip(top_words, top_probs)\n",
|
|
||||||
" pred = ''\n",
|
|
||||||
" total_prob = 0\n",
|
|
||||||
" for word, prob in top_zipped:\n",
|
|
||||||
" if word != '<unk>':\n",
|
|
||||||
" pred += f'{word}:{prob} '\n",
|
|
||||||
" total_prob += prob\n",
|
|
||||||
" unk_prob = 1 - total_prob\n",
|
|
||||||
" pred += f':{unk_prob}'\n",
|
|
||||||
" f_out.write(pred + '\\n')\n",
|
|
||||||
" src=f'{path}/out.tsv'\n",
|
|
||||||
" dst=f\"{path}/{model_name.split('.')[0]}_out.tsv\"\n",
|
|
||||||
" shutil.copy(src, dst)"
|
|
||||||
],
|
|
||||||
"metadata": {
|
|
||||||
"collapsed": false
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 12,
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"/home/ked/PycharmProjects/mj9/challenging-america-word-gap-prediction\n",
|
|
||||||
"300.66\r\n",
|
|
||||||
"/home/ked/PycharmProjects/mj9\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"%cd challenging-america-word-gap-prediction/\n",
|
|
||||||
"!./geval --test-name dev-0\n",
|
|
||||||
"%cd ../"
|
|
||||||
],
|
|
||||||
"metadata": {
|
|
||||||
"collapsed": false
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"outputs": [],
|
|
||||||
"source": [],
|
|
||||||
"metadata": {
|
|
||||||
"collapsed": false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"metadata": {
|
|
||||||
"kernelspec": {
|
|
||||||
"display_name": "Python 3",
|
|
||||||
"language": "python",
|
|
||||||
"name": "python3"
|
|
||||||
},
|
|
||||||
"language_info": {
|
|
||||||
"codemirror_mode": {
|
|
||||||
"name": "ipython",
|
|
||||||
"version": 2
|
|
||||||
},
|
|
||||||
"file_extension": ".py",
|
|
||||||
"mimetype": "text/x-python",
|
|
||||||
"name": "python",
|
|
||||||
"nbconvert_exporter": "python",
|
|
||||||
"pygments_lexer": "ipython2",
|
|
||||||
"version": "2.7.6"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"nbformat": 4,
|
|
||||||
"nbformat_minor": 0
|
|
||||||
}
|
|
7414
test-A/in.tsv
7414
test-A/in.tsv
File diff suppressed because it is too large
Load Diff
BIN
test-A/in.tsv.xz
Normal file
BIN
test-A/in.tsv.xz
Normal file
Binary file not shown.
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user