Score: 6.04
This commit is contained in:
parent
337d2ffc42
commit
a2d183f2e3
21038
dev-0/out.tsv
21038
dev-0/out.tsv
File diff suppressed because it is too large
Load Diff
643
src/08_word2vec.ipynb
Normal file
643
src/08_word2vec.ipynb
Normal file
@ -0,0 +1,643 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# <b>Trigram</b> neural network model for gap fill task"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Import required packages"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 42,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from tqdm import tqdm\n",
|
||||||
|
"import re\n",
|
||||||
|
"import nltk\n",
|
||||||
|
"import os\n",
|
||||||
|
"import csv\n",
|
||||||
|
"import pandas as pd\n",
|
||||||
|
"import torch\n",
|
||||||
|
"import torch.nn as nn\n",
|
||||||
|
"import torch.optim as optim\n",
|
||||||
|
"import sys\n",
|
||||||
|
"import numpy as np\n",
|
||||||
|
"from torch.utils.data import DataLoader, TensorDataset\n",
|
||||||
|
"from bidict import bidict\n",
|
||||||
|
"import math\n",
|
||||||
|
"from sklearn.utils import shuffle\n",
|
||||||
|
"from collections import Counter\n",
|
||||||
|
"import random\n",
|
||||||
|
"from torchtext.vocab import build_vocab_from_iterator"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 43,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"os.environ['CUDA_LAUNCH_BLOCKING'] = '1'\n",
|
||||||
|
"os.environ['TORCH_USE_CUDA_DSA'] = '1'"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Global configuration variables"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 144,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"vocab_size = 40_000\n",
|
||||||
|
"batch_size = 256\n",
|
||||||
|
"embedding_dim = 128\n",
|
||||||
|
"hidden_dim = 1024\n",
|
||||||
|
"learning_rate = 0.001\n",
|
||||||
|
"epochs = 5\n",
|
||||||
|
"\n",
|
||||||
|
"output_size = vocab_size"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 102,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"cuda\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
||||||
|
"# device = torch.device(\"cpu\")\n",
|
||||||
|
"print(device)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Build vocabulary"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 46,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def get_word_lines_from_dataset():\n",
|
||||||
|
" dataset_dir = os.path.join('..', 'train', 'in.tsv.xz')\n",
|
||||||
|
" expected_dir = os.path.join('..', 'train', 'expected.tsv')\n",
|
||||||
|
"\n",
|
||||||
|
" df = pd.read_csv(dataset_dir, sep='\\t', header=None, names=['FileId', 'Year', 'LeftContext', 'RightContext'], quoting=csv.QUOTE_NONE, dtype=str, chunksize=1000)\n",
|
||||||
|
" expected_df = pd.read_csv(expected_dir, sep='\\t', header=None, names=['Word'], quoting=csv.QUOTE_NONE, dtype=str, chunksize=1000)\n",
|
||||||
|
"\n",
|
||||||
|
" for j, (df, expected_df) in tqdm(enumerate(zip(df, expected_df)), total=433):\n",
|
||||||
|
" df = df.replace(r'\\\\r+|\\\\n+|\\\\t+', ' ', regex=True)\n",
|
||||||
|
" \n",
|
||||||
|
" for left_context, word, right_context in zip(df['LeftContext'].to_list(), expected_df['Word'].to_list(), df['RightContext'].to_list()):\n",
|
||||||
|
" yield re.split(r\"\\s+\", left_context.strip()) + [str(word).strip()] + re.split(r\"\\s+\", right_context.strip())"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 103,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
" 0%| | 0/433 [00:00<?, ?it/s]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"100%|██████████| 433/433 [01:34<00:00, 4.60it/s]\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"vocab = build_vocab_from_iterator(\n",
|
||||||
|
" get_word_lines_from_dataset(),\n",
|
||||||
|
" max_tokens = vocab_size,\n",
|
||||||
|
" specials = ['<unk>'])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 104,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"vocab.set_default_index(vocab['<unk>'])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 105,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"['<unk>', 'the', 'of', 'me.']"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 105,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"vocab.lookup_tokens([0, 1, 2, 1245])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 106,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"100%|██████████| 433/433 [01:06<00:00, 6.50it/s]\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"dataset_dir = os.path.join('..', 'train', 'in.tsv.xz')\n",
|
||||||
|
"expected_dir = os.path.join('..', 'train', 'expected.tsv')\n",
|
||||||
|
"\n",
|
||||||
|
"df = pd.read_csv(dataset_dir, sep='\\t', header=None, names=['FileId', 'Year', 'LeftContext', 'RightContext'], quoting=csv.QUOTE_NONE, dtype=str, chunksize=1000)\n",
|
||||||
|
"expected_df = pd.read_csv(expected_dir, sep='\\t', header=None, names=['Word'], quoting=csv.QUOTE_NONE, dtype=str, chunksize=1000)\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"input_corpus = []\n",
|
||||||
|
"target_corpus = []\n",
|
||||||
|
"\n",
|
||||||
|
"left_tokens = 2\n",
|
||||||
|
"right_tokens = 2\n",
|
||||||
|
"\n",
|
||||||
|
"for j, (df, expected_df) in tqdm(enumerate(zip(df, expected_df)), total=433):\n",
|
||||||
|
" df = df.replace(r'\\\\r+|\\\\n+|\\\\t+', ' ', regex=True)\n",
|
||||||
|
" \n",
|
||||||
|
" for left_context, word, right_context in zip(df['LeftContext'].to_list(), expected_df['Word'].to_list(), df['RightContext'].to_list()):\n",
|
||||||
|
" target_corpus.append([str(word).strip()])\n",
|
||||||
|
" input_corpus.append(re.split(r\"\\s+\", left_context.strip())[-left_tokens:] + re.split(r\"\\s+\", right_context.strip())[:right_tokens])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Tokenize entire corpus"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 107,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"100%|██████████| 432022/432022 [00:02<00:00, 168428.47it/s]\n",
|
||||||
|
"100%|██████████| 432022/432022 [00:01<00:00, 332294.03it/s]\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"def tokenize(w):\n",
|
||||||
|
" return vocab[w]\n",
|
||||||
|
" \n",
|
||||||
|
"def detokenize(t):\n",
|
||||||
|
" return vocab.lookup_tokens([t])[0]\n",
|
||||||
|
"\n",
|
||||||
|
"tokenized_input_corpus = []\n",
|
||||||
|
"tokenized_target_corpus = []\n",
|
||||||
|
"\n",
|
||||||
|
"for words in tqdm(input_corpus):\n",
|
||||||
|
" tokenized_input_corpus.append([tokenize(word) for word in words])\n",
|
||||||
|
"\n",
|
||||||
|
"for words in tqdm(target_corpus):\n",
|
||||||
|
" tokenized_target_corpus.append([tokenize(word) for word in words])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 108,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"tokenized_input_corpus, tokenized_target_corpus = shuffle(tokenized_input_corpus, tokenized_target_corpus)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Create dataset"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 110,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"indices = np.nonzero(np.array(tokenized_target_corpus).flatten())\n",
|
||||||
|
"\n",
|
||||||
|
"tokenized_input_corpus = np.take(tokenized_input_corpus, indices, axis=0)\n",
|
||||||
|
"tokenized_target_corpus = np.take(tokenized_target_corpus, indices, axis=0)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 111,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"input_corpus_tensor = torch.flatten(torch.tensor(tokenized_input_corpus, dtype=torch.long, device=device), end_dim=-2)\n",
|
||||||
|
"target_corpus_tensor = torch.flatten(torch.tensor(tokenized_target_corpus, dtype=torch.long, device=device)).reshape(-1, 1)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 112,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"torch.Size([378666, 4])\n",
|
||||||
|
"torch.Size([378666, 1])\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"print(input_corpus_tensor.size())\n",
|
||||||
|
"print(target_corpus_tensor.size())"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 132,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"['silver', 'republicans', 'silver', 'den']\n",
|
||||||
|
"['and']\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"random_index = random.randint(0, len(input_corpus_tensor) - 1)\n",
|
||||||
|
"\n",
|
||||||
|
"# Get random element from input corpus\n",
|
||||||
|
"random_input_element = input_corpus_tensor[random_index]\n",
|
||||||
|
"\n",
|
||||||
|
"# Get corresponding element from target corpus\n",
|
||||||
|
"random_target_element = target_corpus_tensor[random_index]\n",
|
||||||
|
"\n",
|
||||||
|
"print([detokenize(int(idx)) for idx in random_input_element])\n",
|
||||||
|
"print([detokenize(int(idx)) for idx in random_target_element])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 121,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"dataset = TensorDataset(input_corpus_tensor, target_corpus_tensor)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 122,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Define the trigram neural network model"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 141,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"class TrigramNN(nn.Module):\n",
|
||||||
|
" def __init__(self, vocab_size, embedding_dim, hidden_dim, output_size):\n",
|
||||||
|
" super(TrigramNN, self).__init__()\n",
|
||||||
|
" self.embedding = nn.Embedding(vocab_size, embedding_dim)\n",
|
||||||
|
" self.linear = nn.Linear(embedding_dim * (left_tokens + right_tokens), output_size)\n",
|
||||||
|
" \n",
|
||||||
|
" def forward(self, inputs):\n",
|
||||||
|
" out = self.embedding(inputs)\n",
|
||||||
|
" out = out.view(inputs.size(0), -1)\n",
|
||||||
|
" out = self.linear(out)\n",
|
||||||
|
" # out = torch.softmax(out, dim=1)\n",
|
||||||
|
" return out"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Initialize the model, loss function, and optimizer"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 145,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"model = TrigramNN(vocab_size, embedding_dim, hidden_dim, output_size)\n",
|
||||||
|
"criterion = nn.CrossEntropyLoss()\n",
|
||||||
|
"optimizer = optim.Adam(model.parameters(), lr=learning_rate)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Training loop"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 146,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
" 0%| | 0/1480 [00:00<?, ?it/s]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"100%|██████████| 1480/1480 [00:32<00:00, 44.97it/s]\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Epoch 1, Loss: 7.505966403999844\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"100%|██████████| 1480/1480 [00:32<00:00, 45.57it/s]\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Epoch 2, Loss: 5.1014555966531905\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"100%|██████████| 1480/1480 [00:32<00:00, 45.42it/s]\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Epoch 3, Loss: 3.835972652886365\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"100%|██████████| 1480/1480 [00:32<00:00, 45.60it/s]\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Epoch 4, Loss: 3.1567180975063427\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"100%|██████████| 1480/1480 [00:32<00:00, 44.90it/s]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Epoch 5, Loss: 2.749172909517546\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"model.to(device)\n",
|
||||||
|
"\n",
|
||||||
|
"for epoch in range(epochs):\n",
|
||||||
|
" total_loss = 0\n",
|
||||||
|
" for batch_inputs, batch_targets in tqdm(dataloader):\n",
|
||||||
|
" batch_inputs, batch_targets = batch_inputs.to(device), batch_targets.to(device)\n",
|
||||||
|
" \n",
|
||||||
|
" model.zero_grad()\n",
|
||||||
|
" output = model(batch_inputs)\n",
|
||||||
|
"\n",
|
||||||
|
" loss = criterion(output, batch_targets.view(-1))\n",
|
||||||
|
" total_loss += loss.item()\n",
|
||||||
|
"\n",
|
||||||
|
" loss.backward()\n",
|
||||||
|
" optimizer.step()\n",
|
||||||
|
"\n",
|
||||||
|
" print(f\"Epoch {epoch+1}, Loss: {total_loss/len(dataloader)}\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## test the model"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 157,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def predict(left_context, right_context):\n",
|
||||||
|
" with torch.no_grad():\n",
|
||||||
|
" context = left_context + right_context\n",
|
||||||
|
" test_context_idxs = torch.tensor([[tokenize(x) for x in context]], device=device)\n",
|
||||||
|
" output = model(test_context_idxs)\n",
|
||||||
|
" top_predicted_scores, top_predicted_indices = torch.topk(output, vocab_size)\n",
|
||||||
|
"\n",
|
||||||
|
" top_predicted_scores = np.array(top_predicted_scores[0].cpu())\n",
|
||||||
|
" top_predicted_indices = top_predicted_indices[0]\n",
|
||||||
|
"\n",
|
||||||
|
" top_predicted_scores = top_predicted_scores[top_predicted_scores > 0]\n",
|
||||||
|
" top_predicted_indices = top_predicted_indices[:len(top_predicted_scores)]\n",
|
||||||
|
"\n",
|
||||||
|
" total_score = np.sum([score for score in top_predicted_scores[:20]])\n",
|
||||||
|
"\n",
|
||||||
|
" predictions = list(zip(top_predicted_scores, top_predicted_indices))\n",
|
||||||
|
" predictions = [(round(float(score), 2), detokenize(idx)) for score, idx in predictions[:10]]\n",
|
||||||
|
" \n",
|
||||||
|
" words = [word for _, word in predictions]\n",
|
||||||
|
" scores = [round(score/total_score, 2) for score, _ in predictions]\n",
|
||||||
|
"\n",
|
||||||
|
" remaining_score = round(1.0 - np.sum(scores), 2)\n",
|
||||||
|
"\n",
|
||||||
|
" predictions = ' '.join([f\"{word}:{score}\" for score, word in zip(scores, words)]) + ' :' + str(remaining_score)\n",
|
||||||
|
"\n",
|
||||||
|
" return predictions"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 158,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"a:0.07 the:0.06 The:0.06 A:0.06 my:0.05 ,:0.05 to:0.05 and:0.05 -:0.05 an:0.05 :0.45\n",
|
||||||
|
"of:0.07 on:0.06 and:0.06 be:0.06 for:0.05 in:0.05 school,:0.05 ol:0.05 it:0.05 the:0.05 :0.45\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"print(predict([\"came\", \"fiom\"], [\"26th\", \"place\"]))\n",
|
||||||
|
"print(predict([\"will\", \"buy\"], [\"telephone\", \"and\"]))"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Generate result for dev dataset"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 159,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"dataset_dir = os.path.join('..', 'dev-0', 'in.tsv.xz')\n",
|
||||||
|
"output_dir = os.path.join('..', 'dev-0', 'out.tsv')\n",
|
||||||
|
"\n",
|
||||||
|
"df = pd.read_csv(dataset_dir, sep='\\t', header=None, names=['FileId', 'Year', 'LeftContext', 'RightContext'], quoting=csv.QUOTE_NONE)\n",
|
||||||
|
"df = df.replace(r'\\\\r+|\\\\n+|\\\\t+', ' ', regex=True)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 160,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"100%|██████████| 10519/10519 [00:51<00:00, 203.44it/s]\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"final = \"\"\n",
|
||||||
|
"\n",
|
||||||
|
"for i, (_, row) in tqdm(enumerate(df.iterrows()), total=len(df)):\n",
|
||||||
|
" left_context = re.split(r\"\\s+\", row['LeftContext'].strip())[-left_tokens:]\n",
|
||||||
|
" right_context = re.split(r\"\\s+\", row['RightContext'].strip())[:right_tokens]\n",
|
||||||
|
"\n",
|
||||||
|
" final += predict(left_context, right_context) + '\\n'\n",
|
||||||
|
"\n",
|
||||||
|
"with open(output_dir, 'w', encoding=\"UTF-8\") as f:\n",
|
||||||
|
" f.write(final)"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "p311-cu121",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.11.1"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 2
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user