def plot_image_grid(image_list, label_list, sample_images=False, num_images=6, pre_title='class', num_img_per_row=3, cmap=None, img_h_w=3): '''viz images from a list of images and labels INPUTS: image_list: a list of images to be plotted, label_list: a list of correspomding image labels''' #number of img rows n_row= num_images//num_img_per_row plt.subplots(n_row,num_img_per_row,figsize=(img_h_w*num_img_per_row,img_h_w*n_row)) if sample_images: #select_random images sampled_ids = random.choices(np.arange(0,len(image_list)),k=num_images) for i,idx in enumerate(sampled_ids): img = image_list[idx] label = label_list[i] plt.subplot(n_row,num_img_per_row,i+1) plt.title(f'{pre_title} - {label}') plt.axis('off') plt.imshow(img,cmap=cmap) else: for i,img in enumerate(image_list): label = label_list[i] plt.subplot(n_row,num_img_per_row,i+1) plt.title(f'{pre_title} - {label}') plt.axis('off') plt.imshow(img,cmap=cmap) # break the loop if i==num_images-1 : break #show plt.tight_layout() plt.show()