413 KiB
413 KiB
!pip install -q torch_snippets
from torch_snippets import *
from torchvision.datasets import MNIST
from torchvision import transforms
device = 'cuda' if torch.cuda.is_available() else 'cpu'
img_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
transforms.Lambda(lambda x: x.to(device))
])
trn_ds = MNIST('/content/', transform=img_transform, train=True, download=True)
val_ds = MNIST('/content/', transform=img_transform, train=False, download=True)
batch_size = 128
trn_dl = DataLoader(trn_ds, batch_size=batch_size, shuffle=True)
val_dl = DataLoader(val_ds, batch_size=batch_size, shuffle=False)
[K |████████████████████████████████| 36.7MB 80kB/s [K |████████████████████████████████| 61kB 9.3MB/s [K |████████████████████████████████| 102kB 13.8MB/s [?25h Building wheel for contextvars (setup.py) ... [?25l[?25hdone Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /content/MNIST/raw/train-images-idx3-ubyte.gz
HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))
Extracting /content/MNIST/raw/train-images-idx3-ubyte.gz to /content/MNIST/raw Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to /content/MNIST/raw/train-labels-idx1-ubyte.gz
HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))
Extracting /content/MNIST/raw/train-labels-idx1-ubyte.gz to /content/MNIST/raw Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to /content/MNIST/raw/t10k-images-idx3-ubyte.gz
HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))
Extracting /content/MNIST/raw/t10k-images-idx3-ubyte.gz to /content/MNIST/raw Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to /content/MNIST/raw/t10k-labels-idx1-ubyte.gz
HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))
Extracting /content/MNIST/raw/t10k-labels-idx1-ubyte.gz to /content/MNIST/raw Processing... Done!
/usr/local/lib/python3.6/dist-packages/torchvision/datasets/mnist.py:480: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /pytorch/torch/csrc/utils/tensor_numpy.cpp:141.) return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)
class ConvAutoEncoder(nn.Module):
def __init__(self):
super().__init__()
self.encoder = nn.Sequential(
nn.Conv2d(1, 32, 3, stride=3, padding=1), nn.ReLU(True),
nn.MaxPool2d(2, stride=2),
nn.Conv2d(32, 64, 3, stride=2, padding=1), nn.ReLU(True),
nn.MaxPool2d(2, stride=1)
)
self.decoder = nn.Sequential(
nn.ConvTranspose2d(64, 32, 3, stride=2), nn.ReLU(True),
nn.ConvTranspose2d(32, 16, 5, stride=3, padding=1), nn.ReLU(True),
nn.ConvTranspose2d(16, 1, 2, stride=2, padding=1), nn.Tanh()
)
def forward(self, x):
x = self.encoder(x)
x = self.decoder(x)
return x
model = ConvAutoEncoder().to(device)
!pip install torch_summary
from torchsummary import summary
summary(model, torch.zeros(2,1,28,28));
Collecting torch_summary Downloading https://files.pythonhosted.org/packages/83/49/f9db57bcad7246591b93519fd8e5166c719548c45945ef7d2fc9fcba46fb/torch_summary-1.4.3-py3-none-any.whl Installing collected packages: torch-summary Successfully installed torch-summary-1.4.3 ========================================================================================== Layer (type:depth-idx) Output Shape Param # ========================================================================================== ├─Sequential: 1-1 [-1, 64, 2, 2] -- | └─Conv2d: 2-1 [-1, 32, 10, 10] 320 | └─ReLU: 2-2 [-1, 32, 10, 10] -- | └─MaxPool2d: 2-3 [-1, 32, 5, 5] -- | └─Conv2d: 2-4 [-1, 64, 3, 3] 18,496 | └─ReLU: 2-5 [-1, 64, 3, 3] -- | └─MaxPool2d: 2-6 [-1, 64, 2, 2] -- ├─Sequential: 1-2 [-1, 1, 28, 28] -- | └─ConvTranspose2d: 2-7 [-1, 32, 5, 5] 18,464 | └─ReLU: 2-8 [-1, 32, 5, 5] -- | └─ConvTranspose2d: 2-9 [-1, 16, 15, 15] 12,816 | └─ReLU: 2-10 [-1, 16, 15, 15] -- | └─ConvTranspose2d: 2-11 [-1, 1, 28, 28] 65 | └─Tanh: 2-12 [-1, 1, 28, 28] -- ========================================================================================== Total params: 50,161 Trainable params: 50,161 Non-trainable params: 0 Total mult-adds (M): 3.64 ========================================================================================== Input size (MB): 0.01 Forward/backward pass size (MB): 0.07 Params size (MB): 0.19 Estimated Total Size (MB): 0.27 ==========================================================================================
def train_batch(input, model, criterion, optimizer):
model.train()
optimizer.zero_grad()
output = model(input)
loss = criterion(output, input)
loss.backward()
optimizer.step()
return loss
@torch.no_grad()
def validate_batch(input, model, criterion):
model.eval()
output = model(input)
loss = criterion(output, input)
return loss
model = ConvAutoEncoder().to(device)
criterion = nn.MSELoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=1e-5)
num_epochs = 5
log = Report(num_epochs)
for epoch in range(num_epochs):
N = len(trn_dl)
for ix, (data, _) in enumerate(trn_dl):
loss = train_batch(data, model, criterion, optimizer)
log.record(pos=(epoch + (ix+1)/N), trn_loss=loss, end='\r')
N = len(val_dl)
for ix, (data, _) in enumerate(val_dl):
loss = validate_batch(data, model, criterion)
log.record(pos=(epoch + (ix+1)/N), val_loss=loss, end='\r')
log.report_avgs(epoch+1)
EPOCH: 1.000 trn_loss: 0.179 val_loss: 0.101 (19.26s - 77.05s remaining) EPOCH: 2.000 trn_loss: 0.089 val_loss: 0.076 (38.33s - 57.49s remaining) EPOCH: 3.000 trn_loss: 0.073 val_loss: 0.066 (57.80s - 38.53s remaining) EPOCH: 4.000 trn_loss: 0.065 val_loss: 0.061 (77.49s - 19.37s remaining) EPOCH: 5.000 trn_loss: 0.060 val_loss: 0.057 (96.65s - 0.00s remaining)
log.plot_epochs(log=True)
0%| | 0/6 [00:00<?, ?it/s]/usr/local/lib/python3.6/dist-packages/numpy/core/fromnumeric.py:3335: RuntimeWarning: Mean of empty slice. out=out, **kwargs) /usr/local/lib/python3.6/dist-packages/numpy/core/_methods.py:161: RuntimeWarning: invalid value encountered in double_scalars ret = ret.dtype.type(ret / rcount) 100%|██████████| 6/6 [00:00<00:00, 1408.35it/s]
for _ in range(3):
ix = np.random.randint(len(val_ds))
im, _ = val_ds[ix]
_im = model(im[None])[0]
fig, ax = plt.subplots(1, 2, figsize=(3,3))
show(im[0], ax=ax[0], title='input')
show(_im[0], ax=ax[1], title='prediction')
plt.tight_layout()
plt.show()
latent_vectors = []
classes = []
for im,clss in val_dl:
latent_vectors.append(model.encoder(im).view(len(im),-1))
classes.extend(clss)
latent_vectors = torch.cat(latent_vectors).cpu().detach().numpy()
from sklearn.manifold import TSNE
tsne = TSNE(2)
clustered = tsne.fit_transform(latent_vectors)
fig = plt.figure(figsize=(12,10))
cmap = plt.get_cmap('Spectral', 10)
plt.scatter(*zip(*clustered), c=classes, cmap=cmap)
plt.colorbar(drawedges=True)
<matplotlib.colorbar.Colorbar at 0x7fc0d81c8390>
latent_vectors = []
classes = []
for im,clss in val_dl:
latent_vectors.append(model.encoder(im))
classes.extend(clss)
latent_vectors = torch.cat(latent_vectors).cpu().detach().numpy().reshape(10000, -1)
rand_vectors = []
for col in latent_vectors.transpose(1,0):
mu, sigma = col.mean(), col.std()
rand_vectors.append(sigma*torch.randn(1,100) + mu)
rand_vectors = torch.cat(rand_vectors).transpose(1,0).to(device)
fig, ax = plt.subplots(10,10,figsize=(7,7)); ax = iter(ax.flat)
for p in rand_vectors:
img = model.decoder(p.reshape(1,64,2,2)).view(28,28)
show(img, ax=next(ax))