From 5d57b4f6da5925f58bbef835ddbf68bbc7dfd51b Mon Sep 17 00:00:00 2001 From: s464903 Date: Sun, 9 Jun 2024 13:43:22 +0200 Subject: [PATCH] Upload files to "/" --- Transformer.ipynb | 223 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 223 insertions(+) create mode 100644 Transformer.ipynb diff --git a/Transformer.ipynb b/Transformer.ipynb new file mode 100644 index 0000000..c3aae2a --- /dev/null +++ b/Transformer.ipynb @@ -0,0 +1,223 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "5d0842c8-c292-41ce-a27b-73986bf43e1c", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "C:\\Users\\obses\\AppData\\Local\\Programs\\Python\\Python310\\lib\\site-packages\\torchtext\\vocab\\__init__.py:4: UserWarning: \n", + "/!\\ IMPORTANT WARNING ABOUT TORCHTEXT STATUS /!\\ \n", + "Torchtext is deprecated and the last released version will be 0.18 (this one). You can silence this warning by calling the following at the beginnign of your scripts: `import torchtext; torchtext.disable_torchtext_deprecation_warning()`\n", + " warnings.warn(torchtext._TORCHTEXT_DEPRECATION_MSG)\n", + "C:\\Users\\obses\\AppData\\Local\\Programs\\Python\\Python310\\lib\\site-packages\\torchtext\\utils.py:4: UserWarning: \n", + "/!\\ IMPORTANT WARNING ABOUT TORCHTEXT STATUS /!\\ \n", + "Torchtext is deprecated and the last released version will be 0.18 (this one). You can silence this warning by calling the following at the beginnign of your scripts: `import torchtext; torchtext.disable_torchtext_deprecation_warning()`\n", + " warnings.warn(torchtext._TORCHTEXT_DEPRECATION_MSG)\n" + ] + } + ], + "source": [ + "from collections import Counter\n", + "import torch\n", + "from datasets import load_dataset\n", + "from torchtext.vocab import vocab\n", + "from tqdm.notebook import tqdm\n", + "import pandas as pd\n", + "from nltk.tokenize import word_tokenize\n", + "import string\n", + "\n", + "from transformers import pipeline, AutoModelForTokenClassification, AutoTokenizer" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "4ccf08e2-fec1-4d68-a1fd-d33af6bd54bc", + "metadata": {}, + "outputs": [], + "source": [ + "dataset = pd.read_csv('train.tsv', sep='\\t', header=None)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "c027b2f3-48eb-442a-a212-c5fa9d1cfba5", + "metadata": {}, + "outputs": [], + "source": [ + "X_test = pd.read_csv('../dev-0/in.tsv', sep='\\t', header=None)\n", + "Y_test = pd.read_csv('../dev-0/expected.tsv', sep='\\t', header=None)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "9fc7a069-52ce-4343-a8a2-dd77619f5d25", + "metadata": {}, + "outputs": [], + "source": [ + "X_train = dataset[dataset.columns[1]]\n", + "Y_train = dataset[dataset.columns[0]]\n", + "\n", + "X_test = X_test[X_test.columns[0]]\n", + "Y_test = Y_test[Y_test.columns[0]]" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "7cd49893-c730-4089-9b21-8683c11fd6cc", + "metadata": {}, + "outputs": [], + "source": [ + "X_train = [text.split() for text in X_train]\n", + "X_test = [text.split() for text in X_test]" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "bd029e57-f5c3-4b2b-bcf0-ea25d71ee952", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Some weights of the model checkpoint at dbmdz/bert-large-cased-finetuned-conll03-english were not used when initializing BertForTokenClassification: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']\n", + "- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", + "- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "9f4a002be4b64bfdba0a227b9a6815f2", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "215\n" + ] + } + ], + "source": [ + "model = AutoModelForTokenClassification.from_pretrained(\"dbmdz/bert-large-cased-finetuned-conll03-english\")\n", + "tokenizer = AutoTokenizer.from_pretrained(\"google-bert/bert-base-cased\")\n", + "\n", + "recognizer = pipeline(\"ner\", model=model, tokenizer=tokenizer)\n", + "full_tags = []\n", + "for idx, el in tqdm(enumerate(X_test)):\n", + " out_tags = []\n", + " tags_corrected = []\n", + "\n", + " temp = []\n", + " tags_joined = []\n", + " tag = 'NULL'\n", + " t = \"\"\n", + " for r in recognizer(\" \".join(el)):\n", + " \n", + " if len(t) == 0:\n", + " t = r['word']\n", + " tag = r['entity']\n", + " continue\n", + " if \"#\" in r['word']:\n", + " t = t + str(r['word']).replace(\"#\",\"\")\n", + " continue\n", + " if \"#\" not in r['word'] and len(t) != 0:\n", + " temp.append(t)\n", + " tags_joined.append(tag)\n", + " t = r['word']\n", + " tag = r['entity']\n", + "\n", + " for tag in Y_test[idx].split():\n", + " if tag == \"O\":\n", + " out_tags.append(\"O\")\n", + " if tag != \"O\" and len(tags_joined) > 0:\n", + " out_tags.append(tags_joined[0])\n", + " tags_joined = tags_joined[1:]\n", + " continue\n", + " if tag != \"O\" and len(tags_joined) == 0:\n", + " out_tags.append(\"O\")\n", + "\n", + " #print(len(Y_test[idx].split()), len(out_tags))\n", + "\n", + " out_tags = \" \".join(out_tags).replace(\"I-\",\"B-\").split()\n", + " \n", + " last_tag = out_tags[0]\n", + " tags_corrected.append(last_tag)\n", + " \n", + " for tag in out_tags[1:]:\n", + "\n", + " if tag == last_tag:\n", + " tags_corrected.append(tag.replace(\"B-\",\"I-\"))\n", + " last_tag = tag\n", + " else:\n", + " last_tag = tag\n", + " tags_corrected.append(tag)\n", + " \n", + " #print(len(Y_test[idx].split()), len(tags_corrected))\n", + "\n", + " full_tags.append(tags_corrected)\n", + "\n", + "print(len(full_tags))\n", + "\n", + "# for idx, el in tqdm(enumerate(full_tags)):\n", + "# if len(el) != len(len(Y_test[idx].split())):\n", + "# print(\"Somethings wrong sir\")\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "0feb1763-b28b-4cee-a996-f493c3f3d037", + "metadata": {}, + "outputs": [], + "source": [ + "with open(\"out.tsv\", 'w') as file:\n", + " for el in full_tags:\n", + " \n", + " file.write(\" \".join(el)) \n", + " file.write(f\"\\n\")\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}