challenging-america-word-ga.../run.py

234 lines
6.6 KiB
Python

import lzma
import regex as re
from torchtext.vocab import build_vocab_from_iterator
from torch import nn
import pickle
from os.path import exists
from torch.utils.data import IterableDataset
import itertools
from torch.utils.data import DataLoader
import torch
from matplotlib import pyplot as plt
from tqdm import tqdm
def get_words_from_line(line):
line = line.rstrip()
line = line.split("\t")
text = line[-2] + " " + line[-1]
text = re.sub(r"\\\\+n", " ", text)
text = re.sub('[^A-Za-z ]+', '', text)
for t in text.split():
yield t
def get_word_lines_from_file(file_name):
with lzma.open(file_name, "r") as fh:
for line in fh:
yield get_words_from_line(line.decode("utf-8"))
def look_ahead_iterator(gen):
first = None
second = None
for item in gen:
if first is not None and second is not None:
yield (first, second, item)
first = second
second = item
class Trigrams(IterableDataset):
def __init__(self, text_file, vocabulary_size):
self.vocab = build_vocab_from_iterator(
get_word_lines_from_file(text_file),
max_tokens=vocabulary_size,
specials=["<unk>"],
)
self.vocab.set_default_index(self.vocab["<unk>"])
self.vocabulary_size = vocabulary_size
self.text_file = text_file
def __iter__(self):
return look_ahead_iterator(
(
self.vocab[t]
for t in itertools.chain.from_iterable(
get_word_lines_from_file(self.text_file)
)
)
)
class TrigramModel(nn.Module):
def __init__(self, vocab_size, embedding_dim, hidden_dim):
super(TrigramModel, self).__init__()
self.embeddings = nn.Embedding(vocab_size, embedding_dim)
self.hidden = nn.Linear(embedding_dim * 2, hidden_dim)
self.output = nn.Linear(hidden_dim, vocab_size)
self.softmax = nn.Softmax()
def forward(self, x, y):
x = self.embeddings(x)
y = self.embeddings(y)
z = self.hidden(torch.cat([x, y], dim=1))
z = self.output(z)
z = self.softmax(z)
return z
embed_size = 500
vocab_size = 20000
vocab_path = "vocabulary.pickle"
if exists(vocab_path):
print("Loading vocabulary from file...")
with open(vocab_path, "rb") as fh:
vocab = pickle.load(fh)
else:
print("Building vocabulary...")
vocab = build_vocab_from_iterator(
get_word_lines_from_file("train/in.tsv.xz"),
max_tokens=vocab_size,
specials=["<unk>"],
)
with open(vocab_path, "wb") as fh:
pickle.dump(vocab, fh)
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)
dataset_path = 'train/dataset.pickle'
if exists(dataset_path):
print("Loading dataset from file...")
with open(dataset_path, "rb") as fh:
train_dataset = pickle.load(fh)
else:
print("Building dataset...")
train_dataset = Trigrams("train/in.tsv.xz", vocab_size)
with open(dataset_path, "wb") as fh:
pickle.dump(train_dataset, fh)
print("Building model...")
model = TrigramModel(vocab_size, embed_size, 64).to(device)
data = DataLoader(train_dataset, batch_size=10000)
optimizer = torch.optim.Adam(model.parameters())
criterion = torch.nn.NLLLoss()
print("Training model...")
model.train()
losses = []
step = 0
max_steps = 1000
for x, y, z in tqdm(data):
x = x.to(device)
y = y.to(device)
z = z.to(device)
optimizer.zero_grad()
ypredicted = model(x, z)
loss = criterion(torch.log(ypredicted), y)
losses.append(loss.item())
loss.backward()
optimizer.step()
step += 1
if step > max_steps:
break
plt.plot(losses)
plt.show()
torch.save(model.state_dict(), f"trigram_model-embed_{embed_size}.bin")
vocab_unique = set(train_dataset.vocab.get_stoi().keys())
output = []
print('Predicting dev...')
with lzma.open("dev-0/in.tsv.xz", encoding='utf8', mode="rt") as file:
for line in tqdm(file):
line = line.split("\t")
first_word = re.sub(r"\\\\+n", " ", line[-2]).split()[-1]
first_word = re.sub('[^A-Za-z]+', '', first_word)
next_word = re.sub(r"\\\\+n", " ", line[-1]).split()[0]
nenxt_word = re.sub('[^A-Za-z]+', '', next_word)
if first_word not in vocab_unique:
word = "<unk>"
if next_word not in vocab_unique:
word = "<unk>"
first_word = torch.tensor(train_dataset.vocab.forward([first_word])).to(device)
next_word = torch.tensor(train_dataset.vocab.forward([next_word])).to(device)
out = model(first_word, next_word)
top = torch.topk(out[0], 10)
top_indices = top.indices.tolist()
top_probs = top.values.tolist()
unk_bonus = 1 - sum(top_probs)
top_words = vocab.lookup_tokens(top_indices)
top_zipped = list(zip(top_words, top_probs))
res = ""
for w, p in top_zipped:
if w == "<unk>":
res += f":{(p + unk_bonus):.4f} "
else:
res += f"{w}:{p:.4f} "
res = res[:-1]
res += "\n"
output.append(res)
with open(f"dev-0/out-embed-{embed_size}.tsv", mode="w") as file:
file.writelines(output)
model.eval()
output = []
print('Predicting test...')
with lzma.open("test-A/in.tsv.xz", encoding='utf8', mode="rt") as file:
for line in tqdm(file):
line = line.split("\t")
first_word = re.sub(r"\\\\+n", " ", line[-2]).split()[-1]
first_word = re.sub('[^A-Za-z]+', '', first_word)
next_word = re.sub(r"\\\\+n", " ", line[-1]).split()[0]
next_word = re.sub('[^A-Za-z]+', '', next_word)
if first_word not in vocab_unique:
word = "<unk>"
if next_word not in vocab_unique:
word = "<unk>"
first_word = torch.tensor(train_dataset.vocab.forward([first_word])).to(device)
next_word = torch.tensor(train_dataset.vocab.forward([next_word])).to(device)
out = model(first_word, next_word)
top = torch.topk(out[0], 10)
top_indices = top.indices.tolist()
top_probs = top.values.tolist()
unk_bonus = 1 - sum(top_probs)
top_words = vocab.lookup_tokens(top_indices)
top_zipped = list(zip(top_words, top_probs))
res = ""
for w, p in top_zipped:
if w == "<unk>":
res += f":{(p + unk_bonus):.4f} "
else:
res += f"{w}:{p:.4f} "
res = res[:-1]
res += "\n"
output.append(res)
with open(f"test-A/out-embed-{embed_size}.tsv", mode="w") as file:
file.writelines(output)