forked from filipg/aitech-eks-pub
14
This commit is contained in:
parent
edf3811cd7
commit
78fb510cba
338
wyk/14_pretrenowanie.ipynb
Normal file
338
wyk/14_pretrenowanie.ipynb
Normal file
@ -0,0 +1,338 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Pretrenowanie modeli\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"System AlphaZero uczy się grając sam ze sobą — wystarczy 24 godziny,\n",
|
||||
"by system nauczył się grać w szachy lub go na nadludzkim poziomie.\n",
|
||||
"\n",
|
||||
"**Pytanie**: Dlaczego granie samemu ze sobą nie jest dobrym sposobem\n",
|
||||
" nauczenia się grania w szachy dla człowieka, a dla maszyny jest?\n",
|
||||
"\n",
|
||||
"Co jest odpowiednikiem grania samemu ze sobą w świecie przetwarzania tekstu?\n",
|
||||
"Tzn. **pretrenowanie** (*pretraining*) na dużym korpusie tekstu. (Tekst jest tani!)\n",
|
||||
"\n",
|
||||
"Jest kilka sposobów na pretrenowanie modelu, w każdym razie sprowadza\n",
|
||||
"się do odgadywania następnego bądź zamaskowanego słowa.\n",
|
||||
"W każdym razie zawsze stosujemy softmax (być może ze „sztuczkami” takimi jak\n",
|
||||
"negatywne próbkowanie albo hierarchiczny softamx) na pewnej **representecji kontekstowej**:\n",
|
||||
"\n",
|
||||
"$$\\vec{p} = \\operatorname{softmax}(f(\\vec{c})).$$\n",
|
||||
"\n",
|
||||
"Model jest karany używając funkcji log loss:\n",
|
||||
"\n",
|
||||
"$$-\\log(p_j),$$\n",
|
||||
"\n",
|
||||
"gdzie $w_j$ jest wyrazem, który pojawił się rzeczywiście w korpusie.\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Przewidywanie słowa (GPT-2)\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Jeden ze sposobów pretrenowania modelu to po prostu przewidywanie\n",
|
||||
"następnego słowa.\n",
|
||||
"\n",
|
||||
"Zainstalujmy najpierw bibliotekę transformers.\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"! pip install transformers"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"50257\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[('Ġon', 0.6786560416221619),\n",
|
||||
" ('Ġupon', 0.04339785501360893),\n",
|
||||
" ('Ġheavily', 0.02208443358540535),\n",
|
||||
" ('Ġin', 0.021049050614237785),\n",
|
||||
" (',', 0.020188499242067337),\n",
|
||||
" ('Ġa', 0.01833895780146122),\n",
|
||||
" ('Ġvery', 0.017935041338205338),\n",
|
||||
" ('Ġentirely', 0.017528969794511795),\n",
|
||||
" ('Ġlargely', 0.016769640147686005),\n",
|
||||
" ('Ġto', 0.01009418722242117),\n",
|
||||
" ('Ġgreatly', 0.010009866207838058),\n",
|
||||
" ('Ġnot', 0.009016563184559345),\n",
|
||||
" ('Ġmore', 0.005853226874023676),\n",
|
||||
" ('Ġprimarily', 0.005203146021813154),\n",
|
||||
" ('Ġstrongly', 0.0034501152113080025),\n",
|
||||
" ('Ġpartly', 0.0033184229396283627),\n",
|
||||
" ('Ġmuch', 0.0033095215912908316),\n",
|
||||
" ('Ġmostly', 0.0032150144688785076),\n",
|
||||
" ('Ġmainly', 0.0030899408739060163),\n",
|
||||
" ('Ġfor', 0.003034428460523486),\n",
|
||||
" ('.', 0.0028878094162791967),\n",
|
||||
" ('Ġboth', 0.0028405177872627974),\n",
|
||||
" ('Ġsomewhat', 0.0028194624464958906),\n",
|
||||
" ('Ġcru', 0.002263976726680994),\n",
|
||||
" ('Ġas', 0.00221616611815989),\n",
|
||||
" ('Ġof', 0.0022000609897077084),\n",
|
||||
" ('Ġalmost', 0.001968063646927476),\n",
|
||||
" ('Ġat', 0.0018015997484326363),\n",
|
||||
" ('Ġhighly', 0.0017461496172472835),\n",
|
||||
" ('Ġcompletely', 0.001692073536105454)]"
|
||||
]
|
||||
},
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"from transformers import GPT2Tokenizer, GPT2LMHeadModel\n",
|
||||
"tokenizer = GPT2Tokenizer.from_pretrained('gpt2-large')\n",
|
||||
"model = GPT2LMHeadModel.from_pretrained('gpt2-large')\n",
|
||||
"text = \"This issue depends\"\n",
|
||||
"encoded_input = tokenizer(text, return_tensors='pt')\n",
|
||||
"output = model(**encoded_input)\n",
|
||||
"next_token_probs = torch.softmax(output[0][:, -1, :][0], dim=0)\n",
|
||||
"\n",
|
||||
"next_token_probs\n",
|
||||
"nb_of_tokens = next_token_probs.size()[0]\n",
|
||||
"print(nb_of_tokens)\n",
|
||||
"\n",
|
||||
"_, top_k_indices = torch.topk(next_token_probs, 30, sorted=True)\n",
|
||||
"\n",
|
||||
"words = tokenizer.convert_ids_to_tokens(top_k_indices)\n",
|
||||
"\n",
|
||||
"top_probs = []\n",
|
||||
"\n",
|
||||
"for ix in range(len(top_k_indices)):\n",
|
||||
" top_probs.append((words[ix], next_token_probs[top_k_indices[ix]].item()))\n",
|
||||
"\n",
|
||||
"top_probs"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Zalety tego podejścia:\n",
|
||||
"\n",
|
||||
"- prostota,\n",
|
||||
"- dobra podstawa do strojenia systemów generowania tekstu zwłaszcza\n",
|
||||
" „otwartego” (systemy dialogowe, generowanie (fake) newsów, streszczanie tekstu),\n",
|
||||
" ale niekoniecznie tłumaczenia maszynowego,\n",
|
||||
"- zaskakująca skuteczność przy uczeniu *few-shot* i *zero-shot*.\n",
|
||||
"\n",
|
||||
"Wady:\n",
|
||||
"\n",
|
||||
"- asymetryczność, przetwarzanie tylko z lewej do prawej, preferencja\n",
|
||||
" dla lewego kontekstu,\n",
|
||||
"- mniejsza skuteczność przy dostrajaniu do zadań klasyfikacji i innych zadań\n",
|
||||
" niepolegających na prostym generowaniu.\n",
|
||||
"\n",
|
||||
"Przykłady modeli: GPT, GPT-2, GPT-3, DialoGPT.\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Maskowanie słów (BERT)\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Inną metodą jest maskowanie słów (*Masked Language Modeling*, *MLM*).\n",
|
||||
"\n",
|
||||
"W tym podejściu losowe wybrane zastępujemy losowe słowa specjalnym\n",
|
||||
"tokenem (`[MASK]`) i każemy modelowi odgadywać w ten sposób\n",
|
||||
"zamaskowane słowa (z uwzględnieniem również prawego kontekstu!).\n",
|
||||
"\n",
|
||||
"Móciąc ściśle, w jednym z pierwszych modeli tego typu (BERT)\n",
|
||||
"zastosowano schemat, w którym również niezamaskowane słowa są odgadywane (!):\n",
|
||||
"\n",
|
||||
"- wybieramy losowe 15% wyrazów do odgadnięcia\n",
|
||||
"- 80% z nich zastępujemy tokenem `[MASK]`,\n",
|
||||
"- 10% zastępujemy innym losowym wyrazem,\n",
|
||||
"- 10% pozostawiamy bez zmian.\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"# Out[3]:"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from transformers import AutoModelWithLMHead, AutoTokenizer\n",
|
||||
"import torch\n",
|
||||
"\n",
|
||||
"tokenizer = AutoTokenizer.from_pretrained(\"xlm-roberta-large\")\n",
|
||||
"model = AutoModelWithLMHead.from_pretrained(\"xlm-roberta-large\")\n",
|
||||
"\n",
|
||||
"sequence = f'II wojna światowa zakończyła się w {tokenizer.mask_token} roku.'\n",
|
||||
"\n",
|
||||
"input_ids = tokenizer.encode(sequence, return_tensors=\"pt\")\n",
|
||||
"mask_token_index = torch.where(input_ids == tokenizer.mask_token_id)[1]\n",
|
||||
"\n",
|
||||
"token_logits = model(input_ids)[0]\n",
|
||||
"mask_token_logits = token_logits[0, mask_token_index, :]\n",
|
||||
"mask_token_logits = torch.softmax(mask_token_logits, dim=1)\n",
|
||||
"\n",
|
||||
"top_10 = torch.topk(mask_token_logits, 10, dim=1)\n",
|
||||
"top_10_tokens = zip(top_10.indices[0].tolist(), top_10.values[0].tolist())\n",
|
||||
"\n",
|
||||
"for token, score in top_10_tokens:\n",
|
||||
" print(sequence.replace(tokenizer.mask_token, tokenizer.decode([token])), f\"(score: {score})\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Przykłady: BERT, RoBERTa (również Polish RoBERTa).\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Podejście generatywne (koder-dekoder).\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"System ma wygenerować odpowiedź na różne pytania (również\n",
|
||||
"odpowiadające zadaniu MLM), np.:\n",
|
||||
"\n",
|
||||
"- \"translate English to German: That is good.\" => \"Das ist gut.\"\n",
|
||||
"- \"cola sentence: The course is jumping well.\" => \"not acceptable\"\n",
|
||||
"- \"summarize: state authorities dispatched emergency crews tuesday to survey the damage after an onslaught of severe weather in mississippi…\"\n",
|
||||
" => \"six people hospitalized after a storm in attala county\"\n",
|
||||
"- \"Thank you for <X> me to your party <Y> week.\" => <X> for inviting <Y> last <Z>\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from transformers import T5Tokenizer, T5Config, T5ForConditionalGeneration\n",
|
||||
"\n",
|
||||
"T5_PATH = 't5-base'\n",
|
||||
"\n",
|
||||
"t5_tokenizer = T5Tokenizer.from_pretrained(T5_PATH)\n",
|
||||
"t5_config = T5Config.from_pretrained(T5_PATH)\n",
|
||||
"t5_mlm = T5ForConditionalGeneration.from_pretrained(T5_PATH, config=t5_config)\n",
|
||||
"\n",
|
||||
"slot = '<extra_id_0>'\n",
|
||||
"\n",
|
||||
"text = f'Warsaw is the {slot} of Poland.'\n",
|
||||
"\n",
|
||||
"encoded = t5_tokenizer.encode_plus(text, add_special_tokens=True, return_tensors='pt')\n",
|
||||
"input_ids = encoded['input_ids']\n",
|
||||
"\n",
|
||||
"outputs = t5_mlm.generate(input_ids=input_ids,\n",
|
||||
" num_beams=200, num_return_sequences=5,\n",
|
||||
" max_length=5)\n",
|
||||
"\n",
|
||||
"_0_index = text.index(slot)\n",
|
||||
"_result_prefix = text[:_0_index]\n",
|
||||
"_result_suffix = text[_0_index+len(slot):]\n",
|
||||
"\n",
|
||||
"def _filter(output, end_token='<extra_id_1>'):\n",
|
||||
" _txt = t5_tokenizer.decode(output[2:], skip_special_tokens=False, clean_up_tokenization_spaces=False)\n",
|
||||
" if end_token in _txt:\n",
|
||||
" _end_token_index = _txt.index(end_token)\n",
|
||||
" return _result_prefix + _txt[:_end_token_index] + _result_suffix\n",
|
||||
" else:\n",
|
||||
" return _result_prefix + _txt + _result_suffix\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"results = [_filter(out) for out in outputs]\n",
|
||||
"results"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"(Zob. [https://arxiv.org/pdf/1910.10683.pdf](https://arxiv.org/pdf/1910.10683.pdf))\n",
|
||||
"\n",
|
||||
"Przykład: T5, mT5\n",
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.2"
|
||||
},
|
||||
"org": null
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 1
|
||||
}
|
@ -132,10 +132,10 @@ zastosowano schemat, w którym również niezamaskowane słowa są odgadywane (!
|
||||
from transformers import AutoModelWithLMHead, AutoTokenizer
|
||||
import torch
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("distilroberta-base")
|
||||
model = AutoModelWithLMHead.from_pretrained("distilroberta-base")
|
||||
tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-large")
|
||||
model = AutoModelWithLMHead.from_pretrained("xlm-roberta-large")
|
||||
|
||||
sequence = f"Hugging Face is a French company based in {tokenizer.mask_token}"
|
||||
sequence = f'II wojna światowa zakończyła się w {tokenizer.mask_token} roku.'
|
||||
|
||||
input_ids = tokenizer.encode(sequence, return_tensors="pt")
|
||||
mask_token_index = torch.where(input_ids == tokenizer.mask_token_id)[1]
|
||||
@ -144,12 +144,69 @@ token_logits = model(input_ids)[0]
|
||||
mask_token_logits = token_logits[0, mask_token_index, :]
|
||||
mask_token_logits = torch.softmax(mask_token_logits, dim=1)
|
||||
|
||||
top_5 = torch.topk(mask_token_logits, 5, dim=1)
|
||||
top_5_tokens = zip(top_5.indices[0].tolist(), top_5.values[0].tolist())
|
||||
top_10 = torch.topk(mask_token_logits, 10, dim=1)
|
||||
top_10_tokens = zip(top_10.indices[0].tolist(), top_10.values[0].tolist())
|
||||
|
||||
for token, score in top_5_tokens:
|
||||
for token, score in top_10_tokens:
|
||||
print(sequence.replace(tokenizer.mask_token, tokenizer.decode([token])), f"(score: {score})")
|
||||
#+END_SRC
|
||||
|
||||
#+RESULTS:
|
||||
:results:
|
||||
# Out[3]:
|
||||
:end:
|
||||
|
||||
|
||||
Przykłady: BERT, RoBERTa (również Polish RoBERTa).
|
||||
|
||||
** Podejście generatywne (koder-dekoder).
|
||||
|
||||
System ma wygenerować odpowiedź na różne pytania (również
|
||||
odpowiadające zadaniu MLM), np.:
|
||||
|
||||
- "translate English to German: That is good." => "Das ist gut."
|
||||
- "cola sentence: The course is jumping well." => "not acceptable"
|
||||
- "summarize: state authorities dispatched emergency crews tuesday to survey the damage after an onslaught of severe weather in mississippi..."
|
||||
=> "six people hospitalized after a storm in attala county"
|
||||
- "Thank you for <X> me to your party <Y> week." => <X> for inviting <Y> last <Z>
|
||||
|
||||
#+BEGIN_SRC ipython :session mysession :exports both :results raw drawer
|
||||
from transformers import T5Tokenizer, T5Config, T5ForConditionalGeneration
|
||||
|
||||
T5_PATH = 't5-base'
|
||||
|
||||
t5_tokenizer = T5Tokenizer.from_pretrained(T5_PATH)
|
||||
t5_config = T5Config.from_pretrained(T5_PATH)
|
||||
t5_mlm = T5ForConditionalGeneration.from_pretrained(T5_PATH, config=t5_config)
|
||||
|
||||
slot = '<extra_id_0>'
|
||||
|
||||
text = f'Warsaw is the {slot} of Poland.'
|
||||
|
||||
encoded = t5_tokenizer.encode_plus(text, add_special_tokens=True, return_tensors='pt')
|
||||
input_ids = encoded['input_ids']
|
||||
|
||||
outputs = t5_mlm.generate(input_ids=input_ids,
|
||||
num_beams=200, num_return_sequences=5,
|
||||
max_length=5)
|
||||
|
||||
_0_index = text.index(slot)
|
||||
_result_prefix = text[:_0_index]
|
||||
_result_suffix = text[_0_index+len(slot):]
|
||||
|
||||
def _filter(output, end_token='<extra_id_1>'):
|
||||
_txt = t5_tokenizer.decode(output[2:], skip_special_tokens=False, clean_up_tokenization_spaces=False)
|
||||
if end_token in _txt:
|
||||
_end_token_index = _txt.index(end_token)
|
||||
return _result_prefix + _txt[:_end_token_index] + _result_suffix
|
||||
else:
|
||||
return _result_prefix + _txt + _result_suffix
|
||||
|
||||
|
||||
results = [_filter(out) for out in outputs]
|
||||
results
|
||||
#+END_SRC
|
||||
|
||||
(Zob. https://arxiv.org/pdf/1910.10683.pdf)
|
||||
|
||||
Przykład: T5, mT5
|
||||
|
Loading…
Reference in New Issue
Block a user