System_Dialogowy_Janet/Makiety.py

264 lines
8.3 KiB
Python
Raw Normal View History

2021-05-09 13:57:44 +02:00
import jsgf
2021-05-16 19:42:35 +02:00
import codecs
from conllu import parse_incr
from tabulate import tabulate
import os.path
from flair.data import Corpus, Sentence, Token
from flair.datasets import SentenceDataset
from flair.embeddings import StackedEmbeddings
from flair.embeddings import WordEmbeddings
from flair.embeddings import CharacterEmbeddings
from flair.embeddings import FlairEmbeddings
from flair.models import SequenceTagger
from flair.trainers import ModelTrainer
import random
import torch
random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
torch.cuda.manual_seed(0)
torch.cuda.manual_seed_all(0)
torch.backends.cudnn.enabled = False
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
class ML_NLU:
def __init__(self, acts, arguments):
self.acts = acts
self.arguments = arguments
self.model = self.setup()
2021-05-16 19:42:35 +02:00
def nolabel2o(self, line, i):
return 'O' if line[i] == 'NoLabel' else line[i]
def conllu2flair(self, sentences, label=None):
fsentences = []
for sentence in sentences:
fsentence = Sentence()
for token in sentence:
ftoken = Token(token['form'])
if label:
ftoken.add_tag(label, token[label])
fsentence.add_token(ftoken)
fsentences.append(fsentence)
return SentenceDataset(fsentences)
2021-05-09 13:57:44 +02:00
2021-04-19 17:03:06 +02:00
2021-05-16 19:42:35 +02:00
def predict(self, model, sentence):
csentence = [{'form': word} for word in sentence]
fsentence = self.conllu2flair([csentence])[0]
model.predict(fsentence)
return [(token, ftoken.get_tag('slot').value) for token, ftoken in zip(sentence, fsentence)]
def setup(self):
if os.path.isfile('slot-model/final-model.pt'):
model = SequenceTagger.load('slot-model/final-model.pt')
else:
fields = ['id', 'form', 'frame', 'slot']
with open('Janet.conllu', encoding='utf-8') as trainfile:
2021-05-16 19:42:35 +02:00
trainset = list(parse_incr(trainfile, fields=fields, field_parsers={'slot': self.nolabel2o}))
with open('Janet.conllu', encoding='utf-8') as testfile:
2021-05-16 19:42:35 +02:00
testset = list(parse_incr(testfile, fields=fields, field_parsers={'slot': self.nolabel2o}))
tabulate(trainset[0], tablefmt='html')
corpus = Corpus(train=self.conllu2flair(trainset, 'slot'), test=self.conllu2flair(testset, 'slot'))
tag_dictionary = corpus.make_tag_dictionary(tag_type='slot')
embedding_types = [
WordEmbeddings('pl'),
FlairEmbeddings('pl-forward'),
FlairEmbeddings('pl-backward'),
CharacterEmbeddings(),
]
embeddings = StackedEmbeddings(embeddings=embedding_types)
tagger = SequenceTagger(hidden_size=256, embeddings=embeddings,
tag_dictionary=tag_dictionary,
tag_type='slot', use_crf=True)
trainer = ModelTrainer(tagger, corpus)
trainer.train('slot-model',
learning_rate=0.1,
mini_batch_size=32,
max_epochs=10,
train_with_dev=False)
model = SequenceTagger.load('slot-model/final-model.pt')
return model
def test_nlu(self, utterance):
if utterance:
2021-05-16 21:49:53 +02:00
return tabulate(self.predict(self.model, utterance.split()), tablefmt='tsv')
2021-05-16 19:42:35 +02:00
else:
return 'Critical Error'
2021-04-19 17:03:06 +02:00
2021-05-16 19:42:35 +02:00
class Book_NLU: #Natural Language Understanding
2021-04-19 17:03:06 +02:00
"""
Moduł odpowiedzialny za analizę tekstu. W wyniku jego działania tekstowa reprezentacja wypowiedzi użytkownika zostaje zamieniona na jej reprezentację semantyczną, najczęściej w postaci ramy.
Wejście: Tekst
Wyjście: Akt użytkownika (rama)
"""
2021-05-09 13:57:44 +02:00
def __init__(self, acts, arguments, book_grammar):
2021-04-25 11:44:21 +02:00
self.acts = acts
self.arguments = arguments
2021-05-09 13:57:44 +02:00
self.book_grammar = book_grammar
2021-05-09 14:19:01 +02:00
def get_dialog_act(self, rule):
2021-05-09 13:57:44 +02:00
slots = []
2021-05-09 14:19:01 +02:00
self.get_slots(rule.expansion, slots)
2021-05-09 13:57:44 +02:00
return {'act': rule.grammar.name, 'slots': slots}
2021-05-09 14:19:01 +02:00
def get_slots(self, expansion, slots):
2021-05-09 13:57:44 +02:00
if expansion.tag != '':
slots.append((expansion.tag, expansion.current_match))
return
for child in expansion.children:
2021-05-09 14:19:01 +02:00
self.get_slots(child, slots)
2021-05-09 13:57:44 +02:00
if not expansion.children and isinstance(expansion, jsgf.NamedRuleRef):
2021-05-09 14:19:01 +02:00
self.get_slots(expansion.referenced_rule.expansion, slots)
2021-04-25 11:44:21 +02:00
def analyze(self, text):
"""
Analiza Tekstu wprowadzonego przez użytkownika i zamiana na akt (rama)
"""
2021-04-19 17:03:06 +02:00
print("Analiza Tekstu: " + text)
2021-04-25 11:44:21 +02:00
act = "(greetings()&request(name))"
print("Akt to: " + act)
2021-04-19 17:03:06 +02:00
#przerobienie na wektor
2021-04-25 11:44:21 +02:00
act_vector = [[0],[1,0]] #1 wektor to greetings, a 2 wektor to request z argumentem "name"
2021-04-19 17:03:06 +02:00
print("Zamiana na: ")
2021-04-25 11:44:21 +02:00
print(act_vector)
return act_vector
2021-04-19 17:03:06 +02:00
2021-05-09 14:19:01 +02:00
def test_nlu(self, utterance):
matched = self.book_grammar.find_matching_rules(utterance)
print(matched)
2021-05-09 13:57:44 +02:00
if matched:
2021-05-09 14:19:01 +02:00
return self.get_dialog_act(matched[0])
2021-05-09 13:57:44 +02:00
else:
return {'act': 'null', 'slots': []}
2021-04-19 17:03:06 +02:00
class DST: #Dialogue State Tracker
"""
Moduł odpowiedzialny za śledzenie stanu dialogu. Przechowuje informacje o tym jakie dane zostały uzyskane od użytkownika w toku prowadzonej konwersacji.
Wejście: Akt użytkownika (rama)
Wyjście: Reprezentacja stanu dialogu (rama)
"""
2021-04-25 11:44:21 +02:00
def __init__(self, acts, arguments):
self.acts = acts
self.arguments = arguments
self.frame_list= []
2021-04-19 17:03:06 +02:00
2021-04-25 11:44:21 +02:00
def store(self, rama):
"""
Dodanie nowego aktu do listy
"""
2021-04-19 17:03:06 +02:00
print("\nDodanie do listy nowej ramy: ")
print(rama)
2021-04-25 11:44:21 +02:00
self.frame_list.append(rama)
2021-04-19 17:03:06 +02:00
2021-04-25 11:44:21 +02:00
def transfer(self):
2021-04-19 17:03:06 +02:00
print("Przekazanie dalej listy ram: ")
2021-04-25 11:44:21 +02:00
print(self.frame_list)
2021-04-19 17:03:06 +02:00
return self.frame_list
class DP:
"""
Moduł decydujący o wyborze kolejnego aktu, który ma podjąć system prowadząc rozmowę.
Wejście: Reprezentacja stanu dialogu (rama)
Wyjście: Akt systemu (rama)
"""
2021-04-25 11:44:21 +02:00
def __init__(self, acts, arguments):
self.acts = acts
self.arguments = arguments
2021-04-19 17:03:06 +02:00
2021-04-25 11:44:21 +02:00
def choose_tactic(self, frame_list):
"""
Obieranie taktyki na podstawie aktów usera. Bardzo ważna jest kolejność dodawanych do frame_list wartości.
"""
act_vector = [0, 0]
return act_vector
2021-04-19 17:03:06 +02:00
class NLG:
"""
Moduł, który tworzy reprezentację tekstową aktu systemowego wybranego przez taktykę dialogu.
Wejście: Akt systemu (rama)
Wyjście: Tekst
"""
2021-04-25 11:44:21 +02:00
def __init__(self, acts, arguments):
self.acts = acts
self.arguments = arguments
def change_to_text(self, act_vector):
"""
Funkcja zamieniająca akt systemu na tekst rozumiany przez użytkownika.
"""
if(act_vector == [0, 0]):
return "Cześć, mam na imię Janet"
return "Nie rozumiem"
class Janet:
def __init__(self):
self.acts={
0: "greetings",
1: "request",
}
self.arguments={
0: "name"
}
self.nlg = NLG(self.acts, self.arguments)
self.dp = DP(self.acts, self.arguments)
self.dst = DST(self.acts, self.arguments)
2021-05-16 19:42:35 +02:00
self.nlu = Book_NLU(self.acts, self.arguments, jsgf.parse_grammar_file('book.jsgf'))
self.nlu_v2 = ML_NLU(self.acts, self.arguments)
2021-04-25 11:44:21 +02:00
2021-05-09 13:57:44 +02:00
def test(self, command):
2021-05-16 19:42:35 +02:00
out = self.nlu_v2.test_nlu(command)
2021-05-09 14:19:01 +02:00
return out
2021-04-25 11:44:21 +02:00
def process(self, command):
act = self.nlu.analyze(command)
self.dst.store(act)
dest_act = self.dp.choose_tactic(self.dst.transfer())
return self.nlg.change_to_text(dest_act)
janet = Janet()
print(janet.test('chciałbym się umówić na wizytę do Piotra Pająka na jutro')) #Testowy print na start
2021-04-25 11:44:21 +02:00
while(1):
print('\n')
2021-04-25 11:44:21 +02:00
text = input("Wpisz tekst: ")
2021-05-09 13:57:44 +02:00
print(janet.test(text))