234 lines
6.6 KiB
Python
234 lines
6.6 KiB
Python
import lzma
|
|
import regex as re
|
|
from torchtext.vocab import build_vocab_from_iterator
|
|
from torch import nn
|
|
import pickle
|
|
from os.path import exists
|
|
from torch.utils.data import IterableDataset
|
|
import itertools
|
|
from torch.utils.data import DataLoader
|
|
import torch
|
|
from matplotlib import pyplot as plt
|
|
from tqdm import tqdm
|
|
|
|
|
|
def get_words_from_line(line):
|
|
line = line.rstrip()
|
|
line = line.split("\t")
|
|
text = line[-2] + " " + line[-1]
|
|
text = re.sub(r"\\\\+n", " ", text)
|
|
text = re.sub('[^A-Za-z ]+', '', text)
|
|
for t in text.split():
|
|
yield t
|
|
|
|
|
|
def get_word_lines_from_file(file_name):
|
|
with lzma.open(file_name, "r") as fh:
|
|
for line in fh:
|
|
yield get_words_from_line(line.decode("utf-8"))
|
|
|
|
|
|
def look_ahead_iterator(gen):
|
|
first = None
|
|
second = None
|
|
for item in gen:
|
|
if first is not None and second is not None:
|
|
yield (first, second, item)
|
|
first = second
|
|
second = item
|
|
|
|
|
|
class Trigrams(IterableDataset):
|
|
def __init__(self, text_file, vocabulary_size):
|
|
self.vocab = build_vocab_from_iterator(
|
|
get_word_lines_from_file(text_file),
|
|
max_tokens=vocabulary_size,
|
|
specials=["<unk>"],
|
|
)
|
|
self.vocab.set_default_index(self.vocab["<unk>"])
|
|
self.vocabulary_size = vocabulary_size
|
|
self.text_file = text_file
|
|
|
|
def __iter__(self):
|
|
return look_ahead_iterator(
|
|
(
|
|
self.vocab[t]
|
|
for t in itertools.chain.from_iterable(
|
|
get_word_lines_from_file(self.text_file)
|
|
)
|
|
)
|
|
)
|
|
|
|
|
|
class TrigramModel(nn.Module):
|
|
def __init__(self, vocab_size, embedding_dim, hidden_dim):
|
|
super(TrigramModel, self).__init__()
|
|
self.embeddings = nn.Embedding(vocab_size, embedding_dim)
|
|
self.hidden = nn.Linear(embedding_dim * 2, hidden_dim)
|
|
self.output = nn.Linear(hidden_dim, vocab_size)
|
|
self.softmax = nn.Softmax()
|
|
|
|
def forward(self, x, y):
|
|
x = self.embeddings(x)
|
|
y = self.embeddings(y)
|
|
z = self.hidden(torch.cat([x, y], dim=1))
|
|
z = self.output(z)
|
|
z = self.softmax(z)
|
|
return z
|
|
|
|
|
|
embed_size = 500
|
|
vocab_size = 20000
|
|
vocab_path = "vocabulary.pickle"
|
|
if exists(vocab_path):
|
|
print("Loading vocabulary from file...")
|
|
with open(vocab_path, "rb") as fh:
|
|
vocab = pickle.load(fh)
|
|
else:
|
|
print("Building vocabulary...")
|
|
vocab = build_vocab_from_iterator(
|
|
get_word_lines_from_file("train/in.tsv.xz"),
|
|
max_tokens=vocab_size,
|
|
specials=["<unk>"],
|
|
)
|
|
|
|
with open(vocab_path, "wb") as fh:
|
|
pickle.dump(vocab, fh)
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
print("Using device:", device)
|
|
dataset_path = 'train/dataset.pickle'
|
|
if exists(dataset_path):
|
|
print("Loading dataset from file...")
|
|
with open(dataset_path, "rb") as fh:
|
|
train_dataset = pickle.load(fh)
|
|
else:
|
|
print("Building dataset...")
|
|
train_dataset = Trigrams("train/in.tsv.xz", vocab_size)
|
|
with open(dataset_path, "wb") as fh:
|
|
pickle.dump(train_dataset, fh)
|
|
|
|
print("Building model...")
|
|
model = TrigramModel(vocab_size, embed_size, 64).to(device)
|
|
data = DataLoader(train_dataset, batch_size=10000)
|
|
optimizer = torch.optim.Adam(model.parameters())
|
|
criterion = torch.nn.NLLLoss()
|
|
|
|
print("Training model...")
|
|
model.train()
|
|
losses = []
|
|
step = 0
|
|
max_steps = 1000
|
|
|
|
for x, y, z in tqdm(data):
|
|
x = x.to(device)
|
|
y = y.to(device)
|
|
z = z.to(device)
|
|
|
|
optimizer.zero_grad()
|
|
ypredicted = model(x, z)
|
|
loss = criterion(torch.log(ypredicted), y)
|
|
losses.append(loss.item())
|
|
loss.backward()
|
|
optimizer.step()
|
|
step += 1
|
|
if step > max_steps:
|
|
break
|
|
|
|
plt.plot(losses)
|
|
plt.show()
|
|
|
|
torch.save(model.state_dict(), f"trigram_model-embed_{embed_size}.bin")
|
|
|
|
vocab_unique = set(train_dataset.vocab.get_stoi().keys())
|
|
|
|
output = []
|
|
print('Predicting dev...')
|
|
with lzma.open("dev-0/in.tsv.xz", encoding='utf8', mode="rt") as file:
|
|
for line in tqdm(file):
|
|
line = line.split("\t")
|
|
|
|
first_word = re.sub(r"\\\\+n", " ", line[-2]).split()[-1]
|
|
first_word = re.sub('[^A-Za-z]+', '', first_word)
|
|
|
|
next_word = re.sub(r"\\\\+n", " ", line[-1]).split()[0]
|
|
nenxt_word = re.sub('[^A-Za-z]+', '', next_word)
|
|
|
|
if first_word not in vocab_unique:
|
|
word = "<unk>"
|
|
if next_word not in vocab_unique:
|
|
word = "<unk>"
|
|
|
|
first_word = torch.tensor(train_dataset.vocab.forward([first_word])).to(device)
|
|
next_word = torch.tensor(train_dataset.vocab.forward([next_word])).to(device)
|
|
|
|
out = model(first_word, next_word)
|
|
|
|
top = torch.topk(out[0], 10)
|
|
top_indices = top.indices.tolist()
|
|
top_probs = top.values.tolist()
|
|
unk_bonus = 1 - sum(top_probs)
|
|
top_words = vocab.lookup_tokens(top_indices)
|
|
top_zipped = list(zip(top_words, top_probs))
|
|
|
|
res = ""
|
|
for w, p in top_zipped:
|
|
if w == "<unk>":
|
|
res += f":{(p + unk_bonus):.4f} "
|
|
else:
|
|
res += f"{w}:{p:.4f} "
|
|
|
|
res = res[:-1]
|
|
res += "\n"
|
|
output.append(res)
|
|
|
|
with open(f"dev-0/out-embed-{embed_size}.tsv", mode="w") as file:
|
|
file.writelines(output)
|
|
|
|
|
|
model.eval()
|
|
|
|
output = []
|
|
print('Predicting test...')
|
|
with lzma.open("test-A/in.tsv.xz", encoding='utf8', mode="rt") as file:
|
|
for line in tqdm(file):
|
|
line = line.split("\t")
|
|
|
|
first_word = re.sub(r"\\\\+n", " ", line[-2]).split()[-1]
|
|
first_word = re.sub('[^A-Za-z]+', '', first_word)
|
|
|
|
next_word = re.sub(r"\\\\+n", " ", line[-1]).split()[0]
|
|
next_word = re.sub('[^A-Za-z]+', '', next_word)
|
|
|
|
if first_word not in vocab_unique:
|
|
word = "<unk>"
|
|
if next_word not in vocab_unique:
|
|
word = "<unk>"
|
|
|
|
first_word = torch.tensor(train_dataset.vocab.forward([first_word])).to(device)
|
|
next_word = torch.tensor(train_dataset.vocab.forward([next_word])).to(device)
|
|
|
|
out = model(first_word, next_word)
|
|
|
|
top = torch.topk(out[0], 10)
|
|
top_indices = top.indices.tolist()
|
|
top_probs = top.values.tolist()
|
|
unk_bonus = 1 - sum(top_probs)
|
|
top_words = vocab.lookup_tokens(top_indices)
|
|
top_zipped = list(zip(top_words, top_probs))
|
|
|
|
res = ""
|
|
for w, p in top_zipped:
|
|
if w == "<unk>":
|
|
res += f":{(p + unk_bonus):.4f} "
|
|
else:
|
|
res += f"{w}:{p:.4f} "
|
|
|
|
res = res[:-1]
|
|
res += "\n"
|
|
output.append(res)
|
|
|
|
with open(f"test-A/out-embed-{embed_size}.tsv", mode="w") as file:
|
|
file.writelines(output)
|