ISI-transformers/transformers.ipynb
2021-06-07 15:13:35 +02:00

37 KiB

bpe

pip install tokenizers

from tokenizers import Tokenizer, models, trainers
from tokenizers.trainers import BpeTrainer
tokenizer = Tokenizer(models.BPE())
trainer = trainers.BpeTrainer(vocab_size=20000, min_frequency=2)
tokenizer.train(files = ['/home/kuba/Syncthing/przedmioty/2020-02/ISI/zajecia9_ngramowy_model_jDDezykowy/pan-tadeusz-train.txt'], trainer = trainer)
output = tokenizer.encode("Nie śpiewają piosenek: pracują leniwo,")
output.ids
[236, 2255, 2069, 3898, 9908, 14, 8675, 8319, 191, 7]
output.tokens
['Nie', ' śpie', 'wają', ' pios', 'enek', ':', ' pracują', ' leni', 'wo', ',']
tokenizer.save("./my-bpe.tokenizer.json", pretty=True)

ZADANIE

stworzyć BPE tokenizer na podstawie https://git.wmi.amu.edu.pl/kubapok/lalka-lm/src/branch/master/train/train.tsv i stworzyć stokenizowaną listę: https://git.wmi.amu.edu.pl/kubapok/lalka-lm/src/branch/master/test-A/in.tsv

wybrać vocab_size = 8k, uwzględnić dodatkowe tokeny: BOS oraz EOS i wpleść je do zbioru testowego

transformery

# pip install transformers
import torch
from transformers import pipeline, set_seed
from transformers import RobertaTokenizer, RobertaModel
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
model = RobertaModel.from_pretrained('roberta-base')
text = "Replace me by any text you'd like. Bla Bla"
encoded_input = tokenizer(text, return_tensors='pt')
encoded_input['input_ids']
tensor([[   0, 9064, 6406,  162,   30,  143, 2788,   47, 1017,  101,    4, 2091,
          102, 2091,  102,    2]])
encoded_input['input_ids']
tensor([[   0, 9064, 6406,  162,   30,  143, 2788,   47, 1017,  101,    4, 2091,
          102, 2091,  102,    2]])
tokenizer.decode([162])
' me'
output = model(**encoded_input)
output
BaseModelOutputWithPoolingAndCrossAttentions(last_hidden_state=tensor([[[-4.4858e-02,  8.6642e-02, -7.2129e-03,  ..., -4.6295e-02,
          -3.9316e-02,  4.5264e-04],
         [-6.0603e-02,  1.5684e-01,  4.3705e-02,  ...,  5.3485e-01,
           8.4371e-02,  1.4826e-01],
         [-2.3786e-02, -1.2086e-02,  7.8233e-02,  ..., -4.9132e-01,
           1.2500e-01,  3.3293e-01],
         ...,
         [ 6.7192e-02,  2.4028e-01, -2.9984e-01,  ...,  2.1992e-01,
           1.9186e-02,  1.5355e-01],
         [ 1.7611e-01,  1.4001e-01, -1.3774e-01,  ..., -5.0379e-01,
           7.3958e-02,  3.7870e-02],
         [-3.4405e-02,  8.7648e-02, -3.9429e-02,  ..., -9.1916e-02,
          -3.5529e-02, -2.6777e-02]]], grad_fn=<NativeLayerNormBackward>), pooler_output=tensor([[ 5.1994e-03, -2.1290e-01, -2.2585e-01, -9.3315e-02,  1.1761e-01,
          1.9024e-01,  2.4873e-01, -8.1097e-02, -4.4050e-02, -1.6596e-01,
          2.1572e-01, -8.3736e-03, -7.7416e-02,  8.2714e-02, -1.2487e-01,
          4.8405e-01,  2.1145e-01, -4.4653e-01,  4.1008e-02, -1.2578e-02,
         -2.5560e-01,  7.3874e-02,  4.6924e-01,  3.2284e-01,  1.2382e-01,
          6.3117e-02, -1.2633e-01, -1.3542e-02,  1.6195e-01,  2.1738e-01,
          2.8682e-01,  6.2675e-02,  8.5778e-02,  2.3686e-01, -2.5219e-01,
          4.3382e-02, -3.0992e-01,  4.0551e-02,  2.3078e-01, -1.8763e-01,
         -7.2084e-02,  1.5847e-01,  2.0228e-01, -1.2756e-01, -1.1429e-01,
          3.9603e-01,  2.6018e-01,  2.9297e-02, -1.2162e-01, -8.2051e-02,
         -3.5665e-01,  3.4722e-01,  2.8292e-01,  1.9929e-01, -1.4734e-02,
          4.8892e-02, -1.4596e-01,  2.5527e-01, -7.5540e-02, -8.7507e-02,
         -1.2164e-01, -2.0039e-01, -7.1405e-03, -5.9407e-02,  4.1914e-02,
         -1.3208e-01,  6.8612e-02, -1.4154e-01, -1.1040e-01,  5.3550e-02,
         -7.3125e-02,  1.6401e-01,  1.6232e-01, -3.0164e-01, -2.8429e-01,
          5.9359e-02, -5.7466e-01, -9.3444e-02,  3.1263e-01,  4.3317e-01,
         -1.0505e-01,  2.0451e-01,  3.2512e-02,  2.0723e-01, -2.0692e-02,
         -7.1687e-02, -4.7481e-02, -9.9879e-02,  1.8269e-01,  2.6383e-01,
         -1.8467e-01, -3.7178e-01,  7.7521e-02,  2.5127e-02, -1.0401e-01,
          3.6107e-03, -2.1520e-02, -8.3309e-02, -1.6115e-01, -1.6062e-01,
          4.8140e-02, -2.4970e-01, -1.4153e-01,  2.7024e-01, -3.3932e-02,
         -2.0374e-01, -2.1484e-02,  2.9751e-01,  7.9157e-02, -1.1653e-01,
         -1.9188e-01,  4.2722e-01,  2.7925e-01,  2.1393e-03, -1.2266e-02,
          1.6474e-01,  1.5236e-01, -2.8896e-01,  4.2948e-01, -3.1129e-01,
         -8.3017e-03, -1.2746e-01,  9.3701e-02,  1.3876e-01, -2.2297e-01,
          2.7802e-01,  1.4876e-01,  2.7402e-01,  1.8048e-01,  1.0370e-01,
         -1.9344e-02,  1.5151e-01, -1.0991e-01,  1.4710e-01,  2.1547e-01,
          1.1827e-01, -3.4855e-03, -3.2536e-01, -2.0146e-01,  2.6933e-01,
          3.2889e-01,  1.6086e-01, -4.0214e-02,  1.7948e-01,  9.3386e-02,
          2.3509e-01,  1.3713e-01, -3.9689e-01,  2.7102e-02,  3.2785e-01,
          9.5205e-02,  1.6616e-01, -8.5767e-02, -2.9547e-01, -2.5672e-01,
         -9.8167e-02,  4.9819e-02, -3.1859e-01, -1.0416e-01,  3.6841e-01,
          3.9702e-02,  1.0092e-02, -1.5538e-01, -2.3540e-01, -3.1184e-02,
         -1.0212e-01,  2.2566e-02,  8.2764e-02, -8.8543e-02, -4.3323e-01,
         -9.3633e-02, -5.2647e-01, -1.1645e-01,  1.9042e-01, -3.2305e-01,
          2.3381e-01, -2.9311e-01,  9.7979e-02,  3.9168e-01,  4.9259e-02,
         -1.4776e-03, -1.9380e-01, -2.3509e-02,  1.1105e-01,  3.2593e-01,
          2.3894e-01, -4.0230e-01,  1.0931e-01,  1.4188e-01,  2.5775e-01,
          1.4593e-01, -6.1273e-02, -1.1612e-01,  1.5098e-01, -2.0754e-01,
          1.7434e-01, -2.2652e-01,  1.8130e-01, -2.4506e-01, -2.1763e-01,
          2.7327e-01, -4.0463e-01, -3.2816e-02,  8.4527e-02,  2.6249e-01,
          5.3814e-03, -3.7149e-02, -8.5346e-02,  1.2447e-01,  1.6730e-01,
          1.2632e-01, -3.9535e-01,  2.6029e-01, -2.7662e-02, -1.0249e-02,
         -3.3478e-02,  1.7810e-01,  2.5045e-01,  8.5239e-02, -3.8860e-01,
         -1.2383e-01,  1.0050e-01,  2.8403e-01, -2.3275e-01,  1.4317e-01,
         -2.5891e-01, -3.7882e-01, -1.4209e-01,  2.0352e-01,  2.2653e-01,
          1.7335e-01, -2.7438e-01,  1.6319e-01, -1.0932e-01, -4.1693e-01,
         -3.6249e-01, -1.0035e-01,  2.3360e-01,  1.6881e-01,  1.8671e-01,
          2.4667e-01,  3.3260e-02,  1.0321e-01,  1.5088e-01,  1.4948e-01,
         -1.4621e-01,  1.5615e-01, -3.4600e-01, -4.0963e-02, -2.5380e-01,
         -1.9230e-01, -2.1718e-01,  3.9236e-01, -2.2989e-01,  2.3594e-01,
          3.7562e-01, -3.1114e-01, -1.2649e-01,  1.6051e-01,  1.0577e-01,
          8.3363e-02, -1.2593e-01,  2.0584e-01,  1.4127e-01, -1.0632e-01,
          2.3048e-01, -9.1782e-03,  2.4953e-01,  1.7332e-01,  8.9003e-02,
          1.5702e-01,  1.1363e-01, -1.4771e-01,  3.8719e-02,  3.1532e-03,
         -1.9275e-02, -2.3103e-01, -1.4421e-01,  2.3575e-01, -5.3696e-02,
          3.7994e-02, -1.6641e-01, -1.1399e-01,  1.1890e-02,  3.9652e-01,
         -3.6048e-01,  2.4448e-01,  7.5291e-02,  1.5195e-01, -2.2864e-01,
         -2.1031e-01,  8.9896e-02,  1.6992e-01, -3.9863e-01,  1.7435e-03,
          1.6546e-01,  1.0483e-01,  2.1218e-01,  2.8159e-01, -5.0768e-05,
         -9.4031e-02,  4.9128e-01, -1.6932e-01, -1.2761e-01,  2.5249e-01,
         -2.7310e-01, -2.7744e-01,  2.4945e-01, -2.9252e-02,  3.0148e-01,
          1.1743e-01,  3.8095e-02,  6.4731e-02, -6.0554e-01,  6.6722e-02,
         -4.5959e-01, -8.5795e-03,  3.6827e-02, -7.1025e-02, -2.0848e-01,
          1.4228e-01,  2.9630e-01, -2.4077e-01, -4.3098e-02,  2.2334e-01,
          8.4614e-02, -1.2553e-01,  4.8810e-01,  1.4035e-03,  2.0182e-01,
         -6.3799e-02,  2.4192e-01, -2.0799e-01,  2.6687e-01, -2.7694e-01,
         -9.9754e-02,  4.7169e-03,  8.3846e-02,  4.5165e-02, -5.9169e-02,
         -3.5243e-01,  2.2125e-01, -1.9798e-02, -5.7139e-02, -4.0613e-02,
          9.3967e-02, -1.7488e-02,  6.2663e-02,  5.2265e-02,  3.4857e-01,
          2.2626e-01, -2.6472e-02, -3.7240e-01, -1.1370e-02, -9.6964e-02,
          5.6610e-02,  2.6788e-02, -2.4786e-02,  4.3236e-01, -7.6491e-02,
          5.2428e-03, -1.4826e-01,  2.6029e-01,  1.8327e-01,  1.2439e-01,
          1.3042e-01,  6.1801e-02,  1.3667e-01, -6.7373e-02, -1.2048e-02,
         -1.4603e-01, -2.2505e-01, -2.9760e-01,  2.0056e-01, -2.2011e-01,
         -1.8115e-01,  1.4262e-01,  2.1523e-01, -1.3893e-01,  1.4466e-01,
          2.9357e-01,  1.2101e-01, -1.4499e-01,  2.6571e-01, -1.0101e-01,
          1.1599e-01,  2.9781e-01, -2.0156e-02,  2.0140e-01,  5.0007e-01,
          2.1717e-01, -3.5394e-01, -1.6840e-02, -2.2424e-01,  1.1153e-02,
          2.4465e-01, -1.5139e-01,  1.9410e-01,  3.8725e-01,  2.9424e-01,
          4.2920e-01, -7.3521e-03, -1.1083e-01,  9.0861e-02,  2.2493e-01,
          2.7805e-02, -1.5957e-01, -1.9878e-01,  2.5621e-01,  6.2884e-02,
         -1.5698e-01, -1.3003e-02, -1.1539e-01,  2.8383e-02, -1.2329e-01,
         -3.8461e-01,  3.5215e-02,  1.8146e-01, -4.7366e-01,  7.3029e-02,
         -2.9380e-01,  4.3577e-02, -2.3151e-01,  2.0148e-01, -2.2549e-01,
         -1.0662e-01,  3.9916e-01, -7.4222e-02,  4.3181e-02, -1.7905e-01,
         -1.3722e-01,  2.4749e-02,  7.4731e-03, -1.4543e-02, -4.1486e-03,
          3.3529e-01, -1.2101e-01,  3.6759e-02,  3.6844e-02,  1.9582e-01,
         -5.1381e-02,  2.0516e-01,  3.1175e-02, -1.4019e-01, -4.0386e-01,
          1.4271e-01, -1.8793e-01, -4.2023e-01, -3.6638e-01,  3.6652e-01,
         -1.3753e-01, -2.5359e-01, -2.0423e-01, -2.4466e-01,  8.1067e-02,
          1.6987e-01,  4.7120e-01, -3.9858e-01, -6.8325e-02,  4.7077e-01,
         -6.8745e-02, -1.8953e-01,  2.7360e-01,  1.8793e-01, -3.3325e-01,
          3.1144e-01,  2.6919e-01, -5.6080e-02,  1.5771e-02,  5.0668e-01,
          1.1729e-01,  1.8437e-01, -2.0954e-01,  4.4338e-01, -2.1112e-01,
          3.1039e-01, -1.6460e-01, -2.1319e-01, -2.1592e-01, -1.9942e-02,
          3.3144e-01,  1.8923e-01, -4.2029e-01, -1.0169e-01,  3.1353e-02,
          3.6021e-01, -3.7626e-01, -8.6387e-02,  1.3697e-02, -3.3636e-01,
          1.2770e-01,  1.0668e-01,  2.2197e-01, -3.7968e-01, -1.5053e-02,
          3.9753e-01, -2.9535e-01,  1.3459e-01,  3.2518e-01,  7.6786e-02,
          3.4168e-01, -2.8172e-02,  1.0189e-02,  5.9536e-02, -2.3156e-01,
         -3.8199e-02,  1.3041e-01,  5.4866e-01,  1.5127e-01, -3.6896e-01,
          9.5292e-02,  2.4462e-01, -1.6506e-01,  3.1529e-01, -8.9680e-02,
         -4.6637e-02,  2.6508e-01, -3.6751e-02,  1.5445e-01, -9.7824e-02,
         -2.1623e-01, -3.0666e-01,  3.6944e-01, -1.8711e-01, -1.1481e-01,
         -1.6787e-01, -1.1253e-01, -1.4680e-01,  4.1271e-02, -3.6980e-01,
          3.3081e-01,  1.2455e-01, -1.8123e-01, -6.8767e-02, -9.6390e-02,
         -1.4910e-01, -2.0524e-01, -2.6686e-01,  4.2154e-01, -1.6543e-01,
         -4.5050e-01,  2.5019e-01,  3.4722e-02,  3.3103e-01,  4.8806e-02,
          9.9796e-02, -3.0042e-02,  1.4140e-01,  9.6566e-02, -1.1071e-01,
          2.7989e-01,  6.1347e-02, -5.6286e-01, -1.4422e-01, -2.1070e-01,
          7.6292e-02,  1.9691e-01, -3.3576e-01,  1.7638e-02,  2.2105e-02,
          1.4265e-01,  3.0694e-02, -1.0665e-01, -5.5197e-02,  3.9164e-01,
          2.0961e-01,  2.9841e-01,  8.7946e-02,  2.4076e-01, -1.0636e-02,
         -3.3807e-01,  2.2974e-02,  8.5258e-02, -1.8663e-01,  4.1414e-01,
         -9.9141e-02, -3.8117e-01, -5.6155e-02,  3.9692e-01,  9.7551e-02,
         -1.8710e-02, -5.0913e-02,  2.0049e-01,  1.5407e-01, -1.2523e-01,
          1.8187e-01, -9.7470e-03, -1.3372e-01, -1.0178e-01,  8.5468e-02,
         -2.1953e-01,  4.7566e-02, -1.3239e-01,  1.3200e-03, -2.0911e-01,
          3.2521e-03, -2.1387e-01,  2.4508e-01, -3.2182e-01,  9.8634e-02,
          7.5848e-02,  2.9231e-01, -3.5121e-01, -1.4159e-01, -5.7937e-02,
          1.6263e-01,  2.5830e-01,  3.5601e-01,  2.3997e-02,  2.5322e-02,
         -1.5363e-01, -2.6361e-01,  5.4986e-02, -2.0897e-01,  1.2282e-01,
          7.1346e-02,  2.4762e-01, -3.0430e-01, -1.8016e-01,  2.2226e-01,
         -9.7989e-02, -1.4158e-01,  4.2292e-01,  2.5139e-01,  2.1049e-01,
          2.2865e-02,  2.4210e-01,  3.7744e-02, -1.7568e-01, -1.1512e-01,
         -2.4392e-01,  6.9097e-02, -8.5799e-02, -5.8893e-02, -7.1211e-02,
         -1.2143e-01, -1.9825e-01, -1.5658e-01,  1.5637e-01,  1.3693e-01,
          2.9095e-02, -5.5552e-02, -2.4771e-02, -2.7771e-01,  2.9286e-01,
          2.4894e-02,  7.2069e-02, -4.8322e-02,  2.3967e-02, -1.5199e-01,
          2.3989e-01,  2.0234e-01,  8.2009e-02, -1.8899e-01, -4.8667e-02,
         -2.9075e-01, -3.5470e-01,  4.1930e-02,  1.3129e-01,  1.1387e-01,
         -1.0165e-01, -2.7247e-01, -2.7974e-02, -1.3051e-01,  1.8051e-01,
         -9.6646e-03, -1.5500e-01, -7.4565e-02, -6.0039e-02, -5.1055e-02,
          6.7692e-02, -2.0781e-01, -1.9844e-01, -1.2495e-01, -7.5151e-02,
         -6.6146e-02,  3.6196e-01, -3.5989e-02,  2.7737e-01, -1.5471e-01,
          1.1208e-02, -1.9818e-01,  1.0743e-01, -7.3001e-02,  7.3365e-02,
          2.6398e-01, -4.2969e-01, -1.5308e-01,  6.1186e-03, -2.1301e-01,
         -1.4149e-01, -7.1113e-02, -4.0364e-02,  2.1242e-01, -3.4205e-01,
          2.1659e-01, -8.0915e-02,  1.8907e-01, -9.4013e-02, -2.5456e-01,
         -1.6216e-01,  2.3130e-02,  2.4984e-01, -3.3239e-01, -2.2947e-01,
         -2.6681e-01, -9.7903e-02, -9.0469e-02, -2.6217e-01,  4.1510e-01,
         -1.0590e-01, -5.5713e-02,  9.9271e-03,  4.3321e-01,  1.9454e-01,
          1.5135e-01,  2.1670e-01, -1.3371e-02,  2.7091e-02,  1.0805e-01,
         -4.6743e-01,  2.3397e-01, -2.2627e-01, -1.2724e-01,  2.7149e-02,
          8.9104e-02, -3.1547e-02,  1.2930e-02, -1.1888e-01, -1.0141e-01,
          2.0849e-01, -3.6962e-01, -1.2304e-02,  2.7230e-01,  1.4519e-01,
         -2.4969e-01,  4.2865e-02,  1.2965e-01,  3.7797e-01,  8.8492e-02,
         -2.2487e-01,  1.3100e-01, -3.4240e-01, -2.4896e-02, -1.8675e-01,
         -2.9198e-01,  1.3836e-01, -6.9468e-02,  5.4983e-02, -6.8482e-02,
         -2.7968e-01,  2.1223e-01, -5.0621e-02, -6.3859e-02,  4.1759e-01,
          3.3747e-02, -1.1644e-01,  1.5398e-01,  1.5137e-02, -5.4925e-03,
         -1.0726e-01,  2.6553e-01,  2.0031e-01, -2.7755e-01,  1.2135e-01,
         -1.2860e-01, -2.5987e-02, -1.1620e-01]], grad_fn=<TanhBackward>), hidden_states=None, past_key_values=None, attentions=None, cross_attentions=None)
len(output)
2
output[0].shape
torch.Size([1, 16, 768])

output[1].shape
torch.Size([1, 768])
output = model(**encoded_input, output_hidden_states=True)
len(output)
3
len(output[2])
13
output[2][0]
tensor([[[ 0.1664, -0.0541, -0.0014,  ..., -0.0811,  0.0794,  0.0155],
         [-0.7241,  0.1035,  0.0784,  ...,  0.2474, -0.0535,  0.4320],
         [ 0.5926, -0.1062,  0.0372,  ..., -0.0140,  0.1021, -0.2212],
         ...,
         [ 0.4734, -0.0570, -0.2506,  ...,  0.4071,  0.4481, -0.2180],
         [ 0.7836, -0.2838, -0.2083,  ..., -0.0959, -0.0136,  0.1995],
         [ 0.2733, -0.1372, -0.0387,  ...,  0.5187,  0.1545, -0.2604]]],
       grad_fn=<NativeLayerNormBackward>)
output[2][0].shape
torch.Size([1, 16, 768])
output[2][1].shape
torch.Size([1, 16, 768])
output[2][12].shape
torch.Size([1, 16, 768])
output = model(**encoded_input, output_attentions=True)
len(output)
3
len(output[2])
12
output[2][0].shape
torch.Size([1, 12, 16, 16])
output[2][2]
tensor([[[[9.8775e-01, 6.2288e-04, 8.7264e-04,  ..., 5.4309e-04,
           1.3059e-03, 1.0826e-03],
          [3.3152e-01, 7.3213e-03, 3.0339e-02,  ..., 1.6386e-03,
           1.1041e-03, 1.0450e-03],
          [8.4058e-01, 7.4270e-04, 1.8587e-04,  ..., 1.9484e-03,
           8.3106e-04, 2.2206e-03],
          ...,
          [8.3998e-01, 6.3201e-06, 7.9328e-06,  ..., 1.8371e-02,
           5.9146e-02, 7.1377e-02],
          [9.4819e-01, 3.9591e-06, 2.9191e-06,  ..., 9.6707e-03,
           1.1201e-02, 2.4954e-02],
          [9.2851e-01, 4.9144e-04, 2.2858e-04,  ..., 9.3861e-03,
           1.7582e-02, 2.4180e-02]],

         [[9.2353e-01, 4.3481e-03, 1.9423e-02,  ..., 5.0829e-03,
           7.5931e-03, 4.6599e-03],
          [9.7840e-01, 4.1909e-03, 9.0263e-03,  ..., 2.1102e-06,
           5.4437e-07, 6.7581e-06],
          [8.3596e-01, 6.3265e-02, 7.9091e-02,  ..., 1.4975e-05,
           2.1750e-06, 2.3804e-06],
          ...,
          [4.7469e-01, 1.1083e-04, 1.8293e-03,  ..., 9.7021e-03,
           6.5544e-03, 1.9043e-03],
          [2.1963e-01, 1.3427e-06, 1.2042e-04,  ..., 7.5510e-01,
           2.8724e-03, 6.2941e-03],
          [4.2043e-01, 3.4030e-06, 6.4028e-05,  ..., 8.2335e-02,
           3.9994e-01, 9.1114e-02]],

         [[9.8968e-01, 2.9357e-04, 2.4483e-04,  ..., 2.0526e-04,
           4.1698e-04, 3.3650e-03],
          [9.0939e-01, 3.1261e-03, 2.7859e-02,  ..., 3.1149e-04,
           8.0127e-05, 2.8887e-03],
          [8.9282e-01, 2.4450e-04, 5.3892e-03,  ..., 8.5178e-04,
           9.8922e-05, 2.7169e-03],
          ...,
          [9.3745e-01, 2.0096e-06, 4.1223e-06,  ..., 4.7319e-02,
           3.8060e-03, 6.3264e-03],
          [9.5799e-01, 1.2817e-04, 1.0723e-05,  ..., 1.0232e-03,
           2.1168e-02, 3.7038e-03],
          [9.1897e-01, 4.5952e-04, 7.4514e-05,  ..., 5.2304e-05,
           3.8385e-05, 5.9209e-02]],

         ...,

         [[9.7214e-01, 1.8048e-03, 2.0910e-03,  ..., 1.5654e-03,
           2.0380e-03, 2.9465e-03],
          [2.0737e-01, 1.5373e-02, 3.4949e-01,  ..., 1.0591e-04,
           3.8994e-06, 1.9794e-05],
          [7.0131e-01, 2.8094e-03, 7.6395e-03,  ..., 1.2338e-03,
           8.6231e-05, 8.1068e-05],
          ...,
          [4.1426e-01, 1.9507e-06, 5.5085e-05,  ..., 3.8152e-02,
           4.5979e-01, 6.9998e-02],
          [7.5517e-01, 2.2428e-07, 3.2856e-06,  ..., 1.3153e-02,
           5.5085e-03, 2.1891e-01],
          [9.4142e-01, 3.3256e-05, 6.0546e-06,  ..., 9.1890e-04,
           8.7666e-03, 3.8735e-02]],

         [[9.7447e-01, 1.1291e-03, 2.3473e-03,  ..., 1.6628e-03,
           1.7247e-03, 3.7978e-03],
          [7.2027e-01, 5.4353e-02, 5.0394e-03,  ..., 4.7070e-03,
           1.4477e-03, 7.9330e-02],
          [9.1602e-01, 6.2537e-03, 6.2520e-03,  ..., 3.0431e-03,
           1.6902e-03, 2.6523e-02],
          ...,
          [8.7035e-01, 5.6680e-03, 2.5519e-04,  ..., 1.0693e-02,
           1.0154e-02, 2.8158e-02],
          [7.8992e-01, 1.3184e-03, 5.2799e-04,  ..., 3.8399e-03,
           2.3379e-02, 5.4757e-02],
          [4.0584e-01, 5.6631e-03, 8.5153e-03,  ..., 1.0006e-02,
           1.0799e-02, 1.9912e-01]],

         [[9.8713e-01, 3.3973e-04, 9.6788e-04,  ..., 2.1040e-04,
           1.3595e-03, 8.0080e-04],
          [1.0312e-01, 4.2905e-03, 8.3475e-01,  ..., 7.3782e-06,
           1.9842e-04, 1.3445e-03],
          [7.9036e-01, 2.8547e-02, 5.0725e-02,  ..., 1.9356e-05,
           6.4891e-05, 2.8477e-03],
          ...,
          [2.1335e-01, 9.7233e-06, 6.9469e-05,  ..., 3.6693e-04,
           3.3324e-01, 1.3384e-02],
          [1.1667e-02, 3.0911e-05, 2.5899e-06,  ..., 5.6125e-01,
           2.7517e-04, 1.5053e-03],
          [8.4494e-01, 8.0791e-04, 1.0116e-03,  ..., 2.4602e-03,
           6.7727e-02, 1.1728e-02]]]], grad_fn=<SoftmaxBackward>)

gotowe api

generowanie tekstu

model = pipeline('text-generation', model='gpt2')
Some weights of GPT2Model were not initialized from the model checkpoint at gpt2 and are newly initialized: ['h.0.attn.masked_bias', 'h.1.attn.masked_bias', 'h.2.attn.masked_bias', 'h.3.attn.masked_bias', 'h.4.attn.masked_bias', 'h.5.attn.masked_bias', 'h.6.attn.masked_bias', 'h.7.attn.masked_bias', 'h.8.attn.masked_bias', 'h.9.attn.masked_bias', 'h.10.attn.masked_bias', 'h.11.attn.masked_bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
model("Hello, I'm a computer science student", max_length=30, num_return_sequences=5)
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
[{'generated_text': 'Hello, I\'m a computer science student by trade. I don\'t really like science."\n\nThen I hear him say: "I love the'},
 {'generated_text': "Hello, I'm a computer science student.\n\nAnd if you're curious what I'm doing here, don't hesitate:\n\n\nI've"},
 {'generated_text': "Hello, I'm a computer science student, not an engineer. But, I'm also fascinated, because all the people I'm talking to are engineers"},
 {'generated_text': "Hello, I'm a computer science student with a big project called the Data Science project, to help students create and understand and improve their data science."},
 {'generated_text': "Hello, I'm a computer science student from North Carolina. My work involves a number of questions (and many possible answers as well). I can't"}]
model("I want to contribute to Google's Computer Vision Program, which is doing extensive work on big", max_length=30, num_return_sequences=5)
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
[{'generated_text': "I want to contribute to Google's Computer Vision Program, which is doing extensive work on big data. It is also working on creating a new type of"},
 {'generated_text': "I want to contribute to Google's Computer Vision Program, which is doing extensive work on big datasets. It has now given every request to Google for information"},
 {'generated_text': "I want to contribute to Google's Computer Vision Program, which is doing extensive work on big data, robotics and artificial intelligence. We understand the potential impact"},
 {'generated_text': "I want to contribute to Google's Computer Vision Program, which is doing extensive work on big data, artificial intelligence, and machine learning. I think we"},
 {'generated_text': "I want to contribute to Google's Computer Vision Program, which is doing extensive work on big new ways to see and interact with image data. We'll"}]

sentiment analysis

from transformers import pipeline

model = pipeline("sentiment-analysis", model='distilbert-base-uncased-finetuned-sst-2-english')
model
<transformers.pipelines.text_classification.TextClassificationPipeline at 0x7f8e45b319d0>
model("I'm very happy. Today is the beatifull weather")
[{'label': 'POSITIVE', 'score': 0.9998474717140198}]
model("It's raining. What a terrible day...")
[{'label': 'NEGATIVE', 'score': 0.9946851134300232}]

NER

model = pipeline("sentiment-analysis", model='distilbert-base-uncased-finetuned-sst-2-english')
from transformers import pipeline
model = pipeline("ner")
text = "George Washington went to Washington"
model(text)
[{'word': 'George',
  'score': 0.9983943104743958,
  'entity': 'I-PER',
  'index': 1,
  'start': 0,
  'end': 6},
 {'word': 'Washington',
  'score': 0.9992505311965942,
  'entity': 'I-PER',
  'index': 2,
  'start': 7,
  'end': 17},
 {'word': 'Washington',
  'score': 0.98389732837677,
  'entity': 'I-LOC',
  'index': 5,
  'start': 26,
  'end': 36}]

masked language modelling

ZADANIE (10 minut)

przewidziać