make code more flexible

This commit is contained in:
s473558 2023-06-05 04:42:53 +02:00
parent 28bf53c037
commit e27acbacaf

View File

@ -10,61 +10,50 @@ parser.add_argument('-i', '--input',
help='path to the input image') help='path to the input image')
args = vars(parser.parse_args()) args = vars(parser.parse_args())
# the computation device def main(path):
device = ('cuda' if torch.cuda.is_available() else 'cpu') # the computation device
# list containing all the class labels device = ('cuda' if torch.cuda.is_available() else 'cpu')
labels = [ # list containing all the class labels
'bean', 'bitter gourd', 'bottle gourd', 'brinjal', 'broccoli', labels = [
'cabbage', 'capsicum', 'carrot', 'cauliflower', 'cucumber', 'bean', 'bitter gourd', 'bottle gourd', 'brinjal', 'broccoli',
'papaya', 'potato', 'pumpkin', 'radish', 'tomato' 'cabbage', 'capsicum', 'carrot', 'cauliflower', 'cucumber',
] 'papaya', 'potato', 'pumpkin', 'radish', 'tomato'
]
# initialize the model and load the trained weights # initialize the model and load the trained weights
model = CNNModel().to(device) model = CNNModel().to(device)
checkpoint = torch.load('outputs/model.pth', map_location=device) checkpoint = torch.load('outputs/model.pth', map_location=device)
model.load_state_dict(checkpoint['model_state_dict']) model.load_state_dict(checkpoint['model_state_dict'])
model.eval() model.eval()
# define preprocess transforms # define preprocess transforms
transform = transforms.Compose([ transform = transforms.Compose([
transforms.ToPILImage(), transforms.ToPILImage(),
transforms.Resize(224), transforms.Resize(224),
transforms.ToTensor(), transforms.ToTensor(),
transforms.Normalize( transforms.Normalize(
mean=[0.5, 0.5, 0.5], mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5] std=[0.5, 0.5, 0.5]
) )
]) ])
# read and preprocess the image # read and preprocess the image
image = cv2.imread(args['input']) image = cv2.imread(path)
# get the ground truth class # get the ground truth class
gt_class = args['input'].split('/')[-2] gt_class = path.split('/')[-2]
orig_image = image.copy() orig_image = image.copy()
# convert to RGB format # convert to RGB format
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = transform(image) image = transform(image)
# add batch dimension # add batch dimension
image = torch.unsqueeze(image, 0) image = torch.unsqueeze(image, 0)
with torch.no_grad(): with torch.no_grad():
outputs = model(image.to(device)) outputs = model(image.to(device))
output_label = torch.topk(outputs, 1) output_label = torch.topk(outputs, 1)
pred_class = labels[int(output_label.indices)] pred_class = labels[int(output_label.indices)]
cv2.putText(orig_image,
f"GT: {gt_class}", return pred_class
(10, 25),
cv2.FONT_HERSHEY_SIMPLEX, if __name__ == "__main__":
0.6, (0, 255, 0), 2, cv2.LINE_AA main(args['input'])
)
cv2.putText(orig_image,
f"Pred: {pred_class}",
(10, 55),
cv2.FONT_HERSHEY_SIMPLEX,
0.6, (0, 0, 255), 2, cv2.LINE_AA
)
print(f"GT: {gt_class}, pred: {pred_class}")
cv2.imshow('Result', orig_image)
cv2.waitKey(0)
cv2.imwrite(f"outputs/{gt_class}{args['input'].split('/')[-1].split('.')[0]}.png",
orig_image)