{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# bpe" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "pip install tokenizers" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "https://github.com/huggingface/tokenizers/tree/master/bindings/python" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from tokenizers import Tokenizer, models, trainers" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "from tokenizers.trainers import BpeTrainer" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "tokenizer = Tokenizer(models.BPE())" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "trainer = trainers.BpeTrainer(vocab_size=20000, min_frequency=2)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "tokenizer.train(files = ['/home/kuba/Syncthing/przedmioty/2020-02/ISI/zajecia9_ngramowy_model_jDDezykowy/pan-tadeusz-train.txt'], trainer = trainer)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "output = tokenizer.encode(\"Nie śpiewają piosenek: pracują leniwo,\")" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[236, 2255, 2069, 3898, 9908, 14, 8675, 8319, 191, 7]" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "output.ids" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['Nie', ' śpie', 'wają', ' pios', 'enek', ':', ' pracują', ' leni', 'wo', ',']" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "output.tokens" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "tokenizer.save(\"./my-bpe.tokenizer.json\", pretty=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## ZADANIE\n", "stworzyć BPE tokenizer na podstawie https://git.wmi.amu.edu.pl/kubapok/lalka-lm/src/branch/master/train/train.tsv\n", "i stworzyć stokenizowaną listę: \n", "https://git.wmi.amu.edu.pl/kubapok/lalka-lm/src/branch/master/test-A/in.tsv\n", "\n", "wybrać vocab_size = 8k, uwzględnić dodatkowe tokeny: BOS oraz EOS i wpleść je do zbioru testowego" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# transformery" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "# pip install transformers" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "import torch" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "from transformers import pipeline, set_seed" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [], "source": [ "from transformers import RobertaTokenizer, RobertaModel" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [], "source": [ "tokenizer = RobertaTokenizer.from_pretrained('roberta-base')\n" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [], "source": [ "model = RobertaModel.from_pretrained('roberta-base')" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [], "source": [ "text = \"Replace me by any text you'd like. Bla Bla\"" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [], "source": [ "encoded_input = tokenizer(text, return_tensors='pt')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[ 0, 9064, 6406, 162, 30, 143, 2788, 47, 1017, 101, 4, 2091,\n", " 102, 2091, 102, 2]])" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "encoded_input['input_ids']" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[ 0, 9064, 6406, 162, 30, 143, 2788, 47, 1017, 101, 4, 2091,\n", " 102, 2091, 102, 2]])" ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "source": [ "encoded_input['input_ids']" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "' me'" ] }, "execution_count": 30, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tokenizer.decode([162])" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [], "source": [ "output = model(**encoded_input)" ] }, { "cell_type": "code", "execution_count": 32, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "BaseModelOutputWithPoolingAndCrossAttentions(last_hidden_state=tensor([[[-4.4858e-02, 8.6642e-02, -7.2129e-03, ..., -4.6295e-02,\n", " -3.9316e-02, 4.5264e-04],\n", " [-6.0603e-02, 1.5684e-01, 4.3705e-02, ..., 5.3485e-01,\n", " 8.4371e-02, 1.4826e-01],\n", " [-2.3786e-02, -1.2086e-02, 7.8233e-02, ..., -4.9132e-01,\n", " 1.2500e-01, 3.3293e-01],\n", " ...,\n", " [ 6.7192e-02, 2.4028e-01, -2.9984e-01, ..., 2.1992e-01,\n", " 1.9186e-02, 1.5355e-01],\n", " [ 1.7611e-01, 1.4001e-01, -1.3774e-01, ..., -5.0379e-01,\n", " 7.3958e-02, 3.7870e-02],\n", " [-3.4405e-02, 8.7648e-02, -3.9429e-02, ..., -9.1916e-02,\n", " -3.5529e-02, -2.6777e-02]]], grad_fn=), pooler_output=tensor([[ 5.1994e-03, -2.1290e-01, -2.2585e-01, -9.3315e-02, 1.1761e-01,\n", " 1.9024e-01, 2.4873e-01, -8.1097e-02, -4.4050e-02, -1.6596e-01,\n", " 2.1572e-01, -8.3736e-03, -7.7416e-02, 8.2714e-02, -1.2487e-01,\n", " 4.8405e-01, 2.1145e-01, -4.4653e-01, 4.1008e-02, -1.2578e-02,\n", " -2.5560e-01, 7.3874e-02, 4.6924e-01, 3.2284e-01, 1.2382e-01,\n", " 6.3117e-02, -1.2633e-01, -1.3542e-02, 1.6195e-01, 2.1738e-01,\n", " 2.8682e-01, 6.2675e-02, 8.5778e-02, 2.3686e-01, -2.5219e-01,\n", " 4.3382e-02, -3.0992e-01, 4.0551e-02, 2.3078e-01, -1.8763e-01,\n", " -7.2084e-02, 1.5847e-01, 2.0228e-01, -1.2756e-01, -1.1429e-01,\n", " 3.9603e-01, 2.6018e-01, 2.9297e-02, -1.2162e-01, -8.2051e-02,\n", " -3.5665e-01, 3.4722e-01, 2.8292e-01, 1.9929e-01, -1.4734e-02,\n", " 4.8892e-02, -1.4596e-01, 2.5527e-01, -7.5540e-02, -8.7507e-02,\n", " -1.2164e-01, -2.0039e-01, -7.1405e-03, -5.9407e-02, 4.1914e-02,\n", " -1.3208e-01, 6.8612e-02, -1.4154e-01, -1.1040e-01, 5.3550e-02,\n", " -7.3125e-02, 1.6401e-01, 1.6232e-01, -3.0164e-01, -2.8429e-01,\n", " 5.9359e-02, -5.7466e-01, -9.3444e-02, 3.1263e-01, 4.3317e-01,\n", " -1.0505e-01, 2.0451e-01, 3.2512e-02, 2.0723e-01, -2.0692e-02,\n", " -7.1687e-02, -4.7481e-02, -9.9879e-02, 1.8269e-01, 2.6383e-01,\n", " -1.8467e-01, -3.7178e-01, 7.7521e-02, 2.5127e-02, -1.0401e-01,\n", " 3.6107e-03, -2.1520e-02, -8.3309e-02, -1.6115e-01, -1.6062e-01,\n", " 4.8140e-02, -2.4970e-01, -1.4153e-01, 2.7024e-01, -3.3932e-02,\n", " -2.0374e-01, -2.1484e-02, 2.9751e-01, 7.9157e-02, -1.1653e-01,\n", " -1.9188e-01, 4.2722e-01, 2.7925e-01, 2.1393e-03, -1.2266e-02,\n", " 1.6474e-01, 1.5236e-01, -2.8896e-01, 4.2948e-01, -3.1129e-01,\n", " -8.3017e-03, -1.2746e-01, 9.3701e-02, 1.3876e-01, -2.2297e-01,\n", " 2.7802e-01, 1.4876e-01, 2.7402e-01, 1.8048e-01, 1.0370e-01,\n", " -1.9344e-02, 1.5151e-01, -1.0991e-01, 1.4710e-01, 2.1547e-01,\n", " 1.1827e-01, -3.4855e-03, -3.2536e-01, -2.0146e-01, 2.6933e-01,\n", " 3.2889e-01, 1.6086e-01, -4.0214e-02, 1.7948e-01, 9.3386e-02,\n", " 2.3509e-01, 1.3713e-01, -3.9689e-01, 2.7102e-02, 3.2785e-01,\n", " 9.5205e-02, 1.6616e-01, -8.5767e-02, -2.9547e-01, -2.5672e-01,\n", " -9.8167e-02, 4.9819e-02, -3.1859e-01, -1.0416e-01, 3.6841e-01,\n", " 3.9702e-02, 1.0092e-02, -1.5538e-01, -2.3540e-01, -3.1184e-02,\n", " -1.0212e-01, 2.2566e-02, 8.2764e-02, -8.8543e-02, -4.3323e-01,\n", " -9.3633e-02, -5.2647e-01, -1.1645e-01, 1.9042e-01, -3.2305e-01,\n", " 2.3381e-01, -2.9311e-01, 9.7979e-02, 3.9168e-01, 4.9259e-02,\n", " -1.4776e-03, -1.9380e-01, -2.3509e-02, 1.1105e-01, 3.2593e-01,\n", " 2.3894e-01, -4.0230e-01, 1.0931e-01, 1.4188e-01, 2.5775e-01,\n", " 1.4593e-01, -6.1273e-02, -1.1612e-01, 1.5098e-01, -2.0754e-01,\n", " 1.7434e-01, -2.2652e-01, 1.8130e-01, -2.4506e-01, -2.1763e-01,\n", " 2.7327e-01, -4.0463e-01, -3.2816e-02, 8.4527e-02, 2.6249e-01,\n", " 5.3814e-03, -3.7149e-02, -8.5346e-02, 1.2447e-01, 1.6730e-01,\n", " 1.2632e-01, -3.9535e-01, 2.6029e-01, -2.7662e-02, -1.0249e-02,\n", " -3.3478e-02, 1.7810e-01, 2.5045e-01, 8.5239e-02, -3.8860e-01,\n", " -1.2383e-01, 1.0050e-01, 2.8403e-01, -2.3275e-01, 1.4317e-01,\n", " -2.5891e-01, -3.7882e-01, -1.4209e-01, 2.0352e-01, 2.2653e-01,\n", " 1.7335e-01, -2.7438e-01, 1.6319e-01, -1.0932e-01, -4.1693e-01,\n", " -3.6249e-01, -1.0035e-01, 2.3360e-01, 1.6881e-01, 1.8671e-01,\n", " 2.4667e-01, 3.3260e-02, 1.0321e-01, 1.5088e-01, 1.4948e-01,\n", " -1.4621e-01, 1.5615e-01, -3.4600e-01, -4.0963e-02, -2.5380e-01,\n", " -1.9230e-01, -2.1718e-01, 3.9236e-01, -2.2989e-01, 2.3594e-01,\n", " 3.7562e-01, -3.1114e-01, -1.2649e-01, 1.6051e-01, 1.0577e-01,\n", " 8.3363e-02, -1.2593e-01, 2.0584e-01, 1.4127e-01, -1.0632e-01,\n", " 2.3048e-01, -9.1782e-03, 2.4953e-01, 1.7332e-01, 8.9003e-02,\n", " 1.5702e-01, 1.1363e-01, -1.4771e-01, 3.8719e-02, 3.1532e-03,\n", " -1.9275e-02, -2.3103e-01, -1.4421e-01, 2.3575e-01, -5.3696e-02,\n", " 3.7994e-02, -1.6641e-01, -1.1399e-01, 1.1890e-02, 3.9652e-01,\n", " -3.6048e-01, 2.4448e-01, 7.5291e-02, 1.5195e-01, -2.2864e-01,\n", " -2.1031e-01, 8.9896e-02, 1.6992e-01, -3.9863e-01, 1.7435e-03,\n", " 1.6546e-01, 1.0483e-01, 2.1218e-01, 2.8159e-01, -5.0768e-05,\n", " -9.4031e-02, 4.9128e-01, -1.6932e-01, -1.2761e-01, 2.5249e-01,\n", " -2.7310e-01, -2.7744e-01, 2.4945e-01, -2.9252e-02, 3.0148e-01,\n", " 1.1743e-01, 3.8095e-02, 6.4731e-02, -6.0554e-01, 6.6722e-02,\n", " -4.5959e-01, -8.5795e-03, 3.6827e-02, -7.1025e-02, -2.0848e-01,\n", " 1.4228e-01, 2.9630e-01, -2.4077e-01, -4.3098e-02, 2.2334e-01,\n", " 8.4614e-02, -1.2553e-01, 4.8810e-01, 1.4035e-03, 2.0182e-01,\n", " -6.3799e-02, 2.4192e-01, -2.0799e-01, 2.6687e-01, -2.7694e-01,\n", " -9.9754e-02, 4.7169e-03, 8.3846e-02, 4.5165e-02, -5.9169e-02,\n", " -3.5243e-01, 2.2125e-01, -1.9798e-02, -5.7139e-02, -4.0613e-02,\n", " 9.3967e-02, -1.7488e-02, 6.2663e-02, 5.2265e-02, 3.4857e-01,\n", " 2.2626e-01, -2.6472e-02, -3.7240e-01, -1.1370e-02, -9.6964e-02,\n", " 5.6610e-02, 2.6788e-02, -2.4786e-02, 4.3236e-01, -7.6491e-02,\n", " 5.2428e-03, -1.4826e-01, 2.6029e-01, 1.8327e-01, 1.2439e-01,\n", " 1.3042e-01, 6.1801e-02, 1.3667e-01, -6.7373e-02, -1.2048e-02,\n", " -1.4603e-01, -2.2505e-01, -2.9760e-01, 2.0056e-01, -2.2011e-01,\n", " -1.8115e-01, 1.4262e-01, 2.1523e-01, -1.3893e-01, 1.4466e-01,\n", " 2.9357e-01, 1.2101e-01, -1.4499e-01, 2.6571e-01, -1.0101e-01,\n", " 1.1599e-01, 2.9781e-01, -2.0156e-02, 2.0140e-01, 5.0007e-01,\n", " 2.1717e-01, -3.5394e-01, -1.6840e-02, -2.2424e-01, 1.1153e-02,\n", " 2.4465e-01, -1.5139e-01, 1.9410e-01, 3.8725e-01, 2.9424e-01,\n", " 4.2920e-01, -7.3521e-03, -1.1083e-01, 9.0861e-02, 2.2493e-01,\n", " 2.7805e-02, -1.5957e-01, -1.9878e-01, 2.5621e-01, 6.2884e-02,\n", " -1.5698e-01, -1.3003e-02, -1.1539e-01, 2.8383e-02, -1.2329e-01,\n", " -3.8461e-01, 3.5215e-02, 1.8146e-01, -4.7366e-01, 7.3029e-02,\n", " -2.9380e-01, 4.3577e-02, -2.3151e-01, 2.0148e-01, -2.2549e-01,\n", " -1.0662e-01, 3.9916e-01, -7.4222e-02, 4.3181e-02, -1.7905e-01,\n", " -1.3722e-01, 2.4749e-02, 7.4731e-03, -1.4543e-02, -4.1486e-03,\n", " 3.3529e-01, -1.2101e-01, 3.6759e-02, 3.6844e-02, 1.9582e-01,\n", " -5.1381e-02, 2.0516e-01, 3.1175e-02, -1.4019e-01, -4.0386e-01,\n", " 1.4271e-01, -1.8793e-01, -4.2023e-01, -3.6638e-01, 3.6652e-01,\n", " -1.3753e-01, -2.5359e-01, -2.0423e-01, -2.4466e-01, 8.1067e-02,\n", " 1.6987e-01, 4.7120e-01, -3.9858e-01, -6.8325e-02, 4.7077e-01,\n", " -6.8745e-02, -1.8953e-01, 2.7360e-01, 1.8793e-01, -3.3325e-01,\n", " 3.1144e-01, 2.6919e-01, -5.6080e-02, 1.5771e-02, 5.0668e-01,\n", " 1.1729e-01, 1.8437e-01, -2.0954e-01, 4.4338e-01, -2.1112e-01,\n", " 3.1039e-01, -1.6460e-01, -2.1319e-01, -2.1592e-01, -1.9942e-02,\n", " 3.3144e-01, 1.8923e-01, -4.2029e-01, -1.0169e-01, 3.1353e-02,\n", " 3.6021e-01, -3.7626e-01, -8.6387e-02, 1.3697e-02, -3.3636e-01,\n", " 1.2770e-01, 1.0668e-01, 2.2197e-01, -3.7968e-01, -1.5053e-02,\n", " 3.9753e-01, -2.9535e-01, 1.3459e-01, 3.2518e-01, 7.6786e-02,\n", " 3.4168e-01, -2.8172e-02, 1.0189e-02, 5.9536e-02, -2.3156e-01,\n", " -3.8199e-02, 1.3041e-01, 5.4866e-01, 1.5127e-01, -3.6896e-01,\n", " 9.5292e-02, 2.4462e-01, -1.6506e-01, 3.1529e-01, -8.9680e-02,\n", " -4.6637e-02, 2.6508e-01, -3.6751e-02, 1.5445e-01, -9.7824e-02,\n", " -2.1623e-01, -3.0666e-01, 3.6944e-01, -1.8711e-01, -1.1481e-01,\n", " -1.6787e-01, -1.1253e-01, -1.4680e-01, 4.1271e-02, -3.6980e-01,\n", " 3.3081e-01, 1.2455e-01, -1.8123e-01, -6.8767e-02, -9.6390e-02,\n", " -1.4910e-01, -2.0524e-01, -2.6686e-01, 4.2154e-01, -1.6543e-01,\n", " -4.5050e-01, 2.5019e-01, 3.4722e-02, 3.3103e-01, 4.8806e-02,\n", " 9.9796e-02, -3.0042e-02, 1.4140e-01, 9.6566e-02, -1.1071e-01,\n", " 2.7989e-01, 6.1347e-02, -5.6286e-01, -1.4422e-01, -2.1070e-01,\n", " 7.6292e-02, 1.9691e-01, -3.3576e-01, 1.7638e-02, 2.2105e-02,\n", " 1.4265e-01, 3.0694e-02, -1.0665e-01, -5.5197e-02, 3.9164e-01,\n", " 2.0961e-01, 2.9841e-01, 8.7946e-02, 2.4076e-01, -1.0636e-02,\n", " -3.3807e-01, 2.2974e-02, 8.5258e-02, -1.8663e-01, 4.1414e-01,\n", " -9.9141e-02, -3.8117e-01, -5.6155e-02, 3.9692e-01, 9.7551e-02,\n", " -1.8710e-02, -5.0913e-02, 2.0049e-01, 1.5407e-01, -1.2523e-01,\n", " 1.8187e-01, -9.7470e-03, -1.3372e-01, -1.0178e-01, 8.5468e-02,\n", " -2.1953e-01, 4.7566e-02, -1.3239e-01, 1.3200e-03, -2.0911e-01,\n", " 3.2521e-03, -2.1387e-01, 2.4508e-01, -3.2182e-01, 9.8634e-02,\n", " 7.5848e-02, 2.9231e-01, -3.5121e-01, -1.4159e-01, -5.7937e-02,\n", " 1.6263e-01, 2.5830e-01, 3.5601e-01, 2.3997e-02, 2.5322e-02,\n", " -1.5363e-01, -2.6361e-01, 5.4986e-02, -2.0897e-01, 1.2282e-01,\n", " 7.1346e-02, 2.4762e-01, -3.0430e-01, -1.8016e-01, 2.2226e-01,\n", " -9.7989e-02, -1.4158e-01, 4.2292e-01, 2.5139e-01, 2.1049e-01,\n", " 2.2865e-02, 2.4210e-01, 3.7744e-02, -1.7568e-01, -1.1512e-01,\n", " -2.4392e-01, 6.9097e-02, -8.5799e-02, -5.8893e-02, -7.1211e-02,\n", " -1.2143e-01, -1.9825e-01, -1.5658e-01, 1.5637e-01, 1.3693e-01,\n", " 2.9095e-02, -5.5552e-02, -2.4771e-02, -2.7771e-01, 2.9286e-01,\n", " 2.4894e-02, 7.2069e-02, -4.8322e-02, 2.3967e-02, -1.5199e-01,\n", " 2.3989e-01, 2.0234e-01, 8.2009e-02, -1.8899e-01, -4.8667e-02,\n", " -2.9075e-01, -3.5470e-01, 4.1930e-02, 1.3129e-01, 1.1387e-01,\n", " -1.0165e-01, -2.7247e-01, -2.7974e-02, -1.3051e-01, 1.8051e-01,\n", " -9.6646e-03, -1.5500e-01, -7.4565e-02, -6.0039e-02, -5.1055e-02,\n", " 6.7692e-02, -2.0781e-01, -1.9844e-01, -1.2495e-01, -7.5151e-02,\n", " -6.6146e-02, 3.6196e-01, -3.5989e-02, 2.7737e-01, -1.5471e-01,\n", " 1.1208e-02, -1.9818e-01, 1.0743e-01, -7.3001e-02, 7.3365e-02,\n", " 2.6398e-01, -4.2969e-01, -1.5308e-01, 6.1186e-03, -2.1301e-01,\n", " -1.4149e-01, -7.1113e-02, -4.0364e-02, 2.1242e-01, -3.4205e-01,\n", " 2.1659e-01, -8.0915e-02, 1.8907e-01, -9.4013e-02, -2.5456e-01,\n", " -1.6216e-01, 2.3130e-02, 2.4984e-01, -3.3239e-01, -2.2947e-01,\n", " -2.6681e-01, -9.7903e-02, -9.0469e-02, -2.6217e-01, 4.1510e-01,\n", " -1.0590e-01, -5.5713e-02, 9.9271e-03, 4.3321e-01, 1.9454e-01,\n", " 1.5135e-01, 2.1670e-01, -1.3371e-02, 2.7091e-02, 1.0805e-01,\n", " -4.6743e-01, 2.3397e-01, -2.2627e-01, -1.2724e-01, 2.7149e-02,\n", " 8.9104e-02, -3.1547e-02, 1.2930e-02, -1.1888e-01, -1.0141e-01,\n", " 2.0849e-01, -3.6962e-01, -1.2304e-02, 2.7230e-01, 1.4519e-01,\n", " -2.4969e-01, 4.2865e-02, 1.2965e-01, 3.7797e-01, 8.8492e-02,\n", " -2.2487e-01, 1.3100e-01, -3.4240e-01, -2.4896e-02, -1.8675e-01,\n", " -2.9198e-01, 1.3836e-01, -6.9468e-02, 5.4983e-02, -6.8482e-02,\n", " -2.7968e-01, 2.1223e-01, -5.0621e-02, -6.3859e-02, 4.1759e-01,\n", " 3.3747e-02, -1.1644e-01, 1.5398e-01, 1.5137e-02, -5.4925e-03,\n", " -1.0726e-01, 2.6553e-01, 2.0031e-01, -2.7755e-01, 1.2135e-01,\n", " -1.2860e-01, -2.5987e-02, -1.1620e-01]], grad_fn=), hidden_states=None, past_key_values=None, attentions=None, cross_attentions=None)" ] }, "execution_count": 32, "metadata": {}, "output_type": "execute_result" } ], "source": [ "output" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "https://huggingface.co/transformers/main_classes/output.html#basemodeloutputwithpoolingandcrossattentionsM" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "https://arxiv.org/pdf/1907.11692.pdf" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "2" ] }, "execution_count": 33, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(output)" ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([1, 16, 768])" ] }, "execution_count": 34, "metadata": {}, "output_type": "execute_result" } ], "source": [ "output[0].shape" ] }, { "cell_type": "code", "execution_count": 35, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "torch.Size([1, 768])" ] }, "execution_count": 35, "metadata": {}, "output_type": "execute_result" } ], "source": [ "\n", "output[1].shape" ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [], "source": [ "output = model(**encoded_input, output_hidden_states=True)" ] }, { "cell_type": "code", "execution_count": 37, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "3" ] }, "execution_count": 37, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(output)" ] }, { "cell_type": "code", "execution_count": 38, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "13" ] }, "execution_count": 38, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(output[2])" ] }, { "cell_type": "code", "execution_count": 39, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "tensor([[[ 0.1664, -0.0541, -0.0014, ..., -0.0811, 0.0794, 0.0155],\n", " [-0.7241, 0.1035, 0.0784, ..., 0.2474, -0.0535, 0.4320],\n", " [ 0.5926, -0.1062, 0.0372, ..., -0.0140, 0.1021, -0.2212],\n", " ...,\n", " [ 0.4734, -0.0570, -0.2506, ..., 0.4071, 0.4481, -0.2180],\n", " [ 0.7836, -0.2838, -0.2083, ..., -0.0959, -0.0136, 0.1995],\n", " [ 0.2733, -0.1372, -0.0387, ..., 0.5187, 0.1545, -0.2604]]],\n", " grad_fn=)" ] }, "execution_count": 39, "metadata": {}, "output_type": "execute_result" } ], "source": [ "output[2][0]" ] }, { "cell_type": "code", "execution_count": 40, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([1, 16, 768])" ] }, "execution_count": 40, "metadata": {}, "output_type": "execute_result" } ], "source": [ "output[2][0].shape" ] }, { "cell_type": "code", "execution_count": 41, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([1, 16, 768])" ] }, "execution_count": 41, "metadata": {}, "output_type": "execute_result" } ], "source": [ "output[2][1].shape" ] }, { "cell_type": "code", "execution_count": 42, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([1, 16, 768])" ] }, "execution_count": 42, "metadata": {}, "output_type": "execute_result" } ], "source": [ "output[2][12].shape" ] }, { "cell_type": "code", "execution_count": 43, "metadata": {}, "outputs": [], "source": [ "output = model(**encoded_input, output_attentions=True)" ] }, { "cell_type": "code", "execution_count": 44, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "3" ] }, "execution_count": 44, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(output)" ] }, { "cell_type": "code", "execution_count": 45, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "12" ] }, "execution_count": 45, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(output[2])" ] }, { "cell_type": "code", "execution_count": 46, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([1, 12, 16, 16])" ] }, "execution_count": 46, "metadata": {}, "output_type": "execute_result" } ], "source": [ "output[2][0].shape" ] }, { "cell_type": "code", "execution_count": 47, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "tensor([[[[9.8775e-01, 6.2288e-04, 8.7264e-04, ..., 5.4309e-04,\n", " 1.3059e-03, 1.0826e-03],\n", " [3.3152e-01, 7.3213e-03, 3.0339e-02, ..., 1.6386e-03,\n", " 1.1041e-03, 1.0450e-03],\n", " [8.4058e-01, 7.4270e-04, 1.8587e-04, ..., 1.9484e-03,\n", " 8.3106e-04, 2.2206e-03],\n", " ...,\n", " [8.3998e-01, 6.3201e-06, 7.9328e-06, ..., 1.8371e-02,\n", " 5.9146e-02, 7.1377e-02],\n", " [9.4819e-01, 3.9591e-06, 2.9191e-06, ..., 9.6707e-03,\n", " 1.1201e-02, 2.4954e-02],\n", " [9.2851e-01, 4.9144e-04, 2.2858e-04, ..., 9.3861e-03,\n", " 1.7582e-02, 2.4180e-02]],\n", "\n", " [[9.2353e-01, 4.3481e-03, 1.9423e-02, ..., 5.0829e-03,\n", " 7.5931e-03, 4.6599e-03],\n", " [9.7840e-01, 4.1909e-03, 9.0263e-03, ..., 2.1102e-06,\n", " 5.4437e-07, 6.7581e-06],\n", " [8.3596e-01, 6.3265e-02, 7.9091e-02, ..., 1.4975e-05,\n", " 2.1750e-06, 2.3804e-06],\n", " ...,\n", " [4.7469e-01, 1.1083e-04, 1.8293e-03, ..., 9.7021e-03,\n", " 6.5544e-03, 1.9043e-03],\n", " [2.1963e-01, 1.3427e-06, 1.2042e-04, ..., 7.5510e-01,\n", " 2.8724e-03, 6.2941e-03],\n", " [4.2043e-01, 3.4030e-06, 6.4028e-05, ..., 8.2335e-02,\n", " 3.9994e-01, 9.1114e-02]],\n", "\n", " [[9.8968e-01, 2.9357e-04, 2.4483e-04, ..., 2.0526e-04,\n", " 4.1698e-04, 3.3650e-03],\n", " [9.0939e-01, 3.1261e-03, 2.7859e-02, ..., 3.1149e-04,\n", " 8.0127e-05, 2.8887e-03],\n", " [8.9282e-01, 2.4450e-04, 5.3892e-03, ..., 8.5178e-04,\n", " 9.8922e-05, 2.7169e-03],\n", " ...,\n", " [9.3745e-01, 2.0096e-06, 4.1223e-06, ..., 4.7319e-02,\n", " 3.8060e-03, 6.3264e-03],\n", " [9.5799e-01, 1.2817e-04, 1.0723e-05, ..., 1.0232e-03,\n", " 2.1168e-02, 3.7038e-03],\n", " [9.1897e-01, 4.5952e-04, 7.4514e-05, ..., 5.2304e-05,\n", " 3.8385e-05, 5.9209e-02]],\n", "\n", " ...,\n", "\n", " [[9.7214e-01, 1.8048e-03, 2.0910e-03, ..., 1.5654e-03,\n", " 2.0380e-03, 2.9465e-03],\n", " [2.0737e-01, 1.5373e-02, 3.4949e-01, ..., 1.0591e-04,\n", " 3.8994e-06, 1.9794e-05],\n", " [7.0131e-01, 2.8094e-03, 7.6395e-03, ..., 1.2338e-03,\n", " 8.6231e-05, 8.1068e-05],\n", " ...,\n", " [4.1426e-01, 1.9507e-06, 5.5085e-05, ..., 3.8152e-02,\n", " 4.5979e-01, 6.9998e-02],\n", " [7.5517e-01, 2.2428e-07, 3.2856e-06, ..., 1.3153e-02,\n", " 5.5085e-03, 2.1891e-01],\n", " [9.4142e-01, 3.3256e-05, 6.0546e-06, ..., 9.1890e-04,\n", " 8.7666e-03, 3.8735e-02]],\n", "\n", " [[9.7447e-01, 1.1291e-03, 2.3473e-03, ..., 1.6628e-03,\n", " 1.7247e-03, 3.7978e-03],\n", " [7.2027e-01, 5.4353e-02, 5.0394e-03, ..., 4.7070e-03,\n", " 1.4477e-03, 7.9330e-02],\n", " [9.1602e-01, 6.2537e-03, 6.2520e-03, ..., 3.0431e-03,\n", " 1.6902e-03, 2.6523e-02],\n", " ...,\n", " [8.7035e-01, 5.6680e-03, 2.5519e-04, ..., 1.0693e-02,\n", " 1.0154e-02, 2.8158e-02],\n", " [7.8992e-01, 1.3184e-03, 5.2799e-04, ..., 3.8399e-03,\n", " 2.3379e-02, 5.4757e-02],\n", " [4.0584e-01, 5.6631e-03, 8.5153e-03, ..., 1.0006e-02,\n", " 1.0799e-02, 1.9912e-01]],\n", "\n", " [[9.8713e-01, 3.3973e-04, 9.6788e-04, ..., 2.1040e-04,\n", " 1.3595e-03, 8.0080e-04],\n", " [1.0312e-01, 4.2905e-03, 8.3475e-01, ..., 7.3782e-06,\n", " 1.9842e-04, 1.3445e-03],\n", " [7.9036e-01, 2.8547e-02, 5.0725e-02, ..., 1.9356e-05,\n", " 6.4891e-05, 2.8477e-03],\n", " ...,\n", " [2.1335e-01, 9.7233e-06, 6.9469e-05, ..., 3.6693e-04,\n", " 3.3324e-01, 1.3384e-02],\n", " [1.1667e-02, 3.0911e-05, 2.5899e-06, ..., 5.6125e-01,\n", " 2.7517e-04, 1.5053e-03],\n", " [8.4494e-01, 8.0791e-04, 1.0116e-03, ..., 2.4602e-03,\n", " 6.7727e-02, 1.1728e-02]]]], grad_fn=)" ] }, "execution_count": 47, "metadata": {}, "output_type": "execute_result" } ], "source": [ "output[2][2]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## gotowe api" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### generowanie tekstu" ] }, { "cell_type": "code", "execution_count": 48, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Some weights of GPT2Model were not initialized from the model checkpoint at gpt2 and are newly initialized: ['h.0.attn.masked_bias', 'h.1.attn.masked_bias', 'h.2.attn.masked_bias', 'h.3.attn.masked_bias', 'h.4.attn.masked_bias', 'h.5.attn.masked_bias', 'h.6.attn.masked_bias', 'h.7.attn.masked_bias', 'h.8.attn.masked_bias', 'h.9.attn.masked_bias', 'h.10.attn.masked_bias', 'h.11.attn.masked_bias']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" ] } ], "source": [ "model = pipeline('text-generation', model='gpt2')" ] }, { "cell_type": "code", "execution_count": 49, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n" ] }, { "data": { "text/plain": [ "[{'generated_text': 'Hello, I\\'m a computer science student by trade. I don\\'t really like science.\"\\n\\nThen I hear him say: \"I love the'},\n", " {'generated_text': \"Hello, I'm a computer science student.\\n\\nAnd if you're curious what I'm doing here, don't hesitate:\\n\\n\\nI've\"},\n", " {'generated_text': \"Hello, I'm a computer science student, not an engineer. But, I'm also fascinated, because all the people I'm talking to are engineers\"},\n", " {'generated_text': \"Hello, I'm a computer science student with a big project called the Data Science project, to help students create and understand and improve their data science.\"},\n", " {'generated_text': \"Hello, I'm a computer science student from North Carolina. My work involves a number of questions (and many possible answers as well). I can't\"}]" ] }, "execution_count": 49, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model(\"Hello, I'm a computer science student\", max_length=30, num_return_sequences=5)" ] }, { "cell_type": "code", "execution_count": 50, "metadata": { "scrolled": true }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n" ] }, { "data": { "text/plain": [ "[{'generated_text': \"I want to contribute to Google's Computer Vision Program, which is doing extensive work on big data. It is also working on creating a new type of\"},\n", " {'generated_text': \"I want to contribute to Google's Computer Vision Program, which is doing extensive work on big datasets. It has now given every request to Google for information\"},\n", " {'generated_text': \"I want to contribute to Google's Computer Vision Program, which is doing extensive work on big data, robotics and artificial intelligence. We understand the potential impact\"},\n", " {'generated_text': \"I want to contribute to Google's Computer Vision Program, which is doing extensive work on big data, artificial intelligence, and machine learning. I think we\"},\n", " {'generated_text': \"I want to contribute to Google's Computer Vision Program, which is doing extensive work on big new ways to see and interact with image data. We'll\"}]" ] }, "execution_count": 50, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model(\"I want to contribute to Google's Computer Vision Program, which is doing extensive work on big\", max_length=30, num_return_sequences=5)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### sentiment analysis" ] }, { "cell_type": "code", "execution_count": 51, "metadata": {}, "outputs": [], "source": [ "from transformers import pipeline\n", "\n", "model = pipeline(\"sentiment-analysis\", model='distilbert-base-uncased-finetuned-sst-2-english')" ] }, { "cell_type": "code", "execution_count": 52, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 52, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model" ] }, { "cell_type": "code", "execution_count": 53, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[{'label': 'POSITIVE', 'score': 0.9998474717140198}]" ] }, "execution_count": 53, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model(\"I'm very happy. Today is the beatifull weather\")" ] }, { "cell_type": "code", "execution_count": 54, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[{'label': 'NEGATIVE', 'score': 0.9946851134300232}]" ] }, "execution_count": 54, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model(\"It's raining. What a terrible day...\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## NER" ] }, { "cell_type": "code", "execution_count": 55, "metadata": {}, "outputs": [], "source": [ "model = pipeline(\"sentiment-analysis\", model='distilbert-base-uncased-finetuned-sst-2-english')" ] }, { "cell_type": "code", "execution_count": 56, "metadata": {}, "outputs": [], "source": [ "from transformers import pipeline\n", "model = pipeline(\"ner\")" ] }, { "cell_type": "code", "execution_count": 57, "metadata": {}, "outputs": [], "source": [ "text = \"George Washington went to Washington\"" ] }, { "cell_type": "code", "execution_count": 58, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[{'word': 'George',\n", " 'score': 0.9983943104743958,\n", " 'entity': 'I-PER',\n", " 'index': 1,\n", " 'start': 0,\n", " 'end': 6},\n", " {'word': 'Washington',\n", " 'score': 0.9992505311965942,\n", " 'entity': 'I-PER',\n", " 'index': 2,\n", " 'start': 7,\n", " 'end': 17},\n", " {'word': 'Washington',\n", " 'score': 0.98389732837677,\n", " 'entity': 'I-LOC',\n", " 'index': 5,\n", " 'start': 26,\n", " 'end': 36}]" ] }, "execution_count": 58, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model(text)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### masked language modelling" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### ZADANIE (10 minut)\n", "\n", "przewidziać token w \"The world II started in 1939\"\" wg dowolnego anglojęzycznego modelu" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.3" } }, "nbformat": 4, "nbformat_minor": 4 }