aitech-eks-pub/wyk/09_neurozoo.ipynb

1225 lines
52 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Neurozoo\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Funkcja sigmoidalna\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Funkcja sigmoidalna zamienia dowolną wartość („sygnał”) w wartość z przedziału $(0,1)$, czyli wartość, która może być interperetowana jako prawdopodobieństwo.\n",
"\n",
"$$\\sigma(x) = \\frac{1}{1 + e^{-x}}$$\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(0.6457)"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import torch\n",
"\n",
"def sigmoid(x):\n",
" return 1 / (1 + torch.exp(-x))\n",
"\n",
"sigmoid(torch.tensor(0.6))"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'sigmoid.png'"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAEGCAYAAABo25JHAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAAhNUlEQVR4nO3deXyV5Z3+8c83+04gCVsIqyAguwHcWm2tDijuti5YFW39aWvHWtuqHaszdanLtFNn1PKjFpda12opVQpq6zYubMoWQiCGJWHLRvY9554/EpkMBgTMk+cs1/v1Oq/knOdJvI7Gcz3rfZtzDhERiVxRfgcQERF/qQhERCKcikBEJMKpCEREIpyKQEQkwsX4HeBIZWZmuuHDh/sdQ0QkpKxevbrcOZfV3bKQK4Lhw4ezatUqv2OIiIQUM9t+sGU6NCQiEuFUBCIiEU5FICIS4VQEIiIRzrMiMLOFZlZqZhsOstzM7D/NrNDM1pnZNK+yiIjIwXm5R/AkMOsQy2cDozsf1wG/9TCLiIgchGdF4Jx7F6g8xCrnAU+7Dh8B6WY2yKs8IiLSPT/vI8gGirs8L+l8bfeBK5rZdXTsNTB06NBeCSci0hucczS3BahpaqWuqY265rb9Xxta2qlvaaOhuePr8cP68pXR3d4T9qX4WQTWzWvdTo7gnFsALADIzc3VBAoiEpQCAce+hhbK61qoqG+moq6FfQ0tVNa3UNXQyr6Gjq/VjR2PmsZWapvaaGkPHNbvv+G0UWFXBCVATpfnQ4BdPmURETmkptZ2dlY1squqkd3VTeyuamJPTROlNU3srW2itKaZivoW2gPdb6umJcSQnhRH36RY0hJjGdI3kT6JHd+nJsSQmhBLWkIMyXExpCTEkBIfQ3J8DMlx0STFx5AYG010VHfbz1+en0WwGLjRzJ4HZgLVzrnPHRYSEekt1Y2tbC2vp6isjm0VDeyoqGd7ZQPFlY2U1zV/bv2M5DgGpCUwIC2e8YPSyEqNJyslnoyUeDJS4shMiadfchzpibHERAfv1fqeFYGZPQecBmSaWQlwFxAL4JybDywBzgIKgQZgnldZRES6qmlqpWBPLZt217B5bx1bSmspLK2jvK5l/zpRBoP6JDIsI4nTx/Ynu28iQ/omMjg9kcF9EumfFk9CbLSP76LneFYEzrnLvmC5A77v1T9fRAQ6tvLXlVSxrqSaDTurWb+zmpJ9jfuXpybEMLp/Cl8f259j+qcwIjOFEZnJDO2XRFxM8G7F96SQG31URORgnHNsr2hgxdZKVm6r5JPiKgpL6/YvH5aRxOScdC6fOZRxA9MYOyiVgWkJmHlz7D1UqAhEJKTtqmrkvwvL+aCwnA8+raC0tuNYft+kWKYN7cv5UwYzJacvE4f0oU9irM9pg5OKQERCSmt7gJXbKnm7oIy3NpWypXOLPzMljhNHZXLCyH7MGN6PY/qnRPyW/uFSEYhI0GtqbeftgjJez9vD3zeVUt3YSmy0MXNEBpdMz+GU0ZkcOyBVH/xHSUUgIkGptT3Af28pZ/HaXbyet4f6lnb6JMZy+rj+nDl+IKeMziQlXh9hPUH/FkUkqBTsqeWlVcUsWrOT8roW0hJimDNpMOdMHszMkf2IDeLr8UOVikBEfNfU2s5r63bzh4+2s6a4ipgo4/Rx/bn4+BxOHZMVMZdx+kVFICK+2VPdxFMfbuP5FTvY19DKyKxk7jh7HBdMzSYjJd7veBFDRSAivW7TnhoWvFvEX9fuoj3gOHP8QK48cRgnjsrQCV8fqAhEpNds2FnNf/1jC8vy9pIUF83cmcO49pQR5PRL8jtaRFMRiIjnCvbU8tCyAt7M30tqQgw3nT6aeScPJz0pzu9ogopARDxUsq+BX7+xmT9/spOUuBh+dMYYrj55OGkJusM3mKgIRKTHNbS08dhbn7LgvSIAvvuVkdxw6ij6JmsPIBipCESkxzjnWLx2F79csok9NU2cP2UwP5k1luz0RL+jySGoCESkR2yvqOeORRt4b0s5k4b04dG5Uzl+WD+/Y8lhUBGIyJfS1h5gwXtFPPzmFuKio7j7/AnMnTGUKI+mVZSepyIQkaNWWFrLLS+uZW1JNbMnDORfzz2OAWkJfseSI6QiEJEjFgg4Fr6/lQeXFZAcF81jc6dx1sRBfseSo6QiEJEjUlrbxC0vruW9LeV8Y9wAfnnhRLJSNRxEKFMRiMhhe3dzGT96cQ21TW3cd8FELpuRoyEhwoCKQES+UCDgePjvW3j471sYMyCFP37nBI4dmOp3LOkhKgIROaTqxlZufmEN/9hUykXThnDP+RNIjIv2O5b0IBWBiBzUlr21fOfpVezc18jd5x3HFScM06GgMKQiEJFuvbeljO898zHxsdE8f90J5A7XzWHhSkUgIp/zzEfbuWtxHqP7p/D7q6driIgwpyIQkf2cczywtID573zKacdm8V+XTSVVI4WGPRWBiAAdQ0Xc9sp6/rS6hMtnDuUX5x5HjCaKjwgqAhGhsaWdG5/9mL9vKuWm00fzw2+M1knhCKIiEIlw9c1tXPPkSlZsq+Tu8yfw7ROG+R1JepmKQCSC1TS1Mu+JlawpruI3l0zhvCnZfkcSH6gIRCJUdUMrVz6xgryd1Txy2VRma9C4iOXpmSAzm2VmBWZWaGa3dbO8j5n91czWmlmemc3zMo+IdKhpauXKhcvJ31XDb684XiUQ4TwrAjOLBh4FZgPjgcvMbPwBq30f2OicmwycBvzKzDSpqYiH6pvbmPfESvJ21fDY3GmcMX6A35HEZ17uEcwACp1zRc65FuB54LwD1nFAqnVcnpACVAJtHmYSiWhNre1856lVfLJjH/952VS+oRIQvC2CbKC4y/OSzte6egQYB+wC1gM3OecCB/4iM7vOzFaZ2aqysjKv8oqEtdb2ADc8s5qPtlbw629N0UQysp+XRdDdRcjugOf/BKwBBgNTgEfMLO1zP+TcAudcrnMuNysrq6dzioS9QMBx65/W8VZBGfeeP5Hzp+rqIPlfXhZBCZDT5fkQOrb8u5oHvOI6FAJbgbEeZhKJSPcv3cQrn+zkljPGcPnMoX7HkSDjZRGsBEab2YjOE8CXAosPWGcHcDqAmQ0AjgWKPMwkEnF+924RC94t4qoTh3Hj14/xO44EIc/uI3DOtZnZjcAyIBpY6JzLM7PrO5fPB+4GnjSz9XQcSrrVOVfuVSaRSLNk/W7uXZLP2RMHcdc5x2nYCOmWpzeUOeeWAEsOeG1+l+93AWd6mUEkUn28Yx83v7CG44f15VffmkxUlEpAuqehBUXCUHFlA999ahUD0hJY8O3jSYjV1JJycCoCkTBT29TKNU+upC3geGLedDJS4v2OJEFOYw2JhJFAwHHzC2soKq/nD9fMYFRWit+RJARoj0AkjPz6jc28mV/Kz88ex0nHZPodR0KEikAkTLy6bhePvFXIJbk5XHXScL/jSAhREYiEgYI9tfzkpXUcP6wvvzhfl4nKkVERiIS42qZWbnhmNcnxMfx27jTiY3SFkBwZnSwWCWHOOX7y0jq2Vzbw7Hdm0j8twe9IEoK0RyASwn73XhFL8/Zw26yxzByZ4XccCVEqApEQtXp7JQ8sLWD2hIF85ysj/I4jIUxFIBKC9tW38INnPyE7PZEHLp6kk8PypegcgUiIcc7xkz+tpayumZdvOIm0hFi/I0mI0x6BSIhZ+P423swv5fbZ45g0JN3vOBIGVAQiIWTDzmru/1s+3xg3gHknD/c7joQJFYFIiGhsaeefn/+EfslxPKTzAtKDdI5AJETc89pGisrq+eN3ZtI3Oc7vOBJGtEcgEgJez9vDH5fv4LqvjuRkDSYnPUxFIBLkSmubuPXldRw3OI1bzhzjdxwJQyoCkSDmnOP2l9fT0NLOw5dO0ThC4gkVgUgQe3FVMX/fVMqts8ZyTP9Uv+NImFIRiASp4soGfvHXjZw4MoOrNb+AeEhFIBKEAgHHLS+tJcqMf//WZKKidKmoeEdFIBKEnvhgGyu2VnLnOePJTk/0O46EORWBSJDZWl7PQ8s2cfrY/lx8/BC/40gEUBGIBJFAwPHTP60lLjqK+y6cqLuHpVeoCESCyFMfbmPltn3cec5xDNBsY9JLVAQiQWJ7RT0PLN3E147N4qJp2X7HkQiiIhAJAs45bnt5PbFROiQkvU9FIBIEXlhZzIdFFdx+1jgG9dFVQtK7VAQiPttb08S9S/I5YWQ/Lp2e43cciUAqAhEfOee4Y9EGWtoC3H/hJN04Jr7wtAjMbJaZFZhZoZnddpB1TjOzNWaWZ2bveJlHJNj8bcMe3ti4l1vOHMPwzGS/40iE8mxiGjOLBh4FzgBKgJVmttg5t7HLOunAY8As59wOM+vvVR6RYFPd2Mpdi/OYkJ3GNSeP8DuORDAv9whmAIXOuSLnXAvwPHDeAetcDrzinNsB4Jwr9TCPSFB5YOkmKuqauf/CScRE6yit+MfLv75soLjL85LO17oaA/Q1s7fNbLWZXdndLzKz68xslZmtKisr8yiuSO9Zua2SZ5fv4NpTRjAhu4/fcSTCeVkE3Z31cgc8jwGOB84G/gn4uZl9bgom59wC51yucy43Kyur55OK9KLmtnZuf2U92emJ3HyGZhwT/3k5eX0J0PVauCHArm7WKXfO1QP1ZvYuMBnY7GEuEV/9/3eKKCyt44l500mK8/J/QZHD4+UewUpgtJmNMLM44FJg8QHr/AX4ipnFmFkSMBPI9zCTiK+2ltfzyFuFzJk0iK8dq2sjJDh4tjninGszsxuBZUA0sNA5l2dm13cun++cyzezpcA6IAA87pzb4FUmET913DOwnvjoKO6cM97vOCL7ebpf6pxbAiw54LX5Bzx/CHjIyxwiweAva3bxfmEFd58/gf4aWVSCiK5ZE+kFVQ0t3PPaRqbkpDN3xlC/44j8HzpTJdILHlhawL6GVp6+ZqKGkZCgoz0CEY+t3r6P51bsYN5Jwxk/OM3vOCKfoyIQ8VBbe4A7Fm1gYFoCP9Q9AxKkVAQiHnryg23k767hrnPGkxKvI7ESnFQEIh7ZXd3If7yxmdOOzWLWhIF+xxE5KBWBiEfufnUjbQHHL86doKknJaipCEQ88HZBKUvW7+HGrx3D0Iwkv+OIHJKKQKSHNbW2c9fiPEZkJnPdqSP9jiPyhXT2SqSHzX/nU7ZXNPDMtTOJj4n2O47IF9IegUgP2lZez2Nvf8o5kwdzyuhMv+OIHBYVgUgPcc5x1+I84qKjuOPscX7HETlsKgKRHrIsbw/vbC7j5jPGMECDykkI+cIiMLMbzaxvb4QRCVX1zW382183MnZgKledOMzvOCJH5HD2CAYCK83sRTObZbogWuRz/vMfW9hd3cS9F0zQRPQScr7wL9Y5dwcwGvg9cDWwxczuM7NRHmcTCQmb99by+/e28q3cIRw/rJ/fcUSO2GFtujjnHLCn89EG9AX+ZGYPephNJOg55/j5og0kx8dw66yxfscROSpfeB+Bmf0zcBVQDjwO/MQ512pmUcAW4KfeRhQJXovW7GT51kruu2AiGSnxfscROSqHc0NZJnChc2571xedcwEzm+NNLJHgV93Qyr2v5TM5J51Lp+f4HUfkqH1hETjn7jzEsvyejSMSOv799QIq61t4ct4MzTomIU2XN4gchXUlVTyzfDtXnjicCdl9/I4j8qWoCESOUHvAcceiDWQkx/OjMzXrmIQ+FYHIEXp2+XbWlVTz8znjSEuI9TuOyJemIhA5AqW1TTy4rICTj8ng3MmD/Y4j0iNUBCJH4L7X8mluDXD3eZp1TMKHikDkMH1QWM6iNbu4/rRRjMxK8TuOSI9REYgchua2du74ywaGZSTxvdM0uoqEF81QJnIY5r9dRFFZPU9fM4OEWM06JuFFewQiX6CorI5H3y7knMmD+eqYLL/jiPQ4FYHIITjXcc9AfEwUP5+jWcckPKkIRA5h0ZqdfPBpBT+dNZb+qZp1TMKTp0XQOZFNgZkVmtlth1hvupm1m9nFXuYRORL76lu459V8puSkM3fGUL/jiHjGs5PFZhYNPAqcAZTQMcvZYufcxm7WewBY5lUWkaNx35J8qhpb+cMFEzWonIQ1L/cIZgCFzrki51wL8DxwXjfr/QB4GSj1MIvIEfmgsJyXVpdw3VdHMn5wmt9xRDzlZRFkA8Vdnpd0vrafmWUDFwDzD/WLzOw6M1tlZqvKysp6PKhIV02t7fzsz+sZlpHETaeP9juOiOe8LILu9qXdAc9/A9zqnGs/1C9yzi1wzuU653KzsnT5nnjrkX8Usq2igXvPn6h7BiQieHlDWQnQddqmIcCuA9bJBZ7vHLMlEzjLzNqcc4s8zCVyUPm7a5j/zqdcODWbU0Zn+h1HpFd4WQQrgdFmNgLYCVwKXN51BefciM++N7MngVdVAuKXtvYAt768jj6Jsfx8zni/44j0Gs+KwDnXZmY30nE1UDSw0DmXZ2bXdy4/5HkBkd72xPvbWFdSzX9dNpW+yXF+xxHpNZ6ONeScWwIsOeC1bgvAOXe1l1lEDmVbeT2/eqOAb4wbwJxJg/yOI9KrdGexRDznHLe/sp7YqCjuOV/zDEjkURFIxPvj8h18WFTB7WeNY2AfDSMhkUdFIBGtuLKBXy7J55RjMrlsRs4X/4BIGFIRSMRyznHbK+sAuP+iiTokJBFLRSAR67kVxbxfWMHPzh7HkL5JfscR8Y2KQCJScWUD9y3J56RRGVyukUUlwqkIJOIEAo4fv7QWgAcumqRDQhLxVAQScRa+v5XlWyu585zx5PTTISERFYFElC17a3lwWceNY988fojfcUSCgopAIkZre4CbX1xDSnwMv7xQVwmJfMbTISZEgslv3tzMhp01zL9iGlmp8X7HEQka2iOQiLC8qILH3v6Ub+UOYdYEjSUk0pWKQMJedUMrN7+whmH9krjrnOP8jiMSdHRoSMKac46fLVpPaW0zL99wEsnx+pMXOZD2CCSsvbS6hNfW7ebmM8YwOSfd7zgiQUlFIGFry95a7vzLBk4cmcH1p47yO45I0FIRSFhqbGnnxmc/ITkuhocvnUJ0lC4VFTkYHTCVsPSLV/Mo2FvLU9fMoH+a5hgQORTtEUjYWfTJTp5bUcwNp43i1DFZfscRCXoqAgkrBXtquf2V9Uwf3pcfnTHG7zgiIUFFIGGjpqmV659ZTUpCDI9ePo3YaP15ixwOnSOQsOCc4ycvrWVHZQPPffcEnRcQOQLaZJKw8Nt3PmVZ3l5unz2WGSP6+R1HJKSoCCTk/T1/Lw8tK2DOpEFce8oIv+OIhBwVgYS0wtJabnp+DeMHpfHQxZM1tLTIUVARSMiqbmjlu0+vJiE2igVX5pIYF+13JJGQpJPFEpJa2wN879nVlOxr4NnvnkB2eqLfkURClopAQo5zjjv+vIH3Cyt46OJJTB+uk8MiX4YODUnI+e07n/LCqmJ+8PVj+GZujt9xREKeikBCyqvrdvHg0gLOnTxYdw6L9BAVgYSM9wvLufmFNcwY3o8HL56kK4REeoinRWBms8yswMwKzey2bpbPNbN1nY8PzGyyl3kkdK0vqea6p1cxMjOF312ZS0KsrhAS6SmeFYGZRQOPArOB8cBlZjb+gNW2Aqc65yYBdwMLvMojoWtreT1XP7GC9KQ4nr52Bn2SYv2OJBJWvNwjmAEUOueKnHMtwPPAeV1XcM594Jzb1/n0I2CIh3kkBJXsa+CKx5cTcI6nr53BAI0hJNLjvCyCbKC4y/OSztcO5lrgb90tMLPrzGyVma0qKyvrwYgSzPZUNzH38eXUNLXyh2tnMiorxe9IImHJyyLo7kye63ZFs6/RUQS3drfcObfAOZfrnMvNytJEI5GgvK6ZuY9/RHltM09dM4MJ2X38jiQStry8oawE6HqR9xBg14Ermdkk4HFgtnOuwsM8EiLKajtKYGdVI0/Nm8G0oX39jiQS1rzcI1gJjDazEWYWB1wKLO66gpkNBV4Bvu2c2+xhFgkRe6qbuGTBhxRXNrLwqunMHJnhdySRsOfZHoFzrs3MbgSWAdHAQudcnpld37l8PnAnkAE81nlNeJtzLterTBLcdlY1cvnv/vdwkOYVEOkd5ly3h+2DVm5urlu1apXfMaSHFZbWcdXCFdQ0tfLUNTocJNLTzGz1wTa0Neic+G5NcRXznlhBdJTx3HdP0IlhkV6mIhBfvbu5jOufWU1GShx/uGYmwzOT/Y4kEnFUBOKb51bs4I5FGxgzIJWn5k3XhPMiPlERSK9rDzgeWLqJBe8WceqYLB65fCqpCRo2QsQvKgLpVbVNrfzoxbW8sXEvV544jDvnjCcmWoPgivhJRSC9prC0jv/3h1Vsq2jgX88Zz9Unj/A7koigIpBesnTDHn780lriY6J45tqZnDhKN4qJBAsVgXiqua2d+/+2iSfe38bknHTmXzGNQX000bxIMFERiGeKyur4wXOfkLerhnknD+e22WOJj9GEMiLBRkUgPc45x3MrirnntY3ExUTxuytzOWP8AL9jichBqAikR+2pbuLWl9fxzuYyTj4mg4cunszgdB0KEglmKgLpEYGA48VVxdy3JJ/WdscvzjuOK2YOIypKE8yLBDsVgXxphaW1/OyVDazYVsmMEf144KJJjNBQESIhQ0UgR62uuY1H/lHIwv/eSmJcNA9eNIlv5g6hc0hxEQkRKgI5YoGA45VPdvLA0k2U1TZz0bQh3H7WWDJT4v2OJiJHQUUgh805x9uby3hwaQH5u2uYnJPOgm8fz1TNHSAS0lQEclhWbqvkV68X8FFRJTn9EvnNJVM4d/JgnQwWCQMqAjmkj4oqePjNLXxYVEFmShz/es54Lp85jLgYDRQnEi5UBPI57QHHGxv3sODdIj7eUUVWajx3nD2OuTOHkRinO4NFwo2KQParbmzl5dUlPPXhNrZXNDC0XxL/du5xXDI9h4RYFYBIuFIRRDjnHOtKqnluxQ4WrdlJU2uAaUPTuW3WWM48biDROgcgEvZUBBGqrLaZv6zZyYuritm8t47E2GgumJrN3JnDNHm8SIRREUSQ6oZWluXtYfHaXXzwaTkBB1Ny0rn3ggnMmTSYPomaLlIkEqkIwtzemiZe37iX1/P28OGnFbQFHEP7JfG9047hvCmDGT0g1e+IIuIzFUGYaW0PsLa4ircLyniroJS8XTUAjMhM5tqvjGD2hEFMHtJHw0CIyH4qghDX1h4gf3cty7dW8MGnFSwvqqC+pZ3oKOP4oX356axjOX3sAMYMSNGHv4h0S0UQYqoaWlhTXMUnO6r4eMc+PtlRRV1zG9Cx1X/BtGxOHpXJSaMy6ZOkY/4i8sVUBEHKOUdZbTP5e2rZuKuGDTurWb+zmh2VDQBEGYwZkMr5UwczY0QGM4b3Y2CfBJ9Ti0goUhH4LBBw7K5pYmtZPUXldWzZW8fmvbVsKa2jsr5l/3pD+iYyMbsPl0zPYerQdCYNSSclXv/5ROTL0yeJx5xz1DS2sbOqkZ1VjZTsa6C4spEdlQ3sqKxnR2UDTa2B/eunJsQwZkAqZ44fwNiBqYwdlMbYgamkJ8X5+C5EJJypCI5SIOCoamyloq6Z8roWyuqaKa9tprS2mdKaJvbWNrG7uondVU00trb/n59NiI1iWL9khmUk89XRWYzMSmFEZjIjs5Lpnxqvk7oi0qs8LQIzmwU8DEQDjzvn7j9guXUuPwtoAK52zn3sZabPOOdobgtQ39xGfXM7tc2t1DW1UdvURm1zKzWNbdQ0tlLd2EpVYytVDa1UNbSwr6GFqoZW9jW0EHCf/72x0Ub/1AQGpMVz7IBUThvTn8HpCQzqk8iQvolk900kIzlOH/YiEjQ8KwIziwYeBc4ASoCVZrbYObexy2qzgdGdj5nAbzu/9ri3Ckq559WNNLS0dz7aaG3v5pP8AElx0aQnxpKWGEvfpDiO7TxMk5EcR7/OR2ZKPP1T48lMiSc9KVYf8iISUrzcI5gBFDrnigDM7HngPKBrEZwHPO2cc8BHZpZuZoOcc7t7OkyfxFjGDkojOS6apLgYkuKiSY6PISU+Zv/X1ISOr2mJsaQlxJCaEKtx90Uk7HlZBNlAcZfnJXx+a7+7dbKB/1MEZnYdcB3A0KFDjyrMtKF9mXa5plQUETmQl5u73R0fOfBYzOGsg3NugXMu1zmXm5WV1SPhRESkg5dFUALkdHk+BNh1FOuIiIiHvCyClcBoMxthZnHApcDiA9ZZDFxpHU4Aqr04PyAiIgfn2TkC51ybmd0ILKPj8tGFzrk8M7u+c/l8YAkdl44W0nH56Dyv8oiISPc8vY/AObeEjg/7rq/N7/K9A77vZQYRETk0XRspIhLhVAQiIhFORSAiEuGs4zB96DCzMmC73zmOQiZQ7neIXqb3HP4i7f1C6L7nYc65bm/ECrkiCFVmtso5l+t3jt6k9xz+Iu39Qni+Zx0aEhGJcCoCEZEIpyLoPQv8DuADvefwF2nvF8LwPescgYhIhNMegYhIhFMRiIhEOBWBD8zsx2bmzCzT7yxeMrOHzGyTma0zsz+bWbrfmbxiZrPMrMDMCs3sNr/zeM3McszsLTPLN7M8M7vJ70y9xcyizewTM3vV7yw9RUXQy8wsh455nHf4naUXvAFMcM5NAjYDt/ucxxNd5ueeDYwHLjOz8f6m8lwbcItzbhxwAvD9CHjPn7kJyPc7RE9SEfS+/wB+SjczsYUb59zrzrm2zqcf0THxUDjaPz+3c64F+Gx+7rDlnNvtnPu48/taOj4Ys/1N5T0zGwKcDTzud5aepCLoRWZ2LrDTObfW7yw+uAb4m98hPHKwubcjgpkNB6YCy32O0ht+Q8eGXMDnHD3K0/kIIpGZvQkM7GbRvwA/A87s3UTeOtT7dc79pXOdf6HjUMIfezNbLzqsubfDkZmlAC8DP3TO1fidx0tmNgcodc6tNrPTfI7To1QEPcw5943uXjezicAIYK2ZQcdhko/NbIZzbk8vRuxRB3u/nzGzq4A5wOkufG9aici5t80slo4S+KNz7hW/8/SCk4FzzewsIAFIM7NnnHNX+JzrS9MNZT4xs21ArnMuFEcxPCxmNgv4NXCqc67M7zxeMbMYOk6Gnw7spGO+7sudc3m+BvOQdWzNPAVUOud+6HOcXte5R/Bj59wcn6P0CJ0jEC89AqQCb5jZGjOb/0U/EIo6T4h/Nj93PvBiOJdAp5OBbwNf7/xvu6ZzS1lCkPYIREQinPYIREQinIpARCTCqQhERCKcikBEJMKpCEREIpyKQEQkwqkIREQinIpA5Esys+mdcy4kmFly5/j8E/zOJXK4dEOZSA8ws3voGH8mEShxzv3S50gih01FINIDzCyOjjGGmoCTnHPtPkcSOWw6NCTSM/oBKXSMrZTgcxaRI6I9ApEeYGaL6ZiZbAQwyDl3o8+RRA6b5iMQ+ZLM7EqgzTn3bOf8xR+Y2dedc//wO5vI4dAegYhIhNM5AhGRCKciEBGJcCoCEZEIpyIQEYlwKgIRkQinIhARiXAqAhGRCPc/OnN4AJzLYKEAAAAASUVORK5CYII=\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"%matplotlib inline\n",
"import matplotlib.pyplot as plt\n",
"import torch\n",
"\n",
"x = torch.linspace(-5,5,100)\n",
"plt.xlabel(\"x\")\n",
"plt.ylabel(\"y\")\n",
"plt.plot(x, sigmoid(x))\n",
"fname = 'sigmoid.png'\n",
"plt.savefig(fname)\n",
"fname"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"[[file:# Out[32]:\n",
"\n",
" 'sigmoid.png'\n",
"\n",
"![img](./obipy-resources/Tb0Of9.png)]]\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### PyTorch\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Funkcja `torch.sigmoid` po prostu stosuje sigmoidę do każdego elementu tensora (*element-wise*).\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([0.6457, 0.7311, 0.0067])"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import torch\n",
"\n",
"torch.sigmoid(torch.tensor([0.6, 1.0, -5.0]))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Istnieje również `torch.nn.Sigmoid`, które może być używane jako warstwa.\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([0.5000, 0.4502, 0.5987])"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import torch.nn as nn\n",
"\n",
"s = nn.Sigmoid()\n",
"s(torch.tensor([0.0, -0.2, 0.4]))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"##### Implementacja w Pytorchu\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([0.5000, 0.6225, 0.5744])"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import torch.nn as nn\n",
"import torch\n",
"\n",
"class MySigmoid(nn.Module):\n",
" def __init__(self):\n",
" super(MySigmoid, self).__init__()\n",
"\n",
" def forward(self, x):\n",
" return 1 / (1 + torch.exp(-x))\n",
"\n",
"s = MySigmoid()\n",
"s(torch.tensor([0.0, 0.5, 0.3]))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Wagi\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Funkcja sigmoidalna nie ma żadnych wyuczalnych wag.\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"##### **Pytanie**: Czy można rozszerzyć funkcję sigmoidalną o jakieś wyuczalne wagi?\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Regresja liniowa\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Softmax\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"W klasyfikacji wieloklasowej należy zwrócić musimy zwrócić rozkład\n",
"prawdopodobieństwa po wszystkich klasach, w przeciwieństwie do\n",
"klasyfikacji binarnej, gdzie wystarczy zwrócić jedną liczbę —\n",
"prawdopodobieństwo pozytywnej klasy ($p$; prawdopodobieństwo drugiej\n",
"klasy to po prostu $1-p$).\n",
"\n",
"A zatem na potrzeby klasyfikacji wieloklasowej potrzeba wektorowego\n",
"odpowiednika funkcji sigmoidalnej, to jest funkcji, która zamienia\n",
"nieznormalizowany wektor $\\vec{z} = [z_1,\\dots,z_k]$ (pochodzący np. z\n",
"poprzedzającej warstwy liniowej) na rozkład prawdopobieństwa.\n",
"Potrzebujemy zatem funkcji $s: \\mathcal{R}^k \\rightarrow [0,1]^k$\n",
"\n",
"spełniającej następujące warunki:\n",
"\n",
"- $s(z_i) = s_i(z) \\in [0,1]$\n",
"- $\\Sigma_i s(z_i) = 1$\n",
"- $z_i > z_j \\Rightarrow s(z_i) > s(z_j)$\n",
"\n",
"Można by podać takie (**błędne**!) rozwiązanie:\n",
"\n",
"$$s(z_i) = \\frac{z_i}{\\Sigma_{j=1}^k z_j}$$\n",
"\n",
"To rozwiązanie zadziała błędnie dla liczb ujemnych, trzeba najpierw\n",
"użyć funkcji monotonicznej, która przekształaca $\\mathcal{R}$ na $\\mathcal{R^+}$.\n",
"Naturalna funkcja tego rodzaju to funkcja wykładnicza $\\exp{x} = e^x$.\n",
"Tym sposobem dochodzimy do funkcji softmax:\n",
"\n",
"$$s(z_i) = \\frac{e^{z_i}}{\\Sigma_{j=1}^k e^{z_j}}$$\n",
"\n",
"Mianownik ułamka w definicji funkcji softmax nazywamy czasami czynnikiem normalizacyjnym:\n",
"$Z(\\vec{z}) = \\Sigma_{j=1}^k e^{z_j}$, wtedy:\n",
"\n",
"$$s(z_i) = \\frac{e^{z_i}}{Z(\\vec{z})}$$\n",
"\n",
"Definicja w PyTorchu:\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"# Out[75]:\n",
"tensor([0.1182, 0.0022, 0.0059, 0.8737])"
]
}
],
"source": [
"import torch\n",
"\n",
"def softmax(z):\n",
" z_plus = torch.exp(z)\n",
" return z_plus / torch.sum(z_plus)\n",
"\n",
"softmax(torch.tensor([3., -1., 0., 5.]))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Soft vs hard\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Dlaczego *softmax*? Czasami używa się funkcji **hardmax**, która np.\n",
"wektora $[3, -1, 0, 5]$ zwróciłaby $[0, 0, 0, 5]$ — to jest po prostu\n",
"wektorowa wersja funkcji zwracającej maksimum. Istnieje też funkcja\n",
"hard\\*arg\\*max, która zwraca wektor *one-hot* — z jedną jedynką na\n",
"pozycji dla największej wartości (zamiast podania największej\n",
"wartości), np. wartość hardargmax dla $[3, -1, 0, 5]$ zwróciłaby $[0,\n",
"0, 0, 1]$.\n",
"\n",
"Zauważmy, że powszechnie przyjęta nazwa *softmax* jest właściwie\n",
"błędna, funkcja ta powinna nazywać się *softargmax*, jako że w\n",
"„miękki” sposób identyfikuje największą wartość przez wartość zbliżoną\n",
"do 1 (na pozostałych pozycjach wektora nie będzie 0).\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"##### **Pytanie**: Jak można zdefiniować funkcję *softmax* w ścisłym tego słowa znaczeniu („miękki” odpowiednik hardmax, nie hardargmax)?\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### PyTorch\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Funkcja `torch.nn.functional.softmax` normalizuje wartości dla całego tensora:\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"# Out[5]:\n",
"tensor([0.4007, 0.5978, 0.0015])"
]
}
],
"source": [
"import torch.nn as nn\n",
"\n",
"nn.functional.softmax(torch.tensor([0.6, 1.0, -5.0]))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"… zobaczmy, jak ta funkcja zachowuje się dla macierzy:\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"# Out[6]:\n",
"#+BEGIN_EXAMPLE\n",
" tensor([[0.4013, 0.5987],\n",
" [0.0041, 0.9959]])\n",
"#+END_EXAMPLE"
]
}
],
"source": [
"import torch.nn as nn\n",
"\n",
"nn.functional.softmax(torch.tensor([[0.6, 1.0], [-2.0, 3.5]]))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Za pomocą (zalecanego zresztą) argumentu `dim` możemy określić wymiar, wzdłuż którego dokonujemy normalizacji:\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"# Out[8]:\n",
"#+BEGIN_EXAMPLE\n",
" tensor([[0.9309, 0.0759],\n",
" [0.0691, 0.9241]])\n",
"#+END_EXAMPLE"
]
}
],
"source": [
"import torch.nn as nn\n",
"\n",
"nn.functional.softmax(torch.tensor([[0.6, 1.0], [-2.0, 3.5]]), dim=0)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Istnieje również `torch.nn.Softmax`, które może być używane jako warstwa.\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"# Out[10]:\n",
"tensor([0.3021, 0.2473, 0.4506])"
]
}
],
"source": [
"import torch.nn as nn\n",
"\n",
"s = nn.Softmax(dim=0)\n",
"s(torch.tensor([0.0, -0.2, 0.4]))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"##### Implementacja w Pytorchu\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"# Out[48]:\n",
"tensor([0.5000, 0.6225, 0.5744])"
]
}
],
"source": [
"import torch.nn as nn\n",
"import torch\n",
"\n",
"class MySoftmax(nn.Module):\n",
" def __init__(self):\n",
" super(MySoftmax, self).__init__()\n",
"\n",
" def forward(self, x):\n",
" ex = torch.exp(x)\n",
" return ex / torch.sum(ex)\n",
"\n",
"s = MySigmoid()\n",
"s(torch.tensor([0.0, 0.5, 0.3]))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"###### **Pytanie**: Tak naprawdę wyżej zdefiniowana klasa `MySoftmax` nie zachowuje się identycznie jak `nn.Softmax`. Na czym polega różnica?\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Przypadek szczególny\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Sigmoida jest przypadkiem szczególnym funkcji softmax:\n",
"\n",
"$$\\sigma(x) = \\frac{1}{1 + e^{-x}} = \\frac{e^x}{e^x + 1} = \\frac{e^x}{e^x + e^0} = s([x, 0])_1$$\n",
"\n",
"Ogólniej: softmax na dwuelementowych wektorach daje przesuniętą sigmoidę (przy ustaleniu jednej z wartości).\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline\n",
"import matplotlib.pyplot as plt\n",
"import torch\n",
"import torch.nn as nn\n",
"\n",
"x = torch.linspace(-5,5,100)\n",
"plt.xlabel(\"x\")\n",
"plt.ylabel(\"y\")\n",
"a = torch.Tensor(x.size()[0]).fill_(2.)\n",
"m = torch.stack([x, a])\n",
"plt.plot(x, nn.functional.softmax(m, dim=0)[0])\n",
"fname = 'softmax3.png'\n",
"plt.savefig(fname)\n",
"fname"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"[[file:# Out[19]:\n",
"\n",
" 'softmax3.png'\n",
"\n",
"![img](./obipy-resources/gjBA7K.png)]]\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline\n",
"import matplotlib.pyplot as plt\n",
"from mpl_toolkits import mplot3d\n",
"import torch\n",
"import torch.nn as nn\n",
"\n",
"x = torch.linspace(-5,5,10)\n",
"y = torch.linspace(-5,5,10)\n",
"fig = plt.figure()\n",
"ax = fig.add_subplot(111, projection='3d')\n",
"plt.xlabel(\"x\")\n",
"plt.ylabel(\"y\")\n",
"X, Y = torch.meshgrid(x, y)\n",
"m = torch.stack([X, Y])\n",
"z = nn.functional.softmax(m, dim=0)\n",
"ax.plot_wireframe(x, y, z[0])\n",
"fname = 'softmax3d.png'\n",
"plt.savefig(fname)\n",
"fname"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"[[file:# Out[27]:\n",
"\n",
" 'softmax3d.png'\n",
"\n",
"![img](./obipy-resources/p96515.png)]]\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Wagi\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Podobnie jak funkcja sigmoidalna, softmax nie ma żadnych wyuczalnych wag.\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Zastosowania\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Podstawowym zastosowaniem funkcji softmax jest klasyfikacja\n",
"wieloklasowa, również w wypadku zadań przetwarzania sekwencji, które\n",
"mogą być interpretowane jako klasyfikacja wieloklasowa:\n",
"\n",
"- przewidywanie kolejnego słowa w modelowaniu języka (klasą jest słowo, zbiór klas to słownik)\n",
"- przypisywanie etykiet (np. części mowy) słowom.\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### LogSoftmax\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Ze względów obliczeniowych często korzysta się z funkcji **LogSoftmax**\n",
"która zwraca logarytmy pradopodobieństw (*logproby*).\n",
"\n",
"$$log s(z_i) = log \\frac{e^{z_i}}{\\Sigma_{j=1}^k e^{z_j}}$$\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### PyTorch\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"# Out[25]:\n",
"tensor([-1.1971, -1.3971, -0.7971])"
]
}
],
"source": [
"import torch.nn as nn\n",
"\n",
"s = nn.LogSoftmax(dim=0)\n",
"s(torch.tensor([0.0, -0.2, 0.4]))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Niektóre funkcje kosztu (np. `NLLLoss`) zaimplementowane w PyTorchu\n",
"operują właśnie na logarytmach prawdopobieństw.\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Przykład: klasyfikacja wieloklasowa\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Na przykładzie rozpoznawania dyscypliny sportu: git://gonito.net/sport-text-classification.git\n",
"\n",
"Wczytujemy zbiór uczący:\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"# Out[23]:\n",
"#+BEGIN_EXAMPLE\n",
" {'zimowe': 0,\n",
" 'moto': 1,\n",
" 'tenis': 2,\n",
" 'pilka-reczna': 3,\n",
" 'sporty-walki': 4,\n",
" 'koszykowka': 5,\n",
" 'siatkowka': 6,\n",
" 'pilka-nozna': 7}\n",
"#+END_EXAMPLE"
]
}
],
"source": [
"import gzip\n",
"from pytorch_regression.analyzer import vectorize_text, vector_length\n",
"\n",
"texts = []\n",
"labels = []\n",
"labels_dic = {}\n",
"labels_revdic = {}\n",
"c = 0\n",
"\n",
"with gzip.open('sport-text-classification/train/train.tsv.gz', 'rt') as fh:\n",
" for line in fh:\n",
" line = line.rstrip('\\n')\n",
" line = line.replace('\\\\\\t', ' ')\n",
" label, text = line.split('\\t')\n",
" texts.append(text)\n",
" if label not in labels_dic:\n",
" labels_dic[label] =c\n",
" labels_revdic[c] = label\n",
" c += 1\n",
" labels.append(labels_dic[label])\n",
"nb_of_labels = len(labels_dic)\n",
"labels_dic"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Przygotowujemy model:\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"# Out[8]:"
]
}
],
"source": [
"import torch.nn as nn\n",
"from torch import optim\n",
"\n",
"model = nn.Sequential(\n",
" nn.Linear(vector_length, nb_of_labels),\n",
" nn.LogSoftmax()\n",
" )\n",
"\n",
"optimizer = optim.Adam(model.parameters())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Funkcja kosztu to log-loss.\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"# Out[9]:\n",
"tensor(2.3026)"
]
}
],
"source": [
"import torch\n",
"import torch.nn.functional as F\n",
"\n",
"loss_fn = torch.nn.NLLLoss()\n",
"\n",
"expected_class_id = torch.tensor([2])\n",
"loss_fn(torch.log(\n",
" torch.tensor([[0.3, 0.5, 0.1, 0.0, 0.1]])),\n",
" expected_class_id)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Pętla ucząca:\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"# Out[25]:"
]
}
],
"source": [
"iteration = 0\n",
"step = 50\n",
"closs = torch.tensor(0.0, dtype=torch.float, requires_grad=False)\n",
"\n",
"for t, y_exp in zip(texts, labels):\n",
" x = vectorize_text(t).float().unsqueeze(dim=0)\n",
"\n",
" optimizer.zero_grad()\n",
"\n",
" y_logprobs = model(x)\n",
"\n",
" loss = loss_fn(y_logprobs, torch.tensor([y_exp]))\n",
"\n",
" loss.backward()\n",
"\n",
" with torch.no_grad():\n",
" closs += loss\n",
"\n",
" optimizer.step()\n",
"\n",
" if iteration % 50 == 0:\n",
" print((closs / step).item(), loss.item(), iteration, y_exp, torch.exp(y_logprobs), t)\n",
" closs = torch.tensor(0.0, dtype=torch.float, requires_grad=False)\n",
" iteration += 1\n",
"\n",
" if iteration == 5000:\n",
" break"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Model jest tak prosty, że jego wagi są interpretowalne.\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"# Out[26]:\n",
"tensor([[0.0070, 0.0075, 0.0059, 0.0061, 0.0093, 0.9509, 0.0062, 0.0071]])"
]
}
],
"source": [
"with torch.no_grad():\n",
" x = vectorize_text('NBA').float().unsqueeze(dim=0)\n",
" y_prob = model(x)\n",
"torch.exp(y_prob)"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"# Out[32]:\n",
"#+BEGIN_EXAMPLE\n",
" tensor([-2.3693, -2.3421, -2.4205, -2.4353, -2.1499, 2.5163, -2.4351, -2.4546],\n",
" grad_fn=<SelectBackward>)\n",
"#+END_EXAMPLE"
]
}
],
"source": [
"with torch.no_grad():\n",
" x = vectorize_text('NBA').float().unsqueeze(dim=0)\n",
" ix = torch.argmax(x).item()\n",
"model[0].weight[:,ix]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Możemy nawet zaprezentować wykres przedstawiający rozmieszczenie słów względem dwóch osi odnoszących się do poszczególnych wybranych dyscyplin.\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"# Out[45]:"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAY0AAAD4CAYAAAAQP7oXAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAAaRklEQVR4nO3dfXRV1bnv8e9DQKVcMdKoGMAGvIBCCMQRqIoSrCKK1PAST/EeNFxrM/TWjnMdt6kg1aH22GO1l6o9Sg+HFkGr0iMQEaj4VhrBtwRQQDSKmCqEC0EMqAE1+Nw/sskIcSeZsHf2TsLvM0YGe80195rPJLU/1lxr72XujoiISIhOyS5ARETaD4WGiIgEU2iIiEgwhYaIiARTaIiISLDOyS6gOWlpaZ6RkZHsMkRE2o21a9fudvdTWuv4bTo0MjIyKCsrS3YZIiLthpn9ozWPr+UpEREJ1qFDo6KigszMzKN6b3FxMZs3b45zRSIi7VuHDo2jVVtbq9AQEYnimAmNrVu3kp2dTWlpKeeeey5ZWVlMnDiRTz/9FIDRo0dz6623kpuby29+8xuWLl1KUVERw4YN44MPPkhy9SIibUObvhAeL+Xl5UyZMoV58+Zx7bXX8vvf/57c3Fxuv/127rzzTu6//34Aqqur+fvf/w7A+++/z/jx48nPz09i5SIibUuHC43i9du5b2U5ldX76eF72bZjJ3l5eSxatIjevXtTXV1Nbm4uAAUFBVx11VX17/3Rj36UrLJFRNqFDrU8Vbx+OzMWb2R79X4c2LnvADUczwmpp7JmzZoW39+tW7fWL1JEpB3rUKFx38py9n998PDGTimcMO4WFixYwPLlyzn55JN5+eWXAXj00UfrzzoaO/HEE/nss89au2QRkXalQy1PVVbvj9q+swbWL1vGmDFjmDRpEkVFRdTU1NCvXz/mzZsX9T1TpkzhJz/5CQ8++CBPPfUUZ555ZmuWLiLSLlhbfghTTk6OH8knwkfe8xLbowRHr9SurJn+g3iWJiLSJpnZWnfPaa3jd6jlqaKxA+naJeWwtq5dUigaOzBJFYmIdCwdanlqQnYvgPq7p9JTu1I0dmB9u4iIxKZDhQbUBYdCQkSkdXSo5SkREWldCg0REQkWl9Aws8vMrNzMtpjZ9Cj7R5vZXjN7M/JzezzGFRGRxIr5moaZpQAPAWOAbUCpmS1198ZfEfuyu4+PdTwREUmeeJxpjAC2uPtWd/8KeBLIi8NxRUSkjYlHaPQCPm6wvS3S1th5ZvaWmf3VzAbHYVwREUmweNxya1HaGn/MfB3wPXf/3MzGAcVA/6gHMysECgHOOOOMOJQnIiLxEo8zjW1AnwbbvYHKhh3cfZ+7fx55vQLoYmZp0Q7m7nPcPcfdc0455ZQ4lCciIvESj9AoBfqbWV8zOw6YAixt2MHMepqZRV6PiIz7SRzGFhGRBIp5ecrda83sJmAlkAL8yd3fNrMbIvv/AOQDN5pZLbAfmOJt+ZsSRUQkqg71LbciIsc6fcutiIi0GQoNEREJptAQEZFgCg0REQmm0BARkWAKDRERCabQEBGRYAoNEREJptAQEZFgCg0REQmm0BARkWAKDRERCabQEBGRYAoNEREJptAQEZFgCg0REQmm0BARkWAKDRERCabQEBGRYAoNEREJptAQEZFgCg0REQmm0BARkWAKDRERCabQEBGRYAoNEREJptAQEZFgCg0REQmm0BARkWBxCQ0zu8zMys1si5lNj7LfzOzByP4NZnZOPMYVEZHEijk0zCwFeAi4HBgEXG1mgxp1uxzoH/kpBGbHOq6IiCRePM40RgBb3H2ru38FPAnkNeqTByzwOq8BqWZ2ehzGFhGRBIpHaPQCPm6wvS3SdqR9ADCzQjMrM7OyqqqqOJQnIiLxEo/QsChtfhR96hrd57h7jrvnnHLKKTEXJyIi8ROP0NgG9Gmw3RuoPIo+IiLSxsUjNEqB/mbW18yOA6YASxv1WQpcG7mL6lxgr7vviMPYIiKSQJ1jPYC715rZTcBKIAX4k7u/bWY3RPb/AVgBjAO2ADXA/4x1XBERSbyYQwPA3VdQFwwN2/7Q4LUDP43HWCIikjz6RLiIiARTaIiISDCFhoiIBFNoiIhIMIWGiIgEU2iIiEgwhYaIiARTaIiISDCFhoiIBFNoiIhIMIWGiIgEU2iIiEgwhYaIiARTaIiISDCFhoiIBFNoiIhIMIWGiMTk+uuvZ/PmzQBkZGSwe/duKioqyMzMTHJl0hri8uQ+ETl2zZ07N9klSALpTENEglRUVHDWWWdRUFBAVlYW+fn51NTUMHr0aMrKyur7ffDBB4e9b+vWrWRnZ1NaWsobb7zB+eefT3Z2Nueffz7l5eWJnobESKEhIsHKy8spLCxkw4YNdO/enYcffrh+38GDB/nqq69YsGDBYf0nT57MvHnzGD58OGeddRYlJSWsX7+eu+66i1tvvTUZ05AYaHlKRJpUvH47960sp7J6Pz18L2k90xk5ciQAU6dO5cEHH6zve9ttt7Fv3z7uvPNOPv/8c6qqqsjLy2PRokUMHjwYgL1791JQUMD777+PmfH1118nZV5y9HSmIc1auHAhFRUVyS5DkqB4/XZmLN7I9ur9OLBz3wGqa2opXr+9vo+Z1b/+9a9/TVpaWv32SSedRJ8+fVizZk1922233cZFF13Epk2beOaZZzhw4EBC5iLxo9A4hk2bNo2nnnqqyf2PPfYYH330ERkZGUf1fmnf7ltZzv6vDx7WVrtvF7fPWQzAE088wQUXXNDk+4877jiKi4tZsGABjz/+OFB3ptGrVy8AHnnkkdYpXFqVlqekSVOnTk12CZJEldX7v9XW5bt9+PC1FWRl/Sf9+/fnxhtv5JlnnmnyGN26dWPZsmWMGTOGbt268Ytf/IKCggJmzZrFD37wg9YsX1qJzjSS6IsvvuCKK65g6NChZGZmsnDhQu666y6GDx9OZmYmhYWFuDsAo0eP5uabb2bUqFGcffbZlJaWMmnSJPr3788vf/nL+mPOmjWLzMxMMjMzuf/+++vbFyxYQFZWFkOHDuWaa66pby8pKeH888+nX79+9WcN7k5RURGZmZkMGTKEhQsX1rffdNNNDBo0iCuuuIJdu3Yl4G9JkiU9teu3G83I+lERGzZsYNGiRXznO99h1apV5OTkAHV3WKWlpZGRkcGmTZsASE1NpbS0lLy8PM477zzee+891qxZw69+9SstfbZDOtNIomeffZb09HSWL18O1J26jxkzhttvvx2Aa665hmXLlvHDH/4QqDvdLykp4YEHHiAvL4+1a9fSo0cPzjzzTG6++WYqKiqYN28er7/+Ou7O97//fXJzcznuuOO4++67WbNmDWlpaezZs6e+hh07drB69WreffddrrzySvLz81m8eDFvvvkmb731Frt372b48OGMGjWKV199lfLycjZu3MjOnTsZNGgQ1113XeL/4iQhisYOZMbijYctUZkZRWMHJrEqSTaFRoI1vBvl5K8/Z/uKlfS45RbGjx/PhRdeyKJFi7j33nupqalhz549DB48uD40rrzySgCGDBnC4MGDOf300wHo168fH3/8MatXr2bixIl069YNgEmTJvHyyy9jZuTn59dfpOzRo0d9PRMmTKBTp04MGjSInTt3ArB69WquvvpqUlJSOO2008jNzaW0tJSSkpL69vT0dC0vdHATsuuuPRz63+v3vpfBv69cU98uxyaFRgIduhvl0L/c9nRJ46Sr/y9fnriDGTNmcOmll/LQQw9RVlZGnz59uOOOOw67u+T4448HoFOnTvWvD23X1tbWL2U15u6H3eXSUMPjHHp/U8cBmjyOdEwTsnspJOQwuqaRQI3vRqn97BO+pDOlnTP5+c9/zrp16wBIS0vj888/P+I7k0aNGkVxcTE1NTV88cUXLFmyhAsvvJCLL76Yv/zlL3zyyScAhy1PNXWchQsXcvDgQaqqqigpKWHEiBGMGjWKJ598koMHD7Jjxw7+9re/HeHfgIi0dzGdaZhZD2AhkAFUAP/k7p9G6VcBfAYcBGrdPSeWcdurxnejfF1Vwa5V89hhxt1nfJfZs2dTXFzMkCFDyMjIYPjw4Ud0/HPOOYdp06YxYsQIoO6L5LKzswGYOXMmubm5pKSkkJ2d3eztjhMnTuTVV19l6NChmBn33nsvPXv2ZOLEibz00ksMGTKEAQMGkJube2R/ASLS7llzSxEtvtnsXmCPu99jZtOBk939lij9KoAcd999JMfPycnxht9p096NvOcltke5jbFXalfWTNf1ARGJnZmtbc1/mMe6PJUHzI+8ng9MiPF4HVrR2IF07ZJyWFvXLim6G0VE2o1YQ+M0d98BEPnz1Cb6OfCcma01s8LmDmhmhWZWZmZlVVVVMZbXtkzI7sW/TRpCr9SuGHVnGP82aYguNIpIu9Hi8pSZvQD0jLJrJjDf3VMb9P3U3U+Ocox0d680s1OB54GfuXtJS8V1tOUpEZHW1trLUy1eCHf3S5raZ2Y7zex0d99hZqcDUT8i7O6VkT93mdkSYATQYmiIiEjbEuvy1FKgIPK6AHi6cQcz62ZmJx56DVwKbIpxXBERSYJYQ+MeYIyZvQ+MiWxjZulmtiLS5zRgtZm9BbwBLHf3Z2McV0REkiCmz2m4+yfAxVHaK4FxkddbgaGxjCMiIm2DPhEuItIGVVRUkJmZGdz/+uuvZ/PmzU3uN7NVZhbzBXJ995SISAcwd+7chIyjMw0RkTaqtraWgoICsrKyyM/Pp6amhhdffJHs7GyGDBnCddddx5dffgnUPXPn0EcUzOwRM9tkZhvN7OaGxzSzTmY238z+1cxOMLN5kX7rzeyilmpSaIiItFHl5eUUFhayYcMGunfvzqxZs5g2bRoLFy5k48aN1NbWMnv27MZv+w7Qy90z3X0IMK/Bvs7An4H33P2XwE8BIv2uBuab2QnN1aTQEBFpI4rXb2fkPS/Rd/pyJs9+hbSe6YwcORKoe/zyiy++SN++fRkwYAAABQUFlJR86yNvXwL9zOz3ZnYZsK/Bvv8ANrn73ZHtC4BHAdz9XeAfwIDmalRoiIi0AYeet7O9ej8O7Nx3gOqaWorXbz/SQx2k7o7VVdSdSTS82PEKcFGDs4kjfkCOQkNEpA1o/LwdgNp9u7h9zmIAnnjiCS655BIqKirYsmULAI8++mi0RxR0Bjq5+yLgNuCcBvv+CKwA/svMOlP3zRz/DGBmA4AzgPLm6tTdUyIibUDj5+0AdPluHz58bQVZWf9J//79eeCBBzj33HO56qqrqK2tZfjw4dxwww3fehuwyswOnRTMaLjT3WeZ2UnULUv9GHjYzDYCtcA0d/+yuTpjep5Ga9MXForIsSJez9tp68/TEBGROGgvz9vR8pSISBtw6Lk6960sp7J6P+mpXSkaO7DNPW9HoSEi0kZMyO7V5kKiMS1PiYhIMIWGiIgEU2iIiEgwhYaIiARTaIiISDCFhoiIBFNoiIhIMIWGiIgEU2iIiEgwhYaIiARTaIiISDCFhoiIBFNoiIhIMIWGiIgEU2iIiEgwhYaIiASLKTTM7Coze9vMvjGzJp9Ja2aXmVm5mW0xs+mxjCkiIskT65nGJmASUNJUBzNLAR4CLgcGAVeb2aAYxxURkSSI6XGv7v4OgJk1120EsMXdt0b6PgnkAZtjGVtERBIvEdc0egEfN9jeFmmLyswKzazMzMqqqqpavTgREQnX4pmGmb0A9Iyya6a7Px0wRrTTEG+qs7vPAeYA5OTkNNlPREQSr8XQcPdLYhxjG9CnwXZvoDLGY4qISBIkYnmqFOhvZn3N7DhgCrA0AeOKiEicxXrL7UQz2wacByw3s5WR9nQzWwHg7rXATcBK4B3gL+7+dmxli4hIMsR699QSYEmU9kpgXIPtFcCKWMYSEZHk0yfCRUQkmEJDRESCKTRERCSYQkNERIIpNEREJJhCQ0REgik0REQkmEJDRESCKTRERCSYQkNERIIpNEREJJhCQ0REgik0REQkmEJDRESCKTRERCSYQkNERIIpNEREJJhCQ0REgik0REQkmEJDRESCKTRERCSYQkNERIIpNEREJJhCQ0REgik0REQkmEJDRESCKTRERCSYQkNERIIpNEREJFhMoWFmV5nZ22b2jZnlNNOvwsw2mtmbZlYWy5giIpI8nWN8/yZgEvAfAX0vcvfdMY4nIiJJFFNouPs7AGYWn2pERKRNS9Q1DQeeM7O1ZlbYXEczKzSzMjMrq6qqSlB5IiISosUzDTN7AegZZddMd386cJyR7l5pZqcCz5vZu+5eEq2ju88B5gDk5OR44PFFRCQBWgwNd78k1kHcvTLy5y4zWwKMAKKGhoiItF2tvjxlZt3M7MRDr4FLqbuALiIi7Uyst9xONLNtwHnAcjNbGWlPN7MVkW6nAavN7C3gDWC5uz8by7giIpIcsd49tQRYEqW9EhgXeb0VGBrLOCIi0jboE+EiIhJMoSEiIsEUGiIiEkyhISIiwRQaIiISTKEhIiLBFBoiIhJMoSEiIsEUGiIiEkyhISIiwRQaIiISTKEhIiLBFBoiIhJMoSEiIsEUGiIiEkyhISIiwRQaIiISTKEhIiLBFBoiIhJMoSEiIsEUGiIiEkyhISIiwRQaIiISTKEhIiLBFBoiIglWUVFBZmbmYW2rVq1i/PjxR/y+RFNoiIhIMIWGiEgSbd26lezsbEpLS+vb7rjjDn7729/Wb2dmZlJRUQFAbW0tBQUFZGVlkZ+fT01NTULrVWiIiCRJeXk5kydPZt68eQwfPjz4PYWFhWzYsIHu3bvz8MMPt3KVh4spNMzsPjN718w2mNkSM0ttot9lZlZuZlvMbHosY4qItEfF67cz8p6X6Dt9OZNnv8K2HTvJy8vjscceY9iwYcHH6dOnDyNHjgRg6tSprF69upUqji7WM43ngUx3zwLeA2Y07mBmKcBDwOXAIOBqMxsU47giIu1G8frtzFi8ke3V+3Fg574D1HA8J6Seypo1a77Vv3PnznzzzTf12wcOHKh/bWaH9W283dpiCg13f87dayObrwG9o3QbAWxx963u/hXwJJAXy7giIu3JfSvL2f/1wcMbO6VwwrhbWLBgAY8//vhhuzIyMli3bh0A69at48MPP6zf99FHH/Hqq68C8MQTT3DBBRe0bvGNxPOaxnXAX6O09wI+brC9LdImInJMqKzeH7V9Zw0sW7aM3/3ud+zdu7e+ffLkyezZs4dhw4Yxe/ZsBgwYUL/v7LPPZv78+WRlZbFnzx5uvPHGVq+/oc4tdTCzF4CeUXbNdPenI31mArXAn6MdIkqbNzNeIVAIcMYZZ7RUnohIm5ee2pXtDYKj80mnkf7jh0lP7Upqamr9nVN5eXWLMF27duW5556LeqzNmze3fsHNaDE03P2S5vabWQEwHrjY3aOFwTagT4Pt3kBlM+PNAeYA5OTkNBkuIiLtRdHYgcxYvPGwJaquXVIoGjswiVUdnRZDozlmdhlwC5Dr7k3dLFwK9DezvsB2YArwP2IZV0SkPZmQXbcif9/Kciqr95Oe2pWisQPr29uTmEID+HfgeOD5yBX819z9BjNLB+a6+zh3rzWzm4CVQArwJ3d/O8ZxRUTalQnZvdplSDQWU2i4+39vor0SGNdgewWwIpaxREQk+fSJcBERCabQEBGRYAoNEREJptAQEZFgFv2jFW2DmVUB/4jjIdOA3XE8XluiubVPmlv71Jbn9j13P6W1Dt6mQyPezKzM3XOSXUdr0NzaJ82tferIc2uJlqdERCSYQkNERIIda6ExJ9kFtCLNrX3S3Nqnjjy3Zh1T1zRERCQ2x9qZhoiIxEChISIiwTp0aJhZDzN73szej/x5chP9Us3sKTN718zeMbPzEl3rkQqdW6RvipmtN7NliazxaIXMzcz6mNnfIr+vt83sX5JRaygzu8zMys1si5lNj7LfzOzByP4NZnZOMuo8GgFz++fInDaY2StmNjQZdR6NlubWoN9wMztoZvmJrC8ZOnRoANOBF929P/BiZDuaB4Bn3f0sYCjwToLqi0Xo3AD+hfYxp0NC5lYL/B93Pxs4F/ipmQ1KYI3BzCwFeAi4HBgEXB2l1suB/pGfQmB2Qos8SoFz+5C6Z+5kAb+inVxEDpzboX6/oe7xDx1eRw+NPGB+5PV8YELjDmbWHRgF/BHA3b9y9+oE1ReLFucGYGa9gSuAuYkpKy5anJu773D3dZHXn1EXim31YQUjgC3uvtXdvwKepG6ODeUBC7zOa0CqmZ2e6EKPQotzc/dX3P3TyOZr1D29sz0I+b0B/AxYBOxKZHHJ0tFD4zR33wF1/ycDnBqlTz+gCpgXWcKZa2bdElnkUQqZG8D9wC+AbxJUVzyEzg0AM8sAsoHXW7+0o9IL+LjB9ja+HXAhfdqiI637x8BfW7Wi+GlxbmbWC5gI/CGBdSVVrE/uSzozewHoGWXXzMBDdAbOAX7m7q+b2QPULYfcFqcSj1qsczOz8cAud19rZqPjWFrM4vB7O3Sc/0bdv/L+t7vvi0dtrcCitDW+1z2kT1sUXLeZXURdaFzQqhXFT8jc7gducfeDkaeXdnjtPjTc/ZKm9pnZTjM73d13RE71o50+bgO2ufuhf6U+RfPXBxImDnMbCVxpZuOAE4DuZvaYu09tpZKDxWFumFkX6gLjz+6+uJVKjYdtQJ8G272ByqPo0xYF1W1mWdQtkV7u7p8kqLZYhcwtB3gyEhhpwDgzq3X34oRUmAQdfXlqKVAQeV0APN24g7v/P+BjMxsYaboY2JyY8mISMrcZ7t7b3TOAKcBLbSEwArQ4N6v7r/SPwDvuPiuBtR2NUqC/mfU1s+Oo+10sbdRnKXBt5C6qc4G9h5bo2rgW52ZmZwCLgWvc/b0k1Hi0Wpybu/d194zIf2NPAf+rIwcGAO7eYX+A71J39837kT97RNrTgRUN+g0DyoANQDFwcrJrj9fcGvQfDSxLdt3xmht1Sxwe+Z29GfkZl+zam5nTOOA94ANgZqTtBuCGyGuj7k6dD4CNQE6ya47j3OYCnzb4PZUlu+Z4za1R30eA/GTX3No/+hoREREJ1tGXp0REJI4UGiIiEkyhISIiwRQaIiISTKEhIiLBFBoiIhJMoSEiIsH+P7xRGbngDhcEAAAAAElFTkSuQmCC",
"text/plain": [
"<matplotlib.figure.Figure>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"%matplotlib inline\n",
"import matplotlib.pyplot as plt\n",
"\n",
"with torch.no_grad():\n",
" words = ['piłka', 'klub', 'kort', 'boisko', 'samochód']\n",
" words_ixs = [torch.argmax(vectorize_text(w).float().unsqueeze(dim=0)).item() for w in words]\n",
"\n",
" x_label = labels_dic['pilka-nozna']\n",
" y_label = labels_dic['tenis']\n",
"\n",
" x = [model[0].weight[x_label, ix] for ix in words_ixs]\n",
" y = [model[0].weight[y_label, ix] for ix in words_ixs]\n",
"\n",
" fig, ax = plt.subplots()\n",
" ax.scatter(x, y)\n",
"\n",
" for i, txt in enumerate(words):\n",
" ax.annotate(txt, (x[i], y[i]))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Zadanie etykietowania sekwencji\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Zadanie etykietowania sekwencji (*sequence labelling*) polega na przypisaniu poszczególnym wyrazom (tokenom) tekstu **etykiet** ze skończonego zbioru. Definiując formalnie:\n",
"\n",
"- rozpatrujemy ciąg wejściowy tokenów $(t^1,\\dots,t^K)$\n",
"- dany jest skończony zbiór etykiet $L = \\{l_1,\\dots,l_{|L|}\\}$, dla uproszczenia można założyć, że etykietami\n",
" są po prostu kolejne liczby, tj. $L=\\{0,\\dots,|L|-1\\}$\n",
"- zadanie polega na wygenerowaniu sekwencji etykiet (o tej samej długości co ciąg wejściowy!) $(y^1,\\dots,y^K)$,\n",
" $y^k \\in L$\n",
"\n",
"Zadanie etykietowania można traktować jako przypadek szczególny klasyfikacji wieloklasowej, z tym, że klasyfikacji dokonujemy wielokrotnie — dla każdego tokenu (nie dla każdego tekstu).\n",
"\n",
"Przykłady zastosowań:\n",
"\n",
"- oznaczanie częściami mowy (*POS tagger*) — czasownik, przymiotnik, rzeczownik itd.\n",
"- oznaczanie etykiet nazw w zadaniu NER (nazwisko, kwoty, adresy — najwięcej tokenów będzie miało etykietę pustą, zazwyczaj oznaczaną przez `O`)\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### **Pytanie**: czy zadanie tłumaczenia maszynowego można potraktować jako problem etykietowania sekwencji?\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Przykładowe wyzwanie NER CoNLL-2003\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Zob. [https://gonito.net/challenge/en-ner-conll-2003](https://gonito.net/challenge/en-ner-conll-2003).\n",
"\n",
"Przykładowy przykład uczący (`xzcat train.tsv.xz| head -n 1`):\n",
"\n",
"O O B-MISC I-MISC O O O O O B-LOC O B-LOC O O O O O O O O O O O B-MISC I-MISC O O O O O B-PER I-PER O B-LOC O O O O O O B-PER I-PER O B-LOC O O O O O O B-PER I-PER I-PER O B-LOC O O O O O B-PER I-PER O O B-LOC O O O O O O B-PER I-PER O B-LOC O O O O O B-PER I-PER O O O O O B-PER I-PER O B-LOC O O O O O B-PER I-PER O B-LOC O B-LOC O O O O O O B-PER I-PER O O O O B-PER I-PER O B-LOC O O O O O O B-PER I-PER O B-LOC O O O O O B-PER I-PER O O O O O B-PER I-PER O B-LOC O O O O O B-PER I-PER O B-LOC O O O O O O B-PER I-PER O O O O O B-PER I-PER O B-LOC O O O O O B-PER I-PER O B-LOC O O O O O O B-PER I-PER O O O O B-PER I-PER I-PER O B-LOC O O O O O O B-PER I-PER O O O O B-PER I-PER O B-LOC O O O O O O B-PER I-PER O B-LOC O O O O O O B-PER I-PER O O O O B-PER I-PER O B-LOC O O O O O O B-PER I-PER O O O O B-PER I-PER O B-LOC O O O O O O B-PER I-PER O B-LOC O O O O O B-PER I-PER O B-LOC O B-LOC O O O O O B-PER I-PER O O O O O\tGOLF - BRITISH MASTERS THIRD ROUND SCORES . </S> NORTHAMPTON , England 1996-08-30 </S> Leading scores after </S> the third round of the British Masters on Friday : </S> 211 Robert Allenby ( Australia ) 69 71 71 </S> 212 Pedro Linhart ( Spain ) 72 73 67 </S> 216 Miguel Angel Martin ( Spain ) 75 70 71 , Costantino Rocca </S> ( Italy ) 71 73 72 </S> 217 Antoine Lebouc ( France ) 74 73 70 , Ian Woosnam 70 76 71 , </S> Francisco Cea ( Spain ) 70 71 76 , Gavin Levenson ( South </S> Africa ) 66 75 76 </S> 218 Stephen McAllister 73 76 69 , Joakim Haeggman ( Swe ) 71 77 </S> 70 , Jose Coceres ( Argentina ) 69 78 71 , Paul Eales 75 71 72 , </S> Klas Eriksson ( Sweden ) 71 75 72 , Mike Clayton ( Australia ) </S> 69 76 73 , Mark Roe 69 71 78 </S> 219 Eamonn Darcy ( Ireland ) 74 76 69 , Bob May ( U.S. ) 74 75 70 , </S> Paul Lawrie 72 75 72 , Miguel Angel Jimenez ( Spain ) 74 72 </S> 73 , Peter Mitchell 74 71 75 , Philip Walton ( Ireland ) 71 74 </S> 74 , Peter O'Malley ( Australia ) 71 73 75 </S> 220 Barry Lane 73 77 70 , Wayne Riley ( Australia ) 71 78 71 , </S> Martin Gates 71 77 72 , Bradley Hughes ( Australia ) 73 75 72 , </S> Peter Hedblom ( Sweden ) 70 75 75 , Retief Goosen ( South </S> Africa ) 71 74 75 , David Gilford 69 74 77 . </S>\n",
"\n",
"W pierwszym polu oczekiwany wynik zapisany za pomocą notacji **BIO**.\n",
"\n",
"Jako metrykę używamy F1 (z pominięciem tagu `O`)\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Metryka F1\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Etykietowanie za pomocą klasyfikacji wieloklasowej\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Można potraktować problem etykietowania dokładnie tak jak problem\n",
"klasyfikacji wieloklasowej (jak w przykładzie klasyfikacji dyscyplin\n",
"sportowych powyżej), tzn. rozkład prawdopodobieństwa możliwych etykiet\n",
"uzyskujemy poprzez zastosowanie prostej warstwy liniowej i funkcji softmax:\n",
"\n",
"$$p(l^k=i) = s(\\vec{w}\\vec{v}(t^k))_i = \\frac{e^{\\vec{w}\\vec{v}(t^k)}}{Z},$$\n",
"\n",
"gdzie $\\vec{v}(t^k)$ to reprezentacja wektorowa tokenu $t^k$.\n",
"Zauważmy, że tutaj (w przeciwieństwie do klasyfikacji całego tekstu)\n",
"reprezentacja wektorowa jest bardzo uboga: wektor <u>one-hot</u>! Taki\n",
"klasyfikator w ogóle nie będzie brał pod uwagę kontekstu, tylko sam\n",
"wyraz, więc tak naprawdę zdegeneruje się to do zapamiętania częstości\n",
"etykiet dla każdego słowa osobno.\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"##### Bogatsza reprezentacja słowa\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Można spróbować uzyskać bogatszą reprezentację dla słowa biorąc pod uwagę na przykład:\n",
"\n",
"- długość słowa\n",
"- kształt słowa (*word shape*), np. czy pisany wielkimi literami, czy składa się z cyfr itp.\n",
"- n-gramy znakowe wewnątrz słowa (np. słowo *Kowalski* można zakodować jako sumę wektorów\n",
" trigramów znakówych $\\vec{v}(Kow) + \\vec{v}(owa) + \\vec{v}(wal) + \\vec{v}(als) + \\vec{v}(lsk) + + \\vec{v}(ski)$\n",
"\n",
"Cały czas nie rozpatrujemy jednak w tej metodzie kontekstu wyrazu.\n",
"(*Renault* w pewnym kontekście może być nazwą firmy, w innym —\n",
"nazwiskiem).\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"##### Reprezentacja kontekstu\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Za pomocą wektora można przedstawić nie pojedynczy token $t^k$, lecz\n",
"cały kontekst, dla *okna* o długości $c$ będzie to kontekst $t^{k-c},\\dots,t^k,\\dots,t^{k+c}$.\n",
"Innymi słowy klasyfikujemy token na podstawie jego samego oraz jego kontekstu:\n",
"\n",
"$$p(l^k=i) = \\frac{e^{\\vec{w}\\vec{v}(t^{k-c},\\dots,t^k,\\dots,t^{k+c})}}{Z_k}.$$\n",
"\n",
"Zauważmy, że w tej metodzie w ogóle nie rozpatrujemy sensowności\n",
"sekwencji wyjściowej (etykiet), np. może być bardzo mało\n",
"prawdopodobne, że bezpośrednio po nazwisku występuje data.\n",
"\n",
"Napiszmy wzór określający prawdopodobieństwo całej sekwencji, nie\n",
"tylko pojedynczego tokenu. Na razie będzie to po prostu iloczyn poszczególnych wartości.\n",
"\n",
"$$p(l) = \\prod_{k=1}^K \\frac{e^{\\vec{w}\\vec{v}(t^{k-c},\\dots,t^k,\\dots,t^{k+c})}}{Z_k} = \\frac{e^{\\sum_{k=1}^K\\vec{w}\\vec{v}(t^{k-c},\\dots,t^k,\\dots,t^{k+c})}}{\\prod_{k=1}^K Z_k}$$\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Warunkowe pola losowe\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Warunkowe pola losowe (*Conditional Random Fields*, *CRF*) to klasa\n",
"modeli, które pozwalają uwzględnić zależności między punktami danych\n",
"(które można wyrazić jako graf). Najprostszym przykładem będzie prosty\n",
"graf wyrażający „następowanie po” (czyli sekwencje). Do poprzedniego\n",
"wzoru dodamy składnik $V_{i,j}$ (który można interpretować jako\n",
"macierz) określający prawdopodobieństwo, że po etykiecie o numerze $i$ wystąpi etykieta o numerze $j$.\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### **Pytanie**: Czy macierz $V$ musi być symetryczna? Czy $V_{i,j} = V_{j,i}$? Czy jakieś specjalne wartości występują na przekątnej?\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Macierz $V$ wraz z wektorem $\\vec{w}$ będzie stanowiła wyuczalne wagi w naszym modelu.\n",
"\n",
"Wartości $V_{i,j}$ nie stanowią bezpośrednio prawdopodobieństwa, mogą\n",
"przyjmować dowolne wartości, które będę normalizowane podobnie jak to się dzieje w funkcji Softmax.\n",
"\n",
"W takiej wersji warunkowych pól losowych otrzymamy następujący wzór na prawdopodobieństwo całej sekwencji.\n",
"\n",
"$$p(l) = \\frac{e^{\\sum_{k=1}^K\\vec{w}\\vec{v}(t^{k-c},\\dots,t^k,\\dots,t^{k+c}) + \\sum_{k=1}^{K-1} V_{l_k,l_{k+1}}}}{\\prod_{k=1}^K Z_k}$$\n",
"\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.2"
},
"org": null
},
"nbformat": 4,
"nbformat_minor": 1
}