diff --git a/reguly.ipynb b/reguly.ipynb new file mode 100644 index 0000000..4ddb3a9 --- /dev/null +++ b/reguly.ipynb @@ -0,0 +1,389 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 29, + "id": "706dd5e1-57ee-416b-a77c-5d15df8dbdc8", + "metadata": {}, + "outputs": [], + "source": [ + "from convlab.base_models.t5.nlu import T5NLU\n", + "import requests\n", + "\n", + "\n", + "def translate_text(text, target_language='en'):\n", + " url = 'https://translate.googleapis.com/translate_a/single?client=gtx&sl=auto&tl={}&dt=t&q={}'.format(\n", + " target_language, text)\n", + " response = requests.get(url)\n", + " if response.status_code == 200:\n", + " translated_text = response.json()[0][0][0]\n", + " return translated_text\n", + " else:\n", + " return None\n", + "\n", + "\n", + "class NaturalLanguageAnalyzer: \n", + " def predict(self, text, context=None):\n", + " # Inicjalizacja modelu NLU\n", + " model_name = \"ConvLab/t5-small-nlu-multiwoz21\"\n", + " nlu_model = T5NLU(speaker='user', context_window_size=0, model_name_or_path=model_name)\n", + "\n", + " # Automatyczne tłumaczenie na język angielski\n", + " translated_input = translate_text(text)\n", + "\n", + " # Wygenerowanie odpowiedzi z modelu NLU\n", + " nlu_output = nlu_model.predict(translated_input)\n", + "\n", + " return nlu_output\n", + "\n", + " def init_session(self):\n", + " # Inicjalizacja sesji (jeśli konieczne)\n", + " pass" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "423f0821-000a-4aaa-b400-2e7554866175", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'user_action': [],\n", + " 'system_action': [],\n", + " 'belief_state': {'attraction': {'type': '', 'name': '', 'area': ''},\n", + " 'hotel': {'name': '',\n", + " 'area': '',\n", + " 'parking': '',\n", + " 'price range': '',\n", + " 'stars': '4',\n", + " 'internet': 'yes',\n", + " 'type': 'hotel',\n", + " 'book stay': '',\n", + " 'book day': '',\n", + " 'book people': ''},\n", + " 'restaurant': {'food': '',\n", + " 'price range': '',\n", + " 'name': '',\n", + " 'area': '',\n", + " 'book time': '',\n", + " 'book day': '',\n", + " 'book people': ''},\n", + " 'taxi': {'leave at': '',\n", + " 'destination': '',\n", + " 'departure': '',\n", + " 'arrive by': ''},\n", + " 'train': {'leave at': '',\n", + " 'destination': '',\n", + " 'day': '',\n", + " 'arrive by': '',\n", + " 'departure': '',\n", + " 'book people': ''},\n", + " 'hospital': {'department': ''}},\n", + " 'booked': {},\n", + " 'request_state': {},\n", + " 'terminated': False,\n", + " 'history': []}" + ] + }, + "execution_count": 27, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from convlab.util.multiwoz.state import default_state\n", + "default_state()" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "06926543-cab1-48e7-8e82-0560fc0fa16a", + "metadata": {}, + "outputs": [], + "source": [ + "import json\n", + "import os\n", + "from convlab.dst.dst import DST\n", + "from convlab.dst.rule.multiwoz.dst_util import normalize_value\n", + "\n", + "class SimpleRuleDST(DST):\n", + " def __init__(self):\n", + " DST.__init__(self)\n", + " self.state = default_state()\n", + " self.value_dict = json.load(open('value_dict.json'))\n", + "\n", + " def update(self, user_act=None):\n", + " for intent, domain, slot, value in user_act:\n", + " domain = domain.lower()\n", + " intent = intent.lower()\n", + " slot = slot.lower()\n", + " \n", + " if domain not in self.state['belief_state']:\n", + " continue\n", + "\n", + " if intent == 'inform':\n", + " if slot == 'none' or slot == '':\n", + " continue\n", + "\n", + " domain_dic = self.state['belief_state'][domain]\n", + "\n", + " if slot in domain_dic:\n", + " nvalue = normalize_value(self.value_dict, domain, slot, value)\n", + " self.state['belief_state'][domain][slot] = nvalue\n", + "\n", + " elif intent == 'request':\n", + " if domain not in self.state['request_state']:\n", + " self.state['request_state'][domain] = {}\n", + " if slot not in self.state['request_state'][domain]:\n", + " self.state['request_state'][domain][slot] = 0\n", + "\n", + " return self.state\n", + "\n", + " def init_session(self):\n", + " self.state = default_state()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b1d42d5f-e923-4c46-a930-48da9b72d77b", + "metadata": {}, + "outputs": [], + "source": [ + "dst = SimpleRuleDST()\n", + "dst.state" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "749e3a90-17c3-4a3e-acd7-856560445eaf", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'name': '',\n", + " 'area': '',\n", + " 'parking': 'yes',\n", + " 'price range': 'cheap',\n", + " 'stars': '4',\n", + " 'internet': 'yes',\n", + " 'type': 'hotel',\n", + " 'book stay': '',\n", + " 'book day': '',\n", + " 'book people': ''}" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dst.update([['Inform', 'Hotel', 'Price Range', 'cheap'], ['Inform', 'Hotel', 'Parking', 'yes']])\n", + "dst.state['belief_state']['hotel']" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "a7f3d067-3a95-4ef5-b216-be5840bc8831", + "metadata": {}, + "outputs": [], + "source": [ + "from collections import defaultdict\n", + "import copy\n", + "import json\n", + "from copy import deepcopy\n", + "\n", + "from convlab.policy.policy import Policy\n", + "from convlab.util.multiwoz.dbquery import Database\n", + "\n", + "\n", + "class SimpleRulePolicy(Policy):\n", + " def __init__(self):\n", + " Policy.__init__(self)\n", + " self.db = Database()\n", + "\n", + " def predict(self, state):\n", + " self.results = []\n", + " system_action = defaultdict(list)\n", + " user_action = defaultdict(list)\n", + "\n", + " for intent, domain, slot, value in state['user_action']:\n", + " user_action[(domain.lower(), intent.lower())].append((slot.lower(), value))\n", + "\n", + " for user_act in user_action:\n", + " self.update_system_action(user_act, user_action, state, system_action)\n", + "\n", + " # Reguła 3\n", + " if any(True for slots in user_action.values() for (slot, _) in slots if slot in ['book stay', 'book day', 'book people']):\n", + " if self.results:\n", + " system_action = {('Booking', 'Book'): [[\"Ref\", self.results[0].get('Ref', 'N/A')]]}\n", + "\n", + " system_acts = [[intent, domain, slot, value] for (domain, intent), slots in system_action.items() for slot, value in slots]\n", + " state['system_action'] = system_acts\n", + " return system_acts\n", + "\n", + " def update_system_action(self, user_act, user_action, state, system_action):\n", + " domain, intent = user_act\n", + " constraints = [(slot, value) for slot, value in state['belief_state'][domain.lower()].items() if value != '']\n", + " self.results = deepcopy(self.db.query(domain.lower(), constraints))\n", + "\n", + " # Reguła 1\n", + " if intent == 'request':\n", + " if len(self.results) == 0:\n", + " system_action[(domain, 'NoOffer')] = []\n", + " else:\n", + " for slot in user_action[user_act]: \n", + " if slot[0] in self.results[0]:\n", + " system_action[(domain, 'Inform')].append([slot[0], self.results[0].get(slot[0], 'unknown')])\n", + "\n", + " # Reguła 2\n", + " elif intent == 'inform':\n", + " if len(self.results) == 0:\n", + " system_action[(domain, 'NoOffer')] = []\n", + " else:\n", + " system_action[(domain, 'Inform')].append(['Choice', str(len(self.results))])\n", + " choice = self.results[0]\n", + "\n", + " if domain in [\"hotel\", \"attraction\", \"police\", \"restaurant\"]:\n", + " system_action[(domain, 'Recommend')].append(['Name', choice['name']])" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "089dbfa8-d34a-457c-9084-ef335372ea05", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING:root:nlu info_dict is not initialized\n", + "WARNING:root:dst info_dict is not initialized\n", + "WARNING:root:policy info_dict is not initialized\n", + "WARNING:root:nlg info_dict is not initialized\n" + ] + } + ], + "source": [ + "from convlab.dialog_agent import PipelineAgent\n", + "dst.init_session()\n", + "policy = SimpleRulePolicy()\n", + "agent = PipelineAgent(nlu=None, dst=dst, policy=policy, nlg=None, name='sys')" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "5ac57cc8-6650-4a1b-a87e-2cda67d9b0f3", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[['Inform', 'hotel', 'Choice', '3'],\n", + " ['Recommend', 'hotel', 'Name', 'huntingdon marriott hotel']]" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "agent.response([['Inform', 'Hotel', 'Price Range', 'cheap'], ['Inform', 'Hotel', 'Parking', 'yes']])" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "id": "eaeca7b0-08d5-4db0-9eb3-3aceda24f987", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING:root:nlu info_dict is not initialized\n", + "WARNING:root:dst info_dict is not initialized\n", + "WARNING:root:policy info_dict is not initialized\n", + "WARNING:root:nlg info_dict is not initialized\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "NLG seed 0\n" + ] + } + ], + "source": [ + "from convlab.base_models.t5.nlu import T5NLU\n", + "from convlab.nlg.template.multiwoz import TemplateNLG\n", + "\n", + "# nlu = T5NLU(speaker='user', context_window_size=0, model_name_or_path='ConvLab/t5-small-nlu-multiwoz21')\n", + "nlu = NaturalLanguageAnalyzer()\n", + "nlg = TemplateNLG(is_user=False)\n", + "agent = PipelineAgent(nlu=nlu, dst=dst, policy=policy, nlg=nlg, name='sys')" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "id": "b559fcd3-861b-49d7-ac2b-d3160d4c5a1d", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'We have 3 such places . Would huntingdon marriott hotel work for you ?'" + ] + }, + "execution_count": 37, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "agent.response(\"I need a cheap hotel with free parking .\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7f831f56-10ba-40da-a89c-baeed37df81e", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.8" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}