modelowanie-jezykowe-aitech-cw/cw/09_Model_neuronowy_rekurencyjny.ipynb
Jakub Pokrywka 7bf28acbf4 09
2022-05-09 09:59:47 +02:00

1017 lines
61 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"![Logo 1](https://git.wmi.amu.edu.pl/AITech/Szablon/raw/branch/master/Logotyp_AITech1.jpg)\n",
"<div class=\"alert alert-block alert-info\">\n",
"<h1> Modelowanie Języka</h1>\n",
"<h2> 9. <i>Model neuronowy rekurencyjny</i> [ćwiczenia]</h2> \n",
"<h3> Jakub Pokrywka (2022)</h3>\n",
"</div>\n",
"\n",
"![Logo 2](https://git.wmi.amu.edu.pl/AITech/Szablon/raw/branch/master/Logotyp_AITech2.jpg)"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from torch import nn, optim\n",
"from torch.utils.data import DataLoader\n",
"import numpy as np\n",
"from collections import Counter\n",
"import re"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"device = 'cpu'"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"--2022-05-08 19:27:04-- https://wolnelektury.pl/media/book/txt/potop-tom-pierwszy.txt\n",
"Resolving wolnelektury.pl (wolnelektury.pl)... 51.83.143.148, 2001:41d0:602:3294::\n",
"Connecting to wolnelektury.pl (wolnelektury.pl)|51.83.143.148|:443... connected.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 877893 (857K) [text/plain]\n",
"Saving to: potop-tom-pierwszy.txt.2\n",
"\n",
"potop-tom-pierwszy. 100%[===================>] 857,32K --.-KB/s in 0,07s \n",
"\n",
"2022-05-08 19:27:04 (12,0 MB/s) - potop-tom-pierwszy.txt.2 saved [877893/877893]\n",
"\n",
"--2022-05-08 19:27:04-- https://wolnelektury.pl/media/book/txt/potop-tom-drugi.txt\n",
"Resolving wolnelektury.pl (wolnelektury.pl)... 51.83.143.148, 2001:41d0:602:3294::\n",
"Connecting to wolnelektury.pl (wolnelektury.pl)|51.83.143.148|:443... connected.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 1087797 (1,0M) [text/plain]\n",
"Saving to: potop-tom-drugi.txt.2\n",
"\n",
"potop-tom-drugi.txt 100%[===================>] 1,04M --.-KB/s in 0,08s \n",
"\n",
"2022-05-08 19:27:04 (12,9 MB/s) - potop-tom-drugi.txt.2 saved [1087797/1087797]\n",
"\n",
"--2022-05-08 19:27:05-- https://wolnelektury.pl/media/book/txt/potop-tom-trzeci.txt\n",
"Resolving wolnelektury.pl (wolnelektury.pl)... 51.83.143.148, 2001:41d0:602:3294::\n",
"Connecting to wolnelektury.pl (wolnelektury.pl)|51.83.143.148|:443... connected.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 788219 (770K) [text/plain]\n",
"Saving to: potop-tom-trzeci.txt.2\n",
"\n",
"potop-tom-trzeci.tx 100%[===================>] 769,75K --.-KB/s in 0,06s \n",
"\n",
"2022-05-08 19:27:05 (12,0 MB/s) - potop-tom-trzeci.txt.2 saved [788219/788219]\n",
"\n"
]
}
],
"source": [
"! wget https://wolnelektury.pl/media/book/txt/potop-tom-pierwszy.txt\n",
"! wget https://wolnelektury.pl/media/book/txt/potop-tom-drugi.txt\n",
"! wget https://wolnelektury.pl/media/book/txt/potop-tom-trzeci.txt"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"!cat potop-* > potop.txt"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"class Dataset(torch.utils.data.Dataset):\n",
" def __init__(\n",
" self,\n",
" sequence_length,\n",
" ):\n",
" self.sequence_length = sequence_length\n",
" self.words = self.load()\n",
" self.uniq_words = self.get_uniq_words()\n",
"\n",
" self.index_to_word = {index: word for index, word in enumerate(self.uniq_words)}\n",
" self.word_to_index = {word: index for index, word in enumerate(self.uniq_words)}\n",
"\n",
" self.words_indexes = [self.word_to_index[w] for w in self.words]\n",
"\n",
" def load(self):\n",
" with open('potop.txt', 'r') as f_in:\n",
" text = [x.rstrip() for x in f_in.readlines() if x.strip()]\n",
" text = ' '.join(text).lower()\n",
" text = re.sub('[^a-ząćęłńóśźż ]', '', text) \n",
" text = text.split(' ')\n",
" return text\n",
" \n",
" \n",
" def get_uniq_words(self):\n",
" word_counts = Counter(self.words)\n",
" return sorted(word_counts, key=word_counts.get, reverse=True)\n",
"\n",
" def __len__(self):\n",
" return len(self.words_indexes) - self.sequence_length\n",
"\n",
" def __getitem__(self, index):\n",
" return (\n",
" torch.tensor(self.words_indexes[index:index+self.sequence_length]),\n",
" torch.tensor(self.words_indexes[index+1:index+self.sequence_length+1]),\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"dataset = Dataset(5)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"text/plain": [
"(tensor([ 551, 18, 17, 255, 10748]),\n",
" tensor([ 18, 17, 255, 10748, 34]))"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dataset[200]"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['patrzył', 'tak', 'jak', 'człowiek', 'zbudzony']"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"[dataset.index_to_word[x] for x in [ 551, 18, 17, 255, 10748]]"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['tak', 'jak', 'człowiek', 'zbudzony', 'ze']"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"[dataset.index_to_word[x] for x in [ 18, 17, 255, 10748, 34]]"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"input_tensor = torch.tensor([[ 551, 18, 17, 255, 10748]], dtype=torch.int32).to(device)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"#input_tensor = torch.tensor([[ 551, 18]], dtype=torch.int32).to(device)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"class Model(nn.Module):\n",
" def __init__(self, vocab_size):\n",
" super(Model, self).__init__()\n",
" self.lstm_size = 128\n",
" self.embedding_dim = 128\n",
" self.num_layers = 3\n",
"\n",
" self.embedding = nn.Embedding(\n",
" num_embeddings=vocab_size,\n",
" embedding_dim=self.embedding_dim,\n",
" )\n",
" self.lstm = nn.LSTM(\n",
" input_size=self.lstm_size,\n",
" hidden_size=self.lstm_size,\n",
" num_layers=self.num_layers,\n",
" dropout=0.2,\n",
" )\n",
" self.fc = nn.Linear(self.lstm_size, vocab_size)\n",
"\n",
" def forward(self, x, prev_state = None):\n",
" embed = self.embedding(x)\n",
" output, state = self.lstm(embed, prev_state)\n",
" logits = self.fc(output)\n",
" return logits, state\n",
"\n",
" def init_state(self, sequence_length):\n",
" return (torch.zeros(self.num_layers, sequence_length, self.lstm_size).to(device),\n",
" torch.zeros(self.num_layers, sequence_length, self.lstm_size).to(device))"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"model = Model(len(dataset)).to(device)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"y_pred, (state_h, state_c) = model(input_tensor)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[[ 0.0046, -0.0113, 0.0313, ..., 0.0198, -0.0312, 0.0223],\n",
" [ 0.0039, -0.0110, 0.0303, ..., 0.0213, -0.0302, 0.0230],\n",
" [ 0.0029, -0.0133, 0.0265, ..., 0.0204, -0.0297, 0.0219],\n",
" [ 0.0010, -0.0120, 0.0282, ..., 0.0241, -0.0314, 0.0241],\n",
" [ 0.0038, -0.0106, 0.0346, ..., 0.0230, -0.0333, 0.0232]]],\n",
" grad_fn=<AddBackward0>)"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y_pred"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([1, 5, 1187998])"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y_pred.shape"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"def train(dataset, model, max_epochs, batch_size):\n",
" model.train()\n",
"\n",
" dataloader = DataLoader(dataset, batch_size=batch_size)\n",
" criterion = nn.CrossEntropyLoss()\n",
" optimizer = optim.Adam(model.parameters(), lr=0.001)\n",
"\n",
" for epoch in range(max_epochs):\n",
" for batch, (x, y) in enumerate(dataloader):\n",
" optimizer.zero_grad()\n",
" x = x.to(device)\n",
" y = y.to(device)\n",
"\n",
" y_pred, (state_h, state_c) = model(x)\n",
" loss = criterion(y_pred.transpose(1, 2), y)\n",
"\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
" print({ 'epoch': epoch, 'update in batch': batch, '/' : len(dataloader), 'loss': loss.item() })\n"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'epoch': 0, 'update in batch': 0, '/': 18563, 'loss': 10.717817306518555}\n",
"{'epoch': 0, 'update in batch': 1, '/': 18563, 'loss': 10.699922561645508}\n",
"{'epoch': 0, 'update in batch': 2, '/': 18563, 'loss': 10.701103210449219}\n",
"{'epoch': 0, 'update in batch': 3, '/': 18563, 'loss': 10.700254440307617}\n",
"{'epoch': 0, 'update in batch': 4, '/': 18563, 'loss': 10.69465160369873}\n",
"{'epoch': 0, 'update in batch': 5, '/': 18563, 'loss': 10.681333541870117}\n",
"{'epoch': 0, 'update in batch': 6, '/': 18563, 'loss': 10.668376922607422}\n",
"{'epoch': 0, 'update in batch': 7, '/': 18563, 'loss': 10.675261497497559}\n",
"{'epoch': 0, 'update in batch': 8, '/': 18563, 'loss': 10.665823936462402}\n",
"{'epoch': 0, 'update in batch': 9, '/': 18563, 'loss': 10.655462265014648}\n",
"{'epoch': 0, 'update in batch': 10, '/': 18563, 'loss': 10.591516494750977}\n",
"{'epoch': 0, 'update in batch': 11, '/': 18563, 'loss': 10.580559730529785}\n",
"{'epoch': 0, 'update in batch': 12, '/': 18563, 'loss': 10.524133682250977}\n",
"{'epoch': 0, 'update in batch': 13, '/': 18563, 'loss': 10.480895042419434}\n",
"{'epoch': 0, 'update in batch': 14, '/': 18563, 'loss': 10.33996295928955}\n",
"{'epoch': 0, 'update in batch': 15, '/': 18563, 'loss': 10.345580101013184}\n",
"{'epoch': 0, 'update in batch': 16, '/': 18563, 'loss': 10.200639724731445}\n",
"{'epoch': 0, 'update in batch': 17, '/': 18563, 'loss': 10.030133247375488}\n",
"{'epoch': 0, 'update in batch': 18, '/': 18563, 'loss': 10.046720504760742}\n",
"{'epoch': 0, 'update in batch': 19, '/': 18563, 'loss': 10.00318717956543}\n",
"{'epoch': 0, 'update in batch': 20, '/': 18563, 'loss': 9.588350296020508}\n",
"{'epoch': 0, 'update in batch': 21, '/': 18563, 'loss': 9.780914306640625}\n",
"{'epoch': 0, 'update in batch': 22, '/': 18563, 'loss': 9.36646842956543}\n",
"{'epoch': 0, 'update in batch': 23, '/': 18563, 'loss': 9.306387901306152}\n",
"{'epoch': 0, 'update in batch': 24, '/': 18563, 'loss': 9.150574684143066}\n",
"{'epoch': 0, 'update in batch': 25, '/': 18563, 'loss': 8.89719295501709}\n",
"{'epoch': 0, 'update in batch': 26, '/': 18563, 'loss': 8.741975784301758}\n",
"{'epoch': 0, 'update in batch': 27, '/': 18563, 'loss': 9.36513614654541}\n",
"{'epoch': 0, 'update in batch': 28, '/': 18563, 'loss': 8.840768814086914}\n",
"{'epoch': 0, 'update in batch': 29, '/': 18563, 'loss': 8.356801986694336}\n",
"{'epoch': 0, 'update in batch': 30, '/': 18563, 'loss': 8.274016380310059}\n",
"{'epoch': 0, 'update in batch': 31, '/': 18563, 'loss': 8.944927215576172}\n",
"{'epoch': 0, 'update in batch': 32, '/': 18563, 'loss': 8.923280715942383}\n",
"{'epoch': 0, 'update in batch': 33, '/': 18563, 'loss': 8.479402542114258}\n",
"{'epoch': 0, 'update in batch': 34, '/': 18563, 'loss': 8.42425537109375}\n",
"{'epoch': 0, 'update in batch': 35, '/': 18563, 'loss': 9.487113952636719}\n",
"{'epoch': 0, 'update in batch': 36, '/': 18563, 'loss': 8.314191818237305}\n",
"{'epoch': 0, 'update in batch': 37, '/': 18563, 'loss': 8.0274658203125}\n",
"{'epoch': 0, 'update in batch': 38, '/': 18563, 'loss': 8.725769996643066}\n",
"{'epoch': 0, 'update in batch': 39, '/': 18563, 'loss': 8.67934799194336}\n",
"{'epoch': 0, 'update in batch': 40, '/': 18563, 'loss': 8.872161865234375}\n",
"{'epoch': 0, 'update in batch': 41, '/': 18563, 'loss': 7.883971214294434}\n",
"{'epoch': 0, 'update in batch': 42, '/': 18563, 'loss': 7.682810306549072}\n",
"{'epoch': 0, 'update in batch': 43, '/': 18563, 'loss': 7.880677223205566}\n",
"{'epoch': 0, 'update in batch': 44, '/': 18563, 'loss': 7.807427406311035}\n",
"{'epoch': 0, 'update in batch': 45, '/': 18563, 'loss': 7.93829870223999}\n",
"{'epoch': 0, 'update in batch': 46, '/': 18563, 'loss': 7.718912601470947}\n",
"{'epoch': 0, 'update in batch': 47, '/': 18563, 'loss': 8.309863090515137}\n",
"{'epoch': 0, 'update in batch': 48, '/': 18563, 'loss': 9.091133117675781}\n",
"{'epoch': 0, 'update in batch': 49, '/': 18563, 'loss': 9.317312240600586}\n",
"{'epoch': 0, 'update in batch': 50, '/': 18563, 'loss': 8.517735481262207}\n",
"{'epoch': 0, 'update in batch': 51, '/': 18563, 'loss': 7.697592258453369}\n",
"{'epoch': 0, 'update in batch': 52, '/': 18563, 'loss': 6.838181972503662}\n",
"{'epoch': 0, 'update in batch': 53, '/': 18563, 'loss': 7.967227935791016}\n",
"{'epoch': 0, 'update in batch': 54, '/': 18563, 'loss': 8.47049331665039}\n",
"{'epoch': 0, 'update in batch': 55, '/': 18563, 'loss': 8.958921432495117}\n",
"{'epoch': 0, 'update in batch': 56, '/': 18563, 'loss': 8.316679000854492}\n",
"{'epoch': 0, 'update in batch': 57, '/': 18563, 'loss': 8.997099876403809}\n",
"{'epoch': 0, 'update in batch': 58, '/': 18563, 'loss': 8.608811378479004}\n",
"{'epoch': 0, 'update in batch': 59, '/': 18563, 'loss': 9.377460479736328}\n",
"{'epoch': 0, 'update in batch': 60, '/': 18563, 'loss': 8.6201171875}\n",
"{'epoch': 0, 'update in batch': 61, '/': 18563, 'loss': 8.821510314941406}\n",
"{'epoch': 0, 'update in batch': 62, '/': 18563, 'loss': 8.915961265563965}\n",
"{'epoch': 0, 'update in batch': 63, '/': 18563, 'loss': 8.222617149353027}\n",
"{'epoch': 0, 'update in batch': 64, '/': 18563, 'loss': 9.266777992248535}\n",
"{'epoch': 0, 'update in batch': 65, '/': 18563, 'loss': 8.749354362487793}\n",
"{'epoch': 0, 'update in batch': 66, '/': 18563, 'loss': 8.311641693115234}\n",
"{'epoch': 0, 'update in batch': 67, '/': 18563, 'loss': 8.553888320922852}\n",
"{'epoch': 0, 'update in batch': 68, '/': 18563, 'loss': 8.790258407592773}\n",
"{'epoch': 0, 'update in batch': 69, '/': 18563, 'loss': 9.090133666992188}\n",
"{'epoch': 0, 'update in batch': 70, '/': 18563, 'loss': 8.893723487854004}\n",
"{'epoch': 0, 'update in batch': 71, '/': 18563, 'loss': 8.844594955444336}\n",
"{'epoch': 0, 'update in batch': 72, '/': 18563, 'loss': 7.771625518798828}\n",
"{'epoch': 0, 'update in batch': 73, '/': 18563, 'loss': 8.536479949951172}\n",
"{'epoch': 0, 'update in batch': 74, '/': 18563, 'loss': 7.300860404968262}\n",
"{'epoch': 0, 'update in batch': 75, '/': 18563, 'loss': 8.62000846862793}\n",
"{'epoch': 0, 'update in batch': 76, '/': 18563, 'loss': 8.67784309387207}\n",
"{'epoch': 0, 'update in batch': 77, '/': 18563, 'loss': 7.319235801696777}\n",
"{'epoch': 0, 'update in batch': 78, '/': 18563, 'loss': 8.322186470031738}\n",
"{'epoch': 0, 'update in batch': 79, '/': 18563, 'loss': 7.767421722412109}\n",
"{'epoch': 0, 'update in batch': 80, '/': 18563, 'loss': 8.817885398864746}\n",
"{'epoch': 0, 'update in batch': 81, '/': 18563, 'loss': 8.133109092712402}\n",
"{'epoch': 0, 'update in batch': 82, '/': 18563, 'loss': 7.822054862976074}\n",
"{'epoch': 0, 'update in batch': 83, '/': 18563, 'loss': 8.055540084838867}\n",
"{'epoch': 0, 'update in batch': 84, '/': 18563, 'loss': 8.053682327270508}\n",
"{'epoch': 0, 'update in batch': 85, '/': 18563, 'loss': 8.018306732177734}\n",
"{'epoch': 0, 'update in batch': 86, '/': 18563, 'loss': 8.371909141540527}\n",
"{'epoch': 0, 'update in batch': 87, '/': 18563, 'loss': 8.057979583740234}\n",
"{'epoch': 0, 'update in batch': 88, '/': 18563, 'loss': 8.340703010559082}\n",
"{'epoch': 0, 'update in batch': 89, '/': 18563, 'loss': 8.7703857421875}\n",
"{'epoch': 0, 'update in batch': 90, '/': 18563, 'loss': 9.714847564697266}\n",
"{'epoch': 0, 'update in batch': 91, '/': 18563, 'loss': 8.621702194213867}\n",
"{'epoch': 0, 'update in batch': 92, '/': 18563, 'loss': 9.406997680664062}\n",
"{'epoch': 0, 'update in batch': 93, '/': 18563, 'loss': 9.29774284362793}\n",
"{'epoch': 0, 'update in batch': 94, '/': 18563, 'loss': 8.649836540222168}\n",
"{'epoch': 0, 'update in batch': 95, '/': 18563, 'loss': 8.441780090332031}\n",
"{'epoch': 0, 'update in batch': 96, '/': 18563, 'loss': 7.991406440734863}\n",
"{'epoch': 0, 'update in batch': 97, '/': 18563, 'loss': 9.314489364624023}\n",
"{'epoch': 0, 'update in batch': 98, '/': 18563, 'loss': 8.368816375732422}\n",
"{'epoch': 0, 'update in batch': 99, '/': 18563, 'loss': 8.771149635314941}\n",
"{'epoch': 0, 'update in batch': 100, '/': 18563, 'loss': 7.8758111000061035}\n",
"{'epoch': 0, 'update in batch': 101, '/': 18563, 'loss': 8.341328620910645}\n",
"{'epoch': 0, 'update in batch': 102, '/': 18563, 'loss': 8.413129806518555}\n",
"{'epoch': 0, 'update in batch': 103, '/': 18563, 'loss': 7.372011661529541}\n",
"{'epoch': 0, 'update in batch': 104, '/': 18563, 'loss': 8.170934677124023}\n",
"{'epoch': 0, 'update in batch': 105, '/': 18563, 'loss': 8.109993934631348}\n",
"{'epoch': 0, 'update in batch': 106, '/': 18563, 'loss': 8.172578811645508}\n",
"{'epoch': 0, 'update in batch': 107, '/': 18563, 'loss': 8.33222484588623}\n",
"{'epoch': 0, 'update in batch': 108, '/': 18563, 'loss': 7.997575283050537}\n",
"{'epoch': 0, 'update in batch': 109, '/': 18563, 'loss': 7.847937107086182}\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'epoch': 0, 'update in batch': 110, '/': 18563, 'loss': 7.351314544677734}\n",
"{'epoch': 0, 'update in batch': 111, '/': 18563, 'loss': 8.472936630249023}\n",
"{'epoch': 0, 'update in batch': 112, '/': 18563, 'loss': 7.855953216552734}\n",
"{'epoch': 0, 'update in batch': 113, '/': 18563, 'loss': 8.163175582885742}\n",
"{'epoch': 0, 'update in batch': 114, '/': 18563, 'loss': 8.208657264709473}\n",
"{'epoch': 0, 'update in batch': 115, '/': 18563, 'loss': 8.781523704528809}\n",
"{'epoch': 0, 'update in batch': 116, '/': 18563, 'loss': 8.449674606323242}\n",
"{'epoch': 0, 'update in batch': 117, '/': 18563, 'loss': 8.176030158996582}\n",
"{'epoch': 0, 'update in batch': 118, '/': 18563, 'loss': 8.415689468383789}\n",
"{'epoch': 0, 'update in batch': 119, '/': 18563, 'loss': 8.645845413208008}\n",
"{'epoch': 0, 'update in batch': 120, '/': 18563, 'loss': 8.160420417785645}\n",
"{'epoch': 0, 'update in batch': 121, '/': 18563, 'loss': 8.117982864379883}\n",
"{'epoch': 0, 'update in batch': 122, '/': 18563, 'loss': 9.099283218383789}\n",
"{'epoch': 0, 'update in batch': 123, '/': 18563, 'loss': 7.98253870010376}\n",
"{'epoch': 0, 'update in batch': 124, '/': 18563, 'loss': 8.112133979797363}\n",
"{'epoch': 0, 'update in batch': 125, '/': 18563, 'loss': 8.479134559631348}\n",
"{'epoch': 0, 'update in batch': 126, '/': 18563, 'loss': 8.92817497253418}\n",
"{'epoch': 0, 'update in batch': 127, '/': 18563, 'loss': 8.38918399810791}\n",
"{'epoch': 0, 'update in batch': 128, '/': 18563, 'loss': 9.000529289245605}\n",
"{'epoch': 0, 'update in batch': 129, '/': 18563, 'loss': 8.525534629821777}\n",
"{'epoch': 0, 'update in batch': 130, '/': 18563, 'loss': 9.055428504943848}\n",
"{'epoch': 0, 'update in batch': 131, '/': 18563, 'loss': 8.818662643432617}\n",
"{'epoch': 0, 'update in batch': 132, '/': 18563, 'loss': 8.807767868041992}\n",
"{'epoch': 0, 'update in batch': 133, '/': 18563, 'loss': 8.398343086242676}\n",
"{'epoch': 0, 'update in batch': 134, '/': 18563, 'loss': 8.435093879699707}\n",
"{'epoch': 0, 'update in batch': 135, '/': 18563, 'loss': 7.877000331878662}\n",
"{'epoch': 0, 'update in batch': 136, '/': 18563, 'loss': 8.197925567626953}\n",
"{'epoch': 0, 'update in batch': 137, '/': 18563, 'loss': 8.655011177062988}\n",
"{'epoch': 0, 'update in batch': 138, '/': 18563, 'loss': 7.786923885345459}\n",
"{'epoch': 0, 'update in batch': 139, '/': 18563, 'loss': 8.338996887207031}\n",
"{'epoch': 0, 'update in batch': 140, '/': 18563, 'loss': 8.607789993286133}\n",
"{'epoch': 0, 'update in batch': 141, '/': 18563, 'loss': 8.52219295501709}\n",
"{'epoch': 0, 'update in batch': 142, '/': 18563, 'loss': 8.436418533325195}\n",
"{'epoch': 0, 'update in batch': 143, '/': 18563, 'loss': 7.999323844909668}\n",
"{'epoch': 0, 'update in batch': 144, '/': 18563, 'loss': 7.543336391448975}\n",
"{'epoch': 0, 'update in batch': 145, '/': 18563, 'loss': 7.3255791664123535}\n",
"{'epoch': 0, 'update in batch': 146, '/': 18563, 'loss': 7.993613243103027}\n",
"{'epoch': 0, 'update in batch': 147, '/': 18563, 'loss': 8.8505859375}\n",
"{'epoch': 0, 'update in batch': 148, '/': 18563, 'loss': 8.146835327148438}\n",
"{'epoch': 0, 'update in batch': 149, '/': 18563, 'loss': 8.532424926757812}\n",
"{'epoch': 0, 'update in batch': 150, '/': 18563, 'loss': 8.323905944824219}\n",
"{'epoch': 0, 'update in batch': 151, '/': 18563, 'loss': 7.8726677894592285}\n",
"{'epoch': 0, 'update in batch': 152, '/': 18563, 'loss': 7.912005424499512}\n",
"{'epoch': 0, 'update in batch': 153, '/': 18563, 'loss': 8.010560035705566}\n",
"{'epoch': 0, 'update in batch': 154, '/': 18563, 'loss': 7.9417009353637695}\n",
"{'epoch': 0, 'update in batch': 155, '/': 18563, 'loss': 7.991711616516113}\n",
"{'epoch': 0, 'update in batch': 156, '/': 18563, 'loss': 8.27558708190918}\n",
"{'epoch': 0, 'update in batch': 157, '/': 18563, 'loss': 7.736246585845947}\n",
"{'epoch': 0, 'update in batch': 158, '/': 18563, 'loss': 7.4755754470825195}\n",
"{'epoch': 0, 'update in batch': 159, '/': 18563, 'loss': 8.023443222045898}\n",
"{'epoch': 0, 'update in batch': 160, '/': 18563, 'loss': 8.130350112915039}\n",
"{'epoch': 0, 'update in batch': 161, '/': 18563, 'loss': 7.770634651184082}\n",
"{'epoch': 0, 'update in batch': 162, '/': 18563, 'loss': 7.775434970855713}\n",
"{'epoch': 0, 'update in batch': 163, '/': 18563, 'loss': 7.965312957763672}\n",
"{'epoch': 0, 'update in batch': 164, '/': 18563, 'loss': 7.977341651916504}\n",
"{'epoch': 0, 'update in batch': 165, '/': 18563, 'loss': 7.703671455383301}\n",
"{'epoch': 0, 'update in batch': 166, '/': 18563, 'loss': 8.027135848999023}\n",
"{'epoch': 0, 'update in batch': 167, '/': 18563, 'loss': 7.7673773765563965}\n",
"{'epoch': 0, 'update in batch': 168, '/': 18563, 'loss': 8.654549598693848}\n",
"{'epoch': 0, 'update in batch': 169, '/': 18563, 'loss': 7.8060808181762695}\n",
"{'epoch': 0, 'update in batch': 170, '/': 18563, 'loss': 7.33704137802124}\n",
"{'epoch': 0, 'update in batch': 171, '/': 18563, 'loss': 7.971919059753418}\n",
"{'epoch': 0, 'update in batch': 172, '/': 18563, 'loss': 7.450611114501953}\n",
"{'epoch': 0, 'update in batch': 173, '/': 18563, 'loss': 7.978057861328125}\n",
"{'epoch': 0, 'update in batch': 174, '/': 18563, 'loss': 8.264434814453125}\n",
"{'epoch': 0, 'update in batch': 175, '/': 18563, 'loss': 8.47761058807373}\n",
"{'epoch': 0, 'update in batch': 176, '/': 18563, 'loss': 7.643885135650635}\n",
"{'epoch': 0, 'update in batch': 177, '/': 18563, 'loss': 8.696805000305176}\n",
"{'epoch': 0, 'update in batch': 178, '/': 18563, 'loss': 9.144462585449219}\n",
"{'epoch': 0, 'update in batch': 179, '/': 18563, 'loss': 8.582620620727539}\n",
"{'epoch': 0, 'update in batch': 180, '/': 18563, 'loss': 8.495562553405762}\n",
"{'epoch': 0, 'update in batch': 181, '/': 18563, 'loss': 9.259647369384766}\n",
"{'epoch': 0, 'update in batch': 182, '/': 18563, 'loss': 8.286632537841797}\n",
"{'epoch': 0, 'update in batch': 183, '/': 18563, 'loss': 8.378074645996094}\n",
"{'epoch': 0, 'update in batch': 184, '/': 18563, 'loss': 8.404892921447754}\n",
"{'epoch': 0, 'update in batch': 185, '/': 18563, 'loss': 9.206843376159668}\n",
"{'epoch': 0, 'update in batch': 186, '/': 18563, 'loss': 8.97215747833252}\n",
"{'epoch': 0, 'update in batch': 187, '/': 18563, 'loss': 8.281005859375}\n",
"{'epoch': 0, 'update in batch': 188, '/': 18563, 'loss': 7.638144493103027}\n",
"{'epoch': 0, 'update in batch': 189, '/': 18563, 'loss': 7.991082668304443}\n",
"{'epoch': 0, 'update in batch': 190, '/': 18563, 'loss': 8.207674026489258}\n",
"{'epoch': 0, 'update in batch': 191, '/': 18563, 'loss': 8.16801643371582}\n",
"{'epoch': 0, 'update in batch': 192, '/': 18563, 'loss': 7.827309608459473}\n",
"{'epoch': 0, 'update in batch': 193, '/': 18563, 'loss': 8.387285232543945}\n",
"{'epoch': 0, 'update in batch': 194, '/': 18563, 'loss': 7.990261077880859}\n",
"{'epoch': 0, 'update in batch': 195, '/': 18563, 'loss': 7.7953925132751465}\n",
"{'epoch': 0, 'update in batch': 196, '/': 18563, 'loss': 7.252983093261719}\n",
"{'epoch': 0, 'update in batch': 197, '/': 18563, 'loss': 7.806585788726807}\n",
"{'epoch': 0, 'update in batch': 198, '/': 18563, 'loss': 7.871600151062012}\n",
"{'epoch': 0, 'update in batch': 199, '/': 18563, 'loss': 7.639830589294434}\n",
"{'epoch': 0, 'update in batch': 200, '/': 18563, 'loss': 8.108308792114258}\n",
"{'epoch': 0, 'update in batch': 201, '/': 18563, 'loss': 7.41513729095459}\n",
"{'epoch': 0, 'update in batch': 202, '/': 18563, 'loss': 8.103743553161621}\n",
"{'epoch': 0, 'update in batch': 203, '/': 18563, 'loss': 8.82174301147461}\n",
"{'epoch': 0, 'update in batch': 204, '/': 18563, 'loss': 8.34859561920166}\n",
"{'epoch': 0, 'update in batch': 205, '/': 18563, 'loss': 7.890545845031738}\n",
"{'epoch': 0, 'update in batch': 206, '/': 18563, 'loss': 7.679532527923584}\n",
"{'epoch': 0, 'update in batch': 207, '/': 18563, 'loss': 7.810311317443848}\n",
"{'epoch': 0, 'update in batch': 208, '/': 18563, 'loss': 8.342585563659668}\n",
"{'epoch': 0, 'update in batch': 209, '/': 18563, 'loss': 8.253597259521484}\n",
"{'epoch': 0, 'update in batch': 210, '/': 18563, 'loss': 7.963072299957275}\n",
"{'epoch': 0, 'update in batch': 211, '/': 18563, 'loss': 8.537101745605469}\n",
"{'epoch': 0, 'update in batch': 212, '/': 18563, 'loss': 8.503724098205566}\n",
"{'epoch': 0, 'update in batch': 213, '/': 18563, 'loss': 8.568987846374512}\n",
"{'epoch': 0, 'update in batch': 214, '/': 18563, 'loss': 7.760678291320801}\n",
"{'epoch': 0, 'update in batch': 215, '/': 18563, 'loss': 8.302183151245117}\n",
"{'epoch': 0, 'update in batch': 216, '/': 18563, 'loss': 7.427420616149902}\n",
"{'epoch': 0, 'update in batch': 217, '/': 18563, 'loss': 8.05746078491211}\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'epoch': 0, 'update in batch': 218, '/': 18563, 'loss': 8.82285213470459}\n",
"{'epoch': 0, 'update in batch': 219, '/': 18563, 'loss': 7.948827266693115}\n",
"{'epoch': 0, 'update in batch': 220, '/': 18563, 'loss': 8.164112091064453}\n",
"{'epoch': 0, 'update in batch': 221, '/': 18563, 'loss': 7.721047401428223}\n",
"{'epoch': 0, 'update in batch': 222, '/': 18563, 'loss': 7.668707370758057}\n",
"{'epoch': 0, 'update in batch': 223, '/': 18563, 'loss': 8.576696395874023}\n",
"{'epoch': 0, 'update in batch': 224, '/': 18563, 'loss': 8.253091812133789}\n",
"{'epoch': 0, 'update in batch': 225, '/': 18563, 'loss': 8.303543090820312}\n",
"{'epoch': 0, 'update in batch': 226, '/': 18563, 'loss': 8.069855690002441}\n",
"{'epoch': 0, 'update in batch': 227, '/': 18563, 'loss': 8.57229232788086}\n",
"{'epoch': 0, 'update in batch': 228, '/': 18563, 'loss': 8.904585838317871}\n",
"{'epoch': 0, 'update in batch': 229, '/': 18563, 'loss': 8.485595703125}\n",
"{'epoch': 0, 'update in batch': 230, '/': 18563, 'loss': 8.22756290435791}\n",
"{'epoch': 0, 'update in batch': 231, '/': 18563, 'loss': 8.281603813171387}\n",
"{'epoch': 0, 'update in batch': 232, '/': 18563, 'loss': 7.591467380523682}\n",
"{'epoch': 0, 'update in batch': 233, '/': 18563, 'loss': 7.8028883934021}\n",
"{'epoch': 0, 'update in batch': 234, '/': 18563, 'loss': 8.079168319702148}\n",
"{'epoch': 0, 'update in batch': 235, '/': 18563, 'loss': 7.578390598297119}\n",
"{'epoch': 0, 'update in batch': 236, '/': 18563, 'loss': 7.865830421447754}\n",
"{'epoch': 0, 'update in batch': 237, '/': 18563, 'loss': 7.105422019958496}\n",
"{'epoch': 0, 'update in batch': 238, '/': 18563, 'loss': 8.034143447875977}\n",
"{'epoch': 0, 'update in batch': 239, '/': 18563, 'loss': 7.23009729385376}\n",
"{'epoch': 0, 'update in batch': 240, '/': 18563, 'loss': 7.221669673919678}\n",
"{'epoch': 0, 'update in batch': 241, '/': 18563, 'loss': 7.118913173675537}\n",
"{'epoch': 0, 'update in batch': 242, '/': 18563, 'loss': 7.690147399902344}\n",
"{'epoch': 0, 'update in batch': 243, '/': 18563, 'loss': 7.676979064941406}\n",
"{'epoch': 0, 'update in batch': 244, '/': 18563, 'loss': 8.231537818908691}\n",
"{'epoch': 0, 'update in batch': 245, '/': 18563, 'loss': 8.212566375732422}\n",
"{'epoch': 0, 'update in batch': 246, '/': 18563, 'loss': 9.095616340637207}\n",
"{'epoch': 0, 'update in batch': 247, '/': 18563, 'loss': 8.249703407287598}\n",
"{'epoch': 0, 'update in batch': 248, '/': 18563, 'loss': 9.082058906555176}\n",
"{'epoch': 0, 'update in batch': 249, '/': 18563, 'loss': 8.530516624450684}\n",
"{'epoch': 0, 'update in batch': 250, '/': 18563, 'loss': 8.979915618896484}\n",
"{'epoch': 0, 'update in batch': 251, '/': 18563, 'loss': 8.667882919311523}\n",
"{'epoch': 0, 'update in batch': 252, '/': 18563, 'loss': 8.804525375366211}\n",
"{'epoch': 0, 'update in batch': 253, '/': 18563, 'loss': 8.67729377746582}\n",
"{'epoch': 0, 'update in batch': 254, '/': 18563, 'loss': 8.580761909484863}\n",
"{'epoch': 0, 'update in batch': 255, '/': 18563, 'loss': 7.724173545837402}\n",
"{'epoch': 0, 'update in batch': 256, '/': 18563, 'loss': 7.7925591468811035}\n",
"{'epoch': 0, 'update in batch': 257, '/': 18563, 'loss': 7.731482028961182}\n",
"{'epoch': 0, 'update in batch': 258, '/': 18563, 'loss': 7.644040107727051}\n",
"{'epoch': 0, 'update in batch': 259, '/': 18563, 'loss': 7.947877407073975}\n",
"{'epoch': 0, 'update in batch': 260, '/': 18563, 'loss': 7.649043083190918}\n",
"{'epoch': 0, 'update in batch': 261, '/': 18563, 'loss': 7.40912389755249}\n",
"{'epoch': 0, 'update in batch': 262, '/': 18563, 'loss': 8.199918746948242}\n",
"{'epoch': 0, 'update in batch': 263, '/': 18563, 'loss': 7.272132873535156}\n",
"{'epoch': 0, 'update in batch': 264, '/': 18563, 'loss': 7.205214500427246}\n",
"{'epoch': 0, 'update in batch': 265, '/': 18563, 'loss': 8.999595642089844}\n",
"{'epoch': 0, 'update in batch': 266, '/': 18563, 'loss': 7.851510524749756}\n",
"{'epoch': 0, 'update in batch': 267, '/': 18563, 'loss': 7.748948097229004}\n",
"{'epoch': 0, 'update in batch': 268, '/': 18563, 'loss': 7.96875}\n",
"{'epoch': 0, 'update in batch': 269, '/': 18563, 'loss': 7.627255916595459}\n",
"{'epoch': 0, 'update in batch': 270, '/': 18563, 'loss': 7.719862937927246}\n",
"{'epoch': 0, 'update in batch': 271, '/': 18563, 'loss': 7.58780574798584}\n",
"{'epoch': 0, 'update in batch': 272, '/': 18563, 'loss': 8.386865615844727}\n",
"{'epoch': 0, 'update in batch': 273, '/': 18563, 'loss': 8.708396911621094}\n",
"{'epoch': 0, 'update in batch': 274, '/': 18563, 'loss': 7.853432655334473}\n",
"{'epoch': 0, 'update in batch': 275, '/': 18563, 'loss': 7.818131923675537}\n",
"{'epoch': 0, 'update in batch': 276, '/': 18563, 'loss': 7.714521884918213}\n",
"{'epoch': 0, 'update in batch': 277, '/': 18563, 'loss': 8.75371265411377}\n",
"{'epoch': 0, 'update in batch': 278, '/': 18563, 'loss': 7.6992998123168945}\n",
"{'epoch': 0, 'update in batch': 279, '/': 18563, 'loss': 7.652693748474121}\n",
"{'epoch': 0, 'update in batch': 280, '/': 18563, 'loss': 7.364585876464844}\n",
"{'epoch': 0, 'update in batch': 281, '/': 18563, 'loss': 7.742022514343262}\n",
"{'epoch': 0, 'update in batch': 282, '/': 18563, 'loss': 7.6205573081970215}\n",
"{'epoch': 0, 'update in batch': 283, '/': 18563, 'loss': 7.475846290588379}\n",
"{'epoch': 0, 'update in batch': 284, '/': 18563, 'loss': 7.302148342132568}\n",
"{'epoch': 0, 'update in batch': 285, '/': 18563, 'loss': 7.524351596832275}\n",
"{'epoch': 0, 'update in batch': 286, '/': 18563, 'loss': 7.755963325500488}\n",
"{'epoch': 0, 'update in batch': 287, '/': 18563, 'loss': 7.620995998382568}\n",
"{'epoch': 0, 'update in batch': 288, '/': 18563, 'loss': 7.289975166320801}\n",
"{'epoch': 0, 'update in batch': 289, '/': 18563, 'loss': 7.470652103424072}\n",
"{'epoch': 0, 'update in batch': 290, '/': 18563, 'loss': 7.297110557556152}\n",
"{'epoch': 0, 'update in batch': 291, '/': 18563, 'loss': 7.907563209533691}\n",
"{'epoch': 0, 'update in batch': 292, '/': 18563, 'loss': 8.051852226257324}\n",
"{'epoch': 0, 'update in batch': 293, '/': 18563, 'loss': 6.691899299621582}\n",
"{'epoch': 0, 'update in batch': 294, '/': 18563, 'loss': 7.9747819900512695}\n",
"{'epoch': 0, 'update in batch': 295, '/': 18563, 'loss': 7.415904998779297}\n",
"{'epoch': 0, 'update in batch': 296, '/': 18563, 'loss': 7.479670524597168}\n",
"{'epoch': 0, 'update in batch': 297, '/': 18563, 'loss': 7.9454755783081055}\n",
"{'epoch': 0, 'update in batch': 298, '/': 18563, 'loss': 7.79656457901001}\n",
"{'epoch': 0, 'update in batch': 299, '/': 18563, 'loss': 7.644859313964844}\n",
"{'epoch': 0, 'update in batch': 300, '/': 18563, 'loss': 7.649240970611572}\n",
"{'epoch': 0, 'update in batch': 301, '/': 18563, 'loss': 7.497203826904297}\n",
"{'epoch': 0, 'update in batch': 302, '/': 18563, 'loss': 7.169632911682129}\n",
"{'epoch': 0, 'update in batch': 303, '/': 18563, 'loss': 7.124764442443848}\n",
"{'epoch': 0, 'update in batch': 304, '/': 18563, 'loss': 7.728893280029297}\n",
"{'epoch': 0, 'update in batch': 305, '/': 18563, 'loss': 8.029245376586914}\n",
"{'epoch': 0, 'update in batch': 306, '/': 18563, 'loss': 7.361662864685059}\n",
"{'epoch': 0, 'update in batch': 307, '/': 18563, 'loss': 8.070173263549805}\n",
"{'epoch': 0, 'update in batch': 308, '/': 18563, 'loss': 7.55655574798584}\n",
"{'epoch': 0, 'update in batch': 309, '/': 18563, 'loss': 7.713553428649902}\n",
"{'epoch': 0, 'update in batch': 310, '/': 18563, 'loss': 8.333553314208984}\n",
"{'epoch': 0, 'update in batch': 311, '/': 18563, 'loss': 8.089872360229492}\n",
"{'epoch': 0, 'update in batch': 312, '/': 18563, 'loss': 8.951356887817383}\n",
"{'epoch': 0, 'update in batch': 313, '/': 18563, 'loss': 8.920665740966797}\n",
"{'epoch': 0, 'update in batch': 314, '/': 18563, 'loss': 8.811259269714355}\n",
"{'epoch': 0, 'update in batch': 315, '/': 18563, 'loss': 8.719802856445312}\n",
"{'epoch': 0, 'update in batch': 316, '/': 18563, 'loss': 8.700776100158691}\n",
"{'epoch': 0, 'update in batch': 317, '/': 18563, 'loss': 8.846036911010742}\n",
"{'epoch': 0, 'update in batch': 318, '/': 18563, 'loss': 8.553533554077148}\n",
"{'epoch': 0, 'update in batch': 319, '/': 18563, 'loss': 9.257116317749023}\n",
"{'epoch': 0, 'update in batch': 320, '/': 18563, 'loss': 8.487042427062988}\n",
"{'epoch': 0, 'update in batch': 321, '/': 18563, 'loss': 8.743330955505371}\n",
"{'epoch': 0, 'update in batch': 322, '/': 18563, 'loss': 8.377813339233398}\n",
"{'epoch': 0, 'update in batch': 323, '/': 18563, 'loss': 8.41798210144043}\n",
"{'epoch': 0, 'update in batch': 324, '/': 18563, 'loss': 7.884764671325684}\n",
"{'epoch': 0, 'update in batch': 325, '/': 18563, 'loss': 8.827409744262695}\n",
"{'epoch': 0, 'update in batch': 326, '/': 18563, 'loss': 8.21721363067627}\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'epoch': 0, 'update in batch': 327, '/': 18563, 'loss': 8.522723197937012}\n",
"{'epoch': 0, 'update in batch': 328, '/': 18563, 'loss': 7.387178897857666}\n",
"{'epoch': 0, 'update in batch': 329, '/': 18563, 'loss': 8.58663558959961}\n",
"{'epoch': 0, 'update in batch': 330, '/': 18563, 'loss': 8.539435386657715}\n",
"{'epoch': 0, 'update in batch': 331, '/': 18563, 'loss': 8.35865592956543}\n",
"{'epoch': 0, 'update in batch': 332, '/': 18563, 'loss': 8.55555248260498}\n",
"{'epoch': 0, 'update in batch': 333, '/': 18563, 'loss': 7.9116950035095215}\n",
"{'epoch': 0, 'update in batch': 334, '/': 18563, 'loss': 8.424735069274902}\n",
"{'epoch': 0, 'update in batch': 335, '/': 18563, 'loss': 8.383890151977539}\n",
"{'epoch': 0, 'update in batch': 336, '/': 18563, 'loss': 8.145454406738281}\n",
"{'epoch': 0, 'update in batch': 337, '/': 18563, 'loss': 8.014772415161133}\n",
"{'epoch': 0, 'update in batch': 338, '/': 18563, 'loss': 8.532005310058594}\n",
"{'epoch': 0, 'update in batch': 339, '/': 18563, 'loss': 8.979973793029785}\n",
"{'epoch': 0, 'update in batch': 340, '/': 18563, 'loss': 8.3964204788208}\n",
"{'epoch': 0, 'update in batch': 341, '/': 18563, 'loss': 8.34205150604248}\n",
"{'epoch': 0, 'update in batch': 342, '/': 18563, 'loss': 7.861489295959473}\n",
"{'epoch': 0, 'update in batch': 343, '/': 18563, 'loss': 8.807058334350586}\n",
"{'epoch': 0, 'update in batch': 344, '/': 18563, 'loss': 8.14976978302002}\n",
"{'epoch': 0, 'update in batch': 345, '/': 18563, 'loss': 8.212860107421875}\n",
"{'epoch': 0, 'update in batch': 346, '/': 18563, 'loss': 8.323419570922852}\n",
"{'epoch': 0, 'update in batch': 347, '/': 18563, 'loss': 9.06071662902832}\n",
"{'epoch': 0, 'update in batch': 348, '/': 18563, 'loss': 8.79192066192627}\n",
"{'epoch': 0, 'update in batch': 349, '/': 18563, 'loss': 8.717201232910156}\n",
"{'epoch': 0, 'update in batch': 350, '/': 18563, 'loss': 8.149703979492188}\n",
"{'epoch': 0, 'update in batch': 351, '/': 18563, 'loss': 7.990046501159668}\n",
"{'epoch': 0, 'update in batch': 352, '/': 18563, 'loss': 7.8197221755981445}\n",
"{'epoch': 0, 'update in batch': 353, '/': 18563, 'loss': 8.022729873657227}\n",
"{'epoch': 0, 'update in batch': 354, '/': 18563, 'loss': 8.339923858642578}\n",
"{'epoch': 0, 'update in batch': 355, '/': 18563, 'loss': 7.867880821228027}\n",
"{'epoch': 0, 'update in batch': 356, '/': 18563, 'loss': 8.161782264709473}\n",
"{'epoch': 0, 'update in batch': 357, '/': 18563, 'loss': 7.711170196533203}\n",
"{'epoch': 0, 'update in batch': 358, '/': 18563, 'loss': 8.46279239654541}\n",
"{'epoch': 0, 'update in batch': 359, '/': 18563, 'loss': 8.327804565429688}\n",
"{'epoch': 0, 'update in batch': 360, '/': 18563, 'loss': 8.184597969055176}\n",
"{'epoch': 0, 'update in batch': 361, '/': 18563, 'loss': 8.126212120056152}\n",
"{'epoch': 0, 'update in batch': 362, '/': 18563, 'loss': 8.122446060180664}\n",
"{'epoch': 0, 'update in batch': 363, '/': 18563, 'loss': 7.730257511138916}\n",
"{'epoch': 0, 'update in batch': 364, '/': 18563, 'loss': 7.7179059982299805}\n",
"{'epoch': 0, 'update in batch': 365, '/': 18563, 'loss': 7.557857513427734}\n",
"{'epoch': 0, 'update in batch': 366, '/': 18563, 'loss': 8.614083290100098}\n",
"{'epoch': 0, 'update in batch': 367, '/': 18563, 'loss': 8.0489501953125}\n",
"{'epoch': 0, 'update in batch': 368, '/': 18563, 'loss': 8.355381965637207}\n",
"{'epoch': 0, 'update in batch': 369, '/': 18563, 'loss': 7.592991828918457}\n",
"{'epoch': 0, 'update in batch': 370, '/': 18563, 'loss': 7.674102783203125}\n",
"{'epoch': 0, 'update in batch': 371, '/': 18563, 'loss': 7.818256378173828}\n",
"{'epoch': 0, 'update in batch': 372, '/': 18563, 'loss': 8.510438919067383}\n",
"{'epoch': 0, 'update in batch': 373, '/': 18563, 'loss': 8.02087116241455}\n",
"{'epoch': 0, 'update in batch': 374, '/': 18563, 'loss': 8.206090927124023}\n",
"{'epoch': 0, 'update in batch': 375, '/': 18563, 'loss': 7.645677089691162}\n",
"{'epoch': 0, 'update in batch': 376, '/': 18563, 'loss': 8.241236686706543}\n",
"{'epoch': 0, 'update in batch': 377, '/': 18563, 'loss': 8.581649780273438}\n",
"{'epoch': 0, 'update in batch': 378, '/': 18563, 'loss': 9.361258506774902}\n",
"{'epoch': 0, 'update in batch': 379, '/': 18563, 'loss': 9.097440719604492}\n",
"{'epoch': 0, 'update in batch': 380, '/': 18563, 'loss': 8.081677436828613}\n",
"{'epoch': 0, 'update in batch': 381, '/': 18563, 'loss': 8.761143684387207}\n",
"{'epoch': 0, 'update in batch': 382, '/': 18563, 'loss': 7.9429121017456055}\n",
"{'epoch': 0, 'update in batch': 383, '/': 18563, 'loss': 8.05648422241211}\n",
"{'epoch': 0, 'update in batch': 384, '/': 18563, 'loss': 7.316658020019531}\n",
"{'epoch': 0, 'update in batch': 385, '/': 18563, 'loss': 8.597393035888672}\n",
"{'epoch': 0, 'update in batch': 386, '/': 18563, 'loss': 9.393728256225586}\n",
"{'epoch': 0, 'update in batch': 387, '/': 18563, 'loss': 8.225081443786621}\n",
"{'epoch': 0, 'update in batch': 388, '/': 18563, 'loss': 7.9958319664001465}\n",
"{'epoch': 0, 'update in batch': 389, '/': 18563, 'loss': 8.390036582946777}\n",
"{'epoch': 0, 'update in batch': 390, '/': 18563, 'loss': 7.745572566986084}\n",
"{'epoch': 0, 'update in batch': 391, '/': 18563, 'loss': 8.403060913085938}\n",
"{'epoch': 0, 'update in batch': 392, '/': 18563, 'loss': 8.703788757324219}\n",
"{'epoch': 0, 'update in batch': 393, '/': 18563, 'loss': 8.516857147216797}\n",
"{'epoch': 0, 'update in batch': 394, '/': 18563, 'loss': 8.078744888305664}\n",
"{'epoch': 0, 'update in batch': 395, '/': 18563, 'loss': 7.6597900390625}\n",
"{'epoch': 0, 'update in batch': 396, '/': 18563, 'loss': 8.454282760620117}\n",
"{'epoch': 0, 'update in batch': 397, '/': 18563, 'loss': 7.7727837562561035}\n",
"{'epoch': 0, 'update in batch': 398, '/': 18563, 'loss': 8.222984313964844}\n",
"{'epoch': 0, 'update in batch': 399, '/': 18563, 'loss': 8.369619369506836}\n",
"{'epoch': 0, 'update in batch': 400, '/': 18563, 'loss': 8.542525291442871}\n",
"{'epoch': 0, 'update in batch': 401, '/': 18563, 'loss': 7.9681854248046875}\n",
"{'epoch': 0, 'update in batch': 402, '/': 18563, 'loss': 8.842118263244629}\n",
"{'epoch': 0, 'update in batch': 403, '/': 18563, 'loss': 7.958454132080078}\n",
"{'epoch': 0, 'update in batch': 404, '/': 18563, 'loss': 7.084095001220703}\n",
"{'epoch': 0, 'update in batch': 405, '/': 18563, 'loss': 7.8765130043029785}\n",
"{'epoch': 0, 'update in batch': 406, '/': 18563, 'loss': 7.639691352844238}\n",
"{'epoch': 0, 'update in batch': 407, '/': 18563, 'loss': 7.440125942230225}\n",
"{'epoch': 0, 'update in batch': 408, '/': 18563, 'loss': 7.928472995758057}\n",
"{'epoch': 0, 'update in batch': 409, '/': 18563, 'loss': 8.704710960388184}\n",
"{'epoch': 0, 'update in batch': 410, '/': 18563, 'loss': 8.214713096618652}\n",
"{'epoch': 0, 'update in batch': 411, '/': 18563, 'loss': 8.115629196166992}\n",
"{'epoch': 0, 'update in batch': 412, '/': 18563, 'loss': 9.357975006103516}\n",
"{'epoch': 0, 'update in batch': 413, '/': 18563, 'loss': 7.756926536560059}\n",
"{'epoch': 0, 'update in batch': 414, '/': 18563, 'loss': 8.93007755279541}\n",
"{'epoch': 0, 'update in batch': 415, '/': 18563, 'loss': 8.929518699645996}\n",
"{'epoch': 0, 'update in batch': 416, '/': 18563, 'loss': 7.646470069885254}\n",
"{'epoch': 0, 'update in batch': 417, '/': 18563, 'loss': 8.457891464233398}\n",
"{'epoch': 0, 'update in batch': 418, '/': 18563, 'loss': 7.377375602722168}\n",
"{'epoch': 0, 'update in batch': 419, '/': 18563, 'loss': 8.03713607788086}\n",
"{'epoch': 0, 'update in batch': 420, '/': 18563, 'loss': 8.125130653381348}\n",
"{'epoch': 0, 'update in batch': 421, '/': 18563, 'loss': 6.818246364593506}\n",
"{'epoch': 0, 'update in batch': 422, '/': 18563, 'loss': 7.220259189605713}\n",
"{'epoch': 0, 'update in batch': 423, '/': 18563, 'loss': 7.800910949707031}\n",
"{'epoch': 0, 'update in batch': 424, '/': 18563, 'loss': 8.175793647766113}\n",
"{'epoch': 0, 'update in batch': 425, '/': 18563, 'loss': 7.588067054748535}\n",
"{'epoch': 0, 'update in batch': 426, '/': 18563, 'loss': 7.2054619789123535}\n",
"{'epoch': 0, 'update in batch': 427, '/': 18563, 'loss': 7.6552839279174805}\n",
"{'epoch': 0, 'update in batch': 428, '/': 18563, 'loss': 8.851090431213379}\n",
"{'epoch': 0, 'update in batch': 429, '/': 18563, 'loss': 8.768563270568848}\n",
"{'epoch': 0, 'update in batch': 430, '/': 18563, 'loss': 7.926184177398682}\n",
"{'epoch': 0, 'update in batch': 431, '/': 18563, 'loss': 8.663213729858398}\n",
"{'epoch': 0, 'update in batch': 432, '/': 18563, 'loss': 8.386338233947754}\n",
"{'epoch': 0, 'update in batch': 433, '/': 18563, 'loss': 8.77399730682373}\n",
"{'epoch': 0, 'update in batch': 434, '/': 18563, 'loss': 8.385528564453125}\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'epoch': 0, 'update in batch': 435, '/': 18563, 'loss': 7.742388725280762}\n",
"{'epoch': 0, 'update in batch': 436, '/': 18563, 'loss': 8.363179206848145}\n",
"{'epoch': 0, 'update in batch': 437, '/': 18563, 'loss': 9.262784004211426}\n",
"{'epoch': 0, 'update in batch': 438, '/': 18563, 'loss': 9.236469268798828}\n",
"{'epoch': 0, 'update in batch': 439, '/': 18563, 'loss': 8.904603958129883}\n",
"{'epoch': 0, 'update in batch': 440, '/': 18563, 'loss': 8.675701141357422}\n",
"{'epoch': 0, 'update in batch': 441, '/': 18563, 'loss': 8.811418533325195}\n",
"{'epoch': 0, 'update in batch': 442, '/': 18563, 'loss': 8.002241134643555}\n",
"{'epoch': 0, 'update in batch': 443, '/': 18563, 'loss': 9.04414176940918}\n",
"{'epoch': 0, 'update in batch': 444, '/': 18563, 'loss': 7.8904008865356445}\n",
"{'epoch': 0, 'update in batch': 445, '/': 18563, 'loss': 8.524297714233398}\n",
"{'epoch': 0, 'update in batch': 446, '/': 18563, 'loss': 8.615904808044434}\n",
"{'epoch': 0, 'update in batch': 447, '/': 18563, 'loss': 8.201675415039062}\n",
"{'epoch': 0, 'update in batch': 448, '/': 18563, 'loss': 8.531024932861328}\n",
"{'epoch': 0, 'update in batch': 449, '/': 18563, 'loss': 7.8379621505737305}\n",
"{'epoch': 0, 'update in batch': 450, '/': 18563, 'loss': 8.416367530822754}\n",
"{'epoch': 0, 'update in batch': 451, '/': 18563, 'loss': 7.4990715980529785}\n",
"{'epoch': 0, 'update in batch': 452, '/': 18563, 'loss': 7.984610557556152}\n",
"{'epoch': 0, 'update in batch': 453, '/': 18563, 'loss': 7.719987392425537}\n",
"{'epoch': 0, 'update in batch': 454, '/': 18563, 'loss': 7.9333176612854}\n",
"{'epoch': 0, 'update in batch': 455, '/': 18563, 'loss': 8.619344711303711}\n",
"{'epoch': 0, 'update in batch': 456, '/': 18563, 'loss': 7.849525451660156}\n",
"{'epoch': 0, 'update in batch': 457, '/': 18563, 'loss': 7.700997352600098}\n",
"{'epoch': 0, 'update in batch': 458, '/': 18563, 'loss': 8.065767288208008}\n",
"{'epoch': 0, 'update in batch': 459, '/': 18563, 'loss': 7.489628791809082}\n",
"{'epoch': 0, 'update in batch': 460, '/': 18563, 'loss': 8.036481857299805}\n",
"{'epoch': 0, 'update in batch': 461, '/': 18563, 'loss': 8.227537155151367}\n",
"{'epoch': 0, 'update in batch': 462, '/': 18563, 'loss': 7.66103982925415}\n",
"{'epoch': 0, 'update in batch': 463, '/': 18563, 'loss': 8.481343269348145}\n",
"{'epoch': 0, 'update in batch': 464, '/': 18563, 'loss': 8.711318969726562}\n",
"{'epoch': 0, 'update in batch': 465, '/': 18563, 'loss': 7.549925804138184}\n",
"{'epoch': 0, 'update in batch': 466, '/': 18563, 'loss': 8.020782470703125}\n",
"{'epoch': 0, 'update in batch': 467, '/': 18563, 'loss': 7.784451484680176}\n",
"{'epoch': 0, 'update in batch': 468, '/': 18563, 'loss': 7.7545928955078125}\n",
"{'epoch': 0, 'update in batch': 469, '/': 18563, 'loss': 8.484171867370605}\n",
"{'epoch': 0, 'update in batch': 470, '/': 18563, 'loss': 8.291640281677246}\n",
"{'epoch': 0, 'update in batch': 471, '/': 18563, 'loss': 7.873322486877441}\n",
"{'epoch': 0, 'update in batch': 472, '/': 18563, 'loss': 7.891420841217041}\n",
"{'epoch': 0, 'update in batch': 473, '/': 18563, 'loss': 8.376962661743164}\n",
"{'epoch': 0, 'update in batch': 474, '/': 18563, 'loss': 8.147513389587402}\n",
"{'epoch': 0, 'update in batch': 475, '/': 18563, 'loss': 7.739943027496338}\n",
"{'epoch': 0, 'update in batch': 476, '/': 18563, 'loss': 7.52395486831665}\n",
"{'epoch': 0, 'update in batch': 477, '/': 18563, 'loss': 7.962507724761963}\n",
"{'epoch': 0, 'update in batch': 478, '/': 18563, 'loss': 7.61989688873291}\n",
"{'epoch': 0, 'update in batch': 479, '/': 18563, 'loss': 8.628551483154297}\n",
"{'epoch': 0, 'update in batch': 480, '/': 18563, 'loss': 10.344924926757812}\n",
"{'epoch': 0, 'update in batch': 481, '/': 18563, 'loss': 9.189457893371582}\n",
"{'epoch': 0, 'update in batch': 482, '/': 18563, 'loss': 9.283202171325684}\n",
"{'epoch': 0, 'update in batch': 483, '/': 18563, 'loss': 8.036226272583008}\n",
"{'epoch': 0, 'update in batch': 484, '/': 18563, 'loss': 8.949888229370117}\n",
"{'epoch': 0, 'update in batch': 485, '/': 18563, 'loss': 9.32779598236084}\n",
"{'epoch': 0, 'update in batch': 486, '/': 18563, 'loss': 9.554967880249023}\n",
"{'epoch': 0, 'update in batch': 487, '/': 18563, 'loss': 8.438692092895508}\n",
"{'epoch': 0, 'update in batch': 488, '/': 18563, 'loss': 8.015823364257812}\n",
"{'epoch': 0, 'update in batch': 489, '/': 18563, 'loss': 8.621005058288574}\n",
"{'epoch': 0, 'update in batch': 490, '/': 18563, 'loss': 8.432602882385254}\n",
"{'epoch': 0, 'update in batch': 491, '/': 18563, 'loss': 8.659430503845215}\n",
"{'epoch': 0, 'update in batch': 492, '/': 18563, 'loss': 8.693103790283203}\n",
"{'epoch': 0, 'update in batch': 493, '/': 18563, 'loss': 8.895064353942871}\n"
]
},
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-18-fe996a0be74b>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mModel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvocab_size\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0muniq_words\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m64\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;32m<ipython-input-17-8d700bc624e3>\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(dataset, model, max_epochs, batch_size)\u001b[0m\n\u001b[1;32m 15\u001b[0m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcriterion\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_pred\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtranspose\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 16\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 17\u001b[0;31m \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 18\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 19\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/anaconda3/lib/python3.8/site-packages/torch/_tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m 361\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 362\u001b[0m inputs=inputs)\n\u001b[0;32m--> 363\u001b[0;31m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mautograd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 364\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 365\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mregister_hook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/anaconda3/lib/python3.8/site-packages/torch/autograd/__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m 171\u001b[0m \u001b[0;31m# some Python versions print out the first line of a multi-line function\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 172\u001b[0m \u001b[0;31m# calls in the traceback and some print out the last line\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 173\u001b[0;31m Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass\n\u001b[0m\u001b[1;32m 174\u001b[0m \u001b[0mtensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad_tensors_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 175\u001b[0m allow_unreachable=True, accumulate_grad=True) # Calls into the C++ engine to run the backward pass\n",
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
]
}
],
"source": [
"model = Model(vocab_size = len(dataset.uniq_words)).to(device)\n",
"train(dataset, model, 1, 64)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"def predict(dataset, model, text, next_words=5):\n",
" model.eval()\n",
" words = text.split(' ')\n",
" state_h, state_c = model.init_state(len(words))\n",
"\n",
" for i in range(0, next_words):\n",
" x = torch.tensor([[dataset.word_to_index[w] for w in words[i:]]]).to(device)\n",
" y_pred, (state_h, state_c) = model(x, (state_h, state_c))\n",
"\n",
" last_word_logits = y_pred[0][-1]\n",
" p = torch.nn.functional.softmax(last_word_logits, dim=0).detach().cpu().numpy()\n",
" word_index = np.random.choice(len(last_word_logits), p=p)\n",
" words.append(dataset.index_to_word[word_index])\n",
"\n",
" return words"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['kmicic', 'szedł', 'zwycięzco', 'po', 'do', 'zlituj', 'i']"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"predict(dataset, model, 'kmicic szedł')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### ZADANIE 1\n",
"\n",
"Stworzyć sieć rekurencyjną GRU dla Challenging America word-gap prediction. Wymogi takie jak zawsze, zadanie widoczne na gonito"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### ZADANIE 2\n",
"\n",
"Podjąć wyzwanie na https://gonito.net/challenge/precipitation-pl i/lub https://gonito.net/challenge/book-dialogues-pl\n",
"\n",
"\n",
"**KONIECZNIE** należy je zgłosić do końca następnego piątku, czyli 20 maja!. Za późniejsze zgłoszenia (nawet minutę) nieprzyznaję punktów.\n",
" \n",
"Za każde zgłoszenie lepsze niż baseline przyznaję 40 punktów.\n",
"\n",
"Zamiast tych 40 punktów za najlepsze miejsca:\n",
"- 1. miejsce 150 punktów\n",
"- 2. miejsce 100 punktów\n",
"- 3. miejsce 70 punktów\n",
"\n",
"Można brać udział w 2 wyzwaniach jednocześnie.\n",
"\n",
"Zadania nie będą widoczne w gonito w achievements. Nie trzeba udostępniać kodu, należy jednak przestrzegać regulaminu wyzwań."
]
}
],
"metadata": {
"author": "Jakub Pokrywka",
"email": "kubapok@wmi.amu.edu.pl",
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"lang": "pl",
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.4"
},
"subtitle": "0.Informacje na temat przedmiotu[ćwiczenia]",
"title": "Ekstrakcja informacji",
"year": "2021"
},
"nbformat": 4,
"nbformat_minor": 4
}