challenging-america-word-ga.../solution.ipynb
2023-05-24 20:03:23 +02:00

9.4 KiB

from torchtext.vocab import build_vocab_from_iterator
import pickle
from torch.utils.data import IterableDataset
from itertools import chain
from torch import nn
import torch.nn.functional as F
import torch
import lzma
from torch.utils.data import DataLoader
import shutil
torch.manual_seed(1)
def simple_preprocess(line):
    return line.replace(r'\n', ' ')

def get_words_from_line(line):
    line = line.strip()
    line = simple_preprocess(line)
    yield '<s>'
    for t in line.split():
        yield t
    yield '</s>'

def get_word_lines_from_file(file_name, n_size=-1):
    with lzma.open(file_name, 'r') as fh:
        n = 0
        for line in fh:
            n += 1
            yield get_words_from_line(line.decode('utf-8'))
            if n == n_size:
                break

def look_ahead_iterator(gen):
    ngram = []
    for item in gen:
        if len(ngram) < 3:
            ngram.append(item)
            if len(ngram) == 3:
                yield ngram[1], ngram[2], ngram[0]
        else:
            ngram = ngram[1:]
            ngram.append(item)
            yield ngram[1], ngram[2], ngram[0]

def build_vocab(file, vocab_size):
    try:
        with open(f'trigram_nn_vocab_{vocab_size}.pickle', 'rb') as handle:
            vocab = pickle.load(handle)
    except:
        vocab = build_vocab_from_iterator(
            get_word_lines_from_file(file),
            max_tokens = vocab_size,
            specials = ['<unk>'])
        with open(f'trigram_nn_vocab_{vocab_size}.pickle', 'wb') as handle:
            pickle.dump(vocab, handle, protocol=pickle.HIGHEST_PROTOCOL)
    return vocab

class Trigrams(IterableDataset):
    def __init__(self, text_file):
        self.vocab = vocab
        self.vocab.set_default_index(self.vocab['<unk>'])
        self.text_file = text_file

    def __iter__(self):
        return look_ahead_iterator(
            (self.vocab[t] for t in chain.from_iterable(get_word_lines_from_file(self.text_file))))

class TrigramNeuralLanguageModel(nn.Module):
    def __init__(self, vocab_size, embed_size, hidden_size):
        super(TrigramNeuralLanguageModel, self).__init__()
        self.embeddings = nn.Embedding(vocab_size, embed_size)
        self.hidden_layer = nn.Linear(2*embed_size, hidden_size)
        self.output_layer = nn.Linear(hidden_size, vocab_size)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        embeds = self.embeddings(x[0]), self.embeddings(x[1])
        concat_embed = torch.concat(embeds, dim=1)
        z = F.relu(self.hidden_layer(concat_embed))
        y = self.softmax(self.output_layer(z))
        return y
max_steps = -1
vocab_size = 20000
embed_size = 150
batch_size = 1024
hidden_size = 1024
learning_rate = 0.001
vocab = build_vocab('challenging-america-word-gap-prediction/train/in.tsv.xz', vocab_size)
train_dataset = Trigrams('challenging-america-word-gap-prediction/train/in.tsv.xz')
if torch.cuda.is_available():
  device = 'cuda'
else:
  raise Exception()
model = TrigramNeuralLanguageModel(vocab_size, embed_size, hidden_size).to(device)
data = DataLoader(train_dataset, batch_size=batch_size)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
criterion = torch.nn.NLLLoss()

model.train()
step = 0
for x1, x2, y in data:
    x = x1.to(device), x2.to(device)
    y = y.to(device)
    optimizer.zero_grad()
    ypredicted = model(x)
    loss = criterion(torch.log(ypredicted), y)
    if step % 1000 == 0:
        print(f'steps: {step}, loss: {loss.item()}')
        if step != 0:
            torch.save(model.state_dict(), f'trigram_nn_model_steps-{step}_vocab-{vocab_size}_embed-{embed_size}_batch-{batch_size}_hidden-{hidden_size}_lr-{learning_rate}.bin')
    loss.backward()
    optimizer.step()
    if step == max_steps:
      break
    step += 1
vocab_size = 20000
embed_size = 150
vocab = build_vocab('challenging-america-word-gap-prediction/train/in.tsv.xz', vocab_size)
vocab.set_default_index(vocab['<unk>'])
for model_name in ['trigram_nn_model_steps-13000_vocab-20000_embed-150_batch-512_hidden-256_lr-0.0001.bin', 'trigram_nn_model_steps-7000_vocab-20000_embed-150_batch-1024_hidden-1024_lr-0.001.bin', 'trigram_nn_model_steps-6000_vocab-20000_embed-150_batch-4096_hidden-256_lr-0.001.bin']:
    print(model_name)
    batch_size = int(model_name.split('_')[-3].split('-')[1])
    print(batch_size)
    hidden_size = int(model_name.split('_')[-2].split('-')[1])
    print(hidden_size)
    topk = 10
    preds = []
    device = 'cuda'
    model = TrigramNeuralLanguageModel(vocab_size, embed_size, hidden_size).to(device)
    model.load_state_dict(torch.load(model_name))
    model.eval()
    for path in ['challenging-america-word-gap-prediction/dev-0', 'challenging-america-word-gap-prediction/test-A']:
        with lzma.open(f'{path}/in.tsv.xz', 'r') as fh, open(f'{path}/out.tsv', 'w', encoding='utf-8') as f_out:
            for line in fh:
                right_context = simple_preprocess(line.decode('utf-8').split('\t')[-1].strip()).split()[:2]
                x = torch.tensor(vocab.forward([right_context[0]])).to(device), \
                torch.tensor(vocab.forward([right_context[1]])).to(device)
                out = model(x)
                top = torch.topk(out[0], topk)
                top_indices = top.indices.tolist()
                top_probs = top.values.tolist()
                top_words = vocab.lookup_tokens(top_indices)
                top_zipped = zip(top_words, top_probs)
                pred = ''
                total_prob = 0
                for word, prob in top_zipped:
                    if word != '<unk>':
                        pred += f'{word}:{prob} '
                        total_prob += prob
                unk_prob = 1 - total_prob
                pred += f':{unk_prob}'
                f_out.write(pred + '\n')
        src=f'{path}/out.tsv'
        dst=f"{path}/{model_name.split('.')[0]}_out.tsv"
        shutil.copy(src, dst)
trigram_nn_model_steps-13000_vocab-20000_embed-150_batch-512_hidden-256_lr-0.0001.bin
512
256
trigram_nn_model_steps-7000_vocab-20000_embed-150_batch-1024_hidden-1024_lr-0.001.bin
1024
1024
trigram_nn_model_steps-6000_vocab-20000_embed-150_batch-4096_hidden-256_lr-0.001.bin
4096
256
%cd challenging-america-word-gap-prediction/
!./geval --test-name dev-0
%cd ../
/home/ked/PycharmProjects/mj9/challenging-america-word-gap-prediction
300.66
/home/ked/PycharmProjects/mj9