Upload files to "/"

This commit is contained in:
s464903 2024-06-09 13:43:22 +02:00
parent ce0d4d2fba
commit 5d57b4f6da

223
Transformer.ipynb Normal file
View File

@ -0,0 +1,223 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "5d0842c8-c292-41ce-a27b-73986bf43e1c",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"C:\\Users\\obses\\AppData\\Local\\Programs\\Python\\Python310\\lib\\site-packages\\torchtext\\vocab\\__init__.py:4: UserWarning: \n",
"/!\\ IMPORTANT WARNING ABOUT TORCHTEXT STATUS /!\\ \n",
"Torchtext is deprecated and the last released version will be 0.18 (this one). You can silence this warning by calling the following at the beginnign of your scripts: `import torchtext; torchtext.disable_torchtext_deprecation_warning()`\n",
" warnings.warn(torchtext._TORCHTEXT_DEPRECATION_MSG)\n",
"C:\\Users\\obses\\AppData\\Local\\Programs\\Python\\Python310\\lib\\site-packages\\torchtext\\utils.py:4: UserWarning: \n",
"/!\\ IMPORTANT WARNING ABOUT TORCHTEXT STATUS /!\\ \n",
"Torchtext is deprecated and the last released version will be 0.18 (this one). You can silence this warning by calling the following at the beginnign of your scripts: `import torchtext; torchtext.disable_torchtext_deprecation_warning()`\n",
" warnings.warn(torchtext._TORCHTEXT_DEPRECATION_MSG)\n"
]
}
],
"source": [
"from collections import Counter\n",
"import torch\n",
"from datasets import load_dataset\n",
"from torchtext.vocab import vocab\n",
"from tqdm.notebook import tqdm\n",
"import pandas as pd\n",
"from nltk.tokenize import word_tokenize\n",
"import string\n",
"\n",
"from transformers import pipeline, AutoModelForTokenClassification, AutoTokenizer"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "4ccf08e2-fec1-4d68-a1fd-d33af6bd54bc",
"metadata": {},
"outputs": [],
"source": [
"dataset = pd.read_csv('train.tsv', sep='\\t', header=None)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "c027b2f3-48eb-442a-a212-c5fa9d1cfba5",
"metadata": {},
"outputs": [],
"source": [
"X_test = pd.read_csv('../dev-0/in.tsv', sep='\\t', header=None)\n",
"Y_test = pd.read_csv('../dev-0/expected.tsv', sep='\\t', header=None)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "9fc7a069-52ce-4343-a8a2-dd77619f5d25",
"metadata": {},
"outputs": [],
"source": [
"X_train = dataset[dataset.columns[1]]\n",
"Y_train = dataset[dataset.columns[0]]\n",
"\n",
"X_test = X_test[X_test.columns[0]]\n",
"Y_test = Y_test[Y_test.columns[0]]"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "7cd49893-c730-4089-9b21-8683c11fd6cc",
"metadata": {},
"outputs": [],
"source": [
"X_train = [text.split() for text in X_train]\n",
"X_test = [text.split() for text in X_test]"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "bd029e57-f5c3-4b2b-bcf0-ea25d71ee952",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Some weights of the model checkpoint at dbmdz/bert-large-cased-finetuned-conll03-english were not used when initializing BertForTokenClassification: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']\n",
"- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
"- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "9f4a002be4b64bfdba0a227b9a6815f2",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"215\n"
]
}
],
"source": [
"model = AutoModelForTokenClassification.from_pretrained(\"dbmdz/bert-large-cased-finetuned-conll03-english\")\n",
"tokenizer = AutoTokenizer.from_pretrained(\"google-bert/bert-base-cased\")\n",
"\n",
"recognizer = pipeline(\"ner\", model=model, tokenizer=tokenizer)\n",
"full_tags = []\n",
"for idx, el in tqdm(enumerate(X_test)):\n",
" out_tags = []\n",
" tags_corrected = []\n",
"\n",
" temp = []\n",
" tags_joined = []\n",
" tag = 'NULL'\n",
" t = \"\"\n",
" for r in recognizer(\" \".join(el)):\n",
" \n",
" if len(t) == 0:\n",
" t = r['word']\n",
" tag = r['entity']\n",
" continue\n",
" if \"#\" in r['word']:\n",
" t = t + str(r['word']).replace(\"#\",\"\")\n",
" continue\n",
" if \"#\" not in r['word'] and len(t) != 0:\n",
" temp.append(t)\n",
" tags_joined.append(tag)\n",
" t = r['word']\n",
" tag = r['entity']\n",
"\n",
" for tag in Y_test[idx].split():\n",
" if tag == \"O\":\n",
" out_tags.append(\"O\")\n",
" if tag != \"O\" and len(tags_joined) > 0:\n",
" out_tags.append(tags_joined[0])\n",
" tags_joined = tags_joined[1:]\n",
" continue\n",
" if tag != \"O\" and len(tags_joined) == 0:\n",
" out_tags.append(\"O\")\n",
"\n",
" #print(len(Y_test[idx].split()), len(out_tags))\n",
"\n",
" out_tags = \" \".join(out_tags).replace(\"I-\",\"B-\").split()\n",
" \n",
" last_tag = out_tags[0]\n",
" tags_corrected.append(last_tag)\n",
" \n",
" for tag in out_tags[1:]:\n",
"\n",
" if tag == last_tag:\n",
" tags_corrected.append(tag.replace(\"B-\",\"I-\"))\n",
" last_tag = tag\n",
" else:\n",
" last_tag = tag\n",
" tags_corrected.append(tag)\n",
" \n",
" #print(len(Y_test[idx].split()), len(tags_corrected))\n",
"\n",
" full_tags.append(tags_corrected)\n",
"\n",
"print(len(full_tags))\n",
"\n",
"# for idx, el in tqdm(enumerate(full_tags)):\n",
"# if len(el) != len(len(Y_test[idx].split())):\n",
"# print(\"Somethings wrong sir\")\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "0feb1763-b28b-4cee-a996-f493c3f3d037",
"metadata": {},
"outputs": [],
"source": [
"with open(\"out.tsv\", 'w') as file:\n",
" for el in full_tags:\n",
" \n",
" file.write(\" \".join(el)) \n",
" file.write(f\"\\n\")\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}