en-ner-conll-2003/solution.ipynb

341 lines
11 KiB
Plaintext
Raw Normal View History

2024-06-09 00:14:55 +02:00
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import re\n",
"import pandas as pd\n",
"from transformers import pipeline"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"def correct_labels(input_file, output_file):\n",
" df = pd.read_csv(input_file, sep=\"\\t\", names=[\"Text\"])\n",
"\n",
" corrected_lines = []\n",
"\n",
" for line in df[\"Text\"]:\n",
" tokens = line.split(\" \")\n",
" corrected_tokens = []\n",
" previous_token = \"O\"\n",
"\n",
" for token in tokens:\n",
" if (\n",
" token == \"I-ORG\"\n",
" and previous_token != \"B-ORG\"\n",
" and previous_token != \"I-ORG\"\n",
" ):\n",
" corrected_tokens.append(\"B-ORG\")\n",
" elif (\n",
" token == \"I-PER\"\n",
" and previous_token != \"B-PER\"\n",
" and previous_token != \"I-PER\"\n",
" ):\n",
" corrected_tokens.append(\"B-PER\")\n",
" elif (\n",
" token == \"I-LOC\"\n",
" and previous_token != \"B-LOC\"\n",
" and previous_token != \"I-LOC\"\n",
" ):\n",
" corrected_tokens.append(\"B-LOC\")\n",
" elif (\n",
" token == \"I-MISC\"\n",
" and previous_token != \"B-MISC\"\n",
" and previous_token != \"I-MISC\"\n",
" ):\n",
" corrected_tokens.append(\"B-MISC\")\n",
" else:\n",
" corrected_tokens.append(token)\n",
"\n",
" previous_token = token\n",
"\n",
" corrected_line = \" \".join(corrected_tokens)\n",
" corrected_lines.append(corrected_line)\n",
"\n",
" df[\"Text\"] = corrected_lines\n",
" df.to_csv(output_file, sep=\"\\t\", index=False, header=False)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"WARNING:tensorflow:From C:\\Users\\48690\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.9_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python39\\site-packages\\tf_keras\\src\\losses.py:2976: The name tf.losses.sparse_softmax_cross_entropy is deprecated. Please use tf.compat.v1.losses.sparse_softmax_cross_entropy instead.\n",
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e6128bb3df9d45e1b7ad703507a8e9ef",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"model.safetensors: 0%| | 0.00/1.33G [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Some weights of the model checkpoint at dbmdz/bert-large-cased-finetuned-conll03-english were not used when initializing BertForTokenClassification: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']\n",
"- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
"- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "1099fa0386244b9fb276e1b988f12c30",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"tokenizer_config.json: 0%| | 0.00/60.0 [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "8c80c0068e2d41d78c4fe3edf003a841",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"vocab.txt: 0%| | 0.00/213k [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"ner_model = pipeline(\"ner\", model = 'dbmdz/bert-large-cased-finetuned-conll03-english')"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"def get_word_indices(string_to_search):\n",
" pattern = \"\\s\\S\"\n",
" matches = re.finditer(pattern, string_to_search)\n",
" indices = [m.start(0) + 1 for m in matches]\n",
" if not string_to_search[0].isspace():\n",
" indices.insert(0, 0)\n",
" return sorted(indices)\n",
"\n",
"def get_word_beginning(string_to_search, letter_index):\n",
" while letter_index > 0 and string_to_search[letter_index - 1] != \" \":\n",
" letter_index -= 1\n",
" return letter_index\n",
"\n",
"def wordpiece_tokenization(ner_tokenized, original_sentence):\n",
" word_start_index_to_tag = {}\n",
" formatted_results = []\n",
" previous_tag = \"O\"\n",
"\n",
" for result in ner_tokenized:\n",
" word = result[\"word\"].replace(\"##\", \"\")\n",
" start, end = result[\"start\"], result[\"start\"] + len(word)\n",
"\n",
" if formatted_results and (original_sentence[result[\"start\"] - 1] != \" \" or result[\"word\"].startswith(\"##\")):\n",
" formatted_results[-1][\"end\"] = end\n",
" formatted_results[-1][\"word\"] += word\n",
" else:\n",
" result[\"word\"] = word\n",
" result[\"start\"] = get_word_beginning(original_sentence, start)\n",
" result[\"end\"] = end\n",
" formatted_results.append(result)\n",
"\n",
" for result in formatted_results:\n",
" start_index = result[\"start\"]\n",
" tag = result[\"entity\"]\n",
"\n",
" if tag != \"O\":\n",
" if previous_tag != tag:\n",
" tag = f\"B-{tag.split('-')[-1]}\"\n",
" else:\n",
" tag = f\"I-{tag.split('-')[-1]}\"\n",
" word_start_index_to_tag[start_index] = tag\n",
" previous_tag = result[\"entity\"]\n",
"\n",
" for index in get_word_indices(original_sentence):\n",
" word_start_index_to_tag.setdefault(index, \"O\")\n",
"\n",
" return [word_start_index_to_tag[index] for index in sorted(word_start_index_to_tag.keys())]"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"from tqdm import tqdm\n",
"\n",
"def get_input_file(input_file):\n",
" with open(input_file, \"r\", encoding=\"utf-8\") as f:\n",
" original_sentences = f.readlines()\n",
" return original_sentences\n",
"\n",
"def save_output_file(output_file, processed_data):\n",
" with open(output_file, \"w\", encoding=\"utf-8\") as f:\n",
" for line in processed_data:\n",
" f.write(f\"{line}\\n\")\n",
"\n",
"def tokenize_file(input_file, output_file):\n",
" original_sentences = get_input_file(input_file)\n",
"\n",
" processed_data = []\n",
" for raw_sentence in tqdm(original_sentences, desc=f\"Processing {input_file}\"):\n",
" model_out = ner_model(raw_sentence.strip())\n",
" word_tokenization = wordpiece_tokenization(model_out, raw_sentence.strip())\n",
" processed_line = \" \".join(word_tokenization)\n",
" processed_data.append(processed_line)\n",
" \n",
" save_output_file(output_file, processed_data)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Processing dev-0/in.tsv: 100%|██████████| 215/215 [11:57<00:00, 3.34s/it]\n"
]
}
],
"source": [
"tokenize_file(\"dev-0/in.tsv\", \"dev-0/out.tsv\")"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Processing test-A/in.tsv: 100%|██████████| 230/230 [12:39<00:00, 3.30s/it]\n"
]
}
],
"source": [
"tokenize_file(\"test-A/in.tsv\", \"test-A/out.tsv\")"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"input_file = \"dev-0/out.tsv\"\n",
"output_file = \"dev-0/out.tsv\"\n",
"correct_labels(input_file, output_file)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"input_file = \"test-A/out.tsv\"\n",
"output_file = \"test-A/out.tsv\"\n",
"correct_labels(input_file, output_file)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy for dev0: 0.9566\n"
]
}
],
"source": [
"def dev0_accuracy():\n",
" out_file = \"dev-0/out.tsv\"\n",
" expected_file = \"dev-0/expected.tsv\"\n",
"\n",
" with open(out_file, \"r\", encoding=\"utf-8\") as f:\n",
" out_lines = f.readlines()\n",
" \n",
" with open(expected_file, \"r\", encoding=\"utf-8\") as f:\n",
" expected_lines = f.readlines()\n",
" \n",
" all_tags = 0\n",
" correct_tags = 0\n",
"\n",
" for i in range(len(out_lines)):\n",
" out_tags = out_lines[i].split()\n",
" expected_tags = expected_lines[i].split()\n",
"\n",
" all_tags += len(expected_tags)\n",
" correct_tags += sum(a == b for a, b in zip(out_tags, expected_tags))\n",
" \n",
" accuracy = correct_tags / all_tags\n",
" print(f\"Accuracy for dev0: {accuracy:.4f}\")\n",
"\n",
"dev0_accuracy()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
}
},
"nbformat": 4,
"nbformat_minor": 2
}