ugb/3_RNN.ipynb
Paweł Skórzewski ccd30390b3 3_RNN.ipynb
2024-05-10 14:56:49 +02:00

24 KiB
Raw Blame History

Uczenie głębokie przetwarzanie tekstu laboratoria

3. RNN

Podejście softmax z embeddingami na przykładzie NER

!pip install torch torchtext
Defaulting to user installation because normal site-packages is not writeable
Requirement already satisfied: torch in /home/pawel/.local/lib/python3.10/site-packages (2.3.0)
Collecting torchtext
  Downloading torchtext-0.18.0-cp310-cp310-manylinux1_x86_64.whl (2.0 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.0/2.0 MB 9.6 MB/s eta 0:00:00a 0:00:01
[?25hRequirement already satisfied: filelock in /home/pawel/.local/lib/python3.10/site-packages (from torch) (3.13.1)
Requirement already satisfied: fsspec in /home/pawel/.local/lib/python3.10/site-packages (from torch) (2024.2.0)
Requirement already satisfied: sympy in /home/pawel/.local/lib/python3.10/site-packages (from torch) (1.12)
Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /home/pawel/.local/lib/python3.10/site-packages (from torch) (12.1.105)
Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /home/pawel/.local/lib/python3.10/site-packages (from torch) (12.1.105)
Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /home/pawel/.local/lib/python3.10/site-packages (from torch) (12.1.105)
Requirement already satisfied: typing-extensions>=4.8.0 in /home/pawel/.local/lib/python3.10/site-packages (from torch) (4.10.0)
Requirement already satisfied: nvidia-cudnn-cu12==8.9.2.26 in /home/pawel/.local/lib/python3.10/site-packages (from torch) (8.9.2.26)
Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /home/pawel/.local/lib/python3.10/site-packages (from torch) (10.3.2.106)
Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /home/pawel/.local/lib/python3.10/site-packages (from torch) (12.1.0.106)
Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /home/pawel/.local/lib/python3.10/site-packages (from torch) (12.1.105)
Requirement already satisfied: triton==2.3.0 in /home/pawel/.local/lib/python3.10/site-packages (from torch) (2.3.0)
Requirement already satisfied: jinja2 in /home/pawel/.local/lib/python3.10/site-packages (from torch) (3.1.3)
Requirement already satisfied: nvidia-nccl-cu12==2.20.5 in /home/pawel/.local/lib/python3.10/site-packages (from torch) (2.20.5)
Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /home/pawel/.local/lib/python3.10/site-packages (from torch) (12.1.3.1)
Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /home/pawel/.local/lib/python3.10/site-packages (from torch) (11.0.2.54)
Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /home/pawel/.local/lib/python3.10/site-packages (from torch) (11.4.5.107)
Requirement already satisfied: networkx in /home/pawel/.local/lib/python3.10/site-packages (from torch) (3.3)
Requirement already satisfied: nvidia-nvjitlink-cu12 in /home/pawel/.local/lib/python3.10/site-packages (from nvidia-cusolver-cu12==11.4.5.107->torch) (12.4.127)
Requirement already satisfied: requests in /home/pawel/.local/lib/python3.10/site-packages (from torchtext) (2.31.0)
Requirement already satisfied: numpy in /home/pawel/.local/lib/python3.10/site-packages (from torchtext) (1.26.4)
Requirement already satisfied: tqdm in /home/pawel/.local/lib/python3.10/site-packages (from torchtext) (4.66.2)
Requirement already satisfied: MarkupSafe>=2.0 in /home/pawel/.local/lib/python3.10/site-packages (from jinja2->torch) (2.1.5)
Requirement already satisfied: certifi>=2017.4.17 in /home/pawel/.local/lib/python3.10/site-packages (from requests->torchtext) (2024.2.2)
Requirement already satisfied: idna<4,>=2.5 in /home/pawel/.local/lib/python3.10/site-packages (from requests->torchtext) (3.6)
Requirement already satisfied: charset-normalizer<4,>=2 in /home/pawel/.local/lib/python3.10/site-packages (from requests->torchtext) (3.3.2)
Requirement already satisfied: urllib3<3,>=1.21.1 in /home/pawel/.local/lib/python3.10/site-packages (from requests->torchtext) (2.2.1)
Requirement already satisfied: mpmath>=0.19 in /home/pawel/.local/lib/python3.10/site-packages (from sympy->torch) (1.3.0)
Installing collected packages: torchtext
Successfully installed torchtext-0.18.0
from collections import Counter

import torch
from datasets import load_dataset
from torchtext.vocab import vocab
from tqdm.notebook import tqdm

Wczytujemy zbiór danych conll2003 (https://huggingface.co/datasets/conll2003), który zawiera teksty oznaczone znacznikami części mowy (_POS tags):

dataset = load_dataset("conll2003")

Poiżej funkcja, która tworzy słownik (https://pytorch.org/text/stable/vocab.html).

Parametr special określa symbole specjalne:

  • <unk> nieznany token
  • <pad> wypełnienie
  • <bos> początek zdania
  • <eos> koniec zdania
def build_vocab(dataset):
    counter = Counter()
    for document in dataset:
        counter.update(document)
    return vocab(counter, specials=['<unk>', '<pad>', '<bos>', '<eos>'])
v = build_vocab(dataset['train']['tokens'])
itos = v.get_itos()  # mapowanie indeksów na tokeny
len(itos)  # liczba różnych tokenów w słowniku
23627
v['on']  # indeks tokenu `on`
21
v["<unk>"]  # indeks nieznanego tokenu
0

W przypadku, gdy w analizowanym tekście znajdzie się token, którego nie ma w słowniku, będzie reprezentowany przez indeks domyślny (_default index). Ustawiamy, żeby był taki sam, jak indeks „nieznanego tokenu”:

v.set_default_index(v["<unk>"])
def data_process(dt):
    # Wektoryzacja dokumentów tekstowych.
    return [ torch.tensor([v['<bos>']] +[v[token]  for token in  document ] + [v['<eos>']], dtype = torch.long) for document in dt]
def labels_process(dt):
    # Wektoryzacja etykiet (POS)
    return [ torch.tensor([0] + document + [0], dtype = torch.long) for document in dt]

Teraz wektoryzujemy wszystkie dane:

train_tokens_ids = data_process(dataset['train']['tokens'])
test_tokens_ids = data_process(dataset['test']['tokens'])
validation_tokens_ids =  data_process(dataset['validation']['tokens'])
train_labels = labels_process(dataset['train']['ner_tags'])
validation_labels = labels_process(dataset['validation']['ner_tags'])
test_labels = labels_process(dataset['test']['ner_tags'])

Przykład, jak wyglądają dane po zwektoryzowaniu:

train_tokens_ids[0]
tensor([ 2,  4,  5,  6,  7,  8,  9, 10, 11, 12,  3])
dataset['train'][0]
{'id': '0',
 'tokens': ['EU',
  'rejects',
  'German',
  'call',
  'to',
  'boycott',
  'British',
  'lamb',
  '.'],
 'pos_tags': [22, 42, 16, 21, 35, 37, 16, 21, 7],
 'chunk_tags': [11, 21, 11, 12, 21, 22, 11, 12, 0],
 'ner_tags': [3, 0, 7, 0, 0, 0, 7, 0, 0]}
train_labels[0]
tensor([0, 3, 0, 7, 0, 0, 0, 7, 0, 0, 0])

Funkcja, której użyjemy do ewaluacji:

def get_scores(y_true, y_pred):
    # Funkcja zwraca precyzję, pokrycie i F1
    acc_score = 0
    tp = 0
    fp = 0
    selected_items = 0
    relevant_items = 0 

    for p,t in zip(y_pred, y_true):
        if p == t:
            acc_score +=1

        if p > 0 and p == t:
            tp +=1

        if p > 0:
            selected_items += 1

        if t > 0 :
            relevant_items +=1

            
            
    if selected_items == 0:
        precision = 1.0
    else:
        precision = tp / selected_items
        
            
    if relevant_items == 0:
        recall = 1.0
    else:
        recall = tp / relevant_items
    
    
    if precision + recall == 0.0 :
        f1 = 0.0
    else:
        f1 = 2* precision * recall  / (precision + recall)

    return precision, recall, f1

Ile mamy różnych POS tagów?

num_tags = max([max(x) for x in dataset['train']['ner_tags'] ]) + 1 
print(num_tags)
9

Implementacja rekurencyjnej sieci neuronowej LSTM:

class LSTM(torch.nn.Module):

    def __init__(self):
        super(LSTM, self).__init__()
        self.emb = torch.nn.Embedding(len(v.get_itos()),100)
        self.rec = torch.nn.LSTM(100, 256, 1, batch_first = True)
        self.fc1 = torch.nn.Linear( 256 , 9)

    def forward(self, x):
        emb = torch.relu(self.emb(x))
        
        lstm_output, (h_n, c_n) = self.rec(emb)
        
        out_weights = self.fc1(lstm_output)

        return out_weights

Stworzenie modelu:

lstm = LSTM()

Definicja funkcji kosztu:

criterion = torch.nn.CrossEntropyLoss()

Definicja optymalizatora:

optimizer = torch.optim.Adam(lstm.parameters())

Funkcja do ewaluacji modelu:

def eval_model(dataset_tokens, dataset_labels, model):
    Y_true = []
    Y_pred = []
    for i in tqdm(range(len(dataset_labels))):
        batch_tokens = dataset_tokens[i].unsqueeze(0)
        tags = list(dataset_labels[i].numpy())
        Y_true += tags
        
        Y_batch_pred_weights = model(batch_tokens).squeeze(0)
        Y_batch_pred = torch.argmax(Y_batch_pred_weights,1)
        Y_pred += list(Y_batch_pred.numpy())
        

    return get_scores(Y_true, Y_pred)
        

Uczenie modelu:

NUM_EPOCHS = 5
for i in range(NUM_EPOCHS):
    lstm.train()
    #for i in tqdm(range(500)):
    for i in tqdm(range(len(train_labels))):
        batch_tokens = train_tokens_ids[i].unsqueeze(0)
        tags = train_labels[i].unsqueeze(1)
        
        
        predicted_tags = lstm(batch_tokens)

        
        optimizer.zero_grad()
        loss  = criterion(predicted_tags.squeeze(0),tags.squeeze(1))
        
        loss.backward()
        optimizer.step()
        
    lstm.eval()
    print(eval_model(validation_tokens_ids, validation_labels, lstm))
  0%|          | 0/14041 [00:00<?, ?it/s]
  0%|          | 0/3250 [00:00<?, ?it/s]
(0.49656056896350703, 0.4950598628385447, 0.49580908032596044)
  0%|          | 0/14041 [00:00<?, ?it/s]
  0%|          | 0/3250 [00:00<?, ?it/s]
(0.6289105835367207, 0.6589561780774148, 0.643582902877902)
  0%|          | 0/14041 [00:00<?, ?it/s]
  0%|          | 0/3250 [00:00<?, ?it/s]
(0.7031268719300348, 0.6822038823666163, 0.6925073746312684)
  0%|          | 0/14041 [00:00<?, ?it/s]
  0%|          | 0/3250 [00:00<?, ?it/s]
(0.7354687113529558, 0.6912704870394049, 0.7126850020971898)
  0%|          | 0/14041 [00:00<?, ?it/s]
  0%|          | 0/3250 [00:00<?, ?it/s]
(0.7134837896666285, 0.7239335115657329, 0.7186706669743826)

Ewaluacja:

eval_model(validation_tokens_ids, validation_labels, lstm)
  0%|          | 0/3250 [00:00<?, ?it/s]
(0.7134837896666285, 0.7239335115657329, 0.7186706669743826)
eval_model(test_tokens_ids, test_labels, lstm)
  0%|          | 0/3453 [00:00<?, ?it/s]
(0.6529463280370325, 0.6433678500986193, 0.6481217013349891)

Zadanie 3

Sklonuj repozytorium https://git.wmi.amu.edu.pl/kubapok/en-ner-conll-2003

Stwórz model _sequence labelling oparty o dowolną rekurencyjną sieć neuronową (możesz wzorować się na przykładzie z zajęć).

W plikach dev-0/out.tsv oraz test-A/out.tsv umieść wyniki predykcji dla dev-0/in.tsv i test-A/in.tsv odpowiednio. Do ewaluacji wykorzystaj narzędzie GEval (https://gitlab.com/filipg/geval):

wget https://gonito.net/get/bin/geval
chmod u+x geval
./geval --help

Liczba punktów uzyskanych za zadanie zależy od uzyskanej wartości accuracy na zbiorze test-A (wynik zaokrąglony w górę):

points = math.ceil(accuracy * 7.0)

⚠️ W systemie Moodle proszę załączyć plik test-A/out.tsv oraz link do repozytorium z rozwiązaniem zadania.