Compare commits
2 Commits
main
...
neural-net
Author | SHA1 | Date | |
---|---|---|---|
|
337d2ffc42 | ||
|
d380959afc |
21038
dev-0/out.tsv
21038
dev-0/out.tsv
File diff suppressed because it is too large
Load Diff
248
run.py
Normal file
248
run.py
Normal file
@ -0,0 +1,248 @@
|
||||
# %% [markdown]
|
||||
# # <b>Trigram</b> neural network model for gap fill task
|
||||
|
||||
# %% [markdown]
|
||||
# ## Import required packages
|
||||
|
||||
# %%
|
||||
from tqdm import tqdm
|
||||
import re
|
||||
import nltk
|
||||
import os
|
||||
import csv
|
||||
import pandas as pd
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import sys
|
||||
import numpy as np
|
||||
from torch.utils.data import DataLoader, TensorDataset
|
||||
from bidict import bidict
|
||||
import math
|
||||
from sklearn.utils import shuffle
|
||||
from collections import Counter
|
||||
import random
|
||||
|
||||
# %%
|
||||
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
|
||||
os.environ['TORCH_USE_CUDA_DSA'] = '1'
|
||||
|
||||
# %% [markdown]
|
||||
# ## Global configuration variables
|
||||
|
||||
# %%
|
||||
vocab_size = 60_000
|
||||
batch_size = 64
|
||||
embedding_dim = 64
|
||||
hidden_dim = 1024
|
||||
learning_rate = 0.001
|
||||
epochs = 20
|
||||
|
||||
output_size = vocab_size
|
||||
|
||||
# %%
|
||||
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
device = torch.device("cpu")
|
||||
print(device)
|
||||
|
||||
# %% [markdown]
|
||||
# ## Load train data corpus
|
||||
|
||||
# %%
|
||||
dataset_dir = os.path.join('..', 'train', 'in.tsv.xz')
|
||||
expected_dir = os.path.join('..', 'train', 'expected.tsv')
|
||||
|
||||
df = pd.read_csv(dataset_dir, sep='\t', header=None, names=['FileId', 'Year', 'LeftContext', 'RightContext'], quoting=csv.QUOTE_NONE, dtype=str, chunksize=1000)
|
||||
expected_df = pd.read_csv(expected_dir, sep='\t', header=None, names=['Word'], quoting=csv.QUOTE_NONE, dtype=str, chunksize=1000)
|
||||
|
||||
|
||||
input_corpus = []
|
||||
target_corpus = []
|
||||
|
||||
left_tokens = 1
|
||||
right_tokens = 1
|
||||
|
||||
for j, (df, expected_df) in tqdm(enumerate(zip(df, expected_df)), total=433):
|
||||
df = df.replace(r'\\r+|\\n+|\\t+', ' ', regex=True)
|
||||
|
||||
for left_context, word, right_context in zip(df['LeftContext'].to_list(), expected_df['Word'].to_list(), df['RightContext'].to_list()):
|
||||
target_corpus.append([str(word).strip()])
|
||||
input_corpus.append(re.split(r"\s+", left_context.strip())[-left_tokens:] + re.split(r"\s+", right_context.strip())[:right_tokens])
|
||||
|
||||
# %% [markdown]
|
||||
# ## Create dictionaries for mapping words to indices
|
||||
|
||||
# %%
|
||||
def flatten(matrix):
|
||||
flat_list = []
|
||||
for row in matrix:
|
||||
flat_list += row
|
||||
return flat_list
|
||||
|
||||
# %%
|
||||
word_to_ix = bidict({})
|
||||
words_corpus = flatten(input_corpus) + flatten(target_corpus)
|
||||
|
||||
counts = Counter(words_corpus)
|
||||
|
||||
for word, _ in tqdm(counts.most_common(vocab_size - 1)):
|
||||
if word not in word_to_ix:
|
||||
word_to_ix[word] = len(word_to_ix) + 1
|
||||
|
||||
# %% [markdown]
|
||||
# ## Tokenize entire corpus
|
||||
|
||||
# %%
|
||||
def tokenize(w):
|
||||
if w in word_to_ix:
|
||||
return word_to_ix[w]
|
||||
else:
|
||||
return 0
|
||||
|
||||
tokenized_input_corpus = []
|
||||
tokenized_target_corpus = []
|
||||
|
||||
for words in tqdm(input_corpus):
|
||||
tokenized_input_corpus.append([tokenize(word) for word in words])
|
||||
|
||||
for words in tqdm(target_corpus):
|
||||
tokenized_target_corpus.append([tokenize(word) for word in words])
|
||||
|
||||
# %%
|
||||
tokenized_input_corpus, tokenized_target_corpus = shuffle(tokenized_input_corpus, tokenized_target_corpus)
|
||||
|
||||
# %% [markdown]
|
||||
# ## Create dataset
|
||||
|
||||
# %%
|
||||
indices = np.nonzero(np.array(tokenized_target_corpus).flatten())
|
||||
|
||||
tokenized_input_corpus = np.take(tokenized_input_corpus, indices, axis=0)
|
||||
tokenized_target_corpus = np.take(tokenized_target_corpus, indices, axis=0)
|
||||
|
||||
# %%
|
||||
input_corpus_tensor = torch.flatten(torch.tensor(tokenized_input_corpus, dtype=torch.long, device=device), end_dim=-2)
|
||||
target_corpus_tensor = torch.flatten(torch.tensor(tokenized_target_corpus, dtype=torch.long, device=device)).reshape(-1, 1)
|
||||
|
||||
# %%
|
||||
print(input_corpus_tensor.size())
|
||||
print(target_corpus_tensor.size())
|
||||
|
||||
# %%
|
||||
random_index = random.randint(0, len(input_corpus_tensor) - 1)
|
||||
|
||||
# Get random element from input corpus
|
||||
random_input_element = input_corpus_tensor[random_index]
|
||||
|
||||
# Get corresponding element from target corpus
|
||||
random_target_element = target_corpus_tensor[random_index]
|
||||
|
||||
print([word_to_ix.inverse[int(idx)] if int(idx) > 0 else '<UNK>' for idx in random_input_element])
|
||||
print([word_to_ix.inverse[int(idx)] if int(idx) > 0 else '<UNK>' for idx in random_target_element])
|
||||
|
||||
# %%
|
||||
dataset = TensorDataset(input_corpus_tensor[:10_000], target_corpus_tensor[:10_000])
|
||||
|
||||
# %%
|
||||
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
|
||||
|
||||
# %% [markdown]
|
||||
# ## Define the trigram neural network model
|
||||
|
||||
# %%
|
||||
class TrigramNN(nn.Module):
|
||||
def __init__(self, vocab_size, embedding_dim, hidden_dim, output_size):
|
||||
super(TrigramNN, self).__init__()
|
||||
self.embedding = nn.Embedding(vocab_size, embedding_dim)
|
||||
self.linear1 = nn.Linear(embedding_dim * (left_tokens + right_tokens), hidden_dim)
|
||||
self.linear2 = nn.Linear(hidden_dim, output_size)
|
||||
|
||||
def forward(self, inputs):
|
||||
out = self.embedding(inputs)
|
||||
out = out.view(inputs.size(0), -1)
|
||||
out = torch.softmax(self.linear1(out), dim=1)
|
||||
out = self.linear2(out)
|
||||
return out
|
||||
|
||||
# %% [markdown]
|
||||
# ## Initialize the model, loss function, and optimizer
|
||||
|
||||
# %%
|
||||
model = TrigramNN(vocab_size, embedding_dim, hidden_dim, output_size)
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
optimizer = optim.SGD(model.parameters(), lr=learning_rate)
|
||||
|
||||
# %% [markdown]
|
||||
# ## Training loop
|
||||
|
||||
# %%
|
||||
model.to(device)
|
||||
|
||||
for epoch in range(epochs):
|
||||
total_loss = 0
|
||||
for batch_inputs, batch_targets in tqdm(dataloader):
|
||||
batch_inputs, batch_targets = batch_inputs.to(device), batch_targets.to(device)
|
||||
|
||||
model.zero_grad()
|
||||
output = model(batch_inputs)
|
||||
|
||||
loss = criterion(output, batch_targets.view(-1))
|
||||
total_loss += loss.item()
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
print(f"Epoch {epoch+1}, Loss: {total_loss/len(dataloader)}")
|
||||
|
||||
# %% [markdown]
|
||||
# ## Write function to convert index to word
|
||||
|
||||
# %%
|
||||
def idx_to_word(idx):
|
||||
idx = int(idx)
|
||||
if idx not in word_to_ix.inverse:
|
||||
return '<UNK>'
|
||||
return word_to_ix.inverse[idx]
|
||||
|
||||
# %% [markdown]
|
||||
# ## test the model
|
||||
|
||||
# %%
|
||||
def predict(left_context, right_context):
|
||||
with torch.no_grad():
|
||||
context = left_context + right_context
|
||||
test_context_idxs = torch.tensor([[tokenize(x) for x in context]], device=device)
|
||||
output = model(test_context_idxs)
|
||||
top_predicted_scores, top_predicted_indices = torch.topk(output, 5)
|
||||
predictions = list(zip(top_predicted_scores[0], top_predicted_indices[0]))
|
||||
predictions = [(float(score), idx_to_word(idx)) for score, idx in predictions]
|
||||
total_score = np.sum([score for score, _ in predictions])
|
||||
predictions = ' '.join([f"{word}:{score}" for score, word in predictions]) + ' :' + str(1.0 - total_score)
|
||||
return predictions
|
||||
|
||||
# %%
|
||||
print(predict(["came", "fiom"], []))
|
||||
|
||||
# %% [markdown]
|
||||
# # Generate result for dev dataset
|
||||
|
||||
# %%
|
||||
dataset_dir = os.path.join('..', 'dev-0', 'in.tsv.xz')
|
||||
output_dir = os.path.join('..', 'dev-0', 'out.tsv')
|
||||
|
||||
df = pd.read_csv(dataset_dir, sep='\t', header=None, names=['FileId', 'Year', 'LeftContext', 'RightContext'], quoting=csv.QUOTE_NONE)
|
||||
df = df.replace(r'\\r+|\\n+|\\t+', ' ', regex=True)
|
||||
|
||||
# %%
|
||||
final = ""
|
||||
|
||||
for i, (_, row) in tqdm(enumerate(df.iterrows()), total=len(df)):
|
||||
left_context = re.split(r"\s+", row['LeftContext'].strip())[-left_tokens:]
|
||||
right_context = re.split(r"\s+", row['RightContext'].strip())[:right_tokens]
|
||||
|
||||
final += predict(left_context, right_context) + '\n'
|
||||
|
||||
with open(output_dir, 'w', encoding="UTF-8") as f:
|
||||
f.write(final)
|
||||
|
||||
|
824
src/07_trigram_neural.ipynb
Normal file
824
src/07_trigram_neural.ipynb
Normal file
@ -0,0 +1,824 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# <b>Trigram</b> neural network model for gap fill task"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Import required packages"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 414,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from tqdm import tqdm\n",
|
||||
"import re\n",
|
||||
"import nltk\n",
|
||||
"import os\n",
|
||||
"import csv\n",
|
||||
"import pandas as pd\n",
|
||||
"import torch\n",
|
||||
"import torch.nn as nn\n",
|
||||
"import torch.optim as optim\n",
|
||||
"import sys\n",
|
||||
"import numpy as np\n",
|
||||
"from torch.utils.data import DataLoader, TensorDataset\n",
|
||||
"from bidict import bidict\n",
|
||||
"import math\n",
|
||||
"from sklearn.utils import shuffle\n",
|
||||
"from collections import Counter\n",
|
||||
"import random"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 415,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"os.environ['CUDA_LAUNCH_BLOCKING'] = '1'\n",
|
||||
"os.environ['TORCH_USE_CUDA_DSA'] = '1'"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Global configuration variables"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 416,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"vocab_size = 60_000\n",
|
||||
"batch_size = 64\n",
|
||||
"embedding_dim = 64\n",
|
||||
"hidden_dim = 1024\n",
|
||||
"learning_rate = 0.001\n",
|
||||
"epochs = 20\n",
|
||||
"\n",
|
||||
"output_size = vocab_size"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 417,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"cpu\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
||||
"device = torch.device(\"cpu\")\n",
|
||||
"print(device)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Load train data corpus"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 418,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" 0%| | 0/433 [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"100%|██████████| 433/433 [01:03<00:00, 6.77it/s]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"dataset_dir = os.path.join('..', 'train', 'in.tsv.xz')\n",
|
||||
"expected_dir = os.path.join('..', 'train', 'expected.tsv')\n",
|
||||
"\n",
|
||||
"df = pd.read_csv(dataset_dir, sep='\\t', header=None, names=['FileId', 'Year', 'LeftContext', 'RightContext'], quoting=csv.QUOTE_NONE, dtype=str, chunksize=1000)\n",
|
||||
"expected_df = pd.read_csv(expected_dir, sep='\\t', header=None, names=['Word'], quoting=csv.QUOTE_NONE, dtype=str, chunksize=1000)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"input_corpus = []\n",
|
||||
"target_corpus = []\n",
|
||||
"\n",
|
||||
"left_tokens = 1\n",
|
||||
"right_tokens = 1\n",
|
||||
"\n",
|
||||
"for j, (df, expected_df) in tqdm(enumerate(zip(df, expected_df)), total=433):\n",
|
||||
" df = df.replace(r'\\\\r+|\\\\n+|\\\\t+', ' ', regex=True)\n",
|
||||
" \n",
|
||||
" for left_context, word, right_context in zip(df['LeftContext'].to_list(), expected_df['Word'].to_list(), df['RightContext'].to_list()):\n",
|
||||
" target_corpus.append([str(word).strip()])\n",
|
||||
" input_corpus.append(re.split(r\"\\s+\", left_context.strip())[-left_tokens:] + re.split(r\"\\s+\", right_context.strip())[:right_tokens])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Create dictionaries for mapping words to indices"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 419,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def flatten(matrix):\n",
|
||||
" flat_list = []\n",
|
||||
" for row in matrix:\n",
|
||||
" flat_list += row\n",
|
||||
" return flat_list"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 420,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"100%|██████████| 59999/59999 [00:00<00:00, 131034.12it/s]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"word_to_ix = bidict({})\n",
|
||||
"words_corpus = flatten(input_corpus) + flatten(target_corpus)\n",
|
||||
"\n",
|
||||
"counts = Counter(words_corpus)\n",
|
||||
"\n",
|
||||
"for word, _ in tqdm(counts.most_common(vocab_size - 1)):\n",
|
||||
" if word not in word_to_ix:\n",
|
||||
" word_to_ix[word] = len(word_to_ix) + 1"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Tokenize entire corpus"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 421,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"100%|██████████| 432022/432022 [00:01<00:00, 255044.26it/s]\n",
|
||||
"100%|██████████| 432022/432022 [00:01<00:00, 348618.53it/s]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"def tokenize(w):\n",
|
||||
" if w in word_to_ix:\n",
|
||||
" return word_to_ix[w]\n",
|
||||
" else:\n",
|
||||
" return 0\n",
|
||||
"\n",
|
||||
"tokenized_input_corpus = []\n",
|
||||
"tokenized_target_corpus = []\n",
|
||||
"\n",
|
||||
"for words in tqdm(input_corpus):\n",
|
||||
" tokenized_input_corpus.append([tokenize(word) for word in words])\n",
|
||||
"\n",
|
||||
"for words in tqdm(target_corpus):\n",
|
||||
" tokenized_target_corpus.append([tokenize(word) for word in words])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 422,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"tokenized_input_corpus, tokenized_target_corpus = shuffle(tokenized_input_corpus, tokenized_target_corpus)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Create dataset"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 423,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"indices = np.nonzero(np.array(tokenized_target_corpus).flatten())\n",
|
||||
"\n",
|
||||
"tokenized_input_corpus = np.take(tokenized_input_corpus, indices, axis=0)\n",
|
||||
"tokenized_target_corpus = np.take(tokenized_target_corpus, indices, axis=0)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 424,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"input_corpus_tensor = torch.flatten(torch.tensor(tokenized_input_corpus, dtype=torch.long, device=device), end_dim=-2)\n",
|
||||
"target_corpus_tensor = torch.flatten(torch.tensor(tokenized_target_corpus, dtype=torch.long, device=device)).reshape(-1, 1)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 425,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"torch.Size([389892, 2])\n",
|
||||
"torch.Size([389892, 1])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(input_corpus_tensor.size())\n",
|
||||
"print(target_corpus_tensor.size())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 426,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"['end', 'the']\n",
|
||||
"['of']\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"random_index = random.randint(0, len(input_corpus_tensor) - 1)\n",
|
||||
"\n",
|
||||
"# Get random element from input corpus\n",
|
||||
"random_input_element = input_corpus_tensor[random_index]\n",
|
||||
"\n",
|
||||
"# Get corresponding element from target corpus\n",
|
||||
"random_target_element = target_corpus_tensor[random_index]\n",
|
||||
"\n",
|
||||
"print([word_to_ix.inverse[int(idx)] if int(idx) > 0 else '<UNK>' for idx in random_input_element])\n",
|
||||
"print([word_to_ix.inverse[int(idx)] if int(idx) > 0 else '<UNK>' for idx in random_target_element])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 427,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"dataset = TensorDataset(input_corpus_tensor[:10_000], target_corpus_tensor[:10_000])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 428,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Define the trigram neural network model"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 429,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class TrigramNN(nn.Module):\n",
|
||||
" def __init__(self, vocab_size, embedding_dim, hidden_dim, output_size):\n",
|
||||
" super(TrigramNN, self).__init__()\n",
|
||||
" self.embedding = nn.Embedding(vocab_size, embedding_dim)\n",
|
||||
" self.linear1 = nn.Linear(embedding_dim * (left_tokens + right_tokens), hidden_dim)\n",
|
||||
" self.linear2 = nn.Linear(hidden_dim, output_size)\n",
|
||||
" \n",
|
||||
" def forward(self, inputs):\n",
|
||||
" out = self.embedding(inputs)\n",
|
||||
" out = out.view(inputs.size(0), -1)\n",
|
||||
" out = torch.softmax(self.linear1(out), dim=1)\n",
|
||||
" out = self.linear2(out)\n",
|
||||
" return out"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Initialize the model, loss function, and optimizer"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 430,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model = TrigramNN(vocab_size, embedding_dim, hidden_dim, output_size)\n",
|
||||
"criterion = nn.CrossEntropyLoss()\n",
|
||||
"optimizer = optim.SGD(model.parameters(), lr=learning_rate)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Training loop"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 431,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"100%|██████████| 157/157 [00:32<00:00, 4.81it/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch 1, Loss: 10.999195001687214\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"100%|██████████| 157/157 [00:32<00:00, 4.86it/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch 2, Loss: 10.997720451112006\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"100%|██████████| 157/157 [00:32<00:00, 4.88it/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch 3, Loss: 10.99624701214444\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"100%|██████████| 157/157 [00:30<00:00, 5.17it/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch 4, Loss: 10.994744385883306\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"100%|██████████| 157/157 [00:30<00:00, 5.21it/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch 5, Loss: 10.993266263585182\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"100%|██████████| 157/157 [00:30<00:00, 5.22it/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch 6, Loss: 10.991843545512788\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"100%|██████████| 157/157 [00:31<00:00, 4.92it/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch 7, Loss: 10.990350304135852\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"100%|██████████| 157/157 [00:28<00:00, 5.60it/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch 8, Loss: 10.988877800619527\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"100%|██████████| 157/157 [00:32<00:00, 4.81it/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch 9, Loss: 10.987337306806236\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"100%|██████████| 157/157 [00:29<00:00, 5.32it/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch 10, Loss: 10.985873113012618\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"100%|██████████| 157/157 [00:30<00:00, 5.13it/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch 11, Loss: 10.98438450637137\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"100%|██████████| 157/157 [00:28<00:00, 5.45it/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch 12, Loss: 10.9829175548189\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"100%|██████████| 157/157 [00:30<00:00, 5.11it/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch 13, Loss: 10.981461263765954\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"100%|██████████| 157/157 [00:30<00:00, 5.08it/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch 14, Loss: 10.97996347269435\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"100%|██████████| 157/157 [00:30<00:00, 5.22it/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch 15, Loss: 10.978485234983408\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"100%|██████████| 157/157 [00:31<00:00, 4.98it/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch 16, Loss: 10.977057912547117\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"100%|██████████| 157/157 [00:30<00:00, 5.23it/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch 17, Loss: 10.97553843601494\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"100%|██████████| 157/157 [00:29<00:00, 5.34it/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch 18, Loss: 10.974108489455691\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"100%|██████████| 157/157 [00:32<00:00, 4.82it/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch 19, Loss: 10.972679308265638\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"100%|██████████| 157/157 [00:30<00:00, 5.23it/s]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch 20, Loss: 10.971182902147815\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"model.to(device)\n",
|
||||
"\n",
|
||||
"for epoch in range(epochs):\n",
|
||||
" total_loss = 0\n",
|
||||
" for batch_inputs, batch_targets in tqdm(dataloader):\n",
|
||||
" batch_inputs, batch_targets = batch_inputs.to(device), batch_targets.to(device)\n",
|
||||
" \n",
|
||||
" model.zero_grad()\n",
|
||||
" output = model(batch_inputs)\n",
|
||||
"\n",
|
||||
" loss = criterion(output, batch_targets.view(-1))\n",
|
||||
" total_loss += loss.item()\n",
|
||||
"\n",
|
||||
" loss.backward()\n",
|
||||
" optimizer.step()\n",
|
||||
"\n",
|
||||
" print(f\"Epoch {epoch+1}, Loss: {total_loss/len(dataloader)}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Write function to convert index to word"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 432,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def idx_to_word(idx):\n",
|
||||
" idx = int(idx)\n",
|
||||
" if idx not in word_to_ix.inverse:\n",
|
||||
" return '<UNK>'\n",
|
||||
" return word_to_ix.inverse[idx]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## test the model"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 433,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def predict(left_context, right_context):\n",
|
||||
" with torch.no_grad():\n",
|
||||
" context = left_context + right_context\n",
|
||||
" test_context_idxs = torch.tensor([[tokenize(x) for x in context]], device=device)\n",
|
||||
" output = model(test_context_idxs)\n",
|
||||
" top_predicted_scores, top_predicted_indices = torch.topk(output, 5)\n",
|
||||
" predictions = list(zip(top_predicted_scores[0], top_predicted_indices[0]))\n",
|
||||
" predictions = [(float(score), idx_to_word(idx)) for score, idx in predictions]\n",
|
||||
" total_score = np.sum([score for score, _ in predictions])\n",
|
||||
" predictions = ' '.join([f\"{word}:{score}\" for score, word in predictions]) + ' :' + str(1.0 - total_score)\n",
|
||||
" return predictions"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 434,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"the:0.210836723446846 of:0.13834647834300995 and:0.11819174885749817 to:0.09819918870925903 a:0.0662047415971756 :0.36822111904621124\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(predict([\"came\", \"fiom\"], []))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Generate result for dev dataset"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 435,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"dataset_dir = os.path.join('..', 'dev-0', 'in.tsv.xz')\n",
|
||||
"output_dir = os.path.join('..', 'dev-0', 'out.tsv')\n",
|
||||
"\n",
|
||||
"df = pd.read_csv(dataset_dir, sep='\\t', header=None, names=['FileId', 'Year', 'LeftContext', 'RightContext'], quoting=csv.QUOTE_NONE)\n",
|
||||
"df = df.replace(r'\\\\r+|\\\\n+|\\\\t+', ' ', regex=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 436,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"100%|██████████| 10519/10519 [02:25<00:00, 72.19it/s]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"final = \"\"\n",
|
||||
"\n",
|
||||
"for i, (_, row) in tqdm(enumerate(df.iterrows()), total=len(df)):\n",
|
||||
" left_context = re.split(r\"\\s+\", row['LeftContext'].strip())[-left_tokens:]\n",
|
||||
" right_context = re.split(r\"\\s+\", row['RightContext'].strip())[:right_tokens]\n",
|
||||
"\n",
|
||||
" final += predict(left_context, right_context) + '\\n'\n",
|
||||
"\n",
|
||||
"with open(output_dir, 'w', encoding=\"UTF-8\") as f:\n",
|
||||
" f.write(final)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "p311-cu121",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.1"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
Loading…
Reference in New Issue
Block a user