add solution code

This commit is contained in:
Kacper 2023-05-08 16:52:16 +02:00
parent e003ad6f34
commit 6a99ef51da

View File

@ -1,71 +1,33 @@
{ {
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU",
"gpuClass": "standard"
},
"cells": [ "cells": [
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 1, "execution_count": null,
"metadata": { "metadata": {
"colab": { "collapsed": true,
"base_uri": "https://localhost:8080/" "pycharm": {
}, "is_executing": true
"id": "5oSsy7tRYrXO",
"outputId": "896cbe7d-61a5-44b0-b4fb-ba308c6ea7b2"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Cloning into 'challenging-america-word-gap-prediction'...\n",
"remote: Wymienianie obiektów: 27, gotowe.\u001b[K\n",
"remote: Zliczanie obiektów: 100% (27/27), gotowe.\u001b[K\n",
"remote: Kompresowanie obiektów: 100% (23/23), gotowe.\u001b[K\n",
"remote: Razem 27 (delty 2), użyte ponownie 18 (delty 0), paczki użyte ponownie 0\u001b[K\n",
"Receiving objects: 100% (27/27), 278.33 MiB | 8.66 MiB/s, done.\n",
"Resolving deltas: 100% (2/2), done.\n"
]
} }
],
"source": [
" !git clone --single-branch git://gonito.net/challenging-america-word-gap-prediction -b master"
]
}, },
{ "outputs": [],
"cell_type": "code",
"source": [ "source": [
"from torchtext.vocab import build_vocab_from_iterator\n", "from torchtext.vocab import build_vocab_from_iterator\n",
"import pickle\n", "import pickle\n",
"from torch.utils.data import IterableDataset\n", "from torch.utils.data import IterableDataset\n",
"import itertools\n", "from itertools import chain\n",
"from torch import nn\n", "from torch import nn\n",
"import torch.nn.functional as F\n",
"import torch\n", "import torch\n",
"import lzma\n", "import lzma\n",
"from torch.utils.data import DataLoader\n", "from torch.utils.data import DataLoader\n",
"from tqdm import tqdm" "import shutil\n",
], "torch.manual_seed(1)"
"metadata": { ]
"id": "WnglOFA8gGJl"
},
"execution_count": 6,
"outputs": []
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 65,
"outputs": [],
"source": [ "source": [
"def simple_preprocess(line):\n", "def simple_preprocess(line):\n",
" return line.replace(r'\\n', ' ')\n", " return line.replace(r'\\n', ' ')\n",
@ -88,11 +50,16 @@
" break\n", " break\n",
"\n", "\n",
"def look_ahead_iterator(gen):\n", "def look_ahead_iterator(gen):\n",
" prev = None\n", " ngram = []\n",
" for item in gen:\n", " for item in gen:\n",
" if prev is not None:\n", " if len(ngram) < 3:\n",
" yield prev, item\n", " ngram.append(item)\n",
" prev = item\n", " if len(ngram) == 3:\n",
" yield ngram[1], ngram[2], ngram[0]\n",
" else:\n",
" ngram = ngram[1:]\n",
" ngram.append(item)\n",
" yield ngram[1], ngram[2], ngram[0]\n",
"\n", "\n",
"def build_vocab(file, vocab_size):\n", "def build_vocab(file, vocab_size):\n",
" try:\n", " try:\n",
@ -107,57 +74,59 @@
" pickle.dump(vocab, handle, protocol=pickle.HIGHEST_PROTOCOL)\n", " pickle.dump(vocab, handle, protocol=pickle.HIGHEST_PROTOCOL)\n",
" return vocab\n", " return vocab\n",
"\n", "\n",
"class Bigrams(IterableDataset):\n", "class Trigrams(IterableDataset):\n",
" def __init__(self, text_file, vocabulary_size):\n", " def __init__(self, text_file):\n",
" self.vocab = vocab\n", " self.vocab = vocab\n",
" self.vocab.set_default_index(self.vocab['<unk>'])\n", " self.vocab.set_default_index(self.vocab['<unk>'])\n",
" self.vocabulary_size = vocabulary_size\n",
" self.text_file = text_file\n", " self.text_file = text_file\n",
"\n", "\n",
" def __iter__(self):\n", " def __iter__(self):\n",
" return look_ahead_iterator(\n", " return look_ahead_iterator(\n",
" (self.vocab[t] for t in itertools.chain.from_iterable(get_word_lines_from_file(self.text_file))))\n", " (self.vocab[t] for t in chain.from_iterable(get_word_lines_from_file(self.text_file))))\n",
"\n", "\n",
"class SimpleBigramNeuralLanguageModel(nn.Module):\n", "class TrigramNeuralLanguageModel(nn.Module):\n",
" def __init__(self, vocabulary_size, embedding_size):\n", " def __init__(self, vocab_size, embed_size):\n",
" super(SimpleBigramNeuralLanguageModel, self).__init__()\n", " super(TrigramNeuralLanguageModel, self).__init__()\n",
" self.model = nn.Sequential(\n", " self.embeddings = nn.Embedding(vocab_size, embed_size)\n",
" nn.Embedding(vocabulary_size, embedding_size),\n", " self.hidden_layer = nn.Linear(2*embed_size, 64)\n",
" nn.Linear(embedding_size, vocabulary_size),\n", " self.output_layer = nn.Linear(64, vocab_size)\n",
" nn.Softmax()\n",
" )\n",
"\n", "\n",
" def forward(self, x):\n", " def forward(self, x):\n",
" return self.model(x)" " embeds = self.embeddings(x[0]), self.embeddings(x[1])\n",
" concat_embed = torch.concat(embeds, dim=1)\n",
" z = F.relu(self.hidden_layer(concat_embed))\n",
" softmax = nn.Softmax(dim=1)\n",
" y = softmax(self.output_layer(z))\n",
" return y"
], ],
"metadata": { "metadata": {
"id": "aW_3JqSNgLLr" "collapsed": false
}, }
"execution_count": 25,
"outputs": []
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [ "source": [
"max_steps=-1\n", "max_steps = -1\n",
"vocab_size = 5000\n", "vocab_size = 5000\n",
"embed_size = 50\n", "embed_size = 50\n",
"batch_size = 5000\n", "batch_size = 5000\n",
"vocab = build_vocab('challenging-america-word-gap-prediction/train/in.tsv.xz', vocab_size)\n", "vocab = build_vocab('challenging-america-word-gap-prediction/train/in.tsv.xz', vocab_size)\n",
"train_dataset = Bigrams('challenging-america-word-gap-prediction/train/in.tsv.xz', vocab_size)\n", "train_dataset = Trigrams('challenging-america-word-gap-prediction/train/in.tsv.xz')\n",
"if torch.cuda.is_available():\n", "if torch.cuda.is_available():\n",
" device = 'cuda'\n", " device = 'cuda'\n",
"else:\n", "else:\n",
" raise Exception()\n", " raise Exception()\n",
"model = SimpleBigramNeuralLanguageModel(vocab_size, embed_size).to(device)\n", "model = TrigramNeuralLanguageModel(vocab_size, embed_size).to(device)\n",
"data = DataLoader(train_dataset, batch_size=batch_size)\n", "data = DataLoader(train_dataset, batch_size=batch_size)\n",
"optimizer = torch.optim.Adam(model.parameters())\n", "optimizer = torch.optim.Adam(model.parameters())\n",
"criterion = torch.nn.NLLLoss()\n", "criterion = torch.nn.NLLLoss()\n",
"\n", "\n",
"model.train()\n", "model.train()\n",
"step = 0\n", "step = 0\n",
"for x, y in data:\n", "for x1, x2, y in data:\n",
" x = x.to(device)\n", " x = x1.to(device), x2.to(device)\n",
" y = y.to(device)\n", " y = y.to(device)\n",
" optimizer.zero_grad()\n", " optimizer.zero_grad()\n",
" ypredicted = model(x)\n", " ypredicted = model(x)\n",
@ -166,109 +135,54 @@
" print(step, loss)\n", " print(step, loss)\n",
" if step % 1000 == 0:\n", " if step % 1000 == 0:\n",
" torch.save(model.state_dict(), f'model_steps-{step}_vocab-{vocab_size}_embed-{embed_size}_batch-{batch_size}.bin')\n", " torch.save(model.state_dict(), f'model_steps-{step}_vocab-{vocab_size}_embed-{embed_size}_batch-{batch_size}.bin')\n",
" loss.backward()\n",
" optimizer.step()\n",
" if step == max_steps:\n", " if step == max_steps:\n",
" break\n", " break\n",
" step += 1\n", " step += 1"
" loss.backward()\n",
" optimizer.step()"
], ],
"metadata": { "metadata": {
"colab": { "collapsed": false,
"base_uri": "https://localhost:8080/" "pycharm": {
}, "is_executing": true
"id": "QQw_E7Ku4h0a", }
"outputId": "4a37d9ba-1abd-46ae-b157-cd6d52b951a2"
},
"execution_count": 11,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Currently training: model_steps--1_vocab-5000_embed-50_batch-5000.bin\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"/usr/local/lib/python3.10/dist-packages/torch/nn/modules/container.py:217: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n",
" input = module(input)\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"0 tensor(8.6451, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"1000 tensor(4.7971, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"2000 tensor(4.7606, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"3000 tensor(4.5784, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"4000 tensor(4.5029, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"5000 tensor(4.6751, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"6000 tensor(4.4452, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"7000 tensor(4.4145, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"8000 tensor(4.5194, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"9000 tensor(4.4242, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"10000 tensor(4.2885, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"11000 tensor(4.3033, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"12000 tensor(4.4238, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"13000 tensor(4.5368, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"14000 tensor(4.3551, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"15000 tensor(4.3116, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"16000 tensor(4.3750, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"17000 tensor(4.4356, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"18000 tensor(4.4206, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"19000 tensor(4.5120, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"20000 tensor(4.4687, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"21000 tensor(4.3365, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"22000 tensor(4.3464, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"23000 tensor(4.4861, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"24000 tensor(4.3531, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"25000 tensor(4.3431, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"26000 tensor(4.3747, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"27000 tensor(4.2183, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
"28000 tensor(4.4097, device='cuda:0', grad_fn=<NllLossBackward0>)\n"
]
} }
]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [ "source": [
"# vocab_size = 5000\n", "vocab_size = 5000\n",
"# embed_size = 50\n", "embed_size = 50\n",
"# batch_size = 5000\n",
"# vocab = build_vocab('challenging-america-word-gap-prediction/train/in.tsv.xz', vocab_size)\n",
"# vocab.set_default_index(vocab['<unk>'])\n",
"\n",
"vocab_size = 20000\n",
"embed_size = 100\n",
"batch_size = 5000\n", "batch_size = 5000\n",
"vocab = build_vocab('challenging-america-word-gap-prediction/train/in.tsv.xz', vocab_size)\n", "vocab = build_vocab('challenging-america-word-gap-prediction/train/in.tsv.xz', vocab_size)\n",
"vocab.set_default_index(vocab['<unk>'])" "vocab.set_default_index(vocab['<unk>'])"
], ],
"metadata": { "metadata": {
"id": "N9-wmLOEZ2aV" "collapsed": false
}, }
"execution_count": 42,
"outputs": []
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [ "source": [
"preds = []\n", "for model_name in ['model_steps-1000_vocab-5000_embed-50_batch-5000.bin',\n",
"device = 'cuda'\n", " 'model_steps-1000_vocab-5000_embed-50_batch-5000.bin', 'model_steps-27000_vocab-5000_embed-50_batch-5000.bin']:\n",
"model = SimpleBigramNeuralLanguageModel(vocab_size, embed_size).to(device)\n", " preds = []\n",
"model.load_state_dict(torch.load('/content/model_steps-27000_vocab-5000_embed-50_batch-5000.bin'))\n", " device = 'cuda'\n",
"model.eval()\n", " model = TrigramNeuralLanguageModel(vocab_size, embed_size).to(device)\n",
"j = 0\n", " model.load_state_dict(torch.load(model_name))\n",
"for path in ['challenging-america-word-gap-prediction/dev-0', 'challenging-america-word-gap-prediction/test-A']:\n", " model.eval()\n",
" j = 0\n",
" for path in ['challenging-america-word-gap-prediction/dev-0', 'challenging-america-word-gap-prediction/test-A']:\n",
" with lzma.open(f'{path}/in.tsv.xz', 'r') as fh, open(f'{path}/out.tsv', 'w', encoding='utf-8') as f_out:\n", " with lzma.open(f'{path}/in.tsv.xz', 'r') as fh, open(f'{path}/out.tsv', 'w', encoding='utf-8') as f_out:\n",
" for line in fh:\n", " for line in fh:\n",
" previous_word = simple_preprocess(line.decode('utf-8').split('\\t')[-2]).split()[-1]\n", " right_context = simple_preprocess(line.decode('utf-8').split('\\t')[-1]).split()[:2]\n",
" ixs = torch.tensor(vocab.forward([previous_word])).to(device)\n", " x = torch.tensor(vocab.forward([right_context[0]])).to(device), \\\n",
" out = model(ixs)\n", " torch.tensor(vocab.forward([right_context[1]])).to(device)\n",
" out = model(x)\n",
" top = torch.topk(out[0], 5)\n", " top = torch.topk(out[0], 5)\n",
" top_indices = top.indices.tolist()\n", " top_indices = top.indices.tolist()\n",
" top_probs = top.values.tolist()\n", " top_probs = top.values.tolist()\n",
@ -288,42 +202,35 @@
" f_out.write(pred + '\\n')\n", " f_out.write(pred + '\\n')\n",
" if j % 1000 == 0:\n", " if j % 1000 == 0:\n",
" print(pred)\n", " print(pred)\n",
" j += 1 " " j += 1\n",
" src=f'{path}/out.tsv'\n",
" dst=f\"{path}/{model_name.split('.')[0]}_out.tsv\"\n",
" shutil.copy(src, dst)"
], ],
"metadata": { "metadata": {
"colab": { "collapsed": false
"base_uri": "https://localhost:8080/"
},
"id": "99uioFpVCJL8",
"outputId": "d4267cb1-e557-478a-8cf7-91a90db07698"
},
"execution_count": 48,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"the:0.32605835795402527\ta:0.03863263502717018\this:0.019891299307346344\ttho:0.017584890127182007\t:0.1336958259344101\n",
"same:0.008983609266579151\tmost:0.006951075047254562\tfirst:0.005848093423992395\tUnited:0.005354634020477533\t:0.22962644696235657\n",
"of:0.1870267689228058\tNo.:0.05885934457182884\tand:0.0347345806658268\tnumbered:0.017088865861296654\t:0.12375127524137497\n",
"the:0.23099401593208313\ta:0.05134483054280281\this:0.017109891399741173\tthis:0.015690239146351814\t:0.2021108716726303\n",
"is:0.16247524321079254\twas:0.08097667992115021\twill:0.03666245937347412\twould:0.031893592327833176\t:0.09085553884506226\n",
"the:0.14925561845302582\tbe:0.07023955136537552\ta:0.0237724632024765\thave:0.0131039097905159\t:0.12894178926944733\n",
"years:0.11707684397697449\tmiles:0.038641661405563354\tacres:0.0361776202917099\tdays:0.035523977130651474\t:0.1676659733057022\n",
"and:0.05091285705566406\tof:0.03853045403957367\tthe:0.02558819204568863\tto:0.019778745248913765\t:0.2338942289352417\n",
"to:0.20445719361305237\tthe:0.13792230188846588\ta:0.04136090725660324\tby:0.02959897182881832\t:0.06412851065397263\n",
"the:0.14456485211849213\the:0.0543459951877594\tthey:0.0345623604953289\tit:0.03187565878033638\t:0.08283700793981552\n",
"to:0.11275122314691544\tof:0.07946161180734634\tlike:0.056227609515190125\tthat:0.05296172574162483\t:0.1051449254155159\n",
"of:0.04079027101397514\tday:0.0400676503777504\ttime:0.02808181196451187\tto:0.02239527367055416\t:0.147441565990448\n",
"on:0.28541672229766846\tat:0.043499380350112915\tthe:0.04269522428512573\tin:0.03935478255152702\t:0.10247787833213806\n",
".:0.26101377606391907\t.,:0.046980664134025574\tand:0.009626681916415691\tM:0.007779326289892197\t:0.3348052203655243\n",
"and:0.05091285705566406\tof:0.03853045403957367\tthe:0.02558819204568863\tto:0.019778745248913765\t:0.2338942289352417\n",
"the:0.4567626714706421\tsaid:0.053911514580249786\twith:0.04098761826753616\tand:0.02215263620018959\t:0.07401206344366074\n",
"and:0.19774483144283295\tbut:0.03353063389658928\tthe:0.029393238946795464\tas:0.026280701160430908\t:0.06644411385059357\n",
"and:0.15652838349342346\twho:0.038931723684072495\tbut:0.036329541355371475\tthe:0.03554282337427139\t:0.05828680843114853\n"
]
} }
]
} }
] ],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 0
} }