{ "cells": [ { "cell_type": "markdown", "source": [ "### Importy" ], "metadata": { "collapsed": false } }, { "cell_type": "code", "execution_count": 4, "outputs": [], "source": [ "from transformers import pipeline\n", "import re\n", "from tqdm import tqdm\n", "import pandas as pd" ], "metadata": { "collapsed": false, "ExecuteTime": { "start_time": "2024-06-09T12:13:28.590508Z", "end_time": "2024-06-09T12:13:40.429636Z" } } }, { "cell_type": "markdown", "source": [ "### Initializacja modelu NER" ], "metadata": { "collapsed": false } }, { "cell_type": "code", "execution_count": 5, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Some weights of the model checkpoint at dbmdz/bert-large-cased-finetuned-conll03-english were not used when initializing BertForTokenClassification: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']\n", "- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n" ] } ], "source": [ "nlp = pipeline(\"ner\", model = 'dbmdz/bert-large-cased-finetuned-conll03-english')" ], "metadata": { "collapsed": false, "ExecuteTime": { "start_time": "2024-06-09T12:13:40.436629Z", "end_time": "2024-06-09T12:13:43.520630Z" } } }, { "cell_type": "markdown", "source": [ "### Metody do tokenizacji" ], "metadata": { "collapsed": false } }, { "cell_type": "code", "execution_count": 3, "outputs": [], "source": [ "def get_word_indices(string_to_search):\n", " pattern = \"\\s\\S\"\n", " matches = re.finditer(pattern, string_to_search)\n", " indices = [m.start(0) + 1 for m in matches]\n", " if not string_to_search[0].isspace():\n", " indices.insert(0, 0)\n", " return sorted(indices)\n", "\n", "def get_word_beginning(string_to_search, letter_index):\n", " while letter_index > 0 and string_to_search[letter_index - 1] != \" \":\n", " letter_index -= 1\n", " return letter_index\n", "\n", "def wordpiece_tokenization(ner_tokenized, original_sentence):\n", " word_start_index_to_tag = {}\n", " formatted_results = []\n", " previous_tag = \"O\"\n", "\n", " for result in ner_tokenized:\n", " word = result[\"word\"].replace(\"##\", \"\")\n", " start, end = result[\"start\"], result[\"start\"] + len(word)\n", "\n", " if formatted_results and (original_sentence[result[\"start\"] - 1] != \" \" or result[\"word\"].startswith(\"##\")):\n", " formatted_results[-1][\"end\"] = end\n", " formatted_results[-1][\"word\"] += word\n", " else:\n", " result[\"word\"] = word\n", " result[\"start\"] = get_word_beginning(original_sentence, start)\n", " result[\"end\"] = end\n", " formatted_results.append(result)\n", "\n", " for result in formatted_results:\n", " start_index = result[\"start\"]\n", " tag = result[\"entity\"]\n", "\n", " if tag != \"O\":\n", " if previous_tag != tag:\n", " tag = f\"B-{tag.split('-')[-1]}\"\n", " else:\n", " tag = f\"I-{tag.split('-')[-1]}\"\n", " word_start_index_to_tag[start_index] = tag\n", " previous_tag = result[\"entity\"]\n", "\n", " for index in get_word_indices(original_sentence):\n", " word_start_index_to_tag.setdefault(index, \"O\")\n", "\n", " return [word_start_index_to_tag[index] for index in sorted(word_start_index_to_tag.keys())]" ], "metadata": { "collapsed": false, "ExecuteTime": { "start_time": "2024-06-05T22:18:22.319194Z", "end_time": "2024-06-05T22:18:22.343447Z" } } }, { "cell_type": "markdown", "source": [ "### Przykładowe użycie" ], "metadata": { "collapsed": false } }, { "cell_type": "code", "execution_count": 9, "outputs": [ { "data": { "text/plain": "[{'entity': 'I-ORG',\n 'score': 0.9995635,\n 'index': 1,\n 'word': 'Hu',\n 'start': 0,\n 'end': 2},\n {'entity': 'I-ORG',\n 'score': 0.99159384,\n 'index': 2,\n 'word': '##gging',\n 'start': 2,\n 'end': 7},\n {'entity': 'I-ORG',\n 'score': 0.99826705,\n 'index': 3,\n 'word': 'Face',\n 'start': 8,\n 'end': 12},\n {'entity': 'I-ORG',\n 'score': 0.9994404,\n 'index': 4,\n 'word': 'Inc',\n 'start': 13,\n 'end': 16},\n {'entity': 'I-LOC',\n 'score': 0.99943465,\n 'index': 11,\n 'word': 'New',\n 'start': 40,\n 'end': 43},\n {'entity': 'I-LOC',\n 'score': 0.99932706,\n 'index': 12,\n 'word': 'York',\n 'start': 44,\n 'end': 48},\n {'entity': 'I-LOC',\n 'score': 0.9993864,\n 'index': 13,\n 'word': 'City',\n 'start': 49,\n 'end': 53},\n {'entity': 'I-LOC',\n 'score': 0.9825622,\n 'index': 19,\n 'word': 'D',\n 'start': 79,\n 'end': 80},\n {'entity': 'I-LOC',\n 'score': 0.936983,\n 'index': 20,\n 'word': '##UM',\n 'start': 80,\n 'end': 82},\n {'entity': 'I-LOC',\n 'score': 0.89870995,\n 'index': 21,\n 'word': '##BO',\n 'start': 82,\n 'end': 84},\n {'entity': 'I-LOC',\n 'score': 0.97582406,\n 'index': 29,\n 'word': 'Manhattan',\n 'start': 113,\n 'end': 122},\n {'entity': 'I-LOC',\n 'score': 0.99024945,\n 'index': 30,\n 'word': 'Bridge',\n 'start': 123,\n 'end': 129}]" }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sequence = \"Hugging Face Inc. is a company based in New York City. Its headquarters are in DUMBO, therefore very\" \\\n", " \"close to the Manhattan Bridge which is visible from the window.\"\n", "model_out = nlp(sequence)\n", "model_out" ], "metadata": { "collapsed": false, "ExecuteTime": { "start_time": "2024-06-09T12:14:36.626686Z", "end_time": "2024-06-09T12:14:36.815685Z" } } }, { "cell_type": "markdown", "source": [ "### Tokenizacja plików" ], "metadata": { "collapsed": false } }, { "cell_type": "code", "execution_count": 4, "outputs": [], "source": [ "def tokenize_file(input_file, output_file):\n", " with open(input_file, \"r\", encoding=\"utf-8\") as f:\n", " original_sentences = f.readlines()\n", "\n", " processed_data = []\n", " for raw_sentence in tqdm(original_sentences, desc=f\"Processing {input_file}\"):\n", " model_out = nlp(raw_sentence.strip())\n", " word_tokenization = wordpiece_tokenization(model_out, raw_sentence.strip())\n", " processed_line = \" \".join(word_tokenization)\n", " processed_data.append(processed_line)\n", "\n", " with open(output_file, \"w\", encoding=\"utf-8\") as f:\n", " for line in processed_data:\n", " f.write(f\"{line}\\n\")" ], "metadata": { "collapsed": false, "ExecuteTime": { "start_time": "2024-06-05T22:18:22.339446Z", "end_time": "2024-06-05T22:18:22.350525Z" } } }, { "cell_type": "markdown", "source": [ "### Ewaluacja" ], "metadata": { "collapsed": false } }, { "cell_type": "code", "execution_count": 5, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Processing dev-0/in.tsv: 100%|██████████| 215/215 [03:28<00:00, 1.03it/s]\n" ] } ], "source": [ "tokenize_file(\"dev-0/in.tsv\", \"dev-0/out.tsv\")" ], "metadata": { "collapsed": false, "ExecuteTime": { "start_time": "2024-06-05T22:18:22.354491Z", "end_time": "2024-06-05T22:21:51.061001Z" } } }, { "cell_type": "code", "execution_count": 6, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Processing test-A/in.tsv: 100%|██████████| 230/230 [03:42<00:00, 1.03it/s]\n" ] } ], "source": [ "tokenize_file(\"test-A/in.tsv\", \"test-A/out.tsv\")" ], "metadata": { "collapsed": false, "ExecuteTime": { "start_time": "2024-06-05T22:21:51.062002Z", "end_time": "2024-06-05T22:25:33.462085Z" } } }, { "cell_type": "markdown", "source": [ "### Poprawienie etykiet" ], "metadata": { "collapsed": false } }, { "cell_type": "code", "execution_count": 7, "outputs": [], "source": [ "def correct_labels(input_file, output_file):\n", " df = pd.read_csv(input_file, sep=\"\\t\", names=[\"Text\"])\n", "\n", " corrected_lines = []\n", "\n", " for line in df[\"Text\"]:\n", " tokens = line.split(\" \")\n", " corrected_tokens = []\n", " previous_token = \"O\"\n", "\n", " for token in tokens:\n", " if (\n", " token == \"I-ORG\"\n", " and previous_token != \"B-ORG\"\n", " and previous_token != \"I-ORG\"\n", " ):\n", " corrected_tokens.append(\"B-ORG\")\n", " elif (\n", " token == \"I-PER\"\n", " and previous_token != \"B-PER\"\n", " and previous_token != \"I-PER\"\n", " ):\n", " corrected_tokens.append(\"B-PER\")\n", " elif (\n", " token == \"I-LOC\"\n", " and previous_token != \"B-LOC\"\n", " and previous_token != \"I-LOC\"\n", " ):\n", " corrected_tokens.append(\"B-LOC\")\n", " elif (\n", " token == \"I-MISC\"\n", " and previous_token != \"B-MISC\"\n", " and previous_token != \"I-MISC\"\n", " ):\n", " corrected_tokens.append(\"B-MISC\")\n", " else:\n", " corrected_tokens.append(token)\n", "\n", " previous_token = token\n", "\n", " corrected_line = \" \".join(corrected_tokens)\n", " corrected_lines.append(corrected_line)\n", "\n", " df[\"Text\"] = corrected_lines\n", " df.to_csv(output_file, sep=\"\\t\", index=False, header=False)\n" ], "metadata": { "collapsed": false, "ExecuteTime": { "start_time": "2024-06-05T22:25:33.468992Z", "end_time": "2024-06-05T22:25:33.507038Z" } } }, { "cell_type": "code", "execution_count": 8, "outputs": [], "source": [ "input_file = \"test-A/out.tsv\"\n", "output_file = \"test-A/out.tsv\"\n", "correct_labels(input_file, output_file)" ], "metadata": { "collapsed": false, "ExecuteTime": { "start_time": "2024-06-05T22:25:33.480964Z", "end_time": "2024-06-05T22:25:33.593979Z" } } }, { "cell_type": "code", "execution_count": 9, "outputs": [], "source": [ "input_file = \"dev-0/out.tsv\"\n", "output_file = \"dev-0/out.tsv\"\n", "correct_labels(input_file, output_file)" ], "metadata": { "collapsed": false, "ExecuteTime": { "start_time": "2024-06-05T22:25:33.589982Z", "end_time": "2024-06-05T22:25:33.624147Z" } } }, { "cell_type": "markdown", "source": [ "### Obliczenie dokładności" ], "metadata": { "collapsed": false } }, { "cell_type": "code", "execution_count": 10, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Processing dev-0/in.tsv: 100%|██████████| 215/215 [03:36<00:00, 1.01s/it]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Accuracy: 0.9236\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "def calculate_accuracy(input_file, expected_file):\n", " with open(input_file, \"r\", encoding=\"utf-8\") as f:\n", " original_sentences = f.readlines()\n", "\n", " with open(expected_file, \"r\", encoding=\"utf-8\") as f:\n", " expected_tags = f.readlines()\n", "\n", " total_tags = 0\n", " correct_tags = 0\n", "\n", " for raw_sentence, expected_line in tqdm(zip(original_sentences, expected_tags), desc=f\"Processing {input_file}\", total=len(original_sentences)):\n", " model_out = nlp(raw_sentence.strip())\n", " word_tokenization = wordpiece_tokenization(model_out, raw_sentence.strip())\n", " expected_tags_list = expected_line.strip().split()\n", "\n", " total_tags += len(expected_tags_list)\n", " correct_tags += sum(p == e for p, e in zip(word_tokenization, expected_tags_list))\n", "\n", " accuracy = correct_tags / total_tags\n", " print(f\"Accuracy: {accuracy:.4f}\")\n", "\n", "calculate_accuracy(\"dev-0/in.tsv\", \"dev-0/expected.tsv\")" ], "metadata": { "collapsed": false, "ExecuteTime": { "start_time": "2024-06-05T22:25:33.625148Z", "end_time": "2024-06-05T22:29:10.146791Z" } } } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.6" } }, "nbformat": 4, "nbformat_minor": 0 }