neural_style/neural_style_app/app.py

100 lines
3.2 KiB
Python
Raw Normal View History

2024-08-10 15:25:51 +02:00
from flask import Flask, render_template, request, redirect, url_for, send_file, jsonify
from mode_style_transfer import StyleTransferModel, save_image, StyleTransferVisualizer
from PIL import Image
import io
import torch
from torchvision.models import vgg19, VGG19_Weights
import torchvision.transforms as transforms
import os
import matplotlib.pyplot as plt
import base64
app = Flask(__name__)
# Image transformation
imsize = 512 if torch.cuda.is_available() else 128
loader = transforms.Compose([
transforms.Resize(imsize),
transforms.ToTensor()
])
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
visualizations = []
def image_loader(image_bytes):
image = Image.open(io.BytesIO(image_bytes))
image = loader(image).unsqueeze(0)
return image.to(device, torch.float)
def tensor_to_image(tensor):
image = tensor.clone().detach().squeeze(0)
image = transforms.ToPILImage()(image)
return image
def image_to_base64(image):
img_io = io.BytesIO()
image.save(img_io, 'JPEG')
img_io.seek(0)
return base64.b64encode(img_io.getvalue()).decode('utf-8')
@app.route('/', methods=['GET', 'POST'])
def index():
if request.method == 'POST':
content_image_file = request.files['content_image']
style_image_file = request.files['style_image']
# Load images directly from the uploaded files
content_image = Image.open(content_image_file)
style_image = Image.open(style_image_file)
# Pass the images to the StyleTransferModel
style_transfer = StyleTransferModel(content_image, style_image)
output = style_transfer.run_style_transfer()
# Convert the output tensor to an image
output_image = tensor_to_image(output)
# Convert the image to Base64 for JSON response
image_base64 = image_to_base64(output_image)
return jsonify({'image': image_base64})
return render_template('index.html')
@app.route('/visualize', methods=['POST'])
def visualize():
cnn = vgg19(weights=VGG19_Weights.DEFAULT).features.to(device).eval()
cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406]).to(device)
cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225]).to(device)
content_image_bytes = visualizations[0] # The last saved content image
content_image = image_loader(content_image_bytes)
style_transfer = StyleTransferModel(content_image, content_image)
# Running the model for visualization purpose
input_img = content_image.clone().requires_grad_(True)
model, _, _ = style_transfer.get_style_model_and_losses(
cnn, cnn_normalization_mean, cnn_normalization_std, content_image, content_image)
layer_visualizations = []
# Run the image through each layer and store the output
for i, layer in enumerate(model):
input_img = layer(input_img)
with torch.no_grad():
output_image = tensor_to_image(input_img.clamp(0, 1))
img_io = io.BytesIO()
output_image.save(img_io, 'JPEG')
img_io.seek(0)
layer_visualizations.append(img_io.getvalue()) # Save the image bytes
return render_template('visualize.html', visualizations=layer_visualizations)
if __name__ == '__main__':
app.run(debug=True)