modelowanie-jezykowe-aitech-cw/cw/09_Model_neuronowy_rekurencyjny.ipynb
Jakub Pokrywka 7bf28acbf4 09
2022-05-09 09:59:47 +02:00

61 KiB
Raw Blame History

Logo 1

Modelowanie Języka

9. Model neuronowy rekurencyjny [ćwiczenia]

Jakub Pokrywka (2022)

Logo 2

import torch
from torch import nn, optim
from torch.utils.data import DataLoader
import numpy as np
from collections import Counter
import re
device = 'cpu'
! wget https://wolnelektury.pl/media/book/txt/potop-tom-pierwszy.txt
! wget https://wolnelektury.pl/media/book/txt/potop-tom-drugi.txt
! wget https://wolnelektury.pl/media/book/txt/potop-tom-trzeci.txt
--2022-05-08 19:27:04--  https://wolnelektury.pl/media/book/txt/potop-tom-pierwszy.txt
Resolving wolnelektury.pl (wolnelektury.pl)... 51.83.143.148, 2001:41d0:602:3294::
Connecting to wolnelektury.pl (wolnelektury.pl)|51.83.143.148|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 877893 (857K) [text/plain]
Saving to: potop-tom-pierwszy.txt.2

potop-tom-pierwszy. 100%[===================>] 857,32K  --.-KB/s    in 0,07s   

2022-05-08 19:27:04 (12,0 MB/s) - potop-tom-pierwszy.txt.2 saved [877893/877893]

--2022-05-08 19:27:04--  https://wolnelektury.pl/media/book/txt/potop-tom-drugi.txt
Resolving wolnelektury.pl (wolnelektury.pl)... 51.83.143.148, 2001:41d0:602:3294::
Connecting to wolnelektury.pl (wolnelektury.pl)|51.83.143.148|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1087797 (1,0M) [text/plain]
Saving to: potop-tom-drugi.txt.2

potop-tom-drugi.txt 100%[===================>]   1,04M  --.-KB/s    in 0,08s   

2022-05-08 19:27:04 (12,9 MB/s) - potop-tom-drugi.txt.2 saved [1087797/1087797]

--2022-05-08 19:27:05--  https://wolnelektury.pl/media/book/txt/potop-tom-trzeci.txt
Resolving wolnelektury.pl (wolnelektury.pl)... 51.83.143.148, 2001:41d0:602:3294::
Connecting to wolnelektury.pl (wolnelektury.pl)|51.83.143.148|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 788219 (770K) [text/plain]
Saving to: potop-tom-trzeci.txt.2

potop-tom-trzeci.tx 100%[===================>] 769,75K  --.-KB/s    in 0,06s   

2022-05-08 19:27:05 (12,0 MB/s) - potop-tom-trzeci.txt.2 saved [788219/788219]

!cat potop-* > potop.txt
class Dataset(torch.utils.data.Dataset):
    def __init__(
            self,
            sequence_length,
    ):
        self.sequence_length = sequence_length
        self.words = self.load()
        self.uniq_words = self.get_uniq_words()

        self.index_to_word = {index: word for index, word in enumerate(self.uniq_words)}
        self.word_to_index = {word: index for index, word in enumerate(self.uniq_words)}

        self.words_indexes = [self.word_to_index[w] for w in self.words]

    def load(self):
        with open('potop.txt', 'r') as f_in:
            text = [x.rstrip() for x in f_in.readlines() if x.strip()]
            text = ' '.join(text).lower()
            text = re.sub('[^a-ząćęłńóśźż ]', '', text) 
            text = text.split(' ')
        return text
    
    
    def get_uniq_words(self):
        word_counts = Counter(self.words)
        return sorted(word_counts, key=word_counts.get, reverse=True)

    def __len__(self):
        return len(self.words_indexes) - self.sequence_length

    def __getitem__(self, index):
        return (
            torch.tensor(self.words_indexes[index:index+self.sequence_length]),
            torch.tensor(self.words_indexes[index+1:index+self.sequence_length+1]),
        )
dataset = Dataset(5)
dataset[200]
(tensor([  551,    18,    17,   255, 10748]),
 tensor([   18,    17,   255, 10748,    34]))
[dataset.index_to_word[x] for x in [   551,    18,    17,   255, 10748]]
['patrzył', 'tak', 'jak', 'człowiek', 'zbudzony']
[dataset.index_to_word[x] for x in [   18,    17,   255, 10748,    34]]
['tak', 'jak', 'człowiek', 'zbudzony', 'ze']
input_tensor = torch.tensor([[ 551,    18,    17,   255, 10748]], dtype=torch.int32).to(device)
#input_tensor = torch.tensor([[ 551,    18]], dtype=torch.int32).to(device)
class Model(nn.Module):
    def __init__(self, vocab_size):
        super(Model, self).__init__()
        self.lstm_size = 128
        self.embedding_dim = 128
        self.num_layers = 3

        self.embedding = nn.Embedding(
            num_embeddings=vocab_size,
            embedding_dim=self.embedding_dim,
        )
        self.lstm = nn.LSTM(
            input_size=self.lstm_size,
            hidden_size=self.lstm_size,
            num_layers=self.num_layers,
            dropout=0.2,
        )
        self.fc = nn.Linear(self.lstm_size, vocab_size)

    def forward(self, x, prev_state = None):
        embed = self.embedding(x)
        output, state = self.lstm(embed, prev_state)
        logits = self.fc(output)
        return logits, state

    def init_state(self, sequence_length):
        return (torch.zeros(self.num_layers, sequence_length, self.lstm_size).to(device),
                torch.zeros(self.num_layers, sequence_length, self.lstm_size).to(device))
model = Model(len(dataset)).to(device)
y_pred, (state_h, state_c) = model(input_tensor)
y_pred
tensor([[[ 0.0046, -0.0113,  0.0313,  ...,  0.0198, -0.0312,  0.0223],
         [ 0.0039, -0.0110,  0.0303,  ...,  0.0213, -0.0302,  0.0230],
         [ 0.0029, -0.0133,  0.0265,  ...,  0.0204, -0.0297,  0.0219],
         [ 0.0010, -0.0120,  0.0282,  ...,  0.0241, -0.0314,  0.0241],
         [ 0.0038, -0.0106,  0.0346,  ...,  0.0230, -0.0333,  0.0232]]],
       grad_fn=<AddBackward0>)
y_pred.shape
torch.Size([1, 5, 1187998])
def train(dataset, model, max_epochs, batch_size):
    model.train()

    dataloader = DataLoader(dataset, batch_size=batch_size)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    for epoch in range(max_epochs):
        for batch, (x, y) in enumerate(dataloader):
            optimizer.zero_grad()
            x = x.to(device)
            y = y.to(device)

            y_pred, (state_h, state_c) = model(x)
            loss = criterion(y_pred.transpose(1, 2), y)

            loss.backward()
            optimizer.step()

            print({ 'epoch': epoch, 'update in batch': batch, '/' : len(dataloader), 'loss': loss.item() })
model = Model(vocab_size = len(dataset.uniq_words)).to(device)
train(dataset, model, 1, 64)
{'epoch': 0, 'update in batch': 0, '/': 18563, 'loss': 10.717817306518555}
{'epoch': 0, 'update in batch': 1, '/': 18563, 'loss': 10.699922561645508}
{'epoch': 0, 'update in batch': 2, '/': 18563, 'loss': 10.701103210449219}
{'epoch': 0, 'update in batch': 3, '/': 18563, 'loss': 10.700254440307617}
{'epoch': 0, 'update in batch': 4, '/': 18563, 'loss': 10.69465160369873}
{'epoch': 0, 'update in batch': 5, '/': 18563, 'loss': 10.681333541870117}
{'epoch': 0, 'update in batch': 6, '/': 18563, 'loss': 10.668376922607422}
{'epoch': 0, 'update in batch': 7, '/': 18563, 'loss': 10.675261497497559}
{'epoch': 0, 'update in batch': 8, '/': 18563, 'loss': 10.665823936462402}
{'epoch': 0, 'update in batch': 9, '/': 18563, 'loss': 10.655462265014648}
{'epoch': 0, 'update in batch': 10, '/': 18563, 'loss': 10.591516494750977}
{'epoch': 0, 'update in batch': 11, '/': 18563, 'loss': 10.580559730529785}
{'epoch': 0, 'update in batch': 12, '/': 18563, 'loss': 10.524133682250977}
{'epoch': 0, 'update in batch': 13, '/': 18563, 'loss': 10.480895042419434}
{'epoch': 0, 'update in batch': 14, '/': 18563, 'loss': 10.33996295928955}
{'epoch': 0, 'update in batch': 15, '/': 18563, 'loss': 10.345580101013184}
{'epoch': 0, 'update in batch': 16, '/': 18563, 'loss': 10.200639724731445}
{'epoch': 0, 'update in batch': 17, '/': 18563, 'loss': 10.030133247375488}
{'epoch': 0, 'update in batch': 18, '/': 18563, 'loss': 10.046720504760742}
{'epoch': 0, 'update in batch': 19, '/': 18563, 'loss': 10.00318717956543}
{'epoch': 0, 'update in batch': 20, '/': 18563, 'loss': 9.588350296020508}
{'epoch': 0, 'update in batch': 21, '/': 18563, 'loss': 9.780914306640625}
{'epoch': 0, 'update in batch': 22, '/': 18563, 'loss': 9.36646842956543}
{'epoch': 0, 'update in batch': 23, '/': 18563, 'loss': 9.306387901306152}
{'epoch': 0, 'update in batch': 24, '/': 18563, 'loss': 9.150574684143066}
{'epoch': 0, 'update in batch': 25, '/': 18563, 'loss': 8.89719295501709}
{'epoch': 0, 'update in batch': 26, '/': 18563, 'loss': 8.741975784301758}
{'epoch': 0, 'update in batch': 27, '/': 18563, 'loss': 9.36513614654541}
{'epoch': 0, 'update in batch': 28, '/': 18563, 'loss': 8.840768814086914}
{'epoch': 0, 'update in batch': 29, '/': 18563, 'loss': 8.356801986694336}
{'epoch': 0, 'update in batch': 30, '/': 18563, 'loss': 8.274016380310059}
{'epoch': 0, 'update in batch': 31, '/': 18563, 'loss': 8.944927215576172}
{'epoch': 0, 'update in batch': 32, '/': 18563, 'loss': 8.923280715942383}
{'epoch': 0, 'update in batch': 33, '/': 18563, 'loss': 8.479402542114258}
{'epoch': 0, 'update in batch': 34, '/': 18563, 'loss': 8.42425537109375}
{'epoch': 0, 'update in batch': 35, '/': 18563, 'loss': 9.487113952636719}
{'epoch': 0, 'update in batch': 36, '/': 18563, 'loss': 8.314191818237305}
{'epoch': 0, 'update in batch': 37, '/': 18563, 'loss': 8.0274658203125}
{'epoch': 0, 'update in batch': 38, '/': 18563, 'loss': 8.725769996643066}
{'epoch': 0, 'update in batch': 39, '/': 18563, 'loss': 8.67934799194336}
{'epoch': 0, 'update in batch': 40, '/': 18563, 'loss': 8.872161865234375}
{'epoch': 0, 'update in batch': 41, '/': 18563, 'loss': 7.883971214294434}
{'epoch': 0, 'update in batch': 42, '/': 18563, 'loss': 7.682810306549072}
{'epoch': 0, 'update in batch': 43, '/': 18563, 'loss': 7.880677223205566}
{'epoch': 0, 'update in batch': 44, '/': 18563, 'loss': 7.807427406311035}
{'epoch': 0, 'update in batch': 45, '/': 18563, 'loss': 7.93829870223999}
{'epoch': 0, 'update in batch': 46, '/': 18563, 'loss': 7.718912601470947}
{'epoch': 0, 'update in batch': 47, '/': 18563, 'loss': 8.309863090515137}
{'epoch': 0, 'update in batch': 48, '/': 18563, 'loss': 9.091133117675781}
{'epoch': 0, 'update in batch': 49, '/': 18563, 'loss': 9.317312240600586}
{'epoch': 0, 'update in batch': 50, '/': 18563, 'loss': 8.517735481262207}
{'epoch': 0, 'update in batch': 51, '/': 18563, 'loss': 7.697592258453369}
{'epoch': 0, 'update in batch': 52, '/': 18563, 'loss': 6.838181972503662}
{'epoch': 0, 'update in batch': 53, '/': 18563, 'loss': 7.967227935791016}
{'epoch': 0, 'update in batch': 54, '/': 18563, 'loss': 8.47049331665039}
{'epoch': 0, 'update in batch': 55, '/': 18563, 'loss': 8.958921432495117}
{'epoch': 0, 'update in batch': 56, '/': 18563, 'loss': 8.316679000854492}
{'epoch': 0, 'update in batch': 57, '/': 18563, 'loss': 8.997099876403809}
{'epoch': 0, 'update in batch': 58, '/': 18563, 'loss': 8.608811378479004}
{'epoch': 0, 'update in batch': 59, '/': 18563, 'loss': 9.377460479736328}
{'epoch': 0, 'update in batch': 60, '/': 18563, 'loss': 8.6201171875}
{'epoch': 0, 'update in batch': 61, '/': 18563, 'loss': 8.821510314941406}
{'epoch': 0, 'update in batch': 62, '/': 18563, 'loss': 8.915961265563965}
{'epoch': 0, 'update in batch': 63, '/': 18563, 'loss': 8.222617149353027}
{'epoch': 0, 'update in batch': 64, '/': 18563, 'loss': 9.266777992248535}
{'epoch': 0, 'update in batch': 65, '/': 18563, 'loss': 8.749354362487793}
{'epoch': 0, 'update in batch': 66, '/': 18563, 'loss': 8.311641693115234}
{'epoch': 0, 'update in batch': 67, '/': 18563, 'loss': 8.553888320922852}
{'epoch': 0, 'update in batch': 68, '/': 18563, 'loss': 8.790258407592773}
{'epoch': 0, 'update in batch': 69, '/': 18563, 'loss': 9.090133666992188}
{'epoch': 0, 'update in batch': 70, '/': 18563, 'loss': 8.893723487854004}
{'epoch': 0, 'update in batch': 71, '/': 18563, 'loss': 8.844594955444336}
{'epoch': 0, 'update in batch': 72, '/': 18563, 'loss': 7.771625518798828}
{'epoch': 0, 'update in batch': 73, '/': 18563, 'loss': 8.536479949951172}
{'epoch': 0, 'update in batch': 74, '/': 18563, 'loss': 7.300860404968262}
{'epoch': 0, 'update in batch': 75, '/': 18563, 'loss': 8.62000846862793}
{'epoch': 0, 'update in batch': 76, '/': 18563, 'loss': 8.67784309387207}
{'epoch': 0, 'update in batch': 77, '/': 18563, 'loss': 7.319235801696777}
{'epoch': 0, 'update in batch': 78, '/': 18563, 'loss': 8.322186470031738}
{'epoch': 0, 'update in batch': 79, '/': 18563, 'loss': 7.767421722412109}
{'epoch': 0, 'update in batch': 80, '/': 18563, 'loss': 8.817885398864746}
{'epoch': 0, 'update in batch': 81, '/': 18563, 'loss': 8.133109092712402}
{'epoch': 0, 'update in batch': 82, '/': 18563, 'loss': 7.822054862976074}
{'epoch': 0, 'update in batch': 83, '/': 18563, 'loss': 8.055540084838867}
{'epoch': 0, 'update in batch': 84, '/': 18563, 'loss': 8.053682327270508}
{'epoch': 0, 'update in batch': 85, '/': 18563, 'loss': 8.018306732177734}
{'epoch': 0, 'update in batch': 86, '/': 18563, 'loss': 8.371909141540527}
{'epoch': 0, 'update in batch': 87, '/': 18563, 'loss': 8.057979583740234}
{'epoch': 0, 'update in batch': 88, '/': 18563, 'loss': 8.340703010559082}
{'epoch': 0, 'update in batch': 89, '/': 18563, 'loss': 8.7703857421875}
{'epoch': 0, 'update in batch': 90, '/': 18563, 'loss': 9.714847564697266}
{'epoch': 0, 'update in batch': 91, '/': 18563, 'loss': 8.621702194213867}
{'epoch': 0, 'update in batch': 92, '/': 18563, 'loss': 9.406997680664062}
{'epoch': 0, 'update in batch': 93, '/': 18563, 'loss': 9.29774284362793}
{'epoch': 0, 'update in batch': 94, '/': 18563, 'loss': 8.649836540222168}
{'epoch': 0, 'update in batch': 95, '/': 18563, 'loss': 8.441780090332031}
{'epoch': 0, 'update in batch': 96, '/': 18563, 'loss': 7.991406440734863}
{'epoch': 0, 'update in batch': 97, '/': 18563, 'loss': 9.314489364624023}
{'epoch': 0, 'update in batch': 98, '/': 18563, 'loss': 8.368816375732422}
{'epoch': 0, 'update in batch': 99, '/': 18563, 'loss': 8.771149635314941}
{'epoch': 0, 'update in batch': 100, '/': 18563, 'loss': 7.8758111000061035}
{'epoch': 0, 'update in batch': 101, '/': 18563, 'loss': 8.341328620910645}
{'epoch': 0, 'update in batch': 102, '/': 18563, 'loss': 8.413129806518555}
{'epoch': 0, 'update in batch': 103, '/': 18563, 'loss': 7.372011661529541}
{'epoch': 0, 'update in batch': 104, '/': 18563, 'loss': 8.170934677124023}
{'epoch': 0, 'update in batch': 105, '/': 18563, 'loss': 8.109993934631348}
{'epoch': 0, 'update in batch': 106, '/': 18563, 'loss': 8.172578811645508}
{'epoch': 0, 'update in batch': 107, '/': 18563, 'loss': 8.33222484588623}
{'epoch': 0, 'update in batch': 108, '/': 18563, 'loss': 7.997575283050537}
{'epoch': 0, 'update in batch': 109, '/': 18563, 'loss': 7.847937107086182}
{'epoch': 0, 'update in batch': 110, '/': 18563, 'loss': 7.351314544677734}
{'epoch': 0, 'update in batch': 111, '/': 18563, 'loss': 8.472936630249023}
{'epoch': 0, 'update in batch': 112, '/': 18563, 'loss': 7.855953216552734}
{'epoch': 0, 'update in batch': 113, '/': 18563, 'loss': 8.163175582885742}
{'epoch': 0, 'update in batch': 114, '/': 18563, 'loss': 8.208657264709473}
{'epoch': 0, 'update in batch': 115, '/': 18563, 'loss': 8.781523704528809}
{'epoch': 0, 'update in batch': 116, '/': 18563, 'loss': 8.449674606323242}
{'epoch': 0, 'update in batch': 117, '/': 18563, 'loss': 8.176030158996582}
{'epoch': 0, 'update in batch': 118, '/': 18563, 'loss': 8.415689468383789}
{'epoch': 0, 'update in batch': 119, '/': 18563, 'loss': 8.645845413208008}
{'epoch': 0, 'update in batch': 120, '/': 18563, 'loss': 8.160420417785645}
{'epoch': 0, 'update in batch': 121, '/': 18563, 'loss': 8.117982864379883}
{'epoch': 0, 'update in batch': 122, '/': 18563, 'loss': 9.099283218383789}
{'epoch': 0, 'update in batch': 123, '/': 18563, 'loss': 7.98253870010376}
{'epoch': 0, 'update in batch': 124, '/': 18563, 'loss': 8.112133979797363}
{'epoch': 0, 'update in batch': 125, '/': 18563, 'loss': 8.479134559631348}
{'epoch': 0, 'update in batch': 126, '/': 18563, 'loss': 8.92817497253418}
{'epoch': 0, 'update in batch': 127, '/': 18563, 'loss': 8.38918399810791}
{'epoch': 0, 'update in batch': 128, '/': 18563, 'loss': 9.000529289245605}
{'epoch': 0, 'update in batch': 129, '/': 18563, 'loss': 8.525534629821777}
{'epoch': 0, 'update in batch': 130, '/': 18563, 'loss': 9.055428504943848}
{'epoch': 0, 'update in batch': 131, '/': 18563, 'loss': 8.818662643432617}
{'epoch': 0, 'update in batch': 132, '/': 18563, 'loss': 8.807767868041992}
{'epoch': 0, 'update in batch': 133, '/': 18563, 'loss': 8.398343086242676}
{'epoch': 0, 'update in batch': 134, '/': 18563, 'loss': 8.435093879699707}
{'epoch': 0, 'update in batch': 135, '/': 18563, 'loss': 7.877000331878662}
{'epoch': 0, 'update in batch': 136, '/': 18563, 'loss': 8.197925567626953}
{'epoch': 0, 'update in batch': 137, '/': 18563, 'loss': 8.655011177062988}
{'epoch': 0, 'update in batch': 138, '/': 18563, 'loss': 7.786923885345459}
{'epoch': 0, 'update in batch': 139, '/': 18563, 'loss': 8.338996887207031}
{'epoch': 0, 'update in batch': 140, '/': 18563, 'loss': 8.607789993286133}
{'epoch': 0, 'update in batch': 141, '/': 18563, 'loss': 8.52219295501709}
{'epoch': 0, 'update in batch': 142, '/': 18563, 'loss': 8.436418533325195}
{'epoch': 0, 'update in batch': 143, '/': 18563, 'loss': 7.999323844909668}
{'epoch': 0, 'update in batch': 144, '/': 18563, 'loss': 7.543336391448975}
{'epoch': 0, 'update in batch': 145, '/': 18563, 'loss': 7.3255791664123535}
{'epoch': 0, 'update in batch': 146, '/': 18563, 'loss': 7.993613243103027}
{'epoch': 0, 'update in batch': 147, '/': 18563, 'loss': 8.8505859375}
{'epoch': 0, 'update in batch': 148, '/': 18563, 'loss': 8.146835327148438}
{'epoch': 0, 'update in batch': 149, '/': 18563, 'loss': 8.532424926757812}
{'epoch': 0, 'update in batch': 150, '/': 18563, 'loss': 8.323905944824219}
{'epoch': 0, 'update in batch': 151, '/': 18563, 'loss': 7.8726677894592285}
{'epoch': 0, 'update in batch': 152, '/': 18563, 'loss': 7.912005424499512}
{'epoch': 0, 'update in batch': 153, '/': 18563, 'loss': 8.010560035705566}
{'epoch': 0, 'update in batch': 154, '/': 18563, 'loss': 7.9417009353637695}
{'epoch': 0, 'update in batch': 155, '/': 18563, 'loss': 7.991711616516113}
{'epoch': 0, 'update in batch': 156, '/': 18563, 'loss': 8.27558708190918}
{'epoch': 0, 'update in batch': 157, '/': 18563, 'loss': 7.736246585845947}
{'epoch': 0, 'update in batch': 158, '/': 18563, 'loss': 7.4755754470825195}
{'epoch': 0, 'update in batch': 159, '/': 18563, 'loss': 8.023443222045898}
{'epoch': 0, 'update in batch': 160, '/': 18563, 'loss': 8.130350112915039}
{'epoch': 0, 'update in batch': 161, '/': 18563, 'loss': 7.770634651184082}
{'epoch': 0, 'update in batch': 162, '/': 18563, 'loss': 7.775434970855713}
{'epoch': 0, 'update in batch': 163, '/': 18563, 'loss': 7.965312957763672}
{'epoch': 0, 'update in batch': 164, '/': 18563, 'loss': 7.977341651916504}
{'epoch': 0, 'update in batch': 165, '/': 18563, 'loss': 7.703671455383301}
{'epoch': 0, 'update in batch': 166, '/': 18563, 'loss': 8.027135848999023}
{'epoch': 0, 'update in batch': 167, '/': 18563, 'loss': 7.7673773765563965}
{'epoch': 0, 'update in batch': 168, '/': 18563, 'loss': 8.654549598693848}
{'epoch': 0, 'update in batch': 169, '/': 18563, 'loss': 7.8060808181762695}
{'epoch': 0, 'update in batch': 170, '/': 18563, 'loss': 7.33704137802124}
{'epoch': 0, 'update in batch': 171, '/': 18563, 'loss': 7.971919059753418}
{'epoch': 0, 'update in batch': 172, '/': 18563, 'loss': 7.450611114501953}
{'epoch': 0, 'update in batch': 173, '/': 18563, 'loss': 7.978057861328125}
{'epoch': 0, 'update in batch': 174, '/': 18563, 'loss': 8.264434814453125}
{'epoch': 0, 'update in batch': 175, '/': 18563, 'loss': 8.47761058807373}
{'epoch': 0, 'update in batch': 176, '/': 18563, 'loss': 7.643885135650635}
{'epoch': 0, 'update in batch': 177, '/': 18563, 'loss': 8.696805000305176}
{'epoch': 0, 'update in batch': 178, '/': 18563, 'loss': 9.144462585449219}
{'epoch': 0, 'update in batch': 179, '/': 18563, 'loss': 8.582620620727539}
{'epoch': 0, 'update in batch': 180, '/': 18563, 'loss': 8.495562553405762}
{'epoch': 0, 'update in batch': 181, '/': 18563, 'loss': 9.259647369384766}
{'epoch': 0, 'update in batch': 182, '/': 18563, 'loss': 8.286632537841797}
{'epoch': 0, 'update in batch': 183, '/': 18563, 'loss': 8.378074645996094}
{'epoch': 0, 'update in batch': 184, '/': 18563, 'loss': 8.404892921447754}
{'epoch': 0, 'update in batch': 185, '/': 18563, 'loss': 9.206843376159668}
{'epoch': 0, 'update in batch': 186, '/': 18563, 'loss': 8.97215747833252}
{'epoch': 0, 'update in batch': 187, '/': 18563, 'loss': 8.281005859375}
{'epoch': 0, 'update in batch': 188, '/': 18563, 'loss': 7.638144493103027}
{'epoch': 0, 'update in batch': 189, '/': 18563, 'loss': 7.991082668304443}
{'epoch': 0, 'update in batch': 190, '/': 18563, 'loss': 8.207674026489258}
{'epoch': 0, 'update in batch': 191, '/': 18563, 'loss': 8.16801643371582}
{'epoch': 0, 'update in batch': 192, '/': 18563, 'loss': 7.827309608459473}
{'epoch': 0, 'update in batch': 193, '/': 18563, 'loss': 8.387285232543945}
{'epoch': 0, 'update in batch': 194, '/': 18563, 'loss': 7.990261077880859}
{'epoch': 0, 'update in batch': 195, '/': 18563, 'loss': 7.7953925132751465}
{'epoch': 0, 'update in batch': 196, '/': 18563, 'loss': 7.252983093261719}
{'epoch': 0, 'update in batch': 197, '/': 18563, 'loss': 7.806585788726807}
{'epoch': 0, 'update in batch': 198, '/': 18563, 'loss': 7.871600151062012}
{'epoch': 0, 'update in batch': 199, '/': 18563, 'loss': 7.639830589294434}
{'epoch': 0, 'update in batch': 200, '/': 18563, 'loss': 8.108308792114258}
{'epoch': 0, 'update in batch': 201, '/': 18563, 'loss': 7.41513729095459}
{'epoch': 0, 'update in batch': 202, '/': 18563, 'loss': 8.103743553161621}
{'epoch': 0, 'update in batch': 203, '/': 18563, 'loss': 8.82174301147461}
{'epoch': 0, 'update in batch': 204, '/': 18563, 'loss': 8.34859561920166}
{'epoch': 0, 'update in batch': 205, '/': 18563, 'loss': 7.890545845031738}
{'epoch': 0, 'update in batch': 206, '/': 18563, 'loss': 7.679532527923584}
{'epoch': 0, 'update in batch': 207, '/': 18563, 'loss': 7.810311317443848}
{'epoch': 0, 'update in batch': 208, '/': 18563, 'loss': 8.342585563659668}
{'epoch': 0, 'update in batch': 209, '/': 18563, 'loss': 8.253597259521484}
{'epoch': 0, 'update in batch': 210, '/': 18563, 'loss': 7.963072299957275}
{'epoch': 0, 'update in batch': 211, '/': 18563, 'loss': 8.537101745605469}
{'epoch': 0, 'update in batch': 212, '/': 18563, 'loss': 8.503724098205566}
{'epoch': 0, 'update in batch': 213, '/': 18563, 'loss': 8.568987846374512}
{'epoch': 0, 'update in batch': 214, '/': 18563, 'loss': 7.760678291320801}
{'epoch': 0, 'update in batch': 215, '/': 18563, 'loss': 8.302183151245117}
{'epoch': 0, 'update in batch': 216, '/': 18563, 'loss': 7.427420616149902}
{'epoch': 0, 'update in batch': 217, '/': 18563, 'loss': 8.05746078491211}
{'epoch': 0, 'update in batch': 218, '/': 18563, 'loss': 8.82285213470459}
{'epoch': 0, 'update in batch': 219, '/': 18563, 'loss': 7.948827266693115}
{'epoch': 0, 'update in batch': 220, '/': 18563, 'loss': 8.164112091064453}
{'epoch': 0, 'update in batch': 221, '/': 18563, 'loss': 7.721047401428223}
{'epoch': 0, 'update in batch': 222, '/': 18563, 'loss': 7.668707370758057}
{'epoch': 0, 'update in batch': 223, '/': 18563, 'loss': 8.576696395874023}
{'epoch': 0, 'update in batch': 224, '/': 18563, 'loss': 8.253091812133789}
{'epoch': 0, 'update in batch': 225, '/': 18563, 'loss': 8.303543090820312}
{'epoch': 0, 'update in batch': 226, '/': 18563, 'loss': 8.069855690002441}
{'epoch': 0, 'update in batch': 227, '/': 18563, 'loss': 8.57229232788086}
{'epoch': 0, 'update in batch': 228, '/': 18563, 'loss': 8.904585838317871}
{'epoch': 0, 'update in batch': 229, '/': 18563, 'loss': 8.485595703125}
{'epoch': 0, 'update in batch': 230, '/': 18563, 'loss': 8.22756290435791}
{'epoch': 0, 'update in batch': 231, '/': 18563, 'loss': 8.281603813171387}
{'epoch': 0, 'update in batch': 232, '/': 18563, 'loss': 7.591467380523682}
{'epoch': 0, 'update in batch': 233, '/': 18563, 'loss': 7.8028883934021}
{'epoch': 0, 'update in batch': 234, '/': 18563, 'loss': 8.079168319702148}
{'epoch': 0, 'update in batch': 235, '/': 18563, 'loss': 7.578390598297119}
{'epoch': 0, 'update in batch': 236, '/': 18563, 'loss': 7.865830421447754}
{'epoch': 0, 'update in batch': 237, '/': 18563, 'loss': 7.105422019958496}
{'epoch': 0, 'update in batch': 238, '/': 18563, 'loss': 8.034143447875977}
{'epoch': 0, 'update in batch': 239, '/': 18563, 'loss': 7.23009729385376}
{'epoch': 0, 'update in batch': 240, '/': 18563, 'loss': 7.221669673919678}
{'epoch': 0, 'update in batch': 241, '/': 18563, 'loss': 7.118913173675537}
{'epoch': 0, 'update in batch': 242, '/': 18563, 'loss': 7.690147399902344}
{'epoch': 0, 'update in batch': 243, '/': 18563, 'loss': 7.676979064941406}
{'epoch': 0, 'update in batch': 244, '/': 18563, 'loss': 8.231537818908691}
{'epoch': 0, 'update in batch': 245, '/': 18563, 'loss': 8.212566375732422}
{'epoch': 0, 'update in batch': 246, '/': 18563, 'loss': 9.095616340637207}
{'epoch': 0, 'update in batch': 247, '/': 18563, 'loss': 8.249703407287598}
{'epoch': 0, 'update in batch': 248, '/': 18563, 'loss': 9.082058906555176}
{'epoch': 0, 'update in batch': 249, '/': 18563, 'loss': 8.530516624450684}
{'epoch': 0, 'update in batch': 250, '/': 18563, 'loss': 8.979915618896484}
{'epoch': 0, 'update in batch': 251, '/': 18563, 'loss': 8.667882919311523}
{'epoch': 0, 'update in batch': 252, '/': 18563, 'loss': 8.804525375366211}
{'epoch': 0, 'update in batch': 253, '/': 18563, 'loss': 8.67729377746582}
{'epoch': 0, 'update in batch': 254, '/': 18563, 'loss': 8.580761909484863}
{'epoch': 0, 'update in batch': 255, '/': 18563, 'loss': 7.724173545837402}
{'epoch': 0, 'update in batch': 256, '/': 18563, 'loss': 7.7925591468811035}
{'epoch': 0, 'update in batch': 257, '/': 18563, 'loss': 7.731482028961182}
{'epoch': 0, 'update in batch': 258, '/': 18563, 'loss': 7.644040107727051}
{'epoch': 0, 'update in batch': 259, '/': 18563, 'loss': 7.947877407073975}
{'epoch': 0, 'update in batch': 260, '/': 18563, 'loss': 7.649043083190918}
{'epoch': 0, 'update in batch': 261, '/': 18563, 'loss': 7.40912389755249}
{'epoch': 0, 'update in batch': 262, '/': 18563, 'loss': 8.199918746948242}
{'epoch': 0, 'update in batch': 263, '/': 18563, 'loss': 7.272132873535156}
{'epoch': 0, 'update in batch': 264, '/': 18563, 'loss': 7.205214500427246}
{'epoch': 0, 'update in batch': 265, '/': 18563, 'loss': 8.999595642089844}
{'epoch': 0, 'update in batch': 266, '/': 18563, 'loss': 7.851510524749756}
{'epoch': 0, 'update in batch': 267, '/': 18563, 'loss': 7.748948097229004}
{'epoch': 0, 'update in batch': 268, '/': 18563, 'loss': 7.96875}
{'epoch': 0, 'update in batch': 269, '/': 18563, 'loss': 7.627255916595459}
{'epoch': 0, 'update in batch': 270, '/': 18563, 'loss': 7.719862937927246}
{'epoch': 0, 'update in batch': 271, '/': 18563, 'loss': 7.58780574798584}
{'epoch': 0, 'update in batch': 272, '/': 18563, 'loss': 8.386865615844727}
{'epoch': 0, 'update in batch': 273, '/': 18563, 'loss': 8.708396911621094}
{'epoch': 0, 'update in batch': 274, '/': 18563, 'loss': 7.853432655334473}
{'epoch': 0, 'update in batch': 275, '/': 18563, 'loss': 7.818131923675537}
{'epoch': 0, 'update in batch': 276, '/': 18563, 'loss': 7.714521884918213}
{'epoch': 0, 'update in batch': 277, '/': 18563, 'loss': 8.75371265411377}
{'epoch': 0, 'update in batch': 278, '/': 18563, 'loss': 7.6992998123168945}
{'epoch': 0, 'update in batch': 279, '/': 18563, 'loss': 7.652693748474121}
{'epoch': 0, 'update in batch': 280, '/': 18563, 'loss': 7.364585876464844}
{'epoch': 0, 'update in batch': 281, '/': 18563, 'loss': 7.742022514343262}
{'epoch': 0, 'update in batch': 282, '/': 18563, 'loss': 7.6205573081970215}
{'epoch': 0, 'update in batch': 283, '/': 18563, 'loss': 7.475846290588379}
{'epoch': 0, 'update in batch': 284, '/': 18563, 'loss': 7.302148342132568}
{'epoch': 0, 'update in batch': 285, '/': 18563, 'loss': 7.524351596832275}
{'epoch': 0, 'update in batch': 286, '/': 18563, 'loss': 7.755963325500488}
{'epoch': 0, 'update in batch': 287, '/': 18563, 'loss': 7.620995998382568}
{'epoch': 0, 'update in batch': 288, '/': 18563, 'loss': 7.289975166320801}
{'epoch': 0, 'update in batch': 289, '/': 18563, 'loss': 7.470652103424072}
{'epoch': 0, 'update in batch': 290, '/': 18563, 'loss': 7.297110557556152}
{'epoch': 0, 'update in batch': 291, '/': 18563, 'loss': 7.907563209533691}
{'epoch': 0, 'update in batch': 292, '/': 18563, 'loss': 8.051852226257324}
{'epoch': 0, 'update in batch': 293, '/': 18563, 'loss': 6.691899299621582}
{'epoch': 0, 'update in batch': 294, '/': 18563, 'loss': 7.9747819900512695}
{'epoch': 0, 'update in batch': 295, '/': 18563, 'loss': 7.415904998779297}
{'epoch': 0, 'update in batch': 296, '/': 18563, 'loss': 7.479670524597168}
{'epoch': 0, 'update in batch': 297, '/': 18563, 'loss': 7.9454755783081055}
{'epoch': 0, 'update in batch': 298, '/': 18563, 'loss': 7.79656457901001}
{'epoch': 0, 'update in batch': 299, '/': 18563, 'loss': 7.644859313964844}
{'epoch': 0, 'update in batch': 300, '/': 18563, 'loss': 7.649240970611572}
{'epoch': 0, 'update in batch': 301, '/': 18563, 'loss': 7.497203826904297}
{'epoch': 0, 'update in batch': 302, '/': 18563, 'loss': 7.169632911682129}
{'epoch': 0, 'update in batch': 303, '/': 18563, 'loss': 7.124764442443848}
{'epoch': 0, 'update in batch': 304, '/': 18563, 'loss': 7.728893280029297}
{'epoch': 0, 'update in batch': 305, '/': 18563, 'loss': 8.029245376586914}
{'epoch': 0, 'update in batch': 306, '/': 18563, 'loss': 7.361662864685059}
{'epoch': 0, 'update in batch': 307, '/': 18563, 'loss': 8.070173263549805}
{'epoch': 0, 'update in batch': 308, '/': 18563, 'loss': 7.55655574798584}
{'epoch': 0, 'update in batch': 309, '/': 18563, 'loss': 7.713553428649902}
{'epoch': 0, 'update in batch': 310, '/': 18563, 'loss': 8.333553314208984}
{'epoch': 0, 'update in batch': 311, '/': 18563, 'loss': 8.089872360229492}
{'epoch': 0, 'update in batch': 312, '/': 18563, 'loss': 8.951356887817383}
{'epoch': 0, 'update in batch': 313, '/': 18563, 'loss': 8.920665740966797}
{'epoch': 0, 'update in batch': 314, '/': 18563, 'loss': 8.811259269714355}
{'epoch': 0, 'update in batch': 315, '/': 18563, 'loss': 8.719802856445312}
{'epoch': 0, 'update in batch': 316, '/': 18563, 'loss': 8.700776100158691}
{'epoch': 0, 'update in batch': 317, '/': 18563, 'loss': 8.846036911010742}
{'epoch': 0, 'update in batch': 318, '/': 18563, 'loss': 8.553533554077148}
{'epoch': 0, 'update in batch': 319, '/': 18563, 'loss': 9.257116317749023}
{'epoch': 0, 'update in batch': 320, '/': 18563, 'loss': 8.487042427062988}
{'epoch': 0, 'update in batch': 321, '/': 18563, 'loss': 8.743330955505371}
{'epoch': 0, 'update in batch': 322, '/': 18563, 'loss': 8.377813339233398}
{'epoch': 0, 'update in batch': 323, '/': 18563, 'loss': 8.41798210144043}
{'epoch': 0, 'update in batch': 324, '/': 18563, 'loss': 7.884764671325684}
{'epoch': 0, 'update in batch': 325, '/': 18563, 'loss': 8.827409744262695}
{'epoch': 0, 'update in batch': 326, '/': 18563, 'loss': 8.21721363067627}
{'epoch': 0, 'update in batch': 327, '/': 18563, 'loss': 8.522723197937012}
{'epoch': 0, 'update in batch': 328, '/': 18563, 'loss': 7.387178897857666}
{'epoch': 0, 'update in batch': 329, '/': 18563, 'loss': 8.58663558959961}
{'epoch': 0, 'update in batch': 330, '/': 18563, 'loss': 8.539435386657715}
{'epoch': 0, 'update in batch': 331, '/': 18563, 'loss': 8.35865592956543}
{'epoch': 0, 'update in batch': 332, '/': 18563, 'loss': 8.55555248260498}
{'epoch': 0, 'update in batch': 333, '/': 18563, 'loss': 7.9116950035095215}
{'epoch': 0, 'update in batch': 334, '/': 18563, 'loss': 8.424735069274902}
{'epoch': 0, 'update in batch': 335, '/': 18563, 'loss': 8.383890151977539}
{'epoch': 0, 'update in batch': 336, '/': 18563, 'loss': 8.145454406738281}
{'epoch': 0, 'update in batch': 337, '/': 18563, 'loss': 8.014772415161133}
{'epoch': 0, 'update in batch': 338, '/': 18563, 'loss': 8.532005310058594}
{'epoch': 0, 'update in batch': 339, '/': 18563, 'loss': 8.979973793029785}
{'epoch': 0, 'update in batch': 340, '/': 18563, 'loss': 8.3964204788208}
{'epoch': 0, 'update in batch': 341, '/': 18563, 'loss': 8.34205150604248}
{'epoch': 0, 'update in batch': 342, '/': 18563, 'loss': 7.861489295959473}
{'epoch': 0, 'update in batch': 343, '/': 18563, 'loss': 8.807058334350586}
{'epoch': 0, 'update in batch': 344, '/': 18563, 'loss': 8.14976978302002}
{'epoch': 0, 'update in batch': 345, '/': 18563, 'loss': 8.212860107421875}
{'epoch': 0, 'update in batch': 346, '/': 18563, 'loss': 8.323419570922852}
{'epoch': 0, 'update in batch': 347, '/': 18563, 'loss': 9.06071662902832}
{'epoch': 0, 'update in batch': 348, '/': 18563, 'loss': 8.79192066192627}
{'epoch': 0, 'update in batch': 349, '/': 18563, 'loss': 8.717201232910156}
{'epoch': 0, 'update in batch': 350, '/': 18563, 'loss': 8.149703979492188}
{'epoch': 0, 'update in batch': 351, '/': 18563, 'loss': 7.990046501159668}
{'epoch': 0, 'update in batch': 352, '/': 18563, 'loss': 7.8197221755981445}
{'epoch': 0, 'update in batch': 353, '/': 18563, 'loss': 8.022729873657227}
{'epoch': 0, 'update in batch': 354, '/': 18563, 'loss': 8.339923858642578}
{'epoch': 0, 'update in batch': 355, '/': 18563, 'loss': 7.867880821228027}
{'epoch': 0, 'update in batch': 356, '/': 18563, 'loss': 8.161782264709473}
{'epoch': 0, 'update in batch': 357, '/': 18563, 'loss': 7.711170196533203}
{'epoch': 0, 'update in batch': 358, '/': 18563, 'loss': 8.46279239654541}
{'epoch': 0, 'update in batch': 359, '/': 18563, 'loss': 8.327804565429688}
{'epoch': 0, 'update in batch': 360, '/': 18563, 'loss': 8.184597969055176}
{'epoch': 0, 'update in batch': 361, '/': 18563, 'loss': 8.126212120056152}
{'epoch': 0, 'update in batch': 362, '/': 18563, 'loss': 8.122446060180664}
{'epoch': 0, 'update in batch': 363, '/': 18563, 'loss': 7.730257511138916}
{'epoch': 0, 'update in batch': 364, '/': 18563, 'loss': 7.7179059982299805}
{'epoch': 0, 'update in batch': 365, '/': 18563, 'loss': 7.557857513427734}
{'epoch': 0, 'update in batch': 366, '/': 18563, 'loss': 8.614083290100098}
{'epoch': 0, 'update in batch': 367, '/': 18563, 'loss': 8.0489501953125}
{'epoch': 0, 'update in batch': 368, '/': 18563, 'loss': 8.355381965637207}
{'epoch': 0, 'update in batch': 369, '/': 18563, 'loss': 7.592991828918457}
{'epoch': 0, 'update in batch': 370, '/': 18563, 'loss': 7.674102783203125}
{'epoch': 0, 'update in batch': 371, '/': 18563, 'loss': 7.818256378173828}
{'epoch': 0, 'update in batch': 372, '/': 18563, 'loss': 8.510438919067383}
{'epoch': 0, 'update in batch': 373, '/': 18563, 'loss': 8.02087116241455}
{'epoch': 0, 'update in batch': 374, '/': 18563, 'loss': 8.206090927124023}
{'epoch': 0, 'update in batch': 375, '/': 18563, 'loss': 7.645677089691162}
{'epoch': 0, 'update in batch': 376, '/': 18563, 'loss': 8.241236686706543}
{'epoch': 0, 'update in batch': 377, '/': 18563, 'loss': 8.581649780273438}
{'epoch': 0, 'update in batch': 378, '/': 18563, 'loss': 9.361258506774902}
{'epoch': 0, 'update in batch': 379, '/': 18563, 'loss': 9.097440719604492}
{'epoch': 0, 'update in batch': 380, '/': 18563, 'loss': 8.081677436828613}
{'epoch': 0, 'update in batch': 381, '/': 18563, 'loss': 8.761143684387207}
{'epoch': 0, 'update in batch': 382, '/': 18563, 'loss': 7.9429121017456055}
{'epoch': 0, 'update in batch': 383, '/': 18563, 'loss': 8.05648422241211}
{'epoch': 0, 'update in batch': 384, '/': 18563, 'loss': 7.316658020019531}
{'epoch': 0, 'update in batch': 385, '/': 18563, 'loss': 8.597393035888672}
{'epoch': 0, 'update in batch': 386, '/': 18563, 'loss': 9.393728256225586}
{'epoch': 0, 'update in batch': 387, '/': 18563, 'loss': 8.225081443786621}
{'epoch': 0, 'update in batch': 388, '/': 18563, 'loss': 7.9958319664001465}
{'epoch': 0, 'update in batch': 389, '/': 18563, 'loss': 8.390036582946777}
{'epoch': 0, 'update in batch': 390, '/': 18563, 'loss': 7.745572566986084}
{'epoch': 0, 'update in batch': 391, '/': 18563, 'loss': 8.403060913085938}
{'epoch': 0, 'update in batch': 392, '/': 18563, 'loss': 8.703788757324219}
{'epoch': 0, 'update in batch': 393, '/': 18563, 'loss': 8.516857147216797}
{'epoch': 0, 'update in batch': 394, '/': 18563, 'loss': 8.078744888305664}
{'epoch': 0, 'update in batch': 395, '/': 18563, 'loss': 7.6597900390625}
{'epoch': 0, 'update in batch': 396, '/': 18563, 'loss': 8.454282760620117}
{'epoch': 0, 'update in batch': 397, '/': 18563, 'loss': 7.7727837562561035}
{'epoch': 0, 'update in batch': 398, '/': 18563, 'loss': 8.222984313964844}
{'epoch': 0, 'update in batch': 399, '/': 18563, 'loss': 8.369619369506836}
{'epoch': 0, 'update in batch': 400, '/': 18563, 'loss': 8.542525291442871}
{'epoch': 0, 'update in batch': 401, '/': 18563, 'loss': 7.9681854248046875}
{'epoch': 0, 'update in batch': 402, '/': 18563, 'loss': 8.842118263244629}
{'epoch': 0, 'update in batch': 403, '/': 18563, 'loss': 7.958454132080078}
{'epoch': 0, 'update in batch': 404, '/': 18563, 'loss': 7.084095001220703}
{'epoch': 0, 'update in batch': 405, '/': 18563, 'loss': 7.8765130043029785}
{'epoch': 0, 'update in batch': 406, '/': 18563, 'loss': 7.639691352844238}
{'epoch': 0, 'update in batch': 407, '/': 18563, 'loss': 7.440125942230225}
{'epoch': 0, 'update in batch': 408, '/': 18563, 'loss': 7.928472995758057}
{'epoch': 0, 'update in batch': 409, '/': 18563, 'loss': 8.704710960388184}
{'epoch': 0, 'update in batch': 410, '/': 18563, 'loss': 8.214713096618652}
{'epoch': 0, 'update in batch': 411, '/': 18563, 'loss': 8.115629196166992}
{'epoch': 0, 'update in batch': 412, '/': 18563, 'loss': 9.357975006103516}
{'epoch': 0, 'update in batch': 413, '/': 18563, 'loss': 7.756926536560059}
{'epoch': 0, 'update in batch': 414, '/': 18563, 'loss': 8.93007755279541}
{'epoch': 0, 'update in batch': 415, '/': 18563, 'loss': 8.929518699645996}
{'epoch': 0, 'update in batch': 416, '/': 18563, 'loss': 7.646470069885254}
{'epoch': 0, 'update in batch': 417, '/': 18563, 'loss': 8.457891464233398}
{'epoch': 0, 'update in batch': 418, '/': 18563, 'loss': 7.377375602722168}
{'epoch': 0, 'update in batch': 419, '/': 18563, 'loss': 8.03713607788086}
{'epoch': 0, 'update in batch': 420, '/': 18563, 'loss': 8.125130653381348}
{'epoch': 0, 'update in batch': 421, '/': 18563, 'loss': 6.818246364593506}
{'epoch': 0, 'update in batch': 422, '/': 18563, 'loss': 7.220259189605713}
{'epoch': 0, 'update in batch': 423, '/': 18563, 'loss': 7.800910949707031}
{'epoch': 0, 'update in batch': 424, '/': 18563, 'loss': 8.175793647766113}
{'epoch': 0, 'update in batch': 425, '/': 18563, 'loss': 7.588067054748535}
{'epoch': 0, 'update in batch': 426, '/': 18563, 'loss': 7.2054619789123535}
{'epoch': 0, 'update in batch': 427, '/': 18563, 'loss': 7.6552839279174805}
{'epoch': 0, 'update in batch': 428, '/': 18563, 'loss': 8.851090431213379}
{'epoch': 0, 'update in batch': 429, '/': 18563, 'loss': 8.768563270568848}
{'epoch': 0, 'update in batch': 430, '/': 18563, 'loss': 7.926184177398682}
{'epoch': 0, 'update in batch': 431, '/': 18563, 'loss': 8.663213729858398}
{'epoch': 0, 'update in batch': 432, '/': 18563, 'loss': 8.386338233947754}
{'epoch': 0, 'update in batch': 433, '/': 18563, 'loss': 8.77399730682373}
{'epoch': 0, 'update in batch': 434, '/': 18563, 'loss': 8.385528564453125}
{'epoch': 0, 'update in batch': 435, '/': 18563, 'loss': 7.742388725280762}
{'epoch': 0, 'update in batch': 436, '/': 18563, 'loss': 8.363179206848145}
{'epoch': 0, 'update in batch': 437, '/': 18563, 'loss': 9.262784004211426}
{'epoch': 0, 'update in batch': 438, '/': 18563, 'loss': 9.236469268798828}
{'epoch': 0, 'update in batch': 439, '/': 18563, 'loss': 8.904603958129883}
{'epoch': 0, 'update in batch': 440, '/': 18563, 'loss': 8.675701141357422}
{'epoch': 0, 'update in batch': 441, '/': 18563, 'loss': 8.811418533325195}
{'epoch': 0, 'update in batch': 442, '/': 18563, 'loss': 8.002241134643555}
{'epoch': 0, 'update in batch': 443, '/': 18563, 'loss': 9.04414176940918}
{'epoch': 0, 'update in batch': 444, '/': 18563, 'loss': 7.8904008865356445}
{'epoch': 0, 'update in batch': 445, '/': 18563, 'loss': 8.524297714233398}
{'epoch': 0, 'update in batch': 446, '/': 18563, 'loss': 8.615904808044434}
{'epoch': 0, 'update in batch': 447, '/': 18563, 'loss': 8.201675415039062}
{'epoch': 0, 'update in batch': 448, '/': 18563, 'loss': 8.531024932861328}
{'epoch': 0, 'update in batch': 449, '/': 18563, 'loss': 7.8379621505737305}
{'epoch': 0, 'update in batch': 450, '/': 18563, 'loss': 8.416367530822754}
{'epoch': 0, 'update in batch': 451, '/': 18563, 'loss': 7.4990715980529785}
{'epoch': 0, 'update in batch': 452, '/': 18563, 'loss': 7.984610557556152}
{'epoch': 0, 'update in batch': 453, '/': 18563, 'loss': 7.719987392425537}
{'epoch': 0, 'update in batch': 454, '/': 18563, 'loss': 7.9333176612854}
{'epoch': 0, 'update in batch': 455, '/': 18563, 'loss': 8.619344711303711}
{'epoch': 0, 'update in batch': 456, '/': 18563, 'loss': 7.849525451660156}
{'epoch': 0, 'update in batch': 457, '/': 18563, 'loss': 7.700997352600098}
{'epoch': 0, 'update in batch': 458, '/': 18563, 'loss': 8.065767288208008}
{'epoch': 0, 'update in batch': 459, '/': 18563, 'loss': 7.489628791809082}
{'epoch': 0, 'update in batch': 460, '/': 18563, 'loss': 8.036481857299805}
{'epoch': 0, 'update in batch': 461, '/': 18563, 'loss': 8.227537155151367}
{'epoch': 0, 'update in batch': 462, '/': 18563, 'loss': 7.66103982925415}
{'epoch': 0, 'update in batch': 463, '/': 18563, 'loss': 8.481343269348145}
{'epoch': 0, 'update in batch': 464, '/': 18563, 'loss': 8.711318969726562}
{'epoch': 0, 'update in batch': 465, '/': 18563, 'loss': 7.549925804138184}
{'epoch': 0, 'update in batch': 466, '/': 18563, 'loss': 8.020782470703125}
{'epoch': 0, 'update in batch': 467, '/': 18563, 'loss': 7.784451484680176}
{'epoch': 0, 'update in batch': 468, '/': 18563, 'loss': 7.7545928955078125}
{'epoch': 0, 'update in batch': 469, '/': 18563, 'loss': 8.484171867370605}
{'epoch': 0, 'update in batch': 470, '/': 18563, 'loss': 8.291640281677246}
{'epoch': 0, 'update in batch': 471, '/': 18563, 'loss': 7.873322486877441}
{'epoch': 0, 'update in batch': 472, '/': 18563, 'loss': 7.891420841217041}
{'epoch': 0, 'update in batch': 473, '/': 18563, 'loss': 8.376962661743164}
{'epoch': 0, 'update in batch': 474, '/': 18563, 'loss': 8.147513389587402}
{'epoch': 0, 'update in batch': 475, '/': 18563, 'loss': 7.739943027496338}
{'epoch': 0, 'update in batch': 476, '/': 18563, 'loss': 7.52395486831665}
{'epoch': 0, 'update in batch': 477, '/': 18563, 'loss': 7.962507724761963}
{'epoch': 0, 'update in batch': 478, '/': 18563, 'loss': 7.61989688873291}
{'epoch': 0, 'update in batch': 479, '/': 18563, 'loss': 8.628551483154297}
{'epoch': 0, 'update in batch': 480, '/': 18563, 'loss': 10.344924926757812}
{'epoch': 0, 'update in batch': 481, '/': 18563, 'loss': 9.189457893371582}
{'epoch': 0, 'update in batch': 482, '/': 18563, 'loss': 9.283202171325684}
{'epoch': 0, 'update in batch': 483, '/': 18563, 'loss': 8.036226272583008}
{'epoch': 0, 'update in batch': 484, '/': 18563, 'loss': 8.949888229370117}
{'epoch': 0, 'update in batch': 485, '/': 18563, 'loss': 9.32779598236084}
{'epoch': 0, 'update in batch': 486, '/': 18563, 'loss': 9.554967880249023}
{'epoch': 0, 'update in batch': 487, '/': 18563, 'loss': 8.438692092895508}
{'epoch': 0, 'update in batch': 488, '/': 18563, 'loss': 8.015823364257812}
{'epoch': 0, 'update in batch': 489, '/': 18563, 'loss': 8.621005058288574}
{'epoch': 0, 'update in batch': 490, '/': 18563, 'loss': 8.432602882385254}
{'epoch': 0, 'update in batch': 491, '/': 18563, 'loss': 8.659430503845215}
{'epoch': 0, 'update in batch': 492, '/': 18563, 'loss': 8.693103790283203}
{'epoch': 0, 'update in batch': 493, '/': 18563, 'loss': 8.895064353942871}
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
<ipython-input-18-fe996a0be74b> in <module>
      1 model = Model(vocab_size = len(dataset.uniq_words)).to(device)
----> 2 train(dataset, model, 1, 64)

<ipython-input-17-8d700bc624e3> in train(dataset, model, max_epochs, batch_size)
     15             loss = criterion(y_pred.transpose(1, 2), y)
     16 
---> 17             loss.backward()
     18             optimizer.step()
     19 

~/anaconda3/lib/python3.8/site-packages/torch/_tensor.py in backward(self, gradient, retain_graph, create_graph, inputs)
    361                 create_graph=create_graph,
    362                 inputs=inputs)
--> 363         torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
    364 
    365     def register_hook(self, hook):

~/anaconda3/lib/python3.8/site-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
    171     # some Python versions print out the first line of a multi-line function
    172     # calls in the traceback and some print out the last line
--> 173     Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
    174         tensors, grad_tensors_, retain_graph, create_graph, inputs,
    175         allow_unreachable=True, accumulate_grad=True)  # Calls into the C++ engine to run the backward pass

KeyboardInterrupt: 
def predict(dataset, model, text, next_words=5):
    model.eval()
    words = text.split(' ')
    state_h, state_c = model.init_state(len(words))

    for i in range(0, next_words):
        x = torch.tensor([[dataset.word_to_index[w] for w in words[i:]]]).to(device)
        y_pred, (state_h, state_c) = model(x, (state_h, state_c))

        last_word_logits = y_pred[0][-1]
        p = torch.nn.functional.softmax(last_word_logits, dim=0).detach().cpu().numpy()
        word_index = np.random.choice(len(last_word_logits), p=p)
        words.append(dataset.index_to_word[word_index])

    return words
predict(dataset, model, 'kmicic szedł')
['kmicic', 'szedł', 'zwycięzco', 'po', 'do', 'zlituj', 'i']

ZADANIE 1

Stworzyć sieć rekurencyjną GRU dla Challenging America word-gap prediction. Wymogi takie jak zawsze, zadanie widoczne na gonito

ZADANIE 2

Podjąć wyzwanie na https://gonito.net/challenge/precipitation-pl i/lub https://gonito.net/challenge/book-dialogues-pl

KONIECZNIE należy je zgłosić do końca następnego piątku, czyli 20 maja!. Za późniejsze zgłoszenia (nawet minutę) nieprzyznaję punktów.

Za każde zgłoszenie lepsze niż baseline przyznaję 40 punktów.

Zamiast tych 40 punktów za najlepsze miejsca:

    1. miejsce 150 punktów
    1. miejsce 100 punktów
    1. miejsce 70 punktów

Można brać udział w 2 wyzwaniach jednocześnie.

Zadania nie będą widoczne w gonito w achievements. Nie trzeba udostępniać kodu, należy jednak przestrzegać regulaminu wyzwań.