399 lines
13 KiB
Plaintext
399 lines
13 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"c:\\Users\\macty\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\tensorflow_addons\\utils\\tfa_eol_msg.py:23: UserWarning: \n",
|
|
"\n",
|
|
"TensorFlow Addons (TFA) has ended development and introduction of new features.\n",
|
|
"TFA has entered a minimal maintenance and release mode until a planned end of life in May 2024.\n",
|
|
"Please modify downstream libraries to take dependencies from other repositories in our TensorFlow community (e.g. Keras, Keras-CV, and Keras-NLP). \n",
|
|
"\n",
|
|
"For more information see: https://github.com/tensorflow/addons/issues/2807 \n",
|
|
"\n",
|
|
" warnings.warn(\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"import tensorflow as tf\n",
|
|
"from tensorflow.keras.models import load_model\n",
|
|
"import tensorflow_addons as tfa\n",
|
|
"import numpy as np"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"c:\\Users\\macty\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
|
" from .autonotebook import tqdm as notebook_tqdm\n",
|
|
"c:\\Users\\macty\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\transformers\\generation_utils.py:24: FutureWarning: Importing `GenerationMixin` from `src/transformers/generation_utils.py` is deprecated and will be removed in Transformers v5. Import as `from transformers import GenerationMixin` instead.\n",
|
|
" warnings.warn(\n",
|
|
"c:\\Users\\macty\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\transformers\\generation_tf_utils.py:24: FutureWarning: Importing `TFGenerationMixin` from `src/transformers/generation_tf_utils.py` is deprecated and will be removed in Transformers v5. Import as `from transformers import TFGenerationMixin` instead.\n",
|
|
" warnings.warn(\n",
|
|
"loading file vocab.txt from cache at C:\\Users\\macty/.cache\\huggingface\\hub\\models--dkleczek--bert-base-polish-uncased-v1\\snapshots\\62be9821055981deafb23f217b68cc41f38cdb76\\vocab.txt\n",
|
|
"loading file added_tokens.json from cache at None\n",
|
|
"loading file special_tokens_map.json from cache at C:\\Users\\macty/.cache\\huggingface\\hub\\models--dkleczek--bert-base-polish-uncased-v1\\snapshots\\62be9821055981deafb23f217b68cc41f38cdb76\\special_tokens_map.json\n",
|
|
"loading file tokenizer_config.json from cache at C:\\Users\\macty/.cache\\huggingface\\hub\\models--dkleczek--bert-base-polish-uncased-v1\\snapshots\\62be9821055981deafb23f217b68cc41f38cdb76\\tokenizer_config.json\n",
|
|
"loading configuration file config.json from cache at C:\\Users\\macty/.cache\\huggingface\\hub\\models--dkleczek--bert-base-polish-uncased-v1\\snapshots\\62be9821055981deafb23f217b68cc41f38cdb76\\config.json\n",
|
|
"Model config BertConfig {\n",
|
|
" \"_name_or_path\": \"dkleczek/bert-base-polish-uncased-v1\",\n",
|
|
" \"architectures\": [\n",
|
|
" \"BertForMaskedLM\",\n",
|
|
" \"BertForPreTraining\"\n",
|
|
" ],\n",
|
|
" \"attention_probs_dropout_prob\": 0.1,\n",
|
|
" \"classifier_dropout\": null,\n",
|
|
" \"hidden_act\": \"gelu\",\n",
|
|
" \"hidden_dropout_prob\": 0.1,\n",
|
|
" \"hidden_size\": 768,\n",
|
|
" \"initializer_range\": 0.02,\n",
|
|
" \"intermediate_size\": 3072,\n",
|
|
" \"layer_norm_eps\": 1e-12,\n",
|
|
" \"max_position_embeddings\": 512,\n",
|
|
" \"model_type\": \"bert\",\n",
|
|
" \"num_attention_heads\": 12,\n",
|
|
" \"num_hidden_layers\": 12,\n",
|
|
" \"output_past\": true,\n",
|
|
" \"pad_token_id\": 0,\n",
|
|
" \"position_embedding_type\": \"absolute\",\n",
|
|
" \"transformers_version\": \"4.28.1\",\n",
|
|
" \"type_vocab_size\": 2,\n",
|
|
" \"use_cache\": true,\n",
|
|
" \"vocab_size\": 60000\n",
|
|
"}\n",
|
|
"\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"loaded_model = tf.keras.models.load_model('model')\n",
|
|
"from transformers import *\n",
|
|
"tokenizer = BertTokenizer.from_pretrained(\"dkleczek/bert-base-polish-uncased-v1\")"
|
|
]
|
|
},
|
|
{
|
|
"attachments": {},
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# ASR"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 28,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def asr(inputText: str) -> str:\n",
|
|
" # Do something\n",
|
|
" inputText\n"
|
|
]
|
|
},
|
|
{
|
|
"attachments": {},
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# NLU"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 13,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class NLU:\n",
|
|
" def __init__(self, text: str):\n",
|
|
" self.text = text\n",
|
|
" self.act = \"\"\n",
|
|
"\n",
|
|
" def get_dialog_act(self): \n",
|
|
" predicted_classes_names=[]\n",
|
|
" input = [self.text]\n",
|
|
" encoded_input = tokenizer.batch_encode_plus(input, padding=True, truncation=True, return_tensors='tf')\n",
|
|
" dataset = tf.data.Dataset.from_tensor_slices({\n",
|
|
" 'input_ids': encoded_input['input_ids'],\n",
|
|
" 'attention_mask': encoded_input['attention_mask'],\n",
|
|
" 'token_type_ids': encoded_input['token_type_ids']\n",
|
|
" }).batch(2)\n",
|
|
" predictions = loaded_model.predict(dataset)\n",
|
|
" classes = [\"ack\",\"affirm\",\"bye\",\"hello\",\"help\",\"negate\",\"null\",\"repeat\",\"reqalts\",\"reqmore\",\"restart\",\"silence\",\"thankyou\",\"confirm\",\"deny\",\"inform\",\"request\"]\n",
|
|
" for prediction in predictions: #trying to get predictions, if none it take maximum\n",
|
|
" predicted_classes = (predictions[prediction]> 0.5).astype(\"int32\")\n",
|
|
" if predicted_classes.sum()==0:\n",
|
|
" predicted_classes=max(predictions[prediction])\n",
|
|
" predicted_classes_indexes= np.where(predicted_classes==1)[1]\n",
|
|
" for p_classes in predicted_classes_indexes:\n",
|
|
" predicted_classes_names.append(classes[p_classes])\n",
|
|
" self.act=predicted_classes_names\n",
|
|
" return self.act\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 17,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"1/1 [==============================] - 0s 58ms/step\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"['request']"
|
|
]
|
|
},
|
|
"execution_count": 17,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"nlu = NLU(\"Jaki pokój proponujesz w tym hotelu?\")\n",
|
|
"nlu.get_dialog_act()\n",
|
|
"nlu.act"
|
|
]
|
|
},
|
|
{
|
|
"attachments": {},
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# DST"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 18,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class DialogueStateTracker:\n",
|
|
" \n",
|
|
" slots_dict: dict[tuple[str], str] = {\n",
|
|
" (\"osoby\", \"ludzie\", \"osób\", \"osobowy\"): \"people\",\n",
|
|
" (\"miasto\", \"miasta\", \"miejsowość\", \"poznań\", \"warszawa\", \"warszawie\", \"poznaniu\", \"kraków\", \"krakowie\"): \"city\",\n",
|
|
" (\"basen\", \"parking\", \"śniadania\"): \"facilities\",\n",
|
|
" (\"data\", \"datę\"): \"date\",\n",
|
|
" (\"pokój\", \"pokoje\"): \"room\"\n",
|
|
" }\n",
|
|
" \n",
|
|
" def __init__(self, nlu: NLU):\n",
|
|
" self.slots = []\n",
|
|
" self.act = nlu\n",
|
|
" self.text = nlu.text\n",
|
|
" \n",
|
|
" def get_dialog_slots(self):\n",
|
|
" for word in self.text.lower().split():\n",
|
|
" for key in DialogueStateTracker.slots_dict:\n",
|
|
" if word in key:\n",
|
|
" self.slots.append(DialogueStateTracker.slots_dict[key])\n",
|
|
" "
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 19,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"['room']"
|
|
]
|
|
},
|
|
"execution_count": 19,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"dst: DialogueStateTracker = DialogueStateTracker(nlu)\n",
|
|
"dst.get_dialog_slots()\n",
|
|
"dst.slots\n"
|
|
]
|
|
},
|
|
{
|
|
"attachments": {},
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Dialogue Policy"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 45,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class DialoguePolicy:\n",
|
|
" user_act_to_system_act_dict: dict[str, str] = {\n",
|
|
" \"ack\": \"reqmore\",\n",
|
|
" \"bye\": \"bye\",\n",
|
|
" \"hello\": \"welcomemsg\",\n",
|
|
" \"help\": \"inform\",\n",
|
|
" \"negate\": \"offer\",\n",
|
|
" \"requalts\": \"offer\",\n",
|
|
" \"reqmore\": \"inform\",\n",
|
|
" \"restart\": \"welcomemsg\",\n",
|
|
" \"thankyou\": \"reqmore\",\n",
|
|
" \"confirm\": \"reqmore\",\n",
|
|
" \"deny\": \"offer\",\n",
|
|
" \"inform\": \"offer\",\n",
|
|
" \"request\": \"inform\",\n",
|
|
" \"null\": \"null\"\n",
|
|
" }\n",
|
|
" \n",
|
|
" def __init__(self, dst: DialogueStateTracker):\n",
|
|
" self.user_text = dst.text\n",
|
|
" self.user_act = dst.act\n",
|
|
" self.user_slots = dst.slots\n",
|
|
" self.system_act = \"\"\n",
|
|
" \n",
|
|
" def get_system_act(self):\n",
|
|
" self.system_act = DialoguePolicy.user_act_to_system_act_dict[self.user_act]\n",
|
|
" "
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 46,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"'inform'"
|
|
]
|
|
},
|
|
"execution_count": 46,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"dp: DialoguePolicy = DialoguePolicy(dst)\n",
|
|
"dp.get_system_act()\n",
|
|
"dp.system_act"
|
|
]
|
|
},
|
|
{
|
|
"attachments": {},
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# NLG"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 57,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class NaturalLanguageGeneration:\n",
|
|
" system_act_to_text = {\n",
|
|
" \"reqmore\": \"Informuje więcej o \",\n",
|
|
" \"bye\": \"Do widzenia\",\n",
|
|
" \"welcomemsg\": \"Witaj w systemie rezerwacji hotelowych. W czym mogę pomóc?\",\n",
|
|
" \"inform\": \"Informuje cię o \",\n",
|
|
" \"offer\": \"Co myślisz o hotlu z \",\n",
|
|
" \"reqmore\": \"Czy mogę jeszcze jakoś Ci pomóc?\",\n",
|
|
" \"null\": \"\"\n",
|
|
" }\n",
|
|
" user_slots_to_text = {\n",
|
|
" \"people\": \"pojemności pokoju\",\n",
|
|
" \"city\": \"mieście\",\n",
|
|
" \"facilities\": \"udogodnieniach\",\n",
|
|
" \"date\": \"dacie\",\n",
|
|
" \"room\": \"pokoju\"\n",
|
|
" }\n",
|
|
" \n",
|
|
" def __init__(self, dp: DialoguePolicy):\n",
|
|
" self.user_text = dp.user_text\n",
|
|
" self.user_act = dp.user_act\n",
|
|
" self.user_slots = dp.user_slots\n",
|
|
" self.system_act = dp.system_act\n",
|
|
" self.system_text = \"\"\n",
|
|
" \n",
|
|
" def generate_system_text(self):\n",
|
|
" text: str = NaturalLanguageGeneration.system_act_to_text[self.system_act]\n",
|
|
" slots_transformed = [NaturalLanguageGeneration.user_slots_to_text[slot] for slot in self.user_slots]\n",
|
|
" self.system_text = text + \" i \".join(slots_transformed)\n",
|
|
" "
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 58,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"'Informuje cię o pokoju'"
|
|
]
|
|
},
|
|
"execution_count": 58,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"nlg: NaturalLanguageGeneration = NaturalLanguageGeneration(dp)\n",
|
|
"nlg.generate_system_text()\n",
|
|
"nlg.system_text"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "SDenv",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.11.2"
|
|
},
|
|
"orig_nbformat": 4
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|