fix
This commit is contained in:
parent
142eed56c0
commit
17184e30e8
215
dev-0/out.tsv
Normal file
215
dev-0/out.tsv
Normal file
File diff suppressed because one or more lines are too long
@ -5,28 +5,18 @@
|
|||||||
"execution_count": 1,
|
"execution_count": 1,
|
||||||
"id": "bce0cfa7",
|
"id": "bce0cfa7",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [],
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"C:\\Users\\grzyb\\anaconda3\\lib\\site-packages\\gensim\\similarities\\__init__.py:15: UserWarning: The gensim.similarities.levenshtein submodule is disabled, because the optional Levenshtein package <https://pypi.org/project/python-Levenshtein/> is unavailable. Install Levenhstein (e.g. `pip install python-Levenshtein`) to suppress this warning.\n",
|
|
||||||
" warnings.warn(msg)\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
"source": [
|
||||||
"from os import sep\n",
|
"from os import sep\n",
|
||||||
"from nltk import word_tokenize\n",
|
"from nltk import word_tokenize\n",
|
||||||
"import pandas as pd\n",
|
"import pandas as pd\n",
|
||||||
"import torch\n",
|
"import torch\n",
|
||||||
"from TorchCRF import CRF\n",
|
"from torchcrf import CRF\n",
|
||||||
"import gensim\n",
|
"import gensim\n",
|
||||||
"from torch._C import device\n",
|
"from torch._C import device\n",
|
||||||
"from tqdm import tqdm\n",
|
"from tqdm import tqdm\n",
|
||||||
"from torchtext.vocab import Vocab\n",
|
"from torchtext.vocab import Vocab\n",
|
||||||
"from collections import Counter, OrderedDict\n",
|
"from collections import Counter, OrderedDict\n",
|
||||||
"import spacy\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"from torch.utils.data import DataLoader\n",
|
"from torch.utils.data import DataLoader\n",
|
||||||
@ -43,65 +33,6 @@
|
|||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 2,
|
"execution_count": 2,
|
||||||
"id": "67ace382",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Requirement already satisfied: pytorch-crf in c:\\users\\grzyb\\anaconda3\\lib\\site-packages (0.7.2)\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"!pip3 install pytorch-crf"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 3,
|
|
||||||
"id": "adc9a4de",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"ename": "ModuleNotFoundError",
|
|
||||||
"evalue": "No module named 'torchcrf'",
|
|
||||||
"output_type": "error",
|
|
||||||
"traceback": [
|
|
||||||
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
|
|
||||||
"\u001b[1;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
|
|
||||||
"\u001b[1;32m<ipython-input-3-2a643b4fc1bb>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m 18\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 19\u001b[0m \u001b[1;32mimport\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 20\u001b[1;33m \u001b[1;32mfrom\u001b[0m \u001b[0mtorchcrf\u001b[0m \u001b[1;32mimport\u001b[0m \u001b[0mCRF\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m",
|
|
||||||
"\u001b[1;31mModuleNotFoundError\u001b[0m: No module named 'torchcrf'"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"import numpy as np\n",
|
|
||||||
"import gensim\n",
|
|
||||||
"import torch\n",
|
|
||||||
"import pandas as pd\n",
|
|
||||||
"import seaborn as sns\n",
|
|
||||||
"from sklearn.model_selection import train_test_split\n",
|
|
||||||
"\n",
|
|
||||||
"from torchtext.vocab import Vocab\n",
|
|
||||||
"from collections import Counter\n",
|
|
||||||
"\n",
|
|
||||||
"from sklearn.datasets import fetch_20newsgroups\n",
|
|
||||||
"# https://scikit-learn.org/0.19/datasets/twenty_newsgroups.html\n",
|
|
||||||
"\n",
|
|
||||||
"from sklearn.feature_extraction.text import TfidfVectorizer\n",
|
|
||||||
"from sklearn.metrics import accuracy_score\n",
|
|
||||||
"\n",
|
|
||||||
"from tqdm.notebook import tqdm\n",
|
|
||||||
"\n",
|
|
||||||
"import torch\n",
|
|
||||||
"from torchcrf import CRF"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"id": "6695751c",
|
"id": "6695751c",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@ -115,7 +46,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": 3,
|
||||||
"id": "d247e4fe",
|
"id": "d247e4fe",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@ -160,7 +91,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": 4,
|
||||||
"id": "b6061642",
|
"id": "b6061642",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@ -185,7 +116,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": 5,
|
||||||
"id": "3d7c4dd3",
|
"id": "3d7c4dd3",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@ -208,7 +139,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": 6,
|
||||||
"id": "cd5e419d",
|
"id": "cd5e419d",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@ -230,7 +161,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": 7,
|
||||||
"id": "c808bbd5",
|
"id": "c808bbd5",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@ -252,7 +183,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": 8,
|
||||||
"id": "79485c9a",
|
"id": "79485c9a",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@ -265,39 +196,18 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": 12,
|
||||||
"id": "3726c82a",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"torch.cuda.get_device_name(0)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"id": "f29e3b63",
|
"id": "f29e3b63",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"device_gpu = torch.device(\"cuda:0\")\n",
|
|
||||||
"model = GRU()\n",
|
"model = GRU()\n",
|
||||||
"crf = CRF(9)\n"
|
"crf = CRF(9)\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": 13,
|
||||||
"id": "9c321d58",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"mask = torch.ByteTensor([1, 1])"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"id": "05482a7c",
|
"id": "05482a7c",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@ -309,10 +219,32 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": 14,
|
||||||
"id": "21a5282e",
|
"id": "21a5282e",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
" 0%| | 0/945 [00:00<?, ?it/s]\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"ename": "ValueError",
|
||||||
|
"evalue": "expected last dimension of emissions is 10, got 9",
|
||||||
|
"output_type": "error",
|
||||||
|
"traceback": [
|
||||||
|
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||||||
|
"\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
|
||||||
|
"\u001b[0;32m<ipython-input-14-6dc1a1c63d46>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzero_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 11\u001b[0;31m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0mcrf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpredicted_tags\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtags\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 12\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 13\u001b[0m \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||||
|
"\u001b[0;32m~/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 887\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 888\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 889\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 890\u001b[0m for hook in itertools.chain(\n\u001b[1;32m 891\u001b[0m \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||||
|
"\u001b[0;32m~/anaconda3/lib/python3.8/site-packages/torchcrf/__init__.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, emissions, tags, mask, reduction)\u001b[0m\n\u001b[1;32m 88\u001b[0m \u001b[0mreduction\u001b[0m \u001b[0;32mis\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0mnone\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m \u001b[0motherwise\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 89\u001b[0m \"\"\"\n\u001b[0;32m---> 90\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_validate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0memissions\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtags\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtags\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmask\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmask\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 91\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mreduction\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32min\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;34m'none'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'sum'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'mean'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'token_mean'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 92\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf'invalid reduction: {reduction}'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||||
|
"\u001b[0;32m~/anaconda3/lib/python3.8/site-packages/torchcrf/__init__.py\u001b[0m in \u001b[0;36m_validate\u001b[0;34m(self, emissions, tags, mask)\u001b[0m\n\u001b[1;32m 147\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf'emissions must have dimension of 3, got {emissions.dim()}'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 148\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0memissions\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnum_tags\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 149\u001b[0;31m raise ValueError(\n\u001b[0m\u001b[1;32m 150\u001b[0m \u001b[0;34mf'expected last dimension of emissions is {self.num_tags}, '\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 151\u001b[0m f'got {emissions.size(2)}')\n",
|
||||||
|
"\u001b[0;31mValueError\u001b[0m: expected last dimension of emissions is 10, got 9"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"for i in range(2):\n",
|
"for i in range(2):\n",
|
||||||
" crf.train()\n",
|
" crf.train()\n",
|
||||||
@ -333,22 +265,21 @@
|
|||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"id": "cec14c35",
|
"id": "366ab1fe",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"!pip3 install pytorch-crf"
|
"Y_pred = []\n",
|
||||||
]
|
"model.eval()\n",
|
||||||
},
|
"crf.eval()\n",
|
||||||
{
|
"for i in tqdm(range(len(test_tokens))):\n",
|
||||||
"cell_type": "code",
|
" batch_tokens = test_tokens[i].unsqueeze(0)\n",
|
||||||
"execution_count": null,
|
"\n",
|
||||||
"id": "1ee634f7",
|
" Y_batch_pred = model(batch_tokens).squeeze(0).unsqueeze(1)\n",
|
||||||
"metadata": {},
|
" Y_pred += [crf.decode(Y_batch_pred)[0]]\n",
|
||||||
"outputs": [],
|
"\n",
|
||||||
"source": [
|
"Y_pred_translate = translate(Y_pred, vocab)\n",
|
||||||
"import torch\n",
|
"return Y_pred_translate"
|
||||||
"from torchcrf import CRF"
|
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
230
test-A/out.tsv
Normal file
230
test-A/out.tsv
Normal file
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue
Block a user