moj-2024/lab/09_Model_neuronowy_rekurencyjny.ipynb
Paweł Skórzewski 1c2f3bf500 Lab. 10
2024-05-15 10:41:40 +02:00

60 KiB
Raw Blame History

Modelowanie języka laboratoria

24 kwietnia 2024

9. Model neuronowy rekurencyjny

import torch
from torch import nn, optim
from torch.utils.data import DataLoader
import numpy as np
from collections import Counter
import re
device = 'cpu'
! wget https://wolnelektury.pl/media/book/txt/potop-tom-pierwszy.txt
! wget https://wolnelektury.pl/media/book/txt/potop-tom-drugi.txt
! wget https://wolnelektury.pl/media/book/txt/potop-tom-trzeci.txt
--2022-05-08 19:27:04--  https://wolnelektury.pl/media/book/txt/potop-tom-pierwszy.txt
Resolving wolnelektury.pl (wolnelektury.pl)... 51.83.143.148, 2001:41d0:602:3294::
Connecting to wolnelektury.pl (wolnelektury.pl)|51.83.143.148|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 877893 (857K) [text/plain]
Saving to: potop-tom-pierwszy.txt.2

potop-tom-pierwszy. 100%[===================>] 857,32K  --.-KB/s    in 0,07s   

2022-05-08 19:27:04 (12,0 MB/s) - potop-tom-pierwszy.txt.2 saved [877893/877893]

--2022-05-08 19:27:04--  https://wolnelektury.pl/media/book/txt/potop-tom-drugi.txt
Resolving wolnelektury.pl (wolnelektury.pl)... 51.83.143.148, 2001:41d0:602:3294::
Connecting to wolnelektury.pl (wolnelektury.pl)|51.83.143.148|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1087797 (1,0M) [text/plain]
Saving to: potop-tom-drugi.txt.2

potop-tom-drugi.txt 100%[===================>]   1,04M  --.-KB/s    in 0,08s   

2022-05-08 19:27:04 (12,9 MB/s) - potop-tom-drugi.txt.2 saved [1087797/1087797]

--2022-05-08 19:27:05--  https://wolnelektury.pl/media/book/txt/potop-tom-trzeci.txt
Resolving wolnelektury.pl (wolnelektury.pl)... 51.83.143.148, 2001:41d0:602:3294::
Connecting to wolnelektury.pl (wolnelektury.pl)|51.83.143.148|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 788219 (770K) [text/plain]
Saving to: potop-tom-trzeci.txt.2

potop-tom-trzeci.tx 100%[===================>] 769,75K  --.-KB/s    in 0,06s   

2022-05-08 19:27:05 (12,0 MB/s) - potop-tom-trzeci.txt.2 saved [788219/788219]

!cat potop-* > potop.txt
class Dataset(torch.utils.data.Dataset):
    def __init__(
            self,
            sequence_length,
    ):
        self.sequence_length = sequence_length
        self.words = self.load()
        self.uniq_words = self.get_uniq_words()

        self.index_to_word = {index: word for index, word in enumerate(self.uniq_words)}
        self.word_to_index = {word: index for index, word in enumerate(self.uniq_words)}

        self.words_indexes = [self.word_to_index[w] for w in self.words]

    def load(self):
        with open('potop.txt', 'r') as f_in:
            text = [x.rstrip() for x in f_in.readlines() if x.strip()]
            text = ' '.join(text).lower()
            text = re.sub('[^a-ząćęłńóśźż ]', '', text) 
            text = text.split(' ')
        return text
    
    
    def get_uniq_words(self):
        word_counts = Counter(self.words)
        return sorted(word_counts, key=word_counts.get, reverse=True)

    def __len__(self):
        return len(self.words_indexes) - self.sequence_length

    def __getitem__(self, index):
        return (
            torch.tensor(self.words_indexes[index:index+self.sequence_length]),
            torch.tensor(self.words_indexes[index+1:index+self.sequence_length+1]),
        )
dataset = Dataset(5)
dataset[200]
(tensor([  551,    18,    17,   255, 10748]),
 tensor([   18,    17,   255, 10748,    34]))
[dataset.index_to_word[x] for x in [   551,    18,    17,   255, 10748]]
['patrzył', 'tak', 'jak', 'człowiek', 'zbudzony']
[dataset.index_to_word[x] for x in [   18,    17,   255, 10748,    34]]
['tak', 'jak', 'człowiek', 'zbudzony', 'ze']
input_tensor = torch.tensor([[ 551,    18,    17,   255, 10748]], dtype=torch.int32).to(device)
#input_tensor = torch.tensor([[ 551,    18]], dtype=torch.int32).to(device)
class Model(nn.Module):
    def __init__(self, vocab_size):
        super(Model, self).__init__()
        self.lstm_size = 128
        self.embedding_dim = 128
        self.num_layers = 3

        self.embedding = nn.Embedding(
            num_embeddings=vocab_size,
            embedding_dim=self.embedding_dim,
        )
        self.lstm = nn.LSTM(
            input_size=self.lstm_size,
            hidden_size=self.lstm_size,
            num_layers=self.num_layers,
            dropout=0.2,
        )
        self.fc = nn.Linear(self.lstm_size, vocab_size)

    def forward(self, x, prev_state = None):
        embed = self.embedding(x)
        output, state = self.lstm(embed, prev_state)
        logits = self.fc(output)
        return logits, state

    def init_state(self, sequence_length):
        return (torch.zeros(self.num_layers, sequence_length, self.lstm_size).to(device),
                torch.zeros(self.num_layers, sequence_length, self.lstm_size).to(device))
model = Model(len(dataset)).to(device)
y_pred, (state_h, state_c) = model(input_tensor)
y_pred
tensor([[[ 0.0046, -0.0113,  0.0313,  ...,  0.0198, -0.0312,  0.0223],
         [ 0.0039, -0.0110,  0.0303,  ...,  0.0213, -0.0302,  0.0230],
         [ 0.0029, -0.0133,  0.0265,  ...,  0.0204, -0.0297,  0.0219],
         [ 0.0010, -0.0120,  0.0282,  ...,  0.0241, -0.0314,  0.0241],
         [ 0.0038, -0.0106,  0.0346,  ...,  0.0230, -0.0333,  0.0232]]],
       grad_fn=<AddBackward0>)
y_pred.shape
torch.Size([1, 5, 1187998])
def train(dataset, model, max_epochs, batch_size):
    model.train()

    dataloader = DataLoader(dataset, batch_size=batch_size)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    for epoch in range(max_epochs):
        for batch, (x, y) in enumerate(dataloader):
            optimizer.zero_grad()
            x = x.to(device)
            y = y.to(device)

            y_pred, (state_h, state_c) = model(x)
            loss = criterion(y_pred.transpose(1, 2), y)

            loss.backward()
            optimizer.step()

            print({ 'epoch': epoch, 'update in batch': batch, '/' : len(dataloader), 'loss': loss.item() })
model = Model(vocab_size = len(dataset.uniq_words)).to(device)
train(dataset, model, 1, 64)
{'epoch': 0, 'update in batch': 0, '/': 18563, 'loss': 10.717817306518555}
{'epoch': 0, 'update in batch': 1, '/': 18563, 'loss': 10.699922561645508}
{'epoch': 0, 'update in batch': 2, '/': 18563, 'loss': 10.701103210449219}
{'epoch': 0, 'update in batch': 3, '/': 18563, 'loss': 10.700254440307617}
{'epoch': 0, 'update in batch': 4, '/': 18563, 'loss': 10.69465160369873}
{'epoch': 0, 'update in batch': 5, '/': 18563, 'loss': 10.681333541870117}
{'epoch': 0, 'update in batch': 6, '/': 18563, 'loss': 10.668376922607422}
{'epoch': 0, 'update in batch': 7, '/': 18563, 'loss': 10.675261497497559}
{'epoch': 0, 'update in batch': 8, '/': 18563, 'loss': 10.665823936462402}
{'epoch': 0, 'update in batch': 9, '/': 18563, 'loss': 10.655462265014648}
{'epoch': 0, 'update in batch': 10, '/': 18563, 'loss': 10.591516494750977}
{'epoch': 0, 'update in batch': 11, '/': 18563, 'loss': 10.580559730529785}
{'epoch': 0, 'update in batch': 12, '/': 18563, 'loss': 10.524133682250977}
{'epoch': 0, 'update in batch': 13, '/': 18563, 'loss': 10.480895042419434}
{'epoch': 0, 'update in batch': 14, '/': 18563, 'loss': 10.33996295928955}
{'epoch': 0, 'update in batch': 15, '/': 18563, 'loss': 10.345580101013184}
{'epoch': 0, 'update in batch': 16, '/': 18563, 'loss': 10.200639724731445}
{'epoch': 0, 'update in batch': 17, '/': 18563, 'loss': 10.030133247375488}
{'epoch': 0, 'update in batch': 18, '/': 18563, 'loss': 10.046720504760742}
{'epoch': 0, 'update in batch': 19, '/': 18563, 'loss': 10.00318717956543}
{'epoch': 0, 'update in batch': 20, '/': 18563, 'loss': 9.588350296020508}
{'epoch': 0, 'update in batch': 21, '/': 18563, 'loss': 9.780914306640625}
{'epoch': 0, 'update in batch': 22, '/': 18563, 'loss': 9.36646842956543}
{'epoch': 0, 'update in batch': 23, '/': 18563, 'loss': 9.306387901306152}
{'epoch': 0, 'update in batch': 24, '/': 18563, 'loss': 9.150574684143066}
{'epoch': 0, 'update in batch': 25, '/': 18563, 'loss': 8.89719295501709}
{'epoch': 0, 'update in batch': 26, '/': 18563, 'loss': 8.741975784301758}
{'epoch': 0, 'update in batch': 27, '/': 18563, 'loss': 9.36513614654541}
{'epoch': 0, 'update in batch': 28, '/': 18563, 'loss': 8.840768814086914}
{'epoch': 0, 'update in batch': 29, '/': 18563, 'loss': 8.356801986694336}
{'epoch': 0, 'update in batch': 30, '/': 18563, 'loss': 8.274016380310059}
{'epoch': 0, 'update in batch': 31, '/': 18563, 'loss': 8.944927215576172}
{'epoch': 0, 'update in batch': 32, '/': 18563, 'loss': 8.923280715942383}
{'epoch': 0, 'update in batch': 33, '/': 18563, 'loss': 8.479402542114258}
{'epoch': 0, 'update in batch': 34, '/': 18563, 'loss': 8.42425537109375}
{'epoch': 0, 'update in batch': 35, '/': 18563, 'loss': 9.487113952636719}
{'epoch': 0, 'update in batch': 36, '/': 18563, 'loss': 8.314191818237305}
{'epoch': 0, 'update in batch': 37, '/': 18563, 'loss': 8.0274658203125}
{'epoch': 0, 'update in batch': 38, '/': 18563, 'loss': 8.725769996643066}
{'epoch': 0, 'update in batch': 39, '/': 18563, 'loss': 8.67934799194336}
{'epoch': 0, 'update in batch': 40, '/': 18563, 'loss': 8.872161865234375}
{'epoch': 0, 'update in batch': 41, '/': 18563, 'loss': 7.883971214294434}
{'epoch': 0, 'update in batch': 42, '/': 18563, 'loss': 7.682810306549072}
{'epoch': 0, 'update in batch': 43, '/': 18563, 'loss': 7.880677223205566}
{'epoch': 0, 'update in batch': 44, '/': 18563, 'loss': 7.807427406311035}
{'epoch': 0, 'update in batch': 45, '/': 18563, 'loss': 7.93829870223999}
{'epoch': 0, 'update in batch': 46, '/': 18563, 'loss': 7.718912601470947}
{'epoch': 0, 'update in batch': 47, '/': 18563, 'loss': 8.309863090515137}
{'epoch': 0, 'update in batch': 48, '/': 18563, 'loss': 9.091133117675781}
{'epoch': 0, 'update in batch': 49, '/': 18563, 'loss': 9.317312240600586}
{'epoch': 0, 'update in batch': 50, '/': 18563, 'loss': 8.517735481262207}
{'epoch': 0, 'update in batch': 51, '/': 18563, 'loss': 7.697592258453369}
{'epoch': 0, 'update in batch': 52, '/': 18563, 'loss': 6.838181972503662}
{'epoch': 0, 'update in batch': 53, '/': 18563, 'loss': 7.967227935791016}
{'epoch': 0, 'update in batch': 54, '/': 18563, 'loss': 8.47049331665039}
{'epoch': 0, 'update in batch': 55, '/': 18563, 'loss': 8.958921432495117}
{'epoch': 0, 'update in batch': 56, '/': 18563, 'loss': 8.316679000854492}
{'epoch': 0, 'update in batch': 57, '/': 18563, 'loss': 8.997099876403809}
{'epoch': 0, 'update in batch': 58, '/': 18563, 'loss': 8.608811378479004}
{'epoch': 0, 'update in batch': 59, '/': 18563, 'loss': 9.377460479736328}
{'epoch': 0, 'update in batch': 60, '/': 18563, 'loss': 8.6201171875}
{'epoch': 0, 'update in batch': 61, '/': 18563, 'loss': 8.821510314941406}
{'epoch': 0, 'update in batch': 62, '/': 18563, 'loss': 8.915961265563965}
{'epoch': 0, 'update in batch': 63, '/': 18563, 'loss': 8.222617149353027}
{'epoch': 0, 'update in batch': 64, '/': 18563, 'loss': 9.266777992248535}
{'epoch': 0, 'update in batch': 65, '/': 18563, 'loss': 8.749354362487793}
{'epoch': 0, 'update in batch': 66, '/': 18563, 'loss': 8.311641693115234}
{'epoch': 0, 'update in batch': 67, '/': 18563, 'loss': 8.553888320922852}
{'epoch': 0, 'update in batch': 68, '/': 18563, 'loss': 8.790258407592773}
{'epoch': 0, 'update in batch': 69, '/': 18563, 'loss': 9.090133666992188}
{'epoch': 0, 'update in batch': 70, '/': 18563, 'loss': 8.893723487854004}
{'epoch': 0, 'update in batch': 71, '/': 18563, 'loss': 8.844594955444336}
{'epoch': 0, 'update in batch': 72, '/': 18563, 'loss': 7.771625518798828}
{'epoch': 0, 'update in batch': 73, '/': 18563, 'loss': 8.536479949951172}
{'epoch': 0, 'update in batch': 74, '/': 18563, 'loss': 7.300860404968262}
{'epoch': 0, 'update in batch': 75, '/': 18563, 'loss': 8.62000846862793}
{'epoch': 0, 'update in batch': 76, '/': 18563, 'loss': 8.67784309387207}
{'epoch': 0, 'update in batch': 77, '/': 18563, 'loss': 7.319235801696777}
{'epoch': 0, 'update in batch': 78, '/': 18563, 'loss': 8.322186470031738}
{'epoch': 0, 'update in batch': 79, '/': 18563, 'loss': 7.767421722412109}
{'epoch': 0, 'update in batch': 80, '/': 18563, 'loss': 8.817885398864746}
{'epoch': 0, 'update in batch': 81, '/': 18563, 'loss': 8.133109092712402}
{'epoch': 0, 'update in batch': 82, '/': 18563, 'loss': 7.822054862976074}
{'epoch': 0, 'update in batch': 83, '/': 18563, 'loss': 8.055540084838867}
{'epoch': 0, 'update in batch': 84, '/': 18563, 'loss': 8.053682327270508}
{'epoch': 0, 'update in batch': 85, '/': 18563, 'loss': 8.018306732177734}
{'epoch': 0, 'update in batch': 86, '/': 18563, 'loss': 8.371909141540527}
{'epoch': 0, 'update in batch': 87, '/': 18563, 'loss': 8.057979583740234}
{'epoch': 0, 'update in batch': 88, '/': 18563, 'loss': 8.340703010559082}
{'epoch': 0, 'update in batch': 89, '/': 18563, 'loss': 8.7703857421875}
{'epoch': 0, 'update in batch': 90, '/': 18563, 'loss': 9.714847564697266}
{'epoch': 0, 'update in batch': 91, '/': 18563, 'loss': 8.621702194213867}
{'epoch': 0, 'update in batch': 92, '/': 18563, 'loss': 9.406997680664062}
{'epoch': 0, 'update in batch': 93, '/': 18563, 'loss': 9.29774284362793}
{'epoch': 0, 'update in batch': 94, '/': 18563, 'loss': 8.649836540222168}
{'epoch': 0, 'update in batch': 95, '/': 18563, 'loss': 8.441780090332031}
{'epoch': 0, 'update in batch': 96, '/': 18563, 'loss': 7.991406440734863}
{'epoch': 0, 'update in batch': 97, '/': 18563, 'loss': 9.314489364624023}
{'epoch': 0, 'update in batch': 98, '/': 18563, 'loss': 8.368816375732422}
{'epoch': 0, 'update in batch': 99, '/': 18563, 'loss': 8.771149635314941}
{'epoch': 0, 'update in batch': 100, '/': 18563, 'loss': 7.8758111000061035}
{'epoch': 0, 'update in batch': 101, '/': 18563, 'loss': 8.341328620910645}
{'epoch': 0, 'update in batch': 102, '/': 18563, 'loss': 8.413129806518555}
{'epoch': 0, 'update in batch': 103, '/': 18563, 'loss': 7.372011661529541}
{'epoch': 0, 'update in batch': 104, '/': 18563, 'loss': 8.170934677124023}
{'epoch': 0, 'update in batch': 105, '/': 18563, 'loss': 8.109993934631348}
{'epoch': 0, 'update in batch': 106, '/': 18563, 'loss': 8.172578811645508}
{'epoch': 0, 'update in batch': 107, '/': 18563, 'loss': 8.33222484588623}
{'epoch': 0, 'update in batch': 108, '/': 18563, 'loss': 7.997575283050537}
{'epoch': 0, 'update in batch': 109, '/': 18563, 'loss': 7.847937107086182}
{'epoch': 0, 'update in batch': 110, '/': 18563, 'loss': 7.351314544677734}
{'epoch': 0, 'update in batch': 111, '/': 18563, 'loss': 8.472936630249023}
{'epoch': 0, 'update in batch': 112, '/': 18563, 'loss': 7.855953216552734}
{'epoch': 0, 'update in batch': 113, '/': 18563, 'loss': 8.163175582885742}
{'epoch': 0, 'update in batch': 114, '/': 18563, 'loss': 8.208657264709473}
{'epoch': 0, 'update in batch': 115, '/': 18563, 'loss': 8.781523704528809}
{'epoch': 0, 'update in batch': 116, '/': 18563, 'loss': 8.449674606323242}
{'epoch': 0, 'update in batch': 117, '/': 18563, 'loss': 8.176030158996582}
{'epoch': 0, 'update in batch': 118, '/': 18563, 'loss': 8.415689468383789}
{'epoch': 0, 'update in batch': 119, '/': 18563, 'loss': 8.645845413208008}
{'epoch': 0, 'update in batch': 120, '/': 18563, 'loss': 8.160420417785645}
{'epoch': 0, 'update in batch': 121, '/': 18563, 'loss': 8.117982864379883}
{'epoch': 0, 'update in batch': 122, '/': 18563, 'loss': 9.099283218383789}
{'epoch': 0, 'update in batch': 123, '/': 18563, 'loss': 7.98253870010376}
{'epoch': 0, 'update in batch': 124, '/': 18563, 'loss': 8.112133979797363}
{'epoch': 0, 'update in batch': 125, '/': 18563, 'loss': 8.479134559631348}
{'epoch': 0, 'update in batch': 126, '/': 18563, 'loss': 8.92817497253418}
{'epoch': 0, 'update in batch': 127, '/': 18563, 'loss': 8.38918399810791}
{'epoch': 0, 'update in batch': 128, '/': 18563, 'loss': 9.000529289245605}
{'epoch': 0, 'update in batch': 129, '/': 18563, 'loss': 8.525534629821777}
{'epoch': 0, 'update in batch': 130, '/': 18563, 'loss': 9.055428504943848}
{'epoch': 0, 'update in batch': 131, '/': 18563, 'loss': 8.818662643432617}
{'epoch': 0, 'update in batch': 132, '/': 18563, 'loss': 8.807767868041992}
{'epoch': 0, 'update in batch': 133, '/': 18563, 'loss': 8.398343086242676}
{'epoch': 0, 'update in batch': 134, '/': 18563, 'loss': 8.435093879699707}
{'epoch': 0, 'update in batch': 135, '/': 18563, 'loss': 7.877000331878662}
{'epoch': 0, 'update in batch': 136, '/': 18563, 'loss': 8.197925567626953}
{'epoch': 0, 'update in batch': 137, '/': 18563, 'loss': 8.655011177062988}
{'epoch': 0, 'update in batch': 138, '/': 18563, 'loss': 7.786923885345459}
{'epoch': 0, 'update in batch': 139, '/': 18563, 'loss': 8.338996887207031}
{'epoch': 0, 'update in batch': 140, '/': 18563, 'loss': 8.607789993286133}
{'epoch': 0, 'update in batch': 141, '/': 18563, 'loss': 8.52219295501709}
{'epoch': 0, 'update in batch': 142, '/': 18563, 'loss': 8.436418533325195}
{'epoch': 0, 'update in batch': 143, '/': 18563, 'loss': 7.999323844909668}
{'epoch': 0, 'update in batch': 144, '/': 18563, 'loss': 7.543336391448975}
{'epoch': 0, 'update in batch': 145, '/': 18563, 'loss': 7.3255791664123535}
{'epoch': 0, 'update in batch': 146, '/': 18563, 'loss': 7.993613243103027}
{'epoch': 0, 'update in batch': 147, '/': 18563, 'loss': 8.8505859375}
{'epoch': 0, 'update in batch': 148, '/': 18563, 'loss': 8.146835327148438}
{'epoch': 0, 'update in batch': 149, '/': 18563, 'loss': 8.532424926757812}
{'epoch': 0, 'update in batch': 150, '/': 18563, 'loss': 8.323905944824219}
{'epoch': 0, 'update in batch': 151, '/': 18563, 'loss': 7.8726677894592285}
{'epoch': 0, 'update in batch': 152, '/': 18563, 'loss': 7.912005424499512}
{'epoch': 0, 'update in batch': 153, '/': 18563, 'loss': 8.010560035705566}
{'epoch': 0, 'update in batch': 154, '/': 18563, 'loss': 7.9417009353637695}
{'epoch': 0, 'update in batch': 155, '/': 18563, 'loss': 7.991711616516113}
{'epoch': 0, 'update in batch': 156, '/': 18563, 'loss': 8.27558708190918}
{'epoch': 0, 'update in batch': 157, '/': 18563, 'loss': 7.736246585845947}
{'epoch': 0, 'update in batch': 158, '/': 18563, 'loss': 7.4755754470825195}
{'epoch': 0, 'update in batch': 159, '/': 18563, 'loss': 8.023443222045898}
{'epoch': 0, 'update in batch': 160, '/': 18563, 'loss': 8.130350112915039}
{'epoch': 0, 'update in batch': 161, '/': 18563, 'loss': 7.770634651184082}
{'epoch': 0, 'update in batch': 162, '/': 18563, 'loss': 7.775434970855713}
{'epoch': 0, 'update in batch': 163, '/': 18563, 'loss': 7.965312957763672}
{'epoch': 0, 'update in batch': 164, '/': 18563, 'loss': 7.977341651916504}
{'epoch': 0, 'update in batch': 165, '/': 18563, 'loss': 7.703671455383301}
{'epoch': 0, 'update in batch': 166, '/': 18563, 'loss': 8.027135848999023}
{'epoch': 0, 'update in batch': 167, '/': 18563, 'loss': 7.7673773765563965}
{'epoch': 0, 'update in batch': 168, '/': 18563, 'loss': 8.654549598693848}
{'epoch': 0, 'update in batch': 169, '/': 18563, 'loss': 7.8060808181762695}
{'epoch': 0, 'update in batch': 170, '/': 18563, 'loss': 7.33704137802124}
{'epoch': 0, 'update in batch': 171, '/': 18563, 'loss': 7.971919059753418}
{'epoch': 0, 'update in batch': 172, '/': 18563, 'loss': 7.450611114501953}
{'epoch': 0, 'update in batch': 173, '/': 18563, 'loss': 7.978057861328125}
{'epoch': 0, 'update in batch': 174, '/': 18563, 'loss': 8.264434814453125}
{'epoch': 0, 'update in batch': 175, '/': 18563, 'loss': 8.47761058807373}
{'epoch': 0, 'update in batch': 176, '/': 18563, 'loss': 7.643885135650635}
{'epoch': 0, 'update in batch': 177, '/': 18563, 'loss': 8.696805000305176}
{'epoch': 0, 'update in batch': 178, '/': 18563, 'loss': 9.144462585449219}
{'epoch': 0, 'update in batch': 179, '/': 18563, 'loss': 8.582620620727539}
{'epoch': 0, 'update in batch': 180, '/': 18563, 'loss': 8.495562553405762}
{'epoch': 0, 'update in batch': 181, '/': 18563, 'loss': 9.259647369384766}
{'epoch': 0, 'update in batch': 182, '/': 18563, 'loss': 8.286632537841797}
{'epoch': 0, 'update in batch': 183, '/': 18563, 'loss': 8.378074645996094}
{'epoch': 0, 'update in batch': 184, '/': 18563, 'loss': 8.404892921447754}
{'epoch': 0, 'update in batch': 185, '/': 18563, 'loss': 9.206843376159668}
{'epoch': 0, 'update in batch': 186, '/': 18563, 'loss': 8.97215747833252}
{'epoch': 0, 'update in batch': 187, '/': 18563, 'loss': 8.281005859375}
{'epoch': 0, 'update in batch': 188, '/': 18563, 'loss': 7.638144493103027}
{'epoch': 0, 'update in batch': 189, '/': 18563, 'loss': 7.991082668304443}
{'epoch': 0, 'update in batch': 190, '/': 18563, 'loss': 8.207674026489258}
{'epoch': 0, 'update in batch': 191, '/': 18563, 'loss': 8.16801643371582}
{'epoch': 0, 'update in batch': 192, '/': 18563, 'loss': 7.827309608459473}
{'epoch': 0, 'update in batch': 193, '/': 18563, 'loss': 8.387285232543945}
{'epoch': 0, 'update in batch': 194, '/': 18563, 'loss': 7.990261077880859}
{'epoch': 0, 'update in batch': 195, '/': 18563, 'loss': 7.7953925132751465}
{'epoch': 0, 'update in batch': 196, '/': 18563, 'loss': 7.252983093261719}
{'epoch': 0, 'update in batch': 197, '/': 18563, 'loss': 7.806585788726807}
{'epoch': 0, 'update in batch': 198, '/': 18563, 'loss': 7.871600151062012}
{'epoch': 0, 'update in batch': 199, '/': 18563, 'loss': 7.639830589294434}
{'epoch': 0, 'update in batch': 200, '/': 18563, 'loss': 8.108308792114258}
{'epoch': 0, 'update in batch': 201, '/': 18563, 'loss': 7.41513729095459}
{'epoch': 0, 'update in batch': 202, '/': 18563, 'loss': 8.103743553161621}
{'epoch': 0, 'update in batch': 203, '/': 18563, 'loss': 8.82174301147461}
{'epoch': 0, 'update in batch': 204, '/': 18563, 'loss': 8.34859561920166}
{'epoch': 0, 'update in batch': 205, '/': 18563, 'loss': 7.890545845031738}
{'epoch': 0, 'update in batch': 206, '/': 18563, 'loss': 7.679532527923584}
{'epoch': 0, 'update in batch': 207, '/': 18563, 'loss': 7.810311317443848}
{'epoch': 0, 'update in batch': 208, '/': 18563, 'loss': 8.342585563659668}
{'epoch': 0, 'update in batch': 209, '/': 18563, 'loss': 8.253597259521484}
{'epoch': 0, 'update in batch': 210, '/': 18563, 'loss': 7.963072299957275}
{'epoch': 0, 'update in batch': 211, '/': 18563, 'loss': 8.537101745605469}
{'epoch': 0, 'update in batch': 212, '/': 18563, 'loss': 8.503724098205566}
{'epoch': 0, 'update in batch': 213, '/': 18563, 'loss': 8.568987846374512}
{'epoch': 0, 'update in batch': 214, '/': 18563, 'loss': 7.760678291320801}
{'epoch': 0, 'update in batch': 215, '/': 18563, 'loss': 8.302183151245117}
{'epoch': 0, 'update in batch': 216, '/': 18563, 'loss': 7.427420616149902}
{'epoch': 0, 'update in batch': 217, '/': 18563, 'loss': 8.05746078491211}
{'epoch': 0, 'update in batch': 218, '/': 18563, 'loss': 8.82285213470459}
{'epoch': 0, 'update in batch': 219, '/': 18563, 'loss': 7.948827266693115}
{'epoch': 0, 'update in batch': 220, '/': 18563, 'loss': 8.164112091064453}
{'epoch': 0, 'update in batch': 221, '/': 18563, 'loss': 7.721047401428223}
{'epoch': 0, 'update in batch': 222, '/': 18563, 'loss': 7.668707370758057}
{'epoch': 0, 'update in batch': 223, '/': 18563, 'loss': 8.576696395874023}
{'epoch': 0, 'update in batch': 224, '/': 18563, 'loss': 8.253091812133789}
{'epoch': 0, 'update in batch': 225, '/': 18563, 'loss': 8.303543090820312}
{'epoch': 0, 'update in batch': 226, '/': 18563, 'loss': 8.069855690002441}
{'epoch': 0, 'update in batch': 227, '/': 18563, 'loss': 8.57229232788086}
{'epoch': 0, 'update in batch': 228, '/': 18563, 'loss': 8.904585838317871}
{'epoch': 0, 'update in batch': 229, '/': 18563, 'loss': 8.485595703125}
{'epoch': 0, 'update in batch': 230, '/': 18563, 'loss': 8.22756290435791}
{'epoch': 0, 'update in batch': 231, '/': 18563, 'loss': 8.281603813171387}
{'epoch': 0, 'update in batch': 232, '/': 18563, 'loss': 7.591467380523682}
{'epoch': 0, 'update in batch': 233, '/': 18563, 'loss': 7.8028883934021}
{'epoch': 0, 'update in batch': 234, '/': 18563, 'loss': 8.079168319702148}
{'epoch': 0, 'update in batch': 235, '/': 18563, 'loss': 7.578390598297119}
{'epoch': 0, 'update in batch': 236, '/': 18563, 'loss': 7.865830421447754}
{'epoch': 0, 'update in batch': 237, '/': 18563, 'loss': 7.105422019958496}
{'epoch': 0, 'update in batch': 238, '/': 18563, 'loss': 8.034143447875977}
{'epoch': 0, 'update in batch': 239, '/': 18563, 'loss': 7.23009729385376}
{'epoch': 0, 'update in batch': 240, '/': 18563, 'loss': 7.221669673919678}
{'epoch': 0, 'update in batch': 241, '/': 18563, 'loss': 7.118913173675537}
{'epoch': 0, 'update in batch': 242, '/': 18563, 'loss': 7.690147399902344}
{'epoch': 0, 'update in batch': 243, '/': 18563, 'loss': 7.676979064941406}
{'epoch': 0, 'update in batch': 244, '/': 18563, 'loss': 8.231537818908691}
{'epoch': 0, 'update in batch': 245, '/': 18563, 'loss': 8.212566375732422}
{'epoch': 0, 'update in batch': 246, '/': 18563, 'loss': 9.095616340637207}
{'epoch': 0, 'update in batch': 247, '/': 18563, 'loss': 8.249703407287598}
{'epoch': 0, 'update in batch': 248, '/': 18563, 'loss': 9.082058906555176}
{'epoch': 0, 'update in batch': 249, '/': 18563, 'loss': 8.530516624450684}
{'epoch': 0, 'update in batch': 250, '/': 18563, 'loss': 8.979915618896484}
{'epoch': 0, 'update in batch': 251, '/': 18563, 'loss': 8.667882919311523}
{'epoch': 0, 'update in batch': 252, '/': 18563, 'loss': 8.804525375366211}
{'epoch': 0, 'update in batch': 253, '/': 18563, 'loss': 8.67729377746582}
{'epoch': 0, 'update in batch': 254, '/': 18563, 'loss': 8.580761909484863}
{'epoch': 0, 'update in batch': 255, '/': 18563, 'loss': 7.724173545837402}
{'epoch': 0, 'update in batch': 256, '/': 18563, 'loss': 7.7925591468811035}
{'epoch': 0, 'update in batch': 257, '/': 18563, 'loss': 7.731482028961182}
{'epoch': 0, 'update in batch': 258, '/': 18563, 'loss': 7.644040107727051}
{'epoch': 0, 'update in batch': 259, '/': 18563, 'loss': 7.947877407073975}
{'epoch': 0, 'update in batch': 260, '/': 18563, 'loss': 7.649043083190918}
{'epoch': 0, 'update in batch': 261, '/': 18563, 'loss': 7.40912389755249}
{'epoch': 0, 'update in batch': 262, '/': 18563, 'loss': 8.199918746948242}
{'epoch': 0, 'update in batch': 263, '/': 18563, 'loss': 7.272132873535156}
{'epoch': 0, 'update in batch': 264, '/': 18563, 'loss': 7.205214500427246}
{'epoch': 0, 'update in batch': 265, '/': 18563, 'loss': 8.999595642089844}
{'epoch': 0, 'update in batch': 266, '/': 18563, 'loss': 7.851510524749756}
{'epoch': 0, 'update in batch': 267, '/': 18563, 'loss': 7.748948097229004}
{'epoch': 0, 'update in batch': 268, '/': 18563, 'loss': 7.96875}
{'epoch': 0, 'update in batch': 269, '/': 18563, 'loss': 7.627255916595459}
{'epoch': 0, 'update in batch': 270, '/': 18563, 'loss': 7.719862937927246}
{'epoch': 0, 'update in batch': 271, '/': 18563, 'loss': 7.58780574798584}
{'epoch': 0, 'update in batch': 272, '/': 18563, 'loss': 8.386865615844727}
{'epoch': 0, 'update in batch': 273, '/': 18563, 'loss': 8.708396911621094}
{'epoch': 0, 'update in batch': 274, '/': 18563, 'loss': 7.853432655334473}
{'epoch': 0, 'update in batch': 275, '/': 18563, 'loss': 7.818131923675537}
{'epoch': 0, 'update in batch': 276, '/': 18563, 'loss': 7.714521884918213}
{'epoch': 0, 'update in batch': 277, '/': 18563, 'loss': 8.75371265411377}
{'epoch': 0, 'update in batch': 278, '/': 18563, 'loss': 7.6992998123168945}
{'epoch': 0, 'update in batch': 279, '/': 18563, 'loss': 7.652693748474121}
{'epoch': 0, 'update in batch': 280, '/': 18563, 'loss': 7.364585876464844}
{'epoch': 0, 'update in batch': 281, '/': 18563, 'loss': 7.742022514343262}
{'epoch': 0, 'update in batch': 282, '/': 18563, 'loss': 7.6205573081970215}
{'epoch': 0, 'update in batch': 283, '/': 18563, 'loss': 7.475846290588379}
{'epoch': 0, 'update in batch': 284, '/': 18563, 'loss': 7.302148342132568}
{'epoch': 0, 'update in batch': 285, '/': 18563, 'loss': 7.524351596832275}
{'epoch': 0, 'update in batch': 286, '/': 18563, 'loss': 7.755963325500488}
{'epoch': 0, 'update in batch': 287, '/': 18563, 'loss': 7.620995998382568}
{'epoch': 0, 'update in batch': 288, '/': 18563, 'loss': 7.289975166320801}
{'epoch': 0, 'update in batch': 289, '/': 18563, 'loss': 7.470652103424072}
{'epoch': 0, 'update in batch': 290, '/': 18563, 'loss': 7.297110557556152}
{'epoch': 0, 'update in batch': 291, '/': 18563, 'loss': 7.907563209533691}
{'epoch': 0, 'update in batch': 292, '/': 18563, 'loss': 8.051852226257324}
{'epoch': 0, 'update in batch': 293, '/': 18563, 'loss': 6.691899299621582}
{'epoch': 0, 'update in batch': 294, '/': 18563, 'loss': 7.9747819900512695}
{'epoch': 0, 'update in batch': 295, '/': 18563, 'loss': 7.415904998779297}
{'epoch': 0, 'update in batch': 296, '/': 18563, 'loss': 7.479670524597168}
{'epoch': 0, 'update in batch': 297, '/': 18563, 'loss': 7.9454755783081055}
{'epoch': 0, 'update in batch': 298, '/': 18563, 'loss': 7.79656457901001}
{'epoch': 0, 'update in batch': 299, '/': 18563, 'loss': 7.644859313964844}
{'epoch': 0, 'update in batch': 300, '/': 18563, 'loss': 7.649240970611572}
{'epoch': 0, 'update in batch': 301, '/': 18563, 'loss': 7.497203826904297}
{'epoch': 0, 'update in batch': 302, '/': 18563, 'loss': 7.169632911682129}
{'epoch': 0, 'update in batch': 303, '/': 18563, 'loss': 7.124764442443848}
{'epoch': 0, 'update in batch': 304, '/': 18563, 'loss': 7.728893280029297}
{'epoch': 0, 'update in batch': 305, '/': 18563, 'loss': 8.029245376586914}
{'epoch': 0, 'update in batch': 306, '/': 18563, 'loss': 7.361662864685059}
{'epoch': 0, 'update in batch': 307, '/': 18563, 'loss': 8.070173263549805}
{'epoch': 0, 'update in batch': 308, '/': 18563, 'loss': 7.55655574798584}
{'epoch': 0, 'update in batch': 309, '/': 18563, 'loss': 7.713553428649902}
{'epoch': 0, 'update in batch': 310, '/': 18563, 'loss': 8.333553314208984}
{'epoch': 0, 'update in batch': 311, '/': 18563, 'loss': 8.089872360229492}
{'epoch': 0, 'update in batch': 312, '/': 18563, 'loss': 8.951356887817383}
{'epoch': 0, 'update in batch': 313, '/': 18563, 'loss': 8.920665740966797}
{'epoch': 0, 'update in batch': 314, '/': 18563, 'loss': 8.811259269714355}
{'epoch': 0, 'update in batch': 315, '/': 18563, 'loss': 8.719802856445312}
{'epoch': 0, 'update in batch': 316, '/': 18563, 'loss': 8.700776100158691}
{'epoch': 0, 'update in batch': 317, '/': 18563, 'loss': 8.846036911010742}
{'epoch': 0, 'update in batch': 318, '/': 18563, 'loss': 8.553533554077148}
{'epoch': 0, 'update in batch': 319, '/': 18563, 'loss': 9.257116317749023}
{'epoch': 0, 'update in batch': 320, '/': 18563, 'loss': 8.487042427062988}
{'epoch': 0, 'update in batch': 321, '/': 18563, 'loss': 8.743330955505371}
{'epoch': 0, 'update in batch': 322, '/': 18563, 'loss': 8.377813339233398}
{'epoch': 0, 'update in batch': 323, '/': 18563, 'loss': 8.41798210144043}
{'epoch': 0, 'update in batch': 324, '/': 18563, 'loss': 7.884764671325684}
{'epoch': 0, 'update in batch': 325, '/': 18563, 'loss': 8.827409744262695}
{'epoch': 0, 'update in batch': 326, '/': 18563, 'loss': 8.21721363067627}
{'epoch': 0, 'update in batch': 327, '/': 18563, 'loss': 8.522723197937012}
{'epoch': 0, 'update in batch': 328, '/': 18563, 'loss': 7.387178897857666}
{'epoch': 0, 'update in batch': 329, '/': 18563, 'loss': 8.58663558959961}
{'epoch': 0, 'update in batch': 330, '/': 18563, 'loss': 8.539435386657715}
{'epoch': 0, 'update in batch': 331, '/': 18563, 'loss': 8.35865592956543}
{'epoch': 0, 'update in batch': 332, '/': 18563, 'loss': 8.55555248260498}
{'epoch': 0, 'update in batch': 333, '/': 18563, 'loss': 7.9116950035095215}
{'epoch': 0, 'update in batch': 334, '/': 18563, 'loss': 8.424735069274902}
{'epoch': 0, 'update in batch': 335, '/': 18563, 'loss': 8.383890151977539}
{'epoch': 0, 'update in batch': 336, '/': 18563, 'loss': 8.145454406738281}
{'epoch': 0, 'update in batch': 337, '/': 18563, 'loss': 8.014772415161133}
{'epoch': 0, 'update in batch': 338, '/': 18563, 'loss': 8.532005310058594}
{'epoch': 0, 'update in batch': 339, '/': 18563, 'loss': 8.979973793029785}
{'epoch': 0, 'update in batch': 340, '/': 18563, 'loss': 8.3964204788208}
{'epoch': 0, 'update in batch': 341, '/': 18563, 'loss': 8.34205150604248}
{'epoch': 0, 'update in batch': 342, '/': 18563, 'loss': 7.861489295959473}
{'epoch': 0, 'update in batch': 343, '/': 18563, 'loss': 8.807058334350586}
{'epoch': 0, 'update in batch': 344, '/': 18563, 'loss': 8.14976978302002}
{'epoch': 0, 'update in batch': 345, '/': 18563, 'loss': 8.212860107421875}
{'epoch': 0, 'update in batch': 346, '/': 18563, 'loss': 8.323419570922852}
{'epoch': 0, 'update in batch': 347, '/': 18563, 'loss': 9.06071662902832}
{'epoch': 0, 'update in batch': 348, '/': 18563, 'loss': 8.79192066192627}
{'epoch': 0, 'update in batch': 349, '/': 18563, 'loss': 8.717201232910156}
{'epoch': 0, 'update in batch': 350, '/': 18563, 'loss': 8.149703979492188}
{'epoch': 0, 'update in batch': 351, '/': 18563, 'loss': 7.990046501159668}
{'epoch': 0, 'update in batch': 352, '/': 18563, 'loss': 7.8197221755981445}
{'epoch': 0, 'update in batch': 353, '/': 18563, 'loss': 8.022729873657227}
{'epoch': 0, 'update in batch': 354, '/': 18563, 'loss': 8.339923858642578}
{'epoch': 0, 'update in batch': 355, '/': 18563, 'loss': 7.867880821228027}
{'epoch': 0, 'update in batch': 356, '/': 18563, 'loss': 8.161782264709473}
{'epoch': 0, 'update in batch': 357, '/': 18563, 'loss': 7.711170196533203}
{'epoch': 0, 'update in batch': 358, '/': 18563, 'loss': 8.46279239654541}
{'epoch': 0, 'update in batch': 359, '/': 18563, 'loss': 8.327804565429688}
{'epoch': 0, 'update in batch': 360, '/': 18563, 'loss': 8.184597969055176}
{'epoch': 0, 'update in batch': 361, '/': 18563, 'loss': 8.126212120056152}
{'epoch': 0, 'update in batch': 362, '/': 18563, 'loss': 8.122446060180664}
{'epoch': 0, 'update in batch': 363, '/': 18563, 'loss': 7.730257511138916}
{'epoch': 0, 'update in batch': 364, '/': 18563, 'loss': 7.7179059982299805}
{'epoch': 0, 'update in batch': 365, '/': 18563, 'loss': 7.557857513427734}
{'epoch': 0, 'update in batch': 366, '/': 18563, 'loss': 8.614083290100098}
{'epoch': 0, 'update in batch': 367, '/': 18563, 'loss': 8.0489501953125}
{'epoch': 0, 'update in batch': 368, '/': 18563, 'loss': 8.355381965637207}
{'epoch': 0, 'update in batch': 369, '/': 18563, 'loss': 7.592991828918457}
{'epoch': 0, 'update in batch': 370, '/': 18563, 'loss': 7.674102783203125}
{'epoch': 0, 'update in batch': 371, '/': 18563, 'loss': 7.818256378173828}
{'epoch': 0, 'update in batch': 372, '/': 18563, 'loss': 8.510438919067383}
{'epoch': 0, 'update in batch': 373, '/': 18563, 'loss': 8.02087116241455}
{'epoch': 0, 'update in batch': 374, '/': 18563, 'loss': 8.206090927124023}
{'epoch': 0, 'update in batch': 375, '/': 18563, 'loss': 7.645677089691162}
{'epoch': 0, 'update in batch': 376, '/': 18563, 'loss': 8.241236686706543}
{'epoch': 0, 'update in batch': 377, '/': 18563, 'loss': 8.581649780273438}
{'epoch': 0, 'update in batch': 378, '/': 18563, 'loss': 9.361258506774902}
{'epoch': 0, 'update in batch': 379, '/': 18563, 'loss': 9.097440719604492}
{'epoch': 0, 'update in batch': 380, '/': 18563, 'loss': 8.081677436828613}
{'epoch': 0, 'update in batch': 381, '/': 18563, 'loss': 8.761143684387207}
{'epoch': 0, 'update in batch': 382, '/': 18563, 'loss': 7.9429121017456055}
{'epoch': 0, 'update in batch': 383, '/': 18563, 'loss': 8.05648422241211}
{'epoch': 0, 'update in batch': 384, '/': 18563, 'loss': 7.316658020019531}
{'epoch': 0, 'update in batch': 385, '/': 18563, 'loss': 8.597393035888672}
{'epoch': 0, 'update in batch': 386, '/': 18563, 'loss': 9.393728256225586}
{'epoch': 0, 'update in batch': 387, '/': 18563, 'loss': 8.225081443786621}
{'epoch': 0, 'update in batch': 388, '/': 18563, 'loss': 7.9958319664001465}
{'epoch': 0, 'update in batch': 389, '/': 18563, 'loss': 8.390036582946777}
{'epoch': 0, 'update in batch': 390, '/': 18563, 'loss': 7.745572566986084}
{'epoch': 0, 'update in batch': 391, '/': 18563, 'loss': 8.403060913085938}
{'epoch': 0, 'update in batch': 392, '/': 18563, 'loss': 8.703788757324219}
{'epoch': 0, 'update in batch': 393, '/': 18563, 'loss': 8.516857147216797}
{'epoch': 0, 'update in batch': 394, '/': 18563, 'loss': 8.078744888305664}
{'epoch': 0, 'update in batch': 395, '/': 18563, 'loss': 7.6597900390625}
{'epoch': 0, 'update in batch': 396, '/': 18563, 'loss': 8.454282760620117}
{'epoch': 0, 'update in batch': 397, '/': 18563, 'loss': 7.7727837562561035}
{'epoch': 0, 'update in batch': 398, '/': 18563, 'loss': 8.222984313964844}
{'epoch': 0, 'update in batch': 399, '/': 18563, 'loss': 8.369619369506836}
{'epoch': 0, 'update in batch': 400, '/': 18563, 'loss': 8.542525291442871}
{'epoch': 0, 'update in batch': 401, '/': 18563, 'loss': 7.9681854248046875}
{'epoch': 0, 'update in batch': 402, '/': 18563, 'loss': 8.842118263244629}
{'epoch': 0, 'update in batch': 403, '/': 18563, 'loss': 7.958454132080078}
{'epoch': 0, 'update in batch': 404, '/': 18563, 'loss': 7.084095001220703}
{'epoch': 0, 'update in batch': 405, '/': 18563, 'loss': 7.8765130043029785}
{'epoch': 0, 'update in batch': 406, '/': 18563, 'loss': 7.639691352844238}
{'epoch': 0, 'update in batch': 407, '/': 18563, 'loss': 7.440125942230225}
{'epoch': 0, 'update in batch': 408, '/': 18563, 'loss': 7.928472995758057}
{'epoch': 0, 'update in batch': 409, '/': 18563, 'loss': 8.704710960388184}
{'epoch': 0, 'update in batch': 410, '/': 18563, 'loss': 8.214713096618652}
{'epoch': 0, 'update in batch': 411, '/': 18563, 'loss': 8.115629196166992}
{'epoch': 0, 'update in batch': 412, '/': 18563, 'loss': 9.357975006103516}
{'epoch': 0, 'update in batch': 413, '/': 18563, 'loss': 7.756926536560059}
{'epoch': 0, 'update in batch': 414, '/': 18563, 'loss': 8.93007755279541}
{'epoch': 0, 'update in batch': 415, '/': 18563, 'loss': 8.929518699645996}
{'epoch': 0, 'update in batch': 416, '/': 18563, 'loss': 7.646470069885254}
{'epoch': 0, 'update in batch': 417, '/': 18563, 'loss': 8.457891464233398}
{'epoch': 0, 'update in batch': 418, '/': 18563, 'loss': 7.377375602722168}
{'epoch': 0, 'update in batch': 419, '/': 18563, 'loss': 8.03713607788086}
{'epoch': 0, 'update in batch': 420, '/': 18563, 'loss': 8.125130653381348}
{'epoch': 0, 'update in batch': 421, '/': 18563, 'loss': 6.818246364593506}
{'epoch': 0, 'update in batch': 422, '/': 18563, 'loss': 7.220259189605713}
{'epoch': 0, 'update in batch': 423, '/': 18563, 'loss': 7.800910949707031}
{'epoch': 0, 'update in batch': 424, '/': 18563, 'loss': 8.175793647766113}
{'epoch': 0, 'update in batch': 425, '/': 18563, 'loss': 7.588067054748535}
{'epoch': 0, 'update in batch': 426, '/': 18563, 'loss': 7.2054619789123535}
{'epoch': 0, 'update in batch': 427, '/': 18563, 'loss': 7.6552839279174805}
{'epoch': 0, 'update in batch': 428, '/': 18563, 'loss': 8.851090431213379}
{'epoch': 0, 'update in batch': 429, '/': 18563, 'loss': 8.768563270568848}
{'epoch': 0, 'update in batch': 430, '/': 18563, 'loss': 7.926184177398682}
{'epoch': 0, 'update in batch': 431, '/': 18563, 'loss': 8.663213729858398}
{'epoch': 0, 'update in batch': 432, '/': 18563, 'loss': 8.386338233947754}
{'epoch': 0, 'update in batch': 433, '/': 18563, 'loss': 8.77399730682373}
{'epoch': 0, 'update in batch': 434, '/': 18563, 'loss': 8.385528564453125}
{'epoch': 0, 'update in batch': 435, '/': 18563, 'loss': 7.742388725280762}
{'epoch': 0, 'update in batch': 436, '/': 18563, 'loss': 8.363179206848145}
{'epoch': 0, 'update in batch': 437, '/': 18563, 'loss': 9.262784004211426}
{'epoch': 0, 'update in batch': 438, '/': 18563, 'loss': 9.236469268798828}
{'epoch': 0, 'update in batch': 439, '/': 18563, 'loss': 8.904603958129883}
{'epoch': 0, 'update in batch': 440, '/': 18563, 'loss': 8.675701141357422}
{'epoch': 0, 'update in batch': 441, '/': 18563, 'loss': 8.811418533325195}
{'epoch': 0, 'update in batch': 442, '/': 18563, 'loss': 8.002241134643555}
{'epoch': 0, 'update in batch': 443, '/': 18563, 'loss': 9.04414176940918}
{'epoch': 0, 'update in batch': 444, '/': 18563, 'loss': 7.8904008865356445}
{'epoch': 0, 'update in batch': 445, '/': 18563, 'loss': 8.524297714233398}
{'epoch': 0, 'update in batch': 446, '/': 18563, 'loss': 8.615904808044434}
{'epoch': 0, 'update in batch': 447, '/': 18563, 'loss': 8.201675415039062}
{'epoch': 0, 'update in batch': 448, '/': 18563, 'loss': 8.531024932861328}
{'epoch': 0, 'update in batch': 449, '/': 18563, 'loss': 7.8379621505737305}
{'epoch': 0, 'update in batch': 450, '/': 18563, 'loss': 8.416367530822754}
{'epoch': 0, 'update in batch': 451, '/': 18563, 'loss': 7.4990715980529785}
{'epoch': 0, 'update in batch': 452, '/': 18563, 'loss': 7.984610557556152}
{'epoch': 0, 'update in batch': 453, '/': 18563, 'loss': 7.719987392425537}
{'epoch': 0, 'update in batch': 454, '/': 18563, 'loss': 7.9333176612854}
{'epoch': 0, 'update in batch': 455, '/': 18563, 'loss': 8.619344711303711}
{'epoch': 0, 'update in batch': 456, '/': 18563, 'loss': 7.849525451660156}
{'epoch': 0, 'update in batch': 457, '/': 18563, 'loss': 7.700997352600098}
{'epoch': 0, 'update in batch': 458, '/': 18563, 'loss': 8.065767288208008}
{'epoch': 0, 'update in batch': 459, '/': 18563, 'loss': 7.489628791809082}
{'epoch': 0, 'update in batch': 460, '/': 18563, 'loss': 8.036481857299805}
{'epoch': 0, 'update in batch': 461, '/': 18563, 'loss': 8.227537155151367}
{'epoch': 0, 'update in batch': 462, '/': 18563, 'loss': 7.66103982925415}
{'epoch': 0, 'update in batch': 463, '/': 18563, 'loss': 8.481343269348145}
{'epoch': 0, 'update in batch': 464, '/': 18563, 'loss': 8.711318969726562}
{'epoch': 0, 'update in batch': 465, '/': 18563, 'loss': 7.549925804138184}
{'epoch': 0, 'update in batch': 466, '/': 18563, 'loss': 8.020782470703125}
{'epoch': 0, 'update in batch': 467, '/': 18563, 'loss': 7.784451484680176}
{'epoch': 0, 'update in batch': 468, '/': 18563, 'loss': 7.7545928955078125}
{'epoch': 0, 'update in batch': 469, '/': 18563, 'loss': 8.484171867370605}
{'epoch': 0, 'update in batch': 470, '/': 18563, 'loss': 8.291640281677246}
{'epoch': 0, 'update in batch': 471, '/': 18563, 'loss': 7.873322486877441}
{'epoch': 0, 'update in batch': 472, '/': 18563, 'loss': 7.891420841217041}
{'epoch': 0, 'update in batch': 473, '/': 18563, 'loss': 8.376962661743164}
{'epoch': 0, 'update in batch': 474, '/': 18563, 'loss': 8.147513389587402}
{'epoch': 0, 'update in batch': 475, '/': 18563, 'loss': 7.739943027496338}
{'epoch': 0, 'update in batch': 476, '/': 18563, 'loss': 7.52395486831665}
{'epoch': 0, 'update in batch': 477, '/': 18563, 'loss': 7.962507724761963}
{'epoch': 0, 'update in batch': 478, '/': 18563, 'loss': 7.61989688873291}
{'epoch': 0, 'update in batch': 479, '/': 18563, 'loss': 8.628551483154297}
{'epoch': 0, 'update in batch': 480, '/': 18563, 'loss': 10.344924926757812}
{'epoch': 0, 'update in batch': 481, '/': 18563, 'loss': 9.189457893371582}
{'epoch': 0, 'update in batch': 482, '/': 18563, 'loss': 9.283202171325684}
{'epoch': 0, 'update in batch': 483, '/': 18563, 'loss': 8.036226272583008}
{'epoch': 0, 'update in batch': 484, '/': 18563, 'loss': 8.949888229370117}
{'epoch': 0, 'update in batch': 485, '/': 18563, 'loss': 9.32779598236084}
{'epoch': 0, 'update in batch': 486, '/': 18563, 'loss': 9.554967880249023}
{'epoch': 0, 'update in batch': 487, '/': 18563, 'loss': 8.438692092895508}
{'epoch': 0, 'update in batch': 488, '/': 18563, 'loss': 8.015823364257812}
{'epoch': 0, 'update in batch': 489, '/': 18563, 'loss': 8.621005058288574}
{'epoch': 0, 'update in batch': 490, '/': 18563, 'loss': 8.432602882385254}
{'epoch': 0, 'update in batch': 491, '/': 18563, 'loss': 8.659430503845215}
{'epoch': 0, 'update in batch': 492, '/': 18563, 'loss': 8.693103790283203}
{'epoch': 0, 'update in batch': 493, '/': 18563, 'loss': 8.895064353942871}
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
<ipython-input-18-fe996a0be74b> in <module>
      1 model = Model(vocab_size = len(dataset.uniq_words)).to(device)
----> 2 train(dataset, model, 1, 64)

<ipython-input-17-8d700bc624e3> in train(dataset, model, max_epochs, batch_size)
     15             loss = criterion(y_pred.transpose(1, 2), y)
     16 
---> 17             loss.backward()
     18             optimizer.step()
     19 

~/anaconda3/lib/python3.8/site-packages/torch/_tensor.py in backward(self, gradient, retain_graph, create_graph, inputs)
    361                 create_graph=create_graph,
    362                 inputs=inputs)
--> 363         torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
    364 
    365     def register_hook(self, hook):

~/anaconda3/lib/python3.8/site-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
    171     # some Python versions print out the first line of a multi-line function
    172     # calls in the traceback and some print out the last line
--> 173     Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
    174         tensors, grad_tensors_, retain_graph, create_graph, inputs,
    175         allow_unreachable=True, accumulate_grad=True)  # Calls into the C++ engine to run the backward pass

KeyboardInterrupt: 
def predict(dataset, model, text, next_words=5):
    model.eval()
    words = text.split(' ')
    state_h, state_c = model.init_state(len(words))

    for i in range(0, next_words):
        x = torch.tensor([[dataset.word_to_index[w] for w in words[i:]]]).to(device)
        y_pred, (state_h, state_c) = model(x, (state_h, state_c))

        last_word_logits = y_pred[0][-1]
        p = torch.nn.functional.softmax(last_word_logits, dim=0).detach().cpu().numpy()
        word_index = np.random.choice(len(last_word_logits), p=p)
        words.append(dataset.index_to_word[word_index])

    return words
predict(dataset, model, 'kmicic szedł')
['kmicic', 'szedł', 'zwycięzco', 'po', 'do', 'zlituj', 'i']

ZADANIE

Stworzyć sieć rekurencyjną GRU dla Challenging America word-gap prediction. Wymogi takie jak zawsze, zadanie widoczne na Gonito:

https://gonito.csi.wmi.amu.edu.pl/challenge/challenging-america-word-gap-prediction

Punktacja: 100 punktów

Deadline: 29 maja 2024 przed zajęciami