{"cells":[{"attachments":{},"cell_type":"markdown","metadata":{},"source":["# Fine tuning GPT-2\n","Model dotrenowano z wykorzystaniem odwróconego ciągu tokenów i odgadywanych słów (od prawej do lewej) – `f\"{word} {right_context}\".split()[::-1]]` – ignorując lewy kontekst.\n","\n","https://gonito.net/view-variant/9580"]},{"cell_type":"code","execution_count":null,"metadata":{"executionInfo":{"elapsed":2080,"status":"ok","timestamp":1686690844437,"user":{"displayName":"Mateusz Ogrodowczyk","userId":"07738898985493819606"},"user_tz":-120},"id":"EL-Tr90XTOL4"},"outputs":[],"source":["import torch\n","\n","from transformers import TextDataset, DataCollatorForLanguageModeling\n","from transformers import GPT2Tokenizer, GPT2LMHeadModel\n","from transformers import Trainer, TrainingArguments, PreTrainedModel\n","\n","import lzma\n","\n","device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["torch.__version__, device"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"R8PxRwwzTOL_"},"source":["### Methods"]},{"cell_type":"code","execution_count":null,"metadata":{"executionInfo":{"elapsed":31,"status":"ok","timestamp":1686690844438,"user":{"displayName":"Mateusz Ogrodowczyk","userId":"07738898985493819606"},"user_tz":-120},"id":"cOIuYlKzTOMC"},"outputs":[],"source":["def reverse_sentence(sentence):\n"," return \" \".join(sentence.split()[::-1])\n","\n","\n","def file_iterator(file_path):\n"," print(file_path, file_path.endswith(\".xz\"))\n"," if file_path.endswith(\".xz\"):\n"," with lzma.open(file_path, mode=\"r\") as fp:\n"," for line in fp.readlines():\n"," yield line.decode(\"utf-8\")\n"," else:\n"," with open(file_path, \"r\", encoding=\"utf-8\") as fp:\n"," for line in fp.readlines():\n"," yield line\n","\n","\n","def clear_line(line):\n"," return line.lower().replace(\"\\\\n\", \" \").strip(\"\\n\\t \")\n","\n","\n","def prepare_training_data(dir_path):\n"," data_iter = file_iterator(dir_path + \"/in.tsv.xz\")\n"," expected_iter = file_iterator(dir_path + \"/expected.tsv\")\n"," new_file_path = dir_path + \"/in.txt\" \n"," with open(new_file_path, \"w\", encoding=\"utf-8\") as fp:\n"," for word, line in zip(expected_iter, data_iter):\n"," left_context = clear_line(line.split(\"\\t\")[6])\n"," text = left_context + \" \" + word.lower().strip() + \"\\n\"\n"," fp.write(text)\n"," return new_file_path\n","\n","\n","def train(\n"," dataset,\n"," model,\n"," data_collator,\n"," batch_size,\n"," epochs,\n"," output_path,\n"," overwrite_output_path=False,\n"," save_steps=10000,\n","):\n"," training_args = TrainingArguments(\n"," output_dir=output_path,\n"," overwrite_output_dir=overwrite_output_path,\n"," per_device_train_batch_size=batch_size,\n"," num_train_epochs=epochs,\n"," logging_steps=save_steps,\n"," save_steps=save_steps,\n"," )\n"," trainer = Trainer(\n"," model=model,\n"," args=training_args,\n"," data_collator=data_collator,\n"," train_dataset=dataset,\n"," )\n"," trainer.train()\n"," trainer.save_model()\n"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"JEDaF5SCTOML"},"source":["### Load & prepare data and model"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":31,"status":"ok","timestamp":1686690844441,"user":{"displayName":"Mateusz Ogrodowczyk","userId":"07738898985493819606"},"user_tz":-120},"id":"PV2j-D48gM1C","outputId":"6d4b561b-ed46-4793-92fc-f1b749a57cef"},"outputs":[],"source":["!cat train/in.txt | head -n 5"]},{"cell_type":"code","execution_count":null,"metadata":{"executionInfo":{"elapsed":67006,"status":"ok","timestamp":1686690912814,"user":{"displayName":"Mateusz Ogrodowczyk","userId":"07738898985493819606"},"user_tz":-120},"id":"iecEffdcfAcv"},"outputs":[],"source":["training_data_path = prepare_training_data(\"train\")\n"]},{"cell_type":"code","execution_count":null,"metadata":{"executionInfo":{"elapsed":19,"status":"ok","timestamp":1686690912817,"user":{"displayName":"Mateusz Ogrodowczyk","userId":"07738898985493819606"},"user_tz":-120},"id":"dqBikr2UfDbm"},"outputs":[],"source":["MODEL_NAME = \"gpt2\"\n","OUTPUT_PATH = \"results\"\n"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":27403,"status":"ok","timestamp":1686690940205,"user":{"displayName":"Mateusz Ogrodowczyk","userId":"07738898985493819606"},"user_tz":-120},"id":"4eZ80aX9TOMN","outputId":"316ac128-d4ce-4f93-a156-4bef6572b80b"},"outputs":[],"source":["tokenizer = GPT2Tokenizer.from_pretrained(MODEL_NAME)\n","tokenizer.save_pretrained(OUTPUT_PATH)\n","\n","train_dataset = TextDataset(\n"," tokenizer=tokenizer,\n"," file_path=training_data_path,\n"," block_size=128,\n",")\n","data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)\n","\n","model = GPT2LMHeadModel.from_pretrained(MODEL_NAME).to(device)\n","model.save_pretrained(OUTPUT_PATH)\n"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"A-C83WVeTOMU"},"source":["### Train model"]},{"cell_type":"code","execution_count":null,"metadata":{"executionInfo":{"elapsed":13,"status":"ok","timestamp":1686690940206,"user":{"displayName":"Mateusz Ogrodowczyk","userId":"07738898985493819606"},"user_tz":-120},"id":"uJWb9P2NekZJ"},"outputs":[],"source":["EPOCHS = 1\n","BATCH_SIZE = 32\n"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["train(\n"," dataset=train_dataset,\n"," model=model,\n"," data_collator=data_collator,\n"," batch_size=BATCH_SIZE,\n"," epochs=EPOCHS,\n"," output_path=OUTPUT_PATH,\n"," save_steps=10000\n",")\n"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"01QLoSptTOMY"},"source":["### Inference"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":26954,"status":"ok","timestamp":1686691110863,"user":{"displayName":"Mateusz Ogrodowczyk","userId":"07738898985493819606"},"user_tz":-120},"id":"wglgCZ5enrFN","outputId":"a8737e64-0fdd-40c0-9000-31b7058159c5"},"outputs":[],"source":["# model = GPT2LMHeadModel.from_pretrained('results/checkpoint-48000/')\n","# model\n"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"dqYwJTcynNES"},"outputs":[],"source":["for file_path in (\"test/in.tsv.xz\", \"dev/in.tsv.xz\"):\n"," with open(file_path.split(\"/\")[0] + \"/out.tsv\", \"w\", encoding=\"utf-8\") as fp:\n"," for line in file_iterator(file_path):\n"," line = reverse_sentence(line.lower().strip(\"\\n\").replace(\"\\\\n\", \" \"))\n"," inputs = tokenizer.encode(line, return_tensors=\"pt\").to(device)\n"," output = model(inputs)\n","\n"," z_dist = output[0][0][-1]\n"," prob_dist = torch.softmax(z_dist, dim=0)\n"," top_k_values, top_k_indices = prob_dist.topk(20)\n","\n"," remainder = 1\n"," result = \"\"\n"," probs = []\n"," result = [\n"," (\n"," tokenizer.decode(idx).strip(),\n"," probs.append(prob) or prob if prob <= 0.7 else 0.7,\n"," )\n"," for prob, idx in zip(top_k_values, top_k_indices)\n"," ]\n"," result = (\n"," \" \".join(f\"{pair[0]}:{pair[1]}\" for pair in result)\n"," + f\" :{1. - sum(probs)}\\n\"\n"," )\n"," fp.write(result)\n"]}],"metadata":{"colab":{"provenance":[]},"kernelspec":{"display_name":"Python 3","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.10.6"},"orig_nbformat":4},"nbformat":4,"nbformat_minor":0}