{ "cells": [ { "cell_type": "markdown", "source": [ "## Transformer" ], "metadata": { "collapsed": false }, "id": "7dd30d84a916d9d0" }, { "cell_type": "code", "execution_count": 1, "outputs": [], "source": [ "# Necessary imports\n", "import pandas as pd\n", "import numpy as np\n", "\n", "import torch\n", "\n", "import datasets\n", "from datasets import ClassLabel, Features, Sequence, Value\n", "from transformers import pipeline, AutoModelForTokenClassification, AutoTokenizer, TrainingArguments, Trainer, DataCollatorForTokenClassification\n", "\n", "import warnings\n", "warnings.filterwarnings('ignore')" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-05-29T20:46:17.480512300Z", "start_time": "2024-05-29T20:46:00.288032100Z" } }, "id": "initial_id" }, { "cell_type": "markdown", "source": [ "### Prepare data" ], "metadata": { "collapsed": false }, "id": "61aea5d48638d128" }, { "cell_type": "code", "execution_count": 2, "outputs": [], "source": [ "# Divide train data into sentences and labels\n", "train_data = pd.read_csv('train/train.tsv', sep='\\t', header=None)\n", "\n", "with open('train/train_labels.tsv', 'w') as f:\n", " for i in range(len(train_data)):\n", " if i == len(train_data) - 1:\n", " f.write(train_data.iloc[i][0])\n", " else:\n", " f.write(train_data.iloc[i][0] + '\\n')\n", " \n", "with open('train/train_sentences.tsv', 'w') as f:\n", " for i in range(len(train_data)):\n", " if i == len(train_data) - 1:\n", " f.write(train_data.iloc[i][1])\n", " else:\n", " f.write(train_data.iloc[i][1] + '\\n')" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-05-29T20:46:17.573476800Z", "start_time": "2024-05-29T20:46:17.482528600Z" } }, "id": "b19f9ff554147bc4" }, { "cell_type": "code", "execution_count": 3, "outputs": [], "source": [ "# Data paths\n", "train_sentences_file = 'train/train_sentences.tsv'\n", "train_labels_file = 'train/train_labels.tsv'\n", "val_sentences_file = 'dev-0/in.tsv'\n", "val_labels_file = 'dev-0/expected.tsv'\n", "test_sentences_file = 'test-A/in.tsv'" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-05-29T20:46:17.587032200Z", "start_time": "2024-05-29T20:46:17.575283500Z" } }, "id": "e0e8f33971087b3e" }, { "cell_type": "code", "execution_count": 4, "outputs": [], "source": [ "# Method to read tokens and labels from files\n", "def read_sentences_and_labels(sentences_path, labels_path=None):\n", " tokens = []\n", " ner_tags = []\n", " \n", " with open(sentences_path, 'r') as f:\n", " for line in f:\n", " tokens.append(line.strip().split())\n", " \n", " if labels_path:\n", " with open(labels_path, 'r') as f:\n", " for line in f:\n", " ner_tags.append(line.strip().split())\n", " \n", " if labels_path:\n", " return {'tokens': tokens, 'ner_tags': ner_tags}\n", " else:\n", " return {'tokens': tokens}" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-05-29T20:46:17.611271600Z", "start_time": "2024-05-29T20:46:17.589908Z" } }, "id": "47bcc983f537b670" }, { "cell_type": "code", "execution_count": 5, "outputs": [], "source": [ "# Load data\n", "train_data = read_sentences_and_labels(train_sentences_file, train_labels_file)\n", "val_data = read_sentences_and_labels(val_sentences_file, val_labels_file)\n", "test_data = read_sentences_and_labels(test_sentences_file)" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-05-29T20:46:17.757405800Z", "start_time": "2024-05-29T20:46:17.606209100Z" } }, "id": "cf901e101075811e" }, { "cell_type": "code", "execution_count": 6, "outputs": [], "source": [ "# Split long sentences into multiple sentences\n", "def split_long_sentences(data, max_length=128):\n", " if 'ner_tags' in data:\n", " new_data = {'tokens': [], 'ner_tags': []}\n", " else:\n", " new_data = {'tokens': []}\n", " \n", " original_sentence_indices = []\n", " fragment_lengths = []\n", " \n", " for i in range(len(data['tokens'])):\n", " tokens = data['tokens'][i]\n", " if 'ner_tags' in data:\n", " ner_tags = data['ner_tags'][i]\n", " \n", " if len(tokens) > max_length:\n", " for j in range(0, len(tokens), max_length):\n", " new_data['tokens'].append(tokens[j:j+max_length])\n", " if 'ner_tags' in data:\n", " new_data['ner_tags'].append(ner_tags[j:j+max_length])\n", " original_sentence_indices.append(i)\n", " fragment_lengths.append(len(tokens[j:j+max_length]))\n", " else:\n", " new_data['tokens'].append(tokens)\n", " if 'ner_tags' in data:\n", " new_data['ner_tags'].append(ner_tags)\n", " original_sentence_indices.append(i)\n", " fragment_lengths.append(len(tokens))\n", " \n", " return new_data, original_sentence_indices, fragment_lengths" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-05-29T20:46:17.772977500Z", "start_time": "2024-05-29T20:46:17.686317600Z" } }, "id": "f9afdda7a877f2" }, { "cell_type": "code", "execution_count": 7, "outputs": [], "source": [ "# Split long sentences\n", "train_data, train_original_sentence_indices, train_fragment_lengths = split_long_sentences(train_data)\n", "val_data, val_original_sentence_indices, val_fragment_lengths = split_long_sentences(val_data)\n", "test_data, test_original_sentence_indices, test_fragment_lengths = split_long_sentences(test_data)" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-05-29T20:46:17.773978Z", "start_time": "2024-05-29T20:46:17.697615700Z" } }, "id": "8050585694f36e46" }, { "cell_type": "code", "execution_count": 8, "outputs": [], "source": [ "# Convert to datasets\n", "train_dataset = datasets.Dataset.from_dict(train_data)\n", "val_dataset = datasets.Dataset.from_dict(val_data)\n", "test_dataset = datasets.Dataset.from_dict(test_data)" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-05-29T20:46:17.849338500Z", "start_time": "2024-05-29T20:46:17.713989300Z" } }, "id": "d8cb17d6e5631db5" }, { "cell_type": "code", "execution_count": 9, "outputs": [], "source": [ "# List of unique ner labels\n", "unique_labels = ['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC', 'B-MISC', 'I-MISC']\n", "\n", "# Create class label\n", "ner_tags_feature = ClassLabel(names=unique_labels)" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-05-29T20:46:17.853658400Z", "start_time": "2024-05-29T20:46:17.839342100Z" } }, "id": "4baa0327c2b22f4c" }, { "cell_type": "code", "execution_count": 10, "outputs": [], "source": [ "# Method to convert ner tags to class labels\n", "def convert_to_classlabel(example):\n", " example['ner_tags'] = [ner_tags_feature.str2int(tag) for tag in example['ner_tags']]\n", " \n", " return example" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-05-29T20:46:17.904893600Z", "start_time": "2024-05-29T20:46:17.854695Z" } }, "id": "6d88730786c438c1" }, { "cell_type": "code", "execution_count": 11, "outputs": [ { "data": { "text/plain": "Map: 0%| | 0/2149 [00:00", "text/html": "\n
\n \n \n [ 2/10745 : < :, Epoch 0.00/5]\n
\n \n \n \n \n \n \n \n \n \n \n
EpochTraining LossValidation Loss

" }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": "TrainOutput(global_step=10745, training_loss=0.16770398169686596, metrics={'train_runtime': 715.5143, 'train_samples_per_second': 15.017, 'train_steps_per_second': 15.017, 'total_flos': 424900907413920.0, 'train_loss': 0.16770398169686596, 'epoch': 5.0})" }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Train model\n", "torch.cuda.empty_cache()\n", "trainer.train()" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-05-29T19:28:01.082515100Z", "start_time": "2024-05-29T19:16:05.326865100Z" } }, "id": "bc129922e37d3a66" }, { "cell_type": "code", "execution_count": 29, "outputs": [], "source": [ "# Save model\n", "trainer.save_model('ner-model')" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-05-29T19:34:21.843359600Z", "start_time": "2024-05-29T19:34:21.416348600Z" } }, "id": "feb7b9d1a7361676" }, { "cell_type": "markdown", "source": [ "### Evaluate model" ], "metadata": { "collapsed": false }, "id": "2d35de4a67725848" }, { "cell_type": "code", "execution_count": 30, "outputs": [ { "data": { "text/plain": "", "text/html": "\n

\n \n \n [ 1/529 : < :]\n
\n " }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": "{'eval_loss': 0.17733338475227356,\n 'eval_precision': 0.7404310067545835,\n 'eval_recall': 0.7740416946872899,\n 'eval_f1': 0.7568633897747822,\n 'eval_accuracy': 0.9586372907517319,\n 'eval_runtime': 5.8684,\n 'eval_samples_per_second': 90.144,\n 'eval_steps_per_second': 90.144}" }, "execution_count": 30, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Evaluate\n", "trainer.evaluate()" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-05-29T20:47:15.896453Z", "start_time": "2024-05-29T20:47:10.004317800Z" } }, "id": "bb8d44f73837f564" }, { "cell_type": "markdown", "source": [ "### Predict on validation data" ], "metadata": { "collapsed": false }, "id": "78ea55ccc041a68" }, { "cell_type": "code", "execution_count": 167, "outputs": [], "source": [ "# Preprocess data\n", "def preprocess_data(tokens):\n", " sentences = [\" \".join(token_list) for token_list in tokens]\n", " return sentences" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-05-30T04:14:10.297133Z", "start_time": "2024-05-30T04:14:10.283305Z" } }, "id": "bbed07b9338166a2" }, { "cell_type": "code", "execution_count": 168, "outputs": [], "source": [ "train_sentences = preprocess_data(train_data['tokens'])\n", "val_sentences = preprocess_data(val_data['tokens'])\n", "test_sentences = preprocess_data(test_data['tokens'])" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-05-30T04:14:11.977406700Z", "start_time": "2024-05-30T04:14:11.961831800Z" } }, "id": "2a48fa6bce956435" }, { "cell_type": "code", "execution_count": 169, "outputs": [], "source": [ "# Align predictions\n", "def align_predictions(predictions, label_ids, sentence_indices, fragment_lengths):\n", " preds = np.argmax(predictions, axis=2)\n", " aligned_preds = []\n", " aligned_labels = []\n", "\n", " for pred, label, idx, length in zip(preds, label_ids, sentence_indices, fragment_lengths):\n", " aligned_pred = []\n", " aligned_label = []\n", " for p, l in zip(pred, label):\n", " if l != -100:\n", " aligned_pred.append(p)\n", " aligned_label.append(l)\n", " aligned_preds.append(aligned_pred)\n", " aligned_labels.append(aligned_label)\n", "\n", " return aligned_preds, aligned_labels" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-05-30T04:14:13.030955900Z", "start_time": "2024-05-30T04:14:13.011494Z" } }, "id": "acb1e4438f26c866" }, { "cell_type": "code", "execution_count": 170, "outputs": [ { "data": { "text/plain": "", "text/html": "\n
\n \n \n [ 1/529 : < :]\n
\n " }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Predict on validation data\n", "predictions_val, label_ids_val, metrics_val = trainer.predict(tokenized_val)" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-05-30T04:14:19.969405600Z", "start_time": "2024-05-30T04:14:13.795285400Z" } }, "id": "e9d716adcc29094c" }, { "cell_type": "code", "execution_count": 171, "outputs": [], "source": [ "# Align predictions\n", "aligned_preds_val, aligned_labels_val = align_predictions(predictions_val, label_ids_val, val_original_sentence_indices, val_fragment_lengths)" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-05-30T04:14:20.017766Z", "start_time": "2024-05-30T04:14:19.970385600Z" } }, "id": "6e794d5403a8a7a9" }, { "cell_type": "code", "execution_count": 172, "outputs": [], "source": [ "# Concat results based on val_original_sentence_indices\n", "predicted_labels = []\n", "true_labels = []\n", "for i in range(len(aligned_preds_val)):\n", " if i == 0:\n", " predicted_labels.append(aligned_preds_val[i])\n", " true_labels.append(aligned_labels_val[i])\n", " elif val_original_sentence_indices[i] == val_original_sentence_indices[i-1]:\n", " predicted_labels[-1] += aligned_preds_val[i]\n", " true_labels[-1] += aligned_labels_val[i]\n", " else:\n", " predicted_labels.append(aligned_preds_val[i])\n", " true_labels.append(aligned_labels_val[i])" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-05-30T04:14:20.017766Z", "start_time": "2024-05-30T04:14:20.000190600Z" } }, "id": "2c5c1526bee64f34" }, { "cell_type": "markdown", "source": [ "### Postprocessing" ], "metadata": { "collapsed": false }, "id": "27b07318b3720194" }, { "cell_type": "code", "execution_count": 174, "outputs": [], "source": [ "import regex as re\n", "\n", "# Postprocessing\n", "# Regex for finding I-tags that start a sequence (should be B-tags)\n", "def incorrect_I_as_begin_tag(text):\n", " return re.finditer(r'(?