make code more flexible

This commit is contained in:
s473558 2023-06-05 04:42:53 +02:00
parent 28bf53c037
commit e27acbacaf

View File

@ -10,23 +10,24 @@ parser.add_argument('-i', '--input',
help='path to the input image') help='path to the input image')
args = vars(parser.parse_args()) args = vars(parser.parse_args())
# the computation device def main(path):
device = ('cuda' if torch.cuda.is_available() else 'cpu') # the computation device
# list containing all the class labels device = ('cuda' if torch.cuda.is_available() else 'cpu')
labels = [ # list containing all the class labels
labels = [
'bean', 'bitter gourd', 'bottle gourd', 'brinjal', 'broccoli', 'bean', 'bitter gourd', 'bottle gourd', 'brinjal', 'broccoli',
'cabbage', 'capsicum', 'carrot', 'cauliflower', 'cucumber', 'cabbage', 'capsicum', 'carrot', 'cauliflower', 'cucumber',
'papaya', 'potato', 'pumpkin', 'radish', 'tomato' 'papaya', 'potato', 'pumpkin', 'radish', 'tomato'
] ]
# initialize the model and load the trained weights # initialize the model and load the trained weights
model = CNNModel().to(device) model = CNNModel().to(device)
checkpoint = torch.load('outputs/model.pth', map_location=device) checkpoint = torch.load('outputs/model.pth', map_location=device)
model.load_state_dict(checkpoint['model_state_dict']) model.load_state_dict(checkpoint['model_state_dict'])
model.eval() model.eval()
# define preprocess transforms # define preprocess transforms
transform = transforms.Compose([ transform = transforms.Compose([
transforms.ToPILImage(), transforms.ToPILImage(),
transforms.Resize(224), transforms.Resize(224),
transforms.ToTensor(), transforms.ToTensor(),
@ -34,37 +35,25 @@ transform = transforms.Compose([
mean=[0.5, 0.5, 0.5], mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5] std=[0.5, 0.5, 0.5]
) )
]) ])
# read and preprocess the image # read and preprocess the image
image = cv2.imread(args['input']) image = cv2.imread(path)
# get the ground truth class # get the ground truth class
gt_class = args['input'].split('/')[-2] gt_class = path.split('/')[-2]
orig_image = image.copy() orig_image = image.copy()
# convert to RGB format # convert to RGB format
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = transform(image) image = transform(image)
# add batch dimension # add batch dimension
image = torch.unsqueeze(image, 0) image = torch.unsqueeze(image, 0)
with torch.no_grad(): with torch.no_grad():
outputs = model(image.to(device)) outputs = model(image.to(device))
output_label = torch.topk(outputs, 1) output_label = torch.topk(outputs, 1)
pred_class = labels[int(output_label.indices)] pred_class = labels[int(output_label.indices)]
cv2.putText(orig_image,
f"GT: {gt_class}", return pred_class
(10, 25),
cv2.FONT_HERSHEY_SIMPLEX, if __name__ == "__main__":
0.6, (0, 255, 0), 2, cv2.LINE_AA main(args['input'])
)
cv2.putText(orig_image,
f"Pred: {pred_class}",
(10, 55),
cv2.FONT_HERSHEY_SIMPLEX,
0.6, (0, 0, 255), 2, cv2.LINE_AA
)
print(f"GT: {gt_class}, pred: {pred_class}")
cv2.imshow('Result', orig_image)
cv2.waitKey(0)
cv2.imwrite(f"outputs/{gt_class}{args['input'].split('/')[-1].split('.')[0]}.png",
orig_image)