802 KiB
802 KiB
#!git clone https://github.com/spMohanty/PlantVillage-Dataset
#!cd PlantVillage-Dataset
import numpy as np
import pickle
import cv2
import tensorflow as tf
import tensorflow_addons as tfa
from tensorflow import keras
from tensorflow.keras import layers
from os import listdir, mkdir
from sklearn.preprocessing import LabelBinarizer
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.model_selection import train_test_split
#from keras.utils import img_to_array
from keras.preprocessing.image import img_to_array
import matplotlib.pyplot as plt
data_dir = "/kaggle/input/plantvillage-dataset/color"
classes = listdir(data_dir)
classes
['Tomato___Late_blight', 'Tomato___healthy', 'Grape___healthy', 'Orange___Haunglongbing_(Citrus_greening)', 'Soybean___healthy', 'Squash___Powdery_mildew', 'Potato___healthy', 'Corn_(maize)___Northern_Leaf_Blight', 'Tomato___Early_blight', 'Tomato___Septoria_leaf_spot', 'Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot', 'Strawberry___Leaf_scorch', 'Peach___healthy', 'Apple___Apple_scab', 'Tomato___Tomato_Yellow_Leaf_Curl_Virus', 'Tomato___Bacterial_spot', 'Apple___Black_rot', 'Blueberry___healthy', 'Cherry_(including_sour)___Powdery_mildew', 'Peach___Bacterial_spot', 'Apple___Cedar_apple_rust', 'Tomato___Target_Spot', 'Pepper,_bell___healthy', 'Grape___Leaf_blight_(Isariopsis_Leaf_Spot)', 'Potato___Late_blight', 'Tomato___Tomato_mosaic_virus', 'Strawberry___healthy', 'Apple___healthy', 'Grape___Black_rot', 'Potato___Early_blight', 'Cherry_(including_sour)___healthy', 'Corn_(maize)___Common_rust_', 'Grape___Esca_(Black_Measles)', 'Raspberry___healthy', 'Tomato___Leaf_Mold', 'Tomato___Spider_mites Two-spotted_spider_mite', 'Pepper,_bell___Bacterial_spot', 'Corn_(maize)___healthy']
#mkdir(data_dir + 'test_data')
#mkdir(data_dir + 'train_data')
learning_rate = 0.001
weight_decay = 0.0001
batch_size = 256
num_epochs = 100
image_size = 72 # Resize input images to this size
patch_size = 6 # Size of the patches to be extract from the input images
num_patches = (image_size // patch_size) ** 2
projection_dim = 64
num_heads = 4
transformer_units = [
projection_dim * 2,
projection_dim,
] # Size of the transformer layers
transformer_layers = 8
mlp_head_units = [2048, 1024] # Size of the dense layers of the final classifier
num_classes = 38
def convert_image_to_array(image_dir):
image = cv2.imread(image_dir)
if image is not None :
return img_to_array(image)
else :
return np.array([])
image_list, label_list = [], []
print("Loading images ...")
root_dir = listdir(data_dir)
for plant_folder in root_dir :
plant_disease_image_list = listdir(f"{data_dir}/{plant_folder}")
print(len(plant_disease_image_list))
for image in plant_disease_image_list[:100]:
image_directory = f"{data_dir}/{plant_folder}/{image}"
if image_directory.lower().endswith(".jpg"):
image_list.append(convert_image_to_array(image_directory))
label_list.append(plant_folder)
print("Image loading completed")
Loading images ... 1909 1591 423 5507 5090 1835 152 985 1000 1771 513 1109 360 630 5357 2127 621 1502 1052 2297 275 1404 1478 1076 1000 373 456 1645 1180 1000 854 1192 1383 371 952 1676 997 1162 Image loading completed
label_binarizer = LabelBinarizer()
image_labels = label_binarizer.fit_transform(label_list)
pickle.dump(label_binarizer,open('label_transform.pkl', 'wb'))
n_classes = len(label_binarizer.classes_)
print(label_binarizer.classes_)
['Apple___Apple_scab' 'Apple___Black_rot' 'Apple___Cedar_apple_rust' 'Apple___healthy' 'Blueberry___healthy' 'Cherry_(including_sour)___Powdery_mildew' 'Cherry_(including_sour)___healthy' 'Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot' 'Corn_(maize)___Common_rust_' 'Corn_(maize)___Northern_Leaf_Blight' 'Corn_(maize)___healthy' 'Grape___Black_rot' 'Grape___Esca_(Black_Measles)' 'Grape___Leaf_blight_(Isariopsis_Leaf_Spot)' 'Grape___healthy' 'Orange___Haunglongbing_(Citrus_greening)' 'Peach___Bacterial_spot' 'Peach___healthy' 'Pepper,_bell___Bacterial_spot' 'Pepper,_bell___healthy' 'Potato___Early_blight' 'Potato___Late_blight' 'Potato___healthy' 'Raspberry___healthy' 'Soybean___healthy' 'Squash___Powdery_mildew' 'Strawberry___Leaf_scorch' 'Strawberry___healthy' 'Tomato___Bacterial_spot' 'Tomato___Early_blight' 'Tomato___Late_blight' 'Tomato___Leaf_Mold' 'Tomato___Septoria_leaf_spot' 'Tomato___Spider_mites Two-spotted_spider_mite' 'Tomato___Target_Spot' 'Tomato___Tomato_Yellow_Leaf_Curl_Virus' 'Tomato___Tomato_mosaic_virus' 'Tomato___healthy']
np_image_list = np.array(image_list, dtype=np.float16) / 225.0
print("Spliting data to train, test")
x_train, x_test, y_train, y_test = train_test_split(np_image_list, image_labels, test_size=0.2, random_state = 42)
Spliting data to train, test
input_shape = x_train[0].shape
print(input_shape)
(256, 256, 3)
data_augmentation = keras.Sequential(
[
layers.Normalization(),
layers.Resizing(image_size, image_size),
layers.RandomFlip("horizontal"),
layers.RandomRotation(factor=0.02),
layers.RandomZoom(
height_factor=0.2, width_factor=0.2
),
],
name="data_augmentation",
)
# Compute the mean and the variance of the training data for normalization.
data_augmentation.layers[0].adapt(x_train)
2023-02-01 19:16:02.215465: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero 2023-02-01 19:16:02.216581: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero 2023-02-01 19:16:02.339556: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero 2023-02-01 19:16:02.340388: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero 2023-02-01 19:16:02.341204: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero 2023-02-01 19:16:02.341945: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero 2023-02-01 19:16:02.344060: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 AVX512F FMA To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags. 2023-02-01 19:16:02.612233: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero 2023-02-01 19:16:02.613116: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero 2023-02-01 19:16:02.613836: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero 2023-02-01 19:16:02.614556: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero 2023-02-01 19:16:02.615247: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero 2023-02-01 19:16:02.615907: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero 2023-02-01 19:16:05.652587: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero 2023-02-01 19:16:05.653498: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero 2023-02-01 19:16:05.654225: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero 2023-02-01 19:16:05.654932: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero 2023-02-01 19:16:05.655675: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero 2023-02-01 19:16:05.656433: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1510] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 13789 MB memory: -> device: 0, name: Tesla T4, pci bus id: 0000:00:04.0, compute capability: 7.5 2023-02-01 19:16:05.659529: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero 2023-02-01 19:16:05.660219: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1510] Created device /job:localhost/replica:0/task:0/device:GPU:1 with 13789 MB memory: -> device: 1, name: Tesla T4, pci bus id: 0000:00:05.0, compute capability: 7.5 2023-02-01 19:16:08.194863: W tensorflow/core/framework/cpu_allocator_impl.cc:80] Allocation of 2390753280 exceeds 10% of free system memory. 2023-02-01 19:16:11.025936: W tensorflow/core/framework/cpu_allocator_impl.cc:80] Allocation of 2390753280 exceeds 10% of free system memory. 2023-02-01 19:16:12.991353: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:185] None of the MLIR Optimization Passes are enabled (registered 2)
def mlp(x, hidden_units, dropout_rate):
for units in hidden_units:
x = layers.Dense(units, activation=tf.nn.gelu)(x)
x = layers.Dropout(dropout_rate)(x)
return x
class Patches(layers.Layer):
def __init__(self, patch_size):
super().__init__()
self.patch_size = patch_size
def call(self, images):
batch_size = tf.shape(images)[0]
patches = tf.image.extract_patches(
images=images,
sizes=[1, self.patch_size, self.patch_size, 1],
strides=[1, self.patch_size, self.patch_size, 1],
rates=[1, 1, 1, 1],
padding="VALID",
)
patch_dims = patches.shape[-1]
patches = tf.reshape(patches, [batch_size, -1, patch_dims])
return patches
plt.figure(figsize=(4, 4))
image = x_train[np.random.choice(range(x_train.shape[0]))]
image = image * 225.0
plt.imshow(image.astype("uint8"))
plt.axis("off")
resized_image = tf.image.resize(
tf.convert_to_tensor([image]), size=(image_size, image_size)
)
patches = Patches(patch_size)(resized_image)
print(f"Image size: {image_size} X {image_size}")
print(f"Patch size: {patch_size} X {patch_size}")
print(f"Patches per image: {patches.shape[1]}")
print(f"Elements per patch: {patches.shape[-1]}")
n = int(np.sqrt(patches.shape[1]))
plt.figure(figsize=(4, 4))
for i, patch in enumerate(patches[0]):
ax = plt.subplot(n, n, i + 1)
patch_img = tf.reshape(patch, (patch_size, patch_size, 3))
plt.imshow(patch_img.numpy().astype("uint8"))
plt.axis("off")
Image size: 72 X 72 Patch size: 6 X 6 Patches per image: 144 Elements per patch: 108
class PatchEncoder(layers.Layer):
def __init__(self, num_patches, projection_dim):
super().__init__()
self.num_patches = num_patches
self.projection = layers.Dense(units=projection_dim)
self.position_embedding = layers.Embedding(
input_dim=num_patches, output_dim=projection_dim
)
def call(self, patch):
positions = tf.range(start=0, limit=self.num_patches, delta=1)
encoded = self.projection(patch) + self.position_embedding(positions)
return encoded
def create_vit_classifier(with_attention_scores_output=False):
inputs = layers.Input(shape=input_shape)
# Augment data.
augmented = data_augmentation(inputs)
# Create patches.
patches = Patches(patch_size)(augmented)
# Encode patches.
encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)
attention_score_dict = dict()
# Create multiple layers of the Transformer block.
for i in range(transformer_layers):
# Layer normalization 1.
x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
# Create a multi-head attention layer.
attention_output, attention_score = layers.MultiHeadAttention(
num_heads=num_heads, key_dim=projection_dim, dropout=0.1
)(x1, x1, return_attention_scores=True)
# Skip connection 1.
x2 = layers.Add()([attention_output, encoded_patches])
# Layer normalization 2.
x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
# MLP.
x3 = mlp(x3, hidden_units=transformer_units, dropout_rate=0.1)
# Skip connection 2.
encoded_patches = layers.Add()([x3, x2])
attention_score_dict[i] = attention_score
# Create a [batch_size, projection_dim] tensor.
representation = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
representation = layers.Flatten()(representation)
representation = layers.Dropout(0.5)(representation)
# Add MLP.
features = mlp(representation, hidden_units=mlp_head_units, dropout_rate=0.5)
# Classify outputs.
logits = layers.Dense(num_classes)(features)
# Create the Keras model.
if with_attention_scores_output:
model = keras.Model(inputs=inputs, outputs=[logits, attention_score_dict])
else:
model = keras.Model(inputs=inputs, outputs=logits)
return model
def run_experiment(model):
optimizer = tfa.optimizers.AdamW(
learning_rate=learning_rate, weight_decay=weight_decay
)
model.compile(
optimizer=optimizer,
loss=keras.losses.CategoricalCrossentropy(from_logits=True),
metrics=[
keras.metrics.CategoricalAccuracy(name="accuracy"),
keras.metrics.TopKCategoricalAccuracy(5, name="top-5-accuracy"),
],
)
checkpoint_filepath = "/tmp/checkpoint"
checkpoint_callback = keras.callbacks.ModelCheckpoint(
checkpoint_filepath,
monitor="val_accuracy",
save_best_only=True,
save_weights_only=True,
)
history = model.fit(
x=x_train,
y=y_train,
batch_size=batch_size,
epochs=num_epochs,
validation_split=0.1,
callbacks=[checkpoint_callback],
)
model.load_weights(checkpoint_filepath)
_, accuracy, top_5_accuracy = model.evaluate(x_test, y_test)
print(f"Test accuracy: {round(accuracy * 100, 2)}%")
print(f"Test top 5 accuracy: {round(top_5_accuracy * 100, 2)}%")
return history
vit_classifier = create_vit_classifier()
history = run_experiment(vit_classifier)
2023-02-01 19:16:22.696720: W tensorflow/core/framework/cpu_allocator_impl.cc:80] Allocation of 2151677952 exceeds 10% of free system memory. 2023-02-01 19:16:25.088615: W tensorflow/core/framework/cpu_allocator_impl.cc:80] Allocation of 2151677952 exceeds 10% of free system memory.
Epoch 1/100 11/11 [==============================] - 19s 634ms/step - loss: 4.5576 - accuracy: 0.1012 - top-5-accuracy: 0.2982 - val_loss: 2.7751 - val_accuracy: 0.2599 - val_top-5-accuracy: 0.5691 Epoch 2/100 11/11 [==============================] - 5s 468ms/step - loss: 3.0885 - accuracy: 0.1970 - top-5-accuracy: 0.4952 - val_loss: 2.4063 - val_accuracy: 0.3191 - val_top-5-accuracy: 0.6743 Epoch 3/100 11/11 [==============================] - 5s 451ms/step - loss: 2.7145 - accuracy: 0.2657 - top-5-accuracy: 0.6023 - val_loss: 2.1349 - val_accuracy: 0.3947 - val_top-5-accuracy: 0.7533 Epoch 4/100 11/11 [==============================] - 5s 478ms/step - loss: 2.4462 - accuracy: 0.3249 - top-5-accuracy: 0.6703 - val_loss: 2.0473 - val_accuracy: 0.4243 - val_top-5-accuracy: 0.7829 Epoch 5/100 11/11 [==============================] - 5s 468ms/step - loss: 2.2455 - accuracy: 0.3644 - top-5-accuracy: 0.7251 - val_loss: 1.8077 - val_accuracy: 0.4934 - val_top-5-accuracy: 0.8388 Epoch 6/100 11/11 [==============================] - 5s 454ms/step - loss: 2.0483 - accuracy: 0.4090 - top-5-accuracy: 0.7686 - val_loss: 1.5991 - val_accuracy: 0.5296 - val_top-5-accuracy: 0.8717 Epoch 7/100 11/11 [==============================] - 5s 462ms/step - loss: 1.8823 - accuracy: 0.4529 - top-5-accuracy: 0.8132 - val_loss: 1.4373 - val_accuracy: 0.5757 - val_top-5-accuracy: 0.9178 Epoch 8/100 11/11 [==============================] - 5s 462ms/step - loss: 1.7941 - accuracy: 0.4762 - top-5-accuracy: 0.8326 - val_loss: 1.3213 - val_accuracy: 0.6513 - val_top-5-accuracy: 0.9046 Epoch 9/100 11/11 [==============================] - 5s 482ms/step - loss: 1.6349 - accuracy: 0.5128 - top-5-accuracy: 0.8578 - val_loss: 1.3080 - val_accuracy: 0.5921 - val_top-5-accuracy: 0.9112 Epoch 10/100 11/11 [==============================] - 5s 463ms/step - loss: 1.4902 - accuracy: 0.5585 - top-5-accuracy: 0.8819 - val_loss: 1.1707 - val_accuracy: 0.6546 - val_top-5-accuracy: 0.9276 Epoch 11/100 11/11 [==============================] - 5s 478ms/step - loss: 1.3675 - accuracy: 0.5936 - top-5-accuracy: 0.9046 - val_loss: 1.0757 - val_accuracy: 0.6809 - val_top-5-accuracy: 0.9408 Epoch 12/100 11/11 [==============================] - 5s 465ms/step - loss: 1.3511 - accuracy: 0.5826 - top-5-accuracy: 0.9050 - val_loss: 1.0533 - val_accuracy: 0.6546 - val_top-5-accuracy: 0.9605 Epoch 13/100 11/11 [==============================] - 5s 477ms/step - loss: 1.2636 - accuracy: 0.6107 - top-5-accuracy: 0.9192 - val_loss: 0.9733 - val_accuracy: 0.7434 - val_top-5-accuracy: 0.9474 Epoch 14/100 11/11 [==============================] - 5s 505ms/step - loss: 1.1770 - accuracy: 0.6316 - top-5-accuracy: 0.9284 - val_loss: 1.0223 - val_accuracy: 0.7072 - val_top-5-accuracy: 0.9539 Epoch 15/100 11/11 [==============================] - 6s 531ms/step - loss: 1.0597 - accuracy: 0.6678 - top-5-accuracy: 0.9382 - val_loss: 0.9859 - val_accuracy: 0.7204 - val_top-5-accuracy: 0.9539 Epoch 16/100 11/11 [==============================] - 5s 479ms/step - loss: 1.0553 - accuracy: 0.6732 - top-5-accuracy: 0.9474 - val_loss: 0.9171 - val_accuracy: 0.7237 - val_top-5-accuracy: 0.9539 Epoch 17/100 11/11 [==============================] - 5s 469ms/step - loss: 0.9795 - accuracy: 0.6919 - top-5-accuracy: 0.9503 - val_loss: 0.9283 - val_accuracy: 0.7303 - val_top-5-accuracy: 0.9572 Epoch 18/100 11/11 [==============================] - 5s 470ms/step - loss: 0.9205 - accuracy: 0.7069 - top-5-accuracy: 0.9583 - val_loss: 0.8204 - val_accuracy: 0.7961 - val_top-5-accuracy: 0.9737 Epoch 19/100 11/11 [==============================] - 5s 469ms/step - loss: 0.9587 - accuracy: 0.6959 - top-5-accuracy: 0.9594 - val_loss: 0.7994 - val_accuracy: 0.7664 - val_top-5-accuracy: 0.9638 Epoch 20/100 11/11 [==============================] - 5s 503ms/step - loss: 0.8644 - accuracy: 0.7336 - top-5-accuracy: 0.9635 - val_loss: 0.7162 - val_accuracy: 0.7697 - val_top-5-accuracy: 0.9671 Epoch 21/100 11/11 [==============================] - 5s 478ms/step - loss: 0.8480 - accuracy: 0.7288 - top-5-accuracy: 0.9653 - val_loss: 0.7097 - val_accuracy: 0.7862 - val_top-5-accuracy: 0.9737 Epoch 22/100 11/11 [==============================] - 5s 483ms/step - loss: 0.7660 - accuracy: 0.7518 - top-5-accuracy: 0.9730 - val_loss: 0.7207 - val_accuracy: 0.7763 - val_top-5-accuracy: 0.9704 Epoch 23/100 11/11 [==============================] - 5s 470ms/step - loss: 0.7501 - accuracy: 0.7664 - top-5-accuracy: 0.9671 - val_loss: 0.6781 - val_accuracy: 0.7961 - val_top-5-accuracy: 0.9737 Epoch 24/100 11/11 [==============================] - 5s 467ms/step - loss: 0.7344 - accuracy: 0.7683 - top-5-accuracy: 0.9726 - val_loss: 0.8229 - val_accuracy: 0.7434 - val_top-5-accuracy: 0.9704 Epoch 25/100 11/11 [==============================] - 5s 470ms/step - loss: 0.7047 - accuracy: 0.7756 - top-5-accuracy: 0.9744 - val_loss: 0.6895 - val_accuracy: 0.7829 - val_top-5-accuracy: 0.9671 Epoch 26/100 11/11 [==============================] - 5s 467ms/step - loss: 0.6215 - accuracy: 0.7968 - top-5-accuracy: 0.9784 - val_loss: 0.7865 - val_accuracy: 0.7566 - val_top-5-accuracy: 0.9704 Epoch 27/100 11/11 [==============================] - 5s 485ms/step - loss: 0.6753 - accuracy: 0.7880 - top-5-accuracy: 0.9799 - val_loss: 0.6987 - val_accuracy: 0.8191 - val_top-5-accuracy: 0.9737 Epoch 28/100 11/11 [==============================] - 5s 467ms/step - loss: 0.6023 - accuracy: 0.8085 - top-5-accuracy: 0.9814 - val_loss: 0.6609 - val_accuracy: 0.8191 - val_top-5-accuracy: 0.9770 Epoch 29/100 11/11 [==============================] - 5s 482ms/step - loss: 0.5867 - accuracy: 0.8077 - top-5-accuracy: 0.9828 - val_loss: 0.6644 - val_accuracy: 0.7961 - val_top-5-accuracy: 0.9704 Epoch 30/100 11/11 [==============================] - 5s 470ms/step - loss: 0.5852 - accuracy: 0.8216 - top-5-accuracy: 0.9825 - val_loss: 0.5996 - val_accuracy: 0.8158 - val_top-5-accuracy: 0.9737 Epoch 31/100 11/11 [==============================] - 5s 470ms/step - loss: 0.4852 - accuracy: 0.8454 - top-5-accuracy: 0.9894 - val_loss: 0.6340 - val_accuracy: 0.8224 - val_top-5-accuracy: 0.9704 Epoch 32/100 11/11 [==============================] - 5s 491ms/step - loss: 0.4860 - accuracy: 0.8425 - top-5-accuracy: 0.9857 - val_loss: 0.7302 - val_accuracy: 0.7993 - val_top-5-accuracy: 0.9704 Epoch 33/100 11/11 [==============================] - 5s 475ms/step - loss: 0.4684 - accuracy: 0.8428 - top-5-accuracy: 0.9857 - val_loss: 0.5790 - val_accuracy: 0.8257 - val_top-5-accuracy: 0.9704 Epoch 34/100 11/11 [==============================] - 5s 471ms/step - loss: 0.4278 - accuracy: 0.8626 - top-5-accuracy: 0.9883 - val_loss: 0.5776 - val_accuracy: 0.8092 - val_top-5-accuracy: 0.9836 Epoch 35/100 11/11 [==============================] - 5s 468ms/step - loss: 0.4116 - accuracy: 0.8607 - top-5-accuracy: 0.9949 - val_loss: 0.5644 - val_accuracy: 0.8388 - val_top-5-accuracy: 0.9770 Epoch 36/100 11/11 [==============================] - 5s 476ms/step - loss: 0.4202 - accuracy: 0.8626 - top-5-accuracy: 0.9920 - val_loss: 0.6480 - val_accuracy: 0.8191 - val_top-5-accuracy: 0.9770 Epoch 37/100 11/11 [==============================] - 5s 471ms/step - loss: 0.4682 - accuracy: 0.8520 - top-5-accuracy: 0.9898 - val_loss: 0.6602 - val_accuracy: 0.8059 - val_top-5-accuracy: 0.9638 Epoch 38/100 11/11 [==============================] - 5s 489ms/step - loss: 0.4195 - accuracy: 0.8648 - top-5-accuracy: 0.9931 - val_loss: 0.5412 - val_accuracy: 0.8158 - val_top-5-accuracy: 0.9737 Epoch 39/100 11/11 [==============================] - 5s 469ms/step - loss: 0.3823 - accuracy: 0.8750 - top-5-accuracy: 0.9945 - val_loss: 0.6372 - val_accuracy: 0.8224 - val_top-5-accuracy: 0.9704 Epoch 40/100 11/11 [==============================] - 5s 470ms/step - loss: 0.3818 - accuracy: 0.8812 - top-5-accuracy: 0.9956 - val_loss: 0.6658 - val_accuracy: 0.7993 - val_top-5-accuracy: 0.9737 Epoch 41/100 11/11 [==============================] - 5s 481ms/step - loss: 0.3293 - accuracy: 0.8933 - top-5-accuracy: 0.9952 - val_loss: 0.6289 - val_accuracy: 0.8322 - val_top-5-accuracy: 0.9671 Epoch 42/100 11/11 [==============================] - 5s 469ms/step - loss: 0.2960 - accuracy: 0.9031 - top-5-accuracy: 0.9967 - val_loss: 0.5719 - val_accuracy: 0.8257 - val_top-5-accuracy: 0.9770 Epoch 43/100 11/11 [==============================] - 5s 486ms/step - loss: 0.3025 - accuracy: 0.9010 - top-5-accuracy: 0.9956 - val_loss: 0.5237 - val_accuracy: 0.8388 - val_top-5-accuracy: 0.9803 Epoch 44/100 11/11 [==============================] - 5s 479ms/step - loss: 0.3484 - accuracy: 0.8966 - top-5-accuracy: 0.9942 - val_loss: 0.6792 - val_accuracy: 0.7862 - val_top-5-accuracy: 0.9803 Epoch 45/100 11/11 [==============================] - 5s 472ms/step - loss: 0.2933 - accuracy: 0.9046 - top-5-accuracy: 0.9971 - val_loss: 0.5624 - val_accuracy: 0.8421 - val_top-5-accuracy: 0.9868 Epoch 46/100 11/11 [==============================] - 5s 466ms/step - loss: 0.3030 - accuracy: 0.9013 - top-5-accuracy: 0.9985 - val_loss: 0.6261 - val_accuracy: 0.7961 - val_top-5-accuracy: 0.9770 Epoch 47/100 11/11 [==============================] - 5s 469ms/step - loss: 0.3210 - accuracy: 0.8977 - top-5-accuracy: 0.9960 - val_loss: 0.5802 - val_accuracy: 0.8388 - val_top-5-accuracy: 0.9836 Epoch 48/100 11/11 [==============================] - 5s 468ms/step - loss: 0.3209 - accuracy: 0.8958 - top-5-accuracy: 0.9971 - val_loss: 0.6645 - val_accuracy: 0.8158 - val_top-5-accuracy: 0.9803 Epoch 49/100 11/11 [==============================] - 5s 469ms/step - loss: 0.3030 - accuracy: 0.9134 - top-5-accuracy: 0.9945 - val_loss: 0.5909 - val_accuracy: 0.8322 - val_top-5-accuracy: 0.9770 Epoch 50/100 11/11 [==============================] - 5s 489ms/step - loss: 0.2841 - accuracy: 0.9061 - top-5-accuracy: 0.9971 - val_loss: 0.6537 - val_accuracy: 0.8191 - val_top-5-accuracy: 0.9671 Epoch 51/100 11/11 [==============================] - 5s 480ms/step - loss: 0.2866 - accuracy: 0.9072 - top-5-accuracy: 0.9982 - val_loss: 0.5929 - val_accuracy: 0.8224 - val_top-5-accuracy: 0.9770 Epoch 52/100 11/11 [==============================] - 5s 468ms/step - loss: 0.2941 - accuracy: 0.9046 - top-5-accuracy: 0.9974 - val_loss: 0.5234 - val_accuracy: 0.8322 - val_top-5-accuracy: 0.9803 Epoch 53/100 11/11 [==============================] - 5s 473ms/step - loss: 0.2900 - accuracy: 0.9141 - top-5-accuracy: 0.9956 - val_loss: 0.6007 - val_accuracy: 0.8289 - val_top-5-accuracy: 0.9737 Epoch 54/100 11/11 [==============================] - 5s 471ms/step - loss: 0.2311 - accuracy: 0.9280 - top-5-accuracy: 0.9993 - val_loss: 0.5893 - val_accuracy: 0.8586 - val_top-5-accuracy: 0.9803 Epoch 55/100 11/11 [==============================] - 5s 478ms/step - loss: 0.2484 - accuracy: 0.9218 - top-5-accuracy: 0.9985 - val_loss: 0.6348 - val_accuracy: 0.8355 - val_top-5-accuracy: 0.9770 Epoch 56/100 11/11 [==============================] - 5s 496ms/step - loss: 0.2685 - accuracy: 0.9159 - top-5-accuracy: 0.9956 - val_loss: 0.5722 - val_accuracy: 0.8586 - val_top-5-accuracy: 0.9770 Epoch 57/100 11/11 [==============================] - 5s 478ms/step - loss: 0.2363 - accuracy: 0.9287 - top-5-accuracy: 0.9989 - val_loss: 0.5732 - val_accuracy: 0.8487 - val_top-5-accuracy: 0.9737 Epoch 58/100 11/11 [==============================] - 5s 470ms/step - loss: 0.2536 - accuracy: 0.9229 - top-5-accuracy: 0.9982 - val_loss: 0.5657 - val_accuracy: 0.8421 - val_top-5-accuracy: 0.9836 Epoch 59/100 11/11 [==============================] - 5s 466ms/step - loss: 0.2083 - accuracy: 0.9393 - top-5-accuracy: 0.9971 - val_loss: 0.5523 - val_accuracy: 0.8618 - val_top-5-accuracy: 0.9704 Epoch 60/100 11/11 [==============================] - 5s 466ms/step - loss: 0.2203 - accuracy: 0.9306 - top-5-accuracy: 0.9989 - val_loss: 0.6903 - val_accuracy: 0.8322 - val_top-5-accuracy: 0.9737 Epoch 61/100 11/11 [==============================] - 5s 470ms/step - loss: 0.2407 - accuracy: 0.9207 - top-5-accuracy: 0.9982 - val_loss: 0.5560 - val_accuracy: 0.8322 - val_top-5-accuracy: 0.9868 Epoch 62/100 11/11 [==============================] - 5s 483ms/step - loss: 0.2638 - accuracy: 0.9152 - top-5-accuracy: 0.9967 - val_loss: 0.5574 - val_accuracy: 0.8553 - val_top-5-accuracy: 0.9803 Epoch 63/100 11/11 [==============================] - 5s 476ms/step - loss: 0.2159 - accuracy: 0.9317 - top-5-accuracy: 0.9985 - val_loss: 0.6279 - val_accuracy: 0.8487 - val_top-5-accuracy: 0.9836 Epoch 64/100 11/11 [==============================] - 5s 463ms/step - loss: 0.2551 - accuracy: 0.9251 - top-5-accuracy: 0.9967 - val_loss: 0.6940 - val_accuracy: 0.8191 - val_top-5-accuracy: 0.9803 Epoch 65/100 11/11 [==============================] - 5s 479ms/step - loss: 0.2325 - accuracy: 0.9258 - top-5-accuracy: 0.9989 - val_loss: 0.6324 - val_accuracy: 0.8355 - val_top-5-accuracy: 0.9770 Epoch 66/100 11/11 [==============================] - 5s 469ms/step - loss: 0.2144 - accuracy: 0.9364 - top-5-accuracy: 0.9982 - val_loss: 0.6224 - val_accuracy: 0.8454 - val_top-5-accuracy: 0.9836 Epoch 67/100 11/11 [==============================] - 5s 470ms/step - loss: 0.2023 - accuracy: 0.9346 - top-5-accuracy: 0.9993 - val_loss: 0.5875 - val_accuracy: 0.8553 - val_top-5-accuracy: 0.9704 Epoch 68/100 11/11 [==============================] - 5s 475ms/step - loss: 0.1912 - accuracy: 0.9375 - top-5-accuracy: 0.9989 - val_loss: 0.6924 - val_accuracy: 0.8224 - val_top-5-accuracy: 0.9737 Epoch 69/100 11/11 [==============================] - 5s 472ms/step - loss: 0.2205 - accuracy: 0.9295 - top-5-accuracy: 0.9974 - val_loss: 0.6561 - val_accuracy: 0.8257 - val_top-5-accuracy: 0.9737 Epoch 70/100 11/11 [==============================] - 5s 471ms/step - loss: 0.2002 - accuracy: 0.9368 - top-5-accuracy: 0.9985 - val_loss: 0.5291 - val_accuracy: 0.8651 - val_top-5-accuracy: 0.9868 Epoch 71/100 11/11 [==============================] - 5s 468ms/step - loss: 0.1925 - accuracy: 0.9382 - top-5-accuracy: 0.9993 - val_loss: 0.5830 - val_accuracy: 0.8586 - val_top-5-accuracy: 0.9836 Epoch 72/100 11/11 [==============================] - 5s 471ms/step - loss: 0.1846 - accuracy: 0.9430 - top-5-accuracy: 0.9996 - val_loss: 0.6433 - val_accuracy: 0.8487 - val_top-5-accuracy: 0.9803 Epoch 73/100 11/11 [==============================] - 5s 470ms/step - loss: 0.1780 - accuracy: 0.9444 - top-5-accuracy: 0.9993 - val_loss: 0.5253 - val_accuracy: 0.8618 - val_top-5-accuracy: 0.9836 Epoch 74/100 11/11 [==============================] - 6s 505ms/step - loss: 0.1646 - accuracy: 0.9485 - top-5-accuracy: 0.9982 - val_loss: 0.5431 - val_accuracy: 0.8684 - val_top-5-accuracy: 0.9803 Epoch 75/100 11/11 [==============================] - 5s 474ms/step - loss: 0.1678 - accuracy: 0.9470 - top-5-accuracy: 0.9978 - val_loss: 0.6069 - val_accuracy: 0.8684 - val_top-5-accuracy: 0.9868 Epoch 76/100 11/11 [==============================] - 5s 469ms/step - loss: 0.1517 - accuracy: 0.9488 - top-5-accuracy: 0.9985 - val_loss: 0.5667 - val_accuracy: 0.8487 - val_top-5-accuracy: 0.9901 Epoch 77/100 11/11 [==============================] - 5s 471ms/step - loss: 0.1416 - accuracy: 0.9525 - top-5-accuracy: 0.9989 - val_loss: 0.6547 - val_accuracy: 0.8421 - val_top-5-accuracy: 0.9803 Epoch 78/100 11/11 [==============================] - 5s 479ms/step - loss: 0.1623 - accuracy: 0.9477 - top-5-accuracy: 0.9989 - val_loss: 0.5051 - val_accuracy: 0.8618 - val_top-5-accuracy: 0.9934 Epoch 79/100 11/11 [==============================] - 5s 475ms/step - loss: 0.1666 - accuracy: 0.9488 - top-5-accuracy: 0.9989 - val_loss: 0.6188 - val_accuracy: 0.8618 - val_top-5-accuracy: 0.9737 Epoch 80/100 11/11 [==============================] - 5s 478ms/step - loss: 0.2011 - accuracy: 0.9368 - top-5-accuracy: 0.9985 - val_loss: 0.6103 - val_accuracy: 0.8618 - val_top-5-accuracy: 0.9803 Epoch 81/100 11/11 [==============================] - 5s 471ms/step - loss: 0.1709 - accuracy: 0.9488 - top-5-accuracy: 0.9985 - val_loss: 0.5960 - val_accuracy: 0.8684 - val_top-5-accuracy: 0.9671 Epoch 82/100 11/11 [==============================] - 5s 480ms/step - loss: 0.1613 - accuracy: 0.9477 - top-5-accuracy: 0.9993 - val_loss: 0.6275 - val_accuracy: 0.8487 - val_top-5-accuracy: 0.9737 Epoch 83/100 11/11 [==============================] - 5s 471ms/step - loss: 0.1537 - accuracy: 0.9507 - top-5-accuracy: 0.9985 - val_loss: 0.6007 - val_accuracy: 0.8586 - val_top-5-accuracy: 0.9803 Epoch 84/100 11/11 [==============================] - 5s 477ms/step - loss: 0.1711 - accuracy: 0.9481 - top-5-accuracy: 0.9982 - val_loss: 0.5334 - val_accuracy: 0.8487 - val_top-5-accuracy: 0.9803 Epoch 85/100 11/11 [==============================] - 5s 468ms/step - loss: 0.1666 - accuracy: 0.9496 - top-5-accuracy: 0.9989 - val_loss: 0.6471 - val_accuracy: 0.8421 - val_top-5-accuracy: 0.9770 Epoch 86/100 11/11 [==============================] - 5s 483ms/step - loss: 0.1671 - accuracy: 0.9496 - top-5-accuracy: 0.9996 - val_loss: 0.7173 - val_accuracy: 0.8289 - val_top-5-accuracy: 0.9836 Epoch 87/100 11/11 [==============================] - 5s 463ms/step - loss: 0.1811 - accuracy: 0.9426 - top-5-accuracy: 0.9985 - val_loss: 0.7476 - val_accuracy: 0.8191 - val_top-5-accuracy: 0.9836 Epoch 88/100 11/11 [==============================] - 5s 482ms/step - loss: 0.1487 - accuracy: 0.9547 - top-5-accuracy: 1.0000 - val_loss: 0.5163 - val_accuracy: 0.8717 - val_top-5-accuracy: 0.9868 Epoch 89/100 11/11 [==============================] - 5s 470ms/step - loss: 0.1241 - accuracy: 0.9609 - top-5-accuracy: 1.0000 - val_loss: 0.5221 - val_accuracy: 0.8717 - val_top-5-accuracy: 0.9868 Epoch 90/100 11/11 [==============================] - 5s 484ms/step - loss: 0.1255 - accuracy: 0.9613 - top-5-accuracy: 0.9993 - val_loss: 0.4853 - val_accuracy: 0.8750 - val_top-5-accuracy: 0.9868 Epoch 91/100 11/11 [==============================] - 5s 488ms/step - loss: 0.1113 - accuracy: 0.9653 - top-5-accuracy: 1.0000 - val_loss: 0.5162 - val_accuracy: 0.8882 - val_top-5-accuracy: 0.9901 Epoch 92/100 11/11 [==============================] - 5s 476ms/step - loss: 0.0999 - accuracy: 0.9638 - top-5-accuracy: 0.9996 - val_loss: 0.4899 - val_accuracy: 0.8882 - val_top-5-accuracy: 0.9934 Epoch 93/100 11/11 [==============================] - 5s 466ms/step - loss: 0.1115 - accuracy: 0.9645 - top-5-accuracy: 0.9993 - val_loss: 0.6128 - val_accuracy: 0.8520 - val_top-5-accuracy: 0.9934 Epoch 94/100 11/11 [==============================] - 5s 474ms/step - loss: 0.1242 - accuracy: 0.9580 - top-5-accuracy: 1.0000 - val_loss: 0.5649 - val_accuracy: 0.8849 - val_top-5-accuracy: 0.9836 Epoch 95/100 11/11 [==============================] - 5s 471ms/step - loss: 0.1530 - accuracy: 0.9569 - top-5-accuracy: 0.9996 - val_loss: 0.5064 - val_accuracy: 0.8651 - val_top-5-accuracy: 0.9836 Epoch 96/100 11/11 [==============================] - 5s 475ms/step - loss: 0.1813 - accuracy: 0.9496 - top-5-accuracy: 0.9989 - val_loss: 0.5598 - val_accuracy: 0.8914 - val_top-5-accuracy: 0.9836 Epoch 97/100 11/11 [==============================] - 5s 488ms/step - loss: 0.1615 - accuracy: 0.9496 - top-5-accuracy: 0.9993 - val_loss: 0.6164 - val_accuracy: 0.8454 - val_top-5-accuracy: 0.9770 Epoch 98/100 11/11 [==============================] - 5s 482ms/step - loss: 0.1599 - accuracy: 0.9499 - top-5-accuracy: 0.9996 - val_loss: 0.6491 - val_accuracy: 0.8586 - val_top-5-accuracy: 0.9836 Epoch 99/100 11/11 [==============================] - 5s 470ms/step - loss: 0.1286 - accuracy: 0.9576 - top-5-accuracy: 0.9993 - val_loss: 0.5429 - val_accuracy: 0.8684 - val_top-5-accuracy: 0.9836 Epoch 100/100 11/11 [==============================] - 5s 469ms/step - loss: 0.1263 - accuracy: 0.9627 - top-5-accuracy: 0.9989 - val_loss: 0.6173 - val_accuracy: 0.8618 - val_top-5-accuracy: 0.9868
2023-02-01 19:25:51.290614: W tensorflow/core/framework/cpu_allocator_impl.cc:80] Allocation of 597688320 exceeds 10% of free system memory.
24/24 [==============================] - 1s 36ms/step - loss: 0.6115 - accuracy: 0.8579 - top-5-accuracy: 0.9750 Test accuracy: 85.79% Test top 5 accuracy: 97.5%
print(history.history.keys())
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
top_5_acc = history.history['top-5-accuracy']
op_5_val_acc = history.history['val_top-5-accuracy']
epochs = range(1, len(acc) + 1)
plt.plot(epochs, acc, 'b', label='Training accurarcy')
plt.plot(epochs, val_acc, 'r', label='Validation accurarcy')
plt.title('Training and Validation accurarcy')
plt.legend()
plt.figure()
plt.plot(epochs, acc, 'b', label='Top-5 Training accurarcy')
plt.plot(epochs, val_acc, 'r', label='Top-5 Validation accurarcy')
plt.title('Training and Validation Top-5 accurarcy')
plt.legend()
plt.figure()
plt.plot(epochs, loss, 'b', label='Training loss')
plt.plot(epochs, val_loss, 'r', label='Validation loss')
plt.title('Training and Validation loss')
plt.legend()
plt.show()
dict_keys(['loss', 'accuracy', 'top-5-accuracy', 'val_loss', 'val_accuracy', 'val_top-5-accuracy'])
print(model.predict(x_test[0]))
vit_classifier_with_att_scores = create_vit_classifier(with_attention_scores_output=True)
checkpoint_filepath = "/tmp/checkpoint"
vit_classifier_with_att_scores.load_weights(checkpoint_filepath)
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7fc3eefcfd90>
plt.figure(figsize=(4, 4))
image_to_predict = x_train[np.random.choice(range(x_train.shape[0]))]
image_to_draw = image_to_predict * 225.0
plt.imshow(image_to_draw.astype("uint8"))
plt.axis("off")
(-0.5, 255.5, 255.5, -0.5)
image_to_predict = np.expand_dims(image_to_predict, axis=0)
output, attention_scores = vit_classifier_with_att_scores(image_to_predict)
# computes attention rollout map
def attention_map(src_image, attention_score_dict):
attention_mat = tf.stack(list(attention_score_dict.values()))
print(attention_mat.shape)
attention_mat = tf.squeeze(attention_mat, axis=1)
print(attention_mat.shape)
# Average attention (for all heads) per MultiHeadAttention layers
attention_mat = tf.reduce_mean(attention_mat, axis=1)
print(attention_mat.shape)
# "to account for residual connections, we add an identity matrix to the attention matrix and re-normalize the weights"
residual_attn = tf.eye(attention_mat.shape[1])
print(residual_attn.shape)
aug_attention_mat = attention_mat + residual_attn
aug_attention_mat = np.array(aug_attention_mat / tf.reduce_sum(aug_attention_mat, axis=-1)[..., None])
print(aug_attention_mat.shape)
# Recursive formula on Vl+1
joint_attentions = np.zeros(aug_attention_mat.shape)
joint_attentions[0] = aug_attention_mat[0]
for n in range(1, aug_attention_mat.shape[0]):
joint_attentions[n] = np.matmul(aug_attention_mat[n], joint_attentions[n - 1])
# Reshape last calculated attention that we can multiply it by the input RGB image
v = joint_attentions[-1]
grid_size = int(np.sqrt(aug_attention_mat.shape[-1]))
mask = v[0].reshape(grid_size, grid_size)
#print(src_image.shape[:2])
#print(src_image.size)
mask = cv2.resize(mask / mask.max(), src_image.shape[:2])[..., np.newaxis] #[: ,:, np.newaxis]
#print(mask.shape)
result = (mask * src_image).astype("uint8")
return result
attn_rollout_result = attention_map(image_to_draw, attention_scores)
fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(8, 10))
ax1.imshow(image_to_draw.astype("uint8"))
ax2.imshow(attn_rollout_result)
ax1.axis("off")
ax2.axis("off")
ax1.set_title("Input")
ax2.set_title("Attention Map")
fig.tight_layout()
fig.show()
(8, 1, 4, 144, 144) (8, 4, 144, 144) (8, 144, 144) (144, 144) (8, 144, 144)