forked from kubapok/ISI-transformers
1158 lines
37 KiB
Plaintext
1158 lines
37 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# bpe"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"pip install tokenizers"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"https://github.com/huggingface/tokenizers/tree/master/bindings/python"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from tokenizers import Tokenizer, models, trainers"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from tokenizers.trainers import BpeTrainer"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"tokenizer = Tokenizer(models.BPE())"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"trainer = trainers.BpeTrainer(vocab_size=20000, min_frequency=2)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"tokenizer.train(files = ['/home/kuba/Syncthing/przedmioty/2020-02/ISI/zajecia9_ngramowy_model_jDDezykowy/pan-tadeusz-train.txt'], trainer = trainer)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"output = tokenizer.encode(\"Nie śpiewają piosenek: pracują leniwo,\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"[236, 2255, 2069, 3898, 9908, 14, 8675, 8319, 191, 7]"
|
|
]
|
|
},
|
|
"execution_count": 7,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"output.ids"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"['Nie', ' śpie', 'wają', ' pios', 'enek', ':', ' pracują', ' leni', 'wo', ',']"
|
|
]
|
|
},
|
|
"execution_count": 8,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"output.tokens"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"tokenizer.save(\"./my-bpe.tokenizer.json\", pretty=True)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## ZADANIE\n",
|
|
"stworzyć BPE tokenizer na podstawie https://git.wmi.amu.edu.pl/kubapok/lalka-lm/src/branch/master/train/train.tsv\n",
|
|
"i stworzyć stokenizowaną listę: \n",
|
|
"https://git.wmi.amu.edu.pl/kubapok/lalka-lm/src/branch/master/test-A/in.tsv\n",
|
|
"\n",
|
|
"wybrać vocab_size = 8k, uwzględnić dodatkowe tokeny: BOS oraz EOS i wpleść je do zbioru testowego"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# transformery"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 20,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# pip install transformers"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 21,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import torch"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 22,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from transformers import pipeline, set_seed"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 23,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from transformers import RobertaTokenizer, RobertaModel"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 24,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"tokenizer = RobertaTokenizer.from_pretrained('roberta-base')\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 25,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"model = RobertaModel.from_pretrained('roberta-base')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 26,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"text = \"Replace me by any text you'd like. Bla Bla\""
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 27,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"encoded_input = tokenizer(text, return_tensors='pt')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 28,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"tensor([[ 0, 9064, 6406, 162, 30, 143, 2788, 47, 1017, 101, 4, 2091,\n",
|
|
" 102, 2091, 102, 2]])"
|
|
]
|
|
},
|
|
"execution_count": 28,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"encoded_input['input_ids']"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 29,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"tensor([[ 0, 9064, 6406, 162, 30, 143, 2788, 47, 1017, 101, 4, 2091,\n",
|
|
" 102, 2091, 102, 2]])"
|
|
]
|
|
},
|
|
"execution_count": 29,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"encoded_input['input_ids']"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 30,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"' me'"
|
|
]
|
|
},
|
|
"execution_count": 30,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"tokenizer.decode([162])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 31,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"output = model(**encoded_input)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 32,
|
|
"metadata": {
|
|
"scrolled": true
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"BaseModelOutputWithPoolingAndCrossAttentions(last_hidden_state=tensor([[[-4.4858e-02, 8.6642e-02, -7.2129e-03, ..., -4.6295e-02,\n",
|
|
" -3.9316e-02, 4.5264e-04],\n",
|
|
" [-6.0603e-02, 1.5684e-01, 4.3705e-02, ..., 5.3485e-01,\n",
|
|
" 8.4371e-02, 1.4826e-01],\n",
|
|
" [-2.3786e-02, -1.2086e-02, 7.8233e-02, ..., -4.9132e-01,\n",
|
|
" 1.2500e-01, 3.3293e-01],\n",
|
|
" ...,\n",
|
|
" [ 6.7192e-02, 2.4028e-01, -2.9984e-01, ..., 2.1992e-01,\n",
|
|
" 1.9186e-02, 1.5355e-01],\n",
|
|
" [ 1.7611e-01, 1.4001e-01, -1.3774e-01, ..., -5.0379e-01,\n",
|
|
" 7.3958e-02, 3.7870e-02],\n",
|
|
" [-3.4405e-02, 8.7648e-02, -3.9429e-02, ..., -9.1916e-02,\n",
|
|
" -3.5529e-02, -2.6777e-02]]], grad_fn=<NativeLayerNormBackward>), pooler_output=tensor([[ 5.1994e-03, -2.1290e-01, -2.2585e-01, -9.3315e-02, 1.1761e-01,\n",
|
|
" 1.9024e-01, 2.4873e-01, -8.1097e-02, -4.4050e-02, -1.6596e-01,\n",
|
|
" 2.1572e-01, -8.3736e-03, -7.7416e-02, 8.2714e-02, -1.2487e-01,\n",
|
|
" 4.8405e-01, 2.1145e-01, -4.4653e-01, 4.1008e-02, -1.2578e-02,\n",
|
|
" -2.5560e-01, 7.3874e-02, 4.6924e-01, 3.2284e-01, 1.2382e-01,\n",
|
|
" 6.3117e-02, -1.2633e-01, -1.3542e-02, 1.6195e-01, 2.1738e-01,\n",
|
|
" 2.8682e-01, 6.2675e-02, 8.5778e-02, 2.3686e-01, -2.5219e-01,\n",
|
|
" 4.3382e-02, -3.0992e-01, 4.0551e-02, 2.3078e-01, -1.8763e-01,\n",
|
|
" -7.2084e-02, 1.5847e-01, 2.0228e-01, -1.2756e-01, -1.1429e-01,\n",
|
|
" 3.9603e-01, 2.6018e-01, 2.9297e-02, -1.2162e-01, -8.2051e-02,\n",
|
|
" -3.5665e-01, 3.4722e-01, 2.8292e-01, 1.9929e-01, -1.4734e-02,\n",
|
|
" 4.8892e-02, -1.4596e-01, 2.5527e-01, -7.5540e-02, -8.7507e-02,\n",
|
|
" -1.2164e-01, -2.0039e-01, -7.1405e-03, -5.9407e-02, 4.1914e-02,\n",
|
|
" -1.3208e-01, 6.8612e-02, -1.4154e-01, -1.1040e-01, 5.3550e-02,\n",
|
|
" -7.3125e-02, 1.6401e-01, 1.6232e-01, -3.0164e-01, -2.8429e-01,\n",
|
|
" 5.9359e-02, -5.7466e-01, -9.3444e-02, 3.1263e-01, 4.3317e-01,\n",
|
|
" -1.0505e-01, 2.0451e-01, 3.2512e-02, 2.0723e-01, -2.0692e-02,\n",
|
|
" -7.1687e-02, -4.7481e-02, -9.9879e-02, 1.8269e-01, 2.6383e-01,\n",
|
|
" -1.8467e-01, -3.7178e-01, 7.7521e-02, 2.5127e-02, -1.0401e-01,\n",
|
|
" 3.6107e-03, -2.1520e-02, -8.3309e-02, -1.6115e-01, -1.6062e-01,\n",
|
|
" 4.8140e-02, -2.4970e-01, -1.4153e-01, 2.7024e-01, -3.3932e-02,\n",
|
|
" -2.0374e-01, -2.1484e-02, 2.9751e-01, 7.9157e-02, -1.1653e-01,\n",
|
|
" -1.9188e-01, 4.2722e-01, 2.7925e-01, 2.1393e-03, -1.2266e-02,\n",
|
|
" 1.6474e-01, 1.5236e-01, -2.8896e-01, 4.2948e-01, -3.1129e-01,\n",
|
|
" -8.3017e-03, -1.2746e-01, 9.3701e-02, 1.3876e-01, -2.2297e-01,\n",
|
|
" 2.7802e-01, 1.4876e-01, 2.7402e-01, 1.8048e-01, 1.0370e-01,\n",
|
|
" -1.9344e-02, 1.5151e-01, -1.0991e-01, 1.4710e-01, 2.1547e-01,\n",
|
|
" 1.1827e-01, -3.4855e-03, -3.2536e-01, -2.0146e-01, 2.6933e-01,\n",
|
|
" 3.2889e-01, 1.6086e-01, -4.0214e-02, 1.7948e-01, 9.3386e-02,\n",
|
|
" 2.3509e-01, 1.3713e-01, -3.9689e-01, 2.7102e-02, 3.2785e-01,\n",
|
|
" 9.5205e-02, 1.6616e-01, -8.5767e-02, -2.9547e-01, -2.5672e-01,\n",
|
|
" -9.8167e-02, 4.9819e-02, -3.1859e-01, -1.0416e-01, 3.6841e-01,\n",
|
|
" 3.9702e-02, 1.0092e-02, -1.5538e-01, -2.3540e-01, -3.1184e-02,\n",
|
|
" -1.0212e-01, 2.2566e-02, 8.2764e-02, -8.8543e-02, -4.3323e-01,\n",
|
|
" -9.3633e-02, -5.2647e-01, -1.1645e-01, 1.9042e-01, -3.2305e-01,\n",
|
|
" 2.3381e-01, -2.9311e-01, 9.7979e-02, 3.9168e-01, 4.9259e-02,\n",
|
|
" -1.4776e-03, -1.9380e-01, -2.3509e-02, 1.1105e-01, 3.2593e-01,\n",
|
|
" 2.3894e-01, -4.0230e-01, 1.0931e-01, 1.4188e-01, 2.5775e-01,\n",
|
|
" 1.4593e-01, -6.1273e-02, -1.1612e-01, 1.5098e-01, -2.0754e-01,\n",
|
|
" 1.7434e-01, -2.2652e-01, 1.8130e-01, -2.4506e-01, -2.1763e-01,\n",
|
|
" 2.7327e-01, -4.0463e-01, -3.2816e-02, 8.4527e-02, 2.6249e-01,\n",
|
|
" 5.3814e-03, -3.7149e-02, -8.5346e-02, 1.2447e-01, 1.6730e-01,\n",
|
|
" 1.2632e-01, -3.9535e-01, 2.6029e-01, -2.7662e-02, -1.0249e-02,\n",
|
|
" -3.3478e-02, 1.7810e-01, 2.5045e-01, 8.5239e-02, -3.8860e-01,\n",
|
|
" -1.2383e-01, 1.0050e-01, 2.8403e-01, -2.3275e-01, 1.4317e-01,\n",
|
|
" -2.5891e-01, -3.7882e-01, -1.4209e-01, 2.0352e-01, 2.2653e-01,\n",
|
|
" 1.7335e-01, -2.7438e-01, 1.6319e-01, -1.0932e-01, -4.1693e-01,\n",
|
|
" -3.6249e-01, -1.0035e-01, 2.3360e-01, 1.6881e-01, 1.8671e-01,\n",
|
|
" 2.4667e-01, 3.3260e-02, 1.0321e-01, 1.5088e-01, 1.4948e-01,\n",
|
|
" -1.4621e-01, 1.5615e-01, -3.4600e-01, -4.0963e-02, -2.5380e-01,\n",
|
|
" -1.9230e-01, -2.1718e-01, 3.9236e-01, -2.2989e-01, 2.3594e-01,\n",
|
|
" 3.7562e-01, -3.1114e-01, -1.2649e-01, 1.6051e-01, 1.0577e-01,\n",
|
|
" 8.3363e-02, -1.2593e-01, 2.0584e-01, 1.4127e-01, -1.0632e-01,\n",
|
|
" 2.3048e-01, -9.1782e-03, 2.4953e-01, 1.7332e-01, 8.9003e-02,\n",
|
|
" 1.5702e-01, 1.1363e-01, -1.4771e-01, 3.8719e-02, 3.1532e-03,\n",
|
|
" -1.9275e-02, -2.3103e-01, -1.4421e-01, 2.3575e-01, -5.3696e-02,\n",
|
|
" 3.7994e-02, -1.6641e-01, -1.1399e-01, 1.1890e-02, 3.9652e-01,\n",
|
|
" -3.6048e-01, 2.4448e-01, 7.5291e-02, 1.5195e-01, -2.2864e-01,\n",
|
|
" -2.1031e-01, 8.9896e-02, 1.6992e-01, -3.9863e-01, 1.7435e-03,\n",
|
|
" 1.6546e-01, 1.0483e-01, 2.1218e-01, 2.8159e-01, -5.0768e-05,\n",
|
|
" -9.4031e-02, 4.9128e-01, -1.6932e-01, -1.2761e-01, 2.5249e-01,\n",
|
|
" -2.7310e-01, -2.7744e-01, 2.4945e-01, -2.9252e-02, 3.0148e-01,\n",
|
|
" 1.1743e-01, 3.8095e-02, 6.4731e-02, -6.0554e-01, 6.6722e-02,\n",
|
|
" -4.5959e-01, -8.5795e-03, 3.6827e-02, -7.1025e-02, -2.0848e-01,\n",
|
|
" 1.4228e-01, 2.9630e-01, -2.4077e-01, -4.3098e-02, 2.2334e-01,\n",
|
|
" 8.4614e-02, -1.2553e-01, 4.8810e-01, 1.4035e-03, 2.0182e-01,\n",
|
|
" -6.3799e-02, 2.4192e-01, -2.0799e-01, 2.6687e-01, -2.7694e-01,\n",
|
|
" -9.9754e-02, 4.7169e-03, 8.3846e-02, 4.5165e-02, -5.9169e-02,\n",
|
|
" -3.5243e-01, 2.2125e-01, -1.9798e-02, -5.7139e-02, -4.0613e-02,\n",
|
|
" 9.3967e-02, -1.7488e-02, 6.2663e-02, 5.2265e-02, 3.4857e-01,\n",
|
|
" 2.2626e-01, -2.6472e-02, -3.7240e-01, -1.1370e-02, -9.6964e-02,\n",
|
|
" 5.6610e-02, 2.6788e-02, -2.4786e-02, 4.3236e-01, -7.6491e-02,\n",
|
|
" 5.2428e-03, -1.4826e-01, 2.6029e-01, 1.8327e-01, 1.2439e-01,\n",
|
|
" 1.3042e-01, 6.1801e-02, 1.3667e-01, -6.7373e-02, -1.2048e-02,\n",
|
|
" -1.4603e-01, -2.2505e-01, -2.9760e-01, 2.0056e-01, -2.2011e-01,\n",
|
|
" -1.8115e-01, 1.4262e-01, 2.1523e-01, -1.3893e-01, 1.4466e-01,\n",
|
|
" 2.9357e-01, 1.2101e-01, -1.4499e-01, 2.6571e-01, -1.0101e-01,\n",
|
|
" 1.1599e-01, 2.9781e-01, -2.0156e-02, 2.0140e-01, 5.0007e-01,\n",
|
|
" 2.1717e-01, -3.5394e-01, -1.6840e-02, -2.2424e-01, 1.1153e-02,\n",
|
|
" 2.4465e-01, -1.5139e-01, 1.9410e-01, 3.8725e-01, 2.9424e-01,\n",
|
|
" 4.2920e-01, -7.3521e-03, -1.1083e-01, 9.0861e-02, 2.2493e-01,\n",
|
|
" 2.7805e-02, -1.5957e-01, -1.9878e-01, 2.5621e-01, 6.2884e-02,\n",
|
|
" -1.5698e-01, -1.3003e-02, -1.1539e-01, 2.8383e-02, -1.2329e-01,\n",
|
|
" -3.8461e-01, 3.5215e-02, 1.8146e-01, -4.7366e-01, 7.3029e-02,\n",
|
|
" -2.9380e-01, 4.3577e-02, -2.3151e-01, 2.0148e-01, -2.2549e-01,\n",
|
|
" -1.0662e-01, 3.9916e-01, -7.4222e-02, 4.3181e-02, -1.7905e-01,\n",
|
|
" -1.3722e-01, 2.4749e-02, 7.4731e-03, -1.4543e-02, -4.1486e-03,\n",
|
|
" 3.3529e-01, -1.2101e-01, 3.6759e-02, 3.6844e-02, 1.9582e-01,\n",
|
|
" -5.1381e-02, 2.0516e-01, 3.1175e-02, -1.4019e-01, -4.0386e-01,\n",
|
|
" 1.4271e-01, -1.8793e-01, -4.2023e-01, -3.6638e-01, 3.6652e-01,\n",
|
|
" -1.3753e-01, -2.5359e-01, -2.0423e-01, -2.4466e-01, 8.1067e-02,\n",
|
|
" 1.6987e-01, 4.7120e-01, -3.9858e-01, -6.8325e-02, 4.7077e-01,\n",
|
|
" -6.8745e-02, -1.8953e-01, 2.7360e-01, 1.8793e-01, -3.3325e-01,\n",
|
|
" 3.1144e-01, 2.6919e-01, -5.6080e-02, 1.5771e-02, 5.0668e-01,\n",
|
|
" 1.1729e-01, 1.8437e-01, -2.0954e-01, 4.4338e-01, -2.1112e-01,\n",
|
|
" 3.1039e-01, -1.6460e-01, -2.1319e-01, -2.1592e-01, -1.9942e-02,\n",
|
|
" 3.3144e-01, 1.8923e-01, -4.2029e-01, -1.0169e-01, 3.1353e-02,\n",
|
|
" 3.6021e-01, -3.7626e-01, -8.6387e-02, 1.3697e-02, -3.3636e-01,\n",
|
|
" 1.2770e-01, 1.0668e-01, 2.2197e-01, -3.7968e-01, -1.5053e-02,\n",
|
|
" 3.9753e-01, -2.9535e-01, 1.3459e-01, 3.2518e-01, 7.6786e-02,\n",
|
|
" 3.4168e-01, -2.8172e-02, 1.0189e-02, 5.9536e-02, -2.3156e-01,\n",
|
|
" -3.8199e-02, 1.3041e-01, 5.4866e-01, 1.5127e-01, -3.6896e-01,\n",
|
|
" 9.5292e-02, 2.4462e-01, -1.6506e-01, 3.1529e-01, -8.9680e-02,\n",
|
|
" -4.6637e-02, 2.6508e-01, -3.6751e-02, 1.5445e-01, -9.7824e-02,\n",
|
|
" -2.1623e-01, -3.0666e-01, 3.6944e-01, -1.8711e-01, -1.1481e-01,\n",
|
|
" -1.6787e-01, -1.1253e-01, -1.4680e-01, 4.1271e-02, -3.6980e-01,\n",
|
|
" 3.3081e-01, 1.2455e-01, -1.8123e-01, -6.8767e-02, -9.6390e-02,\n",
|
|
" -1.4910e-01, -2.0524e-01, -2.6686e-01, 4.2154e-01, -1.6543e-01,\n",
|
|
" -4.5050e-01, 2.5019e-01, 3.4722e-02, 3.3103e-01, 4.8806e-02,\n",
|
|
" 9.9796e-02, -3.0042e-02, 1.4140e-01, 9.6566e-02, -1.1071e-01,\n",
|
|
" 2.7989e-01, 6.1347e-02, -5.6286e-01, -1.4422e-01, -2.1070e-01,\n",
|
|
" 7.6292e-02, 1.9691e-01, -3.3576e-01, 1.7638e-02, 2.2105e-02,\n",
|
|
" 1.4265e-01, 3.0694e-02, -1.0665e-01, -5.5197e-02, 3.9164e-01,\n",
|
|
" 2.0961e-01, 2.9841e-01, 8.7946e-02, 2.4076e-01, -1.0636e-02,\n",
|
|
" -3.3807e-01, 2.2974e-02, 8.5258e-02, -1.8663e-01, 4.1414e-01,\n",
|
|
" -9.9141e-02, -3.8117e-01, -5.6155e-02, 3.9692e-01, 9.7551e-02,\n",
|
|
" -1.8710e-02, -5.0913e-02, 2.0049e-01, 1.5407e-01, -1.2523e-01,\n",
|
|
" 1.8187e-01, -9.7470e-03, -1.3372e-01, -1.0178e-01, 8.5468e-02,\n",
|
|
" -2.1953e-01, 4.7566e-02, -1.3239e-01, 1.3200e-03, -2.0911e-01,\n",
|
|
" 3.2521e-03, -2.1387e-01, 2.4508e-01, -3.2182e-01, 9.8634e-02,\n",
|
|
" 7.5848e-02, 2.9231e-01, -3.5121e-01, -1.4159e-01, -5.7937e-02,\n",
|
|
" 1.6263e-01, 2.5830e-01, 3.5601e-01, 2.3997e-02, 2.5322e-02,\n",
|
|
" -1.5363e-01, -2.6361e-01, 5.4986e-02, -2.0897e-01, 1.2282e-01,\n",
|
|
" 7.1346e-02, 2.4762e-01, -3.0430e-01, -1.8016e-01, 2.2226e-01,\n",
|
|
" -9.7989e-02, -1.4158e-01, 4.2292e-01, 2.5139e-01, 2.1049e-01,\n",
|
|
" 2.2865e-02, 2.4210e-01, 3.7744e-02, -1.7568e-01, -1.1512e-01,\n",
|
|
" -2.4392e-01, 6.9097e-02, -8.5799e-02, -5.8893e-02, -7.1211e-02,\n",
|
|
" -1.2143e-01, -1.9825e-01, -1.5658e-01, 1.5637e-01, 1.3693e-01,\n",
|
|
" 2.9095e-02, -5.5552e-02, -2.4771e-02, -2.7771e-01, 2.9286e-01,\n",
|
|
" 2.4894e-02, 7.2069e-02, -4.8322e-02, 2.3967e-02, -1.5199e-01,\n",
|
|
" 2.3989e-01, 2.0234e-01, 8.2009e-02, -1.8899e-01, -4.8667e-02,\n",
|
|
" -2.9075e-01, -3.5470e-01, 4.1930e-02, 1.3129e-01, 1.1387e-01,\n",
|
|
" -1.0165e-01, -2.7247e-01, -2.7974e-02, -1.3051e-01, 1.8051e-01,\n",
|
|
" -9.6646e-03, -1.5500e-01, -7.4565e-02, -6.0039e-02, -5.1055e-02,\n",
|
|
" 6.7692e-02, -2.0781e-01, -1.9844e-01, -1.2495e-01, -7.5151e-02,\n",
|
|
" -6.6146e-02, 3.6196e-01, -3.5989e-02, 2.7737e-01, -1.5471e-01,\n",
|
|
" 1.1208e-02, -1.9818e-01, 1.0743e-01, -7.3001e-02, 7.3365e-02,\n",
|
|
" 2.6398e-01, -4.2969e-01, -1.5308e-01, 6.1186e-03, -2.1301e-01,\n",
|
|
" -1.4149e-01, -7.1113e-02, -4.0364e-02, 2.1242e-01, -3.4205e-01,\n",
|
|
" 2.1659e-01, -8.0915e-02, 1.8907e-01, -9.4013e-02, -2.5456e-01,\n",
|
|
" -1.6216e-01, 2.3130e-02, 2.4984e-01, -3.3239e-01, -2.2947e-01,\n",
|
|
" -2.6681e-01, -9.7903e-02, -9.0469e-02, -2.6217e-01, 4.1510e-01,\n",
|
|
" -1.0590e-01, -5.5713e-02, 9.9271e-03, 4.3321e-01, 1.9454e-01,\n",
|
|
" 1.5135e-01, 2.1670e-01, -1.3371e-02, 2.7091e-02, 1.0805e-01,\n",
|
|
" -4.6743e-01, 2.3397e-01, -2.2627e-01, -1.2724e-01, 2.7149e-02,\n",
|
|
" 8.9104e-02, -3.1547e-02, 1.2930e-02, -1.1888e-01, -1.0141e-01,\n",
|
|
" 2.0849e-01, -3.6962e-01, -1.2304e-02, 2.7230e-01, 1.4519e-01,\n",
|
|
" -2.4969e-01, 4.2865e-02, 1.2965e-01, 3.7797e-01, 8.8492e-02,\n",
|
|
" -2.2487e-01, 1.3100e-01, -3.4240e-01, -2.4896e-02, -1.8675e-01,\n",
|
|
" -2.9198e-01, 1.3836e-01, -6.9468e-02, 5.4983e-02, -6.8482e-02,\n",
|
|
" -2.7968e-01, 2.1223e-01, -5.0621e-02, -6.3859e-02, 4.1759e-01,\n",
|
|
" 3.3747e-02, -1.1644e-01, 1.5398e-01, 1.5137e-02, -5.4925e-03,\n",
|
|
" -1.0726e-01, 2.6553e-01, 2.0031e-01, -2.7755e-01, 1.2135e-01,\n",
|
|
" -1.2860e-01, -2.5987e-02, -1.1620e-01]], grad_fn=<TanhBackward>), hidden_states=None, past_key_values=None, attentions=None, cross_attentions=None)"
|
|
]
|
|
},
|
|
"execution_count": 32,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"output"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"https://huggingface.co/transformers/main_classes/output.html#basemodeloutputwithpoolingandcrossattentionsM"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"https://arxiv.org/pdf/1907.11692.pdf"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 33,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"2"
|
|
]
|
|
},
|
|
"execution_count": 33,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"len(output)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 34,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"torch.Size([1, 16, 768])"
|
|
]
|
|
},
|
|
"execution_count": 34,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"output[0].shape"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 35,
|
|
"metadata": {
|
|
"scrolled": true
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"torch.Size([1, 768])"
|
|
]
|
|
},
|
|
"execution_count": 35,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"\n",
|
|
"output[1].shape"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 36,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"output = model(**encoded_input, output_hidden_states=True)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 37,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"3"
|
|
]
|
|
},
|
|
"execution_count": 37,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"len(output)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 38,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"13"
|
|
]
|
|
},
|
|
"execution_count": 38,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"len(output[2])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 39,
|
|
"metadata": {
|
|
"scrolled": true
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"tensor([[[ 0.1664, -0.0541, -0.0014, ..., -0.0811, 0.0794, 0.0155],\n",
|
|
" [-0.7241, 0.1035, 0.0784, ..., 0.2474, -0.0535, 0.4320],\n",
|
|
" [ 0.5926, -0.1062, 0.0372, ..., -0.0140, 0.1021, -0.2212],\n",
|
|
" ...,\n",
|
|
" [ 0.4734, -0.0570, -0.2506, ..., 0.4071, 0.4481, -0.2180],\n",
|
|
" [ 0.7836, -0.2838, -0.2083, ..., -0.0959, -0.0136, 0.1995],\n",
|
|
" [ 0.2733, -0.1372, -0.0387, ..., 0.5187, 0.1545, -0.2604]]],\n",
|
|
" grad_fn=<NativeLayerNormBackward>)"
|
|
]
|
|
},
|
|
"execution_count": 39,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"output[2][0]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 40,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"torch.Size([1, 16, 768])"
|
|
]
|
|
},
|
|
"execution_count": 40,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"output[2][0].shape"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 41,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"torch.Size([1, 16, 768])"
|
|
]
|
|
},
|
|
"execution_count": 41,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"output[2][1].shape"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 42,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"torch.Size([1, 16, 768])"
|
|
]
|
|
},
|
|
"execution_count": 42,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"output[2][12].shape"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 43,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"output = model(**encoded_input, output_attentions=True)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 44,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"3"
|
|
]
|
|
},
|
|
"execution_count": 44,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"len(output)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 45,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"12"
|
|
]
|
|
},
|
|
"execution_count": 45,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"len(output[2])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 46,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"torch.Size([1, 12, 16, 16])"
|
|
]
|
|
},
|
|
"execution_count": 46,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"output[2][0].shape"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 47,
|
|
"metadata": {
|
|
"scrolled": true
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"tensor([[[[9.8775e-01, 6.2288e-04, 8.7264e-04, ..., 5.4309e-04,\n",
|
|
" 1.3059e-03, 1.0826e-03],\n",
|
|
" [3.3152e-01, 7.3213e-03, 3.0339e-02, ..., 1.6386e-03,\n",
|
|
" 1.1041e-03, 1.0450e-03],\n",
|
|
" [8.4058e-01, 7.4270e-04, 1.8587e-04, ..., 1.9484e-03,\n",
|
|
" 8.3106e-04, 2.2206e-03],\n",
|
|
" ...,\n",
|
|
" [8.3998e-01, 6.3201e-06, 7.9328e-06, ..., 1.8371e-02,\n",
|
|
" 5.9146e-02, 7.1377e-02],\n",
|
|
" [9.4819e-01, 3.9591e-06, 2.9191e-06, ..., 9.6707e-03,\n",
|
|
" 1.1201e-02, 2.4954e-02],\n",
|
|
" [9.2851e-01, 4.9144e-04, 2.2858e-04, ..., 9.3861e-03,\n",
|
|
" 1.7582e-02, 2.4180e-02]],\n",
|
|
"\n",
|
|
" [[9.2353e-01, 4.3481e-03, 1.9423e-02, ..., 5.0829e-03,\n",
|
|
" 7.5931e-03, 4.6599e-03],\n",
|
|
" [9.7840e-01, 4.1909e-03, 9.0263e-03, ..., 2.1102e-06,\n",
|
|
" 5.4437e-07, 6.7581e-06],\n",
|
|
" [8.3596e-01, 6.3265e-02, 7.9091e-02, ..., 1.4975e-05,\n",
|
|
" 2.1750e-06, 2.3804e-06],\n",
|
|
" ...,\n",
|
|
" [4.7469e-01, 1.1083e-04, 1.8293e-03, ..., 9.7021e-03,\n",
|
|
" 6.5544e-03, 1.9043e-03],\n",
|
|
" [2.1963e-01, 1.3427e-06, 1.2042e-04, ..., 7.5510e-01,\n",
|
|
" 2.8724e-03, 6.2941e-03],\n",
|
|
" [4.2043e-01, 3.4030e-06, 6.4028e-05, ..., 8.2335e-02,\n",
|
|
" 3.9994e-01, 9.1114e-02]],\n",
|
|
"\n",
|
|
" [[9.8968e-01, 2.9357e-04, 2.4483e-04, ..., 2.0526e-04,\n",
|
|
" 4.1698e-04, 3.3650e-03],\n",
|
|
" [9.0939e-01, 3.1261e-03, 2.7859e-02, ..., 3.1149e-04,\n",
|
|
" 8.0127e-05, 2.8887e-03],\n",
|
|
" [8.9282e-01, 2.4450e-04, 5.3892e-03, ..., 8.5178e-04,\n",
|
|
" 9.8922e-05, 2.7169e-03],\n",
|
|
" ...,\n",
|
|
" [9.3745e-01, 2.0096e-06, 4.1223e-06, ..., 4.7319e-02,\n",
|
|
" 3.8060e-03, 6.3264e-03],\n",
|
|
" [9.5799e-01, 1.2817e-04, 1.0723e-05, ..., 1.0232e-03,\n",
|
|
" 2.1168e-02, 3.7038e-03],\n",
|
|
" [9.1897e-01, 4.5952e-04, 7.4514e-05, ..., 5.2304e-05,\n",
|
|
" 3.8385e-05, 5.9209e-02]],\n",
|
|
"\n",
|
|
" ...,\n",
|
|
"\n",
|
|
" [[9.7214e-01, 1.8048e-03, 2.0910e-03, ..., 1.5654e-03,\n",
|
|
" 2.0380e-03, 2.9465e-03],\n",
|
|
" [2.0737e-01, 1.5373e-02, 3.4949e-01, ..., 1.0591e-04,\n",
|
|
" 3.8994e-06, 1.9794e-05],\n",
|
|
" [7.0131e-01, 2.8094e-03, 7.6395e-03, ..., 1.2338e-03,\n",
|
|
" 8.6231e-05, 8.1068e-05],\n",
|
|
" ...,\n",
|
|
" [4.1426e-01, 1.9507e-06, 5.5085e-05, ..., 3.8152e-02,\n",
|
|
" 4.5979e-01, 6.9998e-02],\n",
|
|
" [7.5517e-01, 2.2428e-07, 3.2856e-06, ..., 1.3153e-02,\n",
|
|
" 5.5085e-03, 2.1891e-01],\n",
|
|
" [9.4142e-01, 3.3256e-05, 6.0546e-06, ..., 9.1890e-04,\n",
|
|
" 8.7666e-03, 3.8735e-02]],\n",
|
|
"\n",
|
|
" [[9.7447e-01, 1.1291e-03, 2.3473e-03, ..., 1.6628e-03,\n",
|
|
" 1.7247e-03, 3.7978e-03],\n",
|
|
" [7.2027e-01, 5.4353e-02, 5.0394e-03, ..., 4.7070e-03,\n",
|
|
" 1.4477e-03, 7.9330e-02],\n",
|
|
" [9.1602e-01, 6.2537e-03, 6.2520e-03, ..., 3.0431e-03,\n",
|
|
" 1.6902e-03, 2.6523e-02],\n",
|
|
" ...,\n",
|
|
" [8.7035e-01, 5.6680e-03, 2.5519e-04, ..., 1.0693e-02,\n",
|
|
" 1.0154e-02, 2.8158e-02],\n",
|
|
" [7.8992e-01, 1.3184e-03, 5.2799e-04, ..., 3.8399e-03,\n",
|
|
" 2.3379e-02, 5.4757e-02],\n",
|
|
" [4.0584e-01, 5.6631e-03, 8.5153e-03, ..., 1.0006e-02,\n",
|
|
" 1.0799e-02, 1.9912e-01]],\n",
|
|
"\n",
|
|
" [[9.8713e-01, 3.3973e-04, 9.6788e-04, ..., 2.1040e-04,\n",
|
|
" 1.3595e-03, 8.0080e-04],\n",
|
|
" [1.0312e-01, 4.2905e-03, 8.3475e-01, ..., 7.3782e-06,\n",
|
|
" 1.9842e-04, 1.3445e-03],\n",
|
|
" [7.9036e-01, 2.8547e-02, 5.0725e-02, ..., 1.9356e-05,\n",
|
|
" 6.4891e-05, 2.8477e-03],\n",
|
|
" ...,\n",
|
|
" [2.1335e-01, 9.7233e-06, 6.9469e-05, ..., 3.6693e-04,\n",
|
|
" 3.3324e-01, 1.3384e-02],\n",
|
|
" [1.1667e-02, 3.0911e-05, 2.5899e-06, ..., 5.6125e-01,\n",
|
|
" 2.7517e-04, 1.5053e-03],\n",
|
|
" [8.4494e-01, 8.0791e-04, 1.0116e-03, ..., 2.4602e-03,\n",
|
|
" 6.7727e-02, 1.1728e-02]]]], grad_fn=<SoftmaxBackward>)"
|
|
]
|
|
},
|
|
"execution_count": 47,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"output[2][2]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## gotowe api"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### generowanie tekstu"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 48,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Some weights of GPT2Model were not initialized from the model checkpoint at gpt2 and are newly initialized: ['h.0.attn.masked_bias', 'h.1.attn.masked_bias', 'h.2.attn.masked_bias', 'h.3.attn.masked_bias', 'h.4.attn.masked_bias', 'h.5.attn.masked_bias', 'h.6.attn.masked_bias', 'h.7.attn.masked_bias', 'h.8.attn.masked_bias', 'h.9.attn.masked_bias', 'h.10.attn.masked_bias', 'h.11.attn.masked_bias']\n",
|
|
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"model = pipeline('text-generation', model='gpt2')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 49,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"[{'generated_text': 'Hello, I\\'m a computer science student by trade. I don\\'t really like science.\"\\n\\nThen I hear him say: \"I love the'},\n",
|
|
" {'generated_text': \"Hello, I'm a computer science student.\\n\\nAnd if you're curious what I'm doing here, don't hesitate:\\n\\n\\nI've\"},\n",
|
|
" {'generated_text': \"Hello, I'm a computer science student, not an engineer. But, I'm also fascinated, because all the people I'm talking to are engineers\"},\n",
|
|
" {'generated_text': \"Hello, I'm a computer science student with a big project called the Data Science project, to help students create and understand and improve their data science.\"},\n",
|
|
" {'generated_text': \"Hello, I'm a computer science student from North Carolina. My work involves a number of questions (and many possible answers as well). I can't\"}]"
|
|
]
|
|
},
|
|
"execution_count": 49,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"model(\"Hello, I'm a computer science student\", max_length=30, num_return_sequences=5)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 50,
|
|
"metadata": {
|
|
"scrolled": true
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"[{'generated_text': \"I want to contribute to Google's Computer Vision Program, which is doing extensive work on big data. It is also working on creating a new type of\"},\n",
|
|
" {'generated_text': \"I want to contribute to Google's Computer Vision Program, which is doing extensive work on big datasets. It has now given every request to Google for information\"},\n",
|
|
" {'generated_text': \"I want to contribute to Google's Computer Vision Program, which is doing extensive work on big data, robotics and artificial intelligence. We understand the potential impact\"},\n",
|
|
" {'generated_text': \"I want to contribute to Google's Computer Vision Program, which is doing extensive work on big data, artificial intelligence, and machine learning. I think we\"},\n",
|
|
" {'generated_text': \"I want to contribute to Google's Computer Vision Program, which is doing extensive work on big new ways to see and interact with image data. We'll\"}]"
|
|
]
|
|
},
|
|
"execution_count": 50,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"model(\"I want to contribute to Google's Computer Vision Program, which is doing extensive work on big\", max_length=30, num_return_sequences=5)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### sentiment analysis"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 51,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from transformers import pipeline\n",
|
|
"\n",
|
|
"model = pipeline(\"sentiment-analysis\", model='distilbert-base-uncased-finetuned-sst-2-english')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 52,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"<transformers.pipelines.text_classification.TextClassificationPipeline at 0x7f8e45b319d0>"
|
|
]
|
|
},
|
|
"execution_count": 52,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"model"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 53,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"[{'label': 'POSITIVE', 'score': 0.9998474717140198}]"
|
|
]
|
|
},
|
|
"execution_count": 53,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"model(\"I'm very happy. Today is the beatifull weather\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 54,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"[{'label': 'NEGATIVE', 'score': 0.9946851134300232}]"
|
|
]
|
|
},
|
|
"execution_count": 54,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"model(\"It's raining. What a terrible day...\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## NER"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 55,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"model = pipeline(\"sentiment-analysis\", model='distilbert-base-uncased-finetuned-sst-2-english')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 56,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from transformers import pipeline\n",
|
|
"model = pipeline(\"ner\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 57,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"text = \"George Washington went to Washington\""
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 58,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"[{'word': 'George',\n",
|
|
" 'score': 0.9983943104743958,\n",
|
|
" 'entity': 'I-PER',\n",
|
|
" 'index': 1,\n",
|
|
" 'start': 0,\n",
|
|
" 'end': 6},\n",
|
|
" {'word': 'Washington',\n",
|
|
" 'score': 0.9992505311965942,\n",
|
|
" 'entity': 'I-PER',\n",
|
|
" 'index': 2,\n",
|
|
" 'start': 7,\n",
|
|
" 'end': 17},\n",
|
|
" {'word': 'Washington',\n",
|
|
" 'score': 0.98389732837677,\n",
|
|
" 'entity': 'I-LOC',\n",
|
|
" 'index': 5,\n",
|
|
" 'start': 26,\n",
|
|
" 'end': 36}]"
|
|
]
|
|
},
|
|
"execution_count": 58,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"model(text)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### masked language modelling"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### ZADANIE (10 minut)\n",
|
|
"\n",
|
|
"przewidziać <mask> token w \"The world <MASK> II started in 1939\"\" wg dowolnego anglojęzycznego modelu"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.8.3"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 4
|
|
}
|