wko-on-cloud-n/src/utils.py
2022-02-15 04:03:45 +01:00

48 lines
1.4 KiB
Python

def plot_image_grid(image_list,
label_list,
sample_images=False,
num_images=6,
pre_title='class',
num_img_per_row=3,
cmap=None,
img_h_w=3):
'''viz images from a list of images and labels
INPUTS:
image_list: a list of images to be plotted,
label_list: a list of correspomding image labels'''
#number of img rows
n_row= num_images//num_img_per_row
plt.subplots(n_row,num_img_per_row,figsize=(img_h_w*num_img_per_row,img_h_w*n_row))
if sample_images:
#select_random images
sampled_ids = random.choices(np.arange(0,len(image_list)),k=num_images)
for i,idx in enumerate(sampled_ids):
img = image_list[idx]
label = label_list[i]
plt.subplot(n_row,num_img_per_row,i+1)
plt.title(f'{pre_title} - {label}')
plt.axis('off')
plt.imshow(img,cmap=cmap)
else:
for i,img in enumerate(image_list):
label = label_list[i]
plt.subplot(n_row,num_img_per_row,i+1)
plt.title(f'{pre_title} - {label}')
plt.axis('off')
plt.imshow(img,cmap=cmap)
# break the loop
if i==num_images-1 :
break
#show
plt.tight_layout()
plt.show()