DL_RNN/RNN.ipynb

1665 lines
50 KiB
Plaintext
Raw Normal View History

2024-05-25 18:52:33 +02:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## RNN\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Installation of packages\n"
]
},
{
"cell_type": "code",
2024-05-26 13:40:54 +02:00
"execution_count": 1,
2024-05-25 18:52:33 +02:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: torch in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (2.3.0)\n",
"Requirement already satisfied: filelock in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from torch) (3.14.0)\n",
"Requirement already satisfied: typing-extensions>=4.8.0 in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from torch) (4.10.0)\n",
"Requirement already satisfied: sympy in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from torch) (1.12)\n",
"Requirement already satisfied: networkx in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from torch) (3.2.1)\n",
"Requirement already satisfied: jinja2 in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from torch) (3.1.3)\n",
"Requirement already satisfied: fsspec in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from torch) (2024.3.1)\n",
"Requirement already satisfied: mkl<=2021.4.0,>=2021.1.1 in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from torch) (2021.4.0)\n",
"Requirement already satisfied: intel-openmp==2021.* in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from mkl<=2021.4.0,>=2021.1.1->torch) (2021.4.0)\n",
"Requirement already satisfied: tbb==2021.* in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from mkl<=2021.4.0,>=2021.1.1->torch) (2021.12.0)\n",
"Requirement already satisfied: MarkupSafe>=2.0 in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from jinja2->torch) (2.1.5)\n",
"Requirement already satisfied: mpmath>=0.19 in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from sympy->torch) (1.3.0)\n",
"Note: you may need to restart the kernel to use updated packages.\n",
"Requirement already satisfied: torchtext in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (0.18.0)\n",
"Requirement already satisfied: tqdm in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from torchtext) (4.66.2)\n",
"Requirement already satisfied: requests in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from torchtext) (2.31.0)\n",
"Requirement already satisfied: torch>=2.3.0 in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from torchtext) (2.3.0)\n",
"Requirement already satisfied: numpy in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from torchtext) (1.26.3)\n",
"Requirement already satisfied: filelock in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from torch>=2.3.0->torchtext) (3.14.0)\n",
"Requirement already satisfied: typing-extensions>=4.8.0 in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from torch>=2.3.0->torchtext) (4.10.0)\n",
"Requirement already satisfied: sympy in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from torch>=2.3.0->torchtext) (1.12)\n",
"Requirement already satisfied: networkx in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from torch>=2.3.0->torchtext) (3.2.1)\n",
"Requirement already satisfied: jinja2 in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from torch>=2.3.0->torchtext) (3.1.3)\n",
"Requirement already satisfied: fsspec in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from torch>=2.3.0->torchtext) (2024.3.1)\n",
"Requirement already satisfied: mkl<=2021.4.0,>=2021.1.1 in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from torch>=2.3.0->torchtext) (2021.4.0)\n",
"Requirement already satisfied: charset-normalizer<4,>=2 in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from requests->torchtext) (3.3.2)\n",
"Requirement already satisfied: idna<4,>=2.5 in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from requests->torchtext) (3.6)\n",
"Requirement already satisfied: urllib3<3,>=1.21.1 in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from requests->torchtext) (2.2.1)\n",
"Requirement already satisfied: certifi>=2017.4.17 in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from requests->torchtext) (2024.2.2)\n",
"Requirement already satisfied: colorama in c:\\users\\skype\\appdata\\roaming\\python\\python312\\site-packages (from tqdm->torchtext) (0.4.6)\n",
"Requirement already satisfied: intel-openmp==2021.* in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from mkl<=2021.4.0,>=2021.1.1->torch>=2.3.0->torchtext) (2021.4.0)\n",
"Requirement already satisfied: tbb==2021.* in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from mkl<=2021.4.0,>=2021.1.1->torch>=2.3.0->torchtext) (2021.12.0)\n",
"Requirement already satisfied: MarkupSafe>=2.0 in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from jinja2->torch>=2.3.0->torchtext) (2.1.5)\n",
"Requirement already satisfied: mpmath>=0.19 in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from sympy->torch>=2.3.0->torchtext) (1.3.0)\n",
"Note: you may need to restart the kernel to use updated packages.\n",
"Requirement already satisfied: datasets in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (2.19.1)\n",
"Requirement already satisfied: filelock in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from datasets) (3.14.0)\n",
"Requirement already satisfied: numpy>=1.17 in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from datasets) (1.26.3)\n",
"Requirement already satisfied: pyarrow>=12.0.0 in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from datasets) (15.0.2)\n",
"Requirement already satisfied: pyarrow-hotfix in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from datasets) (0.6)\n",
"Requirement already satisfied: dill<0.3.9,>=0.3.0 in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from datasets) (0.3.8)\n",
"Requirement already satisfied: pandas in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from datasets) (2.2.1)\n",
"Requirement already satisfied: requests>=2.19.0 in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from datasets) (2.31.0)\n",
"Requirement already satisfied: tqdm>=4.62.1 in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from datasets) (4.66.2)\n",
"Requirement already satisfied: xxhash in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from datasets) (3.4.1)\n",
"Requirement already satisfied: multiprocess in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from datasets) (0.70.16)\n",
"Requirement already satisfied: fsspec<=2024.3.1,>=2023.1.0 in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from fsspec[http]<=2024.3.1,>=2023.1.0->datasets) (2024.3.1)\n",
"Requirement already satisfied: aiohttp in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from datasets) (3.9.5)\n",
"Requirement already satisfied: huggingface-hub>=0.21.2 in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from datasets) (0.23.1)\n",
"Requirement already satisfied: packaging in c:\\users\\skype\\appdata\\roaming\\python\\python312\\site-packages (from datasets) (23.2)\n",
"Requirement already satisfied: pyyaml>=5.1 in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from datasets) (6.0.1)\n",
"Requirement already satisfied: aiosignal>=1.1.2 in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from aiohttp->datasets) (1.3.1)\n",
"Requirement already satisfied: attrs>=17.3.0 in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from aiohttp->datasets) (23.2.0)\n",
"Requirement already satisfied: frozenlist>=1.1.1 in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from aiohttp->datasets) (1.4.1)\n",
"Requirement already satisfied: multidict<7.0,>=4.5 in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from aiohttp->datasets) (6.0.5)\n",
"Requirement already satisfied: yarl<2.0,>=1.0 in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from aiohttp->datasets) (1.9.4)\n",
"Requirement already satisfied: typing-extensions>=3.7.4.3 in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from huggingface-hub>=0.21.2->datasets) (4.10.0)\n",
"Requirement already satisfied: charset-normalizer<4,>=2 in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from requests>=2.19.0->datasets) (3.3.2)\n",
"Requirement already satisfied: idna<4,>=2.5 in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from requests>=2.19.0->datasets) (3.6)\n",
"Requirement already satisfied: urllib3<3,>=1.21.1 in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from requests>=2.19.0->datasets) (2.2.1)\n",
"Requirement already satisfied: certifi>=2017.4.17 in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from requests>=2.19.0->datasets) (2024.2.2)\n",
"Requirement already satisfied: colorama in c:\\users\\skype\\appdata\\roaming\\python\\python312\\site-packages (from tqdm>=4.62.1->datasets) (0.4.6)\n",
"Requirement already satisfied: python-dateutil>=2.8.2 in c:\\users\\skype\\appdata\\roaming\\python\\python312\\site-packages (from pandas->datasets) (2.9.0.post0)\n",
"Requirement already satisfied: pytz>=2020.1 in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from pandas->datasets) (2024.1)\n",
"Requirement already satisfied: tzdata>=2022.7 in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from pandas->datasets) (2024.1)\n",
"Requirement already satisfied: six>=1.5 in c:\\users\\skype\\appdata\\roaming\\python\\python312\\site-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.16.0)\n",
"Note: you may need to restart the kernel to use updated packages.\n",
"Requirement already satisfied: pandas in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (2.2.1)\n",
"Requirement already satisfied: numpy<2,>=1.26.0 in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from pandas) (1.26.3)\n",
"Requirement already satisfied: python-dateutil>=2.8.2 in c:\\users\\skype\\appdata\\roaming\\python\\python312\\site-packages (from pandas) (2.9.0.post0)\n",
"Requirement already satisfied: pytz>=2020.1 in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from pandas) (2024.1)\n",
"Requirement already satisfied: tzdata>=2022.7 in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from pandas) (2024.1)\n",
"Requirement already satisfied: six>=1.5 in c:\\users\\skype\\appdata\\roaming\\python\\python312\\site-packages (from python-dateutil>=2.8.2->pandas) (1.16.0)\n",
"Note: you may need to restart the kernel to use updated packages.\n",
"Requirement already satisfied: scikit-learn in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (1.4.1.post1)\n",
"Requirement already satisfied: numpy<2.0,>=1.19.5 in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from scikit-learn) (1.26.3)\n",
"Requirement already satisfied: scipy>=1.6.0 in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from scikit-learn) (1.12.0)\n",
"Requirement already satisfied: joblib>=1.2.0 in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from scikit-learn) (1.3.2)\n",
"Requirement already satisfied: threadpoolctl>=2.0.0 in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from scikit-learn) (3.3.0)\n",
"Note: you may need to restart the kernel to use updated packages.\n"
]
}
],
"source": [
"%pip install torch\n",
"%pip install torchtext\n",
"%pip install datasets\n",
"%pip install pandas\n",
"%pip install scikit-learn"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Importing libraries\n"
]
},
{
"cell_type": "code",
2024-05-27 13:01:15 +02:00
"execution_count": 18,
2024-05-25 18:52:33 +02:00
"metadata": {},
2024-05-27 13:01:15 +02:00
"outputs": [],
2024-05-25 18:52:33 +02:00
"source": [
"from collections import Counter\n",
"import torch\n",
"import pandas as pd\n",
"from torchtext.vocab import vocab\n",
"from sklearn.model_selection import train_test_split\n",
"from tqdm.notebook import tqdm"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Read datasets\n"
]
},
{
"cell_type": "code",
2024-05-26 13:40:54 +02:00
"execution_count": 3,
2024-05-25 18:52:33 +02:00
"metadata": {},
"outputs": [],
"source": [
"def read_data():\n",
" train_dataset = pd.read_csv(\n",
" \"train/train.tsv.xz\", compression=\"xz\", sep=\"\\t\", names=[\"Label\", \"Text\"]\n",
" )\n",
" dev_0_dataset = pd.read_csv(\"dev-0/in.tsv\", sep=\"\\t\", names=[\"Text\"])\n",
" dev_0_labels = pd.read_csv(\"dev-0/expected.tsv\", sep=\"\\t\", names=[\"Label\"])\n",
" test_A_dataset = pd.read_csv(\"test-A/in.tsv\", sep=\"\\t\", names=[\"Text\"])\n",
"\n",
" return train_dataset, dev_0_dataset, dev_0_labels, test_A_dataset"
]
},
{
"cell_type": "code",
2024-05-26 13:40:54 +02:00
"execution_count": 4,
2024-05-25 18:52:33 +02:00
"metadata": {},
"outputs": [],
"source": [
"train_dataset, dev_0_dataset, dev_0_labels, test_A_dataset = read_data()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Split the training data into training and validation sets\n"
]
},
{
"cell_type": "code",
2024-05-26 13:40:54 +02:00
"execution_count": 5,
2024-05-25 18:52:33 +02:00
"metadata": {},
"outputs": [],
"source": [
"train_texts, val_texts, train_labels, val_labels = train_test_split(\n",
" train_dataset[\"Text\"], train_dataset[\"Label\"], test_size=0.1, random_state=42\n",
")"
]
},
{
"cell_type": "code",
2024-05-26 13:40:54 +02:00
"execution_count": 6,
2024-05-25 18:52:33 +02:00
"metadata": {},
"outputs": [],
"source": [
"train_dataset = pd.DataFrame({\"Text\": train_texts, \"Label\": train_labels})\n",
"val_dataset = pd.DataFrame({\"Text\": val_texts, \"Label\": val_labels})"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Tokenize the text and labels\n"
]
},
{
"cell_type": "code",
2024-05-26 13:40:54 +02:00
"execution_count": 7,
2024-05-25 18:52:33 +02:00
"metadata": {},
"outputs": [],
"source": [
"train_dataset[\"tokenized_text\"] = train_dataset[\"Text\"].apply(lambda x: x.split())\n",
"train_dataset[\"tokenized_labels\"] = train_dataset[\"Label\"].apply(lambda x: x.split())"
]
},
{
"cell_type": "code",
2024-05-26 13:40:54 +02:00
"execution_count": 8,
2024-05-25 18:52:33 +02:00
"metadata": {},
"outputs": [],
"source": [
"val_dataset[\"tokenized_text\"] = val_dataset[\"Text\"].apply(lambda x: x.split())\n",
"val_dataset[\"tokenized_labels\"] = val_dataset[\"Label\"].apply(lambda x: x.split())"
]
},
{
"cell_type": "code",
2024-05-26 13:40:54 +02:00
"execution_count": 9,
2024-05-25 18:52:33 +02:00
"metadata": {},
"outputs": [],
"source": [
"dev_0_dataset[\"tokenized_text\"] = dev_0_dataset[\"Text\"].apply(lambda x: x.split())\n",
"dev_0_dataset[\"tokenized_labels\"] = dev_0_labels[\"Label\"].apply(lambda x: x.split())"
]
},
{
"cell_type": "code",
2024-05-26 13:40:54 +02:00
"execution_count": 10,
2024-05-25 18:52:33 +02:00
"metadata": {},
"outputs": [],
"source": [
"test_A_dataset[\"tokenized_text\"] = test_A_dataset[\"Text\"].apply(lambda x: x.split())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Create a vocab object which maps tokens to indices\n"
]
},
{
"cell_type": "code",
2024-05-26 13:40:54 +02:00
"execution_count": 11,
2024-05-25 18:52:33 +02:00
"metadata": {},
"outputs": [],
"source": [
"def build_vocab(dataset):\n",
" counter = Counter()\n",
" for document in dataset:\n",
" counter.update(document)\n",
" return vocab(counter, specials=[\"<unk>\", \"<pad>\", \"<bos>\", \"<eos>\"])"
]
},
{
"cell_type": "code",
2024-05-26 13:40:54 +02:00
"execution_count": 12,
2024-05-25 18:52:33 +02:00
"metadata": {},
"outputs": [],
"source": [
"v = build_vocab(train_dataset[\"tokenized_text\"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Map indices to tokens\n"
]
},
{
"cell_type": "code",
2024-05-26 13:40:54 +02:00
"execution_count": 13,
2024-05-25 18:52:33 +02:00
"metadata": {},
"outputs": [],
"source": [
"itos = v.get_itos()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Number of tokens in the vocabulary\n"
]
},
{
"cell_type": "code",
2024-05-26 13:40:54 +02:00
"execution_count": 14,
2024-05-25 18:52:33 +02:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"22154"
]
},
2024-05-26 13:40:54 +02:00
"execution_count": 14,
2024-05-25 18:52:33 +02:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(itos)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Index of the 'rejects' token\n"
]
},
{
"cell_type": "code",
2024-05-26 13:40:54 +02:00
"execution_count": 15,
2024-05-25 18:52:33 +02:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"9086"
]
},
2024-05-26 13:40:54 +02:00
"execution_count": 15,
2024-05-25 18:52:33 +02:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"v[\"rejects\"]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Index of the '\\<unk\\>' token\n"
]
},
{
"cell_type": "code",
2024-05-26 13:40:54 +02:00
"execution_count": 16,
2024-05-25 18:52:33 +02:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0"
]
},
2024-05-26 13:40:54 +02:00
"execution_count": 16,
2024-05-25 18:52:33 +02:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"v[\"<unk>\"]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Set the default index to the unknown token\n"
]
},
{
"cell_type": "code",
2024-05-26 13:40:54 +02:00
"execution_count": 17,
2024-05-25 18:52:33 +02:00
"metadata": {},
"outputs": [],
"source": [
"v.set_default_index(v[\"<unk>\"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Use cuda if available\n"
]
},
{
"cell_type": "code",
2024-05-26 13:40:54 +02:00
"execution_count": 18,
2024-05-25 18:52:33 +02:00
"metadata": {},
"outputs": [],
"source": [
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Vectorize the data\n"
]
},
{
"cell_type": "code",
2024-05-26 13:40:54 +02:00
"execution_count": 19,
2024-05-25 18:52:33 +02:00
"metadata": {},
"outputs": [],
"source": [
"def data_process(dt):\n",
" return [\n",
" torch.tensor(\n",
" [v[\"<bos>\"]] + [v[token] for token in document] + [v[\"<eos>\"]],\n",
" dtype=torch.long,\n",
" device=device,\n",
" )\n",
" for document in dt\n",
" ]"
]
},
{
"cell_type": "code",
2024-05-26 13:40:54 +02:00
"execution_count": 20,
2024-05-25 18:52:33 +02:00
"metadata": {},
"outputs": [],
"source": [
"train_tokens_ids = data_process(train_dataset[\"tokenized_text\"])\n",
"val_tokens_ids = data_process(val_dataset[\"tokenized_text\"])\n",
"dev_0_tokens_ids = data_process(dev_0_dataset[\"tokenized_text\"])\n",
"test_A_tokens_ids = data_process(test_A_dataset[\"tokenized_text\"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Create a mapping from label to index\n"
]
},
{
"cell_type": "code",
2024-05-26 13:40:54 +02:00
"execution_count": 21,
2024-05-25 18:52:33 +02:00
"metadata": {},
"outputs": [],
"source": [
"labels = [\"O\", \"B-PER\", \"I-PER\", \"B-ORG\", \"I-ORG\", \"B-LOC\", \"I-LOC\", \"B-MISC\", \"I-MISC\"]\n",
"\n",
"label_to_index = {label: idx for idx, label in enumerate(labels)}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Vectorize the labels (NER)\n"
]
},
{
"cell_type": "code",
2024-05-26 13:40:54 +02:00
"execution_count": 22,
2024-05-25 18:52:33 +02:00
"metadata": {},
"outputs": [],
"source": [
"def labels_process(dt, label_to_index):\n",
" return [\n",
" torch.tensor(\n",
" [0] + [label_to_index[label] for label in document] + [0],\n",
" dtype=torch.long,\n",
" device=device,\n",
" )\n",
" for document in dt\n",
" ]"
]
},
{
"cell_type": "code",
2024-05-26 13:40:54 +02:00
"execution_count": 23,
2024-05-25 18:52:33 +02:00
"metadata": {},
"outputs": [],
"source": [
"train_labels = labels_process(train_dataset[\"tokenized_labels\"], label_to_index)\n",
"val_labels = labels_process(val_dataset[\"tokenized_labels\"], label_to_index)\n",
"dev_0_labels = labels_process(dev_0_dataset[\"tokenized_labels\"], label_to_index)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Function for evaluation (returns precision, recall, and F1 score)\n"
]
},
{
"cell_type": "code",
2024-05-26 13:40:54 +02:00
"execution_count": 24,
2024-05-25 18:52:33 +02:00
"metadata": {},
"outputs": [],
"source": [
"def get_scores(y_true, y_pred):\n",
" acc_score = 0\n",
" tp = 0\n",
" fp = 0\n",
" selected_items = 0\n",
" relevant_items = 0\n",
"\n",
" for p, t in zip(y_pred, y_true):\n",
" if p == t:\n",
" acc_score += 1\n",
"\n",
" if p > 0 and p == t:\n",
" tp += 1\n",
"\n",
" if p > 0:\n",
" selected_items += 1\n",
"\n",
" if t > 0:\n",
" relevant_items += 1\n",
"\n",
" if selected_items == 0:\n",
" precision = 1.0\n",
" else:\n",
" precision = tp / selected_items\n",
"\n",
" if relevant_items == 0:\n",
" recall = 1.0\n",
" else:\n",
" recall = tp / relevant_items\n",
"\n",
" if precision + recall == 0.0:\n",
" f1 = 0.0\n",
" else:\n",
" f1 = 2 * precision * recall / (precision + recall)\n",
"\n",
" return precision, recall, f1"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Calculate the number of unique tags\n"
]
},
{
"cell_type": "code",
2024-05-26 13:40:54 +02:00
"execution_count": 25,
2024-05-25 18:52:33 +02:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"9\n"
]
}
],
"source": [
"all_label_indices = [\n",
" label_to_index[label]\n",
" for document in train_dataset[\"tokenized_labels\"]\n",
" for label in document\n",
"]\n",
"\n",
"num_tags = max(all_label_indices) + 1\n",
"\n",
"print(num_tags)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Implementation of a recurrent neural network LSTM\n"
]
},
{
"cell_type": "code",
2024-05-26 13:40:54 +02:00
"execution_count": 26,
2024-05-25 18:52:33 +02:00
"metadata": {},
"outputs": [],
"source": [
"class LSTM(torch.nn.Module):\n",
"\n",
2024-05-25 19:27:27 +02:00
" def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers, num_tags):\n",
2024-05-25 18:52:33 +02:00
" super(LSTM, self).__init__()\n",
2024-05-25 19:27:27 +02:00
" self.embedding = torch.nn.Embedding(vocab_size, embedding_dim)\n",
" self.rec = torch.nn.LSTM(\n",
" embedding_dim, hidden_dim, num_layers, batch_first=True, bidirectional=True\n",
" )\n",
" self.fc1 = torch.nn.Linear(hidden_dim * 2, num_tags)\n",
2024-05-25 18:52:33 +02:00
"\n",
" def forward(self, x):\n",
2024-05-25 19:27:27 +02:00
" embedding = torch.relu(self.embedding(x))\n",
" lstm_output, _ = self.rec(embedding)\n",
2024-05-25 18:52:33 +02:00
" out_weights = self.fc1(lstm_output)\n",
" return out_weights"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Initialize the LSTM model\n"
]
},
{
"cell_type": "code",
2024-05-26 13:40:54 +02:00
"execution_count": 27,
2024-05-25 18:52:33 +02:00
"metadata": {},
"outputs": [],
"source": [
2024-05-25 19:27:27 +02:00
"lstm = LSTM(len(v.get_itos()), 100, 100, 1, num_tags).to(device)"
2024-05-25 18:52:33 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Define the loss function\n"
]
},
{
"cell_type": "code",
2024-05-26 13:40:54 +02:00
"execution_count": 28,
2024-05-25 18:52:33 +02:00
"metadata": {},
"outputs": [],
"source": [
"criterion = torch.nn.CrossEntropyLoss()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Define the optimizer\n"
]
},
{
"cell_type": "code",
2024-05-26 13:40:54 +02:00
"execution_count": 29,
2024-05-25 18:52:33 +02:00
"metadata": {},
"outputs": [],
"source": [
"optimizer = torch.optim.Adam(lstm.parameters())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Function for model evaluation\n"
]
},
{
"cell_type": "code",
2024-05-26 13:40:54 +02:00
"execution_count": 30,
2024-05-25 18:52:33 +02:00
"metadata": {},
"outputs": [],
"source": [
"def eval_model(dataset_tokens, dataset_labels, model):\n",
" Y_true = []\n",
" Y_pred = []\n",
" for i in tqdm(range(len(dataset_labels))):\n",
" batch_tokens = dataset_tokens[i].unsqueeze(0)\n",
" tags = list(dataset_labels[i].cpu().numpy())\n",
" Y_true += tags\n",
"\n",
" Y_batch_pred_weights = model(batch_tokens).squeeze(0)\n",
" Y_batch_pred = torch.argmax(Y_batch_pred_weights, 1)\n",
" Y_pred += list(Y_batch_pred.cpu().numpy())\n",
"\n",
" return get_scores(Y_true, Y_pred)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Function for returning the predictions labels\n"
]
},
{
"cell_type": "code",
2024-05-26 13:40:54 +02:00
"execution_count": 31,
2024-05-25 18:52:33 +02:00
"metadata": {},
"outputs": [],
"source": [
"def pred_labels(dataset_tokens, model, label_to_index):\n",
" Y_pred = []\n",
" inv_label_to_index = {\n",
" v: k for k, v in label_to_index.items()\n",
" } # Create the inverse mapping\n",
"\n",
" for i in tqdm(range(len(dataset_tokens))):\n",
" batch_tokens = dataset_tokens[i].unsqueeze(0)\n",
" Y_batch_pred_weights = model(batch_tokens).squeeze(0)\n",
" Y_batch_pred = torch.argmax(Y_batch_pred_weights, 1)\n",
" predicted_labels = [inv_label_to_index[label.item()] for label in Y_batch_pred]\n",
2024-05-26 13:40:54 +02:00
" predicted_labels = predicted_labels[1:-1]\n",
2024-05-25 18:52:33 +02:00
" Y_pred.append(\" \".join(predicted_labels))\n",
"\n",
" return Y_pred"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Training\n"
]
},
{
"cell_type": "code",
2024-05-26 13:40:54 +02:00
"execution_count": 32,
2024-05-25 18:52:33 +02:00
"metadata": {},
"outputs": [],
"source": [
2024-05-25 19:27:27 +02:00
"NUM_EPOCHS = 20"
2024-05-25 18:52:33 +02:00
]
},
{
"cell_type": "code",
2024-05-26 13:40:54 +02:00
"execution_count": 33,
2024-05-25 18:52:33 +02:00
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
2024-05-26 13:40:54 +02:00
"model_id": "56ca1f77d2d843bbbf760fead298cb00",
2024-05-25 18:52:33 +02:00
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/850 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
2024-05-26 13:40:54 +02:00
"model_id": "98da31b3a184494ba6d6a6dc86468025",
2024-05-25 18:52:33 +02:00
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/95 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-05-26 13:40:54 +02:00
"(0.6839857651245551, 0.27638769053782, 0.3936911102007374)\n"
2024-05-25 18:52:33 +02:00
]
2024-05-25 19:27:27 +02:00
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
2024-05-26 13:40:54 +02:00
"model_id": "3a72db21f0aa4d20b328b59201ee3763",
2024-05-25 19:27:27 +02:00
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/850 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
2024-05-25 18:52:33 +02:00
{
"data": {
"application/vnd.jupyter.widget-view+json": {
2024-05-26 13:40:54 +02:00
"model_id": "0e21f4a67ba74420ae620d489e5cb87d",
2024-05-25 18:52:33 +02:00
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/95 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
2024-05-25 19:27:27 +02:00
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-05-26 13:40:54 +02:00
"(0.7672373900971773, 0.4768478573482888, 0.5881518268889678)\n"
2024-05-25 19:27:27 +02:00
]
},
2024-05-25 18:52:33 +02:00
{
"data": {
2024-05-25 19:27:27 +02:00
"application/vnd.jupyter.widget-view+json": {
2024-05-26 13:40:54 +02:00
"model_id": "4f6da9aaae844f9dbce34a463798b9ef",
2024-05-25 19:27:27 +02:00
"version_major": 2,
"version_minor": 0
},
2024-05-25 18:52:33 +02:00
"text/plain": [
2024-05-25 19:27:27 +02:00
" 0%| | 0/850 [00:00<?, ?it/s]"
2024-05-25 18:52:33 +02:00
]
},
"metadata": {},
2024-05-25 19:27:27 +02:00
"output_type": "display_data"
},
2024-05-25 18:52:33 +02:00
{
"data": {
"application/vnd.jupyter.widget-view+json": {
2024-05-26 13:40:54 +02:00
"model_id": "780a5d7519d4464db894f04f56ca7ab4",
2024-05-25 18:52:33 +02:00
"version_major": 2,
"version_minor": 0
},
"text/plain": [
2024-05-25 19:27:27 +02:00
" 0%| | 0/95 [00:00<?, ?it/s]"
2024-05-25 18:52:33 +02:00
]
},
"metadata": {},
"output_type": "display_data"
},
2024-05-25 19:27:27 +02:00
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-05-26 13:40:54 +02:00
"(0.7881323697223279, 0.5959160195570894, 0.678676711431379)\n"
2024-05-25 19:27:27 +02:00
]
},
2024-05-25 18:52:33 +02:00
{
"data": {
2024-05-25 19:27:27 +02:00
"application/vnd.jupyter.widget-view+json": {
2024-05-26 13:40:54 +02:00
"model_id": "5bad58a6fc9541cab17eddf933ba90b5",
2024-05-25 19:27:27 +02:00
"version_major": 2,
"version_minor": 0
},
2024-05-25 18:52:33 +02:00
"text/plain": [
2024-05-25 19:27:27 +02:00
" 0%| | 0/850 [00:00<?, ?it/s]"
2024-05-25 18:52:33 +02:00
]
},
"metadata": {},
2024-05-25 19:27:27 +02:00
"output_type": "display_data"
},
2024-05-25 18:52:33 +02:00
{
"data": {
"application/vnd.jupyter.widget-view+json": {
2024-05-26 13:40:54 +02:00
"model_id": "496c5c8b1e874841a0a94ae776979f36",
2024-05-25 18:52:33 +02:00
"version_major": 2,
"version_minor": 0
},
"text/plain": [
2024-05-25 19:27:27 +02:00
" 0%| | 0/95 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-05-26 13:40:54 +02:00
"(0.7992857142857143, 0.6436583261432269, 0.7130794965747969)\n"
2024-05-25 19:27:27 +02:00
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
2024-05-26 13:40:54 +02:00
"model_id": "7c18d00543da431f97462f292ac04eb3",
2024-05-25 19:27:27 +02:00
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/850 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
2024-05-26 13:40:54 +02:00
"model_id": "f86f10fbd62a46419a174e9cad52d414",
2024-05-25 19:27:27 +02:00
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/95 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-05-26 13:40:54 +02:00
"(0.7935967302452316, 0.6701179177451826, 0.7266489942304692)\n"
2024-05-25 19:27:27 +02:00
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
2024-05-26 13:40:54 +02:00
"model_id": "599bf5f15e344e0d99d1b97a1dada9c1",
2024-05-25 19:27:27 +02:00
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/850 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
2024-05-26 13:40:54 +02:00
"model_id": "1ca8a85bdae1450cb9a05fe3833627e7",
2024-05-25 19:27:27 +02:00
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/95 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-05-26 13:40:54 +02:00
"(0.7951251646903821, 0.6942766752947943, 0.7412866574543221)\n"
2024-05-25 19:27:27 +02:00
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
2024-05-26 13:40:54 +02:00
"model_id": "48532842807846509577a6263e9b7089",
2024-05-25 19:27:27 +02:00
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/850 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
2024-05-26 13:40:54 +02:00
"model_id": "03cb456498b3433d98feec2f4e7469a0",
2024-05-25 19:27:27 +02:00
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/95 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-05-26 13:40:54 +02:00
"(0.794789321325185, 0.7106701179177451, 0.7503795930762224)\n"
2024-05-25 19:27:27 +02:00
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
2024-05-26 13:40:54 +02:00
"model_id": "1e4c0bf415a74a3b92435a944310b271",
2024-05-25 19:27:27 +02:00
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/850 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
2024-05-26 13:40:54 +02:00
"model_id": "670f9768f28345c4a28f5961ebbee780",
2024-05-25 19:27:27 +02:00
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/95 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-05-26 13:40:54 +02:00
"(0.8046014257939079, 0.7141213689962611, 0.7566661587688556)\n"
2024-05-25 19:27:27 +02:00
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
2024-05-26 13:40:54 +02:00
"model_id": "f74c301cbde847f8bc7f62d2d81d0194",
2024-05-25 19:27:27 +02:00
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/850 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
2024-05-26 13:40:54 +02:00
"model_id": "ed35cacbda794c59acecbd270a3f25ad",
2024-05-25 19:27:27 +02:00
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/95 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-05-26 13:40:54 +02:00
"(0.8059558117195005, 0.7238999137187231, 0.7627272727272727)\n"
2024-05-25 19:27:27 +02:00
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
2024-05-26 13:40:54 +02:00
"model_id": "226bf16a3d7e4158b7e4b406e2260ca0",
2024-05-25 19:27:27 +02:00
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/850 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
2024-05-26 13:40:54 +02:00
"model_id": "64d8bf3bc34b4ef0937833f998c164dd",
2024-05-25 19:27:27 +02:00
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/95 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-05-26 13:40:54 +02:00
"(0.8214634146341463, 0.72648835202761, 0.7710622710622711)\n"
2024-05-25 19:27:27 +02:00
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
2024-05-26 13:40:54 +02:00
"model_id": "9dff285530c143d2a2742dcd0100ba2c",
2024-05-25 19:27:27 +02:00
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/850 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
2024-05-26 13:40:54 +02:00
"model_id": "a7b4dabfd9cd496384bab072ef005a19",
2024-05-25 19:27:27 +02:00
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/95 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-05-26 13:40:54 +02:00
"(0.8328277721604315, 0.7106701179177451, 0.7669149596523898)\n"
2024-05-25 19:27:27 +02:00
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
2024-05-26 13:40:54 +02:00
"model_id": "c529aaf8f97c4fba84fda6f7bf6997b0",
2024-05-25 19:27:27 +02:00
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/850 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
2024-05-26 13:40:54 +02:00
"model_id": "d9b377ca735b4b0eb38c1b0b2ca3651b",
2024-05-25 19:27:27 +02:00
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/95 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-05-26 13:40:54 +02:00
"(0.8623625399077687, 0.6991659476560254, 0.7722363405336722)\n"
2024-05-25 19:27:27 +02:00
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
2024-05-26 13:40:54 +02:00
"model_id": "a598d0724ca4486ea8c68aa6c07a3111",
2024-05-25 19:27:27 +02:00
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/850 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
2024-05-26 13:40:54 +02:00
"model_id": "a896d04311e846fe9253020b531c2c6a",
2024-05-25 19:27:27 +02:00
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/95 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-05-26 13:40:54 +02:00
"(0.83496126641967, 0.7129709519700892, 0.7691591684765747)\n"
2024-05-25 19:27:27 +02:00
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
2024-05-26 13:40:54 +02:00
"model_id": "8a00bd95b14745a8b2846925b813e473",
2024-05-25 19:27:27 +02:00
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/850 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
2024-05-26 13:40:54 +02:00
"model_id": "ae225fdd11954d94ba54f21da0e64070",
2024-05-25 19:27:27 +02:00
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/95 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-05-26 13:40:54 +02:00
"(0.8154145077720207, 0.724187517975266, 0.7670982482863671)\n"
2024-05-25 19:27:27 +02:00
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
2024-05-26 13:40:54 +02:00
"model_id": "ea835b4f9c35406aa521709cfb9f126b",
2024-05-25 19:27:27 +02:00
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/850 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
2024-05-26 13:40:54 +02:00
"model_id": "0701c8301be3419a9e5b5375975eec59",
2024-05-25 19:27:27 +02:00
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/95 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-05-26 13:40:54 +02:00
"(0.8324958123953099, 0.7146965775093471, 0.7691117301145157)\n"
2024-05-25 19:27:27 +02:00
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
2024-05-26 13:40:54 +02:00
"model_id": "dbbb802e734745d08a3987ee6751881b",
2024-05-25 19:27:27 +02:00
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/850 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
2024-05-26 13:40:54 +02:00
"model_id": "6e16e721d3a143118acc873fec5e61ae",
2024-05-25 19:27:27 +02:00
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/95 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-05-26 13:40:54 +02:00
"(0.839452603471295, 0.7233247052056371, 0.7770739996910242)\n"
2024-05-25 19:27:27 +02:00
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
2024-05-26 13:40:54 +02:00
"model_id": "f2c25dd34a7a4ecbbdf7d21bff21c477",
2024-05-25 19:27:27 +02:00
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/850 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
2024-05-26 13:40:54 +02:00
"model_id": "85f39bd1fd394e12b0c2e06561d8032d",
2024-05-25 19:27:27 +02:00
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/95 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-05-26 13:40:54 +02:00
"(0.8292120013188262, 0.7233247052056371, 0.7726574500768049)\n"
2024-05-25 19:27:27 +02:00
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
2024-05-26 13:40:54 +02:00
"model_id": "4ff1a830578747cda4b35f4c87ba3d0f",
2024-05-25 19:27:27 +02:00
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/850 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
2024-05-26 13:40:54 +02:00
"model_id": "a0ae2f3a6d4f4e4da43fff14556ec664",
2024-05-25 19:27:27 +02:00
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/95 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-05-26 13:40:54 +02:00
"(0.8110863332271424, 0.73224043715847, 0.7696493349455865)\n"
2024-05-25 19:27:27 +02:00
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
2024-05-26 13:40:54 +02:00
"model_id": "1b56482ff2b74816b4048d9a3f2e7cea",
2024-05-25 19:27:27 +02:00
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/850 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
2024-05-26 13:40:54 +02:00
"model_id": "d8165c18298a4ea99734a40ed70590c0",
2024-05-25 19:27:27 +02:00
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/95 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-05-26 13:40:54 +02:00
"(0.8210254756530152, 0.73224043715847, 0.7740954697476436)\n"
2024-05-25 19:27:27 +02:00
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
2024-05-26 13:40:54 +02:00
"model_id": "fb7658347a184892866e5e82be91f33e",
2024-05-25 19:27:27 +02:00
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/850 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
2024-05-26 13:40:54 +02:00
"model_id": "4cae8a4711374ab4945f15e9638bc38c",
2024-05-25 19:27:27 +02:00
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/95 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-05-26 13:40:54 +02:00
"(0.8173884938590821, 0.727351164797239, 0.7697458529904124)\n"
2024-05-25 19:27:27 +02:00
]
}
],
"source": [
"for i in range(NUM_EPOCHS):\n",
" lstm.train()\n",
" for i in tqdm(range(len(train_labels))):\n",
" batch_tokens = train_tokens_ids[i].unsqueeze(0)\n",
" tags = train_labels[i].unsqueeze(1)\n",
"\n",
" predicted_tags = lstm(batch_tokens)\n",
"\n",
" optimizer.zero_grad()\n",
" loss = criterion(predicted_tags.squeeze(0), tags.squeeze(1))\n",
"\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
" lstm.eval()\n",
" print(eval_model(val_tokens_ids, val_labels, lstm))"
]
},
{
"cell_type": "code",
2024-05-26 13:40:54 +02:00
"execution_count": 34,
2024-05-25 19:27:27 +02:00
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
2024-05-26 13:40:54 +02:00
"model_id": "0cd8fbb1f4a34b569773014519a586de",
2024-05-25 19:27:27 +02:00
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/95 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
2024-05-26 13:40:54 +02:00
"(0.8173884938590821, 0.727351164797239, 0.7697458529904124)"
2024-05-25 19:27:27 +02:00
]
},
2024-05-26 13:40:54 +02:00
"execution_count": 34,
2024-05-25 19:27:27 +02:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"eval_model(val_tokens_ids, val_labels, lstm)"
]
},
{
"cell_type": "code",
2024-05-26 13:40:54 +02:00
"execution_count": 35,
2024-05-25 19:27:27 +02:00
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
2024-05-26 13:40:54 +02:00
"model_id": "6ac118471d25448a910260c4ffd01bf5",
2024-05-25 19:27:27 +02:00
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/215 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
2024-05-26 13:40:54 +02:00
"(0.8368401624215578, 0.7924726171055698, 0.8140523071398648)"
2024-05-25 19:27:27 +02:00
]
},
2024-05-26 13:40:54 +02:00
"execution_count": 35,
2024-05-25 19:27:27 +02:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"eval_model(dev_0_tokens_ids, dev_0_labels, lstm)"
]
},
{
"cell_type": "code",
2024-05-26 13:40:54 +02:00
"execution_count": 36,
2024-05-25 19:27:27 +02:00
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
2024-05-26 13:40:54 +02:00
"model_id": "79cda761e6e9434f99bbc1e2022bf9a7",
2024-05-25 19:27:27 +02:00
"version_major": 2,
"version_minor": 0
},
"text/plain": [
2024-05-26 13:40:54 +02:00
" 0%| | 0/215 [00:00<?, ?it/s]"
2024-05-25 18:52:33 +02:00
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"dev_0_predictons = pred_labels(dev_0_tokens_ids, lstm, label_to_index)\n",
"dev_0_predictons = pd.DataFrame(dev_0_predictons, columns=[\"Label\"])\n",
"dev_0_predictons.to_csv(\"dev-0/out.tsv\", index=False, header=False)"
]
},
{
"cell_type": "code",
2024-05-26 13:40:54 +02:00
"execution_count": 37,
2024-05-25 18:52:33 +02:00
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
2024-05-26 13:40:54 +02:00
"model_id": "35e9a5e3a08b4404a6dfdd5a94418028",
2024-05-25 18:52:33 +02:00
"version_major": 2,
"version_minor": 0
},
"text/plain": [
2024-05-26 13:40:54 +02:00
" 0%| | 0/230 [00:00<?, ?it/s]"
2024-05-25 18:52:33 +02:00
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"test_A_predictions = pred_labels(test_A_tokens_ids, lstm, label_to_index)\n",
"test_A_predictions = pd.DataFrame(test_A_predictions, columns=[\"Label\"])\n",
"test_A_predictions.to_csv(\"test-A/out.tsv\", index=False, header=False)"
]
2024-05-27 13:01:15 +02:00
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Correct labels\n"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"def correct_labels(input_file, output_file):\n",
" df = pd.read_csv(input_file, sep=\"\\t\", names=[\"Text\"])\n",
"\n",
" corrected_lines = []\n",
"\n",
" for line in df[\"Text\"]:\n",
" tokens = line.split(\" \")\n",
" corrected_tokens = []\n",
" previous_token = \"O\"\n",
"\n",
" for token in tokens:\n",
" if (\n",
" token == \"I-ORG\"\n",
" and previous_token != \"B-ORG\"\n",
" and previous_token != \"I-ORG\"\n",
" ):\n",
" corrected_tokens.append(\"B-ORG\")\n",
" elif (\n",
" token == \"I-PER\"\n",
" and previous_token != \"B-PER\"\n",
" and previous_token != \"I-PER\"\n",
" ):\n",
" corrected_tokens.append(\"B-PER\")\n",
" elif (\n",
" token == \"I-LOC\"\n",
" and previous_token != \"B-LOC\"\n",
" and previous_token != \"I-LOC\"\n",
" ):\n",
" corrected_tokens.append(\"B-LOC\")\n",
" elif (\n",
" token == \"I-MISC\"\n",
" and previous_token != \"B-MISC\"\n",
" and previous_token != \"I-MISC\"\n",
" ):\n",
" corrected_tokens.append(\"B-MISC\")\n",
" else:\n",
" corrected_tokens.append(token)\n",
"\n",
" previous_token = token\n",
"\n",
" corrected_line = \" \".join(corrected_tokens)\n",
" corrected_lines.append(corrected_line)\n",
"\n",
" df[\"Text\"] = corrected_lines\n",
" df.to_csv(output_file, sep=\"\\t\", index=False, header=False)"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"correct_labels(\"dev-0/out.tsv\", \"dev-0/out.tsv\")\n",
"correct_labels(\"test-A/out.tsv\", \"test-A/out.tsv\")"
]
2024-05-25 18:52:33 +02:00
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}