212 KiB
212 KiB
import torch
import numpy as np, cv2, pandas as pd, glob, time
import matplotlib.pyplot as plt
%matplotlib inline
import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision import transforms, models, datasets
device = 'cuda' if torch.cuda.is_available() else 'cpu'
from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive
from google.colab import auth
from oauth2client.client import GoogleCredentials
auth.authenticate_user()
gauth = GoogleAuth()
gauth.credentials = GoogleCredentials.get_application_default()
drive = GoogleDrive(gauth)
def getFile_from_drive( file_id, name ):
downloaded = drive.CreateFile({'id': file_id})
downloaded.GetContentFile(name)
getFile_from_drive('1Z1RqRo0_JiavaZw2yzZG6WETdZQ8qX86', 'fairface-img-margin025-trainval.zip')
getFile_from_drive('1k5vvyREmHDW5TSM9QgB04Bvc8C8_7dl-', 'fairface-label-train.csv')
getFile_from_drive('1_rtz1M1zhvS0d5vVoXUamnohB6cJ02iJ', 'fairface-label-val.csv')
!unzip -qq fairface-img-margin025-trainval.zip
trn_df = pd.read_csv('fairface-label-train.csv')
val_df = pd.read_csv('fairface-label-val.csv')
trn_df.head()
file | age | gender | race | service_test | |
---|---|---|---|---|---|
0 | train/1.jpg | 59 | Male | East Asian | True |
1 | train/2.jpg | 39 | Female | Indian | False |
2 | train/3.jpg | 11 | Female | Black | False |
3 | train/4.jpg | 26 | Female | Indian | True |
4 | train/5.jpg | 26 | Female | Indian | True |
from torch.utils.data import Dataset, DataLoader
import cv2
IMAGE_SIZE = 224
class GenderAgeClass(Dataset):
def __init__(self, df, tfms=None):
self.df = df
self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
def __len__(self): return len(self.df)
def __getitem__(self, ix):
f = self.df.iloc[ix].squeeze()
file = f.file
gen = f.gender == 'Female'
age = f.age
im = cv2.imread(file)
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
return im, age, gen
def preprocess_image(self, im):
im = cv2.resize(im, (IMAGE_SIZE, IMAGE_SIZE))
im = torch.tensor(im).permute(2,0,1)
im = self.normalize(im/255.)
return im[None]
def collate_fn(self, batch):
'preprocess images, ages and genders'
ims, ages, genders = [], [], []
for im, age, gender in batch:
im = self.preprocess_image(im)
ims.append(im)
ages.append(float(int(age)/80))
genders.append(float(gender))
ages, genders = [torch.tensor(x).to(device).float() for x in [ages, genders]]
ims = torch.cat(ims).to(device)
return ims, ages, genders
trn = GenderAgeClass(trn_df)
val = GenderAgeClass(val_df)
device='cuda'
train_loader = DataLoader(trn, batch_size=32, shuffle=True, drop_last=True, collate_fn=trn.collate_fn)
test_loader = DataLoader(val, batch_size=32, collate_fn=val.collate_fn)
a,b,c, = next(iter(train_loader))
print(a.shape, b.shape, c.shape)
torch.Size([32, 3, 224, 224]) torch.Size([32]) torch.Size([32])
def get_model():
model = models.vgg16(pretrained = True)
# Freeze parameters so we don't backprop through them
for param in model.parameters():
param.requires_grad = False
model.avgpool = nn.Sequential(
nn.Conv2d(512,512, kernel_size=3),
nn.MaxPool2d(2),
nn.ReLU(),
nn.Flatten()
)
class ageGenderClassifier(nn.Module):
def __init__(self):
super(ageGenderClassifier, self).__init__()
self.intermediate = nn.Sequential(
nn.Linear(2048,512),
nn.ReLU(),
nn.Dropout(0.4),
nn.Linear(512,128),
nn.ReLU(),
nn.Dropout(0.4),
nn.Linear(128,64),
nn.ReLU(),
)
self.age_classifier = nn.Sequential(
nn.Linear(64, 1),
nn.Sigmoid()
)
self.gender_classifier = nn.Sequential(
nn.Linear(64, 1),
nn.Sigmoid()
)
def forward(self, x):
x = self.intermediate(x)
age = self.age_classifier(x)
gender = self.gender_classifier(x)
return gender, age
model.classifier = ageGenderClassifier()
gender_criterion = nn.BCELoss()
age_criterion = nn.L1Loss()
loss_functions = gender_criterion, age_criterion
optimizer = torch.optim.Adam(model.parameters(), lr= 1e-4)
return model.to(device), loss_functions, optimizer
model, loss_functions, optimizer = get_model()
Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth
HBox(children=(FloatProgress(value=0.0, max=553433881.0), HTML(value='')))
!pip install torchsummary
from torchsummary import summary
summary(model, input_size=(3,224,224), device=device)
Requirement already satisfied: torchsummary in /usr/local/lib/python3.6/dist-packages (1.5.1) ---------------------------------------------------------------- Layer (type) Output Shape Param # ================================================================ Conv2d-1 [-1, 64, 224, 224] 1,792 ReLU-2 [-1, 64, 224, 224] 0 Conv2d-3 [-1, 64, 224, 224] 36,928 ReLU-4 [-1, 64, 224, 224] 0 MaxPool2d-5 [-1, 64, 112, 112] 0 Conv2d-6 [-1, 128, 112, 112] 73,856 ReLU-7 [-1, 128, 112, 112] 0 Conv2d-8 [-1, 128, 112, 112] 147,584 ReLU-9 [-1, 128, 112, 112] 0 MaxPool2d-10 [-1, 128, 56, 56] 0 Conv2d-11 [-1, 256, 56, 56] 295,168 ReLU-12 [-1, 256, 56, 56] 0 Conv2d-13 [-1, 256, 56, 56] 590,080 ReLU-14 [-1, 256, 56, 56] 0 Conv2d-15 [-1, 256, 56, 56] 590,080 ReLU-16 [-1, 256, 56, 56] 0 MaxPool2d-17 [-1, 256, 28, 28] 0 Conv2d-18 [-1, 512, 28, 28] 1,180,160 ReLU-19 [-1, 512, 28, 28] 0 Conv2d-20 [-1, 512, 28, 28] 2,359,808 ReLU-21 [-1, 512, 28, 28] 0 Conv2d-22 [-1, 512, 28, 28] 2,359,808 ReLU-23 [-1, 512, 28, 28] 0 MaxPool2d-24 [-1, 512, 14, 14] 0 Conv2d-25 [-1, 512, 14, 14] 2,359,808 ReLU-26 [-1, 512, 14, 14] 0 Conv2d-27 [-1, 512, 14, 14] 2,359,808 ReLU-28 [-1, 512, 14, 14] 0 Conv2d-29 [-1, 512, 14, 14] 2,359,808 ReLU-30 [-1, 512, 14, 14] 0 MaxPool2d-31 [-1, 512, 7, 7] 0 Conv2d-32 [-1, 512, 5, 5] 2,359,808 MaxPool2d-33 [-1, 512, 2, 2] 0 ReLU-34 [-1, 512, 2, 2] 0 Flatten-35 [-1, 2048] 0 Linear-36 [-1, 512] 1,049,088 ReLU-37 [-1, 512] 0 Dropout-38 [-1, 512] 0 Linear-39 [-1, 128] 65,664 ReLU-40 [-1, 128] 0 Dropout-41 [-1, 128] 0 Linear-42 [-1, 64] 8,256 ReLU-43 [-1, 64] 0 Linear-44 [-1, 1] 65 Sigmoid-45 [-1, 1] 0 Linear-46 [-1, 1] 65 Sigmoid-47 [-1, 1] 0 ageGenderClassifier-48 [[-1, 1], [-1, 1]] 0 ================================================================ Total params: 18,197,634 Trainable params: 3,482,946 Non-trainable params: 14,714,688 ---------------------------------------------------------------- Input size (MB): 0.57 Forward/backward pass size (MB): 218.55 Params size (MB): 69.42 Estimated Total Size (MB): 288.55 ----------------------------------------------------------------
def train_batch(data, model, optimizer, criteria):
model.train()
ims, age, gender = data
optimizer.zero_grad()
pred_gender, pred_age = model(ims)
gender_criterion, age_criterion = criteria
gender_loss = gender_criterion(pred_gender.squeeze(), gender)
age_loss = age_criterion(pred_age.squeeze(), age)
total_loss = gender_loss + age_loss
total_loss.backward()
optimizer.step()
return total_loss
def validate_batch(data, model, criteria):
model.eval()
ims, age, gender = data
with torch.no_grad():
pred_gender, pred_age = model(ims)
gender_criterion, age_criterion = criteria
gender_loss = gender_criterion(pred_gender.squeeze(), gender)
age_loss = age_criterion(pred_age.squeeze(), age)
total_loss = gender_loss + age_loss
pred_gender = (pred_gender > 0.5).squeeze()
gender_acc = (pred_gender == gender).float().sum()
age_mae = torch.abs(age - pred_age).float().sum()
return total_loss, gender_acc, age_mae
model, criteria, optimizer = get_model()
val_gender_accuracies = []
val_age_maes = []
train_losses = []
val_losses = []
n_epochs = 5
best_test_loss = 1000
start = time.time()
for epoch in range(n_epochs):
epoch_train_loss, epoch_test_loss = 0, 0
val_age_mae, val_gender_acc, ctr = 0, 0, 0
_n = len(train_loader)
for ix, data in enumerate(train_loader):
# if ix == 100: break
loss = train_batch(data, model, optimizer, criteria)
epoch_train_loss += loss.item()
for ix, data in enumerate(test_loader):
# if ix == 10: break
loss, gender_acc, age_mae = validate_batch(data, model, criteria)
epoch_test_loss += loss.item()
val_age_mae += age_mae
val_gender_acc += gender_acc
ctr += len(data[0])
val_age_mae /= ctr
val_gender_acc /= ctr
epoch_train_loss /= len(train_loader)
epoch_test_loss /= len(test_loader)
elapsed = time.time()-start
best_test_loss = min(best_test_loss, epoch_test_loss)
print('{}/{} ({:.2f}s - {:.2f}s remaining)'.format(epoch+1, n_epochs, time.time()-start, (n_epochs-epoch)*(elapsed/(epoch+1))))
info = f'''Epoch: {epoch+1:03d}\tTrain Loss: {epoch_train_loss:.3f}\tTest: {epoch_test_loss:.3f}\tBest Test Loss: {best_test_loss:.4f}'''
info += f'\nGender Accuracy: {val_gender_acc*100:.2f}%\tAge MAE: {val_age_mae:.2f}\n'
print(info)
val_gender_accuracies.append(val_gender_acc)
val_age_maes.append(val_age_mae)
1/5 (844.58s - 4222.92s remaining) Epoch: 001 Train Loss: 0.548 Test: 0.469 Best Test Loss: 0.4695 Gender Accuracy: 83.54% Age MAE: 6.34 2/5 (1682.39s - 3364.79s remaining) Epoch: 002 Train Loss: 0.400 Test: 0.440 Best Test Loss: 0.4399 Gender Accuracy: 84.88% Age MAE: 6.23 3/5 (2526.27s - 2526.27s remaining) Epoch: 003 Train Loss: 0.286 Test: 0.494 Best Test Loss: 0.4399 Gender Accuracy: 84.67% Age MAE: 6.27 4/5 (3370.63s - 1685.32s remaining) Epoch: 004 Train Loss: 0.199 Test: 0.613 Best Test Loss: 0.4399 Gender Accuracy: 83.80% Age MAE: 6.41 5/5 (4208.70s - 841.74s remaining) Epoch: 005 Train Loss: 0.159 Test: 0.710 Best Test Loss: 0.4399 Gender Accuracy: 83.30% Age MAE: 6.29
epochs = np.arange(1,len(val_gender_accuracies)+1)
fig,ax = plt.subplots(1,2,figsize=(10,5))
ax = ax.flat
ax[0].plot(epochs, val_gender_accuracies, 'bo')
ax[1].plot(epochs, val_age_maes, 'r')
ax[0].set_xlabel('Epochs')
ax[1].set_xlabel('Epochs')
ax[0].set_ylabel('Accuracy')
ax[1].set_ylabel('MAE')
ax[0].set_title('Validation Gender Accuracy')
ax[0].set_title('Validation Age Mean-Absolute-Error')
plt.show()
!wget https://www.dropbox.com/s/6kzr8l68e9kpjkf/5_9.JPG
im = cv2.imread('/content/5_9.JPG')
im = trn.preprocess_image(im).to(device)
gender, age = model(im)
pred_gender = gender.to('cpu').detach().numpy()
pred_age = age.to('cpu').detach().numpy()
im = cv2.imread('/content/5_9.JPG')
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
plt.imshow(im)
print('predicted gender:',np.where(pred_gender[0][0]<0.5,'Male','Female'), '; Predicted age', int(pred_age[0][0]*80))
--2020-11-09 14:08:08-- https://www.dropbox.com/s/6kzr8l68e9kpjkf/5_9.JPG Resolving www.dropbox.com (www.dropbox.com)... 162.125.5.1, 2620:100:601f:1::a27d:901 Connecting to www.dropbox.com (www.dropbox.com)|162.125.5.1|:443... connected. HTTP request sent, awaiting response... 301 Moved Permanently Location: /s/raw/6kzr8l68e9kpjkf/5_9.JPG [following] --2020-11-09 14:08:08-- https://www.dropbox.com/s/raw/6kzr8l68e9kpjkf/5_9.JPG Reusing existing connection to www.dropbox.com:443. HTTP request sent, awaiting response... 302 Found Location: https://uca373aa04bc3c60dd027c22b5aa.dl.dropboxusercontent.com/cd/0/inline/BC3TH3a9c3lP0QNxxm8x3r7gmpJ4kF79o3OdKPnoPKKZckshL_T1F5dD_lg7QKQdBUIUXjUwZ_Ljau6bhBMpll1ZeIuk42O44KyGUGEJyV3VAzJzHvtn7gN00jGfqvtrTeU/file# [following] --2020-11-09 14:08:09-- https://uca373aa04bc3c60dd027c22b5aa.dl.dropboxusercontent.com/cd/0/inline/BC3TH3a9c3lP0QNxxm8x3r7gmpJ4kF79o3OdKPnoPKKZckshL_T1F5dD_lg7QKQdBUIUXjUwZ_Ljau6bhBMpll1ZeIuk42O44KyGUGEJyV3VAzJzHvtn7gN00jGfqvtrTeU/file Resolving uca373aa04bc3c60dd027c22b5aa.dl.dropboxusercontent.com (uca373aa04bc3c60dd027c22b5aa.dl.dropboxusercontent.com)... 162.125.9.15, 2620:100:601d:15::a27d:50f Connecting to uca373aa04bc3c60dd027c22b5aa.dl.dropboxusercontent.com (uca373aa04bc3c60dd027c22b5aa.dl.dropboxusercontent.com)|162.125.9.15|:443... connected. HTTP request sent, awaiting response... 200 OK Length: 46983 (46K) [image/jpeg] Saving to: ‘5_9.JPG.1’ 5_9.JPG.1 100%[===================>] 45.88K --.-KB/s in 0.02s 2020-11-09 14:08:09 (2.44 MB/s) - ‘5_9.JPG.1’ saved [46983/46983] predicted gender: Female ; Predicted age 24