JARVIS/nlg_train.ipynb

1 line
24 KiB
Plaintext
Raw Normal View History

2024-06-03 14:49:59 +02:00
{"metadata":{"kernelspec":{"name":"python3","display_name":"Python 3","language":"python"},"language_info":{"name":"python","version":"3.10.13","mimetype":"text/x-python","codemirror_mode":{"name":"ipython","version":3},"pygments_lexer":"ipython3","nbconvert_exporter":"python","file_extension":".py"},"kaggle":{"accelerator":"gpu","dataSources":[{"sourceId":8587424,"sourceType":"datasetVersion","datasetId":5135632}],"dockerImageVersionId":30716,"isInternetEnabled":true,"language":"python","sourceType":"notebook","isGpuEnabled":true}},"nbformat_minor":4,"nbformat":4,"cells":[{"cell_type":"code","source":"from transformers import (\n AutoModelForSeq2SeqLM,\n AutoTokenizer,\n DataCollatorForSeq2Seq,\n Seq2SeqTrainer,\n Seq2SeqTrainingArguments,\n pipeline,\n)\n\nfrom datasets import load_dataset\n\nmodel_name = \"google/umt5-small\"","metadata":{"execution":{"iopub.status.busy":"2024-06-03T11:18:55.032642Z","iopub.execute_input":"2024-06-03T11:18:55.033345Z","iopub.status.idle":"2024-06-03T11:19:13.773777Z","shell.execute_reply.started":"2024-06-03T11:18:55.033313Z","shell.execute_reply":"2024-06-03T11:19:13.772989Z"},"trusted":true},"execution_count":1,"outputs":[{"name":"stderr","text":"2024-06-03 11:19:02.256736: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n2024-06-03 11:19:02.256864: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n2024-06-03 11:19:02.368948: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n","output_type":"stream"}]},{"cell_type":"code","source":"dataset = load_dataset('csv', data_files='/kaggle/input/ngl-data/nlg_data.csv', split='train').train_test_split(test_size=0.1)\ndataset","metadata":{"execution":{"iopub.status.busy":"2024-06-03T11:19:13.775364Z","iopub.execute_input":"2024-06-03T11:19:13.775904Z","iopub.status.idle":"2024-06-03T11:19:14.356839Z","shell.execute_reply.started":"2024-06-03T11:19:13.775878Z","shell.execute_reply":"2024-06-03T11:19:14.355976Z"},"trusted":true},"execution_count":2,"outputs":[{"output_type":"display_data","data":{"text/plain":"Generating train split: 0 examples [00:00, ? examples/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"fdd37b65a44d42b2931bdc0db8229fa7"}},"metadata":{}},{"execution_count":2,"output_type":"execute_result","data":{"text/plain":"DatasetDict({\n train: Dataset({\n features: ['mr', 'ref'],\n num_rows: 18564\n })\n test: Dataset({\n features: ['mr', 'ref'],\n num_rows: 2063\n })\n})"},"metadata":{}}]},{"cell_type":"code","source":"tokenizer = AutoTokenizer.from_pretrained(model_name)\n\n\ndef tokenize_samples(samples):\n inputs = [f\"generate text: {mr}\" for mr in samples[\"mr\"]]\n\n tokenized_inputs = tokenizer(\n inputs,\n max_length=128,\n padding=\"max_length\",\n truncation=True,\n )\n\n labels = tokenizer(\n text_target=samples[\"ref\"],\n max_length=128,\n padding=\"max_length\",\n truncation=True,\n )\n\n labels[\"input_ids\"] = [\n [\n (token_id if token_id != tokenizer.pad_token_id else -100)\n for token_id in label\n ]\n for label in labels[\"input_ids\"]\n ]\n\n tokenized_inputs[\"labels\"] = labels[\"input_ids\"]\n return tokenized_inputs\n\n\ntokenized_dataset = dataset.map(\n tokenize_samples,\n batched=True,\n remove_columns=[\"mr\", \"ref\"],\n)\n\ntokenized_dataset","metadata":{"execution":{"iopub.status.busy":"2024-06-03T11:19:14.357803Z","iopub.execute_input":"2024-06-03T11:19:14.358052Z","iopub.status.idle":"2024-06-03T11:19:24.614600Z","shell.execute_reply.started":"2024-06-03T11:19:14.