Compare commits

...

2 Commits

Author SHA1 Message Date
PawelDopierala aca0beb345 Merge remote-tracking branch 'origin/master' 2024-05-08 23:36:13 +02:00
PawelDopierala 95a7cd4305 do NaturalLanguageAnalyzer.py and evaluate.py 2024-05-08 23:36:04 +02:00
5 changed files with 85 additions and 32 deletions

View File

@ -1,35 +1,28 @@
import jsgf
from convlab.base_models.t5.nlu import T5NLU
import requests
def translate_text(text, target_language='en'):
url = 'https://translate.googleapis.com/translate_a/single?client=gtx&sl=auto&tl={}&dt=t&q={}'.format(
target_language, text)
response = requests.get(url)
if response.status_code == 200:
translated_text = response.json()[0][0][0]
return translated_text
else:
return None
class NaturalLanguageAnalyzer:
# def process(self, text):
# user_act = None
# if ("imie" in text or "imię" in text) and "?" in text:
# user_act = "request(firstname)"
# return user_act
def process(self, text):
with open('grammar_1.jsgf', 'r', encoding='utf-8') as f:
content = f.read()
book_grammar = jsgf.parse_grammar_string(content)
matched = book_grammar.find_matching_rules(text)
if matched:
return self.get_dialog_act(matched[0])
else:
return {'act': 'null', 'slots': []}
# Inicjalizacja modelu NLU
model_name = "ConvLab/t5-small-nlu-multiwoz21"
nlu_model = T5NLU(speaker='user', context_window_size=0, model_name_or_path=model_name)
def get_slots(self, expansion, slots):
if expansion.tag != '':
slots.append((expansion.tag, expansion.current_match))
return
# Automatyczne tłumaczenie na język angielski
translated_input = translate_text(text)
for child in expansion.children:
self.get_slots(child, slots)
# Wygenerowanie odpowiedzi z modelu NLU
nlu_output = nlu_model.predict(translated_input)
if not expansion.children and isinstance(expansion, jsgf.NamedRuleRef):
self.get_slots(expansion.referenced_rule.expansion, slots)
def get_dialog_act(self, rule):
slots = []
self.get_slots(rule.expansion, slots)
return {'act': rule.grammar.name, 'slots': slots}
return nlu_output

29
archives/iobes_slot.py Normal file
View File

@ -0,0 +1,29 @@
import pandas as pd
# Wczytanie danych z pliku TSV
data = pd.read_csv("combined_df.tsv", sep="\t")
# Inicjalizacja pustej listy do przechowywania tagów IOBES
tags = []
# Iteracja po każdym wierszu danych
for index, row in data.iterrows():
# Podział akcji na pojedyncze słowa
words = row['act'].split()
# Początkowy tag IOBES to 'O' dla każdego słowa
current_tags = ['O'] * len(words)
# Ustawienie tagu 'B' dla pierwszego słowa
current_tags[0] = 'B'
# Ustawienie tagu 'E' dla ostatniego słowa
current_tags[-1] = 'E'
# Ustawienie tagu 'I' dla pozostałych słów, jeśli są
if len(words) > 2:
current_tags[1:-1] = ['I'] * (len(words) - 2)
# Dodanie tagów do listy
tags.extend(current_tags)
# Dodanie kolumny z tagami do danych
data['tags'] = tags
# Zapisanie danych z tagami do nowego pliku TSV
data.to_csv("nazwa_pliku_z_tagami.tsv", sep="\t", index=False)

View File

@ -1,5 +1,7 @@
import os
import pandas as pd
import re
from NaturalLanguageAnalyzer import NaturalLanguageAnalyzer
data_directory = 'data'
@ -14,9 +16,38 @@ for file_name in file_list:
combined_df = pd.concat(dfs, ignore_index=True)
for text, act in zip(combined_df["value"].values, combined_df["act"].values):
change_act_format = {
"thankyou": "thank",
"bye": "thank",
"hello": "greet",
"inform": "inform",
"request": "request",
"reqmore": "request"
}
correct = 0
incorrect = 0
for text, ground_act in zip(combined_df["value"].values, combined_df["act"].values):
nla = NaturalLanguageAnalyzer()
user_act = nla.process(text)
print(user_act)
print(act)
nla_output = nla.process(text)
predicted_act = set([i[0] for i in nla_output])
pattern = re.compile(r'([^(&]+)(?=\()')
matches = re.findall(pattern, ground_act)
ground_act_processed = set()
for match in matches:
if match in change_act_format:
ground_act_processed.add(change_act_format[match])
for i in ground_act_processed:
if i in predicted_act:
correct += 1
else:
incorrect += 1
print("Predicted:", predicted_act)
print("Ground truth:", ground_act_processed)
print()
accuracy = correct/(correct+incorrect)
print("Accuracy: ", accuracy)