Upload files to "/"
This commit is contained in:
parent
ce0d4d2fba
commit
5d57b4f6da
223
Transformer.ipynb
Normal file
223
Transformer.ipynb
Normal file
@ -0,0 +1,223 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"id": "5d0842c8-c292-41ce-a27b-73986bf43e1c",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"C:\\Users\\obses\\AppData\\Local\\Programs\\Python\\Python310\\lib\\site-packages\\torchtext\\vocab\\__init__.py:4: UserWarning: \n",
|
||||||
|
"/!\\ IMPORTANT WARNING ABOUT TORCHTEXT STATUS /!\\ \n",
|
||||||
|
"Torchtext is deprecated and the last released version will be 0.18 (this one). You can silence this warning by calling the following at the beginnign of your scripts: `import torchtext; torchtext.disable_torchtext_deprecation_warning()`\n",
|
||||||
|
" warnings.warn(torchtext._TORCHTEXT_DEPRECATION_MSG)\n",
|
||||||
|
"C:\\Users\\obses\\AppData\\Local\\Programs\\Python\\Python310\\lib\\site-packages\\torchtext\\utils.py:4: UserWarning: \n",
|
||||||
|
"/!\\ IMPORTANT WARNING ABOUT TORCHTEXT STATUS /!\\ \n",
|
||||||
|
"Torchtext is deprecated and the last released version will be 0.18 (this one). You can silence this warning by calling the following at the beginnign of your scripts: `import torchtext; torchtext.disable_torchtext_deprecation_warning()`\n",
|
||||||
|
" warnings.warn(torchtext._TORCHTEXT_DEPRECATION_MSG)\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"from collections import Counter\n",
|
||||||
|
"import torch\n",
|
||||||
|
"from datasets import load_dataset\n",
|
||||||
|
"from torchtext.vocab import vocab\n",
|
||||||
|
"from tqdm.notebook import tqdm\n",
|
||||||
|
"import pandas as pd\n",
|
||||||
|
"from nltk.tokenize import word_tokenize\n",
|
||||||
|
"import string\n",
|
||||||
|
"\n",
|
||||||
|
"from transformers import pipeline, AutoModelForTokenClassification, AutoTokenizer"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"id": "4ccf08e2-fec1-4d68-a1fd-d33af6bd54bc",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"dataset = pd.read_csv('train.tsv', sep='\\t', header=None)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"id": "c027b2f3-48eb-442a-a212-c5fa9d1cfba5",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"X_test = pd.read_csv('../dev-0/in.tsv', sep='\\t', header=None)\n",
|
||||||
|
"Y_test = pd.read_csv('../dev-0/expected.tsv', sep='\\t', header=None)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 4,
|
||||||
|
"id": "9fc7a069-52ce-4343-a8a2-dd77619f5d25",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"X_train = dataset[dataset.columns[1]]\n",
|
||||||
|
"Y_train = dataset[dataset.columns[0]]\n",
|
||||||
|
"\n",
|
||||||
|
"X_test = X_test[X_test.columns[0]]\n",
|
||||||
|
"Y_test = Y_test[Y_test.columns[0]]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 5,
|
||||||
|
"id": "7cd49893-c730-4089-9b21-8683c11fd6cc",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"X_train = [text.split() for text in X_train]\n",
|
||||||
|
"X_test = [text.split() for text in X_test]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 6,
|
||||||
|
"id": "bd029e57-f5c3-4b2b-bcf0-ea25d71ee952",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Some weights of the model checkpoint at dbmdz/bert-large-cased-finetuned-conll03-english were not used when initializing BertForTokenClassification: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']\n",
|
||||||
|
"- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
|
||||||
|
"- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"application/vnd.jupyter.widget-view+json": {
|
||||||
|
"model_id": "9f4a002be4b64bfdba0a227b9a6815f2",
|
||||||
|
"version_major": 2,
|
||||||
|
"version_minor": 0
|
||||||
|
},
|
||||||
|
"text/plain": [
|
||||||
|
"0it [00:00, ?it/s]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "display_data"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"215\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"model = AutoModelForTokenClassification.from_pretrained(\"dbmdz/bert-large-cased-finetuned-conll03-english\")\n",
|
||||||
|
"tokenizer = AutoTokenizer.from_pretrained(\"google-bert/bert-base-cased\")\n",
|
||||||
|
"\n",
|
||||||
|
"recognizer = pipeline(\"ner\", model=model, tokenizer=tokenizer)\n",
|
||||||
|
"full_tags = []\n",
|
||||||
|
"for idx, el in tqdm(enumerate(X_test)):\n",
|
||||||
|
" out_tags = []\n",
|
||||||
|
" tags_corrected = []\n",
|
||||||
|
"\n",
|
||||||
|
" temp = []\n",
|
||||||
|
" tags_joined = []\n",
|
||||||
|
" tag = 'NULL'\n",
|
||||||
|
" t = \"\"\n",
|
||||||
|
" for r in recognizer(\" \".join(el)):\n",
|
||||||
|
" \n",
|
||||||
|
" if len(t) == 0:\n",
|
||||||
|
" t = r['word']\n",
|
||||||
|
" tag = r['entity']\n",
|
||||||
|
" continue\n",
|
||||||
|
" if \"#\" in r['word']:\n",
|
||||||
|
" t = t + str(r['word']).replace(\"#\",\"\")\n",
|
||||||
|
" continue\n",
|
||||||
|
" if \"#\" not in r['word'] and len(t) != 0:\n",
|
||||||
|
" temp.append(t)\n",
|
||||||
|
" tags_joined.append(tag)\n",
|
||||||
|
" t = r['word']\n",
|
||||||
|
" tag = r['entity']\n",
|
||||||
|
"\n",
|
||||||
|
" for tag in Y_test[idx].split():\n",
|
||||||
|
" if tag == \"O\":\n",
|
||||||
|
" out_tags.append(\"O\")\n",
|
||||||
|
" if tag != \"O\" and len(tags_joined) > 0:\n",
|
||||||
|
" out_tags.append(tags_joined[0])\n",
|
||||||
|
" tags_joined = tags_joined[1:]\n",
|
||||||
|
" continue\n",
|
||||||
|
" if tag != \"O\" and len(tags_joined) == 0:\n",
|
||||||
|
" out_tags.append(\"O\")\n",
|
||||||
|
"\n",
|
||||||
|
" #print(len(Y_test[idx].split()), len(out_tags))\n",
|
||||||
|
"\n",
|
||||||
|
" out_tags = \" \".join(out_tags).replace(\"I-\",\"B-\").split()\n",
|
||||||
|
" \n",
|
||||||
|
" last_tag = out_tags[0]\n",
|
||||||
|
" tags_corrected.append(last_tag)\n",
|
||||||
|
" \n",
|
||||||
|
" for tag in out_tags[1:]:\n",
|
||||||
|
"\n",
|
||||||
|
" if tag == last_tag:\n",
|
||||||
|
" tags_corrected.append(tag.replace(\"B-\",\"I-\"))\n",
|
||||||
|
" last_tag = tag\n",
|
||||||
|
" else:\n",
|
||||||
|
" last_tag = tag\n",
|
||||||
|
" tags_corrected.append(tag)\n",
|
||||||
|
" \n",
|
||||||
|
" #print(len(Y_test[idx].split()), len(tags_corrected))\n",
|
||||||
|
"\n",
|
||||||
|
" full_tags.append(tags_corrected)\n",
|
||||||
|
"\n",
|
||||||
|
"print(len(full_tags))\n",
|
||||||
|
"\n",
|
||||||
|
"# for idx, el in tqdm(enumerate(full_tags)):\n",
|
||||||
|
"# if len(el) != len(len(Y_test[idx].split())):\n",
|
||||||
|
"# print(\"Somethings wrong sir\")\n",
|
||||||
|
" "
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 7,
|
||||||
|
"id": "0feb1763-b28b-4cee-a996-f493c3f3d037",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"with open(\"out.tsv\", 'w') as file:\n",
|
||||||
|
" for el in full_tags:\n",
|
||||||
|
" \n",
|
||||||
|
" file.write(\" \".join(el)) \n",
|
||||||
|
" file.write(f\"\\n\")\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.10.0"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user