{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "e574fca4",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\grzyb\\anaconda3\\lib\\site-packages\\gensim\\similarities\\__init__.py:15: UserWarning: The gensim.similarities.levenshtein submodule is disabled, because the optional Levenshtein package <https://pypi.org/project/python-Levenshtein/> is unavailable. Install Levenhstein (e.g. `pip install python-Levenshtein`) to suppress this warning.\n",
      "  warnings.warn(msg)\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import csv\n",
    "import os.path\n",
    "import shutil\n",
    "import torch\n",
    "from tqdm import tqdm\n",
    "from itertools import islice\n",
    "from sklearn.model_selection import train_test_split\n",
    "from torchtext.vocab import Vocab\n",
    "from collections import Counter\n",
    "from nltk.tokenize import word_tokenize\n",
    "import gensim.downloader as api\n",
    "from gensim.models.word2vec import Word2Vec"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "b476f295",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Collecting gensim\n",
      "  Downloading gensim-4.0.1-cp38-cp38-win_amd64.whl (23.9 MB)\n",
      "Requirement already satisfied: scipy>=0.18.1 in c:\\users\\grzyb\\anaconda3\\lib\\site-packages (from gensim) (1.6.2)\n",
      "Collecting Cython==0.29.21\n",
      "  Downloading Cython-0.29.21-cp38-cp38-win_amd64.whl (1.7 MB)\n",
      "Requirement already satisfied: numpy>=1.11.3 in c:\\users\\grzyb\\anaconda3\\lib\\site-packages (from gensim) (1.20.1)\n",
      "Collecting smart-open>=1.8.1\n",
      "  Downloading smart_open-5.1.0-py3-none-any.whl (57 kB)\n",
      "Installing collected packages: smart-open, Cython, gensim\n",
      "  Attempting uninstall: Cython\n",
      "    Found existing installation: Cython 0.29.23\n",
      "    Uninstalling Cython-0.29.23:\n",
      "      Successfully uninstalled Cython-0.29.23\n",
      "Successfully installed Cython-0.29.21 gensim-4.0.1 smart-open-5.1.0\n"
     ]
    }
   ],
   "source": [
    "!pip install gensim"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "fbe3a657",
   "metadata": {},
   "outputs": [],
   "source": [
    "class NERModel(torch.nn.Module):\n",
    "\n",
    "    def __init__(self,):\n",
    "        super(NERModel, self).__init__()\n",
    "        self.emb = torch.nn.Embedding(23628,200)\n",
    "        self.fc1 = torch.nn.Linear(600,9)\n",
    "        \n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.emb(x)\n",
    "        x = x.reshape(600) \n",
    "        x = self.fc1(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "3497a580",
   "metadata": {},
   "outputs": [],
   "source": [
    "def process_output(lines):\n",
    "    result = []\n",
    "    for line in lines:\n",
    "        last_label = None\n",
    "        new_line = []\n",
    "        for label in line:\n",
    "            if(label != \"O\" and label[0:2] == \"I-\"):\n",
    "                if last_label == None or last_label == \"O\":\n",
    "                    label = label.replace('I-', 'B-')\n",
    "                else:\n",
    "                    label = \"I-\" + last_label[2:]\n",
    "            last_label = label\n",
    "            new_line.append(label)\n",
    "            x = (\" \".join(new_line))\n",
    "        result.append(\" \".join(new_line))\n",
    "    return result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "3e78d902",
   "metadata": {},
   "outputs": [],
   "source": [
    "def build_vocab(dataset):\n",
    "    counter = Counter()\n",
    "    for document in dataset:\n",
    "        counter.update(document)\n",
    "    return Vocab(counter, specials=['<unk>', '<pad>', '<bos>', '<eos>'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "ec8537cf",
   "metadata": {},
   "outputs": [],
   "source": [
    "def data_process(dt):\n",
    "    return [ torch.tensor([vocab['<bos>']] +[vocab[token]  for token in  document ] + [vocab['<eos>']], dtype = torch.long) for document in dt]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "847c958a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def labels_process(dt):\n",
    "    return [ torch.tensor([0] + document + [0], dtype = torch.long) for document in dt]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "66bee163",
   "metadata": {},
   "outputs": [],
   "source": [
    "def predict(input_tokens, labels):\n",
    "\n",
    "  results = []\n",
    "  \n",
    "  for i in range(len(input_tokens)):\n",
    "    line_results = []\n",
    "    for j in range(1, len(input_tokens[i]) - 1):\n",
    "        x = input_tokens[i][j-1: j+2].to(device_gpu)\n",
    "        predicted = ner_model(x.long())\n",
    "        result = torch.argmax(predicted)\n",
    "        label = labels[result]\n",
    "        line_results.append(label)\n",
    "    results.append(line_results)\n",
    "\n",
    "  return results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "39046f3f",
   "metadata": {},
   "outputs": [],
   "source": [
    "train = pd.read_csv('train/train.tsv.xz', sep='\\t', names=['a', 'b'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "9b40a8b6",
   "metadata": {},
   "outputs": [],
   "source": [
    "labels = ['O','B-LOC', 'I-LOC','B-MISC', 'I-MISC', 'B-ORG', 'I-ORG', 'B-PER', 'I-PER'] \n",
    "train[\"a\"]=train[\"a\"].apply(lambda x: [labels.index(y) for y in  x.split()])\n",
    "train[\"b\"]=train[\"b\"].apply(lambda x: x.split())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "02a12cbd",
   "metadata": {},
   "outputs": [],
   "source": [
    "vocab = build_vocab(train['b'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "8cc6d19d",
   "metadata": {},
   "outputs": [],
   "source": [
    "  tensors = []\n",
    "\n",
    "  for sent in train[\"b\"]:\n",
    "    sent_tensor = torch.tensor(())\n",
    "    for word in sent:\n",
    "      temp = torch.tensor([word[0].isupper(), word[0].isdigit()])\n",
    "      sent_tensor = torch.cat((sent_tensor, temp))\n",
    "\n",
    "    tensors.append(sent_tensor)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "690085f6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'NVIDIA GeForce RTX 2060'"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.cuda.get_device_name(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "64b2d751",
   "metadata": {},
   "outputs": [],
   "source": [
    "device_gpu = torch.device(\"cuda:0\")\n",
    "ner_model = NERModel().to(device_gpu)\n",
    "criterion = torch.nn.CrossEntropyLoss()\n",
    "optimizer = torch.optim.Adam(ner_model.parameters())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "094d7e69",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_labels = labels_process(train['a'])\n",
    "train_tokens_ids = data_process(train['b'])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "17291b41",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_tensors = [torch.cat((token, tensors[i])) for i, token in enumerate(train_tokens_ids)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "045b7186",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 0\n",
      "f1: 0.6373470953763748\n",
      "acc: 0.9116419913061858\n",
      "epoch: 1\n",
      "f1: 0.7973076923076923\n",
      "acc: 0.9540771782783307\n",
      "epoch: 2\n",
      "f1: 0.8640167364016735\n",
      "acc: 0.9702287410511612\n",
      "epoch: 3\n",
      "f1: 0.9038441719055962\n",
      "acc: 0.9793820591289644\n",
      "epoch: 4\n",
      "f1: 0.928903400400047\n",
      "acc: 0.9850890978100043\n"
     ]
    }
   ],
   "source": [
    "for epoch in range(5):\n",
    "    acc_score = 0\n",
    "    prec_score = 0\n",
    "    selected_items = 0\n",
    "    recall_score = 0\n",
    "    relevant_items = 0\n",
    "    items_total = 0\n",
    "    ner_model.train()\n",
    "    for i in range(len(train_labels)):\n",
    "        for j in range(1, len(train_labels[i]) - 1):\n",
    "            X = train_tensors[i][j - 1: j + 2].to(device_gpu)\n",
    "\n",
    "            Y = train_labels[i][j: j + 1].to(device_gpu)\n",
    "\n",
    "            Y_predictions = ner_model(X.long())\n",
    "\n",
    "            acc_score += int(torch.argmax(Y_predictions) == Y)\n",
    "            if torch.argmax(Y_predictions) != 0:\n",
    "                selected_items += 1\n",
    "            if torch.argmax(Y_predictions) != 0 and torch.argmax(Y_predictions) == Y.item():\n",
    "                prec_score += 1\n",
    "            if Y.item() != 0:\n",
    "                relevant_items += 1\n",
    "            if Y.item() != 0 and torch.argmax(Y_predictions) == Y.item():\n",
    "                recall_score += 1\n",
    "\n",
    "            items_total += 1\n",
    "            optimizer.zero_grad()\n",
    "            loss = criterion(Y_predictions.unsqueeze(0), Y)\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "\n",
    "    precision = prec_score / selected_items\n",
    "    recall = recall_score / relevant_items\n",
    "    f1_score = (2 * precision * recall) / (precision + recall)\n",
    "    print(f'epoch: {epoch}')\n",
    "    print(f'f1: {f1_score}')\n",
    "    print(f'acc: {acc_score / items_total}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "f75aa5e2",
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_tensors_list(data):\n",
    "  tensors = []\n",
    "\n",
    "  for sent in data[\"a\"]:\n",
    "    sent_tensor = torch.tensor(())\n",
    "    for word in sent:\n",
    "      temp = torch.tensor([word[0].isupper(), word[0].isdigit()])\n",
    "      sent_tensor = torch.cat((sent_tensor, temp))\n",
    "\n",
    "    tensors.append(sent_tensor)\n",
    "\n",
    "  return tensors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "49215802",
   "metadata": {},
   "outputs": [],
   "source": [
    "dev = pd.read_csv('dev-0/in.tsv', sep='\\t', names=['a'])\n",
    "dev[\"a\"] = dev[\"a\"].apply(lambda x: x.split())\n",
    "\n",
    "dev_tokens_ids = data_process(dev[\"a\"])\n",
    "\n",
    "dev_extra_tensors = create_tensors_list(dev)\n",
    "\n",
    "dev_tensors = [torch.cat((token, dev_extra_tensors[i])) for i, token in enumerate(dev_tokens_ids)]\n",
    "\n",
    "results = predict(dev_tensors, labels)\n",
    "results_processed = process_output(results)\n",
    "\n",
    "with open(\"dev-0/out.tsv\", \"w\") as f:\n",
    "  for line in results_processed:\n",
    "    f.write(line + \"\\n\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "8c5b007e",
   "metadata": {},
   "outputs": [],
   "source": [
    "test = pd.read_csv('test-A/in.tsv', sep='\\t', names=['a'])\n",
    "test[\"a\"] = test[\"a\"].apply(lambda x: x.split())\n",
    "\n",
    "test_tokens_ids = data_process(test[\"a\"])\n",
    "\n",
    "test_extra_tensors = create_tensors_list(test)\n",
    "\n",
    "test_tensors = [torch.cat((token, test_extra_tensors[i])) for i, token in enumerate(test_tokens_ids)]\n",
    "\n",
    "results = predict(test_tensors, labels)\n",
    "results_processed = process_output(results)\n",
    "\n",
    "with open(\"test-A/out.tsv\", \"w\") as f:\n",
    "  for line in results_processed:\n",
    "    f.write(line + \"\\n\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}