GOATS/MachineLearningNLG.py

48 lines
1.9 KiB
Python
Raw Normal View History

2024-06-03 22:36:02 +02:00
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline
import requests
2024-06-04 00:38:14 +02:00
import os
import time
2024-06-03 22:36:02 +02:00
class MachineLearningNLG:
def __init__(self):
self.model_name = "./nlg_model" # Ścieżka do wytrenowanego modelu
2024-06-04 00:38:14 +02:00
if not os.path.exists(self.model_name):
raise ValueError(
f"Ścieżka {self.model_name} nie istnieje. Upewnij się, że model został poprawnie zapisany.")
2024-06-03 22:36:02 +02:00
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name)
self.generator = pipeline('text2text-generation', model=self.model, tokenizer=self.tokenizer)
def translate_text(self, text, target_language='pl'):
2024-06-04 00:38:14 +02:00
url = f'https://translate.googleapis.com/translate_a/single?client=gtx&sl=auto&tl={target_language}&dt=t&q={text}'
2024-06-03 22:36:02 +02:00
response = requests.get(url)
if response.status_code == 200:
translated_text = response.json()[0][0][0]
return translated_text
else:
return text # Zwracamy oryginalny tekst w razie problemów z tłumaczeniem
def nlg(self, system_act):
input_text = f"generate text: {system_act}"
2024-06-04 00:38:14 +02:00
start_time = time.time()
2024-06-03 22:36:02 +02:00
result = self.generator(input_text)
2024-06-04 00:38:14 +02:00
response_time = time.time() - start_time
2024-06-03 22:36:02 +02:00
response = result[0]['generated_text']
translated_response = self.translate_text(response, target_language='pl')
return translated_response
2024-06-04 00:38:14 +02:00
def generate(self, action):
# Przyjmujemy, że 'action' jest formatowanym stringiem, który jest przekazywany do self.nlg
return self.nlg(action)
def init_session(self):
pass # Dodanie pustej metody init_session
2024-06-03 22:36:02 +02:00
2024-06-04 00:38:14 +02:00
# Przykład użycia
2024-06-03 22:36:02 +02:00
if __name__ == "__main__":
nlg = MachineLearningNLG()
system_act = "inform(date.from=15.07, date.to=22.07)"
print(nlg.nlg(system_act))