inference and results fixes
This commit is contained in:
parent
aaccbbeb06
commit
035ee66c44
21038
dev-0/out.tsv
21038
dev-0/out.tsv
File diff suppressed because it is too large
Load Diff
219
solution.ipynb
219
solution.ipynb
@ -16,35 +16,6 @@
|
|||||||
"gpuClass": "standard"
|
"gpuClass": "standard"
|
||||||
},
|
},
|
||||||
"cells": [
|
"cells": [
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 1,
|
|
||||||
"metadata": {
|
|
||||||
"colab": {
|
|
||||||
"base_uri": "https://localhost:8080/"
|
|
||||||
},
|
|
||||||
"id": "5oSsy7tRYrXO",
|
|
||||||
"outputId": "896cbe7d-61a5-44b0-b4fb-ba308c6ea7b2"
|
|
||||||
},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"output_type": "stream",
|
|
||||||
"name": "stdout",
|
|
||||||
"text": [
|
|
||||||
"Cloning into 'challenging-america-word-gap-prediction'...\n",
|
|
||||||
"remote: Wymienianie obiektów: 27, gotowe.\u001b[K\n",
|
|
||||||
"remote: Zliczanie obiektów: 100% (27/27), gotowe.\u001b[K\n",
|
|
||||||
"remote: Kompresowanie obiektów: 100% (23/23), gotowe.\u001b[K\n",
|
|
||||||
"remote: Razem 27 (delty 2), użyte ponownie 18 (delty 0), paczki użyte ponownie 0\u001b[K\n",
|
|
||||||
"Receiving objects: 100% (27/27), 278.33 MiB | 8.66 MiB/s, done.\n",
|
|
||||||
"Resolving deltas: 100% (2/2), done.\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
" !git clone --single-branch git://gonito.net/challenging-america-word-gap-prediction -b master"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"source": [
|
"source": [
|
||||||
@ -61,7 +32,7 @@
|
|||||||
"metadata": {
|
"metadata": {
|
||||||
"id": "WnglOFA8gGJl"
|
"id": "WnglOFA8gGJl"
|
||||||
},
|
},
|
||||||
"execution_count": 6,
|
"execution_count": 2,
|
||||||
"outputs": []
|
"outputs": []
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -96,14 +67,14 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"def build_vocab(file, vocab_size):\n",
|
"def build_vocab(file, vocab_size):\n",
|
||||||
" try:\n",
|
" try:\n",
|
||||||
" with open(f'vocab_{vocab_size}.pickle', 'rb') as handle:\n",
|
" with open(f'bigram_nn_vocab_{vocab_size}.pickle', 'rb') as handle:\n",
|
||||||
" vocab = pickle.load(handle)\n",
|
" vocab = pickle.load(handle)\n",
|
||||||
" except:\n",
|
" except:\n",
|
||||||
" vocab = build_vocab_from_iterator(\n",
|
" vocab = build_vocab_from_iterator(\n",
|
||||||
" get_word_lines_from_file(file),\n",
|
" get_word_lines_from_file(file),\n",
|
||||||
" max_tokens = vocab_size,\n",
|
" max_tokens = vocab_size,\n",
|
||||||
" specials = ['<unk>'])\n",
|
" specials = ['<unk>'])\n",
|
||||||
" with open(f'vocab_{vocab_size}.pickle', 'wb') as handle:\n",
|
" with open(f'bigram_nn_vocab_{vocab_size}.pickle', 'wb') as handle:\n",
|
||||||
" pickle.dump(vocab, handle, protocol=pickle.HIGHEST_PROTOCOL)\n",
|
" pickle.dump(vocab, handle, protocol=pickle.HIGHEST_PROTOCOL)\n",
|
||||||
" return vocab\n",
|
" return vocab\n",
|
||||||
"\n",
|
"\n",
|
||||||
@ -124,7 +95,7 @@
|
|||||||
" self.model = nn.Sequential(\n",
|
" self.model = nn.Sequential(\n",
|
||||||
" nn.Embedding(vocabulary_size, embedding_size),\n",
|
" nn.Embedding(vocabulary_size, embedding_size),\n",
|
||||||
" nn.Linear(embedding_size, vocabulary_size),\n",
|
" nn.Linear(embedding_size, vocabulary_size),\n",
|
||||||
" nn.Softmax()\n",
|
" nn.Softmax(dim=1)\n",
|
||||||
" )\n",
|
" )\n",
|
||||||
"\n",
|
"\n",
|
||||||
" def forward(self, x):\n",
|
" def forward(self, x):\n",
|
||||||
@ -133,16 +104,17 @@
|
|||||||
"metadata": {
|
"metadata": {
|
||||||
"id": "aW_3JqSNgLLr"
|
"id": "aW_3JqSNgLLr"
|
||||||
},
|
},
|
||||||
"execution_count": 25,
|
"execution_count": 3,
|
||||||
"outputs": []
|
"outputs": []
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"source": [
|
"source": [
|
||||||
"max_steps=-1\n",
|
"max_steps= -1\n",
|
||||||
"vocab_size = 5000\n",
|
"vocab_size = 20000\n",
|
||||||
"embed_size = 50\n",
|
"embed_size = 150\n",
|
||||||
"batch_size = 5000\n",
|
"batch_size = 5000\n",
|
||||||
|
"learning_rate = 0.001\n",
|
||||||
"vocab = build_vocab('challenging-america-word-gap-prediction/train/in.tsv.xz', vocab_size)\n",
|
"vocab = build_vocab('challenging-america-word-gap-prediction/train/in.tsv.xz', vocab_size)\n",
|
||||||
"train_dataset = Bigrams('challenging-america-word-gap-prediction/train/in.tsv.xz', vocab_size)\n",
|
"train_dataset = Bigrams('challenging-america-word-gap-prediction/train/in.tsv.xz', vocab_size)\n",
|
||||||
"if torch.cuda.is_available():\n",
|
"if torch.cuda.is_available():\n",
|
||||||
@ -151,7 +123,7 @@
|
|||||||
" raise Exception()\n",
|
" raise Exception()\n",
|
||||||
"model = SimpleBigramNeuralLanguageModel(vocab_size, embed_size).to(device)\n",
|
"model = SimpleBigramNeuralLanguageModel(vocab_size, embed_size).to(device)\n",
|
||||||
"data = DataLoader(train_dataset, batch_size=batch_size)\n",
|
"data = DataLoader(train_dataset, batch_size=batch_size)\n",
|
||||||
"optimizer = torch.optim.Adam(model.parameters())\n",
|
"optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)\n",
|
||||||
"criterion = torch.nn.NLLLoss()\n",
|
"criterion = torch.nn.NLLLoss()\n",
|
||||||
"\n",
|
"\n",
|
||||||
"model.train()\n",
|
"model.train()\n",
|
||||||
@ -160,12 +132,12 @@
|
|||||||
" x = x.to(device)\n",
|
" x = x.to(device)\n",
|
||||||
" y = y.to(device)\n",
|
" y = y.to(device)\n",
|
||||||
" optimizer.zero_grad()\n",
|
" optimizer.zero_grad()\n",
|
||||||
" ypredicted = model(x)\n",
|
" y_predicted = model(x)\n",
|
||||||
" loss = criterion(torch.log(ypredicted), y)\n",
|
" loss = criterion(torch.log(y_predicted), y)\n",
|
||||||
" if step % 1000 == 0:\n",
|
" if step % 1000 == 0:\n",
|
||||||
" print(step, loss)\n",
|
" print(f'steps: {step}, loss: {loss.item()}')\n",
|
||||||
" if step % 1000 == 0:\n",
|
" if step != 0:\n",
|
||||||
" torch.save(model.state_dict(), f'model_steps-{step}_vocab-{vocab_size}_embed-{embed_size}_batch-{batch_size}.bin')\n",
|
" torch.save(model.state_dict(), f'bigram_nn_model_steps-{step}_vocab-{vocab_size}_embed-{embed_size}_batch-{batch_size}.bin')\n",
|
||||||
" if step == max_steps:\n",
|
" if step == max_steps:\n",
|
||||||
" break\n",
|
" break\n",
|
||||||
" step += 1\n",
|
" step += 1\n",
|
||||||
@ -179,56 +151,46 @@
|
|||||||
"id": "QQw_E7Ku4h0a",
|
"id": "QQw_E7Ku4h0a",
|
||||||
"outputId": "4a37d9ba-1abd-46ae-b157-cd6d52b951a2"
|
"outputId": "4a37d9ba-1abd-46ae-b157-cd6d52b951a2"
|
||||||
},
|
},
|
||||||
"execution_count": 11,
|
"execution_count": 4,
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"output_type": "stream",
|
|
||||||
"name": "stdout",
|
|
||||||
"text": [
|
|
||||||
"Currently training: model_steps--1_vocab-5000_embed-50_batch-5000.bin\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"output_type": "stream",
|
|
||||||
"name": "stderr",
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"/usr/local/lib/python3.10/dist-packages/torch/nn/modules/container.py:217: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n",
|
"/home/ked/PycharmProjects/mj9/venv/lib/python3.10/site-packages/torch/nn/modules/container.py:217: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n",
|
||||||
" input = module(input)\n"
|
" input = module(input)\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"output_type": "stream",
|
|
||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"0 tensor(8.6451, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
"steps: 0, loss: 10.091094017028809\n",
|
||||||
"1000 tensor(4.7971, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
"steps: 1000, loss: 5.73332405090332\n",
|
||||||
"2000 tensor(4.7606, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
"steps: 2000, loss: 5.655370712280273\n",
|
||||||
"3000 tensor(4.5784, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
"steps: 3000, loss: 5.457630634307861\n",
|
||||||
"4000 tensor(4.5029, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
"steps: 4000, loss: 5.38517427444458\n",
|
||||||
"5000 tensor(4.6751, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
"steps: 5000, loss: 5.467936992645264\n",
|
||||||
"6000 tensor(4.4452, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
"steps: 6000, loss: 5.372152328491211\n",
|
||||||
"7000 tensor(4.4145, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
"steps: 7000, loss: 5.272013187408447\n",
|
||||||
"8000 tensor(4.5194, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
"steps: 8000, loss: 5.439966201782227\n",
|
||||||
"9000 tensor(4.4242, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
"steps: 9000, loss: 5.268238544464111\n",
|
||||||
"10000 tensor(4.2885, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
"steps: 10000, loss: 5.1395182609558105\n",
|
||||||
"11000 tensor(4.3033, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
"steps: 11000, loss: 5.2558159828186035\n",
|
||||||
"12000 tensor(4.4238, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
"steps: 12000, loss: 5.263617515563965\n"
|
||||||
"13000 tensor(4.5368, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
]
|
||||||
"14000 tensor(4.3551, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
},
|
||||||
"15000 tensor(4.3116, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
{
|
||||||
"16000 tensor(4.3750, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
"ename": "KeyboardInterrupt",
|
||||||
"17000 tensor(4.4356, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
"evalue": "",
|
||||||
"18000 tensor(4.4206, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
"output_type": "error",
|
||||||
"19000 tensor(4.5120, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
"traceback": [
|
||||||
"20000 tensor(4.4687, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
"\u001B[0;31m---------------------------------------------------------------------------\u001B[0m",
|
||||||
"21000 tensor(4.3365, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
"\u001B[0;31mKeyboardInterrupt\u001B[0m Traceback (most recent call last)",
|
||||||
"22000 tensor(4.3464, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
"Cell \u001B[0;32mIn[4], line 31\u001B[0m\n\u001B[1;32m 29\u001B[0m \u001B[38;5;28;01mbreak\u001B[39;00m\n\u001B[1;32m 30\u001B[0m step \u001B[38;5;241m+\u001B[39m\u001B[38;5;241m=\u001B[39m \u001B[38;5;241m1\u001B[39m\n\u001B[0;32m---> 31\u001B[0m \u001B[43mloss\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mbackward\u001B[49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 32\u001B[0m optimizer\u001B[38;5;241m.\u001B[39mstep()\n",
|
||||||
"23000 tensor(4.4861, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
"File \u001B[0;32m~/PycharmProjects/mj9/venv/lib/python3.10/site-packages/torch/_tensor.py:487\u001B[0m, in \u001B[0;36mTensor.backward\u001B[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001B[0m\n\u001B[1;32m 477\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m has_torch_function_unary(\u001B[38;5;28mself\u001B[39m):\n\u001B[1;32m 478\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m handle_torch_function(\n\u001B[1;32m 479\u001B[0m Tensor\u001B[38;5;241m.\u001B[39mbackward,\n\u001B[1;32m 480\u001B[0m (\u001B[38;5;28mself\u001B[39m,),\n\u001B[0;32m (...)\u001B[0m\n\u001B[1;32m 485\u001B[0m inputs\u001B[38;5;241m=\u001B[39minputs,\n\u001B[1;32m 486\u001B[0m )\n\u001B[0;32m--> 487\u001B[0m \u001B[43mtorch\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mautograd\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mbackward\u001B[49m\u001B[43m(\u001B[49m\n\u001B[1;32m 488\u001B[0m \u001B[43m \u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mgradient\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mretain_graph\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mcreate_graph\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43minputs\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43minputs\u001B[49m\n\u001B[1;32m 489\u001B[0m \u001B[43m\u001B[49m\u001B[43m)\u001B[49m\n",
|
||||||
"24000 tensor(4.3531, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
"File \u001B[0;32m~/PycharmProjects/mj9/venv/lib/python3.10/site-packages/torch/autograd/__init__.py:200\u001B[0m, in \u001B[0;36mbackward\u001B[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001B[0m\n\u001B[1;32m 195\u001B[0m retain_graph \u001B[38;5;241m=\u001B[39m create_graph\n\u001B[1;32m 197\u001B[0m \u001B[38;5;66;03m# The reason we repeat same the comment below is that\u001B[39;00m\n\u001B[1;32m 198\u001B[0m \u001B[38;5;66;03m# some Python versions print out the first line of a multi-line function\u001B[39;00m\n\u001B[1;32m 199\u001B[0m \u001B[38;5;66;03m# calls in the traceback and some print out the last line\u001B[39;00m\n\u001B[0;32m--> 200\u001B[0m \u001B[43mVariable\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43m_execution_engine\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mrun_backward\u001B[49m\u001B[43m(\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;66;43;03m# Calls into the C++ engine to run the backward pass\u001B[39;49;00m\n\u001B[1;32m 201\u001B[0m \u001B[43m \u001B[49m\u001B[43mtensors\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mgrad_tensors_\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mretain_graph\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mcreate_graph\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43minputs\u001B[49m\u001B[43m,\u001B[49m\n\u001B[1;32m 202\u001B[0m \u001B[43m \u001B[49m\u001B[43mallow_unreachable\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;28;43;01mTrue\u001B[39;49;00m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43maccumulate_grad\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;28;43;01mTrue\u001B[39;49;00m\u001B[43m)\u001B[49m\n",
|
||||||
"25000 tensor(4.3431, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
"\u001B[0;31mKeyboardInterrupt\u001B[0m: "
|
||||||
"26000 tensor(4.3747, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|
||||||
"27000 tensor(4.2183, device='cuda:0', grad_fn=<NllLossBackward0>)\n",
|
|
||||||
"28000 tensor(4.4097, device='cuda:0', grad_fn=<NllLossBackward0>)\n"
|
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
@ -236,14 +198,8 @@
|
|||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"source": [
|
"source": [
|
||||||
"# vocab_size = 5000\n",
|
|
||||||
"# embed_size = 50\n",
|
|
||||||
"# batch_size = 5000\n",
|
|
||||||
"# vocab = build_vocab('challenging-america-word-gap-prediction/train/in.tsv.xz', vocab_size)\n",
|
|
||||||
"# vocab.set_default_index(vocab['<unk>'])\n",
|
|
||||||
"\n",
|
|
||||||
"vocab_size = 20000\n",
|
"vocab_size = 20000\n",
|
||||||
"embed_size = 100\n",
|
"embed_size = 150\n",
|
||||||
"batch_size = 5000\n",
|
"batch_size = 5000\n",
|
||||||
"vocab = build_vocab('challenging-america-word-gap-prediction/train/in.tsv.xz', vocab_size)\n",
|
"vocab = build_vocab('challenging-america-word-gap-prediction/train/in.tsv.xz', vocab_size)\n",
|
||||||
"vocab.set_default_index(vocab['<unk>'])"
|
"vocab.set_default_index(vocab['<unk>'])"
|
||||||
@ -251,44 +207,38 @@
|
|||||||
"metadata": {
|
"metadata": {
|
||||||
"id": "N9-wmLOEZ2aV"
|
"id": "N9-wmLOEZ2aV"
|
||||||
},
|
},
|
||||||
"execution_count": 42,
|
"execution_count": 5,
|
||||||
"outputs": []
|
"outputs": []
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"source": [
|
"source": [
|
||||||
|
"topk = 5\n",
|
||||||
"preds = []\n",
|
"preds = []\n",
|
||||||
"device = 'cuda'\n",
|
"device = 'cuda'\n",
|
||||||
"model = SimpleBigramNeuralLanguageModel(vocab_size, embed_size).to(device)\n",
|
"model = SimpleBigramNeuralLanguageModel(vocab_size, embed_size).to(device)\n",
|
||||||
"model.load_state_dict(torch.load('/content/model_steps-27000_vocab-5000_embed-50_batch-5000.bin'))\n",
|
"model.load_state_dict(torch.load('bigram_nn_model_steps-10000_vocab-20000_embed-150_batch-5000.bin'))\n",
|
||||||
"model.eval()\n",
|
"model.eval()\n",
|
||||||
"j = 0\n",
|
|
||||||
"for path in ['challenging-america-word-gap-prediction/dev-0', 'challenging-america-word-gap-prediction/test-A']:\n",
|
"for path in ['challenging-america-word-gap-prediction/dev-0', 'challenging-america-word-gap-prediction/test-A']:\n",
|
||||||
" with lzma.open(f'{path}/in.tsv.xz', 'r') as fh, open(f'{path}/out.tsv', 'w', encoding='utf-8') as f_out:\n",
|
" with lzma.open(f'{path}/in.tsv.xz', 'r') as fh, open(f'{path}/out.tsv', 'w', encoding='utf-8') as f_out:\n",
|
||||||
" for line in fh:\n",
|
" for line in fh:\n",
|
||||||
" previous_word = simple_preprocess(line.decode('utf-8').split('\\t')[-2]).split()[-1]\n",
|
" previous_word = simple_preprocess(line.decode('utf-8').split('\\t')[-2].strip()).split()[-1]\n",
|
||||||
" ixs = torch.tensor(vocab.forward([previous_word])).to(device)\n",
|
" ixs = torch.tensor(vocab.forward([previous_word])).to(device)\n",
|
||||||
" out = model(ixs)\n",
|
" out = model(ixs)\n",
|
||||||
" top = torch.topk(out[0], 5)\n",
|
" top = torch.topk(out[0], topk)\n",
|
||||||
" top_indices = top.indices.tolist()\n",
|
" top_indices = top.indices.tolist()\n",
|
||||||
" top_probs = top.values.tolist()\n",
|
" top_probs = top.values.tolist()\n",
|
||||||
" top_words = vocab.lookup_tokens(top_indices)\n",
|
" top_words = vocab.lookup_tokens(top_indices)\n",
|
||||||
" top_zipped = list(zip(top_words, top_probs))\n",
|
" top_zipped = zip(top_words, top_probs)\n",
|
||||||
" pred = ''\n",
|
" pred = ''\n",
|
||||||
" unk = None\n",
|
" total_prob = 0\n",
|
||||||
" for i, tup in enumerate(top_zipped):\n",
|
" for word, prob in top_zipped:\n",
|
||||||
" if tup[0] == '<unk>':\n",
|
" if word != '<unk>':\n",
|
||||||
" unk = top_zipped.pop(i)\n",
|
" pred += f'{word}:{prob} '\n",
|
||||||
" for tup in top_zipped:\n",
|
" total_prob += prob\n",
|
||||||
" pred += f'{tup[0]}:{tup[1]}\\t'\n",
|
" unk_prob = 1 - total_prob\n",
|
||||||
" if unk:\n",
|
" pred += f':{unk_prob}'\n",
|
||||||
" pred += f':{unk[1]}'\n",
|
" f_out.write(pred + '\\n')"
|
||||||
" else:\n",
|
|
||||||
" pred = pred.rstrip()\n",
|
|
||||||
" f_out.write(pred + '\\n')\n",
|
|
||||||
" if j % 1000 == 0:\n",
|
|
||||||
" print(pred)\n",
|
|
||||||
" j += 1 "
|
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"colab": {
|
"colab": {
|
||||||
@ -297,33 +247,40 @@
|
|||||||
"id": "99uioFpVCJL8",
|
"id": "99uioFpVCJL8",
|
||||||
"outputId": "d4267cb1-e557-478a-8cf7-91a90db07698"
|
"outputId": "d4267cb1-e557-478a-8cf7-91a90db07698"
|
||||||
},
|
},
|
||||||
"execution_count": 48,
|
"execution_count": 24,
|
||||||
|
"outputs": []
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 25,
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"output_type": "stream",
|
|
||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"the:0.32605835795402527\ta:0.03863263502717018\this:0.019891299307346344\ttho:0.017584890127182007\t:0.1336958259344101\n",
|
"/home/ked/PycharmProjects/mj9/challenging-america-word-gap-prediction\n",
|
||||||
"same:0.008983609266579151\tmost:0.006951075047254562\tfirst:0.005848093423992395\tUnited:0.005354634020477533\t:0.22962644696235657\n",
|
"394.97\r\n",
|
||||||
"of:0.1870267689228058\tNo.:0.05885934457182884\tand:0.0347345806658268\tnumbered:0.017088865861296654\t:0.12375127524137497\n",
|
"/home/ked/PycharmProjects/mj9\n"
|
||||||
"the:0.23099401593208313\ta:0.05134483054280281\this:0.017109891399741173\tthis:0.015690239146351814\t:0.2021108716726303\n",
|
|
||||||
"is:0.16247524321079254\twas:0.08097667992115021\twill:0.03666245937347412\twould:0.031893592327833176\t:0.09085553884506226\n",
|
|
||||||
"the:0.14925561845302582\tbe:0.07023955136537552\ta:0.0237724632024765\thave:0.0131039097905159\t:0.12894178926944733\n",
|
|
||||||
"years:0.11707684397697449\tmiles:0.038641661405563354\tacres:0.0361776202917099\tdays:0.035523977130651474\t:0.1676659733057022\n",
|
|
||||||
"and:0.05091285705566406\tof:0.03853045403957367\tthe:0.02558819204568863\tto:0.019778745248913765\t:0.2338942289352417\n",
|
|
||||||
"to:0.20445719361305237\tthe:0.13792230188846588\ta:0.04136090725660324\tby:0.02959897182881832\t:0.06412851065397263\n",
|
|
||||||
"the:0.14456485211849213\the:0.0543459951877594\tthey:0.0345623604953289\tit:0.03187565878033638\t:0.08283700793981552\n",
|
|
||||||
"to:0.11275122314691544\tof:0.07946161180734634\tlike:0.056227609515190125\tthat:0.05296172574162483\t:0.1051449254155159\n",
|
|
||||||
"of:0.04079027101397514\tday:0.0400676503777504\ttime:0.02808181196451187\tto:0.02239527367055416\t:0.147441565990448\n",
|
|
||||||
"on:0.28541672229766846\tat:0.043499380350112915\tthe:0.04269522428512573\tin:0.03935478255152702\t:0.10247787833213806\n",
|
|
||||||
".:0.26101377606391907\t.,:0.046980664134025574\tand:0.009626681916415691\tM:0.007779326289892197\t:0.3348052203655243\n",
|
|
||||||
"and:0.05091285705566406\tof:0.03853045403957367\tthe:0.02558819204568863\tto:0.019778745248913765\t:0.2338942289352417\n",
|
|
||||||
"the:0.4567626714706421\tsaid:0.053911514580249786\twith:0.04098761826753616\tand:0.02215263620018959\t:0.07401206344366074\n",
|
|
||||||
"and:0.19774483144283295\tbut:0.03353063389658928\tthe:0.029393238946795464\tas:0.026280701160430908\t:0.06644411385059357\n",
|
|
||||||
"and:0.15652838349342346\twho:0.038931723684072495\tbut:0.036329541355371475\tthe:0.03554282337427139\t:0.05828680843114853\n"
|
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
]
|
],
|
||||||
|
"source": [
|
||||||
|
"%cd challenging-america-word-gap-prediction/\n",
|
||||||
|
"!./geval --test-name dev-0\n",
|
||||||
|
"%cd ../"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false
|
||||||
|
}
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
14828
test-A/out.tsv
14828
test-A/out.tsv
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user