diff --git a/lstm.ipynb b/lstm.ipynb new file mode 100644 index 0000000..7e2ce9b --- /dev/null +++ b/lstm.ipynb @@ -0,0 +1,956 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "initial_id", + "metadata": { + "collapsed": true, + "ExecuteTime": { + "end_time": "2024-05-16T18:21:49.572131300Z", + "start_time": "2024-05-16T18:21:43.423852800Z" + } + }, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import numpy as np\n", + "\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.optim as optim\n", + "\n", + "import warnings\n", + "warnings.filterwarnings('ignore')\n", + "\n", + "import torchtext\n", + "from torchtext.vocab import vocab\n", + "\n", + "from seqeval.metrics import accuracy_score, precision_score, recall_score, f1_score, classification_report\n", + "\n", + "from tqdm.notebook import tqdm\n", + "\n", + "from collections import Counter" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "outputs": [], + "source": [ + "# Load the data\n", + "train_data = pd.read_csv('train/train.tsv', delimiter='\\t', header=None)\n", + "\n", + "valid_data_in = pd.read_csv('dev-0/in.tsv', delimiter='\\t', header=None)\n", + "valid_data_expected = pd.read_csv('dev-0/expected.tsv', delimiter='\\t', header=None)\n", + "valid_data = pd.concat([valid_data_expected, valid_data_in], axis=1)\n", + "\n", + "test_data = pd.read_csv('test-A/in.tsv', delimiter='\\t', header=None)\n", + "\n", + "# Label the columns\n", + "train_data.columns = ['ner_tags', 'text']\n", + "valid_data.columns = ['ner_tags', 'text']\n", + "test_data.columns = ['text']\n", + "\n", + "# Split the text into tokens\n", + "train_data['text_tokens'] = train_data['text'].apply(lambda x: x.split())\n", + "valid_data['text_tokens'] = valid_data['text'].apply(lambda x: x.split())\n", + "test_data['text_tokens'] = test_data['text'].apply(lambda x: x.split())\n", + "\n", + "# Split the NER tags into tokens\n", + "train_data['ner_tags_tokens'] = train_data['ner_tags'].apply(lambda x: x.split())\n", + "valid_data['ner_tags_tokens'] = valid_data['ner_tags'].apply(lambda x: x.split())" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-05-14T07:11:23.174336100Z", + "start_time": "2024-05-14T07:11:23.080690300Z" + } + }, + "id": "9e5c5c1083e3f387" + }, + { + "cell_type": "code", + "execution_count": 8, + "outputs": [], + "source": [ + "# Method for building the vocabulary from DataFrame dataset\n", + "# Special tokens:\n", + "# - unknown token\n", + "# - padding token\n", + "# - beginning of sentence token\n", + "# - end of sentence token\n", + "def build_vocab(dataset):\n", + " # Initialize the counter\n", + " counter = Counter()\n", + " \n", + " # Iterate over the dataset and update the counter\n", + " for idx, document in dataset.iterrows():\n", + " counter.update(document['text_tokens'])\n", + " \n", + " # Return the vocabulary\n", + " return vocab(counter, specials=['', '', '', ''])" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-05-14T07:11:23.647897500Z", + "start_time": "2024-05-14T07:11:23.640148800Z" + } + }, + "id": "56a8833a05334060" + }, + { + "cell_type": "code", + "execution_count": 9, + "outputs": [], + "source": [ + "# Build the vocabulary\n", + "v = build_vocab(train_data)" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-05-14T07:11:24.169912Z", + "start_time": "2024-05-14T07:11:24.081356500Z" + } + }, + "id": "eacfbc15230adc2e" + }, + { + "cell_type": "code", + "execution_count": 10, + "outputs": [], + "source": [ + "# Mapping from index to token\n", + "itos = v.get_itos()" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-05-14T07:11:24.484522400Z", + "start_time": "2024-05-14T07:11:24.470356200Z" + } + }, + "id": "c9c7ce32ebd5a3c2" + }, + { + "cell_type": "code", + "execution_count": 11, + "outputs": [], + "source": [ + "# Set default index for unknown tokens\n", + "v.set_default_index(v[\"\"])" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-05-14T07:11:24.842556700Z", + "start_time": "2024-05-14T07:11:24.823442400Z" + } + }, + "id": "ce8d899162dcc776" + }, + { + "cell_type": "code", + "execution_count": 12, + "outputs": [], + "source": [ + "# Get the unique ner tags\n", + "ner_tags = set([tag for tags in train_data['ner_tags_tokens'] for tag in tags])" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-05-14T07:11:25.201567900Z", + "start_time": "2024-05-14T07:11:25.180831600Z" + } + }, + "id": "2e9f2dc469b6025d" + }, + { + "cell_type": "code", + "execution_count": 13, + "outputs": [], + "source": [ + "# Mapping from tag to index (https://huggingface.co/datasets/conll2003)\n", + "ner_tag2idx = {'O': 0, 'B-PER': 1, 'I-PER': 2, 'B-ORG': 3, 'I-ORG': 4, 'B-LOC': 5, 'I-LOC': 6, 'B-MISC': 7, 'I-MISC': 8}\n", + "\n", + "# reverse mapping\n", + "ner_idx2tag = {idx: tag for tag, idx in ner_tag2idx.items()}" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-05-14T07:11:26.534701200Z", + "start_time": "2024-05-14T07:11:26.526620300Z" + } + }, + "id": "5271fd04bd9f16e3" + }, + { + "cell_type": "code", + "execution_count": 14, + "outputs": [ + { + "data": { + "text/plain": "{'O': 0,\n 'B-PER': 1,\n 'I-PER': 2,\n 'B-ORG': 3,\n 'I-ORG': 4,\n 'B-LOC': 5,\n 'I-LOC': 6,\n 'B-MISC': 7,\n 'I-MISC': 8}" + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ner_tag2idx" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-05-14T07:11:27.854314700Z", + "start_time": "2024-05-14T07:11:27.844315700Z" + } + }, + "id": "8bf1e9961daa4bd8" + }, + { + "cell_type": "code", + "execution_count": 15, + "outputs": [ + { + "data": { + "text/plain": "{0: 'O',\n 1: 'B-PER',\n 2: 'I-PER',\n 3: 'B-ORG',\n 4: 'I-ORG',\n 5: 'B-LOC',\n 6: 'I-LOC',\n 7: 'B-MISC',\n 8: 'I-MISC'}" + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ner_idx2tag" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-05-14T07:11:28.332071700Z", + "start_time": "2024-05-14T07:11:28.286070800Z" + } + }, + "id": "12571d646796d21b" + }, + { + "cell_type": "code", + "execution_count": 16, + "outputs": [], + "source": [ + "# Method for vectorizing text data using the vocabulary mapping\n", + "def text_to_vec(data):\n", + " return [torch.tensor([v['']] + [v[token] for token in document] + [v['']], dtype=torch.long) for document in data]" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-05-14T07:11:29.032865100Z", + "start_time": "2024-05-14T07:11:29.012730500Z" + } + }, + "id": "da795a7fd000b135" + }, + { + "cell_type": "code", + "execution_count": 17, + "outputs": [], + "source": [ + "# Method for vectorizing NER tags data using the NER tags mapping\n", + "def ner_tags_to_vec(data):\n", + " return [torch.tensor([0] + [ner_tag2idx[tag] for tag in document] + [0], dtype=torch.long) for document in data]" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-05-14T07:11:29.824074700Z", + "start_time": "2024-05-14T07:11:29.812059800Z" + } + }, + "id": "f9c2bb1f0bb0e480" + }, + { + "cell_type": "code", + "execution_count": 18, + "outputs": [], + "source": [ + "# Vectorize the text data (input)\n", + "X_train = text_to_vec(train_data['text_tokens'])\n", + "X_dev = text_to_vec(valid_data['text_tokens'])\n", + "X_test = text_to_vec(test_data['text_tokens'])" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-05-14T07:11:30.896086700Z", + "start_time": "2024-05-14T07:11:30.610066Z" + } + }, + "id": "2f851f63cedacf6c" + }, + { + "cell_type": "code", + "execution_count": 19, + "outputs": [], + "source": [ + "# Vectorize the NER tags data (output, labels)\n", + "y_train = ner_tags_to_vec(train_data['ner_tags_tokens'])\n", + "y_dev = ner_tags_to_vec(valid_data['ner_tags_tokens'])" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-05-14T07:11:31.468671200Z", + "start_time": "2024-05-14T07:11:31.415476500Z" + } + }, + "id": "30e8c488d3b9d11a" + }, + { + "cell_type": "code", + "execution_count": 20, + "outputs": [], + "source": [ + "# Model definition\n", + "class LSTM(nn.Module):\n", + " def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim):\n", + " super(LSTM, self).__init__()\n", + " \n", + " # Embedding layer\n", + " self.embedding = nn.Embedding(vocab_size, embedding_dim)\n", + " \n", + " # LSTM layer\n", + " self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first = True)\n", + " \n", + " # Fully connected layer\n", + " self.fc = nn.Linear(hidden_dim, output_dim)\n", + " \n", + " self.relu = nn.ReLU()\n", + " \n", + " def forward(self, x):\n", + " # Embedding\n", + " embedding = self.relu(self.embedding(x))\n", + " \n", + " # LSTM\n", + " output, (hidden, cell) = self.lstm(embedding)\n", + " \n", + " # Fully connected\n", + " output = self.fc(output)\n", + " \n", + " return output\n", + " " + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-05-14T07:11:32.454622200Z", + "start_time": "2024-05-14T07:11:32.422201500Z" + } + }, + "id": "6a86649248c384b5" + }, + { + "cell_type": "code", + "execution_count": 77, + "outputs": [], + "source": [ + "# Segeval evaluation\n", + "def evaluate_model(model, X, y):\n", + " \"\"\"\n", + " Method for evaluating the model\n", + " :param model: model\n", + " :param X: input data\n", + " :param y: output data \n", + " :return: dictionary with metrics values\n", + " \"\"\"\n", + " # No gradients\n", + " with torch.no_grad():\n", + " # Predict the labels\n", + " y_pred = [torch.argmax(model(x.unsqueeze(0)).squeeze(0), 1) for x in X]\n", + " \n", + " # Convert the labels to ner tags\n", + " y_pred = [[ner_idx2tag[int(idx)] for idx in y] for y in y_pred]\n", + " y_tags = [[ner_idx2tag[int(idx)] for idx in y] for y in y]\n", + " \n", + " # Calculate the metrics\n", + " accuracy = accuracy_score(y_tags, y_pred)\n", + " precision = precision_score(y_tags, y_pred)\n", + " recall = recall_score(y_tags, y_pred)\n", + " f1 = f1_score(y_tags, y_pred)\n", + " \n", + " return {'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1': f1}" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-05-14T08:26:31.612231Z", + "start_time": "2024-05-14T08:26:31.599603300Z" + } + }, + "id": "b18d26ac9fbc590e" + }, + { + "cell_type": "code", + "execution_count": 23, + "outputs": [], + "source": [ + "# Use GPU if available\n", + "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-05-14T07:11:49.825835200Z", + "start_time": "2024-05-14T07:11:49.817343100Z" + } + }, + "id": "badf288796646abe" + }, + { + "cell_type": "code", + "execution_count": 39, + "outputs": [], + "source": [ + "# Model parameters\n", + "vocab_size = len(v)\n", + "embedding_dim = 64\n", + "hidden_dim = 256\n", + "output_dim = len(ner_tags)\n", + "epochs = 20" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-05-14T07:22:20.730379Z", + "start_time": "2024-05-14T07:22:20.724143500Z" + } + }, + "id": "65beded501220882" + }, + { + "cell_type": "code", + "execution_count": 44, + "outputs": [], + "source": [ + "# Seed for reproducibility\n", + "torch.manual_seed(1234)\n", + "\n", + "import random\n", + "random.seed(1234)\n", + "\n", + "np.random.seed(1234)" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-05-14T07:28:18.248713300Z", + "start_time": "2024-05-14T07:28:18.188830400Z" + } + }, + "id": "63b68885d93d5fce" + }, + { + "cell_type": "code", + "execution_count": 40, + "outputs": [], + "source": [ + "# Initialize the model\n", + "model = LSTM(vocab_size, embedding_dim, hidden_dim, output_dim)" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-05-14T07:22:21.068317100Z", + "start_time": "2024-05-14T07:22:21.044162900Z" + } + }, + "id": "29116c705decf395" + }, + { + "cell_type": "code", + "execution_count": 41, + "outputs": [], + "source": [ + "# Loss function and optimizer\n", + "criterion = nn.CrossEntropyLoss()\n", + "optimizer = optim.Adam(model.parameters())" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-05-14T07:22:21.705555900Z", + "start_time": "2024-05-14T07:22:21.675608300Z" + } + }, + "id": "617bec2a8a8b56b3" + }, + { + "cell_type": "code", + "execution_count": 65, + "outputs": [], + "source": [ + "# Move training to GPU\n", + "model = model.to(device)\n", + "X_train = [x.to(device) for x in X_train]\n", + "y_train = [y.to(device) for y in y_train]\n", + "X_dev = [x.to(device) for x in X_dev]\n", + "y_dev = [y.to(device) for y in y_dev]" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-05-14T07:44:19.353471700Z", + "start_time": "2024-05-14T07:44:19.317384100Z" + } + }, + "id": "dfa0d6b3bdca6853" + }, + { + "cell_type": "code", + "execution_count": 67, + "outputs": [ + { + "data": { + "text/plain": " 0%| | 0/945 [00:00 and )\n", + "y_dev_pred_con = [' '.join(y[1:-1]) for y in y_dev_pred]\n", + "y_test_pred_con = [' '.join(y[1:-1]) for y in y_test_pred]" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-05-14T07:45:49.368345900Z", + "start_time": "2024-05-14T07:45:49.355900900Z" + } + }, + "id": "1a9dc8188e83e5e9" + }, + { + "cell_type": "code", + "execution_count": 72, + "outputs": [], + "source": [ + "# Save the predictions (without postprocessing)\n", + "pd.DataFrame(y_dev_pred_con).to_csv('dev-0/out-model.tsv', header=False, index=False, sep='\\t')\n", + "pd.DataFrame(y_test_pred_con).to_csv('test-A/out-model.tsv', header=False, index=False, sep='\\t')" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-05-14T07:45:49.397283400Z", + "start_time": "2024-05-14T07:45:49.370434300Z" + } + }, + "id": "66daac87feaa2f66" + }, + { + "cell_type": "code", + "execution_count": 42, + "outputs": [], + "source": [ + "# Postprocessing\n", + "# Regex for finding I-tags that start a sequence (should be B-tags)\n", + "def incorrect_I_as_begin_tag(text):\n", + " return re.finditer(r'(?