import lzma import regex as re from torchtext.vocab import build_vocab_from_iterator from torch import nn import pickle from os.path import exists from torch.utils.data import IterableDataset import itertools from torch.utils.data import DataLoader import torch from matplotlib import pyplot as plt from tqdm import tqdm def get_words_from_line(line): line = line.rstrip() line = line.split("\t") text = line[-2] + " " + line[-1] text = re.sub(r"\\\\+n", " ", text) text = re.sub('[^A-Za-z ]+', '', text) for t in text.split(): yield t def get_word_lines_from_file(file_name): with lzma.open(file_name, "r") as fh: for line in fh: yield get_words_from_line(line.decode("utf-8")) def look_ahead_iterator(gen): first = None second = None for item in gen: if first is not None and second is not None: yield (first, second, item) first = second second = item class Trigrams(IterableDataset): def __init__(self, text_file, vocabulary_size): self.vocab = build_vocab_from_iterator( get_word_lines_from_file(text_file), max_tokens=vocabulary_size, specials=[""], ) self.vocab.set_default_index(self.vocab[""]) self.vocabulary_size = vocabulary_size self.text_file = text_file def __iter__(self): return look_ahead_iterator( ( self.vocab[t] for t in itertools.chain.from_iterable( get_word_lines_from_file(self.text_file) ) ) ) class TrigramModel(nn.Module): def __init__(self, vocab_size, embedding_dim, hidden_dim): super(TrigramModel, self).__init__() self.embeddings = nn.Embedding(vocab_size, embedding_dim) self.hidden = nn.Linear(embedding_dim * 2, hidden_dim) self.output = nn.Linear(hidden_dim, vocab_size) self.softmax = nn.Softmax() def forward(self, x, y): x = self.embeddings(x) y = self.embeddings(y) z = self.hidden(torch.cat([x, y], dim=1)) z = self.output(z) z = self.softmax(z) return z embed_size = 500 vocab_size = 20000 vocab_path = "vocabulary.pickle" if exists(vocab_path): print("Loading vocabulary from file...") with open(vocab_path, "rb") as fh: vocab = pickle.load(fh) else: print("Building vocabulary...") vocab = build_vocab_from_iterator( get_word_lines_from_file("train/in.tsv.xz"), max_tokens=vocab_size, specials=[""], ) with open(vocab_path, "wb") as fh: pickle.dump(vocab, fh) device = "cuda" if torch.cuda.is_available() else "cpu" print("Using device:", device) dataset_path = 'train/dataset.pickle' if exists(dataset_path): print("Loading dataset from file...") with open(dataset_path, "rb") as fh: train_dataset = pickle.load(fh) else: print("Building dataset...") train_dataset = Trigrams("train/in.tsv.xz", vocab_size) with open(dataset_path, "wb") as fh: pickle.dump(train_dataset, fh) print("Building model...") model = TrigramModel(vocab_size, embed_size, 64).to(device) data = DataLoader(train_dataset, batch_size=10000) optimizer = torch.optim.Adam(model.parameters()) criterion = torch.nn.NLLLoss() print("Training model...") model.train() losses = [] step = 0 max_steps = 1000 for x, y, z in tqdm(data): x = x.to(device) y = y.to(device) z = z.to(device) optimizer.zero_grad() ypredicted = model(x, z) loss = criterion(torch.log(ypredicted), y) losses.append(loss.item()) loss.backward() optimizer.step() step += 1 if step > max_steps: break plt.plot(losses) plt.show() torch.save(model.state_dict(), f"trigram_model-embed_{embed_size}.bin") vocab_unique = set(train_dataset.vocab.get_stoi().keys()) output = [] print('Predicting dev...') with lzma.open("dev-0/in.tsv.xz", encoding='utf8', mode="rt") as file: for line in tqdm(file): line = line.split("\t") first_word = re.sub(r"\\\\+n", " ", line[-2]).split()[-1] first_word = re.sub('[^A-Za-z]+', '', first_word) next_word = re.sub(r"\\\\+n", " ", line[-1]).split()[0] nenxt_word = re.sub('[^A-Za-z]+', '', next_word) if first_word not in vocab_unique: word = "" if next_word not in vocab_unique: word = "" first_word = torch.tensor(train_dataset.vocab.forward([first_word])).to(device) next_word = torch.tensor(train_dataset.vocab.forward([next_word])).to(device) out = model(first_word, next_word) top = torch.topk(out[0], 10) top_indices = top.indices.tolist() top_probs = top.values.tolist() unk_bonus = 1 - sum(top_probs) top_words = vocab.lookup_tokens(top_indices) top_zipped = list(zip(top_words, top_probs)) res = "" for w, p in top_zipped: if w == "": res += f":{(p + unk_bonus):.4f} " else: res += f"{w}:{p:.4f} " res = res[:-1] res += "\n" output.append(res) with open(f"dev-0/out-embed-{embed_size}.tsv", mode="w") as file: file.writelines(output) model.eval() output = [] print('Predicting test...') with lzma.open("test-A/in.tsv.xz", encoding='utf8', mode="rt") as file: for line in tqdm(file): line = line.split("\t") first_word = re.sub(r"\\\\+n", " ", line[-2]).split()[-1] first_word = re.sub('[^A-Za-z]+', '', first_word) next_word = re.sub(r"\\\\+n", " ", line[-1]).split()[0] next_word = re.sub('[^A-Za-z]+', '', next_word) if first_word not in vocab_unique: word = "" if next_word not in vocab_unique: word = "" first_word = torch.tensor(train_dataset.vocab.forward([first_word])).to(device) next_word = torch.tensor(train_dataset.vocab.forward([next_word])).to(device) out = model(first_word, next_word) top = torch.topk(out[0], 10) top_indices = top.indices.tolist() top_probs = top.values.tolist() unk_bonus = 1 - sum(top_probs) top_words = vocab.lookup_tokens(top_indices) top_zipped = list(zip(top_words, top_probs)) res = "" for w, p in top_zipped: if w == "": res += f":{(p + unk_bonus):.4f} " else: res += f"{w}:{p:.4f} " res = res[:-1] res += "\n" output.append(res) with open(f"test-A/out-embed-{embed_size}.tsv", mode="w") as file: file.writelines(output)