challenging-america-word-ga.../run.ipynb

2 lines
7.6 KiB
Plaintext
Raw Normal View History

{"cells":[{"cell_type":"markdown","metadata":{},"source":["# Fine tuning GPT-2\n"]},{"cell_type":"code","execution_count":null,"metadata":{"executionInfo":{"elapsed":2080,"status":"ok","timestamp":1686690844437,"user":{"displayName":"Mateusz Ogrodowczyk","userId":"07738898985493819606"},"user_tz":-120},"id":"EL-Tr90XTOL4"},"outputs":[],"source":["import torch\n","\n","from transformers import TextDataset, DataCollatorForLanguageModeling\n","from transformers import GPT2Tokenizer, GPT2LMHeadModel\n","from transformers import Trainer, TrainingArguments\n","\n","import lzma\n","\n","device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["torch.__version__, device"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"R8PxRwwzTOL_"},"source":["### Methods"]},{"cell_type":"code","execution_count":null,"metadata":{"executionInfo":{"elapsed":31,"status":"ok","timestamp":1686690844438,"user":{"displayName":"Mateusz Ogrodowczyk","userId":"07738898985493819606"},"user_tz":-120},"id":"cOIuYlKzTOMC"},"outputs":[],"source":["def file_iterator(file_path):\n"," print(file_path, file_path.endswith(\".xz\"))\n"," if file_path.endswith(\".xz\"):\n"," with lzma.open(file_path, mode=\"r\") as fp:\n"," for line in fp.readlines():\n"," yield line.decode(\"utf-8\")\n"," else:\n"," with open(file_path, \"r\", encoding=\"utf-8\") as fp:\n"," for line in fp.readlines():\n"," yield line\n","\n","\n","def clear_line(line):\n"," return line.lower().replace(\"\\\\n\", \" \").strip(\"\\n\\t \")\n","\n","\n","def prepare_training_data(dir_path):\n"," data_iter = file_iterator(dir_path + \"/in.tsv.xz\")\n"," expected_iter = file_iterator(dir_path + \"/expected.tsv\")\n"," new_file_path = dir_path + \"/in.txt\" \n"," with open(new_file_path, \"w\", encoding=\"utf-8\") as fp:\n"," for word, line in zip(expected_iter, data_iter):\n"," left_context = clear_line(line.split(\"\\t\")[6])\n"," text = left_context + \" \" + word.lower().strip() + \"\\n\"\n"," fp.write(text)\n"," return new_file_path\n","\n","\n","def train(\n"," dataset,\n"," model,\n"," data_collator,\n"," batch_size,\n"," epochs,\n"," output_path,\n"," overwrite_output_path=False,\n"," save_steps=10000,\n","):\n"," training_args = TrainingArguments(\n"," output_dir=output_path,\n"," overwrite_output_dir=overwrite_output_path,\n"," per_device_train_batch_size=batch_size,\n"," num_train_epochs=epochs,\n"," logging_steps=save_steps,\n"," save_steps=save_steps,\n"," )\n"," trainer = Trainer(\n"," model=model,\n"," args=training_args,\n"," data_collator=data_collator,\n"," train_dataset=dataset,\n"," )\n"," trainer.train()\n"," trainer.save_model()\n"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"JEDaF5SCTOML"},"source":["### Load & prepare data and model"]},{"cell_type":"code","execution_count":null,"metadata":{"executionInfo":{"elapsed":67006,"status":"ok","timestamp":1686690912814,"user":{"displayName":"Mateusz Ogrodowczyk","userId":"07738898985493819606"},"user_tz":-120},"id":"iecEffdcfAcv"},"outputs":[],"source":["training_data_path = prepare_training_data(\"train\")\n"]},{"cell_type":"code","execution_count":null,"metadata":{"executionInfo":{"elapsed":19,"status":"ok","timestamp":1686690912817,"user":{"displayName":"Mateusz Ogrodowczyk","userId":"07738898985493819606"},"user_tz":-120},"id":"dqBikr2UfDbm"},"outputs":[],"source":["MODEL_NAME = \"gpt2\"\n","OUTPUT_PATH = \"results\"\n"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":27403,"status":"ok","timestamp":1686690940205,"user":{"displayName":"Mateusz Ogrodowczyk","userId":"07738898985493819606"},"user_tz":-120},"id":"4eZ80aX9TOMN","outputId":"316ac128-d4ce-4f93-a156-4bef6572b80b"},"outputs"