challenging-america-word-ga.../run.ipynb
2023-06-29 12:28:49 +02:00

2 lines
8.0 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{"cells":[{"attachments":{},"cell_type":"markdown","metadata":{},"source":["# Fine tuning GPT-2\n","Model dotrenowano z wykorzystaniem odwróconego ciągu tokenów i odgadywanych słów (od prawej do lewej) `f\"{word} {right_context}\".split()[::-1]]` ignorując lewy kontekst.\n","\n","https://gonito.net/view-variant/9580"]},{"cell_type":"code","execution_count":null,"metadata":{"executionInfo":{"elapsed":2080,"status":"ok","timestamp":1686690844437,"user":{"displayName":"Mateusz Ogrodowczyk","userId":"07738898985493819606"},"user_tz":-120},"id":"EL-Tr90XTOL4"},"outputs":[],"source":["import torch\n","\n","from transformers import TextDataset, DataCollatorForLanguageModeling\n","from transformers import GPT2Tokenizer, GPT2LMHeadModel\n","from transformers import Trainer, TrainingArguments, PreTrainedModel\n","\n","import lzma\n","\n","device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["torch.__version__, device"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"R8PxRwwzTOL_"},"source":["### Methods"]},{"cell_type":"code","execution_count":null,"metadata":{"executionInfo":{"elapsed":31,"status":"ok","timestamp":1686690844438,"user":{"displayName":"Mateusz Ogrodowczyk","userId":"07738898985493819606"},"user_tz":-120},"id":"cOIuYlKzTOMC"},"outputs":[],"source":["def reverse_sentence(sentence):\n"," return \" \".join(sentence.split()[::-1])\n","\n","\n","def file_iterator(file_path):\n"," print(file_path, file_path.endswith(\".xz\"))\n"," if file_path.endswith(\".xz\"):\n"," with lzma.open(file_path, mode=\"r\") as fp:\n"," for line in fp.readlines():\n"," yield line.decode(\"utf-8\")\n"," else:\n"," with open(file_path, \"r\", encoding=\"utf-8\") as fp:\n"," for line in fp.readlines():\n"," yield line\n","\n","\n","def clear_line(line):\n"," return line.lower().replace(\"\\\\n\", \" \").strip(\"\\n\\t \")\n","\n","\n","def prepare_training_data(dir_path):\n"," data_iter = file_iterator(dir_path + \"/in.tsv.xz\")\n"," expected_iter = file_iterator(dir_path + \"/expected.tsv\")\n"," new_file_path = dir_path + \"/in.txt\" \n"," with open(new_file_path, \"w\", encoding=\"utf-8\") as fp:\n"," for word, line in zip(expected_iter, data_iter):\n"," left_context = clear_line(line.split(\"\\t\")[6])\n"," text = left_context + \" \" + word.lower().strip() + \"\\n\"\n"," fp.write(text)\n"," return new_file_path\n","\n","\n","def train(\n"," dataset,\n"," model,\n"," data_collator,\n"," batch_size,\n"," epochs,\n"," output_path,\n"," overwrite_output_path=False,\n"," save_steps=10000,\n","):\n"," training_args = TrainingArguments(\n"," output_dir=output_path,\n"," overwrite_output_dir=overwrite_output_path,\n"," per_device_train_batch_size=batch_size,\n"," num_train_epochs=epochs,\n"," logging_steps=save_steps,\n"," save_steps=save_steps,\n"," )\n"," trainer = Trainer(\n"," model=model,\n"," args=training_args,\n"," data_collator=data_collator,\n"," train_dataset=dataset,\n"," )\n"," trainer.train()\n"," trainer.save_model()\n"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"JEDaF5SCTOML"},"source":["### Load & prepare data and model"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":31,"status":"ok","timestamp":1686690844441,"user":{"displayName":"Mateusz Ogrodowczyk","userId":"07738898985493819606"},"user_tz":-120},"id":"PV2j-D48gM1C","outputId":"6d4b561b-ed46-4793-92fc-f1b749a57cef"},"outputs":[],"source":["!cat train/in.txt | head -n 5"]},{"cell_type":"code","execution_count":null,"metadata":{"executionInfo":{"elapsed":67006,"status":"ok","timestamp":1686690912814,"user":{"displayName":"Mateusz Ogrodowczyk","userId":"07738898985493819606"},"user_tz":-120},"id":"iecEffdcfAcv"},"outputs":[],"source":["training_data_path = prepare_training_data(\"train\")\n"]},{"cell_type":"code","execution_count":null,"metadata":{"executionInfo":{"elapsed":19,"status":"ok","timestamp":1686690912817,"user":{"displayName":"Mateusz Ogrodowczyk","userId":"07738898985493819606"},"user_tz":-120},"id":"dqBikr2UfDbm"},"outputs":[],"source":["MODEL_NAME = \"gpt2\"\n","OUTPUT_PATH = \"results\"\n"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":27403,"status":"ok","timestamp":1686690940205,"user":{"displayName":"Mateusz Ogrodowczyk","userId":"07738898985493819606"},"user_tz":-120},"id":"4eZ80aX9TOMN","outputId":"316ac128-d4ce-4f93-a156-4bef6572b80b"},"outputs":[],"source":["tokenizer = GPT2Tokenizer.from_pretrained(MODEL_NAME)\n","tokenizer.save_pretrained(OUTPUT_PATH)\n","\n","train_dataset = TextDataset(\n"," tokenizer=tokenizer,\n"," file_path=training_data_path,\n"," block_size=128,\n",")\n","data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)\n","\n","model = GPT2LMHeadModel.from_pretrained(MODEL_NAME).to(device)\n","model.save_pretrained(OUTPUT_PATH)\n"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"A-C83WVeTOMU"},"source":["### Train model"]},{"cell_type":"code","execution_count":null,"metadata":{"executionInfo":{"elapsed":13,"status":"ok","timestamp":1686690940206,"user":{"displayName":"Mateusz Ogrodowczyk","userId":"07738898985493819606"},"user_tz":-120},"id":"uJWb9P2NekZJ"},"outputs":[],"source":["EPOCHS = 1\n","BATCH_SIZE = 32\n"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["train(\n"," dataset=train_dataset,\n"," model=model,\n"," data_collator=data_collator,\n"," batch_size=BATCH_SIZE,\n"," epochs=EPOCHS,\n"," output_path=OUTPUT_PATH,\n"," save_steps=10000\n",")\n"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"01QLoSptTOMY"},"source":["### Inference"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":26954,"status":"ok","timestamp":1686691110863,"user":{"displayName":"Mateusz Ogrodowczyk","userId":"07738898985493819606"},"user_tz":-120},"id":"wglgCZ5enrFN","outputId":"a8737e64-0fdd-40c0-9000-31b7058159c5"},"outputs":[],"source":["# model = GPT2LMHeadModel.from_pretrained('results/checkpoint-48000/')\n","# model\n"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"dqYwJTcynNES"},"outputs":[],"source":["for file_path in (\"test/in.tsv.xz\", \"dev/in.tsv.xz\"):\n"," with open(file_path.split(\"/\")[0] + \"/out.tsv\", \"w\", encoding=\"utf-8\") as fp:\n"," for line in file_iterator(file_path):\n"," line = reverse_sentence(line.lower().strip(\"\\n\").replace(\"\\\\n\", \" \"))\n"," inputs = tokenizer.encode(line, return_tensors=\"pt\").to(device)\n"," output = model(inputs)\n","\n"," z_dist = output[0][0][-1]\n"," prob_dist = torch.softmax(z_dist, dim=0)\n"," top_k_values, top_k_indices = prob_dist.topk(20)\n","\n"," remainder = 1\n"," result = \"\"\n"," probs = []\n"," result = [\n"," (\n"," tokenizer.decode(idx).strip(),\n"," probs.append(prob) or prob if prob <= 0.7 else 0.7,\n"," )\n"," for prob, idx in zip(top_k_values, top_k_indices)\n"," ]\n"," result = (\n"," \" \".join(f\"{pair[0]}:{pair[1]}\" for pair in result)\n"," + f\" :{1. - sum(probs)}\\n\"\n"," )\n"," fp.write(result)\n"]}],"metadata":{"colab":{"provenance":[]},"kernelspec":{"display_name":"Python 3","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.10.6"},"orig_nbformat":4},"nbformat":4,"nbformat_minor":0}