DL_RNN/RNN.ipynb
2024-05-27 13:01:15 +02:00

50 KiB

RNN

Installation of packages

%pip install torch
%pip install torchtext
%pip install datasets
%pip install pandas
%pip install scikit-learn
Requirement already satisfied: torch in c:\users\skype\appdata\local\programs\python\python312\lib\site-packages (2.3.0)
Requirement already satisfied: filelock in c:\users\skype\appdata\local\programs\python\python312\lib\site-packages (from torch) (3.14.0)
Requirement already satisfied: typing-extensions>=4.8.0 in c:\users\skype\appdata\local\programs\python\python312\lib\site-packages (from torch) (4.10.0)
Requirement already satisfied: sympy in c:\users\skype\appdata\local\programs\python\python312\lib\site-packages (from torch) (1.12)
Requirement already satisfied: networkx in c:\users\skype\appdata\local\programs\python\python312\lib\site-packages (from torch) (3.2.1)
Requirement already satisfied: jinja2 in c:\users\skype\appdata\local\programs\python\python312\lib\site-packages (from torch) (3.1.3)
Requirement already satisfied: fsspec in c:\users\skype\appdata\local\programs\python\python312\lib\site-packages (from torch) (2024.3.1)
Requirement already satisfied: mkl<=2021.4.0,>=2021.1.1 in c:\users\skype\appdata\local\programs\python\python312\lib\site-packages (from torch) (2021.4.0)
Requirement already satisfied: intel-openmp==2021.* in c:\users\skype\appdata\local\programs\python\python312\lib\site-packages (from mkl<=2021.4.0,>=2021.1.1->torch) (2021.4.0)
Requirement already satisfied: tbb==2021.* in c:\users\skype\appdata\local\programs\python\python312\lib\site-packages (from mkl<=2021.4.0,>=2021.1.1->torch) (2021.12.0)
Requirement already satisfied: MarkupSafe>=2.0 in c:\users\skype\appdata\local\programs\python\python312\lib\site-packages (from jinja2->torch) (2.1.5)
Requirement already satisfied: mpmath>=0.19 in c:\users\skype\appdata\local\programs\python\python312\lib\site-packages (from sympy->torch) (1.3.0)
Note: you may need to restart the kernel to use updated packages.
Requirement already satisfied: torchtext in c:\users\skype\appdata\local\programs\python\python312\lib\site-packages (0.18.0)
Requirement already satisfied: tqdm in c:\users\skype\appdata\local\programs\python\python312\lib\site-packages (from torchtext) (4.66.2)
Requirement already satisfied: requests in c:\users\skype\appdata\local\programs\python\python312\lib\site-packages (from torchtext) (2.31.0)
Requirement already satisfied: torch>=2.3.0 in c:\users\skype\appdata\local\programs\python\python312\lib\site-packages (from torchtext) (2.3.0)
Requirement already satisfied: numpy in c:\users\skype\appdata\local\programs\python\python312\lib\site-packages (from torchtext) (1.26.3)
Requirement already satisfied: filelock in c:\users\skype\appdata\local\programs\python\python312\lib\site-packages (from torch>=2.3.0->torchtext) (3.14.0)
Requirement already satisfied: typing-extensions>=4.8.0 in c:\users\skype\appdata\local\programs\python\python312\lib\site-packages (from torch>=2.3.0->torchtext) (4.10.0)
Requirement already satisfied: sympy in c:\users\skype\appdata\local\programs\python\python312\lib\site-packages (from torch>=2.3.0->torchtext) (1.12)
Requirement already satisfied: networkx in c:\users\skype\appdata\local\programs\python\python312\lib\site-packages (from torch>=2.3.0->torchtext) (3.2.1)
Requirement already satisfied: jinja2 in c:\users\skype\appdata\local\programs\python\python312\lib\site-packages (from torch>=2.3.0->torchtext) (3.1.3)
Requirement already satisfied: fsspec in c:\users\skype\appdata\local\programs\python\python312\lib\site-packages (from torch>=2.3.0->torchtext) (2024.3.1)
Requirement already satisfied: mkl<=2021.4.0,>=2021.1.1 in c:\users\skype\appdata\local\programs\python\python312\lib\site-packages (from torch>=2.3.0->torchtext) (2021.4.0)
Requirement already satisfied: charset-normalizer<4,>=2 in c:\users\skype\appdata\local\programs\python\python312\lib\site-packages (from requests->torchtext) (3.3.2)
Requirement already satisfied: idna<4,>=2.5 in c:\users\skype\appdata\local\programs\python\python312\lib\site-packages (from requests->torchtext) (3.6)
Requirement already satisfied: urllib3<3,>=1.21.1 in c:\users\skype\appdata\local\programs\python\python312\lib\site-packages (from requests->torchtext) (2.2.1)
Requirement already satisfied: certifi>=2017.4.17 in c:\users\skype\appdata\local\programs\python\python312\lib\site-packages (from requests->torchtext) (2024.2.2)
Requirement already satisfied: colorama in c:\users\skype\appdata\roaming\python\python312\site-packages (from tqdm->torchtext) (0.4.6)
Requirement already satisfied: intel-openmp==2021.* in c:\users\skype\appdata\local\programs\python\python312\lib\site-packages (from mkl<=2021.4.0,>=2021.1.1->torch>=2.3.0->torchtext) (2021.4.0)
Requirement already satisfied: tbb==2021.* in c:\users\skype\appdata\local\programs\python\python312\lib\site-packages (from mkl<=2021.4.0,>=2021.1.1->torch>=2.3.0->torchtext) (2021.12.0)
Requirement already satisfied: MarkupSafe>=2.0 in c:\users\skype\appdata\local\programs\python\python312\lib\site-packages (from jinja2->torch>=2.3.0->torchtext) (2.1.5)
Requirement already satisfied: mpmath>=0.19 in c:\users\skype\appdata\local\programs\python\python312\lib\site-packages (from sympy->torch>=2.3.0->torchtext) (1.3.0)
Note: you may need to restart the kernel to use updated packages.
Requirement already satisfied: datasets in c:\users\skype\appdata\local\programs\python\python312\lib\site-packages (2.19.1)
Requirement already satisfied: filelock in c:\users\skype\appdata\local\programs\python\python312\lib\site-packages (from datasets) (3.14.0)
Requirement already satisfied: numpy>=1.17 in c:\users\skype\appdata\local\programs\python\python312\lib\site-packages (from datasets) (1.26.3)
Requirement already satisfied: pyarrow>=12.0.0 in c:\users\skype\appdata\local\programs\python\python312\lib\site-packages (from datasets) (15.0.2)
Requirement already satisfied: pyarrow-hotfix in c:\users\skype\appdata\local\programs\python\python312\lib\site-packages (from datasets) (0.6)
Requirement already satisfied: dill<0.3.9,>=0.3.0 in c:\users\skype\appdata\local\programs\python\python312\lib\site-packages (from datasets) (0.3.8)
Requirement already satisfied: pandas in c:\users\skype\appdata\local\programs\python\python312\lib\site-packages (from datasets) (2.2.1)
Requirement already satisfied: requests>=2.19.0 in c:\users\skype\appdata\local\programs\python\python312\lib\site-packages (from datasets) (2.31.0)
Requirement already satisfied: tqdm>=4.62.1 in c:\users\skype\appdata\local\programs\python\python312\lib\site-packages (from datasets) (4.66.2)
Requirement already satisfied: xxhash in c:\users\skype\appdata\local\programs\python\python312\lib\site-packages (from datasets) (3.4.1)
Requirement already satisfied: multiprocess in c:\users\skype\appdata\local\programs\python\python312\lib\site-packages (from datasets) (0.70.16)
Requirement already satisfied: fsspec<=2024.3.1,>=2023.1.0 in c:\users\skype\appdata\local\programs\python\python312\lib\site-packages (from fsspec[http]<=2024.3.1,>=2023.1.0->datasets) (2024.3.1)
Requirement already satisfied: aiohttp in c:\users\skype\appdata\local\programs\python\python312\lib\site-packages (from datasets) (3.9.5)
Requirement already satisfied: huggingface-hub>=0.21.2 in c:\users\skype\appdata\local\programs\python\python312\lib\site-packages (from datasets) (0.23.1)
Requirement already satisfied: packaging in c:\users\skype\appdata\roaming\python\python312\site-packages (from datasets) (23.2)
Requirement already satisfied: pyyaml>=5.1 in c:\users\skype\appdata\local\programs\python\python312\lib\site-packages (from datasets) (6.0.1)
Requirement already satisfied: aiosignal>=1.1.2 in c:\users\skype\appdata\local\programs\python\python312\lib\site-packages (from aiohttp->datasets) (1.3.1)
Requirement already satisfied: attrs>=17.3.0 in c:\users\skype\appdata\local\programs\python\python312\lib\site-packages (from aiohttp->datasets) (23.2.0)
Requirement already satisfied: frozenlist>=1.1.1 in c:\users\skype\appdata\local\programs\python\python312\lib\site-packages (from aiohttp->datasets) (1.4.1)
Requirement already satisfied: multidict<7.0,>=4.5 in c:\users\skype\appdata\local\programs\python\python312\lib\site-packages (from aiohttp->datasets) (6.0.5)
Requirement already satisfied: yarl<2.0,>=1.0 in c:\users\skype\appdata\local\programs\python\python312\lib\site-packages (from aiohttp->datasets) (1.9.4)
Requirement already satisfied: typing-extensions>=3.7.4.3 in c:\users\skype\appdata\local\programs\python\python312\lib\site-packages (from huggingface-hub>=0.21.2->datasets) (4.10.0)
Requirement already satisfied: charset-normalizer<4,>=2 in c:\users\skype\appdata\local\programs\python\python312\lib\site-packages (from requests>=2.19.0->datasets) (3.3.2)
Requirement already satisfied: idna<4,>=2.5 in c:\users\skype\appdata\local\programs\python\python312\lib\site-packages (from requests>=2.19.0->datasets) (3.6)
Requirement already satisfied: urllib3<3,>=1.21.1 in c:\users\skype\appdata\local\programs\python\python312\lib\site-packages (from requests>=2.19.0->datasets) (2.2.1)
Requirement already satisfied: certifi>=2017.4.17 in c:\users\skype\appdata\local\programs\python\python312\lib\site-packages (from requests>=2.19.0->datasets) (2024.2.2)
Requirement already satisfied: colorama in c:\users\skype\appdata\roaming\python\python312\site-packages (from tqdm>=4.62.1->datasets) (0.4.6)
Requirement already satisfied: python-dateutil>=2.8.2 in c:\users\skype\appdata\roaming\python\python312\site-packages (from pandas->datasets) (2.9.0.post0)
Requirement already satisfied: pytz>=2020.1 in c:\users\skype\appdata\local\programs\python\python312\lib\site-packages (from pandas->datasets) (2024.1)
Requirement already satisfied: tzdata>=2022.7 in c:\users\skype\appdata\local\programs\python\python312\lib\site-packages (from pandas->datasets) (2024.1)
Requirement already satisfied: six>=1.5 in c:\users\skype\appdata\roaming\python\python312\site-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.16.0)
Note: you may need to restart the kernel to use updated packages.
Requirement already satisfied: pandas in c:\users\skype\appdata\local\programs\python\python312\lib\site-packages (2.2.1)
Requirement already satisfied: numpy<2,>=1.26.0 in c:\users\skype\appdata\local\programs\python\python312\lib\site-packages (from pandas) (1.26.3)
Requirement already satisfied: python-dateutil>=2.8.2 in c:\users\skype\appdata\roaming\python\python312\site-packages (from pandas) (2.9.0.post0)
Requirement already satisfied: pytz>=2020.1 in c:\users\skype\appdata\local\programs\python\python312\lib\site-packages (from pandas) (2024.1)
Requirement already satisfied: tzdata>=2022.7 in c:\users\skype\appdata\local\programs\python\python312\lib\site-packages (from pandas) (2024.1)
Requirement already satisfied: six>=1.5 in c:\users\skype\appdata\roaming\python\python312\site-packages (from python-dateutil>=2.8.2->pandas) (1.16.0)
Note: you may need to restart the kernel to use updated packages.
Requirement already satisfied: scikit-learn in c:\users\skype\appdata\local\programs\python\python312\lib\site-packages (1.4.1.post1)
Requirement already satisfied: numpy<2.0,>=1.19.5 in c:\users\skype\appdata\local\programs\python\python312\lib\site-packages (from scikit-learn) (1.26.3)
Requirement already satisfied: scipy>=1.6.0 in c:\users\skype\appdata\local\programs\python\python312\lib\site-packages (from scikit-learn) (1.12.0)
Requirement already satisfied: joblib>=1.2.0 in c:\users\skype\appdata\local\programs\python\python312\lib\site-packages (from scikit-learn) (1.3.2)
Requirement already satisfied: threadpoolctl>=2.0.0 in c:\users\skype\appdata\local\programs\python\python312\lib\site-packages (from scikit-learn) (3.3.0)
Note: you may need to restart the kernel to use updated packages.

Importing libraries

from collections import Counter
import torch
import pandas as pd
from torchtext.vocab import vocab
from sklearn.model_selection import train_test_split
from tqdm.notebook import tqdm

Read datasets

def read_data():
    train_dataset = pd.read_csv(
        "train/train.tsv.xz", compression="xz", sep="\t", names=["Label", "Text"]
    )
    dev_0_dataset = pd.read_csv("dev-0/in.tsv", sep="\t", names=["Text"])
    dev_0_labels = pd.read_csv("dev-0/expected.tsv", sep="\t", names=["Label"])
    test_A_dataset = pd.read_csv("test-A/in.tsv", sep="\t", names=["Text"])

    return train_dataset, dev_0_dataset, dev_0_labels, test_A_dataset
train_dataset, dev_0_dataset, dev_0_labels, test_A_dataset = read_data()

Split the training data into training and validation sets

train_texts, val_texts, train_labels, val_labels = train_test_split(
    train_dataset["Text"], train_dataset["Label"], test_size=0.1, random_state=42
)
train_dataset = pd.DataFrame({"Text": train_texts, "Label": train_labels})
val_dataset = pd.DataFrame({"Text": val_texts, "Label": val_labels})

Tokenize the text and labels

train_dataset["tokenized_text"] = train_dataset["Text"].apply(lambda x: x.split())
train_dataset["tokenized_labels"] = train_dataset["Label"].apply(lambda x: x.split())
val_dataset["tokenized_text"] = val_dataset["Text"].apply(lambda x: x.split())
val_dataset["tokenized_labels"] = val_dataset["Label"].apply(lambda x: x.split())
dev_0_dataset["tokenized_text"] = dev_0_dataset["Text"].apply(lambda x: x.split())
dev_0_dataset["tokenized_labels"] = dev_0_labels["Label"].apply(lambda x: x.split())
test_A_dataset["tokenized_text"] = test_A_dataset["Text"].apply(lambda x: x.split())

Create a vocab object which maps tokens to indices

def build_vocab(dataset):
    counter = Counter()
    for document in dataset:
        counter.update(document)
    return vocab(counter, specials=["<unk>", "<pad>", "<bos>", "<eos>"])
v = build_vocab(train_dataset["tokenized_text"])

Map indices to tokens

itos = v.get_itos()

Number of tokens in the vocabulary

len(itos)
22154

Index of the 'rejects' token

v["rejects"]
9086

Index of the '<unk>' token

v["<unk>"]
0

Set the default index to the unknown token

v.set_default_index(v["<unk>"])

Use cuda if available

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Vectorize the data

def data_process(dt):
    return [
        torch.tensor(
            [v["<bos>"]] + [v[token] for token in document] + [v["<eos>"]],
            dtype=torch.long,
            device=device,
        )
        for document in dt
    ]
train_tokens_ids = data_process(train_dataset["tokenized_text"])
val_tokens_ids = data_process(val_dataset["tokenized_text"])
dev_0_tokens_ids = data_process(dev_0_dataset["tokenized_text"])
test_A_tokens_ids = data_process(test_A_dataset["tokenized_text"])

Create a mapping from label to index

labels = ["O", "B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC", "B-MISC", "I-MISC"]

label_to_index = {label: idx for idx, label in enumerate(labels)}

Vectorize the labels (NER)

def labels_process(dt, label_to_index):
    return [
        torch.tensor(
            [0] + [label_to_index[label] for label in document] + [0],
            dtype=torch.long,
            device=device,
        )
        for document in dt
    ]
train_labels = labels_process(train_dataset["tokenized_labels"], label_to_index)
val_labels = labels_process(val_dataset["tokenized_labels"], label_to_index)
dev_0_labels = labels_process(dev_0_dataset["tokenized_labels"], label_to_index)

Function for evaluation (returns precision, recall, and F1 score)

def get_scores(y_true, y_pred):
    acc_score = 0
    tp = 0
    fp = 0
    selected_items = 0
    relevant_items = 0

    for p, t in zip(y_pred, y_true):
        if p == t:
            acc_score += 1

        if p > 0 and p == t:
            tp += 1

        if p > 0:
            selected_items += 1

        if t > 0:
            relevant_items += 1

    if selected_items == 0:
        precision = 1.0
    else:
        precision = tp / selected_items

    if relevant_items == 0:
        recall = 1.0
    else:
        recall = tp / relevant_items

    if precision + recall == 0.0:
        f1 = 0.0
    else:
        f1 = 2 * precision * recall / (precision + recall)

    return precision, recall, f1

Calculate the number of unique tags

all_label_indices = [
    label_to_index[label]
    for document in train_dataset["tokenized_labels"]
    for label in document
]

num_tags = max(all_label_indices) + 1

print(num_tags)
9

Implementation of a recurrent neural network LSTM

class LSTM(torch.nn.Module):

    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers, num_tags):
        super(LSTM, self).__init__()
        self.embedding = torch.nn.Embedding(vocab_size, embedding_dim)
        self.rec = torch.nn.LSTM(
            embedding_dim, hidden_dim, num_layers, batch_first=True, bidirectional=True
        )
        self.fc1 = torch.nn.Linear(hidden_dim * 2, num_tags)

    def forward(self, x):
        embedding = torch.relu(self.embedding(x))
        lstm_output, _ = self.rec(embedding)
        out_weights = self.fc1(lstm_output)
        return out_weights

Initialize the LSTM model

lstm = LSTM(len(v.get_itos()), 100, 100, 1, num_tags).to(device)

Define the loss function

criterion = torch.nn.CrossEntropyLoss()

Define the optimizer

optimizer = torch.optim.Adam(lstm.parameters())

Function for model evaluation

def eval_model(dataset_tokens, dataset_labels, model):
    Y_true = []
    Y_pred = []
    for i in tqdm(range(len(dataset_labels))):
        batch_tokens = dataset_tokens[i].unsqueeze(0)
        tags = list(dataset_labels[i].cpu().numpy())
        Y_true += tags

        Y_batch_pred_weights = model(batch_tokens).squeeze(0)
        Y_batch_pred = torch.argmax(Y_batch_pred_weights, 1)
        Y_pred += list(Y_batch_pred.cpu().numpy())

    return get_scores(Y_true, Y_pred)

Function for returning the predictions labels

def pred_labels(dataset_tokens, model, label_to_index):
    Y_pred = []
    inv_label_to_index = {
        v: k for k, v in label_to_index.items()
    }  # Create the inverse mapping

    for i in tqdm(range(len(dataset_tokens))):
        batch_tokens = dataset_tokens[i].unsqueeze(0)
        Y_batch_pred_weights = model(batch_tokens).squeeze(0)
        Y_batch_pred = torch.argmax(Y_batch_pred_weights, 1)
        predicted_labels = [inv_label_to_index[label.item()] for label in Y_batch_pred]
        predicted_labels = predicted_labels[1:-1]
        Y_pred.append(" ".join(predicted_labels))

    return Y_pred

Training

NUM_EPOCHS = 20
for i in range(NUM_EPOCHS):
    lstm.train()
    for i in tqdm(range(len(train_labels))):
        batch_tokens = train_tokens_ids[i].unsqueeze(0)
        tags = train_labels[i].unsqueeze(1)

        predicted_tags = lstm(batch_tokens)

        optimizer.zero_grad()
        loss = criterion(predicted_tags.squeeze(0), tags.squeeze(1))

        loss.backward()
        optimizer.step()

    lstm.eval()
    print(eval_model(val_tokens_ids, val_labels, lstm))
  0%|          | 0/850 [00:00<?, ?it/s]
  0%|          | 0/95 [00:00<?, ?it/s]
(0.6839857651245551, 0.27638769053782, 0.3936911102007374)
  0%|          | 0/850 [00:00<?, ?it/s]
  0%|          | 0/95 [00:00<?, ?it/s]
(0.7672373900971773, 0.4768478573482888, 0.5881518268889678)
  0%|          | 0/850 [00:00<?, ?it/s]
  0%|          | 0/95 [00:00<?, ?it/s]
(0.7881323697223279, 0.5959160195570894, 0.678676711431379)
  0%|          | 0/850 [00:00<?, ?it/s]
  0%|          | 0/95 [00:00<?, ?it/s]
(0.7992857142857143, 0.6436583261432269, 0.7130794965747969)
  0%|          | 0/850 [00:00<?, ?it/s]
  0%|          | 0/95 [00:00<?, ?it/s]
(0.7935967302452316, 0.6701179177451826, 0.7266489942304692)
  0%|          | 0/850 [00:00<?, ?it/s]
  0%|          | 0/95 [00:00<?, ?it/s]
(0.7951251646903821, 0.6942766752947943, 0.7412866574543221)
  0%|          | 0/850 [00:00<?, ?it/s]
  0%|          | 0/95 [00:00<?, ?it/s]
(0.794789321325185, 0.7106701179177451, 0.7503795930762224)
  0%|          | 0/850 [00:00<?, ?it/s]
  0%|          | 0/95 [00:00<?, ?it/s]
(0.8046014257939079, 0.7141213689962611, 0.7566661587688556)
  0%|          | 0/850 [00:00<?, ?it/s]
  0%|          | 0/95 [00:00<?, ?it/s]
(0.8059558117195005, 0.7238999137187231, 0.7627272727272727)
  0%|          | 0/850 [00:00<?, ?it/s]
  0%|          | 0/95 [00:00<?, ?it/s]
(0.8214634146341463, 0.72648835202761, 0.7710622710622711)
  0%|          | 0/850 [00:00<?, ?it/s]
  0%|          | 0/95 [00:00<?, ?it/s]
(0.8328277721604315, 0.7106701179177451, 0.7669149596523898)
  0%|          | 0/850 [00:00<?, ?it/s]
  0%|          | 0/95 [00:00<?, ?it/s]
(0.8623625399077687, 0.6991659476560254, 0.7722363405336722)
  0%|          | 0/850 [00:00<?, ?it/s]
  0%|          | 0/95 [00:00<?, ?it/s]
(0.83496126641967, 0.7129709519700892, 0.7691591684765747)
  0%|          | 0/850 [00:00<?, ?it/s]
  0%|          | 0/95 [00:00<?, ?it/s]
(0.8154145077720207, 0.724187517975266, 0.7670982482863671)
  0%|          | 0/850 [00:00<?, ?it/s]
  0%|          | 0/95 [00:00<?, ?it/s]
(0.8324958123953099, 0.7146965775093471, 0.7691117301145157)
  0%|          | 0/850 [00:00<?, ?it/s]
  0%|          | 0/95 [00:00<?, ?it/s]
(0.839452603471295, 0.7233247052056371, 0.7770739996910242)
  0%|          | 0/850 [00:00<?, ?it/s]
  0%|          | 0/95 [00:00<?, ?it/s]
(0.8292120013188262, 0.7233247052056371, 0.7726574500768049)
  0%|          | 0/850 [00:00<?, ?it/s]
  0%|          | 0/95 [00:00<?, ?it/s]
(0.8110863332271424, 0.73224043715847, 0.7696493349455865)
  0%|          | 0/850 [00:00<?, ?it/s]
  0%|          | 0/95 [00:00<?, ?it/s]
(0.8210254756530152, 0.73224043715847, 0.7740954697476436)
  0%|          | 0/850 [00:00<?, ?it/s]
  0%|          | 0/95 [00:00<?, ?it/s]
(0.8173884938590821, 0.727351164797239, 0.7697458529904124)
eval_model(val_tokens_ids, val_labels, lstm)
  0%|          | 0/95 [00:00<?, ?it/s]
(0.8173884938590821, 0.727351164797239, 0.7697458529904124)
eval_model(dev_0_tokens_ids, dev_0_labels, lstm)
  0%|          | 0/215 [00:00<?, ?it/s]
(0.8368401624215578, 0.7924726171055698, 0.8140523071398648)
dev_0_predictons = pred_labels(dev_0_tokens_ids, lstm, label_to_index)
dev_0_predictons = pd.DataFrame(dev_0_predictons, columns=["Label"])
dev_0_predictons.to_csv("dev-0/out.tsv", index=False, header=False)
  0%|          | 0/215 [00:00<?, ?it/s]
test_A_predictions = pred_labels(test_A_tokens_ids, lstm, label_to_index)
test_A_predictions = pd.DataFrame(test_A_predictions, columns=["Label"])
test_A_predictions.to_csv("test-A/out.tsv", index=False, header=False)
  0%|          | 0/230 [00:00<?, ?it/s]

Correct labels

def correct_labels(input_file, output_file):
    df = pd.read_csv(input_file, sep="\t", names=["Text"])

    corrected_lines = []

    for line in df["Text"]:
        tokens = line.split(" ")
        corrected_tokens = []
        previous_token = "O"

        for token in tokens:
            if (
                token == "I-ORG"
                and previous_token != "B-ORG"
                and previous_token != "I-ORG"
            ):
                corrected_tokens.append("B-ORG")
            elif (
                token == "I-PER"
                and previous_token != "B-PER"
                and previous_token != "I-PER"
            ):
                corrected_tokens.append("B-PER")
            elif (
                token == "I-LOC"
                and previous_token != "B-LOC"
                and previous_token != "I-LOC"
            ):
                corrected_tokens.append("B-LOC")
            elif (
                token == "I-MISC"
                and previous_token != "B-MISC"
                and previous_token != "I-MISC"
            ):
                corrected_tokens.append("B-MISC")
            else:
                corrected_tokens.append(token)

            previous_token = token

        corrected_line = " ".join(corrected_tokens)
        corrected_lines.append(corrected_line)

    df["Text"] = corrected_lines
    df.to_csv(output_file, sep="\t", index=False, header=False)
correct_labels("dev-0/out.tsv", "dev-0/out.tsv")
correct_labels("test-A/out.tsv", "test-A/out.tsv")