260 lines
13 KiB
Plaintext
260 lines
13 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Zajęcia 8: Vector representations\n",
|
|
"\n",
|
|
"Wszystkie zadania ćwiczeniowe należy rozwiązywać w języku Python w kopii Jupyter Notebook'a dla danych zajęć w wyznaczonych miejscach (komórki z komentarzem `# Solution`).\n",
|
|
"\n",
|
|
"Nie należy usuwać komórek z treścią zadań.\n",
|
|
"\n",
|
|
"Należy wyświetlać outputy przy pomocy `print`\n",
|
|
"\n",
|
|
"## Dla chętnych! (może się przydać od ambitniejszych projektów końcowych)\n",
|
|
"\n",
|
|
"https://github.com/huggingface/smol-course - kurs finetune'owania LLMów do własnych zadań\n",
|
|
"\n",
|
|
"https://github.com/unslothai/unsloth - biblioteka do efektywnego finetune'owania LLMów (są gotowe notebooki z kodem na platformie Colab)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Zadanie 1\n",
|
|
"\n",
|
|
"Dokonaj preprocessingu tekstów https://git.wmi.amu.edu.pl/ryssta/spam-classification/src/branch/master/train/in.tsv (preprocessing - proces wstępnej \"obróbki tekstów - sprowadzenie do małych liter, tokenizacja itd..) i dokonaj:\n",
|
|
"* a) one hot encodingu\n",
|
|
"* b) frequency encodingu\n",
|
|
"\n",
|
|
"dla dwóch przykładowych zdań, które zawierają przynajmniej 2 wystąpienia słowa, które znajduje się w słowniku (czyli występuje w korpusie z pliku in.tsv). Ze względu na dużą liczbę unikalnych słów w korpusie, proszę nie printować całych wektorów, tylko indeksy oraz wartości."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Zadanie 2\n",
|
|
"\n",
|
|
"Na podstawie pliku https://git.wmi.amu.edu.pl/ryssta/spam-classification/src/branch/master/train/in.tsv oraz pliku https://git.wmi.amu.edu.pl/ryssta/spam-classification/src/branch/master/train/expected.tsv podziel teksty względem klasy spam/nie spam. Oblicz wartość IDF osobno dla tekstów klasy spam oraz dla tekstów klasy nie spam, dla słów:\n",
|
|
"* free\n",
|
|
"* send\n",
|
|
"* are\n",
|
|
"* the\n",
|
|
"\n",
|
|
"oraz 2-3 własnoręcznie wybranych słów."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Zadanie 3\n",
|
|
"\n",
|
|
"Na podstawie warstwy embedding modelu gpt2 wypisz 15 najbardziej podobnych (względem miary podobieństwa cosinuowego) tokenów do słów:\n",
|
|
"* cat\n",
|
|
"* tree\n",
|
|
"\n",
|
|
"oraz 2 własnoręcznie wybranych tokenów."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 62,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from transformers import GPT2Tokenizer, GPT2Model\n",
|
|
"import torch\n",
|
|
"\n",
|
|
"\n",
|
|
"tokenizer = GPT2Tokenizer.from_pretrained('gpt2')\n",
|
|
"model = GPT2Model.from_pretrained('gpt2')\n",
|
|
"embedding_layer = model.wte\n",
|
|
"cos_sim = torch.nn.CosineSimilarity()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 63,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Tekst 'cat' jest konwertowany do tokenu 9246\n",
|
|
"{'input_ids': [33215], 'attention_mask': [1]}\n",
|
|
"cat\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print(\"Tekst 'cat' jest konwertowany do tokenu 9246\")\n",
|
|
"print(tokenizer(\"computer\"))\n",
|
|
"print(tokenizer.decode([9246]))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 66,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Embedding tokenu: 9246\n",
|
|
"torch.Size([1, 768])\n",
|
|
"tensor([[-0.0164, -0.0934, 0.2425, 0.1398, 0.0388, -0.2592, -0.2724, -0.1625,\n",
|
|
" 0.1683, 0.0829, 0.0136, -0.2788, 0.1493, 0.1408, 0.0557, -0.3691,\n",
|
|
" 0.2200, -0.0428, 0.2206, 0.0865, 0.1237, -0.1499, 0.1446, -0.1150,\n",
|
|
" -0.1425, -0.0715, -0.0526, 0.1550, -0.0678, -0.2059, 0.2065, -0.0297,\n",
|
|
" 0.0834, -0.0483, 0.1207, 0.1975, -0.3193, 0.0124, 0.1067, -0.0473,\n",
|
|
" -0.3037, 0.1139, 0.0949, -0.2175, 0.0796, -0.0941, -0.0394, -0.0704,\n",
|
|
" 0.2033, -0.1555, 0.2928, -0.0770, 0.0787, 0.1214, 0.1528, -0.1464,\n",
|
|
" 0.4247, 0.1921, -0.0415, -0.0850, -0.2787, 0.0656, -0.2026, 0.1856,\n",
|
|
" 0.1353, -0.0820, -0.0639, 0.0701, 0.1680, 0.0597, 0.3265, -0.1100,\n",
|
|
" 0.1056, 0.1845, -0.1156, 0.0054, 0.0663, 0.1842, -0.1069, 0.0491,\n",
|
|
" -0.0853, -0.2519, 0.0031, 0.1805, 0.1505, 0.0442, -0.2427, 0.1104,\n",
|
|
" 0.0970, 0.1123, -0.1519, -0.1444, 0.2323, -0.0241, -0.0677, 0.1157,\n",
|
|
" -0.2668, -0.1229, 0.1120, 0.0601, -0.0535, 0.1259, -0.0966, -0.1975,\n",
|
|
" -0.2031, 0.1323, 0.0176, -0.1332, 0.1159, 0.1037, 0.0722, 0.1644,\n",
|
|
" -0.0775, -0.0227, 0.1146, 0.0060, 0.3959, -0.0828, 0.0125, 0.0415,\n",
|
|
" -0.0147, -0.1352, -0.0579, 0.0423, -0.0793, -0.2702, 0.2806, -0.0744,\n",
|
|
" 0.1118, 0.0908, 0.0639, -0.0882, -0.0190, -0.1386, -0.0490, -0.1785,\n",
|
|
" 0.1416, 0.0497, -0.0461, -0.1544, 0.0662, -0.0538, 0.0992, 0.1308,\n",
|
|
" -0.0885, -0.2840, -0.0297, -0.0882, -0.0340, -0.1495, 0.0295, 0.0700,\n",
|
|
" 0.0661, -0.1282, -0.0546, -0.1392, -0.1368, 0.0353, 0.0814, -0.1436,\n",
|
|
" -0.0559, 0.1523, -0.0780, 0.2562, 0.0164, -0.0433, 0.0468, 0.2896,\n",
|
|
" 0.0069, 0.2136, 0.0378, -0.1625, -0.0421, -0.0109, -0.0386, 0.0453,\n",
|
|
" 0.2572, 0.0323, -0.1206, 0.0135, -0.0171, 0.0404, -0.2544, 0.1455,\n",
|
|
" 0.3265, -0.0545, 0.0887, -0.0321, 0.1485, -0.0699, -0.0606, 0.2177,\n",
|
|
" -0.1566, -0.0619, -0.2655, 0.2471, -0.1213, -0.0741, 0.1074, -0.0243,\n",
|
|
" -0.2747, 0.1828, 0.0046, 0.1426, 0.0201, 0.0413, -0.0189, -0.0070,\n",
|
|
" -0.0338, -0.1384, 0.0269, 0.1447, -0.0216, -0.0042, -0.0363, -0.0579,\n",
|
|
" -0.0909, -0.1359, 0.1407, 0.1421, 0.0041, 0.0100, 0.0997, -0.0718,\n",
|
|
" -0.0958, 0.0051, -0.2576, 0.1980, 0.1545, 0.0226, -0.2521, -0.1091,\n",
|
|
" -0.1467, -0.1140, 0.1383, -0.1952, 0.0554, -0.0036, -0.2915, -0.1645,\n",
|
|
" 0.0469, -0.2251, 0.2000, 0.1070, 0.1651, -0.0781, 0.1511, -0.0095,\n",
|
|
" 0.0925, 0.0776, 0.1631, 0.1563, 0.0286, -0.1157, 0.0349, 0.0033,\n",
|
|
" -0.0870, -0.0864, 0.1233, -0.0691, -0.0458, -0.0601, 0.0501, -0.1450,\n",
|
|
" -0.2425, -0.0773, 0.1182, 0.1351, 0.1904, 0.1746, 0.0925, -0.1253,\n",
|
|
" -0.1149, -0.1312, -0.2170, 0.0682, -0.0121, -0.1774, -0.0857, -0.1906,\n",
|
|
" 0.2842, -0.0410, 0.0530, 0.0480, -0.0641, -0.0911, 0.2907, -0.2503,\n",
|
|
" -0.1085, 0.1753, 0.0610, 0.0466, 0.0097, -0.1300, -0.0273, 0.0498,\n",
|
|
" -0.0619, -0.1867, -0.0769, -0.1091, -0.0410, -0.0617, -0.0537, 0.0582,\n",
|
|
" -0.0986, -0.2655, 0.1236, -0.0026, 0.0444, -0.1018, -0.1652, 0.0174,\n",
|
|
" -0.2561, 0.0440, 0.2048, 0.0049, -0.0220, -0.1031, -0.1387, 0.0493,\n",
|
|
" 0.2048, 0.0473, 0.1630, -0.0195, -0.0714, -0.0713, -0.2524, 0.0852,\n",
|
|
" -0.1682, 0.0713, 0.1301, -0.1088, -0.0188, 0.0092, 0.0078, 0.2213,\n",
|
|
" 0.0638, -0.1617, -0.0365, -0.0923, -0.1052, 0.1108, -0.1175, -0.0016,\n",
|
|
" -0.0258, 0.0902, 0.1089, 0.1685, -0.2664, -0.0309, -0.0187, -0.0678,\n",
|
|
" -0.1424, -0.0026, -0.0623, -0.0575, -0.1009, 0.0142, -0.1950, 0.0085,\n",
|
|
" -0.1402, 0.0371, -0.4072, -0.0478, 0.4013, 0.3212, 0.1051, 0.0349,\n",
|
|
" -0.1302, -0.0298, -0.1738, 0.0692, -0.2638, 0.1268, 0.1773, -0.1094,\n",
|
|
" 0.0737, -0.0460, 0.1870, -0.0605, -0.1308, -0.0920, -0.0290, -0.0542,\n",
|
|
" 0.1214, -0.0308, -0.1173, -0.2127, 0.0209, 0.2911, -0.1751, 0.0469,\n",
|
|
" 0.0740, -0.1323, 0.0283, 0.2125, 0.1870, 0.0978, 0.1799, 0.2669,\n",
|
|
" 0.1709, -0.1191, -0.2022, 0.0445, 0.0601, -0.1820, 0.0224, 0.1902,\n",
|
|
" -0.3199, -0.2551, 0.0795, 0.0814, 0.1245, 0.0871, -0.0455, -0.2342,\n",
|
|
" 0.1167, 0.0870, 0.0257, -0.2073, 0.1849, -0.0184, 0.0498, -0.1423,\n",
|
|
" -0.0682, 0.1386, 0.0406, -0.0325, 0.2179, -0.0567, -0.2568, -0.0935,\n",
|
|
" -0.0453, -0.1317, 0.0682, -0.2721, -0.2026, -0.0565, -0.0134, -0.0423,\n",
|
|
" -0.0415, -0.0560, -0.1522, 0.1617, -0.0753, -0.1967, 0.0536, -0.0988,\n",
|
|
" 0.0539, -0.0489, 0.0296, 0.1272, 0.1567, 0.0185, -0.0855, -0.2336,\n",
|
|
" 0.1859, 0.1528, -0.1824, -0.0834, -0.1414, -0.0526, 0.1744, -0.0290,\n",
|
|
" -0.0753, -0.0100, -0.1702, -0.0676, 0.0856, 0.0493, -0.1256, -0.1652,\n",
|
|
" -0.1317, 0.0677, 0.0209, -0.0346, 0.0048, 0.1209, 0.1959, 0.1520,\n",
|
|
" 0.0793, -0.1492, 0.3141, 0.1526, -0.1732, -0.0914, 0.1339, -0.1410,\n",
|
|
" -0.0595, -0.0250, -0.1136, -0.1206, -0.1126, -0.0470, 0.1898, 0.0565,\n",
|
|
" -0.2058, 0.0389, 0.0177, -0.2718, 0.2021, -0.0779, 0.1444, 0.1047,\n",
|
|
" -0.2096, -0.0210, 0.1791, -0.4005, -0.1931, 0.1083, 0.2465, 0.1026,\n",
|
|
" -0.0503, 0.1047, 0.0299, -0.1043, 0.0964, 0.0852, -0.2067, 0.1263,\n",
|
|
" 0.2064, 0.2248, 0.2739, -0.1881, -0.0745, 0.0769, 0.2994, 0.2803,\n",
|
|
" 0.0063, 0.2585, -0.0176, 0.2318, -0.0432, 0.1889, -0.0766, 0.0751,\n",
|
|
" -0.0157, 0.0517, 0.1274, -0.2235, -0.0450, 0.1606, 0.0876, 0.1240,\n",
|
|
" 0.4417, -0.0625, 0.0591, -0.0181, 0.1996, 0.0959, -0.2623, -0.2826,\n",
|
|
" 0.0023, 0.1835, 0.1931, -0.1054, 0.1816, -0.1599, -0.0871, 0.0115,\n",
|
|
" 0.2386, 0.0161, 0.0580, -0.0558, 0.0963, 0.1206, -0.3461, 0.0726,\n",
|
|
" 0.0301, 0.1058, 0.0532, 0.0515, 0.0216, -0.0531, -0.0217, 0.0539,\n",
|
|
" -0.0191, 0.0636, -0.1527, -0.1670, 0.0756, 0.0167, -0.0437, 0.0050,\n",
|
|
" -0.1861, 0.0304, 0.2442, -0.0126, -0.2314, 0.1562, 0.1635, -0.1206,\n",
|
|
" -0.0428, 0.1079, 0.1216, -0.0113, 0.1757, -0.0235, 0.2049, -0.3030,\n",
|
|
" 0.0067, -0.3157, 0.1435, -0.1737, -0.1698, 0.2276, -0.0360, 0.0048,\n",
|
|
" -0.2974, 0.2021, 0.1380, 0.1129, -0.0626, 0.1347, 0.0729, 0.0481,\n",
|
|
" -0.1397, 0.0197, -0.0932, 0.1717, -0.1519, -0.0554, -0.0344, 0.0201,\n",
|
|
" 0.1316, 0.0743, -0.1189, 0.2787, 0.0597, -0.2073, -0.3555, -0.0645,\n",
|
|
" -0.1326, 0.1094, 0.1512, 0.0241, 0.0608, 0.0334, 0.1340, -0.0510,\n",
|
|
" -0.0197, 0.0681, -0.1494, 0.2410, -0.1016, 0.1148, -0.1280, 0.0576,\n",
|
|
" -0.0101, 0.1957, -0.2854, 0.0231, -0.0282, -0.0101, 0.0383, -0.0086,\n",
|
|
" -0.1152, 0.1464, 0.0351, -0.2189, 0.2156, 0.0722, 0.0881, 0.2360,\n",
|
|
" 0.1182, 0.0676, 0.0506, -0.0226, -0.2842, -0.2781, 0.0791, 0.1958,\n",
|
|
" 0.0470, -0.0645, 0.0129, 0.2878, 0.1339, -0.2548, -0.0855, -0.1515,\n",
|
|
" 0.1442, -0.1694, -0.2864, -0.1720, -0.0473, -0.1325, 0.2390, 0.0736,\n",
|
|
" -0.1975, -0.1141, -0.1409, 0.1901, -0.0673, 0.0887, 0.2132, 0.1869,\n",
|
|
" -0.2635, 0.1148, -0.0100, 0.0344, -0.1273, -0.0943, 0.1698, 0.1511,\n",
|
|
" -0.2248, 0.0495, 0.1384, 0.1260, 0.0787, 0.3357, 0.2288, -0.0885,\n",
|
|
" 0.0622, -0.2151, 0.0553, 0.0690, 0.2568, -0.2637, -0.0460, 0.0911,\n",
|
|
" -0.0214, 0.1712, 0.0342, -0.0438, -0.1457, -0.0953, 0.0289, -0.0830,\n",
|
|
" -0.0591, -0.0244, 0.1259, 0.0075, -0.1052, 0.1119, 0.0621, -0.2382,\n",
|
|
" 0.1763, -0.1132, -0.1201, 0.0162, -0.1807, 0.1419, 0.0749, -0.2318,\n",
|
|
" 0.0132, 0.2532, -0.1028, 0.0556, -0.0214, 0.1974, -0.1573, 0.1861,\n",
|
|
" -0.1729, -0.0737, -0.1499, -0.1533, 0.0011, 0.0681, -0.1828, 0.0519,\n",
|
|
" -0.0837, 0.0671, -0.1867, 0.0012, 0.0377, 0.1061, -0.1713, -0.1579]],\n",
|
|
" grad_fn=<EmbeddingBackward0>)\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print(\"Embedding tokenu: 9246\")\n",
|
|
"cat_embedding = embedding_layer(torch.LongTensor([9246]))\n",
|
|
"print(cat_embedding.shape)\n",
|
|
"print(cat_embedding)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 65,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Podobieństwo tego samego embeddingu (powinno wyjść 1)\n",
|
|
"tensor([1.0000], grad_fn=<SumBackward1>)\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print(\"Podobieństwo tego samego embeddingu (powinno wyjść 1)\")\n",
|
|
"print(cos_sim(cat_embedding, cat_embedding))"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "base",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.11.5"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|