s464953_uczenie_glebokie_RNN/RNN.ipynb

1079 lines
28 KiB
Plaintext
Raw Permalink Normal View History

2024-05-26 20:27:00 +02:00
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from collections import Counter\n",
"import torch\n",
"import pandas as pd\n",
"from torchtext.vocab import vocab\n",
"from sklearn.model_selection import train_test_split\n",
"from tqdm.notebook import tqdm\n",
"from sklearn.metrics import accuracy_score"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CUDA nie jest dostępna. Model będzie uruchomiony na CPU.\n"
]
}
],
"source": [
"if torch.cuda.is_available():\n",
" print(\"CUDA jest dostępna!\")\n",
" print(f\"Nazwa urządzenia: {torch.cuda.get_device_name(0)}\")\n",
" device = torch.device(\"cuda\")\n",
"else:\n",
" print(\"CUDA nie jest dostępna. Model będzie uruchomiony na CPU.\")\n",
" device = torch.device(\"cpu\")"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"train_data = pd.read_csv(\"train.tsv\", sep='\\t', header=None, names=['labels', 'documents'])\n",
"train_data[\"tokenized_documents\"] = train_data[\"documents\"].apply(lambda x: x.split())\n",
"train_data[\"tokenized_labels\"] = train_data[\"labels\"].apply(lambda x: x.split())\n",
"\n",
"X_train, X_val, y_train, y_val = train_test_split(\n",
" train_data[\"tokenized_documents\"], train_data[\"tokenized_labels\"], test_size=0.2, random_state=42\n",
")\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"def build_vocab(dataset):\n",
" counter = Counter()\n",
" for document in dataset:\n",
" counter.update(document)\n",
" return vocab(counter, specials=[\"<unk>\", \"<pad>\", \"<bos>\", \"<eos>\"])\n",
"\n",
"train_vocab = build_vocab(X_train)\n",
"itos = train_vocab.get_itos()\n",
"train_vocab.set_default_index(train_vocab[\"<unk>\"])"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"def data_process(dt):\n",
" return [\n",
" torch.tensor(\n",
" [train_vocab[\"<bos>\"]] + [train_vocab[token] for token in document] + [train_vocab[\"<eos>\"]],\n",
" dtype=torch.long,\n",
" )\n",
" for document in dt\n",
" ]"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"train_tokens_ids = data_process(X_train)\n",
"val_tokens_ids = data_process(X_val)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"labels = [\"O\", \"B-PER\", \"I-PER\", \"B-ORG\", \"I-ORG\", \"B-LOC\", \"I-LOC\", \"B-MISC\", \"I-MISC\"]\n",
"\n",
"label_to_index = {label: idx for idx, label in enumerate(labels)}\n",
"\n",
"def labels_process(dt, label_to_index):\n",
" return [\n",
" torch.tensor(\n",
" [0] + [label_to_index[label] for label in document] + [0],\n",
" dtype=torch.long,\n",
" device=device,\n",
" )\n",
" for document in dt\n",
" ]"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"train_labels = labels_process(y_train, label_to_index)\n",
"val_labels = labels_process(y_val, label_to_index)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"all_label_indices = [\n",
" label_to_index[label]\n",
" for document in y_train\n",
" for label in document\n",
"]\n",
"\n",
"num_tags = max(all_label_indices) + 1"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"class LSTM(torch.nn.Module):\n",
"\n",
" def __init__(self):\n",
" super(LSTM, self).__init__()\n",
" self.emb = torch.nn.Embedding(len(train_vocab.get_itos()), 100)\n",
" self.rec = torch.nn.LSTM(100, 256, 1, batch_first=True)\n",
" self.fc1 = torch.nn.Linear(256, num_tags)\n",
"\n",
" def forward(self, x):\n",
" emb = torch.relu(self.emb(x))\n",
" lstm_output, (h_n, c_n) = self.rec(emb)\n",
" out_weights = self.fc1(lstm_output)\n",
" return out_weights"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"lstm = LSTM()\n",
"criterion = torch.nn.CrossEntropyLoss()\n",
"optimizer = torch.optim.Adam(lstm.parameters())"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"def get_scores(y_true, y_pred):\n",
" acc_score = 0\n",
" tp = 0\n",
" fp = 0\n",
" selected_items = 0\n",
" relevant_items = 0\n",
"\n",
" for p, t in zip(y_pred, y_true):\n",
" if p == t:\n",
" acc_score += 1\n",
"\n",
" if p > 0 and p == t:\n",
" tp += 1\n",
"\n",
" if p > 0:\n",
" selected_items += 1\n",
"\n",
" if t > 0:\n",
" relevant_items += 1\n",
"\n",
" if selected_items == 0:\n",
" precision = 1.0\n",
" else:\n",
" precision = tp / selected_items\n",
"\n",
" if relevant_items == 0:\n",
" recall = 1.0\n",
" else:\n",
" recall = tp / relevant_items\n",
"\n",
" if precision + recall == 0.0:\n",
" f1 = 0.0\n",
" else:\n",
" f1 = 2 * precision * recall / (precision + recall)\n",
"\n",
" return precision, recall, f1\n",
"\n",
"def eval_model(dataset_tokens, dataset_labels, model):\n",
" Y_true = []\n",
" Y_pred = []\n",
" for i in tqdm(range(len(dataset_labels))):\n",
" batch_tokens = dataset_tokens[i].unsqueeze(0)\n",
" tags = list(dataset_labels[i].numpy())\n",
" Y_true += tags\n",
"\n",
" Y_batch_pred_weights = model(batch_tokens).squeeze(0)\n",
" Y_batch_pred = torch.argmax(Y_batch_pred_weights, 1)\n",
" Y_pred += list(Y_batch_pred.numpy())\n",
"\n",
" return get_scores(Y_true, Y_pred)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e740a20b87d346f6ab948dcec6f6b1e7",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/756 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "3dd1c217de0746878d8472a3c3eea9a7",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/189 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"(0.5042095416276894, 0.07628078120577413, 0.1325138291333743)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a76ffb596f4b48f99bfcbddcc415ab23",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/756 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "0b4413f5591148dea6a1f4f3807a3e87",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/189 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"(0.6831020812685827, 0.3901783187093122, 0.4966672671590705)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "00d34c09954741a5b78a9916a44b2421",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/756 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "91b3bb7aece64e6ea936eaa8d1637d07",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/189 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"(0.7253631723596388, 0.5229266911972827, 0.6077302631578947)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e600073771134839a82c1508cd7189f3",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/756 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "fab86a7d276b4da09953c8d48bea1841",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/189 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"(0.7647388059701492, 0.5801018964053213, 0.6597456945115081)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "0fac8f61e79548d9a954cc349dc3733f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/756 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "532f0eb1fd264a6394b8e2658f5647ec",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/189 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"(0.7872690689592098, 0.6091140673648457, 0.6868267773079072)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "dc66b035585544c9b23af527cff6ae1c",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/756 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "40ba6295fe754d86991e776607f5b15b",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/189 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"(0.8001462790272444, 0.6193037078969714, 0.6982050259274033)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "96709b4563da4a20a75a473438fd73da",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/756 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "3793118578964b9baa1de4e9d0223962",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/189 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"(0.8068851251840943, 0.6202943673931502, 0.7013922227556408)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "436a61533a384a32b76b07ade1be82e3",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/756 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "661b28ede62740d79eb612de217bd495",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/189 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"(0.8154405086285196, 0.635295782621002, 0.7141834380717526)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "8885ddd13fae40c9b5cb617904147916",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/756 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "bc38c641533a4e28acc9de58c4222f69",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/189 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"(0.8557729190640583, 0.6314746674214549, 0.7267100977198697)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a5b13fcef3dc42da9ed0a974ccf62765",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/756 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "016e85f9a5cd4a2c91e472082aa28004",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/189 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"(0.837495475931958, 0.6549674497594112, 0.7350698856416772)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "3e4fd25b8a2b4b769aea944cb1ca9833",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/756 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "329ebf139d3845a58c4370381f5e38d7",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/189 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"(0.8592408926187297, 0.6375601471836966, 0.7319847266227963)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "d5a5f4d0c1da4124a1b1bee53c5b28c7",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/756 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "191cea33b9d74a8ea1a0ad9055aa4907",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/189 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"(0.8497704315886134, 0.6548259269742428, 0.7396690911997442)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "95b07cf2c14c447b9214f7dd29480196",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/756 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "1dffd7fa50484950ad9ba54d2f33876c",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/189 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"(0.8601201652271874, 0.6483158788564959, 0.7393479664299548)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "d11bf24242d041c5b2d91555adfd7802",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/756 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "7d390ff27b9b462b818a3b7016065a44",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/189 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"(0.8630500758725341, 0.6439286725162752, 0.737558761549684)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "ce869cd0928b49b4b5a50d00f0e3f27f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/756 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "99713f107d0141b28e3006c34ed3ede2",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/189 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"(0.8547055586130985, 0.6593546560996321, 0.7444275784932493)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "9d1c0bdf59634489856f576546815294",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/756 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "10b57a15e1ed465283a3a0f42272e73e",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/189 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"(0.8580931263858093, 0.6572318143221059, 0.7443500560987337)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "47bf91db0aaf47dda0dc463fd9f681c8",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/756 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "d2de98af28c44819a35ff6a4aca290e9",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/189 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"(0.857722827089869, 0.6577979054627795, 0.744573488185823)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "1ae8d518a0314821962af0915dea3949",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/756 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "cc5c608c3f4546b2b16ee85d2e494e58",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/189 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"(0.8601476014760148, 0.6597792244551373, 0.7467563671311869)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "117eed9ce0064c6ea9e96c73b460b2ad",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/756 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "611e53b7eb5949b1a3ba714a8746a56a",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/189 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"(0.8474636395885066, 0.6761958675346731, 0.7522040302267002)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "ee428ed7f48448fe934a1fca1a462348",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/756 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "f86e193fb5f74cb99c119a6a198f1045",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/189 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"(0.855834829443447, 0.6746391168978205, 0.7545109211775879)\n"
]
}
],
"source": [
"NUM_EPOCHS = 20\n",
"for i in range(NUM_EPOCHS):\n",
" lstm.train()\n",
" for i in tqdm(range(len(train_labels))):\n",
" batch_tokens = train_tokens_ids[i].unsqueeze(0)\n",
" tags = train_labels[i].unsqueeze(1)\n",
"\n",
" predicted_tags = lstm(batch_tokens)\n",
"\n",
" optimizer.zero_grad()\n",
" loss = criterion(predicted_tags.squeeze(0), tags.squeeze(1))\n",
"\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
" lstm.eval()\n",
" print(eval_model(val_tokens_ids, val_labels, lstm))"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "7966898b2ca64115a34fad154c0518dc",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/215 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "2574e1adda9946e8a2fb0201726beba7",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/230 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"def pred_labels(dataset_tokens, model, label_to_index):\n",
" Y_pred = []\n",
" inv_label_to_index = {\n",
" v: k for k, v in label_to_index.items()\n",
" }\n",
"\n",
" for i in tqdm(range(len(dataset_tokens))):\n",
" batch_tokens = dataset_tokens[i].unsqueeze(0)\n",
" Y_batch_pred_weights = model(batch_tokens).squeeze(0)\n",
" Y_batch_pred = torch.argmax(Y_batch_pred_weights, 1)\n",
" predicted_labels = [inv_label_to_index[label.item()] for label in Y_batch_pred]\n",
" predicted_labels = predicted_labels[1:-1]\n",
" Y_pred.append(\" \".join(predicted_labels))\n",
"\n",
" return Y_pred\n",
"\n",
"dev_data = pd.read_csv(\"dev-0/in.tsv\", sep=\"\\t\", names=[\"Text\"])\n",
"dev_labels = pd.read_csv(\"dev-0/expected.tsv\", sep=\"\\t\", names=[\"Label\"])\n",
"test_A_data = pd.read_csv(\"test-A/in.tsv\", sep=\"\\t\", names=[\"Text\"])\n",
"dev_data[\"tokenized_text\"] = dev_data[\"Text\"].apply(lambda x: x.split())\n",
"dev_labels[\"tokenized_labels\"] = dev_labels[\"Label\"].apply(lambda x: x.split())\n",
"test_A_data[\"tokenized_text\"] = test_A_data[\"Text\"].apply(lambda x: x.split())\n",
"dev_0_tokens_ids = data_process(dev_data[\"tokenized_text\"])\n",
"test_A_tokens_ids = data_process(test_A_data[\"tokenized_text\"])\n",
"dev_0_labels = labels_process(dev_labels[\"tokenized_labels\"], label_to_index)\n",
"dev_0_predictons = pred_labels(dev_0_tokens_ids, lstm, label_to_index)\n",
"dev_0_predictons = pd.DataFrame(dev_0_predictons, columns=[\"Label\"])\n",
"dev_0_predictons.to_csv(\"dev-0/out.tsv\", index=False, header=False)\n",
"test_A_predictions = pred_labels(test_A_tokens_ids, lstm, label_to_index)\n",
"test_A_predictions = pd.DataFrame(test_A_predictions, columns=[\"Label\"])\n",
"test_A_predictions.to_csv(\"test-A/out.tsv\", index=False, header=False)"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy: 0.9441390252001624\n"
]
}
],
"source": [
"with open('dev-0/out.tsv', 'r') as file:\n",
" predicted_labels = [line.strip().split()[1:] for line in file]\n",
"\n",
"with open('dev-0/expected.tsv', 'r') as file:\n",
" true_labels = [line.strip().split()[1:] for line in file]\n",
"\n",
"predicted_labels = [label for sublist in predicted_labels for label in sublist]\n",
"true_labels = [label for sublist in true_labels for label in sublist]\n",
"\n",
"accuracy = accuracy_score(true_labels, predicted_labels)\n",
"print(\"Accuracy:\", accuracy)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.11"
}
},
"nbformat": 4,
"nbformat_minor": 2
}