aitech-moj-2023/cw/09_Model_neuronowy_rekurencyjny.ipynb
2022-05-08 19:32:57 +02:00

60 KiB
Raw Blame History

Logo 1

Modelowanie Języka

9. Model neuronowy rekurencyjny [ćwiczenia]

Jakub Pokrywka (2022)

Logo 2

import torch
from torch import nn, optim
from torch.utils.data import DataLoader
import numpy as np
from collections import Counter
import re
device = 'cpu'
! wget https://wolnelektury.pl/media/book/txt/potop-tom-pierwszy.txt
! wget https://wolnelektury.pl/media/book/txt/potop-tom-drugi.txt
! wget https://wolnelektury.pl/media/book/txt/potop-tom-trzeci.txt
--2022-05-08 19:27:04--  https://wolnelektury.pl/media/book/txt/potop-tom-pierwszy.txt
Resolving wolnelektury.pl (wolnelektury.pl)... 51.83.143.148, 2001:41d0:602:3294::
Connecting to wolnelektury.pl (wolnelektury.pl)|51.83.143.148|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 877893 (857K) [text/plain]
Saving to: potop-tom-pierwszy.txt.2

potop-tom-pierwszy. 100%[===================>] 857,32K  --.-KB/s    in 0,07s   

2022-05-08 19:27:04 (12,0 MB/s) - potop-tom-pierwszy.txt.2 saved [877893/877893]

--2022-05-08 19:27:04--  https://wolnelektury.pl/media/book/txt/potop-tom-drugi.txt
Resolving wolnelektury.pl (wolnelektury.pl)... 51.83.143.148, 2001:41d0:602:3294::
Connecting to wolnelektury.pl (wolnelektury.pl)|51.83.143.148|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1087797 (1,0M) [text/plain]
Saving to: potop-tom-drugi.txt.2

potop-tom-drugi.txt 100%[===================>]   1,04M  --.-KB/s    in 0,08s   

2022-05-08 19:27:04 (12,9 MB/s) - potop-tom-drugi.txt.2 saved [1087797/1087797]

--2022-05-08 19:27:05--  https://wolnelektury.pl/media/book/txt/potop-tom-trzeci.txt
Resolving wolnelektury.pl (wolnelektury.pl)... 51.83.143.148, 2001:41d0:602:3294::
Connecting to wolnelektury.pl (wolnelektury.pl)|51.83.143.148|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 788219 (770K) [text/plain]
Saving to: potop-tom-trzeci.txt.2

potop-tom-trzeci.tx 100%[===================>] 769,75K  --.-KB/s    in 0,06s   

2022-05-08 19:27:05 (12,0 MB/s) - potop-tom-trzeci.txt.2 saved [788219/788219]

!cat potop-* > potop.txt
class Dataset(torch.utils.data.Dataset):
    def __init__(
            self,
            sequence_length,
    ):
        self.sequence_length = sequence_length
        self.words = self.load()
        self.uniq_words = self.get_uniq_words()

        self.index_to_word = {index: word for index, word in enumerate(self.uniq_words)}
        self.word_to_index = {word: index for index, word in enumerate(self.uniq_words)}

        self.words_indexes = [self.word_to_index[w] for w in self.words]

    def load(self):
        with open('potop.txt', 'r') as f_in:
            text = [x.rstrip() for x in f_in.readlines() if x.strip()]
            text = ' '.join(text).lower()
            text = re.sub('[^a-ząćęłńóśźż ]', '', text) 
            text = text.split(' ')
        return text
    
    
    def get_uniq_words(self):
        word_counts = Counter(self.words)
        return sorted(word_counts, key=word_counts.get, reverse=True)

    def __len__(self):
        return len(self.words_indexes) - self.sequence_length

    def __getitem__(self, index):
        return (
            torch.tensor(self.words_indexes[index:index+self.sequence_length]),
            torch.tensor(self.words_indexes[index+1:index+self.sequence_length+1]),
        )
dataset = Dataset(5)
dataset[200]
(tensor([  551,    18,    17,   255, 10748]),
 tensor([   18,    17,   255, 10748,    34]))
[dataset.index_to_word[x] for x in [   551,    18,    17,   255, 10748]]
['patrzył', 'tak', 'jak', 'człowiek', 'zbudzony']
[dataset.index_to_word[x] for x in [   18,    17,   255, 10748,    34]]
['tak', 'jak', 'człowiek', 'zbudzony', 'ze']
input_tensor = torch.tensor([[ 551,    18,    17,   255, 10748]], dtype=torch.int32).to(device)
#input_tensor = torch.tensor([[ 551,    18]], dtype=torch.int32).to(device)
class Model(nn.Module):
    def __init__(self, vocab_size):
        super(Model, self).__init__()
        self.lstm_size = 128
        self.embedding_dim = 128
        self.num_layers = 3

        self.embedding = nn.Embedding(
            num_embeddings=vocab_size,
            embedding_dim=self.embedding_dim,
        )
        self.lstm = nn.LSTM(
            input_size=self.lstm_size,
            hidden_size=self.lstm_size,
            num_layers=self.num_layers,
            dropout=0.2,
        )
        self.fc = nn.Linear(self.lstm_size, vocab_size)

    def forward(self, x, prev_state = None):
        embed = self.embedding(x)
        output, state = self.lstm(embed, prev_state)
        logits = self.fc(output)
        return logits, state

    def init_state(self, sequence_length):
        return (torch.zeros(self.num_layers, sequence_length, self.lstm_size).to(device),
                torch.zeros(self.num_layers, sequence_length, self.lstm_size).to(device))
model = Model(len(dataset)).to(device)
y_pred, (state_h, state_c) = model(input_tensor)
y_pred
tensor([[[ 0.0046, -0.0113,  0.0313,  ...,  0.0198, -0.0312,  0.0223],
         [ 0.0039, -0.0110,  0.0303,  ...,  0.0213, -0.0302,  0.0230],
         [ 0.0029, -0.0133,  0.0265,  ...,  0.0204, -0.0297,  0.0219],
         [ 0.0010, -0.0120,  0.0282,  ...,  0.0241, -0.0314,  0.0241],
         [ 0.0038, -0.0106,  0.0346,  ...,  0.0230, -0.0333,  0.0232]]],
       grad_fn=<AddBackward0>)
y_pred.shape
torch.Size([1, 5, 1187998])
def train(dataset, model, max_epochs, batch_size):
    model.train()

    dataloader = DataLoader(dataset, batch_size=batch_size)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    for epoch in range(max_epochs):
        for batch, (x, y) in enumerate(dataloader):
            optimizer.zero_grad()
            x = x.to(device)
            y = y.to(device)

            y_pred, (state_h, state_c) = model(x)
            loss = criterion(y_pred.transpose(1, 2), y)

            loss.backward()
            optimizer.step()

            print({ 'epoch': epoch, 'update in batch': batch, '/' : len(dataloader), 'loss': loss.item() })
model = Model(vocab_size = len(dataset.uniq_words)).to(device)
train(dataset, model, 1, 64)
{'epoch': 0, 'update in batch': 0, '/': 18563, 'loss': 10.717817306518555}
{'epoch': 0, 'update in batch': 1, '/': 18563, 'loss': 10.699922561645508}
{'epoch': 0, 'update in batch': 2, '/': 18563, 'loss': 10.701103210449219}
{'epoch': 0, 'update in batch': 3, '/': 18563, 'loss': 10.700254440307617}
{'epoch': 0, 'update in batch': 4, '/': 18563, 'loss': 10.69465160369873}
{'epoch': 0, 'update in batch': 5, '/': 18563, 'loss': 10.681333541870117}
{'epoch': 0, 'update in batch': 6, '/': 18563, 'loss': 10.668376922607422}
{'epoch': 0, 'update in batch': 7, '/': 18563, 'loss': 10.675261497497559}
{'epoch': 0, 'update in batch': 8, '/': 18563, 'loss': 10.665823936462402}
{'epoch': 0, 'update in batch': 9, '/': 18563, 'loss': 10.655462265014648}
{'epoch': 0, 'update in batch': 10, '/': 18563, 'loss': 10.591516494750977}
{'epoch': 0, 'update in batch': 11, '/': 18563, 'loss': 10.580559730529785}
{'epoch': 0, 'update in batch': 12, '/': 18563, 'loss': 10.524133682250977}
{'epoch': 0, 'update in batch': 13, '/': 18563, 'loss': 10.480895042419434}
{'epoch': 0, 'update in batch': 14, '/': 18563, 'loss': 10.33996295928955}
{'epoch': 0, 'update in batch': 15, '/': 18563, 'loss': 10.345580101013184}
{'epoch': 0, 'update in batch': 16, '/': 18563, 'loss': 10.200639724731445}
{'epoch': 0, 'update in batch': 17, '/': 18563, 'loss': 10.030133247375488}
{'epoch': 0, 'update in batch': 18, '/': 18563, 'loss': 10.046720504760742}
{'epoch': 0, 'update in batch': 19, '/': 18563, 'loss': 10.00318717956543}
{'epoch': 0, 'update in batch': 20, '/': 18563, 'loss': 9.588350296020508}
{'epoch': 0, 'update in batch': 21, '/': 18563, 'loss': 9.780914306640625}
{'epoch': 0, 'update in batch': 22, '/': 18563, 'loss': 9.36646842956543}
{'epoch': 0, 'update in batch': 23, '/': 18563, 'loss': 9.306387901306152}
{'epoch': 0, 'update in batch': 24, '/': 18563, 'loss': 9.150574684143066}
{'epoch': 0, 'update in batch': 25, '/': 18563, 'loss': 8.89719295501709}
{'epoch': 0, 'update in batch': 26, '/': 18563, 'loss': 8.741975784301758}
{'epoch': 0, 'update in batch': 27, '/': 18563, 'loss': 9.36513614654541}
{'epoch': 0, 'update in batch': 28, '/': 18563, 'loss': 8.840768814086914}
{'epoch': 0, 'update in batch': 29, '/': 18563, 'loss': 8.356801986694336}
{'epoch': 0, 'update in batch': 30, '/': 18563, 'loss': 8.274016380310059}
{'epoch': 0, 'update in batch': 31, '/': 18563, 'loss': 8.944927215576172}
{'epoch': 0, 'update in batch': 32, '/': 18563, 'loss': 8.923280715942383}
{'epoch': 0, 'update in batch': 33, '/': 18563, 'loss': 8.479402542114258}
{'epoch': 0, 'update in batch': 34, '/': 18563, 'loss': 8.42425537109375}
{'epoch': 0, 'update in batch': 35, '/': 18563, 'loss': 9.487113952636719}
{'epoch': 0, 'update in batch': 36, '/': 18563, 'loss': 8.314191818237305}
{'epoch': 0, 'update in batch': 37, '/': 18563, 'loss': 8.0274658203125}
{'epoch': 0, 'update in batch': 38, '/': 18563, 'loss': 8.725769996643066}
{'epoch': 0, 'update in batch': 39, '/': 18563, 'loss': 8.67934799194336}
{'epoch': 0, 'update in batch': 40, '/': 18563, 'loss': 8.872161865234375}
{'epoch': 0, 'update in batch': 41, '/': 18563, 'loss': 7.883971214294434}
{'epoch': 0, 'update in batch': 42, '/': 18563, 'loss': 7.682810306549072}
{'epoch': 0, 'update in batch': 43, '/': 18563, 'loss': 7.880677223205566}
{'epoch': 0, 'update in batch': 44, '/': 18563, 'loss': 7.807427406311035}
{'epoch': 0, 'update in batch': 45, '/': 18563, 'loss': 7.93829870223999}
{'epoch': 0, 'update in batch': 46, '/': 18563, 'loss': 7.718912601470947}
{'epoch': 0, 'update in batch': 47, '/': 18563, 'loss': 8.309863090515137}
{'epoch': 0, 'update in batch': 48, '/': 18563, 'loss': 9.091133117675781}
{'epoch': 0, 'update in batch': 49, '/': 18563, 'loss': 9.317312240600586}
{'epoch': 0, 'update in batch': 50, '/': 18563, 'loss': 8.517735481262207}
{'epoch': 0, 'update in batch': 51, '/': 18563, 'loss': 7.697592258453369}
{'epoch': 0, 'update in batch': 52, '/': 18563, 'loss': 6.838181972503662}
{'epoch': 0, 'update in batch': 53, '/': 18563, 'loss': 7.967227935791016}
{'epoch': 0, 'update in batch': 54, '/': 18563, 'loss': 8.47049331665039}
{'epoch': 0, 'update in batch': 55, '/': 18563, 'loss': 8.958921432495117}
{'epoch': 0, 'update in batch': 56, '/': 18563, 'loss': 8.316679000854492}
{'epoch': 0, 'update in batch': 57, '/': 18563, 'loss': 8.997099876403809}
{'epoch': 0, 'update in batch': 58, '/': 18563, 'loss': 8.608811378479004}
{'epoch': 0, 'update in batch': 59, '/': 18563, 'loss': 9.377460479736328}
{'epoch': 0, 'update in batch': 60, '/': 18563, 'loss': 8.6201171875}
{'epoch': 0, 'update in batch': 61, '/': 18563, 'loss': 8.821510314941406}
{'epoch': 0, 'update in batch': 62, '/': 18563, 'loss': 8.915961265563965}
{'epoch': 0, 'update in batch': 63, '/': 18563, 'loss': 8.222617149353027}
{'epoch': 0, 'update in batch': 64, '/': 18563, 'loss': 9.266777992248535}
{'epoch': 0, 'update in batch': 65, '/': 18563, 'loss': 8.749354362487793}
{'epoch': 0, 'update in batch': 66, '/': 18563, 'loss': 8.311641693115234}
{'epoch': 0, 'update in batch': 67, '/': 18563, 'loss': 8.553888320922852}
{'epoch': 0, 'update in batch': 68, '/': 18563, 'loss': 8.790258407592773}
{'epoch': 0, 'update in batch': 69, '/': 18563, 'loss': 9.090133666992188}
{'epoch': 0, 'update in batch': 70, '/': 18563, 'loss': 8.893723487854004}
{'epoch': 0, 'update in batch': 71, '/': 18563, 'loss': 8.844594955444336}
{'epoch': 0, 'update in batch': 72, '/': 18563, 'loss': 7.771625518798828}
{'epoch': 0, 'update in batch': 73, '/': 18563, 'loss': 8.536479949951172}
{'epoch': 0, 'update in batch': 74, '/': 18563, 'loss': 7.300860404968262}
{'epoch': 0, 'update in batch': 75, '/': 18563, 'loss': 8.62000846862793}
{'epoch': 0, 'update in batch': 76, '/': 18563, 'loss': 8.67784309387207}
{'epoch': 0, 'update in batch': 77, '/': 18563, 'loss': 7.319235801696777}
{'epoch': 0, 'update in batch': 78, '/': 18563, 'loss': 8.322186470031738}
{'epoch': 0, 'update in batch': 79, '/': 18563, 'loss': 7.767421722412109}
{'epoch': 0, 'update in batch': 80, '/': 18563, 'loss': 8.817885398864746}
{'epoch': 0, 'update in batch': 81, '/': 18563, 'loss': 8.133109092712402}
{'epoch': 0, 'update in batch': 82, '/': 18563, 'loss': 7.822054862976074}
{'epoch': 0, 'update in batch': 83, '/': 18563, 'loss': 8.055540084838867}
{'epoch': 0, 'update in batch': 84, '/': 18563, 'loss': 8.053682327270508}
{'epoch': 0, 'update in batch': 85, '/': 18563, 'loss': 8.018306732177734}
{'epoch': 0, 'update in batch': 86, '/': 18563, 'loss': 8.371909141540527}
{'epoch': 0, 'update in batch': 87, '/': 18563, 'loss': 8.057979583740234}
{'epoch': 0, 'update in batch': 88, '/': 18563, 'loss': 8.340703010559082}
{'epoch': 0, 'update in batch': 89, '/': 18563, 'loss': 8.7703857421875}
{'epoch': 0, 'update in batch': 90, '/': 18563, 'loss': 9.714847564697266}
{'epoch': 0, 'update in batch': 91, '/': 18563, 'loss': 8.621702194213867}
{'epoch': 0, 'update in batch': 92, '/': 18563, 'loss': 9.406997680664062}
{'epoch': 0, 'update in batch': 93, '/': 18563, 'loss': 9.29774284362793}
{'epoch': 0, 'update in batch': 94, '/': 18563, 'loss': 8.649836540222168}
{'epoch': 0, 'update in batch': 95, '/': 18563, 'loss': 8.441780090332031}
{'epoch': 0, 'update in batch': 96, '/': 18563, 'loss': 7.991406440734863}
{'epoch': 0, 'update in batch': 97, '/': 18563, 'loss': 9.314489364624023}
{'epoch': 0, 'update in batch': 98, '/': 18563, 'loss': 8.368816375732422}
{'epoch': 0, 'update in batch': 99, '/': 18563, 'loss': 8.771149635314941}
{'epoch': 0, 'update in batch': 100, '/': 18563, 'loss': 7.8758111000061035}
{'epoch': 0, 'update in batch': 101, '/': 18563, 'loss': 8.341328620910645}
{'epoch': 0, 'update in batch': 102, '/': 18563, 'loss': 8.413129806518555}
{'epoch': 0, 'update in batch': 103, '/': 18563, 'loss': 7.372011661529541}
{'epoch': 0, 'update in batch': 104, '/': 18563, 'loss': 8.170934677124023}
{'epoch': 0, 'update in batch': 105, '/': 18563, 'loss': 8.109993934631348}
{'epoch': 0, 'update in batch': 106, '/': 18563, 'loss': 8.172578811645508}
{'epoch': 0, 'update in batch': 107, '/': 18563, 'loss': 8.33222484588623}
{'epoch': 0, 'update in batch': 108, '/': 18563, 'loss': 7.997575283050537}
{'epoch': 0, 'update in batch': 109, '/': 18563, 'loss': 7.847937107086182}
{'epoch': 0, 'update in batch': 110, '/': 18563, 'loss': 7.351314544677734}
{'epoch': 0, 'update in batch': 111, '/': 18563, 'loss': 8.472936630249023}
{'epoch': 0, 'update in batch': 112, '/': 18563, 'loss': 7.855953216552734}
{'epoch': 0, 'update in batch': 113, '/': 18563, 'loss': 8.163175582885742}
{'epoch': 0, 'update in batch': 114, '/': 18563, 'loss': 8.208657264709473}
{'epoch': 0, 'update in batch': 115, '/': 18563, 'loss': 8.781523704528809}
{'epoch': 0, 'update in batch': 116, '/': 18563, 'loss': 8.449674606323242}
{'epoch': 0, 'update in batch': 117, '/': 18563, 'loss': 8.176030158996582}
{'epoch': 0, 'update in batch': 118, '/': 18563, 'loss': 8.415689468383789}
{'epoch': 0, 'update in batch': 119, '/': 18563, 'loss': 8.645845413208008}
{'epoch': 0, 'update in batch': 120, '/': 18563, 'loss': 8.160420417785645}
{'epoch': 0, 'update in batch': 121, '/': 18563, 'loss': 8.117982864379883}
{'epoch': 0, 'update in batch': 122, '/': 18563, 'loss': 9.099283218383789}
{'epoch': 0, 'update in batch': 123, '/': 18563, 'loss': 7.98253870010376}
{'epoch': 0, 'update in batch': 124, '/': 18563, 'loss': 8.112133979797363}
{'epoch': 0, 'update in batch': 125, '/': 18563, 'loss': 8.479134559631348}
{'epoch': 0, 'update in batch': 126, '/': 18563, 'loss': 8.92817497253418}
{'epoch': 0, 'update in batch': 127, '/': 18563, 'loss': 8.38918399810791}
{'epoch': 0, 'update in batch': 128, '/': 18563, 'loss': 9.000529289245605}
{'epoch': 0, 'update in batch': 129, '/': 18563, 'loss': 8.525534629821777}
{'epoch': 0, 'update in batch': 130, '/': 18563, 'loss': 9.055428504943848}
{'epoch': 0, 'update in batch': 131, '/': 18563, 'loss': 8.818662643432617}
{'epoch': 0, 'update in batch': 132, '/': 18563, 'loss': 8.807767868041992}
{'epoch': 0, 'update in batch': 133, '/': 18563, 'loss': 8.398343086242676}
{'epoch': 0, 'update in batch': 134, '/': 18563, 'loss': 8.435093879699707}
{'epoch': 0, 'update in batch': 135, '/': 18563, 'loss': 7.877000331878662}
{'epoch': 0, 'update in batch': 136, '/': 18563, 'loss': 8.197925567626953}
{'epoch': 0, 'update in batch': 137, '/': 18563, 'loss': 8.655011177062988}
{'epoch': 0, 'update in batch': 138, '/': 18563, 'loss': 7.786923885345459}
{'epoch': 0, 'update in batch': 139, '/': 18563, 'loss': 8.338996887207031}
{'epoch': 0, 'update in batch': 140, '/': 18563, 'loss': 8.607789993286133}
{'epoch': 0, 'update in batch': 141, '/': 18563, 'loss': 8.52219295501709}
{'epoch': 0, 'update in batch': 142, '/': 18563, 'loss': 8.436418533325195}
{'epoch': 0, 'update in batch': 143, '/': 18563, 'loss': 7.999323844909668}
{'epoch': 0, 'update in batch': 144, '/': 18563, 'loss': 7.543336391448975}
{'epoch': 0, 'update in batch': 145, '/': 18563, 'loss': 7.3255791664123535}
{'epoch': 0, 'update in batch': 146, '/': 18563, 'loss': 7.993613243103027}
{'epoch': 0, 'update in batch': 147, '/': 18563, 'loss': 8.8505859375}
{'epoch': 0, 'update in batch': 148, '/': 18563, 'loss': 8.146835327148438}
{'epoch': 0, 'update in batch': 149, '/': 18563, 'loss': 8.532424926757812}
{'epoch': 0, 'update in batch': 150, '/': 18563, 'loss': 8.323905944824219}
{'epoch': 0, 'update in batch': 151, '/': 18563, 'loss': 7.8726677894592285}
{'epoch': 0, 'update in batch': 152, '/': 18563, 'loss': 7.912005424499512}
{'epoch': 0, 'update in batch': 153, '/': 18563, 'loss': 8.010560035705566}
{'epoch': 0, 'update in batch': 154, '/': 18563, 'loss': 7.9417009353637695}
{'epoch': 0, 'update in batch': 155, '/': 18563, 'loss': 7.991711616516113}
{'epoch': 0, 'update in batch': 156, '/': 18563, 'loss': 8.27558708190918}
{'epoch': 0, 'update in batch': 157, '/': 18563, 'loss': 7.736246585845947}
{'epoch': 0, 'update in batch': 158, '/': 18563, 'loss': 7.4755754470825195}
{'epoch': 0, 'update in batch': 159, '/': 18563, 'loss': 8.023443222045898}
{'epoch': 0, 'update in batch': 160, '/': 18563, 'loss': 8.130350112915039}
{'epoch': 0, 'update in batch': 161, '/': 18563, 'loss': 7.770634651184082}
{'epoch': 0, 'update in batch': 162, '/': 18563, 'loss': 7.775434970855713}
{'epoch': 0, 'update in batch': 163, '/': 18563, 'loss': 7.965312957763672}
{'epoch': 0, 'update in batch': 164, '/': 18563, 'loss': 7.977341651916504}
{'epoch': 0, 'update in batch': 165, '/': 18563, 'loss': 7.703671455383301}
{'epoch': 0, 'update in batch': 166, '/': 18563, 'loss': 8.027135848999023}
{'epoch': 0, 'update in batch': 167, '/': 18563, 'loss': 7.7673773765563965}
{'epoch': 0, 'update in batch': 168, '/': 18563, 'loss': 8.654549598693848}
{'epoch': 0, 'update in batch': 169, '/': 18563, 'loss': 7.8060808181762695}
{'epoch': 0, 'update in batch': 170, '/': 18563, 'loss': 7.33704137802124}
{'epoch': 0, 'update in batch': 171, '/': 18563, 'loss': 7.971919059753418}
{'epoch': 0, 'update in batch': 172, '/': 18563, 'loss': 7.450611114501953}
{'epoch': 0, 'update in batch': 173, '/': 18563, 'loss': 7.978057861328125}
{'epoch': 0, 'update in batch': 174, '/': 18563, 'loss': 8.264434814453125}
{'epoch': 0, 'update in batch': 175, '/': 18563, 'loss': 8.47761058807373}
{'epoch': 0, 'update in batch': 176, '/': 18563, 'loss': 7.643885135650635}
{'epoch': 0, 'update in batch': 177, '/': 18563, 'loss': 8.696805000305176}
{'epoch': 0, 'update in batch': 178, '/': 18563, 'loss': 9.144462585449219}
{'epoch': 0, 'update in batch': 179, '/': 18563, 'loss': 8.582620620727539}
{'epoch': 0, 'update in batch': 180, '/': 18563, 'loss': 8.495562553405762}
{'epoch': 0, 'update in batch': 181, '/': 18563, 'loss': 9.259647369384766}
{'epoch': 0, 'update in batch': 182, '/': 18563, 'loss': 8.286632537841797}
{'epoch': 0, 'update in batch': 183, '/': 18563, 'loss': 8.378074645996094}
{'epoch': 0, 'update in batch': 184, '/': 18563, 'loss': 8.404892921447754}
{'epoch': 0, 'update in batch': 185, '/': 18563, 'loss': 9.206843376159668}
{'epoch': 0, 'update in batch': 186, '/': 18563, 'loss': 8.97215747833252}
{'epoch': 0, 'update in batch': 187, '/': 18563, 'loss': 8.281005859375}
{'epoch': 0, 'update in batch': 188, '/': 18563, 'loss': 7.638144493103027}
{'epoch': 0, 'update in batch': 189, '/': 18563, 'loss': 7.991082668304443}
{'epoch': 0, 'update in batch': 190, '/': 18563, 'loss': 8.207674026489258}
{'epoch': 0, 'update in batch': 191, '/': 18563, 'loss': 8.16801643371582}
{'epoch': 0, 'update in batch': 192, '/': 18563, 'loss': 7.827309608459473}
{'epoch': 0, 'update in batch': 193, '/': 18563, 'loss': 8.387285232543945}
{'epoch': 0, 'update in batch': 194, '/': 18563, 'loss': 7.990261077880859}
{'epoch': 0, 'update in batch': 195, '/': 18563, 'loss': 7.7953925132751465}
{'epoch': 0, 'update in batch': 196, '/': 18563, 'loss': 7.252983093261719}
{'epoch': 0, 'update in batch': 197, '/': 18563, 'loss': 7.806585788726807}
{'epoch': 0, 'update in batch': 198, '/': 18563, 'loss': 7.871600151062012}
{'epoch': 0, 'update in batch': 199, '/': 18563, 'loss': 7.639830589294434}
{'epoch': 0, 'update in batch': 200, '/': 18563, 'loss': 8.108308792114258}
{'epoch': 0, 'update in batch': 201, '/': 18563, 'loss': 7.41513729095459}
{'epoch': 0, 'update in batch': 202, '/': 18563, 'loss': 8.103743553161621}
{'epoch': 0, 'update in batch': 203, '/': 18563, 'loss': 8.82174301147461}
{'epoch': 0, 'update in batch': 204, '/': 18563, 'loss': 8.34859561920166}
{'epoch': 0, 'update in batch': 205, '/': 18563, 'loss': 7.890545845031738}
{'epoch': 0, 'update in batch': 206, '/': 18563, 'loss': 7.679532527923584}
{'epoch': 0, 'update in batch': 207, '/': 18563, 'loss': 7.810311317443848}
{'epoch': 0, 'update in batch': 208, '/': 18563, 'loss': 8.342585563659668}
{'epoch': 0, 'update in batch': 209, '/': 18563, 'loss': 8.253597259521484}
{'epoch': 0, 'update in batch': 210, '/': 18563, 'loss': 7.963072299957275}
{'epoch': 0, 'update in batch': 211, '/': 18563, 'loss': 8.537101745605469}
{'epoch': 0, 'update in batch': 212, '/': 18563, 'loss': 8.503724098205566}
{'epoch': 0, 'update in batch': 213, '/': 18563, 'loss': 8.568987846374512}
{'epoch': 0, 'update in batch': 214, '/': 18563, 'loss': 7.760678291320801}
{'epoch': 0, 'update in batch': 215, '/': 18563, 'loss': 8.302183151245117}
{'epoch': 0, 'update in batch': 216, '/': 18563, 'loss': 7.427420616149902}
{'epoch': 0, 'update in batch': 217, '/': 18563, 'loss': 8.05746078491211}
{'epoch': 0, 'update in batch': 218, '/': 18563, 'loss': 8.82285213470459}
{'epoch': 0, 'update in batch': 219, '/': 18563, 'loss': 7.948827266693115}
{'epoch': 0, 'update in batch': 220, '/': 18563, 'loss': 8.164112091064453}
{'epoch': 0, 'update in batch': 221, '/': 18563, 'loss': 7.721047401428223}
{'epoch': 0, 'update in batch': 222, '/': 18563, 'loss': 7.668707370758057}
{'epoch': 0, 'update in batch': 223, '/': 18563, 'loss': 8.576696395874023}
{'epoch': 0, 'update in batch': 224, '/': 18563, 'loss': 8.253091812133789}
{'epoch': 0, 'update in batch': 225, '/': 18563, 'loss': 8.303543090820312}
{'epoch': 0, 'update in batch': 226, '/': 18563, 'loss': 8.069855690002441}
{'epoch': 0, 'update in batch': 227, '/': 18563, 'loss': 8.57229232788086}
{'epoch': 0, 'update in batch': 228, '/': 18563, 'loss': 8.904585838317871}
{'epoch': 0, 'update in batch': 229, '/': 18563, 'loss': 8.485595703125}
{'epoch': 0, 'update in batch': 230, '/': 18563, 'loss': 8.22756290435791}
{'epoch': 0, 'update in batch': 231, '/': 18563, 'loss': 8.281603813171387}
{'epoch': 0, 'update in batch': 232, '/': 18563, 'loss': 7.591467380523682}
{'epoch': 0, 'update in batch': 233, '/': 18563, 'loss': 7.8028883934021}
{'epoch': 0, 'update in batch': 234, '/': 18563, 'loss': 8.079168319702148}
{'epoch': 0, 'update in batch': 235, '/': 18563, 'loss': 7.578390598297119}
{'epoch': 0, 'update in batch': 236, '/': 18563, 'loss': 7.865830421447754}
{'epoch': 0, 'update in batch': 237, '/': 18563, 'loss': 7.105422019958496}
{'epoch': 0, 'update in batch': 238, '/': 18563, 'loss': 8.034143447875977}
{'epoch': 0, 'update in batch': 239, '/': 18563, 'loss': 7.23009729385376}
{'epoch': 0, 'update in batch': 240, '/': 18563, 'loss': 7.221669673919678}
{'epoch': 0, 'update in batch': 241, '/': 18563, 'loss': 7.118913173675537}
{'epoch': 0, 'update in batch': 242, '/': 18563, 'loss': 7.690147399902344}
{'epoch': 0, 'update in batch': 243, '/': 18563, 'loss': 7.676979064941406}
{'epoch': 0, 'update in batch': 244, '/': 18563, 'loss': 8.231537818908691}
{'epoch': 0, 'update in batch': 245, '/': 18563, 'loss': 8.212566375732422}
{'epoch': 0, 'update in batch': 246, '/': 18563, 'loss': 9.095616340637207}
{'epoch': 0, 'update in batch': 247, '/': 18563, 'loss': 8.249703407287598}
{'epoch': 0, 'update in batch': 248, '/': 18563, 'loss': 9.082058906555176}
{'epoch': 0, 'update in batch': 249, '/': 18563, 'loss': 8.530516624450684}
{'epoch': 0, 'update in batch': 250, '/': 18563, 'loss': 8.979915618896484}
{'epoch': 0, 'update in batch': 251, '/': 18563, 'loss': 8.667882919311523}
{'epoch': 0, 'update in batch': 252, '/': 18563, 'loss': 8.804525375366211}
{'epoch': 0, 'update in batch': 253, '/': 18563, 'loss': 8.67729377746582}
{'epoch': 0, 'update in batch': 254, '/': 18563, 'loss': 8.580761909484863}
{'epoch': 0, 'update in batch': 255, '/': 18563, 'loss': 7.724173545837402}
{'epoch': 0, 'update in batch': 256, '/': 18563, 'loss': 7.7925591468811035}
{'epoch': 0, 'update in batch': 257, '/': 18563, 'loss': 7.731482028961182}
{'epoch': 0, 'update in batch': 258, '/': 18563, 'loss': 7.644040107727051}
{'epoch': 0, 'update in batch': 259, '/': 18563, 'loss': 7.947877407073975}
{'epoch': 0, 'update in batch': 260, '/': 18563, 'loss': 7.649043083190918}
{'epoch': 0, 'update in batch': 261, '/': 18563, 'loss': 7.40912389755249}
{'epoch': 0, 'update in batch': 262, '/': 18563, 'loss': 8.199918746948242}
{'epoch': 0, 'update in batch': 263, '/': 18563, 'loss': 7.272132873535156}
{'epoch': 0, 'update in batch': 264, '/': 18563, 'loss': 7.205214500427246}
{'epoch': 0, 'update in batch': 265, '/': 18563, 'loss': 8.999595642089844}
{'epoch': 0, 'update in batch': 266, '/': 18563, 'loss': 7.851510524749756}
{'epoch': 0, 'update in batch': 267, '/': 18563, 'loss': 7.748948097229004}
{'epoch': 0, 'update in batch': 268, '/': 18563, 'loss': 7.96875}
{'epoch': 0, 'update in batch': 269, '/': 18563, 'loss': 7.627255916595459}
{'epoch': 0, 'update in batch': 270, '/': 18563, 'loss': 7.719862937927246}
{'epoch': 0, 'update in batch': 271, '/': 18563, 'loss': 7.58780574798584}
{'epoch': 0, 'update in batch': 272, '/': 18563, 'loss': 8.386865615844727}
{'epoch': 0, 'update in batch': 273, '/': 18563, 'loss': 8.708396911621094}
{'epoch': 0, 'update in batch': 274, '/': 18563, 'loss': 7.853432655334473}
{'epoch': 0, 'update in batch': 275, '/': 18563, 'loss': 7.818131923675537}
{'epoch': 0, 'update in batch': 276, '/': 18563, 'loss': 7.714521884918213}
{'epoch': 0, 'update in batch': 277, '/': 18563, 'loss': 8.75371265411377}
{'epoch': 0, 'update in batch': 278, '/': 18563, 'loss': 7.6992998123168945}
{'epoch': 0, 'update in batch': 279, '/': 18563, 'loss': 7.652693748474121}
{'epoch': 0, 'update in batch': 280, '/': 18563, 'loss': 7.364585876464844}
{'epoch': 0, 'update in batch': 281, '/': 18563, 'loss': 7.742022514343262}
{'epoch': 0, 'update in batch': 282, '/': 18563, 'loss': 7.6205573081970215}
{'epoch': 0, 'update in batch': 283, '/': 18563, 'loss': 7.475846290588379}
{'epoch': 0, 'update in batch': 284, '/': 18563, 'loss': 7.302148342132568}
{'epoch': 0, 'update in batch': 285, '/': 18563, 'loss': 7.524351596832275}
{'epoch': 0, 'update in batch': 286, '/': 18563, 'loss': 7.755963325500488}
{'epoch': 0, 'update in batch': 287, '/': 18563, 'loss': 7.620995998382568}
{'epoch': 0, 'update in batch': 288, '/': 18563, 'loss': 7.289975166320801}
{'epoch': 0, 'update in batch': 289, '/': 18563, 'loss': 7.470652103424072}
{'epoch': 0, 'update in batch': 290, '/': 18563, 'loss': 7.297110557556152}
{'epoch': 0, 'update in batch': 291, '/': 18563, 'loss': 7.907563209533691}
{'epoch': 0, 'update in batch': 292, '/': 18563, 'loss': 8.051852226257324}
{'epoch': 0, 'update in batch': 293, '/': 18563, 'loss': 6.691899299621582}
{'epoch': 0, 'update in batch': 294, '/': 18563, 'loss': 7.9747819900512695}
{'epoch': 0, 'update in batch': 295, '/': 18563, 'loss': 7.415904998779297}
{'epoch': 0, 'update in batch': 296, '/': 18563, 'loss': 7.479670524597168}
{'epoch': 0, 'update in batch': 297, '/': 18563, 'loss': 7.9454755783081055}
{'epoch': 0, 'update in batch': 298, '/': 18563, 'loss': 7.79656457901001}
{'epoch': 0, 'update in batch': 299, '/': 18563, 'loss': 7.644859313964844}
{'epoch': 0, 'update in batch': 300, '/': 18563, 'loss': 7.649240970611572}
{'epoch': 0, 'update in batch': 301, '/': 18563, 'loss': 7.497203826904297}
{'epoch': 0, 'update in batch': 302, '/': 18563, 'loss': 7.169632911682129}
{'epoch': 0, 'update in batch': 303, '/': 18563, 'loss': 7.124764442443848}
{'epoch': 0, 'update in batch': 304, '/': 18563, 'loss': 7.728893280029297}
{'epoch': 0, 'update in batch': 305, '/': 18563, 'loss': 8.029245376586914}
{'epoch': 0, 'update in batch': 306, '/': 18563, 'loss': 7.361662864685059}
{'epoch': 0, 'update in batch': 307, '/': 18563, 'loss': 8.070173263549805}
{'epoch': 0, 'update in batch': 308, '/': 18563, 'loss': 7.55655574798584}
{'epoch': 0, 'update in batch': 309, '/': 18563, 'loss': 7.713553428649902}
{'epoch': 0, 'update in batch': 310, '/': 18563, 'loss': 8.333553314208984}
{'epoch': 0, 'update in batch': 311, '/': 18563, 'loss': 8.089872360229492}
{'epoch': 0, 'update in batch': 312, '/': 18563, 'loss': 8.951356887817383}
{'epoch': 0, 'update in batch': 313, '/': 18563, 'loss': 8.920665740966797}
{'epoch': 0, 'update in batch': 314, '/': 18563, 'loss': 8.811259269714355}
{'epoch': 0, 'update in batch': 315, '/': 18563, 'loss': 8.719802856445312}
{'epoch': 0, 'update in batch': 316, '/': 18563, 'loss': 8.700776100158691}
{'epoch': 0, 'update in batch': 317, '/': 18563, 'loss': 8.846036911010742}
{'epoch': 0, 'update in batch': 318, '/': 18563, 'loss': 8.553533554077148}
{'epoch': 0, 'update in batch': 319, '/': 18563, 'loss': 9.257116317749023}
{'epoch': 0, 'update in batch': 320, '/': 18563, 'loss': 8.487042427062988}
{'epoch': 0, 'update in batch': 321, '/': 18563, 'loss': 8.743330955505371}
{'epoch': 0, 'update in batch': 322, '/': 18563, 'loss': 8.377813339233398}
{'epoch': 0, 'update in batch': 323, '/': 18563, 'loss': 8.41798210144043}
{'epoch': 0, 'update in batch': 324, '/': 18563, 'loss': 7.884764671325684}
{'epoch': 0, 'update in batch': 325, '/': 18563, 'loss': 8.827409744262695}
{'epoch': 0, 'update in batch': 326, '/': 18563, 'loss': 8.21721363067627}
{'epoch': 0, 'update in batch': 327, '/': 18563, 'loss': 8.522723197937012}
{'epoch': 0, 'update in batch': 328, '/': 18563, 'loss': 7.387178897857666}
{'epoch': 0, 'update in batch': 329, '/': 18563, 'loss': 8.58663558959961}
{'epoch': 0, 'update in batch': 330, '/': 18563, 'loss': 8.539435386657715}
{'epoch': 0, 'update in batch': 331, '/': 18563, 'loss': 8.35865592956543}
{'epoch': 0, 'update in batch': 332, '/': 18563, 'loss': 8.55555248260498}
{'epoch': 0, 'update in batch': 333, '/': 18563, 'loss': 7.9116950035095215}
{'epoch': 0, 'update in batch': 334, '/': 18563, 'loss': 8.424735069274902}
{'epoch': 0, 'update in batch': 335, '/': 18563, 'loss': 8.383890151977539}
{'epoch': 0, 'update in batch': 336, '/': 18563, 'loss': 8.145454406738281}
{'epoch': 0, 'update in batch': 337, '/': 18563, 'loss': 8.014772415161133}
{'epoch': 0, 'update in batch': 338, '/': 18563, 'loss': 8.532005310058594}
{'epoch': 0, 'update in batch': 339, '/': 18563, 'loss': 8.979973793029785}
{'epoch': 0, 'update in batch': 340, '/': 18563, 'loss': 8.3964204788208}
{'epoch': 0, 'update in batch': 341, '/': 18563, 'loss': 8.34205150604248}
{'epoch': 0, 'update in batch': 342, '/': 18563, 'loss': 7.861489295959473}
{'epoch': 0, 'update in batch': 343, '/': 18563, 'loss': 8.807058334350586}
{'epoch': 0, 'update in batch': 344, '/': 18563, 'loss': 8.14976978302002}
{'epoch': 0, 'update in batch': 345, '/': 18563, 'loss': 8.212860107421875}
{'epoch': 0, 'update in batch': 346, '/': 18563, 'loss': 8.323419570922852}
{'epoch': 0, 'update in batch': 347, '/': 18563, 'loss': 9.06071662902832}
{'epoch': 0, 'update in batch': 348, '/': 18563, 'loss': 8.79192066192627}
{'epoch': 0, 'update in batch': 349, '/': 18563, 'loss': 8.717201232910156}
{'epoch': 0, 'update in batch': 350, '/': 18563, 'loss': 8.149703979492188}
{'epoch': 0, 'update in batch': 351, '/': 18563, 'loss': 7.990046501159668}
{'epoch': 0, 'update in batch': 352, '/': 18563, 'loss': 7.8197221755981445}
{'epoch': 0, 'update in batch': 353, '/': 18563, 'loss': 8.022729873657227}
{'epoch': 0, 'update in batch': 354, '/': 18563, 'loss': 8.339923858642578}
{'epoch': 0, 'update in batch': 355, '/': 18563, 'loss': 7.867880821228027}
{'epoch': 0, 'update in batch': 356, '/': 18563, 'loss': 8.161782264709473}
{'epoch': 0, 'update in batch': 357, '/': 18563, 'loss': 7.711170196533203}
{'epoch': 0, 'update in batch': 358, '/': 18563, 'loss': 8.46279239654541}
{'epoch': 0, 'update in batch': 359, '/': 18563, 'loss': 8.327804565429688}
{'epoch': 0, 'update in batch': 360, '/': 18563, 'loss': 8.184597969055176}
{'epoch': 0, 'update in batch': 361, '/': 18563, 'loss': 8.126212120056152}
{'epoch': 0, 'update in batch': 362, '/': 18563, 'loss': 8.122446060180664}
{'epoch': 0, 'update in batch': 363, '/': 18563, 'loss': 7.730257511138916}
{'epoch': 0, 'update in batch': 364, '/': 18563, 'loss': 7.7179059982299805}
{'epoch': 0, 'update in batch': 365, '/': 18563, 'loss': 7.557857513427734}
{'epoch': 0, 'update in batch': 366, '/': 18563, 'loss': 8.614083290100098}
{'epoch': 0, 'update in batch': 367, '/': 18563, 'loss': 8.0489501953125}
{'epoch': 0, 'update in batch': 368, '/': 18563, 'loss': 8.355381965637207}
{'epoch': 0, 'update in batch': 369, '/': 18563, 'loss': 7.592991828918457}
{'epoch': 0, 'update in batch': 370, '/': 18563, 'loss': 7.674102783203125}
{'epoch': 0, 'update in batch': 371, '/': 18563, 'loss': 7.818256378173828}
{'epoch': 0, 'update in batch': 372, '/': 18563, 'loss': 8.510438919067383}
{'epoch': 0, 'update in batch': 373, '/': 18563, 'loss': 8.02087116241455}
{'epoch': 0, 'update in batch': 374, '/': 18563, 'loss': 8.206090927124023}
{'epoch': 0, 'update in batch': 375, '/': 18563, 'loss': 7.645677089691162}
{'epoch': 0, 'update in batch': 376, '/': 18563, 'loss': 8.241236686706543}
{'epoch': 0, 'update in batch': 377, '/': 18563, 'loss': 8.581649780273438}
{'epoch': 0, 'update in batch': 378, '/': 18563, 'loss': 9.361258506774902}
{'epoch': 0, 'update in batch': 379, '/': 18563, 'loss': 9.097440719604492}
{'epoch': 0, 'update in batch': 380, '/': 18563, 'loss': 8.081677436828613}
{'epoch': 0, 'update in batch': 381, '/': 18563, 'loss': 8.761143684387207}
{'epoch': 0, 'update in batch': 382, '/': 18563, 'loss': 7.9429121017456055}
{'epoch': 0, 'update in batch': 383, '/': 18563, 'loss': 8.05648422241211}
{'epoch': 0, 'update in batch': 384, '/': 18563, 'loss': 7.316658020019531}
{'epoch': 0, 'update in batch': 385, '/': 18563, 'loss': 8.597393035888672}
{'epoch': 0, 'update in batch': 386, '/': 18563, 'loss': 9.393728256225586}
{'epoch': 0, 'update in batch': 387, '/': 18563, 'loss': 8.225081443786621}
{'epoch': 0, 'update in batch': 388, '/': 18563, 'loss': 7.9958319664001465}
{'epoch': 0, 'update in batch': 389, '/': 18563, 'loss': 8.390036582946777}
{'epoch': 0, 'update in batch': 390, '/': 18563, 'loss': 7.745572566986084}
{'epoch': 0, 'update in batch': 391, '/': 18563, 'loss': 8.403060913085938}
{'epoch': 0, 'update in batch': 392, '/': 18563, 'loss': 8.703788757324219}
{'epoch': 0, 'update in batch': 393, '/': 18563, 'loss': 8.516857147216797}
{'epoch': 0, 'update in batch': 394, '/': 18563, 'loss': 8.078744888305664}
{'epoch': 0, 'update in batch': 395, '/': 18563, 'loss': 7.6597900390625}
{'epoch': 0, 'update in batch': 396, '/': 18563, 'loss': 8.454282760620117}
{'epoch': 0, 'update in batch': 397, '/': 18563, 'loss': 7.7727837562561035}
{'epoch': 0, 'update in batch': 398, '/': 18563, 'loss': 8.222984313964844}
{'epoch': 0, 'update in batch': 399, '/': 18563, 'loss': 8.369619369506836}
{'epoch': 0, 'update in batch': 400, '/': 18563, 'loss': 8.542525291442871}
{'epoch': 0, 'update in batch': 401, '/': 18563, 'loss': 7.9681854248046875}
{'epoch': 0, 'update in batch': 402, '/': 18563, 'loss': 8.842118263244629}
{'epoch': 0, 'update in batch': 403, '/': 18563, 'loss': 7.958454132080078}
{'epoch': 0, 'update in batch': 404, '/': 18563, 'loss': 7.084095001220703}
{'epoch': 0, 'update in batch': 405, '/': 18563, 'loss': 7.8765130043029785}
{'epoch': 0, 'update in batch': 406, '/': 18563, 'loss': 7.639691352844238}
{'epoch': 0, 'update in batch': 407, '/': 18563, 'loss': 7.440125942230225}
{'epoch': 0, 'update in batch': 408, '/': 18563, 'loss': 7.928472995758057}
{'epoch': 0, 'update in batch': 409, '/': 18563, 'loss': 8.704710960388184}
{'epoch': 0, 'update in batch': 410, '/': 18563, 'loss': 8.214713096618652}
{'epoch': 0, 'update in batch': 411, '/': 18563, 'loss': 8.115629196166992}
{'epoch': 0, 'update in batch': 412, '/': 18563, 'loss': 9.357975006103516}
{'epoch': 0, 'update in batch': 413, '/': 18563, 'loss': 7.756926536560059}
{'epoch': 0, 'update in batch': 414, '/': 18563, 'loss': 8.93007755279541}
{'epoch': 0, 'update in batch': 415, '/': 18563, 'loss': 8.929518699645996}
{'epoch': 0, 'update in batch': 416, '/': 18563, 'loss': 7.646470069885254}
{'epoch': 0, 'update in batch': 417, '/': 18563, 'loss': 8.457891464233398}
{'epoch': 0, 'update in batch': 418, '/': 18563, 'loss': 7.377375602722168}
{'epoch': 0, 'update in batch': 419, '/': 18563, 'loss': 8.03713607788086}
{'epoch': 0, 'update in batch': 420, '/': 18563, 'loss': 8.125130653381348}
{'epoch': 0, 'update in batch': 421, '/': 18563, 'loss': 6.818246364593506}
{'epoch': 0, 'update in batch': 422, '/': 18563, 'loss': 7.220259189605713}
{'epoch': 0, 'update in batch': 423, '/': 18563, 'loss': 7.800910949707031}
{'epoch': 0, 'update in batch': 424, '/': 18563, 'loss': 8.175793647766113}
{'epoch': 0, 'update in batch': 425, '/': 18563, 'loss': 7.588067054748535}
{'epoch': 0, 'update in batch': 426, '/': 18563, 'loss': 7.2054619789123535}
{'epoch': 0, 'update in batch': 427, '/': 18563, 'loss': 7.6552839279174805}
{'epoch': 0, 'update in batch': 428, '/': 18563, 'loss': 8.851090431213379}
{'epoch': 0, 'update in batch': 429, '/': 18563, 'loss': 8.768563270568848}
{'epoch': 0, 'update in batch': 430, '/': 18563, 'loss': 7.926184177398682}
{'epoch': 0, 'update in batch': 431, '/': 18563, 'loss': 8.663213729858398}
{'epoch': 0, 'update in batch': 432, '/': 18563, 'loss': 8.386338233947754}
{'epoch': 0, 'update in batch': 433, '/': 18563, 'loss': 8.77399730682373}
{'epoch': 0, 'update in batch': 434, '/': 18563, 'loss': 8.385528564453125}
{'epoch': 0, 'update in batch': 435, '/': 18563, 'loss': 7.742388725280762}
{'epoch': 0, 'update in batch': 436, '/': 18563, 'loss': 8.363179206848145}
{'epoch': 0, 'update in batch': 437, '/': 18563, 'loss': 9.262784004211426}
{'epoch': 0, 'update in batch': 438, '/': 18563, 'loss': 9.236469268798828}
{'epoch': 0, 'update in batch': 439, '/': 18563, 'loss': 8.904603958129883}
{'epoch': 0, 'update in batch': 440, '/': 18563, 'loss': 8.675701141357422}
{'epoch': 0, 'update in batch': 441, '/': 18563, 'loss': 8.811418533325195}
{'epoch': 0, 'update in batch': 442, '/': 18563, 'loss': 8.002241134643555}
{'epoch': 0, 'update in batch': 443, '/': 18563, 'loss': 9.04414176940918}
{'epoch': 0, 'update in batch': 444, '/': 18563, 'loss': 7.8904008865356445}
{'epoch': 0, 'update in batch': 445, '/': 18563, 'loss': 8.524297714233398}
{'epoch': 0, 'update in batch': 446, '/': 18563, 'loss': 8.615904808044434}
{'epoch': 0, 'update in batch': 447, '/': 18563, 'loss': 8.201675415039062}
{'epoch': 0, 'update in batch': 448, '/': 18563, 'loss': 8.531024932861328}
{'epoch': 0, 'update in batch': 449, '/': 18563, 'loss': 7.8379621505737305}
{'epoch': 0, 'update in batch': 450, '/': 18563, 'loss': 8.416367530822754}
{'epoch': 0, 'update in batch': 451, '/': 18563, 'loss': 7.4990715980529785}
{'epoch': 0, 'update in batch': 452, '/': 18563, 'loss': 7.984610557556152}
{'epoch': 0, 'update in batch': 453, '/': 18563, 'loss': 7.719987392425537}
{'epoch': 0, 'update in batch': 454, '/': 18563, 'loss': 7.9333176612854}
{'epoch': 0, 'update in batch': 455, '/': 18563, 'loss': 8.619344711303711}
{'epoch': 0, 'update in batch': 456, '/': 18563, 'loss': 7.849525451660156}
{'epoch': 0, 'update in batch': 457, '/': 18563, 'loss': 7.700997352600098}
{'epoch': 0, 'update in batch': 458, '/': 18563, 'loss': 8.065767288208008}
{'epoch': 0, 'update in batch': 459, '/': 18563, 'loss': 7.489628791809082}
{'epoch': 0, 'update in batch': 460, '/': 18563, 'loss': 8.036481857299805}
{'epoch': 0, 'update in batch': 461, '/': 18563, 'loss': 8.227537155151367}
{'epoch': 0, 'update in batch': 462, '/': 18563, 'loss': 7.66103982925415}
{'epoch': 0, 'update in batch': 463, '/': 18563, 'loss': 8.481343269348145}
{'epoch': 0, 'update in batch': 464, '/': 18563, 'loss': 8.711318969726562}
{'epoch': 0, 'update in batch': 465, '/': 18563, 'loss': 7.549925804138184}
{'epoch': 0, 'update in batch': 466, '/': 18563, 'loss': 8.020782470703125}
{'epoch': 0, 'update in batch': 467, '/': 18563, 'loss': 7.784451484680176}
{'epoch': 0, 'update in batch': 468, '/': 18563, 'loss': 7.7545928955078125}
{'epoch': 0, 'update in batch': 469, '/': 18563, 'loss': 8.484171867370605}
{'epoch': 0, 'update in batch': 470, '/': 18563, 'loss': 8.291640281677246}
{'epoch': 0, 'update in batch': 471, '/': 18563, 'loss': 7.873322486877441}
{'epoch': 0, 'update in batch': 472, '/': 18563, 'loss': 7.891420841217041}
{'epoch': 0, 'update in batch': 473, '/': 18563, 'loss': 8.376962661743164}
{'epoch': 0, 'update in batch': 474, '/': 18563, 'loss': 8.147513389587402}
{'epoch': 0, 'update in batch': 475, '/': 18563, 'loss': 7.739943027496338}
{'epoch': 0, 'update in batch': 476, '/': 18563, 'loss': 7.52395486831665}
{'epoch': 0, 'update in batch': 477, '/': 18563, 'loss': 7.962507724761963}
{'epoch': 0, 'update in batch': 478, '/': 18563, 'loss': 7.61989688873291}
{'epoch': 0, 'update in batch': 479, '/': 18563, 'loss': 8.628551483154297}
{'epoch': 0, 'update in batch': 480, '/': 18563, 'loss': 10.344924926757812}
{'epoch': 0, 'update in batch': 481, '/': 18563, 'loss': 9.189457893371582}
{'epoch': 0, 'update in batch': 482, '/': 18563, 'loss': 9.283202171325684}
{'epoch': 0, 'update in batch': 483, '/': 18563, 'loss': 8.036226272583008}
{'epoch': 0, 'update in batch': 484, '/': 18563, 'loss': 8.949888229370117}
{'epoch': 0, 'update in batch': 485, '/': 18563, 'loss': 9.32779598236084}
{'epoch': 0, 'update in batch': 486, '/': 18563, 'loss': 9.554967880249023}
{'epoch': 0, 'update in batch': 487, '/': 18563, 'loss': 8.438692092895508}
{'epoch': 0, 'update in batch': 488, '/': 18563, 'loss': 8.015823364257812}
{'epoch': 0, 'update in batch': 489, '/': 18563, 'loss': 8.621005058288574}
{'epoch': 0, 'update in batch': 490, '/': 18563, 'loss': 8.432602882385254}
{'epoch': 0, 'update in batch': 491, '/': 18563, 'loss': 8.659430503845215}
{'epoch': 0, 'update in batch': 492, '/': 18563, 'loss': 8.693103790283203}
{'epoch': 0, 'update in batch': 493, '/': 18563, 'loss': 8.895064353942871}
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
<ipython-input-18-fe996a0be74b> in <module>
      1 model = Model(vocab_size = len(dataset.uniq_words)).to(device)
----> 2 train(dataset, model, 1, 64)

<ipython-input-17-8d700bc624e3> in train(dataset, model, max_epochs, batch_size)
     15             loss = criterion(y_pred.transpose(1, 2), y)
     16 
---> 17             loss.backward()
     18             optimizer.step()
     19 

~/anaconda3/lib/python3.8/site-packages/torch/_tensor.py in backward(self, gradient, retain_graph, create_graph, inputs)
    361                 create_graph=create_graph,
    362                 inputs=inputs)
--> 363         torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
    364 
    365     def register_hook(self, hook):

~/anaconda3/lib/python3.8/site-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
    171     # some Python versions print out the first line of a multi-line function
    172     # calls in the traceback and some print out the last line
--> 173     Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
    174         tensors, grad_tensors_, retain_graph, create_graph, inputs,
    175         allow_unreachable=True, accumulate_grad=True)  # Calls into the C++ engine to run the backward pass

KeyboardInterrupt: 
def predict(dataset, model, text, next_words=5):
    model.eval()
    words = text.split(' ')
    state_h, state_c = model.init_state(len(words))

    for i in range(0, next_words):
        x = torch.tensor([[dataset.word_to_index[w] for w in words[i:]]]).to(device)
        y_pred, (state_h, state_c) = model(x, (state_h, state_c))

        last_word_logits = y_pred[0][-1]
        p = torch.nn.functional.softmax(last_word_logits, dim=0).detach().cpu().numpy()
        word_index = np.random.choice(len(last_word_logits), p=p)
        words.append(dataset.index_to_word[word_index])

    return words
predict(dataset, model, 'kmicic szedł')
['kmicic', 'szedł', 'zwycięzco', 'po', 'do', 'zlituj', 'i']