import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
import timm
from PIL import Image

import matplotlib.pyplot as plt # For data viz
import pandas as pd
import numpy as np
import sys
from siec import dataset
from glob import glob

class LabelClassifier(nn.Module):
    def __init__(self,num_classes=3):
        super(LabelClassifier, self).__init__()
        self.model=timm.create_model('efficientnet_b0',pretrained=True)
        self.features = nn.Sequential(*list(self.model.children())[:-1])
        out_size=1280
        self.classifier=nn.Linear(out_size,num_classes)
    def forward(self,x):
        x=self.features(x)
        output=self.classifier(x)
        return output
model=torch.load("./model")

criterion=nn.CrossEntropyLoss()
optimizer=optim.Adam(model.parameters(),lr=0.0005)
print(criterion(model(dataset.images),dataset.labels))

num_epochs=0
trainLosses=[]
valLosses=[]

device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)

for epoch in range(num_epochs):
    model.train()
    runningLoss=0.0
    for images,labels in dataset.dataloader:
        images, labels=images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs=model(images)
        loss=criterion(outputs,labels)
        loss.backward()
        optimizer.step()
        runningLoss+=loss.item()*images.size(0)
    trainLoss=runningLoss/len(dataset.dataloader.dataset)
    trainLosses.append(trainLoss)

    model.eval()
    runningLoss=0.0
    with torch.no_grad():
        for images, labels in dataset.valloader:
            images, labels = images.to(device), labels.to(device)
            outputs=model(images)
            loss=criterion(outputs,labels)
            runningLoss+=loss.item()*images.size(0)
    valLoss=runningLoss/len(dataset.valloader.dataset)
    valLosses.append(valLoss)
    print(f"Epoch {epoch + 1}/{num_epochs} - Train loss: {trainLoss}, Validation loss: {valLoss}")
modell=torch.jit.script(model)
modell.save('model.pt')


def preprocess_image(image_path, transform):
    image = Image.open(image_path).convert("RGB")
    return image, transform(image).unsqueeze(0)

# Predict using the model
def predict(model, image_tensor, device):
    model.eval()
    with torch.no_grad():
        image_tensor = image_tensor.to(device)
        outputs = model(image_tensor)
        probabilities = torch.nn.functional.softmax(outputs, dim=1)
        print(outputs)
    return probabilities.cpu().numpy().flatten()

# Visualization
def visualize_predictions(original_image, probabilities, class_names):
    fig, axarr = plt.subplots(1, 2, figsize=(14, 7))
    axarr[0].imshow(original_image)
    axarr[0].axis("off")

    # Display predictions
    axarr[1].barh(class_names, probabilities)
    axarr[1].set_xlabel("Probability")
    axarr[1].set_title("Class Predictions")
    axarr[1].set_xlim(0, 1)

    plt.tight_layout()
    plt.show()
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor()
])
test_images = glob('./train/*/*')
test_examples = np.random.choice(test_images, 10)
for example in test_examples:
    model.eval()
    original_image, image_tensor = preprocess_image(example, transform)
    probabilities = predict(model, image_tensor, device)
    print(probabilities)

    # Assuming dataset.classes gives the class names
    class_names = dataset.dataset.classes
    visualize_predictions(original_image, probabilities, class_names)
    model(image_tensor)