aitech-moj-2023/cw/09_Model_neuronowy_rekurencyjny.ipynb
2022-05-08 19:32:57 +02:00

985 lines
60 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"![Logo 1](https://git.wmi.amu.edu.pl/AITech/Szablon/raw/branch/master/Logotyp_AITech1.jpg)\n",
"<div class=\"alert alert-block alert-info\">\n",
"<h1> Modelowanie Języka</h1>\n",
"<h2> 9. <i>Model neuronowy rekurencyjny</i> [ćwiczenia]</h2> \n",
"<h3> Jakub Pokrywka (2022)</h3>\n",
"</div>\n",
"\n",
"![Logo 2](https://git.wmi.amu.edu.pl/AITech/Szablon/raw/branch/master/Logotyp_AITech2.jpg)"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from torch import nn, optim\n",
"from torch.utils.data import DataLoader\n",
"import numpy as np\n",
"from collections import Counter\n",
"import re"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"device = 'cpu'"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"--2022-05-08 19:27:04-- https://wolnelektury.pl/media/book/txt/potop-tom-pierwszy.txt\n",
"Resolving wolnelektury.pl (wolnelektury.pl)... 51.83.143.148, 2001:41d0:602:3294::\n",
"Connecting to wolnelektury.pl (wolnelektury.pl)|51.83.143.148|:443... connected.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 877893 (857K) [text/plain]\n",
"Saving to: potop-tom-pierwszy.txt.2\n",
"\n",
"potop-tom-pierwszy. 100%[===================>] 857,32K --.-KB/s in 0,07s \n",
"\n",
"2022-05-08 19:27:04 (12,0 MB/s) - potop-tom-pierwszy.txt.2 saved [877893/877893]\n",
"\n",
"--2022-05-08 19:27:04-- https://wolnelektury.pl/media/book/txt/potop-tom-drugi.txt\n",
"Resolving wolnelektury.pl (wolnelektury.pl)... 51.83.143.148, 2001:41d0:602:3294::\n",
"Connecting to wolnelektury.pl (wolnelektury.pl)|51.83.143.148|:443... connected.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 1087797 (1,0M) [text/plain]\n",
"Saving to: potop-tom-drugi.txt.2\n",
"\n",
"potop-tom-drugi.txt 100%[===================>] 1,04M --.-KB/s in 0,08s \n",
"\n",
"2022-05-08 19:27:04 (12,9 MB/s) - potop-tom-drugi.txt.2 saved [1087797/1087797]\n",
"\n",
"--2022-05-08 19:27:05-- https://wolnelektury.pl/media/book/txt/potop-tom-trzeci.txt\n",
"Resolving wolnelektury.pl (wolnelektury.pl)... 51.83.143.148, 2001:41d0:602:3294::\n",
"Connecting to wolnelektury.pl (wolnelektury.pl)|51.83.143.148|:443... connected.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 788219 (770K) [text/plain]\n",
"Saving to: potop-tom-trzeci.txt.2\n",
"\n",
"potop-tom-trzeci.tx 100%[===================>] 769,75K --.-KB/s in 0,06s \n",
"\n",
"2022-05-08 19:27:05 (12,0 MB/s) - potop-tom-trzeci.txt.2 saved [788219/788219]\n",
"\n"
]
}
],
"source": [
"! wget https://wolnelektury.pl/media/book/txt/potop-tom-pierwszy.txt\n",
"! wget https://wolnelektury.pl/media/book/txt/potop-tom-drugi.txt\n",
"! wget https://wolnelektury.pl/media/book/txt/potop-tom-trzeci.txt"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"!cat potop-* > potop.txt"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"class Dataset(torch.utils.data.Dataset):\n",
" def __init__(\n",
" self,\n",
" sequence_length,\n",
" ):\n",
" self.sequence_length = sequence_length\n",
" self.words = self.load()\n",
" self.uniq_words = self.get_uniq_words()\n",
"\n",
" self.index_to_word = {index: word for index, word in enumerate(self.uniq_words)}\n",
" self.word_to_index = {word: index for index, word in enumerate(self.uniq_words)}\n",
"\n",
" self.words_indexes = [self.word_to_index[w] for w in self.words]\n",
"\n",
" def load(self):\n",
" with open('potop.txt', 'r') as f_in:\n",
" text = [x.rstrip() for x in f_in.readlines() if x.strip()]\n",
" text = ' '.join(text).lower()\n",
" text = re.sub('[^a-ząćęłńóśźż ]', '', text) \n",
" text = text.split(' ')\n",
" return text\n",
" \n",
" \n",
" def get_uniq_words(self):\n",
" word_counts = Counter(self.words)\n",
" return sorted(word_counts, key=word_counts.get, reverse=True)\n",
"\n",
" def __len__(self):\n",
" return len(self.words_indexes) - self.sequence_length\n",
"\n",
" def __getitem__(self, index):\n",
" return (\n",
" torch.tensor(self.words_indexes[index:index+self.sequence_length]),\n",
" torch.tensor(self.words_indexes[index+1:index+self.sequence_length+1]),\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"dataset = Dataset(5)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"text/plain": [
"(tensor([ 551, 18, 17, 255, 10748]),\n",
" tensor([ 18, 17, 255, 10748, 34]))"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dataset[200]"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['patrzył', 'tak', 'jak', 'człowiek', 'zbudzony']"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"[dataset.index_to_word[x] for x in [ 551, 18, 17, 255, 10748]]"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['tak', 'jak', 'człowiek', 'zbudzony', 'ze']"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"[dataset.index_to_word[x] for x in [ 18, 17, 255, 10748, 34]]"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"input_tensor = torch.tensor([[ 551, 18, 17, 255, 10748]], dtype=torch.int32).to(device)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"#input_tensor = torch.tensor([[ 551, 18]], dtype=torch.int32).to(device)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"class Model(nn.Module):\n",
" def __init__(self, vocab_size):\n",
" super(Model, self).__init__()\n",
" self.lstm_size = 128\n",
" self.embedding_dim = 128\n",
" self.num_layers = 3\n",
"\n",
" self.embedding = nn.Embedding(\n",
" num_embeddings=vocab_size,\n",
" embedding_dim=self.embedding_dim,\n",
" )\n",
" self.lstm = nn.LSTM(\n",
" input_size=self.lstm_size,\n",
" hidden_size=self.lstm_size,\n",
" num_layers=self.num_layers,\n",
" dropout=0.2,\n",
" )\n",
" self.fc = nn.Linear(self.lstm_size, vocab_size)\n",
"\n",
" def forward(self, x, prev_state = None):\n",
" embed = self.embedding(x)\n",
" output, state = self.lstm(embed, prev_state)\n",
" logits = self.fc(output)\n",
" return logits, state\n",
"\n",
" def init_state(self, sequence_length):\n",
" return (torch.zeros(self.num_layers, sequence_length, self.lstm_size).to(device),\n",
" torch.zeros(self.num_layers, sequence_length, self.lstm_size).to(device))"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"model = Model(len(dataset)).to(device)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"y_pred, (state_h, state_c) = model(input_tensor)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[[ 0.0046, -0.0113, 0.0313, ..., 0.0198, -0.0312, 0.0223],\n",
" [ 0.0039, -0.0110, 0.0303, ..., 0.0213, -0.0302, 0.0230],\n",
" [ 0.0029, -0.0133, 0.0265, ..., 0.0204, -0.0297, 0.0219],\n",
" [ 0.0010, -0.0120, 0.0282, ..., 0.0241, -0.0314, 0.0241],\n",
" [ 0.0038, -0.0106, 0.0346, ..., 0.0230, -0.0333, 0.0232]]],\n",
" grad_fn=<AddBackward0>)"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y_pred"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([1, 5, 1187998])"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y_pred.shape"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"def train(dataset, model, max_epochs, batch_size):\n",
" model.train()\n",
"\n",
" dataloader = DataLoader(dataset, batch_size=batch_size)\n",
" criterion = nn.CrossEntropyLoss()\n",
" optimizer = optim.Adam(model.parameters(), lr=0.001)\n",
"\n",
" for epoch in range(max_epochs):\n",
" for batch, (x, y) in enumerate(dataloader):\n",
" optimizer.zero_grad()\n",
" x = x.to(device)\n",
" y = y.to(device)\n",
"\n",
" y_pred, (state_h, state_c) = model(x)\n",
" loss = criterion(y_pred.transpose(1, 2), y)\n",
"\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
" print({ 'epoch': epoch, 'update in batch': batch, '/' : len(dataloader), 'loss': loss.item() })\n"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'epoch': 0, 'update in batch': 0, '/': 18563, 'loss': 10.717817306518555}\n",
"{'epoch': 0, 'update in batch': 1, '/': 18563, 'loss': 10.699922561645508}\n",
"{'epoch': 0, 'update in batch': 2, '/': 18563, 'loss': 10.701103210449219}\n",
"{'epoch': 0, 'update in batch': 3, '/': 18563, 'loss': 10.700254440307617}\n",
"{'epoch': 0, 'update in batch': 4, '/': 18563, 'loss': 10.69465160369873}\n",
"{'epoch': 0, 'update in batch': 5, '/': 18563, 'loss': 10.681333541870117}\n",
"{'epoch': 0, 'update in batch': 6, '/': 18563, 'loss': 10.668376922607422}\n",
"{'epoch': 0, 'update in batch': 7, '/': 18563, 'loss': 10.675261497497559}\n",
"{'epoch': 0, 'update in batch': 8, '/': 18563, 'loss': 10.665823936462402}\n",
"{'epoch': 0, 'update in batch': 9, '/': 18563, 'loss': 10.655462265014648}\n",
"{'epoch': 0, 'update in batch': 10, '/': 18563, 'loss': 10.591516494750977}\n",
"{'epoch': 0, 'update in batch': 11, '/': 18563, 'loss': 10.580559730529785}\n",
"{'epoch': 0, 'update in batch': 12, '/': 18563, 'loss': 10.524133682250977}\n",
"{'epoch': 0, 'update in batch': 13, '/': 18563, 'loss': 10.480895042419434}\n",
"{'epoch': 0, 'update in batch': 14, '/': 18563, 'loss': 10.33996295928955}\n",
"{'epoch': 0, 'update in batch': 15, '/': 18563, 'loss': 10.345580101013184}\n",
"{'epoch': 0, 'update in batch': 16, '/': 18563, 'loss': 10.200639724731445}\n",
"{'epoch': 0, 'update in batch': 17, '/': 18563, 'loss': 10.030133247375488}\n",
"{'epoch': 0, 'update in batch': 18, '/': 18563, 'loss': 10.046720504760742}\n",
"{'epoch': 0, 'update in batch': 19, '/': 18563, 'loss': 10.00318717956543}\n",
"{'epoch': 0, 'update in batch': 20, '/': 18563, 'loss': 9.588350296020508}\n",
"{'epoch': 0, 'update in batch': 21, '/': 18563, 'loss': 9.780914306640625}\n",
"{'epoch': 0, 'update in batch': 22, '/': 18563, 'loss': 9.36646842956543}\n",
"{'epoch': 0, 'update in batch': 23, '/': 18563, 'loss': 9.306387901306152}\n",
"{'epoch': 0, 'update in batch': 24, '/': 18563, 'loss': 9.150574684143066}\n",
"{'epoch': 0, 'update in batch': 25, '/': 18563, 'loss': 8.89719295501709}\n",
"{'epoch': 0, 'update in batch': 26, '/': 18563, 'loss': 8.741975784301758}\n",
"{'epoch': 0, 'update in batch': 27, '/': 18563, 'loss': 9.36513614654541}\n",
"{'epoch': 0, 'update in batch': 28, '/': 18563, 'loss': 8.840768814086914}\n",
"{'epoch': 0, 'update in batch': 29, '/': 18563, 'loss': 8.356801986694336}\n",
"{'epoch': 0, 'update in batch': 30, '/': 18563, 'loss': 8.274016380310059}\n",
"{'epoch': 0, 'update in batch': 31, '/': 18563, 'loss': 8.944927215576172}\n",
"{'epoch': 0, 'update in batch': 32, '/': 18563, 'loss': 8.923280715942383}\n",
"{'epoch': 0, 'update in batch': 33, '/': 18563, 'loss': 8.479402542114258}\n",
"{'epoch': 0, 'update in batch': 34, '/': 18563, 'loss': 8.42425537109375}\n",
"{'epoch': 0, 'update in batch': 35, '/': 18563, 'loss': 9.487113952636719}\n",
"{'epoch': 0, 'update in batch': 36, '/': 18563, 'loss': 8.314191818237305}\n",
"{'epoch': 0, 'update in batch': 37, '/': 18563, 'loss': 8.0274658203125}\n",
"{'epoch': 0, 'update in batch': 38, '/': 18563, 'loss': 8.725769996643066}\n",
"{'epoch': 0, 'update in batch': 39, '/': 18563, 'loss': 8.67934799194336}\n",
"{'epoch': 0, 'update in batch': 40, '/': 18563, 'loss': 8.872161865234375}\n",
"{'epoch': 0, 'update in batch': 41, '/': 18563, 'loss': 7.883971214294434}\n",
"{'epoch': 0, 'update in batch': 42, '/': 18563, 'loss': 7.682810306549072}\n",
"{'epoch': 0, 'update in batch': 43, '/': 18563, 'loss': 7.880677223205566}\n",
"{'epoch': 0, 'update in batch': 44, '/': 18563, 'loss': 7.807427406311035}\n",
"{'epoch': 0, 'update in batch': 45, '/': 18563, 'loss': 7.93829870223999}\n",
"{'epoch': 0, 'update in batch': 46, '/': 18563, 'loss': 7.718912601470947}\n",
"{'epoch': 0, 'update in batch': 47, '/': 18563, 'loss': 8.309863090515137}\n",
"{'epoch': 0, 'update in batch': 48, '/': 18563, 'loss': 9.091133117675781}\n",
"{'epoch': 0, 'update in batch': 49, '/': 18563, 'loss': 9.317312240600586}\n",
"{'epoch': 0, 'update in batch': 50, '/': 18563, 'loss': 8.517735481262207}\n",
"{'epoch': 0, 'update in batch': 51, '/': 18563, 'loss': 7.697592258453369}\n",
"{'epoch': 0, 'update in batch': 52, '/': 18563, 'loss': 6.838181972503662}\n",
"{'epoch': 0, 'update in batch': 53, '/': 18563, 'loss': 7.967227935791016}\n",
"{'epoch': 0, 'update in batch': 54, '/': 18563, 'loss': 8.47049331665039}\n",
"{'epoch': 0, 'update in batch': 55, '/': 18563, 'loss': 8.958921432495117}\n",
"{'epoch': 0, 'update in batch': 56, '/': 18563, 'loss': 8.316679000854492}\n",
"{'epoch': 0, 'update in batch': 57, '/': 18563, 'loss': 8.997099876403809}\n",
"{'epoch': 0, 'update in batch': 58, '/': 18563, 'loss': 8.608811378479004}\n",
"{'epoch': 0, 'update in batch': 59, '/': 18563, 'loss': 9.377460479736328}\n",
"{'epoch': 0, 'update in batch': 60, '/': 18563, 'loss': 8.6201171875}\n",
"{'epoch': 0, 'update in batch': 61, '/': 18563, 'loss': 8.821510314941406}\n",
"{'epoch': 0, 'update in batch': 62, '/': 18563, 'loss': 8.915961265563965}\n",
"{'epoch': 0, 'update in batch': 63, '/': 18563, 'loss': 8.222617149353027}\n",
"{'epoch': 0, 'update in batch': 64, '/': 18563, 'loss': 9.266777992248535}\n",
"{'epoch': 0, 'update in batch': 65, '/': 18563, 'loss': 8.749354362487793}\n",
"{'epoch': 0, 'update in batch': 66, '/': 18563, 'loss': 8.311641693115234}\n",
"{'epoch': 0, 'update in batch': 67, '/': 18563, 'loss': 8.553888320922852}\n",
"{'epoch': 0, 'update in batch': 68, '/': 18563, 'loss': 8.790258407592773}\n",
"{'epoch': 0, 'update in batch': 69, '/': 18563, 'loss': 9.090133666992188}\n",
"{'epoch': 0, 'update in batch': 70, '/': 18563, 'loss': 8.893723487854004}\n",
"{'epoch': 0, 'update in batch': 71, '/': 18563, 'loss': 8.844594955444336}\n",
"{'epoch': 0, 'update in batch': 72, '/': 18563, 'loss': 7.771625518798828}\n",
"{'epoch': 0, 'update in batch': 73, '/': 18563, 'loss': 8.536479949951172}\n",
"{'epoch': 0, 'update in batch': 74, '/': 18563, 'loss': 7.300860404968262}\n",
"{'epoch': 0, 'update in batch': 75, '/': 18563, 'loss': 8.62000846862793}\n",
"{'epoch': 0, 'update in batch': 76, '/': 18563, 'loss': 8.67784309387207}\n",
"{'epoch': 0, 'update in batch': 77, '/': 18563, 'loss': 7.319235801696777}\n",
"{'epoch': 0, 'update in batch': 78, '/': 18563, 'loss': 8.322186470031738}\n",
"{'epoch': 0, 'update in batch': 79, '/': 18563, 'loss': 7.767421722412109}\n",
"{'epoch': 0, 'update in batch': 80, '/': 18563, 'loss': 8.817885398864746}\n",
"{'epoch': 0, 'update in batch': 81, '/': 18563, 'loss': 8.133109092712402}\n",
"{'epoch': 0, 'update in batch': 82, '/': 18563, 'loss': 7.822054862976074}\n",
"{'epoch': 0, 'update in batch': 83, '/': 18563, 'loss': 8.055540084838867}\n",
"{'epoch': 0, 'update in batch': 84, '/': 18563, 'loss': 8.053682327270508}\n",
"{'epoch': 0, 'update in batch': 85, '/': 18563, 'loss': 8.018306732177734}\n",
"{'epoch': 0, 'update in batch': 86, '/': 18563, 'loss': 8.371909141540527}\n",
"{'epoch': 0, 'update in batch': 87, '/': 18563, 'loss': 8.057979583740234}\n",
"{'epoch': 0, 'update in batch': 88, '/': 18563, 'loss': 8.340703010559082}\n",
"{'epoch': 0, 'update in batch': 89, '/': 18563, 'loss': 8.7703857421875}\n",
"{'epoch': 0, 'update in batch': 90, '/': 18563, 'loss': 9.714847564697266}\n",
"{'epoch': 0, 'update in batch': 91, '/': 18563, 'loss': 8.621702194213867}\n",
"{'epoch': 0, 'update in batch': 92, '/': 18563, 'loss': 9.406997680664062}\n",
"{'epoch': 0, 'update in batch': 93, '/': 18563, 'loss': 9.29774284362793}\n",
"{'epoch': 0, 'update in batch': 94, '/': 18563, 'loss': 8.649836540222168}\n",
"{'epoch': 0, 'update in batch': 95, '/': 18563, 'loss': 8.441780090332031}\n",
"{'epoch': 0, 'update in batch': 96, '/': 18563, 'loss': 7.991406440734863}\n",
"{'epoch': 0, 'update in batch': 97, '/': 18563, 'loss': 9.314489364624023}\n",
"{'epoch': 0, 'update in batch': 98, '/': 18563, 'loss': 8.368816375732422}\n",
"{'epoch': 0, 'update in batch': 99, '/': 18563, 'loss': 8.771149635314941}\n",
"{'epoch': 0, 'update in batch': 100, '/': 18563, 'loss': 7.8758111000061035}\n",
"{'epoch': 0, 'update in batch': 101, '/': 18563, 'loss': 8.341328620910645}\n",
"{'epoch': 0, 'update in batch': 102, '/': 18563, 'loss': 8.413129806518555}\n",
"{'epoch': 0, 'update in batch': 103, '/': 18563, 'loss': 7.372011661529541}\n",
"{'epoch': 0, 'update in batch': 104, '/': 18563, 'loss': 8.170934677124023}\n",
"{'epoch': 0, 'update in batch': 105, '/': 18563, 'loss': 8.109993934631348}\n",
"{'epoch': 0, 'update in batch': 106, '/': 18563, 'loss': 8.172578811645508}\n",
"{'epoch': 0, 'update in batch': 107, '/': 18563, 'loss': 8.33222484588623}\n",
"{'epoch': 0, 'update in batch': 108, '/': 18563, 'loss': 7.997575283050537}\n",
"{'epoch': 0, 'update in batch': 109, '/': 18563, 'loss': 7.847937107086182}\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'epoch': 0, 'update in batch': 110, '/': 18563, 'loss': 7.351314544677734}\n",
"{'epoch': 0, 'update in batch': 111, '/': 18563, 'loss': 8.472936630249023}\n",
"{'epoch': 0, 'update in batch': 112, '/': 18563, 'loss': 7.855953216552734}\n",
"{'epoch': 0, 'update in batch': 113, '/': 18563, 'loss': 8.163175582885742}\n",
"{'epoch': 0, 'update in batch': 114, '/': 18563, 'loss': 8.208657264709473}\n",
"{'epoch': 0, 'update in batch': 115, '/': 18563, 'loss': 8.781523704528809}\n",
"{'epoch': 0, 'update in batch': 116, '/': 18563, 'loss': 8.449674606323242}\n",
"{'epoch': 0, 'update in batch': 117, '/': 18563, 'loss': 8.176030158996582}\n",
"{'epoch': 0, 'update in batch': 118, '/': 18563, 'loss': 8.415689468383789}\n",
"{'epoch': 0, 'update in batch': 119, '/': 18563, 'loss': 8.645845413208008}\n",
"{'epoch': 0, 'update in batch': 120, '/': 18563, 'loss': 8.160420417785645}\n",
"{'epoch': 0, 'update in batch': 121, '/': 18563, 'loss': 8.117982864379883}\n",
"{'epoch': 0, 'update in batch': 122, '/': 18563, 'loss': 9.099283218383789}\n",
"{'epoch': 0, 'update in batch': 123, '/': 18563, 'loss': 7.98253870010376}\n",
"{'epoch': 0, 'update in batch': 124, '/': 18563, 'loss': 8.112133979797363}\n",
"{'epoch': 0, 'update in batch': 125, '/': 18563, 'loss': 8.479134559631348}\n",
"{'epoch': 0, 'update in batch': 126, '/': 18563, 'loss': 8.92817497253418}\n",
"{'epoch': 0, 'update in batch': 127, '/': 18563, 'loss': 8.38918399810791}\n",
"{'epoch': 0, 'update in batch': 128, '/': 18563, 'loss': 9.000529289245605}\n",
"{'epoch': 0, 'update in batch': 129, '/': 18563, 'loss': 8.525534629821777}\n",
"{'epoch': 0, 'update in batch': 130, '/': 18563, 'loss': 9.055428504943848}\n",
"{'epoch': 0, 'update in batch': 131, '/': 18563, 'loss': 8.818662643432617}\n",
"{'epoch': 0, 'update in batch': 132, '/': 18563, 'loss': 8.807767868041992}\n",
"{'epoch': 0, 'update in batch': 133, '/': 18563, 'loss': 8.398343086242676}\n",
"{'epoch': 0, 'update in batch': 134, '/': 18563, 'loss': 8.435093879699707}\n",
"{'epoch': 0, 'update in batch': 135, '/': 18563, 'loss': 7.877000331878662}\n",
"{'epoch': 0, 'update in batch': 136, '/': 18563, 'loss': 8.197925567626953}\n",
"{'epoch': 0, 'update in batch': 137, '/': 18563, 'loss': 8.655011177062988}\n",
"{'epoch': 0, 'update in batch': 138, '/': 18563, 'loss': 7.786923885345459}\n",
"{'epoch': 0, 'update in batch': 139, '/': 18563, 'loss': 8.338996887207031}\n",
"{'epoch': 0, 'update in batch': 140, '/': 18563, 'loss': 8.607789993286133}\n",
"{'epoch': 0, 'update in batch': 141, '/': 18563, 'loss': 8.52219295501709}\n",
"{'epoch': 0, 'update in batch': 142, '/': 18563, 'loss': 8.436418533325195}\n",
"{'epoch': 0, 'update in batch': 143, '/': 18563, 'loss': 7.999323844909668}\n",
"{'epoch': 0, 'update in batch': 144, '/': 18563, 'loss': 7.543336391448975}\n",
"{'epoch': 0, 'update in batch': 145, '/': 18563, 'loss': 7.3255791664123535}\n",
"{'epoch': 0, 'update in batch': 146, '/': 18563, 'loss': 7.993613243103027}\n",
"{'epoch': 0, 'update in batch': 147, '/': 18563, 'loss': 8.8505859375}\n",
"{'epoch': 0, 'update in batch': 148, '/': 18563, 'loss': 8.146835327148438}\n",
"{'epoch': 0, 'update in batch': 149, '/': 18563, 'loss': 8.532424926757812}\n",
"{'epoch': 0, 'update in batch': 150, '/': 18563, 'loss': 8.323905944824219}\n",
"{'epoch': 0, 'update in batch': 151, '/': 18563, 'loss': 7.8726677894592285}\n",
"{'epoch': 0, 'update in batch': 152, '/': 18563, 'loss': 7.912005424499512}\n",
"{'epoch': 0, 'update in batch': 153, '/': 18563, 'loss': 8.010560035705566}\n",
"{'epoch': 0, 'update in batch': 154, '/': 18563, 'loss': 7.9417009353637695}\n",
"{'epoch': 0, 'update in batch': 155, '/': 18563, 'loss': 7.991711616516113}\n",
"{'epoch': 0, 'update in batch': 156, '/': 18563, 'loss': 8.27558708190918}\n",
"{'epoch': 0, 'update in batch': 157, '/': 18563, 'loss': 7.736246585845947}\n",
"{'epoch': 0, 'update in batch': 158, '/': 18563, 'loss': 7.4755754470825195}\n",
"{'epoch': 0, 'update in batch': 159, '/': 18563, 'loss': 8.023443222045898}\n",
"{'epoch': 0, 'update in batch': 160, '/': 18563, 'loss': 8.130350112915039}\n",
"{'epoch': 0, 'update in batch': 161, '/': 18563, 'loss': 7.770634651184082}\n",
"{'epoch': 0, 'update in batch': 162, '/': 18563, 'loss': 7.775434970855713}\n",
"{'epoch': 0, 'update in batch': 163, '/': 18563, 'loss': 7.965312957763672}\n",
"{'epoch': 0, 'update in batch': 164, '/': 18563, 'loss': 7.977341651916504}\n",
"{'epoch': 0, 'update in batch': 165, '/': 18563, 'loss': 7.703671455383301}\n",
"{'epoch': 0, 'update in batch': 166, '/': 18563, 'loss': 8.027135848999023}\n",
"{'epoch': 0, 'update in batch': 167, '/': 18563, 'loss': 7.7673773765563965}\n",
"{'epoch': 0, 'update in batch': 168, '/': 18563, 'loss': 8.654549598693848}\n",
"{'epoch': 0, 'update in batch': 169, '/': 18563, 'loss': 7.8060808181762695}\n",
"{'epoch': 0, 'update in batch': 170, '/': 18563, 'loss': 7.33704137802124}\n",
"{'epoch': 0, 'update in batch': 171, '/': 18563, 'loss': 7.971919059753418}\n",
"{'epoch': 0, 'update in batch': 172, '/': 18563, 'loss': 7.450611114501953}\n",
"{'epoch': 0, 'update in batch': 173, '/': 18563, 'loss': 7.978057861328125}\n",
"{'epoch': 0, 'update in batch': 174, '/': 18563, 'loss': 8.264434814453125}\n",
"{'epoch': 0, 'update in batch': 175, '/': 18563, 'loss': 8.47761058807373}\n",
"{'epoch': 0, 'update in batch': 176, '/': 18563, 'loss': 7.643885135650635}\n",
"{'epoch': 0, 'update in batch': 177, '/': 18563, 'loss': 8.696805000305176}\n",
"{'epoch': 0, 'update in batch': 178, '/': 18563, 'loss': 9.144462585449219}\n",
"{'epoch': 0, 'update in batch': 179, '/': 18563, 'loss': 8.582620620727539}\n",
"{'epoch': 0, 'update in batch': 180, '/': 18563, 'loss': 8.495562553405762}\n",
"{'epoch': 0, 'update in batch': 181, '/': 18563, 'loss': 9.259647369384766}\n",
"{'epoch': 0, 'update in batch': 182, '/': 18563, 'loss': 8.286632537841797}\n",
"{'epoch': 0, 'update in batch': 183, '/': 18563, 'loss': 8.378074645996094}\n",
"{'epoch': 0, 'update in batch': 184, '/': 18563, 'loss': 8.404892921447754}\n",
"{'epoch': 0, 'update in batch': 185, '/': 18563, 'loss': 9.206843376159668}\n",
"{'epoch': 0, 'update in batch': 186, '/': 18563, 'loss': 8.97215747833252}\n",
"{'epoch': 0, 'update in batch': 187, '/': 18563, 'loss': 8.281005859375}\n",
"{'epoch': 0, 'update in batch': 188, '/': 18563, 'loss': 7.638144493103027}\n",
"{'epoch': 0, 'update in batch': 189, '/': 18563, 'loss': 7.991082668304443}\n",
"{'epoch': 0, 'update in batch': 190, '/': 18563, 'loss': 8.207674026489258}\n",
"{'epoch': 0, 'update in batch': 191, '/': 18563, 'loss': 8.16801643371582}\n",
"{'epoch': 0, 'update in batch': 192, '/': 18563, 'loss': 7.827309608459473}\n",
"{'epoch': 0, 'update in batch': 193, '/': 18563, 'loss': 8.387285232543945}\n",
"{'epoch': 0, 'update in batch': 194, '/': 18563, 'loss': 7.990261077880859}\n",
"{'epoch': 0, 'update in batch': 195, '/': 18563, 'loss': 7.7953925132751465}\n",
"{'epoch': 0, 'update in batch': 196, '/': 18563, 'loss': 7.252983093261719}\n",
"{'epoch': 0, 'update in batch': 197, '/': 18563, 'loss': 7.806585788726807}\n",
"{'epoch': 0, 'update in batch': 198, '/': 18563, 'loss': 7.871600151062012}\n",
"{'epoch': 0, 'update in batch': 199, '/': 18563, 'loss': 7.639830589294434}\n",
"{'epoch': 0, 'update in batch': 200, '/': 18563, 'loss': 8.108308792114258}\n",
"{'epoch': 0, 'update in batch': 201, '/': 18563, 'loss': 7.41513729095459}\n",
"{'epoch': 0, 'update in batch': 202, '/': 18563, 'loss': 8.103743553161621}\n",
"{'epoch': 0, 'update in batch': 203, '/': 18563, 'loss': 8.82174301147461}\n",
"{'epoch': 0, 'update in batch': 204, '/': 18563, 'loss': 8.34859561920166}\n",
"{'epoch': 0, 'update in batch': 205, '/': 18563, 'loss': 7.890545845031738}\n",
"{'epoch': 0, 'update in batch': 206, '/': 18563, 'loss': 7.679532527923584}\n",
"{'epoch': 0, 'update in batch': 207, '/': 18563, 'loss': 7.810311317443848}\n",
"{'epoch': 0, 'update in batch': 208, '/': 18563, 'loss': 8.342585563659668}\n",
"{'epoch': 0, 'update in batch': 209, '/': 18563, 'loss': 8.253597259521484}\n",
"{'epoch': 0, 'update in batch': 210, '/': 18563, 'loss': 7.963072299957275}\n",
"{'epoch': 0, 'update in batch': 211, '/': 18563, 'loss': 8.537101745605469}\n",
"{'epoch': 0, 'update in batch': 212, '/': 18563, 'loss': 8.503724098205566}\n",
"{'epoch': 0, 'update in batch': 213, '/': 18563, 'loss': 8.568987846374512}\n",
"{'epoch': 0, 'update in batch': 214, '/': 18563, 'loss': 7.760678291320801}\n",
"{'epoch': 0, 'update in batch': 215, '/': 18563, 'loss': 8.302183151245117}\n",
"{'epoch': 0, 'update in batch': 216, '/': 18563, 'loss': 7.427420616149902}\n",
"{'epoch': 0, 'update in batch': 217, '/': 18563, 'loss': 8.05746078491211}\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'epoch': 0, 'update in batch': 218, '/': 18563, 'loss': 8.82285213470459}\n",
"{'epoch': 0, 'update in batch': 219, '/': 18563, 'loss': 7.948827266693115}\n",
"{'epoch': 0, 'update in batch': 220, '/': 18563, 'loss': 8.164112091064453}\n",
"{'epoch': 0, 'update in batch': 221, '/': 18563, 'loss': 7.721047401428223}\n",
"{'epoch': 0, 'update in batch': 222, '/': 18563, 'loss': 7.668707370758057}\n",
"{'epoch': 0, 'update in batch': 223, '/': 18563, 'loss': 8.576696395874023}\n",
"{'epoch': 0, 'update in batch': 224, '/': 18563, 'loss': 8.253091812133789}\n",
"{'epoch': 0, 'update in batch': 225, '/': 18563, 'loss': 8.303543090820312}\n",
"{'epoch': 0, 'update in batch': 226, '/': 18563, 'loss': 8.069855690002441}\n",
"{'epoch': 0, 'update in batch': 227, '/': 18563, 'loss': 8.57229232788086}\n",
"{'epoch': 0, 'update in batch': 228, '/': 18563, 'loss': 8.904585838317871}\n",
"{'epoch': 0, 'update in batch': 229, '/': 18563, 'loss': 8.485595703125}\n",
"{'epoch': 0, 'update in batch': 230, '/': 18563, 'loss': 8.22756290435791}\n",
"{'epoch': 0, 'update in batch': 231, '/': 18563, 'loss': 8.281603813171387}\n",
"{'epoch': 0, 'update in batch': 232, '/': 18563, 'loss': 7.591467380523682}\n",
"{'epoch': 0, 'update in batch': 233, '/': 18563, 'loss': 7.8028883934021}\n",
"{'epoch': 0, 'update in batch': 234, '/': 18563, 'loss': 8.079168319702148}\n",
"{'epoch': 0, 'update in batch': 235, '/': 18563, 'loss': 7.578390598297119}\n",
"{'epoch': 0, 'update in batch': 236, '/': 18563, 'loss': 7.865830421447754}\n",
"{'epoch': 0, 'update in batch': 237, '/': 18563, 'loss': 7.105422019958496}\n",
"{'epoch': 0, 'update in batch': 238, '/': 18563, 'loss': 8.034143447875977}\n",
"{'epoch': 0, 'update in batch': 239, '/': 18563, 'loss': 7.23009729385376}\n",
"{'epoch': 0, 'update in batch': 240, '/': 18563, 'loss': 7.221669673919678}\n",
"{'epoch': 0, 'update in batch': 241, '/': 18563, 'loss': 7.118913173675537}\n",
"{'epoch': 0, 'update in batch': 242, '/': 18563, 'loss': 7.690147399902344}\n",
"{'epoch': 0, 'update in batch': 243, '/': 18563, 'loss': 7.676979064941406}\n",
"{'epoch': 0, 'update in batch': 244, '/': 18563, 'loss': 8.231537818908691}\n",
"{'epoch': 0, 'update in batch': 245, '/': 18563, 'loss': 8.212566375732422}\n",
"{'epoch': 0, 'update in batch': 246, '/': 18563, 'loss': 9.095616340637207}\n",
"{'epoch': 0, 'update in batch': 247, '/': 18563, 'loss': 8.249703407287598}\n",
"{'epoch': 0, 'update in batch': 248, '/': 18563, 'loss': 9.082058906555176}\n",
"{'epoch': 0, 'update in batch': 249, '/': 18563, 'loss': 8.530516624450684}\n",
"{'epoch': 0, 'update in batch': 250, '/': 18563, 'loss': 8.979915618896484}\n",
"{'epoch': 0, 'update in batch': 251, '/': 18563, 'loss': 8.667882919311523}\n",
"{'epoch': 0, 'update in batch': 252, '/': 18563, 'loss': 8.804525375366211}\n",
"{'epoch': 0, 'update in batch': 253, '/': 18563, 'loss': 8.67729377746582}\n",
"{'epoch': 0, 'update in batch': 254, '/': 18563, 'loss': 8.580761909484863}\n",
"{'epoch': 0, 'update in batch': 255, '/': 18563, 'loss': 7.724173545837402}\n",
"{'epoch': 0, 'update in batch': 256, '/': 18563, 'loss': 7.7925591468811035}\n",
"{'epoch': 0, 'update in batch': 257, '/': 18563, 'loss': 7.731482028961182}\n",
"{'epoch': 0, 'update in batch': 258, '/': 18563, 'loss': 7.644040107727051}\n",
"{'epoch': 0, 'update in batch': 259, '/': 18563, 'loss': 7.947877407073975}\n",
"{'epoch': 0, 'update in batch': 260, '/': 18563, 'loss': 7.649043083190918}\n",
"{'epoch': 0, 'update in batch': 261, '/': 18563, 'loss': 7.40912389755249}\n",
"{'epoch': 0, 'update in batch': 262, '/': 18563, 'loss': 8.199918746948242}\n",
"{'epoch': 0, 'update in batch': 263, '/': 18563, 'loss': 7.272132873535156}\n",
"{'epoch': 0, 'update in batch': 264, '/': 18563, 'loss': 7.205214500427246}\n",
"{'epoch': 0, 'update in batch': 265, '/': 18563, 'loss': 8.999595642089844}\n",
"{'epoch': 0, 'update in batch': 266, '/': 18563, 'loss': 7.851510524749756}\n",
"{'epoch': 0, 'update in batch': 267, '/': 18563, 'loss': 7.748948097229004}\n",
"{'epoch': 0, 'update in batch': 268, '/': 18563, 'loss': 7.96875}\n",
"{'epoch': 0, 'update in batch': 269, '/': 18563, 'loss': 7.627255916595459}\n",
"{'epoch': 0, 'update in batch': 270, '/': 18563, 'loss': 7.719862937927246}\n",
"{'epoch': 0, 'update in batch': 271, '/': 18563, 'loss': 7.58780574798584}\n",
"{'epoch': 0, 'update in batch': 272, '/': 18563, 'loss': 8.386865615844727}\n",
"{'epoch': 0, 'update in batch': 273, '/': 18563, 'loss': 8.708396911621094}\n",
"{'epoch': 0, 'update in batch': 274, '/': 18563, 'loss': 7.853432655334473}\n",
"{'epoch': 0, 'update in batch': 275, '/': 18563, 'loss': 7.818131923675537}\n",
"{'epoch': 0, 'update in batch': 276, '/': 18563, 'loss': 7.714521884918213}\n",
"{'epoch': 0, 'update in batch': 277, '/': 18563, 'loss': 8.75371265411377}\n",
"{'epoch': 0, 'update in batch': 278, '/': 18563, 'loss': 7.6992998123168945}\n",
"{'epoch': 0, 'update in batch': 279, '/': 18563, 'loss': 7.652693748474121}\n",
"{'epoch': 0, 'update in batch': 280, '/': 18563, 'loss': 7.364585876464844}\n",
"{'epoch': 0, 'update in batch': 281, '/': 18563, 'loss': 7.742022514343262}\n",
"{'epoch': 0, 'update in batch': 282, '/': 18563, 'loss': 7.6205573081970215}\n",
"{'epoch': 0, 'update in batch': 283, '/': 18563, 'loss': 7.475846290588379}\n",
"{'epoch': 0, 'update in batch': 284, '/': 18563, 'loss': 7.302148342132568}\n",
"{'epoch': 0, 'update in batch': 285, '/': 18563, 'loss': 7.524351596832275}\n",
"{'epoch': 0, 'update in batch': 286, '/': 18563, 'loss': 7.755963325500488}\n",
"{'epoch': 0, 'update in batch': 287, '/': 18563, 'loss': 7.620995998382568}\n",
"{'epoch': 0, 'update in batch': 288, '/': 18563, 'loss': 7.289975166320801}\n",
"{'epoch': 0, 'update in batch': 289, '/': 18563, 'loss': 7.470652103424072}\n",
"{'epoch': 0, 'update in batch': 290, '/': 18563, 'loss': 7.297110557556152}\n",
"{'epoch': 0, 'update in batch': 291, '/': 18563, 'loss': 7.907563209533691}\n",
"{'epoch': 0, 'update in batch': 292, '/': 18563, 'loss': 8.051852226257324}\n",
"{'epoch': 0, 'update in batch': 293, '/': 18563, 'loss': 6.691899299621582}\n",
"{'epoch': 0, 'update in batch': 294, '/': 18563, 'loss': 7.9747819900512695}\n",
"{'epoch': 0, 'update in batch': 295, '/': 18563, 'loss': 7.415904998779297}\n",
"{'epoch': 0, 'update in batch': 296, '/': 18563, 'loss': 7.479670524597168}\n",
"{'epoch': 0, 'update in batch': 297, '/': 18563, 'loss': 7.9454755783081055}\n",
"{'epoch': 0, 'update in batch': 298, '/': 18563, 'loss': 7.79656457901001}\n",
"{'epoch': 0, 'update in batch': 299, '/': 18563, 'loss': 7.644859313964844}\n",
"{'epoch': 0, 'update in batch': 300, '/': 18563, 'loss': 7.649240970611572}\n",
"{'epoch': 0, 'update in batch': 301, '/': 18563, 'loss': 7.497203826904297}\n",
"{'epoch': 0, 'update in batch': 302, '/': 18563, 'loss': 7.169632911682129}\n",
"{'epoch': 0, 'update in batch': 303, '/': 18563, 'loss': 7.124764442443848}\n",
"{'epoch': 0, 'update in batch': 304, '/': 18563, 'loss': 7.728893280029297}\n",
"{'epoch': 0, 'update in batch': 305, '/': 18563, 'loss': 8.029245376586914}\n",
"{'epoch': 0, 'update in batch': 306, '/': 18563, 'loss': 7.361662864685059}\n",
"{'epoch': 0, 'update in batch': 307, '/': 18563, 'loss': 8.070173263549805}\n",
"{'epoch': 0, 'update in batch': 308, '/': 18563, 'loss': 7.55655574798584}\n",
"{'epoch': 0, 'update in batch': 309, '/': 18563, 'loss': 7.713553428649902}\n",
"{'epoch': 0, 'update in batch': 310, '/': 18563, 'loss': 8.333553314208984}\n",
"{'epoch': 0, 'update in batch': 311, '/': 18563, 'loss': 8.089872360229492}\n",
"{'epoch': 0, 'update in batch': 312, '/': 18563, 'loss': 8.951356887817383}\n",
"{'epoch': 0, 'update in batch': 313, '/': 18563, 'loss': 8.920665740966797}\n",
"{'epoch': 0, 'update in batch': 314, '/': 18563, 'loss': 8.811259269714355}\n",
"{'epoch': 0, 'update in batch': 315, '/': 18563, 'loss': 8.719802856445312}\n",
"{'epoch': 0, 'update in batch': 316, '/': 18563, 'loss': 8.700776100158691}\n",
"{'epoch': 0, 'update in batch': 317, '/': 18563, 'loss': 8.846036911010742}\n",
"{'epoch': 0, 'update in batch': 318, '/': 18563, 'loss': 8.553533554077148}\n",
"{'epoch': 0, 'update in batch': 319, '/': 18563, 'loss': 9.257116317749023}\n",
"{'epoch': 0, 'update in batch': 320, '/': 18563, 'loss': 8.487042427062988}\n",
"{'epoch': 0, 'update in batch': 321, '/': 18563, 'loss': 8.743330955505371}\n",
"{'epoch': 0, 'update in batch': 322, '/': 18563, 'loss': 8.377813339233398}\n",
"{'epoch': 0, 'update in batch': 323, '/': 18563, 'loss': 8.41798210144043}\n",
"{'epoch': 0, 'update in batch': 324, '/': 18563, 'loss': 7.884764671325684}\n",
"{'epoch': 0, 'update in batch': 325, '/': 18563, 'loss': 8.827409744262695}\n",
"{'epoch': 0, 'update in batch': 326, '/': 18563, 'loss': 8.21721363067627}\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'epoch': 0, 'update in batch': 327, '/': 18563, 'loss': 8.522723197937012}\n",
"{'epoch': 0, 'update in batch': 328, '/': 18563, 'loss': 7.387178897857666}\n",
"{'epoch': 0, 'update in batch': 329, '/': 18563, 'loss': 8.58663558959961}\n",
"{'epoch': 0, 'update in batch': 330, '/': 18563, 'loss': 8.539435386657715}\n",
"{'epoch': 0, 'update in batch': 331, '/': 18563, 'loss': 8.35865592956543}\n",
"{'epoch': 0, 'update in batch': 332, '/': 18563, 'loss': 8.55555248260498}\n",
"{'epoch': 0, 'update in batch': 333, '/': 18563, 'loss': 7.9116950035095215}\n",
"{'epoch': 0, 'update in batch': 334, '/': 18563, 'loss': 8.424735069274902}\n",
"{'epoch': 0, 'update in batch': 335, '/': 18563, 'loss': 8.383890151977539}\n",
"{'epoch': 0, 'update in batch': 336, '/': 18563, 'loss': 8.145454406738281}\n",
"{'epoch': 0, 'update in batch': 337, '/': 18563, 'loss': 8.014772415161133}\n",
"{'epoch': 0, 'update in batch': 338, '/': 18563, 'loss': 8.532005310058594}\n",
"{'epoch': 0, 'update in batch': 339, '/': 18563, 'loss': 8.979973793029785}\n",
"{'epoch': 0, 'update in batch': 340, '/': 18563, 'loss': 8.3964204788208}\n",
"{'epoch': 0, 'update in batch': 341, '/': 18563, 'loss': 8.34205150604248}\n",
"{'epoch': 0, 'update in batch': 342, '/': 18563, 'loss': 7.861489295959473}\n",
"{'epoch': 0, 'update in batch': 343, '/': 18563, 'loss': 8.807058334350586}\n",
"{'epoch': 0, 'update in batch': 344, '/': 18563, 'loss': 8.14976978302002}\n",
"{'epoch': 0, 'update in batch': 345, '/': 18563, 'loss': 8.212860107421875}\n",
"{'epoch': 0, 'update in batch': 346, '/': 18563, 'loss': 8.323419570922852}\n",
"{'epoch': 0, 'update in batch': 347, '/': 18563, 'loss': 9.06071662902832}\n",
"{'epoch': 0, 'update in batch': 348, '/': 18563, 'loss': 8.79192066192627}\n",
"{'epoch': 0, 'update in batch': 349, '/': 18563, 'loss': 8.717201232910156}\n",
"{'epoch': 0, 'update in batch': 350, '/': 18563, 'loss': 8.149703979492188}\n",
"{'epoch': 0, 'update in batch': 351, '/': 18563, 'loss': 7.990046501159668}\n",
"{'epoch': 0, 'update in batch': 352, '/': 18563, 'loss': 7.8197221755981445}\n",
"{'epoch': 0, 'update in batch': 353, '/': 18563, 'loss': 8.022729873657227}\n",
"{'epoch': 0, 'update in batch': 354, '/': 18563, 'loss': 8.339923858642578}\n",
"{'epoch': 0, 'update in batch': 355, '/': 18563, 'loss': 7.867880821228027}\n",
"{'epoch': 0, 'update in batch': 356, '/': 18563, 'loss': 8.161782264709473}\n",
"{'epoch': 0, 'update in batch': 357, '/': 18563, 'loss': 7.711170196533203}\n",
"{'epoch': 0, 'update in batch': 358, '/': 18563, 'loss': 8.46279239654541}\n",
"{'epoch': 0, 'update in batch': 359, '/': 18563, 'loss': 8.327804565429688}\n",
"{'epoch': 0, 'update in batch': 360, '/': 18563, 'loss': 8.184597969055176}\n",
"{'epoch': 0, 'update in batch': 361, '/': 18563, 'loss': 8.126212120056152}\n",
"{'epoch': 0, 'update in batch': 362, '/': 18563, 'loss': 8.122446060180664}\n",
"{'epoch': 0, 'update in batch': 363, '/': 18563, 'loss': 7.730257511138916}\n",
"{'epoch': 0, 'update in batch': 364, '/': 18563, 'loss': 7.7179059982299805}\n",
"{'epoch': 0, 'update in batch': 365, '/': 18563, 'loss': 7.557857513427734}\n",
"{'epoch': 0, 'update in batch': 366, '/': 18563, 'loss': 8.614083290100098}\n",
"{'epoch': 0, 'update in batch': 367, '/': 18563, 'loss': 8.0489501953125}\n",
"{'epoch': 0, 'update in batch': 368, '/': 18563, 'loss': 8.355381965637207}\n",
"{'epoch': 0, 'update in batch': 369, '/': 18563, 'loss': 7.592991828918457}\n",
"{'epoch': 0, 'update in batch': 370, '/': 18563, 'loss': 7.674102783203125}\n",
"{'epoch': 0, 'update in batch': 371, '/': 18563, 'loss': 7.818256378173828}\n",
"{'epoch': 0, 'update in batch': 372, '/': 18563, 'loss': 8.510438919067383}\n",
"{'epoch': 0, 'update in batch': 373, '/': 18563, 'loss': 8.02087116241455}\n",
"{'epoch': 0, 'update in batch': 374, '/': 18563, 'loss': 8.206090927124023}\n",
"{'epoch': 0, 'update in batch': 375, '/': 18563, 'loss': 7.645677089691162}\n",
"{'epoch': 0, 'update in batch': 376, '/': 18563, 'loss': 8.241236686706543}\n",
"{'epoch': 0, 'update in batch': 377, '/': 18563, 'loss': 8.581649780273438}\n",
"{'epoch': 0, 'update in batch': 378, '/': 18563, 'loss': 9.361258506774902}\n",
"{'epoch': 0, 'update in batch': 379, '/': 18563, 'loss': 9.097440719604492}\n",
"{'epoch': 0, 'update in batch': 380, '/': 18563, 'loss': 8.081677436828613}\n",
"{'epoch': 0, 'update in batch': 381, '/': 18563, 'loss': 8.761143684387207}\n",
"{'epoch': 0, 'update in batch': 382, '/': 18563, 'loss': 7.9429121017456055}\n",
"{'epoch': 0, 'update in batch': 383, '/': 18563, 'loss': 8.05648422241211}\n",
"{'epoch': 0, 'update in batch': 384, '/': 18563, 'loss': 7.316658020019531}\n",
"{'epoch': 0, 'update in batch': 385, '/': 18563, 'loss': 8.597393035888672}\n",
"{'epoch': 0, 'update in batch': 386, '/': 18563, 'loss': 9.393728256225586}\n",
"{'epoch': 0, 'update in batch': 387, '/': 18563, 'loss': 8.225081443786621}\n",
"{'epoch': 0, 'update in batch': 388, '/': 18563, 'loss': 7.9958319664001465}\n",
"{'epoch': 0, 'update in batch': 389, '/': 18563, 'loss': 8.390036582946777}\n",
"{'epoch': 0, 'update in batch': 390, '/': 18563, 'loss': 7.745572566986084}\n",
"{'epoch': 0, 'update in batch': 391, '/': 18563, 'loss': 8.403060913085938}\n",
"{'epoch': 0, 'update in batch': 392, '/': 18563, 'loss': 8.703788757324219}\n",
"{'epoch': 0, 'update in batch': 393, '/': 18563, 'loss': 8.516857147216797}\n",
"{'epoch': 0, 'update in batch': 394, '/': 18563, 'loss': 8.078744888305664}\n",
"{'epoch': 0, 'update in batch': 395, '/': 18563, 'loss': 7.6597900390625}\n",
"{'epoch': 0, 'update in batch': 396, '/': 18563, 'loss': 8.454282760620117}\n",
"{'epoch': 0, 'update in batch': 397, '/': 18563, 'loss': 7.7727837562561035}\n",
"{'epoch': 0, 'update in batch': 398, '/': 18563, 'loss': 8.222984313964844}\n",
"{'epoch': 0, 'update in batch': 399, '/': 18563, 'loss': 8.369619369506836}\n",
"{'epoch': 0, 'update in batch': 400, '/': 18563, 'loss': 8.542525291442871}\n",
"{'epoch': 0, 'update in batch': 401, '/': 18563, 'loss': 7.9681854248046875}\n",
"{'epoch': 0, 'update in batch': 402, '/': 18563, 'loss': 8.842118263244629}\n",
"{'epoch': 0, 'update in batch': 403, '/': 18563, 'loss': 7.958454132080078}\n",
"{'epoch': 0, 'update in batch': 404, '/': 18563, 'loss': 7.084095001220703}\n",
"{'epoch': 0, 'update in batch': 405, '/': 18563, 'loss': 7.8765130043029785}\n",
"{'epoch': 0, 'update in batch': 406, '/': 18563, 'loss': 7.639691352844238}\n",
"{'epoch': 0, 'update in batch': 407, '/': 18563, 'loss': 7.440125942230225}\n",
"{'epoch': 0, 'update in batch': 408, '/': 18563, 'loss': 7.928472995758057}\n",
"{'epoch': 0, 'update in batch': 409, '/': 18563, 'loss': 8.704710960388184}\n",
"{'epoch': 0, 'update in batch': 410, '/': 18563, 'loss': 8.214713096618652}\n",
"{'epoch': 0, 'update in batch': 411, '/': 18563, 'loss': 8.115629196166992}\n",
"{'epoch': 0, 'update in batch': 412, '/': 18563, 'loss': 9.357975006103516}\n",
"{'epoch': 0, 'update in batch': 413, '/': 18563, 'loss': 7.756926536560059}\n",
"{'epoch': 0, 'update in batch': 414, '/': 18563, 'loss': 8.93007755279541}\n",
"{'epoch': 0, 'update in batch': 415, '/': 18563, 'loss': 8.929518699645996}\n",
"{'epoch': 0, 'update in batch': 416, '/': 18563, 'loss': 7.646470069885254}\n",
"{'epoch': 0, 'update in batch': 417, '/': 18563, 'loss': 8.457891464233398}\n",
"{'epoch': 0, 'update in batch': 418, '/': 18563, 'loss': 7.377375602722168}\n",
"{'epoch': 0, 'update in batch': 419, '/': 18563, 'loss': 8.03713607788086}\n",
"{'epoch': 0, 'update in batch': 420, '/': 18563, 'loss': 8.125130653381348}\n",
"{'epoch': 0, 'update in batch': 421, '/': 18563, 'loss': 6.818246364593506}\n",
"{'epoch': 0, 'update in batch': 422, '/': 18563, 'loss': 7.220259189605713}\n",
"{'epoch': 0, 'update in batch': 423, '/': 18563, 'loss': 7.800910949707031}\n",
"{'epoch': 0, 'update in batch': 424, '/': 18563, 'loss': 8.175793647766113}\n",
"{'epoch': 0, 'update in batch': 425, '/': 18563, 'loss': 7.588067054748535}\n",
"{'epoch': 0, 'update in batch': 426, '/': 18563, 'loss': 7.2054619789123535}\n",
"{'epoch': 0, 'update in batch': 427, '/': 18563, 'loss': 7.6552839279174805}\n",
"{'epoch': 0, 'update in batch': 428, '/': 18563, 'loss': 8.851090431213379}\n",
"{'epoch': 0, 'update in batch': 429, '/': 18563, 'loss': 8.768563270568848}\n",
"{'epoch': 0, 'update in batch': 430, '/': 18563, 'loss': 7.926184177398682}\n",
"{'epoch': 0, 'update in batch': 431, '/': 18563, 'loss': 8.663213729858398}\n",
"{'epoch': 0, 'update in batch': 432, '/': 18563, 'loss': 8.386338233947754}\n",
"{'epoch': 0, 'update in batch': 433, '/': 18563, 'loss': 8.77399730682373}\n",
"{'epoch': 0, 'update in batch': 434, '/': 18563, 'loss': 8.385528564453125}\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'epoch': 0, 'update in batch': 435, '/': 18563, 'loss': 7.742388725280762}\n",
"{'epoch': 0, 'update in batch': 436, '/': 18563, 'loss': 8.363179206848145}\n",
"{'epoch': 0, 'update in batch': 437, '/': 18563, 'loss': 9.262784004211426}\n",
"{'epoch': 0, 'update in batch': 438, '/': 18563, 'loss': 9.236469268798828}\n",
"{'epoch': 0, 'update in batch': 439, '/': 18563, 'loss': 8.904603958129883}\n",
"{'epoch': 0, 'update in batch': 440, '/': 18563, 'loss': 8.675701141357422}\n",
"{'epoch': 0, 'update in batch': 441, '/': 18563, 'loss': 8.811418533325195}\n",
"{'epoch': 0, 'update in batch': 442, '/': 18563, 'loss': 8.002241134643555}\n",
"{'epoch': 0, 'update in batch': 443, '/': 18563, 'loss': 9.04414176940918}\n",
"{'epoch': 0, 'update in batch': 444, '/': 18563, 'loss': 7.8904008865356445}\n",
"{'epoch': 0, 'update in batch': 445, '/': 18563, 'loss': 8.524297714233398}\n",
"{'epoch': 0, 'update in batch': 446, '/': 18563, 'loss': 8.615904808044434}\n",
"{'epoch': 0, 'update in batch': 447, '/': 18563, 'loss': 8.201675415039062}\n",
"{'epoch': 0, 'update in batch': 448, '/': 18563, 'loss': 8.531024932861328}\n",
"{'epoch': 0, 'update in batch': 449, '/': 18563, 'loss': 7.8379621505737305}\n",
"{'epoch': 0, 'update in batch': 450, '/': 18563, 'loss': 8.416367530822754}\n",
"{'epoch': 0, 'update in batch': 451, '/': 18563, 'loss': 7.4990715980529785}\n",
"{'epoch': 0, 'update in batch': 452, '/': 18563, 'loss': 7.984610557556152}\n",
"{'epoch': 0, 'update in batch': 453, '/': 18563, 'loss': 7.719987392425537}\n",
"{'epoch': 0, 'update in batch': 454, '/': 18563, 'loss': 7.9333176612854}\n",
"{'epoch': 0, 'update in batch': 455, '/': 18563, 'loss': 8.619344711303711}\n",
"{'epoch': 0, 'update in batch': 456, '/': 18563, 'loss': 7.849525451660156}\n",
"{'epoch': 0, 'update in batch': 457, '/': 18563, 'loss': 7.700997352600098}\n",
"{'epoch': 0, 'update in batch': 458, '/': 18563, 'loss': 8.065767288208008}\n",
"{'epoch': 0, 'update in batch': 459, '/': 18563, 'loss': 7.489628791809082}\n",
"{'epoch': 0, 'update in batch': 460, '/': 18563, 'loss': 8.036481857299805}\n",
"{'epoch': 0, 'update in batch': 461, '/': 18563, 'loss': 8.227537155151367}\n",
"{'epoch': 0, 'update in batch': 462, '/': 18563, 'loss': 7.66103982925415}\n",
"{'epoch': 0, 'update in batch': 463, '/': 18563, 'loss': 8.481343269348145}\n",
"{'epoch': 0, 'update in batch': 464, '/': 18563, 'loss': 8.711318969726562}\n",
"{'epoch': 0, 'update in batch': 465, '/': 18563, 'loss': 7.549925804138184}\n",
"{'epoch': 0, 'update in batch': 466, '/': 18563, 'loss': 8.020782470703125}\n",
"{'epoch': 0, 'update in batch': 467, '/': 18563, 'loss': 7.784451484680176}\n",
"{'epoch': 0, 'update in batch': 468, '/': 18563, 'loss': 7.7545928955078125}\n",
"{'epoch': 0, 'update in batch': 469, '/': 18563, 'loss': 8.484171867370605}\n",
"{'epoch': 0, 'update in batch': 470, '/': 18563, 'loss': 8.291640281677246}\n",
"{'epoch': 0, 'update in batch': 471, '/': 18563, 'loss': 7.873322486877441}\n",
"{'epoch': 0, 'update in batch': 472, '/': 18563, 'loss': 7.891420841217041}\n",
"{'epoch': 0, 'update in batch': 473, '/': 18563, 'loss': 8.376962661743164}\n",
"{'epoch': 0, 'update in batch': 474, '/': 18563, 'loss': 8.147513389587402}\n",
"{'epoch': 0, 'update in batch': 475, '/': 18563, 'loss': 7.739943027496338}\n",
"{'epoch': 0, 'update in batch': 476, '/': 18563, 'loss': 7.52395486831665}\n",
"{'epoch': 0, 'update in batch': 477, '/': 18563, 'loss': 7.962507724761963}\n",
"{'epoch': 0, 'update in batch': 478, '/': 18563, 'loss': 7.61989688873291}\n",
"{'epoch': 0, 'update in batch': 479, '/': 18563, 'loss': 8.628551483154297}\n",
"{'epoch': 0, 'update in batch': 480, '/': 18563, 'loss': 10.344924926757812}\n",
"{'epoch': 0, 'update in batch': 481, '/': 18563, 'loss': 9.189457893371582}\n",
"{'epoch': 0, 'update in batch': 482, '/': 18563, 'loss': 9.283202171325684}\n",
"{'epoch': 0, 'update in batch': 483, '/': 18563, 'loss': 8.036226272583008}\n",
"{'epoch': 0, 'update in batch': 484, '/': 18563, 'loss': 8.949888229370117}\n",
"{'epoch': 0, 'update in batch': 485, '/': 18563, 'loss': 9.32779598236084}\n",
"{'epoch': 0, 'update in batch': 486, '/': 18563, 'loss': 9.554967880249023}\n",
"{'epoch': 0, 'update in batch': 487, '/': 18563, 'loss': 8.438692092895508}\n",
"{'epoch': 0, 'update in batch': 488, '/': 18563, 'loss': 8.015823364257812}\n",
"{'epoch': 0, 'update in batch': 489, '/': 18563, 'loss': 8.621005058288574}\n",
"{'epoch': 0, 'update in batch': 490, '/': 18563, 'loss': 8.432602882385254}\n",
"{'epoch': 0, 'update in batch': 491, '/': 18563, 'loss': 8.659430503845215}\n",
"{'epoch': 0, 'update in batch': 492, '/': 18563, 'loss': 8.693103790283203}\n",
"{'epoch': 0, 'update in batch': 493, '/': 18563, 'loss': 8.895064353942871}\n"
]
},
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-18-fe996a0be74b>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mModel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvocab_size\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0muniq_words\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m64\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;32m<ipython-input-17-8d700bc624e3>\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(dataset, model, max_epochs, batch_size)\u001b[0m\n\u001b[1;32m 15\u001b[0m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcriterion\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_pred\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtranspose\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 16\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 17\u001b[0;31m \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 18\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 19\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/anaconda3/lib/python3.8/site-packages/torch/_tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m 361\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 362\u001b[0m inputs=inputs)\n\u001b[0;32m--> 363\u001b[0;31m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mautograd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 364\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 365\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mregister_hook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/anaconda3/lib/python3.8/site-packages/torch/autograd/__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m 171\u001b[0m \u001b[0;31m# some Python versions print out the first line of a multi-line function\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 172\u001b[0m \u001b[0;31m# calls in the traceback and some print out the last line\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 173\u001b[0;31m Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass\n\u001b[0m\u001b[1;32m 174\u001b[0m \u001b[0mtensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad_tensors_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 175\u001b[0m allow_unreachable=True, accumulate_grad=True) # Calls into the C++ engine to run the backward pass\n",
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
]
}
],
"source": [
"model = Model(vocab_size = len(dataset.uniq_words)).to(device)\n",
"train(dataset, model, 1, 64)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"def predict(dataset, model, text, next_words=5):\n",
" model.eval()\n",
" words = text.split(' ')\n",
" state_h, state_c = model.init_state(len(words))\n",
"\n",
" for i in range(0, next_words):\n",
" x = torch.tensor([[dataset.word_to_index[w] for w in words[i:]]]).to(device)\n",
" y_pred, (state_h, state_c) = model(x, (state_h, state_c))\n",
"\n",
" last_word_logits = y_pred[0][-1]\n",
" p = torch.nn.functional.softmax(last_word_logits, dim=0).detach().cpu().numpy()\n",
" word_index = np.random.choice(len(last_word_logits), p=p)\n",
" words.append(dataset.index_to_word[word_index])\n",
"\n",
" return words"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['kmicic', 'szedł', 'zwycięzco', 'po', 'do', 'zlituj', 'i']"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"predict(dataset, model, 'kmicic szedł')"
]
}
],
"metadata": {
"author": "Jakub Pokrywka",
"email": "kubapok@wmi.amu.edu.pl",
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"lang": "pl",
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.3"
},
"subtitle": "0.Informacje na temat przedmiotu[ćwiczenia]",
"title": "Ekstrakcja informacji",
"year": "2021"
},
"nbformat": 4,
"nbformat_minor": 4
}