dl_seq2seq/fin-to-en-seq2seq-v3.ipynb

1301 lines
119 KiB
Plaintext
Raw Normal View History

{
"metadata": {
"kernelspec": {
"language": "python",
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.10.13",
"mimetype": "text/x-python",
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"pygments_lexer": "ipython3",
"nbconvert_exporter": "python",
"file_extension": ".py"
},
"kaggle": {
"accelerator": "nvidiaTeslaT4",
"dataSources": [
{
"sourceId": 8513800,
"sourceType": "datasetVersion",
"datasetId": 5082663
}
],
"dockerImageVersionId": 30699,
"isInternetEnabled": true,
"language": "python",
"sourceType": "notebook",
"isGpuEnabled": true
}
},
"nbformat_minor": 4,
"nbformat": 4,
"cells": [
{
"cell_type": "markdown",
"source": [
"# Seq2Seq Fiński --> Angielski\n",
"https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html"
],
"metadata": {}
},
{
"cell_type": "code",
"source": [
"from __future__ import unicode_literals, print_function, division\n",
"from io import open\n",
"import unicodedata\n",
"import re\n",
"import random\n",
"\n",
"import torch\n",
"import torch.nn as nn\n",
"from torch import optim\n",
"import torch.nn.functional as F\n",
"\n",
"import numpy as np\n",
"from torch.utils.data import TensorDataset, DataLoader, RandomSampler\n",
"\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
],
"metadata": {
"_uuid": "8f2839f25d086af736a60e9eeb907d3b93b6e0e5",
"_cell_guid": "b1076dfc-b9ad-4769-8c92-a6c4dae69d19",
"execution": {
"iopub.status.busy": "2024-05-25T14:03:55.886451Z",
"iopub.execute_input": "2024-05-25T14:03:55.887266Z",
"iopub.status.idle": "2024-05-25T14:04:02.514594Z",
"shell.execute_reply.started": "2024-05-25T14:03:55.887232Z",
"shell.execute_reply": "2024-05-25T14:04:02.513697Z"
},
"trusted": true
},
"execution_count": 1,
"outputs": []
},
{
"cell_type": "code",
"source": [
"torch.cuda.device_count()"
],
"metadata": {
"execution": {
"iopub.status.busy": "2024-05-25T14:04:09.403445Z",
"iopub.execute_input": "2024-05-25T14:04:09.403926Z",
"iopub.status.idle": "2024-05-25T14:04:09.434533Z",
"shell.execute_reply.started": "2024-05-25T14:04:09.403898Z",
"shell.execute_reply": "2024-05-25T14:04:09.433678Z"
},
"trusted": true
},
"execution_count": 2,
"outputs": [
{
"execution_count": 2,
"output_type": "execute_result",
"data": {
"text/plain": "2"
},
"metadata": {}
}
]
},
{
"cell_type": "markdown",
"source": [
"### Konwersja słów na index"
],
"metadata": {}
},
{
"cell_type": "code",
"source": [
"SOS_token = 0\n",
"EOS_token = 1\n",
"\n",
"class Lang:\n",
" def __init__(self, name):\n",
" self.name = name\n",
" self.word2index = {}\n",
" self.word2count = {}\n",
" self.index2word = {0: \"SOS\", 1: \"EOS\"}\n",
" self.n_words = 2 # Count SOS and EOS\n",
"\n",
" def addSentence(self, sentence):\n",
" for word in sentence.split(' '):\n",
" self.addWord(word)\n",
"\n",
" def addWord(self, word):\n",
" if word not in self.word2index:\n",
" self.word2index[word] = self.n_words\n",
" self.word2count[word] = 1\n",
" self.index2word[self.n_words] = word\n",
" self.n_words += 1\n",
" else:\n",
" self.word2count[word] += 1"
],
"metadata": {
"execution": {
"iopub.status.busy": "2024-05-25T14:04:14.014114Z",
"iopub.execute_input": "2024-05-25T14:04:14.014490Z",
"iopub.status.idle": "2024-05-25T14:04:14.024526Z",
"shell.execute_reply.started": "2024-05-25T14:04:14.014461Z",
"shell.execute_reply": "2024-05-25T14:04:14.023673Z"
},
"trusted": true
},
"execution_count": 3,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"### Normalizacja tekstu"
],
"metadata": {}
},
{
"cell_type": "code",
"source": [
"# Turn a Unicode string to plain ASCII, thanks to\n",
"# https://stackoverflow.com/a/518232/2809427\n",
"def unicodeToAscii(s):\n",
" return ''.join(\n",
" c for c in unicodedata.normalize('NFD', s)\n",
" if unicodedata.category(c) != 'Mn'\n",
" )\n",
"\n",
"# Lowercase, trim, and remove non-letter characters\n",
"def normalizeString(s):\n",
" s = unicodeToAscii(s.lower().strip())\n",
" s = re.sub(r\"([.!?])\", r\" \\1\", s)\n",
" s = re.sub(r\"[^a-zA-Z!?]+\", r\" \", s)\n",
" return s.strip()"
],
"metadata": {
"execution": {
"iopub.status.busy": "2024-05-25T14:04:23.431898Z",
"iopub.execute_input": "2024-05-25T14:04:23.432285Z",
"iopub.status.idle": "2024-05-25T14:04:23.438688Z",
"shell.execute_reply.started": "2024-05-25T14:04:23.432256Z",
"shell.execute_reply": "2024-05-25T14:04:23.437569Z"
},
"trusted": true
},
"execution_count": 4,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"### Wczytywanie danych (zmodyfikowane ze względu na ścieżkę w kaggle)"
],
"metadata": {}
},
{
"cell_type": "code",
"source": [
"# Zmodyfikowana wersja ze względu na użycie pojedynczego pliku przesłanego na Kaggle\n",
"def readLangs(reverse=False):\n",
" print(\"Reading lines...\")\n",
" lang1=\"en\"\n",
" lang2=\"fin\"\n",
" # Read the file and split into lines\n",
" lines = open('/kaggle/input/anki-en-fin/fin.txt', encoding='utf-8').\\\n",
" read().strip().split('\\n')\n",
"\n",
" # Split every line into pairs and normalize\n",
" pairs = [[normalizeString(s) for s in l.split('\\t')[:-1]] for l in lines] # +Usuwanie licencji CC z linii\n",
"\n",
" # Reverse pairs, make Lang instances\n",
" if reverse:\n",
" pairs = [list(reversed(p)) for p in pairs]\n",
" input_lang = Lang(lang2)\n",
" output_lang = Lang(lang1)\n",
" else:\n",
" input_lang = Lang(lang1)\n",
" output_lang = Lang(lang2)\n",
"\n",
" return input_lang, output_lang, pairs"
],
"metadata": {
"execution": {
"iopub.status.busy": "2024-05-25T14:12:25.385674Z",
"iopub.execute_input": "2024-05-25T14:12:25.386029Z",
"iopub.status.idle": "2024-05-25T14:12:25.394103Z",
"shell.execute_reply.started": "2024-05-25T14:12:25.386002Z",
"shell.execute_reply": "2024-05-25T14:12:25.392925Z"
},
"trusted": true
},
"execution_count": 14,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"#### Ograniczenie do zdań max 10 słów, formy I am / You are / He is etc. bez interpunkcji"
],
"metadata": {}
},
{
"cell_type": "code",
"source": [
"MAX_LENGTH = 10\n",
"\n",
"eng_prefixes = (\n",
" \"i am \", \"i m \",\n",
" \"he is\", \"he s \",\n",
" \"she is\", \"she s \",\n",
" \"you are\", \"you re \",\n",
" \"we are\", \"we re \",\n",
" \"they are\", \"they re \"\n",
")\n",
"\n",
"def filterPair(p):\n",
" return len(p[0].split(' ')) < MAX_LENGTH and \\\n",
" len(p[1].split(' ')) < MAX_LENGTH and \\\n",
" p[1].startswith(eng_prefixes)\n",
"\n",
"\n",
"def filterPairs(pairs):\n",
" return [pair for pair in pairs if filterPair(pair)]"
],
"metadata": {
"execution": {
"iopub.status.busy": "2024-05-25T14:12:29.729786Z",
"iopub.execute_input": "2024-05-25T14:12:29.730147Z",
"iopub.status.idle": "2024-05-25T14:12:29.737013Z",
"shell.execute_reply.started": "2024-05-25T14:12:29.730121Z",
"shell.execute_reply": "2024-05-25T14:12:29.735886Z"
},
"trusted": true
},
"execution_count": 15,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def prepareData(reverse=False):\n",
" input_lang, output_lang, pairs = readLangs(reverse)\n",
" print(\"Read %s sentence pairs\" % len(pairs))\n",
" pairs = filterPairs(pairs)\n",
" print(\"Trimmed to %s sentence pairs\" % len(pairs))\n",
" print(\"Counting words...\")\n",
" for pair in pairs:\n",
" input_lang.addSentence(pair[0])\n",
" output_lang.addSentence(pair[1])\n",
" print(\"Counted words:\")\n",
" print(input_lang.name, input_lang.n_words)\n",
" print(output_lang.name, output_lang.n_words)\n",
" return input_lang, output_lang, pairs\n",
"\n",
"input_lang, output_lang, pairs = prepareData(True)\n",
"print(random.choice(pairs))"
],
"metadata": {
"execution": {
"iopub.status.busy": "2024-05-25T14:12:33.204103Z",
"iopub.execute_input": "2024-05-25T14:12:33.204776Z",
"iopub.status.idle": "2024-05-25T14:12:36.889693Z",
"shell.execute_reply.started": "2024-05-25T14:12:33.204744Z",
"shell.execute_reply": "2024-05-25T14:12:36.888700Z"
},
"trusted": true
},
"execution_count": 16,
"outputs": [
{
"name": "stdout",
"text": "Reading lines...\nRead 72258 sentence pairs\nTrimmed to 5005 sentence pairs\nCounting words...\nCounted words:\nfin 3686\nen 1971\n['mina odotan joulua innolla', 'i am looking forward to christmas']\n",
"output_type": "stream"
}
]
},
{
"cell_type": "markdown",
"source": [
"### Definicja modelu"
],
"metadata": {}
},
{
"cell_type": "code",
"source": [
"class EncoderRNN(nn.Module):\n",
" def __init__(self, input_size, hidden_size, dropout_p=0.1):\n",
" super(EncoderRNN, self).__init__()\n",
" self.hidden_size = hidden_size\n",
"\n",
" self.embedding = nn.Embedding(input_size, hidden_size)\n",
" self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True)\n",
" self.dropout = nn.Dropout(dropout_p)\n",
"\n",
" def forward(self, input):\n",
" embedded = self.dropout(self.embedding(input))\n",
" output, hidden = self.gru(embedded)\n",
" return output, hidden"
],
"metadata": {
"execution": {
"iopub.status.busy": "2024-05-25T14:12:52.383787Z",
"iopub.execute_input": "2024-05-25T14:12:52.384131Z",
"iopub.status.idle": "2024-05-25T14:12:52.391196Z",
"shell.execute_reply.started": "2024-05-25T14:12:52.384104Z",
"shell.execute_reply": "2024-05-25T14:12:52.390316Z"
},
"trusted": true
},
"execution_count": 17,
"outputs": []
},
{
"cell_type": "code",
"source": [
"class DecoderRNN(nn.Module):\n",
" def __init__(self, hidden_size, output_size):\n",
" super(DecoderRNN, self).__init__()\n",
" self.embedding = nn.Embedding(output_size, hidden_size)\n",
" self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True)\n",
" self.out = nn.Linear(hidden_size, output_size)\n",
"\n",
" def forward(self, encoder_outputs, encoder_hidden, target_tensor=None):\n",
" batch_size = encoder_outputs.size(0)\n",
" decoder_input = torch.empty(batch_size, 1, dtype=torch.long, device=device).fill_(SOS_token)\n",
" decoder_hidden = encoder_hidden\n",
" decoder_outputs = []\n",
"\n",
" for i in range(MAX_LENGTH):\n",
" decoder_output, decoder_hidden = self.forward_step(decoder_input, decoder_hidden)\n",
" decoder_outputs.append(decoder_output)\n",
"\n",
" if target_tensor is not None:\n",
" # Teacher forcing: Feed the target as the next input\n",
" decoder_input = target_tensor[:, i].unsqueeze(1) # Teacher forcing\n",
" else:\n",
" # Without teacher forcing: use its own predictions as the next input\n",
" _, topi = decoder_output.topk(1)\n",
" decoder_input = topi.squeeze(-1).detach() # detach from history as input\n",
"\n",
" decoder_outputs = torch.cat(decoder_outputs, dim=1)\n",
" decoder_outputs = F.log_softmax(decoder_outputs, dim=-1)\n",
" return decoder_outputs, decoder_hidden, None # We return `None` for consistency in the training loop\n",
"\n",
" def forward_step(self, input, hidden):\n",
" output = self.embedding(input)\n",
" output = F.relu(output)\n",
" output, hidden = self.gru(output, hidden)\n",
" output = self.out(output)\n",
" return output, hidden"
],
"metadata": {
"execution": {
"iopub.status.busy": "2024-05-25T14:12:54.393953Z",
"iopub.execute_input": "2024-05-25T14:12:54.394808Z",
"iopub.status.idle": "2024-05-25T14:12:54.409000Z",
"shell.execute_reply.started": "2024-05-25T14:12:54.394765Z",
"shell.execute_reply": "2024-05-25T14:12:54.407827Z"
},
"trusted": true
},
"execution_count": 18,
"outputs": []
},
{
"cell_type": "code",
"source": [
"class BahdanauAttention(nn.Module):\n",
" def __init__(self, hidden_size):\n",
" super(BahdanauAttention, self).__init__()\n",
" self.Wa = nn.Linear(hidden_size, hidden_size)\n",
" self.Ua = nn.Linear(hidden_size, hidden_size)\n",
" self.Va = nn.Linear(hidden_size, 1)\n",
"\n",
" def forward(self, query, keys):\n",
" scores = self.Va(torch.tanh(self.Wa(query) + self.Ua(keys)))\n",
" scores = scores.squeeze(2).unsqueeze(1)\n",
"\n",
" weights = F.softmax(scores, dim=-1)\n",
" context = torch.bmm(weights, keys)\n",
"\n",
" return context, weights\n",
"\n",
"class AttnDecoderRNN(nn.Module):\n",
" def __init__(self, hidden_size, output_size, dropout_p=0.1):\n",
" super(AttnDecoderRNN, self).__init__()\n",
" self.embedding = nn.Embedding(output_size, hidden_size)\n",
" self.attention = BahdanauAttention(hidden_size)\n",
" self.gru = nn.GRU(2 * hidden_size, hidden_size, batch_first=True)\n",
" self.out = nn.Linear(hidden_size, output_size)\n",
" self.dropout = nn.Dropout(dropout_p)\n",
"\n",
" def forward(self, encoder_outputs, encoder_hidden, target_tensor=None):\n",
" batch_size = encoder_outputs.size(0)\n",
" decoder_input = torch.empty(batch_size, 1, dtype=torch.long, device=device).fill_(SOS_token)\n",
" decoder_hidden = encoder_hidden\n",
" decoder_outputs = []\n",
" attentions = []\n",
"\n",
" for i in range(MAX_LENGTH):\n",
" decoder_output, decoder_hidden, attn_weights = self.forward_step(\n",
" decoder_input, decoder_hidden, encoder_outputs\n",
" )\n",
" decoder_outputs.append(decoder_output)\n",
" attentions.append(attn_weights)\n",
"\n",
" if target_tensor is not None:\n",
" # Teacher forcing: Feed the target as the next input\n",
" decoder_input = target_tensor[:, i].unsqueeze(1) # Teacher forcing\n",
" else:\n",
" # Without teacher forcing: use its own predictions as the next input\n",
" _, topi = decoder_output.topk(1)\n",
" decoder_input = topi.squeeze(-1).detach() # detach from history as input\n",
"\n",
" decoder_outputs = torch.cat(decoder_outputs, dim=1)\n",
" decoder_outputs = F.log_softmax(decoder_outputs, dim=-1)\n",
" attentions = torch.cat(attentions, dim=1)\n",
"\n",
" return decoder_outputs, decoder_hidden, attentions\n",
"\n",
"\n",
" def forward_step(self, input, hidden, encoder_outputs):\n",
" embedded = self.dropout(self.embedding(input))\n",
"\n",
" query = hidden.permute(1, 0, 2)\n",
" context, attn_weights = self.attention(query, encoder_outputs)\n",
" input_gru = torch.cat((embedded, context), dim=2)\n",
"\n",
" output, hidden = self.gru(input_gru, hidden)\n",
" output = self.out(output)\n",
"\n",
" return output, hidden, attn_weights"
],
"metadata": {
"execution": {
"iopub.status.busy": "2024-05-25T14:13:00.670299Z",
"iopub.execute_input": "2024-05-25T14:13:00.670758Z",
"iopub.status.idle": "2024-05-25T14:13:00.687695Z",
"shell.execute_reply.started": "2024-05-25T14:13:00.670720Z",
"shell.execute_reply": "2024-05-25T14:13:00.686610Z"
},
"trusted": true
},
"execution_count": 19,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def indexesFromSentence(lang, sentence):\n",
" return [lang.word2index[word] for word in sentence.split(' ')]\n",
"\n",
"def tensorFromSentence(lang, sentence):\n",
" indexes = indexesFromSentence(lang, sentence)\n",
" indexes.append(EOS_token)\n",
" return torch.tensor(indexes, dtype=torch.long, device=device).view(1, -1)\n",
"\n",
"def tensorsFromPair(pair):\n",
" input_tensor = tensorFromSentence(input_lang, pair[0])\n",
" target_tensor = tensorFromSentence(output_lang, pair[1])\n",
" return (input_tensor, target_tensor)\n",
"\n",
"def get_dataloader(batch_size):\n",
" input_lang, output_lang, pairs = prepareData(True)\n",
"\n",
" n = len(pairs)\n",
" input_ids = np.zeros((n, MAX_LENGTH), dtype=np.int32)\n",
" target_ids = np.zeros((n, MAX_LENGTH), dtype=np.int32)\n",
"\n",
" for idx, (inp, tgt) in enumerate(pairs):\n",
" inp_ids = indexesFromSentence(input_lang, inp)\n",
" tgt_ids = indexesFromSentence(output_lang, tgt)\n",
" inp_ids.append(EOS_token)\n",
" tgt_ids.append(EOS_token)\n",
" input_ids[idx, :len(inp_ids)] = inp_ids\n",
" target_ids[idx, :len(tgt_ids)] = tgt_ids\n",
"\n",
" train_data = TensorDataset(torch.LongTensor(input_ids).to(device),\n",
" torch.LongTensor(target_ids).to(device))\n",
"\n",
" train_sampler = RandomSampler(train_data)\n",
" train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size)\n",
" return input_lang, output_lang, train_dataloader"
],
"metadata": {
"execution": {
"iopub.status.busy": "2024-05-25T14:22:08.183866Z",
"iopub.execute_input": "2024-05-25T14:22:08.184711Z",
"iopub.status.idle": "2024-05-25T14:22:08.194870Z",
"shell.execute_reply.started": "2024-05-25T14:22:08.184675Z",
"shell.execute_reply": "2024-05-25T14:22:08.193965Z"
},
"trusted": true
},
"execution_count": 31,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def train_epoch(dataloader, encoder, decoder, encoder_optimizer,\n",
" decoder_optimizer, criterion):\n",
"\n",
" total_loss = 0\n",
" for data in dataloader:\n",
" input_tensor, target_tensor = data\n",
"\n",
" encoder_optimizer.zero_grad()\n",
" decoder_optimizer.zero_grad()\n",
"\n",
" encoder_outputs, encoder_hidden = encoder(input_tensor)\n",
" decoder_outputs, _, _ = decoder(encoder_outputs, encoder_hidden, target_tensor)\n",
"\n",
" loss = criterion(\n",
" decoder_outputs.view(-1, decoder_outputs.size(-1)),\n",
" target_tensor.view(-1)\n",
" )\n",
" loss.backward()\n",
"\n",
" encoder_optimizer.step()\n",
" decoder_optimizer.step()\n",
"\n",
" total_loss += loss.item()\n",
"\n",
" return total_loss / len(dataloader)"
],
"metadata": {
"execution": {
"iopub.status.busy": "2024-05-25T14:16:38.894580Z",
"iopub.execute_input": "2024-05-25T14:16:38.895410Z",
"iopub.status.idle": "2024-05-25T14:16:38.902142Z",
"shell.execute_reply.started": "2024-05-25T14:16:38.895382Z",
"shell.execute_reply": "2024-05-25T14:16:38.900953Z"
},
"trusted": true
},
"execution_count": 22,
"outputs": []
},
{
"cell_type": "code",
"source": [
"import time\n",
"import math\n",
"\n",
"def asMinutes(s):\n",
" m = math.floor(s / 60)\n",
" s -= m * 60\n",
" return '%dm %ds' % (m, s)\n",
"\n",
"def timeSince(since, percent):\n",
" now = time.time()\n",
" s = now - since\n",
" es = s / (percent)\n",
" rs = es - s\n",
" return '%s (- %s)' % (asMinutes(s), asMinutes(rs))"
],
"metadata": {
"execution": {
"iopub.status.busy": "2024-05-25T14:16:43.069584Z",
"iopub.execute_input": "2024-05-25T14:16:43.069953Z",
"iopub.status.idle": "2024-05-25T14:16:43.075972Z",
"shell.execute_reply.started": "2024-05-25T14:16:43.069926Z",
"shell.execute_reply": "2024-05-25T14:16:43.075033Z"
},
"trusted": true
},
"execution_count": 23,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def train(train_dataloader, encoder, decoder, n_epochs, learning_rate=0.001,\n",
" print_every=100, plot_every=100):\n",
" start = time.time()\n",
" plot_losses = []\n",
" print_loss_total = 0 # Reset every print_every\n",
" plot_loss_total = 0 # Reset every plot_every\n",
"\n",
" encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)\n",
" decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate)\n",
" criterion = nn.NLLLoss()\n",
"\n",
" for epoch in range(1, n_epochs + 1):\n",
" loss = train_epoch(train_dataloader, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion)\n",
" print_loss_total += loss\n",
" plot_loss_total += loss\n",
"\n",
" if epoch % print_every == 0:\n",
" print_loss_avg = print_loss_total / print_every\n",
" print_loss_total = 0\n",
" print('%s (%d %d%%) %.4f' % (timeSince(start, epoch / n_epochs),\n",
" epoch, epoch / n_epochs * 100, print_loss_avg))\n",
"\n",
" if epoch % plot_every == 0:\n",
" plot_loss_avg = plot_loss_total / plot_every\n",
" plot_losses.append(plot_loss_avg)\n",
" plot_loss_total = 0\n",
"\n",
" showPlot(plot_losses)"
],
"metadata": {
"execution": {
"iopub.status.busy": "2024-05-25T14:20:58.574148Z",
"iopub.execute_input": "2024-05-25T14:20:58.574520Z",
"iopub.status.idle": "2024-05-25T14:20:58.583203Z",
"shell.execute_reply.started": "2024-05-25T14:20:58.574492Z",
"shell.execute_reply": "2024-05-25T14:20:58.582230Z"
},
"trusted": true
},
"execution_count": 24,
"outputs": []
},
{
"cell_type": "code",
"source": [
"import matplotlib.pyplot as plt\n",
"plt.switch_backend('agg')\n",
"import matplotlib.ticker as ticker\n",
"import numpy as np\n",
"%matplotlib inline\n",
"\n",
"def showPlot(points):\n",
" plt.figure()\n",
" fig, ax = plt.subplots()\n",
" # this locator puts ticks at regular intervals\n",
" loc = ticker.MultipleLocator(base=0.2)\n",
" ax.yaxis.set_major_locator(loc)\n",
" plt.plot(points)"
],
"metadata": {
"execution": {
"iopub.status.busy": "2024-05-25T14:21:00.586018Z",
"iopub.execute_input": "2024-05-25T14:21:00.586719Z",
"iopub.status.idle": "2024-05-25T14:21:00.592633Z",
"shell.execute_reply.started": "2024-05-25T14:21:00.586683Z",
"shell.execute_reply": "2024-05-25T14:21:00.591636Z"
},
"trusted": true
},
"execution_count": 25,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"### Ewaluacja"
],
"metadata": {}
},
{
"cell_type": "code",
"source": [
"def evaluate(encoder, decoder, sentence, input_lang, output_lang):\n",
" with torch.no_grad():\n",
" input_tensor = tensorFromSentence(input_lang, sentence)\n",
"\n",
" encoder_outputs, encoder_hidden = encoder(input_tensor)\n",
" decoder_outputs, decoder_hidden, decoder_attn = decoder(encoder_outputs, encoder_hidden)\n",
"\n",
" _, topi = decoder_outputs.topk(1)\n",
" decoded_ids = topi.squeeze()\n",
"\n",
" decoded_words = []\n",
" for idx in decoded_ids:\n",
" if idx.item() == EOS_token:\n",
" decoded_words.append('<EOS>')\n",
" break\n",
" decoded_words.append(output_lang.index2word[idx.item()])\n",
" return decoded_words, decoder_attn"
],
"metadata": {
"execution": {
"iopub.status.busy": "2024-05-25T14:21:01.858691Z",
"iopub.execute_input": "2024-05-25T14:21:01.859612Z",
"iopub.status.idle": "2024-05-25T14:21:01.866857Z",
"shell.execute_reply.started": "2024-05-25T14:21:01.859574Z",
"shell.execute_reply": "2024-05-25T14:21:01.865732Z"
},
"trusted": true
},
"execution_count": 26,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def evaluateRandomly(encoder, decoder, n=10):\n",
" for i in range(n):\n",
" pair = random.choice(pairs)\n",
" print('Input sentence: ', pair[0])\n",
" print('Target (true) translation:' , pair[1])\n",
" output_words, _ = evaluate(encoder, decoder, pair[0], input_lang, output_lang)\n",
" output_sentence = ' '.join(output_words)\n",
" print('Output sentence: ', output_sentence)\n",
" print('')"
],
"metadata": {
"execution": {
"iopub.status.busy": "2024-05-25T14:39:52.474985Z",
"iopub.execute_input": "2024-05-25T14:39:52.475327Z",
"iopub.status.idle": "2024-05-25T14:39:52.481801Z",
"shell.execute_reply.started": "2024-05-25T14:39:52.475304Z",
"shell.execute_reply": "2024-05-25T14:39:52.480957Z"
},
"trusted": true
},
"execution_count": 36,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# Wykorzystanie zdefiniowanych wyżej funkcji"
],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"### Trenowanie modelu"
],
"metadata": {}
},
{
"cell_type": "code",
"source": [
"hidden_size = 128\n",
"batch_size = 32\n",
"\n",
"input_lang, output_lang, train_dataloader = get_dataloader(batch_size)\n",
"\n",
"encoder = EncoderRNN(input_lang.n_words, hidden_size).to(device)\n",
"decoder = AttnDecoderRNN(hidden_size, output_lang.n_words).to(device)\n",
"\n",
"train(train_dataloader, encoder, decoder, 80, print_every=5, plot_every=5)"
],
"metadata": {
"execution": {
"iopub.status.busy": "2024-05-25T14:22:39.754370Z",
"iopub.execute_input": "2024-05-25T14:22:39.754740Z",
"iopub.status.idle": "2024-05-25T14:26:53.707012Z",
"shell.execute_reply.started": "2024-05-25T14:22:39.754714Z",
"shell.execute_reply": "2024-05-25T14:26:53.705969Z"
},
"trusted": true
},
"execution_count": 32,
"outputs": [
{
"name": "stdout",
"text": "Reading lines...\nRead 72258 sentence pairs\nTrimmed to 5005 sentence pairs\nCounting words...\nCounted words:\nfin 3686\nen 1971\n0m 21s (- 5m 21s) (5 6%) 1.9364\n0m 36s (- 4m 17s) (10 12%) 1.0355\n0m 51s (- 3m 45s) (15 18%) 0.6313\n1m 7s (- 3m 21s) (20 25%) 0.3787\n1m 22s (- 3m 1s) (25 31%) 0.2243\n1m 37s (- 2m 42s) (30 37%) 0.1371\n1m 52s (- 2m 25s) (35 43%) 0.0903\n2m 7s (- 2m 7s) (40 50%) 0.0668\n2m 23s (- 1m 51s) (45 56%) 0.0538\n2m 38s (- 1m 34s) (50 62%) 0.0471\n2m 53s (- 1m 18s) (55 68%) 0.0410\n3m 8s (- 1m 2s) (60 75%) 0.0381\n3m 23s (- 0m 47s) (65 81%) 0.0343\n3m 38s (- 0m 31s) (70 87%) 0.0342\n3m 54s (- 0m 15s) (75 93%) 0.0322\n4m 9s (- 0m 0s) (80 100%) 0.0307\n",
"output_type": "stream"
}
]
},
{
"cell_type": "code",
"source": [
"evaluateRandomly(encoder, decoder)"
],
"metadata": {
"execution": {
"iopub.status.busy": "2024-05-25T15:48:23.131435Z",
"iopub.execute_input": "2024-05-25T15:48:23.131948Z",
"iopub.status.idle": "2024-05-25T15:48:23.213007Z",
"shell.execute_reply.started": "2024-05-25T15:48:23.131911Z",
"shell.execute_reply": "2024-05-25T15:48:23.211856Z"
},
"trusted": true
},
"execution_count": 121,
"outputs": [
{
"name": "stdout",
"text": "Input sentence: olen hyvin hyvin vihainen\nTarget (true) translation: i m very very angry\nOutput sentence: i am very angry today <EOS>\n\nInput sentence: han valehtelee\nTarget (true) translation: he s lying\nOutput sentence: he is telling a lie <EOS>\n\nInput sentence: olen myohassa\nTarget (true) translation: i m late\nOutput sentence: i m late <EOS>\n\nInput sentence: han on linja autonkuljettaja\nTarget (true) translation: he is a bus driver\nOutput sentence: he is a bus driver <EOS>\n\nInput sentence: mukava tavata sinut taas\nTarget (true) translation: i m glad to see you again\nOutput sentence: i m glad to see you again <EOS>\n\nInput sentence: olet kuumeessa\nTarget (true) translation: you re running a fever\nOutput sentence: you re so predictable <EOS>\n\nInput sentence: anteeksi mutta unohdin tehda laksyt\nTarget (true) translation: i m sorry i forgot to do my homework\nOutput sentence: i m sorry i forgot to do my homework <EOS>\n\nInput sentence: mina olen tyoton\nTarget (true) translation: i m unemployed\nOutput sentence: i m unemployed <EOS>\n\nInput sentence: olen taynna\nTarget (true) translation: i am full\nOutput sentence: i am full of french <EOS>\n\nInput sentence: ma kuolen nalkaan !\nTarget (true) translation: i m dying of hunger\nOutput sentence: i m dying of hunger <EOS>\n\n",
"output_type": "stream"
}
]
},
{
"cell_type": "code",
"source": [
"def showAttention(input_sentence, output_words, attentions):\n",
" fig = plt.figure()\n",
" ax = fig.add_subplot(111)\n",
" cax = ax.matshow(attentions.cpu().numpy(), cmap='bone')\n",
" fig.colorbar(cax)\n",
"\n",
" # Set up axes\n",
" ax.set_xticklabels([''] + input_sentence.split(' ') +\n",
" ['<EOS>'], rotation=90)\n",
" ax.set_yticklabels([''] + output_words)\n",
"\n",
" # Show label at every tick\n",
" ax.xaxis.set_major_locator(ticker.MultipleLocator(1))\n",
" ax.yaxis.set_major_locator(ticker.MultipleLocator(1))\n",
"\n",
" plt.show()\n",
"\n",
"\n",
"def evaluateAndShowAttention(input_sentence):\n",
" input_sentence = normalizeString(input_sentence)\n",
" output_words, attentions = evaluate(encoder, decoder, input_sentence, input_lang, output_lang)\n",
" print('input =', input_sentence)\n",
" print('output =', ' '.join(output_words))\n",
" showAttention(input_sentence, output_words, attentions[0, :len(output_words), :])\n",
"\n",
"def translate(input_sentence, tokenized=False):\n",
" input_sentence = normalizeString(input_sentence)\n",
" output_words, attentions = evaluate(encoder, decoder, input_sentence, input_lang, output_lang)\n",
" if tokenized:\n",
" if \"<EOS>\" in output_words:\n",
" output_words.remove(\"<EOS>\")\n",
" return output_words\n",
" return ' '.join(output_words)"
],
"metadata": {
"execution": {
"iopub.status.busy": "2024-05-25T15:30:41.253325Z",
"iopub.execute_input": "2024-05-25T15:30:41.253734Z",
"iopub.status.idle": "2024-05-25T15:30:41.264515Z",
"shell.execute_reply.started": "2024-05-25T15:30:41.253703Z",
"shell.execute_reply": "2024-05-25T15:30:41.263376Z"
},
"trusted": true
},
"execution_count": 99,
"outputs": []
},
{
"cell_type": "code",
"source": [
"translate(\"Meillä on nälkä\", tokenized=True)"
],
"metadata": {
"execution": {
"iopub.status.busy": "2024-05-25T14:58:53.639252Z",
"iopub.execute_input": "2024-05-25T14:58:53.639963Z",
"iopub.status.idle": "2024-05-25T14:58:53.654186Z",
"shell.execute_reply.started": "2024-05-25T14:58:53.639932Z",
"shell.execute_reply": "2024-05-25T14:58:53.653028Z"
},
"trusted": true
},
"execution_count": 76,
"outputs": [
{
"execution_count": 76,
"output_type": "execute_result",
"data": {
"text/plain": "['we', 'are', 'hungry']"
},
"metadata": {}
}
]
},
{
"cell_type": "code",
"source": [
"evaluateAndShowAttention('Olet liian naivi')"
],
"metadata": {
"execution": {
"iopub.status.busy": "2024-05-25T14:46:05.023218Z",
"iopub.execute_input": "2024-05-25T14:46:05.024277Z",
"iopub.status.idle": "2024-05-25T14:46:05.426793Z",
"shell.execute_reply.started": "2024-05-25T14:46:05.024227Z",
"shell.execute_reply": "2024-05-25T14:46:05.424993Z"
},
"trusted": true
},
"execution_count": 44,
"outputs": [
{
"name": "stdout",
"text": "input = olet liian naivi\noutput = you re too naive <EOS>\n",
"output_type": "stream"
},
{
"name": "stderr",
"text": "/tmp/ipykernel_34/2052950992.py:8: UserWarning: FixedFormatter should only be used together with FixedLocator\n ax.set_xticklabels([''] + input_sentence.split(' ') +\n/tmp/ipykernel_34/2052950992.py:10: UserWarning: FixedFormatter should only be used together with FixedLocator\n ax.set_yticklabels([''] + output_words)\n",
"output_type": "stream"
},
{
"output_type": "display_data",
"data": {
"text/plain": "<Figure size 640x480 with 2 Axes>",
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAcQAAAHHCAYAAAAhyyixAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuNSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/xnp5ZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAxC0lEQVR4nO3deXgUZbbH8V8HSIKEDiAmYQlE1sCwCESYiChiBPEZHC4q+2IUVAQVoyIoElAvoCKbICiLgIoGRWfwQUEJRGcABeGCsskuEQk7NGsC6b5/MOmhJdUmdDrVnfp+eOoZU13VdbpHczjnfestm8vlcgkAAIsLMTsAAAACAQkRAACREAEAkERCBABAEgkRAABJJEQAACSREAEAkERCBABAEgkRAABJJEQAACSREAEAkERCBABAEgkRAABJJEQA8Co3N1c//fSTLl26ZHYo8DMSIgB48cUXX6hZs2ZKS0szOxT4GQkRALyYN2+ebrjhBs2dO9fsUOBnNh4QDAD5O3r0qKpXr65//OMfuvfee7Vnzx5Vr17d7LDgJ1SIAGDgo48+UqNGjXT33XerTZs2ev/9980OCX5EQgQAA3PnzlXfvn0lSb1799b8+fNNjgj+RMsUAPKxefNmtWjRQgcOHFDlypV15swZRUdHa8WKFWrVqpXZ4cEPqBABIB/z5s1T+/btVblyZUlSRESEOnfuzOSaEoyECAB/kJubqw8++MDdLs3Tu3dvpaWlKScnx6TI4E8kRAD4g8OHD2vgwIH6+9//7rG/Q4cOSklJUVZWlkmRwZ8YQwQAQFSIAFAgv/76q7Zu3Sqn02l2KPATEiIAXGHOnDmaMGGCx75HHnlEtWrVUuPGjdWoUSNlZmaaFB38iYQIAFd49913VbFiRffPS5cu1Xvvvaf58+dr3bp1qlChgkaPHm1ihPAXxhAB4ArXX3+9MjIy1LhxY0nSwIEDdeTIEX366aeSpIyMDCUnJ2vv3r1mhgk/oEIEgCucP39edrvd/fPq1at12223uX+uVasWs0xLKBIiAFyhZs2aWr9+vaTLi3tv2bJFrVu3dr+elZWlyMhIs8KDH5U2OwAACCT9+vXToEGDtGXLFq1YsULx8fFq0aKF+/XVq1erUaNGJkYIfyEhAiWcw+FwtwAdDofXY69sFVrV0KFDde7cOX322WeKiYnRJ5984vH6qlWr1KNHD5Oigz8xqQYo4UqVKqWDBw8qKipKISEhstlsVx3jcrlks9mUm5trQoRAYKBCBEq4FStWqFKlSu5/zi8h4mrnz5/XN998ox07dkiS6tWrp7vuuktly5Y1OTL4CxUiAPzB4sWL1b9/fx09etRjf+XKlTV79mx16tTJpMjgT8wyBSykbt26GjVqlHbu3Gl2KAFr9erVuv/++3Xbbbdp1apVOn78uI4fP65///vfatOmje6//359//33ZocJP6BCRIlw9uxZjRs3Tunp6Tp8+PBV603u2bPHpMgCy8SJE7VgwQJt2LBBzZs3V+/evdWtWzfFxMSYHVrAuOeeexQbG6t33nkn39cfffRRZWZm6ssvvyzmyOBvJMQAd+WEiCsdO3ZMUVFRTIL4jx49eujbb79Vnz59VKVKlavGyZ566imTIgtMO3bs0IcffqiPPvpIe/fu1R133KHevXtf9fw/K6pUqZK+/fZb90o1f/TTTz/p9ttv14kTJ4o5MvgbCTHAhYSEKCsr66qE+Pvvv6t27do6f/68SZEFlgoVKmjJkiUeN1CjYL7//nsNHDhQP/30E3/BklS2bFlt375dNWvWzPf1X3/9VfHx8fy3VwIxyzRATZkyRZJks9k0a9YsRUREuF/Lzc3Vd999p/j4eLPCCzgVK1Z0z6REwaxdu1YLFixQWlqaHA6HHnjgAbNDCgh169bVihUrlJycnO/r6enpqlu3bjFHheJAQgxQEydOlHT5/rAZM2aoVKlS7tdCQ0MVFxenGTNmmBVewHnllVc0cuRIzZs3T9ddd53Z4QSsP7ZK27Vrp9dee01dunTx+EuXlSUnJ+vZZ59VdHS07rnnHo/XlixZoqFDh+qFF14wKTr4Ey3TAHfHHXfos88+83gcDa7WrFkz7d69Wy6XS3FxcSpTpozH6xs2bDApssASEhKim2++WT179lT37t0VHR1tdkgBx+l0qlu3blq0aJHq16+vBg0ayOVyadu2bdq5c6c6d+6sTz75RCEhTNIvaUiIQSInJ0d79+5V7dq1Vbo0hf0f/dnz6VJTU4spksC2c+dO2n0FlJaWpo8++sjjxvzu3bure/fuJkcGfyEhBrjz589r8ODBmjdvnqTLLa9atWrpiSeeULVq1TRs2DCTIwSAkoGaP8ANGzZMmzZtUkZGhsLDw937k5KSlJaWZmJkCBaVKlVyr7iSN/nIaIO0cOFC5eTkuH/+7bffPO5rPXfunF5//XUzQoOfUSEGuJo1ayotLU1//etfVb58eW3atEm1atXSrl271Lx58z99eoFV5ObmauLEiVq4cKH279/v8QtNko4fP25SZOabN2+eunfvrrCwMHenwUi/fv2KKarA9cd7f+12uzZu3KhatWpJkg4dOqSqVatyi0oJxGBUgDty5MhV9yBKl1dmYZHm/xo9erRmzZqlZ555RiNGjNCLL76offv26R//+IdGjhxpdnimujLJkfD+3B9rBGoG66BlGuASEhK0ZMkS9895SXDWrFlKTEw0K6yA8+GHH2rmzJl65plnVLp0afXo0UOzZs3SyJEjWXfSwIULF+RwODw2wMqoEAPcmDFj1LFjR23dulWXLl3S5MmTtXXrVq1evVrffvut2eEFjKysLPdSWxERETp16pQk6W9/+5teeuklM0MLKGfPntXzzz+vhQsX6tixY1e9ThsQVkaFGOBuvfVWbdy4UZcuXVLjxo319ddfKyoqSmvWrFGLFi3MDi9gVK9eXQcPHpQk1a5dW19//bUkad26dQoLCzMztIAydOhQrVixQtOnT1dYWJhmzZql0aNHq2rVqpo/f77Z4QWMZcuWafHixVq8eLGcTqfS09PdPy9btszs8OAnTKpBiTBs2DDZ7Xa98MILSktLU+/evRUXF6f9+/fr6aef1rhx48wOMSDUqFFD8+fPV9u2bWW327VhwwbVqVNH77//vj766COe4CAV6IZ7m81GNV0CkRADUGHGcux2ux8jCV5r1qzRmjVrVLduXR7meoWIiAht3bpVNWrUUPXq1fXZZ5+pZcuW2rt3rxo3bqwzZ86YHSJgGsYQA1CFChX+dAapy+Xib6leJCYmMukoH7Vq1dLevXtVo0YNxcfHa+HChWrZsqW++OILVahQwezwAsa5c+e0e/fufB8BtWXLFtWsWZO1X0sgEmIAWrlypdkhBIXFixerY8eOKlOmjBYvXuz12HvvvbeYogpsycnJ2rRpk26//XYNGzZMnTp10tSpU3Xx4kVNmDDB7PACRk5Ojlq1aqWMjAy1bNnSvX/r1q1q1qyZ9u/fT0IsgWiZBoGTJ09q9uzZ2rZtmySpYcOGevjhhxUZGWlyZOa68lmR3sZ9qKSN/frrr1q/fr3q1KmjJk2amB1OQOnatauioqI0depU977hw4dr48aN+uqrr0yMDP5CQgxwP/74o+6++26Fh4e7/6a6bt06nT9/Xl9//bWaN29ucoQINunp6UpPT9fhw4c9liSTpDlz5pgUVeBZsmSJHnzwQR08eFClS5eWy+VSzZo1NX78eHXt2tXs8OAHJMQA16ZNG9WpU0czZ850P+Xi0qVL6t+/v/bs2aPvvvvO5AgRTEaPHq2XX35ZCQkJqlKlylVj1Z9//rlJkQWe3NxcVa9eXTNmzNDf//53rVy5Uvfdd5+ysrIUGhpqdnjwAxJigCtbtqz+7//+T/Hx8R77t27dqoSEBJ07d86kyMw3ZcoUPfLIIwoPD9eUKVO8Hvvkk08WU1SBrUqVKnr99dfVp08fs0MJCs8++6z27t2rRYsW6aGHHlJYWJimT59udljwEybVBDi73a79+/dflRAzMzNVvnx5k6IKDBMnTlSvXr0UHh6uiRMnGh5ns9lIiP+Rk5OjW26
},
"metadata": {}
}
]
},
{
"cell_type": "code",
"source": [
"evaluateAndShowAttention('Olen todella pahoillani')"
],
"metadata": {
"execution": {
"iopub.status.busy": "2024-05-25T14:46:10.394969Z",
"iopub.execute_input": "2024-05-25T14:46:10.395671Z",
"iopub.status.idle": "2024-05-25T14:46:10.793392Z",
"shell.execute_reply.started": "2024-05-25T14:46:10.395630Z",
"shell.execute_reply": "2024-05-25T14:46:10.791940Z"
},
"trusted": true
},
"execution_count": 45,
"outputs": [
{
"name": "stdout",
"text": "input = olen todella pahoillani\noutput = i am truly sorry <EOS>\n",
"output_type": "stream"
},
{
"name": "stderr",
"text": "/tmp/ipykernel_34/2052950992.py:8: UserWarning: FixedFormatter should only be used together with FixedLocator\n ax.set_xticklabels([''] + input_sentence.split(' ') +\n/tmp/ipykernel_34/2052950992.py:10: UserWarning: FixedFormatter should only be used together with FixedLocator\n ax.set_yticklabels([''] + output_words)\n",
"output_type": "stream"
},
{
"output_type": "display_data",
"data": {
"text/plain": "<Figure size 640x480 with 2 Axes>",
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAcQAAAHXCAYAAAAiHSoqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuNSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/xnp5ZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAA01ElEQVR4nO3deVyU9fr/8feALCqCmCIuqFkukGuSpB7TyjI7Xz1mltoCYlonzVS00lLRTqknV0rLJLfqmLbZsa9lCy6d1Cy3nxsuaQYnRTGXIUxQZn5/EPNtkpnAAe8Z7tfTx/1I7rmXa0aby+uz3Ra73W4XAAAm52d0AAAAeAMSIgAAIiECACCJhAgAgCQSIgAAkkiIAABIIiECACCJhAgAgCQSIgAAkkiIAABIIiECACCJhAgAgCQSIgAAkqRKRgcAoHytWrVKPXr0UEBAgFatWuX22F69el2lqADvY+HxT0DF5ufnp6ysLEVERMjPz3WjkMViUUFBwVWMDPAuJEQAAEQfIgAAkuhDBEwnLS1NaWlpOnnypGw2m9NrixYtMigqwHgkRMBEJk+erOeff16xsbGqU6eOLBaL0SEBXoM+RMBE6tSpo5deekkPP/yw0aEAXoc+RMBE8vPz1bFjR6PDALwSCREwkcGDB2vZsmVGhwF4JfoQARO5cOGCFixYoC+//FKtWrVSQECA0+uzZs0yKDLAePQhAiZy6623unzNYrFo7dq1VzEawLuQEAEAEH2IAABIog8RMJ2tW7fq3XffVUZGhvLz851e+/DDDw2KCjAeFSJgIsuXL1fHjh2Vnp6ulStX6uLFi9q7d6/Wrl2rsLAwo8MDDEVCBExkypQpmj17tj7++GMFBgYqJSVF+/fv1/33368GDRoYHR5gKBIiYCKHDx/WX//6V0lSYGCgcnNzZbFYNGrUKC1YsMDg6LxTQUGBdu3apUuXLhkdCsoZCREwkfDwcOXk5EiS6tWrpz179kiSzp49q/PnzxsZmtf6+OOP1bZtW61YscLoUFDOSIiAidxyyy364osvJEn33XefRowYoSFDhmjAgAG6/fbbDY7OOy1dulS1atXSkiVLjA4F5Yx5iKgwcnNztWHDhmJHTz755JMGReVdTp8+rQsXLqhu3bqy2Wx66aWXtGnTJjVp0kTjx49XeHi40SF6lVOnTql+/fr66KOP1KtXLx05ckT169c3OiyUExIiKoQdO3bo7rvv1vnz55Wbm6saNWro1KlTqlKliiIiInTkyBGjQ4QPeuWVV7R06VJt3bpVt99+u7p166Zx48YZHRbKCU2mqBBGjRqlnj176syZM6pcubK++eYb/fjjj2rXrp1mzJhhdHiGslqtJd7gbMmSJYqPj5ckPfTQQ3rzzTcNjgjliQoRFUL16tW1ZcsWNWvWTNWrV9fmzZsVHR2tLVu2KCEhQfv37zc6RMP4+fn96YOA7Xa7LBaLCgoKrlJU3m/Pnj1q166dfvrpJ9WsWVO//PKLateurbVr1youLs7o8FAOWKkGFUJAQID8/AobPCIiIpSRkaHo6GiFhYUpMzPT4OiMtW7dOqND8ElLly7VnXfeqZo1a0qSQkJC1Lt3by1ZsoSEWEGREFEhtG3bVt99952aNGmiLl26aOLEiTp16pTeeusttWjRwujwDNWlSxejQ/A5BQUFevvtt/Xyyy877X/ooYf04IMPKiUlRYGBgQZFh/JCkykqhK1btyonJ0e33nqrTp48qfj4eMfoyUWLFql169ZGh2iYXbt2lfjYVq1alWMkvuP48eNKTU3V2LFjnRKfzWbTlClTFB8fz8o+FRAJEajgivoQ/+x/dfoQYXY0mQIV3A8//GB0CBXCjz/+qNzcXDVv3tzRX42KhQoRPqtt27Z/OnqyyPbt28s5GlQUixYt0tmzZ5WUlOTY9+ijj2rhwoWSpGbNmumzzz5TVFSUUSGinFAhwmf17t3b6BB8wqpVq9SjRw8FBARo1apVbo/t1avXVYrKey1YsECPPfaY4+c1a9Zo8eLFevPNNxUdHa0nnnhCkydP1htvvGFglCgPVIhABefn56esrCxFRES4beqjD7HQNddco/Xr16tly5aSpMcff1zZ2dl6//33JUnr169XYmIiTdEVEA3hQAVns9kUERHh+L2rjWRY6Ndff1VoaKjj502bNumWW25x/Ny4cWNlZWUZERrKGU2m8Fnh4eEl7kM8ffp0OUeDiqJhw4batm2bGjZsqFOnTmnv3r3q1KmT4/WsrCyFhYUZGCHKCwkRPmvOnDlGh+CTNmzYoBkzZig9PV2SFBMTo6eeekqdO3c2ODLvkJCQoGHDhmnv3r1au3atmjdvrnbt2jle37Rpk+kXe6ioSIjwWQkJCUaH4HPefvttJSYmqk+fPo5HYm3cuFG33367lixZogceeMDgCI339NNP6/z58/rwww8VGRmp9957z+n1jRs3asCAAQZFh/LEoBpUGIcPH9bixYt1+PBhpaSkKCIiQp9++qkaNGigG264wejwvEJ0dLQeffRRjRo1ymn/rFmzlJqa6qgaATMiIaJC2LBhg3r06KFOnTrpq6++Unp6uho3bqxp06Zp69atjhGCZhcUFKS9e/fq+uuvd9r//fffq0WLFrpw4YJBkXmfX3/9VV988YUOHjwoSWratKnuuOMOVa5c2eDIUF5oMvUBhw4d0rp163Ty5EnZbDan1yZOnGhQVN5l7NixeuGFF5SUlKRq1ao59t92222aO3eugZF5l6ioKKWlpV2WEL/88ksmmv/OqlWrNHjwYJ06dcppf82aNbVw4UL17NnToMhQnkiIXi41NVWPP/64atasqcjISKdRlRaLhYT4m927d2vZsmWX7Y+IiLjsS83MRo8erSeffFI7d+5Ux44dJRX2iS1ZskQpKSkGR+cdNm3apL59+6pXr14aPXq0oqOjJUn79u3TzJkz1bdvX23YsEE333yzwZGirNFk6uUaNmyooUOH6plnnjE6FK9Wv359vfvuu+rYsaOqVaum//f//p8aN26slStXasyYMTp8+LDRIXqNlStXaubMmY7+wujoaD311FP629/+ZnBk3uHuu+9WVFSUXn/99WJff+yxx5SZmalPPvnkKkeG8kZC9HKhoaHauXOnGjdubHQoXm3MmDHasmWL3nvvPTVt2lTbt2/XiRMnFB8fr/j4eCUnJxsdInxEjRo1tGHDBsdKNX+0a9cudenSRWfOnLnKkaG8sVKNl7vvvvv0+eefGx2G15syZYqaN2+uqKgo/fLLL4qJidEtt9yijh07avz48UaH53Xy8/P13//+VxkZGU4bLl+p5o/CwsIYfFRB0Yfo5a6//npNmDBB33zzjVq2bKmAgACn14vmkpldYGCgUlNTNWHCBO3Zs0e//PKL2rZtqyZNmhgdmlc5dOiQBg0apE2bNjntt9vtrGX6myZNmmjt2rVKTEws9vW0tDT+XlVQNJl6uWuvvdblaxaLRUeOHLmK0cDXderUSZUqVdLYsWNVp06dy5a+a926tUGReY/Zs2frhRde0FtvvaW7777b6bXVq1crISFBzz77rNPjoVAxkBDhs0rzhTRr1qxyjMR3VK1aVdu2bVPz5s2NDsVr2Ww29evXTx988IGaNWum6Oho2e12paen69ChQ+rdu7fee+89HhJcAdFk6iPy8/P1ww8/6LrrrlOlSvyxSdKOHTucft6+fbsuXbqkZs2aSZIOHjwof39/p3UozS4mJoZpKH/Cz89P7733nlasWKF33nlH+/fvlyQ1b95ckyZNUv/+/Q2OEOWFCtHLnT9/XsOHD9fSpUslFX7JN27cWMOHD1e9evU0duxYgyP0DrNmzdL69eu1dOlShYeHS5LOnDmjxMREde7cWaNHjzY4QuNYrVbH77du3arx48drypQpxfZJuxtMAlR4dni1J5980t6uXTv7f/7zH3vVqlXthw8fttvtdvtHH31kb9OmjcHReY+6deva9+zZc9n+3bt32+vUqWNARN7DYrHY/fz8HNsff/79PtjtK1assOfl5Tl+zszMtBcUFDh+zs3Ntf/zn/80IjSUM9revNxHH32kFSt
},
"metadata": {}
}
]
},
{
"cell_type": "code",
"source": [
"evaluateAndShowAttention('Olet minun isäni')"
],
"metadata": {
"execution": {
"iopub.status.busy": "2024-05-25T14:46:13.190403Z",
"iopub.execute_input": "2024-05-25T14:46:13.191025Z",
"iopub.status.idle": "2024-05-25T14:46:13.613486Z",
"shell.execute_reply.started": "2024-05-25T14:46:13.190997Z",
"shell.execute_reply": "2024-05-25T14:46:13.612143Z"
},
"trusted": true
},
"execution_count": 46,
"outputs": [
{
"name": "stdout",
"text": "input = olet minun isani\noutput = you re my father <EOS>\n",
"output_type": "stream"
},
{
"name": "stderr",
"text": "/tmp/ipykernel_34/2052950992.py:8: UserWarning: FixedFormatter should only be used together with FixedLocator\n ax.set_xticklabels([''] + input_sentence.split(' ') +\n/tmp/ipykernel_34/2052950992.py:10: UserWarning: FixedFormatter should only be used together with FixedLocator\n ax.set_yticklabels([''] + output_words)\n",
"output_type": "stream"
},
{
"output_type": "display_data",
"data": {
"text/plain": "<Figure size 640x480 with 2 Axes>",
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAcQAAAHHCAYAAAAhyyixAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuNSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/xnp5ZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAA5XUlEQVR4nO3de1yUZf7/8feAAiqCBxQ8kORZ1gMEK2lZbkvaVpata5itIKVbJpVRW7qldvqKlZm1a2kqHmpLttZ+a2thhbLuKrsWZh7XYwqZHCx1Cg10Zn5/sExNcE/AADPDvJ77uB/r3HPf9/WZedh8/FzXdV+3yWaz2QQAgI/zc3cAAAB4AhIiAAAiIQIAIImECACAJBIiAACSSIgAAEgiIQIAIImECACAJBIiAACSSIgAAEgiIQIAIImECACAJBIiAACSSIgA4JTFYtGuXbt08eJFd4eCRkZCBAAn3n33XcXGxiorK8vdoaCRkRABwInVq1erU6dOWrVqlbtDQSMz8YBgAKjZqVOn1L17d/2///f/dNNNN+no0aPq3r27u8NCI6FCBAADb775pgYOHKjrrrtOI0aM0GuvvebukNCISIgAYGDVqlVKTk6WJP32t7/VmjVr3BwRGhNdpgBQgz179iguLk4nTpxQWFiYvv32W4WHh2vTpk1KSEhwd3hoBFSIAFCD1atXa9SoUQoLC5MkBQcHa+zYsUyuacaoENFsHDp0SJs3b1ZJSYmsVqvDe3PmzHFTVPBGFotF3bt310svvaTx48fb97///vu6/fbbVVRUpICAADdGiMZAQkSzsGzZMk2bNk1hYWGKiIiQyWSyv2cymbRjxw43Rgdvc/LkSS1btkwzZ850SHxWq1Xz5s1TcnKyLrnkEjdGiMZAQkSz0KNHD91zzz165JFH3B0KAC/FGCKahdOnTzt0bQEN7fjx49q3b1+17ng0HyRENAvjx4/XBx984O4w0AxkZmZq4cKFDvt+97vfqWfPnho0aJAGDhyowsJCN0WHxtTC3QEADaF3796aPXu2/v3vf2vQoEFq2bKlw/v33XefmyKDt3n11Vd111132V9nZ2dr5cqVWrNmjQYMGKC0tDQ98cQTWr58uRujRGNgDBHNwqWXXmr4nslk0tGjR5swGnizjh07Kjc3V4MGDZIkTZs2TaWlpXr77bclSbm5uUpNTdXnn3/uzjDRCKgQ0Szw44SGcv78eYWEhNhfb9u2TXfeeaf9dc+ePVVUVOSO0NDIGEMEgB/o0aOH8vPzJVUu7r13715dccUV9veLiooUGhrqrvDQiKgQ0SzccccdTt/PzMxsokjg7VJSUjR9+nTt3btXmzZtUv/+/RUXF2d/f9u2bRo4cKAbI0RjISGiWTh9+rTD6wsXLmjPnj06c+aMrrnmGjdFBW/08MMP69y5c1q3bp0iIiL01ltvOby/detW3XbbbW6KDo2JSTVotqxWq6ZNm6ZevXrp4Ycfdnc4ADwcCRHN2oEDBzRy5EidPHnS3aG4TYcOHXTw4EGFhYWpffv2Dsva/djXX3/dhJF5tvPnz+vDDz/UwYMHJUl9+/bVtddeq1atWrk5MjQWukzRrB05ckQXL150dxhu9cILL6ht27aSpEWLFrk3GC+xfv16TZkyRadOnXLYHxYWphUrVmjMmDFuigyNiQoRzUJ6errDa5vNppMnT2rDhg1KSUnRn/70JzdFBm+zbds2jRw5UjfddJMefPBBDRgwQJK0b98+Pf/88/r73/+uf/zjH7r88svdHCkaGgkRzcIvfvELh9d+fn7q1KmTrrnmGt1xxx1q0YLOkCpWq1WHDx+u8TFZV111lZui8hzXX3+9IiMjtXTp0hrfv+uuu1RYWKj33nuviSNDYyMhejh/f3+dPHlSnTt3dtj/1VdfqXPnzrJYLG6KDN7o3//+tyZOnKjjx4/rx//pm0wm/j6pcsz1H//4h32lmh/btWuXrr766mozm+H9+GezhzP690p5eTkPKEWd3X333YqPj9eGDRvUpUsXpxNsfNWPV6r5sdDQUH333XdNGBGaCgnRQ7300kuSKv/Vvnz5cgUHB9vfs1gs2rJli/r37++u8DxOcXGxHnroIeXk5KikpKTaPySofCodOnRIb7/9tnr37u3uUDxWnz59tGnTJqWmptb4fk5Ojvr06dPEUaEpkBA91AsvvCCpskJcsmSJ/P397e8FBAQoKipKS5YscVd4Hmfy5MkqKCjQ7NmzqXycSEhI0OHDh0mITqSmpuqhhx5SeHi4rr/+eof3NmzYoIcfflh/+MMf3BQdGhNjiB7uF7/4hdatW6f27du7OxSP1rZtW/3zn/9UTEyMu0PxaO+8844ee+wx/f73v6/xMVmDBw92U2Sew2q1KikpSX/961/Vr18/DRgwQDabTfv379ehQ4c0duxYvfXWW/LzYyno5oaE6CUqKir0+eefq1evXsyYrEF0dLT+/Oc/KzY21t2heLSafsRNJpNsNhuTan4kKytLb775psON+RMmTNCECRPcHBkaCwnRw50/f15paWlavXq1JOngwYPq2bOn7r33XnXr1k0zZ850c4Se4YMPPtDzzz+vpUuXKioqyt3heKzjx487fb9Hjx5NFAngeaj5PdzMmTP12WefKTc3V0FBQfb9iYmJysrKcmNkniUpKUm5ubnq1auX2rZtqw4dOjhsqNSjRw+nG6S//OUvqqiosL/+4osvHO7XPHfunJ599ll3hIZGRoXo4Xr06KGsrCxdfvnlatu2rT777DP17NlThw8f1mWXXSaz2ezuED1CVQVtJCUlpYki8Q779u1TQUGBww+/JN10001uishz/Pje35CQEO3cuVM9e/aUVDmjuWvXrnQvN0MMRnm40tLSajflS1JZWRkzKX+AhFc7R48e1S233KLdu3fbxw4l2f8u8SNf/d5fagbfQZeph6u6ibpK1Q/X8uXLNWzYMHeF5RF+WB2bzWanGyrdf//9uvTSS1VSUqLWrVtr79692rJli+Lj45Wbm+vu8AC3okL0cPPmzdOvfvUr7du3TxcvXtSLL76offv2adu2bfrHP/7h7vDcqn379vaurXbt2tVYMTN70lFeXp42bdqksLAw+fn5yc/PT1deeaUyMjJ033336dNPP3V3iIDbkBA93JVXXqmdO3dq/vz5GjRokD744ANddtllysvLM1xr0Vds2rTJPmFm8+bNbo7GO1gsFvujoMLCwvTll1+qX79+6tGjhw4cOODm6DzHxo0bFRoaKqnyvsScnBzt2bNHknTmzBk3RobGxKQaNBvfffeddu3aVeNTHJgsUmnEiBF68MEHNXbsWE2cOFGnT5/WY489pldffVX5+fn2H31fVpsb7ul1aJ5IiB6oLmNezhYh9iXZ2dlKTk6u9kBXiR+vH9q4caPKysr061//WocPH9aNN96ogwcPqmPHjsrKytI111zj7hABtyEheiA/P7+fnEHK2JijPn36aNSoUZozZ47Cw8PdHY5X+frrr9W+fXtmLf/AuXPndOTIkRqHJfbu3asePXo4LLiP5oGE6IHqMlnm6quvbsRIvEdISIg+/fRT9erVy92heBWz2axNmzapf//+PD3lB86cOaOuXbsqNzdXQ4cOte/ft2+fYmJiVFBQoIiICDdGiMbApBoP9OMkd+bMGa1YsUL79++XVLlu55133mkf9If0m9/8xr5SDYzdeuutuuqqq5SWlqbz588rPj5ex44dk81m09q1azVu3Dh3h+gR2rVrpxtvvFFr1qxxSIivvfaafvnLX5IMmykqRA/3ySef6LrrrlNQUJD9P8yPP/5Y58+ft884RWUX1/jx49WpU6can+Jw3333uSkyzxIREaGNGzdqyJAheuONNzR37lx99tlnWr16tV599VVuu/iBDRs2aPLkyTp58qRatGghm82mHj16aMGCBbr11lvdHR4aAQnRw40YMUK9e/fWsmXL7E+5uHjxoqZMmaKjR49qy5Ytbo7QM6xYsUJ33323goKC1LFjR4fxMJPJpKNHj7oxOs/RqlUrHTx4UJGRkUpOTlbXrl01f/58FRQUKDo6Wt9++627Q/Q
},
"metadata": {}
}
]
},
{
"cell_type": "code",
"source": [
"evaluateAndShowAttention('Hän on opettaja')"
],
"metadata": {
"execution": {
"iopub.status.busy": "2024-05-25T14:46:15.840433Z",
"iopub.execute_input": "2024-05-25T14:46:15.841005Z",
"iopub.status.idle": "2024-05-25T14:46:16.232003Z",
"shell.execute_reply.started": "2024-05-25T14:46:15.840973Z",
"shell.execute_reply": "2024-05-25T14:46:16.230365Z"
},
"trusted": true
},
"execution_count": 47,
"outputs": [
{
"name": "stdout",
"text": "input = han on opettaja\noutput = he is a teacher <EOS>\n",
"output_type": "stream"
},
{
"name": "stderr",
"text": "/tmp/ipykernel_34/2052950992.py:8: UserWarning: FixedFormatter should only be used together with FixedLocator\n ax.set_xticklabels([''] + input_sentence.split(' ') +\n/tmp/ipykernel_34/2052950992.py:10: UserWarning: FixedFormatter should only be used together with FixedLocator\n ax.set_yticklabels([''] + output_words)\n",
"output_type": "stream"
},
{
"output_type": "display_data",
"data": {
"text/plain": "<Figure size 640x480 with 2 Axes>",
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAcUAAAHOCAYAAADpBhJHAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuNSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/xnp5ZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAx5klEQVR4nO3deVxU9f7H8feAAioMriwqgpZL7lsalmmFWT2ybDUqF8qysl8p9TC95VqJWZlZloqa2i21bLmWSwuKXdMs8Wq5pOYGWihKikJCMvP7g2Fuc8UjBnhmOK8nj/O4zpmzfObE5TOf7/l+v8fmdDqdAgAA8jM7AAAAvAVJEQAAF5IiAAAuJEUAAFxIigAAuJAUAQBwISkCAOBCUgQAwIWkCACAC0kRAAAXkiIAAC4kRQAAXKqYHQAAcyxZskQffPCB0tPTVVBQ4PHepk2bTIoKMBeVImBB06ZNU0JCgsLDw/Wf//xHXbp0UZ06dbR3717deOONZocHmMbGo6MA62nRooXGjh2r+Ph4hYSEaMuWLWrSpInGjBmj7Oxsvfnmm2aHCJiCShGwoPT0dHXr1k2SVK1aNZ08eVKS1L9/fy1cuNDM0ABTkRQBC4qIiFB2drYkqVGjRvruu+8kSfv27RONR7AykiJgQddee62WLl0qSUpISNDw4cPVq1cv9evXT7fddpvJ0QHm4Z4iYEEOh0MOh0NVqhR1QF+0aJHWrVunpk2basiQIQoICDA5QsAcJEUAOI/CwkJt27ZNLVu2dH+RQOXEf13AIn788Ue1bt1afn5++vHHHw23DQ4OVlRUlKpWrXqRovNun332me644w4tWLBA9913n9nhoAJRKQIW4efnp8zMTIWFhcnPz082m82wU01oaKhmzJihfv36XcQovdNtt92m9evXq02bNvrqq6/MDgcViKQIWMSBAwfUqFEj2Ww2HThwwHDb/Px8ffjhh0pOTtb+/fsvToBe6ujRo2rYsKE+/fRT3XLLLdq7d68aNmxodlioIDSfAhYRHR1d4r/P5bHHHlNaWlpFhuQTFi5cqNatW+uGG25Q9+7d9e6772rUqFFmh4UKQqUIWFheXl6Jc5+2bdvWpIi8T6dOnTRw4EA98cQTeueddzR58mTt2LHD7LBQQUiKgAVlZWUpISFBK1asKPH9wsLCixyRd9q6das6deqkQ4cOqW7dujp16pTCw8O1atUqde3a1ezwUAEYvA9Y0LBhw3T8+HFt2LBB1apV08qVKzV//nw1bdrUPagf0vz583X99derbt26kop65fbt21fz5s0zNzBUGCpFwIIiIyP1r3/9S126dJHdbtfGjRvVrFkzLV26VJMnT9batWvNDtF0hYWFatiwoaZNm6a77rrLvX7FihW67777lJmZySQHlRCVImBBubm5CgsLkyTVqlVLWVlZkqQ2bdrwLEWXI0eO6NFHH9Wtt97qsb53795KTExUZmamSZGhIlEpAhZ0+eWX64UXXlDv3r11yy23qGbNmkpKStK0adO0ZMkS7dmzx+wQAVOQFAEL+uc//6kzZ85o0KBBSktL0w033KDs7GwFBARo3rx5DNg/hwMHDig3N1ctWrSQnx8NbZURSRGA8vLy9PPPP6tRo0buTiVWNnfuXB0/flyJiYnudQ8//LDmzJkjSWrevLm++OILRUVFmRUiKghfdQALmjBhgvLy8tyvq1evro4dO6pGjRqaMGGCiZF5h1mzZqlWrVru1ytXrtQ777yjBQsW6IcfflDNmjU1fvx4EyNERaFSBCzI399fv/32m7uzTbFjx44pLCzM8uMU69Spo9TUVLVp00aS9OijjyorK0tLliyRJKWmpiohIUH79u0zM0xUACpFwIKcTqdsNttZ67ds2aLatWubEJF3+eOPP2S3292v161bp6uvvtr9ukmTJvQ+raSY+xSwkFq1aslms8lms6lZs2YeibGwsFCnTp3SI488YmKE3iE6OlppaWmKjo7W0aNHtW3bNl155ZXu9zMzMxUaGmpihKgoJEXAQqZOnSqn06kHHnhA48eP9/jDHhAQoJiYGMXGxpoYoXcYOHCghg4dqm3btmnVqlVq0aKFOnXq5H5/3bp1at26tYkRoqKQFAELGThwoCSpcePGuvLKK3mK/DmMGDFCeXl5+vjjjxUREaEPP/zQ4/1vv/1W8fHxJkWHikRHG8CC6GgDlIyviYAFneu7cH5+PvN5/sUff/yhr776Srt27ZIkNWvWTL169VK1atVMjgwVhaQIWMi0adMkSTabTbNnz1ZwcLD7vcLCQn3zzTdq0aKFWeF5laVLl2rw4ME6evSox/q6detqzpw56tOnj0mRoSLRfApYSOPGjSUVTVfWsGFD+fv7u98r7mgzYcIEyz8rcN26derZs6duueUWPfXUU7rsssskSdu3b9err76qzz//XGvWrNEVV1xhcqQobyRFwIKuueYaffzxxx6ztuC/brrpJkVFRWnmzJklvj9kyBBlZGRo+fLlFzkyVDSSImBhBQUF2rdvny655BJ6ov5F7dq1tWbNGveMNv/rxx9/VI8ePfT7779f5MhQ0ZjRBrCgP/74Qw8++KCqV6+uVq1aKT09XZL0f//3f5o0aZLJ0Znvf2e0+V+hoaE6ffr0RYwIFwtJ0UccPnxY/fv3V/369VWlShX5+/t7LMCFGDlypLZs2aLU1FQFBQW518fFxWnx4sUmRuYdmjZtqlWrVp3z/ZSUFDVt2vQiRoSLhfYSHzFo0CClp6dr9OjRioyMLHHeSqC0Pv30Uy1evFhXXHGFx+9Sq1ateMCwpISEBD399NMKDw/XTTfd5PHesmXLNGLECP3jH/8wKTpUJJKij1i7dq3+/e9/q3379maHgkogKyvrrIH7kpSbm8sXLklPPvmk1q1bp5tvvlnNmzfXZZddJqfTqR07dmj37t3q27evhg0bZnaYqAA0n/qIqKiocw64Bi5U586dtWzZMvfr4kQ4e/Zs5j6V5Ofnpw8//FALFy5U8+bN9fPPP2vnzp1q0aKF3nvvPX300Ufy8+PPZ2VE71Mf8eWXX+rVV1/VzJkzFRMTY3Y48HFr167VjTfeqPvvv1/z5s3TkCFDtH37dq1bt05r1qzxmPwasBK+6viIfv36KTU1VZdccolCQkJUu3ZtjwW4EFdddZU2b96sM2fOqE2bNvryyy8VFham9evXkxAlffDBByooKHC/PnjwoBwOh/t1Xl6eJk+ebEZoqGBUij5i/vz5hu8XP/0AQNn974TpdrtdmzdvVpMmTSQV9QavX78+E6dXQnS08REkvQtTUFCgI0eOeHy7l6RGjRqZFJH3KSws1CeffKIdO3ZIklq2bKlbb72VQfw6e8J0agfr4LffB50+fdqjaUeS4UBjK9m9e7ceeOABrVu3zmO90+mUzWbjm73Ltm3bdMsttygzM1PNmzeXJL300kuqV6+ePvvsMx6gC8siKfqI3NxcPfPMM/rggw907Nixs97nj32RQYMGqUqVKvr8888Zz2lg8ODBatWqlTZu3Oie//T333/XoEGD9PDDD5/1pQKwCpKijxgxYoRWr16tt99+W/3799f06dN16NAhzZw5k2m5/mLz5s1KS0vj8UfnsXnzZo+EKEm1atXSiy++qMsvv9zEyLzHF198odDQUEmSw+FQSkqKtm7dKkk6fvy4iZGhIpEUfcRnn32mBQsWqGfPnkpISFD37t116aWXKjo6Wu+9957uu+8+s0P0Ci1btjzr+Xc4W7NmzXT48GG1atXKY/2RI0d06aWXmhSVd/nf+/hDhgzxeE0rROXEkAwfkZ2d7e75ZrfblZ2dLamoa/0333xjZmhe5aWXXtKIESOUmpqqY8eOKScnx2NBkaSkJD3xxBNasmSJDh48qIMHD2rJkiUaNmyYXnrpJctfM4fDcd6FWxaVE0MyfETbtm31xhtvqEePHoqLi1P79u31yiuvaNq0aZo8ebIOHjxodohe4a+zjPz1mzwdbTyVdJ2K/xT89bWVr1leXp727NlT4uOjtm3bpujoaAUHB5sQGSoSzac+IiEhQVu2bFGPHj00cuRI9enTR2+++ab
},
"metadata": {}
}
]
},
{
"cell_type": "code",
"source": [
"torch.save(encoder.state_dict(), \"encoder.pt\")\n",
"torch.save(decoder.state_dict(), \"decoder.pt\")"
],
"metadata": {
"execution": {
"iopub.status.busy": "2024-05-25T14:50:24.773506Z",
"iopub.execute_input": "2024-05-25T14:50:24.774464Z",
"iopub.status.idle": "2024-05-25T14:50:24.795895Z",
"shell.execute_reply.started": "2024-05-25T14:50:24.774430Z",
"shell.execute_reply": "2024-05-25T14:50:24.794979Z"
},
"trusted": true
},
"execution_count": 48,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# BLEU score"
],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"Jako że korzystaliśmy z okrojonej wersji zbioru danych, słownik nie zawiera wszystkich słów pojawiających się w przykładach więc do ewaluacji wykorzystujemy część przykładów z treningu"
],
"metadata": {}
},
{
"cell_type": "code",
"source": [
"import pandas as pd\n",
"\n",
"\n",
"def filter_rows(row):\n",
" return len(row[\"English\"].split(' '))<MAX_LENGTH and \\\n",
" len(row[\"Finnish\"].split(' '))<MAX_LENGTH and \\\n",
" row[\"English\"].startswith(eng_prefixes)\n",
"data_file = pd.read_csv(\"/kaggle/input/anki-en-fin/fin.txt\", sep='\\t', names=[\"English\",\"Finnish\",\"attribution\"])\n",
"data_file[\"English\"] = data_file[\"English\"].apply(normalizeString)\n",
"data_file[\"Finnish\"] = data_file[\"Finnish\"].apply(normalizeString)\n",
"\n",
"filter_list = data_file.apply(filter_rows, axis=1)\n",
"\n",
"test_section = data_file[filter_list]\n",
"test_section = test_section.sample(frac=1).head(500)\n",
"test_section.head()"
],
"metadata": {
"execution": {
"iopub.status.busy": "2024-05-25T15:28:54.296348Z",
"iopub.execute_input": "2024-05-25T15:28:54.297253Z",
"iopub.status.idle": "2024-05-25T15:28:59.041172Z",
"shell.execute_reply.started": "2024-05-25T15:28:54.297211Z",
"shell.execute_reply": "2024-05-25T15:28:59.040201Z"
},
"trusted": true
},
"execution_count": 95,
"outputs": [
{
"execution_count": 95,
"output_type": "execute_result",
"data": {
"text/plain": " English ... attribution\n38027 i m very serious about this ... CC-BY 2.0 (France) Attribution: tatoeba.org #2...\n3803 i m not tired ... CC-BY 2.0 (France) Attribution: tatoeba.org #1...\n26924 i m not married either ... CC-BY 2.0 (France) Attribution: tatoeba.org #6...\n32009 he s sleeping like a baby ... CC-BY 2.0 (France) Attribution: tatoeba.org #2...\n21339 i m joking of course ... CC-BY 2.0 (France) Attribution: tatoeba.org #2...\n\n[5 rows x 3 columns]",
"text/html": "<div>\n<style scoped>\n .dataframe tbody tr th:only-of-type {\n vertical-align: middle;\n }\n\n .dataframe tbody tr th {\n vertical-align: top;\n }\n\n .dataframe thead th {\n text-align: right;\n }\n</style>\n<table border=\"1\" class=\"dataframe\">\n <thead>\n <tr style=\"text-align: right;\">\n <th></th>\n <th>English</th>\n <th>Finnish</th>\n <th>attribution</th>\n </tr>\n </thead>\n <tbody>\n <tr>\n <th>38027</th>\n <td>i m very serious about this</td>\n <td>olen hyvin tosissani tasta</td>\n <td>CC-BY 2.0 (France) Attribution: tatoeba.org #2...</td>\n </tr>\n <tr>\n <th>3803</th>\n <td>i m not tired</td>\n <td>en ole vasynyt</td>\n <td>CC-BY 2.0 (France) Attribution: tatoeba.org #1...</td>\n </tr>\n <tr>\n <th>26924</th>\n <td>i m not married either</td>\n <td>minakaan en ole naimisissa</td>\n <td>CC-BY 2.0 (France) Attribution: tatoeba.org #6...</td>\n </tr>\n <tr>\n <th>32009</th>\n <td>he s sleeping like a baby</td>\n <td>han nukkuu kuin pikkuvauva</td>\n <td>CC-BY 2.0 (France) Attribution: tatoeba.org #2...</td>\n </tr>\n <tr>\n <th>21339</th>\n <td>i m joking of course</td>\n <td>se oli vitsi tietenkin</td>\n <td>CC-BY 2.0 (France) Attribution: tatoeba.org #2...</td>\n </tr>\n </tbody>\n</table>\n</div>"
},
"metadata": {}
}
]
},
{
"cell_type": "code",
"source": [
"test_section[\"English_tokenized\"] = test_section[\"English\"].apply(lambda x: x.split())"
],
"metadata": {
"execution": {
"iopub.status.busy": "2024-05-25T15:29:07.110416Z",
"iopub.execute_input": "2024-05-25T15:29:07.110816Z",
"iopub.status.idle": "2024-05-25T15:29:07.117378Z",
"shell.execute_reply.started": "2024-05-25T15:29:07.110786Z",
"shell.execute_reply": "2024-05-25T15:29:07.116136Z"
},
"trusted": true
},
"execution_count": 96,
"outputs": []
},
{
"cell_type": "code",
"source": [
"test_section.head()[\"English_tokenized\"]"
],
"metadata": {
"execution": {
"iopub.status.busy": "2024-05-25T15:29:10.203170Z",
"iopub.execute_input": "2024-05-25T15:29:10.203540Z",
"iopub.status.idle": "2024-05-25T15:29:10.212993Z",
"shell.execute_reply.started": "2024-05-25T15:29:10.203511Z",
"shell.execute_reply": "2024-05-25T15:29:10.211937Z"
},
"trusted": true
},
"execution_count": 97,
"outputs": [
{
"execution_count": 97,
"output_type": "execute_result",
"data": {
"text/plain": "38027 [i, m, very, serious, about, this]\n3803 [i, m, not, tired]\n26924 [i, m, not, married, either]\n32009 [he, s, sleeping, like, a, baby]\n21339 [i, m, joking, of, course]\nName: English_tokenized, dtype: object"
},
"metadata": {}
}
]
},
{
"cell_type": "code",
"source": [
"test_section[\"English_translated\"] = test_section[\"Finnish\"].apply(lambda x: translate(x, tokenized=True))"
],
"metadata": {
"execution": {
"iopub.status.busy": "2024-05-25T15:30:53.183117Z",
"iopub.execute_input": "2024-05-25T15:30:53.183937Z",
"iopub.status.idle": "2024-05-25T15:30:56.313012Z",
"shell.execute_reply.started": "2024-05-25T15:30:53.183902Z",
"shell.execute_reply": "2024-05-25T15:30:56.312202Z"
},
"trusted": true
},
"execution_count": 100,
"outputs": []
},
{
"cell_type": "code",
"source": [
"test_section.head()"
],
"metadata": {
"execution": {
"iopub.status.busy": "2024-05-25T15:31:06.745381Z",
"iopub.execute_input": "2024-05-25T15:31:06.746471Z",
"iopub.status.idle": "2024-05-25T15:31:06.771839Z",
"shell.execute_reply.started": "2024-05-25T15:31:06.746417Z",
"shell.execute_reply": "2024-05-25T15:31:06.770679Z"
},
"trusted": true
},
"execution_count": 101,
"outputs": [
{
"execution_count": 101,
"output_type": "execute_result",
"data": {
"text/plain": " English ... English_translated\n38027 i m very serious about this ... [i, m, in, french]\n3803 i m not tired ... [i, not, tired]\n26924 i m not married either ... [i, m, not, married, either]\n32009 he s sleeping like a baby ... [he, is, as, a, pianist]\n21339 i m joking of course ... [i, m, joking, of, course]\n\n[5 rows x 5 columns]",
"text/html": "<div>\n<style scoped>\n .dataframe tbody tr th:only-of-type {\n vertical-align: middle;\n }\n\n .dataframe tbody tr th {\n vertical-align: top;\n }\n\n .dataframe thead th {\n text-align: right;\n }\n</style>\n<table border=\"1\" class=\"dataframe\">\n <thead>\n <tr style=\"text-align: right;\">\n <th></th>\n <th>English</th>\n <th>Finnish</th>\n <th>attribution</th>\n <th>English_tokenized</th>\n <th>English_translated</th>\n </tr>\n </thead>\n <tbody>\n <tr>\n <th>38027</th>\n <td>i m very serious about this</td>\n <td>olen hyvin tosissani tasta</td>\n <td>CC-BY 2.0 (France) Attribution: tatoeba.org #2...</td>\n <td>[i, m, very, serious, about, this]</td>\n <td>[i, m, in, french]</td>\n </tr>\n <tr>\n <th>3803</th>\n <td>i m not tired</td>\n <td>en ole vasynyt</td>\n <td>CC-BY 2.0 (France) Attribution: tatoeba.org #1...</td>\n <td>[i, m, not, tired]</td>\n <td>[i, not, tired]</td>\n </tr>\n <tr>\n <th>26924</th>\n <td>i m not married either</td>\n <td>minakaan en ole naimisissa</td>\n <td>CC-BY 2.0 (France) Attribution: tatoeba.org #6...</td>\n <td>[i, m, not, married, either]</td>\n <td>[i, m, not, married, either]</td>\n </tr>\n <tr>\n <th>32009</th>\n <td>he s sleeping like a baby</td>\n <td>han nukkuu kuin pikkuvauva</td>\n <td>CC-BY 2.0 (France) Attribution: tatoeba.org #2...</td>\n <td>[he, s, sleeping, like, a, baby]</td>\n <td>[he, is, as, a, pianist]</td>\n </tr>\n <tr>\n <th>21339</th>\n <td>i m joking of course</td>\n <td>se oli vitsi tietenkin</td>\n <td>CC-BY 2.0 (France) Attribution: tatoeba.org #2...</td>\n <td>[i, m, joking, of, course]</td>\n <td>[i, m, joking, of, course]</td>\n </tr>\n </tbody>\n</table>\n</div>"
},
"metadata": {}
}
]
},
{
"cell_type": "code",
"source": [
"candidate_corpus = test_section[\"English_translated\"].values\n",
"references_corpus = test_section[\"English_tokenized\"].values.tolist()\n",
"x = candidate_corpus.tolist()\n",
"y = [[el] for el in references_corpus]\n",
"#print(references_corpus[:5])\n",
"#print(candidate_corpus[:5])"
],
"metadata": {
"execution": {
"iopub.status.busy": "2024-05-25T15:43:29.441911Z",
"iopub.execute_input": "2024-05-25T15:43:29.442752Z",
"iopub.status.idle": "2024-05-25T15:43:29.447799Z",
"shell.execute_reply.started": "2024-05-25T15:43:29.442721Z",
"shell.execute_reply": "2024-05-25T15:43:29.446877Z"
},
"trusted": true
},
"execution_count": 118,
"outputs": []
},
{
"cell_type": "code",
"source": [
"y[:5]"
],
"metadata": {
"execution": {
"iopub.status.busy": "2024-05-25T15:43:30.474463Z",
"iopub.execute_input": "2024-05-25T15:43:30.475080Z",
"iopub.status.idle": "2024-05-25T15:43:30.482690Z",
"shell.execute_reply.started": "2024-05-25T15:43:30.475039Z",
"shell.execute_reply": "2024-05-25T15:43:30.481686Z"
},
"trusted": true
},
"execution_count": 119,
"outputs": [
{
"execution_count": 119,
"output_type": "execute_result",
"data": {
"text/plain": "[[['i', 'm', 'very', 'serious', 'about', 'this']],\n [['i', 'm', 'not', 'tired']],\n [['i', 'm', 'not', 'married', 'either']],\n [['he', 's', 'sleeping', 'like', 'a', 'baby']],\n [['i', 'm', 'joking', 'of', 'course']]]"
},
"metadata": {}
}
]
},
{
"cell_type": "code",
"source": [
"from torchtext.data.metrics import bleu_score\n",
"\n",
"bleu_score(x, y)"
],
"metadata": {
"execution": {
"iopub.status.busy": "2024-05-25T15:43:36.654035Z",
"iopub.execute_input": "2024-05-25T15:43:36.654953Z",
"iopub.status.idle": "2024-05-25T15:43:36.916617Z",
"shell.execute_reply.started": "2024-05-25T15:43:36.654906Z",
"shell.execute_reply": "2024-05-25T15:43:36.915429Z"
},
"trusted": true
},
"execution_count": 120,
"outputs": [
{
"execution_count": 120,
"output_type": "execute_result",
"data": {
"text/plain": "0.5885258316993713"
},
"metadata": {}
}
]
}
]
}