aitech-moj/cw/11_Model_rekurencyjny_z_atencją.ipynb

817 lines
27 KiB
Plaintext
Raw Normal View History

2022-05-29 18:14:19 +02:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"![Logo 1](https://git.wmi.amu.edu.pl/AITech/Szablon/raw/branch/master/Logotyp_AITech1.jpg)\n",
"<div class=\"alert alert-block alert-info\">\n",
"<h1> Modelowanie Języka</h1>\n",
"<h2> 10. <i>Model rekurencyjny z atencją</i> [ćwiczenia]</h2> \n",
"<h3> Jakub Pokrywka (2022)</h3>\n",
"</div>\n",
"\n",
"![Logo 2](https://git.wmi.amu.edu.pl/AITech/Szablon/raw/branch/master/Logotyp_AITech2.jpg)"
]
},
{
2022-05-29 20:00:36 +02:00
"cell_type": "markdown",
2022-05-29 18:14:19 +02:00
"metadata": {},
"source": [
2022-05-29 20:00:36 +02:00
"notebook na podstawie:\n",
"\n",
2022-05-29 18:14:19 +02:00
"# https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html"
]
},
{
"cell_type": "code",
2022-05-29 21:24:53 +02:00
"execution_count": 1,
2022-05-29 18:14:19 +02:00
"metadata": {},
"outputs": [],
"source": [
"from __future__ import unicode_literals, print_function, division\n",
"from io import open\n",
"import unicodedata\n",
"import string\n",
"import re\n",
"import random\n",
"\n",
"import torch\n",
"import torch.nn as nn\n",
"from torch import optim\n",
"import torch.nn.functional as F\n",
"\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
]
},
{
"cell_type": "code",
2022-05-29 21:24:53 +02:00
"execution_count": 2,
2022-05-29 18:14:19 +02:00
"metadata": {},
"outputs": [],
"source": [
"SOS_token = 0\n",
"EOS_token = 1\n",
"\n",
"class Lang:\n",
" def __init__(self):\n",
" self.word2index = {}\n",
" self.word2count = {}\n",
" self.index2word = {0: \"SOS\", 1: \"EOS\"}\n",
" self.n_words = 2 # Count SOS and EOS\n",
"\n",
" def addSentence(self, sentence):\n",
" for word in sentence.split(' '):\n",
" self.addWord(word)\n",
"\n",
" def addWord(self, word):\n",
" if word not in self.word2index:\n",
" self.word2index[word] = self.n_words\n",
" self.word2count[word] = 1\n",
" self.index2word[self.n_words] = word\n",
" self.n_words += 1\n",
" else:\n",
" self.word2count[word] += 1"
]
},
{
"cell_type": "code",
2022-05-29 21:24:53 +02:00
"execution_count": 3,
2022-05-29 18:14:19 +02:00
"metadata": {},
"outputs": [],
"source": [
"pairs = []\n",
2022-05-29 19:05:03 +02:00
"with open('data/eng-pol.txt') as f:\n",
2022-05-29 18:14:19 +02:00
" for line in f:\n",
2022-05-29 19:05:03 +02:00
" eng_line, pol_line = line.lower().rstrip().split('\\t')\n",
2022-05-29 18:14:19 +02:00
"\n",
" eng_line = re.sub(r\"([.!?])\", r\" \\1\", eng_line)\n",
" eng_line = re.sub(r\"[^a-zA-Z.!?]+\", r\" \", eng_line)\n",
"\n",
2022-05-29 19:05:03 +02:00
" pol_line = re.sub(r\"([.!?])\", r\" \\1\", pol_line)\n",
2022-05-29 20:00:36 +02:00
" pol_line = re.sub(r\"[^a-zA-Z.!?ąćęłńóśźżĄĆĘŁŃÓŚŹŻ]+\", r\" \", pol_line)\n",
2022-05-29 18:14:19 +02:00
"\n",
2022-05-29 19:05:03 +02:00
" pairs.append([eng_line, pol_line])\n",
2022-05-29 18:14:19 +02:00
"\n",
"\n"
]
},
{
"cell_type": "code",
2022-05-29 21:24:53 +02:00
"execution_count": 4,
2022-05-29 18:14:19 +02:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2022-05-29 20:00:36 +02:00
"['hi .', 'cześć .']"
2022-05-29 18:14:19 +02:00
]
},
2022-05-29 21:24:53 +02:00
"execution_count": 4,
2022-05-29 18:14:19 +02:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pairs[1]"
]
},
{
"cell_type": "code",
2022-05-29 21:24:53 +02:00
"execution_count": 5,
2022-05-29 18:14:19 +02:00
"metadata": {},
"outputs": [],
"source": [
"MAX_LENGTH = 10\n",
"eng_prefixes = (\n",
" \"i am \", \"i m \",\n",
" \"he is\", \"he s \",\n",
" \"she is\", \"she s \",\n",
" \"you are\", \"you re \",\n",
" \"we are\", \"we re \",\n",
" \"they are\", \"they re \"\n",
")\n",
"\n",
"pairs = [p for p in pairs if len(p[0].split(' ')) < MAX_LENGTH and len(p[1].split(' ')) < MAX_LENGTH]\n",
"pairs = [p for p in pairs if p[0].startswith(eng_prefixes)]\n",
"\n",
"eng_lang = Lang()\n",
2022-05-29 19:05:03 +02:00
"pol_lang = Lang()\n",
2022-05-29 18:14:19 +02:00
"\n",
"for pair in pairs:\n",
" eng_lang.addSentence(pair[0])\n",
2022-05-29 19:05:03 +02:00
" pol_lang.addSentence(pair[1])"
2022-05-29 18:14:19 +02:00
]
},
{
"cell_type": "code",
2022-05-29 21:24:53 +02:00
"execution_count": 6,
2022-05-29 18:14:19 +02:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2022-05-29 20:00:36 +02:00
"['i m ok .', 'ze mną wszystko w porządku .']"
2022-05-29 18:14:19 +02:00
]
},
2022-05-29 21:24:53 +02:00
"execution_count": 6,
2022-05-29 18:14:19 +02:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pairs[0]"
]
},
{
"cell_type": "code",
2022-05-29 21:24:53 +02:00
"execution_count": 7,
2022-05-29 18:14:19 +02:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2022-05-29 20:00:36 +02:00
"['i m up .', 'wstałem .']"
2022-05-29 18:14:19 +02:00
]
},
2022-05-29 21:24:53 +02:00
"execution_count": 7,
2022-05-29 18:14:19 +02:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pairs[1]"
]
},
{
"cell_type": "code",
2022-05-29 21:24:53 +02:00
"execution_count": 8,
2022-05-29 18:14:19 +02:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2022-05-29 19:05:03 +02:00
"['i m tom .', 'jestem tom .']"
2022-05-29 18:14:19 +02:00
]
},
2022-05-29 21:24:53 +02:00
"execution_count": 8,
2022-05-29 18:14:19 +02:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pairs[2]"
]
},
2022-05-29 21:24:53 +02:00
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1828"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"eng_lang.n_words"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"2883"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pol_lang.n_words"
]
},
2022-05-29 18:14:19 +02:00
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"class EncoderRNN(nn.Module):\n",
" def __init__(self, input_size, hidden_size):\n",
" super(EncoderRNN, self).__init__()\n",
" self.hidden_size = hidden_size\n",
"\n",
" self.embedding = nn.Embedding(input_size, hidden_size)\n",
" self.gru = nn.GRU(hidden_size, hidden_size)\n",
"\n",
" def forward(self, input, hidden):\n",
" embedded = self.embedding(input).view(1, 1, -1)\n",
" output = embedded\n",
" output, hidden = self.gru(output, hidden)\n",
" return output, hidden\n",
"\n",
" def initHidden(self):\n",
" return torch.zeros(1, 1, self.hidden_size, device=device)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"class DecoderRNN(nn.Module):\n",
" def __init__(self, hidden_size, output_size):\n",
" super(DecoderRNN, self).__init__()\n",
" self.hidden_size = hidden_size\n",
"\n",
" self.embedding = nn.Embedding(output_size, hidden_size)\n",
" self.gru = nn.GRU(hidden_size, hidden_size)\n",
" self.out = nn.Linear(hidden_size, output_size)\n",
" self.softmax = nn.LogSoftmax(dim=1)\n",
"\n",
" def forward(self, input, hidden):\n",
" output = self.embedding(input).view(1, 1, -1)\n",
" output = F.relu(output)\n",
" output, hidden = self.gru(output, hidden)\n",
" output = self.softmax(self.out(output[0]))\n",
" return output, hidden\n",
"\n",
" def initHidden(self):\n",
" return torch.zeros(1, 1, self.hidden_size, device=device)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"class AttnDecoderRNN(nn.Module):\n",
" def __init__(self, hidden_size, output_size, dropout_p=0.1, max_length=MAX_LENGTH):\n",
" super(AttnDecoderRNN, self).__init__()\n",
" self.hidden_size = hidden_size\n",
" self.output_size = output_size\n",
" self.dropout_p = dropout_p\n",
" self.max_length = max_length\n",
"\n",
" self.embedding = nn.Embedding(self.output_size, self.hidden_size)\n",
" self.attn = nn.Linear(self.hidden_size * 2, self.max_length)\n",
" self.attn_combine = nn.Linear(self.hidden_size * 2, self.hidden_size)\n",
" self.dropout = nn.Dropout(self.dropout_p)\n",
" self.gru = nn.GRU(self.hidden_size, self.hidden_size)\n",
" self.out = nn.Linear(self.hidden_size, self.output_size)\n",
"\n",
" def forward(self, input, hidden, encoder_outputs):\n",
" embedded = self.embedding(input).view(1, 1, -1)\n",
" embedded = self.dropout(embedded)\n",
"\n",
" attn_weights = F.softmax(\n",
" self.attn(torch.cat((embedded[0], hidden[0]), 1)), dim=1)\n",
" attn_applied = torch.bmm(attn_weights.unsqueeze(0),\n",
" encoder_outputs.unsqueeze(0))\n",
"\n",
" output = torch.cat((embedded[0], attn_applied[0]), 1)\n",
" output = self.attn_combine(output).unsqueeze(0)\n",
"\n",
" output = F.relu(output)\n",
" output, hidden = self.gru(output, hidden)\n",
"\n",
" output = F.log_softmax(self.out(output[0]), dim=1)\n",
" return output, hidden, attn_weights\n",
"\n",
" def initHidden(self):\n",
" return torch.zeros(1, 1, self.hidden_size, device=device)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"def tensorFromSentence(sentence, lang):\n",
" indexes = [lang.word2index[word] for word in sentence.split(' ')]\n",
" indexes.append(EOS_token)\n",
" return torch.tensor(indexes, dtype=torch.long, device=device).view(-1, 1)\n"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"teacher_forcing_ratio = 0.5\n",
"\n",
2022-05-29 21:24:53 +02:00
"def train_one_batch(input_tensor, target_tensor, encoder, decoder, optimizer, criterion, max_length=MAX_LENGTH):\n",
2022-05-29 18:14:19 +02:00
" encoder_hidden = encoder.initHidden()\n",
"\n",
2022-05-29 21:24:53 +02:00
"\n",
" optimizer.zero_grad()\n",
2022-05-29 18:14:19 +02:00
"\n",
" input_length = input_tensor.size(0)\n",
" target_length = target_tensor.size(0)\n",
"\n",
" encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device)\n",
"\n",
" loss = 0\n",
"\n",
" for ei in range(input_length):\n",
2022-05-29 19:05:03 +02:00
" encoder_output, encoder_hidden = encoder(input_tensor[ei], encoder_hidden)\n",
2022-05-29 18:14:19 +02:00
" encoder_outputs[ei] = encoder_output[0, 0]\n",
"\n",
" decoder_input = torch.tensor([[SOS_token]], device=device)\n",
"\n",
" decoder_hidden = encoder_hidden\n",
"\n",
" use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False\n",
"\n",
" if use_teacher_forcing:\n",
" for di in range(target_length):\n",
2022-05-29 19:05:03 +02:00
" decoder_output, decoder_hidden, decoder_attention = decoder(decoder_input, decoder_hidden, encoder_outputs)\n",
2022-05-29 18:14:19 +02:00
" loss += criterion(decoder_output, target_tensor[di])\n",
" decoder_input = target_tensor[di] # Teacher forcing\n",
"\n",
" else:\n",
" for di in range(target_length):\n",
2022-05-29 19:05:03 +02:00
" decoder_output, decoder_hidden, decoder_attention = decoder(decoder_input, decoder_hidden, encoder_outputs)\n",
2022-05-29 18:14:19 +02:00
" topv, topi = decoder_output.topk(1)\n",
" decoder_input = topi.squeeze().detach() # detach from history as input\n",
"\n",
" loss += criterion(decoder_output, target_tensor[di])\n",
" if decoder_input.item() == EOS_token:\n",
" break\n",
"\n",
" loss.backward()\n",
"\n",
2022-05-29 21:24:53 +02:00
" optimizer.step()\n",
2022-05-29 18:14:19 +02:00
"\n",
" return loss.item() / target_length"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"def trainIters(encoder, decoder, n_iters, print_every=1000, learning_rate=0.01):\n",
" print_loss_total = 0 # Reset every print_every\n",
2022-05-29 21:24:53 +02:00
" encoder.train()\n",
" decoder.train()\n",
2022-05-29 18:14:19 +02:00
"\n",
2022-05-29 21:24:53 +02:00
" optimizer = optim.SGD(list(encoder.parameters()) + list(decoder.parameters()), lr=learning_rate)\n",
2022-05-29 18:14:19 +02:00
" \n",
" training_pairs = [random.choice(pairs) for _ in range(n_iters)]\n",
2022-05-29 19:05:03 +02:00
" training_pairs = [(tensorFromSentence(p[0], eng_lang), tensorFromSentence(p[1], pol_lang)) for p in training_pairs]\n",
2022-05-29 18:14:19 +02:00
" \n",
" criterion = nn.NLLLoss()\n",
"\n",
" for i in range(1, n_iters + 1):\n",
" training_pair = training_pairs[i - 1]\n",
" input_tensor = training_pair[0]\n",
" target_tensor = training_pair[1]\n",
"\n",
" loss = train_one_batch(input_tensor,\n",
" target_tensor,\n",
" encoder,\n",
2022-05-29 19:05:03 +02:00
" decoder,\n",
2022-05-29 21:24:53 +02:00
" optimizer,\n",
"\n",
2022-05-29 18:14:19 +02:00
" criterion)\n",
" \n",
" print_loss_total += loss\n",
"\n",
" if i % print_every == 0:\n",
" print_loss_avg = print_loss_total / print_every\n",
" print_loss_total = 0\n",
" print(f'iter: {i}, loss: {print_loss_avg}')\n"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"def evaluate(encoder, decoder, sentence, max_length=MAX_LENGTH):\n",
2022-05-29 21:24:53 +02:00
" encoder.eval()\n",
" decoder.eval()\n",
2022-05-29 18:14:19 +02:00
" with torch.no_grad():\n",
" input_tensor = tensorFromSentence(sentence, eng_lang)\n",
" input_length = input_tensor.size()[0]\n",
" encoder_hidden = encoder.initHidden()\n",
"\n",
" encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device)\n",
"\n",
" for ei in range(input_length):\n",
2022-05-29 21:24:53 +02:00
" encoder_output, encoder_hidden = encoder(input_tensor[ei], encoder_hidden)\n",
2022-05-29 18:14:19 +02:00
" encoder_outputs[ei] += encoder_output[0, 0]\n",
"\n",
2022-05-29 21:24:53 +02:00
" decoder_input = torch.tensor([[SOS_token]], device=device)\n",
2022-05-29 18:14:19 +02:00
"\n",
" decoder_hidden = encoder_hidden\n",
"\n",
" decoded_words = []\n",
" decoder_attentions = torch.zeros(max_length, max_length)\n",
"\n",
" for di in range(max_length):\n",
" decoder_output, decoder_hidden, decoder_attention = decoder(\n",
" decoder_input, decoder_hidden, encoder_outputs)\n",
" decoder_attentions[di] = decoder_attention.data\n",
" topv, topi = decoder_output.data.topk(1)\n",
" if topi.item() == EOS_token:\n",
" decoded_words.append('<EOS>')\n",
" break\n",
" else:\n",
2022-05-29 19:05:03 +02:00
" decoded_words.append(pol_lang.index2word[topi.item()])\n",
2022-05-29 18:14:19 +02:00
"\n",
" decoder_input = topi.squeeze().detach()\n",
"\n",
" return decoded_words, decoder_attentions[:di + 1]"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"def evaluateRandomly(encoder, decoder, n=10):\n",
" for i in range(n):\n",
" pair = random.choice(pairs)\n",
" print('>', pair[0])\n",
" print('=', pair[1])\n",
" output_words, attentions = evaluate(encoder, decoder, pair[0])\n",
" output_sentence = ' '.join(output_words)\n",
" print('<', output_sentence)\n",
" print('')"
]
},
{
"cell_type": "code",
2022-05-29 20:00:36 +02:00
"execution_count": 19,
2022-05-29 18:14:19 +02:00
"metadata": {},
2022-05-29 21:24:53 +02:00
"outputs": [],
"source": [
"hidden_size = 256\n",
"encoder1 = EncoderRNN(eng_lang.n_words, hidden_size).to(device)\n",
"attn_decoder1 = AttnDecoderRNN(hidden_size, pol_lang.n_words, dropout_p=0.1).to(device)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
2022-05-29 18:14:19 +02:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2022-05-29 21:24:53 +02:00
"iter: 50, loss: 4.699110437711081\n",
"iter: 100, loss: 4.241607124086411\n",
"iter: 150, loss: 4.14866822333563\n",
"iter: 200, loss: 4.175457921709334\n",
"iter: 250, loss: 4.304153789429438\n",
"iter: 300, loss: 4.304717092377798\n",
"iter: 350, loss: 4.316578052808368\n",
"iter: 400, loss: 4.379952565056937\n",
"iter: 450, loss: 4.086811531929743\n",
"iter: 500, loss: 4.252370147765628\n",
"iter: 550, loss: 4.02257244164603\n",
"iter: 600, loss: 4.271288591505989\n",
"iter: 650, loss: 4.037527732379852\n",
"iter: 700, loss: 3.808401109422956\n",
"iter: 750, loss: 4.01287091629089\n",
"iter: 800, loss: 4.185342459905715\n",
"iter: 850, loss: 3.8268170519934763\n",
"iter: 900, loss: 3.9197384970074607\n",
"iter: 950, loss: 4.225208856279888\n",
"iter: 1000, loss: 4.128686094178094\n",
"iter: 1050, loss: 3.9167927505553712\n",
"iter: 1100, loss: 4.015269571940103\n",
"iter: 1150, loss: 4.168424199830918\n",
"iter: 1200, loss: 4.302581990559896\n",
"iter: 1250, loss: 3.7335942743392225\n",
"iter: 1300, loss: 3.9526881422315334\n",
"iter: 1350, loss: 3.8640213389169604\n",
"iter: 1400, loss: 4.101886716827512\n",
"iter: 1450, loss: 3.6106392740067985\n",
"iter: 1500, loss: 4.0689067233857665\n",
"iter: 1550, loss: 4.02288844353812\n",
"iter: 1600, loss: 3.572508715992883\n",
"iter: 1650, loss: 3.972692446489183\n",
"iter: 1700, loss: 3.8709554294404525\n",
"iter: 1750, loss: 3.9830204631714583\n",
"iter: 1800, loss: 3.7999766263961794\n",
"iter: 1850, loss: 3.7026816112578858\n",
"iter: 1900, loss: 3.833205360775902\n",
"iter: 1950, loss: 3.650638633606925\n",
"iter: 2000, loss: 3.748746382418133\n",
"iter: 2050, loss: 3.762590566922748\n",
"iter: 2100, loss: 3.5997376789214117\n",
"iter: 2150, loss: 3.919283335610041\n",
"iter: 2200, loss: 3.8638847478684912\n",
"iter: 2250, loss: 3.4960837801675946\n",
"iter: 2300, loss: 3.685049927688782\n",
"iter: 2350, loss: 3.5716699722759313\n",
"iter: 2400, loss: 3.8988636863874997\n",
"iter: 2450, loss: 3.752788569586617\n",
"iter: 2500, loss: 3.802307117961702\n",
"iter: 2550, loss: 3.6420236970432227\n",
"iter: 2600, loss: 3.6925315249912325\n",
"iter: 2650, loss: 3.8897219879059572\n",
"iter: 2700, loss: 3.6327851654537153\n",
"iter: 2750, loss: 3.396957855118645\n",
"iter: 2800, loss: 3.5258935768112307\n",
"iter: 2850, loss: 3.605109554866003\n",
"iter: 2900, loss: 3.533288128330594\n",
"iter: 2950, loss: 3.4583421086054\n",
"iter: 3000, loss: 3.403592811425526\n",
"iter: 3050, loss: 3.5225157889411567\n",
"iter: 3100, loss: 3.4702517202846592\n",
"iter: 3150, loss: 3.4234997159185867\n",
"iter: 3200, loss: 3.5447632862348404\n",
"iter: 3250, loss: 3.1799173504133074\n",
"iter: 3300, loss: 3.7154814013905\n",
"iter: 3350, loss: 3.4188442155444445\n",
"iter: 3400, loss: 3.6557525696527393\n",
"iter: 3450, loss: 3.52880564416401\n",
"iter: 3500, loss: 3.4842312318408295\n",
"iter: 3550, loss: 3.5256399853570115\n",
"iter: 3600, loss: 3.70226228499034\n",
"iter: 3650, loss: 3.2043497113424633\n",
"iter: 3700, loss: 3.4575287022439256\n",
"iter: 3750, loss: 3.4197605448374664\n",
"iter: 3800, loss: 3.290345760890417\n",
"iter: 3850, loss: 3.300158274309976\n",
"iter: 3900, loss: 3.3362661438139645\n",
"iter: 3950, loss: 3.4947717628630373\n",
"iter: 4000, loss: 3.5624450731353154\n",
"iter: 4050, loss: 3.438600626892514\n",
"iter: 4100, loss: 3.142976412258451\n",
"iter: 4150, loss: 3.332818130595344\n",
"iter: 4200, loss: 3.31952378733196\n",
"iter: 4250, loss: 3.5315058948123252\n",
"iter: 4300, loss: 3.6603812535074023\n",
"iter: 4350, loss: 3.35295347692853\n",
"iter: 4400, loss: 3.374297706498041\n",
"iter: 4450, loss: 3.09948105843105\n",
"iter: 4500, loss: 3.16787886763376\n",
"iter: 4550, loss: 3.455794033330583\n",
"iter: 4600, loss: 3.1263191164258926\n",
"iter: 4650, loss: 3.3723485524995\n",
"iter: 4700, loss: 3.147410953930445\n",
"iter: 4750, loss: 3.4546711923281346\n",
"iter: 4800, loss: 3.449277176016852\n",
"iter: 4850, loss: 3.197799104531606\n",
"iter: 4900, loss: 3.239384971149383\n",
"iter: 4950, loss: 3.696369633697328\n",
"iter: 5000, loss: 3.2114706332191587\n",
"iter: 5050, loss: 3.400943172795432\n",
"iter: 5100, loss: 3.298932059106372\n",
"iter: 5150, loss: 3.3697974183445907\n",
"iter: 5200, loss: 3.31293656670858\n",
"iter: 5250, loss: 3.1415378823658773\n",
"iter: 5300, loss: 3.1587839283867494\n",
"iter: 5350, loss: 3.3505903312440903\n",
"iter: 5400, loss: 3.247191356802744\n",
"iter: 5450, loss: 3.236625145200699\n",
"iter: 5500, loss: 3.19994143747148\n",
"iter: 5550, loss: 3.2911239544626265\n",
"iter: 5600, loss: 3.1855649600483122\n",
"iter: 5650, loss: 3.157031875163789\n",
"iter: 5700, loss: 3.2652817099586366\n",
"iter: 5750, loss: 3.3272896775593837\n",
"iter: 5800, loss: 3.3162626687458583\n",
"iter: 5850, loss: 3.1342987139338536\n",
"iter: 5900, loss: 3.29665669613036\n",
"iter: 5950, loss: 3.232995939807286\n",
"iter: 6000, loss: 3.0922561403758935\n",
"iter: 6050, loss: 3.1034776155835107\n",
"iter: 6100, loss: 3.1502840874081564\n",
"iter: 6150, loss: 2.915993771098909\n",
"iter: 6200, loss: 2.994096033270397\n",
"iter: 6250, loss: 3.1102042265392487\n",
"iter: 6300, loss: 2.8244728108587718\n",
"iter: 6350, loss: 3.117810124692462\n",
"iter: 6400, loss: 3.0742526639529637\n",
"iter: 6450, loss: 2.8390014954218787\n",
"iter: 6500, loss: 3.1032223067510687\n",
"iter: 6550, loss: 2.912433739840038\n",
"iter: 6600, loss: 2.9158696003490023\n",
"iter: 6650, loss: 3.2617745389030093\n",
"iter: 6700, loss: 3.295657290466248\n",
"iter: 6750, loss: 2.975928121767347\n",
"iter: 6800, loss: 3.0057779382069914\n",
"iter: 6850, loss: 2.85224422507059\n",
"iter: 6900, loss: 3.0329934195336836\n",
"iter: 6950, loss: 3.1322296761255415\n",
"iter: 7000, loss: 2.893814939192363\n",
"iter: 7050, loss: 2.934597730205173\n",
"iter: 7100, loss: 3.267660904082041\n",
"iter: 7150, loss: 3.1199153114651867\n",
"iter: 7200, loss: 2.8414319788160776\n",
"iter: 7250, loss: 3.1128779797251256\n",
"iter: 7300, loss: 3.1182169116565155\n",
"iter: 7350, loss: 3.101384938853128\n",
"iter: 7400, loss: 2.9836614183395627\n",
"iter: 7450, loss: 2.7261425285036602\n",
"iter: 7500, loss: 2.7323913456977356\n",
"iter: 7550, loss: 3.284201001443559\n",
"iter: 7600, loss: 2.9473503636405587\n",
"iter: 7650, loss: 2.861012626541986\n",
"iter: 7700, loss: 2.6726747900872003\n",
"iter: 7750, loss: 2.760957624162947\n",
"iter: 7800, loss: 2.647666095211393\n",
"iter: 7850, loss: 2.7921250426428657\n",
"iter: 7900, loss: 2.9527213778495787\n",
"iter: 7950, loss: 2.790506172891647\n",
"iter: 8000, loss: 2.8376009529431663\n",
"iter: 8050, loss: 3.0387913953690298\n",
"iter: 8100, loss: 2.908381733046637\n",
"iter: 8150, loss: 2.7374484727761104\n",
"iter: 8200, loss: 2.84610585779614\n",
"iter: 8250, loss: 2.8532650649736793\n",
"iter: 8300, loss: 2.856347685723078\n",
"iter: 8350, loss: 2.6641267998710503\n",
"iter: 8400, loss: 2.7541870554590973\n",
"iter: 8450, loss: 2.814719854824126\n",
"iter: 8500, loss: 2.6979909611694395\n",
"iter: 8550, loss: 2.577483120327904\n",
"iter: 8600, loss: 2.7884950113561415\n",
"iter: 8650, loss: 3.0236114144552317\n",
"iter: 8700, loss: 2.5850161893329924\n",
"iter: 8750, loss: 2.992550043756999\n",
"iter: 8800, loss: 2.581544444644262\n",
"iter: 8850, loss: 2.7955539315276674\n",
"iter: 8900, loss: 2.583812619288763\n",
"iter: 8950, loss: 2.6446591711649825\n",
"iter: 9000, loss: 2.577330000854674\n",
"iter: 9050, loss: 2.4657566853288615\n",
"iter: 9100, loss: 2.800543680138058\n",
"iter: 9150, loss: 2.8939966171544707\n",
"iter: 9200, loss: 2.484702325525738\n",
"iter: 9250, loss: 2.9708456475469807\n",
"iter: 9300, loss: 2.8829837035148858\n",
"iter: 9350, loss: 2.451061187414896\n",
"iter: 9400, loss: 3.144906068983533\n",
"iter: 9450, loss: 2.4527184899950787\n",
"iter: 9500, loss: 2.665944624832698\n",
"iter: 9550, loss: 2.5468089370273406\n",
"iter: 9600, loss: 2.51169423552165\n",
"iter: 9650, loss: 2.916568091210864\n",
"iter: 9700, loss: 2.8149766059640853\n",
"iter: 9750, loss: 2.6544064010362773\n",
"iter: 9800, loss: 2.300161985658464\n",
"iter: 9850, loss: 2.5070087575912483\n",
"iter: 9900, loss: 2.617770311056621\n",
"iter: 9950, loss: 2.756971993983738\n",
"iter: 10000, loss: 2.629019902910504\n"
2022-05-29 18:14:19 +02:00
]
}
],
"source": [
2022-05-29 21:24:53 +02:00
"trainIters(encoder1, attn_decoder1, 10_000, print_every=50)"
2022-05-29 18:14:19 +02:00
]
},
{
"cell_type": "code",
2022-05-29 21:24:53 +02:00
"execution_count": 21,
2022-05-29 18:14:19 +02:00
"metadata": {},
2022-05-29 20:00:36 +02:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2022-05-29 21:24:53 +02:00
"> we re both in the same class .\n",
"= jesteśmy oboje w tej samej klasie .\n",
"< jesteśmy w w . <EOS>\n",
2022-05-29 20:00:36 +02:00
"\n",
2022-05-29 21:24:53 +02:00
"> you re telling lies again .\n",
"= znowu kłamiesz .\n",
"< znowu mi . <EOS>\n",
2022-05-29 20:00:36 +02:00
"\n",
2022-05-29 21:24:53 +02:00
"> i m glad you re back .\n",
"= cieszę się że wróciliście .\n",
"< cieszę się że . . <EOS>\n",
2022-05-29 20:00:36 +02:00
"\n",
2022-05-29 21:24:53 +02:00
"> i m not going to have any fun .\n",
"= nie będę się bawił .\n",
"< nie wolno się . . <EOS>\n",
2022-05-29 20:00:36 +02:00
"\n",
2022-05-29 21:24:53 +02:00
"> i m practising judo .\n",
"= trenuję dżudo .\n",
"< jestem . . <EOS>\n",
2022-05-29 20:00:36 +02:00
"\n",
2022-05-29 21:24:53 +02:00
"> you re wasting our time .\n",
"= marnujesz nasz czas .\n",
"< masz ci na . . <EOS>\n",
2022-05-29 20:00:36 +02:00
"\n",
2022-05-29 21:24:53 +02:00
"> he is anxious about her health .\n",
"= on martwi się o jej zdrowie .\n",
"< jest bardzo z niej . . <EOS>\n",
2022-05-29 20:00:36 +02:00
"\n",
2022-05-29 21:24:53 +02:00
"> you re introverted .\n",
"= jesteś zamknięty w sobie .\n",
"< masz . <EOS>\n",
2022-05-29 20:00:36 +02:00
"\n",
2022-05-29 21:24:53 +02:00
"> she s correct for sure .\n",
"= ona z pewnością ma rację .\n",
"< ona jest z z . <EOS>\n",
2022-05-29 20:00:36 +02:00
"\n",
2022-05-29 21:24:53 +02:00
"> they re armed .\n",
"= są uzbrojeni .\n",
"< są . . <EOS>\n",
2022-05-29 20:00:36 +02:00
"\n"
]
}
],
2022-05-29 18:14:19 +02:00
"source": [
"evaluateRandomly(encoder1, attn_decoder1)"
]
}
],
"metadata": {
"author": "Jakub Pokrywka",
"email": "kubapok@wmi.amu.edu.pl",
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"lang": "pl",
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.4"
},
"subtitle": "0.Informacje na temat przedmiotu[ćwiczenia]",
"title": "Ekstrakcja informacji",
"year": "2021"
},
"nbformat": 4,
"nbformat_minor": 4
}