nlg
This commit is contained in:
parent
a2e2dd4e23
commit
0b2efaa3f4
@ -1,15 +1,21 @@
|
|||||||
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline
|
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline
|
||||||
import requests
|
import requests
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
|
||||||
|
|
||||||
class MachineLearningNLG:
|
class MachineLearningNLG:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.model_name = "./nlg_model" # Ścieżka do wytrenowanego modelu
|
self.model_name = "./nlg_model" # Ścieżka do wytrenowanego modelu
|
||||||
|
if not os.path.exists(self.model_name):
|
||||||
|
raise ValueError(
|
||||||
|
f"Ścieżka {self.model_name} nie istnieje. Upewnij się, że model został poprawnie zapisany.")
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
||||||
self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name)
|
self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name)
|
||||||
self.generator = pipeline('text2text-generation', model=self.model, tokenizer=self.tokenizer)
|
self.generator = pipeline('text2text-generation', model=self.model, tokenizer=self.tokenizer)
|
||||||
|
|
||||||
def translate_text(self, text, target_language='pl'):
|
def translate_text(self, text, target_language='pl'):
|
||||||
url = 'https://translate.googleapis.com/translate_a/single?client=gtx&sl=auto&tl={}&dt=t&q={}'.format(target_language, text)
|
url = f'https://translate.googleapis.com/translate_a/single?client=gtx&sl=auto&tl={target_language}&dt=t&q={text}'
|
||||||
response = requests.get(url)
|
response = requests.get(url)
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
translated_text = response.json()[0][0][0]
|
translated_text = response.json()[0][0][0]
|
||||||
@ -19,12 +25,22 @@ class MachineLearningNLG:
|
|||||||
|
|
||||||
def nlg(self, system_act):
|
def nlg(self, system_act):
|
||||||
input_text = f"generate text: {system_act}"
|
input_text = f"generate text: {system_act}"
|
||||||
|
start_time = time.time()
|
||||||
result = self.generator(input_text)
|
result = self.generator(input_text)
|
||||||
|
response_time = time.time() - start_time
|
||||||
response = result[0]['generated_text']
|
response = result[0]['generated_text']
|
||||||
translated_response = self.translate_text(response, target_language='pl')
|
translated_response = self.translate_text(response, target_language='pl')
|
||||||
return translated_response
|
return translated_response
|
||||||
|
|
||||||
|
def generate(self, action):
|
||||||
|
# Przyjmujemy, że 'action' jest formatowanym stringiem, który jest przekazywany do self.nlg
|
||||||
|
return self.nlg(action)
|
||||||
|
|
||||||
|
def init_session(self):
|
||||||
|
pass # Dodanie pustej metody init_session
|
||||||
|
|
||||||
|
|
||||||
|
# Przykład użycia
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
nlg = MachineLearningNLG()
|
nlg = MachineLearningNLG()
|
||||||
system_act = "inform(date.from=15.07, date.to=22.07)"
|
system_act = "inform(date.from=15.07, date.to=22.07)"
|
||||||
|
2
Main.py
2
Main.py
@ -9,7 +9,7 @@ if __name__ == "__main__":
|
|||||||
nlu = NaturalLanguageAnalyzer()
|
nlu = NaturalLanguageAnalyzer()
|
||||||
dst = DialogueStateTracker()
|
dst = DialogueStateTracker()
|
||||||
policy = DialoguePolicy()
|
policy = DialoguePolicy()
|
||||||
nlg = MachineLearningNLG() # Używamy nowego komponentu NLG
|
nlg = MachineLearningNLG()
|
||||||
|
|
||||||
agent = PipelineAgent(nlu=nlu, dst=dst, policy=policy, nlg=nlg, name='sys')
|
agent = PipelineAgent(nlu=nlu, dst=dst, policy=policy, nlg=nlg, name='sys')
|
||||||
response = agent.response(text)
|
response = agent.response(text)
|
||||||
|
@ -52,7 +52,7 @@ training_args = Seq2SeqTrainingArguments(
|
|||||||
num_train_epochs=3,
|
num_train_epochs=3,
|
||||||
evaluation_strategy="epoch",
|
evaluation_strategy="epoch",
|
||||||
save_strategy="epoch",
|
save_strategy="epoch",
|
||||||
save_total_limit=1,
|
save_total_limit=None, # Wyłącz rotację punktów kontrolnych
|
||||||
load_best_model_at_end=True,
|
load_best_model_at_end=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -67,3 +67,7 @@ trainer = Seq2SeqTrainer(
|
|||||||
|
|
||||||
# Trening modelu
|
# Trening modelu
|
||||||
trainer.train()
|
trainer.train()
|
||||||
|
|
||||||
|
# Zapisanie wytrenowanego modelu
|
||||||
|
trainer.save_model("./nlg_model")
|
||||||
|
tokenizer.save_pretrained("./nlg_model")
|
||||||
|
Loading…
Reference in New Issue
Block a user