forked from kubapok/en-ner-conll-2003
delete script copy
This commit is contained in:
parent
499702ff9c
commit
9e5716b67e
@ -1,546 +0,0 @@
|
|||||||
{
|
|
||||||
"cells": [
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"## Zadanie domowe\n",
|
|
||||||
"\n",
|
|
||||||
"\n",
|
|
||||||
"- sklonować repozytorium https://git.wmi.amu.edu.pl/kubapok/en-ner-conll-2003\n",
|
|
||||||
"- stworzyć model seq labelling bazujący na sieci neuronowej opisanej w punkcie niżej (można bazować na tym jupyterze lub nie).\n",
|
|
||||||
"- model sieci to GRU (o dowolnych parametrach) + CRF w pytorchu korzystając z modułu CRF z poprzednich zajęć- - stworzyć predykcje w plikach dev-0/out.tsv oraz test-A/out.tsv\n",
|
|
||||||
"- wynik fscore sprawdzony za pomocą narzędzia geval (patrz poprzednie zadanie) powinien wynosić conajmniej 0.65\n",
|
|
||||||
"- proszę umieścić predykcję oraz skrypty generujące (w postaci tekstowej a nie jupyter) w repo, a w MS TEAMS umieścić link do swojego repo\n",
|
|
||||||
"termin 22.06, 60 punktów, za najlepszy wynik- 100 punktów\n",
|
|
||||||
" "
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 2,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"import numpy as np\n",
|
|
||||||
"import torch\n",
|
|
||||||
"from torchtext.vocab import Vocab\n",
|
|
||||||
"from collections import Counter\n",
|
|
||||||
"from tqdm.notebook import tqdm\n",
|
|
||||||
"import lzma\n",
|
|
||||||
"import itertools\n",
|
|
||||||
"from torchcrf import CRF"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 3,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"def read_data(filename):\n",
|
|
||||||
" all_data = lzma.open(filename).read().decode('UTF-8').split('\\n')\n",
|
|
||||||
" return [line.split('\\t') for line in all_data][:-1]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 4,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"def data_process(dt):\n",
|
|
||||||
" return [torch.tensor([vocab['<bos>']] + [vocab[token] for token in document] + [vocab['<eos>']], dtype = torch.long) for document in dt]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 5,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"def labels_process(dt):\n",
|
|
||||||
" return [ torch.tensor([0] + document + [0], dtype = torch.long) for document in dt]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 6,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"def build_vocab(dataset):\n",
|
|
||||||
" counter = Counter()\n",
|
|
||||||
" for document in dataset:\n",
|
|
||||||
" counter.update(document)\n",
|
|
||||||
" return Vocab(counter, specials=['<unk>', '<pad>', '<bos>', '<eos>'])"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 7,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"train_data = read_data('train/train.tsv.xz')\n",
|
|
||||||
"\n",
|
|
||||||
"tokens, ner_tags = [], []\n",
|
|
||||||
"for i in train_data:\n",
|
|
||||||
" ner_tags.append(i[0].split())\n",
|
|
||||||
" tokens.append(i[1].split())"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 8,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"vocab = build_vocab(tokens)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 9,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"train_tokens_ids = data_process(tokens)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 10,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"['B-LOC', 'B-MISC', 'B-ORG', 'B-PER', 'I-LOC', 'I-MISC', 'I-ORG', 'I-PER', 'O']\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"ner_tags_set = list(set(itertools.chain(*ner_tags)))\n",
|
|
||||||
"ner_tags_set.sort()\n",
|
|
||||||
"print(ner_tags_set)\n",
|
|
||||||
"train_labels = labels_process([[ner_tags_set.index(token) for token in doc] for doc in ner_tags])"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 11,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"num_tags = max([max(x) for x in train_labels]) + 1 "
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 12,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"class GRU(torch.nn.Module):\n",
|
|
||||||
"\n",
|
|
||||||
" def __init__(self):\n",
|
|
||||||
" super(GRU, self).__init__()\n",
|
|
||||||
" self.emb = torch.nn.Embedding(len(vocab.itos),100)\n",
|
|
||||||
" self.dropout = torch.nn.Dropout(0.2)\n",
|
|
||||||
" self.rec = torch.nn.GRU(100, 256, 2, batch_first = True, bidirectional = True)\n",
|
|
||||||
" self.fc1 = torch.nn.Linear(2* 256 , 9)\n",
|
|
||||||
" \n",
|
|
||||||
" def forward(self, x):\n",
|
|
||||||
" emb = torch.relu(self.emb(x))\n",
|
|
||||||
" emb = self.dropout(emb)\n",
|
|
||||||
" gru_output, h_n = self.rec(emb)\n",
|
|
||||||
" out_weights = self.fc1(gru_output)\n",
|
|
||||||
" return out_weights"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 13,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"def get_scores(y_true, y_pred):\n",
|
|
||||||
" acc_score = 0\n",
|
|
||||||
" tp = 0\n",
|
|
||||||
" fp = 0\n",
|
|
||||||
" selected_items = 0\n",
|
|
||||||
" relevant_items = 0 \n",
|
|
||||||
"\n",
|
|
||||||
" for p,t in zip(y_pred, y_true):\n",
|
|
||||||
" if p == t:\n",
|
|
||||||
" acc_score +=1\n",
|
|
||||||
"\n",
|
|
||||||
" if p > 0 and p == t:\n",
|
|
||||||
" tp +=1\n",
|
|
||||||
"\n",
|
|
||||||
" if p > 0:\n",
|
|
||||||
" selected_items += 1\n",
|
|
||||||
"\n",
|
|
||||||
" if t > 0 :\n",
|
|
||||||
" relevant_items +=1\n",
|
|
||||||
" \n",
|
|
||||||
" if selected_items == 0:\n",
|
|
||||||
" precision = 1.0\n",
|
|
||||||
" else:\n",
|
|
||||||
" precision = tp / selected_items\n",
|
|
||||||
" \n",
|
|
||||||
" if relevant_items == 0:\n",
|
|
||||||
" recall = 1.0\n",
|
|
||||||
" else:\n",
|
|
||||||
" recall = tp / relevant_items\n",
|
|
||||||
" \n",
|
|
||||||
" if precision + recall == 0.0 :\n",
|
|
||||||
" f1 = 0.0\n",
|
|
||||||
" else:\n",
|
|
||||||
" f1 = 2* precision * recall / (precision + recall)\n",
|
|
||||||
"\n",
|
|
||||||
" return precision, recall, f1"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 68,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"def eval_model(dataset_tokens, dataset_labels, model):\n",
|
|
||||||
" Y_true = []\n",
|
|
||||||
" Y_pred = []\n",
|
|
||||||
" gru.eval()\n",
|
|
||||||
" crf.eval()\n",
|
|
||||||
" for i in tqdm(range(len(dataset_labels))):\n",
|
|
||||||
" batch_tokens = dataset_tokens[i]\n",
|
|
||||||
" tags = list(dataset_labels[i].numpy())\n",
|
|
||||||
" emissions = ff(batch_tokens).unsqueeze(1)\n",
|
|
||||||
" Y_pred += crf.decode(emissions)[0]\n",
|
|
||||||
" Y_true += tags\n",
|
|
||||||
" return get_scores(Y_true, Y_pred)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 69,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"gru = GRU()\n",
|
|
||||||
"crf = CRF(num_tags)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 70,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"params = list(gru.parameters()) + list(crf.parameters())\n",
|
|
||||||
"criterion = torch.nn.CrossEntropyLoss()\n",
|
|
||||||
"optimizer = torch.optim.Adam(params)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 71,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"NUM_EPOCHS = 2"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 77,
|
|
||||||
"metadata": {
|
|
||||||
"scrolled": true
|
|
||||||
},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"application/vnd.jupyter.widget-view+json": {
|
|
||||||
"model_id": "946d7c19bfdd4671a7c4f5fbff7cc735",
|
|
||||||
"version_major": 2,
|
|
||||||
"version_minor": 0
|
|
||||||
},
|
|
||||||
"text/plain": [
|
|
||||||
"HBox(children=(FloatProgress(value=0.0, max=945.0), HTML(value='')))"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "display_data"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"ename": "ValueError",
|
|
||||||
"evalue": "emissions must have dimension of 3, got 4",
|
|
||||||
"output_type": "error",
|
|
||||||
"traceback": [
|
|
||||||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
|
||||||
"\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
|
|
||||||
"\u001b[0;32m<ipython-input-77-3305460cbe70>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0mpredicted_tags\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mgru\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch_tokens\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzero_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 9\u001b[0;31m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0mcrf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpredicted_tags\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0munsqueeze\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mtags\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msqueeze\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 10\u001b[0m \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
|
||||||
"\u001b[0;32m~/.local/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1049\u001b[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m 1050\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1051\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1052\u001b[0m \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1053\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
|
||||||
"\u001b[0;32m~/.local/lib/python3.8/site-packages/torchcrf/__init__.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, emissions, tags, mask, reduction)\u001b[0m\n\u001b[1;32m 88\u001b[0m \u001b[0mreduction\u001b[0m \u001b[0;32mis\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0mnone\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m \u001b[0motherwise\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 89\u001b[0m \"\"\"\n\u001b[0;32m---> 90\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_validate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0memissions\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtags\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtags\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmask\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmask\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 91\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mreduction\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32min\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;34m'none'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'sum'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'mean'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'token_mean'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 92\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf'invalid reduction: {reduction}'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
|
||||||
"\u001b[0;32m~/.local/lib/python3.8/site-packages/torchcrf/__init__.py\u001b[0m in \u001b[0;36m_validate\u001b[0;34m(self, emissions, tags, mask)\u001b[0m\n\u001b[1;32m 145\u001b[0m mask: Optional[torch.ByteTensor] = None) -> None:\n\u001b[1;32m 146\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0memissions\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0;36m3\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 147\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf'emissions must have dimension of 3, got {emissions.dim()}'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 148\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0memissions\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnum_tags\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 149\u001b[0m raise ValueError(\n",
|
|
||||||
"\u001b[0;31mValueError\u001b[0m: emissions must have dimension of 3, got 4"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"for i in range(NUM_EPOCHS):\n",
|
|
||||||
" gru.train()\n",
|
|
||||||
" crf.train()\n",
|
|
||||||
" for i in tqdm(range(len(train_labels))):\n",
|
|
||||||
" batch_tokens = train_tokens_ids[i].unsqueeze(0)\n",
|
|
||||||
" tags = train_labels[i].unsqueeze(1)\n",
|
|
||||||
" predicted_tags = gru(batch_tokens)\n",
|
|
||||||
" optimizer.zero_grad()\n",
|
|
||||||
" loss = -crf(predicted_tags.unsqueeze(1),tags.squeeze(1))\n",
|
|
||||||
" loss.backward()\n",
|
|
||||||
" optimizer.step()\n",
|
|
||||||
" gru.eval()\n",
|
|
||||||
" crf.eval()\n",
|
|
||||||
" print(eval_model(train_tokens_ids, train_labels, gru))"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"## dev-0"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 55,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"def predict_labels(dataset_tokens, model):\n",
|
|
||||||
" Y_true = []\n",
|
|
||||||
" Y_pred = []\n",
|
|
||||||
" result = []\n",
|
|
||||||
" for i in tqdm(range(len(dataset_tokens))):\n",
|
|
||||||
" batch_tokens = dataset_tokens[i].unsqueeze(0)\n",
|
|
||||||
" Y_batch_pred_weights = model(batch_tokens).squeeze(0)\n",
|
|
||||||
" Y_batch_pred = torch.argmax(Y_batch_pred_weights,1)\n",
|
|
||||||
" Y_pred += list(Y_batch_pred.numpy())\n",
|
|
||||||
" result += [list(Y_batch_pred.numpy())]\n",
|
|
||||||
" return result"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 56,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"with open('dev-0/in.tsv', \"r\", encoding=\"utf-8\") as f:\n",
|
|
||||||
" dev_0_data = [line.rstrip() for line in f]\n",
|
|
||||||
" \n",
|
|
||||||
"dev_0_data = [i.split() for i in dev_0_data]\n",
|
|
||||||
"dev_0_tokens_ids = data_process(dev_0_data)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 57,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"with open('dev-0/expected.tsv', \"r\", encoding=\"utf-8\") as f:\n",
|
|
||||||
" dev_0_labels = [line.rstrip() for line in f]\n",
|
|
||||||
" \n",
|
|
||||||
"dev_0_labels = [i.split() for i in dev_0_labels]\n",
|
|
||||||
"dev_0_labels = labels_process([[ner_tags_set.index(token) for token in doc] for doc in dev_0_labels])"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 58,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"application/vnd.jupyter.widget-view+json": {
|
|
||||||
"model_id": "2e335038c15f4e68bbbd184e5f8dded2",
|
|
||||||
"version_major": 2,
|
|
||||||
"version_minor": 0
|
|
||||||
},
|
|
||||||
"text/plain": [
|
|
||||||
"HBox(children=(FloatProgress(value=0.0, max=215.0), HTML(value='')))"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "display_data"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"tmp = predict_labels(dev_0_tokens_ids, gru)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 65,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"r = [[ner_tags_set[i] for i in tmp2] for tmp2 in tmp]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 66,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"# for doc in r:\n",
|
|
||||||
"# if doc[0] != 'O':\n",
|
|
||||||
"# doc[0] = 'B' + doc[0][1:]\n",
|
|
||||||
"# for i in range(len(doc))[:-1]:\n",
|
|
||||||
"# if doc[i] == 'O':\n",
|
|
||||||
"# if doc[i + 1] != 'O':\n",
|
|
||||||
"# doc[i + 1] = 'B' + doc[i + 1][1:]\n",
|
|
||||||
"# elif doc[i + 1] != 'O':\n",
|
|
||||||
"# if doc[i][1:] == doc[i + 1][1:]:\n",
|
|
||||||
"# doc[i + 1] = 'I' + doc[i + 1][1:]\n",
|
|
||||||
"# else:\n",
|
|
||||||
"# doc[i + 1] = 'B' + doc[i + 1][1:]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 67,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"f = open(\"dev-0/out.tsv\", \"a\")\n",
|
|
||||||
"for i in r:\n",
|
|
||||||
" f.write(' '.join(i) + '\\n')\n",
|
|
||||||
"f.close()"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"## test-A"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 62,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"with open('test-A/in.tsv', \"r\", encoding=\"utf-8\") as f:\n",
|
|
||||||
" test_A_data = [line.rstrip() for line in f]\n",
|
|
||||||
" \n",
|
|
||||||
"test_A_data = [i.split() for i in test_A_data]\n",
|
|
||||||
"test_A_tokens_ids = data_process(test_A_data)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 63,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"application/vnd.jupyter.widget-view+json": {
|
|
||||||
"model_id": "83cd31864a29458f81be3d79cf43d1ca",
|
|
||||||
"version_major": 2,
|
|
||||||
"version_minor": 0
|
|
||||||
},
|
|
||||||
"text/plain": [
|
|
||||||
"HBox(children=(FloatProgress(value=0.0, max=215.0), HTML(value='')))"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "display_data"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"tmp = predict_labels(dev_0_tokens_ids, gru)\n",
|
|
||||||
"r = [[ner_tags_set[i] for i in tmp2] for tmp2 in tmp]\n",
|
|
||||||
"for doc in r:\n",
|
|
||||||
" if doc[0] != 'O':\n",
|
|
||||||
" doc[0] = 'B' + doc[0][1:]\n",
|
|
||||||
" for i in range(len(doc))[:-1]:\n",
|
|
||||||
" if doc[i] == 'O':\n",
|
|
||||||
" if doc[i + 1] != 'O':\n",
|
|
||||||
" doc[i + 1] = 'B' + doc[i + 1][1:]\n",
|
|
||||||
" elif doc[i + 1] != 'O':\n",
|
|
||||||
" if doc[i][1:] == doc[i + 1][1:]:\n",
|
|
||||||
" doc[i + 1] = 'I' + doc[i + 1][1:]\n",
|
|
||||||
" else:\n",
|
|
||||||
" doc[i + 1] = 'B' + doc[i + 1][1:]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 64,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"f = open(\"test-A/out.tsv\", \"a\")\n",
|
|
||||||
"for i in r:\n",
|
|
||||||
" f.write(' '.join(i) + '\\n')\n",
|
|
||||||
"f.close()"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": []
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"metadata": {
|
|
||||||
"kernelspec": {
|
|
||||||
"display_name": "Python 3",
|
|
||||||
"language": "python",
|
|
||||||
"name": "python3"
|
|
||||||
},
|
|
||||||
"language_info": {
|
|
||||||
"codemirror_mode": {
|
|
||||||
"name": "ipython",
|
|
||||||
"version": 3
|
|
||||||
},
|
|
||||||
"file_extension": ".py",
|
|
||||||
"mimetype": "text/x-python",
|
|
||||||
"name": "python",
|
|
||||||
"nbconvert_exporter": "python",
|
|
||||||
"pygments_lexer": "ipython3",
|
|
||||||
"version": "3.8.5"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"nbformat": 4,
|
|
||||||
"nbformat_minor": 4
|
|
||||||
}
|
|
Loading…
Reference in New Issue
Block a user