Compare commits
4 Commits
master
...
test_branc
Author | SHA1 | Date | |
---|---|---|---|
|
893733025b | ||
99306f0532 | |||
7a39f24359 | |||
|
7306fdb83a |
1
ConvLab-3
Submodule
1
ConvLab-3
Submodule
@ -0,0 +1 @@
|
|||||||
|
Subproject commit 60f4e5641f93e99b8d61b49cf5fd6dc818a83c4c
|
@ -33,15 +33,14 @@ class MachineLearningNLG:
|
|||||||
return translated_response
|
return translated_response
|
||||||
|
|
||||||
def generate(self, action):
|
def generate(self, action):
|
||||||
# Przyjmujemy, że 'action' jest formatowanym stringiem, który jest przekazywany do self.nlg
|
|
||||||
return self.nlg(action)
|
return self.nlg(action)
|
||||||
|
|
||||||
def init_session(self):
|
def init_session(self):
|
||||||
pass # Dodanie pustej metody init_session
|
pass
|
||||||
|
|
||||||
|
|
||||||
# Przykład użycia
|
# Przykład użycia
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
nlg = MachineLearningNLG()
|
nlg = MachineLearningNLG()
|
||||||
system_act = "inform(date.from=15.07, date.to=22.07)"
|
system_act = "inform(hotel='Four Seasons Hotel')"
|
||||||
print(nlg.nlg(system_act))
|
print(nlg.nlg(system_act))
|
||||||
|
3
Main.py
3
Main.py
@ -3,13 +3,14 @@ from DialoguePolicy import DialoguePolicy
|
|||||||
from DialogueStateTracker import DialogueStateTracker
|
from DialogueStateTracker import DialogueStateTracker
|
||||||
from convlab.dialog_agent import PipelineAgent
|
from convlab.dialog_agent import PipelineAgent
|
||||||
from MachineLearningNLG import MachineLearningNLG # Importujemy nowy komponent NLG
|
from MachineLearningNLG import MachineLearningNLG # Importujemy nowy komponent NLG
|
||||||
|
from convlab.nlg.template.multiwoz import TemplateNLG
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
text = "chciałbym zarezerwować drogi hotel z parkingiem 1 stycznia w Warszawie w centrum"
|
text = "chciałbym zarezerwować drogi hotel z parkingiem 1 stycznia w Warszawie w centrum"
|
||||||
nlu = NaturalLanguageAnalyzer()
|
nlu = NaturalLanguageAnalyzer()
|
||||||
dst = DialogueStateTracker()
|
dst = DialogueStateTracker()
|
||||||
policy = DialoguePolicy()
|
policy = DialoguePolicy()
|
||||||
nlg = MachineLearningNLG()
|
nlg = TemplateNLG(is_user=False)
|
||||||
|
|
||||||
agent = PipelineAgent(nlu=nlu, dst=dst, policy=policy, nlg=nlg, name='sys')
|
agent = PipelineAgent(nlu=nlu, dst=dst, policy=policy, nlg=nlg, name='sys')
|
||||||
response = agent.response(text)
|
response = agent.response(text)
|
||||||
|
@ -11,7 +11,8 @@ for file_name in os.listdir(translated_data_directory):
|
|||||||
if file_name.endswith('.tsv'):
|
if file_name.endswith('.tsv'):
|
||||||
file_path = os.path.join(translated_data_directory, file_name)
|
file_path = os.path.join(translated_data_directory, file_name)
|
||||||
df = pd.read_csv(file_path, sep='\t')
|
df = pd.read_csv(file_path, sep='\t')
|
||||||
dfs.append(df)
|
df_user = df[df['role'] == 'system'].drop('role', axis=1)
|
||||||
|
dfs.append(df_user)
|
||||||
combined_df = pd.concat(dfs, ignore_index=True)
|
combined_df = pd.concat(dfs, ignore_index=True)
|
||||||
|
|
||||||
# Przygotowanie zbioru danych do trenowania
|
# Przygotowanie zbioru danych do trenowania
|
||||||
@ -49,7 +50,7 @@ training_args = Seq2SeqTrainingArguments(
|
|||||||
per_device_eval_batch_size=16,
|
per_device_eval_batch_size=16,
|
||||||
predict_with_generate=True,
|
predict_with_generate=True,
|
||||||
learning_rate=5e-5,
|
learning_rate=5e-5,
|
||||||
num_train_epochs=3,
|
num_train_epochs=10,
|
||||||
evaluation_strategy="epoch",
|
evaluation_strategy="epoch",
|
||||||
save_strategy="epoch",
|
save_strategy="epoch",
|
||||||
save_total_limit=None, # Wyłącz rotację punktów kontrolnych
|
save_total_limit=None, # Wyłącz rotację punktów kontrolnych
|
||||||
|
Loading…
Reference in New Issue
Block a user