wozek/ai-wozek/siec/model.py

114 lines
3.6 KiB
Python
Raw Normal View History

2024-06-17 04:58:21 +02:00
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
import timm
from PIL import Image
import matplotlib.pyplot as plt # For data viz
import pandas as pd
import numpy as np
import sys
from siec import dataset
from glob import glob
class LabelClassifier(nn.Module):
def __init__(self,num_classes=3):
super(LabelClassifier, self).__init__()
self.model=timm.create_model('efficientnet_b0',pretrained=True)
self.features = nn.Sequential(*list(self.model.children())[:-1])
out_size=1280
self.classifier=nn.Linear(out_size,num_classes)
def forward(self,x):
x=self.features(x)
output=self.classifier(x)
return output
model=torch.load("./model")
criterion=nn.CrossEntropyLoss()
optimizer=optim.Adam(model.parameters(),lr=0.0005)
print(criterion(model(dataset.images),dataset.labels))
num_epochs=0
trainLosses=[]
valLosses=[]
device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)
for epoch in range(num_epochs):
model.train()
runningLoss=0.0
for images,labels in dataset.dataloader:
images, labels=images.to(device), labels.to(device)
optimizer.zero_grad()
outputs=model(images)
loss=criterion(outputs,labels)
loss.backward()
optimizer.step()
runningLoss+=loss.item()*images.size(0)
trainLoss=runningLoss/len(dataset.dataloader.dataset)
trainLosses.append(trainLoss)
model.eval()
runningLoss=0.0
with torch.no_grad():
for images, labels in dataset.valloader:
images, labels = images.to(device), labels.to(device)
outputs=model(images)
loss=criterion(outputs,labels)
runningLoss+=loss.item()*images.size(0)
valLoss=runningLoss/len(dataset.valloader.dataset)
valLosses.append(valLoss)
print(f"Epoch {epoch + 1}/{num_epochs} - Train loss: {trainLoss}, Validation loss: {valLoss}")
modell=torch.jit.script(model)
modell.save('model.pt')
def preprocess_image(image_path, transform):
image = Image.open(image_path).convert("RGB")
return image, transform(image).unsqueeze(0)
# Predict using the model
def predict(model, image_tensor, device):
model.eval()
with torch.no_grad():
image_tensor = image_tensor.to(device)
outputs = model(image_tensor)
probabilities = torch.nn.functional.softmax(outputs, dim=1)
print(outputs)
return probabilities.cpu().numpy().flatten()
# Visualization
def visualize_predictions(original_image, probabilities, class_names):
fig, axarr = plt.subplots(1, 2, figsize=(14, 7))
axarr[0].imshow(original_image)
axarr[0].axis("off")
# Display predictions
axarr[1].barh(class_names, probabilities)
axarr[1].set_xlabel("Probability")
axarr[1].set_title("Class Predictions")
axarr[1].set_xlim(0, 1)
plt.tight_layout()
plt.show()
transform = transforms.Compose([
transforms.Resize((128, 128)),
transforms.ToTensor()
])
test_images = glob('./train/*/*')
test_examples = np.random.choice(test_images, 10)
for example in test_examples:
model.eval()
original_image, image_tensor = preprocess_image(example, transform)
probabilities = predict(model, image_tensor, device)
print(probabilities)
# Assuming dataset.classes gives the class names
class_names = dataset.dataset.classes
visualize_predictions(original_image, probabilities, class_names)
model(image_tensor)