fix inference and results
This commit is contained in:
parent
dbda50ac2b
commit
c2fa4e59db
@ -1 +1,2 @@
|
||||
# Rozwiązanie dla wariantu kontekstu dwóch następnych słów (reszta z dzielenia przez 3 = 2)
|
||||
# Rozwiązanie dla wariantu kontekstu dwóch następnych słów (reszta z dzielenia przez 3 = 2)
|
||||
# Bugfixed inference and uploaded correctly generated results on 24.05.23.
|
||||
|
21038
dev-0/out.tsv
21038
dev-0/out.tsv
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -4,8 +4,10 @@ tags:
|
||||
- right-context
|
||||
- trigrams
|
||||
params:
|
||||
vocab_size: 5000
|
||||
embed_size: 50
|
||||
batch_size: 5000
|
||||
vocab_size: 20000
|
||||
embed_size: 150
|
||||
batch_size: 512, 1024, 4096
|
||||
hidden_size: 256, 1024
|
||||
learning_rate: 0.0001, 0.001
|
||||
param-files:
|
||||
- "*.yaml"
|
||||
|
165
solution.ipynb
165
solution.ipynb
@ -26,19 +26,19 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 65,
|
||||
"execution_count": 3,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def simple_preprocess(line):\n",
|
||||
" return line.replace(r'\\n', ' ')\n",
|
||||
" return line.replace(r'\\n', ' ')\n",
|
||||
"\n",
|
||||
"def get_words_from_line(line):\n",
|
||||
" line = line.strip()\n",
|
||||
" line = simple_preprocess(line)\n",
|
||||
" yield '<s>'\n",
|
||||
" for t in line.split():\n",
|
||||
" yield t\n",
|
||||
" yield '</s>'\n",
|
||||
" line = line.strip()\n",
|
||||
" line = simple_preprocess(line)\n",
|
||||
" yield '<s>'\n",
|
||||
" for t in line.split():\n",
|
||||
" yield t\n",
|
||||
" yield '</s>'\n",
|
||||
"\n",
|
||||
"def get_word_lines_from_file(file_name, n_size=-1):\n",
|
||||
" with lzma.open(file_name, 'r') as fh:\n",
|
||||
@ -63,14 +63,14 @@
|
||||
"\n",
|
||||
"def build_vocab(file, vocab_size):\n",
|
||||
" try:\n",
|
||||
" with open(f'vocab_{vocab_size}.pickle', 'rb') as handle:\n",
|
||||
" with open(f'trigram_nn_vocab_{vocab_size}.pickle', 'rb') as handle:\n",
|
||||
" vocab = pickle.load(handle)\n",
|
||||
" except:\n",
|
||||
" vocab = build_vocab_from_iterator(\n",
|
||||
" get_word_lines_from_file(file),\n",
|
||||
" max_tokens = vocab_size,\n",
|
||||
" specials = ['<unk>'])\n",
|
||||
" with open(f'vocab_{vocab_size}.pickle', 'wb') as handle:\n",
|
||||
" with open(f'trigram_nn_vocab_{vocab_size}.pickle', 'wb') as handle:\n",
|
||||
" pickle.dump(vocab, handle, protocol=pickle.HIGHEST_PROTOCOL)\n",
|
||||
" return vocab\n",
|
||||
"\n",
|
||||
@ -85,18 +85,18 @@
|
||||
" (self.vocab[t] for t in chain.from_iterable(get_word_lines_from_file(self.text_file))))\n",
|
||||
"\n",
|
||||
"class TrigramNeuralLanguageModel(nn.Module):\n",
|
||||
" def __init__(self, vocab_size, embed_size):\n",
|
||||
" def __init__(self, vocab_size, embed_size, hidden_size):\n",
|
||||
" super(TrigramNeuralLanguageModel, self).__init__()\n",
|
||||
" self.embeddings = nn.Embedding(vocab_size, embed_size)\n",
|
||||
" self.hidden_layer = nn.Linear(2*embed_size, 64)\n",
|
||||
" self.output_layer = nn.Linear(64, vocab_size)\n",
|
||||
" self.hidden_layer = nn.Linear(2*embed_size, hidden_size)\n",
|
||||
" self.output_layer = nn.Linear(hidden_size, vocab_size)\n",
|
||||
" self.softmax = nn.Softmax(dim=1)\n",
|
||||
"\n",
|
||||
" def forward(self, x):\n",
|
||||
" embeds = self.embeddings(x[0]), self.embeddings(x[1])\n",
|
||||
" concat_embed = torch.concat(embeds, dim=1)\n",
|
||||
" z = F.relu(self.hidden_layer(concat_embed))\n",
|
||||
" softmax = nn.Softmax(dim=1)\n",
|
||||
" y = softmax(self.output_layer(z))\n",
|
||||
" y = self.softmax(self.output_layer(z))\n",
|
||||
" return y"
|
||||
],
|
||||
"metadata": {
|
||||
@ -109,18 +109,20 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"max_steps = -1\n",
|
||||
"vocab_size = 5000\n",
|
||||
"embed_size = 50\n",
|
||||
"batch_size = 5000\n",
|
||||
"vocab_size = 20000\n",
|
||||
"embed_size = 150\n",
|
||||
"batch_size = 1024\n",
|
||||
"hidden_size = 1024\n",
|
||||
"learning_rate = 0.001\n",
|
||||
"vocab = build_vocab('challenging-america-word-gap-prediction/train/in.tsv.xz', vocab_size)\n",
|
||||
"train_dataset = Trigrams('challenging-america-word-gap-prediction/train/in.tsv.xz')\n",
|
||||
"if torch.cuda.is_available():\n",
|
||||
" device = 'cuda'\n",
|
||||
"else:\n",
|
||||
" raise Exception()\n",
|
||||
"model = TrigramNeuralLanguageModel(vocab_size, embed_size).to(device)\n",
|
||||
"model = TrigramNeuralLanguageModel(vocab_size, embed_size, hidden_size).to(device)\n",
|
||||
"data = DataLoader(train_dataset, batch_size=batch_size)\n",
|
||||
"optimizer = torch.optim.Adam(model.parameters())\n",
|
||||
"optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)\n",
|
||||
"criterion = torch.nn.NLLLoss()\n",
|
||||
"\n",
|
||||
"model.train()\n",
|
||||
@ -132,9 +134,9 @@
|
||||
" ypredicted = model(x)\n",
|
||||
" loss = criterion(torch.log(ypredicted), y)\n",
|
||||
" if step % 1000 == 0:\n",
|
||||
" print(step, loss)\n",
|
||||
" if step % 1000 == 0:\n",
|
||||
" torch.save(model.state_dict(), f'model_steps-{step}_vocab-{vocab_size}_embed-{embed_size}_batch-{batch_size}.bin')\n",
|
||||
" print(f'steps: {step}, loss: {loss.item()}')\n",
|
||||
" if step != 0:\n",
|
||||
" torch.save(model.state_dict(), f'trigram_nn_model_steps-{step}_vocab-{vocab_size}_embed-{embed_size}_batch-{batch_size}_hidden-{hidden_size}_lr-{learning_rate}.bin')\n",
|
||||
" loss.backward()\n",
|
||||
" optimizer.step()\n",
|
||||
" if step == max_steps:\n",
|
||||
@ -150,12 +152,11 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 9,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"vocab_size = 5000\n",
|
||||
"embed_size = 50\n",
|
||||
"batch_size = 5000\n",
|
||||
"vocab_size = 20000\n",
|
||||
"embed_size = 150\n",
|
||||
"vocab = build_vocab('challenging-america-word-gap-prediction/train/in.tsv.xz', vocab_size)\n",
|
||||
"vocab.set_default_index(vocab['<unk>'])"
|
||||
],
|
||||
@ -165,44 +166,58 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"outputs": [],
|
||||
"execution_count": 11,
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"trigram_nn_model_steps-13000_vocab-20000_embed-150_batch-512_hidden-256_lr-0.0001.bin\n",
|
||||
"512\n",
|
||||
"256\n",
|
||||
"trigram_nn_model_steps-7000_vocab-20000_embed-150_batch-1024_hidden-1024_lr-0.001.bin\n",
|
||||
"1024\n",
|
||||
"1024\n",
|
||||
"trigram_nn_model_steps-6000_vocab-20000_embed-150_batch-4096_hidden-256_lr-0.001.bin\n",
|
||||
"4096\n",
|
||||
"256\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"for model_name in ['model_steps-1000_vocab-5000_embed-50_batch-5000.bin',\n",
|
||||
" 'model_steps-1000_vocab-5000_embed-50_batch-5000.bin', 'model_steps-27000_vocab-5000_embed-50_batch-5000.bin']:\n",
|
||||
"for model_name in ['trigram_nn_model_steps-13000_vocab-20000_embed-150_batch-512_hidden-256_lr-0.0001.bin', 'trigram_nn_model_steps-7000_vocab-20000_embed-150_batch-1024_hidden-1024_lr-0.001.bin', 'trigram_nn_model_steps-6000_vocab-20000_embed-150_batch-4096_hidden-256_lr-0.001.bin']:\n",
|
||||
" print(model_name)\n",
|
||||
" batch_size = int(model_name.split('_')[-3].split('-')[1])\n",
|
||||
" print(batch_size)\n",
|
||||
" hidden_size = int(model_name.split('_')[-2].split('-')[1])\n",
|
||||
" print(hidden_size)\n",
|
||||
" topk = 10\n",
|
||||
" preds = []\n",
|
||||
" device = 'cuda'\n",
|
||||
" model = TrigramNeuralLanguageModel(vocab_size, embed_size).to(device)\n",
|
||||
" model = TrigramNeuralLanguageModel(vocab_size, embed_size, hidden_size).to(device)\n",
|
||||
" model.load_state_dict(torch.load(model_name))\n",
|
||||
" model.eval()\n",
|
||||
" j = 0\n",
|
||||
" for path in ['challenging-america-word-gap-prediction/dev-0', 'challenging-america-word-gap-prediction/test-A']:\n",
|
||||
" with lzma.open(f'{path}/in.tsv.xz', 'r') as fh, open(f'{path}/out.tsv', 'w', encoding='utf-8') as f_out:\n",
|
||||
" for line in fh:\n",
|
||||
" right_context = simple_preprocess(line.decode('utf-8').split('\\t')[-1]).split()[:2]\n",
|
||||
" x = torch.tensor(vocab.forward([right_context[0]])).to(device), \\\n",
|
||||
" torch.tensor(vocab.forward([right_context[1]])).to(device)\n",
|
||||
" out = model(x)\n",
|
||||
" top = torch.topk(out[0], 5)\n",
|
||||
" top_indices = top.indices.tolist()\n",
|
||||
" top_probs = top.values.tolist()\n",
|
||||
" top_words = vocab.lookup_tokens(top_indices)\n",
|
||||
" top_zipped = list(zip(top_words, top_probs))\n",
|
||||
" pred = ''\n",
|
||||
" unk = None\n",
|
||||
" for i, tup in enumerate(top_zipped):\n",
|
||||
" if tup[0] == '<unk>':\n",
|
||||
" unk = top_zipped.pop(i)\n",
|
||||
" for tup in top_zipped:\n",
|
||||
" pred += f'{tup[0]}:{tup[1]}\\t'\n",
|
||||
" if unk:\n",
|
||||
" pred += f':{unk[1]}'\n",
|
||||
" else:\n",
|
||||
" pred = pred.rstrip()\n",
|
||||
" f_out.write(pred + '\\n')\n",
|
||||
" if j % 1000 == 0:\n",
|
||||
" print(pred)\n",
|
||||
" j += 1\n",
|
||||
" with lzma.open(f'{path}/in.tsv.xz', 'r') as fh, open(f'{path}/out.tsv', 'w', encoding='utf-8') as f_out:\n",
|
||||
" for line in fh:\n",
|
||||
" right_context = simple_preprocess(line.decode('utf-8').split('\\t')[-1].strip()).split()[:2]\n",
|
||||
" x = torch.tensor(vocab.forward([right_context[0]])).to(device), \\\n",
|
||||
" torch.tensor(vocab.forward([right_context[1]])).to(device)\n",
|
||||
" out = model(x)\n",
|
||||
" top = torch.topk(out[0], topk)\n",
|
||||
" top_indices = top.indices.tolist()\n",
|
||||
" top_probs = top.values.tolist()\n",
|
||||
" top_words = vocab.lookup_tokens(top_indices)\n",
|
||||
" top_zipped = zip(top_words, top_probs)\n",
|
||||
" pred = ''\n",
|
||||
" total_prob = 0\n",
|
||||
" for word, prob in top_zipped:\n",
|
||||
" if word != '<unk>':\n",
|
||||
" pred += f'{word}:{prob} '\n",
|
||||
" total_prob += prob\n",
|
||||
" unk_prob = 1 - total_prob\n",
|
||||
" pred += f':{unk_prob}'\n",
|
||||
" f_out.write(pred + '\\n')\n",
|
||||
" src=f'{path}/out.tsv'\n",
|
||||
" dst=f\"{path}/{model_name.split('.')[0]}_out.tsv\"\n",
|
||||
" shutil.copy(src, dst)"
|
||||
@ -210,6 +225,38 @@
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/home/ked/PycharmProjects/mj9/challenging-america-word-gap-prediction\n",
|
||||
"300.66\r\n",
|
||||
"/home/ked/PycharmProjects/mj9\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"%cd challenging-america-word-gap-prediction/\n",
|
||||
"!./geval --test-name dev-0\n",
|
||||
"%cd ../"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"outputs": [],
|
||||
"source": [],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
}
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
14828
test-A/out.tsv
14828
test-A/out.tsv
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user