954 lines
33 KiB
Plaintext
954 lines
33 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 1,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"c:\\Users\\macty\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\tensorflow_addons\\utils\\tfa_eol_msg.py:23: UserWarning: \n",
|
||
"\n",
|
||
"TensorFlow Addons (TFA) has ended development and introduction of new features.\n",
|
||
"TFA has entered a minimal maintenance and release mode until a planned end of life in May 2024.\n",
|
||
"Please modify downstream libraries to take dependencies from other repositories in our TensorFlow community (e.g. Keras, Keras-CV, and Keras-NLP). \n",
|
||
"\n",
|
||
"For more information see: https://github.com/tensorflow/addons/issues/2807 \n",
|
||
"\n",
|
||
" warnings.warn(\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"import pandas as pd\n",
|
||
"import numpy as np\n",
|
||
"import tokenization\n",
|
||
"\n",
|
||
"import tensorflow as tf\n",
|
||
"import tensorflow_hub as hub\n",
|
||
"import tensorflow_addons as tfa\n",
|
||
"\n",
|
||
"import sklearn\n",
|
||
"from sklearn.model_selection import train_test_split\n",
|
||
"import glob\n",
|
||
"import os"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 2,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"path = os.getcwd()+'\\data'\n",
|
||
"tsv_files = glob.glob(os.path.join(path, \"*.tsv\"))\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 3,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"dfs = []\n",
|
||
"for filename in tsv_files:\n",
|
||
" df = pd.read_csv(filename, index_col=None, header=None, delimiter='\\t',names=[\"speaker\", \"sentence\", \"dialogue_act\"])\n",
|
||
" dfs.append(df)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 4,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"combined_df = pd.concat(dfs, axis=0, ignore_index=True)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 5,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>speaker</th>\n",
|
||
" <th>sentence</th>\n",
|
||
" <th>dialogue_act</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>0</th>\n",
|
||
" <td>user</td>\n",
|
||
" <td>Co proszę?</td>\n",
|
||
" <td>null()/hello()</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>1</th>\n",
|
||
" <td>system</td>\n",
|
||
" <td>Witam w systemie rezerwacji hotelu. Gdzie chci...</td>\n",
|
||
" <td>welcomemsg()</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>2</th>\n",
|
||
" <td>user</td>\n",
|
||
" <td>W jakim kraju/B-country mogę zarezerwować hotel?</td>\n",
|
||
" <td>help(country)</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>3</th>\n",
|
||
" <td>system</td>\n",
|
||
" <td>Mamy szeroki wybór hoteli na całym świecie.</td>\n",
|
||
" <td>expl-conf()</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>4</th>\n",
|
||
" <td>user</td>\n",
|
||
" <td>Przedstaw proszę oferty z obszaru Górnego Kara...</td>\n",
|
||
" <td>request(country=Górny Karabuch)</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>...</th>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>347</th>\n",
|
||
" <td>system</td>\n",
|
||
" <td>Okej w takim razie, proponuję ten sam hotel w ...</td>\n",
|
||
" <td>offer(price=110, date=02.07.2023- 08.07.2023)</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>348</th>\n",
|
||
" <td>user</td>\n",
|
||
" <td>Jak najbardziej. Proszę o zarezerwowanie/B-res...</td>\n",
|
||
" <td>confirm()</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>349</th>\n",
|
||
" <td>system</td>\n",
|
||
" <td>Dobrze, numer rezerwacji to 912312. Dokładny A...</td>\n",
|
||
" <td>inform(reservation_number=912312, address=3 ma...</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>350</th>\n",
|
||
" <td>user</td>\n",
|
||
" <td>Nie, dziękuję i życzę miłego dnia</td>\n",
|
||
" <td>negate()&thankyou()&bye()</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>351</th>\n",
|
||
" <td>system</td>\n",
|
||
" <td>Dziękuję bardzo wzajemnie.</td>\n",
|
||
" <td>thankyou()</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"<p>352 rows × 3 columns</p>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" speaker sentence \\\n",
|
||
"0 user Co proszę? \n",
|
||
"1 system Witam w systemie rezerwacji hotelu. Gdzie chci... \n",
|
||
"2 user W jakim kraju/B-country mogę zarezerwować hotel? \n",
|
||
"3 system Mamy szeroki wybór hoteli na całym świecie. \n",
|
||
"4 user Przedstaw proszę oferty z obszaru Górnego Kara... \n",
|
||
".. ... ... \n",
|
||
"347 system Okej w takim razie, proponuję ten sam hotel w ... \n",
|
||
"348 user Jak najbardziej. Proszę o zarezerwowanie/B-res... \n",
|
||
"349 system Dobrze, numer rezerwacji to 912312. Dokładny A... \n",
|
||
"350 user Nie, dziękuję i życzę miłego dnia \n",
|
||
"351 system Dziękuję bardzo wzajemnie. \n",
|
||
"\n",
|
||
" dialogue_act \n",
|
||
"0 null()/hello() \n",
|
||
"1 welcomemsg() \n",
|
||
"2 help(country) \n",
|
||
"3 expl-conf() \n",
|
||
"4 request(country=Górny Karabuch) \n",
|
||
".. ... \n",
|
||
"347 offer(price=110, date=02.07.2023- 08.07.2023) \n",
|
||
"348 confirm() \n",
|
||
"349 inform(reservation_number=912312, address=3 ma... \n",
|
||
"350 negate()&thankyou()&bye() \n",
|
||
"351 thankyou() \n",
|
||
"\n",
|
||
"[352 rows x 3 columns]"
|
||
]
|
||
},
|
||
"execution_count": 5,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"combined_df"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 6,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"def extract_labels(sentence):\n",
|
||
" tokens = sentence.split()\n",
|
||
" labels = []\n",
|
||
" for token in tokens:\n",
|
||
" parts = token.split(\"/\")\n",
|
||
" if len(parts) > 1:\n",
|
||
" label = parts[1]\n",
|
||
" if label.startswith('B-'):\n",
|
||
" labels.append(label[0:])\n",
|
||
" elif label.startswith('I-'):\n",
|
||
" labels.append(label[0:])\n",
|
||
" else:\n",
|
||
" labels.append('O')\n",
|
||
" else:\n",
|
||
" labels.append('O')\n",
|
||
" return labels\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 31,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"labels = combined_df['sentence'].apply(extract_labels)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 9,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"labels_list = [label for sentence in labels for label in sentence ]"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 10,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"unique_labels = set(labels_list)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 11,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"{'B-alternative',\n",
|
||
" 'B-animal',\n",
|
||
" 'B-area',\n",
|
||
" 'B-area,',\n",
|
||
" 'B-area.',\n",
|
||
" 'B-available?',\n",
|
||
" 'B-beggining',\n",
|
||
" 'B-checkin?',\n",
|
||
" 'B-city',\n",
|
||
" 'B-city,',\n",
|
||
" 'B-city.',\n",
|
||
" 'B-confirmation',\n",
|
||
" 'B-country',\n",
|
||
" 'B-country,',\n",
|
||
" 'B-country?',\n",
|
||
" 'B-date',\n",
|
||
" 'B-date?',\n",
|
||
" 'B-day',\n",
|
||
" 'B-day-11',\n",
|
||
" 'B-day-28',\n",
|
||
" 'B-days',\n",
|
||
" 'B-days,',\n",
|
||
" 'B-days.',\n",
|
||
" 'B-email',\n",
|
||
" 'B-facilities',\n",
|
||
" 'B-facilities,',\n",
|
||
" 'B-facilities.',\n",
|
||
" 'B-facilities...',\n",
|
||
" 'B-facilities?',\n",
|
||
" 'B-finish',\n",
|
||
" 'B-first',\n",
|
||
" 'B-hotel',\n",
|
||
" 'B-hotel.',\n",
|
||
" 'B-hotel?',\n",
|
||
" 'B-insurance',\n",
|
||
" 'B-insurance?',\n",
|
||
" 'B-location',\n",
|
||
" 'B-month',\n",
|
||
" 'B-month.',\n",
|
||
" 'B-month?',\n",
|
||
" 'B-next',\n",
|
||
" 'B-nights',\n",
|
||
" 'B-nights,',\n",
|
||
" 'B-number_of_rooms',\n",
|
||
" 'B-payment',\n",
|
||
" 'B-payment?',\n",
|
||
" 'B-people',\n",
|
||
" 'B-people?',\n",
|
||
" 'B-per_night',\n",
|
||
" 'B-per_night.',\n",
|
||
" 'B-price',\n",
|
||
" 'B-price!',\n",
|
||
" 'B-price.',\n",
|
||
" 'B-price?',\n",
|
||
" 'B-reservation',\n",
|
||
" 'B-reservation_number',\n",
|
||
" 'B-room_size',\n",
|
||
" 'B-room_size,',\n",
|
||
" 'B-room_size.',\n",
|
||
" 'B-room_size?',\n",
|
||
" 'B-rooms',\n",
|
||
" 'B-sickness',\n",
|
||
" 'B-size',\n",
|
||
" 'B-size,',\n",
|
||
" 'B-size.',\n",
|
||
" 'B-size?',\n",
|
||
" 'B-stars',\n",
|
||
" 'B-stars?',\n",
|
||
" 'B-sum',\n",
|
||
" 'B-week',\n",
|
||
" 'B-weekend',\n",
|
||
" 'B-weekend.',\n",
|
||
" 'B-weekend?',\n",
|
||
" 'B-year',\n",
|
||
" 'B-year.',\n",
|
||
" 'I-area',\n",
|
||
" 'I-country',\n",
|
||
" 'I-country.',\n",
|
||
" 'I-day',\n",
|
||
" 'I-day.',\n",
|
||
" 'I-days',\n",
|
||
" 'I-facilities',\n",
|
||
" 'I-facilities.',\n",
|
||
" 'I-hotel',\n",
|
||
" 'I-location,',\n",
|
||
" 'I-month',\n",
|
||
" 'I-payment',\n",
|
||
" 'I-perperson',\n",
|
||
" 'I-room_size',\n",
|
||
" 'I-year',\n",
|
||
" 'O'}"
|
||
]
|
||
},
|
||
"execution_count": 11,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"unique_labels"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 12,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"num_labels = unique_labels.__len__()+1"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 13,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"OTHER_LABEL = \"pad\"\n",
|
||
"label_map = {label: i for i, label in enumerate(unique_labels, start=1)}\n",
|
||
"label_map[OTHER_LABEL] = 0\n",
|
||
"\n",
|
||
"# Convert the flattened labels array to a numpy array of integers\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 14,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"0"
|
||
]
|
||
},
|
||
"execution_count": 14,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"label_map[\"pad\"]"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 15,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"55"
|
||
]
|
||
},
|
||
"execution_count": 15,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"label_map['O']"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 32,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"0 [O, O]\n",
|
||
"1 [O, O, O, O, O, O, O, O, O]\n",
|
||
"2 [O, O, B-country, O, O, O]\n",
|
||
"3 [O, O, O, O, O, O, O]\n",
|
||
"4 [O, O, O, O, O, O, B-country]\n",
|
||
" ... \n",
|
||
"347 [O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, ...\n",
|
||
"348 [O, O, O, O, B-reservation, O, O]\n",
|
||
"349 [O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, ...\n",
|
||
"350 [O, O, O, O, O, O]\n",
|
||
"351 [O, O, O]\n",
|
||
"Name: sentence, Length: 352, dtype: object"
|
||
]
|
||
},
|
||
"execution_count": 32,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"labels"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 19,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"c:\\Users\\macty\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
||
" from .autonotebook import tqdm as notebook_tqdm\n",
|
||
"Some weights of the model checkpoint at allegro/herbert-base-cased were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.sso.sso_relationship.weight', 'cls.sso.sso_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.weight']\n",
|
||
"- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
|
||
"- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"from transformers import AutoTokenizer, AutoModel\n",
|
||
"tokenizer = AutoTokenizer.from_pretrained(\"allegro/herbert-base-cased\")\n",
|
||
"model = AutoModel.from_pretrained(\"allegro/herbert-base-cased\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 20,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"tokenizer = AutoTokenizer.from_pretrained(\"allegro/herbert-base-cased\")\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 33,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Some weights of the model checkpoint at allegro/herbert-base-cased were not used when initializing BertForTokenClassification: ['cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.sso.sso_relationship.weight', 'cls.sso.sso_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.weight']\n",
|
||
"- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
|
||
"- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
|
||
"Some weights of BertForTokenClassification were not initialized from the model checkpoint at allegro/herbert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
|
||
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
|
||
"C:\\Users\\macty\\AppData\\Local\\Temp\\ipykernel_1684\\3997969718.py:42: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
||
" torch.tensor(labels, dtype=torch.long)\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"# 1. Preprocess the data\n",
|
||
"import pandas as pd\n",
|
||
"import numpy as np\n",
|
||
"from transformers import AutoTokenizer\n",
|
||
"from tensorflow.keras.preprocessing.sequence import pad_sequences\n",
|
||
"import torch\n",
|
||
"from torch.utils.data import DataLoader, TensorDataset\n",
|
||
"from transformers import AdamW, get_linear_schedule_with_warmup\n",
|
||
"# Initialize the tokenizer\n",
|
||
"# Define a maximum sequence length\n",
|
||
"max_length = 128\n",
|
||
"\n",
|
||
"# Tokenize the text\n",
|
||
"tokens = tokenizer(combined_df[\"sentence\"].tolist(), padding='max_length', truncation=True, max_length=max_length, return_tensors=\"pt\")\n",
|
||
"\n",
|
||
"# Create attention masks\n",
|
||
"attention_masks = tokens[\"attention_mask\"]\n",
|
||
"\n",
|
||
"# Truncate or pad the labels to match the sequence length\n",
|
||
"labels = [[label_map.get(l, 0) for l in lab] for lab in labels]\n",
|
||
"labels = pad_sequences(labels, maxlen=max_length, value=0, dtype=np.int32, truncating='post', padding='post')\n",
|
||
"labels = torch.tensor(labels, dtype=torch.long)\n",
|
||
"# Convert the preprocessed data into a PyTorch Dataset\n",
|
||
"dataset = TensorDataset(tokens[\"input_ids\"], attention_masks, labels)\n",
|
||
"\n",
|
||
"\n",
|
||
"\n",
|
||
"# 2. Define the NER model\n",
|
||
"\n",
|
||
"from transformers import AutoModelForTokenClassification\n",
|
||
"\n",
|
||
"# Load the pre-trained model\n",
|
||
"model = AutoModelForTokenClassification.from_pretrained(\"allegro/herbert-base-cased\", num_labels=num_labels)\n",
|
||
"\n",
|
||
"# 3. Train the NER model\n",
|
||
"\n",
|
||
"\n",
|
||
"# Convert the preprocessed data into a PyTorch Dataset\n",
|
||
"dataset = dataset = TensorDataset(\n",
|
||
" tokens[\"input_ids\"],\n",
|
||
" attention_masks,\n",
|
||
" torch.tensor(labels, dtype=torch.long)\n",
|
||
")\n",
|
||
"\n",
|
||
"\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 35,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"c:\\Users\\macty\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\transformers\\optimization.py:391: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
|
||
" warnings.warn(\n",
|
||
"C:\\Users\\macty\\AppData\\Local\\Temp\\ipykernel_1684\\2322694894.py:30: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
||
" train_dataset = TensorDataset(train_inputs, train_masks, torch.tensor(train_labels))\n",
|
||
"C:\\Users\\macty\\AppData\\Local\\Temp\\ipykernel_1684\\2322694894.py:33: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
||
" test_dataset = TensorDataset(test_inputs, test_masks, torch.tensor(test_labels))\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"# Define the training parameters\n",
|
||
"batch_size = 8\n",
|
||
"optimizer = AdamW(model.parameters(), lr=5e-5, eps=1e-8)\n",
|
||
"scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=len(dataset) // batch_size + 1)\n",
|
||
"loss_fn = torch.nn.CrossEntropyLoss()\n",
|
||
"\n",
|
||
"# Define the training loop\n",
|
||
"def train(model, dataloader, optimizer, scheduler, loss_fn):\n",
|
||
" model.train()\n",
|
||
" for batch in dataloader:\n",
|
||
" inputs = {key: value.to(model.device) for key, value in dict(zip([\"input_ids\", \"attention_mask\"], batch)).items()}\n",
|
||
" labels = batch[2].to(model.device)\n",
|
||
" optimizer.zero_grad()\n",
|
||
"\n",
|
||
" outputs = model(**inputs, labels=labels)\n",
|
||
" loss = loss_fn(outputs.logits.view(-1, num_labels), labels.view(-1))\n",
|
||
"\n",
|
||
" loss.backward()\n",
|
||
"\n",
|
||
" optimizer.step()\n",
|
||
" scheduler.step()\n",
|
||
"\n",
|
||
"# Split the dataset into training and testing sets\n",
|
||
"from sklearn.model_selection import train_test_split\n",
|
||
"train_inputs, test_inputs, train_labels, test_labels = train_test_split(tokens[\"input_ids\"], labels, test_size=0.2)\n",
|
||
"\n",
|
||
"train_masks, test_masks, _, _ = train_test_split(attention_masks, tokens[\"input_ids\"], test_size=0.2)\n",
|
||
"\n",
|
||
"# Convert the preprocessed data into PyTorch Dataloaders\n",
|
||
"train_dataset = TensorDataset(train_inputs, train_masks, torch.tensor(train_labels))\n",
|
||
"train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n",
|
||
"\n",
|
||
"test_dataset = TensorDataset(test_inputs, test_masks, torch.tensor(test_labels))\n",
|
||
"test_dataloader = DataLoader(test_dataset, batch_size=batch_size)\n",
|
||
"\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 36,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"tensor([[55, 55, 0, ..., 0, 0, 0],\n",
|
||
" [55, 55, 55, ..., 0, 0, 0],\n",
|
||
" [55, 55, 73, ..., 0, 0, 0],\n",
|
||
" ...,\n",
|
||
" [55, 55, 55, ..., 0, 0, 0],\n",
|
||
" [55, 55, 55, ..., 0, 0, 0],\n",
|
||
" [55, 55, 55, ..., 0, 0, 0]])"
|
||
]
|
||
},
|
||
"execution_count": 36,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"labels"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 37,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Accuracy: 0.88\n",
|
||
"Accuracy: 0.95\n",
|
||
"Accuracy: 0.92\n",
|
||
"Accuracy: 0.94\n",
|
||
"Accuracy: 0.93\n",
|
||
"Accuracy: 0.94\n",
|
||
"Accuracy: 0.94\n",
|
||
"Accuracy: 0.92\n",
|
||
"Accuracy: 0.90\n",
|
||
"Accuracy: 0.88\n",
|
||
"Accuracy: 0.95\n",
|
||
"Accuracy: 0.92\n",
|
||
"Accuracy: 0.94\n",
|
||
"Accuracy: 0.93\n",
|
||
"Accuracy: 0.94\n",
|
||
"Accuracy: 0.94\n",
|
||
"Accuracy: 0.92\n",
|
||
"Accuracy: 0.90\n",
|
||
"Accuracy: 0.88\n",
|
||
"Accuracy: 0.95\n",
|
||
"Accuracy: 0.92\n",
|
||
"Accuracy: 0.94\n",
|
||
"Accuracy: 0.93\n",
|
||
"Accuracy: 0.94\n",
|
||
"Accuracy: 0.94\n",
|
||
"Accuracy: 0.92\n",
|
||
"Accuracy: 0.90\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"\n",
|
||
"# Train the model\n",
|
||
"epochs = 3\n",
|
||
"for epoch in range(epochs):\n",
|
||
" train(model, train_dataloader, optimizer, scheduler, loss_fn)\n",
|
||
" # Evaluate the model\n",
|
||
" model.eval()\n",
|
||
" with torch.no_grad():\n",
|
||
" for batch in test_dataloader:\n",
|
||
" inputs = {key: value.to(model.device) for key, value in dict(zip([\"input_ids\", \"attention_mask\"], batch)).items()}\n",
|
||
" labels = batch[2].to(model.device)\n",
|
||
" outputs = model(**inputs)\n",
|
||
" predictions = outputs.logits.argmax(dim=-1)\n",
|
||
"\n",
|
||
" # Calculate the accuracy\n",
|
||
" accuracy = (predictions == labels).float().mean().item()\n",
|
||
"\n",
|
||
" print(f\"Accuracy: {accuracy:.2f}\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"ename": "IndexError",
|
||
"evalue": "index 1 is out of bounds for dimension 0 with size 1",
|
||
"output_type": "error",
|
||
"traceback": [
|
||
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
|
||
"\u001b[1;31mIndexError\u001b[0m Traceback (most recent call last)",
|
||
"Cell \u001b[1;32mIn[271], line 1\u001b[0m\n\u001b[1;32m----> 1\u001b[0m predictions[\u001b[39m1\u001b[39;49m]\n",
|
||
"\u001b[1;31mIndexError\u001b[0m: index 1 is out of bounds for dimension 0 with size 1"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"predictions[1]"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"tensor([[25, 25, 25, 25, 25, 25, 25, 84, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0],\n",
|
||
" [25, 25, 25, 25, 84, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0],\n",
|
||
" [25, 25, 25, 25, 25, 25, 25, 25, 15, 25, 25, 32, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0],\n",
|
||
" [25, 25, 25, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0],\n",
|
||
" [25, 25, 25, 25, 25, 25, 25, 25, 25, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0],\n",
|
||
" [25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 75, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0],\n",
|
||
" [25, 25, 25, 25, 25, 25, 66, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0]])"
|
||
]
|
||
},
|
||
"execution_count": 269,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"labels"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 41,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"['pad', 'pad', 'pad', 'pad', 'pad', 'pad', 'pad', 'pad', 'pad', 'pad', 'pad', 'pad', 'pad', 'pad', 'pad', 'pad', 'pad', 'pad']\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"# Define the sentence\n",
|
||
"sentence = \"Hej, chciałbym zamówić pokój w Poznaniu na termin 25.03 - 17.04\"\n",
|
||
"\n",
|
||
"# Tokenize the sentence\n",
|
||
"input_ids = tokenizer.encode(sentence, add_special_tokens=True, return_tensors=\"pt\")\n",
|
||
"\n",
|
||
"# Create the attention mask\n",
|
||
"attention_mask = torch.ones_like(input_ids)\n",
|
||
"\n",
|
||
"# Make the prediction\n",
|
||
"with torch.no_grad():\n",
|
||
" outputs = model(input_ids=input_ids, attention_mask=attention_mask)\n",
|
||
" predictions = outputs.logits.argmax(dim=-1)\n",
|
||
"index_label_map = {v: k for k, v in label_map.items()}\n",
|
||
"# Decode the predicted labels\n",
|
||
"predicted_labels = [index_label_map[label] for label in predictions[0].tolist()]\n",
|
||
"\n",
|
||
"# Print the predicted labels\n",
|
||
"print(predicted_labels)\n",
|
||
"\n",
|
||
"\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 40,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"25"
|
||
]
|
||
},
|
||
"execution_count": 285,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"label_map[\"O\"]"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 39,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])"
|
||
]
|
||
},
|
||
"execution_count": 39,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"predictions[0]"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"[(0, 0), (0, 2), (3, 9), (9, 10), (0, 5)]\n"
|
||
]
|
||
}
|
||
],
|
||
"source": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": []
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "Python 3",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.11.2"
|
||
},
|
||
"orig_nbformat": 4
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 2
|
||
}
|