DL_TRANSFORMER/transformer5.ipynb

454 lines
14 KiB
Plaintext
Raw Permalink Normal View History

2024-06-08 15:52:20 +02:00
{
"cells": [
{
"cell_type": "markdown",
"source": [
"### Importy"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
2024-06-09 12:16:39 +02:00
"execution_count": 4,
2024-06-08 15:52:20 +02:00
"outputs": [],
"source": [
"from transformers import pipeline\n",
"import re\n",
"from tqdm import tqdm\n",
"import pandas as pd"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
2024-06-09 12:16:39 +02:00
"start_time": "2024-06-09T12:13:28.590508Z",
"end_time": "2024-06-09T12:13:40.429636Z"
2024-06-08 15:52:20 +02:00
}
}
},
{
"cell_type": "markdown",
"source": [
"### Initializacja modelu NER"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
2024-06-09 12:16:39 +02:00
"execution_count": 5,
2024-06-08 15:52:20 +02:00
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Some weights of the model checkpoint at dbmdz/bert-large-cased-finetuned-conll03-english were not used when initializing BertForTokenClassification: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']\n",
"- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
"- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
]
}
],
"source": [
"nlp = pipeline(\"ner\", model = 'dbmdz/bert-large-cased-finetuned-conll03-english')"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
2024-06-09 12:16:39 +02:00
"start_time": "2024-06-09T12:13:40.436629Z",
"end_time": "2024-06-09T12:13:43.520630Z"
2024-06-08 15:52:20 +02:00
}
}
},
{
"cell_type": "markdown",
"source": [
"### Metody do tokenizacji"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 3,
"outputs": [],
"source": [
"def get_word_indices(string_to_search):\n",
" pattern = \"\\s\\S\"\n",
" matches = re.finditer(pattern, string_to_search)\n",
" indices = [m.start(0) + 1 for m in matches]\n",
" if not string_to_search[0].isspace():\n",
" indices.insert(0, 0)\n",
" return sorted(indices)\n",
"\n",
"def get_word_beginning(string_to_search, letter_index):\n",
" while letter_index > 0 and string_to_search[letter_index - 1] != \" \":\n",
" letter_index -= 1\n",
" return letter_index\n",
"\n",
"def wordpiece_tokenization(ner_tokenized, original_sentence):\n",
" word_start_index_to_tag = {}\n",
" formatted_results = []\n",
" previous_tag = \"O\"\n",
"\n",
" for result in ner_tokenized:\n",
" word = result[\"word\"].replace(\"##\", \"\")\n",
" start, end = result[\"start\"], result[\"start\"] + len(word)\n",
"\n",
" if formatted_results and (original_sentence[result[\"start\"] - 1] != \" \" or result[\"word\"].startswith(\"##\")):\n",
" formatted_results[-1][\"end\"] = end\n",
" formatted_results[-1][\"word\"] += word\n",
" else:\n",
" result[\"word\"] = word\n",
" result[\"start\"] = get_word_beginning(original_sentence, start)\n",
" result[\"end\"] = end\n",
" formatted_results.append(result)\n",
"\n",
" for result in formatted_results:\n",
" start_index = result[\"start\"]\n",
" tag = result[\"entity\"]\n",
"\n",
" if tag != \"O\":\n",
" if previous_tag != tag:\n",
" tag = f\"B-{tag.split('-')[-1]}\"\n",
" else:\n",
" tag = f\"I-{tag.split('-')[-1]}\"\n",
" word_start_index_to_tag[start_index] = tag\n",
" previous_tag = result[\"entity\"]\n",
"\n",
" for index in get_word_indices(original_sentence):\n",
" word_start_index_to_tag.setdefault(index, \"O\")\n",
"\n",
" return [word_start_index_to_tag[index] for index in sorted(word_start_index_to_tag.keys())]"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"start_time": "2024-06-05T22:18:22.319194Z",
"end_time": "2024-06-05T22:18:22.343447Z"
}
}
},
2024-06-09 12:16:39 +02:00
{
"cell_type": "markdown",
"source": [
"### Przykładowe użycie"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 9,
"outputs": [
{
"data": {
"text/plain": "[{'entity': 'I-ORG',\n 'score': 0.9995635,\n 'index': 1,\n 'word': 'Hu',\n 'start': 0,\n 'end': 2},\n {'entity': 'I-ORG',\n 'score': 0.99159384,\n 'index': 2,\n 'word': '##gging',\n 'start': 2,\n 'end': 7},\n {'entity': 'I-ORG',\n 'score': 0.99826705,\n 'index': 3,\n 'word': 'Face',\n 'start': 8,\n 'end': 12},\n {'entity': 'I-ORG',\n 'score': 0.9994404,\n 'index': 4,\n 'word': 'Inc',\n 'start': 13,\n 'end': 16},\n {'entity': 'I-LOC',\n 'score': 0.99943465,\n 'index': 11,\n 'word': 'New',\n 'start': 40,\n 'end': 43},\n {'entity': 'I-LOC',\n 'score': 0.99932706,\n 'index': 12,\n 'word': 'York',\n 'start': 44,\n 'end': 48},\n {'entity': 'I-LOC',\n 'score': 0.9993864,\n 'index': 13,\n 'word': 'City',\n 'start': 49,\n 'end': 53},\n {'entity': 'I-LOC',\n 'score': 0.9825622,\n 'index': 19,\n 'word': 'D',\n 'start': 79,\n 'end': 80},\n {'entity': 'I-LOC',\n 'score': 0.936983,\n 'index': 20,\n 'word': '##UM',\n 'start': 80,\n 'end': 82},\n {'entity': 'I-LOC',\n 'score': 0.89870995,\n 'index': 21,\n 'word': '##BO',\n 'start': 82,\n 'end': 84},\n {'entity': 'I-LOC',\n 'score': 0.97582406,\n 'index': 29,\n 'word': 'Manhattan',\n 'start': 113,\n 'end': 122},\n {'entity': 'I-LOC',\n 'score': 0.99024945,\n 'index': 30,\n 'word': 'Bridge',\n 'start': 123,\n 'end': 129}]"
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sequence = \"Hugging Face Inc. is a company based in New York City. Its headquarters are in DUMBO, therefore very\" \\\n",
" \"close to the Manhattan Bridge which is visible from the window.\"\n",
"model_out = nlp(sequence)\n",
"model_out"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"start_time": "2024-06-09T12:14:36.626686Z",
"end_time": "2024-06-09T12:14:36.815685Z"
}
}
},
2024-06-08 15:52:20 +02:00
{
"cell_type": "markdown",
"source": [
"### Tokenizacja plików"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 4,
"outputs": [],
"source": [
"def tokenize_file(input_file, output_file):\n",
" with open(input_file, \"r\", encoding=\"utf-8\") as f:\n",
" original_sentences = f.readlines()\n",
"\n",
" processed_data = []\n",
" for raw_sentence in tqdm(original_sentences, desc=f\"Processing {input_file}\"):\n",
" model_out = nlp(raw_sentence.strip())\n",
" word_tokenization = wordpiece_tokenization(model_out, raw_sentence.strip())\n",
" processed_line = \" \".join(word_tokenization)\n",
" processed_data.append(processed_line)\n",
"\n",
" with open(output_file, \"w\", encoding=\"utf-8\") as f:\n",
" for line in processed_data:\n",
" f.write(f\"{line}\\n\")"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"start_time": "2024-06-05T22:18:22.339446Z",
"end_time": "2024-06-05T22:18:22.350525Z"
}
}
},
{
"cell_type": "markdown",
"source": [
"### Ewaluacja"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 5,
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Processing dev-0/in.tsv: 100%|██████████| 215/215 [03:28<00:00, 1.03it/s]\n"
]
}
],
"source": [
"tokenize_file(\"dev-0/in.tsv\", \"dev-0/out.tsv\")"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"start_time": "2024-06-05T22:18:22.354491Z",
"end_time": "2024-06-05T22:21:51.061001Z"
}
}
},
{
"cell_type": "code",
"execution_count": 6,
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Processing test-A/in.tsv: 100%|██████████| 230/230 [03:42<00:00, 1.03it/s]\n"
]
}
],
"source": [
"tokenize_file(\"test-A/in.tsv\", \"test-A/out.tsv\")"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"start_time": "2024-06-05T22:21:51.062002Z",
"end_time": "2024-06-05T22:25:33.462085Z"
}
}
},
{
"cell_type": "markdown",
"source": [
"### Poprawienie etykiet"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 7,
"outputs": [],
"source": [
"def correct_labels(input_file, output_file):\n",
" df = pd.read_csv(input_file, sep=\"\\t\", names=[\"Text\"])\n",
"\n",
" corrected_lines = []\n",
"\n",
" for line in df[\"Text\"]:\n",
" tokens = line.split(\" \")\n",
" corrected_tokens = []\n",
" previous_token = \"O\"\n",
"\n",
" for token in tokens:\n",
" if (\n",
" token == \"I-ORG\"\n",
" and previous_token != \"B-ORG\"\n",
" and previous_token != \"I-ORG\"\n",
" ):\n",
" corrected_tokens.append(\"B-ORG\")\n",
" elif (\n",
" token == \"I-PER\"\n",
" and previous_token != \"B-PER\"\n",
" and previous_token != \"I-PER\"\n",
" ):\n",
" corrected_tokens.append(\"B-PER\")\n",
" elif (\n",
" token == \"I-LOC\"\n",
" and previous_token != \"B-LOC\"\n",
" and previous_token != \"I-LOC\"\n",
" ):\n",
" corrected_tokens.append(\"B-LOC\")\n",
" elif (\n",
" token == \"I-MISC\"\n",
" and previous_token != \"B-MISC\"\n",
" and previous_token != \"I-MISC\"\n",
" ):\n",
" corrected_tokens.append(\"B-MISC\")\n",
" else:\n",
" corrected_tokens.append(token)\n",
"\n",
" previous_token = token\n",
"\n",
" corrected_line = \" \".join(corrected_tokens)\n",
" corrected_lines.append(corrected_line)\n",
"\n",
" df[\"Text\"] = corrected_lines\n",
" df.to_csv(output_file, sep=\"\\t\", index=False, header=False)\n"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"start_time": "2024-06-05T22:25:33.468992Z",
"end_time": "2024-06-05T22:25:33.507038Z"
}
}
},
{
"cell_type": "code",
"execution_count": 8,
"outputs": [],
"source": [
"input_file = \"test-A/out.tsv\"\n",
"output_file = \"test-A/out.tsv\"\n",
"correct_labels(input_file, output_file)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"start_time": "2024-06-05T22:25:33.480964Z",
"end_time": "2024-06-05T22:25:33.593979Z"
}
}
},
{
"cell_type": "code",
"execution_count": 9,
"outputs": [],
"source": [
"input_file = \"dev-0/out.tsv\"\n",
"output_file = \"dev-0/out.tsv\"\n",
"correct_labels(input_file, output_file)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"start_time": "2024-06-05T22:25:33.589982Z",
"end_time": "2024-06-05T22:25:33.624147Z"
}
}
},
{
"cell_type": "markdown",
"source": [
"### Obliczenie dokładności"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 10,
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Processing dev-0/in.tsv: 100%|██████████| 215/215 [03:36<00:00, 1.01s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy: 0.9236\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"def calculate_accuracy(input_file, expected_file):\n",
" with open(input_file, \"r\", encoding=\"utf-8\") as f:\n",
" original_sentences = f.readlines()\n",
"\n",
" with open(expected_file, \"r\", encoding=\"utf-8\") as f:\n",
" expected_tags = f.readlines()\n",
"\n",
" total_tags = 0\n",
" correct_tags = 0\n",
"\n",
" for raw_sentence, expected_line in tqdm(zip(original_sentences, expected_tags), desc=f\"Processing {input_file}\", total=len(original_sentences)):\n",
" model_out = nlp(raw_sentence.strip())\n",
" word_tokenization = wordpiece_tokenization(model_out, raw_sentence.strip())\n",
" expected_tags_list = expected_line.strip().split()\n",
"\n",
" total_tags += len(expected_tags_list)\n",
" correct_tags += sum(p == e for p, e in zip(word_tokenization, expected_tags_list))\n",
"\n",
" accuracy = correct_tags / total_tags\n",
" print(f\"Accuracy: {accuracy:.4f}\")\n",
"\n",
"calculate_accuracy(\"dev-0/in.tsv\", \"dev-0/expected.tsv\")"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"start_time": "2024-06-05T22:25:33.625148Z",
"end_time": "2024-06-05T22:29:10.146791Z"
}
}
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 0
}