update code to show visualizations
This commit is contained in:
parent
7f19f6329d
commit
669c39f659
@ -1,5 +1,5 @@
|
|||||||
from flask import Flask, render_template, request, redirect, url_for, send_file, jsonify
|
from flask import Flask, render_template, request, redirect, url_for, send_file, jsonify, g
|
||||||
from mode_style_transfer import StyleTransferModel, save_image, StyleTransferVisualizer
|
from mode_style_transfer import StyleTransferModel
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import io
|
import io
|
||||||
import torch
|
import torch
|
||||||
@ -8,6 +8,7 @@ import torchvision.transforms as transforms
|
|||||||
import os
|
import os
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import base64
|
import base64
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
app = Flask(__name__)
|
app = Flask(__name__)
|
||||||
|
|
||||||
@ -20,7 +21,8 @@ loader = transforms.Compose([
|
|||||||
|
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
visualizations = []
|
# Global variable to store the output tensor
|
||||||
|
output_tensor = None
|
||||||
|
|
||||||
def image_loader(image_bytes):
|
def image_loader(image_bytes):
|
||||||
image = Image.open(io.BytesIO(image_bytes))
|
image = Image.open(io.BytesIO(image_bytes))
|
||||||
@ -38,10 +40,10 @@ def image_to_base64(image):
|
|||||||
img_io.seek(0)
|
img_io.seek(0)
|
||||||
return base64.b64encode(img_io.getvalue()).decode('utf-8')
|
return base64.b64encode(img_io.getvalue()).decode('utf-8')
|
||||||
|
|
||||||
|
|
||||||
@app.route('/', methods=['GET', 'POST'])
|
@app.route('/', methods=['GET', 'POST'])
|
||||||
def index():
|
def index():
|
||||||
if request.method == 'POST':
|
if request.method == 'POST':
|
||||||
|
global output_tensor
|
||||||
content_image_file = request.files['content_image']
|
content_image_file = request.files['content_image']
|
||||||
style_image_file = request.files['style_image']
|
style_image_file = request.files['style_image']
|
||||||
|
|
||||||
@ -52,7 +54,7 @@ def index():
|
|||||||
# Pass the images to the StyleTransferModel
|
# Pass the images to the StyleTransferModel
|
||||||
style_transfer = StyleTransferModel(content_image, style_image)
|
style_transfer = StyleTransferModel(content_image, style_image)
|
||||||
output = style_transfer.run_style_transfer()
|
output = style_transfer.run_style_transfer()
|
||||||
|
output_tensor = output
|
||||||
# Convert the output tensor to an image
|
# Convert the output tensor to an image
|
||||||
output_image = tensor_to_image(output)
|
output_image = tensor_to_image(output)
|
||||||
|
|
||||||
@ -63,37 +65,53 @@ def index():
|
|||||||
|
|
||||||
return render_template('index.html')
|
return render_template('index.html')
|
||||||
|
|
||||||
@app.route('/visualize', methods=['POST'])
|
@app.route('/visualize', methods=['GET'])
|
||||||
def visualize():
|
def visualize():
|
||||||
cnn = vgg19(weights=VGG19_Weights.DEFAULT).features.to(device).eval()
|
pretrained_model = vgg19(weights=VGG19_Weights.DEFAULT).features.eval().to(device)
|
||||||
|
|
||||||
cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406]).to(device)
|
|
||||||
cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225]).to(device)
|
|
||||||
|
|
||||||
content_image_bytes = visualizations[0] # The last saved content image
|
|
||||||
content_image = image_loader(content_image_bytes)
|
|
||||||
|
|
||||||
style_transfer = StyleTransferModel(content_image, content_image)
|
# Extract convolutional layers from VGG19
|
||||||
|
conv_layers = []
|
||||||
# Running the model for visualization purpose
|
for module in pretrained_model.children():
|
||||||
input_img = content_image.clone().requires_grad_(True)
|
if isinstance(module, nn.Conv2d):
|
||||||
|
conv_layers.append(module)
|
||||||
model, _, _ = style_transfer.get_style_model_and_losses(
|
|
||||||
cnn, cnn_normalization_mean, cnn_normalization_std, content_image, content_image)
|
|
||||||
|
|
||||||
layer_visualizations = []
|
# Pass the resulting image through the convolutional layers and capture feature maps
|
||||||
|
feature_maps = []
|
||||||
|
layer_names = []
|
||||||
|
input_image = output_tensor.clone()
|
||||||
|
|
||||||
# Run the image through each layer and store the output
|
for i, layer in enumerate(conv_layers):
|
||||||
for i, layer in enumerate(model):
|
input_image = layer(input_image)
|
||||||
input_img = layer(input_img)
|
feature_maps.append(input_image)
|
||||||
with torch.no_grad():
|
layer_names.append(f"Layer {i + 1}: {str(layer)}")
|
||||||
output_image = tensor_to_image(input_img.clamp(0, 1))
|
|
||||||
img_io = io.BytesIO()
|
|
||||||
output_image.save(img_io, 'JPEG')
|
|
||||||
img_io.seek(0)
|
|
||||||
layer_visualizations.append(img_io.getvalue()) # Save the image bytes
|
|
||||||
|
|
||||||
return render_template('visualize.html', visualizations=layer_visualizations)
|
|
||||||
|
|
||||||
|
# Process and feature maps
|
||||||
|
processed_feature_maps = []
|
||||||
|
for feature_map in feature_maps:
|
||||||
|
feature_map = feature_map.squeeze(0) # Remove the batch dimension
|
||||||
|
mean_feature_map = torch.mean(feature_map, dim=0).cpu().detach().numpy() # Compute mean across channels
|
||||||
|
processed_feature_maps.append(mean_feature_map)
|
||||||
|
|
||||||
|
# Plot the feature maps
|
||||||
|
fig = plt.figure(figsize=(20, 20))
|
||||||
|
for i, fm in enumerate(processed_feature_maps):
|
||||||
|
ax = fig.add_subplot(4, 4, i + 1) # Adjust grid size as needed
|
||||||
|
ax.imshow(fm, cmap='viridis') # Display feature map as image
|
||||||
|
ax.axis("off")
|
||||||
|
ax.set_title(layer_names[i], fontsize=8)
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
|
||||||
|
# Save the plot to a BytesIO object and encode it as base64
|
||||||
|
img_io = io.BytesIO()
|
||||||
|
plt.savefig(img_io, format='png')
|
||||||
|
img_io.seek(0)
|
||||||
|
plt.close(fig)
|
||||||
|
plot_base64 = base64.b64encode(img_io.getvalue()).decode('utf-8')
|
||||||
|
|
||||||
|
# Return the image as a base64-encoded string that can be embedded in HTML
|
||||||
|
return f'<img src="data:image/png;base64,{plot_base64}" alt="Layer Visualizations"/>'
|
||||||
|
|
||||||
|
#run the app
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
app.run(debug=True)
|
app.run(debug=True)
|
||||||
|
@ -2,7 +2,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
|
import io
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
@ -11,6 +11,8 @@ from torchvision.models import vgg19, VGG19_Weights
|
|||||||
|
|
||||||
from torchvision import models
|
from torchvision import models
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
import torchvision.utils as vutils
|
||||||
|
|
||||||
|
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
torch.set_default_device(device)
|
torch.set_default_device(device)
|
||||||
@ -217,42 +219,3 @@ class StyleTransferModel:
|
|||||||
|
|
||||||
return self.input_img
|
return self.input_img
|
||||||
|
|
||||||
|
|
||||||
class StyleTransferVisualizer(StyleTransferModel):
|
|
||||||
def __init__(self, content_img, style_img):
|
|
||||||
super().__init__(content_img, style_img)
|
|
||||||
self.model_layers = self.get_model_layers()
|
|
||||||
|
|
||||||
def get_model_layers(self):
|
|
||||||
cnn = models.vgg19(pretrained=True).features.to(self.device).eval()
|
|
||||||
model_layers = []
|
|
||||||
i = 0
|
|
||||||
for layer in cnn.children():
|
|
||||||
if isinstance(layer, torch.nn.Conv2d):
|
|
||||||
i += 1
|
|
||||||
model_layers.append((f'conv_{i}', layer))
|
|
||||||
return model_layers
|
|
||||||
|
|
||||||
def visualize_layers(self):
|
|
||||||
fig, axs = plt.subplots(len(self.model_layers), 3, figsize=(15, 20))
|
|
||||||
|
|
||||||
input_img = self.content_img.clone().detach()
|
|
||||||
|
|
||||||
for idx, (name, layer) in enumerate(self.model_layers):
|
|
||||||
input_img = layer(input_img)
|
|
||||||
axs[idx, 0].imshow(self.content_img.squeeze(0).permute(1, 2, 0).cpu().numpy())
|
|
||||||
axs[idx, 0].set_title("Original Image")
|
|
||||||
axs[idx, 0].axis('off')
|
|
||||||
|
|
||||||
axs[idx, 1].imshow(input_img.squeeze(0).permute(1, 2, 0).cpu().detach().numpy())
|
|
||||||
axs[idx, 1].set_title(f"After {name}")
|
|
||||||
axs[idx, 1].axis('off')
|
|
||||||
|
|
||||||
combined = input_img.clone()
|
|
||||||
combined += self.style_img.squeeze(0)
|
|
||||||
axs[idx, 2].imshow(combined.permute(1, 2, 0).cpu().detach().numpy())
|
|
||||||
axs[idx, 2].set_title(f"Combined (Content + Style) after {name}")
|
|
||||||
axs[idx, 2].axis('off')
|
|
||||||
|
|
||||||
plt.tight_layout()
|
|
||||||
plt.show()
|
|
@ -35,9 +35,8 @@
|
|||||||
<div id="result-container">
|
<div id="result-container">
|
||||||
<h2>Resulting Image:</h2>
|
<h2>Resulting Image:</h2>
|
||||||
<div id="image-container"></div>
|
<div id="image-container"></div>
|
||||||
<a href="{{ url_for('visualize') }}">
|
<button id="visualize-btn">Visualize Layers</button>
|
||||||
<button>Visualize Layers</button>
|
<div id="visualization-container"></div> <!-- Container for the visualization -->
|
||||||
</a>
|
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<script>
|
<script>
|
||||||
@ -66,6 +65,16 @@
|
|||||||
resultContainer.style.display = 'block'; // Show the container
|
resultContainer.style.display = 'block'; // Show the container
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// JavaScript to handle visualization button click
|
||||||
|
document.getElementById('visualize-btn').addEventListener('click', async function() {
|
||||||
|
const response = await fetch('/visualize');
|
||||||
|
|
||||||
|
if (response.ok) {
|
||||||
|
const visualizationContainer = document.getElementById('visualization-container');
|
||||||
|
visualizationContainer.innerHTML = await response.text(); // Display the plot
|
||||||
|
}
|
||||||
|
});
|
||||||
</script>
|
</script>
|
||||||
</body>
|
</body>
|
||||||
</html>
|
</html>
|
||||||
|
@ -1,25 +0,0 @@
|
|||||||
<!DOCTYPE html>
|
|
||||||
<html lang="en">
|
|
||||||
<head>
|
|
||||||
<meta charset="UTF-8">
|
|
||||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
|
||||||
<title>Visualize Layers</title>
|
|
||||||
</head>
|
|
||||||
<body>
|
|
||||||
<h1>Layer Visualizations</h1>
|
|
||||||
<p>Select a layer to view the image before and after processing through that layer:</p>
|
|
||||||
|
|
||||||
{% for i in range(visualizations|length) %}
|
|
||||||
<div>
|
|
||||||
<h2>Layer {{ i + 1 }}</h2>
|
|
||||||
<button onclick="document.getElementById('img_before').src='{{ url_for(show_image, index=i) }}';">
|
|
||||||
View Layer {{ i + 1 }}
|
|
||||||
</button>
|
|
||||||
</div>
|
|
||||||
{% endfor %}
|
|
||||||
|
|
||||||
<h2>Layer Output:</h2>
|
|
||||||
<img id="img_before" src="" alt="Layer Image Output" style="max-width: 100%; height: auto;">
|
|
||||||
</body>
|
|
||||||
</html>
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user