### Importy

In [0]:
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
from collections import Counter
from torchtext.vocab import vocab
from tqdm import tqdm

### Wczytanie danych

In [0]:
def read_custom_dataset(path, is_train=True):
    if is_train:
        data = pd.read_csv(path, sep='\t', header=None, compression='xz')
        data.columns = ['ner', 'document']
    else:
        with open(path, 'r') as f:
            documents = f.read().splitlines()
        data = pd.DataFrame(documents, columns=['document'])
    return data


In [0]:
train_path = 'train/train.tsv.xz'
dev_in_path = 'dev-0/in.tsv'
dev_expected_path = 'dev-0/expected.tsv'
test_in_path = 'test-A/in.tsv'

train_data = read_custom_dataset(train_path, is_train=True)
dev_data = read_custom_dataset(dev_in_path, is_train=False)
dev_labels = pd.read_csv(dev_expected_path, header=None, names=['ner'])
test_data = read_custom_dataset(test_in_path, is_train=False)

### Tokenizacja

In [0]:
def tokenize_documents(data):
    return [doc.split() for doc in data['document'].tolist()]

train_tokens = tokenize_documents(train_data)
dev_tokens = tokenize_documents(dev_data)
test_tokens = tokenize_documents(test_data)

### Budowa s≈Çownika

In [20]:
def build_vocab(tokens_list):
    counter = Counter()
    for tokens in tokens_list:
        counter.update(tokens)
    return vocab(counter, specials=["<unk>", "<pad>", "<bos>", "<eos>"])

token_vocab = build_vocab(train_tokens)
token_vocab.set_default_index(token_vocab["<unk>"])

In [26]:
print(token_vocab.get_itos())

Wielko≈õƒá s≈Çownika: Vocab()


In [28]:
print("Wielko≈õƒá s≈Çownika:")
print(len(token_vocab))

Wielko≈õƒá s≈Çownika:
23628


### Bio tagging

In [0]:
bio_tags = sorted(set(tag for tags in train_data['ner'].apply(lambda x: x.split()) for tag in tags))
bio_to_int = {tag: idx for idx, tag in enumerate(bio_tags, start=1)}
bio_to_int["O"] = 0

label_mapping = {
    0: 'O',
    1: 'B-LOC',
    2: 'B-MISC',
    3: 'B-ORG',
    4: 'B-PER',
    5: 'I-LOC',
    6: 'I-MISC',
    7: 'I-ORG',
    8: 'I-PER'
}

### Przetwarzanie danych

In [0]:
def convert_bio_to_int(tags_list):
    return [bio_to_int[tag] for tag in tags_list]

def data_process(tokens_list):
    return [
        torch.tensor(
            [token_vocab["<bos>"]] + [token_vocab[token] for token in tokens] + [token_vocab["<eos>"]],
            dtype=torch.long,
        )
        for tokens in tokens_list
    ]

def labels_process(labels_list):
    return [torch.tensor([0] + labels + [0], dtype=torch.long) for labels in labels_list]

train_tokens_ids = data_process(train_tokens)
dev_tokens_ids = data_process(dev_tokens)
test_tokens_ids = data_process(test_tokens)

train_labels = labels_process(train_data['ner'].apply(lambda x: convert_bio_to_int(x.split())).tolist())
dev_labels = labels_process(dev_labels['ner'].apply(lambda x: convert_bio_to_int(x.split())).tolist())

### Model

In [0]:
class LSTMModel(nn.Module):
    def __init__(self, vocab_size, num_labels, embedding_dim=128, lstm_units=64):
        super(LSTMModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, lstm_units, num_layers=1,
                            batch_first=True, dropout=0.1)
        self.linear = nn.Linear(lstm_units, num_labels)
        self.softmax = nn.LogSoftmax(dim=2)

    def forward(self, x):
        embedded = self.embedding(x)
        lstm_out, _ = self.lstm(embedded)
        logits = self.linear(lstm_out)
        return self.softmax(logits)

vocab_size = len(token_vocab.get_itos())
num_labels = len(bio_to_int)
model = LSTMModel(vocab_size, num_labels)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())

### Trening i ewaluacja

In [0]:
def get_scores(y_true, y_pred):
    acc_score = 0
    tp = 0
    selected_items = 0
    relevant_items = 0

    for p, t in zip(y_pred, y_true):
        if p == t:
            acc_score += 1

        if p > 0 and p == t:
            tp += 1

        if p > 0:
            selected_items += 1

        if t > 0:
            relevant_items += 1

    if selected_items == 0:
        precision = 1.0
    else:
        precision = tp / selected_items

    if relevant_items == 0:
        recall = 1.0
    else:
        recall = tp / relevant_items

    if precision + recall == 0.0:
        f1 = 0.0
    else:
        f1 = 2 * precision * recall / (precision + recall)

    return precision, recall, f1

def eval_model(dataset_tokens, dataset_labels, model):
    Y_true = []
    Y_pred = []
    for i in tqdm(range(len(dataset_labels))):
        batch_tokens = dataset_tokens[i].unsqueeze(0)
        tags = list(dataset_labels[i].numpy())
        Y_true += tags

        Y_batch_pred_weights = model(batch_tokens).squeeze(0)
        Y_batch_pred = torch.argmax(Y_batch_pred_weights, 1)
        Y_pred += list(Y_batch_pred.numpy())

    return get_scores(Y_true, Y_pred)


In [14]:
NUM_EPOCHS = 20
for epoch in range(NUM_EPOCHS):
    model.train()
    for i in tqdm(range(len(train_labels))):
        batch_tokens = train_tokens_ids[i].unsqueeze(0)
        tags = train_labels[i].unsqueeze(1)

        predicted_tags = model(batch_tokens)

        optimizer.zero_grad()
        loss = criterion(predicted_tags.view(-1, num_labels), tags.view(-1))
        loss.backward()
        optimizer.step()

    model.eval()
    print(f'Epoch {epoch + 1}')
    print(eval_model(dev_tokens_ids, dev_labels, model))

100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 945/945 [00:26<00:00, 35.31it/s]


Epoch 1


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 215/215 [00:00<00:00, 404.90it/s]


(0.6013011152416357, 0.2261710556979725, 0.3287044877222693)


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 945/945 [00:46<00:00, 20.34it/s]


Epoch 2


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 215/215 [00:00<00:00, 255.35it/s]


(0.7338551859099804, 0.48065718946632485, 0.5808631979159333)


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 945/945 [00:45<00:00, 20.57it/s]


Epoch 3


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 215/215 [00:00<00:00, 389.49it/s]


(0.809629044988161, 0.5976462363085527, 0.6876717838707515)


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 945/945 [00:45<00:00, 20.85it/s]


Epoch 4


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 215/215 [00:00<00:00, 423.23it/s]


(0.8388561053109397, 0.6460032626427407, 0.7299058653149891)


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 945/945 [00:45<00:00, 20.96it/s]


Epoch 5


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 215/215 [00:00<00:00, 402.62it/s]


(0.8558545239503252, 0.6745513866231647, 0.754463703896781)


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 945/945 [00:45<00:00, 20.99it/s]


Epoch 6


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 215/215 [00:00<00:00, 416.67it/s]


(0.8630437966896147, 0.6865532509904451, 0.7647478746187292)


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 945/945 [00:45<00:00, 20.95it/s]


Epoch 7


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 215/215 [00:00<00:00, 411.09it/s]


(0.8659315147997678, 0.6954089955721278, 0.7713584076515446)


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 945/945 [00:45<00:00, 20.96it/s]


Epoch 8


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 215/215 [00:00<00:00, 413.46it/s]


(0.8631503920171062, 0.7055464926590538, 0.7764313650060909)


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 945/945 [00:45<00:00, 20.76it/s]


Epoch 9


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 215/215 [00:00<00:00, 407.20it/s]


(0.8623391158365976, 0.718247494756467, 0.7837253655435473)


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 945/945 [00:45<00:00, 20.69it/s]


Epoch 10


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 215/215 [00:00<00:00, 375.22it/s]


(0.865633442343239, 0.7093917501747844, 0.7797630483509447)


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 945/945 [00:45<00:00, 20.75it/s]


Epoch 11


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 215/215 [00:00<00:00, 411.88it/s]


(0.86810551558753, 0.7170822652062456, 0.7853997830387339)


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 945/945 [00:45<00:00, 20.64it/s]


Epoch 12


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 215/215 [00:00<00:00, 404.89it/s]


(0.8732292570106968, 0.7039151712887439, 0.779483870967742)


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 945/945 [00:46<00:00, 20.54it/s]


Epoch 13


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 215/215 [00:00<00:00, 407.20it/s]


(0.8747655460972442, 0.706478676299231, 0.7816669889769869)


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 945/945 [00:46<00:00, 20.50it/s]


Epoch 14


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 215/215 [00:00<00:00, 417.06it/s]


(0.8458568868054598, 0.714868329060825, 0.7748658035996211)


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 945/945 [00:46<00:00, 20.46it/s]


Epoch 15


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 215/215 [00:00<00:00, 408.75it/s]


(0.8549187103885368, 0.7230249359123747, 0.783459595959596)


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 945/945 [00:46<00:00, 20.49it/s]


Epoch 16


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 215/215 [00:00<00:00, 401.12it/s]


(0.8629124004550626, 0.7070612910743417, 0.7772511848341231)


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 945/945 [00:45<00:00, 20.58it/s]


Epoch 17


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 215/215 [00:00<00:00, 419.10it/s]


(0.8741012472487161, 0.6941272430668842, 0.7737871013833864)


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 945/945 [00:45<00:00, 20.73it/s]


Epoch 18


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 215/215 [00:00<00:00, 404.14it/s]


(0.8725857595848948, 0.7054299697040317, 0.7801546391752578)


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 945/945 [00:46<00:00, 20.28it/s]


Epoch 19


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 215/215 [00:00<00:00, 411.87it/s]


(0.876505586997533, 0.7037986483337217, 0.7807147935112777)


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 945/945 [00:46<00:00, 20.32it/s]


Epoch 20


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 215/215 [00:00<00:00, 413.46it/s]


(0.8801597869507324, 0.6931950594267071, 0.7755687373704453)


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 215/215 [00:00<00:00, 377.20it/s]
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 230/230 [00:00<00:00, 352.76it/s]


### Predykcje

In [17]:
def validate_bio_sequence(labels):
    corrected_labels = []
    previous_label = 'O'
    for label in labels:
        if label.startswith('I-'):
            if previous_label == 'O' or previous_label[2:] != label[2:]:
                corrected_labels.append('B-' + label[2:])
            else:
                corrected_labels.append(label)
        else:
            corrected_labels.append(label)
        previous_label = corrected_labels[-1]
    return corrected_labels

def save_predictions(tokens_ids, model, output_path, label_mapping):
    predictions = []
    for i in tqdm(range(len(tokens_ids))):
        batch_tokens = tokens_ids[i].unsqueeze(0)
        Y_batch_pred_weights = model(batch_tokens).squeeze(0)
        Y_batch_pred = torch.argmax(Y_batch_pred_weights, 1)
        bio_labels = [label_mapping[label] for label in Y_batch_pred.numpy()[1:-1]]
        bio_labels = validate_bio_sequence(bio_labels)
        predictions.append(" ".join(bio_labels))

    with open(output_path, 'w') as f:
        for line in predictions:
            f.write(line + '\n')

save_predictions(dev_tokens_ids, model, 'dev-0/out.tsv', label_mapping)
save_predictions(test_tokens_ids, model, 'test-A/out.tsv', label_mapping)

100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 215/215 [00:01<00:00, 141.73it/s]
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 230/230 [00:01<00:00, 163.82it/s]
