import torch import torchvision.transforms as transforms import PIL.Image as Image class AnimalClassifier: def __init__(self, model_path, classes, image_size=224, mean=None, std=None): self.classes = classes self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.model = torch.load(model_path, map_location=torch.device('cpu')) self.model = self.model.to(self.device) self.model = self.model.eval() self.image_size = image_size self.mean = mean if mean is not None else [0.5164, 0.5147, 0.4746] self.std = std if std is not None else [0.2180, 0.2126, 0.2172] self.image_transforms = transforms.Compose([ transforms.Resize((self.image_size, self.image_size)), transforms.ToTensor(), transforms.Normalize(torch.Tensor(self.mean), torch.Tensor(self.std)) ]) def classify(self, image_path): image = Image.open(image_path) if image.mode == 'RGBA': image = image.convert('RGB') image = self.image_transforms(image).float() image = image.unsqueeze(0).to(self.device) with torch.no_grad(): output = self.model(image) _, predicted = torch.max(output.data, 1) return self.classes[predicted.item()] classes = [ "bat", "bear", "elephant", "giraffe", "owl", "parrot", "penguin" ]