uczenie_glebokie_rnn/RNN.ipynb

563 lines
15 KiB
Plaintext
Raw Permalink Normal View History

2024-05-27 00:29:45 +02:00
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "cfcbab0f-15cd-4357-ba23-9160a592f2f1",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"C:\\Users\\Szpil\\anaconda3\\envs\\py310\\lib\\site-packages\\torchtext\\vocab\\__init__.py:4: UserWarning: \n",
"/!\\ IMPORTANT WARNING ABOUT TORCHTEXT STATUS /!\\ \n",
"Torchtext is deprecated and the last released version will be 0.18 (this one). You can silence this warning by calling the following at the beginnign of your scripts: `import torchtext; torchtext.disable_torchtext_deprecation_warning()`\n",
" warnings.warn(torchtext._TORCHTEXT_DEPRECATION_MSG)\n",
"C:\\Users\\Szpil\\anaconda3\\envs\\py310\\lib\\site-packages\\torchtext\\utils.py:4: UserWarning: \n",
"/!\\ IMPORTANT WARNING ABOUT TORCHTEXT STATUS /!\\ \n",
"Torchtext is deprecated and the last released version will be 0.18 (this one). You can silence this warning by calling the following at the beginnign of your scripts: `import torchtext; torchtext.disable_torchtext_deprecation_warning()`\n",
" warnings.warn(torchtext._TORCHTEXT_DEPRECATION_MSG)\n"
]
}
],
"source": [
"from collections import Counter\n",
"import pandas as pd\n",
"import numpy as np\n",
"\n",
"import gensim\n",
"\n",
"from tensorflow.keras.models import Sequential\n",
"from tensorflow.keras.layers import Dense\n",
"import matplotlib.pyplot as plt\n",
"from keras.regularizers import l2\n",
"from tqdm.notebook import tqdm\n",
"\n",
"import torch\n",
"from torchtext.vocab import vocab"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "49faae52-6c7b-415f-ba56-a244cb9e5c9f",
"metadata": {},
"outputs": [],
"source": [
"def build_vocab(dataset):\n",
" counter = Counter()\n",
" for document in dataset:\n",
" counter.update(document)\n",
" return vocab(counter, specials=[\"<unk>\", \"<pad>\", \"<bos>\", \"<eos>\"])"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "691f40e5-d976-42e9-afbd-c51ce06b9077",
"metadata": {},
"outputs": [],
"source": [
"def fit_data_Y(column):\n",
" dt = [\n",
" [ner_dict[token] for token in row.split()] for row in column\n",
" ]\n",
" return [torch.tensor([0] + document + [0], dtype=torch.long) for document in dt]\n",
"\n",
"def fit_data_X(dt):\n",
" return [\n",
" torch.tensor(\n",
" [v[\"<bos>\"]] + [v[token] for token in document.split()] + [v[\"<eos>\"]],\n",
" dtype=torch.long,\n",
" )\n",
" for document in dt\n",
" ]"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "2e357038-013d-4887-804a-c3718ab82d4f",
"metadata": {},
"outputs": [],
"source": [
"def predict(X):\n",
" Y_predicted = []\n",
" for i in tqdm(range(len(X))):\n",
" batch_tokens = X[i].unsqueeze(0)\n",
" Y_batch_pred_weights = lstm(batch_tokens).squeeze(0)\n",
" Y_batch_pred = torch.argmax(Y_batch_pred_weights, 1)\n",
" Y_processed_pred = \" \".join(reversed_ner_dict[item] for item in Y_batch_pred.numpy()[1:-1])\n",
" Y_predicted.append(Y_processed_pred)\n",
" return Y_predicted\n",
"\n",
"def save_to_csv(filename, data):\n",
" Y_predicted_df = pd.DataFrame(data)\n",
" Y_predicted_df.to_csv(filename, sep='\\t', index=False, header=None)"
]
},
{
"cell_type": "markdown",
"id": "32a1584c-557c-4001-857d-af01ca13b291",
"metadata": {},
"source": [
"# Prepairing training data"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "1d6f1d64-0474-41dc-af79-939a209d81c3",
"metadata": {},
"outputs": [],
"source": [
"# Reading the train dataset\n",
"train_data = pd.read_csv('./train/train.tsv', sep='\\t', usecols=[0, 1], header=None, names=['label', 'sentence'])"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "7ef28765-603d-4a61-be5a-6c40c2b3b80e",
"metadata": {},
"outputs": [],
"source": [
"train_X = train_data['sentence'].apply(lambda x: gensim.utils.simple_preprocess(x))\n",
"v = build_vocab(train_X)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "ab5a4196-b760-4d6a-ba6e-bf5153328f83",
"metadata": {},
"outputs": [],
"source": [
"itos = v.get_itos()"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "cc42fa7d-d05d-4161-9a62-9e94ddfd43d3",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['<unk>', '<pad>', '<bos>', '<eos>', 'eu', 'rejects', 'german', 'call', 'to', 'boycott', 'british', 'lamb', 'peter', 'blackburn', 'brussels', 'the', 'european', 'commission', 'said', 'on', 'thursday', 'it', 'disagreed', 'with', 'advice']\n"
]
}
],
"source": [
"print(itos[0:25])"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "f545c5b0-0e5f-45e3-927e-4b9393c34416",
"metadata": {},
"outputs": [],
"source": [
"v.set_default_index(v[\"<unk>\"])"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "261cc7f7-52e3-49b1-8d3a-73615a257650",
"metadata": {},
"outputs": [],
"source": [
"# Creating a mapping for label to index conversion\n",
"ner_dict = {'O': 0, 'B-PER': 1, 'I-PER': 2, 'B-ORG': 3, 'I-ORG': 4, 'B-LOC': 5, 'I-LOC': 6, 'B-MISC': 7, 'I-MISC': 8}\n",
"reversed_ner_dict = {v: k for k, v in ner_dict.items()}"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "6ddc11a7-8c6d-4fcd-b540-3e3c0fce50f4",
"metadata": {},
"outputs": [],
"source": [
"num_tags = len(ner_dict)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "3fd1b6c4-893e-4b13-8e06-e50f065e7d5d",
"metadata": {},
"outputs": [],
"source": [
"train_X = fit_data_X(train_data['sentence'])\n",
"train_Y = fit_data_Y(train_data['label'])"
]
},
{
"cell_type": "markdown",
"id": "b4237a35-0952-4533-9a23-70d7a299b937",
"metadata": {},
"source": [
"# Prepairing dev data"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "caff8949-d52a-4604-bf34-614b02527a38",
"metadata": {},
"outputs": [],
"source": [
"dev_texts_data = pd.read_csv('./dev-0/in.tsv', sep='\\t', usecols=[0], header=None, names=['sentence'])\n",
"dev_labels_data = pd.read_csv('./dev-0/expected.tsv', sep='\\t', usecols=[0], header=None, names=['label'])"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "5458123d-37d0-48f4-b83c-fc9c3d87ff21",
"metadata": {},
"outputs": [],
"source": [
"dev_X = fit_data_X(dev_texts_data['sentence'])\n",
"dev_Y = fit_data_Y(dev_labels_data['label'])"
]
},
{
"cell_type": "markdown",
"id": "5022ee1a-e660-49d7-b462-7218e11f6e5b",
"metadata": {},
"source": [
"# Prepairing test data"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "08fb43bc-218a-4a69-9c6d-6df3201a9fe1",
"metadata": {},
"outputs": [],
"source": [
"test_texts_data = pd.read_csv('./test-A/in.tsv', sep='\\t', usecols=[0], header=None, names=['sentence'])"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "38c9c66d-1a03-48e7-b1df-72c007bc7969",
"metadata": {},
"outputs": [],
"source": [
"test_X = fit_data_X(test_texts_data['sentence'])"
]
},
{
"cell_type": "markdown",
"id": "44c5bc6e-6ff8-49a1-8efe-9e60d0354769",
"metadata": {},
"source": [
"# Model implementation"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "d3eb0ebd-c3f6-4832-83df-f8d5110cb7bd",
"metadata": {},
"outputs": [],
"source": [
"class LSTM(torch.nn.Module):\n",
"\n",
" def __init__(self):\n",
" super(LSTM, self).__init__()\n",
" self.emb = torch.nn.Embedding(len(v.get_itos()), 100)\n",
" self.rec = torch.nn.LSTM(100, 256, 1, batch_first=True)\n",
" self.fc1 = torch.nn.Linear(256, num_tags)\n",
"\n",
" def forward(self, x):\n",
" emb = torch.relu(self.emb(x))\n",
" lstm_output, (h_n, c_n) = self.rec(emb)\n",
" out_weights = self.fc1(lstm_output)\n",
" return out_weights"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "451ff3f2-5204-4c76-b691-ffd90e01d472",
"metadata": {},
"outputs": [],
"source": [
"lstm = LSTM()\n",
"criterion = torch.nn.CrossEntropyLoss()\n",
"optimizer = torch.optim.Adam(lstm.parameters())"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "53475557-ad89-499b-bdd6-26473a908af7",
"metadata": {},
"outputs": [],
"source": [
"def get_accuracy(y_true, y_pred):\n",
" hit = 0\n",
" missed = 0\n",
" for p, t in zip(y_pred, y_true):\n",
" if p == t:\n",
" hit += 1\n",
" else:\n",
" missed += 1\n",
" accuracy = hit / (hit + missed)\n",
" return accuracy"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "d16ba920-5c26-4edc-a85a-27e310e3508a",
"metadata": {},
"outputs": [],
"source": [
"def eval_model(dataset_tokens, dataset_labels, model):\n",
" Y_true = []\n",
" Y_pred = []\n",
" for i in tqdm(range(len(dataset_labels))):\n",
" batch_tokens = dataset_tokens[i].unsqueeze(0)\n",
" tags = list(dataset_labels[i].numpy())\n",
" Y_true += tags\n",
"\n",
" Y_batch_pred_weights = model(batch_tokens).squeeze(0)\n",
" Y_batch_pred = torch.argmax(Y_batch_pred_weights, 1)\n",
" Y_pred += list(Y_batch_pred.numpy())\n",
"\n",
" return get_accuracy(Y_true, Y_pred)"
]
},
{
"cell_type": "markdown",
"id": "bba09c20-424a-4676-a916-85e317f4beb7",
"metadata": {},
"source": [
"# Model training\n",
"After some tests model with this data preprocessing has gotten 84% accuracy results after 3 epochs and stabilized.\n",
"Thus more than 3 epochs are redundant."
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "a53ccf52-50a3-4d13-91ad-1e0b5b27e84b",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "466858f9cdfa4a0288feb4adfea692f6",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/945 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "5e7fb33641f747988f76f659a439d345",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/215 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy: 0.8424276676815372\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e6bc1567690d47f3a51a555bab23d969",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/945 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "630618f46c454304b2bdaaaaac49688a",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/215 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy: 0.8431933784251883\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "83356f848e77467faae7f524537799e6",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/945 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "aa893cb86fe94264ae1cb3e5631abc4d",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/215 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy: 0.8448706495779476\n"
]
}
],
"source": [
"for epoch in range(3):\n",
" lstm.train()\n",
" for i in tqdm(range(len(train_Y))):\n",
" batch_tokens = train_X[i].unsqueeze(0)\n",
" tags = train_Y[i].unsqueeze(1)\n",
" predicted_tags = lstm(batch_tokens)\n",
" optimizer.zero_grad()\n",
" loss = criterion(predicted_tags.squeeze(0), tags.squeeze(1))\n",
" loss.backward()\n",
" optimizer.step()\n",
" lstm.eval()\n",
" accuracy = eval_model(dev_X, dev_Y, lstm)\n",
" print(f\"Accuracy: {accuracy}\")"
]
},
{
"cell_type": "markdown",
"id": "63cd38d8-89a6-47f4-9b3f-0fe92d475ab6",
"metadata": {},
"source": [
"# Evaluation"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "4d1b743d-6e65-478c-95e1-dec60d650f90",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "56d958be54fc48a6b4263f0277c7d916",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/215 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"dev_predicted = predict(dev_X)\n",
"save_to_csv('./dev-0/out.tsv', dev_predicted)"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "3c03d339-cb94-4a71-803f-cc26d3ea5f5d",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "cb9b4d0f3d664609ad9332829549f51e",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/230 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"test_predicted = predict(test_X)\n",
"save_to_csv('./test-A/out.tsv', test_predicted)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
}
},
"nbformat": 4,
"nbformat_minor": 5
}