1004 KiB
1004 KiB
import os
if not os.path.exists('Faceswap-Deepfake-Pytorch'):
!wget -q https://www.dropbox.com/s/5ji7jl7httso9ny/person_images.zip
!wget -q https://raw.githubusercontent.com/sizhky/deep-fake-util/main/random_warp.py
!unzip -q person_images.zip
!pip install -q torch_snippets torch_summary
from torch_snippets import *
from random_warp import get_training_data
face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml')
def crop_face(img):
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
faces = face_cascade.detectMultiScale(gray, 1.3, 5)
if(len(faces)>0):
for (x,y,w,h) in faces:
img2 = img[y:(y+h),x:(x+w),:]
img2 = cv2.resize(img2,(256,256))
return img2, True
else:
return img, False
!mkdir cropped_faces_personA
!mkdir cropped_faces_personB
def crop_images(folder):
images = Glob(folder+'/*.jpg')
for i in range(len(images)):
img = read(images[i],1)
img2, face_detected = crop_face(img)
if(face_detected==False):
continue
else:
cv2.imwrite('cropped_faces_'+folder+'/'+str(i)+'.jpg',cv2.cvtColor(img2, cv2.COLOR_RGB2BGR))
crop_images('personA')
crop_images('personB')
mkdir: cannot create directory ‘cropped_faces_personA’: File exists mkdir: cannot create directory ‘cropped_faces_personB’: File exists
2020-11-08 07:23:24.933 | INFO | torch_snippets.loader:Glob:181 - 444 files found at personA/*.jpg
class ImageDataset(Dataset):
def __init__(self, items_A, items_B):
self.items_A = np.concatenate([read(f,1)[None] for f in items_A])/255.
self.items_B = np.concatenate([read(f,1)[None] for f in items_B])/255.
self.items_A += self.items_B.mean(axis=(0, 1, 2)) - self.items_A.mean(axis=(0, 1, 2))
def __len__(self):
return min(len(self.items_A), len(self.items_B))
def __getitem__(self, ix):
a, b = choose(self.items_A), choose(self.items_B)
return a, b
def collate_fn(self, batch):
imsA, imsB = list(zip(*batch))
imsA, targetA = get_training_data(imsA, len(imsA))
imsB, targetB = get_training_data(imsB, len(imsB))
imsA, imsB, targetA, targetB = [torch.Tensor(i).permute(0,3,1,2).to(device) for i in [imsA, imsB, targetA, targetB]]
return imsA, imsB, targetA, targetB
a = ImageDataset(Glob('cropped_faces_personA'), Glob('cropped_faces_personB'))
x = DataLoader(a, batch_size=32, collate_fn=a.collate_fn)
2020-11-08 07:16:09.186 | INFO | torch_snippets.loader:Glob:181 - 349 files found at cropped_faces_personA 2020-11-08 07:16:09.189 | INFO | torch_snippets.loader:Glob:181 - 105 files found at cropped_faces_personB
inspect(*next(iter(x)))
for i in next(iter(x)):
subplots(i[:8], nc=4, sz=(4,2))
================================================================== Tensor Shape: torch.Size([32, 3, 64, 64]) Min: -0.006 Max: 0.884 Mean: 0.506 dtype: torch.float32 ================================================================== Tensor Shape: torch.Size([32, 3, 64, 64]) Min: 0.023 Max: 0.927 Mean: 0.490 dtype: torch.float32 ================================================================== Tensor Shape: torch.Size([32, 3, 64, 64]) Min: 0.003 Max: 0.882 Mean: 0.506 dtype: torch.float32 ================================================================== Tensor Shape: torch.Size([32, 3, 64, 64]) Min: 0.023 Max: 0.927 Mean: 0.492 dtype: torch.float32 ==================================================================
2020-11-08 07:16:16.342 | INFO | torch_snippets.loader:subplots:375 - plotting 8 images in a grid of 2x4 @ (5, 5)
2020-11-08 07:16:16.710 | INFO | torch_snippets.loader:subplots:375 - plotting 8 images in a grid of 2x4 @ (5, 5)
2020-11-08 07:16:17.026 | INFO | torch_snippets.loader:subplots:375 - plotting 8 images in a grid of 2x4 @ (5, 5)
2020-11-08 07:16:17.411 | INFO | torch_snippets.loader:subplots:375 - plotting 8 images in a grid of 2x4 @ (5, 5)
def _ConvLayer(input_features, output_features):
return nn.Sequential(
nn.Conv2d(input_features, output_features, kernel_size=5, stride=2, padding=2),
nn.LeakyReLU(0.1, inplace=True)
)
def _UpScale(input_features, output_features):
return nn.Sequential(
nn.ConvTranspose2d(input_features, output_features, kernel_size=2, stride=2, padding=0),
nn.LeakyReLU(0.1, inplace=True)
)
class Reshape(nn.Module):
def forward(self, input):
output = input.view(-1, 1024, 4, 4) # channel * 4 * 4
return output
class Autoencoder(nn.Module):
def __init__(self):
super(Autoencoder, self).__init__()
self.encoder = nn.Sequential(
_ConvLayer(3, 128),
_ConvLayer(128, 256),
_ConvLayer(256, 512),
_ConvLayer(512, 1024),
nn.Flatten(),
nn.Linear(1024 * 4 * 4, 1024),
nn.Linear(1024, 1024 * 4 * 4),
Reshape(),
_UpScale(1024, 512),
)
self.decoder_A = nn.Sequential(
_UpScale(512, 256),
_UpScale(256, 128),
_UpScale(128, 64),
nn.Conv2d(64, 3, kernel_size=3, padding=1),
nn.Sigmoid(),
)
self.decoder_B = nn.Sequential(
_UpScale(512, 256),
_UpScale(256, 128),
_UpScale(128, 64),
nn.Conv2d(64, 3, kernel_size=3, padding=1),
nn.Sigmoid(),
)
def forward(self, x, select='A'):
if select == 'A':
out = self.encoder(x)
out = self.decoder_A(out)
else:
out = self.encoder(x)
out = self.decoder_B(out)
return out
from torchsummary import summary
model = Autoencoder()
summary(model, torch.zeros(32,3,64,64), 'A');
========================================================================================== Layer (type:depth-idx) Output Shape Param # ========================================================================================== ├─Sequential: 1-1 [-1, 512, 8, 8] -- | └─Sequential: 2-1 [-1, 128, 32, 32] -- | | └─Conv2d: 3-1 [-1, 128, 32, 32] 9,728 | | └─LeakyReLU: 3-2 [-1, 128, 32, 32] -- | └─Sequential: 2-2 [-1, 256, 16, 16] -- | | └─Conv2d: 3-3 [-1, 256, 16, 16] 819,456 | | └─LeakyReLU: 3-4 [-1, 256, 16, 16] -- | └─Sequential: 2-3 [-1, 512, 8, 8] -- | | └─Conv2d: 3-5 [-1, 512, 8, 8] 3,277,312 | | └─LeakyReLU: 3-6 [-1, 512, 8, 8] -- | └─Sequential: 2-4 [-1, 1024, 4, 4] -- | | └─Conv2d: 3-7 [-1, 1024, 4, 4] 13,108,224 | | └─LeakyReLU: 3-8 [-1, 1024, 4, 4] -- | └─Flatten: 2-5 [-1, 16384] -- | └─Linear: 2-6 [-1, 1024] 16,778,240 | └─Linear: 2-7 [-1, 16384] 16,793,600 | └─Reshape: 2-8 [-1, 1024, 4, 4] -- | └─Sequential: 2-9 [-1, 512, 8, 8] -- | | └─ConvTranspose2d: 3-9 [-1, 512, 8, 8] 2,097,664 | | └─LeakyReLU: 3-10 [-1, 512, 8, 8] -- ├─Sequential: 1-2 [-1, 3, 64, 64] -- | └─Sequential: 2-10 [-1, 256, 16, 16] -- | | └─ConvTranspose2d: 3-11 [-1, 256, 16, 16] 524,544 | | └─LeakyReLU: 3-12 [-1, 256, 16, 16] -- | └─Sequential: 2-11 [-1, 128, 32, 32] -- | | └─ConvTranspose2d: 3-13 [-1, 128, 32, 32] 131,200 | | └─LeakyReLU: 3-14 [-1, 128, 32, 32] -- | └─Sequential: 2-12 [-1, 64, 64, 64] -- | | └─ConvTranspose2d: 3-15 [-1, 64, 64, 64] 32,832 | | └─LeakyReLU: 3-16 [-1, 64, 64, 64] -- | └─Conv2d: 2-13 [-1, 3, 64, 64] 1,731 | └─Sigmoid: 2-14 [-1, 3, 64, 64] -- ========================================================================================== Total params: 53,574,531 Trainable params: 53,574,531 Non-trainable params: 0 Total mult-adds (G): 1.29 ========================================================================================== Input size (MB): 1.50 Forward/backward pass size (MB): 5.85 Params size (MB): 204.37 Estimated Total Size (MB): 211.72 ==========================================================================================
def train_batch(model, data, criterion, optimizes):
optA, optB = optimizers
optA.zero_grad()
optB.zero_grad()
imgA, imgB, targetA, targetB = data
_imgA, _imgB = model(imgA, 'A'), model(imgB, 'B')
lossA = criterion(_imgA, targetA)
lossB = criterion(_imgB, targetB)
lossA.backward()
lossB.backward()
optA.step()
optB.step()
return lossA.item(), lossB.item()
model = Autoencoder().to(device)
dataset = ImageDataset(Glob('cropped_faces_personA'), Glob('cropped_faces_personB'))
dataloader = DataLoader(dataset, 32, collate_fn=dataset.collate_fn)
optimizers = optim.Adam([{'params': model.encoder.parameters()},
{'params': model.decoder_A.parameters()}],
lr=5e-5, betas=(0.5, 0.999)), \
optim.Adam([{'params': model.encoder.parameters()},
{'params': model.decoder_B.parameters()}],
lr=5e-5, betas=(0.5, 0.999))
criterion = nn.L1Loss()
2020-11-08 07:16:45.033 | INFO | torch_snippets.loader:Glob:181 - 349 files found at cropped_faces_personA 2020-11-08 07:16:45.036 | INFO | torch_snippets.loader:Glob:181 - 105 files found at cropped_faces_personB
n_epochs = 10000
log = Report(n_epochs)
!mkdir checkpoint
for ex in range(n_epochs):
N = len(dataloader)
for bx,data in enumerate(dataloader):
lossA, lossB = train_batch(model, data, criterion, optimizers)
log.record(ex+(1+bx)/N, lossA=lossA, lossB=lossB, end='\r')
log.report_avgs(ex+1)
if (ex+1)%100 == 0:
state = {
'state': model.state_dict(),
'epoch': ex
}
torch.save(state, './checkpoint/autoencoder.pth')
if (ex+1)%100 == 0:
bs = 5
a,b,A,B = data
line('A to B')
_a = model(a[:bs], 'A')
_b = model(a[:bs], 'B')
x = torch.cat([A[:bs],_a,_b])
subplots(x, nc=bs, figsize=(bs*2, 5))
line('B to A')
_a = model(b[:bs], 'A')
_b = model(b[:bs], 'B')
x = torch.cat([B[:bs],_a,_b])
subplots(x, nc=bs, figsize=(bs*2, 5))
log.plot_epochs()
mkdir: cannot create directory ‘checkpoint’: File exists EPOCH: 1.000 lossA: 0.139 lossB: 0.149 (1.76s - 17552.61s remaining) EPOCH: 2.000 lossA: 0.137 lossB: 0.149 (2.62s - 13099.50s remaining) EPOCH: 3.000 lossA: 0.134 lossB: 0.142 (3.49s - 11622.35s remaining) EPOCH: 4.000 lossA: 0.118 lossB: 0.125 (4.37s - 10921.87s remaining) EPOCH: 5.000 lossA: 0.109 lossB: 0.113 (5.25s - 10493.76s remaining) EPOCH: 6.000 lossA: 0.096 lossB: 0.105 (6.15s - 10247.40s remaining) EPOCH: 7.000 lossA: 0.093 lossB: 0.103 (7.02s - 10028.24s remaining) EPOCH: 8.000 lossA: 0.092 lossB: 0.102 (7.91s - 9874.61s remaining)) EPOCH: 9.000 lossA: 0.094 lossB: 0.099 (8.90s - 9878.02s remaining)) EPOCH: 10.000 lossA: 0.093 lossB: 0.103 (10.01s - 10003.98s remaining) EPOCH: 11.000 lossA: 0.093 lossB: 0.099 (10.93s - 9929.14s remaining)) EPOCH: 12.000 lossA: 0.090 lossB: 0.098 (11.84s - 9851.34s remaining) EPOCH: 13.000 lossA: 0.093 lossB: 0.102 (12.87s - 9889.94s remaining) EPOCH: 14.000 lossA: 0.096 lossB: 0.102 (13.75s - 9810.10s remaining) EPOCH: 15.000 lossA: 0.094 lossB: 0.101 (14.65s - 9750.43s remaining) EPOCH: 16.000 lossA: 0.092 lossB: 0.097 (15.53s - 9690.47s remaining) EPOCH: 17.000 lossA: 0.089 lossB: 0.100 (16.43s - 9647.20s remaining) EPOCH: 18.000 lossA: 0.088 lossB: 0.095 (17.31s - 9601.83s remaining) EPOCH: 19.000 lossA: 0.092 lossB: 0.090 (18.25s - 9585.49s remaining) EPOCH: 20.000 lossA: 0.090 lossB: 0.090 (19.12s - 9541.80s remaining) EPOCH: 21.000 lossA: 0.091 lossB: 0.087 (20.02s - 9512.30s remaining) EPOCH: 22.000 lossA: 0.091 lossB: 0.091 (20.92s - 9486.22s remaining) EPOCH: 23.000 lossA: 0.090 lossB: 0.086 (21.80s - 9455.60s remaining) EPOCH: 24.000 lossA: 0.088 lossB: 0.084 (22.69s - 9429.68s remaining) EPOCH: 25.000 lossA: 0.087 lossB: 0.081 (23.75s - 9474.44s remaining) EPOCH: 26.000 lossA: 0.094 lossB: 0.083 (24.64s - 9453.14s remaining) EPOCH: 27.000 lossA: 0.093 lossB: 0.087 (25.53s - 9428.41s remaining) EPOCH: 28.000 lossA: 0.090 lossB: 0.080 (26.42s - 9408.78s remaining) EPOCH: 29.000 lossA: 0.088 lossB: 0.079 (27.32s - 9394.34s remaining) EPOCH: 30.000 lossA: 0.088 lossB: 0.077 (28.22s - 9378.07s remaining) EPOCH: 31.000 lossA: 0.091 lossB: 0.079 (29.09s - 9356.22s remaining) EPOCH: 32.000 lossA: 0.089 lossB: 0.077 (29.96s - 9332.41s remaining) EPOCH: 33.000 lossA: 0.090 lossB: 0.079 (30.83s - 9310.78s remaining) EPOCH: 34.000 lossA: 0.088 lossB: 0.076 (31.71s - 9296.19s remaining) EPOCH: 35.000 lossA: 0.087 lossB: 0.080 (32.59s - 9279.21s remaining) EPOCH: 36.000 lossA: 0.088 lossB: 0.078 (33.73s - 9334.53s remaining) EPOCH: 37.000 lossA: 0.088 lossB: 0.077 (34.77s - 9361.94s remaining) EPOCH: 38.000 lossA: 0.086 lossB: 0.078 (35.66s - 9347.50s remaining) EPOCH: 39.000 lossA: 0.090 lossB: 0.077 (36.55s - 9334.48s remaining) EPOCH: 40.000 lossA: 0.087 lossB: 0.076 (37.63s - 9370.58s remaining) EPOCH: 41.000 lossA: 0.086 lossB: 0.076 (38.51s - 9353.73s remaining) EPOCH: 42.000 lossA: 0.085 lossB: 0.076 (39.42s - 9345.13s remaining) EPOCH: 43.000 lossA: 0.089 lossB: 0.079 (40.30s - 9331.86s remaining) EPOCH: 44.000 lossA: 0.088 lossB: 0.077 (41.22s - 9327.07s remaining) EPOCH: 45.000 lossA: 0.087 lossB: 0.079 (42.30s - 9356.74s remaining) EPOCH: 46.000 lossA: 0.088 lossB: 0.079 (43.37s - 9384.66s remaining) EPOCH: 47.000 lossA: 0.089 lossB: 0.078 (44.24s - 9368.66s remaining) EPOCH: 48.000 lossA: 0.092 lossB: 0.081 (45.23s - 9378.62s remaining) EPOCH: 49.000 lossA: 0.089 lossB: 0.077 (46.19s - 9379.72s remaining) EPOCH: 50.000 lossA: 0.088 lossB: 0.076 (47.08s - 9368.09s remaining) EPOCH: 51.000 lossA: 0.086 lossB: 0.076 (47.97s - 9357.04s remaining) EPOCH: 52.000 lossA: 0.087 lossB: 0.077 (48.86s - 9347.82s remaining) EPOCH: 53.000 lossA: 0.084 lossB: 0.075 (49.74s - 9334.42s remaining) EPOCH: 54.000 lossA: 0.085 lossB: 0.073 (50.62s - 9323.13s remaining) EPOCH: 55.000 lossA: 0.083 lossB: 0.074 (51.50s - 9312.77s remaining) EPOCH: 56.000 lossA: 0.082 lossB: 0.072 (52.38s - 9301.78s remaining) EPOCH: 57.000 lossA: 0.080 lossB: 0.071 (53.26s - 9290.23s remaining) EPOCH: 58.000 lossA: 0.080 lossB: 0.071 (54.13s - 9278.51s remaining) EPOCH: 59.000 lossA: 0.079 lossB: 0.071 (55.01s - 9269.24s remaining) EPOCH: 60.000 lossA: 0.076 lossB: 0.068 (55.93s - 9265.31s remaining) EPOCH: 61.000 lossA: 0.079 lossB: 0.069 (56.92s - 9274.52s remaining) EPOCH: 62.000 lossA: 0.076 lossB: 0.069 (57.80s - 9265.46s remaining) EPOCH: 63.000 lossA: 0.076 lossB: 0.066 (58.68s - 9255.77s remaining) EPOCH: 64.000 lossA: 0.074 lossB: 0.067 (59.56s - 9247.13s remaining) EPOCH: 65.000 lossA: 0.076 lossB: 0.067 (60.44s - 9237.91s remaining) EPOCH: 66.000 lossA: 0.075 lossB: 0.067 (61.33s - 9230.99s remaining) EPOCH: 67.000 lossA: 0.073 lossB: 0.064 (62.21s - 9222.39s remaining) EPOCH: 68.000 lossA: 0.071 lossB: 0.063 (63.09s - 9215.32s remaining) EPOCH: 69.000 lossA: 0.070 lossB: 0.064 (63.96s - 9205.16s remaining) EPOCH: 70.000 lossA: 0.070 lossB: 0.066 (64.84s - 9197.61s remaining) EPOCH: 71.000 lossA: 0.068 lossB: 0.065 (65.72s - 9190.89s remaining) EPOCH: 72.000 lossA: 0.072 lossB: 0.065 (66.59s - 9181.92s remaining) EPOCH: 73.000 lossA: 0.069 lossB: 0.065 (67.67s - 9201.66s remaining) EPOCH: 74.000 lossA: 0.069 lossB: 0.063 (68.54s - 9193.62s remaining) EPOCH: 75.000 lossA: 0.066 lossB: 0.059 (69.41s - 9185.85s remaining) EPOCH: 76.000 lossA: 0.066 lossB: 0.060 (70.30s - 9179.45s remaining) EPOCH: 77.000 lossA: 0.067 lossB: 0.061 (71.21s - 9176.62s remaining) EPOCH: 78.000 lossA: 0.066 lossB: 0.058 (72.12s - 9173.91s remaining) EPOCH: 79.000 lossA: 0.067 lossB: 0.060 (73.01s - 9168.72s remaining) EPOCH: 80.000 lossA: 0.066 lossB: 0.059 (73.90s - 9163.35s remaining) EPOCH: 81.000 lossA: 0.067 lossB: 0.056 (74.89s - 9170.86s remaining) EPOCH: 82.000 lossA: 0.064 lossB: 0.058 (75.98s - 9189.94s remaining) EPOCH: 83.000 lossA: 0.068 lossB: 0.058 (76.87s - 9185.00s remaining) EPOCH: 84.000 lossA: 0.064 lossB: 0.059 (77.81s - 9185.63s remaining) EPOCH: 85.000 lossA: 0.065 lossB: 0.057 (78.88s - 9201.17s remaining) EPOCH: 86.000 lossA: 0.063 lossB: 0.055 (79.76s - 9194.83s remaining) EPOCH: 87.000 lossA: 0.064 lossB: 0.056 (80.65s - 9189.49s remaining) EPOCH: 88.000 lossA: 0.066 lossB: 0.057 (81.54s - 9184.14s remaining) EPOCH: 89.000 lossA: 0.065 lossB: 0.055 (82.43s - 9179.16s remaining) EPOCH: 90.000 lossA: 0.064 lossB: 0.058 (83.30s - 9172.47s remaining) EPOCH: 91.000 lossA: 0.062 lossB: 0.055 (84.19s - 9166.94s remaining) EPOCH: 92.000 lossA: 0.062 lossB: 0.055 (85.06s - 9160.67s remaining) EPOCH: 93.000 lossA: 0.063 lossB: 0.055 (85.93s - 9154.27s remaining) EPOCH: 94.000 lossA: 0.064 lossB: 0.054 (86.82s - 9148.97s remaining) EPOCH: 95.000 lossA: 0.065 lossB: 0.055 (87.69s - 9143.33s remaining) EPOCH: 96.000 lossA: 0.063 lossB: 0.054 (88.57s - 9137.44s remaining) EPOCH: 97.000 lossA: 0.061 lossB: 0.057 (89.66s - 9153.39s remaining) EPOCH: 98.000 lossA: 0.061 lossB: 0.053 (90.54s - 9147.81s remaining) EPOCH: 99.000 lossA: 0.060 lossB: 0.050 (91.41s - 9141.92s remaining) EPOCH: 100.000 lossA: 0.061 lossB: 0.053 (92.29s - 9136.65s remaining)
2020-11-08 07:18:21.553 | INFO | torch_snippets.loader:subplots:375 - plotting 15 images in a grid of 3x5 @ (10, 5)
==============================A TO B==============================
2020-11-08 07:18:22.547 | INFO | torch_snippets.loader:subplots:375 - plotting 15 images in a grid of 3x5 @ (10, 5)
==============================B TO A==============================
EPOCH: 101.000 lossA: 0.061 lossB: 0.051 (96.39s - 9447.10s remaining) EPOCH: 102.000 lossA: 0.063 lossB: 0.051 (97.27s - 9439.45s remaining) EPOCH: 103.000 lossA: 0.063 lossB: 0.053 (98.21s - 9437.00s remaining) EPOCH: 104.000 lossA: 0.058 lossB: 0.051 (99.11s - 9431.09s remaining) EPOCH: 105.000 lossA: 0.060 lossB: 0.049 (100.01s - 9424.46s remaining) EPOCH: 106.000 lossA: 0.062 lossB: 0.052 (101.20s - 9445.77s remaining) EPOCH: 106.250 lossA: 0.058 lossB: 0.051 (101.45s - 9446.94s remaining)
[0;31m---------------------------------------------------------------------------[0m [0;31mKeyboardInterrupt[0m Traceback (most recent call last) [0;32m<ipython-input-14-767a20d7fcbf>[0m in [0;36m<module>[0;34m[0m [1;32m 5[0m [0mN[0m [0;34m=[0m [0mlen[0m[0;34m([0m[0mdataloader[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m [1;32m 6[0m [0;32mfor[0m [0mbx[0m[0;34m,[0m[0mdata[0m [0;32min[0m [0menumerate[0m[0;34m([0m[0mdataloader[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m [0;32m----> 7[0;31m [0mlossA[0m[0;34m,[0m [0mlossB[0m [0;34m=[0m [0mtrain_batch[0m[0;34m([0m[0mmodel[0m[0;34m,[0m [0mdata[0m[0;34m,[0m [0mcriterion[0m[0;34m,[0m [0moptimizers[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m [0m[1;32m 8[0m [0mlog[0m[0;34m.[0m[0mrecord[0m[0;34m([0m[0mex[0m[0;34m+[0m[0;34m([0m[0;36m1[0m[0;34m+[0m[0mbx[0m[0;34m)[0m[0;34m/[0m[0mN[0m[0;34m,[0m [0mlossA[0m[0;34m=[0m[0mlossA[0m[0;34m,[0m [0mlossB[0m[0;34m=[0m[0mlossB[0m[0;34m,[0m [0mend[0m[0;34m=[0m[0;34m'\r'[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m [1;32m 9[0m [0;34m[0m[0m [0;32m<ipython-input-10-b295b379f984>[0m in [0;36mtrain_batch[0;34m(model, data, criterion, optimizes)[0m [1;32m 15[0m [0moptB[0m[0;34m.[0m[0mstep[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m [1;32m 16[0m [0;34m[0m[0m [0;32m---> 17[0;31m [0;32mreturn[0m [0mlossA[0m[0;34m.[0m[0mitem[0m[0;34m([0m[0;34m)[0m[0;34m,[0m [0mlossB[0m[0;34m.[0m[0mitem[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m [0m [0;31mKeyboardInterrupt[0m: