import torch
import common.helpers
from common.constants import DEVICE, BATCH_SIZE, NUM_EPOCHS, LEARNING_RATE, SETUP_PHOTOS, ID_TO_CLASS
from watersandtreegrass import WaterSandTreeGrass
from torch.utils.data import DataLoader
from neural_network import NeuralNetwork
from torchvision.io import read_image, ImageReadMode
import torch.nn as nn
from torch.optim import Adam
import matplotlib.pyplot as plt
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping


def train(model):
    model = model.to(DEVICE)
    model.train()
    trainset = WaterSandTreeGrass('./data/train_csv_file.csv', transform=SETUP_PHOTOS)
    testset = WaterSandTreeGrass('./data/test_csv_file.csv', transform=SETUP_PHOTOS)
    train_loader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True)
    test_loader = DataLoader(testset, batch_size=BATCH_SIZE, shuffle=True)

    criterion = nn.CrossEntropyLoss()
    optimizer = Adam(model.parameters(), lr=LEARNING_RATE)

    for epoch in range(NUM_EPOCHS):
        for batch_idx, (data, targets) in enumerate(train_loader):
            data = data.to(device=DEVICE)
            targets = targets.to(device=DEVICE)

            scores = model(data)
            loss = criterion(scores, targets)

            optimizer.zero_grad()
            loss.backward()

            optimizer.step()

            if batch_idx % 4 == 0:
                print("epoch: %d loss: %.4f" % (epoch, loss.item()))

    print("FINISHED TRAINING!")
    torch.save(model.state_dict(), "./learnednetwork.pth")

    print("Checking accuracy for the train set.")
    check_accuracy(train_loader)
    print("Checking accuracy for the test set.")
    check_accuracy(test_loader)
    print("Checking accuracy for the tiles.")
    check_accuracy_tiles()


def check_accuracy_tiles():
    answer = 0
    for i in range(100):
        if what_is_it('../../resources/textures/grass_with_tree.jpg') == 'tree':
            answer = answer + 1
    print("Accuracy(%) grass_with_tree.jpg",  answer)

    answer = 0
    for i in range(100):
        if what_is_it('../../resources/textures/grass2.png') == 'grass':
            answer = answer + 1
    print("Accuracy(%) grass2.png",  answer)

    answer = 0
    for i in range(100):
        if what_is_it('../../resources/textures/grass3.png') == 'grass':
            answer = answer + 1
    print("Accuracy(%) grass3.png",  answer)

    answer = 0
    for i in range(100):
        if what_is_it('../../resources/textures/grass4.png') == 'grass':
            answer = answer + 1
    print("Accuracy(%) grass4.png", answer)

    answer = 0
    for i in range(100):
        if what_is_it('../../resources/textures/grass1.png') == 'grass':
            answer = answer + 1
    print("Accuracy(%) grass1.png", answer)

    answer = 0
    for i in range(100):
        if what_is_it('../../resources/textures/water.png') == 'water':
            answer = answer + 1
    print("Accuracy(%) water.png", answer)

    answer = 0
    for i in range(100):
        if what_is_it('../../resources/textures/sand.png') == 'sand':
            answer = answer + 1
    print("Accuracy(%) sand.png", answer)


def what_is_it(img_path, show_img=False):
    image = read_image(img_path, mode=ImageReadMode.RGB)
    if show_img:
        plt.imshow(plt.imread(img_path))
        plt.show()
    image = SETUP_PHOTOS(image).unsqueeze(0)
    model = NeuralNetwork.load_from_checkpoint('./lightning_logs/version_3/checkpoints/epoch=8-step=810.ckpt')

    with torch.no_grad():
        model.eval()
        idx = int(model(image).argmax(dim=1))
        return ID_TO_CLASS[idx]


CNN = NeuralNetwork()


trainer = pl.Trainer(accelerator='gpu', devices=1, auto_scale_batch_size=True, callbacks=[EarlyStopping('val_loss')], max_epochs=NUM_EPOCHS)
#trainer = pl.Trainer(accelerator='gpu', devices=1, auto_lr_find=True, max_epochs=NUM_EPOCHS)

trainset = WaterSandTreeGrass('./data/train_csv_file.csv', transform=SETUP_PHOTOS)
testset = WaterSandTreeGrass('./data/test_csv_file.csv', transform=SETUP_PHOTOS)
train_loader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(testset, batch_size=BATCH_SIZE)

#trainer.fit(CNN, train_loader, test_loader)
#trainer.tune(CNN, train_loader, test_loader)
check_accuracy_tiles()
print(what_is_it('../../resources/textures/sand.png', True))