2023-03-31 19:23:16 +02:00
{
"cells": [
2023-04-21 10:50:26 +02:00
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\macty\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\tensorflow_addons\\utils\\tfa_eol_msg.py:23: UserWarning: \n",
"\n",
"TensorFlow Addons (TFA) has ended development and introduction of new features.\n",
"TFA has entered a minimal maintenance and release mode until a planned end of life in May 2024.\n",
"Please modify downstream libraries to take dependencies from other repositories in our TensorFlow community (e.g. Keras, Keras-CV, and Keras-NLP). \n",
"\n",
"For more information see: https://github.com/tensorflow/addons/issues/2807 \n",
"\n",
" warnings.warn(\n"
]
}
],
"source": [
"import tensorflow as tf\n",
"from tensorflow.keras.models import load_model\n",
"import tensorflow_addons as tfa\n",
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\macty\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n",
"c:\\Users\\macty\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\transformers\\generation_utils.py:24: FutureWarning: Importing `GenerationMixin` from `src/transformers/generation_utils.py` is deprecated and will be removed in Transformers v5. Import as `from transformers import GenerationMixin` instead.\n",
" warnings.warn(\n",
"c:\\Users\\macty\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\transformers\\generation_tf_utils.py:24: FutureWarning: Importing `TFGenerationMixin` from `src/transformers/generation_tf_utils.py` is deprecated and will be removed in Transformers v5. Import as `from transformers import TFGenerationMixin` instead.\n",
" warnings.warn(\n",
"loading file vocab.txt from cache at C:\\Users\\macty/.cache\\huggingface\\hub\\models--dkleczek--bert-base-polish-uncased-v1\\snapshots\\62be9821055981deafb23f217b68cc41f38cdb76\\vocab.txt\n",
"loading file added_tokens.json from cache at None\n",
"loading file special_tokens_map.json from cache at C:\\Users\\macty/.cache\\huggingface\\hub\\models--dkleczek--bert-base-polish-uncased-v1\\snapshots\\62be9821055981deafb23f217b68cc41f38cdb76\\special_tokens_map.json\n",
"loading file tokenizer_config.json from cache at C:\\Users\\macty/.cache\\huggingface\\hub\\models--dkleczek--bert-base-polish-uncased-v1\\snapshots\\62be9821055981deafb23f217b68cc41f38cdb76\\tokenizer_config.json\n",
"loading configuration file config.json from cache at C:\\Users\\macty/.cache\\huggingface\\hub\\models--dkleczek--bert-base-polish-uncased-v1\\snapshots\\62be9821055981deafb23f217b68cc41f38cdb76\\config.json\n",
"Model config BertConfig {\n",
" \"_name_or_path\": \"dkleczek/bert-base-polish-uncased-v1\",\n",
" \"architectures\": [\n",
" \"BertForMaskedLM\",\n",
" \"BertForPreTraining\"\n",
" ],\n",
" \"attention_probs_dropout_prob\": 0.1,\n",
" \"classifier_dropout\": null,\n",
" \"hidden_act\": \"gelu\",\n",
" \"hidden_dropout_prob\": 0.1,\n",
" \"hidden_size\": 768,\n",
" \"initializer_range\": 0.02,\n",
" \"intermediate_size\": 3072,\n",
" \"layer_norm_eps\": 1e-12,\n",
" \"max_position_embeddings\": 512,\n",
" \"model_type\": \"bert\",\n",
" \"num_attention_heads\": 12,\n",
" \"num_hidden_layers\": 12,\n",
" \"output_past\": true,\n",
" \"pad_token_id\": 0,\n",
" \"position_embedding_type\": \"absolute\",\n",
" \"transformers_version\": \"4.28.1\",\n",
" \"type_vocab_size\": 2,\n",
" \"use_cache\": true,\n",
" \"vocab_size\": 60000\n",
"}\n",
"\n"
]
}
],
"source": [
"loaded_model = tf.keras.models.load_model('model')\n",
"from transformers import *\n",
"tokenizer = BertTokenizer.from_pretrained(\"dkleczek/bert-base-polish-uncased-v1\")"
]
},
2023-03-31 19:23:16 +02:00
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# ASR"
]
},
{
"cell_type": "code",
2023-04-18 19:57:43 +02:00
"execution_count": 28,
2023-03-31 19:23:16 +02:00
"metadata": {},
"outputs": [],
"source": [
"def asr(inputText: str) -> str:\n",
" # Do something\n",
" inputText\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# NLU"
]
},
{
"cell_type": "code",
2023-04-21 10:50:26 +02:00
"execution_count": 13,
2023-03-31 19:23:16 +02:00
"metadata": {},
"outputs": [],
"source": [
2023-04-21 10:50:26 +02:00
"class NLU:\n",
2023-03-31 19:23:16 +02:00
" def __init__(self, text: str):\n",
" self.text = text\n",
2023-04-18 19:57:43 +02:00
" self.act = \"\"\n",
2023-04-21 10:50:26 +02:00
"\n",
2023-04-18 19:57:43 +02:00
" def get_dialog_act(self): \n",
2023-04-21 10:50:26 +02:00
" predicted_classes_names=[]\n",
" input = [self.text]\n",
" encoded_input = tokenizer.batch_encode_plus(input, padding=True, truncation=True, return_tensors='tf')\n",
" dataset = tf.data.Dataset.from_tensor_slices({\n",
" 'input_ids': encoded_input['input_ids'],\n",
" 'attention_mask': encoded_input['attention_mask'],\n",
" 'token_type_ids': encoded_input['token_type_ids']\n",
" }).batch(2)\n",
" predictions = loaded_model.predict(dataset)\n",
" classes = [\"ack\",\"affirm\",\"bye\",\"hello\",\"help\",\"negate\",\"null\",\"repeat\",\"reqalts\",\"reqmore\",\"restart\",\"silence\",\"thankyou\",\"confirm\",\"deny\",\"inform\",\"request\"]\n",
" for prediction in predictions: #trying to get predictions, if none it take maximum\n",
" predicted_classes = (predictions[prediction]> 0.5).astype(\"int32\")\n",
" if predicted_classes.sum()==0:\n",
" predicted_classes=max(predictions[prediction])\n",
" predicted_classes_indexes= np.where(predicted_classes==1)[1]\n",
" for p_classes in predicted_classes_indexes:\n",
" predicted_classes_names.append(classes[p_classes])\n",
" self.act=predicted_classes_names\n",
" return self.act\n"
2023-03-31 19:23:16 +02:00
]
},
{
"cell_type": "code",
2023-04-21 10:50:26 +02:00
"execution_count": 17,
2023-03-31 19:23:16 +02:00
"metadata": {},
"outputs": [
2023-04-21 10:50:26 +02:00
{
"name": "stdout",
"output_type": "stream",
"text": [
"1/1 [==============================] - 0s 58ms/step\n"
]
},
2023-03-31 19:23:16 +02:00
{
"data": {
"text/plain": [
2023-04-21 10:50:26 +02:00
"['request']"
2023-03-31 19:23:16 +02:00
]
},
2023-04-21 10:50:26 +02:00
"execution_count": 17,
2023-03-31 19:23:16 +02:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
2023-04-21 10:50:26 +02:00
"nlu = NLU(\"Jaki pokój proponujesz w tym hotelu?\")\n",
2023-04-18 19:57:43 +02:00
"nlu.get_dialog_act()\n",
"nlu.act"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# DST"
2023-03-31 19:23:16 +02:00
]
},
2023-03-31 22:30:50 +02:00
{
"cell_type": "code",
2023-04-21 10:50:26 +02:00
"execution_count": 18,
2023-03-31 22:30:50 +02:00
"metadata": {},
"outputs": [],
2023-04-18 19:57:43 +02:00
"source": [
"class DialogueStateTracker:\n",
" \n",
" slots_dict: dict[tuple[str], str] = {\n",
" (\"osoby\", \"ludzie\", \"osób\", \"osobowy\"): \"people\",\n",
" (\"miasto\", \"miasta\", \"miejsowość\", \"poznań\", \"warszawa\", \"warszawie\", \"poznaniu\", \"kraków\", \"krakowie\"): \"city\",\n",
" (\"basen\", \"parking\", \"śniadania\"): \"facilities\",\n",
" (\"data\", \"datę\"): \"date\",\n",
" (\"pokój\", \"pokoje\"): \"room\"\n",
" }\n",
" \n",
2023-04-21 10:50:26 +02:00
" def __init__(self, nlu: NLU):\n",
2023-04-18 19:57:43 +02:00
" self.slots = []\n",
2023-04-21 10:50:26 +02:00
" self.act = nlu\n",
2023-04-18 19:57:43 +02:00
" self.text = nlu.text\n",
" \n",
" def get_dialog_slots(self):\n",
" for word in self.text.lower().split():\n",
" for key in DialogueStateTracker.slots_dict:\n",
" if word in key:\n",
" self.slots.append(DialogueStateTracker.slots_dict[key])\n",
" "
]
},
{
"cell_type": "code",
2023-04-21 10:50:26 +02:00
"execution_count": 19,
2023-04-18 19:57:43 +02:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['room']"
]
},
2023-04-21 10:50:26 +02:00
"execution_count": 19,
2023-04-18 19:57:43 +02:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dst: DialogueStateTracker = DialogueStateTracker(nlu)\n",
"dst.get_dialog_slots()\n",
"dst.slots\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# Dialogue Policy"
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {},
"outputs": [],
"source": [
"class DialoguePolicy:\n",
" user_act_to_system_act_dict: dict[str, str] = {\n",
" \"ack\": \"reqmore\",\n",
" \"bye\": \"bye\",\n",
" \"hello\": \"welcomemsg\",\n",
" \"help\": \"inform\",\n",
" \"negate\": \"offer\",\n",
" \"requalts\": \"offer\",\n",
" \"reqmore\": \"inform\",\n",
" \"restart\": \"welcomemsg\",\n",
" \"thankyou\": \"reqmore\",\n",
" \"confirm\": \"reqmore\",\n",
" \"deny\": \"offer\",\n",
" \"inform\": \"offer\",\n",
" \"request\": \"inform\",\n",
" \"null\": \"null\"\n",
" }\n",
" \n",
" def __init__(self, dst: DialogueStateTracker):\n",
" self.user_text = dst.text\n",
" self.user_act = dst.act\n",
" self.user_slots = dst.slots\n",
" self.system_act = \"\"\n",
" \n",
" def get_system_act(self):\n",
" self.system_act = DialoguePolicy.user_act_to_system_act_dict[self.user_act]\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'inform'"
]
},
"execution_count": 46,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dp: DialoguePolicy = DialoguePolicy(dst)\n",
"dp.get_system_act()\n",
"dp.system_act"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# NLG"
]
},
{
"cell_type": "code",
"execution_count": 57,
"metadata": {},
"outputs": [],
"source": [
"class NaturalLanguageGeneration:\n",
" system_act_to_text = {\n",
" \"reqmore\": \"Informuje więcej o \",\n",
" \"bye\": \"Do widzenia\",\n",
" \"welcomemsg\": \"Witaj w systemie rezerwacji hotelowych. W czym mogę pomóc?\",\n",
" \"inform\": \"Informuje cię o \",\n",
" \"offer\": \"Co myślisz o hotlu z \",\n",
" \"reqmore\": \"Czy mogę jeszcze jakoś Ci pomóc?\",\n",
" \"null\": \"\"\n",
" }\n",
" user_slots_to_text = {\n",
" \"people\": \"pojemności pokoju\",\n",
" \"city\": \"mieście\",\n",
" \"facilities\": \"udogodnieniach\",\n",
" \"date\": \"dacie\",\n",
" \"room\": \"pokoju\"\n",
" }\n",
" \n",
" def __init__(self, dp: DialoguePolicy):\n",
" self.user_text = dp.user_text\n",
" self.user_act = dp.user_act\n",
" self.user_slots = dp.user_slots\n",
" self.system_act = dp.system_act\n",
" self.system_text = \"\"\n",
" \n",
" def generate_system_text(self):\n",
" text: str = NaturalLanguageGeneration.system_act_to_text[self.system_act]\n",
" slots_transformed = [NaturalLanguageGeneration.user_slots_to_text[slot] for slot in self.user_slots]\n",
" self.system_text = text + \" i \".join(slots_transformed)\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 58,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'Informuje cię o pokoju'"
]
},
"execution_count": 58,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"nlg: NaturalLanguageGeneration = NaturalLanguageGeneration(dp)\n",
"nlg.generate_system_text()\n",
"nlg.system_text"
]
2023-03-31 22:30:50 +02:00
},
2023-03-31 19:23:16 +02:00
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "SDenv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
2023-04-21 10:50:26 +02:00
"version": "3.11.2"
2023-03-31 19:23:16 +02:00
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}