JARVIS/nlg_train.ipynb
2024-06-02 12:53:48 +02:00

43 KiB
Raw Blame History

from transformers import (
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    DataCollatorForSeq2Seq,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    pipeline,
)

from datasets import load_dataset

model_name = "google/flan-t5-small"
/home/filnow/anaconda3/envs/jarvis/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm
dataset = load_dataset('csv', data_files='nlg_data.csv', split='train').train_test_split(test_size=0.1)
dataset
Generating train split: 20628 examples [00:00, 237850.07 examples/s]
DatasetDict({
    train: Dataset({
        features: ['meaning_representation', 'human_reference'],
        num_rows: 18565
    })
    test: Dataset({
        features: ['meaning_representation', 'human_reference'],
        num_rows: 2063
    })
})
tokenizer = AutoTokenizer.from_pretrained(model_name)


def tokenize_samples(samples):
    inputs = [f"generate text: {mr}" for mr in samples["meaning_representation"]]

    tokenized_inputs = tokenizer(
        inputs,
        max_length=128,
        padding="max_length",
        truncation=True,
    )

    labels = tokenizer(
        text_target=samples["human_reference"],
        max_length=128,
        padding="max_length",
        truncation=True,
    )

    labels["input_ids"] = [
        [
            (token_id if token_id != tokenizer.pad_token_id else -100)
            for token_id in label
        ]
        for label in labels["input_ids"]
    ]

    tokenized_inputs["labels"] = labels["input_ids"]
    return tokenized_inputs


tokenized_dataset = dataset.map(
    tokenize_samples,
    batched=True,
    remove_columns=["meaning_representation", "human_reference"],
)

tokenized_dataset
/home/filnow/anaconda3/envs/jarvis/lib/python3.9/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
Map:  70%|███████   | 13000/18565 [00:02<00:01, 4762.33 examples/s]
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[3], line 33
     29     tokenized_inputs["labels"] = labels["input_ids"]
     30     return tokenized_inputs
---> 33 tokenized_dataset = dataset.map(
     34     tokenize_samples,
     35     batched=True,
     36     remove_columns=["meaning_representation", "human_reference"],
     37 )
     39 tokenized_dataset

File ~/anaconda3/envs/jarvis/lib/python3.9/site-packages/datasets/dataset_dict.py:869, in DatasetDict.map(self, function, with_indices, with_rank, input_columns, batched, batch_size, drop_last_batch, remove_columns, keep_in_memory, load_from_cache_file, cache_file_names, writer_batch_size, features, disable_nullable, fn_kwargs, num_proc, desc)
    866 if cache_file_names is None:
    867     cache_file_names = {k: None for k in self}
    868 return DatasetDict(
--> 869     {
    870         k: dataset.map(
    871             function=function,
    872             with_indices=with_indices,
    873             with_rank=with_rank,
    874             input_columns=input_columns,
    875             batched=batched,
    876             batch_size=batch_size,
    877             drop_last_batch=drop_last_batch,
    878             remove_columns=remove_columns,
    879             keep_in_memory=keep_in_memory,
    880             load_from_cache_file=load_from_cache_file,
    881             cache_file_name=cache_file_names[k],
    882             writer_batch_size=writer_batch_size,
    883             features=features,
    884             disable_nullable=disable_nullable,
    885             fn_kwargs=fn_kwargs,
    886             num_proc=num_proc,
    887             desc=desc,
    888         )
    889         for k, dataset in self.items()
    890     }
    891 )

File ~/anaconda3/envs/jarvis/lib/python3.9/site-packages/datasets/dataset_dict.py:870, in <dictcomp>(.0)
    866 if cache_file_names is None:
    867     cache_file_names = {k: None for k in self}
    868 return DatasetDict(
    869     {
--> 870         k: dataset.map(
    871             function=function,
    872             with_indices=with_indices,
    873             with_rank=with_rank,
    874             input_columns=input_columns,
    875             batched=batched,
    876             batch_size=batch_size,
    877             drop_last_batch=drop_last_batch,
    878             remove_columns=remove_columns,
    879             keep_in_memory=keep_in_memory,
    880             load_from_cache_file=load_from_cache_file,
    881             cache_file_name=cache_file_names[k],
    882             writer_batch_size=writer_batch_size,
    883             features=features,
    884             disable_nullable=disable_nullable,
    885             fn_kwargs=fn_kwargs,
    886             num_proc=num_proc,
    887             desc=desc,
    888         )
    889         for k, dataset in self.items()
    890     }
    891 )

File ~/anaconda3/envs/jarvis/lib/python3.9/site-packages/datasets/arrow_dataset.py:602, in transmit_tasks.<locals>.wrapper(*args, **kwargs)
    600     self: "Dataset" = kwargs.pop("self")
    601 # apply actual function
--> 602 out: Union["Dataset", "DatasetDict"] = func(self, *args, **kwargs)
    603 datasets: List["Dataset"] = list(out.values()) if isinstance(out, dict) else [out]
    604 for dataset in datasets:
    605     # Remove task templates if a column mapping of the template is no longer valid

File ~/anaconda3/envs/jarvis/lib/python3.9/site-packages/datasets/arrow_dataset.py:567, in transmit_format.<locals>.wrapper(*args, **kwargs)
    560 self_format = {
    561     "type": self._format_type,
    562     "format_kwargs": self._format_kwargs,
    563     "columns": self._format_columns,
    564     "output_all_columns": self._output_all_columns,
    565 }
    566 # apply actual function
--> 567 out: Union["Dataset", "DatasetDict"] = func(self, *args, **kwargs)
    568 datasets: List["Dataset"] = list(out.values()) if isinstance(out, dict) else [out]
    569 # re-apply format to the output

File ~/anaconda3/envs/jarvis/lib/python3.9/site-packages/datasets/arrow_dataset.py:3156, in Dataset.map(self, function, with_indices, with_rank, input_columns, batched, batch_size, drop_last_batch, remove_columns, keep_in_memory, load_from_cache_file, cache_file_name, writer_batch_size, features, disable_nullable, fn_kwargs, num_proc, suffix_template, new_fingerprint, desc)
   3150 if transformed_dataset is None:
   3151     with hf_tqdm(
   3152         unit=" examples",
   3153         total=pbar_total,
   3154         desc=desc or "Map",
   3155     ) as pbar:
-> 3156         for rank, done, content in Dataset._map_single(**dataset_kwargs):
   3157             if done:
   3158                 shards_done += 1

File ~/anaconda3/envs/jarvis/lib/python3.9/site-packages/datasets/arrow_dataset.py:3547, in Dataset._map_single(shard, function, with_indices, with_rank, input_columns, batched, batch_size, drop_last_batch, remove_columns, keep_in_memory, cache_file_name, writer_batch_size, features, disable_nullable, fn_kwargs, new_fingerprint, rank, offset)
   3543 indices = list(
   3544     range(*(slice(i, i + batch_size).indices(shard.num_rows)))
   3545 )  # Something simpler?
   3546 try:
-> 3547     batch = apply_function_on_filtered_inputs(
   3548         batch,
   3549         indices,
   3550         check_same_num_examples=len(shard.list_indexes()) > 0,
   3551         offset=offset,
   3552     )
   3553 except NumExamplesMismatchError:
   3554     raise DatasetTransformationNotAllowedError(
   3555         "Using `.map` in batched mode on a dataset with attached indexes is allowed only if it doesn't create or remove existing examples. You can first run `.drop_index() to remove your index and then re-add it."
   3556     ) from None

File ~/anaconda3/envs/jarvis/lib/python3.9/site-packages/datasets/arrow_dataset.py:3416, in Dataset._map_single.<locals>.apply_function_on_filtered_inputs(pa_inputs, indices, check_same_num_examples, offset)
   3414 if with_rank:
   3415     additional_args += (rank,)
-> 3416 processed_inputs = function(*fn_args, *additional_args, **fn_kwargs)
   3417 if isinstance(processed_inputs, LazyDict):
   3418     processed_inputs = {
   3419         k: v for k, v in processed_inputs.data.items() if k not in processed_inputs.keys_to_format
   3420     }

Cell In[3], line 14, in tokenize_samples(samples)
      5 inputs = [f"generate text: {mr}" for mr in samples["meaning_representation"]]
      7 tokenized_inputs = tokenizer(
      8     inputs,
      9     max_length=128,
     10     padding="max_length",
     11     truncation=True,
     12 )
---> 14 labels = tokenizer(
     15     text_target=samples["human_reference"],
     16     max_length=128,
     17     padding="max_length",
     18     truncation=True,
     19 )
     21 labels["input_ids"] = [
     22     [
     23         (token_id if token_id != tokenizer.pad_token_id else -100)
   (...)
     26     for label in labels["input_ids"]
     27 ]
     29 tokenized_inputs["labels"] = labels["input_ids"]

File ~/anaconda3/envs/jarvis/lib/python3.9/site-packages/transformers/tokenization_utils_base.py:2491, in PreTrainedTokenizerBase.__call__(self, text, text_pair, text_target, text_pair_target, add_special_tokens, padding, truncation, max_length, stride, is_split_into_words, pad_to_multiple_of, return_tensors, return_token_type_ids, return_attention_mask, return_overflowing_tokens, return_special_tokens_mask, return_offsets_mapping, return_length, verbose, **kwargs)
   2489 if text_target is not None:
   2490     self._switch_to_target_mode()
-> 2491     target_encodings = self._call_one(text=text_target, text_pair=text_pair_target, **all_kwargs)
   2492 # Leave back tokenizer in input mode
   2493 self._switch_to_input_mode()

File ~/anaconda3/envs/jarvis/lib/python3.9/site-packages/transformers/tokenization_utils_base.py:2574, in PreTrainedTokenizerBase._call_one(self, text, text_pair, add_special_tokens, padding, truncation, max_length, stride, is_split_into_words, pad_to_multiple_of, return_tensors, return_token_type_ids, return_attention_mask, return_overflowing_tokens, return_special_tokens_mask, return_offsets_mapping, return_length, verbose, **kwargs)
   2569         raise ValueError(
   2570             f"batch length of `text`: {len(text)} does not match batch length of `text_pair`:"
   2571             f" {len(text_pair)}."
   2572         )
   2573     batch_text_or_text_pairs = list(zip(text, text_pair)) if text_pair is not None else text
-> 2574     return self.batch_encode_plus(
   2575         batch_text_or_text_pairs=batch_text_or_text_pairs,
   2576         add_special_tokens=add_special_tokens,
   2577         padding=padding,
   2578         truncation=truncation,
   2579         max_length=max_length,
   2580         stride=stride,
   2581         is_split_into_words=is_split_into_words,
   2582         pad_to_multiple_of=pad_to_multiple_of,
   2583         return_tensors=return_tensors,
   2584         return_token_type_ids=return_token_type_ids,
   2585         return_attention_mask=return_attention_mask,
   2586         return_overflowing_tokens=return_overflowing_tokens,
   2587         return_special_tokens_mask=return_special_tokens_mask,
   2588         return_offsets_mapping=return_offsets_mapping,
   2589         return_length=return_length,
   2590         verbose=verbose,
   2591         **kwargs,
   2592     )
   2593 else:
   2594     return self.encode_plus(
   2595         text=text,
   2596         text_pair=text_pair,
   (...)
   2612         **kwargs,
   2613     )

File ~/anaconda3/envs/jarvis/lib/python3.9/site-packages/transformers/tokenization_utils_base.py:2765, in PreTrainedTokenizerBase.batch_encode_plus(self, batch_text_or_text_pairs, add_special_tokens, padding, truncation, max_length, stride, is_split_into_words, pad_to_multiple_of, return_tensors, return_token_type_ids, return_attention_mask, return_overflowing_tokens, return_special_tokens_mask, return_offsets_mapping, return_length, verbose, **kwargs)
   2755 # Backward compatibility for 'truncation_strategy', 'pad_to_max_length'
   2756 padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies(
   2757     padding=padding,
   2758     truncation=truncation,
   (...)
   2762     **kwargs,
   2763 )
-> 2765 return self._batch_encode_plus(
   2766     batch_text_or_text_pairs=batch_text_or_text_pairs,
   2767     add_special_tokens=add_special_tokens,
   2768     padding_strategy=padding_strategy,
   2769     truncation_strategy=truncation_strategy,
   2770     max_length=max_length,
   2771     stride=stride,
   2772     is_split_into_words=is_split_into_words,
   2773     pad_to_multiple_of=pad_to_multiple_of,
   2774     return_tensors=return_tensors,
   2775     return_token_type_ids=return_token_type_ids,
   2776     return_attention_mask=return_attention_mask,
   2777     return_overflowing_tokens=return_overflowing_tokens,
   2778     return_special_tokens_mask=return_special_tokens_mask,
   2779     return_offsets_mapping=return_offsets_mapping,
   2780     return_length=return_length,
   2781     verbose=verbose,
   2782     **kwargs,
   2783 )

File ~/anaconda3/envs/jarvis/lib/python3.9/site-packages/transformers/tokenization_utils_fast.py:429, in PreTrainedTokenizerFast._batch_encode_plus(self, batch_text_or_text_pairs, add_special_tokens, padding_strategy, truncation_strategy, max_length, stride, is_split_into_words, pad_to_multiple_of, return_tensors, return_token_type_ids, return_attention_mask, return_overflowing_tokens, return_special_tokens_mask, return_offsets_mapping, return_length, verbose)
    420 # Set the truncation and padding strategy and restore the initial configuration
    421 self.set_truncation_and_padding(
    422     padding_strategy=padding_strategy,
    423     truncation_strategy=truncation_strategy,
   (...)
    426     pad_to_multiple_of=pad_to_multiple_of,
    427 )
--> 429 encodings = self._tokenizer.encode_batch(
    430     batch_text_or_text_pairs,
    431     add_special_tokens=add_special_tokens,
    432     is_pretokenized=is_split_into_words,
    433 )
    435 # Convert encoding to dict
    436 # `Tokens` has type: Tuple[
    437 #                       List[Dict[str, List[List[int]]]] or List[Dict[str, 2D-Tensor]],
    438 #                       List[EncodingFast]
    439 #                    ]
    440 # with nested dimensions corresponding to batch, overflows, sequence length
    441 tokens_and_encodings = [
    442     self._convert_encoding(
    443         encoding=encoding,
   (...)
    452     for encoding in encodings
    453 ]

TypeError: TextEncodeInput must be Union[TextInputSequence, Tuple[InputSequence, InputSequence]]
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model, label_pad_token_id=-100, pad_to_multiple_of=8)
training_args = Seq2SeqTrainingArguments(
    output_dir="/",
    per_device_train_batch_size=8,
    per_device_eval_batch_size=16,
    predict_with_generate=True,
    learning_rate=5e-5,
    num_train_epochs=3,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    save_total_limit=1,
    load_best_model_at_end=True,
)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["test"],
)
trainer.train()
nlg = pipeline('summarization', model=model, tokenizer=tokenizer)
nlg(f'generate text: dish[tatar], price[50], ingredient[wolowina]')[0]['summary_text']
nlg(f'generate text: payment_methods[gotowka], price[150], addresses[ulica Dluga 5]')[0]['summary_text']
nlg(f'generate text: dish[tiramisu], ingredient[mleko], allergy[laktoza]')[0]['summary_text']
nlg(f'generate text: time[dziesiata]')[0]['summary_text']
model.save_pretrained("/")