en-ner-conll-2003/lab8.ipynb
Alicja Szulecka ef0e228ed8 lab8
2024-06-02 19:45:06 +02:00

20 KiB

import torch
import pandas as pd

from collections import Counter
from torchtext.vocab import vocab
from sklearn.metrics import accuracy_score
from tqdm import tqdm
#Wczytanie zbioru danych

train_set = pd.read_csv('./train/train.tsv', sep='\t', header=None, names=['labels', 'text'])
val_set = pd.read_csv('./dev-0/expected.tsv', sep='\t', header=None, names=['labels'])
val_set['text'] = pd.read_csv('./dev-0/in.tsv', sep='\t', header=None, names=['text'])
test_set = pd.read_csv('./test-A/in.tsv', sep='\t', header=None, names=['text'])
#Tokenizacja danych
train_set['text'] = train_set["text"].apply(lambda x : x.split())
train_set['labels'] = train_set["labels"].apply(lambda x : x.split())

val_set['text'] = val_set["text"].apply(lambda x : x.split())
val_set['labels'] = val_set["labels"].apply(lambda x : x.split())

test_set['text'] = test_set["text"].apply(lambda x : x.split())
train_set.head(10)
labels text
0 [B-ORG, O, B-MISC, O, O, O, B-MISC, O, O, O, B... [EU, rejects, German, call, to, boycott, Briti...
1 [O, B-PER, O, O, O, O, O, O, O, O, O, B-LOC, O... [Rare, Hendrix, song, draft, sells, for, almos...
2 [B-LOC, O, B-LOC, O, O, O, O, O, O, B-LOC, O, ... [China, says, Taiwan, spoils, atmosphere, for,...
3 [B-LOC, O, O, O, O, B-LOC, O, O, O, B-LOC, O, ... [China, says, time, right, for, Taiwan, talks,...
4 [B-MISC, O, O, O, O, O, O, O, O, O, O, O, B-LO... [German, July, car, registrations, up, 14.2, p...
5 [B-MISC, O, O, O, O, O, O, O, O, O, O, B-LOC, ... [GREEK, SOCIALISTS, GIVE, GREEN, LIGHT, TO, PM...
6 [B-ORG, O, B-MISC, O, O, O, O, O, O, B-LOC, O,... [BayerVB, sets, C$, 100, million, six-year, bo...
7 [B-ORG, O, O, O, O, O, O, O, O, O, B-LOC, O, O... [Venantius, sets, $, 300, million, January, 19...
8 [O, O, O, O, B-LOC, O, B-ORG, I-ORG, O, O, O, ... [Port, conditions, update, -, Syria, -, Lloyds...
9 [B-LOC, O, O, O, O, O, O, B-LOC, O, O, B-PER, ... [Israel, plays, down, fears, of, war, with, Sy...
#Budowanie słownika
def build_vocab(dataset):
    counter = Counter()
    for document in dataset:
        counter.update(document)
    return vocab(counter, specials=["<unk>", "<pad>", "<bos>", "<eos>"])
    
v = build_vocab(train_set['text'])
v.set_default_index(v["<unk>"])

itos = v.get_itos()

itos[:10]
['<unk>',
 '<pad>',
 '<bos>',
 '<eos>',
 'EU',
 'rejects',
 'German',
 'call',
 'to',
 'boycott']
def data_process(dt):
    # Wektoryzacja dokumentów tekstowych.
    return [
        torch.tensor(
            [v["<bos>"]] + [v[token] for token in document] + [v["<eos>"]],
            dtype=torch.long,
        )
        for document in dt
    ]

def labels_process(dt):
    # Wektoryzacja etykiet (NER)
    return [torch.tensor([0] + document + [0], dtype=torch.long) for document in dt]
#Różne tagi NER
num_tags = {
        "O" : 0,
        "B-PER" : 1,
        "I-PER" : 2,
        "B-ORG" : 3,
        "I-ORG" : 4,
        "B-LOC" : 5,
        "I-LOC" : 6,
        "B-MISC" : 7,
        "I-MISC" : 8,
}
def covert_to_int(dt, tags):
    labels = []
    for label in dt:
        labels.append([tags[i] for i in label])
    return labels
train_tokens_ids = data_process(train_set['text'])
train_labels_ids = labels_process(covert_to_int(train_set['labels'], tags=num_tags))

val_tokens_ids = data_process(val_set['text'])
val_labels_ids = labels_process(covert_to_int(val_set['labels'], tags=num_tags))

test_tokens_ids = data_process(train_set['text'])
class LSTM(torch.nn.Module):

    def __init__(self, num_tags):
        super(LSTM, self).__init__()
        self.emb = torch.nn.Embedding(len(v.get_itos()), 100)
        self.rec = torch.nn.LSTM(100, 256, 1, batch_first=True)
        self.fc1 = torch.nn.Linear(256, num_tags)
        self.hidden2tag = torch.nn.Linear(20, num_tags)

    def forward(self, x):
        emb = torch.relu(self.emb(x))
        lstm_output, (h_n, c_n) = self.rec(emb)
        out_weights = self.fc1(lstm_output)
        return out_weights
EPOCHS = 10
LR = 0.001
NUM_TAGS = len(num_tags)
model = LSTM(num_tags=NUM_TAGS)
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
criterion = torch.nn.CrossEntropyLoss()
def get_scores(y_true, y_pred):
    # Funkcja zwraca precyzję, pokrycie i F1
    acc_score = 0
    tp = 0
    fp = 0
    selected_items = 0
    relevant_items = 0

    for p, t in zip(y_pred, y_true):
        if p == t:
            acc_score += 1

        if p > 0 and p == t:
            tp += 1

        if p > 0:
            selected_items += 1

        if t > 0:
            relevant_items += 1

    if selected_items == 0:
        precision = 1.0
    else:
        precision = tp / selected_items

    if relevant_items == 0:
        recall = 1.0
    else:
        recall = tp / relevant_items

    if precision + recall == 0.0:
        f1 = 0.0
    else:
        f1 = 2 * precision * recall / (precision + recall)

    acc = accuracy_score(y_true, y_pred)
    return precision, recall, f1, acc
def eval_model(dataset_tokens, dataset_labels, model):
    Y_true = []
    Y_pred = []
    for i in tqdm(range(len(dataset_labels))):
        batch_tokens = dataset_tokens[i].unsqueeze(0)
        tags = dataset_labels[i].unsqueeze(1)
        Y_true += tags

        Y_batch_pred_weights = model(batch_tokens).squeeze(0)
        Y_batch_pred = torch.argmax(Y_batch_pred_weights, 1)
        Y_pred += list(Y_batch_pred.numpy())

    precision, recall, f1, acc = get_scores(Y_true, Y_pred)
    print(f'precision: {precision}, recall: {recall}, f1: {f1}, val accuracy: {acc}')
NUM_EPOCHS = 10
for i in range(NUM_EPOCHS):
    model.train()
    train_true = []
    train_pred = []
    for i in tqdm(range(len(train_set['labels']))):
        batch_tokens = train_tokens_ids[i].unsqueeze(0)
        tags = train_labels_ids[i].unsqueeze(1)
        train_true += tags

        Y_batch_pred_weights = model(batch_tokens).squeeze(0)
        Y_batch_pred = torch.argmax(Y_batch_pred_weights, 1)
        train_pred += list(Y_batch_pred.numpy())

        predicted_tags = model(batch_tokens)

        optimizer.zero_grad()
        loss = criterion(predicted_tags.squeeze(0), tags.squeeze(1))

        loss.backward()
        optimizer.step()

    model.eval()
    eval_model(val_tokens_ids, val_labels_ids, model)
    print(f'Train accuracy: {accuracy_score(train_true, train_pred)}')
100%|██████████| 945/945 [02:21<00:00,  6.68it/s]
100%|██████████| 215/215 [00:02<00:00, 93.42it/s] 
precision: 0.8434014196726061, recall: 0.6783966441388953, f1: 0.7519535033903778, val accuracy: 0.9457621556580554
Train accuracy: 0.9983919167623316
100%|██████████| 945/945 [02:14<00:00,  7.04it/s]
100%|██████████| 215/215 [00:02<00:00, 89.87it/s] 
precision: 0.8440340076223981, recall: 0.6709391750174785, f1: 0.7475980264866269, val accuracy: 0.9454522251189587
Train accuracy: 0.9989522403833889
100%|██████████| 945/945 [02:08<00:00,  7.34it/s]
100%|██████████| 215/215 [00:02<00:00, 90.39it/s] 
precision: 0.852653120888759, recall: 0.6796783966441389, f1: 0.7564027750761848, val accuracy: 0.9472206523126288
Train accuracy: 0.9993303449406877
100%|██████████| 945/945 [02:18<00:00,  6.85it/s]
100%|██████████| 215/215 [00:03<00:00, 66.03it/s]
precision: 0.8375809935205184, recall: 0.6778140293637847, f1: 0.7492754556578862, val accuracy: 0.9455980747844159
Train accuracy: 0.9991891251662749
100%|██████████| 945/945 [02:07<00:00,  7.39it/s]
100%|██████████| 215/215 [00:03<00:00, 62.19it/s]
precision: 0.8413109098749461, recall: 0.6820088557445817, f1: 0.7533303301370746, val accuracy: 0.9462908606953383
Train accuracy: 0.9991435704003353
100%|██████████| 945/945 [02:13<00:00,  7.08it/s]
100%|██████████| 215/215 [00:03<00:00, 54.46it/s]
precision: 0.8479315263908702, recall: 0.6926124446515963, f1: 0.7624422780913289, val accuracy: 0.9478769758071868
Train accuracy: 0.998583246779278
100%|██████████| 945/945 [02:11<00:00,  7.20it/s]
100%|██████████| 215/215 [00:03<00:00, 57.36it/s]
precision: 0.8470877294406706, recall: 0.6829410393847588, f1: 0.7562092768208503, val accuracy: 0.9471294962717179
Train accuracy: 0.999180014213087
100%|██████████| 945/945 [02:15<00:00,  6.99it/s]
100%|██████████| 215/215 [00:04<00:00, 45.79it/s]
precision: 0.8728230645397337, recall: 0.6949429037520392, f1: 0.7737917612714889, val accuracy: 0.9498824087072251
Train accuracy: 0.9993212339874997
100%|██████████| 945/945 [02:38<00:00,  5.98it/s]
100%|██████████| 215/215 [00:07<00:00, 28.22it/s]
precision: 0.8691318792431028, recall: 0.7011186203682125, f1: 0.7761367300870687, val accuracy: 0.9505934258263296
Train accuracy: 0.9996310063958891
100%|██████████| 945/945 [02:03<00:00,  7.62it/s]
100%|██████████| 215/215 [00:02<00:00, 77.13it/s] 
precision: 0.8701146047605054, recall: 0.6900489396411092, f1: 0.7696906680530282, val accuracy: 0.949116697963574
Train accuracy: 0.9997540042639261
def save_prediction(test_tokens, test_pred, file_name):
    with open(file_name, 'w') as f:
        for i in range(len(test_tokens)):
            for j in range(len(test_tokens[i])):
                print(i, j)
                print(test_pred[i][j])
                f.write(f'{test_tokens[i][j]}\t{list(num_tags.keys())[test_pred[i][j]]}\n')
            f.write('\n')
test_pred = []

with torch.no_grad():
    for i in range(len(test_tokens_ids)):
        batch_tokens = test_tokens_ids[i].unsqueeze(0)

        Y_batch_pred_weights = model(batch_tokens).squeeze(0)
        Y_batch_pred = torch.argmax(Y_batch_pred_weights, 1)
        test_pred += list(Y_batch_pred.numpy())
0
with open('test-A/out.tsv', 'w') as f:
    for i in range(len(test_pred)):
        tag = list(num_tags.keys())[test_pred[i]]
        f.write(tag)
        f.write('\n')