diff --git a/cw/09_Model_neuronowy_rekurencyjny.ipynb b/cw/09_Model_neuronowy_rekurencyjny.ipynb new file mode 100644 index 0000000..210630e --- /dev/null +++ b/cw/09_Model_neuronowy_rekurencyjny.ipynb @@ -0,0 +1,1016 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "![Logo 1](https://git.wmi.amu.edu.pl/AITech/Szablon/raw/branch/master/Logotyp_AITech1.jpg)\n", + "
\n", + "

Modelowanie Języka

\n", + "

9. Model neuronowy rekurencyjny [ćwiczenia]

\n", + "

Jakub Pokrywka (2022)

\n", + "
\n", + "\n", + "![Logo 2](https://git.wmi.amu.edu.pl/AITech/Szablon/raw/branch/master/Logotyp_AITech2.jpg)" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from torch import nn, optim\n", + "from torch.utils.data import DataLoader\n", + "import numpy as np\n", + "from collections import Counter\n", + "import re" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "device = 'cpu'" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "--2022-05-08 19:27:04-- https://wolnelektury.pl/media/book/txt/potop-tom-pierwszy.txt\n", + "Resolving wolnelektury.pl (wolnelektury.pl)... 51.83.143.148, 2001:41d0:602:3294::\n", + "Connecting to wolnelektury.pl (wolnelektury.pl)|51.83.143.148|:443... connected.\n", + "HTTP request sent, awaiting response... 200 OK\n", + "Length: 877893 (857K) [text/plain]\n", + "Saving to: ‘potop-tom-pierwszy.txt.2’\n", + "\n", + "potop-tom-pierwszy. 100%[===================>] 857,32K --.-KB/s in 0,07s \n", + "\n", + "2022-05-08 19:27:04 (12,0 MB/s) - ‘potop-tom-pierwszy.txt.2’ saved [877893/877893]\n", + "\n", + "--2022-05-08 19:27:04-- https://wolnelektury.pl/media/book/txt/potop-tom-drugi.txt\n", + "Resolving wolnelektury.pl (wolnelektury.pl)... 51.83.143.148, 2001:41d0:602:3294::\n", + "Connecting to wolnelektury.pl (wolnelektury.pl)|51.83.143.148|:443... connected.\n", + "HTTP request sent, awaiting response... 200 OK\n", + "Length: 1087797 (1,0M) [text/plain]\n", + "Saving to: ‘potop-tom-drugi.txt.2’\n", + "\n", + "potop-tom-drugi.txt 100%[===================>] 1,04M --.-KB/s in 0,08s \n", + "\n", + "2022-05-08 19:27:04 (12,9 MB/s) - ‘potop-tom-drugi.txt.2’ saved [1087797/1087797]\n", + "\n", + "--2022-05-08 19:27:05-- https://wolnelektury.pl/media/book/txt/potop-tom-trzeci.txt\n", + "Resolving wolnelektury.pl (wolnelektury.pl)... 51.83.143.148, 2001:41d0:602:3294::\n", + "Connecting to wolnelektury.pl (wolnelektury.pl)|51.83.143.148|:443... connected.\n", + "HTTP request sent, awaiting response... 200 OK\n", + "Length: 788219 (770K) [text/plain]\n", + "Saving to: ‘potop-tom-trzeci.txt.2’\n", + "\n", + "potop-tom-trzeci.tx 100%[===================>] 769,75K --.-KB/s in 0,06s \n", + "\n", + "2022-05-08 19:27:05 (12,0 MB/s) - ‘potop-tom-trzeci.txt.2’ saved [788219/788219]\n", + "\n" + ] + } + ], + "source": [ + "! wget https://wolnelektury.pl/media/book/txt/potop-tom-pierwszy.txt\n", + "! wget https://wolnelektury.pl/media/book/txt/potop-tom-drugi.txt\n", + "! wget https://wolnelektury.pl/media/book/txt/potop-tom-trzeci.txt" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "!cat potop-* > potop.txt" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "class Dataset(torch.utils.data.Dataset):\n", + " def __init__(\n", + " self,\n", + " sequence_length,\n", + " ):\n", + " self.sequence_length = sequence_length\n", + " self.words = self.load()\n", + " self.uniq_words = self.get_uniq_words()\n", + "\n", + " self.index_to_word = {index: word for index, word in enumerate(self.uniq_words)}\n", + " self.word_to_index = {word: index for index, word in enumerate(self.uniq_words)}\n", + "\n", + " self.words_indexes = [self.word_to_index[w] for w in self.words]\n", + "\n", + " def load(self):\n", + " with open('potop.txt', 'r') as f_in:\n", + " text = [x.rstrip() for x in f_in.readlines() if x.strip()]\n", + " text = ' '.join(text).lower()\n", + " text = re.sub('[^a-ząćęłńóśźż ]', '', text) \n", + " text = text.split(' ')\n", + " return text\n", + " \n", + " \n", + " def get_uniq_words(self):\n", + " word_counts = Counter(self.words)\n", + " return sorted(word_counts, key=word_counts.get, reverse=True)\n", + "\n", + " def __len__(self):\n", + " return len(self.words_indexes) - self.sequence_length\n", + "\n", + " def __getitem__(self, index):\n", + " return (\n", + " torch.tensor(self.words_indexes[index:index+self.sequence_length]),\n", + " torch.tensor(self.words_indexes[index+1:index+self.sequence_length+1]),\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "dataset = Dataset(5)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "data": { + "text/plain": [ + "(tensor([ 551, 18, 17, 255, 10748]),\n", + " tensor([ 18, 17, 255, 10748, 34]))" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dataset[200]" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['patrzył', 'tak', 'jak', 'człowiek', 'zbudzony']" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "[dataset.index_to_word[x] for x in [ 551, 18, 17, 255, 10748]]" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['tak', 'jak', 'człowiek', 'zbudzony', 'ze']" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "[dataset.index_to_word[x] for x in [ 18, 17, 255, 10748, 34]]" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "input_tensor = torch.tensor([[ 551, 18, 17, 255, 10748]], dtype=torch.int32).to(device)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "#input_tensor = torch.tensor([[ 551, 18]], dtype=torch.int32).to(device)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "class Model(nn.Module):\n", + " def __init__(self, vocab_size):\n", + " super(Model, self).__init__()\n", + " self.lstm_size = 128\n", + " self.embedding_dim = 128\n", + " self.num_layers = 3\n", + "\n", + " self.embedding = nn.Embedding(\n", + " num_embeddings=vocab_size,\n", + " embedding_dim=self.embedding_dim,\n", + " )\n", + " self.lstm = nn.LSTM(\n", + " input_size=self.lstm_size,\n", + " hidden_size=self.lstm_size,\n", + " num_layers=self.num_layers,\n", + " dropout=0.2,\n", + " )\n", + " self.fc = nn.Linear(self.lstm_size, vocab_size)\n", + "\n", + " def forward(self, x, prev_state = None):\n", + " embed = self.embedding(x)\n", + " output, state = self.lstm(embed, prev_state)\n", + " logits = self.fc(output)\n", + " return logits, state\n", + "\n", + " def init_state(self, sequence_length):\n", + " return (torch.zeros(self.num_layers, sequence_length, self.lstm_size).to(device),\n", + " torch.zeros(self.num_layers, sequence_length, self.lstm_size).to(device))" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "model = Model(len(dataset)).to(device)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "y_pred, (state_h, state_c) = model(input_tensor)" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[[ 0.0046, -0.0113, 0.0313, ..., 0.0198, -0.0312, 0.0223],\n", + " [ 0.0039, -0.0110, 0.0303, ..., 0.0213, -0.0302, 0.0230],\n", + " [ 0.0029, -0.0133, 0.0265, ..., 0.0204, -0.0297, 0.0219],\n", + " [ 0.0010, -0.0120, 0.0282, ..., 0.0241, -0.0314, 0.0241],\n", + " [ 0.0038, -0.0106, 0.0346, ..., 0.0230, -0.0333, 0.0232]]],\n", + " grad_fn=)" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "y_pred" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([1, 5, 1187998])" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "y_pred.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "def train(dataset, model, max_epochs, batch_size):\n", + " model.train()\n", + "\n", + " dataloader = DataLoader(dataset, batch_size=batch_size)\n", + " criterion = nn.CrossEntropyLoss()\n", + " optimizer = optim.Adam(model.parameters(), lr=0.001)\n", + "\n", + " for epoch in range(max_epochs):\n", + " for batch, (x, y) in enumerate(dataloader):\n", + " optimizer.zero_grad()\n", + " x = x.to(device)\n", + " y = y.to(device)\n", + "\n", + " y_pred, (state_h, state_c) = model(x)\n", + " loss = criterion(y_pred.transpose(1, 2), y)\n", + "\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + " print({ 'epoch': epoch, 'update in batch': batch, '/' : len(dataloader), 'loss': loss.item() })\n" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'epoch': 0, 'update in batch': 0, '/': 18563, 'loss': 10.717817306518555}\n", + "{'epoch': 0, 'update in batch': 1, '/': 18563, 'loss': 10.699922561645508}\n", + "{'epoch': 0, 'update in batch': 2, '/': 18563, 'loss': 10.701103210449219}\n", + "{'epoch': 0, 'update in batch': 3, '/': 18563, 'loss': 10.700254440307617}\n", + "{'epoch': 0, 'update in batch': 4, '/': 18563, 'loss': 10.69465160369873}\n", + "{'epoch': 0, 'update in batch': 5, '/': 18563, 'loss': 10.681333541870117}\n", + "{'epoch': 0, 'update in batch': 6, '/': 18563, 'loss': 10.668376922607422}\n", + "{'epoch': 0, 'update in batch': 7, '/': 18563, 'loss': 10.675261497497559}\n", + "{'epoch': 0, 'update in batch': 8, '/': 18563, 'loss': 10.665823936462402}\n", + "{'epoch': 0, 'update in batch': 9, '/': 18563, 'loss': 10.655462265014648}\n", + "{'epoch': 0, 'update in batch': 10, '/': 18563, 'loss': 10.591516494750977}\n", + "{'epoch': 0, 'update in batch': 11, '/': 18563, 'loss': 10.580559730529785}\n", + "{'epoch': 0, 'update in batch': 12, '/': 18563, 'loss': 10.524133682250977}\n", + "{'epoch': 0, 'update in batch': 13, '/': 18563, 'loss': 10.480895042419434}\n", + "{'epoch': 0, 'update in batch': 14, '/': 18563, 'loss': 10.33996295928955}\n", + "{'epoch': 0, 'update in batch': 15, '/': 18563, 'loss': 10.345580101013184}\n", + "{'epoch': 0, 'update in batch': 16, '/': 18563, 'loss': 10.200639724731445}\n", + "{'epoch': 0, 'update in batch': 17, '/': 18563, 'loss': 10.030133247375488}\n", + "{'epoch': 0, 'update in batch': 18, '/': 18563, 'loss': 10.046720504760742}\n", + "{'epoch': 0, 'update in batch': 19, '/': 18563, 'loss': 10.00318717956543}\n", + "{'epoch': 0, 'update in batch': 20, '/': 18563, 'loss': 9.588350296020508}\n", + "{'epoch': 0, 'update in batch': 21, '/': 18563, 'loss': 9.780914306640625}\n", + "{'epoch': 0, 'update in batch': 22, '/': 18563, 'loss': 9.36646842956543}\n", + "{'epoch': 0, 'update in batch': 23, '/': 18563, 'loss': 9.306387901306152}\n", + "{'epoch': 0, 'update in batch': 24, '/': 18563, 'loss': 9.150574684143066}\n", + "{'epoch': 0, 'update in batch': 25, '/': 18563, 'loss': 8.89719295501709}\n", + "{'epoch': 0, 'update in batch': 26, '/': 18563, 'loss': 8.741975784301758}\n", + "{'epoch': 0, 'update in batch': 27, '/': 18563, 'loss': 9.36513614654541}\n", + "{'epoch': 0, 'update in batch': 28, '/': 18563, 'loss': 8.840768814086914}\n", + "{'epoch': 0, 'update in batch': 29, '/': 18563, 'loss': 8.356801986694336}\n", + "{'epoch': 0, 'update in batch': 30, '/': 18563, 'loss': 8.274016380310059}\n", + "{'epoch': 0, 'update in batch': 31, '/': 18563, 'loss': 8.944927215576172}\n", + "{'epoch': 0, 'update in batch': 32, '/': 18563, 'loss': 8.923280715942383}\n", + "{'epoch': 0, 'update in batch': 33, '/': 18563, 'loss': 8.479402542114258}\n", + "{'epoch': 0, 'update in batch': 34, '/': 18563, 'loss': 8.42425537109375}\n", + "{'epoch': 0, 'update in batch': 35, '/': 18563, 'loss': 9.487113952636719}\n", + "{'epoch': 0, 'update in batch': 36, '/': 18563, 'loss': 8.314191818237305}\n", + "{'epoch': 0, 'update in batch': 37, '/': 18563, 'loss': 8.0274658203125}\n", + "{'epoch': 0, 'update in batch': 38, '/': 18563, 'loss': 8.725769996643066}\n", + "{'epoch': 0, 'update in batch': 39, '/': 18563, 'loss': 8.67934799194336}\n", + "{'epoch': 0, 'update in batch': 40, '/': 18563, 'loss': 8.872161865234375}\n", + "{'epoch': 0, 'update in batch': 41, '/': 18563, 'loss': 7.883971214294434}\n", + "{'epoch': 0, 'update in batch': 42, '/': 18563, 'loss': 7.682810306549072}\n", + "{'epoch': 0, 'update in batch': 43, '/': 18563, 'loss': 7.880677223205566}\n", + "{'epoch': 0, 'update in batch': 44, '/': 18563, 'loss': 7.807427406311035}\n", + "{'epoch': 0, 'update in batch': 45, '/': 18563, 'loss': 7.93829870223999}\n", + "{'epoch': 0, 'update in batch': 46, '/': 18563, 'loss': 7.718912601470947}\n", + "{'epoch': 0, 'update in batch': 47, '/': 18563, 'loss': 8.309863090515137}\n", + "{'epoch': 0, 'update in batch': 48, '/': 18563, 'loss': 9.091133117675781}\n", + "{'epoch': 0, 'update in batch': 49, '/': 18563, 'loss': 9.317312240600586}\n", + "{'epoch': 0, 'update in batch': 50, '/': 18563, 'loss': 8.517735481262207}\n", + "{'epoch': 0, 'update in batch': 51, '/': 18563, 'loss': 7.697592258453369}\n", + "{'epoch': 0, 'update in batch': 52, '/': 18563, 'loss': 6.838181972503662}\n", + "{'epoch': 0, 'update in batch': 53, '/': 18563, 'loss': 7.967227935791016}\n", + "{'epoch': 0, 'update in batch': 54, '/': 18563, 'loss': 8.47049331665039}\n", + "{'epoch': 0, 'update in batch': 55, '/': 18563, 'loss': 8.958921432495117}\n", + "{'epoch': 0, 'update in batch': 56, '/': 18563, 'loss': 8.316679000854492}\n", + "{'epoch': 0, 'update in batch': 57, '/': 18563, 'loss': 8.997099876403809}\n", + "{'epoch': 0, 'update in batch': 58, '/': 18563, 'loss': 8.608811378479004}\n", + "{'epoch': 0, 'update in batch': 59, '/': 18563, 'loss': 9.377460479736328}\n", + "{'epoch': 0, 'update in batch': 60, '/': 18563, 'loss': 8.6201171875}\n", + "{'epoch': 0, 'update in batch': 61, '/': 18563, 'loss': 8.821510314941406}\n", + "{'epoch': 0, 'update in batch': 62, '/': 18563, 'loss': 8.915961265563965}\n", + "{'epoch': 0, 'update in batch': 63, '/': 18563, 'loss': 8.222617149353027}\n", + "{'epoch': 0, 'update in batch': 64, '/': 18563, 'loss': 9.266777992248535}\n", + "{'epoch': 0, 'update in batch': 65, '/': 18563, 'loss': 8.749354362487793}\n", + "{'epoch': 0, 'update in batch': 66, '/': 18563, 'loss': 8.311641693115234}\n", + "{'epoch': 0, 'update in batch': 67, '/': 18563, 'loss': 8.553888320922852}\n", + "{'epoch': 0, 'update in batch': 68, '/': 18563, 'loss': 8.790258407592773}\n", + "{'epoch': 0, 'update in batch': 69, '/': 18563, 'loss': 9.090133666992188}\n", + "{'epoch': 0, 'update in batch': 70, '/': 18563, 'loss': 8.893723487854004}\n", + "{'epoch': 0, 'update in batch': 71, '/': 18563, 'loss': 8.844594955444336}\n", + "{'epoch': 0, 'update in batch': 72, '/': 18563, 'loss': 7.771625518798828}\n", + "{'epoch': 0, 'update in batch': 73, '/': 18563, 'loss': 8.536479949951172}\n", + "{'epoch': 0, 'update in batch': 74, '/': 18563, 'loss': 7.300860404968262}\n", + "{'epoch': 0, 'update in batch': 75, '/': 18563, 'loss': 8.62000846862793}\n", + "{'epoch': 0, 'update in batch': 76, '/': 18563, 'loss': 8.67784309387207}\n", + "{'epoch': 0, 'update in batch': 77, '/': 18563, 'loss': 7.319235801696777}\n", + "{'epoch': 0, 'update in batch': 78, '/': 18563, 'loss': 8.322186470031738}\n", + "{'epoch': 0, 'update in batch': 79, '/': 18563, 'loss': 7.767421722412109}\n", + "{'epoch': 0, 'update in batch': 80, '/': 18563, 'loss': 8.817885398864746}\n", + "{'epoch': 0, 'update in batch': 81, '/': 18563, 'loss': 8.133109092712402}\n", + "{'epoch': 0, 'update in batch': 82, '/': 18563, 'loss': 7.822054862976074}\n", + "{'epoch': 0, 'update in batch': 83, '/': 18563, 'loss': 8.055540084838867}\n", + "{'epoch': 0, 'update in batch': 84, '/': 18563, 'loss': 8.053682327270508}\n", + "{'epoch': 0, 'update in batch': 85, '/': 18563, 'loss': 8.018306732177734}\n", + "{'epoch': 0, 'update in batch': 86, '/': 18563, 'loss': 8.371909141540527}\n", + "{'epoch': 0, 'update in batch': 87, '/': 18563, 'loss': 8.057979583740234}\n", + "{'epoch': 0, 'update in batch': 88, '/': 18563, 'loss': 8.340703010559082}\n", + "{'epoch': 0, 'update in batch': 89, '/': 18563, 'loss': 8.7703857421875}\n", + "{'epoch': 0, 'update in batch': 90, '/': 18563, 'loss': 9.714847564697266}\n", + "{'epoch': 0, 'update in batch': 91, '/': 18563, 'loss': 8.621702194213867}\n", + "{'epoch': 0, 'update in batch': 92, '/': 18563, 'loss': 9.406997680664062}\n", + "{'epoch': 0, 'update in batch': 93, '/': 18563, 'loss': 9.29774284362793}\n", + "{'epoch': 0, 'update in batch': 94, '/': 18563, 'loss': 8.649836540222168}\n", + "{'epoch': 0, 'update in batch': 95, '/': 18563, 'loss': 8.441780090332031}\n", + "{'epoch': 0, 'update in batch': 96, '/': 18563, 'loss': 7.991406440734863}\n", + "{'epoch': 0, 'update in batch': 97, '/': 18563, 'loss': 9.314489364624023}\n", + "{'epoch': 0, 'update in batch': 98, '/': 18563, 'loss': 8.368816375732422}\n", + "{'epoch': 0, 'update in batch': 99, '/': 18563, 'loss': 8.771149635314941}\n", + "{'epoch': 0, 'update in batch': 100, '/': 18563, 'loss': 7.8758111000061035}\n", + "{'epoch': 0, 'update in batch': 101, '/': 18563, 'loss': 8.341328620910645}\n", + "{'epoch': 0, 'update in batch': 102, '/': 18563, 'loss': 8.413129806518555}\n", + "{'epoch': 0, 'update in batch': 103, '/': 18563, 'loss': 7.372011661529541}\n", + "{'epoch': 0, 'update in batch': 104, '/': 18563, 'loss': 8.170934677124023}\n", + "{'epoch': 0, 'update in batch': 105, '/': 18563, 'loss': 8.109993934631348}\n", + "{'epoch': 0, 'update in batch': 106, '/': 18563, 'loss': 8.172578811645508}\n", + "{'epoch': 0, 'update in batch': 107, '/': 18563, 'loss': 8.33222484588623}\n", + "{'epoch': 0, 'update in batch': 108, '/': 18563, 'loss': 7.997575283050537}\n", + "{'epoch': 0, 'update in batch': 109, '/': 18563, 'loss': 7.847937107086182}\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'epoch': 0, 'update in batch': 110, '/': 18563, 'loss': 7.351314544677734}\n", + "{'epoch': 0, 'update in batch': 111, '/': 18563, 'loss': 8.472936630249023}\n", + "{'epoch': 0, 'update in batch': 112, '/': 18563, 'loss': 7.855953216552734}\n", + "{'epoch': 0, 'update in batch': 113, '/': 18563, 'loss': 8.163175582885742}\n", + "{'epoch': 0, 'update in batch': 114, '/': 18563, 'loss': 8.208657264709473}\n", + "{'epoch': 0, 'update in batch': 115, '/': 18563, 'loss': 8.781523704528809}\n", + "{'epoch': 0, 'update in batch': 116, '/': 18563, 'loss': 8.449674606323242}\n", + "{'epoch': 0, 'update in batch': 117, '/': 18563, 'loss': 8.176030158996582}\n", + "{'epoch': 0, 'update in batch': 118, '/': 18563, 'loss': 8.415689468383789}\n", + "{'epoch': 0, 'update in batch': 119, '/': 18563, 'loss': 8.645845413208008}\n", + "{'epoch': 0, 'update in batch': 120, '/': 18563, 'loss': 8.160420417785645}\n", + "{'epoch': 0, 'update in batch': 121, '/': 18563, 'loss': 8.117982864379883}\n", + "{'epoch': 0, 'update in batch': 122, '/': 18563, 'loss': 9.099283218383789}\n", + "{'epoch': 0, 'update in batch': 123, '/': 18563, 'loss': 7.98253870010376}\n", + "{'epoch': 0, 'update in batch': 124, '/': 18563, 'loss': 8.112133979797363}\n", + "{'epoch': 0, 'update in batch': 125, '/': 18563, 'loss': 8.479134559631348}\n", + "{'epoch': 0, 'update in batch': 126, '/': 18563, 'loss': 8.92817497253418}\n", + "{'epoch': 0, 'update in batch': 127, '/': 18563, 'loss': 8.38918399810791}\n", + "{'epoch': 0, 'update in batch': 128, '/': 18563, 'loss': 9.000529289245605}\n", + "{'epoch': 0, 'update in batch': 129, '/': 18563, 'loss': 8.525534629821777}\n", + "{'epoch': 0, 'update in batch': 130, '/': 18563, 'loss': 9.055428504943848}\n", + "{'epoch': 0, 'update in batch': 131, '/': 18563, 'loss': 8.818662643432617}\n", + "{'epoch': 0, 'update in batch': 132, '/': 18563, 'loss': 8.807767868041992}\n", + "{'epoch': 0, 'update in batch': 133, '/': 18563, 'loss': 8.398343086242676}\n", + "{'epoch': 0, 'update in batch': 134, '/': 18563, 'loss': 8.435093879699707}\n", + "{'epoch': 0, 'update in batch': 135, '/': 18563, 'loss': 7.877000331878662}\n", + "{'epoch': 0, 'update in batch': 136, '/': 18563, 'loss': 8.197925567626953}\n", + "{'epoch': 0, 'update in batch': 137, '/': 18563, 'loss': 8.655011177062988}\n", + "{'epoch': 0, 'update in batch': 138, '/': 18563, 'loss': 7.786923885345459}\n", + "{'epoch': 0, 'update in batch': 139, '/': 18563, 'loss': 8.338996887207031}\n", + "{'epoch': 0, 'update in batch': 140, '/': 18563, 'loss': 8.607789993286133}\n", + "{'epoch': 0, 'update in batch': 141, '/': 18563, 'loss': 8.52219295501709}\n", + "{'epoch': 0, 'update in batch': 142, '/': 18563, 'loss': 8.436418533325195}\n", + "{'epoch': 0, 'update in batch': 143, '/': 18563, 'loss': 7.999323844909668}\n", + "{'epoch': 0, 'update in batch': 144, '/': 18563, 'loss': 7.543336391448975}\n", + "{'epoch': 0, 'update in batch': 145, '/': 18563, 'loss': 7.3255791664123535}\n", + "{'epoch': 0, 'update in batch': 146, '/': 18563, 'loss': 7.993613243103027}\n", + "{'epoch': 0, 'update in batch': 147, '/': 18563, 'loss': 8.8505859375}\n", + "{'epoch': 0, 'update in batch': 148, '/': 18563, 'loss': 8.146835327148438}\n", + "{'epoch': 0, 'update in batch': 149, '/': 18563, 'loss': 8.532424926757812}\n", + "{'epoch': 0, 'update in batch': 150, '/': 18563, 'loss': 8.323905944824219}\n", + "{'epoch': 0, 'update in batch': 151, '/': 18563, 'loss': 7.8726677894592285}\n", + "{'epoch': 0, 'update in batch': 152, '/': 18563, 'loss': 7.912005424499512}\n", + "{'epoch': 0, 'update in batch': 153, '/': 18563, 'loss': 8.010560035705566}\n", + "{'epoch': 0, 'update in batch': 154, '/': 18563, 'loss': 7.9417009353637695}\n", + "{'epoch': 0, 'update in batch': 155, '/': 18563, 'loss': 7.991711616516113}\n", + "{'epoch': 0, 'update in batch': 156, '/': 18563, 'loss': 8.27558708190918}\n", + "{'epoch': 0, 'update in batch': 157, '/': 18563, 'loss': 7.736246585845947}\n", + "{'epoch': 0, 'update in batch': 158, '/': 18563, 'loss': 7.4755754470825195}\n", + "{'epoch': 0, 'update in batch': 159, '/': 18563, 'loss': 8.023443222045898}\n", + "{'epoch': 0, 'update in batch': 160, '/': 18563, 'loss': 8.130350112915039}\n", + "{'epoch': 0, 'update in batch': 161, '/': 18563, 'loss': 7.770634651184082}\n", + "{'epoch': 0, 'update in batch': 162, '/': 18563, 'loss': 7.775434970855713}\n", + "{'epoch': 0, 'update in batch': 163, '/': 18563, 'loss': 7.965312957763672}\n", + "{'epoch': 0, 'update in batch': 164, '/': 18563, 'loss': 7.977341651916504}\n", + "{'epoch': 0, 'update in batch': 165, '/': 18563, 'loss': 7.703671455383301}\n", + "{'epoch': 0, 'update in batch': 166, '/': 18563, 'loss': 8.027135848999023}\n", + "{'epoch': 0, 'update in batch': 167, '/': 18563, 'loss': 7.7673773765563965}\n", + "{'epoch': 0, 'update in batch': 168, '/': 18563, 'loss': 8.654549598693848}\n", + "{'epoch': 0, 'update in batch': 169, '/': 18563, 'loss': 7.8060808181762695}\n", + "{'epoch': 0, 'update in batch': 170, '/': 18563, 'loss': 7.33704137802124}\n", + "{'epoch': 0, 'update in batch': 171, '/': 18563, 'loss': 7.971919059753418}\n", + "{'epoch': 0, 'update in batch': 172, '/': 18563, 'loss': 7.450611114501953}\n", + "{'epoch': 0, 'update in batch': 173, '/': 18563, 'loss': 7.978057861328125}\n", + "{'epoch': 0, 'update in batch': 174, '/': 18563, 'loss': 8.264434814453125}\n", + "{'epoch': 0, 'update in batch': 175, '/': 18563, 'loss': 8.47761058807373}\n", + "{'epoch': 0, 'update in batch': 176, '/': 18563, 'loss': 7.643885135650635}\n", + "{'epoch': 0, 'update in batch': 177, '/': 18563, 'loss': 8.696805000305176}\n", + "{'epoch': 0, 'update in batch': 178, '/': 18563, 'loss': 9.144462585449219}\n", + "{'epoch': 0, 'update in batch': 179, '/': 18563, 'loss': 8.582620620727539}\n", + "{'epoch': 0, 'update in batch': 180, '/': 18563, 'loss': 8.495562553405762}\n", + "{'epoch': 0, 'update in batch': 181, '/': 18563, 'loss': 9.259647369384766}\n", + "{'epoch': 0, 'update in batch': 182, '/': 18563, 'loss': 8.286632537841797}\n", + "{'epoch': 0, 'update in batch': 183, '/': 18563, 'loss': 8.378074645996094}\n", + "{'epoch': 0, 'update in batch': 184, '/': 18563, 'loss': 8.404892921447754}\n", + "{'epoch': 0, 'update in batch': 185, '/': 18563, 'loss': 9.206843376159668}\n", + "{'epoch': 0, 'update in batch': 186, '/': 18563, 'loss': 8.97215747833252}\n", + "{'epoch': 0, 'update in batch': 187, '/': 18563, 'loss': 8.281005859375}\n", + "{'epoch': 0, 'update in batch': 188, '/': 18563, 'loss': 7.638144493103027}\n", + "{'epoch': 0, 'update in batch': 189, '/': 18563, 'loss': 7.991082668304443}\n", + "{'epoch': 0, 'update in batch': 190, '/': 18563, 'loss': 8.207674026489258}\n", + "{'epoch': 0, 'update in batch': 191, '/': 18563, 'loss': 8.16801643371582}\n", + "{'epoch': 0, 'update in batch': 192, '/': 18563, 'loss': 7.827309608459473}\n", + "{'epoch': 0, 'update in batch': 193, '/': 18563, 'loss': 8.387285232543945}\n", + "{'epoch': 0, 'update in batch': 194, '/': 18563, 'loss': 7.990261077880859}\n", + "{'epoch': 0, 'update in batch': 195, '/': 18563, 'loss': 7.7953925132751465}\n", + "{'epoch': 0, 'update in batch': 196, '/': 18563, 'loss': 7.252983093261719}\n", + "{'epoch': 0, 'update in batch': 197, '/': 18563, 'loss': 7.806585788726807}\n", + "{'epoch': 0, 'update in batch': 198, '/': 18563, 'loss': 7.871600151062012}\n", + "{'epoch': 0, 'update in batch': 199, '/': 18563, 'loss': 7.639830589294434}\n", + "{'epoch': 0, 'update in batch': 200, '/': 18563, 'loss': 8.108308792114258}\n", + "{'epoch': 0, 'update in batch': 201, '/': 18563, 'loss': 7.41513729095459}\n", + "{'epoch': 0, 'update in batch': 202, '/': 18563, 'loss': 8.103743553161621}\n", + "{'epoch': 0, 'update in batch': 203, '/': 18563, 'loss': 8.82174301147461}\n", + "{'epoch': 0, 'update in batch': 204, '/': 18563, 'loss': 8.34859561920166}\n", + "{'epoch': 0, 'update in batch': 205, '/': 18563, 'loss': 7.890545845031738}\n", + "{'epoch': 0, 'update in batch': 206, '/': 18563, 'loss': 7.679532527923584}\n", + "{'epoch': 0, 'update in batch': 207, '/': 18563, 'loss': 7.810311317443848}\n", + "{'epoch': 0, 'update in batch': 208, '/': 18563, 'loss': 8.342585563659668}\n", + "{'epoch': 0, 'update in batch': 209, '/': 18563, 'loss': 8.253597259521484}\n", + "{'epoch': 0, 'update in batch': 210, '/': 18563, 'loss': 7.963072299957275}\n", + "{'epoch': 0, 'update in batch': 211, '/': 18563, 'loss': 8.537101745605469}\n", + "{'epoch': 0, 'update in batch': 212, '/': 18563, 'loss': 8.503724098205566}\n", + "{'epoch': 0, 'update in batch': 213, '/': 18563, 'loss': 8.568987846374512}\n", + "{'epoch': 0, 'update in batch': 214, '/': 18563, 'loss': 7.760678291320801}\n", + "{'epoch': 0, 'update in batch': 215, '/': 18563, 'loss': 8.302183151245117}\n", + "{'epoch': 0, 'update in batch': 216, '/': 18563, 'loss': 7.427420616149902}\n", + "{'epoch': 0, 'update in batch': 217, '/': 18563, 'loss': 8.05746078491211}\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'epoch': 0, 'update in batch': 218, '/': 18563, 'loss': 8.82285213470459}\n", + "{'epoch': 0, 'update in batch': 219, '/': 18563, 'loss': 7.948827266693115}\n", + "{'epoch': 0, 'update in batch': 220, '/': 18563, 'loss': 8.164112091064453}\n", + "{'epoch': 0, 'update in batch': 221, '/': 18563, 'loss': 7.721047401428223}\n", + "{'epoch': 0, 'update in batch': 222, '/': 18563, 'loss': 7.668707370758057}\n", + "{'epoch': 0, 'update in batch': 223, '/': 18563, 'loss': 8.576696395874023}\n", + "{'epoch': 0, 'update in batch': 224, '/': 18563, 'loss': 8.253091812133789}\n", + "{'epoch': 0, 'update in batch': 225, '/': 18563, 'loss': 8.303543090820312}\n", + "{'epoch': 0, 'update in batch': 226, '/': 18563, 'loss': 8.069855690002441}\n", + "{'epoch': 0, 'update in batch': 227, '/': 18563, 'loss': 8.57229232788086}\n", + "{'epoch': 0, 'update in batch': 228, '/': 18563, 'loss': 8.904585838317871}\n", + "{'epoch': 0, 'update in batch': 229, '/': 18563, 'loss': 8.485595703125}\n", + "{'epoch': 0, 'update in batch': 230, '/': 18563, 'loss': 8.22756290435791}\n", + "{'epoch': 0, 'update in batch': 231, '/': 18563, 'loss': 8.281603813171387}\n", + "{'epoch': 0, 'update in batch': 232, '/': 18563, 'loss': 7.591467380523682}\n", + "{'epoch': 0, 'update in batch': 233, '/': 18563, 'loss': 7.8028883934021}\n", + "{'epoch': 0, 'update in batch': 234, '/': 18563, 'loss': 8.079168319702148}\n", + "{'epoch': 0, 'update in batch': 235, '/': 18563, 'loss': 7.578390598297119}\n", + "{'epoch': 0, 'update in batch': 236, '/': 18563, 'loss': 7.865830421447754}\n", + "{'epoch': 0, 'update in batch': 237, '/': 18563, 'loss': 7.105422019958496}\n", + "{'epoch': 0, 'update in batch': 238, '/': 18563, 'loss': 8.034143447875977}\n", + "{'epoch': 0, 'update in batch': 239, '/': 18563, 'loss': 7.23009729385376}\n", + "{'epoch': 0, 'update in batch': 240, '/': 18563, 'loss': 7.221669673919678}\n", + "{'epoch': 0, 'update in batch': 241, '/': 18563, 'loss': 7.118913173675537}\n", + "{'epoch': 0, 'update in batch': 242, '/': 18563, 'loss': 7.690147399902344}\n", + "{'epoch': 0, 'update in batch': 243, '/': 18563, 'loss': 7.676979064941406}\n", + "{'epoch': 0, 'update in batch': 244, '/': 18563, 'loss': 8.231537818908691}\n", + "{'epoch': 0, 'update in batch': 245, '/': 18563, 'loss': 8.212566375732422}\n", + "{'epoch': 0, 'update in batch': 246, '/': 18563, 'loss': 9.095616340637207}\n", + "{'epoch': 0, 'update in batch': 247, '/': 18563, 'loss': 8.249703407287598}\n", + "{'epoch': 0, 'update in batch': 248, '/': 18563, 'loss': 9.082058906555176}\n", + "{'epoch': 0, 'update in batch': 249, '/': 18563, 'loss': 8.530516624450684}\n", + "{'epoch': 0, 'update in batch': 250, '/': 18563, 'loss': 8.979915618896484}\n", + "{'epoch': 0, 'update in batch': 251, '/': 18563, 'loss': 8.667882919311523}\n", + "{'epoch': 0, 'update in batch': 252, '/': 18563, 'loss': 8.804525375366211}\n", + "{'epoch': 0, 'update in batch': 253, '/': 18563, 'loss': 8.67729377746582}\n", + "{'epoch': 0, 'update in batch': 254, '/': 18563, 'loss': 8.580761909484863}\n", + "{'epoch': 0, 'update in batch': 255, '/': 18563, 'loss': 7.724173545837402}\n", + "{'epoch': 0, 'update in batch': 256, '/': 18563, 'loss': 7.7925591468811035}\n", + "{'epoch': 0, 'update in batch': 257, '/': 18563, 'loss': 7.731482028961182}\n", + "{'epoch': 0, 'update in batch': 258, '/': 18563, 'loss': 7.644040107727051}\n", + "{'epoch': 0, 'update in batch': 259, '/': 18563, 'loss': 7.947877407073975}\n", + "{'epoch': 0, 'update in batch': 260, '/': 18563, 'loss': 7.649043083190918}\n", + "{'epoch': 0, 'update in batch': 261, '/': 18563, 'loss': 7.40912389755249}\n", + "{'epoch': 0, 'update in batch': 262, '/': 18563, 'loss': 8.199918746948242}\n", + "{'epoch': 0, 'update in batch': 263, '/': 18563, 'loss': 7.272132873535156}\n", + "{'epoch': 0, 'update in batch': 264, '/': 18563, 'loss': 7.205214500427246}\n", + "{'epoch': 0, 'update in batch': 265, '/': 18563, 'loss': 8.999595642089844}\n", + "{'epoch': 0, 'update in batch': 266, '/': 18563, 'loss': 7.851510524749756}\n", + "{'epoch': 0, 'update in batch': 267, '/': 18563, 'loss': 7.748948097229004}\n", + "{'epoch': 0, 'update in batch': 268, '/': 18563, 'loss': 7.96875}\n", + "{'epoch': 0, 'update in batch': 269, '/': 18563, 'loss': 7.627255916595459}\n", + "{'epoch': 0, 'update in batch': 270, '/': 18563, 'loss': 7.719862937927246}\n", + "{'epoch': 0, 'update in batch': 271, '/': 18563, 'loss': 7.58780574798584}\n", + "{'epoch': 0, 'update in batch': 272, '/': 18563, 'loss': 8.386865615844727}\n", + "{'epoch': 0, 'update in batch': 273, '/': 18563, 'loss': 8.708396911621094}\n", + "{'epoch': 0, 'update in batch': 274, '/': 18563, 'loss': 7.853432655334473}\n", + "{'epoch': 0, 'update in batch': 275, '/': 18563, 'loss': 7.818131923675537}\n", + "{'epoch': 0, 'update in batch': 276, '/': 18563, 'loss': 7.714521884918213}\n", + "{'epoch': 0, 'update in batch': 277, '/': 18563, 'loss': 8.75371265411377}\n", + "{'epoch': 0, 'update in batch': 278, '/': 18563, 'loss': 7.6992998123168945}\n", + "{'epoch': 0, 'update in batch': 279, '/': 18563, 'loss': 7.652693748474121}\n", + "{'epoch': 0, 'update in batch': 280, '/': 18563, 'loss': 7.364585876464844}\n", + "{'epoch': 0, 'update in batch': 281, '/': 18563, 'loss': 7.742022514343262}\n", + "{'epoch': 0, 'update in batch': 282, '/': 18563, 'loss': 7.6205573081970215}\n", + "{'epoch': 0, 'update in batch': 283, '/': 18563, 'loss': 7.475846290588379}\n", + "{'epoch': 0, 'update in batch': 284, '/': 18563, 'loss': 7.302148342132568}\n", + "{'epoch': 0, 'update in batch': 285, '/': 18563, 'loss': 7.524351596832275}\n", + "{'epoch': 0, 'update in batch': 286, '/': 18563, 'loss': 7.755963325500488}\n", + "{'epoch': 0, 'update in batch': 287, '/': 18563, 'loss': 7.620995998382568}\n", + "{'epoch': 0, 'update in batch': 288, '/': 18563, 'loss': 7.289975166320801}\n", + "{'epoch': 0, 'update in batch': 289, '/': 18563, 'loss': 7.470652103424072}\n", + "{'epoch': 0, 'update in batch': 290, '/': 18563, 'loss': 7.297110557556152}\n", + "{'epoch': 0, 'update in batch': 291, '/': 18563, 'loss': 7.907563209533691}\n", + "{'epoch': 0, 'update in batch': 292, '/': 18563, 'loss': 8.051852226257324}\n", + "{'epoch': 0, 'update in batch': 293, '/': 18563, 'loss': 6.691899299621582}\n", + "{'epoch': 0, 'update in batch': 294, '/': 18563, 'loss': 7.9747819900512695}\n", + "{'epoch': 0, 'update in batch': 295, '/': 18563, 'loss': 7.415904998779297}\n", + "{'epoch': 0, 'update in batch': 296, '/': 18563, 'loss': 7.479670524597168}\n", + "{'epoch': 0, 'update in batch': 297, '/': 18563, 'loss': 7.9454755783081055}\n", + "{'epoch': 0, 'update in batch': 298, '/': 18563, 'loss': 7.79656457901001}\n", + "{'epoch': 0, 'update in batch': 299, '/': 18563, 'loss': 7.644859313964844}\n", + "{'epoch': 0, 'update in batch': 300, '/': 18563, 'loss': 7.649240970611572}\n", + "{'epoch': 0, 'update in batch': 301, '/': 18563, 'loss': 7.497203826904297}\n", + "{'epoch': 0, 'update in batch': 302, '/': 18563, 'loss': 7.169632911682129}\n", + "{'epoch': 0, 'update in batch': 303, '/': 18563, 'loss': 7.124764442443848}\n", + "{'epoch': 0, 'update in batch': 304, '/': 18563, 'loss': 7.728893280029297}\n", + "{'epoch': 0, 'update in batch': 305, '/': 18563, 'loss': 8.029245376586914}\n", + "{'epoch': 0, 'update in batch': 306, '/': 18563, 'loss': 7.361662864685059}\n", + "{'epoch': 0, 'update in batch': 307, '/': 18563, 'loss': 8.070173263549805}\n", + "{'epoch': 0, 'update in batch': 308, '/': 18563, 'loss': 7.55655574798584}\n", + "{'epoch': 0, 'update in batch': 309, '/': 18563, 'loss': 7.713553428649902}\n", + "{'epoch': 0, 'update in batch': 310, '/': 18563, 'loss': 8.333553314208984}\n", + "{'epoch': 0, 'update in batch': 311, '/': 18563, 'loss': 8.089872360229492}\n", + "{'epoch': 0, 'update in batch': 312, '/': 18563, 'loss': 8.951356887817383}\n", + "{'epoch': 0, 'update in batch': 313, '/': 18563, 'loss': 8.920665740966797}\n", + "{'epoch': 0, 'update in batch': 314, '/': 18563, 'loss': 8.811259269714355}\n", + "{'epoch': 0, 'update in batch': 315, '/': 18563, 'loss': 8.719802856445312}\n", + "{'epoch': 0, 'update in batch': 316, '/': 18563, 'loss': 8.700776100158691}\n", + "{'epoch': 0, 'update in batch': 317, '/': 18563, 'loss': 8.846036911010742}\n", + "{'epoch': 0, 'update in batch': 318, '/': 18563, 'loss': 8.553533554077148}\n", + "{'epoch': 0, 'update in batch': 319, '/': 18563, 'loss': 9.257116317749023}\n", + "{'epoch': 0, 'update in batch': 320, '/': 18563, 'loss': 8.487042427062988}\n", + "{'epoch': 0, 'update in batch': 321, '/': 18563, 'loss': 8.743330955505371}\n", + "{'epoch': 0, 'update in batch': 322, '/': 18563, 'loss': 8.377813339233398}\n", + "{'epoch': 0, 'update in batch': 323, '/': 18563, 'loss': 8.41798210144043}\n", + "{'epoch': 0, 'update in batch': 324, '/': 18563, 'loss': 7.884764671325684}\n", + "{'epoch': 0, 'update in batch': 325, '/': 18563, 'loss': 8.827409744262695}\n", + "{'epoch': 0, 'update in batch': 326, '/': 18563, 'loss': 8.21721363067627}\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'epoch': 0, 'update in batch': 327, '/': 18563, 'loss': 8.522723197937012}\n", + "{'epoch': 0, 'update in batch': 328, '/': 18563, 'loss': 7.387178897857666}\n", + "{'epoch': 0, 'update in batch': 329, '/': 18563, 'loss': 8.58663558959961}\n", + "{'epoch': 0, 'update in batch': 330, '/': 18563, 'loss': 8.539435386657715}\n", + "{'epoch': 0, 'update in batch': 331, '/': 18563, 'loss': 8.35865592956543}\n", + "{'epoch': 0, 'update in batch': 332, '/': 18563, 'loss': 8.55555248260498}\n", + "{'epoch': 0, 'update in batch': 333, '/': 18563, 'loss': 7.9116950035095215}\n", + "{'epoch': 0, 'update in batch': 334, '/': 18563, 'loss': 8.424735069274902}\n", + "{'epoch': 0, 'update in batch': 335, '/': 18563, 'loss': 8.383890151977539}\n", + "{'epoch': 0, 'update in batch': 336, '/': 18563, 'loss': 8.145454406738281}\n", + "{'epoch': 0, 'update in batch': 337, '/': 18563, 'loss': 8.014772415161133}\n", + "{'epoch': 0, 'update in batch': 338, '/': 18563, 'loss': 8.532005310058594}\n", + "{'epoch': 0, 'update in batch': 339, '/': 18563, 'loss': 8.979973793029785}\n", + "{'epoch': 0, 'update in batch': 340, '/': 18563, 'loss': 8.3964204788208}\n", + "{'epoch': 0, 'update in batch': 341, '/': 18563, 'loss': 8.34205150604248}\n", + "{'epoch': 0, 'update in batch': 342, '/': 18563, 'loss': 7.861489295959473}\n", + "{'epoch': 0, 'update in batch': 343, '/': 18563, 'loss': 8.807058334350586}\n", + "{'epoch': 0, 'update in batch': 344, '/': 18563, 'loss': 8.14976978302002}\n", + "{'epoch': 0, 'update in batch': 345, '/': 18563, 'loss': 8.212860107421875}\n", + "{'epoch': 0, 'update in batch': 346, '/': 18563, 'loss': 8.323419570922852}\n", + "{'epoch': 0, 'update in batch': 347, '/': 18563, 'loss': 9.06071662902832}\n", + "{'epoch': 0, 'update in batch': 348, '/': 18563, 'loss': 8.79192066192627}\n", + "{'epoch': 0, 'update in batch': 349, '/': 18563, 'loss': 8.717201232910156}\n", + "{'epoch': 0, 'update in batch': 350, '/': 18563, 'loss': 8.149703979492188}\n", + "{'epoch': 0, 'update in batch': 351, '/': 18563, 'loss': 7.990046501159668}\n", + "{'epoch': 0, 'update in batch': 352, '/': 18563, 'loss': 7.8197221755981445}\n", + "{'epoch': 0, 'update in batch': 353, '/': 18563, 'loss': 8.022729873657227}\n", + "{'epoch': 0, 'update in batch': 354, '/': 18563, 'loss': 8.339923858642578}\n", + "{'epoch': 0, 'update in batch': 355, '/': 18563, 'loss': 7.867880821228027}\n", + "{'epoch': 0, 'update in batch': 356, '/': 18563, 'loss': 8.161782264709473}\n", + "{'epoch': 0, 'update in batch': 357, '/': 18563, 'loss': 7.711170196533203}\n", + "{'epoch': 0, 'update in batch': 358, '/': 18563, 'loss': 8.46279239654541}\n", + "{'epoch': 0, 'update in batch': 359, '/': 18563, 'loss': 8.327804565429688}\n", + "{'epoch': 0, 'update in batch': 360, '/': 18563, 'loss': 8.184597969055176}\n", + "{'epoch': 0, 'update in batch': 361, '/': 18563, 'loss': 8.126212120056152}\n", + "{'epoch': 0, 'update in batch': 362, '/': 18563, 'loss': 8.122446060180664}\n", + "{'epoch': 0, 'update in batch': 363, '/': 18563, 'loss': 7.730257511138916}\n", + "{'epoch': 0, 'update in batch': 364, '/': 18563, 'loss': 7.7179059982299805}\n", + "{'epoch': 0, 'update in batch': 365, '/': 18563, 'loss': 7.557857513427734}\n", + "{'epoch': 0, 'update in batch': 366, '/': 18563, 'loss': 8.614083290100098}\n", + "{'epoch': 0, 'update in batch': 367, '/': 18563, 'loss': 8.0489501953125}\n", + "{'epoch': 0, 'update in batch': 368, '/': 18563, 'loss': 8.355381965637207}\n", + "{'epoch': 0, 'update in batch': 369, '/': 18563, 'loss': 7.592991828918457}\n", + "{'epoch': 0, 'update in batch': 370, '/': 18563, 'loss': 7.674102783203125}\n", + "{'epoch': 0, 'update in batch': 371, '/': 18563, 'loss': 7.818256378173828}\n", + "{'epoch': 0, 'update in batch': 372, '/': 18563, 'loss': 8.510438919067383}\n", + "{'epoch': 0, 'update in batch': 373, '/': 18563, 'loss': 8.02087116241455}\n", + "{'epoch': 0, 'update in batch': 374, '/': 18563, 'loss': 8.206090927124023}\n", + "{'epoch': 0, 'update in batch': 375, '/': 18563, 'loss': 7.645677089691162}\n", + "{'epoch': 0, 'update in batch': 376, '/': 18563, 'loss': 8.241236686706543}\n", + "{'epoch': 0, 'update in batch': 377, '/': 18563, 'loss': 8.581649780273438}\n", + "{'epoch': 0, 'update in batch': 378, '/': 18563, 'loss': 9.361258506774902}\n", + "{'epoch': 0, 'update in batch': 379, '/': 18563, 'loss': 9.097440719604492}\n", + "{'epoch': 0, 'update in batch': 380, '/': 18563, 'loss': 8.081677436828613}\n", + "{'epoch': 0, 'update in batch': 381, '/': 18563, 'loss': 8.761143684387207}\n", + "{'epoch': 0, 'update in batch': 382, '/': 18563, 'loss': 7.9429121017456055}\n", + "{'epoch': 0, 'update in batch': 383, '/': 18563, 'loss': 8.05648422241211}\n", + "{'epoch': 0, 'update in batch': 384, '/': 18563, 'loss': 7.316658020019531}\n", + "{'epoch': 0, 'update in batch': 385, '/': 18563, 'loss': 8.597393035888672}\n", + "{'epoch': 0, 'update in batch': 386, '/': 18563, 'loss': 9.393728256225586}\n", + "{'epoch': 0, 'update in batch': 387, '/': 18563, 'loss': 8.225081443786621}\n", + "{'epoch': 0, 'update in batch': 388, '/': 18563, 'loss': 7.9958319664001465}\n", + "{'epoch': 0, 'update in batch': 389, '/': 18563, 'loss': 8.390036582946777}\n", + "{'epoch': 0, 'update in batch': 390, '/': 18563, 'loss': 7.745572566986084}\n", + "{'epoch': 0, 'update in batch': 391, '/': 18563, 'loss': 8.403060913085938}\n", + "{'epoch': 0, 'update in batch': 392, '/': 18563, 'loss': 8.703788757324219}\n", + "{'epoch': 0, 'update in batch': 393, '/': 18563, 'loss': 8.516857147216797}\n", + "{'epoch': 0, 'update in batch': 394, '/': 18563, 'loss': 8.078744888305664}\n", + "{'epoch': 0, 'update in batch': 395, '/': 18563, 'loss': 7.6597900390625}\n", + "{'epoch': 0, 'update in batch': 396, '/': 18563, 'loss': 8.454282760620117}\n", + "{'epoch': 0, 'update in batch': 397, '/': 18563, 'loss': 7.7727837562561035}\n", + "{'epoch': 0, 'update in batch': 398, '/': 18563, 'loss': 8.222984313964844}\n", + "{'epoch': 0, 'update in batch': 399, '/': 18563, 'loss': 8.369619369506836}\n", + "{'epoch': 0, 'update in batch': 400, '/': 18563, 'loss': 8.542525291442871}\n", + "{'epoch': 0, 'update in batch': 401, '/': 18563, 'loss': 7.9681854248046875}\n", + "{'epoch': 0, 'update in batch': 402, '/': 18563, 'loss': 8.842118263244629}\n", + "{'epoch': 0, 'update in batch': 403, '/': 18563, 'loss': 7.958454132080078}\n", + "{'epoch': 0, 'update in batch': 404, '/': 18563, 'loss': 7.084095001220703}\n", + "{'epoch': 0, 'update in batch': 405, '/': 18563, 'loss': 7.8765130043029785}\n", + "{'epoch': 0, 'update in batch': 406, '/': 18563, 'loss': 7.639691352844238}\n", + "{'epoch': 0, 'update in batch': 407, '/': 18563, 'loss': 7.440125942230225}\n", + "{'epoch': 0, 'update in batch': 408, '/': 18563, 'loss': 7.928472995758057}\n", + "{'epoch': 0, 'update in batch': 409, '/': 18563, 'loss': 8.704710960388184}\n", + "{'epoch': 0, 'update in batch': 410, '/': 18563, 'loss': 8.214713096618652}\n", + "{'epoch': 0, 'update in batch': 411, '/': 18563, 'loss': 8.115629196166992}\n", + "{'epoch': 0, 'update in batch': 412, '/': 18563, 'loss': 9.357975006103516}\n", + "{'epoch': 0, 'update in batch': 413, '/': 18563, 'loss': 7.756926536560059}\n", + "{'epoch': 0, 'update in batch': 414, '/': 18563, 'loss': 8.93007755279541}\n", + "{'epoch': 0, 'update in batch': 415, '/': 18563, 'loss': 8.929518699645996}\n", + "{'epoch': 0, 'update in batch': 416, '/': 18563, 'loss': 7.646470069885254}\n", + "{'epoch': 0, 'update in batch': 417, '/': 18563, 'loss': 8.457891464233398}\n", + "{'epoch': 0, 'update in batch': 418, '/': 18563, 'loss': 7.377375602722168}\n", + "{'epoch': 0, 'update in batch': 419, '/': 18563, 'loss': 8.03713607788086}\n", + "{'epoch': 0, 'update in batch': 420, '/': 18563, 'loss': 8.125130653381348}\n", + "{'epoch': 0, 'update in batch': 421, '/': 18563, 'loss': 6.818246364593506}\n", + "{'epoch': 0, 'update in batch': 422, '/': 18563, 'loss': 7.220259189605713}\n", + "{'epoch': 0, 'update in batch': 423, '/': 18563, 'loss': 7.800910949707031}\n", + "{'epoch': 0, 'update in batch': 424, '/': 18563, 'loss': 8.175793647766113}\n", + "{'epoch': 0, 'update in batch': 425, '/': 18563, 'loss': 7.588067054748535}\n", + "{'epoch': 0, 'update in batch': 426, '/': 18563, 'loss': 7.2054619789123535}\n", + "{'epoch': 0, 'update in batch': 427, '/': 18563, 'loss': 7.6552839279174805}\n", + "{'epoch': 0, 'update in batch': 428, '/': 18563, 'loss': 8.851090431213379}\n", + "{'epoch': 0, 'update in batch': 429, '/': 18563, 'loss': 8.768563270568848}\n", + "{'epoch': 0, 'update in batch': 430, '/': 18563, 'loss': 7.926184177398682}\n", + "{'epoch': 0, 'update in batch': 431, '/': 18563, 'loss': 8.663213729858398}\n", + "{'epoch': 0, 'update in batch': 432, '/': 18563, 'loss': 8.386338233947754}\n", + "{'epoch': 0, 'update in batch': 433, '/': 18563, 'loss': 8.77399730682373}\n", + "{'epoch': 0, 'update in batch': 434, '/': 18563, 'loss': 8.385528564453125}\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'epoch': 0, 'update in batch': 435, '/': 18563, 'loss': 7.742388725280762}\n", + "{'epoch': 0, 'update in batch': 436, '/': 18563, 'loss': 8.363179206848145}\n", + "{'epoch': 0, 'update in batch': 437, '/': 18563, 'loss': 9.262784004211426}\n", + "{'epoch': 0, 'update in batch': 438, '/': 18563, 'loss': 9.236469268798828}\n", + "{'epoch': 0, 'update in batch': 439, '/': 18563, 'loss': 8.904603958129883}\n", + "{'epoch': 0, 'update in batch': 440, '/': 18563, 'loss': 8.675701141357422}\n", + "{'epoch': 0, 'update in batch': 441, '/': 18563, 'loss': 8.811418533325195}\n", + "{'epoch': 0, 'update in batch': 442, '/': 18563, 'loss': 8.002241134643555}\n", + "{'epoch': 0, 'update in batch': 443, '/': 18563, 'loss': 9.04414176940918}\n", + "{'epoch': 0, 'update in batch': 444, '/': 18563, 'loss': 7.8904008865356445}\n", + "{'epoch': 0, 'update in batch': 445, '/': 18563, 'loss': 8.524297714233398}\n", + "{'epoch': 0, 'update in batch': 446, '/': 18563, 'loss': 8.615904808044434}\n", + "{'epoch': 0, 'update in batch': 447, '/': 18563, 'loss': 8.201675415039062}\n", + "{'epoch': 0, 'update in batch': 448, '/': 18563, 'loss': 8.531024932861328}\n", + "{'epoch': 0, 'update in batch': 449, '/': 18563, 'loss': 7.8379621505737305}\n", + "{'epoch': 0, 'update in batch': 450, '/': 18563, 'loss': 8.416367530822754}\n", + "{'epoch': 0, 'update in batch': 451, '/': 18563, 'loss': 7.4990715980529785}\n", + "{'epoch': 0, 'update in batch': 452, '/': 18563, 'loss': 7.984610557556152}\n", + "{'epoch': 0, 'update in batch': 453, '/': 18563, 'loss': 7.719987392425537}\n", + "{'epoch': 0, 'update in batch': 454, '/': 18563, 'loss': 7.9333176612854}\n", + "{'epoch': 0, 'update in batch': 455, '/': 18563, 'loss': 8.619344711303711}\n", + "{'epoch': 0, 'update in batch': 456, '/': 18563, 'loss': 7.849525451660156}\n", + "{'epoch': 0, 'update in batch': 457, '/': 18563, 'loss': 7.700997352600098}\n", + "{'epoch': 0, 'update in batch': 458, '/': 18563, 'loss': 8.065767288208008}\n", + "{'epoch': 0, 'update in batch': 459, '/': 18563, 'loss': 7.489628791809082}\n", + "{'epoch': 0, 'update in batch': 460, '/': 18563, 'loss': 8.036481857299805}\n", + "{'epoch': 0, 'update in batch': 461, '/': 18563, 'loss': 8.227537155151367}\n", + "{'epoch': 0, 'update in batch': 462, '/': 18563, 'loss': 7.66103982925415}\n", + "{'epoch': 0, 'update in batch': 463, '/': 18563, 'loss': 8.481343269348145}\n", + "{'epoch': 0, 'update in batch': 464, '/': 18563, 'loss': 8.711318969726562}\n", + "{'epoch': 0, 'update in batch': 465, '/': 18563, 'loss': 7.549925804138184}\n", + "{'epoch': 0, 'update in batch': 466, '/': 18563, 'loss': 8.020782470703125}\n", + "{'epoch': 0, 'update in batch': 467, '/': 18563, 'loss': 7.784451484680176}\n", + "{'epoch': 0, 'update in batch': 468, '/': 18563, 'loss': 7.7545928955078125}\n", + "{'epoch': 0, 'update in batch': 469, '/': 18563, 'loss': 8.484171867370605}\n", + "{'epoch': 0, 'update in batch': 470, '/': 18563, 'loss': 8.291640281677246}\n", + "{'epoch': 0, 'update in batch': 471, '/': 18563, 'loss': 7.873322486877441}\n", + "{'epoch': 0, 'update in batch': 472, '/': 18563, 'loss': 7.891420841217041}\n", + "{'epoch': 0, 'update in batch': 473, '/': 18563, 'loss': 8.376962661743164}\n", + "{'epoch': 0, 'update in batch': 474, '/': 18563, 'loss': 8.147513389587402}\n", + "{'epoch': 0, 'update in batch': 475, '/': 18563, 'loss': 7.739943027496338}\n", + "{'epoch': 0, 'update in batch': 476, '/': 18563, 'loss': 7.52395486831665}\n", + "{'epoch': 0, 'update in batch': 477, '/': 18563, 'loss': 7.962507724761963}\n", + "{'epoch': 0, 'update in batch': 478, '/': 18563, 'loss': 7.61989688873291}\n", + "{'epoch': 0, 'update in batch': 479, '/': 18563, 'loss': 8.628551483154297}\n", + "{'epoch': 0, 'update in batch': 480, '/': 18563, 'loss': 10.344924926757812}\n", + "{'epoch': 0, 'update in batch': 481, '/': 18563, 'loss': 9.189457893371582}\n", + "{'epoch': 0, 'update in batch': 482, '/': 18563, 'loss': 9.283202171325684}\n", + "{'epoch': 0, 'update in batch': 483, '/': 18563, 'loss': 8.036226272583008}\n", + "{'epoch': 0, 'update in batch': 484, '/': 18563, 'loss': 8.949888229370117}\n", + "{'epoch': 0, 'update in batch': 485, '/': 18563, 'loss': 9.32779598236084}\n", + "{'epoch': 0, 'update in batch': 486, '/': 18563, 'loss': 9.554967880249023}\n", + "{'epoch': 0, 'update in batch': 487, '/': 18563, 'loss': 8.438692092895508}\n", + "{'epoch': 0, 'update in batch': 488, '/': 18563, 'loss': 8.015823364257812}\n", + "{'epoch': 0, 'update in batch': 489, '/': 18563, 'loss': 8.621005058288574}\n", + "{'epoch': 0, 'update in batch': 490, '/': 18563, 'loss': 8.432602882385254}\n", + "{'epoch': 0, 'update in batch': 491, '/': 18563, 'loss': 8.659430503845215}\n", + "{'epoch': 0, 'update in batch': 492, '/': 18563, 'loss': 8.693103790283203}\n", + "{'epoch': 0, 'update in batch': 493, '/': 18563, 'loss': 8.895064353942871}\n" + ] + }, + { + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mModel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvocab_size\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0muniq_words\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m64\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;32m\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(dataset, model, max_epochs, batch_size)\u001b[0m\n\u001b[1;32m 15\u001b[0m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcriterion\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_pred\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtranspose\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 16\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 17\u001b[0;31m \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 18\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 19\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/lib/python3.8/site-packages/torch/_tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m 361\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 362\u001b[0m inputs=inputs)\n\u001b[0;32m--> 363\u001b[0;31m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mautograd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 364\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 365\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mregister_hook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/lib/python3.8/site-packages/torch/autograd/__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m 171\u001b[0m \u001b[0;31m# some Python versions print out the first line of a multi-line function\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 172\u001b[0m \u001b[0;31m# calls in the traceback and some print out the last line\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 173\u001b[0;31m Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass\n\u001b[0m\u001b[1;32m 174\u001b[0m \u001b[0mtensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad_tensors_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 175\u001b[0m allow_unreachable=True, accumulate_grad=True) # Calls into the C++ engine to run the backward pass\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m: " + ] + } + ], + "source": [ + "model = Model(vocab_size = len(dataset.uniq_words)).to(device)\n", + "train(dataset, model, 1, 64)" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "def predict(dataset, model, text, next_words=5):\n", + " model.eval()\n", + " words = text.split(' ')\n", + " state_h, state_c = model.init_state(len(words))\n", + "\n", + " for i in range(0, next_words):\n", + " x = torch.tensor([[dataset.word_to_index[w] for w in words[i:]]]).to(device)\n", + " y_pred, (state_h, state_c) = model(x, (state_h, state_c))\n", + "\n", + " last_word_logits = y_pred[0][-1]\n", + " p = torch.nn.functional.softmax(last_word_logits, dim=0).detach().cpu().numpy()\n", + " word_index = np.random.choice(len(last_word_logits), p=p)\n", + " words.append(dataset.index_to_word[word_index])\n", + "\n", + " return words" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['kmicic', 'szedł', 'zwycięzco', 'po', 'do', 'zlituj', 'i']" + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "predict(dataset, model, 'kmicic szedł')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### ZADANIE 1\n", + "\n", + "Stworzyć sieć rekurencyjną GRU dla Challenging America word-gap prediction. Wymogi takie jak zawsze, zadanie widoczne na gonito" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### ZADANIE 2\n", + "\n", + "Podjąć wyzwanie na https://gonito.net/challenge/precipitation-pl i/lub https://gonito.net/challenge/book-dialogues-pl\n", + "\n", + "\n", + "**KONIECZNIE** należy je zgłosić do końca następnego piątku, czyli 20 maja!. Za późniejsze zgłoszenia (nawet minutę) nieprzyznaję punktów.\n", + " \n", + "Za każde zgłoszenie lepsze niż baseline przyznaję 40 punktów.\n", + "\n", + "Zamiast tych 40 punktów za najlepsze miejsca:\n", + "- 1. miejsce 150 punktów\n", + "- 2. miejsce 100 punktów\n", + "- 3. miejsce 70 punktów\n", + "\n", + "Można brać udział w 2 wyzwaniach jednocześnie.\n", + "\n", + "Zadania nie będą widoczne w gonito w achievements. Nie trzeba udostępniać kodu, należy jednak przestrzegać regulaminu wyzwań." + ] + } + ], + "metadata": { + "author": "Jakub Pokrywka", + "email": "kubapok@wmi.amu.edu.pl", + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "lang": "pl", + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.4" + }, + "subtitle": "0.Informacje na temat przedmiotu[ćwiczenia]", + "title": "Ekstrakcja informacji", + "year": "2021" + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/wyk/09_Rekurencyjny_model_jezyka.ipynb b/wyk/09_Rekurencyjny_model_jezyka.ipynb new file mode 100644 index 0000000..b4c5b5c --- /dev/null +++ b/wyk/09_Rekurencyjny_model_jezyka.ipynb @@ -0,0 +1,292 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Model języka oparty na rekurencyjnej sieci neuronowej\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Podejście rekurencyjne\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Na poprzednim wykładzie rozpatrywaliśmy różne funkcje\n", + "$A(w_1,\\dots,w_{i-1})$, dzięki którym możliwe było „skompresowanie” ciągu słów\n", + "(a właściwie ich zanurzeń) o dowolnej długości w wektor o stałej długości.\n", + "\n", + "Funkcję $A$ moglibyśmy zdefiniować w inny sposób, w sposób ****rekurencyjny****.\n", + "\n", + "Otóż moglibyśmy zdekomponować funkcję $A$ do\n", + "\n", + "- pewnego stanu początkowego $\\vec{s_0} \\in \\mathcal{R}^p$,\n", + "- pewnej funkcji rekurencyjnej $R : \\mathcal{R}^p \\times \\mathcal{R}^m \\rightarrow \\mathcal{R}^p$.\n", + "\n", + "Wówczas funkcję $A$ można będzie zdefiniować rekurencyjnie jako:\n", + "\n", + "$$A(w_1,\\dots,w_t) = R(A(w_1,\\dots,w_{t-1}), E(w_t)),$$\n", + "\n", + "przy czym dla ciągu pustego:\n", + "\n", + "$$A(\\epsilon) = \\vec{s_0}$$\n", + "\n", + "Przypomnijmy, że $m$ to rozmiar zanurzenia (embeddingu). Z kolei $p$ to rozmiar wektora stanu\n", + "(często $p=m$, ale nie jest to konieczne).\n", + "\n", + "Przy takim podejściu rekurencyjnym wprowadzamy niejako „strzałkę\n", + "czasu”, możemy mówić o przetwarzaniu krok po kroku.\n", + "\n", + "W wypadku modelowania języka możemy końcowy wektor stanu zrzutować do wektora o rozmiarze słownika\n", + "i zastosować softmax:\n", + "\n", + "$$\\vec{y} = \\operatorname{softmax}(CA(w_1,\\dots,w_{i-1})),$$\n", + "\n", + "gdzie $C$ jest wyuczalną macierzą o rozmiarze $|V| \\times p$.\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Worek słów zdefiniowany rekurencyjnie\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Nietrudno zdefiniować model „worka słów” w taki rekurencyjny sposób:\n", + "\n", + "- $p=m$,\n", + "- $\\vec{s_0} = [0,\\dots,0]$,\n", + "- $R(\\vec{s}, \\vec{x}) = \\vec{s} + \\vec{x}.$\n", + "\n", + "Dodawanie (również wektorowe) jest operacją przemienną i łączną, więc\n", + "to rekurencyjne spojrzenie niewiele tu wnosi. Można jednak zastosować\n", + "inną funkcję $R$, która nie jest przemienna — w ten sposób wyjdziemy poza\n", + "nieuporządkowany worek słów.\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Związek z programowaniem funkcyjnym\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Zauważmy, że stosowane tutaj podejście jest tożsame z zastosowaniem funkcji typu `fold`\n", + "w językach funkcyjnych:\n", + "\n", + "![img](./09_Rekurencyjny_model_jezyka/fold.png \"Opis funkcji foldl w języku Haskell\")\n", + "\n", + "W Pythonie odpowiednik `fold` jest funkcja `reduce` z pakietu `functools`:\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "18" + ] + } + ], + "source": [ + "from functools import reduce\n", + "\n", + "def product(ns):\n", + " return reduce(lambda a, b: a * b, ns, 1)\n", + "\n", + "product([2, 3, 1, 3])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Sieci rekurencyjne\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "W jaki sposób „złamać” przemienność i wprowadzić porządek? Jedną z\n", + "najprostszych operacji nieprzemiennych jest konkatenacja — możemy\n", + "dokonać konkatenacji wektora stanu i bieżącego stanu, a następnie\n", + "zastosować jakąś prostą operację (na wyjściu musimy mieć wektor o\n", + "rozmiarze $p$, nie $p + m$!), dobrze przy okazji „złamać” też\n", + "liniowość operacji. Możemy po prostu zastosować rzutowanie (mnożenie\n", + "przez macierz) i jakąś prostą funkcję aktywacji (na przykład sigmoidę):\n", + "\n", + "$$R(\\vec{s}, \\vec{e}) = \\sigma(W[\\vec{s},\\vec{e}] + \\vec{b}).$$\n", + "\n", + "Dodatkowo jeszcze wprowadziliśmy wektor obciążeń $\\vec{b}$, a zatem wyuczalne wagi obejmują:\n", + "\n", + "- macierz $W \\in \\mathcal{R}^p \\times \\mathcal{R}^{p+m}$,\n", + "- wektor obciążeń $b \\in \\mathcal{R}^p$.\n", + "\n", + "Olbrzymią zaletą sieci rekurencyjnych jest fakt, że liczba wag nie zależy od rozmiaru wejścia!\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Zwykła sieć rekurencyjna\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Wyżej zdefiniową sieć nazywamy „zwykłą” siecią rekurencyjną (*Vanilla RNN*).\n", + "\n", + "**Uwaga**: przez RNN czasami rozumie się taką „zwykłą” sieć\n", + "rekurencyjną, a czasami szerszą klasę sieci rekurencyjnych\n", + "obejmujących również sieci GRU czy LSTM (zob. poniżej).\n", + "\n", + "![img](./09_Rekurencyjny_model_jezyka/rnn.drawio.png \"Schemat prostego modelu języka opartego na zwykłej sieci rekurencyjnych\")\n", + "\n", + "**Uwaga**: powyższy schemat nie obejmuje już „całego” działania sieci,\n", + " tylko pojedynczy krok czasowy.\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Praktyczna niestosowalność prostych sieci RNN\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Niestety w praktyce proste sieci RNN sprawiają duże trudności jeśli\n", + "chodzi o propagację wsteczną — pojawia się zjawisko zanikającego\n", + "(rzadziej: eksplodującego) gradientu. Dlatego zaproponowano różne\n", + "modyfikacje sieci RNN. Zacznijmy od omówienia stosunkowo prostej sieci GRU.\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Sieć GRU\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "GRU (*Gated Recurrent Unit*) to sieć z dwiema ****bramkami**** (*gates*):\n", + "\n", + "- bramką resetu (*reset gate*) $\\Gamma_\\gamma \\in \\mathcal{R}^p$ — która określa, w jakim\n", + " stopniu sieć ma pamiętać albo zapominać stan z poprzedniego kroku,\n", + "- bramką aktualizacji (*update gate*) $\\Gamma_u \\in \\mathcal{R}^p$ — która określa wpływ\n", + " bieżącego wyrazu na zmianę stanu.\n", + "\n", + "Tak więc w skrajnym przypadku:\n", + "\n", + "- jeśli $\\Gamma_\\gamma = [0,\\dots,0]$, sieć całkowicie zapomina\n", + " informację płynącą z poprzednich wyrazów,\n", + "- jeśli $\\Gamma_u = [0,\\dots,0]$, sieć nie bierze pod uwagę\n", + " bieżącego wyrazu.\n", + "\n", + "Zauważmy, że bramki mogą selektywnie, na każdej pozycji wektora stanu,\n", + "sterować przepływem informacji. Na przykład $\\Gamma_\\gamma =\n", + "[0,1,\\dots,1]$ oznacza, że pierwsza pozycja wektora stanu jest\n", + "zapominana, a pozostałe — wnoszą wkład w całości.\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Wzory\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Najpierw zdefiniujmy pośredni stan $\\vec{\\xi} \\in \\mathcal{R}^p$:\n", + "\n", + "$$\\vec{\\xi_t} = \\operatorname{tanh}(W_{\\xi}[\\Gamma_\\gamma \\bullet c_{t-1}, E(w_t)] + b_{\\xi}),$$\n", + "\n", + "gdzie $\\bullet$ oznacza iloczyn Hadamarda (nie iloczyn skalarny!) dwóch wektorów:\n", + "\n", + "$$[x_1,\\dots,x_n] \\bullet [y_1,\\dots,y_n] = [x_1 y_1,\\dots,x_n y_n].$$\n", + "\n", + "Obliczanie $\\vec{\\xi_t}$ bardzo przypomina zwykłą sieć rekurencyjną,\n", + "jedyna różnica polega na tym, że za pomocą bramki $\\Gamma_\\gamma$\n", + "modulujemy wpływ poprzedniego stanu.\n", + "\n", + "Ostateczna wartość stanu jest średnią ważoną poprzedniego stanu i bieżącego stanu pośredniego:\n", + "\n", + "$$\\vec{c_t} = \\Gamma_u \\bullet \\vec{\\xi_t} + (1 - \\Gamma_u) \\bullet \\vec{c_{t-1}}.$$\n", + "\n", + "Skąd się biorą bramki $\\Gamma_\\gamma$ i $\\Gamma_u$? Również z poprzedniego stanu i z biężacego wyrazu.\n", + "\n", + "$$\\Gamma_\\gamma = \\sigma(W_\\gamma[\\vec{c_{t-1}},E(w_t)] + b_\\gamma),$$\n", + "\n", + "$$\\Gamma_u = \\sigma(W_u[\\vec{c_{t-1}},E(w_t)] + b_u),$$\n", + "\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.3" + }, + "org": null + }, + "nbformat": 4, + "nbformat_minor": 1 +}