120 lines
3.2 KiB
Python
120 lines
3.2 KiB
Python
|
from itertools import islice
|
||
|
import sys
|
||
|
import lzma
|
||
|
import regex as re
|
||
|
from torchtext.vocab import build_vocab_from_iterator
|
||
|
from torch import nn
|
||
|
import pickle
|
||
|
from os.path import exists
|
||
|
from torch.utils.data import IterableDataset
|
||
|
import itertools
|
||
|
from torch.utils.data import DataLoader
|
||
|
import torch
|
||
|
from matplotlib import pyplot as plt
|
||
|
from tqdm import tqdm
|
||
|
|
||
|
|
||
|
def get_words_from_line(line):
|
||
|
line = line.rstrip()
|
||
|
yield "<s>"
|
||
|
for m in re.finditer(r"[\p{L}0-9\*]+|\p{P}+", line):
|
||
|
yield m.group(0).lower()
|
||
|
yield "</s>"
|
||
|
|
||
|
|
||
|
def get_word_lines_from_file(file_name):
|
||
|
with lzma.open(file_name, "r") as fh:
|
||
|
for line in fh:
|
||
|
yield get_words_from_line(line.decode("utf-8"))
|
||
|
|
||
|
|
||
|
def look_ahead_iterator(gen):
|
||
|
first = None
|
||
|
second = None
|
||
|
for item in gen:
|
||
|
if first is not None and second is not None:
|
||
|
yield (first, second, item)
|
||
|
first = second
|
||
|
second = item
|
||
|
|
||
|
|
||
|
class Trigrams(IterableDataset):
|
||
|
def __init__(self, text_file, vocabulary_size):
|
||
|
self.vocab = build_vocab_from_iterator(
|
||
|
get_word_lines_from_file(text_file),
|
||
|
max_tokens=vocabulary_size,
|
||
|
specials=["<unk>"],
|
||
|
)
|
||
|
self.vocab.set_default_index(self.vocab["<unk>"])
|
||
|
self.vocabulary_size = vocabulary_size
|
||
|
self.text_file = text_file
|
||
|
|
||
|
def __iter__(self):
|
||
|
return look_ahead_iterator(
|
||
|
(
|
||
|
self.vocab[t]
|
||
|
for t in itertools.chain.from_iterable(
|
||
|
get_word_lines_from_file(self.text_file)
|
||
|
)
|
||
|
)
|
||
|
)
|
||
|
|
||
|
|
||
|
class TrigramModel(nn.Module):
|
||
|
def __init__(self, vocab_size, embedding_dim, hidden_dim):
|
||
|
super(TrigramModel, self).__init__()
|
||
|
self.embeddings = nn.Embedding(vocab_size, embedding_dim)
|
||
|
self.linear1 = nn.Linear(embedding_dim, hidden_dim)
|
||
|
self.linear2 = nn.Linear(hidden_dim, vocab_size)
|
||
|
self.softmax = nn.Softmax()
|
||
|
|
||
|
def forward(self, x, y):
|
||
|
x = self.embeddings(x)
|
||
|
y = self.embeddings(y)
|
||
|
z = self.linear1(x + y)
|
||
|
z = self.linear2(z)
|
||
|
z = self.softmax(z)
|
||
|
return z
|
||
|
|
||
|
|
||
|
vocab_size = 20000
|
||
|
vocab_path = "vocabulary.pickle"
|
||
|
if exists(vocab_path):
|
||
|
with open(vocab_path, "rb") as fh:
|
||
|
vocab = pickle.load(fh)
|
||
|
else:
|
||
|
vocab = build_vocab_from_iterator(
|
||
|
get_word_lines_from_file("train/in.tsv.xz"),
|
||
|
max_tokens=vocab_size,
|
||
|
specials=["<unk>"],
|
||
|
)
|
||
|
|
||
|
with open(vocab_path, "wb") as fh:
|
||
|
pickle.dump(vocab, fh)
|
||
|
|
||
|
device = "cpu"
|
||
|
train_dataset = Trigrams("train/in.tsv.xz", vocab_size)
|
||
|
model = TrigramModel(vocab_size, 100, 64).to(device)
|
||
|
data = DataLoader(train_dataset, batch_size=5000)
|
||
|
optimizer = torch.optim.Adam(model.parameters())
|
||
|
criterion = torch.nn.NLLLoss()
|
||
|
|
||
|
model.train()
|
||
|
losses = []
|
||
|
for epoch in tqdm(range(10)):
|
||
|
for x, y, z in tqdm(data):
|
||
|
x = x.to(device)
|
||
|
y = y.to(device)
|
||
|
z = z.to(device)
|
||
|
|
||
|
optimizer.zero_grad()
|
||
|
ypredicted = model(x, z)
|
||
|
loss = criterion(torch.log(ypredicted), y)
|
||
|
losses.append(loss)
|
||
|
loss.backward()
|
||
|
optimizer.step()
|
||
|
print(f"Epoch {epoch} loss:", loss.item())
|
||
|
|
||
|
plt.plot(losses)
|
||
|
torch.save(model.state_dict(), "model1.bin")
|