558 lines
41 KiB
Plaintext
558 lines
41 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"id": "fda33bda",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"epoch=0, loss.item()=161257439232.0\n",
|
|
"epoch=1, loss.item()=160922075136.0\n",
|
|
"epoch=2, loss.item()=146514542592.0\n",
|
|
"epoch=3, loss.item()=113313726464.0\n",
|
|
"epoch=4, loss.item()=12049310547968.0\n",
|
|
"epoch=5, loss.item()=4.334163566116261e+19\n",
|
|
"epoch=6, loss.item()=inf\n",
|
|
"epoch=7, loss.item()=nan\n",
|
|
"epoch=8, loss.item()=nan\n",
|
|
"epoch=9, loss.item()=nan\n",
|
|
"epoch=10, loss.item()=nan\n",
|
|
"epoch=11, loss.item()=nan\n",
|
|
"epoch=12, loss.item()=nan\n",
|
|
"epoch=13, loss.item()=nan\n",
|
|
"epoch=14, loss.item()=nan\n",
|
|
"epoch=15, loss.item()=nan\n",
|
|
"epoch=16, loss.item()=nan\n",
|
|
"epoch=17, loss.item()=nan\n",
|
|
"epoch=18, loss.item()=nan\n",
|
|
"epoch=19, loss.item()=nan\n",
|
|
"epoch=20, loss.item()=nan\n",
|
|
"epoch=21, loss.item()=nan\n",
|
|
"epoch=22, loss.item()=nan\n",
|
|
"epoch=23, loss.item()=nan\n",
|
|
"epoch=24, loss.item()=nan\n",
|
|
"epoch=25, loss.item()=nan\n",
|
|
"epoch=26, loss.item()=nan\n",
|
|
"epoch=27, loss.item()=nan\n",
|
|
"epoch=28, loss.item()=nan\n",
|
|
"epoch=29, loss.item()=nan\n",
|
|
"epoch=30, loss.item()=nan\n",
|
|
"epoch=31, loss.item()=nan\n",
|
|
"epoch=32, loss.item()=nan\n",
|
|
"epoch=33, loss.item()=nan\n",
|
|
"epoch=34, loss.item()=nan\n",
|
|
"epoch=35, loss.item()=nan\n",
|
|
"epoch=36, loss.item()=nan\n",
|
|
"epoch=37, loss.item()=nan\n",
|
|
"epoch=38, loss.item()=nan\n",
|
|
"epoch=39, loss.item()=nan\n",
|
|
"epoch=40, loss.item()=nan\n",
|
|
"epoch=41, loss.item()=nan\n",
|
|
"epoch=42, loss.item()=nan\n",
|
|
"epoch=43, loss.item()=nan\n",
|
|
"epoch=44, loss.item()=nan\n",
|
|
"epoch=45, loss.item()=nan\n",
|
|
"epoch=46, loss.item()=nan\n",
|
|
"epoch=47, loss.item()=nan\n",
|
|
"epoch=48, loss.item()=nan\n",
|
|
"epoch=49, loss.item()=nan\n",
|
|
"epoch=50, loss.item()=nan\n",
|
|
"epoch=51, loss.item()=nan\n",
|
|
"epoch=52, loss.item()=nan\n",
|
|
"epoch=53, loss.item()=nan\n",
|
|
"epoch=54, loss.item()=nan\n",
|
|
"epoch=55, loss.item()=nan\n",
|
|
"epoch=56, loss.item()=nan\n",
|
|
"epoch=57, loss.item()=nan\n",
|
|
"epoch=58, loss.item()=nan\n",
|
|
"epoch=59, loss.item()=nan\n",
|
|
"epoch=60, loss.item()=nan\n",
|
|
"epoch=61, loss.item()=nan\n",
|
|
"epoch=62, loss.item()=nan\n",
|
|
"epoch=63, loss.item()=nan\n",
|
|
"epoch=64, loss.item()=nan\n",
|
|
"epoch=65, loss.item()=nan\n",
|
|
"epoch=66, loss.item()=nan\n",
|
|
"epoch=67, loss.item()=nan\n",
|
|
"epoch=68, loss.item()=nan\n",
|
|
"epoch=69, loss.item()=nan\n",
|
|
"epoch=70, loss.item()=nan\n",
|
|
"epoch=71, loss.item()=nan\n",
|
|
"epoch=72, loss.item()=nan\n",
|
|
"epoch=73, loss.item()=nan\n",
|
|
"epoch=74, loss.item()=nan\n",
|
|
"epoch=75, loss.item()=nan\n",
|
|
"epoch=76, loss.item()=nan\n",
|
|
"epoch=77, loss.item()=nan\n",
|
|
"epoch=78, loss.item()=nan\n",
|
|
"epoch=79, loss.item()=nan\n",
|
|
"epoch=80, loss.item()=nan\n",
|
|
"epoch=81, loss.item()=nan\n",
|
|
"epoch=82, loss.item()=nan\n",
|
|
"epoch=83, loss.item()=nan\n",
|
|
"epoch=84, loss.item()=nan\n",
|
|
"epoch=85, loss.item()=nan\n",
|
|
"epoch=86, loss.item()=nan\n",
|
|
"epoch=87, loss.item()=nan\n",
|
|
"epoch=88, loss.item()=nan\n",
|
|
"epoch=89, loss.item()=nan\n",
|
|
"epoch=90, loss.item()=nan\n",
|
|
"epoch=91, loss.item()=nan\n",
|
|
"epoch=92, loss.item()=nan\n",
|
|
"epoch=93, loss.item()=nan\n",
|
|
"epoch=94, loss.item()=nan\n",
|
|
"epoch=95, loss.item()=nan\n",
|
|
"epoch=96, loss.item()=nan\n",
|
|
"epoch=97, loss.item()=nan\n",
|
|
"epoch=98, loss.item()=nan\n",
|
|
"epoch=99, loss.item()=nan\n",
|
|
"predicted=array([[nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan],\n",
|
|
" [nan]], dtype=float32)\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"image/png": "\n",
|
|
"text/plain": [
|
|
"<Figure size 640x480 with 1 Axes>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"#! /usr/bin/env python3\n",
|
|
"# -*- coding: utf-8 -*-\n",
|
|
"\n",
|
|
"import matplotlib.pyplot as plt\n",
|
|
"import numpy as np\n",
|
|
"import pandas as pd\n",
|
|
"import torch\n",
|
|
"from sklearn.model_selection import train_test_split\n",
|
|
"\n",
|
|
"\n",
|
|
"class LinearRegression(torch.nn.Module):\n",
|
|
" def __init__(self, input_size, output_size, hidden_size):\n",
|
|
" super().__init__()\n",
|
|
" self.linear = torch.nn.Linear(input_size, hidden_size)\n",
|
|
" self.linear2 = torch.nn.Linear(hidden_size, output_size)\n",
|
|
"\n",
|
|
" def forward(self, x):\n",
|
|
" x = self.linear(x)\n",
|
|
" y = self.linear2(x)\n",
|
|
" return y\n",
|
|
"\n",
|
|
"\n",
|
|
"data = pd.read_csv(\"data_flats.tsv\", sep=\"\\t\")\n",
|
|
"x = data[\"sqrMetres\"].to_numpy(dtype=np.float32).reshape(-1, 1)\n",
|
|
"y = data[\"price\"].to_numpy(dtype=np.float32).reshape(-1, 1)\n",
|
|
"\n",
|
|
"x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2)\n",
|
|
"\n",
|
|
"input_dim = 1\n",
|
|
"output_dim = 1\n",
|
|
"hidden_dim = 10\n",
|
|
"learning_rate = 0.0000001\n",
|
|
"epochs = 100\n",
|
|
"\n",
|
|
"model = LinearRegression(input_dim, output_dim, hidden_dim)\n",
|
|
"\n",
|
|
"criterion = torch.nn.MSELoss()\n",
|
|
"optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)\n",
|
|
"\n",
|
|
"for epoch in range(epochs):\n",
|
|
" inputs = torch.autograd.Variable(torch.from_numpy(x_train))\n",
|
|
" labels = torch.autograd.Variable(torch.from_numpy(y_train))\n",
|
|
"\n",
|
|
" optimizer.zero_grad()\n",
|
|
"\n",
|
|
" outputs = model(inputs)\n",
|
|
"\n",
|
|
" loss = criterion(outputs, labels)\n",
|
|
" loss.backward()\n",
|
|
"\n",
|
|
" optimizer.step()\n",
|
|
"\n",
|
|
" print(f\"{epoch=}, {loss.item()=}\")\n",
|
|
"\n",
|
|
"with torch.no_grad():\n",
|
|
" predicted = model(torch.autograd.Variable(torch.from_numpy(x_test))).data.numpy()\n",
|
|
"\n",
|
|
"print(f\"{predicted=}\")\n",
|
|
"\n",
|
|
"plt.plot(x_train, y_train, \"go\")\n",
|
|
"plt.plot(x_test, predicted, \"--\")\n",
|
|
"\n",
|
|
"plt.show()\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "ac50018c",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.10.6"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|