{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## RNN\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Installation of packages\n" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: torch in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (2.3.0)\n", "Requirement already satisfied: filelock in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from torch) (3.14.0)\n", "Requirement already satisfied: typing-extensions>=4.8.0 in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from torch) (4.10.0)\n", "Requirement already satisfied: sympy in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from torch) (1.12)\n", "Requirement already satisfied: networkx in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from torch) (3.2.1)\n", "Requirement already satisfied: jinja2 in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from torch) (3.1.3)\n", "Requirement already satisfied: fsspec in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from torch) (2024.3.1)\n", "Requirement already satisfied: mkl<=2021.4.0,>=2021.1.1 in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from torch) (2021.4.0)\n", "Requirement already satisfied: intel-openmp==2021.* in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from mkl<=2021.4.0,>=2021.1.1->torch) (2021.4.0)\n", "Requirement already satisfied: tbb==2021.* in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from mkl<=2021.4.0,>=2021.1.1->torch) (2021.12.0)\n", "Requirement already satisfied: MarkupSafe>=2.0 in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from jinja2->torch) (2.1.5)\n", "Requirement already satisfied: mpmath>=0.19 in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from sympy->torch) (1.3.0)\n", "Note: you may need to restart the kernel to use updated packages.\n", "Requirement already satisfied: torchtext in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (0.18.0)\n", "Requirement already satisfied: tqdm in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from torchtext) (4.66.2)\n", "Requirement already satisfied: requests in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from torchtext) (2.31.0)\n", "Requirement already satisfied: torch>=2.3.0 in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from torchtext) (2.3.0)\n", "Requirement already satisfied: numpy in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from torchtext) (1.26.3)\n", "Requirement already satisfied: filelock in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from torch>=2.3.0->torchtext) (3.14.0)\n", "Requirement already satisfied: typing-extensions>=4.8.0 in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from torch>=2.3.0->torchtext) (4.10.0)\n", "Requirement already satisfied: sympy in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from torch>=2.3.0->torchtext) (1.12)\n", "Requirement already satisfied: networkx in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from torch>=2.3.0->torchtext) (3.2.1)\n", "Requirement already satisfied: jinja2 in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from torch>=2.3.0->torchtext) (3.1.3)\n", "Requirement already satisfied: fsspec in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from torch>=2.3.0->torchtext) (2024.3.1)\n", "Requirement already satisfied: mkl<=2021.4.0,>=2021.1.1 in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from torch>=2.3.0->torchtext) (2021.4.0)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from requests->torchtext) (3.3.2)\n", "Requirement already satisfied: idna<4,>=2.5 in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from requests->torchtext) (3.6)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from requests->torchtext) (2.2.1)\n", "Requirement already satisfied: certifi>=2017.4.17 in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from requests->torchtext) (2024.2.2)\n", "Requirement already satisfied: colorama in c:\\users\\skype\\appdata\\roaming\\python\\python312\\site-packages (from tqdm->torchtext) (0.4.6)\n", "Requirement already satisfied: intel-openmp==2021.* in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from mkl<=2021.4.0,>=2021.1.1->torch>=2.3.0->torchtext) (2021.4.0)\n", "Requirement already satisfied: tbb==2021.* in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from mkl<=2021.4.0,>=2021.1.1->torch>=2.3.0->torchtext) (2021.12.0)\n", "Requirement already satisfied: MarkupSafe>=2.0 in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from jinja2->torch>=2.3.0->torchtext) (2.1.5)\n", "Requirement already satisfied: mpmath>=0.19 in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from sympy->torch>=2.3.0->torchtext) (1.3.0)\n", "Note: you may need to restart the kernel to use updated packages.\n", "Requirement already satisfied: datasets in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (2.19.1)\n", "Requirement already satisfied: filelock in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from datasets) (3.14.0)\n", "Requirement already satisfied: numpy>=1.17 in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from datasets) (1.26.3)\n", "Requirement already satisfied: pyarrow>=12.0.0 in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from datasets) (15.0.2)\n", "Requirement already satisfied: pyarrow-hotfix in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from datasets) (0.6)\n", "Requirement already satisfied: dill<0.3.9,>=0.3.0 in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from datasets) (0.3.8)\n", "Requirement already satisfied: pandas in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from datasets) (2.2.1)\n", "Requirement already satisfied: requests>=2.19.0 in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from datasets) (2.31.0)\n", "Requirement already satisfied: tqdm>=4.62.1 in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from datasets) (4.66.2)\n", "Requirement already satisfied: xxhash in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from datasets) (3.4.1)\n", "Requirement already satisfied: multiprocess in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from datasets) (0.70.16)\n", "Requirement already satisfied: fsspec<=2024.3.1,>=2023.1.0 in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from fsspec[http]<=2024.3.1,>=2023.1.0->datasets) (2024.3.1)\n", "Requirement already satisfied: aiohttp in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from datasets) (3.9.5)\n", "Requirement already satisfied: huggingface-hub>=0.21.2 in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from datasets) (0.23.1)\n", "Requirement already satisfied: packaging in c:\\users\\skype\\appdata\\roaming\\python\\python312\\site-packages (from datasets) (23.2)\n", "Requirement already satisfied: pyyaml>=5.1 in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from datasets) (6.0.1)\n", "Requirement already satisfied: aiosignal>=1.1.2 in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from aiohttp->datasets) (1.3.1)\n", "Requirement already satisfied: attrs>=17.3.0 in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from aiohttp->datasets) (23.2.0)\n", "Requirement already satisfied: frozenlist>=1.1.1 in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from aiohttp->datasets) (1.4.1)\n", "Requirement already satisfied: multidict<7.0,>=4.5 in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from aiohttp->datasets) (6.0.5)\n", "Requirement already satisfied: yarl<2.0,>=1.0 in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from aiohttp->datasets) (1.9.4)\n", "Requirement already satisfied: typing-extensions>=3.7.4.3 in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from huggingface-hub>=0.21.2->datasets) (4.10.0)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from requests>=2.19.0->datasets) (3.3.2)\n", "Requirement already satisfied: idna<4,>=2.5 in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from requests>=2.19.0->datasets) (3.6)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from requests>=2.19.0->datasets) (2.2.1)\n", "Requirement already satisfied: certifi>=2017.4.17 in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from requests>=2.19.0->datasets) (2024.2.2)\n", "Requirement already satisfied: colorama in c:\\users\\skype\\appdata\\roaming\\python\\python312\\site-packages (from tqdm>=4.62.1->datasets) (0.4.6)\n", "Requirement already satisfied: python-dateutil>=2.8.2 in c:\\users\\skype\\appdata\\roaming\\python\\python312\\site-packages (from pandas->datasets) (2.9.0.post0)\n", "Requirement already satisfied: pytz>=2020.1 in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from pandas->datasets) (2024.1)\n", "Requirement already satisfied: tzdata>=2022.7 in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from pandas->datasets) (2024.1)\n", "Requirement already satisfied: six>=1.5 in c:\\users\\skype\\appdata\\roaming\\python\\python312\\site-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.16.0)\n", "Note: you may need to restart the kernel to use updated packages.\n", "Requirement already satisfied: pandas in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (2.2.1)\n", "Requirement already satisfied: numpy<2,>=1.26.0 in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from pandas) (1.26.3)\n", "Requirement already satisfied: python-dateutil>=2.8.2 in c:\\users\\skype\\appdata\\roaming\\python\\python312\\site-packages (from pandas) (2.9.0.post0)\n", "Requirement already satisfied: pytz>=2020.1 in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from pandas) (2024.1)\n", "Requirement already satisfied: tzdata>=2022.7 in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from pandas) (2024.1)\n", "Requirement already satisfied: six>=1.5 in c:\\users\\skype\\appdata\\roaming\\python\\python312\\site-packages (from python-dateutil>=2.8.2->pandas) (1.16.0)\n", "Note: you may need to restart the kernel to use updated packages.\n", "Requirement already satisfied: scikit-learn in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (1.4.1.post1)\n", "Requirement already satisfied: numpy<2.0,>=1.19.5 in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from scikit-learn) (1.26.3)\n", "Requirement already satisfied: scipy>=1.6.0 in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from scikit-learn) (1.12.0)\n", "Requirement already satisfied: joblib>=1.2.0 in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from scikit-learn) (1.3.2)\n", "Requirement already satisfied: threadpoolctl>=2.0.0 in c:\\users\\skype\\appdata\\local\\programs\\python\\python312\\lib\\site-packages (from scikit-learn) (3.3.0)\n", "Note: you may need to restart the kernel to use updated packages.\n" ] } ], "source": [ "%pip install torch\n", "%pip install torchtext\n", "%pip install datasets\n", "%pip install pandas\n", "%pip install scikit-learn" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Importing libraries\n" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "from collections import Counter\n", "import torch\n", "import pandas as pd\n", "from torchtext.vocab import vocab\n", "from sklearn.model_selection import train_test_split\n", "from tqdm.notebook import tqdm" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Read datasets\n" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "def read_data():\n", " train_dataset = pd.read_csv(\n", " \"train/train.tsv.xz\", compression=\"xz\", sep=\"\\t\", names=[\"Label\", \"Text\"]\n", " )\n", " dev_0_dataset = pd.read_csv(\"dev-0/in.tsv\", sep=\"\\t\", names=[\"Text\"])\n", " dev_0_labels = pd.read_csv(\"dev-0/expected.tsv\", sep=\"\\t\", names=[\"Label\"])\n", " test_A_dataset = pd.read_csv(\"test-A/in.tsv\", sep=\"\\t\", names=[\"Text\"])\n", "\n", " return train_dataset, dev_0_dataset, dev_0_labels, test_A_dataset" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "train_dataset, dev_0_dataset, dev_0_labels, test_A_dataset = read_data()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Split the training data into training and validation sets\n" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "train_texts, val_texts, train_labels, val_labels = train_test_split(\n", " train_dataset[\"Text\"], train_dataset[\"Label\"], test_size=0.1, random_state=42\n", ")" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "train_dataset = pd.DataFrame({\"Text\": train_texts, \"Label\": train_labels})\n", "val_dataset = pd.DataFrame({\"Text\": val_texts, \"Label\": val_labels})" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Tokenize the text and labels\n" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "train_dataset[\"tokenized_text\"] = train_dataset[\"Text\"].apply(lambda x: x.split())\n", "train_dataset[\"tokenized_labels\"] = train_dataset[\"Label\"].apply(lambda x: x.split())" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "val_dataset[\"tokenized_text\"] = val_dataset[\"Text\"].apply(lambda x: x.split())\n", "val_dataset[\"tokenized_labels\"] = val_dataset[\"Label\"].apply(lambda x: x.split())" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "dev_0_dataset[\"tokenized_text\"] = dev_0_dataset[\"Text\"].apply(lambda x: x.split())\n", "dev_0_dataset[\"tokenized_labels\"] = dev_0_labels[\"Label\"].apply(lambda x: x.split())" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "test_A_dataset[\"tokenized_text\"] = test_A_dataset[\"Text\"].apply(lambda x: x.split())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Create a vocab object which maps tokens to indices\n" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "def build_vocab(dataset):\n", " counter = Counter()\n", " for document in dataset:\n", " counter.update(document)\n", " return vocab(counter, specials=[\"\", \"\", \"\", \"\"])" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "v = build_vocab(train_dataset[\"tokenized_text\"])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Map indices to tokens\n" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "itos = v.get_itos()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Number of tokens in the vocabulary\n" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "22154" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(itos)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Index of the 'rejects' token\n" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "9086" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "v[\"rejects\"]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Index of the '\\' token\n" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "v[\"\"]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Set the default index to the unknown token\n" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "v.set_default_index(v[\"\"])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Use cuda if available\n" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Vectorize the data\n" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "def data_process(dt):\n", " return [\n", " torch.tensor(\n", " [v[\"\"]] + [v[token] for token in document] + [v[\"\"]],\n", " dtype=torch.long,\n", " device=device,\n", " )\n", " for document in dt\n", " ]" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "train_tokens_ids = data_process(train_dataset[\"tokenized_text\"])\n", "val_tokens_ids = data_process(val_dataset[\"tokenized_text\"])\n", "dev_0_tokens_ids = data_process(dev_0_dataset[\"tokenized_text\"])\n", "test_A_tokens_ids = data_process(test_A_dataset[\"tokenized_text\"])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Create a mapping from label to index\n" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "labels = [\"O\", \"B-PER\", \"I-PER\", \"B-ORG\", \"I-ORG\", \"B-LOC\", \"I-LOC\", \"B-MISC\", \"I-MISC\"]\n", "\n", "label_to_index = {label: idx for idx, label in enumerate(labels)}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Vectorize the labels (NER)\n" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "def labels_process(dt, label_to_index):\n", " return [\n", " torch.tensor(\n", " [0] + [label_to_index[label] for label in document] + [0],\n", " dtype=torch.long,\n", " device=device,\n", " )\n", " for document in dt\n", " ]" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [], "source": [ "train_labels = labels_process(train_dataset[\"tokenized_labels\"], label_to_index)\n", "val_labels = labels_process(val_dataset[\"tokenized_labels\"], label_to_index)\n", "dev_0_labels = labels_process(dev_0_dataset[\"tokenized_labels\"], label_to_index)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Function for evaluation (returns precision, recall, and F1 score)\n" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [], "source": [ "def get_scores(y_true, y_pred):\n", " acc_score = 0\n", " tp = 0\n", " fp = 0\n", " selected_items = 0\n", " relevant_items = 0\n", "\n", " for p, t in zip(y_pred, y_true):\n", " if p == t:\n", " acc_score += 1\n", "\n", " if p > 0 and p == t:\n", " tp += 1\n", "\n", " if p > 0:\n", " selected_items += 1\n", "\n", " if t > 0:\n", " relevant_items += 1\n", "\n", " if selected_items == 0:\n", " precision = 1.0\n", " else:\n", " precision = tp / selected_items\n", "\n", " if relevant_items == 0:\n", " recall = 1.0\n", " else:\n", " recall = tp / relevant_items\n", "\n", " if precision + recall == 0.0:\n", " f1 = 0.0\n", " else:\n", " f1 = 2 * precision * recall / (precision + recall)\n", "\n", " return precision, recall, f1" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Calculate the number of unique tags\n" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "9\n" ] } ], "source": [ "all_label_indices = [\n", " label_to_index[label]\n", " for document in train_dataset[\"tokenized_labels\"]\n", " for label in document\n", "]\n", "\n", "num_tags = max(all_label_indices) + 1\n", "\n", "print(num_tags)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Implementation of a recurrent neural network LSTM\n" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [], "source": [ "class LSTM(torch.nn.Module):\n", "\n", " def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers, num_tags):\n", " super(LSTM, self).__init__()\n", " self.embedding = torch.nn.Embedding(vocab_size, embedding_dim)\n", " self.rec = torch.nn.LSTM(\n", " embedding_dim, hidden_dim, num_layers, batch_first=True, bidirectional=True\n", " )\n", " self.fc1 = torch.nn.Linear(hidden_dim * 2, num_tags)\n", "\n", " def forward(self, x):\n", " embedding = torch.relu(self.embedding(x))\n", " lstm_output, _ = self.rec(embedding)\n", " out_weights = self.fc1(lstm_output)\n", " return out_weights" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Initialize the LSTM model\n" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [], "source": [ "lstm = LSTM(len(v.get_itos()), 100, 100, 1, num_tags).to(device)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Define the loss function\n" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [], "source": [ "criterion = torch.nn.CrossEntropyLoss()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Define the optimizer\n" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [], "source": [ "optimizer = torch.optim.Adam(lstm.parameters())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Function for model evaluation\n" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [], "source": [ "def eval_model(dataset_tokens, dataset_labels, model):\n", " Y_true = []\n", " Y_pred = []\n", " for i in tqdm(range(len(dataset_labels))):\n", " batch_tokens = dataset_tokens[i].unsqueeze(0)\n", " tags = list(dataset_labels[i].cpu().numpy())\n", " Y_true += tags\n", "\n", " Y_batch_pred_weights = model(batch_tokens).squeeze(0)\n", " Y_batch_pred = torch.argmax(Y_batch_pred_weights, 1)\n", " Y_pred += list(Y_batch_pred.cpu().numpy())\n", "\n", " return get_scores(Y_true, Y_pred)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Function for returning the predictions labels\n" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [], "source": [ "def pred_labels(dataset_tokens, model, label_to_index):\n", " Y_pred = []\n", " inv_label_to_index = {\n", " v: k for k, v in label_to_index.items()\n", " } # Create the inverse mapping\n", "\n", " for i in tqdm(range(len(dataset_tokens))):\n", " batch_tokens = dataset_tokens[i].unsqueeze(0)\n", " Y_batch_pred_weights = model(batch_tokens).squeeze(0)\n", " Y_batch_pred = torch.argmax(Y_batch_pred_weights, 1)\n", " predicted_labels = [inv_label_to_index[label.item()] for label in Y_batch_pred]\n", " predicted_labels = predicted_labels[1:-1]\n", " Y_pred.append(\" \".join(predicted_labels))\n", "\n", " return Y_pred" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Training\n" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [], "source": [ "NUM_EPOCHS = 20" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "56ca1f77d2d843bbbf760fead298cb00", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/850 [00:00