48 lines
1.4 KiB
Python
48 lines
1.4 KiB
Python
|
def plot_image_grid(image_list,
|
||
|
label_list,
|
||
|
sample_images=False,
|
||
|
num_images=6,
|
||
|
pre_title='class',
|
||
|
num_img_per_row=3,
|
||
|
cmap=None,
|
||
|
img_h_w=3):
|
||
|
'''viz images from a list of images and labels
|
||
|
INPUTS:
|
||
|
image_list: a list of images to be plotted,
|
||
|
label_list: a list of correspomding image labels'''
|
||
|
|
||
|
|
||
|
|
||
|
#number of img rows
|
||
|
n_row= num_images//num_img_per_row
|
||
|
|
||
|
plt.subplots(n_row,num_img_per_row,figsize=(img_h_w*num_img_per_row,img_h_w*n_row))
|
||
|
|
||
|
if sample_images:
|
||
|
#select_random images
|
||
|
sampled_ids = random.choices(np.arange(0,len(image_list)),k=num_images)
|
||
|
|
||
|
for i,idx in enumerate(sampled_ids):
|
||
|
|
||
|
img = image_list[idx]
|
||
|
label = label_list[i]
|
||
|
plt.subplot(n_row,num_img_per_row,i+1)
|
||
|
plt.title(f'{pre_title} - {label}')
|
||
|
plt.axis('off')
|
||
|
plt.imshow(img,cmap=cmap)
|
||
|
else:
|
||
|
for i,img in enumerate(image_list):
|
||
|
|
||
|
label = label_list[i]
|
||
|
plt.subplot(n_row,num_img_per_row,i+1)
|
||
|
plt.title(f'{pre_title} - {label}')
|
||
|
plt.axis('off')
|
||
|
plt.imshow(img,cmap=cmap)
|
||
|
|
||
|
# break the loop
|
||
|
if i==num_images-1 :
|
||
|
break
|
||
|
|
||
|
#show
|
||
|
plt.tight_layout()
|
||
|
plt.show()
|