Merge branch 'master' of git.wmi.amu.edu.pl:filipg/aitech-moj

This commit is contained in:
Filip Gralinski 2022-05-14 14:33:06 +02:00
commit 846a20111c
2 changed files with 1318 additions and 0 deletions

View File

@ -0,0 +1,984 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"![Logo 1](https://git.wmi.amu.edu.pl/AITech/Szablon/raw/branch/master/Logotyp_AITech1.jpg)\n",
"<div class=\"alert alert-block alert-info\">\n",
"<h1> Modelowanie Języka</h1>\n",
"<h2> 9. <i>Model neuronowy rekurencyjny</i> [ćwiczenia]</h2> \n",
"<h3> Jakub Pokrywka (2022)</h3>\n",
"</div>\n",
"\n",
"![Logo 2](https://git.wmi.amu.edu.pl/AITech/Szablon/raw/branch/master/Logotyp_AITech2.jpg)"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from torch import nn, optim\n",
"from torch.utils.data import DataLoader\n",
"import numpy as np\n",
"from collections import Counter\n",
"import re"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"device = 'cpu'"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"--2022-05-08 19:27:04-- https://wolnelektury.pl/media/book/txt/potop-tom-pierwszy.txt\n",
"Resolving wolnelektury.pl (wolnelektury.pl)... 51.83.143.148, 2001:41d0:602:3294::\n",
"Connecting to wolnelektury.pl (wolnelektury.pl)|51.83.143.148|:443... connected.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 877893 (857K) [text/plain]\n",
"Saving to: potop-tom-pierwszy.txt.2\n",
"\n",
"potop-tom-pierwszy. 100%[===================>] 857,32K --.-KB/s in 0,07s \n",
"\n",
"2022-05-08 19:27:04 (12,0 MB/s) - potop-tom-pierwszy.txt.2 saved [877893/877893]\n",
"\n",
"--2022-05-08 19:27:04-- https://wolnelektury.pl/media/book/txt/potop-tom-drugi.txt\n",
"Resolving wolnelektury.pl (wolnelektury.pl)... 51.83.143.148, 2001:41d0:602:3294::\n",
"Connecting to wolnelektury.pl (wolnelektury.pl)|51.83.143.148|:443... connected.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 1087797 (1,0M) [text/plain]\n",
"Saving to: potop-tom-drugi.txt.2\n",
"\n",
"potop-tom-drugi.txt 100%[===================>] 1,04M --.-KB/s in 0,08s \n",
"\n",
"2022-05-08 19:27:04 (12,9 MB/s) - potop-tom-drugi.txt.2 saved [1087797/1087797]\n",
"\n",
"--2022-05-08 19:27:05-- https://wolnelektury.pl/media/book/txt/potop-tom-trzeci.txt\n",
"Resolving wolnelektury.pl (wolnelektury.pl)... 51.83.143.148, 2001:41d0:602:3294::\n",
"Connecting to wolnelektury.pl (wolnelektury.pl)|51.83.143.148|:443... connected.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 788219 (770K) [text/plain]\n",
"Saving to: potop-tom-trzeci.txt.2\n",
"\n",
"potop-tom-trzeci.tx 100%[===================>] 769,75K --.-KB/s in 0,06s \n",
"\n",
"2022-05-08 19:27:05 (12,0 MB/s) - potop-tom-trzeci.txt.2 saved [788219/788219]\n",
"\n"
]
}
],
"source": [
"! wget https://wolnelektury.pl/media/book/txt/potop-tom-pierwszy.txt\n",
"! wget https://wolnelektury.pl/media/book/txt/potop-tom-drugi.txt\n",
"! wget https://wolnelektury.pl/media/book/txt/potop-tom-trzeci.txt"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"!cat potop-* > potop.txt"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"class Dataset(torch.utils.data.Dataset):\n",
" def __init__(\n",
" self,\n",
" sequence_length,\n",
" ):\n",
" self.sequence_length = sequence_length\n",
" self.words = self.load()\n",
" self.uniq_words = self.get_uniq_words()\n",
"\n",
" self.index_to_word = {index: word for index, word in enumerate(self.uniq_words)}\n",
" self.word_to_index = {word: index for index, word in enumerate(self.uniq_words)}\n",
"\n",
" self.words_indexes = [self.word_to_index[w] for w in self.words]\n",
"\n",
" def load(self):\n",
" with open('potop.txt', 'r') as f_in:\n",
" text = [x.rstrip() for x in f_in.readlines() if x.strip()]\n",
" text = ' '.join(text).lower()\n",
" text = re.sub('[^a-ząćęłńóśźż ]', '', text) \n",
" text = text.split(' ')\n",
" return text\n",
" \n",
" \n",
" def get_uniq_words(self):\n",
" word_counts = Counter(self.words)\n",
" return sorted(word_counts, key=word_counts.get, reverse=True)\n",
"\n",
" def __len__(self):\n",
" return len(self.words_indexes) - self.sequence_length\n",
"\n",
" def __getitem__(self, index):\n",
" return (\n",
" torch.tensor(self.words_indexes[index:index+self.sequence_length]),\n",
" torch.tensor(self.words_indexes[index+1:index+self.sequence_length+1]),\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"dataset = Dataset(5)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"text/plain": [
"(tensor([ 551, 18, 17, 255, 10748]),\n",
" tensor([ 18, 17, 255, 10748, 34]))"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dataset[200]"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['patrzył', 'tak', 'jak', 'człowiek', 'zbudzony']"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"[dataset.index_to_word[x] for x in [ 551, 18, 17, 255, 10748]]"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['tak', 'jak', 'człowiek', 'zbudzony', 'ze']"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"[dataset.index_to_word[x] for x in [ 18, 17, 255, 10748, 34]]"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"input_tensor = torch.tensor([[ 551, 18, 17, 255, 10748]], dtype=torch.int32).to(device)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"#input_tensor = torch.tensor([[ 551, 18]], dtype=torch.int32).to(device)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"class Model(nn.Module):\n",
" def __init__(self, vocab_size):\n",
" super(Model, self).__init__()\n",
" self.lstm_size = 128\n",
" self.embedding_dim = 128\n",
" self.num_layers = 3\n",
"\n",
" self.embedding = nn.Embedding(\n",
" num_embeddings=vocab_size,\n",
" embedding_dim=self.embedding_dim,\n",
" )\n",
" self.lstm = nn.LSTM(\n",
" input_size=self.lstm_size,\n",
" hidden_size=self.lstm_size,\n",
" num_layers=self.num_layers,\n",
" dropout=0.2,\n",
" )\n",
" self.fc = nn.Linear(self.lstm_size, vocab_size)\n",
"\n",
" def forward(self, x, prev_state = None):\n",
" embed = self.embedding(x)\n",
" output, state = self.lstm(embed, prev_state)\n",
" logits = self.fc(output)\n",
" return logits, state\n",
"\n",
" def init_state(self, sequence_length):\n",
" return (torch.zeros(self.num_layers, sequence_length, self.lstm_size).to(device),\n",
" torch.zeros(self.num_layers, sequence_length, self.lstm_size).to(device))"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"model = Model(len(dataset)).to(device)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"y_pred, (state_h, state_c) = model(input_tensor)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[[ 0.0046, -0.0113, 0.0313, ..., 0.0198, -0.0312, 0.0223],\n",
" [ 0.0039, -0.0110, 0.0303, ..., 0.0213, -0.0302, 0.0230],\n",
" [ 0.0029, -0.0133, 0.0265, ..., 0.0204, -0.0297, 0.0219],\n",
" [ 0.0010, -0.0120, 0.0282, ..., 0.0241, -0.0314, 0.0241],\n",
" [ 0.0038, -0.0106, 0.0346, ..., 0.0230, -0.0333, 0.0232]]],\n",
" grad_fn=<AddBackward0>)"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y_pred"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([1, 5, 1187998])"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y_pred.shape"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"def train(dataset, model, max_epochs, batch_size):\n",
" model.train()\n",
"\n",
" dataloader = DataLoader(dataset, batch_size=batch_size)\n",
" criterion = nn.CrossEntropyLoss()\n",
" optimizer = optim.Adam(model.parameters(), lr=0.001)\n",
"\n",
" for epoch in range(max_epochs):\n",
" for batch, (x, y) in enumerate(dataloader):\n",
" optimizer.zero_grad()\n",
" x = x.to(device)\n",
" y = y.to(device)\n",
"\n",
" y_pred, (state_h, state_c) = model(x)\n",
" loss = criterion(y_pred.transpose(1, 2), y)\n",
"\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
" print({ 'epoch': epoch, 'update in batch': batch, '/' : len(dataloader), 'loss': loss.item() })\n"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'epoch': 0, 'update in batch': 0, '/': 18563, 'loss': 10.717817306518555}\n",
"{'epoch': 0, 'update in batch': 1, '/': 18563, 'loss': 10.699922561645508}\n",
"{'epoch': 0, 'update in batch': 2, '/': 18563, 'loss': 10.701103210449219}\n",
"{'epoch': 0, 'update in batch': 3, '/': 18563, 'loss': 10.700254440307617}\n",
"{'epoch': 0, 'update in batch': 4, '/': 18563, 'loss': 10.69465160369873}\n",
"{'epoch': 0, 'update in batch': 5, '/': 18563, 'loss': 10.681333541870117}\n",
"{'epoch': 0, 'update in batch': 6, '/': 18563, 'loss': 10.668376922607422}\n",
"{'epoch': 0, 'update in batch': 7, '/': 18563, 'loss': 10.675261497497559}\n",
"{'epoch': 0, 'update in batch': 8, '/': 18563, 'loss': 10.665823936462402}\n",
"{'epoch': 0, 'update in batch': 9, '/': 18563, 'loss': 10.655462265014648}\n",
"{'epoch': 0, 'update in batch': 10, '/': 18563, 'loss': 10.591516494750977}\n",
"{'epoch': 0, 'update in batch': 11, '/': 18563, 'loss': 10.580559730529785}\n",
"{'epoch': 0, 'update in batch': 12, '/': 18563, 'loss': 10.524133682250977}\n",
"{'epoch': 0, 'update in batch': 13, '/': 18563, 'loss': 10.480895042419434}\n",
"{'epoch': 0, 'update in batch': 14, '/': 18563, 'loss': 10.33996295928955}\n",
"{'epoch': 0, 'update in batch': 15, '/': 18563, 'loss': 10.345580101013184}\n",
"{'epoch': 0, 'update in batch': 16, '/': 18563, 'loss': 10.200639724731445}\n",
"{'epoch': 0, 'update in batch': 17, '/': 18563, 'loss': 10.030133247375488}\n",
"{'epoch': 0, 'update in batch': 18, '/': 18563, 'loss': 10.046720504760742}\n",
"{'epoch': 0, 'update in batch': 19, '/': 18563, 'loss': 10.00318717956543}\n",
"{'epoch': 0, 'update in batch': 20, '/': 18563, 'loss': 9.588350296020508}\n",
"{'epoch': 0, 'update in batch': 21, '/': 18563, 'loss': 9.780914306640625}\n",
"{'epoch': 0, 'update in batch': 22, '/': 18563, 'loss': 9.36646842956543}\n",
"{'epoch': 0, 'update in batch': 23, '/': 18563, 'loss': 9.306387901306152}\n",
"{'epoch': 0, 'update in batch': 24, '/': 18563, 'loss': 9.150574684143066}\n",
"{'epoch': 0, 'update in batch': 25, '/': 18563, 'loss': 8.89719295501709}\n",
"{'epoch': 0, 'update in batch': 26, '/': 18563, 'loss': 8.741975784301758}\n",
"{'epoch': 0, 'update in batch': 27, '/': 18563, 'loss': 9.36513614654541}\n",
"{'epoch': 0, 'update in batch': 28, '/': 18563, 'loss': 8.840768814086914}\n",
"{'epoch': 0, 'update in batch': 29, '/': 18563, 'loss': 8.356801986694336}\n",
"{'epoch': 0, 'update in batch': 30, '/': 18563, 'loss': 8.274016380310059}\n",
"{'epoch': 0, 'update in batch': 31, '/': 18563, 'loss': 8.944927215576172}\n",
"{'epoch': 0, 'update in batch': 32, '/': 18563, 'loss': 8.923280715942383}\n",
"{'epoch': 0, 'update in batch': 33, '/': 18563, 'loss': 8.479402542114258}\n",
"{'epoch': 0, 'update in batch': 34, '/': 18563, 'loss': 8.42425537109375}\n",
"{'epoch': 0, 'update in batch': 35, '/': 18563, 'loss': 9.487113952636719}\n",
"{'epoch': 0, 'update in batch': 36, '/': 18563, 'loss': 8.314191818237305}\n",
"{'epoch': 0, 'update in batch': 37, '/': 18563, 'loss': 8.0274658203125}\n",
"{'epoch': 0, 'update in batch': 38, '/': 18563, 'loss': 8.725769996643066}\n",
"{'epoch': 0, 'update in batch': 39, '/': 18563, 'loss': 8.67934799194336}\n",
"{'epoch': 0, 'update in batch': 40, '/': 18563, 'loss': 8.872161865234375}\n",
"{'epoch': 0, 'update in batch': 41, '/': 18563, 'loss': 7.883971214294434}\n",
"{'epoch': 0, 'update in batch': 42, '/': 18563, 'loss': 7.682810306549072}\n",
"{'epoch': 0, 'update in batch': 43, '/': 18563, 'loss': 7.880677223205566}\n",
"{'epoch': 0, 'update in batch': 44, '/': 18563, 'loss': 7.807427406311035}\n",
"{'epoch': 0, 'update in batch': 45, '/': 18563, 'loss': 7.93829870223999}\n",
"{'epoch': 0, 'update in batch': 46, '/': 18563, 'loss': 7.718912601470947}\n",
"{'epoch': 0, 'update in batch': 47, '/': 18563, 'loss': 8.309863090515137}\n",
"{'epoch': 0, 'update in batch': 48, '/': 18563, 'loss': 9.091133117675781}\n",
"{'epoch': 0, 'update in batch': 49, '/': 18563, 'loss': 9.317312240600586}\n",
"{'epoch': 0, 'update in batch': 50, '/': 18563, 'loss': 8.517735481262207}\n",
"{'epoch': 0, 'update in batch': 51, '/': 18563, 'loss': 7.697592258453369}\n",
"{'epoch': 0, 'update in batch': 52, '/': 18563, 'loss': 6.838181972503662}\n",
"{'epoch': 0, 'update in batch': 53, '/': 18563, 'loss': 7.967227935791016}\n",
"{'epoch': 0, 'update in batch': 54, '/': 18563, 'loss': 8.47049331665039}\n",
"{'epoch': 0, 'update in batch': 55, '/': 18563, 'loss': 8.958921432495117}\n",
"{'epoch': 0, 'update in batch': 56, '/': 18563, 'loss': 8.316679000854492}\n",
"{'epoch': 0, 'update in batch': 57, '/': 18563, 'loss': 8.997099876403809}\n",
"{'epoch': 0, 'update in batch': 58, '/': 18563, 'loss': 8.608811378479004}\n",
"{'epoch': 0, 'update in batch': 59, '/': 18563, 'loss': 9.377460479736328}\n",
"{'epoch': 0, 'update in batch': 60, '/': 18563, 'loss': 8.6201171875}\n",
"{'epoch': 0, 'update in batch': 61, '/': 18563, 'loss': 8.821510314941406}\n",
"{'epoch': 0, 'update in batch': 62, '/': 18563, 'loss': 8.915961265563965}\n",
"{'epoch': 0, 'update in batch': 63, '/': 18563, 'loss': 8.222617149353027}\n",
"{'epoch': 0, 'update in batch': 64, '/': 18563, 'loss': 9.266777992248535}\n",
"{'epoch': 0, 'update in batch': 65, '/': 18563, 'loss': 8.749354362487793}\n",
"{'epoch': 0, 'update in batch': 66, '/': 18563, 'loss': 8.311641693115234}\n",
"{'epoch': 0, 'update in batch': 67, '/': 18563, 'loss': 8.553888320922852}\n",
"{'epoch': 0, 'update in batch': 68, '/': 18563, 'loss': 8.790258407592773}\n",
"{'epoch': 0, 'update in batch': 69, '/': 18563, 'loss': 9.090133666992188}\n",
"{'epoch': 0, 'update in batch': 70, '/': 18563, 'loss': 8.893723487854004}\n",
"{'epoch': 0, 'update in batch': 71, '/': 18563, 'loss': 8.844594955444336}\n",
"{'epoch': 0, 'update in batch': 72, '/': 18563, 'loss': 7.771625518798828}\n",
"{'epoch': 0, 'update in batch': 73, '/': 18563, 'loss': 8.536479949951172}\n",
"{'epoch': 0, 'update in batch': 74, '/': 18563, 'loss': 7.300860404968262}\n",
"{'epoch': 0, 'update in batch': 75, '/': 18563, 'loss': 8.62000846862793}\n",
"{'epoch': 0, 'update in batch': 76, '/': 18563, 'loss': 8.67784309387207}\n",
"{'epoch': 0, 'update in batch': 77, '/': 18563, 'loss': 7.319235801696777}\n",
"{'epoch': 0, 'update in batch': 78, '/': 18563, 'loss': 8.322186470031738}\n",
"{'epoch': 0, 'update in batch': 79, '/': 18563, 'loss': 7.767421722412109}\n",
"{'epoch': 0, 'update in batch': 80, '/': 18563, 'loss': 8.817885398864746}\n",
"{'epoch': 0, 'update in batch': 81, '/': 18563, 'loss': 8.133109092712402}\n",
"{'epoch': 0, 'update in batch': 82, '/': 18563, 'loss': 7.822054862976074}\n",
"{'epoch': 0, 'update in batch': 83, '/': 18563, 'loss': 8.055540084838867}\n",
"{'epoch': 0, 'update in batch': 84, '/': 18563, 'loss': 8.053682327270508}\n",
"{'epoch': 0, 'update in batch': 85, '/': 18563, 'loss': 8.018306732177734}\n",
"{'epoch': 0, 'update in batch': 86, '/': 18563, 'loss': 8.371909141540527}\n",
"{'epoch': 0, 'update in batch': 87, '/': 18563, 'loss': 8.057979583740234}\n",
"{'epoch': 0, 'update in batch': 88, '/': 18563, 'loss': 8.340703010559082}\n",
"{'epoch': 0, 'update in batch': 89, '/': 18563, 'loss': 8.7703857421875}\n",
"{'epoch': 0, 'update in batch': 90, '/': 18563, 'loss': 9.714847564697266}\n",
"{'epoch': 0, 'update in batch': 91, '/': 18563, 'loss': 8.621702194213867}\n",
"{'epoch': 0, 'update in batch': 92, '/': 18563, 'loss': 9.406997680664062}\n",
"{'epoch': 0, 'update in batch': 93, '/': 18563, 'loss': 9.29774284362793}\n",
"{'epoch': 0, 'update in batch': 94, '/': 18563, 'loss': 8.649836540222168}\n",
"{'epoch': 0, 'update in batch': 95, '/': 18563, 'loss': 8.441780090332031}\n",
"{'epoch': 0, 'update in batch': 96, '/': 18563, 'loss': 7.991406440734863}\n",
"{'epoch': 0, 'update in batch': 97, '/': 18563, 'loss': 9.314489364624023}\n",
"{'epoch': 0, 'update in batch': 98, '/': 18563, 'loss': 8.368816375732422}\n",
"{'epoch': 0, 'update in batch': 99, '/': 18563, 'loss': 8.771149635314941}\n",
"{'epoch': 0, 'update in batch': 100, '/': 18563, 'loss': 7.8758111000061035}\n",
"{'epoch': 0, 'update in batch': 101, '/': 18563, 'loss': 8.341328620910645}\n",
"{'epoch': 0, 'update in batch': 102, '/': 18563, 'loss': 8.413129806518555}\n",
"{'epoch': 0, 'update in batch': 103, '/': 18563, 'loss': 7.372011661529541}\n",
"{'epoch': 0, 'update in batch': 104, '/': 18563, 'loss': 8.170934677124023}\n",
"{'epoch': 0, 'update in batch': 105, '/': 18563, 'loss': 8.109993934631348}\n",
"{'epoch': 0, 'update in batch': 106, '/': 18563, 'loss': 8.172578811645508}\n",
"{'epoch': 0, 'update in batch': 107, '/': 18563, 'loss': 8.33222484588623}\n",
"{'epoch': 0, 'update in batch': 108, '/': 18563, 'loss': 7.997575283050537}\n",
"{'epoch': 0, 'update in batch': 109, '/': 18563, 'loss': 7.847937107086182}\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'epoch': 0, 'update in batch': 110, '/': 18563, 'loss': 7.351314544677734}\n",
"{'epoch': 0, 'update in batch': 111, '/': 18563, 'loss': 8.472936630249023}\n",
"{'epoch': 0, 'update in batch': 112, '/': 18563, 'loss': 7.855953216552734}\n",
"{'epoch': 0, 'update in batch': 113, '/': 18563, 'loss': 8.163175582885742}\n",
"{'epoch': 0, 'update in batch': 114, '/': 18563, 'loss': 8.208657264709473}\n",
"{'epoch': 0, 'update in batch': 115, '/': 18563, 'loss': 8.781523704528809}\n",
"{'epoch': 0, 'update in batch': 116, '/': 18563, 'loss': 8.449674606323242}\n",
"{'epoch': 0, 'update in batch': 117, '/': 18563, 'loss': 8.176030158996582}\n",
"{'epoch': 0, 'update in batch': 118, '/': 18563, 'loss': 8.415689468383789}\n",
"{'epoch': 0, 'update in batch': 119, '/': 18563, 'loss': 8.645845413208008}\n",
"{'epoch': 0, 'update in batch': 120, '/': 18563, 'loss': 8.160420417785645}\n",
"{'epoch': 0, 'update in batch': 121, '/': 18563, 'loss': 8.117982864379883}\n",
"{'epoch': 0, 'update in batch': 122, '/': 18563, 'loss': 9.099283218383789}\n",
"{'epoch': 0, 'update in batch': 123, '/': 18563, 'loss': 7.98253870010376}\n",
"{'epoch': 0, 'update in batch': 124, '/': 18563, 'loss': 8.112133979797363}\n",
"{'epoch': 0, 'update in batch': 125, '/': 18563, 'loss': 8.479134559631348}\n",
"{'epoch': 0, 'update in batch': 126, '/': 18563, 'loss': 8.92817497253418}\n",
"{'epoch': 0, 'update in batch': 127, '/': 18563, 'loss': 8.38918399810791}\n",
"{'epoch': 0, 'update in batch': 128, '/': 18563, 'loss': 9.000529289245605}\n",
"{'epoch': 0, 'update in batch': 129, '/': 18563, 'loss': 8.525534629821777}\n",
"{'epoch': 0, 'update in batch': 130, '/': 18563, 'loss': 9.055428504943848}\n",
"{'epoch': 0, 'update in batch': 131, '/': 18563, 'loss': 8.818662643432617}\n",
"{'epoch': 0, 'update in batch': 132, '/': 18563, 'loss': 8.807767868041992}\n",
"{'epoch': 0, 'update in batch': 133, '/': 18563, 'loss': 8.398343086242676}\n",
"{'epoch': 0, 'update in batch': 134, '/': 18563, 'loss': 8.435093879699707}\n",
"{'epoch': 0, 'update in batch': 135, '/': 18563, 'loss': 7.877000331878662}\n",
"{'epoch': 0, 'update in batch': 136, '/': 18563, 'loss': 8.197925567626953}\n",
"{'epoch': 0, 'update in batch': 137, '/': 18563, 'loss': 8.655011177062988}\n",
"{'epoch': 0, 'update in batch': 138, '/': 18563, 'loss': 7.786923885345459}\n",
"{'epoch': 0, 'update in batch': 139, '/': 18563, 'loss': 8.338996887207031}\n",
"{'epoch': 0, 'update in batch': 140, '/': 18563, 'loss': 8.607789993286133}\n",
"{'epoch': 0, 'update in batch': 141, '/': 18563, 'loss': 8.52219295501709}\n",
"{'epoch': 0, 'update in batch': 142, '/': 18563, 'loss': 8.436418533325195}\n",
"{'epoch': 0, 'update in batch': 143, '/': 18563, 'loss': 7.999323844909668}\n",
"{'epoch': 0, 'update in batch': 144, '/': 18563, 'loss': 7.543336391448975}\n",
"{'epoch': 0, 'update in batch': 145, '/': 18563, 'loss': 7.3255791664123535}\n",
"{'epoch': 0, 'update in batch': 146, '/': 18563, 'loss': 7.993613243103027}\n",
"{'epoch': 0, 'update in batch': 147, '/': 18563, 'loss': 8.8505859375}\n",
"{'epoch': 0, 'update in batch': 148, '/': 18563, 'loss': 8.146835327148438}\n",
"{'epoch': 0, 'update in batch': 149, '/': 18563, 'loss': 8.532424926757812}\n",
"{'epoch': 0, 'update in batch': 150, '/': 18563, 'loss': 8.323905944824219}\n",
"{'epoch': 0, 'update in batch': 151, '/': 18563, 'loss': 7.8726677894592285}\n",
"{'epoch': 0, 'update in batch': 152, '/': 18563, 'loss': 7.912005424499512}\n",
"{'epoch': 0, 'update in batch': 153, '/': 18563, 'loss': 8.010560035705566}\n",
"{'epoch': 0, 'update in batch': 154, '/': 18563, 'loss': 7.9417009353637695}\n",
"{'epoch': 0, 'update in batch': 155, '/': 18563, 'loss': 7.991711616516113}\n",
"{'epoch': 0, 'update in batch': 156, '/': 18563, 'loss': 8.27558708190918}\n",
"{'epoch': 0, 'update in batch': 157, '/': 18563, 'loss': 7.736246585845947}\n",
"{'epoch': 0, 'update in batch': 158, '/': 18563, 'loss': 7.4755754470825195}\n",
"{'epoch': 0, 'update in batch': 159, '/': 18563, 'loss': 8.023443222045898}\n",
"{'epoch': 0, 'update in batch': 160, '/': 18563, 'loss': 8.130350112915039}\n",
"{'epoch': 0, 'update in batch': 161, '/': 18563, 'loss': 7.770634651184082}\n",
"{'epoch': 0, 'update in batch': 162, '/': 18563, 'loss': 7.775434970855713}\n",
"{'epoch': 0, 'update in batch': 163, '/': 18563, 'loss': 7.965312957763672}\n",
"{'epoch': 0, 'update in batch': 164, '/': 18563, 'loss': 7.977341651916504}\n",
"{'epoch': 0, 'update in batch': 165, '/': 18563, 'loss': 7.703671455383301}\n",
"{'epoch': 0, 'update in batch': 166, '/': 18563, 'loss': 8.027135848999023}\n",
"{'epoch': 0, 'update in batch': 167, '/': 18563, 'loss': 7.7673773765563965}\n",
"{'epoch': 0, 'update in batch': 168, '/': 18563, 'loss': 8.654549598693848}\n",
"{'epoch': 0, 'update in batch': 169, '/': 18563, 'loss': 7.8060808181762695}\n",
"{'epoch': 0, 'update in batch': 170, '/': 18563, 'loss': 7.33704137802124}\n",
"{'epoch': 0, 'update in batch': 171, '/': 18563, 'loss': 7.971919059753418}\n",
"{'epoch': 0, 'update in batch': 172, '/': 18563, 'loss': 7.450611114501953}\n",
"{'epoch': 0, 'update in batch': 173, '/': 18563, 'loss': 7.978057861328125}\n",
"{'epoch': 0, 'update in batch': 174, '/': 18563, 'loss': 8.264434814453125}\n",
"{'epoch': 0, 'update in batch': 175, '/': 18563, 'loss': 8.47761058807373}\n",
"{'epoch': 0, 'update in batch': 176, '/': 18563, 'loss': 7.643885135650635}\n",
"{'epoch': 0, 'update in batch': 177, '/': 18563, 'loss': 8.696805000305176}\n",
"{'epoch': 0, 'update in batch': 178, '/': 18563, 'loss': 9.144462585449219}\n",
"{'epoch': 0, 'update in batch': 179, '/': 18563, 'loss': 8.582620620727539}\n",
"{'epoch': 0, 'update in batch': 180, '/': 18563, 'loss': 8.495562553405762}\n",
"{'epoch': 0, 'update in batch': 181, '/': 18563, 'loss': 9.259647369384766}\n",
"{'epoch': 0, 'update in batch': 182, '/': 18563, 'loss': 8.286632537841797}\n",
"{'epoch': 0, 'update in batch': 183, '/': 18563, 'loss': 8.378074645996094}\n",
"{'epoch': 0, 'update in batch': 184, '/': 18563, 'loss': 8.404892921447754}\n",
"{'epoch': 0, 'update in batch': 185, '/': 18563, 'loss': 9.206843376159668}\n",
"{'epoch': 0, 'update in batch': 186, '/': 18563, 'loss': 8.97215747833252}\n",
"{'epoch': 0, 'update in batch': 187, '/': 18563, 'loss': 8.281005859375}\n",
"{'epoch': 0, 'update in batch': 188, '/': 18563, 'loss': 7.638144493103027}\n",
"{'epoch': 0, 'update in batch': 189, '/': 18563, 'loss': 7.991082668304443}\n",
"{'epoch': 0, 'update in batch': 190, '/': 18563, 'loss': 8.207674026489258}\n",
"{'epoch': 0, 'update in batch': 191, '/': 18563, 'loss': 8.16801643371582}\n",
"{'epoch': 0, 'update in batch': 192, '/': 18563, 'loss': 7.827309608459473}\n",
"{'epoch': 0, 'update in batch': 193, '/': 18563, 'loss': 8.387285232543945}\n",
"{'epoch': 0, 'update in batch': 194, '/': 18563, 'loss': 7.990261077880859}\n",
"{'epoch': 0, 'update in batch': 195, '/': 18563, 'loss': 7.7953925132751465}\n",
"{'epoch': 0, 'update in batch': 196, '/': 18563, 'loss': 7.252983093261719}\n",
"{'epoch': 0, 'update in batch': 197, '/': 18563, 'loss': 7.806585788726807}\n",
"{'epoch': 0, 'update in batch': 198, '/': 18563, 'loss': 7.871600151062012}\n",
"{'epoch': 0, 'update in batch': 199, '/': 18563, 'loss': 7.639830589294434}\n",
"{'epoch': 0, 'update in batch': 200, '/': 18563, 'loss': 8.108308792114258}\n",
"{'epoch': 0, 'update in batch': 201, '/': 18563, 'loss': 7.41513729095459}\n",
"{'epoch': 0, 'update in batch': 202, '/': 18563, 'loss': 8.103743553161621}\n",
"{'epoch': 0, 'update in batch': 203, '/': 18563, 'loss': 8.82174301147461}\n",
"{'epoch': 0, 'update in batch': 204, '/': 18563, 'loss': 8.34859561920166}\n",
"{'epoch': 0, 'update in batch': 205, '/': 18563, 'loss': 7.890545845031738}\n",
"{'epoch': 0, 'update in batch': 206, '/': 18563, 'loss': 7.679532527923584}\n",
"{'epoch': 0, 'update in batch': 207, '/': 18563, 'loss': 7.810311317443848}\n",
"{'epoch': 0, 'update in batch': 208, '/': 18563, 'loss': 8.342585563659668}\n",
"{'epoch': 0, 'update in batch': 209, '/': 18563, 'loss': 8.253597259521484}\n",
"{'epoch': 0, 'update in batch': 210, '/': 18563, 'loss': 7.963072299957275}\n",
"{'epoch': 0, 'update in batch': 211, '/': 18563, 'loss': 8.537101745605469}\n",
"{'epoch': 0, 'update in batch': 212, '/': 18563, 'loss': 8.503724098205566}\n",
"{'epoch': 0, 'update in batch': 213, '/': 18563, 'loss': 8.568987846374512}\n",
"{'epoch': 0, 'update in batch': 214, '/': 18563, 'loss': 7.760678291320801}\n",
"{'epoch': 0, 'update in batch': 215, '/': 18563, 'loss': 8.302183151245117}\n",
"{'epoch': 0, 'update in batch': 216, '/': 18563, 'loss': 7.427420616149902}\n",
"{'epoch': 0, 'update in batch': 217, '/': 18563, 'loss': 8.05746078491211}\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'epoch': 0, 'update in batch': 218, '/': 18563, 'loss': 8.82285213470459}\n",
"{'epoch': 0, 'update in batch': 219, '/': 18563, 'loss': 7.948827266693115}\n",
"{'epoch': 0, 'update in batch': 220, '/': 18563, 'loss': 8.164112091064453}\n",
"{'epoch': 0, 'update in batch': 221, '/': 18563, 'loss': 7.721047401428223}\n",
"{'epoch': 0, 'update in batch': 222, '/': 18563, 'loss': 7.668707370758057}\n",
"{'epoch': 0, 'update in batch': 223, '/': 18563, 'loss': 8.576696395874023}\n",
"{'epoch': 0, 'update in batch': 224, '/': 18563, 'loss': 8.253091812133789}\n",
"{'epoch': 0, 'update in batch': 225, '/': 18563, 'loss': 8.303543090820312}\n",
"{'epoch': 0, 'update in batch': 226, '/': 18563, 'loss': 8.069855690002441}\n",
"{'epoch': 0, 'update in batch': 227, '/': 18563, 'loss': 8.57229232788086}\n",
"{'epoch': 0, 'update in batch': 228, '/': 18563, 'loss': 8.904585838317871}\n",
"{'epoch': 0, 'update in batch': 229, '/': 18563, 'loss': 8.485595703125}\n",
"{'epoch': 0, 'update in batch': 230, '/': 18563, 'loss': 8.22756290435791}\n",
"{'epoch': 0, 'update in batch': 231, '/': 18563, 'loss': 8.281603813171387}\n",
"{'epoch': 0, 'update in batch': 232, '/': 18563, 'loss': 7.591467380523682}\n",
"{'epoch': 0, 'update in batch': 233, '/': 18563, 'loss': 7.8028883934021}\n",
"{'epoch': 0, 'update in batch': 234, '/': 18563, 'loss': 8.079168319702148}\n",
"{'epoch': 0, 'update in batch': 235, '/': 18563, 'loss': 7.578390598297119}\n",
"{'epoch': 0, 'update in batch': 236, '/': 18563, 'loss': 7.865830421447754}\n",
"{'epoch': 0, 'update in batch': 237, '/': 18563, 'loss': 7.105422019958496}\n",
"{'epoch': 0, 'update in batch': 238, '/': 18563, 'loss': 8.034143447875977}\n",
"{'epoch': 0, 'update in batch': 239, '/': 18563, 'loss': 7.23009729385376}\n",
"{'epoch': 0, 'update in batch': 240, '/': 18563, 'loss': 7.221669673919678}\n",
"{'epoch': 0, 'update in batch': 241, '/': 18563, 'loss': 7.118913173675537}\n",
"{'epoch': 0, 'update in batch': 242, '/': 18563, 'loss': 7.690147399902344}\n",
"{'epoch': 0, 'update in batch': 243, '/': 18563, 'loss': 7.676979064941406}\n",
"{'epoch': 0, 'update in batch': 244, '/': 18563, 'loss': 8.231537818908691}\n",
"{'epoch': 0, 'update in batch': 245, '/': 18563, 'loss': 8.212566375732422}\n",
"{'epoch': 0, 'update in batch': 246, '/': 18563, 'loss': 9.095616340637207}\n",
"{'epoch': 0, 'update in batch': 247, '/': 18563, 'loss': 8.249703407287598}\n",
"{'epoch': 0, 'update in batch': 248, '/': 18563, 'loss': 9.082058906555176}\n",
"{'epoch': 0, 'update in batch': 249, '/': 18563, 'loss': 8.530516624450684}\n",
"{'epoch': 0, 'update in batch': 250, '/': 18563, 'loss': 8.979915618896484}\n",
"{'epoch': 0, 'update in batch': 251, '/': 18563, 'loss': 8.667882919311523}\n",
"{'epoch': 0, 'update in batch': 252, '/': 18563, 'loss': 8.804525375366211}\n",
"{'epoch': 0, 'update in batch': 253, '/': 18563, 'loss': 8.67729377746582}\n",
"{'epoch': 0, 'update in batch': 254, '/': 18563, 'loss': 8.580761909484863}\n",
"{'epoch': 0, 'update in batch': 255, '/': 18563, 'loss': 7.724173545837402}\n",
"{'epoch': 0, 'update in batch': 256, '/': 18563, 'loss': 7.7925591468811035}\n",
"{'epoch': 0, 'update in batch': 257, '/': 18563, 'loss': 7.731482028961182}\n",
"{'epoch': 0, 'update in batch': 258, '/': 18563, 'loss': 7.644040107727051}\n",
"{'epoch': 0, 'update in batch': 259, '/': 18563, 'loss': 7.947877407073975}\n",
"{'epoch': 0, 'update in batch': 260, '/': 18563, 'loss': 7.649043083190918}\n",
"{'epoch': 0, 'update in batch': 261, '/': 18563, 'loss': 7.40912389755249}\n",
"{'epoch': 0, 'update in batch': 262, '/': 18563, 'loss': 8.199918746948242}\n",
"{'epoch': 0, 'update in batch': 263, '/': 18563, 'loss': 7.272132873535156}\n",
"{'epoch': 0, 'update in batch': 264, '/': 18563, 'loss': 7.205214500427246}\n",
"{'epoch': 0, 'update in batch': 265, '/': 18563, 'loss': 8.999595642089844}\n",
"{'epoch': 0, 'update in batch': 266, '/': 18563, 'loss': 7.851510524749756}\n",
"{'epoch': 0, 'update in batch': 267, '/': 18563, 'loss': 7.748948097229004}\n",
"{'epoch': 0, 'update in batch': 268, '/': 18563, 'loss': 7.96875}\n",
"{'epoch': 0, 'update in batch': 269, '/': 18563, 'loss': 7.627255916595459}\n",
"{'epoch': 0, 'update in batch': 270, '/': 18563, 'loss': 7.719862937927246}\n",
"{'epoch': 0, 'update in batch': 271, '/': 18563, 'loss': 7.58780574798584}\n",
"{'epoch': 0, 'update in batch': 272, '/': 18563, 'loss': 8.386865615844727}\n",
"{'epoch': 0, 'update in batch': 273, '/': 18563, 'loss': 8.708396911621094}\n",
"{'epoch': 0, 'update in batch': 274, '/': 18563, 'loss': 7.853432655334473}\n",
"{'epoch': 0, 'update in batch': 275, '/': 18563, 'loss': 7.818131923675537}\n",
"{'epoch': 0, 'update in batch': 276, '/': 18563, 'loss': 7.714521884918213}\n",
"{'epoch': 0, 'update in batch': 277, '/': 18563, 'loss': 8.75371265411377}\n",
"{'epoch': 0, 'update in batch': 278, '/': 18563, 'loss': 7.6992998123168945}\n",
"{'epoch': 0, 'update in batch': 279, '/': 18563, 'loss': 7.652693748474121}\n",
"{'epoch': 0, 'update in batch': 280, '/': 18563, 'loss': 7.364585876464844}\n",
"{'epoch': 0, 'update in batch': 281, '/': 18563, 'loss': 7.742022514343262}\n",
"{'epoch': 0, 'update in batch': 282, '/': 18563, 'loss': 7.6205573081970215}\n",
"{'epoch': 0, 'update in batch': 283, '/': 18563, 'loss': 7.475846290588379}\n",
"{'epoch': 0, 'update in batch': 284, '/': 18563, 'loss': 7.302148342132568}\n",
"{'epoch': 0, 'update in batch': 285, '/': 18563, 'loss': 7.524351596832275}\n",
"{'epoch': 0, 'update in batch': 286, '/': 18563, 'loss': 7.755963325500488}\n",
"{'epoch': 0, 'update in batch': 287, '/': 18563, 'loss': 7.620995998382568}\n",
"{'epoch': 0, 'update in batch': 288, '/': 18563, 'loss': 7.289975166320801}\n",
"{'epoch': 0, 'update in batch': 289, '/': 18563, 'loss': 7.470652103424072}\n",
"{'epoch': 0, 'update in batch': 290, '/': 18563, 'loss': 7.297110557556152}\n",
"{'epoch': 0, 'update in batch': 291, '/': 18563, 'loss': 7.907563209533691}\n",
"{'epoch': 0, 'update in batch': 292, '/': 18563, 'loss': 8.051852226257324}\n",
"{'epoch': 0, 'update in batch': 293, '/': 18563, 'loss': 6.691899299621582}\n",
"{'epoch': 0, 'update in batch': 294, '/': 18563, 'loss': 7.9747819900512695}\n",
"{'epoch': 0, 'update in batch': 295, '/': 18563, 'loss': 7.415904998779297}\n",
"{'epoch': 0, 'update in batch': 296, '/': 18563, 'loss': 7.479670524597168}\n",
"{'epoch': 0, 'update in batch': 297, '/': 18563, 'loss': 7.9454755783081055}\n",
"{'epoch': 0, 'update in batch': 298, '/': 18563, 'loss': 7.79656457901001}\n",
"{'epoch': 0, 'update in batch': 299, '/': 18563, 'loss': 7.644859313964844}\n",
"{'epoch': 0, 'update in batch': 300, '/': 18563, 'loss': 7.649240970611572}\n",
"{'epoch': 0, 'update in batch': 301, '/': 18563, 'loss': 7.497203826904297}\n",
"{'epoch': 0, 'update in batch': 302, '/': 18563, 'loss': 7.169632911682129}\n",
"{'epoch': 0, 'update in batch': 303, '/': 18563, 'loss': 7.124764442443848}\n",
"{'epoch': 0, 'update in batch': 304, '/': 18563, 'loss': 7.728893280029297}\n",
"{'epoch': 0, 'update in batch': 305, '/': 18563, 'loss': 8.029245376586914}\n",
"{'epoch': 0, 'update in batch': 306, '/': 18563, 'loss': 7.361662864685059}\n",
"{'epoch': 0, 'update in batch': 307, '/': 18563, 'loss': 8.070173263549805}\n",
"{'epoch': 0, 'update in batch': 308, '/': 18563, 'loss': 7.55655574798584}\n",
"{'epoch': 0, 'update in batch': 309, '/': 18563, 'loss': 7.713553428649902}\n",
"{'epoch': 0, 'update in batch': 310, '/': 18563, 'loss': 8.333553314208984}\n",
"{'epoch': 0, 'update in batch': 311, '/': 18563, 'loss': 8.089872360229492}\n",
"{'epoch': 0, 'update in batch': 312, '/': 18563, 'loss': 8.951356887817383}\n",
"{'epoch': 0, 'update in batch': 313, '/': 18563, 'loss': 8.920665740966797}\n",
"{'epoch': 0, 'update in batch': 314, '/': 18563, 'loss': 8.811259269714355}\n",
"{'epoch': 0, 'update in batch': 315, '/': 18563, 'loss': 8.719802856445312}\n",
"{'epoch': 0, 'update in batch': 316, '/': 18563, 'loss': 8.700776100158691}\n",
"{'epoch': 0, 'update in batch': 317, '/': 18563, 'loss': 8.846036911010742}\n",
"{'epoch': 0, 'update in batch': 318, '/': 18563, 'loss': 8.553533554077148}\n",
"{'epoch': 0, 'update in batch': 319, '/': 18563, 'loss': 9.257116317749023}\n",
"{'epoch': 0, 'update in batch': 320, '/': 18563, 'loss': 8.487042427062988}\n",
"{'epoch': 0, 'update in batch': 321, '/': 18563, 'loss': 8.743330955505371}\n",
"{'epoch': 0, 'update in batch': 322, '/': 18563, 'loss': 8.377813339233398}\n",
"{'epoch': 0, 'update in batch': 323, '/': 18563, 'loss': 8.41798210144043}\n",
"{'epoch': 0, 'update in batch': 324, '/': 18563, 'loss': 7.884764671325684}\n",
"{'epoch': 0, 'update in batch': 325, '/': 18563, 'loss': 8.827409744262695}\n",
"{'epoch': 0, 'update in batch': 326, '/': 18563, 'loss': 8.21721363067627}\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'epoch': 0, 'update in batch': 327, '/': 18563, 'loss': 8.522723197937012}\n",
"{'epoch': 0, 'update in batch': 328, '/': 18563, 'loss': 7.387178897857666}\n",
"{'epoch': 0, 'update in batch': 329, '/': 18563, 'loss': 8.58663558959961}\n",
"{'epoch': 0, 'update in batch': 330, '/': 18563, 'loss': 8.539435386657715}\n",
"{'epoch': 0, 'update in batch': 331, '/': 18563, 'loss': 8.35865592956543}\n",
"{'epoch': 0, 'update in batch': 332, '/': 18563, 'loss': 8.55555248260498}\n",
"{'epoch': 0, 'update in batch': 333, '/': 18563, 'loss': 7.9116950035095215}\n",
"{'epoch': 0, 'update in batch': 334, '/': 18563, 'loss': 8.424735069274902}\n",
"{'epoch': 0, 'update in batch': 335, '/': 18563, 'loss': 8.383890151977539}\n",
"{'epoch': 0, 'update in batch': 336, '/': 18563, 'loss': 8.145454406738281}\n",
"{'epoch': 0, 'update in batch': 337, '/': 18563, 'loss': 8.014772415161133}\n",
"{'epoch': 0, 'update in batch': 338, '/': 18563, 'loss': 8.532005310058594}\n",
"{'epoch': 0, 'update in batch': 339, '/': 18563, 'loss': 8.979973793029785}\n",
"{'epoch': 0, 'update in batch': 340, '/': 18563, 'loss': 8.3964204788208}\n",
"{'epoch': 0, 'update in batch': 341, '/': 18563, 'loss': 8.34205150604248}\n",
"{'epoch': 0, 'update in batch': 342, '/': 18563, 'loss': 7.861489295959473}\n",
"{'epoch': 0, 'update in batch': 343, '/': 18563, 'loss': 8.807058334350586}\n",
"{'epoch': 0, 'update in batch': 344, '/': 18563, 'loss': 8.14976978302002}\n",
"{'epoch': 0, 'update in batch': 345, '/': 18563, 'loss': 8.212860107421875}\n",
"{'epoch': 0, 'update in batch': 346, '/': 18563, 'loss': 8.323419570922852}\n",
"{'epoch': 0, 'update in batch': 347, '/': 18563, 'loss': 9.06071662902832}\n",
"{'epoch': 0, 'update in batch': 348, '/': 18563, 'loss': 8.79192066192627}\n",
"{'epoch': 0, 'update in batch': 349, '/': 18563, 'loss': 8.717201232910156}\n",
"{'epoch': 0, 'update in batch': 350, '/': 18563, 'loss': 8.149703979492188}\n",
"{'epoch': 0, 'update in batch': 351, '/': 18563, 'loss': 7.990046501159668}\n",
"{'epoch': 0, 'update in batch': 352, '/': 18563, 'loss': 7.8197221755981445}\n",
"{'epoch': 0, 'update in batch': 353, '/': 18563, 'loss': 8.022729873657227}\n",
"{'epoch': 0, 'update in batch': 354, '/': 18563, 'loss': 8.339923858642578}\n",
"{'epoch': 0, 'update in batch': 355, '/': 18563, 'loss': 7.867880821228027}\n",
"{'epoch': 0, 'update in batch': 356, '/': 18563, 'loss': 8.161782264709473}\n",
"{'epoch': 0, 'update in batch': 357, '/': 18563, 'loss': 7.711170196533203}\n",
"{'epoch': 0, 'update in batch': 358, '/': 18563, 'loss': 8.46279239654541}\n",
"{'epoch': 0, 'update in batch': 359, '/': 18563, 'loss': 8.327804565429688}\n",
"{'epoch': 0, 'update in batch': 360, '/': 18563, 'loss': 8.184597969055176}\n",
"{'epoch': 0, 'update in batch': 361, '/': 18563, 'loss': 8.126212120056152}\n",
"{'epoch': 0, 'update in batch': 362, '/': 18563, 'loss': 8.122446060180664}\n",
"{'epoch': 0, 'update in batch': 363, '/': 18563, 'loss': 7.730257511138916}\n",
"{'epoch': 0, 'update in batch': 364, '/': 18563, 'loss': 7.7179059982299805}\n",
"{'epoch': 0, 'update in batch': 365, '/': 18563, 'loss': 7.557857513427734}\n",
"{'epoch': 0, 'update in batch': 366, '/': 18563, 'loss': 8.614083290100098}\n",
"{'epoch': 0, 'update in batch': 367, '/': 18563, 'loss': 8.0489501953125}\n",
"{'epoch': 0, 'update in batch': 368, '/': 18563, 'loss': 8.355381965637207}\n",
"{'epoch': 0, 'update in batch': 369, '/': 18563, 'loss': 7.592991828918457}\n",
"{'epoch': 0, 'update in batch': 370, '/': 18563, 'loss': 7.674102783203125}\n",
"{'epoch': 0, 'update in batch': 371, '/': 18563, 'loss': 7.818256378173828}\n",
"{'epoch': 0, 'update in batch': 372, '/': 18563, 'loss': 8.510438919067383}\n",
"{'epoch': 0, 'update in batch': 373, '/': 18563, 'loss': 8.02087116241455}\n",
"{'epoch': 0, 'update in batch': 374, '/': 18563, 'loss': 8.206090927124023}\n",
"{'epoch': 0, 'update in batch': 375, '/': 18563, 'loss': 7.645677089691162}\n",
"{'epoch': 0, 'update in batch': 376, '/': 18563, 'loss': 8.241236686706543}\n",
"{'epoch': 0, 'update in batch': 377, '/': 18563, 'loss': 8.581649780273438}\n",
"{'epoch': 0, 'update in batch': 378, '/': 18563, 'loss': 9.361258506774902}\n",
"{'epoch': 0, 'update in batch': 379, '/': 18563, 'loss': 9.097440719604492}\n",
"{'epoch': 0, 'update in batch': 380, '/': 18563, 'loss': 8.081677436828613}\n",
"{'epoch': 0, 'update in batch': 381, '/': 18563, 'loss': 8.761143684387207}\n",
"{'epoch': 0, 'update in batch': 382, '/': 18563, 'loss': 7.9429121017456055}\n",
"{'epoch': 0, 'update in batch': 383, '/': 18563, 'loss': 8.05648422241211}\n",
"{'epoch': 0, 'update in batch': 384, '/': 18563, 'loss': 7.316658020019531}\n",
"{'epoch': 0, 'update in batch': 385, '/': 18563, 'loss': 8.597393035888672}\n",
"{'epoch': 0, 'update in batch': 386, '/': 18563, 'loss': 9.393728256225586}\n",
"{'epoch': 0, 'update in batch': 387, '/': 18563, 'loss': 8.225081443786621}\n",
"{'epoch': 0, 'update in batch': 388, '/': 18563, 'loss': 7.9958319664001465}\n",
"{'epoch': 0, 'update in batch': 389, '/': 18563, 'loss': 8.390036582946777}\n",
"{'epoch': 0, 'update in batch': 390, '/': 18563, 'loss': 7.745572566986084}\n",
"{'epoch': 0, 'update in batch': 391, '/': 18563, 'loss': 8.403060913085938}\n",
"{'epoch': 0, 'update in batch': 392, '/': 18563, 'loss': 8.703788757324219}\n",
"{'epoch': 0, 'update in batch': 393, '/': 18563, 'loss': 8.516857147216797}\n",
"{'epoch': 0, 'update in batch': 394, '/': 18563, 'loss': 8.078744888305664}\n",
"{'epoch': 0, 'update in batch': 395, '/': 18563, 'loss': 7.6597900390625}\n",
"{'epoch': 0, 'update in batch': 396, '/': 18563, 'loss': 8.454282760620117}\n",
"{'epoch': 0, 'update in batch': 397, '/': 18563, 'loss': 7.7727837562561035}\n",
"{'epoch': 0, 'update in batch': 398, '/': 18563, 'loss': 8.222984313964844}\n",
"{'epoch': 0, 'update in batch': 399, '/': 18563, 'loss': 8.369619369506836}\n",
"{'epoch': 0, 'update in batch': 400, '/': 18563, 'loss': 8.542525291442871}\n",
"{'epoch': 0, 'update in batch': 401, '/': 18563, 'loss': 7.9681854248046875}\n",
"{'epoch': 0, 'update in batch': 402, '/': 18563, 'loss': 8.842118263244629}\n",
"{'epoch': 0, 'update in batch': 403, '/': 18563, 'loss': 7.958454132080078}\n",
"{'epoch': 0, 'update in batch': 404, '/': 18563, 'loss': 7.084095001220703}\n",
"{'epoch': 0, 'update in batch': 405, '/': 18563, 'loss': 7.8765130043029785}\n",
"{'epoch': 0, 'update in batch': 406, '/': 18563, 'loss': 7.639691352844238}\n",
"{'epoch': 0, 'update in batch': 407, '/': 18563, 'loss': 7.440125942230225}\n",
"{'epoch': 0, 'update in batch': 408, '/': 18563, 'loss': 7.928472995758057}\n",
"{'epoch': 0, 'update in batch': 409, '/': 18563, 'loss': 8.704710960388184}\n",
"{'epoch': 0, 'update in batch': 410, '/': 18563, 'loss': 8.214713096618652}\n",
"{'epoch': 0, 'update in batch': 411, '/': 18563, 'loss': 8.115629196166992}\n",
"{'epoch': 0, 'update in batch': 412, '/': 18563, 'loss': 9.357975006103516}\n",
"{'epoch': 0, 'update in batch': 413, '/': 18563, 'loss': 7.756926536560059}\n",
"{'epoch': 0, 'update in batch': 414, '/': 18563, 'loss': 8.93007755279541}\n",
"{'epoch': 0, 'update in batch': 415, '/': 18563, 'loss': 8.929518699645996}\n",
"{'epoch': 0, 'update in batch': 416, '/': 18563, 'loss': 7.646470069885254}\n",
"{'epoch': 0, 'update in batch': 417, '/': 18563, 'loss': 8.457891464233398}\n",
"{'epoch': 0, 'update in batch': 418, '/': 18563, 'loss': 7.377375602722168}\n",
"{'epoch': 0, 'update in batch': 419, '/': 18563, 'loss': 8.03713607788086}\n",
"{'epoch': 0, 'update in batch': 420, '/': 18563, 'loss': 8.125130653381348}\n",
"{'epoch': 0, 'update in batch': 421, '/': 18563, 'loss': 6.818246364593506}\n",
"{'epoch': 0, 'update in batch': 422, '/': 18563, 'loss': 7.220259189605713}\n",
"{'epoch': 0, 'update in batch': 423, '/': 18563, 'loss': 7.800910949707031}\n",
"{'epoch': 0, 'update in batch': 424, '/': 18563, 'loss': 8.175793647766113}\n",
"{'epoch': 0, 'update in batch': 425, '/': 18563, 'loss': 7.588067054748535}\n",
"{'epoch': 0, 'update in batch': 426, '/': 18563, 'loss': 7.2054619789123535}\n",
"{'epoch': 0, 'update in batch': 427, '/': 18563, 'loss': 7.6552839279174805}\n",
"{'epoch': 0, 'update in batch': 428, '/': 18563, 'loss': 8.851090431213379}\n",
"{'epoch': 0, 'update in batch': 429, '/': 18563, 'loss': 8.768563270568848}\n",
"{'epoch': 0, 'update in batch': 430, '/': 18563, 'loss': 7.926184177398682}\n",
"{'epoch': 0, 'update in batch': 431, '/': 18563, 'loss': 8.663213729858398}\n",
"{'epoch': 0, 'update in batch': 432, '/': 18563, 'loss': 8.386338233947754}\n",
"{'epoch': 0, 'update in batch': 433, '/': 18563, 'loss': 8.77399730682373}\n",
"{'epoch': 0, 'update in batch': 434, '/': 18563, 'loss': 8.385528564453125}\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'epoch': 0, 'update in batch': 435, '/': 18563, 'loss': 7.742388725280762}\n",
"{'epoch': 0, 'update in batch': 436, '/': 18563, 'loss': 8.363179206848145}\n",
"{'epoch': 0, 'update in batch': 437, '/': 18563, 'loss': 9.262784004211426}\n",
"{'epoch': 0, 'update in batch': 438, '/': 18563, 'loss': 9.236469268798828}\n",
"{'epoch': 0, 'update in batch': 439, '/': 18563, 'loss': 8.904603958129883}\n",
"{'epoch': 0, 'update in batch': 440, '/': 18563, 'loss': 8.675701141357422}\n",
"{'epoch': 0, 'update in batch': 441, '/': 18563, 'loss': 8.811418533325195}\n",
"{'epoch': 0, 'update in batch': 442, '/': 18563, 'loss': 8.002241134643555}\n",
"{'epoch': 0, 'update in batch': 443, '/': 18563, 'loss': 9.04414176940918}\n",
"{'epoch': 0, 'update in batch': 444, '/': 18563, 'loss': 7.8904008865356445}\n",
"{'epoch': 0, 'update in batch': 445, '/': 18563, 'loss': 8.524297714233398}\n",
"{'epoch': 0, 'update in batch': 446, '/': 18563, 'loss': 8.615904808044434}\n",
"{'epoch': 0, 'update in batch': 447, '/': 18563, 'loss': 8.201675415039062}\n",
"{'epoch': 0, 'update in batch': 448, '/': 18563, 'loss': 8.531024932861328}\n",
"{'epoch': 0, 'update in batch': 449, '/': 18563, 'loss': 7.8379621505737305}\n",
"{'epoch': 0, 'update in batch': 450, '/': 18563, 'loss': 8.416367530822754}\n",
"{'epoch': 0, 'update in batch': 451, '/': 18563, 'loss': 7.4990715980529785}\n",
"{'epoch': 0, 'update in batch': 452, '/': 18563, 'loss': 7.984610557556152}\n",
"{'epoch': 0, 'update in batch': 453, '/': 18563, 'loss': 7.719987392425537}\n",
"{'epoch': 0, 'update in batch': 454, '/': 18563, 'loss': 7.9333176612854}\n",
"{'epoch': 0, 'update in batch': 455, '/': 18563, 'loss': 8.619344711303711}\n",
"{'epoch': 0, 'update in batch': 456, '/': 18563, 'loss': 7.849525451660156}\n",
"{'epoch': 0, 'update in batch': 457, '/': 18563, 'loss': 7.700997352600098}\n",
"{'epoch': 0, 'update in batch': 458, '/': 18563, 'loss': 8.065767288208008}\n",
"{'epoch': 0, 'update in batch': 459, '/': 18563, 'loss': 7.489628791809082}\n",
"{'epoch': 0, 'update in batch': 460, '/': 18563, 'loss': 8.036481857299805}\n",
"{'epoch': 0, 'update in batch': 461, '/': 18563, 'loss': 8.227537155151367}\n",
"{'epoch': 0, 'update in batch': 462, '/': 18563, 'loss': 7.66103982925415}\n",
"{'epoch': 0, 'update in batch': 463, '/': 18563, 'loss': 8.481343269348145}\n",
"{'epoch': 0, 'update in batch': 464, '/': 18563, 'loss': 8.711318969726562}\n",
"{'epoch': 0, 'update in batch': 465, '/': 18563, 'loss': 7.549925804138184}\n",
"{'epoch': 0, 'update in batch': 466, '/': 18563, 'loss': 8.020782470703125}\n",
"{'epoch': 0, 'update in batch': 467, '/': 18563, 'loss': 7.784451484680176}\n",
"{'epoch': 0, 'update in batch': 468, '/': 18563, 'loss': 7.7545928955078125}\n",
"{'epoch': 0, 'update in batch': 469, '/': 18563, 'loss': 8.484171867370605}\n",
"{'epoch': 0, 'update in batch': 470, '/': 18563, 'loss': 8.291640281677246}\n",
"{'epoch': 0, 'update in batch': 471, '/': 18563, 'loss': 7.873322486877441}\n",
"{'epoch': 0, 'update in batch': 472, '/': 18563, 'loss': 7.891420841217041}\n",
"{'epoch': 0, 'update in batch': 473, '/': 18563, 'loss': 8.376962661743164}\n",
"{'epoch': 0, 'update in batch': 474, '/': 18563, 'loss': 8.147513389587402}\n",
"{'epoch': 0, 'update in batch': 475, '/': 18563, 'loss': 7.739943027496338}\n",
"{'epoch': 0, 'update in batch': 476, '/': 18563, 'loss': 7.52395486831665}\n",
"{'epoch': 0, 'update in batch': 477, '/': 18563, 'loss': 7.962507724761963}\n",
"{'epoch': 0, 'update in batch': 478, '/': 18563, 'loss': 7.61989688873291}\n",
"{'epoch': 0, 'update in batch': 479, '/': 18563, 'loss': 8.628551483154297}\n",
"{'epoch': 0, 'update in batch': 480, '/': 18563, 'loss': 10.344924926757812}\n",
"{'epoch': 0, 'update in batch': 481, '/': 18563, 'loss': 9.189457893371582}\n",
"{'epoch': 0, 'update in batch': 482, '/': 18563, 'loss': 9.283202171325684}\n",
"{'epoch': 0, 'update in batch': 483, '/': 18563, 'loss': 8.036226272583008}\n",
"{'epoch': 0, 'update in batch': 484, '/': 18563, 'loss': 8.949888229370117}\n",
"{'epoch': 0, 'update in batch': 485, '/': 18563, 'loss': 9.32779598236084}\n",
"{'epoch': 0, 'update in batch': 486, '/': 18563, 'loss': 9.554967880249023}\n",
"{'epoch': 0, 'update in batch': 487, '/': 18563, 'loss': 8.438692092895508}\n",
"{'epoch': 0, 'update in batch': 488, '/': 18563, 'loss': 8.015823364257812}\n",
"{'epoch': 0, 'update in batch': 489, '/': 18563, 'loss': 8.621005058288574}\n",
"{'epoch': 0, 'update in batch': 490, '/': 18563, 'loss': 8.432602882385254}\n",
"{'epoch': 0, 'update in batch': 491, '/': 18563, 'loss': 8.659430503845215}\n",
"{'epoch': 0, 'update in batch': 492, '/': 18563, 'loss': 8.693103790283203}\n",
"{'epoch': 0, 'update in batch': 493, '/': 18563, 'loss': 8.895064353942871}\n"
]
},
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-18-fe996a0be74b>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mModel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvocab_size\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0muniq_words\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m64\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;32m<ipython-input-17-8d700bc624e3>\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(dataset, model, max_epochs, batch_size)\u001b[0m\n\u001b[1;32m 15\u001b[0m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcriterion\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_pred\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtranspose\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 16\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 17\u001b[0;31m \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 18\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 19\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/anaconda3/lib/python3.8/site-packages/torch/_tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m 361\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 362\u001b[0m inputs=inputs)\n\u001b[0;32m--> 363\u001b[0;31m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mautograd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 364\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 365\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mregister_hook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/anaconda3/lib/python3.8/site-packages/torch/autograd/__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m 171\u001b[0m \u001b[0;31m# some Python versions print out the first line of a multi-line function\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 172\u001b[0m \u001b[0;31m# calls in the traceback and some print out the last line\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 173\u001b[0;31m Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass\n\u001b[0m\u001b[1;32m 174\u001b[0m \u001b[0mtensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad_tensors_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 175\u001b[0m allow_unreachable=True, accumulate_grad=True) # Calls into the C++ engine to run the backward pass\n",
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
]
}
],
"source": [
"model = Model(vocab_size = len(dataset.uniq_words)).to(device)\n",
"train(dataset, model, 1, 64)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"def predict(dataset, model, text, next_words=5):\n",
" model.eval()\n",
" words = text.split(' ')\n",
" state_h, state_c = model.init_state(len(words))\n",
"\n",
" for i in range(0, next_words):\n",
" x = torch.tensor([[dataset.word_to_index[w] for w in words[i:]]]).to(device)\n",
" y_pred, (state_h, state_c) = model(x, (state_h, state_c))\n",
"\n",
" last_word_logits = y_pred[0][-1]\n",
" p = torch.nn.functional.softmax(last_word_logits, dim=0).detach().cpu().numpy()\n",
" word_index = np.random.choice(len(last_word_logits), p=p)\n",
" words.append(dataset.index_to_word[word_index])\n",
"\n",
" return words"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['kmicic', 'szedł', 'zwycięzco', 'po', 'do', 'zlituj', 'i']"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"predict(dataset, model, 'kmicic szedł')"
]
}
],
"metadata": {
"author": "Jakub Pokrywka",
"email": "kubapok@wmi.amu.edu.pl",
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"lang": "pl",
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.3"
},
"subtitle": "0.Informacje na temat przedmiotu[ćwiczenia]",
"title": "Ekstrakcja informacji",
"year": "2021"
},
"nbformat": 4,
"nbformat_minor": 4
}

View File

@ -0,0 +1,334 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"![Logo 1](https://git.wmi.amu.edu.pl/AITech/Szablon/raw/branch/master/Logotyp_AITech1.jpg)\n",
"<div class=\"alert alert-block alert-info\">\n",
"<h1> Modelowanie Języka</h1>\n",
"<h2> 10. <i>Model neuronowy rekurencyjny</i> [ćwiczenia]</h2> \n",
"<h3> Jakub Pokrywka (2022)</h3>\n",
"</div>\n",
"\n",
"![Logo 2](https://git.wmi.amu.edu.pl/AITech/Szablon/raw/branch/master/Logotyp_AITech2.jpg)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Ensemble modeli"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"W jaki sposób można podnieść wynik predykcji dla zadania uczenia maszynowego?"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"mkdir: cannot create directory dev-0-ireland-news: File exists\r\n"
]
}
],
"source": [
"!mkdir dev-0-ireland-news"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Mamy wyzwanie:"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"https://gonito.net/challenge-all-submissions/ireland-news-headlines-word-gap"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"https://github.com/kubapok/ireland-news-word-gap/tree/0c6557c8a3cd6d8c77f64618850b2ae82c19476a"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"--2022-05-13 13:23:05-- https://github.com/kubapok/ireland-news-word-gap/raw/11c72875023c5c01c9d0c0ca39d72c90c840aeb3/dev-0/out.tsv\n",
"Resolving github.com (github.com)... 140.82.121.4\n",
"Connecting to github.com (github.com)|140.82.121.4|:443... connected.\n",
"HTTP request sent, awaiting response... 302 Found\n",
"Location: https://raw.githubusercontent.com/kubapok/ireland-news-word-gap/11c72875023c5c01c9d0c0ca39d72c90c840aeb3/dev-0/out.tsv [following]\n",
"--2022-05-13 13:23:06-- https://raw.githubusercontent.com/kubapok/ireland-news-word-gap/11c72875023c5c01c9d0c0ca39d72c90c840aeb3/dev-0/out.tsv\n",
"Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.110.133, 185.199.111.133, 185.199.108.133, ...\n",
"Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.110.133|:443... connected.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 63249692 (60M) [text/plain]\n",
"Saving to: out.tsv\n",
"\n",
"out.tsv 100%[===================>] 60,32M 26,7MB/s in 2,3s \n",
"\n",
"2022-05-13 13:23:08 (26,7 MB/s) - out.tsv saved [63249692/63249692]\n",
"\n",
"--2022-05-13 13:23:09-- https://github.com/kubapok/ireland-news-word-gap/raw/0c6557c8a3cd6d8c77f64618850b2ae82c19476a/dev-0/out.tsv\n",
"Resolving github.com (github.com)... 140.82.121.4\n",
"Connecting to github.com (github.com)|140.82.121.4|:443... connected.\n",
"HTTP request sent, awaiting response... 302 Found\n",
"Location: https://raw.githubusercontent.com/kubapok/ireland-news-word-gap/0c6557c8a3cd6d8c77f64618850b2ae82c19476a/dev-0/out.tsv [following]\n",
"--2022-05-13 13:23:09-- https://raw.githubusercontent.com/kubapok/ireland-news-word-gap/0c6557c8a3cd6d8c77f64618850b2ae82c19476a/dev-0/out.tsv\n",
"Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.110.133, 185.199.111.133, 185.199.108.133, ...\n",
"Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.110.133|:443... connected.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 63271863 (60M) [text/plain]\n",
"Saving to: out.tsv\n",
"\n",
"out.tsv 100%[===================>] 60,34M 45,1MB/s in 1,3s \n",
"\n",
"2022-05-13 13:23:10 (45,1 MB/s) - out.tsv saved [63271863/63271863]\n",
"\n",
"--2022-05-13 13:23:11-- https://git.wmi.amu.edu.pl/kubapok/ireland-news-word-gap-prediction/raw/branch/master/dev-0/expected.tsv\n",
"Resolving git.wmi.amu.edu.pl (git.wmi.amu.edu.pl)... 150.254.78.40\n",
"Connecting to git.wmi.amu.edu.pl (git.wmi.amu.edu.pl)|150.254.78.40|:443... connected.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 866583 (846K) [text/plain]\n",
"Saving to: expected.tsv.1\n",
"\n",
"expected.tsv.1 100%[===================>] 846,27K 1,91MB/s in 0,4s \n",
"\n",
"2022-05-13 13:23:11 (1,91 MB/s) - expected.tsv.1 saved [866583/866583]\n",
"\n"
]
}
],
"source": [
"!wget https://github.com/kubapok/ireland-news-word-gap/raw/11c72875023c5c01c9d0c0ca39d72c90c840aeb3/dev-0/out.tsv\n",
"!mv out.tsv ./dev-0/out-solution1.tsv\n",
"!wget https://github.com/kubapok/ireland-news-word-gap/raw/0c6557c8a3cd6d8c77f64618850b2ae82c19476a/dev-0/out.tsv\n",
"!mv out.tsv ./dev-0/out-solution2.tsv\n",
"! ( cd dev-0 ; wget https://git.wmi.amu.edu.pl/kubapok/ireland-news-word-gap-prediction/raw/branch/master/dev-0/expected.tsv)"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"--2022-05-13 13:23:12-- https://gonito.net/get/bin/geval\n",
"Resolving gonito.net (gonito.net)... 150.254.78.126\n",
"Connecting to gonito.net (gonito.net)|150.254.78.126|:443... connected.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 12860136 (12M) [application/octet-stream]\n",
"Saving to: geval.1\n",
"\n",
"geval.1 100%[===================>] 12,26M 2,67MB/s in 4,1s \n",
"\n",
"2022-05-13 13:23:16 (2,97 MB/s) - geval.1 saved [12860136/12860136]\n",
"\n"
]
}
],
"source": [
"!wget https://gonito.net/get/bin/geval\n",
"!chmod u+x geval"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"35.05218788086649\r\n"
]
}
],
"source": [
"!./geval --metric PerplexityHashed -o ./dev-0/out-solution1.tsv -e dev-0/expected.tsv"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"33.47429048442195\r\n"
]
}
],
"source": [
"!./geval --metric PerplexityHashed -o ./dev-0/out-solution2.tsv -e dev-0/expected.tsv"
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {},
"outputs": [],
"source": [
"with open('./dev-0/out-solution1.tsv') as s1, open('./dev-0/out-solution2.tsv') as s2, open('./dev-0/out-merge.tsv','w') as f_merge:\n",
" for l1, l2 in zip(s1, s2):\n",
" dir1 = {''.join(x.split(':')[:-1]): float(x.split(':')[-1]) for x in l1.rstrip().split(' ')}\n",
" dir2 = {''.join(x.split(':')[:-1]): float(x.split(':')[-1]) for x in l2.rstrip().split(' ')}\n",
" newdir = dict()\n",
" for k in dir1.keys() | dir2.keys():\n",
" newdir[k] = dir1[k] if k in dir1 else 0.0\n",
" newdir[k] += dir2[k] if k in dir2 else 0.0\n",
" newdir[k] /= 2\n",
" merge_line = ' '.join([k + ':' + str(v) for k,v in newdir.items()]) + '\\n'\n",
" f_merge.write(merge_line)"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"29.054162509715063\r\n"
]
}
],
"source": [
"!./geval --metric PerplexityHashed -o ./dev-0/out-merge.tsv -e dev-0/expected.tsv"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Złożenie:\n",
"\n",
"- kilku dobrych niezależnych modeli \n",
"- kilku modeli wytrenowanych dla różnego seeda\n",
"- kilku ostatnich checkpointów"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"W jaki sposób składać różne modele:\n",
"\n",
"- średnia ważona\n",
"- średnia geometryczna\n",
"- inna średnia\n",
"- trenowanie osobnego prostego modelu, którego zadanie to składanie modeli (np. regresja liniowa)\n",
"\n",
"\n",
"Można też trenować wspólnie kilka modeli jednocześnie ze wspólnym backpropagation."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Jakie są minusy ensemble?\n",
"\n",
"- wyższy stopień skomplikowania modelu\n",
"- dłuższy czas inferencji\n",
"- zużycie większych zasobów komputera\n",
"- gorsza interpretowalność"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"W praktyce jeżeli startujemy w konkursie uczenia maszynowego, zawsze warto robić ensemble.\n",
"\n",
"\n",
"W komercji jeżeli mamy ograniczenia czasowe lub zasobów, model jest ciężki, wynik modelu nie jest bardzo ważny, to często nie korzysta się z ensembli.\n",
"W nauce albo kiedy chcemy porównać kilka różnych metod, to składanie modeli zaburza nam niepotrzebnie obraz."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Warto mieć na uwadze, że niektóre metody z założenia są ensemblami. Np. las losowy albo boostowane drzewa decyzyjne."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### ZADANIE\n",
"\n",
"przykładowy tekst: \"ala\" \"ma\" \"kota\" \"MASK\" \"2\" \"psy\" \"i\" \"chomika\"\n",
"\n",
"Stworzyć 2 sieci rekurencyjne LSTM dla Challenging America word-gap prediction. \n",
"- jedna sieć powinna działać do przodu (czyli jak zwyczajna sieć): \"ala\" \"ma\" \"kota\"\n",
"- druga siec powinna działać do tyłu: \"chomika\" \"i\" \"psy\" \"2\"\n",
" \n",
"Zrobienie sieci odwrotnej jest bardzo proste. Wystarczy odwrócić kolejność słów, nie ma potrzeby ingerować w architekturę modeli.\n",
"\n",
"Następnie należy zrobić jakiś ensemble tych modeli. W najprostszej wersji może to być średnia arytmetyczna. Warto jednak spróbować innych sposobów, np. można trenować te sieci łącznie.\n",
"\n",
"\n",
"Wymogi takie jak zawsze, zadanie widoczne na gonito."
]
}
],
"metadata": {
"author": "Jakub Pokrywka",
"email": "kubapok@wmi.amu.edu.pl",
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"lang": "pl",
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.4"
},
"subtitle": "0.Informacje na temat przedmiotu[ćwiczenia]",
"title": "Ekstrakcja informacji",
"year": "2021"
},
"nbformat": 4,
"nbformat_minor": 4
}