Initial commit
This commit is contained in:
parent
a2c225a80c
commit
1f7b84fa29
956
lstm.ipynb
Normal file
956
lstm.ipynb
Normal file
@ -0,0 +1,956 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"id": "initial_id",
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": true,
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2024-05-16T18:21:49.572131300Z",
|
||||||
|
"start_time": "2024-05-16T18:21:43.423852800Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import pandas as pd\n",
|
||||||
|
"import numpy as np\n",
|
||||||
|
"\n",
|
||||||
|
"import torch\n",
|
||||||
|
"import torch.nn as nn\n",
|
||||||
|
"import torch.optim as optim\n",
|
||||||
|
"\n",
|
||||||
|
"import warnings\n",
|
||||||
|
"warnings.filterwarnings('ignore')\n",
|
||||||
|
"\n",
|
||||||
|
"import torchtext\n",
|
||||||
|
"from torchtext.vocab import vocab\n",
|
||||||
|
"\n",
|
||||||
|
"from seqeval.metrics import accuracy_score, precision_score, recall_score, f1_score, classification_report\n",
|
||||||
|
"\n",
|
||||||
|
"from tqdm.notebook import tqdm\n",
|
||||||
|
"\n",
|
||||||
|
"from collections import Counter"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 7,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Load the data\n",
|
||||||
|
"train_data = pd.read_csv('train/train.tsv', delimiter='\\t', header=None)\n",
|
||||||
|
"\n",
|
||||||
|
"valid_data_in = pd.read_csv('dev-0/in.tsv', delimiter='\\t', header=None)\n",
|
||||||
|
"valid_data_expected = pd.read_csv('dev-0/expected.tsv', delimiter='\\t', header=None)\n",
|
||||||
|
"valid_data = pd.concat([valid_data_expected, valid_data_in], axis=1)\n",
|
||||||
|
"\n",
|
||||||
|
"test_data = pd.read_csv('test-A/in.tsv', delimiter='\\t', header=None)\n",
|
||||||
|
"\n",
|
||||||
|
"# Label the columns\n",
|
||||||
|
"train_data.columns = ['ner_tags', 'text']\n",
|
||||||
|
"valid_data.columns = ['ner_tags', 'text']\n",
|
||||||
|
"test_data.columns = ['text']\n",
|
||||||
|
"\n",
|
||||||
|
"# Split the text into tokens\n",
|
||||||
|
"train_data['text_tokens'] = train_data['text'].apply(lambda x: x.split())\n",
|
||||||
|
"valid_data['text_tokens'] = valid_data['text'].apply(lambda x: x.split())\n",
|
||||||
|
"test_data['text_tokens'] = test_data['text'].apply(lambda x: x.split())\n",
|
||||||
|
"\n",
|
||||||
|
"# Split the NER tags into tokens\n",
|
||||||
|
"train_data['ner_tags_tokens'] = train_data['ner_tags'].apply(lambda x: x.split())\n",
|
||||||
|
"valid_data['ner_tags_tokens'] = valid_data['ner_tags'].apply(lambda x: x.split())"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2024-05-14T07:11:23.174336100Z",
|
||||||
|
"start_time": "2024-05-14T07:11:23.080690300Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"id": "9e5c5c1083e3f387"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 8,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Method for building the vocabulary from DataFrame dataset\n",
|
||||||
|
"# Special tokens:\n",
|
||||||
|
"# <unk> - unknown token\n",
|
||||||
|
"# <pad> - padding token\n",
|
||||||
|
"# <bos> - beginning of sentence token\n",
|
||||||
|
"# <eos> - end of sentence token\n",
|
||||||
|
"def build_vocab(dataset):\n",
|
||||||
|
" # Initialize the counter\n",
|
||||||
|
" counter = Counter()\n",
|
||||||
|
" \n",
|
||||||
|
" # Iterate over the dataset and update the counter\n",
|
||||||
|
" for idx, document in dataset.iterrows():\n",
|
||||||
|
" counter.update(document['text_tokens'])\n",
|
||||||
|
" \n",
|
||||||
|
" # Return the vocabulary\n",
|
||||||
|
" return vocab(counter, specials=['<unk>', '<pad>', '<bos>', '<eos>'])"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2024-05-14T07:11:23.647897500Z",
|
||||||
|
"start_time": "2024-05-14T07:11:23.640148800Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"id": "56a8833a05334060"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 9,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Build the vocabulary\n",
|
||||||
|
"v = build_vocab(train_data)"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2024-05-14T07:11:24.169912Z",
|
||||||
|
"start_time": "2024-05-14T07:11:24.081356500Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"id": "eacfbc15230adc2e"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 10,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Mapping from index to token\n",
|
||||||
|
"itos = v.get_itos()"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2024-05-14T07:11:24.484522400Z",
|
||||||
|
"start_time": "2024-05-14T07:11:24.470356200Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"id": "c9c7ce32ebd5a3c2"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 11,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Set default index for unknown tokens\n",
|
||||||
|
"v.set_default_index(v[\"<unk>\"])"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2024-05-14T07:11:24.842556700Z",
|
||||||
|
"start_time": "2024-05-14T07:11:24.823442400Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"id": "ce8d899162dcc776"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 12,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Get the unique ner tags\n",
|
||||||
|
"ner_tags = set([tag for tags in train_data['ner_tags_tokens'] for tag in tags])"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2024-05-14T07:11:25.201567900Z",
|
||||||
|
"start_time": "2024-05-14T07:11:25.180831600Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"id": "2e9f2dc469b6025d"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 13,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Mapping from tag to index (https://huggingface.co/datasets/conll2003)\n",
|
||||||
|
"ner_tag2idx = {'O': 0, 'B-PER': 1, 'I-PER': 2, 'B-ORG': 3, 'I-ORG': 4, 'B-LOC': 5, 'I-LOC': 6, 'B-MISC': 7, 'I-MISC': 8}\n",
|
||||||
|
"\n",
|
||||||
|
"# reverse mapping\n",
|
||||||
|
"ner_idx2tag = {idx: tag for tag, idx in ner_tag2idx.items()}"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2024-05-14T07:11:26.534701200Z",
|
||||||
|
"start_time": "2024-05-14T07:11:26.526620300Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"id": "5271fd04bd9f16e3"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 14,
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": "{'O': 0,\n 'B-PER': 1,\n 'I-PER': 2,\n 'B-ORG': 3,\n 'I-ORG': 4,\n 'B-LOC': 5,\n 'I-LOC': 6,\n 'B-MISC': 7,\n 'I-MISC': 8}"
|
||||||
|
},
|
||||||
|
"execution_count": 14,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"ner_tag2idx"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2024-05-14T07:11:27.854314700Z",
|
||||||
|
"start_time": "2024-05-14T07:11:27.844315700Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"id": "8bf1e9961daa4bd8"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 15,
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": "{0: 'O',\n 1: 'B-PER',\n 2: 'I-PER',\n 3: 'B-ORG',\n 4: 'I-ORG',\n 5: 'B-LOC',\n 6: 'I-LOC',\n 7: 'B-MISC',\n 8: 'I-MISC'}"
|
||||||
|
},
|
||||||
|
"execution_count": 15,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"ner_idx2tag"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2024-05-14T07:11:28.332071700Z",
|
||||||
|
"start_time": "2024-05-14T07:11:28.286070800Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"id": "12571d646796d21b"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 16,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Method for vectorizing text data using the vocabulary mapping\n",
|
||||||
|
"def text_to_vec(data):\n",
|
||||||
|
" return [torch.tensor([v['<bos>']] + [v[token] for token in document] + [v['<eos>']], dtype=torch.long) for document in data]"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2024-05-14T07:11:29.032865100Z",
|
||||||
|
"start_time": "2024-05-14T07:11:29.012730500Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"id": "da795a7fd000b135"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 17,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Method for vectorizing NER tags data using the NER tags mapping\n",
|
||||||
|
"def ner_tags_to_vec(data):\n",
|
||||||
|
" return [torch.tensor([0] + [ner_tag2idx[tag] for tag in document] + [0], dtype=torch.long) for document in data]"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2024-05-14T07:11:29.824074700Z",
|
||||||
|
"start_time": "2024-05-14T07:11:29.812059800Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"id": "f9c2bb1f0bb0e480"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 18,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Vectorize the text data (input)\n",
|
||||||
|
"X_train = text_to_vec(train_data['text_tokens'])\n",
|
||||||
|
"X_dev = text_to_vec(valid_data['text_tokens'])\n",
|
||||||
|
"X_test = text_to_vec(test_data['text_tokens'])"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2024-05-14T07:11:30.896086700Z",
|
||||||
|
"start_time": "2024-05-14T07:11:30.610066Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"id": "2f851f63cedacf6c"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 19,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Vectorize the NER tags data (output, labels)\n",
|
||||||
|
"y_train = ner_tags_to_vec(train_data['ner_tags_tokens'])\n",
|
||||||
|
"y_dev = ner_tags_to_vec(valid_data['ner_tags_tokens'])"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2024-05-14T07:11:31.468671200Z",
|
||||||
|
"start_time": "2024-05-14T07:11:31.415476500Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"id": "30e8c488d3b9d11a"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 20,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Model definition\n",
|
||||||
|
"class LSTM(nn.Module):\n",
|
||||||
|
" def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim):\n",
|
||||||
|
" super(LSTM, self).__init__()\n",
|
||||||
|
" \n",
|
||||||
|
" # Embedding layer\n",
|
||||||
|
" self.embedding = nn.Embedding(vocab_size, embedding_dim)\n",
|
||||||
|
" \n",
|
||||||
|
" # LSTM layer\n",
|
||||||
|
" self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first = True)\n",
|
||||||
|
" \n",
|
||||||
|
" # Fully connected layer\n",
|
||||||
|
" self.fc = nn.Linear(hidden_dim, output_dim)\n",
|
||||||
|
" \n",
|
||||||
|
" self.relu = nn.ReLU()\n",
|
||||||
|
" \n",
|
||||||
|
" def forward(self, x):\n",
|
||||||
|
" # Embedding\n",
|
||||||
|
" embedding = self.relu(self.embedding(x))\n",
|
||||||
|
" \n",
|
||||||
|
" # LSTM\n",
|
||||||
|
" output, (hidden, cell) = self.lstm(embedding)\n",
|
||||||
|
" \n",
|
||||||
|
" # Fully connected\n",
|
||||||
|
" output = self.fc(output)\n",
|
||||||
|
" \n",
|
||||||
|
" return output\n",
|
||||||
|
" "
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2024-05-14T07:11:32.454622200Z",
|
||||||
|
"start_time": "2024-05-14T07:11:32.422201500Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"id": "6a86649248c384b5"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 77,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Segeval evaluation\n",
|
||||||
|
"def evaluate_model(model, X, y):\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" Method for evaluating the model\n",
|
||||||
|
" :param model: model\n",
|
||||||
|
" :param X: input data\n",
|
||||||
|
" :param y: output data \n",
|
||||||
|
" :return: dictionary with metrics values\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" # No gradients\n",
|
||||||
|
" with torch.no_grad():\n",
|
||||||
|
" # Predict the labels\n",
|
||||||
|
" y_pred = [torch.argmax(model(x.unsqueeze(0)).squeeze(0), 1) for x in X]\n",
|
||||||
|
" \n",
|
||||||
|
" # Convert the labels to ner tags\n",
|
||||||
|
" y_pred = [[ner_idx2tag[int(idx)] for idx in y] for y in y_pred]\n",
|
||||||
|
" y_tags = [[ner_idx2tag[int(idx)] for idx in y] for y in y]\n",
|
||||||
|
" \n",
|
||||||
|
" # Calculate the metrics\n",
|
||||||
|
" accuracy = accuracy_score(y_tags, y_pred)\n",
|
||||||
|
" precision = precision_score(y_tags, y_pred)\n",
|
||||||
|
" recall = recall_score(y_tags, y_pred)\n",
|
||||||
|
" f1 = f1_score(y_tags, y_pred)\n",
|
||||||
|
" \n",
|
||||||
|
" return {'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1': f1}"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2024-05-14T08:26:31.612231Z",
|
||||||
|
"start_time": "2024-05-14T08:26:31.599603300Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"id": "b18d26ac9fbc590e"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 23,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Use GPU if available\n",
|
||||||
|
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2024-05-14T07:11:49.825835200Z",
|
||||||
|
"start_time": "2024-05-14T07:11:49.817343100Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"id": "badf288796646abe"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 39,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Model parameters\n",
|
||||||
|
"vocab_size = len(v)\n",
|
||||||
|
"embedding_dim = 64\n",
|
||||||
|
"hidden_dim = 256\n",
|
||||||
|
"output_dim = len(ner_tags)\n",
|
||||||
|
"epochs = 20"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2024-05-14T07:22:20.730379Z",
|
||||||
|
"start_time": "2024-05-14T07:22:20.724143500Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"id": "65beded501220882"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 44,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Seed for reproducibility\n",
|
||||||
|
"torch.manual_seed(1234)\n",
|
||||||
|
"\n",
|
||||||
|
"import random\n",
|
||||||
|
"random.seed(1234)\n",
|
||||||
|
"\n",
|
||||||
|
"np.random.seed(1234)"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2024-05-14T07:28:18.248713300Z",
|
||||||
|
"start_time": "2024-05-14T07:28:18.188830400Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"id": "63b68885d93d5fce"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 40,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Initialize the model\n",
|
||||||
|
"model = LSTM(vocab_size, embedding_dim, hidden_dim, output_dim)"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2024-05-14T07:22:21.068317100Z",
|
||||||
|
"start_time": "2024-05-14T07:22:21.044162900Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"id": "29116c705decf395"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 41,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Loss function and optimizer\n",
|
||||||
|
"criterion = nn.CrossEntropyLoss()\n",
|
||||||
|
"optimizer = optim.Adam(model.parameters())"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2024-05-14T07:22:21.705555900Z",
|
||||||
|
"start_time": "2024-05-14T07:22:21.675608300Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"id": "617bec2a8a8b56b3"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 65,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Move training to GPU\n",
|
||||||
|
"model = model.to(device)\n",
|
||||||
|
"X_train = [x.to(device) for x in X_train]\n",
|
||||||
|
"y_train = [y.to(device) for y in y_train]\n",
|
||||||
|
"X_dev = [x.to(device) for x in X_dev]\n",
|
||||||
|
"y_dev = [y.to(device) for y in y_dev]"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2024-05-14T07:44:19.353471700Z",
|
||||||
|
"start_time": "2024-05-14T07:44:19.317384100Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"id": "dfa0d6b3bdca6853"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 67,
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": " 0%| | 0/945 [00:00<?, ?it/s]",
|
||||||
|
"application/vnd.jupyter.widget-view+json": {
|
||||||
|
"version_major": 2,
|
||||||
|
"version_minor": 0,
|
||||||
|
"model_id": "415ba9a191bf4ff0993115b428e604ec"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "display_data"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Epoch: 1, Accuracy: 0.9545313667936774, Precision: 0.7780607604147717, Recall: 0.7213695395513577, F1: 0.7486434447750743\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"# Training loop\n",
|
||||||
|
"model.train()\n",
|
||||||
|
"\n",
|
||||||
|
"for epoch in range(epochs):\n",
|
||||||
|
" \n",
|
||||||
|
" for idx in tqdm(range(len(X_train))):\n",
|
||||||
|
" # Zero the gradients\n",
|
||||||
|
" optimizer.zero_grad()\n",
|
||||||
|
" \n",
|
||||||
|
" # Forward pass\n",
|
||||||
|
" output = model(X_train[idx].unsqueeze(0))\n",
|
||||||
|
"\n",
|
||||||
|
" # Calculate the loss\n",
|
||||||
|
" loss = criterion(output.squeeze(0), y_train[idx])\n",
|
||||||
|
" \n",
|
||||||
|
" # Backward pass\n",
|
||||||
|
" loss.backward()\n",
|
||||||
|
" \n",
|
||||||
|
" # Update the weights\n",
|
||||||
|
" optimizer.step()\n",
|
||||||
|
" \n",
|
||||||
|
" # Evaluate the model on the dev set\n",
|
||||||
|
" metrics = evaluate_model(model, X_dev, y_dev)\n",
|
||||||
|
" \n",
|
||||||
|
" print(f'Epoch: {epoch+1}, Accuracy: {metrics[\"accuracy\"]}, Precision: {metrics[\"precision\"]}, Recall: {metrics[\"recall\"]}, F1: {metrics[\"f1\"]}')"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2024-05-14T07:45:10.782583200Z",
|
||||||
|
"start_time": "2024-05-14T07:44:53.579284100Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"id": "7a77d0ac6fce81fd"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 78,
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": "{'accuracy': 0.9545313667936774,\n 'precision': 0.7780607604147717,\n 'recall': 0.7213695395513577,\n 'f1': 0.7486434447750743}"
|
||||||
|
},
|
||||||
|
"execution_count": 78,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"evaluate_model(model, X_dev, y_dev)"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2024-05-14T08:26:38.050024900Z",
|
||||||
|
"start_time": "2024-05-14T08:26:36.980050200Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"id": "956b180f74abd5d4"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 69,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Move to CPU\n",
|
||||||
|
"model = model.to('cpu')\n",
|
||||||
|
"X_dev = [x.to('cpu') for x in X_dev]\n",
|
||||||
|
"y_dev = [y.to('cpu') for y in y_dev]"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2024-05-14T07:45:46.733747700Z",
|
||||||
|
"start_time": "2024-05-14T07:45:46.673385500Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"id": "b019c1d995100ef5"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 70,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Predict the labels for the validation and test sets\n",
|
||||||
|
"with torch.no_grad():\n",
|
||||||
|
" y_dev_pred = [torch.argmax(model(x.unsqueeze(0)).squeeze(0), 1) for x in X_dev]\n",
|
||||||
|
" y_test_pred = [torch.argmax(model(x.unsqueeze(0)).squeeze(0), 1) for x in X_test]\n",
|
||||||
|
"\n",
|
||||||
|
"# Convert the labels to ner tags\n",
|
||||||
|
"y_dev_pred = [[ner_idx2tag[int(idx)] for idx in y] for y in y_dev_pred]\n",
|
||||||
|
"y_test_pred = [[ner_idx2tag[int(idx)] for idx in y] for y in y_test_pred]"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2024-05-14T07:45:49.354440Z",
|
||||||
|
"start_time": "2024-05-14T07:45:47.244447100Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"id": "523f8444e9a73a05"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 71,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Concatenate predicted labels (skip the special tokens <bos> and <eos>)\n",
|
||||||
|
"y_dev_pred_con = [' '.join(y[1:-1]) for y in y_dev_pred]\n",
|
||||||
|
"y_test_pred_con = [' '.join(y[1:-1]) for y in y_test_pred]"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2024-05-14T07:45:49.368345900Z",
|
||||||
|
"start_time": "2024-05-14T07:45:49.355900900Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"id": "1a9dc8188e83e5e9"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 72,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Save the predictions (without postprocessing)\n",
|
||||||
|
"pd.DataFrame(y_dev_pred_con).to_csv('dev-0/out-model.tsv', header=False, index=False, sep='\\t')\n",
|
||||||
|
"pd.DataFrame(y_test_pred_con).to_csv('test-A/out-model.tsv', header=False, index=False, sep='\\t')"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2024-05-14T07:45:49.397283400Z",
|
||||||
|
"start_time": "2024-05-14T07:45:49.370434300Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"id": "66daac87feaa2f66"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 42,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Postprocessing\n",
|
||||||
|
"# Regex for finding I-tags that start a sequence (should be B-tags)\n",
|
||||||
|
"def incorrect_I_as_begin_tag(text):\n",
|
||||||
|
" return re.finditer(r'(?<![BI]-\\w+ )I-\\w+', text)\n",
|
||||||
|
"\n",
|
||||||
|
"# Helper method for replacing I-tags that start a sequence with B-tags\n",
|
||||||
|
"def replace_incorrect_I_as_begin_tag(df):\n",
|
||||||
|
" # Iterate until no more changes\n",
|
||||||
|
" i = 0\n",
|
||||||
|
" \n",
|
||||||
|
" while True:\n",
|
||||||
|
" outer_counter_old = 0\n",
|
||||||
|
" outer_counter = 0\n",
|
||||||
|
" \n",
|
||||||
|
" print(f\"Iteration: {i+1}\")\n",
|
||||||
|
" \n",
|
||||||
|
" for idx, row in df.iterrows():\n",
|
||||||
|
" x = incorrect_I_as_begin_tag(row['ner_tags'])\n",
|
||||||
|
" \n",
|
||||||
|
" inner_counter = 0\n",
|
||||||
|
" \n",
|
||||||
|
" for match in x:\n",
|
||||||
|
" inner_counter += 1\n",
|
||||||
|
" hp = list(row['ner_tags'])\n",
|
||||||
|
" hp[match.start()] = 'B'\n",
|
||||||
|
" row['ner_tags'] = \"\".join(hp)\n",
|
||||||
|
" \n",
|
||||||
|
" outer_counter += inner_counter\n",
|
||||||
|
" \n",
|
||||||
|
" print(f\"Changes: {outer_counter - outer_counter_old}\")\n",
|
||||||
|
" \n",
|
||||||
|
" i += 1\n",
|
||||||
|
" \n",
|
||||||
|
" if outer_counter_old == outer_counter:\n",
|
||||||
|
" break\n",
|
||||||
|
" else:\n",
|
||||||
|
" outer_counter_old = outer_counter\n",
|
||||||
|
" \n",
|
||||||
|
" return df\n",
|
||||||
|
"\n",
|
||||||
|
"# Regex for finding inconsistent I-tags after B-tags (I-tags that are not continuation of B-tags)\n",
|
||||||
|
"def inconsistent_I_after_B(text):\n",
|
||||||
|
" return re.finditer(r'(?<=B-(\\w+) )(?:I-(?!\\1)\\w+)', text)\n",
|
||||||
|
"\n",
|
||||||
|
"# Helper method for removing inconsistent I-tags after B-tags\n",
|
||||||
|
"def replace_inconsistent_I_after_B(df):\n",
|
||||||
|
" # Iterate until no more changes\n",
|
||||||
|
" i = 0\n",
|
||||||
|
" \n",
|
||||||
|
" while True:\n",
|
||||||
|
" outer_counter_old = 0\n",
|
||||||
|
" outer_counter = 0\n",
|
||||||
|
" \n",
|
||||||
|
" print(f\"Iteration: {i+1}\")\n",
|
||||||
|
" \n",
|
||||||
|
" for idx, row in df.iterrows():\n",
|
||||||
|
" matches = inconsistent_I_after_B(row['ner_tags'])\n",
|
||||||
|
" \n",
|
||||||
|
" inner_counter = 0\n",
|
||||||
|
" \n",
|
||||||
|
" for match in matches:\n",
|
||||||
|
" inner_counter += 1\n",
|
||||||
|
" hp = list(row['ner_tags'])\n",
|
||||||
|
" hp[match.start()] = 'B'\n",
|
||||||
|
" row['ner_tags'] = \"\".join(hp)\n",
|
||||||
|
" \n",
|
||||||
|
" outer_counter += inner_counter\n",
|
||||||
|
" \n",
|
||||||
|
" print(f\"Changes: {outer_counter - outer_counter_old}\")\n",
|
||||||
|
" \n",
|
||||||
|
" i += 1\n",
|
||||||
|
" \n",
|
||||||
|
" if outer_counter_old == outer_counter:\n",
|
||||||
|
" break\n",
|
||||||
|
" else:\n",
|
||||||
|
" outer_counter_old = outer_counter\n",
|
||||||
|
" \n",
|
||||||
|
" return df\n",
|
||||||
|
"\n",
|
||||||
|
"# Regex for finding inconsistent I-tags after other I-tags (I-tags that are not continuation of the same tag)\n",
|
||||||
|
"def inconsistent_I_after_I(text):\n",
|
||||||
|
" return re.finditer(r'(?<=I-(\\w+) )(?:I-(?!\\1)\\w+)', text)\n",
|
||||||
|
"\n",
|
||||||
|
"# Helper method for removing inconsistent I-tags after other I-tags\n",
|
||||||
|
"def replace_inconsistent_I_after_I(df):\n",
|
||||||
|
" # Iterate until no more changes\n",
|
||||||
|
" i = 0\n",
|
||||||
|
" \n",
|
||||||
|
" while True:\n",
|
||||||
|
" outer_counter_old = 0\n",
|
||||||
|
" outer_counter = 0\n",
|
||||||
|
" \n",
|
||||||
|
" print(f\"Iteration: {i+1}\")\n",
|
||||||
|
" \n",
|
||||||
|
" for idx, row in df.iterrows():\n",
|
||||||
|
" matches = inconsistent_I_after_I(row['ner_tags'])\n",
|
||||||
|
" \n",
|
||||||
|
" inner_counter = 0\n",
|
||||||
|
" \n",
|
||||||
|
" for match in matches:\n",
|
||||||
|
" inner_counter += 1\n",
|
||||||
|
" hp = list(row['ner_tags'])\n",
|
||||||
|
" hp[match.start()] = 'B'\n",
|
||||||
|
" row['ner_tags'] = \"\".join(hp)\n",
|
||||||
|
" \n",
|
||||||
|
" outer_counter += inner_counter\n",
|
||||||
|
" \n",
|
||||||
|
" print(f\"Changes: {outer_counter - outer_counter_old}\")\n",
|
||||||
|
" \n",
|
||||||
|
" i += 1\n",
|
||||||
|
" \n",
|
||||||
|
" if outer_counter_old == outer_counter:\n",
|
||||||
|
" break\n",
|
||||||
|
" else:\n",
|
||||||
|
" outer_counter_old = outer_counter\n",
|
||||||
|
" \n",
|
||||||
|
" return df"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2024-05-16T19:18:22.287969600Z",
|
||||||
|
"start_time": "2024-05-16T19:18:22.272058200Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"id": "e3f5c71b5b231d5e"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 43,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Load the predictions\n",
|
||||||
|
"out_dev = pd.read_csv('dev-0/out-model.tsv', delimiter='\\t', header=None)\n",
|
||||||
|
"out_dev.columns = ['ner_tags']\n",
|
||||||
|
"\n",
|
||||||
|
"out_test = pd.read_csv('test-A/out-model.tsv', delimiter='\\t', header=None)\n",
|
||||||
|
"out_test.columns = ['ner_tags']"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2024-05-16T19:18:25.353082700Z",
|
||||||
|
"start_time": "2024-05-16T19:18:25.341655500Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"id": "cef273b10f1fc169"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 44,
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Iteration: 1\n",
|
||||||
|
"Changes: 100\n",
|
||||||
|
"Iteration: 2\n",
|
||||||
|
"Changes: 0\n",
|
||||||
|
"Iteration: 1\n",
|
||||||
|
"Changes: 113\n",
|
||||||
|
"Iteration: 2\n",
|
||||||
|
"Changes: 4\n",
|
||||||
|
"Iteration: 3\n",
|
||||||
|
"Changes: 0\n",
|
||||||
|
"Iteration: 1\n",
|
||||||
|
"Changes: 18\n",
|
||||||
|
"Iteration: 2\n",
|
||||||
|
"Changes: 0\n",
|
||||||
|
"Iteration: 1\n",
|
||||||
|
"Changes: 105\n",
|
||||||
|
"Iteration: 2\n",
|
||||||
|
"Changes: 0\n",
|
||||||
|
"Iteration: 1\n",
|
||||||
|
"Changes: 111\n",
|
||||||
|
"Iteration: 2\n",
|
||||||
|
"Changes: 5\n",
|
||||||
|
"Iteration: 3\n",
|
||||||
|
"Changes: 0\n",
|
||||||
|
"Iteration: 1\n",
|
||||||
|
"Changes: 22\n",
|
||||||
|
"Iteration: 2\n",
|
||||||
|
"Changes: 0\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"# Postprocessing\n",
|
||||||
|
"out_dev = replace_incorrect_I_as_begin_tag(out_dev)\n",
|
||||||
|
"out_dev = replace_inconsistent_I_after_B(out_dev)\n",
|
||||||
|
"out_dev = replace_inconsistent_I_after_I(out_dev)\n",
|
||||||
|
"\n",
|
||||||
|
"out_test = replace_incorrect_I_as_begin_tag(out_test)\n",
|
||||||
|
"out_test = replace_inconsistent_I_after_B(out_test)\n",
|
||||||
|
"out_test = replace_inconsistent_I_after_I(out_test)"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2024-05-16T19:18:32.884479700Z",
|
||||||
|
"start_time": "2024-05-16T19:18:32.705259700Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"id": "a845573affc53a38"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 45,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Save the predictions (with postprocessing)\n",
|
||||||
|
"out_dev.to_csv('dev-0/out.tsv', header=False, index=False, sep='\\t')\n",
|
||||||
|
"out_test.to_csv('test-A/out.tsv', header=False, index=False, sep='\\t')"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2024-05-16T19:18:35.475503700Z",
|
||||||
|
"start_time": "2024-05-16T19:18:35.453433800Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"id": "d8cf3a8cdbe2de9a"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 47,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Evaluation\n",
|
||||||
|
"in_dev = pd.read_csv('dev-0/expected.tsv', delimiter='\\t', header=None)\n",
|
||||||
|
"in_dev.columns = ['ner_tags']"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2024-05-16T19:21:21.442871Z",
|
||||||
|
"start_time": "2024-05-16T19:21:21.423414300Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"id": "341015fd66bc6573"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"source": [
|
||||||
|
"GEVAL F1-BIO (dev): 0.74864"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false
|
||||||
|
},
|
||||||
|
"id": "d0c015ab8e55873c"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 2
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython2",
|
||||||
|
"version": "2.7.6"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user