DL_RNN/RNN_6.ipynb

1050 lines
275 KiB
Plaintext
Raw Permalink Normal View History

2024-05-26 21:42:42 +02:00
{
"cells": [
{
"cell_type": "markdown",
"source": [
"### Importy"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 0,
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.optim as optim\n",
"import pandas as pd\n",
"from collections import Counter\n",
"from torchtext.vocab import vocab\n",
"from tqdm import tqdm"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"start_time": "2024-05-26T20:32:41.574922Z",
"end_time": "2024-05-26T20:47:52.025382Z"
}
}
},
{
"cell_type": "markdown",
"source": [
"### Wczytanie danych"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 0,
"outputs": [],
"source": [
"def read_custom_dataset(path, is_train=True):\n",
" if is_train:\n",
" data = pd.read_csv(path, sep='\\t', header=None, compression='xz')\n",
" data.columns = ['ner', 'document']\n",
" else:\n",
" with open(path, 'r') as f:\n",
" documents = f.read().splitlines()\n",
" data = pd.DataFrame(documents, columns=['document'])\n",
" return data\n"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 0,
"outputs": [],
"source": [
"train_path = 'train/train.tsv.xz'\n",
"dev_in_path = 'dev-0/in.tsv'\n",
"dev_expected_path = 'dev-0/expected.tsv'\n",
"test_in_path = 'test-A/in.tsv'\n",
"\n",
"train_data = read_custom_dataset(train_path, is_train=True)\n",
"dev_data = read_custom_dataset(dev_in_path, is_train=False)\n",
"dev_labels = pd.read_csv(dev_expected_path, header=None, names=['ner'])\n",
"test_data = read_custom_dataset(test_in_path, is_train=False)"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "markdown",
"source": [
"### Tokenizacja"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 0,
"outputs": [],
"source": [
"def tokenize_documents(data):\n",
" return [doc.split() for doc in data['document'].tolist()]\n",
"\n",
"train_tokens = tokenize_documents(train_data)\n",
"dev_tokens = tokenize_documents(dev_data)\n",
"test_tokens = tokenize_documents(test_data)"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "markdown",
"source": [
"### Budowa słownika"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 20,
"outputs": [],
"source": [
"def build_vocab(tokens_list):\n",
" counter = Counter()\n",
" for tokens in tokens_list:\n",
" counter.update(tokens)\n",
" return vocab(counter, specials=[\"<unk>\", \"<pad>\", \"<bos>\", \"<eos>\"])\n",
"\n",
"token_vocab = build_vocab(train_tokens)\n",
"token_vocab.set_default_index(token_vocab[\"<unk>\"])"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"start_time": "2024-05-26T21:35:40.831018Z",
"end_time": "2024-05-26T21:35:40.926990Z"
}
}
},
{
"cell_type": "code",
"execution_count": 26,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['<unk>', '<pad>', '<bos>', '<eos>', 'EU', 'rejects', 'German', 'call', 'to', 'boycott', 'British', 'lamb', '.', '</S>', 'Peter', 'Blackburn', 'BRUSSELS', '1996-08-22', 'The', 'European', 'Commission', 'said', 'on', 'Thursday', 'it', 'disagreed', 'with', 'advice', 'consumers', 'shun', 'until', 'scientists', 'determine', 'whether', 'mad', 'cow', 'disease', 'can', 'be', 'transmitted', 'sheep', 'Germany', \"'s\", 'representative', 'the', 'Union', 'veterinary', 'committee', 'Werner', 'Zwingmann', 'Wednesday', 'should', 'buy', 'sheepmeat', 'from', 'countries', 'other', 'than', 'Britain', 'scientific', 'was', 'clearer', '\"', 'We', 'do', \"n't\", 'support', 'any', 'such', 'recommendation', 'because', 'we', 'see', 'grounds', 'for', ',', 'chief', 'spokesman', 'Nikolaus', 'van', 'der', 'Pas', 'told', 'a', 'news', 'briefing', 'He', 'further', 'study', 'required', 'and', 'if', 'found', 'that', 'action', 'needed', 'taken', 'by', 'proposal', 'last', 'month', 'Farm', 'Commissioner', 'Franz', 'Fischler', 'ban', 'brains', 'spleens', 'spinal', 'cords', 'human', 'animal', 'food', 'chains', 'highly', 'specific', 'precautionary', 'move', 'protect', 'health', 'proposed', 'EU-wide', 'measures', 'after', 'reports', 'France', 'under', 'laboratory', 'conditions', 'could', 'contract', 'Bovine', 'Spongiform', 'Encephalopathy', '(', 'BSE', ')', '--', 'But', 'agreed', 'review', 'his', 'standing', 'mational', 'officials', 'questioned', 'justified', 'as', 'there', 'only', 'slight', 'risk', 'Spanish', 'Minister', 'Loyola', 'de', 'Palacio', 'had', 'earlier', 'accused', 'at', 'an', 'farm', 'ministers', \"'\", 'meeting', 'of', 'causing', 'unjustified', 'alarm', 'through', 'dangerous', 'generalisation', 'Only', 'backed', 'multidisciplinary', 'committees', 'are', 'due', 're-examine', 'issue', 'early', 'next', 'make', 'recommendations', 'senior', 'Sheep', 'have', 'long', 'been', 'known', 'scrapie', 'brain-wasting', 'similar', 'which', 'is', 'believed', 'transferred', 'cattle', 'feed', 'containing', 'waste', 'farmers', 'denied', 'danger', 'their', 'but', 'expressed', 'concern', 'government', 'avoid', 'might', 'influence', 'across', 'Europe', 'What', 'extremely', 'careful', 'how', 'going', 'take', 'lead', 'Welsh', 'National', 'Farmers', 'NFU', 'chairman', 'John', 'Lloyd', 'Jones', 'BBC', 'radio', 'Bonn', 'has', 'led', 'efforts', 'public', 'consumer', 'confidence', 'collapsed', 'in', 'March', 'report', 'suggested', 'humans', 'illness', 'eating', 'contaminated', 'beef', 'imported', '47,600', 'year', 'nearly', 'half', 'total', 'imports', 'It', 'brought', '4,275', 'tonnes', 'mutton', 'some', '10', 'percent', 'overall', 'Rare', 'Hendrix', 'song', 'draft', 'sells', 'almost', '$', '17,000', 'LONDON', 'A', 'rare', 'handwritten', 'U.S.', 'guitar', 'legend', 'Jimi', 'sold', 'auction', 'late', 'musician', 'favourite', 'possessions', 'Florida', 'restaurant', 'paid', '10,925', 'pounds', '16,935', 'Ai', 'no', 'telling', 'penned', 'piece', 'London', 'hotel', 'stationery', '1966', 'At', 'end', 'January', '1967', 'concert', 'English', 'city', 'Nottingham', 'he', 'threw', 'sheet', 'paper', 'into', 'audience', 'where', 'retrieved', 'fan', 'Buyers', 'also', 'snapped', 'up', '16', 'items', 'were', 'put', 'former', 'girlfriend', 'Kathy', 'Etchingham', 'who', 'lived', 'him', '1969', 'They', 'included', 'black', 'lacquer', 'mother', 'pearl', 'inlaid', 'box', 'used', 'store', 'drugs', 'anonymous', 'Australian', 'purchaser', 'bought', '5,060', '7,845', 'guitarist', 'died', 'overdose', '1970', 'aged', '27', 'China', 'says', 'Taiwan', 'spoils', 'atmosphere', 'talks', 'BEIJING', 'Taipei', 'spoiling', 'resumption', 'Strait', 'visit', 'Ukraine', 'Taiwanese', 'Vice', 'President', 'Lien', 'Chan', 'this', 'week', 'infuriated', 'Beijing', 'Speaking', 'hours', 'Chinese', 'state', 'media', 'time', 'right', 'engage', 'political', 'Foreign', 'Ministry', 'Shen', 'Guofang', 'Reuters', ':', 'necessary', 'opening', 'disrupted', 'authorities', 'State', 'quoted', 'top', 'negotiator', 'Tang', 'Shubei', 'visiting', 'group', 'rivals', 'hold', 'Now', 'two', 'sides', '...', 'hostility', 'overseas', 'e
"Wielkość słownika: Vocab()\n"
]
}
],
"source": [
"print(token_vocab.get_itos())"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"start_time": "2024-05-26T21:36:57.983096Z",
"end_time": "2024-05-26T21:36:58.054097Z"
}
}
},
{
"cell_type": "code",
"execution_count": 28,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Wielkość słownika:\n",
"23628\n"
]
}
],
"source": [
"print(\"Wielkość słownika:\")\n",
"print(len(token_vocab))"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"start_time": "2024-05-26T21:37:15.421051Z",
"end_time": "2024-05-26T21:37:15.449862Z"
}
}
},
{
"cell_type": "markdown",
"source": [
"### Bio tagging"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 0,
"outputs": [],
"source": [
"bio_tags = sorted(set(tag for tags in train_data['ner'].apply(lambda x: x.split()) for tag in tags))\n",
"bio_to_int = {tag: idx for idx, tag in enumerate(bio_tags, start=1)}\n",
"bio_to_int[\"O\"] = 0\n",
"\n",
"label_mapping = {\n",
" 0: 'O',\n",
" 1: 'B-LOC',\n",
" 2: 'B-MISC',\n",
" 3: 'B-ORG',\n",
" 4: 'B-PER',\n",
" 5: 'I-LOC',\n",
" 6: 'I-MISC',\n",
" 7: 'I-ORG',\n",
" 8: 'I-PER'\n",
"}"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "markdown",
"source": [
"### Przetwarzanie danych"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 0,
"outputs": [],
"source": [
"def convert_bio_to_int(tags_list):\n",
" return [bio_to_int[tag] for tag in tags_list]\n",
"\n",
"def data_process(tokens_list):\n",
" return [\n",
" torch.tensor(\n",
" [token_vocab[\"<bos>\"]] + [token_vocab[token] for token in tokens] + [token_vocab[\"<eos>\"]],\n",
" dtype=torch.long,\n",
" )\n",
" for tokens in tokens_list\n",
" ]\n",
"\n",
"def labels_process(labels_list):\n",
" return [torch.tensor([0] + labels + [0], dtype=torch.long) for labels in labels_list]\n",
"\n",
"train_tokens_ids = data_process(train_tokens)\n",
"dev_tokens_ids = data_process(dev_tokens)\n",
"test_tokens_ids = data_process(test_tokens)\n",
"\n",
"train_labels = labels_process(train_data['ner'].apply(lambda x: convert_bio_to_int(x.split())).tolist())\n",
"dev_labels = labels_process(dev_labels['ner'].apply(lambda x: convert_bio_to_int(x.split())).tolist())"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "markdown",
"source": [
"### Model"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 0,
"outputs": [],
"source": [
"class LSTMModel(nn.Module):\n",
" def __init__(self, vocab_size, num_labels, embedding_dim=128, lstm_units=64):\n",
" super(LSTMModel, self).__init__()\n",
" self.embedding = nn.Embedding(vocab_size, embedding_dim)\n",
" self.lstm = nn.LSTM(embedding_dim, lstm_units, num_layers=1,\n",
" batch_first=True, dropout=0.1)\n",
" self.linear = nn.Linear(lstm_units, num_labels)\n",
" self.softmax = nn.LogSoftmax(dim=2)\n",
"\n",
" def forward(self, x):\n",
" embedded = self.embedding(x)\n",
" lstm_out, _ = self.lstm(embedded)\n",
" logits = self.linear(lstm_out)\n",
" return self.softmax(logits)\n",
"\n",
"vocab_size = len(token_vocab.get_itos())\n",
"num_labels = len(bio_to_int)\n",
"model = LSTMModel(vocab_size, num_labels)\n",
"\n",
"criterion = nn.CrossEntropyLoss()\n",
"optimizer = optim.Adam(model.parameters())"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "markdown",
"source": [
"### Trening i ewaluacja"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 0,
"outputs": [],
"source": [
"def get_scores(y_true, y_pred):\n",
" acc_score = 0\n",
" tp = 0\n",
" selected_items = 0\n",
" relevant_items = 0\n",
"\n",
" for p, t in zip(y_pred, y_true):\n",
" if p == t:\n",
" acc_score += 1\n",
"\n",
" if p > 0 and p == t:\n",
" tp += 1\n",
"\n",
" if p > 0:\n",
" selected_items += 1\n",
"\n",
" if t > 0:\n",
" relevant_items += 1\n",
"\n",
" if selected_items == 0:\n",
" precision = 1.0\n",
" else:\n",
" precision = tp / selected_items\n",
"\n",
" if relevant_items == 0:\n",
" recall = 1.0\n",
" else:\n",
" recall = tp / relevant_items\n",
"\n",
" if precision + recall == 0.0:\n",
" f1 = 0.0\n",
" else:\n",
" f1 = 2 * precision * recall / (precision + recall)\n",
"\n",
" return precision, recall, f1\n",
"\n",
"def eval_model(dataset_tokens, dataset_labels, model):\n",
" Y_true = []\n",
" Y_pred = []\n",
" for i in tqdm(range(len(dataset_labels))):\n",
" batch_tokens = dataset_tokens[i].unsqueeze(0)\n",
" tags = list(dataset_labels[i].numpy())\n",
" Y_true += tags\n",
"\n",
" Y_batch_pred_weights = model(batch_tokens).squeeze(0)\n",
" Y_batch_pred = torch.argmax(Y_batch_pred_weights, 1)\n",
" Y_pred += list(Y_batch_pred.numpy())\n",
"\n",
" return get_scores(Y_true, Y_pred)\n"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 14,
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"C:\\Users\\adamw\\PycharmProjects\\pythonProject\\venv\\lib\\site-packages\\torch\\nn\\modules\\rnn.py:83: UserWarning: dropout option adds dropout after all but last recurrent layer, so non-zero dropout expects num_layers greater than 1, but got dropout=0.1 and num_layers=1\n",
" warnings.warn(\"dropout option adds dropout after all but last \"\n",
"100%|██████████| 945/945 [00:26<00:00, 35.31it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 215/215 [00:00<00:00, 404.90it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"(0.6013011152416357, 0.2261710556979725, 0.3287044877222693)\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 945/945 [00:46<00:00, 20.34it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 2\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 215/215 [00:00<00:00, 255.35it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"(0.7338551859099804, 0.48065718946632485, 0.5808631979159333)\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 945/945 [00:45<00:00, 20.57it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 3\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 215/215 [00:00<00:00, 389.49it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"(0.809629044988161, 0.5976462363085527, 0.6876717838707515)\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 945/945 [00:45<00:00, 20.85it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 4\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 215/215 [00:00<00:00, 423.23it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"(0.8388561053109397, 0.6460032626427407, 0.7299058653149891)\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 945/945 [00:45<00:00, 20.96it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 5\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 215/215 [00:00<00:00, 402.62it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"(0.8558545239503252, 0.6745513866231647, 0.754463703896781)\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 945/945 [00:45<00:00, 20.99it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 6\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 215/215 [00:00<00:00, 416.67it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"(0.8630437966896147, 0.6865532509904451, 0.7647478746187292)\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 945/945 [00:45<00:00, 20.95it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 7\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 215/215 [00:00<00:00, 411.09it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"(0.8659315147997678, 0.6954089955721278, 0.7713584076515446)\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 945/945 [00:45<00:00, 20.96it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 8\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 215/215 [00:00<00:00, 413.46it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"(0.8631503920171062, 0.7055464926590538, 0.7764313650060909)\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 945/945 [00:45<00:00, 20.76it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 9\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 215/215 [00:00<00:00, 407.20it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"(0.8623391158365976, 0.718247494756467, 0.7837253655435473)\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 945/945 [00:45<00:00, 20.69it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 10\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 215/215 [00:00<00:00, 375.22it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"(0.865633442343239, 0.7093917501747844, 0.7797630483509447)\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 945/945 [00:45<00:00, 20.75it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 11\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 215/215 [00:00<00:00, 411.88it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"(0.86810551558753, 0.7170822652062456, 0.7853997830387339)\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 945/945 [00:45<00:00, 20.64it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 12\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 215/215 [00:00<00:00, 404.89it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"(0.8732292570106968, 0.7039151712887439, 0.779483870967742)\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 945/945 [00:46<00:00, 20.54it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 13\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 215/215 [00:00<00:00, 407.20it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"(0.8747655460972442, 0.706478676299231, 0.7816669889769869)\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 945/945 [00:46<00:00, 20.50it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 14\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 215/215 [00:00<00:00, 417.06it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"(0.8458568868054598, 0.714868329060825, 0.7748658035996211)\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 945/945 [00:46<00:00, 20.46it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 15\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 215/215 [00:00<00:00, 408.75it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"(0.8549187103885368, 0.7230249359123747, 0.783459595959596)\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 945/945 [00:46<00:00, 20.49it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 16\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 215/215 [00:00<00:00, 401.12it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"(0.8629124004550626, 0.7070612910743417, 0.7772511848341231)\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 945/945 [00:45<00:00, 20.58it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 17\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 215/215 [00:00<00:00, 419.10it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"(0.8741012472487161, 0.6941272430668842, 0.7737871013833864)\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 945/945 [00:45<00:00, 20.73it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 18\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 215/215 [00:00<00:00, 404.14it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"(0.8725857595848948, 0.7054299697040317, 0.7801546391752578)\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 945/945 [00:46<00:00, 20.28it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 19\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 215/215 [00:00<00:00, 411.87it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"(0.876505586997533, 0.7037986483337217, 0.7807147935112777)\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 945/945 [00:46<00:00, 20.32it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 20\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 215/215 [00:00<00:00, 413.46it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"(0.8801597869507324, 0.6931950594267071, 0.7755687373704453)\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 215/215 [00:00<00:00, 377.20it/s]\n",
"100%|██████████| 230/230 [00:00<00:00, 352.76it/s]\n"
]
}
],
"source": [
"NUM_EPOCHS = 20\n",
"for epoch in range(NUM_EPOCHS):\n",
" model.train()\n",
" for i in tqdm(range(len(train_labels))):\n",
" batch_tokens = train_tokens_ids[i].unsqueeze(0)\n",
" tags = train_labels[i].unsqueeze(1)\n",
"\n",
" predicted_tags = model(batch_tokens)\n",
"\n",
" optimizer.zero_grad()\n",
" loss = criterion(predicted_tags.view(-1, num_labels), tags.view(-1))\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
" model.eval()\n",
" print(f'Epoch {epoch + 1}')\n",
" print(eval_model(dev_tokens_ids, dev_labels, model))"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "markdown",
"source": [
"### Predykcje"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 17,
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 215/215 [00:01<00:00, 141.73it/s]\n",
"100%|██████████| 230/230 [00:01<00:00, 163.82it/s]\n"
]
}
],
"source": [
2024-05-27 14:16:53 +02:00
"def validate_bio_sequence(labels):\n",
" corrected_labels = []\n",
" previous_label = 'O'\n",
" for label in labels:\n",
" if label.startswith('I-'):\n",
" if previous_label == 'O' or previous_label[2:] != label[2:]:\n",
" corrected_labels.append('B-' + label[2:])\n",
" else:\n",
" corrected_labels.append(label)\n",
" else:\n",
" corrected_labels.append(label)\n",
" previous_label = corrected_labels[-1]\n",
" return corrected_labels\n",
"\n",
2024-05-26 21:42:42 +02:00
"def save_predictions(tokens_ids, model, output_path, label_mapping):\n",
" predictions = []\n",
" for i in tqdm(range(len(tokens_ids))):\n",
" batch_tokens = tokens_ids[i].unsqueeze(0)\n",
" Y_batch_pred_weights = model(batch_tokens).squeeze(0)\n",
" Y_batch_pred = torch.argmax(Y_batch_pred_weights, 1)\n",
" bio_labels = [label_mapping[label] for label in Y_batch_pred.numpy()[1:-1]]\n",
2024-05-27 14:16:53 +02:00
" bio_labels = validate_bio_sequence(bio_labels)\n",
2024-05-26 21:42:42 +02:00
" predictions.append(\" \".join(bio_labels))\n",
"\n",
" with open(output_path, 'w') as f:\n",
" for line in predictions:\n",
" f.write(line + '\\n')\n",
"\n",
"save_predictions(dev_tokens_ids, model, 'dev-0/out.tsv', label_mapping)\n",
"save_predictions(test_tokens_ids, model, 'test-A/out.tsv', label_mapping)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"start_time": "2024-05-26T21:29:28.194003Z",
"end_time": "2024-05-26T21:29:31.148001Z"
}
}
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 0
}