diff --git a/Makiety.py b/Makiety.py index 4476677..fa9bfc8 100644 --- a/Makiety.py +++ b/Makiety.py @@ -1,8 +1,116 @@ import jsgf +import codecs +from conllu import parse_incr +from tabulate import tabulate +import os.path +from flair.data import Corpus, Sentence, Token +from flair.datasets import SentenceDataset +from flair.embeddings import StackedEmbeddings +from flair.embeddings import WordEmbeddings +from flair.embeddings import CharacterEmbeddings +from flair.embeddings import FlairEmbeddings +from flair.models import SequenceTagger +from flair.trainers import ModelTrainer -class NLU: #Natural Language Understanding +import random +import torch +random.seed(42) +torch.manual_seed(42) + +if torch.cuda.is_available(): + torch.cuda.manual_seed(0) + torch.cuda.manual_seed_all(0) + torch.backends.cudnn.enabled = False + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + +class ML_NLU: + def __init__(self, acts, arguments): + self.acts = acts + self.arguments = arguments + + def nolabel2o(self, line, i): + return 'O' if line[i] == 'NoLabel' else line[i] + + def conllu2flair(self, sentences, label=None): + fsentences = [] + + for sentence in sentences: + fsentence = Sentence() + + for token in sentence: + ftoken = Token(token['form']) + + if label: + ftoken.add_tag(label, token[label]) + + fsentence.add_token(ftoken) + + fsentences.append(fsentence) + + return SentenceDataset(fsentences) + + + def predict(self, model, sentence): + csentence = [{'form': word} for word in sentence] + fsentence = self.conllu2flair([csentence])[0] + model.predict(fsentence) + return [(token, ftoken.get_tag('slot').value) for token, ftoken in zip(sentence, fsentence)] + + def setup(self): + + if os.path.isfile('slot-model/final-model.pt'): + model = SequenceTagger.load('slot-model/final-model.pt') + else: + fields = ['id', 'form', 'frame', 'slot'] + + with open('Janet_test.conllu', encoding='utf-8') as trainfile: + trainset = list(parse_incr(trainfile, fields=fields, field_parsers={'slot': self.nolabel2o})) + with open('Janet_test.conllu', encoding='utf-8') as testfile: + testset = list(parse_incr(testfile, fields=fields, field_parsers={'slot': self.nolabel2o})) + + tabulate(trainset[0], tablefmt='html') + + corpus = Corpus(train=self.conllu2flair(trainset, 'slot'), test=self.conllu2flair(testset, 'slot')) + tag_dictionary = corpus.make_tag_dictionary(tag_type='slot') + + embedding_types = [ + WordEmbeddings('pl'), + FlairEmbeddings('pl-forward'), + FlairEmbeddings('pl-backward'), + CharacterEmbeddings(), + ] + + embeddings = StackedEmbeddings(embeddings=embedding_types) + tagger = SequenceTagger(hidden_size=256, embeddings=embeddings, + tag_dictionary=tag_dictionary, + tag_type='slot', use_crf=True) + + trainer = ModelTrainer(tagger, corpus) + trainer.train('slot-model', + learning_rate=0.1, + mini_batch_size=32, + max_epochs=10, + train_with_dev=False) + + model = SequenceTagger.load('slot-model/final-model.pt') + + + return model + + def test_nlu(self, utterance): + + model = self.setup() + + if utterance: + return tabulate(self.predict(model, 'doktor lekarza rodzinnego najlepiej dzisiaj w godzinach popołudniowych dziś '.split()), tablefmt='html') + + else: + return 'Critical Error' + +class Book_NLU: #Natural Language Understanding """ Moduł odpowiedzialny za analizę tekstu. W wyniku jego działania tekstowa reprezentacja wypowiedzi użytkownika zostaje zamieniona na jej reprezentację semantyczną, najczęściej w postaci ramy. @@ -136,10 +244,11 @@ class Janet: self.nlg = NLG(self.acts, self.arguments) self.dp = DP(self.acts, self.arguments) self.dst = DST(self.acts, self.arguments) - self.nlu = NLU(self.acts, self.arguments, jsgf.parse_grammar_file('book.jsgf')) + self.nlu = Book_NLU(self.acts, self.arguments, jsgf.parse_grammar_file('book.jsgf')) + self.nlu_v2 = ML_NLU(self.acts, self.arguments) def test(self, command): - out = self.nlu.test_nlu(command) + out = self.nlu_v2.test_nlu(command) return out def process(self, command):