7.9 KiB
7.9 KiB
from torchtext.vocab import build_vocab_from_iterator
import pickle
from torch.utils.data import IterableDataset
from itertools import chain
from torch import nn
import torch.nn.functional as F
import torch
import lzma
from torch.utils.data import DataLoader
import shutil
torch.manual_seed(1)
def simple_preprocess(line):
return line.replace(r'\n', ' ')
def get_words_from_line(line):
line = line.strip()
line = simple_preprocess(line)
yield '<s>'
for t in line.split():
yield t
yield '</s>'
def get_word_lines_from_file(file_name, n_size=-1):
with lzma.open(file_name, 'r') as fh:
n = 0
for line in fh:
n += 1
yield get_words_from_line(line.decode('utf-8'))
if n == n_size:
break
def look_ahead_iterator(gen):
ngram = []
for item in gen:
if len(ngram) < 3:
ngram.append(item)
if len(ngram) == 3:
yield ngram[1], ngram[2], ngram[0]
else:
ngram = ngram[1:]
ngram.append(item)
yield ngram[1], ngram[2], ngram[0]
def build_vocab(file, vocab_size):
try:
with open(f'vocab_{vocab_size}.pickle', 'rb') as handle:
vocab = pickle.load(handle)
except:
vocab = build_vocab_from_iterator(
get_word_lines_from_file(file),
max_tokens = vocab_size,
specials = ['<unk>'])
with open(f'vocab_{vocab_size}.pickle', 'wb') as handle:
pickle.dump(vocab, handle, protocol=pickle.HIGHEST_PROTOCOL)
return vocab
class Trigrams(IterableDataset):
def __init__(self, text_file):
self.vocab = vocab
self.vocab.set_default_index(self.vocab['<unk>'])
self.text_file = text_file
def __iter__(self):
return look_ahead_iterator(
(self.vocab[t] for t in chain.from_iterable(get_word_lines_from_file(self.text_file))))
class TrigramNeuralLanguageModel(nn.Module):
def __init__(self, vocab_size, embed_size):
super(TrigramNeuralLanguageModel, self).__init__()
self.embeddings = nn.Embedding(vocab_size, embed_size)
self.hidden_layer = nn.Linear(2*embed_size, 64)
self.output_layer = nn.Linear(64, vocab_size)
def forward(self, x):
embeds = self.embeddings(x[0]), self.embeddings(x[1])
concat_embed = torch.concat(embeds, dim=1)
z = F.relu(self.hidden_layer(concat_embed))
softmax = nn.Softmax(dim=1)
y = softmax(self.output_layer(z))
return y
max_steps = -1
vocab_size = 5000
embed_size = 50
batch_size = 5000
vocab = build_vocab('challenging-america-word-gap-prediction/train/in.tsv.xz', vocab_size)
train_dataset = Trigrams('challenging-america-word-gap-prediction/train/in.tsv.xz')
if torch.cuda.is_available():
device = 'cuda'
else:
raise Exception()
model = TrigramNeuralLanguageModel(vocab_size, embed_size).to(device)
data = DataLoader(train_dataset, batch_size=batch_size)
optimizer = torch.optim.Adam(model.parameters())
criterion = torch.nn.NLLLoss()
model.train()
step = 0
for x1, x2, y in data:
x = x1.to(device), x2.to(device)
y = y.to(device)
optimizer.zero_grad()
ypredicted = model(x)
loss = criterion(torch.log(ypredicted), y)
if step % 1000 == 0:
print(step, loss)
if step % 1000 == 0:
torch.save(model.state_dict(), f'model_steps-{step}_vocab-{vocab_size}_embed-{embed_size}_batch-{batch_size}.bin')
loss.backward()
optimizer.step()
if step == max_steps:
break
step += 1
vocab_size = 5000
embed_size = 50
batch_size = 5000
vocab = build_vocab('challenging-america-word-gap-prediction/train/in.tsv.xz', vocab_size)
vocab.set_default_index(vocab['<unk>'])
for model_name in ['model_steps-1000_vocab-5000_embed-50_batch-5000.bin',
'model_steps-1000_vocab-5000_embed-50_batch-5000.bin', 'model_steps-27000_vocab-5000_embed-50_batch-5000.bin']:
preds = []
device = 'cuda'
model = TrigramNeuralLanguageModel(vocab_size, embed_size).to(device)
model.load_state_dict(torch.load(model_name))
model.eval()
j = 0
for path in ['challenging-america-word-gap-prediction/dev-0', 'challenging-america-word-gap-prediction/test-A']:
with lzma.open(f'{path}/in.tsv.xz', 'r') as fh, open(f'{path}/out.tsv', 'w', encoding='utf-8') as f_out:
for line in fh:
right_context = simple_preprocess(line.decode('utf-8').split('\t')[-1]).split()[:2]
x = torch.tensor(vocab.forward([right_context[0]])).to(device), \
torch.tensor(vocab.forward([right_context[1]])).to(device)
out = model(x)
top = torch.topk(out[0], 5)
top_indices = top.indices.tolist()
top_probs = top.values.tolist()
top_words = vocab.lookup_tokens(top_indices)
top_zipped = list(zip(top_words, top_probs))
pred = ''
unk = None
for i, tup in enumerate(top_zipped):
if tup[0] == '<unk>':
unk = top_zipped.pop(i)
for tup in top_zipped:
pred += f'{tup[0]}:{tup[1]}\t'
if unk:
pred += f':{unk[1]}'
else:
pred = pred.rstrip()
f_out.write(pred + '\n')
if j % 1000 == 0:
print(pred)
j += 1
src=f'{path}/out.tsv'
dst=f"{path}/{model_name.split('.')[0]}_out.tsv"
shutil.copy(src, dst)