This commit is contained in:
Tomasz Grzybowski 2021-06-22 19:27:08 +02:00
parent c6aaaf6544
commit 142eed56c0
8 changed files with 790 additions and 506 deletions

View File

@ -1,6 +0,0 @@
{
"cells": [],
"metadata": {},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@ -0,0 +1,376 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "bce0cfa7",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"C:\\Users\\grzyb\\anaconda3\\lib\\site-packages\\gensim\\similarities\\__init__.py:15: UserWarning: The gensim.similarities.levenshtein submodule is disabled, because the optional Levenshtein package <https://pypi.org/project/python-Levenshtein/> is unavailable. Install Levenhstein (e.g. `pip install python-Levenshtein`) to suppress this warning.\n",
" warnings.warn(msg)\n"
]
}
],
"source": [
"from os import sep\n",
"from nltk import word_tokenize\n",
"import pandas as pd\n",
"import torch\n",
"from TorchCRF import CRF\n",
"import gensim\n",
"from torch._C import device\n",
"from tqdm import tqdm\n",
"from torchtext.vocab import Vocab\n",
"from collections import Counter, OrderedDict\n",
"import spacy\n",
"\n",
"\n",
"from torch.utils.data import DataLoader\n",
"import numpy as np\n",
"from sklearn.metrics import accuracy_score, f1_score, classification_report\n",
"import csv\n",
"import pickle\n",
"\n",
"import lzma\n",
"import re\n",
"import itertools"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "67ace382",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: pytorch-crf in c:\\users\\grzyb\\anaconda3\\lib\\site-packages (0.7.2)\n"
]
}
],
"source": [
"!pip3 install pytorch-crf"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "adc9a4de",
"metadata": {},
"outputs": [
{
"ename": "ModuleNotFoundError",
"evalue": "No module named 'torchcrf'",
"output_type": "error",
"traceback": [
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[1;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
"\u001b[1;32m<ipython-input-3-2a643b4fc1bb>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m 18\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 19\u001b[0m \u001b[1;32mimport\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 20\u001b[1;33m \u001b[1;32mfrom\u001b[0m \u001b[0mtorchcrf\u001b[0m \u001b[1;32mimport\u001b[0m \u001b[0mCRF\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[1;31mModuleNotFoundError\u001b[0m: No module named 'torchcrf'"
]
}
],
"source": [
"import numpy as np\n",
"import gensim\n",
"import torch\n",
"import pandas as pd\n",
"import seaborn as sns\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"from torchtext.vocab import Vocab\n",
"from collections import Counter\n",
"\n",
"from sklearn.datasets import fetch_20newsgroups\n",
"# https://scikit-learn.org/0.19/datasets/twenty_newsgroups.html\n",
"\n",
"from sklearn.feature_extraction.text import TfidfVectorizer\n",
"from sklearn.metrics import accuracy_score\n",
"\n",
"from tqdm.notebook import tqdm\n",
"\n",
"import torch\n",
"from torchcrf import CRF"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6695751c",
"metadata": {},
"outputs": [],
"source": [
"def build_vocab(dataset):\n",
" counter = Counter()\n",
" for document in dataset:\n",
" counter.update(document)\n",
" return Vocab(counter, specials=['<unk>', '<pad>', '<bos>', '<eos>'])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d247e4fe",
"metadata": {},
"outputs": [],
"source": [
"def data_process(dt, vocab):\n",
" return [torch.tensor([vocab[token] for token in document], dtype=torch.long) for document in dt]\n",
"\n",
"\n",
"def get_scores(y_true, y_pred):\n",
" acc_score = 0\n",
" tp = 0\n",
" fp = 0\n",
" selected_items = 0\n",
" relevant_items = 0\n",
" for p, t in zip(y_pred, y_true):\n",
" if p == t:\n",
" acc_score += 1\n",
" if p > 0 and p == t:\n",
" tp += 1\n",
" if p > 0:\n",
" selected_items += 1\n",
" if t > 0:\n",
" relevant_items += 1\n",
"\n",
" if selected_items == 0:\n",
" precision = 1.0\n",
" else:\n",
" precision = tp / selected_items\n",
"\n",
" if relevant_items == 0:\n",
" recall = 1.0\n",
" else:\n",
" recall = tp / relevant_items\n",
"\n",
" if precision + recall == 0.0:\n",
" f1 = 0.0\n",
" else:\n",
" f1 = 2 * precision * recall / (precision + recall)\n",
"\n",
" return precision, recall, f1"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b6061642",
"metadata": {},
"outputs": [],
"source": [
"def process_output(lines):\n",
" result = []\n",
" for line in lines:\n",
" last_label = None\n",
" new_line = []\n",
" for label in line:\n",
" if(label != \"O\" and label[0:2] == \"I-\"):\n",
" if last_label == None or last_label == \"O\":\n",
" label = label.replace('I-', 'B-')\n",
" else:\n",
" label = \"I-\" + last_label[2:]\n",
" last_label = label\n",
" new_line.append(label)\n",
" x = (\" \".join(new_line))\n",
" result.append(\" \".join(new_line))\n",
" return result"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3d7c4dd3",
"metadata": {},
"outputs": [],
"source": [
"class GRU(torch.nn.Module):\n",
" def __init__(self):\n",
" super(GRU, self).__init__()\n",
" self.emb = torch.nn.Embedding(len(vocab_x.itos),100)\n",
" self.dropout = torch.nn.Dropout(0.2)\n",
" self.rec = torch.nn.GRU(100, 256, 2, batch_first = True, bidirectional = True)\n",
" self.fc1 = torch.nn.Linear(2* 256 , 9)\n",
" \n",
" def forward(self, x):\n",
" emb = torch.relu(self.emb(x))\n",
" emb = self.dropout(emb) \n",
" gru_output, h_n = self.rec(emb) \n",
" out_weights = self.fc1(gru_output)\n",
" return out_weights"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cd5e419d",
"metadata": {},
"outputs": [],
"source": [
"def dev_eval(model, crf, dev_tokens, dev_labels_tokens, vocab):\n",
" Y_true = []\n",
" Y_pred = []\n",
" model.eval()\n",
" crf.eval()\n",
" for i in tqdm(range(len(dev_labels_tokens))):\n",
" batch_tokens = dev_tokens[i].unsqueeze(0)\n",
" tags = list(dev_labels_tokens[i].numpy())\n",
" Y_true += tags\n",
"\n",
" Y_batch_pred_weights = model(batch_tokens).squeeze(0)\n",
" Y_batch_pred = torch.argmax(Y_batch_pred_weights, 1)\n",
" Y_pred += [crf.decode(Y_batch_pred)[0]]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c808bbd5",
"metadata": {},
"outputs": [],
"source": [
"train = pd.read_csv('train/train.tsv', sep='\\t',\n",
" names=['labels', 'document'])\n",
"\n",
"Y_train = [y.split(sep=\" \") for y in train['labels'].values]\n",
"X_train = [x.split(sep=\" \") for x in train['document'].values]\n",
"\n",
"dev = pd.read_csv('dev-0/in.tsv', sep='\\t', names=['document'])\n",
"exp = pd.read_csv('dev-0/expected.tsv', sep='\\t', names=['labels'])\n",
"X_dev = [x.split(sep=\" \") for x in dev['document'].values]\n",
"Y_dev = [y.split(sep=\" \") for y in exp['labels'].values]\n",
"\n",
"test = pd.read_csv('test-A/in.tsv', sep='\\t', names=['document'])\n",
"X_test = test['document'].values"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "79485c9a",
"metadata": {},
"outputs": [],
"source": [
"vocab_x = build_vocab(X_train)\n",
"vocab_y = build_vocab(Y_train)\n",
"train_tokens = data_process(X_train, vocab_x)\n",
"labels_tokens = data_process(Y_train, vocab_y)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3726c82a",
"metadata": {},
"outputs": [],
"source": [
"torch.cuda.get_device_name(0)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f29e3b63",
"metadata": {},
"outputs": [],
"source": [
"device_gpu = torch.device(\"cuda:0\")\n",
"model = GRU()\n",
"crf = CRF(9)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9c321d58",
"metadata": {},
"outputs": [],
"source": [
"mask = torch.ByteTensor([1, 1])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "05482a7c",
"metadata": {},
"outputs": [],
"source": [
"criterion = torch.nn.CrossEntropyLoss()\n",
"params = list(model.parameters()) + list(crf.parameters())\n",
"optimizer = torch.optim.Adam(params)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "21a5282e",
"metadata": {},
"outputs": [],
"source": [
"for i in range(2):\n",
" crf.train()\n",
" model.train()\n",
" for i in tqdm(range(len(labels_tokens))):\n",
" batch_tokens = train_tokens[i].unsqueeze(0)\n",
" tags = labels_tokens[i].unsqueeze(1)\n",
"\n",
" predicted_tags = model(batch_tokens).squeeze(0).unsqueeze(1)\n",
"\n",
" optimizer.zero_grad()\n",
" loss = -crf(predicted_tags, tags)\n",
"\n",
" loss.backward()\n",
" optimizer.step()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cec14c35",
"metadata": {},
"outputs": [],
"source": [
"!pip3 install pytorch-crf"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1ee634f7",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from torchcrf import CRF"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.8"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@ -32,37 +32,6 @@
"from gensim.models.word2vec import Word2Vec" "from gensim.models.word2vec import Word2Vec"
] ]
}, },
{
"cell_type": "code",
"execution_count": 6,
"id": "b476f295",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Collecting gensim\n",
" Downloading gensim-4.0.1-cp38-cp38-win_amd64.whl (23.9 MB)\n",
"Requirement already satisfied: scipy>=0.18.1 in c:\\users\\grzyb\\anaconda3\\lib\\site-packages (from gensim) (1.6.2)\n",
"Collecting Cython==0.29.21\n",
" Downloading Cython-0.29.21-cp38-cp38-win_amd64.whl (1.7 MB)\n",
"Requirement already satisfied: numpy>=1.11.3 in c:\\users\\grzyb\\anaconda3\\lib\\site-packages (from gensim) (1.20.1)\n",
"Collecting smart-open>=1.8.1\n",
" Downloading smart_open-5.1.0-py3-none-any.whl (57 kB)\n",
"Installing collected packages: smart-open, Cython, gensim\n",
" Attempting uninstall: Cython\n",
" Found existing installation: Cython 0.29.23\n",
" Uninstalling Cython-0.29.23:\n",
" Successfully uninstalled Cython-0.29.23\n",
"Successfully installed Cython-0.29.21 gensim-4.0.1 smart-open-5.1.0\n"
]
}
],
"source": [
"!pip install gensim"
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 2, "execution_count": 2,
@ -148,7 +117,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 24, "execution_count": 7,
"id": "66bee163", "id": "66bee163",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -172,7 +141,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 7, "execution_count": 8,
"id": "39046f3f", "id": "39046f3f",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -182,7 +151,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 8, "execution_count": 9,
"id": "9b40a8b6", "id": "9b40a8b6",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -194,7 +163,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 9, "execution_count": 10,
"id": "02a12cbd", "id": "02a12cbd",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -204,7 +173,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 10, "execution_count": 11,
"id": "8cc6d19d", "id": "8cc6d19d",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -222,7 +191,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 15, "execution_count": 12,
"id": "690085f6", "id": "690085f6",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -232,7 +201,7 @@
"'NVIDIA GeForce RTX 2060'" "'NVIDIA GeForce RTX 2060'"
] ]
}, },
"execution_count": 15, "execution_count": 12,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
@ -243,7 +212,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 16, "execution_count": 13,
"id": "64b2d751", "id": "64b2d751",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -256,7 +225,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 17, "execution_count": 14,
"id": "094d7e69", "id": "094d7e69",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -267,7 +236,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 18, "execution_count": 15,
"id": "17291b41", "id": "17291b41",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -277,7 +246,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 19, "execution_count": 16,
"id": "045b7186", "id": "045b7186",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -286,20 +255,20 @@
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"epoch: 0\n", "epoch: 0\n",
"f1: 0.6373470953763748\n", "f1: 0.6310260230881535\n",
"acc: 0.9116419913061858\n", "acc: 0.9099004714510215\n",
"epoch: 1\n", "epoch: 1\n",
"f1: 0.7973076923076923\n", "f1: 0.7977381727751791\n",
"acc: 0.9540771782783307\n", "acc: 0.9539025667888947\n",
"epoch: 2\n", "epoch: 2\n",
"f1: 0.8640167364016735\n", "f1: 0.8635445687583837\n",
"acc: 0.9702287410511612\n", "acc: 0.9699162783858546\n",
"epoch: 3\n", "epoch: 3\n",
"f1: 0.9038441719055962\n", "f1: 0.9047002002591589\n",
"acc: 0.9793820591289644\n", "acc: 0.9794417946385082\n",
"epoch: 4\n", "epoch: 4\n",
"f1: 0.928903400400047\n", "f1: 0.9300697243387956\n",
"acc: 0.9850890978100043\n" "acc: 0.9852774944170274\n"
] ]
} }
], ],
@ -346,7 +315,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 28, "execution_count": 17,
"id": "f75aa5e2", "id": "f75aa5e2",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -367,7 +336,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 29, "execution_count": 18,
"id": "49215802", "id": "49215802",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -391,7 +360,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 30, "execution_count": 19,
"id": "8c5b007e", "id": "8c5b007e",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -412,6 +381,17 @@
" for line in results_processed:\n", " for line in results_processed:\n",
" f.write(line + \"\\n\")" " f.write(line + \"\\n\")"
] ]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "000dd425",
"metadata": {},
"outputs": [],
"source": [
"model_path = \"seq_labeling.model\"\n",
"torch.save(ner_model.state_dict(), model_path)"
]
} }
], ],
"metadata": { "metadata": {

376
Zad11.ipynb Normal file
View File

@ -0,0 +1,376 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "bce0cfa7",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"C:\\Users\\grzyb\\anaconda3\\lib\\site-packages\\gensim\\similarities\\__init__.py:15: UserWarning: The gensim.similarities.levenshtein submodule is disabled, because the optional Levenshtein package <https://pypi.org/project/python-Levenshtein/> is unavailable. Install Levenhstein (e.g. `pip install python-Levenshtein`) to suppress this warning.\n",
" warnings.warn(msg)\n"
]
}
],
"source": [
"from os import sep\n",
"from nltk import word_tokenize\n",
"import pandas as pd\n",
"import torch\n",
"from TorchCRF import CRF\n",
"import gensim\n",
"from torch._C import device\n",
"from tqdm import tqdm\n",
"from torchtext.vocab import Vocab\n",
"from collections import Counter, OrderedDict\n",
"import spacy\n",
"\n",
"\n",
"from torch.utils.data import DataLoader\n",
"import numpy as np\n",
"from sklearn.metrics import accuracy_score, f1_score, classification_report\n",
"import csv\n",
"import pickle\n",
"\n",
"import lzma\n",
"import re\n",
"import itertools"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "67ace382",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: pytorch-crf in c:\\users\\grzyb\\anaconda3\\lib\\site-packages (0.7.2)\n"
]
}
],
"source": [
"!pip3 install pytorch-crf"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "adc9a4de",
"metadata": {},
"outputs": [
{
"ename": "ModuleNotFoundError",
"evalue": "No module named 'torchcrf'",
"output_type": "error",
"traceback": [
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[1;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
"\u001b[1;32m<ipython-input-3-2a643b4fc1bb>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m 18\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 19\u001b[0m \u001b[1;32mimport\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 20\u001b[1;33m \u001b[1;32mfrom\u001b[0m \u001b[0mtorchcrf\u001b[0m \u001b[1;32mimport\u001b[0m \u001b[0mCRF\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[1;31mModuleNotFoundError\u001b[0m: No module named 'torchcrf'"
]
}
],
"source": [
"import numpy as np\n",
"import gensim\n",
"import torch\n",
"import pandas as pd\n",
"import seaborn as sns\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"from torchtext.vocab import Vocab\n",
"from collections import Counter\n",
"\n",
"from sklearn.datasets import fetch_20newsgroups\n",
"# https://scikit-learn.org/0.19/datasets/twenty_newsgroups.html\n",
"\n",
"from sklearn.feature_extraction.text import TfidfVectorizer\n",
"from sklearn.metrics import accuracy_score\n",
"\n",
"from tqdm.notebook import tqdm\n",
"\n",
"import torch\n",
"from torchcrf import CRF"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6695751c",
"metadata": {},
"outputs": [],
"source": [
"def build_vocab(dataset):\n",
" counter = Counter()\n",
" for document in dataset:\n",
" counter.update(document)\n",
" return Vocab(counter, specials=['<unk>', '<pad>', '<bos>', '<eos>'])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d247e4fe",
"metadata": {},
"outputs": [],
"source": [
"def data_process(dt, vocab):\n",
" return [torch.tensor([vocab[token] for token in document], dtype=torch.long) for document in dt]\n",
"\n",
"\n",
"def get_scores(y_true, y_pred):\n",
" acc_score = 0\n",
" tp = 0\n",
" fp = 0\n",
" selected_items = 0\n",
" relevant_items = 0\n",
" for p, t in zip(y_pred, y_true):\n",
" if p == t:\n",
" acc_score += 1\n",
" if p > 0 and p == t:\n",
" tp += 1\n",
" if p > 0:\n",
" selected_items += 1\n",
" if t > 0:\n",
" relevant_items += 1\n",
"\n",
" if selected_items == 0:\n",
" precision = 1.0\n",
" else:\n",
" precision = tp / selected_items\n",
"\n",
" if relevant_items == 0:\n",
" recall = 1.0\n",
" else:\n",
" recall = tp / relevant_items\n",
"\n",
" if precision + recall == 0.0:\n",
" f1 = 0.0\n",
" else:\n",
" f1 = 2 * precision * recall / (precision + recall)\n",
"\n",
" return precision, recall, f1"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b6061642",
"metadata": {},
"outputs": [],
"source": [
"def process_output(lines):\n",
" result = []\n",
" for line in lines:\n",
" last_label = None\n",
" new_line = []\n",
" for label in line:\n",
" if(label != \"O\" and label[0:2] == \"I-\"):\n",
" if last_label == None or last_label == \"O\":\n",
" label = label.replace('I-', 'B-')\n",
" else:\n",
" label = \"I-\" + last_label[2:]\n",
" last_label = label\n",
" new_line.append(label)\n",
" x = (\" \".join(new_line))\n",
" result.append(\" \".join(new_line))\n",
" return result"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3d7c4dd3",
"metadata": {},
"outputs": [],
"source": [
"class GRU(torch.nn.Module):\n",
" def __init__(self):\n",
" super(GRU, self).__init__()\n",
" self.emb = torch.nn.Embedding(len(vocab_x.itos),100)\n",
" self.dropout = torch.nn.Dropout(0.2)\n",
" self.rec = torch.nn.GRU(100, 256, 2, batch_first = True, bidirectional = True)\n",
" self.fc1 = torch.nn.Linear(2* 256 , 9)\n",
" \n",
" def forward(self, x):\n",
" emb = torch.relu(self.emb(x))\n",
" emb = self.dropout(emb) \n",
" gru_output, h_n = self.rec(emb) \n",
" out_weights = self.fc1(gru_output)\n",
" return out_weights"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cd5e419d",
"metadata": {},
"outputs": [],
"source": [
"def dev_eval(model, crf, dev_tokens, dev_labels_tokens, vocab):\n",
" Y_true = []\n",
" Y_pred = []\n",
" model.eval()\n",
" crf.eval()\n",
" for i in tqdm(range(len(dev_labels_tokens))):\n",
" batch_tokens = dev_tokens[i].unsqueeze(0)\n",
" tags = list(dev_labels_tokens[i].numpy())\n",
" Y_true += tags\n",
"\n",
" Y_batch_pred_weights = model(batch_tokens).squeeze(0)\n",
" Y_batch_pred = torch.argmax(Y_batch_pred_weights, 1)\n",
" Y_pred += [crf.decode(Y_batch_pred)[0]]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c808bbd5",
"metadata": {},
"outputs": [],
"source": [
"train = pd.read_csv('train/train.tsv', sep='\\t',\n",
" names=['labels', 'document'])\n",
"\n",
"Y_train = [y.split(sep=\" \") for y in train['labels'].values]\n",
"X_train = [x.split(sep=\" \") for x in train['document'].values]\n",
"\n",
"dev = pd.read_csv('dev-0/in.tsv', sep='\\t', names=['document'])\n",
"exp = pd.read_csv('dev-0/expected.tsv', sep='\\t', names=['labels'])\n",
"X_dev = [x.split(sep=\" \") for x in dev['document'].values]\n",
"Y_dev = [y.split(sep=\" \") for y in exp['labels'].values]\n",
"\n",
"test = pd.read_csv('test-A/in.tsv', sep='\\t', names=['document'])\n",
"X_test = test['document'].values"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "79485c9a",
"metadata": {},
"outputs": [],
"source": [
"vocab_x = build_vocab(X_train)\n",
"vocab_y = build_vocab(Y_train)\n",
"train_tokens = data_process(X_train, vocab_x)\n",
"labels_tokens = data_process(Y_train, vocab_y)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3726c82a",
"metadata": {},
"outputs": [],
"source": [
"torch.cuda.get_device_name(0)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f29e3b63",
"metadata": {},
"outputs": [],
"source": [
"device_gpu = torch.device(\"cuda:0\")\n",
"model = GRU()\n",
"crf = CRF(9)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9c321d58",
"metadata": {},
"outputs": [],
"source": [
"mask = torch.ByteTensor([1, 1])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "05482a7c",
"metadata": {},
"outputs": [],
"source": [
"criterion = torch.nn.CrossEntropyLoss()\n",
"params = list(model.parameters()) + list(crf.parameters())\n",
"optimizer = torch.optim.Adam(params)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "21a5282e",
"metadata": {},
"outputs": [],
"source": [
"for i in range(2):\n",
" crf.train()\n",
" model.train()\n",
" for i in tqdm(range(len(labels_tokens))):\n",
" batch_tokens = train_tokens[i].unsqueeze(0)\n",
" tags = labels_tokens[i].unsqueeze(1)\n",
"\n",
" predicted_tags = model(batch_tokens).squeeze(0).unsqueeze(1)\n",
"\n",
" optimizer.zero_grad()\n",
" loss = -crf(predicted_tags, tags)\n",
"\n",
" loss.backward()\n",
" optimizer.step()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cec14c35",
"metadata": {},
"outputs": [],
"source": [
"!pip3 install pytorch-crf"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1ee634f7",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from torchcrf import CRF"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.8"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

File diff suppressed because one or more lines are too long

BIN
seq_labeling.model Normal file

Binary file not shown.

View File

@ -184,3 +184,6 @@ with open("test-A/out.tsv", "w") as f:
for line in results_processed: for line in results_processed:
f.write(line + "\n") f.write(line + "\n")
model_path = "seq_labeling.model"
torch.save(ner_model.state_dict(), model_path)

File diff suppressed because one or more lines are too long