import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim import io from PIL import Image import matplotlib.pyplot as plt import torchvision.transforms as transforms from torchvision.models import vgg19, VGG19_Weights from torchvision import models import matplotlib.pyplot as plt import torchvision.utils as vutils device = torch.device("cuda" if torch.cuda.is_available() else "cpu") torch.set_default_device(device) def image_loader(image): #image = Image.open(image_name) # fake batch dimension required to fit network's input dimensions imsize = 512 if torch.cuda.is_available() else 128 # use small size if no GPU loader = transforms.Compose([ transforms.Resize(imsize), # scale imported image transforms.ToTensor()]) # transform it into a torch tensor image = loader(image).unsqueeze(0) return image.to(device, torch.float) def save_image(tensor, path): image = tensor.clone().detach() image = image.squeeze(0) image = transforms.ToPILImage()(image) image.save(path) class ContentLoss(nn.Module): def __init__(self, target,): super(ContentLoss, self).__init__() # we 'detach' the target content from the tree used # to dynamically compute the gradient: this is a stated value, # not a variable. Otherwise the forward method of the criterion # will throw an error. self.target = target.detach() def forward(self, input): self.loss = F.mse_loss(input, self.target) return input def gram_matrix(input): a, b, c, d = input.size() # a=batch size(=1) # b=number of feature maps # (c,d)=dimensions of a f. map (N=c*d) features = input.view(a * b, c * d) # resize F_XL into \hat F_XL G = torch.mm(features, features.t()) # compute the gram product # we 'normalize' the values of the gram matrix # by dividing by the number of element in each feature maps. return G.div(a * b * c * d) class StyleLoss(nn.Module): def __init__(self, target_feature): super(StyleLoss, self).__init__() self.target = gram_matrix(target_feature).detach() def forward(self, input): G = gram_matrix(input) self.loss = F.mse_loss(G, self.target) return input #cnn = vgg19(weights=VGG19_Weights.DEFAULT).features.eval() #cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406]) #cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225]) # create a module to normalize input image so we can easily put it in a # ``nn.Sequential`` class Normalization(nn.Module): def __init__(self, mean, std): super(Normalization, self).__init__() # .view the mean and std to make them [C x 1 x 1] so that they can # directly work with image Tensor of shape [B x C x H x W]. # B is batch size. C is number of channels. H is height and W is width. self.mean = torch.tensor(mean).view(-1, 1, 1) self.std = torch.tensor(std).view(-1, 1, 1) def forward(self, img): # normalize ``img`` return (img - self.mean) / self.std # desired depth layers to compute style/content losses : content_layers_default = ['conv_4'] style_layers_default = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5'] def get_style_model_and_losses(cnn, normalization_mean, normalization_std, style_img, content_img, content_layers=content_layers_default, style_layers=style_layers_default): # normalization module normalization = Normalization(normalization_mean, normalization_std) # just in order to have an iterable access to or list of content/style # losses content_losses = [] style_losses = [] # assuming that ``cnn`` is a ``nn.Sequential``, so we make a new ``nn.Sequential`` # to put in modules that are supposed to be activated sequentially model = nn.Sequential(normalization) i = 0 # increment every time we see a conv for layer in cnn.children(): if isinstance(layer, nn.Conv2d): i += 1 name = 'conv_{}'.format(i) elif isinstance(layer, nn.ReLU): name = 'relu_{}'.format(i) # The in-place version doesn't play very nicely with the ``ContentLoss`` # and ``StyleLoss`` we insert below. So we replace with out-of-place # ones here. layer = nn.ReLU(inplace=False) elif isinstance(layer, nn.MaxPool2d): name = 'pool_{}'.format(i) elif isinstance(layer, nn.BatchNorm2d): name = 'bn_{}'.format(i) else: raise RuntimeError('Unrecognized layer: {}'.format(layer.__class__.__name__)) model.add_module(name, layer) if name in content_layers: # add content loss: target = model(content_img).detach() content_loss = ContentLoss(target) model.add_module("content_loss_{}".format(i), content_loss) content_losses.append(content_loss) if name in style_layers: # add style loss: target_feature = model(style_img).detach() style_loss = StyleLoss(target_feature) model.add_module("style_loss_{}".format(i), style_loss) style_losses.append(style_loss) # now we trim off the layers after the last content and style losses for i in range(len(model) - 1, -1, -1): if isinstance(model[i], ContentLoss) or isinstance(model[i], StyleLoss): break model = model[:(i + 1)] return model, style_losses, content_losses def get_input_optimizer(input_img): # this line to show that input is a parameter that requires a gradient optimizer = optim.LBFGS([input_img]) return optimizer class StyleTransferModel: def __init__(self, content_img, style_img, num_steps=300, style_weight=1000000, content_weight=1): self.content_img = content_img self.style_img = style_img.resize(content_img.size) #self.style_img = self.style_img.resize(self.content_img.size) self.style_img = image_loader(self.style_img) self.content_img = image_loader(self.content_img) self.input_img = self.content_img.clone() self.num_steps = num_steps self.style_weight = style_weight self.content_weight = content_weight self.cnn = vgg19(weights=VGG19_Weights.DEFAULT).features.to(device).eval() self.cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406]).to(device) self.cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225]).to(device) def run_style_transfer(self): print('Building the style transfer model..') model, style_losses, content_losses = get_style_model_and_losses( self.cnn, self.cnn_normalization_mean, self.cnn_normalization_std, self.style_img, self.content_img) self.input_img.requires_grad_(True) model.eval() model.requires_grad_(False) optimizer = get_input_optimizer(self.input_img) print('Optimizing..') run = [0] while run[0] <= self.num_steps: def closure(): with torch.no_grad(): self.input_img.clamp_(0, 1) optimizer.zero_grad() model(self.input_img) style_score = 0 content_score = 0 for sl in style_losses: style_score += sl.loss for cl in content_losses: content_score += cl.loss style_score *= self.style_weight content_score *= self.content_weight loss = style_score + content_score loss.backward() run[0] += 1 if run[0] % 50 == 0: print(f"run {run[0]}:") print(f'Style Loss : {style_score.item():4f} Content Loss: {content_score.item():4f}') print() return style_score + content_score optimizer.step(closure) with torch.no_grad(): self.input_img.clamp_(0, 1) return self.input_img