fix inference and results
This commit is contained in:
parent
dbda50ac2b
commit
c2fa4e59db
@ -1 +1,2 @@
|
|||||||
# Rozwiązanie dla wariantu kontekstu dwóch następnych słów (reszta z dzielenia przez 3 = 2)
|
# Rozwiązanie dla wariantu kontekstu dwóch następnych słów (reszta z dzielenia przez 3 = 2)
|
||||||
|
# Bugfixed inference and uploaded correctly generated results on 24.05.23.
|
||||||
|
21038
dev-0/out.tsv
21038
dev-0/out.tsv
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -4,8 +4,10 @@ tags:
|
|||||||
- right-context
|
- right-context
|
||||||
- trigrams
|
- trigrams
|
||||||
params:
|
params:
|
||||||
vocab_size: 5000
|
vocab_size: 20000
|
||||||
embed_size: 50
|
embed_size: 150
|
||||||
batch_size: 5000
|
batch_size: 512, 1024, 4096
|
||||||
|
hidden_size: 256, 1024
|
||||||
|
learning_rate: 0.0001, 0.001
|
||||||
param-files:
|
param-files:
|
||||||
- "*.yaml"
|
- "*.yaml"
|
||||||
|
131
solution.ipynb
131
solution.ipynb
@ -26,7 +26,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 65,
|
"execution_count": 3,
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"def simple_preprocess(line):\n",
|
"def simple_preprocess(line):\n",
|
||||||
@ -63,14 +63,14 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"def build_vocab(file, vocab_size):\n",
|
"def build_vocab(file, vocab_size):\n",
|
||||||
" try:\n",
|
" try:\n",
|
||||||
" with open(f'vocab_{vocab_size}.pickle', 'rb') as handle:\n",
|
" with open(f'trigram_nn_vocab_{vocab_size}.pickle', 'rb') as handle:\n",
|
||||||
" vocab = pickle.load(handle)\n",
|
" vocab = pickle.load(handle)\n",
|
||||||
" except:\n",
|
" except:\n",
|
||||||
" vocab = build_vocab_from_iterator(\n",
|
" vocab = build_vocab_from_iterator(\n",
|
||||||
" get_word_lines_from_file(file),\n",
|
" get_word_lines_from_file(file),\n",
|
||||||
" max_tokens = vocab_size,\n",
|
" max_tokens = vocab_size,\n",
|
||||||
" specials = ['<unk>'])\n",
|
" specials = ['<unk>'])\n",
|
||||||
" with open(f'vocab_{vocab_size}.pickle', 'wb') as handle:\n",
|
" with open(f'trigram_nn_vocab_{vocab_size}.pickle', 'wb') as handle:\n",
|
||||||
" pickle.dump(vocab, handle, protocol=pickle.HIGHEST_PROTOCOL)\n",
|
" pickle.dump(vocab, handle, protocol=pickle.HIGHEST_PROTOCOL)\n",
|
||||||
" return vocab\n",
|
" return vocab\n",
|
||||||
"\n",
|
"\n",
|
||||||
@ -85,18 +85,18 @@
|
|||||||
" (self.vocab[t] for t in chain.from_iterable(get_word_lines_from_file(self.text_file))))\n",
|
" (self.vocab[t] for t in chain.from_iterable(get_word_lines_from_file(self.text_file))))\n",
|
||||||
"\n",
|
"\n",
|
||||||
"class TrigramNeuralLanguageModel(nn.Module):\n",
|
"class TrigramNeuralLanguageModel(nn.Module):\n",
|
||||||
" def __init__(self, vocab_size, embed_size):\n",
|
" def __init__(self, vocab_size, embed_size, hidden_size):\n",
|
||||||
" super(TrigramNeuralLanguageModel, self).__init__()\n",
|
" super(TrigramNeuralLanguageModel, self).__init__()\n",
|
||||||
" self.embeddings = nn.Embedding(vocab_size, embed_size)\n",
|
" self.embeddings = nn.Embedding(vocab_size, embed_size)\n",
|
||||||
" self.hidden_layer = nn.Linear(2*embed_size, 64)\n",
|
" self.hidden_layer = nn.Linear(2*embed_size, hidden_size)\n",
|
||||||
" self.output_layer = nn.Linear(64, vocab_size)\n",
|
" self.output_layer = nn.Linear(hidden_size, vocab_size)\n",
|
||||||
|
" self.softmax = nn.Softmax(dim=1)\n",
|
||||||
"\n",
|
"\n",
|
||||||
" def forward(self, x):\n",
|
" def forward(self, x):\n",
|
||||||
" embeds = self.embeddings(x[0]), self.embeddings(x[1])\n",
|
" embeds = self.embeddings(x[0]), self.embeddings(x[1])\n",
|
||||||
" concat_embed = torch.concat(embeds, dim=1)\n",
|
" concat_embed = torch.concat(embeds, dim=1)\n",
|
||||||
" z = F.relu(self.hidden_layer(concat_embed))\n",
|
" z = F.relu(self.hidden_layer(concat_embed))\n",
|
||||||
" softmax = nn.Softmax(dim=1)\n",
|
" y = self.softmax(self.output_layer(z))\n",
|
||||||
" y = softmax(self.output_layer(z))\n",
|
|
||||||
" return y"
|
" return y"
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
@ -109,18 +109,20 @@
|
|||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"max_steps = -1\n",
|
"max_steps = -1\n",
|
||||||
"vocab_size = 5000\n",
|
"vocab_size = 20000\n",
|
||||||
"embed_size = 50\n",
|
"embed_size = 150\n",
|
||||||
"batch_size = 5000\n",
|
"batch_size = 1024\n",
|
||||||
|
"hidden_size = 1024\n",
|
||||||
|
"learning_rate = 0.001\n",
|
||||||
"vocab = build_vocab('challenging-america-word-gap-prediction/train/in.tsv.xz', vocab_size)\n",
|
"vocab = build_vocab('challenging-america-word-gap-prediction/train/in.tsv.xz', vocab_size)\n",
|
||||||
"train_dataset = Trigrams('challenging-america-word-gap-prediction/train/in.tsv.xz')\n",
|
"train_dataset = Trigrams('challenging-america-word-gap-prediction/train/in.tsv.xz')\n",
|
||||||
"if torch.cuda.is_available():\n",
|
"if torch.cuda.is_available():\n",
|
||||||
" device = 'cuda'\n",
|
" device = 'cuda'\n",
|
||||||
"else:\n",
|
"else:\n",
|
||||||
" raise Exception()\n",
|
" raise Exception()\n",
|
||||||
"model = TrigramNeuralLanguageModel(vocab_size, embed_size).to(device)\n",
|
"model = TrigramNeuralLanguageModel(vocab_size, embed_size, hidden_size).to(device)\n",
|
||||||
"data = DataLoader(train_dataset, batch_size=batch_size)\n",
|
"data = DataLoader(train_dataset, batch_size=batch_size)\n",
|
||||||
"optimizer = torch.optim.Adam(model.parameters())\n",
|
"optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)\n",
|
||||||
"criterion = torch.nn.NLLLoss()\n",
|
"criterion = torch.nn.NLLLoss()\n",
|
||||||
"\n",
|
"\n",
|
||||||
"model.train()\n",
|
"model.train()\n",
|
||||||
@ -132,9 +134,9 @@
|
|||||||
" ypredicted = model(x)\n",
|
" ypredicted = model(x)\n",
|
||||||
" loss = criterion(torch.log(ypredicted), y)\n",
|
" loss = criterion(torch.log(ypredicted), y)\n",
|
||||||
" if step % 1000 == 0:\n",
|
" if step % 1000 == 0:\n",
|
||||||
" print(step, loss)\n",
|
" print(f'steps: {step}, loss: {loss.item()}')\n",
|
||||||
" if step % 1000 == 0:\n",
|
" if step != 0:\n",
|
||||||
" torch.save(model.state_dict(), f'model_steps-{step}_vocab-{vocab_size}_embed-{embed_size}_batch-{batch_size}.bin')\n",
|
" torch.save(model.state_dict(), f'trigram_nn_model_steps-{step}_vocab-{vocab_size}_embed-{embed_size}_batch-{batch_size}_hidden-{hidden_size}_lr-{learning_rate}.bin')\n",
|
||||||
" loss.backward()\n",
|
" loss.backward()\n",
|
||||||
" optimizer.step()\n",
|
" optimizer.step()\n",
|
||||||
" if step == max_steps:\n",
|
" if step == max_steps:\n",
|
||||||
@ -150,12 +152,11 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": 9,
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"vocab_size = 5000\n",
|
"vocab_size = 20000\n",
|
||||||
"embed_size = 50\n",
|
"embed_size = 150\n",
|
||||||
"batch_size = 5000\n",
|
|
||||||
"vocab = build_vocab('challenging-america-word-gap-prediction/train/in.tsv.xz', vocab_size)\n",
|
"vocab = build_vocab('challenging-america-word-gap-prediction/train/in.tsv.xz', vocab_size)\n",
|
||||||
"vocab.set_default_index(vocab['<unk>'])"
|
"vocab.set_default_index(vocab['<unk>'])"
|
||||||
],
|
],
|
||||||
@ -165,44 +166,58 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": 11,
|
||||||
"outputs": [],
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"trigram_nn_model_steps-13000_vocab-20000_embed-150_batch-512_hidden-256_lr-0.0001.bin\n",
|
||||||
|
"512\n",
|
||||||
|
"256\n",
|
||||||
|
"trigram_nn_model_steps-7000_vocab-20000_embed-150_batch-1024_hidden-1024_lr-0.001.bin\n",
|
||||||
|
"1024\n",
|
||||||
|
"1024\n",
|
||||||
|
"trigram_nn_model_steps-6000_vocab-20000_embed-150_batch-4096_hidden-256_lr-0.001.bin\n",
|
||||||
|
"4096\n",
|
||||||
|
"256\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"for model_name in ['model_steps-1000_vocab-5000_embed-50_batch-5000.bin',\n",
|
"for model_name in ['trigram_nn_model_steps-13000_vocab-20000_embed-150_batch-512_hidden-256_lr-0.0001.bin', 'trigram_nn_model_steps-7000_vocab-20000_embed-150_batch-1024_hidden-1024_lr-0.001.bin', 'trigram_nn_model_steps-6000_vocab-20000_embed-150_batch-4096_hidden-256_lr-0.001.bin']:\n",
|
||||||
" 'model_steps-1000_vocab-5000_embed-50_batch-5000.bin', 'model_steps-27000_vocab-5000_embed-50_batch-5000.bin']:\n",
|
" print(model_name)\n",
|
||||||
|
" batch_size = int(model_name.split('_')[-3].split('-')[1])\n",
|
||||||
|
" print(batch_size)\n",
|
||||||
|
" hidden_size = int(model_name.split('_')[-2].split('-')[1])\n",
|
||||||
|
" print(hidden_size)\n",
|
||||||
|
" topk = 10\n",
|
||||||
" preds = []\n",
|
" preds = []\n",
|
||||||
" device = 'cuda'\n",
|
" device = 'cuda'\n",
|
||||||
" model = TrigramNeuralLanguageModel(vocab_size, embed_size).to(device)\n",
|
" model = TrigramNeuralLanguageModel(vocab_size, embed_size, hidden_size).to(device)\n",
|
||||||
" model.load_state_dict(torch.load(model_name))\n",
|
" model.load_state_dict(torch.load(model_name))\n",
|
||||||
" model.eval()\n",
|
" model.eval()\n",
|
||||||
" j = 0\n",
|
|
||||||
" for path in ['challenging-america-word-gap-prediction/dev-0', 'challenging-america-word-gap-prediction/test-A']:\n",
|
" for path in ['challenging-america-word-gap-prediction/dev-0', 'challenging-america-word-gap-prediction/test-A']:\n",
|
||||||
" with lzma.open(f'{path}/in.tsv.xz', 'r') as fh, open(f'{path}/out.tsv', 'w', encoding='utf-8') as f_out:\n",
|
" with lzma.open(f'{path}/in.tsv.xz', 'r') as fh, open(f'{path}/out.tsv', 'w', encoding='utf-8') as f_out:\n",
|
||||||
" for line in fh:\n",
|
" for line in fh:\n",
|
||||||
" right_context = simple_preprocess(line.decode('utf-8').split('\\t')[-1]).split()[:2]\n",
|
" right_context = simple_preprocess(line.decode('utf-8').split('\\t')[-1].strip()).split()[:2]\n",
|
||||||
" x = torch.tensor(vocab.forward([right_context[0]])).to(device), \\\n",
|
" x = torch.tensor(vocab.forward([right_context[0]])).to(device), \\\n",
|
||||||
" torch.tensor(vocab.forward([right_context[1]])).to(device)\n",
|
" torch.tensor(vocab.forward([right_context[1]])).to(device)\n",
|
||||||
" out = model(x)\n",
|
" out = model(x)\n",
|
||||||
" top = torch.topk(out[0], 5)\n",
|
" top = torch.topk(out[0], topk)\n",
|
||||||
" top_indices = top.indices.tolist()\n",
|
" top_indices = top.indices.tolist()\n",
|
||||||
" top_probs = top.values.tolist()\n",
|
" top_probs = top.values.tolist()\n",
|
||||||
" top_words = vocab.lookup_tokens(top_indices)\n",
|
" top_words = vocab.lookup_tokens(top_indices)\n",
|
||||||
" top_zipped = list(zip(top_words, top_probs))\n",
|
" top_zipped = zip(top_words, top_probs)\n",
|
||||||
" pred = ''\n",
|
" pred = ''\n",
|
||||||
" unk = None\n",
|
" total_prob = 0\n",
|
||||||
" for i, tup in enumerate(top_zipped):\n",
|
" for word, prob in top_zipped:\n",
|
||||||
" if tup[0] == '<unk>':\n",
|
" if word != '<unk>':\n",
|
||||||
" unk = top_zipped.pop(i)\n",
|
" pred += f'{word}:{prob} '\n",
|
||||||
" for tup in top_zipped:\n",
|
" total_prob += prob\n",
|
||||||
" pred += f'{tup[0]}:{tup[1]}\\t'\n",
|
" unk_prob = 1 - total_prob\n",
|
||||||
" if unk:\n",
|
" pred += f':{unk_prob}'\n",
|
||||||
" pred += f':{unk[1]}'\n",
|
|
||||||
" else:\n",
|
|
||||||
" pred = pred.rstrip()\n",
|
|
||||||
" f_out.write(pred + '\\n')\n",
|
" f_out.write(pred + '\\n')\n",
|
||||||
" if j % 1000 == 0:\n",
|
|
||||||
" print(pred)\n",
|
|
||||||
" j += 1\n",
|
|
||||||
" src=f'{path}/out.tsv'\n",
|
" src=f'{path}/out.tsv'\n",
|
||||||
" dst=f\"{path}/{model_name.split('.')[0]}_out.tsv\"\n",
|
" dst=f\"{path}/{model_name.split('.')[0]}_out.tsv\"\n",
|
||||||
" shutil.copy(src, dst)"
|
" shutil.copy(src, dst)"
|
||||||
@ -210,6 +225,38 @@
|
|||||||
"metadata": {
|
"metadata": {
|
||||||
"collapsed": false
|
"collapsed": false
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 12,
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"/home/ked/PycharmProjects/mj9/challenging-america-word-gap-prediction\n",
|
||||||
|
"300.66\r\n",
|
||||||
|
"/home/ked/PycharmProjects/mj9\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"%cd challenging-america-word-gap-prediction/\n",
|
||||||
|
"!./geval --test-name dev-0\n",
|
||||||
|
"%cd ../"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false
|
||||||
|
}
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
|
14828
test-A/out.tsv
14828
test-A/out.tsv
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user