1156 lines
121 KiB
Plaintext
1156 lines
121 KiB
Plaintext
|
{
|
||
|
"cells": [
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"source": [
|
||
|
"### Importy"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
}
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 84,
|
||
|
"metadata": {
|
||
|
"collapsed": true,
|
||
|
"ExecuteTime": {
|
||
|
"start_time": "2024-06-02T19:58:41.249607Z",
|
||
|
"end_time": "2024-06-02T19:58:41.261609Z"
|
||
|
}
|
||
|
},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"from __future__ import unicode_literals, print_function, division\n",
|
||
|
"from io import open\n",
|
||
|
"import unicodedata\n",
|
||
|
"import re\n",
|
||
|
"import os\n",
|
||
|
"import random\n",
|
||
|
"import torch\n",
|
||
|
"import pandas as pd\n",
|
||
|
"import torch.nn as nn\n",
|
||
|
"from torch import optim\n",
|
||
|
"import torch.nn.functional as F\n",
|
||
|
"from torchtext.data.metrics import bleu_score\n",
|
||
|
"\n",
|
||
|
"from torch.utils.data import TensorDataset, DataLoader, RandomSampler\n",
|
||
|
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\""
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 9,
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Is CUDA supported by this system? True\n",
|
||
|
"CUDA version: 12.1\n",
|
||
|
"ID of current CUDA device: 0\n",
|
||
|
"Name of current CUDA device: NVIDIA GeForce GTX 1660 Ti\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"print(f'Is CUDA supported by this system? {torch.cuda.is_available()}')\n",
|
||
|
"print(f\"CUDA version: {torch.version.cuda}\")\n",
|
||
|
"\n",
|
||
|
"cuda_id = torch.cuda.current_device()\n",
|
||
|
"print(f'ID of current CUDA device: {torch.cuda.current_device()}')\n",
|
||
|
"\n",
|
||
|
"print(f'Name of current CUDA device: {torch.cuda.get_device_name(cuda_id)}')"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false,
|
||
|
"ExecuteTime": {
|
||
|
"start_time": "2024-06-02T19:20:55.709021Z",
|
||
|
"end_time": "2024-06-02T19:20:55.725023Z"
|
||
|
}
|
||
|
}
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 10,
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"cuda\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
||
|
"print(device)"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false,
|
||
|
"ExecuteTime": {
|
||
|
"start_time": "2024-06-02T19:20:55.996605Z",
|
||
|
"end_time": "2024-06-02T19:20:56.041138Z"
|
||
|
}
|
||
|
}
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"source": [
|
||
|
"### Konwersja słów na tensory"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
}
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 11,
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"SOS_token = 0\n",
|
||
|
"EOS_token = 1\n",
|
||
|
"\n",
|
||
|
"class Lang:\n",
|
||
|
" def __init__(self, name):\n",
|
||
|
" self.name = name\n",
|
||
|
" self.word2index = {}\n",
|
||
|
" self.word2count = {}\n",
|
||
|
" self.index2word = {0: \"SOS\", 1: \"EOS\"}\n",
|
||
|
" self.n_words = 2 # Count SOS and EOS\n",
|
||
|
"\n",
|
||
|
" def addSentence(self, sentence):\n",
|
||
|
" for word in sentence.split(' '):\n",
|
||
|
" self.addWord(word)\n",
|
||
|
"\n",
|
||
|
" def addWord(self, word):\n",
|
||
|
" if word not in self.word2index:\n",
|
||
|
" self.word2index[word] = self.n_words\n",
|
||
|
" self.word2count[word] = 1\n",
|
||
|
" self.index2word[self.n_words] = word\n",
|
||
|
" self.n_words += 1\n",
|
||
|
" else:\n",
|
||
|
" self.word2count[word] += 1"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false,
|
||
|
"ExecuteTime": {
|
||
|
"start_time": "2024-06-02T19:20:59.879666Z",
|
||
|
"end_time": "2024-06-02T19:20:59.893667Z"
|
||
|
}
|
||
|
}
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"source": [
|
||
|
"### Przygotowanie danych"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
}
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 12,
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# Turn a Unicode string to plain ASCII, thanks to\n",
|
||
|
"# https://stackoverflow.com/a/518232/2809427\n",
|
||
|
"def unicodeToAscii(s):\n",
|
||
|
" return ''.join(\n",
|
||
|
" c for c in unicodedata.normalize('NFD', s)\n",
|
||
|
" if unicodedata.category(c) != 'Mn'\n",
|
||
|
" )\n",
|
||
|
"\n",
|
||
|
"# Lowercase, trim, and remove non-letter characters\n",
|
||
|
"def normalizeString(s):\n",
|
||
|
" s = unicodeToAscii(s.lower().strip())\n",
|
||
|
" s = re.sub(r\"([.!?])\", r\" \\1\", s)\n",
|
||
|
" s = re.sub(r\"[^a-zA-Z!?]+\", r\" \", s)\n",
|
||
|
" return s.strip()"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false,
|
||
|
"ExecuteTime": {
|
||
|
"start_time": "2024-06-02T19:21:00.877093Z",
|
||
|
"end_time": "2024-06-02T19:21:00.892090Z"
|
||
|
}
|
||
|
}
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"source": [
|
||
|
"### Wczytanie danych"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
}
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 13,
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"def readLangs(lang1, lang2, reverse=False):\n",
|
||
|
" print(\"Reading lines...\")\n",
|
||
|
" # Read the file and split into lines\n",
|
||
|
" lines = open('data/%s-%s.txt' % (lang1, lang2), encoding='utf-8').\\\n",
|
||
|
" read().strip().split('\\n')\n",
|
||
|
"\n",
|
||
|
" # Split every line into pairs and normalize\n",
|
||
|
" pairs = [[normalizeString(s) for s in l.split('\\t')[:-1]] for l in lines]\n",
|
||
|
"\n",
|
||
|
" # Reverse pairs, make Lang instances\n",
|
||
|
" if reverse:\n",
|
||
|
" pairs = [df_filtered(reversed(p)) for p in pairs]\n",
|
||
|
" input_lang = Lang(lang2)\n",
|
||
|
" output_lang = Lang(lang1)\n",
|
||
|
" else:\n",
|
||
|
" input_lang = Lang(lang1)\n",
|
||
|
" output_lang = Lang(lang2)\n",
|
||
|
"\n",
|
||
|
" return input_lang, output_lang, pairs"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false,
|
||
|
"ExecuteTime": {
|
||
|
"start_time": "2024-06-02T19:21:02.075474Z",
|
||
|
"end_time": "2024-06-02T19:21:02.087474Z"
|
||
|
}
|
||
|
}
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"source": [
|
||
|
"### Filtracja danych"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
}
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"source": [
|
||
|
"Ograniczenie zdań do 10 słów oraz zdań zaczynających się od prefiksów"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
}
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 14,
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"MAX_LENGTH = 10\n",
|
||
|
"\n",
|
||
|
"eng_prefixes = (\n",
|
||
|
" \"i am \", \"i m \",\n",
|
||
|
" \"he is\", \"he s \",\n",
|
||
|
" \"she is\", \"she s \",\n",
|
||
|
" \"you are\", \"you re \",\n",
|
||
|
" \"we are\", \"we re \",\n",
|
||
|
" \"they are\", \"they re \"\n",
|
||
|
")\n",
|
||
|
"\n",
|
||
|
"def filterPair(p):\n",
|
||
|
" return len(p[0].split(' ')) < MAX_LENGTH and \\\n",
|
||
|
" len(p[1].split(' ')) < MAX_LENGTH and \\\n",
|
||
|
" p[1].startswith(eng_prefixes)\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"def filterPairs(pairs):\n",
|
||
|
" return [pair for pair in pairs if filterPair(pair)]"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false,
|
||
|
"ExecuteTime": {
|
||
|
"start_time": "2024-06-02T19:21:03.811303Z",
|
||
|
"end_time": "2024-06-02T19:21:03.829054Z"
|
||
|
}
|
||
|
}
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 15,
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Reading lines...\n",
|
||
|
"Read 49943 sentence pairs\n",
|
||
|
"Trimmed to 3613 sentence pairs\n",
|
||
|
"Counting words...\n",
|
||
|
"Counted words:\n",
|
||
|
"pol 3070\n",
|
||
|
"eng 1969\n",
|
||
|
"['nie umieram', 'i m not dying']\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"def prepareData(lang1, lang2, reverse=False):\n",
|
||
|
" input_lang, output_lang, pairs = readLangs(lang1, lang2, reverse)\n",
|
||
|
" print(\"Read %s sentence pairs\" % len(pairs))\n",
|
||
|
" pairs = filterPairs(pairs)\n",
|
||
|
" print(\"Trimmed to %s sentence pairs\" % len(pairs))\n",
|
||
|
" print(\"Counting words...\")\n",
|
||
|
" for pair in pairs:\n",
|
||
|
" input_lang.addSentence(pair[0])\n",
|
||
|
" output_lang.addSentence(pair[1])\n",
|
||
|
" print(\"Counted words:\")\n",
|
||
|
" print(input_lang.name, input_lang.n_words)\n",
|
||
|
" print(output_lang.name, output_lang.n_words)\n",
|
||
|
" return input_lang, output_lang, pairs\n",
|
||
|
"\n",
|
||
|
"input_lang, output_lang, pairs = prepareData('eng', 'pol' , True)\n",
|
||
|
"print(random.choice(pairs))"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false,
|
||
|
"ExecuteTime": {
|
||
|
"start_time": "2024-06-02T19:21:04.527025Z",
|
||
|
"end_time": "2024-06-02T19:21:06.394023Z"
|
||
|
}
|
||
|
}
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"source": [
|
||
|
"### Model"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
}
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 16,
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"class EncoderRNN(nn.Module):\n",
|
||
|
" def __init__(self, input_size, hidden_size, dropout_p=0.1):\n",
|
||
|
" super(EncoderRNN, self).__init__()\n",
|
||
|
" self.hidden_size = hidden_size\n",
|
||
|
"\n",
|
||
|
" self.embedding = nn.Embedding(input_size, hidden_size)\n",
|
||
|
" self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True)\n",
|
||
|
" self.dropout = nn.Dropout(dropout_p)\n",
|
||
|
"\n",
|
||
|
" def forward(self, input):\n",
|
||
|
" embedded = self.dropout(self.embedding(input))\n",
|
||
|
" output, hidden = self.gru(embedded)\n",
|
||
|
" return output, hidden"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false,
|
||
|
"ExecuteTime": {
|
||
|
"start_time": "2024-06-02T19:21:11.058623Z",
|
||
|
"end_time": "2024-06-02T19:21:11.074974Z"
|
||
|
}
|
||
|
}
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 17,
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"class DecoderRNN(nn.Module):\n",
|
||
|
" def __init__(self, hidden_size, output_size):\n",
|
||
|
" super(DecoderRNN, self).__init__()\n",
|
||
|
" self.embedding = nn.Embedding(output_size, hidden_size)\n",
|
||
|
" self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True)\n",
|
||
|
" self.out = nn.Linear(hidden_size, output_size)\n",
|
||
|
"\n",
|
||
|
" def forward(self, encoder_outputs, encoder_hidden, target_tensor=None):\n",
|
||
|
" batch_size = encoder_outputs.size(0)\n",
|
||
|
" decoder_input = torch.empty(batch_size, 1, dtype=torch.long, device=device).fill_(SOS_token)\n",
|
||
|
" decoder_hidden = encoder_hidden\n",
|
||
|
" decoder_outputs = []\n",
|
||
|
"\n",
|
||
|
" for i in range(MAX_LENGTH):\n",
|
||
|
" decoder_output, decoder_hidden = self.forward_step(decoder_input, decoder_hidden)\n",
|
||
|
" decoder_outputs.append(decoder_output)\n",
|
||
|
"\n",
|
||
|
" if target_tensor is not None:\n",
|
||
|
" # Teacher forcing: Feed the target as the next input\n",
|
||
|
" decoder_input = target_tensor[:, i].unsqueeze(1) # Teacher forcing\n",
|
||
|
" else:\n",
|
||
|
" # Without teacher forcing: use its own predictions as the next input\n",
|
||
|
" _, topi = decoder_output.topk(1)\n",
|
||
|
" decoder_input = topi.squeeze(-1).detach() # detach from history as input\n",
|
||
|
"\n",
|
||
|
" decoder_outputs = torch.cat(decoder_outputs, dim=1)\n",
|
||
|
" decoder_outputs = F.log_softmax(decoder_outputs, dim=-1)\n",
|
||
|
" return decoder_outputs, decoder_hidden, None # We return `None` for consistency in the training loop\n",
|
||
|
"\n",
|
||
|
" def forward_step(self, input, hidden):\n",
|
||
|
" output = self.embedding(input)\n",
|
||
|
" output = F.relu(output)\n",
|
||
|
" output, hidden = self.gru(output, hidden)\n",
|
||
|
" output = self.out(output)\n",
|
||
|
" return output, hidden"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false,
|
||
|
"ExecuteTime": {
|
||
|
"start_time": "2024-06-02T19:21:11.447213Z",
|
||
|
"end_time": "2024-06-02T19:21:11.462232Z"
|
||
|
}
|
||
|
}
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 18,
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"class BahdanauAttention(nn.Module):\n",
|
||
|
" def __init__(self, hidden_size):\n",
|
||
|
" super(BahdanauAttention, self).__init__()\n",
|
||
|
" self.Wa = nn.Linear(hidden_size, hidden_size)\n",
|
||
|
" self.Ua = nn.Linear(hidden_size, hidden_size)\n",
|
||
|
" self.Va = nn.Linear(hidden_size, 1)\n",
|
||
|
"\n",
|
||
|
" def forward(self, query, keys):\n",
|
||
|
" scores = self.Va(torch.tanh(self.Wa(query) + self.Ua(keys)))\n",
|
||
|
" scores = scores.squeeze(2).unsqueeze(1)\n",
|
||
|
"\n",
|
||
|
" weights = F.softmax(scores, dim=-1)\n",
|
||
|
" context = torch.bmm(weights, keys)\n",
|
||
|
"\n",
|
||
|
" return context, weights\n",
|
||
|
"\n",
|
||
|
"class AttnDecoderRNN(nn.Module):\n",
|
||
|
" def __init__(self, hidden_size, output_size, dropout_p=0.1):\n",
|
||
|
" super(AttnDecoderRNN, self).__init__()\n",
|
||
|
" self.embedding = nn.Embedding(output_size, hidden_size)\n",
|
||
|
" self.attention = BahdanauAttention(hidden_size)\n",
|
||
|
" self.gru = nn.GRU(2 * hidden_size, hidden_size, batch_first=True)\n",
|
||
|
" self.out = nn.Linear(hidden_size, output_size)\n",
|
||
|
" self.dropout = nn.Dropout(dropout_p)\n",
|
||
|
"\n",
|
||
|
" def forward(self, encoder_outputs, encoder_hidden, target_tensor=None):\n",
|
||
|
" batch_size = encoder_outputs.size(0)\n",
|
||
|
" decoder_input = torch.empty(batch_size, 1, dtype=torch.long, device=device).fill_(SOS_token)\n",
|
||
|
" decoder_hidden = encoder_hidden\n",
|
||
|
" decoder_outputs = []\n",
|
||
|
" attentions = []\n",
|
||
|
"\n",
|
||
|
" for i in range(MAX_LENGTH):\n",
|
||
|
" decoder_output, decoder_hidden, attn_weights = self.forward_step(\n",
|
||
|
" decoder_input, decoder_hidden, encoder_outputs\n",
|
||
|
" )\n",
|
||
|
" decoder_outputs.append(decoder_output)\n",
|
||
|
" attentions.append(attn_weights)\n",
|
||
|
"\n",
|
||
|
" if target_tensor is not None:\n",
|
||
|
" # Teacher forcing: Feed the target as the next input\n",
|
||
|
" decoder_input = target_tensor[:, i].unsqueeze(1) # Teacher forcing\n",
|
||
|
" else:\n",
|
||
|
" # Without teacher forcing: use its own predictions as the next input\n",
|
||
|
" _, topi = decoder_output.topk(1)\n",
|
||
|
" decoder_input = topi.squeeze(-1).detach() # detach from history as input\n",
|
||
|
"\n",
|
||
|
" decoder_outputs = torch.cat(decoder_outputs, dim=1)\n",
|
||
|
" decoder_outputs = F.log_softmax(decoder_outputs, dim=-1)\n",
|
||
|
" attentions = torch.cat(attentions, dim=1)\n",
|
||
|
"\n",
|
||
|
" return decoder_outputs, decoder_hidden, attentions\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
" def forward_step(self, input, hidden, encoder_outputs):\n",
|
||
|
" embedded = self.dropout(self.embedding(input))\n",
|
||
|
"\n",
|
||
|
" query = hidden.permute(1, 0, 2)\n",
|
||
|
" context, attn_weights = self.attention(query, encoder_outputs)\n",
|
||
|
" input_gru = torch.cat((embedded, context), dim=2)\n",
|
||
|
"\n",
|
||
|
" output, hidden = self.gru(input_gru, hidden)\n",
|
||
|
" output = self.out(output)\n",
|
||
|
"\n",
|
||
|
" return output, hidden, attn_weights"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false,
|
||
|
"ExecuteTime": {
|
||
|
"start_time": "2024-06-02T19:21:12.049305Z",
|
||
|
"end_time": "2024-06-02T19:21:12.073302Z"
|
||
|
}
|
||
|
}
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 21,
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"def indexesFromSentence(lang, sentence):\n",
|
||
|
" return [lang.word2index[word] for word in sentence.split(' ')]\n",
|
||
|
"\n",
|
||
|
"def tensorFromSentence(lang, sentence):\n",
|
||
|
" indexes = indexesFromSentence(lang, sentence)\n",
|
||
|
" indexes.append(EOS_token)\n",
|
||
|
" return torch.tensor(indexes, dtype=torch.long, device=device).view(1, -1)\n",
|
||
|
"\n",
|
||
|
"def tensorsFromPair(pair):\n",
|
||
|
" input_tensor = tensorFromSentence(input_lang, pair[0])\n",
|
||
|
" target_tensor = tensorFromSentence(output_lang, pair[1])\n",
|
||
|
" return (input_tensor, target_tensor)\n",
|
||
|
"\n",
|
||
|
"def get_dataloader(batch_size):\n",
|
||
|
" input_lang, output_lang, pairs = prepareData( 'eng', 'pol', True)\n",
|
||
|
"\n",
|
||
|
" n = len(pairs)\n",
|
||
|
" input_ids = np.zeros((n, MAX_LENGTH), dtype=np.int32)\n",
|
||
|
" target_ids = np.zeros((n, MAX_LENGTH), dtype=np.int32)\n",
|
||
|
"\n",
|
||
|
" for idx, (inp, tgt) in enumerate(pairs):\n",
|
||
|
" inp_ids = indexesFromSentence(input_lang, inp)\n",
|
||
|
" tgt_ids = indexesFromSentence(output_lang, tgt)\n",
|
||
|
" inp_ids.append(EOS_token)\n",
|
||
|
" tgt_ids.append(EOS_token)\n",
|
||
|
" input_ids[idx, :len(inp_ids)] = inp_ids\n",
|
||
|
" target_ids[idx, :len(tgt_ids)] = tgt_ids\n",
|
||
|
"\n",
|
||
|
" train_data = TensorDataset(torch.LongTensor(input_ids).to(device),\n",
|
||
|
" torch.LongTensor(target_ids).to(device))\n",
|
||
|
"\n",
|
||
|
" train_sampler = RandomSampler(train_data)\n",
|
||
|
" train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size)\n",
|
||
|
" return input_lang, output_lang, train_dataloader"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false,
|
||
|
"ExecuteTime": {
|
||
|
"start_time": "2024-06-02T19:23:18.301396Z",
|
||
|
"end_time": "2024-06-02T19:23:18.321420Z"
|
||
|
}
|
||
|
}
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"source": [
|
||
|
"### Trening"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
}
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 22,
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"def train_epoch(dataloader, encoder, decoder, encoder_optimizer,\n",
|
||
|
" decoder_optimizer, criterion):\n",
|
||
|
"\n",
|
||
|
" total_loss = 0\n",
|
||
|
" for data in dataloader:\n",
|
||
|
" input_tensor, target_tensor = data\n",
|
||
|
"\n",
|
||
|
" encoder_optimizer.zero_grad()\n",
|
||
|
" decoder_optimizer.zero_grad()\n",
|
||
|
"\n",
|
||
|
" encoder_outputs, encoder_hidden = encoder(input_tensor)\n",
|
||
|
" decoder_outputs, _, _ = decoder(encoder_outputs, encoder_hidden, target_tensor)\n",
|
||
|
"\n",
|
||
|
" loss = criterion(\n",
|
||
|
" decoder_outputs.view(-1, decoder_outputs.size(-1)),\n",
|
||
|
" target_tensor.view(-1)\n",
|
||
|
" )\n",
|
||
|
" loss.backward()\n",
|
||
|
"\n",
|
||
|
" encoder_optimizer.step()\n",
|
||
|
" decoder_optimizer.step()\n",
|
||
|
"\n",
|
||
|
" total_loss += loss.item()\n",
|
||
|
"\n",
|
||
|
" return total_loss / len(dataloader)"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false,
|
||
|
"ExecuteTime": {
|
||
|
"start_time": "2024-06-02T19:23:19.166843Z",
|
||
|
"end_time": "2024-06-02T19:23:19.182827Z"
|
||
|
}
|
||
|
}
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 23,
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"import time\n",
|
||
|
"import math\n",
|
||
|
"\n",
|
||
|
"def asMinutes(s):\n",
|
||
|
" m = math.floor(s / 60)\n",
|
||
|
" s -= m * 60\n",
|
||
|
" return '%dm %ds' % (m, s)\n",
|
||
|
"\n",
|
||
|
"def timeSince(since, percent):\n",
|
||
|
" now = time.time()\n",
|
||
|
" s = now - since\n",
|
||
|
" es = s / (percent)\n",
|
||
|
" rs = es - s\n",
|
||
|
" return '%s (- %s)' % (asMinutes(s), asMinutes(rs))"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false,
|
||
|
"ExecuteTime": {
|
||
|
"start_time": "2024-06-02T19:23:19.675207Z",
|
||
|
"end_time": "2024-06-02T19:23:19.699207Z"
|
||
|
}
|
||
|
}
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 24,
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"import matplotlib.pyplot as plt\n",
|
||
|
"plt.switch_backend('agg')\n",
|
||
|
"import matplotlib.ticker as ticker\n",
|
||
|
"import numpy as np\n",
|
||
|
"\n",
|
||
|
"def showPlot(points):\n",
|
||
|
" plt.figure()\n",
|
||
|
" fig, ax = plt.subplots()\n",
|
||
|
" # this locator puts ticks at regular intervals\n",
|
||
|
" loc = ticker.MultipleLocator(base=0.2)\n",
|
||
|
" ax.yaxis.set_major_locator(loc)\n",
|
||
|
" plt.plot(points)"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false,
|
||
|
"ExecuteTime": {
|
||
|
"start_time": "2024-06-02T19:23:20.120325Z",
|
||
|
"end_time": "2024-06-02T19:23:20.833674Z"
|
||
|
}
|
||
|
}
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 25,
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"def train(train_dataloader, encoder, decoder, n_epochs, learning_rate=0.001,\n",
|
||
|
" print_every=100, plot_every=100):\n",
|
||
|
" start = time.time()\n",
|
||
|
" plot_losses = []\n",
|
||
|
" print_loss_total = 0 # Reset every print_every\n",
|
||
|
" plot_loss_total = 0 # Reset every plot_every\n",
|
||
|
"\n",
|
||
|
" encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)\n",
|
||
|
" decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate)\n",
|
||
|
" criterion = nn.NLLLoss()\n",
|
||
|
"\n",
|
||
|
" for epoch in range(1, n_epochs + 1):\n",
|
||
|
" loss = train_epoch(train_dataloader, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion)\n",
|
||
|
" print_loss_total += loss\n",
|
||
|
" plot_loss_total += loss\n",
|
||
|
"\n",
|
||
|
" if epoch % print_every == 0:\n",
|
||
|
" print_loss_avg = print_loss_total / print_every\n",
|
||
|
" print_loss_total = 0\n",
|
||
|
" print('%s (%d %d%%) %.4f' % (timeSince(start, epoch / n_epochs),\n",
|
||
|
" epoch, epoch / n_epochs * 100, print_loss_avg))\n",
|
||
|
"\n",
|
||
|
" if epoch % plot_every == 0:\n",
|
||
|
" plot_loss_avg = plot_loss_total / plot_every\n",
|
||
|
" plot_losses.append(plot_loss_avg)\n",
|
||
|
" plot_loss_total = 0\n",
|
||
|
"\n",
|
||
|
" showPlot(plot_losses)"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false,
|
||
|
"ExecuteTime": {
|
||
|
"start_time": "2024-06-02T19:23:21.920756Z",
|
||
|
"end_time": "2024-06-02T19:23:21.949755Z"
|
||
|
}
|
||
|
}
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 98,
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Reading lines...\n",
|
||
|
"Read 49943 sentence pairs\n",
|
||
|
"Trimmed to 3613 sentence pairs\n",
|
||
|
"Counting words...\n",
|
||
|
"Counted words:\n",
|
||
|
"pol 3070\n",
|
||
|
"eng 1969\n",
|
||
|
"0m 7s (- 2m 18s) (5 5%) 1.9851\n",
|
||
|
"0m 14s (- 2m 8s) (10 10%) 1.0089\n",
|
||
|
"0m 21s (- 1m 59s) (15 15%) 0.5189\n",
|
||
|
"0m 28s (- 1m 52s) (20 20%) 0.2294\n",
|
||
|
"0m 35s (- 1m 45s) (25 25%) 0.0961\n",
|
||
|
"0m 42s (- 1m 38s) (30 30%) 0.0509\n",
|
||
|
"0m 50s (- 1m 33s) (35 35%) 0.0355\n",
|
||
|
"0m 57s (- 1m 25s) (40 40%) 0.0289\n",
|
||
|
"1m 4s (- 1m 18s) (45 45%) 0.0249\n",
|
||
|
"1m 11s (- 1m 11s) (50 50%) 0.0228\n",
|
||
|
"1m 18s (- 1m 4s) (55 55%) 0.0207\n",
|
||
|
"1m 25s (- 0m 57s) (60 60%) 0.0215\n",
|
||
|
"1m 32s (- 0m 49s) (65 65%) 0.0249\n",
|
||
|
"1m 39s (- 0m 42s) (70 70%) 0.0184\n",
|
||
|
"1m 47s (- 0m 35s) (75 75%) 0.0172\n",
|
||
|
"1m 55s (- 0m 28s) (80 80%) 0.0166\n",
|
||
|
"2m 3s (- 0m 21s) (85 85%) 0.0163\n",
|
||
|
"2m 11s (- 0m 14s) (90 90%) 0.0163\n",
|
||
|
"2m 18s (- 0m 7s) (95 95%) 0.0176\n",
|
||
|
"2m 27s (- 0m 0s) (100 100%) 0.0256\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": "<Figure size 640x480 with 0 Axes>"
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
},
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": "<Figure size 640x480 with 1 Axes>",
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAiMAAAGdCAYAAADAAnMpAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8fJSN1AAAACXBIWXMAAA9hAAAPYQGoP6dpAAA9OElEQVR4nO3de3yU5Z3///cckkmCyXAIOUEgEcUTGChKjKigpmJ0qfy6W6m6onhoa7HFpu1qaoW17ZraqutWEZSK2K8HUFexVYvFKFAxSjlkKx5QJJAImUBQZnIgp5n790cyQwaSkAmZ3JnM6/l43I9k7rmumc/t7TDv3Nd93bfFMAxDAAAAJrGaXQAAAIhuhBEAAGAqwggAADAVYQQAAJiKMAIAAExFGAEAAKYijAAAAFMRRgAAgKnsZhfQEz6fT/v27VNiYqIsFovZ5QAAgB4wDEO1tbXKyMiQ1dr18Y+ICCP79u1TZmam2WUAAIBeqKys1OjRo7t8PiLCSGJioqS2jUlKSjK5GgAA0BMej0eZmZmB7/GuREQY8Q/NJCUlEUYAAIgwxzvFghNYAQCAqQgjAADAVIQRAABgKsIIAAAwFWEEAACYijACAABMRRgBAACmCimMFBcX69xzz1ViYqJSUlI0e/Zs7dix47j9XnzxRZ1++umKi4vTxIkT9cYbb/S6YAAAMLiEFEbWr1+v+fPn6/3339fatWvV0tKiyy67TPX19V32ee+993TNNdfo5ptv1rZt2zR79mzNnj1b27dvP+HiAQBA5LMYhmH0tvOBAweUkpKi9evX66KLLuq0zZw5c1RfX6/XXnstsO68887TpEmTtHTp0h69j8fjkdPplNvt5gqsAABEiJ5+f5/QOSNut1uSNHz48C7blJaWKj8/P2jdzJkzVVpa2mWfpqYmeTyeoAUAAAxOvQ4jPp9Pd9xxh6ZNm6YJEyZ02c7lcik1NTVoXWpqqlwuV5d9iouL5XQ6Awt37AUAYPDqdRiZP3++tm/frpUrV/ZlPZKkoqIiud3uwFJZWdnn7yFJq7ft1d2vfKgte74Oy+sDAIDj69Vde2+//Xa99tpr2rBhg0aPHt1t27S0NFVXVwetq66uVlpaWpd9HA6HHA5Hb0oLydpPqvX6P6uUnTxEU8YOC/v7AQCAY4V0ZMQwDN1+++165ZVX9Pbbbys7O/u4ffLy8lRSUhK0bu3atcrLywut0jDIHjFEklRe0/VsIAAAEF4hHRmZP3++nnvuOb366qtKTEwMnPfhdDoVHx8vSZo7d65GjRql4uJiSdKCBQs0ffp0Pfjgg7ryyiu1cuVKbd68WU888UQfb0roxo5IkCTtOdhgciUAAESvkI6MLFmyRG63WzNmzFB6enpgWbVqVaBNRUWFqqqqAo/PP/98Pffcc3riiSeUk5Ojl156SatXr+72pNf+kp3MkREAAMwW0pGRnlySZN26dces+853vqPvfOc7obxVvxjbPkyzz31YTa1eOew2kysCACD6RPW9aZJPitVJDrsMQ6r8iqEaAADMENVhxGKxBM4b2V1DGAEAwAxRHUYkKav9vJHdBzlvBAAAMxBG/EdGCCMAAJiCMNJ+EivDNAAAmIMwwjANAACmivow4j+Bdd+htum9AACgf0V9GBl5kkNDYm3yGVLlV4fNLgcAgKgT9WGkbXqv/7wRhmoAAOhvUR9GpCOXhee8EQAA+h9hRNwwDwAAMxFGxIwaAADMRBhRh2uNEEYAAOh3hBFJWcltwzR7vz6s5lafydUAABBdCCM6anrv15w3AgBAfyKMiOm9AACYiTDSzj9Us5sZNQAA9CvCSLssjowAAGAKwkg7ZtQAAGAOwkg7rjUCAIA5CCPtskYwvRcAADMQRtqNTHQogem9AAD0O8JIu47Te/cwVAMAQL8JOYxs2LBBs2bNUkZGhiwWi1avXn3cPs8++6xycnKUkJCg9PR03XTTTTp48GBv6g0r/1BNeQ1HRgAA6C8hh5H6+nrl5ORo8eLFPWq/ceNGzZ07VzfffLM++ugjvfjii9q0aZNuvfXWkIsNN/9JrBwZAQCg/9hD7VBQUKCCgoIety8tLVVWVpZ+/OMfS5Kys7P1/e9/X/fff3+obx12/iMjXPgMAID+E/ZzRvLy8lRZWak33nhDhmGourpaL730kq644oou+zQ1Ncnj8QQt/YELnwEA0P/CHkamTZumZ599VnPmzFFsbKzS0tLkdDq7HeYpLi6W0+kMLJmZmeEuU9KRYZovv25gei8AAP0k7GHk448/1oIFC7Rw4UJt2bJFa9as0e7du/WDH/ygyz5FRUVyu92BpbKyMtxlSpJSEh2Kj2mb3vsl03sBAOgXIZ8zEqri4mJNmzZNP//5zyVJZ599toYMGaILL7xQv/nNb5Senn5MH4fDIYfDEe7SjtE2vTdBn7pqtedgg04eeVK/1wAAQLQJ+5GRhoYGWa3Bb2Oz2SRJhmGE++1Dlt0+VFPOeSMAAPSLkMNIXV2dysrKVFZWJkkqLy9XWVmZKioqJLUNscydOzfQftasWXr55Ze1ZMkS7dq1Sxs3btSPf/xjTZ06VRkZGX2zFX2IC58BANC/Qh6m2bx5sy6++OLA48LCQknSDTfcoBUrVqiqqioQTCTpxhtvVG1trR599FH99Kc/1dChQ3XJJZcMyKm9kpSd3H7hM6b3AgDQLyzGQBwrOYrH45HT6ZTb7VZSUlJY3+v9XQf13Sfe19gRCVr/84uP3wEAAHSqp9/f3JvmKNmB6b2H1eJlei8AAOFGGDmKf3qv12foy68Pm10OAACDHmHkKP7pvRJXYgUAoD8QRjoRuCw8M2oAAAg7wkgnxiZzZAQAgP5CGOlEduDICNN7AQAIN8JIJ7jwGQAA/Ycw0gn/9N5KpvcCABB2hJFOpCQ6FBdjlddnaC/TewEACCvCSCesVktgRk05QzUAAIQVYaQL/muN7GFGDQAAYUUY6UJWMjNqAADoD4SRLnDhMwAA+gdhpAuBMMIwDQAAYUUY6UJW+1VYuXsvAADhRRjpQmpinOJirGplei8AAGFFGOmC1WrR2OGcNwIAQLgRRrqRxQ3zAAAIO8JIN7K4YR4AAGFHGOnGkWuNcGQEAIBwIYx0I3AVVo6MAAAQNoSRbviHaSq/alAr03sBAAgLwkg30pLi5LC3T+89xPReAADCIeQwsmHDBs2aNUsZGRmyWCxavXr1cfs0NTXp7rvv1tixY+VwOJSVlaXly5f3pt5+ZbVaAkM1nMQKAEB42EPtUF9fr5ycHN1000369re/3aM+V199taqrq/Xkk0/qlFNOUVVVlXy+yBj2yBoxRJ9V12l3Tb2mjx9pdjkAAAw6IYeRgoICFRQU9Lj9mjVrtH79eu3atUvDhw+XJGVlZYX6tqZhRg0AAOEV9nNG/vznP+ucc87R7373O40aNUrjx4/Xz372Mx0+3PU5GE1NTfJ4PEGLWbhhHgAA4RXykZFQ7dq1S++++67i4uL0yiuvqKamRj/84Q918OBBPfXUU532KS4u1r333hvu0noki+m9AACEVdiPjPh8PlksFj377LOaOnWqrrjiCj300EN6+umnuzw6UlRUJLfbHVgqKyvDXWaX/MM0FUzvBQAgLMIeRtLT0zVq1Cg5nc7AujPOOEOGYejLL7/stI/D4VBSUlLQYpaO03v3HWo0rQ4AAAarsIeRadOmad++faqrqwus++yzz2S1WjV69Ohwv/0J6zi9t5yTWAEA6HMhh5G6ujqVlZWprKxMklReXq6ysjJVVFRIahtimTt3bqD9tddeqxEjRmjevHn6+OOPtWHDBv385z/XTTfdpPj4+L7ZijAb234S6x7CCAAAfS7kMLJ582ZNnjxZkydPliQVFhZq8uTJWrhwoSSpqqoqEEwk6aSTTtLatWt16NAhnXPOObruuus0a9Ys/eEPf+ijTQi/7PbzRsqZUQMAQJ8LeTbNjBkzZBhGl8+vWLHimHWnn3661q5dG+pbDRjcMA8AgPDh3jQ9kM21RgAACBvCSA+MbR+mqfya6b0AAPQ1wkgPpCfFKdZuVYvXUJWb6b0AAPQlwkgPWK0WjR3
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"hidden_size = 256\n",
|
||
|
"batch_size = 64\n",
|
||
|
"\n",
|
||
|
"input_lang, output_lang, train_dataloader = get_dataloader(batch_size)\n",
|
||
|
"\n",
|
||
|
"encoder = EncoderRNN(input_lang.n_words, hidden_size).to(device)\n",
|
||
|
"decoder = AttnDecoderRNN(hidden_size, output_lang.n_words).to(device)\n",
|
||
|
"\n",
|
||
|
"train(train_dataloader, encoder, decoder, 100, print_every=5, plot_every=5)"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false,
|
||
|
"ExecuteTime": {
|
||
|
"start_time": "2024-06-02T20:00:44.619526Z",
|
||
|
"end_time": "2024-06-02T20:03:13.180305Z"
|
||
|
}
|
||
|
}
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"source": [
|
||
|
"### Ewaluacja"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
}
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 85,
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"def evaluate(encoder, decoder, sentence, input_lang, output_lang):\n",
|
||
|
" with torch.no_grad():\n",
|
||
|
" input_tensor = tensorFromSentence(input_lang, sentence)\n",
|
||
|
"\n",
|
||
|
" encoder_outputs, encoder_hidden = encoder(input_tensor)\n",
|
||
|
" decoder_outputs, decoder_hidden, decoder_attn = decoder(encoder_outputs, encoder_hidden)\n",
|
||
|
"\n",
|
||
|
" _, topi = decoder_outputs.topk(1)\n",
|
||
|
" decoded_ids = topi.squeeze()\n",
|
||
|
"\n",
|
||
|
" decoded_words = []\n",
|
||
|
" for idx in decoded_ids:\n",
|
||
|
" if idx.item() == EOS_token:\n",
|
||
|
" decoded_words.append('<EOS>')\n",
|
||
|
" break\n",
|
||
|
" decoded_words.append(output_lang.index2word[idx.item()])\n",
|
||
|
" return decoded_words, decoder_attn"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false,
|
||
|
"ExecuteTime": {
|
||
|
"start_time": "2024-06-02T19:59:01.782695Z",
|
||
|
"end_time": "2024-06-02T19:59:01.811933Z"
|
||
|
}
|
||
|
}
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 86,
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"def evaluateRandomly(encoder, decoder, n=10):\n",
|
||
|
" for i in range(n):\n",
|
||
|
" pair = random.choice(pairs)\n",
|
||
|
" print('>', pair[0])\n",
|
||
|
" print('=', pair[1])\n",
|
||
|
" output_words, _ = evaluate(encoder, decoder, pair[0], input_lang, output_lang)\n",
|
||
|
" output_sentence = ' '.join(output_words)\n",
|
||
|
" print('<', output_sentence)\n",
|
||
|
" print('')"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false,
|
||
|
"ExecuteTime": {
|
||
|
"start_time": "2024-06-02T19:59:02.352827Z",
|
||
|
"end_time": "2024-06-02T19:59:02.374825Z"
|
||
|
}
|
||
|
}
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 99,
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"> utne sobie drzemke\n",
|
||
|
"= i m going to go take a nap\n",
|
||
|
"< i m going to go take a nap wallet <EOS>\n",
|
||
|
"\n",
|
||
|
"> nie jestem co do tego pewny to zalezy\n",
|
||
|
"= i m not sure about that it depends\n",
|
||
|
"< i m not sure about that it depends <EOS>\n",
|
||
|
"\n",
|
||
|
"> nie kupujemy\n",
|
||
|
"= we re not buying\n",
|
||
|
"< we re not buying <EOS>\n",
|
||
|
"\n",
|
||
|
"> nie jestem g upi\n",
|
||
|
"= i m not stupid\n",
|
||
|
"< i m not stupid <EOS>\n",
|
||
|
"\n",
|
||
|
"> jestes wymagajacy\n",
|
||
|
"= you re demanding\n",
|
||
|
"< you re demanding <EOS>\n",
|
||
|
"\n",
|
||
|
"> jestem m ody ale nie az tak\n",
|
||
|
"= i m young but i m not that young\n",
|
||
|
"< i m young but i m not that young <EOS>\n",
|
||
|
"\n",
|
||
|
"> nie jestem ubrana\n",
|
||
|
"= i m not dressed\n",
|
||
|
"< i m not dressed <EOS>\n",
|
||
|
"\n",
|
||
|
"> jestem gotowy sie z tym pogodzic\n",
|
||
|
"= i m ready to accept it\n",
|
||
|
"< i m ready to accept it <EOS>\n",
|
||
|
"\n",
|
||
|
"> jestem pewny ze ona nied ugo wroci\n",
|
||
|
"= i m sure that she will come back soon\n",
|
||
|
"< i m sure that she will come back soon <EOS>\n",
|
||
|
"\n",
|
||
|
"> w niedziele mam wolne\n",
|
||
|
"= i m free on sunday\n",
|
||
|
"< i m free on sunday <EOS>\n",
|
||
|
"\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"encoder.eval()\n",
|
||
|
"decoder.eval()\n",
|
||
|
"evaluateRandomly(encoder, decoder)"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false,
|
||
|
"ExecuteTime": {
|
||
|
"start_time": "2024-06-02T20:03:19.348154Z",
|
||
|
"end_time": "2024-06-02T20:03:19.572157Z"
|
||
|
}
|
||
|
}
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 88,
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"def showAttention(input_sentence, output_words, attentions):\n",
|
||
|
" fig = plt.figure()\n",
|
||
|
" ax = fig.add_subplot(111)\n",
|
||
|
" cax = ax.matshow(attentions.cpu().numpy(), cmap='bone')\n",
|
||
|
" fig.colorbar(cax)\n",
|
||
|
"\n",
|
||
|
" # Set up axes\n",
|
||
|
" ax.set_xticklabels([''] + input_sentence.split(' ') +\n",
|
||
|
" ['<EOS>'], rotation=90)\n",
|
||
|
" ax.set_yticklabels([''] + output_words)\n",
|
||
|
"\n",
|
||
|
" # Show label at every tick\n",
|
||
|
" ax.xaxis.set_major_locator(ticker.MultipleLocator(1))\n",
|
||
|
" ax.yaxis.set_major_locator(ticker.MultipleLocator(1))\n",
|
||
|
"\n",
|
||
|
" plt.show()\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"def evaluateAndShowAttention(input_sentence):\n",
|
||
|
" input_sentence = normalizeString(input_sentence)\n",
|
||
|
" output_words, attentions = evaluate(encoder, decoder, input_sentence, input_lang, output_lang)\n",
|
||
|
" print('input =', input_sentence)\n",
|
||
|
" print('output =', ' '.join(output_words))\n",
|
||
|
" showAttention(input_sentence, output_words, attentions[0, :len(output_words), :])"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false,
|
||
|
"ExecuteTime": {
|
||
|
"start_time": "2024-06-02T19:59:04.218821Z",
|
||
|
"end_time": "2024-06-02T19:59:04.250855Z"
|
||
|
}
|
||
|
}
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 100,
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"input = nie jestem katoliczka\n",
|
||
|
"output = i m not catholic <EOS>\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"name": "stderr",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"C:\\Users\\adamw\\AppData\\Local\\Temp\\ipykernel_17652\\691622281.py:8: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.\n",
|
||
|
" ax.set_xticklabels([''] + input_sentence.split(' ') +\n",
|
||
|
"C:\\Users\\adamw\\AppData\\Local\\Temp\\ipykernel_17652\\691622281.py:10: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.\n",
|
||
|
" ax.set_yticklabels([''] + output_words)\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": "<Figure size 640x480 with 2 Axes>",
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAcYAAAHZCAYAAAAc4ptnAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8fJSN1AAAACXBIWXMAAA9hAAAPYQGoP6dpAAA8xklEQVR4nO3dfVxUZf7/8feA3HgDqKHgDUreS95gokZ2Y4Vp7s+y2hbNgqjctqStqM38mqDVSpkZbVmUYVhbyprdWBqVFLYlK6W55r2mCabclYpiwgrz+8NldibRDg5wZpjX08d5rJw5Z85nxta313Wu6zoWq9VqFQAAkCR5mV0AAACuhGAEAMAOwQgAgB2CEQAAOwQjAAB2CEYAAOwQjAAA2CEYAQCwQzACAGCHYAQAwA7BCACAHYIRAAA7BCMAAHZamF0AAPNUVFRozZo1KigoUFVVlcNrf/7zn02qCjCXhcdOAZ7p22+/1bhx43T8+HFVVFSoffv2KisrU6tWrdSxY0ft2bPH7BIBU9CVCnioBx54QOPHj9ehQ4fUsmVL/etf/9K+ffs0dOhQzZs3z+zyANPQYgQ8VNu2bbVu3Tr17dtXbdu2VV5envr3769169YpPj5e27dvN7tEwBS0GAEP5ePjIy+vU38FdOzYUQUFBZKkoKAgFRYWmlkaYCoG3wAeasiQIfr666/Vu3dvXX755UpOTlZZWZneeOMNDRgwwOzyANPQlQp4qG+++UZHjx7VFVdcoZKSEsXFxWnt2rXq3bu3MjIyFBkZaXaJgCkIRgCn+eWXX9SyZUuzywBMwT1GwEOdaZ5iRUWFxo0b18TVAK6DYAQ81MqVK5WSkuKw79ixYxo7dqxOnjxpUlWA+Rh8A3ioTz75RJdeeqnatWun+++/X0ePHtWYMWPUokULffTRR2aXB5iGYAQ8VM+ePZWdna0rrrhCXl5eWrJkifz8/LRy5Uq1bt3a7PIA0zD4BvBweXl5Gj16tEaMGKEPP/yQQTfweAQj4EGGDBkii8Vy2v59+/apY8eODqG4YcOGpiwNcBl0pQIeZMKECWaXALg8WowAANhhugbgoW6//XYtXrz4tP3l5eW6/fbbTagIcA20GAEP5eXlpZYtW+qOO+5QWlqabUHx4uJide7cWdXV1SZXCJiDFiPgwVauXKlVq1ZpzJgxOnTokNnluLTq6mpt2rSJxQ88AMEIeLCIiAitW7dO//nPfzR8+HBt27bN7JJc1gcffKAhQ4YoKyvL7FLQyAhGwEPVTts477zztHr1al1++eWKjo7WihUrTK7MNS1evFgdOnRQZmam2aWgkXGPEc3KTz/9pOTkZH3++ecqKSlRTU2Nw+s///yzSZW5Hi8vLxUVFaljx462ffPnz9e0adNUU1PDPUY7ZWVl6tq1q9577z1de+212rNnj7p27Wp2WWgkzGNEs3Lrrbdq9+7duuOOOxQSElLnZHac8vnnn6t9+/YO+5KSkjRo0CB99dVXJlXlmpYsWaIBAwZo7NixuvTSS/XGG29o+vTpZpeFRkKLEc1KQECAvvzySw0ePNjsUtCMDB06VPHx8frzn/+s1157TXPnzuV+bDPGPUY0K/369dMvv/xidhlu4cYbb9RTTz112v65c+fqD3/4gwkVuabNmzdr8+bNuvnmmyVJN910kwoKCrRu3TqTK0NjIRjRrLz44ouaMWOG1qxZo59++knl5eUOG/7niy++qPOBxNdcc43WrFljQkWuafHixbr66qsVHBwsSWrTpo0mTJjAIJxmjGBEs9K2bVuVl5fryiuvVMeOHdWuXTu1a9dObdu2Vbt27cwuz6UcO3ZMvr6+p+338fHhHxH/VV1drb///e+Ki4tz2H/LLbcoKytLVVVVJlWGxsTgGzQrkydPlo+Pj9566y0G3/yGgQMHKisrS8nJyQ77ly5dqoiICJOqci0lJSW6++67dd111znsHzNmjJKSklRUVKRu3bqZVB0aC4Nv0Ky0atVK3377rfr27Wt2KS7vgw8+0A033KCbb75ZV155pSQpJydHS5Ys0bJly3gSBzwWXaloVqKiolRYWGh2GW5h/Pjxeu+997R7927dc889evDBB7V//36tXr2aUDyLffv2aevWrafNkUXzQYsRzcqyZcs0a9Ys/eUvf9HAgQPl4+Pj8PqgQYNMqgzuZtGiRTp8+LCSkpJs+/74xz8qIyNDktS3b199/PHHCgsLM6tENBKCEc1K7RMi7FksFlmtVlksFlZzgWEXXXSR7rrrLiUkJEiSsrOzNX78eGVmZqp///5KTExURESEXn31VZMrRUNj8A2alb1795pdgktr3769du7cqeDgYLVr1+6sg5M8ffm8Xbt2KSoqyvbz+++/r+uuu06TJ0+WJM2ZM8cWmmheCEY0K927dze7BJf27LPPKiAgQJKUlpZmbjEu7pdfflFgYKDt57Vr1+qOO+6w/dyjRw8VFRWZURoaGcGIZueNN95Qenq69u7dq7y8PHXv3l1paWk6//zzTxt272ni4+Pr/D1O1717d61fv17du3dXWVmZtmzZopEjR9peLyoqUlBQkIkVorEwKhXNyksvvaSkpCSNGzdOhw8ftt1TbNu2LS0k6bSVgM62ebr4+HhNnTpVjz/+uG666Sb169dPQ4cOtb2+du1aDRgwwMQK0VhoMaJZef7557Vw4UJNmDBBTz75pG1/VFSUHnroIRMrcw1t27b9zUUPGKh0ysMPP6zjx4/rnXfeUWhoqJYtW+bw+ldffaVJkyaZVB0aE6NS0ay0bNlS27dvV/fu3RUQEKB///vf6tGjh3bt2qVBgwZ5/ALj9VkD9fLLL2/ESgDXRYsRzcr555+vjRs3njYIJzs7W/379zepKtdB2NXfL7/8ok8//VQ7d+6UJPXp00ejR49Wy5YtTa4MjYVgRLOSlJSkqVOn6sSJE7JarcrPz9eSJUuUmprKfLM6HD58WBkZGbZnC15wwQW6/fbbGVTyXytWrNCdd96psrIyh/3BwcHKyMjQ+PHjTaoMjYmuVDQ7b775pmbNmqXvv/9ektS5c2fNnj3bYag9pG+++UZjxoxRy5YtNXz4cEnS119/rV9++UWffPKJLrzwQpMrNNfatWs1atQoXXvttXrwwQdtPQ5bt27VM888ow8//FBr1qzRRRddZHKlaGgEI5qt48eP69ixY+rYsaPZpbikSy+9VL169dLChQvVosWpzqOTJ0/qzjvv1J49e/TFF1+YXKG5xo0bp7CwML388st1vn7XXXepsLBQq1atauLK0NgIRjQrV155pd555x21bdvWYX95ebkmTJigzz77zJzCXFDLli317bffql+/fg77t27dqqioKB0/ftykylxD+/bttWbNGg0cOLDO1zdt2qTLL79chw4dauLK0NiYx4hmJTc3t86Hx544cUL//Oc/TajIdQUGBqqgoOC0/YWFhbbVcTzZr1e++bWgoCCdOHGiCStCU2HwDZqFTZs22X6/detWh6W6qqurlZ2drS5duphRmsuKjY3VHXfcoXnz5uniiy+WdGpu3l/+8hfm50nq3bu3PvvsszOuh5qTk6PevXs3cVVoCgQjmoXIyEhZLBZZLBbbQ3fttWzZUs8//7wJlbmuefPmyWKxKC4uTidPnpQk+fj46O6773ZYHMFTJSQk6KGHHlJISIjGjRvn8NrKlSv18MMP6//+7/9Mqg6NiXuMaBb27dsnq9WqHj16KD8/Xx06dLC95uvrq44dO8rb29vECl3X8ePHbSN4e/bsqVatWplckWuoqalRbGysli9frr59+6p///6yWq3atm2bdu3apQkTJmjZsmV1PuoM7o1gBDzU7bffrueee+60+4kVFRW69957tWjRIpMqcy1ZWVlasmSJwwT/iRMnauLEiSZXhsZCMLqZkydPKjc3V99//71uvvlmBQQE6MCBAwoMDFSbNm3MLs90ixcvVnBwsH73u99JOrXe5SuvvKKIiAgtWbKEx1LZ8fb21sGDB0+bzlJWVqbQ0FBb9yrgaegDcCP79u3TwIEDdd1112nq1KkqLS2VJD311FMskP1fc+bMsS3VlZe
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"evaluateAndShowAttention('Nie jestem katoliczką')"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false,
|
||
|
"ExecuteTime": {
|
||
|
"start_time": "2024-06-02T20:03:24.422192Z",
|
||
|
"end_time": "2024-06-02T20:03:24.634214Z"
|
||
|
}
|
||
|
}
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 101,
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"input = przykro nam ze to sie zdarzy o\n",
|
||
|
"output = we re sorry that it happened <EOS>\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"name": "stderr",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"C:\\Users\\adamw\\AppData\\Local\\Temp\\ipykernel_17652\\691622281.py:8: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.\n",
|
||
|
" ax.set_xticklabels([''] + input_sentence.split(' ') +\n",
|
||
|
"C:\\Users\\adamw\\AppData\\Local\\Temp\\ipykernel_17652\\691622281.py:10: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.\n",
|
||
|
" ax.set_yticklabels([''] + output_words)\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": "<Figure size 640x480 with 2 Axes>",
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAjsAAAG4CAYAAACjGiawAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8fJSN1AAAACXBIWXMAAA9hAAAPYQGoP6dpAABLY0lEQVR4nO3dfVyN9/8H8Nc53aqc3LRKRG6imu5EiflqWxY2jNkwlCazWbPEhp+RGrKNtOE75i6MLTPbbHxz02RDY4sMpbmvochNR5ninOv3h2/n66yi0zm6Oud6PT0+j+lzrnNd76ua3r0/n8/1kQmCIICIiIjIRMnFDoCIiIjocWKyQ0RERCaNyQ4RERGZNCY7REREZNKY7BAREZFJY7JDREREJo3JDhEREZk0JjtERERk0pjsEBERkUljskNEREQmjckOERERmTQmO0RERGTSzMUOgPSjUqnw3XffITc3FwDw5JNPYuDAgTAzMxM5MiIiooZBxl3Pjdfp06fx/PPP46+//kKnTp0AAHl5eXB1dcW2bdvQvn17kSMkIjIdKpUKJ06cgJeXF8zNWSswJhzGMmITJ05Eu3btUFBQgMOHD+Pw4cPIz89H27ZtMXHiRLHDIyIyKT/88AP8/f2RmpoqdiikI1Z2jJitrS1+/fVXeHt7a/UfPXoUPXv2RGlpqUiRERGZnsGDByMzMxPe3t7YtWuX2OGQDljZMWJWVla4detWlf7S0lJYWlqKEBERkWkqLi7Gf/7zH6SkpGDv3r3466+/xA6JdMBkx4i98MILeP3113Hw4EEIggBBEPDrr7/ijTfewMCBA8UOj4jqKCIiAj///LPYYdADvvzyS3Tu3Bl9+/ZFr169sH79erFDIh0w2TFin376Kdq3b4/g4GBYW1vD2toaPXv2RIcOHfDJJ5+IHR4R1VFJSQlCQ0Ph7u6OefPm4eLFi2KHJHkpKSkIDw8HAIwaNQrr1q0TOSLSBefsGClBEFBQUIAnnngCFy9e1Cw99/T0RIcOHUSOjoj0dfXqVaxfvx5r165FTk4OQkNDMXbsWAwaNAgWFhZihycpx48fR0BAAC5evAgHBweUlpbCyckJP/30E4KCgsQOj2qByY6RUqvVsLa2xokTJ+Du7i52OET0GB0+fBhr1qzBypUrYWdnh1GjRmHChAn8f7+evPvuuzh58iR++OEHTd/IkSOhUCjw2WefiRgZ1RaHsYyUXC6Hu7s7rl27JnYoRPQYXb58Gbt27cKuXbtgZmaG/v3749ixY/Dy8sKiRYvEDs/kqVQqfPHFF5ohrEqjRo1CamoqKioqRIqMdMFkx4jNnz8f7777Lo4fPy52KERkQHfv3sU333yDF154AW3atMHXX3+NmJgYXLp0CWvXrsXu3buxadMmJCQkiB2qybty5QrefPNNDBo0SKs/LCwMsbGxKCwsFCky0gWHsYxY06ZNcfv2bdy7dw+WlpZo1KiR1uvXr18XKTIi0oeDgwPUajVGjBiBcePGwc/Pr8oxN2/ehL+/P86dO1f/ARIZGT7v2ogtWrQIMplM7DCIyMAWLVqEl19+GdbW1jUe06RJEyY6Irlw4QLKysrg4eEBuZwDJMaAyY4RGzNmTI2v/f333/UXCBEZlCAIOHfuHDw9PbX679y5g02bNlWZP0KPx+rVq3Hz5k3ExsZq+l5//XWsWrUKANCpUyfs2LEDrq6uYoVItcSU1IjVtP9VWVkZ+vfvX8/REJGhjBkzBoGBgfjmm2+0+ktKShAZGSlSVNLz+eefo2nTppqP09LSsGbNGqxbtw6//fYbmjRpgvj4eBEjpNpismPEtm3bhri4OK2+srIy9O3bF/fu3RMpKiIyhPj4eIwePRqzZ88WOxTJOnXqFLp27ar5+Pvvv8egQYMwcuRIdOnSBfPmzUN6erqIEVJtMdkxYjt37sSKFSuQnJwMALh16xb69OkDmUyGtLQ0cYMjIr2MGjUKP/30E5YvX46hQ4dyaFoEf//9NxQKhebjAwcO4F//+pfm43bt2nE1lpHgnB0j1r59e6SlpeHpp5+GXC7Hl19+CSsrK2zbtg22trZih0dEdVS58KB79+44ePAgBg4ciB49emDZsmUiRyYtbdq0QVZWFtq0aYPi4mKcOHECPXv21LxeWFgIe3t7ESOk2mKyY+R8fHzw448/ok+fPggKCsKPP/5YZQk6ERmXB58I0rp1axw4cAAjR45Enz59RIxKeiIiIvDWW2/hxIkT+Omnn+Dh4YGAgADN6wcOHEDnzp1FjJBqi8mOkfH39692ubmVlRUuXbqk9VvH4cOH6zM0IjKQuLg42NnZaT62sbHBt99+i7i4OO6GXo/ee+893L59G1u2bIGzszO+/vprrdf379+PESNGiBQd6YIPFTQyusz8/+fkZSJq+O7evYvx48dj5syZaNu2rdjhEJkEJjtERA2Mvb09srOzmew0EH///Td27dqFP//8EwDQsWNH9OnTh1MGjAiTHSMWERGBsWPHaq0OMDaCIGDz5s3Ys2cPrly5ArVarfX6li1bRIqMSDwRERHw8/PDpEmTxA5F8rZu3YqoqCgUFxdr9Ts4OGDVqlUYMGCASJGRLjhnx4iVlJQgNDQUbdq0QWRkJCIiItCyZUuxw9JJTEwMli9fjqeffhpOTk7c/oIIgLu7OxISErB//34EBARUWV1Z0wNFybAOHDiAoUOHYuDAgZg8ebLmidY5OTlYuHAhhg4dir1796J79+4iR0qPwsqOkbt69SrWr1+PtWvXIicnB6GhoRg7diwGDRoECwsLscN7pGbNmuGLL77gE5+JHvCw4SuZTIazZ8/WYzTS1b9/f7i6umL58uXVvj5+/HgUFBRg+/bt9RwZ6YrJjgk5fPgw1qxZg5UrV8LOzg6jRo3ChAkT4O7uLnZoNWrbti3+85//wMPDQ+xQiIi0NGvWDHv37oW3t3e1r//xxx/o3bs3bty4Uc+Rka74BGUTcfnyZezatQu7du2CmZkZ+vfvj2PHjsHLywuLFi0SO7wazZ49G/Hx8Xw6LBE1OP98gvI/2dvb486dO/UYEdUV5+wYsbt372Lr1q1Ys2YNdu7cCR8fH8TExODVV1/V/A/67bff4rXXXmuwEx1feeUVfPnll3B0dISbm1uVoTc+K4ik6q+//sLWrVuRn5+PiooKrdeSkpJEikpa3N3d8dNPP9W4+Wp6enqDrpzT/zDZMWItWrSAWq3GiBEjcOjQIfj5+VU55umnn0aTJk3qPbbaioiIQFZWFkaNGsUJykT/lZ6ejoEDB6Jdu3Y4efIkOnfujPPnz0MQBHTp0kXs8CQjMjISU6ZMgZOTU5V5hdu2bcN7772H//u//xMpOtIF5+wYsfXr1+Pll1+GtbW12KHUma2tLXbs2IGnnnpK7FCIGozAwED069cP8fHxaNy4MY4ePQpHR0eMHDkSffv2xZtvvil2iJKgVqsxbNgwfPPNN+jUqRM8PT0hCAJyc3Nx6tQpvPjii/j6668hl3NGSEPHr5AREwQB586dq9J/584drFu3ToSIdOfq6vrQMXEiKcrNzUV4eDgAwNzcHH///Tfs7OyQkJCADz/8UOTopEMul+Prr7/Gl19+iU6dOuHkyZPIy8uDh4cHNmzYgG+++YaJjpFgZceIyeVy2NraIiUlBS+99JKmv6ioCC4uLlCpVCJGVzvbtm3D4sWLsWzZMri5uYkdDlGD4OzsjD179sDT0xNeXl6YP38+Bg4ciKNHj6Jnz54oLS0VO0Qio8KU1MjFx8dj9OjRmD17ttih1MmoUaOwZ88etG/fHo0bN0azZs20GpGufvnlF4waNQrBwcG4ePEigPtDvvv27RM5strr3r27Jt7+/ftj8uTJmDt3Ll577TU+wK4ebdq0SWty+F9//aX1lPfbt2/jo48+EiM00hErO0ZMLpejsLAQZ8+exeDBg9GzZ0+sX78eSqXSaCo7a9eufejrERER9RQJmYJvvvkGo0ePxsiRI7F+/Xrk5OSgXbt2WLJkCbZv3240D387e/YsSktL4ePjg7KyMkyePBkHDhyAu7s7kpKS0KZNG7FDlAQzMzNcvnw
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"evaluateAndShowAttention('Przykro nam ze to sie zdarzyło')\n"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false,
|
||
|
"ExecuteTime": {
|
||
|
"start_time": "2024-06-02T20:03:25.856941Z",
|
||
|
"end_time": "2024-06-02T20:03:26.205536Z"
|
||
|
}
|
||
|
}
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 102,
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"input = on mowi p ynnie po francusku\n",
|
||
|
"output = he is fluent in french <EOS>\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"name": "stderr",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"C:\\Users\\adamw\\AppData\\Local\\Temp\\ipykernel_17652\\691622281.py:8: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.\n",
|
||
|
" ax.set_xticklabels([''] + input_sentence.split(' ') +\n",
|
||
|
"C:\\Users\\adamw\\AppData\\Local\\Temp\\ipykernel_17652\\691622281.py:10: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.\n",
|
||
|
" ax.set_yticklabels([''] + output_words)\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": "<Figure size 640x480 with 2 Axes>",
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAikAAAHECAYAAAD8obrfAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8fJSN1AAAACXBIWXMAAA9hAAAPYQGoP6dpAAA7FElEQVR4nO3df3zNdf/H8efZ2Cb7ZcY2LEsIZX5NokRdfqVLXPohuWwUlSxGXSKFXGUuV37UNVEi+q3SVQoTYxFKFPqh/IqNbH5l83Nj53z/cO18O23TtrOzz/mc87i7fW45n/P5nM/rY7HXXq/3+/2x2Gw2mwAAANyMj9EBAAAAFIckBQAAuCWSFAAA4JZIUgAAgFsiSQEAAG6JJAUAALglkhQAAOCWSFIAAIBbIkkBAABuiSQFAAC4JZIUAADglkhSAACAW6pidAAAANe66qqrZLFYSnx/3759lRgNUHokKQDg4ZKSkhxeX7hwQd9++61SU1P1j3/8w5iggFKw2Gw2m9FBAAAq3+zZs7Vlyxa99tprRocCFIskBQC81L59+9SyZUvl5uYaHQpQLAbOAoCX+uCDDxQWFmZ0GECJGJMCAB6uVatWDgNnbTabsrKydPToUb300ksGRgZcHkkKAHi4Pn36OLz28fFRrVq11LlzZzVp0sSYoIBSYEwKAHi4gwcPql69esW+9+WXX+qGG26o5IiA0mFMCgB4uG7duunEiRNF9m/YsEE9evQwICKgdEhSAMDD3XDDDerWrZtOnTpl37du3Tr17NlTEydONDAy4PJo9wCAh7Narbrrrrt04sQJrVy5Uhs3btQdd9yhZ599ViNHjjQ6PKBEJCkA4AXy8/N1++236+zZs9qxY4eSk5OVmJhodFjAZZGkAIAH2rFjR5F9p06dUv/+/XX77bdr2LBh9v2xsbGVGRpQaiQpAOCBfHx8ZLFY9Pt/4n//uvD3FotFBQUFRoUJXBbrpACAB/rll1+MDgFwGpUUAADglpiCDAAebtGiRVq2bJn99ZgxYxQaGqoOHTrowIEDBkYGXB5JCgB4uClTpqhatWqSpE2bNiklJUXTpk1TeHi4Ro0aZXB0Fa+goEA7duzQxYsXjQ4FTiJJAQAPl5mZqYYNG0qSPvroI91111168MEHlZycrPXr1xscXcX75JNP1KpVKy1evNjoUOAkkhQA8HCBgYE6fvy4JOmzzz5T165dJUkBAQE6d+6ckaG5xKJFi1SrVi0tXLjQ6FDgJGb3AICH69q1q4YMGaJWrVpp165d6tmzpyTphx9+UExMjLHBVbBjx45pxYoV+uijj3THHXdc9uGKcH9UUgDAw82ePVvt27fX0aNHtWTJEtWsWVOStHXrVvXv39/g6CrWO++8o+uuu049evRQx44d9cYbbxgdEpzAFGQAgMdo06aNEhISNGLECL322muaNm2adu7caXRYKCeSFADwcOvWrbvs+zfffHMlReJa33//vdq0aaNDhw4pPDxcp0+fVkREhNasWaN27doZHR7KgSQFADycj0/Rzr7FYrH/3lOWxf/HP/6hn376SZ988ol934ABAxQcHKw5c+YYGBnKizEpAODhfvvtN4ftyJEjSk1NVdu2bfXZZ58ZHV6FKCgo0Jtvvqn4+HiH/X//+9+1ePFi5efnGxQZnMHsHgDwcCEhIUX2de3aVX5+fho9erS2bt1qQFQV68iRIxo2bJh69+7tsL979+4aPXq0srKydOWVVxoUHcqLdg8AeKmffvpJcXFxOn36tNGhAMWikgIAHm7Hjh0Or202mw4fPqypU6eqZcuWxgRVCQ4cOKAzZ86oSZMmxY7LgfujkgLgT128eFHp6enau3ev7rvvPgUFBenXX39VcHCwAgMDjQ4Pf8LHx0cWi0V//Of+hhtu0IIFC9SkSRODIqsYCxYs0MmTJzV69Gj7vgcffFDz58+XJF1zzTVauXKloqOjjQoR5USSAuCyDhw4oB49eigjI0N5eXnatWuXGjRooJEjRyovL09z5841OkT8iT8+6djHx0e1atVSQECAQRFVrBtuuEEPPfSQBg8eLElKTU1Vr169tHDhQjVt2lSJiYlq1qyZXn31VYMjRVnR7gFwWSNHjlRcXJy2b99uX6lUkv72t79p6NChBkaG0qpfv77RIbjU7t27FRcXZ3/98ccfq3fv3howYICkS0+BLkxgYC4kKXCJsLAw7dq1S+Hh4apRo4bDmgx/dOLEiUqMDGW1fv16bdy4UX5+fg77Y2JidOjQIYOiQlmMGDFCDRs21IgRIxz2p6SkaM+ePZo1a5YxgVWQc+fOKTg42P5648aNeuCBB+yvGzRooKysLCNCg5NIUuASM2fOVFBQkP33l0tS4N6sVmuxi30dPHjQ/jWGe1uyZImWLl1aZH+HDh00depU0ycp9evX19atW1W/fn0dO3ZMP/zwg2688Ub7+1lZWcVOw4b7I0mBSyQkJNh/P2jQIOMCgdO6deumWbNm6ZVXXpF0aaXS06dPa+LEifan6cK9HT9+vNhv0sHBwTp27JgBEVWshIQEDR8+XD/88IPWrFmjJk2aqE2bNvb3N27cqOuuu87ACFFezMmCy8XHx+u1117T3r17jQ4F5TB9+nRt2LBBzZo10/nz53XffffZWz3/+te/jA4PpdCwYUOlpqYW2b9ixQo1aNDAgIgq1pgxYzR06FB9+OGHCggI0Pvvv+/w/oYNGzzuac/egtk9cLkhQ4Zo3bp12rNnj+rWratOnTqpc+fO6tSpkxo1amR0eCiFixcv6t1339WOHTt0+vRptW7dWgMGDFC1atWMDg2lsGDBAiUmJuof//iHbr31VklSWlqapk+frlmzZjEAGm6LJAWV5tChQ1q3bp0+//xzff7559q1a5eioqJ08OBBo0MDPN6cOXP03HPP6ddff5V0aeDzpEmTijzrxszOnTunVatWadeuXZKkxo0bq2vXriTTJsaYFFSaGjVqqGbNmqpRo4ZCQ0NVpUoV1apVy+iwUIylS5fqtttuU9WqVYsdcPl7d9xxRyVFBWcMGzZMw4YN09GjR1WtWjWPW4Rv6dKlGjJkSJExNuHh4Zo/f7569eplUGRwBpUUuNyTTz6p9PR0ffvtt2ratKm93XPzzTerRo0aRoeHYvj4+CgrK0u1a9e+7HLiFoul2Jk/QGXauHGjOnfurDvuuEOPPfaYmjZtKkn68ccfNX36dH366af6/PPPdcMNNxgcKcqKJAUuV7i65ahRo9S3b181btzY6JAAr5Kdna3HH39caWlpOnLkSJHl8c2eaPbs2VPR0dF6+eWXi33/oYceUmZmppYvX17JkcFZJClwue3bt+vzzz9Xenq61q9fLz8/P3s1pXPnziQtgIvddtttysjIUGJioqKiooqsW9S7d2+DIqsYYWFh+vzzz9W8efNi39+xY4c6deqk3377rZIjg7NIUlDptm/frpkzZ+qtt94qcaEwuJe0tDT7T+FWq9XhvQULFhgUFUorKChI69ev99gnHlerVk0//fRTicv/HzhwQE2aNNG5c+cqOTI4i4GzcDmbzaZvv/1W6enpSk9P1xdffKHc3FzFxsaqU6dORoeHP/HMM89o8uTJiouLK/ancLi/6OjoIi0eT9KoUSOtWbOmxOfzpKWlsdyBSZGkwOXCwsJ0+vRptWjRQp06ddLQoUPVsWNHhYaGGh0aSmHu3LlauHChBg4caHQoKKdZs2Zp7NixevnllxUTE2N0OBVu8ODBevzxxxUREVFkFeRly5ZpzJgxevLJJw2KDs6g3QOXW7ZsmTp27OjwADCYR82aNbV582ZdffXVRoeCcqpRo4bOnj2rixcv6oorrlDVqlUd3jf7Qz6tVqv69eunJUuW6JprrlHTpk1ls9m0c+dO7d69W3369NH7779/2ZlqcE8kKahUhQu31atXz+BIUFpPPPGEAgMD9fTTTxsdCspp0aJFl33/98/aMrPFixfrnXfecVjM7d5779W9995rcGQoL5IUuJzVatWzzz6r6dOn6/Tp05IuDeR77LHHNH78eH66cXMjR47U66+/rtjYWMXGxhb5KXzGjBkGRQb
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"evaluateAndShowAttention('On mówi płynnie po francusku')\n"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false,
|
||
|
"ExecuteTime": {
|
||
|
"start_time": "2024-06-02T20:03:26.594026Z",
|
||
|
"end_time": "2024-06-02T20:03:26.838018Z"
|
||
|
}
|
||
|
}
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"source": [
|
||
|
"### BLEU"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
}
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 103,
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"def filter_rows(row):\n",
|
||
|
" return len(row[\"eng\"].split(' '))<MAX_LENGTH and \\\n",
|
||
|
" len(row[\"pol\"].split(' '))<MAX_LENGTH and \\\n",
|
||
|
" row[\"eng\"].startswith(eng_prefixes)\n",
|
||
|
"\n",
|
||
|
"def evaluateWithTokenization(input_sentence):\n",
|
||
|
" input_sentence = normalizeString(input_sentence)\n",
|
||
|
" output_words, attentions = evaluate(encoder, decoder, input_sentence, input_lang, output_lang)\n",
|
||
|
" if \"<EOS>\" in output_words:\n",
|
||
|
" output_words.remove(\"<EOS>\")\n",
|
||
|
" return output_words"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false,
|
||
|
"ExecuteTime": {
|
||
|
"start_time": "2024-06-02T20:03:29.868015Z",
|
||
|
"end_time": "2024-06-02T20:03:29.884050Z"
|
||
|
}
|
||
|
}
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 114,
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"df = pd.read_csv(\"data/eng-pol.txt\", sep='\\t', names=[\"eng\", \"pol\", \"attribution\"])\n",
|
||
|
"df[\"eng\"] = df[\"eng\"].apply(normalizeString)\n",
|
||
|
"df[\"pol\"] = df[\"pol\"].apply(normalizeString)\n",
|
||
|
"df_filtered = df.apply(filter_rows, axis=1)\n",
|
||
|
"test_df = df[df_filtered].sample(frac=1)\n",
|
||
|
"test_df[\"eng_token\"] = test_df[\"eng\"].apply(lambda x: x.split())\n",
|
||
|
"test_df[\"eng_eval\"] = test_df[\"pol\"].apply(lambda x: evaluateWithTokenization(x))"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false,
|
||
|
"ExecuteTime": {
|
||
|
"start_time": "2024-06-02T20:07:48.707058Z",
|
||
|
"end_time": "2024-06-02T20:08:22.246952Z"
|
||
|
}
|
||
|
}
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 115,
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"references_corpus = test_df[\"eng_token\"].values.tolist()\n",
|
||
|
"candidate_corpus = test_df[\"eng_eval\"].values.tolist()\n",
|
||
|
"references_corpus = [[el] for el in references_corpus]"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false,
|
||
|
"ExecuteTime": {
|
||
|
"start_time": "2024-06-02T20:08:22.248949Z",
|
||
|
"end_time": "2024-06-02T20:08:22.262981Z"
|
||
|
}
|
||
|
}
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 116,
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": "0.9301728010177612"
|
||
|
},
|
||
|
"execution_count": 116,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"bleu_score(candidate_corpus, references_corpus)"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false,
|
||
|
"ExecuteTime": {
|
||
|
"start_time": "2024-06-02T20:08:22.264948Z",
|
||
|
"end_time": "2024-06-02T20:08:23.695461Z"
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
],
|
||
|
"metadata": {
|
||
|
"kernelspec": {
|
||
|
"display_name": "Python 3",
|
||
|
"language": "python",
|
||
|
"name": "python3"
|
||
|
},
|
||
|
"language_info": {
|
||
|
"codemirror_mode": {
|
||
|
"name": "ipython",
|
||
|
"version": 2
|
||
|
},
|
||
|
"file_extension": ".py",
|
||
|
"mimetype": "text/x-python",
|
||
|
"name": "python",
|
||
|
"nbconvert_exporter": "python",
|
||
|
"pygments_lexer": "ipython2",
|
||
|
"version": "2.7.6"
|
||
|
}
|
||
|
},
|
||
|
"nbformat": 4,
|
||
|
"nbformat_minor": 0
|
||
|
}
|