42 KiB
42 KiB
from transformers import (
AutoModelForSeq2SeqLM,
AutoTokenizer,
DataCollatorForSeq2Seq,
Seq2SeqTrainer,
Seq2SeqTrainingArguments,
pipeline,
)
from datasets import load_dataset
model_name = "google/flan-t5-small"
dataset = load_dataset('csv', data_files='nlg_data.csv', split='train').train_test_split(test_size=0.1)
dataset
DatasetDict({ train: Dataset({ features: ['meaning_representation', 'human_reference'], num_rows: 18546 }) test: Dataset({ features: ['meaning_representation', 'human_reference'], num_rows: 2061 }) })
tokenizer = AutoTokenizer.from_pretrained(model_name)
def tokenize_samples(samples):
inputs = [f"generate text: {mr}" for mr in samples["meaning_representation"]]
tokenized_inputs = tokenizer(
inputs,
max_length=128,
padding="max_length",
truncation=True,
)
labels = tokenizer(
text_target=samples["human_reference"],
max_length=128,
padding="max_length",
truncation=True,
)
labels["input_ids"] = [
[
(token_id if token_id != tokenizer.pad_token_id else -100)
for token_id in label
]
for label in labels["input_ids"]
]
tokenized_inputs["labels"] = labels["input_ids"]
return tokenized_inputs
tokenized_dataset = dataset.map(
tokenize_samples,
batched=True,
remove_columns=["meaning_representation", "human_reference"],
)
tokenized_dataset
/home/filnow/anaconda3/envs/jarvis/lib/python3.9/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`. warnings.warn( Map: 75%|███████▌ | 14000/18546 [00:02<00:00, 4940.94 examples/s]
[0;31m---------------------------------------------------------------------------[0m [0;31mTypeError[0m Traceback (most recent call last) Cell [0;32mIn[5], line 33[0m [1;32m 29[0m tokenized_inputs[[38;5;124m"[39m[38;5;124mlabels[39m[38;5;124m"[39m] [38;5;241m=[39m labels[[38;5;124m"[39m[38;5;124minput_ids[39m[38;5;124m"[39m] [1;32m 30[0m [38;5;28;01mreturn[39;00m tokenized_inputs [0;32m---> 33[0m tokenized_dataset [38;5;241m=[39m [43mdataset[49m[38;5;241;43m.[39;49m[43mmap[49m[43m([49m [1;32m 34[0m [43m [49m[43mtokenize_samples[49m[43m,[49m [1;32m 35[0m [43m [49m[43mbatched[49m[38;5;241;43m=[39;49m[38;5;28;43;01mTrue[39;49;00m[43m,[49m [1;32m 36[0m [43m [49m[43mremove_columns[49m[38;5;241;43m=[39;49m[43m[[49m[38;5;124;43m"[39;49m[38;5;124;43mmeaning_representation[39;49m[38;5;124;43m"[39;49m[43m,[49m[43m [49m[38;5;124;43m"[39;49m[38;5;124;43mhuman_reference[39;49m[38;5;124;43m"[39;49m[43m][49m[43m,[49m [1;32m 37[0m [43m)[49m [1;32m 39[0m tokenized_dataset File [0;32m~/anaconda3/envs/jarvis/lib/python3.9/site-packages/datasets/dataset_dict.py:869[0m, in [0;36mDatasetDict.map[0;34m(self, function, with_indices, with_rank, input_columns, batched, batch_size, drop_last_batch, remove_columns, keep_in_memory, load_from_cache_file, cache_file_names, writer_batch_size, features, disable_nullable, fn_kwargs, num_proc, desc)[0m [1;32m 866[0m [38;5;28;01mif[39;00m cache_file_names [38;5;129;01mis[39;00m [38;5;28;01mNone[39;00m: [1;32m 867[0m cache_file_names [38;5;241m=[39m {k: [38;5;28;01mNone[39;00m [38;5;28;01mfor[39;00m k [38;5;129;01min[39;00m [38;5;28mself[39m} [1;32m 868[0m [38;5;28;01mreturn[39;00m DatasetDict( [0;32m--> 869[0m { [1;32m 870[0m k: dataset[38;5;241m.[39mmap( [1;32m 871[0m function[38;5;241m=[39mfunction, [1;32m 872[0m with_indices[38;5;241m=[39mwith_indices, [1;32m 873[0m with_rank[38;5;241m=[39mwith_rank, [1;32m 874[0m input_columns[38;5;241m=[39minput_columns, [1;32m 875[0m batched[38;5;241m=[39mbatched, [1;32m 876[0m batch_size[38;5;241m=[39mbatch_size, [1;32m 877[0m drop_last_batch[38;5;241m=[39mdrop_last_batch, [1;32m 878[0m remove_columns[38;5;241m=[39mremove_columns, [1;32m 879[0m keep_in_memory[38;5;241m=[39mkeep_in_memory, [1;32m 880[0m load_from_cache_file[38;5;241m=[39mload_from_cache_file, [1;32m 881[0m cache_file_name[38;5;241m=[39mcache_file_names[k], [1;32m 882[0m writer_batch_size[38;5;241m=[39mwriter_batch_size, [1;32m 883[0m features[38;5;241m=[39mfeatures, [1;32m 884[0m disable_nullable[38;5;241m=[39mdisable_nullable, [1;32m 885[0m fn_kwargs[38;5;241m=[39mfn_kwargs, [1;32m 886[0m num_proc[38;5;241m=[39mnum_proc, [1;32m 887[0m desc[38;5;241m=[39mdesc, [1;32m 888[0m ) [1;32m 889[0m [38;5;28;01mfor[39;00m k, dataset [38;5;129;01min[39;00m [38;5;28mself[39m[38;5;241m.[39mitems() [1;32m 890[0m } [1;32m 891[0m ) File [0;32m~/anaconda3/envs/jarvis/lib/python3.9/site-packages/datasets/dataset_dict.py:870[0m, in [0;36m<dictcomp>[0;34m(.0)[0m [1;32m 866[0m [38;5;28;01mif[39;00m cache_file_names [38;5;129;01mis[39;00m [38;5;28;01mNone[39;00m: [1;32m 867[0m cache_file_names [38;5;241m=[39m {k: [38;5;28;01mNone[39;00m [38;5;28;01mfor[39;00m k [38;5;129;01min[39;00m [38;5;28mself[39m} [1;32m 868[0m [38;5;28;01mreturn[39;00m DatasetDict( [1;32m 869[0m { [0;32m--> 870[0m k: [43mdataset[49m[38;5;241;43m.[39;49m[43mmap[49m[43m([49m [1;32m 871[0m [43m [49m[43mfunction[49m[38;5;241;43m=[39;49m[43mfunction[49m[43m,[49m [1;32m 872[0m [43m [49m[43mwith_indices[49m[38;5;241;43m=[39;49m[43mwith_indices[49m[43m,[49m [1;32m 873[0m [43m [49m[43mwith_rank[49m[38;5;241;43m=[39;49m[43mwith_rank[49m[43m,[49m [1;32m 874[0m [43m [49m[43minput_columns[49m[38;5;241;43m=[39;49m[43minput_columns[49m[43m,[49m [1;32m 875[0m [43m [49m[43mbatched[49m[38;5;241;43m=[39;49m[43mbatched[49m[43m,[49m [1;32m 876[0m [43m [49m[43mbatch_size[49m[38;5;241;43m=[39;49m[43mbatch_size[49m[43m,[49m [1;32m 877[0m [43m [49m[43mdrop_last_batch[49m[38;5;241;43m=[39;49m[43mdrop_last_batch[49m[43m,[49m [1;32m 878[0m [43m [49m[43mremove_columns[49m[38;5;241;43m=[39;49m[43mremove_columns[49m[43m,[49m [1;32m 879[0m [43m [49m[43mkeep_in_memory[49m[38;5;241;43m=[39;49m[43mkeep_in_memory[49m[43m,[49m [1;32m 880[0m [43m [49m[43mload_from_cache_file[49m[38;5;241;43m=[39;49m[43mload_from_cache_file[49m[43m,[49m [1;32m 881[0m [43m [49m[43mcache_file_name[49m[38;5;241;43m=[39;49m[43mcache_file_names[49m[43m[[49m[43mk[49m[43m][49m[43m,[49m [1;32m 882[0m [43m [49m[43mwriter_batch_size[49m[38;5;241;43m=[39;49m[43mwriter_batch_size[49m[43m,[49m [1;32m 883[0m [43m [49m[43mfeatures[49m[38;5;241;43m=[39;49m[43mfeatures[49m[43m,[49m [1;32m 884[0m [43m [49m[43mdisable_nullable[49m[38;5;241;43m=[39;49m[43mdisable_nullable[49m[43m,[49m [1;32m 885[0m [43m [49m[43mfn_kwargs[49m[38;5;241;43m=[39;49m[43mfn_kwargs[49m[43m,[49m [1;32m 886[0m [43m [49m[43mnum_proc[49m[38;5;241;43m=[39;49m[43mnum_proc[49m[43m,[49m [1;32m 887[0m [43m [49m[43mdesc[49m[38;5;241;43m=[39;49m[43mdesc[49m[43m,[49m [1;32m 888[0m [43m [49m[43m)[49m [1;32m 889[0m [38;5;28;01mfor[39;00m k, dataset [38;5;129;01min[39;00m [38;5;28mself[39m[38;5;241m.[39mitems() [1;32m 890[0m } [1;32m 891[0m ) File [0;32m~/anaconda3/envs/jarvis/lib/python3.9/site-packages/datasets/arrow_dataset.py:602[0m, in [0;36mtransmit_tasks.<locals>.wrapper[0;34m(*args, **kwargs)[0m [1;32m 600[0m [38;5;28mself[39m: [38;5;124m"[39m[38;5;124mDataset[39m[38;5;124m"[39m [38;5;241m=[39m kwargs[38;5;241m.[39mpop([38;5;124m"[39m[38;5;124mself[39m[38;5;124m"[39m) [1;32m 601[0m [38;5;66;03m# apply actual function[39;00m [0;32m--> 602[0m out: Union[[38;5;124m"[39m[38;5;124mDataset[39m[38;5;124m"[39m, [38;5;124m"[39m[38;5;124mDatasetDict[39m[38;5;124m"[39m] [38;5;241m=[39m [43mfunc[49m[43m([49m[38;5;28;43mself[39;49m[43m,[49m[43m [49m[38;5;241;43m*[39;49m[43margs[49m[43m,[49m[43m [49m[38;5;241;43m*[39;49m[38;5;241;43m*[39;49m[43mkwargs[49m[43m)[49m [1;32m 603[0m datasets: List[[38;5;124m"[39m[38;5;124mDataset[39m[38;5;124m"[39m] [38;5;241m=[39m [38;5;28mlist[39m(out[38;5;241m.[39mvalues()) [38;5;28;01mif[39;00m [38;5;28misinstance[39m(out, [38;5;28mdict[39m) [38;5;28;01melse[39;00m [out] [1;32m 604[0m [38;5;28;01mfor[39;00m dataset [38;5;129;01min[39;00m datasets: [1;32m 605[0m [38;5;66;03m# Remove task templates if a column mapping of the template is no longer valid[39;00m File [0;32m~/anaconda3/envs/jarvis/lib/python3.9/site-packages/datasets/arrow_dataset.py:567[0m, in [0;36mtransmit_format.<locals>.wrapper[0;34m(*args, **kwargs)[0m [1;32m 560[0m self_format [38;5;241m=[39m { [1;32m 561[0m [38;5;124m"[39m[38;5;124mtype[39m[38;5;124m"[39m: [38;5;28mself[39m[38;5;241m.[39m_format_type, [1;32m 562[0m [38;5;124m"[39m[38;5;124mformat_kwargs[39m[38;5;124m"[39m: [38;5;28mself[39m[38;5;241m.[39m_format_kwargs, [1;32m 563[0m [38;5;124m"[39m[38;5;124mcolumns[39m[38;5;124m"[39m: [38;5;28mself[39m[38;5;241m.[39m_format_columns, [1;32m 564[0m [38;5;124m"[39m[38;5;124moutput_all_columns[39m[38;5;124m"[39m: [38;5;28mself[39m[38;5;241m.[39m_output_all_columns, [1;32m 565[0m } [1;32m 566[0m [38;5;66;03m# apply actual function[39;00m [0;32m--> 567[0m out: Union[[38;5;124m"[39m[38;5;124mDataset[39m[38;5;124m"[39m, [38;5;124m"[39m[38;5;124mDatasetDict[39m[38;5;124m"[39m] [38;5;241m=[39m [43mfunc[49m[43m([49m[38;5;28;43mself[39;49m[43m,[49m[43m [49m[38;5;241;43m*[39;49m[43margs[49m[43m,[49m[43m [49m[38;5;241;43m*[39;49m[38;5;241;43m*[39;49m[43mkwargs[49m[43m)[49m [1;32m 568[0m datasets: List[[38;5;124m"[39m[38;5;124mDataset[39m[38;5;124m"[39m] [38;5;241m=[39m [38;5;28mlist[39m(out[38;5;241m.[39mvalues()) [38;5;28;01mif[39;00m [38;5;28misinstance[39m(out, [38;5;28mdict[39m) [38;5;28;01melse[39;00m [out] [1;32m 569[0m [38;5;66;03m# re-apply format to the output[39;00m File [0;32m~/anaconda3/envs/jarvis/lib/python3.9/site-packages/datasets/arrow_dataset.py:3156[0m, in [0;36mDataset.map[0;34m(self, function, with_indices, with_rank, input_columns, batched, batch_size, drop_last_batch, remove_columns, keep_in_memory, load_from_cache_file, cache_file_name, writer_batch_size, features, disable_nullable, fn_kwargs, num_proc, suffix_template, new_fingerprint, desc)[0m [1;32m 3150[0m [38;5;28;01mif[39;00m transformed_dataset [38;5;129;01mis[39;00m [38;5;28;01mNone[39;00m: [1;32m 3151[0m [38;5;28;01mwith[39;00m hf_tqdm( [1;32m 3152[0m unit[38;5;241m=[39m[38;5;124m"[39m[38;5;124m examples[39m[38;5;124m"[39m, [1;32m 3153[0m total[38;5;241m=[39mpbar_total, [1;32m 3154[0m desc[38;5;241m=[39mdesc [38;5;129;01mor[39;00m [38;5;124m"[39m[38;5;124mMap[39m[38;5;124m"[39m, [1;32m 3155[0m ) [38;5;28;01mas[39;00m pbar: [0;32m-> 3156[0m [38;5;28;01mfor[39;00m rank, done, content [38;5;129;01min[39;00m Dataset[38;5;241m.[39m_map_single([38;5;241m*[39m[38;5;241m*[39mdataset_kwargs): [1;32m 3157[0m [38;5;28;01mif[39;00m done: [1;32m 3158[0m shards_done [38;5;241m+[39m[38;5;241m=[39m [38;5;241m1[39m File [0;32m~/anaconda3/envs/jarvis/lib/python3.9/site-packages/datasets/arrow_dataset.py:3547[0m, in [0;36mDataset._map_single[0;34m(shard, function, with_indices, with_rank, input_columns, batched, batch_size, drop_last_batch, remove_columns, keep_in_memory, cache_file_name, writer_batch_size, features, disable_nullable, fn_kwargs, new_fingerprint, rank, offset)[0m [1;32m 3543[0m indices [38;5;241m=[39m [38;5;28mlist[39m( [1;32m 3544[0m [38;5;28mrange[39m([38;5;241m*[39m([38;5;28mslice[39m(i, i [38;5;241m+[39m batch_size)[38;5;241m.[39mindices(shard[38;5;241m.[39mnum_rows))) [1;32m 3545[0m ) [38;5;66;03m# Something simpler?[39;00m [1;32m 3546[0m [38;5;28;01mtry[39;00m: [0;32m-> 3547[0m batch [38;5;241m=[39m [43mapply_function_on_filtered_inputs[49m[43m([49m [1;32m 3548[0m [43m [49m[43mbatch[49m[43m,[49m [1;32m 3549[0m [43m [49m[43mindices[49m[43m,[49m [1;32m 3550[0m [43m [49m[43mcheck_same_num_examples[49m[38;5;241;43m=[39;49m[38;5;28;43mlen[39;49m[43m([49m[43mshard[49m[38;5;241;43m.[39;49m[43mlist_indexes[49m[43m([49m[43m)[49m[43m)[49m[43m [49m[38;5;241;43m>[39;49m[43m [49m[38;5;241;43m0[39;49m[43m,[49m [1;32m 3551[0m [43m [49m[43moffset[49m[38;5;241;43m=[39;49m[43moffset[49m[43m,[49m [1;32m 3552[0m [43m [49m[43m)[49m [1;32m 3553[0m [38;5;28;01mexcept[39;00m NumExamplesMismatchError: [1;32m 3554[0m [38;5;28;01mraise[39;00m DatasetTransformationNotAllowedError( [1;32m 3555[0m [38;5;124m"[39m[38;5;124mUsing `.map` in batched mode on a dataset with attached indexes is allowed only if it doesn[39m[38;5;124m'[39m[38;5;124mt create or remove existing examples. You can first run `.drop_index() to remove your index and then re-add it.[39m[38;5;124m"[39m [1;32m 3556[0m ) [38;5;28;01mfrom[39;00m [38;5;28;01mNone[39;00m File [0;32m~/anaconda3/envs/jarvis/lib/python3.9/site-packages/datasets/arrow_dataset.py:3416[0m, in [0;36mDataset._map_single.<locals>.apply_function_on_filtered_inputs[0;34m(pa_inputs, indices, check_same_num_examples, offset)[0m [1;32m 3414[0m [38;5;28;01mif[39;00m with_rank: [1;32m 3415[0m additional_args [38;5;241m+[39m[38;5;241m=[39m (rank,) [0;32m-> 3416[0m processed_inputs [38;5;241m=[39m [43mfunction[49m[43m([49m[38;5;241;43m*[39;49m[43mfn_args[49m[43m,[49m[43m [49m[38;5;241;43m*[39;49m[43madditional_args[49m[43m,[49m[43m [49m[38;5;241;43m*[39;49m[38;5;241;43m*[39;49m[43mfn_kwargs[49m[43m)[49m [1;32m 3417[0m [38;5;28;01mif[39;00m [38;5;28misinstance[39m(processed_inputs, LazyDict): [1;32m 3418[0m processed_inputs [38;5;241m=[39m { [1;32m 3419[0m k: v [38;5;28;01mfor[39;00m k, v [38;5;129;01min[39;00m processed_inputs[38;5;241m.[39mdata[38;5;241m.[39mitems() [38;5;28;01mif[39;00m k [38;5;129;01mnot[39;00m [38;5;129;01min[39;00m processed_inputs[38;5;241m.[39mkeys_to_format [1;32m 3420[0m } Cell [0;32mIn[5], line 14[0m, in [0;36mtokenize_samples[0;34m(samples)[0m [1;32m 5[0m inputs [38;5;241m=[39m [[38;5;124mf[39m[38;5;124m"[39m[38;5;124mgenerate text: [39m[38;5;132;01m{[39;00mmr[38;5;132;01m}[39;00m[38;5;124m"[39m [38;5;28;01mfor[39;00m mr [38;5;129;01min[39;00m samples[[38;5;124m"[39m[38;5;124mmeaning_representation[39m[38;5;124m"[39m]] [1;32m 7[0m tokenized_inputs [38;5;241m=[39m tokenizer( [1;32m 8[0m inputs, [1;32m 9[0m max_length[38;5;241m=[39m[38;5;241m128[39m, [1;32m 10[0m padding[38;5;241m=[39m[38;5;124m"[39m[38;5;124mmax_length[39m[38;5;124m"[39m, [1;32m 11[0m truncation[38;5;241m=[39m[38;5;28;01mTrue[39;00m, [1;32m 12[0m ) [0;32m---> 14[0m labels [38;5;241m=[39m [43mtokenizer[49m[43m([49m [1;32m 15[0m [43m [49m[43mtext_target[49m[38;5;241;43m=[39;49m[43msamples[49m[43m[[49m[38;5;124;43m"[39;49m[38;5;124;43mhuman_reference[39;49m[38;5;124;43m"[39;49m[43m][49m[43m,[49m [1;32m 16[0m [43m [49m[43mmax_length[49m[38;5;241;43m=[39;49m[38;5;241;43m128[39;49m[43m,[49m [1;32m 17[0m [43m [49m[43mpadding[49m[38;5;241;43m=[39;49m[38;5;124;43m"[39;49m[38;5;124;43mmax_length[39;49m[38;5;124;43m"[39;49m[43m,[49m [1;32m 18[0m [43m [49m[43mtruncation[49m[38;5;241;43m=[39;49m[38;5;28;43;01mTrue[39;49;00m[43m,[49m [1;32m 19[0m [43m[49m[43m)[49m [1;32m 21[0m labels[[38;5;124m"[39m[38;5;124minput_ids[39m[38;5;124m"[39m] [38;5;241m=[39m [ [1;32m 22[0m [ [1;32m 23[0m (token_id [38;5;28;01mif[39;00m token_id [38;5;241m!=[39m tokenizer[38;5;241m.[39mpad_token_id [38;5;28;01melse[39;00m [38;5;241m-[39m[38;5;241m100[39m) [0;32m (...)[0m [1;32m 26[0m [38;5;28;01mfor[39;00m label [38;5;129;01min[39;00m labels[[38;5;124m"[39m[38;5;124minput_ids[39m[38;5;124m"[39m] [1;32m 27[0m ] [1;32m 29[0m tokenized_inputs[[38;5;124m"[39m[38;5;124mlabels[39m[38;5;124m"[39m] [38;5;241m=[39m labels[[38;5;124m"[39m[38;5;124minput_ids[39m[38;5;124m"[39m] File [0;32m~/anaconda3/envs/jarvis/lib/python3.9/site-packages/transformers/tokenization_utils_base.py:2491[0m, in [0;36mPreTrainedTokenizerBase.__call__[0;34m(self, text, text_pair, text_target, text_pair_target, add_special_tokens, padding, truncation, max_length, stride, is_split_into_words, pad_to_multiple_of, return_tensors, return_token_type_ids, return_attention_mask, return_overflowing_tokens, return_special_tokens_mask, return_offsets_mapping, return_length, verbose, **kwargs)[0m [1;32m 2489[0m [38;5;28;01mif[39;00m text_target [38;5;129;01mis[39;00m [38;5;129;01mnot[39;00m [38;5;28;01mNone[39;00m: [1;32m 2490[0m [38;5;28mself[39m[38;5;241m.[39m_switch_to_target_mode() [0;32m-> 2491[0m target_encodings [38;5;241m=[39m [38;5;28;43mself[39;49m[38;5;241;43m.[39;49m[43m_call_one[49m[43m([49m[43mtext[49m[38;5;241;43m=[39;49m[43mtext_target[49m[43m,[49m[43m [49m[43mtext_pair[49m[38;5;241;43m=[39;49m[43mtext_pair_target[49m[43m,[49m[43m [49m[38;5;241;43m*[39;49m[38;5;241;43m*[39;49m[43mall_kwargs[49m[43m)[49m [1;32m 2492[0m [38;5;66;03m# Leave back tokenizer in input mode[39;00m [1;32m 2493[0m [38;5;28mself[39m[38;5;241m.[39m_switch_to_input_mode() File [0;32m~/anaconda3/envs/jarvis/lib/python3.9/site-packages/transformers/tokenization_utils_base.py:2574[0m, in [0;36mPreTrainedTokenizerBase._call_one[0;34m(self, text, text_pair, add_special_tokens, padding, truncation, max_length, stride, is_split_into_words, pad_to_multiple_of, return_tensors, return_token_type_ids, return_attention_mask, return_overflowing_tokens, return_special_tokens_mask, return_offsets_mapping, return_length, verbose, **kwargs)[0m [1;32m 2569[0m [38;5;28;01mraise[39;00m [38;5;167;01mValueError[39;00m( [1;32m 2570[0m [38;5;124mf[39m[38;5;124m"[39m[38;5;124mbatch length of `text`: [39m[38;5;132;01m{[39;00m[38;5;28mlen[39m(text)[38;5;132;01m}[39;00m[38;5;124m does not match batch length of `text_pair`:[39m[38;5;124m"[39m [1;32m 2571[0m [38;5;124mf[39m[38;5;124m"[39m[38;5;124m [39m[38;5;132;01m{[39;00m[38;5;28mlen[39m(text_pair)[38;5;132;01m}[39;00m[38;5;124m.[39m[38;5;124m"[39m [1;32m 2572[0m ) [1;32m 2573[0m batch_text_or_text_pairs [38;5;241m=[39m [38;5;28mlist[39m([38;5;28mzip[39m(text, text_pair)) [38;5;28;01mif[39;00m text_pair [38;5;129;01mis[39;00m [38;5;129;01mnot[39;00m [38;5;28;01mNone[39;00m [38;5;28;01melse[39;00m text [0;32m-> 2574[0m [38;5;28;01mreturn[39;00m [38;5;28;43mself[39;49m[38;5;241;43m.[39;49m[43mbatch_encode_plus[49m[43m([49m [1;32m 2575[0m [43m [49m[43mbatch_text_or_text_pairs[49m[38;5;241;43m=[39;49m[43mbatch_text_or_text_pairs[49m[43m,[49m [1;32m 2576[0m [43m [49m[43madd_special_tokens[49m[38;5;241;43m=[39;49m[43madd_special_tokens[49m[43m,[49m [1;32m 2577[0m [43m [49m[43mpadding[49m[38;5;241;43m=[39;49m[43mpadding[49m[43m,[49m [1;32m 2578[0m [43m [49m[43mtruncation[49m[38;5;241;43m=[39;49m[43mtruncation[49m[43m,[49m [1;32m 2579[0m [43m [49m[43mmax_length[49m[38;5;241;43m=[39;49m[43mmax_length[49m[43m,[49m [1;32m 2580[0m [43m [49m[43mstride[49m[38;5;241;43m=[39;49m[43mstride[49m[43m,[49m [1;32m 2581[0m [43m [49m[43mis_split_into_words[49m[38;5;241;43m=[39;49m[43mis_split_into_words[49m[43m,[49m [1;32m 2582[0m [43m [49m[43mpad_to_multiple_of[49m[38;5;241;43m=[39;49m[43mpad_to_multiple_of[49m[43m,[49m [1;32m 2583[0m [43m [49m[43mreturn_tensors[49m[38;5;241;43m=[39;49m[43mreturn_tensors[49m[43m,[49m [1;32m 2584[0m [43m [49m[43mreturn_token_type_ids[49m[38;5;241;43m=[39;49m[43mreturn_token_type_ids[49m[43m,[49m [1;32m 2585[0m [43m [49m[43mreturn_attention_mask[49m[38;5;241;43m=[39;49m[43mreturn_attention_mask[49m[43m,[49m [1;32m 2586[0m [43m [49m[43mreturn_overflowing_tokens[49m[38;5;241;43m=[39;49m[43mreturn_overflowing_tokens[49m[43m,[49m [1;32m 2587[0m [43m [49m[43mreturn_special_tokens_mask[49m[38;5;241;43m=[39;49m[43mreturn_special_tokens_mask[49m[43m,[49m [1;32m 2588[0m [43m [49m[43mreturn_offsets_mapping[49m[38;5;241;43m=[39;49m[43mreturn_offsets_mapping[49m[43m,[49m [1;32m 2589[0m [43m [49m[43mreturn_length[49m[38;5;241;43m=[39;49m[43mreturn_length[49m[43m,[49m [1;32m 2590[0m [43m [49m[43mverbose[49m[38;5;241;43m=[39;49m[43mverbose[49m[43m,[49m [1;32m 2591[0m [43m [49m[38;5;241;43m*[39;49m[38;5;241;43m*[39;49m[43mkwargs[49m[43m,[49m [1;32m 2592[0m [43m [49m[43m)[49m [1;32m 2593[0m [38;5;28;01melse[39;00m: [1;32m 2594[0m [38;5;28;01mreturn[39;00m [38;5;28mself[39m[38;5;241m.[39mencode_plus( [1;32m 2595[0m text[38;5;241m=[39mtext, [1;32m 2596[0m text_pair[38;5;241m=[39mtext_pair, [0;32m (...)[0m [1;32m 2612[0m [38;5;241m*[39m[38;5;241m*[39mkwargs, [1;32m 2613[0m ) File [0;32m~/anaconda3/envs/jarvis/lib/python3.9/site-packages/transformers/tokenization_utils_base.py:2765[0m, in [0;36mPreTrainedTokenizerBase.batch_encode_plus[0;34m(self, batch_text_or_text_pairs, add_special_tokens, padding, truncation, max_length, stride, is_split_into_words, pad_to_multiple_of, return_tensors, return_token_type_ids, return_attention_mask, return_overflowing_tokens, return_special_tokens_mask, return_offsets_mapping, return_length, verbose, **kwargs)[0m [1;32m 2755[0m [38;5;66;03m# Backward compatibility for 'truncation_strategy', 'pad_to_max_length'[39;00m [1;32m 2756[0m padding_strategy, truncation_strategy, max_length, kwargs [38;5;241m=[39m [38;5;28mself[39m[38;5;241m.[39m_get_padding_truncation_strategies( [1;32m 2757[0m padding[38;5;241m=[39mpadding, [1;32m 2758[0m truncation[38;5;241m=[39mtruncation, [0;32m (...)[0m [1;32m 2762[0m [38;5;241m*[39m[38;5;241m*[39mkwargs, [1;32m 2763[0m ) [0;32m-> 2765[0m [38;5;28;01mreturn[39;00m [38;5;28;43mself[39;49m[38;5;241;43m.[39;49m[43m_batch_encode_plus[49m[43m([49m [1;32m 2766[0m [43m [49m[43mbatch_text_or_text_pairs[49m[38;5;241;43m=[39;49m[43mbatch_text_or_text_pairs[49m[43m,[49m [1;32m 2767[0m [43m [49m[43madd_special_tokens[49m[38;5;241;43m=[39;49m[43madd_special_tokens[49m[43m,[49m [1;32m 2768[0m [43m [49m[43mpadding_strategy[49m[38;5;241;43m=[39;49m[43mpadding_strategy[49m[43m,[49m [1;32m 2769[0m [43m [49m[43mtruncation_strategy[49m[38;5;241;43m=[39;49m[43mtruncation_strategy[49m[43m,[49m [1;32m 2770[0m [43m [49m[43mmax_length[49m[38;5;241;43m=[39;49m[43mmax_length[49m[43m,[49m [1;32m 2771[0m [43m [49m[43mstride[49m[38;5;241;43m=[39;49m[43mstride[49m[43m,[49m [1;32m 2772[0m [43m [49m[43mis_split_into_words[49m[38;5;241;43m=[39;49m[43mis_split_into_words[49m[43m,[49m [1;32m 2773[0m [43m [49m[43mpad_to_multiple_of[49m[38;5;241;43m=[39;49m[43mpad_to_multiple_of[49m[43m,[49m [1;32m 2774[0m [43m [49m[43mreturn_tensors[49m[38;5;241;43m=[39;49m[43mreturn_tensors[49m[43m,[49m [1;32m 2775[0m [43m [49m[43mreturn_token_type_ids[49m[38;5;241;43m=[39;49m[43mreturn_token_type_ids[49m[43m,[49m [1;32m 2776[0m [43m [49m[43mreturn_attention_mask[49m[38;5;241;43m=[39;49m[43mreturn_attention_mask[49m[43m,[49m [1;32m 2777[0m [43m [49m[43mreturn_overflowing_tokens[49m[38;5;241;43m=[39;49m[43mreturn_overflowing_tokens[49m[43m,[49m [1;32m 2778[0m [43m [49m[43mreturn_special_tokens_mask[49m[38;5;241;43m=[39;49m[43mreturn_special_tokens_mask[49m[43m,[49m [1;32m 2779[0m [43m [49m[43mreturn_offsets_mapping[49m[38;5;241;43m=[39;49m[43mreturn_offsets_mapping[49m[43m,[49m [1;32m 2780[0m [43m [49m[43mreturn_length[49m[38;5;241;43m=[39;49m[43mreturn_length[49m[43m,[49m [1;32m 2781[0m [43m [49m[43mverbose[49m[38;5;241;43m=[39;49m[43mverbose[49m[43m,[49m [1;32m 2782[0m [43m [49m[38;5;241;43m*[39;49m[38;5;241;43m*[39;49m[43mkwargs[49m[43m,[49m [1;32m 2783[0m [43m[49m[43m)[49m File [0;32m~/anaconda3/envs/jarvis/lib/python3.9/site-packages/transformers/tokenization_utils_fast.py:429[0m, in [0;36mPreTrainedTokenizerFast._batch_encode_plus[0;34m(self, batch_text_or_text_pairs, add_special_tokens, padding_strategy, truncation_strategy, max_length, stride, is_split_into_words, pad_to_multiple_of, return_tensors, return_token_type_ids, return_attention_mask, return_overflowing_tokens, return_special_tokens_mask, return_offsets_mapping, return_length, verbose)[0m [1;32m 420[0m [38;5;66;03m# Set the truncation and padding strategy and restore the initial configuration[39;00m [1;32m 421[0m [38;5;28mself[39m[38;5;241m.[39mset_truncation_and_padding( [1;32m 422[0m padding_strategy[38;5;241m=[39mpadding_strategy, [1;32m 423[0m truncation_strategy[38;5;241m=[39mtruncation_strategy, [0;32m (...)[0m [1;32m 426[0m pad_to_multiple_of[38;5;241m=[39mpad_to_multiple_of, [1;32m 427[0m ) [0;32m--> 429[0m encodings [38;5;241m=[39m [38;5;28;43mself[39;49m[38;5;241;43m.[39;49m[43m_tokenizer[49m[38;5;241;43m.[39;49m[43mencode_batch[49m[43m([49m [1;32m 430[0m [43m [49m[43mbatch_text_or_text_pairs[49m[43m,[49m [1;32m 431[0m [43m [49m[43madd_special_tokens[49m[38;5;241;43m=[39;49m[43madd_special_tokens[49m[43m,[49m [1;32m 432[0m [43m [49m[43mis_pretokenized[49m[38;5;241;43m=[39;49m[43mis_split_into_words[49m[43m,[49m [1;32m 433[0m [43m[49m[43m)[49m [1;32m 435[0m [38;5;66;03m# Convert encoding to dict[39;00m [1;32m 436[0m [38;5;66;03m# `Tokens` has type: Tuple[[39;00m [1;32m 437[0m [38;5;66;03m# List[Dict[str, List[List[int]]]] or List[Dict[str, 2D-Tensor]],[39;00m [1;32m 438[0m [38;5;66;03m# List[EncodingFast][39;00m [1;32m 439[0m [38;5;66;03m# ][39;00m [1;32m 440[0m [38;5;66;03m# with nested dimensions corresponding to batch, overflows, sequence length[39;00m [1;32m 441[0m tokens_and_encodings [38;5;241m=[39m [ [1;32m 442[0m [38;5;28mself[39m[38;5;241m.[39m_convert_encoding( [1;32m 443[0m encoding[38;5;241m=[39mencoding, [0;32m (...)[0m [1;32m 452[0m [38;5;28;01mfor[39;00m encoding [38;5;129;01min[39;00m encodings [1;32m 453[0m ] [0;31mTypeError[0m: TextEncodeInput must be Union[TextInputSequence, Tuple[InputSequence, InputSequence]]
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model, label_pad_token_id=-100, pad_to_multiple_of=8)
training_args = Seq2SeqTrainingArguments(
output_dir="/kaggle/",
per_device_train_batch_size=8,
per_device_eval_batch_size=16,
predict_with_generate=True,
learning_rate=5e-5,
num_train_epochs=3,
evaluation_strategy="epoch",
save_strategy="epoch",
save_total_limit=1,
load_best_model_at_end=True,
)
trainer = Seq2SeqTrainer(
model=model,
args=training_args,
data_collator=data_collator,
train_dataset=tokenized_dataset["train"],
eval_dataset=tokenized_dataset["test"],
)
trainer.train()
nlg = pipeline('summarization', model=model, tokenizer=tokenizer)
nlg(f'generate text: dish[tatar], price[50], ingredient[wolowina]')[0]['summary_text']
nlg(f'generate text: payment_methods[gotowka], price[150], addresses[ulica Dluga 5]')[0]['summary_text']
nlg(f'generate text: dish[tiramisu], ingredient[mleko], allergy[laktoza]')[0]['summary_text']
nlg(f'generate text: time[dziesiata]')[0]['summary_text']
model.save_pretrained("/kaggle/output")