modelowanie-jezykowe-aitech-cw/cw/09_Model_neuronowy_rekurencyjny.ipynb

1017 lines
61 KiB
Plaintext
Raw Normal View History

2022-05-09 09:59:47 +02:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"![Logo 1](https://git.wmi.amu.edu.pl/AITech/Szablon/raw/branch/master/Logotyp_AITech1.jpg)\n",
"<div class=\"alert alert-block alert-info\">\n",
"<h1> Modelowanie Języka</h1>\n",
"<h2> 9. <i>Model neuronowy rekurencyjny</i> [ćwiczenia]</h2> \n",
"<h3> Jakub Pokrywka (2022)</h3>\n",
"</div>\n",
"\n",
"![Logo 2](https://git.wmi.amu.edu.pl/AITech/Szablon/raw/branch/master/Logotyp_AITech2.jpg)"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from torch import nn, optim\n",
"from torch.utils.data import DataLoader\n",
"import numpy as np\n",
"from collections import Counter\n",
"import re"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"device = 'cpu'"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"--2022-05-08 19:27:04-- https://wolnelektury.pl/media/book/txt/potop-tom-pierwszy.txt\n",
"Resolving wolnelektury.pl (wolnelektury.pl)... 51.83.143.148, 2001:41d0:602:3294::\n",
"Connecting to wolnelektury.pl (wolnelektury.pl)|51.83.143.148|:443... connected.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 877893 (857K) [text/plain]\n",
"Saving to: potop-tom-pierwszy.txt.2\n",
"\n",
"potop-tom-pierwszy. 100%[===================>] 857,32K --.-KB/s in 0,07s \n",
"\n",
"2022-05-08 19:27:04 (12,0 MB/s) - potop-tom-pierwszy.txt.2 saved [877893/877893]\n",
"\n",
"--2022-05-08 19:27:04-- https://wolnelektury.pl/media/book/txt/potop-tom-drugi.txt\n",
"Resolving wolnelektury.pl (wolnelektury.pl)... 51.83.143.148, 2001:41d0:602:3294::\n",
"Connecting to wolnelektury.pl (wolnelektury.pl)|51.83.143.148|:443... connected.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 1087797 (1,0M) [text/plain]\n",
"Saving to: potop-tom-drugi.txt.2\n",
"\n",
"potop-tom-drugi.txt 100%[===================>] 1,04M --.-KB/s in 0,08s \n",
"\n",
"2022-05-08 19:27:04 (12,9 MB/s) - potop-tom-drugi.txt.2 saved [1087797/1087797]\n",
"\n",
"--2022-05-08 19:27:05-- https://wolnelektury.pl/media/book/txt/potop-tom-trzeci.txt\n",
"Resolving wolnelektury.pl (wolnelektury.pl)... 51.83.143.148, 2001:41d0:602:3294::\n",
"Connecting to wolnelektury.pl (wolnelektury.pl)|51.83.143.148|:443... connected.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 788219 (770K) [text/plain]\n",
"Saving to: potop-tom-trzeci.txt.2\n",
"\n",
"potop-tom-trzeci.tx 100%[===================>] 769,75K --.-KB/s in 0,06s \n",
"\n",
"2022-05-08 19:27:05 (12,0 MB/s) - potop-tom-trzeci.txt.2 saved [788219/788219]\n",
"\n"
]
}
],
"source": [
"! wget https://wolnelektury.pl/media/book/txt/potop-tom-pierwszy.txt\n",
"! wget https://wolnelektury.pl/media/book/txt/potop-tom-drugi.txt\n",
"! wget https://wolnelektury.pl/media/book/txt/potop-tom-trzeci.txt"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"!cat potop-* > potop.txt"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"class Dataset(torch.utils.data.Dataset):\n",
" def __init__(\n",
" self,\n",
" sequence_length,\n",
" ):\n",
" self.sequence_length = sequence_length\n",
" self.words = self.load()\n",
" self.uniq_words = self.get_uniq_words()\n",
"\n",
" self.index_to_word = {index: word for index, word in enumerate(self.uniq_words)}\n",
" self.word_to_index = {word: index for index, word in enumerate(self.uniq_words)}\n",
"\n",
" self.words_indexes = [self.word_to_index[w] for w in self.words]\n",
"\n",
" def load(self):\n",
" with open('potop.txt', 'r') as f_in:\n",
" text = [x.rstrip() for x in f_in.readlines() if x.strip()]\n",
" text = ' '.join(text).lower()\n",
" text = re.sub('[^a-ząćęłńóśźż ]', '', text) \n",
" text = text.split(' ')\n",
" return text\n",
" \n",
" \n",
" def get_uniq_words(self):\n",
" word_counts = Counter(self.words)\n",
" return sorted(word_counts, key=word_counts.get, reverse=True)\n",
"\n",
" def __len__(self):\n",
" return len(self.words_indexes) - self.sequence_length\n",
"\n",
" def __getitem__(self, index):\n",
" return (\n",
" torch.tensor(self.words_indexes[index:index+self.sequence_length]),\n",
" torch.tensor(self.words_indexes[index+1:index+self.sequence_length+1]),\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"dataset = Dataset(5)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"text/plain": [
"(tensor([ 551, 18, 17, 255, 10748]),\n",
" tensor([ 18, 17, 255, 10748, 34]))"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dataset[200]"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['patrzył', 'tak', 'jak', 'człowiek', 'zbudzony']"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"[dataset.index_to_word[x] for x in [ 551, 18, 17, 255, 10748]]"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['tak', 'jak', 'człowiek', 'zbudzony', 'ze']"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"[dataset.index_to_word[x] for x in [ 18, 17, 255, 10748, 34]]"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"input_tensor = torch.tensor([[ 551, 18, 17, 255, 10748]], dtype=torch.int32).to(device)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"#input_tensor = torch.tensor([[ 551, 18]], dtype=torch.int32).to(device)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"class Model(nn.Module):\n",
" def __init__(self, vocab_size):\n",
" super(Model, self).__init__()\n",
" self.lstm_size = 128\n",
" self.embedding_dim = 128\n",
" self.num_layers = 3\n",
"\n",
" self.embedding = nn.Embedding(\n",
" num_embeddings=vocab_size,\n",
" embedding_dim=self.embedding_dim,\n",
" )\n",
" self.lstm = nn.LSTM(\n",
" input_size=self.lstm_size,\n",
" hidden_size=self.lstm_size,\n",
" num_layers=self.num_layers,\n",
" dropout=0.2,\n",
" )\n",
" self.fc = nn.Linear(self.lstm_size, vocab_size)\n",
"\n",
" def forward(self, x, prev_state = None):\n",
" embed = self.embedding(x)\n",
" output, state = self.lstm(embed, prev_state)\n",
" logits = self.fc(output)\n",
" return logits, state\n",
"\n",
" def init_state(self, sequence_length):\n",
" return (torch.zeros(self.num_layers, sequence_length, self.lstm_size).to(device),\n",
" torch.zeros(self.num_layers, sequence_length, self.lstm_size).to(device))"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"model = Model(len(dataset)).to(device)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"y_pred, (state_h, state_c) = model(input_tensor)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[[ 0.0046, -0.0113, 0.0313, ..., 0.0198, -0.0312, 0.0223],\n",
" [ 0.0039, -0.0110, 0.0303, ..., 0.0213, -0.0302, 0.0230],\n",
" [ 0.0029, -0.0133, 0.0265, ..., 0.0204, -0.0297, 0.0219],\n",
" [ 0.0010, -0.0120, 0.0282, ..., 0.0241, -0.0314, 0.0241],\n",
" [ 0.0038, -0.0106, 0.0346, ..., 0.0230, -0.0333, 0.0232]]],\n",
" grad_fn=<AddBackward0>)"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y_pred"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([1, 5, 1187998])"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y_pred.shape"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"def train(dataset, model, max_epochs, batch_size):\n",
" model.train()\n",
"\n",
" dataloader = DataLoader(dataset, batch_size=batch_size)\n",
" criterion = nn.CrossEntropyLoss()\n",
" optimizer = optim.Adam(model.parameters(), lr=0.001)\n",
"\n",
" for epoch in range(max_epochs):\n",
" for batch, (x, y) in enumerate(dataloader):\n",
" optimizer.zero_grad()\n",
" x = x.to(device)\n",
" y = y.to(device)\n",
"\n",
" y_pred, (state_h, state_c) = model(x)\n",
" loss = criterion(y_pred.transpose(1, 2), y)\n",
"\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
" print({ 'epoch': epoch, 'update in batch': batch, '/' : len(dataloader), 'loss': loss.item() })\n"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'epoch': 0, 'update in batch': 0, '/': 18563, 'loss': 10.717817306518555}\n",
"{'epoch': 0, 'update in batch': 1, '/': 18563, 'loss': 10.699922561645508}\n",
"{'epoch': 0, 'update in batch': 2, '/': 18563, 'loss': 10.701103210449219}\n",
"{'epoch': 0, 'update in batch': 3, '/': 18563, 'loss': 10.700254440307617}\n",
"{'epoch': 0, 'update in batch': 4, '/': 18563, 'loss': 10.69465160369873}\n",
"{'epoch': 0, 'update in batch': 5, '/': 18563, 'loss': 10.681333541870117}\n",
"{'epoch': 0, 'update in batch': 6, '/': 18563, 'loss': 10.668376922607422}\n",
"{'epoch': 0, 'update in batch': 7, '/': 18563, 'loss': 10.675261497497559}\n",
"{'epoch': 0, 'update in batch': 8, '/': 18563, 'loss': 10.665823936462402}\n",
"{'epoch': 0, 'update in batch': 9, '/': 18563, 'loss': 10.655462265014648}\n",
"{'epoch': 0, 'update in batch': 10, '/': 18563, 'loss': 10.591516494750977}\n",
"{'epoch': 0, 'update in batch': 11, '/': 18563, 'loss': 10.580559730529785}\n",
"{'epoch': 0, 'update in batch': 12, '/': 18563, 'loss': 10.524133682250977}\n",
"{'epoch': 0, 'update in batch': 13, '/': 18563, 'loss': 10.480895042419434}\n",
"{'epoch': 0, 'update in batch': 14, '/': 18563, 'loss': 10.33996295928955}\n",
"{'epoch': 0, 'update in batch': 15, '/': 18563, 'loss': 10.345580101013184}\n",
"{'epoch': 0, 'update in batch': 16, '/': 18563, 'loss': 10.200639724731445}\n",
"{'epoch': 0, 'update in batch': 17, '/': 18563, 'loss': 10.030133247375488}\n",
"{'epoch': 0, 'update in batch': 18, '/': 18563, 'loss': 10.046720504760742}\n",
"{'epoch': 0, 'update in batch': 19, '/': 18563, 'loss': 10.00318717956543}\n",
"{'epoch': 0, 'update in batch': 20, '/': 18563, 'loss': 9.588350296020508}\n",
"{'epoch': 0, 'update in batch': 21, '/': 18563, 'loss': 9.780914306640625}\n",
"{'epoch': 0, 'update in batch': 22, '/': 18563, 'loss': 9.36646842956543}\n",
"{'epoch': 0, 'update in batch': 23, '/': 18563, 'loss': 9.306387901306152}\n",
"{'epoch': 0, 'update in batch': 24, '/': 18563, 'loss': 9.150574684143066}\n",
"{'epoch': 0, 'update in batch': 25, '/': 18563, 'loss': 8.89719295501709}\n",
"{'epoch': 0, 'update in batch': 26, '/': 18563, 'loss': 8.741975784301758}\n",
"{'epoch': 0, 'update in batch': 27, '/': 18563, 'loss': 9.36513614654541}\n",
"{'epoch': 0, 'update in batch': 28, '/': 18563, 'loss': 8.840768814086914}\n",
"{'epoch': 0, 'update in batch': 29, '/': 18563, 'loss': 8.356801986694336}\n",
"{'epoch': 0, 'update in batch': 30, '/': 18563, 'loss': 8.274016380310059}\n",
"{'epoch': 0, 'update in batch': 31, '/': 18563, 'loss': 8.944927215576172}\n",
"{'epoch': 0, 'update in batch': 32, '/': 18563, 'loss': 8.923280715942383}\n",
"{'epoch': 0, 'update in batch': 33, '/': 18563, 'loss': 8.479402542114258}\n",
"{'epoch': 0, 'update in batch': 34, '/': 18563, 'loss': 8.42425537109375}\n",
"{'epoch': 0, 'update in batch': 35, '/': 18563, 'loss': 9.487113952636719}\n",
"{'epoch': 0, 'update in batch': 36, '/': 18563, 'loss': 8.314191818237305}\n",
"{'epoch': 0, 'update in batch': 37, '/': 18563, 'loss': 8.0274658203125}\n",
"{'epoch': 0, 'update in batch': 38, '/': 18563, 'loss': 8.725769996643066}\n",
"{'epoch': 0, 'update in batch': 39, '/': 18563, 'loss': 8.67934799194336}\n",
"{'epoch': 0, 'update in batch': 40, '/': 18563, 'loss': 8.872161865234375}\n",
"{'epoch': 0, 'update in batch': 41, '/': 18563, 'loss': 7.883971214294434}\n",
"{'epoch': 0, 'update in batch': 42, '/': 18563, 'loss': 7.682810306549072}\n",
"{'epoch': 0, 'update in batch': 43, '/': 18563, 'loss': 7.880677223205566}\n",
"{'epoch': 0, 'update in batch': 44, '/': 18563, 'loss': 7.807427406311035}\n",
"{'epoch': 0, 'update in batch': 45, '/': 18563, 'loss': 7.93829870223999}\n",
"{'epoch': 0, 'update in batch': 46, '/': 18563, 'loss': 7.718912601470947}\n",
"{'epoch': 0, 'update in batch': 47, '/': 18563, 'loss': 8.309863090515137}\n",
"{'epoch': 0, 'update in batch': 48, '/': 18563, 'loss': 9.091133117675781}\n",
"{'epoch': 0, 'update in batch': 49, '/': 18563, 'loss': 9.317312240600586}\n",
"{'epoch': 0, 'update in batch': 50, '/': 18563, 'loss': 8.517735481262207}\n",
"{'epoch': 0, 'update in batch': 51, '/': 18563, 'loss': 7.697592258453369}\n",
"{'epoch': 0, 'update in batch': 52, '/': 18563, 'loss': 6.838181972503662}\n",
"{'epoch': 0, 'update in batch': 53, '/': 18563, 'loss': 7.967227935791016}\n",
"{'epoch': 0, 'update in batch': 54, '/': 18563, 'loss': 8.47049331665039}\n",
"{'epoch': 0, 'update in batch': 55, '/': 18563, 'loss': 8.958921432495117}\n",
"{'epoch': 0, 'update in batch': 56, '/': 18563, 'loss': 8.316679000854492}\n",
"{'epoch': 0, 'update in batch': 57, '/': 18563, 'loss': 8.997099876403809}\n",
"{'epoch': 0, 'update in batch': 58, '/': 18563, 'loss': 8.608811378479004}\n",
"{'epoch': 0, 'update in batch': 59, '/': 18563, 'loss': 9.377460479736328}\n",
"{'epoch': 0, 'update in batch': 60, '/': 18563, 'loss': 8.6201171875}\n",
"{'epoch': 0, 'update in batch': 61, '/': 18563, 'loss': 8.821510314941406}\n",
"{'epoch': 0, 'update in batch': 62, '/': 18563, 'loss': 8.915961265563965}\n",
"{'epoch': 0, 'update in batch': 63, '/': 18563, 'loss': 8.222617149353027}\n",
"{'epoch': 0, 'update in batch': 64, '/': 18563, 'loss': 9.266777992248535}\n",
"{'epoch': 0, 'update in batch': 65, '/': 18563, 'loss': 8.749354362487793}\n",
"{'epoch': 0, 'update in batch': 66, '/': 18563, 'loss': 8.311641693115234}\n",
"{'epoch': 0, 'update in batch': 67, '/': 18563, 'loss': 8.553888320922852}\n",
"{'epoch': 0, 'update in batch': 68, '/': 18563, 'loss': 8.790258407592773}\n",
"{'epoch': 0, 'update in batch': 69, '/': 18563, 'loss': 9.090133666992188}\n",
"{'epoch': 0, 'update in batch': 70, '/': 18563, 'loss': 8.893723487854004}\n",
"{'epoch': 0, 'update in batch': 71, '/': 18563, 'loss': 8.844594955444336}\n",
"{'epoch': 0, 'update in batch': 72, '/': 18563, 'loss': 7.771625518798828}\n",
"{'epoch': 0, 'update in batch': 73, '/': 18563, 'loss': 8.536479949951172}\n",
"{'epoch': 0, 'update in batch': 74, '/': 18563, 'loss': 7.300860404968262}\n",
"{'epoch': 0, 'update in batch': 75, '/': 18563, 'loss': 8.62000846862793}\n",
"{'epoch': 0, 'update in batch': 76, '/': 18563, 'loss': 8.67784309387207}\n",
"{'epoch': 0, 'update in batch': 77, '/': 18563, 'loss': 7.319235801696777}\n",
"{'epoch': 0, 'update in batch': 78, '/': 18563, 'loss': 8.322186470031738}\n",
"{'epoch': 0, 'update in batch': 79, '/': 18563, 'loss': 7.767421722412109}\n",
"{'epoch': 0, 'update in batch': 80, '/': 18563, 'loss': 8.817885398864746}\n",
"{'epoch': 0, 'update in batch': 81, '/': 18563, 'loss': 8.133109092712402}\n",
"{'epoch': 0, 'update in batch': 82, '/': 18563, 'loss': 7.822054862976074}\n",
"{'epoch': 0, 'update in batch': 83, '/': 18563, 'loss': 8.055540084838867}\n",
"{'epoch': 0, 'update in batch': 84, '/': 18563, 'loss': 8.053682327270508}\n",
"{'epoch': 0, 'update in batch': 85, '/': 18563, 'loss': 8.018306732177734}\n",
"{'epoch': 0, 'update in batch': 86, '/': 18563, 'loss': 8.371909141540527}\n",
"{'epoch': 0, 'update in batch': 87, '/': 18563, 'loss': 8.057979583740234}\n",
"{'epoch': 0, 'update in batch': 88, '/': 18563, 'loss': 8.340703010559082}\n",
"{'epoch': 0, 'update in batch': 89, '/': 18563, 'loss': 8.7703857421875}\n",
"{'epoch': 0, 'update in batch': 90, '/': 18563, 'loss': 9.714847564697266}\n",
"{'epoch': 0, 'update in batch': 91, '/': 18563, 'loss': 8.621702194213867}\n",
"{'epoch': 0, 'update in batch': 92, '/': 18563, 'loss': 9.406997680664062}\n",
"{'epoch': 0, 'update in batch': 93, '/': 18563, 'loss': 9.29774284362793}\n",
"{'epoch': 0, 'update in batch': 94, '/': 18563, 'loss': 8.649836540222168}\n",
"{'epoch': 0, 'update in batch': 95, '/': 18563, 'loss': 8.441780090332031}\n",
"{'epoch': 0, 'update in batch': 96, '/': 18563, 'loss': 7.991406440734863}\n",
"{'epoch': 0, 'update in batch': 97, '/': 18563, 'loss': 9.314489364624023}\n",
"{'epoch': 0, 'update in batch': 98, '/': 18563, 'loss': 8.368816375732422}\n",
"{'epoch': 0, 'update in batch': 99, '/': 18563, 'loss': 8.771149635314941}\n",
"{'epoch': 0, 'update in batch': 100, '/': 18563, 'loss': 7.8758111000061035}\n",
"{'epoch': 0, 'update in batch': 101, '/': 18563, 'loss': 8.341328620910645}\n",
"{'epoch': 0, 'update in batch': 102, '/': 18563, 'loss': 8.413129806518555}\n",
"{'epoch': 0, 'update in batch': 103, '/': 18563, 'loss': 7.372011661529541}\n",
"{'epoch': 0, 'update in batch': 104, '/': 18563, 'loss': 8.170934677124023}\n",
"{'epoch': 0, 'update in batch': 105, '/': 18563, 'loss': 8.109993934631348}\n",
"{'epoch': 0, 'update in batch': 106, '/': 18563, 'loss': 8.172578811645508}\n",
"{'epoch': 0, 'update in batch': 107, '/': 18563, 'loss': 8.33222484588623}\n",
"{'epoch': 0, 'update in batch': 108, '/': 18563, 'loss': 7.997575283050537}\n",
"{'epoch': 0, 'update in batch': 109, '/': 18563, 'loss': 7.847937107086182}\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'epoch': 0, 'update in batch': 110, '/': 18563, 'loss': 7.351314544677734}\n",
"{'epoch': 0, 'update in batch': 111, '/': 18563, 'loss': 8.472936630249023}\n",
"{'epoch': 0, 'update in batch': 112, '/': 18563, 'loss': 7.855953216552734}\n",
"{'epoch': 0, 'update in batch': 113, '/': 18563, 'loss': 8.163175582885742}\n",
"{'epoch': 0, 'update in batch': 114, '/': 18563, 'loss': 8.208657264709473}\n",
"{'epoch': 0, 'update in batch': 115, '/': 18563, 'loss': 8.781523704528809}\n",
"{'epoch': 0, 'update in batch': 116, '/': 18563, 'loss': 8.449674606323242}\n",
"{'epoch': 0, 'update in batch': 117, '/': 18563, 'loss': 8.176030158996582}\n",
"{'epoch': 0, 'update in batch': 118, '/': 18563, 'loss': 8.415689468383789}\n",
"{'epoch': 0, 'update in batch': 119, '/': 18563, 'loss': 8.645845413208008}\n",
"{'epoch': 0, 'update in batch': 120, '/': 18563, 'loss': 8.160420417785645}\n",
"{'epoch': 0, 'update in batch': 121, '/': 18563, 'loss': 8.117982864379883}\n",
"{'epoch': 0, 'update in batch': 122, '/': 18563, 'loss': 9.099283218383789}\n",
"{'epoch': 0, 'update in batch': 123, '/': 18563, 'loss': 7.98253870010376}\n",
"{'epoch': 0, 'update in batch': 124, '/': 18563, 'loss': 8.112133979797363}\n",
"{'epoch': 0, 'update in batch': 125, '/': 18563, 'loss': 8.479134559631348}\n",
"{'epoch': 0, 'update in batch': 126, '/': 18563, 'loss': 8.92817497253418}\n",
"{'epoch': 0, 'update in batch': 127, '/': 18563, 'loss': 8.38918399810791}\n",
"{'epoch': 0, 'update in batch': 128, '/': 18563, 'loss': 9.000529289245605}\n",
"{'epoch': 0, 'update in batch': 129, '/': 18563, 'loss': 8.525534629821777}\n",
"{'epoch': 0, 'update in batch': 130, '/': 18563, 'loss': 9.055428504943848}\n",
"{'epoch': 0, 'update in batch': 131, '/': 18563, 'loss': 8.818662643432617}\n",
"{'epoch': 0, 'update in batch': 132, '/': 18563, 'loss': 8.807767868041992}\n",
"{'epoch': 0, 'update in batch': 133, '/': 18563, 'loss': 8.398343086242676}\n",
"{'epoch': 0, 'update in batch': 134, '/': 18563, 'loss': 8.435093879699707}\n",
"{'epoch': 0, 'update in batch': 135, '/': 18563, 'loss': 7.877000331878662}\n",
"{'epoch': 0, 'update in batch': 136, '/': 18563, 'loss': 8.197925567626953}\n",
"{'epoch': 0, 'update in batch': 137, '/': 18563, 'loss': 8.655011177062988}\n",
"{'epoch': 0, 'update in batch': 138, '/': 18563, 'loss': 7.786923885345459}\n",
"{'epoch': 0, 'update in batch': 139, '/': 18563, 'loss': 8.338996887207031}\n",
"{'epoch': 0, 'update in batch': 140, '/': 18563, 'loss': 8.607789993286133}\n",
"{'epoch': 0, 'update in batch': 141, '/': 18563, 'loss': 8.52219295501709}\n",
"{'epoch': 0, 'update in batch': 142, '/': 18563, 'loss': 8.436418533325195}\n",
"{'epoch': 0, 'update in batch': 143, '/': 18563, 'loss': 7.999323844909668}\n",
"{'epoch': 0, 'update in batch': 144, '/': 18563, 'loss': 7.543336391448975}\n",
"{'epoch': 0, 'update in batch': 145, '/': 18563, 'loss': 7.3255791664123535}\n",
"{'epoch': 0, 'update in batch': 146, '/': 18563, 'loss': 7.993613243103027}\n",
"{'epoch': 0, 'update in batch': 147, '/': 18563, 'loss': 8.8505859375}\n",
"{'epoch': 0, 'update in batch': 148, '/': 18563, 'loss': 8.146835327148438}\n",
"{'epoch': 0, 'update in batch': 149, '/': 18563, 'loss': 8.532424926757812}\n",
"{'epoch': 0, 'update in batch': 150, '/': 18563, 'loss': 8.323905944824219}\n",
"{'epoch': 0, 'update in batch': 151, '/': 18563, 'loss': 7.8726677894592285}\n",
"{'epoch': 0, 'update in batch': 152, '/': 18563, 'loss': 7.912005424499512}\n",
"{'epoch': 0, 'update in batch': 153, '/': 18563, 'loss': 8.010560035705566}\n",
"{'epoch': 0, 'update in batch': 154, '/': 18563, 'loss': 7.9417009353637695}\n",
"{'epoch': 0, 'update in batch': 155, '/': 18563, 'loss': 7.991711616516113}\n",
"{'epoch': 0, 'update in batch': 156, '/': 18563, 'loss': 8.27558708190918}\n",
"{'epoch': 0, 'update in batch': 157, '/': 18563, 'loss': 7.736246585845947}\n",
"{'epoch': 0, 'update in batch': 158, '/': 18563, 'loss': 7.4755754470825195}\n",
"{'epoch': 0, 'update in batch': 159, '/': 18563, 'loss': 8.023443222045898}\n",
"{'epoch': 0, 'update in batch': 160, '/': 18563, 'loss': 8.130350112915039}\n",
"{'epoch': 0, 'update in batch': 161, '/': 18563, 'loss': 7.770634651184082}\n",
"{'epoch': 0, 'update in batch': 162, '/': 18563, 'loss': 7.775434970855713}\n",
"{'epoch': 0, 'update in batch': 163, '/': 18563, 'loss': 7.965312957763672}\n",
"{'epoch': 0, 'update in batch': 164, '/': 18563, 'loss': 7.977341651916504}\n",
"{'epoch': 0, 'update in batch': 165, '/': 18563, 'loss': 7.703671455383301}\n",
"{'epoch': 0, 'update in batch': 166, '/': 18563, 'loss': 8.027135848999023}\n",
"{'epoch': 0, 'update in batch': 167, '/': 18563, 'loss': 7.7673773765563965}\n",
"{'epoch': 0, 'update in batch': 168, '/': 18563, 'loss': 8.654549598693848}\n",
"{'epoch': 0, 'update in batch': 169, '/': 18563, 'loss': 7.8060808181762695}\n",
"{'epoch': 0, 'update in batch': 170, '/': 18563, 'loss': 7.33704137802124}\n",
"{'epoch': 0, 'update in batch': 171, '/': 18563, 'loss': 7.971919059753418}\n",
"{'epoch': 0, 'update in batch': 172, '/': 18563, 'loss': 7.450611114501953}\n",
"{'epoch': 0, 'update in batch': 173, '/': 18563, 'loss': 7.978057861328125}\n",
"{'epoch': 0, 'update in batch': 174, '/': 18563, 'loss': 8.264434814453125}\n",
"{'epoch': 0, 'update in batch': 175, '/': 18563, 'loss': 8.47761058807373}\n",
"{'epoch': 0, 'update in batch': 176, '/': 18563, 'loss': 7.643885135650635}\n",
"{'epoch': 0, 'update in batch': 177, '/': 18563, 'loss': 8.696805000305176}\n",
"{'epoch': 0, 'update in batch': 178, '/': 18563, 'loss': 9.144462585449219}\n",
"{'epoch': 0, 'update in batch': 179, '/': 18563, 'loss': 8.582620620727539}\n",
"{'epoch': 0, 'update in batch': 180, '/': 18563, 'loss': 8.495562553405762}\n",
"{'epoch': 0, 'update in batch': 181, '/': 18563, 'loss': 9.259647369384766}\n",
"{'epoch': 0, 'update in batch': 182, '/': 18563, 'loss': 8.286632537841797}\n",
"{'epoch': 0, 'update in batch': 183, '/': 18563, 'loss': 8.378074645996094}\n",
"{'epoch': 0, 'update in batch': 184, '/': 18563, 'loss': 8.404892921447754}\n",
"{'epoch': 0, 'update in batch': 185, '/': 18563, 'loss': 9.206843376159668}\n",
"{'epoch': 0, 'update in batch': 186, '/': 18563, 'loss': 8.97215747833252}\n",
"{'epoch': 0, 'update in batch': 187, '/': 18563, 'loss': 8.281005859375}\n",
"{'epoch': 0, 'update in batch': 188, '/': 18563, 'loss': 7.638144493103027}\n",
"{'epoch': 0, 'update in batch': 189, '/': 18563, 'loss': 7.991082668304443}\n",
"{'epoch': 0, 'update in batch': 190, '/': 18563, 'loss': 8.207674026489258}\n",
"{'epoch': 0, 'update in batch': 191, '/': 18563, 'loss': 8.16801643371582}\n",
"{'epoch': 0, 'update in batch': 192, '/': 18563, 'loss': 7.827309608459473}\n",
"{'epoch': 0, 'update in batch': 193, '/': 18563, 'loss': 8.387285232543945}\n",
"{'epoch': 0, 'update in batch': 194, '/': 18563, 'loss': 7.990261077880859}\n",
"{'epoch': 0, 'update in batch': 195, '/': 18563, 'loss': 7.7953925132751465}\n",
"{'epoch': 0, 'update in batch': 196, '/': 18563, 'loss': 7.252983093261719}\n",
"{'epoch': 0, 'update in batch': 197, '/': 18563, 'loss': 7.806585788726807}\n",
"{'epoch': 0, 'update in batch': 198, '/': 18563, 'loss': 7.871600151062012}\n",
"{'epoch': 0, 'update in batch': 199, '/': 18563, 'loss': 7.639830589294434}\n",
"{'epoch': 0, 'update in batch': 200, '/': 18563, 'loss': 8.108308792114258}\n",
"{'epoch': 0, 'update in batch': 201, '/': 18563, 'loss': 7.41513729095459}\n",
"{'epoch': 0, 'update in batch': 202, '/': 18563, 'loss': 8.103743553161621}\n",
"{'epoch': 0, 'update in batch': 203, '/': 18563, 'loss': 8.82174301147461}\n",
"{'epoch': 0, 'update in batch': 204, '/': 18563, 'loss': 8.34859561920166}\n",
"{'epoch': 0, 'update in batch': 205, '/': 18563, 'loss': 7.890545845031738}\n",
"{'epoch': 0, 'update in batch': 206, '/': 18563, 'loss': 7.679532527923584}\n",
"{'epoch': 0, 'update in batch': 207, '/': 18563, 'loss': 7.810311317443848}\n",
"{'epoch': 0, 'update in batch': 208, '/': 18563, 'loss': 8.342585563659668}\n",
"{'epoch': 0, 'update in batch': 209, '/': 18563, 'loss': 8.253597259521484}\n",
"{'epoch': 0, 'update in batch': 210, '/': 18563, 'loss': 7.963072299957275}\n",
"{'epoch': 0, 'update in batch': 211, '/': 18563, 'loss': 8.537101745605469}\n",
"{'epoch': 0, 'update in batch': 212, '/': 18563, 'loss': 8.503724098205566}\n",
"{'epoch': 0, 'update in batch': 213, '/': 18563, 'loss': 8.568987846374512}\n",
"{'epoch': 0, 'update in batch': 214, '/': 18563, 'loss': 7.760678291320801}\n",
"{'epoch': 0, 'update in batch': 215, '/': 18563, 'loss': 8.302183151245117}\n",
"{'epoch': 0, 'update in batch': 216, '/': 18563, 'loss': 7.427420616149902}\n",
"{'epoch': 0, 'update in batch': 217, '/': 18563, 'loss': 8.05746078491211}\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'epoch': 0, 'update in batch': 218, '/': 18563, 'loss': 8.82285213470459}\n",
"{'epoch': 0, 'update in batch': 219, '/': 18563, 'loss': 7.948827266693115}\n",
"{'epoch': 0, 'update in batch': 220, '/': 18563, 'loss': 8.164112091064453}\n",
"{'epoch': 0, 'update in batch': 221, '/': 18563, 'loss': 7.721047401428223}\n",
"{'epoch': 0, 'update in batch': 222, '/': 18563, 'loss': 7.668707370758057}\n",
"{'epoch': 0, 'update in batch': 223, '/': 18563, 'loss': 8.576696395874023}\n",
"{'epoch': 0, 'update in batch': 224, '/': 18563, 'loss': 8.253091812133789}\n",
"{'epoch': 0, 'update in batch': 225, '/': 18563, 'loss': 8.303543090820312}\n",
"{'epoch': 0, 'update in batch': 226, '/': 18563, 'loss': 8.069855690002441}\n",
"{'epoch': 0, 'update in batch': 227, '/': 18563, 'loss': 8.57229232788086}\n",
"{'epoch': 0, 'update in batch': 228, '/': 18563, 'loss': 8.904585838317871}\n",
"{'epoch': 0, 'update in batch': 229, '/': 18563, 'loss': 8.485595703125}\n",
"{'epoch': 0, 'update in batch': 230, '/': 18563, 'loss': 8.22756290435791}\n",
"{'epoch': 0, 'update in batch': 231, '/': 18563, 'loss': 8.281603813171387}\n",
"{'epoch': 0, 'update in batch': 232, '/': 18563, 'loss': 7.591467380523682}\n",
"{'epoch': 0, 'update in batch': 233, '/': 18563, 'loss': 7.8028883934021}\n",
"{'epoch': 0, 'update in batch': 234, '/': 18563, 'loss': 8.079168319702148}\n",
"{'epoch': 0, 'update in batch': 235, '/': 18563, 'loss': 7.578390598297119}\n",
"{'epoch': 0, 'update in batch': 236, '/': 18563, 'loss': 7.865830421447754}\n",
"{'epoch': 0, 'update in batch': 237, '/': 18563, 'loss': 7.105422019958496}\n",
"{'epoch': 0, 'update in batch': 238, '/': 18563, 'loss': 8.034143447875977}\n",
"{'epoch': 0, 'update in batch': 239, '/': 18563, 'loss': 7.23009729385376}\n",
"{'epoch': 0, 'update in batch': 240, '/': 18563, 'loss': 7.221669673919678}\n",
"{'epoch': 0, 'update in batch': 241, '/': 18563, 'loss': 7.118913173675537}\n",
"{'epoch': 0, 'update in batch': 242, '/': 18563, 'loss': 7.690147399902344}\n",
"{'epoch': 0, 'update in batch': 243, '/': 18563, 'loss': 7.676979064941406}\n",
"{'epoch': 0, 'update in batch': 244, '/': 18563, 'loss': 8.231537818908691}\n",
"{'epoch': 0, 'update in batch': 245, '/': 18563, 'loss': 8.212566375732422}\n",
"{'epoch': 0, 'update in batch': 246, '/': 18563, 'loss': 9.095616340637207}\n",
"{'epoch': 0, 'update in batch': 247, '/': 18563, 'loss': 8.249703407287598}\n",
"{'epoch': 0, 'update in batch': 248, '/': 18563, 'loss': 9.082058906555176}\n",
"{'epoch': 0, 'update in batch': 249, '/': 18563, 'loss': 8.530516624450684}\n",
"{'epoch': 0, 'update in batch': 250, '/': 18563, 'loss': 8.979915618896484}\n",
"{'epoch': 0, 'update in batch': 251, '/': 18563, 'loss': 8.667882919311523}\n",
"{'epoch': 0, 'update in batch': 252, '/': 18563, 'loss': 8.804525375366211}\n",
"{'epoch': 0, 'update in batch': 253, '/': 18563, 'loss': 8.67729377746582}\n",
"{'epoch': 0, 'update in batch': 254, '/': 18563, 'loss': 8.580761909484863}\n",
"{'epoch': 0, 'update in batch': 255, '/': 18563, 'loss': 7.724173545837402}\n",
"{'epoch': 0, 'update in batch': 256, '/': 18563, 'loss': 7.7925591468811035}\n",
"{'epoch': 0, 'update in batch': 257, '/': 18563, 'loss': 7.731482028961182}\n",
"{'epoch': 0, 'update in batch': 258, '/': 18563, 'loss': 7.644040107727051}\n",
"{'epoch': 0, 'update in batch': 259, '/': 18563, 'loss': 7.947877407073975}\n",
"{'epoch': 0, 'update in batch': 260, '/': 18563, 'loss': 7.649043083190918}\n",
"{'epoch': 0, 'update in batch': 261, '/': 18563, 'loss': 7.40912389755249}\n",
"{'epoch': 0, 'update in batch': 262, '/': 18563, 'loss': 8.199918746948242}\n",
"{'epoch': 0, 'update in batch': 263, '/': 18563, 'loss': 7.272132873535156}\n",
"{'epoch': 0, 'update in batch': 264, '/': 18563, 'loss': 7.205214500427246}\n",
"{'epoch': 0, 'update in batch': 265, '/': 18563, 'loss': 8.999595642089844}\n",
"{'epoch': 0, 'update in batch': 266, '/': 18563, 'loss': 7.851510524749756}\n",
"{'epoch': 0, 'update in batch': 267, '/': 18563, 'loss': 7.748948097229004}\n",
"{'epoch': 0, 'update in batch': 268, '/': 18563, 'loss': 7.96875}\n",
"{'epoch': 0, 'update in batch': 269, '/': 18563, 'loss': 7.627255916595459}\n",
"{'epoch': 0, 'update in batch': 270, '/': 18563, 'loss': 7.719862937927246}\n",
"{'epoch': 0, 'update in batch': 271, '/': 18563, 'loss': 7.58780574798584}\n",
"{'epoch': 0, 'update in batch': 272, '/': 18563, 'loss': 8.386865615844727}\n",
"{'epoch': 0, 'update in batch': 273, '/': 18563, 'loss': 8.708396911621094}\n",
"{'epoch': 0, 'update in batch': 274, '/': 18563, 'loss': 7.853432655334473}\n",
"{'epoch': 0, 'update in batch': 275, '/': 18563, 'loss': 7.818131923675537}\n",
"{'epoch': 0, 'update in batch': 276, '/': 18563, 'loss': 7.714521884918213}\n",
"{'epoch': 0, 'update in batch': 277, '/': 18563, 'loss': 8.75371265411377}\n",
"{'epoch': 0, 'update in batch': 278, '/': 18563, 'loss': 7.6992998123168945}\n",
"{'epoch': 0, 'update in batch': 279, '/': 18563, 'loss': 7.652693748474121}\n",
"{'epoch': 0, 'update in batch': 280, '/': 18563, 'loss': 7.364585876464844}\n",
"{'epoch': 0, 'update in batch': 281, '/': 18563, 'loss': 7.742022514343262}\n",
"{'epoch': 0, 'update in batch': 282, '/': 18563, 'loss': 7.6205573081970215}\n",
"{'epoch': 0, 'update in batch': 283, '/': 18563, 'loss': 7.475846290588379}\n",
"{'epoch': 0, 'update in batch': 284, '/': 18563, 'loss': 7.302148342132568}\n",
"{'epoch': 0, 'update in batch': 285, '/': 18563, 'loss': 7.524351596832275}\n",
"{'epoch': 0, 'update in batch': 286, '/': 18563, 'loss': 7.755963325500488}\n",
"{'epoch': 0, 'update in batch': 287, '/': 18563, 'loss': 7.620995998382568}\n",
"{'epoch': 0, 'update in batch': 288, '/': 18563, 'loss': 7.289975166320801}\n",
"{'epoch': 0, 'update in batch': 289, '/': 18563, 'loss': 7.470652103424072}\n",
"{'epoch': 0, 'update in batch': 290, '/': 18563, 'loss': 7.297110557556152}\n",
"{'epoch': 0, 'update in batch': 291, '/': 18563, 'loss': 7.907563209533691}\n",
"{'epoch': 0, 'update in batch': 292, '/': 18563, 'loss': 8.051852226257324}\n",
"{'epoch': 0, 'update in batch': 293, '/': 18563, 'loss': 6.691899299621582}\n",
"{'epoch': 0, 'update in batch': 294, '/': 18563, 'loss': 7.9747819900512695}\n",
"{'epoch': 0, 'update in batch': 295, '/': 18563, 'loss': 7.415904998779297}\n",
"{'epoch': 0, 'update in batch': 296, '/': 18563, 'loss': 7.479670524597168}\n",
"{'epoch': 0, 'update in batch': 297, '/': 18563, 'loss': 7.9454755783081055}\n",
"{'epoch': 0, 'update in batch': 298, '/': 18563, 'loss': 7.79656457901001}\n",
"{'epoch': 0, 'update in batch': 299, '/': 18563, 'loss': 7.644859313964844}\n",
"{'epoch': 0, 'update in batch': 300, '/': 18563, 'loss': 7.649240970611572}\n",
"{'epoch': 0, 'update in batch': 301, '/': 18563, 'loss': 7.497203826904297}\n",
"{'epoch': 0, 'update in batch': 302, '/': 18563, 'loss': 7.169632911682129}\n",
"{'epoch': 0, 'update in batch': 303, '/': 18563, 'loss': 7.124764442443848}\n",
"{'epoch': 0, 'update in batch': 304, '/': 18563, 'loss': 7.728893280029297}\n",
"{'epoch': 0, 'update in batch': 305, '/': 18563, 'loss': 8.029245376586914}\n",
"{'epoch': 0, 'update in batch': 306, '/': 18563, 'loss': 7.361662864685059}\n",
"{'epoch': 0, 'update in batch': 307, '/': 18563, 'loss': 8.070173263549805}\n",
"{'epoch': 0, 'update in batch': 308, '/': 18563, 'loss': 7.55655574798584}\n",
"{'epoch': 0, 'update in batch': 309, '/': 18563, 'loss': 7.713553428649902}\n",
"{'epoch': 0, 'update in batch': 310, '/': 18563, 'loss': 8.333553314208984}\n",
"{'epoch': 0, 'update in batch': 311, '/': 18563, 'loss': 8.089872360229492}\n",
"{'epoch': 0, 'update in batch': 312, '/': 18563, 'loss': 8.951356887817383}\n",
"{'epoch': 0, 'update in batch': 313, '/': 18563, 'loss': 8.920665740966797}\n",
"{'epoch': 0, 'update in batch': 314, '/': 18563, 'loss': 8.811259269714355}\n",
"{'epoch': 0, 'update in batch': 315, '/': 18563, 'loss': 8.719802856445312}\n",
"{'epoch': 0, 'update in batch': 316, '/': 18563, 'loss': 8.700776100158691}\n",
"{'epoch': 0, 'update in batch': 317, '/': 18563, 'loss': 8.846036911010742}\n",
"{'epoch': 0, 'update in batch': 318, '/': 18563, 'loss': 8.553533554077148}\n",
"{'epoch': 0, 'update in batch': 319, '/': 18563, 'loss': 9.257116317749023}\n",
"{'epoch': 0, 'update in batch': 320, '/': 18563, 'loss': 8.487042427062988}\n",
"{'epoch': 0, 'update in batch': 321, '/': 18563, 'loss': 8.743330955505371}\n",
"{'epoch': 0, 'update in batch': 322, '/': 18563, 'loss': 8.377813339233398}\n",
"{'epoch': 0, 'update in batch': 323, '/': 18563, 'loss': 8.41798210144043}\n",
"{'epoch': 0, 'update in batch': 324, '/': 18563, 'loss': 7.884764671325684}\n",
"{'epoch': 0, 'update in batch': 325, '/': 18563, 'loss': 8.827409744262695}\n",
"{'epoch': 0, 'update in batch': 326, '/': 18563, 'loss': 8.21721363067627}\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'epoch': 0, 'update in batch': 327, '/': 18563, 'loss': 8.522723197937012}\n",
"{'epoch': 0, 'update in batch': 328, '/': 18563, 'loss': 7.387178897857666}\n",
"{'epoch': 0, 'update in batch': 329, '/': 18563, 'loss': 8.58663558959961}\n",
"{'epoch': 0, 'update in batch': 330, '/': 18563, 'loss': 8.539435386657715}\n",
"{'epoch': 0, 'update in batch': 331, '/': 18563, 'loss': 8.35865592956543}\n",
"{'epoch': 0, 'update in batch': 332, '/': 18563, 'loss': 8.55555248260498}\n",
"{'epoch': 0, 'update in batch': 333, '/': 18563, 'loss': 7.9116950035095215}\n",
"{'epoch': 0, 'update in batch': 334, '/': 18563, 'loss': 8.424735069274902}\n",
"{'epoch': 0, 'update in batch': 335, '/': 18563, 'loss': 8.383890151977539}\n",
"{'epoch': 0, 'update in batch': 336, '/': 18563, 'loss': 8.145454406738281}\n",
"{'epoch': 0, 'update in batch': 337, '/': 18563, 'loss': 8.014772415161133}\n",
"{'epoch': 0, 'update in batch': 338, '/': 18563, 'loss': 8.532005310058594}\n",
"{'epoch': 0, 'update in batch': 339, '/': 18563, 'loss': 8.979973793029785}\n",
"{'epoch': 0, 'update in batch': 340, '/': 18563, 'loss': 8.3964204788208}\n",
"{'epoch': 0, 'update in batch': 341, '/': 18563, 'loss': 8.34205150604248}\n",
"{'epoch': 0, 'update in batch': 342, '/': 18563, 'loss': 7.861489295959473}\n",
"{'epoch': 0, 'update in batch': 343, '/': 18563, 'loss': 8.807058334350586}\n",
"{'epoch': 0, 'update in batch': 344, '/': 18563, 'loss': 8.14976978302002}\n",
"{'epoch': 0, 'update in batch': 345, '/': 18563, 'loss': 8.212860107421875}\n",
"{'epoch': 0, 'update in batch': 346, '/': 18563, 'loss': 8.323419570922852}\n",
"{'epoch': 0, 'update in batch': 347, '/': 18563, 'loss': 9.06071662902832}\n",
"{'epoch': 0, 'update in batch': 348, '/': 18563, 'loss': 8.79192066192627}\n",
"{'epoch': 0, 'update in batch': 349, '/': 18563, 'loss': 8.717201232910156}\n",
"{'epoch': 0, 'update in batch': 350, '/': 18563, 'loss': 8.149703979492188}\n",
"{'epoch': 0, 'update in batch': 351, '/': 18563, 'loss': 7.990046501159668}\n",
"{'epoch': 0, 'update in batch': 352, '/': 18563, 'loss': 7.8197221755981445}\n",
"{'epoch': 0, 'update in batch': 353, '/': 18563, 'loss': 8.022729873657227}\n",
"{'epoch': 0, 'update in batch': 354, '/': 18563, 'loss': 8.339923858642578}\n",
"{'epoch': 0, 'update in batch': 355, '/': 18563, 'loss': 7.867880821228027}\n",
"{'epoch': 0, 'update in batch': 356, '/': 18563, 'loss': 8.161782264709473}\n",
"{'epoch': 0, 'update in batch': 357, '/': 18563, 'loss': 7.711170196533203}\n",
"{'epoch': 0, 'update in batch': 358, '/': 18563, 'loss': 8.46279239654541}\n",
"{'epoch': 0, 'update in batch': 359, '/': 18563, 'loss': 8.327804565429688}\n",
"{'epoch': 0, 'update in batch': 360, '/': 18563, 'loss': 8.184597969055176}\n",
"{'epoch': 0, 'update in batch': 361, '/': 18563, 'loss': 8.126212120056152}\n",
"{'epoch': 0, 'update in batch': 362, '/': 18563, 'loss': 8.122446060180664}\n",
"{'epoch': 0, 'update in batch': 363, '/': 18563, 'loss': 7.730257511138916}\n",
"{'epoch': 0, 'update in batch': 364, '/': 18563, 'loss': 7.7179059982299805}\n",
"{'epoch': 0, 'update in batch': 365, '/': 18563, 'loss': 7.557857513427734}\n",
"{'epoch': 0, 'update in batch': 366, '/': 18563, 'loss': 8.614083290100098}\n",
"{'epoch': 0, 'update in batch': 367, '/': 18563, 'loss': 8.0489501953125}\n",
"{'epoch': 0, 'update in batch': 368, '/': 18563, 'loss': 8.355381965637207}\n",
"{'epoch': 0, 'update in batch': 369, '/': 18563, 'loss': 7.592991828918457}\n",
"{'epoch': 0, 'update in batch': 370, '/': 18563, 'loss': 7.674102783203125}\n",
"{'epoch': 0, 'update in batch': 371, '/': 18563, 'loss': 7.818256378173828}\n",
"{'epoch': 0, 'update in batch': 372, '/': 18563, 'loss': 8.510438919067383}\n",
"{'epoch': 0, 'update in batch': 373, '/': 18563, 'loss': 8.02087116241455}\n",
"{'epoch': 0, 'update in batch': 374, '/': 18563, 'loss': 8.206090927124023}\n",
"{'epoch': 0, 'update in batch': 375, '/': 18563, 'loss': 7.645677089691162}\n",
"{'epoch': 0, 'update in batch': 376, '/': 18563, 'loss': 8.241236686706543}\n",
"{'epoch': 0, 'update in batch': 377, '/': 18563, 'loss': 8.581649780273438}\n",
"{'epoch': 0, 'update in batch': 378, '/': 18563, 'loss': 9.361258506774902}\n",
"{'epoch': 0, 'update in batch': 379, '/': 18563, 'loss': 9.097440719604492}\n",
"{'epoch': 0, 'update in batch': 380, '/': 18563, 'loss': 8.081677436828613}\n",
"{'epoch': 0, 'update in batch': 381, '/': 18563, 'loss': 8.761143684387207}\n",
"{'epoch': 0, 'update in batch': 382, '/': 18563, 'loss': 7.9429121017456055}\n",
"{'epoch': 0, 'update in batch': 383, '/': 18563, 'loss': 8.05648422241211}\n",
"{'epoch': 0, 'update in batch': 384, '/': 18563, 'loss': 7.316658020019531}\n",
"{'epoch': 0, 'update in batch': 385, '/': 18563, 'loss': 8.597393035888672}\n",
"{'epoch': 0, 'update in batch': 386, '/': 18563, 'loss': 9.393728256225586}\n",
"{'epoch': 0, 'update in batch': 387, '/': 18563, 'loss': 8.225081443786621}\n",
"{'epoch': 0, 'update in batch': 388, '/': 18563, 'loss': 7.9958319664001465}\n",
"{'epoch': 0, 'update in batch': 389, '/': 18563, 'loss': 8.390036582946777}\n",
"{'epoch': 0, 'update in batch': 390, '/': 18563, 'loss': 7.745572566986084}\n",
"{'epoch': 0, 'update in batch': 391, '/': 18563, 'loss': 8.403060913085938}\n",
"{'epoch': 0, 'update in batch': 392, '/': 18563, 'loss': 8.703788757324219}\n",
"{'epoch': 0, 'update in batch': 393, '/': 18563, 'loss': 8.516857147216797}\n",
"{'epoch': 0, 'update in batch': 394, '/': 18563, 'loss': 8.078744888305664}\n",
"{'epoch': 0, 'update in batch': 395, '/': 18563, 'loss': 7.6597900390625}\n",
"{'epoch': 0, 'update in batch': 396, '/': 18563, 'loss': 8.454282760620117}\n",
"{'epoch': 0, 'update in batch': 397, '/': 18563, 'loss': 7.7727837562561035}\n",
"{'epoch': 0, 'update in batch': 398, '/': 18563, 'loss': 8.222984313964844}\n",
"{'epoch': 0, 'update in batch': 399, '/': 18563, 'loss': 8.369619369506836}\n",
"{'epoch': 0, 'update in batch': 400, '/': 18563, 'loss': 8.542525291442871}\n",
"{'epoch': 0, 'update in batch': 401, '/': 18563, 'loss': 7.9681854248046875}\n",
"{'epoch': 0, 'update in batch': 402, '/': 18563, 'loss': 8.842118263244629}\n",
"{'epoch': 0, 'update in batch': 403, '/': 18563, 'loss': 7.958454132080078}\n",
"{'epoch': 0, 'update in batch': 404, '/': 18563, 'loss': 7.084095001220703}\n",
"{'epoch': 0, 'update in batch': 405, '/': 18563, 'loss': 7.8765130043029785}\n",
"{'epoch': 0, 'update in batch': 406, '/': 18563, 'loss': 7.639691352844238}\n",
"{'epoch': 0, 'update in batch': 407, '/': 18563, 'loss': 7.440125942230225}\n",
"{'epoch': 0, 'update in batch': 408, '/': 18563, 'loss': 7.928472995758057}\n",
"{'epoch': 0, 'update in batch': 409, '/': 18563, 'loss': 8.704710960388184}\n",
"{'epoch': 0, 'update in batch': 410, '/': 18563, 'loss': 8.214713096618652}\n",
"{'epoch': 0, 'update in batch': 411, '/': 18563, 'loss': 8.115629196166992}\n",
"{'epoch': 0, 'update in batch': 412, '/': 18563, 'loss': 9.357975006103516}\n",
"{'epoch': 0, 'update in batch': 413, '/': 18563, 'loss': 7.756926536560059}\n",
"{'epoch': 0, 'update in batch': 414, '/': 18563, 'loss': 8.93007755279541}\n",
"{'epoch': 0, 'update in batch': 415, '/': 18563, 'loss': 8.929518699645996}\n",
"{'epoch': 0, 'update in batch': 416, '/': 18563, 'loss': 7.646470069885254}\n",
"{'epoch': 0, 'update in batch': 417, '/': 18563, 'loss': 8.457891464233398}\n",
"{'epoch': 0, 'update in batch': 418, '/': 18563, 'loss': 7.377375602722168}\n",
"{'epoch': 0, 'update in batch': 419, '/': 18563, 'loss': 8.03713607788086}\n",
"{'epoch': 0, 'update in batch': 420, '/': 18563, 'loss': 8.125130653381348}\n",
"{'epoch': 0, 'update in batch': 421, '/': 18563, 'loss': 6.818246364593506}\n",
"{'epoch': 0, 'update in batch': 422, '/': 18563, 'loss': 7.220259189605713}\n",
"{'epoch': 0, 'update in batch': 423, '/': 18563, 'loss': 7.800910949707031}\n",
"{'epoch': 0, 'update in batch': 424, '/': 18563, 'loss': 8.175793647766113}\n",
"{'epoch': 0, 'update in batch': 425, '/': 18563, 'loss': 7.588067054748535}\n",
"{'epoch': 0, 'update in batch': 426, '/': 18563, 'loss': 7.2054619789123535}\n",
"{'epoch': 0, 'update in batch': 427, '/': 18563, 'loss': 7.6552839279174805}\n",
"{'epoch': 0, 'update in batch': 428, '/': 18563, 'loss': 8.851090431213379}\n",
"{'epoch': 0, 'update in batch': 429, '/': 18563, 'loss': 8.768563270568848}\n",
"{'epoch': 0, 'update in batch': 430, '/': 18563, 'loss': 7.926184177398682}\n",
"{'epoch': 0, 'update in batch': 431, '/': 18563, 'loss': 8.663213729858398}\n",
"{'epoch': 0, 'update in batch': 432, '/': 18563, 'loss': 8.386338233947754}\n",
"{'epoch': 0, 'update in batch': 433, '/': 18563, 'loss': 8.77399730682373}\n",
"{'epoch': 0, 'update in batch': 434, '/': 18563, 'loss': 8.385528564453125}\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'epoch': 0, 'update in batch': 435, '/': 18563, 'loss': 7.742388725280762}\n",
"{'epoch': 0, 'update in batch': 436, '/': 18563, 'loss': 8.363179206848145}\n",
"{'epoch': 0, 'update in batch': 437, '/': 18563, 'loss': 9.262784004211426}\n",
"{'epoch': 0, 'update in batch': 438, '/': 18563, 'loss': 9.236469268798828}\n",
"{'epoch': 0, 'update in batch': 439, '/': 18563, 'loss': 8.904603958129883}\n",
"{'epoch': 0, 'update in batch': 440, '/': 18563, 'loss': 8.675701141357422}\n",
"{'epoch': 0, 'update in batch': 441, '/': 18563, 'loss': 8.811418533325195}\n",
"{'epoch': 0, 'update in batch': 442, '/': 18563, 'loss': 8.002241134643555}\n",
"{'epoch': 0, 'update in batch': 443, '/': 18563, 'loss': 9.04414176940918}\n",
"{'epoch': 0, 'update in batch': 444, '/': 18563, 'loss': 7.8904008865356445}\n",
"{'epoch': 0, 'update in batch': 445, '/': 18563, 'loss': 8.524297714233398}\n",
"{'epoch': 0, 'update in batch': 446, '/': 18563, 'loss': 8.615904808044434}\n",
"{'epoch': 0, 'update in batch': 447, '/': 18563, 'loss': 8.201675415039062}\n",
"{'epoch': 0, 'update in batch': 448, '/': 18563, 'loss': 8.531024932861328}\n",
"{'epoch': 0, 'update in batch': 449, '/': 18563, 'loss': 7.8379621505737305}\n",
"{'epoch': 0, 'update in batch': 450, '/': 18563, 'loss': 8.416367530822754}\n",
"{'epoch': 0, 'update in batch': 451, '/': 18563, 'loss': 7.4990715980529785}\n",
"{'epoch': 0, 'update in batch': 452, '/': 18563, 'loss': 7.984610557556152}\n",
"{'epoch': 0, 'update in batch': 453, '/': 18563, 'loss': 7.719987392425537}\n",
"{'epoch': 0, 'update in batch': 454, '/': 18563, 'loss': 7.9333176612854}\n",
"{'epoch': 0, 'update in batch': 455, '/': 18563, 'loss': 8.619344711303711}\n",
"{'epoch': 0, 'update in batch': 456, '/': 18563, 'loss': 7.849525451660156}\n",
"{'epoch': 0, 'update in batch': 457, '/': 18563, 'loss': 7.700997352600098}\n",
"{'epoch': 0, 'update in batch': 458, '/': 18563, 'loss': 8.065767288208008}\n",
"{'epoch': 0, 'update in batch': 459, '/': 18563, 'loss': 7.489628791809082}\n",
"{'epoch': 0, 'update in batch': 460, '/': 18563, 'loss': 8.036481857299805}\n",
"{'epoch': 0, 'update in batch': 461, '/': 18563, 'loss': 8.227537155151367}\n",
"{'epoch': 0, 'update in batch': 462, '/': 18563, 'loss': 7.66103982925415}\n",
"{'epoch': 0, 'update in batch': 463, '/': 18563, 'loss': 8.481343269348145}\n",
"{'epoch': 0, 'update in batch': 464, '/': 18563, 'loss': 8.711318969726562}\n",
"{'epoch': 0, 'update in batch': 465, '/': 18563, 'loss': 7.549925804138184}\n",
"{'epoch': 0, 'update in batch': 466, '/': 18563, 'loss': 8.020782470703125}\n",
"{'epoch': 0, 'update in batch': 467, '/': 18563, 'loss': 7.784451484680176}\n",
"{'epoch': 0, 'update in batch': 468, '/': 18563, 'loss': 7.7545928955078125}\n",
"{'epoch': 0, 'update in batch': 469, '/': 18563, 'loss': 8.484171867370605}\n",
"{'epoch': 0, 'update in batch': 470, '/': 18563, 'loss': 8.291640281677246}\n",
"{'epoch': 0, 'update in batch': 471, '/': 18563, 'loss': 7.873322486877441}\n",
"{'epoch': 0, 'update in batch': 472, '/': 18563, 'loss': 7.891420841217041}\n",
"{'epoch': 0, 'update in batch': 473, '/': 18563, 'loss': 8.376962661743164}\n",
"{'epoch': 0, 'update in batch': 474, '/': 18563, 'loss': 8.147513389587402}\n",
"{'epoch': 0, 'update in batch': 475, '/': 18563, 'loss': 7.739943027496338}\n",
"{'epoch': 0, 'update in batch': 476, '/': 18563, 'loss': 7.52395486831665}\n",
"{'epoch': 0, 'update in batch': 477, '/': 18563, 'loss': 7.962507724761963}\n",
"{'epoch': 0, 'update in batch': 478, '/': 18563, 'loss': 7.61989688873291}\n",
"{'epoch': 0, 'update in batch': 479, '/': 18563, 'loss': 8.628551483154297}\n",
"{'epoch': 0, 'update in batch': 480, '/': 18563, 'loss': 10.344924926757812}\n",
"{'epoch': 0, 'update in batch': 481, '/': 18563, 'loss': 9.189457893371582}\n",
"{'epoch': 0, 'update in batch': 482, '/': 18563, 'loss': 9.283202171325684}\n",
"{'epoch': 0, 'update in batch': 483, '/': 18563, 'loss': 8.036226272583008}\n",
"{'epoch': 0, 'update in batch': 484, '/': 18563, 'loss': 8.949888229370117}\n",
"{'epoch': 0, 'update in batch': 485, '/': 18563, 'loss': 9.32779598236084}\n",
"{'epoch': 0, 'update in batch': 486, '/': 18563, 'loss': 9.554967880249023}\n",
"{'epoch': 0, 'update in batch': 487, '/': 18563, 'loss': 8.438692092895508}\n",
"{'epoch': 0, 'update in batch': 488, '/': 18563, 'loss': 8.015823364257812}\n",
"{'epoch': 0, 'update in batch': 489, '/': 18563, 'loss': 8.621005058288574}\n",
"{'epoch': 0, 'update in batch': 490, '/': 18563, 'loss': 8.432602882385254}\n",
"{'epoch': 0, 'update in batch': 491, '/': 18563, 'loss': 8.659430503845215}\n",
"{'epoch': 0, 'update in batch': 492, '/': 18563, 'loss': 8.693103790283203}\n",
"{'epoch': 0, 'update in batch': 493, '/': 18563, 'loss': 8.895064353942871}\n"
]
},
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-18-fe996a0be74b>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mModel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvocab_size\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0muniq_words\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m64\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;32m<ipython-input-17-8d700bc624e3>\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(dataset, model, max_epochs, batch_size)\u001b[0m\n\u001b[1;32m 15\u001b[0m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcriterion\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_pred\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtranspose\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 16\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 17\u001b[0;31m \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 18\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 19\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/anaconda3/lib/python3.8/site-packages/torch/_tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m 361\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 362\u001b[0m inputs=inputs)\n\u001b[0;32m--> 363\u001b[0;31m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mautograd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 364\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 365\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mregister_hook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/anaconda3/lib/python3.8/site-packages/torch/autograd/__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m 171\u001b[0m \u001b[0;31m# some Python versions print out the first line of a multi-line function\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 172\u001b[0m \u001b[0;31m# calls in the traceback and some print out the last line\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 173\u001b[0;31m Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass\n\u001b[0m\u001b[1;32m 174\u001b[0m \u001b[0mtensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad_tensors_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 175\u001b[0m allow_unreachable=True, accumulate_grad=True) # Calls into the C++ engine to run the backward pass\n",
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
]
}
],
"source": [
"model = Model(vocab_size = len(dataset.uniq_words)).to(device)\n",
"train(dataset, model, 1, 64)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"def predict(dataset, model, text, next_words=5):\n",
" model.eval()\n",
" words = text.split(' ')\n",
" state_h, state_c = model.init_state(len(words))\n",
"\n",
" for i in range(0, next_words):\n",
" x = torch.tensor([[dataset.word_to_index[w] for w in words[i:]]]).to(device)\n",
" y_pred, (state_h, state_c) = model(x, (state_h, state_c))\n",
"\n",
" last_word_logits = y_pred[0][-1]\n",
" p = torch.nn.functional.softmax(last_word_logits, dim=0).detach().cpu().numpy()\n",
" word_index = np.random.choice(len(last_word_logits), p=p)\n",
" words.append(dataset.index_to_word[word_index])\n",
"\n",
" return words"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['kmicic', 'szedł', 'zwycięzco', 'po', 'do', 'zlituj', 'i']"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"predict(dataset, model, 'kmicic szedł')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### ZADANIE 1\n",
"\n",
"Stworzyć sieć rekurencyjną GRU dla Challenging America word-gap prediction. Wymogi takie jak zawsze, zadanie widoczne na gonito"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### ZADANIE 2\n",
"\n",
"Podjąć wyzwanie na https://gonito.net/challenge/precipitation-pl i/lub https://gonito.net/challenge/book-dialogues-pl\n",
"\n",
"\n",
"**KONIECZNIE** należy je zgłosić do końca następnego piątku, czyli 20 maja!. Za późniejsze zgłoszenia (nawet minutę) nieprzyznaję punktów.\n",
" \n",
"Za każde zgłoszenie lepsze niż baseline przyznaję 40 punktów.\n",
"\n",
"Zamiast tych 40 punktów za najlepsze miejsca:\n",
"- 1. miejsce 150 punktów\n",
"- 2. miejsce 100 punktów\n",
"- 3. miejsce 70 punktów\n",
"\n",
"Można brać udział w 2 wyzwaniach jednocześnie.\n",
"\n",
"Zadania nie będą widoczne w gonito w achievements. Nie trzeba udostępniać kodu, należy jednak przestrzegać regulaminu wyzwań."
]
}
],
"metadata": {
"author": "Jakub Pokrywka",
"email": "kubapok@wmi.amu.edu.pl",
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"lang": "pl",
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.4"
},
"subtitle": "0.Informacje na temat przedmiotu[ćwiczenia]",
"title": "Ekstrakcja informacji",
"year": "2021"
},
"nbformat": 4,
"nbformat_minor": 4
}