challenging-america-word-ga.../trigram_neural.ipynb
2023-05-11 22:48:20 +02:00

8.4 KiB

import itertools
import lzma
import numpy as np
import regex as re
import torch
from torch import nn
from torch.utils.data import IterableDataset, DataLoader
from torchtext.vocab import build_vocab_from_iterator
from google.colab import drive
drive.mount('/content/drive')
%cd /content/drive/MyDrive/america
def get_line(line: str):
    parts = line.split('\t')
    prefix = parts[6].replace(r'\n', ' ')
    suffix = parts[7].replace(r'\n', ' ')
    return prefix + ' ' + suffix

def read_words(line):
    line = get_line(line)
    for word in line.split():
        yield word

def get_words_from_file(path):
    with lzma.open(path, mode='rt', encoding='utf-8') as f:
        for line in f:
            yield read_words(line)
class SimpleTrigramNeuralLanguageModel(nn.Module):
    def __init__(self, vocabulary_size, embedding_size, hidden_size):
        super(SimpleTrigramNeuralLanguageModel, self).__init__()
        self.embedding_size = embedding_size
        self.embedding = nn.Embedding(vocabulary_size, embedding_size)
        self.lin1 = nn.Linear(2 * embedding_size, hidden_size)
        self.rel = nn.ReLU()
        self.lin2 = nn.Linear(hidden_size, vocabulary_size)
        self.sm = nn.Softmax()

    def forward(self, x):
        x = self.embedding(x).view((-1, 2 * self.embedding_size))
        x = self.lin1(x)
        x = self.rel(x)
        x = self.lin2(x)
        return self.sm(x)
def get_context(gen):
    items = [None, None] + list(gen)
    for i in range(2, len(items)):
        if items[i-2] is not None:
            yield np.asarray(items[i-2:i+1])
class Trigrams(IterableDataset):
    def __init__(self, text_file, vocabulary_size):
        self.vocab = build_vocab_from_iterator(
            get_words_from_file(text_file),
            max_tokens=vocabulary_size,
            specials=['<unk>'])
        self.vocab.set_default_index(self.vocab['<unk>'])
        self.vocabulary_size = vocabulary_size
        self.text_file = text_file

    def __iter__(self):
        return get_context(
            (self.vocab[t] for t in itertools.chain.from_iterable(get_words_from_file(self.text_file))))
def train_model(lr):
    model = SimpleTrigramNeuralLanguageModel(vocab_size, embed_size, hidden_size).to(device)
    data = DataLoader(train_dataset, batch_size=batch_size)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = torch.nn.NLLLoss()

    model.train()
    step = 0
    for batch in data:
        x = batch[:, :2]
        y = batch[:, 2]
        x = x.to(device)
        y = y.to(device)
        optimizer.zero_grad()
        ypredicted = model(x)
        loss = criterion(torch.log(ypredicted), y)
        if step % 100 == 0:
            print(step, loss)
        step += 1
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), 10)
        optimizer.step()

    torch.save(model.state_dict(), model_path)
def prediction(words, model, top) -> str:
    words_tensor = [train_dataset.vocab.forward([word]) for word in words]
    ixs = torch.tensor(words_tensor).view(-1).to(device)
    out = model(ixs)
    top_values, top_indices = torch.topk(out[0], top)
    top_probs = top_values.tolist()
    top_words = vocab.lookup_tokens(top_indices.tolist())
    unk_index = top_words.index('<unk>') if '<unk>' in top_words else -1
    if unk_index != -1:
        unk_prob = top_probs[unk_index]
        top_words.pop(unk_index)
        top_probs.pop(unk_index)
        top_words.append('')
        top_probs.append(unk_prob)
    else:
        top_words[-1] = ''
    return ' '.join([f'{x[0]}:{x[1]}' for x in zip(top_words, top_probs)])
def save_outputs(folder_name, model, top):
    input_file_path = f'{folder_name}/in.tsv.xz'
    output_file_path = f'{folder_name}/out-top={top}.tsv'
    with lzma.open(input_file_path, mode='rt', encoding='utf-8') as input_file:
        with open(output_file_path, 'w', encoding='utf-8', newline='\n') as output_file:
            for line in input_file:
                separated = line.split('\t')
                prefix = separated[6].replace(r'\n', ' ').split()[-2:]
                output_line = prediction(prefix, model, top)
                output_file.write(output_line + '\n')
vocab_size = 15000
embed_size = 200
hidden_size = 100
batch_size = 3000
learning_rate = 0.0001
device = 'cuda'
train_path = 'train/in.tsv.xz'
model_path = 'model1.bin'
vocab = build_vocab_from_iterator(
    get_words_from_file(train_path),
    max_tokens=vocab_size,
    specials=['<unk>']
)

vocab.set_default_index(vocab['<unk>'])
train_dataset = Trigrams(train_path, vocab_size)
train_model(lr=learning_rate)
model = SimpleTrigramNeuralLanguageModel(vocab_size, embed_size, hidden_size).to(device)
model.load_state_dict(torch.load(model_path))
model.eval()
for top in [100, 200, 300]:
    save_outputs('dev-0', model, top)
    save_outputs('test-A', model, top)