en-ner-conll-2003/transformer.ipynb

1292 lines
355 KiB
Plaintext
Raw Normal View History

2024-05-30 06:22:01 +02:00
{
"cells": [
{
"cell_type": "markdown",
"source": [
"## Transformer"
],
"metadata": {
"collapsed": false
},
"id": "7dd30d84a916d9d0"
},
{
"cell_type": "code",
"execution_count": 1,
"outputs": [],
"source": [
"# Necessary imports\n",
"import pandas as pd\n",
"import numpy as np\n",
"\n",
"import torch\n",
"\n",
"import datasets\n",
"from datasets import ClassLabel, Features, Sequence, Value\n",
"from transformers import pipeline, AutoModelForTokenClassification, AutoTokenizer, TrainingArguments, Trainer, DataCollatorForTokenClassification\n",
"\n",
"import warnings\n",
"warnings.filterwarnings('ignore')"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-05-29T20:46:17.480512300Z",
"start_time": "2024-05-29T20:46:00.288032100Z"
}
},
"id": "initial_id"
},
{
"cell_type": "markdown",
"source": [
"### Prepare data"
],
"metadata": {
"collapsed": false
},
"id": "61aea5d48638d128"
},
{
"cell_type": "code",
"execution_count": 2,
"outputs": [],
"source": [
"# Divide train data into sentences and labels\n",
"train_data = pd.read_csv('train/train.tsv', sep='\\t', header=None)\n",
"\n",
"with open('train/train_labels.tsv', 'w') as f:\n",
" for i in range(len(train_data)):\n",
" if i == len(train_data) - 1:\n",
" f.write(train_data.iloc[i][0])\n",
" else:\n",
" f.write(train_data.iloc[i][0] + '\\n')\n",
" \n",
"with open('train/train_sentences.tsv', 'w') as f:\n",
" for i in range(len(train_data)):\n",
" if i == len(train_data) - 1:\n",
" f.write(train_data.iloc[i][1])\n",
" else:\n",
" f.write(train_data.iloc[i][1] + '\\n')"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-05-29T20:46:17.573476800Z",
"start_time": "2024-05-29T20:46:17.482528600Z"
}
},
"id": "b19f9ff554147bc4"
},
{
"cell_type": "code",
"execution_count": 3,
"outputs": [],
"source": [
"# Data paths\n",
"train_sentences_file = 'train/train_sentences.tsv'\n",
"train_labels_file = 'train/train_labels.tsv'\n",
"val_sentences_file = 'dev-0/in.tsv'\n",
"val_labels_file = 'dev-0/expected.tsv'\n",
"test_sentences_file = 'test-A/in.tsv'"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-05-29T20:46:17.587032200Z",
"start_time": "2024-05-29T20:46:17.575283500Z"
}
},
"id": "e0e8f33971087b3e"
},
{
"cell_type": "code",
"execution_count": 4,
"outputs": [],
"source": [
"# Method to read tokens and labels from files\n",
"def read_sentences_and_labels(sentences_path, labels_path=None):\n",
" tokens = []\n",
" ner_tags = []\n",
" \n",
" with open(sentences_path, 'r') as f:\n",
" for line in f:\n",
" tokens.append(line.strip().split())\n",
" \n",
" if labels_path:\n",
" with open(labels_path, 'r') as f:\n",
" for line in f:\n",
" ner_tags.append(line.strip().split())\n",
" \n",
" if labels_path:\n",
" return {'tokens': tokens, 'ner_tags': ner_tags}\n",
" else:\n",
" return {'tokens': tokens}"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-05-29T20:46:17.611271600Z",
"start_time": "2024-05-29T20:46:17.589908Z"
}
},
"id": "47bcc983f537b670"
},
{
"cell_type": "code",
"execution_count": 5,
"outputs": [],
"source": [
"# Load data\n",
"train_data = read_sentences_and_labels(train_sentences_file, train_labels_file)\n",
"val_data = read_sentences_and_labels(val_sentences_file, val_labels_file)\n",
"test_data = read_sentences_and_labels(test_sentences_file)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-05-29T20:46:17.757405800Z",
"start_time": "2024-05-29T20:46:17.606209100Z"
}
},
"id": "cf901e101075811e"
},
{
"cell_type": "code",
"execution_count": 6,
"outputs": [],
"source": [
"# Split long sentences into multiple sentences\n",
"def split_long_sentences(data, max_length=128):\n",
" if 'ner_tags' in data:\n",
" new_data = {'tokens': [], 'ner_tags': []}\n",
" else:\n",
" new_data = {'tokens': []}\n",
" \n",
" original_sentence_indices = []\n",
" fragment_lengths = []\n",
" \n",
" for i in range(len(data['tokens'])):\n",
" tokens = data['tokens'][i]\n",
" if 'ner_tags' in data:\n",
" ner_tags = data['ner_tags'][i]\n",
" \n",
" if len(tokens) > max_length:\n",
" for j in range(0, len(tokens), max_length):\n",
" new_data['tokens'].append(tokens[j:j+max_length])\n",
" if 'ner_tags' in data:\n",
" new_data['ner_tags'].append(ner_tags[j:j+max_length])\n",
" original_sentence_indices.append(i)\n",
" fragment_lengths.append(len(tokens[j:j+max_length]))\n",
" else:\n",
" new_data['tokens'].append(tokens)\n",
" if 'ner_tags' in data:\n",
" new_data['ner_tags'].append(ner_tags)\n",
" original_sentence_indices.append(i)\n",
" fragment_lengths.append(len(tokens))\n",
" \n",
" return new_data, original_sentence_indices, fragment_lengths"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-05-29T20:46:17.772977500Z",
"start_time": "2024-05-29T20:46:17.686317600Z"
}
},
"id": "f9afdda7a877f2"
},
{
"cell_type": "code",
"execution_count": 7,
"outputs": [],
"source": [
"# Split long sentences\n",
"train_data, train_original_sentence_indices, train_fragment_lengths = split_long_sentences(train_data)\n",
"val_data, val_original_sentence_indices, val_fragment_lengths = split_long_sentences(val_data)\n",
"test_data, test_original_sentence_indices, test_fragment_lengths = split_long_sentences(test_data)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-05-29T20:46:17.773978Z",
"start_time": "2024-05-29T20:46:17.697615700Z"
}
},
"id": "8050585694f36e46"
},
{
"cell_type": "code",
"execution_count": 8,
"outputs": [],
"source": [
"# Convert to datasets\n",
"train_dataset = datasets.Dataset.from_dict(train_data)\n",
"val_dataset = datasets.Dataset.from_dict(val_data)\n",
"test_dataset = datasets.Dataset.from_dict(test_data)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-05-29T20:46:17.849338500Z",
"start_time": "2024-05-29T20:46:17.713989300Z"
}
},
"id": "d8cb17d6e5631db5"
},
{
"cell_type": "code",
"execution_count": 9,
"outputs": [],
"source": [
"# List of unique ner labels\n",
"unique_labels = ['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC', 'B-MISC', 'I-MISC']\n",
"\n",
"# Create class label\n",
"ner_tags_feature = ClassLabel(names=unique_labels)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-05-29T20:46:17.853658400Z",
"start_time": "2024-05-29T20:46:17.839342100Z"
}
},
"id": "4baa0327c2b22f4c"
},
{
"cell_type": "code",
"execution_count": 10,
"outputs": [],
"source": [
"# Method to convert ner tags to class labels\n",
"def convert_to_classlabel(example):\n",
" example['ner_tags'] = [ner_tags_feature.str2int(tag) for tag in example['ner_tags']]\n",
" \n",
" return example"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-05-29T20:46:17.904893600Z",
"start_time": "2024-05-29T20:46:17.854695Z"
}
},
"id": "6d88730786c438c1"
},
{
"cell_type": "code",
"execution_count": 11,
"outputs": [
{
"data": {
"text/plain": "Map: 0%| | 0/2149 [00:00<?, ? examples/s]",
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "15e2dba8f69f4ca3813c25765b46cdc3"
}
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": "Map: 0%| | 0/529 [00:00<?, ? examples/s]",
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "2a3babc0133a42378718560100fc1b56"
}
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Convert ner tags to class labels\n",
"train_dataset = train_dataset.map(convert_to_classlabel)\n",
"val_dataset = val_dataset.map(convert_to_classlabel)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-05-29T20:46:18.446392600Z",
"start_time": "2024-05-29T20:46:17.869151100Z"
}
},
"id": "338c3957ef99d163"
},
{
"cell_type": "code",
"execution_count": 12,
"outputs": [],
"source": [
"# Define features\n",
"features = Features({\n",
" 'ner_tags': Sequence(ner_tags_feature),\n",
" 'tokens': Sequence(Value('string'))\n",
"})"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-05-29T20:46:18.528362Z",
"start_time": "2024-05-29T20:46:18.446392600Z"
}
},
"id": "1407184b8b686559"
},
{
"cell_type": "code",
"execution_count": 13,
"outputs": [
{
"data": {
"text/plain": "Casting the dataset: 0%| | 0/2149 [00:00<?, ? examples/s]",
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "59b23fde0d704112a43f63ae1bcc4489"
}
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": "Casting the dataset: 0%| | 0/529 [00:00<?, ? examples/s]",
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "0f2c8fc4e5a445a6be40171dc180b6a1"
}
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Cast dataset\n",
"train_dataset = train_dataset.cast(features)\n",
"val_dataset = val_dataset.cast(features)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-05-29T20:46:18.550719700Z",
"start_time": "2024-05-29T20:46:18.462988800Z"
}
},
"id": "f014e485919353b0"
},
{
"cell_type": "markdown",
"source": [
"### Tokenization"
],
"metadata": {
"collapsed": false
},
"id": "fc391077d8929f53"
},
{
"cell_type": "code",
"execution_count": 14,
"outputs": [],
"source": [
"# Load tokenizer\n",
"tokenizer = AutoTokenizer.from_pretrained(\"google-bert/bert-base-cased\", return_token_type_ids=\"token_type_ids\")"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-05-29T20:46:19.377299700Z",
"start_time": "2024-05-29T20:46:18.509638800Z"
}
},
"id": "18ecdf9a28e3172c"
},
{
"cell_type": "code",
"execution_count": 15,
"outputs": [],
"source": [
"# Method to tokenize and align labels\n",
"def tokenize_and_align_labels(examples, label_all_tokens=False):\n",
" tokenized_inputs = tokenizer(examples[\"tokens\"], truncation=True, is_split_into_words=True)\n",
" \n",
" if 'ner_tags' in examples:\n",
" labels = []\n",
" for i, label in enumerate(examples[f\"ner_tags\"]):\n",
" word_ids = tokenized_inputs.word_ids(batch_index=i)\n",
" previous_word_idx = None\n",
" label_ids = []\n",
" for word_idx in word_ids:\n",
" # Special tokens have a word id that is None. We set the label to -100 so they are automatically\n",
" # ignored in the loss function.\n",
" if word_idx is None:\n",
" label_ids.append(-100)\n",
" # We set the label for the first token of each word.\n",
" elif word_idx != previous_word_idx:\n",
" label_ids.append(label[word_idx])\n",
" # For the other tokens in a word, we set the label to either the current label or -100, depending on\n",
" # the label_all_tokens flag.\n",
" else:\n",
" label_ids.append(label[word_idx] if label_all_tokens else -100)\n",
" previous_word_idx = word_idx\n",
" \n",
" labels.append(label_ids)\n",
"\n",
" tokenized_inputs[\"labels\"] = labels\n",
" return tokenized_inputs"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-05-29T20:46:19.381308500Z",
"start_time": "2024-05-29T20:46:19.368878500Z"
}
},
"id": "3f9bf9813abb6a1f"
},
{
"cell_type": "code",
"execution_count": 16,
"outputs": [
{
"data": {
"text/plain": "Map: 0%| | 0/2149 [00:00<?, ? examples/s]",
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "1311c74351e44fcbb37e87a61c250c12"
}
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": "Map: 0%| | 0/529 [00:00<?, ? examples/s]",
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "f2a00a8e9fbb466e8be9c92351c4bbab"
}
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Tokenize and align labels\n",
"tokenized_train = train_dataset.map(tokenize_and_align_labels, batched=True)\n",
"tokenized_val = val_dataset.map(tokenize_and_align_labels, batched=True)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-05-29T20:46:20.676598100Z",
"start_time": "2024-05-29T20:46:19.380307100Z"
}
},
"id": "3485e85a2d659648"
},
{
"cell_type": "code",
"execution_count": 17,
"outputs": [
{
"data": {
"text/plain": "Map: 0%| | 0/504 [00:00<?, ? examples/s]",
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "ea9d003b39584f459c31e8fb33e37f7a"
}
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Tokenize test data\n",
"tokenized_test = test_dataset.map(tokenize_and_align_labels, batched=True)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-05-29T20:46:20.825195400Z",
"start_time": "2024-05-29T20:46:20.676598100Z"
}
},
"id": "14092535e69bcee7"
},
{
"cell_type": "markdown",
"source": [
"### Load pre-trained model"
],
"metadata": {
"collapsed": false
},
"id": "9c3a5e2044b8a5f"
},
{
"cell_type": "code",
"execution_count": 24,
"outputs": [],
"source": [
"# Load model\n",
"# model = AutoModelForTokenClassification.from_pretrained(\"distilbert-base-uncased\", num_labels=len(unique_labels))\n",
"model = AutoModelForTokenClassification.from_pretrained(\"ner-model\", num_labels=len(unique_labels))"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-05-29T20:46:51.504688200Z",
"start_time": "2024-05-29T20:46:51.212229100Z"
}
},
"id": "35286c44c2075004"
},
{
"cell_type": "markdown",
"source": [
"### Retrain model with prepared data"
],
"metadata": {
"collapsed": false
},
"id": "c8a8e0f62d8e29ff"
},
{
"cell_type": "code",
"execution_count": 25,
"outputs": [],
"source": [
"# Define training arguments\n",
"training_args = TrainingArguments(\n",
" \"test-ner\",\n",
" evaluation_strategy=\"epoch\",\n",
" learning_rate=2e-5,\n",
" per_device_train_batch_size=1,\n",
" per_device_eval_batch_size=1,\n",
" num_train_epochs=5,\n",
" weight_decay=0.01,\n",
")"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-05-29T20:47:01.006126Z",
"start_time": "2024-05-29T20:47:00.943341100Z"
}
},
"id": "38df4eac0f029518"
},
{
"cell_type": "code",
"execution_count": 26,
"outputs": [],
"source": [
"# Define data collator\n",
"data_collator = DataCollatorForTokenClassification(tokenizer)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-05-29T20:47:01.381489500Z",
"start_time": "2024-05-29T20:47:01.355676300Z"
}
},
"id": "538e65db52071f44"
},
{
"cell_type": "code",
"execution_count": 27,
"outputs": [],
"source": [
"# Define metric to compute\n",
"metric = datasets.load_metric(\"seqeval\")"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-05-29T20:47:04.088558100Z",
"start_time": "2024-05-29T20:47:01.835013600Z"
}
},
"id": "9824690d2d254d92"
},
{
"cell_type": "code",
"execution_count": 28,
"outputs": [],
"source": [
"# Helper method to compute metrics\n",
"def compute_metrics(p):\n",
" predictions, labels = p\n",
" predictions = np.argmax(predictions, axis=2)\n",
"\n",
" # Remove ignored index (special tokens)\n",
" true_predictions = [\n",
" [unique_labels[p] for (p, l) in zip(prediction, label) if l != -100]\n",
" for prediction, label in zip(predictions, labels)\n",
" ]\n",
" true_labels = [\n",
" [unique_labels[l] for (p, l) in zip(prediction, label) if l != -100]\n",
" for prediction, label in zip(predictions, labels)\n",
" ]\n",
"\n",
" results = metric.compute(predictions=true_predictions, references=true_labels)\n",
" return {\n",
" \"precision\": results[\"overall_precision\"],\n",
" \"recall\": results[\"overall_recall\"],\n",
" \"f1\": results[\"overall_f1\"],\n",
" \"accuracy\": results[\"overall_accuracy\"],\n",
" }"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-05-29T20:47:04.101555900Z",
"start_time": "2024-05-29T20:47:04.079557600Z"
}
},
"id": "263574f331ccd7fa"
},
{
"cell_type": "code",
"execution_count": 29,
"outputs": [],
"source": [
"# Define trainer\n",
"trainer = Trainer(\n",
" model,\n",
" training_args,\n",
" train_dataset=tokenized_train,\n",
" eval_dataset=tokenized_val,\n",
" data_collator=data_collator,\n",
" tokenizer=tokenizer,\n",
" compute_metrics=compute_metrics\n",
")"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-05-29T20:47:05.152846100Z",
"start_time": "2024-05-29T20:47:04.839762Z"
}
},
"id": "7435f846555be628"
},
{
"cell_type": "markdown",
"source": [
"### Train model"
],
"metadata": {
"collapsed": false
},
"id": "36cabe4b448fcbf2"
},
{
"cell_type": "code",
"execution_count": 27,
"outputs": [
{
"data": {
"text/plain": "<IPython.core.display.HTML object>",
"text/html": "\n <div>\n \n <progress value='2' max='10745' style='width:300px; height:20px; vertical-align: middle;'></progress>\n [ 2/10745 : < :, Epoch 0.00/5]\n </div>\n <table border=\"1\" class=\"dataframe\">\n <thead>\n <tr style=\"text-align: left;\">\n <th>Epoch</th>\n <th>Training Loss</th>\n <th>Validation Loss</th>\n </tr>\n </thead>\n <tbody>\n </tbody>\n</table><p>"
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": "TrainOutput(global_step=10745, training_loss=0.16770398169686596, metrics={'train_runtime': 715.5143, 'train_samples_per_second': 15.017, 'train_steps_per_second': 15.017, 'total_flos': 424900907413920.0, 'train_loss': 0.16770398169686596, 'epoch': 5.0})"
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Train model\n",
"torch.cuda.empty_cache()\n",
"trainer.train()"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-05-29T19:28:01.082515100Z",
"start_time": "2024-05-29T19:16:05.326865100Z"
}
},
"id": "bc129922e37d3a66"
},
{
"cell_type": "code",
"execution_count": 29,
"outputs": [],
"source": [
"# Save model\n",
"trainer.save_model('ner-model')"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-05-29T19:34:21.843359600Z",
"start_time": "2024-05-29T19:34:21.416348600Z"
}
},
"id": "feb7b9d1a7361676"
},
{
"cell_type": "markdown",
"source": [
"### Evaluate model"
],
"metadata": {
"collapsed": false
},
"id": "2d35de4a67725848"
},
{
"cell_type": "code",
"execution_count": 30,
"outputs": [
{
"data": {
"text/plain": "<IPython.core.display.HTML object>",
"text/html": "\n <div>\n \n <progress value='1' max='529' style='width:300px; height:20px; vertical-align: middle;'></progress>\n [ 1/529 : < :]\n </div>\n "
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": "{'eval_loss': 0.17733338475227356,\n 'eval_precision': 0.7404310067545835,\n 'eval_recall': 0.7740416946872899,\n 'eval_f1': 0.7568633897747822,\n 'eval_accuracy': 0.9586372907517319,\n 'eval_runtime': 5.8684,\n 'eval_samples_per_second': 90.144,\n 'eval_steps_per_second': 90.144}"
},
"execution_count": 30,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Evaluate\n",
"trainer.evaluate()"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-05-29T20:47:15.896453Z",
"start_time": "2024-05-29T20:47:10.004317800Z"
}
},
"id": "bb8d44f73837f564"
},
{
"cell_type": "markdown",
"source": [
"### Predict on validation data"
],
"metadata": {
"collapsed": false
},
"id": "78ea55ccc041a68"
},
{
"cell_type": "code",
"execution_count": 167,
"outputs": [],
"source": [
"# Preprocess data\n",
"def preprocess_data(tokens):\n",
" sentences = [\" \".join(token_list) for token_list in tokens]\n",
" return sentences"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-05-30T04:14:10.297133Z",
"start_time": "2024-05-30T04:14:10.283305Z"
}
},
"id": "bbed07b9338166a2"
},
{
"cell_type": "code",
"execution_count": 168,
"outputs": [],
"source": [
"train_sentences = preprocess_data(train_data['tokens'])\n",
"val_sentences = preprocess_data(val_data['tokens'])\n",
"test_sentences = preprocess_data(test_data['tokens'])"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-05-30T04:14:11.977406700Z",
"start_time": "2024-05-30T04:14:11.961831800Z"
}
},
"id": "2a48fa6bce956435"
},
{
"cell_type": "code",
"execution_count": 169,
"outputs": [],
"source": [
"# Align predictions\n",
"def align_predictions(predictions, label_ids, sentence_indices, fragment_lengths):\n",
" preds = np.argmax(predictions, axis=2)\n",
" aligned_preds = []\n",
" aligned_labels = []\n",
"\n",
" for pred, label, idx, length in zip(preds, label_ids, sentence_indices, fragment_lengths):\n",
" aligned_pred = []\n",
" aligned_label = []\n",
" for p, l in zip(pred, label):\n",
" if l != -100:\n",
" aligned_pred.append(p)\n",
" aligned_label.append(l)\n",
" aligned_preds.append(aligned_pred)\n",
" aligned_labels.append(aligned_label)\n",
"\n",
" return aligned_preds, aligned_labels"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-05-30T04:14:13.030955900Z",
"start_time": "2024-05-30T04:14:13.011494Z"
}
},
"id": "acb1e4438f26c866"
},
{
"cell_type": "code",
"execution_count": 170,
"outputs": [
{
"data": {
"text/plain": "<IPython.core.display.HTML object>",
"text/html": "\n <div>\n \n <progress value='1' max='529' style='width:300px; height:20px; vertical-align: middle;'></progress>\n [ 1/529 : < :]\n </div>\n "
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Predict on validation data\n",
"predictions_val, label_ids_val, metrics_val = trainer.predict(tokenized_val)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-05-30T04:14:19.969405600Z",
"start_time": "2024-05-30T04:14:13.795285400Z"
}
},
"id": "e9d716adcc29094c"
},
{
"cell_type": "code",
"execution_count": 171,
"outputs": [],
"source": [
"# Align predictions\n",
"aligned_preds_val, aligned_labels_val = align_predictions(predictions_val, label_ids_val, val_original_sentence_indices, val_fragment_lengths)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-05-30T04:14:20.017766Z",
"start_time": "2024-05-30T04:14:19.970385600Z"
}
},
"id": "6e794d5403a8a7a9"
},
{
"cell_type": "code",
"execution_count": 172,
"outputs": [],
"source": [
"# Concat results based on val_original_sentence_indices\n",
"predicted_labels = []\n",
"true_labels = []\n",
"for i in range(len(aligned_preds_val)):\n",
" if i == 0:\n",
" predicted_labels.append(aligned_preds_val[i])\n",
" true_labels.append(aligned_labels_val[i])\n",
" elif val_original_sentence_indices[i] == val_original_sentence_indices[i-1]:\n",
" predicted_labels[-1] += aligned_preds_val[i]\n",
" true_labels[-1] += aligned_labels_val[i]\n",
" else:\n",
" predicted_labels.append(aligned_preds_val[i])\n",
" true_labels.append(aligned_labels_val[i])"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-05-30T04:14:20.017766Z",
"start_time": "2024-05-30T04:14:20.000190600Z"
}
},
"id": "2c5c1526bee64f34"
},
{
"cell_type": "markdown",
"source": [
"### Postprocessing"
],
"metadata": {
"collapsed": false
},
"id": "27b07318b3720194"
},
{
"cell_type": "code",
"execution_count": 174,
"outputs": [],
"source": [
"import regex as re\n",
"\n",
"# Postprocessing\n",
"# Regex for finding I-tags that start a sequence (should be B-tags)\n",
"def incorrect_I_as_begin_tag(text):\n",
" return re.finditer(r'(?<![BI]-\\w+ )I-\\w+', text)\n",
"\n",
"# Helper method for replacing I-tags that start a sequence with B-tags\n",
"def replace_incorrect_I_as_begin_tag(df):\n",
" # Iterate until no more changes\n",
" i = 0\n",
" \n",
" while True:\n",
" outer_counter_old = 0\n",
" outer_counter = 0\n",
" \n",
" print(f\"Iteration: {i+1}\")\n",
" \n",
" for idx, row in df.iterrows():\n",
" x = incorrect_I_as_begin_tag(row['ner_tags'])\n",
" \n",
" inner_counter = 0\n",
" \n",
" for match in x:\n",
" inner_counter += 1\n",
" hp = list(row['ner_tags'])\n",
" hp[match.start()] = 'B'\n",
" row['ner_tags'] = \"\".join(hp)\n",
" \n",
" outer_counter += inner_counter\n",
" \n",
" print(f\"Changes: {outer_counter - outer_counter_old}\")\n",
" \n",
" i += 1\n",
" \n",
" if outer_counter_old == outer_counter:\n",
" break\n",
" else:\n",
" outer_counter_old = outer_counter\n",
" \n",
" return df\n",
"\n",
"# Regex for finding inconsistent I-tags after B-tags (I-tags that are not continuation of B-tags)\n",
"def inconsistent_I_after_B(text):\n",
" return re.finditer(r'(?<=B-(\\w+) )(?:I-(?!\\1)\\w+)', text)\n",
"\n",
"# Helper method for removing inconsistent I-tags after B-tags\n",
"def replace_inconsistent_I_after_B(df):\n",
" # Iterate until no more changes\n",
" i = 0\n",
" \n",
" while True:\n",
" outer_counter_old = 0\n",
" outer_counter = 0\n",
" \n",
" print(f\"Iteration: {i+1}\")\n",
" \n",
" for idx, row in df.iterrows():\n",
" matches = inconsistent_I_after_B(row['ner_tags'])\n",
" \n",
" inner_counter = 0\n",
" \n",
" for match in matches:\n",
" inner_counter += 1\n",
" hp = list(row['ner_tags'])\n",
" hp[match.start()] = 'B'\n",
" row['ner_tags'] = \"\".join(hp)\n",
" \n",
" outer_counter += inner_counter\n",
" \n",
" print(f\"Changes: {outer_counter - outer_counter_old}\")\n",
" \n",
" i += 1\n",
" \n",
" if outer_counter_old == outer_counter:\n",
" break\n",
" else:\n",
" outer_counter_old = outer_counter\n",
" \n",
" return df\n",
"\n",
"# Regex for finding inconsistent I-tags after other I-tags (I-tags that are not continuation of the same tag)\n",
"def inconsistent_I_after_I(text):\n",
" return re.finditer(r'(?<=I-(\\w+) )(?:I-(?!\\1)\\w+)', text)\n",
"\n",
"# Helper method for removing inconsistent I-tags after other I-tags\n",
"def replace_inconsistent_I_after_I(df):\n",
" # Iterate until no more changes\n",
" i = 0\n",
" \n",
" while True:\n",
" outer_counter_old = 0\n",
" outer_counter = 0\n",
" \n",
" print(f\"Iteration: {i+1}\")\n",
" \n",
" for idx, row in df.iterrows():\n",
" matches = inconsistent_I_after_I(row['ner_tags'])\n",
" \n",
" inner_counter = 0\n",
" \n",
" for match in matches:\n",
" inner_counter += 1\n",
" hp = list(row['ner_tags'])\n",
" hp[match.start()] = 'B'\n",
" row['ner_tags'] = \"\".join(hp)\n",
" \n",
" outer_counter += inner_counter\n",
" \n",
" print(f\"Changes: {outer_counter - outer_counter_old}\")\n",
" \n",
" i += 1\n",
" \n",
" if outer_counter_old == outer_counter:\n",
" break\n",
" else:\n",
" outer_counter_old = outer_counter\n",
" \n",
" return df"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-05-30T04:14:25.303245500Z",
"start_time": "2024-05-30T04:14:25.284352300Z"
}
},
"id": "650018baa72635bd"
},
{
"cell_type": "code",
"execution_count": 175,
"outputs": [
{
"data": {
"text/plain": "[[0,\n 0,\n 3,\n 0,\n 0,\n 0,\n 7,\n 0,\n 0,\n 0,\n 0,\n 0,\n 5,\n 0,\n 0,\n 5,\n 8,\n 0,\n 1,\n 2,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 3,\n 0,\n 3,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 5,\n 0,\n 3,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 5,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 3,\n 0,\n 0,\n 0,\n 0,\n 1,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 5,\n 6,\n 0,\n 3,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 5,\n 0,\n 1,\n 2,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 1,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 1,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 1,\n 2,\n 0,\n 1,\n 2,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 3,\n 0,\n 5,\n 0,\n 0,\n 1,\n 0,\n 0,\n 0,\n 0,\n 5,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 3,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 3,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 3,\n 0,\n 3,\n 0,\n 1,\n 2,\n 0,\n 0,\n 0,\n 0,\n 0,\n 5,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 3,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 5,\n 2,\n 1,\n 2,\n 0,\n 0,\n 0,\n 0,\n 3,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 5,\n 0,\n 0,\n 0,\n 0,\n 0,\n 3,\n 1,\n 2,\n 0,\n 0,\n 0,\n 0,\n 0,\n 1,\n 2,\n 0,\n 0,\n 0,\n 0,\n 1,\n 2,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 5,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 5,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 5,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 1,\n 2,\n 0,\n 0,\n 0,\n 1,\n 1,\n 2,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 5,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0],\n [0,\n 0,\n 7,\n 8,\n 8,\n 0,\n 0,\n 0,\n 5,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 7,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 3,\n 0,\n 3,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 3,\n 0,\n 0,\n 0,\n 0,\n 1,\n 2,\n 0,\n 0,\n 0,\n 3,\n 0,\n 0,\n 0,\n 3,\n 0,\n 0,\n 0,\n 3,\n 0,\n 0,\n 0,\n 0,\n 0,\n 3,\n 0,\n 0,\n 5,\n 0,\n 1,\n 2,\n 0,\n 0,\n 1,\n 2,\n 0,\n 0,\n 1,\n 2,\n 0,\n 0,\n 0,\n 3,\n 2,\n 0,\n 1,\n 2,\n 0,\n 0,\n 0,\n 0,\n 0,\n 0,\n 1,\n 2,\n 0,\n 3,\n 2,\n 0,\n 1,\n 2,\n 0,\n 0,\n 1,\n 2,\n 0,\n 0,\n 0,\n 1,\n 0,\n 0,\n 0,\n 3,\n 0,\n 5,\n 6,\n 0,\n 0,\n 3,\n 0,\n 0,\n 3,\n 0,\n 0,\n 1,\n 2,\n 0,\n 0,\n 0,\n 0,\n 1,\n 2,\n 0,\n 0,\n 1,\n 2,\n 0,\n 0,\n 1,\n 2,\n 0,\n 0,\n 1,\n 2,\n 0,\n 0,\n 0,\n 0,\n 1,\n 0,\n 3,\n 0,\n 0,\n 1,\n 0,\n 0,\n 0,\n 1,\n 2,\n 0,\n 0,\n 1,\n 2,\n 0,\n 0,\n 0,\n 1,\n 2,\n 0,\n 1,\n 2,\n 0,\n 0,\n 0,\n 5,\n 0,\n 3,\n 0,\n 0,\n 0,\n 0,\n 1,\n 2,\n 0,\n 0,\n 1,\n 2,\n 0,\n 0,\n 1,\n 2,\n 0,\n 0,\n 0,\n 1,\n 0,\n 0,\n 0,\n 0,\n 0,\n 1,\n 0,\n 3,\n 0,\n 0,\n 0,\n 0,\n 1,\n 0,\n 0,\n 1,\n 2,\n 0,\n 0,\n 1,\n 0,\n 0,\n 0,\n 0,\n 1,\n 2,\n 2,\n 0,\n 1,\
},
"execution_count": 175,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"predicted_labels"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-05-30T04:14:40.688964500Z",
"start_time": "2024-05-30T04:14:40.228990800Z"
}
},
"id": "d329563f196b50b8"
},
{
"cell_type": "code",
"execution_count": 127,
"outputs": [],
"source": [
"# Save predictions to .tsv file (line by line)\n",
"with open('train/out-transformer.tsv', 'w') as f:\n",
" for i in range(len(predicted_labels)):\n",
" f.write(' '.join([unique_labels[p] for p in predicted_labels[i]]) + '\\n')"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-05-30T04:01:53.622588100Z",
"start_time": "2024-05-30T04:01:53.604030400Z"
}
},
"id": "fddf782e15ce2ba9"
},
{
"cell_type": "code",
"execution_count": 157,
"outputs": [],
"source": [
"# Load predictions\n",
"predictions = pd.read_csv('train/out-transformer.tsv', header=None, delimiter='\\t')\n",
"predictions.columns = ['ner_tags']"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-05-30T04:11:08.569352800Z",
"start_time": "2024-05-30T04:11:08.552616200Z"
}
},
"id": "3f3bb54ada67b23e"
},
{
"cell_type": "code",
"execution_count": 158,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Iteration: 1\n",
"Changes: 143\n",
"Iteration: 2\n",
"Changes: 0\n",
"Iteration: 1\n",
"Changes: 168\n",
"Iteration: 2\n",
"Changes: 14\n",
"Iteration: 3\n",
"Changes: 0\n",
"Iteration: 1\n",
"Changes: 17\n",
"Iteration: 2\n",
"Changes: 0\n"
]
}
],
"source": [
"# Postprocessing\n",
"predictions = replace_incorrect_I_as_begin_tag(predictions)\n",
"predictions = replace_inconsistent_I_after_B(predictions)\n",
"predictions = replace_inconsistent_I_after_I(predictions)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-05-30T04:11:09.003702200Z",
"start_time": "2024-05-30T04:11:08.898389300Z"
}
},
"id": "9c02e24a89ccbe4d"
},
{
"cell_type": "code",
"execution_count": 159,
"outputs": [],
"source": [
"# Save predictions to .tsv file (line by line)\n",
"predictions.to_csv('dev-0/out.tsv', header=False, index=False, sep='\\t')"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-05-30T04:11:18.878774100Z",
"start_time": "2024-05-30T04:11:18.853798Z"
}
},
"id": "a79119ef9e7a7451"
},
{
"cell_type": "code",
"execution_count": 160,
"outputs": [],
"source": [
"from seqeval.metrics import classification_report\n",
"\n",
"# Convert index to label\n",
"df_val = pd.DataFrame({'ner_tags': true_labels})\n",
"df_val['tokens'] = df_val['ner_tags'].apply(lambda x: [unique_labels[int(i)] for i in x])\n",
"predictions['tokens'] = predictions['ner_tags'].apply(lambda x: x.split())"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-05-30T04:11:59.336256600Z",
"start_time": "2024-05-30T04:11:59.317212100Z"
}
},
"id": "2430c7d5c422ac17"
},
{
"cell_type": "code",
"execution_count": 165,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" precision recall f1-score support\n",
"\n",
" LOC 0.80 0.86 0.83 1835\n",
" MISC 0.72 0.70 0.71 921\n",
" ORG 0.66 0.69 0.67 1333\n",
" PER 0.74 0.78 0.76 1840\n",
"\n",
" micro avg 0.74 0.77 0.76 5929\n",
" macro avg 0.73 0.76 0.74 5929\n",
"weighted avg 0.74 0.77 0.75 5929\n"
]
}
],
"source": [
"# Classification report\n",
"print(classification_report(df_val['tokens'].tolist(), predictions['tokens'].tolist()))"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-05-30T04:12:34.453872200Z",
"start_time": "2024-05-30T04:12:33.984525300Z"
}
},
"id": "6cde0cee2a46f1e3"
},
{
"cell_type": "markdown",
"source": [
"GEVAL F1-BIO (dev): 0.75517"
],
"metadata": {
"collapsed": false
},
"id": "921ef224aed27e38"
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}