import PIL
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from matplotlib.pyplot import imshow


def to_negative(img):
    img = PIL.ImageOps.invert(img)
    return img


class Negative(object):
    def __init__(self):
        pass

    def __call__(self, img):
        return to_negative(img)


def plotdigit(image):
    img = np.reshape(image, (-1, 100))
    imshow(img, cmap='Greys')


transform = transforms.Compose([Negative(), transforms.ToTensor()])
train_set = torchvision.datasets.ImageFolder(root='../src/train', transform=transform)
classes = ("apple", "potato")

BATCH_SIZE = 2
train_loader = torch.utils.data.DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(3 * 100 * 100, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 2),
            nn.ReLU()
        )
        self.linear_relu_stack = self.linear_relu_stack.to(device)

    def forward(self, x):
        x = self.flatten(x).to(device)
        logits = self.linear_relu_stack(x).to(device)
        return logits


def training_network():
    net = Net()
    net = net.to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

    for epoch in range(4):
        running_loss = 0.0
        for i, data in enumerate(train_loader, 0):
            inputs, labels = data[0].to(device), data[1].to(device)
            optimizer.zero_grad()
            outputs = net(inputs.to(device))
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            if i % 2000 == 1999:
                print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss))
                running_loss = 0.0

    print("Finished training")
    save_network_to_file(net)


def result_from_network(net, loaded_image):
    image = PIL.Image.open(loaded_image)
    pil_to_tensor = transforms.ToTensor()(image.convert("RGB")).unsqueeze_(0)
    outputs = net(pil_to_tensor.to(device))

    return classes[torch.max(outputs, 1)[1]]


def save_network_to_file(network):
    torch.save(network.state_dict(), 'network_model.pth')
    print("Network saved to file")


def load_network_from_structure(network):
    network.load_state_dict(torch.load('network_model.pth'))


# Create network_model.pth
if __name__ == "__main__":
    print(torch.cuda.is_available())
    training_network()