uczenie-maszynowe/lab/linear_regression_pytorch.ipynb
2023-12-14 18:05:05 +01:00

41 KiB

#! /usr/bin/env python3
# -*- coding: utf-8 -*-

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from sklearn.model_selection import train_test_split


class LinearRegression(torch.nn.Module):
    def __init__(self, input_size, output_size, hidden_size):
        super().__init__()
        self.linear = torch.nn.Linear(input_size, hidden_size)
        self.linear2 = torch.nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = self.linear(x)
        y = self.linear2(x)
        return y


data = pd.read_csv("data_flats.tsv", sep="\t")
x = data["sqrMetres"].to_numpy(dtype=np.float32).reshape(-1, 1)
y = data["price"].to_numpy(dtype=np.float32).reshape(-1, 1)

x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2)

input_dim = 1
output_dim = 1
hidden_dim = 10
learning_rate = 0.0000001
epochs = 100

model = LinearRegression(input_dim, output_dim, hidden_dim)

criterion = torch.nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

for epoch in range(epochs):
    inputs = torch.autograd.Variable(torch.from_numpy(x_train))
    labels = torch.autograd.Variable(torch.from_numpy(y_train))

    optimizer.zero_grad()

    outputs = model(inputs)

    loss = criterion(outputs, labels)
    loss.backward()

    optimizer.step()

    print(f"{epoch=}, {loss.item()=}")

with torch.no_grad():
    predicted = model(torch.autograd.Variable(torch.from_numpy(x_test))).data.numpy()

print(f"{predicted=}")

plt.plot(x_train, y_train, "go")
plt.plot(x_test, predicted, "--")

plt.show()
epoch=0, loss.item()=161257439232.0
epoch=1, loss.item()=160922075136.0
epoch=2, loss.item()=146514542592.0
epoch=3, loss.item()=113313726464.0
epoch=4, loss.item()=12049310547968.0
epoch=5, loss.item()=4.334163566116261e+19
epoch=6, loss.item()=inf
epoch=7, loss.item()=nan
epoch=8, loss.item()=nan
epoch=9, loss.item()=nan
epoch=10, loss.item()=nan
epoch=11, loss.item()=nan
epoch=12, loss.item()=nan
epoch=13, loss.item()=nan
epoch=14, loss.item()=nan
epoch=15, loss.item()=nan
epoch=16, loss.item()=nan
epoch=17, loss.item()=nan
epoch=18, loss.item()=nan
epoch=19, loss.item()=nan
epoch=20, loss.item()=nan
epoch=21, loss.item()=nan
epoch=22, loss.item()=nan
epoch=23, loss.item()=nan
epoch=24, loss.item()=nan
epoch=25, loss.item()=nan
epoch=26, loss.item()=nan
epoch=27, loss.item()=nan
epoch=28, loss.item()=nan
epoch=29, loss.item()=nan
epoch=30, loss.item()=nan
epoch=31, loss.item()=nan
epoch=32, loss.item()=nan
epoch=33, loss.item()=nan
epoch=34, loss.item()=nan
epoch=35, loss.item()=nan
epoch=36, loss.item()=nan
epoch=37, loss.item()=nan
epoch=38, loss.item()=nan
epoch=39, loss.item()=nan
epoch=40, loss.item()=nan
epoch=41, loss.item()=nan
epoch=42, loss.item()=nan
epoch=43, loss.item()=nan
epoch=44, loss.item()=nan
epoch=45, loss.item()=nan
epoch=46, loss.item()=nan
epoch=47, loss.item()=nan
epoch=48, loss.item()=nan
epoch=49, loss.item()=nan
epoch=50, loss.item()=nan
epoch=51, loss.item()=nan
epoch=52, loss.item()=nan
epoch=53, loss.item()=nan
epoch=54, loss.item()=nan
epoch=55, loss.item()=nan
epoch=56, loss.item()=nan
epoch=57, loss.item()=nan
epoch=58, loss.item()=nan
epoch=59, loss.item()=nan
epoch=60, loss.item()=nan
epoch=61, loss.item()=nan
epoch=62, loss.item()=nan
epoch=63, loss.item()=nan
epoch=64, loss.item()=nan
epoch=65, loss.item()=nan
epoch=66, loss.item()=nan
epoch=67, loss.item()=nan
epoch=68, loss.item()=nan
epoch=69, loss.item()=nan
epoch=70, loss.item()=nan
epoch=71, loss.item()=nan
epoch=72, loss.item()=nan
epoch=73, loss.item()=nan
epoch=74, loss.item()=nan
epoch=75, loss.item()=nan
epoch=76, loss.item()=nan
epoch=77, loss.item()=nan
epoch=78, loss.item()=nan
epoch=79, loss.item()=nan
epoch=80, loss.item()=nan
epoch=81, loss.item()=nan
epoch=82, loss.item()=nan
epoch=83, loss.item()=nan
epoch=84, loss.item()=nan
epoch=85, loss.item()=nan
epoch=86, loss.item()=nan
epoch=87, loss.item()=nan
epoch=88, loss.item()=nan
epoch=89, loss.item()=nan
epoch=90, loss.item()=nan
epoch=91, loss.item()=nan
epoch=92, loss.item()=nan
epoch=93, loss.item()=nan
epoch=94, loss.item()=nan
epoch=95, loss.item()=nan
epoch=96, loss.item()=nan
epoch=97, loss.item()=nan
epoch=98, loss.item()=nan
epoch=99, loss.item()=nan
predicted=array([[nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan],
       [nan]], dtype=float32)