diff --git a/MachineLearningNLG.py b/MachineLearningNLG.py index 4fd6090..f9434dd 100644 --- a/MachineLearningNLG.py +++ b/MachineLearningNLG.py @@ -1,15 +1,21 @@ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline import requests +import os +import time + class MachineLearningNLG: def __init__(self): self.model_name = "./nlg_model" # Ścieżka do wytrenowanego modelu + if not os.path.exists(self.model_name): + raise ValueError( + f"Ścieżka {self.model_name} nie istnieje. Upewnij się, że model został poprawnie zapisany.") self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name) self.generator = pipeline('text2text-generation', model=self.model, tokenizer=self.tokenizer) def translate_text(self, text, target_language='pl'): - url = 'https://translate.googleapis.com/translate_a/single?client=gtx&sl=auto&tl={}&dt=t&q={}'.format(target_language, text) + url = f'https://translate.googleapis.com/translate_a/single?client=gtx&sl=auto&tl={target_language}&dt=t&q={text}' response = requests.get(url) if response.status_code == 200: translated_text = response.json()[0][0][0] @@ -19,12 +25,22 @@ class MachineLearningNLG: def nlg(self, system_act): input_text = f"generate text: {system_act}" + start_time = time.time() result = self.generator(input_text) + response_time = time.time() - start_time response = result[0]['generated_text'] translated_response = self.translate_text(response, target_language='pl') return translated_response + def generate(self, action): + # Przyjmujemy, że 'action' jest formatowanym stringiem, który jest przekazywany do self.nlg + return self.nlg(action) + def init_session(self): + pass # Dodanie pustej metody init_session + + +# Przykład użycia if __name__ == "__main__": nlg = MachineLearningNLG() system_act = "inform(date.from=15.07, date.to=22.07)" diff --git a/Main.py b/Main.py index 077c2b0..2575c44 100644 --- a/Main.py +++ b/Main.py @@ -9,7 +9,7 @@ if __name__ == "__main__": nlu = NaturalLanguageAnalyzer() dst = DialogueStateTracker() policy = DialoguePolicy() - nlg = MachineLearningNLG() # Używamy nowego komponentu NLG + nlg = MachineLearningNLG() agent = PipelineAgent(nlu=nlu, dst=dst, policy=policy, nlg=nlg, name='sys') response = agent.response(text) diff --git a/train_nlg.py b/train_nlg.py index 00d2269..032083c 100644 --- a/train_nlg.py +++ b/train_nlg.py @@ -52,7 +52,7 @@ training_args = Seq2SeqTrainingArguments( num_train_epochs=3, evaluation_strategy="epoch", save_strategy="epoch", - save_total_limit=1, + save_total_limit=None, # Wyłącz rotację punktów kontrolnych load_best_model_at_end=True, ) @@ -67,3 +67,7 @@ trainer = Seq2SeqTrainer( # Trening modelu trainer.train() + +# Zapisanie wytrenowanego modelu +trainer.save_model("./nlg_model") +tokenizer.save_pretrained("./nlg_model")