Compare commits
13 Commits
zad1
...
trigram-nn
Author | SHA1 | Date | |
---|---|---|---|
378b5588fe | |||
|
c2fa4e59db | ||
dbda50ac2b | |||
9cc73ca767 | |||
4e07117b92 | |||
|
6a99ef51da | ||
|
e003ad6f34 | ||
|
aaccbbeb06 | ||
39c1f3a341 | |||
bb121718aa | |||
9332c1957b | |||
|
d877969ac2 | ||
|
2a4ab01f29 |
11
README.md
11
README.md
@ -1,9 +1,2 @@
|
||||
Challenging America word-gap prediction
|
||||
===================================
|
||||
|
||||
Guess a word in a gap.
|
||||
|
||||
Evaluation metric
|
||||
-----------------
|
||||
|
||||
LikelihoodHashed is the metric
|
||||
# Rozwiązanie dla wariantu kontekstu dwóch następnych słów (reszta z dzielenia przez 3 = 2)
|
||||
# Bugfixy inferencji i wstawienie lepszych wyników 24.05.23.
|
||||
|
10498
dev-0/model_steps-1000_vocab-5000_embed-50_batch-5000_out.tsv
Normal file
10498
dev-0/model_steps-1000_vocab-5000_embed-50_batch-5000_out.tsv
Normal file
File diff suppressed because it is too large
Load Diff
10497
dev-0/model_steps-27000_vocab-5000_embed-50_batch-5000_out.tsv
Normal file
10497
dev-0/model_steps-27000_vocab-5000_embed-50_batch-5000_out.tsv
Normal file
File diff suppressed because it is too large
Load Diff
21038
dev-0/out.tsv
21038
dev-0/out.tsv
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -1,11 +0,0 @@
|
||||
import sys
|
||||
|
||||
file = sys.argv[1]
|
||||
|
||||
with open(file, encoding='utf-8') as f1, open('out.tsv', 'w', encoding='utf-8') as f2:
|
||||
for line in f1:
|
||||
line = line.split('\t')
|
||||
if line[-1][0].isupper():
|
||||
f2.write('the:0.9 :0.1\n')
|
||||
else:
|
||||
f2.write('the:0.4 a:0.4 :0.2\n')
|
13
gonito.yaml
Normal file
13
gonito.yaml
Normal file
@ -0,0 +1,13 @@
|
||||
description: neural network with trigrams, right context
|
||||
tags:
|
||||
- neural-network
|
||||
- right-context
|
||||
- trigrams
|
||||
params:
|
||||
vocab_size: 20000
|
||||
embed_size: 150
|
||||
batch_size: 512, 1024, 4096
|
||||
hidden_size: 256, 1024
|
||||
learning_rate: 0.0001, 0.001
|
||||
param-files:
|
||||
- "*.yaml"
|
283
solution.ipynb
Normal file
283
solution.ipynb
Normal file
@ -0,0 +1,283 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": true,
|
||||
"pycharm": {
|
||||
"is_executing": true
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from torchtext.vocab import build_vocab_from_iterator\n",
|
||||
"import pickle\n",
|
||||
"from torch.utils.data import IterableDataset\n",
|
||||
"from itertools import chain\n",
|
||||
"from torch import nn\n",
|
||||
"import torch.nn.functional as F\n",
|
||||
"import torch\n",
|
||||
"import lzma\n",
|
||||
"from torch.utils.data import DataLoader\n",
|
||||
"import shutil\n",
|
||||
"torch.manual_seed(1)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def simple_preprocess(line):\n",
|
||||
" return line.replace(r'\\n', ' ')\n",
|
||||
"\n",
|
||||
"def get_words_from_line(line):\n",
|
||||
" line = line.strip()\n",
|
||||
" line = simple_preprocess(line)\n",
|
||||
" yield '<s>'\n",
|
||||
" for t in line.split():\n",
|
||||
" yield t\n",
|
||||
" yield '</s>'\n",
|
||||
"\n",
|
||||
"def get_word_lines_from_file(file_name, n_size=-1):\n",
|
||||
" with lzma.open(file_name, 'r') as fh:\n",
|
||||
" n = 0\n",
|
||||
" for line in fh:\n",
|
||||
" n += 1\n",
|
||||
" yield get_words_from_line(line.decode('utf-8'))\n",
|
||||
" if n == n_size:\n",
|
||||
" break\n",
|
||||
"\n",
|
||||
"def look_ahead_iterator(gen):\n",
|
||||
" ngram = []\n",
|
||||
" for item in gen:\n",
|
||||
" if len(ngram) < 3:\n",
|
||||
" ngram.append(item)\n",
|
||||
" if len(ngram) == 3:\n",
|
||||
" yield ngram[1], ngram[2], ngram[0]\n",
|
||||
" else:\n",
|
||||
" ngram = ngram[1:]\n",
|
||||
" ngram.append(item)\n",
|
||||
" yield ngram[1], ngram[2], ngram[0]\n",
|
||||
"\n",
|
||||
"def build_vocab(file, vocab_size):\n",
|
||||
" try:\n",
|
||||
" with open(f'trigram_nn_vocab_{vocab_size}.pickle', 'rb') as handle:\n",
|
||||
" vocab = pickle.load(handle)\n",
|
||||
" except:\n",
|
||||
" vocab = build_vocab_from_iterator(\n",
|
||||
" get_word_lines_from_file(file),\n",
|
||||
" max_tokens = vocab_size,\n",
|
||||
" specials = ['<unk>'])\n",
|
||||
" with open(f'trigram_nn_vocab_{vocab_size}.pickle', 'wb') as handle:\n",
|
||||
" pickle.dump(vocab, handle, protocol=pickle.HIGHEST_PROTOCOL)\n",
|
||||
" return vocab\n",
|
||||
"\n",
|
||||
"class Trigrams(IterableDataset):\n",
|
||||
" def __init__(self, text_file):\n",
|
||||
" self.vocab = vocab\n",
|
||||
" self.vocab.set_default_index(self.vocab['<unk>'])\n",
|
||||
" self.text_file = text_file\n",
|
||||
"\n",
|
||||
" def __iter__(self):\n",
|
||||
" return look_ahead_iterator(\n",
|
||||
" (self.vocab[t] for t in chain.from_iterable(get_word_lines_from_file(self.text_file))))\n",
|
||||
"\n",
|
||||
"class TrigramNeuralLanguageModel(nn.Module):\n",
|
||||
" def __init__(self, vocab_size, embed_size, hidden_size):\n",
|
||||
" super(TrigramNeuralLanguageModel, self).__init__()\n",
|
||||
" self.embeddings = nn.Embedding(vocab_size, embed_size)\n",
|
||||
" self.hidden_layer = nn.Linear(2*embed_size, hidden_size)\n",
|
||||
" self.output_layer = nn.Linear(hidden_size, vocab_size)\n",
|
||||
" self.softmax = nn.Softmax(dim=1)\n",
|
||||
"\n",
|
||||
" def forward(self, x):\n",
|
||||
" embeds = self.embeddings(x[0]), self.embeddings(x[1])\n",
|
||||
" concat_embed = torch.concat(embeds, dim=1)\n",
|
||||
" z = F.relu(self.hidden_layer(concat_embed))\n",
|
||||
" y = self.softmax(self.output_layer(z))\n",
|
||||
" return y"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"max_steps = -1\n",
|
||||
"vocab_size = 20000\n",
|
||||
"embed_size = 150\n",
|
||||
"batch_size = 1024\n",
|
||||
"hidden_size = 1024\n",
|
||||
"learning_rate = 0.001\n",
|
||||
"vocab = build_vocab('challenging-america-word-gap-prediction/train/in.tsv.xz', vocab_size)\n",
|
||||
"train_dataset = Trigrams('challenging-america-word-gap-prediction/train/in.tsv.xz')\n",
|
||||
"if torch.cuda.is_available():\n",
|
||||
" device = 'cuda'\n",
|
||||
"else:\n",
|
||||
" raise Exception()\n",
|
||||
"model = TrigramNeuralLanguageModel(vocab_size, embed_size, hidden_size).to(device)\n",
|
||||
"data = DataLoader(train_dataset, batch_size=batch_size)\n",
|
||||
"optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)\n",
|
||||
"criterion = torch.nn.NLLLoss()\n",
|
||||
"\n",
|
||||
"model.train()\n",
|
||||
"step = 0\n",
|
||||
"for x1, x2, y in data:\n",
|
||||
" x = x1.to(device), x2.to(device)\n",
|
||||
" y = y.to(device)\n",
|
||||
" optimizer.zero_grad()\n",
|
||||
" ypredicted = model(x)\n",
|
||||
" loss = criterion(torch.log(ypredicted), y)\n",
|
||||
" if step % 1000 == 0:\n",
|
||||
" print(f'steps: {step}, loss: {loss.item()}')\n",
|
||||
" if step != 0:\n",
|
||||
" torch.save(model.state_dict(), f'trigram_nn_model_steps-{step}_vocab-{vocab_size}_embed-{embed_size}_batch-{batch_size}_hidden-{hidden_size}_lr-{learning_rate}.bin')\n",
|
||||
" loss.backward()\n",
|
||||
" optimizer.step()\n",
|
||||
" if step == max_steps:\n",
|
||||
" break\n",
|
||||
" step += 1"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"is_executing": true
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"vocab_size = 20000\n",
|
||||
"embed_size = 150\n",
|
||||
"vocab = build_vocab('challenging-america-word-gap-prediction/train/in.tsv.xz', vocab_size)\n",
|
||||
"vocab.set_default_index(vocab['<unk>'])"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"trigram_nn_model_steps-13000_vocab-20000_embed-150_batch-512_hidden-256_lr-0.0001.bin\n",
|
||||
"512\n",
|
||||
"256\n",
|
||||
"trigram_nn_model_steps-7000_vocab-20000_embed-150_batch-1024_hidden-1024_lr-0.001.bin\n",
|
||||
"1024\n",
|
||||
"1024\n",
|
||||
"trigram_nn_model_steps-6000_vocab-20000_embed-150_batch-4096_hidden-256_lr-0.001.bin\n",
|
||||
"4096\n",
|
||||
"256\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"for model_name in ['trigram_nn_model_steps-13000_vocab-20000_embed-150_batch-512_hidden-256_lr-0.0001.bin', 'trigram_nn_model_steps-7000_vocab-20000_embed-150_batch-1024_hidden-1024_lr-0.001.bin', 'trigram_nn_model_steps-6000_vocab-20000_embed-150_batch-4096_hidden-256_lr-0.001.bin']:\n",
|
||||
" print(model_name)\n",
|
||||
" batch_size = int(model_name.split('_')[-3].split('-')[1])\n",
|
||||
" print(batch_size)\n",
|
||||
" hidden_size = int(model_name.split('_')[-2].split('-')[1])\n",
|
||||
" print(hidden_size)\n",
|
||||
" topk = 10\n",
|
||||
" preds = []\n",
|
||||
" device = 'cuda'\n",
|
||||
" model = TrigramNeuralLanguageModel(vocab_size, embed_size, hidden_size).to(device)\n",
|
||||
" model.load_state_dict(torch.load(model_name))\n",
|
||||
" model.eval()\n",
|
||||
" for path in ['challenging-america-word-gap-prediction/dev-0', 'challenging-america-word-gap-prediction/test-A']:\n",
|
||||
" with lzma.open(f'{path}/in.tsv.xz', 'r') as fh, open(f'{path}/out.tsv', 'w', encoding='utf-8') as f_out:\n",
|
||||
" for line in fh:\n",
|
||||
" right_context = simple_preprocess(line.decode('utf-8').split('\\t')[-1].strip()).split()[:2]\n",
|
||||
" x = torch.tensor(vocab.forward([right_context[0]])).to(device), \\\n",
|
||||
" torch.tensor(vocab.forward([right_context[1]])).to(device)\n",
|
||||
" out = model(x)\n",
|
||||
" top = torch.topk(out[0], topk)\n",
|
||||
" top_indices = top.indices.tolist()\n",
|
||||
" top_probs = top.values.tolist()\n",
|
||||
" top_words = vocab.lookup_tokens(top_indices)\n",
|
||||
" top_zipped = zip(top_words, top_probs)\n",
|
||||
" pred = ''\n",
|
||||
" total_prob = 0\n",
|
||||
" for word, prob in top_zipped:\n",
|
||||
" if word != '<unk>':\n",
|
||||
" pred += f'{word}:{prob} '\n",
|
||||
" total_prob += prob\n",
|
||||
" unk_prob = 1 - total_prob\n",
|
||||
" pred += f':{unk_prob}'\n",
|
||||
" f_out.write(pred + '\\n')\n",
|
||||
" src=f'{path}/out.tsv'\n",
|
||||
" dst=f\"{path}/{model_name.split('.')[0]}_out.tsv\"\n",
|
||||
" shutil.copy(src, dst)"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/home/ked/PycharmProjects/mj9/challenging-america-word-gap-prediction\n",
|
||||
"300.66\r\n",
|
||||
"/home/ked/PycharmProjects/mj9\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"%cd challenging-america-word-gap-prediction/\n",
|
||||
"!./geval --test-name dev-0\n",
|
||||
"%cd ../"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"outputs": [],
|
||||
"source": [],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
}
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 2
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython2",
|
||||
"version": "2.7.6"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
}
|
7369
test-A/model_steps-1000_vocab-5000_embed-50_batch-5000_out.tsv
Normal file
7369
test-A/model_steps-1000_vocab-5000_embed-50_batch-5000_out.tsv
Normal file
File diff suppressed because it is too large
Load Diff
7381
test-A/model_steps-27000_vocab-5000_embed-50_batch-5000_out.tsv
Normal file
7381
test-A/model_steps-27000_vocab-5000_embed-50_batch-5000_out.tsv
Normal file
File diff suppressed because it is too large
Load Diff
14828
test-A/out.tsv
14828
test-A/out.tsv
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user