uczenie-maszynowe/lab/linear_regression_pytorch.ipynb

558 lines
41 KiB
Plaintext
Raw Permalink Normal View History

2023-01-13 14:30:44 +01:00
{
"cells": [
{
"cell_type": "code",
"execution_count": 3,
"id": "fda33bda",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch=0, loss.item()=161257439232.0\n",
"epoch=1, loss.item()=160922075136.0\n",
"epoch=2, loss.item()=146514542592.0\n",
"epoch=3, loss.item()=113313726464.0\n",
"epoch=4, loss.item()=12049310547968.0\n",
"epoch=5, loss.item()=4.334163566116261e+19\n",
"epoch=6, loss.item()=inf\n",
"epoch=7, loss.item()=nan\n",
"epoch=8, loss.item()=nan\n",
"epoch=9, loss.item()=nan\n",
"epoch=10, loss.item()=nan\n",
"epoch=11, loss.item()=nan\n",
"epoch=12, loss.item()=nan\n",
"epoch=13, loss.item()=nan\n",
"epoch=14, loss.item()=nan\n",
"epoch=15, loss.item()=nan\n",
"epoch=16, loss.item()=nan\n",
"epoch=17, loss.item()=nan\n",
"epoch=18, loss.item()=nan\n",
"epoch=19, loss.item()=nan\n",
"epoch=20, loss.item()=nan\n",
"epoch=21, loss.item()=nan\n",
"epoch=22, loss.item()=nan\n",
"epoch=23, loss.item()=nan\n",
"epoch=24, loss.item()=nan\n",
"epoch=25, loss.item()=nan\n",
"epoch=26, loss.item()=nan\n",
"epoch=27, loss.item()=nan\n",
"epoch=28, loss.item()=nan\n",
"epoch=29, loss.item()=nan\n",
"epoch=30, loss.item()=nan\n",
"epoch=31, loss.item()=nan\n",
"epoch=32, loss.item()=nan\n",
"epoch=33, loss.item()=nan\n",
"epoch=34, loss.item()=nan\n",
"epoch=35, loss.item()=nan\n",
"epoch=36, loss.item()=nan\n",
"epoch=37, loss.item()=nan\n",
"epoch=38, loss.item()=nan\n",
"epoch=39, loss.item()=nan\n",
"epoch=40, loss.item()=nan\n",
"epoch=41, loss.item()=nan\n",
"epoch=42, loss.item()=nan\n",
"epoch=43, loss.item()=nan\n",
"epoch=44, loss.item()=nan\n",
"epoch=45, loss.item()=nan\n",
"epoch=46, loss.item()=nan\n",
"epoch=47, loss.item()=nan\n",
"epoch=48, loss.item()=nan\n",
"epoch=49, loss.item()=nan\n",
"epoch=50, loss.item()=nan\n",
"epoch=51, loss.item()=nan\n",
"epoch=52, loss.item()=nan\n",
"epoch=53, loss.item()=nan\n",
"epoch=54, loss.item()=nan\n",
"epoch=55, loss.item()=nan\n",
"epoch=56, loss.item()=nan\n",
"epoch=57, loss.item()=nan\n",
"epoch=58, loss.item()=nan\n",
"epoch=59, loss.item()=nan\n",
"epoch=60, loss.item()=nan\n",
"epoch=61, loss.item()=nan\n",
"epoch=62, loss.item()=nan\n",
"epoch=63, loss.item()=nan\n",
"epoch=64, loss.item()=nan\n",
"epoch=65, loss.item()=nan\n",
"epoch=66, loss.item()=nan\n",
"epoch=67, loss.item()=nan\n",
"epoch=68, loss.item()=nan\n",
"epoch=69, loss.item()=nan\n",
"epoch=70, loss.item()=nan\n",
"epoch=71, loss.item()=nan\n",
"epoch=72, loss.item()=nan\n",
"epoch=73, loss.item()=nan\n",
"epoch=74, loss.item()=nan\n",
"epoch=75, loss.item()=nan\n",
"epoch=76, loss.item()=nan\n",
"epoch=77, loss.item()=nan\n",
"epoch=78, loss.item()=nan\n",
"epoch=79, loss.item()=nan\n",
"epoch=80, loss.item()=nan\n",
"epoch=81, loss.item()=nan\n",
"epoch=82, loss.item()=nan\n",
"epoch=83, loss.item()=nan\n",
"epoch=84, loss.item()=nan\n",
"epoch=85, loss.item()=nan\n",
"epoch=86, loss.item()=nan\n",
"epoch=87, loss.item()=nan\n",
"epoch=88, loss.item()=nan\n",
"epoch=89, loss.item()=nan\n",
"epoch=90, loss.item()=nan\n",
"epoch=91, loss.item()=nan\n",
"epoch=92, loss.item()=nan\n",
"epoch=93, loss.item()=nan\n",
"epoch=94, loss.item()=nan\n",
"epoch=95, loss.item()=nan\n",
"epoch=96, loss.item()=nan\n",
"epoch=97, loss.item()=nan\n",
"epoch=98, loss.item()=nan\n",
"epoch=99, loss.item()=nan\n",
"predicted=array([[nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan],\n",
" [nan]], dtype=float32)\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAiQAAAGsCAYAAADt+LxYAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAABNn0lEQVR4nO3de1xUdf4/8NcwCkoyGKKCzChKZpZmZv4UXQq/meLXChZJs/pqu13WvAResmy3zLbNvloKu1nt9t3EtrwgjvrNNRIvKAVd1PymxrJqKojgjWRQFHTm/P5wZ5ZhbufMnJkzl9fTxzwecuYzZz7nzOW853N5f1SCIAggIiIiUlCY0hUgIiIiYkBCREREimNAQkRERIpjQEJERESKY0BCREREimNAQkRERIpjQEJERESKY0BCREREimNAQkRERIpjQEJERESKC6iAZM+ePXjooYfQo0cPqFQqbNq0SfI+BEHA22+/jVtvvRURERFISEjAH/7wB/krS0RERKK1U7oCUly+fBmDBg3Cr3/9a2RmZrq1j+zsbGzbtg1vv/02Bg4ciPr6etTX18tcUyIiIpJCFaiL66lUKmzcuBEZGRmWbc3Nzfjtb3+LNWvW4OLFixgwYAD++7//G6mpqQCAiooK3HnnnTh06BD69eunTMWJiIjIRkB12bgyc+ZMlJeXY+3atfjhhx/wyCOPIC0tDUeOHAEAfPbZZ+jTpw+2bNmC3r17IzExEU8//TRbSIiIiBQWNAFJVVUVVq5cifXr1yMlJQVJSUmYN28efvGLX2DlypUAgJ9++gknT57E+vXr8fHHHyM/Px/79u1DVlaWwrUnIiIKbQE1hsSZgwcPwmg04tZbb7Xa3tzcjC5dugAATCYTmpub8fHHH1vK/fWvf8WQIUNQWVnJbhwiIiKFBE1AcunSJajVauzbtw9qtdrqvk6dOgEA4uPj0a5dO6ugpX///gButLAwICEiIlJG0AQkgwcPhtFoxNmzZ5GSkmK3zMiRI3H9+nUcO3YMSUlJAIB//vOfAIBevXr5rK5ERERkLaBm2Vy6dAlHjx4FcCMAWbZsGUaNGoWYmBj07NkTTzzxBL766iu88847GDx4MM6dO4cdO3bgzjvvxPjx42EymTB06FB06tQJubm5MJlMmDFjBjQaDbZt26bw0REREYWugApISkpKMGrUKJvtU6dORX5+Pq5du4Y33ngDH3/8MWpqahAbG4vhw4dj0aJFGDhwIADg9OnTmDVrFrZt24abbroJ48aNwzvvvIOYmBhfHw4RERH9S0AFJERERBScgmbaLxEREQUuBiRERESkuICYZWMymXD69GlERUVBpVIpXR0iIiISQRAENDY2okePHggLc94GEhAByenTp6HT6ZSuBhEREbmhuroaWq3WaZmACEiioqIA3DggjUajcG2IiIhIDIPBAJ1OZ7mOOxMQAYm5m0aj0TAgISIiCjBihltwUCsREREpjgEJERERKY4BCRERESmOAQkREREpjgEJERERKY4BCRERESmOAQkREREpjgEJERERKS4gEqMRkXyMJiNKq0pR21iL+Kh4pPRMgTpMrXS1iCjEMSAhCiH6Cj2yi7JxynDKsk2r0SIvLQ+Z/TMVrBkRhTp22RCFCH2FHlkFWVbBCADUGGqQVZAFfYVeoZoRETEgIQoJRpMR2UXZECDY3GfellOUA6PJ6OuqEREBYEBCFBJKq0ptWkZaEyCg2lCN0qpSH9aKiOjfGJAQhYDaxlpZyxERyY0BCVEIiI+Kl7UcEZHcGJAQhYCUninQarRQQWX3fhVU0Gl0SOmZ4uOaERHdwICEKASow9TIS8sDAJugxPx3blou85EQkWIYkBCFiMz+mSicWIgETYLVdq1Gi8KJhcxDQkSKUgmCYDsP0M8YDAZER0ejoaEBGo1G6eoQBTRmaiUiX5Fy/WamVqIQow5TIzUxVelqEBFZYZcNERERKY4BCRERESmOAQkREREpjgEJERERKY4BCRERESmOAQkREREpjgEJERERKY4BCRERESmOAQkREREpjgEJERERKY4BCRERESmOAQkREREpjgEJERERKY4BCRERESmOAQkREREpTlJAsnjxYgwdOhRRUVHo1q0bMjIyUFlZ6fQx+fn5UKlUVrcOHTp4VGkiIiIKLpICkt27d2PGjBn4+uuvUVxcjGvXrmHMmDG4fPmy08dpNBrU1tZabidPnvSo0kRERBRc2kkpXFRUZPV3fn4+unXrhn379uHee+91+DiVSoW4uDj3akhERERBz6MxJA0NDQCAmJgYp+UuXbqEXr16QafTIT09HYcPH3Zavrm5GQaDwepGREREwcvtgMRkMiEnJwcjR47EgAEDHJbr168fPvroI2zevBmffPIJTCYTRowYgVOnTjl8zOLFixEdHW256XQ6d6tJREREAUAlCILgzgOfe+45fP755/jyyy+h1WpFP+7atWvo378/Jk+ejN///vd2yzQ3N6O5udnyt8FggE6nQ0NDAzQajTvVJSIiIh8zGAyIjo4Wdf2WNIbEbObMmdiyZQv27NkjKRgBgPbt22Pw4ME4evSowzIRERGIiIhwp2pEREQUgCR12QiCgJkzZ2Ljxo3YuXMnevfuLfkJjUYjDh48iPj4eMmPJSIiouAkqYVkxowZWL16NTZv3oyoqCjU1dUBAKKjo9GxY0cAwJQpU5CQkIDFixcDAF5//XUMHz4ct9xyCy5evIilS5fi5MmTePrpp2U+FCIiIgpUkgKS999/HwCQmppqtX3lypV48sknAQBVVVUIC/t3w8vPP/+MZ555BnV1dbj55psxZMgQlJWV4fbbb/es5kRERBQ03B7U6ktSBsUQERGRf5By/eZaNkRERKQ4BiRERESkOAYkREREpDgGJERERKQ4BiRERESkOAYkREREpDgGJERERKQ4BiRERESkOAYkREREpDgGJERERKQ4BiRERESkOAYkREREpDgGJERERKQ4BiRERESkOAYkREREpDgGJERERKQ4BiRERESkOAYkREREpDgGJERERKQ4BiRERESkOAYkREREpDgGJERERKQ4BiRERESkOAYkREREpDgGJERERKQ4BiRERESkOAYkREREpDgGJERERKQ4BiRERESkOAYkREREpDgGJERERKQ4BiRERESkOAYkREREpDgGJERERKQ4BiRERESkOAYkREREpDgGJERERKQ4BiRERESkOAYkREREpDgGJERERKQ4BiRERESkOAYkREREpDgGJERERKQ4BiRERESkOAYkREREpDgGJERERKQ4BiRERESkOAYkREREpDgGJERERKQ4BiRERESkOAYkREREpDgGJERERKQ4BiRERESkOAYkREREpDhJAcnixYsxdOhQREVFoVu3bsjIyEBlZaXLx61fvx633XYbOnTogIEDB2Lr1q1uV5iIiIiCj6SAZPfu3ZgxYwa+/vprFBcX49q1axgzZgwuX77s8DFlZWWYPHkynnrqKXz//ffIyMhARkYGDh065HHliYiIKDioBEEQ3H3wuXPn0K1bN+zevRv33nuv3TKTJk3C5cuXsWXLFsu24cOH46677sIHH3wg6nkMBgOio6PR0NAAjUbjbnWJiIjIh6Rcvz0aQ9LQ0AAAiImJcVimvLwco0ePtto2duxYlJeXO3xMc3MzDAaD1Y2IiIiCl9sBiclkQk5ODkaOHIkBAwY4LFdXV4fu3btbbevevTvq6uocPmbx4sWIjo623HQ6nbvVJCIiogDgdkAyY8YMHDp0CGvXrpWzPgCABQsWoKGhwXKrrq6W/TmIiIjIf7Rz50EzZ87Eli1bsGfPHmi1Wqdl4+LicObMGattZ86cQVxcnMPHREREICIiwp2qERERUQCS1EIiCAJmzpyJjRs3YufOnejdu7fLxyQnJ2PHjh1W24qLi5GcnCytpkRERBS0JLWQzJgxA6tXr8bmzZsRFRVlGQcSHR2Njh07AgCmTJmChIQELF68GACQnZ2N++67D++88w7Gjx+PtWvXYu/evfjLX/4i86EQERFRoJLUQvL++++joaEBqampiI+Pt9zWrVtnKVN
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"#! /usr/bin/env python3\n",
"# -*- coding: utf-8 -*-\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import pandas as pd\n",
"import torch\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"\n",
"class LinearRegression(torch.nn.Module):\n",
" def __init__(self, input_size, output_size, hidden_size):\n",
" super().__init__()\n",
" self.linear = torch.nn.Linear(input_size, hidden_size)\n",
" self.linear2 = torch.nn.Linear(hidden_size, output_size)\n",
"\n",
" def forward(self, x):\n",
" x = self.linear(x)\n",
" y = self.linear2(x)\n",
" return y\n",
"\n",
"\n",
"data = pd.read_csv(\"data_flats.tsv\", sep=\"\\t\")\n",
"x = data[\"sqrMetres\"].to_numpy(dtype=np.float32).reshape(-1, 1)\n",
"y = data[\"price\"].to_numpy(dtype=np.float32).reshape(-1, 1)\n",
"\n",
"x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2)\n",
"\n",
"input_dim = 1\n",
"output_dim = 1\n",
"hidden_dim = 10\n",
"learning_rate = 0.0000001\n",
"epochs = 100\n",
"\n",
"model = LinearRegression(input_dim, output_dim, hidden_dim)\n",
"\n",
"criterion = torch.nn.MSELoss()\n",
"optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)\n",
"\n",
"for epoch in range(epochs):\n",
" inputs = torch.autograd.Variable(torch.from_numpy(x_train))\n",
" labels = torch.autograd.Variable(torch.from_numpy(y_train))\n",
"\n",
" optimizer.zero_grad()\n",
"\n",
" outputs = model(inputs)\n",
"\n",
" loss = criterion(outputs, labels)\n",
" loss.backward()\n",
"\n",
" optimizer.step()\n",
"\n",
" print(f\"{epoch=}, {loss.item()=}\")\n",
"\n",
"with torch.no_grad():\n",
" predicted = model(torch.autograd.Variable(torch.from_numpy(x_test))).data.numpy()\n",
"\n",
"print(f\"{predicted=}\")\n",
"\n",
"plt.plot(x_train, y_train, \"go\")\n",
"plt.plot(x_test, predicted, \"--\")\n",
"\n",
"plt.show()\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ac50018c",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
2023-12-14 18:05:05 +01:00
"version": "3.10.12"
2023-01-13 14:30:44 +01:00
}
},
"nbformat": 4,
"nbformat_minor": 5
}