en-ner-conll-2003/lstm.ipynb

957 lines
28 KiB
Plaintext
Raw Normal View History

2024-05-16 21:28:49 +02:00
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "initial_id",
"metadata": {
"collapsed": true,
"ExecuteTime": {
"end_time": "2024-05-16T18:21:49.572131300Z",
"start_time": "2024-05-16T18:21:43.423852800Z"
}
},
"outputs": [],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.optim as optim\n",
"\n",
"import warnings\n",
"warnings.filterwarnings('ignore')\n",
"\n",
"import torchtext\n",
"from torchtext.vocab import vocab\n",
"\n",
"from seqeval.metrics import accuracy_score, precision_score, recall_score, f1_score, classification_report\n",
"\n",
"from tqdm.notebook import tqdm\n",
"\n",
"from collections import Counter"
]
},
{
"cell_type": "code",
"execution_count": 7,
"outputs": [],
"source": [
"# Load the data\n",
"train_data = pd.read_csv('train/train.tsv', delimiter='\\t', header=None)\n",
"\n",
"valid_data_in = pd.read_csv('dev-0/in.tsv', delimiter='\\t', header=None)\n",
"valid_data_expected = pd.read_csv('dev-0/expected.tsv', delimiter='\\t', header=None)\n",
"valid_data = pd.concat([valid_data_expected, valid_data_in], axis=1)\n",
"\n",
"test_data = pd.read_csv('test-A/in.tsv', delimiter='\\t', header=None)\n",
"\n",
"# Label the columns\n",
"train_data.columns = ['ner_tags', 'text']\n",
"valid_data.columns = ['ner_tags', 'text']\n",
"test_data.columns = ['text']\n",
"\n",
"# Split the text into tokens\n",
"train_data['text_tokens'] = train_data['text'].apply(lambda x: x.split())\n",
"valid_data['text_tokens'] = valid_data['text'].apply(lambda x: x.split())\n",
"test_data['text_tokens'] = test_data['text'].apply(lambda x: x.split())\n",
"\n",
"# Split the NER tags into tokens\n",
"train_data['ner_tags_tokens'] = train_data['ner_tags'].apply(lambda x: x.split())\n",
"valid_data['ner_tags_tokens'] = valid_data['ner_tags'].apply(lambda x: x.split())"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-05-14T07:11:23.174336100Z",
"start_time": "2024-05-14T07:11:23.080690300Z"
}
},
"id": "9e5c5c1083e3f387"
},
{
"cell_type": "code",
"execution_count": 8,
"outputs": [],
"source": [
"# Method for building the vocabulary from DataFrame dataset\n",
"# Special tokens:\n",
"# <unk> - unknown token\n",
"# <pad> - padding token\n",
"# <bos> - beginning of sentence token\n",
"# <eos> - end of sentence token\n",
"def build_vocab(dataset):\n",
" # Initialize the counter\n",
" counter = Counter()\n",
" \n",
" # Iterate over the dataset and update the counter\n",
" for idx, document in dataset.iterrows():\n",
" counter.update(document['text_tokens'])\n",
" \n",
" # Return the vocabulary\n",
" return vocab(counter, specials=['<unk>', '<pad>', '<bos>', '<eos>'])"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-05-14T07:11:23.647897500Z",
"start_time": "2024-05-14T07:11:23.640148800Z"
}
},
"id": "56a8833a05334060"
},
{
"cell_type": "code",
"execution_count": 9,
"outputs": [],
"source": [
"# Build the vocabulary\n",
"v = build_vocab(train_data)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-05-14T07:11:24.169912Z",
"start_time": "2024-05-14T07:11:24.081356500Z"
}
},
"id": "eacfbc15230adc2e"
},
{
"cell_type": "code",
"execution_count": 10,
"outputs": [],
"source": [
"# Mapping from index to token\n",
"itos = v.get_itos()"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-05-14T07:11:24.484522400Z",
"start_time": "2024-05-14T07:11:24.470356200Z"
}
},
"id": "c9c7ce32ebd5a3c2"
},
{
"cell_type": "code",
"execution_count": 11,
"outputs": [],
"source": [
"# Set default index for unknown tokens\n",
"v.set_default_index(v[\"<unk>\"])"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-05-14T07:11:24.842556700Z",
"start_time": "2024-05-14T07:11:24.823442400Z"
}
},
"id": "ce8d899162dcc776"
},
{
"cell_type": "code",
"execution_count": 12,
"outputs": [],
"source": [
"# Get the unique ner tags\n",
"ner_tags = set([tag for tags in train_data['ner_tags_tokens'] for tag in tags])"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-05-14T07:11:25.201567900Z",
"start_time": "2024-05-14T07:11:25.180831600Z"
}
},
"id": "2e9f2dc469b6025d"
},
{
"cell_type": "code",
"execution_count": 13,
"outputs": [],
"source": [
"# Mapping from tag to index (https://huggingface.co/datasets/conll2003)\n",
"ner_tag2idx = {'O': 0, 'B-PER': 1, 'I-PER': 2, 'B-ORG': 3, 'I-ORG': 4, 'B-LOC': 5, 'I-LOC': 6, 'B-MISC': 7, 'I-MISC': 8}\n",
"\n",
"# reverse mapping\n",
"ner_idx2tag = {idx: tag for tag, idx in ner_tag2idx.items()}"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-05-14T07:11:26.534701200Z",
"start_time": "2024-05-14T07:11:26.526620300Z"
}
},
"id": "5271fd04bd9f16e3"
},
{
"cell_type": "code",
"execution_count": 14,
"outputs": [
{
"data": {
"text/plain": "{'O': 0,\n 'B-PER': 1,\n 'I-PER': 2,\n 'B-ORG': 3,\n 'I-ORG': 4,\n 'B-LOC': 5,\n 'I-LOC': 6,\n 'B-MISC': 7,\n 'I-MISC': 8}"
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ner_tag2idx"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-05-14T07:11:27.854314700Z",
"start_time": "2024-05-14T07:11:27.844315700Z"
}
},
"id": "8bf1e9961daa4bd8"
},
{
"cell_type": "code",
"execution_count": 15,
"outputs": [
{
"data": {
"text/plain": "{0: 'O',\n 1: 'B-PER',\n 2: 'I-PER',\n 3: 'B-ORG',\n 4: 'I-ORG',\n 5: 'B-LOC',\n 6: 'I-LOC',\n 7: 'B-MISC',\n 8: 'I-MISC'}"
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ner_idx2tag"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-05-14T07:11:28.332071700Z",
"start_time": "2024-05-14T07:11:28.286070800Z"
}
},
"id": "12571d646796d21b"
},
{
"cell_type": "code",
"execution_count": 16,
"outputs": [],
"source": [
"# Method for vectorizing text data using the vocabulary mapping\n",
"def text_to_vec(data):\n",
" return [torch.tensor([v['<bos>']] + [v[token] for token in document] + [v['<eos>']], dtype=torch.long) for document in data]"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-05-14T07:11:29.032865100Z",
"start_time": "2024-05-14T07:11:29.012730500Z"
}
},
"id": "da795a7fd000b135"
},
{
"cell_type": "code",
"execution_count": 17,
"outputs": [],
"source": [
"# Method for vectorizing NER tags data using the NER tags mapping\n",
"def ner_tags_to_vec(data):\n",
" return [torch.tensor([0] + [ner_tag2idx[tag] for tag in document] + [0], dtype=torch.long) for document in data]"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-05-14T07:11:29.824074700Z",
"start_time": "2024-05-14T07:11:29.812059800Z"
}
},
"id": "f9c2bb1f0bb0e480"
},
{
"cell_type": "code",
"execution_count": 18,
"outputs": [],
"source": [
"# Vectorize the text data (input)\n",
"X_train = text_to_vec(train_data['text_tokens'])\n",
"X_dev = text_to_vec(valid_data['text_tokens'])\n",
"X_test = text_to_vec(test_data['text_tokens'])"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-05-14T07:11:30.896086700Z",
"start_time": "2024-05-14T07:11:30.610066Z"
}
},
"id": "2f851f63cedacf6c"
},
{
"cell_type": "code",
"execution_count": 19,
"outputs": [],
"source": [
"# Vectorize the NER tags data (output, labels)\n",
"y_train = ner_tags_to_vec(train_data['ner_tags_tokens'])\n",
"y_dev = ner_tags_to_vec(valid_data['ner_tags_tokens'])"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-05-14T07:11:31.468671200Z",
"start_time": "2024-05-14T07:11:31.415476500Z"
}
},
"id": "30e8c488d3b9d11a"
},
{
"cell_type": "code",
"execution_count": 20,
"outputs": [],
"source": [
"# Model definition\n",
"class LSTM(nn.Module):\n",
" def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim):\n",
" super(LSTM, self).__init__()\n",
" \n",
" # Embedding layer\n",
" self.embedding = nn.Embedding(vocab_size, embedding_dim)\n",
" \n",
" # LSTM layer\n",
" self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first = True)\n",
" \n",
" # Fully connected layer\n",
" self.fc = nn.Linear(hidden_dim, output_dim)\n",
" \n",
" self.relu = nn.ReLU()\n",
" \n",
" def forward(self, x):\n",
" # Embedding\n",
" embedding = self.relu(self.embedding(x))\n",
" \n",
" # LSTM\n",
" output, (hidden, cell) = self.lstm(embedding)\n",
" \n",
" # Fully connected\n",
" output = self.fc(output)\n",
" \n",
" return output\n",
" "
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-05-14T07:11:32.454622200Z",
"start_time": "2024-05-14T07:11:32.422201500Z"
}
},
"id": "6a86649248c384b5"
},
{
"cell_type": "code",
"execution_count": 77,
"outputs": [],
"source": [
"# Segeval evaluation\n",
"def evaluate_model(model, X, y):\n",
" \"\"\"\n",
" Method for evaluating the model\n",
" :param model: model\n",
" :param X: input data\n",
" :param y: output data \n",
" :return: dictionary with metrics values\n",
" \"\"\"\n",
" # No gradients\n",
" with torch.no_grad():\n",
" # Predict the labels\n",
" y_pred = [torch.argmax(model(x.unsqueeze(0)).squeeze(0), 1) for x in X]\n",
" \n",
" # Convert the labels to ner tags\n",
" y_pred = [[ner_idx2tag[int(idx)] for idx in y] for y in y_pred]\n",
" y_tags = [[ner_idx2tag[int(idx)] for idx in y] for y in y]\n",
" \n",
" # Calculate the metrics\n",
" accuracy = accuracy_score(y_tags, y_pred)\n",
" precision = precision_score(y_tags, y_pred)\n",
" recall = recall_score(y_tags, y_pred)\n",
" f1 = f1_score(y_tags, y_pred)\n",
" \n",
" return {'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1': f1}"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-05-14T08:26:31.612231Z",
"start_time": "2024-05-14T08:26:31.599603300Z"
}
},
"id": "b18d26ac9fbc590e"
},
{
"cell_type": "code",
"execution_count": 23,
"outputs": [],
"source": [
"# Use GPU if available\n",
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-05-14T07:11:49.825835200Z",
"start_time": "2024-05-14T07:11:49.817343100Z"
}
},
"id": "badf288796646abe"
},
{
"cell_type": "code",
"execution_count": 39,
"outputs": [],
"source": [
"# Model parameters\n",
"vocab_size = len(v)\n",
"embedding_dim = 64\n",
"hidden_dim = 256\n",
"output_dim = len(ner_tags)\n",
"epochs = 20"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-05-14T07:22:20.730379Z",
"start_time": "2024-05-14T07:22:20.724143500Z"
}
},
"id": "65beded501220882"
},
{
"cell_type": "code",
"execution_count": 44,
"outputs": [],
"source": [
"# Seed for reproducibility\n",
"torch.manual_seed(1234)\n",
"\n",
"import random\n",
"random.seed(1234)\n",
"\n",
"np.random.seed(1234)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-05-14T07:28:18.248713300Z",
"start_time": "2024-05-14T07:28:18.188830400Z"
}
},
"id": "63b68885d93d5fce"
},
{
"cell_type": "code",
"execution_count": 40,
"outputs": [],
"source": [
"# Initialize the model\n",
"model = LSTM(vocab_size, embedding_dim, hidden_dim, output_dim)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-05-14T07:22:21.068317100Z",
"start_time": "2024-05-14T07:22:21.044162900Z"
}
},
"id": "29116c705decf395"
},
{
"cell_type": "code",
"execution_count": 41,
"outputs": [],
"source": [
"# Loss function and optimizer\n",
"criterion = nn.CrossEntropyLoss()\n",
"optimizer = optim.Adam(model.parameters())"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-05-14T07:22:21.705555900Z",
"start_time": "2024-05-14T07:22:21.675608300Z"
}
},
"id": "617bec2a8a8b56b3"
},
{
"cell_type": "code",
"execution_count": 65,
"outputs": [],
"source": [
"# Move training to GPU\n",
"model = model.to(device)\n",
"X_train = [x.to(device) for x in X_train]\n",
"y_train = [y.to(device) for y in y_train]\n",
"X_dev = [x.to(device) for x in X_dev]\n",
"y_dev = [y.to(device) for y in y_dev]"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-05-14T07:44:19.353471700Z",
"start_time": "2024-05-14T07:44:19.317384100Z"
}
},
"id": "dfa0d6b3bdca6853"
},
{
"cell_type": "code",
"execution_count": 67,
"outputs": [
{
"data": {
"text/plain": " 0%| | 0/945 [00:00<?, ?it/s]",
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "415ba9a191bf4ff0993115b428e604ec"
}
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch: 1, Accuracy: 0.9545313667936774, Precision: 0.7780607604147717, Recall: 0.7213695395513577, F1: 0.7486434447750743\n"
]
}
],
"source": [
"# Training loop\n",
"model.train()\n",
"\n",
"for epoch in range(epochs):\n",
" \n",
" for idx in tqdm(range(len(X_train))):\n",
" # Zero the gradients\n",
" optimizer.zero_grad()\n",
" \n",
" # Forward pass\n",
" output = model(X_train[idx].unsqueeze(0))\n",
"\n",
" # Calculate the loss\n",
" loss = criterion(output.squeeze(0), y_train[idx])\n",
" \n",
" # Backward pass\n",
" loss.backward()\n",
" \n",
" # Update the weights\n",
" optimizer.step()\n",
" \n",
" # Evaluate the model on the dev set\n",
" metrics = evaluate_model(model, X_dev, y_dev)\n",
" \n",
" print(f'Epoch: {epoch+1}, Accuracy: {metrics[\"accuracy\"]}, Precision: {metrics[\"precision\"]}, Recall: {metrics[\"recall\"]}, F1: {metrics[\"f1\"]}')"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-05-14T07:45:10.782583200Z",
"start_time": "2024-05-14T07:44:53.579284100Z"
}
},
"id": "7a77d0ac6fce81fd"
},
{
"cell_type": "code",
"execution_count": 78,
"outputs": [
{
"data": {
"text/plain": "{'accuracy': 0.9545313667936774,\n 'precision': 0.7780607604147717,\n 'recall': 0.7213695395513577,\n 'f1': 0.7486434447750743}"
},
"execution_count": 78,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"evaluate_model(model, X_dev, y_dev)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-05-14T08:26:38.050024900Z",
"start_time": "2024-05-14T08:26:36.980050200Z"
}
},
"id": "956b180f74abd5d4"
},
{
"cell_type": "code",
"execution_count": 69,
"outputs": [],
"source": [
"# Move to CPU\n",
"model = model.to('cpu')\n",
"X_dev = [x.to('cpu') for x in X_dev]\n",
"y_dev = [y.to('cpu') for y in y_dev]"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-05-14T07:45:46.733747700Z",
"start_time": "2024-05-14T07:45:46.673385500Z"
}
},
"id": "b019c1d995100ef5"
},
{
"cell_type": "code",
"execution_count": 70,
"outputs": [],
"source": [
"# Predict the labels for the validation and test sets\n",
"with torch.no_grad():\n",
" y_dev_pred = [torch.argmax(model(x.unsqueeze(0)).squeeze(0), 1) for x in X_dev]\n",
" y_test_pred = [torch.argmax(model(x.unsqueeze(0)).squeeze(0), 1) for x in X_test]\n",
"\n",
"# Convert the labels to ner tags\n",
"y_dev_pred = [[ner_idx2tag[int(idx)] for idx in y] for y in y_dev_pred]\n",
"y_test_pred = [[ner_idx2tag[int(idx)] for idx in y] for y in y_test_pred]"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-05-14T07:45:49.354440Z",
"start_time": "2024-05-14T07:45:47.244447100Z"
}
},
"id": "523f8444e9a73a05"
},
{
"cell_type": "code",
"execution_count": 71,
"outputs": [],
"source": [
"# Concatenate predicted labels (skip the special tokens <bos> and <eos>)\n",
"y_dev_pred_con = [' '.join(y[1:-1]) for y in y_dev_pred]\n",
"y_test_pred_con = [' '.join(y[1:-1]) for y in y_test_pred]"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-05-14T07:45:49.368345900Z",
"start_time": "2024-05-14T07:45:49.355900900Z"
}
},
"id": "1a9dc8188e83e5e9"
},
{
"cell_type": "code",
"execution_count": 72,
"outputs": [],
"source": [
"# Save the predictions (without postprocessing)\n",
"pd.DataFrame(y_dev_pred_con).to_csv('dev-0/out-model.tsv', header=False, index=False, sep='\\t')\n",
"pd.DataFrame(y_test_pred_con).to_csv('test-A/out-model.tsv', header=False, index=False, sep='\\t')"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-05-14T07:45:49.397283400Z",
"start_time": "2024-05-14T07:45:49.370434300Z"
}
},
"id": "66daac87feaa2f66"
},
{
"cell_type": "code",
"execution_count": 42,
"outputs": [],
"source": [
"# Postprocessing\n",
"# Regex for finding I-tags that start a sequence (should be B-tags)\n",
"def incorrect_I_as_begin_tag(text):\n",
" return re.finditer(r'(?<![BI]-\\w+ )I-\\w+', text)\n",
"\n",
"# Helper method for replacing I-tags that start a sequence with B-tags\n",
"def replace_incorrect_I_as_begin_tag(df):\n",
" # Iterate until no more changes\n",
" i = 0\n",
" \n",
" while True:\n",
" outer_counter_old = 0\n",
" outer_counter = 0\n",
" \n",
" print(f\"Iteration: {i+1}\")\n",
" \n",
" for idx, row in df.iterrows():\n",
" x = incorrect_I_as_begin_tag(row['ner_tags'])\n",
" \n",
" inner_counter = 0\n",
" \n",
" for match in x:\n",
" inner_counter += 1\n",
" hp = list(row['ner_tags'])\n",
" hp[match.start()] = 'B'\n",
" row['ner_tags'] = \"\".join(hp)\n",
" \n",
" outer_counter += inner_counter\n",
" \n",
" print(f\"Changes: {outer_counter - outer_counter_old}\")\n",
" \n",
" i += 1\n",
" \n",
" if outer_counter_old == outer_counter:\n",
" break\n",
" else:\n",
" outer_counter_old = outer_counter\n",
" \n",
" return df\n",
"\n",
"# Regex for finding inconsistent I-tags after B-tags (I-tags that are not continuation of B-tags)\n",
"def inconsistent_I_after_B(text):\n",
" return re.finditer(r'(?<=B-(\\w+) )(?:I-(?!\\1)\\w+)', text)\n",
"\n",
"# Helper method for removing inconsistent I-tags after B-tags\n",
"def replace_inconsistent_I_after_B(df):\n",
" # Iterate until no more changes\n",
" i = 0\n",
" \n",
" while True:\n",
" outer_counter_old = 0\n",
" outer_counter = 0\n",
" \n",
" print(f\"Iteration: {i+1}\")\n",
" \n",
" for idx, row in df.iterrows():\n",
" matches = inconsistent_I_after_B(row['ner_tags'])\n",
" \n",
" inner_counter = 0\n",
" \n",
" for match in matches:\n",
" inner_counter += 1\n",
" hp = list(row['ner_tags'])\n",
" hp[match.start()] = 'B'\n",
" row['ner_tags'] = \"\".join(hp)\n",
" \n",
" outer_counter += inner_counter\n",
" \n",
" print(f\"Changes: {outer_counter - outer_counter_old}\")\n",
" \n",
" i += 1\n",
" \n",
" if outer_counter_old == outer_counter:\n",
" break\n",
" else:\n",
" outer_counter_old = outer_counter\n",
" \n",
" return df\n",
"\n",
"# Regex for finding inconsistent I-tags after other I-tags (I-tags that are not continuation of the same tag)\n",
"def inconsistent_I_after_I(text):\n",
" return re.finditer(r'(?<=I-(\\w+) )(?:I-(?!\\1)\\w+)', text)\n",
"\n",
"# Helper method for removing inconsistent I-tags after other I-tags\n",
"def replace_inconsistent_I_after_I(df):\n",
" # Iterate until no more changes\n",
" i = 0\n",
" \n",
" while True:\n",
" outer_counter_old = 0\n",
" outer_counter = 0\n",
" \n",
" print(f\"Iteration: {i+1}\")\n",
" \n",
" for idx, row in df.iterrows():\n",
" matches = inconsistent_I_after_I(row['ner_tags'])\n",
" \n",
" inner_counter = 0\n",
" \n",
" for match in matches:\n",
" inner_counter += 1\n",
" hp = list(row['ner_tags'])\n",
" hp[match.start()] = 'B'\n",
" row['ner_tags'] = \"\".join(hp)\n",
" \n",
" outer_counter += inner_counter\n",
" \n",
" print(f\"Changes: {outer_counter - outer_counter_old}\")\n",
" \n",
" i += 1\n",
" \n",
" if outer_counter_old == outer_counter:\n",
" break\n",
" else:\n",
" outer_counter_old = outer_counter\n",
" \n",
" return df"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-05-16T19:18:22.287969600Z",
"start_time": "2024-05-16T19:18:22.272058200Z"
}
},
"id": "e3f5c71b5b231d5e"
},
{
"cell_type": "code",
"execution_count": 43,
"outputs": [],
"source": [
"# Load the predictions\n",
"out_dev = pd.read_csv('dev-0/out-model.tsv', delimiter='\\t', header=None)\n",
"out_dev.columns = ['ner_tags']\n",
"\n",
"out_test = pd.read_csv('test-A/out-model.tsv', delimiter='\\t', header=None)\n",
"out_test.columns = ['ner_tags']"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-05-16T19:18:25.353082700Z",
"start_time": "2024-05-16T19:18:25.341655500Z"
}
},
"id": "cef273b10f1fc169"
},
{
"cell_type": "code",
"execution_count": 44,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Iteration: 1\n",
"Changes: 100\n",
"Iteration: 2\n",
"Changes: 0\n",
"Iteration: 1\n",
"Changes: 113\n",
"Iteration: 2\n",
"Changes: 4\n",
"Iteration: 3\n",
"Changes: 0\n",
"Iteration: 1\n",
"Changes: 18\n",
"Iteration: 2\n",
"Changes: 0\n",
"Iteration: 1\n",
"Changes: 105\n",
"Iteration: 2\n",
"Changes: 0\n",
"Iteration: 1\n",
"Changes: 111\n",
"Iteration: 2\n",
"Changes: 5\n",
"Iteration: 3\n",
"Changes: 0\n",
"Iteration: 1\n",
"Changes: 22\n",
"Iteration: 2\n",
"Changes: 0\n"
]
}
],
"source": [
"# Postprocessing\n",
"out_dev = replace_incorrect_I_as_begin_tag(out_dev)\n",
"out_dev = replace_inconsistent_I_after_B(out_dev)\n",
"out_dev = replace_inconsistent_I_after_I(out_dev)\n",
"\n",
"out_test = replace_incorrect_I_as_begin_tag(out_test)\n",
"out_test = replace_inconsistent_I_after_B(out_test)\n",
"out_test = replace_inconsistent_I_after_I(out_test)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-05-16T19:18:32.884479700Z",
"start_time": "2024-05-16T19:18:32.705259700Z"
}
},
"id": "a845573affc53a38"
},
{
"cell_type": "code",
"execution_count": 45,
"outputs": [],
"source": [
"# Save the predictions (with postprocessing)\n",
"out_dev.to_csv('dev-0/out.tsv', header=False, index=False, sep='\\t')\n",
"out_test.to_csv('test-A/out.tsv', header=False, index=False, sep='\\t')"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-05-16T19:18:35.475503700Z",
"start_time": "2024-05-16T19:18:35.453433800Z"
}
},
"id": "d8cf3a8cdbe2de9a"
},
{
"cell_type": "code",
"execution_count": 47,
"outputs": [],
"source": [
"# Evaluation\n",
"in_dev = pd.read_csv('dev-0/expected.tsv', delimiter='\\t', header=None)\n",
"in_dev.columns = ['ner_tags']"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-05-16T19:21:21.442871Z",
"start_time": "2024-05-16T19:21:21.423414300Z"
}
},
"id": "341015fd66bc6573"
},
{
"cell_type": "markdown",
"source": [
"GEVAL F1-BIO (dev): 0.74864"
],
"metadata": {
"collapsed": false
},
"id": "d0c015ab8e55873c"
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}