{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import re\n", "import pandas as pd\n", "from transformers import pipeline" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "def correct_labels(input_file, output_file):\n", " df = pd.read_csv(input_file, sep=\"\\t\", names=[\"Text\"])\n", "\n", " corrected_lines = []\n", "\n", " for line in df[\"Text\"]:\n", " tokens = line.split(\" \")\n", " corrected_tokens = []\n", " previous_token = \"O\"\n", "\n", " for token in tokens:\n", " if (\n", " token == \"I-ORG\"\n", " and previous_token != \"B-ORG\"\n", " and previous_token != \"I-ORG\"\n", " ):\n", " corrected_tokens.append(\"B-ORG\")\n", " elif (\n", " token == \"I-PER\"\n", " and previous_token != \"B-PER\"\n", " and previous_token != \"I-PER\"\n", " ):\n", " corrected_tokens.append(\"B-PER\")\n", " elif (\n", " token == \"I-LOC\"\n", " and previous_token != \"B-LOC\"\n", " and previous_token != \"I-LOC\"\n", " ):\n", " corrected_tokens.append(\"B-LOC\")\n", " elif (\n", " token == \"I-MISC\"\n", " and previous_token != \"B-MISC\"\n", " and previous_token != \"I-MISC\"\n", " ):\n", " corrected_tokens.append(\"B-MISC\")\n", " else:\n", " corrected_tokens.append(token)\n", "\n", " previous_token = token\n", "\n", " corrected_line = \" \".join(corrected_tokens)\n", " corrected_lines.append(corrected_line)\n", "\n", " df[\"Text\"] = corrected_lines\n", " df.to_csv(output_file, sep=\"\\t\", index=False, header=False)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From C:\\Users\\48690\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.9_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python39\\site-packages\\tf_keras\\src\\losses.py:2976: The name tf.losses.sparse_softmax_cross_entropy is deprecated. Please use tf.compat.v1.losses.sparse_softmax_cross_entropy instead.\n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "e6128bb3df9d45e1b7ad703507a8e9ef", "version_major": 2, "version_minor": 0 }, "text/plain": [ "model.safetensors: 0%| | 0.00/1.33G [00:00 0 and string_to_search[letter_index - 1] != \" \":\n", " letter_index -= 1\n", " return letter_index\n", "\n", "def wordpiece_tokenization(ner_tokenized, original_sentence):\n", " word_start_index_to_tag = {}\n", " formatted_results = []\n", " previous_tag = \"O\"\n", "\n", " for result in ner_tokenized:\n", " word = result[\"word\"].replace(\"##\", \"\")\n", " start, end = result[\"start\"], result[\"start\"] + len(word)\n", "\n", " if formatted_results and (original_sentence[result[\"start\"] - 1] != \" \" or result[\"word\"].startswith(\"##\")):\n", " formatted_results[-1][\"end\"] = end\n", " formatted_results[-1][\"word\"] += word\n", " else:\n", " result[\"word\"] = word\n", " result[\"start\"] = get_word_beginning(original_sentence, start)\n", " result[\"end\"] = end\n", " formatted_results.append(result)\n", "\n", " for result in formatted_results:\n", " start_index = result[\"start\"]\n", " tag = result[\"entity\"]\n", "\n", " if tag != \"O\":\n", " if previous_tag != tag:\n", " tag = f\"B-{tag.split('-')[-1]}\"\n", " else:\n", " tag = f\"I-{tag.split('-')[-1]}\"\n", " word_start_index_to_tag[start_index] = tag\n", " previous_tag = result[\"entity\"]\n", "\n", " for index in get_word_indices(original_sentence):\n", " word_start_index_to_tag.setdefault(index, \"O\")\n", "\n", " return [word_start_index_to_tag[index] for index in sorted(word_start_index_to_tag.keys())]" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "from tqdm import tqdm\n", "\n", "def get_input_file(input_file):\n", " with open(input_file, \"r\", encoding=\"utf-8\") as f:\n", " original_sentences = f.readlines()\n", " return original_sentences\n", "\n", "def save_output_file(output_file, processed_data):\n", " with open(output_file, \"w\", encoding=\"utf-8\") as f:\n", " for line in processed_data:\n", " f.write(f\"{line}\\n\")\n", "\n", "def tokenize_file(input_file, output_file):\n", " original_sentences = get_input_file(input_file)\n", "\n", " processed_data = []\n", " for raw_sentence in tqdm(original_sentences, desc=f\"Processing {input_file}\"):\n", " model_out = ner_model(raw_sentence.strip())\n", " word_tokenization = wordpiece_tokenization(model_out, raw_sentence.strip())\n", " processed_line = \" \".join(word_tokenization)\n", " processed_data.append(processed_line)\n", " \n", " save_output_file(output_file, processed_data)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Processing dev-0/in.tsv: 100%|██████████| 215/215 [11:57<00:00, 3.34s/it]\n" ] } ], "source": [ "tokenize_file(\"dev-0/in.tsv\", \"dev-0/out.tsv\")" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Processing test-A/in.tsv: 100%|██████████| 230/230 [12:39<00:00, 3.30s/it]\n" ] } ], "source": [ "tokenize_file(\"test-A/in.tsv\", \"test-A/out.tsv\")" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "input_file = \"dev-0/out.tsv\"\n", "output_file = \"dev-0/out.tsv\"\n", "correct_labels(input_file, output_file)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "input_file = \"test-A/out.tsv\"\n", "output_file = \"test-A/out.tsv\"\n", "correct_labels(input_file, output_file)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy for dev0: 0.9566\n" ] } ], "source": [ "def dev0_accuracy():\n", " out_file = \"dev-0/out.tsv\"\n", " expected_file = \"dev-0/expected.tsv\"\n", "\n", " with open(out_file, \"r\", encoding=\"utf-8\") as f:\n", " out_lines = f.readlines()\n", " \n", " with open(expected_file, \"r\", encoding=\"utf-8\") as f:\n", " expected_lines = f.readlines()\n", " \n", " all_tags = 0\n", " correct_tags = 0\n", "\n", " for i in range(len(out_lines)):\n", " out_tags = out_lines[i].split()\n", " expected_tags = expected_lines[i].split()\n", "\n", " all_tags += len(expected_tags)\n", " correct_tags += sum(a == b for a, b in zip(out_tags, expected_tags))\n", " \n", " accuracy = correct_tags / all_tags\n", " print(f\"Accuracy for dev0: {accuracy:.4f}\")\n", "\n", "dev0_accuracy()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.13" } }, "nbformat": 4, "nbformat_minor": 2 }