{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Pretrenowanie modeli\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "System AlphaZero uczy się grając sam ze sobą — wystarczy 24 godziny,\n", "by system nauczył się grać w szachy lub go na nadludzkim poziomie.\n", "\n", "**Pytanie**: Dlaczego granie samemu ze sobą nie jest dobrym sposobem\n", " nauczenia się grania w szachy dla człowieka, a dla maszyny jest?\n", "\n", "Co jest odpowiednikiem grania samemu ze sobą w świecie przetwarzania tekstu?\n", "Tzn. **pretrenowanie** (*pretraining*) na dużym korpusie tekstu. (Tekst jest tani!)\n", "\n", "Jest kilka sposobów na pretrenowanie modelu, w każdym razie sprowadza\n", "się do odgadywania następnego bądź zamaskowanego słowa.\n", "W każdym razie zawsze stosujemy softmax (być może ze „sztuczkami” takimi jak\n", "negatywne próbkowanie albo hierarchiczny softamx) na pewnej **representecji kontekstowej**:\n", "\n", "$$\\vec{p} = \\operatorname{softmax}(f(\\vec{c})).$$\n", "\n", "Model jest karany używając funkcji log loss:\n", "\n", "$$-\\log(p_j),$$\n", "\n", "gdzie $w_j$ jest wyrazem, który pojawił się rzeczywiście w korpusie.\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Przewidywanie słowa (GPT-2)\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Jeden ze sposobów pretrenowania modelu to po prostu przewidywanie\n", "następnego słowa.\n", "\n", "Zainstalujmy najpierw bibliotekę transformers.\n", "\n" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "! pip install transformers" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "50257\n" ] }, { "data": { "text/plain": [ "[('Ġon', 0.6786560416221619),\n", " ('Ġupon', 0.04339785501360893),\n", " ('Ġheavily', 0.02208443358540535),\n", " ('Ġin', 0.021049050614237785),\n", " (',', 0.020188499242067337),\n", " ('Ġa', 0.01833895780146122),\n", " ('Ġvery', 0.017935041338205338),\n", " ('Ġentirely', 0.017528969794511795),\n", " ('Ġlargely', 0.016769640147686005),\n", " ('Ġto', 0.01009418722242117),\n", " ('Ġgreatly', 0.010009866207838058),\n", " ('Ġnot', 0.009016563184559345),\n", " ('Ġmore', 0.005853226874023676),\n", " ('Ġprimarily', 0.005203146021813154),\n", " ('Ġstrongly', 0.0034501152113080025),\n", " ('Ġpartly', 0.0033184229396283627),\n", " ('Ġmuch', 0.0033095215912908316),\n", " ('Ġmostly', 0.0032150144688785076),\n", " ('Ġmainly', 0.0030899408739060163),\n", " ('Ġfor', 0.003034428460523486),\n", " ('.', 0.0028878094162791967),\n", " ('Ġboth', 0.0028405177872627974),\n", " ('Ġsomewhat', 0.0028194624464958906),\n", " ('Ġcru', 0.002263976726680994),\n", " ('Ġas', 0.00221616611815989),\n", " ('Ġof', 0.0022000609897077084),\n", " ('Ġalmost', 0.001968063646927476),\n", " ('Ġat', 0.0018015997484326363),\n", " ('Ġhighly', 0.0017461496172472835),\n", " ('Ġcompletely', 0.001692073536105454)]" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import torch\n", "from transformers import GPT2Tokenizer, GPT2LMHeadModel\n", "tokenizer = GPT2Tokenizer.from_pretrained('gpt2-large')\n", "model = GPT2LMHeadModel.from_pretrained('gpt2-large')\n", "text = \"This issue depends\"\n", "encoded_input = tokenizer(text, return_tensors='pt')\n", "output = model(**encoded_input)\n", "next_token_probs = torch.softmax(output[0][:, -1, :][0], dim=0)\n", "\n", "next_token_probs\n", "nb_of_tokens = next_token_probs.size()[0]\n", "print(nb_of_tokens)\n", "\n", "_, top_k_indices = torch.topk(next_token_probs, 30, sorted=True)\n", "\n", "words = tokenizer.convert_ids_to_tokens(top_k_indices)\n", "\n", "top_probs = []\n", "\n", "for ix in range(len(top_k_indices)):\n", " top_probs.append((words[ix], next_token_probs[top_k_indices[ix]].item()))\n", "\n", "top_probs" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Zalety tego podejścia:\n", "\n", "- prostota,\n", "- dobra podstawa do strojenia systemów generowania tekstu zwłaszcza\n", " „otwartego” (systemy dialogowe, generowanie (fake) newsów, streszczanie tekstu),\n", " ale niekoniecznie tłumaczenia maszynowego,\n", "- zaskakująca skuteczność przy uczeniu *few-shot* i *zero-shot*.\n", "\n", "Wady:\n", "\n", "- asymetryczność, przetwarzanie tylko z lewej do prawej, preferencja\n", " dla lewego kontekstu,\n", "- mniejsza skuteczność przy dostrajaniu do zadań klasyfikacji i innych zadań\n", " niepolegających na prostym generowaniu.\n", "\n", "Przykłady modeli: GPT, GPT-2, GPT-3, DialoGPT.\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Maskowanie słów (BERT)\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Inną metodą jest maskowanie słów (*Masked Language Modeling*, *MLM*).\n", "\n", "W tym podejściu losowe wybrane zastępujemy losowe słowa specjalnym\n", "tokenem (`[MASK]`) i każemy modelowi odgadywać w ten sposób\n", "zamaskowane słowa (z uwzględnieniem również prawego kontekstu!).\n", "\n", "Móciąc ściśle, w jednym z pierwszych modeli tego typu (BERT)\n", "zastosowano schemat, w którym również niezamaskowane słowa są odgadywane (!):\n", "\n", "- wybieramy losowe 15% wyrazów do odgadnięcia\n", "- 80% z nich zastępujemy tokenem `[MASK]`,\n", "- 10% zastępujemy innym losowym wyrazem,\n", "- 10% pozostawiamy bez zmian.\n", "\n" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "# Out[3]:" ] } ], "source": [ "from transformers import AutoModelWithLMHead, AutoTokenizer\n", "import torch\n", "\n", "tokenizer = AutoTokenizer.from_pretrained(\"xlm-roberta-large\")\n", "model = AutoModelWithLMHead.from_pretrained(\"xlm-roberta-large\")\n", "\n", "sequence = f'II wojna światowa zakończyła się w {tokenizer.mask_token} roku.'\n", "\n", "input_ids = tokenizer.encode(sequence, return_tensors=\"pt\")\n", "mask_token_index = torch.where(input_ids == tokenizer.mask_token_id)[1]\n", "\n", "token_logits = model(input_ids)[0]\n", "mask_token_logits = token_logits[0, mask_token_index, :]\n", "mask_token_logits = torch.softmax(mask_token_logits, dim=1)\n", "\n", "top_10 = torch.topk(mask_token_logits, 10, dim=1)\n", "top_10_tokens = zip(top_10.indices[0].tolist(), top_10.values[0].tolist())\n", "\n", "for token, score in top_10_tokens:\n", " print(sequence.replace(tokenizer.mask_token, tokenizer.decode([token])), f\"(score: {score})\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Przykłady: BERT, RoBERTa (również Polish RoBERTa).\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Podejście generatywne (koder-dekoder).\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "System ma wygenerować odpowiedź na różne pytania (również\n", "odpowiadające zadaniu MLM), np.:\n", "\n", "- \"translate English to German: That is good.\" => \"Das ist gut.\"\n", "- \"cola sentence: The course is jumping well.\" => \"not acceptable\"\n", "- \"summarize: state authorities dispatched emergency crews tuesday to survey the damage after an onslaught of severe weather in mississippi…\"\n", " => \"six people hospitalized after a storm in attala county\"\n", "- \"Thank you for me to your party week.\" => for inviting last \n", "\n" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from transformers import T5Tokenizer, T5Config, T5ForConditionalGeneration\n", "\n", "T5_PATH = 't5-base'\n", "\n", "t5_tokenizer = T5Tokenizer.from_pretrained(T5_PATH)\n", "t5_config = T5Config.from_pretrained(T5_PATH)\n", "t5_mlm = T5ForConditionalGeneration.from_pretrained(T5_PATH, config=t5_config)\n", "\n", "slot = ''\n", "\n", "text = f'Warsaw is the {slot} of Poland.'\n", "\n", "encoded = t5_tokenizer.encode_plus(text, add_special_tokens=True, return_tensors='pt')\n", "input_ids = encoded['input_ids']\n", "\n", "outputs = t5_mlm.generate(input_ids=input_ids,\n", " num_beams=200, num_return_sequences=5,\n", " max_length=5)\n", "\n", "_0_index = text.index(slot)\n", "_result_prefix = text[:_0_index]\n", "_result_suffix = text[_0_index+len(slot):]\n", "\n", "def _filter(output, end_token=''):\n", " _txt = t5_tokenizer.decode(output[2:], skip_special_tokens=False, clean_up_tokenization_spaces=False)\n", " if end_token in _txt:\n", " _end_token_index = _txt.index(end_token)\n", " return _result_prefix + _txt[:_end_token_index] + _result_suffix\n", " else:\n", " return _result_prefix + _txt + _result_suffix\n", "\n", "\n", "results = [_filter(out) for out in outputs]\n", "results" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "(Zob. [https://arxiv.org/pdf/1910.10683.pdf](https://arxiv.org/pdf/1910.10683.pdf))\n", "\n", "Przykład: T5, mT5\n", "\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.2" }, "org": null }, "nbformat": 4, "nbformat_minor": 1 }