{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Zajęcia 8: Vector representations\n",
    "\n",
    "Wszystkie zadania ćwiczeniowe należy rozwiązywać w języku Python w kopii Jupyter Notebook'a dla danych zajęć w wyznaczonych miejscach (komórki z komentarzem `# Solution`).\n",
    "\n",
    "Nie należy usuwać komórek z treścią zadań.\n",
    "\n",
    "Należy wyświetlać outputy przy pomocy `print`\n",
    "\n",
    "## Dla chętnych! (może się przydać od ambitniejszych projektów końcowych)\n",
    "\n",
    "https://github.com/huggingface/smol-course - kurs finetune'owania LLMów do własnych zadań\n",
    "\n",
    "https://github.com/unslothai/unsloth - biblioteka do efektywnego finetune'owania LLMów (są gotowe notebooki z kodem na platformie Colab)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Co to jest wektor?\n",
    "\n",
    "Wektor - jednowymiarowa macierz\n",
    "\n",
    "[0, 1, 0, 0, 0] - one hot encoding - tylko wartości 0/1\n",
    "\n",
    "[0, 2, 0, 5, 1, 100] - frequency encoding - liczby całkowite >= 0\n",
    "\n",
    "[-1.5, 0.0002, 5000.01] - wektor"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Zadanie 1\n",
    "\n",
    "Dokonaj preprocessingu tekstów https://git.wmi.amu.edu.pl/ryssta/spam-classification/src/branch/master/train/in.tsv (preprocessing - proces wstępnej \"obróbki tekstów - sprowadzenie do małych liter, tokenizacja itd..) i dokonaj:\n",
    "* a) one hot encodingu\n",
    "* b) frequency encodingu\n",
    "\n",
    "dla dwóch przykładowych zdań, które zawierają przynajmniej 2 wystąpienia słowa, które znajduje się w słowniku (czyli występuje w korpusie z pliku in.tsv). Ze względu na dużą liczbę unikalnych słów w korpusie, proszę nie printować całych wektorów, tylko indeksy oraz wartości."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Zadanie 2\n",
    "\n",
    "Na podstawie pliku https://git.wmi.amu.edu.pl/ryssta/spam-classification/src/branch/master/train/in.tsv oraz pliku https://git.wmi.amu.edu.pl/ryssta/spam-classification/src/branch/master/train/expected.tsv podziel teksty względem klasy spam/nie spam. Oblicz wartość IDF osobno dla tekstów klasy spam oraz dla tekstów klasy nie spam, dla słów:\n",
    "* free\n",
    "* send\n",
    "* are\n",
    "* the\n",
    "\n",
    "oraz 2-3 własnoręcznie wybranych słów."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Zadanie 3\n",
    "\n",
    "Na podstawie warstwy embedding modelu gpt2 wypisz 15 najbardziej podobnych (względem miary podobieństwa cosinuowego) tokenów do słów:\n",
    "* cat\n",
    "* tree\n",
    "\n",
    "oraz 2 własnoręcznie wybranych tokenów."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 79,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "c:\\Users\\ryssta\\AppData\\Local\\anaconda3\\Lib\\site-packages\\transformers\\tokenization_utils_base.py:1601: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be depracted in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "from transformers import GPT2Tokenizer, GPT2Model\n",
    "import torch\n",
    "\n",
    "\n",
    "tokenizer = GPT2Tokenizer.from_pretrained('gpt2')\n",
    "model = GPT2Model.from_pretrained('gpt2')\n",
    "embedding_layer = model.wte\n",
    "cos_sim = torch.nn.CosineSimilarity()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "[\n",
    "    [0.1, 0.2, 0.3], # Ala\n",
    "    [-0.5, 0.5, 0.9], # ma\n",
    "    ...\n",
    "    # 50254\n",
    "    ...\n",
    "    [0.1, -0.1, -0.2] # w GPT2 jest 768 wartości w pojedynczym wektorze, a nie 3\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 90,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Tekst 'cat' jest konwertowany do tokenu 9246\n",
      "\n",
      "Tokenizacja\n",
      "{'input_ids': [9246], 'attention_mask': [1]}\n",
      "\n",
      "Detokenizacja\n",
      "cat\n",
      "\n",
      "Liczba tokenów w słowniku\n",
      "50257\n"
     ]
    }
   ],
   "source": [
    "print(\"Tekst 'cat' jest konwertowany do tokenu 9246\")\n",
    "print(\"\\nTokenizacja\")\n",
    "print(tokenizer(\"cat\"))\n",
    "print(\"\\nDetokenizacja\")\n",
    "print(tokenizer.decode([9246]))\n",
    "print(\"\\nLiczba tokenów w słowniku\")\n",
    "print(len(tokenizer))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Embedding tokenu: 9246\n",
      "\n",
      "Rozmiar embeddingu (wektora)\n",
      "torch.Size([1, 768])\n",
      "\n",
      "Wartości embeddingu\n",
      "tensor([[-0.0164, -0.0934,  0.2425,  0.1398,  0.0388, -0.2592, -0.2724, -0.1625,\n",
      "          0.1683,  0.0829,  0.0136, -0.2788,  0.1493,  0.1408,  0.0557, -0.3691,\n",
      "          0.2200, -0.0428,  0.2206,  0.0865,  0.1237, -0.1499,  0.1446, -0.1150,\n",
      "         -0.1425, -0.0715, -0.0526,  0.1550, -0.0678, -0.2059,  0.2065, -0.0297,\n",
      "          0.0834, -0.0483,  0.1207,  0.1975, -0.3193,  0.0124,  0.1067, -0.0473,\n",
      "         -0.3037,  0.1139,  0.0949, -0.2175,  0.0796, -0.0941, -0.0394, -0.0704,\n",
      "          0.2033, -0.1555,  0.2928, -0.0770,  0.0787,  0.1214,  0.1528, -0.1464,\n",
      "          0.4247,  0.1921, -0.0415, -0.0850, -0.2787,  0.0656, -0.2026,  0.1856,\n",
      "          0.1353, -0.0820, -0.0639,  0.0701,  0.1680,  0.0597,  0.3265, -0.1100,\n",
      "          0.1056,  0.1845, -0.1156,  0.0054,  0.0663,  0.1842, -0.1069,  0.0491,\n",
      "         -0.0853, -0.2519,  0.0031,  0.1805,  0.1505,  0.0442, -0.2427,  0.1104,\n",
      "          0.0970,  0.1123, -0.1519, -0.1444,  0.2323, -0.0241, -0.0677,  0.1157,\n",
      "         -0.2668, -0.1229,  0.1120,  0.0601, -0.0535,  0.1259, -0.0966, -0.1975,\n",
      "         -0.2031,  0.1323,  0.0176, -0.1332,  0.1159,  0.1037,  0.0722,  0.1644,\n",
      "         -0.0775, -0.0227,  0.1146,  0.0060,  0.3959, -0.0828,  0.0125,  0.0415,\n",
      "         -0.0147, -0.1352, -0.0579,  0.0423, -0.0793, -0.2702,  0.2806, -0.0744,\n",
      "          0.1118,  0.0908,  0.0639, -0.0882, -0.0190, -0.1386, -0.0490, -0.1785,\n",
      "          0.1416,  0.0497, -0.0461, -0.1544,  0.0662, -0.0538,  0.0992,  0.1308,\n",
      "         -0.0885, -0.2840, -0.0297, -0.0882, -0.0340, -0.1495,  0.0295,  0.0700,\n",
      "          0.0661, -0.1282, -0.0546, -0.1392, -0.1368,  0.0353,  0.0814, -0.1436,\n",
      "         -0.0559,  0.1523, -0.0780,  0.2562,  0.0164, -0.0433,  0.0468,  0.2896,\n",
      "          0.0069,  0.2136,  0.0378, -0.1625, -0.0421, -0.0109, -0.0386,  0.0453,\n",
      "          0.2572,  0.0323, -0.1206,  0.0135, -0.0171,  0.0404, -0.2544,  0.1455,\n",
      "          0.3265, -0.0545,  0.0887, -0.0321,  0.1485, -0.0699, -0.0606,  0.2177,\n",
      "         -0.1566, -0.0619, -0.2655,  0.2471, -0.1213, -0.0741,  0.1074, -0.0243,\n",
      "         -0.2747,  0.1828,  0.0046,  0.1426,  0.0201,  0.0413, -0.0189, -0.0070,\n",
      "         -0.0338, -0.1384,  0.0269,  0.1447, -0.0216, -0.0042, -0.0363, -0.0579,\n",
      "         -0.0909, -0.1359,  0.1407,  0.1421,  0.0041,  0.0100,  0.0997, -0.0718,\n",
      "         -0.0958,  0.0051, -0.2576,  0.1980,  0.1545,  0.0226, -0.2521, -0.1091,\n",
      "         -0.1467, -0.1140,  0.1383, -0.1952,  0.0554, -0.0036, -0.2915, -0.1645,\n",
      "          0.0469, -0.2251,  0.2000,  0.1070,  0.1651, -0.0781,  0.1511, -0.0095,\n",
      "          0.0925,  0.0776,  0.1631,  0.1563,  0.0286, -0.1157,  0.0349,  0.0033,\n",
      "         -0.0870, -0.0864,  0.1233, -0.0691, -0.0458, -0.0601,  0.0501, -0.1450,\n",
      "         -0.2425, -0.0773,  0.1182,  0.1351,  0.1904,  0.1746,  0.0925, -0.1253,\n",
      "         -0.1149, -0.1312, -0.2170,  0.0682, -0.0121, -0.1774, -0.0857, -0.1906,\n",
      "          0.2842, -0.0410,  0.0530,  0.0480, -0.0641, -0.0911,  0.2907, -0.2503,\n",
      "         -0.1085,  0.1753,  0.0610,  0.0466,  0.0097, -0.1300, -0.0273,  0.0498,\n",
      "         -0.0619, -0.1867, -0.0769, -0.1091, -0.0410, -0.0617, -0.0537,  0.0582,\n",
      "         -0.0986, -0.2655,  0.1236, -0.0026,  0.0444, -0.1018, -0.1652,  0.0174,\n",
      "         -0.2561,  0.0440,  0.2048,  0.0049, -0.0220, -0.1031, -0.1387,  0.0493,\n",
      "          0.2048,  0.0473,  0.1630, -0.0195, -0.0714, -0.0713, -0.2524,  0.0852,\n",
      "         -0.1682,  0.0713,  0.1301, -0.1088, -0.0188,  0.0092,  0.0078,  0.2213,\n",
      "          0.0638, -0.1617, -0.0365, -0.0923, -0.1052,  0.1108, -0.1175, -0.0016,\n",
      "         -0.0258,  0.0902,  0.1089,  0.1685, -0.2664, -0.0309, -0.0187, -0.0678,\n",
      "         -0.1424, -0.0026, -0.0623, -0.0575, -0.1009,  0.0142, -0.1950,  0.0085,\n",
      "         -0.1402,  0.0371, -0.4072, -0.0478,  0.4013,  0.3212,  0.1051,  0.0349,\n",
      "         -0.1302, -0.0298, -0.1738,  0.0692, -0.2638,  0.1268,  0.1773, -0.1094,\n",
      "          0.0737, -0.0460,  0.1870, -0.0605, -0.1308, -0.0920, -0.0290, -0.0542,\n",
      "          0.1214, -0.0308, -0.1173, -0.2127,  0.0209,  0.2911, -0.1751,  0.0469,\n",
      "          0.0740, -0.1323,  0.0283,  0.2125,  0.1870,  0.0978,  0.1799,  0.2669,\n",
      "          0.1709, -0.1191, -0.2022,  0.0445,  0.0601, -0.1820,  0.0224,  0.1902,\n",
      "         -0.3199, -0.2551,  0.0795,  0.0814,  0.1245,  0.0871, -0.0455, -0.2342,\n",
      "          0.1167,  0.0870,  0.0257, -0.2073,  0.1849, -0.0184,  0.0498, -0.1423,\n",
      "         -0.0682,  0.1386,  0.0406, -0.0325,  0.2179, -0.0567, -0.2568, -0.0935,\n",
      "         -0.0453, -0.1317,  0.0682, -0.2721, -0.2026, -0.0565, -0.0134, -0.0423,\n",
      "         -0.0415, -0.0560, -0.1522,  0.1617, -0.0753, -0.1967,  0.0536, -0.0988,\n",
      "          0.0539, -0.0489,  0.0296,  0.1272,  0.1567,  0.0185, -0.0855, -0.2336,\n",
      "          0.1859,  0.1528, -0.1824, -0.0834, -0.1414, -0.0526,  0.1744, -0.0290,\n",
      "         -0.0753, -0.0100, -0.1702, -0.0676,  0.0856,  0.0493, -0.1256, -0.1652,\n",
      "         -0.1317,  0.0677,  0.0209, -0.0346,  0.0048,  0.1209,  0.1959,  0.1520,\n",
      "          0.0793, -0.1492,  0.3141,  0.1526, -0.1732, -0.0914,  0.1339, -0.1410,\n",
      "         -0.0595, -0.0250, -0.1136, -0.1206, -0.1126, -0.0470,  0.1898,  0.0565,\n",
      "         -0.2058,  0.0389,  0.0177, -0.2718,  0.2021, -0.0779,  0.1444,  0.1047,\n",
      "         -0.2096, -0.0210,  0.1791, -0.4005, -0.1931,  0.1083,  0.2465,  0.1026,\n",
      "         -0.0503,  0.1047,  0.0299, -0.1043,  0.0964,  0.0852, -0.2067,  0.1263,\n",
      "          0.2064,  0.2248,  0.2739, -0.1881, -0.0745,  0.0769,  0.2994,  0.2803,\n",
      "          0.0063,  0.2585, -0.0176,  0.2318, -0.0432,  0.1889, -0.0766,  0.0751,\n",
      "         -0.0157,  0.0517,  0.1274, -0.2235, -0.0450,  0.1606,  0.0876,  0.1240,\n",
      "          0.4417, -0.0625,  0.0591, -0.0181,  0.1996,  0.0959, -0.2623, -0.2826,\n",
      "          0.0023,  0.1835,  0.1931, -0.1054,  0.1816, -0.1599, -0.0871,  0.0115,\n",
      "          0.2386,  0.0161,  0.0580, -0.0558,  0.0963,  0.1206, -0.3461,  0.0726,\n",
      "          0.0301,  0.1058,  0.0532,  0.0515,  0.0216, -0.0531, -0.0217,  0.0539,\n",
      "         -0.0191,  0.0636, -0.1527, -0.1670,  0.0756,  0.0167, -0.0437,  0.0050,\n",
      "         -0.1861,  0.0304,  0.2442, -0.0126, -0.2314,  0.1562,  0.1635, -0.1206,\n",
      "         -0.0428,  0.1079,  0.1216, -0.0113,  0.1757, -0.0235,  0.2049, -0.3030,\n",
      "          0.0067, -0.3157,  0.1435, -0.1737, -0.1698,  0.2276, -0.0360,  0.0048,\n",
      "         -0.2974,  0.2021,  0.1380,  0.1129, -0.0626,  0.1347,  0.0729,  0.0481,\n",
      "         -0.1397,  0.0197, -0.0932,  0.1717, -0.1519, -0.0554, -0.0344,  0.0201,\n",
      "          0.1316,  0.0743, -0.1189,  0.2787,  0.0597, -0.2073, -0.3555, -0.0645,\n",
      "         -0.1326,  0.1094,  0.1512,  0.0241,  0.0608,  0.0334,  0.1340, -0.0510,\n",
      "         -0.0197,  0.0681, -0.1494,  0.2410, -0.1016,  0.1148, -0.1280,  0.0576,\n",
      "         -0.0101,  0.1957, -0.2854,  0.0231, -0.0282, -0.0101,  0.0383, -0.0086,\n",
      "         -0.1152,  0.1464,  0.0351, -0.2189,  0.2156,  0.0722,  0.0881,  0.2360,\n",
      "          0.1182,  0.0676,  0.0506, -0.0226, -0.2842, -0.2781,  0.0791,  0.1958,\n",
      "          0.0470, -0.0645,  0.0129,  0.2878,  0.1339, -0.2548, -0.0855, -0.1515,\n",
      "          0.1442, -0.1694, -0.2864, -0.1720, -0.0473, -0.1325,  0.2390,  0.0736,\n",
      "         -0.1975, -0.1141, -0.1409,  0.1901, -0.0673,  0.0887,  0.2132,  0.1869,\n",
      "         -0.2635,  0.1148, -0.0100,  0.0344, -0.1273, -0.0943,  0.1698,  0.1511,\n",
      "         -0.2248,  0.0495,  0.1384,  0.1260,  0.0787,  0.3357,  0.2288, -0.0885,\n",
      "          0.0622, -0.2151,  0.0553,  0.0690,  0.2568, -0.2637, -0.0460,  0.0911,\n",
      "         -0.0214,  0.1712,  0.0342, -0.0438, -0.1457, -0.0953,  0.0289, -0.0830,\n",
      "         -0.0591, -0.0244,  0.1259,  0.0075, -0.1052,  0.1119,  0.0621, -0.2382,\n",
      "          0.1763, -0.1132, -0.1201,  0.0162, -0.1807,  0.1419,  0.0749, -0.2318,\n",
      "          0.0132,  0.2532, -0.1028,  0.0556, -0.0214,  0.1974, -0.1573,  0.1861,\n",
      "         -0.1729, -0.0737, -0.1499, -0.1533,  0.0011,  0.0681, -0.1828,  0.0519,\n",
      "         -0.0837,  0.0671, -0.1867,  0.0012,  0.0377,  0.1061, -0.1713, -0.1579]],\n",
      "       grad_fn=<EmbeddingBackward0>)\n"
     ]
    }
   ],
   "source": [
    "print(\"Embedding tokenu: 9246\")\n",
    "cat_embedding = embedding_layer(torch.LongTensor([9246]))\n",
    "print(\"\\nRozmiar embeddingu (wektora)\")\n",
    "print(cat_embedding.shape)\n",
    "print(\"\\nWartości embeddingu\")\n",
    "print(cat_embedding)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 89,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Podobieństwo tego samego embeddingu (powinno wyjść 1)\n",
      "tensor([1.0000], grad_fn=<SumBackward1>)\n"
     ]
    }
   ],
   "source": [
    "print(\"Podobieństwo tego samego embeddingu (powinno wyjść 1)\")\n",
    "print(cos_sim(cat_embedding, cat_embedding))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}