8.4 KiB
8.4 KiB
import itertools
import lzma
import numpy as np
import regex as re
import torch
from torch import nn
from torch.utils.data import IterableDataset, DataLoader
from torchtext.vocab import build_vocab_from_iterator
from google.colab import drive
drive.mount('/content/drive')
%cd /content/drive/MyDrive/america
def get_line(line: str):
parts = line.split('\t')
prefix = parts[6].replace(r'\n', ' ')
suffix = parts[7].replace(r'\n', ' ')
return prefix + ' ' + suffix
def read_words(line):
line = get_line(line)
for word in line.split():
yield word
def get_words_from_file(path):
with lzma.open(path, mode='rt', encoding='utf-8') as f:
for line in f:
yield read_words(line)
class SimpleTrigramNeuralLanguageModel(nn.Module):
def __init__(self, vocabulary_size, embedding_size, hidden_size):
super(SimpleTrigramNeuralLanguageModel, self).__init__()
self.embedding_size = embedding_size
self.embedding = nn.Embedding(vocabulary_size, embedding_size)
self.lin1 = nn.Linear(2 * embedding_size, hidden_size)
self.rel = nn.ReLU()
self.lin2 = nn.Linear(hidden_size, vocabulary_size)
self.sm = nn.Softmax()
def forward(self, x):
x = self.embedding(x).view((-1, 2 * self.embedding_size))
x = self.lin1(x)
x = self.rel(x)
x = self.lin2(x)
return self.sm(x)
def get_context(gen):
items = [None, None] + list(gen)
for i in range(2, len(items)):
if items[i-2] is not None:
yield np.asarray(items[i-2:i+1])
class Trigrams(IterableDataset):
def __init__(self, text_file, vocabulary_size):
self.vocab = build_vocab_from_iterator(
get_words_from_file(text_file),
max_tokens=vocabulary_size,
specials=['<unk>'])
self.vocab.set_default_index(self.vocab['<unk>'])
self.vocabulary_size = vocabulary_size
self.text_file = text_file
def __iter__(self):
return get_context(
(self.vocab[t] for t in itertools.chain.from_iterable(get_words_from_file(self.text_file))))
def train_model(lr):
model = SimpleTrigramNeuralLanguageModel(vocab_size, embed_size, hidden_size).to(device)
data = DataLoader(train_dataset, batch_size=batch_size)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
criterion = torch.nn.NLLLoss()
model.train()
step = 0
for batch in data:
x = batch[:, :2]
y = batch[:, 2]
x = x.to(device)
y = y.to(device)
optimizer.zero_grad()
ypredicted = model(x)
loss = criterion(torch.log(ypredicted), y)
if step % 100 == 0:
print(step, loss)
step += 1
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 10)
optimizer.step()
torch.save(model.state_dict(), model_path)
def prediction(words, model, top) -> str:
words_tensor = [train_dataset.vocab.forward([word]) for word in words]
ixs = torch.tensor(words_tensor).view(-1).to(device)
out = model(ixs)
top_values, top_indices = torch.topk(out[0], top)
top_probs = top_values.tolist()
top_words = vocab.lookup_tokens(top_indices.tolist())
unk_index = top_words.index('<unk>') if '<unk>' in top_words else -1
if unk_index != -1:
unk_prob = top_probs[unk_index]
top_words.pop(unk_index)
top_probs.pop(unk_index)
top_words.append('')
top_probs.append(unk_prob)
else:
top_words[-1] = ''
return ' '.join([f'{x[0]}:{x[1]}' for x in zip(top_words, top_probs)])
def save_outputs(folder_name, model, top):
input_file_path = f'{folder_name}/in.tsv.xz'
output_file_path = f'{folder_name}/out-top={top}.tsv'
with lzma.open(input_file_path, mode='rt', encoding='utf-8') as input_file:
with open(output_file_path, 'w', encoding='utf-8', newline='\n') as output_file:
for line in input_file:
separated = line.split('\t')
prefix = separated[6].replace(r'\n', ' ').split()[-2:]
output_line = prediction(prefix, model, top)
output_file.write(output_line + '\n')
vocab_size = 15000
embed_size = 200
hidden_size = 100
batch_size = 3000
learning_rate = 0.0001
device = 'cuda'
train_path = 'train/in.tsv.xz'
model_path = 'model1.bin'
vocab = build_vocab_from_iterator(
get_words_from_file(train_path),
max_tokens=vocab_size,
specials=['<unk>']
)
vocab.set_default_index(vocab['<unk>'])
train_dataset = Trigrams(train_path, vocab_size)
train_model(lr=learning_rate)
model = SimpleTrigramNeuralLanguageModel(vocab_size, embed_size, hidden_size).to(device)
model.load_state_dict(torch.load(model_path))
model.eval()
for top in [100, 200, 300]:
save_outputs('dev-0', model, top)
save_outputs('test-A', model, top)