make code more flexible
This commit is contained in:
parent
28bf53c037
commit
e27acbacaf
@ -10,61 +10,50 @@ parser.add_argument('-i', '--input',
|
||||
help='path to the input image')
|
||||
args = vars(parser.parse_args())
|
||||
|
||||
# the computation device
|
||||
device = ('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
# list containing all the class labels
|
||||
labels = [
|
||||
'bean', 'bitter gourd', 'bottle gourd', 'brinjal', 'broccoli',
|
||||
'cabbage', 'capsicum', 'carrot', 'cauliflower', 'cucumber',
|
||||
'papaya', 'potato', 'pumpkin', 'radish', 'tomato'
|
||||
]
|
||||
def main(path):
|
||||
# the computation device
|
||||
device = ('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
# list containing all the class labels
|
||||
labels = [
|
||||
'bean', 'bitter gourd', 'bottle gourd', 'brinjal', 'broccoli',
|
||||
'cabbage', 'capsicum', 'carrot', 'cauliflower', 'cucumber',
|
||||
'papaya', 'potato', 'pumpkin', 'radish', 'tomato'
|
||||
]
|
||||
|
||||
# initialize the model and load the trained weights
|
||||
model = CNNModel().to(device)
|
||||
checkpoint = torch.load('outputs/model.pth', map_location=device)
|
||||
model.load_state_dict(checkpoint['model_state_dict'])
|
||||
model.eval()
|
||||
# initialize the model and load the trained weights
|
||||
model = CNNModel().to(device)
|
||||
checkpoint = torch.load('outputs/model.pth', map_location=device)
|
||||
model.load_state_dict(checkpoint['model_state_dict'])
|
||||
model.eval()
|
||||
|
||||
# define preprocess transforms
|
||||
transform = transforms.Compose([
|
||||
transforms.ToPILImage(),
|
||||
transforms.Resize(224),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(
|
||||
mean=[0.5, 0.5, 0.5],
|
||||
std=[0.5, 0.5, 0.5]
|
||||
)
|
||||
])
|
||||
# define preprocess transforms
|
||||
transform = transforms.Compose([
|
||||
transforms.ToPILImage(),
|
||||
transforms.Resize(224),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(
|
||||
mean=[0.5, 0.5, 0.5],
|
||||
std=[0.5, 0.5, 0.5]
|
||||
)
|
||||
])
|
||||
|
||||
|
||||
# read and preprocess the image
|
||||
image = cv2.imread(args['input'])
|
||||
# get the ground truth class
|
||||
gt_class = args['input'].split('/')[-2]
|
||||
orig_image = image.copy()
|
||||
# convert to RGB format
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
image = transform(image)
|
||||
# add batch dimension
|
||||
image = torch.unsqueeze(image, 0)
|
||||
with torch.no_grad():
|
||||
outputs = model(image.to(device))
|
||||
output_label = torch.topk(outputs, 1)
|
||||
pred_class = labels[int(output_label.indices)]
|
||||
cv2.putText(orig_image,
|
||||
f"GT: {gt_class}",
|
||||
(10, 25),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
0.6, (0, 255, 0), 2, cv2.LINE_AA
|
||||
)
|
||||
cv2.putText(orig_image,
|
||||
f"Pred: {pred_class}",
|
||||
(10, 55),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
0.6, (0, 0, 255), 2, cv2.LINE_AA
|
||||
)
|
||||
print(f"GT: {gt_class}, pred: {pred_class}")
|
||||
cv2.imshow('Result', orig_image)
|
||||
cv2.waitKey(0)
|
||||
cv2.imwrite(f"outputs/{gt_class}{args['input'].split('/')[-1].split('.')[0]}.png",
|
||||
orig_image)
|
||||
# read and preprocess the image
|
||||
image = cv2.imread(path)
|
||||
# get the ground truth class
|
||||
gt_class = path.split('/')[-2]
|
||||
orig_image = image.copy()
|
||||
# convert to RGB format
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
image = transform(image)
|
||||
# add batch dimension
|
||||
image = torch.unsqueeze(image, 0)
|
||||
with torch.no_grad():
|
||||
outputs = model(image.to(device))
|
||||
output_label = torch.topk(outputs, 1)
|
||||
pred_class = labels[int(output_label.indices)]
|
||||
|
||||
return pred_class
|
||||
|
||||
if __name__ == "__main__":
|
||||
main(args['input'])
|
Loading…
Reference in New Issue
Block a user