11
This commit is contained in:
parent
a31ef88426
commit
8f20a86ea7
830
cw/11_Model_rekurencyjny_z_atencją.ipynb
Normal file
830
cw/11_Model_rekurencyjny_z_atencją.ipynb
Normal file
@ -0,0 +1,830 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"![Logo 1](https://git.wmi.amu.edu.pl/AITech/Szablon/raw/branch/master/Logotyp_AITech1.jpg)\n",
|
||||
"<div class=\"alert alert-block alert-info\">\n",
|
||||
"<h1> Modelowanie Języka</h1>\n",
|
||||
"<h2> 10. <i>Model rekurencyjny z atencją</i> [ćwiczenia]</h2> \n",
|
||||
"<h3> Jakub Pokrywka (2022)</h3>\n",
|
||||
"</div>\n",
|
||||
"\n",
|
||||
"![Logo 2](https://git.wmi.amu.edu.pl/AITech/Szablon/raw/branch/master/Logotyp_AITech2.jpg)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"notebook na podstawie:\n",
|
||||
"\n",
|
||||
"# https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from __future__ import unicode_literals, print_function, division\n",
|
||||
"from io import open\n",
|
||||
"import unicodedata\n",
|
||||
"import string\n",
|
||||
"import re\n",
|
||||
"import random\n",
|
||||
"\n",
|
||||
"import torch\n",
|
||||
"import torch.nn as nn\n",
|
||||
"from torch import optim\n",
|
||||
"import torch.nn.functional as F\n",
|
||||
"\n",
|
||||
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"SOS_token = 0\n",
|
||||
"EOS_token = 1\n",
|
||||
"\n",
|
||||
"class Lang:\n",
|
||||
" def __init__(self):\n",
|
||||
" self.word2index = {}\n",
|
||||
" self.word2count = {}\n",
|
||||
" self.index2word = {0: \"SOS\", 1: \"EOS\"}\n",
|
||||
" self.n_words = 2 # Count SOS and EOS\n",
|
||||
"\n",
|
||||
" def addSentence(self, sentence):\n",
|
||||
" for word in sentence.split(' '):\n",
|
||||
" self.addWord(word)\n",
|
||||
"\n",
|
||||
" def addWord(self, word):\n",
|
||||
" if word not in self.word2index:\n",
|
||||
" self.word2index[word] = self.n_words\n",
|
||||
" self.word2count[word] = 1\n",
|
||||
" self.index2word[self.n_words] = word\n",
|
||||
" self.n_words += 1\n",
|
||||
" else:\n",
|
||||
" self.word2count[word] += 1"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"pairs = []\n",
|
||||
"with open('data/eng-pol.txt') as f:\n",
|
||||
" for line in f:\n",
|
||||
" eng_line, pol_line = line.lower().rstrip().split('\\t')\n",
|
||||
"\n",
|
||||
" eng_line = re.sub(r\"([.!?])\", r\" \\1\", eng_line)\n",
|
||||
" eng_line = re.sub(r\"[^a-zA-Z.!?]+\", r\" \", eng_line)\n",
|
||||
"\n",
|
||||
" pol_line = re.sub(r\"([.!?])\", r\" \\1\", pol_line)\n",
|
||||
" pol_line = re.sub(r\"[^a-zA-Z.!?ąćęłńóśźżĄĆĘŁŃÓŚŹŻ]+\", r\" \", pol_line)\n",
|
||||
"\n",
|
||||
" pairs.append([eng_line, pol_line])\n",
|
||||
"\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"['hi .', 'cześć .']"
|
||||
]
|
||||
},
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"pairs[1]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"MAX_LENGTH = 10\n",
|
||||
"eng_prefixes = (\n",
|
||||
" \"i am \", \"i m \",\n",
|
||||
" \"he is\", \"he s \",\n",
|
||||
" \"she is\", \"she s \",\n",
|
||||
" \"you are\", \"you re \",\n",
|
||||
" \"we are\", \"we re \",\n",
|
||||
" \"they are\", \"they re \"\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"pairs = [p for p in pairs if len(p[0].split(' ')) < MAX_LENGTH and len(p[1].split(' ')) < MAX_LENGTH]\n",
|
||||
"pairs = [p for p in pairs if p[0].startswith(eng_prefixes)]\n",
|
||||
"\n",
|
||||
"eng_lang = Lang()\n",
|
||||
"pol_lang = Lang()\n",
|
||||
"\n",
|
||||
"for pair in pairs:\n",
|
||||
" eng_lang.addSentence(pair[0])\n",
|
||||
" pol_lang.addSentence(pair[1])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"['i m ok .', 'ze mną wszystko w porządku .']"
|
||||
]
|
||||
},
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"pairs[0]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"['i m up .', 'wstałem .']"
|
||||
]
|
||||
},
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"pairs[1]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"['i m tom .', 'jestem tom .']"
|
||||
]
|
||||
},
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"pairs[2]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"1828"
|
||||
]
|
||||
},
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"eng_lang.n_words"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"2883"
|
||||
]
|
||||
},
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"pol_lang.n_words"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class EncoderRNN(nn.Module):\n",
|
||||
" def __init__(self, input_size, embedding_size, hidden_size):\n",
|
||||
" super(EncoderRNN, self).__init__()\n",
|
||||
" self.embedding_size = 200\n",
|
||||
" self.hidden_size = hidden_size\n",
|
||||
"\n",
|
||||
" self.embedding = nn.Embedding(input_size, self.embedding_size)\n",
|
||||
" self.gru = nn.GRU(self.embedding_size, hidden_size)\n",
|
||||
"\n",
|
||||
" def forward(self, input, hidden):\n",
|
||||
" embedded = self.embedding(input).view(1, 1, -1)\n",
|
||||
" output = embedded\n",
|
||||
" output, hidden = self.gru(output, hidden)\n",
|
||||
" return output, hidden\n",
|
||||
"\n",
|
||||
" def initHidden(self):\n",
|
||||
" return torch.zeros(1, 1, self.hidden_size, device=device)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class DecoderRNN(nn.Module):\n",
|
||||
" def __init__(self, embedding_size, hidden_size, output_size):\n",
|
||||
" super(DecoderRNN, self).__init__()\n",
|
||||
" self.embedding_size = embedding_size\n",
|
||||
" self.hidden_size = hidden_size\n",
|
||||
"\n",
|
||||
" self.embedding = nn.Embedding(output_size, self.embedding_size)\n",
|
||||
" self.gru = nn.GRU(self.embedding_size, hidden_size)\n",
|
||||
" self.out = nn.Linear(hidden_size, output_size)\n",
|
||||
" self.softmax = nn.LogSoftmax(dim=1)\n",
|
||||
"\n",
|
||||
" def forward(self, input, hidden):\n",
|
||||
" output = self.embedding(input).view(1, 1, -1)\n",
|
||||
" output = F.relu(output)\n",
|
||||
" output, hidden = self.gru(output, hidden)\n",
|
||||
" output = self.softmax(self.out(output[0]))\n",
|
||||
" return output, hidden\n",
|
||||
"\n",
|
||||
" def initHidden(self):\n",
|
||||
" return torch.zeros(1, 1, self.hidden_size, device=device)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class AttnDecoderRNN(nn.Module):\n",
|
||||
" def __init__(self, embedding_size, hidden_size, output_size, dropout_p=0.1, max_length=MAX_LENGTH):\n",
|
||||
" super(AttnDecoderRNN, self).__init__()\n",
|
||||
" self.embedding_size = embedding_size\n",
|
||||
" self.hidden_size = hidden_size\n",
|
||||
" self.output_size = output_size\n",
|
||||
" self.dropout_p = dropout_p\n",
|
||||
" self.max_length = max_length\n",
|
||||
"\n",
|
||||
" self.embedding = nn.Embedding(self.output_size, self.embedding_size)\n",
|
||||
" self.attn = nn.Linear(self.hidden_size + self.embedding_size, self.max_length)\n",
|
||||
" self.attn_combine = nn.Linear(self.hidden_size + self.embedding_size, self.embedding_size)\n",
|
||||
" self.dropout = nn.Dropout(self.dropout_p)\n",
|
||||
" self.gru = nn.GRU(self.embedding_size, self.hidden_size)\n",
|
||||
" self.out = nn.Linear(self.hidden_size, self.output_size)\n",
|
||||
"\n",
|
||||
" def forward(self, input, hidden, encoder_outputs):\n",
|
||||
" embedded = self.embedding(input).view(1, 1, -1)\n",
|
||||
" embedded = self.dropout(embedded)\n",
|
||||
"\n",
|
||||
" attn_weights = F.softmax(\n",
|
||||
" self.attn(torch.cat((embedded[0], hidden[0]), 1)), dim=1)\n",
|
||||
" attn_applied = torch.bmm(attn_weights.unsqueeze(0),\n",
|
||||
" encoder_outputs.unsqueeze(0))\n",
|
||||
" #import pdb; pdb.set_trace()\n",
|
||||
"\n",
|
||||
" output = torch.cat((embedded[0], attn_applied[0]), 1)\n",
|
||||
" output = self.attn_combine(output).unsqueeze(0)\n",
|
||||
"\n",
|
||||
" output = F.relu(output)\n",
|
||||
" output, hidden = self.gru(output, hidden)\n",
|
||||
"\n",
|
||||
" output = F.log_softmax(self.out(output[0]), dim=1)\n",
|
||||
" return output, hidden, attn_weights\n",
|
||||
"\n",
|
||||
" def initHidden(self):\n",
|
||||
" return torch.zeros(1, 1, self.hidden_size, device=device)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def tensorFromSentence(sentence, lang):\n",
|
||||
" indexes = [lang.word2index[word] for word in sentence.split(' ')]\n",
|
||||
" indexes.append(EOS_token)\n",
|
||||
" return torch.tensor(indexes, dtype=torch.long, device=device).view(-1, 1)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"teacher_forcing_ratio = 0.5\n",
|
||||
"\n",
|
||||
"def train_one_batch(input_tensor, target_tensor, encoder, decoder, optimizer, criterion, max_length=MAX_LENGTH):\n",
|
||||
" encoder_hidden = encoder.initHidden()\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" optimizer.zero_grad()\n",
|
||||
"\n",
|
||||
" input_length = input_tensor.size(0)\n",
|
||||
" target_length = target_tensor.size(0)\n",
|
||||
"\n",
|
||||
" encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device)\n",
|
||||
"\n",
|
||||
" loss = 0\n",
|
||||
"\n",
|
||||
" for ei in range(input_length):\n",
|
||||
" encoder_output, encoder_hidden = encoder(input_tensor[ei], encoder_hidden)\n",
|
||||
" encoder_outputs[ei] = encoder_output[0, 0]\n",
|
||||
"\n",
|
||||
" decoder_input = torch.tensor([[SOS_token]], device=device)\n",
|
||||
"\n",
|
||||
" decoder_hidden = encoder_hidden\n",
|
||||
"\n",
|
||||
" use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False\n",
|
||||
"\n",
|
||||
" if use_teacher_forcing:\n",
|
||||
" for di in range(target_length):\n",
|
||||
" decoder_output, decoder_hidden, decoder_attention = decoder(decoder_input, decoder_hidden, encoder_outputs)\n",
|
||||
" loss += criterion(decoder_output, target_tensor[di])\n",
|
||||
" decoder_input = target_tensor[di] # Teacher forcing\n",
|
||||
"\n",
|
||||
" else:\n",
|
||||
" for di in range(target_length):\n",
|
||||
" decoder_output, decoder_hidden, decoder_attention = decoder(decoder_input, decoder_hidden, encoder_outputs)\n",
|
||||
" topv, topi = decoder_output.topk(1)\n",
|
||||
" decoder_input = topi.squeeze().detach() # detach from history as input\n",
|
||||
"\n",
|
||||
" loss += criterion(decoder_output, target_tensor[di])\n",
|
||||
" if decoder_input.item() == EOS_token:\n",
|
||||
" break\n",
|
||||
"\n",
|
||||
" loss.backward()\n",
|
||||
"\n",
|
||||
" optimizer.step()\n",
|
||||
"\n",
|
||||
" return loss.item() / target_length"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def trainIters(encoder, decoder, n_iters, print_every=1000, learning_rate=0.01):\n",
|
||||
" print_loss_total = 0 # Reset every print_every\n",
|
||||
" encoder.train()\n",
|
||||
" decoder.train()\n",
|
||||
"\n",
|
||||
" optimizer = optim.SGD(list(encoder.parameters()) + list(decoder.parameters()), lr=learning_rate)\n",
|
||||
" \n",
|
||||
" training_pairs = [random.choice(pairs) for _ in range(n_iters)]\n",
|
||||
" training_pairs = [(tensorFromSentence(p[0], eng_lang), tensorFromSentence(p[1], pol_lang)) for p in training_pairs]\n",
|
||||
" \n",
|
||||
" criterion = nn.NLLLoss()\n",
|
||||
"\n",
|
||||
" for i in range(1, n_iters + 1):\n",
|
||||
" training_pair = training_pairs[i - 1]\n",
|
||||
" input_tensor = training_pair[0]\n",
|
||||
" target_tensor = training_pair[1]\n",
|
||||
"\n",
|
||||
" loss = train_one_batch(input_tensor,\n",
|
||||
" target_tensor,\n",
|
||||
" encoder,\n",
|
||||
" decoder,\n",
|
||||
" optimizer,\n",
|
||||
"\n",
|
||||
" criterion)\n",
|
||||
" \n",
|
||||
" print_loss_total += loss\n",
|
||||
"\n",
|
||||
" if i % print_every == 0:\n",
|
||||
" print_loss_avg = print_loss_total / print_every\n",
|
||||
" print_loss_total = 0\n",
|
||||
" print(f'iter: {i}, loss: {print_loss_avg}')\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 17,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def evaluate(encoder, decoder, sentence, max_length=MAX_LENGTH):\n",
|
||||
" encoder.eval()\n",
|
||||
" decoder.eval()\n",
|
||||
" with torch.no_grad():\n",
|
||||
" input_tensor = tensorFromSentence(sentence, eng_lang)\n",
|
||||
" input_length = input_tensor.size()[0]\n",
|
||||
" encoder_hidden = encoder.initHidden()\n",
|
||||
"\n",
|
||||
" encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device)\n",
|
||||
"\n",
|
||||
" for ei in range(input_length):\n",
|
||||
" encoder_output, encoder_hidden = encoder(input_tensor[ei], encoder_hidden)\n",
|
||||
" encoder_outputs[ei] += encoder_output[0, 0]\n",
|
||||
"\n",
|
||||
" decoder_input = torch.tensor([[SOS_token]], device=device)\n",
|
||||
"\n",
|
||||
" decoder_hidden = encoder_hidden\n",
|
||||
"\n",
|
||||
" decoded_words = []\n",
|
||||
" decoder_attentions = torch.zeros(max_length, max_length)\n",
|
||||
"\n",
|
||||
" for di in range(max_length):\n",
|
||||
" decoder_output, decoder_hidden, decoder_attention = decoder(\n",
|
||||
" decoder_input, decoder_hidden, encoder_outputs)\n",
|
||||
" decoder_attentions[di] = decoder_attention.data\n",
|
||||
" topv, topi = decoder_output.data.topk(1)\n",
|
||||
" if topi.item() == EOS_token:\n",
|
||||
" decoded_words.append('<EOS>')\n",
|
||||
" break\n",
|
||||
" else:\n",
|
||||
" decoded_words.append(pol_lang.index2word[topi.item()])\n",
|
||||
"\n",
|
||||
" decoder_input = topi.squeeze().detach()\n",
|
||||
"\n",
|
||||
" return decoded_words, decoder_attentions[:di + 1]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def evaluateRandomly(encoder, decoder, n=10):\n",
|
||||
" for i in range(n):\n",
|
||||
" pair = random.choice(pairs)\n",
|
||||
" print('>', pair[0])\n",
|
||||
" print('=', pair[1])\n",
|
||||
" output_words, attentions = evaluate(encoder, decoder, pair[0])\n",
|
||||
" output_sentence = ' '.join(output_words)\n",
|
||||
" print('<', output_sentence)\n",
|
||||
" print('')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 19,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"embedding_size = 200\n",
|
||||
"hidden_size = 256\n",
|
||||
"encoder1 = EncoderRNN(eng_lang.n_words, embedding_size, hidden_size).to(device)\n",
|
||||
"attn_decoder1 = AttnDecoderRNN(embedding_size, hidden_size, pol_lang.n_words, dropout_p=0.1).to(device)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 20,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"iter: 50, loss: 5.042555550272503\n",
|
||||
"iter: 100, loss: 4.143612308138894\n",
|
||||
"iter: 150, loss: 4.258466395877656\n",
|
||||
"iter: 200, loss: 4.078979822052849\n",
|
||||
"iter: 250, loss: 3.9038650802657715\n",
|
||||
"iter: 300, loss: 4.07207449336279\n",
|
||||
"iter: 350, loss: 3.940484183538527\n",
|
||||
"iter: 400, loss: 4.425489738524906\n",
|
||||
"iter: 450, loss: 3.9398847290826224\n",
|
||||
"iter: 500, loss: 4.264409653027852\n",
|
||||
"iter: 550, loss: 4.323172234974209\n",
|
||||
"iter: 600, loss: 4.22224827657427\n",
|
||||
"iter: 650, loss: 4.204052018634857\n",
|
||||
"iter: 700, loss: 3.9438682432023295\n",
|
||||
"iter: 750, loss: 4.001692515509468\n",
|
||||
"iter: 800, loss: 4.054982795352028\n",
|
||||
"iter: 850, loss: 4.119050166281443\n",
|
||||
"iter: 900, loss: 3.908679961704073\n",
|
||||
"iter: 950, loss: 4.136870030266898\n",
|
||||
"iter: 1000, loss: 3.8147727276938297\n",
|
||||
"iter: 1050, loss: 4.026022962623171\n",
|
||||
"iter: 1100, loss: 3.9598817706335154\n",
|
||||
"iter: 1150, loss: 3.848097898089696\n",
|
||||
"iter: 1200, loss: 4.01016833985041\n",
|
||||
"iter: 1250, loss: 3.7720014858472917\n",
|
||||
"iter: 1300, loss: 4.059876484976874\n",
|
||||
"iter: 1350, loss: 3.8380891363658605\n",
|
||||
"iter: 1400, loss: 4.013203263676356\n",
|
||||
"iter: 1450, loss: 4.067137318686833\n",
|
||||
"iter: 1500, loss: 4.020450985673874\n",
|
||||
"iter: 1550, loss: 3.7160321428662244\n",
|
||||
"iter: 1600, loss: 3.8411714478977137\n",
|
||||
"iter: 1650, loss: 3.7125136051177985\n",
|
||||
"iter: 1700, loss: 3.705152728769514\n",
|
||||
"iter: 1750, loss: 3.9118153427441915\n",
|
||||
"iter: 1800, loss: 3.857195938375262\n",
|
||||
"iter: 1850, loss: 3.9566935270703025\n",
|
||||
"iter: 1900, loss: 3.9394864430957375\n",
|
||||
"iter: 1950, loss: 3.636212232317243\n",
|
||||
"iter: 2000, loss: 3.847666795261321\n",
|
||||
"iter: 2050, loss: 3.787096965411352\n",
|
||||
"iter: 2100, loss: 3.4702608700933912\n",
|
||||
"iter: 2150, loss: 3.727882717624543\n",
|
||||
"iter: 2200, loss: 3.6961711362884153\n",
|
||||
"iter: 2250, loss: 3.870331466848889\n",
|
||||
"iter: 2300, loss: 3.8506508341743837\n",
|
||||
"iter: 2350, loss: 3.803002176814609\n",
|
||||
"iter: 2400, loss: 3.5700957290558586\n",
|
||||
"iter: 2450, loss: 3.5328896935326712\n",
|
||||
"iter: 2500, loss: 3.810194352997674\n",
|
||||
"iter: 2550, loss: 3.713556599700262\n",
|
||||
"iter: 2600, loss: 3.6131167711303345\n",
|
||||
"iter: 2650, loss: 3.433012700254954\n",
|
||||
"iter: 2700, loss: 3.7313271602903084\n",
|
||||
"iter: 2750, loss: 3.5837062497366037\n",
|
||||
"iter: 2800, loss: 3.6265894929265214\n",
|
||||
"iter: 2850, loss: 3.5165250884616186\n",
|
||||
"iter: 2900, loss: 3.8752988719410366\n",
|
||||
"iter: 2950, loss: 3.709828086020455\n",
|
||||
"iter: 3000, loss: 3.742527751090035\n",
|
||||
"iter: 3050, loss: 3.5926183513232646\n",
|
||||
"iter: 3100, loss: 3.6629667194003157\n",
|
||||
"iter: 3150, loss: 3.7953110780715944\n",
|
||||
"iter: 3200, loss: 3.4833724756770663\n",
|
||||
"iter: 3250, loss: 3.5239689500066977\n",
|
||||
"iter: 3300, loss: 3.552185758560423\n",
|
||||
"iter: 3350, loss: 3.342997217700594\n",
|
||||
"iter: 3400, loss: 3.7131163925897512\n",
|
||||
"iter: 3450, loss: 3.2172264359110874\n",
|
||||
"iter: 3500, loss: 3.1694674255961464\n",
|
||||
"iter: 3550, loss: 3.5181667824548386\n",
|
||||
"iter: 3600, loss: 3.552696303821745\n",
|
||||
"iter: 3650, loss: 3.5465369727573703\n",
|
||||
"iter: 3700, loss: 3.3895190108844213\n",
|
||||
"iter: 3750, loss: 3.55357305569119\n",
|
||||
"iter: 3800, loss: 3.618841464133489\n",
|
||||
"iter: 3850, loss: 3.631707963504488\n",
|
||||
"iter: 3900, loss: 3.705602922939119\n",
|
||||
"iter: 3950, loss: 3.1555525365556987\n",
|
||||
"iter: 4000, loss: 3.423284879676879\n",
|
||||
"iter: 4050, loss: 3.74216214027859\n",
|
||||
"iter: 4100, loss: 3.273874522224304\n",
|
||||
"iter: 4150, loss: 3.9754231488666836\n",
|
||||
"iter: 4200, loss: 3.255707532473973\n",
|
||||
"iter: 4250, loss: 3.622867019956075\n",
|
||||
"iter: 4300, loss: 3.3847267730198216\n",
|
||||
"iter: 4350, loss: 3.6832511274095565\n",
|
||||
"iter: 4400, loss: 3.265418997968946\n",
|
||||
"iter: 4450, loss: 3.53306358509972\n",
|
||||
"iter: 4500, loss: 3.2655868359520333\n",
|
||||
"iter: 4550, loss: 3.579948601419965\n",
|
||||
"iter: 4600, loss: 3.554656519799005\n",
|
||||
"iter: 4650, loss: 3.324159849643708\n",
|
||||
"iter: 4700, loss: 3.357913894865249\n",
|
||||
"iter: 4750, loss: 3.048288846031067\n",
|
||||
"iter: 4800, loss: 3.185154194937811\n",
|
||||
"iter: 4850, loss: 2.9646709245159513\n",
|
||||
"iter: 4900, loss: 3.4766449508288546\n",
|
||||
"iter: 4950, loss: 3.1528075372302338\n",
|
||||
"iter: 5000, loss: 3.12558690051427\n",
|
||||
"iter: 5050, loss: 3.6565875165273276\n",
|
||||
"iter: 5100, loss: 3.113538140228817\n",
|
||||
"iter: 5150, loss: 3.0463946421638366\n",
|
||||
"iter: 5200, loss: 3.384180574084086\n",
|
||||
"iter: 5250, loss: 3.3104316232090913\n",
|
||||
"iter: 5300, loss: 2.9496352179807332\n",
|
||||
"iter: 5350, loss: 3.1814023027722804\n",
|
||||
"iter: 5400, loss: 2.9286732437345724\n",
|
||||
"iter: 5450, loss: 3.4691178646617464\n",
|
||||
"iter: 5500, loss: 3.373944672122834\n",
|
||||
"iter: 5550, loss: 3.213332776455653\n",
|
||||
"iter: 5600, loss: 3.3247368506931116\n",
|
||||
"iter: 5650, loss: 3.2702379176957272\n",
|
||||
"iter: 5700, loss: 3.4554740653038025\n",
|
||||
"iter: 5750, loss: 3.281306777431851\n",
|
||||
"iter: 5800, loss: 2.9936736260368706\n",
|
||||
"iter: 5850, loss: 3.277740831851959\n",
|
||||
"iter: 5900, loss: 3.120459364088754\n",
|
||||
"iter: 5950, loss: 3.387252744160001\n",
|
||||
"iter: 6000, loss: 3.238504883735898\n",
|
||||
"iter: 6050, loss: 2.738152531003195\n",
|
||||
"iter: 6100, loss: 3.231002421265556\n",
|
||||
"iter: 6150, loss: 3.0410601262819195\n",
|
||||
"iter: 6200, loss: 3.093445486522856\n",
|
||||
"iter: 6250, loss: 2.877119398207891\n",
|
||||
"iter: 6300, loss: 3.006740029849703\n",
|
||||
"iter: 6350, loss: 2.8918780979504657\n",
|
||||
"iter: 6400, loss: 3.3124666434015553\n",
|
||||
"iter: 6450, loss: 3.170363757602752\n",
|
||||
"iter: 6500, loss: 3.1445780278387527\n",
|
||||
"iter: 6550, loss: 3.0042706321610346\n",
|
||||
"iter: 6600, loss: 2.94450242013023\n",
|
||||
"iter: 6650, loss: 3.1747314814840046\n",
|
||||
"iter: 6700, loss: 3.325715871651966\n",
|
||||
"iter: 6750, loss: 3.1039765825120225\n",
|
||||
"iter: 6800, loss: 3.260562201068516\n",
|
||||
"iter: 6850, loss: 2.95558365320024\n",
|
||||
"iter: 6900, loss: 3.1284036347071327\n",
|
||||
"iter: 6950, loss: 3.161784927746607\n",
|
||||
"iter: 7000, loss: 3.083566860369275\n",
|
||||
"iter: 7050, loss: 3.1606678485643296\n",
|
||||
"iter: 7100, loss: 3.39304134529356\n",
|
||||
"iter: 7150, loss: 3.05389289476001\n",
|
||||
"iter: 7200, loss: 3.171286074725408\n",
|
||||
"iter: 7250, loss: 3.307133579034654\n",
|
||||
"iter: 7300, loss: 2.987511603022379\n",
|
||||
"iter: 7350, loss: 3.1221464098370264\n",
|
||||
"iter: 7400, loss: 2.9686622249966574\n",
|
||||
"iter: 7450, loss: 2.874706161885035\n",
|
||||
"iter: 7500, loss: 2.759323406164608\n",
|
||||
"iter: 7550, loss: 2.835318256658221\n",
|
||||
"iter: 7600, loss: 2.896953154404958\n",
|
||||
"iter: 7650, loss: 2.8871691599497717\n",
|
||||
"iter: 7700, loss: 3.049550093332927\n",
|
||||
"iter: 7750, loss: 2.9703013692507665\n",
|
||||
"iter: 7800, loss: 2.8142153175671893\n",
|
||||
"iter: 7850, loss: 2.8352768955987604\n",
|
||||
"iter: 7900, loss: 2.863677294496506\n",
|
||||
"iter: 7950, loss: 3.031682641491057\n",
|
||||
"iter: 8000, loss: 2.9286883136809814\n",
|
||||
"iter: 8050, loss: 2.9240697879488504\n",
|
||||
"iter: 8100, loss: 3.0172221147900546\n",
|
||||
"iter: 8150, loss: 2.8361169849426027\n",
|
||||
"iter: 8200, loss: 2.9860127468676803\n",
|
||||
"iter: 8250, loss: 2.9495567634294906\n",
|
||||
"iter: 8300, loss: 2.793946119104113\n",
|
||||
"iter: 8350, loss: 3.2106793221594785\n",
|
||||
"iter: 8400, loss: 2.736634517018757\n",
|
||||
"iter: 8450, loss: 2.8962079345536615\n",
|
||||
"iter: 8500, loss: 2.906407202516283\n",
|
||||
"iter: 8550, loss: 2.6900012663281148\n",
|
||||
"iter: 8600, loss: 2.8905927643056897\n",
|
||||
"iter: 8650, loss: 2.950769727600945\n",
|
||||
"iter: 8700, loss: 2.884238138978443\n",
|
||||
"iter: 8750, loss: 2.7154052526648083\n",
|
||||
"iter: 8800, loss: 2.8823739119030183\n",
|
||||
"iter: 8850, loss: 2.93061117755799\n",
|
||||
"iter: 8900, loss: 2.658344201617771\n",
|
||||
"iter: 8950, loss: 2.5747124820644887\n",
|
||||
"iter: 9000, loss: 2.8281182004307954\n",
|
||||
"iter: 9050, loss: 2.6702445936959895\n",
|
||||
"iter: 9100, loss: 2.8030708763485865\n",
|
||||
"iter: 9150, loss: 3.0742075329053966\n",
|
||||
"iter: 9200, loss: 2.7834522392787635\n",
|
||||
"iter: 9250, loss: 2.9308865650949025\n",
|
||||
"iter: 9300, loss: 2.776913931453039\n",
|
||||
"iter: 9350, loss: 2.7998796779011923\n",
|
||||
"iter: 9400, loss: 3.1615792548088795\n",
|
||||
"iter: 9450, loss: 3.2742855516539673\n",
|
||||
"iter: 9500, loss: 2.981044085154457\n",
|
||||
"iter: 9550, loss: 2.4407524968101866\n",
|
||||
"iter: 9600, loss: 2.624275121037923\n",
|
||||
"iter: 9650, loss: 2.4893303714971697\n",
|
||||
"iter: 9700, loss: 2.7211539438906183\n",
|
||||
"iter: 9750, loss: 2.8714180671828133\n",
|
||||
"iter: 9800, loss: 2.7188037380396373\n",
|
||||
"iter: 9850, loss: 2.4101966271173385\n",
|
||||
"iter: 9900, loss: 2.9492219283542926\n",
|
||||
"iter: 9950, loss: 2.547067801430112\n",
|
||||
"iter: 10000, loss: 2.8521263429191372\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"trainIters(encoder1, attn_decoder1, 10_000, print_every=50)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 21,
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"> he is a tennis player .\n",
|
||||
"= on jest tenisistą .\n",
|
||||
"< jest tenisistą . <EOS>\n",
|
||||
"\n",
|
||||
"> i m not going to change my mind .\n",
|
||||
"= nie zamierzam zmieniać zdania .\n",
|
||||
"< nie idę do . <EOS>\n",
|
||||
"\n",
|
||||
"> i m totally confused .\n",
|
||||
"= jestem kompletnie zmieszany .\n",
|
||||
"< jestem dziś . . <EOS>\n",
|
||||
"\n",
|
||||
"> he is a pioneer in this field .\n",
|
||||
"= jest pionierem w tej dziedzinie .\n",
|
||||
"< on jest w w . . <EOS>\n",
|
||||
"\n",
|
||||
"> i m so excited .\n",
|
||||
"= jestem taki podekscytowany !\n",
|
||||
"< jestem jestem głodny . <EOS>\n",
|
||||
"\n",
|
||||
"> they are a party of six .\n",
|
||||
"= jest ich sześć osób .\n",
|
||||
"< oni nie są . . <EOS>\n",
|
||||
"\n",
|
||||
"> he is the father of two children .\n",
|
||||
"= on jest ojcem dwójki dzieci .\n",
|
||||
"< on jest na do . . <EOS>\n",
|
||||
"\n",
|
||||
"> i am leaving at four .\n",
|
||||
"= wychodzę o czwartej .\n",
|
||||
"< jestem na . <EOS>\n",
|
||||
"\n",
|
||||
"> i m not much of a writer .\n",
|
||||
"= pisarz ze mnie żaden .\n",
|
||||
"< nie jestem mnie . . <EOS>\n",
|
||||
"\n",
|
||||
"> you re disgusting !\n",
|
||||
"= jesteś obrzydliwy !\n",
|
||||
"< jesteś obrzydliwy . <EOS>\n",
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"evaluateRandomly(encoder1, attn_decoder1)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"author": "Jakub Pokrywka",
|
||||
"email": "kubapok@wmi.amu.edu.pl",
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"lang": "pl",
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.4"
|
||||
},
|
||||
"subtitle": "0.Informacje na temat przedmiotu[ćwiczenia]",
|
||||
"title": "Ekstrakcja informacji",
|
||||
"year": "2021"
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
46211
cw/data/eng-pol.txt
Normal file
46211
cw/data/eng-pol.txt
Normal file
File diff suppressed because it is too large
Load Diff
1
cw/data/eng-pol.txt.README
Normal file
1
cw/data/eng-pol.txt.README
Normal file
@ -0,0 +1 @@
|
||||
Plik eng-pol.txt pochodzi z eng-pol.txt phttps://www.manythings.org/anki/o
|
Loading…
Reference in New Issue
Block a user