{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "initial_id", "metadata": { "collapsed": true, "ExecuteTime": { "end_time": "2024-05-16T18:21:49.572131300Z", "start_time": "2024-05-16T18:21:43.423852800Z" } }, "outputs": [], "source": [ "import pandas as pd\n", "import numpy as np\n", "\n", "import torch\n", "import torch.nn as nn\n", "import torch.optim as optim\n", "\n", "import warnings\n", "warnings.filterwarnings('ignore')\n", "\n", "import torchtext\n", "from torchtext.vocab import vocab\n", "\n", "from seqeval.metrics import accuracy_score, precision_score, recall_score, f1_score, classification_report\n", "\n", "from tqdm.notebook import tqdm\n", "\n", "from collections import Counter" ] }, { "cell_type": "code", "execution_count": 7, "outputs": [], "source": [ "# Load the data\n", "train_data = pd.read_csv('train/train.tsv', delimiter='\\t', header=None)\n", "\n", "valid_data_in = pd.read_csv('dev-0/in.tsv', delimiter='\\t', header=None)\n", "valid_data_expected = pd.read_csv('dev-0/expected.tsv', delimiter='\\t', header=None)\n", "valid_data = pd.concat([valid_data_expected, valid_data_in], axis=1)\n", "\n", "test_data = pd.read_csv('test-A/in.tsv', delimiter='\\t', header=None)\n", "\n", "# Label the columns\n", "train_data.columns = ['ner_tags', 'text']\n", "valid_data.columns = ['ner_tags', 'text']\n", "test_data.columns = ['text']\n", "\n", "# Split the text into tokens\n", "train_data['text_tokens'] = train_data['text'].apply(lambda x: x.split())\n", "valid_data['text_tokens'] = valid_data['text'].apply(lambda x: x.split())\n", "test_data['text_tokens'] = test_data['text'].apply(lambda x: x.split())\n", "\n", "# Split the NER tags into tokens\n", "train_data['ner_tags_tokens'] = train_data['ner_tags'].apply(lambda x: x.split())\n", "valid_data['ner_tags_tokens'] = valid_data['ner_tags'].apply(lambda x: x.split())" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-05-14T07:11:23.174336100Z", "start_time": "2024-05-14T07:11:23.080690300Z" } }, "id": "9e5c5c1083e3f387" }, { "cell_type": "code", "execution_count": 8, "outputs": [], "source": [ "# Method for building the vocabulary from DataFrame dataset\n", "# Special tokens:\n", "# - unknown token\n", "# - padding token\n", "# - beginning of sentence token\n", "# - end of sentence token\n", "def build_vocab(dataset):\n", " # Initialize the counter\n", " counter = Counter()\n", " \n", " # Iterate over the dataset and update the counter\n", " for idx, document in dataset.iterrows():\n", " counter.update(document['text_tokens'])\n", " \n", " # Return the vocabulary\n", " return vocab(counter, specials=['', '', '', ''])" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-05-14T07:11:23.647897500Z", "start_time": "2024-05-14T07:11:23.640148800Z" } }, "id": "56a8833a05334060" }, { "cell_type": "code", "execution_count": 9, "outputs": [], "source": [ "# Build the vocabulary\n", "v = build_vocab(train_data)" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-05-14T07:11:24.169912Z", "start_time": "2024-05-14T07:11:24.081356500Z" } }, "id": "eacfbc15230adc2e" }, { "cell_type": "code", "execution_count": 10, "outputs": [], "source": [ "# Mapping from index to token\n", "itos = v.get_itos()" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-05-14T07:11:24.484522400Z", "start_time": "2024-05-14T07:11:24.470356200Z" } }, "id": "c9c7ce32ebd5a3c2" }, { "cell_type": "code", "execution_count": 11, "outputs": [], "source": [ "# Set default index for unknown tokens\n", "v.set_default_index(v[\"\"])" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-05-14T07:11:24.842556700Z", "start_time": "2024-05-14T07:11:24.823442400Z" } }, "id": "ce8d899162dcc776" }, { "cell_type": "code", "execution_count": 12, "outputs": [], "source": [ "# Get the unique ner tags\n", "ner_tags = set([tag for tags in train_data['ner_tags_tokens'] for tag in tags])" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-05-14T07:11:25.201567900Z", "start_time": "2024-05-14T07:11:25.180831600Z" } }, "id": "2e9f2dc469b6025d" }, { "cell_type": "code", "execution_count": 13, "outputs": [], "source": [ "# Mapping from tag to index (https://huggingface.co/datasets/conll2003)\n", "ner_tag2idx = {'O': 0, 'B-PER': 1, 'I-PER': 2, 'B-ORG': 3, 'I-ORG': 4, 'B-LOC': 5, 'I-LOC': 6, 'B-MISC': 7, 'I-MISC': 8}\n", "\n", "# reverse mapping\n", "ner_idx2tag = {idx: tag for tag, idx in ner_tag2idx.items()}" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-05-14T07:11:26.534701200Z", "start_time": "2024-05-14T07:11:26.526620300Z" } }, "id": "5271fd04bd9f16e3" }, { "cell_type": "code", "execution_count": 14, "outputs": [ { "data": { "text/plain": "{'O': 0,\n 'B-PER': 1,\n 'I-PER': 2,\n 'B-ORG': 3,\n 'I-ORG': 4,\n 'B-LOC': 5,\n 'I-LOC': 6,\n 'B-MISC': 7,\n 'I-MISC': 8}" }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ner_tag2idx" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-05-14T07:11:27.854314700Z", "start_time": "2024-05-14T07:11:27.844315700Z" } }, "id": "8bf1e9961daa4bd8" }, { "cell_type": "code", "execution_count": 15, "outputs": [ { "data": { "text/plain": "{0: 'O',\n 1: 'B-PER',\n 2: 'I-PER',\n 3: 'B-ORG',\n 4: 'I-ORG',\n 5: 'B-LOC',\n 6: 'I-LOC',\n 7: 'B-MISC',\n 8: 'I-MISC'}" }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ner_idx2tag" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-05-14T07:11:28.332071700Z", "start_time": "2024-05-14T07:11:28.286070800Z" } }, "id": "12571d646796d21b" }, { "cell_type": "code", "execution_count": 16, "outputs": [], "source": [ "# Method for vectorizing text data using the vocabulary mapping\n", "def text_to_vec(data):\n", " return [torch.tensor([v['']] + [v[token] for token in document] + [v['']], dtype=torch.long) for document in data]" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-05-14T07:11:29.032865100Z", "start_time": "2024-05-14T07:11:29.012730500Z" } }, "id": "da795a7fd000b135" }, { "cell_type": "code", "execution_count": 17, "outputs": [], "source": [ "# Method for vectorizing NER tags data using the NER tags mapping\n", "def ner_tags_to_vec(data):\n", " return [torch.tensor([0] + [ner_tag2idx[tag] for tag in document] + [0], dtype=torch.long) for document in data]" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-05-14T07:11:29.824074700Z", "start_time": "2024-05-14T07:11:29.812059800Z" } }, "id": "f9c2bb1f0bb0e480" }, { "cell_type": "code", "execution_count": 18, "outputs": [], "source": [ "# Vectorize the text data (input)\n", "X_train = text_to_vec(train_data['text_tokens'])\n", "X_dev = text_to_vec(valid_data['text_tokens'])\n", "X_test = text_to_vec(test_data['text_tokens'])" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-05-14T07:11:30.896086700Z", "start_time": "2024-05-14T07:11:30.610066Z" } }, "id": "2f851f63cedacf6c" }, { "cell_type": "code", "execution_count": 19, "outputs": [], "source": [ "# Vectorize the NER tags data (output, labels)\n", "y_train = ner_tags_to_vec(train_data['ner_tags_tokens'])\n", "y_dev = ner_tags_to_vec(valid_data['ner_tags_tokens'])" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-05-14T07:11:31.468671200Z", "start_time": "2024-05-14T07:11:31.415476500Z" } }, "id": "30e8c488d3b9d11a" }, { "cell_type": "code", "execution_count": 20, "outputs": [], "source": [ "# Model definition\n", "class LSTM(nn.Module):\n", " def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim):\n", " super(LSTM, self).__init__()\n", " \n", " # Embedding layer\n", " self.embedding = nn.Embedding(vocab_size, embedding_dim)\n", " \n", " # LSTM layer\n", " self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first = True)\n", " \n", " # Fully connected layer\n", " self.fc = nn.Linear(hidden_dim, output_dim)\n", " \n", " self.relu = nn.ReLU()\n", " \n", " def forward(self, x):\n", " # Embedding\n", " embedding = self.relu(self.embedding(x))\n", " \n", " # LSTM\n", " output, (hidden, cell) = self.lstm(embedding)\n", " \n", " # Fully connected\n", " output = self.fc(output)\n", " \n", " return output\n", " " ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-05-14T07:11:32.454622200Z", "start_time": "2024-05-14T07:11:32.422201500Z" } }, "id": "6a86649248c384b5" }, { "cell_type": "code", "execution_count": 77, "outputs": [], "source": [ "# Segeval evaluation\n", "def evaluate_model(model, X, y):\n", " \"\"\"\n", " Method for evaluating the model\n", " :param model: model\n", " :param X: input data\n", " :param y: output data \n", " :return: dictionary with metrics values\n", " \"\"\"\n", " # No gradients\n", " with torch.no_grad():\n", " # Predict the labels\n", " y_pred = [torch.argmax(model(x.unsqueeze(0)).squeeze(0), 1) for x in X]\n", " \n", " # Convert the labels to ner tags\n", " y_pred = [[ner_idx2tag[int(idx)] for idx in y] for y in y_pred]\n", " y_tags = [[ner_idx2tag[int(idx)] for idx in y] for y in y]\n", " \n", " # Calculate the metrics\n", " accuracy = accuracy_score(y_tags, y_pred)\n", " precision = precision_score(y_tags, y_pred)\n", " recall = recall_score(y_tags, y_pred)\n", " f1 = f1_score(y_tags, y_pred)\n", " \n", " return {'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1': f1}" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-05-14T08:26:31.612231Z", "start_time": "2024-05-14T08:26:31.599603300Z" } }, "id": "b18d26ac9fbc590e" }, { "cell_type": "code", "execution_count": 23, "outputs": [], "source": [ "# Use GPU if available\n", "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-05-14T07:11:49.825835200Z", "start_time": "2024-05-14T07:11:49.817343100Z" } }, "id": "badf288796646abe" }, { "cell_type": "code", "execution_count": 39, "outputs": [], "source": [ "# Model parameters\n", "vocab_size = len(v)\n", "embedding_dim = 64\n", "hidden_dim = 256\n", "output_dim = len(ner_tags)\n", "epochs = 20" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-05-14T07:22:20.730379Z", "start_time": "2024-05-14T07:22:20.724143500Z" } }, "id": "65beded501220882" }, { "cell_type": "code", "execution_count": 44, "outputs": [], "source": [ "# Seed for reproducibility\n", "torch.manual_seed(1234)\n", "\n", "import random\n", "random.seed(1234)\n", "\n", "np.random.seed(1234)" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-05-14T07:28:18.248713300Z", "start_time": "2024-05-14T07:28:18.188830400Z" } }, "id": "63b68885d93d5fce" }, { "cell_type": "code", "execution_count": 40, "outputs": [], "source": [ "# Initialize the model\n", "model = LSTM(vocab_size, embedding_dim, hidden_dim, output_dim)" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-05-14T07:22:21.068317100Z", "start_time": "2024-05-14T07:22:21.044162900Z" } }, "id": "29116c705decf395" }, { "cell_type": "code", "execution_count": 41, "outputs": [], "source": [ "# Loss function and optimizer\n", "criterion = nn.CrossEntropyLoss()\n", "optimizer = optim.Adam(model.parameters())" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-05-14T07:22:21.705555900Z", "start_time": "2024-05-14T07:22:21.675608300Z" } }, "id": "617bec2a8a8b56b3" }, { "cell_type": "code", "execution_count": 65, "outputs": [], "source": [ "# Move training to GPU\n", "model = model.to(device)\n", "X_train = [x.to(device) for x in X_train]\n", "y_train = [y.to(device) for y in y_train]\n", "X_dev = [x.to(device) for x in X_dev]\n", "y_dev = [y.to(device) for y in y_dev]" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-05-14T07:44:19.353471700Z", "start_time": "2024-05-14T07:44:19.317384100Z" } }, "id": "dfa0d6b3bdca6853" }, { "cell_type": "code", "execution_count": 67, "outputs": [ { "data": { "text/plain": " 0%| | 0/945 [00:00 and )\n", "y_dev_pred_con = [' '.join(y[1:-1]) for y in y_dev_pred]\n", "y_test_pred_con = [' '.join(y[1:-1]) for y in y_test_pred]" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-05-14T07:45:49.368345900Z", "start_time": "2024-05-14T07:45:49.355900900Z" } }, "id": "1a9dc8188e83e5e9" }, { "cell_type": "code", "execution_count": 72, "outputs": [], "source": [ "# Save the predictions (without postprocessing)\n", "pd.DataFrame(y_dev_pred_con).to_csv('dev-0/out-model.tsv', header=False, index=False, sep='\\t')\n", "pd.DataFrame(y_test_pred_con).to_csv('test-A/out-model.tsv', header=False, index=False, sep='\\t')" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-05-14T07:45:49.397283400Z", "start_time": "2024-05-14T07:45:49.370434300Z" } }, "id": "66daac87feaa2f66" }, { "cell_type": "code", "execution_count": 42, "outputs": [], "source": [ "# Postprocessing\n", "# Regex for finding I-tags that start a sequence (should be B-tags)\n", "def incorrect_I_as_begin_tag(text):\n", " return re.finditer(r'(?