challenging-america-word-ga.../run.py
Jakub Kaczmarek a7ec11ca27 434624
2023-05-10 00:46:03 +02:00

120 lines
3.2 KiB
Python

from itertools import islice
import sys
import lzma
import regex as re
from torchtext.vocab import build_vocab_from_iterator
from torch import nn
import pickle
from os.path import exists
from torch.utils.data import IterableDataset
import itertools
from torch.utils.data import DataLoader
import torch
from matplotlib import pyplot as plt
from tqdm import tqdm
def get_words_from_line(line):
line = line.rstrip()
yield "<s>"
for m in re.finditer(r"[\p{L}0-9\*]+|\p{P}+", line):
yield m.group(0).lower()
yield "</s>"
def get_word_lines_from_file(file_name):
with lzma.open(file_name, "r") as fh:
for line in fh:
yield get_words_from_line(line.decode("utf-8"))
def look_ahead_iterator(gen):
first = None
second = None
for item in gen:
if first is not None and second is not None:
yield (first, second, item)
first = second
second = item
class Trigrams(IterableDataset):
def __init__(self, text_file, vocabulary_size):
self.vocab = build_vocab_from_iterator(
get_word_lines_from_file(text_file),
max_tokens=vocabulary_size,
specials=["<unk>"],
)
self.vocab.set_default_index(self.vocab["<unk>"])
self.vocabulary_size = vocabulary_size
self.text_file = text_file
def __iter__(self):
return look_ahead_iterator(
(
self.vocab[t]
for t in itertools.chain.from_iterable(
get_word_lines_from_file(self.text_file)
)
)
)
class TrigramModel(nn.Module):
def __init__(self, vocab_size, embedding_dim, hidden_dim):
super(TrigramModel, self).__init__()
self.embeddings = nn.Embedding(vocab_size, embedding_dim)
self.linear1 = nn.Linear(embedding_dim, hidden_dim)
self.linear2 = nn.Linear(hidden_dim, vocab_size)
self.softmax = nn.Softmax()
def forward(self, x, y):
x = self.embeddings(x)
y = self.embeddings(y)
z = self.linear1(x + y)
z = self.linear2(z)
z = self.softmax(z)
return z
vocab_size = 20000
vocab_path = "vocabulary.pickle"
if exists(vocab_path):
with open(vocab_path, "rb") as fh:
vocab = pickle.load(fh)
else:
vocab = build_vocab_from_iterator(
get_word_lines_from_file("train/in.tsv.xz"),
max_tokens=vocab_size,
specials=["<unk>"],
)
with open(vocab_path, "wb") as fh:
pickle.dump(vocab, fh)
device = "cpu"
train_dataset = Trigrams("train/in.tsv.xz", vocab_size)
model = TrigramModel(vocab_size, 100, 64).to(device)
data = DataLoader(train_dataset, batch_size=5000)
optimizer = torch.optim.Adam(model.parameters())
criterion = torch.nn.NLLLoss()
model.train()
losses = []
for epoch in tqdm(range(10)):
for x, y, z in tqdm(data):
x = x.to(device)
y = y.to(device)
z = z.to(device)
optimizer.zero_grad()
ypredicted = model(x, z)
loss = criterion(torch.log(ypredicted), y)
losses.append(loss)
loss.backward()
optimizer.step()
print(f"Epoch {epoch} loss:", loss.item())
plt.plot(losses)
torch.save(model.state_dict(), "model1.bin")