diff --git a/08_vector_representations.ipynb b/08_vector_representations.ipynb new file mode 100644 index 0000000..71c770a --- /dev/null +++ b/08_vector_representations.ipynb @@ -0,0 +1,259 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Zajęcia 8: Vector representations\n", + "\n", + "Wszystkie zadania ćwiczeniowe należy rozwiązywać w języku Python w kopii Jupyter Notebook'a dla danych zajęć w wyznaczonych miejscach (komórki z komentarzem `# Solution`).\n", + "\n", + "Nie należy usuwać komórek z treścią zadań.\n", + "\n", + "Należy wyświetlać outputy przy pomocy `print`\n", + "\n", + "## Dla chętnych! (może się przydać od ambitniejszych projektów końcowych)\n", + "\n", + "https://github.com/huggingface/smol-course - kurs finetune'owania LLMów do własnych zadań\n", + "\n", + "https://github.com/unslothai/unsloth - biblioteka do efektywnego finetune'owania LLMów (są gotowe notebooki z kodem na platformie Colab)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Zadanie 1\n", + "\n", + "Dokonaj preprocessingu tekstów https://git.wmi.amu.edu.pl/ryssta/spam-classification/src/branch/master/train/in.tsv (preprocessing - proces wstępnej \"obróbki tekstów - sprowadzenie do małych liter, tokenizacja itd..) i dokonaj:\n", + "* a) one hot encodingu\n", + "* b) frequency encodingu\n", + "\n", + "dla dwóch przykładowych zdań, które zawierają przynajmniej 2 wystąpienia słowa, które znajduje się w słowniku (czyli występuje w korpusie z pliku in.tsv). Ze względu na dużą liczbę unikalnych słów w korpusie, proszę nie printować całych wektorów, tylko indeksy oraz wartości." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Zadanie 2\n", + "\n", + "Na podstawie pliku https://git.wmi.amu.edu.pl/ryssta/spam-classification/src/branch/master/train/in.tsv oraz pliku https://git.wmi.amu.edu.pl/ryssta/spam-classification/src/branch/master/train/expected.tsv podziel teksty względem klasy spam/nie spam. Oblicz wartość IDF osobno dla tekstów klasy spam oraz dla tekstów klasy nie spam, dla słów:\n", + "* free\n", + "* send\n", + "* are\n", + "* the\n", + "\n", + "oraz 2-3 własnoręcznie wybranych słów." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Zadanie 3\n", + "\n", + "Na podstawie warstwy embedding modelu gpt2 wypisz 15 najbardziej podobnych (względem miary podobieństwa cosinuowego) tokenów do słów:\n", + "* cat\n", + "* tree\n", + "\n", + "oraz 2 własnoręcznie wybranych tokenów." + ] + }, + { + "cell_type": "code", + "execution_count": 62, + "metadata": {}, + "outputs": [], + "source": [ + "from transformers import GPT2Tokenizer, GPT2Model\n", + "import torch\n", + "\n", + "\n", + "tokenizer = GPT2Tokenizer.from_pretrained('gpt2')\n", + "model = GPT2Model.from_pretrained('gpt2')\n", + "embedding_layer = model.wte\n", + "cos_sim = torch.nn.CosineSimilarity()" + ] + }, + { + "cell_type": "code", + "execution_count": 63, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Tekst 'cat' jest konwertowany do tokenu 9246\n", + "{'input_ids': [33215], 'attention_mask': [1]}\n", + "cat\n" + ] + } + ], + "source": [ + "print(\"Tekst 'cat' jest konwertowany do tokenu 9246\")\n", + "print(tokenizer(\"computer\"))\n", + "print(tokenizer.decode([9246]))" + ] + }, + { + "cell_type": "code", + "execution_count": 66, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Embedding tokenu: 9246\n", + "torch.Size([1, 768])\n", + "tensor([[-0.0164, -0.0934, 0.2425, 0.1398, 0.0388, -0.2592, -0.2724, -0.1625,\n", + " 0.1683, 0.0829, 0.0136, -0.2788, 0.1493, 0.1408, 0.0557, -0.3691,\n", + " 0.2200, -0.0428, 0.2206, 0.0865, 0.1237, -0.1499, 0.1446, -0.1150,\n", + " -0.1425, -0.0715, -0.0526, 0.1550, -0.0678, -0.2059, 0.2065, -0.0297,\n", + " 0.0834, -0.0483, 0.1207, 0.1975, -0.3193, 0.0124, 0.1067, -0.0473,\n", + " -0.3037, 0.1139, 0.0949, -0.2175, 0.0796, -0.0941, -0.0394, -0.0704,\n", + " 0.2033, -0.1555, 0.2928, -0.0770, 0.0787, 0.1214, 0.1528, -0.1464,\n", + " 0.4247, 0.1921, -0.0415, -0.0850, -0.2787, 0.0656, -0.2026, 0.1856,\n", + " 0.1353, -0.0820, -0.0639, 0.0701, 0.1680, 0.0597, 0.3265, -0.1100,\n", + " 0.1056, 0.1845, -0.1156, 0.0054, 0.0663, 0.1842, -0.1069, 0.0491,\n", + " -0.0853, -0.2519, 0.0031, 0.1805, 0.1505, 0.0442, -0.2427, 0.1104,\n", + " 0.0970, 0.1123, -0.1519, -0.1444, 0.2323, -0.0241, -0.0677, 0.1157,\n", + " -0.2668, -0.1229, 0.1120, 0.0601, -0.0535, 0.1259, -0.0966, -0.1975,\n", + " -0.2031, 0.1323, 0.0176, -0.1332, 0.1159, 0.1037, 0.0722, 0.1644,\n", + " -0.0775, -0.0227, 0.1146, 0.0060, 0.3959, -0.0828, 0.0125, 0.0415,\n", + " -0.0147, -0.1352, -0.0579, 0.0423, -0.0793, -0.2702, 0.2806, -0.0744,\n", + " 0.1118, 0.0908, 0.0639, -0.0882, -0.0190, -0.1386, -0.0490, -0.1785,\n", + " 0.1416, 0.0497, -0.0461, -0.1544, 0.0662, -0.0538, 0.0992, 0.1308,\n", + " -0.0885, -0.2840, -0.0297, -0.0882, -0.0340, -0.1495, 0.0295, 0.0700,\n", + " 0.0661, -0.1282, -0.0546, -0.1392, -0.1368, 0.0353, 0.0814, -0.1436,\n", + " -0.0559, 0.1523, -0.0780, 0.2562, 0.0164, -0.0433, 0.0468, 0.2896,\n", + " 0.0069, 0.2136, 0.0378, -0.1625, -0.0421, -0.0109, -0.0386, 0.0453,\n", + " 0.2572, 0.0323, -0.1206, 0.0135, -0.0171, 0.0404, -0.2544, 0.1455,\n", + " 0.3265, -0.0545, 0.0887, -0.0321, 0.1485, -0.0699, -0.0606, 0.2177,\n", + " -0.1566, -0.0619, -0.2655, 0.2471, -0.1213, -0.0741, 0.1074, -0.0243,\n", + " -0.2747, 0.1828, 0.0046, 0.1426, 0.0201, 0.0413, -0.0189, -0.0070,\n", + " -0.0338, -0.1384, 0.0269, 0.1447, -0.0216, -0.0042, -0.0363, -0.0579,\n", + " -0.0909, -0.1359, 0.1407, 0.1421, 0.0041, 0.0100, 0.0997, -0.0718,\n", + " -0.0958, 0.0051, -0.2576, 0.1980, 0.1545, 0.0226, -0.2521, -0.1091,\n", + " -0.1467, -0.1140, 0.1383, -0.1952, 0.0554, -0.0036, -0.2915, -0.1645,\n", + " 0.0469, -0.2251, 0.2000, 0.1070, 0.1651, -0.0781, 0.1511, -0.0095,\n", + " 0.0925, 0.0776, 0.1631, 0.1563, 0.0286, -0.1157, 0.0349, 0.0033,\n", + " -0.0870, -0.0864, 0.1233, -0.0691, -0.0458, -0.0601, 0.0501, -0.1450,\n", + " -0.2425, -0.0773, 0.1182, 0.1351, 0.1904, 0.1746, 0.0925, -0.1253,\n", + " -0.1149, -0.1312, -0.2170, 0.0682, -0.0121, -0.1774, -0.0857, -0.1906,\n", + " 0.2842, -0.0410, 0.0530, 0.0480, -0.0641, -0.0911, 0.2907, -0.2503,\n", + " -0.1085, 0.1753, 0.0610, 0.0466, 0.0097, -0.1300, -0.0273, 0.0498,\n", + " -0.0619, -0.1867, -0.0769, -0.1091, -0.0410, -0.0617, -0.0537, 0.0582,\n", + " -0.0986, -0.2655, 0.1236, -0.0026, 0.0444, -0.1018, -0.1652, 0.0174,\n", + " -0.2561, 0.0440, 0.2048, 0.0049, -0.0220, -0.1031, -0.1387, 0.0493,\n", + " 0.2048, 0.0473, 0.1630, -0.0195, -0.0714, -0.0713, -0.2524, 0.0852,\n", + " -0.1682, 0.0713, 0.1301, -0.1088, -0.0188, 0.0092, 0.0078, 0.2213,\n", + " 0.0638, -0.1617, -0.0365, -0.0923, -0.1052, 0.1108, -0.1175, -0.0016,\n", + " -0.0258, 0.0902, 0.1089, 0.1685, -0.2664, -0.0309, -0.0187, -0.0678,\n", + " -0.1424, -0.0026, -0.0623, -0.0575, -0.1009, 0.0142, -0.1950, 0.0085,\n", + " -0.1402, 0.0371, -0.4072, -0.0478, 0.4013, 0.3212, 0.1051, 0.0349,\n", + " -0.1302, -0.0298, -0.1738, 0.0692, -0.2638, 0.1268, 0.1773, -0.1094,\n", + " 0.0737, -0.0460, 0.1870, -0.0605, -0.1308, -0.0920, -0.0290, -0.0542,\n", + " 0.1214, -0.0308, -0.1173, -0.2127, 0.0209, 0.2911, -0.1751, 0.0469,\n", + " 0.0740, -0.1323, 0.0283, 0.2125, 0.1870, 0.0978, 0.1799, 0.2669,\n", + " 0.1709, -0.1191, -0.2022, 0.0445, 0.0601, -0.1820, 0.0224, 0.1902,\n", + " -0.3199, -0.2551, 0.0795, 0.0814, 0.1245, 0.0871, -0.0455, -0.2342,\n", + " 0.1167, 0.0870, 0.0257, -0.2073, 0.1849, -0.0184, 0.0498, -0.1423,\n", + " -0.0682, 0.1386, 0.0406, -0.0325, 0.2179, -0.0567, -0.2568, -0.0935,\n", + " -0.0453, -0.1317, 0.0682, -0.2721, -0.2026, -0.0565, -0.0134, -0.0423,\n", + " -0.0415, -0.0560, -0.1522, 0.1617, -0.0753, -0.1967, 0.0536, -0.0988,\n", + " 0.0539, -0.0489, 0.0296, 0.1272, 0.1567, 0.0185, -0.0855, -0.2336,\n", + " 0.1859, 0.1528, -0.1824, -0.0834, -0.1414, -0.0526, 0.1744, -0.0290,\n", + " -0.0753, -0.0100, -0.1702, -0.0676, 0.0856, 0.0493, -0.1256, -0.1652,\n", + " -0.1317, 0.0677, 0.0209, -0.0346, 0.0048, 0.1209, 0.1959, 0.1520,\n", + " 0.0793, -0.1492, 0.3141, 0.1526, -0.1732, -0.0914, 0.1339, -0.1410,\n", + " -0.0595, -0.0250, -0.1136, -0.1206, -0.1126, -0.0470, 0.1898, 0.0565,\n", + " -0.2058, 0.0389, 0.0177, -0.2718, 0.2021, -0.0779, 0.1444, 0.1047,\n", + " -0.2096, -0.0210, 0.1791, -0.4005, -0.1931, 0.1083, 0.2465, 0.1026,\n", + " -0.0503, 0.1047, 0.0299, -0.1043, 0.0964, 0.0852, -0.2067, 0.1263,\n", + " 0.2064, 0.2248, 0.2739, -0.1881, -0.0745, 0.0769, 0.2994, 0.2803,\n", + " 0.0063, 0.2585, -0.0176, 0.2318, -0.0432, 0.1889, -0.0766, 0.0751,\n", + " -0.0157, 0.0517, 0.1274, -0.2235, -0.0450, 0.1606, 0.0876, 0.1240,\n", + " 0.4417, -0.0625, 0.0591, -0.0181, 0.1996, 0.0959, -0.2623, -0.2826,\n", + " 0.0023, 0.1835, 0.1931, -0.1054, 0.1816, -0.1599, -0.0871, 0.0115,\n", + " 0.2386, 0.0161, 0.0580, -0.0558, 0.0963, 0.1206, -0.3461, 0.0726,\n", + " 0.0301, 0.1058, 0.0532, 0.0515, 0.0216, -0.0531, -0.0217, 0.0539,\n", + " -0.0191, 0.0636, -0.1527, -0.1670, 0.0756, 0.0167, -0.0437, 0.0050,\n", + " -0.1861, 0.0304, 0.2442, -0.0126, -0.2314, 0.1562, 0.1635, -0.1206,\n", + " -0.0428, 0.1079, 0.1216, -0.0113, 0.1757, -0.0235, 0.2049, -0.3030,\n", + " 0.0067, -0.3157, 0.1435, -0.1737, -0.1698, 0.2276, -0.0360, 0.0048,\n", + " -0.2974, 0.2021, 0.1380, 0.1129, -0.0626, 0.1347, 0.0729, 0.0481,\n", + " -0.1397, 0.0197, -0.0932, 0.1717, -0.1519, -0.0554, -0.0344, 0.0201,\n", + " 0.1316, 0.0743, -0.1189, 0.2787, 0.0597, -0.2073, -0.3555, -0.0645,\n", + " -0.1326, 0.1094, 0.1512, 0.0241, 0.0608, 0.0334, 0.1340, -0.0510,\n", + " -0.0197, 0.0681, -0.1494, 0.2410, -0.1016, 0.1148, -0.1280, 0.0576,\n", + " -0.0101, 0.1957, -0.2854, 0.0231, -0.0282, -0.0101, 0.0383, -0.0086,\n", + " -0.1152, 0.1464, 0.0351, -0.2189, 0.2156, 0.0722, 0.0881, 0.2360,\n", + " 0.1182, 0.0676, 0.0506, -0.0226, -0.2842, -0.2781, 0.0791, 0.1958,\n", + " 0.0470, -0.0645, 0.0129, 0.2878, 0.1339, -0.2548, -0.0855, -0.1515,\n", + " 0.1442, -0.1694, -0.2864, -0.1720, -0.0473, -0.1325, 0.2390, 0.0736,\n", + " -0.1975, -0.1141, -0.1409, 0.1901, -0.0673, 0.0887, 0.2132, 0.1869,\n", + " -0.2635, 0.1148, -0.0100, 0.0344, -0.1273, -0.0943, 0.1698, 0.1511,\n", + " -0.2248, 0.0495, 0.1384, 0.1260, 0.0787, 0.3357, 0.2288, -0.0885,\n", + " 0.0622, -0.2151, 0.0553, 0.0690, 0.2568, -0.2637, -0.0460, 0.0911,\n", + " -0.0214, 0.1712, 0.0342, -0.0438, -0.1457, -0.0953, 0.0289, -0.0830,\n", + " -0.0591, -0.0244, 0.1259, 0.0075, -0.1052, 0.1119, 0.0621, -0.2382,\n", + " 0.1763, -0.1132, -0.1201, 0.0162, -0.1807, 0.1419, 0.0749, -0.2318,\n", + " 0.0132, 0.2532, -0.1028, 0.0556, -0.0214, 0.1974, -0.1573, 0.1861,\n", + " -0.1729, -0.0737, -0.1499, -0.1533, 0.0011, 0.0681, -0.1828, 0.0519,\n", + " -0.0837, 0.0671, -0.1867, 0.0012, 0.0377, 0.1061, -0.1713, -0.1579]],\n", + " grad_fn=)\n" + ] + } + ], + "source": [ + "print(\"Embedding tokenu: 9246\")\n", + "cat_embedding = embedding_layer(torch.LongTensor([9246]))\n", + "print(cat_embedding.shape)\n", + "print(cat_embedding)" + ] + }, + { + "cell_type": "code", + "execution_count": 65, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Podobieństwo tego samego embeddingu (powinno wyjść 1)\n", + "tensor([1.0000], grad_fn=)\n" + ] + } + ], + "source": [ + "print(\"Podobieństwo tego samego embeddingu (powinno wyjść 1)\")\n", + "print(cos_sim(cat_embedding, cat_embedding))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "base", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.5" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}