This commit is contained in:
s434695 2021-06-24 19:07:58 +02:00
parent 142eed56c0
commit 17184e30e8
4 changed files with 491 additions and 115 deletions

215
dev-0/out.tsv Normal file

File diff suppressed because one or more lines are too long

View File

@ -5,28 +5,18 @@
"execution_count": 1, "execution_count": 1,
"id": "bce0cfa7", "id": "bce0cfa7",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [],
{
"name": "stderr",
"output_type": "stream",
"text": [
"C:\\Users\\grzyb\\anaconda3\\lib\\site-packages\\gensim\\similarities\\__init__.py:15: UserWarning: The gensim.similarities.levenshtein submodule is disabled, because the optional Levenshtein package <https://pypi.org/project/python-Levenshtein/> is unavailable. Install Levenhstein (e.g. `pip install python-Levenshtein`) to suppress this warning.\n",
" warnings.warn(msg)\n"
]
}
],
"source": [ "source": [
"from os import sep\n", "from os import sep\n",
"from nltk import word_tokenize\n", "from nltk import word_tokenize\n",
"import pandas as pd\n", "import pandas as pd\n",
"import torch\n", "import torch\n",
"from TorchCRF import CRF\n", "from torchcrf import CRF\n",
"import gensim\n", "import gensim\n",
"from torch._C import device\n", "from torch._C import device\n",
"from tqdm import tqdm\n", "from tqdm import tqdm\n",
"from torchtext.vocab import Vocab\n", "from torchtext.vocab import Vocab\n",
"from collections import Counter, OrderedDict\n", "from collections import Counter, OrderedDict\n",
"import spacy\n",
"\n", "\n",
"\n", "\n",
"from torch.utils.data import DataLoader\n", "from torch.utils.data import DataLoader\n",
@ -43,65 +33,6 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 2, "execution_count": 2,
"id": "67ace382",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: pytorch-crf in c:\\users\\grzyb\\anaconda3\\lib\\site-packages (0.7.2)\n"
]
}
],
"source": [
"!pip3 install pytorch-crf"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "adc9a4de",
"metadata": {},
"outputs": [
{
"ename": "ModuleNotFoundError",
"evalue": "No module named 'torchcrf'",
"output_type": "error",
"traceback": [
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[1;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
"\u001b[1;32m<ipython-input-3-2a643b4fc1bb>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m 18\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 19\u001b[0m \u001b[1;32mimport\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 20\u001b[1;33m \u001b[1;32mfrom\u001b[0m \u001b[0mtorchcrf\u001b[0m \u001b[1;32mimport\u001b[0m \u001b[0mCRF\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[1;31mModuleNotFoundError\u001b[0m: No module named 'torchcrf'"
]
}
],
"source": [
"import numpy as np\n",
"import gensim\n",
"import torch\n",
"import pandas as pd\n",
"import seaborn as sns\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"from torchtext.vocab import Vocab\n",
"from collections import Counter\n",
"\n",
"from sklearn.datasets import fetch_20newsgroups\n",
"# https://scikit-learn.org/0.19/datasets/twenty_newsgroups.html\n",
"\n",
"from sklearn.feature_extraction.text import TfidfVectorizer\n",
"from sklearn.metrics import accuracy_score\n",
"\n",
"from tqdm.notebook import tqdm\n",
"\n",
"import torch\n",
"from torchcrf import CRF"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6695751c", "id": "6695751c",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -115,7 +46,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 3,
"id": "d247e4fe", "id": "d247e4fe",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -160,7 +91,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 4,
"id": "b6061642", "id": "b6061642",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -185,7 +116,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 5,
"id": "3d7c4dd3", "id": "3d7c4dd3",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -208,7 +139,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 6,
"id": "cd5e419d", "id": "cd5e419d",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -230,7 +161,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 7,
"id": "c808bbd5", "id": "c808bbd5",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -252,7 +183,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 8,
"id": "79485c9a", "id": "79485c9a",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -265,39 +196,18 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 12,
"id": "3726c82a",
"metadata": {},
"outputs": [],
"source": [
"torch.cuda.get_device_name(0)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f29e3b63", "id": "f29e3b63",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"device_gpu = torch.device(\"cuda:0\")\n",
"model = GRU()\n", "model = GRU()\n",
"crf = CRF(9)\n" "crf = CRF(9)\n"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 13,
"id": "9c321d58",
"metadata": {},
"outputs": [],
"source": [
"mask = torch.ByteTensor([1, 1])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "05482a7c", "id": "05482a7c",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -309,10 +219,32 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 14,
"id": "21a5282e", "id": "21a5282e",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
" 0%| | 0/945 [00:00<?, ?it/s]\n"
]
},
{
"ename": "ValueError",
"evalue": "expected last dimension of emissions is 10, got 9",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-14-6dc1a1c63d46>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzero_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 11\u001b[0;31m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0mcrf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpredicted_tags\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtags\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 12\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 13\u001b[0m \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 887\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 888\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 889\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 890\u001b[0m for hook in itertools.chain(\n\u001b[1;32m 891\u001b[0m \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/anaconda3/lib/python3.8/site-packages/torchcrf/__init__.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, emissions, tags, mask, reduction)\u001b[0m\n\u001b[1;32m 88\u001b[0m \u001b[0mreduction\u001b[0m \u001b[0;32mis\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0mnone\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m \u001b[0motherwise\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 89\u001b[0m \"\"\"\n\u001b[0;32m---> 90\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_validate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0memissions\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtags\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtags\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmask\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmask\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 91\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mreduction\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32min\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;34m'none'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'sum'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'mean'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'token_mean'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 92\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf'invalid reduction: {reduction}'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/anaconda3/lib/python3.8/site-packages/torchcrf/__init__.py\u001b[0m in \u001b[0;36m_validate\u001b[0;34m(self, emissions, tags, mask)\u001b[0m\n\u001b[1;32m 147\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf'emissions must have dimension of 3, got {emissions.dim()}'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 148\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0memissions\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnum_tags\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 149\u001b[0;31m raise ValueError(\n\u001b[0m\u001b[1;32m 150\u001b[0m \u001b[0;34mf'expected last dimension of emissions is {self.num_tags}, '\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 151\u001b[0m f'got {emissions.size(2)}')\n",
"\u001b[0;31mValueError\u001b[0m: expected last dimension of emissions is 10, got 9"
]
}
],
"source": [ "source": [
"for i in range(2):\n", "for i in range(2):\n",
" crf.train()\n", " crf.train()\n",
@ -333,22 +265,21 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"id": "cec14c35", "id": "366ab1fe",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"!pip3 install pytorch-crf" "Y_pred = []\n",
] "model.eval()\n",
}, "crf.eval()\n",
{ "for i in tqdm(range(len(test_tokens))):\n",
"cell_type": "code", " batch_tokens = test_tokens[i].unsqueeze(0)\n",
"execution_count": null, "\n",
"id": "1ee634f7", " Y_batch_pred = model(batch_tokens).squeeze(0).unsqueeze(1)\n",
"metadata": {}, " Y_pred += [crf.decode(Y_batch_pred)[0]]\n",
"outputs": [], "\n",
"source": [ "Y_pred_translate = translate(Y_pred, vocab)\n",
"import torch\n", "return Y_pred_translate"
"from torchcrf import CRF"
] ]
} }
], ],

230
test-A/out.tsv Normal file

File diff suppressed because one or more lines are too long