Compare commits

...

4 Commits

Author SHA1 Message Date
PawelDopierala
893733025b Change to TemplateNLG 2024-06-04 13:07:41 +02:00
99306f0532 Merge remote-tracking branch 'origin/test_branch' into test_branch 2024-06-04 11:44:58 +02:00
7a39f24359 test_learning 2024-06-04 11:43:14 +02:00
Maciej Matusz
7306fdb83a test- commit 2024-06-04 10:56:40 +02:00
4 changed files with 8 additions and 6 deletions

1
ConvLab-3 Submodule

@ -0,0 +1 @@
Subproject commit 60f4e5641f93e99b8d61b49cf5fd6dc818a83c4c

View File

@ -33,15 +33,14 @@ class MachineLearningNLG:
return translated_response
def generate(self, action):
# Przyjmujemy, że 'action' jest formatowanym stringiem, który jest przekazywany do self.nlg
return self.nlg(action)
def init_session(self):
pass # Dodanie pustej metody init_session
pass
# Przykład użycia
if __name__ == "__main__":
nlg = MachineLearningNLG()
system_act = "inform(date.from=15.07, date.to=22.07)"
system_act = "inform(hotel='Four Seasons Hotel')"
print(nlg.nlg(system_act))

View File

@ -3,13 +3,14 @@ from DialoguePolicy import DialoguePolicy
from DialogueStateTracker import DialogueStateTracker
from convlab.dialog_agent import PipelineAgent
from MachineLearningNLG import MachineLearningNLG # Importujemy nowy komponent NLG
from convlab.nlg.template.multiwoz import TemplateNLG
if __name__ == "__main__":
text = "chciałbym zarezerwować drogi hotel z parkingiem 1 stycznia w Warszawie w centrum"
nlu = NaturalLanguageAnalyzer()
dst = DialogueStateTracker()
policy = DialoguePolicy()
nlg = MachineLearningNLG()
nlg = TemplateNLG(is_user=False)
agent = PipelineAgent(nlu=nlu, dst=dst, policy=policy, nlg=nlg, name='sys')
response = agent.response(text)

View File

@ -11,7 +11,8 @@ for file_name in os.listdir(translated_data_directory):
if file_name.endswith('.tsv'):
file_path = os.path.join(translated_data_directory, file_name)
df = pd.read_csv(file_path, sep='\t')
dfs.append(df)
df_user = df[df['role'] == 'system'].drop('role', axis=1)
dfs.append(df_user)
combined_df = pd.concat(dfs, ignore_index=True)
# Przygotowanie zbioru danych do trenowania
@ -49,7 +50,7 @@ training_args = Seq2SeqTrainingArguments(
per_device_eval_batch_size=16,
predict_with_generate=True,
learning_rate=5e-5,
num_train_epochs=3,
num_train_epochs=10,
evaluation_strategy="epoch",
save_strategy="epoch",
save_total_limit=None, # Wyłącz rotację punktów kontrolnych