From 897a1490c7df1402433fecf06ff8f63a0f20223f Mon Sep 17 00:00:00 2001 From: s464863 Date: Thu, 30 May 2024 06:22:01 +0200 Subject: [PATCH] Add transformer model notebook --- transformer.ipynb | 1291 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 1291 insertions(+) create mode 100644 transformer.ipynb diff --git a/transformer.ipynb b/transformer.ipynb new file mode 100644 index 0000000..223c6ba --- /dev/null +++ b/transformer.ipynb @@ -0,0 +1,1291 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "source": [ + "## Transformer" + ], + "metadata": { + "collapsed": false + }, + "id": "7dd30d84a916d9d0" + }, + { + "cell_type": "code", + "execution_count": 1, + "outputs": [], + "source": [ + "# Necessary imports\n", + "import pandas as pd\n", + "import numpy as np\n", + "\n", + "import torch\n", + "\n", + "import datasets\n", + "from datasets import ClassLabel, Features, Sequence, Value\n", + "from transformers import pipeline, AutoModelForTokenClassification, AutoTokenizer, TrainingArguments, Trainer, DataCollatorForTokenClassification\n", + "\n", + "import warnings\n", + "warnings.filterwarnings('ignore')" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-05-29T20:46:17.480512300Z", + "start_time": "2024-05-29T20:46:00.288032100Z" + } + }, + "id": "initial_id" + }, + { + "cell_type": "markdown", + "source": [ + "### Prepare data" + ], + "metadata": { + "collapsed": false + }, + "id": "61aea5d48638d128" + }, + { + "cell_type": "code", + "execution_count": 2, + "outputs": [], + "source": [ + "# Divide train data into sentences and labels\n", + "train_data = pd.read_csv('train/train.tsv', sep='\\t', header=None)\n", + "\n", + "with open('train/train_labels.tsv', 'w') as f:\n", + " for i in range(len(train_data)):\n", + " if i == len(train_data) - 1:\n", + " f.write(train_data.iloc[i][0])\n", + " else:\n", + " f.write(train_data.iloc[i][0] + '\\n')\n", + " \n", + "with open('train/train_sentences.tsv', 'w') as f:\n", + " for i in range(len(train_data)):\n", + " if i == len(train_data) - 1:\n", + " f.write(train_data.iloc[i][1])\n", + " else:\n", + " f.write(train_data.iloc[i][1] + '\\n')" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-05-29T20:46:17.573476800Z", + "start_time": "2024-05-29T20:46:17.482528600Z" + } + }, + "id": "b19f9ff554147bc4" + }, + { + "cell_type": "code", + "execution_count": 3, + "outputs": [], + "source": [ + "# Data paths\n", + "train_sentences_file = 'train/train_sentences.tsv'\n", + "train_labels_file = 'train/train_labels.tsv'\n", + "val_sentences_file = 'dev-0/in.tsv'\n", + "val_labels_file = 'dev-0/expected.tsv'\n", + "test_sentences_file = 'test-A/in.tsv'" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-05-29T20:46:17.587032200Z", + "start_time": "2024-05-29T20:46:17.575283500Z" + } + }, + "id": "e0e8f33971087b3e" + }, + { + "cell_type": "code", + "execution_count": 4, + "outputs": [], + "source": [ + "# Method to read tokens and labels from files\n", + "def read_sentences_and_labels(sentences_path, labels_path=None):\n", + " tokens = []\n", + " ner_tags = []\n", + " \n", + " with open(sentences_path, 'r') as f:\n", + " for line in f:\n", + " tokens.append(line.strip().split())\n", + " \n", + " if labels_path:\n", + " with open(labels_path, 'r') as f:\n", + " for line in f:\n", + " ner_tags.append(line.strip().split())\n", + " \n", + " if labels_path:\n", + " return {'tokens': tokens, 'ner_tags': ner_tags}\n", + " else:\n", + " return {'tokens': tokens}" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-05-29T20:46:17.611271600Z", + "start_time": "2024-05-29T20:46:17.589908Z" + } + }, + "id": "47bcc983f537b670" + }, + { + "cell_type": "code", + "execution_count": 5, + "outputs": [], + "source": [ + "# Load data\n", + "train_data = read_sentences_and_labels(train_sentences_file, train_labels_file)\n", + "val_data = read_sentences_and_labels(val_sentences_file, val_labels_file)\n", + "test_data = read_sentences_and_labels(test_sentences_file)" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-05-29T20:46:17.757405800Z", + "start_time": "2024-05-29T20:46:17.606209100Z" + } + }, + "id": "cf901e101075811e" + }, + { + "cell_type": "code", + "execution_count": 6, + "outputs": [], + "source": [ + "# Split long sentences into multiple sentences\n", + "def split_long_sentences(data, max_length=128):\n", + " if 'ner_tags' in data:\n", + " new_data = {'tokens': [], 'ner_tags': []}\n", + " else:\n", + " new_data = {'tokens': []}\n", + " \n", + " original_sentence_indices = []\n", + " fragment_lengths = []\n", + " \n", + " for i in range(len(data['tokens'])):\n", + " tokens = data['tokens'][i]\n", + " if 'ner_tags' in data:\n", + " ner_tags = data['ner_tags'][i]\n", + " \n", + " if len(tokens) > max_length:\n", + " for j in range(0, len(tokens), max_length):\n", + " new_data['tokens'].append(tokens[j:j+max_length])\n", + " if 'ner_tags' in data:\n", + " new_data['ner_tags'].append(ner_tags[j:j+max_length])\n", + " original_sentence_indices.append(i)\n", + " fragment_lengths.append(len(tokens[j:j+max_length]))\n", + " else:\n", + " new_data['tokens'].append(tokens)\n", + " if 'ner_tags' in data:\n", + " new_data['ner_tags'].append(ner_tags)\n", + " original_sentence_indices.append(i)\n", + " fragment_lengths.append(len(tokens))\n", + " \n", + " return new_data, original_sentence_indices, fragment_lengths" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-05-29T20:46:17.772977500Z", + "start_time": "2024-05-29T20:46:17.686317600Z" + } + }, + "id": "f9afdda7a877f2" + }, + { + "cell_type": "code", + "execution_count": 7, + "outputs": [], + "source": [ + "# Split long sentences\n", + "train_data, train_original_sentence_indices, train_fragment_lengths = split_long_sentences(train_data)\n", + "val_data, val_original_sentence_indices, val_fragment_lengths = split_long_sentences(val_data)\n", + "test_data, test_original_sentence_indices, test_fragment_lengths = split_long_sentences(test_data)" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-05-29T20:46:17.773978Z", + "start_time": "2024-05-29T20:46:17.697615700Z" + } + }, + "id": "8050585694f36e46" + }, + { + "cell_type": "code", + "execution_count": 8, + "outputs": [], + "source": [ + "# Convert to datasets\n", + "train_dataset = datasets.Dataset.from_dict(train_data)\n", + "val_dataset = datasets.Dataset.from_dict(val_data)\n", + "test_dataset = datasets.Dataset.from_dict(test_data)" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-05-29T20:46:17.849338500Z", + "start_time": "2024-05-29T20:46:17.713989300Z" + } + }, + "id": "d8cb17d6e5631db5" + }, + { + "cell_type": "code", + "execution_count": 9, + "outputs": [], + "source": [ + "# List of unique ner labels\n", + "unique_labels = ['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC', 'B-MISC', 'I-MISC']\n", + "\n", + "# Create class label\n", + "ner_tags_feature = ClassLabel(names=unique_labels)" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-05-29T20:46:17.853658400Z", + "start_time": "2024-05-29T20:46:17.839342100Z" + } + }, + "id": "4baa0327c2b22f4c" + }, + { + "cell_type": "code", + "execution_count": 10, + "outputs": [], + "source": [ + "# Method to convert ner tags to class labels\n", + "def convert_to_classlabel(example):\n", + " example['ner_tags'] = [ner_tags_feature.str2int(tag) for tag in example['ner_tags']]\n", + " \n", + " return example" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-05-29T20:46:17.904893600Z", + "start_time": "2024-05-29T20:46:17.854695Z" + } + }, + "id": "6d88730786c438c1" + }, + { + "cell_type": "code", + "execution_count": 11, + "outputs": [ + { + "data": { + "text/plain": "Map: 0%| | 0/2149 [00:00", + "text/html": "\n
\n \n \n [ 2/10745 : < :, Epoch 0.00/5]\n
\n \n \n \n \n \n \n \n \n \n \n
EpochTraining LossValidation Loss

" + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": "TrainOutput(global_step=10745, training_loss=0.16770398169686596, metrics={'train_runtime': 715.5143, 'train_samples_per_second': 15.017, 'train_steps_per_second': 15.017, 'total_flos': 424900907413920.0, 'train_loss': 0.16770398169686596, 'epoch': 5.0})" + }, + "execution_count": 27, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Train model\n", + "torch.cuda.empty_cache()\n", + "trainer.train()" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-05-29T19:28:01.082515100Z", + "start_time": "2024-05-29T19:16:05.326865100Z" + } + }, + "id": "bc129922e37d3a66" + }, + { + "cell_type": "code", + "execution_count": 29, + "outputs": [], + "source": [ + "# Save model\n", + "trainer.save_model('ner-model')" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-05-29T19:34:21.843359600Z", + "start_time": "2024-05-29T19:34:21.416348600Z" + } + }, + "id": "feb7b9d1a7361676" + }, + { + "cell_type": "markdown", + "source": [ + "### Evaluate model" + ], + "metadata": { + "collapsed": false + }, + "id": "2d35de4a67725848" + }, + { + "cell_type": "code", + "execution_count": 30, + "outputs": [ + { + "data": { + "text/plain": "", + "text/html": "\n

\n \n \n [ 1/529 : < :]\n
\n " + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": "{'eval_loss': 0.17733338475227356,\n 'eval_precision': 0.7404310067545835,\n 'eval_recall': 0.7740416946872899,\n 'eval_f1': 0.7568633897747822,\n 'eval_accuracy': 0.9586372907517319,\n 'eval_runtime': 5.8684,\n 'eval_samples_per_second': 90.144,\n 'eval_steps_per_second': 90.144}" + }, + "execution_count": 30, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Evaluate\n", + "trainer.evaluate()" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-05-29T20:47:15.896453Z", + "start_time": "2024-05-29T20:47:10.004317800Z" + } + }, + "id": "bb8d44f73837f564" + }, + { + "cell_type": "markdown", + "source": [ + "### Predict on validation data" + ], + "metadata": { + "collapsed": false + }, + "id": "78ea55ccc041a68" + }, + { + "cell_type": "code", + "execution_count": 167, + "outputs": [], + "source": [ + "# Preprocess data\n", + "def preprocess_data(tokens):\n", + " sentences = [\" \".join(token_list) for token_list in tokens]\n", + " return sentences" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-05-30T04:14:10.297133Z", + "start_time": "2024-05-30T04:14:10.283305Z" + } + }, + "id": "bbed07b9338166a2" + }, + { + "cell_type": "code", + "execution_count": 168, + "outputs": [], + "source": [ + "train_sentences = preprocess_data(train_data['tokens'])\n", + "val_sentences = preprocess_data(val_data['tokens'])\n", + "test_sentences = preprocess_data(test_data['tokens'])" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-05-30T04:14:11.977406700Z", + "start_time": "2024-05-30T04:14:11.961831800Z" + } + }, + "id": "2a48fa6bce956435" + }, + { + "cell_type": "code", + "execution_count": 169, + "outputs": [], + "source": [ + "# Align predictions\n", + "def align_predictions(predictions, label_ids, sentence_indices, fragment_lengths):\n", + " preds = np.argmax(predictions, axis=2)\n", + " aligned_preds = []\n", + " aligned_labels = []\n", + "\n", + " for pred, label, idx, length in zip(preds, label_ids, sentence_indices, fragment_lengths):\n", + " aligned_pred = []\n", + " aligned_label = []\n", + " for p, l in zip(pred, label):\n", + " if l != -100:\n", + " aligned_pred.append(p)\n", + " aligned_label.append(l)\n", + " aligned_preds.append(aligned_pred)\n", + " aligned_labels.append(aligned_label)\n", + "\n", + " return aligned_preds, aligned_labels" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-05-30T04:14:13.030955900Z", + "start_time": "2024-05-30T04:14:13.011494Z" + } + }, + "id": "acb1e4438f26c866" + }, + { + "cell_type": "code", + "execution_count": 170, + "outputs": [ + { + "data": { + "text/plain": "", + "text/html": "\n
\n \n \n [ 1/529 : < :]\n
\n " + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Predict on validation data\n", + "predictions_val, label_ids_val, metrics_val = trainer.predict(tokenized_val)" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-05-30T04:14:19.969405600Z", + "start_time": "2024-05-30T04:14:13.795285400Z" + } + }, + "id": "e9d716adcc29094c" + }, + { + "cell_type": "code", + "execution_count": 171, + "outputs": [], + "source": [ + "# Align predictions\n", + "aligned_preds_val, aligned_labels_val = align_predictions(predictions_val, label_ids_val, val_original_sentence_indices, val_fragment_lengths)" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-05-30T04:14:20.017766Z", + "start_time": "2024-05-30T04:14:19.970385600Z" + } + }, + "id": "6e794d5403a8a7a9" + }, + { + "cell_type": "code", + "execution_count": 172, + "outputs": [], + "source": [ + "# Concat results based on val_original_sentence_indices\n", + "predicted_labels = []\n", + "true_labels = []\n", + "for i in range(len(aligned_preds_val)):\n", + " if i == 0:\n", + " predicted_labels.append(aligned_preds_val[i])\n", + " true_labels.append(aligned_labels_val[i])\n", + " elif val_original_sentence_indices[i] == val_original_sentence_indices[i-1]:\n", + " predicted_labels[-1] += aligned_preds_val[i]\n", + " true_labels[-1] += aligned_labels_val[i]\n", + " else:\n", + " predicted_labels.append(aligned_preds_val[i])\n", + " true_labels.append(aligned_labels_val[i])" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-05-30T04:14:20.017766Z", + "start_time": "2024-05-30T04:14:20.000190600Z" + } + }, + "id": "2c5c1526bee64f34" + }, + { + "cell_type": "markdown", + "source": [ + "### Postprocessing" + ], + "metadata": { + "collapsed": false + }, + "id": "27b07318b3720194" + }, + { + "cell_type": "code", + "execution_count": 174, + "outputs": [], + "source": [ + "import regex as re\n", + "\n", + "# Postprocessing\n", + "# Regex for finding I-tags that start a sequence (should be B-tags)\n", + "def incorrect_I_as_begin_tag(text):\n", + " return re.finditer(r'(?