wmt-2020-pl-en/gru_attention.ipynb

602 lines
49 KiB
Plaintext
Raw Normal View History

2021-02-08 14:36:14 +01:00
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "gru_attention.ipynb",
"provenance": [],
"collapsed_sections": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "code",
"metadata": {
"id": "WtPfLDxTLoFn"
},
"source": [
"from __future__ import unicode_literals, print_function, division\r\n",
"from io import open\r\n",
"import unicodedata\r\n",
"import string\r\n",
"import re\r\n",
"import random\r\n",
"\r\n",
"import torch\r\n",
"import torch.nn as nn\r\n",
"from torch import optim\r\n",
"import torch.nn.functional as F\r\n",
"\r\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
],
"execution_count": 2,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "X3xChOaALwwA"
},
"source": [
"SOS_token = 0\r\n",
"EOS_token = 1\r\n",
"\r\n",
"\r\n",
"class Lang:\r\n",
" def __init__(self, name):\r\n",
" self.name = name\r\n",
" self.word2index = {}\r\n",
" self.word2count = {}\r\n",
" self.index2word = {0: \"SOS\", 1: \"EOS\"}\r\n",
" self.n_words = 2 # Count SOS and EOS\r\n",
"\r\n",
" def addSentence(self, sentence):\r\n",
" for word in sentence.split(' '):\r\n",
" self.addWord(word)\r\n",
"\r\n",
" def addWord(self, word):\r\n",
" if word not in self.word2index:\r\n",
" self.word2index[word] = self.n_words\r\n",
" self.word2count[word] = 1\r\n",
" self.index2word[self.n_words] = word\r\n",
" self.n_words += 1\r\n",
" else:\r\n",
" self.word2count[word] += 1"
],
"execution_count": 3,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "1rra860ILy48"
},
"source": [
"# Turn a Unicode string to plain ASCII, thanks to\r\n",
"# https://stackoverflow.com/a/518232/2809427\r\n",
"def unicodeToAscii(s):\r\n",
" return ''.join(\r\n",
" c for c in unicodedata.normalize('NFD', s)\r\n",
" if unicodedata.category(c) != 'Mn'\r\n",
" )\r\n",
"\r\n",
"# Lowercase, trim, and remove non-letter characters\r\n",
"\r\n",
"\r\n",
"def normalizeString(s):\r\n",
" s = unicodeToAscii(s.lower().strip())\r\n",
" s = re.sub(r\"([.!?])\", r\" \\1\", s)\r\n",
" s = re.sub(r\"[^a-zA-Z.!?]+\", r\" \", s)\r\n",
" return s"
],
"execution_count": 4,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "sEcey4mxL3We"
},
"source": [
"def readLangs():\r\n",
" print(\"Reading lines...\")\r\n",
"\r\n",
" # Read the file and split into lines\r\n",
" linesIn = open('in_40k.tsv').read().strip().split('\\n')[:100]\r\n",
" linesOut = open('exp_40k.tsv').read().strip().split('\\n')[:100]\r\n",
" # Split every line into pairs and normalize\r\n",
" pairs = [[normalizeString(a),normalizeString(b)] for a,b in zip(linesIn,linesOut)]\r\n",
"\r\n",
" print(pairs)\r\n",
"\r\n",
" # Reverse pairs, make Lang instances\r\n",
" # pairs = [list(reversed(p)) for p in pairs]\r\n",
" input_lang = Lang('in')\r\n",
" output_lang = Lang('out')\r\n",
" return input_lang, output_lang, pairs"
],
"execution_count": 5,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "59dvTVlsL9dK"
},
"source": [
"MAX_LENGTH = 80\r\n",
"\r\n",
"def filterPair(p):\r\n",
" #print(p)\r\n",
" return len(p[0].split(' ')) < MAX_LENGTH and \\\r\n",
" len(p[1].split(' ')) < MAX_LENGTH\r\n",
"\r\n",
"\r\n",
"def filterPairs(pairs):\r\n",
" return [pair for pair in pairs if filterPair(pair)]"
],
"execution_count": 6,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "4pKs9H5_ST8P",
"outputId": "2d07edf9-1bbd-4676-9577-411dc09f28b7"
},
"source": [
"def prepareData(lang1, lang2, reverse=False):\r\n",
" input_lang, output_lang, pairs = readLangs()\r\n",
" print(\"Read %s sentence pairs\" % len(pairs))\r\n",
" pairs = filterPairs(pairs)\r\n",
" print(\"Trimmed to %s sentence pairs\" % len(pairs))\r\n",
" print(\"Counting words...\")\r\n",
" for pair in pairs:\r\n",
" input_lang.addSentence(pair[0])\r\n",
" output_lang.addSentence(pair[1])\r\n",
" print(\"Counted words:\")\r\n",
" print(input_lang.name, input_lang.n_words)\r\n",
" print(output_lang.name, output_lang.n_words)\r\n",
" return input_lang, output_lang, pairs\r\n",
"\r\n",
"\r\n",
"input_lang, output_lang, pairs = prepareData('pl', 'en', True)\r\n",
"#print(random.choice(pairs))"
],
"execution_count": 8,
"outputs": [
{
"output_type": "stream",
"text": [
"Reading lines...\n",
"[['naprawde wazne jest by wzrost gospodarczy nie powodowa automatycznie proporcjonalnego zwiekszonego zuzycia energii .', 'it is really important that growth should not automatically generate a proportionate rise in energy consumption .'], [' bg pani przewodniczaca panie premierze ! rok bedzie pierwszym w ktorym unii europejskiej beda przewodzic dwa kraje z europy srodkowej i wschodniej wegry oraz polska .', ' bg madam president prime minister will be the first year in which the european union will be headed by two countries from central and eastern europe hungary and poland .'], ['w dodatku odsetek ludzi w wieku ponad lat wzrosnie z w roku do w roku .', 'in addition the proportion of people aged over will rise from . in to . in .'], ['na pismie . sv w sprawozdaniu stwierdzono ze w wiekszosci panstw cz onkowskich spo eczenstwo starzeje sie co obciazy systemy zabezpieczenia spo ecznego i systemy emerytalne .', 'in writing . sv this report observes that in most member states the population is getting older and that the social security and pension systems will therefore be put under strain .'], ['oswiadczenia pisemne art . regulaminu ', 'written statements rule '], ['jestesmy na przyk ad za przeprowadzeniem wspolnych badan z zakresu bezpieczenstwa jadrowego ale obawiamy sie ze wiele punktow sprawozdania wyraza zbyt mocne poparcie dla kwestgii energii jadrowej .', 'we are in favour of common research into nuclear safety for example but we feel that in several cases the report is far too pro nuclear energy .'], ['kolejna kwestia wspomniana w trakcie debaty by a kwestia sprzeciwu wobec protekcjonizmu .', 'another point that was mentioned during the debate was the issue of resistance to protectionism .'], ['s uzba zewnetrzna musi w pe ni dotrzymywac kroku komisji .', 'the external service must be wholly in step with the commission .'], ['z zadowoleniem przyjmujemy propozycje wysuniete celem zagwarantowania usprawnienia krajowych ram budzetowych i zachecenia panstw cz onkowskich do podejmowania bardziej wywazonych decyzji budzetowych w przysz osci .', 'we welcome the proposal put forward to ensure improvements in national fiscal frameworks and to encourage member states to make better fiscal decisions in the future .'], ['po pierwsze nie sadze aby umniejszanie grekow w sposob w jaki uczyni to pan pose soini by o wartosciowe czy nawet w asciwe .', 'first of all i do not believe that belittling the greeks in the manner that mr soini did is very useful or even professionally appropriate .'], ['dopuscilismy na przyk ad do spadku naszych wolnych mocy produkcyjnych o ok . w skali roku a to powoduje niepewnosc .', 'we have for example allowed our spare capacity to fall by around every year and that is creating insecurity .'], ['na pismie . chce podziekowac sprawozdawcy za wspania a prace .', 'in writing . i would like to thank the rapporteur for his excellent work .'], ['chcia abym sie odniesc przede wszystkim do misji obserwacyjnych w afryce ze wzgledu na szczegolna wspo prace ue z panstwami afryki karaibow i pacyfiku .', 'i would like to speak with particular reference to observation missions in africa because there is a special partnership between the eu and african caribbean and pacific countries .'], ['g osowa am za przyjeciem przedmiotowej rezolucji .', 'i voted in favour of this resolution .'], ['wniosek komisji ktory w stylu wielkiego brata stwierdza ze nalezy dostarczac owoce sezonowe podkreslajac roznorodnosc owocow tak aby dzieci mog y odkrywac ich smaki jest ca kowicie absurdalny .', 'the committee s proposal which in a big brother like manner states that seasonal fruit should be distributed giving preference to a varied range of fruits so as to enable children to discover different tastes is completely ridiculous .'], ['w kazdym razie przyjelismy to do wiadomosci i wezmiemy je pod uwage .', 'we have in any case taken note of them and shall take them into consideration .'], ['chcia abym skomentowac tresc w zakresie czterech czy pieciu konkretnych spraw .', 'i would like to comment on the content in relation to
"Read 100 sentence pairs\n",
"Trimmed to 100 sentence pairs\n",
"Counting words...\n",
"Counted words:\n",
"in 1155\n",
"out 877\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "1lEImRDtSYK1",
"outputId": "0d345a19-5f00-40cf-98fe-ebb846ef9a74"
},
"source": [
"input_lang.n_words"
],
"execution_count": 9,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"1155"
]
},
"metadata": {
"tags": []
},
"execution_count": 9
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "jbMReYBuMBUy"
},
"source": [
"class EncoderRNN(nn.Module):\r\n",
" def __init__(self, input_size, hidden_size):\r\n",
" super(EncoderRNN, self).__init__()\r\n",
" self.hidden_size = hidden_size\r\n",
"\r\n",
" self.embedding = nn.Embedding(input_size, hidden_size)\r\n",
" self.gru = nn.GRU(hidden_size, hidden_size)\r\n",
"\r\n",
" def forward(self, input, hidden):\r\n",
" embedded = self.embedding(input).view(1, 1, -1)\r\n",
" output = embedded\r\n",
" output, hidden = self.gru(output, hidden)\r\n",
" return output, hidden\r\n",
"\r\n",
" def initHidden(self):\r\n",
" return torch.zeros(1, 1, self.hidden_size, device=device)"
],
"execution_count": 10,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "FTGPW7-AMC5R"
},
"source": [
"class AttnDecoderRNN(nn.Module):\r\n",
" def __init__(self, hidden_size, output_size, dropout_p=0.1, max_length=MAX_LENGTH):\r\n",
" super(AttnDecoderRNN, self).__init__()\r\n",
" self.hidden_size = hidden_size\r\n",
" self.output_size = output_size\r\n",
" self.dropout_p = dropout_p\r\n",
" self.max_length = max_length\r\n",
"\r\n",
" self.embedding = nn.Embedding(self.output_size, self.hidden_size)\r\n",
" self.attn = nn.Linear(self.hidden_size * 2, self.max_length)\r\n",
" self.attn_combine = nn.Linear(self.hidden_size * 2, self.hidden_size)\r\n",
" self.dropout = nn.Dropout(self.dropout_p)\r\n",
" self.gru = nn.GRU(self.hidden_size, self.hidden_size)\r\n",
" self.out = nn.Linear(self.hidden_size, self.output_size)\r\n",
"\r\n",
" def forward(self, input, hidden, encoder_outputs):\r\n",
" embedded = self.embedding(input).view(1, 1, -1)\r\n",
" embedded = self.dropout(embedded)\r\n",
"\r\n",
" attn_weights = F.softmax(\r\n",
" self.attn(torch.cat((embedded[0], hidden[0]), 1)), dim=1)\r\n",
" attn_applied = torch.bmm(attn_weights.unsqueeze(0),\r\n",
" encoder_outputs.unsqueeze(0))\r\n",
"\r\n",
" output = torch.cat((embedded[0], attn_applied[0]), 1)\r\n",
" output = self.attn_combine(output).unsqueeze(0)\r\n",
"\r\n",
" output = F.relu(output)\r\n",
" output, hidden = self.gru(output, hidden)\r\n",
"\r\n",
" output = F.log_softmax(self.out(output[0]), dim=1)\r\n",
" return output, hidden, attn_weights\r\n",
"\r\n",
" def initHidden(self):\r\n",
" return torch.zeros(1, 1, self.hidden_size, device=device)"
],
"execution_count": 11,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "LuTQI2G3MEpk"
},
"source": [
"def indexesFromSentence(lang, sentence):\r\n",
" res = []\r\n",
" for word in sentence.split(' '):\r\n",
" if word not in lang.word2index:\r\n",
" res.append(random.randrange(len(lang.word2index)))\r\n",
" else:\r\n",
" res.append(lang.word2index[word])\r\n",
" return res\r\n",
"\r\n",
"\r\n",
"def tensorFromSentence(lang, sentence):\r\n",
" indexes = indexesFromSentence(lang, sentence)\r\n",
" indexes.append(EOS_token)\r\n",
" return torch.tensor(indexes, dtype=torch.long, device=device).view(-1, 1)\r\n",
"\r\n",
"\r\n",
"def tensorsFromPair(pair):\r\n",
" input_tensor = tensorFromSentence(input_lang, pair[0])\r\n",
" target_tensor = tensorFromSentence(output_lang, pair[1])\r\n",
" return (input_tensor, target_tensor)"
],
"execution_count": 22,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "XRaHJN_5MGzk"
},
"source": [
"teacher_forcing_ratio = 0.5\r\n",
"\r\n",
"\r\n",
"def train(input_tensor, target_tensor, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion, max_length=MAX_LENGTH):\r\n",
" encoder_hidden = encoder.initHidden()\r\n",
"\r\n",
" encoder_optimizer.zero_grad()\r\n",
" decoder_optimizer.zero_grad()\r\n",
"\r\n",
" input_length = input_tensor.size(0)\r\n",
" target_length = target_tensor.size(0)\r\n",
"\r\n",
" encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device)\r\n",
"\r\n",
" loss = 0\r\n",
"\r\n",
" for ei in range(input_length):\r\n",
" encoder_output, encoder_hidden = encoder(\r\n",
" input_tensor[ei], encoder_hidden)\r\n",
" encoder_outputs[ei] = encoder_output[0, 0]\r\n",
"\r\n",
" decoder_input = torch.tensor([[SOS_token]], device=device)\r\n",
"\r\n",
" decoder_hidden = encoder_hidden\r\n",
"\r\n",
" use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False\r\n",
"\r\n",
" if use_teacher_forcing:\r\n",
" # Teacher forcing: Feed the target as the next input\r\n",
" for di in range(target_length):\r\n",
" decoder_output, decoder_hidden, decoder_attention = decoder(\r\n",
" decoder_input, decoder_hidden, encoder_outputs)\r\n",
" loss += criterion(decoder_output, target_tensor[di])\r\n",
" decoder_input = target_tensor[di] # Teacher forcing\r\n",
"\r\n",
" else:\r\n",
" # Without teacher forcing: use its own predictions as the next input\r\n",
" for di in range(target_length):\r\n",
" decoder_output, decoder_hidden, decoder_attention = decoder(\r\n",
" decoder_input, decoder_hidden, encoder_outputs)\r\n",
" topv, topi = decoder_output.topk(1)\r\n",
" decoder_input = topi.squeeze().detach() # detach from history as input\r\n",
"\r\n",
" loss += criterion(decoder_output, target_tensor[di])\r\n",
" if decoder_input.item() == EOS_token:\r\n",
" break\r\n",
"\r\n",
" loss.backward()\r\n",
"\r\n",
" encoder_optimizer.step()\r\n",
" decoder_optimizer.step()\r\n",
"\r\n",
" return loss.item() / target_length"
],
"execution_count": 13,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "O67rLt62MJST"
},
"source": [
"import time\r\n",
"import math\r\n",
"\r\n",
"\r\n",
"def asMinutes(s):\r\n",
" m = math.floor(s / 60)\r\n",
" s -= m * 60\r\n",
" return '%dm %ds' % (m, s)\r\n",
"\r\n",
"\r\n",
"def timeSince(since, percent):\r\n",
" now = time.time()\r\n",
" s = now - since\r\n",
" es = s / (percent)\r\n",
" rs = es - s\r\n",
" return '%s (- %s)' % (asMinutes(s), asMinutes(rs))"
],
"execution_count": 14,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "gRBh9zz-MLjh"
},
"source": [
"def trainIters(encoder, decoder, n_iters, print_every=1000, plot_every=100, learning_rate=0.01):\r\n",
" start = time.time()\r\n",
" plot_losses = []\r\n",
" print_loss_total = 0 # Reset every print_every\r\n",
" plot_loss_total = 0 # Reset every plot_every\r\n",
"\r\n",
" encoder_optimizer = optim.SGD(encoder.parameters(), lr=learning_rate)\r\n",
" decoder_optimizer = optim.SGD(decoder.parameters(), lr=learning_rate)\r\n",
" training_pairs = [tensorsFromPair(random.choice(pairs))\r\n",
" for i in range(n_iters)]\r\n",
" criterion = nn.NLLLoss()\r\n",
"\r\n",
" for iter in range(1, n_iters + 1):\r\n",
" training_pair = training_pairs[iter - 1]\r\n",
" input_tensor = training_pair[0]\r\n",
" target_tensor = training_pair[1]\r\n",
"\r\n",
" loss = train(input_tensor, target_tensor, encoder,\r\n",
" decoder, encoder_optimizer, decoder_optimizer, criterion)\r\n",
" print_loss_total += loss\r\n",
" plot_loss_total += loss\r\n",
"\r\n",
" if iter % print_every == 0:\r\n",
" print_loss_avg = print_loss_total / print_every\r\n",
" print_loss_total = 0\r\n",
" print('%s (%d %d%%) %.4f' % (timeSince(start, iter / n_iters),\r\n",
" iter, iter / n_iters * 100, print_loss_avg))\r\n",
"\r\n",
" if iter % plot_every == 0:\r\n",
" plot_loss_avg = plot_loss_total / plot_every\r\n",
" plot_losses.append(plot_loss_avg)\r\n",
" plot_loss_total = 0\r\n",
"\r\n",
" showPlot(plot_losses)"
],
"execution_count": 15,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "Qqkc5IsEMOfW"
},
"source": [
"def evaluate(encoder, decoder, sentence, max_length=MAX_LENGTH):\r\n",
" with torch.no_grad():\r\n",
" input_tensor = tensorFromSentence(input_lang, sentence)\r\n",
" input_length = input_tensor.size()[0]\r\n",
" encoder_hidden = encoder.initHidden()\r\n",
"\r\n",
" encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device)\r\n",
"\r\n",
" for ei in range(input_length):\r\n",
" encoder_output, encoder_hidden = encoder(input_tensor[ei],\r\n",
" encoder_hidden)\r\n",
" encoder_outputs[ei] += encoder_output[0, 0]\r\n",
"\r\n",
" decoder_input = torch.tensor([[SOS_token]], device=device) # SOS\r\n",
"\r\n",
" decoder_hidden = encoder_hidden\r\n",
"\r\n",
" decoded_words = []\r\n",
" decoder_attentions = torch.zeros(max_length, max_length)\r\n",
"\r\n",
" for di in range(max_length):\r\n",
" decoder_output, decoder_hidden, decoder_attention = decoder(\r\n",
" decoder_input, decoder_hidden, encoder_outputs)\r\n",
" decoder_attentions[di] = decoder_attention.data\r\n",
" topv, topi = decoder_output.data.topk(1)\r\n",
" if topi.item() == EOS_token:\r\n",
" decoded_words.append('<EOS>')\r\n",
" break\r\n",
" else:\r\n",
" decoded_words.append(output_lang.index2word[topi.item()])\r\n",
"\r\n",
" decoder_input = topi.squeeze().detach()\r\n",
"\r\n",
" return decoded_words, decoder_attentions[:di + 1]"
],
"execution_count": 16,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "u_aEPNTQMRQc"
},
"source": [
"hidden_size = 256\r\n",
"encoder1 = EncoderRNN(input_lang.n_words, hidden_size).to(device)\r\n",
"attn_decoder1 = AttnDecoderRNN(hidden_size, output_lang.n_words, dropout_p=0.1).to(device)\r\n",
"\r\n",
"trainIters(encoder1, attn_decoder1, 25000, print_every=20)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "sZH-wZjyRd9V"
},
"source": [
"evaluate(encoder1, attn_decoder1, \"Co tam u ciebie\")"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "GgeEWwJAZAfE"
},
"source": [
"def evaluateAndShow(input_sentence):\r\n",
" output_words = evaluate(\r\n",
" encoder1, attn_decoder1, input_sentence)\r\n",
" return \" \".join(output_words[0])"
],
"execution_count": 36,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 35
},
"id": "OvpxV5Sz19Wg",
"outputId": "758386a3-7365-4297-bd2c-809b5732ef6b"
},
"source": [
"evaluateAndShow(\"Co tam u cbie\")"
],
"execution_count": 37,
"outputs": [
{
"output_type": "execute_result",
"data": {
"application/vnd.google.colaboratory.intrinsic+json": {
"type": "string"
},
"text/plain": [
"'let us be able to live up to it because these are real problems and real people and we have to deal with them now . <EOS>'"
]
},
"metadata": {
"tags": []
},
"execution_count": 37
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "tak-qrYjyjws"
},
"source": [
"temp = open('in.tsv', 'r').readlines()\r\n",
"data = []\r\n",
"for sent in temp:\r\n",
" data.append(sent.replace('\\n',''))\r\n",
"\r\n",
"f=open('out.tsv','w+')\r\n",
"for sent in data:\r\n",
" f.write(evaluateAndShow(sent).replace('<EOS>', '') + '\\n')\r\n",
"\r\n",
"f.close()"
],
"execution_count": 38,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "qpuTVdo12O5y"
},
"source": [
""
],
"execution_count": null,
"outputs": []
}
]
}