2023-06-05 03:35:16 +02:00
|
|
|
import torch
|
|
|
|
import cv2
|
|
|
|
import torchvision.transforms as transforms
|
|
|
|
import argparse
|
2023-06-17 11:49:38 +02:00
|
|
|
from agent.neural_network.model import CNNModel
|
2023-06-05 03:35:16 +02:00
|
|
|
# construct the argument parser
|
|
|
|
parser = argparse.ArgumentParser()
|
|
|
|
parser.add_argument('-i', '--input',
|
|
|
|
default='',
|
|
|
|
help='path to the input image')
|
|
|
|
args = vars(parser.parse_args())
|
|
|
|
|
2023-06-05 04:42:53 +02:00
|
|
|
def main(path):
|
|
|
|
# the computation device
|
|
|
|
device = ('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
# list containing all the class labels
|
|
|
|
labels = [
|
|
|
|
'bean', 'bitter gourd', 'bottle gourd', 'brinjal', 'broccoli',
|
|
|
|
'cabbage', 'capsicum', 'carrot', 'cauliflower', 'cucumber',
|
|
|
|
'papaya', 'potato', 'pumpkin', 'radish', 'tomato'
|
|
|
|
]
|
2023-06-05 03:35:16 +02:00
|
|
|
|
2023-06-05 04:42:53 +02:00
|
|
|
# initialize the model and load the trained weights
|
|
|
|
model = CNNModel().to(device)
|
2023-06-17 11:49:38 +02:00
|
|
|
checkpoint = torch.load('./agent/neural_network/outputs/model.pth', map_location=device)
|
2023-06-05 04:42:53 +02:00
|
|
|
model.load_state_dict(checkpoint['model_state_dict'])
|
|
|
|
model.eval()
|
2023-06-05 03:35:16 +02:00
|
|
|
|
2023-06-05 04:42:53 +02:00
|
|
|
# define preprocess transforms
|
|
|
|
transform = transforms.Compose([
|
|
|
|
transforms.ToPILImage(),
|
|
|
|
transforms.Resize(224),
|
|
|
|
transforms.ToTensor(),
|
|
|
|
transforms.Normalize(
|
|
|
|
mean=[0.5, 0.5, 0.5],
|
|
|
|
std=[0.5, 0.5, 0.5]
|
|
|
|
)
|
|
|
|
])
|
2023-06-05 03:35:16 +02:00
|
|
|
|
|
|
|
|
2023-06-05 04:42:53 +02:00
|
|
|
# read and preprocess the image
|
|
|
|
image = cv2.imread(path)
|
|
|
|
# get the ground truth class
|
2023-06-19 15:44:49 +02:00
|
|
|
#gt_class = path.split('/')[-2]
|
|
|
|
gt_class = path.split('/')
|
2023-06-05 04:42:53 +02:00
|
|
|
orig_image = image.copy()
|
|
|
|
# convert to RGB format
|
|
|
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
|
|
|
image = transform(image)
|
|
|
|
# add batch dimension
|
|
|
|
image = torch.unsqueeze(image, 0)
|
|
|
|
with torch.no_grad():
|
|
|
|
outputs = model(image.to(device))
|
|
|
|
output_label = torch.topk(outputs, 1)
|
|
|
|
pred_class = labels[int(output_label.indices)]
|
|
|
|
|
|
|
|
return pred_class
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
main(args['input'])
|