114 lines
3.6 KiB
Python
114 lines
3.6 KiB
Python
|
import torch
|
||
|
import torch.nn as nn
|
||
|
import torch.optim as optim
|
||
|
from torch.utils.data import Dataset, DataLoader
|
||
|
import torchvision
|
||
|
import torchvision.transforms as transforms
|
||
|
from torchvision.datasets import ImageFolder
|
||
|
import timm
|
||
|
from PIL import Image
|
||
|
|
||
|
import matplotlib.pyplot as plt # For data viz
|
||
|
import pandas as pd
|
||
|
import numpy as np
|
||
|
import sys
|
||
|
from siec import dataset
|
||
|
from glob import glob
|
||
|
|
||
|
class LabelClassifier(nn.Module):
|
||
|
def __init__(self,num_classes=3):
|
||
|
super(LabelClassifier, self).__init__()
|
||
|
self.model=timm.create_model('efficientnet_b0',pretrained=True)
|
||
|
self.features = nn.Sequential(*list(self.model.children())[:-1])
|
||
|
out_size=1280
|
||
|
self.classifier=nn.Linear(out_size,num_classes)
|
||
|
def forward(self,x):
|
||
|
x=self.features(x)
|
||
|
output=self.classifier(x)
|
||
|
return output
|
||
|
model=torch.load("./model")
|
||
|
|
||
|
criterion=nn.CrossEntropyLoss()
|
||
|
optimizer=optim.Adam(model.parameters(),lr=0.0005)
|
||
|
print(criterion(model(dataset.images),dataset.labels))
|
||
|
|
||
|
num_epochs=0
|
||
|
trainLosses=[]
|
||
|
valLosses=[]
|
||
|
|
||
|
device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||
|
model.to(device)
|
||
|
|
||
|
for epoch in range(num_epochs):
|
||
|
model.train()
|
||
|
runningLoss=0.0
|
||
|
for images,labels in dataset.dataloader:
|
||
|
images, labels=images.to(device), labels.to(device)
|
||
|
optimizer.zero_grad()
|
||
|
outputs=model(images)
|
||
|
loss=criterion(outputs,labels)
|
||
|
loss.backward()
|
||
|
optimizer.step()
|
||
|
runningLoss+=loss.item()*images.size(0)
|
||
|
trainLoss=runningLoss/len(dataset.dataloader.dataset)
|
||
|
trainLosses.append(trainLoss)
|
||
|
|
||
|
model.eval()
|
||
|
runningLoss=0.0
|
||
|
with torch.no_grad():
|
||
|
for images, labels in dataset.valloader:
|
||
|
images, labels = images.to(device), labels.to(device)
|
||
|
outputs=model(images)
|
||
|
loss=criterion(outputs,labels)
|
||
|
runningLoss+=loss.item()*images.size(0)
|
||
|
valLoss=runningLoss/len(dataset.valloader.dataset)
|
||
|
valLosses.append(valLoss)
|
||
|
print(f"Epoch {epoch + 1}/{num_epochs} - Train loss: {trainLoss}, Validation loss: {valLoss}")
|
||
|
modell=torch.jit.script(model)
|
||
|
modell.save('model.pt')
|
||
|
|
||
|
|
||
|
def preprocess_image(image_path, transform):
|
||
|
image = Image.open(image_path).convert("RGB")
|
||
|
return image, transform(image).unsqueeze(0)
|
||
|
|
||
|
# Predict using the model
|
||
|
def predict(model, image_tensor, device):
|
||
|
model.eval()
|
||
|
with torch.no_grad():
|
||
|
image_tensor = image_tensor.to(device)
|
||
|
outputs = model(image_tensor)
|
||
|
probabilities = torch.nn.functional.softmax(outputs, dim=1)
|
||
|
print(outputs)
|
||
|
return probabilities.cpu().numpy().flatten()
|
||
|
|
||
|
# Visualization
|
||
|
def visualize_predictions(original_image, probabilities, class_names):
|
||
|
fig, axarr = plt.subplots(1, 2, figsize=(14, 7))
|
||
|
axarr[0].imshow(original_image)
|
||
|
axarr[0].axis("off")
|
||
|
|
||
|
# Display predictions
|
||
|
axarr[1].barh(class_names, probabilities)
|
||
|
axarr[1].set_xlabel("Probability")
|
||
|
axarr[1].set_title("Class Predictions")
|
||
|
axarr[1].set_xlim(0, 1)
|
||
|
|
||
|
plt.tight_layout()
|
||
|
plt.show()
|
||
|
transform = transforms.Compose([
|
||
|
transforms.Resize((128, 128)),
|
||
|
transforms.ToTensor()
|
||
|
])
|
||
|
test_images = glob('./train/*/*')
|
||
|
test_examples = np.random.choice(test_images, 10)
|
||
|
for example in test_examples:
|
||
|
model.eval()
|
||
|
original_image, image_tensor = preprocess_image(example, transform)
|
||
|
probabilities = predict(model, image_tensor, device)
|
||
|
print(probabilities)
|
||
|
|
||
|
# Assuming dataset.classes gives the class names
|
||
|
class_names = dataset.dataset.classes
|
||
|
visualize_predictions(original_image, probabilities, class_names)
|
||
|
model(image_tensor)
|