import numpy as np
import pickle
import cv2
import tensorflow as tf
import tensorflow_addons as tfa
from tensorflow import keras
from tensorflow.keras import layers
from os import listdir, mkdir
from sklearn.preprocessing import LabelBinarizer
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.model_selection import train_test_split
#from keras.utils import img_to_array
from keras.preprocessing.image import img_to_array
import matplotlib.pyplot as plt
data_dir = "/kaggle/input/plantvillage-dataset/color"

classes = listdir(data_dir)
 'Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot',
 'Tomato___Spider_mites Two-spotted_spider_mite',
#mkdir(data_dir + 'test_data')
#mkdir(data_dir + 'train_data')
learning_rate = 0.001
weight_decay = 0.0001
batch_size = 256
num_epochs = 100
image_size = 72  # Resize input images to this size
patch_size = 6  # Size of the patches to be extract from the input images
num_patches = (image_size // patch_size) ** 2
projection_dim = 64
num_heads = 4
transformer_units = [
    projection_dim * 2,
]  # Size of the transformer layers
transformer_layers = 8
mlp_head_units = [2048, 1024]  # Size of the dense layers of the final classifier
num_classes = 38
def convert_image_to_array(image_dir):
  image = cv2.imread(image_dir)
  if image is not None :  
    return img_to_array(image)
  else :
    return np.array([])
image_list, label_list = [], []

print("Loading images ...")
root_dir = listdir(data_dir)

for plant_folder in root_dir :
    plant_disease_image_list = listdir(f"{data_dir}/{plant_folder}")
    for image in plant_disease_image_list[:100]:
        image_directory = f"{data_dir}/{plant_folder}/{image}"
        if image_directory.lower().endswith(".jpg"):
print("Image loading completed")  
Loading images ...
Image loading completed
label_binarizer = LabelBinarizer()
image_labels = label_binarizer.fit_transform(label_list)
pickle.dump(label_binarizer,open('label_transform.pkl', 'wb'))
n_classes = len(label_binarizer.classes_)
['Apple___Apple_scab' 'Apple___Black_rot' 'Apple___Cedar_apple_rust'
 'Apple___healthy' 'Blueberry___healthy'
 'Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot'
 'Corn_(maize)___Common_rust_' 'Corn_(maize)___Northern_Leaf_Blight'
 'Corn_(maize)___healthy' 'Grape___Black_rot'
 'Grape___Leaf_blight_(Isariopsis_Leaf_Spot)' 'Grape___healthy'
 'Orange___Haunglongbing_(Citrus_greening)' 'Peach___Bacterial_spot'
 'Peach___healthy' 'Pepper,_bell___Bacterial_spot'
 'Pepper,_bell___healthy' 'Potato___Early_blight' 'Potato___Late_blight'
 'Potato___healthy' 'Raspberry___healthy' 'Soybean___healthy'
 'Squash___Powdery_mildew' 'Strawberry___Leaf_scorch'
 'Strawberry___healthy' 'Tomato___Bacterial_spot' 'Tomato___Early_blight'
 'Tomato___Late_blight' 'Tomato___Leaf_Mold' 'Tomato___Septoria_leaf_spot'
 'Tomato___Spider_mites Two-spotted_spider_mite' 'Tomato___Target_Spot'
 'Tomato___Tomato_Yellow_Leaf_Curl_Virus' 'Tomato___Tomato_mosaic_virus'
np_image_list = np.array(image_list, dtype=np.float16) / 225.0
print("Spliting data to train, test")
x_train, x_test, y_train, y_test = train_test_split(np_image_list, image_labels, test_size=0.2, random_state = 42) 
Spliting data to train, test
input_shape = x_train[0].shape
(256, 256, 3)
data_augmentation = keras.Sequential(
        layers.Resizing(image_size, image_size),
            height_factor=0.2, width_factor=0.2
# Compute the mean and the variance of the training data for normalization.
def mlp(x, hidden_units, dropout_rate):
    for units in hidden_units:
        x = layers.Dense(units, activation=tf.nn.gelu)(x)
        x = layers.Dropout(dropout_rate)(x)
    return x
class Patches(layers.Layer):
    def __init__(self, patch_size):
        self.patch_size = patch_size

    def call(self, images):
        batch_size = tf.shape(images)[0]
        patches = tf.image.extract_patches(
            sizes=[1, self.patch_size, self.patch_size, 1],
            strides=[1, self.patch_size, self.patch_size, 1],
            rates=[1, 1, 1, 1],
        patch_dims = patches.shape[-1]
        patches = tf.reshape(patches, [batch_size, -1, patch_dims])
        return patches
plt.figure(figsize=(4, 4))
image = x_train[np.random.choice(range(x_train.shape[0]))]
image = image * 225.0

resized_image = tf.image.resize(
    tf.convert_to_tensor([image]), size=(image_size, image_size)
patches = Patches(patch_size)(resized_image)
print(f"Image size: {image_size} X {image_size}")
print(f"Patch size: {patch_size} X {patch_size}")
print(f"Patches per image: {patches.shape[1]}")
print(f"Elements per patch: {patches.shape[-1]}")

n = int(np.sqrt(patches.shape[1]))
plt.figure(figsize=(4, 4))
for i, patch in enumerate(patches[0]):
    ax = plt.subplot(n, n, i + 1)
    patch_img = tf.reshape(patch, (patch_size, patch_size, 3))
Image size: 72 X 72
Patch size: 6 X 6
Patches per image: 144
Elements per patch: 108
class PatchEncoder(layers.Layer):
    def __init__(self, num_patches, projection_dim):
        self.num_patches = num_patches
        self.projection = layers.Dense(units=projection_dim)
        self.position_embedding = layers.Embedding(
            input_dim=num_patches, output_dim=projection_dim

    def call(self, patch):
        positions = tf.range(start=0, limit=self.num_patches, delta=1)
        encoded = self.projection(patch) + self.position_embedding(positions)
        return encoded
def create_vit_classifier(with_attention_scores_output=False):
    inputs = layers.Input(shape=input_shape)
    # Augment data.
    augmented = data_augmentation(inputs)
    # Create patches.
    patches = Patches(patch_size)(augmented)
    # Encode patches.
    encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)
    attention_score_dict = dict()

    # Create multiple layers of the Transformer block.
    for i in range(transformer_layers):
        # Layer normalization 1.
        x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
        # Create a multi-head attention layer.
        attention_output, attention_score = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=projection_dim, dropout=0.1
        )(x1, x1, return_attention_scores=True)
        # Skip connection 1.
        x2 = layers.Add()([attention_output, encoded_patches])
        # Layer normalization 2.
        x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
        # MLP.
        x3 = mlp(x3, hidden_units=transformer_units, dropout_rate=0.1)
        # Skip connection 2.
        encoded_patches = layers.Add()([x3, x2])
        attention_score_dict[i] = attention_score

    # Create a [batch_size, projection_dim] tensor.
    representation = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
    representation = layers.Flatten()(representation)
    representation = layers.Dropout(0.5)(representation)
    # Add MLP.
    features = mlp(representation, hidden_units=mlp_head_units, dropout_rate=0.5)
    # Classify outputs.
    logits = layers.Dense(num_classes)(features)
    # Create the Keras model.
    if with_attention_scores_output:
        model = keras.Model(inputs=inputs, outputs=[logits, attention_score_dict])
        model = keras.Model(inputs=inputs, outputs=logits)
    return model
def run_experiment(model):
    optimizer = tfa.optimizers.AdamW(
        learning_rate=learning_rate, weight_decay=weight_decay

            keras.metrics.TopKCategoricalAccuracy(5, name="top-5-accuracy"),

    checkpoint_filepath = "/tmp/checkpoint"
    checkpoint_callback = keras.callbacks.ModelCheckpoint(

    history = model.fit(

    _, accuracy, top_5_accuracy = model.evaluate(x_test, y_test)
    print(f"Test accuracy: {round(accuracy * 100, 2)}%")
    print(f"Test top 5 accuracy: {round(top_5_accuracy * 100, 2)}%")

    return history

vit_classifier = create_vit_classifier()
history = run_experiment(vit_classifier)
Epoch 1/100
11/11 [==============================] - 19s 634ms/step - loss: 4.5576 - accuracy: 0.1012 - top-5-accuracy: 0.2982 - val_loss: 2.7751 - val_accuracy: 0.2599 - val_top-5-accuracy: 0.5691
Epoch 2/100
11/11 [==============================] - 5s 468ms/step - loss: 3.0885 - accuracy: 0.1970 - top-5-accuracy: 0.4952 - val_loss: 2.4063 - val_accuracy: 0.3191 - val_top-5-accuracy: 0.6743
Epoch 3/100
11/11 [==============================] - 5s 451ms/step - loss: 2.7145 - accuracy: 0.2657 - top-5-accuracy: 0.6023 - val_loss: 2.1349 - val_accuracy: 0.3947 - val_top-5-accuracy: 0.7533
Epoch 4/100
11/11 [==============================] - 5s 478ms/step - loss: 2.4462 - accuracy: 0.3249 - top-5-accuracy: 0.6703 - val_loss: 2.0473 - val_accuracy: 0.4243 - val_top-5-accuracy: 0.7829
Epoch 5/100
11/11 [==============================] - 5s 468ms/step - loss: 2.2455 - accuracy: 0.3644 - top-5-accuracy: 0.7251 - val_loss: 1.8077 - val_accuracy: 0.4934 - val_top-5-accuracy: 0.8388
Epoch 6/100
11/11 [==============================] - 5s 454ms/step - loss: 2.0483 - accuracy: 0.4090 - top-5-accuracy: 0.7686 - val_loss: 1.5991 - val_accuracy: 0.5296 - val_top-5-accuracy: 0.8717
Epoch 7/100
11/11 [==============================] - 5s 462ms/step - loss: 1.8823 - accuracy: 0.4529 - top-5-accuracy: 0.8132 - val_loss: 1.4373 - val_accuracy: 0.5757 - val_top-5-accuracy: 0.9178
Epoch 8/100
11/11 [==============================] - 5s 462ms/step - loss: 1.7941 - accuracy: 0.4762 - top-5-accuracy: 0.8326 - val_loss: 1.3213 - val_accuracy: 0.6513 - val_top-5-accuracy: 0.9046
Epoch 9/100
11/11 [==============================] - 5s 482ms/step - loss: 1.6349 - accuracy: 0.5128 - top-5-accuracy: 0.8578 - val_loss: 1.3080 - val_accuracy: 0.5921 - val_top-5-accuracy: 0.9112
Epoch 10/100
11/11 [==============================] - 5s 463ms/step - loss: 1.4902 - accuracy: 0.5585 - top-5-accuracy: 0.8819 - val_loss: 1.1707 - val_accuracy: 0.6546 - val_top-5-accuracy: 0.9276
Epoch 11/100
11/11 [==============================] - 5s 478ms/step - loss: 1.3675 - accuracy: 0.5936 - top-5-accuracy: 0.9046 - val_loss: 1.0757 - val_accuracy: 0.6809 - val_top-5-accuracy: 0.9408
Epoch 12/100
11/11 [==============================] - 5s 465ms/step - loss: 1.3511 - accuracy: 0.5826 - top-5-accuracy: 0.9050 - val_loss: 1.0533 - val_accuracy: 0.6546 - val_top-5-accuracy: 0.9605
Epoch 13/100
11/11 [==============================] - 5s 477ms/step - loss: 1.2636 - accuracy: 0.6107 - top-5-accuracy: 0.9192 - val_loss: 0.9733 - val_accuracy: 0.7434 - val_top-5-accuracy: 0.9474
Epoch 14/100
11/11 [==============================] - 5s 505ms/step - loss: 1.1770 - accuracy: 0.6316 - top-5-accuracy: 0.9284 - val_loss: 1.0223 - val_accuracy: 0.7072 - val_top-5-accuracy: 0.9539
Epoch 15/100
11/11 [==============================] - 6s 531ms/step - loss: 1.0597 - accuracy: 0.6678 - top-5-accuracy: 0.9382 - val_loss: 0.9859 - val_accuracy: 0.7204 - val_top-5-accuracy: 0.9539
Epoch 16/100
11/11 [==============================] - 5s 479ms/step - loss: 1.0553 - accuracy: 0.6732 - top-5-accuracy: 0.9474 - val_loss: 0.9171 - val_accuracy: 0.7237 - val_top-5-accuracy: 0.9539
Epoch 17/100
11/11 [==============================] - 5s 469ms/step - loss: 0.9795 - accuracy: 0.6919 - top-5-accuracy: 0.9503 - val_loss: 0.9283 - val_accuracy: 0.7303 - val_top-5-accuracy: 0.9572
Epoch 18/100
11/11 [==============================] - 5s 470ms/step - loss: 0.9205 - accuracy: 0.7069 - top-5-accuracy: 0.9583 - val_loss: 0.8204 - val_accuracy: 0.7961 - val_top-5-accuracy: 0.9737
Epoch 19/100
11/11 [==============================] - 5s 469ms/step - loss: 0.9587 - accuracy: 0.6959 - top-5-accuracy: 0.9594 - val_loss: 0.7994 - val_accuracy: 0.7664 - val_top-5-accuracy: 0.9638
Epoch 20/100
11/11 [==============================] - 5s 503ms/step - loss: 0.8644 - accuracy: 0.7336 - top-5-accuracy: 0.9635 - val_loss: 0.7162 - val_accuracy: 0.7697 - val_top-5-accuracy: 0.9671
Epoch 21/100
11/11 [==============================] - 5s 478ms/step - loss: 0.8480 - accuracy: 0.7288 - top-5-accuracy: 0.9653 - val_loss: 0.7097 - val_accuracy: 0.7862 - val_top-5-accuracy: 0.9737
Epoch 22/100
11/11 [==============================] - 5s 483ms/step - loss: 0.7660 - accuracy: 0.7518 - top-5-accuracy: 0.9730 - val_loss: 0.7207 - val_accuracy: 0.7763 - val_top-5-accuracy: 0.9704
Epoch 23/100
11/11 [==============================] - 5s 470ms/step - loss: 0.7501 - accuracy: 0.7664 - top-5-accuracy: 0.9671 - val_loss: 0.6781 - val_accuracy: 0.7961 - val_top-5-accuracy: 0.9737
Epoch 24/100
11/11 [==============================] - 5s 467ms/step - loss: 0.7344 - accuracy: 0.7683 - top-5-accuracy: 0.9726 - val_loss: 0.8229 - val_accuracy: 0.7434 - val_top-5-accuracy: 0.9704
Epoch 25/100
11/11 [==============================] - 5s 470ms/step - loss: 0.7047 - accuracy: 0.7756 - top-5-accuracy: 0.9744 - val_loss: 0.6895 - val_accuracy: 0.7829 - val_top-5-accuracy: 0.9671
Epoch 26/100
11/11 [==============================] - 5s 467ms/step - loss: 0.6215 - accuracy: 0.7968 - top-5-accuracy: 0.9784 - val_loss: 0.7865 - val_accuracy: 0.7566 - val_top-5-accuracy: 0.9704
Epoch 27/100
11/11 [==============================] - 5s 485ms/step - loss: 0.6753 - accuracy: 0.7880 - top-5-accuracy: 0.9799 - val_loss: 0.6987 - val_accuracy: 0.8191 - val_top-5-accuracy: 0.9737
Epoch 28/100
11/11 [==============================] - 5s 467ms/step - loss: 0.6023 - accuracy: 0.8085 - top-5-accuracy: 0.9814 - val_loss: 0.6609 - val_accuracy: 0.8191 - val_top-5-accuracy: 0.9770
Epoch 29/100
11/11 [==============================] - 5s 482ms/step - loss: 0.5867 - accuracy: 0.8077 - top-5-accuracy: 0.9828 - val_loss: 0.6644 - val_accuracy: 0.7961 - val_top-5-accuracy: 0.9704
Epoch 30/100
11/11 [==============================] - 5s 470ms/step - loss: 0.5852 - accuracy: 0.8216 - top-5-accuracy: 0.9825 - val_loss: 0.5996 - val_accuracy: 0.8158 - val_top-5-accuracy: 0.9737
Epoch 31/100
11/11 [==============================] - 5s 470ms/step - loss: 0.4852 - accuracy: 0.8454 - top-5-accuracy: 0.9894 - val_loss: 0.6340 - val_accuracy: 0.8224 - val_top-5-accuracy: 0.9704
Epoch 32/100
11/11 [==============================] - 5s 491ms/step - loss: 0.4860 - accuracy: 0.8425 - top-5-accuracy: 0.9857 - val_loss: 0.7302 - val_accuracy: 0.7993 - val_top-5-accuracy: 0.9704
Epoch 33/100
11/11 [==============================] - 5s 475ms/step - loss: 0.4684 - accuracy: 0.8428 - top-5-accuracy: 0.9857 - val_loss: 0.5790 - val_accuracy: 0.8257 - val_top-5-accuracy: 0.9704
Epoch 34/100
11/11 [==============================] - 5s 471ms/step - loss: 0.4278 - accuracy: 0.8626 - top-5-accuracy: 0.9883 - val_loss: 0.5776 - val_accuracy: 0.8092 - val_top-5-accuracy: 0.9836
Epoch 35/100
11/11 [==============================] - 5s 468ms/step - loss: 0.4116 - accuracy: 0.8607 - top-5-accuracy: 0.9949 - val_loss: 0.5644 - val_accuracy: 0.8388 - val_top-5-accuracy: 0.9770
Epoch 36/100
11/11 [==============================] - 5s 476ms/step - loss: 0.4202 - accuracy: 0.8626 - top-5-accuracy: 0.9920 - val_loss: 0.6480 - val_accuracy: 0.8191 - val_top-5-accuracy: 0.9770
Epoch 37/100
11/11 [==============================] - 5s 471ms/step - loss: 0.4682 - accuracy: 0.8520 - top-5-accuracy: 0.9898 - val_loss: 0.6602 - val_accuracy: 0.8059 - val_top-5-accuracy: 0.9638
Epoch 38/100
11/11 [==============================] - 5s 489ms/step - loss: 0.4195 - accuracy: 0.8648 - top-5-accuracy: 0.9931 - val_loss: 0.5412 - val_accuracy: 0.8158 - val_top-5-accuracy: 0.9737
Epoch 39/100
11/11 [==============================] - 5s 469ms/step - loss: 0.3823 - accuracy: 0.8750 - top-5-accuracy: 0.9945 - val_loss: 0.6372 - val_accuracy: 0.8224 - val_top-5-accuracy: 0.9704
Epoch 40/100
11/11 [==============================] - 5s 470ms/step - loss: 0.3818 - accuracy: 0.8812 - top-5-accuracy: 0.9956 - val_loss: 0.6658 - val_accuracy: 0.7993 - val_top-5-accuracy: 0.9737
Epoch 41/100
11/11 [==============================] - 5s 481ms/step - loss: 0.3293 - accuracy: 0.8933 - top-5-accuracy: 0.9952 - val_loss: 0.6289 - val_accuracy: 0.8322 - val_top-5-accuracy: 0.9671
Epoch 42/100
11/11 [==============================] - 5s 469ms/step - loss: 0.2960 - accuracy: 0.9031 - top-5-accuracy: 0.9967 - val_loss: 0.5719 - val_accuracy: 0.8257 - val_top-5-accuracy: 0.9770
Epoch 43/100
11/11 [==============================] - 5s 486ms/step - loss: 0.3025 - accuracy: 0.9010 - top-5-accuracy: 0.9956 - val_loss: 0.5237 - val_accuracy: 0.8388 - val_top-5-accuracy: 0.9803
Epoch 44/100
11/11 [==============================] - 5s 479ms/step - loss: 0.3484 - accuracy: 0.8966 - top-5-accuracy: 0.9942 - val_loss: 0.6792 - val_accuracy: 0.7862 - val_top-5-accuracy: 0.9803
Epoch 45/100
11/11 [==============================] - 5s 472ms/step - loss: 0.2933 - accuracy: 0.9046 - top-5-accuracy: 0.9971 - val_loss: 0.5624 - val_accuracy: 0.8421 - val_top-5-accuracy: 0.9868
Epoch 46/100
11/11 [==============================] - 5s 466ms/step - loss: 0.3030 - accuracy: 0.9013 - top-5-accuracy: 0.9985 - val_loss: 0.6261 - val_accuracy: 0.7961 - val_top-5-accuracy: 0.9770
Epoch 47/100
11/11 [==============================] - 5s 469ms/step - loss: 0.3210 - accuracy: 0.8977 - top-5-accuracy: 0.9960 - val_loss: 0.5802 - val_accuracy: 0.8388 - val_top-5-accuracy: 0.9836
Epoch 48/100
11/11 [==============================] - 5s 468ms/step - loss: 0.3209 - accuracy: 0.8958 - top-5-accuracy: 0.9971 - val_loss: 0.6645 - val_accuracy: 0.8158 - val_top-5-accuracy: 0.9803
Epoch 49/100
11/11 [==============================] - 5s 469ms/step - loss: 0.3030 - accuracy: 0.9134 - top-5-accuracy: 0.9945 - val_loss: 0.5909 - val_accuracy: 0.8322 - val_top-5-accuracy: 0.9770
Epoch 50/100
11/11 [==============================] - 5s 489ms/step - loss: 0.2841 - accuracy: 0.9061 - top-5-accuracy: 0.9971 - val_loss: 0.6537 - val_accuracy: 0.8191 - val_top-5-accuracy: 0.9671
Epoch 51/100
11/11 [==============================] - 5s 480ms/step - loss: 0.2866 - accuracy: 0.9072 - top-5-accuracy: 0.9982 - val_loss: 0.5929 - val_accuracy: 0.8224 - val_top-5-accuracy: 0.9770
Epoch 52/100
11/11 [==============================] - 5s 468ms/step - loss: 0.2941 - accuracy: 0.9046 - top-5-accuracy: 0.9974 - val_loss: 0.5234 - val_accuracy: 0.8322 - val_top-5-accuracy: 0.9803
Epoch 53/100
11/11 [==============================] - 5s 473ms/step - loss: 0.2900 - accuracy: 0.9141 - top-5-accuracy: 0.9956 - val_loss: 0.6007 - val_accuracy: 0.8289 - val_top-5-accuracy: 0.9737
Epoch 54/100
11/11 [==============================] - 5s 471ms/step - loss: 0.2311 - accuracy: 0.9280 - top-5-accuracy: 0.9993 - val_loss: 0.5893 - val_accuracy: 0.8586 - val_top-5-accuracy: 0.9803
Epoch 55/100
11/11 [==============================] - 5s 478ms/step - loss: 0.2484 - accuracy: 0.9218 - top-5-accuracy: 0.9985 - val_loss: 0.6348 - val_accuracy: 0.8355 - val_top-5-accuracy: 0.9770
Epoch 56/100
11/11 [==============================] - 5s 496ms/step - loss: 0.2685 - accuracy: 0.9159 - top-5-accuracy: 0.9956 - val_loss: 0.5722 - val_accuracy: 0.8586 - val_top-5-accuracy: 0.9770
Epoch 57/100
11/11 [==============================] - 5s 478ms/step - loss: 0.2363 - accuracy: 0.9287 - top-5-accuracy: 0.9989 - val_loss: 0.5732 - val_accuracy: 0.8487 - val_top-5-accuracy: 0.9737
Epoch 58/100
11/11 [==============================] - 5s 470ms/step - loss: 0.2536 - accuracy: 0.9229 - top-5-accuracy: 0.9982 - val_loss: 0.5657 - val_accuracy: 0.8421 - val_top-5-accuracy: 0.9836
Epoch 59/100
11/11 [==============================] - 5s 466ms/step - loss: 0.2083 - accuracy: 0.9393 - top-5-accuracy: 0.9971 - val_loss: 0.5523 - val_accuracy: 0.8618 - val_top-5-accuracy: 0.9704
Epoch 60/100
11/11 [==============================] - 5s 466ms/step - loss: 0.2203 - accuracy: 0.9306 - top-5-accuracy: 0.9989 - val_loss: 0.6903 - val_accuracy: 0.8322 - val_top-5-accuracy: 0.9737
Epoch 61/100
11/11 [==============================] - 5s 470ms/step - loss: 0.2407 - accuracy: 0.9207 - top-5-accuracy: 0.9982 - val_loss: 0.5560 - val_accuracy: 0.8322 - val_top-5-accuracy: 0.9868
Epoch 62/100
11/11 [==============================] - 5s 483ms/step - loss: 0.2638 - accuracy: 0.9152 - top-5-accuracy: 0.9967 - val_loss: 0.5574 - val_accuracy: 0.8553 - val_top-5-accuracy: 0.9803
Epoch 63/100
11/11 [==============================] - 5s 476ms/step - loss: 0.2159 - accuracy: 0.9317 - top-5-accuracy: 0.9985 - val_loss: 0.6279 - val_accuracy: 0.8487 - val_top-5-accuracy: 0.9836
Epoch 64/100
11/11 [==============================] - 5s 463ms/step - loss: 0.2551 - accuracy: 0.9251 - top-5-accuracy: 0.9967 - val_loss: 0.6940 - val_accuracy: 0.8191 - val_top-5-accuracy: 0.9803
Epoch 65/100
11/11 [==============================] - 5s 479ms/step - loss: 0.2325 - accuracy: 0.9258 - top-5-accuracy: 0.9989 - val_loss: 0.6324 - val_accuracy: 0.8355 - val_top-5-accuracy: 0.9770
Epoch 66/100
11/11 [==============================] - 5s 469ms/step - loss: 0.2144 - accuracy: 0.9364 - top-5-accuracy: 0.9982 - val_loss: 0.6224 - val_accuracy: 0.8454 - val_top-5-accuracy: 0.9836
Epoch 67/100
11/11 [==============================] - 5s 470ms/step - loss: 0.2023 - accuracy: 0.9346 - top-5-accuracy: 0.9993 - val_loss: 0.5875 - val_accuracy: 0.8553 - val_top-5-accuracy: 0.9704
Epoch 68/100
11/11 [==============================] - 5s 475ms/step - loss: 0.1912 - accuracy: 0.9375 - top-5-accuracy: 0.9989 - val_loss: 0.6924 - val_accuracy: 0.8224 - val_top-5-accuracy: 0.9737
Epoch 69/100
11/11 [==============================] - 5s 472ms/step - loss: 0.2205 - accuracy: 0.9295 - top-5-accuracy: 0.9974 - val_loss: 0.6561 - val_accuracy: 0.8257 - val_top-5-accuracy: 0.9737
Epoch 70/100
11/11 [==============================] - 5s 471ms/step - loss: 0.2002 - accuracy: 0.9368 - top-5-accuracy: 0.9985 - val_loss: 0.5291 - val_accuracy: 0.8651 - val_top-5-accuracy: 0.9868
Epoch 71/100
11/11 [==============================] - 5s 468ms/step - loss: 0.1925 - accuracy: 0.9382 - top-5-accuracy: 0.9993 - val_loss: 0.5830 - val_accuracy: 0.8586 - val_top-5-accuracy: 0.9836
Epoch 72/100
11/11 [==============================] - 5s 471ms/step - loss: 0.1846 - accuracy: 0.9430 - top-5-accuracy: 0.9996 - val_loss: 0.6433 - val_accuracy: 0.8487 - val_top-5-accuracy: 0.9803
Epoch 73/100
11/11 [==============================] - 5s 470ms/step - loss: 0.1780 - accuracy: 0.9444 - top-5-accuracy: 0.9993 - val_loss: 0.5253 - val_accuracy: 0.8618 - val_top-5-accuracy: 0.9836
Epoch 74/100
11/11 [==============================] - 6s 505ms/step - loss: 0.1646 - accuracy: 0.9485 - top-5-accuracy: 0.9982 - val_loss: 0.5431 - val_accuracy: 0.8684 - val_top-5-accuracy: 0.9803
Epoch 75/100
11/11 [==============================] - 5s 474ms/step - loss: 0.1678 - accuracy: 0.9470 - top-5-accuracy: 0.9978 - val_loss: 0.6069 - val_accuracy: 0.8684 - val_top-5-accuracy: 0.9868
Epoch 76/100
11/11 [==============================] - 5s 469ms/step - loss: 0.1517 - accuracy: 0.9488 - top-5-accuracy: 0.9985 - val_loss: 0.5667 - val_accuracy: 0.8487 - val_top-5-accuracy: 0.9901
Epoch 77/100
11/11 [==============================] - 5s 471ms/step - loss: 0.1416 - accuracy: 0.9525 - top-5-accuracy: 0.9989 - val_loss: 0.6547 - val_accuracy: 0.8421 - val_top-5-accuracy: 0.9803
Epoch 78/100
11/11 [==============================] - 5s 479ms/step - loss: 0.1623 - accuracy: 0.9477 - top-5-accuracy: 0.9989 - val_loss: 0.5051 - val_accuracy: 0.8618 - val_top-5-accuracy: 0.9934
Epoch 79/100
11/11 [==============================] - 5s 475ms/step - loss: 0.1666 - accuracy: 0.9488 - top-5-accuracy: 0.9989 - val_loss: 0.6188 - val_accuracy: 0.8618 - val_top-5-accuracy: 0.9737
Epoch 80/100
11/11 [==============================] - 5s 478ms/step - loss: 0.2011 - accuracy: 0.9368 - top-5-accuracy: 0.9985 - val_loss: 0.6103 - val_accuracy: 0.8618 - val_top-5-accuracy: 0.9803
Epoch 81/100
11/11 [==============================] - 5s 471ms/step - loss: 0.1709 - accuracy: 0.9488 - top-5-accuracy: 0.9985 - val_loss: 0.5960 - val_accuracy: 0.8684 - val_top-5-accuracy: 0.9671
Epoch 82/100
11/11 [==============================] - 5s 480ms/step - loss: 0.1613 - accuracy: 0.9477 - top-5-accuracy: 0.9993 - val_loss: 0.6275 - val_accuracy: 0.8487 - val_top-5-accuracy: 0.9737
Epoch 83/100
11/11 [==============================] - 5s 471ms/step - loss: 0.1537 - accuracy: 0.9507 - top-5-accuracy: 0.9985 - val_loss: 0.6007 - val_accuracy: 0.8586 - val_top-5-accuracy: 0.9803
Epoch 84/100
11/11 [==============================] - 5s 477ms/step - loss: 0.1711 - accuracy: 0.9481 - top-5-accuracy: 0.9982 - val_loss: 0.5334 - val_accuracy: 0.8487 - val_top-5-accuracy: 0.9803
Epoch 85/100
11/11 [==============================] - 5s 468ms/step - loss: 0.1666 - accuracy: 0.9496 - top-5-accuracy: 0.9989 - val_loss: 0.6471 - val_accuracy: 0.8421 - val_top-5-accuracy: 0.9770
Epoch 86/100
11/11 [==============================] - 5s 483ms/step - loss: 0.1671 - accuracy: 0.9496 - top-5-accuracy: 0.9996 - val_loss: 0.7173 - val_accuracy: 0.8289 - val_top-5-accuracy: 0.9836
Epoch 87/100
11/11 [==============================] - 5s 463ms/step - loss: 0.1811 - accuracy: 0.9426 - top-5-accuracy: 0.9985 - val_loss: 0.7476 - val_accuracy: 0.8191 - val_top-5-accuracy: 0.9836
Epoch 88/100
11/11 [==============================] - 5s 482ms/step - loss: 0.1487 - accuracy: 0.9547 - top-5-accuracy: 1.0000 - val_loss: 0.5163 - val_accuracy: 0.8717 - val_top-5-accuracy: 0.9868
Epoch 89/100
11/11 [==============================] - 5s 470ms/step - loss: 0.1241 - accuracy: 0.9609 - top-5-accuracy: 1.0000 - val_loss: 0.5221 - val_accuracy: 0.8717 - val_top-5-accuracy: 0.9868
Epoch 90/100
11/11 [==============================] - 5s 484ms/step - loss: 0.1255 - accuracy: 0.9613 - top-5-accuracy: 0.9993 - val_loss: 0.4853 - val_accuracy: 0.8750 - val_top-5-accuracy: 0.9868
Epoch 91/100
11/11 [==============================] - 5s 488ms/step - loss: 0.1113 - accuracy: 0.9653 - top-5-accuracy: 1.0000 - val_loss: 0.5162 - val_accuracy: 0.8882 - val_top-5-accuracy: 0.9901
Epoch 92/100
11/11 [==============================] - 5s 476ms/step - loss: 0.0999 - accuracy: 0.9638 - top-5-accuracy: 0.9996 - val_loss: 0.4899 - val_accuracy: 0.8882 - val_top-5-accuracy: 0.9934
Epoch 93/100
11/11 [==============================] - 5s 466ms/step - loss: 0.1115 - accuracy: 0.9645 - top-5-accuracy: 0.9993 - val_loss: 0.6128 - val_accuracy: 0.8520 - val_top-5-accuracy: 0.9934
Epoch 94/100
11/11 [==============================] - 5s 474ms/step - loss: 0.1242 - accuracy: 0.9580 - top-5-accuracy: 1.0000 - val_loss: 0.5649 - val_accuracy: 0.8849 - val_top-5-accuracy: 0.9836
Epoch 95/100
11/11 [==============================] - 5s 471ms/step - loss: 0.1530 - accuracy: 0.9569 - top-5-accuracy: 0.9996 - val_loss: 0.5064 - val_accuracy: 0.8651 - val_top-5-accuracy: 0.9836
Epoch 96/100
11/11 [==============================] - 5s 475ms/step - loss: 0.1813 - accuracy: 0.9496 - top-5-accuracy: 0.9989 - val_loss: 0.5598 - val_accuracy: 0.8914 - val_top-5-accuracy: 0.9836
Epoch 97/100
11/11 [==============================] - 5s 488ms/step - loss: 0.1615 - accuracy: 0.9496 - top-5-accuracy: 0.9993 - val_loss: 0.6164 - val_accuracy: 0.8454 - val_top-5-accuracy: 0.9770
Epoch 98/100
11/11 [==============================] - 5s 482ms/step - loss: 0.1599 - accuracy: 0.9499 - top-5-accuracy: 0.9996 - val_loss: 0.6491 - val_accuracy: 0.8586 - val_top-5-accuracy: 0.9836
Epoch 99/100
11/11 [==============================] - 5s 470ms/step - loss: 0.1286 - accuracy: 0.9576 - top-5-accuracy: 0.9993 - val_loss: 0.5429 - val_accuracy: 0.8684 - val_top-5-accuracy: 0.9836
Epoch 100/100
11/11 [==============================] - 5s 469ms/step - loss: 0.1263 - accuracy: 0.9627 - top-5-accuracy: 0.9989 - val_loss: 0.6173 - val_accuracy: 0.8618 - val_top-5-accuracy: 0.9868
24/24 [==============================] - 1s 36ms/step - loss: 0.6115 - accuracy: 0.8579 - top-5-accuracy: 0.9750
Test accuracy: 85.79%
Test top 5 accuracy: 97.5%
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
top_5_acc = history.history['top-5-accuracy']
op_5_val_acc = history.history['val_top-5-accuracy']
epochs = range(1, len(acc) + 1)

plt.plot(epochs, acc, 'b', label='Training accurarcy')
plt.plot(epochs, val_acc, 'r', label='Validation accurarcy')
plt.title('Training and Validation accurarcy')

plt.plot(epochs, acc, 'b', label='Top-5 Training accurarcy')
plt.plot(epochs, val_acc, 'r', label='Top-5 Validation accurarcy')
plt.title('Training and Validation Top-5 accurarcy')

plt.plot(epochs, loss, 'b', label='Training loss')
plt.plot(epochs, val_loss, 'r', label='Validation loss')
plt.title('Training and Validation loss')
dict_keys(['loss', 'accuracy', 'top-5-accuracy', 'val_loss', 'val_accuracy', 'val_top-5-accuracy'])


vit_classifier_with_att_scores = create_vit_classifier(with_attention_scores_output=True)
checkpoint_filepath = "/tmp/checkpoint"
plt.figure(figsize=(4, 4))
image_to_predict = x_train[np.random.choice(range(x_train.shape[0]))]
image_to_draw = image_to_predict * 225.0
image_to_predict = np.expand_dims(image_to_predict, axis=0)
output, attention_scores  = vit_classifier_with_att_scores(image_to_predict)
# computes attention rollout map
def attention_map(src_image, attention_score_dict):
    attention_mat = tf.stack(list(attention_score_dict.values()))
    attention_mat = tf.squeeze(attention_mat, axis=1)

    # Average attention (for all heads) per MultiHeadAttention layers
    attention_mat = tf.reduce_mean(attention_mat, axis=1)

    # "to account for residual connections, we add an identity matrix to the attention matrix and re-normalize the weights"
    residual_attn = tf.eye(attention_mat.shape[1])
    aug_attention_mat = attention_mat + residual_attn
    aug_attention_mat = np.array(aug_attention_mat / tf.reduce_sum(aug_attention_mat, axis=-1)[..., None])

    # Recursive formula on Vl+1
    joint_attentions = np.zeros(aug_attention_mat.shape)
    joint_attentions[0] = aug_attention_mat[0]
    for n in range(1, aug_attention_mat.shape[0]):
        joint_attentions[n] = np.matmul(aug_attention_mat[n], joint_attentions[n - 1])

    # Reshape last calculated attention that we can multiply it by the input RGB image
    v = joint_attentions[-1]
    grid_size = int(np.sqrt(aug_attention_mat.shape[-1]))
    mask = v[0].reshape(grid_size, grid_size)
    mask = cv2.resize(mask / mask.max(), src_image.shape[:2])[..., np.newaxis] #[: ,:, np.newaxis]
    result = (mask * src_image).astype("uint8")
    return result
attn_rollout_result = attention_map(image_to_draw, attention_scores)

fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(8, 10))

ax2.set_title("Attention Map")
(8, 1, 4, 144, 144)
(8, 4, 144, 144)
(8, 144, 144)
(144, 144)
(8, 144, 144)