61 KiB
61 KiB
Modelowanie Języka
10. Model neuronowy rekurencyjny [ćwiczenia]
Jakub Pokrywka (2022)
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
import numpy as np
from collections import Counter
import re
device = 'cpu'
! wget https://wolnelektury.pl/media/book/txt/potop-tom-pierwszy.txt
! wget https://wolnelektury.pl/media/book/txt/potop-tom-drugi.txt
! wget https://wolnelektury.pl/media/book/txt/potop-tom-trzeci.txt
--2022-05-08 19:27:04-- https://wolnelektury.pl/media/book/txt/potop-tom-pierwszy.txt Resolving wolnelektury.pl (wolnelektury.pl)... 51.83.143.148, 2001:41d0:602:3294:: Connecting to wolnelektury.pl (wolnelektury.pl)|51.83.143.148|:443... connected. HTTP request sent, awaiting response... 200 OK Length: 877893 (857K) [text/plain] Saving to: ‘potop-tom-pierwszy.txt.2’ potop-tom-pierwszy. 100%[===================>] 857,32K --.-KB/s in 0,07s 2022-05-08 19:27:04 (12,0 MB/s) - ‘potop-tom-pierwszy.txt.2’ saved [877893/877893] --2022-05-08 19:27:04-- https://wolnelektury.pl/media/book/txt/potop-tom-drugi.txt Resolving wolnelektury.pl (wolnelektury.pl)... 51.83.143.148, 2001:41d0:602:3294:: Connecting to wolnelektury.pl (wolnelektury.pl)|51.83.143.148|:443... connected. HTTP request sent, awaiting response... 200 OK Length: 1087797 (1,0M) [text/plain] Saving to: ‘potop-tom-drugi.txt.2’ potop-tom-drugi.txt 100%[===================>] 1,04M --.-KB/s in 0,08s 2022-05-08 19:27:04 (12,9 MB/s) - ‘potop-tom-drugi.txt.2’ saved [1087797/1087797] --2022-05-08 19:27:05-- https://wolnelektury.pl/media/book/txt/potop-tom-trzeci.txt Resolving wolnelektury.pl (wolnelektury.pl)... 51.83.143.148, 2001:41d0:602:3294:: Connecting to wolnelektury.pl (wolnelektury.pl)|51.83.143.148|:443... connected. HTTP request sent, awaiting response... 200 OK Length: 788219 (770K) [text/plain] Saving to: ‘potop-tom-trzeci.txt.2’ potop-tom-trzeci.tx 100%[===================>] 769,75K --.-KB/s in 0,06s 2022-05-08 19:27:05 (12,0 MB/s) - ‘potop-tom-trzeci.txt.2’ saved [788219/788219]
!cat potop-* > potop.txt
class Dataset(torch.utils.data.Dataset):
def __init__(
self,
sequence_length,
):
self.sequence_length = sequence_length
self.words = self.load()
self.uniq_words = self.get_uniq_words()
self.index_to_word = {index: word for index, word in enumerate(self.uniq_words)}
self.word_to_index = {word: index for index, word in enumerate(self.uniq_words)}
self.words_indexes = [self.word_to_index[w] for w in self.words]
def load(self):
with open('potop.txt', 'r') as f_in:
text = [x.rstrip() for x in f_in.readlines() if x.strip()]
text = ' '.join(text).lower()
text = re.sub('[^a-ząćęłńóśźż ]', '', text)
text = text.split(' ')
return text
def get_uniq_words(self):
word_counts = Counter(self.words)
return sorted(word_counts, key=word_counts.get, reverse=True)
def __len__(self):
return len(self.words_indexes) - self.sequence_length
def __getitem__(self, index):
return (
torch.tensor(self.words_indexes[index:index+self.sequence_length]),
torch.tensor(self.words_indexes[index+1:index+self.sequence_length+1]),
)
dataset = Dataset(5)
dataset[200]
(tensor([ 551, 18, 17, 255, 10748]), tensor([ 18, 17, 255, 10748, 34]))
[dataset.index_to_word[x] for x in [ 551, 18, 17, 255, 10748]]
['patrzył', 'tak', 'jak', 'człowiek', 'zbudzony']
[dataset.index_to_word[x] for x in [ 18, 17, 255, 10748, 34]]
['tak', 'jak', 'człowiek', 'zbudzony', 'ze']
input_tensor = torch.tensor([[ 551, 18, 17, 255, 10748]], dtype=torch.int32).to(device)
#input_tensor = torch.tensor([[ 551, 18]], dtype=torch.int32).to(device)
class Model(nn.Module):
def __init__(self, vocab_size):
super(Model, self).__init__()
self.lstm_size = 128
self.embedding_dim = 128
self.num_layers = 3
self.embedding = nn.Embedding(
num_embeddings=vocab_size,
embedding_dim=self.embedding_dim,
)
self.lstm = nn.LSTM(
input_size=self.lstm_size,
hidden_size=self.lstm_size,
num_layers=self.num_layers,
dropout=0.2,
)
self.fc = nn.Linear(self.lstm_size, vocab_size)
def forward(self, x, prev_state = None):
embed = self.embedding(x)
output, state = self.lstm(embed, prev_state)
logits = self.fc(output)
return logits, state
def init_state(self, sequence_length):
return (torch.zeros(self.num_layers, sequence_length, self.lstm_size).to(device),
torch.zeros(self.num_layers, sequence_length, self.lstm_size).to(device))
model = Model(len(dataset)).to(device)
y_pred, (state_h, state_c) = model(input_tensor)
y_pred
tensor([[[ 0.0046, -0.0113, 0.0313, ..., 0.0198, -0.0312, 0.0223], [ 0.0039, -0.0110, 0.0303, ..., 0.0213, -0.0302, 0.0230], [ 0.0029, -0.0133, 0.0265, ..., 0.0204, -0.0297, 0.0219], [ 0.0010, -0.0120, 0.0282, ..., 0.0241, -0.0314, 0.0241], [ 0.0038, -0.0106, 0.0346, ..., 0.0230, -0.0333, 0.0232]]], grad_fn=<AddBackward0>)
y_pred.shape
torch.Size([1, 5, 1187998])
def train(dataset, model, max_epochs, batch_size):
model.train()
dataloader = DataLoader(dataset, batch_size=batch_size)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
for epoch in range(max_epochs):
for batch, (x, y) in enumerate(dataloader):
optimizer.zero_grad()
x = x.to(device)
y = y.to(device)
y_pred, (state_h, state_c) = model(x)
loss = criterion(y_pred.transpose(1, 2), y)
loss.backward()
optimizer.step()
print({ 'epoch': epoch, 'update in batch': batch, '/' : len(dataloader), 'loss': loss.item() })
model = Model(vocab_size = len(dataset.uniq_words)).to(device)
train(dataset, model, 1, 64)
{'epoch': 0, 'update in batch': 0, '/': 18563, 'loss': 10.717817306518555} {'epoch': 0, 'update in batch': 1, '/': 18563, 'loss': 10.699922561645508} {'epoch': 0, 'update in batch': 2, '/': 18563, 'loss': 10.701103210449219} {'epoch': 0, 'update in batch': 3, '/': 18563, 'loss': 10.700254440307617} {'epoch': 0, 'update in batch': 4, '/': 18563, 'loss': 10.69465160369873} {'epoch': 0, 'update in batch': 5, '/': 18563, 'loss': 10.681333541870117} {'epoch': 0, 'update in batch': 6, '/': 18563, 'loss': 10.668376922607422} {'epoch': 0, 'update in batch': 7, '/': 18563, 'loss': 10.675261497497559} {'epoch': 0, 'update in batch': 8, '/': 18563, 'loss': 10.665823936462402} {'epoch': 0, 'update in batch': 9, '/': 18563, 'loss': 10.655462265014648} {'epoch': 0, 'update in batch': 10, '/': 18563, 'loss': 10.591516494750977} {'epoch': 0, 'update in batch': 11, '/': 18563, 'loss': 10.580559730529785} {'epoch': 0, 'update in batch': 12, '/': 18563, 'loss': 10.524133682250977} {'epoch': 0, 'update in batch': 13, '/': 18563, 'loss': 10.480895042419434} {'epoch': 0, 'update in batch': 14, '/': 18563, 'loss': 10.33996295928955} {'epoch': 0, 'update in batch': 15, '/': 18563, 'loss': 10.345580101013184} {'epoch': 0, 'update in batch': 16, '/': 18563, 'loss': 10.200639724731445} {'epoch': 0, 'update in batch': 17, '/': 18563, 'loss': 10.030133247375488} {'epoch': 0, 'update in batch': 18, '/': 18563, 'loss': 10.046720504760742} {'epoch': 0, 'update in batch': 19, '/': 18563, 'loss': 10.00318717956543} {'epoch': 0, 'update in batch': 20, '/': 18563, 'loss': 9.588350296020508} {'epoch': 0, 'update in batch': 21, '/': 18563, 'loss': 9.780914306640625} {'epoch': 0, 'update in batch': 22, '/': 18563, 'loss': 9.36646842956543} {'epoch': 0, 'update in batch': 23, '/': 18563, 'loss': 9.306387901306152} {'epoch': 0, 'update in batch': 24, '/': 18563, 'loss': 9.150574684143066} {'epoch': 0, 'update in batch': 25, '/': 18563, 'loss': 8.89719295501709} {'epoch': 0, 'update in batch': 26, '/': 18563, 'loss': 8.741975784301758} {'epoch': 0, 'update in batch': 27, '/': 18563, 'loss': 9.36513614654541} {'epoch': 0, 'update in batch': 28, '/': 18563, 'loss': 8.840768814086914} {'epoch': 0, 'update in batch': 29, '/': 18563, 'loss': 8.356801986694336} {'epoch': 0, 'update in batch': 30, '/': 18563, 'loss': 8.274016380310059} {'epoch': 0, 'update in batch': 31, '/': 18563, 'loss': 8.944927215576172} {'epoch': 0, 'update in batch': 32, '/': 18563, 'loss': 8.923280715942383} {'epoch': 0, 'update in batch': 33, '/': 18563, 'loss': 8.479402542114258} {'epoch': 0, 'update in batch': 34, '/': 18563, 'loss': 8.42425537109375} {'epoch': 0, 'update in batch': 35, '/': 18563, 'loss': 9.487113952636719} {'epoch': 0, 'update in batch': 36, '/': 18563, 'loss': 8.314191818237305} {'epoch': 0, 'update in batch': 37, '/': 18563, 'loss': 8.0274658203125} {'epoch': 0, 'update in batch': 38, '/': 18563, 'loss': 8.725769996643066} {'epoch': 0, 'update in batch': 39, '/': 18563, 'loss': 8.67934799194336} {'epoch': 0, 'update in batch': 40, '/': 18563, 'loss': 8.872161865234375} {'epoch': 0, 'update in batch': 41, '/': 18563, 'loss': 7.883971214294434} {'epoch': 0, 'update in batch': 42, '/': 18563, 'loss': 7.682810306549072} {'epoch': 0, 'update in batch': 43, '/': 18563, 'loss': 7.880677223205566} {'epoch': 0, 'update in batch': 44, '/': 18563, 'loss': 7.807427406311035} {'epoch': 0, 'update in batch': 45, '/': 18563, 'loss': 7.93829870223999} {'epoch': 0, 'update in batch': 46, '/': 18563, 'loss': 7.718912601470947} {'epoch': 0, 'update in batch': 47, '/': 18563, 'loss': 8.309863090515137} {'epoch': 0, 'update in batch': 48, '/': 18563, 'loss': 9.091133117675781} {'epoch': 0, 'update in batch': 49, '/': 18563, 'loss': 9.317312240600586} {'epoch': 0, 'update in batch': 50, '/': 18563, 'loss': 8.517735481262207} {'epoch': 0, 'update in batch': 51, '/': 18563, 'loss': 7.697592258453369} {'epoch': 0, 'update in batch': 52, '/': 18563, 'loss': 6.838181972503662} {'epoch': 0, 'update in batch': 53, '/': 18563, 'loss': 7.967227935791016} {'epoch': 0, 'update in batch': 54, '/': 18563, 'loss': 8.47049331665039} {'epoch': 0, 'update in batch': 55, '/': 18563, 'loss': 8.958921432495117} {'epoch': 0, 'update in batch': 56, '/': 18563, 'loss': 8.316679000854492} {'epoch': 0, 'update in batch': 57, '/': 18563, 'loss': 8.997099876403809} {'epoch': 0, 'update in batch': 58, '/': 18563, 'loss': 8.608811378479004} {'epoch': 0, 'update in batch': 59, '/': 18563, 'loss': 9.377460479736328} {'epoch': 0, 'update in batch': 60, '/': 18563, 'loss': 8.6201171875} {'epoch': 0, 'update in batch': 61, '/': 18563, 'loss': 8.821510314941406} {'epoch': 0, 'update in batch': 62, '/': 18563, 'loss': 8.915961265563965} {'epoch': 0, 'update in batch': 63, '/': 18563, 'loss': 8.222617149353027} {'epoch': 0, 'update in batch': 64, '/': 18563, 'loss': 9.266777992248535} {'epoch': 0, 'update in batch': 65, '/': 18563, 'loss': 8.749354362487793} {'epoch': 0, 'update in batch': 66, '/': 18563, 'loss': 8.311641693115234} {'epoch': 0, 'update in batch': 67, '/': 18563, 'loss': 8.553888320922852} {'epoch': 0, 'update in batch': 68, '/': 18563, 'loss': 8.790258407592773} {'epoch': 0, 'update in batch': 69, '/': 18563, 'loss': 9.090133666992188} {'epoch': 0, 'update in batch': 70, '/': 18563, 'loss': 8.893723487854004} {'epoch': 0, 'update in batch': 71, '/': 18563, 'loss': 8.844594955444336} {'epoch': 0, 'update in batch': 72, '/': 18563, 'loss': 7.771625518798828} {'epoch': 0, 'update in batch': 73, '/': 18563, 'loss': 8.536479949951172} {'epoch': 0, 'update in batch': 74, '/': 18563, 'loss': 7.300860404968262} {'epoch': 0, 'update in batch': 75, '/': 18563, 'loss': 8.62000846862793} {'epoch': 0, 'update in batch': 76, '/': 18563, 'loss': 8.67784309387207} {'epoch': 0, 'update in batch': 77, '/': 18563, 'loss': 7.319235801696777} {'epoch': 0, 'update in batch': 78, '/': 18563, 'loss': 8.322186470031738} {'epoch': 0, 'update in batch': 79, '/': 18563, 'loss': 7.767421722412109} {'epoch': 0, 'update in batch': 80, '/': 18563, 'loss': 8.817885398864746} {'epoch': 0, 'update in batch': 81, '/': 18563, 'loss': 8.133109092712402} {'epoch': 0, 'update in batch': 82, '/': 18563, 'loss': 7.822054862976074} {'epoch': 0, 'update in batch': 83, '/': 18563, 'loss': 8.055540084838867} {'epoch': 0, 'update in batch': 84, '/': 18563, 'loss': 8.053682327270508} {'epoch': 0, 'update in batch': 85, '/': 18563, 'loss': 8.018306732177734} {'epoch': 0, 'update in batch': 86, '/': 18563, 'loss': 8.371909141540527} {'epoch': 0, 'update in batch': 87, '/': 18563, 'loss': 8.057979583740234} {'epoch': 0, 'update in batch': 88, '/': 18563, 'loss': 8.340703010559082} {'epoch': 0, 'update in batch': 89, '/': 18563, 'loss': 8.7703857421875} {'epoch': 0, 'update in batch': 90, '/': 18563, 'loss': 9.714847564697266} {'epoch': 0, 'update in batch': 91, '/': 18563, 'loss': 8.621702194213867} {'epoch': 0, 'update in batch': 92, '/': 18563, 'loss': 9.406997680664062} {'epoch': 0, 'update in batch': 93, '/': 18563, 'loss': 9.29774284362793} {'epoch': 0, 'update in batch': 94, '/': 18563, 'loss': 8.649836540222168} {'epoch': 0, 'update in batch': 95, '/': 18563, 'loss': 8.441780090332031} {'epoch': 0, 'update in batch': 96, '/': 18563, 'loss': 7.991406440734863} {'epoch': 0, 'update in batch': 97, '/': 18563, 'loss': 9.314489364624023} {'epoch': 0, 'update in batch': 98, '/': 18563, 'loss': 8.368816375732422} {'epoch': 0, 'update in batch': 99, '/': 18563, 'loss': 8.771149635314941} {'epoch': 0, 'update in batch': 100, '/': 18563, 'loss': 7.8758111000061035} {'epoch': 0, 'update in batch': 101, '/': 18563, 'loss': 8.341328620910645} {'epoch': 0, 'update in batch': 102, '/': 18563, 'loss': 8.413129806518555} {'epoch': 0, 'update in batch': 103, '/': 18563, 'loss': 7.372011661529541} {'epoch': 0, 'update in batch': 104, '/': 18563, 'loss': 8.170934677124023} {'epoch': 0, 'update in batch': 105, '/': 18563, 'loss': 8.109993934631348} {'epoch': 0, 'update in batch': 106, '/': 18563, 'loss': 8.172578811645508} {'epoch': 0, 'update in batch': 107, '/': 18563, 'loss': 8.33222484588623} {'epoch': 0, 'update in batch': 108, '/': 18563, 'loss': 7.997575283050537} {'epoch': 0, 'update in batch': 109, '/': 18563, 'loss': 7.847937107086182} {'epoch': 0, 'update in batch': 110, '/': 18563, 'loss': 7.351314544677734} {'epoch': 0, 'update in batch': 111, '/': 18563, 'loss': 8.472936630249023} {'epoch': 0, 'update in batch': 112, '/': 18563, 'loss': 7.855953216552734} {'epoch': 0, 'update in batch': 113, '/': 18563, 'loss': 8.163175582885742} {'epoch': 0, 'update in batch': 114, '/': 18563, 'loss': 8.208657264709473} {'epoch': 0, 'update in batch': 115, '/': 18563, 'loss': 8.781523704528809} {'epoch': 0, 'update in batch': 116, '/': 18563, 'loss': 8.449674606323242} {'epoch': 0, 'update in batch': 117, '/': 18563, 'loss': 8.176030158996582} {'epoch': 0, 'update in batch': 118, '/': 18563, 'loss': 8.415689468383789} {'epoch': 0, 'update in batch': 119, '/': 18563, 'loss': 8.645845413208008} {'epoch': 0, 'update in batch': 120, '/': 18563, 'loss': 8.160420417785645} {'epoch': 0, 'update in batch': 121, '/': 18563, 'loss': 8.117982864379883} {'epoch': 0, 'update in batch': 122, '/': 18563, 'loss': 9.099283218383789} {'epoch': 0, 'update in batch': 123, '/': 18563, 'loss': 7.98253870010376} {'epoch': 0, 'update in batch': 124, '/': 18563, 'loss': 8.112133979797363} {'epoch': 0, 'update in batch': 125, '/': 18563, 'loss': 8.479134559631348} {'epoch': 0, 'update in batch': 126, '/': 18563, 'loss': 8.92817497253418} {'epoch': 0, 'update in batch': 127, '/': 18563, 'loss': 8.38918399810791} {'epoch': 0, 'update in batch': 128, '/': 18563, 'loss': 9.000529289245605} {'epoch': 0, 'update in batch': 129, '/': 18563, 'loss': 8.525534629821777} {'epoch': 0, 'update in batch': 130, '/': 18563, 'loss': 9.055428504943848} {'epoch': 0, 'update in batch': 131, '/': 18563, 'loss': 8.818662643432617} {'epoch': 0, 'update in batch': 132, '/': 18563, 'loss': 8.807767868041992} {'epoch': 0, 'update in batch': 133, '/': 18563, 'loss': 8.398343086242676} {'epoch': 0, 'update in batch': 134, '/': 18563, 'loss': 8.435093879699707} {'epoch': 0, 'update in batch': 135, '/': 18563, 'loss': 7.877000331878662} {'epoch': 0, 'update in batch': 136, '/': 18563, 'loss': 8.197925567626953} {'epoch': 0, 'update in batch': 137, '/': 18563, 'loss': 8.655011177062988} {'epoch': 0, 'update in batch': 138, '/': 18563, 'loss': 7.786923885345459} {'epoch': 0, 'update in batch': 139, '/': 18563, 'loss': 8.338996887207031} {'epoch': 0, 'update in batch': 140, '/': 18563, 'loss': 8.607789993286133} {'epoch': 0, 'update in batch': 141, '/': 18563, 'loss': 8.52219295501709} {'epoch': 0, 'update in batch': 142, '/': 18563, 'loss': 8.436418533325195} {'epoch': 0, 'update in batch': 143, '/': 18563, 'loss': 7.999323844909668} {'epoch': 0, 'update in batch': 144, '/': 18563, 'loss': 7.543336391448975} {'epoch': 0, 'update in batch': 145, '/': 18563, 'loss': 7.3255791664123535} {'epoch': 0, 'update in batch': 146, '/': 18563, 'loss': 7.993613243103027} {'epoch': 0, 'update in batch': 147, '/': 18563, 'loss': 8.8505859375} {'epoch': 0, 'update in batch': 148, '/': 18563, 'loss': 8.146835327148438} {'epoch': 0, 'update in batch': 149, '/': 18563, 'loss': 8.532424926757812} {'epoch': 0, 'update in batch': 150, '/': 18563, 'loss': 8.323905944824219} {'epoch': 0, 'update in batch': 151, '/': 18563, 'loss': 7.8726677894592285} {'epoch': 0, 'update in batch': 152, '/': 18563, 'loss': 7.912005424499512} {'epoch': 0, 'update in batch': 153, '/': 18563, 'loss': 8.010560035705566} {'epoch': 0, 'update in batch': 154, '/': 18563, 'loss': 7.9417009353637695} {'epoch': 0, 'update in batch': 155, '/': 18563, 'loss': 7.991711616516113} {'epoch': 0, 'update in batch': 156, '/': 18563, 'loss': 8.27558708190918} {'epoch': 0, 'update in batch': 157, '/': 18563, 'loss': 7.736246585845947} {'epoch': 0, 'update in batch': 158, '/': 18563, 'loss': 7.4755754470825195} {'epoch': 0, 'update in batch': 159, '/': 18563, 'loss': 8.023443222045898} {'epoch': 0, 'update in batch': 160, '/': 18563, 'loss': 8.130350112915039} {'epoch': 0, 'update in batch': 161, '/': 18563, 'loss': 7.770634651184082} {'epoch': 0, 'update in batch': 162, '/': 18563, 'loss': 7.775434970855713} {'epoch': 0, 'update in batch': 163, '/': 18563, 'loss': 7.965312957763672} {'epoch': 0, 'update in batch': 164, '/': 18563, 'loss': 7.977341651916504} {'epoch': 0, 'update in batch': 165, '/': 18563, 'loss': 7.703671455383301} {'epoch': 0, 'update in batch': 166, '/': 18563, 'loss': 8.027135848999023} {'epoch': 0, 'update in batch': 167, '/': 18563, 'loss': 7.7673773765563965} {'epoch': 0, 'update in batch': 168, '/': 18563, 'loss': 8.654549598693848} {'epoch': 0, 'update in batch': 169, '/': 18563, 'loss': 7.8060808181762695} {'epoch': 0, 'update in batch': 170, '/': 18563, 'loss': 7.33704137802124} {'epoch': 0, 'update in batch': 171, '/': 18563, 'loss': 7.971919059753418} {'epoch': 0, 'update in batch': 172, '/': 18563, 'loss': 7.450611114501953} {'epoch': 0, 'update in batch': 173, '/': 18563, 'loss': 7.978057861328125} {'epoch': 0, 'update in batch': 174, '/': 18563, 'loss': 8.264434814453125} {'epoch': 0, 'update in batch': 175, '/': 18563, 'loss': 8.47761058807373} {'epoch': 0, 'update in batch': 176, '/': 18563, 'loss': 7.643885135650635} {'epoch': 0, 'update in batch': 177, '/': 18563, 'loss': 8.696805000305176} {'epoch': 0, 'update in batch': 178, '/': 18563, 'loss': 9.144462585449219} {'epoch': 0, 'update in batch': 179, '/': 18563, 'loss': 8.582620620727539} {'epoch': 0, 'update in batch': 180, '/': 18563, 'loss': 8.495562553405762} {'epoch': 0, 'update in batch': 181, '/': 18563, 'loss': 9.259647369384766} {'epoch': 0, 'update in batch': 182, '/': 18563, 'loss': 8.286632537841797} {'epoch': 0, 'update in batch': 183, '/': 18563, 'loss': 8.378074645996094} {'epoch': 0, 'update in batch': 184, '/': 18563, 'loss': 8.404892921447754} {'epoch': 0, 'update in batch': 185, '/': 18563, 'loss': 9.206843376159668} {'epoch': 0, 'update in batch': 186, '/': 18563, 'loss': 8.97215747833252} {'epoch': 0, 'update in batch': 187, '/': 18563, 'loss': 8.281005859375} {'epoch': 0, 'update in batch': 188, '/': 18563, 'loss': 7.638144493103027} {'epoch': 0, 'update in batch': 189, '/': 18563, 'loss': 7.991082668304443} {'epoch': 0, 'update in batch': 190, '/': 18563, 'loss': 8.207674026489258} {'epoch': 0, 'update in batch': 191, '/': 18563, 'loss': 8.16801643371582} {'epoch': 0, 'update in batch': 192, '/': 18563, 'loss': 7.827309608459473} {'epoch': 0, 'update in batch': 193, '/': 18563, 'loss': 8.387285232543945} {'epoch': 0, 'update in batch': 194, '/': 18563, 'loss': 7.990261077880859} {'epoch': 0, 'update in batch': 195, '/': 18563, 'loss': 7.7953925132751465} {'epoch': 0, 'update in batch': 196, '/': 18563, 'loss': 7.252983093261719} {'epoch': 0, 'update in batch': 197, '/': 18563, 'loss': 7.806585788726807} {'epoch': 0, 'update in batch': 198, '/': 18563, 'loss': 7.871600151062012} {'epoch': 0, 'update in batch': 199, '/': 18563, 'loss': 7.639830589294434} {'epoch': 0, 'update in batch': 200, '/': 18563, 'loss': 8.108308792114258} {'epoch': 0, 'update in batch': 201, '/': 18563, 'loss': 7.41513729095459} {'epoch': 0, 'update in batch': 202, '/': 18563, 'loss': 8.103743553161621} {'epoch': 0, 'update in batch': 203, '/': 18563, 'loss': 8.82174301147461} {'epoch': 0, 'update in batch': 204, '/': 18563, 'loss': 8.34859561920166} {'epoch': 0, 'update in batch': 205, '/': 18563, 'loss': 7.890545845031738} {'epoch': 0, 'update in batch': 206, '/': 18563, 'loss': 7.679532527923584} {'epoch': 0, 'update in batch': 207, '/': 18563, 'loss': 7.810311317443848} {'epoch': 0, 'update in batch': 208, '/': 18563, 'loss': 8.342585563659668} {'epoch': 0, 'update in batch': 209, '/': 18563, 'loss': 8.253597259521484} {'epoch': 0, 'update in batch': 210, '/': 18563, 'loss': 7.963072299957275} {'epoch': 0, 'update in batch': 211, '/': 18563, 'loss': 8.537101745605469} {'epoch': 0, 'update in batch': 212, '/': 18563, 'loss': 8.503724098205566} {'epoch': 0, 'update in batch': 213, '/': 18563, 'loss': 8.568987846374512} {'epoch': 0, 'update in batch': 214, '/': 18563, 'loss': 7.760678291320801} {'epoch': 0, 'update in batch': 215, '/': 18563, 'loss': 8.302183151245117} {'epoch': 0, 'update in batch': 216, '/': 18563, 'loss': 7.427420616149902} {'epoch': 0, 'update in batch': 217, '/': 18563, 'loss': 8.05746078491211} {'epoch': 0, 'update in batch': 218, '/': 18563, 'loss': 8.82285213470459} {'epoch': 0, 'update in batch': 219, '/': 18563, 'loss': 7.948827266693115} {'epoch': 0, 'update in batch': 220, '/': 18563, 'loss': 8.164112091064453} {'epoch': 0, 'update in batch': 221, '/': 18563, 'loss': 7.721047401428223} {'epoch': 0, 'update in batch': 222, '/': 18563, 'loss': 7.668707370758057} {'epoch': 0, 'update in batch': 223, '/': 18563, 'loss': 8.576696395874023} {'epoch': 0, 'update in batch': 224, '/': 18563, 'loss': 8.253091812133789} {'epoch': 0, 'update in batch': 225, '/': 18563, 'loss': 8.303543090820312} {'epoch': 0, 'update in batch': 226, '/': 18563, 'loss': 8.069855690002441} {'epoch': 0, 'update in batch': 227, '/': 18563, 'loss': 8.57229232788086} {'epoch': 0, 'update in batch': 228, '/': 18563, 'loss': 8.904585838317871} {'epoch': 0, 'update in batch': 229, '/': 18563, 'loss': 8.485595703125} {'epoch': 0, 'update in batch': 230, '/': 18563, 'loss': 8.22756290435791} {'epoch': 0, 'update in batch': 231, '/': 18563, 'loss': 8.281603813171387} {'epoch': 0, 'update in batch': 232, '/': 18563, 'loss': 7.591467380523682} {'epoch': 0, 'update in batch': 233, '/': 18563, 'loss': 7.8028883934021} {'epoch': 0, 'update in batch': 234, '/': 18563, 'loss': 8.079168319702148} {'epoch': 0, 'update in batch': 235, '/': 18563, 'loss': 7.578390598297119} {'epoch': 0, 'update in batch': 236, '/': 18563, 'loss': 7.865830421447754} {'epoch': 0, 'update in batch': 237, '/': 18563, 'loss': 7.105422019958496} {'epoch': 0, 'update in batch': 238, '/': 18563, 'loss': 8.034143447875977} {'epoch': 0, 'update in batch': 239, '/': 18563, 'loss': 7.23009729385376} {'epoch': 0, 'update in batch': 240, '/': 18563, 'loss': 7.221669673919678} {'epoch': 0, 'update in batch': 241, '/': 18563, 'loss': 7.118913173675537} {'epoch': 0, 'update in batch': 242, '/': 18563, 'loss': 7.690147399902344} {'epoch': 0, 'update in batch': 243, '/': 18563, 'loss': 7.676979064941406} {'epoch': 0, 'update in batch': 244, '/': 18563, 'loss': 8.231537818908691} {'epoch': 0, 'update in batch': 245, '/': 18563, 'loss': 8.212566375732422} {'epoch': 0, 'update in batch': 246, '/': 18563, 'loss': 9.095616340637207} {'epoch': 0, 'update in batch': 247, '/': 18563, 'loss': 8.249703407287598} {'epoch': 0, 'update in batch': 248, '/': 18563, 'loss': 9.082058906555176} {'epoch': 0, 'update in batch': 249, '/': 18563, 'loss': 8.530516624450684} {'epoch': 0, 'update in batch': 250, '/': 18563, 'loss': 8.979915618896484} {'epoch': 0, 'update in batch': 251, '/': 18563, 'loss': 8.667882919311523} {'epoch': 0, 'update in batch': 252, '/': 18563, 'loss': 8.804525375366211} {'epoch': 0, 'update in batch': 253, '/': 18563, 'loss': 8.67729377746582} {'epoch': 0, 'update in batch': 254, '/': 18563, 'loss': 8.580761909484863} {'epoch': 0, 'update in batch': 255, '/': 18563, 'loss': 7.724173545837402} {'epoch': 0, 'update in batch': 256, '/': 18563, 'loss': 7.7925591468811035} {'epoch': 0, 'update in batch': 257, '/': 18563, 'loss': 7.731482028961182} {'epoch': 0, 'update in batch': 258, '/': 18563, 'loss': 7.644040107727051} {'epoch': 0, 'update in batch': 259, '/': 18563, 'loss': 7.947877407073975} {'epoch': 0, 'update in batch': 260, '/': 18563, 'loss': 7.649043083190918} {'epoch': 0, 'update in batch': 261, '/': 18563, 'loss': 7.40912389755249} {'epoch': 0, 'update in batch': 262, '/': 18563, 'loss': 8.199918746948242} {'epoch': 0, 'update in batch': 263, '/': 18563, 'loss': 7.272132873535156} {'epoch': 0, 'update in batch': 264, '/': 18563, 'loss': 7.205214500427246} {'epoch': 0, 'update in batch': 265, '/': 18563, 'loss': 8.999595642089844} {'epoch': 0, 'update in batch': 266, '/': 18563, 'loss': 7.851510524749756} {'epoch': 0, 'update in batch': 267, '/': 18563, 'loss': 7.748948097229004} {'epoch': 0, 'update in batch': 268, '/': 18563, 'loss': 7.96875} {'epoch': 0, 'update in batch': 269, '/': 18563, 'loss': 7.627255916595459} {'epoch': 0, 'update in batch': 270, '/': 18563, 'loss': 7.719862937927246} {'epoch': 0, 'update in batch': 271, '/': 18563, 'loss': 7.58780574798584} {'epoch': 0, 'update in batch': 272, '/': 18563, 'loss': 8.386865615844727} {'epoch': 0, 'update in batch': 273, '/': 18563, 'loss': 8.708396911621094} {'epoch': 0, 'update in batch': 274, '/': 18563, 'loss': 7.853432655334473} {'epoch': 0, 'update in batch': 275, '/': 18563, 'loss': 7.818131923675537} {'epoch': 0, 'update in batch': 276, '/': 18563, 'loss': 7.714521884918213} {'epoch': 0, 'update in batch': 277, '/': 18563, 'loss': 8.75371265411377} {'epoch': 0, 'update in batch': 278, '/': 18563, 'loss': 7.6992998123168945} {'epoch': 0, 'update in batch': 279, '/': 18563, 'loss': 7.652693748474121} {'epoch': 0, 'update in batch': 280, '/': 18563, 'loss': 7.364585876464844} {'epoch': 0, 'update in batch': 281, '/': 18563, 'loss': 7.742022514343262} {'epoch': 0, 'update in batch': 282, '/': 18563, 'loss': 7.6205573081970215} {'epoch': 0, 'update in batch': 283, '/': 18563, 'loss': 7.475846290588379} {'epoch': 0, 'update in batch': 284, '/': 18563, 'loss': 7.302148342132568} {'epoch': 0, 'update in batch': 285, '/': 18563, 'loss': 7.524351596832275} {'epoch': 0, 'update in batch': 286, '/': 18563, 'loss': 7.755963325500488} {'epoch': 0, 'update in batch': 287, '/': 18563, 'loss': 7.620995998382568} {'epoch': 0, 'update in batch': 288, '/': 18563, 'loss': 7.289975166320801} {'epoch': 0, 'update in batch': 289, '/': 18563, 'loss': 7.470652103424072} {'epoch': 0, 'update in batch': 290, '/': 18563, 'loss': 7.297110557556152} {'epoch': 0, 'update in batch': 291, '/': 18563, 'loss': 7.907563209533691} {'epoch': 0, 'update in batch': 292, '/': 18563, 'loss': 8.051852226257324} {'epoch': 0, 'update in batch': 293, '/': 18563, 'loss': 6.691899299621582} {'epoch': 0, 'update in batch': 294, '/': 18563, 'loss': 7.9747819900512695} {'epoch': 0, 'update in batch': 295, '/': 18563, 'loss': 7.415904998779297} {'epoch': 0, 'update in batch': 296, '/': 18563, 'loss': 7.479670524597168} {'epoch': 0, 'update in batch': 297, '/': 18563, 'loss': 7.9454755783081055} {'epoch': 0, 'update in batch': 298, '/': 18563, 'loss': 7.79656457901001} {'epoch': 0, 'update in batch': 299, '/': 18563, 'loss': 7.644859313964844} {'epoch': 0, 'update in batch': 300, '/': 18563, 'loss': 7.649240970611572} {'epoch': 0, 'update in batch': 301, '/': 18563, 'loss': 7.497203826904297} {'epoch': 0, 'update in batch': 302, '/': 18563, 'loss': 7.169632911682129} {'epoch': 0, 'update in batch': 303, '/': 18563, 'loss': 7.124764442443848} {'epoch': 0, 'update in batch': 304, '/': 18563, 'loss': 7.728893280029297} {'epoch': 0, 'update in batch': 305, '/': 18563, 'loss': 8.029245376586914} {'epoch': 0, 'update in batch': 306, '/': 18563, 'loss': 7.361662864685059} {'epoch': 0, 'update in batch': 307, '/': 18563, 'loss': 8.070173263549805} {'epoch': 0, 'update in batch': 308, '/': 18563, 'loss': 7.55655574798584} {'epoch': 0, 'update in batch': 309, '/': 18563, 'loss': 7.713553428649902} {'epoch': 0, 'update in batch': 310, '/': 18563, 'loss': 8.333553314208984} {'epoch': 0, 'update in batch': 311, '/': 18563, 'loss': 8.089872360229492} {'epoch': 0, 'update in batch': 312, '/': 18563, 'loss': 8.951356887817383} {'epoch': 0, 'update in batch': 313, '/': 18563, 'loss': 8.920665740966797} {'epoch': 0, 'update in batch': 314, '/': 18563, 'loss': 8.811259269714355} {'epoch': 0, 'update in batch': 315, '/': 18563, 'loss': 8.719802856445312} {'epoch': 0, 'update in batch': 316, '/': 18563, 'loss': 8.700776100158691} {'epoch': 0, 'update in batch': 317, '/': 18563, 'loss': 8.846036911010742} {'epoch': 0, 'update in batch': 318, '/': 18563, 'loss': 8.553533554077148} {'epoch': 0, 'update in batch': 319, '/': 18563, 'loss': 9.257116317749023} {'epoch': 0, 'update in batch': 320, '/': 18563, 'loss': 8.487042427062988} {'epoch': 0, 'update in batch': 321, '/': 18563, 'loss': 8.743330955505371} {'epoch': 0, 'update in batch': 322, '/': 18563, 'loss': 8.377813339233398} {'epoch': 0, 'update in batch': 323, '/': 18563, 'loss': 8.41798210144043} {'epoch': 0, 'update in batch': 324, '/': 18563, 'loss': 7.884764671325684} {'epoch': 0, 'update in batch': 325, '/': 18563, 'loss': 8.827409744262695} {'epoch': 0, 'update in batch': 326, '/': 18563, 'loss': 8.21721363067627} {'epoch': 0, 'update in batch': 327, '/': 18563, 'loss': 8.522723197937012} {'epoch': 0, 'update in batch': 328, '/': 18563, 'loss': 7.387178897857666} {'epoch': 0, 'update in batch': 329, '/': 18563, 'loss': 8.58663558959961} {'epoch': 0, 'update in batch': 330, '/': 18563, 'loss': 8.539435386657715} {'epoch': 0, 'update in batch': 331, '/': 18563, 'loss': 8.35865592956543} {'epoch': 0, 'update in batch': 332, '/': 18563, 'loss': 8.55555248260498} {'epoch': 0, 'update in batch': 333, '/': 18563, 'loss': 7.9116950035095215} {'epoch': 0, 'update in batch': 334, '/': 18563, 'loss': 8.424735069274902} {'epoch': 0, 'update in batch': 335, '/': 18563, 'loss': 8.383890151977539} {'epoch': 0, 'update in batch': 336, '/': 18563, 'loss': 8.145454406738281} {'epoch': 0, 'update in batch': 337, '/': 18563, 'loss': 8.014772415161133} {'epoch': 0, 'update in batch': 338, '/': 18563, 'loss': 8.532005310058594} {'epoch': 0, 'update in batch': 339, '/': 18563, 'loss': 8.979973793029785} {'epoch': 0, 'update in batch': 340, '/': 18563, 'loss': 8.3964204788208} {'epoch': 0, 'update in batch': 341, '/': 18563, 'loss': 8.34205150604248} {'epoch': 0, 'update in batch': 342, '/': 18563, 'loss': 7.861489295959473} {'epoch': 0, 'update in batch': 343, '/': 18563, 'loss': 8.807058334350586} {'epoch': 0, 'update in batch': 344, '/': 18563, 'loss': 8.14976978302002} {'epoch': 0, 'update in batch': 345, '/': 18563, 'loss': 8.212860107421875} {'epoch': 0, 'update in batch': 346, '/': 18563, 'loss': 8.323419570922852} {'epoch': 0, 'update in batch': 347, '/': 18563, 'loss': 9.06071662902832} {'epoch': 0, 'update in batch': 348, '/': 18563, 'loss': 8.79192066192627} {'epoch': 0, 'update in batch': 349, '/': 18563, 'loss': 8.717201232910156} {'epoch': 0, 'update in batch': 350, '/': 18563, 'loss': 8.149703979492188} {'epoch': 0, 'update in batch': 351, '/': 18563, 'loss': 7.990046501159668} {'epoch': 0, 'update in batch': 352, '/': 18563, 'loss': 7.8197221755981445} {'epoch': 0, 'update in batch': 353, '/': 18563, 'loss': 8.022729873657227} {'epoch': 0, 'update in batch': 354, '/': 18563, 'loss': 8.339923858642578} {'epoch': 0, 'update in batch': 355, '/': 18563, 'loss': 7.867880821228027} {'epoch': 0, 'update in batch': 356, '/': 18563, 'loss': 8.161782264709473} {'epoch': 0, 'update in batch': 357, '/': 18563, 'loss': 7.711170196533203} {'epoch': 0, 'update in batch': 358, '/': 18563, 'loss': 8.46279239654541} {'epoch': 0, 'update in batch': 359, '/': 18563, 'loss': 8.327804565429688} {'epoch': 0, 'update in batch': 360, '/': 18563, 'loss': 8.184597969055176} {'epoch': 0, 'update in batch': 361, '/': 18563, 'loss': 8.126212120056152} {'epoch': 0, 'update in batch': 362, '/': 18563, 'loss': 8.122446060180664} {'epoch': 0, 'update in batch': 363, '/': 18563, 'loss': 7.730257511138916} {'epoch': 0, 'update in batch': 364, '/': 18563, 'loss': 7.7179059982299805} {'epoch': 0, 'update in batch': 365, '/': 18563, 'loss': 7.557857513427734} {'epoch': 0, 'update in batch': 366, '/': 18563, 'loss': 8.614083290100098} {'epoch': 0, 'update in batch': 367, '/': 18563, 'loss': 8.0489501953125} {'epoch': 0, 'update in batch': 368, '/': 18563, 'loss': 8.355381965637207} {'epoch': 0, 'update in batch': 369, '/': 18563, 'loss': 7.592991828918457} {'epoch': 0, 'update in batch': 370, '/': 18563, 'loss': 7.674102783203125} {'epoch': 0, 'update in batch': 371, '/': 18563, 'loss': 7.818256378173828} {'epoch': 0, 'update in batch': 372, '/': 18563, 'loss': 8.510438919067383} {'epoch': 0, 'update in batch': 373, '/': 18563, 'loss': 8.02087116241455} {'epoch': 0, 'update in batch': 374, '/': 18563, 'loss': 8.206090927124023} {'epoch': 0, 'update in batch': 375, '/': 18563, 'loss': 7.645677089691162} {'epoch': 0, 'update in batch': 376, '/': 18563, 'loss': 8.241236686706543} {'epoch': 0, 'update in batch': 377, '/': 18563, 'loss': 8.581649780273438} {'epoch': 0, 'update in batch': 378, '/': 18563, 'loss': 9.361258506774902} {'epoch': 0, 'update in batch': 379, '/': 18563, 'loss': 9.097440719604492} {'epoch': 0, 'update in batch': 380, '/': 18563, 'loss': 8.081677436828613} {'epoch': 0, 'update in batch': 381, '/': 18563, 'loss': 8.761143684387207} {'epoch': 0, 'update in batch': 382, '/': 18563, 'loss': 7.9429121017456055} {'epoch': 0, 'update in batch': 383, '/': 18563, 'loss': 8.05648422241211} {'epoch': 0, 'update in batch': 384, '/': 18563, 'loss': 7.316658020019531} {'epoch': 0, 'update in batch': 385, '/': 18563, 'loss': 8.597393035888672} {'epoch': 0, 'update in batch': 386, '/': 18563, 'loss': 9.393728256225586} {'epoch': 0, 'update in batch': 387, '/': 18563, 'loss': 8.225081443786621} {'epoch': 0, 'update in batch': 388, '/': 18563, 'loss': 7.9958319664001465} {'epoch': 0, 'update in batch': 389, '/': 18563, 'loss': 8.390036582946777} {'epoch': 0, 'update in batch': 390, '/': 18563, 'loss': 7.745572566986084} {'epoch': 0, 'update in batch': 391, '/': 18563, 'loss': 8.403060913085938} {'epoch': 0, 'update in batch': 392, '/': 18563, 'loss': 8.703788757324219} {'epoch': 0, 'update in batch': 393, '/': 18563, 'loss': 8.516857147216797} {'epoch': 0, 'update in batch': 394, '/': 18563, 'loss': 8.078744888305664} {'epoch': 0, 'update in batch': 395, '/': 18563, 'loss': 7.6597900390625} {'epoch': 0, 'update in batch': 396, '/': 18563, 'loss': 8.454282760620117} {'epoch': 0, 'update in batch': 397, '/': 18563, 'loss': 7.7727837562561035} {'epoch': 0, 'update in batch': 398, '/': 18563, 'loss': 8.222984313964844} {'epoch': 0, 'update in batch': 399, '/': 18563, 'loss': 8.369619369506836} {'epoch': 0, 'update in batch': 400, '/': 18563, 'loss': 8.542525291442871} {'epoch': 0, 'update in batch': 401, '/': 18563, 'loss': 7.9681854248046875} {'epoch': 0, 'update in batch': 402, '/': 18563, 'loss': 8.842118263244629} {'epoch': 0, 'update in batch': 403, '/': 18563, 'loss': 7.958454132080078} {'epoch': 0, 'update in batch': 404, '/': 18563, 'loss': 7.084095001220703} {'epoch': 0, 'update in batch': 405, '/': 18563, 'loss': 7.8765130043029785} {'epoch': 0, 'update in batch': 406, '/': 18563, 'loss': 7.639691352844238} {'epoch': 0, 'update in batch': 407, '/': 18563, 'loss': 7.440125942230225} {'epoch': 0, 'update in batch': 408, '/': 18563, 'loss': 7.928472995758057} {'epoch': 0, 'update in batch': 409, '/': 18563, 'loss': 8.704710960388184} {'epoch': 0, 'update in batch': 410, '/': 18563, 'loss': 8.214713096618652} {'epoch': 0, 'update in batch': 411, '/': 18563, 'loss': 8.115629196166992} {'epoch': 0, 'update in batch': 412, '/': 18563, 'loss': 9.357975006103516} {'epoch': 0, 'update in batch': 413, '/': 18563, 'loss': 7.756926536560059} {'epoch': 0, 'update in batch': 414, '/': 18563, 'loss': 8.93007755279541} {'epoch': 0, 'update in batch': 415, '/': 18563, 'loss': 8.929518699645996} {'epoch': 0, 'update in batch': 416, '/': 18563, 'loss': 7.646470069885254} {'epoch': 0, 'update in batch': 417, '/': 18563, 'loss': 8.457891464233398} {'epoch': 0, 'update in batch': 418, '/': 18563, 'loss': 7.377375602722168} {'epoch': 0, 'update in batch': 419, '/': 18563, 'loss': 8.03713607788086} {'epoch': 0, 'update in batch': 420, '/': 18563, 'loss': 8.125130653381348} {'epoch': 0, 'update in batch': 421, '/': 18563, 'loss': 6.818246364593506} {'epoch': 0, 'update in batch': 422, '/': 18563, 'loss': 7.220259189605713} {'epoch': 0, 'update in batch': 423, '/': 18563, 'loss': 7.800910949707031} {'epoch': 0, 'update in batch': 424, '/': 18563, 'loss': 8.175793647766113} {'epoch': 0, 'update in batch': 425, '/': 18563, 'loss': 7.588067054748535} {'epoch': 0, 'update in batch': 426, '/': 18563, 'loss': 7.2054619789123535} {'epoch': 0, 'update in batch': 427, '/': 18563, 'loss': 7.6552839279174805} {'epoch': 0, 'update in batch': 428, '/': 18563, 'loss': 8.851090431213379} {'epoch': 0, 'update in batch': 429, '/': 18563, 'loss': 8.768563270568848} {'epoch': 0, 'update in batch': 430, '/': 18563, 'loss': 7.926184177398682} {'epoch': 0, 'update in batch': 431, '/': 18563, 'loss': 8.663213729858398} {'epoch': 0, 'update in batch': 432, '/': 18563, 'loss': 8.386338233947754} {'epoch': 0, 'update in batch': 433, '/': 18563, 'loss': 8.77399730682373} {'epoch': 0, 'update in batch': 434, '/': 18563, 'loss': 8.385528564453125} {'epoch': 0, 'update in batch': 435, '/': 18563, 'loss': 7.742388725280762} {'epoch': 0, 'update in batch': 436, '/': 18563, 'loss': 8.363179206848145} {'epoch': 0, 'update in batch': 437, '/': 18563, 'loss': 9.262784004211426} {'epoch': 0, 'update in batch': 438, '/': 18563, 'loss': 9.236469268798828} {'epoch': 0, 'update in batch': 439, '/': 18563, 'loss': 8.904603958129883} {'epoch': 0, 'update in batch': 440, '/': 18563, 'loss': 8.675701141357422} {'epoch': 0, 'update in batch': 441, '/': 18563, 'loss': 8.811418533325195} {'epoch': 0, 'update in batch': 442, '/': 18563, 'loss': 8.002241134643555} {'epoch': 0, 'update in batch': 443, '/': 18563, 'loss': 9.04414176940918} {'epoch': 0, 'update in batch': 444, '/': 18563, 'loss': 7.8904008865356445} {'epoch': 0, 'update in batch': 445, '/': 18563, 'loss': 8.524297714233398} {'epoch': 0, 'update in batch': 446, '/': 18563, 'loss': 8.615904808044434} {'epoch': 0, 'update in batch': 447, '/': 18563, 'loss': 8.201675415039062} {'epoch': 0, 'update in batch': 448, '/': 18563, 'loss': 8.531024932861328} {'epoch': 0, 'update in batch': 449, '/': 18563, 'loss': 7.8379621505737305} {'epoch': 0, 'update in batch': 450, '/': 18563, 'loss': 8.416367530822754} {'epoch': 0, 'update in batch': 451, '/': 18563, 'loss': 7.4990715980529785} {'epoch': 0, 'update in batch': 452, '/': 18563, 'loss': 7.984610557556152} {'epoch': 0, 'update in batch': 453, '/': 18563, 'loss': 7.719987392425537} {'epoch': 0, 'update in batch': 454, '/': 18563, 'loss': 7.9333176612854} {'epoch': 0, 'update in batch': 455, '/': 18563, 'loss': 8.619344711303711} {'epoch': 0, 'update in batch': 456, '/': 18563, 'loss': 7.849525451660156} {'epoch': 0, 'update in batch': 457, '/': 18563, 'loss': 7.700997352600098} {'epoch': 0, 'update in batch': 458, '/': 18563, 'loss': 8.065767288208008} {'epoch': 0, 'update in batch': 459, '/': 18563, 'loss': 7.489628791809082} {'epoch': 0, 'update in batch': 460, '/': 18563, 'loss': 8.036481857299805} {'epoch': 0, 'update in batch': 461, '/': 18563, 'loss': 8.227537155151367} {'epoch': 0, 'update in batch': 462, '/': 18563, 'loss': 7.66103982925415} {'epoch': 0, 'update in batch': 463, '/': 18563, 'loss': 8.481343269348145} {'epoch': 0, 'update in batch': 464, '/': 18563, 'loss': 8.711318969726562} {'epoch': 0, 'update in batch': 465, '/': 18563, 'loss': 7.549925804138184} {'epoch': 0, 'update in batch': 466, '/': 18563, 'loss': 8.020782470703125} {'epoch': 0, 'update in batch': 467, '/': 18563, 'loss': 7.784451484680176} {'epoch': 0, 'update in batch': 468, '/': 18563, 'loss': 7.7545928955078125} {'epoch': 0, 'update in batch': 469, '/': 18563, 'loss': 8.484171867370605} {'epoch': 0, 'update in batch': 470, '/': 18563, 'loss': 8.291640281677246} {'epoch': 0, 'update in batch': 471, '/': 18563, 'loss': 7.873322486877441} {'epoch': 0, 'update in batch': 472, '/': 18563, 'loss': 7.891420841217041} {'epoch': 0, 'update in batch': 473, '/': 18563, 'loss': 8.376962661743164} {'epoch': 0, 'update in batch': 474, '/': 18563, 'loss': 8.147513389587402} {'epoch': 0, 'update in batch': 475, '/': 18563, 'loss': 7.739943027496338} {'epoch': 0, 'update in batch': 476, '/': 18563, 'loss': 7.52395486831665} {'epoch': 0, 'update in batch': 477, '/': 18563, 'loss': 7.962507724761963} {'epoch': 0, 'update in batch': 478, '/': 18563, 'loss': 7.61989688873291} {'epoch': 0, 'update in batch': 479, '/': 18563, 'loss': 8.628551483154297} {'epoch': 0, 'update in batch': 480, '/': 18563, 'loss': 10.344924926757812} {'epoch': 0, 'update in batch': 481, '/': 18563, 'loss': 9.189457893371582} {'epoch': 0, 'update in batch': 482, '/': 18563, 'loss': 9.283202171325684} {'epoch': 0, 'update in batch': 483, '/': 18563, 'loss': 8.036226272583008} {'epoch': 0, 'update in batch': 484, '/': 18563, 'loss': 8.949888229370117} {'epoch': 0, 'update in batch': 485, '/': 18563, 'loss': 9.32779598236084} {'epoch': 0, 'update in batch': 486, '/': 18563, 'loss': 9.554967880249023} {'epoch': 0, 'update in batch': 487, '/': 18563, 'loss': 8.438692092895508} {'epoch': 0, 'update in batch': 488, '/': 18563, 'loss': 8.015823364257812} {'epoch': 0, 'update in batch': 489, '/': 18563, 'loss': 8.621005058288574} {'epoch': 0, 'update in batch': 490, '/': 18563, 'loss': 8.432602882385254} {'epoch': 0, 'update in batch': 491, '/': 18563, 'loss': 8.659430503845215} {'epoch': 0, 'update in batch': 492, '/': 18563, 'loss': 8.693103790283203} {'epoch': 0, 'update in batch': 493, '/': 18563, 'loss': 8.895064353942871}
[0;31m---------------------------------------------------------------------------[0m [0;31mKeyboardInterrupt[0m Traceback (most recent call last) [0;32m<ipython-input-18-fe996a0be74b>[0m in [0;36m<module>[0;34m[0m [1;32m 1[0m [0mmodel[0m [0;34m=[0m [0mModel[0m[0;34m([0m[0mvocab_size[0m [0;34m=[0m [0mlen[0m[0;34m([0m[0mdataset[0m[0;34m.[0m[0muniq_words[0m[0;34m)[0m[0;34m)[0m[0;34m.[0m[0mto[0m[0;34m([0m[0mdevice[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m [0;32m----> 2[0;31m [0mtrain[0m[0;34m([0m[0mdataset[0m[0;34m,[0m [0mmodel[0m[0;34m,[0m [0;36m1[0m[0;34m,[0m [0;36m64[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m [0m [0;32m<ipython-input-17-8d700bc624e3>[0m in [0;36mtrain[0;34m(dataset, model, max_epochs, batch_size)[0m [1;32m 15[0m [0mloss[0m [0;34m=[0m [0mcriterion[0m[0;34m([0m[0my_pred[0m[0;34m.[0m[0mtranspose[0m[0;34m([0m[0;36m1[0m[0;34m,[0m [0;36m2[0m[0;34m)[0m[0;34m,[0m [0my[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m [1;32m 16[0m [0;34m[0m[0m [0;32m---> 17[0;31m [0mloss[0m[0;34m.[0m[0mbackward[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m [0m[1;32m 18[0m [0moptimizer[0m[0;34m.[0m[0mstep[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m [1;32m 19[0m [0;34m[0m[0m [0;32m~/anaconda3/lib/python3.8/site-packages/torch/_tensor.py[0m in [0;36mbackward[0;34m(self, gradient, retain_graph, create_graph, inputs)[0m [1;32m 361[0m [0mcreate_graph[0m[0;34m=[0m[0mcreate_graph[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m [1;32m 362[0m inputs=inputs) [0;32m--> 363[0;31m [0mtorch[0m[0;34m.[0m[0mautograd[0m[0;34m.[0m[0mbackward[0m[0;34m([0m[0mself[0m[0;34m,[0m [0mgradient[0m[0;34m,[0m [0mretain_graph[0m[0;34m,[0m [0mcreate_graph[0m[0;34m,[0m [0minputs[0m[0;34m=[0m[0minputs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m [0m[1;32m 364[0m [0;34m[0m[0m [1;32m 365[0m [0;32mdef[0m [0mregister_hook[0m[0;34m([0m[0mself[0m[0;34m,[0m [0mhook[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m [0;32m~/anaconda3/lib/python3.8/site-packages/torch/autograd/__init__.py[0m in [0;36mbackward[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)[0m [1;32m 171[0m [0;31m# some Python versions print out the first line of a multi-line function[0m[0;34m[0m[0;34m[0m[0;34m[0m[0m [1;32m 172[0m [0;31m# calls in the traceback and some print out the last line[0m[0;34m[0m[0;34m[0m[0;34m[0m[0m [0;32m--> 173[0;31m Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass [0m[1;32m 174[0m [0mtensors[0m[0;34m,[0m [0mgrad_tensors_[0m[0;34m,[0m [0mretain_graph[0m[0;34m,[0m [0mcreate_graph[0m[0;34m,[0m [0minputs[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m [1;32m 175[0m allow_unreachable=True, accumulate_grad=True) # Calls into the C++ engine to run the backward pass [0;31mKeyboardInterrupt[0m:
def predict(dataset, model, text, next_words=5):
model.eval()
words = text.split(' ')
state_h, state_c = model.init_state(len(words))
for i in range(0, next_words):
x = torch.tensor([[dataset.word_to_index[w] for w in words[i:]]]).to(device)
y_pred, (state_h, state_c) = model(x, (state_h, state_c))
last_word_logits = y_pred[0][-1]
p = torch.nn.functional.softmax(last_word_logits, dim=0).detach().cpu().numpy()
word_index = np.random.choice(len(last_word_logits), p=p)
words.append(dataset.index_to_word[word_index])
return words
predict(dataset, model, 'kmicic szedł')
['kmicic', 'szedł', 'zwycięzco', 'po', 'do', 'zlituj', 'i']
ZADANIE 1
Stworzyć sieć rekurencyjną GRU dla Challenging America word-gap prediction. Wymogi takie jak zawsze, zadanie widoczne na gonito
ZADANIE 2
Podjąć wyzwanie na https://gonito.net/challenge/precipitation-pl i/lub https://gonito.net/challenge/book-dialogues-pl
KONIECZNIE należy je zgłosić do końca następnego piątku, czyli 20 maja!. Za późniejsze zgłoszenia (nawet minutę) nieprzyznaję punktów.
Za każde zgłoszenie lepsze niż baseline przyznaję 40 punktów.
Zamiast tych 40 punktów za najlepsze miejsca:
- miejsce 150 punktów
- miejsce 100 punktów
- miejsce 70 punktów
Można brać udział w 2 wyzwaniach jednocześnie.
Zadania nie będą widoczne w gonito w achievements. Nie trzeba udostępniać kodu, należy jednak przestrzegać regulaminu wyzwań.