Finish chatbot

This commit is contained in:
PawelDopierala 2024-06-07 00:38:21 +02:00
parent 7bb36b5bac
commit e24d1894a3
5 changed files with 38 additions and 28 deletions

View File

@ -7,6 +7,7 @@ db_path = './hotels_data.json'
class DialoguePolicy(Policy): class DialoguePolicy(Policy):
info_dict = None
def __init__(self): def __init__(self):
Policy.__init__(self) Policy.__init__(self)
self.db = self.load_database(db_path) self.db = self.load_database(db_path)

View File

@ -32,6 +32,7 @@ def default_state():
class DialogueStateTracker(DST): class DialogueStateTracker(DST):
info_dict = None
def __init__(self): def __init__(self):
DST.__init__(self) DST.__init__(self)
self.state = default_state() self.state = default_state()

27
Main.py
View File

@ -1,31 +1,22 @@
import requests
from NaturalLanguageAnalyzer import NaturalLanguageAnalyzer from NaturalLanguageAnalyzer import NaturalLanguageAnalyzer
from DialoguePolicy import DialoguePolicy from DialoguePolicy import DialoguePolicy
from DialogueStateTracker import DialogueStateTracker from DialogueStateTracker import DialogueStateTracker
from NaturalLanguageGeneration import NaturalLanguageGeneration
from convlab.dialog_agent import PipelineAgent from convlab.dialog_agent import PipelineAgent
from MachineLearningNLG import MachineLearningNLG # Importujemy nowy komponent NLG import warnings
warnings.filterwarnings("ignore")
def translate_text(text, target_language='pl'):
url = 'https://translate.googleapis.com/translate_a/single?client=gtx&sl=auto&tl={}&dt=t&q={}'.format(
target_language, text)
response = requests.get(url)
if response.status_code == 200:
translated_text = response.json()[0]
translated_text_joined = ''.join([sentence[0] for sentence in translated_text])
return translated_text_joined
else:
return None
if __name__ == "__main__": if __name__ == "__main__":
text = "chciałbym zarezerwować drogi hotel z parkingiem 1 stycznia w Warszawie w centrum"
nlu = NaturalLanguageAnalyzer() nlu = NaturalLanguageAnalyzer()
dst = DialogueStateTracker() dst = DialogueStateTracker()
policy = DialoguePolicy() policy = DialoguePolicy()
nlg = MachineLearningNLG() nlg = NaturalLanguageGeneration()
agent = PipelineAgent(nlu=nlu, dst=dst, policy=policy, nlg=nlg, name='sys') agent = PipelineAgent(nlu=nlu, dst=dst, policy=policy, nlg=nlg, name='sys')
print("Witam, jestem systemem do rezerwowania pokoi hotelowych. W czym mogę Ci pomóc?")
while True:
text = input()
response = agent.response(text) response = agent.response(text)
print(translate_text(response)) print(response)

View File

@ -14,6 +14,7 @@ def translate_text(text, target_language='en'):
class NaturalLanguageAnalyzer: class NaturalLanguageAnalyzer:
info_dict = None
def predict(self, text, context=None): def predict(self, text, context=None):
# Inicjalizacja modelu NLU # Inicjalizacja modelu NLU
model_name = "ConvLab/t5-small-nlu-multiwoz21" model_name = "ConvLab/t5-small-nlu-multiwoz21"

View File

@ -1,12 +1,28 @@
import re import requests
from convlab.nlg.template.multiwoz import TemplateNLG
def translate_text(text, target_language='pl'):
url = 'https://translate.googleapis.com/translate_a/single?client=gtx&sl=auto&tl={}&dt=t&q={}'.format(
target_language, text)
response = requests.get(url)
if response.status_code == 200:
translated_text = response.json()[0]
translated_text_joined = ''.join([sentence[0] for sentence in translated_text])
return translated_text_joined
else:
return None
class NaturalLanguageGeneration: class NaturalLanguageGeneration:
info_dict = None
def generate(self, system_act):
if len(system_act) == 0:
return "Nie rozumiem."
tnlg = TemplateNLG(is_user=False)
response_en = tnlg.generate(system_act)
return translate_text(response_en)
def nlg(self, system_act):
response = None def init_session(self):
pattern = r'inform\(name=([^\)]+)\)' pass # Dodanie pustej metody init_session
matching = re.search(pattern, system_act)
if matching:
name = matching.group(1)
response = f"Witaj, nazywam się {name}"
return response