systemy_dialogowe/NLU.ipynb

774 lines
34 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\macty\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\tensorflow_addons\\utils\\tfa_eol_msg.py:23: UserWarning: \n",
"\n",
"TensorFlow Addons (TFA) has ended development and introduction of new features.\n",
"TFA has entered a minimal maintenance and release mode until a planned end of life in May 2024.\n",
"Please modify downstream libraries to take dependencies from other repositories in our TensorFlow community (e.g. Keras, Keras-CV, and Keras-NLP). \n",
"\n",
"For more information see: https://github.com/tensorflow/addons/issues/2807 \n",
"\n",
" warnings.warn(\n"
]
}
],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"import tokenization\n",
"\n",
"import tensorflow as tf\n",
"import tensorflow_hub as hub\n",
"import tensorflow_addons as tfa\n",
"\n",
"import sklearn\n",
"from sklearn.model_selection import train_test_split"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"acts=pd.read_csv('data-only_dialogue_acts/user_acts.csv',index_col=None)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Agent</th>\n",
" <th>text</th>\n",
" <th>Act</th>\n",
" <th>ack</th>\n",
" <th>affirm</th>\n",
" <th>bye</th>\n",
" <th>hello</th>\n",
" <th>help</th>\n",
" <th>negate</th>\n",
" <th>null</th>\n",
" <th>repeat</th>\n",
" <th>reqalts</th>\n",
" <th>reqmore</th>\n",
" <th>restart</th>\n",
" <th>silence</th>\n",
" <th>thankyou</th>\n",
" <th>confirm</th>\n",
" <th>deny</th>\n",
" <th>inform</th>\n",
" <th>request</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>user</td>\n",
" <td>W jakim kraju mogę zarezerwować hotel?</td>\n",
" <td>help</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>user</td>\n",
" <td>Przedstaw proszę oferty z obszaru Górnego Kara...</td>\n",
" <td>request</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>user</td>\n",
" <td>3</td>\n",
" <td>inform</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>user</td>\n",
" <td>1000 USD na osobę</td>\n",
" <td>inform</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>user</td>\n",
" <td>Ostatni tydzień maja 2023 na 6 dni</td>\n",
" <td>inform</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Agent text Act ack \\\n",
"0 user W jakim kraju mogę zarezerwować hotel? help NaN \n",
"1 user Przedstaw proszę oferty z obszaru Górnego Kara... request NaN \n",
"2 user 3 inform NaN \n",
"3 user 1000 USD na osobę inform NaN \n",
"4 user Ostatni tydzień maja 2023 na 6 dni inform NaN \n",
"\n",
" affirm bye hello help negate null repeat reqalts reqmore restart \\\n",
"0 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN \n",
"1 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN \n",
"2 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN \n",
"3 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN \n",
"4 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN \n",
"\n",
" silence thankyou confirm deny inform request \n",
"0 NaN NaN NaN NaN NaN NaN \n",
"1 NaN NaN NaN NaN NaN NaN \n",
"2 NaN NaN NaN NaN NaN NaN \n",
"3 NaN NaN NaN NaN NaN NaN \n",
"4 NaN NaN NaN NaN NaN NaN "
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"acts.head()"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"for column in acts.columns[3:]:\n",
" acts[column]=acts[\"Act\"].str.contains(column).astype(int)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"acts.to_csv('user_acts_one_hot.csv')"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"acts=acts.drop([\"Agent\"],axis=1)\n",
"acts=acts.drop([\"Act\"],axis=1)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"text = acts[\"text\"]\n",
"labels = acts.drop([\"text\"],axis=1)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"label_cols=labels.columns\n",
"labels=labels[labels.columns].values"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"text = np.array(text)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array(['W jakim kraju mogę zarezerwować hotel?',\n",
" 'Przedstaw proszę oferty z obszaru Górnego Karabachu', '3',\n",
" '1000 USD na osobę', 'Ostatni tydzień maja 2023 na 6 dni',\n",
" 'Gorąca woda i bezpieczna okolica',\n",
" 'Podaj proszę kosztorys dla hotelu YY',\n",
" 'Czy jest to cena łączna dla 3 osób?',\n",
" 'Czy oferta zawiera ubezpieczenie?',\n",
" 'Ile wynosi łączna cena z ubezpieczeniem dla 3 osób?',\n",
" 'Rezygnuję z rezerwacji. Potrzebuję więcej czasu do namysłu.',\n",
" 'chciałbym zwiedzić coś egzotycznego. Co możesz mi zaproponować?',\n",
" 'tam jest trochę za zimno',\n",
" 'brzmi bardzo fajnie! Jakie są ceny hoteli w Brazylii?',\n",
" 'Chciałbym aby to był pokój dla 2 osób i doga',\n",
" 'Dzień dobry. Chciałbym wybrać się do warszawy na przyszły weekend. Szukam pokoi dla dwóch osób w cenie do 500 zł za noc.',\n",
" 'Zależałoby mi na śniadaniu oraz na tym żeby hotel był blisko centrum miasta. Miło by było jakby był dostępny basen.',\n",
" 'Poproszę w takim wypadku o rezerwację w hotelu YY.',\n",
" '31.03-02.04', 'Tak, dziękuję', 'Dziękuję!',\n",
" 'Dzień dobry, szukam ofert wycieczek wakacyjnych last minute, czy mają państwo jakieś w ofercie?',\n",
" 'Interesowałaby mnie Chorwacja, Grecja lub Cypr. Czy któryś z tych krajów wchodzi w grę?',\n",
" 'Szukam pokoju dla dwóch osób. Proszę o polecenie kilku z powyższych krajów',\n",
" 'Hotel w zabytkowym centrum miasta położonego nad morzem',\n",
" 'Maksymalna cena za jedną dobę to 300 złotych',\n",
" 'Znalazłaby się jakaś alternatywa dla hotelu na Cyprze z wycieczką w cenie?',\n",
" 'A co z wycieczką?', 'Hmmm, to może ten ze szpitalem jednak',\n",
" 'Od przyszłego poniedziałku do piątku', 'na 9',\n",
" 'jeśli jest możliwość to 1 osobowe, a jak nie to 2',\n",
" 'budżet to nie problem',\n",
" 'chciałbym żeby był all inclusive i przede wszystkim był blisko kasyna',\n",
" 'Poproszę zarezerwować pokój xxx w takim razie',\n",
" 'Dziękuję! Pobyt na pewno będzie miły',\n",
" 'Dzień dobry, chciałbym zarezerwować nocleg w jakimś tanim hotelu we Włoskich Alpach',\n",
" '1 osoba',\n",
" 'Chciałbym mieć nocleg na 5 dni, maksymalnie 200 euro za noc.',\n",
" 'Przydała by się sauna oraz jacuzzi.',\n",
" 'Czy w cenę hotelu wchodzą może skipassy?',\n",
" 'O super! Który by Pani hotel wybrała? Wolę miejsca gdzie wieczorem jest się gdzie pobawić...',\n",
" 'Super! Proszę o rezerwację i życzę miłego dnia',\n",
" 'A no tak, to też ważna kwestia... interesuje mnie pierwszy tydzień grudnia tego roku.',\n",
" 'tak, czy sa wolne pokoje z ładnym widokiem dla 2 osob w przyszły weekend?',\n",
" 'Warszawa', 'a mozna troche dalej, na rynku straszne halasy sa',\n",
" 'tak, to jest to', 'potwierdzam',\n",
" 'chciałbym zarezerwować pokój dwuosobowy na dni 25-28 marca 2023',\n",
" 'w Poznaniu', 'balkon w pokoju, bar w hotelu',\n",
" 'zależy ile kosztuje', 'brzmi dobrze', 'dziękuję',\n",
" 'Chciałbym zarezerowować pokój w jakimś hotelu w Warszawie, możliwie jak najbliżej lotniska.',\n",
" 'To zależy od tego jaka jest cena. Mogę poprosić o szerszy wybór hoteli?',\n",
" 'W tej sytuacji jestem zainteresowany rezerwacją tego hotelu, który jest najbliżej',\n",
" '31.03.2023 - 02.04.2023', 'Tak, dla jednej', 'dziękuję bardzo',\n",
" 'Witam, chcialbym zarezerwowac pokoj dla 3 osob, 2 os na jeden pokoj a ta trzecia osobno',\n",
" 'W Sosnowcu, na 9-11 września', 'Poproszę',\n",
" 'Dziękuję, kiedy mogę zapłacić?',\n",
" 'Ok, to zapłacę przy odebraniu kluczy', 'Do zobaczenia',\n",
" 'Chciałabym zamówić pokój dla dwóch osób. ',\n",
" 'Sopot. Czy jest możliwość z widokiem na molo? ', 'Nie mam ',\n",
" 'od 1 do 3 maja tego roku ', 'piękna sprawa ', 'Dziękuję pięknie ',\n",
" 'Do zobaczenia! ', 'podaj menu, dań głównych',\n",
" 'tak jaki hotel jest najbliżej Poznania', 'czy w zxy jest basen?',\n",
" 'a xyz?', 'czy sa wolne 2 osobowe pokoje na 25.03?',\n",
" 'tak, o której zaczyna się doba hotelowa?',\n",
" 'potwierdzam rezerwację',\n",
" 'chcialabym dokonac rezerwacji pokoju w warszawie',\n",
" 'wazny jest dla mnie jedyny parking oraz sniadania w formie bufetu',\n",
" 'w jakiej cenie jest hotel yzx',\n",
" 'dobrze to chcialabym zarezerwowac ten hotel',\n",
" 'na DD.MM.RR do DD.MM.RR', 'dla 2', 'dziekuje',\n",
" 'Chciałbym zarezerwować hotel w Zakopanem na najbliższy weekend.',\n",
" 'Czy hotel posiada przynajmniej 3 gwiazdki oraz basen?',\n",
" 'Poproszę o znalezienie hotelu z basenem.',\n",
" 'Tak proszę o dokonanie rezerwacji pokoju dla trzech osób.',\n",
" 'Dziękuję bardzo',\n",
" 'Chciałbym zarezerwować pokój na dwie doby, najlepiej z widokiem na miasto. Rezerwacja od 25.03 do 27.03. Czy jest może u Państwa dostępny pokój?',\n",
" 'Tak, to ten hotel w Poznaniu', 'Tak, poproszę dla dwóch osób',\n",
" 'Bardzo dziękuję, będziemy za 3 dni o godzinie 15',\n",
" 'Do zobaczenia', '8976098',\n",
" 'Bardzo dziękuję za sprawne załatwienie tematu. Powodem jest choroba mojej żony i niestety nie jesteśmy w stanie udać się w podróż ze względu na jej stan fizyczny.',\n",
" 'Jeszcze raz bardzo dziękuję',\n",
" 'chiałem zarezerwować dwa pokoje na dni 28-30 kwiecień',\n",
" 'w karpaczu, oba pokoje dwuosobowe',\n",
" 'czy w ofercie jest wyżywienie?',\n",
" 'czy jest opcja samego śniadania? obiad zjemy gdzieś na mieście',\n",
" 'tak proszę tak zrobić',\n",
" 'mam jeszcze takie pytanie na koniec, czy hotel zapewnia swoje miejsca parkingowe dla gości?',\n",
" 'dziękuję to chyba wszystko',\n",
" 'Czy mają Państwo wolny pokój 2-osobowy w terminie od 10 do 12 kwietnia?',\n",
" 'Prosze podac dostepne lokalizacje', 'Kraków', 'Tak',\n",
" 'To wszystko dziękuję.', 'Czy dostanę potwierdzenie mailem?',\n",
" 'xyz@gmail.com', 'Dziękuję', 'To wszystko', 'Pozdrawiam',\n",
" 'chciałbym zwiedzić coś egoztycznego. Co możesz mi polecić?',\n",
" 'Lokalizacja jest ok, a jak to wygląda cenowo?',\n",
" 'Interesuje mnie dwuosobowy pokój', 'Tak, to super cena. Poproszę',\n",
" 'Dziękuję to wszystko.',\n",
" 'Dzień dobry. Ile kosztuje u was wynajęcie pokoju dla dwóch osób?',\n",
" 'Chciałbym wynająć pokój na przyszły weekend w Poznaniu', 'Tak',\n",
" 'Płatność kartą na miejscu. To wszystko, dziękuję.',\n",
" 'Chciałbym zarezerwować pokój na jutro',\n",
" 'jakie pokoje są dostępne?',\n",
" 'Poznań Wilda, koło rynku wildeckiego',\n",
" 'poproszę pokój 1 osobowy, najlepiej byłoby na jak najwyższym piętrze',\n",
" 'Tak jestem zainteresowany. Czy w hotelu jest restauracja?',\n",
" 'Tak poproszę pokój + posiłki, W jakich porach mogę się zakwaterować?',\n",
" 'Tak, proszę o rezerwację', 'To wszystko, dziękuję za pomoc ',\n",
" 'Szukam noclegu w Pieninach',\n",
" 'Chodzi o koniec lipca lub początek sierpnia',\n",
" '5 nocy od poniedziałku do soboty', 'Nocleg dla 6 osób',\n",
" 'Znalazłby się jeden pokój sześcioosobowy?',\n",
" 'Wrócę do wersji z 2 pokojami. Proszę o rezerwację',\n",
" 'Kiedy mogę dokonać płatności?', 'To wszystko. Dziękuję bardzo.',\n",
" 'A jaki jest najbliższy hotel 5 gwiazdkowy?', 'Tak poproszę.',\n",
" 'poproszę zarezerwować na 02.07.27',\n",
" 'chciałbym zamówić dwa pokoje 4 osobowe', 'Tak poproszę',\n",
" 'Nie, bo będzie wszystko', 'dzięki',\n",
" 'Chciałbym zarezerwować apartament dla dwóch osób na najbliższy weekend.',\n",
" 'W warszawie, centrum', 'Tak jak najbardziej, brzmi świetnie',\n",
" 'NIe, dziękuję bardzo',\n",
" 'Chciałbym wybrać się nad polskie morze, tak na 7 dni. Najlepiej by hotel znajdował się w 1 linii brzegowej, jak również by było tam jak najmniej Januszostwa.',\n",
" 'A no i żeby było tanio!', 'Wie Pan, ale ja chce jechać w lipcu.',\n",
" 'Jak najbardziej. Proszę o zarezerwowanie tego pokoju.',\n",
" 'Nie, dziękuję i życzę miłego dnia'], dtype=object)"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"text"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\macty\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n",
"c:\\Users\\macty\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\transformers\\generation_utils.py:24: FutureWarning: Importing `GenerationMixin` from `src/transformers/generation_utils.py` is deprecated and will be removed in Transformers v5. Import as `from transformers import GenerationMixin` instead.\n",
" warnings.warn(\n",
"c:\\Users\\macty\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\transformers\\generation_tf_utils.py:24: FutureWarning: Importing `TFGenerationMixin` from `src/transformers/generation_tf_utils.py` is deprecated and will be removed in Transformers v5. Import as `from transformers import TFGenerationMixin` instead.\n",
" warnings.warn(\n",
"loading file vocab.txt from cache at C:\\Users\\macty/.cache\\huggingface\\hub\\models--dkleczek--bert-base-polish-uncased-v1\\snapshots\\62be9821055981deafb23f217b68cc41f38cdb76\\vocab.txt\n",
"loading file added_tokens.json from cache at None\n",
"loading file special_tokens_map.json from cache at C:\\Users\\macty/.cache\\huggingface\\hub\\models--dkleczek--bert-base-polish-uncased-v1\\snapshots\\62be9821055981deafb23f217b68cc41f38cdb76\\special_tokens_map.json\n",
"loading file tokenizer_config.json from cache at C:\\Users\\macty/.cache\\huggingface\\hub\\models--dkleczek--bert-base-polish-uncased-v1\\snapshots\\62be9821055981deafb23f217b68cc41f38cdb76\\tokenizer_config.json\n",
"loading configuration file config.json from cache at C:\\Users\\macty/.cache\\huggingface\\hub\\models--dkleczek--bert-base-polish-uncased-v1\\snapshots\\62be9821055981deafb23f217b68cc41f38cdb76\\config.json\n",
"Model config BertConfig {\n",
" \"_name_or_path\": \"dkleczek/bert-base-polish-uncased-v1\",\n",
" \"architectures\": [\n",
" \"BertForMaskedLM\",\n",
" \"BertForPreTraining\"\n",
" ],\n",
" \"attention_probs_dropout_prob\": 0.1,\n",
" \"classifier_dropout\": null,\n",
" \"hidden_act\": \"gelu\",\n",
" \"hidden_dropout_prob\": 0.1,\n",
" \"hidden_size\": 768,\n",
" \"initializer_range\": 0.02,\n",
" \"intermediate_size\": 3072,\n",
" \"layer_norm_eps\": 1e-12,\n",
" \"max_position_embeddings\": 512,\n",
" \"model_type\": \"bert\",\n",
" \"num_attention_heads\": 12,\n",
" \"num_hidden_layers\": 12,\n",
" \"output_past\": true,\n",
" \"pad_token_id\": 0,\n",
" \"position_embedding_type\": \"absolute\",\n",
" \"transformers_version\": \"4.28.1\",\n",
" \"type_vocab_size\": 2,\n",
" \"use_cache\": true,\n",
" \"vocab_size\": 60000\n",
"}\n",
"\n",
"loading configuration file config.json from cache at C:\\Users\\macty/.cache\\huggingface\\hub\\models--dkleczek--bert-base-polish-uncased-v1\\snapshots\\62be9821055981deafb23f217b68cc41f38cdb76\\config.json\n",
"Model config BertConfig {\n",
" \"architectures\": [\n",
" \"BertForMaskedLM\",\n",
" \"BertForPreTraining\"\n",
" ],\n",
" \"attention_probs_dropout_prob\": 0.1,\n",
" \"classifier_dropout\": null,\n",
" \"hidden_act\": \"gelu\",\n",
" \"hidden_dropout_prob\": 0.1,\n",
" \"hidden_size\": 768,\n",
" \"id2label\": {\n",
" \"0\": \"LABEL_0\",\n",
" \"1\": \"LABEL_1\",\n",
" \"2\": \"LABEL_2\",\n",
" \"3\": \"LABEL_3\",\n",
" \"4\": \"LABEL_4\",\n",
" \"5\": \"LABEL_5\",\n",
" \"6\": \"LABEL_6\",\n",
" \"7\": \"LABEL_7\",\n",
" \"8\": \"LABEL_8\",\n",
" \"9\": \"LABEL_9\",\n",
" \"10\": \"LABEL_10\",\n",
" \"11\": \"LABEL_11\",\n",
" \"12\": \"LABEL_12\",\n",
" \"13\": \"LABEL_13\",\n",
" \"14\": \"LABEL_14\",\n",
" \"15\": \"LABEL_15\",\n",
" \"16\": \"LABEL_16\"\n",
" },\n",
" \"initializer_range\": 0.02,\n",
" \"intermediate_size\": 3072,\n",
" \"label2id\": {\n",
" \"LABEL_0\": 0,\n",
" \"LABEL_1\": 1,\n",
" \"LABEL_10\": 10,\n",
" \"LABEL_11\": 11,\n",
" \"LABEL_12\": 12,\n",
" \"LABEL_13\": 13,\n",
" \"LABEL_14\": 14,\n",
" \"LABEL_15\": 15,\n",
" \"LABEL_16\": 16,\n",
" \"LABEL_2\": 2,\n",
" \"LABEL_3\": 3,\n",
" \"LABEL_4\": 4,\n",
" \"LABEL_5\": 5,\n",
" \"LABEL_6\": 6,\n",
" \"LABEL_7\": 7,\n",
" \"LABEL_8\": 8,\n",
" \"LABEL_9\": 9\n",
" },\n",
" \"layer_norm_eps\": 1e-12,\n",
" \"max_position_embeddings\": 512,\n",
" \"model_type\": \"bert\",\n",
" \"num_attention_heads\": 12,\n",
" \"num_hidden_layers\": 12,\n",
" \"output_past\": true,\n",
" \"pad_token_id\": 0,\n",
" \"position_embedding_type\": \"absolute\",\n",
" \"transformers_version\": \"4.28.1\",\n",
" \"type_vocab_size\": 2,\n",
" \"use_cache\": true,\n",
" \"vocab_size\": 60000\n",
"}\n",
"\n",
"loading weights file pytorch_model.bin from cache at C:\\Users\\macty/.cache\\huggingface\\hub\\models--dkleczek--bert-base-polish-uncased-v1\\snapshots\\62be9821055981deafb23f217b68cc41f38cdb76\\pytorch_model.bin\n",
"Loading PyTorch weights from C:\\Users\\macty\\.cache\\huggingface\\hub\\models--dkleczek--bert-base-polish-uncased-v1\\snapshots\\62be9821055981deafb23f217b68cc41f38cdb76\\pytorch_model.bin\n",
"PyTorch checkpoint contains 178,915,010 parameters\n",
"Loaded 132,121,344 parameters in the TF 2.0 model.\n",
"All PyTorch model weights were used when initializing TFBertForSequenceClassification.\n",
"\n",
"Some weights or buffers of the TF 2.0 model TFBertForSequenceClassification were not initialized from the PyTorch model and are newly initialized: ['classifier.weight', 'classifier.bias']\n",
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
]
}
],
"source": [
"from transformers import *\n",
"tokenizer = BertTokenizer.from_pretrained(\"dkleczek/bert-base-polish-uncased-v1\")\n",
"model = TFBertForSequenceClassification.from_pretrained(\"dkleczek/bert-base-polish-uncased-v1\", from_pt=True, num_labels=17)\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"acts_text= [tokenizer.encode(text, add_special_tokens=True, max_length=128, padding='max_length', truncation=True) for text in acts[\"text\"]]"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"input_ids = np.array(acts_text)\n",
"attention_masks = np.where(input_ids != 0, 1, 0) # checks for existing words"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"train_inputs, test_inputs, train_labels, test_labels = train_test_split(input_ids, labels, test_size=0.2, random_state=42)\n",
"train_masks, test_masks, _, _ = train_test_split(attention_masks, input_ids, test_size=0.2, random_state=42)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"optimizer = tf.keras.optimizers.Adam(learning_rate=2e-5)\n",
"loss = tf.keras.losses.BinaryCrossentropy(from_logits=True)\n",
"metrics = tfa.metrics.F1Score(num_classes=17, threshold=0.5)\n",
"model.layers[-1].activation = tf.keras.activations.sigmoid\n",
"model.compile(optimizer=optimizer, loss=loss, metrics=metrics)"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/30\n",
"4/4 [==============================] - 63s 15s/step - loss: 0.1764 - f1_score: 0.1641 - val_loss: 0.1709 - val_f1_score: 0.1149\n",
"Epoch 2/30\n",
"4/4 [==============================] - 59s 15s/step - loss: 0.1712 - f1_score: 0.1957 - val_loss: 0.1672 - val_f1_score: 0.1280\n",
"Epoch 3/30\n",
"4/4 [==============================] - 59s 15s/step - loss: 0.1636 - f1_score: 0.1932 - val_loss: 0.1633 - val_f1_score: 0.1385\n",
"Epoch 4/30\n",
"4/4 [==============================] - 59s 15s/step - loss: 0.1574 - f1_score: 0.2331 - val_loss: 0.1592 - val_f1_score: 0.1627\n",
"Epoch 5/30\n",
"4/4 [==============================] - 59s 15s/step - loss: 0.1519 - f1_score: 0.2435 - val_loss: 0.1554 - val_f1_score: 0.1627\n",
"Epoch 6/30\n",
"4/4 [==============================] - 60s 15s/step - loss: 0.1471 - f1_score: 0.2603 - val_loss: 0.1516 - val_f1_score: 0.1646\n",
"Epoch 7/30\n",
"4/4 [==============================] - 59s 15s/step - loss: 0.1417 - f1_score: 0.2654 - val_loss: 0.1494 - val_f1_score: 0.1627\n",
"Epoch 8/30\n",
"4/4 [==============================] - 59s 15s/step - loss: 0.1371 - f1_score: 0.2688 - val_loss: 0.1476 - val_f1_score: 0.1627\n",
"Epoch 9/30\n",
"4/4 [==============================] - 58s 15s/step - loss: 0.1328 - f1_score: 0.2728 - val_loss: 0.1449 - val_f1_score: 0.1627\n",
"Epoch 10/30\n",
"4/4 [==============================] - 58s 15s/step - loss: 0.1291 - f1_score: 0.2818 - val_loss: 0.1422 - val_f1_score: 0.1627\n",
"Epoch 11/30\n",
"4/4 [==============================] - 59s 15s/step - loss: 0.1246 - f1_score: 0.2808 - val_loss: 0.1392 - val_f1_score: 0.1646\n",
"Epoch 12/30\n",
"4/4 [==============================] - 58s 15s/step - loss: 0.1215 - f1_score: 0.2781 - val_loss: 0.1387 - val_f1_score: 0.1605\n",
"Epoch 13/30\n",
"4/4 [==============================] - 58s 15s/step - loss: 0.1185 - f1_score: 0.2846 - val_loss: 0.1352 - val_f1_score: 0.1646\n",
"Epoch 14/30\n",
"4/4 [==============================] - 58s 15s/step - loss: 0.1139 - f1_score: 0.2856 - val_loss: 0.1343 - val_f1_score: 0.1649\n",
"Epoch 15/30\n",
"4/4 [==============================] - 58s 15s/step - loss: 0.1111 - f1_score: 0.2957 - val_loss: 0.1307 - val_f1_score: 0.1669\n",
"Epoch 16/30\n",
"4/4 [==============================] - 58s 15s/step - loss: 0.1085 - f1_score: 0.2878 - val_loss: 0.1289 - val_f1_score: 0.1588\n",
"Epoch 17/30\n",
"4/4 [==============================] - 58s 15s/step - loss: 0.1055 - f1_score: 0.3281 - val_loss: 0.1306 - val_f1_score: 0.1628\n",
"Epoch 18/30\n",
"4/4 [==============================] - 58s 15s/step - loss: 0.1020 - f1_score: 0.3212 - val_loss: 0.1273 - val_f1_score: 0.1588\n",
"Epoch 19/30\n",
"4/4 [==============================] - 58s 15s/step - loss: 0.0991 - f1_score: 0.3294 - val_loss: 0.1269 - val_f1_score: 0.1568\n",
"Epoch 20/30\n",
"4/4 [==============================] - 58s 15s/step - loss: 0.0978 - f1_score: 0.3272 - val_loss: 0.1272 - val_f1_score: 0.1568\n",
"Epoch 21/30\n",
"4/4 [==============================] - 58s 15s/step - loss: 0.0950 - f1_score: 0.3490 - val_loss: 0.1257 - val_f1_score: 0.1649\n",
"Epoch 22/30\n",
"4/4 [==============================] - 58s 15s/step - loss: 0.0922 - f1_score: 0.3481 - val_loss: 0.1233 - val_f1_score: 0.1669\n",
"Epoch 23/30\n",
"4/4 [==============================] - 58s 15s/step - loss: 0.0902 - f1_score: 0.3516 - val_loss: 0.1236 - val_f1_score: 0.1597\n",
"Epoch 24/30\n",
"4/4 [==============================] - 58s 15s/step - loss: 0.0883 - f1_score: 0.3704 - val_loss: 0.1208 - val_f1_score: 0.1669\n",
"Epoch 25/30\n",
"4/4 [==============================] - 58s 15s/step - loss: 0.0865 - f1_score: 0.3499 - val_loss: 0.1195 - val_f1_score: 0.2414\n",
"Epoch 26/30\n",
"4/4 [==============================] - 61s 15s/step - loss: 0.0841 - f1_score: 0.3725 - val_loss: 0.1199 - val_f1_score: 0.1649\n",
"Epoch 27/30\n",
"4/4 [==============================] - 59s 15s/step - loss: 0.0829 - f1_score: 0.3943 - val_loss: 0.1180 - val_f1_score: 0.1697\n",
"Epoch 28/30\n",
"4/4 [==============================] - 59s 15s/step - loss: 0.0803 - f1_score: 0.4423 - val_loss: 0.1164 - val_f1_score: 0.1669\n",
"Epoch 29/30\n",
"4/4 [==============================] - 58s 15s/step - loss: 0.0784 - f1_score: 0.3961 - val_loss: 0.1167 - val_f1_score: 0.1669\n",
"Epoch 30/30\n",
"4/4 [==============================] - 58s 15s/step - loss: 0.0774 - f1_score: 0.4275 - val_loss: 0.1177 - val_f1_score: 0.1625\n"
]
}
],
"source": [
"history = model.fit(\n",
" [train_inputs, train_masks],\n",
" train_labels,\n",
" validation_data=([test_inputs, test_masks], test_labels),\n",
" batch_size=32,\n",
" epochs=30\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1/1 [==============================] - 0s 64ms/step\n",
"TFSequenceClassifierOutput(loss=None, logits=array([[0.06025698, 0.04095593, 0.7684139 , 0.04259768, 0.05270579,\n",
" 0.10895084, 0.03345573, 0.03596378, 0.0508336 , 0.05108728,\n",
" 0.04185295, 0.06173437, 0.21573795, 0.10289352, 0.05192397,\n",
" 0.07420129, 0.08625325]], dtype=float32), hidden_states=None, attentions=None)\n"
]
}
],
"source": [
"input_sentence = \"do widzenia\"\n",
"inputs = tokenizer.encode_plus(input_sentence, add_special_tokens=True, return_tensors='tf')\n",
"predictions = model.predict([inputs['input_ids'], inputs['attention_mask']])\n",
"print(predictions)"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING:absl:Found untraced functions such as _update_step_xla, embeddings_layer_call_fn, embeddings_layer_call_and_return_conditional_losses, encoder_layer_call_fn, encoder_layer_call_and_return_conditional_losses while saving (showing 5 of 421). These functions will not be directly callable after loading.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Assets written to: model\\assets\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:tensorflow:Assets written to: model\\assets\n"
]
}
],
"source": [
"model.save('model', save_format='tf')"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.2"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}