From 6114ec26fd1a4c89c43d83ac1ff188def693d367 Mon Sep 17 00:00:00 2001 From: Krzysztof Bojakowski Date: Wed, 8 May 2024 01:59:30 +0200 Subject: [PATCH] Skrypt do trenowanie modelu w oparciu o frame oraz slots, wstepny skrypt do ewaluacji evaluate.py, skrypt do testow --- evaluate.py | 45 +++++++++++++++++++++++ nlu_tests.py | 30 ++++++++++++++++ nlu_train.py | 46 ++++++++++++++++++++++++ nlu_utils.py | 100 +++++++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 221 insertions(+) create mode 100644 evaluate.py create mode 100644 nlu_tests.py create mode 100644 nlu_train.py create mode 100644 nlu_utils.py diff --git a/evaluate.py b/evaluate.py new file mode 100644 index 0000000..67e849d --- /dev/null +++ b/evaluate.py @@ -0,0 +1,45 @@ +import re +import os +import pandas as pd +import numpy as np +from nlu_utils import predict_multiple +from flair.models import SequenceTagger + +def __parse_acts(acts): + acts_split = acts.split('&') + remove_slot_regex = "[\(\[].*?[\)\]]" + return set(re.sub(remove_slot_regex, "", act) for act in acts_split) + +def __parse_predictions(predictions): + return set(prediction.split('/')[0] for prediction in predictions) + +# Exploratory tests +frame_model = SequenceTagger.load('frame-model-prod/best-model.pt') +# slot_model = SequenceTagger.load('slot-model-prod/final-model.pt') + +total_acts = 0 +act_correct_predictions = 0 +slot_correct_predictions = 0 + +for file_name in os.listdir('data'): + if file_name.split('.')[-1] != 'tsv': + continue + + df = pd.read_csv(f'data/{file_name}', sep='\t', names=['kto', 'treść', 'akt']) + df = df[df.kto == 'user'] + all_data = np.array(df) + + for row in all_data: + sentence = row[1] + acts = __parse_acts(row[2]) + + predictions_raw = predict_multiple(frame_model, sentence.split(), 'frame') + predictions = __parse_predictions(predictions_raw) + + for act in acts: + total_acts += 1 + if act in predictions: + act_correct_predictions += 1 + + +print(f"Accuracy - predicting acts: {(act_correct_predictions / total_acts)*100} ({act_correct_predictions}/{total_acts})") \ No newline at end of file diff --git a/nlu_tests.py b/nlu_tests.py new file mode 100644 index 0000000..a65da84 --- /dev/null +++ b/nlu_tests.py @@ -0,0 +1,30 @@ +from flair.models import SequenceTagger +from nlu_utils import predict_single, predict_multiple, predict_and_annotate + +# Exploratory tests +frame_model = SequenceTagger.load('frame-model/best-model.pt') +tests = [ + 'chciałbym zamówić pizzę', + 'na godzinę 12', + 'prosiłbym o pizzę z pieczarkami', + 'to wszystko, jaka cena?', + 'ile kosztuje pizza', + 'do widzenia', + 'tak', + 'nie dziękuję', + 'dodatkowy ser', + 'pizzę barcelona bez cebuli', +] + +# print("=== Exploratory tests - frame model ===") +for test in tests: + print(f"Sentence: {test}") + print(f"Single prediction: {predict_single(frame_model, test.split(), 'frame')}") + print(f"Multiple predictions: {predict_multiple(frame_model, test.split(), 'frame')}") + print(f"Annotated sentence: {predict_and_annotate(frame_model, test.split(), 'frame')}") + +print("=== Exploratory tests - slot model ===") +slot_model = SequenceTagger.load('slot-model/final-model.pt') +for test in tests: + print(f"Sentence: {test}") + print(f"Prediction: {predict_and_annotate(slot_model, test.split(), 'slot')}") \ No newline at end of file diff --git a/nlu_train.py b/nlu_train.py new file mode 100644 index 0000000..2018763 --- /dev/null +++ b/nlu_train.py @@ -0,0 +1,46 @@ +from conllu import parse_incr +from flair.data import Corpus +from flair.embeddings import StackedEmbeddings +from flair.embeddings import WordEmbeddings +from flair.embeddings import CharacterEmbeddings +from flair.embeddings import FlairEmbeddings +from flair.models import SequenceTagger +from flair.trainers import ModelTrainer +from nlu_utils import conllu2flair, nolabel2o + +import random +import torch +random.seed(42) +torch.manual_seed(42) + +if torch.cuda.is_available(): + torch.cuda.manual_seed(0) + torch.cuda.manual_seed_all(0) + torch.backends.cudnn.enabled = False + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + + +def train_model(label_type, field_parsers = {}): + with open('data/train_dialog.conllu', encoding='utf-8') as trainfile: + trainset = list(parse_incr(trainfile, fields=['id', 'form', 'frame', 'slot'], field_parsers=field_parsers)) + + corpus = Corpus(train=conllu2flair(trainset, label_type), test=conllu2flair(trainset, label_type)) + label_dictionary = corpus.make_label_dictionary(label_type=label_type) + + embedding_types = [ + WordEmbeddings('pl'), + FlairEmbeddings('pl-forward'), + FlairEmbeddings('pl-backward'), + CharacterEmbeddings(), + ] + + embeddings = StackedEmbeddings(embeddings=embedding_types) + tagger = SequenceTagger(hidden_size=256, embeddings=embeddings, tag_dictionary=label_dictionary, tag_type=label_type, use_crf=True, tag_format="BIO") + + frame_trainer = ModelTrainer(tagger, corpus) + frame_trainer.train(f'{label_type}-model', learning_rate=0.1, mini_batch_size=32, max_epochs=75, train_with_dev=False) + +if __name__ == '__main__': + train_model("frame") + train_model('slot', field_parsers={'slot': nolabel2o}) \ No newline at end of file diff --git a/nlu_utils.py b/nlu_utils.py new file mode 100644 index 0000000..ae5e51b --- /dev/null +++ b/nlu_utils.py @@ -0,0 +1,100 @@ +from flair.data import Sentence +from flair.datasets import FlairDatapointDataset + +def nolabel2o(line, i): + return 'O' if line[i] == 'NoLabel' else line[i] + +def conllu2flair(sentences, label=None): + if label == "frame": + return conllu2flair_frame(sentences, label) + else: + return conllu2flair_slot(sentences, label) + +def conllu2flair_frame(sentences, label=None): + fsentences = [] + for sentence in sentences: + tokens = [token["form"] for token in sentence] + fsentence = Sentence(' '.join(tokens), use_tokenizer=False) + + for i in range(len(fsentence)): + fsentence[i:i+1].add_label(label, sentence[i][label]) + + fsentences.append(fsentence) + + return FlairDatapointDataset(fsentences) + +def conllu2flair_slot(sentences, label=None): + fsentences = [] + + for sentence in sentences: + fsentence = Sentence(' '.join(token['form'] for token in sentence), use_tokenizer=False) + start_idx = None + end_idx = None + tag = None + + if label: + for idx, (token, ftoken) in enumerate(zip(sentence, fsentence)): + if token[label].startswith('B-'): + start_idx = idx + end_idx = idx + tag = token[label][2:] + elif token[label].startswith('I-'): + end_idx = idx + elif token[label] == 'O': + if start_idx is not None: + fsentence[start_idx:end_idx+1].add_label(label, tag) + start_idx = None + end_idx = None + tag = None + + if start_idx is not None: + fsentence[start_idx:end_idx+1].add_label(label, tag) + + fsentences.append(fsentence) + return FlairDatapointDataset(fsentences) + +def __predict(model, csentence): + fsentence = conllu2flair([csentence])[0] + model.predict(fsentence) + return fsentence + +def __csentence(sentence, label_type): + if label_type == "frame": + return [{'form': word } for word in sentence] + else: + return [{'form': word, 'slot': 'O'} for word in sentence] + +def predict_single(model, sentence, label_type): + csentence = __csentence(sentence, label_type) + fsentence = __predict(model, csentence) + intent = {} + + for span in fsentence.get_spans(label_type): + tag = span.get_label(label_type).value + if tag in intent: + intent[tag] += 1 + else: + intent[tag] = 1 + + return max(intent, key=intent.get) + +def predict_multiple(model, sentence, label_type): + csentence = __csentence(sentence, label_type) + fsentence = __predict(model, csentence) + + return set(span.get_label(label_type).value for span in fsentence.get_spans(label_type)) + +def predict_and_annotate(model, sentence, label_type): + csentence = __csentence(sentence, label_type) + fsentence = __predict(model, csentence) + + for span in fsentence.get_spans(label_type): + tag = span.get_label(label_type).value + if label_type == "frame": + csentence[span.tokens[0].idx-1]['frame'] = tag + else: + csentence[span.tokens[0].idx - 1]['slot'] = f'B-{tag}' + for token in span.tokens[1:]: + csentence[token.idx - 1]['slot'] = f'I-{tag}' + + return csentence \ No newline at end of file