pjn-2024-cw/08_vector_representations.ipynb
2024-12-04 16:46:08 +01:00

15 KiB

Zajęcia 8: Vector representations

Wszystkie zadania ćwiczeniowe należy rozwiązywać w języku Python w kopii Jupyter Notebook'a dla danych zajęć w wyznaczonych miejscach (komórki z komentarzem # Solution).

Nie należy usuwać komórek z treścią zadań.

Należy wyświetlać outputy przy pomocy print

Dla chętnych! (może się przydać od ambitniejszych projektów końcowych)

https://github.com/huggingface/smol-course - kurs finetune'owania LLMów do własnych zadań

https://github.com/unslothai/unsloth - biblioteka do efektywnego finetune'owania LLMów (są gotowe notebooki z kodem na platformie Colab)

Co to jest wektor?

Wektor - jednowymiarowa macierz

[0, 1, 0, 0, 0] - one hot encoding - tylko wartości 0/1

[0, 2, 0, 5, 1, 100] - frequency encoding - liczby całkowite >= 0

[-1.5, 0.0002, 5000.01] - wektor

Zadanie 1

Dokonaj preprocessingu tekstów https://git.wmi.amu.edu.pl/ryssta/spam-classification/src/branch/master/train/in.tsv (preprocessing - proces wstępnej "obróbki tekstów - sprowadzenie do małych liter, tokenizacja itd..) i dokonaj:

  • a) one hot encodingu
  • b) frequency encodingu

dla dwóch przykładowych zdań, które zawierają przynajmniej 2 wystąpienia słowa, które znajduje się w słowniku (czyli występuje w korpusie z pliku in.tsv). Ze względu na dużą liczbę unikalnych słów w korpusie, proszę nie printować całych wektorów, tylko indeksy oraz wartości.

Zadanie 2

Na podstawie pliku https://git.wmi.amu.edu.pl/ryssta/spam-classification/src/branch/master/train/in.tsv oraz pliku https://git.wmi.amu.edu.pl/ryssta/spam-classification/src/branch/master/train/expected.tsv podziel teksty względem klasy spam/nie spam. Oblicz wartość IDF osobno dla tekstów klasy spam oraz dla tekstów klasy nie spam, dla słów:

  • free
  • send
  • are
  • the

oraz 2-3 własnoręcznie wybranych słów.

Zadanie 3

Na podstawie warstwy embedding modelu gpt2 wypisz 15 najbardziej podobnych (względem miary podobieństwa cosinuowego) tokenów do słów:

  • cat
  • tree

oraz 2 własnoręcznie wybranych tokenów.

from transformers import GPT2Tokenizer, GPT2Model
import torch


tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2Model.from_pretrained('gpt2')
embedding_layer = model.wte
cos_sim = torch.nn.CosineSimilarity()
c:\Users\ryssta\AppData\Local\anaconda3\Lib\site-packages\transformers\tokenization_utils_base.py:1601: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be depracted in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884
  warnings.warn(
[
    [0.1, 0.2, 0.3], # Ala
    [-0.5, 0.5, 0.9], # ma
    ...
    # 50254
    ...
    [0.1, -0.1, -0.2] # w GPT2 jest 768 wartości w pojedynczym wektorze, a nie 3
]
print("Tekst 'cat' jest konwertowany do tokenu 9246")
print("\nTokenizacja")
print(tokenizer("cat"))
print("\nDetokenizacja")
print(tokenizer.decode([9246]))
print("\nLiczba tokenów w słowniku")
print(len(tokenizer))
Tekst 'cat' jest konwertowany do tokenu 9246

Tokenizacja
{'input_ids': [9246], 'attention_mask': [1]}

Detokenizacja
cat

Liczba tokenów w słowniku
50257
print("Embedding tokenu: 9246")
cat_embedding = embedding_layer(torch.LongTensor([9246]))
print("\nRozmiar embeddingu (wektora)")
print(cat_embedding.shape)
print("\nWartości embeddingu")
print(cat_embedding)
Embedding tokenu: 9246

Rozmiar embeddingu (wektora)
torch.Size([1, 768])

Wartości embeddingu
tensor([[-0.0164, -0.0934,  0.2425,  0.1398,  0.0388, -0.2592, -0.2724, -0.1625,
          0.1683,  0.0829,  0.0136, -0.2788,  0.1493,  0.1408,  0.0557, -0.3691,
          0.2200, -0.0428,  0.2206,  0.0865,  0.1237, -0.1499,  0.1446, -0.1150,
         -0.1425, -0.0715, -0.0526,  0.1550, -0.0678, -0.2059,  0.2065, -0.0297,
          0.0834, -0.0483,  0.1207,  0.1975, -0.3193,  0.0124,  0.1067, -0.0473,
         -0.3037,  0.1139,  0.0949, -0.2175,  0.0796, -0.0941, -0.0394, -0.0704,
          0.2033, -0.1555,  0.2928, -0.0770,  0.0787,  0.1214,  0.1528, -0.1464,
          0.4247,  0.1921, -0.0415, -0.0850, -0.2787,  0.0656, -0.2026,  0.1856,
          0.1353, -0.0820, -0.0639,  0.0701,  0.1680,  0.0597,  0.3265, -0.1100,
          0.1056,  0.1845, -0.1156,  0.0054,  0.0663,  0.1842, -0.1069,  0.0491,
         -0.0853, -0.2519,  0.0031,  0.1805,  0.1505,  0.0442, -0.2427,  0.1104,
          0.0970,  0.1123, -0.1519, -0.1444,  0.2323, -0.0241, -0.0677,  0.1157,
         -0.2668, -0.1229,  0.1120,  0.0601, -0.0535,  0.1259, -0.0966, -0.1975,
         -0.2031,  0.1323,  0.0176, -0.1332,  0.1159,  0.1037,  0.0722,  0.1644,
         -0.0775, -0.0227,  0.1146,  0.0060,  0.3959, -0.0828,  0.0125,  0.0415,
         -0.0147, -0.1352, -0.0579,  0.0423, -0.0793, -0.2702,  0.2806, -0.0744,
          0.1118,  0.0908,  0.0639, -0.0882, -0.0190, -0.1386, -0.0490, -0.1785,
          0.1416,  0.0497, -0.0461, -0.1544,  0.0662, -0.0538,  0.0992,  0.1308,
         -0.0885, -0.2840, -0.0297, -0.0882, -0.0340, -0.1495,  0.0295,  0.0700,
          0.0661, -0.1282, -0.0546, -0.1392, -0.1368,  0.0353,  0.0814, -0.1436,
         -0.0559,  0.1523, -0.0780,  0.2562,  0.0164, -0.0433,  0.0468,  0.2896,
          0.0069,  0.2136,  0.0378, -0.1625, -0.0421, -0.0109, -0.0386,  0.0453,
          0.2572,  0.0323, -0.1206,  0.0135, -0.0171,  0.0404, -0.2544,  0.1455,
          0.3265, -0.0545,  0.0887, -0.0321,  0.1485, -0.0699, -0.0606,  0.2177,
         -0.1566, -0.0619, -0.2655,  0.2471, -0.1213, -0.0741,  0.1074, -0.0243,
         -0.2747,  0.1828,  0.0046,  0.1426,  0.0201,  0.0413, -0.0189, -0.0070,
         -0.0338, -0.1384,  0.0269,  0.1447, -0.0216, -0.0042, -0.0363, -0.0579,
         -0.0909, -0.1359,  0.1407,  0.1421,  0.0041,  0.0100,  0.0997, -0.0718,
         -0.0958,  0.0051, -0.2576,  0.1980,  0.1545,  0.0226, -0.2521, -0.1091,
         -0.1467, -0.1140,  0.1383, -0.1952,  0.0554, -0.0036, -0.2915, -0.1645,
          0.0469, -0.2251,  0.2000,  0.1070,  0.1651, -0.0781,  0.1511, -0.0095,
          0.0925,  0.0776,  0.1631,  0.1563,  0.0286, -0.1157,  0.0349,  0.0033,
         -0.0870, -0.0864,  0.1233, -0.0691, -0.0458, -0.0601,  0.0501, -0.1450,
         -0.2425, -0.0773,  0.1182,  0.1351,  0.1904,  0.1746,  0.0925, -0.1253,
         -0.1149, -0.1312, -0.2170,  0.0682, -0.0121, -0.1774, -0.0857, -0.1906,
          0.2842, -0.0410,  0.0530,  0.0480, -0.0641, -0.0911,  0.2907, -0.2503,
         -0.1085,  0.1753,  0.0610,  0.0466,  0.0097, -0.1300, -0.0273,  0.0498,
         -0.0619, -0.1867, -0.0769, -0.1091, -0.0410, -0.0617, -0.0537,  0.0582,
         -0.0986, -0.2655,  0.1236, -0.0026,  0.0444, -0.1018, -0.1652,  0.0174,
         -0.2561,  0.0440,  0.2048,  0.0049, -0.0220, -0.1031, -0.1387,  0.0493,
          0.2048,  0.0473,  0.1630, -0.0195, -0.0714, -0.0713, -0.2524,  0.0852,
         -0.1682,  0.0713,  0.1301, -0.1088, -0.0188,  0.0092,  0.0078,  0.2213,
          0.0638, -0.1617, -0.0365, -0.0923, -0.1052,  0.1108, -0.1175, -0.0016,
         -0.0258,  0.0902,  0.1089,  0.1685, -0.2664, -0.0309, -0.0187, -0.0678,
         -0.1424, -0.0026, -0.0623, -0.0575, -0.1009,  0.0142, -0.1950,  0.0085,
         -0.1402,  0.0371, -0.4072, -0.0478,  0.4013,  0.3212,  0.1051,  0.0349,
         -0.1302, -0.0298, -0.1738,  0.0692, -0.2638,  0.1268,  0.1773, -0.1094,
          0.0737, -0.0460,  0.1870, -0.0605, -0.1308, -0.0920, -0.0290, -0.0542,
          0.1214, -0.0308, -0.1173, -0.2127,  0.0209,  0.2911, -0.1751,  0.0469,
          0.0740, -0.1323,  0.0283,  0.2125,  0.1870,  0.0978,  0.1799,  0.2669,
          0.1709, -0.1191, -0.2022,  0.0445,  0.0601, -0.1820,  0.0224,  0.1902,
         -0.3199, -0.2551,  0.0795,  0.0814,  0.1245,  0.0871, -0.0455, -0.2342,
          0.1167,  0.0870,  0.0257, -0.2073,  0.1849, -0.0184,  0.0498, -0.1423,
         -0.0682,  0.1386,  0.0406, -0.0325,  0.2179, -0.0567, -0.2568, -0.0935,
         -0.0453, -0.1317,  0.0682, -0.2721, -0.2026, -0.0565, -0.0134, -0.0423,
         -0.0415, -0.0560, -0.1522,  0.1617, -0.0753, -0.1967,  0.0536, -0.0988,
          0.0539, -0.0489,  0.0296,  0.1272,  0.1567,  0.0185, -0.0855, -0.2336,
          0.1859,  0.1528, -0.1824, -0.0834, -0.1414, -0.0526,  0.1744, -0.0290,
         -0.0753, -0.0100, -0.1702, -0.0676,  0.0856,  0.0493, -0.1256, -0.1652,
         -0.1317,  0.0677,  0.0209, -0.0346,  0.0048,  0.1209,  0.1959,  0.1520,
          0.0793, -0.1492,  0.3141,  0.1526, -0.1732, -0.0914,  0.1339, -0.1410,
         -0.0595, -0.0250, -0.1136, -0.1206, -0.1126, -0.0470,  0.1898,  0.0565,
         -0.2058,  0.0389,  0.0177, -0.2718,  0.2021, -0.0779,  0.1444,  0.1047,
         -0.2096, -0.0210,  0.1791, -0.4005, -0.1931,  0.1083,  0.2465,  0.1026,
         -0.0503,  0.1047,  0.0299, -0.1043,  0.0964,  0.0852, -0.2067,  0.1263,
          0.2064,  0.2248,  0.2739, -0.1881, -0.0745,  0.0769,  0.2994,  0.2803,
          0.0063,  0.2585, -0.0176,  0.2318, -0.0432,  0.1889, -0.0766,  0.0751,
         -0.0157,  0.0517,  0.1274, -0.2235, -0.0450,  0.1606,  0.0876,  0.1240,
          0.4417, -0.0625,  0.0591, -0.0181,  0.1996,  0.0959, -0.2623, -0.2826,
          0.0023,  0.1835,  0.1931, -0.1054,  0.1816, -0.1599, -0.0871,  0.0115,
          0.2386,  0.0161,  0.0580, -0.0558,  0.0963,  0.1206, -0.3461,  0.0726,
          0.0301,  0.1058,  0.0532,  0.0515,  0.0216, -0.0531, -0.0217,  0.0539,
         -0.0191,  0.0636, -0.1527, -0.1670,  0.0756,  0.0167, -0.0437,  0.0050,
         -0.1861,  0.0304,  0.2442, -0.0126, -0.2314,  0.1562,  0.1635, -0.1206,
         -0.0428,  0.1079,  0.1216, -0.0113,  0.1757, -0.0235,  0.2049, -0.3030,
          0.0067, -0.3157,  0.1435, -0.1737, -0.1698,  0.2276, -0.0360,  0.0048,
         -0.2974,  0.2021,  0.1380,  0.1129, -0.0626,  0.1347,  0.0729,  0.0481,
         -0.1397,  0.0197, -0.0932,  0.1717, -0.1519, -0.0554, -0.0344,  0.0201,
          0.1316,  0.0743, -0.1189,  0.2787,  0.0597, -0.2073, -0.3555, -0.0645,
         -0.1326,  0.1094,  0.1512,  0.0241,  0.0608,  0.0334,  0.1340, -0.0510,
         -0.0197,  0.0681, -0.1494,  0.2410, -0.1016,  0.1148, -0.1280,  0.0576,
         -0.0101,  0.1957, -0.2854,  0.0231, -0.0282, -0.0101,  0.0383, -0.0086,
         -0.1152,  0.1464,  0.0351, -0.2189,  0.2156,  0.0722,  0.0881,  0.2360,
          0.1182,  0.0676,  0.0506, -0.0226, -0.2842, -0.2781,  0.0791,  0.1958,
          0.0470, -0.0645,  0.0129,  0.2878,  0.1339, -0.2548, -0.0855, -0.1515,
          0.1442, -0.1694, -0.2864, -0.1720, -0.0473, -0.1325,  0.2390,  0.0736,
         -0.1975, -0.1141, -0.1409,  0.1901, -0.0673,  0.0887,  0.2132,  0.1869,
         -0.2635,  0.1148, -0.0100,  0.0344, -0.1273, -0.0943,  0.1698,  0.1511,
         -0.2248,  0.0495,  0.1384,  0.1260,  0.0787,  0.3357,  0.2288, -0.0885,
          0.0622, -0.2151,  0.0553,  0.0690,  0.2568, -0.2637, -0.0460,  0.0911,
         -0.0214,  0.1712,  0.0342, -0.0438, -0.1457, -0.0953,  0.0289, -0.0830,
         -0.0591, -0.0244,  0.1259,  0.0075, -0.1052,  0.1119,  0.0621, -0.2382,
          0.1763, -0.1132, -0.1201,  0.0162, -0.1807,  0.1419,  0.0749, -0.2318,
          0.0132,  0.2532, -0.1028,  0.0556, -0.0214,  0.1974, -0.1573,  0.1861,
         -0.1729, -0.0737, -0.1499, -0.1533,  0.0011,  0.0681, -0.1828,  0.0519,
         -0.0837,  0.0671, -0.1867,  0.0012,  0.0377,  0.1061, -0.1713, -0.1579]],
       grad_fn=<EmbeddingBackward0>)
print("Podobieństwo tego samego embeddingu (powinno wyjść 1)")
print(cos_sim(cat_embedding, cat_embedding))
Podobieństwo tego samego embeddingu (powinno wyjść 1)
tensor([1.0000], grad_fn=<SumBackward1>)