341 lines
11 KiB
Plaintext
341 lines
11 KiB
Plaintext
|
{
|
||
|
"cells": [
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 1,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"import re\n",
|
||
|
"import pandas as pd\n",
|
||
|
"from transformers import pipeline"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 2,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"def correct_labels(input_file, output_file):\n",
|
||
|
" df = pd.read_csv(input_file, sep=\"\\t\", names=[\"Text\"])\n",
|
||
|
"\n",
|
||
|
" corrected_lines = []\n",
|
||
|
"\n",
|
||
|
" for line in df[\"Text\"]:\n",
|
||
|
" tokens = line.split(\" \")\n",
|
||
|
" corrected_tokens = []\n",
|
||
|
" previous_token = \"O\"\n",
|
||
|
"\n",
|
||
|
" for token in tokens:\n",
|
||
|
" if (\n",
|
||
|
" token == \"I-ORG\"\n",
|
||
|
" and previous_token != \"B-ORG\"\n",
|
||
|
" and previous_token != \"I-ORG\"\n",
|
||
|
" ):\n",
|
||
|
" corrected_tokens.append(\"B-ORG\")\n",
|
||
|
" elif (\n",
|
||
|
" token == \"I-PER\"\n",
|
||
|
" and previous_token != \"B-PER\"\n",
|
||
|
" and previous_token != \"I-PER\"\n",
|
||
|
" ):\n",
|
||
|
" corrected_tokens.append(\"B-PER\")\n",
|
||
|
" elif (\n",
|
||
|
" token == \"I-LOC\"\n",
|
||
|
" and previous_token != \"B-LOC\"\n",
|
||
|
" and previous_token != \"I-LOC\"\n",
|
||
|
" ):\n",
|
||
|
" corrected_tokens.append(\"B-LOC\")\n",
|
||
|
" elif (\n",
|
||
|
" token == \"I-MISC\"\n",
|
||
|
" and previous_token != \"B-MISC\"\n",
|
||
|
" and previous_token != \"I-MISC\"\n",
|
||
|
" ):\n",
|
||
|
" corrected_tokens.append(\"B-MISC\")\n",
|
||
|
" else:\n",
|
||
|
" corrected_tokens.append(token)\n",
|
||
|
"\n",
|
||
|
" previous_token = token\n",
|
||
|
"\n",
|
||
|
" corrected_line = \" \".join(corrected_tokens)\n",
|
||
|
" corrected_lines.append(corrected_line)\n",
|
||
|
"\n",
|
||
|
" df[\"Text\"] = corrected_lines\n",
|
||
|
" df.to_csv(output_file, sep=\"\\t\", index=False, header=False)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 4,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"WARNING:tensorflow:From C:\\Users\\48690\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.9_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python39\\site-packages\\tf_keras\\src\\losses.py:2976: The name tf.losses.sparse_softmax_cross_entropy is deprecated. Please use tf.compat.v1.losses.sparse_softmax_cross_entropy instead.\n",
|
||
|
"\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"data": {
|
||
|
"application/vnd.jupyter.widget-view+json": {
|
||
|
"model_id": "e6128bb3df9d45e1b7ad703507a8e9ef",
|
||
|
"version_major": 2,
|
||
|
"version_minor": 0
|
||
|
},
|
||
|
"text/plain": [
|
||
|
"model.safetensors: 0%| | 0.00/1.33G [00:00<?, ?B/s]"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
},
|
||
|
{
|
||
|
"name": "stderr",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Some weights of the model checkpoint at dbmdz/bert-large-cased-finetuned-conll03-english were not used when initializing BertForTokenClassification: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']\n",
|
||
|
"- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
|
||
|
"- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"data": {
|
||
|
"application/vnd.jupyter.widget-view+json": {
|
||
|
"model_id": "1099fa0386244b9fb276e1b988f12c30",
|
||
|
"version_major": 2,
|
||
|
"version_minor": 0
|
||
|
},
|
||
|
"text/plain": [
|
||
|
"tokenizer_config.json: 0%| | 0.00/60.0 [00:00<?, ?B/s]"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
},
|
||
|
{
|
||
|
"data": {
|
||
|
"application/vnd.jupyter.widget-view+json": {
|
||
|
"model_id": "8c80c0068e2d41d78c4fe3edf003a841",
|
||
|
"version_major": 2,
|
||
|
"version_minor": 0
|
||
|
},
|
||
|
"text/plain": [
|
||
|
"vocab.txt: 0%| | 0.00/213k [00:00<?, ?B/s]"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"ner_model = pipeline(\"ner\", model = 'dbmdz/bert-large-cased-finetuned-conll03-english')"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 5,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"def get_word_indices(string_to_search):\n",
|
||
|
" pattern = \"\\s\\S\"\n",
|
||
|
" matches = re.finditer(pattern, string_to_search)\n",
|
||
|
" indices = [m.start(0) + 1 for m in matches]\n",
|
||
|
" if not string_to_search[0].isspace():\n",
|
||
|
" indices.insert(0, 0)\n",
|
||
|
" return sorted(indices)\n",
|
||
|
"\n",
|
||
|
"def get_word_beginning(string_to_search, letter_index):\n",
|
||
|
" while letter_index > 0 and string_to_search[letter_index - 1] != \" \":\n",
|
||
|
" letter_index -= 1\n",
|
||
|
" return letter_index\n",
|
||
|
"\n",
|
||
|
"def wordpiece_tokenization(ner_tokenized, original_sentence):\n",
|
||
|
" word_start_index_to_tag = {}\n",
|
||
|
" formatted_results = []\n",
|
||
|
" previous_tag = \"O\"\n",
|
||
|
"\n",
|
||
|
" for result in ner_tokenized:\n",
|
||
|
" word = result[\"word\"].replace(\"##\", \"\")\n",
|
||
|
" start, end = result[\"start\"], result[\"start\"] + len(word)\n",
|
||
|
"\n",
|
||
|
" if formatted_results and (original_sentence[result[\"start\"] - 1] != \" \" or result[\"word\"].startswith(\"##\")):\n",
|
||
|
" formatted_results[-1][\"end\"] = end\n",
|
||
|
" formatted_results[-1][\"word\"] += word\n",
|
||
|
" else:\n",
|
||
|
" result[\"word\"] = word\n",
|
||
|
" result[\"start\"] = get_word_beginning(original_sentence, start)\n",
|
||
|
" result[\"end\"] = end\n",
|
||
|
" formatted_results.append(result)\n",
|
||
|
"\n",
|
||
|
" for result in formatted_results:\n",
|
||
|
" start_index = result[\"start\"]\n",
|
||
|
" tag = result[\"entity\"]\n",
|
||
|
"\n",
|
||
|
" if tag != \"O\":\n",
|
||
|
" if previous_tag != tag:\n",
|
||
|
" tag = f\"B-{tag.split('-')[-1]}\"\n",
|
||
|
" else:\n",
|
||
|
" tag = f\"I-{tag.split('-')[-1]}\"\n",
|
||
|
" word_start_index_to_tag[start_index] = tag\n",
|
||
|
" previous_tag = result[\"entity\"]\n",
|
||
|
"\n",
|
||
|
" for index in get_word_indices(original_sentence):\n",
|
||
|
" word_start_index_to_tag.setdefault(index, \"O\")\n",
|
||
|
"\n",
|
||
|
" return [word_start_index_to_tag[index] for index in sorted(word_start_index_to_tag.keys())]"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 6,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"from tqdm import tqdm\n",
|
||
|
"\n",
|
||
|
"def get_input_file(input_file):\n",
|
||
|
" with open(input_file, \"r\", encoding=\"utf-8\") as f:\n",
|
||
|
" original_sentences = f.readlines()\n",
|
||
|
" return original_sentences\n",
|
||
|
"\n",
|
||
|
"def save_output_file(output_file, processed_data):\n",
|
||
|
" with open(output_file, \"w\", encoding=\"utf-8\") as f:\n",
|
||
|
" for line in processed_data:\n",
|
||
|
" f.write(f\"{line}\\n\")\n",
|
||
|
"\n",
|
||
|
"def tokenize_file(input_file, output_file):\n",
|
||
|
" original_sentences = get_input_file(input_file)\n",
|
||
|
"\n",
|
||
|
" processed_data = []\n",
|
||
|
" for raw_sentence in tqdm(original_sentences, desc=f\"Processing {input_file}\"):\n",
|
||
|
" model_out = ner_model(raw_sentence.strip())\n",
|
||
|
" word_tokenization = wordpiece_tokenization(model_out, raw_sentence.strip())\n",
|
||
|
" processed_line = \" \".join(word_tokenization)\n",
|
||
|
" processed_data.append(processed_line)\n",
|
||
|
" \n",
|
||
|
" save_output_file(output_file, processed_data)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 7,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stderr",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Processing dev-0/in.tsv: 100%|██████████| 215/215 [11:57<00:00, 3.34s/it]\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"tokenize_file(\"dev-0/in.tsv\", \"dev-0/out.tsv\")"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 8,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stderr",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Processing test-A/in.tsv: 100%|██████████| 230/230 [12:39<00:00, 3.30s/it]\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"tokenize_file(\"test-A/in.tsv\", \"test-A/out.tsv\")"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 9,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"input_file = \"dev-0/out.tsv\"\n",
|
||
|
"output_file = \"dev-0/out.tsv\"\n",
|
||
|
"correct_labels(input_file, output_file)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 10,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"input_file = \"test-A/out.tsv\"\n",
|
||
|
"output_file = \"test-A/out.tsv\"\n",
|
||
|
"correct_labels(input_file, output_file)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 13,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Accuracy for dev0: 0.9566\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"def dev0_accuracy():\n",
|
||
|
" out_file = \"dev-0/out.tsv\"\n",
|
||
|
" expected_file = \"dev-0/expected.tsv\"\n",
|
||
|
"\n",
|
||
|
" with open(out_file, \"r\", encoding=\"utf-8\") as f:\n",
|
||
|
" out_lines = f.readlines()\n",
|
||
|
" \n",
|
||
|
" with open(expected_file, \"r\", encoding=\"utf-8\") as f:\n",
|
||
|
" expected_lines = f.readlines()\n",
|
||
|
" \n",
|
||
|
" all_tags = 0\n",
|
||
|
" correct_tags = 0\n",
|
||
|
"\n",
|
||
|
" for i in range(len(out_lines)):\n",
|
||
|
" out_tags = out_lines[i].split()\n",
|
||
|
" expected_tags = expected_lines[i].split()\n",
|
||
|
"\n",
|
||
|
" all_tags += len(expected_tags)\n",
|
||
|
" correct_tags += sum(a == b for a, b in zip(out_tags, expected_tags))\n",
|
||
|
" \n",
|
||
|
" accuracy = correct_tags / all_tags\n",
|
||
|
" print(f\"Accuracy for dev0: {accuracy:.4f}\")\n",
|
||
|
"\n",
|
||
|
"dev0_accuracy()"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"metadata": {
|
||
|
"kernelspec": {
|
||
|
"display_name": "Python 3",
|
||
|
"language": "python",
|
||
|
"name": "python3"
|
||
|
},
|
||
|
"language_info": {
|
||
|
"codemirror_mode": {
|
||
|
"name": "ipython",
|
||
|
"version": 3
|
||
|
},
|
||
|
"file_extension": ".py",
|
||
|
"mimetype": "text/x-python",
|
||
|
"name": "python",
|
||
|
"nbconvert_exporter": "python",
|
||
|
"pygments_lexer": "ipython3",
|
||
|
"version": "3.9.13"
|
||
|
}
|
||
|
},
|
||
|
"nbformat": 4,
|
||
|
"nbformat_minor": 2
|
||
|
}
|