2 lines
7.6 KiB
Plaintext
2 lines
7.6 KiB
Plaintext
{"cells":[{"cell_type":"markdown","metadata":{},"source":["# Fine tuning GPT-2\n"]},{"cell_type":"code","execution_count":null,"metadata":{"executionInfo":{"elapsed":2080,"status":"ok","timestamp":1686690844437,"user":{"displayName":"Mateusz Ogrodowczyk","userId":"07738898985493819606"},"user_tz":-120},"id":"EL-Tr90XTOL4"},"outputs":[],"source":["import torch\n","\n","from transformers import TextDataset, DataCollatorForLanguageModeling\n","from transformers import GPT2Tokenizer, GPT2LMHeadModel\n","from transformers import Trainer, TrainingArguments\n","\n","import lzma\n","\n","device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["torch.__version__, device"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"R8PxRwwzTOL_"},"source":["### Methods"]},{"cell_type":"code","execution_count":null,"metadata":{"executionInfo":{"elapsed":31,"status":"ok","timestamp":1686690844438,"user":{"displayName":"Mateusz Ogrodowczyk","userId":"07738898985493819606"},"user_tz":-120},"id":"cOIuYlKzTOMC"},"outputs":[],"source":["def file_iterator(file_path):\n"," print(file_path, file_path.endswith(\".xz\"))\n"," if file_path.endswith(\".xz\"):\n"," with lzma.open(file_path, mode=\"r\") as fp:\n"," for line in fp.readlines():\n"," yield line.decode(\"utf-8\")\n"," else:\n"," with open(file_path, \"r\", encoding=\"utf-8\") as fp:\n"," for line in fp.readlines():\n"," yield line\n","\n","\n","def clear_line(line):\n"," return line.lower().replace(\"\\\\n\", \" \").strip(\"\\n\\t \")\n","\n","\n","def prepare_training_data(dir_path):\n"," data_iter = file_iterator(dir_path + \"/in.tsv.xz\")\n"," expected_iter = file_iterator(dir_path + \"/expected.tsv\")\n"," new_file_path = dir_path + \"/in.txt\" \n"," with open(new_file_path, \"w\", encoding=\"utf-8\") as fp:\n"," for word, line in zip(expected_iter, data_iter):\n"," left_context = clear_line(line.split(\"\\t\")[6])\n"," text = left_context + \" \" + word.lower().strip() + \"\\n\"\n"," fp.write(text)\n"," return new_file_path\n","\n","\n","def train(\n"," dataset,\n"," model,\n"," data_collator,\n"," batch_size,\n"," epochs,\n"," output_path,\n"," overwrite_output_path=False,\n"," save_steps=10000,\n","):\n"," training_args = TrainingArguments(\n"," output_dir=output_path,\n"," overwrite_output_dir=overwrite_output_path,\n"," per_device_train_batch_size=batch_size,\n"," num_train_epochs=epochs,\n"," logging_steps=save_steps,\n"," save_steps=save_steps,\n"," )\n"," trainer = Trainer(\n"," model=model,\n"," args=training_args,\n"," data_collator=data_collator,\n"," train_dataset=dataset,\n"," )\n"," trainer.train()\n"," trainer.save_model()\n"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"JEDaF5SCTOML"},"source":["### Load & prepare data and model"]},{"cell_type":"code","execution_count":null,"metadata":{"executionInfo":{"elapsed":67006,"status":"ok","timestamp":1686690912814,"user":{"displayName":"Mateusz Ogrodowczyk","userId":"07738898985493819606"},"user_tz":-120},"id":"iecEffdcfAcv"},"outputs":[],"source":["training_data_path = prepare_training_data(\"train\")\n"]},{"cell_type":"code","execution_count":null,"metadata":{"executionInfo":{"elapsed":19,"status":"ok","timestamp":1686690912817,"user":{"displayName":"Mateusz Ogrodowczyk","userId":"07738898985493819606"},"user_tz":-120},"id":"dqBikr2UfDbm"},"outputs":[],"source":["MODEL_NAME = \"gpt2\"\n","OUTPUT_PATH = \"results\"\n"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":27403,"status":"ok","timestamp":1686690940205,"user":{"displayName":"Mateusz Ogrodowczyk","userId":"07738898985493819606"},"user_tz":-120},"id":"4eZ80aX9TOMN","outputId":"316ac128-d4ce-4f93-a156-4bef6572b80b"},"outputs":[],"source":["tokenizer = GPT2Tokenizer.from_pretrained(MODEL_NAME)\n","tokenizer.save_pretrained(OUTPUT_PATH)\n","\n","train_dataset = TextDataset(\n"," tokenizer=tokenizer,\n"," file_path=training_data_path,\n"," block_size=128,\n",")\n","data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)\n","\n","model = GPT2LMHeadModel.from_pretrained(MODEL_NAME).to(device)\n","model.save_pretrained(OUTPUT_PATH)\n"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"A-C83WVeTOMU"},"source":["### Train model"]},{"cell_type":"code","execution_count":null,"metadata":{"executionInfo":{"elapsed":13,"status":"ok","timestamp":1686690940206,"user":{"displayName":"Mateusz Ogrodowczyk","userId":"07738898985493819606"},"user_tz":-120},"id":"uJWb9P2NekZJ"},"outputs":[],"source":["EPOCHS = 1\n","BATCH_SIZE = 32\n"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["train(\n"," dataset=train_dataset,\n"," model=model,\n"," data_collator=data_collator,\n"," batch_size=BATCH_SIZE,\n"," epochs=EPOCHS,\n"," output_path=OUTPUT_PATH,\n"," save_steps=10000\n",")\n"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"01QLoSptTOMY"},"source":["### Inference"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"dqYwJTcynNES"},"outputs":[],"source":["for file_path, lines_no in ((\"test-A/in.tsv.xz\", 7414), (\"dev-0/in.tsv.xz\", 10519)):\n"," with open(file_path.split(\"/\")[0] + \"/out.tsv\", \"w\", encoding=\"utf-8\") as fp:\n"," print(f'Working on file: {file_path}...')\n"," i = 1\n"," missed_lines = []\n"," for line in file_iterator(file_path):\n"," print(f'\\r\\t{100.0*i/lines_no:.2f}% ({i}/{lines_no})', end='')\n"," line = clear_line(line.split(\"\\t\")[6])\n"," inputs = tokenizer.encode(line, return_tensors=\"pt\").to(device)\n"," output = model(inputs)\n","\n"," z_dist = output[0][0][-1]\n"," prob_dist = torch.softmax(z_dist, dim=0)\n"," top_k_values, top_k_indices = prob_dist.topk(20)\n","\n"," probs = []\n"," result = [\n"," (\n"," tokenizer.decode(idx).strip(),\n"," probs.append(prob) or prob if prob <= 0.7 else 0.7,\n"," )\n"," for prob, idx in zip(top_k_values, top_k_indices)\n"," ]\n"," result = (\n"," \" \".join(f\"{pair[0]}:{pair[1]}\" for pair in result)\n"," + f\" :{1. - sum(probs)}\\n\"\n"," )\n"," if len(result) < 250:\n"," missed_lines.append(i)\n"," result = \"the:0.5175086259841919 and:0.12364283204078674 ,:0.05142376944422722 of:0.03426751121878624 .:0.028525719419121742 or:0.02097073383629322 :0.014924607239663601 every:0.008976494893431664 each:0.008128014393150806 a:0.007482781074941158 ;:0.005168373696506023 -:0.004823171999305487 holy:0.004624966997653246 one:0.004140088334679604 tho:0.003332334803417325 only:0.0030411879997700453 that:0.002834469312801957 !:0.0022952412255108356 ):0.002251386409625411 t:0.0021530792582780123 :0.14948463439941406\\n\"\n"," fp.write(result)\n"," i += 1 \n"," print(\"\\t...processing finished\\n\\tMissed lines:\", missed_lines)\n"]}],"metadata":{"colab":{"provenance":[]},"kernelspec":{"display_name":"Python 3","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.10.11"},"orig_nbformat":4},"nbformat":4,"nbformat_minor":0}
|