Computer_Vision/Chapter14/Zero_shot_learning.ipynb
2024-02-13 03:34:51 +01:00

44 KiB
Raw Blame History

Open In Colab

!git clone https://github.com/sizhky/zero-shot-learning/
!pip install -Uq torch_snippets
%cd zero-shot-learning/src
import gzip
import _pickle as cPickle
from torch_snippets import *
from sklearn.preprocessing import LabelEncoder, normalize
device = 'cuda' if torch.cuda.is_available() else 'cpu'
Cloning into 'zero-shot-learning'...
remote: Enumerating objects: 102, done.
remote: Total 102 (delta 0), reused 0 (delta 0), pack-reused 102
Receiving objects: 100% (102/102), 134.47 MiB | 36.40 MiB/s, done.
Resolving deltas: 100% (45/45), done.
     |████████████████████████████████| 36.7MB 87kB/s 
     |████████████████████████████████| 61kB 7.8MB/s 
     |████████████████████████████████| 102kB 9.4MB/s 
[?25h  Building wheel for contextvars (setup.py) ... [?25l[?25hdone
/content/zero-shot-learning/src
WORD2VECPATH = "../data/class_vectors.npy"
DATAPATH = "../data/zeroshot_data.pkl"
with open('train_classes.txt', 'r') as infile:
    train_classes = [str.strip(line) for line in infile]
with gzip.GzipFile(DATAPATH, 'rb') as infile:
    data = cPickle.load(infile)
training_data = [instance for instance in data if instance[0] in train_classes]
zero_shot_data = [instance for instance in data if instance[0] not in train_classes]
np.random.shuffle(training_data)
train_size = 300 # per class
train_data, valid_data = [], []
for class_label in train_classes:
    ctr = 0
    for instance in training_data:
        if instance[0] == class_label:
            if ctr < train_size:
                train_data.append(instance)
                ctr+=1
            else:
                valid_data.append(instance)
np.random.shuffle(train_data)
np.random.shuffle(valid_data)
vectors = {i:j for i,j in np.load(WORD2VECPATH, allow_pickle=True)}
train_data = [(feat, vectors[clss]) for clss,feat in train_data]
valid_data = [(feat, vectors[clss]) for clss,feat in valid_data]
train_clss = [clss for clss,feat in train_data]
valid_clss = [clss for clss,feat in valid_data]
zero_shot_clss = [clss for clss,feat in zero_shot_data]
x_train, y_train = zip(*train_data)
x_train, y_train = np.squeeze(np.asarray(x_train)), np.squeeze(np.asarray(y_train))
x_train = normalize(x_train, norm='l2')

x_valid, y_valid = zip(*valid_data)
x_valid, y_valid = np.squeeze(np.asarray(x_valid)), np.squeeze(np.asarray(y_valid))
x_valid = normalize(x_valid, norm='l2')

y_zsl, x_zsl = zip(*zero_shot_data)
x_zsl, y_zsl = np.squeeze(np.asarray(x_zsl)), np.squeeze(np.asarray(y_zsl))
x_zsl = normalize(x_zsl, norm='l2')
from torch.utils.data import TensorDataset

trn_ds = TensorDataset(*[torch.Tensor(t).to(device) for t in [x_train, y_train]])
val_ds = TensorDataset(*[torch.Tensor(t).to(device) for t in [x_valid, y_valid]])

trn_dl = DataLoader(trn_ds, batch_size=32, shuffle=True)
val_dl = DataLoader(val_ds, batch_size=32, shuffle=False)
def build_model(): 
    return nn.Sequential(
        nn.Linear(4096, 1024), nn.ReLU(inplace=True),
        nn.BatchNorm1d(1024), nn.Dropout(0.8),
        nn.Linear(1024, 512), nn.ReLU(inplace=True),
        nn.BatchNorm1d(512), nn.Dropout(0.8),
        nn.Linear(512, 256), nn.ReLU(inplace=True),
        nn.BatchNorm1d(256), nn.Dropout(0.8),
        nn.Linear(256, 300)
    )
def train_batch(model, data, optimizer, criterion):
    ims, labels = data
    _preds = model(ims)
    optimizer.zero_grad()
    loss = criterion(_preds, labels)
    loss.backward()
    optimizer.step()
    return loss.item()

@torch.no_grad()
def validate_batch(model, data, criterion):
    ims, labels = data
    _preds = model(ims)
    loss = criterion(_preds, labels)
    return loss.item()
model = build_model().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
n_epochs = 60

log = Report(n_epochs)
for ex in range(n_epochs):
    N = len(trn_dl)
    for bx, data in enumerate(trn_dl):
        loss = train_batch(model, data, optimizer, criterion)
        log.record(ex+(bx+1)/N, trn_loss=loss, end='\r')

    N = len(val_dl)
    for bx, data in enumerate(val_dl):
        loss = validate_batch(model, data, criterion)
        log.record(ex+(bx+1)/N, val_loss=loss, end='\r')
        
    if ex == 10: optimizer = optim.Adam(model.parameters(), lr=1e-4)
    if ex == 40: optimizer = optim.Adam(model.parameters(), lr=1e-5)
    if not (ex+1)%10: log.report_avgs(ex+1)

log.plot(log=True)
EPOCH: 10.000	trn_loss: 0.019	val_loss: 0.019	(122.92s - 614.58s remaining)
EPOCH: 20.000	trn_loss: 0.014	val_loss: 0.014	(237.73s - 475.46s remaining)
EPOCH: 30.000	trn_loss: 0.013	val_loss: 0.013	(359.53s - 359.53s remaining)
EPOCH: 40.000	trn_loss: 0.012	val_loss: 0.013	(479.74s - 239.87s remaining)
EPOCH: 50.000	trn_loss: 0.011	val_loss: 0.013	(592.00s - 118.40s remaining)
EPOCH: 60.000	trn_loss: 0.011	val_loss: 0.013	(704.25s - 0.00s remaining)
pred_zsl = model(torch.Tensor(x_zsl).to(device)).cpu().detach().numpy()

class_vectors = sorted(np.load(WORD2VECPATH, allow_pickle=True), key=lambda x: x[0])
classnames, vectors = zip(*class_vectors)
classnames = list(classnames)

vectors = np.array(vectors)
dists = (pred_zsl[None] - vectors[:,None])
dists = (dists**2).sum(-1).T

best_classes = []
for item in dists:
    best_classes.append([classnames[j] for j in np.argsort(item)[:5]])

np.mean([i in J for i,J in zip(zero_shot_clss, best_classes)])
0.7248624312156078