Implementation seq2seq
This commit is contained in:
parent
9a0d82c191
commit
42c89699db
640
seq2seq.ipynb
Normal file
640
seq2seq.ipynb
Normal file
@ -0,0 +1,640 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from __future__ import unicode_literals, print_function, division\n",
|
||||||
|
"from io import open\n",
|
||||||
|
"import unicodedata\n",
|
||||||
|
"import re\n",
|
||||||
|
"import random\n",
|
||||||
|
"import time\n",
|
||||||
|
"import math\n",
|
||||||
|
"\n",
|
||||||
|
"import torch\n",
|
||||||
|
"import torch.nn as nn\n",
|
||||||
|
"from torch import optim\n",
|
||||||
|
"import torch.nn.functional as F\n",
|
||||||
|
"\n",
|
||||||
|
"import numpy as np\n",
|
||||||
|
"from torch.utils.data import TensorDataset, DataLoader, RandomSampler\n",
|
||||||
|
"\n",
|
||||||
|
"import matplotlib.pyplot as plt\n",
|
||||||
|
"import matplotlib.ticker as ticker"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Reading lines...\n",
|
||||||
|
"Read 49943 sentence pairs\n",
|
||||||
|
"Trimmed to 3613 sentence pairs\n",
|
||||||
|
"Counting words...\n",
|
||||||
|
"Counted words:\n",
|
||||||
|
"pol 3070\n",
|
||||||
|
"en 1969\n",
|
||||||
|
"['jestes sumienny', 'you re conscientious']\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
||||||
|
"\n",
|
||||||
|
"SOS_token = 0\n",
|
||||||
|
"EOS_token = 1\n",
|
||||||
|
"\n",
|
||||||
|
"class Lang:\n",
|
||||||
|
" def __init__(self, name):\n",
|
||||||
|
" self.name = name\n",
|
||||||
|
" self.word2index = {}\n",
|
||||||
|
" self.word2count = {}\n",
|
||||||
|
" self.index2word = {0: \"SOS\", 1: \"EOS\"}\n",
|
||||||
|
" self.n_words = 2\n",
|
||||||
|
"\n",
|
||||||
|
" def addSentence(self, sentence):\n",
|
||||||
|
" for word in sentence.split(' '):\n",
|
||||||
|
" self.addWord(word)\n",
|
||||||
|
"\n",
|
||||||
|
" def addWord(self, word):\n",
|
||||||
|
" if word not in self.word2index:\n",
|
||||||
|
" self.word2index[word] = self.n_words\n",
|
||||||
|
" self.word2count[word] = 1\n",
|
||||||
|
" self.index2word[self.n_words] = word\n",
|
||||||
|
" self.n_words += 1\n",
|
||||||
|
" else:\n",
|
||||||
|
" self.word2count[word] += 1\n",
|
||||||
|
"\n",
|
||||||
|
"def unicodeToAscii(s):\n",
|
||||||
|
" return ''.join(\n",
|
||||||
|
" c for c in unicodedata.normalize('NFD', s)\n",
|
||||||
|
" if unicodedata.category(c) != 'Mn'\n",
|
||||||
|
" )\n",
|
||||||
|
"\n",
|
||||||
|
"def normalizeString(s):\n",
|
||||||
|
" s = unicodeToAscii(s.lower().strip())\n",
|
||||||
|
" s = re.sub(r\"([.!?])\", r\" \\1\", s)\n",
|
||||||
|
" s = re.sub(r\"[^a-zA-Z!?]+\", r\" \", s)\n",
|
||||||
|
" return s.strip()\n",
|
||||||
|
"\n",
|
||||||
|
"def readLangs(reverse=False):\n",
|
||||||
|
" print(\"Reading lines...\")\n",
|
||||||
|
" lang1=\"en\"\n",
|
||||||
|
" lang2=\"pol\"\n",
|
||||||
|
" # Read the file and split into lines\n",
|
||||||
|
" lines = open('pol.txt', encoding='utf-8').\\\n",
|
||||||
|
" read().strip().split('\\n')\n",
|
||||||
|
"\n",
|
||||||
|
" # Split every line into pairs and normalize\n",
|
||||||
|
" pairs = [[normalizeString(s) for s in l.split('\\t')[:-1]] for l in lines]\n",
|
||||||
|
"\n",
|
||||||
|
" # Reverse pairs, make Lang instances\n",
|
||||||
|
" if reverse:\n",
|
||||||
|
" pairs = [list(reversed(p)) for p in pairs]\n",
|
||||||
|
" input_lang = Lang(lang2)\n",
|
||||||
|
" output_lang = Lang(lang1)\n",
|
||||||
|
" else:\n",
|
||||||
|
" input_lang = Lang(lang1)\n",
|
||||||
|
" output_lang = Lang(lang2)\n",
|
||||||
|
"\n",
|
||||||
|
" return input_lang, output_lang, pairs\n",
|
||||||
|
"\n",
|
||||||
|
"MAX_LENGTH = 10\n",
|
||||||
|
"\n",
|
||||||
|
"eng_prefixes = (\n",
|
||||||
|
" \"i am \", \"i m \",\n",
|
||||||
|
" \"he is\", \"he s \",\n",
|
||||||
|
" \"she is\", \"she s \",\n",
|
||||||
|
" \"you are\", \"you re \",\n",
|
||||||
|
" \"we are\", \"we re \",\n",
|
||||||
|
" \"they are\", \"they re \"\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"def filterPair(p):\n",
|
||||||
|
" return len(p[0].split(' ')) < MAX_LENGTH and \\\n",
|
||||||
|
" len(p[1].split(' ')) < MAX_LENGTH and \\\n",
|
||||||
|
" p[1].startswith(eng_prefixes)\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"def filterPairs(pairs):\n",
|
||||||
|
" return [pair for pair in pairs if filterPair(pair)]\n",
|
||||||
|
"\n",
|
||||||
|
"def prepareData(reverse=False):\n",
|
||||||
|
" input_lang, output_lang, pairs = readLangs(reverse)\n",
|
||||||
|
" print(\"Read %s sentence pairs\" % len(pairs))\n",
|
||||||
|
" pairs = filterPairs(pairs)\n",
|
||||||
|
" print(\"Trimmed to %s sentence pairs\" % len(pairs))\n",
|
||||||
|
" print(\"Counting words...\")\n",
|
||||||
|
" for pair in pairs:\n",
|
||||||
|
" input_lang.addSentence(pair[0])\n",
|
||||||
|
" output_lang.addSentence(pair[1])\n",
|
||||||
|
" print(\"Counted words:\")\n",
|
||||||
|
" print(input_lang.name, input_lang.n_words)\n",
|
||||||
|
" print(output_lang.name, output_lang.n_words)\n",
|
||||||
|
" return input_lang, output_lang, pairs\n",
|
||||||
|
"\n",
|
||||||
|
"input_lang, output_lang, pairs = prepareData(True)\n",
|
||||||
|
"print(random.choice(pairs))"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 4,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"class EncoderRNN(nn.Module):\n",
|
||||||
|
" def __init__(self, input_size, hidden_size, dropout_p=0.1):\n",
|
||||||
|
" super(EncoderRNN, self).__init__()\n",
|
||||||
|
" self.hidden_size = hidden_size\n",
|
||||||
|
"\n",
|
||||||
|
" self.embedding = nn.Embedding(input_size, hidden_size)\n",
|
||||||
|
" self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True)\n",
|
||||||
|
" self.dropout = nn.Dropout(dropout_p)\n",
|
||||||
|
"\n",
|
||||||
|
" def forward(self, input):\n",
|
||||||
|
" embedded = self.dropout(self.embedding(input))\n",
|
||||||
|
" output, hidden = self.gru(embedded)\n",
|
||||||
|
" return output, hidden\n",
|
||||||
|
" \n",
|
||||||
|
"class DecoderRNN(nn.Module):\n",
|
||||||
|
" def __init__(self, hidden_size, output_size):\n",
|
||||||
|
" super(DecoderRNN, self).__init__()\n",
|
||||||
|
" self.embedding = nn.Embedding(output_size, hidden_size)\n",
|
||||||
|
" self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True)\n",
|
||||||
|
" self.out = nn.Linear(hidden_size, output_size)\n",
|
||||||
|
"\n",
|
||||||
|
" def forward(self, encoder_outputs, encoder_hidden, target_tensor=None):\n",
|
||||||
|
" batch_size = encoder_outputs.size(0)\n",
|
||||||
|
" decoder_input = torch.empty(batch_size, 1, dtype=torch.long, device=device).fill_(SOS_token)\n",
|
||||||
|
" decoder_hidden = encoder_hidden\n",
|
||||||
|
" decoder_outputs = []\n",
|
||||||
|
"\n",
|
||||||
|
" for i in range(MAX_LENGTH):\n",
|
||||||
|
" decoder_output, decoder_hidden = self.forward_step(decoder_input, decoder_hidden)\n",
|
||||||
|
" decoder_outputs.append(decoder_output)\n",
|
||||||
|
"\n",
|
||||||
|
" if target_tensor is not None:\n",
|
||||||
|
" # Teacher forcing: Feed the target as the next input\n",
|
||||||
|
" decoder_input = target_tensor[:, i].unsqueeze(1) # Teacher forcing\n",
|
||||||
|
" else:\n",
|
||||||
|
" # Without teacher forcing: use its own predictions as the next input\n",
|
||||||
|
" _, topi = decoder_output.topk(1)\n",
|
||||||
|
" decoder_input = topi.squeeze(-1).detach() # detach from history as input\n",
|
||||||
|
"\n",
|
||||||
|
" decoder_outputs = torch.cat(decoder_outputs, dim=1)\n",
|
||||||
|
" decoder_outputs = F.log_softmax(decoder_outputs, dim=-1)\n",
|
||||||
|
" return decoder_outputs, decoder_hidden, None # We return `None` for consistency in the training loop\n",
|
||||||
|
"\n",
|
||||||
|
" def forward_step(self, input, hidden):\n",
|
||||||
|
" output = self.embedding(input)\n",
|
||||||
|
" output = F.relu(output)\n",
|
||||||
|
" output, hidden = self.gru(output, hidden)\n",
|
||||||
|
" output = self.out(output)\n",
|
||||||
|
" return output, hidden"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 5,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"class BahdanauAttention(nn.Module):\n",
|
||||||
|
" def __init__(self, hidden_size):\n",
|
||||||
|
" super(BahdanauAttention, self).__init__()\n",
|
||||||
|
" self.Wa = nn.Linear(hidden_size, hidden_size)\n",
|
||||||
|
" self.Ua = nn.Linear(hidden_size, hidden_size)\n",
|
||||||
|
" self.Va = nn.Linear(hidden_size, 1)\n",
|
||||||
|
"\n",
|
||||||
|
" def forward(self, query, keys):\n",
|
||||||
|
" scores = self.Va(torch.tanh(self.Wa(query) + self.Ua(keys)))\n",
|
||||||
|
" scores = scores.squeeze(2).unsqueeze(1)\n",
|
||||||
|
"\n",
|
||||||
|
" weights = F.softmax(scores, dim=-1)\n",
|
||||||
|
" context = torch.bmm(weights, keys)\n",
|
||||||
|
"\n",
|
||||||
|
" return context, weights\n",
|
||||||
|
"\n",
|
||||||
|
"class AttnDecoderRNN(nn.Module):\n",
|
||||||
|
" def __init__(self, hidden_size, output_size, dropout_p=0.1):\n",
|
||||||
|
" super(AttnDecoderRNN, self).__init__()\n",
|
||||||
|
" self.embedding = nn.Embedding(output_size, hidden_size)\n",
|
||||||
|
" self.attention = BahdanauAttention(hidden_size)\n",
|
||||||
|
" self.gru = nn.GRU(2 * hidden_size, hidden_size, batch_first=True)\n",
|
||||||
|
" self.out = nn.Linear(hidden_size, output_size)\n",
|
||||||
|
" self.dropout = nn.Dropout(dropout_p)\n",
|
||||||
|
"\n",
|
||||||
|
" def forward(self, encoder_outputs, encoder_hidden, target_tensor=None):\n",
|
||||||
|
" batch_size = encoder_outputs.size(0)\n",
|
||||||
|
" decoder_input = torch.empty(batch_size, 1, dtype=torch.long, device=device).fill_(SOS_token)\n",
|
||||||
|
" decoder_hidden = encoder_hidden\n",
|
||||||
|
" decoder_outputs = []\n",
|
||||||
|
" attentions = []\n",
|
||||||
|
"\n",
|
||||||
|
" for i in range(MAX_LENGTH):\n",
|
||||||
|
" decoder_output, decoder_hidden, attn_weights = self.forward_step(\n",
|
||||||
|
" decoder_input, decoder_hidden, encoder_outputs\n",
|
||||||
|
" )\n",
|
||||||
|
" decoder_outputs.append(decoder_output)\n",
|
||||||
|
" attentions.append(attn_weights)\n",
|
||||||
|
"\n",
|
||||||
|
" if target_tensor is not None:\n",
|
||||||
|
" decoder_input = target_tensor[:, i].unsqueeze(1)\n",
|
||||||
|
" else:\n",
|
||||||
|
" _, topi = decoder_output.topk(1)\n",
|
||||||
|
" decoder_input = topi.squeeze(-1).detach()\n",
|
||||||
|
"\n",
|
||||||
|
" decoder_outputs = torch.cat(decoder_outputs, dim=1)\n",
|
||||||
|
" decoder_outputs = F.log_softmax(decoder_outputs, dim=-1)\n",
|
||||||
|
" attentions = torch.cat(attentions, dim=1)\n",
|
||||||
|
"\n",
|
||||||
|
" return decoder_outputs, decoder_hidden, attentions\n",
|
||||||
|
"\n",
|
||||||
|
" def forward_step(self, input, hidden, encoder_outputs):\n",
|
||||||
|
" embedded = self.dropout(self.embedding(input))\n",
|
||||||
|
"\n",
|
||||||
|
" query = hidden.permute(1, 0, 2)\n",
|
||||||
|
" context, attn_weights = self.attention(query, encoder_outputs)\n",
|
||||||
|
" input_gru = torch.cat((embedded, context), dim=2)\n",
|
||||||
|
"\n",
|
||||||
|
" output, hidden = self.gru(input_gru, hidden)\n",
|
||||||
|
" output = self.out(output)\n",
|
||||||
|
"\n",
|
||||||
|
" return output, hidden, attn_weights"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 6,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def indexesFromSentence(lang, sentence):\n",
|
||||||
|
" return [lang.word2index[word] for word in sentence.split(' ')]\n",
|
||||||
|
"\n",
|
||||||
|
"def tensorFromSentence(lang, sentence):\n",
|
||||||
|
" indexes = indexesFromSentence(lang, sentence)\n",
|
||||||
|
" indexes.append(EOS_token)\n",
|
||||||
|
" return torch.tensor(indexes, dtype=torch.long, device=device).view(1, -1)\n",
|
||||||
|
"\n",
|
||||||
|
"def tensorsFromPair(pair):\n",
|
||||||
|
" input_tensor = tensorFromSentence(input_lang, pair[0])\n",
|
||||||
|
" target_tensor = tensorFromSentence(output_lang, pair[1])\n",
|
||||||
|
" return (input_tensor, target_tensor)\n",
|
||||||
|
"\n",
|
||||||
|
"def get_dataloader(batch_size):\n",
|
||||||
|
" input_lang, output_lang, pairs = prepareData(True)\n",
|
||||||
|
"\n",
|
||||||
|
" n = len(pairs)\n",
|
||||||
|
" input_ids = np.zeros((n, MAX_LENGTH), dtype=np.int32)\n",
|
||||||
|
" target_ids = np.zeros((n, MAX_LENGTH), dtype=np.int32)\n",
|
||||||
|
"\n",
|
||||||
|
" for idx, (inp, tgt) in enumerate(pairs):\n",
|
||||||
|
" inp_ids = indexesFromSentence(input_lang, inp)\n",
|
||||||
|
" tgt_ids = indexesFromSentence(output_lang, tgt)\n",
|
||||||
|
" inp_ids.append(EOS_token)\n",
|
||||||
|
" tgt_ids.append(EOS_token)\n",
|
||||||
|
" input_ids[idx, :len(inp_ids)] = inp_ids\n",
|
||||||
|
" target_ids[idx, :len(tgt_ids)] = tgt_ids\n",
|
||||||
|
"\n",
|
||||||
|
" train_data = TensorDataset(torch.LongTensor(input_ids).to(device),\n",
|
||||||
|
" torch.LongTensor(target_ids).to(device))\n",
|
||||||
|
"\n",
|
||||||
|
" train_sampler = RandomSampler(train_data)\n",
|
||||||
|
" train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size)\n",
|
||||||
|
" return input_lang, output_lang, train_dataloader"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 7,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def train_epoch(dataloader, encoder, decoder, encoder_optimizer,\n",
|
||||||
|
" decoder_optimizer, criterion):\n",
|
||||||
|
"\n",
|
||||||
|
" total_loss = 0\n",
|
||||||
|
" for data in dataloader:\n",
|
||||||
|
" input_tensor, target_tensor = data\n",
|
||||||
|
"\n",
|
||||||
|
" encoder_optimizer.zero_grad()\n",
|
||||||
|
" decoder_optimizer.zero_grad()\n",
|
||||||
|
"\n",
|
||||||
|
" encoder_outputs, encoder_hidden = encoder(input_tensor)\n",
|
||||||
|
" decoder_outputs, _, _ = decoder(encoder_outputs, encoder_hidden, target_tensor)\n",
|
||||||
|
"\n",
|
||||||
|
" loss = criterion(\n",
|
||||||
|
" decoder_outputs.view(-1, decoder_outputs.size(-1)),\n",
|
||||||
|
" target_tensor.view(-1)\n",
|
||||||
|
" )\n",
|
||||||
|
" loss.backward()\n",
|
||||||
|
"\n",
|
||||||
|
" encoder_optimizer.step()\n",
|
||||||
|
" decoder_optimizer.step()\n",
|
||||||
|
"\n",
|
||||||
|
" total_loss += loss.item()\n",
|
||||||
|
"\n",
|
||||||
|
" return total_loss / len(dataloader)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 8,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def asMinutes(s):\n",
|
||||||
|
" m = math.floor(s / 60)\n",
|
||||||
|
" s -= m * 60\n",
|
||||||
|
" return '%dm %ds' % (m, s)\n",
|
||||||
|
"\n",
|
||||||
|
"def timeSince(since, percent):\n",
|
||||||
|
" now = time.time()\n",
|
||||||
|
" s = now - since\n",
|
||||||
|
" es = s / (percent)\n",
|
||||||
|
" rs = es - s\n",
|
||||||
|
" return '%s (- %s)' % (asMinutes(s), asMinutes(rs))"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 9,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def train(train_dataloader, encoder, decoder, n_epochs, learning_rate=0.001,\n",
|
||||||
|
" print_every=100, plot_every=100):\n",
|
||||||
|
" start = time.time()\n",
|
||||||
|
" plot_losses = []\n",
|
||||||
|
" print_loss_total = 0\n",
|
||||||
|
" plot_loss_total = 0\n",
|
||||||
|
"\n",
|
||||||
|
" encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)\n",
|
||||||
|
" decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate)\n",
|
||||||
|
" criterion = nn.NLLLoss()\n",
|
||||||
|
"\n",
|
||||||
|
" for epoch in range(1, n_epochs + 1):\n",
|
||||||
|
" loss = train_epoch(train_dataloader, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion)\n",
|
||||||
|
" print_loss_total += loss\n",
|
||||||
|
" plot_loss_total += loss\n",
|
||||||
|
"\n",
|
||||||
|
" if epoch % print_every == 0:\n",
|
||||||
|
" print_loss_avg = print_loss_total / print_every\n",
|
||||||
|
" print_loss_total = 0\n",
|
||||||
|
" print('%s (%d %d%%) %.4f' % (timeSince(start, epoch / n_epochs),\n",
|
||||||
|
" epoch, epoch / n_epochs * 100, print_loss_avg))\n",
|
||||||
|
"\n",
|
||||||
|
" if epoch % plot_every == 0:\n",
|
||||||
|
" plot_loss_avg = plot_loss_total / plot_every\n",
|
||||||
|
" plot_losses.append(plot_loss_avg)\n",
|
||||||
|
" plot_loss_total = 0\n",
|
||||||
|
"\n",
|
||||||
|
" showPlot(plot_losses)\n",
|
||||||
|
"\n",
|
||||||
|
"plt.switch_backend('agg')\n",
|
||||||
|
"\n",
|
||||||
|
"def showPlot(points):\n",
|
||||||
|
" plt.figure()\n",
|
||||||
|
" fig, ax = plt.subplots()\n",
|
||||||
|
" loc = ticker.MultipleLocator(base=0.2)\n",
|
||||||
|
" ax.yaxis.set_major_locator(loc)\n",
|
||||||
|
" plt.plot(points)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 10,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def evaluate(encoder, decoder, sentence, input_lang, output_lang):\n",
|
||||||
|
" with torch.no_grad():\n",
|
||||||
|
" input_tensor = tensorFromSentence(input_lang, sentence)\n",
|
||||||
|
"\n",
|
||||||
|
" encoder_outputs, encoder_hidden = encoder(input_tensor)\n",
|
||||||
|
" decoder_outputs, decoder_hidden, decoder_attn = decoder(encoder_outputs, encoder_hidden)\n",
|
||||||
|
"\n",
|
||||||
|
" _, topi = decoder_outputs.topk(1)\n",
|
||||||
|
" decoded_ids = topi.squeeze()\n",
|
||||||
|
"\n",
|
||||||
|
" decoded_words = []\n",
|
||||||
|
" for idx in decoded_ids:\n",
|
||||||
|
" if idx.item() == EOS_token:\n",
|
||||||
|
" decoded_words.append('<EOS>')\n",
|
||||||
|
" break\n",
|
||||||
|
" decoded_words.append(output_lang.index2word[idx.item()])\n",
|
||||||
|
" return decoded_words, decoder_attn\n",
|
||||||
|
"\n",
|
||||||
|
"def evaluateRandomly(encoder, decoder, n=10):\n",
|
||||||
|
" for i in range(n):\n",
|
||||||
|
" pair = random.choice(pairs)\n",
|
||||||
|
" print('>', pair[0])\n",
|
||||||
|
" print('=', pair[1])\n",
|
||||||
|
" output_words, _ = evaluate(encoder, decoder, pair[0], input_lang, output_lang)\n",
|
||||||
|
" output_sentence = ' '.join(output_words)\n",
|
||||||
|
" print('<', output_sentence)\n",
|
||||||
|
" print('')"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 11,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Reading lines...\n",
|
||||||
|
"Read 49943 sentence pairs\n",
|
||||||
|
"Trimmed to 3613 sentence pairs\n",
|
||||||
|
"Counting words...\n",
|
||||||
|
"Counted words:\n",
|
||||||
|
"pol 3070\n",
|
||||||
|
"en 1969\n",
|
||||||
|
"0m 44s (- 11m 8s) (5 6%) 2.0979\n",
|
||||||
|
"1m 26s (- 10m 5s) (10 12%) 1.2611\n",
|
||||||
|
"2m 7s (- 9m 14s) (15 18%) 0.8754\n",
|
||||||
|
"2m 48s (- 8m 26s) (20 25%) 0.5951\n",
|
||||||
|
"3m 29s (- 7m 41s) (25 31%) 0.3932\n",
|
||||||
|
"4m 10s (- 6m 57s) (30 37%) 0.2515\n",
|
||||||
|
"4m 51s (- 6m 14s) (35 43%) 0.1600\n",
|
||||||
|
"5m 32s (- 5m 32s) (40 50%) 0.1037\n",
|
||||||
|
"6m 15s (- 4m 51s) (45 56%) 0.0701\n",
|
||||||
|
"6m 55s (- 4m 9s) (50 62%) 0.0530\n",
|
||||||
|
"7m 36s (- 3m 27s) (55 68%) 0.0424\n",
|
||||||
|
"8m 16s (- 2m 45s) (60 75%) 0.0374\n",
|
||||||
|
"8m 58s (- 2m 4s) (65 81%) 0.0318\n",
|
||||||
|
"9m 39s (- 1m 22s) (70 87%) 0.0287\n",
|
||||||
|
"10m 20s (- 0m 41s) (75 93%) 0.0279\n",
|
||||||
|
"11m 1s (- 0m 0s) (80 100%) 0.0246\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"hidden_size = 128\n",
|
||||||
|
"batch_size = 32\n",
|
||||||
|
"\n",
|
||||||
|
"input_lang, output_lang, train_dataloader = get_dataloader(batch_size)\n",
|
||||||
|
"\n",
|
||||||
|
"encoder = EncoderRNN(input_lang.n_words, hidden_size).to(device)\n",
|
||||||
|
"decoder = AttnDecoderRNN(hidden_size, output_lang.n_words).to(device)\n",
|
||||||
|
"\n",
|
||||||
|
"train(train_dataloader, encoder, decoder, 80, print_every=5, plot_every=5)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 16,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"> wchodze w to\n",
|
||||||
|
"= i m game\n",
|
||||||
|
"< i m game <EOS>\n",
|
||||||
|
"\n",
|
||||||
|
"> on jest o dwa lata starszy od ciebie\n",
|
||||||
|
"= he is two years older than you\n",
|
||||||
|
"< he is two years older than you is questions <EOS>\n",
|
||||||
|
"\n",
|
||||||
|
"> wstydze sie za siebie\n",
|
||||||
|
"= i m ashamed of myself\n",
|
||||||
|
"< i m ashamed of myself <EOS>\n",
|
||||||
|
"\n",
|
||||||
|
"> nie wchodze w to\n",
|
||||||
|
"= i am not getting involved\n",
|
||||||
|
"< i am not getting involved <EOS>\n",
|
||||||
|
"\n",
|
||||||
|
"> jestes moja przyjacio ka\n",
|
||||||
|
"= you are my friend\n",
|
||||||
|
"< you are my friend <EOS>\n",
|
||||||
|
"\n",
|
||||||
|
"> jestem naga\n",
|
||||||
|
"= i m naked\n",
|
||||||
|
"< i m naked <EOS>\n",
|
||||||
|
"\n",
|
||||||
|
"> naprawde nie jestem az tak zajety\n",
|
||||||
|
"= i m really not all that busy\n",
|
||||||
|
"< i m really not all that busy that <EOS>\n",
|
||||||
|
"\n",
|
||||||
|
"> pracuje dla firmy handlowej\n",
|
||||||
|
"= i m working for a trading firm\n",
|
||||||
|
"< i m working for a trading firm <EOS>\n",
|
||||||
|
"\n",
|
||||||
|
"> jestem rysownikiem\n",
|
||||||
|
"= i m a cartoonist\n",
|
||||||
|
"< i m a cartoonist <EOS>\n",
|
||||||
|
"\n",
|
||||||
|
"> wyjezdzasz dopiero jutro prawda ?\n",
|
||||||
|
"= you aren t leaving until tomorrow right ?\n",
|
||||||
|
"< you aren t leaving until tomorrow right ? aren t\n",
|
||||||
|
"\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"evaluateRandomly(encoder, decoder)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## BLEU"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 19,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"[nltk_data] Downloading package punkt to\n",
|
||||||
|
"[nltk_data] C:\\Users\\mateu\\AppData\\Roaming\\nltk_data...\n",
|
||||||
|
"[nltk_data] Unzipping tokenizers\\punkt.zip.\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"BLEU score: 0.7677458355439187\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"import pandas as pd\n",
|
||||||
|
"from nltk.translate.bleu_score import corpus_bleu, SmoothingFunction\n",
|
||||||
|
"import nltk\n",
|
||||||
|
"nltk.download('punkt')\n",
|
||||||
|
"\n",
|
||||||
|
"def filter_data(data, max_length, prefixes):\n",
|
||||||
|
" filtered_data = data[\n",
|
||||||
|
" data.apply(lambda row: len(row[\"English\"].split()) < max_length and\n",
|
||||||
|
" len(row[\"Polish\"].split()) < max_length and\n",
|
||||||
|
" row[\"English\"].startswith(tuple(prefixes)), axis=1)\n",
|
||||||
|
" ]\n",
|
||||||
|
" return filtered_data\n",
|
||||||
|
"\n",
|
||||||
|
"# Load and normalize data\n",
|
||||||
|
"data_file = pd.read_csv(\"pol.txt\", sep='\\t', names=[\"English\", \"Polish\", \"attribution\"])\n",
|
||||||
|
"data_file[\"English\"] = data_file[\"English\"].apply(normalizeString)\n",
|
||||||
|
"data_file[\"Polish\"] = data_file[\"Polish\"].apply(normalizeString)\n",
|
||||||
|
"\n",
|
||||||
|
"# Filter data\n",
|
||||||
|
"filtered_data = filter_data(data_file, MAX_LENGTH, eng_prefixes)\n",
|
||||||
|
"test_section = filtered_data.sample(frac=1).head(500)\n",
|
||||||
|
"\n",
|
||||||
|
"# Tokenize and translate\n",
|
||||||
|
"test_section[\"English_tokenized\"] = test_section[\"English\"].apply(nltk.word_tokenize)\n",
|
||||||
|
"test_section[\"English_translated\"] = test_section[\"Polish\"].apply(lambda x: translate(x, tokenized=True))\n",
|
||||||
|
"\n",
|
||||||
|
"# Prepare corpus for BLEU calculation\n",
|
||||||
|
"candidate_corpus = test_section[\"English_translated\"].tolist()\n",
|
||||||
|
"references_corpus = [[ref] for ref in test_section[\"English_tokenized\"].tolist()]\n",
|
||||||
|
"\n",
|
||||||
|
"# Calculate BLEU score\n",
|
||||||
|
"smooth_fn = SmoothingFunction().method4\n",
|
||||||
|
"bleu = corpus_bleu(references_corpus, candidate_corpus, smoothing_function=smooth_fn)\n",
|
||||||
|
"print(\"BLEU score:\", bleu)"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "aienv",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.9.19"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 2
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user