Compare commits
4 Commits
Author | SHA1 | Date | |
---|---|---|---|
|
0c45624062 | ||
a2d183f2e3 | |||
|
337d2ffc42 | ||
|
d380959afc |
21038
dev-0/out.tsv
21038
dev-0/out.tsv
File diff suppressed because it is too large
Load Diff
248
run.py
Normal file
248
run.py
Normal file
@ -0,0 +1,248 @@
|
|||||||
|
# %% [markdown]
|
||||||
|
# # <b>Trigram</b> neural network model for gap fill task
|
||||||
|
|
||||||
|
# %% [markdown]
|
||||||
|
# ## Import required packages
|
||||||
|
|
||||||
|
# %%
|
||||||
|
from tqdm import tqdm
|
||||||
|
import re
|
||||||
|
import nltk
|
||||||
|
import os
|
||||||
|
import csv
|
||||||
|
import pandas as pd
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.optim as optim
|
||||||
|
import sys
|
||||||
|
import numpy as np
|
||||||
|
from torch.utils.data import DataLoader, TensorDataset
|
||||||
|
from bidict import bidict
|
||||||
|
import math
|
||||||
|
from sklearn.utils import shuffle
|
||||||
|
from collections import Counter
|
||||||
|
import random
|
||||||
|
|
||||||
|
# %%
|
||||||
|
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
|
||||||
|
os.environ['TORCH_USE_CUDA_DSA'] = '1'
|
||||||
|
|
||||||
|
# %% [markdown]
|
||||||
|
# ## Global configuration variables
|
||||||
|
|
||||||
|
# %%
|
||||||
|
vocab_size = 60_000
|
||||||
|
batch_size = 64
|
||||||
|
embedding_dim = 64
|
||||||
|
hidden_dim = 1024
|
||||||
|
learning_rate = 0.001
|
||||||
|
epochs = 20
|
||||||
|
|
||||||
|
output_size = vocab_size
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
device = torch.device("cpu")
|
||||||
|
print(device)
|
||||||
|
|
||||||
|
# %% [markdown]
|
||||||
|
# ## Load train data corpus
|
||||||
|
|
||||||
|
# %%
|
||||||
|
dataset_dir = os.path.join('..', 'train', 'in.tsv.xz')
|
||||||
|
expected_dir = os.path.join('..', 'train', 'expected.tsv')
|
||||||
|
|
||||||
|
df = pd.read_csv(dataset_dir, sep='\t', header=None, names=['FileId', 'Year', 'LeftContext', 'RightContext'], quoting=csv.QUOTE_NONE, dtype=str, chunksize=1000)
|
||||||
|
expected_df = pd.read_csv(expected_dir, sep='\t', header=None, names=['Word'], quoting=csv.QUOTE_NONE, dtype=str, chunksize=1000)
|
||||||
|
|
||||||
|
|
||||||
|
input_corpus = []
|
||||||
|
target_corpus = []
|
||||||
|
|
||||||
|
left_tokens = 1
|
||||||
|
right_tokens = 1
|
||||||
|
|
||||||
|
for j, (df, expected_df) in tqdm(enumerate(zip(df, expected_df)), total=433):
|
||||||
|
df = df.replace(r'\\r+|\\n+|\\t+', ' ', regex=True)
|
||||||
|
|
||||||
|
for left_context, word, right_context in zip(df['LeftContext'].to_list(), expected_df['Word'].to_list(), df['RightContext'].to_list()):
|
||||||
|
target_corpus.append([str(word).strip()])
|
||||||
|
input_corpus.append(re.split(r"\s+", left_context.strip())[-left_tokens:] + re.split(r"\s+", right_context.strip())[:right_tokens])
|
||||||
|
|
||||||
|
# %% [markdown]
|
||||||
|
# ## Create dictionaries for mapping words to indices
|
||||||
|
|
||||||
|
# %%
|
||||||
|
def flatten(matrix):
|
||||||
|
flat_list = []
|
||||||
|
for row in matrix:
|
||||||
|
flat_list += row
|
||||||
|
return flat_list
|
||||||
|
|
||||||
|
# %%
|
||||||
|
word_to_ix = bidict({})
|
||||||
|
words_corpus = flatten(input_corpus) + flatten(target_corpus)
|
||||||
|
|
||||||
|
counts = Counter(words_corpus)
|
||||||
|
|
||||||
|
for word, _ in tqdm(counts.most_common(vocab_size - 1)):
|
||||||
|
if word not in word_to_ix:
|
||||||
|
word_to_ix[word] = len(word_to_ix) + 1
|
||||||
|
|
||||||
|
# %% [markdown]
|
||||||
|
# ## Tokenize entire corpus
|
||||||
|
|
||||||
|
# %%
|
||||||
|
def tokenize(w):
|
||||||
|
if w in word_to_ix:
|
||||||
|
return word_to_ix[w]
|
||||||
|
else:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
tokenized_input_corpus = []
|
||||||
|
tokenized_target_corpus = []
|
||||||
|
|
||||||
|
for words in tqdm(input_corpus):
|
||||||
|
tokenized_input_corpus.append([tokenize(word) for word in words])
|
||||||
|
|
||||||
|
for words in tqdm(target_corpus):
|
||||||
|
tokenized_target_corpus.append([tokenize(word) for word in words])
|
||||||
|
|
||||||
|
# %%
|
||||||
|
tokenized_input_corpus, tokenized_target_corpus = shuffle(tokenized_input_corpus, tokenized_target_corpus)
|
||||||
|
|
||||||
|
# %% [markdown]
|
||||||
|
# ## Create dataset
|
||||||
|
|
||||||
|
# %%
|
||||||
|
indices = np.nonzero(np.array(tokenized_target_corpus).flatten())
|
||||||
|
|
||||||
|
tokenized_input_corpus = np.take(tokenized_input_corpus, indices, axis=0)
|
||||||
|
tokenized_target_corpus = np.take(tokenized_target_corpus, indices, axis=0)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
input_corpus_tensor = torch.flatten(torch.tensor(tokenized_input_corpus, dtype=torch.long, device=device), end_dim=-2)
|
||||||
|
target_corpus_tensor = torch.flatten(torch.tensor(tokenized_target_corpus, dtype=torch.long, device=device)).reshape(-1, 1)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
print(input_corpus_tensor.size())
|
||||||
|
print(target_corpus_tensor.size())
|
||||||
|
|
||||||
|
# %%
|
||||||
|
random_index = random.randint(0, len(input_corpus_tensor) - 1)
|
||||||
|
|
||||||
|
# Get random element from input corpus
|
||||||
|
random_input_element = input_corpus_tensor[random_index]
|
||||||
|
|
||||||
|
# Get corresponding element from target corpus
|
||||||
|
random_target_element = target_corpus_tensor[random_index]
|
||||||
|
|
||||||
|
print([word_to_ix.inverse[int(idx)] if int(idx) > 0 else '<UNK>' for idx in random_input_element])
|
||||||
|
print([word_to_ix.inverse[int(idx)] if int(idx) > 0 else '<UNK>' for idx in random_target_element])
|
||||||
|
|
||||||
|
# %%
|
||||||
|
dataset = TensorDataset(input_corpus_tensor[:10_000], target_corpus_tensor[:10_000])
|
||||||
|
|
||||||
|
# %%
|
||||||
|
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
|
||||||
|
|
||||||
|
# %% [markdown]
|
||||||
|
# ## Define the trigram neural network model
|
||||||
|
|
||||||
|
# %%
|
||||||
|
class TrigramNN(nn.Module):
|
||||||
|
def __init__(self, vocab_size, embedding_dim, hidden_dim, output_size):
|
||||||
|
super(TrigramNN, self).__init__()
|
||||||
|
self.embedding = nn.Embedding(vocab_size, embedding_dim)
|
||||||
|
self.linear1 = nn.Linear(embedding_dim * (left_tokens + right_tokens), hidden_dim)
|
||||||
|
self.linear2 = nn.Linear(hidden_dim, output_size)
|
||||||
|
|
||||||
|
def forward(self, inputs):
|
||||||
|
out = self.embedding(inputs)
|
||||||
|
out = out.view(inputs.size(0), -1)
|
||||||
|
out = torch.softmax(self.linear1(out), dim=1)
|
||||||
|
out = self.linear2(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
# %% [markdown]
|
||||||
|
# ## Initialize the model, loss function, and optimizer
|
||||||
|
|
||||||
|
# %%
|
||||||
|
model = TrigramNN(vocab_size, embedding_dim, hidden_dim, output_size)
|
||||||
|
criterion = nn.CrossEntropyLoss()
|
||||||
|
optimizer = optim.SGD(model.parameters(), lr=learning_rate)
|
||||||
|
|
||||||
|
# %% [markdown]
|
||||||
|
# ## Training loop
|
||||||
|
|
||||||
|
# %%
|
||||||
|
model.to(device)
|
||||||
|
|
||||||
|
for epoch in range(epochs):
|
||||||
|
total_loss = 0
|
||||||
|
for batch_inputs, batch_targets in tqdm(dataloader):
|
||||||
|
batch_inputs, batch_targets = batch_inputs.to(device), batch_targets.to(device)
|
||||||
|
|
||||||
|
model.zero_grad()
|
||||||
|
output = model(batch_inputs)
|
||||||
|
|
||||||
|
loss = criterion(output, batch_targets.view(-1))
|
||||||
|
total_loss += loss.item()
|
||||||
|
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
print(f"Epoch {epoch+1}, Loss: {total_loss/len(dataloader)}")
|
||||||
|
|
||||||
|
# %% [markdown]
|
||||||
|
# ## Write function to convert index to word
|
||||||
|
|
||||||
|
# %%
|
||||||
|
def idx_to_word(idx):
|
||||||
|
idx = int(idx)
|
||||||
|
if idx not in word_to_ix.inverse:
|
||||||
|
return '<UNK>'
|
||||||
|
return word_to_ix.inverse[idx]
|
||||||
|
|
||||||
|
# %% [markdown]
|
||||||
|
# ## test the model
|
||||||
|
|
||||||
|
# %%
|
||||||
|
def predict(left_context, right_context):
|
||||||
|
with torch.no_grad():
|
||||||
|
context = left_context + right_context
|
||||||
|
test_context_idxs = torch.tensor([[tokenize(x) for x in context]], device=device)
|
||||||
|
output = model(test_context_idxs)
|
||||||
|
top_predicted_scores, top_predicted_indices = torch.topk(output, 5)
|
||||||
|
predictions = list(zip(top_predicted_scores[0], top_predicted_indices[0]))
|
||||||
|
predictions = [(float(score), idx_to_word(idx)) for score, idx in predictions]
|
||||||
|
total_score = np.sum([score for score, _ in predictions])
|
||||||
|
predictions = ' '.join([f"{word}:{score}" for score, word in predictions]) + ' :' + str(1.0 - total_score)
|
||||||
|
return predictions
|
||||||
|
|
||||||
|
# %%
|
||||||
|
print(predict(["came", "fiom"], []))
|
||||||
|
|
||||||
|
# %% [markdown]
|
||||||
|
# # Generate result for dev dataset
|
||||||
|
|
||||||
|
# %%
|
||||||
|
dataset_dir = os.path.join('..', 'dev-0', 'in.tsv.xz')
|
||||||
|
output_dir = os.path.join('..', 'dev-0', 'out.tsv')
|
||||||
|
|
||||||
|
df = pd.read_csv(dataset_dir, sep='\t', header=None, names=['FileId', 'Year', 'LeftContext', 'RightContext'], quoting=csv.QUOTE_NONE)
|
||||||
|
df = df.replace(r'\\r+|\\n+|\\t+', ' ', regex=True)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
final = ""
|
||||||
|
|
||||||
|
for i, (_, row) in tqdm(enumerate(df.iterrows()), total=len(df)):
|
||||||
|
left_context = re.split(r"\s+", row['LeftContext'].strip())[-left_tokens:]
|
||||||
|
right_context = re.split(r"\s+", row['RightContext'].strip())[:right_tokens]
|
||||||
|
|
||||||
|
final += predict(left_context, right_context) + '\n'
|
||||||
|
|
||||||
|
with open(output_dir, 'w', encoding="UTF-8") as f:
|
||||||
|
f.write(final)
|
||||||
|
|
||||||
|
|
824
src/07_trigram_neural.ipynb
Normal file
824
src/07_trigram_neural.ipynb
Normal file
@ -0,0 +1,824 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# <b>Trigram</b> neural network model for gap fill task"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Import required packages"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 414,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from tqdm import tqdm\n",
|
||||||
|
"import re\n",
|
||||||
|
"import nltk\n",
|
||||||
|
"import os\n",
|
||||||
|
"import csv\n",
|
||||||
|
"import pandas as pd\n",
|
||||||
|
"import torch\n",
|
||||||
|
"import torch.nn as nn\n",
|
||||||
|
"import torch.optim as optim\n",
|
||||||
|
"import sys\n",
|
||||||
|
"import numpy as np\n",
|
||||||
|
"from torch.utils.data import DataLoader, TensorDataset\n",
|
||||||
|
"from bidict import bidict\n",
|
||||||
|
"import math\n",
|
||||||
|
"from sklearn.utils import shuffle\n",
|
||||||
|
"from collections import Counter\n",
|
||||||
|
"import random"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 415,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"os.environ['CUDA_LAUNCH_BLOCKING'] = '1'\n",
|
||||||
|
"os.environ['TORCH_USE_CUDA_DSA'] = '1'"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Global configuration variables"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 416,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"vocab_size = 60_000\n",
|
||||||
|
"batch_size = 64\n",
|
||||||
|
"embedding_dim = 64\n",
|
||||||
|
"hidden_dim = 1024\n",
|
||||||
|
"learning_rate = 0.001\n",
|
||||||
|
"epochs = 20\n",
|
||||||
|
"\n",
|
||||||
|
"output_size = vocab_size"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 417,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"cpu\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"# device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
||||||
|
"device = torch.device(\"cpu\")\n",
|
||||||
|
"print(device)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Load train data corpus"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 418,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
" 0%| | 0/433 [00:00<?, ?it/s]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"100%|██████████| 433/433 [01:03<00:00, 6.77it/s]\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"dataset_dir = os.path.join('..', 'train', 'in.tsv.xz')\n",
|
||||||
|
"expected_dir = os.path.join('..', 'train', 'expected.tsv')\n",
|
||||||
|
"\n",
|
||||||
|
"df = pd.read_csv(dataset_dir, sep='\\t', header=None, names=['FileId', 'Year', 'LeftContext', 'RightContext'], quoting=csv.QUOTE_NONE, dtype=str, chunksize=1000)\n",
|
||||||
|
"expected_df = pd.read_csv(expected_dir, sep='\\t', header=None, names=['Word'], quoting=csv.QUOTE_NONE, dtype=str, chunksize=1000)\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"input_corpus = []\n",
|
||||||
|
"target_corpus = []\n",
|
||||||
|
"\n",
|
||||||
|
"left_tokens = 1\n",
|
||||||
|
"right_tokens = 1\n",
|
||||||
|
"\n",
|
||||||
|
"for j, (df, expected_df) in tqdm(enumerate(zip(df, expected_df)), total=433):\n",
|
||||||
|
" df = df.replace(r'\\\\r+|\\\\n+|\\\\t+', ' ', regex=True)\n",
|
||||||
|
" \n",
|
||||||
|
" for left_context, word, right_context in zip(df['LeftContext'].to_list(), expected_df['Word'].to_list(), df['RightContext'].to_list()):\n",
|
||||||
|
" target_corpus.append([str(word).strip()])\n",
|
||||||
|
" input_corpus.append(re.split(r\"\\s+\", left_context.strip())[-left_tokens:] + re.split(r\"\\s+\", right_context.strip())[:right_tokens])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Create dictionaries for mapping words to indices"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 419,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def flatten(matrix):\n",
|
||||||
|
" flat_list = []\n",
|
||||||
|
" for row in matrix:\n",
|
||||||
|
" flat_list += row\n",
|
||||||
|
" return flat_list"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 420,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"100%|██████████| 59999/59999 [00:00<00:00, 131034.12it/s]\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"word_to_ix = bidict({})\n",
|
||||||
|
"words_corpus = flatten(input_corpus) + flatten(target_corpus)\n",
|
||||||
|
"\n",
|
||||||
|
"counts = Counter(words_corpus)\n",
|
||||||
|
"\n",
|
||||||
|
"for word, _ in tqdm(counts.most_common(vocab_size - 1)):\n",
|
||||||
|
" if word not in word_to_ix:\n",
|
||||||
|
" word_to_ix[word] = len(word_to_ix) + 1"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Tokenize entire corpus"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 421,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"100%|██████████| 432022/432022 [00:01<00:00, 255044.26it/s]\n",
|
||||||
|
"100%|██████████| 432022/432022 [00:01<00:00, 348618.53it/s]\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"def tokenize(w):\n",
|
||||||
|
" if w in word_to_ix:\n",
|
||||||
|
" return word_to_ix[w]\n",
|
||||||
|
" else:\n",
|
||||||
|
" return 0\n",
|
||||||
|
"\n",
|
||||||
|
"tokenized_input_corpus = []\n",
|
||||||
|
"tokenized_target_corpus = []\n",
|
||||||
|
"\n",
|
||||||
|
"for words in tqdm(input_corpus):\n",
|
||||||
|
" tokenized_input_corpus.append([tokenize(word) for word in words])\n",
|
||||||
|
"\n",
|
||||||
|
"for words in tqdm(target_corpus):\n",
|
||||||
|
" tokenized_target_corpus.append([tokenize(word) for word in words])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 422,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"tokenized_input_corpus, tokenized_target_corpus = shuffle(tokenized_input_corpus, tokenized_target_corpus)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Create dataset"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 423,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"indices = np.nonzero(np.array(tokenized_target_corpus).flatten())\n",
|
||||||
|
"\n",
|
||||||
|
"tokenized_input_corpus = np.take(tokenized_input_corpus, indices, axis=0)\n",
|
||||||
|
"tokenized_target_corpus = np.take(tokenized_target_corpus, indices, axis=0)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 424,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"input_corpus_tensor = torch.flatten(torch.tensor(tokenized_input_corpus, dtype=torch.long, device=device), end_dim=-2)\n",
|
||||||
|
"target_corpus_tensor = torch.flatten(torch.tensor(tokenized_target_corpus, dtype=torch.long, device=device)).reshape(-1, 1)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 425,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"torch.Size([389892, 2])\n",
|
||||||
|
"torch.Size([389892, 1])\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"print(input_corpus_tensor.size())\n",
|
||||||
|
"print(target_corpus_tensor.size())"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 426,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"['end', 'the']\n",
|
||||||
|
"['of']\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"random_index = random.randint(0, len(input_corpus_tensor) - 1)\n",
|
||||||
|
"\n",
|
||||||
|
"# Get random element from input corpus\n",
|
||||||
|
"random_input_element = input_corpus_tensor[random_index]\n",
|
||||||
|
"\n",
|
||||||
|
"# Get corresponding element from target corpus\n",
|
||||||
|
"random_target_element = target_corpus_tensor[random_index]\n",
|
||||||
|
"\n",
|
||||||
|
"print([word_to_ix.inverse[int(idx)] if int(idx) > 0 else '<UNK>' for idx in random_input_element])\n",
|
||||||
|
"print([word_to_ix.inverse[int(idx)] if int(idx) > 0 else '<UNK>' for idx in random_target_element])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 427,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"dataset = TensorDataset(input_corpus_tensor[:10_000], target_corpus_tensor[:10_000])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 428,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Define the trigram neural network model"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 429,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"class TrigramNN(nn.Module):\n",
|
||||||
|
" def __init__(self, vocab_size, embedding_dim, hidden_dim, output_size):\n",
|
||||||
|
" super(TrigramNN, self).__init__()\n",
|
||||||
|
" self.embedding = nn.Embedding(vocab_size, embedding_dim)\n",
|
||||||
|
" self.linear1 = nn.Linear(embedding_dim * (left_tokens + right_tokens), hidden_dim)\n",
|
||||||
|
" self.linear2 = nn.Linear(hidden_dim, output_size)\n",
|
||||||
|
" \n",
|
||||||
|
" def forward(self, inputs):\n",
|
||||||
|
" out = self.embedding(inputs)\n",
|
||||||
|
" out = out.view(inputs.size(0), -1)\n",
|
||||||
|
" out = torch.softmax(self.linear1(out), dim=1)\n",
|
||||||
|
" out = self.linear2(out)\n",
|
||||||
|
" return out"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Initialize the model, loss function, and optimizer"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 430,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"model = TrigramNN(vocab_size, embedding_dim, hidden_dim, output_size)\n",
|
||||||
|
"criterion = nn.CrossEntropyLoss()\n",
|
||||||
|
"optimizer = optim.SGD(model.parameters(), lr=learning_rate)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Training loop"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 431,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"100%|██████████| 157/157 [00:32<00:00, 4.81it/s]\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Epoch 1, Loss: 10.999195001687214\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"100%|██████████| 157/157 [00:32<00:00, 4.86it/s]\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Epoch 2, Loss: 10.997720451112006\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"100%|██████████| 157/157 [00:32<00:00, 4.88it/s]\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Epoch 3, Loss: 10.99624701214444\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"100%|██████████| 157/157 [00:30<00:00, 5.17it/s]\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Epoch 4, Loss: 10.994744385883306\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"100%|██████████| 157/157 [00:30<00:00, 5.21it/s]\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Epoch 5, Loss: 10.993266263585182\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"100%|██████████| 157/157 [00:30<00:00, 5.22it/s]\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Epoch 6, Loss: 10.991843545512788\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"100%|██████████| 157/157 [00:31<00:00, 4.92it/s]\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Epoch 7, Loss: 10.990350304135852\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"100%|██████████| 157/157 [00:28<00:00, 5.60it/s]\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Epoch 8, Loss: 10.988877800619527\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"100%|██████████| 157/157 [00:32<00:00, 4.81it/s]\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Epoch 9, Loss: 10.987337306806236\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"100%|██████████| 157/157 [00:29<00:00, 5.32it/s]\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Epoch 10, Loss: 10.985873113012618\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"100%|██████████| 157/157 [00:30<00:00, 5.13it/s]\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Epoch 11, Loss: 10.98438450637137\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"100%|██████████| 157/157 [00:28<00:00, 5.45it/s]\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Epoch 12, Loss: 10.9829175548189\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"100%|██████████| 157/157 [00:30<00:00, 5.11it/s]\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Epoch 13, Loss: 10.981461263765954\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"100%|██████████| 157/157 [00:30<00:00, 5.08it/s]\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Epoch 14, Loss: 10.97996347269435\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"100%|██████████| 157/157 [00:30<00:00, 5.22it/s]\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Epoch 15, Loss: 10.978485234983408\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"100%|██████████| 157/157 [00:31<00:00, 4.98it/s]\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Epoch 16, Loss: 10.977057912547117\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"100%|██████████| 157/157 [00:30<00:00, 5.23it/s]\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Epoch 17, Loss: 10.97553843601494\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"100%|██████████| 157/157 [00:29<00:00, 5.34it/s]\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Epoch 18, Loss: 10.974108489455691\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"100%|██████████| 157/157 [00:32<00:00, 4.82it/s]\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Epoch 19, Loss: 10.972679308265638\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"100%|██████████| 157/157 [00:30<00:00, 5.23it/s]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Epoch 20, Loss: 10.971182902147815\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"model.to(device)\n",
|
||||||
|
"\n",
|
||||||
|
"for epoch in range(epochs):\n",
|
||||||
|
" total_loss = 0\n",
|
||||||
|
" for batch_inputs, batch_targets in tqdm(dataloader):\n",
|
||||||
|
" batch_inputs, batch_targets = batch_inputs.to(device), batch_targets.to(device)\n",
|
||||||
|
" \n",
|
||||||
|
" model.zero_grad()\n",
|
||||||
|
" output = model(batch_inputs)\n",
|
||||||
|
"\n",
|
||||||
|
" loss = criterion(output, batch_targets.view(-1))\n",
|
||||||
|
" total_loss += loss.item()\n",
|
||||||
|
"\n",
|
||||||
|
" loss.backward()\n",
|
||||||
|
" optimizer.step()\n",
|
||||||
|
"\n",
|
||||||
|
" print(f\"Epoch {epoch+1}, Loss: {total_loss/len(dataloader)}\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Write function to convert index to word"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 432,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def idx_to_word(idx):\n",
|
||||||
|
" idx = int(idx)\n",
|
||||||
|
" if idx not in word_to_ix.inverse:\n",
|
||||||
|
" return '<UNK>'\n",
|
||||||
|
" return word_to_ix.inverse[idx]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## test the model"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 433,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def predict(left_context, right_context):\n",
|
||||||
|
" with torch.no_grad():\n",
|
||||||
|
" context = left_context + right_context\n",
|
||||||
|
" test_context_idxs = torch.tensor([[tokenize(x) for x in context]], device=device)\n",
|
||||||
|
" output = model(test_context_idxs)\n",
|
||||||
|
" top_predicted_scores, top_predicted_indices = torch.topk(output, 5)\n",
|
||||||
|
" predictions = list(zip(top_predicted_scores[0], top_predicted_indices[0]))\n",
|
||||||
|
" predictions = [(float(score), idx_to_word(idx)) for score, idx in predictions]\n",
|
||||||
|
" total_score = np.sum([score for score, _ in predictions])\n",
|
||||||
|
" predictions = ' '.join([f\"{word}:{score}\" for score, word in predictions]) + ' :' + str(1.0 - total_score)\n",
|
||||||
|
" return predictions"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 434,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"the:0.210836723446846 of:0.13834647834300995 and:0.11819174885749817 to:0.09819918870925903 a:0.0662047415971756 :0.36822111904621124\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"print(predict([\"came\", \"fiom\"], []))"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Generate result for dev dataset"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 435,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"dataset_dir = os.path.join('..', 'dev-0', 'in.tsv.xz')\n",
|
||||||
|
"output_dir = os.path.join('..', 'dev-0', 'out.tsv')\n",
|
||||||
|
"\n",
|
||||||
|
"df = pd.read_csv(dataset_dir, sep='\\t', header=None, names=['FileId', 'Year', 'LeftContext', 'RightContext'], quoting=csv.QUOTE_NONE)\n",
|
||||||
|
"df = df.replace(r'\\\\r+|\\\\n+|\\\\t+', ' ', regex=True)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 436,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"100%|██████████| 10519/10519 [02:25<00:00, 72.19it/s]\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"final = \"\"\n",
|
||||||
|
"\n",
|
||||||
|
"for i, (_, row) in tqdm(enumerate(df.iterrows()), total=len(df)):\n",
|
||||||
|
" left_context = re.split(r\"\\s+\", row['LeftContext'].strip())[-left_tokens:]\n",
|
||||||
|
" right_context = re.split(r\"\\s+\", row['RightContext'].strip())[:right_tokens]\n",
|
||||||
|
"\n",
|
||||||
|
" final += predict(left_context, right_context) + '\\n'\n",
|
||||||
|
"\n",
|
||||||
|
"with open(output_dir, 'w', encoding=\"UTF-8\") as f:\n",
|
||||||
|
" f.write(final)"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "p311-cu121",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.11.1"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 2
|
||||||
|
}
|
739
src/08_word2vec.ipynb
Normal file
739
src/08_word2vec.ipynb
Normal file
@ -0,0 +1,739 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# <b>Trigram</b> neural network model for gap fill task"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Import required packages"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"c:\\Users\\Marcin\\.conda\\envs\\p311-cu121\\Lib\\site-packages\\torchtext\\vocab\\__init__.py:4: UserWarning: \n",
|
||||||
|
"/!\\ IMPORTANT WARNING ABOUT TORCHTEXT STATUS /!\\ \n",
|
||||||
|
"Torchtext is deprecated and the last released version will be 0.18 (this one). You can silence this warning by calling the following at the beginnign of your scripts: `import torchtext; torchtext.disable_torchtext_deprecation_warning()`\n",
|
||||||
|
" warnings.warn(torchtext._TORCHTEXT_DEPRECATION_MSG)\n",
|
||||||
|
"c:\\Users\\Marcin\\.conda\\envs\\p311-cu121\\Lib\\site-packages\\torchtext\\utils.py:4: UserWarning: \n",
|
||||||
|
"/!\\ IMPORTANT WARNING ABOUT TORCHTEXT STATUS /!\\ \n",
|
||||||
|
"Torchtext is deprecated and the last released version will be 0.18 (this one). You can silence this warning by calling the following at the beginnign of your scripts: `import torchtext; torchtext.disable_torchtext_deprecation_warning()`\n",
|
||||||
|
" warnings.warn(torchtext._TORCHTEXT_DEPRECATION_MSG)\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"from tqdm import tqdm\n",
|
||||||
|
"import re\n",
|
||||||
|
"import os\n",
|
||||||
|
"import csv\n",
|
||||||
|
"import pandas as pd\n",
|
||||||
|
"import torch\n",
|
||||||
|
"import torch.nn as nn\n",
|
||||||
|
"import torch.optim as optim\n",
|
||||||
|
"import numpy as np\n",
|
||||||
|
"from torch.utils.data import DataLoader, TensorDataset\n",
|
||||||
|
"from sklearn.utils import shuffle\n",
|
||||||
|
"import random\n",
|
||||||
|
"from torchtext.vocab import build_vocab_from_iterator"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"os.environ['CUDA_LAUNCH_BLOCKING'] = '1'\n",
|
||||||
|
"os.environ['TORCH_USE_CUDA_DSA'] = '1'"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Global configuration variables"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"vocab_size = 40_000\n",
|
||||||
|
"batch_size = 256\n",
|
||||||
|
"embedding_dim = 128\n",
|
||||||
|
"hidden_dim = 1024\n",
|
||||||
|
"learning_rate = 0.001\n",
|
||||||
|
"epochs = 5\n",
|
||||||
|
"\n",
|
||||||
|
"output_size = vocab_size"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 4,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"cuda\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
||||||
|
"# device = torch.device(\"cpu\")\n",
|
||||||
|
"print(device)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Build vocabulary"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 5,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def get_word_lines_from_dataset():\n",
|
||||||
|
" dataset_dir = os.path.join('..', 'train', 'in.tsv.xz')\n",
|
||||||
|
" expected_dir = os.path.join('..', 'train', 'expected.tsv')\n",
|
||||||
|
"\n",
|
||||||
|
" df = pd.read_csv(dataset_dir, sep='\\t', header=None, names=['FileId', 'Year', 'LeftContext', 'RightContext'], quoting=csv.QUOTE_NONE, dtype=str, chunksize=1000)\n",
|
||||||
|
" expected_df = pd.read_csv(expected_dir, sep='\\t', header=None, names=['Word'], quoting=csv.QUOTE_NONE, dtype=str, chunksize=1000)\n",
|
||||||
|
"\n",
|
||||||
|
" for j, (df, expected_df) in tqdm(enumerate(zip(df, expected_df)), total=433):\n",
|
||||||
|
" df = df.replace(r'\\\\r+|\\\\n+|\\\\t+', ' ', regex=True)\n",
|
||||||
|
" \n",
|
||||||
|
" for left_context, word, right_context in zip(df['LeftContext'].to_list(), expected_df['Word'].to_list(), df['RightContext'].to_list()):\n",
|
||||||
|
" yield re.split(r\"\\s+\", left_context.strip()) + [str(word).strip()] + re.split(r\"\\s+\", right_context.strip())"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 6,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
" 0%| | 0/433 [00:00<?, ?it/s]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"100%|██████████| 433/433 [01:34<00:00, 4.57it/s]\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"vocab = build_vocab_from_iterator(\n",
|
||||||
|
" get_word_lines_from_dataset(),\n",
|
||||||
|
" max_tokens = vocab_size,\n",
|
||||||
|
" specials = ['<unk>'])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 7,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"vocab.set_default_index(vocab['<unk>'])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 8,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"['<unk>', 'the', 'of', 'houses']"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 8,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"vocab.lookup_tokens([0, 1, 2, 1245])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 9,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
" 0%| | 0/433 [00:00<?, ?it/s]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"100%|██████████| 433/433 [01:08<00:00, 6.31it/s]\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"dataset_dir = os.path.join('..', 'train', 'in.tsv.xz')\n",
|
||||||
|
"expected_dir = os.path.join('..', 'train', 'expected.tsv')\n",
|
||||||
|
"\n",
|
||||||
|
"df = pd.read_csv(dataset_dir, sep='\\t', header=None, names=['FileId', 'Year', 'LeftContext', 'RightContext'], quoting=csv.QUOTE_NONE, dtype=str, chunksize=1000)\n",
|
||||||
|
"expected_df = pd.read_csv(expected_dir, sep='\\t', header=None, names=['Word'], quoting=csv.QUOTE_NONE, dtype=str, chunksize=1000)\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"input_corpus = []\n",
|
||||||
|
"target_corpus = []\n",
|
||||||
|
"\n",
|
||||||
|
"left_tokens = 2\n",
|
||||||
|
"right_tokens = 2\n",
|
||||||
|
"\n",
|
||||||
|
"for j, (df, expected_df) in tqdm(enumerate(zip(df, expected_df)), total=433):\n",
|
||||||
|
" df = df.replace(r'\\\\r+|\\\\n+|\\\\t+', ' ', regex=True)\n",
|
||||||
|
" \n",
|
||||||
|
" for left_context, word, right_context in zip(df['LeftContext'].to_list(), expected_df['Word'].to_list(), df['RightContext'].to_list()):\n",
|
||||||
|
" target_corpus.append([str(word).strip()])\n",
|
||||||
|
" input_corpus.append(re.split(r\"\\s+\", left_context.strip())[-left_tokens:] + re.split(r\"\\s+\", right_context.strip())[:right_tokens])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Tokenize entire corpus"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 10,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"100%|██████████| 432022/432022 [00:02<00:00, 172153.99it/s]\n",
|
||||||
|
"100%|██████████| 432022/432022 [00:01<00:00, 316818.82it/s]\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"def tokenize(w):\n",
|
||||||
|
" return vocab[w]\n",
|
||||||
|
" \n",
|
||||||
|
"def detokenize(t):\n",
|
||||||
|
" return vocab.lookup_tokens([t])[0]\n",
|
||||||
|
"\n",
|
||||||
|
"tokenized_input_corpus = []\n",
|
||||||
|
"tokenized_target_corpus = []\n",
|
||||||
|
"\n",
|
||||||
|
"for words in tqdm(input_corpus):\n",
|
||||||
|
" tokenized_input_corpus.append([tokenize(word) for word in words])\n",
|
||||||
|
"\n",
|
||||||
|
"for words in tqdm(target_corpus):\n",
|
||||||
|
" tokenized_target_corpus.append([tokenize(word) for word in words])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 11,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"tokenized_input_corpus, tokenized_target_corpus = shuffle(tokenized_input_corpus, tokenized_target_corpus)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Create dataset"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 12,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"indices = np.nonzero(np.array(tokenized_target_corpus).flatten())\n",
|
||||||
|
"\n",
|
||||||
|
"tokenized_input_corpus = np.take(tokenized_input_corpus, indices, axis=0)\n",
|
||||||
|
"tokenized_target_corpus = np.take(tokenized_target_corpus, indices, axis=0)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 13,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"input_corpus_tensor = torch.flatten(torch.tensor(tokenized_input_corpus, dtype=torch.long, device=device), end_dim=-2)\n",
|
||||||
|
"target_corpus_tensor = torch.flatten(torch.tensor(tokenized_target_corpus, dtype=torch.long, device=device)).reshape(-1, 1)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 14,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"torch.Size([378666, 4])\n",
|
||||||
|
"torch.Size([378666, 1])\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"print(input_corpus_tensor.size())\n",
|
||||||
|
"print(target_corpus_tensor.size())"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 15,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"['a', 'charming', 'woman', 'would']\n",
|
||||||
|
"['young']\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"random_index = random.randint(0, len(input_corpus_tensor) - 1)\n",
|
||||||
|
"\n",
|
||||||
|
"# Get random element from input corpus\n",
|
||||||
|
"random_input_element = input_corpus_tensor[random_index]\n",
|
||||||
|
"\n",
|
||||||
|
"# Get corresponding element from target corpus\n",
|
||||||
|
"random_target_element = target_corpus_tensor[random_index]\n",
|
||||||
|
"\n",
|
||||||
|
"print([detokenize(int(idx)) for idx in random_input_element])\n",
|
||||||
|
"print([detokenize(int(idx)) for idx in random_target_element])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 16,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"dataset = TensorDataset(input_corpus_tensor, target_corpus_tensor)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 17,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Define the trigram neural network model"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 18,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"class NGramNN(nn.Module):\n",
|
||||||
|
" def __init__(self, vocab_size, embedding_dim):\n",
|
||||||
|
" super(NGramNN, self).__init__()\n",
|
||||||
|
" self.embedding = nn.Embedding(vocab_size, embedding_dim)\n",
|
||||||
|
" self.linear = nn.Linear(embedding_dim * (left_tokens + right_tokens), vocab_size)\n",
|
||||||
|
" \n",
|
||||||
|
" def forward(self, inputs):\n",
|
||||||
|
" out = self.embedding(inputs)\n",
|
||||||
|
" out = out.view(inputs.size(0), -1)\n",
|
||||||
|
" out = self.linear(out)\n",
|
||||||
|
" # out = torch.softmax(out, dim=1)\n",
|
||||||
|
" return out"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Initialize the model, loss function, and optimizer"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 20,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"c:\\Users\\Marcin\\.conda\\envs\\p311-cu121\\Lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
||||||
|
" from .autonotebook import tqdm as notebook_tqdm\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"model = NGramNN(vocab_size, embedding_dim)\n",
|
||||||
|
"criterion = nn.CrossEntropyLoss()\n",
|
||||||
|
"optimizer = optim.Adam(model.parameters(), lr=learning_rate)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Training loop"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
" 0%| | 0/1480 [00:00<?, ?it/s]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"100%|██████████| 1480/1480 [00:32<00:00, 45.24it/s]\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Epoch 1, Loss: 7.488175111847955\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"100%|██████████| 1480/1480 [00:32<00:00, 45.62it/s]\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Epoch 2, Loss: 5.083534079629022\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"100%|██████████| 1480/1480 [00:32<00:00, 44.91it/s]\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Epoch 3, Loss: 3.8214319522316393\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"100%|██████████| 1480/1480 [00:33<00:00, 44.65it/s]\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Epoch 4, Loss: 3.1464366490776476\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"100%|██████████| 1480/1480 [00:32<00:00, 45.94it/s]\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Epoch 5, Loss: 2.743303858589482\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"100%|██████████| 1480/1480 [00:32<00:00, 46.07it/s]\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Epoch 6, Loss: 2.456264268949225\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"100%|██████████| 1480/1480 [00:32<00:00, 45.59it/s]\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Epoch 7, Loss: 2.2358319317972337\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"100%|██████████| 1480/1480 [00:32<00:00, 46.15it/s]\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Epoch 8, Loss: 2.0536118873067806\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"100%|██████████| 1480/1480 [00:32<00:00, 45.90it/s]\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Epoch 9, Loss: 1.89841981795994\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"100%|██████████| 1480/1480 [00:32<00:00, 44.89it/s]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Epoch 10, Loss: 1.7637179977990485\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"model.to(device)\n",
|
||||||
|
"\n",
|
||||||
|
"for epoch in range(epochs):\n",
|
||||||
|
" total_loss = 0\n",
|
||||||
|
" for batch_inputs, batch_targets in tqdm(dataloader):\n",
|
||||||
|
" batch_inputs, batch_targets = batch_inputs.to(device), batch_targets.to(device)\n",
|
||||||
|
" \n",
|
||||||
|
" model.zero_grad()\n",
|
||||||
|
" output = model(batch_inputs)\n",
|
||||||
|
"\n",
|
||||||
|
" loss = criterion(output, batch_targets.view(-1))\n",
|
||||||
|
" total_loss += loss.item()\n",
|
||||||
|
"\n",
|
||||||
|
" loss.backward()\n",
|
||||||
|
" optimizer.step()\n",
|
||||||
|
"\n",
|
||||||
|
" print(f\"Epoch {epoch+1}, Loss: {total_loss/len(dataloader)}\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## test the model"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def predict(left_context, right_context):\n",
|
||||||
|
" with torch.no_grad():\n",
|
||||||
|
" context = left_context + right_context\n",
|
||||||
|
" test_context_idxs = torch.tensor([[tokenize(x) for x in context]], device=device)\n",
|
||||||
|
" output = model(test_context_idxs)\n",
|
||||||
|
" top_predicted_scores, top_predicted_indices = torch.topk(output, vocab_size)\n",
|
||||||
|
"\n",
|
||||||
|
" top_predicted_scores = np.array(top_predicted_scores[0].cpu())\n",
|
||||||
|
" top_predicted_indices = top_predicted_indices[0]\n",
|
||||||
|
"\n",
|
||||||
|
" top_predicted_scores = top_predicted_scores[top_predicted_scores > 0]\n",
|
||||||
|
" top_predicted_indices = top_predicted_indices[:len(top_predicted_scores)]\n",
|
||||||
|
"\n",
|
||||||
|
" total_score = np.sum([score for score in top_predicted_scores[:20]])\n",
|
||||||
|
"\n",
|
||||||
|
" predictions = list(zip(top_predicted_scores, top_predicted_indices))\n",
|
||||||
|
" predictions = [(round(float(score), 2), detokenize(idx)) for score, idx in predictions[:10]]\n",
|
||||||
|
" \n",
|
||||||
|
" words = [word for _, word in predictions]\n",
|
||||||
|
" scores = [round(score/total_score, 2) for score, _ in predictions]\n",
|
||||||
|
"\n",
|
||||||
|
" remaining_score = round(1.0 - np.sum(scores), 2)\n",
|
||||||
|
"\n",
|
||||||
|
" predictions = ' '.join([f\"{word}:{score}\" for score, word in zip(scores, words)]) + ' :' + str(remaining_score)\n",
|
||||||
|
"\n",
|
||||||
|
" return predictions"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"the:0.07 his:0.06 a:0.06 their:0.06 John:0.06 tho:0.05 he:0.05 its:0.05 and:0.05 my:0.05 :0.44\n",
|
||||||
|
"to:0.06 a:0.06 the:0.06 and:0.06 in:0.05 when:0.05 of:0.05 up:0.05 such:0.05 for:0.05 :0.46\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"print(predict([\"came\", \"fiom\"], [\"26th\", \"place\"]))\n",
|
||||||
|
"print(predict([\"will\", \"buy\"], [\"telephone\", \"and\"]))"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Generate result for dev dataset"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"dataset_dir = os.path.join('..', 'dev-0', 'in.tsv.xz')\n",
|
||||||
|
"output_dir = os.path.join('..', 'dev-0', 'out.tsv')\n",
|
||||||
|
"\n",
|
||||||
|
"df = pd.read_csv(dataset_dir, sep='\\t', header=None, names=['FileId', 'Year', 'LeftContext', 'RightContext'], quoting=csv.QUOTE_NONE)\n",
|
||||||
|
"df = df.replace(r'\\\\r+|\\\\n+|\\\\t+', ' ', regex=True)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"100%|██████████| 10519/10519 [00:50<00:00, 206.40it/s]\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"final = \"\"\n",
|
||||||
|
"\n",
|
||||||
|
"for i, (_, row) in tqdm(enumerate(df.iterrows()), total=len(df)):\n",
|
||||||
|
" left_context = re.split(r\"\\s+\", row['LeftContext'].strip())[-left_tokens:]\n",
|
||||||
|
" right_context = re.split(r\"\\s+\", row['RightContext'].strip())[:right_tokens]\n",
|
||||||
|
"\n",
|
||||||
|
" final += predict(left_context, right_context) + '\\n'\n",
|
||||||
|
"\n",
|
||||||
|
"with open(output_dir, 'w', encoding=\"UTF-8\") as f:\n",
|
||||||
|
" f.write(final)"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "p311-cu121",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.11.1"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 2
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user