This commit is contained in:
Filip Gralinski 2022-06-11 08:36:47 +02:00
parent 33a7a1f83e
commit 6da452bc4f

View File

@ -90,25 +90,23 @@ Dokonajmy najpierw tokenizacji:
#+BEGIN_SRC python :session mysession :exports both :results raw drawer
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("gpt2")
text = "The World War III will begin in"
text = "The World War III will begin in 2028 in"
encoded_input = tokenizer(text, return_tensors='pt')
encoded_input
#+END_SRC
#+RESULTS:
:results:
{'input_ids': tensor([[ 464, 2159, 1810, 6711, 481, 2221, 287]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1]])}
{'input_ids': tensor([[ 464, 2159, 1810, 6711, 481, 2221, 287, 1160, 2078, 287]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}
:end:
Możemy podejrzeć uzyskane tokeny:
#+BEGIN_SRC python :session mysession :exports both :results raw drawer
[tokenizer.decode(i) for i in encoded_input.input_ids[0]]
[tokenizer.decode(i) for i in encoded_input.input_ids[0]]
#+END_SRC
#+RESULTS:
:results:
['The', ' World', ' War', ' III', ' will', ' begin', ' in']
['The', ' World', ' War', ' III', ' will', ' begin', ' in', ' 20', '28', ' in']
:end:
Zwróćmy uwagę, że w GPT-2 tokeny obejmują spacje!
@ -125,6 +123,15 @@ Teraz uruchommy zasadniczy model:
:results:
:end:
#+BEGIN_SRC python :session mysession :exports both :results raw drawer
softmax(outputs[0][0][-1])
#+END_SRC
#+RESULTS:
:results:
:end:
Z modelu GPT-2 otrzymamy rozkład prawdopodobieństwa kolejnego wyrazu, najpierw w postaci
nieznormalizowanych *logitów*:
@ -140,14 +147,14 @@ tensor([-130.2947, -129.5677, -136.4030, ..., -138.3791, -138.8967,
:end:
#+BEGIN_SRC python :session mysession :exports both :results raw drawer
from torch import softmax, topk
from torch import softmax, topk
k = 20
k = 20
t = topk(softmax(logits, -1), k)
t = topk(softmax(logits, -1), k)
tb = [[tokenizer.decode(t.indices[ix]), t.values[ix].item()] for ix in range(k)]
tb
tb = [[tokenizer.decode(t.indices[ix]), t.values[ix].item()] for ix in range(k)]
tb
#+END_SRC
#+RESULTS: