s464953_uczenie_glebokie_tr.../transformer.ipynb

219 lines
7.7 KiB
Plaintext
Raw Permalink Normal View History

2024-06-10 01:10:54 +02:00
{
"cells": [
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Some weights of the model checkpoint at dbmdz/bert-large-cased-finetuned-conll03-english were not used when initializing BertForTokenClassification: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']\n",
"- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
"- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Sentence: CRICKET - LEICESTERSHIRE TAKE OVER AT TOP AFTER INNINGS VICTORY . </S> LONDON 1996-08-30 </S> West Indian all-rounder Phil\n",
"Tokens: ['L', 'LONDON', 'West', 'Indian', 'Phil']\n",
"Labels: ['I-PER', 'I-LOC', 'I-MISC', 'I-MISC', 'I-PER']\n"
]
}
],
"source": [
"from transformers import pipeline\n",
"import pandas as pd\n",
"import re\n",
"from transformers import pipeline\n",
"\n",
"ner_pipeline = pipeline(\"ner\", model=\"dbmdz/bert-large-cased-finetuned-conll03-english\")\n",
"\n",
"input_text = \"CRICKET - LEICESTERSHIRE TAKE OVER AT TOP AFTER INNINGS VICTORY . </S> LONDON 1996-08-30 </S> West Indian all-rounder Phil\"\n",
"\n",
"def predict_and_combine(text):\n",
" ner_results = ner_pipeline(text)\n",
" combined_tokens = []\n",
" combined_labels = []\n",
" current_word = \"\"\n",
" current_label = None\n",
"\n",
" for result in ner_results:\n",
" token = result['word']\n",
" label = result['entity']\n",
" if token.startswith(\"##\"):\n",
" current_word += token[2:]\n",
" else:\n",
" if current_word:\n",
" combined_tokens.append(current_word)\n",
" combined_labels.append(current_label)\n",
" current_word = token\n",
" current_label = label\n",
"\n",
" if current_word:\n",
" combined_tokens.append(current_word)\n",
" combined_labels.append(current_label)\n",
"\n",
" return combined_tokens, combined_labels\n",
"\n",
"tokens, labels = predict_and_combine(input_text)\n",
"\n",
"print(f\"Sentence: {input_text}\")\n",
"print(\"Tokens:\", tokens)\n",
"print(\"Labels:\", labels)\n"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"def find_word_starts(text):\n",
" indices = [match.start() + 1 for match in re.finditer(r\"\\s\\S\", text)]\n",
" if not text[0].isspace():\n",
" indices.insert(0, 0)\n",
" return sorted(indices)\n",
"\n",
"def find_word_start(text, index):\n",
" while index > 0 and text[index - 1] != \" \":\n",
" index -= 1\n",
" return index\n",
"\n",
"def merge_wordpieces(ner_tokens, original_sentence):\n",
" results = []\n",
" for token in ner_tokens:\n",
" if token['word'].startswith(\"##\") and results and token['start'] == results[-1]['end']:\n",
" results[-1]['end'] = token['end']\n",
" results[-1]['word'] += token['word'][2:]\n",
" else:\n",
" if results and not original_sentence[token['start'] - 1].isspace():\n",
" results[-1]['end'] = token['end']\n",
" results[-1]['word'] += token['word']\n",
" else:\n",
" token['start'] = find_word_start(original_sentence, token['start'])\n",
" results.append(token)\n",
" \n",
" word_start_to_tag = {result['start']: result['entity'] for result in results}\n",
" for index in find_word_starts(original_sentence):\n",
" if index not in word_start_to_tag:\n",
" word_start_to_tag[index] = \"O\"\n",
" \n",
" return [word_start_to_tag[index] for index in sorted(word_start_to_tag.keys())]\n",
"\n",
"def predict_and_merge(text):\n",
" return ner_pipeline(text)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"dev_data = pd.read_csv(\"dev-0/in.tsv\", sep=\"\\t\", names=[\"Text\"])\n",
"dev_labels = pd.read_csv(\"dev-0/expected.tsv\", sep=\"\\t\", names=[\"Label\"])\n",
"\n",
"dev_data[\"NER_Results\"] = dev_data[\"Text\"].apply(predict_and_merge)\n",
"processed_data = []\n",
"\n",
"for i, (model_out, raw_sentence) in enumerate(zip(dev_data[\"NER_Results\"], dev_data[\"Text\"])):\n",
" merged_tokens = merge_wordpieces(model_out, raw_sentence)\n",
" processed_line = \" \".join(merged_tokens)\n",
" processed_data.append(processed_line)\n",
" \n",
" if len(merged_tokens) != len(raw_sentence.split()):\n",
" raise AssertionError\n",
"\n",
"with open(\"dev-0/out_unprocessed.tsv\", \"w\", encoding=\"utf-8\") as f:\n",
" for line in processed_data:\n",
" f.write(f\"{line}\\n\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy: 0.8418625244437885\n"
]
}
],
"source": [
"from sklearn.metrics import accuracy_score\n",
"\n",
"with open('dev-0/out.tsv', 'r') as file:\n",
" predicted_labels = [line.strip().split()[1:] for line in file]\n",
"\n",
"with open('dev-0/expected.tsv', 'r') as file:\n",
" true_labels = [line.strip().split()[1:] for line in file]\n",
"\n",
"predicted_labels = [label for sublist in predicted_labels for label in sublist]\n",
"true_labels = [label for sublist in true_labels for label in sublist]\n",
"\n",
"accuracy = accuracy_score(true_labels, predicted_labels)\n",
"print(\"Accuracy:\", accuracy)"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"dev_data = pd.read_csv(\"test-A/in.tsv\", sep=\"\\t\", names=[\"Text\"])\n",
"\n",
"dev_data[\"NER_Results\"] = dev_data[\"Text\"].apply(predict_and_merge)\n",
"processed_data = []\n",
"\n",
"for i, (model_out, raw_sentence) in enumerate(zip(dev_data[\"NER_Results\"], dev_data[\"Text\"])):\n",
" merged_tokens = merge_wordpieces(model_out, raw_sentence)\n",
" processed_line = \" \".join(merged_tokens)\n",
" processed_data.append(processed_line)\n",
" \n",
" if len(merged_tokens) != len(raw_sentence.split()):\n",
" raise AssertionError\n",
"\n",
"with open(\"test-A/out_unprocessed.tsv\", \"w\", encoding=\"utf-8\") as f:\n",
" for line in processed_data:\n",
" f.write(f\"{line}\\n\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
}
},
"nbformat": 4,
"nbformat_minor": 2
}