PerplexityHashed: 990
This commit is contained in:
parent
d380959afc
commit
337d2ffc42
21038
dev-0/out.tsv
21038
dev-0/out.tsv
File diff suppressed because it is too large
Load Diff
148
run.py
148
run.py
@ -18,9 +18,32 @@ import sys
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from torch.utils.data import DataLoader, TensorDataset
|
from torch.utils.data import DataLoader, TensorDataset
|
||||||
from bidict import bidict
|
from bidict import bidict
|
||||||
import torchtext.vocab as vocab
|
|
||||||
import math
|
import math
|
||||||
|
from sklearn.utils import shuffle
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
|
import random
|
||||||
|
|
||||||
|
# %%
|
||||||
|
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
|
||||||
|
os.environ['TORCH_USE_CUDA_DSA'] = '1'
|
||||||
|
|
||||||
|
# %% [markdown]
|
||||||
|
# ## Global configuration variables
|
||||||
|
|
||||||
|
# %%
|
||||||
|
vocab_size = 60_000
|
||||||
|
batch_size = 64
|
||||||
|
embedding_dim = 64
|
||||||
|
hidden_dim = 1024
|
||||||
|
learning_rate = 0.001
|
||||||
|
epochs = 20
|
||||||
|
|
||||||
|
output_size = vocab_size
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
device = torch.device("cpu")
|
||||||
|
print(device)
|
||||||
|
|
||||||
# %% [markdown]
|
# %% [markdown]
|
||||||
# ## Load train data corpus
|
# ## Load train data corpus
|
||||||
@ -32,21 +55,37 @@ expected_dir = os.path.join('..', 'train', 'expected.tsv')
|
|||||||
df = pd.read_csv(dataset_dir, sep='\t', header=None, names=['FileId', 'Year', 'LeftContext', 'RightContext'], quoting=csv.QUOTE_NONE, dtype=str, chunksize=1000)
|
df = pd.read_csv(dataset_dir, sep='\t', header=None, names=['FileId', 'Year', 'LeftContext', 'RightContext'], quoting=csv.QUOTE_NONE, dtype=str, chunksize=1000)
|
||||||
expected_df = pd.read_csv(expected_dir, sep='\t', header=None, names=['Word'], quoting=csv.QUOTE_NONE, dtype=str, chunksize=1000)
|
expected_df = pd.read_csv(expected_dir, sep='\t', header=None, names=['Word'], quoting=csv.QUOTE_NONE, dtype=str, chunksize=1000)
|
||||||
|
|
||||||
corpus = []
|
|
||||||
|
input_corpus = []
|
||||||
|
target_corpus = []
|
||||||
|
|
||||||
|
left_tokens = 1
|
||||||
|
right_tokens = 1
|
||||||
|
|
||||||
for j, (df, expected_df) in tqdm(enumerate(zip(df, expected_df)), total=433):
|
for j, (df, expected_df) in tqdm(enumerate(zip(df, expected_df)), total=433):
|
||||||
df = df.replace(r'\\r+|\\n+|\\t+', ' ', regex=True)
|
df = df.replace(r'\\r+|\\n+|\\t+', ' ', regex=True)
|
||||||
|
|
||||||
for left_context, word, right_context in zip(df['LeftContext'].to_list(), expected_df['Word'].to_list(), df['LeftContext'].to_list()):
|
for left_context, word, right_context in zip(df['LeftContext'].to_list(), expected_df['Word'].to_list(), df['RightContext'].to_list()):
|
||||||
corpus.extend(re.split(r"\s+", left_context.strip()) + [str(word).strip()] + re.split(r"\s+", right_context.strip()))
|
target_corpus.append([str(word).strip()])
|
||||||
|
input_corpus.append(re.split(r"\s+", left_context.strip())[-left_tokens:] + re.split(r"\s+", right_context.strip())[:right_tokens])
|
||||||
|
|
||||||
# %% [markdown]
|
# %% [markdown]
|
||||||
# ## Create dictionaries for mapping words to indices
|
# ## Create dictionaries for mapping words to indices
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
word_to_ix = bidict({})
|
def flatten(matrix):
|
||||||
counts = Counter(corpus)
|
flat_list = []
|
||||||
|
for row in matrix:
|
||||||
|
flat_list += row
|
||||||
|
return flat_list
|
||||||
|
|
||||||
for word, _ in tqdm(counts.most_common(1_500_000)):
|
# %%
|
||||||
|
word_to_ix = bidict({})
|
||||||
|
words_corpus = flatten(input_corpus) + flatten(target_corpus)
|
||||||
|
|
||||||
|
counts = Counter(words_corpus)
|
||||||
|
|
||||||
|
for word, _ in tqdm(counts.most_common(vocab_size - 1)):
|
||||||
if word not in word_to_ix:
|
if word not in word_to_ix:
|
||||||
word_to_ix[word] = len(word_to_ix) + 1
|
word_to_ix[word] = len(word_to_ix) + 1
|
||||||
|
|
||||||
@ -54,39 +93,58 @@ for word, _ in tqdm(counts.most_common(1_500_000)):
|
|||||||
# ## Tokenize entire corpus
|
# ## Tokenize entire corpus
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
||||||
|
|
||||||
def tokenize(w):
|
def tokenize(w):
|
||||||
if w in word_to_ix:
|
if w in word_to_ix:
|
||||||
return word_to_ix[w]
|
return word_to_ix[w]
|
||||||
else:
|
else:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
tokenized_corpus = []
|
tokenized_input_corpus = []
|
||||||
|
tokenized_target_corpus = []
|
||||||
|
|
||||||
for word in tqdm(corpus):
|
for words in tqdm(input_corpus):
|
||||||
tokenized_corpus.append(tokenize(word))
|
tokenized_input_corpus.append([tokenize(word) for word in words])
|
||||||
|
|
||||||
|
for words in tqdm(target_corpus):
|
||||||
|
tokenized_target_corpus.append([tokenize(word) for word in words])
|
||||||
|
|
||||||
|
# %%
|
||||||
|
tokenized_input_corpus, tokenized_target_corpus = shuffle(tokenized_input_corpus, tokenized_target_corpus)
|
||||||
|
|
||||||
# %% [markdown]
|
# %% [markdown]
|
||||||
# ## Create n-grams
|
# ## Create dataset
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
tokenized_training_corpus = []
|
indices = np.nonzero(np.array(tokenized_target_corpus).flatten())
|
||||||
ngrams = list(nltk.ngrams(tokenized_corpus, n=7))
|
|
||||||
np.random.shuffle(ngrams)
|
|
||||||
ngrams = ngrams[:100_000]
|
|
||||||
ngrams_tensor = torch.tensor(ngrams, dtype=torch.long, device=device)
|
|
||||||
|
|
||||||
indices = torch.any(ngrams_tensor == 0, dim=1)
|
tokenized_input_corpus = np.take(tokenized_input_corpus, indices, axis=0)
|
||||||
ngrams_tensor = ngrams_tensor[~indices]
|
tokenized_target_corpus = np.take(tokenized_target_corpus, indices, axis=0)
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
target_tensor = ngrams_tensor[:, 3].reshape(-1, 1).to(device)
|
input_corpus_tensor = torch.flatten(torch.tensor(tokenized_input_corpus, dtype=torch.long, device=device), end_dim=-2)
|
||||||
input_tensor = torch.cat((ngrams_tensor[:, :3], ngrams_tensor[:, 4:]), dim=1).to(device)
|
target_corpus_tensor = torch.flatten(torch.tensor(tokenized_target_corpus, dtype=torch.long, device=device)).reshape(-1, 1)
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
batched_input_tensor = torch.split(input_tensor, 512)
|
print(input_corpus_tensor.size())
|
||||||
batched_target_tensor = torch.split(target_tensor, 512)
|
print(target_corpus_tensor.size())
|
||||||
|
|
||||||
|
# %%
|
||||||
|
random_index = random.randint(0, len(input_corpus_tensor) - 1)
|
||||||
|
|
||||||
|
# Get random element from input corpus
|
||||||
|
random_input_element = input_corpus_tensor[random_index]
|
||||||
|
|
||||||
|
# Get corresponding element from target corpus
|
||||||
|
random_target_element = target_corpus_tensor[random_index]
|
||||||
|
|
||||||
|
print([word_to_ix.inverse[int(idx)] if int(idx) > 0 else '<UNK>' for idx in random_input_element])
|
||||||
|
print([word_to_ix.inverse[int(idx)] if int(idx) > 0 else '<UNK>' for idx in random_target_element])
|
||||||
|
|
||||||
|
# %%
|
||||||
|
dataset = TensorDataset(input_corpus_tensor[:10_000], target_corpus_tensor[:10_000])
|
||||||
|
|
||||||
|
# %%
|
||||||
|
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
|
||||||
|
|
||||||
# %% [markdown]
|
# %% [markdown]
|
||||||
# ## Define the trigram neural network model
|
# ## Define the trigram neural network model
|
||||||
@ -95,28 +153,17 @@ batched_target_tensor = torch.split(target_tensor, 512)
|
|||||||
class TrigramNN(nn.Module):
|
class TrigramNN(nn.Module):
|
||||||
def __init__(self, vocab_size, embedding_dim, hidden_dim, output_size):
|
def __init__(self, vocab_size, embedding_dim, hidden_dim, output_size):
|
||||||
super(TrigramNN, self).__init__()
|
super(TrigramNN, self).__init__()
|
||||||
self.embedding = nn.Embedding(vocab_size, 50)
|
self.embedding = nn.Embedding(vocab_size, embedding_dim)
|
||||||
self.linear1 = nn.Linear(50 * 6, output_size)
|
self.linear1 = nn.Linear(embedding_dim * (left_tokens + right_tokens), hidden_dim)
|
||||||
|
self.linear2 = nn.Linear(hidden_dim, output_size)
|
||||||
|
|
||||||
def forward(self, inputs):
|
def forward(self, inputs):
|
||||||
out = self.embedding(inputs)
|
out = self.embedding(inputs)
|
||||||
out = out.view(inputs.size(0), -1)
|
out = out.view(inputs.size(0), -1)
|
||||||
out = self.linear1(out)
|
out = torch.softmax(self.linear1(out), dim=1)
|
||||||
out = torch.softmax(out, dim=1)
|
out = self.linear2(out)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
# %% [markdown]
|
|
||||||
# ## Define training parameters
|
|
||||||
|
|
||||||
# %%
|
|
||||||
batch_size = 512
|
|
||||||
vocab_size = len(word_to_ix) + 1
|
|
||||||
embedding_dim = 10
|
|
||||||
hidden_dim = 64
|
|
||||||
output_size = vocab_size
|
|
||||||
learning_rate = 0.005
|
|
||||||
epochs = 1
|
|
||||||
|
|
||||||
# %% [markdown]
|
# %% [markdown]
|
||||||
# ## Initialize the model, loss function, and optimizer
|
# ## Initialize the model, loss function, and optimizer
|
||||||
|
|
||||||
@ -131,11 +178,11 @@ optimizer = optim.SGD(model.parameters(), lr=learning_rate)
|
|||||||
# %%
|
# %%
|
||||||
model.to(device)
|
model.to(device)
|
||||||
|
|
||||||
batches = list(zip(batched_input_tensor, batched_target_tensor))
|
|
||||||
|
|
||||||
for epoch in range(epochs):
|
for epoch in range(epochs):
|
||||||
total_loss = 0
|
total_loss = 0
|
||||||
for batch_inputs, batch_targets in tqdm(batches):
|
for batch_inputs, batch_targets in tqdm(dataloader):
|
||||||
|
batch_inputs, batch_targets = batch_inputs.to(device), batch_targets.to(device)
|
||||||
|
|
||||||
model.zero_grad()
|
model.zero_grad()
|
||||||
output = model(batch_inputs)
|
output = model(batch_inputs)
|
||||||
|
|
||||||
@ -145,7 +192,7 @@ for epoch in range(epochs):
|
|||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
print(f"Epoch {epoch+1}, Loss: {total_loss/len(batches)}")
|
print(f"Epoch {epoch+1}, Loss: {total_loss/len(dataloader)}")
|
||||||
|
|
||||||
# %% [markdown]
|
# %% [markdown]
|
||||||
# ## Write function to convert index to word
|
# ## Write function to convert index to word
|
||||||
@ -168,14 +215,13 @@ def predict(left_context, right_context):
|
|||||||
output = model(test_context_idxs)
|
output = model(test_context_idxs)
|
||||||
top_predicted_scores, top_predicted_indices = torch.topk(output, 5)
|
top_predicted_scores, top_predicted_indices = torch.topk(output, 5)
|
||||||
predictions = list(zip(top_predicted_scores[0], top_predicted_indices[0]))
|
predictions = list(zip(top_predicted_scores[0], top_predicted_indices[0]))
|
||||||
predictions = [(round(float(score), 2), idx_to_word(idx)) for score, idx in predictions]
|
predictions = [(float(score), idx_to_word(idx)) for score, idx in predictions]
|
||||||
total_score = np.sum([score for score, _ in predictions])
|
total_score = np.sum([score for score, _ in predictions])
|
||||||
predictions = ' '.join([f"{word}:{round(score/total_score, 2)}" for score, word in predictions]) + ' :0.01'
|
predictions = ' '.join([f"{word}:{score}" for score, word in predictions]) + ' :' + str(1.0 - total_score)
|
||||||
return predictions
|
return predictions
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
test_context = ["came", "fiom", "the", "place", "to", "this"]
|
print(predict(["came", "fiom"], []))
|
||||||
print(predict(test_context[:3], test_context[3:]))
|
|
||||||
|
|
||||||
# %% [markdown]
|
# %% [markdown]
|
||||||
# # Generate result for dev dataset
|
# # Generate result for dev dataset
|
||||||
@ -191,8 +237,8 @@ df = df.replace(r'\\r+|\\n+|\\t+', ' ', regex=True)
|
|||||||
final = ""
|
final = ""
|
||||||
|
|
||||||
for i, (_, row) in tqdm(enumerate(df.iterrows()), total=len(df)):
|
for i, (_, row) in tqdm(enumerate(df.iterrows()), total=len(df)):
|
||||||
left_context = re.split(r"\s+", row['LeftContext'].strip())[-3:]
|
left_context = re.split(r"\s+", row['LeftContext'].strip())[-left_tokens:]
|
||||||
right_context = re.split(r"\s+", row['RightContext'].strip())[:3]
|
right_context = re.split(r"\s+", row['RightContext'].strip())[:right_tokens]
|
||||||
|
|
||||||
final += predict(left_context, right_context) + '\n'
|
final += predict(left_context, right_context) + '\n'
|
||||||
|
|
||||||
|
@ -16,24 +16,9 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 1,
|
"execution_count": 414,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [],
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"c:\\Users\\Marcin\\.conda\\envs\\p311-cu121\\Lib\\site-packages\\torchtext\\vocab\\__init__.py:4: UserWarning: \n",
|
|
||||||
"/!\\ IMPORTANT WARNING ABOUT TORCHTEXT STATUS /!\\ \n",
|
|
||||||
"Torchtext is deprecated and the last released version will be 0.18 (this one). You can silence this warning by calling the following at the beginnign of your scripts: `import torchtext; torchtext.disable_torchtext_deprecation_warning()`\n",
|
|
||||||
" warnings.warn(torchtext._TORCHTEXT_DEPRECATION_MSG)\n",
|
|
||||||
"c:\\Users\\Marcin\\.conda\\envs\\p311-cu121\\Lib\\site-packages\\torchtext\\utils.py:4: UserWarning: \n",
|
|
||||||
"/!\\ IMPORTANT WARNING ABOUT TORCHTEXT STATUS /!\\ \n",
|
|
||||||
"Torchtext is deprecated and the last released version will be 0.18 (this one). You can silence this warning by calling the following at the beginnign of your scripts: `import torchtext; torchtext.disable_torchtext_deprecation_warning()`\n",
|
|
||||||
" warnings.warn(torchtext._TORCHTEXT_DEPRECATION_MSG)\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
"source": [
|
||||||
"from tqdm import tqdm\n",
|
"from tqdm import tqdm\n",
|
||||||
"import re\n",
|
"import re\n",
|
||||||
@ -48,9 +33,62 @@
|
|||||||
"import numpy as np\n",
|
"import numpy as np\n",
|
||||||
"from torch.utils.data import DataLoader, TensorDataset\n",
|
"from torch.utils.data import DataLoader, TensorDataset\n",
|
||||||
"from bidict import bidict\n",
|
"from bidict import bidict\n",
|
||||||
"import torchtext.vocab as vocab\n",
|
|
||||||
"import math\n",
|
"import math\n",
|
||||||
"from collections import Counter"
|
"from sklearn.utils import shuffle\n",
|
||||||
|
"from collections import Counter\n",
|
||||||
|
"import random"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 415,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"os.environ['CUDA_LAUNCH_BLOCKING'] = '1'\n",
|
||||||
|
"os.environ['TORCH_USE_CUDA_DSA'] = '1'"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Global configuration variables"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 416,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"vocab_size = 60_000\n",
|
||||||
|
"batch_size = 64\n",
|
||||||
|
"embedding_dim = 64\n",
|
||||||
|
"hidden_dim = 1024\n",
|
||||||
|
"learning_rate = 0.001\n",
|
||||||
|
"epochs = 20\n",
|
||||||
|
"\n",
|
||||||
|
"output_size = vocab_size"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 417,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"cpu\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"# device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
||||||
|
"device = torch.device(\"cpu\")\n",
|
||||||
|
"print(device)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -62,14 +100,21 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 2,
|
"execution_count": 418,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"name": "stderr",
|
"name": "stderr",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"100%|██████████| 433/433 [01:10<00:00, 6.12it/s]\n"
|
" 0%| | 0/433 [00:00<?, ?it/s]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"100%|██████████| 433/433 [01:03<00:00, 6.77it/s]\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
@ -80,12 +125,19 @@
|
|||||||
"df = pd.read_csv(dataset_dir, sep='\\t', header=None, names=['FileId', 'Year', 'LeftContext', 'RightContext'], quoting=csv.QUOTE_NONE, dtype=str, chunksize=1000)\n",
|
"df = pd.read_csv(dataset_dir, sep='\\t', header=None, names=['FileId', 'Year', 'LeftContext', 'RightContext'], quoting=csv.QUOTE_NONE, dtype=str, chunksize=1000)\n",
|
||||||
"expected_df = pd.read_csv(expected_dir, sep='\\t', header=None, names=['Word'], quoting=csv.QUOTE_NONE, dtype=str, chunksize=1000)\n",
|
"expected_df = pd.read_csv(expected_dir, sep='\\t', header=None, names=['Word'], quoting=csv.QUOTE_NONE, dtype=str, chunksize=1000)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"corpus = []\n",
|
"\n",
|
||||||
|
"input_corpus = []\n",
|
||||||
|
"target_corpus = []\n",
|
||||||
|
"\n",
|
||||||
|
"left_tokens = 1\n",
|
||||||
|
"right_tokens = 1\n",
|
||||||
|
"\n",
|
||||||
"for j, (df, expected_df) in tqdm(enumerate(zip(df, expected_df)), total=433):\n",
|
"for j, (df, expected_df) in tqdm(enumerate(zip(df, expected_df)), total=433):\n",
|
||||||
" df = df.replace(r'\\\\r+|\\\\n+|\\\\t+', ' ', regex=True)\n",
|
" df = df.replace(r'\\\\r+|\\\\n+|\\\\t+', ' ', regex=True)\n",
|
||||||
" \n",
|
" \n",
|
||||||
" for left_context, word, right_context in zip(df['LeftContext'].to_list(), expected_df['Word'].to_list(), df['LeftContext'].to_list()):\n",
|
" for left_context, word, right_context in zip(df['LeftContext'].to_list(), expected_df['Word'].to_list(), df['RightContext'].to_list()):\n",
|
||||||
" corpus.extend(re.split(r\"\\s+\", left_context.strip()) + [str(word).strip()] + re.split(r\"\\s+\", right_context.strip()))"
|
" target_corpus.append([str(word).strip()])\n",
|
||||||
|
" input_corpus.append(re.split(r\"\\s+\", left_context.strip())[-left_tokens:] + re.split(r\"\\s+\", right_context.strip())[:right_tokens])"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -97,22 +149,37 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 3,
|
"execution_count": 419,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def flatten(matrix):\n",
|
||||||
|
" flat_list = []\n",
|
||||||
|
" for row in matrix:\n",
|
||||||
|
" flat_list += row\n",
|
||||||
|
" return flat_list"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 420,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"name": "stderr",
|
"name": "stderr",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"100%|██████████| 1500000/1500000 [00:11<00:00, 128039.35it/s]\n"
|
"100%|██████████| 59999/59999 [00:00<00:00, 131034.12it/s]\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"word_to_ix = bidict({})\n",
|
"word_to_ix = bidict({})\n",
|
||||||
"counts = Counter(corpus)\n",
|
"words_corpus = flatten(input_corpus) + flatten(target_corpus)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"for word, _ in tqdm(counts.most_common(1_500_000)):\n",
|
"counts = Counter(words_corpus)\n",
|
||||||
|
"\n",
|
||||||
|
"for word, _ in tqdm(counts.most_common(vocab_size - 1)):\n",
|
||||||
" if word not in word_to_ix:\n",
|
" if word not in word_to_ix:\n",
|
||||||
" word_to_ix[word] = len(word_to_ix) + 1"
|
" word_to_ix[word] = len(word_to_ix) + 1"
|
||||||
]
|
]
|
||||||
@ -126,73 +193,135 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 4,
|
"execution_count": 421,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"name": "stderr",
|
"name": "stderr",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"100%|██████████| 139456816/139456816 [01:28<00:00, 1569462.31it/s]\n"
|
"100%|██████████| 432022/432022 [00:01<00:00, 255044.26it/s]\n",
|
||||||
|
"100%|██████████| 432022/432022 [00:01<00:00, 348618.53it/s]\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
|
||||||
"\n",
|
|
||||||
"def tokenize(w):\n",
|
"def tokenize(w):\n",
|
||||||
" if w in word_to_ix:\n",
|
" if w in word_to_ix:\n",
|
||||||
" return word_to_ix[w]\n",
|
" return word_to_ix[w]\n",
|
||||||
" else:\n",
|
" else:\n",
|
||||||
" return 0\n",
|
" return 0\n",
|
||||||
"\n",
|
"\n",
|
||||||
"tokenized_corpus = []\n",
|
"tokenized_input_corpus = []\n",
|
||||||
|
"tokenized_target_corpus = []\n",
|
||||||
"\n",
|
"\n",
|
||||||
"for word in tqdm(corpus):\n",
|
"for words in tqdm(input_corpus):\n",
|
||||||
" tokenized_corpus.append(tokenize(word))"
|
" tokenized_input_corpus.append([tokenize(word) for word in words])\n",
|
||||||
|
"\n",
|
||||||
|
"for words in tqdm(target_corpus):\n",
|
||||||
|
" tokenized_target_corpus.append([tokenize(word) for word in words])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 422,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"tokenized_input_corpus, tokenized_target_corpus = shuffle(tokenized_input_corpus, tokenized_target_corpus)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"## Create n-grams"
|
"## Create dataset"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 5,
|
"execution_count": 423,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"tokenized_training_corpus = []\n",
|
"indices = np.nonzero(np.array(tokenized_target_corpus).flatten())\n",
|
||||||
"ngrams = list(nltk.ngrams(tokenized_corpus, n=7))\n",
|
|
||||||
"np.random.shuffle(ngrams)\n",
|
|
||||||
"ngrams = ngrams[:100_000]\n",
|
|
||||||
"ngrams_tensor = torch.tensor(ngrams, dtype=torch.long, device=device)\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
"indices = torch.any(ngrams_tensor == 0, dim=1)\n",
|
"tokenized_input_corpus = np.take(tokenized_input_corpus, indices, axis=0)\n",
|
||||||
"ngrams_tensor = ngrams_tensor[~indices]"
|
"tokenized_target_corpus = np.take(tokenized_target_corpus, indices, axis=0)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 6,
|
"execution_count": 424,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"target_tensor = ngrams_tensor[:, 3].reshape(-1, 1).to(device)\n",
|
"input_corpus_tensor = torch.flatten(torch.tensor(tokenized_input_corpus, dtype=torch.long, device=device), end_dim=-2)\n",
|
||||||
"input_tensor = torch.cat((ngrams_tensor[:, :3], ngrams_tensor[:, 4:]), dim=1).to(device)"
|
"target_corpus_tensor = torch.flatten(torch.tensor(tokenized_target_corpus, dtype=torch.long, device=device)).reshape(-1, 1)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 7,
|
"execution_count": 425,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"torch.Size([389892, 2])\n",
|
||||||
|
"torch.Size([389892, 1])\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"print(input_corpus_tensor.size())\n",
|
||||||
|
"print(target_corpus_tensor.size())"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 426,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"['end', 'the']\n",
|
||||||
|
"['of']\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"random_index = random.randint(0, len(input_corpus_tensor) - 1)\n",
|
||||||
|
"\n",
|
||||||
|
"# Get random element from input corpus\n",
|
||||||
|
"random_input_element = input_corpus_tensor[random_index]\n",
|
||||||
|
"\n",
|
||||||
|
"# Get corresponding element from target corpus\n",
|
||||||
|
"random_target_element = target_corpus_tensor[random_index]\n",
|
||||||
|
"\n",
|
||||||
|
"print([word_to_ix.inverse[int(idx)] if int(idx) > 0 else '<UNK>' for idx in random_input_element])\n",
|
||||||
|
"print([word_to_ix.inverse[int(idx)] if int(idx) > 0 else '<UNK>' for idx in random_target_element])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 427,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"batched_input_tensor = torch.split(input_tensor, 512)\n",
|
"dataset = TensorDataset(input_corpus_tensor[:10_000], target_corpus_tensor[:10_000])"
|
||||||
"batched_target_tensor = torch.split(target_tensor, 512)"
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 428,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -204,46 +333,25 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 14,
|
"execution_count": 429,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"class TrigramNN(nn.Module):\n",
|
"class TrigramNN(nn.Module):\n",
|
||||||
" def __init__(self, vocab_size, embedding_dim, hidden_dim, output_size):\n",
|
" def __init__(self, vocab_size, embedding_dim, hidden_dim, output_size):\n",
|
||||||
" super(TrigramNN, self).__init__()\n",
|
" super(TrigramNN, self).__init__()\n",
|
||||||
" self.embedding = nn.Embedding(vocab_size, 50)\n",
|
" self.embedding = nn.Embedding(vocab_size, embedding_dim)\n",
|
||||||
" self.linear1 = nn.Linear(50 * 6, output_size)\n",
|
" self.linear1 = nn.Linear(embedding_dim * (left_tokens + right_tokens), hidden_dim)\n",
|
||||||
" \n",
|
" self.linear2 = nn.Linear(hidden_dim, output_size)\n",
|
||||||
|
" \n",
|
||||||
" def forward(self, inputs):\n",
|
" def forward(self, inputs):\n",
|
||||||
" out = self.embedding(inputs)\n",
|
" out = self.embedding(inputs)\n",
|
||||||
" out = out.view(inputs.size(0), -1)\n",
|
" out = out.view(inputs.size(0), -1)\n",
|
||||||
" out = self.linear1(out)\n",
|
" out = torch.softmax(self.linear1(out), dim=1)\n",
|
||||||
" out = torch.softmax(out, dim=1)\n",
|
" out = self.linear2(out)\n",
|
||||||
" return out"
|
" return out"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"## Define training parameters"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 15,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"batch_size = 512\n",
|
|
||||||
"vocab_size = len(word_to_ix) + 1\n",
|
|
||||||
"embedding_dim = 10\n",
|
|
||||||
"hidden_dim = 64\n",
|
|
||||||
"output_size = vocab_size\n",
|
|
||||||
"learning_rate = 0.005\n",
|
|
||||||
"epochs = 1"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
@ -253,7 +361,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 16,
|
"execution_count": 430,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
@ -271,28 +379,287 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 17,
|
"execution_count": 431,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"name": "stderr",
|
"name": "stderr",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
" 0%| | 0/164 [00:00<?, ?it/s]"
|
"100%|██████████| 157/157 [00:32<00:00, 4.81it/s]\n"
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"100%|██████████| 164/164 [29:56<00:00, 10.95s/it]"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"Epoch 1, Loss: 14.220980655856248\n"
|
"Epoch 1, Loss: 10.999195001687214\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"100%|██████████| 157/157 [00:32<00:00, 4.86it/s]\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Epoch 2, Loss: 10.997720451112006\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"100%|██████████| 157/157 [00:32<00:00, 4.88it/s]\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Epoch 3, Loss: 10.99624701214444\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"100%|██████████| 157/157 [00:30<00:00, 5.17it/s]\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Epoch 4, Loss: 10.994744385883306\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"100%|██████████| 157/157 [00:30<00:00, 5.21it/s]\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Epoch 5, Loss: 10.993266263585182\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"100%|██████████| 157/157 [00:30<00:00, 5.22it/s]\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Epoch 6, Loss: 10.991843545512788\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"100%|██████████| 157/157 [00:31<00:00, 4.92it/s]\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Epoch 7, Loss: 10.990350304135852\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"100%|██████████| 157/157 [00:28<00:00, 5.60it/s]\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Epoch 8, Loss: 10.988877800619527\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"100%|██████████| 157/157 [00:32<00:00, 4.81it/s]\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Epoch 9, Loss: 10.987337306806236\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"100%|██████████| 157/157 [00:29<00:00, 5.32it/s]\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Epoch 10, Loss: 10.985873113012618\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"100%|██████████| 157/157 [00:30<00:00, 5.13it/s]\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Epoch 11, Loss: 10.98438450637137\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"100%|██████████| 157/157 [00:28<00:00, 5.45it/s]\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Epoch 12, Loss: 10.9829175548189\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"100%|██████████| 157/157 [00:30<00:00, 5.11it/s]\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Epoch 13, Loss: 10.981461263765954\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"100%|██████████| 157/157 [00:30<00:00, 5.08it/s]\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Epoch 14, Loss: 10.97996347269435\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"100%|██████████| 157/157 [00:30<00:00, 5.22it/s]\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Epoch 15, Loss: 10.978485234983408\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"100%|██████████| 157/157 [00:31<00:00, 4.98it/s]\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Epoch 16, Loss: 10.977057912547117\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"100%|██████████| 157/157 [00:30<00:00, 5.23it/s]\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Epoch 17, Loss: 10.97553843601494\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"100%|██████████| 157/157 [00:29<00:00, 5.34it/s]\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Epoch 18, Loss: 10.974108489455691\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"100%|██████████| 157/157 [00:32<00:00, 4.82it/s]\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Epoch 19, Loss: 10.972679308265638\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"100%|██████████| 157/157 [00:30<00:00, 5.23it/s]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Epoch 20, Loss: 10.971182902147815\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -306,11 +673,11 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"model.to(device)\n",
|
"model.to(device)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"batches = list(zip(batched_input_tensor, batched_target_tensor))\n",
|
|
||||||
"\n",
|
|
||||||
"for epoch in range(epochs):\n",
|
"for epoch in range(epochs):\n",
|
||||||
" total_loss = 0\n",
|
" total_loss = 0\n",
|
||||||
" for batch_inputs, batch_targets in tqdm(batches):\n",
|
" for batch_inputs, batch_targets in tqdm(dataloader):\n",
|
||||||
|
" batch_inputs, batch_targets = batch_inputs.to(device), batch_targets.to(device)\n",
|
||||||
|
" \n",
|
||||||
" model.zero_grad()\n",
|
" model.zero_grad()\n",
|
||||||
" output = model(batch_inputs)\n",
|
" output = model(batch_inputs)\n",
|
||||||
"\n",
|
"\n",
|
||||||
@ -320,7 +687,7 @@
|
|||||||
" loss.backward()\n",
|
" loss.backward()\n",
|
||||||
" optimizer.step()\n",
|
" optimizer.step()\n",
|
||||||
"\n",
|
"\n",
|
||||||
" print(f\"Epoch {epoch+1}, Loss: {total_loss/len(batches)}\")"
|
" print(f\"Epoch {epoch+1}, Loss: {total_loss/len(dataloader)}\")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -332,7 +699,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 18,
|
"execution_count": 432,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
@ -352,7 +719,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 19,
|
"execution_count": 433,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
@ -363,36 +730,27 @@
|
|||||||
" output = model(test_context_idxs)\n",
|
" output = model(test_context_idxs)\n",
|
||||||
" top_predicted_scores, top_predicted_indices = torch.topk(output, 5)\n",
|
" top_predicted_scores, top_predicted_indices = torch.topk(output, 5)\n",
|
||||||
" predictions = list(zip(top_predicted_scores[0], top_predicted_indices[0]))\n",
|
" predictions = list(zip(top_predicted_scores[0], top_predicted_indices[0]))\n",
|
||||||
" predictions = [(round(float(score), 2), idx_to_word(idx)) for score, idx in predictions]\n",
|
" predictions = [(float(score), idx_to_word(idx)) for score, idx in predictions]\n",
|
||||||
" total_score = np.sum([score for score, _ in predictions])\n",
|
" total_score = np.sum([score for score, _ in predictions])\n",
|
||||||
" predictions = ' '.join([f\"{word}:{round(score/total_score, 2)}\" for score, word in predictions]) + ' :0.01'\n",
|
" predictions = ' '.join([f\"{word}:{score}\" for score, word in predictions]) + ' :' + str(1.0 - total_score)\n",
|
||||||
" return predictions"
|
" return predictions"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 20,
|
"execution_count": 434,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"AmiTlceaa.:nan Allentown.:nan thereuntoi:nan Jugo-Slav:nan Sallie,:nan :0.01\n"
|
"the:0.210836723446846 of:0.13834647834300995 and:0.11819174885749817 to:0.09819918870925903 a:0.0662047415971756 :0.36822111904621124\n"
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"C:\\Users\\Marcin\\AppData\\Local\\Temp\\ipykernel_36872\\3389363719.py:10: RuntimeWarning: invalid value encountered in scalar divide\n",
|
|
||||||
" predictions = ' '.join([f\"{word}:{round(score/total_score, 2)}\" for score, word in predictions]) + ' :0.01'\n"
|
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"test_context = [\"came\", \"fiom\", \"the\", \"place\", \"to\", \"this\"]\n",
|
"print(predict([\"came\", \"fiom\"], []))"
|
||||||
"print(predict(test_context[:3], test_context[3:]))"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -404,7 +762,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 21,
|
"execution_count": 435,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
@ -417,15 +775,23 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": 436,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"100%|██████████| 10519/10519 [02:25<00:00, 72.19it/s]\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"final = \"\"\n",
|
"final = \"\"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"for i, (_, row) in tqdm(enumerate(df.iterrows()), total=len(df)):\n",
|
"for i, (_, row) in tqdm(enumerate(df.iterrows()), total=len(df)):\n",
|
||||||
" left_context = re.split(r\"\\s+\", row['LeftContext'].strip())[-3:]\n",
|
" left_context = re.split(r\"\\s+\", row['LeftContext'].strip())[-left_tokens:]\n",
|
||||||
" right_context = re.split(r\"\\s+\", row['RightContext'].strip())[:3]\n",
|
" right_context = re.split(r\"\\s+\", row['RightContext'].strip())[:right_tokens]\n",
|
||||||
"\n",
|
"\n",
|
||||||
" final += predict(left_context, right_context) + '\\n'\n",
|
" final += predict(left_context, right_context) + '\\n'\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
Loading…
Reference in New Issue
Block a user