1583 lines
106 KiB
Plaintext
1583 lines
106 KiB
Plaintext
|
{
|
||
|
"cells": [
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 1,
|
||
|
"id": "spread-happiness",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"%matplotlib inline\n",
|
||
|
"%load_ext autoreload\n",
|
||
|
"%autoreload 2\n",
|
||
|
"\n",
|
||
|
"import numpy as np\n",
|
||
|
"import pandas as pd\n",
|
||
|
"import matplotlib.pyplot as plt\n",
|
||
|
"import seaborn as sns\n",
|
||
|
"import matplotlib.ticker as ticker\n",
|
||
|
"from IPython.display import Markdown, display, HTML\n",
|
||
|
"\n",
|
||
|
"import torch\n",
|
||
|
"import torch.nn as nn\n",
|
||
|
"import torch.optim as optim\n",
|
||
|
"\n",
|
||
|
"# Fix the dying kernel problem (only a problem in some installations - you can remove it, if it works without it)\n",
|
||
|
"import os\n",
|
||
|
"os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "approximate-classic",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"# PyTorch\n",
|
||
|
"\n",
|
||
|
"Here's your best friend when working with PyTorch: https://pytorch.org/docs/stable/index.html.\n",
|
||
|
"\n",
|
||
|
"The beginning of this notebook shows that PyTorch tensors can be used exactly like numpy arrays. Later in the notebook additional features of tensors will be presented."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "renewable-chase",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"## Creating PyTorch tensors"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "afraid-consortium",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"### Directly"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 2,
|
||
|
"id": "textile-mainland",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"[[1. 2. 3.]\n",
|
||
|
" [4. 5. 6.]\n",
|
||
|
" [7. 8. 9.]]\n",
|
||
|
"\n",
|
||
|
"tensor([[1., 2., 3.],\n",
|
||
|
" [4., 5., 6.],\n",
|
||
|
" [7., 8., 9.]])\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"a = np.array(\n",
|
||
|
" [[1.0, 2.0, 3.0], \n",
|
||
|
" [4.0, 5.0, 6.0], \n",
|
||
|
" [7.0, 8.0, 9.0]]\n",
|
||
|
")\n",
|
||
|
"\n",
|
||
|
"print(a)\n",
|
||
|
"print()\n",
|
||
|
"\n",
|
||
|
"t = torch.tensor(\n",
|
||
|
" [[1.0, 2.0, 3.0], \n",
|
||
|
" [4.0, 5.0, 6.0], \n",
|
||
|
" [7.0, 8.0, 9.0]]\n",
|
||
|
")\n",
|
||
|
"\n",
|
||
|
"print(t)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "floating-junior",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"### From a list"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 3,
|
||
|
"id": "reasonable-mistress",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]\n",
|
||
|
"\n",
|
||
|
"[[1. 2. 3.]\n",
|
||
|
" [4. 5. 6.]\n",
|
||
|
" [7. 8. 9.]]\n",
|
||
|
"\n",
|
||
|
"tensor([[1., 2., 3.],\n",
|
||
|
" [4., 5., 6.],\n",
|
||
|
" [7., 8., 9.]])\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"l = [[1.0, 2.0, 3.0], \n",
|
||
|
" [4.0, 5.0, 6.0], \n",
|
||
|
" [7.0, 8.0, 9.0]]\n",
|
||
|
"\n",
|
||
|
"print(l)\n",
|
||
|
"print()\n",
|
||
|
"\n",
|
||
|
"a = np.array(l)\n",
|
||
|
"print(a)\n",
|
||
|
"print()\n",
|
||
|
"\n",
|
||
|
"t = torch.tensor(l)\n",
|
||
|
"print(t)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "incorrect-practitioner",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"### From a list comprehension"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 4,
|
||
|
"id": "straight-cooling",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"[0, 1, 4, 9, 16, 25, 36, 49, 64, 81]\n",
|
||
|
"\n",
|
||
|
"[ 0 1 4 9 16 25 36 49 64 81]\n",
|
||
|
"\n",
|
||
|
"tensor([ 0, 1, 4, 9, 16, 25, 36, 49, 64, 81])\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"a = [i**2 for i in range(10)]\n",
|
||
|
"\n",
|
||
|
"print(a)\n",
|
||
|
"print()\n",
|
||
|
"print(np.array(a))\n",
|
||
|
"print()\n",
|
||
|
"print(torch.tensor(a))"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "enormous-drink",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"### From a numpy array"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 5,
|
||
|
"id": "parental-judges",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"tensor([[1., 2., 3.],\n",
|
||
|
" [4., 5., 6.],\n",
|
||
|
" [7., 8., 9.]], dtype=torch.float64)\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"a = np.array(\n",
|
||
|
" [[1.0, 2.0, 3.0], \n",
|
||
|
" [4.0, 5.0, 6.0], \n",
|
||
|
" [7.0, 8.0, 9.0]]\n",
|
||
|
")\n",
|
||
|
"\n",
|
||
|
"t = torch.tensor(a)\n",
|
||
|
"\n",
|
||
|
"print(t)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "suffering-myanmar",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"### Ready-made functions in PyTorch"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 6,
|
||
|
"id": "expensive-bowling",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"All zeros\n",
|
||
|
"tensor([[0., 0., 0., 0.],\n",
|
||
|
" [0., 0., 0., 0.],\n",
|
||
|
" [0., 0., 0., 0.]])\n",
|
||
|
"\n",
|
||
|
"All chosen value (variant 1)\n",
|
||
|
"tensor([[7., 7., 7., 7.],\n",
|
||
|
" [7., 7., 7., 7.],\n",
|
||
|
" [7., 7., 7., 7.]])\n",
|
||
|
"\n",
|
||
|
"All chosen value (variant 2)\n",
|
||
|
"tensor([[7., 7., 7., 7.],\n",
|
||
|
" [7., 7., 7., 7.],\n",
|
||
|
" [7., 7., 7., 7.]])\n",
|
||
|
"\n",
|
||
|
"Random integers\n",
|
||
|
"[[6 6]\n",
|
||
|
" [8 9]\n",
|
||
|
" [1 0]]\n",
|
||
|
"\n",
|
||
|
"tensor([[9, 5],\n",
|
||
|
" [9, 3],\n",
|
||
|
" [3, 8]])\n",
|
||
|
"\n",
|
||
|
"Random values from the normal distribution\n",
|
||
|
"[[ -5.34346728 0.97207777]\n",
|
||
|
" [ -7.26648922 -12.2890286 ]\n",
|
||
|
" [ -2.68082928 10.95819034]]\n",
|
||
|
"\n",
|
||
|
"tensor([[ 1.1231, -5.9980],\n",
|
||
|
" [20.4600, -6.4359],\n",
|
||
|
" [-6.6826, -0.4491]])\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"# All zeros\n",
|
||
|
"a = torch.zeros((3, 4))\n",
|
||
|
"print(\"All zeros\")\n",
|
||
|
"print(a)\n",
|
||
|
"print()\n",
|
||
|
"\n",
|
||
|
"# All a chosen value\n",
|
||
|
"a = torch.full((3, 4), 7.0)\n",
|
||
|
"print(\"All chosen value (variant 1)\")\n",
|
||
|
"print(a)\n",
|
||
|
"print()\n",
|
||
|
"\n",
|
||
|
"# or\n",
|
||
|
"\n",
|
||
|
"a = torch.zeros((3, 4))\n",
|
||
|
"a[:] = 7.0\n",
|
||
|
"print(\"All chosen value (variant 2)\")\n",
|
||
|
"print(a)\n",
|
||
|
"print()\n",
|
||
|
"\n",
|
||
|
"# Random integers\n",
|
||
|
"\n",
|
||
|
"print(\"Random integers\")\n",
|
||
|
"a = np.random.randint(low=0, high=10, size=(3, 2))\n",
|
||
|
"print(a)\n",
|
||
|
"print()\n",
|
||
|
"a = torch.randint(low=0, high=10, size=(3, 2))\n",
|
||
|
"print(a)\n",
|
||
|
"print()\n",
|
||
|
"\n",
|
||
|
"# Random values from the normal distribution (Gaussian)\n",
|
||
|
"\n",
|
||
|
"print(\"Random values from the normal distribution\")\n",
|
||
|
"a = np.random.normal(loc=0, scale=10, size=(3, 2))\n",
|
||
|
"print(a)\n",
|
||
|
"print()\n",
|
||
|
"a = torch.normal(mean=0, std=10, size=(3, 2))\n",
|
||
|
"print(a)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "aggressive-titanium",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"## Slicing PyTorch tensors"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "former-richardson",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"### Slicing in 1D\n",
|
||
|
"\n",
|
||
|
"To obtain only specific values from a PyTorch tensor one can use so called slicing. It has the form\n",
|
||
|
"\n",
|
||
|
"**arr[low:high:step]**\n",
|
||
|
"\n",
|
||
|
"where low is the lowest index to be retrieved, high is the lowest index not to be retrieved and step indicates that every step element will be taken."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 7,
|
||
|
"id": "desirable-documentary",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Original: tensor([ 0, 1, 4, 9, 16, 25, 36, 49, 64, 81])\n",
|
||
|
"First 5 elements: tensor([ 0, 1, 4, 9, 16])\n",
|
||
|
"Elements from index 3 to index 5: tensor([ 9, 16, 25])\n",
|
||
|
"Last 3 elements (negative indexing): tensor([49, 64, 81])\n",
|
||
|
"Every second element: tensor([ 0, 4, 16, 36, 64])\n",
|
||
|
"Negative step a[::-1] to obtain reverse order does not work for tensors\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"a = torch.tensor([i**2 for i in range(10)])\n",
|
||
|
"\n",
|
||
|
"print(\"Original: \", a)\n",
|
||
|
"print(\"First 5 elements:\", a[:5])\n",
|
||
|
"print(\"Elements from index 3 to index 5:\", a[3:6])\n",
|
||
|
"print(\"Last 3 elements (negative indexing):\", a[-3:])\n",
|
||
|
"print(\"Every second element:\", a[::2])\n",
|
||
|
"\n",
|
||
|
"print(\"Negative step a[::-1] to obtain reverse order does not work for tensors\")"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "micro-explosion",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"### Slicing in 2D\n",
|
||
|
"\n",
|
||
|
"In two dimensions it works similarly, just the slicing is separate for every dimension."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 8,
|
||
|
"id": "disciplinary-think",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Original: \n",
|
||
|
"tensor([[ 0, 1, 2, 3, 4],\n",
|
||
|
" [ 5, 6, 7, 8, 9],\n",
|
||
|
" [10, 11, 12, 13, 14],\n",
|
||
|
" [15, 16, 17, 18, 19],\n",
|
||
|
" [20, 21, 22, 23, 24]])\n",
|
||
|
"\n",
|
||
|
"First 2 elements of the first 3 row:\n",
|
||
|
"tensor([[ 0, 1],\n",
|
||
|
" [ 5, 6],\n",
|
||
|
" [10, 11]])\n",
|
||
|
"\n",
|
||
|
"Middle 3 elements from the middle 3 rows:\n",
|
||
|
"tensor([[ 6, 7, 8],\n",
|
||
|
" [11, 12, 13],\n",
|
||
|
" [16, 17, 18]])\n",
|
||
|
"\n",
|
||
|
"Bottom-right 3 by 3 submatrix (negative indexing):\n",
|
||
|
"tensor([[12, 13, 14],\n",
|
||
|
" [17, 18, 19],\n",
|
||
|
" [22, 23, 24]])\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"a = torch.tensor([i for i in range(25)]).reshape(5, 5)\n",
|
||
|
"\n",
|
||
|
"print(\"Original: \")\n",
|
||
|
"print(a)\n",
|
||
|
"print()\n",
|
||
|
"print(\"First 2 elements of the first 3 row:\")\n",
|
||
|
"print(a[:3, :2])\n",
|
||
|
"print()\n",
|
||
|
"print(\"Middle 3 elements from the middle 3 rows:\")\n",
|
||
|
"print(a[1:4, 1:4])\n",
|
||
|
"print()\n",
|
||
|
"print(\"Bottom-right 3 by 3 submatrix (negative indexing):\")\n",
|
||
|
"print(a[-3:, -3:])"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "removable-canyon",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"### Setting PyTorch tensor field values"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 9,
|
||
|
"id": "senior-serbia",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Original: \n",
|
||
|
"tensor([[ 0, 1, 2, 3, 4],\n",
|
||
|
" [ 5, 6, 7, 8, 9],\n",
|
||
|
" [10, 11, 12, 13, 14],\n",
|
||
|
" [15, 16, 17, 18, 19],\n",
|
||
|
" [20, 21, 22, 23, 24]])\n",
|
||
|
"\n",
|
||
|
"Middle values changed to 5\n",
|
||
|
"tensor([[ 0, 1, 2, 3, 4],\n",
|
||
|
" [ 5, 5, 5, 5, 9],\n",
|
||
|
" [10, 5, 5, 5, 14],\n",
|
||
|
" [15, 5, 5, 5, 19],\n",
|
||
|
" [20, 21, 22, 23, 24]])\n",
|
||
|
"\n",
|
||
|
"Second matrix\n",
|
||
|
"tensor([[ 0, 0, 2],\n",
|
||
|
" [ 6, 12, 20],\n",
|
||
|
" [30, 42, 56]])\n",
|
||
|
"\n",
|
||
|
"Second matrix substituted into the middle of the first matrix\n",
|
||
|
"tensor([[ 0, 1, 2, 3, 4],\n",
|
||
|
" [ 5, 0, 0, 2, 9],\n",
|
||
|
" [10, 6, 12, 20, 14],\n",
|
||
|
" [15, 30, 42, 56, 19],\n",
|
||
|
" [20, 21, 22, 23, 24]])\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"a = torch.tensor([i for i in range(25)]).reshape(5, 5)\n",
|
||
|
"\n",
|
||
|
"print(\"Original: \")\n",
|
||
|
"print(a)\n",
|
||
|
"print()\n",
|
||
|
"\n",
|
||
|
"a[1:4, 1:4] = 5.0\n",
|
||
|
"\n",
|
||
|
"print(\"Middle values changed to 5\")\n",
|
||
|
"print(a)\n",
|
||
|
"print()\n",
|
||
|
"\n",
|
||
|
"b = torch.tensor([i**2 - i for i in range(9)]).reshape(3, 3)\n",
|
||
|
"\n",
|
||
|
"print(\"Second matrix\")\n",
|
||
|
"print(b)\n",
|
||
|
"print()\n",
|
||
|
"\n",
|
||
|
"a[1:4, 1:4] = b\n",
|
||
|
"\n",
|
||
|
"print(\"Second matrix substituted into the middle of the first matrix\")\n",
|
||
|
"print(a)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "federal-wayne",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"## Operations on PyTorch tensors\n",
|
||
|
"\n",
|
||
|
"It is important to remember that arithmetic operations on PyTorch tensors are always element-wise."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 10,
|
||
|
"id": "southwest-biotechnology",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"tensor([[ 0, 1, 4],\n",
|
||
|
" [ 9, 16, 25],\n",
|
||
|
" [36, 49, 64]])\n",
|
||
|
"\n",
|
||
|
"tensor([[0.0000, 1.0000, 1.4142],\n",
|
||
|
" [1.7321, 2.0000, 2.2361],\n",
|
||
|
" [2.4495, 2.6458, 2.8284]])\n",
|
||
|
"\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"a = torch.tensor([i**2 for i in range(9)]).reshape((3, 3))\n",
|
||
|
"print(a)\n",
|
||
|
"print()\n",
|
||
|
"\n",
|
||
|
"b = torch.tensor([i**0.5 for i in range(9)]).reshape((3, 3))\n",
|
||
|
"print(b)\n",
|
||
|
"print()"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "intensive-gates",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"### Element-wise sum"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 11,
|
||
|
"id": "behavioral-safety",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"tensor([[ 0.0000, 2.0000, 5.4142],\n",
|
||
|
" [10.7321, 18.0000, 27.2361],\n",
|
||
|
" [38.4495, 51.6458, 66.8284]])\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"print(a + b)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "occupied-trial",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"### Element-wise multiplication"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 12,
|
||
|
"id": "charming-pleasure",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"tensor([[ 0.0000, 1.0000, 5.6569],\n",
|
||
|
" [ 15.5885, 32.0000, 55.9017],\n",
|
||
|
" [ 88.1816, 129.6418, 181.0193]])\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"print(a * b)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "efficient-league",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"### Matrix multiplication"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 13,
|
||
|
"id": "changing-community",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"tensor([[ 11.5300, 12.5830, 13.5498],\n",
|
||
|
" [ 88.9501, 107.1438, 119.2157],\n",
|
||
|
" [241.6378, 303.3281, 341.4984]], dtype=torch.float64)\n",
|
||
|
"\n",
|
||
|
"tensor([[ 0., 1., 4.],\n",
|
||
|
" [ 9., 16., 25.],\n",
|
||
|
" [36., 49., 64.]])\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"print(np.matmul(a, b))\n",
|
||
|
"print()\n",
|
||
|
"\n",
|
||
|
"# Multiplication by the identity matrix (to check it works as expected)\n",
|
||
|
"id_matrix = torch.tensor(\n",
|
||
|
" [[1.0, 0.0, 0.0], \n",
|
||
|
" [0.0, 1.0, 0.0], \n",
|
||
|
" [0.0, 0.0, 1.0]]\n",
|
||
|
")\n",
|
||
|
"\n",
|
||
|
"# Tensor a contained integers (type Long by default) and must be changed to the float type\n",
|
||
|
"a = a.type(torch.FloatTensor)\n",
|
||
|
"\n",
|
||
|
"print(torch.matmul(id_matrix, a))"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "assisted-communications",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"### Calculating the mean"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 14,
|
||
|
"id": "defensive-wrong",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"tensor([3, 8, 7, 2, 6])\n",
|
||
|
"\n",
|
||
|
"Mean: tensor(5.2000)\n",
|
||
|
"\n",
|
||
|
"Mean: 5.199999809265137\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"a = torch.randint(low=0, high=10, size=(5,))\n",
|
||
|
"\n",
|
||
|
"print(a)\n",
|
||
|
"print()\n",
|
||
|
"\n",
|
||
|
"print(\"Mean: \", torch.sum(a) / len(a))\n",
|
||
|
"print()\n",
|
||
|
"\n",
|
||
|
"# To get a single value use tensor.item()\n",
|
||
|
"\n",
|
||
|
"print(\"Mean: \", (torch.sum(a) / len(a)).item())"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "complex-karma",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"### Calculating the mean of every row"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 15,
|
||
|
"id": "correct-dietary",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"tensor([[1, 6, 8],\n",
|
||
|
" [6, 4, 8],\n",
|
||
|
" [1, 5, 8],\n",
|
||
|
" [2, 5, 7],\n",
|
||
|
" [1, 0, 4]])\n",
|
||
|
"\n",
|
||
|
"Mean: tensor([5.0000, 6.0000, 4.6667, 4.6667, 1.6667])\n",
|
||
|
"Mean in the original matrix form:\n",
|
||
|
"tensor([[5.0000],\n",
|
||
|
" [6.0000],\n",
|
||
|
" [4.6667],\n",
|
||
|
" [4.6667],\n",
|
||
|
" [1.6667]])\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"a = torch.randint(low=0, high=10, size=(5, 3))\n",
|
||
|
"\n",
|
||
|
"print(a)\n",
|
||
|
"print()\n",
|
||
|
"\n",
|
||
|
"print(\"Mean:\", torch.sum(a, axis=1) / a.shape[1])\n",
|
||
|
"\n",
|
||
|
"print(\"Mean in the original matrix form:\")\n",
|
||
|
"print((torch.sum(a, axis=1) / a.shape[1]).reshape(-1, 1)) # -1 calculates the right size to use all elements"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "indian-orlando",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"### More complex operations\n",
|
||
|
"\n",
|
||
|
"Note that more complex tensor operations can only be performed on tensors. Numpy operations can be performed on numpy arrays but also directly on lists."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 16,
|
||
|
"id": "presidential-cologne",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Vector to power 2 (element-wise)\n",
|
||
|
"tensor([1., 4., 9.])\n",
|
||
|
"\n",
|
||
|
"Euler number to the power a (element-wise)\n",
|
||
|
"tensor([ 2.7183, 7.3891, 20.0855])\n",
|
||
|
"\n",
|
||
|
"An even more complex expression\n",
|
||
|
"tensor([0.6197, 1.8982, 4.8476])\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"a = torch.tensor([1.0, 2.0, 3.0])\n",
|
||
|
"\n",
|
||
|
"print(\"Vector to power 2 (element-wise)\")\n",
|
||
|
"print(torch.pow(a, 2))\n",
|
||
|
"print()\n",
|
||
|
"print(\"Euler number to the power a (element-wise)\")\n",
|
||
|
"print(torch.exp(a))\n",
|
||
|
"print()\n",
|
||
|
"print(\"An even more complex expression\")\n",
|
||
|
"print((torch.pow(a, 2) + torch.exp(a)) / torch.sum(a))"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "hearing-street",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"## PyTorch basic operations tasks"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "regular-niger",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"**Task 1.** Calculate the sigmoid (logistic) function on every element of the following array [0.3, 1.2, -1.4, 0.2, -0.1, 0.1, 0.8, -0.25] and print the last 5 elements. Use only tensor operations."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 17,
|
||
|
"id": "agreed-single",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# Write your code here"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "another-catch",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"**Task 2.** Calculate the dot product of the following two vectors:<br/>\n",
|
||
|
"$x = [3, 1, 4, 2, 6, 1, 4, 8]$<br/>\n",
|
||
|
"$y = [5, 2, 3, 12, 2, 4, 17, 9]$<br/>\n",
|
||
|
"a) by using element-wise mutliplication and torch.sum,<br/>\n",
|
||
|
"b) by using torch.dot,<br/>\n",
|
||
|
"b) by using torch.matmul and transposition (x.T)."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 18,
|
||
|
"id": "forbidden-journalism",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# Write your code here"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "acute-amber",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"**Task 3.** Calculate the following expression<br/>\n",
|
||
|
"$$\\frac{1}{1 + e^{-x_0 \\theta_0 - \\ldots - x_9 \\theta_9 - \\theta_{10}}}$$\n",
|
||
|
"for<br/>\n",
|
||
|
"$x = [1.2, 2.3, 3.4, -0.7, 4.2, 2.7, -0.5, 1.4, -3.3, 0.2]$<br/>\n",
|
||
|
"$\\theta = [1.7, 0.33, -2.12, -1.73, 2.9, -5.8, -0.9, 12.11, 3.43, -0.5, -1.65]$<br/>\n",
|
||
|
"and print the result. Use only tensor operations."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 19,
|
||
|
"id": "falling-holder",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# Write your code here"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "latter-vector",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"# Tensor gradients\n",
|
||
|
"\n",
|
||
|
"Tensors are designed to be used in neural networks. Their most important functionality is automatic gradient and backward propagation calculation."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 20,
|
||
|
"id": "guided-interface",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"out=35.0\n",
|
||
|
"\n",
|
||
|
"gradient\n",
|
||
|
"tensor([[12., 3.],\n",
|
||
|
" [27., 3.]])\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"x = torch.tensor([[2., -1.], [3., 1.]], requires_grad=True)\n",
|
||
|
"out = x.pow(3).sum() # the actual derivative is 3*x^2\n",
|
||
|
"print(\"out={}\".format(out))\n",
|
||
|
"print()\n",
|
||
|
"\n",
|
||
|
"out.backward()\n",
|
||
|
"print(\"gradient\")\n",
|
||
|
"print(x.grad)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 21,
|
||
|
"id": "nuclear-gothic",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"tensor([[ 4., 2., -1.]])\n",
|
||
|
"tensor([[ 2., -1., 3.]])\n",
|
||
|
"tensor([[ 0.1807, 0.0904, -0.0452]])\n",
|
||
|
"tensor([[ 0.0904, -0.0452, 0.1355]])\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"x = torch.tensor([[2., -1., 3.]], requires_grad=True)\n",
|
||
|
"y = torch.tensor([[4., 2., -1.]], requires_grad=True)\n",
|
||
|
"\n",
|
||
|
"z = torch.sum(x * y)\n",
|
||
|
"\n",
|
||
|
"z.backward()\n",
|
||
|
"print(x.grad)\n",
|
||
|
"print(y.grad)\n",
|
||
|
"\n",
|
||
|
"x.grad.data.zero_()\n",
|
||
|
"y.grad.data.zero_()\n",
|
||
|
"\n",
|
||
|
"z = torch.sigmoid(torch.sum(x * y))\n",
|
||
|
"\n",
|
||
|
"z.backward()\n",
|
||
|
"print(x.grad)\n",
|
||
|
"print(y.grad)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "innovative-provider",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"# Backpropagation\n",
|
||
|
"\n",
|
||
|
"In this section we train weights $w$ of a simple model $y = \\text{sigmoid}(w * x)$ to obtain $y = 0.65$ on $x = [2.0, -1.0, 3.0]$."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 22,
|
||
|
"id": "supposed-sellers",
|
||
|
"metadata": {
|
||
|
"scrolled": false
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"x\n",
|
||
|
"tensor([ 2., -1., 3.])\n",
|
||
|
"x.grad\n",
|
||
|
"None\n",
|
||
|
"w\n",
|
||
|
"tensor([ 4., 2., -1.], requires_grad=True)\n",
|
||
|
"w.grad\n",
|
||
|
"None\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"w\n",
|
||
|
"tensor([ 3.9945, 2.0027, -1.0082], requires_grad=True)\n",
|
||
|
"w.grad\n",
|
||
|
"tensor([ 0.0547, -0.0273, 0.0820])\n",
|
||
|
"y\n",
|
||
|
"tensor(0.9526, grad_fn=<SigmoidBackward>)\n",
|
||
|
"loss\n",
|
||
|
"tensor(0.0916, grad_fn=<PowBackward0>)\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"w\n",
|
||
|
"tensor([ 3.9889, 2.0055, -1.0166], requires_grad=True)\n",
|
||
|
"w.grad\n",
|
||
|
"tensor([ 0.0563, -0.0281, 0.0844])\n",
|
||
|
"y\n",
|
||
|
"tensor(0.9508, grad_fn=<SigmoidBackward>)\n",
|
||
|
"loss\n",
|
||
|
"tensor(0.0905, grad_fn=<PowBackward0>)\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"w\n",
|
||
|
"tensor([ 3.9831, 2.0084, -1.0253], requires_grad=True)\n",
|
||
|
"w.grad\n",
|
||
|
"tensor([ 0.0579, -0.0290, 0.0869])\n",
|
||
|
"y\n",
|
||
|
"tensor(0.9489, grad_fn=<SigmoidBackward>)\n",
|
||
|
"loss\n",
|
||
|
"tensor(0.0894, grad_fn=<PowBackward0>)\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"w\n",
|
||
|
"tensor([ 3.6599, 2.1701, -1.5102], requires_grad=True)\n",
|
||
|
"w.grad\n",
|
||
|
"tensor([ 6.1291e-06, -3.0645e-06, 9.1936e-06])\n",
|
||
|
"y\n",
|
||
|
"tensor(0.6500, grad_fn=<SigmoidBackward>)\n",
|
||
|
"loss\n",
|
||
|
"tensor(4.5365e-11, grad_fn=<PowBackward0>)\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"w\n",
|
||
|
"tensor([ 3.6599, 2.1701, -1.5102], requires_grad=True)\n",
|
||
|
"w.grad\n",
|
||
|
"tensor([ 5.0985e-06, -2.5493e-06, 7.6478e-06])\n",
|
||
|
"y\n",
|
||
|
"tensor(0.6500, grad_fn=<SigmoidBackward>)\n",
|
||
|
"loss\n",
|
||
|
"tensor(3.1392e-11, grad_fn=<PowBackward0>)\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"w\n",
|
||
|
"tensor([ 3.6599, 2.1701, -1.5102], requires_grad=True)\n",
|
||
|
"w.grad\n",
|
||
|
"tensor([ 4.4477e-06, -2.2238e-06, 6.6715e-06])\n",
|
||
|
"y\n",
|
||
|
"tensor(0.6500, grad_fn=<SigmoidBackward>)\n",
|
||
|
"loss\n",
|
||
|
"tensor(2.3888e-11, grad_fn=<PowBackward0>)\n",
|
||
|
"\n",
|
||
|
"\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYgAAAEWCAYAAAB8LwAVAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAkGklEQVR4nO3deXhV5bn+8e+TGQiEKUwZCKOIgIoBxFr1J9aCtVILKuBArRbbU22t9LTYns7n9KenA7Z1ONpSiyNYrRa1dZ4HhgCiDIIRZJ6nAIGEJM/5Yy88Md1IgOysPdyf6+JiDe/e+1nXgtxZ77vXu8zdERERaSgt7AJERCQ+KSBERCQqBYSIiESlgBARkagUECIiEpUCQkREolJAiByGmf3TzCY2ddujrOEcM1vX1O8r0hgZYRcg0pTMbG+91ZZAFVAbrF/n7g829r3cfVQs2ookCgWEJBV3zz20bGYfAde6+wsN25lZhrvXNGdtIolGXUySEg511ZjZ981sE3CvmbUzs6fMbKuZ7QyWC+u95hUzuzZY/oqZvWFmvw7arjKzUcfYtoeZvWZme8zsBTO7w8weaORxnBh81i4zW2JmF9Xbd4GZLQ3ed72ZfTfY3jE4tl1mtsPMXjcz/d+XI9I/EkklXYD2QHdgEpF///cG68XAfuD2T3n9MGA50BH4b2CamdkxtH0ImAt0AH4KXNmY4s0sE3gSeA7oBNwAPGhmJwRNphHpRmsNDABeCrZPBtYB+UBn4AeA5tiRI1JASCqpA37i7lXuvt/dt7v7Y+5e6e57gP8Czv6U16929z+6ey0wHehK5Aduo9uaWTEwBPixu1e7+xvArEbWfzqQC9wSvPYl4ClgfLD/INDfzNq4+053X1Bve1egu7sfdPfXXZOwSSMoICSVbHX3A4dWzKylmd1tZqvNrAJ4DWhrZumHef2mQwvuXhks5h5l227AjnrbANY2sv5uwFp3r6u3bTVQECyPAS4AVpvZq2Y2PNj+K6AceM7MVprZlEZ+nqQ4BYSkkoa/NU8GTgCGuXsb4Kxg++G6jZrCRqC9mbWst62oka/dABQ1GD8oBtYDuPs8dx9NpPvpCeCRYPsed5/s7j2Bi4CbzGzE8R2GpAIFhKSy1kTGHXaZWXvgJ7H+QHdfDZQBPzWzrOC3/C828uVzgErge2aWaWbnBK+dEbzX5WaW5+4HgQoiXWqY2YVm1jsYA9lN5Gu/dVE/QaQeBYSkstuAFsA2YDbwTDN97uXAcGA78J/ATCL3a3wqd68mEgijiNR8J3CVu78fNLkS+CjoLvt68DkAfYAXgL3A28Cd7v5ykx2NJC3TWJVIuMxsJvC+u8f8CkbkaOgKQqSZmdkQM+tlZmlmNhIYTWTMQCSu6E5qkebXBfgbkfsg1gHfcPeF4ZYk8q/UxSQiIlGpi0lERKJKmi6mjh07eklJSdhliIgklPnz529z9/xo+5ImIEpKSigrKwu7DBGRhGJmqw+3T11MIiISlQJCRESiUkCIiEhUCggREYlKASEiIlEpIEREJCoFhIiIRJXyAXHgYC0/nbWEbXuPONuyiEhKSfmAWLR2Fw/NXcMFv3ud2Su3h12OiEjcSPmAGNazA0/822fIzc5gwh9n8/sXP6C2ThMYioikfEAA9O/Whlk3nMlFJ3fjt8+v4Kt/mceuyuqwyxIRCZUCIpCbncHUy07hlxcP5K0Pt3HR7W+ydENF2GWJiIRGAVGPmTFhWDEzrxtOdU0dX77rTZ5ctCHsskREQqGAiGJwcTuevOFMBhbkccPDC5n6/Ar0YCURSTUKiMPIb53NA9cOY+xphfzuxQ+4/uGF7K+uDbssEZFmkzTPg4iF7Ix0fjV2EH065XLLM++zYdd+pk0cQvtWWWGXJiISc7qCOAIz47qze3HX5aexdEMFY+96i7U7KsMuS0Qk5hQQjTRyQBcevHYY2/dVc/Gdb7F4/e6wSxIRiSkFxFEoLWnPY984g+yMNMbfM5u5q3aEXZKISMwoII5S7065PPqN4eS3yeaqP8/h1RVbwy5JRCQmFBDHoGteCx65bjg9O+Zy7fR5PLN4U9gliYg0OQXEMeqYm83Dk05nYEEe1z+0QCEhIklHAXEc8lpkMv2rQxlUqJAQkeSjgDhOrXMiITEwCIlnlygkRCQ5KCCawKGQGFCQxw0PLeT1DzRwLSKJTwHRRNrkZDL96qH0zG/FpPvmM3/1zrBLEhE5LgqIJpTXMpP7rxlG5zbZXH3vXE0XLiIJTQHRxA5N8tcqO4Or/jxX03KISMJSQMRAYbuW3H/NUKpravnKvXP1dDoRSUgKiBjp3ak1f7yqlLU79vO1+8o4cFBThYtIYolpQJjZSDNbbmblZjYlyv5sM5sZ7J9jZiXB9kwzm25m75nZMjO7OZZ1xsqwnh349aUnM++jnUz+6yLq6vTQIRFJHDELCDNLB+4ARgH9gfFm1r9Bs2uAne7eG5gK3BpsvwTIdveBwGnAdYfCI9FcdHI3pozqx9PvbuT3L30QdjkiIo0WyyuIoUC5u69092pgBjC6QZvRwPRg+VFghJkZ4EArM8sAWgDVQMJ+Jei6s3oyZnAht73wAU+/uzHsckREGiWWAVEArK23vi7YFrWNu9cAu4EORMJiH7ARWAP82t3/ZW5tM5tkZmVmVrZ1a/zenGZm/PLLAxhc3JbJf31Hz5IQkYQQr4PUQ4FaoBvQA5hsZj0bNnL3e9y91N1L8/Pzm7vGo5Kdkc7dV5bSvmUWX7uvjO17q8IuSUTkU8UyINYDRfXWC4NtUdsE3Ul5wHZgAvCMux909y3Am0BpDGttFvmts7nnqlJ27KvmWzMWUqtBaxGJY7EMiHlAHzPrYWZZwDhgVoM2s4CJwfJY4CV3dyLdSucCmFkr4HTg/RjW2mwGFOTxiy8N4M3y7fzmueVhlyMiclgxC4hgTOF64FlgGfCIuy8xs5+b2UVBs2lABzMrB24CDn0V9g4g18yWEAmae9393VjV2twuLS1i/NAi7nzlQ57T7K8iEqcs8gt74istLfWysrKwy2i0AwdrueR/3uaj7fv4x7c+S1H7lmGXJCIpyMzmu3vULvx4HaROejmZ6dx5+WBw+NaMhRysrQu7JBGRT1BAhKiofUv+/5iBLFyzi6nPrwi7HBGRT1BAhOzCQd0YN6SIu179kDfLt4VdjojIxxQQceDHX+xPz46t+M7Md9ixTzO/ikh8UEDEgZZZGfxh/GB2VlbzH0+8R7J8cUBEEpsCIk7079aG73yuL/94bxOzFm0IuxwREQVEPLnurF4MLm7Lj55YzKbdB8IuR0RSnAIijqSnGb+59BQO1jrfe+xddTWJSKgUEHGmR8dW3HxBP15bsZW/lq0LuxwRSWEKiDh0xbDuDO3Rnv98eilbKtTVJCLhUEDEobQ045YvD+RATR0/mbUk7HJEJEUpIOJUz/xcvj2iD/9cvIlnFmtCPxFpfgqIODbprJ6c2LUNP/77YnbvPxh2OSKSYhQQcSwzPY1bxwxk294qfqtnR4hIM1NAxLlBhW254vTu3D97tZ5lLSLNSgGRACaffwLtWmbxo78vpk6PKRWRZqKASAB5LTK5+YITWbhmF3+dvzbsckQkRSggEsSYwQUMKWnHLf98n52a8VVEmoECIkGYGb/40gB27z/IbS/o4UIiEnsKiATSr0sbJgwr5oE5a/hg856wyxGRJKeASDDfOa8vLbPS+c+nl4VdiogkOQVEgumQm823R/Th1RVbeXn5lrDLEZEkpoBIQFcNL6GkQ0v+6+llHKytC7scEUlSCogElJWRxg+/0J/yLXt5eO6asMsRkSSlgEhQ553YiaE92vP7Fz9gX1VN2OWISBJSQCQoM2PKqH5s21vNtDdWhV2OiCQhBUQCG1zcjs+f1Jm7X/2Q7Xurwi5HRJKMAiLB/fvnT2D/wVpuf7k87FJEJMkoIBJc706tubS0iAdmr2btjsqwyxGRJKKASAI3nteXNDNue+GDsEsRkSSigEgCXfJyuOL07jy+cB2rtu0LuxwRSRIKiCTx9bN7kZWRxu9f1FWEiDQNBUSSyG+dzcThJfz9nfWUb9F
|
||
|
"text/plain": [
|
||
|
"<Figure size 432x288 with 1 Axes>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {
|
||
|
"needs_background": "light"
|
||
|
},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"x = torch.tensor([2., -1., 3.], requires_grad=False)\n",
|
||
|
"w = torch.tensor([4., 2., -1.], requires_grad=True)\n",
|
||
|
"y_target = 0.65\n",
|
||
|
"\n",
|
||
|
"print(\"x\")\n",
|
||
|
"print(x)\n",
|
||
|
"print(\"x.grad\")\n",
|
||
|
"print(x.grad)\n",
|
||
|
"print(\"w\")\n",
|
||
|
"print(w)\n",
|
||
|
"print(\"w.grad\")\n",
|
||
|
"print(w.grad)\n",
|
||
|
"print()\n",
|
||
|
"print()\n",
|
||
|
"\n",
|
||
|
"optimizer = optim.SGD([w], lr=0.1)\n",
|
||
|
"\n",
|
||
|
"losses = []\n",
|
||
|
"n_epochs = 100\n",
|
||
|
"for epoch in range(n_epochs):\n",
|
||
|
"\n",
|
||
|
" optimizer.zero_grad()\n",
|
||
|
" y = torch.sigmoid(torch.sum(x * w))\n",
|
||
|
" loss = torch.pow(y - y_target, 2)\n",
|
||
|
" loss.backward()\n",
|
||
|
" losses.append(loss.item())\n",
|
||
|
" optimizer.step()\n",
|
||
|
"\n",
|
||
|
" if epoch < 3 or epoch > 96:\n",
|
||
|
" print(\"w\")\n",
|
||
|
" print(w)\n",
|
||
|
" print(\"w.grad\")\n",
|
||
|
" print(w.grad)\n",
|
||
|
" print(\"y\")\n",
|
||
|
" print(y)\n",
|
||
|
" print(\"loss\")\n",
|
||
|
" print(loss)\n",
|
||
|
" print()\n",
|
||
|
" print()\n",
|
||
|
" \n",
|
||
|
"sns.lineplot(x=np.arange(n_epochs), y=losses).set_title('Training loss')\n",
|
||
|
"plt.xlabel(\"epoch\")\n",
|
||
|
"plt.ylabel(\"loss\")\n",
|
||
|
"plt.show()"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "addressed-anxiety",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"# Proper PyTorch model with a fully-connected layer\n",
|
||
|
"\n",
|
||
|
"A fully-connected layer is represented by torch.nn.Linear. Its parameters are:\n",
|
||
|
" - in_features - the number of input neurons,\n",
|
||
|
" - out_features - the number of output neurons,\n",
|
||
|
" - bias - boolean if bias should be included.\n",
|
||
|
" \n",
|
||
|
"Documentation: https://pytorch.org/docs/stable/generated/torch.nn.Linear.html"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 23,
|
||
|
"id": "lovely-wesley",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"class FullyConnectedNetworkModel(nn.Module):\n",
|
||
|
" def __init__(self, seed):\n",
|
||
|
" super().__init__()\n",
|
||
|
"\n",
|
||
|
" self.seed = torch.manual_seed(seed)\n",
|
||
|
"\n",
|
||
|
" self.fc = nn.Linear(3, 1, bias=False)\n",
|
||
|
"\n",
|
||
|
" self.fc.weight.data = torch.tensor([4., 2., -1.], requires_grad=True)\n",
|
||
|
"\n",
|
||
|
" def forward(self, x):\n",
|
||
|
" x = torch.sigmoid(self.fc(x))\n",
|
||
|
"\n",
|
||
|
" return x"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 24,
|
||
|
"id": "hourly-apollo",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"w\n",
|
||
|
"tensor([ 3.9945, 2.0027, -1.0082])\n",
|
||
|
"w.grad\n",
|
||
|
"tensor([ 0.0547, -0.0273, 0.0820])\n",
|
||
|
"y\n",
|
||
|
"tensor(0.9526, grad_fn=<SigmoidBackward>)\n",
|
||
|
"loss\n",
|
||
|
"tensor(0.0916, grad_fn=<PowBackward0>)\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"w\n",
|
||
|
"tensor([ 3.9889, 2.0055, -1.0166])\n",
|
||
|
"w.grad\n",
|
||
|
"tensor([ 0.0563, -0.0281, 0.0844])\n",
|
||
|
"y\n",
|
||
|
"tensor(0.9508, grad_fn=<SigmoidBackward>)\n",
|
||
|
"loss\n",
|
||
|
"tensor(0.0905, grad_fn=<PowBackward0>)\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"w\n",
|
||
|
"tensor([ 3.9831, 2.0084, -1.0253])\n",
|
||
|
"w.grad\n",
|
||
|
"tensor([ 0.0579, -0.0290, 0.0869])\n",
|
||
|
"y\n",
|
||
|
"tensor(0.9489, grad_fn=<SigmoidBackward>)\n",
|
||
|
"loss\n",
|
||
|
"tensor(0.0894, grad_fn=<PowBackward0>)\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"w\n",
|
||
|
"tensor([ 3.6599, 2.1701, -1.5102])\n",
|
||
|
"w.grad\n",
|
||
|
"tensor([ 6.1291e-06, -3.0645e-06, 9.1936e-06])\n",
|
||
|
"y\n",
|
||
|
"tensor(0.6500, grad_fn=<SigmoidBackward>)\n",
|
||
|
"loss\n",
|
||
|
"tensor(4.5365e-11, grad_fn=<PowBackward0>)\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"w\n",
|
||
|
"tensor([ 3.6599, 2.1701, -1.5102])\n",
|
||
|
"w.grad\n",
|
||
|
"tensor([ 5.0985e-06, -2.5493e-06, 7.6478e-06])\n",
|
||
|
"y\n",
|
||
|
"tensor(0.6500, grad_fn=<SigmoidBackward>)\n",
|
||
|
"loss\n",
|
||
|
"tensor(3.1392e-11, grad_fn=<PowBackward0>)\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"w\n",
|
||
|
"tensor([ 3.6599, 2.1701, -1.5102])\n",
|
||
|
"w.grad\n",
|
||
|
"tensor([ 4.4477e-06, -2.2238e-06, 6.6715e-06])\n",
|
||
|
"y\n",
|
||
|
"tensor(0.6500, grad_fn=<SigmoidBackward>)\n",
|
||
|
"loss\n",
|
||
|
"tensor(2.3888e-11, grad_fn=<PowBackward0>)\n",
|
||
|
"\n",
|
||
|
"\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYgAAAEWCAYAAAB8LwAVAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAkGklEQVR4nO3deXhV5bn+8e+TGQiEKUwZCKOIgIoBxFr1J9aCtVILKuBArRbbU22t9LTYns7n9KenA7Z1ONpSiyNYrRa1dZ4HhgCiDIIRZJ6nAIGEJM/5Yy88Md1IgOysPdyf6+JiDe/e+1nXgtxZ77vXu8zdERERaSgt7AJERCQ+KSBERCQqBYSIiESlgBARkagUECIiEpUCQkREolJAiByGmf3TzCY2ddujrOEcM1vX1O8r0hgZYRcg0pTMbG+91ZZAFVAbrF/n7g829r3cfVQs2ookCgWEJBV3zz20bGYfAde6+wsN25lZhrvXNGdtIolGXUySEg511ZjZ981sE3CvmbUzs6fMbKuZ7QyWC+u95hUzuzZY/oqZvWFmvw7arjKzUcfYtoeZvWZme8zsBTO7w8weaORxnBh81i4zW2JmF9Xbd4GZLQ3ed72ZfTfY3jE4tl1mtsPMXjcz/d+XI9I/EkklXYD2QHdgEpF///cG68XAfuD2T3n9MGA50BH4b2CamdkxtH0ImAt0AH4KXNmY4s0sE3gSeA7oBNwAPGhmJwRNphHpRmsNDABeCrZPBtYB+UBn4AeA5tiRI1JASCqpA37i7lXuvt/dt7v7Y+5e6e57gP8Czv6U16929z+6ey0wHehK5Aduo9uaWTEwBPixu1e7+xvArEbWfzqQC9wSvPYl4ClgfLD/INDfzNq4+053X1Bve1egu7sfdPfXXZOwSSMoICSVbHX3A4dWzKylmd1tZqvNrAJ4DWhrZumHef2mQwvuXhks5h5l227AjnrbANY2sv5uwFp3r6u3bTVQECyPAS4AVpvZq2Y2PNj+K6AceM7MVprZlEZ+nqQ4BYSkkoa/NU8GTgCGuXsb4Kxg++G6jZrCRqC9mbWst62oka/dABQ1GD8oBtYDuPs8dx9NpPvpCeCRYPsed5/s7j2Bi4CbzGzE8R2GpAIFhKSy1kTGHXaZWXvgJ7H+QHdfDZQBPzWzrOC3/C828uVzgErge2aWaWbnBK+dEbzX5WaW5+4HgQoiXWqY2YVm1jsYA9lN5Gu/dVE/QaQeBYSkstuAFsA2YDbwTDN97uXAcGA78J/ATCL3a3wqd68mEgijiNR8J3CVu78fNLkS+CjoLvt68DkAfYAXgL3A28Cd7v5ykx2NJC3TWJVIuMxsJvC+u8f8CkbkaOgKQqSZmdkQM+tlZmlmNhIYTWTMQCSu6E5qkebXBfgbkfsg1gHfcPeF4ZYk8q/UxSQiIlGpi0lERKJKmi6mjh07eklJSdhliIgklPnz529z9/xo+5ImIEpKSigrKwu7DBGRhGJmqw+3T11MIiISlQJCRESiUkCIiEhUCggREYlKASEiIlEpIEREJCoFhIiIRJXyAXHgYC0/nbWEbXuPONuyiEhKSfmAWLR2Fw/NXcMFv3ud2Su3h12OiEjcSPmAGNazA0/822fIzc5gwh9n8/sXP6C2ThMYioikfEAA9O/Whlk3nMlFJ3fjt8+v4Kt/mceuyuqwyxIRCZUCIpCbncHUy07hlxcP5K0Pt3HR7W+ydENF2GWJiIRGAVGPmTFhWDEzrxtOdU0dX77rTZ5ctCHsskREQqGAiGJwcTuevOFMBhbkccPDC5n6/Ar0YCURSTUKiMPIb53NA9cOY+xphfzuxQ+4/uGF7K+uDbssEZFmkzTPg4iF7Ix0fjV2EH065XLLM++zYdd+pk0cQvtWWWGXJiISc7qCOAIz47qze3HX5aexdEMFY+96i7U7KsMuS0Qk5hQQjTRyQBcevHYY2/dVc/Gdb7F4/e6wSxIRiSkFxFEoLWnPY984g+yMNMbfM5u5q3aEXZKISMwoII5S7065PPqN4eS3yeaqP8/h1RVbwy5JRCQmFBDHoGteCx65bjg9O+Zy7fR5PLN4U9gliYg0OQXEMeqYm83Dk05nYEEe1z+0QCEhIklHAXEc8lpkMv2rQxlUqJAQkeSjgDhOrXMiITEwCIlnlygkRCQ5KCCawKGQGFCQxw0PLeT1DzRwLSKJTwHRRNrkZDL96qH0zG/FpPvmM3/1zrBLEhE5LgqIJpTXMpP7rxlG5zbZXH3vXE0XLiIJTQHRxA5N8tcqO4Or/jxX03KISMJSQMRAYbuW3H/NUKpravnKvXP1dDoRSUgKiBjp3ak1f7yqlLU79vO1+8o4cFBThYtIYolpQJjZSDNbbmblZjYlyv5sM5sZ7J9jZiXB9kwzm25m75nZMjO7OZZ1xsqwnh349aUnM++jnUz+6yLq6vTQIRFJHDELCDNLB+4ARgH9gfFm1r9Bs2uAne7eG5gK3BpsvwTIdveBwGnAdYfCI9FcdHI3pozqx9PvbuT3L30QdjkiIo0WyyuIoUC5u69092pgBjC6QZvRwPRg+VFghJkZ4EArM8sAWgDVQMJ+Jei6s3oyZnAht73wAU+/uzHsckREGiWWAVEArK23vi7YFrWNu9cAu4EORMJiH7ARWAP82t3/ZW5tM5tkZmVmVrZ1a/zenGZm/PLLAxhc3JbJf31Hz5IQkYQQr4PUQ4FaoBvQA5hsZj0bNnL3e9y91N1L8/Pzm7vGo5Kdkc7dV5bSvmUWX7uvjO17q8IuSUTkU8UyINYDRfXWC4NtUdsE3Ul5wHZgAvCMux909y3Am0BpDGttFvmts7nnqlJ27KvmWzMWUqtBaxGJY7EMiHlAHzPrYWZZwDhgVoM2s4CJwfJY4CV3dyLdSucCmFkr4HTg/RjW2mwGFOTxiy8N4M3y7fzmueVhlyMiclgxC4hgTOF64FlgGfCIuy8xs5+b2UVBs2lABzMrB24CDn0V9g4g18yWEAmae9393VjV2twuLS1i/NAi7nzlQ57T7K8iEqcs8gt74istLfWysrKwy2i0AwdrueR/3uaj7fv4x7c+S1H7lmGXJCIpyMzmu3vULvx4HaROejmZ6dx5+WBw+NaMhRysrQu7JBGRT1BAhKiofUv+/5iBLFyzi6nPrwi7HBGRT1BAhOzCQd0YN6SIu179kDfLt4VdjojIxxQQceDHX+xPz46t+M7Md9ixTzO/ikh8UEDEgZZZGfxh/GB2VlbzH0+8R7J8cUBEEpsCIk7079aG73yuL/94bxOzFm0IuxwREQVEPLnurF4MLm7Lj55YzKbdB8IuR0RSnAIijqSnGb+59BQO1jrfe+xddTWJSKgUEHGmR8dW3HxBP15bsZW/lq0LuxwRSWEKiDh0xbDuDO3Rnv98eilbKtTVJCLhUEDEobQ045YvD+RATR0/mbUk7HJEJEUpIOJUz/xcvj2iD/9cvIlnFmtCPxFpfgqIODbprJ6c2LUNP/77YnbvPxh2OSKSYhQQcSwzPY1bxwxk294qfqtnR4hIM1NAxLlBhW254vTu3D97tZ5lLSLNSgGRACaffwLtWmbxo78vpk6PKRWRZqKASAB5LTK5+YITWbhmF3+dvzbsckQkRSggEsSYwQUMKWnHLf98n52a8VVEmoECIkGYGb/40gB27z/IbS/o4UIiEnsKiATSr0sbJgwr5oE5a/hg856wyxGRJKeASDDfOa8vLbPS+c+nl4VdiogkOQVEgumQm823R/Th1RVbeXn5lrDLEZEkpoBIQFcNL6GkQ0v+6+llHKytC7scEUlSCogElJWRxg+/0J/yLXt5eO6asMsRkSSlgEhQ553YiaE92vP7Fz9gX1VN2OWISBJSQCQoM2PKqH5s21vNtDdWhV2OiCQhBUQCG1zcjs+f1Jm7X/2Q7Xurwi5HRJKMAiLB/fvnT2D/wVpuf7k87FJEJMkoIBJc706tubS0iAdmr2btjsqwyxGRJKKASAI3nteXNDNue+GDsEsRkSSigEgCXfJyuOL07jy+cB2rtu0LuxwRSRIKiCTx9bN7kZWRxu9f1FWEiDQNBUSSyG+dzcThJfz9nfWUb9F
|
||
|
"text/plain": [
|
||
|
"<Figure size 432x288 with 1 Axes>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {
|
||
|
"needs_background": "light"
|
||
|
},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"x = torch.tensor([2., -1., 3.])\n",
|
||
|
"y_target = 0.65\n",
|
||
|
"\n",
|
||
|
"fc_neural_net = FullyConnectedNetworkModel(seed=6789)\n",
|
||
|
"\n",
|
||
|
"optimizer = optim.SGD(fc_neural_net.parameters(), lr=0.1)\n",
|
||
|
"\n",
|
||
|
"losses = []\n",
|
||
|
"n_epochs = 100\n",
|
||
|
"for epoch in range(n_epochs):\n",
|
||
|
"\n",
|
||
|
" optimizer.zero_grad()\n",
|
||
|
" y = fc_neural_net(x)\n",
|
||
|
" loss = torch.pow(y - y_target, 2)\n",
|
||
|
" loss.backward()\n",
|
||
|
" losses.append(loss.item())\n",
|
||
|
" optimizer.step()\n",
|
||
|
" \n",
|
||
|
" if epoch < 3 or epoch > 96:\n",
|
||
|
" print(\"w\")\n",
|
||
|
" print(fc_neural_net.fc.weight.data)\n",
|
||
|
" print(\"w.grad\")\n",
|
||
|
" print(next(fc_neural_net.parameters()).grad)\n",
|
||
|
" print(\"y\")\n",
|
||
|
" print(y)\n",
|
||
|
" print(\"loss\")\n",
|
||
|
" print(loss)\n",
|
||
|
" print()\n",
|
||
|
" print()\n",
|
||
|
" \n",
|
||
|
"sns.lineplot(x=np.arange(n_epochs), y=losses).set_title('Training loss')\n",
|
||
|
"plt.xlabel(\"epoch\")\n",
|
||
|
"plt.ylabel(\"loss\")\n",
|
||
|
"plt.show()"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "breeding-sailing",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"# Embedding layer\n",
|
||
|
"\n",
|
||
|
"An embedding layer is represented by torch.nn.Embedding. Its main parameters are:\n",
|
||
|
" - num_embeddings - the number of ids to embed,\n",
|
||
|
" - embedding_dim - the dimension of the embedding vector.\n",
|
||
|
" \n",
|
||
|
"Documentation: https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html\n",
|
||
|
"\n",
|
||
|
"In the example below we will have 3 movies and 3 users. The movies have already trained representations:\n",
|
||
|
" - $m0 = [0.6, 0.4, -0.2]$\n",
|
||
|
" - $m1 = [-0.7, 0.8, -0.7]$\n",
|
||
|
" - $m2 = [0.8, -0.75, 0.9]$\n",
|
||
|
"where the three dimensions represent: level of violence, positive message, foul language.\n",
|
||
|
"\n",
|
||
|
"We want to find user embeddings so that:\n",
|
||
|
" - user 0 likes movie 0 and dislikes movie 1 and 2,\n",
|
||
|
" - user 1 likes movie 1 and dislikes movie 0 and 2,\n",
|
||
|
" - user 2 likes movie 2 and dislikes movie 0 and 1."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 25,
|
||
|
"id": "posted-performer",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"class EmbeddingNetworkModel(nn.Module):\n",
|
||
|
" def __init__(self, seed):\n",
|
||
|
" super().__init__()\n",
|
||
|
"\n",
|
||
|
" self.seed = torch.manual_seed(seed)\n",
|
||
|
"\n",
|
||
|
" self.embedding = nn.Embedding(3, 3)\n",
|
||
|
"\n",
|
||
|
" def forward(self, x):\n",
|
||
|
" user_id = x[0]\n",
|
||
|
" item_repr = x[1]\n",
|
||
|
" \n",
|
||
|
" y = self.embedding(user_id) * item_repr\n",
|
||
|
" y = torch.sum(y)\n",
|
||
|
" y = torch.sigmoid(y)\n",
|
||
|
"\n",
|
||
|
" return y"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 26,
|
||
|
"id": "pleased-distributor",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAEWCAYAAABrDZDcAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAeH0lEQVR4nO3de5RdZZ3m8e9z6pJUpSqVWxEgVyCICEu5RCDSrYy2DNAYXEtsYRTBhpUZl91iD91qtAcv071W2zqirajQ0IqIigLamFaQS1CYaYIVDNcQCTcJJKRyT8ilbr/5Y+8Kh0pVUlWpXSc57/NZ66zss/d7zvnt7KSeeve793sUEZiZWbpKlS7AzMwqy0FgZpY4B4GZWeIcBGZmiXMQmJklzkFgZpY4B4ElT9KvJF080m2HWMMZklaN9PuaDUZtpQswGw5J28qeNgK7gO78+X+PiJsG+14RcXYRbc0OFg4COyhFRFPvsqTngcsi4u6+7STVRkTXaNZmdrDxqSGrKr2nWCR9StIa4LuSJkpaJKld0sZ8eXrZa+6TdFm+fImkByR9JW/7nKSzh9n2CEm/lbRV0t2Srpb0g0Hux7H5Z22S9ISk+WXbzpH0ZP6+L0n623z9lHzfNknaIOl+Sf4/bvvkfyRWjQ4FJgGzgAVk/86/mz+fCewAvrmX158KrACmAP8MXC9Jw2j7Q+AhYDLweeCiwRQvqQ74BfBr4BDgr4GbJB2TN7me7PRXM3A8cG++/gpgFdAKTAU+A3gOGdsnB4FVox7gcxGxKyJ2RMT6iLg1IrZHxFbgH4F37OX1L0TEv0ZEN3ADcBjZD9ZBt5U0E3grcGVEdETEA8Dtg6z/NKAJ+Kf8tfcCi4AL8+2dwJskjY+IjRHxcNn6w4BZEdEZEfeHJxOzQXAQWDVqj4idvU8kNUq6RtILkrYAvwUmSKoZ4PVrehciYnu+2DTEtocDG8rWAbw4yPoPB16MiJ6ydS8A0/Ll9wHnAC9I+o2kefn6LwMrgV9LelbSpwf5eZY4B4FVo76/BV8BHAOcGhHjgbfn6wc63TMSVgOTJDWWrZsxyNe+DMzoc35/JvASQET8LiLOIztt9HPgJ/n6rRFxRUQcCcwH/qekd+3fblgKHASWgmaycYFNkiYBnyv6AyPiBaAN+Lyk+vy39vcM8uVLgO3AJyXVSTojf+2P8/f6oKSWiOgEtpCdCkPSuZLm5GMUm8kup+3p9xPMyjgILAVfAxqAdcCDwB2j9LkfBOYB64F/AG4mu99hryKig+wH/9lkNX8L+HBEPJU3uQh4Pj/N9T/yzwE4Grgb2Ab8J/CtiFg8YntjVUseSzIbHZJuBp6KiMJ7JGZD4R6BWUEkvVXSUZJKks4CziM7p292QPGdxWbFORS4jew+glXARyPi95UtyWxPPjVkZpY4nxoyM0vcQXdqaMqUKTF79uxKl2FmdlBZunTpuoho7W/bQRcEs2fPpq2trdJlmJkdVCS9MNA2nxoyM0ucg8DMLHEOAjOzxDkIzMwS5yAwM0ucg8DMLHEOAjOzxCUTBCvWbOUrd65g46sdlS7FzOyAkkwQPLfuVb65eCWrN+/cd2Mzs4QkEwQtDXUAbNrhHoGZWblkgmBCYxYEm7d3VrgSM7MDS3JBsGmHg8DMrFw6QdBQD8Am9wjMzF4nmSAYW1eivrbkMQIzsz6SCQJJTGio8xiBmVkfyQQBZOMEPjVkZvZ6aQVBQz2bPVhsZvY6SQVBS2OdrxoyM+sjqSDIxgg8WGxmVi6tIHCPwMxsD4kFQT3bO7rZ1dVd6VLMzA4YSQVB73xDHjA2M3tNmkHgS0jNzHYrPAgk1Uj6vaRF/Wy7RFK7pGX547Iia/F8Q2Zme6odhc+4HFgOjB9g+80R8VejUMfu+YbcIzAze02hPQJJ04E/B64r8nMGq7dHsNGXkJqZ7Vb0qaGvAZ8EevbS5n2SHpV0i6QZ/TWQtEBSm6S29vb2YRczaVzWI9jgr6s0M9utsCCQdC6wNiKW7qXZL4DZEfFm4C7ghv4aRcS1ETE3Iua2trYOu6bG+hrG1pUcBGZmZYrsEZwOzJf0PPBj4J2SflDeICLWR8Su/Ol1wMkF1oMkJo8bw7ptDgIzs16FBUFELIyI6RExG7gAuDciPlTeRtJhZU/nkw0qF2pyUz3rX92174ZmZokYjauGXkfSF4G2iLgd+Lik+UAXsAG4pOjPnzyu3j0CM7MyoxIEEXEfcF++fGXZ+oXAwtGoodfkpjGsWLN1ND/SzOyAltSdxdB7aqiDiKh0KWZmB4T0gmBcPbu6eni1wxPPmZlBkkEwBoD12zxgbGYGKQZBU3ZTmQeMzcwyyQXBlKasR+CbyszMMskFQe80Ez41ZGaWSTcI3CMwMwMSDIKxdTU0j6llnXsEZmZAgkEA+b0EHiw2MwMSDYJJ4+o9WGxmlksyCCY3jfGpITOzXJJBMKXJE8+ZmfVKMggmjxvDxu0d9PR4viEzsySDYEpTPd094e8uNjMj1SBozucb8oCxmVmaQdA78dy6rR4wNjNLMgham7O7i9t95ZCZWZpB0DvxnK8cMjNLNAhaGuqoLckTz5mZkWgQSGJyU71vKjMzYxSCQFKNpN9LWtTPtjGSbpa0UtISSbOLrqfXlKYxPjVkZsbo9AguB5YPsO1SYGNEzAGuAr40CvUAvUHgHoGZWaFBIGk68OfAdQM0OQ+4IV++BXiXJBVZU68pTWM8A6mZGcX3CL4GfBLoGWD7NOBFgIjoAjYDk/s2krRAUpuktvb29hEpbEpTPe3bdhHhaSbMLG2FBYGkc4G1EbF0f98rIq6NiLkRMbe1tXUEqst6BB1dPWzd1TUi72dmdrAqskdwOjBf0vPAj4F3SvpBnzYvATMAJNUCLcD6AmvabUp+U5nvLjaz1BUWBBGxMCKmR8Rs4ALg3oj4UJ9mtwMX58vn521G5VxN701lnm/IzFJXO9ofKOmLQFtE3A5cD9woaSWwgSwwRoXnGzIzy4xKEETEfcB9+fKVZet3Au8fjRr62n1qyJeQmlnikryzGGBSYz0StPsSUjNLXLJBUFtTYlJjvecbMrPkJRsEgOcbMjMj8SDwfENmZg4C9wjMLHnJB4HnGzKz1CUdBJOb6tm2q4udnd2VLsXMrGKSDoLW/O7idt9UZmYJSzoIfFOZmVnqQdA735DHCcwsYUkHweQ8CNwjMLOUpR0E43xqyMws6SAYW1dD89ha31RmZklLOgggu3LIPQIzS1nyQeD5hswsdckHgecbMrPUOQh8asjMEpd8EEwaV8/mHZ1094zKVyWbmR1wkg+CCY11RMCWHZ2VLsXMrCKSD4KJjdm9BBu3e5zAzNJUWBBIGivpIUmPSHpC0hf6aXOJpHZJy/LHZUXVM5CWxjoANrlHYGaJqi3wvXcB74yIbZLqgAck/SoiHuzT7uaI+KsC69ir3h7BJvcIzCxRhQVBRASwLX9alz8OuBHZCQ15j2C7ewRmlqZCxwgk1UhaBqwF7oqIJf00e5+kRyXdImnGAO+zQFKbpLb29vYRrfG1MQIHgZmlqdAgiIjuiDgBmA6cIun4Pk1+AcyOiDcDdwE3DPA+10bE3IiY29raOqI1No+tpSSfGjKzdI3KVUMRsQlYDJzVZ/36iOi9m+s64OTRqKdcqSRaGup8asjMklXkVUOtkibkyw3Au4Gn+rQ5rOzpfGB5UfXszYTGel8+ambJKvKqocOAGyTVkAXOTyJikaQvAm0RcTvwcUnzgS5gA3BJgfUMaEJjHZt9+aiZJarIq4YeBU7sZ/2VZcsLgYVF1TBYExvrWbt1Z6XLMDOriOTvLIbsElKPEZhZqhwEZGMEDgIzS5WDgGyMYNuuLjq6eipdipnZqHMQABPz+YY8YGxmKXIQAC2eb8jMEuYg4LX5htwjMLMUOQiAFk88Z2YJcxDwWhC4R2BmKXIQkF01BA4CM0uTgwBoHusgMLN0OQiAmpJoHlvrIDCzJA0qCCRdLmm8MtdLeljSmUUXN5paGjzxnJmlabA9gr+MiC3AmcBE4CLgnwqrqgI8A6mZpWqwQaD8z3OAGyPiibJ1VcE9AjN
|
||
|
"text/plain": [
|
||
|
"<Figure size 432x288 with 1 Axes>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {
|
||
|
"needs_background": "light"
|
||
|
},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"user_ids = [torch.tensor(0), torch.tensor(1), torch.tensor(2)]\n",
|
||
|
"items = [torch.tensor([0.6, 0.4, -0.2]), \n",
|
||
|
" torch.tensor([-0.7, 0.8, -0.7]), \n",
|
||
|
" torch.tensor([0.8, -0.75, 0.9])]\n",
|
||
|
"responses = [1, 0, 0, 0, 1, 0, 0, 0, 1]\n",
|
||
|
"data = [(user_ids[user_id], items[item_id]) for user_id in range(3) for item_id in range(3)]\n",
|
||
|
"\n",
|
||
|
"embedding_nn = EmbeddingNetworkModel(seed=6789)\n",
|
||
|
"\n",
|
||
|
"optimizer = optim.SGD(embedding_nn.parameters(), lr=0.1)\n",
|
||
|
"\n",
|
||
|
"losses = []\n",
|
||
|
"n_epochs = 1000\n",
|
||
|
"for epoch in range(n_epochs):\n",
|
||
|
"\n",
|
||
|
" optimizer.zero_grad()\n",
|
||
|
" \n",
|
||
|
" for i in range(len(data)):\n",
|
||
|
" user_id = data[i][0]\n",
|
||
|
" item_repr = data[i][1]\n",
|
||
|
" \n",
|
||
|
" y = embedding_nn((user_id, item_repr))\n",
|
||
|
" if i == 0:\n",
|
||
|
" loss = torch.pow(y - responses[i], 2)\n",
|
||
|
" else:\n",
|
||
|
" loss += torch.pow(y - responses[i], 2)\n",
|
||
|
" \n",
|
||
|
" for param in embedding_nn.parameters():\n",
|
||
|
" loss += 1 / 5 * torch.norm(param)\n",
|
||
|
" \n",
|
||
|
" loss.backward()\n",
|
||
|
" losses.append(loss.item())\n",
|
||
|
" optimizer.step()\n",
|
||
|
"\n",
|
||
|
"sns.lineplot(x=np.arange(n_epochs), y=losses).set_title('Training loss')\n",
|
||
|
"plt.xlabel(\"epoch\")\n",
|
||
|
"plt.ylabel(\"loss\")\n",
|
||
|
"plt.show()"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 27,
|
||
|
"id": "turkish-thinking",
|
||
|
"metadata": {
|
||
|
"scrolled": false
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Embedding for user 0\n",
|
||
|
"tensor([ 0.9887, 0.2676, -0.7881], grad_fn=<EmbeddingBackward>)\n",
|
||
|
"Representation for item 0\n",
|
||
|
"tensor([ 0.6000, 0.4000, -0.2000])\n",
|
||
|
"Score=0.7\n",
|
||
|
"\n",
|
||
|
"Embedding for user 0\n",
|
||
|
"tensor([ 0.9887, 0.2676, -0.7881], grad_fn=<EmbeddingBackward>)\n",
|
||
|
"Representation for item 1\n",
|
||
|
"tensor([-0.7000, 0.8000, -0.7000])\n",
|
||
|
"Score=0.52\n",
|
||
|
"\n",
|
||
|
"Embedding for user 0\n",
|
||
|
"tensor([ 0.9887, 0.2676, -0.7881], grad_fn=<EmbeddingBackward>)\n",
|
||
|
"Representation for item 2\n",
|
||
|
"tensor([ 0.8000, -0.7500, 0.9000])\n",
|
||
|
"Score=0.47\n",
|
||
|
"\n",
|
||
|
"Embedding for user 1\n",
|
||
|
"tensor([-1.7678, 0.1267, -0.4628], grad_fn=<EmbeddingBackward>)\n",
|
||
|
"Representation for item 0\n",
|
||
|
"tensor([ 0.6000, 0.4000, -0.2000])\n",
|
||
|
"Score=0.29\n",
|
||
|
"\n",
|
||
|
"Embedding for user 1\n",
|
||
|
"tensor([-1.7678, 0.1267, -0.4628], grad_fn=<EmbeddingBackward>)\n",
|
||
|
"Representation for item 1\n",
|
||
|
"tensor([-0.7000, 0.8000, -0.7000])\n",
|
||
|
"Score=0.84\n",
|
||
|
"\n",
|
||
|
"Embedding for user 1\n",
|
||
|
"tensor([-1.7678, 0.1267, -0.4628], grad_fn=<EmbeddingBackward>)\n",
|
||
|
"Representation for item 2\n",
|
||
|
"tensor([ 0.8000, -0.7500, 0.9000])\n",
|
||
|
"Score=0.13\n",
|
||
|
"\n",
|
||
|
"Embedding for user 2\n",
|
||
|
"tensor([-0.2462, -1.4256, 1.1095], grad_fn=<EmbeddingBackward>)\n",
|
||
|
"Representation for item 0\n",
|
||
|
"tensor([ 0.6000, 0.4000, -0.2000])\n",
|
||
|
"Score=0.28\n",
|
||
|
"\n",
|
||
|
"Embedding for user 2\n",
|
||
|
"tensor([-0.2462, -1.4256, 1.1095], grad_fn=<EmbeddingBackward>)\n",
|
||
|
"Representation for item 1\n",
|
||
|
"tensor([-0.7000, 0.8000, -0.7000])\n",
|
||
|
"Score=0.15\n",
|
||
|
"\n",
|
||
|
"Embedding for user 2\n",
|
||
|
"tensor([-0.2462, -1.4256, 1.1095], grad_fn=<EmbeddingBackward>)\n",
|
||
|
"Representation for item 2\n",
|
||
|
"tensor([ 0.8000, -0.7500, 0.9000])\n",
|
||
|
"Score=0.87\n",
|
||
|
"\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAX4AAAD4CAYAAADrRI2NAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAA1RElEQVR4nO3deXwT5dbA8d9pCxRKW9oCZZcdARVQFlFBNhFRQcWruIKiqC/iiveieL3u4vV6cb0qKorivoOIrKKCrEqFgqLs+9oCLbSlbc77R4aaQksT0iZNc7585pOZZ57MnAnTkyfPPJmIqmKMMSZ8RAQ7AGOMMYFlid8YY8KMJX5jjAkzlviNMSbMWOI3xpgwExXsAPyVs3ahDUsqY7mvPh3sECq8lz6OCXYIYWH0xkni7zZy96zzOudUqtnU7/2VBWvxG2NMmAn5Fr8xxgSUKz/YEfjNEr8xxvgiPy/YEfjNEr8xxvhA1RXsEPxmffzGGOMLl8v7qQQiMkFEdolIajHrRUReEJE1IrJcRE4vjUOwxG+MMb5Ql/dTyd4G+h1n/QVAC2caDrzid/xYV48xxvimFC/uquoPItL4OFUGAu+o+26aC0WkhojUVdXt/uzXEr8xxvgisH389YHNHstbnDJL/MYYEyjqw6geERmOu4vmiPGqOr7Ug/KRJX5jjPGFFxdtj3CSvD+JfivQ0GO5gVPmF7u4a4wxvijdi7slmQxc74zuORPY72//PliL3xhjfFOKF3dF5AOgB1BTRLYA/wIqAajqq8A3QH9gDXAIuKE09muJ3xhjfFGKF3dV9aoS1iswotR26LDEX4rmLV3O06+9h8vl4rLzz2XYFRcVWr9t5x4eeu5N0vcfID62Ok/edwt1aiYCMG7CR/yw5FcAbhk8kH7ndgl4/KEgslUHqgwYBhER5C6eRe53nxdaX6n7ACp17oO68tHMA+R8/BK6bzeRzU6h8oAbC+pF1KpP9nvPkr9ycaAPIWT0efg6mvVsT25WDlNHjWdn6oZC6yvHRHPNJ/8sWI6tm8jKL+Yz+9FJxNVPov8zw6mWGEv2voNMuesVMnakBfgIyojdssEckZ/v4sn/vcP4J/5Ocs1ErrrrYXqc2YFmjeoX1Hn2zQ+5uPfZDOxzDotSVvHCW5/w5H238MPiFH5bs5FPXnqMw7l5DPvHU5zT6TSqV6saxCMqhySCKpcOJ2v8w+j+vVS949/krVyM7tpSUMW1dR2Hnh8FuYeJ6no+lS+8npz3niV/bSpZ4+5xV6panZjR/yP/j5TgHEcIaNqzHQlN6vDaufdSr0Mzzn98KO9c8nChOocPZvNW/zEFy0O/fow/vl0CQK8xV5P62TxSP/uRk85qw7n/uIKv7341kIdQdny4uFtelbuLuyIyRET+dKYhwY7HW6l/rKNRvWQa1K1NpUpR9Ovehe8W/FKozrpNW+nSrjUAndu15ruF7vVrN23jjFNaERUZSbXoKrRs0pD5S5cH/BjKu4hGLXDt2Y6m7YT8PPJS5hHVtnOhOvlrUyH3MACujX8QUSPpmO1EndaVvN9/KahnjtXivDNI/WweANuWraVKXAwxtWsUWz+hSR2qJcWxefFqAJJa1GfjTysB2PjTKlqcd0aZxxwoqvleT+VVUBK/iBT5SUNEEnFf3OgCdAb+JSIJgYztRO3cm06y020DkFwzkV170wvVadmkEbPm/wzA7J9+5mBWNvsOZNKqaUPm/7ycrOwc0vdnsHj5b+zYU0E+FpciiUtE9+0pWNb9e5H4YxP7EVGd+7gT/NHl7buRlzKvTGKsKGLrJJCxbW/BcsaONGKTi/9TbHPxmfz29cKC5V2/baJVv04AtOzXkSqxVYmuUb3sAg6kwI7qKRNeJX4Raex5EyERGSUiD4vIHSKyyrl50IfOuhjnxkOLRWSZiAx0yoeKyGQRmQPMLmZX5wMzVTVNVdOBmRRxHwsRGS4iS0Vk6RsffunbEQfRvTcN5ufU37ni9n+ydMXv1E5KICJCOOv0UzmnUzuuH/U4/3j6Fdqd3JyIiHL3YSykRJ1+LpENmpE798tC5RKbQGSdRuSvXhacwCqo1gO6suqrBQXL3z3+Pg3PPJkbvnmcRl1ac2B7GloBukiAUr1JW7D428c/GmiiqjkiUsMpGwPMUdUbnbLFIjLLWXc6cJqqFtecLe7ryYV4fimivPz0YnJSAjs9Wuk796RRO6lwC6l2UgLjHrwDgENZ2cyav5S46u6f3Bs+eADDBw8A4B9Pv0Lj+nUCFHno0ANpSI2aBcsSn4Tu33tMvcgWp1G51+VkvfLgMRfiotqdTV7qogrxYxql7fTr+9BucE8Ati9fR2y9vz5NxdZJJGNnepHPq926ERGREYUu/mbu2scXtzwPQKVqVWh5QSdyDhwqu+ADqRy35L3lb7NyOfCeiFwLHPkL6wuMFpEUYC4QDTRy1s08TtIPaW1bNmHjtp1s2bGb3Nw8vv1hET3O7FCoTvr+DFxOK+CNj7/m0r7dAfeF4X0HMgH4Y/0m/tiwma6nnxLYAwgBrs1/ElGzLpJQGyKjiGp/DvmrlhSqE1GvCVUG3UbW20+iB/cfs42o9ueQl/JjoEIOKb+8M4u3+o/hrf5j+HPGz5wy6BwA6nVoRk7GIQ7u2lfk81oP6MqqyQsKlVVNqA7i/rnZriMGsOLj78s09oDKz/V+Kqe8bfHnUfhNItp5vBDoDlwMjBGRUwEBBqnqas8NiEgX4GAJ+9mK+8sMRzTA/eZR7kVFRvLAbddx24PPkO9ycUnf7jQ/qQEvv/s5bVo0pueZp7Nkxe+88PYnCHD6Ka0YM+J6APLy8xh63xMAxFSrylOjbiEqMjKIR1NOuVzkfPk6VW/+lzOcczaunZup3Pcq8resIX/VEipfNAQqRxN93X0AaPpust9+CgBJqIXUqEn+upXBPIqQsHZOCk17tuOWH54lN+sw34z6664DN3zzRKHRPK0v6sLHQ58p9PxGXVtz7t+vBFU2L17NjH++HajQy1457sLxlri/H1BCJZFKuO8G1wrIBL4HZgATVHWDs34j0Ab4OxAHjFRVFZEOqrpMRIYCHVX19uPsJxH4GXeXEMAvwBnH+5RQXrp6KrLcV58OdggV3ksfxwQ7hLAweuMk8Xcb2Qs+8DrnRHe9yu/9lQWvWvyqmisijwKLcbfKfwcigUkiEo+7lf+Cqu4TkceA54DlIhIBrAcuKnrLx+wnzXn+kc/vj1bUriFjTIiqAC1+ry/uquoLwAte1MsCbimi/G3cvzZT0vMnABO8jcsYYwIqnBK/McYY0HJ80dZbQUn8zkXgd48qzlFVu0GNMaZ8qwDDOYOS+FV1BdA+GPs2xhi/WFePMcaEGWvxG2NMmLEWvzHGhBlr8RtjTJjJsx9iMcaY8GItfmOMCTPWx2+MMWHGWvzGGBNmrMUffDGtBwU7hApv75UnBzuECu/07CrBDsF4y1r8xhgTZmxUjzHGhBkvfsOkvLPEb4wxvrA+fmOMCTOW+I0xJszYxV1jjAkz+fnBjsBvEcEOwBhjQorL5f1UAhHpJyKrRWSNiIwuYv1QEdktIinOdFNpHIK1+I0xxhel1McvIpHAy8B5wBZgiYhMVtVVR1X9SFVvL5WdOqzFb4wxvlCX99PxdQbWqOo6VT0MfAgMLPP4scRvjDE+UZd6PYnIcBFZ6jEN99hUfWCzx/IWp+xog0RkuYh8KiINS+MYrKvHGGN84UNXj6qOB8b7sbcpwAeqmiMitwATgV5+bA+wxG+MMb4pvVE9WwHPFnwDp6yAqu71WHwD+Hdp7Ni6eowxxhelN6pnCdBCRJqISGVgMDDZs4KI1PVYHAD8VhqHYC1+Y4zxRSmN6lHVPBG5HZgORAITVHWliDwKLFXVycAdIjIAyAPSgKGlsW9L/GWkVatmvPn6ODp0OIV/PvQ0/x33WpH1evU8h7FjHyQiIoKDmQe58aa7Wbt2Q2CDDTFRp3Yi+roREBFB7txvyPn6w0LrI1udStVrRxDRsCmHXn6cvCU/ACBJtYm561EQgcgoDs/8gsNzvg7GIYSEVk8MoVbvDuRn5ZB6xytkrNh
|
||
|
"text/plain": [
|
||
|
"<Figure size 432x288 with 2 Axes>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {
|
||
|
"needs_background": "light"
|
||
|
},
|
||
|
"output_type": "display_data"
|
||
|
},
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYoAAAEGCAYAAAB7DNKzAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAhhklEQVR4nO3df3hV1Z3v8ff3EExuAYFg+CGIIUgHwcSISbRaR0EhWL1QW7G2SvVWRudeR9p6hzZzGbm2d5gbexlRM63CVAtWHQS0yqBXUSvqPLZCxPDDogIa+VEVjBJJpolJznf+ODtpwOTkhOT8SPJ5Pc95ztn77Oz1Xc9Jzjd7rbXXMndHRESkPaFkByAiIqlNiUJERKJSohARkaiUKEREJColChERiSot2QHEw0knneTZ2dnJDkNEpMd4/fXXP3b3rLbe65WJIjs7m/Ly8mSHISLSY5jZ++29p6YnERGJSolCRESiUqIQEZGolChERCQqJQoREYmqV456EmkWDjuVVbV89FkdI07MIHvYAEIhS3ZYIj2KEoX0WuGw88ybH3Lr6grqGsJk9A9x51X5zJw8UslCpBPU9CS9VmVVbUuSAKhrCHPr6goqq2qTHJlIz6JEIb3WR5/VtSSJZnUNYQ4eqUtSRCI9kxKF9FojTswgo//Rv+IZ/UMMH5SRpIhEeiYlCum1socN4M6r8luSRXMfRfawAUmOTKRnUWe29FqhkDFz8kgmzr+Ag0fqGD5Io55EjocShfRqoZCRkzWQnKyByQ5FpMdS05OIiESlRCEiIlEpUUifcd555wFQWVnJI488EteyXn75ZaZMmUJaWhpr166Na1ki8aZEIX3Gq6++CiQmUYwdO5YVK1bwne98J67liCSCEoX0GQMHRjq0S0pKeOWVV8jPz2fp0qU0NTWxYMECCgsLycvLY9myZQBs3LiRCy+8kNmzZ5OTk0NJSQkPP/wwRUVF5ObmsmfPnnbLys7OJi8vj1BIf2LS82nUk/Q5paWlLFmyhPXr1wOwfPlyBg8ezObNm6mvr+f8889nxowZAGzdupWdO3eSmZlJTk4O8+bNY9OmTdx9992UlZVx1113JbEmIomhRCF93oYNG9i2bVtLX0J1dTW7du3ihBNOoLCwkFGjRgEwfvz4lgSSm5vLiy++mLSYRRJJiUL6PHenrKyM4uLio/Zv3LiR9PT0lu1QKNSyHQqFaGxsTGicIsmiBlTpcwYNGsSRI0datouLi7n33ntpaGgA4J133qG2VjPMijRTopA+Jy8vj379+nHmmWeydOlS5s2bx6RJk5gyZQpnnHEGN910U5evFjZv3syYMWNYs2YNN910E5MnT+6m6EUSz9w92TF0u4KCAi8vL092GCIiPYaZve7uBW29pysKERGJSp3ZIl2wePFi1qxZc9S+OXPmsHDhwiRFJNL91PQkIiJqehIRkeOnRCEiIlEpUYiISFRKFCIiEpUShYiIRJXURGFmM83sbTPbbWYlbbx/vZkdMrOK4DEvGXGKiPRlSbuPwsz6AT8HpgP7gc1mts7d/3DMoY+6+98kPEAREQGSe0VRBOx293fd/XNgFTA7ifGIiEgbkpkoRgP7Wm3vD/Yd65tmts3M1prZKe2dzMxuNLNyMys/dOhQd8cqItJnpXpn9r8B2e6eBzwHrGzvQHdf7u4F7l6QlZWVsABFRHq7ZCaKA0DrK4Qxwb4W7l7l7vXB5i+BsxMUm4iIBJKZKDYDE8xsnJmdAFwNrGt9gJmNarU5C9iZwPhERIQkjnpy90Yz+xvgWaAf8IC7v2lmPwXK3X0dMN/MZgGNwCfA9cmKV0Skr9LssSIiEnX2WK1HISLShnDYqayq5aPP6hhxYgbZwwYQClmyw0oKJQoRkWOEw84zb37IrasrqGsIk9E/xJ1X5TNz8sg+mSxSfXisiEjCVVbVtiQJgLqGMLeurqCyqjbJkSWHEoWIyDE++qyuJUk0q2sIc/BIXZIiSi4lChGRY4w4MYOM/kd/PWb0DzF8UEaSIkouJQoRkWNkDxvAnVfltySL5j6K7GEDkhxZcqgzW0TkGKGQMXPySCbOv4CDR+oYPkijnkRE5BihkJGTNZCcrIHJDiXp1PQkIiJRKVGIiEhUShQi0mucd955AFRWVvLII4/Etaz6+nq+9a1vcdppp3HOOedQWVkZ1/KSSYlCRHqNV199FUhMorj//vsZOnQou3fv5oc//CE//vGP41peMilRiEivMXBgpOO5pKSEV155hfz8fJYuXUpTUxMLFiygsLCQvLw8li1bBsDGjRu58MILmT17Njk5OZSUlPDwww9TVFREbm4ue/bsabesJ598kuuuuw6AK6+8khdeeIFkTbIaDjvvHqrhd3s+5t1DNYTD3RuHRj2JSK9TWlrKkiVLWL9+PQDLly9n8ODBbN68mfr6es4//3xmzJgBwNatW9m5cyeZmZnk5OQwb948Nm3axN13301ZWRl33XVXm2UcOHCAU06JrL2WlpbG4MGDqaqq4qSTTkpIHZslYl4qXVGISK+3YcMGHnzwQfLz8znnnHOoqqpi165dABQWFjJq1CjS09MZP358SwLJzc3tEf0OiZiXSlcUItLruTtlZWUUFxcftX/jxo2kp6e3bIdCoZbtUChEY2Nju+ccPXo0+/btY8yYMTQ2NlJdXc2wYcPiU4Eoos1L1V33gOiKQkR6nUGDBnHkyJGW7eLiYu69914aGhoAeOedd6it7dp/3LNmzWLlypUArF27lmnTpmGW+Du3EzEvla4oRKTXycvLo1+/fpx55plcf/31fP/736eyspIpU6bg7mRlZfHEE090qYwbbriBuXPnctppp5GZmcmqVau6J/hOap6X6tg+iu6cl0pLoYqI9HDNq/F1ZV4qLYUqItKLxXteKiUKEZEoFi9ezJo1a47aN2fOHBYuXJikiBJPTU8iIhK16UmjnkREermVK1cyYcIEJkyY0DJSqzPU9CQi0gs0NjaSlvbFr/RPPvmEn/zkJ5SXl2NmnH322cyaNYuhQ4fGfO4Orygs4lozWxRsjzWzos5UoCeI91wpIiKtVVZWcsYZZ7RsL1myhNtvv5177rmHSZMmkZeXx9VXXw1AbW0t3/ve9ygqKuKss87iySefBGDFihXMmjWLadOmcfHFF7dZzrPPPsv06dPJzMxk6NChTJ8+nWeeeaZTscZyRfELIAxMA34KHAEeAwo7VVIKS8RcKSIisSgtLeW9994jPT2dw4cPA5EO9WnTpvHAAw9w+PBhioqKuOSSSwDYsmUL27ZtIzMzs83ztZ6TCmDMmDEcOHCgUzHF0kdxjrvfDNQBuPunwAmdKiXFJWKuFBGRWOTl5XHNNdfw0EMPtTQlbdiwgdLSUvLz87nooouoq6tj7969AC1XC/EUS6JoMLN+gAOYWRaRK4xeI9pcKSIi8ZCWlkY4/Ofvnbq6yPfNU089xc0338yWLVsoLCyksbERd+exxx6joqKCiooK9u7dy+mnnw7AgAHR78BunpOq2f79+xk9enSnYo0lUdwD/AYYbmaLgX8H/rFTpaS4RMyVIiLS2ogRIzh48CBVVVXU19ezfv16wuEw+/btY+rUqdxxxx1UV1dTU1NDcXExZWVlLetdvPHGGzGXU1xczIYNG/j000/59NNP2bBhwxcmR+xIh30U7v6wmb0OXAwY8HV339mpUlJcIuZKERFprX///ixatIiioiJGjx7NxIkTaWpq4tprr6W6uhp3Z/78+QwZMoTbbruNH/zgB+Tl5REOhxk3blzLWhsdyczM5LbbbqOwMNKtvGjRok43VXV4w52ZtXXGI+7e0KmSEuh4brjrjrlSRER6qq7O9bQFOAX4lMgVxRDgQzP7CPgrd3+9uwJNpnjPlSIi0lPFkiieA9a6+7MAZjYD+CbwKyJDZ8+JX3giIhKL7du3M3fu3KP2paen89prr3X53LE0PW1399xj9m1z9zwzq3D3/C5H0c0015OISOd0da6nD8zsx2Z2avD4EfBRMGS2Vw2TTYaZM2cyZMgQLr/88mSHIiLSplgSxXeAMcATwWNssK8fcFVXCjezmWb2tpn
|
||
|
"text/plain": [
|
||
|
"<Figure size 432x288 with 1 Axes>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {
|
||
|
"needs_background": "light"
|
||
|
},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"id_pairs = [(user_id, item_id) for user_id in range(3) for item_id in range(3)]\n",
|
||
|
"\n",
|
||
|
"for id_pair in id_pairs:\n",
|
||
|
" print(\"Embedding for user {}\".format(id_pair[0]))\n",
|
||
|
" print(embedding_nn.embedding(user_ids[id_pair[0]]))\n",
|
||
|
" print(\"Representation for item {}\".format(id_pair[1]))\n",
|
||
|
" print(items[id_pair[1]])\n",
|
||
|
" print(\"Score={}\".format(round(embedding_nn((user_ids[id_pair[0]], items[id_pair[1]])).item(), 2)))\n",
|
||
|
" print()\n",
|
||
|
" \n",
|
||
|
"embeddings = pd.DataFrame(\n",
|
||
|
" [\n",
|
||
|
" ['user_0'] + embedding_nn.embedding(user_ids[0]).tolist(),\n",
|
||
|
" ['user_1'] + embedding_nn.embedding(user_ids[1]).tolist(),\n",
|
||
|
" ['user_2'] + embedding_nn.embedding(user_ids[2]).tolist(),\n",
|
||
|
" ['item_0'] + items[0].tolist(),\n",
|
||
|
" ['item_1'] + items[1].tolist(),\n",
|
||
|
" ['item_2'] + items[2].tolist()\n",
|
||
|
" \n",
|
||
|
" ],\n",
|
||
|
" columns=['entity', 'violence', 'positive message', 'language']\n",
|
||
|
")\n",
|
||
|
"\n",
|
||
|
"ax = sns.heatmap(embeddings.loc[:, ['violence', 'positive message', 'language']], annot=True)\n",
|
||
|
"ax.yaxis.set_major_formatter(ticker.FixedFormatter(embeddings.loc[:, 'entity'].tolist()))\n",
|
||
|
"plt.yticks(rotation=0)\n",
|
||
|
"plt.show()\n",
|
||
|
"\n",
|
||
|
"ax = sns.scatterplot(data=embeddings, x='violence', y='positive message')\n",
|
||
|
"for i in range(embeddings.shape[0]):\n",
|
||
|
" x = embeddings['violence'][i]\n",
|
||
|
" x = x + (-0.1 + 0.1 * -np.sign(x - np.mean(embeddings['violence'])))\n",
|
||
|
" y = embeddings['positive message'][i]\n",
|
||
|
" y = y + (-0.02 + 0.13 * -np.sign(y - np.mean(embeddings['positive message'])))\n",
|
||
|
" plt.text(x=x, y=y, s=embeddings['entity'][i])\n",
|
||
|
"plt.show()"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "middle-newman",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"## PyTorch advanced operations tasks"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "manual-serial",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"**Task 4.** Calculate the derivative $f'(w)$ using PyTorch and backward propagation (the backword method of the Tensor class) for the following functions and points:\n",
|
||
|
" - $f(w) = w^3 + w^2$ and $w = 2.0$,\n",
|
||
|
" - $f(w) = \\text{sin}(w)$ and $w = \\pi$,\n",
|
||
|
" - $f(w) = \\ln(w * e^{3w})$ and $w = 1.0$."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 28,
|
||
|
"id": "copyrighted-perry",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# Write your code here"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "frequent-sarah",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"**Task 5.** Calculate the derivative $\\frac{\\partial f}{\\partial w_1}(w_1, w_2, w_3)$ using PyTorch and backward propagation (the backword method of the Tensor class) for the following functions and points:\n",
|
||
|
" - $f(w_1, w_2) = w_1^3 + w_1^2 + w_2$ and $(w_1, w_2) = (2.0, 3.0)$,\n",
|
||
|
" - $f(w_1, w_2, w_3) = \\text{sin}(w_1) * w_2 + w_1^2 * w_3$ and $(w_1, w_2) = (\\pi, 2.0, 4.0)$,\n",
|
||
|
" - $f(w_1, w_2, w_3) = e^{w_1^2 + w_2^2 + w_3^2} + w_1^2 + w_2^2 + w_3^2$ and $(w_1, w_2, w_3) = (0.5, 0.67, 0.55)$."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 29,
|
||
|
"id": "dietary-columbia",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# Write your code here"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "short-border",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"**Task 6*.** Train a neural network with:\n",
|
||
|
" - two input neurons, \n",
|
||
|
" - four hidden neurons with sigmoid activation in the first hidden layer,\n",
|
||
|
" - four hidden neurons with sigmoid activation in the second hidden layer,\n",
|
||
|
" - one output neuron without sigmoid activation \n",
|
||
|
" \n",
|
||
|
"to get a good approximation of $f(x) = x_1 * x_2 + 1$ on the following dataset $D = \\{(1.0, 1.0), (0.0, 0.0), (2.0, -1.0), (-1.0, 0.5), (-0.5, -2.0)\\}$, i.e. the network should satisfy:\n",
|
||
|
" - $\\text{net}(1.0, 1.0) \\sim 2.0$,\n",
|
||
|
" - $\\text{net}(0.0, 0.0) \\sim 1.0$,\n",
|
||
|
" - $\\text{net}(2.0, -1.0) \\sim -1.0$,\n",
|
||
|
" - $\\text{net}(-1.0, 0.5) \\sim 0.5$,\n",
|
||
|
" - $\\text{net}(-0.5, -2.0) \\sim 2.0$.\n",
|
||
|
" \n",
|
||
|
"After training print all weights and separately print $w_{1, 2}^{(1)}$ (the weight from the second input to the first hidden neuron in the first hidden layer) and $w_{1, 3}^{(3)}$ (the weight from the third hidden neuron in the second hidden layer to the output unit).\n",
|
||
|
"\n",
|
||
|
"Print the values of the network on the training points and verify that these values are closer to the real values of the $f$ function than $\\epsilon = 0.1$, i.e. $|\\text{net}(x) - f(x)| < \\epsilon$ for $x \\in D$.\n",
|
||
|
"\n",
|
||
|
"Because this network is only tested on the training set, it will certainly overfit if trained long enough. Train for 1000 epochs and then calculate\n",
|
||
|
" - $\\text{net}(2.0, 2.0)$,\n",
|
||
|
" - $\\text{net}(-1.0, -1.0)$,\n",
|
||
|
" - $\\text{net}(3.0, -3.0)$.\n",
|
||
|
" \n",
|
||
|
"How far are these values from real values of the function $f$?"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 30,
|
||
|
"id": "documentary-petersburg",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# Write your code here"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"metadata": {
|
||
|
"kernelspec": {
|
||
|
"display_name": "Python 3",
|
||
|
"language": "python",
|
||
|
"name": "python3"
|
||
|
},
|
||
|
"language_info": {
|
||
|
"codemirror_mode": {
|
||
|
"name": "ipython",
|
||
|
"version": 3
|
||
|
},
|
||
|
"file_extension": ".py",
|
||
|
"mimetype": "text/x-python",
|
||
|
"name": "python",
|
||
|
"nbconvert_exporter": "python",
|
||
|
"pygments_lexer": "ipython3",
|
||
|
"version": "3.6.9"
|
||
|
}
|
||
|
},
|
||
|
"nbformat": 4,
|
||
|
"nbformat_minor": 5
|
||
|
}
|