Computer_Vision/Chapter10/crowd_counting.ipynb
2024-02-13 03:34:51 +01:00

56 KiB
Raw Permalink Blame History

Open In Colab

%%time
import os
if not os.path.exists('CSRNet-pytorch/'):
    !pip install -U scipy torch_snippets torch_summary
    !git clone https://github.com/sizhky/CSRNet-pytorch.git
    from google.colab import files
    files.upload() # upload kaggle.json
    !mkdir -p ~/.kaggle
    !mv kaggle.json ~/.kaggle/
    !ls ~/.kaggle
    !chmod 600 /root/.kaggle/kaggle.json
    print('downloading data...')
    !kaggle datasets download -d tthien/shanghaitech-with-people-density-map/
    print('unzipping data...')
    !unzip -qq shanghaitech-with-people-density-map.zip

%cd CSRNet-pytorch
!ln -s ../shanghaitech_with_people_density_map
from torch_snippets import *
import h5py
from scipy import io
Collecting scipy
[?25l  Downloading https://files.pythonhosted.org/packages/c8/89/63171228d5ced148f5ced50305c89e8576ffc695a90b58fe5bb602b910c2/scipy-1.5.4-cp36-cp36m-manylinux1_x86_64.whl (25.9MB)
     |████████████████████████████████| 25.9MB 119kB/s 
[?25hCollecting torch_snippets
  Downloading https://files.pythonhosted.org/packages/e5/57/7d513a66ffc00d1495c8a8eeac8754b42233d8a68aa565077db8939b0452/torch_snippets-0.235-py3-none-any.whl
Collecting torch_summary
  Downloading https://files.pythonhosted.org/packages/83/49/f9db57bcad7246591b93519fd8e5166c719548c45945ef7d2fc9fcba46fb/torch_summary-1.4.3-py3-none-any.whl
Requirement already satisfied, skipping upgrade: numpy>=1.14.5 in /usr/local/lib/python3.6/dist-packages (from scipy) (1.18.5)
Collecting opencv-python-headless
[?25l  Downloading https://files.pythonhosted.org/packages/08/e9/57d869561389884136be65a2d1bc038fe50171e2ba348fda269a4aab8032/opencv_python_headless-4.4.0.46-cp36-cp36m-manylinux2014_x86_64.whl (36.7MB)
     |████████████████████████████████| 36.7MB 89kB/s 
[?25hRequirement already satisfied, skipping upgrade: Pillow in /usr/local/lib/python3.6/dist-packages (from torch_snippets) (7.0.0)
Requirement already satisfied, skipping upgrade: matplotlib in /usr/local/lib/python3.6/dist-packages (from torch_snippets) (3.2.2)
Requirement already satisfied, skipping upgrade: tqdm in /usr/local/lib/python3.6/dist-packages (from torch_snippets) (4.41.1)
Requirement already satisfied, skipping upgrade: pandas in /usr/local/lib/python3.6/dist-packages (from torch_snippets) (1.1.4)
Collecting loguru
[?25l  Downloading https://files.pythonhosted.org/packages/6d/48/0a7d5847e3de329f1d0134baf707b689700b53bd3066a5a8cfd94b3c9fc8/loguru-0.5.3-py3-none-any.whl (57kB)
     |████████████████████████████████| 61kB 9.5MB/s 
[?25hRequirement already satisfied, skipping upgrade: dill in /usr/local/lib/python3.6/dist-packages (from torch_snippets) (0.3.3)
Requirement already satisfied, skipping upgrade: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib->torch_snippets) (2.4.7)
Requirement already satisfied, skipping upgrade: kiwisolver>=1.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib->torch_snippets) (1.3.1)
Requirement already satisfied, skipping upgrade: cycler>=0.10 in /usr/local/lib/python3.6/dist-packages (from matplotlib->torch_snippets) (0.10.0)
Requirement already satisfied, skipping upgrade: python-dateutil>=2.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib->torch_snippets) (2.8.1)
Requirement already satisfied, skipping upgrade: pytz>=2017.2 in /usr/local/lib/python3.6/dist-packages (from pandas->torch_snippets) (2018.9)
Collecting aiocontextvars>=0.2.0; python_version < "3.7"
  Downloading https://files.pythonhosted.org/packages/db/c1/7a723e8d988de0a2e623927396e54b6831b68cb80dce468c945b849a9385/aiocontextvars-0.2.2-py2.py3-none-any.whl
Requirement already satisfied, skipping upgrade: six in /usr/local/lib/python3.6/dist-packages (from cycler>=0.10->matplotlib->torch_snippets) (1.15.0)
Collecting contextvars==2.4; python_version < "3.7"
  Downloading https://files.pythonhosted.org/packages/83/96/55b82d9f13763be9d672622e1b8106c85acb83edd7cc2fa5bc67cd9877e9/contextvars-2.4.tar.gz
Collecting immutables>=0.9
[?25l  Downloading https://files.pythonhosted.org/packages/99/e0/ea6fd4697120327d26773b5a84853f897a68e33d3f9376b00a8ff96e4f63/immutables-0.14-cp36-cp36m-manylinux1_x86_64.whl (98kB)
     |████████████████████████████████| 102kB 14.1MB/s 
[?25hBuilding wheels for collected packages: contextvars
  Building wheel for contextvars (setup.py) ... [?25l[?25hdone
  Created wheel for contextvars: filename=contextvars-2.4-cp36-none-any.whl size=7666 sha256=5976ad63e4d7be29f86455820048fd851c169bc1671ad4627743b2b16053a1e3
  Stored in directory: /root/.cache/pip/wheels/a5/7d/68/1ebae2668bda2228686e3c1cf16f2c2384cea6e9334ad5f6de
Successfully built contextvars
ERROR: tensorflow 2.3.0 has requirement scipy==1.4.1, but you'll have scipy 1.5.4 which is incompatible.
ERROR: albumentations 0.1.12 has requirement imgaug<0.2.7,>=0.2.5, but you'll have imgaug 0.2.9 which is incompatible.
Installing collected packages: scipy, opencv-python-headless, immutables, contextvars, aiocontextvars, loguru, torch-snippets, torch-summary
  Found existing installation: scipy 1.4.1
    Uninstalling scipy-1.4.1:
      Successfully uninstalled scipy-1.4.1
Successfully installed aiocontextvars-0.2.2 contextvars-2.4 immutables-0.14 loguru-0.5.3 opencv-python-headless-4.4.0.46 scipy-1.5.4 torch-snippets-0.235 torch-summary-1.4.3
Cloning into 'CSRNet-pytorch'...
remote: Enumerating objects: 6, done.
remote: Counting objects: 100% (6/6), done.
remote: Compressing objects: 100% (6/6), done.
remote: Total 92 (delta 1), reused 0 (delta 0), pack-reused 86
Unpacking objects: 100% (92/92), done.
Upload widget is only available when the cell has been executed in the current browser session. Please rerun this cell to enable.
Saving kaggle.json to kaggle.json
kaggle.json
downloading data...
Downloading shanghaitech-with-people-density-map.zip to /content
100% 4.78G/4.79G [01:37<00:00, 41.5MB/s]
100% 4.79G/4.79G [01:37<00:00, 52.6MB/s]
unzipping data...
/content/CSRNet-pytorch
CPU times: user 1.64 s, sys: 431 ms, total: 2.07 s
Wall time: 5min 34s
part_A = Glob('shanghaitech_with_people_density_map/ShanghaiTech/part_A/train_data/');

image_folder = 'shanghaitech_with_people_density_map/ShanghaiTech/part_A/train_data/images/'
heatmap_folder = 'shanghaitech_with_people_density_map/ShanghaiTech/part_A/train_data/ground-truth-h5/'
gt_folder = 'shanghaitech_with_people_density_map/ShanghaiTech/part_A/train_data/ground-truth/'
2020-11-08 10:55:51.544 | INFO     | torch_snippets.loader:Glob:181 - 3 files found at shanghaitech_with_people_density_map/ShanghaiTech/part_A/train_data/
device = 'cuda' if torch.cuda.is_available() else 'cpu'
tfm = T.Compose([
    T.ToTensor()
])

class Crowds(Dataset):
    def __init__(self, stems):
        self.stems = stems

    def __len__(self):
        return len(self.stems)

    def __getitem__(self, ix):
        _stem = self.stems[ix]
        image_path = f'{image_folder}/{_stem}.jpg'
        heatmap_path = f'{heatmap_folder}/{_stem}.h5'
        gt_path = f'{gt_folder}/GT_{_stem}.mat'

        pts = io.loadmat(gt_path)
        pts = len(pts['image_info'][0,0][0,0][0])

        image = read(image_path, 1)
        with h5py.File(heatmap_path, 'r') as hf:
            gt = hf['density'][:]
        gt = resize(gt, 1/8)*64
        return image.copy(), gt.copy(), pts

    def collate_fn(self, batch):
        ims, gts, pts = list(zip(*batch))
        ims = torch.cat([tfm(im)[None] for im in ims]).to(device)
        gts = torch.cat([tfm(gt)[None] for gt in gts]).to(device)
        return ims, gts, torch.tensor(pts).to(device)

    def choose(self):
        return self[randint(len(self))]

from sklearn.model_selection import train_test_split
trn_stems, val_stems = train_test_split(stems(Glob(image_folder)), random_state=10)

trn_ds = Crowds(trn_stems)
val_ds = Crowds(val_stems)

trn_dl = DataLoader(trn_ds, batch_size=1, shuffle=True, collate_fn=trn_ds.collate_fn)
val_dl = DataLoader(val_ds, batch_size=1, shuffle=True, collate_fn=val_ds.collate_fn)
2020-11-08 10:55:53.547 | INFO     | torch_snippets.loader:Glob:181 - 300 files found at shanghaitech_with_people_density_map/ShanghaiTech/part_A/train_data/images/
import torch.nn as nn
import torch
from torchvision import models
from utils import save_net,load_net

def make_layers(cfg, in_channels = 3,batch_norm=False,dilation = False):
    if dilation:
        d_rate = 2
    else:
        d_rate = 1
    layers = []
    for v in cfg:
        if v == 'M':
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
        else:
            conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=d_rate, dilation=d_rate)
            if batch_norm:
                layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
            else:
                layers += [conv2d, nn.ReLU(inplace=True)]
            in_channels = v
    return nn.Sequential(*layers)
class CSRNet(nn.Module):
    def __init__(self, load_weights=False):
        super(CSRNet, self).__init__()
        self.seen = 0
        self.frontend_feat = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512]
        self.backend_feat = [512, 512, 512, 256, 128, 64]
        self.frontend = make_layers(self.frontend_feat)
        self.backend = make_layers(self.backend_feat,in_channels = 512,dilation = True)
        self.output_layer = nn.Conv2d(64, 1, kernel_size=1)
        if not load_weights:
            mod = models.vgg16(pretrained = True)
            self._initialize_weights()
            items = list(self.frontend.state_dict().items())
            _items = list(mod.state_dict().items())
            for i in range(len(self.frontend.state_dict().items())):
                items[i][1].data[:] = _items[i][1].data[:]
    def forward(self,x):
        x = self.frontend(x)
        x = self.backend(x)
        x = self.output_layer(x)
        return x
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.normal_(m.weight, std=0.01)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
def train_batch(model, data, optimizer, criterion):
    model.train()
    optimizer.zero_grad()
    ims, gts, pts = data
    _gts = model(ims)
    loss = criterion(_gts, gts)
    loss.backward()
    optimizer.step()
    pts_loss = nn.L1Loss()(_gts.sum(), gts.sum())
    return loss.item(), pts_loss.item()

@torch.no_grad()
def validate_batch(model, data, criterion):
    model.eval()
    ims, gts, pts = data
    _gts = model(ims)
    loss = criterion(_gts, gts)
    pts_loss = nn.L1Loss()(_gts.sum(), gts.sum())
    return loss.item(), pts_loss.item()
model = CSRNet().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-6)
n_epochs = 20

log = Report(n_epochs)
for ex in range(n_epochs):
    N = len(trn_dl)
    for bx, data in enumerate(trn_dl):
        loss, pts_loss = train_batch(model, data, optimizer, criterion)
        log.record(ex+(bx+1)/N, trn_loss=loss, trn_pts_loss=pts_loss, end='\r')

    N = len(val_dl)
    for bx, data in enumerate(val_dl):
        loss, pts_loss = validate_batch(model, data, criterion)
        log.record(ex+(bx+1)/N, val_loss=loss, val_pts_loss=pts_loss, end='\r')

    log.report_avgs(ex+1)
Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth
HBox(children=(FloatProgress(value=0.0, max=553433881.0), HTML(value='')))
EPOCH: 1.000	trn_loss: 0.048	trn_pts_loss: 427.918	val_loss: 0.053	val_pts_loss: 347.524	(77.39s - 1470.45s remaining)
EPOCH: 2.000	trn_loss: 0.042	trn_pts_loss: 306.686	val_loss: 0.048	val_pts_loss: 262.405	(154.97s - 1394.71s remaining)
EPOCH: 3.000	trn_loss: 0.036	trn_pts_loss: 204.923	val_loss: 0.043	val_pts_loss: 190.447	(232.47s - 1317.31s remaining)
EPOCH: 4.000	trn_loss: 0.033	trn_pts_loss: 159.764	val_loss: 0.042	val_pts_loss: 184.419	(310.06s - 1240.24s remaining)
EPOCH: 5.000	trn_loss: 0.032	trn_pts_loss: 145.682	val_loss: 0.041	val_pts_loss: 241.129	(387.48s - 1162.44s remaining)
EPOCH: 6.000	trn_loss: 0.031	trn_pts_loss: 139.674	val_loss: 0.039	val_pts_loss: 172.444	(464.97s - 1084.93s remaining)
EPOCH: 7.000	trn_loss: 0.031	trn_pts_loss: 143.573	val_loss: 0.039	val_pts_loss: 154.602	(542.31s - 1007.15s remaining)
EPOCH: 8.000	trn_loss: 0.030	trn_pts_loss: 126.261	val_loss: 0.038	val_pts_loss: 152.646	(619.65s - 929.47s remaining)
EPOCH: 9.000	trn_loss: 0.030	trn_pts_loss: 121.418	val_loss: 0.036	val_pts_loss: 168.236	(696.97s - 851.86s remaining)
EPOCH: 10.000	trn_loss: 0.029	trn_pts_loss: 122.989	val_loss: 0.039	val_pts_loss: 183.355	(774.30s - 774.30s remaining)
EPOCH: 11.000	trn_loss: 0.029	trn_pts_loss: 121.840	val_loss: 0.038	val_pts_loss: 198.861	(851.56s - 696.73s remaining)
EPOCH: 12.000	trn_loss: 0.029	trn_pts_loss: 123.117	val_loss: 0.032	val_pts_loss: 131.696	(928.88s - 619.26s remaining)
EPOCH: 13.000	trn_loss: 0.029	trn_pts_loss: 117.252	val_loss: 0.041	val_pts_loss: 145.208	(1006.07s - 541.73s remaining)
EPOCH: 14.000	trn_loss: 0.028	trn_pts_loss: 110.343	val_loss: 0.036	val_pts_loss: 134.451	(1083.44s - 464.33s remaining)
EPOCH: 15.000	trn_loss: 0.028	trn_pts_loss: 113.773	val_loss: 0.036	val_pts_loss: 182.355	(1160.97s - 386.99s remaining)
EPOCH: 16.000	trn_loss: 0.028	trn_pts_loss: 109.048	val_loss: 0.036	val_pts_loss: 152.386	(1238.67s - 309.67s remaining)
EPOCH: 17.000	trn_loss: 0.028	trn_pts_loss: 107.314	val_loss: 0.035	val_pts_loss: 139.919	(1316.50s - 232.32s remaining)
EPOCH: 18.000	trn_loss: 0.028	trn_pts_loss: 111.054	val_loss: 0.036	val_pts_loss: 133.333	(1394.37s - 154.93s remaining)
EPOCH: 19.000	trn_loss: 0.027	trn_pts_loss: 110.610	val_loss: 0.036	val_pts_loss: 141.666	(1472.28s - 77.49s remaining)
EPOCH: 20.000	trn_loss: 0.028	trn_pts_loss: 113.707	val_loss: 0.034	val_pts_loss: 167.978	(1550.17s - 0.00s remaining)
from matplotlib import cm as c
from torchvision import datasets, transforms
from PIL import Image
transform=transforms.Compose([
                      transforms.ToTensor(),transforms.Normalize(
                          mean=[0.485, 0.456, 0.406],
                          std=[0.229, 0.224, 0.225]),
                  ])

test_folder = 'shanghaitech_with_people_density_map/ShanghaiTech/part_A/test_data/'
imgs = Glob(f'{test_folder}/images')
f = choose(imgs)
print(f)
img = transform(Image.open(f).convert('RGB')).to(device)
2020-11-08 11:30:25.780 | INFO     | torch_snippets.loader:Glob:181 - 182 files found at shanghaitech_with_people_density_map/ShanghaiTech/part_A/test_data//images
shanghaitech_with_people_density_map/ShanghaiTech/part_A/test_data//images/IMG_4.jpg
output = model(img[None])
print("Predicted Count : ",int(output.detach().cpu().sum().numpy()))
temp = np.asarray(output.detach().cpu().reshape(output.detach().cpu().shape[2],output.detach().cpu().shape[3]))
plt.imshow(temp,cmap = c.jet)
plt.show()
Predicted Count :  183