GOATS/train_nlg.py

75 lines
2.4 KiB
Python
Raw Normal View History

2024-06-03 22:36:02 +02:00
import os
import pandas as pd
from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainer, Seq2SeqTrainingArguments
translated_data_directory = 'translated_data'
# Łączymy wszystkie przetłumaczone pliki TSV w jeden zbiór danych
dfs = []
for file_name in os.listdir(translated_data_directory):
if file_name.endswith('.tsv'):
file_path = os.path.join(translated_data_directory, file_name)
df = pd.read_csv(file_path, sep='\t')
2024-06-04 11:43:14 +02:00
df_user = df[df['role'] == 'system'].drop('role', axis=1)
dfs.append(df_user)
2024-06-03 22:36:02 +02:00
combined_df = pd.concat(dfs, ignore_index=True)
# Przygotowanie zbioru danych do trenowania
dataset = Dataset.from_pandas(combined_df)
# Wczytujemy model i tokenizer
model_name = "google/flan-t5-small"
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Funkcja do tokenizacji danych
def tokenize_samples(samples):
inputs = [f"generate text: {act}" for act in samples["act"]]
tokenized_inputs = tokenizer(inputs, max_length=128, padding="max_length", truncation=True)
labels = tokenizer(samples["value_en"], max_length=128, padding="max_length", truncation=True)
labels["input_ids"] = [
[(token_id if token_id != tokenizer.pad_token_id else -100) for token_id in label]
for label in labels["input_ids"]
]
tokenized_inputs["labels"] = labels["input_ids"]
return tokenized_inputs
# Tokenizujemy dane
tokenized_dataset = dataset.map(tokenize_samples, batched=True)
# Wczytujemy model
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
# Konfiguracja DataCollator
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model, label_pad_token_id=-100, pad_to_multiple_of=8)
# Konfiguracja treningu
training_args = Seq2SeqTrainingArguments(
output_dir="./nlg_model",
per_device_train_batch_size=8,
per_device_eval_batch_size=16,
predict_with_generate=True,
learning_rate=5e-5,
2024-06-04 10:56:40 +02:00
num_train_epochs=10,
2024-06-03 22:36:02 +02:00
evaluation_strategy="epoch",
save_strategy="epoch",
2024-06-04 00:38:14 +02:00
save_total_limit=None, # Wyłącz rotację punktów kontrolnych
2024-06-03 22:36:02 +02:00
load_best_model_at_end=True,
)
# Inicjalizacja trenera
trainer = Seq2SeqTrainer(
model=model,
args=training_args,
data_collator=data_collator,
train_dataset=tokenized_dataset,
eval_dataset=tokenized_dataset,
)
# Trening modelu
trainer.train()
2024-06-04 00:38:14 +02:00
# Zapisanie wytrenowanego modelu
trainer.save_model("./nlg_model")
tokenizer.save_pretrained("./nlg_model")