60 KiB
60 KiB
Modelowanie Języka
9. Model neuronowy rekurencyjny [ćwiczenia]
Jakub Pokrywka (2022)
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
import numpy as np
from collections import Counter
import re
device = 'cpu'
! wget https://wolnelektury.pl/media/book/txt/potop-tom-pierwszy.txt
! wget https://wolnelektury.pl/media/book/txt/potop-tom-drugi.txt
! wget https://wolnelektury.pl/media/book/txt/potop-tom-trzeci.txt
--2022-05-08 19:27:04-- https://wolnelektury.pl/media/book/txt/potop-tom-pierwszy.txt Resolving wolnelektury.pl (wolnelektury.pl)... 51.83.143.148, 2001:41d0:602:3294:: Connecting to wolnelektury.pl (wolnelektury.pl)|51.83.143.148|:443... connected. HTTP request sent, awaiting response... 200 OK Length: 877893 (857K) [text/plain] Saving to: ‘potop-tom-pierwszy.txt.2’ potop-tom-pierwszy. 100%[===================>] 857,32K --.-KB/s in 0,07s 2022-05-08 19:27:04 (12,0 MB/s) - ‘potop-tom-pierwszy.txt.2’ saved [877893/877893] --2022-05-08 19:27:04-- https://wolnelektury.pl/media/book/txt/potop-tom-drugi.txt Resolving wolnelektury.pl (wolnelektury.pl)... 51.83.143.148, 2001:41d0:602:3294:: Connecting to wolnelektury.pl (wolnelektury.pl)|51.83.143.148|:443... connected. HTTP request sent, awaiting response... 200 OK Length: 1087797 (1,0M) [text/plain] Saving to: ‘potop-tom-drugi.txt.2’ potop-tom-drugi.txt 100%[===================>] 1,04M --.-KB/s in 0,08s 2022-05-08 19:27:04 (12,9 MB/s) - ‘potop-tom-drugi.txt.2’ saved [1087797/1087797] --2022-05-08 19:27:05-- https://wolnelektury.pl/media/book/txt/potop-tom-trzeci.txt Resolving wolnelektury.pl (wolnelektury.pl)... 51.83.143.148, 2001:41d0:602:3294:: Connecting to wolnelektury.pl (wolnelektury.pl)|51.83.143.148|:443... connected. HTTP request sent, awaiting response... 200 OK Length: 788219 (770K) [text/plain] Saving to: ‘potop-tom-trzeci.txt.2’ potop-tom-trzeci.tx 100%[===================>] 769,75K --.-KB/s in 0,06s 2022-05-08 19:27:05 (12,0 MB/s) - ‘potop-tom-trzeci.txt.2’ saved [788219/788219]
!cat potop-* > potop.txt
class Dataset(torch.utils.data.Dataset):
def __init__(
self,
sequence_length,
):
self.sequence_length = sequence_length
self.words = self.load()
self.uniq_words = self.get_uniq_words()
self.index_to_word = {index: word for index, word in enumerate(self.uniq_words)}
self.word_to_index = {word: index for index, word in enumerate(self.uniq_words)}
self.words_indexes = [self.word_to_index[w] for w in self.words]
def load(self):
with open('potop.txt', 'r') as f_in:
text = [x.rstrip() for x in f_in.readlines() if x.strip()]
text = ' '.join(text).lower()
text = re.sub('[^a-ząćęłńóśźż ]', '', text)
text = text.split(' ')
return text
def get_uniq_words(self):
word_counts = Counter(self.words)
return sorted(word_counts, key=word_counts.get, reverse=True)
def __len__(self):
return len(self.words_indexes) - self.sequence_length
def __getitem__(self, index):
return (
torch.tensor(self.words_indexes[index:index+self.sequence_length]),
torch.tensor(self.words_indexes[index+1:index+self.sequence_length+1]),
)
dataset = Dataset(5)
dataset[200]
(tensor([ 551, 18, 17, 255, 10748]), tensor([ 18, 17, 255, 10748, 34]))
[dataset.index_to_word[x] for x in [ 551, 18, 17, 255, 10748]]
['patrzył', 'tak', 'jak', 'człowiek', 'zbudzony']
[dataset.index_to_word[x] for x in [ 18, 17, 255, 10748, 34]]
['tak', 'jak', 'człowiek', 'zbudzony', 'ze']
input_tensor = torch.tensor([[ 551, 18, 17, 255, 10748]], dtype=torch.int32).to(device)
#input_tensor = torch.tensor([[ 551, 18]], dtype=torch.int32).to(device)
class Model(nn.Module):
def __init__(self, vocab_size):
super(Model, self).__init__()
self.lstm_size = 128
self.embedding_dim = 128
self.num_layers = 3
self.embedding = nn.Embedding(
num_embeddings=vocab_size,
embedding_dim=self.embedding_dim,
)
self.lstm = nn.LSTM(
input_size=self.lstm_size,
hidden_size=self.lstm_size,
num_layers=self.num_layers,
dropout=0.2,
)
self.fc = nn.Linear(self.lstm_size, vocab_size)
def forward(self, x, prev_state = None):
embed = self.embedding(x)
output, state = self.lstm(embed, prev_state)
logits = self.fc(output)
return logits, state
def init_state(self, sequence_length):
return (torch.zeros(self.num_layers, sequence_length, self.lstm_size).to(device),
torch.zeros(self.num_layers, sequence_length, self.lstm_size).to(device))
model = Model(len(dataset)).to(device)
y_pred, (state_h, state_c) = model(input_tensor)
y_pred
tensor([[[ 0.0046, -0.0113, 0.0313, ..., 0.0198, -0.0312, 0.0223], [ 0.0039, -0.0110, 0.0303, ..., 0.0213, -0.0302, 0.0230], [ 0.0029, -0.0133, 0.0265, ..., 0.0204, -0.0297, 0.0219], [ 0.0010, -0.0120, 0.0282, ..., 0.0241, -0.0314, 0.0241], [ 0.0038, -0.0106, 0.0346, ..., 0.0230, -0.0333, 0.0232]]], grad_fn=<AddBackward0>)
y_pred.shape
torch.Size([1, 5, 1187998])
def train(dataset, model, max_epochs, batch_size):
model.train()
dataloader = DataLoader(dataset, batch_size=batch_size)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
for epoch in range(max_epochs):
for batch, (x, y) in enumerate(dataloader):
optimizer.zero_grad()
x = x.to(device)
y = y.to(device)
y_pred, (state_h, state_c) = model(x)
loss = criterion(y_pred.transpose(1, 2), y)
loss.backward()
optimizer.step()
print({ 'epoch': epoch, 'update in batch': batch, '/' : len(dataloader), 'loss': loss.item() })
model = Model(vocab_size = len(dataset.uniq_words)).to(device)
train(dataset, model, 1, 64)
{'epoch': 0, 'update in batch': 0, '/': 18563, 'loss': 10.717817306518555} {'epoch': 0, 'update in batch': 1, '/': 18563, 'loss': 10.699922561645508} {'epoch': 0, 'update in batch': 2, '/': 18563, 'loss': 10.701103210449219} {'epoch': 0, 'update in batch': 3, '/': 18563, 'loss': 10.700254440307617} {'epoch': 0, 'update in batch': 4, '/': 18563, 'loss': 10.69465160369873} {'epoch': 0, 'update in batch': 5, '/': 18563, 'loss': 10.681333541870117} {'epoch': 0, 'update in batch': 6, '/': 18563, 'loss': 10.668376922607422} {'epoch': 0, 'update in batch': 7, '/': 18563, 'loss': 10.675261497497559} {'epoch': 0, 'update in batch': 8, '/': 18563, 'loss': 10.665823936462402} {'epoch': 0, 'update in batch': 9, '/': 18563, 'loss': 10.655462265014648} {'epoch': 0, 'update in batch': 10, '/': 18563, 'loss': 10.591516494750977} {'epoch': 0, 'update in batch': 11, '/': 18563, 'loss': 10.580559730529785} {'epoch': 0, 'update in batch': 12, '/': 18563, 'loss': 10.524133682250977} {'epoch': 0, 'update in batch': 13, '/': 18563, 'loss': 10.480895042419434} {'epoch': 0, 'update in batch': 14, '/': 18563, 'loss': 10.33996295928955} {'epoch': 0, 'update in batch': 15, '/': 18563, 'loss': 10.345580101013184} {'epoch': 0, 'update in batch': 16, '/': 18563, 'loss': 10.200639724731445} {'epoch': 0, 'update in batch': 17, '/': 18563, 'loss': 10.030133247375488} {'epoch': 0, 'update in batch': 18, '/': 18563, 'loss': 10.046720504760742} {'epoch': 0, 'update in batch': 19, '/': 18563, 'loss': 10.00318717956543} {'epoch': 0, 'update in batch': 20, '/': 18563, 'loss': 9.588350296020508} {'epoch': 0, 'update in batch': 21, '/': 18563, 'loss': 9.780914306640625} {'epoch': 0, 'update in batch': 22, '/': 18563, 'loss': 9.36646842956543} {'epoch': 0, 'update in batch': 23, '/': 18563, 'loss': 9.306387901306152} {'epoch': 0, 'update in batch': 24, '/': 18563, 'loss': 9.150574684143066} {'epoch': 0, 'update in batch': 25, '/': 18563, 'loss': 8.89719295501709} {'epoch': 0, 'update in batch': 26, '/': 18563, 'loss': 8.741975784301758} {'epoch': 0, 'update in batch': 27, '/': 18563, 'loss': 9.36513614654541} {'epoch': 0, 'update in batch': 28, '/': 18563, 'loss': 8.840768814086914} {'epoch': 0, 'update in batch': 29, '/': 18563, 'loss': 8.356801986694336} {'epoch': 0, 'update in batch': 30, '/': 18563, 'loss': 8.274016380310059} {'epoch': 0, 'update in batch': 31, '/': 18563, 'loss': 8.944927215576172} {'epoch': 0, 'update in batch': 32, '/': 18563, 'loss': 8.923280715942383} {'epoch': 0, 'update in batch': 33, '/': 18563, 'loss': 8.479402542114258} {'epoch': 0, 'update in batch': 34, '/': 18563, 'loss': 8.42425537109375} {'epoch': 0, 'update in batch': 35, '/': 18563, 'loss': 9.487113952636719} {'epoch': 0, 'update in batch': 36, '/': 18563, 'loss': 8.314191818237305} {'epoch': 0, 'update in batch': 37, '/': 18563, 'loss': 8.0274658203125} {'epoch': 0, 'update in batch': 38, '/': 18563, 'loss': 8.725769996643066} {'epoch': 0, 'update in batch': 39, '/': 18563, 'loss': 8.67934799194336} {'epoch': 0, 'update in batch': 40, '/': 18563, 'loss': 8.872161865234375} {'epoch': 0, 'update in batch': 41, '/': 18563, 'loss': 7.883971214294434} {'epoch': 0, 'update in batch': 42, '/': 18563, 'loss': 7.682810306549072} {'epoch': 0, 'update in batch': 43, '/': 18563, 'loss': 7.880677223205566} {'epoch': 0, 'update in batch': 44, '/': 18563, 'loss': 7.807427406311035} {'epoch': 0, 'update in batch': 45, '/': 18563, 'loss': 7.93829870223999} {'epoch': 0, 'update in batch': 46, '/': 18563, 'loss': 7.718912601470947} {'epoch': 0, 'update in batch': 47, '/': 18563, 'loss': 8.309863090515137} {'epoch': 0, 'update in batch': 48, '/': 18563, 'loss': 9.091133117675781} {'epoch': 0, 'update in batch': 49, '/': 18563, 'loss': 9.317312240600586} {'epoch': 0, 'update in batch': 50, '/': 18563, 'loss': 8.517735481262207} {'epoch': 0, 'update in batch': 51, '/': 18563, 'loss': 7.697592258453369} {'epoch': 0, 'update in batch': 52, '/': 18563, 'loss': 6.838181972503662} {'epoch': 0, 'update in batch': 53, '/': 18563, 'loss': 7.967227935791016} {'epoch': 0, 'update in batch': 54, '/': 18563, 'loss': 8.47049331665039} {'epoch': 0, 'update in batch': 55, '/': 18563, 'loss': 8.958921432495117} {'epoch': 0, 'update in batch': 56, '/': 18563, 'loss': 8.316679000854492} {'epoch': 0, 'update in batch': 57, '/': 18563, 'loss': 8.997099876403809} {'epoch': 0, 'update in batch': 58, '/': 18563, 'loss': 8.608811378479004} {'epoch': 0, 'update in batch': 59, '/': 18563, 'loss': 9.377460479736328} {'epoch': 0, 'update in batch': 60, '/': 18563, 'loss': 8.6201171875} {'epoch': 0, 'update in batch': 61, '/': 18563, 'loss': 8.821510314941406} {'epoch': 0, 'update in batch': 62, '/': 18563, 'loss': 8.915961265563965} {'epoch': 0, 'update in batch': 63, '/': 18563, 'loss': 8.222617149353027} {'epoch': 0, 'update in batch': 64, '/': 18563, 'loss': 9.266777992248535} {'epoch': 0, 'update in batch': 65, '/': 18563, 'loss': 8.749354362487793} {'epoch': 0, 'update in batch': 66, '/': 18563, 'loss': 8.311641693115234} {'epoch': 0, 'update in batch': 67, '/': 18563, 'loss': 8.553888320922852} {'epoch': 0, 'update in batch': 68, '/': 18563, 'loss': 8.790258407592773} {'epoch': 0, 'update in batch': 69, '/': 18563, 'loss': 9.090133666992188} {'epoch': 0, 'update in batch': 70, '/': 18563, 'loss': 8.893723487854004} {'epoch': 0, 'update in batch': 71, '/': 18563, 'loss': 8.844594955444336} {'epoch': 0, 'update in batch': 72, '/': 18563, 'loss': 7.771625518798828} {'epoch': 0, 'update in batch': 73, '/': 18563, 'loss': 8.536479949951172} {'epoch': 0, 'update in batch': 74, '/': 18563, 'loss': 7.300860404968262} {'epoch': 0, 'update in batch': 75, '/': 18563, 'loss': 8.62000846862793} {'epoch': 0, 'update in batch': 76, '/': 18563, 'loss': 8.67784309387207} {'epoch': 0, 'update in batch': 77, '/': 18563, 'loss': 7.319235801696777} {'epoch': 0, 'update in batch': 78, '/': 18563, 'loss': 8.322186470031738} {'epoch': 0, 'update in batch': 79, '/': 18563, 'loss': 7.767421722412109} {'epoch': 0, 'update in batch': 80, '/': 18563, 'loss': 8.817885398864746} {'epoch': 0, 'update in batch': 81, '/': 18563, 'loss': 8.133109092712402} {'epoch': 0, 'update in batch': 82, '/': 18563, 'loss': 7.822054862976074} {'epoch': 0, 'update in batch': 83, '/': 18563, 'loss': 8.055540084838867} {'epoch': 0, 'update in batch': 84, '/': 18563, 'loss': 8.053682327270508} {'epoch': 0, 'update in batch': 85, '/': 18563, 'loss': 8.018306732177734} {'epoch': 0, 'update in batch': 86, '/': 18563, 'loss': 8.371909141540527} {'epoch': 0, 'update in batch': 87, '/': 18563, 'loss': 8.057979583740234} {'epoch': 0, 'update in batch': 88, '/': 18563, 'loss': 8.340703010559082} {'epoch': 0, 'update in batch': 89, '/': 18563, 'loss': 8.7703857421875} {'epoch': 0, 'update in batch': 90, '/': 18563, 'loss': 9.714847564697266} {'epoch': 0, 'update in batch': 91, '/': 18563, 'loss': 8.621702194213867} {'epoch': 0, 'update in batch': 92, '/': 18563, 'loss': 9.406997680664062} {'epoch': 0, 'update in batch': 93, '/': 18563, 'loss': 9.29774284362793} {'epoch': 0, 'update in batch': 94, '/': 18563, 'loss': 8.649836540222168} {'epoch': 0, 'update in batch': 95, '/': 18563, 'loss': 8.441780090332031} {'epoch': 0, 'update in batch': 96, '/': 18563, 'loss': 7.991406440734863} {'epoch': 0, 'update in batch': 97, '/': 18563, 'loss': 9.314489364624023} {'epoch': 0, 'update in batch': 98, '/': 18563, 'loss': 8.368816375732422} {'epoch': 0, 'update in batch': 99, '/': 18563, 'loss': 8.771149635314941} {'epoch': 0, 'update in batch': 100, '/': 18563, 'loss': 7.8758111000061035} {'epoch': 0, 'update in batch': 101, '/': 18563, 'loss': 8.341328620910645} {'epoch': 0, 'update in batch': 102, '/': 18563, 'loss': 8.413129806518555} {'epoch': 0, 'update in batch': 103, '/': 18563, 'loss': 7.372011661529541} {'epoch': 0, 'update in batch': 104, '/': 18563, 'loss': 8.170934677124023} {'epoch': 0, 'update in batch': 105, '/': 18563, 'loss': 8.109993934631348} {'epoch': 0, 'update in batch': 106, '/': 18563, 'loss': 8.172578811645508} {'epoch': 0, 'update in batch': 107, '/': 18563, 'loss': 8.33222484588623} {'epoch': 0, 'update in batch': 108, '/': 18563, 'loss': 7.997575283050537} {'epoch': 0, 'update in batch': 109, '/': 18563, 'loss': 7.847937107086182} {'epoch': 0, 'update in batch': 110, '/': 18563, 'loss': 7.351314544677734} {'epoch': 0, 'update in batch': 111, '/': 18563, 'loss': 8.472936630249023} {'epoch': 0, 'update in batch': 112, '/': 18563, 'loss': 7.855953216552734} {'epoch': 0, 'update in batch': 113, '/': 18563, 'loss': 8.163175582885742} {'epoch': 0, 'update in batch': 114, '/': 18563, 'loss': 8.208657264709473} {'epoch': 0, 'update in batch': 115, '/': 18563, 'loss': 8.781523704528809} {'epoch': 0, 'update in batch': 116, '/': 18563, 'loss': 8.449674606323242} {'epoch': 0, 'update in batch': 117, '/': 18563, 'loss': 8.176030158996582} {'epoch': 0, 'update in batch': 118, '/': 18563, 'loss': 8.415689468383789} {'epoch': 0, 'update in batch': 119, '/': 18563, 'loss': 8.645845413208008} {'epoch': 0, 'update in batch': 120, '/': 18563, 'loss': 8.160420417785645} {'epoch': 0, 'update in batch': 121, '/': 18563, 'loss': 8.117982864379883} {'epoch': 0, 'update in batch': 122, '/': 18563, 'loss': 9.099283218383789} {'epoch': 0, 'update in batch': 123, '/': 18563, 'loss': 7.98253870010376} {'epoch': 0, 'update in batch': 124, '/': 18563, 'loss': 8.112133979797363} {'epoch': 0, 'update in batch': 125, '/': 18563, 'loss': 8.479134559631348} {'epoch': 0, 'update in batch': 126, '/': 18563, 'loss': 8.92817497253418} {'epoch': 0, 'update in batch': 127, '/': 18563, 'loss': 8.38918399810791} {'epoch': 0, 'update in batch': 128, '/': 18563, 'loss': 9.000529289245605} {'epoch': 0, 'update in batch': 129, '/': 18563, 'loss': 8.525534629821777} {'epoch': 0, 'update in batch': 130, '/': 18563, 'loss': 9.055428504943848} {'epoch': 0, 'update in batch': 131, '/': 18563, 'loss': 8.818662643432617} {'epoch': 0, 'update in batch': 132, '/': 18563, 'loss': 8.807767868041992} {'epoch': 0, 'update in batch': 133, '/': 18563, 'loss': 8.398343086242676} {'epoch': 0, 'update in batch': 134, '/': 18563, 'loss': 8.435093879699707} {'epoch': 0, 'update in batch': 135, '/': 18563, 'loss': 7.877000331878662} {'epoch': 0, 'update in batch': 136, '/': 18563, 'loss': 8.197925567626953} {'epoch': 0, 'update in batch': 137, '/': 18563, 'loss': 8.655011177062988} {'epoch': 0, 'update in batch': 138, '/': 18563, 'loss': 7.786923885345459} {'epoch': 0, 'update in batch': 139, '/': 18563, 'loss': 8.338996887207031} {'epoch': 0, 'update in batch': 140, '/': 18563, 'loss': 8.607789993286133} {'epoch': 0, 'update in batch': 141, '/': 18563, 'loss': 8.52219295501709} {'epoch': 0, 'update in batch': 142, '/': 18563, 'loss': 8.436418533325195} {'epoch': 0, 'update in batch': 143, '/': 18563, 'loss': 7.999323844909668} {'epoch': 0, 'update in batch': 144, '/': 18563, 'loss': 7.543336391448975} {'epoch': 0, 'update in batch': 145, '/': 18563, 'loss': 7.3255791664123535} {'epoch': 0, 'update in batch': 146, '/': 18563, 'loss': 7.993613243103027} {'epoch': 0, 'update in batch': 147, '/': 18563, 'loss': 8.8505859375} {'epoch': 0, 'update in batch': 148, '/': 18563, 'loss': 8.146835327148438} {'epoch': 0, 'update in batch': 149, '/': 18563, 'loss': 8.532424926757812} {'epoch': 0, 'update in batch': 150, '/': 18563, 'loss': 8.323905944824219} {'epoch': 0, 'update in batch': 151, '/': 18563, 'loss': 7.8726677894592285} {'epoch': 0, 'update in batch': 152, '/': 18563, 'loss': 7.912005424499512} {'epoch': 0, 'update in batch': 153, '/': 18563, 'loss': 8.010560035705566} {'epoch': 0, 'update in batch': 154, '/': 18563, 'loss': 7.9417009353637695} {'epoch': 0, 'update in batch': 155, '/': 18563, 'loss': 7.991711616516113} {'epoch': 0, 'update in batch': 156, '/': 18563, 'loss': 8.27558708190918} {'epoch': 0, 'update in batch': 157, '/': 18563, 'loss': 7.736246585845947} {'epoch': 0, 'update in batch': 158, '/': 18563, 'loss': 7.4755754470825195} {'epoch': 0, 'update in batch': 159, '/': 18563, 'loss': 8.023443222045898} {'epoch': 0, 'update in batch': 160, '/': 18563, 'loss': 8.130350112915039} {'epoch': 0, 'update in batch': 161, '/': 18563, 'loss': 7.770634651184082} {'epoch': 0, 'update in batch': 162, '/': 18563, 'loss': 7.775434970855713} {'epoch': 0, 'update in batch': 163, '/': 18563, 'loss': 7.965312957763672} {'epoch': 0, 'update in batch': 164, '/': 18563, 'loss': 7.977341651916504} {'epoch': 0, 'update in batch': 165, '/': 18563, 'loss': 7.703671455383301} {'epoch': 0, 'update in batch': 166, '/': 18563, 'loss': 8.027135848999023} {'epoch': 0, 'update in batch': 167, '/': 18563, 'loss': 7.7673773765563965} {'epoch': 0, 'update in batch': 168, '/': 18563, 'loss': 8.654549598693848} {'epoch': 0, 'update in batch': 169, '/': 18563, 'loss': 7.8060808181762695} {'epoch': 0, 'update in batch': 170, '/': 18563, 'loss': 7.33704137802124} {'epoch': 0, 'update in batch': 171, '/': 18563, 'loss': 7.971919059753418} {'epoch': 0, 'update in batch': 172, '/': 18563, 'loss': 7.450611114501953} {'epoch': 0, 'update in batch': 173, '/': 18563, 'loss': 7.978057861328125} {'epoch': 0, 'update in batch': 174, '/': 18563, 'loss': 8.264434814453125} {'epoch': 0, 'update in batch': 175, '/': 18563, 'loss': 8.47761058807373} {'epoch': 0, 'update in batch': 176, '/': 18563, 'loss': 7.643885135650635} {'epoch': 0, 'update in batch': 177, '/': 18563, 'loss': 8.696805000305176} {'epoch': 0, 'update in batch': 178, '/': 18563, 'loss': 9.144462585449219} {'epoch': 0, 'update in batch': 179, '/': 18563, 'loss': 8.582620620727539} {'epoch': 0, 'update in batch': 180, '/': 18563, 'loss': 8.495562553405762} {'epoch': 0, 'update in batch': 181, '/': 18563, 'loss': 9.259647369384766} {'epoch': 0, 'update in batch': 182, '/': 18563, 'loss': 8.286632537841797} {'epoch': 0, 'update in batch': 183, '/': 18563, 'loss': 8.378074645996094} {'epoch': 0, 'update in batch': 184, '/': 18563, 'loss': 8.404892921447754} {'epoch': 0, 'update in batch': 185, '/': 18563, 'loss': 9.206843376159668} {'epoch': 0, 'update in batch': 186, '/': 18563, 'loss': 8.97215747833252} {'epoch': 0, 'update in batch': 187, '/': 18563, 'loss': 8.281005859375} {'epoch': 0, 'update in batch': 188, '/': 18563, 'loss': 7.638144493103027} {'epoch': 0, 'update in batch': 189, '/': 18563, 'loss': 7.991082668304443} {'epoch': 0, 'update in batch': 190, '/': 18563, 'loss': 8.207674026489258} {'epoch': 0, 'update in batch': 191, '/': 18563, 'loss': 8.16801643371582} {'epoch': 0, 'update in batch': 192, '/': 18563, 'loss': 7.827309608459473} {'epoch': 0, 'update in batch': 193, '/': 18563, 'loss': 8.387285232543945} {'epoch': 0, 'update in batch': 194, '/': 18563, 'loss': 7.990261077880859} {'epoch': 0, 'update in batch': 195, '/': 18563, 'loss': 7.7953925132751465} {'epoch': 0, 'update in batch': 196, '/': 18563, 'loss': 7.252983093261719} {'epoch': 0, 'update in batch': 197, '/': 18563, 'loss': 7.806585788726807} {'epoch': 0, 'update in batch': 198, '/': 18563, 'loss': 7.871600151062012} {'epoch': 0, 'update in batch': 199, '/': 18563, 'loss': 7.639830589294434} {'epoch': 0, 'update in batch': 200, '/': 18563, 'loss': 8.108308792114258} {'epoch': 0, 'update in batch': 201, '/': 18563, 'loss': 7.41513729095459} {'epoch': 0, 'update in batch': 202, '/': 18563, 'loss': 8.103743553161621} {'epoch': 0, 'update in batch': 203, '/': 18563, 'loss': 8.82174301147461} {'epoch': 0, 'update in batch': 204, '/': 18563, 'loss': 8.34859561920166} {'epoch': 0, 'update in batch': 205, '/': 18563, 'loss': 7.890545845031738} {'epoch': 0, 'update in batch': 206, '/': 18563, 'loss': 7.679532527923584} {'epoch': 0, 'update in batch': 207, '/': 18563, 'loss': 7.810311317443848} {'epoch': 0, 'update in batch': 208, '/': 18563, 'loss': 8.342585563659668} {'epoch': 0, 'update in batch': 209, '/': 18563, 'loss': 8.253597259521484} {'epoch': 0, 'update in batch': 210, '/': 18563, 'loss': 7.963072299957275} {'epoch': 0, 'update in batch': 211, '/': 18563, 'loss': 8.537101745605469} {'epoch': 0, 'update in batch': 212, '/': 18563, 'loss': 8.503724098205566} {'epoch': 0, 'update in batch': 213, '/': 18563, 'loss': 8.568987846374512} {'epoch': 0, 'update in batch': 214, '/': 18563, 'loss': 7.760678291320801} {'epoch': 0, 'update in batch': 215, '/': 18563, 'loss': 8.302183151245117} {'epoch': 0, 'update in batch': 216, '/': 18563, 'loss': 7.427420616149902} {'epoch': 0, 'update in batch': 217, '/': 18563, 'loss': 8.05746078491211} {'epoch': 0, 'update in batch': 218, '/': 18563, 'loss': 8.82285213470459} {'epoch': 0, 'update in batch': 219, '/': 18563, 'loss': 7.948827266693115} {'epoch': 0, 'update in batch': 220, '/': 18563, 'loss': 8.164112091064453} {'epoch': 0, 'update in batch': 221, '/': 18563, 'loss': 7.721047401428223} {'epoch': 0, 'update in batch': 222, '/': 18563, 'loss': 7.668707370758057} {'epoch': 0, 'update in batch': 223, '/': 18563, 'loss': 8.576696395874023} {'epoch': 0, 'update in batch': 224, '/': 18563, 'loss': 8.253091812133789} {'epoch': 0, 'update in batch': 225, '/': 18563, 'loss': 8.303543090820312} {'epoch': 0, 'update in batch': 226, '/': 18563, 'loss': 8.069855690002441} {'epoch': 0, 'update in batch': 227, '/': 18563, 'loss': 8.57229232788086} {'epoch': 0, 'update in batch': 228, '/': 18563, 'loss': 8.904585838317871} {'epoch': 0, 'update in batch': 229, '/': 18563, 'loss': 8.485595703125} {'epoch': 0, 'update in batch': 230, '/': 18563, 'loss': 8.22756290435791} {'epoch': 0, 'update in batch': 231, '/': 18563, 'loss': 8.281603813171387} {'epoch': 0, 'update in batch': 232, '/': 18563, 'loss': 7.591467380523682} {'epoch': 0, 'update in batch': 233, '/': 18563, 'loss': 7.8028883934021} {'epoch': 0, 'update in batch': 234, '/': 18563, 'loss': 8.079168319702148} {'epoch': 0, 'update in batch': 235, '/': 18563, 'loss': 7.578390598297119} {'epoch': 0, 'update in batch': 236, '/': 18563, 'loss': 7.865830421447754} {'epoch': 0, 'update in batch': 237, '/': 18563, 'loss': 7.105422019958496} {'epoch': 0, 'update in batch': 238, '/': 18563, 'loss': 8.034143447875977} {'epoch': 0, 'update in batch': 239, '/': 18563, 'loss': 7.23009729385376} {'epoch': 0, 'update in batch': 240, '/': 18563, 'loss': 7.221669673919678} {'epoch': 0, 'update in batch': 241, '/': 18563, 'loss': 7.118913173675537} {'epoch': 0, 'update in batch': 242, '/': 18563, 'loss': 7.690147399902344} {'epoch': 0, 'update in batch': 243, '/': 18563, 'loss': 7.676979064941406} {'epoch': 0, 'update in batch': 244, '/': 18563, 'loss': 8.231537818908691} {'epoch': 0, 'update in batch': 245, '/': 18563, 'loss': 8.212566375732422} {'epoch': 0, 'update in batch': 246, '/': 18563, 'loss': 9.095616340637207} {'epoch': 0, 'update in batch': 247, '/': 18563, 'loss': 8.249703407287598} {'epoch': 0, 'update in batch': 248, '/': 18563, 'loss': 9.082058906555176} {'epoch': 0, 'update in batch': 249, '/': 18563, 'loss': 8.530516624450684} {'epoch': 0, 'update in batch': 250, '/': 18563, 'loss': 8.979915618896484} {'epoch': 0, 'update in batch': 251, '/': 18563, 'loss': 8.667882919311523} {'epoch': 0, 'update in batch': 252, '/': 18563, 'loss': 8.804525375366211} {'epoch': 0, 'update in batch': 253, '/': 18563, 'loss': 8.67729377746582} {'epoch': 0, 'update in batch': 254, '/': 18563, 'loss': 8.580761909484863} {'epoch': 0, 'update in batch': 255, '/': 18563, 'loss': 7.724173545837402} {'epoch': 0, 'update in batch': 256, '/': 18563, 'loss': 7.7925591468811035} {'epoch': 0, 'update in batch': 257, '/': 18563, 'loss': 7.731482028961182} {'epoch': 0, 'update in batch': 258, '/': 18563, 'loss': 7.644040107727051} {'epoch': 0, 'update in batch': 259, '/': 18563, 'loss': 7.947877407073975} {'epoch': 0, 'update in batch': 260, '/': 18563, 'loss': 7.649043083190918} {'epoch': 0, 'update in batch': 261, '/': 18563, 'loss': 7.40912389755249} {'epoch': 0, 'update in batch': 262, '/': 18563, 'loss': 8.199918746948242} {'epoch': 0, 'update in batch': 263, '/': 18563, 'loss': 7.272132873535156} {'epoch': 0, 'update in batch': 264, '/': 18563, 'loss': 7.205214500427246} {'epoch': 0, 'update in batch': 265, '/': 18563, 'loss': 8.999595642089844} {'epoch': 0, 'update in batch': 266, '/': 18563, 'loss': 7.851510524749756} {'epoch': 0, 'update in batch': 267, '/': 18563, 'loss': 7.748948097229004} {'epoch': 0, 'update in batch': 268, '/': 18563, 'loss': 7.96875} {'epoch': 0, 'update in batch': 269, '/': 18563, 'loss': 7.627255916595459} {'epoch': 0, 'update in batch': 270, '/': 18563, 'loss': 7.719862937927246} {'epoch': 0, 'update in batch': 271, '/': 18563, 'loss': 7.58780574798584} {'epoch': 0, 'update in batch': 272, '/': 18563, 'loss': 8.386865615844727} {'epoch': 0, 'update in batch': 273, '/': 18563, 'loss': 8.708396911621094} {'epoch': 0, 'update in batch': 274, '/': 18563, 'loss': 7.853432655334473} {'epoch': 0, 'update in batch': 275, '/': 18563, 'loss': 7.818131923675537} {'epoch': 0, 'update in batch': 276, '/': 18563, 'loss': 7.714521884918213} {'epoch': 0, 'update in batch': 277, '/': 18563, 'loss': 8.75371265411377} {'epoch': 0, 'update in batch': 278, '/': 18563, 'loss': 7.6992998123168945} {'epoch': 0, 'update in batch': 279, '/': 18563, 'loss': 7.652693748474121} {'epoch': 0, 'update in batch': 280, '/': 18563, 'loss': 7.364585876464844} {'epoch': 0, 'update in batch': 281, '/': 18563, 'loss': 7.742022514343262} {'epoch': 0, 'update in batch': 282, '/': 18563, 'loss': 7.6205573081970215} {'epoch': 0, 'update in batch': 283, '/': 18563, 'loss': 7.475846290588379} {'epoch': 0, 'update in batch': 284, '/': 18563, 'loss': 7.302148342132568} {'epoch': 0, 'update in batch': 285, '/': 18563, 'loss': 7.524351596832275} {'epoch': 0, 'update in batch': 286, '/': 18563, 'loss': 7.755963325500488} {'epoch': 0, 'update in batch': 287, '/': 18563, 'loss': 7.620995998382568} {'epoch': 0, 'update in batch': 288, '/': 18563, 'loss': 7.289975166320801} {'epoch': 0, 'update in batch': 289, '/': 18563, 'loss': 7.470652103424072} {'epoch': 0, 'update in batch': 290, '/': 18563, 'loss': 7.297110557556152} {'epoch': 0, 'update in batch': 291, '/': 18563, 'loss': 7.907563209533691} {'epoch': 0, 'update in batch': 292, '/': 18563, 'loss': 8.051852226257324} {'epoch': 0, 'update in batch': 293, '/': 18563, 'loss': 6.691899299621582} {'epoch': 0, 'update in batch': 294, '/': 18563, 'loss': 7.9747819900512695} {'epoch': 0, 'update in batch': 295, '/': 18563, 'loss': 7.415904998779297} {'epoch': 0, 'update in batch': 296, '/': 18563, 'loss': 7.479670524597168} {'epoch': 0, 'update in batch': 297, '/': 18563, 'loss': 7.9454755783081055} {'epoch': 0, 'update in batch': 298, '/': 18563, 'loss': 7.79656457901001} {'epoch': 0, 'update in batch': 299, '/': 18563, 'loss': 7.644859313964844} {'epoch': 0, 'update in batch': 300, '/': 18563, 'loss': 7.649240970611572} {'epoch': 0, 'update in batch': 301, '/': 18563, 'loss': 7.497203826904297} {'epoch': 0, 'update in batch': 302, '/': 18563, 'loss': 7.169632911682129} {'epoch': 0, 'update in batch': 303, '/': 18563, 'loss': 7.124764442443848} {'epoch': 0, 'update in batch': 304, '/': 18563, 'loss': 7.728893280029297} {'epoch': 0, 'update in batch': 305, '/': 18563, 'loss': 8.029245376586914} {'epoch': 0, 'update in batch': 306, '/': 18563, 'loss': 7.361662864685059} {'epoch': 0, 'update in batch': 307, '/': 18563, 'loss': 8.070173263549805} {'epoch': 0, 'update in batch': 308, '/': 18563, 'loss': 7.55655574798584} {'epoch': 0, 'update in batch': 309, '/': 18563, 'loss': 7.713553428649902} {'epoch': 0, 'update in batch': 310, '/': 18563, 'loss': 8.333553314208984} {'epoch': 0, 'update in batch': 311, '/': 18563, 'loss': 8.089872360229492} {'epoch': 0, 'update in batch': 312, '/': 18563, 'loss': 8.951356887817383} {'epoch': 0, 'update in batch': 313, '/': 18563, 'loss': 8.920665740966797} {'epoch': 0, 'update in batch': 314, '/': 18563, 'loss': 8.811259269714355} {'epoch': 0, 'update in batch': 315, '/': 18563, 'loss': 8.719802856445312} {'epoch': 0, 'update in batch': 316, '/': 18563, 'loss': 8.700776100158691} {'epoch': 0, 'update in batch': 317, '/': 18563, 'loss': 8.846036911010742} {'epoch': 0, 'update in batch': 318, '/': 18563, 'loss': 8.553533554077148} {'epoch': 0, 'update in batch': 319, '/': 18563, 'loss': 9.257116317749023} {'epoch': 0, 'update in batch': 320, '/': 18563, 'loss': 8.487042427062988} {'epoch': 0, 'update in batch': 321, '/': 18563, 'loss': 8.743330955505371} {'epoch': 0, 'update in batch': 322, '/': 18563, 'loss': 8.377813339233398} {'epoch': 0, 'update in batch': 323, '/': 18563, 'loss': 8.41798210144043} {'epoch': 0, 'update in batch': 324, '/': 18563, 'loss': 7.884764671325684} {'epoch': 0, 'update in batch': 325, '/': 18563, 'loss': 8.827409744262695} {'epoch': 0, 'update in batch': 326, '/': 18563, 'loss': 8.21721363067627} {'epoch': 0, 'update in batch': 327, '/': 18563, 'loss': 8.522723197937012} {'epoch': 0, 'update in batch': 328, '/': 18563, 'loss': 7.387178897857666} {'epoch': 0, 'update in batch': 329, '/': 18563, 'loss': 8.58663558959961} {'epoch': 0, 'update in batch': 330, '/': 18563, 'loss': 8.539435386657715} {'epoch': 0, 'update in batch': 331, '/': 18563, 'loss': 8.35865592956543} {'epoch': 0, 'update in batch': 332, '/': 18563, 'loss': 8.55555248260498} {'epoch': 0, 'update in batch': 333, '/': 18563, 'loss': 7.9116950035095215} {'epoch': 0, 'update in batch': 334, '/': 18563, 'loss': 8.424735069274902} {'epoch': 0, 'update in batch': 335, '/': 18563, 'loss': 8.383890151977539} {'epoch': 0, 'update in batch': 336, '/': 18563, 'loss': 8.145454406738281} {'epoch': 0, 'update in batch': 337, '/': 18563, 'loss': 8.014772415161133} {'epoch': 0, 'update in batch': 338, '/': 18563, 'loss': 8.532005310058594} {'epoch': 0, 'update in batch': 339, '/': 18563, 'loss': 8.979973793029785} {'epoch': 0, 'update in batch': 340, '/': 18563, 'loss': 8.3964204788208} {'epoch': 0, 'update in batch': 341, '/': 18563, 'loss': 8.34205150604248} {'epoch': 0, 'update in batch': 342, '/': 18563, 'loss': 7.861489295959473} {'epoch': 0, 'update in batch': 343, '/': 18563, 'loss': 8.807058334350586} {'epoch': 0, 'update in batch': 344, '/': 18563, 'loss': 8.14976978302002} {'epoch': 0, 'update in batch': 345, '/': 18563, 'loss': 8.212860107421875} {'epoch': 0, 'update in batch': 346, '/': 18563, 'loss': 8.323419570922852} {'epoch': 0, 'update in batch': 347, '/': 18563, 'loss': 9.06071662902832} {'epoch': 0, 'update in batch': 348, '/': 18563, 'loss': 8.79192066192627} {'epoch': 0, 'update in batch': 349, '/': 18563, 'loss': 8.717201232910156} {'epoch': 0, 'update in batch': 350, '/': 18563, 'loss': 8.149703979492188} {'epoch': 0, 'update in batch': 351, '/': 18563, 'loss': 7.990046501159668} {'epoch': 0, 'update in batch': 352, '/': 18563, 'loss': 7.8197221755981445} {'epoch': 0, 'update in batch': 353, '/': 18563, 'loss': 8.022729873657227} {'epoch': 0, 'update in batch': 354, '/': 18563, 'loss': 8.339923858642578} {'epoch': 0, 'update in batch': 355, '/': 18563, 'loss': 7.867880821228027} {'epoch': 0, 'update in batch': 356, '/': 18563, 'loss': 8.161782264709473} {'epoch': 0, 'update in batch': 357, '/': 18563, 'loss': 7.711170196533203} {'epoch': 0, 'update in batch': 358, '/': 18563, 'loss': 8.46279239654541} {'epoch': 0, 'update in batch': 359, '/': 18563, 'loss': 8.327804565429688} {'epoch': 0, 'update in batch': 360, '/': 18563, 'loss': 8.184597969055176} {'epoch': 0, 'update in batch': 361, '/': 18563, 'loss': 8.126212120056152} {'epoch': 0, 'update in batch': 362, '/': 18563, 'loss': 8.122446060180664} {'epoch': 0, 'update in batch': 363, '/': 18563, 'loss': 7.730257511138916} {'epoch': 0, 'update in batch': 364, '/': 18563, 'loss': 7.7179059982299805} {'epoch': 0, 'update in batch': 365, '/': 18563, 'loss': 7.557857513427734} {'epoch': 0, 'update in batch': 366, '/': 18563, 'loss': 8.614083290100098} {'epoch': 0, 'update in batch': 367, '/': 18563, 'loss': 8.0489501953125} {'epoch': 0, 'update in batch': 368, '/': 18563, 'loss': 8.355381965637207} {'epoch': 0, 'update in batch': 369, '/': 18563, 'loss': 7.592991828918457} {'epoch': 0, 'update in batch': 370, '/': 18563, 'loss': 7.674102783203125} {'epoch': 0, 'update in batch': 371, '/': 18563, 'loss': 7.818256378173828} {'epoch': 0, 'update in batch': 372, '/': 18563, 'loss': 8.510438919067383} {'epoch': 0, 'update in batch': 373, '/': 18563, 'loss': 8.02087116241455} {'epoch': 0, 'update in batch': 374, '/': 18563, 'loss': 8.206090927124023} {'epoch': 0, 'update in batch': 375, '/': 18563, 'loss': 7.645677089691162} {'epoch': 0, 'update in batch': 376, '/': 18563, 'loss': 8.241236686706543} {'epoch': 0, 'update in batch': 377, '/': 18563, 'loss': 8.581649780273438} {'epoch': 0, 'update in batch': 378, '/': 18563, 'loss': 9.361258506774902} {'epoch': 0, 'update in batch': 379, '/': 18563, 'loss': 9.097440719604492} {'epoch': 0, 'update in batch': 380, '/': 18563, 'loss': 8.081677436828613} {'epoch': 0, 'update in batch': 381, '/': 18563, 'loss': 8.761143684387207} {'epoch': 0, 'update in batch': 382, '/': 18563, 'loss': 7.9429121017456055} {'epoch': 0, 'update in batch': 383, '/': 18563, 'loss': 8.05648422241211} {'epoch': 0, 'update in batch': 384, '/': 18563, 'loss': 7.316658020019531} {'epoch': 0, 'update in batch': 385, '/': 18563, 'loss': 8.597393035888672} {'epoch': 0, 'update in batch': 386, '/': 18563, 'loss': 9.393728256225586} {'epoch': 0, 'update in batch': 387, '/': 18563, 'loss': 8.225081443786621} {'epoch': 0, 'update in batch': 388, '/': 18563, 'loss': 7.9958319664001465} {'epoch': 0, 'update in batch': 389, '/': 18563, 'loss': 8.390036582946777} {'epoch': 0, 'update in batch': 390, '/': 18563, 'loss': 7.745572566986084} {'epoch': 0, 'update in batch': 391, '/': 18563, 'loss': 8.403060913085938} {'epoch': 0, 'update in batch': 392, '/': 18563, 'loss': 8.703788757324219} {'epoch': 0, 'update in batch': 393, '/': 18563, 'loss': 8.516857147216797} {'epoch': 0, 'update in batch': 394, '/': 18563, 'loss': 8.078744888305664} {'epoch': 0, 'update in batch': 395, '/': 18563, 'loss': 7.6597900390625} {'epoch': 0, 'update in batch': 396, '/': 18563, 'loss': 8.454282760620117} {'epoch': 0, 'update in batch': 397, '/': 18563, 'loss': 7.7727837562561035} {'epoch': 0, 'update in batch': 398, '/': 18563, 'loss': 8.222984313964844} {'epoch': 0, 'update in batch': 399, '/': 18563, 'loss': 8.369619369506836} {'epoch': 0, 'update in batch': 400, '/': 18563, 'loss': 8.542525291442871} {'epoch': 0, 'update in batch': 401, '/': 18563, 'loss': 7.9681854248046875} {'epoch': 0, 'update in batch': 402, '/': 18563, 'loss': 8.842118263244629} {'epoch': 0, 'update in batch': 403, '/': 18563, 'loss': 7.958454132080078} {'epoch': 0, 'update in batch': 404, '/': 18563, 'loss': 7.084095001220703} {'epoch': 0, 'update in batch': 405, '/': 18563, 'loss': 7.8765130043029785} {'epoch': 0, 'update in batch': 406, '/': 18563, 'loss': 7.639691352844238} {'epoch': 0, 'update in batch': 407, '/': 18563, 'loss': 7.440125942230225} {'epoch': 0, 'update in batch': 408, '/': 18563, 'loss': 7.928472995758057} {'epoch': 0, 'update in batch': 409, '/': 18563, 'loss': 8.704710960388184} {'epoch': 0, 'update in batch': 410, '/': 18563, 'loss': 8.214713096618652} {'epoch': 0, 'update in batch': 411, '/': 18563, 'loss': 8.115629196166992} {'epoch': 0, 'update in batch': 412, '/': 18563, 'loss': 9.357975006103516} {'epoch': 0, 'update in batch': 413, '/': 18563, 'loss': 7.756926536560059} {'epoch': 0, 'update in batch': 414, '/': 18563, 'loss': 8.93007755279541} {'epoch': 0, 'update in batch': 415, '/': 18563, 'loss': 8.929518699645996} {'epoch': 0, 'update in batch': 416, '/': 18563, 'loss': 7.646470069885254} {'epoch': 0, 'update in batch': 417, '/': 18563, 'loss': 8.457891464233398} {'epoch': 0, 'update in batch': 418, '/': 18563, 'loss': 7.377375602722168} {'epoch': 0, 'update in batch': 419, '/': 18563, 'loss': 8.03713607788086} {'epoch': 0, 'update in batch': 420, '/': 18563, 'loss': 8.125130653381348} {'epoch': 0, 'update in batch': 421, '/': 18563, 'loss': 6.818246364593506} {'epoch': 0, 'update in batch': 422, '/': 18563, 'loss': 7.220259189605713} {'epoch': 0, 'update in batch': 423, '/': 18563, 'loss': 7.800910949707031} {'epoch': 0, 'update in batch': 424, '/': 18563, 'loss': 8.175793647766113} {'epoch': 0, 'update in batch': 425, '/': 18563, 'loss': 7.588067054748535} {'epoch': 0, 'update in batch': 426, '/': 18563, 'loss': 7.2054619789123535} {'epoch': 0, 'update in batch': 427, '/': 18563, 'loss': 7.6552839279174805} {'epoch': 0, 'update in batch': 428, '/': 18563, 'loss': 8.851090431213379} {'epoch': 0, 'update in batch': 429, '/': 18563, 'loss': 8.768563270568848} {'epoch': 0, 'update in batch': 430, '/': 18563, 'loss': 7.926184177398682} {'epoch': 0, 'update in batch': 431, '/': 18563, 'loss': 8.663213729858398} {'epoch': 0, 'update in batch': 432, '/': 18563, 'loss': 8.386338233947754} {'epoch': 0, 'update in batch': 433, '/': 18563, 'loss': 8.77399730682373} {'epoch': 0, 'update in batch': 434, '/': 18563, 'loss': 8.385528564453125} {'epoch': 0, 'update in batch': 435, '/': 18563, 'loss': 7.742388725280762} {'epoch': 0, 'update in batch': 436, '/': 18563, 'loss': 8.363179206848145} {'epoch': 0, 'update in batch': 437, '/': 18563, 'loss': 9.262784004211426} {'epoch': 0, 'update in batch': 438, '/': 18563, 'loss': 9.236469268798828} {'epoch': 0, 'update in batch': 439, '/': 18563, 'loss': 8.904603958129883} {'epoch': 0, 'update in batch': 440, '/': 18563, 'loss': 8.675701141357422} {'epoch': 0, 'update in batch': 441, '/': 18563, 'loss': 8.811418533325195} {'epoch': 0, 'update in batch': 442, '/': 18563, 'loss': 8.002241134643555} {'epoch': 0, 'update in batch': 443, '/': 18563, 'loss': 9.04414176940918} {'epoch': 0, 'update in batch': 444, '/': 18563, 'loss': 7.8904008865356445} {'epoch': 0, 'update in batch': 445, '/': 18563, 'loss': 8.524297714233398} {'epoch': 0, 'update in batch': 446, '/': 18563, 'loss': 8.615904808044434} {'epoch': 0, 'update in batch': 447, '/': 18563, 'loss': 8.201675415039062} {'epoch': 0, 'update in batch': 448, '/': 18563, 'loss': 8.531024932861328} {'epoch': 0, 'update in batch': 449, '/': 18563, 'loss': 7.8379621505737305} {'epoch': 0, 'update in batch': 450, '/': 18563, 'loss': 8.416367530822754} {'epoch': 0, 'update in batch': 451, '/': 18563, 'loss': 7.4990715980529785} {'epoch': 0, 'update in batch': 452, '/': 18563, 'loss': 7.984610557556152} {'epoch': 0, 'update in batch': 453, '/': 18563, 'loss': 7.719987392425537} {'epoch': 0, 'update in batch': 454, '/': 18563, 'loss': 7.9333176612854} {'epoch': 0, 'update in batch': 455, '/': 18563, 'loss': 8.619344711303711} {'epoch': 0, 'update in batch': 456, '/': 18563, 'loss': 7.849525451660156} {'epoch': 0, 'update in batch': 457, '/': 18563, 'loss': 7.700997352600098} {'epoch': 0, 'update in batch': 458, '/': 18563, 'loss': 8.065767288208008} {'epoch': 0, 'update in batch': 459, '/': 18563, 'loss': 7.489628791809082} {'epoch': 0, 'update in batch': 460, '/': 18563, 'loss': 8.036481857299805} {'epoch': 0, 'update in batch': 461, '/': 18563, 'loss': 8.227537155151367} {'epoch': 0, 'update in batch': 462, '/': 18563, 'loss': 7.66103982925415} {'epoch': 0, 'update in batch': 463, '/': 18563, 'loss': 8.481343269348145} {'epoch': 0, 'update in batch': 464, '/': 18563, 'loss': 8.711318969726562} {'epoch': 0, 'update in batch': 465, '/': 18563, 'loss': 7.549925804138184} {'epoch': 0, 'update in batch': 466, '/': 18563, 'loss': 8.020782470703125} {'epoch': 0, 'update in batch': 467, '/': 18563, 'loss': 7.784451484680176} {'epoch': 0, 'update in batch': 468, '/': 18563, 'loss': 7.7545928955078125} {'epoch': 0, 'update in batch': 469, '/': 18563, 'loss': 8.484171867370605} {'epoch': 0, 'update in batch': 470, '/': 18563, 'loss': 8.291640281677246} {'epoch': 0, 'update in batch': 471, '/': 18563, 'loss': 7.873322486877441} {'epoch': 0, 'update in batch': 472, '/': 18563, 'loss': 7.891420841217041} {'epoch': 0, 'update in batch': 473, '/': 18563, 'loss': 8.376962661743164} {'epoch': 0, 'update in batch': 474, '/': 18563, 'loss': 8.147513389587402} {'epoch': 0, 'update in batch': 475, '/': 18563, 'loss': 7.739943027496338} {'epoch': 0, 'update in batch': 476, '/': 18563, 'loss': 7.52395486831665} {'epoch': 0, 'update in batch': 477, '/': 18563, 'loss': 7.962507724761963} {'epoch': 0, 'update in batch': 478, '/': 18563, 'loss': 7.61989688873291} {'epoch': 0, 'update in batch': 479, '/': 18563, 'loss': 8.628551483154297} {'epoch': 0, 'update in batch': 480, '/': 18563, 'loss': 10.344924926757812} {'epoch': 0, 'update in batch': 481, '/': 18563, 'loss': 9.189457893371582} {'epoch': 0, 'update in batch': 482, '/': 18563, 'loss': 9.283202171325684} {'epoch': 0, 'update in batch': 483, '/': 18563, 'loss': 8.036226272583008} {'epoch': 0, 'update in batch': 484, '/': 18563, 'loss': 8.949888229370117} {'epoch': 0, 'update in batch': 485, '/': 18563, 'loss': 9.32779598236084} {'epoch': 0, 'update in batch': 486, '/': 18563, 'loss': 9.554967880249023} {'epoch': 0, 'update in batch': 487, '/': 18563, 'loss': 8.438692092895508} {'epoch': 0, 'update in batch': 488, '/': 18563, 'loss': 8.015823364257812} {'epoch': 0, 'update in batch': 489, '/': 18563, 'loss': 8.621005058288574} {'epoch': 0, 'update in batch': 490, '/': 18563, 'loss': 8.432602882385254} {'epoch': 0, 'update in batch': 491, '/': 18563, 'loss': 8.659430503845215} {'epoch': 0, 'update in batch': 492, '/': 18563, 'loss': 8.693103790283203} {'epoch': 0, 'update in batch': 493, '/': 18563, 'loss': 8.895064353942871}
[0;31m---------------------------------------------------------------------------[0m [0;31mKeyboardInterrupt[0m Traceback (most recent call last) [0;32m<ipython-input-18-fe996a0be74b>[0m in [0;36m<module>[0;34m[0m [1;32m 1[0m [0mmodel[0m [0;34m=[0m [0mModel[0m[0;34m([0m[0mvocab_size[0m [0;34m=[0m [0mlen[0m[0;34m([0m[0mdataset[0m[0;34m.[0m[0muniq_words[0m[0;34m)[0m[0;34m)[0m[0;34m.[0m[0mto[0m[0;34m([0m[0mdevice[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m [0;32m----> 2[0;31m [0mtrain[0m[0;34m([0m[0mdataset[0m[0;34m,[0m [0mmodel[0m[0;34m,[0m [0;36m1[0m[0;34m,[0m [0;36m64[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m [0m [0;32m<ipython-input-17-8d700bc624e3>[0m in [0;36mtrain[0;34m(dataset, model, max_epochs, batch_size)[0m [1;32m 15[0m [0mloss[0m [0;34m=[0m [0mcriterion[0m[0;34m([0m[0my_pred[0m[0;34m.[0m[0mtranspose[0m[0;34m([0m[0;36m1[0m[0;34m,[0m [0;36m2[0m[0;34m)[0m[0;34m,[0m [0my[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m [1;32m 16[0m [0;34m[0m[0m [0;32m---> 17[0;31m [0mloss[0m[0;34m.[0m[0mbackward[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m [0m[1;32m 18[0m [0moptimizer[0m[0;34m.[0m[0mstep[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m [1;32m 19[0m [0;34m[0m[0m [0;32m~/anaconda3/lib/python3.8/site-packages/torch/_tensor.py[0m in [0;36mbackward[0;34m(self, gradient, retain_graph, create_graph, inputs)[0m [1;32m 361[0m [0mcreate_graph[0m[0;34m=[0m[0mcreate_graph[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m [1;32m 362[0m inputs=inputs) [0;32m--> 363[0;31m [0mtorch[0m[0;34m.[0m[0mautograd[0m[0;34m.[0m[0mbackward[0m[0;34m([0m[0mself[0m[0;34m,[0m [0mgradient[0m[0;34m,[0m [0mretain_graph[0m[0;34m,[0m [0mcreate_graph[0m[0;34m,[0m [0minputs[0m[0;34m=[0m[0minputs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m [0m[1;32m 364[0m [0;34m[0m[0m [1;32m 365[0m [0;32mdef[0m [0mregister_hook[0m[0;34m([0m[0mself[0m[0;34m,[0m [0mhook[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m [0;32m~/anaconda3/lib/python3.8/site-packages/torch/autograd/__init__.py[0m in [0;36mbackward[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)[0m [1;32m 171[0m [0;31m# some Python versions print out the first line of a multi-line function[0m[0;34m[0m[0;34m[0m[0;34m[0m[0m [1;32m 172[0m [0;31m# calls in the traceback and some print out the last line[0m[0;34m[0m[0;34m[0m[0;34m[0m[0m [0;32m--> 173[0;31m Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass [0m[1;32m 174[0m [0mtensors[0m[0;34m,[0m [0mgrad_tensors_[0m[0;34m,[0m [0mretain_graph[0m[0;34m,[0m [0mcreate_graph[0m[0;34m,[0m [0minputs[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m [1;32m 175[0m allow_unreachable=True, accumulate_grad=True) # Calls into the C++ engine to run the backward pass [0;31mKeyboardInterrupt[0m:
def predict(dataset, model, text, next_words=5):
model.eval()
words = text.split(' ')
state_h, state_c = model.init_state(len(words))
for i in range(0, next_words):
x = torch.tensor([[dataset.word_to_index[w] for w in words[i:]]]).to(device)
y_pred, (state_h, state_c) = model(x, (state_h, state_c))
last_word_logits = y_pred[0][-1]
p = torch.nn.functional.softmax(last_word_logits, dim=0).detach().cpu().numpy()
word_index = np.random.choice(len(last_word_logits), p=p)
words.append(dataset.index_to_word[word_index])
return words
predict(dataset, model, 'kmicic szedł')
['kmicic', 'szedł', 'zwycięzco', 'po', 'do', 'zlituj', 'i']