256 KiB
256 KiB
import os
if not os.path.exists('dataset1'):
!wget -q https://www.dropbox.com/s/0pigmmmynbf9xwq/dataset1.zip
!unzip -q dataset1.zip
!rm dataset1.zip
!pip install -q torch_snippets pytorch_model_summary
from torch_snippets import *
from torchvision import transforms
from sklearn.model_selection import train_test_split
device = 'cuda' if torch.cuda.is_available() else 'cpu'
[K |████████████████████████████████| 61kB 4.7MB/s [K |████████████████████████████████| 36.7MB 82kB/s [K |████████████████████████████████| 102kB 13.2MB/s [?25h Building wheel for contextvars (setup.py) ... [?25l[?25hdone
tfms = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # imagenet
])
class SegData(Dataset):
def __init__(self, split):
self.items = stems(f'dataset1/images_prepped_{split}')
self.split = split
def __len__(self):
return len(self.items)
def __getitem__(self, ix):
image = read(f'dataset1/images_prepped_{self.split}/{self.items[ix]}.png', 1)
image = cv2.resize(image, (224,224))
mask = read(f'dataset1/annotations_prepped_{self.split}/{self.items[ix]}.png')
mask = cv2.resize(mask, (224,224))
return image, mask
def choose(self): return self[randint(len(self))]
def collate_fn(self, batch):
ims, masks = list(zip(*batch))
ims = torch.cat([tfms(im.copy()/255.)[None] for im in ims]).float().to(device)
ce_masks = torch.cat([torch.Tensor(mask[None]) for mask in masks]).long().to(device)
return ims, ce_masks
trn_ds = SegData('train')
val_ds = SegData('test')
trn_dl = DataLoader(trn_ds, batch_size=4, shuffle=True, collate_fn=trn_ds.collate_fn)
val_dl = DataLoader(val_ds, batch_size=1, shuffle=True, collate_fn=val_ds.collate_fn)
2020-11-07 12:49:50.267 | INFO | torch_snippets.loader:Glob:181 - 367 files found at dataset1/images_prepped_train 2020-11-07 12:49:50.270 | INFO | torch_snippets.loader:Glob:181 - 101 files found at dataset1/images_prepped_test
show(trn_ds[10][0])
def conv(in_channels, out_channels):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def up_conv(in_channels, out_channels):
return nn.Sequential(
nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
nn.ReLU(inplace=True)
)
from torchvision.models import vgg16_bn
class UNet(nn.Module):
def __init__(self, pretrained=True, out_channels=12):
super().__init__()
self.encoder = vgg16_bn(pretrained=pretrained).features
self.block1 = nn.Sequential(*self.encoder[:6])
self.block2 = nn.Sequential(*self.encoder[6:13])
self.block3 = nn.Sequential(*self.encoder[13:20])
self.block4 = nn.Sequential(*self.encoder[20:27])
self.block5 = nn.Sequential(*self.encoder[27:34])
self.bottleneck = nn.Sequential(*self.encoder[34:])
self.conv_bottleneck = conv(512, 1024)
self.up_conv6 = up_conv(1024, 512)
self.conv6 = conv(512 + 512, 512)
self.up_conv7 = up_conv(512, 256)
self.conv7 = conv(256 + 512, 256)
self.up_conv8 = up_conv(256, 128)
self.conv8 = conv(128 + 256, 128)
self.up_conv9 = up_conv(128, 64)
self.conv9 = conv(64 + 128, 64)
self.up_conv10 = up_conv(64, 32)
self.conv10 = conv(32 + 64, 32)
self.conv11 = nn.Conv2d(32, out_channels, kernel_size=1)
def forward(self, x):
block1 = self.block1(x)
block2 = self.block2(block1)
block3 = self.block3(block2)
block4 = self.block4(block3)
block5 = self.block5(block4)
bottleneck = self.bottleneck(block5)
x = self.conv_bottleneck(bottleneck)
x = self.up_conv6(x)
x = torch.cat([x, block5], dim=1)
x = self.conv6(x)
x = self.up_conv7(x)
x = torch.cat([x, block4], dim=1)
x = self.conv7(x)
x = self.up_conv8(x)
x = torch.cat([x, block3], dim=1)
x = self.conv8(x)
x = self.up_conv9(x)
x = torch.cat([x, block2], dim=1)
x = self.conv9(x)
x = self.up_conv10(x)
x = torch.cat([x, block1], dim=1)
x = self.conv10(x)
x = self.conv11(x)
return x
ce = nn.CrossEntropyLoss()
def UnetLoss(preds, targets):
ce_loss = ce(preds, targets)
acc = (torch.max(preds, 1)[1] == targets).float().mean()
return ce_loss, acc
def train_batch(model, data, optimizer, criterion):
model.train()
ims, ce_masks = data
_masks = model(ims)
optimizer.zero_grad()
loss, acc = criterion(_masks, ce_masks)
loss.backward()
optimizer.step()
return loss.item(), acc.item()
@torch.no_grad()
def validate_batch(model, data, criterion):
model.eval()
ims, masks = data
_masks = model(ims)
loss, acc = criterion(_masks, masks)
return loss.item(), acc.item()
model = UNet().to(device)
criterion = UnetLoss
optimizer = optim.Adam(model.parameters(), lr=1e-3)
n_epochs = 20
Downloading: "https://download.pytorch.org/models/vgg16_bn-6c64b313.pth" to /root/.cache/torch/hub/checkpoints/vgg16_bn-6c64b313.pth
HBox(children=(FloatProgress(value=0.0, max=553507836.0), HTML(value='')))
log = Report(n_epochs)
for ex in range(n_epochs):
N = len(trn_dl)
for bx, data in enumerate(trn_dl):
loss, acc = train_batch(model, data, optimizer, criterion)
log.record(ex+(bx+1)/N, trn_loss=loss, trn_acc=acc, end='\r')
N = len(val_dl)
for bx, data in enumerate(val_dl):
loss, acc = validate_batch(model, data, criterion)
log.record(ex+(bx+1)/N, val_loss=loss, val_acc=acc, end='\r')
log.report_avgs(ex+1)
EPOCH: 1.000 trn_loss: 1.169 trn_acc: 0.737 val_loss: 0.976 val_acc: 0.742 (18.50s - 351.43s remaining) EPOCH: 2.000 trn_loss: 0.692 trn_acc: 0.827 val_loss: 0.590 val_acc: 0.858 (36.80s - 331.24s remaining) EPOCH: 3.000 trn_loss: 0.567 trn_acc: 0.851 val_loss: 0.606 val_acc: 0.840 (54.90s - 311.11s remaining) EPOCH: 4.000 trn_loss: 0.496 trn_acc: 0.866 val_loss: 0.544 val_acc: 0.856 (73.06s - 292.25s remaining) EPOCH: 5.000 trn_loss: 0.468 trn_acc: 0.872 val_loss: 0.487 val_acc: 0.864 (91.31s - 273.92s remaining) EPOCH: 6.000 trn_loss: 0.420 trn_acc: 0.883 val_loss: 0.603 val_acc: 0.809 (109.59s - 255.72s remaining) EPOCH: 7.000 trn_loss: 0.402 trn_acc: 0.887 val_loss: 0.423 val_acc: 0.870 (127.91s - 237.55s remaining) EPOCH: 8.000 trn_loss: 0.373 trn_acc: 0.894 val_loss: 0.424 val_acc: 0.863 (146.29s - 219.43s remaining) EPOCH: 9.000 trn_loss: 0.341 trn_acc: 0.902 val_loss: 0.428 val_acc: 0.869 (164.71s - 201.31s remaining) EPOCH: 10.000 trn_loss: 0.327 trn_acc: 0.907 val_loss: 0.404 val_acc: 0.879 (183.10s - 183.10s remaining) EPOCH: 11.000 trn_loss: 0.310 trn_acc: 0.913 val_loss: 0.615 val_acc: 0.819 (201.44s - 164.81s remaining) EPOCH: 12.000 trn_loss: 0.316 trn_acc: 0.910 val_loss: 0.381 val_acc: 0.894 (219.77s - 146.51s remaining) EPOCH: 13.000 trn_loss: 0.288 trn_acc: 0.919 val_loss: 0.473 val_acc: 0.870 (238.21s - 128.27s remaining) EPOCH: 14.000 trn_loss: 0.266 trn_acc: 0.925 val_loss: 0.402 val_acc: 0.885 (256.77s - 110.04s remaining) EPOCH: 15.000 trn_loss: 0.254 trn_acc: 0.928 val_loss: 0.358 val_acc: 0.897 (275.37s - 91.79s remaining) EPOCH: 16.000 trn_loss: 0.247 trn_acc: 0.930 val_loss: 0.386 val_acc: 0.893 (294.03s - 73.51s remaining) EPOCH: 17.000 trn_loss: 0.236 trn_acc: 0.933 val_loss: 0.408 val_acc: 0.888 (312.70s - 55.18s remaining) EPOCH: 18.000 trn_loss: 0.227 trn_acc: 0.936 val_loss: 0.400 val_acc: 0.895 (331.29s - 36.81s remaining) EPOCH: 19.000 trn_loss: 0.217 trn_acc: 0.938 val_loss: 0.422 val_acc: 0.896 (349.77s - 18.41s remaining) EPOCH: 20.000 trn_loss: 0.207 trn_acc: 0.941 val_loss: 0.405 val_acc: 0.896 (368.25s - 0.00s remaining)
log.plot_epochs(['trn_loss','val_loss'])
0%| | 0/20 [00:00<?, ?it/s]/usr/local/lib/python3.6/dist-packages/numpy/core/fromnumeric.py:3335: RuntimeWarning: Mean of empty slice. out=out, **kwargs) /usr/local/lib/python3.6/dist-packages/numpy/core/_methods.py:161: RuntimeWarning: invalid value encountered in double_scalars ret = ret.dtype.type(ret / rcount) 100%|██████████| 20/20 [00:00<00:00, 1512.44it/s]
im, mask = next(iter(val_dl))
_mask = model(im)
_, _mask = torch.max(_mask, dim=1)
subplots([im[0].permute(1,2,0).detach().cpu()[:,:,0], mask.permute(1,2,0).detach().cpu()[:,:,0]
,_mask.permute(1,2,0).detach().cpu()[:,:,0]],
nc=3, titles=['Original image','Original mask','Predicted mask'])
2020-11-07 13:00:34.595 | INFO | torch_snippets.loader:subplots:375 - plotting 3 images in a grid of 1x3 @ (5, 5)