s464953_uczenie_glebokie_se.../pl2en-seq2seq.ipynb

1630 lines
153 KiB
Plaintext
Raw Permalink Normal View History

2024-05-31 23:32:34 +02:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Seq2Seq Polski --> Angielski\n",
"https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"_cell_guid": "b1076dfc-b9ad-4769-8c92-a6c4dae69d19",
"_uuid": "8f2839f25d086af736a60e9eeb907d3b93b6e0e5",
"execution": {
"iopub.execute_input": "2024-05-25T14:03:55.887266Z",
"iopub.status.busy": "2024-05-25T14:03:55.886451Z",
"iopub.status.idle": "2024-05-25T14:04:02.514594Z",
"shell.execute_reply": "2024-05-25T14:04:02.513697Z",
"shell.execute_reply.started": "2024-05-25T14:03:55.887232Z"
},
"trusted": true
},
"outputs": [],
"source": [
"from __future__ import unicode_literals, print_function, division\n",
"from io import open\n",
"import unicodedata\n",
"import re\n",
"import random\n",
"\n",
"import torch\n",
"import torch.nn as nn\n",
"from torch import optim\n",
"import torch.nn.functional as F\n",
"\n",
"import numpy as np\n",
"from torch.utils.data import TensorDataset, DataLoader, RandomSampler\n",
"\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"execution": {
"iopub.execute_input": "2024-05-25T14:04:09.403926Z",
"iopub.status.busy": "2024-05-25T14:04:09.403445Z",
"iopub.status.idle": "2024-05-25T14:04:09.434533Z",
"shell.execute_reply": "2024-05-25T14:04:09.433678Z",
"shell.execute_reply.started": "2024-05-25T14:04:09.403898Z"
},
"trusted": true
},
"outputs": [
{
"data": {
"text/plain": [
"0"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"torch.cuda.device_count()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Konwersja słów na index"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"execution": {
"iopub.execute_input": "2024-05-25T14:04:14.014490Z",
"iopub.status.busy": "2024-05-25T14:04:14.014114Z",
"iopub.status.idle": "2024-05-25T14:04:14.024526Z",
"shell.execute_reply": "2024-05-25T14:04:14.023673Z",
"shell.execute_reply.started": "2024-05-25T14:04:14.014461Z"
},
"trusted": true
},
"outputs": [],
"source": [
"SOS_token = 0\n",
"EOS_token = 1\n",
"\n",
"class Lang:\n",
" def __init__(self, name):\n",
" self.name = name\n",
" self.word2index = {}\n",
" self.word2count = {}\n",
" self.index2word = {0: \"SOS\", 1: \"EOS\"}\n",
" self.n_words = 2 # Count SOS and EOS\n",
"\n",
" def addSentence(self, sentence):\n",
" for word in sentence.split(' '):\n",
" self.addWord(word)\n",
"\n",
" def addWord(self, word):\n",
" if word not in self.word2index:\n",
" self.word2index[word] = self.n_words\n",
" self.word2count[word] = 1\n",
" self.index2word[self.n_words] = word\n",
" self.n_words += 1\n",
" else:\n",
" self.word2count[word] += 1"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Normalizacja tekstu"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"execution": {
"iopub.execute_input": "2024-05-25T14:04:23.432285Z",
"iopub.status.busy": "2024-05-25T14:04:23.431898Z",
"iopub.status.idle": "2024-05-25T14:04:23.438688Z",
"shell.execute_reply": "2024-05-25T14:04:23.437569Z",
"shell.execute_reply.started": "2024-05-25T14:04:23.432256Z"
},
"trusted": true
},
"outputs": [],
"source": [
"# Turn a Unicode string to plain ASCII, thanks to\n",
"# https://stackoverflow.com/a/518232/2809427\n",
"def unicodeToAscii(s):\n",
" return ''.join(\n",
" c for c in unicodedata.normalize('NFD', s)\n",
" if unicodedata.category(c) != 'Mn'\n",
" )\n",
"\n",
"# Lowercase, trim, and remove non-letter characters\n",
"def normalizeString(s):\n",
" s = unicodeToAscii(s.lower().strip())\n",
" s = re.sub(r\"([.!?])\", r\" \\1\", s)\n",
" s = re.sub(r\"[^a-zA-Z!?]+\", r\" \", s)\n",
" return s.strip()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Wczytywanie danych (zmodyfikowane ze względu na ścieżkę w kaggle)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"execution": {
"iopub.execute_input": "2024-05-25T14:12:25.386029Z",
"iopub.status.busy": "2024-05-25T14:12:25.385674Z",
"iopub.status.idle": "2024-05-25T14:12:25.394103Z",
"shell.execute_reply": "2024-05-25T14:12:25.392925Z",
"shell.execute_reply.started": "2024-05-25T14:12:25.386002Z"
},
"trusted": true
},
"outputs": [],
"source": [
"def readLangs(reverse=False):\n",
" print(\"Reading lines...\")\n",
" lang1=\"en\"\n",
" lang2=\"pol\"\n",
" # Read the file and split into lines\n",
" lines = open('pol.txt', encoding='utf-8').\\\n",
" read().strip().split('\\n')\n",
"\n",
" # Split every line into pairs and normalize\n",
" pairs = [[normalizeString(s) for s in l.split('\\t')[:-1]] for l in lines]\n",
"\n",
" # Reverse pairs, make Lang instances\n",
" if reverse:\n",
" pairs = [list(reversed(p)) for p in pairs]\n",
" input_lang = Lang(lang2)\n",
" output_lang = Lang(lang1)\n",
" else:\n",
" input_lang = Lang(lang1)\n",
" output_lang = Lang(lang2)\n",
"\n",
" return input_lang, output_lang, pairs"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Ograniczenie do zdań max 10 słów, formy I am / You are / He is etc. bez interpunkcji"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"execution": {
"iopub.execute_input": "2024-05-25T14:12:29.730147Z",
"iopub.status.busy": "2024-05-25T14:12:29.729786Z",
"iopub.status.idle": "2024-05-25T14:12:29.737013Z",
"shell.execute_reply": "2024-05-25T14:12:29.735886Z",
"shell.execute_reply.started": "2024-05-25T14:12:29.730121Z"
},
"trusted": true
},
"outputs": [],
"source": [
"MAX_LENGTH = 10\n",
"\n",
"eng_prefixes = (\n",
" \"i am \", \"i m \",\n",
" \"he is\", \"he s \",\n",
" \"she is\", \"she s \",\n",
" \"you are\", \"you re \",\n",
" \"we are\", \"we re \",\n",
" \"they are\", \"they re \"\n",
")\n",
"\n",
"def filterPair(p):\n",
" return len(p[0].split(' ')) < MAX_LENGTH and \\\n",
" len(p[1].split(' ')) < MAX_LENGTH and \\\n",
" p[1].startswith(eng_prefixes)\n",
"\n",
"\n",
"def filterPairs(pairs):\n",
" return [pair for pair in pairs if filterPair(pair)]"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"execution": {
"iopub.execute_input": "2024-05-25T14:12:33.204776Z",
"iopub.status.busy": "2024-05-25T14:12:33.204103Z",
"iopub.status.idle": "2024-05-25T14:12:36.889693Z",
"shell.execute_reply": "2024-05-25T14:12:36.888700Z",
"shell.execute_reply.started": "2024-05-25T14:12:33.204744Z"
},
"trusted": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Reading lines...\n",
"Read 49943 sentence pairs\n",
"Trimmed to 3613 sentence pairs\n",
"Counting words...\n",
"Counted words:\n",
"pol 3070\n",
"en 1969\n",
"['jestem tylko mechanikiem', 'i m only the mechanic']\n"
]
}
],
"source": [
"def prepareData(reverse=False):\n",
" input_lang, output_lang, pairs = readLangs(reverse)\n",
" print(\"Read %s sentence pairs\" % len(pairs))\n",
" pairs = filterPairs(pairs)\n",
" print(\"Trimmed to %s sentence pairs\" % len(pairs))\n",
" print(\"Counting words...\")\n",
" for pair in pairs:\n",
" input_lang.addSentence(pair[0])\n",
" output_lang.addSentence(pair[1])\n",
" print(\"Counted words:\")\n",
" print(input_lang.name, input_lang.n_words)\n",
" print(output_lang.name, output_lang.n_words)\n",
" return input_lang, output_lang, pairs\n",
"\n",
"input_lang, output_lang, pairs = prepareData(True)\n",
"print(random.choice(pairs))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Definicja modelu"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"execution": {
"iopub.execute_input": "2024-05-25T14:12:52.384131Z",
"iopub.status.busy": "2024-05-25T14:12:52.383787Z",
"iopub.status.idle": "2024-05-25T14:12:52.391196Z",
"shell.execute_reply": "2024-05-25T14:12:52.390316Z",
"shell.execute_reply.started": "2024-05-25T14:12:52.384104Z"
},
"trusted": true
},
"outputs": [],
"source": [
"class EncoderRNN(nn.Module):\n",
" def __init__(self, input_size, hidden_size, dropout_p=0.1):\n",
" super(EncoderRNN, self).__init__()\n",
" self.hidden_size = hidden_size\n",
"\n",
" self.embedding = nn.Embedding(input_size, hidden_size)\n",
" self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True)\n",
" self.dropout = nn.Dropout(dropout_p)\n",
"\n",
" def forward(self, input):\n",
" embedded = self.dropout(self.embedding(input))\n",
" output, hidden = self.gru(embedded)\n",
" return output, hidden"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"execution": {
"iopub.execute_input": "2024-05-25T14:12:54.394808Z",
"iopub.status.busy": "2024-05-25T14:12:54.393953Z",
"iopub.status.idle": "2024-05-25T14:12:54.409000Z",
"shell.execute_reply": "2024-05-25T14:12:54.407827Z",
"shell.execute_reply.started": "2024-05-25T14:12:54.394765Z"
},
"trusted": true
},
"outputs": [],
"source": [
"class DecoderRNN(nn.Module):\n",
" def __init__(self, hidden_size, output_size):\n",
" super(DecoderRNN, self).__init__()\n",
" self.embedding = nn.Embedding(output_size, hidden_size)\n",
" self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True)\n",
" self.out = nn.Linear(hidden_size, output_size)\n",
"\n",
" def forward(self, encoder_outputs, encoder_hidden, target_tensor=None):\n",
" batch_size = encoder_outputs.size(0)\n",
" decoder_input = torch.empty(batch_size, 1, dtype=torch.long, device=device).fill_(SOS_token)\n",
" decoder_hidden = encoder_hidden\n",
" decoder_outputs = []\n",
"\n",
" for i in range(MAX_LENGTH):\n",
" decoder_output, decoder_hidden = self.forward_step(decoder_input, decoder_hidden)\n",
" decoder_outputs.append(decoder_output)\n",
"\n",
" if target_tensor is not None:\n",
" # Teacher forcing: Feed the target as the next input\n",
" decoder_input = target_tensor[:, i].unsqueeze(1) # Teacher forcing\n",
" else:\n",
" # Without teacher forcing: use its own predictions as the next input\n",
" _, topi = decoder_output.topk(1)\n",
" decoder_input = topi.squeeze(-1).detach() # detach from history as input\n",
"\n",
" decoder_outputs = torch.cat(decoder_outputs, dim=1)\n",
" decoder_outputs = F.log_softmax(decoder_outputs, dim=-1)\n",
" return decoder_outputs, decoder_hidden, None # We return `None` for consistency in the training loop\n",
"\n",
" def forward_step(self, input, hidden):\n",
" output = self.embedding(input)\n",
" output = F.relu(output)\n",
" output, hidden = self.gru(output, hidden)\n",
" output = self.out(output)\n",
" return output, hidden"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"execution": {
"iopub.execute_input": "2024-05-25T14:13:00.670758Z",
"iopub.status.busy": "2024-05-25T14:13:00.670299Z",
"iopub.status.idle": "2024-05-25T14:13:00.687695Z",
"shell.execute_reply": "2024-05-25T14:13:00.686610Z",
"shell.execute_reply.started": "2024-05-25T14:13:00.670720Z"
},
"trusted": true
},
"outputs": [],
"source": [
"class BahdanauAttention(nn.Module):\n",
" def __init__(self, hidden_size):\n",
" super(BahdanauAttention, self).__init__()\n",
" self.Wa = nn.Linear(hidden_size, hidden_size)\n",
" self.Ua = nn.Linear(hidden_size, hidden_size)\n",
" self.Va = nn.Linear(hidden_size, 1)\n",
"\n",
" def forward(self, query, keys):\n",
" scores = self.Va(torch.tanh(self.Wa(query) + self.Ua(keys)))\n",
" scores = scores.squeeze(2).unsqueeze(1)\n",
"\n",
" weights = F.softmax(scores, dim=-1)\n",
" context = torch.bmm(weights, keys)\n",
"\n",
" return context, weights\n",
"\n",
"class AttnDecoderRNN(nn.Module):\n",
" def __init__(self, hidden_size, output_size, dropout_p=0.1):\n",
" super(AttnDecoderRNN, self).__init__()\n",
" self.embedding = nn.Embedding(output_size, hidden_size)\n",
" self.attention = BahdanauAttention(hidden_size)\n",
" self.gru = nn.GRU(2 * hidden_size, hidden_size, batch_first=True)\n",
" self.out = nn.Linear(hidden_size, output_size)\n",
" self.dropout = nn.Dropout(dropout_p)\n",
"\n",
" def forward(self, encoder_outputs, encoder_hidden, target_tensor=None):\n",
" batch_size = encoder_outputs.size(0)\n",
" decoder_input = torch.empty(batch_size, 1, dtype=torch.long, device=device).fill_(SOS_token)\n",
" decoder_hidden = encoder_hidden\n",
" decoder_outputs = []\n",
" attentions = []\n",
"\n",
" for i in range(MAX_LENGTH):\n",
" decoder_output, decoder_hidden, attn_weights = self.forward_step(\n",
" decoder_input, decoder_hidden, encoder_outputs\n",
" )\n",
" decoder_outputs.append(decoder_output)\n",
" attentions.append(attn_weights)\n",
"\n",
" if target_tensor is not None:\n",
" # Teacher forcing: Feed the target as the next input\n",
" decoder_input = target_tensor[:, i].unsqueeze(1) # Teacher forcing\n",
" else:\n",
" # Without teacher forcing: use its own predictions as the next input\n",
" _, topi = decoder_output.topk(1)\n",
" decoder_input = topi.squeeze(-1).detach() # detach from history as input\n",
"\n",
" decoder_outputs = torch.cat(decoder_outputs, dim=1)\n",
" decoder_outputs = F.log_softmax(decoder_outputs, dim=-1)\n",
" attentions = torch.cat(attentions, dim=1)\n",
"\n",
" return decoder_outputs, decoder_hidden, attentions\n",
"\n",
"\n",
" def forward_step(self, input, hidden, encoder_outputs):\n",
" embedded = self.dropout(self.embedding(input))\n",
"\n",
" query = hidden.permute(1, 0, 2)\n",
" context, attn_weights = self.attention(query, encoder_outputs)\n",
" input_gru = torch.cat((embedded, context), dim=2)\n",
"\n",
" output, hidden = self.gru(input_gru, hidden)\n",
" output = self.out(output)\n",
"\n",
" return output, hidden, attn_weights"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"execution": {
"iopub.execute_input": "2024-05-25T14:22:08.184711Z",
"iopub.status.busy": "2024-05-25T14:22:08.183866Z",
"iopub.status.idle": "2024-05-25T14:22:08.194870Z",
"shell.execute_reply": "2024-05-25T14:22:08.193965Z",
"shell.execute_reply.started": "2024-05-25T14:22:08.184675Z"
},
"trusted": true
},
"outputs": [],
"source": [
"def indexesFromSentence(lang, sentence):\n",
" return [lang.word2index[word] for word in sentence.split(' ')]\n",
"\n",
"def tensorFromSentence(lang, sentence):\n",
" indexes = indexesFromSentence(lang, sentence)\n",
" indexes.append(EOS_token)\n",
" return torch.tensor(indexes, dtype=torch.long, device=device).view(1, -1)\n",
"\n",
"def tensorsFromPair(pair):\n",
" input_tensor = tensorFromSentence(input_lang, pair[0])\n",
" target_tensor = tensorFromSentence(output_lang, pair[1])\n",
" return (input_tensor, target_tensor)\n",
"\n",
"def get_dataloader(batch_size):\n",
" input_lang, output_lang, pairs = prepareData(True)\n",
"\n",
" n = len(pairs)\n",
" input_ids = np.zeros((n, MAX_LENGTH), dtype=np.int32)\n",
" target_ids = np.zeros((n, MAX_LENGTH), dtype=np.int32)\n",
"\n",
" for idx, (inp, tgt) in enumerate(pairs):\n",
" inp_ids = indexesFromSentence(input_lang, inp)\n",
" tgt_ids = indexesFromSentence(output_lang, tgt)\n",
" inp_ids.append(EOS_token)\n",
" tgt_ids.append(EOS_token)\n",
" input_ids[idx, :len(inp_ids)] = inp_ids\n",
" target_ids[idx, :len(tgt_ids)] = tgt_ids\n",
"\n",
" train_data = TensorDataset(torch.LongTensor(input_ids).to(device),\n",
" torch.LongTensor(target_ids).to(device))\n",
"\n",
" train_sampler = RandomSampler(train_data)\n",
" train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size)\n",
" return input_lang, output_lang, train_dataloader"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"execution": {
"iopub.execute_input": "2024-05-25T14:16:38.895410Z",
"iopub.status.busy": "2024-05-25T14:16:38.894580Z",
"iopub.status.idle": "2024-05-25T14:16:38.902142Z",
"shell.execute_reply": "2024-05-25T14:16:38.900953Z",
"shell.execute_reply.started": "2024-05-25T14:16:38.895382Z"
},
"trusted": true
},
"outputs": [],
"source": [
"def train_epoch(dataloader, encoder, decoder, encoder_optimizer,\n",
" decoder_optimizer, criterion):\n",
"\n",
" total_loss = 0\n",
" for data in dataloader:\n",
" input_tensor, target_tensor = data\n",
"\n",
" encoder_optimizer.zero_grad()\n",
" decoder_optimizer.zero_grad()\n",
"\n",
" encoder_outputs, encoder_hidden = encoder(input_tensor)\n",
" decoder_outputs, _, _ = decoder(encoder_outputs, encoder_hidden, target_tensor)\n",
"\n",
" loss = criterion(\n",
" decoder_outputs.view(-1, decoder_outputs.size(-1)),\n",
" target_tensor.view(-1)\n",
" )\n",
" loss.backward()\n",
"\n",
" encoder_optimizer.step()\n",
" decoder_optimizer.step()\n",
"\n",
" total_loss += loss.item()\n",
"\n",
" return total_loss / len(dataloader)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"execution": {
"iopub.execute_input": "2024-05-25T14:16:43.069953Z",
"iopub.status.busy": "2024-05-25T14:16:43.069584Z",
"iopub.status.idle": "2024-05-25T14:16:43.075972Z",
"shell.execute_reply": "2024-05-25T14:16:43.075033Z",
"shell.execute_reply.started": "2024-05-25T14:16:43.069926Z"
},
"trusted": true
},
"outputs": [],
"source": [
"import time\n",
"import math\n",
"\n",
"def asMinutes(s):\n",
" m = math.floor(s / 60)\n",
" s -= m * 60\n",
" return '%dm %ds' % (m, s)\n",
"\n",
"def timeSince(since, percent):\n",
" now = time.time()\n",
" s = now - since\n",
" es = s / (percent)\n",
" rs = es - s\n",
" return '%s (- %s)' % (asMinutes(s), asMinutes(rs))"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"execution": {
"iopub.execute_input": "2024-05-25T14:20:58.574520Z",
"iopub.status.busy": "2024-05-25T14:20:58.574148Z",
"iopub.status.idle": "2024-05-25T14:20:58.583203Z",
"shell.execute_reply": "2024-05-25T14:20:58.582230Z",
"shell.execute_reply.started": "2024-05-25T14:20:58.574492Z"
},
"trusted": true
},
"outputs": [],
"source": [
"def train(train_dataloader, encoder, decoder, n_epochs, learning_rate=0.001,\n",
" print_every=100, plot_every=100):\n",
" start = time.time()\n",
" plot_losses = []\n",
" print_loss_total = 0 # Reset every print_every\n",
" plot_loss_total = 0 # Reset every plot_every\n",
"\n",
" encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)\n",
" decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate)\n",
" criterion = nn.NLLLoss()\n",
"\n",
" for epoch in range(1, n_epochs + 1):\n",
" loss = train_epoch(train_dataloader, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion)\n",
" print_loss_total += loss\n",
" plot_loss_total += loss\n",
"\n",
" if epoch % print_every == 0:\n",
" print_loss_avg = print_loss_total / print_every\n",
" print_loss_total = 0\n",
" print('%s (%d %d%%) %.4f' % (timeSince(start, epoch / n_epochs),\n",
" epoch, epoch / n_epochs * 100, print_loss_avg))\n",
"\n",
" if epoch % plot_every == 0:\n",
" plot_loss_avg = plot_loss_total / plot_every\n",
" plot_losses.append(plot_loss_avg)\n",
" plot_loss_total = 0\n",
"\n",
" showPlot(plot_losses)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"execution": {
"iopub.execute_input": "2024-05-25T14:21:00.586719Z",
"iopub.status.busy": "2024-05-25T14:21:00.586018Z",
"iopub.status.idle": "2024-05-25T14:21:00.592633Z",
"shell.execute_reply": "2024-05-25T14:21:00.591636Z",
"shell.execute_reply.started": "2024-05-25T14:21:00.586683Z"
},
"trusted": true
},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"plt.switch_backend('agg')\n",
"import matplotlib.ticker as ticker\n",
"import numpy as np\n",
"%matplotlib inline\n",
"\n",
"def showPlot(points):\n",
" plt.figure()\n",
" fig, ax = plt.subplots()\n",
" # this locator puts ticks at regular intervals\n",
" loc = ticker.MultipleLocator(base=0.2)\n",
" ax.yaxis.set_major_locator(loc)\n",
" plt.plot(points)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Ewaluacja"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {
"execution": {
"iopub.execute_input": "2024-05-25T14:21:01.859612Z",
"iopub.status.busy": "2024-05-25T14:21:01.858691Z",
"iopub.status.idle": "2024-05-25T14:21:01.866857Z",
"shell.execute_reply": "2024-05-25T14:21:01.865732Z",
"shell.execute_reply.started": "2024-05-25T14:21:01.859574Z"
},
"trusted": true
},
"outputs": [],
"source": [
"def evaluate(encoder, decoder, sentence, input_lang, output_lang):\n",
" with torch.no_grad():\n",
" input_tensor = tensorFromSentence(input_lang, sentence)\n",
"\n",
" encoder_outputs, encoder_hidden = encoder(input_tensor)\n",
" decoder_outputs, decoder_hidden, decoder_attn = decoder(encoder_outputs, encoder_hidden)\n",
"\n",
" _, topi = decoder_outputs.topk(1)\n",
" decoded_ids = topi.squeeze()\n",
"\n",
" decoded_words = []\n",
" for idx in decoded_ids:\n",
" if idx.item() == EOS_token:\n",
" decoded_words.append('<EOS>')\n",
" break\n",
" decoded_words.append(output_lang.index2word[idx.item()])\n",
" return decoded_words, decoder_attn"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {
"execution": {
"iopub.execute_input": "2024-05-25T14:39:52.475327Z",
"iopub.status.busy": "2024-05-25T14:39:52.474985Z",
"iopub.status.idle": "2024-05-25T14:39:52.481801Z",
"shell.execute_reply": "2024-05-25T14:39:52.480957Z",
"shell.execute_reply.started": "2024-05-25T14:39:52.475304Z"
},
"trusted": true
},
"outputs": [],
"source": [
"def evaluateRandomly(encoder, decoder, n=10):\n",
" for i in range(n):\n",
" pair = random.choice(pairs)\n",
" print('Input sentence: ', pair[0])\n",
" print('Target (true) translation:' , pair[1])\n",
" output_words, _ = evaluate(encoder, decoder, pair[0], input_lang, output_lang)\n",
" output_sentence = ' '.join(output_words)\n",
" print('Output sentence: ', output_sentence)\n",
" print('')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Wykorzystanie zdefiniowanych wyżej funkcji"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Trenowanie modelu"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {
"execution": {
"iopub.execute_input": "2024-05-25T14:22:39.754740Z",
"iopub.status.busy": "2024-05-25T14:22:39.754370Z",
"iopub.status.idle": "2024-05-25T14:26:53.707012Z",
"shell.execute_reply": "2024-05-25T14:26:53.705969Z",
"shell.execute_reply.started": "2024-05-25T14:22:39.754714Z"
},
"trusted": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Reading lines...\n",
"Read 49943 sentence pairs\n",
"Trimmed to 3613 sentence pairs\n",
"Counting words...\n",
"Counted words:\n",
"pol 3070\n",
"en 1969\n",
"0m 47s (- 11m 58s) (5 6%) 2.1245\n",
"1m 33s (- 10m 57s) (10 12%) 1.2482\n",
"2m 13s (- 9m 36s) (15 18%) 0.8442\n",
"2m 51s (- 8m 35s) (20 25%) 0.5612\n",
"3m 32s (- 7m 46s) (25 31%) 0.3599\n",
"4m 10s (- 6m 56s) (30 37%) 0.2216\n",
"4m 51s (- 6m 15s) (35 43%) 0.1367\n",
"5m 30s (- 5m 30s) (40 50%) 0.0894\n",
"6m 9s (- 4m 47s) (45 56%) 0.0647\n",
"6m 48s (- 4m 5s) (50 62%) 0.0489\n",
"7m 27s (- 3m 23s) (55 68%) 0.0402\n",
"8m 8s (- 2m 42s) (60 75%) 0.0345\n",
"8m 48s (- 2m 1s) (65 81%) 0.0315\n",
"9m 25s (- 1m 20s) (70 87%) 0.0278\n",
"10m 3s (- 0m 40s) (75 93%) 0.0271\n",
"10m 42s (- 0m 0s) (80 100%) 0.0253\n"
]
},
{
"data": {
"text/plain": [
"<Figure size 640x480 with 0 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAiMAAAGeCAYAAABGlgGHAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8fJSN1AAAACXBIWXMAAA9hAAAPYQGoP6dpAABAzUlEQVR4nO3de3gU9d3//9cekk0IyWIC5AABgkcKCiER5OihGgv8uG/uWsWbShSxt3xvFDFKNeUuWtuaaqtFSwFRkFtFGg+ItEVr7loJR4FAEAFFTSQBEiKn3RAgh935/RGyEpJANmQzOTwf1zVXsp+d2XlPvZp9MZ+ZeVsMwzAEAABgEqvZBQAAgI6NMAIAAExFGAEAAKYijAAAAFMRRgAAgKkIIwAAwFSEEQAAYCrCCAAAMBVhBAAAmMpudgGN4fV6dfDgQYWHh8tisZhdDgAAaATDMFRaWqq4uDhZrec5/2H44emnnzaSk5ONzp07G926dTP+/d//3fjiiy/Ou827775r3HzzzUbXrl2N8PBw47rrrjM+/PBDf3ZrFBYWGpJYWFhYWFhY2uBSWFh43u95i2E0vjfNj370I91555269tprVVVVpdmzZ2vnzp3avXu3wsLC6t1m5syZiouL04033qguXbro1Vdf1R/+8Ad9+umnSkxMbNR+XS6XunTposLCQkVERDS2XAAAYCK32634+HgdP35cTqezwfX8CiPn+u6779S9e3etWbNGo0ePbvR2/fv318SJEzVnzpxGre92u+V0OuVyuQgjAAC0EY39/r6oa0ZcLpckKTIystHbeL1elZaWnneb8vJylZeX+1673e6mFwkAAFq1Jt9NYxiG0tLSNHLkSA0YMKDR2z333HMqKyvTHXfc0eA6GRkZcjqdviU+Pr6pZQIAgFauydM006dP19///netW7dOPXv2bNQ2y5cv13333af3339fN998c4Pr1XdmJD4+nmkaAADakIBO0zz44INatWqVsrOzGx1EMjMzNXXqVL399tvnDSKS5HA45HA4mlIaAABoY/wKI4Zh6MEHH9R7772nTz75RAkJCY3abvny5br33nu1fPlyjRs3rkmFAgCA9smvMDJ9+nS9+eabev/99xUeHq7i4mJJktPpVGhoqCQpPT1dBw4c0GuvvSapOoikpqbqhRde0HXXXefbJjQ09Ly3+QAAgI7BrwtYFyxYIJfLpRtuuEGxsbG+JTMz07dOUVGRCgoKfK9feuklVVVVafr06bW2eeihh5rvKAAAQJt1Uc8ZaSk8ZwQAgLansd/fNMoDAACmIowAAABTEUYAAICpCCMAAMBUHTqMrNpxUD9/Z4c+23/c7FIAAOiwOnQY+WBnkd7aul8bvjlidikAAHRYHTqMDIrvIknKLThuah0AAHRkhBFJuYXHTa0DAICOrEOHkat7OmWzWlTsPq1i12mzywEAoEPq0GGkU7BdV0SHS+LsCAAAZunQYURiqgYAALMRRuKrOwfnFh4zuRIAADomwkj8JZKknftd8nhbfc9AAADanQ4fRi7r3llhwTaVVXj0VUmp2eUAANDhdPgwYrNadE3PLpJ43ggAAGbo8GFEkgb16iKJi1gBADADYUTcUQMAgJkII5ISz4SRvYdKVVZeZW4xAAB0MIQRSd0jQhTrDJHXkHYecJldDgAAHQph5AymagAAMAdh5Aw6+AIAYA6/wkhGRoauvfZahYeHq3v37powYYK+/PLLC263Zs0aJSUlKSQkRH379tXChQubXHCgcGYEAABz+BVG1qxZo+nTp2vTpk3KyspSVVWVUlJSVFZW1uA2+fn5Gjt2rEaNGqXt27frF7/4hWbMmKF33333ootvTnTwBQDAHBbDMJr8DPTvvvtO3bt315o1azR69Oh613nssce0atUq7dmzxzc2bdo07dixQxs3bmzUftxut5xOp1wulyIiIppa7gWNeWGt9hS5tfCuwfrRgNiA7QcAgI6gsd/fF3XNiMtVfedJZGRkg+ts3LhRKSkptcZuvfVWbd26VZWVlfVuU15eLrfbXWtpCd9P1XBHDQAALaXJYcQwDKWlpWnkyJEaMGBAg+sVFxcrOjq61lh0dLSqqqp0+PDherfJyMiQ0+n0LfHx8U0t0y908AUAoOU1OYw88MAD+uyzz7R8+fILrmuxWGq9rpkZOne8Rnp6ulwul28pLCxsapl+oYMvAAAtz96UjR588EGtWrVK2dnZ6tmz53nXjYmJUXFxca2xkpIS2e12RUVF1buNw+GQw+FoSmkX5dwOvlfFBO76FAAAUM2vMyOGYeiBBx7QihUr9PHHHyshIeGC2wwbNkxZWVm1xj766CMlJycrKCjIv2oDjA6+AAC0PL/CyPTp0/XGG2/ozTffVHh4uIqLi1VcXKxTp0751klPT1dqaqrv9bRp07Rv3z6lpaVpz549WrJkiRYvXqxHH320+Y6iGdHBFwCAluVXGFmwYIFcLpduuOEGxcbG+pbMzEzfOkVFRSooKPC9TkhI0OrVq/XJJ59o0KBB+vWvf60XX3xRt912W/MdRTPi4WcAALQsv64ZacwjSZYuXVpn7Prrr9e2bdv82ZVpzu3gG+Zo0mU1AACgkehNcw46+AIA0LIII/VgqgYAgJZDGKkHHXwBAGg5hJF6cGYEAICWQxipBx18AQBoOYSRenQKtuuK6HBJ9KkBACDQCCMNqJmq2c5UDQAAAUUYaUBNB98dhBEAAAKKMNIAOvgCANAyCCMNOLeDLwAACAzCSAPo4AsAQMsgjJwHHXwBAAg8wsh58PAzAAACjzByHud28AUAAM2PMHIeZ3fw/Ww/HXwBAAgEwsgF1EzV7Nh/3NQ6AABorwgjF0AHXwAAAoswcgFcxAoAQGARRi6ADr4AAAQWYeQC6OALAEBgEUYagQ6+AAAEjt9hJDs7W+PHj1dcXJwsFotWrlx5wW2WLVumgQMHqlOnToqNjdWUKVN05MiRptRripoOvlzECgBA8/M7jJSVlWngwIGaN29eo9Zft26dUlNTNXXqVO3atUtvv/22tmzZovvuu8/vYs3i6+B7gA6+AAA0N7u/G4wZM0Zjxoxp9PqbNm1Snz59NGPGDElSQkKC7r//fj377LP+7to053bwvSomwuySAABoNwJ+zcjw4cO1f/9+rV69WoZh6NChQ3rnnXc0bty4BrcpLy+X2+2utZiJDr4AAAROi4SRZcuWaeLEiQoODlZMTIy6dOmiP/3pTw1uk5GRIafT6Vvi4+MDXeYF0cEXAIDACHgY2b17t2bMmKE5c+YoJydHH374ofLz8zVt2rQGt0lPT5fL5fIthYWFgS7zgnj4GQAAgeH3NSP+ysjI0IgRIzRr1ixJ0jXXXKOwsDCNGjVKv/nNbxQbG1tnG4fDIYfDEejS/HJuB98wR8D/pwMAoEMI+JmRkydPymqtvRubzSZJMoy2c2cKHXwBAAgMv8PIiRMnlJubq9zcXElSfn6+cnNzVVBQIKl6iiU1NdW3/vjx47VixQotWLBAeXl5Wr9+vWbMmKEhQ4YoLi6ueY6ihTBVAwBA8/N7rmHr1q268cYbfa/T0tIkSXfffbeWLl2qoqIiXzCRpHvuuUelpaWaN2+eHnnkEXXp0kU33XSTnnnmmWYov2UNiu+iDz4v1g7CCAAAzcZitIG5ErfbLafTKZfLpYgI857x8WneEU1ctEkxESHa9IsfmlYHAABtQWO/v+lN4wc6+AIA0PwII36ggy8AAM2PMOInOvgCANC8CCN+ooMvAADNizDiJzr4AgDQvAgjfqrp4HvyTAdfAABwcQgjfqKDLwAAzYsw0gR08AUAoPkQRpqAx8IDANB8CCNNcG4HXwAA0HSEkSaggy8AAM2HMNJETNUAANA8CCNN9H0Y4bHwAABcDMJIE9WEkR2FTNMAAHAxCCNNRAdfAACaB2GkiejgCwBA8yCMXAQ6+AIAcPEIIxeBDr4AAFw8wshFoIMvAAAXjzByEc7u4Lv3EB18AQBoCsLIRajVwZfrRgAAaBK/w0h2drbGjx+vuLg4WSwWrVy58oLblJeXa/bs2erdu7ccDocuvfRSLVmypCn
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"hidden_size = 128\n",
"batch_size = 32\n",
"\n",
"input_lang, output_lang, train_dataloader = get_dataloader(batch_size)\n",
"\n",
"encoder = EncoderRNN(input_lang.n_words, hidden_size).to(device)\n",
"decoder = AttnDecoderRNN(hidden_size, output_lang.n_words).to(device)\n",
"\n",
"train(train_dataloader, encoder, decoder, 80, print_every=5, plot_every=5)"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {
"execution": {
"iopub.execute_input": "2024-05-25T15:48:23.131948Z",
"iopub.status.busy": "2024-05-25T15:48:23.131435Z",
"iopub.status.idle": "2024-05-25T15:48:23.213007Z",
"shell.execute_reply": "2024-05-25T15:48:23.211856Z",
"shell.execute_reply.started": "2024-05-25T15:48:23.131911Z"
},
"trusted": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Input sentence: ciesze sie ze by em w stanie pomoc\n",
"Target (true) translation: i m glad i was able to help\n",
"Output sentence: i m glad i was able to help <EOS>\n",
"\n",
"Input sentence: to moja matka chrzestna\n",
"Target (true) translation: she s my godmother\n",
"Output sentence: she s my godmother by three <EOS>\n",
"\n",
"Input sentence: nie gram w zadna gre\n",
"Target (true) translation: i m not playing a game\n",
"Output sentence: i m not playing a game <EOS>\n",
"\n",
"Input sentence: jestem wyzszy\n",
"Target (true) translation: i am taller\n",
"Output sentence: i am taller <EOS>\n",
"\n",
"Input sentence: jestes zdesperowany\n",
"Target (true) translation: you re desperate\n",
"Output sentence: you re desperate <EOS>\n",
"\n",
"Input sentence: zostane zwolniony\n",
"Target (true) translation: i m going to get fired\n",
"Output sentence: i m going to be arrested i think <EOS>\n",
"\n",
"Input sentence: mamy dzisiaj rybe jako g owne danie\n",
"Target (true) translation: we are having fish for our main course\n",
"Output sentence: we are having fish for our main course <EOS>\n",
"\n",
"Input sentence: jestes przepracowana\n",
"Target (true) translation: you are overworked\n",
"Output sentence: you are overworked <EOS>\n",
"\n",
"Input sentence: jestes elokwentny\n",
"Target (true) translation: you re articulate\n",
"Output sentence: you re articulate <EOS>\n",
"\n",
"Input sentence: zaczynam rozumiec\n",
"Target (true) translation: i m beginning to understand\n",
"Output sentence: i m beginning to understand <EOS>\n",
"\n"
]
}
],
"source": [
"evaluateRandomly(encoder, decoder)"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {
"execution": {
"iopub.execute_input": "2024-05-25T15:30:41.253734Z",
"iopub.status.busy": "2024-05-25T15:30:41.253325Z",
"iopub.status.idle": "2024-05-25T15:30:41.264515Z",
"shell.execute_reply": "2024-05-25T15:30:41.263376Z",
"shell.execute_reply.started": "2024-05-25T15:30:41.253703Z"
},
"trusted": true
},
"outputs": [],
"source": [
"def showAttention(input_sentence, output_words, attentions):\n",
" fig = plt.figure()\n",
" ax = fig.add_subplot(111)\n",
" cax = ax.matshow(attentions.cpu().numpy(), cmap='bone')\n",
" fig.colorbar(cax)\n",
"\n",
" # Set up axes\n",
" ax.set_xticklabels([''] + input_sentence.split(' ') +\n",
" ['<EOS>'], rotation=90)\n",
" ax.set_yticklabels([''] + output_words)\n",
"\n",
" # Show label at every tick\n",
" ax.xaxis.set_major_locator(ticker.MultipleLocator(1))\n",
" ax.yaxis.set_major_locator(ticker.MultipleLocator(1))\n",
"\n",
" plt.show()\n",
"\n",
"\n",
"def evaluateAndShowAttention(input_sentence):\n",
" input_sentence = normalizeString(input_sentence)\n",
" output_words, attentions = evaluate(encoder, decoder, input_sentence, input_lang, output_lang)\n",
" print('input =', input_sentence)\n",
" print('output =', ' '.join(output_words))\n",
" showAttention(input_sentence, output_words, attentions[0, :len(output_words), :])\n",
"\n",
"def translate(input_sentence, tokenized=False):\n",
" input_sentence = normalizeString(input_sentence)\n",
" output_words, attentions = evaluate(encoder, decoder, input_sentence, input_lang, output_lang)\n",
" if tokenized:\n",
" if \"<EOS>\" in output_words:\n",
" output_words.remove(\"<EOS>\")\n",
" return output_words\n",
" return ' '.join(output_words)"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {
"execution": {
"iopub.execute_input": "2024-05-25T14:58:53.639963Z",
"iopub.status.busy": "2024-05-25T14:58:53.639252Z",
"iopub.status.idle": "2024-05-25T14:58:53.654186Z",
"shell.execute_reply": "2024-05-25T14:58:53.653028Z",
"shell.execute_reply.started": "2024-05-25T14:58:53.639932Z"
},
"trusted": true
},
"outputs": [
{
"data": {
"text/plain": [
"['we', 'are', 'hungry']"
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"translate(\"Jesteśmy głodni\", tokenized=True)"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {
"execution": {
"iopub.execute_input": "2024-05-25T14:46:05.024277Z",
"iopub.status.busy": "2024-05-25T14:46:05.023218Z",
"iopub.status.idle": "2024-05-25T14:46:05.426793Z",
"shell.execute_reply": "2024-05-25T14:46:05.424993Z",
"shell.execute_reply.started": "2024-05-25T14:46:05.024227Z"
},
"trusted": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"input = jestes zbyt naiwny\n",
"output = you re too naive <EOS>\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"C:\\Users\\Michał\\AppData\\Local\\Temp\\ipykernel_10608\\712218569.py:8: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.\n",
" ax.set_xticklabels([''] + input_sentence.split(' ') +\n",
"C:\\Users\\Michał\\AppData\\Local\\Temp\\ipykernel_10608\\712218569.py:10: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.\n",
" ax.set_yticklabels([''] + output_words)\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAcQAAAHHCAYAAAAhyyixAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8fJSN1AAAACXBIWXMAAA9hAAAPYQGoP6dpAAA2C0lEQVR4nO3de3iMd/7/8dcESVQk6pSDRhK1SBrntFtUHVq69GtX6aJUHKJfSluqjhdXq9aKaqW0Kg5bEVazKP3SriKtQ5GenEt0aRVBQh2axKFIZn5/ZDO/jmSmYST3TOb5cN3Xyj33Pfd7Zq/m7f3+3J/PbbJYLBYBAODhvIwOAAAAV0BCBABAJEQAACSREAEAkERCBABAEgkRAABJJEQAACSREAEAkERCBABAEgkRAABJJEQAACSREAEAkERCBABAEgkRABzKz8/XgQMHlJeXZ3QoKGUkRMCDXLlyxegQ3M7HH3+s5s2ba8WKFUaHglJGQgQ8SGBgoAYPHqwdO3YYHYrbSE5OVq1atbRkyRKjQ0EpIyECHiQlJUXZ2dl67LHH1KBBA82YMUNnzpwxOiyXdf78eX366adasmSJtm3bplOnThkdEkoRCRHwIN26ddPq1at15swZPf/880pJSVFYWJj+53/+R2vWrGGc7BYffPCBoqOj9ac//Ult27bV0qVLjQ4JpYiECHigGjVq6OWXX9b+/fuVkJCgzz77TE8//bRCQkL06quv6urVq0aH6BKSk5MVGxsrSXr22WdJiOWcyWKxWIwOAkDZysrK0tKlS5WUlKSTJ0/qqaeeUlxcnM6cOaMZM2YoODhYmzZtMjpMQx08eFAtW7bU6dOnVbNmTV2+fFmBgYHavHmz/vjHPxodHkpBRaMDAFB21qxZo6SkJG3cuFFRUVEaMWKEnn32WVWrVs16TLNmzdS8eXPjgnQRS5Ys0RNPPKGaNWtKkvz8/NS9e3clJSWREMspWqZu5pdffjE6BLixQYMGKSQkRDt37tS+ffv0wgsv2CRDSapXr54mTZpkTIAuIj8/X8uXL7e2Sws9++yzWrlypW7cuGFQZChNtExd2BtvvKHw8HD17t1bktSrVy+tXr1aQUFBWr9+vZo2bWpwhHA3V69e1T333GN0GC4vMzNTixYt0oQJE+Tt7W3dbzabNX36dMXGxqpu3boGRojSQEJ0YfXq1dM///lPtW7dWqmpqerVq5dWrFihlStX6uTJkx4/xoM7Yzab9cMPP+jcuXMym802rz366KMGRQUYjzFEF5aZmanQ0FBJ0ieffKJevXqpc+fOCg8PZwwDd+Srr75S3759deLECd36b2GTyaT8/HyDInN9J06c0JUrV9SoUSN5eTHaVB7x/6oLu/fee5WRkSFJ2rBhgx5//HFJksVi4RcX7siwYcMUExOjgwcP6uLFi7p06ZJ1u3jxotHhuYTk5GTNnj3bZt///u//ql69emrcuLGio6Ot/12ifCEhurAePXqob9++6tSpky5cuKAuXbpIkvbt26f69esbHB3c0dGjRzV9+nRFRkaqWrVqCggIsNkgzZ8/3+a72LBhg5KSkrR06VJ9++23qlatml5//XUDI0RpoWXqwt5++22Fh4crIyNDM2fOlJ+fn6SCVurw4cMNjg7u6I9//KN++OEH/kHlwJEjRxQTE2P9ee3atfrzn/+sfv36SZKmT5+uQYMGGRUeShEJ0YVVqlRJY8aMKbJ/1KhRZR8MyoUXX3xRr7zyirKystS4cWNVqlTJ5vUmTZoYFJnruHbtmvz9/a0/p6WlafDgwdaf69Wrp6ysLCNCQykjIbq4ZcuWacGCBTp27Ji+/PJLhYWFafbs2YqIiNBf/vIXo8ODm+nZs6ck2fyCN5lMslgs3FTzX2FhYdq9e7fCwsJ0/vx5HTp0SI888oj19aysLNrL5RQJ0YUlJibq1Vdf1ahRo/T3v//d+suqWrVqmj17NgkRt+2nn34yOgSXFxsbqxEjRujQoUPavHmzGjVqpJYtW1pfT0tLU3R0tIERorSQEF3Yu+++q0WLFql79+6aMWOGdX9MTEyxrVTg94SFhRkdgssbP368rl69qjVr1igoKEirVq2yeX3nzp165plnDIoOpYmJ+S6scuXK+v777xUWFqaqVatq//79qlevno4ePaomTZro2rVrRocINxMSEqL27durffv2ateunRo2bGh0SIDLYNqFC4uIiNC+ffuK7P/0008VFRVV9gG5sMGDBys3N7fI/itXrtiMl3m6WbNmyd/fXwkJCYqMjFRwcLD69Omj+fPn6/Dhw0aH51KuXbumdevW6a233tKsWbO0bt06/hFa3lngshYvXmypU6eO5V//+pelSpUqlpSUFMu0adOsf8f/5+XlZTl79myR/T///LOlQoUKBkTk+rKysiwpKSmWfv36WSpWrGjx8vIyOiSXsXbtWkutWrUsJpPJZqtVq5Zl3bp1RoeHUsIYogsbNGiQ8vLyNG7cOF29elV9+/ZVnTp1NGfOHPXp08fo8FxCTk6OLBaLLBaLcnNz5evra30tPz9f69evV+3atQ2M0PVcvnxZO3bs0LZt27R161bt3btXjRs3Vrt27YwOzSWkpaXp6aef1p///Ge98sorioyMlCSlp6dr1qxZevrpp7V161a1atXK4EhxtzGG6CbOnz8vs9nML/dbeHl5yWQy2X3dZDLp9ddf9/jHGRX64x//qAMHDig6Olrt27fXo48+qrZt2xZ5BJQn69q1q0JDQ7VgwYJiXx86dKgyMjK0fv36Mo4MpY2E6MI6duyoNWvWFPlllZOTo+7du2vz5s3GBOZCtm3bJovFoo4dO2r16tWqXr269TVvb2+FhYUpJCTEwAhdS/Xq1WUymfT4449bb64prIBQ4N5779UXX3yhxo0bF/v6gQMH1K5dO126dKmMI0NpIyG6MC8vL2VlZRWpCs+dO6c6dero5s2bBkXmeo4fP66wsDCH1SIKHDhwQFu3btW2bdu0fft2eXl5qV27durQoYOGDRtmdHiG++3d3cU5ceKEIiMjdfXq1TKODKWNhOiCDhw4IElq1qyZNm/ebFP15Ofna8OGDVqwYIGOHz9uUISuJyIiQoMGDdLAgQN5cOtt2L17t+bOnat//vOfMpvNrFQjqWnTpho1apTd9UoXL16s2bNnW/87RfnBTTUuqFmzZjKZTDKZTOrYsWOR1ytXrqx3333XgMhc1+jRo7VkyRJNnTpVHTp0UFxcnJ566in5+PgYHZpL2bt3r7Zu3aqtW7dq+/btys3NVdOmTTVy5Eh16NDB6PBcwsCBAzVmzBgFBgaqa9euNq/9+9//1rhx4xiTLqeoEF1Q4cNb69Wrp2+++Ua1atWyvubt7a3atWurQoUKBkbouvbv36/FixcrJSVFeXl56tu3rwYPHqwWLVoYHZpLqFixopo3b6527dpZb6r57ULWkMxms3r37q3Vq1erYcOGNneZHj16VN27d9eqVat4SHA5REJEuXTz5k3NmzdP48eP182bNxUdHa2RI0dq0KBBHj3OmJOTQwIsoRUrViglJUVHjhyRJDVo0EB9+vRhylM5RkJ0YcnJyapZs6aefPJJSdK4ceO0cOFCRUVFKSUlhXUpi3Hz5k199NFHSkpKUmpqqh5++GHFxcXpzJkzmjt3rjp06KAPPvjA6DABuCASogtr2LChEhMT1bFjR3355Zd67LHHNHv2bH3yySeqWLGi1qxZY3SILmPPnj1KSkpSSkqKKlSooP79+2vIkCFq1KiR9Zhvv/1Wjz76qMctv1W9enUdOXJENWvW1L333uuwQr548WIZRuaaVq5cqe7du8vb21tSwR3MoaGh1mGKq1evau7cuRo3bpyRYaIUkBBd2D333KPvv/9edevW1fjx45WZmamlS5fq0KFDat++vX7++WejQ3QZFSpUUKdOnRQXF6fu3bsXefCtVLCu6QsvvKCkpCQDIjROcnKy+vTpIx8fHyUnJzs8dsCAAWUUleuqUKGCMjMzrdOd/P39tW/fPtWrV0+SdPbsWYWEhHBHbjnEXaYuzM/PTxcuXFDdunW1adMmvfzyy5I
"text/plain": [
"<Figure size 640x480 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"evaluateAndShowAttention('Jesteś zbyt naiwny')"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {
"execution": {
"iopub.execute_input": "2024-05-25T14:46:10.395671Z",
"iopub.status.busy": "2024-05-25T14:46:10.394969Z",
"iopub.status.idle": "2024-05-25T14:46:10.793392Z",
"shell.execute_reply": "2024-05-25T14:46:10.791940Z",
"shell.execute_reply.started": "2024-05-25T14:46:10.395630Z"
},
"trusted": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"input = naprawde mi przykro\n",
"output = i m really sorry <EOS>\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"C:\\Users\\Michał\\AppData\\Local\\Temp\\ipykernel_10608\\712218569.py:8: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.\n",
" ax.set_xticklabels([''] + input_sentence.split(' ') +\n",
"C:\\Users\\Michał\\AppData\\Local\\Temp\\ipykernel_10608\\712218569.py:10: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.\n",
" ax.set_yticklabels([''] + output_words)\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAcQAAAHZCAYAAAAYF0taAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8fJSN1AAAACXBIWXMAAA9hAAAPYQGoP6dpAAA1gElEQVR4nO3de1yUdfr/8feAAgaCBxQkkVDxtJomdBB0s4OY22pqmYcST5RmZeRm6rd+efj2jc1KrTbdLMNDZpadtEwlD1lqlkqpaGtpLqigK5ooKggzvz8MtglmBAa852ZeTx/3I7m577mvmYqL67o/n89tsdlsNgEA4OG8jA4AAAB3QEIEAEAkRAAAJJEQAQCQREIEAEASCREAAEkkRAAAJJEQAQCQREIEAEASCREAAEkkRAAAJJEQAQCQREI0jQMHDujpp5/W4MGDdfz4cUnS6tWrlZ6ebnBkAFAzkBBN4Msvv1SHDh20bds2ffjhhzp79qwkadeuXZoyZYrB0QFAzUBCNIFJkybp2WefVWpqqnx8fEr233LLLdq6dauBkQFAzUFCNIHdu3erX79+pfY3atRIOTk5BkQEADUPCdEE6tWrp6ysrFL709LSdPXVVxsQEQDUPCREExgyZIgmTpyo7OxsWSwWWa1Wbd68WU888YQSEhKMDg8AagSLzWazGR0EnLt48aKGDx+ud999VzabTbVq1VJRUZGGDBmiBQsWyNvb2+gQAcD0SIgmcuDAAaWlpclqteq6665TVFSU0SEBQI1BQgQAQFItowNA2caPH1/uY2fOnFmNkQCAZyAhuqm0tDS7r3fs2KGioiK1bt1akrR//355e3srOjraiPAAoMYhIbqpDRs2lPx95syZqlu3rhYuXKj69etLkk6dOqURI0aoW7duRoUIADUK9xBN4Oqrr9batWv1pz/9yW7/nj17FB8fr6NHjxoUGQDUHMxDNIHc3FwdO3as1P7jx4/rzJkzBkQEADUPCdEE+vXrpxEjRmj58uU6fPiwDh8+rOXLl2vUqFHq37+/0eEBQI1Ay9QEzp07pyeeeEJvvfWWLl68KEmqVauWRo0apRdeeEH+/v4GRwgA5kdCNJG8vDwdOHBANptNLVu2JBECQBUiIZpAamqq4uLidNVVVxkdCmqAoqIiffzxx9q3b58sFovatm2ru+66iyUA4fFIiCYQGBio/Px8RUdH6+abb1b37t0VFxengIAAo0ODyfz888+68847dfjwYbVu3Vo2m0379+9XeHi4PvvsM7Vo0cLoEN1OUVGR0tPT1a5dO9WqxUy1moxBNSZw6tQpbdy4UX369FFaWpoGDBigBg0a6KabbtKkSZOMDg8mMm7cODVv3lyZmZnauXOn0tLSlJGRocjISI0bN87o8NzSypUrdd1112nZsmVGh4JqRoVoQnv27NGLL76oJUuWyGq1qqioyOiQYBL+/v765ptv1KFDB7v9P/zwg+Li4nT27FmDInNf/fr109atW9WhQwelpqYaHQ6qEfW/Cezbt09ffvmlNm7cqC+//FJFRUXq2rWrXnrpJd18881GhwcT8fX1LXPu6tmzZ+Xj42NARO7txIkT+vzzz/Xxxx+rT58+Onz4sJo2bWp0WKgmVIgm4OXlpUaNGikpKUl9+vQptWINUF4JCQnauXOn5s+frxtuuEGStG3bNj3wwAOKjo7WggULjA3QzbzyyitatGiRtm/frttuu0233Xab/ud//sfosFBNSIgmkJSUpE2bNik9PV2dOnVS9+7d1b17d3Xr1o2BNaiQX3/9VcOGDdPKlStVu3ZtSVJhYaH69OmjBQsWKCgoyOAI3Ut0dLSGDRumcePGKSUlRc8//7x+/PFHo8NCNSEhmsivv/6qr776Sl9++aW+/PJL7d69W506ddI333xjdGgwAZvNpoyMDDVq1EhHjx7Vvn37ZLPZ1K5dO7Vs2dLo8NzOnj17FB0drSNHjig4OFhnz55VSEiI1q9frxtvvNHo8FANuIdoIlarVYWFhSooKFB+fr4uXryoQ4cOGR0WTMJmsykqKkrp6emKiooiCV7GggUL1LNnTwUHB0uSAgIC1LdvX6WkpJAQayimXZjAY489po4dO6px48YaPXq0jh49qgcffFA//PCDsrOzjQ4PJuHl5aWoqCjl5OQYHYrbKyoq0pIlS5SQkGC3//7779d7772ngoICgyJDdaJCNIEjR47ogQceUPfu3dW+fXujw4GJzZgxQxMmTNDcuXP5b8mJ48eP66GHHlKfPn3s9vfs2VPjx49Xdna2mjVrZlB0qC7cQwQ8SP369XXu3DkVFhbKx8dHderUsfv+yZMnDYoMMB4Voons3btXGRkZpdo1f/wtFnBk1qxZslgsRodhSv/+97+Vl5enNm3ayMuLu001ERWiCRw8eFD9+vXT7t27ZbFYVPyvrPgHGyvVoCqcP3++VMXoiRYuXKhTp04pKSmpZN+DDz6o+fPnS5Jat26tNWvWKDw83KAIUV34NccEHnvsMUVGRurYsWO66qqrlJ6erk2bNikmJkYbN240OjyYyMMPP1zm/ry8PPXq1esKR+Oe/vnPf9rNx1y9erVSUlK0aNEifffdd6pXr56mTZtmYISoLiREE9i6daumT5+uRo0aycvLS15eXuratauSk5NZkBkVsnbtWj399NN2+/Ly8nTHHXfQafjN/v37FRMTU/L1J598oj59+ui+++5T586d9dxzz2ndunUGRojqQkI0gaKiopIVaYKDg3X06FFJUkREhP71r38ZGRpMZu3atUpJSdGsWbMkSWfOnFGPHj1ksVi0evVqg6NzD+fPn1dgYGDJ11u2bNGf//znkq+bN2/OdKcaikE1JtC+fXvt2rVLzZs314033qgZM2bIx8dH8+bNU/PmzY0ODyYSGRmpNWvWqHv37vLy8tK7774rX19fffbZZ/L39zc6PLcQERGhHTt2KCIiQidOnFB6erq6du1a8v3s7GyWuKuhSIgm8PTTTysvL0+S9Oyzz+qvf/2runXrpoYNG/KMNlRY+/bt9emnn+r222/XjTfeqE8//ZTBNL+TkJCghx9+WOnp6Vq/fr3atGmj6Ojoku9v2bKFOZw1FKNMTerkyZOqX78+Q+hxWdddd12Z/538+9//VuPGje2S4c6dO69kaG7JarVqypQp+vTTTxUaGqqZM2eqbdu2Jd8fMGCA7rjjDo0aNcrAKFEdSIhurrCwUH5+fvr+++/5rRSVUpERkVOmTKnGSAD3RsvUzdWqVUsRERGMAESlkeQq5/z580pNTdX+/ftlsVgUFRWlHj160F6uwagQTSAlJUXvv/++3n77bTVo0MDocGBiw4cP18iRI+1GTaK0FStWKDExUSdOnLDbHxwcrPnz56t3794GRYbqxLQLE3jllVf01VdfKSwsTK1bt1bnzp3tNqC8zpw5o/j4eEVFRem5557TkSNHjA7J7WzZskX33HOP/vznP2vz5s06efKkTp48qa+//lrdunXTPffco61btxodJqoBFaIJXO4eEC0xVEROTo7efvttLViwQHv27NHtt9+uUaNG6a677lLt2rWNDs9wf/nLXxQeHq7XX3+9zO+PHj1amZmZWrVq1RWODNWNhAh4sLS0NL311lt68803FRAQoPvvv19jx45VVFSU0aEZpn79+tq0aZM6dOhQ5vd37dqlm2++WadOnbrCkaG60TI1ke3bt2vx4sV6++23tWPHDqPDgcllZWVp7dq1Wrt2rby9vfWXv/xF6enpateuXclKNp7owoULdivV/FFQUJDy8/OvYES4UhhlagKHDx/W4MGDtXnzZtWrV0+S9Ouvvyo2NlZLly5l1X2U28WLF7VixQqlpKRo7dq1uvbaa/X444/rvvvuU926dSVJ7777rh566CE9/vjjBkdrjFatWmn9+vUaMWJEmd9ft26dWrZseYWjwpVAhWgCI0eO1MWLF7Vv376SG/z79u2TzWZjcjAqpEmTJnrggQcUERGhb7/9Vtu3b9eYMWNKkqF06anwxb94eaLhw4friSeeKPMe4WeffaYnn3zSYbKEuXEP0QTq1KmjLVu26Lr
"text/plain": [
"<Figure size 640x480 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"evaluateAndShowAttention('Naprawdę mi przykro')"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {
"execution": {
"iopub.execute_input": "2024-05-25T14:46:13.191025Z",
"iopub.status.busy": "2024-05-25T14:46:13.190403Z",
"iopub.status.idle": "2024-05-25T14:46:13.613486Z",
"shell.execute_reply": "2024-05-25T14:46:13.612143Z",
"shell.execute_reply.started": "2024-05-25T14:46:13.190997Z"
},
"trusted": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"input = jestes moim ojcem\n",
"output = you are my father <EOS>\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"C:\\Users\\Michał\\AppData\\Local\\Temp\\ipykernel_10608\\712218569.py:8: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.\n",
" ax.set_xticklabels([''] + input_sentence.split(' ') +\n",
"C:\\Users\\Michał\\AppData\\Local\\Temp\\ipykernel_10608\\712218569.py:10: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.\n",
" ax.set_yticklabels([''] + output_words)\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAcQAAAHHCAYAAAAhyyixAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8fJSN1AAAACXBIWXMAAA9hAAAPYQGoP6dpAAA130lEQVR4nO3deVyU9fr/8feAsiTOYJggSkiZC6Gm0qLlL7PUrCyz1NIjQlKZntzS0ofV8bRI2WZlmitii5llp+WYSZknzVbSTOFkiQoaZJAy5AIK8/uDw3ybYEZwgHuGeT193I/knnu5Zqq5uD7bbbLZbDYBAODj/IwOAAAAT0BCBABAJEQAACSREAEAkERCBABAEgkRAABJJEQAACSREAEAkERCBABAEgkRAABJJEQAACSREAEAkERCBABAEgkRAFwqKyvTjh07dOrUKaNDQT0jIQKAC++//766d++u1atXGx0K6hkJEQBcSEtL0znnnKMVK1YYHQrqmYkHBANA9QoKCtS2bVv961//0o033qjs7Gy1bdvW6LBQT6gQAcCJ119/XXFxcbr22mvVp08frVy50uiQUI9IiADgRFpamhISEiRJf/vb30iIjRxNpgBQjZ07d6pnz546ePCgWrZsqT/++EPh4eHauHGjLr30UqPDQz2gQgSAaqxYsUIDBw5Uy5YtJUkhISEaMmSIUlNTDY4M9YWE6GWOHDlidAhAo1dWVqbXXnvN3lxa6W9/+5vefPNNlZaWGhQZ6hMJ0YM9+eSTDnOfhg8frrCwMLVp00bff/+9gZHBWxUWFmrChAmKjY1Vy5YtdfbZZztsqHDo0CHdc889uvHGGx32Dxw4UFOnTlV+fr5BkaE+0Yfowc477zy9+uqr6t27t9LT0zV8+HCtXr1ab775pnJycrRhwwajQ4SXGTRokPbs2aOxY8cqPDxcJpPJ4fUxY8YYFBlgPBKiBwsODtbu3bsVFRWlSZMm6cSJE1q0aJF2796tSy+9VIcPHzY6RHiZ5s2ba8uWLerWrZvRoXid/fv36+jRo+rUqZP8/Ghca4z4t+rBWrRoodzcXEnS+vXrdc0110iSbDabysrKjAzNYx06dEg7d+7Ujh07HDZU6NSpk44fP250GB4tLS1N8+bNc9h311136bzzzlOXLl0UFxdn//8SjQsJ0YMNHTpUI0eOVP/+/VVYWKhBgwZJkrZv36727dsbHJ1nycjIUFxcnFq3bq2uXbvqoosuUvfu3e3/RIUFCxZo1qxZ+s9//qPCwkJZrVaHDdLLL78si8Vi/3n9+vVKTU3VypUr9c033yg0NFT//Oc/DYwQ9aWJ0QHAueeee07t2rVTbm6u5s6dq5CQEElSXl6exo8fb3B0niUpKUkdOnTQsmXLqu0bQ4XQ0FAVFRWpX79+DvttNptMJhMtD5J2796t+Ph4+8/vvvuubrzxRo0aNUqSNGfOHCUlJRkVHuoRfYhoFJo3b65t27ZROZ/GJZdcoiZNmmjSpEnV/uJw5ZVXGhSZ5zjrrLOUlZWl6OhoSVK3bt10xx13aNKkSZKknJwcdezYkabnRogK0cO98sorWrRokbKzs/XFF18oOjpa8+bNU0xMjG666Sajw/MYV199tb7//nsS4mns3LlT27ZtU8eOHY0OxWNFR0crIyND0dHRKigo0K5du3TFFVfYX8/Pz3doUkXjQUL0YAsXLtTDDz+syZMn6/HHH7c3Z4WGhmrevHkkxD9ZunSpxowZo507dyouLk5NmzZ1eP2v88l8VXx8vHJzc0mILiQkJGjChAnatWuXNm7cqE6dOqlnz57217du3aq4uDgDI0R9ISF6sBdffFFLlizRkCFD9MQTT9j3x8fHa9q0aQZG5nm2bt2qLVu26MMPP6zyGn1j/+fee+/VpEmTNH36dHXp0qXKLw5du3Y1KDLP8cADD+jYsWNau3atIiIitGbNGofXP//8c91+++0GRYf6RB+iBwsODtZ///tfRUdHq3nz5vr+++913nnn6aefflLXrl3pw/iTdu3a6YYbbtBDDz2k8PBwo8PxWNXNnzOZTAyqAUSF6NFiYmK0fft2e+d+pQ8//FCxsbEGReWZCgsLNWXKFJLhaezdu9foELzG8ePHlZ6ert27d8tkMumCCy5Q//79FRwcbHRoqCckRA82ffp0TZgwQSdOnJDNZtPXX3+tVatWKSUlRUuXLjU6PI8ydOhQffrppzr//PONDsWj/fWXK1TvvffeU3JysgoKChz2t2zZUsuWLdPgwYMNigz1iYTowZKSknTq1Cndf//9OnbsmEaOHKk2bdro+eef12233WZ0eB6lQ4cOmjlzprZs2VJt39jEiRMNiszzvPLKK3r55Ze1d+9eRi5XY+vWrbr11lt144036r777lPnzp0lSZmZmXrmmWd06623atOmTerVq5fBkaKu0YfoJQoKClReXq5WrVoZHYpHiomJcfqayWRSdnZ2A0bjuf46cnnnzp0677zztGLFCqWlpenTTz81OkTDXXfddYqKitKiRYuqff3uu+9Wbm6u1q1b18CRob6RED1Yv379tHbtWoWGhjrst1qtGjJkiDZu3GhMYPBasbGxmjNnjoYMGeIwUGvnzp3q27dvlSZCX9SiRQt99tln6tKlS7Wv79ixQ1deeSWL6zdCrGXqwTZt2lTtg0hPnDihzZs3GxARvN3evXurXds1MDBQR48eNSAiz3PixAmZzWanr1ssFpWUlDRgRGgo9CF6oD8/nSEzM9PhYaRlZWVav3692rRpY0RoHmXq1Kl69NFH1axZM02dOtXlsc8++2wDReXZGLl8eh06dNDGjRudrlf6ySefsCJSI0VC9EAXXXSRTCaTTCZTlUWYpYr5iS+++KIBkXmWbdu26eTJk/a/O8NC3/+Hkcunl5iYqGnTpik8PFzXXXedw2v//ve/df/992vWrFkGRYf6RB+iB9q/f79sNpvOO+88ff311zrnnHPsrwUEBKhVq1by9/c3MEJ4syVLluixxx6zP9OvTZs2mj17tsaOHWtwZJ6hvLxcI0aM0Ntvv62OHTs6jDL96aefNGTIEK1Zs4aHBDdCJEQ0OgcOHJDJZKJZ+TQYueza6tWrtWrVKu3evVtSRVPqbbfdxpSnRoyE6MHS0tLUsmVLXX/99ZKk+++/X4sXL1ZsbKxWrVrFJOs/KS8v12OPPaZnnnlGf/zxh6SKR0Ldd999mjVrFr/N/8/evXt16tQpXXDBBQ77f/rpJzVt2lTt2rUzJjDAA/At4cHmzJljXybqiy++0Pz58zV37ly1bNlSU6ZMMTg6zzJr1izNnz9fTzzxhLZt26bvvvtOc+bM0YsvvqiHHnrI6PA8RmJiorZu3Vpl/1dffaXExMSGD8gDvfnmmw6ju/ft2+ewxuuxY8c0d+5cI0JDPaNC9GBnnXWW/vvf/+rcc8/VAw88oLy8PK1cuVK7du1S37599dtvvxkdoseIjIzUyy+/XOUxT++++67Gjx+vgwcPGhSZZzGbzfruu++qjJL8+eefFR8fryNHjhgTmAfx9/dXXl6evSnZbDZr+/btOu+88yRJv/76qyIjI1kIvRGiQvRgISEhKiwslCRt2LBB11xzjSQpKCiIJ138xe+//65OnTpV2d+pUyf9/vvvBkTkmUwmk4qLi6vsLyoq4gv+f/5aI1Az+A4Sogfr37+/kpOTlZycrN27d9v7Enft2kVfz19069ZN8+fPr7J//vz56tatmwEReaY+ffooJSXFIfmVlZUpJSXF4anwgC9iHqIHe+mll/Tggw8qNzdXb7/9tsLCwiRJGRkZPKD0L+bOnavrr79eH3/8sXr16iWTyaStW7cqJyen2ocG+6q5c+fq//2//6eOHTuqT58+kqTNmzfLarWyFCB8Hn2IaDQOHjyohQsXKisrSzabTbGxsRo/frwiIyONDs2j/PLLL5o/f76+//57BQcHq2vXrvr73/+us88+2+jQPIKfn5/S0tJksVgkSbfffrvmzZtnf9bmkSNHlJSURBNzI0RC9HCbN2/WokWLlJ2drTVr1qhNmzZ65ZVXFBMTQxPXX5w4cUI7duz
"text/plain": [
"<Figure size 640x480 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"evaluateAndShowAttention('Jesteś moim ojcem')"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {
"execution": {
"iopub.execute_input": "2024-05-25T14:46:15.841005Z",
"iopub.status.busy": "2024-05-25T14:46:15.840433Z",
"iopub.status.idle": "2024-05-25T14:46:16.232003Z",
"shell.execute_reply": "2024-05-25T14:46:16.230365Z",
"shell.execute_reply.started": "2024-05-25T14:46:15.840973Z"
},
"trusted": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"input = on tez jest nauczycielem\n",
"output = he is a teacher too <EOS>\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"C:\\Users\\Michał\\AppData\\Local\\Temp\\ipykernel_10608\\712218569.py:8: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.\n",
" ax.set_xticklabels([''] + input_sentence.split(' ') +\n",
"C:\\Users\\Michał\\AppData\\Local\\Temp\\ipykernel_10608\\712218569.py:10: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.\n",
" ax.set_yticklabels([''] + output_words)\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAdIAAAHzCAYAAACUkJylAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8fJSN1AAAACXBIWXMAAA9hAAAPYQGoP6dpAABC0UlEQVR4nO3de1xU1f7/8feAAqYOeAUyRMxUDC2FOuElzdSyjmWXk6ZH06Ty0OUgx0y/dtKspMxMrTAtU+mUWekpKzJJ07yUJZGZVOYtiMALlngFZeb3h8f5NQEG7M1smXk9fexHzJq9Z3829ODDZ6291rY5nU6nAABAtfhZHQAAALUZiRQAAANIpAAAGEAiBQDAABIpAAAGkEgBADCARAoAgAEkUgAADCCRAgBgAIkUAAADSKQAABhAIgUAwIA6VgcAwHcdOXJEDofDrc1ut1sUDVA9VKQAPGr37t26/vrrVb9+fQUHB6tRo0Zq1KiRQkJC1KhRI6vDA6qMihSARw0dOlSS9Morryg0NFQ2m83iiABjbDyPFIAnNWjQQJmZmWrXrp3VoQCmoGsXgEdddtllys3NtToMwDR07QLwqJdfflmjR49WXl6eYmJiVLduXbf3O3XqZFFkQPWQSAF41P79+7Vz506NHDnS1Waz2eR0OmWz2VRaWmphdEDVMUYKwKM6dOig6OhojRs3rtybjSIjIy2KDKgeEikAj6pfv762bNmiNm3aWB0KYApuNgLgUb1799aWLVusDgMwDWOkADxqwIABGjNmjLZu3aqOHTuWudnohhtusCgyoHro2gXgUX5+FXeEcbMRaiMSKQAABjBGCsAyJ06csDoEwDASKQCPKi0t1WOPPaYWLVqoQYMG2rVrlyTp3//+t+bPn29xdEDVkUgBeNQTTzyhhQsXatq0aQoICHC1d+zYUS+//LKFkQHVQyIF4FFpaWmaN2+ehg4dKn9/f1d7p06d9P3331sYGVA9JFIAHpWXl1fuYgwOh0MnT560ICLAGBIpAI+6+OKLtW7dujLtb731ljp37mxBRIAxLMgAwKMmTZqkYcOGKS8vTw6HQ8uWLdMPP/ygtLQ0vf/++1aHB1QZ80gBeNxHH32kqVOnKjMzUw6HQ126dNEjjzyifv36WR0aUGUkUgAADGCMFAAAAxgjBVDjGjVqVOa5oxU5ePBgDUcDmItECqDGzZw50+oQgBrDGCkAAAYwRgrA43bu3KmHH35Yt99+u/bt2ydJWrFihbZt22ZxZEDVkUgBeNTatWvVsWNHbdq0ScuWLdORI0ckSd98840mTZpkcXRA1ZFIAXjU+PHj9fjjjysjI8Nt0fqrrrpKn332mYWRAdVDIgXgUVu3btVNN91Upr1Zs2YqLCy0ICLAGBIpAI8KCQlRfn5+mfasrCy1aNHCgogAY0ikADxqyJAheuihh1RQUCCbzSaHw6ENGzZo7NixGj58uNXhAVXG9BcAHnXy5EmNGDFCb7zxhpxOp+rUqaPS0lINGTJECxcudHtGKVAbkEgBWGLnzp3KysqSw+FQ586dddFFF1kdElAtJFIAAAxgiUAANS45OVmPPfaY6tevr+Tk5LPuO2PGDA9FBZiDRAqgxmVlZenkyZOurytS2YXtgXMJXbsAABjA9BcAHnXo0KFyH5V28OBBFRUVWRARYAyJFIBHDR48WG+88UaZ9jfffFODBw+2ICLAGLp2AXhU48aNtWHDBkVHR7u1f//99+rWrRvLBKLWoSIF4FHFxcU6depUmfaTJ0/q+PHjFkQEGEMiBeBRl112mebNm1em/cUXX1RsbKwFEQHGMP0FgEc98cQT6tOnj7Zs2aKrr75akrRq1Sp9+eWXWrlypcXRAVVHRQrAo7p166bPPvtMERERevPNN/Xee++pTZs2+uabb9SjRw+rw6txpaWl+uabb8rt3kbtxM1GAOBB77zzjm655RalpaVp6NChVocDE1CRAvCoXr16KS0tzWdvLFq0aJGaNWumhQsXWh0KTEIiBeBRsbGxGjdunMLCwnTXXXfp888/tzokjzlw4IA+/PBDLVy4UGvXrtXPP/9sdUgwAYkUgEc988wzysvLU1pamvbv368rr7xSHTp00PTp07V3716rw6tRr7/+umJiYnTttdeqR48eSktLszokmIBECsDj/P39deONN+qdd95RXl6ehgwZon//+9+KiIjQwIEDtXr1aqtDrBGLFi3S8OHDJUl///vfSaRegpuNAFjmiy++0IIFC7R48WIFBwdrxIgRys/P12uvvaZ//OMfmj59utUhmubbb79VbGys8vLy1LRpUx05ckShoaFavXq1/vKXv1gdHgwgkQLwqH379unVV1/VggUL9OOPP2rAgAFKSEjQNddc43qM2scff6yBAwfqyJEjFkdrnrFjx2r79u1avny5q23o0KFq2LChXnzxRQsjg1EkUgAeFRAQoAsvvFB33nmnRowYoWbNmpXZp6ioSDfeeKM++eQTCyI0X2lpqS644AI999xzuvXWW13tH374oYYOHaqCggIFBARYGCGMIJEC8Kh169b5xMILv5efn6+XXnpJ48ePd0uYDodDU6dO1fDhw9WyZUsLI4QRJFIAHjVlyhR1795dvXv3dms/evSonnnmGT3yyCMWRQZUD4kUgEf5+fmpbt26SklJUXJysqt97969Ov/881VaWmphdJ7z008/6ejRo2rfvr38/JhAUZvx0wPgcWlpaUpJSdGIESNUUlJidTg1atGiRZo5c6Zb2913363WrVurY8eOiomJUW5urjXBwRQkUgAed9VVV+nzzz/XF198oV69enn1QgwvvviigoODXa9XrFihBQsWKC0tTV9++aVCQkL06KOPWhghjCKRAvCoM1NcLrzwQn3++eey2+2Ki4vT5s2bLY6sZmzfvl1xcXGu1++++65uuOEGDR06VF26dNHUqVO1atUqCyOEUSRSAB71+9sy7Ha70tPTddNNN2ngwIHWBVWDjh8/Lrvd7nq9ceNGXXnlla7XrVu3VkFBgRWhwSQ82BuARy1YsMCtq9PPz0+zZ89W586d9emnn1oYWc2IjIxUZmamIiMjdeDAAW3btk3du3d3vV9QUOD2/UDtw127AFCDUlJSNHv2bCUmJmr16tXav3+/vv32W9f7M2fO1Pvvv6+PP/7YwihhBBUpAI964IEH1KZNGz3wwANu7c8//7x27NhR5g7X2u6hhx7SsWPHtGzZMoWFhemtt95ye3/Dhg26/fbbLYoOZqAiBeBRLVq00PLlyxUbG+vW/tVXX+mGG27gGZ2odahIAXhUYWFhuWOCdrtdBw4csCAizzh+/LgyMjK0fft22Ww2XXTRRerbt6/q1atndWgwiEQKwKPatGmjFStW6L777nNr//DDD9W6dWuLoqpZy5cvV0JCQpk/FJo2bar58+drwIABFkUGM5BIAXhUcnKy7rvvPu3fv9+13u6qVav0zDPPeN34qHR6usutt96qG264Qf/6178UHR0tScrOztYzzzyjW2+9VWvWrFF8fLzFkaK6GCMF4HFz5szRE088oV9++UWS1KpVK02ePFnDhw+3ODLzXXfddYqIiNDcuXPLff+ee+5Rbm6u0tPTPRwZzEIiBSzQu3dvLVu2TCEhIW7tRUVFGjhwoFavXm1NYB62f/9+1atXTw0aNLA6lBrTqFEjffrpp+rYsWO573/zzTfq2bOnfv31Vw9HBrPQtQtYYM2aNeUu1n7ixAmtW7fOgoisUd5Dvb3NiRMn3FY2+qPg4GAVFxd7MCKYjUQKeNA333zj+jo7O9ttabjS0lKtWLFCLVq0sCI0j4mKinKtt1ueXbt2eTCamte2bVutXr1aI0eOLPf9VatWqU2bNh6OCmYikQIedOmll8pms8lms5V5sLUk1atXT88995wFkXlOUlKS2+uTJ08qKytLK1as0IMPPmhNUDVoxIgRGjt2rEJDQ3Xddde5vffBBx9o3LhxmjhxokXRwQyMkQIe9NNPP8npdKp169b64osv3Lo2AwIC1Lx5c/n7+1sYoXVeeOEFbd68WQsWLLA6FFM5HA4NGjRIS5cuVbt27dz
"text/plain": [
"<Figure size 640x480 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"evaluateAndShowAttention('On też jest nauczycielem')"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {
"execution": {
"iopub.execute_input": "2024-05-25T14:50:24.774464Z",
"iopub.status.busy": "2024-05-25T14:50:24.773506Z",
"iopub.status.idle": "2024-05-25T14:50:24.795895Z",
"shell.execute_reply": "2024-05-25T14:50:24.794979Z",
"shell.execute_reply.started": "2024-05-25T14:50:24.774430Z"
},
"trusted": true
},
"outputs": [],
"source": [
"torch.save(encoder.state_dict(), \"encoder.pt\")\n",
"torch.save(decoder.state_dict(), \"decoder.pt\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# BLEU score"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Jako że korzystaliśmy z okrojonej wersji zbioru danych, słownik nie zawiera wszystkich słów pojawiających się w przykładach więc do ewaluacji wykorzystujemy część przykładów z treningu"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {
"execution": {
"iopub.execute_input": "2024-05-25T15:28:54.297253Z",
"iopub.status.busy": "2024-05-25T15:28:54.296348Z",
"iopub.status.idle": "2024-05-25T15:28:59.041172Z",
"shell.execute_reply": "2024-05-25T15:28:59.040201Z",
"shell.execute_reply.started": "2024-05-25T15:28:54.297211Z"
},
"trusted": true
},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>English</th>\n",
" <th>Polish</th>\n",
" <th>attribution</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>13492</th>\n",
" <td>i m the last in line</td>\n",
" <td>jestem ostatni w kolejce</td>\n",
" <td>CC-BY 2.0 (France) Attribution: tatoeba.org #5...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>14379</th>\n",
" <td>you are not a coward</td>\n",
" <td>nie jestes tchorzem</td>\n",
" <td>CC-BY 2.0 (France) Attribution: tatoeba.org #1...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>31538</th>\n",
" <td>we re anxious about her health</td>\n",
" <td>martwimy sie o jej zdrowie</td>\n",
" <td>CC-BY 2.0 (France) Attribution: tatoeba.org #7...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>33382</th>\n",
" <td>he s not at all afraid of snakes</td>\n",
" <td>on zupe nie nie boi sie wezy</td>\n",
" <td>CC-BY 2.0 (France) Attribution: tatoeba.org #2...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>13031</th>\n",
" <td>he is a fast speaker</td>\n",
" <td>on szybko mowi</td>\n",
" <td>CC-BY 2.0 (France) Attribution: tatoeba.org #3...</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" English Polish \\\n",
"13492 i m the last in line jestem ostatni w kolejce \n",
"14379 you are not a coward nie jestes tchorzem \n",
"31538 we re anxious about her health martwimy sie o jej zdrowie \n",
"33382 he s not at all afraid of snakes on zupe nie nie boi sie wezy \n",
"13031 he is a fast speaker on szybko mowi \n",
"\n",
" attribution \n",
"13492 CC-BY 2.0 (France) Attribution: tatoeba.org #5... \n",
"14379 CC-BY 2.0 (France) Attribution: tatoeba.org #1... \n",
"31538 CC-BY 2.0 (France) Attribution: tatoeba.org #7... \n",
"33382 CC-BY 2.0 (France) Attribution: tatoeba.org #2... \n",
"13031 CC-BY 2.0 (France) Attribution: tatoeba.org #3... "
]
},
"execution_count": 32,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import pandas as pd\n",
"\n",
"\n",
"def filter_rows(row):\n",
" return len(row[\"English\"].split(' '))<MAX_LENGTH and \\\n",
" len(row[\"Polish\"].split(' '))<MAX_LENGTH and \\\n",
" row[\"English\"].startswith(eng_prefixes)\n",
"data_file = pd.read_csv(\"pol.txt\", sep='\\t', names=[\"English\",\"Polish\",\"attribution\"])\n",
"data_file[\"English\"] = data_file[\"English\"].apply(normalizeString)\n",
"data_file[\"Polish\"] = data_file[\"Polish\"].apply(normalizeString)\n",
"\n",
"filter_list = data_file.apply(filter_rows, axis=1)\n",
"\n",
"test_section = data_file[filter_list]\n",
"test_section = test_section.sample(frac=1).head(500)\n",
"test_section.head()"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {
"execution": {
"iopub.execute_input": "2024-05-25T15:29:07.110816Z",
"iopub.status.busy": "2024-05-25T15:29:07.110416Z",
"iopub.status.idle": "2024-05-25T15:29:07.117378Z",
"shell.execute_reply": "2024-05-25T15:29:07.116136Z",
"shell.execute_reply.started": "2024-05-25T15:29:07.110786Z"
},
"trusted": true
},
"outputs": [],
"source": [
"test_section[\"English_tokenized\"] = test_section[\"English\"].apply(lambda x: x.split())"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {
"execution": {
"iopub.execute_input": "2024-05-25T15:29:10.203540Z",
"iopub.status.busy": "2024-05-25T15:29:10.203170Z",
"iopub.status.idle": "2024-05-25T15:29:10.212993Z",
"shell.execute_reply": "2024-05-25T15:29:10.211937Z",
"shell.execute_reply.started": "2024-05-25T15:29:10.203511Z"
},
"trusted": true
},
"outputs": [
{
"data": {
"text/plain": [
"13492 [i, m, the, last, in, line]\n",
"14379 [you, are, not, a, coward]\n",
"31538 [we, re, anxious, about, her, health]\n",
"33382 [he, s, not, at, all, afraid, of, snakes]\n",
"13031 [he, is, a, fast, speaker]\n",
"Name: English_tokenized, dtype: object"
]
},
"execution_count": 34,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"test_section.head()[\"English_tokenized\"]"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {
"execution": {
"iopub.execute_input": "2024-05-25T15:30:53.183937Z",
"iopub.status.busy": "2024-05-25T15:30:53.183117Z",
"iopub.status.idle": "2024-05-25T15:30:56.313012Z",
"shell.execute_reply": "2024-05-25T15:30:56.312202Z",
"shell.execute_reply.started": "2024-05-25T15:30:53.183902Z"
},
"trusted": true
},
"outputs": [],
"source": [
"test_section[\"English_translated\"] = test_section[\"Polish\"].apply(lambda x: translate(x, tokenized=True))"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {
"execution": {
"iopub.execute_input": "2024-05-25T15:31:06.746471Z",
"iopub.status.busy": "2024-05-25T15:31:06.745381Z",
"iopub.status.idle": "2024-05-25T15:31:06.771839Z",
"shell.execute_reply": "2024-05-25T15:31:06.770679Z",
"shell.execute_reply.started": "2024-05-25T15:31:06.746417Z"
},
"trusted": true
},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>English</th>\n",
" <th>Polish</th>\n",
" <th>attribution</th>\n",
" <th>English_tokenized</th>\n",
" <th>English_translated</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>13492</th>\n",
" <td>i m the last in line</td>\n",
" <td>jestem ostatni w kolejce</td>\n",
" <td>CC-BY 2.0 (France) Attribution: tatoeba.org #5...</td>\n",
" <td>[i, m, the, last, in, line]</td>\n",
" <td>[i, m, the, last, in, line]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>14379</th>\n",
" <td>you are not a coward</td>\n",
" <td>nie jestes tchorzem</td>\n",
" <td>CC-BY 2.0 (France) Attribution: tatoeba.org #1...</td>\n",
" <td>[you, are, not, a, coward]</td>\n",
" <td>[you, are, not, a, coward]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>31538</th>\n",
" <td>we re anxious about her health</td>\n",
" <td>martwimy sie o jej zdrowie</td>\n",
" <td>CC-BY 2.0 (France) Attribution: tatoeba.org #7...</td>\n",
" <td>[we, re, anxious, about, her, health]</td>\n",
" <td>[we, are, worried, about, her, health]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>33382</th>\n",
" <td>he s not at all afraid of snakes</td>\n",
" <td>on zupe nie nie boi sie wezy</td>\n",
" <td>CC-BY 2.0 (France) Attribution: tatoeba.org #2...</td>\n",
" <td>[he, s, not, at, all, afraid, of, snakes]</td>\n",
" <td>[he, s, not, afraid, of, snakes, at, all, afraid]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>13031</th>\n",
" <td>he is a fast speaker</td>\n",
" <td>on szybko mowi</td>\n",
" <td>CC-BY 2.0 (France) Attribution: tatoeba.org #3...</td>\n",
" <td>[he, is, a, fast, speaker]</td>\n",
" <td>[he, is, a, fast, speaker, young]</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" English Polish \\\n",
"13492 i m the last in line jestem ostatni w kolejce \n",
"14379 you are not a coward nie jestes tchorzem \n",
"31538 we re anxious about her health martwimy sie o jej zdrowie \n",
"33382 he s not at all afraid of snakes on zupe nie nie boi sie wezy \n",
"13031 he is a fast speaker on szybko mowi \n",
"\n",
" attribution \\\n",
"13492 CC-BY 2.0 (France) Attribution: tatoeba.org #5... \n",
"14379 CC-BY 2.0 (France) Attribution: tatoeba.org #1... \n",
"31538 CC-BY 2.0 (France) Attribution: tatoeba.org #7... \n",
"33382 CC-BY 2.0 (France) Attribution: tatoeba.org #2... \n",
"13031 CC-BY 2.0 (France) Attribution: tatoeba.org #3... \n",
"\n",
" English_tokenized \\\n",
"13492 [i, m, the, last, in, line] \n",
"14379 [you, are, not, a, coward] \n",
"31538 [we, re, anxious, about, her, health] \n",
"33382 [he, s, not, at, all, afraid, of, snakes] \n",
"13031 [he, is, a, fast, speaker] \n",
"\n",
" English_translated \n",
"13492 [i, m, the, last, in, line] \n",
"14379 [you, are, not, a, coward] \n",
"31538 [we, are, worried, about, her, health] \n",
"33382 [he, s, not, afraid, of, snakes, at, all, afraid] \n",
"13031 [he, is, a, fast, speaker, young] "
]
},
"execution_count": 36,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"test_section.head()"
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {
"execution": {
"iopub.execute_input": "2024-05-25T15:43:29.442752Z",
"iopub.status.busy": "2024-05-25T15:43:29.441911Z",
"iopub.status.idle": "2024-05-25T15:43:29.447799Z",
"shell.execute_reply": "2024-05-25T15:43:29.446877Z",
"shell.execute_reply.started": "2024-05-25T15:43:29.442721Z"
},
"trusted": true
},
"outputs": [],
"source": [
"candidate_corpus = test_section[\"English_translated\"].values\n",
"references_corpus = test_section[\"English_tokenized\"].values.tolist()\n",
"x = candidate_corpus.tolist()\n",
"y = [[el] for el in references_corpus]"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {
"execution": {
"iopub.execute_input": "2024-05-25T15:43:30.475080Z",
"iopub.status.busy": "2024-05-25T15:43:30.474463Z",
"iopub.status.idle": "2024-05-25T15:43:30.482690Z",
"shell.execute_reply": "2024-05-25T15:43:30.481686Z",
"shell.execute_reply.started": "2024-05-25T15:43:30.475039Z"
},
"trusted": true
},
"outputs": [
{
"data": {
"text/plain": [
"[[['i', 'm', 'the', 'last', 'in', 'line']],\n",
" [['you', 'are', 'not', 'a', 'coward']],\n",
" [['we', 're', 'anxious', 'about', 'her', 'health']],\n",
" [['he', 's', 'not', 'at', 'all', 'afraid', 'of', 'snakes']],\n",
" [['he', 'is', 'a', 'fast', 'speaker']]]"
]
},
"execution_count": 38,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y[:5]"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {
"execution": {
"iopub.execute_input": "2024-05-25T15:43:36.654953Z",
"iopub.status.busy": "2024-05-25T15:43:36.654035Z",
"iopub.status.idle": "2024-05-25T15:43:36.916617Z",
"shell.execute_reply": "2024-05-25T15:43:36.915429Z",
"shell.execute_reply.started": "2024-05-25T15:43:36.654906Z"
},
"trusted": true
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"C:\\Users\\Michał\\AppData\\Roaming\\Python\\Python310\\site-packages\\torchtext\\data\\__init__.py:4: UserWarning: \n",
"/!\\ IMPORTANT WARNING ABOUT TORCHTEXT STATUS /!\\ \n",
"Torchtext is deprecated and the last released version will be 0.18 (this one). You can silence this warning by calling the following at the beginnign of your scripts: `import torchtext; torchtext.disable_torchtext_deprecation_warning()`\n",
" warnings.warn(torchtext._TORCHTEXT_DEPRECATION_MSG)\n"
]
},
{
"data": {
"text/plain": [
"0.8122377991676331"
]
},
"execution_count": 39,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from torchtext.data.metrics import bleu_score\n",
"\n",
"bleu_score(x, y)"
]
}
],
"metadata": {
"kaggle": {
"accelerator": "nvidiaTeslaT4",
"dataSources": [
{
"datasetId": 5082663,
"sourceId": 8513800,
"sourceType": "datasetVersion"
}
],
"dockerImageVersionId": 30699,
"isGpuEnabled": true,
"isInternetEnabled": true,
"language": "python",
"sourceType": "notebook"
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.11"
}
},
"nbformat": 4,
"nbformat_minor": 4
}