diff --git a/train/RNN.ipynb b/train/RNN.ipynb new file mode 100644 index 0000000..fad0feb --- /dev/null +++ b/train/RNN.ipynb @@ -0,0 +1,1201 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "5d0842c8-c292-41ce-a27b-73986bf43e1c", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "C:\\Users\\obses\\AppData\\Local\\Programs\\Python\\Python310\\lib\\site-packages\\torchtext\\vocab\\__init__.py:4: UserWarning: \n", + "/!\\ IMPORTANT WARNING ABOUT TORCHTEXT STATUS /!\\ \n", + "Torchtext is deprecated and the last released version will be 0.18 (this one). You can silence this warning by calling the following at the beginnign of your scripts: `import torchtext; torchtext.disable_torchtext_deprecation_warning()`\n", + " warnings.warn(torchtext._TORCHTEXT_DEPRECATION_MSG)\n", + "C:\\Users\\obses\\AppData\\Local\\Programs\\Python\\Python310\\lib\\site-packages\\torchtext\\utils.py:4: UserWarning: \n", + "/!\\ IMPORTANT WARNING ABOUT TORCHTEXT STATUS /!\\ \n", + "Torchtext is deprecated and the last released version will be 0.18 (this one). You can silence this warning by calling the following at the beginnign of your scripts: `import torchtext; torchtext.disable_torchtext_deprecation_warning()`\n", + " warnings.warn(torchtext._TORCHTEXT_DEPRECATION_MSG)\n" + ] + } + ], + "source": [ + "from collections import Counter\n", + "import torch\n", + "from datasets import load_dataset\n", + "from torchtext.vocab import vocab\n", + "from tqdm.notebook import tqdm\n", + "import pandas as pd\n", + "import nltk\n", + "from nltk.tokenize import word_tokenize\n", + "import string" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "4ccf08e2-fec1-4d68-a1fd-d33af6bd54bc", + "metadata": {}, + "outputs": [], + "source": [ + "dataset = pd.read_csv('train.tsv', sep='\\t', header=None)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "c027b2f3-48eb-442a-a212-c5fa9d1cfba5", + "metadata": {}, + "outputs": [], + "source": [ + "X_test = pd.read_csv('../dev-0/in.tsv', sep='\\t', header=None)\n", + "Y_test = pd.read_csv('../dev-0/expected.tsv', sep='\\t', header=None)\n", + "\n", + "X_val = pd.read_csv('../test-A/in.tsv', sep='\\t', header=None)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "9fc7a069-52ce-4343-a8a2-dd77619f5d25", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "German July car registrations up 14.2 pct yr / yr . FRANKFURT 1996-08-22 German first-time registrations of motor vehicles jumped 14.2 percent in July this year from the year-earlier period , the Federal office for motor vehicles said on Thursday . The office said 356,725 new cars were registered in July 1996 -- 304,850 passenger cars and 15,613 trucks . The figures represent a 13.6 percent increase for passenger cars and a 2.2 percent decline for trucks from July 1995 . Motor-bike registration rose 32.7 percent in the period . The growth was partly due to an increased number of Germans buying German cars abroad , while manufacturers said that domestic demand was weak , the federal office said . Almost all German car manufacturers posted gains in registration numbers in the period . Volkswagen AG won 77,719 registrations , slightly more than a quarter of the total . Opel AG together with General Motors came in second place with 49,269 registrations , 16.4 percent of the overall figure . Third was Ford with 35,563 registrations , or 11.7 percent . Only Seat and Porsche had fewer registrations in July 1996 compared to last year 's July . Seat posted 3,420 registrations compared with 5522 registrations in July a year earlier . Porsche 's registrations fell to 554 from 643 . \n", + "BASKETBALL - INTERNATIONAL TOURNAMENT RESULT . BELGRADE 1996-08-30 Result in an international basketball tournament on Friday : Red Star ( Yugoslavia ) beat Dinamo ( Russia ) 92-90 ( halftime 47-47 ) \n", + "SOCCER - ASIAN CUP GROUP C RESULTS . AL-AIN , United Arab Emirates 1996-12-06 Results of Asian Cup group C matches played on Friday : Japan 2 Syria 1 ( halftime 0-1 ) Scorers : Japan - Hassan Abbas 84 own goal , Takuya Takagi 88 . Syria - Nader Jokhadar 7 Attendance : 10,000 . China 0 Uzbekistan 2 ( halftime 0-0 ) Scorers : Shkvyrin Igor 78 , Shatskikh Oleg 90 Attendence : 3,000 Standings ( tabulate under played , won , drawn , lost , goals for , goals against , points ) : Uzbekistan 1 1 0 0 2 0 3 Japan 1 1 0 0 2 1 3 Syria 1 0 0 1 1 2 0 China 1 0 0 1 0 2 0 \n" + ] + } + ], + "source": [ + "X_train = dataset[dataset.columns[1]].replace(\"\",\"\")\n", + "Y_train = dataset[dataset.columns[0]]\n", + "\n", + "X_test = X_test[X_test.columns[0]].replace(\"\",\"\")\n", + "Y_test = Y_test[Y_test.columns[0]]\n", + "\n", + "X_val = X_val[X_val.columns[0]].replace(\"\",\"\")\n", + "\n", + "print(X_train[4])\n", + "print(X_test[4])\n", + "print(X_val[4])" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "7cd49893-c730-4089-9b21-8683c11fd6cc", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " ['german', 'july', 'car', 'registrations', 'up', '14.2', 'pct', 'yr', '/', 'yr', '.', '', 'frankfurt', '1996-08-22', '', 'german', 'first-time', 'registrations', 'of', 'motor', 'vehicles', 'jumped', '14.2', 'percent', 'in', 'july', 'this', 'year', 'from', 'the', 'year-earlier', 'period', ',', 'the', 'federal', 'office', 'for', 'motor', 'vehicles', 'said', 'on', 'thursday', '.', '', 'the', 'office', 'said', '356,725', 'new', 'cars', 'were', 'registered', 'in', 'july', '1996', '--', '304,850', 'passenger', 'cars', 'and', '15,613', 'trucks', '.', '', 'the', 'figures', 'represent', 'a', '13.6', 'percent', 'increase', 'for', 'passenger', 'cars', 'and', 'a', '2.2', 'percent', 'decline', 'for', 'trucks', 'from', 'july', '1995', '.', '', 'motor-bike', 'registration', 'rose', '32.7', 'percent', 'in', 'the', 'period', '.', '', 'the', 'growth', 'was', 'partly', 'due', 'to', 'an', 'increased', 'number', 'of', 'germans', 'buying', 'german', 'cars', 'abroad', ',', 'while', 'manufacturers', 'said', 'that', 'domestic', 'demand', 'was', 'weak', ',', 'the', 'federal', 'office', 'said', '.', '', 'almost', 'all', 'german', 'car', 'manufacturers', 'posted', 'gains', 'in', 'registration', 'numbers', 'in', 'the', 'period', '.', '', 'volkswagen', 'ag', 'won', '77,719', 'registrations', ',', 'slightly', 'more', 'than', 'a', 'quarter', 'of', 'the', 'total', '.', '', 'opel', 'ag', 'together', 'with', 'general', 'motors', 'came', 'in', 'second', 'place', 'with', '49,269', 'registrations', ',', '16.4', 'percent', 'of', 'the', 'overall', 'figure', '.', '', 'third', 'was', 'ford', 'with', '35,563', 'registrations', ',', 'or', '11.7', 'percent', '.', '', 'only', 'seat', 'and', 'porsche', 'had', 'fewer', 'registrations', 'in', 'july', '1996', 'compared', 'to', 'last', 'year', \"'s\", 'july', '.', '', 'seat', 'posted', '3,420', 'registrations', 'compared', 'with', '5522', 'registrations', 'in', 'july', 'a', 'year', 'earlier', '.', '', 'porsche', \"'s\", 'registrations', 'fell', 'to', '554', 'from', '643', '.', '']\n", + " ['basketball', '-', 'international', 'tournament', 'result', '.', '', 'belgrade', '1996-08-30', '', 'result', 'in', 'an', 'international', '', 'basketball', 'tournament', 'on', 'friday', ':', '', 'red', 'star', '(', 'yugoslavia', ')', 'beat', 'dinamo', '(', 'russia', ')', '92-90', '(', 'halftime', '', '47-47', ')', '']\n", + " ['soccer', '-', 'asian', 'cup', 'group', 'c', 'results', '.', '', 'al-ain', ',', 'united', 'arab', 'emirates', '1996-12-06', '', 'results', 'of', 'asian', 'cup', 'group', 'c', 'matches', 'played', 'on', 'friday', ':', '', 'japan', '2', 'syria', '1', '(', 'halftime', '0-1', ')', '', 'scorers', ':', '', 'japan', '-', 'hassan', 'abbas', '84', 'own', 'goal', ',', 'takuya', 'takagi', '88', '.', '', 'syria', '-', 'nader', 'jokhadar', '7', '', 'attendance', ':', '10,000', '.', '', 'china', '0', 'uzbekistan', '2', '(', 'halftime', '0-0', ')', '', 'scorers', ':', 'shkvyrin', 'igor', '78', ',', 'shatskikh', 'oleg', '90', '', 'attendence', ':', '3,000', '', 'standings', '(', 'tabulate', 'under', 'played', ',', 'won', ',', 'drawn', ',', 'lost', ',', 'goals', '', 'for', ',', 'goals', 'against', ',', 'points', ')', ':', '', 'uzbekistan', '1', '1', '0', '0', '2', '0', '3', '', 'japan', '1', '1', '0', '0', '2', '1', '3', '', 'syria', '1', '0', '0', '1', '1', '2', '0', '', 'china', '1', '0', '0', '1', '0', '2', '0', '']\n" + ] + } + ], + "source": [ + "def preprocess(text):\n", + " text = text.lower()\n", + " return text\n", + "\n", + "X_train = [preprocess(text).split() for text in X_train]\n", + "print( type(X_train), X_train[4])\n", + "X_test = [preprocess(text).split() for text in X_test]\n", + "print( type(X_test), X_test[4])\n", + "X_val = [preprocess(text).split() for text in X_val]\n", + "print( type(X_val), X_val[4])" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "c24df99f-1792-4691-92c0-7e5e596237c3", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['german', 'july', 'car', 'registrations', 'up', '14.2', 'pct', 'yr', '/', 'yr', '.', '', 'frankfurt', '1996-08-22', '', 'german', 'first-time', 'registrations', 'of', 'motor', 'vehicles', 'jumped', '14.2', 'percent', 'in', 'july', 'this', 'year', 'from', 'the', 'year-earlier', 'period', ',', 'the', 'federal', 'office', 'for', 'motor', 'vehicles', 'said', 'on', 'thursday', '.', '', 'the', 'office', 'said', '356,725', 'new', 'cars', 'were', 'registered', 'in', 'july', '1996', '--', '304,850', 'passenger', 'cars', 'and', '15,613', 'trucks', '.', '', 'the', 'figures', 'represent', 'a', '13.6', 'percent', 'increase', 'for', 'passenger', 'cars', 'and', 'a', '2.2', 'percent', 'decline', 'for', 'trucks', 'from', 'july', '1995', '.', '', 'motor-bike', 'registration', 'rose', '32.7', 'percent', 'in', 'the', 'period', '.', '', 'the', 'growth', 'was', 'partly', 'due', 'to', 'an', 'increased', 'number', 'of', 'germans', 'buying', 'german', 'cars', 'abroad', ',', 'while', 'manufacturers', 'said', 'that', 'domestic', 'demand', 'was', 'weak', ',', 'the', 'federal', 'office', 'said', '.', '', 'almost', 'all', 'german', 'car', 'manufacturers', 'posted', 'gains', 'in', 'registration', 'numbers', 'in', 'the', 'period', '.', '', 'volkswagen', 'ag', 'won', '77,719', 'registrations', ',', 'slightly', 'more', 'than', 'a', 'quarter', 'of', 'the', 'total', '.', '', 'opel', 'ag', 'together', 'with', 'general', 'motors', 'came', 'in', 'second', 'place', 'with', '49,269', 'registrations', ',', '16.4', 'percent', 'of', 'the', 'overall', 'figure', '.', '', 'third', 'was', 'ford', 'with', '35,563', 'registrations', ',', 'or', '11.7', 'percent', '.', '', 'only', 'seat', 'and', 'porsche', 'had', 'fewer', 'registrations', 'in', 'july', '1996', 'compared', 'to', 'last', 'year', \"'s\", 'july', '.', '', 'seat', 'posted', '3,420', 'registrations', 'compared', 'with', '5522', 'registrations', 'in', 'july', 'a', 'year', 'earlier', '.', '', 'porsche', \"'s\", 'registrations', 'fell', 'to', '554', 'from', '643', '.', ''] B-MISC O O O O O O O O O O O B-LOC O O B-MISC O O O O O O O O O O O O O O O O O O B-ORG I-ORG I-ORG I-ORG I-ORG O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O B-MISC O B-MISC O O O O O O O O O O O O O O O O O O O O B-MISC O O O O O O O O O O O O B-ORG I-ORG O O O O O O O O O O O O O O B-ORG I-ORG O O B-ORG I-ORG O O O O O O O O O O O O O O O O O O B-ORG O O O O O O O O O O B-ORG O B-ORG O O O O O O O O O O O O O O B-ORG O O O O O O O O O O O O O O B-ORG O O O O O O O O O ['soccer', '-', 'asian', 'cup', 'group', 'c', 'results', '.', '', 'al-ain', ',', 'united', 'arab', 'emirates', '1996-12-06', '', 'results', 'of', 'asian', 'cup', 'group', 'c', 'matches', 'played', 'on', 'friday', ':', '', 'japan', '2', 'syria', '1', '(', 'halftime', '0-1', ')', '', 'scorers', ':', '', 'japan', '-', 'hassan', 'abbas', '84', 'own', 'goal', ',', 'takuya', 'takagi', '88', '.', '', 'syria', '-', 'nader', 'jokhadar', '7', '', 'attendance', ':', '10,000', '.', '', 'china', '0', 'uzbekistan', '2', '(', 'halftime', '0-0', ')', '', 'scorers', ':', 'shkvyrin', 'igor', '78', ',', 'shatskikh', 'oleg', '90', '', 'attendence', ':', '3,000', '', 'standings', '(', 'tabulate', 'under', 'played', ',', 'won', ',', 'drawn', ',', 'lost', ',', 'goals', '', 'for', ',', 'goals', 'against', ',', 'points', ')', ':', '', 'uzbekistan', '1', '1', '0', '0', '2', '0', '3', '', 'japan', '1', '1', '0', '0', '2', '1', '3', '', 'syria', '1', '0', '0', '1', '1', '2', '0', '', 'china', '1', '0', '0', '1', '0', '2', '0', '']\n", + "235 562\n" + ] + } + ], + "source": [ + "print(X_train[4], Y_train[4], X_val[4])\n", + "print(len(X_train[4]), len(Y_train[4]))" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "68f95ff5-28d4-4267-94f1-7de37d0b1a1b", + "metadata": {}, + "outputs": [], + "source": [ + "def build_vocab(dataset):\n", + " counter = Counter()\n", + " for document in dataset:\n", + " counter.update(document)\n", + " return vocab(counter, specials=[\"\", \"\", \"\", \"\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "7d2e8c94-403d-4481-afe2-23ef80146b4c", + "metadata": {}, + "outputs": [], + "source": [ + "v = build_vocab(X_train)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "c8ad750f-fa6e-4ce1-a32f-70c314f5a587", + "metadata": {}, + "outputs": [], + "source": [ + "itos = v.get_itos() # mapowanie indeksów na tokeny" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "4368cae0-a34d-47a8-9a18-8c7dc78ecf8c", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['', '', '', '', 'eu', 'rejects', 'german', 'call', 'to', 'boycott', 'british']\n" + ] + } + ], + "source": [ + "print(itos[:11])" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "a155c4b4-5395-410e-8482-2d8f54250b44", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "21014" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(itos) # liczba różnych tokenów w słowniku" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "7aa9c044-8d13-4582-a081-83837eefb76f", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "v[\"\"] # indeks nieznanego tokenu" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "675d2d9a-398b-4b01-abd7-cfad1047d74c", + "metadata": {}, + "outputs": [], + "source": [ + "v.set_default_index(v[\"\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "4253750b-5c6e-4df2-8440-f190ae21418a", + "metadata": {}, + "outputs": [], + "source": [ + "def data_process(dt):\n", + " # Wektoryzacja dokumentów tekstowych.\n", + " return [\n", + " torch.tensor(\n", + " [v[\"\"]] + [v[token] for token in document] + [v[\"\"]],\n", + " dtype=torch.long,\n", + " )\n", + " for document in dt\n", + " ]" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "4bf37b0a-ad61-41be-adc7-ad66741ef534", + "metadata": {}, + "outputs": [], + "source": [ + "def labels_process(dt):\n", + " # Wektoryzacja etykiet (NER)\n", + " return [torch.tensor([0] + document + [0], dtype=torch.long) for document in dt]" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "be36e015-4e20-485a-867e-bcb517d2e6a5", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([ 2, 6, 438, 439, 440, 309, 441, 442, 443, 444, 443, 12, 13, 445,\n", + " 17, 13, 6, 446, 440, 163, 447, 448, 449, 441, 255, 233, 438, 363,\n", + " 244, 53, 18, 450, 451, 73, 18, 452, 453, 72, 447, 448, 21, 22,\n", + " 23, 12, 13, 18, 453, 21, 454, 455, 456, 312, 457, 233, 438, 458,\n", + " 135, 459, 460, 456, 88, 461, 462, 12, 13, 18, 463, 464, 81, 465,\n", + " 255, 466, 72, 460, 456, 88, 81, 467, 255, 468, 72, 462, 53, 438,\n", + " 469, 12, 13, 470, 471, 472, 473, 255, 233, 18, 451, 12, 13, 18,\n", + " 474, 59, 475, 174, 8, 159, 476, 477, 163, 478, 479, 6, 456, 480,\n", + " 73, 481, 482, 21, 91, 483, 484, 59, 485, 73, 18, 452, 453, 21,\n", + " 12, 13, 262, 416, 6, 439, 482, 486, 487, 233, 471, 488, 233, 18,\n", + " 451, 12, 13, 489, 490, 491, 492, 440, 73, 493, 494, 56, 81, 495,\n", + " 163, 18, 247, 12, 13, 496, 490, 497, 26, 498, 499, 500, 233, 501,\n", + " 502, 26, 503, 440, 73, 504, 255, 163, 18, 256, 505, 12, 13, 506,\n", + " 59, 507, 26, 508, 440, 73, 509, 510, 255, 12, 13, 147, 511, 88,\n", + " 512, 155, 513, 440, 233, 438, 458, 514, 8, 97, 244, 42, 438, 12,\n", + " 13, 511, 486, 515, 440, 514, 26, 516, 440, 233, 438, 81, 244, 156,\n", + " 12, 13, 512, 42, 440, 517, 8, 518, 53, 519, 12, 13, 3])\n", + "tensor([ 2, 9637, 640, 419, 1908, 1850, 12, 13, 2439, 19358,\n", + " 13, 1850, 233, 159, 419, 13, 9637, 1908, 22, 1098,\n", + " 380, 13, 1103, 2438, 132, 2440, 134, 1762, 11167, 132,\n", + " 1160, 134, 0, 132, 1767, 13, 0, 134, 13, 3])\n", + "tensor([ 2, 1759, 640, 5613, 1770, 391, 2103, 1301, 12, 13,\n", + " 0, 73, 820, 1077, 1078, 0, 13, 1301, 163, 5613,\n", + " 1770, 391, 2103, 2010, 2489, 22, 1098, 380, 13, 1677,\n", + " 657, 667, 1316, 132, 1767, 6515, 134, 13, 1775, 380,\n", + " 13, 1677, 640, 875, 18629, 5763, 961, 426, 73, 0,\n", + " 0, 9355, 12, 13, 667, 640, 0, 0, 1929, 13,\n", + " 1786, 380, 4031, 12, 13, 345, 1577, 0, 657, 132,\n", + " 1767, 2299, 134, 13, 1775, 380, 0, 0, 9302, 73,\n", + " 0, 0, 4497, 13, 17180, 380, 2354, 13, 2853, 132,\n", + " 2854, 124, 2489, 73, 491, 73, 3958, 73, 2855, 73,\n", + " 2357, 13, 72, 73, 2357, 746, 73, 1469, 134, 380,\n", + " 13, 0, 1316, 1316, 1577, 1577, 657, 1577, 1945, 13,\n", + " 1677, 1316, 1316, 1577, 1577, 657, 1316, 1945, 13, 667,\n", + " 1316, 1577, 1577, 1316, 1316, 657, 1577, 13, 345, 1316,\n", + " 1577, 1577, 1316, 1577, 657, 1577, 13, 3])\n" + ] + } + ], + "source": [ + "train_tokens_ids = data_process(X_train)\n", + "print(train_tokens_ids[4])\n", + "validation_tokens_ids = data_process(X_test)\n", + "print(validation_tokens_ids[4])\n", + "val_tokens_ids = data_process(X_val)\n", + "print(val_tokens_ids[4])" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "34ec5fa0-7bff-44d0-9428-94c9b24244ec", + "metadata": {}, + "outputs": [], + "source": [ + "def ner_tags(dataset):\n", + " tags = []\n", + " for array in dataset:\n", + " for el in array.split():\n", + " tags.append(el)\n", + " tags = set(tags)\n", + " print(len(tags))\n", + " tag_to_index = {tag: idx for idx, tag in enumerate(tags)}\n", + " return tag_to_index" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "4fc9a9c9-cb3b-43b0-b2b0-0194a9f36d5d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "9\n", + "{'B-LOC': 0, 'B-MISC': 1, 'I-PER': 2, 'B-PER': 3, 'I-MISC': 4, 'I-ORG': 5, 'B-ORG': 6, 'I-LOC': 7, 'O': 8}\n" + ] + } + ], + "source": [ + "ner_vocab = ner_tags(Y_train)\n", + "print(ner_vocab)" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "40255ab6-2e02-4b95-b0db-06015afef0c3", + "metadata": {}, + "outputs": [], + "source": [ + "def transform_tags_to_indexes(tagged_texts, ner_vocab):\n", + " # Initialize the array to store the transformed data\n", + " indexed_texts = []\n", + "\n", + " # Iterate through each tagged text\n", + " for text in tagged_texts:\n", + " # Split the text into tags\n", + " tags = text.split()\n", + "\n", + " # Convert each tag to its corresponding index using the tag_to_index dictionary\n", + " indexed_text = [ner_vocab[tag] for tag in tags]\n", + "\n", + " # Append the indexed text to the array\n", + " indexed_texts.append(indexed_text)\n", + "\n", + " return indexed_texts\n" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "0eff09ee-828e-4d89-9562-10dbcb0c73ea", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "B-ORG O B-MISC O O O B-MISC O O O B-PER I-PER O B-LOC O O O B-ORG I-ORG O O O O O O B-MISC O O O O O B-MISC O O O O O O O O O O O O O O O B-LOC O O O O B-ORG I-ORG O O O B-PER I-PER O O O O O O O O O O O B-LOC O O O O O O O O O O O O O O O O O O O O O O O O O O O O B-ORG O O O B-PER I-PER I-PER I-PER O O O O O O O O O O O O O O O O O O O O O O O O O O O O B-ORG I-ORG O O O O O O O O O B-ORG O O B-PER I-PER O O O O O O O O O O O O O O O O O O O O O O O O O O O O O B-PER O B-MISC O O O O B-LOC O B-LOC O O O O O O O B-MISC I-MISC I-MISC O B-MISC O O O O O O O O B-PER O O O O O O O B-ORG O O O O O O O O O O O O O O O O O O O O O O O O O O O O B-MISC O O B-PER I-PER I-PER O O O B-PER O O B-ORG O O O O O O O O O O O O O O O O O O B-LOC O B-LOC O B-PER O O O O O B-ORG O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O B-MISC O O O O O O O O O O O O O O O O B-MISC O O O O O O O O O O O O O O O O O O O B-MISC O O O O O O B-MISC O O O O O B-LOC O O O O O O O O O O O O O O O O O O O B-LOC O O O O B-ORG I-ORG I-ORG I-ORG I-ORG O B-ORG O O B-PER I-PER I-PER O O B-ORG I-ORG O O B-LOC O O O O O O O O O O O O O O O B-MISC O O O O O O O O O O O O O O O O O O B-LOC O O O O B-LOC O O O O O O O O O O O O O O O O B-MISC O O O O O O O O O O\n", + "235\n" + ] + } + ], + "source": [ + "print(Y_train[0])\n", + "Y_train = transform_tags_to_indexes(Y_train, ner_vocab)\n", + "print(len(Y_train[4]))" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "eed228a2-55c3-454a-9760-54059d38a002", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "O O B-ORG O O O O O O O O O B-LOC O O B-MISC I-MISC O B-PER I-PER O O O O O O O B-ORG O B-ORG O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O B-ORG O B-ORG O B-ORG O O O O O O B-ORG O O O O O O O O O O B-ORG O O O O B-ORG O O O O O O O O B-LOC I-LOC O B-ORG O O O O O O O O O O O O O O B-LOC O B-PER I-PER O O O O O O O O O O B-ORG O O O O O O O O O B-PER O O O O O O O O O O B-ORG O O O O O O O O O O O B-PER I-PER O B-PER I-PER O O O O O O O O O B-ORG O B-LOC O O B-PER O O O O B-LOC O O O O O O O O O O O O O O O O B-ORG O O O O O O O O O O O O O O O B-ORG O O O O O O O O O B-PER O O O O O O O O O O O O O O O O O O O O O O O O O O O O B-LOC O B-ORG O B-PER I-PER O O O O O B-LOC O O O O O O O O O O O O O O O O O O O O O O O O O B-ORG O O O O O O O B-LOC O B-PER I-PER O O O O B-ORG O O O O O O O O O O O O B-ORG O O O O O O O O O O O O O B-ORG O O O O O O O O O O O O O O O O O O O O B-MISC B-PER I-PER O O O O O B-PER I-PER O O O O B-PER I-PER O O O O B-ORG O O O O O O O O O O O O O O O O O O O O O O O O O O O O B-ORG O O O O O O B-ORG O O O O O O O O O O O O O B-PER I-PER O B-MISC O O B-PER I-PER O O O O O O O O B-ORG O O O O O O O\n", + "[8, 8, 6, 8, 8, 8, 8, 8, 8, 8, 8, 8, 0, 8, 8, 1, 4, 8, 3, 2, 8, 8, 8, 8, 8, 8, 8, 6, 8, 6, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 6, 8, 6, 8, 6, 8, 8, 8, 8, 8, 8, 6, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 6, 8, 8, 8, 8, 6, 8, 8, 8, 8, 8, 8, 8, 8, 0, 7, 8, 6, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 0, 8, 3, 2, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 6, 8, 8, 8, 8, 8, 8, 8, 8, 8, 3, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 6, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 3, 2, 8, 3, 2, 8, 8, 8, 8, 8, 8, 8, 8, 8, 6, 8, 0, 8, 8, 3, 8, 8, 8, 8, 0, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 6, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 6, 8, 8, 8, 8, 8, 8, 8, 8, 8, 3, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 0, 8, 6, 8, 3, 2, 8, 8, 8, 8, 8, 0, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 6, 8, 8, 8, 8, 8, 8, 8, 0, 8, 3, 2, 8, 8, 8, 8, 6, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 6, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 6, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 1, 3, 2, 8, 8, 8, 8, 8, 3, 2, 8, 8, 8, 8, 3, 2, 8, 8, 8, 8, 6, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 6, 8, 8, 8, 8, 8, 8, 6, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 3, 2, 8, 1, 8, 8, 3, 2, 8, 8, 8, 8, 8, 8, 8, 8, 6, 8, 8, 8, 8, 8, 8, 8]\n" + ] + } + ], + "source": [ + "print(Y_test[0])\n", + "Y_test = transform_tags_to_indexes(Y_test, ner_vocab)\n", + "print(Y_test[0])" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "be8c5036-d70d-4236-958e-1ae853385a61", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([0, 6, 8, 1, 8, 8, 8, 1, 8, 8, 8, 3, 2, 8, 0, 8, 8, 8, 6, 5, 8, 8, 8, 8,\n", + " 8, 8, 1, 8, 8, 8, 8, 8, 1, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,\n", + " 0, 8, 8, 8, 8, 6, 5, 8, 8, 8, 3, 2, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 0,\n", + " 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,\n", + " 8, 8, 8, 8, 6, 8, 8, 8, 3, 2, 2, 2, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,\n", + " 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 6, 5, 8, 8, 8, 8, 8, 8,\n", + " 8, 8, 8, 6, 8, 8, 3, 2, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,\n", + " 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 3, 8, 1, 8, 8, 8, 8, 0, 8, 0, 8,\n", + " 8, 8, 8, 8, 8, 8, 1, 4, 4, 8, 1, 8, 8, 8, 8, 8, 8, 8, 8, 3, 8, 8, 8, 8,\n", + " 8, 8, 8, 6, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,\n", + " 8, 8, 8, 8, 8, 8, 8, 8, 1, 8, 8, 3, 2, 2, 8, 8, 8, 3, 8, 8, 6, 8, 8, 8,\n", + " 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 0, 8, 0, 8, 3, 8, 8, 8, 8,\n", + " 8, 6, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,\n", + " 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 1, 8, 8, 8, 8, 8, 8,\n", + " 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 1, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,\n", + " 8, 8, 8, 8, 8, 8, 1, 8, 8, 8, 8, 8, 8, 1, 8, 8, 8, 8, 8, 0, 8, 8, 8, 8,\n", + " 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 0, 8, 8, 8, 8, 6, 5, 5, 5,\n", + " 5, 8, 6, 8, 8, 3, 2, 2, 8, 8, 6, 5, 8, 8, 0, 8, 8, 8, 8, 8, 8, 8, 8, 8,\n", + " 8, 8, 8, 8, 8, 8, 1, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,\n", + " 8, 0, 8, 8, 8, 8, 0, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 1,\n", + " 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 0])\n", + "tensor([0, 8, 8, 6, 8, 8, 8, 8, 8, 8, 8, 8, 8, 0, 8, 8, 1, 4, 8, 3, 2, 8, 8, 8,\n", + " 8, 8, 8, 8, 6, 8, 6, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,\n", + " 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 6, 8, 6, 8, 6, 8, 8,\n", + " 8, 8, 8, 8, 6, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 6, 8, 8, 8, 8, 6, 8, 8, 8,\n", + " 8, 8, 8, 8, 8, 0, 7, 8, 6, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 0,\n", + " 8, 3, 2, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 6, 8, 8, 8, 8, 8, 8, 8, 8, 8, 3,\n", + " 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 6, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 3, 2,\n", + " 8, 3, 2, 8, 8, 8, 8, 8, 8, 8, 8, 8, 6, 8, 0, 8, 8, 3, 8, 8, 8, 8, 0, 8,\n", + " 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 6, 8, 8, 8, 8, 8, 8, 8, 8,\n", + " 8, 8, 8, 8, 8, 8, 8, 6, 8, 8, 8, 8, 8, 8, 8, 8, 8, 3, 8, 8, 8, 8, 8, 8,\n", + " 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 0, 8,\n", + " 6, 8, 3, 2, 8, 8, 8, 8, 8, 0, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,\n", + " 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 6, 8, 8, 8, 8, 8, 8, 8, 0, 8, 3, 2, 8,\n", + " 8, 8, 8, 6, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 6, 8, 8, 8, 8, 8, 8, 8,\n", + " 8, 8, 8, 8, 8, 8, 6, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,\n", + " 8, 8, 8, 1, 3, 2, 8, 8, 8, 8, 8, 3, 2, 8, 8, 8, 8, 3, 2, 8, 8, 8, 8, 6,\n", + " 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,\n", + " 8, 8, 8, 8, 6, 8, 8, 8, 8, 8, 8, 6, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,\n", + " 8, 3, 2, 8, 1, 8, 8, 3, 2, 8, 8, 8, 8, 8, 8, 8, 8, 6, 8, 8, 8, 8, 8, 8,\n", + " 8, 0])\n" + ] + } + ], + "source": [ + "train_labels = labels_process(Y_train)\n", + "print(train_labels[0])\n", + "validation_labels = labels_process(Y_test)\n", + "print(validation_labels[0])" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "289b9d0a-ccb1-48fa-b14e-8e3dfdcb8dea", + "metadata": {}, + "outputs": [], + "source": [ + "def get_scores(y_true, y_pred):\n", + " # Funkcja zwraca precyzję, pokrycie i F1\n", + " acc_score = 0\n", + " tp = 0\n", + " fp = 0\n", + " selected_items = 0\n", + " relevant_items = 0\n", + "\n", + " for p, t in zip(y_pred, y_true):\n", + " if p == t:\n", + " acc_score += 1\n", + "\n", + " if p > 0 and p == t:\n", + " tp += 1\n", + "\n", + " if p > 0:\n", + " selected_items += 1\n", + "\n", + " if t > 0:\n", + " relevant_items += 1\n", + "\n", + " if selected_items == 0:\n", + " precision = 1.0\n", + " else:\n", + " precision = tp / selected_items\n", + "\n", + " if relevant_items == 0:\n", + " recall = 1.0\n", + " else:\n", + " recall = tp / relevant_items\n", + "\n", + " if precision + recall == 0.0:\n", + " f1 = 0.0\n", + " else:\n", + " f1 = 2 * precision * recall / (precision + recall)\n", + " accuracy = acc_score / len(y_pred)\n", + " return precision, recall, f1, accuracy" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "29944aa1-619e-41e9-adc7-4f82bfe856fd", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "9\n" + ] + } + ], + "source": [ + "num_tags = max([max(x) for x in Y_train]) + 1\n", + "print(num_tags)" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "a9551d9d-aee2-4e23-91be-bd1c39eb9ccd", + "metadata": {}, + "outputs": [], + "source": [ + "class LSTM(torch.nn.Module):\n", + "\n", + " def __init__(self):\n", + " super(LSTM, self).__init__()\n", + " self.emb = torch.nn.Embedding(len(v.get_itos()), 100)\n", + " self.rec = torch.nn.LSTM(100, 256, 1, batch_first=True)\n", + " self.fc1 = torch.nn.Linear(256, num_tags)\n", + "\n", + " def forward(self, x):\n", + " emb = torch.relu(self.emb(x))\n", + " lstm_output, (h_n, c_n) = self.rec(emb)\n", + " out_weights = self.fc1(lstm_output)\n", + " return out_weights" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "df7786bf-8fd4-48b3-95b4-98a4deba8c94", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['german', 'july', 'car', 'registrations', 'up', '14.2', 'pct', 'yr', '/', 'yr', '.', '', 'frankfurt', '1996-08-22', '', 'german', 'first-time', 'registrations', 'of', 'motor', 'vehicles', 'jumped', '14.2', 'percent', 'in', 'july', 'this', 'year', 'from', 'the', 'year-earlier', 'period', ',', 'the', 'federal', 'office', 'for', 'motor', 'vehicles', 'said', 'on', 'thursday', '.', '', 'the', 'office', 'said', '356,725', 'new', 'cars', 'were', 'registered', 'in', 'july', '1996', '--', '304,850', 'passenger', 'cars', 'and', '15,613', 'trucks', '.', '', 'the', 'figures', 'represent', 'a', '13.6', 'percent', 'increase', 'for', 'passenger', 'cars', 'and', 'a', '2.2', 'percent', 'decline', 'for', 'trucks', 'from', 'july', '1995', '.', '', 'motor-bike', 'registration', 'rose', '32.7', 'percent', 'in', 'the', 'period', '.', '', 'the', 'growth', 'was', 'partly', 'due', 'to', 'an', 'increased', 'number', 'of', 'germans', 'buying', 'german', 'cars', 'abroad', ',', 'while', 'manufacturers', 'said', 'that', 'domestic', 'demand', 'was', 'weak', ',', 'the', 'federal', 'office', 'said', '.', '', 'almost', 'all', 'german', 'car', 'manufacturers', 'posted', 'gains', 'in', 'registration', 'numbers', 'in', 'the', 'period', '.', '', 'volkswagen', 'ag', 'won', '77,719', 'registrations', ',', 'slightly', 'more', 'than', 'a', 'quarter', 'of', 'the', 'total', '.', '', 'opel', 'ag', 'together', 'with', 'general', 'motors', 'came', 'in', 'second', 'place', 'with', '49,269', 'registrations', ',', '16.4', 'percent', 'of', 'the', 'overall', 'figure', '.', '', 'third', 'was', 'ford', 'with', '35,563', 'registrations', ',', 'or', '11.7', 'percent', '.', '', 'only', 'seat', 'and', 'porsche', 'had', 'fewer', 'registrations', 'in', 'july', '1996', 'compared', 'to', 'last', 'year', \"'s\", 'july', '.', '', 'seat', 'posted', '3,420', 'registrations', 'compared', 'with', '5522', 'registrations', 'in', 'july', 'a', 'year', 'earlier', '.', '', 'porsche', \"'s\", 'registrations', 'fell', 'to', '554', 'from', '643', '.', ''] [1, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 0, 8, 8, 1, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 6, 5, 5, 5, 5, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 1, 8, 1, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 1, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 6, 5, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 6, 5, 8, 8, 6, 5, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 6, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 6, 8, 6, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 6, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 6, 8, 8, 8, 8, 8, 8, 8, 8, 8]\n", + "235 235\n" + ] + } + ], + "source": [ + "print(X_train[4], Y_train[4])\n", + "print(len(X_train[4]), len(Y_train[4]))" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "4b0695c1-d8a3-4afe-86c9-c89af8379ddf", + "metadata": {}, + "outputs": [], + "source": [ + "lstm = LSTM()" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "9c52f91b-a355-49bf-9281-980307ae36c5", + "metadata": {}, + "outputs": [], + "source": [ + "criterion = torch.nn.CrossEntropyLoss()" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "id": "a0b86bca-58ae-48a5-a0b0-afbe855950d4", + "metadata": {}, + "outputs": [], + "source": [ + "optimizer = torch.optim.Adam(lstm.parameters())" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "id": "73f6ade4-b62e-46dd-afe9-a3538bb9f5cb", + "metadata": {}, + "outputs": [], + "source": [ + "def eval_model(dataset_tokens, dataset_labels, model):\n", + " Y_true = []\n", + " Y_pred = []\n", + " for i in tqdm(range(len(dataset_labels))):\n", + " batch_tokens = dataset_tokens[i].unsqueeze(0)\n", + " tags = list(dataset_labels[i].numpy())\n", + " Y_true += tags\n", + "\n", + " Y_batch_pred_weights = model(batch_tokens).squeeze(0)\n", + " Y_batch_pred = torch.argmax(Y_batch_pred_weights, 1)\n", + " Y_pred += list(Y_batch_pred.numpy())\n", + "\n", + " return get_scores(Y_true, Y_pred), Y_pred" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "id": "2fa7a94e-06da-4bf4-8c4c-59f6be205214", + "metadata": {}, + "outputs": [], + "source": [ + "NUM_EPOCHS = 5" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "id": "59bf2468-331b-4d95-a8ab-2edb4e9038fb", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "215\n" + ] + } + ], + "source": [ + "print(len(validation_labels))" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "id": "801db2ad-0d00-4550-a8ea-37700fb3ee6e", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "a4488b22b15448bd8473892f256ba528", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/945 [00:00