Male_zoo_Projekt_SI/classification.py

48 lines
1.4 KiB
Python

import torch
import torchvision.transforms as transforms
import PIL.Image as Image
class AnimalClassifier:
def __init__(self, model_path, classes, image_size=224, mean=None, std=None):
self.classes = classes
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model = torch.load(model_path, map_location=torch.device('cpu'))
self.model = self.model.to(self.device)
self.model = self.model.eval()
self.image_size = image_size
self.mean = mean if mean is not None else [0.5164, 0.5147, 0.4746]
self.std = std if std is not None else [0.2180, 0.2126, 0.2172]
self.image_transforms = transforms.Compose([
transforms.Resize((self.image_size, self.image_size)),
transforms.ToTensor(),
transforms.Normalize(torch.Tensor(self.mean), torch.Tensor(self.std))
])
def classify(self, image_path):
image = Image.open(image_path)
if image.mode == 'RGBA':
image = image.convert('RGB')
image = self.image_transforms(image).float()
image = image.unsqueeze(0).to(self.device)
with torch.no_grad():
output = self.model(image)
_, predicted = torch.max(output.data, 1)
return self.classes[predicted.item()]
classes = [
"bat",
"bear",
"elephant",
"giraffe",
"owl",
"parrot",
"penguin"
]