Computer_Vision/Chapter12/Face_generation_using_DCGAN.ipynb
2024-02-13 03:34:51 +01:00

805 KiB
Raw Permalink Blame History

Open In Colab

!wget https://www.dropbox.com/s/rbajpdlh7efkdo1/male_female_face_images.zip
!unzip -q male_female_face_images.zip
!pip install -q --upgrade torch_snippets
from torch_snippets import *
import torchvision
from torchvision import transforms
import torchvision.utils as vutils
import cv2, numpy as np, pandas as pd
device = "cuda" if torch.cuda.is_available() else "cpu"
     |████████████████████████████████| 61kB 6.6MB/s 
     |████████████████████████████████| 36.7MB 82kB/s 
     |████████████████████████████████| 102kB 13.3MB/s 
[?25h  Building wheel for contextvars (setup.py) ... [?25l[?25hdone
face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml')
!mkdir cropped_faces
images = Glob('/content/females/*.jpg')+Glob('/content/males/*.jpg')
for i in range(len(images)):
    img = read(images[i],1)
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    faces = face_cascade.detectMultiScale(gray, 1.3, 5)
    for (x,y,w,h) in faces:
        img2 = img[y:(y+h),x:(x+w),:]
    cv2.imwrite('cropped_faces/'+str(i)+'.jpg',cv2.cvtColor(img2, cv2.COLOR_RGB2BGR))
2020-11-06 18:40:53.399 | INFO     | torch_snippets.loader:Glob:181 - 14688 files found at /content/females/*.jpg
2020-11-06 18:40:53.430 | INFO     | torch_snippets.loader:Glob:181 - 13948 files found at /content/males/*.jpg
transform=transforms.Compose([
                               transforms.Resize(64),
                               transforms.CenterCrop(64),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
class Faces(Dataset):
    def __init__(self, folder):
        super().__init__()
        self.folder = folder
        self.images = sorted(Glob(folder))
    def __len__(self):
        return len(self.images)
    def __getitem__(self, ix):
        image_path = self.images[ix]
        image = Image.open(image_path)
        image = transform(image)
        return image
ds = Faces(folder='cropped_faces/')
2020-11-06 18:43:23.108 | INFO     | torch_snippets.loader:Glob:181 - 28636 files found at cropped_faces/
dataloader = DataLoader(ds, batch_size=64, shuffle=True, num_workers=8)
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3,64,4,2,1,bias=False),
            nn.LeakyReLU(0.2,inplace=True),
            nn.Conv2d(64,64*2,4,2,1,bias=False),
            nn.BatchNorm2d(64*2),
            nn.LeakyReLU(0.2,inplace=True),
            nn.Conv2d(64*2,64*4,4,2,1,bias=False),
            nn.BatchNorm2d(64*4),
            nn.LeakyReLU(0.2,inplace=True),
            nn.Conv2d(64*4,64*8,4,2,1,bias=False),
            nn.BatchNorm2d(64*8),
            nn.LeakyReLU(0.2,inplace=True),
            nn.Conv2d(64*8,1,4,1,0,bias=False),
            nn.Sigmoid()
        )
        self.apply(weights_init)
    def forward(self, input): return self.model(input)
!pip install torch_summary
from torchsummary import summary
discriminator = Discriminator().to(device)
summary(discriminator,torch.zeros(1,3,64,64));
Collecting torch_summary
  Downloading https://files.pythonhosted.org/packages/83/49/f9db57bcad7246591b93519fd8e5166c719548c45945ef7d2fc9fcba46fb/torch_summary-1.4.3-py3-none-any.whl
Installing collected packages: torch-summary
Successfully installed torch-summary-1.4.3
==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
├─Sequential: 1-1                        [-1, 1, 1, 1]             --
|    └─Conv2d: 2-1                       [-1, 64, 32, 32]          3,072
|    └─LeakyReLU: 2-2                    [-1, 64, 32, 32]          --
|    └─Conv2d: 2-3                       [-1, 128, 16, 16]         131,072
|    └─BatchNorm2d: 2-4                  [-1, 128, 16, 16]         256
|    └─LeakyReLU: 2-5                    [-1, 128, 16, 16]         --
|    └─Conv2d: 2-6                       [-1, 256, 8, 8]           524,288
|    └─BatchNorm2d: 2-7                  [-1, 256, 8, 8]           512
|    └─LeakyReLU: 2-8                    [-1, 256, 8, 8]           --
|    └─Conv2d: 2-9                       [-1, 512, 4, 4]           2,097,152
|    └─BatchNorm2d: 2-10                 [-1, 512, 4, 4]           1,024
|    └─LeakyReLU: 2-11                   [-1, 512, 4, 4]           --
|    └─Conv2d: 2-12                      [-1, 1, 1, 1]             8,192
|    └─Sigmoid: 2-13                     [-1, 1, 1, 1]             --
==========================================================================================
Total params: 2,765,568
Trainable params: 2,765,568
Non-trainable params: 0
Total mult-adds (M): 106.58
==========================================================================================
Input size (MB): 0.05
Forward/backward pass size (MB): 1.38
Params size (MB): 10.55
Estimated Total Size (MB): 11.97
==========================================================================================
class Generator(nn.Module):
    def __init__(self):
        super(Generator,self).__init__()
        self.model = nn.Sequential(
            nn.ConvTranspose2d(100,64*8,4,1,0,bias=False,),
            nn.BatchNorm2d(64*8),
            nn.ReLU(True),
            nn.ConvTranspose2d(64*8,64*4,4,2,1,bias=False),
            nn.BatchNorm2d(64*4),
            nn.ReLU(True),
            nn.ConvTranspose2d( 64*4,64*2,4,2,1,bias=False),
            nn.BatchNorm2d(64*2),
            nn.ReLU(True),
            nn.ConvTranspose2d( 64*2,64,4,2,1,bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d( 64,3,4,2,1,bias=False),
            nn.Tanh()
        )
        self.apply(weights_init)
    def forward(self,input): return self.model(input)
generator = Generator().to(device)
summary(generator,torch.zeros(1,100,1,1))
==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
├─Sequential: 1-1                        [-1, 3, 64, 64]           --
|    └─ConvTranspose2d: 2-1              [-1, 512, 4, 4]           819,200
|    └─BatchNorm2d: 2-2                  [-1, 512, 4, 4]           1,024
|    └─ReLU: 2-3                         [-1, 512, 4, 4]           --
|    └─ConvTranspose2d: 2-4              [-1, 256, 8, 8]           2,097,152
|    └─BatchNorm2d: 2-5                  [-1, 256, 8, 8]           512
|    └─ReLU: 2-6                         [-1, 256, 8, 8]           --
|    └─ConvTranspose2d: 2-7              [-1, 128, 16, 16]         524,288
|    └─BatchNorm2d: 2-8                  [-1, 128, 16, 16]         256
|    └─ReLU: 2-9                         [-1, 128, 16, 16]         --
|    └─ConvTranspose2d: 2-10             [-1, 64, 32, 32]          131,072
|    └─BatchNorm2d: 2-11                 [-1, 64, 32, 32]          128
|    └─ReLU: 2-12                        [-1, 64, 32, 32]          --
|    └─ConvTranspose2d: 2-13             [-1, 3, 64, 64]           3,072
|    └─Tanh: 2-14                        [-1, 3, 64, 64]           --
==========================================================================================
Total params: 3,576,704
Trainable params: 3,576,704
Non-trainable params: 0
Total mult-adds (M): 431.92
==========================================================================================
Input size (MB): 0.00
Forward/backward pass size (MB): 1.97
Params size (MB): 13.64
Estimated Total Size (MB): 15.61
==========================================================================================
==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
├─Sequential: 1-1                        [-1, 3, 64, 64]           --
|    └─ConvTranspose2d: 2-1              [-1, 512, 4, 4]           819,200
|    └─BatchNorm2d: 2-2                  [-1, 512, 4, 4]           1,024
|    └─ReLU: 2-3                         [-1, 512, 4, 4]           --
|    └─ConvTranspose2d: 2-4              [-1, 256, 8, 8]           2,097,152
|    └─BatchNorm2d: 2-5                  [-1, 256, 8, 8]           512
|    └─ReLU: 2-6                         [-1, 256, 8, 8]           --
|    └─ConvTranspose2d: 2-7              [-1, 128, 16, 16]         524,288
|    └─BatchNorm2d: 2-8                  [-1, 128, 16, 16]         256
|    └─ReLU: 2-9                         [-1, 128, 16, 16]         --
|    └─ConvTranspose2d: 2-10             [-1, 64, 32, 32]          131,072
|    └─BatchNorm2d: 2-11                 [-1, 64, 32, 32]          128
|    └─ReLU: 2-12                        [-1, 64, 32, 32]          --
|    └─ConvTranspose2d: 2-13             [-1, 3, 64, 64]           3,072
|    └─Tanh: 2-14                        [-1, 3, 64, 64]           --
==========================================================================================
Total params: 3,576,704
Trainable params: 3,576,704
Non-trainable params: 0
Total mult-adds (M): 431.92
==========================================================================================
Input size (MB): 0.00
Forward/backward pass size (MB): 1.97
Params size (MB): 13.64
Estimated Total Size (MB): 15.61
==========================================================================================
def discriminator_train_step(real_data, fake_data):
    d_optimizer.zero_grad()
    prediction_real = discriminator(real_data)
    error_real = loss(prediction_real.squeeze(), torch.ones(len(real_data)).to(device))
    error_real.backward()
    prediction_fake = discriminator(fake_data)
    error_fake = loss(prediction_fake.squeeze(), torch.zeros(len(fake_data)).to(device))
    error_fake.backward()
    d_optimizer.step()
    return error_real + error_fake

def generator_train_step(fake_data):
    g_optimizer.zero_grad()
    prediction = discriminator(fake_data)
    error = loss(prediction.squeeze(), torch.ones(len(real_data)).to(device))
    error.backward()
    g_optimizer.step()
    return error
discriminator = Discriminator().to(device)
generator = Generator().to(device)
loss = nn.BCELoss()
d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
g_optimizer = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
log = Report(25)
for epoch in range(25):
    N = len(dataloader)
    for i, images in enumerate(dataloader):
        real_data = images.to(device)
        fake_data = generator(torch.randn(len(real_data), 100, 1, 1).to(device)).to(device)
        fake_data = fake_data.detach()
        d_loss = discriminator_train_step(real_data, fake_data)
        fake_data = generator(torch.randn(len(real_data), 100, 1, 1).to(device)).to(device)
        g_loss = generator_train_step(fake_data)
        log.record(epoch+(1+i)/N, d_loss=d_loss.item(), g_loss=g_loss.item(), end='\r')
    log.report_avgs(epoch+1)
log.plot_epochs(['d_loss','g_loss'])
EPOCH: 1.000	d_loss: 0.023	g_loss: 36.309	(45.01s - 1080.18s remaining)
EPOCH: 2.000	d_loss: 1.288	g_loss: 7.765	(91.21s - 1048.87s remaining)
EPOCH: 3.000	d_loss: 0.725	g_loss: 4.605	(138.52s - 1015.83s remaining)
EPOCH: 4.000	d_loss: 0.594	g_loss: 4.297	(187.17s - 982.62s remaining)
EPOCH: 5.000	d_loss: 0.575	g_loss: 4.024	(236.63s - 946.51s remaining)
EPOCH: 6.000	d_loss: 0.605	g_loss: 3.599	(287.06s - 909.01s remaining)
EPOCH: 7.000	d_loss: 0.614	g_loss: 3.401	(337.72s - 868.43s remaining)
EPOCH: 8.000	d_loss: 0.646	g_loss: 3.144	(388.28s - 825.09s remaining)
EPOCH: 9.000	d_loss: 0.667	g_loss: 3.028	(438.73s - 779.96s remaining)
EPOCH: 10.000	d_loss: 0.620	g_loss: 2.907	(489.32s - 733.98s remaining)
EPOCH: 11.000	d_loss: 0.661	g_loss: 2.954	(539.77s - 686.98s remaining)
EPOCH: 12.000	d_loss: 0.618	g_loss: 2.899	(590.34s - 639.54s remaining)
EPOCH: 13.000	d_loss: 0.573	g_loss: 2.901	(640.96s - 591.65s remaining)
EPOCH: 14.000	d_loss: 0.613	g_loss: 2.928	(691.42s - 543.26s remaining)
EPOCH: 15.000	d_loss: 0.566	g_loss: 3.032	(741.80s - 494.53s remaining)
EPOCH: 16.000	d_loss: 0.569	g_loss: 3.048	(792.24s - 445.64s remaining)
EPOCH: 17.000	d_loss: 0.496	g_loss: 3.029	(842.80s - 396.61s remaining)
EPOCH: 18.000	d_loss: 0.555	g_loss: 3.008	(893.66s - 347.53s remaining)
EPOCH: 19.000	d_loss: 0.497	g_loss: 3.129	(944.73s - 298.34s remaining)
EPOCH: 20.000	d_loss: 0.554	g_loss: 3.072	(995.37s - 248.84s remaining)
EPOCH: 21.000	d_loss: 0.410	g_loss: 3.251	(1045.87s - 199.21s remaining)
EPOCH: 22.000	d_loss: 0.480	g_loss: 3.209	(1096.45s - 149.52s remaining)
EPOCH: 23.000	d_loss: 0.470	g_loss: 3.323	(1146.98s - 99.74s remaining)
EPOCH: 24.000	d_loss: 0.508	g_loss: 3.262	(1197.44s - 49.89s remaining)
EPOCH: 25.000	d_loss: 0.414	g_loss: 1.849	(1247.87s - 0.00s remaining)
  0%|          | 0/25 [00:00<?, ?it/s]/usr/local/lib/python3.6/dist-packages/numpy/core/fromnumeric.py:3335: RuntimeWarning: Mean of empty slice.
  out=out, **kwargs)
/usr/local/lib/python3.6/dist-packages/numpy/core/_methods.py:161: RuntimeWarning: invalid value encountered in double_scalars
  ret = ret.dtype.type(ret / rcount)
100%|██████████| 25/25 [00:00<00:00, 460.03it/s]
EPOCH: 25.000	d_loss: 0.462	g_loss: 3.340	(1248.02s - 0.00s remaining)
generator.eval()
noise = torch.randn(64, 100, 1, 1, device=device)
sample_images = generator(noise).detach().cpu()
grid = vutils.make_grid(sample_images, nrow=8, normalize=True)
show(grid.cpu().detach().permute(1,2,0), sz=10, title='Generated images')