Computer_Vision/Chapter05/age_gender_torch_snippets.i...

334 KiB
Raw Permalink Blame History

Open In Colab

from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive
from google.colab import auth
from oauth2client.client import GoogleCredentials

auth.authenticate_user()
gauth = GoogleAuth()
gauth.credentials = GoogleCredentials.get_application_default()
drive = GoogleDrive(gauth)

def getFile_from_drive( file_id, name ):
  downloaded = drive.CreateFile({'id': file_id})
  downloaded.GetContentFile(name)

getFile_from_drive('1Z1RqRo0_JiavaZw2yzZG6WETdZQ8qX86', 'fairface-img-margin025-trainval.zip')
getFile_from_drive('1k5vvyREmHDW5TSM9QgB04Bvc8C8_7dl-', 'fairface-label-train.csv')
getFile_from_drive('1_rtz1M1zhvS0d5vVoXUamnohB6cJ02iJ', 'fairface-label-val.csv')

!unzip -qq fairface-img-margin025-trainval.zip
!pip install torch_snippets
from torch_snippets import *

trn_df = pd.read_csv('fairface-label-train.csv')
val_df = pd.read_csv('fairface-label-val.csv')
trn_df.head()
Collecting torch_snippets
  Downloading https://files.pythonhosted.org/packages/e5/57/7d513a66ffc00d1495c8a8eeac8754b42233d8a68aa565077db8939b0452/torch_snippets-0.235-py3-none-any.whl
Requirement already satisfied: matplotlib in /usr/local/lib/python3.6/dist-packages (from torch_snippets) (3.2.2)
Requirement already satisfied: pandas in /usr/local/lib/python3.6/dist-packages (from torch_snippets) (1.1.4)
Collecting loguru
[?25l  Downloading https://files.pythonhosted.org/packages/6d/48/0a7d5847e3de329f1d0134baf707b689700b53bd3066a5a8cfd94b3c9fc8/loguru-0.5.3-py3-none-any.whl (57kB)
     |████████████████████████████████| 61kB 4.3MB/s 
[?25hRequirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from torch_snippets) (1.18.5)
Requirement already satisfied: tqdm in /usr/local/lib/python3.6/dist-packages (from torch_snippets) (4.41.1)
Collecting opencv-python-headless
[?25l  Downloading https://files.pythonhosted.org/packages/08/e9/57d869561389884136be65a2d1bc038fe50171e2ba348fda269a4aab8032/opencv_python_headless-4.4.0.46-cp36-cp36m-manylinux2014_x86_64.whl (36.7MB)
     |████████████████████████████████| 36.7MB 84kB/s 
[?25hRequirement already satisfied: dill in /usr/local/lib/python3.6/dist-packages (from torch_snippets) (0.3.3)
Requirement already satisfied: Pillow in /usr/local/lib/python3.6/dist-packages (from torch_snippets) (7.0.0)
Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib->torch_snippets) (2.8.1)
Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.6/dist-packages (from matplotlib->torch_snippets) (0.10.0)
Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib->torch_snippets) (1.3.1)
Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib->torch_snippets) (2.4.7)
Requirement already satisfied: pytz>=2017.2 in /usr/local/lib/python3.6/dist-packages (from pandas->torch_snippets) (2018.9)
Collecting aiocontextvars>=0.2.0; python_version < "3.7"
  Downloading https://files.pythonhosted.org/packages/db/c1/7a723e8d988de0a2e623927396e54b6831b68cb80dce468c945b849a9385/aiocontextvars-0.2.2-py2.py3-none-any.whl
Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.6/dist-packages (from python-dateutil>=2.1->matplotlib->torch_snippets) (1.15.0)
Collecting contextvars==2.4; python_version < "3.7"
  Downloading https://files.pythonhosted.org/packages/83/96/55b82d9f13763be9d672622e1b8106c85acb83edd7cc2fa5bc67cd9877e9/contextvars-2.4.tar.gz
Collecting immutables>=0.9
[?25l  Downloading https://files.pythonhosted.org/packages/99/e0/ea6fd4697120327d26773b5a84853f897a68e33d3f9376b00a8ff96e4f63/immutables-0.14-cp36-cp36m-manylinux1_x86_64.whl (98kB)
     |████████████████████████████████| 102kB 13.5MB/s 
[?25hBuilding wheels for collected packages: contextvars
  Building wheel for contextvars (setup.py) ... [?25l[?25hdone
  Created wheel for contextvars: filename=contextvars-2.4-cp36-none-any.whl size=7666 sha256=8739abbc29c7cccf1228c26f5f6c1f8f85fd05588ecd8af21af7ef90e86b8876
  Stored in directory: /root/.cache/pip/wheels/a5/7d/68/1ebae2668bda2228686e3c1cf16f2c2384cea6e9334ad5f6de
Successfully built contextvars
Installing collected packages: immutables, contextvars, aiocontextvars, loguru, opencv-python-headless, torch-snippets
Successfully installed aiocontextvars-0.2.2 contextvars-2.4 immutables-0.14 loguru-0.5.3 opencv-python-headless-4.4.0.46 torch-snippets-0.235
file age gender race service_test
0 train/1.jpg 59 Male East Asian True
1 train/2.jpg 39 Female Indian False
2 train/3.jpg 11 Female Black False
3 train/4.jpg 26 Female Indian True
4 train/5.jpg 26 Female Indian True
IMAGE_SIZE = 224
class GenderAgeClass(Dataset):
    def __init__(self, df, tfms=None):
        self.df = df
        self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    def __len__(self): return len(self.df)
    def __getitem__(self, ix):
        f = self.df.iloc[ix].squeeze()
        file = f.file
        gen = f.gender == 'Female'
        age = f.age
        im = read(file, 1)
        return im, age, gen

    def preprocess_image(self, im):
        im = resize(im, IMAGE_SIZE)
        im = torch.tensor(im).permute(2,0,1)
        im = self.normalize(im/255.)
        return im[None]

    def collate_fn(self, batch):
        'preprocess images, ages and genders'
        ims, ages, genders = [], [], []
        for im, age, gender in batch:
            im = self.preprocess_image(im)
            ims.append(im)

            ages.append(float(int(age)/80))
            genders.append(float(gender))

        ages, genders = [torch.tensor(x).to(device).float() for x in [ages, genders]]
        ims = torch.cat(ims).to(device)

        return ims, ages, genders
trn = GenderAgeClass(trn_df)
val = GenderAgeClass(val_df)
im, gen, age = trn[0]
show(im, title=f'Gender: {gen}\nAge: {age}', sz=5)
train_loader = DataLoader(trn, batch_size=32, shuffle=True, drop_last=True, collate_fn=trn.collate_fn)
test_loader = DataLoader(val, batch_size=32, collate_fn=val.collate_fn)

ims, gens, ages = next(iter(train_loader))
inspect(ims, gens, ages)
==================================================================
Tensor	Shape: torch.Size([32, 3, 224, 224])	Min: -2.118	Max: 2.640	Mean: -0.403	dtype: torch.float32
==================================================================
Tensor	Shape: torch.Size([32])	Min: 0.000	Max: 0.762	Mean: 0.345	dtype: torch.float32
==================================================================
Tensor	Shape: torch.Size([32])	Min: 0.000	Max: 1.000	Mean: 0.438	dtype: torch.float32
==================================================================
from torchvision import models
def get_model():
    model = models.vgg16(pretrained = True)
    # Freeze parameters so we don't backprop through them
    for param in model.parameters():
        param.requires_grad = False
    model.avgpool = nn.Sequential(
        nn.Conv2d(512,512, kernel_size=3),
        nn.MaxPool2d(2),
        nn.ReLU(),
        nn.Flatten()
    )
    class ageGenderClassifier(nn.Module):
        def __init__(self):
            super(ageGenderClassifier, self).__init__()
            self.intermediate = nn.Sequential(
                nn.Linear(2048,512),
                nn.ReLU(),
                nn.Dropout(0.4),
                nn.Linear(512,128),
                nn.ReLU(),
                nn.Dropout(0.4),
                nn.Linear(128,64),
                nn.ReLU(),
            )
            self.age_classifier = nn.Sequential(
                nn.Linear(64, 1),
                nn.Sigmoid()
            )
            self.gender_classifier = nn.Sequential(
                nn.Linear(64, 1),
                nn.Sigmoid()
            )
        def forward(self, x):
            x = self.intermediate(x)
            age = self.age_classifier(x)
            gender = self.gender_classifier(x)
            return gender, age
        
    model.classifier = ageGenderClassifier()
    
    gender_criterion = nn.BCELoss()
    age_criterion = nn.L1Loss()
    loss_functions = gender_criterion, age_criterion
    optimizer = torch.optim.Adam(model.parameters(), lr= 1e-4)
    
    return model.to(device), loss_functions, optimizer

model, loss_functions, optimizer = get_model()
Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth
HBox(children=(FloatProgress(value=0.0, max=553433881.0), HTML(value='')))
!pip install torchsummary
from torchsummary import summary
summary(model, input_size=(3,224,224), device=device)
Requirement already satisfied: torchsummary in /usr/local/lib/python3.6/dist-packages (1.5.1)
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1         [-1, 64, 224, 224]           1,792
              ReLU-2         [-1, 64, 224, 224]               0
            Conv2d-3         [-1, 64, 224, 224]          36,928
              ReLU-4         [-1, 64, 224, 224]               0
         MaxPool2d-5         [-1, 64, 112, 112]               0
            Conv2d-6        [-1, 128, 112, 112]          73,856
              ReLU-7        [-1, 128, 112, 112]               0
            Conv2d-8        [-1, 128, 112, 112]         147,584
              ReLU-9        [-1, 128, 112, 112]               0
        MaxPool2d-10          [-1, 128, 56, 56]               0
           Conv2d-11          [-1, 256, 56, 56]         295,168
             ReLU-12          [-1, 256, 56, 56]               0
           Conv2d-13          [-1, 256, 56, 56]         590,080
             ReLU-14          [-1, 256, 56, 56]               0
           Conv2d-15          [-1, 256, 56, 56]         590,080
             ReLU-16          [-1, 256, 56, 56]               0
        MaxPool2d-17          [-1, 256, 28, 28]               0
           Conv2d-18          [-1, 512, 28, 28]       1,180,160
             ReLU-19          [-1, 512, 28, 28]               0
           Conv2d-20          [-1, 512, 28, 28]       2,359,808
             ReLU-21          [-1, 512, 28, 28]               0
           Conv2d-22          [-1, 512, 28, 28]       2,359,808
             ReLU-23          [-1, 512, 28, 28]               0
        MaxPool2d-24          [-1, 512, 14, 14]               0
           Conv2d-25          [-1, 512, 14, 14]       2,359,808
             ReLU-26          [-1, 512, 14, 14]               0
           Conv2d-27          [-1, 512, 14, 14]       2,359,808
             ReLU-28          [-1, 512, 14, 14]               0
           Conv2d-29          [-1, 512, 14, 14]       2,359,808
             ReLU-30          [-1, 512, 14, 14]               0
        MaxPool2d-31            [-1, 512, 7, 7]               0
           Conv2d-32            [-1, 512, 5, 5]       2,359,808
        MaxPool2d-33            [-1, 512, 2, 2]               0
             ReLU-34            [-1, 512, 2, 2]               0
          Flatten-35                 [-1, 2048]               0
           Linear-36                  [-1, 512]       1,049,088
             ReLU-37                  [-1, 512]               0
          Dropout-38                  [-1, 512]               0
           Linear-39                  [-1, 128]          65,664
             ReLU-40                  [-1, 128]               0
          Dropout-41                  [-1, 128]               0
           Linear-42                   [-1, 64]           8,256
             ReLU-43                   [-1, 64]               0
           Linear-44                    [-1, 1]              65
          Sigmoid-45                    [-1, 1]               0
           Linear-46                    [-1, 1]              65
          Sigmoid-47                    [-1, 1]               0
ageGenderClassifier-48         [[-1, 1], [-1, 1]]               0
================================================================
Total params: 18,197,634
Trainable params: 3,482,946
Non-trainable params: 14,714,688
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 218.55
Params size (MB): 69.42
Estimated Total Size (MB): 288.55
----------------------------------------------------------------
def train_batch(data, model, optimizer, criteria):
    model.train()
    ims, age, gender = data
    optimizer.zero_grad()
    pred_gender, pred_age = model(ims)    
    gender_criterion, age_criterion = criteria
    gender_loss = gender_criterion(pred_gender.squeeze(), gender)
    age_loss = age_criterion(pred_age.squeeze(), age)
    total_loss = gender_loss + age_loss
    total_loss.backward()
    optimizer.step()
    return total_loss

def validate_batch(data, model, criteria):
    model.eval()
    ims, age, gender = data
    with torch.no_grad():
      pred_gender, pred_age = model(ims)
    gender_criterion, age_criterion = criteria
    gender_loss = gender_criterion(pred_gender.squeeze(), gender)
    age_loss = age_criterion(pred_age.squeeze(), age)
    total_loss = gender_loss + age_loss
    pred_gender = (pred_gender > 0.5).squeeze()
    gender_acc = (pred_gender == gender).float().sum()
    age_mae = torch.abs(age - pred_age).float().sum()
    return total_loss, gender_acc, age_mae
model, criterion, optimizer = get_model()
n_epochs = 5
log = Report(n_epochs)
for epoch in range(n_epochs):
    N = len(train_loader)
    for ix, data in enumerate(train_loader):
        if ix==N: break
        total_loss = train_batch(data, model, optimizer, criterion)
        log.record(epoch+(ix+1)/N, trn_loss=total_loss, end='\r')

    N = len(test_loader)
    for ix, data in enumerate(test_loader):
        if ix==N: break
        total_loss, gender_acc, age_mae = validate_batch(data, model, criterion)
        gender_acc /= len(data[0])
        age_mae /= len(data[0])
        log.record(epoch+(ix+1)/N, val_loss=total_loss, val_gender_acc=gender_acc, val_age_mae=age_mae, end='\r')
    log.report_avgs(epoch+1)
log.plot_epochs(['trn_loss','val_loss'])
EPOCH: 1.000	trn_loss: 0.551	val_loss: 0.465	val_gender_acc: 0.834	val_age_mae: 6.238	(779.37s - 3117.47s remaining)
EPOCH: 2.000	trn_loss: 0.401	val_loss: 0.444	val_gender_acc: 0.847	val_age_mae: 6.229	(1555.70s - 2333.56s remaining)
EPOCH: 3.000	trn_loss: 0.284	val_loss: 0.493	val_gender_acc: 0.846	val_age_mae: 6.340	(2335.09s - 1556.73s remaining)
EPOCH: 4.000	trn_loss: 0.198	val_loss: 0.655	val_gender_acc: 0.842	val_age_mae: 6.339	(3110.83s - 777.71s remaining)
EPOCH: 4.994	val_loss: 0.375	val_gender_acc: 0.969	val_age_mae: 6.541	(3887.91s - 4.54s remaining)
  0%|          | 0/6 [00:00<?, ?it/s]/usr/local/lib/python3.6/dist-packages/numpy/core/fromnumeric.py:3335: RuntimeWarning: Mean of empty slice.
  out=out, **kwargs)
/usr/local/lib/python3.6/dist-packages/numpy/core/_methods.py:161: RuntimeWarning: invalid value encountered in double_scalars
  ret = ret.dtype.type(ret / rcount)
100%|██████████| 6/6 [00:00<00:00, 291.40it/s]
EPOCH: 4.997	val_loss: 1.030	val_gender_acc: 0.844	val_age_mae: 7.533	(3888.16s - 2.27s remaining)
EPOCH: 5.000	val_loss: 0.171	val_gender_acc: 0.900	val_age_mae: 1.977	(3888.25s - 0.00s remaining)
EPOCH: 5.000	trn_loss: 0.157	val_loss: 0.733	val_gender_acc: 0.847	val_age_mae: 6.322	(3888.25s - 0.00s remaining)
!wget https://www.dropbox.com/s/6kzr8l68e9kpjkf/5_9.JPG
im = cv2.imread('/content/5_9.JPG')
im = trn.preprocess_image(im).to(device)
gender, age = model(im)
pred_gender = gender.to('cpu').detach().numpy()
pred_age = age.to('cpu').detach().numpy()
im = cv2.imread('/content/5_9.JPG')
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
plt.imshow(im)
print('predicted gender:',np.where(pred_gender[0][0]<0.5,'Male','Female'), '; Predicted age', int(pred_age[0][0]*80))
--2020-11-08 12:56:14--  https://www.dropbox.com/s/6kzr8l68e9kpjkf/5_9.JPG
Resolving www.dropbox.com (www.dropbox.com)... 162.125.1.1, 2620:100:6016:1::a27d:101
Connecting to www.dropbox.com (www.dropbox.com)|162.125.1.1|:443... connected.
HTTP request sent, awaiting response... 301 Moved Permanently
Location: /s/raw/6kzr8l68e9kpjkf/5_9.JPG [following]
--2020-11-08 12:56:14--  https://www.dropbox.com/s/raw/6kzr8l68e9kpjkf/5_9.JPG
Reusing existing connection to www.dropbox.com:443.
HTTP request sent, awaiting response... 302 Found
Location: https://uc177111820f4b375b4969ae95e7.dl.dropboxusercontent.com/cd/0/inline/BCx_9g_kyOTpcIFyCAe6R019q6zv2sreGAJzOnco0RygfRp0a253M-JC2VhM9pBdc9xj107QBfquDj3PkN8_exIM5Qnz4eSpipop6_K1f5IM8IC-5z2-zxEmkQUiSQRE-XM/file# [following]
--2020-11-08 12:56:14--  https://uc177111820f4b375b4969ae95e7.dl.dropboxusercontent.com/cd/0/inline/BCx_9g_kyOTpcIFyCAe6R019q6zv2sreGAJzOnco0RygfRp0a253M-JC2VhM9pBdc9xj107QBfquDj3PkN8_exIM5Qnz4eSpipop6_K1f5IM8IC-5z2-zxEmkQUiSQRE-XM/file
Resolving uc177111820f4b375b4969ae95e7.dl.dropboxusercontent.com (uc177111820f4b375b4969ae95e7.dl.dropboxusercontent.com)... 162.125.1.15, 2620:100:6016:15::a27d:10f
Connecting to uc177111820f4b375b4969ae95e7.dl.dropboxusercontent.com (uc177111820f4b375b4969ae95e7.dl.dropboxusercontent.com)|162.125.1.15|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 46983 (46K) [image/jpeg]
Saving to: 5_9.JPG

5_9.JPG             100%[===================>]  45.88K  --.-KB/s    in 0.01s   

2020-11-08 12:56:15 (3.59 MB/s) - 5_9.JPG saved [46983/46983]

predicted gender: Female ; Predicted age 25