1403 lines
38 KiB
Plaintext
1403 lines
38 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"source": [
|
|
"## POS Tagging using LSTM"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false
|
|
},
|
|
"id": "d03db3876ae84fdc"
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"outputs": [],
|
|
"source": [
|
|
"import pandas as pd\n",
|
|
"import numpy as np\n",
|
|
"\n",
|
|
"import torch\n",
|
|
"import torch.nn as nn\n",
|
|
"import torch.optim as optim\n",
|
|
"\n",
|
|
"import warnings\n",
|
|
"warnings.filterwarnings('ignore')\n",
|
|
"\n",
|
|
"import torchtext\n",
|
|
"from torchtext.vocab import vocab\n",
|
|
"\n",
|
|
"from seqeval.metrics import accuracy_score, precision_score, recall_score, f1_score, classification_report\n",
|
|
"\n",
|
|
"from tqdm.notebook import tqdm\n",
|
|
"\n",
|
|
"import datasets\n",
|
|
"\n",
|
|
"from collections import Counter"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"ExecuteTime": {
|
|
"end_time": "2024-05-30T04:56:09.952503200Z",
|
|
"start_time": "2024-05-30T04:56:06.967530400Z"
|
|
}
|
|
},
|
|
"id": "583c93622c61177b"
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"outputs": [],
|
|
"source": [
|
|
"# Load the dataset\n",
|
|
"dataset = datasets.load_dataset('batterydata/pos_tagging')"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"ExecuteTime": {
|
|
"end_time": "2024-05-30T04:56:22.604270500Z",
|
|
"start_time": "2024-05-30T04:56:14.602312200Z"
|
|
}
|
|
},
|
|
"id": "9a73f4af39424a1f"
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"outputs": [],
|
|
"source": [
|
|
"# Convert the dataset to pandas DataFrame\n",
|
|
"train_dataset = dataset['train']\n",
|
|
"test_dataset = dataset['test']\n",
|
|
"\n",
|
|
"train_dataset.set_format(type='pandas')\n",
|
|
"test_dataset.set_format(type='pandas')\n",
|
|
"\n",
|
|
"df_train = pd.concat([train_dataset['words'], train_dataset['labels']], axis=1)\n",
|
|
"df_test = pd.concat([test_dataset['words'], test_dataset['labels']], axis=1)"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"ExecuteTime": {
|
|
"end_time": "2024-05-30T05:02:39.173066900Z",
|
|
"start_time": "2024-05-30T05:02:39.117326300Z"
|
|
}
|
|
},
|
|
"id": "f2d1e260eb9cad0"
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 81,
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": "1451"
|
|
},
|
|
"execution_count": 81,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"len(df_test)"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"ExecuteTime": {
|
|
"end_time": "2024-05-30T05:34:37.254376Z",
|
|
"start_time": "2024-05-30T05:34:37.240989300Z"
|
|
}
|
|
},
|
|
"id": "60c16f74d5df36b0"
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 83,
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": "13054"
|
|
},
|
|
"execution_count": 83,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"len(df_train)"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"ExecuteTime": {
|
|
"end_time": "2024-05-30T05:34:45.153855800Z",
|
|
"start_time": "2024-05-30T05:34:45.137542300Z"
|
|
}
|
|
},
|
|
"id": "184cfb64cddd5c51"
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"outputs": [],
|
|
"source": [
|
|
"# Method for building the vocabulary from DataFrame dataset\n",
|
|
"# Special tokens:\n",
|
|
"# <unk> - unknown token\n",
|
|
"# <pad> - padding token\n",
|
|
"# <bos> - beginning of sentence token\n",
|
|
"# <eos> - end of sentence token\n",
|
|
"def build_vocab(dataset):\n",
|
|
" # Initialize the counter\n",
|
|
" counter = Counter()\n",
|
|
" \n",
|
|
" # Iterate over the dataset and update the counter\n",
|
|
" for idx, document in dataset.iterrows():\n",
|
|
" counter.update(document['words'])\n",
|
|
" \n",
|
|
" # Return the vocabulary\n",
|
|
" return vocab(counter, specials=['<unk>', '<pad>', '<bos>', '<eos>'])"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"ExecuteTime": {
|
|
"end_time": "2024-05-30T05:03:08.896211Z",
|
|
"start_time": "2024-05-30T05:03:08.891565300Z"
|
|
}
|
|
},
|
|
"id": "d0ab581622dec851"
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 11,
|
|
"outputs": [],
|
|
"source": [
|
|
"# Build the vocabulary\n",
|
|
"v = build_vocab(df_train)"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"ExecuteTime": {
|
|
"end_time": "2024-05-30T05:03:23.852859100Z",
|
|
"start_time": "2024-05-30T05:03:23.410789500Z"
|
|
}
|
|
},
|
|
"id": "cfac7f6325c6bc0a"
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 84,
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": "24851"
|
|
},
|
|
"execution_count": 84,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"len(v)"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"ExecuteTime": {
|
|
"end_time": "2024-05-30T05:35:56.342196300Z",
|
|
"start_time": "2024-05-30T05:35:56.326491100Z"
|
|
}
|
|
},
|
|
"id": "2a599cdb42e1dd7e"
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 12,
|
|
"outputs": [],
|
|
"source": [
|
|
"# Mapping from index to token\n",
|
|
"itos = v.get_itos()"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"ExecuteTime": {
|
|
"end_time": "2024-05-30T05:03:31.890910700Z",
|
|
"start_time": "2024-05-30T05:03:31.877808400Z"
|
|
}
|
|
},
|
|
"id": "1669b13ea4c7e3d7"
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 13,
|
|
"outputs": [],
|
|
"source": [
|
|
"# Set default index for unknown tokens\n",
|
|
"v.set_default_index(v[\"<unk>\"])"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"ExecuteTime": {
|
|
"end_time": "2024-05-30T05:03:37.066493400Z",
|
|
"start_time": "2024-05-30T05:03:37.058550900Z"
|
|
}
|
|
},
|
|
"id": "4a5612b9816daf0d"
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 54,
|
|
"outputs": [],
|
|
"source": [
|
|
"# Get unique POS tags\n",
|
|
"pos_tags = df_train['labels'].explode().unique().tolist()"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"ExecuteTime": {
|
|
"end_time": "2024-05-30T05:18:33.019728200Z",
|
|
"start_time": "2024-05-30T05:18:32.985438100Z"
|
|
}
|
|
},
|
|
"id": "e205a3f2fa7468a9"
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 55,
|
|
"outputs": [],
|
|
"source": [
|
|
"# Mapping from POS tag to index\n",
|
|
"label2idx = {label: idx for idx, label in enumerate(pos_tags)}\n",
|
|
"\n",
|
|
"# Mapping from index to POS tag\n",
|
|
"idx2label = {idx: label for label, idx in label2idx.items()}"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"ExecuteTime": {
|
|
"end_time": "2024-05-30T05:18:37.168881400Z",
|
|
"start_time": "2024-05-30T05:18:37.163076800Z"
|
|
}
|
|
},
|
|
"id": "c39568b2c58a89e5"
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 56,
|
|
"outputs": [],
|
|
"source": [
|
|
"# Method for vectorizing text data using the vocabulary mapping\n",
|
|
"def text_to_vec(data):\n",
|
|
" return [torch.tensor([v['<bos>']] + [v[token] for token in document] + [v['<eos>']], dtype=torch.long) for document in data]"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"ExecuteTime": {
|
|
"end_time": "2024-05-30T05:18:38.256181300Z",
|
|
"start_time": "2024-05-30T05:18:38.247671200Z"
|
|
}
|
|
},
|
|
"id": "a65bdf3264844e78"
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 57,
|
|
"outputs": [],
|
|
"source": [
|
|
"# Method for vectorizing POS tags data using the POS tags mapping\n",
|
|
"def pos_tags_to_vec(data):\n",
|
|
" return [torch.tensor([20] + [label2idx[tag] for tag in document] + [20], dtype=torch.long) for document in data]"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"ExecuteTime": {
|
|
"end_time": "2024-05-30T05:18:38.646303500Z",
|
|
"start_time": "2024-05-30T05:18:38.637786100Z"
|
|
}
|
|
},
|
|
"id": "90ceb0f6639c23f6"
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 96,
|
|
"outputs": [],
|
|
"source": [
|
|
"# Vectorize the text data (input)\n",
|
|
"X_train = text_to_vec(df_train['words'])\n",
|
|
"X_test = text_to_vec(df_test['words'])"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"ExecuteTime": {
|
|
"end_time": "2024-05-30T05:45:44.884726Z",
|
|
"start_time": "2024-05-30T05:45:44.390728500Z"
|
|
}
|
|
},
|
|
"id": "c32f2310d38b3442"
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 97,
|
|
"outputs": [],
|
|
"source": [
|
|
"# Vectorize the POS tags data (output)\n",
|
|
"y_train = pos_tags_to_vec(df_train['labels'])\n",
|
|
"y_test = pos_tags_to_vec(df_test['labels'])"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"ExecuteTime": {
|
|
"end_time": "2024-05-30T05:45:46.004430800Z",
|
|
"start_time": "2024-05-30T05:45:45.746219600Z"
|
|
}
|
|
},
|
|
"id": "8255ae3faf474132"
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"source": [
|
|
"## LSTM Models"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false
|
|
},
|
|
"id": "add07dd1d8b699f9"
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 86,
|
|
"outputs": [],
|
|
"source": [
|
|
"# Basic LSTM model\n",
|
|
"class LSTM(nn.Module):\n",
|
|
" def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim):\n",
|
|
" super(LSTM, self).__init__()\n",
|
|
" \n",
|
|
" # Embedding layer\n",
|
|
" self.embedding = nn.Embedding(vocab_size, embedding_dim)\n",
|
|
" \n",
|
|
" # LSTM layer\n",
|
|
" self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first = True)\n",
|
|
" \n",
|
|
" # Fully connected layer\n",
|
|
" self.fc = nn.Linear(hidden_dim, output_dim)\n",
|
|
" \n",
|
|
" self.relu = nn.ReLU()\n",
|
|
" \n",
|
|
" def forward(self, x):\n",
|
|
" # Embedding\n",
|
|
" embedding = self.relu(self.embedding(x))\n",
|
|
" \n",
|
|
" # LSTM\n",
|
|
" output, (hidden, cell) = self.lstm(embedding)\n",
|
|
" \n",
|
|
" # Fully connected\n",
|
|
" output = self.fc(output)\n",
|
|
" \n",
|
|
" return output"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"ExecuteTime": {
|
|
"end_time": "2024-05-30T05:42:11.781180Z",
|
|
"start_time": "2024-05-30T05:42:11.762152200Z"
|
|
}
|
|
},
|
|
"id": "d9d9e2b81dca3e47"
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 88,
|
|
"outputs": [],
|
|
"source": [
|
|
"# LSTM model with dropout\n",
|
|
"class LSTMWithDropout(nn.Module):\n",
|
|
" def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim, dropout_prob=0.5):\n",
|
|
" super(LSTMWithDropout, self).__init__()\n",
|
|
" \n",
|
|
" # Embedding layer\n",
|
|
" self.embedding = nn.Embedding(vocab_size, embedding_dim)\n",
|
|
" \n",
|
|
" # LSTM layer\n",
|
|
" self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)\n",
|
|
" \n",
|
|
" # Dropout layer\n",
|
|
" self.dropout = nn.Dropout(dropout_prob)\n",
|
|
" \n",
|
|
" # Fully connected layer\n",
|
|
" self.fc = nn.Linear(hidden_dim, output_dim)\n",
|
|
" \n",
|
|
" self.relu = nn.ReLU()\n",
|
|
" \n",
|
|
" def forward(self, x):\n",
|
|
" # Embedding\n",
|
|
" embedding = self.relu(self.embedding(x))\n",
|
|
" \n",
|
|
" # LSTM\n",
|
|
" output, (hidden, cell) = self.lstm(embedding)\n",
|
|
" \n",
|
|
" # Dropout\n",
|
|
" output = self.dropout(output)\n",
|
|
" \n",
|
|
" # Fully connected\n",
|
|
" output = self.fc(output)\n",
|
|
" \n",
|
|
" return output"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"ExecuteTime": {
|
|
"end_time": "2024-05-30T05:42:38.928687500Z",
|
|
"start_time": "2024-05-30T05:42:38.922120700Z"
|
|
}
|
|
},
|
|
"id": "d8e190f59de1e675"
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 89,
|
|
"outputs": [],
|
|
"source": [
|
|
"class StackedLSTM(nn.Module):\n",
|
|
" def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim, num_layers=2):\n",
|
|
" super(StackedLSTM, self).__init__()\n",
|
|
" \n",
|
|
" # Embedding layer\n",
|
|
" self.embedding = nn.Embedding(vocab_size, embedding_dim)\n",
|
|
" \n",
|
|
" # Stacked LSTM layers\n",
|
|
" self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers=num_layers, batch_first=True)\n",
|
|
" \n",
|
|
" # Fully connected layer\n",
|
|
" self.fc = nn.Linear(hidden_dim, output_dim)\n",
|
|
" \n",
|
|
" self.relu = nn.ReLU()\n",
|
|
" \n",
|
|
" def forward(self, x):\n",
|
|
" # Embedding\n",
|
|
" embedding = self.relu(self.embedding(x))\n",
|
|
" \n",
|
|
" # LSTM\n",
|
|
" output, (hidden, cell) = self.lstm(embedding)\n",
|
|
" \n",
|
|
" # Fully connected\n",
|
|
" output = self.fc(output)\n",
|
|
" \n",
|
|
" return output"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"ExecuteTime": {
|
|
"end_time": "2024-05-30T05:42:41.562489200Z",
|
|
"start_time": "2024-05-30T05:42:41.542254700Z"
|
|
}
|
|
},
|
|
"id": "c38a934d939afecf"
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 91,
|
|
"outputs": [],
|
|
"source": [
|
|
"class BidirectionalLSTM(nn.Module):\n",
|
|
" def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim):\n",
|
|
" super(BidirectionalLSTM, self).__init__()\n",
|
|
" \n",
|
|
" # Embedding layer\n",
|
|
" self.embedding = nn.Embedding(vocab_size, embedding_dim)\n",
|
|
" \n",
|
|
" # Bidirectional LSTM layer\n",
|
|
" self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True, bidirectional=True)\n",
|
|
" \n",
|
|
" # Fully connected layer\n",
|
|
" self.fc = nn.Linear(hidden_dim * 2, output_dim)\n",
|
|
" \n",
|
|
" self.relu = nn.ReLU()\n",
|
|
" \n",
|
|
" def forward(self, x):\n",
|
|
" # Embedding\n",
|
|
" embedding = self.relu(self.embedding(x))\n",
|
|
" \n",
|
|
" # LSTM\n",
|
|
" output, (hidden, cell) = self.lstm(embedding)\n",
|
|
" \n",
|
|
" # Concatenate the outputs from both directions\n",
|
|
" output = self.fc(output)\n",
|
|
" \n",
|
|
" return output"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"ExecuteTime": {
|
|
"end_time": "2024-05-30T05:42:59.916550600Z",
|
|
"start_time": "2024-05-30T05:42:59.908460800Z"
|
|
}
|
|
},
|
|
"id": "211c084868ad07ac"
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"source": [
|
|
"## Training and Evaluation Methods"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false
|
|
},
|
|
"id": "4b3e5007817bcf05"
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 146,
|
|
"outputs": [],
|
|
"source": [
|
|
"# Segeval evaluation\n",
|
|
"def evaluate_model(model, X_test, y_test):\n",
|
|
" \"\"\"\n",
|
|
" Method for evaluating the model\n",
|
|
" :param model: model\n",
|
|
" :param X: input data\n",
|
|
" :param y: output data \n",
|
|
" :return: dictionary with metrics values\n",
|
|
" \"\"\"\n",
|
|
" # Use GPU if available\n",
|
|
" device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
|
|
" \n",
|
|
" # Move the model to the device\n",
|
|
" model = model.to(device)\n",
|
|
" \n",
|
|
" # Move the data to the device\n",
|
|
" X = [x.to(device) for x in X_test]\n",
|
|
" y = [y.to(device) for y in y_test]\n",
|
|
" \n",
|
|
" # No gradients\n",
|
|
" with torch.no_grad():\n",
|
|
" # Predict the labels\n",
|
|
" y_pred = [torch.argmax(model(x.unsqueeze(0)).squeeze(0), 1) for x in X]\n",
|
|
" \n",
|
|
" # Convert the labels to ner tags\n",
|
|
" y_pred = [[idx2label[int(idx)] for idx in y] for y in y_pred]\n",
|
|
" y_tags = [[idx2label[int(idx)] for idx in y] for y in y]\n",
|
|
" \n",
|
|
" # Calculate the metrics\n",
|
|
" accuracy = accuracy_score(y_tags, y_pred)\n",
|
|
" precision = precision_score(y_tags, y_pred)\n",
|
|
" recall = recall_score(y_tags, y_pred)\n",
|
|
" f1 = f1_score(y_tags, y_pred)\n",
|
|
" \n",
|
|
" return {'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1': f1}"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"ExecuteTime": {
|
|
"end_time": "2024-05-30T06:19:54.397919700Z",
|
|
"start_time": "2024-05-30T06:19:54.382149800Z"
|
|
}
|
|
},
|
|
"id": "481abca2f316793c"
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 112,
|
|
"outputs": [],
|
|
"source": [
|
|
"import random\n",
|
|
"\n",
|
|
"# Train model\n",
|
|
"def train(model, X_train, y_train, X_test, y_test, epochs = 5, seed=1234):\n",
|
|
" \"\"\"\n",
|
|
" Method for training the model\n",
|
|
" :param model: model\n",
|
|
" :param X_train: input data for training\n",
|
|
" :param y_train: output data for training\n",
|
|
" :param X_test: input data for testing\n",
|
|
" :param y_test: output data for testing\n",
|
|
" :param epochs: number of epochs\n",
|
|
" \"\"\"\n",
|
|
" # Seed for reproducibility\n",
|
|
" torch.manual_seed(seed)\n",
|
|
" random.seed(seed)\n",
|
|
" np.random.seed(seed)\n",
|
|
"\n",
|
|
" # Use GPU if available\n",
|
|
" device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
|
|
" \n",
|
|
" # Loss function and optimizer\n",
|
|
" criterion = nn.CrossEntropyLoss()\n",
|
|
" optimizer = optim.Adam(model.parameters())\n",
|
|
" \n",
|
|
" # Move training to GPU\n",
|
|
" model = model.to(device)\n",
|
|
" X_train_device = [x.to(device) for x in X_train]\n",
|
|
" y_train_device = [y.to(device) for y in y_train]\n",
|
|
" X_test_device = [x.to(device) for x in X_test]\n",
|
|
" y_test_device = [y.to(device) for y in y_test]\n",
|
|
" \n",
|
|
" # Training loop\n",
|
|
" model.train()\n",
|
|
"\n",
|
|
" for epoch in range(epochs):\n",
|
|
" for idx in tqdm(range(len(X_train_device))):\n",
|
|
" # Zero the gradients\n",
|
|
" optimizer.zero_grad()\n",
|
|
" \n",
|
|
" # Forward pass\n",
|
|
" output = model(X_train_device[idx].unsqueeze(0))\n",
|
|
" \n",
|
|
" # Calculate the loss\n",
|
|
" loss = criterion(output.squeeze(0), y_train_device[idx])\n",
|
|
" \n",
|
|
" # Backward pass\n",
|
|
" loss.backward()\n",
|
|
" \n",
|
|
" # Update the weights\n",
|
|
" optimizer.step()\n",
|
|
" \n",
|
|
" # Evaluate the model on the dev set\n",
|
|
" metrics = evaluate_model(model, X_test_device, y_test_device)\n",
|
|
" \n",
|
|
" print(f'Epoch: {epoch+1}, Accuracy: {metrics[\"accuracy\"]}, Precision: {metrics[\"precision\"]}, Recall: {metrics[\"recall\"]}, F1: {metrics[\"f1\"]}')"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"ExecuteTime": {
|
|
"end_time": "2024-05-30T05:59:17.968008900Z",
|
|
"start_time": "2024-05-30T05:59:17.949318600Z"
|
|
}
|
|
},
|
|
"id": "c9c6e4a60baaf950"
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"source": [
|
|
"## Basic LSTM Model"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false
|
|
},
|
|
"id": "9a2ef38c4595d331"
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 107,
|
|
"outputs": [],
|
|
"source": [
|
|
"# Model parameters\n",
|
|
"vocab_size = len(v)\n",
|
|
"embedding_dim = 64\n",
|
|
"hidden_dim = 128\n",
|
|
"output_dim = len(pos_tags)\n",
|
|
"epochs = 7"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"ExecuteTime": {
|
|
"end_time": "2024-05-30T05:51:54.574424200Z",
|
|
"start_time": "2024-05-30T05:51:54.559314300Z"
|
|
}
|
|
},
|
|
"id": "e6df7aaff2d06e84"
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 108,
|
|
"outputs": [],
|
|
"source": [
|
|
"# Initialize the model\n",
|
|
"model = LSTM(vocab_size, embedding_dim, hidden_dim, output_dim)"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"ExecuteTime": {
|
|
"end_time": "2024-05-30T05:51:54.908172500Z",
|
|
"start_time": "2024-05-30T05:51:54.889921400Z"
|
|
}
|
|
},
|
|
"id": "a8abf9db9958387f"
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 109,
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": " 0%| | 0/13054 [00:00<?, ?it/s]",
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"version_major": 2,
|
|
"version_minor": 0,
|
|
"model_id": "515291b7575c48ff9f8da5ba5d91c1db"
|
|
}
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Epoch: 1, Accuracy: 0.8818851395991876, Precision: 0.8438832404066907, Recall: 0.8296350578924226, F1: 0.8366984952848316\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": " 0%| | 0/13054 [00:00<?, ?it/s]",
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"version_major": 2,
|
|
"version_minor": 0,
|
|
"model_id": "9374ce8f07d348e29f8188f20486f59d"
|
|
}
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Epoch: 2, Accuracy: 0.9159957912251939, Precision: 0.8855693514613502, Recall: 0.8801700131906786, F1: 0.8828614271853225\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": " 0%| | 0/13054 [00:00<?, ?it/s]",
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"version_major": 2,
|
|
"version_minor": 0,
|
|
"model_id": "abaa2b488a2f4c59b9e5f67e8aa2f7c6"
|
|
}
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Epoch: 3, Accuracy: 0.9277656789096337, Precision: 0.9007945850500294, Recall: 0.8972299574967023, F1: 0.8990087377927893\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": " 0%| | 0/13054 [00:00<?, ?it/s]",
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"version_major": 2,
|
|
"version_minor": 0,
|
|
"model_id": "782784eb278b44129332bc0f027703be"
|
|
}
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Epoch: 4, Accuracy: 0.9327330119656446, Precision: 0.9065371180321132, Recall: 0.9052616151253114, F1: 0.9058989176028865\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": " 0%| | 0/13054 [00:00<?, ?it/s]",
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"version_major": 2,
|
|
"version_minor": 0,
|
|
"model_id": "2f903f740cf848b8b8f2b574877328f1"
|
|
}
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Epoch: 5, Accuracy: 0.9348129297477182, Precision: 0.9091175694301886, Recall: 0.9086911915579657, F1: 0.9089043304893424\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": " 0%| | 0/13054 [00:00<?, ?it/s]",
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"version_major": 2,
|
|
"version_minor": 0,
|
|
"model_id": "3212513399c4473c948267cbd763f914"
|
|
}
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Epoch: 6, Accuracy: 0.9357427753444099, Precision: 0.9102924799249751, Recall: 0.9104792613219991, F1: 0.910385861043129\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": " 0%| | 0/13054 [00:00<?, ?it/s]",
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"version_major": 2,
|
|
"version_minor": 0,
|
|
"model_id": "bc5ea13dcdea458a856548abe7906d92"
|
|
}
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Epoch: 7, Accuracy: 0.9371864829813786, Precision: 0.912434017595308, Recall: 0.9120328301333724, F1: 0.9122333797551859\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# Train the model\n",
|
|
"train(model, X_train, y_train, X_test, y_test, epochs)"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"ExecuteTime": {
|
|
"end_time": "2024-05-30T05:55:25.168469300Z",
|
|
"start_time": "2024-05-30T05:51:55.610716600Z"
|
|
}
|
|
},
|
|
"id": "5aceca013c905053"
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 113,
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": "{'accuracy': 0.9371864829813786,\n 'precision': 0.912434017595308,\n 'recall': 0.9120328301333724,\n 'f1': 0.9122333797551859}"
|
|
},
|
|
"execution_count": 113,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"# Evaluate the model\n",
|
|
"evaluate_model(model, X_test, y_test)"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"ExecuteTime": {
|
|
"end_time": "2024-05-30T05:59:24.940767100Z",
|
|
"start_time": "2024-05-30T05:59:22.139764600Z"
|
|
}
|
|
},
|
|
"id": "779cfabb336a5c80"
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"source": [
|
|
"## LSTM Model with Dropout"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false
|
|
},
|
|
"id": "96f067682bec3e24"
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 121,
|
|
"outputs": [],
|
|
"source": [
|
|
"# Model parameters\n",
|
|
"vocab_size = len(v)\n",
|
|
"embedding_dim = 64\n",
|
|
"hidden_dim = 128\n",
|
|
"output_dim = len(pos_tags)\n",
|
|
"epochs = 7\n",
|
|
"p = 0.2"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"ExecuteTime": {
|
|
"end_time": "2024-05-30T06:04:25.501735500Z",
|
|
"start_time": "2024-05-30T06:04:25.489907600Z"
|
|
}
|
|
},
|
|
"id": "d8cbd5ad588688ce"
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 119,
|
|
"outputs": [],
|
|
"source": [
|
|
"# Initialize the model\n",
|
|
"model_dropout = LSTMWithDropout(vocab_size, embedding_dim, hidden_dim, output_dim, dropout_prob=p)"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"ExecuteTime": {
|
|
"end_time": "2024-05-30T06:02:15.269474400Z",
|
|
"start_time": "2024-05-30T06:02:15.233094500Z"
|
|
}
|
|
},
|
|
"id": "d2a5077a0951d358"
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 122,
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": " 0%| | 0/13054 [00:00<?, ?it/s]",
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"version_major": 2,
|
|
"version_minor": 0,
|
|
"model_id": "f8d725fae33a4a14959420f056fe2aaa"
|
|
}
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Epoch: 1, Accuracy: 0.925294247192111, Precision: 0.8976234540700919, Recall: 0.8956763886853291, F1: 0.8966488643699748\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": " 0%| | 0/13054 [00:00<?, ?it/s]",
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"version_major": 2,
|
|
"version_minor": 0,
|
|
"model_id": "7b77ef5608e045fea539b0a64f49d19b"
|
|
}
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Epoch: 2, Accuracy: 0.9295274916191548, Precision: 0.902512680681385, Recall: 0.9023010405979774, F1: 0.902406848230776\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": " 0%| | 0/13054 [00:00<?, ?it/s]",
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"version_major": 2,
|
|
"version_minor": 0,
|
|
"model_id": "cfc7628f16a546f5998d2d25b27e1ac7"
|
|
}
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Epoch: 3, Accuracy: 0.9336873271833019, Precision: 0.9082087364409264, Recall: 0.9080756265572328, F1: 0.908142176621473\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": " 0%| | 0/13054 [00:00<?, ?it/s]",
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"version_major": 2,
|
|
"version_minor": 0,
|
|
"model_id": "dccaad28d648438fa62c1baa32bbeb1c"
|
|
}
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Epoch: 4, Accuracy: 0.9367949690459295, Precision: 0.9122472897743921, Recall: 0.9126483951341052, F1: 0.9124477983735073\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": " 0%| | 0/13054 [00:00<?, ?it/s]",
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"version_major": 2,
|
|
"version_minor": 0,
|
|
"model_id": "53d58076bfe842f0ad0638d4edaa192c"
|
|
}
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Epoch: 5, Accuracy: 0.9367215601830328, Precision: 0.9114772328221936, Recall: 0.9130001465630954, F1: 0.9122380540952157\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# Train the model\n",
|
|
"train(model_dropout, X_train, y_train, X_test, y_test, epochs)"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"ExecuteTime": {
|
|
"end_time": "2024-05-30T06:07:13.576862500Z",
|
|
"start_time": "2024-05-30T06:04:28.703023800Z"
|
|
}
|
|
},
|
|
"id": "c6e9223a90172dd1"
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 123,
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": "{'accuracy': 0.9370396652555852,\n 'precision': 0.9120747203841424,\n 'recall': 0.9131173970394255,\n 'f1': 0.912595760887079}"
|
|
},
|
|
"execution_count": 123,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"# Evaluate the model\n",
|
|
"evaluate_model(model_dropout, X_test, y_test)"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"ExecuteTime": {
|
|
"end_time": "2024-05-30T06:07:50.143491200Z",
|
|
"start_time": "2024-05-30T06:07:47.040875200Z"
|
|
}
|
|
},
|
|
"id": "3531050a9322cc1b"
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"source": [
|
|
"## Stacked LSTM Model"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false
|
|
},
|
|
"id": "9d337673511c4d2"
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 124,
|
|
"outputs": [],
|
|
"source": [
|
|
"# Model parameters\n",
|
|
"vocab_size = len(v)\n",
|
|
"embedding_dim = 64\n",
|
|
"hidden_dim = 128\n",
|
|
"output_dim = len(pos_tags)\n",
|
|
"epochs = 7\n",
|
|
"num_layers = 2"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"ExecuteTime": {
|
|
"end_time": "2024-05-30T06:08:20.161370100Z",
|
|
"start_time": "2024-05-30T06:08:20.144182900Z"
|
|
}
|
|
},
|
|
"id": "6127a7918a21199a"
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 125,
|
|
"outputs": [],
|
|
"source": [
|
|
"# Initialize the model\n",
|
|
"model_stacked = StackedLSTM(vocab_size, embedding_dim, hidden_dim, output_dim, num_layers=num_layers)"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"ExecuteTime": {
|
|
"end_time": "2024-05-30T06:08:21.257363800Z",
|
|
"start_time": "2024-05-30T06:08:21.216761300Z"
|
|
}
|
|
},
|
|
"id": "d32c10fbbed9d7bc"
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 128,
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": " 0%| | 0/13054 [00:00<?, ?it/s]",
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"version_major": 2,
|
|
"version_minor": 0,
|
|
"model_id": "cc09cf0c032e4c4291691ca317656728"
|
|
}
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Epoch: 1, Accuracy: 0.9349597474735116, Precision: 0.9083610673206048, Recall: 0.9120621427524549, F1: 0.9102078427357427\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# Train the model\n",
|
|
"train(model_stacked, X_train, y_train, X_test, y_test, epochs)"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"ExecuteTime": {
|
|
"end_time": "2024-05-30T06:12:42.768876800Z",
|
|
"start_time": "2024-05-30T06:12:08.498248900Z"
|
|
}
|
|
},
|
|
"id": "c57fb6465d990e49"
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 129,
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": "{'accuracy': 0.9349597474735116,\n 'precision': 0.9083610673206048,\n 'recall': 0.9120621427524549,\n 'f1': 0.9102078427357427}"
|
|
},
|
|
"execution_count": 129,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"# Evaluate the model\n",
|
|
"evaluate_model(model_stacked, X_test, y_test)"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"ExecuteTime": {
|
|
"end_time": "2024-05-30T06:13:06.044600600Z",
|
|
"start_time": "2024-05-30T06:13:02.712430800Z"
|
|
}
|
|
},
|
|
"id": "322ce510fd4fa3ea"
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"source": [
|
|
"## Bidirectional LSTM Model"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false
|
|
},
|
|
"id": "f9a8a6d869257542"
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 148,
|
|
"outputs": [],
|
|
"source": [
|
|
"# Model parameters\n",
|
|
"vocab_size = len(v)\n",
|
|
"embedding_dim = 64\n",
|
|
"hidden_dim = 128\n",
|
|
"output_dim = len(pos_tags)\n",
|
|
"epochs = 5"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"ExecuteTime": {
|
|
"end_time": "2024-05-30T06:21:47.339847400Z",
|
|
"start_time": "2024-05-30T06:21:47.331346100Z"
|
|
}
|
|
},
|
|
"id": "4da1cb8f7cd0cfd7"
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 149,
|
|
"outputs": [],
|
|
"source": [
|
|
"# Initialize the model\n",
|
|
"model_bidirectional = BidirectionalLSTM(vocab_size, embedding_dim, hidden_dim, output_dim)"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"ExecuteTime": {
|
|
"end_time": "2024-05-30T06:21:47.664889Z",
|
|
"start_time": "2024-05-30T06:21:47.632812600Z"
|
|
}
|
|
},
|
|
"id": "8990124b8eb7086d"
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 150,
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": " 0%| | 0/13054 [00:00<?, ?it/s]",
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"version_major": 2,
|
|
"version_minor": 0,
|
|
"model_id": "27b3ef20d9304948b4c27b37efb7ec58"
|
|
}
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Epoch: 1, Accuracy: 0.9022438642425429, Precision: 0.8723924915694291, Recall: 0.8568957936391617, F1: 0.8645747072045428\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": " 0%| | 0/13054 [00:00<?, ?it/s]",
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"version_major": 2,
|
|
"version_minor": 0,
|
|
"model_id": "d79b2fae6ece4b1884a4c34d06161525"
|
|
}
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Epoch: 2, Accuracy: 0.9338586145300609, Precision: 0.9106548443161399, Recall: 0.9061703063168695, F1: 0.908407040639417\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": " 0%| | 0/13054 [00:00<?, ?it/s]",
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"version_major": 2,
|
|
"version_minor": 0,
|
|
"model_id": "46f439ac39ce42be9dd9b3400bbd1e23"
|
|
}
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Epoch: 3, Accuracy: 0.940587760295593, Precision: 0.9189085996240601, Recall: 0.9171039132346475, F1: 0.9180053694819769\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": " 0%| | 0/13054 [00:00<?, ?it/s]",
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"version_major": 2,
|
|
"version_minor": 0,
|
|
"model_id": "d31d5fe35d624205802f91793d160826"
|
|
}
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Epoch: 4, Accuracy: 0.9427900261824944, Precision: 0.9211352763347128, Recall: 0.9199472372856514, F1: 0.9205408734930924\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": " 0%| | 0/13054 [00:00<?, ?it/s]",
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"version_major": 2,
|
|
"version_minor": 0,
|
|
"model_id": "078c2c9033a34a71970dd27366e15138"
|
|
}
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Epoch: 5, Accuracy: 0.9433772970856682, Precision: 0.9217263652378156, Recall: 0.9202403634764766, F1: 0.920982764943161\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# Train the model\n",
|
|
"train(model_bidirectional, X_train, y_train, X_test, y_test, epochs)"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"ExecuteTime": {
|
|
"end_time": "2024-05-30T06:24:23.582057100Z",
|
|
"start_time": "2024-05-30T06:21:48.130817800Z"
|
|
}
|
|
},
|
|
"id": "2c735cf20d0f87f8"
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 151,
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": "{'accuracy': 0.9433772970856682,\n 'precision': 0.9217263652378156,\n 'recall': 0.9202403634764766,\n 'f1': 0.920982764943161}"
|
|
},
|
|
"execution_count": 151,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"# Evaluate the model\n",
|
|
"evaluate_model(model_bidirectional, X_test, y_test)"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"ExecuteTime": {
|
|
"end_time": "2024-05-30T06:24:37.419868900Z",
|
|
"start_time": "2024-05-30T06:24:34.082828Z"
|
|
}
|
|
},
|
|
"id": "6e439a4811a3f20e"
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 2
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython2",
|
|
"version": "2.7.6"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|