386 lines
44 KiB
Plaintext
386 lines
44 KiB
Plaintext
|
{
|
||
|
"nbformat": 4,
|
||
|
"nbformat_minor": 0,
|
||
|
"metadata": {
|
||
|
"colab": {
|
||
|
"name": "Zero_shot_learning.ipynb",
|
||
|
"provenance": [],
|
||
|
"include_colab_link": true
|
||
|
},
|
||
|
"kernelspec": {
|
||
|
"name": "python3",
|
||
|
"display_name": "Python 3"
|
||
|
},
|
||
|
"accelerator": "GPU"
|
||
|
},
|
||
|
"cells": [
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"id": "view-in-github",
|
||
|
"colab_type": "text"
|
||
|
},
|
||
|
"source": [
|
||
|
"<a href=\"https://colab.research.google.com/github/PacktPublishing/Hands-On-Computer-Vision-with-PyTorch/blob/master/Chapter14/Zero_shot_learning.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"metadata": {
|
||
|
"id": "-q5sBEc5EIpI",
|
||
|
"outputId": "4d91014a-9266-4dfe-8a4b-07f90a931d8e",
|
||
|
"colab": {
|
||
|
"base_uri": "https://localhost:8080/",
|
||
|
"height": 187
|
||
|
}
|
||
|
},
|
||
|
"source": [
|
||
|
"!git clone https://github.com/sizhky/zero-shot-learning/\n",
|
||
|
"!pip install -Uq torch_snippets\n",
|
||
|
"%cd zero-shot-learning/src\n",
|
||
|
"import gzip\n",
|
||
|
"import _pickle as cPickle\n",
|
||
|
"from torch_snippets import *\n",
|
||
|
"from sklearn.preprocessing import LabelEncoder, normalize\n",
|
||
|
"device = 'cuda' if torch.cuda.is_available() else 'cpu'"
|
||
|
],
|
||
|
"execution_count": null,
|
||
|
"outputs": [
|
||
|
{
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Cloning into 'zero-shot-learning'...\n",
|
||
|
"remote: Enumerating objects: 102, done.\u001b[K\n",
|
||
|
"remote: Total 102 (delta 0), reused 0 (delta 0), pack-reused 102\u001b[K\n",
|
||
|
"Receiving objects: 100% (102/102), 134.47 MiB | 36.40 MiB/s, done.\n",
|
||
|
"Resolving deltas: 100% (45/45), done.\n",
|
||
|
"\u001b[K |████████████████████████████████| 36.7MB 87kB/s \n",
|
||
|
"\u001b[K |████████████████████████████████| 61kB 7.8MB/s \n",
|
||
|
"\u001b[K |████████████████████████████████| 102kB 9.4MB/s \n",
|
||
|
"\u001b[?25h Building wheel for contextvars (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
|
||
|
"/content/zero-shot-learning/src\n"
|
||
|
],
|
||
|
"name": "stdout"
|
||
|
}
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"metadata": {
|
||
|
"id": "S33kWM_TEeC0"
|
||
|
},
|
||
|
"source": [
|
||
|
"WORD2VECPATH = \"../data/class_vectors.npy\"\n",
|
||
|
"DATAPATH = \"../data/zeroshot_data.pkl\""
|
||
|
],
|
||
|
"execution_count": null,
|
||
|
"outputs": []
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"metadata": {
|
||
|
"id": "kULYCEK7Ezip"
|
||
|
},
|
||
|
"source": [
|
||
|
"with open('train_classes.txt', 'r') as infile:\n",
|
||
|
" train_classes = [str.strip(line) for line in infile]"
|
||
|
],
|
||
|
"execution_count": null,
|
||
|
"outputs": []
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"metadata": {
|
||
|
"id": "iomjeHmyE1xi"
|
||
|
},
|
||
|
"source": [
|
||
|
"with gzip.GzipFile(DATAPATH, 'rb') as infile:\n",
|
||
|
" data = cPickle.load(infile)"
|
||
|
],
|
||
|
"execution_count": null,
|
||
|
"outputs": []
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"metadata": {
|
||
|
"id": "7qZBBK9jE3Sg"
|
||
|
},
|
||
|
"source": [
|
||
|
"training_data = [instance for instance in data if instance[0] in train_classes]\n",
|
||
|
"zero_shot_data = [instance for instance in data if instance[0] not in train_classes]\n",
|
||
|
"np.random.shuffle(training_data)"
|
||
|
],
|
||
|
"execution_count": null,
|
||
|
"outputs": []
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"metadata": {
|
||
|
"id": "_ZAv4KiaE420"
|
||
|
},
|
||
|
"source": [
|
||
|
"train_size = 300 # per class\n",
|
||
|
"train_data, valid_data = [], []\n",
|
||
|
"for class_label in train_classes:\n",
|
||
|
" ctr = 0\n",
|
||
|
" for instance in training_data:\n",
|
||
|
" if instance[0] == class_label:\n",
|
||
|
" if ctr < train_size:\n",
|
||
|
" train_data.append(instance)\n",
|
||
|
" ctr+=1\n",
|
||
|
" else:\n",
|
||
|
" valid_data.append(instance)"
|
||
|
],
|
||
|
"execution_count": null,
|
||
|
"outputs": []
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"metadata": {
|
||
|
"id": "tYjuW30RE6fx"
|
||
|
},
|
||
|
"source": [
|
||
|
"np.random.shuffle(train_data)\n",
|
||
|
"np.random.shuffle(valid_data)\n",
|
||
|
"vectors = {i:j for i,j in np.load(WORD2VECPATH, allow_pickle=True)}"
|
||
|
],
|
||
|
"execution_count": null,
|
||
|
"outputs": []
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"metadata": {
|
||
|
"id": "fWkDpklSE79O"
|
||
|
},
|
||
|
"source": [
|
||
|
"train_data = [(feat, vectors[clss]) for clss,feat in train_data]\n",
|
||
|
"valid_data = [(feat, vectors[clss]) for clss,feat in valid_data]"
|
||
|
],
|
||
|
"execution_count": null,
|
||
|
"outputs": []
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"metadata": {
|
||
|
"id": "haQky_e2E9SL"
|
||
|
},
|
||
|
"source": [
|
||
|
"train_clss = [clss for clss,feat in train_data]\n",
|
||
|
"valid_clss = [clss for clss,feat in valid_data]\n",
|
||
|
"zero_shot_clss = [clss for clss,feat in zero_shot_data]"
|
||
|
],
|
||
|
"execution_count": null,
|
||
|
"outputs": []
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"metadata": {
|
||
|
"id": "VAmZUA_LE-pu"
|
||
|
},
|
||
|
"source": [
|
||
|
"x_train, y_train = zip(*train_data)\n",
|
||
|
"x_train, y_train = np.squeeze(np.asarray(x_train)), np.squeeze(np.asarray(y_train))\n",
|
||
|
"x_train = normalize(x_train, norm='l2')\n",
|
||
|
"\n",
|
||
|
"x_valid, y_valid = zip(*valid_data)\n",
|
||
|
"x_valid, y_valid = np.squeeze(np.asarray(x_valid)), np.squeeze(np.asarray(y_valid))\n",
|
||
|
"x_valid = normalize(x_valid, norm='l2')\n",
|
||
|
"\n",
|
||
|
"y_zsl, x_zsl = zip(*zero_shot_data)\n",
|
||
|
"x_zsl, y_zsl = np.squeeze(np.asarray(x_zsl)), np.squeeze(np.asarray(y_zsl))\n",
|
||
|
"x_zsl = normalize(x_zsl, norm='l2')"
|
||
|
],
|
||
|
"execution_count": null,
|
||
|
"outputs": []
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"metadata": {
|
||
|
"id": "6HCmIEmhFAiD"
|
||
|
},
|
||
|
"source": [
|
||
|
"from torch.utils.data import TensorDataset\n",
|
||
|
"\n",
|
||
|
"trn_ds = TensorDataset(*[torch.Tensor(t).to(device) for t in [x_train, y_train]])\n",
|
||
|
"val_ds = TensorDataset(*[torch.Tensor(t).to(device) for t in [x_valid, y_valid]])\n",
|
||
|
"\n",
|
||
|
"trn_dl = DataLoader(trn_ds, batch_size=32, shuffle=True)\n",
|
||
|
"val_dl = DataLoader(val_ds, batch_size=32, shuffle=False)"
|
||
|
],
|
||
|
"execution_count": null,
|
||
|
"outputs": []
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"metadata": {
|
||
|
"id": "LJod2CHNFCGF"
|
||
|
},
|
||
|
"source": [
|
||
|
"def build_model(): \n",
|
||
|
" return nn.Sequential(\n",
|
||
|
" nn.Linear(4096, 1024), nn.ReLU(inplace=True),\n",
|
||
|
" nn.BatchNorm1d(1024), nn.Dropout(0.8),\n",
|
||
|
" nn.Linear(1024, 512), nn.ReLU(inplace=True),\n",
|
||
|
" nn.BatchNorm1d(512), nn.Dropout(0.8),\n",
|
||
|
" nn.Linear(512, 256), nn.ReLU(inplace=True),\n",
|
||
|
" nn.BatchNorm1d(256), nn.Dropout(0.8),\n",
|
||
|
" nn.Linear(256, 300)\n",
|
||
|
" )"
|
||
|
],
|
||
|
"execution_count": null,
|
||
|
"outputs": []
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"metadata": {
|
||
|
"id": "0h1dNnFSFM1O"
|
||
|
},
|
||
|
"source": [
|
||
|
"def train_batch(model, data, optimizer, criterion):\n",
|
||
|
" ims, labels = data\n",
|
||
|
" _preds = model(ims)\n",
|
||
|
" optimizer.zero_grad()\n",
|
||
|
" loss = criterion(_preds, labels)\n",
|
||
|
" loss.backward()\n",
|
||
|
" optimizer.step()\n",
|
||
|
" return loss.item()\n",
|
||
|
"\n",
|
||
|
"@torch.no_grad()\n",
|
||
|
"def validate_batch(model, data, criterion):\n",
|
||
|
" ims, labels = data\n",
|
||
|
" _preds = model(ims)\n",
|
||
|
" loss = criterion(_preds, labels)\n",
|
||
|
" return loss.item()"
|
||
|
],
|
||
|
"execution_count": null,
|
||
|
"outputs": []
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"metadata": {
|
||
|
"id": "XF-YyXTXFOut",
|
||
|
"outputId": "ffbb890b-551a-45a0-efdb-f1e10a53ae6a",
|
||
|
"colab": {
|
||
|
"base_uri": "https://localhost:8080/",
|
||
|
"height": 490
|
||
|
}
|
||
|
},
|
||
|
"source": [
|
||
|
"model = build_model().to(device)\n",
|
||
|
"criterion = nn.MSELoss()\n",
|
||
|
"optimizer = optim.Adam(model.parameters(), lr=1e-3)\n",
|
||
|
"n_epochs = 60\n",
|
||
|
"\n",
|
||
|
"log = Report(n_epochs)\n",
|
||
|
"for ex in range(n_epochs):\n",
|
||
|
" N = len(trn_dl)\n",
|
||
|
" for bx, data in enumerate(trn_dl):\n",
|
||
|
" loss = train_batch(model, data, optimizer, criterion)\n",
|
||
|
" log.record(ex+(bx+1)/N, trn_loss=loss, end='\\r')\n",
|
||
|
"\n",
|
||
|
" N = len(val_dl)\n",
|
||
|
" for bx, data in enumerate(val_dl):\n",
|
||
|
" loss = validate_batch(model, data, criterion)\n",
|
||
|
" log.record(ex+(bx+1)/N, val_loss=loss, end='\\r')\n",
|
||
|
" \n",
|
||
|
" if ex == 10: optimizer = optim.Adam(model.parameters(), lr=1e-4)\n",
|
||
|
" if ex == 40: optimizer = optim.Adam(model.parameters(), lr=1e-5)\n",
|
||
|
" if not (ex+1)%10: log.report_avgs(ex+1)\n",
|
||
|
"\n",
|
||
|
"log.plot(log=True)"
|
||
|
],
|
||
|
"execution_count": null,
|
||
|
"outputs": [
|
||
|
{
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"EPOCH: 10.000\ttrn_loss: 0.019\tval_loss: 0.019\t(122.92s - 614.58s remaining)\n",
|
||
|
"EPOCH: 20.000\ttrn_loss: 0.014\tval_loss: 0.014\t(237.73s - 475.46s remaining)\n",
|
||
|
"EPOCH: 30.000\ttrn_loss: 0.013\tval_loss: 0.013\t(359.53s - 359.53s remaining)\n",
|
||
|
"EPOCH: 40.000\ttrn_loss: 0.012\tval_loss: 0.013\t(479.74s - 239.87s remaining)\n",
|
||
|
"EPOCH: 50.000\ttrn_loss: 0.011\tval_loss: 0.013\t(592.00s - 118.40s remaining)\n",
|
||
|
"EPOCH: 60.000\ttrn_loss: 0.011\tval_loss: 0.013\t(704.25s - 0.00s remaining)\n"
|
||
|
],
|
||
|
"name": "stdout"
|
||
|
},
|
||
|
{
|
||
|
"output_type": "display_data",
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAfoAAAFzCAYAAADWqstZAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nOzdZ3RU5dqH8euZkoQO0puEoiBFRRE7BPXYELGhYi/H3vUcXywoioKo2I4VC6iIqCgqoKAiUXqVKr2HXkNNpu33w57MZMgkmdRJ4v+3louZXe9sIfd+urEsCxEREamYHPEOQEREREqOEr2IiEgFpkQvIiJSgSnRi4iIVGBK9CIiIhWYEr2IiEgF5op3ACWhTp06VnJycrFd7+DBg1SpUqXYrlfe6XmE6VlE0vOIpOcRpmcRqbifx9y5c3dallU32r4KmeiTk5OZM2dOsV0vNTWVlJSUYrteeafnEaZnEUnPI5KeR5ieRaTifh7GmPW57VPVvYiISAWmRC8iIlKBKdGLiIhUYBWyjV5ERMoXr9dLWloaGRkZ8Q6lVNSoUYOlS5cW+LykpCSaNGmC2+2O+ZwKleiNMT2AHq1atYp3KCIiUgBpaWlUq1aN5ORkjDHxDqfE7d+/n2rVqhXoHMuy2LVrF2lpaTRv3jzm8ypU1b1lWWMsy7qzRo0a8Q5FREQKICMjg9q1a/8jknxhGWOoXbt2gWs9KlSiFxGR8ktJPn+FeUZK9CIiIhWYEr2IiPzj7d27l3fffbdYrpWSklKsk7YVlRK9iIj84+WW6H0+XxyiKV4Vqte9iIiUf8+NWcLfm/cV6zXbNqrOsz3a5bq/T58+rF69mhNPPBG3201SUhK1atVi2bJlDBkyhH79+lGnTh0WL17MySefzPDhw2NqL//yyy8ZMGAAlmXRvXt3Bg0ahN/v5+6772bBggUYY7jtttt45JFHeOutt3j//fdxuVy0bduWkSNHFsvPrkSfj90HPezLtOIdhoiIlKCXXnqJxYsXM3/+fFJTU+nevTuLFy+mefPmpKam8tdff7FkyRIaNWrEmWeeydSpUznrrLPyvObmzZv5v//7P+bOnUutWrU4//zz+f7772natClbtmxh8eLFgF2bkBXD2rVrSUxMDG0rDkr0+Tip/68AXHpBnAMREfmHyKvkXVo6d+4cMVa9c+fONGnSBIATTzyRdevW5ZvoZ8+eTUpKCnXr2ovKXX/99fz555/07duXtWvX8sADD9C9e3fOP/98AI4//niuv/56LrvsMi677LJi+1nURi8iInKEI5eQTUxMDH12Op1FaruvVasW06ZNIyUlhffff59///vfAIwbN4777ruPefPmccoppxRb/wAlehER+cerVq0a+/fvL9Zrdu7cmT/++IOdO3fi9/v58ssv6dq1Kzt37iQQCHDllVfywgsvMG/ePAKBABs3bqRbt24MGjSI9PR0Dhw4UCxxqOpeRET+8WrXrs2ZZ55J+/btqVSpEvXr1y/yNRs2bMhLL71Et27dQp3xevbsyYIFC7j55ptDxw0cOBC/388NN9xAeno6lmXx4IMPUrNmzSLHAEr0IiIiAIwYMSLq9pSUFFJSUkLf33777Tyvk5qaGvrcu3dvevfuHbH/hBNOYPLkyTnmup8yZUrBAo6Rqu5jZFnqeS8iIuWPSvQx8gUs3E7NwywiIrbLL7+ctWvXRmwbNGgQF1xQtoZpKdHHyOML4HaqAkRERGyjR4+OdwgxUeaK0Z5DnniHICIiUmBK9DH638RV8Q5BRESkwJTo83FMvaoA7MvwxjkSERGRglOiz0elBCcAhzz+OEciIiJScEr0+Uhy24n+sFeJXkREbFWrVs1137p162jfvn0pRpM3Jfp8XNy+AQDtGlWPcyQiIiIFp+F1+Ti/XQP6jfmbX5ZsKxMrKomI/CMM7Z5zW7vLoPMd4DkEX/TKuf/E66Dj9XBwF3x9U+S+W8flebs+ffrQtGlT7rvvPgD69euHy+Vi0qRJ7NmzB6/XywsvvEDPnj0L9GNkZGRwzz33MGfOHFwuF6+99hrdunVj6dKl3H///Xg8HgKBAN9++y2NGjXi6quvJi0tDb/fT9++fbnmmmsKdL9olOjzkTV2ftPew3GORERESso111zDww8/HEr0X3/9NRMmTODBBx+kevXq7Ny5k9NOO41LL70UY2KfPO2dd97BGMOiRYtYtmwZ559/PitWrODjjz/moYce4vrrr8fj8eD3+/npp59o1KgR48bZLyXp6enF8rMp0edDs+GJiMRBXiXwhMp5769SO98S/JE6duzI9u3b2bx5Mzt27KBWrVo0aNCARx55hD///BOHw8GmTZvYtm0bDRo0iPm6U6ZM4YEHHgCgTZs2NGvWjBUrVtC5c2cGDBhAWloaV1xxBccccwwdOnTgscce4//+7/+45JJLOPvsswv0M+RGbfT50Gx4IiL/DL169WLUqFF89dVXXHPNNXzxxRfs2LGDuXPnMn/+fOrXr09GRkax3Ovqq6/mxx9/pFKlSlx88cX8/vvvHHvsscybN48OHTrw9NNP8/zzzxfLvZTF8pHV615ERCq2a665hpEjRzJq1Ch69epFeno69erVw+12M2nSJNavX1/ga5599tl88cUXAKxYsYINGzbQunVr1q5dS4sWLXjwwQfp2bMnCxcuZPPmzVSuXJkbbriB//73v8ybN69Yfi5V3efD6VDVvYjIP0G7du3Yv38/jRs3pmHDhlx//fX06NGDDh060KlTJ9q0aVPga957773cc889dOjQAZfLxbBhw0hMTGT06NH07t0bt9tNgwYNePLJJ5k9ezb//e9/cTgcuN1u3nvvvWL5uZToY9C6loOaNWvGOwwRESlhixYtCn2uU6cO06dPj3rcgQMHcr1GcnIyixcvBiApKYmhQ4fmOObRRx/l2Wefjdh2wQUXlMjKd6q6j4HLAV5/IN5hiIiIFJhK9DFwOgy+gBXvMEREpAxZtGgRN954Y8S2xMREZs6cGaeIolOij4HTwGGfSvQiIiXJsqwCjVGPtw4dOjB//vxSvadlFbzQqar7GKjqXkSkZCUlJbFr165CJbJ/Csuy2LVrF0lJSQU6TyX6GLgMqroXESlBTZo0IS0tjR07dsQ7lFKRkZFR4IQN9gtRkyZNCnSOEn0MnA6DV1X3IiIlxu1207x583iHUWpSU1Pp2LFjqdyrzCd6Y0wV4F3AA6RalvVFacfgMuBViV5ERMqhuLTRG2M+McZsN8YsPmL7hcaY5caYVcaYPsHNVwCjLMu6A7i01IMFnGqjFxGRcipenfGGARdm32CMcQLvABcBbYHexpi2QBNgY/AwfynGGGJZsPeQNx63FhERKRITrx6OxphkYKxlWe2D308H+lmWdUHw+xPBQ9OAPZZljTXGjLQs69pcrncncCdA/fr1Tx45cmSxxXrL+IMAvJFSiZpJGqhw4MABqlatGu8wygQ9i0h6HpH0PML0LCIV9/Po1q3bXMuyOkXbV5ba6BsTLrmDneBPBd4C3jbGdAfG5HayZVlDgCEAnTp1slJSUootsKNSf2J3hsUHy938/FDxLBtYnqWmplKcz7c807OIpOcRSc8jTM8iUmk+j7KU6KOyLOsgcGs8Y0iu7mB3hp+lW/bFMwwREZECK0v10JuAptm+Nwlui7sm1crSYxIREYldWcpgs4FjjDHNjTEJwLXAj3GOCYDyMyGjiIhIpHgNr/sSmA60NsakGWNutyzLB9wPTACWAl9blrUkHvEdSUvSi4hIeRWXNnrLsnrnsv0n4KfCXtcY0wPo0apVq8JeQkREpEIpS1X3RWZZ1hjLsu6sUaNGsV63HC2mJCIiEqFCJXoRERGJpEQfA62aKCIi5ZUSfQy0cJ2IiJRXSvQxaFFTj0lERMonZbAYHFPTGe8QRERECqVCJXpjTA9jzJD09PRivW7VBHW7FxGR8qlCJfqSGl4nIiJSXlWoRC8iIiKRlOhFREQqMCX6GFVLKvMr+oqIiOSgRB+jE5vWjHcIIiIiBaZEH6PJK3cCsC/DG+dIREREY
|
||
|
"text/plain": [
|
||
|
"<Figure size 576x432 with 1 Axes>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {
|
||
|
"tags": [],
|
||
|
"needs_background": "light"
|
||
|
}
|
||
|
}
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"metadata": {
|
||
|
"id": "U39Fmo0VFQ6C"
|
||
|
},
|
||
|
"source": [
|
||
|
"pred_zsl = model(torch.Tensor(x_zsl).to(device)).cpu().detach().numpy()\n",
|
||
|
"\n",
|
||
|
"class_vectors = sorted(np.load(WORD2VECPATH, allow_pickle=True), key=lambda x: x[0])\n",
|
||
|
"classnames, vectors = zip(*class_vectors)\n",
|
||
|
"classnames = list(classnames)\n",
|
||
|
"\n",
|
||
|
"vectors = np.array(vectors)"
|
||
|
],
|
||
|
"execution_count": null,
|
||
|
"outputs": []
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"metadata": {
|
||
|
"id": "PPKdmdl7HHat",
|
||
|
"outputId": "f9a0fe71-e7f9-445b-effd-839d5dbfa1ef",
|
||
|
"colab": {
|
||
|
"base_uri": "https://localhost:8080/",
|
||
|
"height": 34
|
||
|
}
|
||
|
},
|
||
|
"source": [
|
||
|
"dists = (pred_zsl[None] - vectors[:,None])\n",
|
||
|
"dists = (dists**2).sum(-1).T\n",
|
||
|
"\n",
|
||
|
"best_classes = []\n",
|
||
|
"for item in dists:\n",
|
||
|
" best_classes.append([classnames[j] for j in np.argsort(item)[:5]])\n",
|
||
|
"\n",
|
||
|
"np.mean([i in J for i,J in zip(zero_shot_clss, best_classes)])"
|
||
|
],
|
||
|
"execution_count": null,
|
||
|
"outputs": [
|
||
|
{
|
||
|
"output_type": "execute_result",
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"0.7248624312156078"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {
|
||
|
"tags": []
|
||
|
},
|
||
|
"execution_count": 16
|
||
|
}
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"metadata": {
|
||
|
"id": "88CMg_0XIjMZ"
|
||
|
},
|
||
|
"source": [
|
||
|
""
|
||
|
],
|
||
|
"execution_count": null,
|
||
|
"outputs": []
|
||
|
}
|
||
|
]
|
||
|
}
|