454 lines
14 KiB
Plaintext
454 lines
14 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"source": [
|
|
"### Importy"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"outputs": [],
|
|
"source": [
|
|
"from transformers import pipeline\n",
|
|
"import re\n",
|
|
"from tqdm import tqdm\n",
|
|
"import pandas as pd"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"ExecuteTime": {
|
|
"start_time": "2024-06-09T12:13:28.590508Z",
|
|
"end_time": "2024-06-09T12:13:40.429636Z"
|
|
}
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"source": [
|
|
"### Initializacja modelu NER"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Some weights of the model checkpoint at dbmdz/bert-large-cased-finetuned-conll03-english were not used when initializing BertForTokenClassification: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']\n",
|
|
"- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
|
|
"- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"nlp = pipeline(\"ner\", model = 'dbmdz/bert-large-cased-finetuned-conll03-english')"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"ExecuteTime": {
|
|
"start_time": "2024-06-09T12:13:40.436629Z",
|
|
"end_time": "2024-06-09T12:13:43.520630Z"
|
|
}
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"source": [
|
|
"### Metody do tokenizacji"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"outputs": [],
|
|
"source": [
|
|
"def get_word_indices(string_to_search):\n",
|
|
" pattern = \"\\s\\S\"\n",
|
|
" matches = re.finditer(pattern, string_to_search)\n",
|
|
" indices = [m.start(0) + 1 for m in matches]\n",
|
|
" if not string_to_search[0].isspace():\n",
|
|
" indices.insert(0, 0)\n",
|
|
" return sorted(indices)\n",
|
|
"\n",
|
|
"def get_word_beginning(string_to_search, letter_index):\n",
|
|
" while letter_index > 0 and string_to_search[letter_index - 1] != \" \":\n",
|
|
" letter_index -= 1\n",
|
|
" return letter_index\n",
|
|
"\n",
|
|
"def wordpiece_tokenization(ner_tokenized, original_sentence):\n",
|
|
" word_start_index_to_tag = {}\n",
|
|
" formatted_results = []\n",
|
|
" previous_tag = \"O\"\n",
|
|
"\n",
|
|
" for result in ner_tokenized:\n",
|
|
" word = result[\"word\"].replace(\"##\", \"\")\n",
|
|
" start, end = result[\"start\"], result[\"start\"] + len(word)\n",
|
|
"\n",
|
|
" if formatted_results and (original_sentence[result[\"start\"] - 1] != \" \" or result[\"word\"].startswith(\"##\")):\n",
|
|
" formatted_results[-1][\"end\"] = end\n",
|
|
" formatted_results[-1][\"word\"] += word\n",
|
|
" else:\n",
|
|
" result[\"word\"] = word\n",
|
|
" result[\"start\"] = get_word_beginning(original_sentence, start)\n",
|
|
" result[\"end\"] = end\n",
|
|
" formatted_results.append(result)\n",
|
|
"\n",
|
|
" for result in formatted_results:\n",
|
|
" start_index = result[\"start\"]\n",
|
|
" tag = result[\"entity\"]\n",
|
|
"\n",
|
|
" if tag != \"O\":\n",
|
|
" if previous_tag != tag:\n",
|
|
" tag = f\"B-{tag.split('-')[-1]}\"\n",
|
|
" else:\n",
|
|
" tag = f\"I-{tag.split('-')[-1]}\"\n",
|
|
" word_start_index_to_tag[start_index] = tag\n",
|
|
" previous_tag = result[\"entity\"]\n",
|
|
"\n",
|
|
" for index in get_word_indices(original_sentence):\n",
|
|
" word_start_index_to_tag.setdefault(index, \"O\")\n",
|
|
"\n",
|
|
" return [word_start_index_to_tag[index] for index in sorted(word_start_index_to_tag.keys())]"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"ExecuteTime": {
|
|
"start_time": "2024-06-05T22:18:22.319194Z",
|
|
"end_time": "2024-06-05T22:18:22.343447Z"
|
|
}
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"source": [
|
|
"### Przykładowe użycie"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": "[{'entity': 'I-ORG',\n 'score': 0.9995635,\n 'index': 1,\n 'word': 'Hu',\n 'start': 0,\n 'end': 2},\n {'entity': 'I-ORG',\n 'score': 0.99159384,\n 'index': 2,\n 'word': '##gging',\n 'start': 2,\n 'end': 7},\n {'entity': 'I-ORG',\n 'score': 0.99826705,\n 'index': 3,\n 'word': 'Face',\n 'start': 8,\n 'end': 12},\n {'entity': 'I-ORG',\n 'score': 0.9994404,\n 'index': 4,\n 'word': 'Inc',\n 'start': 13,\n 'end': 16},\n {'entity': 'I-LOC',\n 'score': 0.99943465,\n 'index': 11,\n 'word': 'New',\n 'start': 40,\n 'end': 43},\n {'entity': 'I-LOC',\n 'score': 0.99932706,\n 'index': 12,\n 'word': 'York',\n 'start': 44,\n 'end': 48},\n {'entity': 'I-LOC',\n 'score': 0.9993864,\n 'index': 13,\n 'word': 'City',\n 'start': 49,\n 'end': 53},\n {'entity': 'I-LOC',\n 'score': 0.9825622,\n 'index': 19,\n 'word': 'D',\n 'start': 79,\n 'end': 80},\n {'entity': 'I-LOC',\n 'score': 0.936983,\n 'index': 20,\n 'word': '##UM',\n 'start': 80,\n 'end': 82},\n {'entity': 'I-LOC',\n 'score': 0.89870995,\n 'index': 21,\n 'word': '##BO',\n 'start': 82,\n 'end': 84},\n {'entity': 'I-LOC',\n 'score': 0.97582406,\n 'index': 29,\n 'word': 'Manhattan',\n 'start': 113,\n 'end': 122},\n {'entity': 'I-LOC',\n 'score': 0.99024945,\n 'index': 30,\n 'word': 'Bridge',\n 'start': 123,\n 'end': 129}]"
|
|
},
|
|
"execution_count": 9,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"sequence = \"Hugging Face Inc. is a company based in New York City. Its headquarters are in DUMBO, therefore very\" \\\n",
|
|
" \"close to the Manhattan Bridge which is visible from the window.\"\n",
|
|
"model_out = nlp(sequence)\n",
|
|
"model_out"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"ExecuteTime": {
|
|
"start_time": "2024-06-09T12:14:36.626686Z",
|
|
"end_time": "2024-06-09T12:14:36.815685Z"
|
|
}
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"source": [
|
|
"### Tokenizacja plików"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"outputs": [],
|
|
"source": [
|
|
"def tokenize_file(input_file, output_file):\n",
|
|
" with open(input_file, \"r\", encoding=\"utf-8\") as f:\n",
|
|
" original_sentences = f.readlines()\n",
|
|
"\n",
|
|
" processed_data = []\n",
|
|
" for raw_sentence in tqdm(original_sentences, desc=f\"Processing {input_file}\"):\n",
|
|
" model_out = nlp(raw_sentence.strip())\n",
|
|
" word_tokenization = wordpiece_tokenization(model_out, raw_sentence.strip())\n",
|
|
" processed_line = \" \".join(word_tokenization)\n",
|
|
" processed_data.append(processed_line)\n",
|
|
"\n",
|
|
" with open(output_file, \"w\", encoding=\"utf-8\") as f:\n",
|
|
" for line in processed_data:\n",
|
|
" f.write(f\"{line}\\n\")"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"ExecuteTime": {
|
|
"start_time": "2024-06-05T22:18:22.339446Z",
|
|
"end_time": "2024-06-05T22:18:22.350525Z"
|
|
}
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"source": [
|
|
"### Ewaluacja"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Processing dev-0/in.tsv: 100%|██████████| 215/215 [03:28<00:00, 1.03it/s]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"tokenize_file(\"dev-0/in.tsv\", \"dev-0/out.tsv\")"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"ExecuteTime": {
|
|
"start_time": "2024-06-05T22:18:22.354491Z",
|
|
"end_time": "2024-06-05T22:21:51.061001Z"
|
|
}
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Processing test-A/in.tsv: 100%|██████████| 230/230 [03:42<00:00, 1.03it/s]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"tokenize_file(\"test-A/in.tsv\", \"test-A/out.tsv\")"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"ExecuteTime": {
|
|
"start_time": "2024-06-05T22:21:51.062002Z",
|
|
"end_time": "2024-06-05T22:25:33.462085Z"
|
|
}
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"source": [
|
|
"### Poprawienie etykiet"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"outputs": [],
|
|
"source": [
|
|
"def correct_labels(input_file, output_file):\n",
|
|
" df = pd.read_csv(input_file, sep=\"\\t\", names=[\"Text\"])\n",
|
|
"\n",
|
|
" corrected_lines = []\n",
|
|
"\n",
|
|
" for line in df[\"Text\"]:\n",
|
|
" tokens = line.split(\" \")\n",
|
|
" corrected_tokens = []\n",
|
|
" previous_token = \"O\"\n",
|
|
"\n",
|
|
" for token in tokens:\n",
|
|
" if (\n",
|
|
" token == \"I-ORG\"\n",
|
|
" and previous_token != \"B-ORG\"\n",
|
|
" and previous_token != \"I-ORG\"\n",
|
|
" ):\n",
|
|
" corrected_tokens.append(\"B-ORG\")\n",
|
|
" elif (\n",
|
|
" token == \"I-PER\"\n",
|
|
" and previous_token != \"B-PER\"\n",
|
|
" and previous_token != \"I-PER\"\n",
|
|
" ):\n",
|
|
" corrected_tokens.append(\"B-PER\")\n",
|
|
" elif (\n",
|
|
" token == \"I-LOC\"\n",
|
|
" and previous_token != \"B-LOC\"\n",
|
|
" and previous_token != \"I-LOC\"\n",
|
|
" ):\n",
|
|
" corrected_tokens.append(\"B-LOC\")\n",
|
|
" elif (\n",
|
|
" token == \"I-MISC\"\n",
|
|
" and previous_token != \"B-MISC\"\n",
|
|
" and previous_token != \"I-MISC\"\n",
|
|
" ):\n",
|
|
" corrected_tokens.append(\"B-MISC\")\n",
|
|
" else:\n",
|
|
" corrected_tokens.append(token)\n",
|
|
"\n",
|
|
" previous_token = token\n",
|
|
"\n",
|
|
" corrected_line = \" \".join(corrected_tokens)\n",
|
|
" corrected_lines.append(corrected_line)\n",
|
|
"\n",
|
|
" df[\"Text\"] = corrected_lines\n",
|
|
" df.to_csv(output_file, sep=\"\\t\", index=False, header=False)\n"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"ExecuteTime": {
|
|
"start_time": "2024-06-05T22:25:33.468992Z",
|
|
"end_time": "2024-06-05T22:25:33.507038Z"
|
|
}
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"outputs": [],
|
|
"source": [
|
|
"input_file = \"test-A/out.tsv\"\n",
|
|
"output_file = \"test-A/out.tsv\"\n",
|
|
"correct_labels(input_file, output_file)"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"ExecuteTime": {
|
|
"start_time": "2024-06-05T22:25:33.480964Z",
|
|
"end_time": "2024-06-05T22:25:33.593979Z"
|
|
}
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"outputs": [],
|
|
"source": [
|
|
"input_file = \"dev-0/out.tsv\"\n",
|
|
"output_file = \"dev-0/out.tsv\"\n",
|
|
"correct_labels(input_file, output_file)"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"ExecuteTime": {
|
|
"start_time": "2024-06-05T22:25:33.589982Z",
|
|
"end_time": "2024-06-05T22:25:33.624147Z"
|
|
}
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"source": [
|
|
"### Obliczenie dokładności"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 10,
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Processing dev-0/in.tsv: 100%|██████████| 215/215 [03:36<00:00, 1.01s/it]"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Accuracy: 0.9236\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"def calculate_accuracy(input_file, expected_file):\n",
|
|
" with open(input_file, \"r\", encoding=\"utf-8\") as f:\n",
|
|
" original_sentences = f.readlines()\n",
|
|
"\n",
|
|
" with open(expected_file, \"r\", encoding=\"utf-8\") as f:\n",
|
|
" expected_tags = f.readlines()\n",
|
|
"\n",
|
|
" total_tags = 0\n",
|
|
" correct_tags = 0\n",
|
|
"\n",
|
|
" for raw_sentence, expected_line in tqdm(zip(original_sentences, expected_tags), desc=f\"Processing {input_file}\", total=len(original_sentences)):\n",
|
|
" model_out = nlp(raw_sentence.strip())\n",
|
|
" word_tokenization = wordpiece_tokenization(model_out, raw_sentence.strip())\n",
|
|
" expected_tags_list = expected_line.strip().split()\n",
|
|
"\n",
|
|
" total_tags += len(expected_tags_list)\n",
|
|
" correct_tags += sum(p == e for p, e in zip(word_tokenization, expected_tags_list))\n",
|
|
"\n",
|
|
" accuracy = correct_tags / total_tags\n",
|
|
" print(f\"Accuracy: {accuracy:.4f}\")\n",
|
|
"\n",
|
|
"calculate_accuracy(\"dev-0/in.tsv\", \"dev-0/expected.tsv\")"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"ExecuteTime": {
|
|
"start_time": "2024-06-05T22:25:33.625148Z",
|
|
"end_time": "2024-06-05T22:29:10.146791Z"
|
|
}
|
|
}
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 2
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython2",
|
|
"version": "2.7.6"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 0
|
|
}
|