Compare commits
4 Commits
master
...
test_branc
Author | SHA1 | Date | |
---|---|---|---|
|
893733025b | ||
99306f0532 | |||
7a39f24359 | |||
|
7306fdb83a |
1
ConvLab-3
Submodule
1
ConvLab-3
Submodule
@ -0,0 +1 @@
|
||||
Subproject commit 60f4e5641f93e99b8d61b49cf5fd6dc818a83c4c
|
@ -33,15 +33,14 @@ class MachineLearningNLG:
|
||||
return translated_response
|
||||
|
||||
def generate(self, action):
|
||||
# Przyjmujemy, że 'action' jest formatowanym stringiem, który jest przekazywany do self.nlg
|
||||
return self.nlg(action)
|
||||
|
||||
def init_session(self):
|
||||
pass # Dodanie pustej metody init_session
|
||||
pass
|
||||
|
||||
|
||||
# Przykład użycia
|
||||
if __name__ == "__main__":
|
||||
nlg = MachineLearningNLG()
|
||||
system_act = "inform(date.from=15.07, date.to=22.07)"
|
||||
system_act = "inform(hotel='Four Seasons Hotel')"
|
||||
print(nlg.nlg(system_act))
|
||||
|
3
Main.py
3
Main.py
@ -3,13 +3,14 @@ from DialoguePolicy import DialoguePolicy
|
||||
from DialogueStateTracker import DialogueStateTracker
|
||||
from convlab.dialog_agent import PipelineAgent
|
||||
from MachineLearningNLG import MachineLearningNLG # Importujemy nowy komponent NLG
|
||||
from convlab.nlg.template.multiwoz import TemplateNLG
|
||||
|
||||
if __name__ == "__main__":
|
||||
text = "chciałbym zarezerwować drogi hotel z parkingiem 1 stycznia w Warszawie w centrum"
|
||||
nlu = NaturalLanguageAnalyzer()
|
||||
dst = DialogueStateTracker()
|
||||
policy = DialoguePolicy()
|
||||
nlg = MachineLearningNLG()
|
||||
nlg = TemplateNLG(is_user=False)
|
||||
|
||||
agent = PipelineAgent(nlu=nlu, dst=dst, policy=policy, nlg=nlg, name='sys')
|
||||
response = agent.response(text)
|
||||
|
@ -11,7 +11,8 @@ for file_name in os.listdir(translated_data_directory):
|
||||
if file_name.endswith('.tsv'):
|
||||
file_path = os.path.join(translated_data_directory, file_name)
|
||||
df = pd.read_csv(file_path, sep='\t')
|
||||
dfs.append(df)
|
||||
df_user = df[df['role'] == 'system'].drop('role', axis=1)
|
||||
dfs.append(df_user)
|
||||
combined_df = pd.concat(dfs, ignore_index=True)
|
||||
|
||||
# Przygotowanie zbioru danych do trenowania
|
||||
@ -49,7 +50,7 @@ training_args = Seq2SeqTrainingArguments(
|
||||
per_device_eval_batch_size=16,
|
||||
predict_with_generate=True,
|
||||
learning_rate=5e-5,
|
||||
num_train_epochs=3,
|
||||
num_train_epochs=10,
|
||||
evaluation_strategy="epoch",
|
||||
save_strategy="epoch",
|
||||
save_total_limit=None, # Wyłącz rotację punktów kontrolnych
|
||||
|
Loading…
Reference in New Issue
Block a user