Computer_Vision/Chapter11/neural_style_transfer.ipynb
2024-02-13 03:34:51 +01:00

955 KiB
Raw Blame History

Open In Colab

!pip install torch_snippets
from torch_snippets import *
from torchvision import transforms as T
from torch.nn import functional as F
device = 'cuda' if torch.cuda.is_available() else 'cpu'
Collecting torch_snippets
  Downloading https://files.pythonhosted.org/packages/50/13/302867fc4189c33290179a92e745cbfe6132c3120f5cbad245026a7eccf9/torch_snippets-0.234-py3-none-any.whl
Requirement already satisfied: Pillow in /usr/local/lib/python3.6/dist-packages (from torch_snippets) (7.0.0)
Requirement already satisfied: matplotlib in /usr/local/lib/python3.6/dist-packages (from torch_snippets) (3.2.2)
Requirement already satisfied: dill in /usr/local/lib/python3.6/dist-packages (from torch_snippets) (0.3.2)
Requirement already satisfied: pandas in /usr/local/lib/python3.6/dist-packages (from torch_snippets) (1.1.3)
Collecting opencv-python-headless
[?25l  Downloading https://files.pythonhosted.org/packages/08/e9/57d869561389884136be65a2d1bc038fe50171e2ba348fda269a4aab8032/opencv_python_headless-4.4.0.46-cp36-cp36m-manylinux2014_x86_64.whl (36.7MB)
     |████████████████████████████████| 36.7MB 78kB/s 
[?25hRequirement already satisfied: tqdm in /usr/local/lib/python3.6/dist-packages (from torch_snippets) (4.41.1)
Collecting loguru
[?25l  Downloading https://files.pythonhosted.org/packages/6d/48/0a7d5847e3de329f1d0134baf707b689700b53bd3066a5a8cfd94b3c9fc8/loguru-0.5.3-py3-none-any.whl (57kB)
     |████████████████████████████████| 61kB 8.3MB/s 
[?25hRequirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from torch_snippets) (1.18.5)
Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib->torch_snippets) (1.2.0)
Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib->torch_snippets) (2.4.7)
Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib->torch_snippets) (2.8.1)
Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.6/dist-packages (from matplotlib->torch_snippets) (0.10.0)
Requirement already satisfied: pytz>=2017.2 in /usr/local/lib/python3.6/dist-packages (from pandas->torch_snippets) (2018.9)
Collecting aiocontextvars>=0.2.0; python_version < "3.7"
  Downloading https://files.pythonhosted.org/packages/db/c1/7a723e8d988de0a2e623927396e54b6831b68cb80dce468c945b849a9385/aiocontextvars-0.2.2-py2.py3-none-any.whl
Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.6/dist-packages (from python-dateutil>=2.1->matplotlib->torch_snippets) (1.15.0)
Collecting contextvars==2.4; python_version < "3.7"
  Downloading https://files.pythonhosted.org/packages/83/96/55b82d9f13763be9d672622e1b8106c85acb83edd7cc2fa5bc67cd9877e9/contextvars-2.4.tar.gz
Collecting immutables>=0.9
[?25l  Downloading https://files.pythonhosted.org/packages/99/e0/ea6fd4697120327d26773b5a84853f897a68e33d3f9376b00a8ff96e4f63/immutables-0.14-cp36-cp36m-manylinux1_x86_64.whl (98kB)
     |████████████████████████████████| 102kB 9.3MB/s 
[?25hBuilding wheels for collected packages: contextvars
  Building wheel for contextvars (setup.py) ... [?25l[?25hdone
  Created wheel for contextvars: filename=contextvars-2.4-cp36-none-any.whl size=7666 sha256=a115fa47d9d1589770a9e44a93d5eb71890fbda661cd92a297d791aa6b61b384
  Stored in directory: /root/.cache/pip/wheels/a5/7d/68/1ebae2668bda2228686e3c1cf16f2c2384cea6e9334ad5f6de
Successfully built contextvars
Installing collected packages: opencv-python-headless, immutables, contextvars, aiocontextvars, loguru, torch-snippets
Successfully installed aiocontextvars-0.2.2 contextvars-2.4 immutables-0.14 loguru-0.5.3 opencv-python-headless-4.4.0.46 torch-snippets-0.234
from torchvision.models import vgg19
preprocess = T.Compose([
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    T.Lambda(lambda x: x.mul_(255))
])
postprocess = T.Compose([
    T.Lambda(lambda x: x.mul_(1./255)),
    T.Normalize(mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225], std=[1/0.229, 1/0.224, 1/0.225]),
])
class GramMatrix(nn.Module):
    def forward(self, input):
        b,c,h,w = input.size()
        feat = input.view(b, c, h*w)
        G = feat@feat.transpose(1,2)
        G.div_(h*w)
        return G
class GramMSELoss(nn.Module):
    def forward(self, input, target):
        out = F.mse_loss(GramMatrix()(input), target)
        return(out)
class vgg19_modified(nn.Module):
    def __init__(self):
        super().__init__()
        features = list(vgg19(pretrained = True).features)
        self.features = nn.ModuleList(features).eval() 
    def forward(self, x, layers=[]):
        order = np.argsort(layers)
        _results, results = [], []
        for ix,model in enumerate(self.features):
            x = model(x)
            if ix in layers: _results.append(x)
        for o in order: results.append(_results[o])
        return results if layers is not [] else x
vgg = vgg19_modified().to(device)
Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /root/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth
HBox(children=(FloatProgress(value=0.0, max=574673361.0), HTML(value='')))
!wget https://www.dropbox.com/s/z1y0fy2r6z6m6py/60.jpg
!wget https://www.dropbox.com/s/1svdliljyo0a98v/style_image.png
--2020-11-04 12:49:49--  https://www.dropbox.com/s/z1y0fy2r6z6m6py/60.jpg
Resolving www.dropbox.com (www.dropbox.com)... 162.125.67.1, 2620:100:6023:1::a27d:4301
Connecting to www.dropbox.com (www.dropbox.com)|162.125.67.1|:443... connected.
HTTP request sent, awaiting response... 301 Moved Permanently
Location: /s/raw/z1y0fy2r6z6m6py/60.jpg [following]
--2020-11-04 12:49:49--  https://www.dropbox.com/s/raw/z1y0fy2r6z6m6py/60.jpg
Reusing existing connection to www.dropbox.com:443.
HTTP request sent, awaiting response... 302 Found
Location: https://ucafd289f0976768196009ab2c29.dl.dropboxusercontent.com/cd/0/inline/BCiua8nqxIBZbZnuplfwgaPjkY9BijOWW5zdYXeRmEH0e3l1ZNCyDelMxefD4Uy280ncvactSbQsFOkPGuutIXPcbrRMzWGF4rmopfH0K-xqIA/file# [following]
--2020-11-04 12:49:50--  https://ucafd289f0976768196009ab2c29.dl.dropboxusercontent.com/cd/0/inline/BCiua8nqxIBZbZnuplfwgaPjkY9BijOWW5zdYXeRmEH0e3l1ZNCyDelMxefD4Uy280ncvactSbQsFOkPGuutIXPcbrRMzWGF4rmopfH0K-xqIA/file
Resolving ucafd289f0976768196009ab2c29.dl.dropboxusercontent.com (ucafd289f0976768196009ab2c29.dl.dropboxusercontent.com)... 162.125.65.15, 2620:100:6021:15::a27d:410f
Connecting to ucafd289f0976768196009ab2c29.dl.dropboxusercontent.com (ucafd289f0976768196009ab2c29.dl.dropboxusercontent.com)|162.125.65.15|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 3650056 (3.5M) [image/jpeg]
Saving to: 60.jpg

60.jpg              100%[===================>]   3.48M  9.61MB/s    in 0.4s    

2020-11-04 12:49:51 (9.61 MB/s) - 60.jpg saved [3650056/3650056]

--2020-11-04 12:49:51--  https://www.dropbox.com/s/1svdliljyo0a98v/style_image.png
Resolving www.dropbox.com (www.dropbox.com)... 162.125.65.1, 2620:100:6023:1::a27d:4301
Connecting to www.dropbox.com (www.dropbox.com)|162.125.65.1|:443... connected.
HTTP request sent, awaiting response... 301 Moved Permanently
Location: /s/raw/1svdliljyo0a98v/style_image.png [following]
--2020-11-04 12:49:51--  https://www.dropbox.com/s/raw/1svdliljyo0a98v/style_image.png
Reusing existing connection to www.dropbox.com:443.
HTTP request sent, awaiting response... 302 Found
Location: https://uc733cdedba7ffc778a8467d938f.dl.dropboxusercontent.com/cd/0/inline/BCjyqFGrcJXgZVt3ZfCBdahhbcd90Gr7yEHg9J-9-1qm3cjA2TXb029HRoyAmUMImB56Xvo7M8YcF_JtzMePyDGbBGGySy7X1WdQ5eCHAcUGhI_okKtAUfTqs2LR5n3ErNk/file# [following]
--2020-11-04 12:49:52--  https://uc733cdedba7ffc778a8467d938f.dl.dropboxusercontent.com/cd/0/inline/BCjyqFGrcJXgZVt3ZfCBdahhbcd90Gr7yEHg9J-9-1qm3cjA2TXb029HRoyAmUMImB56Xvo7M8YcF_JtzMePyDGbBGGySy7X1WdQ5eCHAcUGhI_okKtAUfTqs2LR5n3ErNk/file
Resolving uc733cdedba7ffc778a8467d938f.dl.dropboxusercontent.com (uc733cdedba7ffc778a8467d938f.dl.dropboxusercontent.com)... 162.125.65.15, 2620:100:6023:15::a27d:430f
Connecting to uc733cdedba7ffc778a8467d938f.dl.dropboxusercontent.com (uc733cdedba7ffc778a8467d938f.dl.dropboxusercontent.com)|162.125.65.15|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 618980 (604K) [image/png]
Saving to: style_image.png

style_image.png     100%[===================>] 604.47K  --.-KB/s    in 0.1s    

2020-11-04 12:49:52 (4.91 MB/s) - style_image.png saved [618980/618980]

imgs = [Image.open(path).resize((512,512)).convert('RGB') for path in ['style_image.png', '60.jpg']]
style_image, content_image = [preprocess(img).to(device)[None] for img in imgs]
opt_img = content_image.data.clone()
opt_img.requires_grad = True
style_layers = [0, 5, 10, 19, 28] 
content_layers = [21]
loss_layers = style_layers + content_layers
loss_fns = [GramMSELoss()] * len(style_layers) + [nn.MSELoss()] * len(content_layers)
loss_fns = [loss_fn.to(device) for loss_fn in loss_fns]
style_weights = [1000/n**2 for n in [64,128,256,512,512]] 
content_weights = [1]
weights = style_weights + content_weights
style_targets = [GramMatrix()(A).detach() for A in vgg(style_image, style_layers)]
content_targets = [A.detach() for A in vgg(content_image, content_layers)]
targets = style_targets + content_targets
max_iters = 500
optimizer = optim.LBFGS([opt_img])
log = Report(max_iters)
iters = 0
while iters < max_iters:
    def closure():
        global iters
        iters += 1
        optimizer.zero_grad()
        out = vgg(opt_img, loss_layers)
        layer_losses = [weights[a] * loss_fns[a](A, targets[a]) for a,A in enumerate(out)]
        loss = sum(layer_losses)
        loss.backward()
        log.record(pos=iters, loss=loss, end='\r')
        return loss
    optimizer.step(closure)
EPOCH: 502.000	loss: 9652914.000	(247.12s - -0.98s remaining)
log.plot(log=True)
with torch.no_grad():
    out_img = postprocess(opt_img[0]).permute(1,2,0)
show(out_img)
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).