challenging-america-word-ga.../RNN.ipynb
2024-05-27 12:59:11 +02:00

17 KiB

import torch
from torch import nn, optim
from torch.utils.data import DataLoader
import numpy as np
from collections import Counter
import regex as re
import itertools
from itertools import islice
import torch

cuda_available = torch.cuda.is_available()
print(f"CUDA Available: {cuda_available}")
if cuda_available:
    print(f"CUDA Device Name: {torch.cuda.get_device_name(0)}")
CUDA Available: True
CUDA Device Name: NVIDIA GeForce RTX 3050
device = 'cuda'
train_path = "C:/Users/Mauri/Desktop/UAM - 3 semestr/modelowanie języka/gap_pred/challenging-america-word-gap-prediction/train/train.txt"
class Dataset(torch.utils.data.Dataset):
    def __init__(
            self,
            sequence_length,
            train_path,
            max_vocab_size=20000
    ):
        self.sequence_length = sequence_length
        self.train_path = train_path
        self.max_vocab_size = max_vocab_size

        self.words = self.load()
        self.uniq_words = self.get_uniq_words()

        self.index_to_word = {index: word for index, word in enumerate(self.uniq_words)}
        self.index_to_word[len(self.index_to_word)] = '<UNK>'
        self.word_to_index = {word: index for index, word in enumerate(self.uniq_words)}
        self.word_to_index['<UNK>'] = len(self.word_to_index)

        self.words_indexes = [self.word_to_index.get(w, self.word_to_index['<UNK>']) for w in self.words]

    def load(self):
        with open(self.train_path, 'r', encoding='utf-8') as f_in:
            text = [x.rstrip() for x in f_in.readlines() if x.strip()]
            text = ' '.join(text).lower()
            text = text.replace('-\\\\\\\\n', '').replace('\\\\\\\\n', ' ').replace('\\\\\\\\t', ' ')
            text = re.sub(r'\n', ' ', text)
            text = re.sub(r'(?<=\w)[,-](?=\w)', '', text)
            text = re.sub(r'\s+', ' ', text)
            text = re.sub(r'\p{P}', '', text)
            text = text.split(' ')
        return text

    def get_uniq_words(self):
        word_counts = Counter(self.words)
        most_common_words = word_counts.most_common(self.max_vocab_size)
        return [word for word, _ in most_common_words]

    def __len__(self):
        return len(self.words_indexes) - self.sequence_length

    def __getitem__(self, index):
        # Get the sequence
        sequence = self.words_indexes[index:index+self.sequence_length]
        # Split the sequence into x and y
        x = sequence[:2] + sequence[-2:]
        y = sequence[len(sequence) // 2]
        return torch.tensor(x), torch.tensor(y)
train_dataset = Dataset(5, train_path)
train_dataset[420]
(tensor([ 14, 110,   3,  28]), tensor(208))
[train_dataset.index_to_word[x] for x in [ 14, 110, 3,  28]]
['at', 'last', 'to', 'tho']
[train_dataset.index_to_word[208]]
['come']
train_dataset[21237]
(tensor([ 218,  104, 8207, 3121]), tensor(20000))
import torch
import torch.nn as nn

class Model(nn.Module):
    def __init__(self, vocab_size, lstm_size=128, embedding_dim=128, num_layers=3, dropout=0.2):
        super(Model, self).__init__()
        self.lstm_size = lstm_size
        self.embedding_dim = embedding_dim
        self.num_layers = num_layers
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        self.embedding = nn.Embedding(
            num_embeddings=vocab_size,
            embedding_dim=self.embedding_dim,
        )
        self.lstm = nn.LSTM(
            input_size=self.embedding_dim,
            hidden_size=self.lstm_size,
            num_layers=self.num_layers,
            dropout=dropout,
        )
        self.fc1 = nn.Linear(self.lstm_size, 256) 
        self.fc2 = nn.Linear(256, vocab_size)
        self.softmax = nn.Softmax(dim=1)
        
    def forward(self, x, prev_state=None):
        x = x.to(self.device)
        embed = self.embedding(x)
        embed = embed.transpose(0, 1)
        
        if prev_state is None:
            prev_state = self.init_state(x.size(0))
        
        output, state = self.lstm(embed, prev_state)
        logits = self.fc1(output[-1])
        logits = self.fc2(logits)
        probabilities = self.softmax(logits)
        return probabilities

    def init_state(self, batch_size):
        return (torch.zeros(self.num_layers, batch_size, self.lstm_size).to(self.device),
                torch.zeros(self.num_layers, batch_size, self.lstm_size).to(self.device))
def train(dataset, model, max_epochs, batch_size):
    model.train()

    dataloader = DataLoader(dataset, batch_size=batch_size, pin_memory=True)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters())

    for epoch in range(max_epochs):
        for batch, (x, y) in enumerate(dataloader):
            optimizer.zero_grad()
            x = x.to(device, non_blocking=True)
            y = y.to(device, non_blocking=True)

            y_pred = model(x)
            loss = criterion(torch.log(y_pred), y)

            loss.backward()
            optimizer.step()

            if batch % 500 == 0:
                print({ 'epoch': epoch, 'update in batch': batch, '/' : len(dataloader), 'loss': loss.item() })
model = Model(vocab_size = len(train_dataset.uniq_words) + 1).to(device)
train(train_dataset, model, 1, 8192)
{'epoch': 0, 'update in batch': 0, '/': 16679, 'loss': 9.917818069458008}
{'epoch': 0, 'update in batch': 500, '/': 16679, 'loss': 6.078440189361572}
{'epoch': 0, 'update in batch': 1000, '/': 16679, 'loss': 5.651369571685791}
{'epoch': 0, 'update in batch': 1500, '/': 16679, 'loss': 5.4341654777526855}
{'epoch': 0, 'update in batch': 2000, '/': 16679, 'loss': 5.383695602416992}
{'epoch': 0, 'update in batch': 2500, '/': 16679, 'loss': 5.225739479064941}
{'epoch': 0, 'update in batch': 3000, '/': 16679, 'loss': 5.282474517822266}
{'epoch': 0, 'update in batch': 3500, '/': 16679, 'loss': 5.092397689819336}
{'epoch': 0, 'update in batch': 4000, '/': 16679, 'loss': 4.940906047821045}
{'epoch': 0, 'update in batch': 4500, '/': 16679, 'loss': 4.908115863800049}
{'epoch': 0, 'update in batch': 5000, '/': 16679, 'loss': 5.092423439025879}
{'epoch': 0, 'update in batch': 5500, '/': 16679, 'loss': 4.979565620422363}
{'epoch': 0, 'update in batch': 6000, '/': 16679, 'loss': 4.8268022537231445}
{'epoch': 0, 'update in batch': 6500, '/': 16679, 'loss': 4.7172017097473145}
{'epoch': 0, 'update in batch': 7000, '/': 16679, 'loss': 4.781315326690674}
{'epoch': 0, 'update in batch': 7500, '/': 16679, 'loss': 5.0033040046691895}
{'epoch': 0, 'update in batch': 8000, '/': 16679, 'loss': 4.663774013519287}
{'epoch': 0, 'update in batch': 8500, '/': 16679, 'loss': 4.710158348083496}
{'epoch': 0, 'update in batch': 9000, '/': 16679, 'loss': 4.817586898803711}
{'epoch': 0, 'update in batch': 9500, '/': 16679, 'loss': 4.655371189117432}
{'epoch': 0, 'update in batch': 10000, '/': 16679, 'loss': 4.679412841796875}
{'epoch': 0, 'update in batch': 10500, '/': 16679, 'loss': 4.544621467590332}
{'epoch': 0, 'update in batch': 11000, '/': 16679, 'loss': 4.816493511199951}
{'epoch': 0, 'update in batch': 11500, '/': 16679, 'loss': 4.627770900726318}
{'epoch': 0, 'update in batch': 12000, '/': 16679, 'loss': 4.525866985321045}
{'epoch': 0, 'update in batch': 12500, '/': 16679, 'loss': 4.739295959472656}
{'epoch': 0, 'update in batch': 13000, '/': 16679, 'loss': 4.6095709800720215}
{'epoch': 0, 'update in batch': 13500, '/': 16679, 'loss': 4.7243266105651855}
{'epoch': 0, 'update in batch': 14000, '/': 16679, 'loss': 4.557321071624756}
{'epoch': 0, 'update in batch': 14500, '/': 16679, 'loss': 4.830319404602051}
{'epoch': 0, 'update in batch': 15000, '/': 16679, 'loss': 4.536618709564209}
{'epoch': 0, 'update in batch': 15500, '/': 16679, 'loss': 4.605734825134277}
{'epoch': 0, 'update in batch': 16000, '/': 16679, 'loss': 4.605676651000977}
{'epoch': 0, 'update in batch': 16500, '/': 16679, 'loss': 4.614283084869385}
torch.save(model.state_dict(), 'model.pth')
model = Model(20001).to(device)
model.load_state_dict(torch.load('model.pth'))
<All keys matched successfully>
def clean(text):
    text = text.replace('-\\\\\\\\n', '').replace('\\\\\\\\n', ' ').replace('\\\\\\\\t', ' ')
    text = re.sub(r'\n', ' ', text)
    text = re.sub(r'(?<=\w)[,-](?=\w)', '', text)
    text = re.sub(r'\s+', ' ', text)
    text = re.sub(r'\p{P}', '', text)
    text = text.strip()
    return text
def get_words(words, model, dataset, n=20):
    ixs = [dataset.word_to_index.get(word, dataset.word_to_index['<UNK>']) for word in words]
    ixs = torch.tensor(ixs).unsqueeze(0).to(model.device)

    out = model(ixs)
    top = torch.topk(out[0], n)
    top_indices = top.indices.tolist()
    top_probs = top.values.tolist()
    top_words = [dataset.index_to_word[idx] for idx in top_indices]
    return list(zip(top_words, top_probs))
def f_out(left, right, model, dataset):
    left = clean(left)
    right = clean(right)
    words = left.split(' ')[-2:] + right.split(' ')[:2]
    words = get_words(words, model, dataset)

    probs_sum = 0
    output = ''
    for word, prob in words:
        if word == "<UNK>":
            continue
        probs_sum += prob
        output += f"{word}:{prob} "
    output += f":{1-probs_sum}"

    return output
def create_out(input_path, model, dataset, output_path):
    lines = []
    with open(input_path, encoding='utf-8') as f:
        for line in f:
            columns = line.split('\t')
            left = columns[6]
            right = columns[7]
            lines.append((left, right))

    with open(output_path, 'w', encoding='utf-8') as output_file:
        for left, right in lines:
            result = f_out(left, right, model, dataset)
            output_file.write(result + '\n')
dev_path = "C:/Users/Mauri/Desktop/UAM - 3 semestr/modelowanie języka/gap_pred/challenging-america-word-gap-prediction/dev-0/in.tsv"
create_out(dev_path, model, train_dataset, output_path='C:/Users/Mauri/Desktop/UAM - 3 semestr/modelowanie języka/gap_pred/challenging-america-word-gap-prediction/dev-0/out.tsv')
test_path = "C:/Users/Mauri/Desktop/UAM - 3 semestr/modelowanie języka/gap_pred/challenging-america-word-gap-prediction/test-A/in.tsv"
create_out(test_path, model, train_dataset, output_path='C:/Users/Mauri/Desktop/UAM - 3 semestr/modelowanie języka/gap_pred/challenging-america-word-gap-prediction/test-A/out.tsv')