Compare commits
20 Commits
Author | SHA1 | Date | |
---|---|---|---|
|
abb72efa82 | ||
|
4278d2f8ea | ||
|
5a5265fd3d | ||
|
4c56aa7160 | ||
|
0f731fb82d | ||
|
a95d7ddac5 | ||
|
fbe87e2ef1 | ||
378b5588fe | |||
|
c2fa4e59db | ||
dbda50ac2b | |||
9cc73ca767 | |||
4e07117b92 | |||
|
6a99ef51da | ||
|
e003ad6f34 | ||
|
aaccbbeb06 | ||
39c1f3a341 | |||
bb121718aa | |||
9332c1957b | |||
|
d877969ac2 | ||
|
2a4ab01f29 |
12
README.md
12
README.md
@ -1,9 +1,3 @@
|
||||
Challenging America word-gap prediction
|
||||
===================================
|
||||
|
||||
Guess a word in a gap.
|
||||
|
||||
Evaluation metric
|
||||
-----------------
|
||||
|
||||
LikelihoodHashed is the metric
|
||||
# Rozszerzenia:
|
||||
- uwzględniony prawy kontekst
|
||||
- dwukierunkowy LSTM
|
||||
|
10519
dev-0/in.tsv
10519
dev-0/in.tsv
File diff suppressed because it is too large
Load Diff
BIN
dev-0/in.tsv.xz
Normal file
BIN
dev-0/in.tsv.xz
Normal file
Binary file not shown.
21038
dev-0/out.tsv
21038
dev-0/out.tsv
File diff suppressed because it is too large
Load Diff
@ -1,11 +0,0 @@
|
||||
import sys
|
||||
|
||||
file = sys.argv[1]
|
||||
|
||||
with open(file, encoding='utf-8') as f1, open('out.tsv', 'w', encoding='utf-8') as f2:
|
||||
for line in f1:
|
||||
line = line.split('\t')
|
||||
if line[-1][0].isupper():
|
||||
f2.write('the:0.9 :0.1\n')
|
||||
else:
|
||||
f2.write('the:0.4 a:0.4 :0.2\n')
|
14
gonito.yaml
Normal file
14
gonito.yaml
Normal file
@ -0,0 +1,14 @@
|
||||
description: lstm with trigram left-right context of 3 ngrams
|
||||
tags:
|
||||
- neural-network
|
||||
- left-context
|
||||
- right-context
|
||||
- lstm
|
||||
params:
|
||||
vocab_size: 20000
|
||||
embed_size: 150
|
||||
batch_size: 4096
|
||||
hidden_size: 1024
|
||||
learning_rate: 0.0001
|
||||
param-files:
|
||||
- "*.yaml"
|
246
solution.ipynb
Normal file
246
solution.ipynb
Normal file
@ -0,0 +1,246 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from torchtext.vocab import build_vocab_from_iterator\n",
|
||||
"import pickle\n",
|
||||
"from torch.utils.data import IterableDataset\n",
|
||||
"from itertools import chain\n",
|
||||
"from torch import nn\n",
|
||||
"import torch.nn.functional as F\n",
|
||||
"import torch\n",
|
||||
"import lzma\n",
|
||||
"from torch.utils.data import DataLoader\n",
|
||||
"torch.manual_seed(1)"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"is_executing": true
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 60,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def simple_preprocess(line):\n",
|
||||
" return line.replace(r'\\n', ' ')\n",
|
||||
"\n",
|
||||
"def get_words_from_line(line):\n",
|
||||
" line = line.strip()\n",
|
||||
" line = simple_preprocess(line)\n",
|
||||
" yield '<s>'\n",
|
||||
" for t in line.split():\n",
|
||||
" yield t\n",
|
||||
" yield '</s>'\n",
|
||||
"\n",
|
||||
"def get_word_lines_from_file(file_name, n_size=-1):\n",
|
||||
" with lzma.open(file_name, 'r') as fh:\n",
|
||||
" n = 0\n",
|
||||
" for line in fh:\n",
|
||||
" n += 1\n",
|
||||
" yield get_words_from_line(line.decode('utf-8'))\n",
|
||||
" if n == n_size:\n",
|
||||
" break\n",
|
||||
"\n",
|
||||
"def look_ahead_iterator(gen):\n",
|
||||
" ngram = []\n",
|
||||
" for item in gen:\n",
|
||||
" if len(ngram) < 7:\n",
|
||||
" ngram.append(item)\n",
|
||||
" if len(ngram) == 7:\n",
|
||||
" yield ngram\n",
|
||||
" else:\n",
|
||||
" ngram = ngram[1:]\n",
|
||||
" ngram.append(item)\n",
|
||||
" yield ngram\n",
|
||||
"\n",
|
||||
"def build_vocab(file, vocab_size):\n",
|
||||
" try:\n",
|
||||
" with open(f'trigram_nn_vocab_{vocab_size}.pickle', 'rb') as handle:\n",
|
||||
" vocab = pickle.load(handle)\n",
|
||||
" except:\n",
|
||||
" vocab = build_vocab_from_iterator(\n",
|
||||
" get_word_lines_from_file(file),\n",
|
||||
" max_tokens = vocab_size,\n",
|
||||
" specials = ['<unk>'])\n",
|
||||
" with open(f'trigram_nn_vocab_{vocab_size}.pickle', 'wb') as handle:\n",
|
||||
" pickle.dump(vocab, handle, protocol=pickle.HIGHEST_PROTOCOL)\n",
|
||||
" return vocab\n",
|
||||
"\n",
|
||||
"class Trigrams(IterableDataset):\n",
|
||||
" def __init__(self, text_file):\n",
|
||||
" self.vocab = vocab\n",
|
||||
" self.vocab.set_default_index(self.vocab['<unk>'])\n",
|
||||
" self.text_file = text_file\n",
|
||||
"\n",
|
||||
" def __iter__(self):\n",
|
||||
" return look_ahead_iterator(\n",
|
||||
" (self.vocab[t] for t in chain.from_iterable(get_word_lines_from_file(self.text_file))))\n",
|
||||
"\n",
|
||||
"class LSTMLanguageModel(nn.Module):\n",
|
||||
" def __init__(self, vocab_size, embed_size, hidden_size):\n",
|
||||
" super(LSTMLanguageModel, self).__init__()\n",
|
||||
" self.embeddings = nn.Embedding(vocab_size, embed_size)\n",
|
||||
" self.lstm_layer = nn.LSTM(6*embed_size, hidden_size, bidirectional=True)\n",
|
||||
" self.output_layer = nn.Linear(2*hidden_size, vocab_size)\n",
|
||||
" self.softmax = nn.Softmax(dim=1)\n",
|
||||
"\n",
|
||||
" def forward(self, x):\n",
|
||||
" embeds = [self.embeddings(gram) for gram in x]\n",
|
||||
" concat_embed = torch.concat(embeds, dim=1)\n",
|
||||
" z = F.relu(self.lstm_layer(concat_embed)[0])\n",
|
||||
" y = self.softmax(self.output_layer(z))\n",
|
||||
" return y"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Sample parameters\n",
|
||||
"max_steps = -1\n",
|
||||
"vocab_size = 20000\n",
|
||||
"embed_size = 150\n",
|
||||
"batch_size = 1024\n",
|
||||
"hidden_size = 512\n",
|
||||
"learning_rate = 0.0001\n",
|
||||
"vocab = build_vocab('challenging-america-word-gap-prediction/train/in.tsv.xz', vocab_size)\n",
|
||||
"train_dataset = Trigrams('challenging-america-word-gap-prediction/train/in.tsv.xz')\n",
|
||||
"if torch.cuda.is_available():\n",
|
||||
" device = 'cuda'\n",
|
||||
"else:\n",
|
||||
" raise Exception()\n",
|
||||
"model = LSTMLanguageModel(vocab_size, embed_size, hidden_size).to(device)\n",
|
||||
"data = DataLoader(train_dataset, batch_size=batch_size)\n",
|
||||
"optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)\n",
|
||||
"criterion = torch.nn.NLLLoss()\n",
|
||||
"\n",
|
||||
"model.train()\n",
|
||||
"step = 0\n",
|
||||
"for ngram in data:\n",
|
||||
" x = [gram.to(device) for gram in ngram[:3]+ngram[4:]]\n",
|
||||
" y = ngram[3].to(device)\n",
|
||||
" optimizer.zero_grad()\n",
|
||||
" ypredicted = model(x)\n",
|
||||
" loss = criterion(torch.log(ypredicted), y)\n",
|
||||
" if step % 100 == 0:\n",
|
||||
" print(f'steps: {step}, loss: {loss.item()}')\n",
|
||||
" if step % 1000 == 0:\n",
|
||||
" if step != 0:\n",
|
||||
" torch.save(model.state_dict(), f'{loss}_steps-{step}_vocab-{vocab_size}_embed-{embed_size}_batch-{batch_size}_hidden-{hidden_size}_lr-{learning_rate}.bin')\n",
|
||||
" loss.backward()\n",
|
||||
" optimizer.step()\n",
|
||||
" if step == max_steps:\n",
|
||||
" break\n",
|
||||
" step += 1"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"is_executing": true
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 69,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"vocab_size = 20000\n",
|
||||
"embed_size = 150\n",
|
||||
"hidden_size = 512\n",
|
||||
"vocab = build_vocab('challenging-america-word-gap-prediction/train/in.tsv.xz', vocab_size)\n",
|
||||
"vocab.set_default_index(vocab['<unk>'])"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 74,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"for model_name in ['best.bin']:\n",
|
||||
" topk = 100\n",
|
||||
" preds = []\n",
|
||||
" device = 'cuda'\n",
|
||||
" model = LSTMLanguageModel(vocab_size, embed_size, hidden_size).to(device)\n",
|
||||
" model.load_state_dict(torch.load(model_name))\n",
|
||||
" model.eval()\n",
|
||||
" for path in ['challenging-america-word-gap-prediction/dev-0', 'challenging-america-word-gap-prediction/test-A']:\n",
|
||||
" with lzma.open(f'{path}/in.tsv.xz', 'r') as fh, open(f'{path}/out.tsv', 'w', encoding='utf-8') as f_out:\n",
|
||||
" for line in fh:\n",
|
||||
" left_context = simple_preprocess(line.decode('utf-8').split('\\t')[-2].strip()).split()[-3:]\n",
|
||||
" right_context = simple_preprocess(line.decode('utf-8').split('\\t')[-1].strip()).split()[:3]\n",
|
||||
" full_context = left_context + right_context\n",
|
||||
" x = [torch.tensor(vocab.forward([word])).to(device) for word in full_context]\n",
|
||||
" out = model(x)\n",
|
||||
" top = torch.topk(out[0], topk)\n",
|
||||
" top_indices = top.indices.tolist()\n",
|
||||
" top_probs = top.values.tolist()\n",
|
||||
" top_words = vocab.lookup_tokens(top_indices)\n",
|
||||
" top_zipped = zip(top_words, top_probs)\n",
|
||||
" pred = ''\n",
|
||||
" total_prob = 0\n",
|
||||
" for word, prob in top_zipped:\n",
|
||||
" if word != '<unk>':\n",
|
||||
" pred += f'{word}:{prob} '\n",
|
||||
" total_prob += prob\n",
|
||||
" unk_prob = 1 - total_prob\n",
|
||||
" pred += f':{unk_prob}'\n",
|
||||
" f_out.write(pred + '\\n')"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%cd challenging-america-word-gap-prediction/\n",
|
||||
"!./geval --test-name dev-0\n",
|
||||
"%cd ../"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"is_executing": true
|
||||
}
|
||||
}
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 2
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython2",
|
||||
"version": "2.7.6"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
}
|
7414
test-A/in.tsv
7414
test-A/in.tsv
File diff suppressed because it is too large
Load Diff
BIN
test-A/in.tsv.xz
Normal file
BIN
test-A/in.tsv.xz
Normal file
Binary file not shown.
14828
test-A/out.tsv
14828
test-A/out.tsv
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user