#!/usr/bin/env python3

# Wprowadzamy minibatche

import sys
import torch
from torch import optim
import itertools

# Preprocessing i wektoryzację tekstów wydzialamy do osobnego modułu,
# z którego będzie korzystał zarówno kod do uczenia, jak i predykcji.
from analyzer import vectorizer, vector_length, process_line, vectorize_batch

from my_linear_regressor import MyLinearRegressor

regressor = MyLinearRegressor(vector_length)

# Rozmiar minibatcha
batch_size = 16

# Pomocnicza funkcja do batchowania
def grouper(n, iterable):
    it = iter(iterable)
    while True:
        chunk = tuple(itertools.islice(it, n))
        if not chunk:
            return
        yield chunk


# Tym razem użyjemy optymalizatora
optimizer = optim.Adam(regressor.parameters())


# Funkcja kosztu.
def loss_fun(y_hat, y_exp):
    return torch.sum((y_hat - y_exp)**2) / batch_size


# Co ile kroków będziemy wypisywali informacje o średniej funkcji kosztu.
# To nie jest hiperparametr uczenia, nie ma to żadnego, ani pozytywnego, ani
# negatywnego wpływu na uczenie.
step = 500
i = 1
closs = torch.tensor(0.0, dtype=torch.double, requires_grad=False)

for batch in grouper(batch_size, sys.stdin):
    t = [process_line(line) for line in batch]
    contents = [entry[0] for entry in t]
    # y_exp będzie teraz wektorem!
    y_exp = torch.tensor([entry[1] for entry in t], dtype=torch.double)

    optimizer.zero_grad()

    x = vectorize_batch(contents)

    # wartość z predykcji (też wektor!)
    y_hat = regressor(x)

    # wyliczamy funkcję kosztu
    loss = loss_fun(y_hat, y_exp)

    loss.backward()

    with torch.no_grad():
        closs += loss

    # Optymalizator automagicznie zadba o aktualizację wag!
    optimizer.step()

    # za jakiś czas pokazujemy uśrednioną funkcję kosztu
    if i % step == 0:
        print("Sample item: ", y_exp[0].item(), " => ", y_hat[0].item(),
              " | Avg loss: ", (closs / step).item())
        closs = torch.tensor(0.0, dtype=torch.double, requires_grad=False)

    i += 1


# serializujemy nasz model
torch.save(regressor, "model.bin")