From 78fb510cba3fae1b53c3d47a071013ab0b32112d Mon Sep 17 00:00:00 2001 From: Filip Gralinski Date: Mon, 14 Jun 2021 15:39:15 +0200 Subject: [PATCH] 14 --- wyk/14_pretrenowanie.ipynb | 338 +++++++++++++++++++++++++++++++++++++ wyk/14_pretrenowanie.org | 69 +++++++- 2 files changed, 401 insertions(+), 6 deletions(-) create mode 100644 wyk/14_pretrenowanie.ipynb diff --git a/wyk/14_pretrenowanie.ipynb b/wyk/14_pretrenowanie.ipynb new file mode 100644 index 0000000..ae7721a --- /dev/null +++ b/wyk/14_pretrenowanie.ipynb @@ -0,0 +1,338 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Pretrenowanie modeli\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "System AlphaZero uczy się grając sam ze sobą — wystarczy 24 godziny,\n", + "by system nauczył się grać w szachy lub go na nadludzkim poziomie.\n", + "\n", + "**Pytanie**: Dlaczego granie samemu ze sobą nie jest dobrym sposobem\n", + " nauczenia się grania w szachy dla człowieka, a dla maszyny jest?\n", + "\n", + "Co jest odpowiednikiem grania samemu ze sobą w świecie przetwarzania tekstu?\n", + "Tzn. **pretrenowanie** (*pretraining*) na dużym korpusie tekstu. (Tekst jest tani!)\n", + "\n", + "Jest kilka sposobów na pretrenowanie modelu, w każdym razie sprowadza\n", + "się do odgadywania następnego bądź zamaskowanego słowa.\n", + "W każdym razie zawsze stosujemy softmax (być może ze „sztuczkami” takimi jak\n", + "negatywne próbkowanie albo hierarchiczny softamx) na pewnej **representecji kontekstowej**:\n", + "\n", + "$$\\vec{p} = \\operatorname{softmax}(f(\\vec{c})).$$\n", + "\n", + "Model jest karany używając funkcji log loss:\n", + "\n", + "$$-\\log(p_j),$$\n", + "\n", + "gdzie $w_j$ jest wyrazem, który pojawił się rzeczywiście w korpusie.\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Przewidywanie słowa (GPT-2)\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Jeden ze sposobów pretrenowania modelu to po prostu przewidywanie\n", + "następnego słowa.\n", + "\n", + "Zainstalujmy najpierw bibliotekę transformers.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "! pip install transformers" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "50257\n" + ] + }, + { + "data": { + "text/plain": [ + "[('Ġon', 0.6786560416221619),\n", + " ('Ġupon', 0.04339785501360893),\n", + " ('Ġheavily', 0.02208443358540535),\n", + " ('Ġin', 0.021049050614237785),\n", + " (',', 0.020188499242067337),\n", + " ('Ġa', 0.01833895780146122),\n", + " ('Ġvery', 0.017935041338205338),\n", + " ('Ġentirely', 0.017528969794511795),\n", + " ('Ġlargely', 0.016769640147686005),\n", + " ('Ġto', 0.01009418722242117),\n", + " ('Ġgreatly', 0.010009866207838058),\n", + " ('Ġnot', 0.009016563184559345),\n", + " ('Ġmore', 0.005853226874023676),\n", + " ('Ġprimarily', 0.005203146021813154),\n", + " ('Ġstrongly', 0.0034501152113080025),\n", + " ('Ġpartly', 0.0033184229396283627),\n", + " ('Ġmuch', 0.0033095215912908316),\n", + " ('Ġmostly', 0.0032150144688785076),\n", + " ('Ġmainly', 0.0030899408739060163),\n", + " ('Ġfor', 0.003034428460523486),\n", + " ('.', 0.0028878094162791967),\n", + " ('Ġboth', 0.0028405177872627974),\n", + " ('Ġsomewhat', 0.0028194624464958906),\n", + " ('Ġcru', 0.002263976726680994),\n", + " ('Ġas', 0.00221616611815989),\n", + " ('Ġof', 0.0022000609897077084),\n", + " ('Ġalmost', 0.001968063646927476),\n", + " ('Ġat', 0.0018015997484326363),\n", + " ('Ġhighly', 0.0017461496172472835),\n", + " ('Ġcompletely', 0.001692073536105454)]" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import torch\n", + "from transformers import GPT2Tokenizer, GPT2LMHeadModel\n", + "tokenizer = GPT2Tokenizer.from_pretrained('gpt2-large')\n", + "model = GPT2LMHeadModel.from_pretrained('gpt2-large')\n", + "text = \"This issue depends\"\n", + "encoded_input = tokenizer(text, return_tensors='pt')\n", + "output = model(**encoded_input)\n", + "next_token_probs = torch.softmax(output[0][:, -1, :][0], dim=0)\n", + "\n", + "next_token_probs\n", + "nb_of_tokens = next_token_probs.size()[0]\n", + "print(nb_of_tokens)\n", + "\n", + "_, top_k_indices = torch.topk(next_token_probs, 30, sorted=True)\n", + "\n", + "words = tokenizer.convert_ids_to_tokens(top_k_indices)\n", + "\n", + "top_probs = []\n", + "\n", + "for ix in range(len(top_k_indices)):\n", + " top_probs.append((words[ix], next_token_probs[top_k_indices[ix]].item()))\n", + "\n", + "top_probs" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Zalety tego podejścia:\n", + "\n", + "- prostota,\n", + "- dobra podstawa do strojenia systemów generowania tekstu zwłaszcza\n", + " „otwartego” (systemy dialogowe, generowanie (fake) newsów, streszczanie tekstu),\n", + " ale niekoniecznie tłumaczenia maszynowego,\n", + "- zaskakująca skuteczność przy uczeniu *few-shot* i *zero-shot*.\n", + "\n", + "Wady:\n", + "\n", + "- asymetryczność, przetwarzanie tylko z lewej do prawej, preferencja\n", + " dla lewego kontekstu,\n", + "- mniejsza skuteczność przy dostrajaniu do zadań klasyfikacji i innych zadań\n", + " niepolegających na prostym generowaniu.\n", + "\n", + "Przykłady modeli: GPT, GPT-2, GPT-3, DialoGPT.\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Maskowanie słów (BERT)\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Inną metodą jest maskowanie słów (*Masked Language Modeling*, *MLM*).\n", + "\n", + "W tym podejściu losowe wybrane zastępujemy losowe słowa specjalnym\n", + "tokenem (`[MASK]`) i każemy modelowi odgadywać w ten sposób\n", + "zamaskowane słowa (z uwzględnieniem również prawego kontekstu!).\n", + "\n", + "Móciąc ściśle, w jednym z pierwszych modeli tego typu (BERT)\n", + "zastosowano schemat, w którym również niezamaskowane słowa są odgadywane (!):\n", + "\n", + "- wybieramy losowe 15% wyrazów do odgadnięcia\n", + "- 80% z nich zastępujemy tokenem `[MASK]`,\n", + "- 10% zastępujemy innym losowym wyrazem,\n", + "- 10% pozostawiamy bez zmian.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "# Out[3]:" + ] + } + ], + "source": [ + "from transformers import AutoModelWithLMHead, AutoTokenizer\n", + "import torch\n", + "\n", + "tokenizer = AutoTokenizer.from_pretrained(\"xlm-roberta-large\")\n", + "model = AutoModelWithLMHead.from_pretrained(\"xlm-roberta-large\")\n", + "\n", + "sequence = f'II wojna światowa zakończyła się w {tokenizer.mask_token} roku.'\n", + "\n", + "input_ids = tokenizer.encode(sequence, return_tensors=\"pt\")\n", + "mask_token_index = torch.where(input_ids == tokenizer.mask_token_id)[1]\n", + "\n", + "token_logits = model(input_ids)[0]\n", + "mask_token_logits = token_logits[0, mask_token_index, :]\n", + "mask_token_logits = torch.softmax(mask_token_logits, dim=1)\n", + "\n", + "top_10 = torch.topk(mask_token_logits, 10, dim=1)\n", + "top_10_tokens = zip(top_10.indices[0].tolist(), top_10.values[0].tolist())\n", + "\n", + "for token, score in top_10_tokens:\n", + " print(sequence.replace(tokenizer.mask_token, tokenizer.decode([token])), f\"(score: {score})\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Przykłady: BERT, RoBERTa (również Polish RoBERTa).\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Podejście generatywne (koder-dekoder).\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "System ma wygenerować odpowiedź na różne pytania (również\n", + "odpowiadające zadaniu MLM), np.:\n", + "\n", + "- \"translate English to German: That is good.\" => \"Das ist gut.\"\n", + "- \"cola sentence: The course is jumping well.\" => \"not acceptable\"\n", + "- \"summarize: state authorities dispatched emergency crews tuesday to survey the damage after an onslaught of severe weather in mississippi…\"\n", + " => \"six people hospitalized after a storm in attala county\"\n", + "- \"Thank you for me to your party week.\" => for inviting last \n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from transformers import T5Tokenizer, T5Config, T5ForConditionalGeneration\n", + "\n", + "T5_PATH = 't5-base'\n", + "\n", + "t5_tokenizer = T5Tokenizer.from_pretrained(T5_PATH)\n", + "t5_config = T5Config.from_pretrained(T5_PATH)\n", + "t5_mlm = T5ForConditionalGeneration.from_pretrained(T5_PATH, config=t5_config)\n", + "\n", + "slot = ''\n", + "\n", + "text = f'Warsaw is the {slot} of Poland.'\n", + "\n", + "encoded = t5_tokenizer.encode_plus(text, add_special_tokens=True, return_tensors='pt')\n", + "input_ids = encoded['input_ids']\n", + "\n", + "outputs = t5_mlm.generate(input_ids=input_ids,\n", + " num_beams=200, num_return_sequences=5,\n", + " max_length=5)\n", + "\n", + "_0_index = text.index(slot)\n", + "_result_prefix = text[:_0_index]\n", + "_result_suffix = text[_0_index+len(slot):]\n", + "\n", + "def _filter(output, end_token=''):\n", + " _txt = t5_tokenizer.decode(output[2:], skip_special_tokens=False, clean_up_tokenization_spaces=False)\n", + " if end_token in _txt:\n", + " _end_token_index = _txt.index(end_token)\n", + " return _result_prefix + _txt[:_end_token_index] + _result_suffix\n", + " else:\n", + " return _result_prefix + _txt + _result_suffix\n", + "\n", + "\n", + "results = [_filter(out) for out in outputs]\n", + "results" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "(Zob. [https://arxiv.org/pdf/1910.10683.pdf](https://arxiv.org/pdf/1910.10683.pdf))\n", + "\n", + "Przykład: T5, mT5\n", + "\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.2" + }, + "org": null + }, + "nbformat": 4, + "nbformat_minor": 1 +} diff --git a/wyk/14_pretrenowanie.org b/wyk/14_pretrenowanie.org index 1b33f10..cc46039 100644 --- a/wyk/14_pretrenowanie.org +++ b/wyk/14_pretrenowanie.org @@ -132,10 +132,10 @@ zastosowano schemat, w którym również niezamaskowane słowa są odgadywane (! from transformers import AutoModelWithLMHead, AutoTokenizer import torch -tokenizer = AutoTokenizer.from_pretrained("distilroberta-base") -model = AutoModelWithLMHead.from_pretrained("distilroberta-base") +tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-large") +model = AutoModelWithLMHead.from_pretrained("xlm-roberta-large") -sequence = f"Hugging Face is a French company based in {tokenizer.mask_token}" +sequence = f'II wojna światowa zakończyła się w {tokenizer.mask_token} roku.' input_ids = tokenizer.encode(sequence, return_tensors="pt") mask_token_index = torch.where(input_ids == tokenizer.mask_token_id)[1] @@ -144,12 +144,69 @@ token_logits = model(input_ids)[0] mask_token_logits = token_logits[0, mask_token_index, :] mask_token_logits = torch.softmax(mask_token_logits, dim=1) -top_5 = torch.topk(mask_token_logits, 5, dim=1) -top_5_tokens = zip(top_5.indices[0].tolist(), top_5.values[0].tolist()) +top_10 = torch.topk(mask_token_logits, 10, dim=1) +top_10_tokens = zip(top_10.indices[0].tolist(), top_10.values[0].tolist()) -for token, score in top_5_tokens: +for token, score in top_10_tokens: print(sequence.replace(tokenizer.mask_token, tokenizer.decode([token])), f"(score: {score})") #+END_SRC +#+RESULTS: +:results: +# Out[3]: +:end: + Przykłady: BERT, RoBERTa (również Polish RoBERTa). + +** Podejście generatywne (koder-dekoder). + +System ma wygenerować odpowiedź na różne pytania (również +odpowiadające zadaniu MLM), np.: + +- "translate English to German: That is good." => "Das ist gut." +- "cola sentence: The course is jumping well." => "not acceptable" +- "summarize: state authorities dispatched emergency crews tuesday to survey the damage after an onslaught of severe weather in mississippi..." + => "six people hospitalized after a storm in attala county" +- "Thank you for me to your party week." => for inviting last + +#+BEGIN_SRC ipython :session mysession :exports both :results raw drawer +from transformers import T5Tokenizer, T5Config, T5ForConditionalGeneration + +T5_PATH = 't5-base' + +t5_tokenizer = T5Tokenizer.from_pretrained(T5_PATH) +t5_config = T5Config.from_pretrained(T5_PATH) +t5_mlm = T5ForConditionalGeneration.from_pretrained(T5_PATH, config=t5_config) + +slot = '' + +text = f'Warsaw is the {slot} of Poland.' + +encoded = t5_tokenizer.encode_plus(text, add_special_tokens=True, return_tensors='pt') +input_ids = encoded['input_ids'] + +outputs = t5_mlm.generate(input_ids=input_ids, + num_beams=200, num_return_sequences=5, + max_length=5) + +_0_index = text.index(slot) +_result_prefix = text[:_0_index] +_result_suffix = text[_0_index+len(slot):] + +def _filter(output, end_token=''): + _txt = t5_tokenizer.decode(output[2:], skip_special_tokens=False, clean_up_tokenization_spaces=False) + if end_token in _txt: + _end_token_index = _txt.index(end_token) + return _result_prefix + _txt[:_end_token_index] + _result_suffix + else: + return _result_prefix + _txt + _result_suffix + + +results = [_filter(out) for out in outputs] +results +#+END_SRC + +(Zob. https://arxiv.org/pdf/1910.10683.pdf) + +Przykład: T5, mT5