48 lines
1.4 KiB
Python
48 lines
1.4 KiB
Python
import torch
|
|
import torchvision.transforms as transforms
|
|
import PIL.Image as Image
|
|
|
|
class AnimalClassifier:
|
|
def __init__(self, model_path, classes, image_size=224, mean=None, std=None):
|
|
self.classes = classes
|
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
self.model = torch.load(model_path, map_location=torch.device('cpu'))
|
|
self.model = self.model.to(self.device)
|
|
self.model = self.model.eval()
|
|
self.image_size = image_size
|
|
self.mean = mean if mean is not None else [0.5164, 0.5147, 0.4746]
|
|
self.std = std if std is not None else [0.2180, 0.2126, 0.2172]
|
|
self.image_transforms = transforms.Compose([
|
|
transforms.Resize((self.image_size, self.image_size)),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(torch.Tensor(self.mean), torch.Tensor(self.std))
|
|
])
|
|
|
|
def classify(self, image_path):
|
|
image = Image.open(image_path)
|
|
|
|
if image.mode == 'RGBA':
|
|
image = image.convert('RGB')
|
|
|
|
image = self.image_transforms(image).float()
|
|
image = image.unsqueeze(0).to(self.device)
|
|
|
|
with torch.no_grad():
|
|
output = self.model(image)
|
|
|
|
_, predicted = torch.max(output.data, 1)
|
|
|
|
return self.classes[predicted.item()]
|
|
|
|
classes = [
|
|
"bat",
|
|
"bear",
|
|
"elephant",
|
|
"giraffe",
|
|
"owl",
|
|
"parrot",
|
|
"penguin"
|
|
]
|
|
|
|
|