forked from kubapok/en-ner-conll-2003
989 lines
5.0 MiB
989 lines
5.0 MiB
"cells": [
"cell_type": "markdown",
"metadata": {},
"source": [
"## Uczenie głębokie – przetwarzanie tekstu – laboratoria\n",
"# 3. RNN"
"cell_type": "markdown",
"metadata": {},
"source": [
"### Podejście softmax z embeddingami na przykładzie NER"
"cell_type": "code",
"execution_count": 46,
"metadata": {},
"outputs": [
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: torch in c:\\python312\\lib\\site-packages (2.3.0)\n",
"Requirement already satisfied: torchtext in c:\\python312\\lib\\site-packages (0.18.0)\n",
"Requirement already satisfied: filelock in c:\\python312\\lib\\site-packages (from torch) (3.14.0)\n",
"Requirement already satisfied: typing-extensions>=4.8.0 in c:\\python312\\lib\\site-packages (from torch) (4.11.0)\n",
"Requirement already satisfied: sympy in c:\\python312\\lib\\site-packages (from torch) (1.12)\n",
"Requirement already satisfied: networkx in c:\\python312\\lib\\site-packages (from torch) (3.3)\n",
"Requirement already satisfied: jinja2 in c:\\python312\\lib\\site-packages (from torch) (3.1.4)\n",
"Requirement already satisfied: fsspec in c:\\python312\\lib\\site-packages (from torch) (2024.3.1)\n",
"Requirement already satisfied: mkl<=2021.4.0,>=2021.1.1 in c:\\python312\\lib\\site-packages (from torch) (2021.4.0)\n",
"Requirement already satisfied: tqdm in c:\\python312\\lib\\site-packages (from torchtext) (4.66.4)\n",
"Requirement already satisfied: requests in c:\\python312\\lib\\site-packages (from torchtext) (2.32.2)\n",
"Requirement already satisfied: numpy in c:\\python312\\lib\\site-packages (from torchtext) (1.26.4)\n",
"Requirement already satisfied: intel-openmp==2021.* in c:\\python312\\lib\\site-packages (from mkl<=2021.4.0,>=2021.1.1->torch) (2021.4.0)\n",
"Requirement already satisfied: tbb==2021.* in c:\\python312\\lib\\site-packages (from mkl<=2021.4.0,>=2021.1.1->torch) (2021.12.0)\n",
"Requirement already satisfied: MarkupSafe>=2.0 in c:\\python312\\lib\\site-packages (from jinja2->torch) (2.1.5)\n",
"Requirement already satisfied: charset-normalizer<4,>=2 in c:\\python312\\lib\\site-packages (from requests->torchtext) (3.3.2)\n",
"Requirement already satisfied: idna<4,>=2.5 in c:\\python312\\lib\\site-packages (from requests->torchtext) (3.7)\n",
"Requirement already satisfied: urllib3<3,>=1.21.1 in c:\\python312\\lib\\site-packages (from requests->torchtext) (2.2.1)\n",
"Requirement already satisfied: certifi>=2017.4.17 in c:\\python312\\lib\\site-packages (from requests->torchtext) (2024.2.2)\n",
"Requirement already satisfied: mpmath>=0.19 in c:\\python312\\lib\\site-packages (from sympy->torch) (1.3.0)\n",
"Requirement already satisfied: colorama in c:\\users\\dominik\\appdata\\roaming\\python\\python312\\site-packages (from tqdm->torchtext) (0.4.6)\n",
"Requirement already satisfied: torch in c:\\python312\\lib\\site-packages (2.3.0)\n",
"Requirement already satisfied: datasets in c:\\python312\\lib\\site-packages (2.19.1)\n",
"Requirement already satisfied: filelock in c:\\python312\\lib\\site-packages (from torch) (3.14.0)\n",
"Requirement already satisfied: typing-extensions>=4.8.0 in c:\\python312\\lib\\site-packages (from torch) (4.11.0)\n",
"Requirement already satisfied: sympy in c:\\python312\\lib\\site-packages (from torch) (1.12)\n",
"Requirement already satisfied: networkx in c:\\python312\\lib\\site-packages (from torch) (3.3)\n",
"Requirement already satisfied: jinja2 in c:\\python312\\lib\\site-packages (from torch) (3.1.4)\n",
"Requirement already satisfied: fsspec in c:\\python312\\lib\\site-packages (from torch) (2024.3.1)\n",
"Requirement already satisfied: mkl<=2021.4.0,>=2021.1.1 in c:\\python312\\lib\\site-packages (from torch) (2021.4.0)\n",
"Requirement already satisfied: numpy>=1.17 in c:\\python312\\lib\\site-packages (from datasets) (1.26.4)\n",
"Requirement already satisfied: pyarrow>=12.0.0 in c:\\python312\\lib\\site-packages (from datasets) (16.1.0)\n",
"Requirement already satisfied: pyarrow-hotfix in c:\\python312\\lib\\site-packages (from datasets) (0.6)\n",
"Requirement already satisfied: dill<0.3.9,>=0.3.0 in c:\\python312\\lib\\site-packages (from datasets) (0.3.8)\n",
"Requirement already satisfied: pandas in c:\\python312\\lib\\site-packages (from datasets) (2.2.2)\n",
"Requirement already satisfied: requests>=2.19.0 in c:\\python312\\lib\\site-packages (from datasets) (2.32.2)\n",
"Requirement already satisfied: tqdm>=4.62.1 in c:\\python312\\lib\\site-packages (from datasets) (4.66.4)\n",
"Requirement already satisfied: xxhash in c:\\python312\\lib\\site-packages (from datasets) (3.4.1)\n",
"Requirement already satisfied: multiprocess in c:\\python312\\lib\\site-packages (from datasets) (0.70.16)\n",
"Requirement already satisfied: aiohttp in c:\\python312\\lib\\site-packages (from datasets) (3.9.5)\n",
"Requirement already satisfied: huggingface-hub>=0.21.2 in c:\\python312\\lib\\site-packages (from datasets) (0.23.1)\n",
"Requirement already satisfied: packaging in c:\\users\\dominik\\appdata\\roaming\\python\\python312\\site-packages (from datasets) (24.0)\n",
"Requirement already satisfied: pyyaml>=5.1 in c:\\python312\\lib\\site-packages (from datasets) (6.0.1)\n",
"Requirement already satisfied: aiosignal>=1.1.2 in c:\\python312\\lib\\site-packages (from aiohttp->datasets) (1.3.1)\n",
"Requirement already satisfied: attrs>=17.3.0 in c:\\python312\\lib\\site-packages (from aiohttp->datasets) (23.2.0)\n",
"Requirement already satisfied: frozenlist>=1.1.1 in c:\\python312\\lib\\site-packages (from aiohttp->datasets) (1.4.1)\n",
"Requirement already satisfied: multidict<7.0,>=4.5 in c:\\python312\\lib\\site-packages (from aiohttp->datasets) (6.0.5)\n",
"Requirement already satisfied: yarl<2.0,>=1.0 in c:\\python312\\lib\\site-packages (from aiohttp->datasets) (1.9.4)\n",
"Requirement already satisfied: intel-openmp==2021.* in c:\\python312\\lib\\site-packages (from mkl<=2021.4.0,>=2021.1.1->torch) (2021.4.0)\n",
"Requirement already satisfied: tbb==2021.* in c:\\python312\\lib\\site-packages (from mkl<=2021.4.0,>=2021.1.1->torch) (2021.12.0)\n",
"Requirement already satisfied: charset-normalizer<4,>=2 in c:\\python312\\lib\\site-packages (from requests>=2.19.0->datasets) (3.3.2)\n",
"Requirement already satisfied: idna<4,>=2.5 in c:\\python312\\lib\\site-packages (from requests>=2.19.0->datasets) (3.7)\n",
"Requirement already satisfied: urllib3<3,>=1.21.1 in c:\\python312\\lib\\site-packages (from requests>=2.19.0->datasets) (2.2.1)\n",
"Requirement already satisfied: certifi>=2017.4.17 in c:\\python312\\lib\\site-packages (from requests>=2.19.0->datasets) (2024.2.2)\n",
"Requirement already satisfied: colorama in c:\\users\\dominik\\appdata\\roaming\\python\\python312\\site-packages (from tqdm>=4.62.1->datasets) (0.4.6)\n",
"Requirement already satisfied: MarkupSafe>=2.0 in c:\\python312\\lib\\site-packages (from jinja2->torch) (2.1.5)\n",
"Requirement already satisfied: python-dateutil>=2.8.2 in c:\\users\\dominik\\appdata\\roaming\\python\\python312\\site-packages (from pandas->datasets) (2.9.0.post0)\n",
"Requirement already satisfied: pytz>=2020.1 in c:\\python312\\lib\\site-packages (from pandas->datasets) (2024.1)\n",
"Requirement already satisfied: tzdata>=2022.7 in c:\\python312\\lib\\site-packages (from pandas->datasets) (2024.1)\n",
"Requirement already satisfied: mpmath>=0.19 in c:\\python312\\lib\\site-packages (from sympy->torch) (1.3.0)\n",
"Requirement already satisfied: six>=1.5 in c:\\users\\dominik\\appdata\\roaming\\python\\python312\\site-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.16.0)\n",
"Requirement already satisfied: ipywidgets in c:\\python312\\lib\\site-packages (8.1.2)\n",
"Requirement already satisfied: comm>=0.1.3 in c:\\users\\dominik\\appdata\\roaming\\python\\python312\\site-packages (from ipywidgets) (0.2.2)\n",
"Requirement already satisfied: ipython>=6.1.0 in c:\\users\\dominik\\appdata\\roaming\\python\\python312\\site-packages (from ipywidgets) (8.24.0)\n",
"Requirement already satisfied: traitlets>=4.3.1 in c:\\users\\dominik\\appdata\\roaming\\python\\python312\\site-packages (from ipywidgets) (5.14.3)\n",
"Requirement already satisfied: widgetsnbextension~=4.0.10 in c:\\python312\\lib\\site-packages (from ipywidgets) (4.0.10)\n",
"Requirement already satisfied: jupyterlab-widgets~=3.0.10 in c:\\python312\\lib\\site-packages (from ipywidgets) (3.0.10)\n",
"Requirement already satisfied: decorator in c:\\users\\dominik\\appdata\\roaming\\python\\python312\\site-packages (from ipython>=6.1.0->ipywidgets) (5.1.1)\n",
"Requirement already satisfied: jedi>=0.16 in c:\\users\\dominik\\appdata\\roaming\\python\\python312\\site-packages (from ipython>=6.1.0->ipywidgets) (0.19.1)\n",
"Requirement already satisfied: matplotlib-inline in c:\\users\\dominik\\appdata\\roaming\\python\\python312\\site-packages (from ipython>=6.1.0->ipywidgets) (0.1.7)\n",
"Requirement already satisfied: prompt-toolkit<3.1.0,>=3.0.41 in c:\\users\\dominik\\appdata\\roaming\\python\\python312\\site-packages (from ipython>=6.1.0->ipywidgets) (3.0.43)\n",
"Requirement already satisfied: pygments>=2.4.0 in c:\\users\\dominik\\appdata\\roaming\\python\\python312\\site-packages (from ipython>=6.1.0->ipywidgets) (2.18.0)\n",
"Requirement already satisfied: stack-data in c:\\users\\dominik\\appdata\\roaming\\python\\python312\\site-packages (from ipython>=6.1.0->ipywidgets) (0.6.3)\n",
"Requirement already satisfied: colorama in c:\\users\\dominik\\appdata\\roaming\\python\\python312\\site-packages (from ipython>=6.1.0->ipywidgets) (0.4.6)\n",
"Requirement already satisfied: parso<0.9.0,>=0.8.3 in c:\\users\\dominik\\appdata\\roaming\\python\\python312\\site-packages (from jedi>=0.16->ipython>=6.1.0->ipywidgets) (0.8.4)\n",
"Requirement already satisfied: wcwidth in c:\\users\\dominik\\appdata\\roaming\\python\\python312\\site-packages (from prompt-toolkit<3.1.0,>=3.0.41->ipython>=6.1.0->ipywidgets) (0.2.13)\n",
"Requirement already satisfied: executing>=1.2.0 in c:\\users\\dominik\\appdata\\roaming\\python\\python312\\site-packages (from stack-data->ipython>=6.1.0->ipywidgets) (2.0.1)\n",
"Requirement already satisfied: asttokens>=2.1.0 in c:\\users\\dominik\\appdata\\roaming\\python\\python312\\site-packages (from stack-data->ipython>=6.1.0->ipywidgets) (2.4.1)\n",
"Requirement already satisfied: pure-eval in c:\\users\\dominik\\appdata\\roaming\\python\\python312\\site-packages (from stack-data->ipython>=6.1.0->ipywidgets) (0.2.2)\n",
"Requirement already satisfied: six>=1.12.0 in c:\\users\\dominik\\appdata\\roaming\\python\\python312\\site-packages (from asttokens>=2.1.0->stack-data->ipython>=6.1.0->ipywidgets) (1.16.0)\n"
"name": "stderr",
"output_type": "stream",
"text": [
"usage: jupyter [-h] [--version] [--config-dir] [--data-dir] [--runtime-dir]\n",
" [--paths] [--json] [--debug]\n",
" [subcommand]\n",
"Jupyter: Interactive Computing\n",
"positional arguments:\n",
" subcommand the subcommand to launch\n",
" -h, --help show this help message and exit\n",
" --version show the versions of core jupyter packages and exit\n",
" --config-dir show Jupyter config dir\n",
" --data-dir show Jupyter data dir\n",
" --runtime-dir show Jupyter runtime dir\n",
" --paths show all Jupyter paths. Add --json for machine-readable\n",
" format.\n",
" --json output paths as machine-readable json\n",
" --debug output debug information about paths\n",
"Available subcommands: kernel kernelspec migrate run troubleshoot\n",
"Jupyter command `jupyter-nbextension` not found.\n"
"source": [
"!pip install torch torchtext\n",
"!pip install torch datasets\n",
"!pip install ipywidgets\n",
"!jupyter nbextension enable --py widgetsnbextension"
"cell_type": "code",
"execution_count": 47,
"metadata": {},
"outputs": [],
"source": [
"from collections import Counter\n",
"import torch\n",
"from datasets import load_dataset\n",
"from torchtext.vocab import vocab\n",
"from tqdm import tqdm\n",
"from ipywidgets import FloatProgress"
"cell_type": "markdown",
"metadata": {},
"source": [
"Wczytujemy zbiór danych `conll2003` (, który zawiera teksty oznaczone znacznikami części mowy (*POS tags*): "
"cell_type": "code",
"execution_count": 48,
"metadata": {},
"outputs": [
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Python312\\Lib\\site-packages\\datasets\\ FutureWarning: The repository for conll2003 contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at\n",
"You can avoid this message in future by passing the argument `trust_remote_code=True`.\n",
"Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.\n",
" warnings.warn(\n"
"source": [
"dataset = load_dataset(\"conll2003\")"
"cell_type": "code",
"execution_count": 49,
"metadata": {
"scrolled": true
"outputs": [
"name": "stdout",
"output_type": "stream",
"text": [
" train: Dataset({\n",
" features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags'],\n",
" num_rows: 14041\n",
" })\n",
" validation: Dataset({\n",
" features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags'],\n",
" num_rows: 3250\n",
" })\n",
" test: Dataset({\n",
" features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags'],\n",
" num_rows: 3453\n",
" })\n",
"[['EU', 'rejects', 'German', 'call', 'to', 'boycott', 'British', 'lamb', '.'], ['Peter', 'Blackburn'], ['BRUSSELS', '1996-08-22'], ['The', 'European', 'Commission', 'said', 'on', 'Thursday', 'it', 'disagreed', 'with', 'German', 'advice', 'to', 'consumers', 'to', 'shun', 'British', 'lamb', 'until', 'scientists', 'determine', 'whether', 'mad', 'cow', 'disease', 'can', 'be', 'transmitted', 'to', 'sheep', '.'], ['Germany', \"'s\", 'representative', 'to', 'the', 'European', 'Union', \"'s\", 'veterinary', 'committee', 'Werner', 'Zwingmann', 'said', 'on', 'Wednesday', 'consumers', 'should', 'buy', 'sheepmeat', 'from', 'countries', 'other', 'than', 'Britain', 'until', 'the', 'scientific', 'advice', 'was', 'clearer', '.'], ['\"', 'We', 'do', \"n't\", 'support', 'any', 'such', 'recommendation', 'because', 'we', 'do', \"n't\", 'see', 'any', 'grounds', 'for', 'it', ',', '\"', 'the', 'Commission', \"'s\", 'chief', 'spokesman', 'Nikolaus', 'van', 'der', 'Pas', 'told', 'a', 'news', 'briefing', '.'], ['He', 'said', 'further', 'scientific', 'study', 'was', 'required', 'and', 'if', 'it', 'was', 'found', 'that', 'action', 'was', 'needed', 'it', 'should', 'be', 'taken', 'by', 'the', 'European', 'Union', '.'], ['He', 'said', 'a', 'proposal', 'last', 'month', 'by', 'EU', 'Farm', 'Commissioner', 'Franz', 'Fischler', 'to', 'ban', 'sheep', 'brains', ',', 'spleens', 'and', 'spinal', 'cords', 'from', 'the', 'human', 'and', 'animal', 'food', 'chains', 'was', 'a', 'highly', 'specific', 'and', 'precautionary', 'move', 'to', 'protect', 'human', 'health', '.'], ['Fischler', 'proposed', 'EU-wide', 'measures', 'after', 'reports', 'from', 'Britain', 'and', 'France', 'that', 'under', 'laboratory', 'conditions', 'sheep', 'could', 'contract', 'Bovine', 'Spongiform', 'Encephalopathy', '(', 'BSE', ')', '--', 'mad', 'cow', 'disease', '.'], ['But', 'Fischler', 'agreed', 'to', 'review', 'his', 'proposal', 'after', 'the', 'EU', \"'s\", 'standing', 'veterinary', 'committee', ',', 'mational', 'animal', 'health', 'officials', ',', 'questioned', 'if', 'such', 'action', 'was', 'justified', 'as', 'there', 'was', 'only', 'a', 'slight', 'risk', 'to', 'human', 'health', '.'], ['Spanish', 'Farm', 'Minister', 'Loyola', 'de', 'Palacio', 'had', 'earlier', 'accused', 'Fischler', 'at', 'an', 'EU', 'farm', 'ministers', \"'\", 'meeting', 'of', 'causing', 'unjustified', 'alarm', 'through', '\"', 'dangerous', 'generalisation', '.', '\"'], ['.'], ['Only', 'France', 'and', 'Britain', 'backed', 'Fischler', \"'s\", 'proposal', '.'], ['The', 'EU', \"'s\", 'scientific', 'veterinary', 'and', 'multidisciplinary', 'committees', 'are', 'due', 'to', 're-examine', 'the', 'issue', 'early', 'next', 'month', 'and', 'make', 'recommendations', 'to', 'the', 'senior', 'veterinary', 'officials', '.'], ['Sheep', 'have', 'long', 'been', 'known', 'to', 'contract', 'scrapie', ',', 'a', 'brain-wasting', 'disease', 'similar', 'to', 'BSE', 'which', 'is', 'believed', 'to', 'have', 'been', 'transferred', 'to', 'cattle', 'through', 'feed', 'containing', 'animal', 'waste', '.'], ['British', 'farmers', 'denied', 'on', 'Thursday', 'there', 'was', 'any', 'danger', 'to', 'human', 'health', 'from', 'their', 'sheep', ',', 'but', 'expressed', 'concern', 'that', 'German', 'government', 'advice', 'to', 'consumers', 'to', 'avoid', 'British', 'lamb', 'might', 'influence', 'consumers', 'across', 'Europe', '.'], ['\"', 'What', 'we', 'have', 'to', 'be', 'extremely', 'careful', 'of', 'is', 'how', 'other', 'countries', 'are', 'going', 'to', 'take', 'Germany', \"'s\", 'lead', ',', '\"', 'Welsh', 'National', 'Farmers', \"'\", 'Union', '(', 'NFU', ')', 'chairman', 'John', 'Lloyd', 'Jones', 'said', 'on', 'BBC', 'radio', '.'], ['Bonn', 'has', 'led', 'efforts', 'to', 'protect', 'public', 'health', 'after', 'consumer', 'confidence', 'collapsed', 'in', 'March', 'after', 'a', 'British', 'report', 'suggested', 'humans', 'could', 'contract', 'an', 'illness', 'similar', 'to', 'mad', 'cow', 'disease', 'by', 'eating', 'contaminated', 'beef', '.'], ['Germany', 'imported', '47,600', 'sheep', 'from', 'Britain', 'last', 'year', ',', 'near
"source": [
"cell_type": "markdown",
"metadata": {},
"source": [
"Poiżej funkcja, która tworzy słownik (\n",
"Parametr `special` określa symbole specjalne:\n",
"* `<unk>` – nieznany token\n",
"* `<pad>` – wypełnienie\n",
"* `<bos>` – początek zdania\n",
"* `<eos>` – koniec zdania"
"cell_type": "code",
"execution_count": 50,
"metadata": {},
"outputs": [
"name": "stdout",
"output_type": "stream",
"text": [
"[[11, 21, 11, 12, 21, 22, 11, 12, 0], [11, 12], [11, 12], [11, 12, 12, 21, 13, 11, 11, 21, 13, 11, 12, 13, 11, 21, 22, 11, 12, 17, 11, 21, 17, 11, 12, 12, 21, 22, 22, 13, 11, 0], [11, 11, 12, 13, 11, 12, 12, 11, 12, 12, 12, 12, 21, 13, 11, 12, 21, 22, 11, 13, 11, 1, 13, 11, 17, 11, 12, 12, 21, 1, 0], [0, 11, 21, 22, 22, 11, 12, 12, 17, 11, 21, 22, 22, 11, 12, 13, 11, 0, 0, 11, 12, 11, 12, 12, 12, 12, 12, 12, 21, 11, 12, 12, 0], [11, 21, 11, 12, 12, 21, 22, 0, 17, 11, 21, 22, 17, 11, 21, 22, 11, 21, 22, 22, 13, 11, 12, 12, 0], [11, 21, 11, 12, 11, 12, 13, 11, 12, 12, 12, 12, 21, 22, 11, 12, 0, 11, 0, 11, 12, 13, 11, 12, 12, 12, 12, 12, 21, 11, 12, 1, 2, 2, 11, 21, 22, 11, 12, 0], [11, 12, 12, 21, 13, 11, 13, 11, 12, 12, 11, 13, 11, 11, 12, 21, 22, 11, 12, 12, 0, 11, 0, 0, 11, 12, 12, 0], [0, 11, 21, 22, 22, 11, 12, 13, 11, 12, 11, 12, 12, 12, 0, 11, 12, 12, 12, 0, 21, 17, 11, 12, 21, 22, 13, 3, 21, 3, 11, 12, 12, 13, 11, 12, 0], [11, 12, 12, 12, 12, 12, 21, 22, 22, 11, 13, 11, 12, 12, 12, 11, 12, 13, 21, 1, 11, 13, 0, 11, 12, 0, 0], [0], [11, 12, 12, 12, 21, 11, 11, 12, 0], [11, 12, 11, 12, 12, 12, 12, 12, 21, 1, 21, 22, 11, 12, 11, 12, 12, 0, 21, 11, 13, 11, 12, 12, 12, 0], [11, 21, 22, 22, 22, 13, 11, 12, 0, 11, 12, 12, 1, 13, 11, 11, 21, 22, 22, 22, 22, 22, 13, 11, 13, 11, 21, 11, 12, 0], [11, 12, 21, 13, 11, 11, 21, 11, 12, 13, 11, 12, 13, 11, 12, 0, 0, 21, 11, 17, 11, 12, 12, 13, 11, 21, 22, 11, 12, 21, 22, 11, 13, 11, 0], [0, 11, 11, 21, 22, 22, 1, 2, 13, 21, 3, 11, 12, 21, 22, 22, 22, 11, 11, 12, 0, 0, 11, 12, 12, 11, 12, 0, 11, 0, 11, 12, 12, 12, 21, 13, 11, 12, 0], [11, 21, 22, 11, 21, 22, 11, 12, 13, 11, 12, 21, 13, 11, 13, 11, 12, 12, 21, 11, 21, 22, 11, 12, 1, 2, 2, 11, 12, 13, 21, 22, 11, 0], [11, 21, 11, 12, 13, 11, 11, 12, 0, 11, 12, 13, 11, 12, 0], [11, 21, 13, 11, 12, 13, 11, 12, 0, 11, 12, 12, 13, 11, 12, 0], [11, 12, 12, 12, 21, 13, 11, 12, 12, 0], [11, 12], [11, 12, 12, 12, 12, 13, 11, 12, 13, 11, 12, 12, 12, 12, 21, 22, 13, 11, 12, 12, 13, 11, 13, 11, 12, 13, 11, 13, 11, 12, 12, 11, 12, 12, 0], [11, 12, 12, 21, 11, 12, 0, 11, 12, 0, 13, 11, 12, 13, 0, 21, 0, 11, 21, 0, 0, 11, 11, 21, 13, 11, 12, 13, 11, 12, 12, 13, 11, 12, 0], [13, 11, 12, 13, 11, 12, 12, 12, 13, 11, 12, 12, 13, 11, 11, 21, 11, 12, 13, 11, 13, 11, 12, 0, 3, 11, 21, 22, 13, 11, 12, 0], [11, 3, 21, 3, 11, 12, 12, 11, 21, 22, 15, 13, 11, 13, 11, 11, 12, 12, 12, 12, 0, 11, 21, 13, 11, 13, 11, 13, 11, 0], [11, 21, 11, 12, 12, 12, 12, 13, 11, 21, 11, 21, 13, 11, 21, 22, 11, 12, 0, 11, 11, 12, 12, 12, 21, 13, 11, 12, 0, 11, 12, 0, 0], [11, 12, 21, 13, 11, 12, 12, 13, 11, 1, 11, 0], [11, 21, 11, 21, 11, 13, 11, 0], [21, 11], [11, 13, 11, 21, 11, 13, 21, 11, 12, 13, 11, 12, 13, 11, 13, 11, 12, 12, 13, 11, 12, 13, 11, 13, 11, 12, 12, 12, 12, 11, 12, 11, 21, 11, 0], [21, 3, 11, 13, 11, 12, 12, 21, 11, 12, 21, 1, 21, 22, 13, 11, 12, 13, 11, 0, 11, 12, 12, 12, 12, 21, 11, 0, 0, 11, 12, 12, 13, 11, 12, 13, 11, 12, 21, 22, 22, 13, 11, 12, 12, 0, 0], [11, 12, 21, 11, 11, 12, 12, 13, 11, 0, 11, 12, 0, 13, 21, 11, 12, 12, 13, 11, 13, 11, 17, 11, 21, 11, 13, 11, 12, 21, 22, 11, 12, 0], [0, 3, 21, 11, 12, 13, 11, 12, 12, 21, 22, 13, 11, 12, 0], [11, 21, 21, 22, 11, 12, 13, 11, 0, 0, 11, 11, 12, 12, 13, 11, 12, 11, 12, 21, 11, 13, 21, 0], [11, 12, 12, 11, 12, 21, 11, 12, 13, 11, 12, 11, 21, 22, 11, 13, 11, 11, 12, 0, 21, 11, 12, 13, 3, 11, 12, 21, 22, 11, 12, 3, 13, 11, 13, 11, 0, 11, 11, 21, 11, 12, 12, 0], [11, 0, 11, 21, 22, 22, 11, 12, 12, 21, 22, 11, 12, 12, 0, 21, 22, 13, 11, 12, 13, 11, 11, 12, 13, 11, 12, 12, 12, 0], [11, 21, 11, 3, 13, 11, 12, 0], [21, 11], [11, 21, 22, 11, 21, 11, 13, 11, 12, 13, 11, 0, 17, 11, 12, 12, 21, 22, 11, 12, 13, 11, 12, 0], [11, 21, 22, 22, 21, 22, 11, 12, 12, 12, 13, 11, 12, 0, 11, 12, 12, 12, 12, 21, 11, 12, 0, 11, 12, 12, 13, 11, 12, 13, 11, 13, 11, 12, 12, 0, 13, 21, 11, 3, 11, 0], [11, 12, 12, 12, 3, 11, 12, 12, 0, 11, 0], [11, 12], [11, 12, 12, 13, 11, 12, 21, 11, 12, 13, 11, 11, 12, 13, 11, 12, 12, 0, 11, 12, 12, 13, 11, 12, 21, 13, 11, 0], [11, 12, 21, 11, 12, 12, 21, 22, 13, 11, 12, 0, 11, 12, 12, 0, 1
"source": [
"cell_type": "code",
"execution_count": 51,
"metadata": {},
"outputs": [],
"source": [
"def build_vocab(dataset):\n",
" counter = Counter()\n",
" for document in dataset:\n",
" counter.update(document)\n",
" return vocab(counter, specials=[\"<unk>\", \"<pad>\", \"<bos>\", \"<eos>\"])"
"cell_type": "code",
"execution_count": 52,
"metadata": {},
"outputs": [],
"source": [
"v = build_vocab(dataset[\"train\"][\"tokens\"])"
"cell_type": "code",
"execution_count": 53,
"metadata": {},
"outputs": [],
"source": [
"itos = v.get_itos() # mapowanie indeksów na tokeny"
"cell_type": "code",
"execution_count": 54,
"metadata": {},
"outputs": [
"name": "stdout",
"output_type": "stream",
"text": [
"['<unk>', '<pad>', '<bos>', '<eos>', 'EU', 'rejects', 'German', 'call', 'to', 'boycott', 'British', 'lamb', '.', 'Peter', 'Blackburn', 'BRUSSELS', '1996-08-22', 'The', 'European', 'Commission', 'said', 'on', 'Thursday', 'it', 'disagreed', 'with', 'advice', 'consumers', 'shun', 'until', 'scientists', 'determine', 'whether', 'mad', 'cow', 'disease', 'can', 'be', 'transmitted', 'sheep', 'Germany', \"'s\", 'representative', 'the', 'Union', 'veterinary', 'committee', 'Werner', 'Zwingmann', 'Wednesday', 'should', 'buy', 'sheepmeat', 'from', 'countries', 'other', 'than', 'Britain', 'scientific', 'was', 'clearer', '\"', 'We', 'do', \"n't\", 'support', 'any', 'such', 'recommendation', 'because', 'we', 'see', 'grounds', 'for', ',', 'chief', 'spokesman', 'Nikolaus', 'van', 'der', 'Pas', 'told', 'a', 'news', 'briefing', 'He', 'further', 'study', 'required', 'and', 'if', 'found', 'that', 'action', 'needed', 'taken', 'by', 'proposal', 'last', 'month', 'Farm', 'Commissioner', 'Franz', 'Fischler', 'ban', 'brains', 'spleens', 'spinal', 'cords', 'human', 'animal', 'food', 'chains', 'highly', 'specific', 'precautionary', 'move', 'protect', 'health', 'proposed', 'EU-wide', 'measures', 'after', 'reports', 'France', 'under', 'laboratory', 'conditions', 'could', 'contract', 'Bovine', 'Spongiform', 'Encephalopathy', '(', 'BSE', ')', '--', 'But', 'agreed', 'review', 'his', 'standing', 'mational', 'officials', 'questioned', 'justified', 'as', 'there', 'only', 'slight', 'risk', 'Spanish', 'Minister', 'Loyola', 'de', 'Palacio', 'had', 'earlier', 'accused', 'at', 'an', 'farm', 'ministers', \"'\", 'meeting', 'of', 'causing', 'unjustified', 'alarm', 'through', 'dangerous', 'generalisation', 'Only', 'backed', 'multidisciplinary', 'committees', 'are', 'due', 're-examine', 'issue', 'early', 'next', 'make', 'recommendations', 'senior', 'Sheep', 'have', 'long', 'been', 'known', 'scrapie', 'brain-wasting', 'similar', 'which', 'is', 'believed', 'transferred', 'cattle', 'feed', 'containing', 'waste', 'farmers', 'denied', 'danger', 'their', 'but', 'expressed', 'concern', 'government', 'avoid', 'might', 'influence', 'across', 'Europe', 'What', 'extremely', 'careful', 'how', 'going', 'take', 'lead', 'Welsh', 'National', 'Farmers', 'NFU', 'chairman', 'John', 'Lloyd', 'Jones', 'BBC', 'radio', 'Bonn', 'has', 'led', 'efforts', 'public', 'consumer', 'confidence', 'collapsed', 'in', 'March', 'report', 'suggested', 'humans', 'illness', 'eating', 'contaminated', 'beef', 'imported', '47,600', 'year', 'nearly', 'half', 'total', 'imports', 'It', 'brought', '4,275', 'tonnes', 'mutton', 'some', '10', 'percent', 'overall', 'Rare', 'Hendrix', 'song', 'draft', 'sells', 'almost', '$', '17,000', 'LONDON', 'A', 'rare', 'handwritten', 'U.S.', 'guitar', 'legend', 'Jimi', 'sold', 'auction', 'late', 'musician', 'favourite', 'possessions', 'Florida', 'restaurant', 'paid', '10,925', 'pounds', '16,935', 'Ai', 'no', 'telling', 'penned', 'piece', 'London', 'hotel', 'stationery', '1966', 'At', 'end', 'January', '1967', 'concert', 'English', 'city', 'Nottingham', 'he', 'threw', 'sheet', 'paper', 'into', 'audience', 'where', 'retrieved', 'fan', 'Buyers', 'also', 'snapped', 'up', '16', 'items', 'were', 'put', 'former', 'girlfriend', 'Kathy', 'Etchingham', 'who', 'lived', 'him', '1969', 'They', 'included', 'black', 'lacquer', 'mother', 'pearl', 'inlaid', 'box', 'used', 'store', 'drugs', 'anonymous', 'Australian', 'purchaser', 'bought', '5,060', '7,845', 'guitarist', 'died', 'overdose', '1970', 'aged', '27', 'China', 'says', 'Taiwan', 'spoils', 'atmosphere', 'talks', 'BEIJING', 'Taipei', 'spoiling', 'resumption', 'Strait', 'visit', 'Ukraine', 'Taiwanese', 'Vice', 'President', 'Lien', 'Chan', 'this', 'week', 'infuriated', 'Beijing', 'Speaking', 'hours', 'Chinese', 'state', 'media', 'time', 'right', 'engage', 'political', 'Foreign', 'Ministry', 'Shen', 'Guofang', 'Reuters', ':', 'necessary', 'opening', 'disrupted', 'authorities', 'State', 'quoted', 'top', 'negotiator', 'Tang', 'Shubei', 'visiting', 'group', 'rivals', 'hold', 'Now', 'two', 'sides', '...', 'hostility', 'overseas', 'edition',
"source": [
"cell_type": "code",
"execution_count": 55,
"metadata": {},
"outputs": [
"data": {
"text/plain": [
"execution_count": 55,
"metadata": {},
"output_type": "execute_result"
"source": [
"len(itos) # liczba różnych tokenów w słowniku"
"cell_type": "code",
"execution_count": 56,
"metadata": {},
"outputs": [
"data": {
"text/plain": [
"execution_count": 56,
"metadata": {},
"output_type": "execute_result"
"source": [
"v[\"rejects\"] # indeks tokenu `on`"
"cell_type": "code",
"execution_count": 57,
"metadata": {},
"outputs": [
"data": {
"text/plain": [
"execution_count": 57,
"metadata": {},
"output_type": "execute_result"
"source": [
"v[\"<unk>\"] # indeks nieznanego tokenu"
"cell_type": "markdown",
"metadata": {},
"source": [
"W przypadku, gdy w analizowanym tekście znajdzie się token, którego nie ma w słowniku, będzie reprezentowany przez indeks domyślny (*default index*). Ustawiamy, żeby był taki sam, jak indeks „nieznanego tokenu”:"
"cell_type": "code",
"execution_count": 58,
"metadata": {},
"outputs": [],
"source": [
"cell_type": "code",
"execution_count": 59,
"metadata": {},
"outputs": [],
"source": [
"def data_process(dt):\n",
" # Wektoryzacja dokumentów tekstowych.\n",
" return [\n",
" torch.tensor(\n",
" [v[\"<bos>\"]] + [v[token] for token in document] + [v[\"<eos>\"]],\n",
" dtype=torch.long,\n",
" )\n",
" for document in dt\n",
" ]"
"cell_type": "code",
"execution_count": 60,
"metadata": {},
"outputs": [],
"source": [
"def labels_process(dt):\n",
" # Wektoryzacja etykiet (NER)\n",
" return [torch.tensor([0] + document + [0], dtype=torch.long) for document in dt]"
"cell_type": "markdown",
"metadata": {},
"source": [
"Teraz wektoryzujemy wszystkie dane:"
"cell_type": "code",
"execution_count": 61,
"metadata": {},
"outputs": [
"name": "stdout",
"output_type": "stream",
"text": [
"[['EU', 'rejects', 'German', 'call', 'to', 'boycott', 'British', 'lamb', '.'], ['Peter', 'Blackburn'], ['BRUSSELS', '1996-08-22'], ['The', 'European', 'Commission', 'said', 'on', 'Thursday', 'it', 'disagreed', 'with', 'German', 'advice', 'to', 'consumers', 'to', 'shun', 'British', 'lamb', 'until', 'scientists', 'determine', 'whether', 'mad', 'cow', 'disease', 'can', 'be', 'transmitted', 'to', 'sheep', '.'], ['Germany', \"'s\", 'representative', 'to', 'the', 'European', 'Union', \"'s\", 'veterinary', 'committee', 'Werner', 'Zwingmann', 'said', 'on', 'Wednesday', 'consumers', 'should', 'buy', 'sheepmeat', 'from', 'countries', 'other', 'than', 'Britain', 'until', 'the', 'scientific', 'advice', 'was', 'clearer', '.'], ['\"', 'We', 'do', \"n't\", 'support', 'any', 'such', 'recommendation', 'because', 'we', 'do', \"n't\", 'see', 'any', 'grounds', 'for', 'it', ',', '\"', 'the', 'Commission', \"'s\", 'chief', 'spokesman', 'Nikolaus', 'van', 'der', 'Pas', 'told', 'a', 'news', 'briefing', '.'], ['He', 'said', 'further', 'scientific', 'study', 'was', 'required', 'and', 'if', 'it', 'was', 'found', 'that', 'action', 'was', 'needed', 'it', 'should', 'be', 'taken', 'by', 'the', 'European', 'Union', '.'], ['He', 'said', 'a', 'proposal', 'last', 'month', 'by', 'EU', 'Farm', 'Commissioner', 'Franz', 'Fischler', 'to', 'ban', 'sheep', 'brains', ',', 'spleens', 'and', 'spinal', 'cords', 'from', 'the', 'human', 'and', 'animal', 'food', 'chains', 'was', 'a', 'highly', 'specific', 'and', 'precautionary', 'move', 'to', 'protect', 'human', 'health', '.'], ['Fischler', 'proposed', 'EU-wide', 'measures', 'after', 'reports', 'from', 'Britain', 'and', 'France', 'that', 'under', 'laboratory', 'conditions', 'sheep', 'could', 'contract', 'Bovine', 'Spongiform', 'Encephalopathy', '(', 'BSE', ')', '--', 'mad', 'cow', 'disease', '.'], ['But', 'Fischler', 'agreed', 'to', 'review', 'his', 'proposal', 'after', 'the', 'EU', \"'s\", 'standing', 'veterinary', 'committee', ',', 'mational', 'animal', 'health', 'officials', ',', 'questioned', 'if', 'such', 'action', 'was', 'justified', 'as', 'there', 'was', 'only', 'a', 'slight', 'risk', 'to', 'human', 'health', '.'], ['Spanish', 'Farm', 'Minister', 'Loyola', 'de', 'Palacio', 'had', 'earlier', 'accused', 'Fischler', 'at', 'an', 'EU', 'farm', 'ministers', \"'\", 'meeting', 'of', 'causing', 'unjustified', 'alarm', 'through', '\"', 'dangerous', 'generalisation', '.', '\"'], ['.'], ['Only', 'France', 'and', 'Britain', 'backed', 'Fischler', \"'s\", 'proposal', '.'], ['The', 'EU', \"'s\", 'scientific', 'veterinary', 'and', 'multidisciplinary', 'committees', 'are', 'due', 'to', 're-examine', 'the', 'issue', 'early', 'next', 'month', 'and', 'make', 'recommendations', 'to', 'the', 'senior', 'veterinary', 'officials', '.'], ['Sheep', 'have', 'long', 'been', 'known', 'to', 'contract', 'scrapie', ',', 'a', 'brain-wasting', 'disease', 'similar', 'to', 'BSE', 'which', 'is', 'believed', 'to', 'have', 'been', 'transferred', 'to', 'cattle', 'through', 'feed', 'containing', 'animal', 'waste', '.'], ['British', 'farmers', 'denied', 'on', 'Thursday', 'there', 'was', 'any', 'danger', 'to', 'human', 'health', 'from', 'their', 'sheep', ',', 'but', 'expressed', 'concern', 'that', 'German', 'government', 'advice', 'to', 'consumers', 'to', 'avoid', 'British', 'lamb', 'might', 'influence', 'consumers', 'across', 'Europe', '.'], ['\"', 'What', 'we', 'have', 'to', 'be', 'extremely', 'careful', 'of', 'is', 'how', 'other', 'countries', 'are', 'going', 'to', 'take', 'Germany', \"'s\", 'lead', ',', '\"', 'Welsh', 'National', 'Farmers', \"'\", 'Union', '(', 'NFU', ')', 'chairman', 'John', 'Lloyd', 'Jones', 'said', 'on', 'BBC', 'radio', '.'], ['Bonn', 'has', 'led', 'efforts', 'to', 'protect', 'public', 'health', 'after', 'consumer', 'confidence', 'collapsed', 'in', 'March', 'after', 'a', 'British', 'report', 'suggested', 'humans', 'could', 'contract', 'an', 'illness', 'similar', 'to', 'mad', 'cow', 'disease', 'by', 'eating', 'contaminated', 'beef', '.'], ['Germany', 'imported', '47,600', 'sheep', 'from', 'Britain', 'last', 'year', ',', 'near
"source": [
"train_tokens_ids = data_process(dataset[\"train\"][\"tokens\"])"
"cell_type": "code",
"execution_count": 62,
"metadata": {},
"outputs": [],
"source": [
"test_tokens_ids = data_process(dataset[\"test\"][\"tokens\"])"
"cell_type": "code",
"execution_count": 63,
"metadata": {},
"outputs": [],
"source": [
"validation_tokens_ids = data_process(dataset[\"validation\"][\"tokens\"])"
"cell_type": "code",
"execution_count": 64,
"metadata": {
"scrolled": true
"outputs": [
"name": "stdout",
"output_type": "stream",
"text": [
"[[3, 0, 7, 0, 0, 0, 7, 0, 0], [1, 2], [5, 0], [0, 3, 4, 0, 0, 0, 0, 0, 0, 7, 0, 0, 0, 0, 0, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [5, 0, 0, 0, 0, 3, 4, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 1, 2, 2, 2, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 4, 0], [0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 0, 7, 0, 0, 0, 0, 5, 0, 5, 0, 0, 0, 0, 0, 0, 0, 7, 8, 8, 0, 7, 0, 0, 0, 0, 0, 0], [0, 1, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [7, 0, 0, 1, 2, 2, 0, 0, 0, 1, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0], [0, 5, 0, 5, 0, 1, 0, 0, 0], [0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 7, 0, 0, 0, 0, 0, 0, 7, 0, 0, 0, 0, 0, 5, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 3, 4, 4, 4, 4, 0, 3, 0, 0, 1, 2, 2, 0, 0, 3, 4, 0], [5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [5, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 1, 0, 0, 0, 0, 0, 0, 0, 0], [5, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 7, 8, 8, 8, 0, 0, 0, 1, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 7, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [5, 0, 5, 0, 0, 0, 0, 0], [5, 0], [5, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 6, 0, 0, 0, 0, 5, 0, 7, 0, 0, 1, 2, 0, 0, 0, 0, 5, 0], [0, 0, 0, 0, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 3, 4, 0, 1, 2, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0], [0, 0, 0, 5, 0, 0, 0, 0, 5, 0, 1, 2, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 4, 4, 0, 1, 0, 0, 0], [0, 0, 0, 0, 3, 0, 3, 4, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 5, 0, 0, 0, 0, 0], [5, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 7, 0, 0, 1, 0], [5, 0, 0, 0, 0, 5, 0, 0], [5, 0], [5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0, 0, 3, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0, 0, 0, 0], [7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [5, 0], [7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 4, 4, 4, 4, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 7, 0, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [3, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [3, 4, 0, 0, 3, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0], [0, 3, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [3, 0, 0, 0, 0, 0, 0, 0, 0], [7, 0, 0, 0, 0, 0, 0, 0, 0, 0], [5, 0], [0, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 0], [0, 0, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0], [0, 1, 2, 0, 3, 4, 0, 0], [3, 0, 7, 0, 0, 0, 0, 0], [5, 0], [0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0], [0, 3, 4], [0, 7, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0
"source": [
"train_labels = labels_process(dataset[\"train\"][\"ner_tags\"])"
"cell_type": "code",
"execution_count": 65,
"metadata": {},
"outputs": [],
"source": [
"validation_labels = labels_process(dataset[\"validation\"][\"ner_tags\"])"
"cell_type": "code",
"execution_count": 66,
"metadata": {},
"outputs": [],
"source": [
"test_labels = labels_process(dataset[\"test\"][\"ner_tags\"])"
"cell_type": "markdown",
"metadata": {},
"source": [
"Przykład, jak wyglądają dane po zwektoryzowaniu:"
"cell_type": "code",
"execution_count": 67,
"metadata": {},
"outputs": [
"data": {
"text/plain": [
"tensor([ 2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 3])"
"execution_count": 67,
"metadata": {},
"output_type": "execute_result"
"source": [
"cell_type": "code",
"execution_count": 68,
"metadata": {},
"outputs": [
"data": {
"text/plain": [
"{'id': '0',\n",
" 'tokens': ['EU',\n",
" 'rejects',\n",
" 'German',\n",
" 'call',\n",
" 'to',\n",
" 'boycott',\n",
" 'British',\n",
" 'lamb',\n",
" '.'],\n",
" 'pos_tags': [22, 42, 16, 21, 35, 37, 16, 21, 7],\n",
" 'chunk_tags': [11, 21, 11, 12, 21, 22, 11, 12, 0],\n",
" 'ner_tags': [3, 0, 7, 0, 0, 0, 7, 0, 0]}"
"execution_count": 68,
"metadata": {},
"output_type": "execute_result"
"source": [
"cell_type": "code",
"execution_count": 69,
"metadata": {
"scrolled": true
"outputs": [
"data": {
"text/plain": [
"tensor([0, 3, 0, 7, 0, 0, 0, 7, 0, 0, 0])"
"execution_count": 69,
"metadata": {},
"output_type": "execute_result"
"source": [
"cell_type": "markdown",
"metadata": {},
"source": [
"Funkcja, której użyjemy do ewaluacji:"
"cell_type": "code",
"execution_count": 70,
"metadata": {},
"outputs": [],
"source": [
"def get_scores(y_true, y_pred):\n",
" # Funkcja zwraca precyzję, pokrycie i F1\n",
" acc_score = 0\n",
" tp = 0\n",
" fp = 0\n",
" selected_items = 0\n",
" relevant_items = 0\n",
" for p, t in zip(y_pred, y_true):\n",
" if p == t:\n",
" acc_score += 1\n",
" if p > 0 and p == t:\n",
" tp += 1\n",
" if p > 0:\n",
" selected_items += 1\n",
" if t > 0:\n",
" relevant_items += 1\n",
" if selected_items == 0:\n",
" precision = 1.0\n",
" else:\n",
" precision = tp / selected_items\n",
" if relevant_items == 0:\n",
" recall = 1.0\n",
" else:\n",
" recall = tp / relevant_items\n",
" if precision + recall == 0.0:\n",
" f1 = 0.0\n",
" else:\n",
" f1 = 2 * precision * recall / (precision + recall)\n",
" return precision, recall, f1"
"cell_type": "markdown",
"metadata": {},
"source": [
"Ile mamy różnych tagów NER?"
"cell_type": "code",
"execution_count": 71,
"metadata": {},
"outputs": [
"name": "stdout",
"output_type": "stream",
"text": [
"source": [
"num_tags = max([max(x) for x in dataset[\"train\"][\"ner_tags\"]]) + 1\n",
"cell_type": "markdown",
"metadata": {},
"source": [
"Implementacja rekurencyjnej sieci neuronowej LSTM:"
"cell_type": "code",
"execution_count": 72,
"metadata": {},
"outputs": [],
"source": [
"class LSTM(torch.nn.Module):\n",
" def __init__(self):\n",
" super(LSTM, self).__init__()\n",
" self.emb = torch.nn.Embedding(len(v.get_itos()), 100)\n",
" self.rec = torch.nn.LSTM(100, 256, 1, batch_first=True)\n",
" self.fc1 = torch.nn.Linear(256, num_tags)\n",
" def forward(self, x):\n",
" emb = torch.relu(self.emb(x))\n",
" lstm_output, (h_n, c_n) = self.rec(emb)\n",
" out_weights = self.fc1(lstm_output)\n",
" return out_weights"
"cell_type": "markdown",
"metadata": {},
"source": [
"Stworzenie modelu:"
"cell_type": "code",
"execution_count": 73,
"metadata": {},
"outputs": [],
"source": [
"lstm = LSTM()"
"cell_type": "markdown",
"metadata": {},
"source": [
"Definicja funkcji kosztu:"
"cell_type": "code",
"execution_count": 74,
"metadata": {},
"outputs": [],
"source": [
"criterion = torch.nn.CrossEntropyLoss()"
"cell_type": "markdown",
"metadata": {},
"source": [
"Definicja optymalizatora:"
"cell_type": "code",
"execution_count": 75,
"metadata": {},
"outputs": [],
"source": [
"optimizer = torch.optim.Adam(lstm.parameters())"
"cell_type": "markdown",
"metadata": {},
"source": [
"Funkcja do ewaluacji modelu:"
"cell_type": "code",
"execution_count": 76,
"metadata": {},
"outputs": [],
"source": [
"def eval_model(dataset_tokens, dataset_labels, model):\n",
" Y_true = []\n",
" Y_pred = []\n",
" for i in tqdm(range(len(dataset_labels))):\n",
" batch_tokens = dataset_tokens[i].unsqueeze(0)\n",
" tags = list(dataset_labels[i].numpy())\n",
" Y_true += tags\n",
" Y_batch_pred_weights = model(batch_tokens).squeeze(0)\n",
" Y_batch_pred = torch.argmax(Y_batch_pred_weights, 1)\n",
" Y_pred += list(Y_batch_pred.numpy())\n",
" return get_scores(Y_true, Y_pred)"
"cell_type": "markdown",
"metadata": {},
"source": [
"Uczenie modelu:"
"cell_type": "code",
"execution_count": 77,
"metadata": {},
"outputs": [],
"source": [
"cell_type": "code",
"execution_count": 78,
"metadata": {
"scrolled": false
"outputs": [
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 14041/14041 [05:54<00:00, 39.57it/s]\n",
"100%|██████████| 3250/3250 [00:01<00:00, 1678.69it/s]\n"
"name": "stdout",
"output_type": "stream",
"text": [
"(0.5988246210949583, 0.4500755550389399, 0.513902714181432)\n"
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 14041/14041 [07:01<00:00, 33.29it/s]\n",
"100%|██████████| 3250/3250 [00:01<00:00, 1652.85it/s]\n"
"name": "stdout",
"output_type": "stream",
"text": [
"(0.7379187666765491, 0.5786353597582239, 0.6486416053163073)\n"
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 14041/14041 [06:35<00:00, 35.49it/s]\n",
"100%|██████████| 3250/3250 [00:02<00:00, 1513.42it/s]\n"
"name": "stdout",
"output_type": "stream",
"text": [
"(0.7980072463768116, 0.6144368243635941, 0.6942930321140081)\n"
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 14041/14041 [06:34<00:00, 35.58it/s]\n",
"100%|██████████| 3250/3250 [00:02<00:00, 1468.00it/s]\n"
"name": "stdout",
"output_type": "stream",
"text": [
"(0.8167669945676113, 0.646634894804138, 0.7218113403399506)\n"
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 14041/14041 [06:28<00:00, 36.11it/s]\n",
"100%|██████████| 3250/3250 [00:02<00:00, 1558.26it/s]"
"name": "stdout",
"output_type": "stream",
"text": [
"(0.8325018896447468, 0.6401255376031617, 0.7237481929294256)\n"
"name": "stderr",
"output_type": "stream",
"text": [
"source": [
"for i in range(NUM_EPOCHS):\n",
" lstm.train()\n",
" # for i in tqdm(range(500)):\n",
" for i in tqdm(range(len(train_labels))):\n",
" batch_tokens = train_tokens_ids[i].unsqueeze(0)\n",
" tags = train_labels[i].unsqueeze(1)\n",
" predicted_tags = lstm(batch_tokens)\n",
" optimizer.zero_grad()\n",
" loss = criterion(predicted_tags.squeeze(0), tags.squeeze(1))\n",
" loss.backward()\n",
" optimizer.step()\n",
" lstm.eval()\n",
" print(eval_model(validation_tokens_ids, validation_labels, lstm))"
"cell_type": "markdown",
"metadata": {},
"source": [
"cell_type": "code",
"execution_count": 79,
"metadata": {
"scrolled": true
"outputs": [
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 3250/3250 [00:02<00:00, 1603.66it/s]\n"
"data": {
"text/plain": [
"(0.8325018896447468, 0.6401255376031617, 0.7237481929294256)"
"execution_count": 79,
"metadata": {},
"output_type": "execute_result"
"source": [
"eval_model(validation_tokens_ids, validation_labels, lstm)"
"cell_type": "code",
"execution_count": 80,
"metadata": {},
"outputs": [
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 3453/3453 [00:02<00:00, 1517.54it/s]\n"
"data": {
"text/plain": [
"(0.7690643591130341, 0.525887573964497, 0.6246430924665056)"
"execution_count": 80,
"metadata": {},
"output_type": "execute_result"
"source": [
"eval_model(test_tokens_ids, test_labels, lstm)"
"cell_type": "markdown",
"metadata": {},
"source": [
"## Zadanie 3\n",
"Sklonuj repozytorium\n",
"Stwórz model *sequence labelling* realizujący zadanie NER, oparty o dowolną rekurencyjną sieć neuronową (możesz wzorować się na przykładzie z zajęć).\n",
"W plikach dev-0/out.tsv oraz test-A/out.tsv umieść wyniki predykcji dla dev-0/in.tsv i test-A/in.tsv odpowiednio.\n",
"Do ewaluacji wykorzystaj narzędzie GEval (\n",
" wget\n",
" chmod u+x geval\n",
" ./geval --help\n",
"Liczba punktów uzyskanych za zadanie zależy od uzyskanej wartości accuracy na zbiorze `test-A` (wynik zaokrąglony w górę):\n",
" points = math.ceil(accuracy * 7.0)\n",
"⚠️ W systemie Moodle proszę załączyć plik `test-A/out.tsv` oraz link do repozytorium z rozwiązaniem zadania.\n",
" "
"metadata": {
"author": "Jakub Pokrywka",
"email": "",
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
"lang": "pl",
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
"subtitle": "11.NER RNN[ćwiczenia]",
"title": "Ekstrakcja informacji",
"year": "2021"
"nbformat": 4,
"nbformat_minor": 4