9.3 KiB
9.3 KiB
import torch
import lzma
from itertools import islice
import re
import sys
from torchtext.vocab import build_vocab_from_iterator
from torch import nn
from torch.utils.data import IterableDataset, DataLoader
import itertools
import matplotlib.pyplot as plt
Parameters
VOCAB_SIZE = 2_000
EMBED_SIZE = 500
Functions
def get_words_from_line(line):
line = line.rstrip()
line = line.split("\t")
text = line[-2] + " " + line[-1]
text = re.sub(r"\\\\+n", " ", text)
text = re.sub('[^A-Za-z ]+', '', text)
for t in text.split():
yield t
def get_word_lines_from_file(file_name):
with lzma.open(file_name, encoding='utf8', mode="rt") as fh:
for line in fh:
yield get_words_from_line(line)
def look_ahead_iterator(gen):
first = None
second = None
for item in gen:
if first is not None and second is not None:
yield ((first, item), second)
first = second
second = item
Create Vocab
vocab = build_vocab_from_iterator(
get_word_lines_from_file("train/in.tsv.xz"),
max_tokens = VOCAB_SIZE,
specials = ['<unk>'])
Trigram class
class Trigrams(IterableDataset):
def __init__(self, text_file, vocabulary_size):
self.vocab = vocab
self.vocab.set_default_index(self.vocab['<unk>'])
self.vocabulary_size = VOCAB_SIZE
self.text_file = text_file
def __iter__(self):
return look_ahead_iterator(
(self.vocab[t] for t in itertools.chain.from_iterable(get_word_lines_from_file(self.text_file))))
train_dataset = Trigrams("train/in.tsv.xz", VOCAB_SIZE)
class TrigramNNModel(nn.Module):
def __init__(self, VOCAB_SIZE, EMBED_SIZE):
super(TrigramNNModel, self).__init__()
self.embeddings = nn.Embedding(VOCAB_SIZE, EMBED_SIZE)
self.hidden_layer = nn.Linear(EMBED_SIZE*2, 1200)
self.output_layer = nn.Linear(1200, VOCAB_SIZE)
self.softmax = nn.Softmax()
def forward(self, x):
emb_2 = self.embeddings(x[0])
emb_1 = self.embeddings(x[1])
x = torch.cat([emb_2, emb_1], dim=1)
x = self.hidden_layer(x)
x = self.output_layer(x)
x = self.softmax(x)
return x
model = TrigramNNModel(VOCAB_SIZE, EMBED_SIZE)
vocab.set_default_index(vocab['<unk>'])
Training
device = 'cpu'
model = TrigramNNModel(VOCAB_SIZE, EMBED_SIZE).to(device)
data = DataLoader(train_dataset, batch_size=1_000)
optimizer = torch.optim.Adam(model.parameters())
criterion = torch.nn.NLLLoss()
loss_track = []
last_loss = 1_000
trigger_count = 0
model.train()
step = 0
for x, y in data:
x[0] = x[0].to(device)
x[1] = x[1].to(device)
y = y.to(device)
optimizer.zero_grad()
ypredicted = model(x)
loss = criterion(torch.log(ypredicted), y)
if step % 100 == 0:
print(step, loss)
step += 1
loss.backward()
optimizer.step()
if loss > last_loss:
trigger_count += 1
print(trigger_count, 'LOSS DIFF:', loss, last_loss)
if trigger_count >= 500:
break
loss_track.append(loss)
last_loss = loss
torch.save(model.state_dict(), f'model_trigram-EMBED_SIZE={EMBED_SIZE}.bin')
vocab_unique = set(vocab.get_stoi().keys())
output = []
pattern = re.compile('[^A-Za-z]+')
with lzma.open("dev-0/in.tsv.xz", encoding='utf8', mode="rt") as file:
for line in file:
line = line.split("\t")
first_word = pattern.sub(' ', line[-2]).split()[-1]
second_word = pattern.sub(' ', line[-1]).split(maxsplit=1)[0]
first_word = re.sub('[^A-Za-z]+', '', first_word)
second_word = re.sub('[^A-Za-z]+', '', second_word)
first_word = "<unk>" if first_word not in vocab_unique else first_word
second_word = "<unk>" if second_word not in vocab_unique else second_word
input_tokens = torch.tensor([vocab.forward([first_word]), vocab.forward([second_word])]).to(device)
out = model(input_tokens)
top = torch.topk(out[0], 10)
top_indices = top.indices.tolist()
top_probs = top.values.tolist()
unk_bonus = 1 - sum(top_probs)
top_words = vocab.lookup_tokens(top_indices)
top_zipped = list(zip(top_words, top_probs))
res = " ".join([f"{w}:{p:.4f}" if w != "<unk>" else f":{(p + unk_bonus):.4f}" for w, p in top_zipped])
res += "\n"
output.append(res)
with open(f"dev-0/out-EMBED_SIZE={EMBED_SIZE}.tsv", mode="w") as file:
file.writelines(output)
output = []
pattern = re.compile('[^A-Za-z]+')
with lzma.open("test-A/in.tsv.xz", encoding='utf8', mode="rt") as file:
for line in file:
line = line.split("\t")
first_word = pattern.sub(' ', line[-2]).split()[-1]
second_word = pattern.sub(' ', line[-1]).split(maxsplit=1)[0]
first_word = re.sub('[^A-Za-z]+', '', first_word)
second_word = re.sub('[^A-Za-z]+', '', second_word)
first_word = "<unk>" if first_word not in vocab_unique else first_word
second_word = "<unk>" if second_word not in vocab_unique else second_word
input_tokens = torch.tensor([vocab.forward([first_word]), vocab.forward([second_word])]).to(device)
out = model(input_tokens)
top = torch.topk(out[0], 10)
top_indices = top.indices.tolist()
top_probs = top.values.tolist()
unk_bonus = 1 - sum(top_probs)
top_words = vocab.lookup_tokens(top_indices)
top_zipped = list(zip(top_words, top_probs))
res = " ".join([f"{w}:{p:.4f}" if w != "<unk>" else f":{(p + unk_bonus):.4f}" for w, p in top_zipped])
res += "\n"
output.append(res)
with open(f"test-A/out-EMBED_SIZE={EMBED_SIZE}.tsv", mode="w") as file:
file.writelines(output)