Compare commits

...

20 Commits
zad1 ... zad_10

Author SHA1 Message Date
Kacper
abb72efa82 solution 2023-06-06 22:05:40 +02:00
Kacper
4278d2f8ea add solution etc 2023-05-26 11:24:47 +02:00
Kacper
5a5265fd3d replace test scores 2023-05-26 01:06:05 +02:00
Kacper
4c56aa7160 update scores 2023-05-25 19:38:47 +02:00
Kacper
0f731fb82d update 2023-05-25 19:19:31 +02:00
Kacper
a95d7ddac5 cleanup 2023-05-25 19:16:30 +02:00
Kacper
fbe87e2ef1 test scores 2023-05-25 19:09:59 +02:00
378b5588fe Update 'README.md' 2023-05-25 11:30:04 +02:00
Kacper
c2fa4e59db fix inference and results 2023-05-24 20:03:23 +02:00
dbda50ac2b Update 'README.md' 2023-05-18 19:46:44 +02:00
9cc73ca767 Delete 'lab5.py' 2023-05-13 10:52:54 +02:00
4e07117b92 Delete 'generations.txt' 2023-05-13 10:52:48 +02:00
Kacper
6a99ef51da add solution code 2023-05-08 16:52:16 +02:00
Kacper
e003ad6f34 result files without visible solution for now 2023-05-01 17:23:37 +02:00
Kacper
aaccbbeb06 solution 2023-04-27 22:57:39 +02:00
39c1f3a341 Update 'README.md' 2023-04-15 03:01:41 +02:00
bb121718aa Delete 'bigram_model.py' 2023-04-15 02:50:00 +02:00
9332c1957b Delete 'generate_out.py' 2023-04-15 02:49:49 +02:00
Kacper
d877969ac2 lab5 2023-04-15 02:48:34 +02:00
Kacper
2a4ab01f29 zad2 2023-04-04 22:08:35 +02:00
10 changed files with 18196 additions and 35886 deletions

View File

@ -1,9 +1,3 @@
Challenging America word-gap prediction
===================================
Guess a word in a gap.
Evaluation metric
-----------------
LikelihoodHashed is the metric
# Rozszerzenia:
- uwzględniony prawy kontekst
- dwukierunkowy LSTM

10519
dev-0/in.tsv

File diff suppressed because it is too large Load Diff

BIN
dev-0/in.tsv.xz Normal file

Binary file not shown.

File diff suppressed because it is too large Load Diff

View File

@ -1,11 +0,0 @@
import sys
file = sys.argv[1]
with open(file, encoding='utf-8') as f1, open('out.tsv', 'w', encoding='utf-8') as f2:
for line in f1:
line = line.split('\t')
if line[-1][0].isupper():
f2.write('the:0.9 :0.1\n')
else:
f2.write('the:0.4 a:0.4 :0.2\n')

14
gonito.yaml Normal file
View File

@ -0,0 +1,14 @@
description: lstm with trigram left-right context of 3 ngrams
tags:
- neural-network
- left-context
- right-context
- lstm
params:
vocab_size: 20000
embed_size: 150
batch_size: 4096
hidden_size: 1024
learning_rate: 0.0001
param-files:
- "*.yaml"

246
solution.ipynb Normal file
View File

@ -0,0 +1,246 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"from torchtext.vocab import build_vocab_from_iterator\n",
"import pickle\n",
"from torch.utils.data import IterableDataset\n",
"from itertools import chain\n",
"from torch import nn\n",
"import torch.nn.functional as F\n",
"import torch\n",
"import lzma\n",
"from torch.utils.data import DataLoader\n",
"torch.manual_seed(1)"
],
"metadata": {
"collapsed": false,
"pycharm": {
"is_executing": true
}
}
},
{
"cell_type": "code",
"execution_count": 60,
"outputs": [],
"source": [
"def simple_preprocess(line):\n",
" return line.replace(r'\\n', ' ')\n",
"\n",
"def get_words_from_line(line):\n",
" line = line.strip()\n",
" line = simple_preprocess(line)\n",
" yield '<s>'\n",
" for t in line.split():\n",
" yield t\n",
" yield '</s>'\n",
"\n",
"def get_word_lines_from_file(file_name, n_size=-1):\n",
" with lzma.open(file_name, 'r') as fh:\n",
" n = 0\n",
" for line in fh:\n",
" n += 1\n",
" yield get_words_from_line(line.decode('utf-8'))\n",
" if n == n_size:\n",
" break\n",
"\n",
"def look_ahead_iterator(gen):\n",
" ngram = []\n",
" for item in gen:\n",
" if len(ngram) < 7:\n",
" ngram.append(item)\n",
" if len(ngram) == 7:\n",
" yield ngram\n",
" else:\n",
" ngram = ngram[1:]\n",
" ngram.append(item)\n",
" yield ngram\n",
"\n",
"def build_vocab(file, vocab_size):\n",
" try:\n",
" with open(f'trigram_nn_vocab_{vocab_size}.pickle', 'rb') as handle:\n",
" vocab = pickle.load(handle)\n",
" except:\n",
" vocab = build_vocab_from_iterator(\n",
" get_word_lines_from_file(file),\n",
" max_tokens = vocab_size,\n",
" specials = ['<unk>'])\n",
" with open(f'trigram_nn_vocab_{vocab_size}.pickle', 'wb') as handle:\n",
" pickle.dump(vocab, handle, protocol=pickle.HIGHEST_PROTOCOL)\n",
" return vocab\n",
"\n",
"class Trigrams(IterableDataset):\n",
" def __init__(self, text_file):\n",
" self.vocab = vocab\n",
" self.vocab.set_default_index(self.vocab['<unk>'])\n",
" self.text_file = text_file\n",
"\n",
" def __iter__(self):\n",
" return look_ahead_iterator(\n",
" (self.vocab[t] for t in chain.from_iterable(get_word_lines_from_file(self.text_file))))\n",
"\n",
"class LSTMLanguageModel(nn.Module):\n",
" def __init__(self, vocab_size, embed_size, hidden_size):\n",
" super(LSTMLanguageModel, self).__init__()\n",
" self.embeddings = nn.Embedding(vocab_size, embed_size)\n",
" self.lstm_layer = nn.LSTM(6*embed_size, hidden_size, bidirectional=True)\n",
" self.output_layer = nn.Linear(2*hidden_size, vocab_size)\n",
" self.softmax = nn.Softmax(dim=1)\n",
"\n",
" def forward(self, x):\n",
" embeds = [self.embeddings(gram) for gram in x]\n",
" concat_embed = torch.concat(embeds, dim=1)\n",
" z = F.relu(self.lstm_layer(concat_embed)[0])\n",
" y = self.softmax(self.output_layer(z))\n",
" return y"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"# Sample parameters\n",
"max_steps = -1\n",
"vocab_size = 20000\n",
"embed_size = 150\n",
"batch_size = 1024\n",
"hidden_size = 512\n",
"learning_rate = 0.0001\n",
"vocab = build_vocab('challenging-america-word-gap-prediction/train/in.tsv.xz', vocab_size)\n",
"train_dataset = Trigrams('challenging-america-word-gap-prediction/train/in.tsv.xz')\n",
"if torch.cuda.is_available():\n",
" device = 'cuda'\n",
"else:\n",
" raise Exception()\n",
"model = LSTMLanguageModel(vocab_size, embed_size, hidden_size).to(device)\n",
"data = DataLoader(train_dataset, batch_size=batch_size)\n",
"optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)\n",
"criterion = torch.nn.NLLLoss()\n",
"\n",
"model.train()\n",
"step = 0\n",
"for ngram in data:\n",
" x = [gram.to(device) for gram in ngram[:3]+ngram[4:]]\n",
" y = ngram[3].to(device)\n",
" optimizer.zero_grad()\n",
" ypredicted = model(x)\n",
" loss = criterion(torch.log(ypredicted), y)\n",
" if step % 100 == 0:\n",
" print(f'steps: {step}, loss: {loss.item()}')\n",
" if step % 1000 == 0:\n",
" if step != 0:\n",
" torch.save(model.state_dict(), f'{loss}_steps-{step}_vocab-{vocab_size}_embed-{embed_size}_batch-{batch_size}_hidden-{hidden_size}_lr-{learning_rate}.bin')\n",
" loss.backward()\n",
" optimizer.step()\n",
" if step == max_steps:\n",
" break\n",
" step += 1"
],
"metadata": {
"collapsed": false,
"pycharm": {
"is_executing": true
}
}
},
{
"cell_type": "code",
"execution_count": 69,
"outputs": [],
"source": [
"vocab_size = 20000\n",
"embed_size = 150\n",
"hidden_size = 512\n",
"vocab = build_vocab('challenging-america-word-gap-prediction/train/in.tsv.xz', vocab_size)\n",
"vocab.set_default_index(vocab['<unk>'])"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 74,
"outputs": [],
"source": [
"for model_name in ['best.bin']:\n",
" topk = 100\n",
" preds = []\n",
" device = 'cuda'\n",
" model = LSTMLanguageModel(vocab_size, embed_size, hidden_size).to(device)\n",
" model.load_state_dict(torch.load(model_name))\n",
" model.eval()\n",
" for path in ['challenging-america-word-gap-prediction/dev-0', 'challenging-america-word-gap-prediction/test-A']:\n",
" with lzma.open(f'{path}/in.tsv.xz', 'r') as fh, open(f'{path}/out.tsv', 'w', encoding='utf-8') as f_out:\n",
" for line in fh:\n",
" left_context = simple_preprocess(line.decode('utf-8').split('\\t')[-2].strip()).split()[-3:]\n",
" right_context = simple_preprocess(line.decode('utf-8').split('\\t')[-1].strip()).split()[:3]\n",
" full_context = left_context + right_context\n",
" x = [torch.tensor(vocab.forward([word])).to(device) for word in full_context]\n",
" out = model(x)\n",
" top = torch.topk(out[0], topk)\n",
" top_indices = top.indices.tolist()\n",
" top_probs = top.values.tolist()\n",
" top_words = vocab.lookup_tokens(top_indices)\n",
" top_zipped = zip(top_words, top_probs)\n",
" pred = ''\n",
" total_prob = 0\n",
" for word, prob in top_zipped:\n",
" if word != '<unk>':\n",
" pred += f'{word}:{prob} '\n",
" total_prob += prob\n",
" unk_prob = 1 - total_prob\n",
" pred += f':{unk_prob}'\n",
" f_out.write(pred + '\\n')"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"%cd challenging-america-word-gap-prediction/\n",
"!./geval --test-name dev-0\n",
"%cd ../"
],
"metadata": {
"collapsed": false,
"pycharm": {
"is_executing": true
}
}
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 0
}

File diff suppressed because it is too large Load Diff

BIN
test-A/in.tsv.xz Normal file

Binary file not shown.

File diff suppressed because it is too large Load Diff