from itertools import islice import sys import lzma import regex as re from torchtext.vocab import build_vocab_from_iterator from torch import nn import pickle from os.path import exists from torch.utils.data import IterableDataset import itertools from torch.utils.data import DataLoader import torch from matplotlib import pyplot as plt from tqdm import tqdm def get_words_from_line(line): line = line.rstrip() yield "" for m in re.finditer(r"[\p{L}0-9\*]+|\p{P}+", line): yield m.group(0).lower() yield "" def get_word_lines_from_file(file_name): with lzma.open(file_name, "r") as fh: for line in fh: yield get_words_from_line(line.decode("utf-8")) def look_ahead_iterator(gen): first = None second = None for item in gen: if first is not None and second is not None: yield (first, second, item) first = second second = item class Trigrams(IterableDataset): def __init__(self, text_file, vocabulary_size): self.vocab = build_vocab_from_iterator( get_word_lines_from_file(text_file), max_tokens=vocabulary_size, specials=[""], ) self.vocab.set_default_index(self.vocab[""]) self.vocabulary_size = vocabulary_size self.text_file = text_file def __iter__(self): return look_ahead_iterator( ( self.vocab[t] for t in itertools.chain.from_iterable( get_word_lines_from_file(self.text_file) ) ) ) class TrigramModel(nn.Module): def __init__(self, vocab_size, embedding_dim, hidden_dim): super(TrigramModel, self).__init__() self.embeddings = nn.Embedding(vocab_size, embedding_dim) self.linear1 = nn.Linear(embedding_dim, hidden_dim) self.linear2 = nn.Linear(hidden_dim, vocab_size) self.softmax = nn.Softmax() def forward(self, x, y): x = self.embeddings(x) y = self.embeddings(y) z = self.linear1(x + y) z = self.linear2(z) z = self.softmax(z) return z vocab_size = 20000 vocab_path = "vocabulary.pickle" if exists(vocab_path): with open(vocab_path, "rb") as fh: vocab = pickle.load(fh) else: vocab = build_vocab_from_iterator( get_word_lines_from_file("train/in.tsv.xz"), max_tokens=vocab_size, specials=[""], ) with open(vocab_path, "wb") as fh: pickle.dump(vocab, fh) device = "cpu" train_dataset = Trigrams("train/in.tsv.xz", vocab_size) model = TrigramModel(vocab_size, 100, 64).to(device) data = DataLoader(train_dataset, batch_size=5000) optimizer = torch.optim.Adam(model.parameters()) criterion = torch.nn.NLLLoss() model.train() losses = [] for epoch in tqdm(range(10)): for x, y, z in tqdm(data): x = x.to(device) y = y.to(device) z = z.to(device) optimizer.zero_grad() ypredicted = model(x, z) loss = criterion(torch.log(ypredicted), y) losses.append(loss) loss.backward() optimizer.step() print(f"Epoch {epoch} loss:", loss.item()) plt.plot(losses) torch.save(model.state_dict(), "model1.bin")