41 KiB
41 KiB
#! /usr/bin/env python3
# -*- coding: utf-8 -*-
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from sklearn.model_selection import train_test_split
class LinearRegression(torch.nn.Module):
def __init__(self, input_size, output_size, hidden_size):
super().__init__()
self.linear = torch.nn.Linear(input_size, hidden_size)
self.linear2 = torch.nn.Linear(hidden_size, output_size)
def forward(self, x):
x = self.linear(x)
y = self.linear2(x)
return y
data = pd.read_csv("data_flats.tsv", sep="\t")
x = data["sqrMetres"].to_numpy(dtype=np.float32).reshape(-1, 1)
y = data["price"].to_numpy(dtype=np.float32).reshape(-1, 1)
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2)
input_dim = 1
output_dim = 1
hidden_dim = 10
learning_rate = 0.0000001
epochs = 100
model = LinearRegression(input_dim, output_dim, hidden_dim)
criterion = torch.nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
for epoch in range(epochs):
inputs = torch.autograd.Variable(torch.from_numpy(x_train))
labels = torch.autograd.Variable(torch.from_numpy(y_train))
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
print(f"{epoch=}, {loss.item()=}")
with torch.no_grad():
predicted = model(torch.autograd.Variable(torch.from_numpy(x_test))).data.numpy()
print(f"{predicted=}")
plt.plot(x_train, y_train, "go")
plt.plot(x_test, predicted, "--")
plt.show()
epoch=0, loss.item()=161257439232.0 epoch=1, loss.item()=160922075136.0 epoch=2, loss.item()=146514542592.0 epoch=3, loss.item()=113313726464.0 epoch=4, loss.item()=12049310547968.0 epoch=5, loss.item()=4.334163566116261e+19 epoch=6, loss.item()=inf epoch=7, loss.item()=nan epoch=8, loss.item()=nan epoch=9, loss.item()=nan epoch=10, loss.item()=nan epoch=11, loss.item()=nan epoch=12, loss.item()=nan epoch=13, loss.item()=nan epoch=14, loss.item()=nan epoch=15, loss.item()=nan epoch=16, loss.item()=nan epoch=17, loss.item()=nan epoch=18, loss.item()=nan epoch=19, loss.item()=nan epoch=20, loss.item()=nan epoch=21, loss.item()=nan epoch=22, loss.item()=nan epoch=23, loss.item()=nan epoch=24, loss.item()=nan epoch=25, loss.item()=nan epoch=26, loss.item()=nan epoch=27, loss.item()=nan epoch=28, loss.item()=nan epoch=29, loss.item()=nan epoch=30, loss.item()=nan epoch=31, loss.item()=nan epoch=32, loss.item()=nan epoch=33, loss.item()=nan epoch=34, loss.item()=nan epoch=35, loss.item()=nan epoch=36, loss.item()=nan epoch=37, loss.item()=nan epoch=38, loss.item()=nan epoch=39, loss.item()=nan epoch=40, loss.item()=nan epoch=41, loss.item()=nan epoch=42, loss.item()=nan epoch=43, loss.item()=nan epoch=44, loss.item()=nan epoch=45, loss.item()=nan epoch=46, loss.item()=nan epoch=47, loss.item()=nan epoch=48, loss.item()=nan epoch=49, loss.item()=nan epoch=50, loss.item()=nan epoch=51, loss.item()=nan epoch=52, loss.item()=nan epoch=53, loss.item()=nan epoch=54, loss.item()=nan epoch=55, loss.item()=nan epoch=56, loss.item()=nan epoch=57, loss.item()=nan epoch=58, loss.item()=nan epoch=59, loss.item()=nan epoch=60, loss.item()=nan epoch=61, loss.item()=nan epoch=62, loss.item()=nan epoch=63, loss.item()=nan epoch=64, loss.item()=nan epoch=65, loss.item()=nan epoch=66, loss.item()=nan epoch=67, loss.item()=nan epoch=68, loss.item()=nan epoch=69, loss.item()=nan epoch=70, loss.item()=nan epoch=71, loss.item()=nan epoch=72, loss.item()=nan epoch=73, loss.item()=nan epoch=74, loss.item()=nan epoch=75, loss.item()=nan epoch=76, loss.item()=nan epoch=77, loss.item()=nan epoch=78, loss.item()=nan epoch=79, loss.item()=nan epoch=80, loss.item()=nan epoch=81, loss.item()=nan epoch=82, loss.item()=nan epoch=83, loss.item()=nan epoch=84, loss.item()=nan epoch=85, loss.item()=nan epoch=86, loss.item()=nan epoch=87, loss.item()=nan epoch=88, loss.item()=nan epoch=89, loss.item()=nan epoch=90, loss.item()=nan epoch=91, loss.item()=nan epoch=92, loss.item()=nan epoch=93, loss.item()=nan epoch=94, loss.item()=nan epoch=95, loss.item()=nan epoch=96, loss.item()=nan epoch=97, loss.item()=nan epoch=98, loss.item()=nan epoch=99, loss.item()=nan predicted=array([[nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan], [nan]], dtype=float32)