{ "cells": [ { "cell_type": "markdown", "metadata": { "collapsed": false }, "source": [ "![Logo 1](https://git.wmi.amu.edu.pl/AITech/Szablon/raw/branch/master/Logotyp_AITech1.jpg)\n", "
\n", "

Ekstrakcja informacji

\n", "

14. Pretrenowane modele j\u0119zyka [wyk\u0142ad]

\n", "

Filip Grali\u0144ski (2021)

\n", "
\n", "\n", "![Logo 2](https://git.wmi.amu.edu.pl/AITech/Szablon/raw/branch/master/Logotyp_AITech2.jpg)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Pretrenowanie modeli\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "System AlphaZero uczy si\u0119 graj\u0105c sam ze sob\u0105 \u2014 wystarczy 24 godziny,\n", "by system nauczy\u0142 si\u0119 gra\u0107 w szachy lub go na nadludzkim poziomie.\n", "\n", "**Pytanie**: Dlaczego granie samemu ze sob\u0105 nie jest dobrym sposobem\n", " nauczenia si\u0119 grania w szachy dla cz\u0142owieka, a dla maszyny jest?\n", "\n", "Co jest odpowiednikiem grania samemu ze sob\u0105 w \u015bwiecie przetwarzania tekstu?\n", "Tzn. **pretrenowanie** (*pretraining*) na du\u017cym korpusie tekstu. (Tekst jest tani!)\n", "\n", "Jest kilka sposob\u00f3w na pretrenowanie modelu, w ka\u017cdym razie sprowadza\n", "si\u0119 do odgadywania nast\u0119pnego b\u0105d\u017a zamaskowanego s\u0142owa.\n", "W ka\u017cdym razie zawsze stosujemy softmax (by\u0107 mo\u017ce ze \u201esztuczkami\u201d takimi jak\n", "negatywne pr\u00f3bkowanie albo hierarchiczny softmax) na pewnej **reprezentacji kontekstowej**:\n", "\n", "$$\\vec{p} = \\operatorname{softmax}(f(\\vec{c})).$$\n", "\n", "Model jest karany u\u017cywaj\u0105c funkcji log loss:\n", "\n", "$$-\\log(p_j),$$\n", "\n", "gdzie $w_j$ jest wyrazem, kt\u00f3ry pojawi\u0142 si\u0119 rzeczywi\u015bcie w korpusie.\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Przewidywanie s\u0142owa (GPT-2)\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Jeden ze sposob\u00f3w pretrenowania modelu to po prostu przewidywanie\n", "nast\u0119pnego s\u0142owa.\n", "\n", "Zainstalujmy najpierw bibliotek\u0119 transformers.\n", "\n" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "! pip install transformers" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "50257\n" ] }, { "data": { "text/plain": [ "[('\u00c2\u0142', 0.6182783842086792),\n", " ('\u00c8', 0.1154019758105278),\n", " ('\u00d1\u0123', 0.026960616931319237),\n", " ('_____', 0.024418892338871956),\n", " ('________', 0.014962316490709782),\n", " ('\u00c3\u0124', 0.010653386823832989),\n", " ('\u00e4\u00b8\u0143', 0.008340531960129738),\n", " ('\u00d1', 0.007557711564004421),\n", " ('\u00ca', 0.007046067621558905),\n", " ('\u00e3\u0122', 0.006875576451420784),\n", " ('ile', 0.006685272324830294),\n", " ('____', 0.006307446397840977),\n", " ('\u00e2\u0122\u012d', 0.006306538358330727),\n", " ('\u00d1\u0122', 0.006197483278810978),\n", " ('\u0120Belarus', 0.006108700763434172),\n", " ('\u00c6', 0.005720408633351326),\n", " ('\u0120Poland', 0.0053678699769079685),\n", " ('\u00e1\u00b9', 0.004606408067047596),\n", " ('\u00ee\u0122', 0.004161055199801922),\n", " ('????', 0.004056799225509167),\n", " ('_______', 0.0038176667876541615),\n", " ('\u00e4\u00b8', 0.0036082742735743523),\n", " ('\u00cc', 0.003221835708245635),\n", " ('urs', 0.003080119378864765),\n", " ('________________', 0.0027312245219945908),\n", " ('\u0120Lithuania', 0.0023860156070441008),\n", " ('ich', 0.0021211160346865654),\n", " ('iz', 0.002069818088784814),\n", " ('vern', 0.002001357264816761),\n", " ('\u00c5\u0124', 0.001717406208626926)]" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import torch\n", "from transformers import GPT2Tokenizer, GPT2LMHeadModel\n", "tokenizer = GPT2Tokenizer.from_pretrained('gpt2-large')\n", "model = GPT2LMHeadModel.from_pretrained('gpt2-large')\n", "text = 'Warsaw is the capital city of'\n", "encoded_input = tokenizer(text, return_tensors='pt')\n", "output = model(**encoded_input)\n", "next_token_probs = torch.softmax(output[0][:, -1, :][0], dim=0)\n", "\n", "nb_of_tokens = next_token_probs.size()[0]\n", "print(nb_of_tokens)\n", "\n", "_, top_k_indices = torch.topk(next_token_probs, 30, sorted=True)\n", "\n", "words = tokenizer.convert_ids_to_tokens(top_k_indices)\n", "\n", "top_probs = []\n", "\n", "for ix in range(len(top_k_indices)):\n", " top_probs.append((words[ix], next_token_probs[top_k_indices[ix]].item()))\n", "\n", "top_probs" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Zalety tego podej\u015bcia:\n", "\n", "- prostota,\n", "- dobra podstawa do strojenia system\u00f3w generowania tekstu zw\u0142aszcza\n", " \u201eotwartego\u201d (systemy dialogowe, generowanie (fake) news\u00f3w, streszczanie tekstu),\n", " ale niekoniecznie t\u0142umaczenia maszynowego,\n", "- zaskakuj\u0105ca skuteczno\u015b\u0107 przy uczeniu *few-shot* i *zero-shot*.\n", "\n", "Wady:\n", "\n", "- asymetryczno\u015b\u0107, przetwarzanie tylko z lewej do prawej, preferencja\n", " dla lewego kontekstu,\n", "- mniejsza skuteczno\u015b\u0107 przy dostrajaniu do zada\u0144 klasyfikacji i innych zada\u0144\n", " niepolegaj\u0105cych na prostym generowaniu.\n", "\n", "Przyk\u0142ady modeli: GPT, GPT-2, GPT-3, DialoGPT.\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Maskowanie s\u0142\u00f3w (BERT)\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Inn\u0105 metod\u0105 jest maskowanie s\u0142\u00f3w (*Masked Language Modeling*, *MLM*).\n", "\n", "W tym podej\u015bciu losowe wybrane zast\u0119pujemy losowe s\u0142owa specjalnym\n", "tokenem (`[MASK]`) i ka\u017cemy modelowi odgadywa\u0107 w ten spos\u00f3b\n", "zamaskowane s\u0142owa (z uwzgl\u0119dnieniem r\u00f3wnie\u017c prawego kontekstu!).\n", "\n", "M\u00f3ci\u0105c \u015bci\u015ble, w jednym z pierwszych modeli tego typu (BERT)\n", "zastosowano schemat, w kt\u00f3rym r\u00f3wnie\u017c niezamaskowane s\u0142owa s\u0105 odgadywane (!):\n", "\n", "- wybieramy losowe 15% wyraz\u00f3w do odgadni\u0119cia\n", "- 80% z nich zast\u0119pujemy tokenem `[MASK]`,\n", "- 10% zast\u0119pujemy innym losowym wyrazem,\n", "- 10% pozostawiamy bez zmian.\n", "\n" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/filipg/.local/lib/python3.9/site-packages/transformers/models/auto/modeling_auto.py:806: FutureWarning: The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use `AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and `AutoModelForSeq2SeqLM` for encoder-decoder models.\n", " warnings.warn(\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "W kt\u00f3rym pa\u0144stwie le\u017cy Bombaj? W USA. (score: 0.16715531051158905)\n", "W kt\u00f3rym pa\u0144stwie le\u017cy Bombaj? W India. (score: 0.09912960231304169)\n", "W kt\u00f3rym pa\u0144stwie le\u017cy Bombaj? W Indian. (score: 0.039642028510570526)\n", "W kt\u00f3rym pa\u0144stwie le\u017cy Bombaj? W Nepal. (score: 0.027137665078043938)\n", "W kt\u00f3rym pa\u0144stwie le\u017cy Bombaj? W Pakistan. (score: 0.027065709233283997)\n", "W kt\u00f3rym pa\u0144stwie le\u017cy Bombaj? W Polsce. (score: 0.023737527430057526)\n", "W kt\u00f3rym pa\u0144stwie le\u017cy Bombaj? W .... (score: 0.02306722290813923)\n", "W kt\u00f3rym pa\u0144stwie le\u017cy Bombaj? W Bangladesh. (score: 0.022106658667325974)\n", "W kt\u00f3rym pa\u0144stwie le\u017cy Bombaj? W .... (score: 0.01628892682492733)\n", "W kt\u00f3rym pa\u0144stwie le\u017cy Bombaj? W Niemczech. (score: 0.014501162804663181)\n" ] } ], "source": [ "from transformers import AutoModelWithLMHead, AutoTokenizer\n", "import torch\n", "\n", "tokenizer = AutoTokenizer.from_pretrained(\"xlm-roberta-large\")\n", "model = AutoModelWithLMHead.from_pretrained(\"xlm-roberta-large\")\n", "\n", "sequence = f'W kt\u00f3rym pa\u0144stwie le\u017cy Bombaj? W {tokenizer.mask_token}.'\n", "\n", "input_ids = tokenizer.encode(sequence, return_tensors=\"pt\")\n", "mask_token_index = torch.where(input_ids == tokenizer.mask_token_id)[1]\n", "\n", "token_logits = model(input_ids)[0]\n", "mask_token_logits = token_logits[0, mask_token_index, :]\n", "mask_token_logits = torch.softmax(mask_token_logits, dim=1)\n", "\n", "top_10 = torch.topk(mask_token_logits, 10, dim=1)\n", "top_10_tokens = zip(top_10.indices[0].tolist(), top_10.values[0].tolist())\n", "\n", "for token, score in top_10_tokens:\n", " print(sequence.replace(tokenizer.mask_token, tokenizer.decode([token])), f\"(score: {score})\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Przyk\u0142ady: BERT, RoBERTa (r\u00f3wnie\u017c Polish RoBERTa).\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Podej\u015bcie generatywne (koder-dekoder).\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "System ma wygenerowa\u0107 odpowied\u017a na r\u00f3\u017cne pytania (r\u00f3wnie\u017c\n", "odpowiadaj\u0105ce zadaniu MLM), np.:\n", "\n", "- \"translate English to German: That is good.\" => \"Das ist gut.\"\n", "- \"cola sentence: The course is jumping well.\" => \"not acceptable\"\n", "- \"summarize: state authorities dispatched emergency crews tuesday to survey the damage after an onslaught of severe weather in mississippi…\"\n", " => \"six people hospitalized after a storm in attala county\"\n", "- \"Thank you for me to your party week.\" => for inviting last \n", "\n" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['World War II ended in World War II.',\n", " 'World War II ended in 1945..',\n", " 'World War II ended in 1945.',\n", " 'World War II ended in 1945.',\n", " 'World War II ended in 1945.']" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from transformers import T5Tokenizer, T5Config, T5ForConditionalGeneration\n", "\n", "T5_PATH = 't5-base'\n", "\n", "t5_tokenizer = T5Tokenizer.from_pretrained(T5_PATH)\n", "t5_config = T5Config.from_pretrained(T5_PATH)\n", "t5_mlm = T5ForConditionalGeneration.from_pretrained(T5_PATH, config=t5_config)\n", "\n", "slot = ''\n", "\n", "text = f'World War II ended in {slot}.'\n", "\n", "encoded = t5_tokenizer.encode_plus(text, add_special_tokens=True, return_tensors='pt')\n", "input_ids = encoded['input_ids']\n", "\n", "outputs = t5_mlm.generate(input_ids=input_ids,\n", " num_beams=200, num_return_sequences=5,\n", " max_length=5)\n", "\n", "_0_index = text.index(slot)\n", "_result_prefix = text[:_0_index]\n", "_result_suffix = text[_0_index+len(slot):]\n", "\n", "def _filter(output, end_token=''):\n", " _txt = t5_tokenizer.decode(output[2:], skip_special_tokens=False, clean_up_tokenization_spaces=False)\n", " if end_token in _txt:\n", " _end_token_index = _txt.index(end_token)\n", " return _result_prefix + _txt[:_end_token_index] + _result_suffix\n", " else:\n", " return _result_prefix + _txt + _result_suffix\n", "\n", "\n", "results = [_filter(out) for out in outputs]\n", "results" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "(Zob. [https://arxiv.org/pdf/1910.10683.pdf](https://arxiv.org/pdf/1910.10683.pdf))\n", "\n", "Przyk\u0142ad: T5, mT5\n", "\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.6" }, "org": null, "author": "Filip Grali\u0144ski", "email": "filipg@amu.edu.pl", "lang": "pl", "subtitle": "14.Pretrenowane modele j\u0119zyka[wyk\u0142ad]", "title": "Ekstrakcja informacji", "year": "2021" }, "nbformat": 4, "nbformat_minor": 4 }