Transformer/transformer.ipynb

129 lines
5.8 KiB
Plaintext
Raw Permalink Normal View History

2024-06-06 23:12:09 +02:00
{
"cells": [
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"No model was supplied, defaulted to dbmdz/bert-large-cased-finetuned-conll03-english and revision f2482bf (https://huggingface.co/dbmdz/bert-large-cased-finetuned-conll03-english).\n",
"Using a pipeline without specifying a model name and revision in production is not recommended.\n",
"c:\\Users\\walcz\\Desktop\\studia\\uczenie\\RNN\\en-ner-conll-2003\\myenv\\Lib\\site-packages\\huggingface_hub\\file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
" warnings.warn(\n",
"Some weights of the model checkpoint at dbmdz/bert-large-cased-finetuned-conll03-english were not used when initializing BertForTokenClassification: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']\n",
"- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
"- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
"c:\\Users\\walcz\\Desktop\\studia\\uczenie\\RNN\\en-ner-conll-2003\\myenv\\Lib\\site-packages\\transformers\\pipelines\\token_classification.py:168: UserWarning: `grouped_entities` is deprecated and will be removed in version v5.0.0, defaulted to `aggregation_strategy=\"AggregationStrategy.SIMPLE\"` instead.\n",
" warnings.warn(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy for dev0: 94.88%\n"
]
}
],
"source": [
"import pandas as pd\n",
"from transformers import pipeline\n",
"\n",
"# Wczytanie danych\n",
"train_data = pd.read_csv('en-ner-conll-2003/train/train.tsv.xz', delimiter='\\t', header=None, compression='xz')\n",
"in_data_dev0 = pd.read_csv('en-ner-conll-2003/dev-0/in.tsv', delimiter='\\t', header=None)\n",
"expected_data = pd.read_csv('en-ner-conll-2003/dev-0/expected.tsv', delimiter='\\t', header=None)\n",
"in_data_testA = pd.read_csv('en-ner-conll-2003/test-A/in.tsv', delimiter='\\t', header=None)\n",
"\n",
"# Inicjalizacja pipeline NER\n",
"ner_pipeline = pipeline(\"ner\", grouped_entities=True)\n",
"\n",
"# Przetwarzanie danych z pliku in.tsv\n",
"sentences_dev0 = in_data_dev0[0].tolist()\n",
"senteces_testA = in_data_testA[0].tolist()\n",
"\n",
"ner_results_dev = ner_pipeline(sentences_dev0)\n",
"ner_results_testA = ner_pipeline(senteces_testA)\n",
"\n",
"# Funkcja do mapowania wyników NER na format B-XXX, I-XXX, O\n",
"def map_ner_results(ner_results, sentences):\n",
" ner_labels = []\n",
" for sentence, entities in zip(sentences, ner_results):\n",
" words = sentence.split()\n",
" labels = ['O'] * len(words)\n",
" for entity in entities:\n",
" start_idx = entity['start']\n",
" end_idx = entity['end']\n",
" entity_label = entity['entity_group']\n",
" entity_words = sentence[start_idx:end_idx].split()\n",
" start_word_idx = len(sentence[:start_idx].split())\n",
" end_word_idx = start_word_idx + len(entity_words)\n",
" if start_word_idx < len(labels) and end_word_idx <= len(labels):\n",
" labels[start_word_idx] = f'B-{entity_label}'\n",
" for i in range(start_word_idx + 1, end_word_idx):\n",
" labels[i] = f'I-{entity_label}'\n",
" ner_labels.append(labels)\n",
" return ner_labels\n",
"\n",
"# Mapowanie wyników NER na odpowiednie etykiety\n",
"predicted_labels_dev0 = map_ner_results(ner_results_dev, sentences_dev0)\n",
"predicted_labels_testA = map_ner_results(ner_results_testA, senteces_testA)\n",
"\n",
"# Konwersja listy etykiet na format ciągu znaków\n",
"predicted_strings_dev0 = [' '.join(labels) for labels in predicted_labels_dev0]\n",
"predicted_strings_testA = [' '.join(labels) for labels in predicted_labels_testA]\n",
"\n",
"expected_strings = expected_data[0].tolist()\n",
"\n",
"# Zapisanie wyników do pliku out.tsv\n",
"with open('en-ner-conll-2003/dev-0/out.tsv', 'w') as f:\n",
" for line in predicted_strings_dev0:\n",
" f.write(line + '\\n')\n",
"\n",
"with open('en-ner-conll-2003/test-A/out.tsv', 'w') as f:\n",
" for line in predicted_strings_testA:\n",
" f.write(line + '\\n')\n",
"\n",
"\n",
"# Sprawdzenie zgodności wyników\n",
"correct = 0\n",
"total = 0\n",
"for pred, exp in zip(predicted_strings_dev0, expected_strings):\n",
" pred_labels = pred.split()\n",
" exp_labels = exp.split()\n",
" for p, e in zip(pred_labels, exp_labels):\n",
" if p == e:\n",
" correct += 1\n",
" total += 1\n",
"\n",
"accuracy = correct / total\n",
"print(f\"Accuracy for dev0: {accuracy:.2%}\")\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.0"
}
},
"nbformat": 4,
"nbformat_minor": 2
}