From b4a2b227c112e4a93bcb71079f2c1fe8acb20a9a Mon Sep 17 00:00:00 2001 From: s464903 Date: Mon, 27 May 2024 00:35:23 +0200 Subject: [PATCH] Delete train/RNN.ipynb --- train/RNN.ipynb | 1878 ----------------------------------------------- 1 file changed, 1878 deletions(-) delete mode 100644 train/RNN.ipynb diff --git a/train/RNN.ipynb b/train/RNN.ipynb deleted file mode 100644 index 29a3222..0000000 --- a/train/RNN.ipynb +++ /dev/null @@ -1,1878 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "id": "5d0842c8-c292-41ce-a27b-73986bf43e1c", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "C:\\Users\\obses\\AppData\\Local\\Programs\\Python\\Python310\\lib\\site-packages\\torchtext\\vocab\\__init__.py:4: UserWarning: \n", - "/!\\ IMPORTANT WARNING ABOUT TORCHTEXT STATUS /!\\ \n", - "Torchtext is deprecated and the last released version will be 0.18 (this one). You can silence this warning by calling the following at the beginnign of your scripts: `import torchtext; torchtext.disable_torchtext_deprecation_warning()`\n", - " warnings.warn(torchtext._TORCHTEXT_DEPRECATION_MSG)\n", - "C:\\Users\\obses\\AppData\\Local\\Programs\\Python\\Python310\\lib\\site-packages\\torchtext\\utils.py:4: UserWarning: \n", - "/!\\ IMPORTANT WARNING ABOUT TORCHTEXT STATUS /!\\ \n", - "Torchtext is deprecated and the last released version will be 0.18 (this one). You can silence this warning by calling the following at the beginnign of your scripts: `import torchtext; torchtext.disable_torchtext_deprecation_warning()`\n", - " warnings.warn(torchtext._TORCHTEXT_DEPRECATION_MSG)\n", - "[nltk_data] Downloading package stopwords to\n", - "[nltk_data] C:\\Users\\obses\\AppData\\Roaming\\nltk_data...\n", - "[nltk_data] Package stopwords is already up-to-date!\n" - ] - } - ], - "source": [ - "from collections import Counter\n", - "\n", - "import torch\n", - "from datasets import load_dataset\n", - "from torchtext.vocab import vocab\n", - "from tqdm.notebook import tqdm\n", - "\n", - "import pandas as pd\n", - "\n", - "import nltk\n", - "nltk.download('stopwords')\n", - "from nltk.corpus import stopwords\n", - "from nltk.tokenize import word_tokenize\n", - "import string" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "c96f78af-f38a-49c1-b235-b15b89baf0ae", - "metadata": {}, - "outputs": [], - "source": [ - "dataset = pd.read_csv('train.tsv', sep='\\t', header=None)" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "c027b2f3-48eb-442a-a212-c5fa9d1cfba5", - "metadata": {}, - "outputs": [], - "source": [ - "X_test = pd.read_csv('../dev-0/in.tsv', sep='\\t', header=None)\n", - "Y_test = pd.read_csv('../dev-0/expected.tsv', sep='\\t', header=None)" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "9fc7a069-52ce-4343-a8a2-dd77619f5d25", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "German July car registrations up 14.2 pct yr / yr . FRANKFURT 1996-08-22 German first-time registrations of motor vehicles jumped 14.2 percent in July this year from the year-earlier period , the Federal office for motor vehicles said on Thursday . The office said 356,725 new cars were registered in July 1996 -- 304,850 passenger cars and 15,613 trucks . The figures represent a 13.6 percent increase for passenger cars and a 2.2 percent decline for trucks from July 1995 . Motor-bike registration rose 32.7 percent in the period . The growth was partly due to an increased number of Germans buying German cars abroad , while manufacturers said that domestic demand was weak , the federal office said . Almost all German car manufacturers posted gains in registration numbers in the period . Volkswagen AG won 77,719 registrations , slightly more than a quarter of the total . Opel AG together with General Motors came in second place with 49,269 registrations , 16.4 percent of the overall figure . Third was Ford with 35,563 registrations , or 11.7 percent . Only Seat and Porsche had fewer registrations in July 1996 compared to last year 's July . Seat posted 3,420 registrations compared with 5522 registrations in July a year earlier . Porsche 's registrations fell to 554 from 643 . \n", - "BASKETBALL - INTERNATIONAL TOURNAMENT RESULT . BELGRADE 1996-08-30 Result in an international basketball tournament on Friday : Red Star ( Yugoslavia ) beat Dinamo ( Russia ) 92-90 ( halftime 47-47 ) \n" - ] - } - ], - "source": [ - "X_train = dataset[dataset.columns[1]].replace(\"\",\"\")\n", - "Y_train = dataset[dataset.columns[0]]\n", - "\n", - "X_test = X_test[X_test.columns[0]].replace(\"\",\"\")\n", - "Y_test = Y_test[Y_test.columns[0]]\n", - "\n", - "print(X_train[4])\n", - "print(X_test[4])" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "7cd49893-c730-4089-9b21-8683c11fd6cc", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - " ['german', 'july', 'car', 'registrations', 'up', '14.2', 'pct', 'yr', '/', 'yr', '.', '', 'frankfurt', '1996-08-22', '', 'german', 'first-time', 'registrations', 'of', 'motor', 'vehicles', 'jumped', '14.2', 'percent', 'in', 'july', 'this', 'year', 'from', 'the', 'year-earlier', 'period', ',', 'the', 'federal', 'office', 'for', 'motor', 'vehicles', 'said', 'on', 'thursday', '.', '', 'the', 'office', 'said', '356,725', 'new', 'cars', 'were', 'registered', 'in', 'july', '1996', '--', '304,850', 'passenger', 'cars', 'and', '15,613', 'trucks', '.', '', 'the', 'figures', 'represent', 'a', '13.6', 'percent', 'increase', 'for', 'passenger', 'cars', 'and', 'a', '2.2', 'percent', 'decline', 'for', 'trucks', 'from', 'july', '1995', '.', '', 'motor-bike', 'registration', 'rose', '32.7', 'percent', 'in', 'the', 'period', '.', '', 'the', 'growth', 'was', 'partly', 'due', 'to', 'an', 'increased', 'number', 'of', 'germans', 'buying', 'german', 'cars', 'abroad', ',', 'while', 'manufacturers', 'said', 'that', 'domestic', 'demand', 'was', 'weak', ',', 'the', 'federal', 'office', 'said', '.', '', 'almost', 'all', 'german', 'car', 'manufacturers', 'posted', 'gains', 'in', 'registration', 'numbers', 'in', 'the', 'period', '.', '', 'volkswagen', 'ag', 'won', '77,719', 'registrations', ',', 'slightly', 'more', 'than', 'a', 'quarter', 'of', 'the', 'total', '.', '', 'opel', 'ag', 'together', 'with', 'general', 'motors', 'came', 'in', 'second', 'place', 'with', '49,269', 'registrations', ',', '16.4', 'percent', 'of', 'the', 'overall', 'figure', '.', '', 'third', 'was', 'ford', 'with', '35,563', 'registrations', ',', 'or', '11.7', 'percent', '.', '', 'only', 'seat', 'and', 'porsche', 'had', 'fewer', 'registrations', 'in', 'july', '1996', 'compared', 'to', 'last', 'year', \"'s\", 'july', '.', '', 'seat', 'posted', '3,420', 'registrations', 'compared', 'with', '5522', 'registrations', 'in', 'july', 'a', 'year', 'earlier', '.', '', 'porsche', \"'s\", 'registrations', 'fell', 'to', '554', 'from', '643', '.', '']\n", - " ['basketball', '-', 'international', 'tournament', 'result', '.', '', 'belgrade', '1996-08-30', '', 'result', 'in', 'an', 'international', '', 'basketball', 'tournament', 'on', 'friday', ':', '', 'red', 'star', '(', 'yugoslavia', ')', 'beat', 'dinamo', '(', 'russia', ')', '92-90', '(', 'halftime', '', '47-47', ')', '']\n" - ] - } - ], - "source": [ - "def preprocess(text):\n", - " text = text.lower()\n", - " #text = ''.join([word for word in text if word not in string.punctuation])\n", - " #tokens = word_tokenize(text)\n", - " #tokens = [word for word in tokens if word not in stopwords.words('english')]\n", - " return text\n", - "\n", - "X_train = [preprocess(text).split() for text in X_train]\n", - "print( type(X_train), X_train[4])\n", - "X_test = [preprocess(text).split() for text in X_test]\n", - "print( type(X_test), X_test[4])" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "c24df99f-1792-4691-92c0-7e5e596237c3", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "['german', 'july', 'car', 'registrations', 'up', '14.2', 'pct', 'yr', '/', 'yr', '.', '', 'frankfurt', '1996-08-22', '', 'german', 'first-time', 'registrations', 'of', 'motor', 'vehicles', 'jumped', '14.2', 'percent', 'in', 'july', 'this', 'year', 'from', 'the', 'year-earlier', 'period', ',', 'the', 'federal', 'office', 'for', 'motor', 'vehicles', 'said', 'on', 'thursday', '.', '', 'the', 'office', 'said', '356,725', 'new', 'cars', 'were', 'registered', 'in', 'july', '1996', '--', '304,850', 'passenger', 'cars', 'and', '15,613', 'trucks', '.', '', 'the', 'figures', 'represent', 'a', '13.6', 'percent', 'increase', 'for', 'passenger', 'cars', 'and', 'a', '2.2', 'percent', 'decline', 'for', 'trucks', 'from', 'july', '1995', '.', '', 'motor-bike', 'registration', 'rose', '32.7', 'percent', 'in', 'the', 'period', '.', '', 'the', 'growth', 'was', 'partly', 'due', 'to', 'an', 'increased', 'number', 'of', 'germans', 'buying', 'german', 'cars', 'abroad', ',', 'while', 'manufacturers', 'said', 'that', 'domestic', 'demand', 'was', 'weak', ',', 'the', 'federal', 'office', 'said', '.', '', 'almost', 'all', 'german', 'car', 'manufacturers', 'posted', 'gains', 'in', 'registration', 'numbers', 'in', 'the', 'period', '.', '', 'volkswagen', 'ag', 'won', '77,719', 'registrations', ',', 'slightly', 'more', 'than', 'a', 'quarter', 'of', 'the', 'total', '.', '', 'opel', 'ag', 'together', 'with', 'general', 'motors', 'came', 'in', 'second', 'place', 'with', '49,269', 'registrations', ',', '16.4', 'percent', 'of', 'the', 'overall', 'figure', '.', '', 'third', 'was', 'ford', 'with', '35,563', 'registrations', ',', 'or', '11.7', 'percent', '.', '', 'only', 'seat', 'and', 'porsche', 'had', 'fewer', 'registrations', 'in', 'july', '1996', 'compared', 'to', 'last', 'year', \"'s\", 'july', '.', '', 'seat', 'posted', '3,420', 'registrations', 'compared', 'with', '5522', 'registrations', 'in', 'july', 'a', 'year', 'earlier', '.', '', 'porsche', \"'s\", 'registrations', 'fell', 'to', '554', 'from', '643', '.', ''] B-MISC O O O O O O O O O O O B-LOC O O B-MISC O O O O O O O O O O O O O O O O O O B-ORG I-ORG I-ORG I-ORG I-ORG O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O B-MISC O B-MISC O O O O O O O O O O O O O O O O O O O O B-MISC O O O O O O O O O O O O B-ORG I-ORG O O O O O O O O O O O O O O B-ORG I-ORG O O B-ORG I-ORG O O O O O O O O O O O O O O O O O O B-ORG O O O O O O O O O O B-ORG O B-ORG O O O O O O O O O O O O O O B-ORG O O O O O O O O O O O O O O B-ORG O O O O O O O O O\n", - "235 562\n" - ] - } - ], - "source": [ - "print(X_train[4], Y_train[4])\n", - "print(len(X_train[4]), len(Y_train[4]))" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "68f95ff5-28d4-4267-94f1-7de37d0b1a1b", - "metadata": {}, - "outputs": [], - "source": [ - "def build_vocab(dataset):\n", - " counter = Counter()\n", - " for document in dataset:\n", - " counter.update(document)\n", - " return vocab(counter, specials=[\"\", \"\", \"\", \"\"])" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "7d2e8c94-403d-4481-afe2-23ef80146b4c", - "metadata": {}, - "outputs": [], - "source": [ - "v = build_vocab(X_train)" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "c8ad750f-fa6e-4ce1-a32f-70c314f5a587", - "metadata": {}, - "outputs": [], - "source": [ - "itos = v.get_itos() # mapowanie indeksów na tokeny" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "4368cae0-a34d-47a8-9a18-8c7dc78ecf8c", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "['', '', '', '', 'eu', 'rejects', 'german', 'call', 'to', 'boycott', 'british']\n" - ] - } - ], - "source": [ - "print(itos[:11])" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "id": "a155c4b4-5395-410e-8482-2d8f54250b44", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "21014" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "len(itos) # liczba różnych tokenów w słowniku" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "id": "7aa9c044-8d13-4582-a081-83837eefb76f", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "0" - ] - }, - "execution_count": 12, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "v[\"\"] # indeks nieznanego tokenu" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "id": "675d2d9a-398b-4b01-abd7-cfad1047d74c", - "metadata": {}, - "outputs": [], - "source": [ - "v.set_default_index(v[\"\"])" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "id": "4253750b-5c6e-4df2-8440-f190ae21418a", - "metadata": {}, - "outputs": [], - "source": [ - "def data_process(dt):\n", - " # Wektoryzacja dokumentów tekstowych.\n", - " return [\n", - " torch.tensor(\n", - " [v[\"\"]] + [v[token] for token in document] + [v[\"\"]],\n", - " dtype=torch.long,\n", - " )\n", - " for document in dt\n", - " ]" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "id": "4bf37b0a-ad61-41be-adc7-ad66741ef534", - "metadata": {}, - "outputs": [], - "source": [ - "def labels_process(dt):\n", - " # Wektoryzacja etykiet (NER)\n", - " return [torch.tensor([0] + document + [0], dtype=torch.long) for document in dt]" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "id": "be36e015-4e20-485a-867e-bcb517d2e6a5", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor([ 2, 6, 438, 439, 440, 309, 441, 442, 443, 444, 443, 12, 13, 445,\n", - " 17, 13, 6, 446, 440, 163, 447, 448, 449, 441, 255, 233, 438, 363,\n", - " 244, 53, 18, 450, 451, 73, 18, 452, 453, 72, 447, 448, 21, 22,\n", - " 23, 12, 13, 18, 453, 21, 454, 455, 456, 312, 457, 233, 438, 458,\n", - " 135, 459, 460, 456, 88, 461, 462, 12, 13, 18, 463, 464, 81, 465,\n", - " 255, 466, 72, 460, 456, 88, 81, 467, 255, 468, 72, 462, 53, 438,\n", - " 469, 12, 13, 470, 471, 472, 473, 255, 233, 18, 451, 12, 13, 18,\n", - " 474, 59, 475, 174, 8, 159, 476, 477, 163, 478, 479, 6, 456, 480,\n", - " 73, 481, 482, 21, 91, 483, 484, 59, 485, 73, 18, 452, 453, 21,\n", - " 12, 13, 262, 416, 6, 439, 482, 486, 487, 233, 471, 488, 233, 18,\n", - " 451, 12, 13, 489, 490, 491, 492, 440, 73, 493, 494, 56, 81, 495,\n", - " 163, 18, 247, 12, 13, 496, 490, 497, 26, 498, 499, 500, 233, 501,\n", - " 502, 26, 503, 440, 73, 504, 255, 163, 18, 256, 505, 12, 13, 506,\n", - " 59, 507, 26, 508, 440, 73, 509, 510, 255, 12, 13, 147, 511, 88,\n", - " 512, 155, 513, 440, 233, 438, 458, 514, 8, 97, 244, 42, 438, 12,\n", - " 13, 511, 486, 515, 440, 514, 26, 516, 440, 233, 438, 81, 244, 156,\n", - " 12, 13, 512, 42, 440, 517, 8, 518, 53, 519, 12, 13, 3])\n", - "tensor([ 2, 9637, 640, 419, 1908, 1850, 12, 13, 2439, 19358,\n", - " 13, 1850, 233, 159, 419, 13, 9637, 1908, 22, 1098,\n", - " 380, 13, 1103, 2438, 132, 2440, 134, 1762, 11167, 132,\n", - " 1160, 134, 0, 132, 1767, 13, 0, 134, 13, 3])\n" - ] - } - ], - "source": [ - "train_tokens_ids = data_process(X_train)\n", - "print(train_tokens_ids[4])\n", - "validation_tokens_ids = data_process(X_test)\n", - "print(validation_tokens_ids[4])" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "id": "34ec5fa0-7bff-44d0-9428-94c9b24244ec", - "metadata": {}, - "outputs": [], - "source": [ - "def ner_tags(dataset):\n", - " tags = []\n", - " for array in dataset:\n", - " for el in array.split():\n", - " tags.append(el)\n", - " tags = set(tags)\n", - " print(len(tags))\n", - " tag_to_index = {tag: idx for idx, tag in enumerate(tags)}\n", - " return tag_to_index" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "id": "4fc9a9c9-cb3b-43b0-b2b0-0194a9f36d5d", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "9\n", - "{'I-MISC': 0, 'B-MISC': 1, 'I-LOC': 2, 'I-ORG': 3, 'I-PER': 4, 'O': 5, 'B-ORG': 6, 'B-PER': 7, 'B-LOC': 8}\n" - ] - } - ], - "source": [ - "ner_vocab = ner_tags(Y_train)\n", - "print(ner_vocab)" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "id": "40255ab6-2e02-4b95-b0db-06015afef0c3", - "metadata": {}, - "outputs": [], - "source": [ - "def transform_tags_to_indexes(tagged_texts, ner_vocab):\n", - " # Initialize the array to store the transformed data\n", - " indexed_texts = []\n", - "\n", - " # Iterate through each tagged text\n", - " for text in tagged_texts:\n", - " # Split the text into tags\n", - " tags = text.split()\n", - "\n", - " # Convert each tag to its corresponding index using the tag_to_index dictionary\n", - " indexed_text = [ner_vocab[tag] for tag in tags]\n", - "\n", - " # Append the indexed text to the array\n", - " indexed_texts.append(indexed_text)\n", - "\n", - " return indexed_texts\n" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "id": "0eff09ee-828e-4d89-9562-10dbcb0c73ea", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "B-ORG O B-MISC O O O B-MISC O O O B-PER I-PER O B-LOC O O O B-ORG I-ORG O O O O O O B-MISC O O O O O B-MISC O O O O O O O O O O O O O O O B-LOC O O O O B-ORG I-ORG O O O B-PER I-PER O O O O O O O O O O O B-LOC O O O O O O O O O O O O O O O O O O O O O O O O O O O O B-ORG O O O B-PER I-PER I-PER I-PER O O O O O O O O O O O O O O O O O O O O O O O O O O O O B-ORG I-ORG O O O O O O O O O B-ORG O O B-PER I-PER O O O O O O O O O O O O O O O O O O O O O O O O O O O O O B-PER O B-MISC O O O O B-LOC O B-LOC O O O O O O O B-MISC I-MISC I-MISC O B-MISC O O O O O O O O B-PER O O O O O O O B-ORG O O O O O O O O O O O O O O O O O O O O O O O O O O O O B-MISC O O B-PER I-PER I-PER O O O B-PER O O B-ORG O O O O O O O O O O O O O O O O O O B-LOC O B-LOC O B-PER O O O O O B-ORG O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O B-MISC O O O O O O O O O O O O O O O O B-MISC O O O O O O O O O O O O O O O O O O O B-MISC O O O O O O B-MISC O O O O O B-LOC O O O O O O O O O O O O O O O O O O O B-LOC O O O O B-ORG I-ORG I-ORG I-ORG I-ORG O B-ORG O O B-PER I-PER I-PER O O B-ORG I-ORG O O B-LOC O O O O O O O O O O O O O O O B-MISC O O O O O O O O O O O O O O O O O O B-LOC O O O O B-LOC O O O O O O O O O O O O O O O O B-MISC O O O O O O O O O O\n", - "235\n" - ] - } - ], - "source": [ - "print(Y_train[0])\n", - "Y_train = transform_tags_to_indexes(Y_train, ner_vocab)\n", - "print(len(Y_train[4]))" - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "id": "eed228a2-55c3-454a-9760-54059d38a002", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "O O B-ORG O O O O O O O O O B-LOC O O B-MISC I-MISC O B-PER I-PER O O O O O O O B-ORG O B-ORG O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O B-ORG O B-ORG O B-ORG O O O O O O B-ORG O O O O O O O O O O B-ORG O O O O B-ORG O O O O O O O O B-LOC I-LOC O B-ORG O O O O O O O O O O O O O O B-LOC O B-PER I-PER O O O O O O O O O O B-ORG O O O O O O O O O B-PER O O O O O O O O O O B-ORG O O O O O O O O O O O B-PER I-PER O B-PER I-PER O O O O O O O O O B-ORG O B-LOC O O B-PER O O O O B-LOC O O O O O O O O O O O O O O O O B-ORG O O O O O O O O O O O O O O O B-ORG O O O O O O O O O B-PER O O O O O O O O O O O O O O O O O O O O O O O O O O O O B-LOC O B-ORG O B-PER I-PER O O O O O B-LOC O O O O O O O O O O O O O O O O O O O O O O O O O B-ORG O O O O O O O B-LOC O B-PER I-PER O O O O B-ORG O O O O O O O O O O O O B-ORG O O O O O O O O O O O O O B-ORG O O O O O O O O O O O O O O O O O O O O B-MISC B-PER I-PER O O O O O B-PER I-PER O O O O B-PER I-PER O O O O B-ORG O O O O O O O O O O O O O O O O O O O O O O O O O O O O B-ORG O O O O O O B-ORG O O O O O O O O O O O O O B-PER I-PER O B-MISC O O B-PER I-PER O O O O O O O O B-ORG O O O O O O O\n", - "[5, 5, 6, 5, 5, 5, 5, 5, 5, 5, 5, 5, 8, 5, 5, 1, 0, 5, 7, 4, 5, 5, 5, 5, 5, 5, 5, 6, 5, 6, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 5, 6, 5, 6, 5, 5, 5, 5, 5, 5, 6, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 5, 5, 5, 5, 6, 5, 5, 5, 5, 5, 5, 5, 5, 8, 2, 5, 6, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 8, 5, 7, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 5, 5, 5, 5, 5, 5, 5, 5, 5, 7, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 7, 4, 5, 7, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 5, 8, 5, 5, 7, 5, 5, 5, 5, 8, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 5, 5, 5, 5, 5, 5, 5, 5, 5, 7, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 8, 5, 6, 5, 7, 4, 5, 5, 5, 5, 5, 8, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 5, 5, 5, 5, 5, 5, 5, 8, 5, 7, 4, 5, 5, 5, 5, 6, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 1, 7, 4, 5, 5, 5, 5, 5, 7, 4, 5, 5, 5, 5, 7, 4, 5, 5, 5, 5, 6, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 5, 5, 5, 5, 5, 5, 6, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 7, 4, 5, 1, 5, 5, 7, 4, 5, 5, 5, 5, 5, 5, 5, 5, 6, 5, 5, 5, 5, 5, 5, 5]\n" - ] - } - ], - "source": [ - "print(Y_test[0])\n", - "Y_test = transform_tags_to_indexes(Y_test, ner_vocab)\n", - "print(Y_test[0])" - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "id": "be8c5036-d70d-4236-958e-1ae853385a61", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor([0, 6, 5, 1, 5, 5, 5, 1, 5, 5, 5, 7, 4, 5, 8, 5, 5, 5, 6, 3, 5, 5, 5, 5,\n", - " 5, 5, 1, 5, 5, 5, 5, 5, 1, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,\n", - " 8, 5, 5, 5, 5, 6, 3, 5, 5, 5, 7, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 8,\n", - " 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,\n", - " 5, 5, 5, 5, 6, 5, 5, 5, 7, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,\n", - " 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 3, 5, 5, 5, 5, 5, 5,\n", - " 5, 5, 5, 6, 5, 5, 7, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,\n", - " 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 7, 5, 1, 5, 5, 5, 5, 8, 5, 8, 5,\n", - " 5, 5, 5, 5, 5, 5, 1, 0, 0, 5, 1, 5, 5, 5, 5, 5, 5, 5, 5, 7, 5, 5, 5, 5,\n", - " 5, 5, 5, 6, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,\n", - " 5, 5, 5, 5, 5, 5, 5, 5, 1, 5, 5, 7, 4, 4, 5, 5, 5, 7, 5, 5, 6, 5, 5, 5,\n", - " 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 8, 5, 8, 5, 7, 5, 5, 5, 5,\n", - " 5, 6, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,\n", - " 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 1, 5, 5, 5, 5, 5, 5,\n", - " 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 1, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,\n", - " 5, 5, 5, 5, 5, 5, 1, 5, 5, 5, 5, 5, 5, 1, 5, 5, 5, 5, 5, 8, 5, 5, 5, 5,\n", - " 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 8, 5, 5, 5, 5, 6, 3, 3, 3,\n", - " 3, 5, 6, 5, 5, 7, 4, 4, 5, 5, 6, 3, 5, 5, 8, 5, 5, 5, 5, 5, 5, 5, 5, 5,\n", - " 5, 5, 5, 5, 5, 5, 1, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,\n", - " 5, 8, 5, 5, 5, 5, 8, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 1,\n", - " 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 0])\n", - "tensor([0, 5, 5, 6, 5, 5, 5, 5, 5, 5, 5, 5, 5, 8, 5, 5, 1, 0, 5, 7, 4, 5, 5, 5,\n", - " 5, 5, 5, 5, 6, 5, 6, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,\n", - " 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 5, 6, 5, 6, 5, 5,\n", - " 5, 5, 5, 5, 6, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 5, 5, 5, 5, 6, 5, 5, 5,\n", - " 5, 5, 5, 5, 5, 8, 2, 5, 6, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 8,\n", - " 5, 7, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 5, 5, 5, 5, 5, 5, 5, 5, 5, 7,\n", - " 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 7, 4,\n", - " 5, 7, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 5, 8, 5, 5, 7, 5, 5, 5, 5, 8, 5,\n", - " 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 5, 5, 5, 5, 5, 5, 5, 5,\n", - " 5, 5, 5, 5, 5, 5, 5, 6, 5, 5, 5, 5, 5, 5, 5, 5, 5, 7, 5, 5, 5, 5, 5, 5,\n", - " 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 8, 5,\n", - " 6, 5, 7, 4, 5, 5, 5, 5, 5, 8, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,\n", - " 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 5, 5, 5, 5, 5, 5, 5, 8, 5, 7, 4, 5,\n", - " 5, 5, 5, 6, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 5, 5, 5, 5, 5, 5, 5,\n", - " 5, 5, 5, 5, 5, 5, 6, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,\n", - " 5, 5, 5, 1, 7, 4, 5, 5, 5, 5, 5, 7, 4, 5, 5, 5, 5, 7, 4, 5, 5, 5, 5, 6,\n", - " 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,\n", - " 5, 5, 5, 5, 6, 5, 5, 5, 5, 5, 5, 6, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,\n", - " 5, 7, 4, 5, 1, 5, 5, 7, 4, 5, 5, 5, 5, 5, 5, 5, 5, 6, 5, 5, 5, 5, 5, 5,\n", - " 5, 0])\n" - ] - } - ], - "source": [ - "train_labels = labels_process(Y_train)\n", - "print(train_labels[0])\n", - "validation_labels = labels_process(Y_test)\n", - "print(validation_labels[0])" - ] - }, - { - "cell_type": "code", - "execution_count": 59, - "id": "289b9d0a-ccb1-48fa-b14e-8e3dfdcb8dea", - "metadata": {}, - "outputs": [], - "source": [ - "def get_scores(y_true, y_pred):\n", - " # Funkcja zwraca precyzję, pokrycie i F1\n", - " acc_score = 0\n", - " tp = 0\n", - " fp = 0\n", - " selected_items = 0\n", - " relevant_items = 0\n", - "\n", - " for p, t in zip(y_pred, y_true):\n", - " if p == t:\n", - " acc_score += 1\n", - "\n", - " if p > 0 and p == t:\n", - " tp += 1\n", - "\n", - " if p > 0:\n", - " selected_items += 1\n", - "\n", - " if t > 0:\n", - " relevant_items += 1\n", - "\n", - " if selected_items == 0:\n", - " precision = 1.0\n", - " else:\n", - " precision = tp / selected_items\n", - "\n", - " if relevant_items == 0:\n", - " recall = 1.0\n", - " else:\n", - " recall = tp / relevant_items\n", - "\n", - " if precision + recall == 0.0:\n", - " f1 = 0.0\n", - " else:\n", - " f1 = 2 * precision * recall / (precision + recall)\n", - " accuracy = acc_score / len(y_pred)\n", - " return precision, recall, f1, accuracy" - ] - }, - { - "cell_type": "code", - "execution_count": 60, - "id": "29944aa1-619e-41e9-adc7-4f82bfe856fd", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "9\n" - ] - } - ], - "source": [ - "num_tags = max([max(x) for x in Y_train]) + 1\n", - "print(num_tags)" - ] - }, - { - "cell_type": "code", - "execution_count": 25, - "id": "a9551d9d-aee2-4e23-91be-bd1c39eb9ccd", - "metadata": {}, - "outputs": [], - "source": [ - "class LSTM(torch.nn.Module):\n", - "\n", - " def __init__(self):\n", - " super(LSTM, self).__init__()\n", - " self.emb = torch.nn.Embedding(len(v.get_itos()), 100)\n", - " self.rec = torch.nn.LSTM(100, 256, 1, batch_first=True)\n", - " self.fc1 = torch.nn.Linear(256, num_tags)\n", - "\n", - " def forward(self, x):\n", - " emb = torch.relu(self.emb(x))\n", - " lstm_output, (h_n, c_n) = self.rec(emb)\n", - " out_weights = self.fc1(lstm_output)\n", - " return out_weights" - ] - }, - { - "cell_type": "code", - "execution_count": 26, - "id": "df7786bf-8fd4-48b3-95b4-98a4deba8c94", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "['german', 'july', 'car', 'registrations', 'up', '14.2', 'pct', 'yr', '/', 'yr', '.', '', 'frankfurt', '1996-08-22', '', 'german', 'first-time', 'registrations', 'of', 'motor', 'vehicles', 'jumped', '14.2', 'percent', 'in', 'july', 'this', 'year', 'from', 'the', 'year-earlier', 'period', ',', 'the', 'federal', 'office', 'for', 'motor', 'vehicles', 'said', 'on', 'thursday', '.', '', 'the', 'office', 'said', '356,725', 'new', 'cars', 'were', 'registered', 'in', 'july', '1996', '--', '304,850', 'passenger', 'cars', 'and', '15,613', 'trucks', '.', '', 'the', 'figures', 'represent', 'a', '13.6', 'percent', 'increase', 'for', 'passenger', 'cars', 'and', 'a', '2.2', 'percent', 'decline', 'for', 'trucks', 'from', 'july', '1995', '.', '', 'motor-bike', 'registration', 'rose', '32.7', 'percent', 'in', 'the', 'period', '.', '', 'the', 'growth', 'was', 'partly', 'due', 'to', 'an', 'increased', 'number', 'of', 'germans', 'buying', 'german', 'cars', 'abroad', ',', 'while', 'manufacturers', 'said', 'that', 'domestic', 'demand', 'was', 'weak', ',', 'the', 'federal', 'office', 'said', '.', '', 'almost', 'all', 'german', 'car', 'manufacturers', 'posted', 'gains', 'in', 'registration', 'numbers', 'in', 'the', 'period', '.', '', 'volkswagen', 'ag', 'won', '77,719', 'registrations', ',', 'slightly', 'more', 'than', 'a', 'quarter', 'of', 'the', 'total', '.', '', 'opel', 'ag', 'together', 'with', 'general', 'motors', 'came', 'in', 'second', 'place', 'with', '49,269', 'registrations', ',', '16.4', 'percent', 'of', 'the', 'overall', 'figure', '.', '', 'third', 'was', 'ford', 'with', '35,563', 'registrations', ',', 'or', '11.7', 'percent', '.', '', 'only', 'seat', 'and', 'porsche', 'had', 'fewer', 'registrations', 'in', 'july', '1996', 'compared', 'to', 'last', 'year', \"'s\", 'july', '.', '', 'seat', 'posted', '3,420', 'registrations', 'compared', 'with', '5522', 'registrations', 'in', 'july', 'a', 'year', 'earlier', '.', '', 'porsche', \"'s\", 'registrations', 'fell', 'to', '554', 'from', '643', '.', ''] [1, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 8, 5, 5, 1, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 3, 3, 3, 3, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 1, 5, 1, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 1, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 3, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 3, 5, 5, 6, 3, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 5, 6, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 5, 5, 5, 5, 5, 5, 5, 5, 5]\n", - "235 235\n" - ] - } - ], - "source": [ - "print(X_train[4], Y_train[4])\n", - "print(len(X_train[4]), len(Y_train[4]))" - ] - }, - { - "cell_type": "code", - "execution_count": 27, - "id": "4b0695c1-d8a3-4afe-86c9-c89af8379ddf", - "metadata": {}, - "outputs": [], - "source": [ - "lstm = LSTM()" - ] - }, - { - "cell_type": "code", - "execution_count": 28, - "id": "9c52f91b-a355-49bf-9281-980307ae36c5", - "metadata": {}, - "outputs": [], - "source": [ - "criterion = torch.nn.CrossEntropyLoss()" - ] - }, - { - "cell_type": "code", - "execution_count": 29, - "id": "a0b86bca-58ae-48a5-a0b0-afbe855950d4", - "metadata": {}, - "outputs": [], - "source": [ - "optimizer = torch.optim.Adam(lstm.parameters())" - ] - }, - { - "cell_type": "code", - "execution_count": 52, - "id": "73f6ade4-b62e-46dd-afe9-a3538bb9f5cb", - "metadata": {}, - "outputs": [], - "source": [ - "def eval_model(dataset_tokens, dataset_labels, model):\n", - " Y_true = []\n", - " Y_pred = []\n", - " for i in tqdm(range(len(dataset_labels))):\n", - " batch_tokens = dataset_tokens[i].unsqueeze(0)\n", - " tags = list(dataset_labels[i].numpy())\n", - " Y_true += tags\n", - "\n", - " Y_batch_pred_weights = model(batch_tokens).squeeze(0)\n", - " Y_batch_pred = torch.argmax(Y_batch_pred_weights, 1)\n", - " Y_pred += list(Y_batch_pred.numpy())\n", - " return get_scores(Y_true, Y_pred), Y_pred" - ] - }, - { - "cell_type": "code", - "execution_count": 37, - "id": "2fa7a94e-06da-4bf4-8c4c-59f6be205214", - "metadata": {}, - "outputs": [], - "source": [ - "NUM_EPOCHS = 2" - ] - }, - { - "cell_type": "code", - "execution_count": 38, - "id": "59bf2468-331b-4d95-a8ab-2edb4e9038fb", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "215\n" - ] - } - ], - "source": [ - "print(len(validation_labels))" - ] - }, - { - "cell_type": "code", - "execution_count": 61, - "id": "801db2ad-0d00-4550-a8ea-37700fb3ee6e", - "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "d2819b80390b45cf8a37b6b8125744a4", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " 0%| | 0/945 [00:00