import torch
import pytorch_lightning as pl
import torch.nn as nn
from torch.optim import SGD, Adam, lr_scheduler
import torch.nn.functional as F
from torch.utils.data import DataLoader
from watersandtreegrass import WaterSandTreeGrass
from common.constants import DEVICE, BATCH_SIZE, NUM_EPOCHS, LEARNING_RATE, SETUP_PHOTOS, ID_TO_CLASS


class NeuralNetwork(pl.LightningModule):
    def __init__(self, numChannels=3, batch_size=BATCH_SIZE, learning_rate=LEARNING_RATE, num_classes=4):
        super(NeuralNetwork, self).__init__()
        self.conv1 = nn.Conv2d(numChannels, 24, (3, 3), padding=1)
        self.relu1 = nn.ReLU()
        self.maxpool1 = nn.MaxPool2d((2, 2), stride=2)
        self.conv2 = nn.Conv2d(24, 48, (3, 3), padding=1)
        self.relu2 = nn.ReLU()
        self.fc1 = nn.Linear(48*18*18, 800)
        self.relu3 = nn.ReLU()
        self.fc2 = nn.Linear(800, 400)
        self.relu4 = nn.ReLU()
        self.fc3 = nn.Linear(400, 4)
        self.logSoftmax = nn.LogSoftmax(dim=1)

        self.batch_size = batch_size
        self.learning_rate = learning_rate

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.maxpool1(x)
        x = self.conv2(x)
        x = self.relu2(x)
        x = x.reshape(x.shape[0], -1)
        x = self.fc1(x)
        x = self.relu3(x)
        x = self.fc2(x)
        x = self.relu4(x)
        x = self.fc3(x)
        x = self.logSoftmax(x)
        return x

    def configure_optimizers(self):
        optimizer = Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

    def training_step(self, batch, batch_idx):
        x, y = batch
        scores = self(x)
        loss = F.nll_loss(scores, y)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        scores = self(x)
        val_loss = F.nll_loss(scores, y)
        self.log("val_loss", val_loss, on_step=True, on_epoch=True, sync_dist=True)

    def test_step(self, batch, batch_idx):
        x, y = batch
        scores = self(x)
        test_loss = F.nll_loss(scores, y)
        self.log("test_loss", test_loss, on_step=True, on_epoch=True, sync_dist=True)