en-ner-conll-2003/3_RNN — kopia.ipynb

985 lines
284 KiB
Plaintext
Raw Normal View History

2024-05-23 21:25:06 +02:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Uczenie głębokie przetwarzanie tekstu laboratoria\n",
"# 3. RNN"
]
},
{
"cell_type": "code",
"execution_count": 187,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: torch in c:\\python312\\lib\\site-packages (2.3.0)\n",
"Requirement already satisfied: torchtext in c:\\python312\\lib\\site-packages (0.18.0)\n",
"Requirement already satisfied: filelock in c:\\python312\\lib\\site-packages (from torch) (3.14.0)\n",
"Requirement already satisfied: typing-extensions>=4.8.0 in c:\\python312\\lib\\site-packages (from torch) (4.11.0)\n",
"Requirement already satisfied: sympy in c:\\python312\\lib\\site-packages (from torch) (1.12)\n",
"Requirement already satisfied: networkx in c:\\python312\\lib\\site-packages (from torch) (3.3)\n",
"Requirement already satisfied: jinja2 in c:\\python312\\lib\\site-packages (from torch) (3.1.4)\n",
"Requirement already satisfied: fsspec in c:\\python312\\lib\\site-packages (from torch) (2024.3.1)\n",
"Requirement already satisfied: mkl<=2021.4.0,>=2021.1.1 in c:\\python312\\lib\\site-packages (from torch) (2021.4.0)\n",
"Requirement already satisfied: tqdm in c:\\python312\\lib\\site-packages (from torchtext) (4.66.4)\n",
"Requirement already satisfied: requests in c:\\python312\\lib\\site-packages (from torchtext) (2.32.2)\n",
"Requirement already satisfied: numpy in c:\\python312\\lib\\site-packages (from torchtext) (1.26.4)\n",
"Requirement already satisfied: intel-openmp==2021.* in c:\\python312\\lib\\site-packages (from mkl<=2021.4.0,>=2021.1.1->torch) (2021.4.0)\n",
"Requirement already satisfied: tbb==2021.* in c:\\python312\\lib\\site-packages (from mkl<=2021.4.0,>=2021.1.1->torch) (2021.12.0)\n",
"Requirement already satisfied: MarkupSafe>=2.0 in c:\\python312\\lib\\site-packages (from jinja2->torch) (2.1.5)\n",
"Requirement already satisfied: charset-normalizer<4,>=2 in c:\\python312\\lib\\site-packages (from requests->torchtext) (3.3.2)\n",
"Requirement already satisfied: idna<4,>=2.5 in c:\\python312\\lib\\site-packages (from requests->torchtext) (3.7)\n",
"Requirement already satisfied: urllib3<3,>=1.21.1 in c:\\python312\\lib\\site-packages (from requests->torchtext) (2.2.1)\n",
"Requirement already satisfied: certifi>=2017.4.17 in c:\\python312\\lib\\site-packages (from requests->torchtext) (2024.2.2)\n",
"Requirement already satisfied: mpmath>=0.19 in c:\\python312\\lib\\site-packages (from sympy->torch) (1.3.0)\n",
"Requirement already satisfied: colorama in c:\\users\\dominik\\appdata\\roaming\\python\\python312\\site-packages (from tqdm->torchtext) (0.4.6)\n",
"Requirement already satisfied: torch in c:\\python312\\lib\\site-packages (2.3.0)\n",
"Requirement already satisfied: datasets in c:\\python312\\lib\\site-packages (2.19.1)\n",
"Requirement already satisfied: filelock in c:\\python312\\lib\\site-packages (from torch) (3.14.0)\n",
"Requirement already satisfied: typing-extensions>=4.8.0 in c:\\python312\\lib\\site-packages (from torch) (4.11.0)\n",
"Requirement already satisfied: sympy in c:\\python312\\lib\\site-packages (from torch) (1.12)\n",
"Requirement already satisfied: networkx in c:\\python312\\lib\\site-packages (from torch) (3.3)\n",
"Requirement already satisfied: jinja2 in c:\\python312\\lib\\site-packages (from torch) (3.1.4)\n",
"Requirement already satisfied: fsspec in c:\\python312\\lib\\site-packages (from torch) (2024.3.1)\n",
"Requirement already satisfied: mkl<=2021.4.0,>=2021.1.1 in c:\\python312\\lib\\site-packages (from torch) (2021.4.0)\n",
"Requirement already satisfied: numpy>=1.17 in c:\\python312\\lib\\site-packages (from datasets) (1.26.4)\n",
"Requirement already satisfied: pyarrow>=12.0.0 in c:\\python312\\lib\\site-packages (from datasets) (16.1.0)\n",
"Requirement already satisfied: pyarrow-hotfix in c:\\python312\\lib\\site-packages (from datasets) (0.6)\n",
"Requirement already satisfied: dill<0.3.9,>=0.3.0 in c:\\python312\\lib\\site-packages (from datasets) (0.3.8)\n",
"Requirement already satisfied: pandas in c:\\python312\\lib\\site-packages (from datasets) (2.2.2)\n",
"Requirement already satisfied: requests>=2.19.0 in c:\\python312\\lib\\site-packages (from datasets) (2.32.2)\n",
"Requirement already satisfied: tqdm>=4.62.1 in c:\\python312\\lib\\site-packages (from datasets) (4.66.4)\n",
"Requirement already satisfied: xxhash in c:\\python312\\lib\\site-packages (from datasets) (3.4.1)\n",
"Requirement already satisfied: multiprocess in c:\\python312\\lib\\site-packages (from datasets) (0.70.16)\n",
"Requirement already satisfied: aiohttp in c:\\python312\\lib\\site-packages (from datasets) (3.9.5)\n",
"Requirement already satisfied: huggingface-hub>=0.21.2 in c:\\python312\\lib\\site-packages (from datasets) (0.23.1)\n",
"Requirement already satisfied: packaging in c:\\users\\dominik\\appdata\\roaming\\python\\python312\\site-packages (from datasets) (24.0)\n",
"Requirement already satisfied: pyyaml>=5.1 in c:\\python312\\lib\\site-packages (from datasets) (6.0.1)\n",
"Requirement already satisfied: aiosignal>=1.1.2 in c:\\python312\\lib\\site-packages (from aiohttp->datasets) (1.3.1)\n",
"Requirement already satisfied: attrs>=17.3.0 in c:\\python312\\lib\\site-packages (from aiohttp->datasets) (23.2.0)\n",
"Requirement already satisfied: frozenlist>=1.1.1 in c:\\python312\\lib\\site-packages (from aiohttp->datasets) (1.4.1)\n",
"Requirement already satisfied: multidict<7.0,>=4.5 in c:\\python312\\lib\\site-packages (from aiohttp->datasets) (6.0.5)\n",
"Requirement already satisfied: yarl<2.0,>=1.0 in c:\\python312\\lib\\site-packages (from aiohttp->datasets) (1.9.4)\n",
"Requirement already satisfied: intel-openmp==2021.* in c:\\python312\\lib\\site-packages (from mkl<=2021.4.0,>=2021.1.1->torch) (2021.4.0)\n",
"Requirement already satisfied: tbb==2021.* in c:\\python312\\lib\\site-packages (from mkl<=2021.4.0,>=2021.1.1->torch) (2021.12.0)\n",
"Requirement already satisfied: charset-normalizer<4,>=2 in c:\\python312\\lib\\site-packages (from requests>=2.19.0->datasets) (3.3.2)\n",
"Requirement already satisfied: idna<4,>=2.5 in c:\\python312\\lib\\site-packages (from requests>=2.19.0->datasets) (3.7)\n",
"Requirement already satisfied: urllib3<3,>=1.21.1 in c:\\python312\\lib\\site-packages (from requests>=2.19.0->datasets) (2.2.1)\n",
"Requirement already satisfied: certifi>=2017.4.17 in c:\\python312\\lib\\site-packages (from requests>=2.19.0->datasets) (2024.2.2)\n",
"Requirement already satisfied: colorama in c:\\users\\dominik\\appdata\\roaming\\python\\python312\\site-packages (from tqdm>=4.62.1->datasets) (0.4.6)\n",
"Requirement already satisfied: MarkupSafe>=2.0 in c:\\python312\\lib\\site-packages (from jinja2->torch) (2.1.5)\n",
"Requirement already satisfied: python-dateutil>=2.8.2 in c:\\users\\dominik\\appdata\\roaming\\python\\python312\\site-packages (from pandas->datasets) (2.9.0.post0)\n",
"Requirement already satisfied: pytz>=2020.1 in c:\\python312\\lib\\site-packages (from pandas->datasets) (2024.1)\n",
"Requirement already satisfied: tzdata>=2022.7 in c:\\python312\\lib\\site-packages (from pandas->datasets) (2024.1)\n",
"Requirement already satisfied: mpmath>=0.19 in c:\\python312\\lib\\site-packages (from sympy->torch) (1.3.0)\n",
"Requirement already satisfied: six>=1.5 in c:\\users\\dominik\\appdata\\roaming\\python\\python312\\site-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.16.0)\n",
"Requirement already satisfied: ipywidgets in c:\\python312\\lib\\site-packages (8.1.2)\n",
"Requirement already satisfied: comm>=0.1.3 in c:\\users\\dominik\\appdata\\roaming\\python\\python312\\site-packages (from ipywidgets) (0.2.2)\n",
"Requirement already satisfied: ipython>=6.1.0 in c:\\users\\dominik\\appdata\\roaming\\python\\python312\\site-packages (from ipywidgets) (8.24.0)\n",
"Requirement already satisfied: traitlets>=4.3.1 in c:\\users\\dominik\\appdata\\roaming\\python\\python312\\site-packages (from ipywidgets) (5.14.3)\n",
"Requirement already satisfied: widgetsnbextension~=4.0.10 in c:\\python312\\lib\\site-packages (from ipywidgets) (4.0.10)\n",
"Requirement already satisfied: jupyterlab-widgets~=3.0.10 in c:\\python312\\lib\\site-packages (from ipywidgets) (3.0.10)\n",
"Requirement already satisfied: decorator in c:\\users\\dominik\\appdata\\roaming\\python\\python312\\site-packages (from ipython>=6.1.0->ipywidgets) (5.1.1)\n",
"Requirement already satisfied: jedi>=0.16 in c:\\users\\dominik\\appdata\\roaming\\python\\python312\\site-packages (from ipython>=6.1.0->ipywidgets) (0.19.1)\n",
"Requirement already satisfied: matplotlib-inline in c:\\users\\dominik\\appdata\\roaming\\python\\python312\\site-packages (from ipython>=6.1.0->ipywidgets) (0.1.7)\n",
"Requirement already satisfied: prompt-toolkit<3.1.0,>=3.0.41 in c:\\users\\dominik\\appdata\\roaming\\python\\python312\\site-packages (from ipython>=6.1.0->ipywidgets) (3.0.43)\n",
"Requirement already satisfied: pygments>=2.4.0 in c:\\users\\dominik\\appdata\\roaming\\python\\python312\\site-packages (from ipython>=6.1.0->ipywidgets) (2.18.0)\n",
"Requirement already satisfied: stack-data in c:\\users\\dominik\\appdata\\roaming\\python\\python312\\site-packages (from ipython>=6.1.0->ipywidgets) (0.6.3)\n",
"Requirement already satisfied: colorama in c:\\users\\dominik\\appdata\\roaming\\python\\python312\\site-packages (from ipython>=6.1.0->ipywidgets) (0.4.6)\n",
"Requirement already satisfied: parso<0.9.0,>=0.8.3 in c:\\users\\dominik\\appdata\\roaming\\python\\python312\\site-packages (from jedi>=0.16->ipython>=6.1.0->ipywidgets) (0.8.4)\n",
"Requirement already satisfied: wcwidth in c:\\users\\dominik\\appdata\\roaming\\python\\python312\\site-packages (from prompt-toolkit<3.1.0,>=3.0.41->ipython>=6.1.0->ipywidgets) (0.2.13)\n",
"Requirement already satisfied: executing>=1.2.0 in c:\\users\\dominik\\appdata\\roaming\\python\\python312\\site-packages (from stack-data->ipython>=6.1.0->ipywidgets) (2.0.1)\n",
"Requirement already satisfied: asttokens>=2.1.0 in c:\\users\\dominik\\appdata\\roaming\\python\\python312\\site-packages (from stack-data->ipython>=6.1.0->ipywidgets) (2.4.1)\n",
"Requirement already satisfied: pure-eval in c:\\users\\dominik\\appdata\\roaming\\python\\python312\\site-packages (from stack-data->ipython>=6.1.0->ipywidgets) (0.2.2)\n",
"Requirement already satisfied: six>=1.12.0 in c:\\users\\dominik\\appdata\\roaming\\python\\python312\\site-packages (from asttokens>=2.1.0->stack-data->ipython>=6.1.0->ipywidgets) (1.16.0)\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"usage: jupyter [-h] [--version] [--config-dir] [--data-dir] [--runtime-dir]\n",
" [--paths] [--json] [--debug]\n",
" [subcommand]\n",
"\n",
"Jupyter: Interactive Computing\n",
"\n",
"positional arguments:\n",
" subcommand the subcommand to launch\n",
"\n",
"options:\n",
" -h, --help show this help message and exit\n",
" --version show the versions of core jupyter packages and exit\n",
" --config-dir show Jupyter config dir\n",
" --data-dir show Jupyter data dir\n",
" --runtime-dir show Jupyter runtime dir\n",
" --paths show all Jupyter paths. Add --json for machine-readable\n",
" format.\n",
" --json output paths as machine-readable json\n",
" --debug output debug information about paths\n",
"\n",
"Available subcommands: kernel kernelspec migrate run troubleshoot\n",
"\n",
"Jupyter command `jupyter-nbextension` not found.\n"
]
}
],
"source": [
"!pip install torch torchtext\n",
"!pip install torch datasets\n",
"!pip install ipywidgets\n",
"!jupyter nbextension enable --py widgetsnbextension\n",
"\n",
"from collections import Counter\n",
"import torch\n",
"from datasets import load_dataset\n",
"from torchtext.vocab import vocab\n",
"from tqdm import tqdm\n",
"from ipywidgets import FloatProgress\n",
"\n",
"import pandas as pd\n",
"from nltk.tokenize import word_tokenize\n",
"from unidecode import unidecode"
]
},
{
"cell_type": "code",
"execution_count": 188,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"wczytano dane treningowe\n",
"O B-PER O O O O O O O O O B-LOC O O O O O O O O O O O B-LOC O O B-PER I-PER O O O O O O O O O O O O O O O O O O O O O O O B-LOC O O O O O O O O O O O O O B-MISC I-MISC I-MISC I-MISC O O O B-PER O O O O O B-LOC O O O O O O O O O O O O O O O O O B-MISC O O B-LOC O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O B-PER O O O B-PER I-PER O O O O O O O O O O O O O O O O O O O O O O O O B-PER O O O O O O O O B-MISC O O O O O O O O O O O O O O O O O O O O O O O O Rare Hendrix song draft sells for almost $ 17,000 . </S> LONDON 1996-08-22 </S> A rare early handwritten draft of a song by U.S. guitar legend Jimi Hendrix was sold for almost $ 17,000 on Thursday at an auction of some of the late musician 's favourite possessions . </S> A Florida restaurant paid 10,925 pounds ( $ 16,935 ) for the draft of \" Ai n't no telling \" , which Hendrix penned on a piece of London hotel stationery in late 1966 . </S> At the end of a January 1967 concert in the English city of Nottingham he threw the sheet of paper into the audience , where it was retrieved by a fan . </S> Buyers also snapped up 16 other items that were put up for auction by Hendrix 's former girlfriend Kathy Etchingham , who lived with him from 1966 to 1969 . </S> They included a black lacquer and mother of pearl inlaid box used by Hendrix to store his drugs , which an anonymous Australian purchaser bought for 5,060 pounds ( $ 7,845 ) . </S> The guitarist died of a drugs overdose in 1970 aged 27 . </S>\n",
"podzielono dane treningowe na słowa\n",
"['rare', 'hendrix', 'song', 'draft', 'sells', 'for', 'almost', '$', '17,000', '.', '</s>', 'london', '1996-08-22', '</s>', 'a', 'rare', 'early', 'handwritten', 'draft', 'of', 'a', 'song', 'by', 'u.s.', 'guitar', 'legend', 'jimi', 'hendrix', 'was', 'sold', 'for', 'almost', '$', '17,000', 'on', 'thursday', 'at', 'an', 'auction', 'of', 'some', 'of', 'the', 'late', 'musician', \"'s\", 'favourite', 'possessions', '.', '</s>', 'a', 'florida', 'restaurant', 'paid', '10,925', 'pounds', '(', '$', '16,935', ')', 'for', 'the', 'draft', 'of', '\"', 'ai', \"n't\", 'no', 'telling', '\"', ',', 'which', 'hendrix', 'penned', 'on', 'a', 'piece', 'of', 'london', 'hotel', 'stationery', 'in', 'late', '1966', '.', '</s>', 'at', 'the', 'end', 'of', 'a', 'january', '1967', 'concert', 'in', 'the', 'english', 'city', 'of', 'nottingham', 'he', 'threw', 'the', 'sheet', 'of', 'paper', 'into', 'the', 'audience', ',', 'where', 'it', 'was', 'retrieved', 'by', 'a', 'fan', '.', '</s>', 'buyers', 'also', 'snapped', 'up', '16', 'other', 'items', 'that', 'were', 'put', 'up', 'for', 'auction', 'by', 'hendrix', \"'s\", 'former', 'girlfriend', 'kathy', 'etchingham', ',', 'who', 'lived', 'with', 'him', 'from', '1966', 'to', '1969', '.', '</s>', 'they', 'included', 'a', 'black', 'lacquer', 'and', 'mother', 'of', 'pearl', 'inlaid', 'box', 'used', 'by', 'hendrix', 'to', 'store', 'his', 'drugs', ',', 'which', 'an', 'anonymous', 'australian', 'purchaser', 'bought', 'for', '5,060', 'pounds', '(', '$', '7,845', ')', '.', '</s>', 'the', 'guitarist', 'died', 'of', 'a', 'drugs', 'overdose', 'in', '1970', 'aged', '27', '.', '</s>']\n"
]
}
],
"source": [
"# odczytaj dane treningowe\n",
"train = pd.read_csv('train/train.tsv', sep='\\t')\n",
"train.columns = [\"y\", \"x\"]\n",
"print(\"wczytano dane treningowe\")\n",
"print(train[\"y\"][0], train[\"x\"][0])\n",
"\n",
"# podziel dane treningowe na słowa\n",
"# https://www.geeksforgeeks.org/python-word-embedding-using-word2vec/\n",
"slowa_train = []\n",
"for tekst in train[\"x\"]:\n",
" pom = []\n",
" for slowo in tekst.split(\" \"):\n",
" #if slowo not in (\"<\",\"/s\",\">\",\"/S\",\"``\"):\n",
" pom.append(slowo.lower())\n",
" slowa_train.append(pom)\n",
"print(\"podzielono dane treningowe na słowa\")\n",
"print(slowa_train[0])"
]
},
{
"cell_type": "code",
"execution_count": 189,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"wczytano dane testowe dev-0\n",
"CRICKET - ENGLISH COUNTY CHAMPIONSHIP SCORES . </S> LONDON 1996-08-30 </S> Result and close of play scores in English county championship matches on Friday : </S> Leicester : Leicestershire beat Somerset by an innings and 39 runs . </S> Somerset 83 and 174 ( P. Simmons 4-38 ) , Leicestershire 296 . </S> Leicestershire 22 points , Somerset 4 . </S> Chester-le-Street : Glamorgan 259 and 207 ( A. Dale 69 , H. Morris 69 ; D. Blenkiron 4-43 ) , Durham 114 ( S. Watkin 4-28 ) and 81-3 . </S> Tunbridge Wells : Nottinghamshire 214 ( P. Johnson 84 ; M. McCague 4-55 ) , Kent 108-3 . </S> London ( The Oval ) : Warwickshire 195 , Surrey 429-7 ( C. Lewis 80 not out , M. Butcher 70 , G. Kersey 63 , J. Ratcliffe 63 , D. Bicknell 55 ) . </S> Hove : Sussex 363 ( W. Athey 111 , V. Drakes 52 ; I. Austin 4-37 ) , Lancashire 197-8 ( W. Hegg 54 ) </S> Portsmouth : Middlesex 199 and 426 ( J. Pooley 111 , M. Ramprakash 108 , M. Gatting 83 ) , Hampshire 232 and 109-5 . </S> Chesterfield : Worcestershire 238 and 133-5 , Derbyshire 471 ( J. Adams 123 , T.O'Gorman 109 not out , K. Barnett 87 ; T. Moody 6-82 ) </S> Bristol : Gloucestershire 183 and 185-6 ( J. Russell 56 not out ) , Northamptonshire 190 ( K. Curran 52 ; A. Smith 5-68 ) . </S>\n",
"podzielono dane treningowe na słowa\n",
"['cricket', '-', 'english', 'county', 'championship', 'scores', '.', '</s>', 'london', '1996-08-30', '</s>', 'result', 'and', 'close', 'of', 'play', 'scores', 'in', 'english', 'county', 'championship', 'matches', 'on', 'friday', ':', '</s>', 'leicester', ':', 'leicestershire', 'beat', 'somerset', 'by', 'an', 'innings', 'and', '39', 'runs', '.', '</s>', 'somerset', '83', 'and', '174', '(', 'p.', 'simmons', '4-38', ')', ',', 'leicestershire', '296', '.', '</s>', 'leicestershire', '22', 'points', ',', 'somerset', '4', '.', '</s>', 'chester-le-street', ':', 'glamorgan', '259', 'and', '207', '(', 'a.', 'dale', '69', ',', 'h.', 'morris', '69', ';', 'd.', 'blenkiron', '4-43', ')', ',', 'durham', '114', '(', 's.', 'watkin', '4-28', ')', 'and', '81-3', '.', '</s>', 'tunbridge', 'wells', ':', 'nottinghamshire', '214', '(', 'p.', 'johnson', '84', ';', 'm.', 'mccague', '4-55', ')', ',', 'kent', '108-3', '.', '</s>', 'london', '(', 'the', 'oval', ')', ':', 'warwickshire', '195', ',', 'surrey', '429-7', '(', 'c.', 'lewis', '80', 'not', 'out', ',', 'm.', 'butcher', '70', ',', 'g.', 'kersey', '63', ',', 'j.', 'ratcliffe', '63', ',', 'd.', 'bicknell', '55', ')', '.', '</s>', 'hove', ':', 'sussex', '363', '(', 'w.', 'athey', '111', ',', 'v.', 'drakes', '52', ';', 'i.', 'austin', '4-37', ')', ',', 'lancashire', '197-8', '(', 'w.', 'hegg', '54', ')', '</s>', 'portsmouth', ':', 'middlesex', '199', 'and', '426', '(', 'j.', 'pooley', '111', ',', 'm.', 'ramprakash', '108', ',', 'm.', 'gatting', '83', ')', ',', 'hampshire', '232', 'and', '109-5', '.', '</s>', 'chesterfield', ':', 'worcestershire', '238', 'and', '133-5', ',', 'derbyshire', '471', '(', 'j.', 'adams', '123', ',', \"t.o'gorman\", '109', 'not', 'out', ',', 'k.', 'barnett', '87', ';', 't.', 'moody', '6-82', ')', '</s>', 'bristol', ':', 'gloucestershire', '183', 'and', '185-6', '(', 'j.', 'russell', '56', 'not', 'out', ')', ',', 'northamptonshire', '190', '(', 'k.', 'curran', '52', ';', 'a.', 'smith', '5-68', ')', '.', '</s>']\n"
]
}
],
"source": [
"# odczytaj dane testowe dev-0\n",
"test_dev0 = pd.read_csv('dev-0/in.tsv', sep='\\t')\n",
"test_dev0.columns = [\"x\"]\n",
"print(\"wczytano dane testowe dev-0\")\n",
"print(test_dev0[\"x\"][0])\n",
"\n",
"# podziel dane testowe na słowa\n",
"# https://www.geeksforgeeks.org/python-word-embedding-using-word2vec/\n",
"slowa_test_dev0 = []\n",
"for tekst in test_dev0[\"x\"]:\n",
" pom = []\n",
" for slowo in tekst.split(\" \"):\n",
" #if slowo not in (\"<\",\"/s\",\">\",\"/S\",\"``\"):\n",
" pom.append(slowo.lower())\n",
" slowa_test_dev0.append(pom)\n",
"print(\"podzielono dane treningowe na słowa\")\n",
"print(slowa_test_dev0[0])"
]
},
{
"cell_type": "code",
"execution_count": 190,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"wczytano dane testowe A\n",
"RUGBY UNION - CUTTITTA BACK FOR ITALY AFTER A YEAR . </S> ROME 1996-12-06 </S> Italy recalled Marcello Cuttitta </S> on Friday for their friendly against Scotland at Murrayfield more than a year after the 30-year-old wing announced he was retiring following differences over selection . </S> Cuttitta , who trainer George Coste said was certain to play on Saturday week , was named in a 21-man squad lacking only two of the team beaten 54-21 by England at Twickenham last month . </S> Stefano Bordon is out through illness and Coste said he had dropped back row Corrado Covi , who had been recalled for the England game after five years out of the national team . </S> Cuttitta announced his retirement after the 1995 World Cup , where he took issue with being dropped from the Italy side that faced England in the pool stages . </S> Coste said he had approached the player two months ago about a comeback . </S> \" He ended the World Cup on the wrong note , \" Coste said . </S> \" I thought it would be useful to have him back and he said he would be available . </S> I think now is the right time for him to return . \" </S> Squad : Javier Pertile , Paolo Vaccari , Marcello Cuttitta , Ivan Francescato , Leandro Manteri , Diego Dominguez , Francesco Mazzariol , Alessandro Troncon , Orazio Arancio , Andrea Sgorlon , Massimo Giovanelli , Carlo Checchinato , Walter Cristofoletto , Franco Properzi Curti , Carlo Orlandi , Massimo Cuttitta , Giambatista Croci , Gianluca Guidi , Nicola Mazzucato , Alessandro Moscardi , Andrea Castellani . </S>\n",
"podzielono dane treningowe na słowa\n",
"['rugby', 'union', '-', 'cuttitta', 'back', 'for', 'italy', 'after', 'a', 'year', '.', '</s>', 'rome', '1996-12-06', '</s>', 'italy', 'recalled', 'marcello', 'cuttitta', '</s>', 'on', 'friday', 'for', 'their', 'friendly', 'against', 'scotland', 'at', 'murrayfield', 'more', 'than', 'a', 'year', 'after', 'the', '30-year-old', 'wing', 'announced', 'he', 'was', 'retiring', 'following', 'differences', 'over', 'selection', '.', '</s>', 'cuttitta', ',', 'who', 'trainer', 'george', 'coste', 'said', 'was', 'certain', 'to', 'play', 'on', 'saturday', 'week', ',', 'was', 'named', 'in', 'a', '21-man', 'squad', 'lacking', 'only', 'two', 'of', 'the', 'team', 'beaten', '54-21', 'by', 'england', 'at', 'twickenham', 'last', 'month', '.', '</s>', 'stefano', 'bordon', 'is', 'out', 'through', 'illness', 'and', 'coste', 'said', 'he', 'had', 'dropped', 'back', 'row', 'corrado', 'covi', ',', 'who', 'had', 'been', 'recalled', 'for', 'the', 'england', 'game', 'after', 'five', 'years', 'out', 'of', 'the', 'national', 'team', '.', '</s>', 'cuttitta', 'announced', 'his', 'retirement', 'after', 'the', '1995', 'world', 'cup', ',', 'where', 'he', 'took', 'issue', 'with', 'being', 'dropped', 'from', 'the', 'italy', 'side', 'that', 'faced', 'england', 'in', 'the', 'pool', 'stages', '.', '</s>', 'coste', 'said', 'he', 'had', 'approached', 'the', 'player', 'two', 'months', 'ago', 'about', 'a', 'comeback', '.', '</s>', '\"', 'he', 'ended', 'the', 'world', 'cup', 'on', 'the', 'wrong', 'note', ',', '\"', 'coste', 'said', '.', '</s>', '\"', 'i', 'thought', 'it', 'would', 'be', 'useful', 'to', 'have', 'him', 'back', 'and', 'he', 'said', 'he', 'would', 'be', 'available', '.', '</s>', 'i', 'think', 'now', 'is', 'the', 'right', 'time', 'for', 'him', 'to', 'return', '.', '\"', '</s>', 'squad', ':', 'javier', 'pertile', ',', 'paolo', 'vaccari', ',', 'marcello', 'cuttitta', ',', 'ivan', 'francescato', ',', 'leandro', 'manteri', ',', 'diego', 'dominguez', ',', 'francesco', 'mazzariol', ',', 'alessandro', 'troncon', ',', 'orazio', 'arancio', ',', 'andrea', 'sgorlon', ',', 'massimo', 'giovanelli', ',', 'carlo', 'checchinato', ',', 'walter', 'cristofoletto', ',', 'franco', 'properzi', 'curti', ',', 'carlo', 'orlandi', ',', 'massimo', 'cuttitta', ',', 'giambatista', 'croci', ',', 'gianluca', 'guidi', ',', 'nicola', 'mazzucato', ',', 'alessandro', 'moscardi', ',', 'andrea', 'castellani', '.', '</s>']\n"
]
}
],
"source": [
"# odczytaj dane testowe A\n",
"test_A = pd.read_csv('test-A/in.tsv', sep='\\t')\n",
"test_A.columns = [\"x\"]\n",
"print(\"wczytano dane testowe A\")\n",
"print(test_A[\"x\"][0])\n",
"\n",
"# podziel dane testowe na słowa\n",
"# https://www.geeksforgeeks.org/python-word-embedding-using-word2vec/\n",
"slowa_test_A = []\n",
"for tekst in test_A[\"x\"]:\n",
" pom = []\n",
" for slowo in tekst.split(\" \"):\n",
" #if slowo not in (\"<\",\"/s\",\">\",\"/S\",\"``\"):\n",
" pom.append(slowo.lower())\n",
" slowa_test_A.append(pom)\n",
"print(\"podzielono dane treningowe na słowa\")\n",
"print(slowa_test_A[0])"
]
},
{
"cell_type": "code",
"execution_count": 191,
"metadata": {},
"outputs": [],
"source": [
"def build_vocab(dataset):\n",
" counter = Counter()\n",
" for document in dataset:\n",
" counter.update(document)\n",
" return vocab(counter, specials=[\"<unk>\", \"<pad>\", \"<bos>\", \"<eos>\"])"
]
},
{
"cell_type": "code",
"execution_count": 192,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"20998\n",
"['<unk>', '<pad>', '<bos>', '<eos>', 'rare', 'hendrix', 'song', 'draft', 'sells', 'for', 'almost', '$', '17,000', '.', '</s>', 'london', '1996-08-22', 'a', 'early', 'handwritten', 'of', 'by', 'u.s.', 'guitar', 'legend', 'jimi', 'was', 'sold', 'on', 'thursday', 'at', 'an', 'auction', 'some', 'the', 'late', 'musician', \"'s\", 'favourite', 'possessions', 'florida', 'restaurant', 'paid', '10,925', 'pounds', '(', '16,935', ')', '\"', 'ai', \"n't\", 'no', 'telling', ',', 'which', 'penned', 'piece', 'hotel', 'stationery', 'in', '1966', 'end', 'january', '1967', 'concert', 'english', 'city', 'nottingham', 'he', 'threw', 'sheet', 'paper', 'into', 'audience', 'where', 'it', 'retrieved', 'fan', 'buyers', 'also', 'snapped', 'up', '16', 'other', 'items', 'that', 'were', 'put', 'former', 'girlfriend', 'kathy', 'etchingham', 'who', 'lived', 'with', 'him', 'from', 'to', '1969', 'they', 'included', 'black', 'lacquer', 'and', 'mother', 'pearl', 'inlaid', 'box', 'used', 'store', 'his', 'drugs', 'anonymous', 'australian', 'purchaser', 'bought', '5,060', '7,845', 'guitarist', 'died', 'overdose', '1970', 'aged', '27', 'china', 'says', 'taiwan', 'spoils', 'atmosphere', 'talks', 'beijing', 'accused', 'taipei', 'spoiling', 'resumption', 'across', 'strait', 'visit', 'ukraine', 'taiwanese', 'vice', 'president', 'lien', 'chan', 'this', 'week', 'infuriated', 'speaking', 'only', 'hours', 'after', 'chinese', 'state', 'media', 'said', 'time', 'right', 'engage', 'political', 'foreign', 'ministry', 'spokesman', 'shen', 'guofang', 'told', 'reuters', ':', 'necessary', 'opening', 'has', 'been', 'disrupted', 'authorities', 'quoted', 'top', 'negotiator', 'tang', 'shubei', 'as', 'visiting', 'group', 'wednesday', 'rivals', 'hold', 'now', 'is', 'two', 'sides', '...', 'hostility', 'overseas', 'edition', 'people', 'daily', 'saying', 'television', 'interview', 'had', 'read', 'reports', 'comments', 'but', 'gave', 'details', 'why', 'considered', 'considers', 'renegade', 'province', 'long', 'opposed', 'all', 'efforts', 'gain', 'greater', 'international', 'recognition', 'rival', 'island', 'should', 'take', 'practical', 'steps', 'towards', 'goal', 'consultations', 'be', 'held', 'set', 'format', 'official', 'xinhua', 'news', 'agency', 'executive', 'chairman', 'association', 'relations', 'straits', 'german', 'july', 'car', 'registrations', '14.2', 'pct', 'yr', '/', 'frankfurt', 'first-time', 'motor', 'vehicles', 'jumped', 'percent', 'year', 'year-earlier', 'period', 'federal', 'office', '356,725', 'new', 'cars', 'registered', '1996', '--', '304,850', 'passenger', '15,613', 'trucks', 'figures', 'represent', '13.6', 'increase', '2.2', 'decline', '1995', 'motor-bike', 'registration', 'rose', '32.7', 'growth', 'partly', 'due', 'increased', 'number', 'germans', 'buying', 'abroad', 'while', 'manufacturers', 'domestic', 'demand', 'weak', 'posted', 'gains', 'numbers', 'volkswagen', 'ag', 'won', '77,719', 'slightly', 'more', 'than', 'quarter', 'total', 'opel', 'together', 'general', 'motors', 'came', 'second', 'place', '49,269', '16.4', 'overall', 'figure', 'third', 'ford', '35,563', 'or', '11.7', 'seat', 'porsche', 'fewer', 'compared', 'last', '3,420', '5522', 'earlier', 'fell', '554', '643', 'greek', 'socialists', 'give', 'green', 'light', 'pm', 'elections', 'athens', 'socialist', 'party', 'bureau', 'prime', 'minister', 'costas', 'simitis', 'call', 'snap', 'its', 'secretary', 'skandalidis', 'reporters', 'going', 'make', 'announcement', 'cabinet', 'meeting', 'later', 'dimitris', 'kontogiannis', 'newsroom', '+301', '3311812-4', 'bayervb', 'sets', 'c$', '100', 'million', 'six-year', 'bond', 'following', 'announced', 'lead', 'manager', 'toronto', 'dominion', 'borrower', 'bayerische', 'vereinsbank', 'amt', 'mln', 'coupon', '6.625', 'maturity', '24.sep.02', 'type', 'straight', 'iss', 'price', '100.92', 'pay', 'date', '24.sep.96', 'full', 'fees', '1.875', 'reoffer', '99.32', 'spread', '+20', 'bp', 'moody', 'aa1', 'listing', 'lux', 'freq', '=', 's&p', 'denoms', 'k', '1-10-100', 'sale', 'limits', 'us', 'uk', 'ca', 'neg', 'plg', 'crs', 'deflt', 'force', 'maj', 'gov', 'law', 'home
]
}
],
"source": [
"v = build_vocab(slowa_train)\n",
"v.set_default_index(v[\"<unk>\"])\n",
"itos = v.get_itos() # mapowanie indeksów na tokeny\n",
"print(len(itos)) # liczba różnych tokenów w słowniku\n",
"print(itos)"
]
},
{
"cell_type": "code",
"execution_count": 193,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'O': 0, 'B-PER': 1, 'B-LOC': 2, 'I-PER': 3, 'B-MISC': 4, 'I-MISC': 5, 'I-LOC': 6, 'B-ORG': 7, 'I-ORG': 8}\n"
]
}
],
"source": [
"# slownik etykiety - kody etykiet\n",
"etykieta_na_kod = {}\n",
"licznik = 0\n",
"for tekst in train[\"y\"]:\n",
" for etykieta in tekst.split(\" \"):\n",
" if etykieta not in etykieta_na_kod:\n",
" etykieta_na_kod[etykieta] = licznik\n",
" licznik+=1\n",
"print(etykieta_na_kod)"
]
},
{
"cell_type": "code",
"execution_count": 194,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 1, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 5, 5, 5, 0, 0, 0, 1, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]\n"
]
}
],
"source": [
"# podziel etykiety\n",
"kody_etykiet_train = []\n",
"for tekst in train[\"y\"]:\n",
" pom = []\n",
" for etykieta in tekst.split(\" \"):\n",
" pom.append(etykieta_na_kod[etykieta])\n",
" kody_etykiet_train.append(pom)\n",
"print(kody_etykiet_train[0])"
]
},
{
"cell_type": "code",
"execution_count": 195,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"O O B-MISC I-MISC I-MISC O O O B-LOC O O O O O O O O O B-MISC O O O O O O O B-LOC O B-ORG O B-ORG O O O O O O O O B-ORG O O O O B-PER I-PER O O O B-ORG O O O B-ORG O O O B-ORG O O O B-LOC O B-ORG O O O O B-PER I-PER O O B-PER I-PER O O B-PER I-PER O O O B-ORG O O B-PER I-PER O O O O O O B-LOC I-LOC O B-ORG O O B-PER I-PER O O B-PER I-PER O O O B-ORG O O O B-LOC O B-LOC I-LOC O O B-ORG O O B-ORG O O B-PER I-PER O O O O B-PER I-PER O O B-PER I-PER O O B-PER I-PER O O B-PER I-PER O O O O B-LOC O B-ORG O O B-PER I-PER O O B-PER I-PER O O B-PER I-PER O O O B-ORG O O B-PER I-PER O O O B-LOC O B-ORG O O O O B-PER I-PER O O B-PER I-PER O O B-PER I-PER O O O B-ORG O O O O O B-LOC O B-ORG O O O O B-ORG O O B-PER I-PER O O B-PER O O O O B-PER I-PER O O B-PER I-PER O O O B-LOC O B-ORG O O O O B-PER I-PER O O O O O B-ORG O O B-PER I-PER O O B-PER I-PER O O O O\n",
"[0, 0, 4, 5, 5, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 2, 0, 7, 0, 7, 0, 0, 0, 0, 0, 0, 0, 0, 7, 0, 0, 0, 0, 1, 3, 0, 0, 0, 7, 0, 0, 0, 7, 0, 0, 0, 7, 0, 0, 0, 2, 0, 7, 0, 0, 0, 0, 1, 3, 0, 0, 1, 3, 0, 0, 1, 3, 0, 0, 0, 7, 0, 0, 1, 3, 0, 0, 0, 0, 0, 0, 2, 6, 0, 7, 0, 0, 1, 3, 0, 0, 1, 3, 0, 0, 0, 7, 0, 0, 0, 2, 0, 2, 6, 0, 0, 7, 0, 0, 7, 0, 0, 1, 3, 0, 0, 0, 0, 1, 3, 0, 0, 1, 3, 0, 0, 1, 3, 0, 0, 1, 3, 0, 0, 0, 0, 2, 0, 7, 0, 0, 1, 3, 0, 0, 1, 3, 0, 0, 1, 3, 0, 0, 0, 7, 0, 0, 1, 3, 0, 0, 0, 2, 0, 7, 0, 0, 0, 0, 1, 3, 0, 0, 1, 3, 0, 0, 1, 3, 0, 0, 0, 7, 0, 0, 0, 0, 0, 2, 0, 7, 0, 0, 0, 0, 7, 0, 0, 1, 3, 0, 0, 1, 0, 0, 0, 0, 1, 3, 0, 0, 1, 3, 0, 0, 0, 2, 0, 7, 0, 0, 0, 0, 1, 3, 0, 0, 0, 0, 0, 7, 0, 0, 1, 3, 0, 0, 1, 3, 0, 0, 0, 0]\n"
]
}
],
"source": [
"# odczytaj etykiety dev-0\n",
"labels_dev0 = pd.read_csv('dev-0/expected.tsv', sep='\\t')\n",
"labels_dev0.columns = [\"y\"]\n",
"print(labels_dev0[\"y\"][0])\n",
"\n",
"# podziel etykiety\n",
"kody_etykiet_dev0 = []\n",
"for tekst in labels_dev0[\"y\"]:\n",
" pom = []\n",
" for etykieta in tekst.split(\" \"):\n",
" pom.append(etykieta_na_kod[etykieta])\n",
" kody_etykiet_dev0.append(pom)\n",
"print(kody_etykiet_dev0[0])"
]
},
{
"cell_type": "code",
"execution_count": 196,
"metadata": {},
"outputs": [],
"source": [
"def data_process(dt):\n",
" # Wektoryzacja dokumentów tekstowych.\n",
" return [\n",
" torch.tensor(\n",
" [v[\"<bos>\"]] + [v[token] for token in document] + [v[\"<eos>\"]],\n",
" dtype=torch.long,\n",
" )\n",
" for document in dt\n",
" ]\n",
"\n",
"def labels_process(dt):\n",
" # Wektoryzacja etykiet (NER)\n",
" return [torch.tensor([0] + document + [0], dtype=torch.long) for document in dt]"
]
},
{
"cell_type": "code",
"execution_count": 197,
"metadata": {},
"outputs": [],
"source": [
"train_tokens_ids = data_process(slowa_train)\n",
"test_dev0_tokens_ids = data_process(slowa_test_dev0)\n",
"test_A_tokens_ids = data_process(slowa_test_A)\n",
"\n",
"train_labels = labels_process(kody_etykiet_train)\n",
"test_dev0_labels = labels_process(kody_etykiet_dev0)"
]
},
{
"cell_type": "code",
"execution_count": 198,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"944 199\n",
"214 256\n",
"229 283\n",
"tensor([ 2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,\n",
" 14, 17, 4, 18, 19, 7, 20, 17, 6, 21, 22, 23, 24, 25,\n",
" 5, 26, 27, 9, 10, 11, 12, 28, 29, 30, 31, 32, 20, 33,\n",
" 20, 34, 35, 36, 37, 38, 39, 13, 14, 17, 40, 41, 42, 43,\n",
" 44, 45, 11, 46, 47, 9, 34, 7, 20, 48, 49, 50, 51, 52,\n",
" 48, 53, 54, 5, 55, 28, 17, 56, 20, 15, 57, 58, 59, 35,\n",
" 60, 13, 14, 30, 34, 61, 20, 17, 62, 63, 64, 59, 34, 65,\n",
" 66, 20, 67, 68, 69, 34, 70, 20, 71, 72, 34, 73, 53, 74,\n",
" 75, 26, 76, 21, 17, 77, 13, 14, 78, 79, 80, 81, 82, 83,\n",
" 84, 85, 86, 87, 81, 9, 32, 21, 5, 37, 88, 89, 90, 91,\n",
" 53, 92, 93, 94, 95, 96, 60, 97, 98, 13, 14, 99, 100, 17,\n",
" 101, 102, 103, 104, 20, 105, 106, 107, 108, 21, 5, 97, 109, 110,\n",
" 111, 53, 54, 31, 112, 113, 114, 115, 9, 116, 44, 45, 11, 117,\n",
" 47, 13, 14, 34, 118, 119, 20, 17, 111, 120, 59, 121, 122, 123,\n",
" 13, 14, 3])\n",
"tensor([ 2, 1949, 459, 65, 1950, 1951, 1592, 13, 14, 15,\n",
" 19342, 14, 1793, 103, 1465, 20, 1952, 1592, 59, 65,\n",
" 1950, 1951, 1954, 28, 947, 166, 14, 1992, 166, 1993,\n",
" 1703, 1965, 21, 31, 2038, 103, 3671, 2932, 13, 14,\n",
" 1965, 6226, 103, 16331, 45, 1995, 1996, 0, 47, 53,\n",
" 1993, 0, 13, 14, 1993, 1055, 1330, 53, 1965, 1864,\n",
" 13, 14, 17021, 166, 1991, 19322, 103, 14088, 45, 1977,\n",
" 0, 1620, 53, 10801, 12466, 1620, 1962, 1958, 0, 0,\n",
" 47, 53, 1956, 19326, 45, 1960, 19327, 19328, 47, 103,\n",
" 16667, 13, 14, 19313, 1363, 166, 2012, 0, 45, 1995,\n",
" 2752, 5725, 1962, 1967, 0, 0, 47, 53, 1985, 0,\n",
" 13, 14, 15, 45, 34, 2037, 47, 166, 2020, 14779,\n",
" 53, 2018, 0, 45, 2030, 2059, 5455, 620, 618, 53,\n",
" 1967, 0, 1602, 53, 1963, 0, 1976, 53, 1974, 0,\n",
" 1976, 53, 1958, 0, 3843, 47, 13, 14, 19318, 166,\n",
" 2002, 16329, 45, 2024, 19320, 9379, 53, 2007, 2008, 1979,\n",
" 1962, 2061, 10865, 0, 47, 53, 2034, 0, 45, 2024,\n",
" 0, 2054, 47, 14, 6206, 166, 12584, 11568, 103, 11269,\n",
" 45, 1974, 19334, 9379, 53, 1967, 19335, 1997, 53, 1967,\n",
" 17052, 6226, 47, 53, 2000, 15584, 103, 0, 13, 14,\n",
" 9493, 166, 2026, 10970, 103, 0, 53, 9314, 0, 45,\n",
" 1974, 2717, 0, 53, 0, 6237, 620, 618, 53, 6223,\n",
" 19332, 11058, 1962, 6227, 401, 0, 47, 14, 9488, 166,\n",
" 1972, 19340, 103, 0, 45, 1974, 1975, 4451, 620, 618,\n",
" 47, 53, 2010, 14739, 45, 6223, 6224, 1979, 1962, 1977,\n",
" 4839, 1981, 47, 13, 14, 3])\n",
"tensor([ 2, 6342, 769, 459, 0, 960, 9, 1681, 150, 17,\n",
" 253, 13, 14, 5474, 0, 14, 1681, 3063, 0, 0,\n",
" 14, 28, 947, 9, 701, 7189, 572, 2124, 30, 0,\n",
" 300, 301, 17, 253, 150, 34, 14863, 6363, 371, 68,\n",
" 26, 3333, 370, 3631, 608, 11618, 13, 14, 0, 53,\n",
" 92, 1738, 1753, 0, 154, 26, 3388, 97, 1952, 28,\n",
" 3978, 145, 53, 26, 2116, 59, 17, 0, 2099, 14403,\n",
" 148, 186, 20, 34, 695, 2519, 0, 21, 1208, 30,\n",
" 0, 324, 729, 13, 14, 2725, 0, 185, 618, 863,\n",
" 1521, 103, 0, 154, 68, 197, 2954, 960, 2955, 0,\n",
" 0, 53, 92, 197, 170, 3063, 9, 34, 1208, 2154,\n",
" 150, 1824, 1053, 618, 20, 34, 457, 695, 13, 14,\n",
" 0, 371, 110, 12530, 150, 34, 274, 1593, 1711, 53,\n",
" 74, 68, 596, 452, 94, 1458, 2954, 96, 34, 1681,\n",
" 749, 85, 2517, 1208, 59, 34, 8797, 9174, 13, 14,\n",
" 0, 154, 68, 197, 4705, 34, 6392, 186, 836, 2521,\n",
" 700, 17, 17097, 13, 14, 48, 68, 1240, 34, 1593,\n",
" 1711, 28, 34, 4645, 3370, 53, 48, 0, 154, 13,\n",
" 14, 48, 1500, 1798, 75, 693, 226, 5612, 97, 606,\n",
" 95, 960, 103, 68, 154, 68, 693, 226, 1221, 13,\n",
" 14, 1500, 4604, 184, 185, 34, 156, 155, 9, 95,\n",
" 97, 671, 13, 48, 14, 2099, 166, 2718, 0, 53,\n",
" 0, 16807, 53, 0, 0, 53, 2886, 0, 53, 0,\n",
" 0, 53, 2854, 0, 53, 10959, 0, 53, 11542, 0,\n",
" 53, 0, 0, 53, 2219, 0, 53, 0, 0, 53,\n",
" 4036, 0, 53, 17118, 0, 53, 11460, 0, 0, 53,\n",
" 4036, 0, 53, 0, 0, 53, 0, 0, 53, 9462,\n",
" 0, 53, 13541, 0, 53, 11542, 0, 53, 2219, 0,\n",
" 13, 14, 3])\n"
]
}
],
"source": [
"print(len(train_tokens_ids), len(train_tokens_ids[0]))\n",
"print(len(test_dev0_tokens_ids), len(test_dev0_tokens_ids[0]))\n",
"print(len(test_A_tokens_ids), len(test_A_tokens_ids[0]))\n",
"\n",
"print(train_tokens_ids[0])\n",
"print(test_dev0_tokens_ids[0])\n",
"print(test_A_tokens_ids[0])"
]
},
{
"cell_type": "code",
"execution_count": 199,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"944 199\n",
"214 256\n",
"tensor([0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 2, 0, 0, 1, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 5, 5, 5, 0, 0,\n",
" 0, 1, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 4, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 3, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0])\n",
"tensor([0, 0, 0, 4, 5, 5, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0,\n",
" 0, 0, 0, 2, 0, 7, 0, 7, 0, 0, 0, 0, 0, 0, 0, 0, 7, 0, 0, 0, 0, 1, 3, 0,\n",
" 0, 0, 7, 0, 0, 0, 7, 0, 0, 0, 7, 0, 0, 0, 2, 0, 7, 0, 0, 0, 0, 1, 3, 0,\n",
" 0, 1, 3, 0, 0, 1, 3, 0, 0, 0, 7, 0, 0, 1, 3, 0, 0, 0, 0, 0, 0, 2, 6, 0,\n",
" 7, 0, 0, 1, 3, 0, 0, 1, 3, 0, 0, 0, 7, 0, 0, 0, 2, 0, 2, 6, 0, 0, 7, 0,\n",
" 0, 7, 0, 0, 1, 3, 0, 0, 0, 0, 1, 3, 0, 0, 1, 3, 0, 0, 1, 3, 0, 0, 1, 3,\n",
" 0, 0, 0, 0, 2, 0, 7, 0, 0, 1, 3, 0, 0, 1, 3, 0, 0, 1, 3, 0, 0, 0, 7, 0,\n",
" 0, 1, 3, 0, 0, 0, 2, 0, 7, 0, 0, 0, 0, 1, 3, 0, 0, 1, 3, 0, 0, 1, 3, 0,\n",
" 0, 0, 7, 0, 0, 0, 0, 0, 2, 0, 7, 0, 0, 0, 0, 7, 0, 0, 1, 3, 0, 0, 1, 0,\n",
" 0, 0, 0, 1, 3, 0, 0, 1, 3, 0, 0, 0, 2, 0, 7, 0, 0, 0, 0, 1, 3, 0, 0, 0,\n",
" 0, 0, 7, 0, 0, 1, 3, 0, 0, 1, 3, 0, 0, 0, 0, 0])\n"
]
}
],
"source": [
"print(len(train_labels), len(train_labels[0]))\n",
"print(len(test_dev0_labels), len(test_dev0_labels[0]))\n",
"\n",
"print(train_labels[0])\n",
"print(test_dev0_labels[0])"
]
},
{
"cell_type": "code",
"execution_count": 200,
"metadata": {},
"outputs": [],
"source": [
"def get_scores(y_true, y_pred):\n",
" # Funkcja zwraca precyzję, pokrycie i F1\n",
" acc_score = 0\n",
" tp = 0\n",
" fp = 0\n",
" selected_items = 0\n",
" relevant_items = 0\n",
"\n",
" for p, t in zip(y_pred, y_true):\n",
" if p == t:\n",
" acc_score += 1\n",
"\n",
" if p > 0 and p == t:\n",
" tp += 1\n",
"\n",
" if p > 0:\n",
" selected_items += 1\n",
"\n",
" if t > 0:\n",
" relevant_items += 1\n",
"\n",
" if selected_items == 0:\n",
" precision = 1.0\n",
" else:\n",
" precision = tp / selected_items\n",
"\n",
" if relevant_items == 0:\n",
" recall = 1.0\n",
" else:\n",
" recall = tp / relevant_items\n",
"\n",
" if precision + recall == 0.0:\n",
" f1 = 0.0\n",
" else:\n",
" f1 = 2 * precision * recall / (precision + recall)\n",
"\n",
" return precision, recall, f1"
]
},
{
"cell_type": "code",
"execution_count": 201,
"metadata": {},
"outputs": [],
"source": [
"num_tags = len(etykieta_na_kod.keys())\n",
"\n",
"class LSTM(torch.nn.Module):\n",
"\n",
" def __init__(self):\n",
" super(LSTM, self).__init__()\n",
" self.emb = torch.nn.Embedding(len(v.get_itos()), 100)\n",
" self.rec = torch.nn.LSTM(100, 256, 1, batch_first=True)\n",
" self.fc1 = torch.nn.Linear(256, num_tags)\n",
"\n",
" def forward(self, x):\n",
" emb = torch.relu(self.emb(x))\n",
" lstm_output, (h_n, c_n) = self.rec(emb)\n",
" out_weights = self.fc1(lstm_output)\n",
" return out_weights"
]
},
{
"cell_type": "code",
"execution_count": 202,
"metadata": {},
"outputs": [],
"source": [
"def eval_model(dataset_tokens, dataset_labels, model):\n",
" Y_true = []\n",
" Y_pred = []\n",
" for i in tqdm(range(len(dataset_labels))):\n",
" batch_tokens = dataset_tokens[i].unsqueeze(0)\n",
" tags = list(dataset_labels[i].numpy())\n",
" Y_true += tags\n",
"\n",
" Y_batch_pred_weights = model(batch_tokens).squeeze(0)\n",
" Y_batch_pred = torch.argmax(Y_batch_pred_weights, 1)\n",
" Y_pred += list(Y_batch_pred.numpy())\n",
"\n",
" return get_scores(Y_true, Y_pred)"
]
},
{
"cell_type": "code",
"execution_count": 203,
"metadata": {},
"outputs": [],
"source": [
"lstm = LSTM()\n",
"criterion = torch.nn.CrossEntropyLoss()\n",
"optimizer = torch.optim.Adam(lstm.parameters())\n",
"NUM_EPOCHS = 50"
]
},
{
"cell_type": "code",
"execution_count": 204,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 944/944 [00:22<00:00, 41.98it/s]\n",
"100%|██████████| 944/944 [00:29<00:00, 32.45it/s]\n",
"100%|██████████| 944/944 [00:29<00:00, 32.18it/s]\n",
"100%|██████████| 944/944 [00:29<00:00, 32.23it/s]\n",
"100%|██████████| 944/944 [00:29<00:00, 32.32it/s]\n",
"100%|██████████| 944/944 [00:29<00:00, 32.00it/s]\n",
"100%|██████████| 944/944 [00:29<00:00, 32.22it/s]\n",
"100%|██████████| 944/944 [00:29<00:00, 31.84it/s]\n",
"100%|██████████| 944/944 [00:29<00:00, 32.12it/s]\n",
"100%|██████████| 944/944 [00:29<00:00, 31.82it/s]\n",
"100%|██████████| 944/944 [00:29<00:00, 31.61it/s]\n",
"100%|██████████| 944/944 [00:29<00:00, 31.92it/s]\n",
"100%|██████████| 944/944 [00:29<00:00, 31.71it/s]\n",
"100%|██████████| 944/944 [00:29<00:00, 31.94it/s]\n",
"100%|██████████| 944/944 [00:29<00:00, 31.64it/s]\n",
"100%|██████████| 944/944 [00:29<00:00, 31.78it/s]\n",
"100%|██████████| 944/944 [00:29<00:00, 31.58it/s]\n",
"100%|██████████| 944/944 [00:29<00:00, 31.78it/s]\n",
"100%|██████████| 944/944 [00:30<00:00, 31.32it/s]\n",
"100%|██████████| 944/944 [00:29<00:00, 31.74it/s]\n",
"100%|██████████| 944/944 [00:29<00:00, 31.53it/s]\n",
"100%|██████████| 944/944 [00:31<00:00, 30.43it/s]\n",
"100%|██████████| 944/944 [00:30<00:00, 30.87it/s]\n",
"100%|██████████| 944/944 [00:30<00:00, 31.45it/s]\n",
"100%|██████████| 944/944 [00:30<00:00, 31.46it/s]\n",
"100%|██████████| 944/944 [00:30<00:00, 31.31it/s]\n",
"100%|██████████| 944/944 [00:29<00:00, 31.47it/s]\n",
"100%|██████████| 944/944 [00:30<00:00, 31.34it/s]\n",
"100%|██████████| 944/944 [00:30<00:00, 31.39it/s]\n",
"100%|██████████| 944/944 [00:31<00:00, 30.17it/s]\n",
"100%|██████████| 944/944 [00:30<00:00, 31.39it/s]\n",
"100%|██████████| 944/944 [00:30<00:00, 31.27it/s]\n",
"100%|██████████| 944/944 [00:30<00:00, 31.27it/s]\n",
"100%|██████████| 944/944 [00:30<00:00, 31.42it/s]\n",
"100%|██████████| 944/944 [00:30<00:00, 31.40it/s]\n",
"100%|██████████| 944/944 [00:29<00:00, 31.52it/s]\n",
"100%|██████████| 944/944 [00:30<00:00, 31.13it/s]\n",
"100%|██████████| 944/944 [00:30<00:00, 31.25it/s]\n",
"100%|██████████| 944/944 [00:30<00:00, 31.12it/s]\n",
"100%|██████████| 944/944 [00:30<00:00, 31.31it/s]\n",
"100%|██████████| 944/944 [00:30<00:00, 31.29it/s]\n",
"100%|██████████| 944/944 [00:30<00:00, 31.41it/s]\n",
"100%|██████████| 944/944 [00:30<00:00, 31.19it/s]\n",
"100%|██████████| 944/944 [00:29<00:00, 31.54it/s]\n",
"100%|██████████| 944/944 [00:30<00:00, 31.28it/s]\n",
"100%|██████████| 944/944 [00:30<00:00, 31.42it/s]\n",
"100%|██████████| 944/944 [00:29<00:00, 31.49it/s]\n",
"100%|██████████| 944/944 [00:30<00:00, 31.41it/s]\n",
"100%|██████████| 944/944 [00:29<00:00, 31.50it/s]\n",
"100%|██████████| 944/944 [00:30<00:00, 31.45it/s]\n"
]
}
],
"source": [
"for i in range(NUM_EPOCHS):\n",
" lstm.train()\n",
" # for i in tqdm(range(500)):\n",
" for i in tqdm(range(len(train_labels))):\n",
" batch_tokens = train_tokens_ids[i].unsqueeze(0)\n",
" tags = train_labels[i].unsqueeze(1)\n",
"\n",
" predicted_tags = lstm(batch_tokens)\n",
"\n",
" optimizer.zero_grad()\n",
" loss = criterion(predicted_tags.squeeze(0), tags.squeeze(1))\n",
"\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
" lstm.eval()"
]
},
{
"cell_type": "code",
"execution_count": 205,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 214/214 [00:00<00:00, 277.61it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"(0.7493175614194723, 0.7732394366197183, 0.7610905730129389)\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"print(eval_model(test_dev0_tokens_ids, test_dev0_labels, lstm))"
]
},
{
"cell_type": "code",
"execution_count": 206,
"metadata": {},
"outputs": [],
"source": [
"def zwroc_przewidywania(tokeny):\n",
" Y_pred = []\n",
" for i in tqdm(range(len(tokeny))):\n",
" pom1 = lstm(tokeny[i])\n",
" #print(pom1)\n",
" pom2 = torch.argmax(pom1,1)\n",
" #print(pom2)\n",
" pom3 = list(pom2.numpy())\n",
" #print(pom3)\n",
" Y_pred.append(pom3)\n",
" return Y_pred"
]
},
{
"cell_type": "code",
"execution_count": 207,
"metadata": {},
"outputs": [],
"source": [
"def zamien_przewidziane_kody_na_etykiety(przewidywania):\n",
" etykiety = []\n",
" for lista in przewidywania:\n",
" pom = []\n",
" for kod in lista:\n",
" etykieta = None\n",
" for e, k in etykieta_na_kod.items():\n",
" if kod == k:\n",
" etykieta = e\n",
" pom.append(etykieta)\n",
" etykiety.append(pom)\n",
" return etykiety"
]
},
{
"cell_type": "code",
"execution_count": 208,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 214/214 [00:00<00:00, 280.93it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"(0.7493175614194723, 0.7732394366197183, 0.7610905730129389)\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 214/214 [00:00<00:00, 310.62it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0, 0, 0, 4, 4, 5, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 7, 0, 7, 0, 7, 0, 0, 0, 0, 0, 0, 0, 0, 7, 0, 0, 0, 0, 1, 3, 0, 0, 0, 7, 0, 0, 0, 7, 0, 0, 0, 7, 0, 0, 0, 2, 0, 7, 0, 0, 1, 0, 1, 3, 0, 0, 1, 3, 0, 0, 1, 3, 3, 0, 0, 7, 0, 0, 1, 3, 0, 0, 0, 0, 0, 0, 2, 0, 0, 7, 8, 0, 1, 3, 0, 0, 1, 3, 1, 0, 0, 7, 1, 0, 0, 2, 0, 0, 6, 0, 0, 7, 0, 0, 7, 0, 0, 1, 3, 0, 0, 0, 0, 1, 3, 0, 0, 1, 3, 0, 0, 1, 3, 0, 0, 1, 3, 0, 0, 0, 0, 2, 0, 7, 0, 0, 1, 3, 0, 0, 1, 3, 0, 0, 1, 3, 0, 0, 0, 7, 0, 0, 2, 0, 0, 0, 0, 2, 0, 7, 0, 0, 0, 0, 1, 3, 0, 0, 1, 3, 0, 0, 1, 3, 0, 0, 0, 7, 0, 0, 1, 0, 0, 2, 0, 7, 0, 0, 1, 0, 7, 1, 0, 1, 3, 0, 0, 1, 0, 0, 0, 0, 1, 3, 0, 0, 1, 3, 1, 0, 0, 2, 0, 7, 0, 0, 1, 0, 1, 3, 0, 0, 0, 0, 0, 7, 0, 0, 1, 3, 0, 0, 1, 3, 0, 0, 0, 0, 0]\n",
"['O', 'O', 'O', 'B-MISC', 'B-MISC', 'I-MISC', 'O', 'O', 'O', 'B-LOC', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-MISC', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-ORG', 'O', 'B-ORG', 'O', 'B-ORG', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-ORG', 'O', 'O', 'O', 'O', 'B-PER', 'I-PER', 'O', 'O', 'O', 'B-ORG', 'O', 'O', 'O', 'B-ORG', 'O', 'O', 'O', 'B-ORG', 'O', 'O', 'O', 'B-LOC', 'O', 'B-ORG', 'O', 'O', 'B-PER', 'O', 'B-PER', 'I-PER', 'O', 'O', 'B-PER', 'I-PER', 'O', 'O', 'B-PER', 'I-PER', 'I-PER', 'O', 'O', 'B-ORG', 'O', 'O', 'B-PER', 'I-PER', 'O', 'O', 'O', 'O', 'O', 'O', 'B-LOC', 'O', 'O', 'B-ORG', 'I-ORG', 'O', 'B-PER', 'I-PER', 'O', 'O', 'B-PER', 'I-PER', 'B-PER', 'O', 'O', 'B-ORG', 'B-PER', 'O', 'O', 'B-LOC', 'O', 'O', 'I-LOC', 'O', 'O', 'B-ORG', 'O', 'O', 'B-ORG', 'O', 'O', 'B-PER', 'I-PER', 'O', 'O', 'O', 'O', 'B-PER', 'I-PER', 'O', 'O', 'B-PER', 'I-PER', 'O', 'O', 'B-PER', 'I-PER', 'O', 'O', 'B-PER', 'I-PER', 'O', 'O', 'O', 'O', 'B-LOC', 'O', 'B-ORG', 'O', 'O', 'B-PER', 'I-PER', 'O', 'O', 'B-PER', 'I-PER', 'O', 'O', 'B-PER', 'I-PER', 'O', 'O', 'O', 'B-ORG', 'O', 'O', 'B-LOC', 'O', 'O', 'O', 'O', 'B-LOC', 'O', 'B-ORG', 'O', 'O', 'O', 'O', 'B-PER', 'I-PER', 'O', 'O', 'B-PER', 'I-PER', 'O', 'O', 'B-PER', 'I-PER', 'O', 'O', 'O', 'B-ORG', 'O', 'O', 'B-PER', 'O', 'O', 'B-LOC', 'O', 'B-ORG', 'O', 'O', 'B-PER', 'O', 'B-ORG', 'B-PER', 'O', 'B-PER', 'I-PER', 'O', 'O', 'B-PER', 'O', 'O', 'O', 'O', 'B-PER', 'I-PER', 'O', 'O', 'B-PER', 'I-PER', 'B-PER', 'O', 'O', 'B-LOC', 'O', 'B-ORG', 'O', 'O', 'B-PER', 'O', 'B-PER', 'I-PER', 'O', 'O', 'O', 'O', 'O', 'B-ORG', 'O', 'O', 'B-PER', 'I-PER', 'O', 'O', 'B-PER', 'I-PER', 'O', 'O', 'O', 'O', 'O']\n"
]
}
],
"source": [
"print(eval_model(test_dev0_tokens_ids, test_dev0_labels, lstm))\n",
"przewidywania_kody_test_dev0 = zwroc_przewidywania(test_dev0_tokens_ids)\n",
"print(przewidywania_kody_test_dev0[0])\n",
"przewidywania_etykiety_test_dev0 = zamien_przewidziane_kody_na_etykiety(przewidywania_kody_test_dev0)\n",
"print(przewidywania_etykiety_test_dev0[0])"
]
},
{
"cell_type": "code",
"execution_count": 209,
"metadata": {},
"outputs": [],
"source": [
"with open(\"dev-0/out.tsv\", \"w\", encoding=\"utf-8\") as uwu:\n",
" for lista in przewidywania_etykiety_test_dev0:\n",
" for etykieta in lista:\n",
" uwu.write(str(etykieta) + \" \")\n",
" uwu.write(str(\"\\n\"))"
]
},
{
"cell_type": "code",
"execution_count": 210,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 229/229 [00:00<00:00, 339.90it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0, 7, 8, 0, 1, 0, 0, 2, 0, 0, 0, 0, 0, 2, 5, 0, 2, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 2, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 7, 0, 0, 1, 0, 0, 0, 1, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 0, 2, 0, 2, 0, 0, 0, 0, 1, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 5, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 3, 0, 4, 3, 0, 4, 5, 0, 4, 3, 0, 4, 5, 0, 4, 3, 0, 1, 3, 0, 7, 5, 0, 4, 5, 0, 1, 3, 0, 0, 5, 0, 1, 3, 0, 1, 3, 0, 0, 5, 5, 0, 1, 3, 0, 4, 5, 0, 4, 5, 0, 1, 5, 0, 4, 5, 0, 1, 0, 0, 1, 3, 0, 0, 0]\n",
"['O', 'B-ORG', 'I-ORG', 'O', 'B-PER', 'O', 'O', 'B-LOC', 'O', 'O', 'O', 'O', 'O', 'B-LOC', 'I-MISC', 'O', 'B-LOC', 'O', 'O', 'I-MISC', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-LOC', 'O', 'B-LOC', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-ORG', 'O', 'O', 'B-PER', 'O', 'O', 'O', 'B-PER', 'I-PER', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-MISC', 'O', 'B-LOC', 'O', 'B-LOC', 'O', 'O', 'O', 'O', 'B-PER', 'I-PER', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-MISC', 'I-MISC', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-LOC', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-MISC', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'I-MISC', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-LOC', 'O', 'O', 'O', 'B-LOC', 'O', 'O', 'O', 'O', 'O', 'O', 'B-PER', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'I-MISC', 'O', 'O', 'O', 'O', 'O', 'O', 'B-PER', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-PER', 'I-PER', 'O', 'B-MISC', 'I-PER', 'O', 'B-MISC', 'I-MISC', 'O', 'B-MISC', 'I-PER', 'O', 'B-MISC', 'I-MISC', 'O', 'B-MISC', 'I-PER', 'O', 'B-PER', 'I-PER', 'O', 'B-ORG', 'I-MISC', 'O', 'B-MISC', 'I-MISC', 'O', 'B-PER', 'I-PER', 'O', 'O', 'I-MISC', 'O', 'B-PER', 'I-PER', 'O', 'B-PER', 'I-PER', 'O', 'O', 'I-MISC', 'I-MISC', 'O', 'B-PER', 'I-PER', 'O', 'B-MISC', 'I-MISC', 'O', 'B-MISC', 'I-MISC', 'O', 'B-PER', 'I-MISC', 'O', 'B-MISC', 'I-MISC', 'O', 'B-PER', 'O', 'O', 'B-PER', 'I-PER', 'O', 'O', 'O']\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"przewidywania_kody_test_A = zwroc_przewidywania(test_A_tokens_ids)\n",
"print(przewidywania_kody_test_A[0])\n",
"przewidywania_etykiety_test_A = zamien_przewidziane_kody_na_etykiety(przewidywania_kody_test_A)\n",
"print(przewidywania_etykiety_test_A[0])"
]
},
{
"cell_type": "code",
"execution_count": 211,
"metadata": {},
"outputs": [],
"source": [
"with open(\"test-A/out.tsv\", \"w\", encoding=\"utf-8\") as uwu:\n",
" for lista in przewidywania_etykiety_test_A:\n",
" for etykieta in lista:\n",
" uwu.write(str(etykieta) + \" \")\n",
" uwu.write(str(\"\\n\"))"
]
},
{
"cell_type": "code",
"execution_count": 212,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 944/944 [00:03<00:00, 264.86it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"(0.9969414463429698, 0.9979686764013189, 0.9974547968986774)\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 944/944 [00:03<00:00, 278.75it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 1, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 5, 5, 5, 0, 0, 0, 1, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]\n",
"['O', 'O', 'B-PER', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-LOC', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-LOC', 'O', 'O', 'B-PER', 'I-PER', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-LOC', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'I-MISC', 'O', 'O', 'O', 'B-PER', 'O', 'O', 'O', 'O', 'O', 'B-LOC', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-MISC', 'O', 'O', 'B-LOC', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-PER', 'O', 'O', 'O', 'B-PER', 'I-PER', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-PER', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-MISC', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']\n"
]
}
],
"source": [
"print(eval_model(train_tokens_ids, train_labels, lstm))\n",
"przewidywania_kody_test_train = zwroc_przewidywania(train_tokens_ids)\n",
"print(przewidywania_kody_test_train[0])\n",
"przewidywania_etykiety_test_train = zamien_przewidziane_kody_na_etykiety(przewidywania_kody_test_train)\n",
"print(przewidywania_etykiety_test_train[0])"
]
},
{
"cell_type": "code",
"execution_count": 213,
"metadata": {},
"outputs": [],
"source": [
"with open(\"train/out.tsv\", \"w\", encoding=\"utf-8\") as uwu:\n",
" for lista in przewidywania_etykiety_test_train:\n",
" for etykieta in lista:\n",
" uwu.write(str(etykieta) + \" \")\n",
" uwu.write(str(\"\\n\"))"
]
}
],
"metadata": {
"author": "Jakub Pokrywka",
"email": "kubapok@wmi.amu.edu.pl",
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"lang": "pl",
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
},
"subtitle": "11.NER RNN[ćwiczenia]",
"title": "Ekstrakcja informacji",
"year": "2021"
},
"nbformat": 4,
"nbformat_minor": 4
}