From 6f19489a1324d37d1f02e9401b7058b32bb3e318 Mon Sep 17 00:00:00 2001 From: Jakub Pokrywka Date: Sun, 29 May 2022 18:14:19 +0200 Subject: [PATCH] 11 in progress --- cw/11_Model_rekurencyjny_z_atencją.ipynb | 764 ++++++++++++++++++++++ 1 file changed, 764 insertions(+) create mode 100644 cw/11_Model_rekurencyjny_z_atencją.ipynb diff --git a/cw/11_Model_rekurencyjny_z_atencją.ipynb b/cw/11_Model_rekurencyjny_z_atencją.ipynb new file mode 100644 index 0000000..7a901b7 --- /dev/null +++ b/cw/11_Model_rekurencyjny_z_atencją.ipynb @@ -0,0 +1,764 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "![Logo 1](https://git.wmi.amu.edu.pl/AITech/Szablon/raw/branch/master/Logotyp_AITech1.jpg)\n", + "
\n", + "

Modelowanie Języka

\n", + "

10. Model rekurencyjny z atencją [ćwiczenia]

\n", + "

Jakub Pokrywka (2022)

\n", + "
\n", + "\n", + "![Logo 2](https://git.wmi.amu.edu.pl/AITech/Szablon/raw/branch/master/Logotyp_AITech2.jpg)" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "# https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "from __future__ import unicode_literals, print_function, division\n", + "from io import open\n", + "import unicodedata\n", + "import string\n", + "import re\n", + "import random\n", + "\n", + "import torch\n", + "import torch.nn as nn\n", + "from torch import optim\n", + "import torch.nn.functional as F\n", + "\n", + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "SOS_token = 0\n", + "EOS_token = 1\n", + "\n", + "class Lang:\n", + " def __init__(self):\n", + " self.word2index = {}\n", + " self.word2count = {}\n", + " self.index2word = {0: \"SOS\", 1: \"EOS\"}\n", + " self.n_words = 2 # Count SOS and EOS\n", + "\n", + " def addSentence(self, sentence):\n", + " for word in sentence.split(' '):\n", + " self.addWord(word)\n", + "\n", + " def addWord(self, word):\n", + " if word not in self.word2index:\n", + " self.word2index[word] = self.n_words\n", + " self.word2count[word] = 1\n", + " self.index2word[self.n_words] = word\n", + " self.n_words += 1\n", + " else:\n", + " self.word2count[word] += 1" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "def unicodeToAscii(s):\n", + " return ''.join(\n", + " c for c in unicodedata.normalize('NFD', s)\n", + " if unicodedata.category(c) != 'Mn'\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "pairs = []\n", + "with open('data/eng-fra.txt') as f:\n", + " for line in f:\n", + " eng_line, fra_line = line.lower().rstrip().split('\\t')\n", + "\n", + " eng_line = re.sub(r\"([.!?])\", r\" \\1\", eng_line)\n", + " eng_line = re.sub(r\"[^a-zA-Z.!?]+\", r\" \", eng_line)\n", + "\n", + " fra_line = re.sub(r\"([.!?])\", r\" \\1\", fra_line)\n", + " fra_line = re.sub(r\"[^a-zA-Z.!?]+\", r\" \", fra_line)\n", + " \n", + " eng_line = unicodeToAscii(eng_line)\n", + " fra_line = unicodeToAscii(fra_line)\n", + "\n", + " pairs.append([eng_line, fra_line])\n", + "\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['run !', 'cours !']" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pairs[1]" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "MAX_LENGTH = 10\n", + "eng_prefixes = (\n", + " \"i am \", \"i m \",\n", + " \"he is\", \"he s \",\n", + " \"she is\", \"she s \",\n", + " \"you are\", \"you re \",\n", + " \"we are\", \"we re \",\n", + " \"they are\", \"they re \"\n", + ")\n", + "\n", + "pairs = [p for p in pairs if len(p[0].split(' ')) < MAX_LENGTH and len(p[1].split(' ')) < MAX_LENGTH]\n", + "pairs = [p for p in pairs if p[0].startswith(eng_prefixes)]\n", + "\n", + "eng_lang = Lang()\n", + "fra_lang = Lang()\n", + "\n", + "for pair in pairs:\n", + " eng_lang.addSentence(pair[0])\n", + " fra_lang.addSentence(pair[1])" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['i m .', 'j ai ans .']" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pairs[0]" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['i m ok .', 'je vais bien .']" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pairs[1]" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['i m ok .', ' a va .']" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pairs[2]" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "class EncoderRNN(nn.Module):\n", + " def __init__(self, input_size, hidden_size):\n", + " super(EncoderRNN, self).__init__()\n", + " self.hidden_size = hidden_size\n", + "\n", + " self.embedding = nn.Embedding(input_size, hidden_size)\n", + " self.gru = nn.GRU(hidden_size, hidden_size)\n", + "\n", + " def forward(self, input, hidden):\n", + " embedded = self.embedding(input).view(1, 1, -1)\n", + " output = embedded\n", + " output, hidden = self.gru(output, hidden)\n", + " return output, hidden\n", + "\n", + " def initHidden(self):\n", + " return torch.zeros(1, 1, self.hidden_size, device=device)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "class DecoderRNN(nn.Module):\n", + " def __init__(self, hidden_size, output_size):\n", + " super(DecoderRNN, self).__init__()\n", + " self.hidden_size = hidden_size\n", + "\n", + " self.embedding = nn.Embedding(output_size, hidden_size)\n", + " self.gru = nn.GRU(hidden_size, hidden_size)\n", + " self.out = nn.Linear(hidden_size, output_size)\n", + " self.softmax = nn.LogSoftmax(dim=1)\n", + "\n", + " def forward(self, input, hidden):\n", + " output = self.embedding(input).view(1, 1, -1)\n", + " output = F.relu(output)\n", + " output, hidden = self.gru(output, hidden)\n", + " output = self.softmax(self.out(output[0]))\n", + " return output, hidden\n", + "\n", + " def initHidden(self):\n", + " return torch.zeros(1, 1, self.hidden_size, device=device)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "class AttnDecoderRNN(nn.Module):\n", + " def __init__(self, hidden_size, output_size, dropout_p=0.1, max_length=MAX_LENGTH):\n", + " super(AttnDecoderRNN, self).__init__()\n", + " self.hidden_size = hidden_size\n", + " self.output_size = output_size\n", + " self.dropout_p = dropout_p\n", + " self.max_length = max_length\n", + "\n", + " self.embedding = nn.Embedding(self.output_size, self.hidden_size)\n", + " self.attn = nn.Linear(self.hidden_size * 2, self.max_length)\n", + " self.attn_combine = nn.Linear(self.hidden_size * 2, self.hidden_size)\n", + " self.dropout = nn.Dropout(self.dropout_p)\n", + " self.gru = nn.GRU(self.hidden_size, self.hidden_size)\n", + " self.out = nn.Linear(self.hidden_size, self.output_size)\n", + "\n", + " def forward(self, input, hidden, encoder_outputs):\n", + " embedded = self.embedding(input).view(1, 1, -1)\n", + " embedded = self.dropout(embedded)\n", + "\n", + " attn_weights = F.softmax(\n", + " self.attn(torch.cat((embedded[0], hidden[0]), 1)), dim=1)\n", + " attn_applied = torch.bmm(attn_weights.unsqueeze(0),\n", + " encoder_outputs.unsqueeze(0))\n", + "\n", + " output = torch.cat((embedded[0], attn_applied[0]), 1)\n", + " output = self.attn_combine(output).unsqueeze(0)\n", + "\n", + " output = F.relu(output)\n", + " output, hidden = self.gru(output, hidden)\n", + "\n", + " output = F.log_softmax(self.out(output[0]), dim=1)\n", + " return output, hidden, attn_weights\n", + "\n", + " def initHidden(self):\n", + " return torch.zeros(1, 1, self.hidden_size, device=device)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "def tensorFromSentence(sentence, lang):\n", + " indexes = [lang.word2index[word] for word in sentence.split(' ')]\n", + " indexes.append(EOS_token)\n", + " return torch.tensor(indexes, dtype=torch.long, device=device).view(-1, 1)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "teacher_forcing_ratio = 0.5\n", + "\n", + "def train_one_batch(input_tensor, target_tensor, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion, max_length=MAX_LENGTH):\n", + " encoder_hidden = encoder.initHidden()\n", + "\n", + " encoder_optimizer.zero_grad()\n", + " decoder_optimizer.zero_grad()\n", + "\n", + " input_length = input_tensor.size(0)\n", + " target_length = target_tensor.size(0)\n", + "\n", + " encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device)\n", + "\n", + " loss = 0\n", + "\n", + " for ei in range(input_length):\n", + " encoder_output, encoder_hidden = encoder(\n", + " input_tensor[ei], encoder_hidden)\n", + " encoder_outputs[ei] = encoder_output[0, 0]\n", + "\n", + " decoder_input = torch.tensor([[SOS_token]], device=device)\n", + "\n", + " decoder_hidden = encoder_hidden\n", + "\n", + " use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False\n", + "\n", + " if use_teacher_forcing:\n", + " # Teacher forcing: Feed the target as the next input\n", + " for di in range(target_length):\n", + " decoder_output, decoder_hidden, decoder_attention = decoder(\n", + " decoder_input, decoder_hidden, encoder_outputs)\n", + " loss += criterion(decoder_output, target_tensor[di])\n", + " decoder_input = target_tensor[di] # Teacher forcing\n", + "\n", + " else:\n", + " # Without teacher forcing: use its own predictions as the next input\n", + " for di in range(target_length):\n", + " decoder_output, decoder_hidden, decoder_attention = decoder(\n", + " decoder_input, decoder_hidden, encoder_outputs)\n", + " topv, topi = decoder_output.topk(1)\n", + " decoder_input = topi.squeeze().detach() # detach from history as input\n", + "\n", + " loss += criterion(decoder_output, target_tensor[di])\n", + " if decoder_input.item() == EOS_token:\n", + " break\n", + "\n", + " loss.backward()\n", + "\n", + " encoder_optimizer.step()\n", + " decoder_optimizer.step()\n", + "\n", + " return loss.item() / target_length" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "def trainIters(encoder, decoder, n_iters, print_every=1000, learning_rate=0.01):\n", + " print_loss_total = 0 # Reset every print_every\n", + "\n", + " encoder_optimizer = optim.SGD(encoder.parameters(), lr=learning_rate)\n", + " decoder_optimizer = optim.SGD(decoder.parameters(), lr=learning_rate)\n", + " \n", + " training_pairs = [random.choice(pairs) for _ in range(n_iters)]\n", + " training_pairs = [(tensorFromSentence(p[0], eng_lang), tensorFromSentence(p[1], fra_lang)) for p in training_pairs]\n", + " \n", + " criterion = nn.NLLLoss()\n", + "\n", + " for i in range(1, n_iters + 1):\n", + " training_pair = training_pairs[i - 1]\n", + " input_tensor = training_pair[0]\n", + " target_tensor = training_pair[1]\n", + "\n", + " loss = train_one_batch(input_tensor,\n", + " target_tensor,\n", + " encoder,\n", + " encoder,\n", + " encoder_optimizer,\n", + " decoder_optimizer,\n", + " criterion)\n", + " \n", + " print_loss_total += loss\n", + "\n", + " if i % print_every == 0:\n", + " print_loss_avg = print_loss_total / print_every\n", + " print_loss_total = 0\n", + " print(f'iter: {i}, loss: {print_loss_avg}')\n" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "def evaluate(encoder, decoder, sentence, max_length=MAX_LENGTH):\n", + " with torch.no_grad():\n", + " input_tensor = tensorFromSentence(sentence, eng_lang)\n", + " input_length = input_tensor.size()[0]\n", + " encoder_hidden = encoder.initHidden()\n", + "\n", + " encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device)\n", + "\n", + " for ei in range(input_length):\n", + " encoder_output, encoder_hidden = encoder(input_tensor[ei],\n", + " encoder_hidden)\n", + " encoder_outputs[ei] += encoder_output[0, 0]\n", + "\n", + " decoder_input = torch.tensor([[SOS_token]], device=device) # SOS\n", + "\n", + " decoder_hidden = encoder_hidden\n", + "\n", + " decoded_words = []\n", + " decoder_attentions = torch.zeros(max_length, max_length)\n", + "\n", + " for di in range(max_length):\n", + " decoder_output, decoder_hidden, decoder_attention = decoder(\n", + " decoder_input, decoder_hidden, encoder_outputs)\n", + " decoder_attentions[di] = decoder_attention.data\n", + " topv, topi = decoder_output.data.topk(1)\n", + " if topi.item() == EOS_token:\n", + " decoded_words.append('')\n", + " break\n", + " else:\n", + " decoded_words.append(fra_lang.index2word[topi.item()])\n", + "\n", + " decoder_input = topi.squeeze().detach()\n", + "\n", + " return decoded_words, decoder_attentions[:di + 1]" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "def evaluateRandomly(encoder, decoder, n=10):\n", + " for i in range(n):\n", + " pair = random.choice(pairs)\n", + " print('>', pair[0])\n", + " print('=', pair[1])\n", + " output_words, attentions = evaluate(encoder, decoder, pair[0])\n", + " output_sentence = ' '.join(output_words)\n", + " print('<', output_sentence)\n", + " print('')" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "iter: 50, loss: 4.78930813773473\n", + "iter: 100, loss: 4.554949267220875\n", + "iter: 150, loss: 4.238516052685087\n", + "iter: 200, loss: 4.279887475513276\n", + "iter: 250, loss: 4.1802274973884455\n", + "iter: 300, loss: 4.2113521892305394\n", + "iter: 350, loss: 4.266180963228619\n", + "iter: 400, loss: 4.225914733432588\n", + "iter: 450, loss: 4.1369073431075565\n", + "iter: 500, loss: 3.9906799076019768\n", + "iter: 550, loss: 3.842005534717016\n", + "iter: 600, loss: 4.081443620484972\n", + "iter: 650, loss: 4.030401878296383\n", + "iter: 700, loss: 3.869014380984837\n", + "iter: 750, loss: 3.8505467753031906\n", + "iter: 800, loss: 3.855170104072209\n", + "iter: 850, loss: 3.675745445599631\n", + "iter: 900, loss: 3.9147777624584386\n", + "iter: 950, loss: 3.766264297788106\n", + "iter: 1000, loss: 3.6813155986997814\n", + "iter: 1050, loss: 3.9307321495934144\n", + "iter: 1100, loss: 3.9047770059525027\n", + "iter: 1150, loss: 3.655722749588981\n", + "iter: 1200, loss: 3.540693810886806\n", + "iter: 1250, loss: 3.790360960324605\n", + "iter: 1300, loss: 3.7472636015907153\n", + "iter: 1350, loss: 3.641857419574072\n", + "iter: 1400, loss: 3.717327400631375\n", + "iter: 1450, loss: 3.4848567311423166\n", + "iter: 1500, loss: 3.56774485397339\n", + "iter: 1550, loss: 3.460277635226175\n", + "iter: 1600, loss: 3.241899683013796\n", + "iter: 1650, loss: 3.50151977614751\n", + "iter: 1700, loss: 3.621569488313462\n", + "iter: 1750, loss: 3.3851226735947626\n", + "iter: 1800, loss: 3.346289497057597\n", + "iter: 1850, loss: 3.5180823354569695\n", + "iter: 1900, loss: 3.433616197676886\n", + "iter: 1950, loss: 3.6162788327080864\n", + "iter: 2000, loss: 3.4990604458763492\n", + "iter: 2050, loss: 3.3144700173423405\n", + "iter: 2100, loss: 3.2962356294980135\n", + "iter: 2150, loss: 3.1448448797861728\n", + "iter: 2200, loss: 3.6958242581534018\n", + "iter: 2250, loss: 3.5269318538241925\n", + "iter: 2300, loss: 3.180744191850934\n", + "iter: 2350, loss: 3.317159715145354\n", + "iter: 2400, loss: 3.638545340795366\n", + "iter: 2450, loss: 3.7591161967988995\n", + "iter: 2500, loss: 3.3513535446742218\n", + "iter: 2550, loss: 3.4554441847271393\n", + "iter: 2600, loss: 2.9394915195343994\n", + "iter: 2650, loss: 3.370902210848673\n", + "iter: 2700, loss: 3.4259227318839423\n", + "iter: 2750, loss: 3.4058353806904393\n", + "iter: 2800, loss: 3.467306881359647\n", + "iter: 2850, loss: 3.222254538074372\n", + "iter: 2900, loss: 3.3392559226808087\n", + "iter: 2950, loss: 3.4203980594362533\n", + "iter: 3000, loss: 3.3507530433563955\n", + "iter: 3050, loss: 3.4326547555317966\n", + "iter: 3100, loss: 3.1755515496390205\n", + "iter: 3150, loss: 3.3925877854634847\n", + "iter: 3200, loss: 3.223531436912598\n", + "iter: 3250, loss: 3.3089625614862603\n", + "iter: 3300, loss: 3.367763715501815\n", + "iter: 3350, loss: 3.4278301871163497\n", + "iter: 3400, loss: 3.373292277381534\n", + "iter: 3450, loss: 3.3497054475829717\n", + "iter: 3500, loss: 3.402910869681646\n", + "iter: 3550, loss: 3.072571641732776\n", + "iter: 3600, loss: 3.2611226563832116\n", + "iter: 3650, loss: 3.231520605495998\n", + "iter: 3700, loss: 3.3788801974569043\n", + "iter: 3750, loss: 3.176644308181036\n", + "iter: 3800, loss: 3.2255533708693496\n", + "iter: 3850, loss: 3.2362594686387083\n", + "iter: 3900, loss: 3.095807164230044\n", + "iter: 3950, loss: 3.2343999077024916\n", + "iter: 4000, loss: 3.3681417366512245\n", + "iter: 4050, loss: 3.0732023419879737\n", + "iter: 4100, loss: 3.0663742440617283\n", + "iter: 4150, loss: 3.396770855048347\n", + "iter: 4200, loss: 3.4262332421522292\n", + "iter: 4250, loss: 3.060121847773354\n", + "iter: 4300, loss: 2.895130627753243\n", + "iter: 4350, loss: 3.017712699065133\n", + "iter: 4400, loss: 3.1289404028559487\n", + "iter: 4450, loss: 3.163725920904249\n", + "iter: 4500, loss: 3.3627441662606743\n", + "iter: 4550, loss: 3.409984823173947\n", + "iter: 4600, loss: 2.8944704760899618\n", + "iter: 4650, loss: 3.0016444209568083\n", + "iter: 4700, loss: 2.8574393688837683\n", + "iter: 4750, loss: 3.1946328716656525\n", + "iter: 4800, loss: 2.768447057353125\n", + "iter: 4850, loss: 3.075327144675784\n", + "iter: 4900, loss: 3.268370175997416\n", + "iter: 4950, loss: 3.1798231331053235\n", + "iter: 5000, loss: 3.3217560536218063\n", + "iter: 5050, loss: 3.006732604223585\n", + "iter: 5100, loss: 3.3575944598061698\n", + "iter: 5150, loss: 2.9057663469655175\n", + "iter: 5200, loss: 2.8928466574502374\n", + "iter: 5250, loss: 3.061066797528948\n", + "iter: 5300, loss: 3.35562970057745\n", + "iter: 5350, loss: 2.9118076042901895\n", + "iter: 5400, loss: 2.9514354321918783\n", + "iter: 5450, loss: 2.9334804391406832\n", + "iter: 5500, loss: 3.204634138440329\n", + "iter: 5550, loss: 2.8140748963961526\n", + "iter: 5600, loss: 3.011708143741365\n", + "iter: 5650, loss: 3.323859388586074\n", + "iter: 5700, loss: 2.8442912295810756\n", + "iter: 5750, loss: 2.80684267281729\n", + "iter: 5800, loss: 3.1174840584860903\n", + "iter: 5850, loss: 2.6991389470478837\n", + "iter: 5900, loss: 2.9698236653237116\n", + "iter: 5950, loss: 3.0238281039586137\n", + "iter: 6000, loss: 2.8812837354947645\n", + "iter: 6050, loss: 3.1709352504639394\n", + "iter: 6100, loss: 2.937920509209709\n", + "iter: 6150, loss: 3.178728113076043\n", + "iter: 6200, loss: 2.8974244089429337\n", + "iter: 6250, loss: 2.809626478180052\n", + "iter: 6300, loss: 2.781241159703996\n", + "iter: 6350, loss: 2.9004218400395105\n", + "iter: 6400, loss: 2.9118271145669246\n", + "iter: 6450, loss: 2.8842602037096787\n", + "iter: 6500, loss: 2.9489114957536966\n", + "iter: 6550, loss: 2.9503131193130736\n", + "iter: 6600, loss: 2.8961831474304187\n", + "iter: 6650, loss: 3.002027267266834\n", + "iter: 6700, loss: 3.0047303264103236\n", + "iter: 6750, loss: 2.958453589060949\n", + "iter: 6800, loss: 2.9524990789852446\n", + "iter: 6850, loss: 2.935619188210321\n", + "iter: 6900, loss: 2.9734530233807033\n", + "iter: 6950, loss: 2.785320390822396\n", + "iter: 7000, loss: 3.1911680922054106\n", + "iter: 7050, loss: 2.7732513120363635\n", + "iter: 7100, loss: 2.7432456348282948\n", + "iter: 7150, loss: 2.823985375283256\n", + "iter: 7200, loss: 2.927504679808541\n", + "iter: 7250, loss: 3.0693400076760184\n", + "iter: 7300, loss: 2.666468213043515\n", + "iter: 7350, loss: 2.808132514378382\n", + "iter: 7400, loss: 2.558679431067573\n", + "iter: 7450, loss: 2.6974468813850763\n", + "iter: 7500, loss: 2.8497490201223457\n", + "iter: 7550, loss: 2.7490190564337236\n", + "iter: 7600, loss: 2.8300208840067427\n", + "iter: 7650, loss: 2.793417969741518\n" + ] + }, + { + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + "Input \u001b[0;32mIn [19]\u001b[0m, in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 2\u001b[0m encoder1 \u001b[38;5;241m=\u001b[39m EncoderRNN(eng_lang\u001b[38;5;241m.\u001b[39mn_words, hidden_size)\u001b[38;5;241m.\u001b[39mto(device)\n\u001b[1;32m 3\u001b[0m attn_decoder1 \u001b[38;5;241m=\u001b[39m AttnDecoderRNN(hidden_size, fra_lang\u001b[38;5;241m.\u001b[39mn_words, dropout_p\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0.1\u001b[39m)\u001b[38;5;241m.\u001b[39mto(device)\n\u001b[0;32m----> 5\u001b[0m \u001b[43mtrainIters\u001b[49m\u001b[43m(\u001b[49m\u001b[43mencoder1\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mattn_decoder1\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m75000\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mprint_every\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m50\u001b[39;49m\u001b[43m)\u001b[49m\n", + "Input \u001b[0;32mIn [16]\u001b[0m, in \u001b[0;36mtrainIters\u001b[0;34m(encoder, decoder, n_iters, print_every, plot_every, learning_rate)\u001b[0m\n\u001b[1;32m 16\u001b[0m input_tensor \u001b[38;5;241m=\u001b[39m training_pair[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m 17\u001b[0m target_tensor \u001b[38;5;241m=\u001b[39m training_pair[\u001b[38;5;241m1\u001b[39m]\n\u001b[0;32m---> 19\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43minput_tensor\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtarget_tensor\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mencoder\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 20\u001b[0m \u001b[43m \u001b[49m\u001b[43mdecoder\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mencoder_optimizer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdecoder_optimizer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcriterion\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 21\u001b[0m print_loss_total \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m loss\n\u001b[1;32m 22\u001b[0m plot_loss_total \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m loss\n", + "Input \u001b[0;32mIn [15]\u001b[0m, in \u001b[0;36mtrain\u001b[0;34m(input_tensor, target_tensor, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion, max_length)\u001b[0m\n\u001b[1;32m 45\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m decoder_input\u001b[38;5;241m.\u001b[39mitem() \u001b[38;5;241m==\u001b[39m EOS_token:\n\u001b[1;32m 46\u001b[0m \u001b[38;5;28;01mbreak\u001b[39;00m\n\u001b[0;32m---> 48\u001b[0m \u001b[43mloss\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 50\u001b[0m encoder_optimizer\u001b[38;5;241m.\u001b[39mstep()\n\u001b[1;32m 51\u001b[0m decoder_optimizer\u001b[38;5;241m.\u001b[39mstep()\n", + "File \u001b[0;32m~/anaconda3/envs/zajeciaei/lib/python3.10/site-packages/torch/_tensor.py:363\u001b[0m, in \u001b[0;36mTensor.backward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m 354\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m has_torch_function_unary(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m 355\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m handle_torch_function(\n\u001b[1;32m 356\u001b[0m Tensor\u001b[38;5;241m.\u001b[39mbackward,\n\u001b[1;32m 357\u001b[0m (\u001b[38;5;28mself\u001b[39m,),\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 361\u001b[0m create_graph\u001b[38;5;241m=\u001b[39mcreate_graph,\n\u001b[1;32m 362\u001b[0m inputs\u001b[38;5;241m=\u001b[39minputs)\n\u001b[0;32m--> 363\u001b[0m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mautograd\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgradient\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/anaconda3/envs/zajeciaei/lib/python3.10/site-packages/torch/autograd/__init__.py:173\u001b[0m, in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m 168\u001b[0m retain_graph \u001b[38;5;241m=\u001b[39m create_graph\n\u001b[1;32m 170\u001b[0m \u001b[38;5;66;03m# The reason we repeat same the comment below is that\u001b[39;00m\n\u001b[1;32m 171\u001b[0m \u001b[38;5;66;03m# some Python versions print out the first line of a multi-line function\u001b[39;00m\n\u001b[1;32m 172\u001b[0m \u001b[38;5;66;03m# calls in the traceback and some print out the last line\u001b[39;00m\n\u001b[0;32m--> 173\u001b[0m \u001b[43mVariable\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_execution_engine\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun_backward\u001b[49m\u001b[43m(\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# Calls into the C++ engine to run the backward pass\u001b[39;49;00m\n\u001b[1;32m 174\u001b[0m \u001b[43m \u001b[49m\u001b[43mtensors\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgrad_tensors_\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 175\u001b[0m \u001b[43m \u001b[49m\u001b[43mallow_unreachable\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maccumulate_grad\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m: " + ] + } + ], + "source": [ + "hidden_size = 256\n", + "encoder1 = EncoderRNN(eng_lang.n_words, hidden_size).to(device)\n", + "attn_decoder1 = AttnDecoderRNN(hidden_size, fra_lang.n_words, dropout_p=0.1).to(device)\n", + "\n", + "trainIters(encoder1, attn_decoder1, 75000, print_every=50)" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "> you re sad .\n", + "= tu es triste .\n", + "< vous tes . . \n", + "\n", + "> she is sewing a dress .\n", + "= elle coud une robe .\n", + "< elle est une une . . \n", + "\n", + "> he is suffering from a headache .\n", + "= il souffre d un mal de t te .\n", + "< il est un un un un . \n", + "\n", + "> i m glad to see you .\n", + "= je suis heureux de vous voir .\n", + "< je suis content de vous voir . \n", + "\n", + "> you are only young once .\n", + "= on n est jeune qu une fois .\n", + "< vous tes trop plus une enfant . \n", + "\n", + "> you re so sweet .\n", + "= vous tes si gentille !\n", + "< vous tes trop si . \n", + "\n", + "> i m running out of closet space .\n", + "= je manque d espace dans mon placard .\n", + "< je suis un de de \n", + "\n", + "> i m sort of an extrovert .\n", + "= je suis en quelque sorte extraverti .\n", + "< je suis un un . . \n", + "\n", + "> i m out of practice .\n", + "= je manque de pratique .\n", + "< j ai ai pas de \n", + "\n", + "> you re the last hope for humanity .\n", + "= tu es le dernier espoir de l humanit .\n", + "< vous tes le la la . . \n", + "\n" + ] + } + ], + "source": [ + "evaluateRandomly(encoder1, attn_decoder1)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "author": "Jakub Pokrywka", + "email": "kubapok@wmi.amu.edu.pl", + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "lang": "pl", + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.4" + }, + "subtitle": "0.Informacje na temat przedmiotu[ćwiczenia]", + "title": "Ekstrakcja informacji", + "year": "2021" + }, + "nbformat": 4, + "nbformat_minor": 4 +}