aitech-eks-pub/wyk/14_pretrenowanie.ipynb

370 lines
12 KiB
Plaintext
Raw Normal View History

2021-06-14 15:39:15 +02:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Pretrenowanie modeli\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"System AlphaZero uczy się grając sam ze sobą — wystarczy 24 godziny,\n",
"by system nauczył się grać w szachy lub go na nadludzkim poziomie.\n",
"\n",
"**Pytanie**: Dlaczego granie samemu ze sobą nie jest dobrym sposobem\n",
" nauczenia się grania w szachy dla człowieka, a dla maszyny jest?\n",
"\n",
"Co jest odpowiednikiem grania samemu ze sobą w świecie przetwarzania tekstu?\n",
"Tzn. **pretrenowanie** (*pretraining*) na dużym korpusie tekstu. (Tekst jest tani!)\n",
"\n",
"Jest kilka sposobów na pretrenowanie modelu, w każdym razie sprowadza\n",
"się do odgadywania następnego bądź zamaskowanego słowa.\n",
"W każdym razie zawsze stosujemy softmax (być może ze „sztuczkami” takimi jak\n",
2021-09-27 07:42:48 +02:00
"negatywne próbkowanie albo hierarchiczny softmax) na pewnej **reprezentacji kontekstowej**:\n",
2021-06-14 15:39:15 +02:00
"\n",
"$$\\vec{p} = \\operatorname{softmax}(f(\\vec{c})).$$\n",
"\n",
"Model jest karany używając funkcji log loss:\n",
"\n",
"$$-\\log(p_j),$$\n",
"\n",
"gdzie $w_j$ jest wyrazem, który pojawił się rzeczywiście w korpusie.\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Przewidywanie słowa (GPT-2)\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Jeden ze sposobów pretrenowania modelu to po prostu przewidywanie\n",
"następnego słowa.\n",
"\n",
"Zainstalujmy najpierw bibliotekę transformers.\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"! pip install transformers"
]
},
{
"cell_type": "code",
2021-09-27 07:36:37 +02:00
"execution_count": 17,
2021-06-14 15:39:15 +02:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"50257\n"
]
},
{
"data": {
"text/plain": [
2021-09-27 07:36:37 +02:00
"[('Âł', 0.6182783842086792),\n",
" ('È', 0.1154019758105278),\n",
" ('Ñģ', 0.026960616931319237),\n",
" ('_____', 0.024418892338871956),\n",
" ('________', 0.014962316490709782),\n",
" ('ÃĤ', 0.010653386823832989),\n",
" ('ä¸Ń', 0.008340531960129738),\n",
" ('Ñ', 0.007557711564004421),\n",
" ('Ê', 0.007046067621558905),\n",
" ('ãĢ', 0.006875576451420784),\n",
" ('ile', 0.006685272324830294),\n",
" ('____', 0.006307446397840977),\n",
" ('âĢĭ', 0.006306538358330727),\n",
" ('ÑĢ', 0.006197483278810978),\n",
" ('ĠBelarus', 0.006108700763434172),\n",
" ('Æ', 0.005720408633351326),\n",
" ('ĠPoland', 0.0053678699769079685),\n",
" ('á¹', 0.004606408067047596),\n",
" ('îĢ', 0.004161055199801922),\n",
" ('????', 0.004056799225509167),\n",
" ('_______', 0.0038176667876541615),\n",
" ('ä¸', 0.0036082742735743523),\n",
" ('Ì', 0.003221835708245635),\n",
" ('urs', 0.003080119378864765),\n",
" ('________________', 0.0027312245219945908),\n",
" ('ĠLithuania', 0.0023860156070441008),\n",
" ('ich', 0.0021211160346865654),\n",
" ('iz', 0.002069818088784814),\n",
" ('vern', 0.002001357264816761),\n",
" ('ÅĤ', 0.001717406208626926)]"
2021-06-14 15:39:15 +02:00
]
},
2021-09-27 07:36:37 +02:00
"execution_count": 17,
2021-06-14 15:39:15 +02:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import torch\n",
"from transformers import GPT2Tokenizer, GPT2LMHeadModel\n",
"tokenizer = GPT2Tokenizer.from_pretrained('gpt2-large')\n",
"model = GPT2LMHeadModel.from_pretrained('gpt2-large')\n",
2021-09-27 07:36:37 +02:00
"text = 'Warsaw is the capital city of'\n",
2021-06-14 15:39:15 +02:00
"encoded_input = tokenizer(text, return_tensors='pt')\n",
"output = model(**encoded_input)\n",
"next_token_probs = torch.softmax(output[0][:, -1, :][0], dim=0)\n",
"\n",
"nb_of_tokens = next_token_probs.size()[0]\n",
"print(nb_of_tokens)\n",
"\n",
"_, top_k_indices = torch.topk(next_token_probs, 30, sorted=True)\n",
"\n",
"words = tokenizer.convert_ids_to_tokens(top_k_indices)\n",
"\n",
"top_probs = []\n",
"\n",
"for ix in range(len(top_k_indices)):\n",
" top_probs.append((words[ix], next_token_probs[top_k_indices[ix]].item()))\n",
"\n",
"top_probs"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Zalety tego podejścia:\n",
"\n",
"- prostota,\n",
"- dobra podstawa do strojenia systemów generowania tekstu zwłaszcza\n",
" „otwartego” (systemy dialogowe, generowanie (fake) newsów, streszczanie tekstu),\n",
" ale niekoniecznie tłumaczenia maszynowego,\n",
"- zaskakująca skuteczność przy uczeniu *few-shot* i *zero-shot*.\n",
"\n",
"Wady:\n",
"\n",
"- asymetryczność, przetwarzanie tylko z lewej do prawej, preferencja\n",
" dla lewego kontekstu,\n",
"- mniejsza skuteczność przy dostrajaniu do zadań klasyfikacji i innych zadań\n",
" niepolegających na prostym generowaniu.\n",
"\n",
"Przykłady modeli: GPT, GPT-2, GPT-3, DialoGPT.\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Maskowanie słów (BERT)\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Inną metodą jest maskowanie słów (*Masked Language Modeling*, *MLM*).\n",
"\n",
"W tym podejściu losowe wybrane zastępujemy losowe słowa specjalnym\n",
"tokenem (`[MASK]`) i każemy modelowi odgadywać w ten sposób\n",
"zamaskowane słowa (z uwzględnieniem również prawego kontekstu!).\n",
"\n",
"Móciąc ściśle, w jednym z pierwszych modeli tego typu (BERT)\n",
"zastosowano schemat, w którym również niezamaskowane słowa są odgadywane (!):\n",
"\n",
"- wybieramy losowe 15% wyrazów do odgadnięcia\n",
"- 80% z nich zastępujemy tokenem `[MASK]`,\n",
"- 10% zastępujemy innym losowym wyrazem,\n",
"- 10% pozostawiamy bez zmian.\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
2021-09-27 07:36:37 +02:00
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/filipg/.local/lib/python3.9/site-packages/transformers/models/auto/modeling_auto.py:806: FutureWarning: The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use `AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and `AutoModelForSeq2SeqLM` for encoder-decoder models.\n",
" warnings.warn(\n"
]
},
2021-06-14 15:39:15 +02:00
{
"name": "stdout",
"output_type": "stream",
"text": [
2021-09-27 07:36:37 +02:00
"W którym państwie leży Bombaj? W USA. (score: 0.16715531051158905)\n",
"W którym państwie leży Bombaj? W India. (score: 0.09912960231304169)\n",
"W którym państwie leży Bombaj? W Indian. (score: 0.039642028510570526)\n",
"W którym państwie leży Bombaj? W Nepal. (score: 0.027137665078043938)\n",
"W którym państwie leży Bombaj? W Pakistan. (score: 0.027065709233283997)\n",
"W którym państwie leży Bombaj? W Polsce. (score: 0.023737527430057526)\n",
"W którym państwie leży Bombaj? W .... (score: 0.02306722290813923)\n",
"W którym państwie leży Bombaj? W Bangladesh. (score: 0.022106658667325974)\n",
"W którym państwie leży Bombaj? W .... (score: 0.01628892682492733)\n",
"W którym państwie leży Bombaj? W Niemczech. (score: 0.014501162804663181)\n"
2021-06-14 15:39:15 +02:00
]
}
],
"source": [
"from transformers import AutoModelWithLMHead, AutoTokenizer\n",
"import torch\n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(\"xlm-roberta-large\")\n",
"model = AutoModelWithLMHead.from_pretrained(\"xlm-roberta-large\")\n",
"\n",
2021-09-27 07:36:37 +02:00
"sequence = f'W którym państwie leży Bombaj? W {tokenizer.mask_token}.'\n",
2021-06-14 15:39:15 +02:00
"\n",
"input_ids = tokenizer.encode(sequence, return_tensors=\"pt\")\n",
"mask_token_index = torch.where(input_ids == tokenizer.mask_token_id)[1]\n",
"\n",
"token_logits = model(input_ids)[0]\n",
"mask_token_logits = token_logits[0, mask_token_index, :]\n",
"mask_token_logits = torch.softmax(mask_token_logits, dim=1)\n",
"\n",
"top_10 = torch.topk(mask_token_logits, 10, dim=1)\n",
"top_10_tokens = zip(top_10.indices[0].tolist(), top_10.values[0].tolist())\n",
"\n",
"for token, score in top_10_tokens:\n",
" print(sequence.replace(tokenizer.mask_token, tokenizer.decode([token])), f\"(score: {score})\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Przykłady: BERT, RoBERTa (również Polish RoBERTa).\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Podejście generatywne (koder-dekoder).\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"System ma wygenerować odpowiedź na różne pytania (również\n",
"odpowiadające zadaniu MLM), np.:\n",
"\n",
"- \"translate English to German: That is good.\" => \"Das ist gut.\"\n",
"- \"cola sentence: The course is jumping well.\" => \"not acceptable\"\n",
"- \"summarize: state authorities dispatched emergency crews tuesday to survey the damage after an onslaught of severe weather in mississippi…\"\n",
" => \"six people hospitalized after a storm in attala county\"\n",
"- \"Thank you for <X> me to your party <Y> week.\" => <X> for inviting <Y> last <Z>\n",
"\n"
]
},
{
"cell_type": "code",
2021-09-27 07:36:37 +02:00
"execution_count": 2,
2021-06-14 15:39:15 +02:00
"metadata": {},
2021-09-27 07:36:37 +02:00
"outputs": [
{
"data": {
"text/plain": [
"['World War II ended in World War II.',\n",
" 'World War II ended in 1945..',\n",
" 'World War II ended in 1945.',\n",
" 'World War II ended in 1945.',\n",
" 'World War II ended in 1945.']"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
2021-06-14 15:39:15 +02:00
"source": [
"from transformers import T5Tokenizer, T5Config, T5ForConditionalGeneration\n",
"\n",
"T5_PATH = 't5-base'\n",
"\n",
"t5_tokenizer = T5Tokenizer.from_pretrained(T5_PATH)\n",
"t5_config = T5Config.from_pretrained(T5_PATH)\n",
"t5_mlm = T5ForConditionalGeneration.from_pretrained(T5_PATH, config=t5_config)\n",
"\n",
"slot = '<extra_id_0>'\n",
"\n",
2021-09-27 07:36:37 +02:00
"text = f'World War II ended in {slot}.'\n",
2021-06-14 15:39:15 +02:00
"\n",
"encoded = t5_tokenizer.encode_plus(text, add_special_tokens=True, return_tensors='pt')\n",
"input_ids = encoded['input_ids']\n",
"\n",
"outputs = t5_mlm.generate(input_ids=input_ids,\n",
" num_beams=200, num_return_sequences=5,\n",
" max_length=5)\n",
"\n",
"_0_index = text.index(slot)\n",
"_result_prefix = text[:_0_index]\n",
"_result_suffix = text[_0_index+len(slot):]\n",
"\n",
"def _filter(output, end_token='<extra_id_1>'):\n",
" _txt = t5_tokenizer.decode(output[2:], skip_special_tokens=False, clean_up_tokenization_spaces=False)\n",
" if end_token in _txt:\n",
" _end_token_index = _txt.index(end_token)\n",
" return _result_prefix + _txt[:_end_token_index] + _result_suffix\n",
" else:\n",
" return _result_prefix + _txt + _result_suffix\n",
"\n",
"\n",
"results = [_filter(out) for out in outputs]\n",
"results"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"(Zob. [https://arxiv.org/pdf/1910.10683.pdf](https://arxiv.org/pdf/1910.10683.pdf))\n",
"\n",
"Przykład: T5, mT5\n",
"\n"
]
}
],
"metadata": {
"kernelspec": {
2021-09-27 07:42:48 +02:00
"display_name": "Python 3 (ipykernel)",
2021-06-14 15:39:15 +02:00
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
2021-09-27 07:42:48 +02:00
"version": "3.9.6"
2021-06-14 15:39:15 +02:00
},
"org": null
},
"nbformat": 4,
2021-09-27 07:36:37 +02:00
"nbformat_minor": 4
2021-06-14 15:39:15 +02:00
}