186 KiB
186 KiB
!pip install -q torch_snippets
from torch_snippets import *
device = "cuda" if torch.cuda.is_available() else "cpu"
from torchvision.utils import make_grid
[K |████████████████████████████████| 36.7MB 81kB/s [K |████████████████████████████████| 61kB 8.7MB/s [K |████████████████████████████████| 102kB 13.2MB/s [?25h Building wheel for contextvars (setup.py) ... [?25l[?25hdone
from torchvision.datasets import MNIST
from torchvision import transforms
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=(0.5,), std=(0.5,))
])
data_loader = torch.utils.data.DataLoader(MNIST('~/data', train=True, download=True, transform=transform),batch_size=128, shuffle=True, drop_last=True)
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /root/data/MNIST/raw/train-images-idx3-ubyte.gz
HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))
Extracting /root/data/MNIST/raw/train-images-idx3-ubyte.gz to /root/data/MNIST/raw Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to /root/data/MNIST/raw/train-labels-idx1-ubyte.gz
HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))
Extracting /root/data/MNIST/raw/train-labels-idx1-ubyte.gz to /root/data/MNIST/raw Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to /root/data/MNIST/raw/t10k-images-idx3-ubyte.gz
HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))
Extracting /root/data/MNIST/raw/t10k-images-idx3-ubyte.gz to /root/data/MNIST/raw Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to /root/data/MNIST/raw/t10k-labels-idx1-ubyte.gz
HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))
Extracting /root/data/MNIST/raw/t10k-labels-idx1-ubyte.gz to /root/data/MNIST/raw Processing... Done!
/usr/local/lib/python3.6/dist-packages/torchvision/datasets/mnist.py:480: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /pytorch/torch/csrc/utils/tensor_numpy.cpp:141.) return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)
class Discriminator(nn.Module):
def __init__(self):
super().__init__()
self.model = nn.Sequential(
nn.Linear(784, 1024),
nn.LeakyReLU(0.2),
nn.Dropout(0.3),
nn.Linear(1024, 512),
nn.LeakyReLU(0.2),
nn.Dropout(0.3),
nn.Linear(512, 256),
nn.LeakyReLU(0.2),
nn.Dropout(0.3),
nn.Linear(256, 1),
nn.Sigmoid()
)
def forward(self, x): return self.model(x)
!pip install torch_summary
from torchsummary import summary
discriminator = Discriminator().to(device)
summary(discriminator,torch.zeros(1,784))
Requirement already satisfied: torch_summary in /usr/local/lib/python3.6/dist-packages (1.4.3) ========================================================================================== Layer (type:depth-idx) Output Shape Param # ========================================================================================== ├─Sequential: 1-1 [-1, 1] -- | └─Linear: 2-1 [-1, 1024] 803,840 | └─LeakyReLU: 2-2 [-1, 1024] -- | └─Dropout: 2-3 [-1, 1024] -- | └─Linear: 2-4 [-1, 512] 524,800 | └─LeakyReLU: 2-5 [-1, 512] -- | └─Dropout: 2-6 [-1, 512] -- | └─Linear: 2-7 [-1, 256] 131,328 | └─LeakyReLU: 2-8 [-1, 256] -- | └─Dropout: 2-9 [-1, 256] -- | └─Linear: 2-10 [-1, 1] 257 | └─Sigmoid: 2-11 [-1, 1] -- ========================================================================================== Total params: 1,460,225 Trainable params: 1,460,225 Non-trainable params: 0 Total mult-adds (M): 2.92 ========================================================================================== Input size (MB): 0.00 Forward/backward pass size (MB): 0.01 Params size (MB): 5.57 Estimated Total Size (MB): 5.59 ==========================================================================================
========================================================================================== Layer (type:depth-idx) Output Shape Param # ========================================================================================== ├─Sequential: 1-1 [-1, 1] -- | └─Linear: 2-1 [-1, 1024] 803,840 | └─LeakyReLU: 2-2 [-1, 1024] -- | └─Dropout: 2-3 [-1, 1024] -- | └─Linear: 2-4 [-1, 512] 524,800 | └─LeakyReLU: 2-5 [-1, 512] -- | └─Dropout: 2-6 [-1, 512] -- | └─Linear: 2-7 [-1, 256] 131,328 | └─LeakyReLU: 2-8 [-1, 256] -- | └─Dropout: 2-9 [-1, 256] -- | └─Linear: 2-10 [-1, 1] 257 | └─Sigmoid: 2-11 [-1, 1] -- ========================================================================================== Total params: 1,460,225 Trainable params: 1,460,225 Non-trainable params: 0 Total mult-adds (M): 2.92 ========================================================================================== Input size (MB): 0.00 Forward/backward pass size (MB): 0.01 Params size (MB): 5.57 Estimated Total Size (MB): 5.59 ==========================================================================================
class Generator(nn.Module):
def __init__(self):
super().__init__()
self.model = nn.Sequential(
nn.Linear(100, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 512),
nn.LeakyReLU(0.2),
nn.Linear(512, 1024),
nn.LeakyReLU(0.2),
nn.Linear(1024, 784),
nn.Tanh()
)
def forward(self, x): return self.model(x)
generator = Generator().to(device)
summary(generator,torch.zeros(1,100))
========================================================================================== Layer (type:depth-idx) Output Shape Param # ========================================================================================== ├─Sequential: 1-1 [-1, 784] -- | └─Linear: 2-1 [-1, 256] 25,856 | └─LeakyReLU: 2-2 [-1, 256] -- | └─Linear: 2-3 [-1, 512] 131,584 | └─LeakyReLU: 2-4 [-1, 512] -- | └─Linear: 2-5 [-1, 1024] 525,312 | └─LeakyReLU: 2-6 [-1, 1024] -- | └─Linear: 2-7 [-1, 784] 803,600 | └─Tanh: 2-8 [-1, 784] -- ========================================================================================== Total params: 1,486,352 Trainable params: 1,486,352 Non-trainable params: 0 Total mult-adds (M): 2.97 ========================================================================================== Input size (MB): 0.00 Forward/backward pass size (MB): 0.02 Params size (MB): 5.67 Estimated Total Size (MB): 5.69 ==========================================================================================
========================================================================================== Layer (type:depth-idx) Output Shape Param # ========================================================================================== ├─Sequential: 1-1 [-1, 784] -- | └─Linear: 2-1 [-1, 256] 25,856 | └─LeakyReLU: 2-2 [-1, 256] -- | └─Linear: 2-3 [-1, 512] 131,584 | └─LeakyReLU: 2-4 [-1, 512] -- | └─Linear: 2-5 [-1, 1024] 525,312 | └─LeakyReLU: 2-6 [-1, 1024] -- | └─Linear: 2-7 [-1, 784] 803,600 | └─Tanh: 2-8 [-1, 784] -- ========================================================================================== Total params: 1,486,352 Trainable params: 1,486,352 Non-trainable params: 0 Total mult-adds (M): 2.97 ========================================================================================== Input size (MB): 0.00 Forward/backward pass size (MB): 0.02 Params size (MB): 5.67 Estimated Total Size (MB): 5.69 ==========================================================================================
def noise(size):
n = torch.randn(size, 100)
return n.to(device)
def discriminator_train_step(real_data, fake_data):
d_optimizer.zero_grad()
prediction_real = discriminator(real_data)
error_real = loss(prediction_real, torch.ones(len(real_data), 1).to(device))
error_real.backward()
prediction_fake = discriminator(fake_data)
error_fake = loss(prediction_fake, torch.zeros(len(fake_data), 1).to(device))
error_fake.backward()
d_optimizer.step()
return error_real + error_fake
def generator_train_step(fake_data):
g_optimizer.zero_grad()
prediction = discriminator(fake_data)
error = loss(prediction, torch.ones(len(real_data), 1).to(device))
error.backward()
g_optimizer.step()
return error
discriminator = Discriminator().to(device)
generator = Generator().to(device)
d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002)
g_optimizer = optim.Adam(generator.parameters(), lr=0.0002)
loss = nn.BCELoss()
num_epochs = 200
log = Report(num_epochs)
for epoch in range(num_epochs):
N = len(data_loader)
for i, (images, _) in enumerate(data_loader):
real_data = images.view(len(images), -1).to(device)
fake_data = generator(noise(len(real_data))).to(device)
fake_data = fake_data.detach()
d_loss = discriminator_train_step(real_data, fake_data)
fake_data = generator(noise(len(real_data))).to(device)
g_loss = generator_train_step(fake_data)
log.record(epoch+(1+i)/N, d_loss=d_loss.item(), g_loss=g_loss.item(), end='\r')
log.report_avgs(epoch+1)
log.plot_epochs(['d_loss', 'g_loss'])
EPOCH: 1.000 d_loss: 0.839 g_loss: 3.254 (14.88s - 2960.95s remaining) EPOCH: 2.000 d_loss: 0.786 g_loss: 4.338 (27.37s - 2709.66s remaining) EPOCH: 3.000 d_loss: 0.915 g_loss: 2.351 (40.28s - 2644.75s remaining) EPOCH: 4.000 d_loss: 0.810 g_loss: 2.462 (52.94s - 2594.29s remaining) EPOCH: 5.000 d_loss: 0.679 g_loss: 2.651 (65.44s - 2552.31s remaining) EPOCH: 6.000 d_loss: 0.333 g_loss: 3.991 (78.65s - 2543.15s remaining) EPOCH: 7.000 d_loss: 0.444 g_loss: 3.567 (92.02s - 2537.07s remaining) EPOCH: 8.000 d_loss: 0.454 g_loss: 3.225 (104.66s - 2511.72s remaining) EPOCH: 9.000 d_loss: 0.494 g_loss: 2.993 (118.02s - 2504.74s remaining) EPOCH: 10.000 d_loss: 0.530 g_loss: 2.856 (130.87s - 2486.59s remaining) EPOCH: 11.000 d_loss: 0.477 g_loss: 2.913 (143.87s - 2471.88s remaining) EPOCH: 12.000 d_loss: 0.496 g_loss: 2.963 (156.28s - 2448.38s remaining) EPOCH: 13.000 d_loss: 0.541 g_loss: 2.702 (169.04s - 2431.55s remaining) EPOCH: 14.000 d_loss: 0.655 g_loss: 2.362 (181.85s - 2415.95s remaining) EPOCH: 15.000 d_loss: 0.659 g_loss: 2.362 (194.41s - 2397.70s remaining) EPOCH: 16.000 d_loss: 0.667 g_loss: 2.316 (206.99s - 2380.38s remaining) EPOCH: 17.000 d_loss: 0.701 g_loss: 2.160 (219.70s - 2365.04s remaining) EPOCH: 18.000 d_loss: 0.680 g_loss: 2.165 (232.28s - 2348.66s remaining) EPOCH: 19.000 d_loss: 0.731 g_loss: 2.103 (245.18s - 2335.64s remaining) EPOCH: 20.000 d_loss: 0.745 g_loss: 2.059 (258.06s - 2322.58s remaining) EPOCH: 21.000 d_loss: 0.820 g_loss: 1.865 (271.10s - 2310.83s remaining) EPOCH: 22.000 d_loss: 0.853 g_loss: 1.761 (283.77s - 2295.94s remaining) EPOCH: 23.000 d_loss: 0.781 g_loss: 1.954 (296.40s - 2280.98s remaining) EPOCH: 24.000 d_loss: 0.805 g_loss: 1.933 (309.01s - 2266.06s remaining) EPOCH: 25.000 d_loss: 0.828 g_loss: 1.914 (321.43s - 2250.02s remaining) EPOCH: 26.000 d_loss: 0.804 g_loss: 1.906 (334.11s - 2235.97s remaining) EPOCH: 27.000 d_loss: 0.837 g_loss: 1.788 (347.03s - 2223.54s remaining) EPOCH: 28.000 d_loss: 0.811 g_loss: 1.838 (359.27s - 2206.95s remaining) EPOCH: 29.000 d_loss: 0.858 g_loss: 1.763 (372.05s - 2193.79s remaining) EPOCH: 30.000 d_loss: 0.882 g_loss: 1.684 (384.63s - 2179.56s remaining) EPOCH: 31.000 d_loss: 0.928 g_loss: 1.574 (397.68s - 2167.99s remaining) EPOCH: 32.000 d_loss: 0.920 g_loss: 1.598 (410.50s - 2155.14s remaining) EPOCH: 33.000 d_loss: 0.936 g_loss: 1.562 (423.38s - 2142.54s remaining) EPOCH: 34.000 d_loss: 0.952 g_loss: 1.535 (436.11s - 2129.23s remaining) EPOCH: 35.000 d_loss: 0.969 g_loss: 1.474 (449.16s - 2117.47s remaining) EPOCH: 36.000 d_loss: 0.975 g_loss: 1.480 (462.00s - 2104.67s remaining) EPOCH: 37.000 d_loss: 1.000 g_loss: 1.418 (474.62s - 2090.87s remaining) EPOCH: 38.000 d_loss: 0.999 g_loss: 1.442 (486.77s - 2075.18s remaining) EPOCH: 39.000 d_loss: 1.018 g_loss: 1.404 (499.91s - 2063.75s remaining) EPOCH: 40.000 d_loss: 1.032 g_loss: 1.379 (512.46s - 2049.83s remaining) EPOCH: 41.000 d_loss: 1.029 g_loss: 1.346 (524.71s - 2034.84s remaining) EPOCH: 42.000 d_loss: 1.038 g_loss: 1.333 (537.49s - 2022.00s remaining) EPOCH: 43.000 d_loss: 1.041 g_loss: 1.344 (550.13s - 2008.62s remaining) EPOCH: 44.000 d_loss: 1.052 g_loss: 1.308 (563.33s - 1997.25s remaining) EPOCH: 45.000 d_loss: 1.071 g_loss: 1.281 (575.70s - 1982.98s remaining) EPOCH: 46.000 d_loss: 1.060 g_loss: 1.302 (588.65s - 1970.69s remaining) EPOCH: 47.000 d_loss: 1.064 g_loss: 1.302 (601.14s - 1956.90s remaining) EPOCH: 48.000 d_loss: 1.075 g_loss: 1.259 (613.72s - 1943.45s remaining) EPOCH: 49.000 d_loss: 1.106 g_loss: 1.215 (626.47s - 1930.55s remaining) EPOCH: 50.000 d_loss: 1.112 g_loss: 1.184 (638.79s - 1916.38s remaining) EPOCH: 51.000 d_loss: 1.115 g_loss: 1.175 (651.77s - 1904.20s remaining) EPOCH: 52.000 d_loss: 1.119 g_loss: 1.177 (664.13s - 1890.21s remaining) EPOCH: 53.000 d_loss: 1.121 g_loss: 1.179 (676.57s - 1876.53s remaining) EPOCH: 54.000 d_loss: 1.108 g_loss: 1.196 (689.72s - 1864.81s remaining) EPOCH: 55.000 d_loss: 1.112 g_loss: 1.205 (702.57s - 1852.22s remaining) EPOCH: 56.000 d_loss: 1.104 g_loss: 1.214 (715.56s - 1840.02s remaining) EPOCH: 57.000 d_loss: 1.134 g_loss: 1.147 (728.13s - 1826.72s remaining) EPOCH: 58.000 d_loss: 1.143 g_loss: 1.124 (740.25s - 1812.34s remaining) EPOCH: 59.000 d_loss: 1.138 g_loss: 1.136 (752.99s - 1799.53s remaining) EPOCH: 60.000 d_loss: 1.151 g_loss: 1.114 (766.41s - 1788.28s remaining) EPOCH: 61.000 d_loss: 1.157 g_loss: 1.103 (780.01s - 1777.41s remaining) EPOCH: 62.000 d_loss: 1.149 g_loss: 1.116 (793.34s - 1765.81s remaining) EPOCH: 63.000 d_loss: 1.150 g_loss: 1.115 (806.27s - 1753.32s remaining) EPOCH: 64.000 d_loss: 1.156 g_loss: 1.108 (818.97s - 1740.31s remaining) EPOCH: 65.000 d_loss: 1.148 g_loss: 1.135 (831.50s - 1726.95s remaining) EPOCH: 66.000 d_loss: 1.149 g_loss: 1.122 (843.93s - 1713.43s remaining) EPOCH: 67.000 d_loss: 1.156 g_loss: 1.123 (857.07s - 1701.34s remaining) EPOCH: 68.000 d_loss: 1.145 g_loss: 1.115 (869.75s - 1688.33s remaining) EPOCH: 69.000 d_loss: 1.159 g_loss: 1.095 (882.44s - 1675.36s remaining) EPOCH: 70.000 d_loss: 1.178 g_loss: 1.070 (894.98s - 1662.10s remaining) EPOCH: 71.000 d_loss: 1.170 g_loss: 1.074 (907.33s - 1648.53s remaining) EPOCH: 72.000 d_loss: 1.187 g_loss: 1.049 (919.78s - 1635.17s remaining) EPOCH: 73.000 d_loss: 1.189 g_loss: 1.045 (932.45s - 1622.20s remaining) EPOCH: 74.000 d_loss: 1.187 g_loss: 1.056 (945.63s - 1610.12s remaining) EPOCH: 75.000 d_loss: 1.181 g_loss: 1.069 (959.24s - 1598.73s remaining) EPOCH: 76.000 d_loss: 1.179 g_loss: 1.049 (972.26s - 1586.31s remaining) EPOCH: 77.000 d_loss: 1.194 g_loss: 1.048 (984.91s - 1573.30s remaining) EPOCH: 78.000 d_loss: 1.200 g_loss: 1.033 (997.68s - 1560.48s remaining) EPOCH: 79.000 d_loss: 1.194 g_loss: 1.036 (1010.18s - 1547.24s remaining) EPOCH: 80.000 d_loss: 1.196 g_loss: 1.034 (1023.43s - 1535.15s remaining) EPOCH: 81.000 d_loss: 1.199 g_loss: 1.028 (1036.71s - 1523.07s remaining) EPOCH: 82.000 d_loss: 1.195 g_loss: 1.033 (1049.94s - 1510.89s remaining) EPOCH: 83.000 d_loss: 1.202 g_loss: 1.025 (1063.00s - 1498.44s remaining) EPOCH: 84.000 d_loss: 1.200 g_loss: 1.020 (1076.17s - 1486.14s remaining) EPOCH: 85.000 d_loss: 1.201 g_loss: 1.016 (1088.91s - 1473.23s remaining) EPOCH: 86.000 d_loss: 1.209 g_loss: 1.014 (1101.26s - 1459.81s remaining) EPOCH: 87.000 d_loss: 1.209 g_loss: 1.009 (1114.89s - 1448.08s remaining) EPOCH: 88.000 d_loss: 1.211 g_loss: 1.004 (1127.54s - 1435.06s remaining) EPOCH: 89.000 d_loss: 1.217 g_loss: 1.001 (1139.91s - 1421.68s remaining) EPOCH: 90.000 d_loss: 1.194 g_loss: 1.035 (1152.62s - 1408.76s remaining) EPOCH: 91.000 d_loss: 1.210 g_loss: 1.010 (1165.60s - 1396.15s remaining) EPOCH: 92.000 d_loss: 1.217 g_loss: 0.989 (1178.52s - 1383.47s remaining) EPOCH: 93.000 d_loss: 1.221 g_loss: 0.994 (1191.82s - 1371.23s remaining) EPOCH: 94.000 d_loss: 1.217 g_loss: 0.999 (1204.47s - 1358.23s remaining) EPOCH: 95.000 d_loss: 1.220 g_loss: 0.988 (1217.05s - 1345.16s remaining) EPOCH: 96.000 d_loss: 1.222 g_loss: 0.980 (1229.63s - 1332.10s remaining) EPOCH: 97.000 d_loss: 1.212 g_loss: 1.004 (1243.06s - 1319.95s remaining) EPOCH: 98.000 d_loss: 1.215 g_loss: 0.997 (1256.17s - 1307.44s remaining) EPOCH: 99.000 d_loss: 1.230 g_loss: 0.971 (1268.61s - 1294.23s remaining) EPOCH: 100.000 d_loss: 1.218 g_loss: 0.994 (1281.10s - 1281.10s remaining) EPOCH: 101.000 d_loss: 1.222 g_loss: 0.986 (1293.43s - 1267.82s remaining) EPOCH: 102.000 d_loss: 1.229 g_loss: 0.973 (1306.13s - 1254.91s remaining) EPOCH: 103.000 d_loss: 1.223 g_loss: 0.980 (1319.25s - 1242.40s remaining) EPOCH: 104.000 d_loss: 1.222 g_loss: 0.999 (1332.81s - 1230.29s remaining) EPOCH: 105.000 d_loss: 1.222 g_loss: 0.979 (1346.31s - 1218.09s remaining) EPOCH: 106.000 d_loss: 1.234 g_loss: 0.961 (1359.16s - 1205.29s remaining) EPOCH: 107.000 d_loss: 1.231 g_loss: 0.979 (1372.86s - 1193.24s remaining) EPOCH: 108.000 d_loss: 1.236 g_loss: 0.954 (1386.20s - 1180.84s remaining) EPOCH: 109.000 d_loss: 1.246 g_loss: 0.949 (1399.69s - 1168.55s remaining) EPOCH: 110.000 d_loss: 1.246 g_loss: 0.934 (1413.41s - 1156.42s remaining) EPOCH: 111.000 d_loss: 1.241 g_loss: 0.966 (1427.00s - 1144.17s remaining) EPOCH: 112.000 d_loss: 1.237 g_loss: 0.965 (1440.03s - 1131.45s remaining) EPOCH: 113.000 d_loss: 1.240 g_loss: 0.959 (1453.07s - 1118.74s remaining) EPOCH: 114.000 d_loss: 1.245 g_loss: 0.944 (1466.09s - 1106.00s remaining) EPOCH: 115.000 d_loss: 1.241 g_loss: 0.965 (1479.53s - 1093.56s remaining) EPOCH: 116.000 d_loss: 1.242 g_loss: 0.954 (1493.15s - 1081.25s remaining) EPOCH: 117.000 d_loss: 1.247 g_loss: 0.946 (1506.42s - 1068.66s remaining) EPOCH: 118.000 d_loss: 1.247 g_loss: 0.926 (1519.91s - 1056.21s remaining) EPOCH: 119.000 d_loss: 1.251 g_loss: 0.935 (1533.72s - 1043.96s remaining) EPOCH: 120.000 d_loss: 1.248 g_loss: 0.941 (1547.61s - 1031.74s remaining) EPOCH: 121.000 d_loss: 1.245 g_loss: 0.955 (1561.19s - 1019.29s remaining) EPOCH: 122.000 d_loss: 1.249 g_loss: 0.950 (1574.56s - 1006.69s remaining) EPOCH: 123.000 d_loss: 1.247 g_loss: 0.944 (1587.89s - 994.05s remaining) EPOCH: 124.000 d_loss: 1.238 g_loss: 0.956 (1601.54s - 981.59s remaining) EPOCH: 125.000 d_loss: 1.251 g_loss: 0.941 (1615.39s - 969.23s remaining) EPOCH: 126.000 d_loss: 1.249 g_loss: 0.948 (1628.82s - 956.61s remaining) EPOCH: 127.000 d_loss: 1.254 g_loss: 0.933 (1642.68s - 944.22s remaining) EPOCH: 128.000 d_loss: 1.251 g_loss: 0.927 (1656.09s - 931.55s remaining) EPOCH: 129.000 d_loss: 1.260 g_loss: 0.922 (1669.69s - 918.98s remaining) EPOCH: 130.000 d_loss: 1.262 g_loss: 0.916 (1683.25s - 906.36s remaining) EPOCH: 131.000 d_loss: 1.249 g_loss: 0.941 (1696.50s - 893.58s remaining) EPOCH: 132.000 d_loss: 1.254 g_loss: 0.935 (1710.02s - 880.92s remaining) EPOCH: 133.000 d_loss: 1.251 g_loss: 0.941 (1723.55s - 868.25s remaining) EPOCH: 134.000 d_loss: 1.245 g_loss: 0.949 (1737.01s - 855.54s remaining) EPOCH: 135.000 d_loss: 1.254 g_loss: 0.927 (1750.90s - 843.02s remaining) EPOCH: 136.000 d_loss: 1.257 g_loss: 0.922 (1764.25s - 830.23s remaining) EPOCH: 137.000 d_loss: 1.260 g_loss: 0.919 (1777.60s - 817.43s remaining) EPOCH: 138.000 d_loss: 1.264 g_loss: 0.917 (1791.03s - 804.67s remaining) EPOCH: 139.000 d_loss: 1.267 g_loss: 0.909 (1804.99s - 792.12s remaining) EPOCH: 140.000 d_loss: 1.268 g_loss: 0.912 (1819.06s - 779.60s remaining) EPOCH: 141.000 d_loss: 1.260 g_loss: 0.928 (1833.03s - 767.01s remaining) EPOCH: 142.000 d_loss: 1.262 g_loss: 0.920 (1847.04s - 754.42s remaining) EPOCH: 143.000 d_loss: 1.263 g_loss: 0.925 (1861.11s - 741.84s remaining) EPOCH: 144.000 d_loss: 1.262 g_loss: 0.915 (1874.95s - 729.15s remaining) EPOCH: 145.000 d_loss: 1.268 g_loss: 0.903 (1888.85s - 716.46s remaining) EPOCH: 146.000 d_loss: 1.270 g_loss: 0.909 (1902.87s - 703.80s remaining) EPOCH: 147.000 d_loss: 1.268 g_loss: 0.907 (1916.79s - 691.09s remaining) EPOCH: 148.000 d_loss: 1.268 g_loss: 0.913 (1930.89s - 678.42s remaining) EPOCH: 149.000 d_loss: 1.266 g_loss: 0.903 (1944.73s - 665.65s remaining) EPOCH: 150.000 d_loss: 1.268 g_loss: 0.906 (1959.24s - 653.08s remaining) EPOCH: 151.000 d_loss: 1.274 g_loss: 0.905 (1973.09s - 640.27s remaining) EPOCH: 152.000 d_loss: 1.273 g_loss: 0.919 (1987.13s - 627.52s remaining) EPOCH: 153.000 d_loss: 1.256 g_loss: 0.937 (2000.86s - 614.64s remaining) EPOCH: 154.000 d_loss: 1.270 g_loss: 0.897 (2015.12s - 601.92s remaining) EPOCH: 155.000 d_loss: 1.272 g_loss: 0.894 (2029.28s - 589.15s remaining) EPOCH: 156.000 d_loss: 1.265 g_loss: 0.911 (2043.64s - 576.41s remaining) EPOCH: 157.000 d_loss: 1.271 g_loss: 0.895 (2057.61s - 563.55s remaining) EPOCH: 158.000 d_loss: 1.276 g_loss: 0.895 (2072.00s - 550.79s remaining) EPOCH: 159.000 d_loss: 1.272 g_loss: 0.906 (2086.20s - 537.95s remaining) EPOCH: 160.000 d_loss: 1.272 g_loss: 0.907 (2100.51s - 525.13s remaining) EPOCH: 161.000 d_loss: 1.273 g_loss: 0.904 (2114.72s - 512.26s remaining) EPOCH: 162.000 d_loss: 1.266 g_loss: 0.913 (2129.18s - 499.44s remaining) EPOCH: 163.000 d_loss: 1.275 g_loss: 0.887 (2143.25s - 486.51s remaining) EPOCH: 164.000 d_loss: 1.271 g_loss: 0.901 (2157.50s - 473.60s remaining) EPOCH: 165.000 d_loss: 1.272 g_loss: 0.918 (2171.67s - 460.66s remaining) EPOCH: 166.000 d_loss: 1.269 g_loss: 0.907 (2186.02s - 447.74s remaining) EPOCH: 167.000 d_loss: 1.281 g_loss: 0.888 (2200.47s - 434.82s remaining) EPOCH: 168.000 d_loss: 1.273 g_loss: 0.910 (2214.70s - 421.85s remaining) EPOCH: 169.000 d_loss: 1.271 g_loss: 0.901 (2228.81s - 408.84s remaining) EPOCH: 170.000 d_loss: 1.275 g_loss: 0.895 (2242.68s - 395.77s remaining) EPOCH: 171.000 d_loss: 1.273 g_loss: 0.898 (2257.16s - 382.79s remaining) EPOCH: 172.000 d_loss: 1.270 g_loss: 0.906 (2271.58s - 369.79s remaining) EPOCH: 173.000 d_loss: 1.268 g_loss: 0.904 (2285.98s - 356.77s remaining) EPOCH: 174.000 d_loss: 1.273 g_loss: 0.893 (2300.40s - 343.74s remaining) EPOCH: 175.000 d_loss: 1.273 g_loss: 0.894 (2314.91s - 330.70s remaining) EPOCH: 176.000 d_loss: 1.269 g_loss: 0.907 (2329.32s - 317.63s remaining) EPOCH: 177.000 d_loss: 1.280 g_loss: 0.889 (2344.18s - 304.61s remaining) EPOCH: 178.000 d_loss: 1.272 g_loss: 0.896 (2358.27s - 291.47s remaining) EPOCH: 179.000 d_loss: 1.275 g_loss: 0.894 (2372.23s - 278.31s remaining) EPOCH: 180.000 d_loss: 1.272 g_loss: 0.909 (2386.40s - 265.16s remaining) EPOCH: 181.000 d_loss: 1.271 g_loss: 0.897 (2400.74s - 252.01s remaining) EPOCH: 182.000 d_loss: 1.278 g_loss: 0.891 (2414.91s - 238.84s remaining) EPOCH: 183.000 d_loss: 1.274 g_loss: 0.899 (2428.85s - 225.63s remaining) EPOCH: 184.000 d_loss: 1.270 g_loss: 0.905 (2442.80s - 212.42s remaining) EPOCH: 185.000 d_loss: 1.281 g_loss: 0.884 (2456.75s - 199.20s remaining) EPOCH: 186.000 d_loss: 1.282 g_loss: 0.874 (2471.23s - 186.01s remaining) EPOCH: 187.000 d_loss: 1.279 g_loss: 0.892 (2485.90s - 172.82s remaining) EPOCH: 188.000 d_loss: 1.282 g_loss: 0.886 (2500.06s - 159.58s remaining) EPOCH: 189.000 d_loss: 1.278 g_loss: 0.889 (2514.38s - 146.34s remaining) EPOCH: 190.000 d_loss: 1.276 g_loss: 0.895 (2528.89s - 133.10s remaining) EPOCH: 191.000 d_loss: 1.283 g_loss: 0.887 (2543.72s - 119.86s remaining) EPOCH: 192.000 d_loss: 1.277 g_loss: 0.903 (2558.13s - 106.59s remaining) EPOCH: 193.000 d_loss: 1.279 g_loss: 0.881 (2572.89s - 93.32s remaining) EPOCH: 194.000 d_loss: 1.282 g_loss: 0.882 (2587.38s - 80.02s remaining) EPOCH: 195.000 d_loss: 1.282 g_loss: 0.887 (2601.27s - 66.70s remaining) EPOCH: 196.000 d_loss: 1.275 g_loss: 0.894 (2615.26s - 53.37s remaining) EPOCH: 197.000 d_loss: 1.279 g_loss: 0.885 (2629.34s - 40.04s remaining) EPOCH: 198.000 d_loss: 1.281 g_loss: 0.887 (2643.65s - 26.70s remaining) EPOCH: 199.000 d_loss: 1.277 g_loss: 0.909 (2658.09s - 13.36s remaining) EPOCH: 199.996 d_loss: 1.246 g_loss: 0.942 (2672.44s - 0.06s remaining)
0%| | 0/200 [00:00<?, ?it/s]/usr/local/lib/python3.6/dist-packages/numpy/core/fromnumeric.py:3335: RuntimeWarning: Mean of empty slice. out=out, **kwargs) /usr/local/lib/python3.6/dist-packages/numpy/core/_methods.py:161: RuntimeWarning: invalid value encountered in double_scalars ret = ret.dtype.type(ret / rcount) 13%|█▎ | 26/200 [00:00<00:00, 253.49it/s]
EPOCH: 199.998 d_loss: 1.263 g_loss: 0.876 (2672.47s - 0.03s remaining) EPOCH: 200.000 d_loss: 1.256 g_loss: 0.935 (2672.49s - 0.00s remaining) EPOCH: 200.000 d_loss: 1.285 g_loss: 0.884 (2672.52s - 0.00s remaining)
100%|██████████| 200/200 [00:03<00:00, 62.85it/s]
z = torch.randn(64, 100).to(device)
sample_images = generator(z).data.cpu().view(64, 1, 28, 28)
grid = make_grid(sample_images, nrow=8, normalize=True)
show(grid.cpu().detach().permute(1,2,0), sz=5)