forked from kubapok/en-ner-conll-2003
377 lines
11 KiB
Plaintext
377 lines
11 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"id": "bce0cfa7",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"C:\\Users\\grzyb\\anaconda3\\lib\\site-packages\\gensim\\similarities\\__init__.py:15: UserWarning: The gensim.similarities.levenshtein submodule is disabled, because the optional Levenshtein package <https://pypi.org/project/python-Levenshtein/> is unavailable. Install Levenhstein (e.g. `pip install python-Levenshtein`) to suppress this warning.\n",
|
|
" warnings.warn(msg)\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"from os import sep\n",
|
|
"from nltk import word_tokenize\n",
|
|
"import pandas as pd\n",
|
|
"import torch\n",
|
|
"from TorchCRF import CRF\n",
|
|
"import gensim\n",
|
|
"from torch._C import device\n",
|
|
"from tqdm import tqdm\n",
|
|
"from torchtext.vocab import Vocab\n",
|
|
"from collections import Counter, OrderedDict\n",
|
|
"import spacy\n",
|
|
"\n",
|
|
"\n",
|
|
"from torch.utils.data import DataLoader\n",
|
|
"import numpy as np\n",
|
|
"from sklearn.metrics import accuracy_score, f1_score, classification_report\n",
|
|
"import csv\n",
|
|
"import pickle\n",
|
|
"\n",
|
|
"import lzma\n",
|
|
"import re\n",
|
|
"import itertools"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"id": "67ace382",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Requirement already satisfied: pytorch-crf in c:\\users\\grzyb\\anaconda3\\lib\\site-packages (0.7.2)\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"!pip3 install pytorch-crf"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"id": "adc9a4de",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"ename": "ModuleNotFoundError",
|
|
"evalue": "No module named 'torchcrf'",
|
|
"output_type": "error",
|
|
"traceback": [
|
|
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
|
|
"\u001b[1;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
|
|
"\u001b[1;32m<ipython-input-3-2a643b4fc1bb>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m 18\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 19\u001b[0m \u001b[1;32mimport\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 20\u001b[1;33m \u001b[1;32mfrom\u001b[0m \u001b[0mtorchcrf\u001b[0m \u001b[1;32mimport\u001b[0m \u001b[0mCRF\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m",
|
|
"\u001b[1;31mModuleNotFoundError\u001b[0m: No module named 'torchcrf'"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"import numpy as np\n",
|
|
"import gensim\n",
|
|
"import torch\n",
|
|
"import pandas as pd\n",
|
|
"import seaborn as sns\n",
|
|
"from sklearn.model_selection import train_test_split\n",
|
|
"\n",
|
|
"from torchtext.vocab import Vocab\n",
|
|
"from collections import Counter\n",
|
|
"\n",
|
|
"from sklearn.datasets import fetch_20newsgroups\n",
|
|
"# https://scikit-learn.org/0.19/datasets/twenty_newsgroups.html\n",
|
|
"\n",
|
|
"from sklearn.feature_extraction.text import TfidfVectorizer\n",
|
|
"from sklearn.metrics import accuracy_score\n",
|
|
"\n",
|
|
"from tqdm.notebook import tqdm\n",
|
|
"\n",
|
|
"import torch\n",
|
|
"from torchcrf import CRF"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "6695751c",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def build_vocab(dataset):\n",
|
|
" counter = Counter()\n",
|
|
" for document in dataset:\n",
|
|
" counter.update(document)\n",
|
|
" return Vocab(counter, specials=['<unk>', '<pad>', '<bos>', '<eos>'])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "d247e4fe",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def data_process(dt, vocab):\n",
|
|
" return [torch.tensor([vocab[token] for token in document], dtype=torch.long) for document in dt]\n",
|
|
"\n",
|
|
"\n",
|
|
"def get_scores(y_true, y_pred):\n",
|
|
" acc_score = 0\n",
|
|
" tp = 0\n",
|
|
" fp = 0\n",
|
|
" selected_items = 0\n",
|
|
" relevant_items = 0\n",
|
|
" for p, t in zip(y_pred, y_true):\n",
|
|
" if p == t:\n",
|
|
" acc_score += 1\n",
|
|
" if p > 0 and p == t:\n",
|
|
" tp += 1\n",
|
|
" if p > 0:\n",
|
|
" selected_items += 1\n",
|
|
" if t > 0:\n",
|
|
" relevant_items += 1\n",
|
|
"\n",
|
|
" if selected_items == 0:\n",
|
|
" precision = 1.0\n",
|
|
" else:\n",
|
|
" precision = tp / selected_items\n",
|
|
"\n",
|
|
" if relevant_items == 0:\n",
|
|
" recall = 1.0\n",
|
|
" else:\n",
|
|
" recall = tp / relevant_items\n",
|
|
"\n",
|
|
" if precision + recall == 0.0:\n",
|
|
" f1 = 0.0\n",
|
|
" else:\n",
|
|
" f1 = 2 * precision * recall / (precision + recall)\n",
|
|
"\n",
|
|
" return precision, recall, f1"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "b6061642",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def process_output(lines):\n",
|
|
" result = []\n",
|
|
" for line in lines:\n",
|
|
" last_label = None\n",
|
|
" new_line = []\n",
|
|
" for label in line:\n",
|
|
" if(label != \"O\" and label[0:2] == \"I-\"):\n",
|
|
" if last_label == None or last_label == \"O\":\n",
|
|
" label = label.replace('I-', 'B-')\n",
|
|
" else:\n",
|
|
" label = \"I-\" + last_label[2:]\n",
|
|
" last_label = label\n",
|
|
" new_line.append(label)\n",
|
|
" x = (\" \".join(new_line))\n",
|
|
" result.append(\" \".join(new_line))\n",
|
|
" return result"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "3d7c4dd3",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class GRU(torch.nn.Module):\n",
|
|
" def __init__(self):\n",
|
|
" super(GRU, self).__init__()\n",
|
|
" self.emb = torch.nn.Embedding(len(vocab_x.itos),100)\n",
|
|
" self.dropout = torch.nn.Dropout(0.2)\n",
|
|
" self.rec = torch.nn.GRU(100, 256, 2, batch_first = True, bidirectional = True)\n",
|
|
" self.fc1 = torch.nn.Linear(2* 256 , 9)\n",
|
|
" \n",
|
|
" def forward(self, x):\n",
|
|
" emb = torch.relu(self.emb(x))\n",
|
|
" emb = self.dropout(emb) \n",
|
|
" gru_output, h_n = self.rec(emb) \n",
|
|
" out_weights = self.fc1(gru_output)\n",
|
|
" return out_weights"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "cd5e419d",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def dev_eval(model, crf, dev_tokens, dev_labels_tokens, vocab):\n",
|
|
" Y_true = []\n",
|
|
" Y_pred = []\n",
|
|
" model.eval()\n",
|
|
" crf.eval()\n",
|
|
" for i in tqdm(range(len(dev_labels_tokens))):\n",
|
|
" batch_tokens = dev_tokens[i].unsqueeze(0)\n",
|
|
" tags = list(dev_labels_tokens[i].numpy())\n",
|
|
" Y_true += tags\n",
|
|
"\n",
|
|
" Y_batch_pred_weights = model(batch_tokens).squeeze(0)\n",
|
|
" Y_batch_pred = torch.argmax(Y_batch_pred_weights, 1)\n",
|
|
" Y_pred += [crf.decode(Y_batch_pred)[0]]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "c808bbd5",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"train = pd.read_csv('train/train.tsv', sep='\\t',\n",
|
|
" names=['labels', 'document'])\n",
|
|
"\n",
|
|
"Y_train = [y.split(sep=\" \") for y in train['labels'].values]\n",
|
|
"X_train = [x.split(sep=\" \") for x in train['document'].values]\n",
|
|
"\n",
|
|
"dev = pd.read_csv('dev-0/in.tsv', sep='\\t', names=['document'])\n",
|
|
"exp = pd.read_csv('dev-0/expected.tsv', sep='\\t', names=['labels'])\n",
|
|
"X_dev = [x.split(sep=\" \") for x in dev['document'].values]\n",
|
|
"Y_dev = [y.split(sep=\" \") for y in exp['labels'].values]\n",
|
|
"\n",
|
|
"test = pd.read_csv('test-A/in.tsv', sep='\\t', names=['document'])\n",
|
|
"X_test = test['document'].values"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "79485c9a",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"vocab_x = build_vocab(X_train)\n",
|
|
"vocab_y = build_vocab(Y_train)\n",
|
|
"train_tokens = data_process(X_train, vocab_x)\n",
|
|
"labels_tokens = data_process(Y_train, vocab_y)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "3726c82a",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"torch.cuda.get_device_name(0)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "f29e3b63",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"device_gpu = torch.device(\"cuda:0\")\n",
|
|
"model = GRU()\n",
|
|
"crf = CRF(9)\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "9c321d58",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"mask = torch.ByteTensor([1, 1])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "05482a7c",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"criterion = torch.nn.CrossEntropyLoss()\n",
|
|
"params = list(model.parameters()) + list(crf.parameters())\n",
|
|
"optimizer = torch.optim.Adam(params)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "21a5282e",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"for i in range(2):\n",
|
|
" crf.train()\n",
|
|
" model.train()\n",
|
|
" for i in tqdm(range(len(labels_tokens))):\n",
|
|
" batch_tokens = train_tokens[i].unsqueeze(0)\n",
|
|
" tags = labels_tokens[i].unsqueeze(1)\n",
|
|
"\n",
|
|
" predicted_tags = model(batch_tokens).squeeze(0).unsqueeze(1)\n",
|
|
"\n",
|
|
" optimizer.zero_grad()\n",
|
|
" loss = -crf(predicted_tags, tags)\n",
|
|
"\n",
|
|
" loss.backward()\n",
|
|
" optimizer.step()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "cec14c35",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"!pip3 install pytorch-crf"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "1ee634f7",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import torch\n",
|
|
"from torchcrf import CRF"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.8.8"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|