JARVIS/nlg_train.ipynb

5.5 KiB

from transformers import (
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    DataCollatorForSeq2Seq,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    pipeline,
)

from datasets import load_dataset

model_name = "google/umt5-small"
dataset = load_dataset('csv', data_files='/kaggle/input/ngl-data/nlg_data.csv', split='train').train_test_split(test_size=0.1)
dataset
tokenizer = AutoTokenizer.from_pretrained(model_name)


def tokenize_samples(samples):
    inputs = [f"generate text: {mr}" for mr in samples["mr"]]

    tokenized_inputs = tokenizer(
        inputs,
        max_length=128,
        padding="max_length",
        truncation=True,
    )

    labels = tokenizer(
        text_target=samples["ref"],
        max_length=128,
        padding="max_length",
        truncation=True,
    )

    labels["input_ids"] = [
        [
            (token_id if token_id != tokenizer.pad_token_id else -100)
            for token_id in label
        ]
        for label in labels["input_ids"]
    ]

    tokenized_inputs["labels"] = labels["input_ids"]
    return tokenized_inputs


tokenized_dataset = dataset.map(
    tokenize_samples,
    batched=True,
    remove_columns=["mr", "ref"],
)

tokenized_dataset
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model, label_pad_token_id=-100, pad_to_multiple_of=8)
training_args = Seq2SeqTrainingArguments(
    output_dir="/kaggle/working",
    per_device_train_batch_size=8,
    per_device_eval_batch_size=16,
    predict_with_generate=True,
    learning_rate=5e-5,
    num_train_epochs=3,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    save_total_limit=1,
    load_best_model_at_end=True,
)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["test"],
)
trainer.train()
nlg = pipeline('summarization', model=model, tokenizer=tokenizer)
nlg(f'generate text: dish[tatar], price[50], ingredient[wolowina]')[0]['summary_text']
nlg(f'generate text: payment_methods[gotowka], price[150], addresses[ulica Dluga 5]')[0]['summary_text']
nlg(f'generate text: dish[tiramisu], ingredient[mleko], allergy[laktoza]')[0]['summary_text']
nlg(f'generate text: time[dziesiata]')[0]['summary_text']
nlg(f'generate text: dish[spaghetti], ingredient[ser]')[0]['summary_text']
nlg(f'generate text: dish[pierogi], ingredient[kozi ser]')[0]['summary_text']
nlg(f'generate text: time[23:00], adres[ul Krótka 256]')[0]['summary_text']
model.save_pretrained("/kaggle/working")