si23traktor/agent/neural_network/inference.py

61 lines
1.9 KiB
Python
Raw Permalink Normal View History

import torch
import cv2
import torchvision.transforms as transforms
import argparse
2023-06-17 11:49:38 +02:00
from agent.neural_network.model import CNNModel
# construct the argument parser
parser = argparse.ArgumentParser()
parser.add_argument('-i', '--input',
default='',
help='path to the input image')
args = vars(parser.parse_args())
2023-06-05 04:42:53 +02:00
def main(path):
# the computation device
device = ('cuda' if torch.cuda.is_available() else 'cpu')
# list containing all the class labels
2023-06-19 14:21:30 +02:00
# labels = [
# 'bean', 'bitter gourd', 'bottle gourd', 'brinjal', 'broccoli',
# 'cabbage', 'capsicum', 'carrot', 'cauliflower', 'cucumber',
# 'papaya', 'potato', 'pumpkin', 'radish', 'tomato'
# ]
2023-06-05 04:42:53 +02:00
labels = [
2023-06-19 14:21:30 +02:00
'strawberry', 'mango', 'grape', 'banana', 'apple',
2023-06-05 04:42:53 +02:00
]
# initialize the model and load the trained weights
model = CNNModel().to(device)
2023-06-19 14:21:30 +02:00
checkpoint = torch.load('./agent/neural_network/outputs/modelFruits.pth', map_location=device)
2023-06-05 04:42:53 +02:00
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
2023-06-05 04:42:53 +02:00
# define preprocess transforms
transform = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5]
)
])
2023-06-05 04:42:53 +02:00
# read and preprocess the image
image = cv2.imread(path)
# get the ground truth class
gt_class = path.split('/')[-2]
orig_image = image.copy()
# convert to RGB format
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = transform(image)
# add batch dimension
image = torch.unsqueeze(image, 0)
with torch.no_grad():
outputs = model(image.to(device))
output_label = torch.topk(outputs, 1)
pred_class = labels[int(output_label.indices)]
return pred_class
if __name__ == "__main__":
main(args['input'])