{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Seq2Seq Polski --> Angielski\n", "https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "_cell_guid": "b1076dfc-b9ad-4769-8c92-a6c4dae69d19", "_uuid": "8f2839f25d086af736a60e9eeb907d3b93b6e0e5", "execution": { "iopub.execute_input": "2024-05-25T14:03:55.887266Z", "iopub.status.busy": "2024-05-25T14:03:55.886451Z", "iopub.status.idle": "2024-05-25T14:04:02.514594Z", "shell.execute_reply": "2024-05-25T14:04:02.513697Z", "shell.execute_reply.started": "2024-05-25T14:03:55.887232Z" }, "trusted": true }, "outputs": [], "source": [ "from __future__ import unicode_literals, print_function, division\n", "from io import open\n", "import unicodedata\n", "import re\n", "import random\n", "\n", "import torch\n", "import torch.nn as nn\n", "from torch import optim\n", "import torch.nn.functional as F\n", "\n", "import numpy as np\n", "from torch.utils.data import TensorDataset, DataLoader, RandomSampler\n", "\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "execution": { "iopub.execute_input": "2024-05-25T14:04:09.403926Z", "iopub.status.busy": "2024-05-25T14:04:09.403445Z", "iopub.status.idle": "2024-05-25T14:04:09.434533Z", "shell.execute_reply": "2024-05-25T14:04:09.433678Z", "shell.execute_reply.started": "2024-05-25T14:04:09.403898Z" }, "trusted": true }, "outputs": [ { "data": { "text/plain": [ "0" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "torch.cuda.device_count()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Konwersja słów na index" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "execution": { "iopub.execute_input": "2024-05-25T14:04:14.014490Z", "iopub.status.busy": "2024-05-25T14:04:14.014114Z", "iopub.status.idle": "2024-05-25T14:04:14.024526Z", "shell.execute_reply": "2024-05-25T14:04:14.023673Z", "shell.execute_reply.started": "2024-05-25T14:04:14.014461Z" }, "trusted": true }, "outputs": [], "source": [ "SOS_token = 0\n", "EOS_token = 1\n", "\n", "class Lang:\n", " def __init__(self, name):\n", " self.name = name\n", " self.word2index = {}\n", " self.word2count = {}\n", " self.index2word = {0: \"SOS\", 1: \"EOS\"}\n", " self.n_words = 2 # Count SOS and EOS\n", "\n", " def addSentence(self, sentence):\n", " for word in sentence.split(' '):\n", " self.addWord(word)\n", "\n", " def addWord(self, word):\n", " if word not in self.word2index:\n", " self.word2index[word] = self.n_words\n", " self.word2count[word] = 1\n", " self.index2word[self.n_words] = word\n", " self.n_words += 1\n", " else:\n", " self.word2count[word] += 1" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Normalizacja tekstu" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "execution": { "iopub.execute_input": "2024-05-25T14:04:23.432285Z", "iopub.status.busy": "2024-05-25T14:04:23.431898Z", "iopub.status.idle": "2024-05-25T14:04:23.438688Z", "shell.execute_reply": "2024-05-25T14:04:23.437569Z", "shell.execute_reply.started": "2024-05-25T14:04:23.432256Z" }, "trusted": true }, "outputs": [], "source": [ "# Turn a Unicode string to plain ASCII, thanks to\n", "# https://stackoverflow.com/a/518232/2809427\n", "def unicodeToAscii(s):\n", " return ''.join(\n", " c for c in unicodedata.normalize('NFD', s)\n", " if unicodedata.category(c) != 'Mn'\n", " )\n", "\n", "# Lowercase, trim, and remove non-letter characters\n", "def normalizeString(s):\n", " s = unicodeToAscii(s.lower().strip())\n", " s = re.sub(r\"([.!?])\", r\" \\1\", s)\n", " s = re.sub(r\"[^a-zA-Z!?]+\", r\" \", s)\n", " return s.strip()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Wczytywanie danych (zmodyfikowane ze względu na ścieżkę w kaggle)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "execution": { "iopub.execute_input": "2024-05-25T14:12:25.386029Z", "iopub.status.busy": "2024-05-25T14:12:25.385674Z", "iopub.status.idle": "2024-05-25T14:12:25.394103Z", "shell.execute_reply": "2024-05-25T14:12:25.392925Z", "shell.execute_reply.started": "2024-05-25T14:12:25.386002Z" }, "trusted": true }, "outputs": [], "source": [ "def readLangs(reverse=False):\n", " print(\"Reading lines...\")\n", " lang1=\"en\"\n", " lang2=\"pol\"\n", " # Read the file and split into lines\n", " lines = open('pol.txt', encoding='utf-8').\\\n", " read().strip().split('\\n')\n", "\n", " # Split every line into pairs and normalize\n", " pairs = [[normalizeString(s) for s in l.split('\\t')[:-1]] for l in lines]\n", "\n", " # Reverse pairs, make Lang instances\n", " if reverse:\n", " pairs = [list(reversed(p)) for p in pairs]\n", " input_lang = Lang(lang2)\n", " output_lang = Lang(lang1)\n", " else:\n", " input_lang = Lang(lang1)\n", " output_lang = Lang(lang2)\n", "\n", " return input_lang, output_lang, pairs" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Ograniczenie do zdań max 10 słów, formy I am / You are / He is etc. bez interpunkcji" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "execution": { "iopub.execute_input": "2024-05-25T14:12:29.730147Z", "iopub.status.busy": "2024-05-25T14:12:29.729786Z", "iopub.status.idle": "2024-05-25T14:12:29.737013Z", "shell.execute_reply": "2024-05-25T14:12:29.735886Z", "shell.execute_reply.started": "2024-05-25T14:12:29.730121Z" }, "trusted": true }, "outputs": [], "source": [ "MAX_LENGTH = 10\n", "\n", "eng_prefixes = (\n", " \"i am \", \"i m \",\n", " \"he is\", \"he s \",\n", " \"she is\", \"she s \",\n", " \"you are\", \"you re \",\n", " \"we are\", \"we re \",\n", " \"they are\", \"they re \"\n", ")\n", "\n", "def filterPair(p):\n", " return len(p[0].split(' ')) < MAX_LENGTH and \\\n", " len(p[1].split(' ')) < MAX_LENGTH and \\\n", " p[1].startswith(eng_prefixes)\n", "\n", "\n", "def filterPairs(pairs):\n", " return [pair for pair in pairs if filterPair(pair)]" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "execution": { "iopub.execute_input": "2024-05-25T14:12:33.204776Z", "iopub.status.busy": "2024-05-25T14:12:33.204103Z", "iopub.status.idle": "2024-05-25T14:12:36.889693Z", "shell.execute_reply": "2024-05-25T14:12:36.888700Z", "shell.execute_reply.started": "2024-05-25T14:12:33.204744Z" }, "trusted": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Reading lines...\n", "Read 49943 sentence pairs\n", "Trimmed to 3613 sentence pairs\n", "Counting words...\n", "Counted words:\n", "pol 3070\n", "en 1969\n", "['jestem tylko mechanikiem', 'i m only the mechanic']\n" ] } ], "source": [ "def prepareData(reverse=False):\n", " input_lang, output_lang, pairs = readLangs(reverse)\n", " print(\"Read %s sentence pairs\" % len(pairs))\n", " pairs = filterPairs(pairs)\n", " print(\"Trimmed to %s sentence pairs\" % len(pairs))\n", " print(\"Counting words...\")\n", " for pair in pairs:\n", " input_lang.addSentence(pair[0])\n", " output_lang.addSentence(pair[1])\n", " print(\"Counted words:\")\n", " print(input_lang.name, input_lang.n_words)\n", " print(output_lang.name, output_lang.n_words)\n", " return input_lang, output_lang, pairs\n", "\n", "input_lang, output_lang, pairs = prepareData(True)\n", "print(random.choice(pairs))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Definicja modelu" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "execution": { "iopub.execute_input": "2024-05-25T14:12:52.384131Z", "iopub.status.busy": "2024-05-25T14:12:52.383787Z", "iopub.status.idle": "2024-05-25T14:12:52.391196Z", "shell.execute_reply": "2024-05-25T14:12:52.390316Z", "shell.execute_reply.started": "2024-05-25T14:12:52.384104Z" }, "trusted": true }, "outputs": [], "source": [ "class EncoderRNN(nn.Module):\n", " def __init__(self, input_size, hidden_size, dropout_p=0.1):\n", " super(EncoderRNN, self).__init__()\n", " self.hidden_size = hidden_size\n", "\n", " self.embedding = nn.Embedding(input_size, hidden_size)\n", " self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True)\n", " self.dropout = nn.Dropout(dropout_p)\n", "\n", " def forward(self, input):\n", " embedded = self.dropout(self.embedding(input))\n", " output, hidden = self.gru(embedded)\n", " return output, hidden" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "execution": { "iopub.execute_input": "2024-05-25T14:12:54.394808Z", "iopub.status.busy": "2024-05-25T14:12:54.393953Z", "iopub.status.idle": "2024-05-25T14:12:54.409000Z", "shell.execute_reply": "2024-05-25T14:12:54.407827Z", "shell.execute_reply.started": "2024-05-25T14:12:54.394765Z" }, "trusted": true }, "outputs": [], "source": [ "class DecoderRNN(nn.Module):\n", " def __init__(self, hidden_size, output_size):\n", " super(DecoderRNN, self).__init__()\n", " self.embedding = nn.Embedding(output_size, hidden_size)\n", " self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True)\n", " self.out = nn.Linear(hidden_size, output_size)\n", "\n", " def forward(self, encoder_outputs, encoder_hidden, target_tensor=None):\n", " batch_size = encoder_outputs.size(0)\n", " decoder_input = torch.empty(batch_size, 1, dtype=torch.long, device=device).fill_(SOS_token)\n", " decoder_hidden = encoder_hidden\n", " decoder_outputs = []\n", "\n", " for i in range(MAX_LENGTH):\n", " decoder_output, decoder_hidden = self.forward_step(decoder_input, decoder_hidden)\n", " decoder_outputs.append(decoder_output)\n", "\n", " if target_tensor is not None:\n", " # Teacher forcing: Feed the target as the next input\n", " decoder_input = target_tensor[:, i].unsqueeze(1) # Teacher forcing\n", " else:\n", " # Without teacher forcing: use its own predictions as the next input\n", " _, topi = decoder_output.topk(1)\n", " decoder_input = topi.squeeze(-1).detach() # detach from history as input\n", "\n", " decoder_outputs = torch.cat(decoder_outputs, dim=1)\n", " decoder_outputs = F.log_softmax(decoder_outputs, dim=-1)\n", " return decoder_outputs, decoder_hidden, None # We return `None` for consistency in the training loop\n", "\n", " def forward_step(self, input, hidden):\n", " output = self.embedding(input)\n", " output = F.relu(output)\n", " output, hidden = self.gru(output, hidden)\n", " output = self.out(output)\n", " return output, hidden" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "execution": { "iopub.execute_input": "2024-05-25T14:13:00.670758Z", "iopub.status.busy": "2024-05-25T14:13:00.670299Z", "iopub.status.idle": "2024-05-25T14:13:00.687695Z", "shell.execute_reply": "2024-05-25T14:13:00.686610Z", "shell.execute_reply.started": "2024-05-25T14:13:00.670720Z" }, "trusted": true }, "outputs": [], "source": [ "class BahdanauAttention(nn.Module):\n", " def __init__(self, hidden_size):\n", " super(BahdanauAttention, self).__init__()\n", " self.Wa = nn.Linear(hidden_size, hidden_size)\n", " self.Ua = nn.Linear(hidden_size, hidden_size)\n", " self.Va = nn.Linear(hidden_size, 1)\n", "\n", " def forward(self, query, keys):\n", " scores = self.Va(torch.tanh(self.Wa(query) + self.Ua(keys)))\n", " scores = scores.squeeze(2).unsqueeze(1)\n", "\n", " weights = F.softmax(scores, dim=-1)\n", " context = torch.bmm(weights, keys)\n", "\n", " return context, weights\n", "\n", "class AttnDecoderRNN(nn.Module):\n", " def __init__(self, hidden_size, output_size, dropout_p=0.1):\n", " super(AttnDecoderRNN, self).__init__()\n", " self.embedding = nn.Embedding(output_size, hidden_size)\n", " self.attention = BahdanauAttention(hidden_size)\n", " self.gru = nn.GRU(2 * hidden_size, hidden_size, batch_first=True)\n", " self.out = nn.Linear(hidden_size, output_size)\n", " self.dropout = nn.Dropout(dropout_p)\n", "\n", " def forward(self, encoder_outputs, encoder_hidden, target_tensor=None):\n", " batch_size = encoder_outputs.size(0)\n", " decoder_input = torch.empty(batch_size, 1, dtype=torch.long, device=device).fill_(SOS_token)\n", " decoder_hidden = encoder_hidden\n", " decoder_outputs = []\n", " attentions = []\n", "\n", " for i in range(MAX_LENGTH):\n", " decoder_output, decoder_hidden, attn_weights = self.forward_step(\n", " decoder_input, decoder_hidden, encoder_outputs\n", " )\n", " decoder_outputs.append(decoder_output)\n", " attentions.append(attn_weights)\n", "\n", " if target_tensor is not None:\n", " # Teacher forcing: Feed the target as the next input\n", " decoder_input = target_tensor[:, i].unsqueeze(1) # Teacher forcing\n", " else:\n", " # Without teacher forcing: use its own predictions as the next input\n", " _, topi = decoder_output.topk(1)\n", " decoder_input = topi.squeeze(-1).detach() # detach from history as input\n", "\n", " decoder_outputs = torch.cat(decoder_outputs, dim=1)\n", " decoder_outputs = F.log_softmax(decoder_outputs, dim=-1)\n", " attentions = torch.cat(attentions, dim=1)\n", "\n", " return decoder_outputs, decoder_hidden, attentions\n", "\n", "\n", " def forward_step(self, input, hidden, encoder_outputs):\n", " embedded = self.dropout(self.embedding(input))\n", "\n", " query = hidden.permute(1, 0, 2)\n", " context, attn_weights = self.attention(query, encoder_outputs)\n", " input_gru = torch.cat((embedded, context), dim=2)\n", "\n", " output, hidden = self.gru(input_gru, hidden)\n", " output = self.out(output)\n", "\n", " return output, hidden, attn_weights" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "execution": { "iopub.execute_input": "2024-05-25T14:22:08.184711Z", "iopub.status.busy": "2024-05-25T14:22:08.183866Z", "iopub.status.idle": "2024-05-25T14:22:08.194870Z", "shell.execute_reply": "2024-05-25T14:22:08.193965Z", "shell.execute_reply.started": "2024-05-25T14:22:08.184675Z" }, "trusted": true }, "outputs": [], "source": [ "def indexesFromSentence(lang, sentence):\n", " return [lang.word2index[word] for word in sentence.split(' ')]\n", "\n", "def tensorFromSentence(lang, sentence):\n", " indexes = indexesFromSentence(lang, sentence)\n", " indexes.append(EOS_token)\n", " return torch.tensor(indexes, dtype=torch.long, device=device).view(1, -1)\n", "\n", "def tensorsFromPair(pair):\n", " input_tensor = tensorFromSentence(input_lang, pair[0])\n", " target_tensor = tensorFromSentence(output_lang, pair[1])\n", " return (input_tensor, target_tensor)\n", "\n", "def get_dataloader(batch_size):\n", " input_lang, output_lang, pairs = prepareData(True)\n", "\n", " n = len(pairs)\n", " input_ids = np.zeros((n, MAX_LENGTH), dtype=np.int32)\n", " target_ids = np.zeros((n, MAX_LENGTH), dtype=np.int32)\n", "\n", " for idx, (inp, tgt) in enumerate(pairs):\n", " inp_ids = indexesFromSentence(input_lang, inp)\n", " tgt_ids = indexesFromSentence(output_lang, tgt)\n", " inp_ids.append(EOS_token)\n", " tgt_ids.append(EOS_token)\n", " input_ids[idx, :len(inp_ids)] = inp_ids\n", " target_ids[idx, :len(tgt_ids)] = tgt_ids\n", "\n", " train_data = TensorDataset(torch.LongTensor(input_ids).to(device),\n", " torch.LongTensor(target_ids).to(device))\n", "\n", " train_sampler = RandomSampler(train_data)\n", " train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size)\n", " return input_lang, output_lang, train_dataloader" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "execution": { "iopub.execute_input": "2024-05-25T14:16:38.895410Z", "iopub.status.busy": "2024-05-25T14:16:38.894580Z", "iopub.status.idle": "2024-05-25T14:16:38.902142Z", "shell.execute_reply": "2024-05-25T14:16:38.900953Z", "shell.execute_reply.started": "2024-05-25T14:16:38.895382Z" }, "trusted": true }, "outputs": [], "source": [ "def train_epoch(dataloader, encoder, decoder, encoder_optimizer,\n", " decoder_optimizer, criterion):\n", "\n", " total_loss = 0\n", " for data in dataloader:\n", " input_tensor, target_tensor = data\n", "\n", " encoder_optimizer.zero_grad()\n", " decoder_optimizer.zero_grad()\n", "\n", " encoder_outputs, encoder_hidden = encoder(input_tensor)\n", " decoder_outputs, _, _ = decoder(encoder_outputs, encoder_hidden, target_tensor)\n", "\n", " loss = criterion(\n", " decoder_outputs.view(-1, decoder_outputs.size(-1)),\n", " target_tensor.view(-1)\n", " )\n", " loss.backward()\n", "\n", " encoder_optimizer.step()\n", " decoder_optimizer.step()\n", "\n", " total_loss += loss.item()\n", "\n", " return total_loss / len(dataloader)" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "execution": { "iopub.execute_input": "2024-05-25T14:16:43.069953Z", "iopub.status.busy": "2024-05-25T14:16:43.069584Z", "iopub.status.idle": "2024-05-25T14:16:43.075972Z", "shell.execute_reply": "2024-05-25T14:16:43.075033Z", "shell.execute_reply.started": "2024-05-25T14:16:43.069926Z" }, "trusted": true }, "outputs": [], "source": [ "import time\n", "import math\n", "\n", "def asMinutes(s):\n", " m = math.floor(s / 60)\n", " s -= m * 60\n", " return '%dm %ds' % (m, s)\n", "\n", "def timeSince(since, percent):\n", " now = time.time()\n", " s = now - since\n", " es = s / (percent)\n", " rs = es - s\n", " return '%s (- %s)' % (asMinutes(s), asMinutes(rs))" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "execution": { "iopub.execute_input": "2024-05-25T14:20:58.574520Z", "iopub.status.busy": "2024-05-25T14:20:58.574148Z", "iopub.status.idle": "2024-05-25T14:20:58.583203Z", "shell.execute_reply": "2024-05-25T14:20:58.582230Z", "shell.execute_reply.started": "2024-05-25T14:20:58.574492Z" }, "trusted": true }, "outputs": [], "source": [ "def train(train_dataloader, encoder, decoder, n_epochs, learning_rate=0.001,\n", " print_every=100, plot_every=100):\n", " start = time.time()\n", " plot_losses = []\n", " print_loss_total = 0 # Reset every print_every\n", " plot_loss_total = 0 # Reset every plot_every\n", "\n", " encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)\n", " decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate)\n", " criterion = nn.NLLLoss()\n", "\n", " for epoch in range(1, n_epochs + 1):\n", " loss = train_epoch(train_dataloader, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion)\n", " print_loss_total += loss\n", " plot_loss_total += loss\n", "\n", " if epoch % print_every == 0:\n", " print_loss_avg = print_loss_total / print_every\n", " print_loss_total = 0\n", " print('%s (%d %d%%) %.4f' % (timeSince(start, epoch / n_epochs),\n", " epoch, epoch / n_epochs * 100, print_loss_avg))\n", "\n", " if epoch % plot_every == 0:\n", " plot_loss_avg = plot_loss_total / plot_every\n", " plot_losses.append(plot_loss_avg)\n", " plot_loss_total = 0\n", "\n", " showPlot(plot_losses)" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "execution": { "iopub.execute_input": "2024-05-25T14:21:00.586719Z", "iopub.status.busy": "2024-05-25T14:21:00.586018Z", "iopub.status.idle": "2024-05-25T14:21:00.592633Z", "shell.execute_reply": "2024-05-25T14:21:00.591636Z", "shell.execute_reply.started": "2024-05-25T14:21:00.586683Z" }, "trusted": true }, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "plt.switch_backend('agg')\n", "import matplotlib.ticker as ticker\n", "import numpy as np\n", "%matplotlib inline\n", "\n", "def showPlot(points):\n", " plt.figure()\n", " fig, ax = plt.subplots()\n", " # this locator puts ticks at regular intervals\n", " loc = ticker.MultipleLocator(base=0.2)\n", " ax.yaxis.set_major_locator(loc)\n", " plt.plot(points)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Ewaluacja" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "execution": { "iopub.execute_input": "2024-05-25T14:21:01.859612Z", "iopub.status.busy": "2024-05-25T14:21:01.858691Z", "iopub.status.idle": "2024-05-25T14:21:01.866857Z", "shell.execute_reply": "2024-05-25T14:21:01.865732Z", "shell.execute_reply.started": "2024-05-25T14:21:01.859574Z" }, "trusted": true }, "outputs": [], "source": [ "def evaluate(encoder, decoder, sentence, input_lang, output_lang):\n", " with torch.no_grad():\n", " input_tensor = tensorFromSentence(input_lang, sentence)\n", "\n", " encoder_outputs, encoder_hidden = encoder(input_tensor)\n", " decoder_outputs, decoder_hidden, decoder_attn = decoder(encoder_outputs, encoder_hidden)\n", "\n", " _, topi = decoder_outputs.topk(1)\n", " decoded_ids = topi.squeeze()\n", "\n", " decoded_words = []\n", " for idx in decoded_ids:\n", " if idx.item() == EOS_token:\n", " decoded_words.append('')\n", " break\n", " decoded_words.append(output_lang.index2word[idx.item()])\n", " return decoded_words, decoder_attn" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "execution": { "iopub.execute_input": "2024-05-25T14:39:52.475327Z", "iopub.status.busy": "2024-05-25T14:39:52.474985Z", "iopub.status.idle": "2024-05-25T14:39:52.481801Z", "shell.execute_reply": "2024-05-25T14:39:52.480957Z", "shell.execute_reply.started": "2024-05-25T14:39:52.475304Z" }, "trusted": true }, "outputs": [], "source": [ "def evaluateRandomly(encoder, decoder, n=10):\n", " for i in range(n):\n", " pair = random.choice(pairs)\n", " print('Input sentence: ', pair[0])\n", " print('Target (true) translation:' , pair[1])\n", " output_words, _ = evaluate(encoder, decoder, pair[0], input_lang, output_lang)\n", " output_sentence = ' '.join(output_words)\n", " print('Output sentence: ', output_sentence)\n", " print('')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Wykorzystanie zdefiniowanych wyżej funkcji" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Trenowanie modelu" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "execution": { "iopub.execute_input": "2024-05-25T14:22:39.754740Z", "iopub.status.busy": "2024-05-25T14:22:39.754370Z", "iopub.status.idle": "2024-05-25T14:26:53.707012Z", "shell.execute_reply": "2024-05-25T14:26:53.705969Z", "shell.execute_reply.started": "2024-05-25T14:22:39.754714Z" }, "trusted": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Reading lines...\n", "Read 49943 sentence pairs\n", "Trimmed to 3613 sentence pairs\n", "Counting words...\n", "Counted words:\n", "pol 3070\n", "en 1969\n", "0m 47s (- 11m 58s) (5 6%) 2.1245\n", "1m 33s (- 10m 57s) (10 12%) 1.2482\n", "2m 13s (- 9m 36s) (15 18%) 0.8442\n", "2m 51s (- 8m 35s) (20 25%) 0.5612\n", "3m 32s (- 7m 46s) (25 31%) 0.3599\n", "4m 10s (- 6m 56s) (30 37%) 0.2216\n", "4m 51s (- 6m 15s) (35 43%) 0.1367\n", "5m 30s (- 5m 30s) (40 50%) 0.0894\n", "6m 9s (- 4m 47s) (45 56%) 0.0647\n", "6m 48s (- 4m 5s) (50 62%) 0.0489\n", "7m 27s (- 3m 23s) (55 68%) 0.0402\n", "8m 8s (- 2m 42s) (60 75%) 0.0345\n", "8m 48s (- 2m 1s) (65 81%) 0.0315\n", "9m 25s (- 1m 20s) (70 87%) 0.0278\n", "10m 3s (- 0m 40s) (75 93%) 0.0271\n", "10m 42s (- 0m 0s) (80 100%) 0.0253\n" ] }, { "data": { "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "hidden_size = 128\n", "batch_size = 32\n", "\n", "input_lang, output_lang, train_dataloader = get_dataloader(batch_size)\n", "\n", "encoder = EncoderRNN(input_lang.n_words, hidden_size).to(device)\n", "decoder = AttnDecoderRNN(hidden_size, output_lang.n_words).to(device)\n", "\n", "train(train_dataloader, encoder, decoder, 80, print_every=5, plot_every=5)" ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "execution": { "iopub.execute_input": "2024-05-25T15:48:23.131948Z", "iopub.status.busy": "2024-05-25T15:48:23.131435Z", "iopub.status.idle": "2024-05-25T15:48:23.213007Z", "shell.execute_reply": "2024-05-25T15:48:23.211856Z", "shell.execute_reply.started": "2024-05-25T15:48:23.131911Z" }, "trusted": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Input sentence: ciesze sie ze by em w stanie pomoc\n", "Target (true) translation: i m glad i was able to help\n", "Output sentence: i m glad i was able to help \n", "\n", "Input sentence: to moja matka chrzestna\n", "Target (true) translation: she s my godmother\n", "Output sentence: she s my godmother by three \n", "\n", "Input sentence: nie gram w zadna gre\n", "Target (true) translation: i m not playing a game\n", "Output sentence: i m not playing a game \n", "\n", "Input sentence: jestem wyzszy\n", "Target (true) translation: i am taller\n", "Output sentence: i am taller \n", "\n", "Input sentence: jestes zdesperowany\n", "Target (true) translation: you re desperate\n", "Output sentence: you re desperate \n", "\n", "Input sentence: zostane zwolniony\n", "Target (true) translation: i m going to get fired\n", "Output sentence: i m going to be arrested i think \n", "\n", "Input sentence: mamy dzisiaj rybe jako g owne danie\n", "Target (true) translation: we are having fish for our main course\n", "Output sentence: we are having fish for our main course \n", "\n", "Input sentence: jestes przepracowana\n", "Target (true) translation: you are overworked\n", "Output sentence: you are overworked \n", "\n", "Input sentence: jestes elokwentny\n", "Target (true) translation: you re articulate\n", "Output sentence: you re articulate \n", "\n", "Input sentence: zaczynam rozumiec\n", "Target (true) translation: i m beginning to understand\n", "Output sentence: i m beginning to understand \n", "\n" ] } ], "source": [ "evaluateRandomly(encoder, decoder)" ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "execution": { "iopub.execute_input": "2024-05-25T15:30:41.253734Z", "iopub.status.busy": "2024-05-25T15:30:41.253325Z", "iopub.status.idle": "2024-05-25T15:30:41.264515Z", "shell.execute_reply": "2024-05-25T15:30:41.263376Z", "shell.execute_reply.started": "2024-05-25T15:30:41.253703Z" }, "trusted": true }, "outputs": [], "source": [ "def showAttention(input_sentence, output_words, attentions):\n", " fig = plt.figure()\n", " ax = fig.add_subplot(111)\n", " cax = ax.matshow(attentions.cpu().numpy(), cmap='bone')\n", " fig.colorbar(cax)\n", "\n", " # Set up axes\n", " ax.set_xticklabels([''] + input_sentence.split(' ') +\n", " [''], rotation=90)\n", " ax.set_yticklabels([''] + output_words)\n", "\n", " # Show label at every tick\n", " ax.xaxis.set_major_locator(ticker.MultipleLocator(1))\n", " ax.yaxis.set_major_locator(ticker.MultipleLocator(1))\n", "\n", " plt.show()\n", "\n", "\n", "def evaluateAndShowAttention(input_sentence):\n", " input_sentence = normalizeString(input_sentence)\n", " output_words, attentions = evaluate(encoder, decoder, input_sentence, input_lang, output_lang)\n", " print('input =', input_sentence)\n", " print('output =', ' '.join(output_words))\n", " showAttention(input_sentence, output_words, attentions[0, :len(output_words), :])\n", "\n", "def translate(input_sentence, tokenized=False):\n", " input_sentence = normalizeString(input_sentence)\n", " output_words, attentions = evaluate(encoder, decoder, input_sentence, input_lang, output_lang)\n", " if tokenized:\n", " if \"\" in output_words:\n", " output_words.remove(\"\")\n", " return output_words\n", " return ' '.join(output_words)" ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "execution": { "iopub.execute_input": "2024-05-25T14:58:53.639963Z", "iopub.status.busy": "2024-05-25T14:58:53.639252Z", "iopub.status.idle": "2024-05-25T14:58:53.654186Z", "shell.execute_reply": "2024-05-25T14:58:53.653028Z", "shell.execute_reply.started": "2024-05-25T14:58:53.639932Z" }, "trusted": true }, "outputs": [ { "data": { "text/plain": [ "['we', 'are', 'hungry']" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "translate(\"Jesteśmy głodni\", tokenized=True)" ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "execution": { "iopub.execute_input": "2024-05-25T14:46:05.024277Z", "iopub.status.busy": "2024-05-25T14:46:05.023218Z", "iopub.status.idle": "2024-05-25T14:46:05.426793Z", "shell.execute_reply": "2024-05-25T14:46:05.424993Z", "shell.execute_reply.started": "2024-05-25T14:46:05.024227Z" }, "trusted": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "input = jestes zbyt naiwny\n", "output = you re too naive \n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "C:\\Users\\Michał\\AppData\\Local\\Temp\\ipykernel_10608\\712218569.py:8: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.\n", " ax.set_xticklabels([''] + input_sentence.split(' ') +\n", "C:\\Users\\Michał\\AppData\\Local\\Temp\\ipykernel_10608\\712218569.py:10: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.\n", " ax.set_yticklabels([''] + output_words)\n" ] }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "evaluateAndShowAttention('Jesteś zbyt naiwny')" ] }, { "cell_type": "code", "execution_count": 26, "metadata": { "execution": { "iopub.execute_input": "2024-05-25T14:46:10.395671Z", "iopub.status.busy": "2024-05-25T14:46:10.394969Z", "iopub.status.idle": "2024-05-25T14:46:10.793392Z", "shell.execute_reply": "2024-05-25T14:46:10.791940Z", "shell.execute_reply.started": "2024-05-25T14:46:10.395630Z" }, "trusted": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "input = naprawde mi przykro\n", "output = i m really sorry \n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "C:\\Users\\Michał\\AppData\\Local\\Temp\\ipykernel_10608\\712218569.py:8: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.\n", " ax.set_xticklabels([''] + input_sentence.split(' ') +\n", "C:\\Users\\Michał\\AppData\\Local\\Temp\\ipykernel_10608\\712218569.py:10: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.\n", " ax.set_yticklabels([''] + output_words)\n" ] }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "evaluateAndShowAttention('Naprawdę mi przykro')" ] }, { "cell_type": "code", "execution_count": 27, "metadata": { "execution": { "iopub.execute_input": "2024-05-25T14:46:13.191025Z", "iopub.status.busy": "2024-05-25T14:46:13.190403Z", "iopub.status.idle": "2024-05-25T14:46:13.613486Z", "shell.execute_reply": "2024-05-25T14:46:13.612143Z", "shell.execute_reply.started": "2024-05-25T14:46:13.190997Z" }, "trusted": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "input = jestes moim ojcem\n", "output = you are my father \n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "C:\\Users\\Michał\\AppData\\Local\\Temp\\ipykernel_10608\\712218569.py:8: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.\n", " ax.set_xticklabels([''] + input_sentence.split(' ') +\n", "C:\\Users\\Michał\\AppData\\Local\\Temp\\ipykernel_10608\\712218569.py:10: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.\n", " ax.set_yticklabels([''] + output_words)\n" ] }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "evaluateAndShowAttention('Jesteś moim ojcem')" ] }, { "cell_type": "code", "execution_count": 29, "metadata": { "execution": { "iopub.execute_input": "2024-05-25T14:46:15.841005Z", "iopub.status.busy": "2024-05-25T14:46:15.840433Z", "iopub.status.idle": "2024-05-25T14:46:16.232003Z", "shell.execute_reply": "2024-05-25T14:46:16.230365Z", "shell.execute_reply.started": "2024-05-25T14:46:15.840973Z" }, "trusted": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "input = on tez jest nauczycielem\n", "output = he is a teacher too \n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "C:\\Users\\Michał\\AppData\\Local\\Temp\\ipykernel_10608\\712218569.py:8: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.\n", " ax.set_xticklabels([''] + input_sentence.split(' ') +\n", "C:\\Users\\Michał\\AppData\\Local\\Temp\\ipykernel_10608\\712218569.py:10: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.\n", " ax.set_yticklabels([''] + output_words)\n" ] }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "evaluateAndShowAttention('On też jest nauczycielem')" ] }, { "cell_type": "code", "execution_count": 30, "metadata": { "execution": { "iopub.execute_input": "2024-05-25T14:50:24.774464Z", "iopub.status.busy": "2024-05-25T14:50:24.773506Z", "iopub.status.idle": "2024-05-25T14:50:24.795895Z", "shell.execute_reply": "2024-05-25T14:50:24.794979Z", "shell.execute_reply.started": "2024-05-25T14:50:24.774430Z" }, "trusted": true }, "outputs": [], "source": [ "torch.save(encoder.state_dict(), \"encoder.pt\")\n", "torch.save(decoder.state_dict(), \"decoder.pt\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# BLEU score" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Jako że korzystaliśmy z okrojonej wersji zbioru danych, słownik nie zawiera wszystkich słów pojawiających się w przykładach więc do ewaluacji wykorzystujemy część przykładów z treningu" ] }, { "cell_type": "code", "execution_count": 32, "metadata": { "execution": { "iopub.execute_input": "2024-05-25T15:28:54.297253Z", "iopub.status.busy": "2024-05-25T15:28:54.296348Z", "iopub.status.idle": "2024-05-25T15:28:59.041172Z", "shell.execute_reply": "2024-05-25T15:28:59.040201Z", "shell.execute_reply.started": "2024-05-25T15:28:54.297211Z" }, "trusted": true }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
EnglishPolishattribution
13492i m the last in linejestem ostatni w kolejceCC-BY 2.0 (France) Attribution: tatoeba.org #5...
14379you are not a cowardnie jestes tchorzemCC-BY 2.0 (France) Attribution: tatoeba.org #1...
31538we re anxious about her healthmartwimy sie o jej zdrowieCC-BY 2.0 (France) Attribution: tatoeba.org #7...
33382he s not at all afraid of snakeson zupe nie nie boi sie wezyCC-BY 2.0 (France) Attribution: tatoeba.org #2...
13031he is a fast speakeron szybko mowiCC-BY 2.0 (France) Attribution: tatoeba.org #3...
\n", "
" ], "text/plain": [ " English Polish \\\n", "13492 i m the last in line jestem ostatni w kolejce \n", "14379 you are not a coward nie jestes tchorzem \n", "31538 we re anxious about her health martwimy sie o jej zdrowie \n", "33382 he s not at all afraid of snakes on zupe nie nie boi sie wezy \n", "13031 he is a fast speaker on szybko mowi \n", "\n", " attribution \n", "13492 CC-BY 2.0 (France) Attribution: tatoeba.org #5... \n", "14379 CC-BY 2.0 (France) Attribution: tatoeba.org #1... \n", "31538 CC-BY 2.0 (France) Attribution: tatoeba.org #7... \n", "33382 CC-BY 2.0 (France) Attribution: tatoeba.org #2... \n", "13031 CC-BY 2.0 (France) Attribution: tatoeba.org #3... " ] }, "execution_count": 32, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import pandas as pd\n", "\n", "\n", "def filter_rows(row):\n", " return len(row[\"English\"].split(' '))\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
EnglishPolishattributionEnglish_tokenizedEnglish_translated
13492i m the last in linejestem ostatni w kolejceCC-BY 2.0 (France) Attribution: tatoeba.org #5...[i, m, the, last, in, line][i, m, the, last, in, line]
14379you are not a cowardnie jestes tchorzemCC-BY 2.0 (France) Attribution: tatoeba.org #1...[you, are, not, a, coward][you, are, not, a, coward]
31538we re anxious about her healthmartwimy sie o jej zdrowieCC-BY 2.0 (France) Attribution: tatoeba.org #7...[we, re, anxious, about, her, health][we, are, worried, about, her, health]
33382he s not at all afraid of snakeson zupe nie nie boi sie wezyCC-BY 2.0 (France) Attribution: tatoeba.org #2...[he, s, not, at, all, afraid, of, snakes][he, s, not, afraid, of, snakes, at, all, afraid]
13031he is a fast speakeron szybko mowiCC-BY 2.0 (France) Attribution: tatoeba.org #3...[he, is, a, fast, speaker][he, is, a, fast, speaker, young]
\n", "" ], "text/plain": [ " English Polish \\\n", "13492 i m the last in line jestem ostatni w kolejce \n", "14379 you are not a coward nie jestes tchorzem \n", "31538 we re anxious about her health martwimy sie o jej zdrowie \n", "33382 he s not at all afraid of snakes on zupe nie nie boi sie wezy \n", "13031 he is a fast speaker on szybko mowi \n", "\n", " attribution \\\n", "13492 CC-BY 2.0 (France) Attribution: tatoeba.org #5... \n", "14379 CC-BY 2.0 (France) Attribution: tatoeba.org #1... \n", "31538 CC-BY 2.0 (France) Attribution: tatoeba.org #7... \n", "33382 CC-BY 2.0 (France) Attribution: tatoeba.org #2... \n", "13031 CC-BY 2.0 (France) Attribution: tatoeba.org #3... \n", "\n", " English_tokenized \\\n", "13492 [i, m, the, last, in, line] \n", "14379 [you, are, not, a, coward] \n", "31538 [we, re, anxious, about, her, health] \n", "33382 [he, s, not, at, all, afraid, of, snakes] \n", "13031 [he, is, a, fast, speaker] \n", "\n", " English_translated \n", "13492 [i, m, the, last, in, line] \n", "14379 [you, are, not, a, coward] \n", "31538 [we, are, worried, about, her, health] \n", "33382 [he, s, not, afraid, of, snakes, at, all, afraid] \n", "13031 [he, is, a, fast, speaker, young] " ] }, "execution_count": 36, "metadata": {}, "output_type": "execute_result" } ], "source": [ "test_section.head()" ] }, { "cell_type": "code", "execution_count": 37, "metadata": { "execution": { "iopub.execute_input": "2024-05-25T15:43:29.442752Z", "iopub.status.busy": "2024-05-25T15:43:29.441911Z", "iopub.status.idle": "2024-05-25T15:43:29.447799Z", "shell.execute_reply": "2024-05-25T15:43:29.446877Z", "shell.execute_reply.started": "2024-05-25T15:43:29.442721Z" }, "trusted": true }, "outputs": [], "source": [ "candidate_corpus = test_section[\"English_translated\"].values\n", "references_corpus = test_section[\"English_tokenized\"].values.tolist()\n", "x = candidate_corpus.tolist()\n", "y = [[el] for el in references_corpus]" ] }, { "cell_type": "code", "execution_count": 38, "metadata": { "execution": { "iopub.execute_input": "2024-05-25T15:43:30.475080Z", "iopub.status.busy": "2024-05-25T15:43:30.474463Z", "iopub.status.idle": "2024-05-25T15:43:30.482690Z", "shell.execute_reply": "2024-05-25T15:43:30.481686Z", "shell.execute_reply.started": "2024-05-25T15:43:30.475039Z" }, "trusted": true }, "outputs": [ { "data": { "text/plain": [ "[[['i', 'm', 'the', 'last', 'in', 'line']],\n", " [['you', 'are', 'not', 'a', 'coward']],\n", " [['we', 're', 'anxious', 'about', 'her', 'health']],\n", " [['he', 's', 'not', 'at', 'all', 'afraid', 'of', 'snakes']],\n", " [['he', 'is', 'a', 'fast', 'speaker']]]" ] }, "execution_count": 38, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y[:5]" ] }, { "cell_type": "code", "execution_count": 39, "metadata": { "execution": { "iopub.execute_input": "2024-05-25T15:43:36.654953Z", "iopub.status.busy": "2024-05-25T15:43:36.654035Z", "iopub.status.idle": "2024-05-25T15:43:36.916617Z", "shell.execute_reply": "2024-05-25T15:43:36.915429Z", "shell.execute_reply.started": "2024-05-25T15:43:36.654906Z" }, "trusted": true }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "C:\\Users\\Michał\\AppData\\Roaming\\Python\\Python310\\site-packages\\torchtext\\data\\__init__.py:4: UserWarning: \n", "/!\\ IMPORTANT WARNING ABOUT TORCHTEXT STATUS /!\\ \n", "Torchtext is deprecated and the last released version will be 0.18 (this one). You can silence this warning by calling the following at the beginnign of your scripts: `import torchtext; torchtext.disable_torchtext_deprecation_warning()`\n", " warnings.warn(torchtext._TORCHTEXT_DEPRECATION_MSG)\n" ] }, { "data": { "text/plain": [ "0.8122377991676331" ] }, "execution_count": 39, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from torchtext.data.metrics import bleu_score\n", "\n", "bleu_score(x, y)" ] } ], "metadata": { "kaggle": { "accelerator": "nvidiaTeslaT4", "dataSources": [ { "datasetId": 5082663, "sourceId": 8513800, "sourceType": "datasetVersion" } ], "dockerImageVersionId": 30699, "isGpuEnabled": true, "isInternetEnabled": true, "language": "python", "sourceType": "notebook" }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.11" } }, "nbformat": 4, "nbformat_minor": 4 }