aitech-moj-2023/cw/11_Model_rekurencyjny_z_atencją.ipynb

518 lines
16 KiB
Plaintext
Raw Normal View History

2022-05-29 18:14:19 +02:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"![Logo 1](https://git.wmi.amu.edu.pl/AITech/Szablon/raw/branch/master/Logotyp_AITech1.jpg)\n",
"<div class=\"alert alert-block alert-info\">\n",
"<h1> Modelowanie Języka</h1>\n",
"<h2> 10. <i>Model rekurencyjny z atencją</i> [ćwiczenia]</h2> \n",
"<h3> Jakub Pokrywka (2022)</h3>\n",
"</div>\n",
"\n",
"![Logo 2](https://git.wmi.amu.edu.pl/AITech/Szablon/raw/branch/master/Logotyp_AITech2.jpg)"
]
},
{
2022-05-29 20:00:36 +02:00
"cell_type": "markdown",
2022-05-29 18:14:19 +02:00
"metadata": {},
"source": [
2022-05-29 20:00:36 +02:00
"notebook na podstawie:\n",
"\n",
2022-05-29 18:14:19 +02:00
"# https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html"
]
},
{
"cell_type": "code",
2022-05-30 09:32:20 +02:00
"execution_count": null,
2022-05-29 18:14:19 +02:00
"metadata": {},
"outputs": [],
"source": [
"from __future__ import unicode_literals, print_function, division\n",
"from io import open\n",
"import unicodedata\n",
"import string\n",
"import re\n",
"import random\n",
"\n",
"import torch\n",
"import torch.nn as nn\n",
"from torch import optim\n",
"import torch.nn.functional as F\n",
"\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
]
},
{
"cell_type": "code",
2022-05-30 09:32:20 +02:00
"execution_count": null,
2022-05-29 18:14:19 +02:00
"metadata": {},
"outputs": [],
"source": [
"SOS_token = 0\n",
"EOS_token = 1\n",
"\n",
"class Lang:\n",
" def __init__(self):\n",
" self.word2index = {}\n",
" self.word2count = {}\n",
" self.index2word = {0: \"SOS\", 1: \"EOS\"}\n",
" self.n_words = 2 # Count SOS and EOS\n",
"\n",
" def addSentence(self, sentence):\n",
" for word in sentence.split(' '):\n",
" self.addWord(word)\n",
"\n",
" def addWord(self, word):\n",
" if word not in self.word2index:\n",
" self.word2index[word] = self.n_words\n",
" self.word2count[word] = 1\n",
" self.index2word[self.n_words] = word\n",
" self.n_words += 1\n",
" else:\n",
" self.word2count[word] += 1"
]
},
{
"cell_type": "code",
2022-05-30 09:32:20 +02:00
"execution_count": null,
2022-05-29 18:14:19 +02:00
"metadata": {},
"outputs": [],
"source": [
"pairs = []\n",
2022-05-29 19:05:03 +02:00
"with open('data/eng-pol.txt') as f:\n",
2022-05-29 18:14:19 +02:00
" for line in f:\n",
2022-05-29 19:05:03 +02:00
" eng_line, pol_line = line.lower().rstrip().split('\\t')\n",
2022-05-29 18:14:19 +02:00
"\n",
" eng_line = re.sub(r\"([.!?])\", r\" \\1\", eng_line)\n",
" eng_line = re.sub(r\"[^a-zA-Z.!?]+\", r\" \", eng_line)\n",
"\n",
2022-05-29 19:05:03 +02:00
" pol_line = re.sub(r\"([.!?])\", r\" \\1\", pol_line)\n",
2022-05-29 20:00:36 +02:00
" pol_line = re.sub(r\"[^a-zA-Z.!?ąćęłńóśźżĄĆĘŁŃÓŚŹŻ]+\", r\" \", pol_line)\n",
2022-05-29 18:14:19 +02:00
"\n",
2022-05-29 19:05:03 +02:00
" pairs.append([eng_line, pol_line])\n",
2022-05-29 18:14:19 +02:00
"\n",
"\n"
]
},
{
"cell_type": "code",
2022-05-30 09:32:20 +02:00
"execution_count": null,
2022-05-29 18:14:19 +02:00
"metadata": {},
2022-05-30 09:32:20 +02:00
"outputs": [],
2022-05-29 18:14:19 +02:00
"source": [
"pairs[1]"
]
},
{
"cell_type": "code",
2022-05-30 09:32:20 +02:00
"execution_count": null,
2022-05-29 18:14:19 +02:00
"metadata": {},
"outputs": [],
"source": [
"MAX_LENGTH = 10\n",
"eng_prefixes = (\n",
" \"i am \", \"i m \",\n",
" \"he is\", \"he s \",\n",
" \"she is\", \"she s \",\n",
" \"you are\", \"you re \",\n",
" \"we are\", \"we re \",\n",
" \"they are\", \"they re \"\n",
")\n",
"\n",
"pairs = [p for p in pairs if len(p[0].split(' ')) < MAX_LENGTH and len(p[1].split(' ')) < MAX_LENGTH]\n",
"pairs = [p for p in pairs if p[0].startswith(eng_prefixes)]\n",
"\n",
"eng_lang = Lang()\n",
2022-05-29 19:05:03 +02:00
"pol_lang = Lang()\n",
2022-05-29 18:14:19 +02:00
"\n",
"for pair in pairs:\n",
" eng_lang.addSentence(pair[0])\n",
2022-05-29 19:05:03 +02:00
" pol_lang.addSentence(pair[1])"
2022-05-29 18:14:19 +02:00
]
},
{
"cell_type": "code",
2022-05-30 09:32:20 +02:00
"execution_count": null,
2022-05-29 18:14:19 +02:00
"metadata": {},
2022-05-30 09:32:20 +02:00
"outputs": [],
2022-05-29 18:14:19 +02:00
"source": [
"pairs[0]"
]
},
{
"cell_type": "code",
2022-05-30 09:32:20 +02:00
"execution_count": null,
2022-05-29 18:14:19 +02:00
"metadata": {},
2022-05-30 09:32:20 +02:00
"outputs": [],
2022-05-29 18:14:19 +02:00
"source": [
"pairs[1]"
]
},
{
"cell_type": "code",
2022-05-30 09:32:20 +02:00
"execution_count": null,
2022-05-29 18:14:19 +02:00
"metadata": {},
2022-05-30 09:32:20 +02:00
"outputs": [],
2022-05-29 18:14:19 +02:00
"source": [
"pairs[2]"
]
},
2022-05-29 21:24:53 +02:00
{
"cell_type": "code",
2022-05-30 09:32:20 +02:00
"execution_count": null,
2022-05-29 21:24:53 +02:00
"metadata": {},
2022-05-30 09:32:20 +02:00
"outputs": [],
2022-05-29 21:24:53 +02:00
"source": [
"eng_lang.n_words"
]
},
{
"cell_type": "code",
2022-05-30 09:32:20 +02:00
"execution_count": null,
2022-05-29 21:24:53 +02:00
"metadata": {},
2022-05-30 09:32:20 +02:00
"outputs": [],
2022-05-29 21:24:53 +02:00
"source": [
"pol_lang.n_words"
]
},
2022-05-29 18:14:19 +02:00
{
"cell_type": "code",
2022-05-30 09:32:20 +02:00
"execution_count": null,
2022-05-29 18:14:19 +02:00
"metadata": {},
"outputs": [],
"source": [
"class EncoderRNN(nn.Module):\n",
2022-05-30 09:17:21 +02:00
" def __init__(self, input_size, embedding_size, hidden_size):\n",
2022-05-29 18:14:19 +02:00
" super(EncoderRNN, self).__init__()\n",
2022-05-30 09:17:21 +02:00
" self.embedding_size = 200\n",
2022-05-29 18:14:19 +02:00
" self.hidden_size = hidden_size\n",
"\n",
2022-05-30 09:17:21 +02:00
" self.embedding = nn.Embedding(input_size, self.embedding_size)\n",
" self.gru = nn.GRU(self.embedding_size, hidden_size)\n",
2022-05-29 18:14:19 +02:00
"\n",
" def forward(self, input, hidden):\n",
" embedded = self.embedding(input).view(1, 1, -1)\n",
" output = embedded\n",
" output, hidden = self.gru(output, hidden)\n",
" return output, hidden\n",
"\n",
" def initHidden(self):\n",
" return torch.zeros(1, 1, self.hidden_size, device=device)"
]
},
{
"cell_type": "code",
2022-05-30 09:32:20 +02:00
"execution_count": null,
2022-05-29 18:14:19 +02:00
"metadata": {},
"outputs": [],
"source": [
"class DecoderRNN(nn.Module):\n",
2022-05-30 09:17:21 +02:00
" def __init__(self, embedding_size, hidden_size, output_size):\n",
2022-05-29 18:14:19 +02:00
" super(DecoderRNN, self).__init__()\n",
2022-05-30 09:17:21 +02:00
" self.embedding_size = embedding_size\n",
2022-05-29 18:14:19 +02:00
" self.hidden_size = hidden_size\n",
"\n",
2022-05-30 09:17:21 +02:00
" self.embedding = nn.Embedding(output_size, self.embedding_size)\n",
" self.gru = nn.GRU(self.embedding_size, hidden_size)\n",
2022-05-29 18:14:19 +02:00
" self.out = nn.Linear(hidden_size, output_size)\n",
" self.softmax = nn.LogSoftmax(dim=1)\n",
"\n",
" def forward(self, input, hidden):\n",
" output = self.embedding(input).view(1, 1, -1)\n",
" output = F.relu(output)\n",
" output, hidden = self.gru(output, hidden)\n",
" output = self.softmax(self.out(output[0]))\n",
" return output, hidden\n",
"\n",
" def initHidden(self):\n",
" return torch.zeros(1, 1, self.hidden_size, device=device)"
]
},
{
"cell_type": "code",
2022-05-30 09:32:20 +02:00
"execution_count": null,
2022-05-29 18:14:19 +02:00
"metadata": {},
"outputs": [],
"source": [
"class AttnDecoderRNN(nn.Module):\n",
2022-05-30 09:17:21 +02:00
" def __init__(self, embedding_size, hidden_size, output_size, dropout_p=0.1, max_length=MAX_LENGTH):\n",
2022-05-29 18:14:19 +02:00
" super(AttnDecoderRNN, self).__init__()\n",
2022-05-30 09:17:21 +02:00
" self.embedding_size = embedding_size\n",
2022-05-29 18:14:19 +02:00
" self.hidden_size = hidden_size\n",
" self.output_size = output_size\n",
" self.dropout_p = dropout_p\n",
" self.max_length = max_length\n",
"\n",
2022-05-30 09:17:21 +02:00
" self.embedding = nn.Embedding(self.output_size, self.embedding_size)\n",
" self.attn = nn.Linear(self.hidden_size + self.embedding_size, self.max_length)\n",
" self.attn_combine = nn.Linear(self.hidden_size + self.embedding_size, self.embedding_size)\n",
2022-05-29 18:14:19 +02:00
" self.dropout = nn.Dropout(self.dropout_p)\n",
2022-05-30 09:17:21 +02:00
" self.gru = nn.GRU(self.embedding_size, self.hidden_size)\n",
2022-05-29 18:14:19 +02:00
" self.out = nn.Linear(self.hidden_size, self.output_size)\n",
"\n",
" def forward(self, input, hidden, encoder_outputs):\n",
" embedded = self.embedding(input).view(1, 1, -1)\n",
" embedded = self.dropout(embedded)\n",
"\n",
" attn_weights = F.softmax(\n",
" self.attn(torch.cat((embedded[0], hidden[0]), 1)), dim=1)\n",
" attn_applied = torch.bmm(attn_weights.unsqueeze(0),\n",
" encoder_outputs.unsqueeze(0))\n",
2022-05-30 09:17:21 +02:00
" #import pdb; pdb.set_trace()\n",
2022-05-29 18:14:19 +02:00
"\n",
" output = torch.cat((embedded[0], attn_applied[0]), 1)\n",
" output = self.attn_combine(output).unsqueeze(0)\n",
"\n",
" output = F.relu(output)\n",
" output, hidden = self.gru(output, hidden)\n",
"\n",
" output = F.log_softmax(self.out(output[0]), dim=1)\n",
" return output, hidden, attn_weights\n",
"\n",
" def initHidden(self):\n",
" return torch.zeros(1, 1, self.hidden_size, device=device)"
]
},
{
"cell_type": "code",
2022-05-30 09:32:20 +02:00
"execution_count": null,
2022-05-29 18:14:19 +02:00
"metadata": {},
"outputs": [],
"source": [
"def tensorFromSentence(sentence, lang):\n",
" indexes = [lang.word2index[word] for word in sentence.split(' ')]\n",
" indexes.append(EOS_token)\n",
" return torch.tensor(indexes, dtype=torch.long, device=device).view(-1, 1)\n"
]
},
{
"cell_type": "code",
2022-05-30 09:32:20 +02:00
"execution_count": null,
2022-05-29 18:14:19 +02:00
"metadata": {},
"outputs": [],
"source": [
"teacher_forcing_ratio = 0.5\n",
"\n",
2022-05-29 21:24:53 +02:00
"def train_one_batch(input_tensor, target_tensor, encoder, decoder, optimizer, criterion, max_length=MAX_LENGTH):\n",
2022-05-29 18:14:19 +02:00
" encoder_hidden = encoder.initHidden()\n",
"\n",
2022-05-29 21:24:53 +02:00
"\n",
" optimizer.zero_grad()\n",
2022-05-29 18:14:19 +02:00
"\n",
" input_length = input_tensor.size(0)\n",
" target_length = target_tensor.size(0)\n",
"\n",
" encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device)\n",
"\n",
" loss = 0\n",
"\n",
" for ei in range(input_length):\n",
2022-05-29 19:05:03 +02:00
" encoder_output, encoder_hidden = encoder(input_tensor[ei], encoder_hidden)\n",
2022-05-29 18:14:19 +02:00
" encoder_outputs[ei] = encoder_output[0, 0]\n",
"\n",
" decoder_input = torch.tensor([[SOS_token]], device=device)\n",
"\n",
" decoder_hidden = encoder_hidden\n",
"\n",
" use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False\n",
"\n",
" if use_teacher_forcing:\n",
" for di in range(target_length):\n",
2022-05-29 19:05:03 +02:00
" decoder_output, decoder_hidden, decoder_attention = decoder(decoder_input, decoder_hidden, encoder_outputs)\n",
2022-05-29 18:14:19 +02:00
" loss += criterion(decoder_output, target_tensor[di])\n",
" decoder_input = target_tensor[di] # Teacher forcing\n",
"\n",
" else:\n",
" for di in range(target_length):\n",
2022-05-29 19:05:03 +02:00
" decoder_output, decoder_hidden, decoder_attention = decoder(decoder_input, decoder_hidden, encoder_outputs)\n",
2022-05-29 18:14:19 +02:00
" topv, topi = decoder_output.topk(1)\n",
" decoder_input = topi.squeeze().detach() # detach from history as input\n",
"\n",
" loss += criterion(decoder_output, target_tensor[di])\n",
" if decoder_input.item() == EOS_token:\n",
" break\n",
"\n",
" loss.backward()\n",
"\n",
2022-05-29 21:24:53 +02:00
" optimizer.step()\n",
2022-05-29 18:14:19 +02:00
"\n",
" return loss.item() / target_length"
]
},
{
"cell_type": "code",
2022-05-30 09:32:20 +02:00
"execution_count": null,
2022-05-29 18:14:19 +02:00
"metadata": {},
"outputs": [],
"source": [
"def trainIters(encoder, decoder, n_iters, print_every=1000, learning_rate=0.01):\n",
" print_loss_total = 0 # Reset every print_every\n",
2022-05-29 21:24:53 +02:00
" encoder.train()\n",
" decoder.train()\n",
2022-05-29 18:14:19 +02:00
"\n",
2022-05-29 21:24:53 +02:00
" optimizer = optim.SGD(list(encoder.parameters()) + list(decoder.parameters()), lr=learning_rate)\n",
2022-05-29 18:14:19 +02:00
" \n",
" training_pairs = [random.choice(pairs) for _ in range(n_iters)]\n",
2022-05-29 19:05:03 +02:00
" training_pairs = [(tensorFromSentence(p[0], eng_lang), tensorFromSentence(p[1], pol_lang)) for p in training_pairs]\n",
2022-05-29 18:14:19 +02:00
" \n",
" criterion = nn.NLLLoss()\n",
"\n",
" for i in range(1, n_iters + 1):\n",
" training_pair = training_pairs[i - 1]\n",
" input_tensor = training_pair[0]\n",
" target_tensor = training_pair[1]\n",
"\n",
" loss = train_one_batch(input_tensor,\n",
" target_tensor,\n",
" encoder,\n",
2022-05-29 19:05:03 +02:00
" decoder,\n",
2022-05-29 21:24:53 +02:00
" optimizer,\n",
"\n",
2022-05-29 18:14:19 +02:00
" criterion)\n",
" \n",
" print_loss_total += loss\n",
"\n",
" if i % print_every == 0:\n",
" print_loss_avg = print_loss_total / print_every\n",
" print_loss_total = 0\n",
" print(f'iter: {i}, loss: {print_loss_avg}')\n"
]
},
{
"cell_type": "code",
2022-05-30 09:32:20 +02:00
"execution_count": null,
2022-05-29 18:14:19 +02:00
"metadata": {},
"outputs": [],
"source": [
"def evaluate(encoder, decoder, sentence, max_length=MAX_LENGTH):\n",
2022-05-29 21:24:53 +02:00
" encoder.eval()\n",
" decoder.eval()\n",
2022-05-29 18:14:19 +02:00
" with torch.no_grad():\n",
" input_tensor = tensorFromSentence(sentence, eng_lang)\n",
" input_length = input_tensor.size()[0]\n",
" encoder_hidden = encoder.initHidden()\n",
"\n",
" encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device)\n",
"\n",
" for ei in range(input_length):\n",
2022-05-29 21:24:53 +02:00
" encoder_output, encoder_hidden = encoder(input_tensor[ei], encoder_hidden)\n",
2022-05-29 18:14:19 +02:00
" encoder_outputs[ei] += encoder_output[0, 0]\n",
"\n",
2022-05-29 21:24:53 +02:00
" decoder_input = torch.tensor([[SOS_token]], device=device)\n",
2022-05-29 18:14:19 +02:00
"\n",
" decoder_hidden = encoder_hidden\n",
"\n",
" decoded_words = []\n",
" decoder_attentions = torch.zeros(max_length, max_length)\n",
"\n",
" for di in range(max_length):\n",
" decoder_output, decoder_hidden, decoder_attention = decoder(\n",
" decoder_input, decoder_hidden, encoder_outputs)\n",
" decoder_attentions[di] = decoder_attention.data\n",
" topv, topi = decoder_output.data.topk(1)\n",
" if topi.item() == EOS_token:\n",
" decoded_words.append('<EOS>')\n",
" break\n",
" else:\n",
2022-05-29 19:05:03 +02:00
" decoded_words.append(pol_lang.index2word[topi.item()])\n",
2022-05-29 18:14:19 +02:00
"\n",
" decoder_input = topi.squeeze().detach()\n",
"\n",
" return decoded_words, decoder_attentions[:di + 1]"
]
},
{
"cell_type": "code",
2022-05-30 09:32:20 +02:00
"execution_count": null,
2022-05-29 18:14:19 +02:00
"metadata": {},
"outputs": [],
"source": [
"def evaluateRandomly(encoder, decoder, n=10):\n",
" for i in range(n):\n",
" pair = random.choice(pairs)\n",
" print('>', pair[0])\n",
" print('=', pair[1])\n",
" output_words, attentions = evaluate(encoder, decoder, pair[0])\n",
" output_sentence = ' '.join(output_words)\n",
" print('<', output_sentence)\n",
" print('')"
]
},
{
"cell_type": "code",
2022-05-30 09:32:20 +02:00
"execution_count": null,
2022-05-29 18:14:19 +02:00
"metadata": {},
2022-05-29 21:24:53 +02:00
"outputs": [],
"source": [
2022-05-30 09:17:21 +02:00
"embedding_size = 200\n",
2022-05-29 21:24:53 +02:00
"hidden_size = 256\n",
2022-05-30 09:17:21 +02:00
"encoder1 = EncoderRNN(eng_lang.n_words, embedding_size, hidden_size).to(device)\n",
"attn_decoder1 = AttnDecoderRNN(embedding_size, hidden_size, pol_lang.n_words, dropout_p=0.1).to(device)"
2022-05-29 21:24:53 +02:00
]
},
{
"cell_type": "code",
2022-05-30 09:32:20 +02:00
"execution_count": null,
2022-05-29 21:24:53 +02:00
"metadata": {},
2022-05-30 09:32:20 +02:00
"outputs": [],
2022-05-29 18:14:19 +02:00
"source": [
2022-05-29 21:24:53 +02:00
"trainIters(encoder1, attn_decoder1, 10_000, print_every=50)"
2022-05-29 18:14:19 +02:00
]
},
{
"cell_type": "code",
2022-05-30 09:32:20 +02:00
"execution_count": null,
2022-05-29 21:35:52 +02:00
"metadata": {
"scrolled": true
},
2022-05-30 09:32:20 +02:00
"outputs": [],
2022-05-29 18:14:19 +02:00
"source": [
"evaluateRandomly(encoder1, attn_decoder1)"
]
2022-05-29 21:35:52 +02:00
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
2022-05-30 09:32:20 +02:00
"source": [
"## ZADANIE\n",
"\n",
"Gonito \"WMT2017 Czech-English machine translation challenge for news \"\n",
"\n",
"Proszę wytrenować najpierw model german -> english, a później dotrenować na czech-> english.\n",
"Można wziąć inicjalizować enkoder od nowa lub nie. Proszę w każdym razie użyć wytrenowanego dekodera."
]
2022-05-29 18:14:19 +02:00
}
],
"metadata": {
"author": "Jakub Pokrywka",
"email": "kubapok@wmi.amu.edu.pl",
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"lang": "pl",
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.4"
},
"subtitle": "0.Informacje na temat przedmiotu[ćwiczenia]",
"title": "Ekstrakcja informacji",
"year": "2021"
},
"nbformat": 4,
"nbformat_minor": 4
}