![Logo 1](https://git.wmi.amu.edu.pl/AITech/Szablon/raw/branch/master/Logotyp_AITech1.jpg)
<div class="alert alert-block alert-info">
<h1> Modelowanie Języka</h1>
<h2> 9. <i>Model neuronowy rekurencyjny</i>  [ćwiczenia]</h2> 
<h3> Jakub Pokrywka (2022)</h3>
</div>

![Logo 2](https://git.wmi.amu.edu.pl/AITech/Szablon/raw/branch/master/Logotyp_AITech2.jpg)

In [1]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
import numpy as np
from collections import Counter
import re

In [2]:
device = 'cpu'

In [3]:
! wget https://wolnelektury.pl/media/book/txt/potop-tom-pierwszy.txt
! wget https://wolnelektury.pl/media/book/txt/potop-tom-drugi.txt
! wget https://wolnelektury.pl/media/book/txt/potop-tom-trzeci.txt

--2022-05-08 19:27:04--  https://wolnelektury.pl/media/book/txt/potop-tom-pierwszy.txt
Resolving wolnelektury.pl (wolnelektury.pl)... 51.83.143.148, 2001:41d0:602:3294::
Connecting to wolnelektury.pl (wolnelektury.pl)|51.83.143.148|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 877893 (857K) [text/plain]
Saving to: ‘potop-tom-pierwszy.txt.2’


2022-05-08 19:27:04 (12,0 MB/s) - ‘potop-tom-pierwszy.txt.2’ saved [877893/877893]

--2022-05-08 19:27:04--  https://wolnelektury.pl/media/book/txt/potop-tom-drugi.txt
Resolving wolnelektury.pl (wolnelektury.pl)... 51.83.143.148, 2001:41d0:602:3294::
Connecting to wolnelektury.pl (wolnelektury.pl)|51.83.143.148|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1087797 (1,0M) [text/plain]
Saving to: ‘potop-tom-drugi.txt.2’


2022-05-08 19:27:04 (12,9 MB/s) - ‘potop-tom-drugi.txt.2’ saved [1087797/1087797]

--2022-05-08 19:27:05--  https://wolnelektury.pl/media/book/txt/potop-tom-trzeci.txt
Resolvi

In [4]:
!cat potop-* > potop.txt

In [5]:
class Dataset(torch.utils.data.Dataset):
    def __init__(
            self,
            sequence_length,
    ):
        self.sequence_length = sequence_length
        self.words = self.load()
        self.uniq_words = self.get_uniq_words()

        self.index_to_word = {index: word for index, word in enumerate(self.uniq_words)}
        self.word_to_index = {word: index for index, word in enumerate(self.uniq_words)}

        self.words_indexes = [self.word_to_index[w] for w in self.words]

    def load(self):
        with open('potop.txt', 'r') as f_in:
            text = [x.rstrip() for x in f_in.readlines() if x.strip()]
            text = ' '.join(text).lower()
            text = re.sub('[^a-ząćęłńóśźż ]', '', text) 
            text = text.split(' ')
        return text
    
    
    def get_uniq_words(self):
        word_counts = Counter(self.words)
        return sorted(word_counts, key=word_counts.get, reverse=True)

    def __len__(self):
        return len(self.words_indexes) - self.sequence_length

    def __getitem__(self, index):
        return (
            torch.tensor(self.words_indexes[index:index+self.sequence_length]),
            torch.tensor(self.words_indexes[index+1:index+self.sequence_length+1]),
        )

In [6]:
dataset = Dataset(5)

In [7]:
dataset[200]

(tensor([  551,    18,    17,   255, 10748]),
 tensor([   18,    17,   255, 10748,    34]))

In [8]:
[dataset.index_to_word[x] for x in [   551,    18,    17,   255, 10748]]

['patrzył', 'tak', 'jak', 'człowiek', 'zbudzony']

In [9]:
[dataset.index_to_word[x] for x in [   18,    17,   255, 10748,    34]]

['tak', 'jak', 'człowiek', 'zbudzony', 'ze']

In [10]:
input_tensor = torch.tensor([[ 551,    18,    17,   255, 10748]], dtype=torch.int32).to(device)

In [11]:
#input_tensor = torch.tensor([[ 551,    18]], dtype=torch.int32).to(device)

In [12]:
class Model(nn.Module):
    def __init__(self, vocab_size):
        super(Model, self).__init__()
        self.lstm_size = 128
        self.embedding_dim = 128
        self.num_layers = 3

        self.embedding = nn.Embedding(
            num_embeddings=vocab_size,
            embedding_dim=self.embedding_dim,
        )
        self.lstm = nn.LSTM(
            input_size=self.lstm_size,
            hidden_size=self.lstm_size,
            num_layers=self.num_layers,
            dropout=0.2,
        )
        self.fc = nn.Linear(self.lstm_size, vocab_size)

    def forward(self, x, prev_state = None):
        embed = self.embedding(x)
        output, state = self.lstm(embed, prev_state)
        logits = self.fc(output)
        return logits, state

    def init_state(self, sequence_length):
        return (torch.zeros(self.num_layers, sequence_length, self.lstm_size).to(device),
                torch.zeros(self.num_layers, sequence_length, self.lstm_size).to(device))

In [13]:
model = Model(len(dataset)).to(device)

In [14]:
y_pred, (state_h, state_c) = model(input_tensor)

In [15]:
y_pred

tensor([[[ 0.0046, -0.0113,  0.0313,  ...,  0.0198, -0.0312,  0.0223],
         [ 0.0039, -0.0110,  0.0303,  ...,  0.0213, -0.0302,  0.0230],
         [ 0.0029, -0.0133,  0.0265,  ...,  0.0204, -0.0297,  0.0219],
         [ 0.0010, -0.0120,  0.0282,  ...,  0.0241, -0.0314,  0.0241],
         [ 0.0038, -0.0106,  0.0346,  ...,  0.0230, -0.0333,  0.0232]]],
       grad_fn=<AddBackward0>)

In [16]:
y_pred.shape

torch.Size([1, 5, 1187998])

In [17]:
def train(dataset, model, max_epochs, batch_size):
    model.train()

    dataloader = DataLoader(dataset, batch_size=batch_size)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    for epoch in range(max_epochs):
        for batch, (x, y) in enumerate(dataloader):
            optimizer.zero_grad()
            x = x.to(device)
            y = y.to(device)

            y_pred, (state_h, state_c) = model(x)
            loss = criterion(y_pred.transpose(1, 2), y)

            loss.backward()
            optimizer.step()

            print({ 'epoch': epoch, 'update in batch': batch, '/' : len(dataloader), 'loss': loss.item() })


In [18]:
model = Model(vocab_size = len(dataset.uniq_words)).to(device)
train(dataset, model, 1, 64)

{'epoch': 0, 'update in batch': 0, '/': 18563, 'loss': 10.717817306518555}
{'epoch': 0, 'update in batch': 1, '/': 18563, 'loss': 10.699922561645508}
{'epoch': 0, 'update in batch': 2, '/': 18563, 'loss': 10.701103210449219}
{'epoch': 0, 'update in batch': 3, '/': 18563, 'loss': 10.700254440307617}
{'epoch': 0, 'update in batch': 4, '/': 18563, 'loss': 10.69465160369873}
{'epoch': 0, 'update in batch': 5, '/': 18563, 'loss': 10.681333541870117}
{'epoch': 0, 'update in batch': 6, '/': 18563, 'loss': 10.668376922607422}
{'epoch': 0, 'update in batch': 7, '/': 18563, 'loss': 10.675261497497559}
{'epoch': 0, 'update in batch': 8, '/': 18563, 'loss': 10.665823936462402}
{'epoch': 0, 'update in batch': 9, '/': 18563, 'loss': 10.655462265014648}
{'epoch': 0, 'update in batch': 10, '/': 18563, 'loss': 10.591516494750977}
{'epoch': 0, 'update in batch': 11, '/': 18563, 'loss': 10.580559730529785}
{'epoch': 0, 'update in batch': 12, '/': 18563, 'loss': 10.524133682250977}
{'epoch': 0, 'update in

{'epoch': 0, 'update in batch': 110, '/': 18563, 'loss': 7.351314544677734}
{'epoch': 0, 'update in batch': 111, '/': 18563, 'loss': 8.472936630249023}
{'epoch': 0, 'update in batch': 112, '/': 18563, 'loss': 7.855953216552734}
{'epoch': 0, 'update in batch': 113, '/': 18563, 'loss': 8.163175582885742}
{'epoch': 0, 'update in batch': 114, '/': 18563, 'loss': 8.208657264709473}
{'epoch': 0, 'update in batch': 115, '/': 18563, 'loss': 8.781523704528809}
{'epoch': 0, 'update in batch': 116, '/': 18563, 'loss': 8.449674606323242}
{'epoch': 0, 'update in batch': 117, '/': 18563, 'loss': 8.176030158996582}
{'epoch': 0, 'update in batch': 118, '/': 18563, 'loss': 8.415689468383789}
{'epoch': 0, 'update in batch': 119, '/': 18563, 'loss': 8.645845413208008}
{'epoch': 0, 'update in batch': 120, '/': 18563, 'loss': 8.160420417785645}
{'epoch': 0, 'update in batch': 121, '/': 18563, 'loss': 8.117982864379883}
{'epoch': 0, 'update in batch': 122, '/': 18563, 'loss': 9.099283218383789}
{'epoch': 0,

{'epoch': 0, 'update in batch': 218, '/': 18563, 'loss': 8.82285213470459}
{'epoch': 0, 'update in batch': 219, '/': 18563, 'loss': 7.948827266693115}
{'epoch': 0, 'update in batch': 220, '/': 18563, 'loss': 8.164112091064453}
{'epoch': 0, 'update in batch': 221, '/': 18563, 'loss': 7.721047401428223}
{'epoch': 0, 'update in batch': 222, '/': 18563, 'loss': 7.668707370758057}
{'epoch': 0, 'update in batch': 223, '/': 18563, 'loss': 8.576696395874023}
{'epoch': 0, 'update in batch': 224, '/': 18563, 'loss': 8.253091812133789}
{'epoch': 0, 'update in batch': 225, '/': 18563, 'loss': 8.303543090820312}
{'epoch': 0, 'update in batch': 226, '/': 18563, 'loss': 8.069855690002441}
{'epoch': 0, 'update in batch': 227, '/': 18563, 'loss': 8.57229232788086}
{'epoch': 0, 'update in batch': 228, '/': 18563, 'loss': 8.904585838317871}
{'epoch': 0, 'update in batch': 229, '/': 18563, 'loss': 8.485595703125}
{'epoch': 0, 'update in batch': 230, '/': 18563, 'loss': 8.22756290435791}
{'epoch': 0, 'upda

{'epoch': 0, 'update in batch': 327, '/': 18563, 'loss': 8.522723197937012}
{'epoch': 0, 'update in batch': 328, '/': 18563, 'loss': 7.387178897857666}
{'epoch': 0, 'update in batch': 329, '/': 18563, 'loss': 8.58663558959961}
{'epoch': 0, 'update in batch': 330, '/': 18563, 'loss': 8.539435386657715}
{'epoch': 0, 'update in batch': 331, '/': 18563, 'loss': 8.35865592956543}
{'epoch': 0, 'update in batch': 332, '/': 18563, 'loss': 8.55555248260498}
{'epoch': 0, 'update in batch': 333, '/': 18563, 'loss': 7.9116950035095215}
{'epoch': 0, 'update in batch': 334, '/': 18563, 'loss': 8.424735069274902}
{'epoch': 0, 'update in batch': 335, '/': 18563, 'loss': 8.383890151977539}
{'epoch': 0, 'update in batch': 336, '/': 18563, 'loss': 8.145454406738281}
{'epoch': 0, 'update in batch': 337, '/': 18563, 'loss': 8.014772415161133}
{'epoch': 0, 'update in batch': 338, '/': 18563, 'loss': 8.532005310058594}
{'epoch': 0, 'update in batch': 339, '/': 18563, 'loss': 8.979973793029785}
{'epoch': 0, '

{'epoch': 0, 'update in batch': 435, '/': 18563, 'loss': 7.742388725280762}
{'epoch': 0, 'update in batch': 436, '/': 18563, 'loss': 8.363179206848145}
{'epoch': 0, 'update in batch': 437, '/': 18563, 'loss': 9.262784004211426}
{'epoch': 0, 'update in batch': 438, '/': 18563, 'loss': 9.236469268798828}
{'epoch': 0, 'update in batch': 439, '/': 18563, 'loss': 8.904603958129883}
{'epoch': 0, 'update in batch': 440, '/': 18563, 'loss': 8.675701141357422}
{'epoch': 0, 'update in batch': 441, '/': 18563, 'loss': 8.811418533325195}
{'epoch': 0, 'update in batch': 442, '/': 18563, 'loss': 8.002241134643555}
{'epoch': 0, 'update in batch': 443, '/': 18563, 'loss': 9.04414176940918}
{'epoch': 0, 'update in batch': 444, '/': 18563, 'loss': 7.8904008865356445}
{'epoch': 0, 'update in batch': 445, '/': 18563, 'loss': 8.524297714233398}
{'epoch': 0, 'update in batch': 446, '/': 18563, 'loss': 8.615904808044434}
{'epoch': 0, 'update in batch': 447, '/': 18563, 'loss': 8.201675415039062}
{'epoch': 0,

KeyboardInterrupt: 

In [19]:
def predict(dataset, model, text, next_words=5):
    model.eval()
    words = text.split(' ')
    state_h, state_c = model.init_state(len(words))

    for i in range(0, next_words):
        x = torch.tensor([[dataset.word_to_index[w] for w in words[i:]]]).to(device)
        y_pred, (state_h, state_c) = model(x, (state_h, state_c))

        last_word_logits = y_pred[0][-1]
        p = torch.nn.functional.softmax(last_word_logits, dim=0).detach().cpu().numpy()
        word_index = np.random.choice(len(last_word_logits), p=p)
        words.append(dataset.index_to_word[word_index])

    return words

In [22]:
predict(dataset, model, 'kmicic szedł')

['kmicic', 'szedł', 'zwycięzco', 'po', 'do', 'zlituj', 'i']