from flask import Flask, render_template, request, redirect, url_for, send_file, jsonify from mode_style_transfer import StyleTransferModel, save_image, StyleTransferVisualizer from PIL import Image import io import torch from torchvision.models import vgg19, VGG19_Weights import torchvision.transforms as transforms import os import matplotlib.pyplot as plt import base64 app = Flask(__name__) # Image transformation imsize = 512 if torch.cuda.is_available() else 128 loader = transforms.Compose([ transforms.Resize(imsize), transforms.ToTensor() ]) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") visualizations = [] def image_loader(image_bytes): image = Image.open(io.BytesIO(image_bytes)) image = loader(image).unsqueeze(0) return image.to(device, torch.float) def tensor_to_image(tensor): image = tensor.clone().detach().squeeze(0) image = transforms.ToPILImage()(image) return image def image_to_base64(image): img_io = io.BytesIO() image.save(img_io, 'JPEG') img_io.seek(0) return base64.b64encode(img_io.getvalue()).decode('utf-8') @app.route('/', methods=['GET', 'POST']) def index(): if request.method == 'POST': content_image_file = request.files['content_image'] style_image_file = request.files['style_image'] # Load images directly from the uploaded files content_image = Image.open(content_image_file) style_image = Image.open(style_image_file) # Pass the images to the StyleTransferModel style_transfer = StyleTransferModel(content_image, style_image) output = style_transfer.run_style_transfer() # Convert the output tensor to an image output_image = tensor_to_image(output) # Convert the image to Base64 for JSON response image_base64 = image_to_base64(output_image) return jsonify({'image': image_base64}) return render_template('index.html') @app.route('/visualize', methods=['POST']) def visualize(): cnn = vgg19(weights=VGG19_Weights.DEFAULT).features.to(device).eval() cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406]).to(device) cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225]).to(device) content_image_bytes = visualizations[0] # The last saved content image content_image = image_loader(content_image_bytes) style_transfer = StyleTransferModel(content_image, content_image) # Running the model for visualization purpose input_img = content_image.clone().requires_grad_(True) model, _, _ = style_transfer.get_style_model_and_losses( cnn, cnn_normalization_mean, cnn_normalization_std, content_image, content_image) layer_visualizations = [] # Run the image through each layer and store the output for i, layer in enumerate(model): input_img = layer(input_img) with torch.no_grad(): output_image = tensor_to_image(input_img.clamp(0, 1)) img_io = io.BytesIO() output_image.save(img_io, 'JPEG') img_io.seek(0) layer_visualizations.append(img_io.getvalue()) # Save the image bytes return render_template('visualize.html', visualizations=layer_visualizations) if __name__ == '__main__': app.run(debug=True)