5.5 KiB
5.5 KiB
from transformers import (
AutoModelForSeq2SeqLM,
AutoTokenizer,
DataCollatorForSeq2Seq,
Seq2SeqTrainer,
Seq2SeqTrainingArguments,
pipeline,
)
from datasets import load_dataset
model_name = "google/umt5-small"
dataset = load_dataset('csv', data_files='/kaggle/input/ngl-data/nlg_data.csv', split='train').train_test_split(test_size=0.1)
dataset
tokenizer = AutoTokenizer.from_pretrained(model_name)
def tokenize_samples(samples):
inputs = [f"generate text: {mr}" for mr in samples["mr"]]
tokenized_inputs = tokenizer(
inputs,
max_length=128,
padding="max_length",
truncation=True,
)
labels = tokenizer(
text_target=samples["ref"],
max_length=128,
padding="max_length",
truncation=True,
)
labels["input_ids"] = [
[
(token_id if token_id != tokenizer.pad_token_id else -100)
for token_id in label
]
for label in labels["input_ids"]
]
tokenized_inputs["labels"] = labels["input_ids"]
return tokenized_inputs
tokenized_dataset = dataset.map(
tokenize_samples,
batched=True,
remove_columns=["mr", "ref"],
)
tokenized_dataset
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model, label_pad_token_id=-100, pad_to_multiple_of=8)
training_args = Seq2SeqTrainingArguments(
output_dir="/kaggle/working",
per_device_train_batch_size=8,
per_device_eval_batch_size=16,
predict_with_generate=True,
learning_rate=5e-5,
num_train_epochs=3,
evaluation_strategy="epoch",
save_strategy="epoch",
save_total_limit=1,
load_best_model_at_end=True,
)
trainer = Seq2SeqTrainer(
model=model,
args=training_args,
data_collator=data_collator,
train_dataset=tokenized_dataset["train"],
eval_dataset=tokenized_dataset["test"],
)
trainer.train()
nlg = pipeline('summarization', model=model, tokenizer=tokenizer)
nlg(f'generate text: dish[tatar], price[50], ingredient[wolowina]')[0]['summary_text']
nlg(f'generate text: payment_methods[gotowka], price[150], addresses[ulica Dluga 5]')[0]['summary_text']
nlg(f'generate text: dish[tiramisu], ingredient[mleko], allergy[laktoza]')[0]['summary_text']
nlg(f'generate text: time[dziesiata]')[0]['summary_text']
nlg(f'generate text: dish[spaghetti], ingredient[ser]')[0]['summary_text']
nlg(f'generate text: dish[pierogi], ingredient[kozi ser]')[0]['summary_text']
nlg(f'generate text: time[23:00], adres[ul Krótka 256]')[0]['summary_text']
model.save_pretrained("/kaggle/working")