wko_projekt/Projekt widzenie komputerowe.ipynb

802 KiB

#!git clone https://github.com/spMohanty/PlantVillage-Dataset
#!cd PlantVillage-Dataset
import numpy as np
import pickle
import cv2
import tensorflow as tf
import tensorflow_addons as tfa
from tensorflow import keras
from tensorflow.keras import layers
from os import listdir, mkdir
from sklearn.preprocessing import LabelBinarizer
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.model_selection import train_test_split
#from keras.utils import img_to_array
from keras.preprocessing.image import img_to_array
import matplotlib.pyplot as plt
data_dir = "/kaggle/input/plantvillage-dataset/color"

classes = listdir(data_dir)
classes
['Tomato___Late_blight',
 'Tomato___healthy',
 'Grape___healthy',
 'Orange___Haunglongbing_(Citrus_greening)',
 'Soybean___healthy',
 'Squash___Powdery_mildew',
 'Potato___healthy',
 'Corn_(maize)___Northern_Leaf_Blight',
 'Tomato___Early_blight',
 'Tomato___Septoria_leaf_spot',
 'Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot',
 'Strawberry___Leaf_scorch',
 'Peach___healthy',
 'Apple___Apple_scab',
 'Tomato___Tomato_Yellow_Leaf_Curl_Virus',
 'Tomato___Bacterial_spot',
 'Apple___Black_rot',
 'Blueberry___healthy',
 'Cherry_(including_sour)___Powdery_mildew',
 'Peach___Bacterial_spot',
 'Apple___Cedar_apple_rust',
 'Tomato___Target_Spot',
 'Pepper,_bell___healthy',
 'Grape___Leaf_blight_(Isariopsis_Leaf_Spot)',
 'Potato___Late_blight',
 'Tomato___Tomato_mosaic_virus',
 'Strawberry___healthy',
 'Apple___healthy',
 'Grape___Black_rot',
 'Potato___Early_blight',
 'Cherry_(including_sour)___healthy',
 'Corn_(maize)___Common_rust_',
 'Grape___Esca_(Black_Measles)',
 'Raspberry___healthy',
 'Tomato___Leaf_Mold',
 'Tomato___Spider_mites Two-spotted_spider_mite',
 'Pepper,_bell___Bacterial_spot',
 'Corn_(maize)___healthy']
#mkdir(data_dir + 'test_data')
#mkdir(data_dir + 'train_data')
learning_rate = 0.001
weight_decay = 0.0001
batch_size = 256
num_epochs = 100
image_size = 72  # Resize input images to this size
patch_size = 6  # Size of the patches to be extract from the input images
num_patches = (image_size // patch_size) ** 2
projection_dim = 64
num_heads = 4
transformer_units = [
    projection_dim * 2,
    projection_dim,
]  # Size of the transformer layers
transformer_layers = 8
mlp_head_units = [2048, 1024]  # Size of the dense layers of the final classifier
num_classes = 38
def convert_image_to_array(image_dir):
  image = cv2.imread(image_dir)
  if image is not None :  
    return img_to_array(image)
  else :
    return np.array([])
image_list, label_list = [], []

print("Loading images ...")
root_dir = listdir(data_dir)

for plant_folder in root_dir :
    plant_disease_image_list = listdir(f"{data_dir}/{plant_folder}")
    print(len(plant_disease_image_list))
    for image in plant_disease_image_list[:100]:
        image_directory = f"{data_dir}/{plant_folder}/{image}"
        if image_directory.lower().endswith(".jpg"):
            image_list.append(convert_image_to_array(image_directory))
            label_list.append(plant_folder)
print("Image loading completed")  
Loading images ...
1909
1591
423
5507
5090
1835
152
985
1000
1771
513
1109
360
630
5357
2127
621
1502
1052
2297
275
1404
1478
1076
1000
373
456
1645
1180
1000
854
1192
1383
371
952
1676
997
1162
Image loading completed
label_binarizer = LabelBinarizer()
image_labels = label_binarizer.fit_transform(label_list)
pickle.dump(label_binarizer,open('label_transform.pkl', 'wb'))
n_classes = len(label_binarizer.classes_)
print(label_binarizer.classes_)
['Apple___Apple_scab' 'Apple___Black_rot' 'Apple___Cedar_apple_rust'
 'Apple___healthy' 'Blueberry___healthy'
 'Cherry_(including_sour)___Powdery_mildew'
 'Cherry_(including_sour)___healthy'
 'Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot'
 'Corn_(maize)___Common_rust_' 'Corn_(maize)___Northern_Leaf_Blight'
 'Corn_(maize)___healthy' 'Grape___Black_rot'
 'Grape___Esca_(Black_Measles)'
 'Grape___Leaf_blight_(Isariopsis_Leaf_Spot)' 'Grape___healthy'
 'Orange___Haunglongbing_(Citrus_greening)' 'Peach___Bacterial_spot'
 'Peach___healthy' 'Pepper,_bell___Bacterial_spot'
 'Pepper,_bell___healthy' 'Potato___Early_blight' 'Potato___Late_blight'
 'Potato___healthy' 'Raspberry___healthy' 'Soybean___healthy'
 'Squash___Powdery_mildew' 'Strawberry___Leaf_scorch'
 'Strawberry___healthy' 'Tomato___Bacterial_spot' 'Tomato___Early_blight'
 'Tomato___Late_blight' 'Tomato___Leaf_Mold' 'Tomato___Septoria_leaf_spot'
 'Tomato___Spider_mites Two-spotted_spider_mite' 'Tomato___Target_Spot'
 'Tomato___Tomato_Yellow_Leaf_Curl_Virus' 'Tomato___Tomato_mosaic_virus'
 'Tomato___healthy']
np_image_list = np.array(image_list, dtype=np.float16) / 225.0
print("Spliting data to train, test")
x_train, x_test, y_train, y_test = train_test_split(np_image_list, image_labels, test_size=0.2, random_state = 42) 
Spliting data to train, test
input_shape = x_train[0].shape
print(input_shape)
(256, 256, 3)
data_augmentation = keras.Sequential(
    [
        layers.Normalization(),
        layers.Resizing(image_size, image_size),
        layers.RandomFlip("horizontal"),
        layers.RandomRotation(factor=0.02),
        layers.RandomZoom(
            height_factor=0.2, width_factor=0.2
        ),
    ],
    name="data_augmentation",
)
# Compute the mean and the variance of the training data for normalization.
data_augmentation.layers[0].adapt(x_train)
2023-02-01 19:16:02.215465: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2023-02-01 19:16:02.216581: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2023-02-01 19:16:02.339556: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2023-02-01 19:16:02.340388: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2023-02-01 19:16:02.341204: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2023-02-01 19:16:02.341945: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2023-02-01 19:16:02.344060: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-02-01 19:16:02.612233: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2023-02-01 19:16:02.613116: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2023-02-01 19:16:02.613836: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2023-02-01 19:16:02.614556: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2023-02-01 19:16:02.615247: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2023-02-01 19:16:02.615907: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2023-02-01 19:16:05.652587: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2023-02-01 19:16:05.653498: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2023-02-01 19:16:05.654225: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2023-02-01 19:16:05.654932: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2023-02-01 19:16:05.655675: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2023-02-01 19:16:05.656433: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1510] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 13789 MB memory:  -> device: 0, name: Tesla T4, pci bus id: 0000:00:04.0, compute capability: 7.5
2023-02-01 19:16:05.659529: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2023-02-01 19:16:05.660219: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1510] Created device /job:localhost/replica:0/task:0/device:GPU:1 with 13789 MB memory:  -> device: 1, name: Tesla T4, pci bus id: 0000:00:05.0, compute capability: 7.5
2023-02-01 19:16:08.194863: W tensorflow/core/framework/cpu_allocator_impl.cc:80] Allocation of 2390753280 exceeds 10% of free system memory.
2023-02-01 19:16:11.025936: W tensorflow/core/framework/cpu_allocator_impl.cc:80] Allocation of 2390753280 exceeds 10% of free system memory.
2023-02-01 19:16:12.991353: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:185] None of the MLIR Optimization Passes are enabled (registered 2)
def mlp(x, hidden_units, dropout_rate):
    for units in hidden_units:
        x = layers.Dense(units, activation=tf.nn.gelu)(x)
        x = layers.Dropout(dropout_rate)(x)
    return x
class Patches(layers.Layer):
    def __init__(self, patch_size):
        super().__init__()
        self.patch_size = patch_size

    def call(self, images):
        batch_size = tf.shape(images)[0]
        patches = tf.image.extract_patches(
            images=images,
            sizes=[1, self.patch_size, self.patch_size, 1],
            strides=[1, self.patch_size, self.patch_size, 1],
            rates=[1, 1, 1, 1],
            padding="VALID",
        )
        patch_dims = patches.shape[-1]
        patches = tf.reshape(patches, [batch_size, -1, patch_dims])
        return patches
plt.figure(figsize=(4, 4))
image = x_train[np.random.choice(range(x_train.shape[0]))]
image = image * 225.0
plt.imshow(image.astype("uint8"))
plt.axis("off")

resized_image = tf.image.resize(
    tf.convert_to_tensor([image]), size=(image_size, image_size)
)
patches = Patches(patch_size)(resized_image)
print(f"Image size: {image_size} X {image_size}")
print(f"Patch size: {patch_size} X {patch_size}")
print(f"Patches per image: {patches.shape[1]}")
print(f"Elements per patch: {patches.shape[-1]}")

n = int(np.sqrt(patches.shape[1]))
plt.figure(figsize=(4, 4))
for i, patch in enumerate(patches[0]):
    ax = plt.subplot(n, n, i + 1)
    patch_img = tf.reshape(patch, (patch_size, patch_size, 3))
    plt.imshow(patch_img.numpy().astype("uint8"))
    plt.axis("off")
Image size: 72 X 72
Patch size: 6 X 6
Patches per image: 144
Elements per patch: 108
class PatchEncoder(layers.Layer):
    def __init__(self, num_patches, projection_dim):
        super().__init__()
        self.num_patches = num_patches
        self.projection = layers.Dense(units=projection_dim)
        self.position_embedding = layers.Embedding(
            input_dim=num_patches, output_dim=projection_dim
        )

    def call(self, patch):
        positions = tf.range(start=0, limit=self.num_patches, delta=1)
        encoded = self.projection(patch) + self.position_embedding(positions)
        return encoded
def create_vit_classifier(with_attention_scores_output=False):
    inputs = layers.Input(shape=input_shape)
    # Augment data.
    augmented = data_augmentation(inputs)
    # Create patches.
    patches = Patches(patch_size)(augmented)
    # Encode patches.
    encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)
    
    attention_score_dict = dict()

    # Create multiple layers of the Transformer block.
    for i in range(transformer_layers):
        # Layer normalization 1.
        x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
        # Create a multi-head attention layer.
        attention_output, attention_score = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=projection_dim, dropout=0.1
        )(x1, x1, return_attention_scores=True)
        # Skip connection 1.
        x2 = layers.Add()([attention_output, encoded_patches])
        # Layer normalization 2.
        x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
        # MLP.
        x3 = mlp(x3, hidden_units=transformer_units, dropout_rate=0.1)
        # Skip connection 2.
        encoded_patches = layers.Add()([x3, x2])
        
        attention_score_dict[i] = attention_score

    # Create a [batch_size, projection_dim] tensor.
    representation = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
    representation = layers.Flatten()(representation)
    representation = layers.Dropout(0.5)(representation)
    # Add MLP.
    features = mlp(representation, hidden_units=mlp_head_units, dropout_rate=0.5)
    # Classify outputs.
    logits = layers.Dense(num_classes)(features)
    # Create the Keras model.
    if with_attention_scores_output:
        model = keras.Model(inputs=inputs, outputs=[logits, attention_score_dict])
    else:
        model = keras.Model(inputs=inputs, outputs=logits)
    return model
def run_experiment(model):
    optimizer = tfa.optimizers.AdamW(
        learning_rate=learning_rate, weight_decay=weight_decay
    )

    model.compile(
        optimizer=optimizer,
        loss=keras.losses.CategoricalCrossentropy(from_logits=True),
        metrics=[
            keras.metrics.CategoricalAccuracy(name="accuracy"),
            keras.metrics.TopKCategoricalAccuracy(5, name="top-5-accuracy"),
        ],
    )

    checkpoint_filepath = "/tmp/checkpoint"
    checkpoint_callback = keras.callbacks.ModelCheckpoint(
        checkpoint_filepath,
        monitor="val_accuracy",
        save_best_only=True,
        save_weights_only=True,
    )

    history = model.fit(
        x=x_train,
        y=y_train,
        batch_size=batch_size,
        epochs=num_epochs,
        validation_split=0.1,
        callbacks=[checkpoint_callback],
    )

    model.load_weights(checkpoint_filepath)
    _, accuracy, top_5_accuracy = model.evaluate(x_test, y_test)
    print(f"Test accuracy: {round(accuracy * 100, 2)}%")
    print(f"Test top 5 accuracy: {round(top_5_accuracy * 100, 2)}%")

    return history


vit_classifier = create_vit_classifier()
history = run_experiment(vit_classifier)
2023-02-01 19:16:22.696720: W tensorflow/core/framework/cpu_allocator_impl.cc:80] Allocation of 2151677952 exceeds 10% of free system memory.
2023-02-01 19:16:25.088615: W tensorflow/core/framework/cpu_allocator_impl.cc:80] Allocation of 2151677952 exceeds 10% of free system memory.
Epoch 1/100
11/11 [==============================] - 19s 634ms/step - loss: 4.5576 - accuracy: 0.1012 - top-5-accuracy: 0.2982 - val_loss: 2.7751 - val_accuracy: 0.2599 - val_top-5-accuracy: 0.5691
Epoch 2/100
11/11 [==============================] - 5s 468ms/step - loss: 3.0885 - accuracy: 0.1970 - top-5-accuracy: 0.4952 - val_loss: 2.4063 - val_accuracy: 0.3191 - val_top-5-accuracy: 0.6743
Epoch 3/100
11/11 [==============================] - 5s 451ms/step - loss: 2.7145 - accuracy: 0.2657 - top-5-accuracy: 0.6023 - val_loss: 2.1349 - val_accuracy: 0.3947 - val_top-5-accuracy: 0.7533
Epoch 4/100
11/11 [==============================] - 5s 478ms/step - loss: 2.4462 - accuracy: 0.3249 - top-5-accuracy: 0.6703 - val_loss: 2.0473 - val_accuracy: 0.4243 - val_top-5-accuracy: 0.7829
Epoch 5/100
11/11 [==============================] - 5s 468ms/step - loss: 2.2455 - accuracy: 0.3644 - top-5-accuracy: 0.7251 - val_loss: 1.8077 - val_accuracy: 0.4934 - val_top-5-accuracy: 0.8388
Epoch 6/100
11/11 [==============================] - 5s 454ms/step - loss: 2.0483 - accuracy: 0.4090 - top-5-accuracy: 0.7686 - val_loss: 1.5991 - val_accuracy: 0.5296 - val_top-5-accuracy: 0.8717
Epoch 7/100
11/11 [==============================] - 5s 462ms/step - loss: 1.8823 - accuracy: 0.4529 - top-5-accuracy: 0.8132 - val_loss: 1.4373 - val_accuracy: 0.5757 - val_top-5-accuracy: 0.9178
Epoch 8/100
11/11 [==============================] - 5s 462ms/step - loss: 1.7941 - accuracy: 0.4762 - top-5-accuracy: 0.8326 - val_loss: 1.3213 - val_accuracy: 0.6513 - val_top-5-accuracy: 0.9046
Epoch 9/100
11/11 [==============================] - 5s 482ms/step - loss: 1.6349 - accuracy: 0.5128 - top-5-accuracy: 0.8578 - val_loss: 1.3080 - val_accuracy: 0.5921 - val_top-5-accuracy: 0.9112
Epoch 10/100
11/11 [==============================] - 5s 463ms/step - loss: 1.4902 - accuracy: 0.5585 - top-5-accuracy: 0.8819 - val_loss: 1.1707 - val_accuracy: 0.6546 - val_top-5-accuracy: 0.9276
Epoch 11/100
11/11 [==============================] - 5s 478ms/step - loss: 1.3675 - accuracy: 0.5936 - top-5-accuracy: 0.9046 - val_loss: 1.0757 - val_accuracy: 0.6809 - val_top-5-accuracy: 0.9408
Epoch 12/100
11/11 [==============================] - 5s 465ms/step - loss: 1.3511 - accuracy: 0.5826 - top-5-accuracy: 0.9050 - val_loss: 1.0533 - val_accuracy: 0.6546 - val_top-5-accuracy: 0.9605
Epoch 13/100
11/11 [==============================] - 5s 477ms/step - loss: 1.2636 - accuracy: 0.6107 - top-5-accuracy: 0.9192 - val_loss: 0.9733 - val_accuracy: 0.7434 - val_top-5-accuracy: 0.9474
Epoch 14/100
11/11 [==============================] - 5s 505ms/step - loss: 1.1770 - accuracy: 0.6316 - top-5-accuracy: 0.9284 - val_loss: 1.0223 - val_accuracy: 0.7072 - val_top-5-accuracy: 0.9539
Epoch 15/100
11/11 [==============================] - 6s 531ms/step - loss: 1.0597 - accuracy: 0.6678 - top-5-accuracy: 0.9382 - val_loss: 0.9859 - val_accuracy: 0.7204 - val_top-5-accuracy: 0.9539
Epoch 16/100
11/11 [==============================] - 5s 479ms/step - loss: 1.0553 - accuracy: 0.6732 - top-5-accuracy: 0.9474 - val_loss: 0.9171 - val_accuracy: 0.7237 - val_top-5-accuracy: 0.9539
Epoch 17/100
11/11 [==============================] - 5s 469ms/step - loss: 0.9795 - accuracy: 0.6919 - top-5-accuracy: 0.9503 - val_loss: 0.9283 - val_accuracy: 0.7303 - val_top-5-accuracy: 0.9572
Epoch 18/100
11/11 [==============================] - 5s 470ms/step - loss: 0.9205 - accuracy: 0.7069 - top-5-accuracy: 0.9583 - val_loss: 0.8204 - val_accuracy: 0.7961 - val_top-5-accuracy: 0.9737
Epoch 19/100
11/11 [==============================] - 5s 469ms/step - loss: 0.9587 - accuracy: 0.6959 - top-5-accuracy: 0.9594 - val_loss: 0.7994 - val_accuracy: 0.7664 - val_top-5-accuracy: 0.9638
Epoch 20/100
11/11 [==============================] - 5s 503ms/step - loss: 0.8644 - accuracy: 0.7336 - top-5-accuracy: 0.9635 - val_loss: 0.7162 - val_accuracy: 0.7697 - val_top-5-accuracy: 0.9671
Epoch 21/100
11/11 [==============================] - 5s 478ms/step - loss: 0.8480 - accuracy: 0.7288 - top-5-accuracy: 0.9653 - val_loss: 0.7097 - val_accuracy: 0.7862 - val_top-5-accuracy: 0.9737
Epoch 22/100
11/11 [==============================] - 5s 483ms/step - loss: 0.7660 - accuracy: 0.7518 - top-5-accuracy: 0.9730 - val_loss: 0.7207 - val_accuracy: 0.7763 - val_top-5-accuracy: 0.9704
Epoch 23/100
11/11 [==============================] - 5s 470ms/step - loss: 0.7501 - accuracy: 0.7664 - top-5-accuracy: 0.9671 - val_loss: 0.6781 - val_accuracy: 0.7961 - val_top-5-accuracy: 0.9737
Epoch 24/100
11/11 [==============================] - 5s 467ms/step - loss: 0.7344 - accuracy: 0.7683 - top-5-accuracy: 0.9726 - val_loss: 0.8229 - val_accuracy: 0.7434 - val_top-5-accuracy: 0.9704
Epoch 25/100
11/11 [==============================] - 5s 470ms/step - loss: 0.7047 - accuracy: 0.7756 - top-5-accuracy: 0.9744 - val_loss: 0.6895 - val_accuracy: 0.7829 - val_top-5-accuracy: 0.9671
Epoch 26/100
11/11 [==============================] - 5s 467ms/step - loss: 0.6215 - accuracy: 0.7968 - top-5-accuracy: 0.9784 - val_loss: 0.7865 - val_accuracy: 0.7566 - val_top-5-accuracy: 0.9704
Epoch 27/100
11/11 [==============================] - 5s 485ms/step - loss: 0.6753 - accuracy: 0.7880 - top-5-accuracy: 0.9799 - val_loss: 0.6987 - val_accuracy: 0.8191 - val_top-5-accuracy: 0.9737
Epoch 28/100
11/11 [==============================] - 5s 467ms/step - loss: 0.6023 - accuracy: 0.8085 - top-5-accuracy: 0.9814 - val_loss: 0.6609 - val_accuracy: 0.8191 - val_top-5-accuracy: 0.9770
Epoch 29/100
11/11 [==============================] - 5s 482ms/step - loss: 0.5867 - accuracy: 0.8077 - top-5-accuracy: 0.9828 - val_loss: 0.6644 - val_accuracy: 0.7961 - val_top-5-accuracy: 0.9704
Epoch 30/100
11/11 [==============================] - 5s 470ms/step - loss: 0.5852 - accuracy: 0.8216 - top-5-accuracy: 0.9825 - val_loss: 0.5996 - val_accuracy: 0.8158 - val_top-5-accuracy: 0.9737
Epoch 31/100
11/11 [==============================] - 5s 470ms/step - loss: 0.4852 - accuracy: 0.8454 - top-5-accuracy: 0.9894 - val_loss: 0.6340 - val_accuracy: 0.8224 - val_top-5-accuracy: 0.9704
Epoch 32/100
11/11 [==============================] - 5s 491ms/step - loss: 0.4860 - accuracy: 0.8425 - top-5-accuracy: 0.9857 - val_loss: 0.7302 - val_accuracy: 0.7993 - val_top-5-accuracy: 0.9704
Epoch 33/100
11/11 [==============================] - 5s 475ms/step - loss: 0.4684 - accuracy: 0.8428 - top-5-accuracy: 0.9857 - val_loss: 0.5790 - val_accuracy: 0.8257 - val_top-5-accuracy: 0.9704
Epoch 34/100
11/11 [==============================] - 5s 471ms/step - loss: 0.4278 - accuracy: 0.8626 - top-5-accuracy: 0.9883 - val_loss: 0.5776 - val_accuracy: 0.8092 - val_top-5-accuracy: 0.9836
Epoch 35/100
11/11 [==============================] - 5s 468ms/step - loss: 0.4116 - accuracy: 0.8607 - top-5-accuracy: 0.9949 - val_loss: 0.5644 - val_accuracy: 0.8388 - val_top-5-accuracy: 0.9770
Epoch 36/100
11/11 [==============================] - 5s 476ms/step - loss: 0.4202 - accuracy: 0.8626 - top-5-accuracy: 0.9920 - val_loss: 0.6480 - val_accuracy: 0.8191 - val_top-5-accuracy: 0.9770
Epoch 37/100
11/11 [==============================] - 5s 471ms/step - loss: 0.4682 - accuracy: 0.8520 - top-5-accuracy: 0.9898 - val_loss: 0.6602 - val_accuracy: 0.8059 - val_top-5-accuracy: 0.9638
Epoch 38/100
11/11 [==============================] - 5s 489ms/step - loss: 0.4195 - accuracy: 0.8648 - top-5-accuracy: 0.9931 - val_loss: 0.5412 - val_accuracy: 0.8158 - val_top-5-accuracy: 0.9737
Epoch 39/100
11/11 [==============================] - 5s 469ms/step - loss: 0.3823 - accuracy: 0.8750 - top-5-accuracy: 0.9945 - val_loss: 0.6372 - val_accuracy: 0.8224 - val_top-5-accuracy: 0.9704
Epoch 40/100
11/11 [==============================] - 5s 470ms/step - loss: 0.3818 - accuracy: 0.8812 - top-5-accuracy: 0.9956 - val_loss: 0.6658 - val_accuracy: 0.7993 - val_top-5-accuracy: 0.9737
Epoch 41/100
11/11 [==============================] - 5s 481ms/step - loss: 0.3293 - accuracy: 0.8933 - top-5-accuracy: 0.9952 - val_loss: 0.6289 - val_accuracy: 0.8322 - val_top-5-accuracy: 0.9671
Epoch 42/100
11/11 [==============================] - 5s 469ms/step - loss: 0.2960 - accuracy: 0.9031 - top-5-accuracy: 0.9967 - val_loss: 0.5719 - val_accuracy: 0.8257 - val_top-5-accuracy: 0.9770
Epoch 43/100
11/11 [==============================] - 5s 486ms/step - loss: 0.3025 - accuracy: 0.9010 - top-5-accuracy: 0.9956 - val_loss: 0.5237 - val_accuracy: 0.8388 - val_top-5-accuracy: 0.9803
Epoch 44/100
11/11 [==============================] - 5s 479ms/step - loss: 0.3484 - accuracy: 0.8966 - top-5-accuracy: 0.9942 - val_loss: 0.6792 - val_accuracy: 0.7862 - val_top-5-accuracy: 0.9803
Epoch 45/100
11/11 [==============================] - 5s 472ms/step - loss: 0.2933 - accuracy: 0.9046 - top-5-accuracy: 0.9971 - val_loss: 0.5624 - val_accuracy: 0.8421 - val_top-5-accuracy: 0.9868
Epoch 46/100
11/11 [==============================] - 5s 466ms/step - loss: 0.3030 - accuracy: 0.9013 - top-5-accuracy: 0.9985 - val_loss: 0.6261 - val_accuracy: 0.7961 - val_top-5-accuracy: 0.9770
Epoch 47/100
11/11 [==============================] - 5s 469ms/step - loss: 0.3210 - accuracy: 0.8977 - top-5-accuracy: 0.9960 - val_loss: 0.5802 - val_accuracy: 0.8388 - val_top-5-accuracy: 0.9836
Epoch 48/100
11/11 [==============================] - 5s 468ms/step - loss: 0.3209 - accuracy: 0.8958 - top-5-accuracy: 0.9971 - val_loss: 0.6645 - val_accuracy: 0.8158 - val_top-5-accuracy: 0.9803
Epoch 49/100
11/11 [==============================] - 5s 469ms/step - loss: 0.3030 - accuracy: 0.9134 - top-5-accuracy: 0.9945 - val_loss: 0.5909 - val_accuracy: 0.8322 - val_top-5-accuracy: 0.9770
Epoch 50/100
11/11 [==============================] - 5s 489ms/step - loss: 0.2841 - accuracy: 0.9061 - top-5-accuracy: 0.9971 - val_loss: 0.6537 - val_accuracy: 0.8191 - val_top-5-accuracy: 0.9671
Epoch 51/100
11/11 [==============================] - 5s 480ms/step - loss: 0.2866 - accuracy: 0.9072 - top-5-accuracy: 0.9982 - val_loss: 0.5929 - val_accuracy: 0.8224 - val_top-5-accuracy: 0.9770
Epoch 52/100
11/11 [==============================] - 5s 468ms/step - loss: 0.2941 - accuracy: 0.9046 - top-5-accuracy: 0.9974 - val_loss: 0.5234 - val_accuracy: 0.8322 - val_top-5-accuracy: 0.9803
Epoch 53/100
11/11 [==============================] - 5s 473ms/step - loss: 0.2900 - accuracy: 0.9141 - top-5-accuracy: 0.9956 - val_loss: 0.6007 - val_accuracy: 0.8289 - val_top-5-accuracy: 0.9737
Epoch 54/100
11/11 [==============================] - 5s 471ms/step - loss: 0.2311 - accuracy: 0.9280 - top-5-accuracy: 0.9993 - val_loss: 0.5893 - val_accuracy: 0.8586 - val_top-5-accuracy: 0.9803
Epoch 55/100
11/11 [==============================] - 5s 478ms/step - loss: 0.2484 - accuracy: 0.9218 - top-5-accuracy: 0.9985 - val_loss: 0.6348 - val_accuracy: 0.8355 - val_top-5-accuracy: 0.9770
Epoch 56/100
11/11 [==============================] - 5s 496ms/step - loss: 0.2685 - accuracy: 0.9159 - top-5-accuracy: 0.9956 - val_loss: 0.5722 - val_accuracy: 0.8586 - val_top-5-accuracy: 0.9770
Epoch 57/100
11/11 [==============================] - 5s 478ms/step - loss: 0.2363 - accuracy: 0.9287 - top-5-accuracy: 0.9989 - val_loss: 0.5732 - val_accuracy: 0.8487 - val_top-5-accuracy: 0.9737
Epoch 58/100
11/11 [==============================] - 5s 470ms/step - loss: 0.2536 - accuracy: 0.9229 - top-5-accuracy: 0.9982 - val_loss: 0.5657 - val_accuracy: 0.8421 - val_top-5-accuracy: 0.9836
Epoch 59/100
11/11 [==============================] - 5s 466ms/step - loss: 0.2083 - accuracy: 0.9393 - top-5-accuracy: 0.9971 - val_loss: 0.5523 - val_accuracy: 0.8618 - val_top-5-accuracy: 0.9704
Epoch 60/100
11/11 [==============================] - 5s 466ms/step - loss: 0.2203 - accuracy: 0.9306 - top-5-accuracy: 0.9989 - val_loss: 0.6903 - val_accuracy: 0.8322 - val_top-5-accuracy: 0.9737
Epoch 61/100
11/11 [==============================] - 5s 470ms/step - loss: 0.2407 - accuracy: 0.9207 - top-5-accuracy: 0.9982 - val_loss: 0.5560 - val_accuracy: 0.8322 - val_top-5-accuracy: 0.9868
Epoch 62/100
11/11 [==============================] - 5s 483ms/step - loss: 0.2638 - accuracy: 0.9152 - top-5-accuracy: 0.9967 - val_loss: 0.5574 - val_accuracy: 0.8553 - val_top-5-accuracy: 0.9803
Epoch 63/100
11/11 [==============================] - 5s 476ms/step - loss: 0.2159 - accuracy: 0.9317 - top-5-accuracy: 0.9985 - val_loss: 0.6279 - val_accuracy: 0.8487 - val_top-5-accuracy: 0.9836
Epoch 64/100
11/11 [==============================] - 5s 463ms/step - loss: 0.2551 - accuracy: 0.9251 - top-5-accuracy: 0.9967 - val_loss: 0.6940 - val_accuracy: 0.8191 - val_top-5-accuracy: 0.9803
Epoch 65/100
11/11 [==============================] - 5s 479ms/step - loss: 0.2325 - accuracy: 0.9258 - top-5-accuracy: 0.9989 - val_loss: 0.6324 - val_accuracy: 0.8355 - val_top-5-accuracy: 0.9770
Epoch 66/100
11/11 [==============================] - 5s 469ms/step - loss: 0.2144 - accuracy: 0.9364 - top-5-accuracy: 0.9982 - val_loss: 0.6224 - val_accuracy: 0.8454 - val_top-5-accuracy: 0.9836
Epoch 67/100
11/11 [==============================] - 5s 470ms/step - loss: 0.2023 - accuracy: 0.9346 - top-5-accuracy: 0.9993 - val_loss: 0.5875 - val_accuracy: 0.8553 - val_top-5-accuracy: 0.9704
Epoch 68/100
11/11 [==============================] - 5s 475ms/step - loss: 0.1912 - accuracy: 0.9375 - top-5-accuracy: 0.9989 - val_loss: 0.6924 - val_accuracy: 0.8224 - val_top-5-accuracy: 0.9737
Epoch 69/100
11/11 [==============================] - 5s 472ms/step - loss: 0.2205 - accuracy: 0.9295 - top-5-accuracy: 0.9974 - val_loss: 0.6561 - val_accuracy: 0.8257 - val_top-5-accuracy: 0.9737
Epoch 70/100
11/11 [==============================] - 5s 471ms/step - loss: 0.2002 - accuracy: 0.9368 - top-5-accuracy: 0.9985 - val_loss: 0.5291 - val_accuracy: 0.8651 - val_top-5-accuracy: 0.9868
Epoch 71/100
11/11 [==============================] - 5s 468ms/step - loss: 0.1925 - accuracy: 0.9382 - top-5-accuracy: 0.9993 - val_loss: 0.5830 - val_accuracy: 0.8586 - val_top-5-accuracy: 0.9836
Epoch 72/100
11/11 [==============================] - 5s 471ms/step - loss: 0.1846 - accuracy: 0.9430 - top-5-accuracy: 0.9996 - val_loss: 0.6433 - val_accuracy: 0.8487 - val_top-5-accuracy: 0.9803
Epoch 73/100
11/11 [==============================] - 5s 470ms/step - loss: 0.1780 - accuracy: 0.9444 - top-5-accuracy: 0.9993 - val_loss: 0.5253 - val_accuracy: 0.8618 - val_top-5-accuracy: 0.9836
Epoch 74/100
11/11 [==============================] - 6s 505ms/step - loss: 0.1646 - accuracy: 0.9485 - top-5-accuracy: 0.9982 - val_loss: 0.5431 - val_accuracy: 0.8684 - val_top-5-accuracy: 0.9803
Epoch 75/100
11/11 [==============================] - 5s 474ms/step - loss: 0.1678 - accuracy: 0.9470 - top-5-accuracy: 0.9978 - val_loss: 0.6069 - val_accuracy: 0.8684 - val_top-5-accuracy: 0.9868
Epoch 76/100
11/11 [==============================] - 5s 469ms/step - loss: 0.1517 - accuracy: 0.9488 - top-5-accuracy: 0.9985 - val_loss: 0.5667 - val_accuracy: 0.8487 - val_top-5-accuracy: 0.9901
Epoch 77/100
11/11 [==============================] - 5s 471ms/step - loss: 0.1416 - accuracy: 0.9525 - top-5-accuracy: 0.9989 - val_loss: 0.6547 - val_accuracy: 0.8421 - val_top-5-accuracy: 0.9803
Epoch 78/100
11/11 [==============================] - 5s 479ms/step - loss: 0.1623 - accuracy: 0.9477 - top-5-accuracy: 0.9989 - val_loss: 0.5051 - val_accuracy: 0.8618 - val_top-5-accuracy: 0.9934
Epoch 79/100
11/11 [==============================] - 5s 475ms/step - loss: 0.1666 - accuracy: 0.9488 - top-5-accuracy: 0.9989 - val_loss: 0.6188 - val_accuracy: 0.8618 - val_top-5-accuracy: 0.9737
Epoch 80/100
11/11 [==============================] - 5s 478ms/step - loss: 0.2011 - accuracy: 0.9368 - top-5-accuracy: 0.9985 - val_loss: 0.6103 - val_accuracy: 0.8618 - val_top-5-accuracy: 0.9803
Epoch 81/100
11/11 [==============================] - 5s 471ms/step - loss: 0.1709 - accuracy: 0.9488 - top-5-accuracy: 0.9985 - val_loss: 0.5960 - val_accuracy: 0.8684 - val_top-5-accuracy: 0.9671
Epoch 82/100
11/11 [==============================] - 5s 480ms/step - loss: 0.1613 - accuracy: 0.9477 - top-5-accuracy: 0.9993 - val_loss: 0.6275 - val_accuracy: 0.8487 - val_top-5-accuracy: 0.9737
Epoch 83/100
11/11 [==============================] - 5s 471ms/step - loss: 0.1537 - accuracy: 0.9507 - top-5-accuracy: 0.9985 - val_loss: 0.6007 - val_accuracy: 0.8586 - val_top-5-accuracy: 0.9803
Epoch 84/100
11/11 [==============================] - 5s 477ms/step - loss: 0.1711 - accuracy: 0.9481 - top-5-accuracy: 0.9982 - val_loss: 0.5334 - val_accuracy: 0.8487 - val_top-5-accuracy: 0.9803
Epoch 85/100
11/11 [==============================] - 5s 468ms/step - loss: 0.1666 - accuracy: 0.9496 - top-5-accuracy: 0.9989 - val_loss: 0.6471 - val_accuracy: 0.8421 - val_top-5-accuracy: 0.9770
Epoch 86/100
11/11 [==============================] - 5s 483ms/step - loss: 0.1671 - accuracy: 0.9496 - top-5-accuracy: 0.9996 - val_loss: 0.7173 - val_accuracy: 0.8289 - val_top-5-accuracy: 0.9836
Epoch 87/100
11/11 [==============================] - 5s 463ms/step - loss: 0.1811 - accuracy: 0.9426 - top-5-accuracy: 0.9985 - val_loss: 0.7476 - val_accuracy: 0.8191 - val_top-5-accuracy: 0.9836
Epoch 88/100
11/11 [==============================] - 5s 482ms/step - loss: 0.1487 - accuracy: 0.9547 - top-5-accuracy: 1.0000 - val_loss: 0.5163 - val_accuracy: 0.8717 - val_top-5-accuracy: 0.9868
Epoch 89/100
11/11 [==============================] - 5s 470ms/step - loss: 0.1241 - accuracy: 0.9609 - top-5-accuracy: 1.0000 - val_loss: 0.5221 - val_accuracy: 0.8717 - val_top-5-accuracy: 0.9868
Epoch 90/100
11/11 [==============================] - 5s 484ms/step - loss: 0.1255 - accuracy: 0.9613 - top-5-accuracy: 0.9993 - val_loss: 0.4853 - val_accuracy: 0.8750 - val_top-5-accuracy: 0.9868
Epoch 91/100
11/11 [==============================] - 5s 488ms/step - loss: 0.1113 - accuracy: 0.9653 - top-5-accuracy: 1.0000 - val_loss: 0.5162 - val_accuracy: 0.8882 - val_top-5-accuracy: 0.9901
Epoch 92/100
11/11 [==============================] - 5s 476ms/step - loss: 0.0999 - accuracy: 0.9638 - top-5-accuracy: 0.9996 - val_loss: 0.4899 - val_accuracy: 0.8882 - val_top-5-accuracy: 0.9934
Epoch 93/100
11/11 [==============================] - 5s 466ms/step - loss: 0.1115 - accuracy: 0.9645 - top-5-accuracy: 0.9993 - val_loss: 0.6128 - val_accuracy: 0.8520 - val_top-5-accuracy: 0.9934
Epoch 94/100
11/11 [==============================] - 5s 474ms/step - loss: 0.1242 - accuracy: 0.9580 - top-5-accuracy: 1.0000 - val_loss: 0.5649 - val_accuracy: 0.8849 - val_top-5-accuracy: 0.9836
Epoch 95/100
11/11 [==============================] - 5s 471ms/step - loss: 0.1530 - accuracy: 0.9569 - top-5-accuracy: 0.9996 - val_loss: 0.5064 - val_accuracy: 0.8651 - val_top-5-accuracy: 0.9836
Epoch 96/100
11/11 [==============================] - 5s 475ms/step - loss: 0.1813 - accuracy: 0.9496 - top-5-accuracy: 0.9989 - val_loss: 0.5598 - val_accuracy: 0.8914 - val_top-5-accuracy: 0.9836
Epoch 97/100
11/11 [==============================] - 5s 488ms/step - loss: 0.1615 - accuracy: 0.9496 - top-5-accuracy: 0.9993 - val_loss: 0.6164 - val_accuracy: 0.8454 - val_top-5-accuracy: 0.9770
Epoch 98/100
11/11 [==============================] - 5s 482ms/step - loss: 0.1599 - accuracy: 0.9499 - top-5-accuracy: 0.9996 - val_loss: 0.6491 - val_accuracy: 0.8586 - val_top-5-accuracy: 0.9836
Epoch 99/100
11/11 [==============================] - 5s 470ms/step - loss: 0.1286 - accuracy: 0.9576 - top-5-accuracy: 0.9993 - val_loss: 0.5429 - val_accuracy: 0.8684 - val_top-5-accuracy: 0.9836
Epoch 100/100
11/11 [==============================] - 5s 469ms/step - loss: 0.1263 - accuracy: 0.9627 - top-5-accuracy: 0.9989 - val_loss: 0.6173 - val_accuracy: 0.8618 - val_top-5-accuracy: 0.9868
2023-02-01 19:25:51.290614: W tensorflow/core/framework/cpu_allocator_impl.cc:80] Allocation of 597688320 exceeds 10% of free system memory.
24/24 [==============================] - 1s 36ms/step - loss: 0.6115 - accuracy: 0.8579 - top-5-accuracy: 0.9750
Test accuracy: 85.79%
Test top 5 accuracy: 97.5%
print(history.history.keys())
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
top_5_acc = history.history['top-5-accuracy']
op_5_val_acc = history.history['val_top-5-accuracy']
epochs = range(1, len(acc) + 1)

plt.plot(epochs, acc, 'b', label='Training accurarcy')
plt.plot(epochs, val_acc, 'r', label='Validation accurarcy')
plt.title('Training and Validation accurarcy')
plt.legend()

plt.figure()
plt.plot(epochs, acc, 'b', label='Top-5 Training accurarcy')
plt.plot(epochs, val_acc, 'r', label='Top-5 Validation accurarcy')
plt.title('Training and Validation Top-5 accurarcy')
plt.legend()

plt.figure()
plt.plot(epochs, loss, 'b', label='Training loss')
plt.plot(epochs, val_loss, 'r', label='Validation loss')
plt.title('Training and Validation loss')
plt.legend()
plt.show()
dict_keys(['loss', 'accuracy', 'top-5-accuracy', 'val_loss', 'val_accuracy', 'val_top-5-accuracy'])

print(model.predict(x_test[0]))

vit_classifier_with_att_scores = create_vit_classifier(with_attention_scores_output=True)
checkpoint_filepath = "/tmp/checkpoint"
vit_classifier_with_att_scores.load_weights(checkpoint_filepath)
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7fc3eefcfd90>
plt.figure(figsize=(4, 4))
image_to_predict = x_train[np.random.choice(range(x_train.shape[0]))]
image_to_draw = image_to_predict * 225.0
plt.imshow(image_to_draw.astype("uint8"))
plt.axis("off")
(-0.5, 255.5, 255.5, -0.5)
image_to_predict = np.expand_dims(image_to_predict, axis=0)
output, attention_scores  = vit_classifier_with_att_scores(image_to_predict)
# computes attention rollout map
def attention_map(src_image, attention_score_dict):
    attention_mat = tf.stack(list(attention_score_dict.values()))
    print(attention_mat.shape)
    attention_mat = tf.squeeze(attention_mat, axis=1)
    print(attention_mat.shape)

    # Average attention (for all heads) per MultiHeadAttention layers
    attention_mat = tf.reduce_mean(attention_mat, axis=1)
    print(attention_mat.shape)

    # "to account for residual connections, we add an identity matrix to the attention matrix and re-normalize the weights"
    residual_attn = tf.eye(attention_mat.shape[1])
    print(residual_attn.shape)
    aug_attention_mat = attention_mat + residual_attn
    aug_attention_mat = np.array(aug_attention_mat / tf.reduce_sum(aug_attention_mat, axis=-1)[..., None])
    print(aug_attention_mat.shape)

    # Recursive formula on Vl+1
    joint_attentions = np.zeros(aug_attention_mat.shape)
    joint_attentions[0] = aug_attention_mat[0]
    for n in range(1, aug_attention_mat.shape[0]):
        joint_attentions[n] = np.matmul(aug_attention_mat[n], joint_attentions[n - 1])

    # Reshape last calculated attention that we can multiply it by the input RGB image
    v = joint_attentions[-1]
    grid_size = int(np.sqrt(aug_attention_mat.shape[-1]))
    mask = v[0].reshape(grid_size, grid_size)
    #print(src_image.shape[:2])
    #print(src_image.size)
    mask = cv2.resize(mask / mask.max(), src_image.shape[:2])[..., np.newaxis] #[: ,:, np.newaxis]
    #print(mask.shape)
    result = (mask * src_image).astype("uint8")
    return result
attn_rollout_result = attention_map(image_to_draw, attention_scores)

fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(8, 10))

ax1.imshow(image_to_draw.astype("uint8"))
ax2.imshow(attn_rollout_result)
ax1.axis("off")
ax2.axis("off")
ax1.set_title("Input")
ax2.set_title("Attention Map")
fig.tight_layout()
fig.show()
(8, 1, 4, 144, 144)
(8, 4, 144, 144)
(8, 144, 144)
(144, 144)
(8, 144, 144)