This commit is contained in:
Maciej Matusz 2024-06-04 00:38:14 +02:00
parent a2e2dd4e23
commit 0b2efaa3f4
3 changed files with 23 additions and 3 deletions

View File

@ -1,15 +1,21 @@
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline
import requests import requests
import os
import time
class MachineLearningNLG: class MachineLearningNLG:
def __init__(self): def __init__(self):
self.model_name = "./nlg_model" # Ścieżka do wytrenowanego modelu self.model_name = "./nlg_model" # Ścieżka do wytrenowanego modelu
if not os.path.exists(self.model_name):
raise ValueError(
f"Ścieżka {self.model_name} nie istnieje. Upewnij się, że model został poprawnie zapisany.")
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name) self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name)
self.generator = pipeline('text2text-generation', model=self.model, tokenizer=self.tokenizer) self.generator = pipeline('text2text-generation', model=self.model, tokenizer=self.tokenizer)
def translate_text(self, text, target_language='pl'): def translate_text(self, text, target_language='pl'):
url = 'https://translate.googleapis.com/translate_a/single?client=gtx&sl=auto&tl={}&dt=t&q={}'.format(target_language, text) url = f'https://translate.googleapis.com/translate_a/single?client=gtx&sl=auto&tl={target_language}&dt=t&q={text}'
response = requests.get(url) response = requests.get(url)
if response.status_code == 200: if response.status_code == 200:
translated_text = response.json()[0][0][0] translated_text = response.json()[0][0][0]
@ -19,12 +25,22 @@ class MachineLearningNLG:
def nlg(self, system_act): def nlg(self, system_act):
input_text = f"generate text: {system_act}" input_text = f"generate text: {system_act}"
start_time = time.time()
result = self.generator(input_text) result = self.generator(input_text)
response_time = time.time() - start_time
response = result[0]['generated_text'] response = result[0]['generated_text']
translated_response = self.translate_text(response, target_language='pl') translated_response = self.translate_text(response, target_language='pl')
return translated_response return translated_response
def generate(self, action):
# Przyjmujemy, że 'action' jest formatowanym stringiem, który jest przekazywany do self.nlg
return self.nlg(action)
def init_session(self):
pass # Dodanie pustej metody init_session
# Przykład użycia
if __name__ == "__main__": if __name__ == "__main__":
nlg = MachineLearningNLG() nlg = MachineLearningNLG()
system_act = "inform(date.from=15.07, date.to=22.07)" system_act = "inform(date.from=15.07, date.to=22.07)"

View File

@ -9,7 +9,7 @@ if __name__ == "__main__":
nlu = NaturalLanguageAnalyzer() nlu = NaturalLanguageAnalyzer()
dst = DialogueStateTracker() dst = DialogueStateTracker()
policy = DialoguePolicy() policy = DialoguePolicy()
nlg = MachineLearningNLG() # Używamy nowego komponentu NLG nlg = MachineLearningNLG()
agent = PipelineAgent(nlu=nlu, dst=dst, policy=policy, nlg=nlg, name='sys') agent = PipelineAgent(nlu=nlu, dst=dst, policy=policy, nlg=nlg, name='sys')
response = agent.response(text) response = agent.response(text)

View File

@ -52,7 +52,7 @@ training_args = Seq2SeqTrainingArguments(
num_train_epochs=3, num_train_epochs=3,
evaluation_strategy="epoch", evaluation_strategy="epoch",
save_strategy="epoch", save_strategy="epoch",
save_total_limit=1, save_total_limit=None, # Wyłącz rotację punktów kontrolnych
load_best_model_at_end=True, load_best_model_at_end=True,
) )
@ -67,3 +67,7 @@ trainer = Seq2SeqTrainer(
# Trening modelu # Trening modelu
trainer.train() trainer.train()
# Zapisanie wytrenowanego modelu
trainer.save_model("./nlg_model")
tokenizer.save_pretrained("./nlg_model")