12 KiB
Pretrenowanie modeli
System AlphaZero uczy się grając sam ze sobą — wystarczy 24 godziny, by system nauczył się grać w szachy lub go na nadludzkim poziomie.
Pytanie: Dlaczego granie samemu ze sobą nie jest dobrym sposobem nauczenia się grania w szachy dla człowieka, a dla maszyny jest?
Co jest odpowiednikiem grania samemu ze sobą w świecie przetwarzania tekstu? Tzn. pretrenowanie (_pretraining) na dużym korpusie tekstu. (Tekst jest tani!)
Jest kilka sposobów na pretrenowanie modelu, w każdym razie sprowadza się do odgadywania następnego bądź zamaskowanego słowa. W każdym razie zawsze stosujemy softmax (być może ze „sztuczkami” takimi jak negatywne próbkowanie albo hierarchiczny softmax) na pewnej reprezentacji kontekstowej:
$$\vec{p} = \operatorname{softmax}(f(\vec{c})).$$
Model jest karany używając funkcji log loss:
$$-\log(p_j),$$
gdzie $w_j$ jest wyrazem, który pojawił się rzeczywiście w korpusie.
Przewidywanie słowa (GPT-2)
Jeden ze sposobów pretrenowania modelu to po prostu przewidywanie następnego słowa.
Zainstalujmy najpierw bibliotekę transformers.
! pip install transformers
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel
tokenizer = GPT2Tokenizer.from_pretrained('gpt2-large')
model = GPT2LMHeadModel.from_pretrained('gpt2-large')
text = 'Warsaw is the capital city of'
encoded_input = tokenizer(text, return_tensors='pt')
output = model(**encoded_input)
next_token_probs = torch.softmax(output[0][:, -1, :][0], dim=0)
nb_of_tokens = next_token_probs.size()[0]
print(nb_of_tokens)
_, top_k_indices = torch.topk(next_token_probs, 30, sorted=True)
words = tokenizer.convert_ids_to_tokens(top_k_indices)
top_probs = []
for ix in range(len(top_k_indices)):
top_probs.append((words[ix], next_token_probs[top_k_indices[ix]].item()))
top_probs
50257
[('Âł', 0.6182783842086792), ('È', 0.1154019758105278), ('Ñģ', 0.026960616931319237), ('_____', 0.024418892338871956), ('________', 0.014962316490709782), ('ÃĤ', 0.010653386823832989), ('ä¸Ń', 0.008340531960129738), ('Ñ', 0.007557711564004421), ('Ê', 0.007046067621558905), ('ãĢ', 0.006875576451420784), ('ile', 0.006685272324830294), ('____', 0.006307446397840977), ('âĢĭ', 0.006306538358330727), ('ÑĢ', 0.006197483278810978), ('ĠBelarus', 0.006108700763434172), ('Æ', 0.005720408633351326), ('ĠPoland', 0.0053678699769079685), ('á¹', 0.004606408067047596), ('îĢ', 0.004161055199801922), ('????', 0.004056799225509167), ('_______', 0.0038176667876541615), ('ä¸', 0.0036082742735743523), ('Ì', 0.003221835708245635), ('urs', 0.003080119378864765), ('________________', 0.0027312245219945908), ('ĠLithuania', 0.0023860156070441008), ('ich', 0.0021211160346865654), ('iz', 0.002069818088784814), ('vern', 0.002001357264816761), ('ÅĤ', 0.001717406208626926)]
Zalety tego podejścia:
- prostota,
- dobra podstawa do strojenia systemów generowania tekstu zwłaszcza „otwartego” (systemy dialogowe, generowanie (fake) newsów, streszczanie tekstu), ale niekoniecznie tłumaczenia maszynowego,
- zaskakująca skuteczność przy uczeniu _few-shot i zero-shot.
Wady:
- asymetryczność, przetwarzanie tylko z lewej do prawej, preferencja dla lewego kontekstu,
- mniejsza skuteczność przy dostrajaniu do zadań klasyfikacji i innych zadań niepolegających na prostym generowaniu.
Przykłady modeli: GPT, GPT-2, GPT-3, DialoGPT.
Maskowanie słów (BERT)
Inną metodą jest maskowanie słów (_Masked Language Modeling, MLM).
W tym podejściu losowe wybrane zastępujemy losowe słowa specjalnym
tokenem ([MASK]
) i każemy modelowi odgadywać w ten sposób
zamaskowane słowa (z uwzględnieniem również prawego kontekstu!).
Móciąc ściśle, w jednym z pierwszych modeli tego typu (BERT) zastosowano schemat, w którym również niezamaskowane słowa są odgadywane (!):
- wybieramy losowe 15% wyrazów do odgadnięcia
- 80% z nich zastępujemy tokenem
[MASK]
, - 10% zastępujemy innym losowym wyrazem,
- 10% pozostawiamy bez zmian.
from transformers import AutoModelWithLMHead, AutoTokenizer
import torch
tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-large")
model = AutoModelWithLMHead.from_pretrained("xlm-roberta-large")
sequence = f'W którym państwie leży Bombaj? W {tokenizer.mask_token}.'
input_ids = tokenizer.encode(sequence, return_tensors="pt")
mask_token_index = torch.where(input_ids == tokenizer.mask_token_id)[1]
token_logits = model(input_ids)[0]
mask_token_logits = token_logits[0, mask_token_index, :]
mask_token_logits = torch.softmax(mask_token_logits, dim=1)
top_10 = torch.topk(mask_token_logits, 10, dim=1)
top_10_tokens = zip(top_10.indices[0].tolist(), top_10.values[0].tolist())
for token, score in top_10_tokens:
print(sequence.replace(tokenizer.mask_token, tokenizer.decode([token])), f"(score: {score})")
/home/filipg/.local/lib/python3.9/site-packages/transformers/models/auto/modeling_auto.py:806: FutureWarning: The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use `AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and `AutoModelForSeq2SeqLM` for encoder-decoder models. warnings.warn(
W którym państwie leży Bombaj? W USA. (score: 0.16715531051158905) W którym państwie leży Bombaj? W India. (score: 0.09912960231304169) W którym państwie leży Bombaj? W Indian. (score: 0.039642028510570526) W którym państwie leży Bombaj? W Nepal. (score: 0.027137665078043938) W którym państwie leży Bombaj? W Pakistan. (score: 0.027065709233283997) W którym państwie leży Bombaj? W Polsce. (score: 0.023737527430057526) W którym państwie leży Bombaj? W .... (score: 0.02306722290813923) W którym państwie leży Bombaj? W Bangladesh. (score: 0.022106658667325974) W którym państwie leży Bombaj? W .... (score: 0.01628892682492733) W którym państwie leży Bombaj? W Niemczech. (score: 0.014501162804663181)
Przykłady: BERT, RoBERTa (również Polish RoBERTa).
Podejście generatywne (koder-dekoder).
System ma wygenerować odpowiedź na różne pytania (również odpowiadające zadaniu MLM), np.:
- "translate English to German: That is good." => "Das ist gut."
- "cola sentence: The course is jumping well." => "not acceptable"
- "summarize: state authorities dispatched emergency crews tuesday to survey the damage after an onslaught of severe weather in mississippi…" => "six people hospitalized after a storm in attala county"
- "Thank you for me to your party week." => for inviting last
from transformers import T5Tokenizer, T5Config, T5ForConditionalGeneration
T5_PATH = 't5-base'
t5_tokenizer = T5Tokenizer.from_pretrained(T5_PATH)
t5_config = T5Config.from_pretrained(T5_PATH)
t5_mlm = T5ForConditionalGeneration.from_pretrained(T5_PATH, config=t5_config)
slot = '<extra_id_0>'
text = f'World War II ended in {slot}.'
encoded = t5_tokenizer.encode_plus(text, add_special_tokens=True, return_tensors='pt')
input_ids = encoded['input_ids']
outputs = t5_mlm.generate(input_ids=input_ids,
num_beams=200, num_return_sequences=5,
max_length=5)
_0_index = text.index(slot)
_result_prefix = text[:_0_index]
_result_suffix = text[_0_index+len(slot):]
def _filter(output, end_token='<extra_id_1>'):
_txt = t5_tokenizer.decode(output[2:], skip_special_tokens=False, clean_up_tokenization_spaces=False)
if end_token in _txt:
_end_token_index = _txt.index(end_token)
return _result_prefix + _txt[:_end_token_index] + _result_suffix
else:
return _result_prefix + _txt + _result_suffix
results = [_filter(out) for out in outputs]
results
['World War II ended in World War II.', 'World War II ended in 1945..', 'World War II ended in 1945.', 'World War II ended in 1945.', 'World War II ended in 1945.']
(Zob. https://arxiv.org/pdf/1910.10683.pdf)
Przykład: T5, mT5