{ "cells": [ { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import pandas as pd\n", "\n", "from collections import Counter\n", "from torchtext.vocab import vocab\n", "from sklearn.metrics import accuracy_score\n", "from tqdm import tqdm" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "#Wczytanie zbioru danych\n", "\n", "train_set = pd.read_csv('./train/train.tsv', sep='\\t', header=None, names=['labels', 'text'])\n", "val_set = pd.read_csv('./dev-0/expected.tsv', sep='\\t', header=None, names=['labels'])\n", "val_set['text'] = pd.read_csv('./dev-0/in.tsv', sep='\\t', header=None, names=['text'])\n", "test_set = pd.read_csv('./test-A/in.tsv', sep='\\t', header=None, names=['text'])" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "#Tokenizacja danych\n", "train_set['text'] = train_set[\"text\"].apply(lambda x : x.split())\n", "train_set['labels'] = train_set[\"labels\"].apply(lambda x : x.split())\n", "\n", "val_set['text'] = val_set[\"text\"].apply(lambda x : x.split())\n", "val_set['labels'] = val_set[\"labels\"].apply(lambda x : x.split())\n", "\n", "test_set['text'] = test_set[\"text\"].apply(lambda x : x.split())" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
labelstext
0[B-ORG, O, B-MISC, O, O, O, B-MISC, O, O, O, B...[EU, rejects, German, call, to, boycott, Briti...
1[O, B-PER, O, O, O, O, O, O, O, O, O, B-LOC, O...[Rare, Hendrix, song, draft, sells, for, almos...
2[B-LOC, O, B-LOC, O, O, O, O, O, O, B-LOC, O, ...[China, says, Taiwan, spoils, atmosphere, for,...
3[B-LOC, O, O, O, O, B-LOC, O, O, O, B-LOC, O, ...[China, says, time, right, for, Taiwan, talks,...
4[B-MISC, O, O, O, O, O, O, O, O, O, O, O, B-LO...[German, July, car, registrations, up, 14.2, p...
5[B-MISC, O, O, O, O, O, O, O, O, O, O, B-LOC, ...[GREEK, SOCIALISTS, GIVE, GREEN, LIGHT, TO, PM...
6[B-ORG, O, B-MISC, O, O, O, O, O, O, B-LOC, O,...[BayerVB, sets, C$, 100, million, six-year, bo...
7[B-ORG, O, O, O, O, O, O, O, O, O, B-LOC, O, O...[Venantius, sets, $, 300, million, January, 19...
8[O, O, O, O, B-LOC, O, B-ORG, I-ORG, O, O, O, ...[Port, conditions, update, -, Syria, -, Lloyds...
9[B-LOC, O, O, O, O, O, O, B-LOC, O, O, B-PER, ...[Israel, plays, down, fears, of, war, with, Sy...
\n", "
" ], "text/plain": [ " labels \\\n", "0 [B-ORG, O, B-MISC, O, O, O, B-MISC, O, O, O, B... \n", "1 [O, B-PER, O, O, O, O, O, O, O, O, O, B-LOC, O... \n", "2 [B-LOC, O, B-LOC, O, O, O, O, O, O, B-LOC, O, ... \n", "3 [B-LOC, O, O, O, O, B-LOC, O, O, O, B-LOC, O, ... \n", "4 [B-MISC, O, O, O, O, O, O, O, O, O, O, O, B-LO... \n", "5 [B-MISC, O, O, O, O, O, O, O, O, O, O, B-LOC, ... \n", "6 [B-ORG, O, B-MISC, O, O, O, O, O, O, B-LOC, O,... \n", "7 [B-ORG, O, O, O, O, O, O, O, O, O, B-LOC, O, O... \n", "8 [O, O, O, O, B-LOC, O, B-ORG, I-ORG, O, O, O, ... \n", "9 [B-LOC, O, O, O, O, O, O, B-LOC, O, O, B-PER, ... \n", "\n", " text \n", "0 [EU, rejects, German, call, to, boycott, Briti... \n", "1 [Rare, Hendrix, song, draft, sells, for, almos... \n", "2 [China, says, Taiwan, spoils, atmosphere, for,... \n", "3 [China, says, time, right, for, Taiwan, talks,... \n", "4 [German, July, car, registrations, up, 14.2, p... \n", "5 [GREEK, SOCIALISTS, GIVE, GREEN, LIGHT, TO, PM... \n", "6 [BayerVB, sets, C$, 100, million, six-year, bo... \n", "7 [Venantius, sets, $, 300, million, January, 19... \n", "8 [Port, conditions, update, -, Syria, -, Lloyds... \n", "9 [Israel, plays, down, fears, of, war, with, Sy... " ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_set.head(10)\n" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [], "source": [ "#Budowanie słownika\n", "def build_vocab(dataset):\n", " counter = Counter()\n", " for document in dataset:\n", " counter.update(document)\n", " return vocab(counter, specials=[\"\", \"\", \"\", \"\"])\n", " " ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['',\n", " '',\n", " '',\n", " '',\n", " 'EU',\n", " 'rejects',\n", " 'German',\n", " 'call',\n", " 'to',\n", " 'boycott']" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "v = build_vocab(train_set['text'])\n", "v.set_default_index(v[\"\"])\n", "\n", "itos = v.get_itos()\n", "\n", "itos[:10]" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [], "source": [ "def data_process(dt):\n", " # Wektoryzacja dokumentów tekstowych.\n", " return [\n", " torch.tensor(\n", " [v[\"\"]] + [v[token] for token in document] + [v[\"\"]],\n", " dtype=torch.long,\n", " )\n", " for document in dt\n", " ]\n", "\n", "def labels_process(dt):\n", " # Wektoryzacja etykiet (NER)\n", " return [torch.tensor([0] + document + [0], dtype=torch.long) for document in dt]" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [], "source": [ "#Różne tagi NER\n", "num_tags = {\n", " \"O\" : 0,\n", " \"B-PER\" : 1,\n", " \"I-PER\" : 2,\n", " \"B-ORG\" : 3,\n", " \"I-ORG\" : 4,\n", " \"B-LOC\" : 5,\n", " \"I-LOC\" : 6,\n", " \"B-MISC\" : 7,\n", " \"I-MISC\" : 8,\n", "}" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "def covert_to_int(dt, tags):\n", " labels = []\n", " for label in dt:\n", " labels.append([tags[i] for i in label])\n", " return labels" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [], "source": [ "train_tokens_ids = data_process(train_set['text'])\n", "train_labels_ids = labels_process(covert_to_int(train_set['labels'], tags=num_tags))\n", "\n", "val_tokens_ids = data_process(val_set['text'])\n", "val_labels_ids = labels_process(covert_to_int(val_set['labels'], tags=num_tags))\n", "\n", "test_tokens_ids = data_process(train_set['text'])" ] }, { "cell_type": "code", "execution_count": 43, "metadata": {}, "outputs": [], "source": [ "class LSTM(torch.nn.Module):\n", "\n", " def __init__(self, num_tags):\n", " super(LSTM, self).__init__()\n", " self.emb = torch.nn.Embedding(len(v.get_itos()), 100)\n", " self.rec = torch.nn.LSTM(100, 256, 1, batch_first=True)\n", " self.fc1 = torch.nn.Linear(256, num_tags)\n", " self.hidden2tag = torch.nn.Linear(20, num_tags)\n", "\n", " def forward(self, x):\n", " emb = torch.relu(self.emb(x))\n", " lstm_output, (h_n, c_n) = self.rec(emb)\n", " out_weights = self.fc1(lstm_output)\n", " return out_weights" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "EPOCHS = 10\n", "LR = 0.001\n", "NUM_TAGS = len(num_tags)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "model = LSTM(num_tags=NUM_TAGS)\n", "optimizer = torch.optim.Adam(model.parameters(), lr=LR)\n", "criterion = torch.nn.CrossEntropyLoss()" ] }, { "cell_type": "code", "execution_count": 39, "metadata": {}, "outputs": [], "source": [ "def get_scores(y_true, y_pred):\n", " # Funkcja zwraca precyzję, pokrycie i F1\n", " acc_score = 0\n", " tp = 0\n", " fp = 0\n", " selected_items = 0\n", " relevant_items = 0\n", "\n", " for p, t in zip(y_pred, y_true):\n", " if p == t:\n", " acc_score += 1\n", "\n", " if p > 0 and p == t:\n", " tp += 1\n", "\n", " if p > 0:\n", " selected_items += 1\n", "\n", " if t > 0:\n", " relevant_items += 1\n", "\n", " if selected_items == 0:\n", " precision = 1.0\n", " else:\n", " precision = tp / selected_items\n", "\n", " if relevant_items == 0:\n", " recall = 1.0\n", " else:\n", " recall = tp / relevant_items\n", "\n", " if precision + recall == 0.0:\n", " f1 = 0.0\n", " else:\n", " f1 = 2 * precision * recall / (precision + recall)\n", "\n", " acc = accuracy_score(y_true, y_pred)\n", " return precision, recall, f1, acc" ] }, { "cell_type": "code", "execution_count": 40, "metadata": {}, "outputs": [], "source": [ "def eval_model(dataset_tokens, dataset_labels, model):\n", " Y_true = []\n", " Y_pred = []\n", " for i in tqdm(range(len(dataset_labels))):\n", " batch_tokens = dataset_tokens[i].unsqueeze(0)\n", " tags = dataset_labels[i].unsqueeze(1)\n", " Y_true += tags\n", "\n", " Y_batch_pred_weights = model(batch_tokens).squeeze(0)\n", " Y_batch_pred = torch.argmax(Y_batch_pred_weights, 1)\n", " Y_pred += list(Y_batch_pred.numpy())\n", "\n", " precision, recall, f1, acc = get_scores(Y_true, Y_pred)\n", " print(f'precision: {precision}, recall: {recall}, f1: {f1}, val accuracy: {acc}')" ] }, { "cell_type": "code", "execution_count": 46, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 945/945 [02:21<00:00, 6.68it/s]\n", "100%|██████████| 215/215 [00:02<00:00, 93.42it/s] \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "precision: 0.8434014196726061, recall: 0.6783966441388953, f1: 0.7519535033903778, val accuracy: 0.9457621556580554\n", "Train accuracy: 0.9983919167623316\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 945/945 [02:14<00:00, 7.04it/s]\n", "100%|██████████| 215/215 [00:02<00:00, 89.87it/s] \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "precision: 0.8440340076223981, recall: 0.6709391750174785, f1: 0.7475980264866269, val accuracy: 0.9454522251189587\n", "Train accuracy: 0.9989522403833889\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 945/945 [02:08<00:00, 7.34it/s]\n", "100%|██████████| 215/215 [00:02<00:00, 90.39it/s] \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "precision: 0.852653120888759, recall: 0.6796783966441389, f1: 0.7564027750761848, val accuracy: 0.9472206523126288\n", "Train accuracy: 0.9993303449406877\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 945/945 [02:18<00:00, 6.85it/s]\n", "100%|██████████| 215/215 [00:03<00:00, 66.03it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "precision: 0.8375809935205184, recall: 0.6778140293637847, f1: 0.7492754556578862, val accuracy: 0.9455980747844159\n", "Train accuracy: 0.9991891251662749\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 945/945 [02:07<00:00, 7.39it/s]\n", "100%|██████████| 215/215 [00:03<00:00, 62.19it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "precision: 0.8413109098749461, recall: 0.6820088557445817, f1: 0.7533303301370746, val accuracy: 0.9462908606953383\n", "Train accuracy: 0.9991435704003353\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 945/945 [02:13<00:00, 7.08it/s]\n", "100%|██████████| 215/215 [00:03<00:00, 54.46it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "precision: 0.8479315263908702, recall: 0.6926124446515963, f1: 0.7624422780913289, val accuracy: 0.9478769758071868\n", "Train accuracy: 0.998583246779278\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 945/945 [02:11<00:00, 7.20it/s]\n", "100%|██████████| 215/215 [00:03<00:00, 57.36it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "precision: 0.8470877294406706, recall: 0.6829410393847588, f1: 0.7562092768208503, val accuracy: 0.9471294962717179\n", "Train accuracy: 0.999180014213087\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 945/945 [02:15<00:00, 6.99it/s]\n", "100%|██████████| 215/215 [00:04<00:00, 45.79it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "precision: 0.8728230645397337, recall: 0.6949429037520392, f1: 0.7737917612714889, val accuracy: 0.9498824087072251\n", "Train accuracy: 0.9993212339874997\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 945/945 [02:38<00:00, 5.98it/s]\n", "100%|██████████| 215/215 [00:07<00:00, 28.22it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "precision: 0.8691318792431028, recall: 0.7011186203682125, f1: 0.7761367300870687, val accuracy: 0.9505934258263296\n", "Train accuracy: 0.9996310063958891\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 945/945 [02:03<00:00, 7.62it/s]\n", "100%|██████████| 215/215 [00:02<00:00, 77.13it/s] \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "precision: 0.8701146047605054, recall: 0.6900489396411092, f1: 0.7696906680530282, val accuracy: 0.949116697963574\n", "Train accuracy: 0.9997540042639261\n" ] } ], "source": [ "NUM_EPOCHS = 10\n", "for i in range(NUM_EPOCHS):\n", " model.train()\n", " train_true = []\n", " train_pred = []\n", " for i in tqdm(range(len(train_set['labels']))):\n", " batch_tokens = train_tokens_ids[i].unsqueeze(0)\n", " tags = train_labels_ids[i].unsqueeze(1)\n", " train_true += tags\n", "\n", " Y_batch_pred_weights = model(batch_tokens).squeeze(0)\n", " Y_batch_pred = torch.argmax(Y_batch_pred_weights, 1)\n", " train_pred += list(Y_batch_pred.numpy())\n", "\n", " predicted_tags = model(batch_tokens)\n", "\n", " optimizer.zero_grad()\n", " loss = criterion(predicted_tags.squeeze(0), tags.squeeze(1))\n", "\n", " loss.backward()\n", " optimizer.step()\n", "\n", " model.eval()\n", " eval_model(val_tokens_ids, val_labels_ids, model)\n", " print(f'Train accuracy: {accuracy_score(train_true, train_pred)}')" ] }, { "cell_type": "code", "execution_count": 67, "metadata": {}, "outputs": [], "source": [ "def save_prediction(test_tokens, test_pred, file_name):\n", " with open(file_name, 'w') as f:\n", " for i in range(len(test_tokens)):\n", " for j in range(len(test_tokens[i])):\n", " print(i, j)\n", " print(test_pred[i][j])\n", " f.write(f'{test_tokens[i][j]}\\t{list(num_tags.keys())[test_pred[i][j]]}\\n')\n", " f.write('\\n')" ] }, { "cell_type": "code", "execution_count": 70, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0\n" ] } ], "source": [ "test_pred = []\n", "\n", "with torch.no_grad():\n", " for i in range(len(test_tokens_ids)):\n", " batch_tokens = test_tokens_ids[i].unsqueeze(0)\n", "\n", " Y_batch_pred_weights = model(batch_tokens).squeeze(0)\n", " Y_batch_pred = torch.argmax(Y_batch_pred_weights, 1)\n", " test_pred += list(Y_batch_pred.numpy())\n" ] }, { "cell_type": "code", "execution_count": 86, "metadata": {}, "outputs": [], "source": [ "with open('test-A/out.tsv', 'w') as f:\n", " for i in range(len(test_pred)):\n", " tag = list(num_tags.keys())[test_pred[i]]\n", " f.write(tag)\n", " f.write('\\n')\n" ] } ], "metadata": { "kernelspec": { "display_name": "dl", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.6" } }, "nbformat": 4, "nbformat_minor": 2 }