challenging-america-word-ga.../solution.ipynb
2023-05-26 11:24:47 +02:00

12 KiB

from torchtext.vocab import build_vocab_from_iterator
import pickle
from torch.utils.data import IterableDataset
from torch import nn
import torch
import lzma
from torch.utils.data import DataLoader
import shutil
torch.manual_seed(1)
def simple_preprocess(line):
    return line.replace(r'\n', ' ')

def get_max_left_context_len(file_name):
    print('Getting max left context length...')
    max_len = 0
    with lzma.open(file_name, 'r') as fh:
        for line in fh:
            line = line.decode('utf-8')
            line = line.strip()
            line = line.split('\t')[-2]
            line = simple_preprocess(line)
            curr_len = len(line.split())
            if curr_len > max_len:
                max_len = curr_len
    print(f'max_len={max_len}')
    return max_len

def get_words_from_line(line):
    for t in line:
        yield t

def get_word_lines_from_file(file_name, max_left_context_len, return_gen, n_size=-1):
    with lzma.open(file_name, 'r') as fh:
        n = 0
        for line in fh:
            n += 1
            line = line.decode('utf-8')
            line = line.strip()
            padding = '<pad> ' * (max_left_context_len - 1) # <s>
            left_context = padding + '<s> ' + simple_preprocess(line.split('\t')[-2])
            right_context = simple_preprocess(line.split('\t')[-1]) + ' </s> <pad> <pad>'
            line = left_context + ' ' + right_context
            line = line.split()
            if return_gen:
                yield get_words_from_line(line)
            else:
                yield line
            if n == n_size:
                break

def look_ahead_iterator(gen, vocab, max_left_context_len):
    for item in gen:
        start_pos = item.index('<s>') + 1
        item = [vocab[t] for t in item]
        for i in range(start_pos, len(item) - 4):
            yield [item[:i-3][-max_left_context_len+3:], item[i-3:i], item[i], item[i+1:i+4]]

def build_vocab(file, vocab_size, max_left_context_len):
    try:
        with open(f'vocab_{vocab_size}_padded.pickle', 'rb') as handle:
            print('Loading vocab...')
            vocab = pickle.load(handle)
    except:
        print('Building vocab...')
        vocab = build_vocab_from_iterator(
            get_word_lines_from_file(file, max_left_context_len, return_gen=True),
            max_tokens = vocab_size,
            specials = ['<unk>'])
        with open(f'vocab_{vocab_size}_padded.pickle', 'wb') as handle:
            pickle.dump(vocab, handle, protocol=pickle.HIGHEST_PROTOCOL)
    return vocab

class Ngrams(IterableDataset):
    def __init__(self, text_file, max_left_context_len):
        self.vocab = vocab
        self.vocab.set_default_index(self.vocab['<unk>'])
        self.text_file = text_file
        self.max_left_context_len = max_left_context_len

    def __iter__(self):
        return look_ahead_iterator(get_word_lines_from_file(self.text_file, max_left_context_len, return_gen=False), self.vocab, self.max_left_context_len)

# Dropout, norm layers adjusted on a case-by-case basis. Also gradual hidden layer size reduction vs. no reduction
class NeuralLanguageModel(nn.Module):
    def __init__(self, vocab_size, embed_size, hidden_size):
        super(NeuralLanguageModel, self).__init__()
        self.embeddings = nn.Embedding(vocab_size, embed_size)
        self.hidden_1 = nn.Linear(7*embed_size, hidden_size)
        self.hidden_2 = nn.Linear(hidden_size, int(hidden_size/2))
        self.hidden_3 = nn.Linear(int(hidden_size/2), int(hidden_size/4))
        self.output = nn.Linear(int(hidden_size/4), vocab_size)

        self.softmax = nn.Softmax(dim=1)
        self.norm_input = nn.LayerNorm(7*embed_size)
        self.norm_1 = nn.LayerNorm(int(hidden_size))
        self.norm_2 = nn.LayerNorm(int(hidden_size/2))
        self.norm_3 = nn.LayerNorm(int(hidden_size/4))
        self.activation = nn.LeakyReLU()
        self.dropout = nn.Dropout(0.1)

    def forward(self, x):
        x_whole_left, x_left_trigram, x_right_trigram = x
        x_whole_left_embed = [self.embeddings(t) for t in x_whole_left]
        x_whole_left_embed_len = len(x_whole_left_embed)
        x_whole_left_embed = torch.stack(x_whole_left_embed)
        x_whole_left_embed = torch.sum(x_whole_left_embed, dim=0) / x_whole_left_embed_len
        #x_whole_left_embed = torch.sum(x_whole_left_embed, dim=0)
        x_left_trigram_embed = torch.concat([self.embeddings(t) for t in x_left_trigram], dim=1)
        x_right_trigram_embed = torch.concat([self.embeddings(t) for t in x_right_trigram], dim=1)
        concat_embed = torch.concat((x_whole_left_embed, x_left_trigram_embed, x_right_trigram_embed), dim=1)
        if torch.isnan(concat_embed).any():
            print('NaN!')
            raise Exception("Error")
        concat_embed = self.norm_input(concat_embed)
        z = self.hidden_1(concat_embed)
        z = self.norm_1(z)
        z = self.activation(z)
        #z = self.dropout(z)
        z = self.hidden_2(z)
        z = self.norm_2(z)
        z = self.activation(z)
        #z = self.dropout(z)
        z = self.hidden_3(z)
        z = self.norm_3(z)
        z = self.activation(z)
        #z = self.dropout(z)
        z = self.output(z)
        y = self.softmax(z)
        return y
# Sample parameters
max_steps = -1
vocab_size = 20000
embed_size = 150
batch_size = 4096
hidden_size = 1024
learning_rate = 0.001 # < 0.1
epochs = 1
#max_left_context_len = get_max_left_context_len('challenging-america-word-gap-prediction/train/in.tsv.xz')
max_left_context_len = 291
torch.manual_seed(1)
vocab = build_vocab('challenging-america-word-gap-prediction/train/in.tsv.xz', vocab_size, max_left_context_len)
train_dataset = Ngrams('challenging-america-word-gap-prediction/train/in.tsv.xz', max_left_context_len)
if torch.cuda.is_available():
  device = 'cuda'
else:
  raise Exception()
model = NeuralLanguageModel(vocab_size, embed_size, hidden_size).to(device)
#model.load_state_dict(torch.load(model_name))
data = DataLoader(train_dataset, batch_size=batch_size)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
criterion = torch.nn.NLLLoss()

with torch.autograd.set_detect_anomaly(True):
    model.train()
    epoch = 0
    for i in range(epochs):
        step = 0
        epoch += 1
        print(f'--------epoch {epoch}--------')
        for x_whole_left, x_left_trigram, y, x_right_trigram in data:
            x = [t.to(device) for t in x_whole_left], [t.to(device) for t in x_left_trigram], [t.to(device) for t in x_right_trigram]
            y = y.to(device)
            optimizer.zero_grad()
            y_pred = model(x)
            loss = criterion(torch.log(y_pred), y)
            if step % 1000 == 0:
                print(f'steps: {step}, loss: {loss.item()}')
                if step != 0:
                    name = f'loss-{loss.item()}_model_steps-{step}_epoch-{epoch}_vocab-{vocab_size}_embed-{embed_size}_batch-{batch_size}_hidden-{hidden_size}_lr-{learning_rate}.bin'
                    torch.save(model.state_dict(), 'models/' + name)
            loss.backward()
            torch.nn.utils.clip_grad_value_(model.parameters(), clip_value=1.0)
            optimizer.step()
            if step == max_steps:
              break
            step += 1
step += 1
vocab_size = 20000
embed_size = 150
batch_size = 4096
hidden_size = 1024
max_left_context_len = 291
vocab = build_vocab('challenging-america-word-gap-prediction/train/in.tsv.xz', vocab_size, max_left_context_len)
vocab.set_default_index(vocab['<unk>'])
model_name =  'models/' + 'best_model_mod_arch.bin'
topk = 10
preds = []
device = 'cuda'
model = NeuralLanguageModel(vocab_size, embed_size, hidden_size).to(device)
model.load_state_dict(torch.load(model_name))
model.eval()
j = 0
for path in ['challenging-america-word-gap-prediction/dev-0', 'challenging-america-word-gap-prediction/test-A']:
    with lzma.open(f'{path}/in.tsv.xz', 'r') as fh, open(f'{path}/out.tsv', 'w', encoding='utf-8') as f_out:
        for line in fh:
            j += 1
            left_context = simple_preprocess(line.decode('utf-8')).split('\t')[-2].strip()
            right_context = simple_preprocess(line.decode('utf-8')).split('\t')[-1].strip()
            padding = '<pad> ' * (max_left_context_len - 1) # <s>
            left_context = padding + '<s> ' + left_context
            right_context = right_context + ' </s> <pad> <pad>'
            x_left_trigram, x_right_trigram = left_context.split()[-3:], right_context.split()[:3]
            x = [torch.tensor(vocab.forward([w])).to(device) for w in left_context], [torch.tensor(vocab.forward([w])).to(device) for w in x_left_trigram], [torch.tensor(vocab.forward([w])).to(device) for w in x_right_trigram]
            out = model(x)
            top = torch.topk(out[0], topk)
            top_indices = top.indices.tolist()
            print(j, ' '.join(x_left_trigram), '[[[', vocab.lookup_token(top_indices[0]) if vocab.lookup_token(top_indices[0]) != '<unk>' else vocab.lookup_token(top_indices[1]), ']]]', ' '.join(x_right_trigram))
            top_probs = top.values.tolist()
            top_words = vocab.lookup_tokens(top_indices)
            top_zipped = zip(top_words, top_probs)
            pred = ''
            total_prob = 0
            for word, prob in top_zipped:
                if word != '<unk>':
                    pred += f'{word}:{prob} '
                    total_prob += prob
            unk_prob = 1 - total_prob
            pred += f':{unk_prob}'
            f_out.write(pred + '\n')