12 KiB
12 KiB
from torchtext.vocab import build_vocab_from_iterator
import pickle
from torch.utils.data import IterableDataset
from torch import nn
import torch
import lzma
from torch.utils.data import DataLoader
import shutil
torch.manual_seed(1)
def simple_preprocess(line):
return line.replace(r'\n', ' ')
def get_max_left_context_len(file_name):
print('Getting max left context length...')
max_len = 0
with lzma.open(file_name, 'r') as fh:
for line in fh:
line = line.decode('utf-8')
line = line.strip()
line = line.split('\t')[-2]
line = simple_preprocess(line)
curr_len = len(line.split())
if curr_len > max_len:
max_len = curr_len
print(f'max_len={max_len}')
return max_len
def get_words_from_line(line):
for t in line:
yield t
def get_word_lines_from_file(file_name, max_left_context_len, return_gen, n_size=-1):
with lzma.open(file_name, 'r') as fh:
n = 0
for line in fh:
n += 1
line = line.decode('utf-8')
line = line.strip()
padding = '<pad> ' * (max_left_context_len - 1) # <s>
left_context = padding + '<s> ' + simple_preprocess(line.split('\t')[-2])
right_context = simple_preprocess(line.split('\t')[-1]) + ' </s> <pad> <pad>'
line = left_context + ' ' + right_context
line = line.split()
if return_gen:
yield get_words_from_line(line)
else:
yield line
if n == n_size:
break
def look_ahead_iterator(gen, vocab, max_left_context_len):
for item in gen:
start_pos = item.index('<s>') + 1
item = [vocab[t] for t in item]
for i in range(start_pos, len(item) - 4):
yield [item[:i-3][-max_left_context_len+3:], item[i-3:i], item[i], item[i+1:i+4]]
def build_vocab(file, vocab_size, max_left_context_len):
try:
with open(f'vocab_{vocab_size}_padded.pickle', 'rb') as handle:
print('Loading vocab...')
vocab = pickle.load(handle)
except:
print('Building vocab...')
vocab = build_vocab_from_iterator(
get_word_lines_from_file(file, max_left_context_len, return_gen=True),
max_tokens = vocab_size,
specials = ['<unk>'])
with open(f'vocab_{vocab_size}_padded.pickle', 'wb') as handle:
pickle.dump(vocab, handle, protocol=pickle.HIGHEST_PROTOCOL)
return vocab
class Ngrams(IterableDataset):
def __init__(self, text_file, max_left_context_len):
self.vocab = vocab
self.vocab.set_default_index(self.vocab['<unk>'])
self.text_file = text_file
self.max_left_context_len = max_left_context_len
def __iter__(self):
return look_ahead_iterator(get_word_lines_from_file(self.text_file, max_left_context_len, return_gen=False), self.vocab, self.max_left_context_len)
# Dropout, norm layers adjusted on a case-by-case basis. Also gradual hidden layer size reduction vs. no reduction
class NeuralLanguageModel(nn.Module):
def __init__(self, vocab_size, embed_size, hidden_size):
super(NeuralLanguageModel, self).__init__()
self.embeddings = nn.Embedding(vocab_size, embed_size)
self.hidden_1 = nn.Linear(7*embed_size, hidden_size)
self.hidden_2 = nn.Linear(hidden_size, int(hidden_size/2))
self.hidden_3 = nn.Linear(int(hidden_size/2), int(hidden_size/4))
self.output = nn.Linear(int(hidden_size/4), vocab_size)
self.softmax = nn.Softmax(dim=1)
self.norm_input = nn.LayerNorm(7*embed_size)
self.norm_1 = nn.LayerNorm(int(hidden_size))
self.norm_2 = nn.LayerNorm(int(hidden_size/2))
self.norm_3 = nn.LayerNorm(int(hidden_size/4))
self.activation = nn.LeakyReLU()
self.dropout = nn.Dropout(0.1)
def forward(self, x):
x_whole_left, x_left_trigram, x_right_trigram = x
x_whole_left_embed = [self.embeddings(t) for t in x_whole_left]
x_whole_left_embed_len = len(x_whole_left_embed)
x_whole_left_embed = torch.stack(x_whole_left_embed)
x_whole_left_embed = torch.sum(x_whole_left_embed, dim=0) / x_whole_left_embed_len
#x_whole_left_embed = torch.sum(x_whole_left_embed, dim=0)
x_left_trigram_embed = torch.concat([self.embeddings(t) for t in x_left_trigram], dim=1)
x_right_trigram_embed = torch.concat([self.embeddings(t) for t in x_right_trigram], dim=1)
concat_embed = torch.concat((x_whole_left_embed, x_left_trigram_embed, x_right_trigram_embed), dim=1)
if torch.isnan(concat_embed).any():
print('NaN!')
raise Exception("Error")
concat_embed = self.norm_input(concat_embed)
z = self.hidden_1(concat_embed)
z = self.norm_1(z)
z = self.activation(z)
#z = self.dropout(z)
z = self.hidden_2(z)
z = self.norm_2(z)
z = self.activation(z)
#z = self.dropout(z)
z = self.hidden_3(z)
z = self.norm_3(z)
z = self.activation(z)
#z = self.dropout(z)
z = self.output(z)
y = self.softmax(z)
return y
# Sample parameters
max_steps = -1
vocab_size = 20000
embed_size = 150
batch_size = 4096
hidden_size = 1024
learning_rate = 0.001 # < 0.1
epochs = 1
#max_left_context_len = get_max_left_context_len('challenging-america-word-gap-prediction/train/in.tsv.xz')
max_left_context_len = 291
torch.manual_seed(1)
vocab = build_vocab('challenging-america-word-gap-prediction/train/in.tsv.xz', vocab_size, max_left_context_len)
train_dataset = Ngrams('challenging-america-word-gap-prediction/train/in.tsv.xz', max_left_context_len)
if torch.cuda.is_available():
device = 'cuda'
else:
raise Exception()
model = NeuralLanguageModel(vocab_size, embed_size, hidden_size).to(device)
#model.load_state_dict(torch.load(model_name))
data = DataLoader(train_dataset, batch_size=batch_size)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
criterion = torch.nn.NLLLoss()
with torch.autograd.set_detect_anomaly(True):
model.train()
epoch = 0
for i in range(epochs):
step = 0
epoch += 1
print(f'--------epoch {epoch}--------')
for x_whole_left, x_left_trigram, y, x_right_trigram in data:
x = [t.to(device) for t in x_whole_left], [t.to(device) for t in x_left_trigram], [t.to(device) for t in x_right_trigram]
y = y.to(device)
optimizer.zero_grad()
y_pred = model(x)
loss = criterion(torch.log(y_pred), y)
if step % 1000 == 0:
print(f'steps: {step}, loss: {loss.item()}')
if step != 0:
name = f'loss-{loss.item()}_model_steps-{step}_epoch-{epoch}_vocab-{vocab_size}_embed-{embed_size}_batch-{batch_size}_hidden-{hidden_size}_lr-{learning_rate}.bin'
torch.save(model.state_dict(), 'models/' + name)
loss.backward()
torch.nn.utils.clip_grad_value_(model.parameters(), clip_value=1.0)
optimizer.step()
if step == max_steps:
break
step += 1
step += 1
vocab_size = 20000
embed_size = 150
batch_size = 4096
hidden_size = 1024
max_left_context_len = 291
vocab = build_vocab('challenging-america-word-gap-prediction/train/in.tsv.xz', vocab_size, max_left_context_len)
vocab.set_default_index(vocab['<unk>'])
model_name = 'models/' + 'best_model_mod_arch.bin'
topk = 10
preds = []
device = 'cuda'
model = NeuralLanguageModel(vocab_size, embed_size, hidden_size).to(device)
model.load_state_dict(torch.load(model_name))
model.eval()
j = 0
for path in ['challenging-america-word-gap-prediction/dev-0', 'challenging-america-word-gap-prediction/test-A']:
with lzma.open(f'{path}/in.tsv.xz', 'r') as fh, open(f'{path}/out.tsv', 'w', encoding='utf-8') as f_out:
for line in fh:
j += 1
left_context = simple_preprocess(line.decode('utf-8')).split('\t')[-2].strip()
right_context = simple_preprocess(line.decode('utf-8')).split('\t')[-1].strip()
padding = '<pad> ' * (max_left_context_len - 1) # <s>
left_context = padding + '<s> ' + left_context
right_context = right_context + ' </s> <pad> <pad>'
x_left_trigram, x_right_trigram = left_context.split()[-3:], right_context.split()[:3]
x = [torch.tensor(vocab.forward([w])).to(device) for w in left_context], [torch.tensor(vocab.forward([w])).to(device) for w in x_left_trigram], [torch.tensor(vocab.forward([w])).to(device) for w in x_right_trigram]
out = model(x)
top = torch.topk(out[0], topk)
top_indices = top.indices.tolist()
print(j, ' '.join(x_left_trigram), '[[[', vocab.lookup_token(top_indices[0]) if vocab.lookup_token(top_indices[0]) != '<unk>' else vocab.lookup_token(top_indices[1]), ']]]', ' '.join(x_right_trigram))
top_probs = top.values.tolist()
top_words = vocab.lookup_tokens(top_indices)
top_zipped = zip(top_words, top_probs)
pred = ''
total_prob = 0
for word, prob in top_zipped:
if word != '<unk>':
pred += f'{word}:{prob} '
total_prob += prob
unk_prob = 1 - total_prob
pred += f':{unk_prob}'
f_out.write(pred + '\n')