Computer_Vision/Chapter11/Generating_deep_fakes.ipynb
2024-02-13 03:34:51 +01:00

1004 KiB
Raw Permalink Blame History

Open In Colab

import os
if not os.path.exists('Faceswap-Deepfake-Pytorch'):
    !wget -q https://www.dropbox.com/s/5ji7jl7httso9ny/person_images.zip
    !wget -q https://raw.githubusercontent.com/sizhky/deep-fake-util/main/random_warp.py
    !unzip -q person_images.zip
!pip install -q torch_snippets torch_summary
from torch_snippets import *
from random_warp import get_training_data
face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml')
def crop_face(img):
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    faces = face_cascade.detectMultiScale(gray, 1.3, 5)
    if(len(faces)>0):
        for (x,y,w,h) in faces:
            img2 = img[y:(y+h),x:(x+w),:]
        img2 = cv2.resize(img2,(256,256))
        return img2, True
    else:
        return img, False
!mkdir cropped_faces_personA
!mkdir cropped_faces_personB

def crop_images(folder):
    images = Glob(folder+'/*.jpg')
    for i in range(len(images)):
        img = read(images[i],1)
        img2, face_detected = crop_face(img)
        if(face_detected==False):
            continue
        else:
            cv2.imwrite('cropped_faces_'+folder+'/'+str(i)+'.jpg',cv2.cvtColor(img2, cv2.COLOR_RGB2BGR))
crop_images('personA')
crop_images('personB')
mkdir: cannot create directory cropped_faces_personA: File exists
mkdir: cannot create directory cropped_faces_personB: File exists
2020-11-08 07:23:24.933 | INFO     | torch_snippets.loader:Glob:181 - 444 files found at personA/*.jpg
class ImageDataset(Dataset):
    def __init__(self, items_A, items_B):
        self.items_A = np.concatenate([read(f,1)[None] for f in items_A])/255.
        self.items_B = np.concatenate([read(f,1)[None] for f in items_B])/255.
        self.items_A += self.items_B.mean(axis=(0, 1, 2)) - self.items_A.mean(axis=(0, 1, 2))

    def __len__(self):
        return min(len(self.items_A), len(self.items_B))
    def __getitem__(self, ix):
        a, b = choose(self.items_A), choose(self.items_B)
        return a, b

    def collate_fn(self, batch):
        imsA, imsB = list(zip(*batch))
        imsA, targetA = get_training_data(imsA, len(imsA))
        imsB, targetB = get_training_data(imsB, len(imsB))
        imsA, imsB, targetA, targetB = [torch.Tensor(i).permute(0,3,1,2).to(device) for i in [imsA, imsB, targetA, targetB]]
        return imsA, imsB, targetA, targetB

a = ImageDataset(Glob('cropped_faces_personA'), Glob('cropped_faces_personB'))
x = DataLoader(a, batch_size=32, collate_fn=a.collate_fn)
2020-11-08 07:16:09.186 | INFO     | torch_snippets.loader:Glob:181 - 349 files found at cropped_faces_personA
2020-11-08 07:16:09.189 | INFO     | torch_snippets.loader:Glob:181 - 105 files found at cropped_faces_personB
inspect(*next(iter(x)))

for i in next(iter(x)):
    subplots(i[:8], nc=4, sz=(4,2))
==================================================================
Tensor	Shape: torch.Size([32, 3, 64, 64])	Min: -0.006	Max: 0.884	Mean: 0.506	dtype: torch.float32
==================================================================
Tensor	Shape: torch.Size([32, 3, 64, 64])	Min: 0.023	Max: 0.927	Mean: 0.490	dtype: torch.float32
==================================================================
Tensor	Shape: torch.Size([32, 3, 64, 64])	Min: 0.003	Max: 0.882	Mean: 0.506	dtype: torch.float32
==================================================================
Tensor	Shape: torch.Size([32, 3, 64, 64])	Min: 0.023	Max: 0.927	Mean: 0.492	dtype: torch.float32
==================================================================
2020-11-08 07:16:16.342 | INFO     | torch_snippets.loader:subplots:375 - plotting 8 images in a grid of 2x4 @ (5, 5)
2020-11-08 07:16:16.710 | INFO     | torch_snippets.loader:subplots:375 - plotting 8 images in a grid of 2x4 @ (5, 5)
2020-11-08 07:16:17.026 | INFO     | torch_snippets.loader:subplots:375 - plotting 8 images in a grid of 2x4 @ (5, 5)
2020-11-08 07:16:17.411 | INFO     | torch_snippets.loader:subplots:375 - plotting 8 images in a grid of 2x4 @ (5, 5)
def _ConvLayer(input_features, output_features):
    return nn.Sequential(
        nn.Conv2d(input_features, output_features, kernel_size=5, stride=2, padding=2),
        nn.LeakyReLU(0.1, inplace=True)
    )

def _UpScale(input_features, output_features):
    return nn.Sequential(
        nn.ConvTranspose2d(input_features, output_features, kernel_size=2, stride=2, padding=0),
        nn.LeakyReLU(0.1, inplace=True)
    )

class Reshape(nn.Module):
    def forward(self, input):
        output = input.view(-1, 1024, 4, 4) # channel * 4 * 4
        return output
class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()

        self.encoder = nn.Sequential(
            _ConvLayer(3, 128),
            _ConvLayer(128, 256),
            _ConvLayer(256, 512),
            _ConvLayer(512, 1024),
            nn.Flatten(),
            nn.Linear(1024 * 4 * 4, 1024),
            nn.Linear(1024, 1024 * 4 * 4),
            Reshape(),
            _UpScale(1024, 512),
        )

        self.decoder_A = nn.Sequential(
            _UpScale(512, 256),
            _UpScale(256, 128),
            _UpScale(128, 64),
            nn.Conv2d(64, 3, kernel_size=3, padding=1),
            nn.Sigmoid(),
        )

        self.decoder_B = nn.Sequential(
            _UpScale(512, 256),
            _UpScale(256, 128),
            _UpScale(128, 64),
            nn.Conv2d(64, 3, kernel_size=3, padding=1),
            nn.Sigmoid(),
        )

    def forward(self, x, select='A'):
        if select == 'A':
            out = self.encoder(x)
            out = self.decoder_A(out)
        else:
            out = self.encoder(x)
            out = self.decoder_B(out)
        return out
from torchsummary import summary
model = Autoencoder()
summary(model, torch.zeros(32,3,64,64), 'A');
==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
├─Sequential: 1-1                        [-1, 512, 8, 8]           --
|    └─Sequential: 2-1                   [-1, 128, 32, 32]         --
|    |    └─Conv2d: 3-1                  [-1, 128, 32, 32]         9,728
|    |    └─LeakyReLU: 3-2               [-1, 128, 32, 32]         --
|    └─Sequential: 2-2                   [-1, 256, 16, 16]         --
|    |    └─Conv2d: 3-3                  [-1, 256, 16, 16]         819,456
|    |    └─LeakyReLU: 3-4               [-1, 256, 16, 16]         --
|    └─Sequential: 2-3                   [-1, 512, 8, 8]           --
|    |    └─Conv2d: 3-5                  [-1, 512, 8, 8]           3,277,312
|    |    └─LeakyReLU: 3-6               [-1, 512, 8, 8]           --
|    └─Sequential: 2-4                   [-1, 1024, 4, 4]          --
|    |    └─Conv2d: 3-7                  [-1, 1024, 4, 4]          13,108,224
|    |    └─LeakyReLU: 3-8               [-1, 1024, 4, 4]          --
|    └─Flatten: 2-5                      [-1, 16384]               --
|    └─Linear: 2-6                       [-1, 1024]                16,778,240
|    └─Linear: 2-7                       [-1, 16384]               16,793,600
|    └─Reshape: 2-8                      [-1, 1024, 4, 4]          --
|    └─Sequential: 2-9                   [-1, 512, 8, 8]           --
|    |    └─ConvTranspose2d: 3-9         [-1, 512, 8, 8]           2,097,664
|    |    └─LeakyReLU: 3-10              [-1, 512, 8, 8]           --
├─Sequential: 1-2                        [-1, 3, 64, 64]           --
|    └─Sequential: 2-10                  [-1, 256, 16, 16]         --
|    |    └─ConvTranspose2d: 3-11        [-1, 256, 16, 16]         524,544
|    |    └─LeakyReLU: 3-12              [-1, 256, 16, 16]         --
|    └─Sequential: 2-11                  [-1, 128, 32, 32]         --
|    |    └─ConvTranspose2d: 3-13        [-1, 128, 32, 32]         131,200
|    |    └─LeakyReLU: 3-14              [-1, 128, 32, 32]         --
|    └─Sequential: 2-12                  [-1, 64, 64, 64]          --
|    |    └─ConvTranspose2d: 3-15        [-1, 64, 64, 64]          32,832
|    |    └─LeakyReLU: 3-16              [-1, 64, 64, 64]          --
|    └─Conv2d: 2-13                      [-1, 3, 64, 64]           1,731
|    └─Sigmoid: 2-14                     [-1, 3, 64, 64]           --
==========================================================================================
Total params: 53,574,531
Trainable params: 53,574,531
Non-trainable params: 0
Total mult-adds (G): 1.29
==========================================================================================
Input size (MB): 1.50
Forward/backward pass size (MB): 5.85
Params size (MB): 204.37
Estimated Total Size (MB): 211.72
==========================================================================================
def train_batch(model, data, criterion, optimizes):
    optA, optB = optimizers
    optA.zero_grad()
    optB.zero_grad()
    imgA, imgB, targetA, targetB = data
    _imgA, _imgB = model(imgA, 'A'), model(imgB, 'B')

    lossA = criterion(_imgA, targetA)
    lossB = criterion(_imgB, targetB)
    
    lossA.backward()
    lossB.backward()

    optA.step()
    optB.step()

    return lossA.item(), lossB.item()
model = Autoencoder().to(device)

dataset = ImageDataset(Glob('cropped_faces_personA'), Glob('cropped_faces_personB'))
dataloader = DataLoader(dataset, 32, collate_fn=dataset.collate_fn)

optimizers = optim.Adam([{'params': model.encoder.parameters()},
                          {'params': model.decoder_A.parameters()}],
                        lr=5e-5, betas=(0.5, 0.999)), \
             optim.Adam([{'params': model.encoder.parameters()},
                          {'params': model.decoder_B.parameters()}], 
                        lr=5e-5, betas=(0.5, 0.999))
             
criterion = nn.L1Loss()
2020-11-08 07:16:45.033 | INFO     | torch_snippets.loader:Glob:181 - 349 files found at cropped_faces_personA
2020-11-08 07:16:45.036 | INFO     | torch_snippets.loader:Glob:181 - 105 files found at cropped_faces_personB
n_epochs = 10000
log = Report(n_epochs)
!mkdir checkpoint
for ex in range(n_epochs):
    N = len(dataloader)
    for bx,data in enumerate(dataloader):
        lossA, lossB = train_batch(model, data, criterion, optimizers)
        log.record(ex+(1+bx)/N, lossA=lossA, lossB=lossB, end='\r')

    log.report_avgs(ex+1)
    if (ex+1)%100 == 0:
        state = {
                'state': model.state_dict(),
                'epoch': ex
            }
        torch.save(state, './checkpoint/autoencoder.pth')

    if (ex+1)%100 == 0:
        bs = 5
        a,b,A,B = data
        line('A to B')
        _a = model(a[:bs], 'A')
        _b = model(a[:bs], 'B')
        x = torch.cat([A[:bs],_a,_b])
        subplots(x, nc=bs, figsize=(bs*2, 5))

        line('B to A')
        _a = model(b[:bs], 'A')
        _b = model(b[:bs], 'B')
        x = torch.cat([B[:bs],_a,_b])
        subplots(x, nc=bs, figsize=(bs*2, 5))

log.plot_epochs()
mkdir: cannot create directory checkpoint: File exists
EPOCH: 1.000	lossA: 0.139	lossB: 0.149	(1.76s - 17552.61s remaining)
EPOCH: 2.000	lossA: 0.137	lossB: 0.149	(2.62s - 13099.50s remaining)
EPOCH: 3.000	lossA: 0.134	lossB: 0.142	(3.49s - 11622.35s remaining)
EPOCH: 4.000	lossA: 0.118	lossB: 0.125	(4.37s - 10921.87s remaining)
EPOCH: 5.000	lossA: 0.109	lossB: 0.113	(5.25s - 10493.76s remaining)
EPOCH: 6.000	lossA: 0.096	lossB: 0.105	(6.15s - 10247.40s remaining)
EPOCH: 7.000	lossA: 0.093	lossB: 0.103	(7.02s - 10028.24s remaining)
EPOCH: 8.000	lossA: 0.092	lossB: 0.102	(7.91s - 9874.61s remaining))
EPOCH: 9.000	lossA: 0.094	lossB: 0.099	(8.90s - 9878.02s remaining))
EPOCH: 10.000	lossA: 0.093	lossB: 0.103	(10.01s - 10003.98s remaining)
EPOCH: 11.000	lossA: 0.093	lossB: 0.099	(10.93s - 9929.14s remaining))
EPOCH: 12.000	lossA: 0.090	lossB: 0.098	(11.84s - 9851.34s remaining)
EPOCH: 13.000	lossA: 0.093	lossB: 0.102	(12.87s - 9889.94s remaining)
EPOCH: 14.000	lossA: 0.096	lossB: 0.102	(13.75s - 9810.10s remaining)
EPOCH: 15.000	lossA: 0.094	lossB: 0.101	(14.65s - 9750.43s remaining)
EPOCH: 16.000	lossA: 0.092	lossB: 0.097	(15.53s - 9690.47s remaining)
EPOCH: 17.000	lossA: 0.089	lossB: 0.100	(16.43s - 9647.20s remaining)
EPOCH: 18.000	lossA: 0.088	lossB: 0.095	(17.31s - 9601.83s remaining)
EPOCH: 19.000	lossA: 0.092	lossB: 0.090	(18.25s - 9585.49s remaining)
EPOCH: 20.000	lossA: 0.090	lossB: 0.090	(19.12s - 9541.80s remaining)
EPOCH: 21.000	lossA: 0.091	lossB: 0.087	(20.02s - 9512.30s remaining)
EPOCH: 22.000	lossA: 0.091	lossB: 0.091	(20.92s - 9486.22s remaining)
EPOCH: 23.000	lossA: 0.090	lossB: 0.086	(21.80s - 9455.60s remaining)
EPOCH: 24.000	lossA: 0.088	lossB: 0.084	(22.69s - 9429.68s remaining)
EPOCH: 25.000	lossA: 0.087	lossB: 0.081	(23.75s - 9474.44s remaining)
EPOCH: 26.000	lossA: 0.094	lossB: 0.083	(24.64s - 9453.14s remaining)
EPOCH: 27.000	lossA: 0.093	lossB: 0.087	(25.53s - 9428.41s remaining)
EPOCH: 28.000	lossA: 0.090	lossB: 0.080	(26.42s - 9408.78s remaining)
EPOCH: 29.000	lossA: 0.088	lossB: 0.079	(27.32s - 9394.34s remaining)
EPOCH: 30.000	lossA: 0.088	lossB: 0.077	(28.22s - 9378.07s remaining)
EPOCH: 31.000	lossA: 0.091	lossB: 0.079	(29.09s - 9356.22s remaining)
EPOCH: 32.000	lossA: 0.089	lossB: 0.077	(29.96s - 9332.41s remaining)
EPOCH: 33.000	lossA: 0.090	lossB: 0.079	(30.83s - 9310.78s remaining)
EPOCH: 34.000	lossA: 0.088	lossB: 0.076	(31.71s - 9296.19s remaining)
EPOCH: 35.000	lossA: 0.087	lossB: 0.080	(32.59s - 9279.21s remaining)
EPOCH: 36.000	lossA: 0.088	lossB: 0.078	(33.73s - 9334.53s remaining)
EPOCH: 37.000	lossA: 0.088	lossB: 0.077	(34.77s - 9361.94s remaining)
EPOCH: 38.000	lossA: 0.086	lossB: 0.078	(35.66s - 9347.50s remaining)
EPOCH: 39.000	lossA: 0.090	lossB: 0.077	(36.55s - 9334.48s remaining)
EPOCH: 40.000	lossA: 0.087	lossB: 0.076	(37.63s - 9370.58s remaining)
EPOCH: 41.000	lossA: 0.086	lossB: 0.076	(38.51s - 9353.73s remaining)
EPOCH: 42.000	lossA: 0.085	lossB: 0.076	(39.42s - 9345.13s remaining)
EPOCH: 43.000	lossA: 0.089	lossB: 0.079	(40.30s - 9331.86s remaining)
EPOCH: 44.000	lossA: 0.088	lossB: 0.077	(41.22s - 9327.07s remaining)
EPOCH: 45.000	lossA: 0.087	lossB: 0.079	(42.30s - 9356.74s remaining)
EPOCH: 46.000	lossA: 0.088	lossB: 0.079	(43.37s - 9384.66s remaining)
EPOCH: 47.000	lossA: 0.089	lossB: 0.078	(44.24s - 9368.66s remaining)
EPOCH: 48.000	lossA: 0.092	lossB: 0.081	(45.23s - 9378.62s remaining)
EPOCH: 49.000	lossA: 0.089	lossB: 0.077	(46.19s - 9379.72s remaining)
EPOCH: 50.000	lossA: 0.088	lossB: 0.076	(47.08s - 9368.09s remaining)
EPOCH: 51.000	lossA: 0.086	lossB: 0.076	(47.97s - 9357.04s remaining)
EPOCH: 52.000	lossA: 0.087	lossB: 0.077	(48.86s - 9347.82s remaining)
EPOCH: 53.000	lossA: 0.084	lossB: 0.075	(49.74s - 9334.42s remaining)
EPOCH: 54.000	lossA: 0.085	lossB: 0.073	(50.62s - 9323.13s remaining)
EPOCH: 55.000	lossA: 0.083	lossB: 0.074	(51.50s - 9312.77s remaining)
EPOCH: 56.000	lossA: 0.082	lossB: 0.072	(52.38s - 9301.78s remaining)
EPOCH: 57.000	lossA: 0.080	lossB: 0.071	(53.26s - 9290.23s remaining)
EPOCH: 58.000	lossA: 0.080	lossB: 0.071	(54.13s - 9278.51s remaining)
EPOCH: 59.000	lossA: 0.079	lossB: 0.071	(55.01s - 9269.24s remaining)
EPOCH: 60.000	lossA: 0.076	lossB: 0.068	(55.93s - 9265.31s remaining)
EPOCH: 61.000	lossA: 0.079	lossB: 0.069	(56.92s - 9274.52s remaining)
EPOCH: 62.000	lossA: 0.076	lossB: 0.069	(57.80s - 9265.46s remaining)
EPOCH: 63.000	lossA: 0.076	lossB: 0.066	(58.68s - 9255.77s remaining)
EPOCH: 64.000	lossA: 0.074	lossB: 0.067	(59.56s - 9247.13s remaining)
EPOCH: 65.000	lossA: 0.076	lossB: 0.067	(60.44s - 9237.91s remaining)
EPOCH: 66.000	lossA: 0.075	lossB: 0.067	(61.33s - 9230.99s remaining)
EPOCH: 67.000	lossA: 0.073	lossB: 0.064	(62.21s - 9222.39s remaining)
EPOCH: 68.000	lossA: 0.071	lossB: 0.063	(63.09s - 9215.32s remaining)
EPOCH: 69.000	lossA: 0.070	lossB: 0.064	(63.96s - 9205.16s remaining)
EPOCH: 70.000	lossA: 0.070	lossB: 0.066	(64.84s - 9197.61s remaining)
EPOCH: 71.000	lossA: 0.068	lossB: 0.065	(65.72s - 9190.89s remaining)
EPOCH: 72.000	lossA: 0.072	lossB: 0.065	(66.59s - 9181.92s remaining)
EPOCH: 73.000	lossA: 0.069	lossB: 0.065	(67.67s - 9201.66s remaining)
EPOCH: 74.000	lossA: 0.069	lossB: 0.063	(68.54s - 9193.62s remaining)
EPOCH: 75.000	lossA: 0.066	lossB: 0.059	(69.41s - 9185.85s remaining)
EPOCH: 76.000	lossA: 0.066	lossB: 0.060	(70.30s - 9179.45s remaining)
EPOCH: 77.000	lossA: 0.067	lossB: 0.061	(71.21s - 9176.62s remaining)
EPOCH: 78.000	lossA: 0.066	lossB: 0.058	(72.12s - 9173.91s remaining)
EPOCH: 79.000	lossA: 0.067	lossB: 0.060	(73.01s - 9168.72s remaining)
EPOCH: 80.000	lossA: 0.066	lossB: 0.059	(73.90s - 9163.35s remaining)
EPOCH: 81.000	lossA: 0.067	lossB: 0.056	(74.89s - 9170.86s remaining)
EPOCH: 82.000	lossA: 0.064	lossB: 0.058	(75.98s - 9189.94s remaining)
EPOCH: 83.000	lossA: 0.068	lossB: 0.058	(76.87s - 9185.00s remaining)
EPOCH: 84.000	lossA: 0.064	lossB: 0.059	(77.81s - 9185.63s remaining)
EPOCH: 85.000	lossA: 0.065	lossB: 0.057	(78.88s - 9201.17s remaining)
EPOCH: 86.000	lossA: 0.063	lossB: 0.055	(79.76s - 9194.83s remaining)
EPOCH: 87.000	lossA: 0.064	lossB: 0.056	(80.65s - 9189.49s remaining)
EPOCH: 88.000	lossA: 0.066	lossB: 0.057	(81.54s - 9184.14s remaining)
EPOCH: 89.000	lossA: 0.065	lossB: 0.055	(82.43s - 9179.16s remaining)
EPOCH: 90.000	lossA: 0.064	lossB: 0.058	(83.30s - 9172.47s remaining)
EPOCH: 91.000	lossA: 0.062	lossB: 0.055	(84.19s - 9166.94s remaining)
EPOCH: 92.000	lossA: 0.062	lossB: 0.055	(85.06s - 9160.67s remaining)
EPOCH: 93.000	lossA: 0.063	lossB: 0.055	(85.93s - 9154.27s remaining)
EPOCH: 94.000	lossA: 0.064	lossB: 0.054	(86.82s - 9148.97s remaining)
EPOCH: 95.000	lossA: 0.065	lossB: 0.055	(87.69s - 9143.33s remaining)
EPOCH: 96.000	lossA: 0.063	lossB: 0.054	(88.57s - 9137.44s remaining)
EPOCH: 97.000	lossA: 0.061	lossB: 0.057	(89.66s - 9153.39s remaining)
EPOCH: 98.000	lossA: 0.061	lossB: 0.053	(90.54s - 9147.81s remaining)
EPOCH: 99.000	lossA: 0.060	lossB: 0.050	(91.41s - 9141.92s remaining)
EPOCH: 100.000	lossA: 0.061	lossB: 0.053	(92.29s - 9136.65s remaining)
2020-11-08 07:18:21.553 | INFO     | torch_snippets.loader:subplots:375 - plotting 15 images in a grid of 3x5 @ (10, 5)
==============================A TO B==============================
2020-11-08 07:18:22.547 | INFO     | torch_snippets.loader:subplots:375 - plotting 15 images in a grid of 3x5 @ (10, 5)
==============================B TO A==============================
EPOCH: 101.000	lossA: 0.061	lossB: 0.051	(96.39s - 9447.10s remaining)
EPOCH: 102.000	lossA: 0.063	lossB: 0.051	(97.27s - 9439.45s remaining)
EPOCH: 103.000	lossA: 0.063	lossB: 0.053	(98.21s - 9437.00s remaining)
EPOCH: 104.000	lossA: 0.058	lossB: 0.051	(99.11s - 9431.09s remaining)
EPOCH: 105.000	lossA: 0.060	lossB: 0.049	(100.01s - 9424.46s remaining)
EPOCH: 106.000	lossA: 0.062	lossB: 0.052	(101.20s - 9445.77s remaining)
EPOCH: 106.250	lossA: 0.058	lossB: 0.051	(101.45s - 9446.94s remaining)
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
<ipython-input-14-767a20d7fcbf> in <module>
      5     N = len(dataloader)
      6     for bx,data in enumerate(dataloader):
----> 7         lossA, lossB = train_batch(model, data, criterion, optimizers)
      8         log.record(ex+(1+bx)/N, lossA=lossA, lossB=lossB, end='\r')
      9 

<ipython-input-10-b295b379f984> in train_batch(model, data, criterion, optimizes)
     15     optB.step()
     16 
---> 17     return lossA.item(), lossB.item()

KeyboardInterrupt: