aitech-moj/cw/11_Model_rekurencyjny_z_atencją.ipynb
Jakub Pokrywka ecf931a3e0 11final
2022-05-30 09:17:21 +02:00

28 KiB

Logo 1

Modelowanie Języka

10. Model rekurencyjny z atencją [ćwiczenia]

Jakub Pokrywka (2022)

Logo 2

from __future__ import unicode_literals, print_function, division
from io import open
import unicodedata
import string
import re
import random

import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SOS_token = 0
EOS_token = 1

class Lang:
    def __init__(self):
        self.word2index = {}
        self.word2count = {}
        self.index2word = {0: "SOS", 1: "EOS"}
        self.n_words = 2  # Count SOS and EOS

    def addSentence(self, sentence):
        for word in sentence.split(' '):
            self.addWord(word)

    def addWord(self, word):
        if word not in self.word2index:
            self.word2index[word] = self.n_words
            self.word2count[word] = 1
            self.index2word[self.n_words] = word
            self.n_words += 1
        else:
            self.word2count[word] += 1
pairs = []
with open('data/eng-pol.txt') as f:
    for line in f:
        eng_line, pol_line = line.lower().rstrip().split('\t')

        eng_line = re.sub(r"([.!?])", r" \1", eng_line)
        eng_line = re.sub(r"[^a-zA-Z.!?]+", r" ", eng_line)

        pol_line = re.sub(r"([.!?])", r" \1", pol_line)
        pol_line = re.sub(r"[^a-zA-Z.!?ąćęłńóśźżĄĆĘŁŃÓŚŹŻ]+", r" ", pol_line)

        pairs.append([eng_line, pol_line])


pairs[1]
['hi .', 'cześć .']
MAX_LENGTH = 10
eng_prefixes = (
    "i am ", "i m ",
    "he is", "he s ",
    "she is", "she s ",
    "you are", "you re ",
    "we are", "we re ",
    "they are", "they re "
)

pairs = [p for p in pairs if len(p[0].split(' ')) < MAX_LENGTH and len(p[1].split(' ')) < MAX_LENGTH]
pairs = [p for p in pairs if p[0].startswith(eng_prefixes)]

eng_lang = Lang()
pol_lang = Lang()

for pair in pairs:
    eng_lang.addSentence(pair[0])
    pol_lang.addSentence(pair[1])
pairs[0]
['i m ok .', 'ze mną wszystko w porządku .']
pairs[1]
['i m up .', 'wstałem .']
pairs[2]
['i m tom .', 'jestem tom .']
eng_lang.n_words
1828
pol_lang.n_words
2883
class EncoderRNN(nn.Module):
    def __init__(self, input_size, embedding_size, hidden_size):
        super(EncoderRNN, self).__init__()
        self.embedding_size = 200
        self.hidden_size = hidden_size

        self.embedding = nn.Embedding(input_size, self.embedding_size)
        self.gru = nn.GRU(self.embedding_size, hidden_size)

    def forward(self, input, hidden):
        embedded = self.embedding(input).view(1, 1, -1)
        output = embedded
        output, hidden = self.gru(output, hidden)
        return output, hidden

    def initHidden(self):
        return torch.zeros(1, 1, self.hidden_size, device=device)
class DecoderRNN(nn.Module):
    def __init__(self, embedding_size, hidden_size, output_size):
        super(DecoderRNN, self).__init__()
        self.embedding_size = embedding_size
        self.hidden_size = hidden_size

        self.embedding = nn.Embedding(output_size, self.embedding_size)
        self.gru = nn.GRU(self.embedding_size, hidden_size)
        self.out = nn.Linear(hidden_size, output_size)
        self.softmax = nn.LogSoftmax(dim=1)

    def forward(self, input, hidden):
        output = self.embedding(input).view(1, 1, -1)
        output = F.relu(output)
        output, hidden = self.gru(output, hidden)
        output = self.softmax(self.out(output[0]))
        return output, hidden

    def initHidden(self):
        return torch.zeros(1, 1, self.hidden_size, device=device)
class AttnDecoderRNN(nn.Module):
    def __init__(self, embedding_size, hidden_size, output_size, dropout_p=0.1, max_length=MAX_LENGTH):
        super(AttnDecoderRNN, self).__init__()
        self.embedding_size = embedding_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.dropout_p = dropout_p
        self.max_length = max_length

        self.embedding = nn.Embedding(self.output_size, self.embedding_size)
        self.attn = nn.Linear(self.hidden_size + self.embedding_size, self.max_length)
        self.attn_combine = nn.Linear(self.hidden_size + self.embedding_size, self.embedding_size)
        self.dropout = nn.Dropout(self.dropout_p)
        self.gru = nn.GRU(self.embedding_size, self.hidden_size)
        self.out = nn.Linear(self.hidden_size, self.output_size)

    def forward(self, input, hidden, encoder_outputs):
        embedded = self.embedding(input).view(1, 1, -1)
        embedded = self.dropout(embedded)

        attn_weights = F.softmax(
            self.attn(torch.cat((embedded[0], hidden[0]), 1)), dim=1)
        attn_applied = torch.bmm(attn_weights.unsqueeze(0),
                                 encoder_outputs.unsqueeze(0))
        #import pdb; pdb.set_trace()

        output = torch.cat((embedded[0], attn_applied[0]), 1)
        output = self.attn_combine(output).unsqueeze(0)

        output = F.relu(output)
        output, hidden = self.gru(output, hidden)

        output = F.log_softmax(self.out(output[0]), dim=1)
        return output, hidden, attn_weights

    def initHidden(self):
        return torch.zeros(1, 1, self.hidden_size, device=device)
def tensorFromSentence(sentence, lang):
    indexes = [lang.word2index[word] for word in sentence.split(' ')]
    indexes.append(EOS_token)
    return torch.tensor(indexes, dtype=torch.long, device=device).view(-1, 1)
teacher_forcing_ratio = 0.5

def train_one_batch(input_tensor, target_tensor, encoder, decoder, optimizer, criterion, max_length=MAX_LENGTH):
    encoder_hidden = encoder.initHidden()


    optimizer.zero_grad()

    input_length = input_tensor.size(0)
    target_length = target_tensor.size(0)

    encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device)

    loss = 0

    for ei in range(input_length):
        encoder_output, encoder_hidden = encoder(input_tensor[ei], encoder_hidden)
        encoder_outputs[ei] = encoder_output[0, 0]

    decoder_input = torch.tensor([[SOS_token]], device=device)

    decoder_hidden = encoder_hidden

    use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False

    if use_teacher_forcing:
        for di in range(target_length):
            decoder_output, decoder_hidden, decoder_attention = decoder(decoder_input, decoder_hidden, encoder_outputs)
            loss += criterion(decoder_output, target_tensor[di])
            decoder_input = target_tensor[di]  # Teacher forcing

    else:
        for di in range(target_length):
            decoder_output, decoder_hidden, decoder_attention = decoder(decoder_input, decoder_hidden, encoder_outputs)
            topv, topi = decoder_output.topk(1)
            decoder_input = topi.squeeze().detach()  # detach from history as input

            loss += criterion(decoder_output, target_tensor[di])
            if decoder_input.item() == EOS_token:
                break

    loss.backward()

    optimizer.step()

    return loss.item() / target_length
def trainIters(encoder, decoder, n_iters, print_every=1000, learning_rate=0.01):
    print_loss_total = 0  # Reset every print_every
    encoder.train()
    decoder.train()

    optimizer = optim.SGD(list(encoder.parameters()) + list(decoder.parameters()), lr=learning_rate)
    
    training_pairs = [random.choice(pairs) for _ in range(n_iters)]
    training_pairs = [(tensorFromSentence(p[0], eng_lang), tensorFromSentence(p[1], pol_lang)) for p in training_pairs]
    
    criterion = nn.NLLLoss()

    for i in range(1, n_iters + 1):
        training_pair = training_pairs[i - 1]
        input_tensor = training_pair[0]
        target_tensor = training_pair[1]

        loss = train_one_batch(input_tensor,
                               target_tensor,
                               encoder,
                               decoder,
                               optimizer,

                               criterion)
        
        print_loss_total += loss

        if i % print_every == 0:
            print_loss_avg = print_loss_total / print_every
            print_loss_total = 0
            print(f'iter: {i}, loss: {print_loss_avg}')
def evaluate(encoder, decoder, sentence, max_length=MAX_LENGTH):
    encoder.eval()
    decoder.eval()
    with torch.no_grad():
        input_tensor = tensorFromSentence(sentence, eng_lang)
        input_length = input_tensor.size()[0]
        encoder_hidden = encoder.initHidden()

        encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device)

        for ei in range(input_length):
            encoder_output, encoder_hidden = encoder(input_tensor[ei], encoder_hidden)
            encoder_outputs[ei] += encoder_output[0, 0]

        decoder_input = torch.tensor([[SOS_token]], device=device)

        decoder_hidden = encoder_hidden

        decoded_words = []
        decoder_attentions = torch.zeros(max_length, max_length)

        for di in range(max_length):
            decoder_output, decoder_hidden, decoder_attention = decoder(
                decoder_input, decoder_hidden, encoder_outputs)
            decoder_attentions[di] = decoder_attention.data
            topv, topi = decoder_output.data.topk(1)
            if topi.item() == EOS_token:
                decoded_words.append('<EOS>')
                break
            else:
                decoded_words.append(pol_lang.index2word[topi.item()])

            decoder_input = topi.squeeze().detach()

        return decoded_words, decoder_attentions[:di + 1]
def evaluateRandomly(encoder, decoder, n=10):
    for i in range(n):
        pair = random.choice(pairs)
        print('>', pair[0])
        print('=', pair[1])
        output_words, attentions = evaluate(encoder, decoder, pair[0])
        output_sentence = ' '.join(output_words)
        print('<', output_sentence)
        print('')
embedding_size = 200
hidden_size = 256
encoder1 = EncoderRNN(eng_lang.n_words, embedding_size, hidden_size).to(device)
attn_decoder1 = AttnDecoderRNN(embedding_size, hidden_size, pol_lang.n_words, dropout_p=0.1).to(device)
trainIters(encoder1, attn_decoder1, 10_000, print_every=50)
iter: 50, loss: 5.042555550272503
iter: 100, loss: 4.143612308138894
iter: 150, loss: 4.258466395877656
iter: 200, loss: 4.078979822052849
iter: 250, loss: 3.9038650802657715
iter: 300, loss: 4.07207449336279
iter: 350, loss: 3.940484183538527
iter: 400, loss: 4.425489738524906
iter: 450, loss: 3.9398847290826224
iter: 500, loss: 4.264409653027852
iter: 550, loss: 4.323172234974209
iter: 600, loss: 4.22224827657427
iter: 650, loss: 4.204052018634857
iter: 700, loss: 3.9438682432023295
iter: 750, loss: 4.001692515509468
iter: 800, loss: 4.054982795352028
iter: 850, loss: 4.119050166281443
iter: 900, loss: 3.908679961704073
iter: 950, loss: 4.136870030266898
iter: 1000, loss: 3.8147727276938297
iter: 1050, loss: 4.026022962623171
iter: 1100, loss: 3.9598817706335154
iter: 1150, loss: 3.848097898089696
iter: 1200, loss: 4.01016833985041
iter: 1250, loss: 3.7720014858472917
iter: 1300, loss: 4.059876484976874
iter: 1350, loss: 3.8380891363658605
iter: 1400, loss: 4.013203263676356
iter: 1450, loss: 4.067137318686833
iter: 1500, loss: 4.020450985673874
iter: 1550, loss: 3.7160321428662244
iter: 1600, loss: 3.8411714478977137
iter: 1650, loss: 3.7125136051177985
iter: 1700, loss: 3.705152728769514
iter: 1750, loss: 3.9118153427441915
iter: 1800, loss: 3.857195938375262
iter: 1850, loss: 3.9566935270703025
iter: 1900, loss: 3.9394864430957375
iter: 1950, loss: 3.636212232317243
iter: 2000, loss: 3.847666795261321
iter: 2050, loss: 3.787096965411352
iter: 2100, loss: 3.4702608700933912
iter: 2150, loss: 3.727882717624543
iter: 2200, loss: 3.6961711362884153
iter: 2250, loss: 3.870331466848889
iter: 2300, loss: 3.8506508341743837
iter: 2350, loss: 3.803002176814609
iter: 2400, loss: 3.5700957290558586
iter: 2450, loss: 3.5328896935326712
iter: 2500, loss: 3.810194352997674
iter: 2550, loss: 3.713556599700262
iter: 2600, loss: 3.6131167711303345
iter: 2650, loss: 3.433012700254954
iter: 2700, loss: 3.7313271602903084
iter: 2750, loss: 3.5837062497366037
iter: 2800, loss: 3.6265894929265214
iter: 2850, loss: 3.5165250884616186
iter: 2900, loss: 3.8752988719410366
iter: 2950, loss: 3.709828086020455
iter: 3000, loss: 3.742527751090035
iter: 3050, loss: 3.5926183513232646
iter: 3100, loss: 3.6629667194003157
iter: 3150, loss: 3.7953110780715944
iter: 3200, loss: 3.4833724756770663
iter: 3250, loss: 3.5239689500066977
iter: 3300, loss: 3.552185758560423
iter: 3350, loss: 3.342997217700594
iter: 3400, loss: 3.7131163925897512
iter: 3450, loss: 3.2172264359110874
iter: 3500, loss: 3.1694674255961464
iter: 3550, loss: 3.5181667824548386
iter: 3600, loss: 3.552696303821745
iter: 3650, loss: 3.5465369727573703
iter: 3700, loss: 3.3895190108844213
iter: 3750, loss: 3.55357305569119
iter: 3800, loss: 3.618841464133489
iter: 3850, loss: 3.631707963504488
iter: 3900, loss: 3.705602922939119
iter: 3950, loss: 3.1555525365556987
iter: 4000, loss: 3.423284879676879
iter: 4050, loss: 3.74216214027859
iter: 4100, loss: 3.273874522224304
iter: 4150, loss: 3.9754231488666836
iter: 4200, loss: 3.255707532473973
iter: 4250, loss: 3.622867019956075
iter: 4300, loss: 3.3847267730198216
iter: 4350, loss: 3.6832511274095565
iter: 4400, loss: 3.265418997968946
iter: 4450, loss: 3.53306358509972
iter: 4500, loss: 3.2655868359520333
iter: 4550, loss: 3.579948601419965
iter: 4600, loss: 3.554656519799005
iter: 4650, loss: 3.324159849643708
iter: 4700, loss: 3.357913894865249
iter: 4750, loss: 3.048288846031067
iter: 4800, loss: 3.185154194937811
iter: 4850, loss: 2.9646709245159513
iter: 4900, loss: 3.4766449508288546
iter: 4950, loss: 3.1528075372302338
iter: 5000, loss: 3.12558690051427
iter: 5050, loss: 3.6565875165273276
iter: 5100, loss: 3.113538140228817
iter: 5150, loss: 3.0463946421638366
iter: 5200, loss: 3.384180574084086
iter: 5250, loss: 3.3104316232090913
iter: 5300, loss: 2.9496352179807332
iter: 5350, loss: 3.1814023027722804
iter: 5400, loss: 2.9286732437345724
iter: 5450, loss: 3.4691178646617464
iter: 5500, loss: 3.373944672122834
iter: 5550, loss: 3.213332776455653
iter: 5600, loss: 3.3247368506931116
iter: 5650, loss: 3.2702379176957272
iter: 5700, loss: 3.4554740653038025
iter: 5750, loss: 3.281306777431851
iter: 5800, loss: 2.9936736260368706
iter: 5850, loss: 3.277740831851959
iter: 5900, loss: 3.120459364088754
iter: 5950, loss: 3.387252744160001
iter: 6000, loss: 3.238504883735898
iter: 6050, loss: 2.738152531003195
iter: 6100, loss: 3.231002421265556
iter: 6150, loss: 3.0410601262819195
iter: 6200, loss: 3.093445486522856
iter: 6250, loss: 2.877119398207891
iter: 6300, loss: 3.006740029849703
iter: 6350, loss: 2.8918780979504657
iter: 6400, loss: 3.3124666434015553
iter: 6450, loss: 3.170363757602752
iter: 6500, loss: 3.1445780278387527
iter: 6550, loss: 3.0042706321610346
iter: 6600, loss: 2.94450242013023
iter: 6650, loss: 3.1747314814840046
iter: 6700, loss: 3.325715871651966
iter: 6750, loss: 3.1039765825120225
iter: 6800, loss: 3.260562201068516
iter: 6850, loss: 2.95558365320024
iter: 6900, loss: 3.1284036347071327
iter: 6950, loss: 3.161784927746607
iter: 7000, loss: 3.083566860369275
iter: 7050, loss: 3.1606678485643296
iter: 7100, loss: 3.39304134529356
iter: 7150, loss: 3.05389289476001
iter: 7200, loss: 3.171286074725408
iter: 7250, loss: 3.307133579034654
iter: 7300, loss: 2.987511603022379
iter: 7350, loss: 3.1221464098370264
iter: 7400, loss: 2.9686622249966574
iter: 7450, loss: 2.874706161885035
iter: 7500, loss: 2.759323406164608
iter: 7550, loss: 2.835318256658221
iter: 7600, loss: 2.896953154404958
iter: 7650, loss: 2.8871691599497717
iter: 7700, loss: 3.049550093332927
iter: 7750, loss: 2.9703013692507665
iter: 7800, loss: 2.8142153175671893
iter: 7850, loss: 2.8352768955987604
iter: 7900, loss: 2.863677294496506
iter: 7950, loss: 3.031682641491057
iter: 8000, loss: 2.9286883136809814
iter: 8050, loss: 2.9240697879488504
iter: 8100, loss: 3.0172221147900546
iter: 8150, loss: 2.8361169849426027
iter: 8200, loss: 2.9860127468676803
iter: 8250, loss: 2.9495567634294906
iter: 8300, loss: 2.793946119104113
iter: 8350, loss: 3.2106793221594785
iter: 8400, loss: 2.736634517018757
iter: 8450, loss: 2.8962079345536615
iter: 8500, loss: 2.906407202516283
iter: 8550, loss: 2.6900012663281148
iter: 8600, loss: 2.8905927643056897
iter: 8650, loss: 2.950769727600945
iter: 8700, loss: 2.884238138978443
iter: 8750, loss: 2.7154052526648083
iter: 8800, loss: 2.8823739119030183
iter: 8850, loss: 2.93061117755799
iter: 8900, loss: 2.658344201617771
iter: 8950, loss: 2.5747124820644887
iter: 9000, loss: 2.8281182004307954
iter: 9050, loss: 2.6702445936959895
iter: 9100, loss: 2.8030708763485865
iter: 9150, loss: 3.0742075329053966
iter: 9200, loss: 2.7834522392787635
iter: 9250, loss: 2.9308865650949025
iter: 9300, loss: 2.776913931453039
iter: 9350, loss: 2.7998796779011923
iter: 9400, loss: 3.1615792548088795
iter: 9450, loss: 3.2742855516539673
iter: 9500, loss: 2.981044085154457
iter: 9550, loss: 2.4407524968101866
iter: 9600, loss: 2.624275121037923
iter: 9650, loss: 2.4893303714971697
iter: 9700, loss: 2.7211539438906183
iter: 9750, loss: 2.8714180671828133
iter: 9800, loss: 2.7188037380396373
iter: 9850, loss: 2.4101966271173385
iter: 9900, loss: 2.9492219283542926
iter: 9950, loss: 2.547067801430112
iter: 10000, loss: 2.8521263429191372
evaluateRandomly(encoder1, attn_decoder1)
> he is a tennis player .
= on jest tenisistą .
< jest tenisistą . <EOS>

> i m not going to change my mind .
= nie zamierzam zmieniać zdania .
< nie idę do . <EOS>

> i m totally confused .
= jestem kompletnie zmieszany .
< jestem dziś . . <EOS>

> he is a pioneer in this field .
= jest pionierem w tej dziedzinie .
< on jest w w . . <EOS>

> i m so excited .
= jestem taki podekscytowany !
< jestem jestem głodny . <EOS>

> they are a party of six .
= jest ich sześć osób .
< oni nie są . . <EOS>

> he is the father of two children .
= on jest ojcem dwójki dzieci .
< on jest na do . . <EOS>

> i am leaving at four .
= wychodzę o czwartej .
< jestem na . <EOS>

> i m not much of a writer .
= pisarz ze mnie żaden .
< nie jestem mnie . . <EOS>

> you re disgusting !
= jesteś obrzydliwy !
< jesteś obrzydliwy . <EOS>