{ "cells": [ { "cell_type": "markdown", "source": [ "### Importy" ], "metadata": { "collapsed": false } }, { "cell_type": "code", "execution_count": 84, "metadata": { "collapsed": true, "ExecuteTime": { "start_time": "2024-06-02T19:58:41.249607Z", "end_time": "2024-06-02T19:58:41.261609Z" } }, "outputs": [], "source": [ "from __future__ import unicode_literals, print_function, division\n", "from io import open\n", "import unicodedata\n", "import re\n", "import os\n", "import random\n", "import torch\n", "import pandas as pd\n", "import torch.nn as nn\n", "from torch import optim\n", "import torch.nn.functional as F\n", "from torchtext.data.metrics import bleu_score\n", "\n", "from torch.utils.data import TensorDataset, DataLoader, RandomSampler\n", "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"" ] }, { "cell_type": "code", "execution_count": 9, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Is CUDA supported by this system? True\n", "CUDA version: 12.1\n", "ID of current CUDA device: 0\n", "Name of current CUDA device: NVIDIA GeForce GTX 1660 Ti\n" ] } ], "source": [ "print(f'Is CUDA supported by this system? {torch.cuda.is_available()}')\n", "print(f\"CUDA version: {torch.version.cuda}\")\n", "\n", "cuda_id = torch.cuda.current_device()\n", "print(f'ID of current CUDA device: {torch.cuda.current_device()}')\n", "\n", "print(f'Name of current CUDA device: {torch.cuda.get_device_name(cuda_id)}')" ], "metadata": { "collapsed": false, "ExecuteTime": { "start_time": "2024-06-02T19:20:55.709021Z", "end_time": "2024-06-02T19:20:55.725023Z" } } }, { "cell_type": "code", "execution_count": 10, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "cuda\n" ] } ], "source": [ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "print(device)" ], "metadata": { "collapsed": false, "ExecuteTime": { "start_time": "2024-06-02T19:20:55.996605Z", "end_time": "2024-06-02T19:20:56.041138Z" } } }, { "cell_type": "markdown", "source": [ "### Konwersja słów na tensory" ], "metadata": { "collapsed": false } }, { "cell_type": "code", "execution_count": 11, "outputs": [], "source": [ "SOS_token = 0\n", "EOS_token = 1\n", "\n", "class Lang:\n", " def __init__(self, name):\n", " self.name = name\n", " self.word2index = {}\n", " self.word2count = {}\n", " self.index2word = {0: \"SOS\", 1: \"EOS\"}\n", " self.n_words = 2 # Count SOS and EOS\n", "\n", " def addSentence(self, sentence):\n", " for word in sentence.split(' '):\n", " self.addWord(word)\n", "\n", " def addWord(self, word):\n", " if word not in self.word2index:\n", " self.word2index[word] = self.n_words\n", " self.word2count[word] = 1\n", " self.index2word[self.n_words] = word\n", " self.n_words += 1\n", " else:\n", " self.word2count[word] += 1" ], "metadata": { "collapsed": false, "ExecuteTime": { "start_time": "2024-06-02T19:20:59.879666Z", "end_time": "2024-06-02T19:20:59.893667Z" } } }, { "cell_type": "markdown", "source": [ "### Przygotowanie danych" ], "metadata": { "collapsed": false } }, { "cell_type": "code", "execution_count": 12, "outputs": [], "source": [ "# Turn a Unicode string to plain ASCII, thanks to\n", "# https://stackoverflow.com/a/518232/2809427\n", "def unicodeToAscii(s):\n", " return ''.join(\n", " c for c in unicodedata.normalize('NFD', s)\n", " if unicodedata.category(c) != 'Mn'\n", " )\n", "\n", "# Lowercase, trim, and remove non-letter characters\n", "def normalizeString(s):\n", " s = unicodeToAscii(s.lower().strip())\n", " s = re.sub(r\"([.!?])\", r\" \\1\", s)\n", " s = re.sub(r\"[^a-zA-Z!?]+\", r\" \", s)\n", " return s.strip()" ], "metadata": { "collapsed": false, "ExecuteTime": { "start_time": "2024-06-02T19:21:00.877093Z", "end_time": "2024-06-02T19:21:00.892090Z" } } }, { "cell_type": "markdown", "source": [ "### Wczytanie danych" ], "metadata": { "collapsed": false } }, { "cell_type": "code", "execution_count": 13, "outputs": [], "source": [ "def readLangs(lang1, lang2, reverse=False):\n", " print(\"Reading lines...\")\n", " # Read the file and split into lines\n", " lines = open('data/%s-%s.txt' % (lang1, lang2), encoding='utf-8').\\\n", " read().strip().split('\\n')\n", "\n", " # Split every line into pairs and normalize\n", " pairs = [[normalizeString(s) for s in l.split('\\t')[:-1]] for l in lines]\n", "\n", " # Reverse pairs, make Lang instances\n", " if reverse:\n", " pairs = [df_filtered(reversed(p)) for p in pairs]\n", " input_lang = Lang(lang2)\n", " output_lang = Lang(lang1)\n", " else:\n", " input_lang = Lang(lang1)\n", " output_lang = Lang(lang2)\n", "\n", " return input_lang, output_lang, pairs" ], "metadata": { "collapsed": false, "ExecuteTime": { "start_time": "2024-06-02T19:21:02.075474Z", "end_time": "2024-06-02T19:21:02.087474Z" } } }, { "cell_type": "markdown", "source": [ "### Filtracja danych" ], "metadata": { "collapsed": false } }, { "cell_type": "markdown", "source": [ "Ograniczenie zdań do 10 słów oraz zdań zaczynających się od prefiksów" ], "metadata": { "collapsed": false } }, { "cell_type": "code", "execution_count": 14, "outputs": [], "source": [ "MAX_LENGTH = 10\n", "\n", "eng_prefixes = (\n", " \"i am \", \"i m \",\n", " \"he is\", \"he s \",\n", " \"she is\", \"she s \",\n", " \"you are\", \"you re \",\n", " \"we are\", \"we re \",\n", " \"they are\", \"they re \"\n", ")\n", "\n", "def filterPair(p):\n", " return len(p[0].split(' ')) < MAX_LENGTH and \\\n", " len(p[1].split(' ')) < MAX_LENGTH and \\\n", " p[1].startswith(eng_prefixes)\n", "\n", "\n", "def filterPairs(pairs):\n", " return [pair for pair in pairs if filterPair(pair)]" ], "metadata": { "collapsed": false, "ExecuteTime": { "start_time": "2024-06-02T19:21:03.811303Z", "end_time": "2024-06-02T19:21:03.829054Z" } } }, { "cell_type": "code", "execution_count": 15, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Reading lines...\n", "Read 49943 sentence pairs\n", "Trimmed to 3613 sentence pairs\n", "Counting words...\n", "Counted words:\n", "pol 3070\n", "eng 1969\n", "['nie umieram', 'i m not dying']\n" ] } ], "source": [ "def prepareData(lang1, lang2, reverse=False):\n", " input_lang, output_lang, pairs = readLangs(lang1, lang2, reverse)\n", " print(\"Read %s sentence pairs\" % len(pairs))\n", " pairs = filterPairs(pairs)\n", " print(\"Trimmed to %s sentence pairs\" % len(pairs))\n", " print(\"Counting words...\")\n", " for pair in pairs:\n", " input_lang.addSentence(pair[0])\n", " output_lang.addSentence(pair[1])\n", " print(\"Counted words:\")\n", " print(input_lang.name, input_lang.n_words)\n", " print(output_lang.name, output_lang.n_words)\n", " return input_lang, output_lang, pairs\n", "\n", "input_lang, output_lang, pairs = prepareData('eng', 'pol' , True)\n", "print(random.choice(pairs))" ], "metadata": { "collapsed": false, "ExecuteTime": { "start_time": "2024-06-02T19:21:04.527025Z", "end_time": "2024-06-02T19:21:06.394023Z" } } }, { "cell_type": "markdown", "source": [ "### Model" ], "metadata": { "collapsed": false } }, { "cell_type": "code", "execution_count": 16, "outputs": [], "source": [ "class EncoderRNN(nn.Module):\n", " def __init__(self, input_size, hidden_size, dropout_p=0.1):\n", " super(EncoderRNN, self).__init__()\n", " self.hidden_size = hidden_size\n", "\n", " self.embedding = nn.Embedding(input_size, hidden_size)\n", " self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True)\n", " self.dropout = nn.Dropout(dropout_p)\n", "\n", " def forward(self, input):\n", " embedded = self.dropout(self.embedding(input))\n", " output, hidden = self.gru(embedded)\n", " return output, hidden" ], "metadata": { "collapsed": false, "ExecuteTime": { "start_time": "2024-06-02T19:21:11.058623Z", "end_time": "2024-06-02T19:21:11.074974Z" } } }, { "cell_type": "code", "execution_count": 17, "outputs": [], "source": [ "class DecoderRNN(nn.Module):\n", " def __init__(self, hidden_size, output_size):\n", " super(DecoderRNN, self).__init__()\n", " self.embedding = nn.Embedding(output_size, hidden_size)\n", " self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True)\n", " self.out = nn.Linear(hidden_size, output_size)\n", "\n", " def forward(self, encoder_outputs, encoder_hidden, target_tensor=None):\n", " batch_size = encoder_outputs.size(0)\n", " decoder_input = torch.empty(batch_size, 1, dtype=torch.long, device=device).fill_(SOS_token)\n", " decoder_hidden = encoder_hidden\n", " decoder_outputs = []\n", "\n", " for i in range(MAX_LENGTH):\n", " decoder_output, decoder_hidden = self.forward_step(decoder_input, decoder_hidden)\n", " decoder_outputs.append(decoder_output)\n", "\n", " if target_tensor is not None:\n", " # Teacher forcing: Feed the target as the next input\n", " decoder_input = target_tensor[:, i].unsqueeze(1) # Teacher forcing\n", " else:\n", " # Without teacher forcing: use its own predictions as the next input\n", " _, topi = decoder_output.topk(1)\n", " decoder_input = topi.squeeze(-1).detach() # detach from history as input\n", "\n", " decoder_outputs = torch.cat(decoder_outputs, dim=1)\n", " decoder_outputs = F.log_softmax(decoder_outputs, dim=-1)\n", " return decoder_outputs, decoder_hidden, None # We return `None` for consistency in the training loop\n", "\n", " def forward_step(self, input, hidden):\n", " output = self.embedding(input)\n", " output = F.relu(output)\n", " output, hidden = self.gru(output, hidden)\n", " output = self.out(output)\n", " return output, hidden" ], "metadata": { "collapsed": false, "ExecuteTime": { "start_time": "2024-06-02T19:21:11.447213Z", "end_time": "2024-06-02T19:21:11.462232Z" } } }, { "cell_type": "code", "execution_count": 18, "outputs": [], "source": [ "class BahdanauAttention(nn.Module):\n", " def __init__(self, hidden_size):\n", " super(BahdanauAttention, self).__init__()\n", " self.Wa = nn.Linear(hidden_size, hidden_size)\n", " self.Ua = nn.Linear(hidden_size, hidden_size)\n", " self.Va = nn.Linear(hidden_size, 1)\n", "\n", " def forward(self, query, keys):\n", " scores = self.Va(torch.tanh(self.Wa(query) + self.Ua(keys)))\n", " scores = scores.squeeze(2).unsqueeze(1)\n", "\n", " weights = F.softmax(scores, dim=-1)\n", " context = torch.bmm(weights, keys)\n", "\n", " return context, weights\n", "\n", "class AttnDecoderRNN(nn.Module):\n", " def __init__(self, hidden_size, output_size, dropout_p=0.1):\n", " super(AttnDecoderRNN, self).__init__()\n", " self.embedding = nn.Embedding(output_size, hidden_size)\n", " self.attention = BahdanauAttention(hidden_size)\n", " self.gru = nn.GRU(2 * hidden_size, hidden_size, batch_first=True)\n", " self.out = nn.Linear(hidden_size, output_size)\n", " self.dropout = nn.Dropout(dropout_p)\n", "\n", " def forward(self, encoder_outputs, encoder_hidden, target_tensor=None):\n", " batch_size = encoder_outputs.size(0)\n", " decoder_input = torch.empty(batch_size, 1, dtype=torch.long, device=device).fill_(SOS_token)\n", " decoder_hidden = encoder_hidden\n", " decoder_outputs = []\n", " attentions = []\n", "\n", " for i in range(MAX_LENGTH):\n", " decoder_output, decoder_hidden, attn_weights = self.forward_step(\n", " decoder_input, decoder_hidden, encoder_outputs\n", " )\n", " decoder_outputs.append(decoder_output)\n", " attentions.append(attn_weights)\n", "\n", " if target_tensor is not None:\n", " # Teacher forcing: Feed the target as the next input\n", " decoder_input = target_tensor[:, i].unsqueeze(1) # Teacher forcing\n", " else:\n", " # Without teacher forcing: use its own predictions as the next input\n", " _, topi = decoder_output.topk(1)\n", " decoder_input = topi.squeeze(-1).detach() # detach from history as input\n", "\n", " decoder_outputs = torch.cat(decoder_outputs, dim=1)\n", " decoder_outputs = F.log_softmax(decoder_outputs, dim=-1)\n", " attentions = torch.cat(attentions, dim=1)\n", "\n", " return decoder_outputs, decoder_hidden, attentions\n", "\n", "\n", " def forward_step(self, input, hidden, encoder_outputs):\n", " embedded = self.dropout(self.embedding(input))\n", "\n", " query = hidden.permute(1, 0, 2)\n", " context, attn_weights = self.attention(query, encoder_outputs)\n", " input_gru = torch.cat((embedded, context), dim=2)\n", "\n", " output, hidden = self.gru(input_gru, hidden)\n", " output = self.out(output)\n", "\n", " return output, hidden, attn_weights" ], "metadata": { "collapsed": false, "ExecuteTime": { "start_time": "2024-06-02T19:21:12.049305Z", "end_time": "2024-06-02T19:21:12.073302Z" } } }, { "cell_type": "code", "execution_count": 21, "outputs": [], "source": [ "def indexesFromSentence(lang, sentence):\n", " return [lang.word2index[word] for word in sentence.split(' ')]\n", "\n", "def tensorFromSentence(lang, sentence):\n", " indexes = indexesFromSentence(lang, sentence)\n", " indexes.append(EOS_token)\n", " return torch.tensor(indexes, dtype=torch.long, device=device).view(1, -1)\n", "\n", "def tensorsFromPair(pair):\n", " input_tensor = tensorFromSentence(input_lang, pair[0])\n", " target_tensor = tensorFromSentence(output_lang, pair[1])\n", " return (input_tensor, target_tensor)\n", "\n", "def get_dataloader(batch_size):\n", " input_lang, output_lang, pairs = prepareData( 'eng', 'pol', True)\n", "\n", " n = len(pairs)\n", " input_ids = np.zeros((n, MAX_LENGTH), dtype=np.int32)\n", " target_ids = np.zeros((n, MAX_LENGTH), dtype=np.int32)\n", "\n", " for idx, (inp, tgt) in enumerate(pairs):\n", " inp_ids = indexesFromSentence(input_lang, inp)\n", " tgt_ids = indexesFromSentence(output_lang, tgt)\n", " inp_ids.append(EOS_token)\n", " tgt_ids.append(EOS_token)\n", " input_ids[idx, :len(inp_ids)] = inp_ids\n", " target_ids[idx, :len(tgt_ids)] = tgt_ids\n", "\n", " train_data = TensorDataset(torch.LongTensor(input_ids).to(device),\n", " torch.LongTensor(target_ids).to(device))\n", "\n", " train_sampler = RandomSampler(train_data)\n", " train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size)\n", " return input_lang, output_lang, train_dataloader" ], "metadata": { "collapsed": false, "ExecuteTime": { "start_time": "2024-06-02T19:23:18.301396Z", "end_time": "2024-06-02T19:23:18.321420Z" } } }, { "cell_type": "markdown", "source": [ "### Trening" ], "metadata": { "collapsed": false } }, { "cell_type": "code", "execution_count": 22, "outputs": [], "source": [ "def train_epoch(dataloader, encoder, decoder, encoder_optimizer,\n", " decoder_optimizer, criterion):\n", "\n", " total_loss = 0\n", " for data in dataloader:\n", " input_tensor, target_tensor = data\n", "\n", " encoder_optimizer.zero_grad()\n", " decoder_optimizer.zero_grad()\n", "\n", " encoder_outputs, encoder_hidden = encoder(input_tensor)\n", " decoder_outputs, _, _ = decoder(encoder_outputs, encoder_hidden, target_tensor)\n", "\n", " loss = criterion(\n", " decoder_outputs.view(-1, decoder_outputs.size(-1)),\n", " target_tensor.view(-1)\n", " )\n", " loss.backward()\n", "\n", " encoder_optimizer.step()\n", " decoder_optimizer.step()\n", "\n", " total_loss += loss.item()\n", "\n", " return total_loss / len(dataloader)" ], "metadata": { "collapsed": false, "ExecuteTime": { "start_time": "2024-06-02T19:23:19.166843Z", "end_time": "2024-06-02T19:23:19.182827Z" } } }, { "cell_type": "code", "execution_count": 23, "outputs": [], "source": [ "import time\n", "import math\n", "\n", "def asMinutes(s):\n", " m = math.floor(s / 60)\n", " s -= m * 60\n", " return '%dm %ds' % (m, s)\n", "\n", "def timeSince(since, percent):\n", " now = time.time()\n", " s = now - since\n", " es = s / (percent)\n", " rs = es - s\n", " return '%s (- %s)' % (asMinutes(s), asMinutes(rs))" ], "metadata": { "collapsed": false, "ExecuteTime": { "start_time": "2024-06-02T19:23:19.675207Z", "end_time": "2024-06-02T19:23:19.699207Z" } } }, { "cell_type": "code", "execution_count": 24, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "plt.switch_backend('agg')\n", "import matplotlib.ticker as ticker\n", "import numpy as np\n", "\n", "def showPlot(points):\n", " plt.figure()\n", " fig, ax = plt.subplots()\n", " # this locator puts ticks at regular intervals\n", " loc = ticker.MultipleLocator(base=0.2)\n", " ax.yaxis.set_major_locator(loc)\n", " plt.plot(points)" ], "metadata": { "collapsed": false, "ExecuteTime": { "start_time": "2024-06-02T19:23:20.120325Z", "end_time": "2024-06-02T19:23:20.833674Z" } } }, { "cell_type": "code", "execution_count": 25, "outputs": [], "source": [ "def train(train_dataloader, encoder, decoder, n_epochs, learning_rate=0.001,\n", " print_every=100, plot_every=100):\n", " start = time.time()\n", " plot_losses = []\n", " print_loss_total = 0 # Reset every print_every\n", " plot_loss_total = 0 # Reset every plot_every\n", "\n", " encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)\n", " decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate)\n", " criterion = nn.NLLLoss()\n", "\n", " for epoch in range(1, n_epochs + 1):\n", " loss = train_epoch(train_dataloader, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion)\n", " print_loss_total += loss\n", " plot_loss_total += loss\n", "\n", " if epoch % print_every == 0:\n", " print_loss_avg = print_loss_total / print_every\n", " print_loss_total = 0\n", " print('%s (%d %d%%) %.4f' % (timeSince(start, epoch / n_epochs),\n", " epoch, epoch / n_epochs * 100, print_loss_avg))\n", "\n", " if epoch % plot_every == 0:\n", " plot_loss_avg = plot_loss_total / plot_every\n", " plot_losses.append(plot_loss_avg)\n", " plot_loss_total = 0\n", "\n", " showPlot(plot_losses)" ], "metadata": { "collapsed": false, "ExecuteTime": { "start_time": "2024-06-02T19:23:21.920756Z", "end_time": "2024-06-02T19:23:21.949755Z" } } }, { "cell_type": "code", "execution_count": 98, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Reading lines...\n", "Read 49943 sentence pairs\n", "Trimmed to 3613 sentence pairs\n", "Counting words...\n", "Counted words:\n", "pol 3070\n", "eng 1969\n", "0m 7s (- 2m 18s) (5 5%) 1.9851\n", "0m 14s (- 2m 8s) (10 10%) 1.0089\n", "0m 21s (- 1m 59s) (15 15%) 0.5189\n", "0m 28s (- 1m 52s) (20 20%) 0.2294\n", "0m 35s (- 1m 45s) (25 25%) 0.0961\n", "0m 42s (- 1m 38s) (30 30%) 0.0509\n", "0m 50s (- 1m 33s) (35 35%) 0.0355\n", "0m 57s (- 1m 25s) (40 40%) 0.0289\n", "1m 4s (- 1m 18s) (45 45%) 0.0249\n", "1m 11s (- 1m 11s) (50 50%) 0.0228\n", "1m 18s (- 1m 4s) (55 55%) 0.0207\n", "1m 25s (- 0m 57s) (60 60%) 0.0215\n", "1m 32s (- 0m 49s) (65 65%) 0.0249\n", "1m 39s (- 0m 42s) (70 70%) 0.0184\n", "1m 47s (- 0m 35s) (75 75%) 0.0172\n", "1m 55s (- 0m 28s) (80 80%) 0.0166\n", "2m 3s (- 0m 21s) (85 85%) 0.0163\n", "2m 11s (- 0m 14s) (90 90%) 0.0163\n", "2m 18s (- 0m 7s) (95 95%) 0.0176\n", "2m 27s (- 0m 0s) (100 100%) 0.0256\n" ] }, { "data": { "text/plain": "
" }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": "
", "image/png": "" }, "metadata": {}, "output_type": "display_data" } ], "source": [ "hidden_size = 256\n", "batch_size = 64\n", "\n", "input_lang, output_lang, train_dataloader = get_dataloader(batch_size)\n", "\n", "encoder = EncoderRNN(input_lang.n_words, hidden_size).to(device)\n", "decoder = AttnDecoderRNN(hidden_size, output_lang.n_words).to(device)\n", "\n", "train(train_dataloader, encoder, decoder, 100, print_every=5, plot_every=5)" ], "metadata": { "collapsed": false, "ExecuteTime": { "start_time": "2024-06-02T20:00:44.619526Z", "end_time": "2024-06-02T20:03:13.180305Z" } } }, { "cell_type": "markdown", "source": [ "### Ewaluacja" ], "metadata": { "collapsed": false } }, { "cell_type": "code", "execution_count": 85, "outputs": [], "source": [ "def evaluate(encoder, decoder, sentence, input_lang, output_lang):\n", " with torch.no_grad():\n", " input_tensor = tensorFromSentence(input_lang, sentence)\n", "\n", " encoder_outputs, encoder_hidden = encoder(input_tensor)\n", " decoder_outputs, decoder_hidden, decoder_attn = decoder(encoder_outputs, encoder_hidden)\n", "\n", " _, topi = decoder_outputs.topk(1)\n", " decoded_ids = topi.squeeze()\n", "\n", " decoded_words = []\n", " for idx in decoded_ids:\n", " if idx.item() == EOS_token:\n", " decoded_words.append('')\n", " break\n", " decoded_words.append(output_lang.index2word[idx.item()])\n", " return decoded_words, decoder_attn" ], "metadata": { "collapsed": false, "ExecuteTime": { "start_time": "2024-06-02T19:59:01.782695Z", "end_time": "2024-06-02T19:59:01.811933Z" } } }, { "cell_type": "code", "execution_count": 86, "outputs": [], "source": [ "def evaluateRandomly(encoder, decoder, n=10):\n", " for i in range(n):\n", " pair = random.choice(pairs)\n", " print('>', pair[0])\n", " print('=', pair[1])\n", " output_words, _ = evaluate(encoder, decoder, pair[0], input_lang, output_lang)\n", " output_sentence = ' '.join(output_words)\n", " print('<', output_sentence)\n", " print('')" ], "metadata": { "collapsed": false, "ExecuteTime": { "start_time": "2024-06-02T19:59:02.352827Z", "end_time": "2024-06-02T19:59:02.374825Z" } } }, { "cell_type": "code", "execution_count": 99, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "> utne sobie drzemke\n", "= i m going to go take a nap\n", "< i m going to go take a nap wallet \n", "\n", "> nie jestem co do tego pewny to zalezy\n", "= i m not sure about that it depends\n", "< i m not sure about that it depends \n", "\n", "> nie kupujemy\n", "= we re not buying\n", "< we re not buying \n", "\n", "> nie jestem g upi\n", "= i m not stupid\n", "< i m not stupid \n", "\n", "> jestes wymagajacy\n", "= you re demanding\n", "< you re demanding \n", "\n", "> jestem m ody ale nie az tak\n", "= i m young but i m not that young\n", "< i m young but i m not that young \n", "\n", "> nie jestem ubrana\n", "= i m not dressed\n", "< i m not dressed \n", "\n", "> jestem gotowy sie z tym pogodzic\n", "= i m ready to accept it\n", "< i m ready to accept it \n", "\n", "> jestem pewny ze ona nied ugo wroci\n", "= i m sure that she will come back soon\n", "< i m sure that she will come back soon \n", "\n", "> w niedziele mam wolne\n", "= i m free on sunday\n", "< i m free on sunday \n", "\n" ] } ], "source": [ "encoder.eval()\n", "decoder.eval()\n", "evaluateRandomly(encoder, decoder)" ], "metadata": { "collapsed": false, "ExecuteTime": { "start_time": "2024-06-02T20:03:19.348154Z", "end_time": "2024-06-02T20:03:19.572157Z" } } }, { "cell_type": "code", "execution_count": 88, "outputs": [], "source": [ "def showAttention(input_sentence, output_words, attentions):\n", " fig = plt.figure()\n", " ax = fig.add_subplot(111)\n", " cax = ax.matshow(attentions.cpu().numpy(), cmap='bone')\n", " fig.colorbar(cax)\n", "\n", " # Set up axes\n", " ax.set_xticklabels([''] + input_sentence.split(' ') +\n", " [''], rotation=90)\n", " ax.set_yticklabels([''] + output_words)\n", "\n", " # Show label at every tick\n", " ax.xaxis.set_major_locator(ticker.MultipleLocator(1))\n", " ax.yaxis.set_major_locator(ticker.MultipleLocator(1))\n", "\n", " plt.show()\n", "\n", "\n", "def evaluateAndShowAttention(input_sentence):\n", " input_sentence = normalizeString(input_sentence)\n", " output_words, attentions = evaluate(encoder, decoder, input_sentence, input_lang, output_lang)\n", " print('input =', input_sentence)\n", " print('output =', ' '.join(output_words))\n", " showAttention(input_sentence, output_words, attentions[0, :len(output_words), :])" ], "metadata": { "collapsed": false, "ExecuteTime": { "start_time": "2024-06-02T19:59:04.218821Z", "end_time": "2024-06-02T19:59:04.250855Z" } } }, { "cell_type": "code", "execution_count": 100, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "input = nie jestem katoliczka\n", "output = i m not catholic \n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "C:\\Users\\adamw\\AppData\\Local\\Temp\\ipykernel_17652\\691622281.py:8: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.\n", " ax.set_xticklabels([''] + input_sentence.split(' ') +\n", "C:\\Users\\adamw\\AppData\\Local\\Temp\\ipykernel_17652\\691622281.py:10: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.\n", " ax.set_yticklabels([''] + output_words)\n" ] }, { "data": { "text/plain": "
", "image/png": "" }, "metadata": {}, "output_type": "display_data" } ], "source": [ "evaluateAndShowAttention('Nie jestem katoliczką')" ], "metadata": { "collapsed": false, "ExecuteTime": { "start_time": "2024-06-02T20:03:24.422192Z", "end_time": "2024-06-02T20:03:24.634214Z" } } }, { "cell_type": "code", "execution_count": 101, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "input = przykro nam ze to sie zdarzy o\n", "output = we re sorry that it happened \n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "C:\\Users\\adamw\\AppData\\Local\\Temp\\ipykernel_17652\\691622281.py:8: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.\n", " ax.set_xticklabels([''] + input_sentence.split(' ') +\n", "C:\\Users\\adamw\\AppData\\Local\\Temp\\ipykernel_17652\\691622281.py:10: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.\n", " ax.set_yticklabels([''] + output_words)\n" ] }, { "data": { "text/plain": "
", "image/png": "" }, "metadata": {}, "output_type": "display_data" } ], "source": [ "evaluateAndShowAttention('Przykro nam ze to sie zdarzyło')\n" ], "metadata": { "collapsed": false, "ExecuteTime": { "start_time": "2024-06-02T20:03:25.856941Z", "end_time": "2024-06-02T20:03:26.205536Z" } } }, { "cell_type": "code", "execution_count": 102, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "input = on mowi p ynnie po francusku\n", "output = he is fluent in french \n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "C:\\Users\\adamw\\AppData\\Local\\Temp\\ipykernel_17652\\691622281.py:8: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.\n", " ax.set_xticklabels([''] + input_sentence.split(' ') +\n", "C:\\Users\\adamw\\AppData\\Local\\Temp\\ipykernel_17652\\691622281.py:10: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.\n", " ax.set_yticklabels([''] + output_words)\n" ] }, { "data": { "text/plain": "
", "image/png": "" }, "metadata": {}, "output_type": "display_data" } ], "source": [ "evaluateAndShowAttention('On mówi płynnie po francusku')\n" ], "metadata": { "collapsed": false, "ExecuteTime": { "start_time": "2024-06-02T20:03:26.594026Z", "end_time": "2024-06-02T20:03:26.838018Z" } } }, { "cell_type": "markdown", "source": [ "### BLEU" ], "metadata": { "collapsed": false } }, { "cell_type": "code", "execution_count": 103, "outputs": [], "source": [ "def filter_rows(row):\n", " return len(row[\"eng\"].split(' '))\" in output_words:\n", " output_words.remove(\"\")\n", " return output_words" ], "metadata": { "collapsed": false, "ExecuteTime": { "start_time": "2024-06-02T20:03:29.868015Z", "end_time": "2024-06-02T20:03:29.884050Z" } } }, { "cell_type": "code", "execution_count": 114, "outputs": [], "source": [ "df = pd.read_csv(\"data/eng-pol.txt\", sep='\\t', names=[\"eng\", \"pol\", \"attribution\"])\n", "df[\"eng\"] = df[\"eng\"].apply(normalizeString)\n", "df[\"pol\"] = df[\"pol\"].apply(normalizeString)\n", "df_filtered = df.apply(filter_rows, axis=1)\n", "test_df = df[df_filtered].sample(frac=1)\n", "test_df[\"eng_token\"] = test_df[\"eng\"].apply(lambda x: x.split())\n", "test_df[\"eng_eval\"] = test_df[\"pol\"].apply(lambda x: evaluateWithTokenization(x))" ], "metadata": { "collapsed": false, "ExecuteTime": { "start_time": "2024-06-02T20:07:48.707058Z", "end_time": "2024-06-02T20:08:22.246952Z" } } }, { "cell_type": "code", "execution_count": 115, "outputs": [], "source": [ "references_corpus = test_df[\"eng_token\"].values.tolist()\n", "candidate_corpus = test_df[\"eng_eval\"].values.tolist()\n", "references_corpus = [[el] for el in references_corpus]" ], "metadata": { "collapsed": false, "ExecuteTime": { "start_time": "2024-06-02T20:08:22.248949Z", "end_time": "2024-06-02T20:08:22.262981Z" } } }, { "cell_type": "code", "execution_count": 116, "outputs": [ { "data": { "text/plain": "0.9301728010177612" }, "execution_count": 116, "metadata": {}, "output_type": "execute_result" } ], "source": [ "bleu_score(candidate_corpus, references_corpus)" ], "metadata": { "collapsed": false, "ExecuteTime": { "start_time": "2024-06-02T20:08:22.264948Z", "end_time": "2024-06-02T20:08:23.695461Z" } } } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.6" } }, "nbformat": 4, "nbformat_minor": 0 }