Computer_Vision/Chapter12/Handwritten_digit_generatio...

186 KiB
Raw Permalink Blame History

Open In Colab

!pip install -q torch_snippets
from torch_snippets import *
device = "cuda" if torch.cuda.is_available() else "cpu"
from torchvision.utils import make_grid
     |████████████████████████████████| 36.7MB 81kB/s 
     |████████████████████████████████| 61kB 8.7MB/s 
     |████████████████████████████████| 102kB 13.2MB/s 
[?25h  Building wheel for contextvars (setup.py) ... [?25l[?25hdone
from torchvision.datasets import MNIST
from torchvision import transforms

transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5,), std=(0.5,))
])

data_loader = torch.utils.data.DataLoader(MNIST('~/data', train=True, download=True, transform=transform),batch_size=128, shuffle=True, drop_last=True)
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /root/data/MNIST/raw/train-images-idx3-ubyte.gz
HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))
Extracting /root/data/MNIST/raw/train-images-idx3-ubyte.gz to /root/data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to /root/data/MNIST/raw/train-labels-idx1-ubyte.gz
HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))
Extracting /root/data/MNIST/raw/train-labels-idx1-ubyte.gz to /root/data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to /root/data/MNIST/raw/t10k-images-idx3-ubyte.gz

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))
Extracting /root/data/MNIST/raw/t10k-images-idx3-ubyte.gz to /root/data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to /root/data/MNIST/raw/t10k-labels-idx1-ubyte.gz
HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))
Extracting /root/data/MNIST/raw/t10k-labels-idx1-ubyte.gz to /root/data/MNIST/raw
Processing...
Done!
/usr/local/lib/python3.6/dist-packages/torchvision/datasets/mnist.py:480: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at  /pytorch/torch/csrc/utils/tensor_numpy.cpp:141.)
  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential( 
            nn.Linear(784, 1024),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )
    def forward(self, x): return self.model(x)
!pip install torch_summary
from torchsummary import summary
discriminator = Discriminator().to(device)
summary(discriminator,torch.zeros(1,784))
Requirement already satisfied: torch_summary in /usr/local/lib/python3.6/dist-packages (1.4.3)
==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
├─Sequential: 1-1                        [-1, 1]                   --
|    └─Linear: 2-1                       [-1, 1024]                803,840
|    └─LeakyReLU: 2-2                    [-1, 1024]                --
|    └─Dropout: 2-3                      [-1, 1024]                --
|    └─Linear: 2-4                       [-1, 512]                 524,800
|    └─LeakyReLU: 2-5                    [-1, 512]                 --
|    └─Dropout: 2-6                      [-1, 512]                 --
|    └─Linear: 2-7                       [-1, 256]                 131,328
|    └─LeakyReLU: 2-8                    [-1, 256]                 --
|    └─Dropout: 2-9                      [-1, 256]                 --
|    └─Linear: 2-10                      [-1, 1]                   257
|    └─Sigmoid: 2-11                     [-1, 1]                   --
==========================================================================================
Total params: 1,460,225
Trainable params: 1,460,225
Non-trainable params: 0
Total mult-adds (M): 2.92
==========================================================================================
Input size (MB): 0.00
Forward/backward pass size (MB): 0.01
Params size (MB): 5.57
Estimated Total Size (MB): 5.59
==========================================================================================
==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
├─Sequential: 1-1                        [-1, 1]                   --
|    └─Linear: 2-1                       [-1, 1024]                803,840
|    └─LeakyReLU: 2-2                    [-1, 1024]                --
|    └─Dropout: 2-3                      [-1, 1024]                --
|    └─Linear: 2-4                       [-1, 512]                 524,800
|    └─LeakyReLU: 2-5                    [-1, 512]                 --
|    └─Dropout: 2-6                      [-1, 512]                 --
|    └─Linear: 2-7                       [-1, 256]                 131,328
|    └─LeakyReLU: 2-8                    [-1, 256]                 --
|    └─Dropout: 2-9                      [-1, 256]                 --
|    └─Linear: 2-10                      [-1, 1]                   257
|    └─Sigmoid: 2-11                     [-1, 1]                   --
==========================================================================================
Total params: 1,460,225
Trainable params: 1,460,225
Non-trainable params: 0
Total mult-adds (M): 2.92
==========================================================================================
Input size (MB): 0.00
Forward/backward pass size (MB): 0.01
Params size (MB): 5.57
Estimated Total Size (MB): 5.59
==========================================================================================
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(100, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 784),
            nn.Tanh()
        )

    def forward(self, x): return self.model(x)
generator = Generator().to(device)
summary(generator,torch.zeros(1,100))
==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
├─Sequential: 1-1                        [-1, 784]                 --
|    └─Linear: 2-1                       [-1, 256]                 25,856
|    └─LeakyReLU: 2-2                    [-1, 256]                 --
|    └─Linear: 2-3                       [-1, 512]                 131,584
|    └─LeakyReLU: 2-4                    [-1, 512]                 --
|    └─Linear: 2-5                       [-1, 1024]                525,312
|    └─LeakyReLU: 2-6                    [-1, 1024]                --
|    └─Linear: 2-7                       [-1, 784]                 803,600
|    └─Tanh: 2-8                         [-1, 784]                 --
==========================================================================================
Total params: 1,486,352
Trainable params: 1,486,352
Non-trainable params: 0
Total mult-adds (M): 2.97
==========================================================================================
Input size (MB): 0.00
Forward/backward pass size (MB): 0.02
Params size (MB): 5.67
Estimated Total Size (MB): 5.69
==========================================================================================
==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
├─Sequential: 1-1                        [-1, 784]                 --
|    └─Linear: 2-1                       [-1, 256]                 25,856
|    └─LeakyReLU: 2-2                    [-1, 256]                 --
|    └─Linear: 2-3                       [-1, 512]                 131,584
|    └─LeakyReLU: 2-4                    [-1, 512]                 --
|    └─Linear: 2-5                       [-1, 1024]                525,312
|    └─LeakyReLU: 2-6                    [-1, 1024]                --
|    └─Linear: 2-7                       [-1, 784]                 803,600
|    └─Tanh: 2-8                         [-1, 784]                 --
==========================================================================================
Total params: 1,486,352
Trainable params: 1,486,352
Non-trainable params: 0
Total mult-adds (M): 2.97
==========================================================================================
Input size (MB): 0.00
Forward/backward pass size (MB): 0.02
Params size (MB): 5.67
Estimated Total Size (MB): 5.69
==========================================================================================
def noise(size):
    n = torch.randn(size, 100)
    return n.to(device)
def discriminator_train_step(real_data, fake_data):
    d_optimizer.zero_grad()
    prediction_real = discriminator(real_data)
    error_real = loss(prediction_real, torch.ones(len(real_data), 1).to(device))
    error_real.backward()
    prediction_fake = discriminator(fake_data)
    error_fake = loss(prediction_fake, torch.zeros(len(fake_data), 1).to(device))
    error_fake.backward()
    d_optimizer.step()
    return error_real + error_fake
def generator_train_step(fake_data):
    g_optimizer.zero_grad()
    prediction = discriminator(fake_data)
    error = loss(prediction, torch.ones(len(real_data), 1).to(device))
    error.backward()
    g_optimizer.step()
    return error
discriminator = Discriminator().to(device)
generator = Generator().to(device)
d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002)
g_optimizer = optim.Adam(generator.parameters(), lr=0.0002)
loss = nn.BCELoss()
num_epochs = 200
log = Report(num_epochs)
for epoch in range(num_epochs):
    N = len(data_loader)
    for i, (images, _) in enumerate(data_loader):
        real_data = images.view(len(images), -1).to(device)
        fake_data = generator(noise(len(real_data))).to(device)
        fake_data = fake_data.detach()
        d_loss = discriminator_train_step(real_data, fake_data)
        fake_data = generator(noise(len(real_data))).to(device)
        g_loss = generator_train_step(fake_data)
        log.record(epoch+(1+i)/N, d_loss=d_loss.item(), g_loss=g_loss.item(), end='\r')
    log.report_avgs(epoch+1)
log.plot_epochs(['d_loss', 'g_loss'])
EPOCH: 1.000	d_loss: 0.839	g_loss: 3.254	(14.88s - 2960.95s remaining)
EPOCH: 2.000	d_loss: 0.786	g_loss: 4.338	(27.37s - 2709.66s remaining)
EPOCH: 3.000	d_loss: 0.915	g_loss: 2.351	(40.28s - 2644.75s remaining)
EPOCH: 4.000	d_loss: 0.810	g_loss: 2.462	(52.94s - 2594.29s remaining)
EPOCH: 5.000	d_loss: 0.679	g_loss: 2.651	(65.44s - 2552.31s remaining)
EPOCH: 6.000	d_loss: 0.333	g_loss: 3.991	(78.65s - 2543.15s remaining)
EPOCH: 7.000	d_loss: 0.444	g_loss: 3.567	(92.02s - 2537.07s remaining)
EPOCH: 8.000	d_loss: 0.454	g_loss: 3.225	(104.66s - 2511.72s remaining)
EPOCH: 9.000	d_loss: 0.494	g_loss: 2.993	(118.02s - 2504.74s remaining)
EPOCH: 10.000	d_loss: 0.530	g_loss: 2.856	(130.87s - 2486.59s remaining)
EPOCH: 11.000	d_loss: 0.477	g_loss: 2.913	(143.87s - 2471.88s remaining)
EPOCH: 12.000	d_loss: 0.496	g_loss: 2.963	(156.28s - 2448.38s remaining)
EPOCH: 13.000	d_loss: 0.541	g_loss: 2.702	(169.04s - 2431.55s remaining)
EPOCH: 14.000	d_loss: 0.655	g_loss: 2.362	(181.85s - 2415.95s remaining)
EPOCH: 15.000	d_loss: 0.659	g_loss: 2.362	(194.41s - 2397.70s remaining)
EPOCH: 16.000	d_loss: 0.667	g_loss: 2.316	(206.99s - 2380.38s remaining)
EPOCH: 17.000	d_loss: 0.701	g_loss: 2.160	(219.70s - 2365.04s remaining)
EPOCH: 18.000	d_loss: 0.680	g_loss: 2.165	(232.28s - 2348.66s remaining)
EPOCH: 19.000	d_loss: 0.731	g_loss: 2.103	(245.18s - 2335.64s remaining)
EPOCH: 20.000	d_loss: 0.745	g_loss: 2.059	(258.06s - 2322.58s remaining)
EPOCH: 21.000	d_loss: 0.820	g_loss: 1.865	(271.10s - 2310.83s remaining)
EPOCH: 22.000	d_loss: 0.853	g_loss: 1.761	(283.77s - 2295.94s remaining)
EPOCH: 23.000	d_loss: 0.781	g_loss: 1.954	(296.40s - 2280.98s remaining)
EPOCH: 24.000	d_loss: 0.805	g_loss: 1.933	(309.01s - 2266.06s remaining)
EPOCH: 25.000	d_loss: 0.828	g_loss: 1.914	(321.43s - 2250.02s remaining)
EPOCH: 26.000	d_loss: 0.804	g_loss: 1.906	(334.11s - 2235.97s remaining)
EPOCH: 27.000	d_loss: 0.837	g_loss: 1.788	(347.03s - 2223.54s remaining)
EPOCH: 28.000	d_loss: 0.811	g_loss: 1.838	(359.27s - 2206.95s remaining)
EPOCH: 29.000	d_loss: 0.858	g_loss: 1.763	(372.05s - 2193.79s remaining)
EPOCH: 30.000	d_loss: 0.882	g_loss: 1.684	(384.63s - 2179.56s remaining)
EPOCH: 31.000	d_loss: 0.928	g_loss: 1.574	(397.68s - 2167.99s remaining)
EPOCH: 32.000	d_loss: 0.920	g_loss: 1.598	(410.50s - 2155.14s remaining)
EPOCH: 33.000	d_loss: 0.936	g_loss: 1.562	(423.38s - 2142.54s remaining)
EPOCH: 34.000	d_loss: 0.952	g_loss: 1.535	(436.11s - 2129.23s remaining)
EPOCH: 35.000	d_loss: 0.969	g_loss: 1.474	(449.16s - 2117.47s remaining)
EPOCH: 36.000	d_loss: 0.975	g_loss: 1.480	(462.00s - 2104.67s remaining)
EPOCH: 37.000	d_loss: 1.000	g_loss: 1.418	(474.62s - 2090.87s remaining)
EPOCH: 38.000	d_loss: 0.999	g_loss: 1.442	(486.77s - 2075.18s remaining)
EPOCH: 39.000	d_loss: 1.018	g_loss: 1.404	(499.91s - 2063.75s remaining)
EPOCH: 40.000	d_loss: 1.032	g_loss: 1.379	(512.46s - 2049.83s remaining)
EPOCH: 41.000	d_loss: 1.029	g_loss: 1.346	(524.71s - 2034.84s remaining)
EPOCH: 42.000	d_loss: 1.038	g_loss: 1.333	(537.49s - 2022.00s remaining)
EPOCH: 43.000	d_loss: 1.041	g_loss: 1.344	(550.13s - 2008.62s remaining)
EPOCH: 44.000	d_loss: 1.052	g_loss: 1.308	(563.33s - 1997.25s remaining)
EPOCH: 45.000	d_loss: 1.071	g_loss: 1.281	(575.70s - 1982.98s remaining)
EPOCH: 46.000	d_loss: 1.060	g_loss: 1.302	(588.65s - 1970.69s remaining)
EPOCH: 47.000	d_loss: 1.064	g_loss: 1.302	(601.14s - 1956.90s remaining)
EPOCH: 48.000	d_loss: 1.075	g_loss: 1.259	(613.72s - 1943.45s remaining)
EPOCH: 49.000	d_loss: 1.106	g_loss: 1.215	(626.47s - 1930.55s remaining)
EPOCH: 50.000	d_loss: 1.112	g_loss: 1.184	(638.79s - 1916.38s remaining)
EPOCH: 51.000	d_loss: 1.115	g_loss: 1.175	(651.77s - 1904.20s remaining)
EPOCH: 52.000	d_loss: 1.119	g_loss: 1.177	(664.13s - 1890.21s remaining)
EPOCH: 53.000	d_loss: 1.121	g_loss: 1.179	(676.57s - 1876.53s remaining)
EPOCH: 54.000	d_loss: 1.108	g_loss: 1.196	(689.72s - 1864.81s remaining)
EPOCH: 55.000	d_loss: 1.112	g_loss: 1.205	(702.57s - 1852.22s remaining)
EPOCH: 56.000	d_loss: 1.104	g_loss: 1.214	(715.56s - 1840.02s remaining)
EPOCH: 57.000	d_loss: 1.134	g_loss: 1.147	(728.13s - 1826.72s remaining)
EPOCH: 58.000	d_loss: 1.143	g_loss: 1.124	(740.25s - 1812.34s remaining)
EPOCH: 59.000	d_loss: 1.138	g_loss: 1.136	(752.99s - 1799.53s remaining)
EPOCH: 60.000	d_loss: 1.151	g_loss: 1.114	(766.41s - 1788.28s remaining)
EPOCH: 61.000	d_loss: 1.157	g_loss: 1.103	(780.01s - 1777.41s remaining)
EPOCH: 62.000	d_loss: 1.149	g_loss: 1.116	(793.34s - 1765.81s remaining)
EPOCH: 63.000	d_loss: 1.150	g_loss: 1.115	(806.27s - 1753.32s remaining)
EPOCH: 64.000	d_loss: 1.156	g_loss: 1.108	(818.97s - 1740.31s remaining)
EPOCH: 65.000	d_loss: 1.148	g_loss: 1.135	(831.50s - 1726.95s remaining)
EPOCH: 66.000	d_loss: 1.149	g_loss: 1.122	(843.93s - 1713.43s remaining)
EPOCH: 67.000	d_loss: 1.156	g_loss: 1.123	(857.07s - 1701.34s remaining)
EPOCH: 68.000	d_loss: 1.145	g_loss: 1.115	(869.75s - 1688.33s remaining)
EPOCH: 69.000	d_loss: 1.159	g_loss: 1.095	(882.44s - 1675.36s remaining)
EPOCH: 70.000	d_loss: 1.178	g_loss: 1.070	(894.98s - 1662.10s remaining)
EPOCH: 71.000	d_loss: 1.170	g_loss: 1.074	(907.33s - 1648.53s remaining)
EPOCH: 72.000	d_loss: 1.187	g_loss: 1.049	(919.78s - 1635.17s remaining)
EPOCH: 73.000	d_loss: 1.189	g_loss: 1.045	(932.45s - 1622.20s remaining)
EPOCH: 74.000	d_loss: 1.187	g_loss: 1.056	(945.63s - 1610.12s remaining)
EPOCH: 75.000	d_loss: 1.181	g_loss: 1.069	(959.24s - 1598.73s remaining)
EPOCH: 76.000	d_loss: 1.179	g_loss: 1.049	(972.26s - 1586.31s remaining)
EPOCH: 77.000	d_loss: 1.194	g_loss: 1.048	(984.91s - 1573.30s remaining)
EPOCH: 78.000	d_loss: 1.200	g_loss: 1.033	(997.68s - 1560.48s remaining)
EPOCH: 79.000	d_loss: 1.194	g_loss: 1.036	(1010.18s - 1547.24s remaining)
EPOCH: 80.000	d_loss: 1.196	g_loss: 1.034	(1023.43s - 1535.15s remaining)
EPOCH: 81.000	d_loss: 1.199	g_loss: 1.028	(1036.71s - 1523.07s remaining)
EPOCH: 82.000	d_loss: 1.195	g_loss: 1.033	(1049.94s - 1510.89s remaining)
EPOCH: 83.000	d_loss: 1.202	g_loss: 1.025	(1063.00s - 1498.44s remaining)
EPOCH: 84.000	d_loss: 1.200	g_loss: 1.020	(1076.17s - 1486.14s remaining)
EPOCH: 85.000	d_loss: 1.201	g_loss: 1.016	(1088.91s - 1473.23s remaining)
EPOCH: 86.000	d_loss: 1.209	g_loss: 1.014	(1101.26s - 1459.81s remaining)
EPOCH: 87.000	d_loss: 1.209	g_loss: 1.009	(1114.89s - 1448.08s remaining)
EPOCH: 88.000	d_loss: 1.211	g_loss: 1.004	(1127.54s - 1435.06s remaining)
EPOCH: 89.000	d_loss: 1.217	g_loss: 1.001	(1139.91s - 1421.68s remaining)
EPOCH: 90.000	d_loss: 1.194	g_loss: 1.035	(1152.62s - 1408.76s remaining)
EPOCH: 91.000	d_loss: 1.210	g_loss: 1.010	(1165.60s - 1396.15s remaining)
EPOCH: 92.000	d_loss: 1.217	g_loss: 0.989	(1178.52s - 1383.47s remaining)
EPOCH: 93.000	d_loss: 1.221	g_loss: 0.994	(1191.82s - 1371.23s remaining)
EPOCH: 94.000	d_loss: 1.217	g_loss: 0.999	(1204.47s - 1358.23s remaining)
EPOCH: 95.000	d_loss: 1.220	g_loss: 0.988	(1217.05s - 1345.16s remaining)
EPOCH: 96.000	d_loss: 1.222	g_loss: 0.980	(1229.63s - 1332.10s remaining)
EPOCH: 97.000	d_loss: 1.212	g_loss: 1.004	(1243.06s - 1319.95s remaining)
EPOCH: 98.000	d_loss: 1.215	g_loss: 0.997	(1256.17s - 1307.44s remaining)
EPOCH: 99.000	d_loss: 1.230	g_loss: 0.971	(1268.61s - 1294.23s remaining)
EPOCH: 100.000	d_loss: 1.218	g_loss: 0.994	(1281.10s - 1281.10s remaining)
EPOCH: 101.000	d_loss: 1.222	g_loss: 0.986	(1293.43s - 1267.82s remaining)
EPOCH: 102.000	d_loss: 1.229	g_loss: 0.973	(1306.13s - 1254.91s remaining)
EPOCH: 103.000	d_loss: 1.223	g_loss: 0.980	(1319.25s - 1242.40s remaining)
EPOCH: 104.000	d_loss: 1.222	g_loss: 0.999	(1332.81s - 1230.29s remaining)
EPOCH: 105.000	d_loss: 1.222	g_loss: 0.979	(1346.31s - 1218.09s remaining)
EPOCH: 106.000	d_loss: 1.234	g_loss: 0.961	(1359.16s - 1205.29s remaining)
EPOCH: 107.000	d_loss: 1.231	g_loss: 0.979	(1372.86s - 1193.24s remaining)
EPOCH: 108.000	d_loss: 1.236	g_loss: 0.954	(1386.20s - 1180.84s remaining)
EPOCH: 109.000	d_loss: 1.246	g_loss: 0.949	(1399.69s - 1168.55s remaining)
EPOCH: 110.000	d_loss: 1.246	g_loss: 0.934	(1413.41s - 1156.42s remaining)
EPOCH: 111.000	d_loss: 1.241	g_loss: 0.966	(1427.00s - 1144.17s remaining)
EPOCH: 112.000	d_loss: 1.237	g_loss: 0.965	(1440.03s - 1131.45s remaining)
EPOCH: 113.000	d_loss: 1.240	g_loss: 0.959	(1453.07s - 1118.74s remaining)
EPOCH: 114.000	d_loss: 1.245	g_loss: 0.944	(1466.09s - 1106.00s remaining)
EPOCH: 115.000	d_loss: 1.241	g_loss: 0.965	(1479.53s - 1093.56s remaining)
EPOCH: 116.000	d_loss: 1.242	g_loss: 0.954	(1493.15s - 1081.25s remaining)
EPOCH: 117.000	d_loss: 1.247	g_loss: 0.946	(1506.42s - 1068.66s remaining)
EPOCH: 118.000	d_loss: 1.247	g_loss: 0.926	(1519.91s - 1056.21s remaining)
EPOCH: 119.000	d_loss: 1.251	g_loss: 0.935	(1533.72s - 1043.96s remaining)
EPOCH: 120.000	d_loss: 1.248	g_loss: 0.941	(1547.61s - 1031.74s remaining)
EPOCH: 121.000	d_loss: 1.245	g_loss: 0.955	(1561.19s - 1019.29s remaining)
EPOCH: 122.000	d_loss: 1.249	g_loss: 0.950	(1574.56s - 1006.69s remaining)
EPOCH: 123.000	d_loss: 1.247	g_loss: 0.944	(1587.89s - 994.05s remaining)
EPOCH: 124.000	d_loss: 1.238	g_loss: 0.956	(1601.54s - 981.59s remaining)
EPOCH: 125.000	d_loss: 1.251	g_loss: 0.941	(1615.39s - 969.23s remaining)
EPOCH: 126.000	d_loss: 1.249	g_loss: 0.948	(1628.82s - 956.61s remaining)
EPOCH: 127.000	d_loss: 1.254	g_loss: 0.933	(1642.68s - 944.22s remaining)
EPOCH: 128.000	d_loss: 1.251	g_loss: 0.927	(1656.09s - 931.55s remaining)
EPOCH: 129.000	d_loss: 1.260	g_loss: 0.922	(1669.69s - 918.98s remaining)
EPOCH: 130.000	d_loss: 1.262	g_loss: 0.916	(1683.25s - 906.36s remaining)
EPOCH: 131.000	d_loss: 1.249	g_loss: 0.941	(1696.50s - 893.58s remaining)
EPOCH: 132.000	d_loss: 1.254	g_loss: 0.935	(1710.02s - 880.92s remaining)
EPOCH: 133.000	d_loss: 1.251	g_loss: 0.941	(1723.55s - 868.25s remaining)
EPOCH: 134.000	d_loss: 1.245	g_loss: 0.949	(1737.01s - 855.54s remaining)
EPOCH: 135.000	d_loss: 1.254	g_loss: 0.927	(1750.90s - 843.02s remaining)
EPOCH: 136.000	d_loss: 1.257	g_loss: 0.922	(1764.25s - 830.23s remaining)
EPOCH: 137.000	d_loss: 1.260	g_loss: 0.919	(1777.60s - 817.43s remaining)
EPOCH: 138.000	d_loss: 1.264	g_loss: 0.917	(1791.03s - 804.67s remaining)
EPOCH: 139.000	d_loss: 1.267	g_loss: 0.909	(1804.99s - 792.12s remaining)
EPOCH: 140.000	d_loss: 1.268	g_loss: 0.912	(1819.06s - 779.60s remaining)
EPOCH: 141.000	d_loss: 1.260	g_loss: 0.928	(1833.03s - 767.01s remaining)
EPOCH: 142.000	d_loss: 1.262	g_loss: 0.920	(1847.04s - 754.42s remaining)
EPOCH: 143.000	d_loss: 1.263	g_loss: 0.925	(1861.11s - 741.84s remaining)
EPOCH: 144.000	d_loss: 1.262	g_loss: 0.915	(1874.95s - 729.15s remaining)
EPOCH: 145.000	d_loss: 1.268	g_loss: 0.903	(1888.85s - 716.46s remaining)
EPOCH: 146.000	d_loss: 1.270	g_loss: 0.909	(1902.87s - 703.80s remaining)
EPOCH: 147.000	d_loss: 1.268	g_loss: 0.907	(1916.79s - 691.09s remaining)
EPOCH: 148.000	d_loss: 1.268	g_loss: 0.913	(1930.89s - 678.42s remaining)
EPOCH: 149.000	d_loss: 1.266	g_loss: 0.903	(1944.73s - 665.65s remaining)
EPOCH: 150.000	d_loss: 1.268	g_loss: 0.906	(1959.24s - 653.08s remaining)
EPOCH: 151.000	d_loss: 1.274	g_loss: 0.905	(1973.09s - 640.27s remaining)
EPOCH: 152.000	d_loss: 1.273	g_loss: 0.919	(1987.13s - 627.52s remaining)
EPOCH: 153.000	d_loss: 1.256	g_loss: 0.937	(2000.86s - 614.64s remaining)
EPOCH: 154.000	d_loss: 1.270	g_loss: 0.897	(2015.12s - 601.92s remaining)
EPOCH: 155.000	d_loss: 1.272	g_loss: 0.894	(2029.28s - 589.15s remaining)
EPOCH: 156.000	d_loss: 1.265	g_loss: 0.911	(2043.64s - 576.41s remaining)
EPOCH: 157.000	d_loss: 1.271	g_loss: 0.895	(2057.61s - 563.55s remaining)
EPOCH: 158.000	d_loss: 1.276	g_loss: 0.895	(2072.00s - 550.79s remaining)
EPOCH: 159.000	d_loss: 1.272	g_loss: 0.906	(2086.20s - 537.95s remaining)
EPOCH: 160.000	d_loss: 1.272	g_loss: 0.907	(2100.51s - 525.13s remaining)
EPOCH: 161.000	d_loss: 1.273	g_loss: 0.904	(2114.72s - 512.26s remaining)
EPOCH: 162.000	d_loss: 1.266	g_loss: 0.913	(2129.18s - 499.44s remaining)
EPOCH: 163.000	d_loss: 1.275	g_loss: 0.887	(2143.25s - 486.51s remaining)
EPOCH: 164.000	d_loss: 1.271	g_loss: 0.901	(2157.50s - 473.60s remaining)
EPOCH: 165.000	d_loss: 1.272	g_loss: 0.918	(2171.67s - 460.66s remaining)
EPOCH: 166.000	d_loss: 1.269	g_loss: 0.907	(2186.02s - 447.74s remaining)
EPOCH: 167.000	d_loss: 1.281	g_loss: 0.888	(2200.47s - 434.82s remaining)
EPOCH: 168.000	d_loss: 1.273	g_loss: 0.910	(2214.70s - 421.85s remaining)
EPOCH: 169.000	d_loss: 1.271	g_loss: 0.901	(2228.81s - 408.84s remaining)
EPOCH: 170.000	d_loss: 1.275	g_loss: 0.895	(2242.68s - 395.77s remaining)
EPOCH: 171.000	d_loss: 1.273	g_loss: 0.898	(2257.16s - 382.79s remaining)
EPOCH: 172.000	d_loss: 1.270	g_loss: 0.906	(2271.58s - 369.79s remaining)
EPOCH: 173.000	d_loss: 1.268	g_loss: 0.904	(2285.98s - 356.77s remaining)
EPOCH: 174.000	d_loss: 1.273	g_loss: 0.893	(2300.40s - 343.74s remaining)
EPOCH: 175.000	d_loss: 1.273	g_loss: 0.894	(2314.91s - 330.70s remaining)
EPOCH: 176.000	d_loss: 1.269	g_loss: 0.907	(2329.32s - 317.63s remaining)
EPOCH: 177.000	d_loss: 1.280	g_loss: 0.889	(2344.18s - 304.61s remaining)
EPOCH: 178.000	d_loss: 1.272	g_loss: 0.896	(2358.27s - 291.47s remaining)
EPOCH: 179.000	d_loss: 1.275	g_loss: 0.894	(2372.23s - 278.31s remaining)
EPOCH: 180.000	d_loss: 1.272	g_loss: 0.909	(2386.40s - 265.16s remaining)
EPOCH: 181.000	d_loss: 1.271	g_loss: 0.897	(2400.74s - 252.01s remaining)
EPOCH: 182.000	d_loss: 1.278	g_loss: 0.891	(2414.91s - 238.84s remaining)
EPOCH: 183.000	d_loss: 1.274	g_loss: 0.899	(2428.85s - 225.63s remaining)
EPOCH: 184.000	d_loss: 1.270	g_loss: 0.905	(2442.80s - 212.42s remaining)
EPOCH: 185.000	d_loss: 1.281	g_loss: 0.884	(2456.75s - 199.20s remaining)
EPOCH: 186.000	d_loss: 1.282	g_loss: 0.874	(2471.23s - 186.01s remaining)
EPOCH: 187.000	d_loss: 1.279	g_loss: 0.892	(2485.90s - 172.82s remaining)
EPOCH: 188.000	d_loss: 1.282	g_loss: 0.886	(2500.06s - 159.58s remaining)
EPOCH: 189.000	d_loss: 1.278	g_loss: 0.889	(2514.38s - 146.34s remaining)
EPOCH: 190.000	d_loss: 1.276	g_loss: 0.895	(2528.89s - 133.10s remaining)
EPOCH: 191.000	d_loss: 1.283	g_loss: 0.887	(2543.72s - 119.86s remaining)
EPOCH: 192.000	d_loss: 1.277	g_loss: 0.903	(2558.13s - 106.59s remaining)
EPOCH: 193.000	d_loss: 1.279	g_loss: 0.881	(2572.89s - 93.32s remaining)
EPOCH: 194.000	d_loss: 1.282	g_loss: 0.882	(2587.38s - 80.02s remaining)
EPOCH: 195.000	d_loss: 1.282	g_loss: 0.887	(2601.27s - 66.70s remaining)
EPOCH: 196.000	d_loss: 1.275	g_loss: 0.894	(2615.26s - 53.37s remaining)
EPOCH: 197.000	d_loss: 1.279	g_loss: 0.885	(2629.34s - 40.04s remaining)
EPOCH: 198.000	d_loss: 1.281	g_loss: 0.887	(2643.65s - 26.70s remaining)
EPOCH: 199.000	d_loss: 1.277	g_loss: 0.909	(2658.09s - 13.36s remaining)
EPOCH: 199.996	d_loss: 1.246	g_loss: 0.942	(2672.44s - 0.06s remaining)
  0%|          | 0/200 [00:00<?, ?it/s]/usr/local/lib/python3.6/dist-packages/numpy/core/fromnumeric.py:3335: RuntimeWarning: Mean of empty slice.
  out=out, **kwargs)
/usr/local/lib/python3.6/dist-packages/numpy/core/_methods.py:161: RuntimeWarning: invalid value encountered in double_scalars
  ret = ret.dtype.type(ret / rcount)
 13%|█▎        | 26/200 [00:00<00:00, 253.49it/s]
EPOCH: 199.998	d_loss: 1.263	g_loss: 0.876	(2672.47s - 0.03s remaining)
EPOCH: 200.000	d_loss: 1.256	g_loss: 0.935	(2672.49s - 0.00s remaining)
EPOCH: 200.000	d_loss: 1.285	g_loss: 0.884	(2672.52s - 0.00s remaining)
100%|██████████| 200/200 [00:03<00:00, 62.85it/s]
z = torch.randn(64, 100).to(device)
sample_images = generator(z).data.cpu().view(64, 1, 28, 28)
grid = make_grid(sample_images, nrow=8, normalize=True)
show(grid.cpu().detach().permute(1,2,0), sz=5)