dl_seq2seq/fin-to-en-seq2seq-v3.ipynb

1301 lines
119 KiB
Plaintext

{
"metadata": {
"kernelspec": {
"language": "python",
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.10.13",
"mimetype": "text/x-python",
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"pygments_lexer": "ipython3",
"nbconvert_exporter": "python",
"file_extension": ".py"
},
"kaggle": {
"accelerator": "nvidiaTeslaT4",
"dataSources": [
{
"sourceId": 8513800,
"sourceType": "datasetVersion",
"datasetId": 5082663
}
],
"dockerImageVersionId": 30699,
"isInternetEnabled": true,
"language": "python",
"sourceType": "notebook",
"isGpuEnabled": true
}
},
"nbformat_minor": 4,
"nbformat": 4,
"cells": [
{
"cell_type": "markdown",
"source": [
"# Seq2Seq Fiński --> Angielski\n",
"https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html"
],
"metadata": {}
},
{
"cell_type": "code",
"source": [
"from __future__ import unicode_literals, print_function, division\n",
"from io import open\n",
"import unicodedata\n",
"import re\n",
"import random\n",
"\n",
"import torch\n",
"import torch.nn as nn\n",
"from torch import optim\n",
"import torch.nn.functional as F\n",
"\n",
"import numpy as np\n",
"from torch.utils.data import TensorDataset, DataLoader, RandomSampler\n",
"\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
],
"metadata": {
"_uuid": "8f2839f25d086af736a60e9eeb907d3b93b6e0e5",
"_cell_guid": "b1076dfc-b9ad-4769-8c92-a6c4dae69d19",
"execution": {
"iopub.status.busy": "2024-05-25T14:03:55.886451Z",
"iopub.execute_input": "2024-05-25T14:03:55.887266Z",
"iopub.status.idle": "2024-05-25T14:04:02.514594Z",
"shell.execute_reply.started": "2024-05-25T14:03:55.887232Z",
"shell.execute_reply": "2024-05-25T14:04:02.513697Z"
},
"trusted": true
},
"execution_count": 1,
"outputs": []
},
{
"cell_type": "code",
"source": [
"torch.cuda.device_count()"
],
"metadata": {
"execution": {
"iopub.status.busy": "2024-05-25T14:04:09.403445Z",
"iopub.execute_input": "2024-05-25T14:04:09.403926Z",
"iopub.status.idle": "2024-05-25T14:04:09.434533Z",
"shell.execute_reply.started": "2024-05-25T14:04:09.403898Z",
"shell.execute_reply": "2024-05-25T14:04:09.433678Z"
},
"trusted": true
},
"execution_count": 2,
"outputs": [
{
"execution_count": 2,
"output_type": "execute_result",
"data": {
"text/plain": "2"
},
"metadata": {}
}
]
},
{
"cell_type": "markdown",
"source": [
"### Konwersja słów na index"
],
"metadata": {}
},
{
"cell_type": "code",
"source": [
"SOS_token = 0\n",
"EOS_token = 1\n",
"\n",
"class Lang:\n",
" def __init__(self, name):\n",
" self.name = name\n",
" self.word2index = {}\n",
" self.word2count = {}\n",
" self.index2word = {0: \"SOS\", 1: \"EOS\"}\n",
" self.n_words = 2 # Count SOS and EOS\n",
"\n",
" def addSentence(self, sentence):\n",
" for word in sentence.split(' '):\n",
" self.addWord(word)\n",
"\n",
" def addWord(self, word):\n",
" if word not in self.word2index:\n",
" self.word2index[word] = self.n_words\n",
" self.word2count[word] = 1\n",
" self.index2word[self.n_words] = word\n",
" self.n_words += 1\n",
" else:\n",
" self.word2count[word] += 1"
],
"metadata": {
"execution": {
"iopub.status.busy": "2024-05-25T14:04:14.014114Z",
"iopub.execute_input": "2024-05-25T14:04:14.014490Z",
"iopub.status.idle": "2024-05-25T14:04:14.024526Z",
"shell.execute_reply.started": "2024-05-25T14:04:14.014461Z",
"shell.execute_reply": "2024-05-25T14:04:14.023673Z"
},
"trusted": true
},
"execution_count": 3,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"### Normalizacja tekstu"
],
"metadata": {}
},
{
"cell_type": "code",
"source": [
"# Turn a Unicode string to plain ASCII, thanks to\n",
"# https://stackoverflow.com/a/518232/2809427\n",
"def unicodeToAscii(s):\n",
" return ''.join(\n",
" c for c in unicodedata.normalize('NFD', s)\n",
" if unicodedata.category(c) != 'Mn'\n",
" )\n",
"\n",
"# Lowercase, trim, and remove non-letter characters\n",
"def normalizeString(s):\n",
" s = unicodeToAscii(s.lower().strip())\n",
" s = re.sub(r\"([.!?])\", r\" \\1\", s)\n",
" s = re.sub(r\"[^a-zA-Z!?]+\", r\" \", s)\n",
" return s.strip()"
],
"metadata": {
"execution": {
"iopub.status.busy": "2024-05-25T14:04:23.431898Z",
"iopub.execute_input": "2024-05-25T14:04:23.432285Z",
"iopub.status.idle": "2024-05-25T14:04:23.438688Z",
"shell.execute_reply.started": "2024-05-25T14:04:23.432256Z",
"shell.execute_reply": "2024-05-25T14:04:23.437569Z"
},
"trusted": true
},
"execution_count": 4,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"### Wczytywanie danych (zmodyfikowane ze względu na ścieżkę w kaggle)"
],
"metadata": {}
},
{
"cell_type": "code",
"source": [
"# Zmodyfikowana wersja ze względu na użycie pojedynczego pliku przesłanego na Kaggle\n",
"def readLangs(reverse=False):\n",
" print(\"Reading lines...\")\n",
" lang1=\"en\"\n",
" lang2=\"fin\"\n",
" # Read the file and split into lines\n",
" lines = open('/kaggle/input/anki-en-fin/fin.txt', encoding='utf-8').\\\n",
" read().strip().split('\\n')\n",
"\n",
" # Split every line into pairs and normalize\n",
" pairs = [[normalizeString(s) for s in l.split('\\t')[:-1]] for l in lines] # +Usuwanie licencji CC z linii\n",
"\n",
" # Reverse pairs, make Lang instances\n",
" if reverse:\n",
" pairs = [list(reversed(p)) for p in pairs]\n",
" input_lang = Lang(lang2)\n",
" output_lang = Lang(lang1)\n",
" else:\n",
" input_lang = Lang(lang1)\n",
" output_lang = Lang(lang2)\n",
"\n",
" return input_lang, output_lang, pairs"
],
"metadata": {
"execution": {
"iopub.status.busy": "2024-05-25T14:12:25.385674Z",
"iopub.execute_input": "2024-05-25T14:12:25.386029Z",
"iopub.status.idle": "2024-05-25T14:12:25.394103Z",
"shell.execute_reply.started": "2024-05-25T14:12:25.386002Z",
"shell.execute_reply": "2024-05-25T14:12:25.392925Z"
},
"trusted": true
},
"execution_count": 14,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"#### Ograniczenie do zdań max 10 słów, formy I am / You are / He is etc. bez interpunkcji"
],
"metadata": {}
},
{
"cell_type": "code",
"source": [
"MAX_LENGTH = 10\n",
"\n",
"eng_prefixes = (\n",
" \"i am \", \"i m \",\n",
" \"he is\", \"he s \",\n",
" \"she is\", \"she s \",\n",
" \"you are\", \"you re \",\n",
" \"we are\", \"we re \",\n",
" \"they are\", \"they re \"\n",
")\n",
"\n",
"def filterPair(p):\n",
" return len(p[0].split(' ')) < MAX_LENGTH and \\\n",
" len(p[1].split(' ')) < MAX_LENGTH and \\\n",
" p[1].startswith(eng_prefixes)\n",
"\n",
"\n",
"def filterPairs(pairs):\n",
" return [pair for pair in pairs if filterPair(pair)]"
],
"metadata": {
"execution": {
"iopub.status.busy": "2024-05-25T14:12:29.729786Z",
"iopub.execute_input": "2024-05-25T14:12:29.730147Z",
"iopub.status.idle": "2024-05-25T14:12:29.737013Z",
"shell.execute_reply.started": "2024-05-25T14:12:29.730121Z",
"shell.execute_reply": "2024-05-25T14:12:29.735886Z"
},
"trusted": true
},
"execution_count": 15,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def prepareData(reverse=False):\n",
" input_lang, output_lang, pairs = readLangs(reverse)\n",
" print(\"Read %s sentence pairs\" % len(pairs))\n",
" pairs = filterPairs(pairs)\n",
" print(\"Trimmed to %s sentence pairs\" % len(pairs))\n",
" print(\"Counting words...\")\n",
" for pair in pairs:\n",
" input_lang.addSentence(pair[0])\n",
" output_lang.addSentence(pair[1])\n",
" print(\"Counted words:\")\n",
" print(input_lang.name, input_lang.n_words)\n",
" print(output_lang.name, output_lang.n_words)\n",
" return input_lang, output_lang, pairs\n",
"\n",
"input_lang, output_lang, pairs = prepareData(True)\n",
"print(random.choice(pairs))"
],
"metadata": {
"execution": {
"iopub.status.busy": "2024-05-25T14:12:33.204103Z",
"iopub.execute_input": "2024-05-25T14:12:33.204776Z",
"iopub.status.idle": "2024-05-25T14:12:36.889693Z",
"shell.execute_reply.started": "2024-05-25T14:12:33.204744Z",
"shell.execute_reply": "2024-05-25T14:12:36.888700Z"
},
"trusted": true
},
"execution_count": 16,
"outputs": [
{
"name": "stdout",
"text": "Reading lines...\nRead 72258 sentence pairs\nTrimmed to 5005 sentence pairs\nCounting words...\nCounted words:\nfin 3686\nen 1971\n['mina odotan joulua innolla', 'i am looking forward to christmas']\n",
"output_type": "stream"
}
]
},
{
"cell_type": "markdown",
"source": [
"### Definicja modelu"
],
"metadata": {}
},
{
"cell_type": "code",
"source": [
"class EncoderRNN(nn.Module):\n",
" def __init__(self, input_size, hidden_size, dropout_p=0.1):\n",
" super(EncoderRNN, self).__init__()\n",
" self.hidden_size = hidden_size\n",
"\n",
" self.embedding = nn.Embedding(input_size, hidden_size)\n",
" self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True)\n",
" self.dropout = nn.Dropout(dropout_p)\n",
"\n",
" def forward(self, input):\n",
" embedded = self.dropout(self.embedding(input))\n",
" output, hidden = self.gru(embedded)\n",
" return output, hidden"
],
"metadata": {
"execution": {
"iopub.status.busy": "2024-05-25T14:12:52.383787Z",
"iopub.execute_input": "2024-05-25T14:12:52.384131Z",
"iopub.status.idle": "2024-05-25T14:12:52.391196Z",
"shell.execute_reply.started": "2024-05-25T14:12:52.384104Z",
"shell.execute_reply": "2024-05-25T14:12:52.390316Z"
},
"trusted": true
},
"execution_count": 17,
"outputs": []
},
{
"cell_type": "code",
"source": [
"class DecoderRNN(nn.Module):\n",
" def __init__(self, hidden_size, output_size):\n",
" super(DecoderRNN, self).__init__()\n",
" self.embedding = nn.Embedding(output_size, hidden_size)\n",
" self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True)\n",
" self.out = nn.Linear(hidden_size, output_size)\n",
"\n",
" def forward(self, encoder_outputs, encoder_hidden, target_tensor=None):\n",
" batch_size = encoder_outputs.size(0)\n",
" decoder_input = torch.empty(batch_size, 1, dtype=torch.long, device=device).fill_(SOS_token)\n",
" decoder_hidden = encoder_hidden\n",
" decoder_outputs = []\n",
"\n",
" for i in range(MAX_LENGTH):\n",
" decoder_output, decoder_hidden = self.forward_step(decoder_input, decoder_hidden)\n",
" decoder_outputs.append(decoder_output)\n",
"\n",
" if target_tensor is not None:\n",
" # Teacher forcing: Feed the target as the next input\n",
" decoder_input = target_tensor[:, i].unsqueeze(1) # Teacher forcing\n",
" else:\n",
" # Without teacher forcing: use its own predictions as the next input\n",
" _, topi = decoder_output.topk(1)\n",
" decoder_input = topi.squeeze(-1).detach() # detach from history as input\n",
"\n",
" decoder_outputs = torch.cat(decoder_outputs, dim=1)\n",
" decoder_outputs = F.log_softmax(decoder_outputs, dim=-1)\n",
" return decoder_outputs, decoder_hidden, None # We return `None` for consistency in the training loop\n",
"\n",
" def forward_step(self, input, hidden):\n",
" output = self.embedding(input)\n",
" output = F.relu(output)\n",
" output, hidden = self.gru(output, hidden)\n",
" output = self.out(output)\n",
" return output, hidden"
],
"metadata": {
"execution": {
"iopub.status.busy": "2024-05-25T14:12:54.393953Z",
"iopub.execute_input": "2024-05-25T14:12:54.394808Z",
"iopub.status.idle": "2024-05-25T14:12:54.409000Z",
"shell.execute_reply.started": "2024-05-25T14:12:54.394765Z",
"shell.execute_reply": "2024-05-25T14:12:54.407827Z"
},
"trusted": true
},
"execution_count": 18,
"outputs": []
},
{
"cell_type": "code",
"source": [
"class BahdanauAttention(nn.Module):\n",
" def __init__(self, hidden_size):\n",
" super(BahdanauAttention, self).__init__()\n",
" self.Wa = nn.Linear(hidden_size, hidden_size)\n",
" self.Ua = nn.Linear(hidden_size, hidden_size)\n",
" self.Va = nn.Linear(hidden_size, 1)\n",
"\n",
" def forward(self, query, keys):\n",
" scores = self.Va(torch.tanh(self.Wa(query) + self.Ua(keys)))\n",
" scores = scores.squeeze(2).unsqueeze(1)\n",
"\n",
" weights = F.softmax(scores, dim=-1)\n",
" context = torch.bmm(weights, keys)\n",
"\n",
" return context, weights\n",
"\n",
"class AttnDecoderRNN(nn.Module):\n",
" def __init__(self, hidden_size, output_size, dropout_p=0.1):\n",
" super(AttnDecoderRNN, self).__init__()\n",
" self.embedding = nn.Embedding(output_size, hidden_size)\n",
" self.attention = BahdanauAttention(hidden_size)\n",
" self.gru = nn.GRU(2 * hidden_size, hidden_size, batch_first=True)\n",
" self.out = nn.Linear(hidden_size, output_size)\n",
" self.dropout = nn.Dropout(dropout_p)\n",
"\n",
" def forward(self, encoder_outputs, encoder_hidden, target_tensor=None):\n",
" batch_size = encoder_outputs.size(0)\n",
" decoder_input = torch.empty(batch_size, 1, dtype=torch.long, device=device).fill_(SOS_token)\n",
" decoder_hidden = encoder_hidden\n",
" decoder_outputs = []\n",
" attentions = []\n",
"\n",
" for i in range(MAX_LENGTH):\n",
" decoder_output, decoder_hidden, attn_weights = self.forward_step(\n",
" decoder_input, decoder_hidden, encoder_outputs\n",
" )\n",
" decoder_outputs.append(decoder_output)\n",
" attentions.append(attn_weights)\n",
"\n",
" if target_tensor is not None:\n",
" # Teacher forcing: Feed the target as the next input\n",
" decoder_input = target_tensor[:, i].unsqueeze(1) # Teacher forcing\n",
" else:\n",
" # Without teacher forcing: use its own predictions as the next input\n",
" _, topi = decoder_output.topk(1)\n",
" decoder_input = topi.squeeze(-1).detach() # detach from history as input\n",
"\n",
" decoder_outputs = torch.cat(decoder_outputs, dim=1)\n",
" decoder_outputs = F.log_softmax(decoder_outputs, dim=-1)\n",
" attentions = torch.cat(attentions, dim=1)\n",
"\n",
" return decoder_outputs, decoder_hidden, attentions\n",
"\n",
"\n",
" def forward_step(self, input, hidden, encoder_outputs):\n",
" embedded = self.dropout(self.embedding(input))\n",
"\n",
" query = hidden.permute(1, 0, 2)\n",
" context, attn_weights = self.attention(query, encoder_outputs)\n",
" input_gru = torch.cat((embedded, context), dim=2)\n",
"\n",
" output, hidden = self.gru(input_gru, hidden)\n",
" output = self.out(output)\n",
"\n",
" return output, hidden, attn_weights"
],
"metadata": {
"execution": {
"iopub.status.busy": "2024-05-25T14:13:00.670299Z",
"iopub.execute_input": "2024-05-25T14:13:00.670758Z",
"iopub.status.idle": "2024-05-25T14:13:00.687695Z",
"shell.execute_reply.started": "2024-05-25T14:13:00.670720Z",
"shell.execute_reply": "2024-05-25T14:13:00.686610Z"
},
"trusted": true
},
"execution_count": 19,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def indexesFromSentence(lang, sentence):\n",
" return [lang.word2index[word] for word in sentence.split(' ')]\n",
"\n",
"def tensorFromSentence(lang, sentence):\n",
" indexes = indexesFromSentence(lang, sentence)\n",
" indexes.append(EOS_token)\n",
" return torch.tensor(indexes, dtype=torch.long, device=device).view(1, -1)\n",
"\n",
"def tensorsFromPair(pair):\n",
" input_tensor = tensorFromSentence(input_lang, pair[0])\n",
" target_tensor = tensorFromSentence(output_lang, pair[1])\n",
" return (input_tensor, target_tensor)\n",
"\n",
"def get_dataloader(batch_size):\n",
" input_lang, output_lang, pairs = prepareData(True)\n",
"\n",
" n = len(pairs)\n",
" input_ids = np.zeros((n, MAX_LENGTH), dtype=np.int32)\n",
" target_ids = np.zeros((n, MAX_LENGTH), dtype=np.int32)\n",
"\n",
" for idx, (inp, tgt) in enumerate(pairs):\n",
" inp_ids = indexesFromSentence(input_lang, inp)\n",
" tgt_ids = indexesFromSentence(output_lang, tgt)\n",
" inp_ids.append(EOS_token)\n",
" tgt_ids.append(EOS_token)\n",
" input_ids[idx, :len(inp_ids)] = inp_ids\n",
" target_ids[idx, :len(tgt_ids)] = tgt_ids\n",
"\n",
" train_data = TensorDataset(torch.LongTensor(input_ids).to(device),\n",
" torch.LongTensor(target_ids).to(device))\n",
"\n",
" train_sampler = RandomSampler(train_data)\n",
" train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size)\n",
" return input_lang, output_lang, train_dataloader"
],
"metadata": {
"execution": {
"iopub.status.busy": "2024-05-25T14:22:08.183866Z",
"iopub.execute_input": "2024-05-25T14:22:08.184711Z",
"iopub.status.idle": "2024-05-25T14:22:08.194870Z",
"shell.execute_reply.started": "2024-05-25T14:22:08.184675Z",
"shell.execute_reply": "2024-05-25T14:22:08.193965Z"
},
"trusted": true
},
"execution_count": 31,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def train_epoch(dataloader, encoder, decoder, encoder_optimizer,\n",
" decoder_optimizer, criterion):\n",
"\n",
" total_loss = 0\n",
" for data in dataloader:\n",
" input_tensor, target_tensor = data\n",
"\n",
" encoder_optimizer.zero_grad()\n",
" decoder_optimizer.zero_grad()\n",
"\n",
" encoder_outputs, encoder_hidden = encoder(input_tensor)\n",
" decoder_outputs, _, _ = decoder(encoder_outputs, encoder_hidden, target_tensor)\n",
"\n",
" loss = criterion(\n",
" decoder_outputs.view(-1, decoder_outputs.size(-1)),\n",
" target_tensor.view(-1)\n",
" )\n",
" loss.backward()\n",
"\n",
" encoder_optimizer.step()\n",
" decoder_optimizer.step()\n",
"\n",
" total_loss += loss.item()\n",
"\n",
" return total_loss / len(dataloader)"
],
"metadata": {
"execution": {
"iopub.status.busy": "2024-05-25T14:16:38.894580Z",
"iopub.execute_input": "2024-05-25T14:16:38.895410Z",
"iopub.status.idle": "2024-05-25T14:16:38.902142Z",
"shell.execute_reply.started": "2024-05-25T14:16:38.895382Z",
"shell.execute_reply": "2024-05-25T14:16:38.900953Z"
},
"trusted": true
},
"execution_count": 22,
"outputs": []
},
{
"cell_type": "code",
"source": [
"import time\n",
"import math\n",
"\n",
"def asMinutes(s):\n",
" m = math.floor(s / 60)\n",
" s -= m * 60\n",
" return '%dm %ds' % (m, s)\n",
"\n",
"def timeSince(since, percent):\n",
" now = time.time()\n",
" s = now - since\n",
" es = s / (percent)\n",
" rs = es - s\n",
" return '%s (- %s)' % (asMinutes(s), asMinutes(rs))"
],
"metadata": {
"execution": {
"iopub.status.busy": "2024-05-25T14:16:43.069584Z",
"iopub.execute_input": "2024-05-25T14:16:43.069953Z",
"iopub.status.idle": "2024-05-25T14:16:43.075972Z",
"shell.execute_reply.started": "2024-05-25T14:16:43.069926Z",
"shell.execute_reply": "2024-05-25T14:16:43.075033Z"
},
"trusted": true
},
"execution_count": 23,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def train(train_dataloader, encoder, decoder, n_epochs, learning_rate=0.001,\n",
" print_every=100, plot_every=100):\n",
" start = time.time()\n",
" plot_losses = []\n",
" print_loss_total = 0 # Reset every print_every\n",
" plot_loss_total = 0 # Reset every plot_every\n",
"\n",
" encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)\n",
" decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate)\n",
" criterion = nn.NLLLoss()\n",
"\n",
" for epoch in range(1, n_epochs + 1):\n",
" loss = train_epoch(train_dataloader, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion)\n",
" print_loss_total += loss\n",
" plot_loss_total += loss\n",
"\n",
" if epoch % print_every == 0:\n",
" print_loss_avg = print_loss_total / print_every\n",
" print_loss_total = 0\n",
" print('%s (%d %d%%) %.4f' % (timeSince(start, epoch / n_epochs),\n",
" epoch, epoch / n_epochs * 100, print_loss_avg))\n",
"\n",
" if epoch % plot_every == 0:\n",
" plot_loss_avg = plot_loss_total / plot_every\n",
" plot_losses.append(plot_loss_avg)\n",
" plot_loss_total = 0\n",
"\n",
" showPlot(plot_losses)"
],
"metadata": {
"execution": {
"iopub.status.busy": "2024-05-25T14:20:58.574148Z",
"iopub.execute_input": "2024-05-25T14:20:58.574520Z",
"iopub.status.idle": "2024-05-25T14:20:58.583203Z",
"shell.execute_reply.started": "2024-05-25T14:20:58.574492Z",
"shell.execute_reply": "2024-05-25T14:20:58.582230Z"
},
"trusted": true
},
"execution_count": 24,
"outputs": []
},
{
"cell_type": "code",
"source": [
"import matplotlib.pyplot as plt\n",
"plt.switch_backend('agg')\n",
"import matplotlib.ticker as ticker\n",
"import numpy as np\n",
"%matplotlib inline\n",
"\n",
"def showPlot(points):\n",
" plt.figure()\n",
" fig, ax = plt.subplots()\n",
" # this locator puts ticks at regular intervals\n",
" loc = ticker.MultipleLocator(base=0.2)\n",
" ax.yaxis.set_major_locator(loc)\n",
" plt.plot(points)"
],
"metadata": {
"execution": {
"iopub.status.busy": "2024-05-25T14:21:00.586018Z",
"iopub.execute_input": "2024-05-25T14:21:00.586719Z",
"iopub.status.idle": "2024-05-25T14:21:00.592633Z",
"shell.execute_reply.started": "2024-05-25T14:21:00.586683Z",
"shell.execute_reply": "2024-05-25T14:21:00.591636Z"
},
"trusted": true
},
"execution_count": 25,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"### Ewaluacja"
],
"metadata": {}
},
{
"cell_type": "code",
"source": [
"def evaluate(encoder, decoder, sentence, input_lang, output_lang):\n",
" with torch.no_grad():\n",
" input_tensor = tensorFromSentence(input_lang, sentence)\n",
"\n",
" encoder_outputs, encoder_hidden = encoder(input_tensor)\n",
" decoder_outputs, decoder_hidden, decoder_attn = decoder(encoder_outputs, encoder_hidden)\n",
"\n",
" _, topi = decoder_outputs.topk(1)\n",
" decoded_ids = topi.squeeze()\n",
"\n",
" decoded_words = []\n",
" for idx in decoded_ids:\n",
" if idx.item() == EOS_token:\n",
" decoded_words.append('<EOS>')\n",
" break\n",
" decoded_words.append(output_lang.index2word[idx.item()])\n",
" return decoded_words, decoder_attn"
],
"metadata": {
"execution": {
"iopub.status.busy": "2024-05-25T14:21:01.858691Z",
"iopub.execute_input": "2024-05-25T14:21:01.859612Z",
"iopub.status.idle": "2024-05-25T14:21:01.866857Z",
"shell.execute_reply.started": "2024-05-25T14:21:01.859574Z",
"shell.execute_reply": "2024-05-25T14:21:01.865732Z"
},
"trusted": true
},
"execution_count": 26,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def evaluateRandomly(encoder, decoder, n=10):\n",
" for i in range(n):\n",
" pair = random.choice(pairs)\n",
" print('Input sentence: ', pair[0])\n",
" print('Target (true) translation:' , pair[1])\n",
" output_words, _ = evaluate(encoder, decoder, pair[0], input_lang, output_lang)\n",
" output_sentence = ' '.join(output_words)\n",
" print('Output sentence: ', output_sentence)\n",
" print('')"
],
"metadata": {
"execution": {
"iopub.status.busy": "2024-05-25T14:39:52.474985Z",
"iopub.execute_input": "2024-05-25T14:39:52.475327Z",
"iopub.status.idle": "2024-05-25T14:39:52.481801Z",
"shell.execute_reply.started": "2024-05-25T14:39:52.475304Z",
"shell.execute_reply": "2024-05-25T14:39:52.480957Z"
},
"trusted": true
},
"execution_count": 36,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# Wykorzystanie zdefiniowanych wyżej funkcji"
],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"### Trenowanie modelu"
],
"metadata": {}
},
{
"cell_type": "code",
"source": [
"hidden_size = 128\n",
"batch_size = 32\n",
"\n",
"input_lang, output_lang, train_dataloader = get_dataloader(batch_size)\n",
"\n",
"encoder = EncoderRNN(input_lang.n_words, hidden_size).to(device)\n",
"decoder = AttnDecoderRNN(hidden_size, output_lang.n_words).to(device)\n",
"\n",
"train(train_dataloader, encoder, decoder, 80, print_every=5, plot_every=5)"
],
"metadata": {
"execution": {
"iopub.status.busy": "2024-05-25T14:22:39.754370Z",
"iopub.execute_input": "2024-05-25T14:22:39.754740Z",
"iopub.status.idle": "2024-05-25T14:26:53.707012Z",
"shell.execute_reply.started": "2024-05-25T14:22:39.754714Z",
"shell.execute_reply": "2024-05-25T14:26:53.705969Z"
},
"trusted": true
},
"execution_count": 32,
"outputs": [
{
"name": "stdout",
"text": "Reading lines...\nRead 72258 sentence pairs\nTrimmed to 5005 sentence pairs\nCounting words...\nCounted words:\nfin 3686\nen 1971\n0m 21s (- 5m 21s) (5 6%) 1.9364\n0m 36s (- 4m 17s) (10 12%) 1.0355\n0m 51s (- 3m 45s) (15 18%) 0.6313\n1m 7s (- 3m 21s) (20 25%) 0.3787\n1m 22s (- 3m 1s) (25 31%) 0.2243\n1m 37s (- 2m 42s) (30 37%) 0.1371\n1m 52s (- 2m 25s) (35 43%) 0.0903\n2m 7s (- 2m 7s) (40 50%) 0.0668\n2m 23s (- 1m 51s) (45 56%) 0.0538\n2m 38s (- 1m 34s) (50 62%) 0.0471\n2m 53s (- 1m 18s) (55 68%) 0.0410\n3m 8s (- 1m 2s) (60 75%) 0.0381\n3m 23s (- 0m 47s) (65 81%) 0.0343\n3m 38s (- 0m 31s) (70 87%) 0.0342\n3m 54s (- 0m 15s) (75 93%) 0.0322\n4m 9s (- 0m 0s) (80 100%) 0.0307\n",
"output_type": "stream"
}
]
},
{
"cell_type": "code",
"source": [
"evaluateRandomly(encoder, decoder)"
],
"metadata": {
"execution": {
"iopub.status.busy": "2024-05-25T15:48:23.131435Z",
"iopub.execute_input": "2024-05-25T15:48:23.131948Z",
"iopub.status.idle": "2024-05-25T15:48:23.213007Z",
"shell.execute_reply.started": "2024-05-25T15:48:23.131911Z",
"shell.execute_reply": "2024-05-25T15:48:23.211856Z"
},
"trusted": true
},
"execution_count": 121,
"outputs": [
{
"name": "stdout",
"text": "Input sentence: olen hyvin hyvin vihainen\nTarget (true) translation: i m very very angry\nOutput sentence: i am very angry today <EOS>\n\nInput sentence: han valehtelee\nTarget (true) translation: he s lying\nOutput sentence: he is telling a lie <EOS>\n\nInput sentence: olen myohassa\nTarget (true) translation: i m late\nOutput sentence: i m late <EOS>\n\nInput sentence: han on linja autonkuljettaja\nTarget (true) translation: he is a bus driver\nOutput sentence: he is a bus driver <EOS>\n\nInput sentence: mukava tavata sinut taas\nTarget (true) translation: i m glad to see you again\nOutput sentence: i m glad to see you again <EOS>\n\nInput sentence: olet kuumeessa\nTarget (true) translation: you re running a fever\nOutput sentence: you re so predictable <EOS>\n\nInput sentence: anteeksi mutta unohdin tehda laksyt\nTarget (true) translation: i m sorry i forgot to do my homework\nOutput sentence: i m sorry i forgot to do my homework <EOS>\n\nInput sentence: mina olen tyoton\nTarget (true) translation: i m unemployed\nOutput sentence: i m unemployed <EOS>\n\nInput sentence: olen taynna\nTarget (true) translation: i am full\nOutput sentence: i am full of french <EOS>\n\nInput sentence: ma kuolen nalkaan !\nTarget (true) translation: i m dying of hunger\nOutput sentence: i m dying of hunger <EOS>\n\n",
"output_type": "stream"
}
]
},
{
"cell_type": "code",
"source": [
"def showAttention(input_sentence, output_words, attentions):\n",
" fig = plt.figure()\n",
" ax = fig.add_subplot(111)\n",
" cax = ax.matshow(attentions.cpu().numpy(), cmap='bone')\n",
" fig.colorbar(cax)\n",
"\n",
" # Set up axes\n",
" ax.set_xticklabels([''] + input_sentence.split(' ') +\n",
" ['<EOS>'], rotation=90)\n",
" ax.set_yticklabels([''] + output_words)\n",
"\n",
" # Show label at every tick\n",
" ax.xaxis.set_major_locator(ticker.MultipleLocator(1))\n",
" ax.yaxis.set_major_locator(ticker.MultipleLocator(1))\n",
"\n",
" plt.show()\n",
"\n",
"\n",
"def evaluateAndShowAttention(input_sentence):\n",
" input_sentence = normalizeString(input_sentence)\n",
" output_words, attentions = evaluate(encoder, decoder, input_sentence, input_lang, output_lang)\n",
" print('input =', input_sentence)\n",
" print('output =', ' '.join(output_words))\n",
" showAttention(input_sentence, output_words, attentions[0, :len(output_words), :])\n",
"\n",
"def translate(input_sentence, tokenized=False):\n",
" input_sentence = normalizeString(input_sentence)\n",
" output_words, attentions = evaluate(encoder, decoder, input_sentence, input_lang, output_lang)\n",
" if tokenized:\n",
" if \"<EOS>\" in output_words:\n",
" output_words.remove(\"<EOS>\")\n",
" return output_words\n",
" return ' '.join(output_words)"
],
"metadata": {
"execution": {
"iopub.status.busy": "2024-05-25T15:30:41.253325Z",
"iopub.execute_input": "2024-05-25T15:30:41.253734Z",
"iopub.status.idle": "2024-05-25T15:30:41.264515Z",
"shell.execute_reply.started": "2024-05-25T15:30:41.253703Z",
"shell.execute_reply": "2024-05-25T15:30:41.263376Z"
},
"trusted": true
},
"execution_count": 99,
"outputs": []
},
{
"cell_type": "code",
"source": [
"translate(\"Meillä on nälkä\", tokenized=True)"
],
"metadata": {
"execution": {
"iopub.status.busy": "2024-05-25T14:58:53.639252Z",
"iopub.execute_input": "2024-05-25T14:58:53.639963Z",
"iopub.status.idle": "2024-05-25T14:58:53.654186Z",
"shell.execute_reply.started": "2024-05-25T14:58:53.639932Z",
"shell.execute_reply": "2024-05-25T14:58:53.653028Z"
},
"trusted": true
},
"execution_count": 76,
"outputs": [
{
"execution_count": 76,
"output_type": "execute_result",
"data": {
"text/plain": "['we', 'are', 'hungry']"
},
"metadata": {}
}
]
},
{
"cell_type": "code",
"source": [
"evaluateAndShowAttention('Olet liian naivi')"
],
"metadata": {
"execution": {
"iopub.status.busy": "2024-05-25T14:46:05.023218Z",
"iopub.execute_input": "2024-05-25T14:46:05.024277Z",
"iopub.status.idle": "2024-05-25T14:46:05.426793Z",
"shell.execute_reply.started": "2024-05-25T14:46:05.024227Z",
"shell.execute_reply": "2024-05-25T14:46:05.424993Z"
},
"trusted": true
},
"execution_count": 44,
"outputs": [
{
"name": "stdout",
"text": "input = olet liian naivi\noutput = you re too naive <EOS>\n",
"output_type": "stream"
},
{
"name": "stderr",
"text": "/tmp/ipykernel_34/2052950992.py:8: UserWarning: FixedFormatter should only be used together with FixedLocator\n ax.set_xticklabels([''] + input_sentence.split(' ') +\n/tmp/ipykernel_34/2052950992.py:10: UserWarning: FixedFormatter should only be used together with FixedLocator\n ax.set_yticklabels([''] + output_words)\n",
"output_type": "stream"
},
{
"output_type": "display_data",
"data": {
"text/plain": "<Figure size 640x480 with 2 Axes>",
"image/png": ""
},
"metadata": {}
}
]
},
{
"cell_type": "code",
"source": [
"evaluateAndShowAttention('Olen todella pahoillani')"
],
"metadata": {
"execution": {
"iopub.status.busy": "2024-05-25T14:46:10.394969Z",
"iopub.execute_input": "2024-05-25T14:46:10.395671Z",
"iopub.status.idle": "2024-05-25T14:46:10.793392Z",
"shell.execute_reply.started": "2024-05-25T14:46:10.395630Z",
"shell.execute_reply": "2024-05-25T14:46:10.791940Z"
},
"trusted": true
},
"execution_count": 45,
"outputs": [
{
"name": "stdout",
"text": "input = olen todella pahoillani\noutput = i am truly sorry <EOS>\n",
"output_type": "stream"
},
{
"name": "stderr",
"text": "/tmp/ipykernel_34/2052950992.py:8: UserWarning: FixedFormatter should only be used together with FixedLocator\n ax.set_xticklabels([''] + input_sentence.split(' ') +\n/tmp/ipykernel_34/2052950992.py:10: UserWarning: FixedFormatter should only be used together with FixedLocator\n ax.set_yticklabels([''] + output_words)\n",
"output_type": "stream"
},
{
"output_type": "display_data",
"data": {
"text/plain": "<Figure size 640x480 with 2 Axes>",
"image/png": ""
},
"metadata": {}
}
]
},
{
"cell_type": "code",
"source": [
"evaluateAndShowAttention('Olet minun isäni')"
],
"metadata": {
"execution": {
"iopub.status.busy": "2024-05-25T14:46:13.190403Z",
"iopub.execute_input": "2024-05-25T14:46:13.191025Z",
"iopub.status.idle": "2024-05-25T14:46:13.613486Z",
"shell.execute_reply.started": "2024-05-25T14:46:13.190997Z",
"shell.execute_reply": "2024-05-25T14:46:13.612143Z"
},
"trusted": true
},
"execution_count": 46,
"outputs": [
{
"name": "stdout",
"text": "input = olet minun isani\noutput = you re my father <EOS>\n",
"output_type": "stream"
},
{
"name": "stderr",
"text": "/tmp/ipykernel_34/2052950992.py:8: UserWarning: FixedFormatter should only be used together with FixedLocator\n ax.set_xticklabels([''] + input_sentence.split(' ') +\n/tmp/ipykernel_34/2052950992.py:10: UserWarning: FixedFormatter should only be used together with FixedLocator\n ax.set_yticklabels([''] + output_words)\n",
"output_type": "stream"
},
{
"output_type": "display_data",
"data": {
"text/plain": "<Figure size 640x480 with 2 Axes>",
"image/png": ""
},
"metadata": {}
}
]
},
{
"cell_type": "code",
"source": [
"evaluateAndShowAttention('Hän on opettaja')"
],
"metadata": {
"execution": {
"iopub.status.busy": "2024-05-25T14:46:15.840433Z",
"iopub.execute_input": "2024-05-25T14:46:15.841005Z",
"iopub.status.idle": "2024-05-25T14:46:16.232003Z",
"shell.execute_reply.started": "2024-05-25T14:46:15.840973Z",
"shell.execute_reply": "2024-05-25T14:46:16.230365Z"
},
"trusted": true
},
"execution_count": 47,
"outputs": [
{
"name": "stdout",
"text": "input = han on opettaja\noutput = he is a teacher <EOS>\n",
"output_type": "stream"
},
{
"name": "stderr",
"text": "/tmp/ipykernel_34/2052950992.py:8: UserWarning: FixedFormatter should only be used together with FixedLocator\n ax.set_xticklabels([''] + input_sentence.split(' ') +\n/tmp/ipykernel_34/2052950992.py:10: UserWarning: FixedFormatter should only be used together with FixedLocator\n ax.set_yticklabels([''] + output_words)\n",
"output_type": "stream"
},
{
"output_type": "display_data",
"data": {
"text/plain": "<Figure size 640x480 with 2 Axes>",
"image/png": ""
},
"metadata": {}
}
]
},
{
"cell_type": "code",
"source": [
"torch.save(encoder.state_dict(), \"encoder.pt\")\n",
"torch.save(decoder.state_dict(), \"decoder.pt\")"
],
"metadata": {
"execution": {
"iopub.status.busy": "2024-05-25T14:50:24.773506Z",
"iopub.execute_input": "2024-05-25T14:50:24.774464Z",
"iopub.status.idle": "2024-05-25T14:50:24.795895Z",
"shell.execute_reply.started": "2024-05-25T14:50:24.774430Z",
"shell.execute_reply": "2024-05-25T14:50:24.794979Z"
},
"trusted": true
},
"execution_count": 48,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# BLEU score"
],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"Jako że korzystaliśmy z okrojonej wersji zbioru danych, słownik nie zawiera wszystkich słów pojawiających się w przykładach więc do ewaluacji wykorzystujemy część przykładów z treningu"
],
"metadata": {}
},
{
"cell_type": "code",
"source": [
"import pandas as pd\n",
"\n",
"\n",
"def filter_rows(row):\n",
" return len(row[\"English\"].split(' '))<MAX_LENGTH and \\\n",
" len(row[\"Finnish\"].split(' '))<MAX_LENGTH and \\\n",
" row[\"English\"].startswith(eng_prefixes)\n",
"data_file = pd.read_csv(\"/kaggle/input/anki-en-fin/fin.txt\", sep='\\t', names=[\"English\",\"Finnish\",\"attribution\"])\n",
"data_file[\"English\"] = data_file[\"English\"].apply(normalizeString)\n",
"data_file[\"Finnish\"] = data_file[\"Finnish\"].apply(normalizeString)\n",
"\n",
"filter_list = data_file.apply(filter_rows, axis=1)\n",
"\n",
"test_section = data_file[filter_list]\n",
"test_section = test_section.sample(frac=1).head(500)\n",
"test_section.head()"
],
"metadata": {
"execution": {
"iopub.status.busy": "2024-05-25T15:28:54.296348Z",
"iopub.execute_input": "2024-05-25T15:28:54.297253Z",
"iopub.status.idle": "2024-05-25T15:28:59.041172Z",
"shell.execute_reply.started": "2024-05-25T15:28:54.297211Z",
"shell.execute_reply": "2024-05-25T15:28:59.040201Z"
},
"trusted": true
},
"execution_count": 95,
"outputs": [
{
"execution_count": 95,
"output_type": "execute_result",
"data": {
"text/plain": " English ... attribution\n38027 i m very serious about this ... CC-BY 2.0 (France) Attribution: tatoeba.org #2...\n3803 i m not tired ... CC-BY 2.0 (France) Attribution: tatoeba.org #1...\n26924 i m not married either ... CC-BY 2.0 (France) Attribution: tatoeba.org #6...\n32009 he s sleeping like a baby ... CC-BY 2.0 (France) Attribution: tatoeba.org #2...\n21339 i m joking of course ... CC-BY 2.0 (France) Attribution: tatoeba.org #2...\n\n[5 rows x 3 columns]",
"text/html": "<div>\n<style scoped>\n .dataframe tbody tr th:only-of-type {\n vertical-align: middle;\n }\n\n .dataframe tbody tr th {\n vertical-align: top;\n }\n\n .dataframe thead th {\n text-align: right;\n }\n</style>\n<table border=\"1\" class=\"dataframe\">\n <thead>\n <tr style=\"text-align: right;\">\n <th></th>\n <th>English</th>\n <th>Finnish</th>\n <th>attribution</th>\n </tr>\n </thead>\n <tbody>\n <tr>\n <th>38027</th>\n <td>i m very serious about this</td>\n <td>olen hyvin tosissani tasta</td>\n <td>CC-BY 2.0 (France) Attribution: tatoeba.org #2...</td>\n </tr>\n <tr>\n <th>3803</th>\n <td>i m not tired</td>\n <td>en ole vasynyt</td>\n <td>CC-BY 2.0 (France) Attribution: tatoeba.org #1...</td>\n </tr>\n <tr>\n <th>26924</th>\n <td>i m not married either</td>\n <td>minakaan en ole naimisissa</td>\n <td>CC-BY 2.0 (France) Attribution: tatoeba.org #6...</td>\n </tr>\n <tr>\n <th>32009</th>\n <td>he s sleeping like a baby</td>\n <td>han nukkuu kuin pikkuvauva</td>\n <td>CC-BY 2.0 (France) Attribution: tatoeba.org #2...</td>\n </tr>\n <tr>\n <th>21339</th>\n <td>i m joking of course</td>\n <td>se oli vitsi tietenkin</td>\n <td>CC-BY 2.0 (France) Attribution: tatoeba.org #2...</td>\n </tr>\n </tbody>\n</table>\n</div>"
},
"metadata": {}
}
]
},
{
"cell_type": "code",
"source": [
"test_section[\"English_tokenized\"] = test_section[\"English\"].apply(lambda x: x.split())"
],
"metadata": {
"execution": {
"iopub.status.busy": "2024-05-25T15:29:07.110416Z",
"iopub.execute_input": "2024-05-25T15:29:07.110816Z",
"iopub.status.idle": "2024-05-25T15:29:07.117378Z",
"shell.execute_reply.started": "2024-05-25T15:29:07.110786Z",
"shell.execute_reply": "2024-05-25T15:29:07.116136Z"
},
"trusted": true
},
"execution_count": 96,
"outputs": []
},
{
"cell_type": "code",
"source": [
"test_section.head()[\"English_tokenized\"]"
],
"metadata": {
"execution": {
"iopub.status.busy": "2024-05-25T15:29:10.203170Z",
"iopub.execute_input": "2024-05-25T15:29:10.203540Z",
"iopub.status.idle": "2024-05-25T15:29:10.212993Z",
"shell.execute_reply.started": "2024-05-25T15:29:10.203511Z",
"shell.execute_reply": "2024-05-25T15:29:10.211937Z"
},
"trusted": true
},
"execution_count": 97,
"outputs": [
{
"execution_count": 97,
"output_type": "execute_result",
"data": {
"text/plain": "38027 [i, m, very, serious, about, this]\n3803 [i, m, not, tired]\n26924 [i, m, not, married, either]\n32009 [he, s, sleeping, like, a, baby]\n21339 [i, m, joking, of, course]\nName: English_tokenized, dtype: object"
},
"metadata": {}
}
]
},
{
"cell_type": "code",
"source": [
"test_section[\"English_translated\"] = test_section[\"Finnish\"].apply(lambda x: translate(x, tokenized=True))"
],
"metadata": {
"execution": {
"iopub.status.busy": "2024-05-25T15:30:53.183117Z",
"iopub.execute_input": "2024-05-25T15:30:53.183937Z",
"iopub.status.idle": "2024-05-25T15:30:56.313012Z",
"shell.execute_reply.started": "2024-05-25T15:30:53.183902Z",
"shell.execute_reply": "2024-05-25T15:30:56.312202Z"
},
"trusted": true
},
"execution_count": 100,
"outputs": []
},
{
"cell_type": "code",
"source": [
"test_section.head()"
],
"metadata": {
"execution": {
"iopub.status.busy": "2024-05-25T15:31:06.745381Z",
"iopub.execute_input": "2024-05-25T15:31:06.746471Z",
"iopub.status.idle": "2024-05-25T15:31:06.771839Z",
"shell.execute_reply.started": "2024-05-25T15:31:06.746417Z",
"shell.execute_reply": "2024-05-25T15:31:06.770679Z"
},
"trusted": true
},
"execution_count": 101,
"outputs": [
{
"execution_count": 101,
"output_type": "execute_result",
"data": {
"text/plain": " English ... English_translated\n38027 i m very serious about this ... [i, m, in, french]\n3803 i m not tired ... [i, not, tired]\n26924 i m not married either ... [i, m, not, married, either]\n32009 he s sleeping like a baby ... [he, is, as, a, pianist]\n21339 i m joking of course ... [i, m, joking, of, course]\n\n[5 rows x 5 columns]",
"text/html": "<div>\n<style scoped>\n .dataframe tbody tr th:only-of-type {\n vertical-align: middle;\n }\n\n .dataframe tbody tr th {\n vertical-align: top;\n }\n\n .dataframe thead th {\n text-align: right;\n }\n</style>\n<table border=\"1\" class=\"dataframe\">\n <thead>\n <tr style=\"text-align: right;\">\n <th></th>\n <th>English</th>\n <th>Finnish</th>\n <th>attribution</th>\n <th>English_tokenized</th>\n <th>English_translated</th>\n </tr>\n </thead>\n <tbody>\n <tr>\n <th>38027</th>\n <td>i m very serious about this</td>\n <td>olen hyvin tosissani tasta</td>\n <td>CC-BY 2.0 (France) Attribution: tatoeba.org #2...</td>\n <td>[i, m, very, serious, about, this]</td>\n <td>[i, m, in, french]</td>\n </tr>\n <tr>\n <th>3803</th>\n <td>i m not tired</td>\n <td>en ole vasynyt</td>\n <td>CC-BY 2.0 (France) Attribution: tatoeba.org #1...</td>\n <td>[i, m, not, tired]</td>\n <td>[i, not, tired]</td>\n </tr>\n <tr>\n <th>26924</th>\n <td>i m not married either</td>\n <td>minakaan en ole naimisissa</td>\n <td>CC-BY 2.0 (France) Attribution: tatoeba.org #6...</td>\n <td>[i, m, not, married, either]</td>\n <td>[i, m, not, married, either]</td>\n </tr>\n <tr>\n <th>32009</th>\n <td>he s sleeping like a baby</td>\n <td>han nukkuu kuin pikkuvauva</td>\n <td>CC-BY 2.0 (France) Attribution: tatoeba.org #2...</td>\n <td>[he, s, sleeping, like, a, baby]</td>\n <td>[he, is, as, a, pianist]</td>\n </tr>\n <tr>\n <th>21339</th>\n <td>i m joking of course</td>\n <td>se oli vitsi tietenkin</td>\n <td>CC-BY 2.0 (France) Attribution: tatoeba.org #2...</td>\n <td>[i, m, joking, of, course]</td>\n <td>[i, m, joking, of, course]</td>\n </tr>\n </tbody>\n</table>\n</div>"
},
"metadata": {}
}
]
},
{
"cell_type": "code",
"source": [
"candidate_corpus = test_section[\"English_translated\"].values\n",
"references_corpus = test_section[\"English_tokenized\"].values.tolist()\n",
"x = candidate_corpus.tolist()\n",
"y = [[el] for el in references_corpus]\n",
"#print(references_corpus[:5])\n",
"#print(candidate_corpus[:5])"
],
"metadata": {
"execution": {
"iopub.status.busy": "2024-05-25T15:43:29.441911Z",
"iopub.execute_input": "2024-05-25T15:43:29.442752Z",
"iopub.status.idle": "2024-05-25T15:43:29.447799Z",
"shell.execute_reply.started": "2024-05-25T15:43:29.442721Z",
"shell.execute_reply": "2024-05-25T15:43:29.446877Z"
},
"trusted": true
},
"execution_count": 118,
"outputs": []
},
{
"cell_type": "code",
"source": [
"y[:5]"
],
"metadata": {
"execution": {
"iopub.status.busy": "2024-05-25T15:43:30.474463Z",
"iopub.execute_input": "2024-05-25T15:43:30.475080Z",
"iopub.status.idle": "2024-05-25T15:43:30.482690Z",
"shell.execute_reply.started": "2024-05-25T15:43:30.475039Z",
"shell.execute_reply": "2024-05-25T15:43:30.481686Z"
},
"trusted": true
},
"execution_count": 119,
"outputs": [
{
"execution_count": 119,
"output_type": "execute_result",
"data": {
"text/plain": "[[['i', 'm', 'very', 'serious', 'about', 'this']],\n [['i', 'm', 'not', 'tired']],\n [['i', 'm', 'not', 'married', 'either']],\n [['he', 's', 'sleeping', 'like', 'a', 'baby']],\n [['i', 'm', 'joking', 'of', 'course']]]"
},
"metadata": {}
}
]
},
{
"cell_type": "code",
"source": [
"from torchtext.data.metrics import bleu_score\n",
"\n",
"bleu_score(x, y)"
],
"metadata": {
"execution": {
"iopub.status.busy": "2024-05-25T15:43:36.654035Z",
"iopub.execute_input": "2024-05-25T15:43:36.654953Z",
"iopub.status.idle": "2024-05-25T15:43:36.916617Z",
"shell.execute_reply.started": "2024-05-25T15:43:36.654906Z",
"shell.execute_reply": "2024-05-25T15:43:36.915429Z"
},
"trusted": true
},
"execution_count": 120,
"outputs": [
{
"execution_count": 120,
"output_type": "execute_result",
"data": {
"text/plain": "0.5885258316993713"
},
"metadata": {}
}
]
}
]
}