{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "5d0842c8-c292-41ce-a27b-73986bf43e1c", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "C:\\Users\\obses\\AppData\\Local\\Programs\\Python\\Python310\\lib\\site-packages\\torchtext\\vocab\\__init__.py:4: UserWarning: \n", "/!\\ IMPORTANT WARNING ABOUT TORCHTEXT STATUS /!\\ \n", "Torchtext is deprecated and the last released version will be 0.18 (this one). You can silence this warning by calling the following at the beginnign of your scripts: `import torchtext; torchtext.disable_torchtext_deprecation_warning()`\n", " warnings.warn(torchtext._TORCHTEXT_DEPRECATION_MSG)\n", "C:\\Users\\obses\\AppData\\Local\\Programs\\Python\\Python310\\lib\\site-packages\\torchtext\\utils.py:4: UserWarning: \n", "/!\\ IMPORTANT WARNING ABOUT TORCHTEXT STATUS /!\\ \n", "Torchtext is deprecated and the last released version will be 0.18 (this one). You can silence this warning by calling the following at the beginnign of your scripts: `import torchtext; torchtext.disable_torchtext_deprecation_warning()`\n", " warnings.warn(torchtext._TORCHTEXT_DEPRECATION_MSG)\n" ] } ], "source": [ "from collections import Counter\n", "import torch\n", "from datasets import load_dataset\n", "from torchtext.vocab import vocab\n", "from tqdm.notebook import tqdm\n", "import pandas as pd\n", "from nltk.tokenize import word_tokenize\n", "import string\n", "\n", "from transformers import pipeline, AutoModelForTokenClassification, AutoTokenizer" ] }, { "cell_type": "code", "execution_count": 2, "id": "4ccf08e2-fec1-4d68-a1fd-d33af6bd54bc", "metadata": {}, "outputs": [], "source": [ "dataset = pd.read_csv('train.tsv', sep='\\t', header=None)" ] }, { "cell_type": "code", "execution_count": 3, "id": "c027b2f3-48eb-442a-a212-c5fa9d1cfba5", "metadata": {}, "outputs": [], "source": [ "X_test = pd.read_csv('../dev-0/in.tsv', sep='\\t', header=None)\n", "Y_test = pd.read_csv('../dev-0/expected.tsv', sep='\\t', header=None)" ] }, { "cell_type": "code", "execution_count": 4, "id": "9fc7a069-52ce-4343-a8a2-dd77619f5d25", "metadata": {}, "outputs": [], "source": [ "X_train = dataset[dataset.columns[1]]\n", "Y_train = dataset[dataset.columns[0]]\n", "\n", "X_test = X_test[X_test.columns[0]]\n", "Y_test = Y_test[Y_test.columns[0]]" ] }, { "cell_type": "code", "execution_count": 5, "id": "7cd49893-c730-4089-9b21-8683c11fd6cc", "metadata": {}, "outputs": [], "source": [ "X_train = [text.split() for text in X_train]\n", "X_test = [text.split() for text in X_test]" ] }, { "cell_type": "code", "execution_count": 6, "id": "bd029e57-f5c3-4b2b-bcf0-ea25d71ee952", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Some weights of the model checkpoint at dbmdz/bert-large-cased-finetuned-conll03-english were not used when initializing BertForTokenClassification: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']\n", "- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "9f4a002be4b64bfdba0a227b9a6815f2", "version_major": 2, "version_minor": 0 }, "text/plain": [ "0it [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "215\n" ] } ], "source": [ "model = AutoModelForTokenClassification.from_pretrained(\"dbmdz/bert-large-cased-finetuned-conll03-english\")\n", "tokenizer = AutoTokenizer.from_pretrained(\"google-bert/bert-base-cased\")\n", "\n", "recognizer = pipeline(\"ner\", model=model, tokenizer=tokenizer)\n", "full_tags = []\n", "for idx, el in tqdm(enumerate(X_test)):\n", " out_tags = []\n", " tags_corrected = []\n", "\n", " temp = []\n", " tags_joined = []\n", " tag = 'NULL'\n", " t = \"\"\n", " for r in recognizer(\" \".join(el)):\n", " \n", " if len(t) == 0:\n", " t = r['word']\n", " tag = r['entity']\n", " continue\n", " if \"#\" in r['word']:\n", " t = t + str(r['word']).replace(\"#\",\"\")\n", " continue\n", " if \"#\" not in r['word'] and len(t) != 0:\n", " temp.append(t)\n", " tags_joined.append(tag)\n", " t = r['word']\n", " tag = r['entity']\n", "\n", " for tag in Y_test[idx].split():\n", " if tag == \"O\":\n", " out_tags.append(\"O\")\n", " if tag != \"O\" and len(tags_joined) > 0:\n", " out_tags.append(tags_joined[0])\n", " tags_joined = tags_joined[1:]\n", " continue\n", " if tag != \"O\" and len(tags_joined) == 0:\n", " out_tags.append(\"O\")\n", "\n", " #print(len(Y_test[idx].split()), len(out_tags))\n", "\n", " out_tags = \" \".join(out_tags).replace(\"I-\",\"B-\").split()\n", " \n", " last_tag = out_tags[0]\n", " tags_corrected.append(last_tag)\n", " \n", " for tag in out_tags[1:]:\n", "\n", " if tag == last_tag:\n", " tags_corrected.append(tag.replace(\"B-\",\"I-\"))\n", " last_tag = tag\n", " else:\n", " last_tag = tag\n", " tags_corrected.append(tag)\n", " \n", " #print(len(Y_test[idx].split()), len(tags_corrected))\n", "\n", " full_tags.append(tags_corrected)\n", "\n", "print(len(full_tags))\n", "\n", "# for idx, el in tqdm(enumerate(full_tags)):\n", "# if len(el) != len(len(Y_test[idx].split())):\n", "# print(\"Somethings wrong sir\")\n", " " ] }, { "cell_type": "code", "execution_count": 7, "id": "0feb1763-b28b-4cee-a996-f493c3f3d037", "metadata": {}, "outputs": [], "source": [ "with open(\"out.tsv\", 'w') as file:\n", " for el in full_tags:\n", " \n", " file.write(\" \".join(el)) \n", " file.write(f\"\\n\")\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.0" } }, "nbformat": 4, "nbformat_minor": 5 }