add terminal notifications
This commit is contained in:
parent
f82b25ee68
commit
84eabfae0e
@ -76,6 +76,7 @@ def visualize():
|
||||
conv_layers = []
|
||||
for module in pretrained_model.children():
|
||||
if isinstance(module, nn.Conv2d):
|
||||
print("Adding module to the layers... ")
|
||||
conv_layers.append(module)
|
||||
|
||||
# Pass the resulting image through the convolutional layers and capture feature maps
|
||||
@ -85,12 +86,14 @@ def visualize():
|
||||
|
||||
for i, layer in enumerate(conv_layers):
|
||||
input_image = layer(input_image)
|
||||
print("Passing through feature maps - layer " , i)
|
||||
feature_maps.append(input_image)
|
||||
layer_names.append(f"Layer {i + 1}: {str(layer)}")
|
||||
|
||||
# Process and feature maps
|
||||
processed_feature_maps = []
|
||||
for feature_map in feature_maps:
|
||||
print("Processing feature map...")
|
||||
feature_map = feature_map.squeeze(0) # Remove the batch dimension
|
||||
mean_feature_map = torch.mean(feature_map, dim=0).cpu().detach().numpy() # Compute mean across channels
|
||||
processed_feature_maps.append(mean_feature_map)
|
||||
@ -98,6 +101,7 @@ def visualize():
|
||||
# Plot the feature maps
|
||||
fig = plt.figure(figsize=(20, 20))
|
||||
for i, fm in enumerate(processed_feature_maps):
|
||||
print("Plotting feature maps... now at map number ", i)
|
||||
ax = fig.add_subplot(4, 4, i + 1) # Adjust grid size as needed
|
||||
ax.imshow(fm, cmap='viridis') # Display feature map as image
|
||||
ax.axis("off")
|
||||
|
Loading…
Reference in New Issue
Block a user