{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "bce0cfa7", "metadata": {}, "outputs": [], "source": [ "from os import sep\n", "from nltk import word_tokenize\n", "import pandas as pd\n", "import torch\n", "from torchcrf import CRF\n", "import gensim\n", "from torch._C import device\n", "from tqdm import tqdm\n", "from torchtext.vocab import Vocab\n", "from collections import Counter, OrderedDict\n", "\n", "\n", "from torch.utils.data import DataLoader\n", "import numpy as np\n", "from sklearn.metrics import accuracy_score, f1_score, classification_report\n", "import csv\n", "import pickle\n", "\n", "import lzma\n", "import re\n", "import itertools" ] }, { "cell_type": "code", "execution_count": 2, "id": "6695751c", "metadata": {}, "outputs": [], "source": [ "def build_vocab(dataset):\n", " counter = Counter()\n", " for document in dataset:\n", " counter.update(document)\n", " return Vocab(counter, specials=['', '', '', ''])" ] }, { "cell_type": "code", "execution_count": 3, "id": "d247e4fe", "metadata": {}, "outputs": [], "source": [ "def data_process(dt, vocab):\n", " return [torch.tensor([vocab[token] for token in document], dtype=torch.long) for document in dt]\n", "\n", "\n", "def get_scores(y_true, y_pred):\n", " acc_score = 0\n", " tp = 0\n", " fp = 0\n", " selected_items = 0\n", " relevant_items = 0\n", " for p, t in zip(y_pred, y_true):\n", " if p == t:\n", " acc_score += 1\n", " if p > 0 and p == t:\n", " tp += 1\n", " if p > 0:\n", " selected_items += 1\n", " if t > 0:\n", " relevant_items += 1\n", "\n", " if selected_items == 0:\n", " precision = 1.0\n", " else:\n", " precision = tp / selected_items\n", "\n", " if relevant_items == 0:\n", " recall = 1.0\n", " else:\n", " recall = tp / relevant_items\n", "\n", " if precision + recall == 0.0:\n", " f1 = 0.0\n", " else:\n", " f1 = 2 * precision * recall / (precision + recall)\n", "\n", " return precision, recall, f1" ] }, { "cell_type": "code", "execution_count": 4, "id": "b6061642", "metadata": {}, "outputs": [], "source": [ "def process_output(lines):\n", " result = []\n", " for line in lines:\n", " last_label = None\n", " new_line = []\n", " for label in line:\n", " if(label != \"O\" and label[0:2] == \"I-\"):\n", " if last_label == None or last_label == \"O\":\n", " label = label.replace('I-', 'B-')\n", " else:\n", " label = \"I-\" + last_label[2:]\n", " last_label = label\n", " new_line.append(label)\n", " x = (\" \".join(new_line))\n", " result.append(\" \".join(new_line))\n", " return result" ] }, { "cell_type": "code", "execution_count": 5, "id": "3d7c4dd3", "metadata": {}, "outputs": [], "source": [ "class GRU(torch.nn.Module):\n", " def __init__(self):\n", " super(GRU, self).__init__()\n", " self.emb = torch.nn.Embedding(len(vocab_x.itos),100)\n", " self.dropout = torch.nn.Dropout(0.2)\n", " self.rec = torch.nn.GRU(100, 256, 2, batch_first = True, bidirectional = True)\n", " self.fc1 = torch.nn.Linear(2* 256 , 9)\n", " \n", " def forward(self, x):\n", " emb = torch.relu(self.emb(x))\n", " emb = self.dropout(emb) \n", " gru_output, h_n = self.rec(emb) \n", " out_weights = self.fc1(gru_output)\n", " return out_weights" ] }, { "cell_type": "code", "execution_count": 6, "id": "cd5e419d", "metadata": {}, "outputs": [], "source": [ "def dev_eval(model, crf, dev_tokens, dev_labels_tokens, vocab):\n", " Y_true = []\n", " Y_pred = []\n", " model.eval()\n", " crf.eval()\n", " for i in tqdm(range(len(dev_labels_tokens))):\n", " batch_tokens = dev_tokens[i].unsqueeze(0)\n", " tags = list(dev_labels_tokens[i].numpy())\n", " Y_true += tags\n", "\n", " Y_batch_pred_weights = model(batch_tokens).squeeze(0)\n", " Y_batch_pred = torch.argmax(Y_batch_pred_weights, 1)\n", " Y_pred += [crf.decode(Y_batch_pred)[0]]" ] }, { "cell_type": "code", "execution_count": 7, "id": "c808bbd5", "metadata": {}, "outputs": [], "source": [ "train = pd.read_csv('train/train.tsv', sep='\\t',\n", " names=['labels', 'document'])\n", "\n", "Y_train = [y.split(sep=\" \") for y in train['labels'].values]\n", "X_train = [x.split(sep=\" \") for x in train['document'].values]\n", "\n", "dev = pd.read_csv('dev-0/in.tsv', sep='\\t', names=['document'])\n", "exp = pd.read_csv('dev-0/expected.tsv', sep='\\t', names=['labels'])\n", "X_dev = [x.split(sep=\" \") for x in dev['document'].values]\n", "Y_dev = [y.split(sep=\" \") for y in exp['labels'].values]\n", "\n", "test = pd.read_csv('test-A/in.tsv', sep='\\t', names=['document'])\n", "X_test = test['document'].values" ] }, { "cell_type": "code", "execution_count": 8, "id": "79485c9a", "metadata": {}, "outputs": [], "source": [ "vocab_x = build_vocab(X_train)\n", "vocab_y = build_vocab(Y_train)\n", "train_tokens = data_process(X_train, vocab_x)\n", "labels_tokens = data_process(Y_train, vocab_y)" ] }, { "cell_type": "code", "execution_count": 12, "id": "f29e3b63", "metadata": {}, "outputs": [], "source": [ "model = GRU()\n", "crf = CRF(9)\n" ] }, { "cell_type": "code", "execution_count": 13, "id": "05482a7c", "metadata": {}, "outputs": [], "source": [ "criterion = torch.nn.CrossEntropyLoss()\n", "params = list(model.parameters()) + list(crf.parameters())\n", "optimizer = torch.optim.Adam(params)" ] }, { "cell_type": "code", "execution_count": 14, "id": "21a5282e", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ " 0%| | 0/945 [00:00\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzero_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 11\u001b[0;31m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0mcrf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpredicted_tags\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtags\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 12\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 13\u001b[0m \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 887\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 888\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 889\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 890\u001b[0m for hook in itertools.chain(\n\u001b[1;32m 891\u001b[0m \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/anaconda3/lib/python3.8/site-packages/torchcrf/__init__.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, emissions, tags, mask, reduction)\u001b[0m\n\u001b[1;32m 88\u001b[0m \u001b[0mreduction\u001b[0m \u001b[0;32mis\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0mnone\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m \u001b[0motherwise\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 89\u001b[0m \"\"\"\n\u001b[0;32m---> 90\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_validate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0memissions\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtags\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtags\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmask\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmask\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 91\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mreduction\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32min\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;34m'none'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'sum'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'mean'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'token_mean'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 92\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf'invalid reduction: {reduction}'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/anaconda3/lib/python3.8/site-packages/torchcrf/__init__.py\u001b[0m in \u001b[0;36m_validate\u001b[0;34m(self, emissions, tags, mask)\u001b[0m\n\u001b[1;32m 147\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf'emissions must have dimension of 3, got {emissions.dim()}'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 148\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0memissions\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnum_tags\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 149\u001b[0;31m raise ValueError(\n\u001b[0m\u001b[1;32m 150\u001b[0m \u001b[0;34mf'expected last dimension of emissions is {self.num_tags}, '\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 151\u001b[0m f'got {emissions.size(2)}')\n", "\u001b[0;31mValueError\u001b[0m: expected last dimension of emissions is 10, got 9" ] } ], "source": [ "for i in range(2):\n", " crf.train()\n", " model.train()\n", " for i in tqdm(range(len(labels_tokens))):\n", " batch_tokens = train_tokens[i].unsqueeze(0)\n", " tags = labels_tokens[i].unsqueeze(1)\n", "\n", " predicted_tags = model(batch_tokens).squeeze(0).unsqueeze(1)\n", "\n", " optimizer.zero_grad()\n", " loss = -crf(predicted_tags, tags)\n", "\n", " loss.backward()\n", " optimizer.step()" ] }, { "cell_type": "code", "execution_count": null, "id": "366ab1fe", "metadata": {}, "outputs": [], "source": [ "Y_pred = []\n", "model.eval()\n", "crf.eval()\n", "for i in tqdm(range(len(test_tokens))):\n", " batch_tokens = test_tokens[i].unsqueeze(0)\n", "\n", " Y_batch_pred = model(batch_tokens).squeeze(0).unsqueeze(1)\n", " Y_pred += [crf.decode(Y_batch_pred)[0]]\n", "\n", "Y_pred_translate = translate(Y_pred, vocab)\n", "return Y_pred_translate" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.8" } }, "nbformat": 4, "nbformat_minor": 5 }