inteligentna_smieciarka/trash-classification/trash_classification.ipynb
2023-05-31 23:40:47 +02:00

165 KiB

import cv2
import imghdr
import os
data_dir = 'trash/test' 
image_exts = ['jpeg','jpg', 'bmp', 'png']
for image_class in os.listdir(data_dir): 
    for image in os.listdir(os.path.join(data_dir, image_class)):
        image_path = os.path.join(data_dir, image_class, image)
        try: 
            img = cv2.imread(image_path)
            tip = imghdr.what(image_path)
            if tip not in image_exts: 
                print('Image not in ext list {}'.format(image_path))
                os.remove(image_path)
        except Exception as e: 
            print('Issue with image {}'.format(image_path))
            # os.remove(image_path)
import tensorflow as tf
import matplotlib.pyplot as plt
img_height, img_width = 32, 32
batch_size = 20

train_ds = tf.keras.utils.image_dataset_from_directory(
    "trash/train",
    image_size = (img_height, img_width),
    batch_size = batch_size
)
val_ds = tf.keras.utils.image_dataset_from_directory(
    "trash/validation",
    image_size = (img_height, img_width),
    batch_size = batch_size
)
test_ds = tf.keras.utils.image_dataset_from_directory(
    "trash/test",
    image_size = (img_height, img_width),
    batch_size = batch_size
)
Found 5538 files belonging to 4 classes.
Found 991 files belonging to 4 classes.
Found 1441 files belonging to 4 classes.
class_names = ["e-waste", "glass", "paper", "plastic"]
plt.figure(figsize=(10,10))
for images, labels in train_ds.take(1):
  for i in range(9):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(images[i].numpy().astype("uint8"))
    plt.title(class_names[labels[i]])
    plt.axis("off")
model = tf.keras.Sequential(
    [
     tf.keras.layers.Rescaling(1./255),
     tf.keras.layers.Conv2D(32, 3, activation="relu"),
     tf.keras.layers.MaxPooling2D(),
     tf.keras.layers.Conv2D(32, 3, activation="relu"),
     tf.keras.layers.MaxPooling2D(),
     tf.keras.layers.Conv2D(32, 3, activation="relu"),
     tf.keras.layers.MaxPooling2D(),
     tf.keras.layers.Flatten(),
     tf.keras.layers.Dense(128, activation="relu"),
     tf.keras.layers.Dense(4)
    ]
)
model.compile(
    optimizer="adam",
    loss=tf.losses.SparseCategoricalCrossentropy(from_logits = True),
    metrics=['accuracy']
)
model.fit(
    train_ds,
    validation_data = val_ds,
    epochs = 10
)
Epoch 1/10
277/277 [==============================] - 21s 72ms/step - loss: 1.3038 - accuracy: 0.3759 - val_loss: 1.2956 - val_accuracy: 0.3582
Epoch 2/10
277/277 [==============================] - 20s 71ms/step - loss: 1.1874 - accuracy: 0.4807 - val_loss: 1.2681 - val_accuracy: 0.3956
Epoch 3/10
277/277 [==============================] - 20s 71ms/step - loss: 1.1256 - accuracy: 0.5105 - val_loss: 1.2188 - val_accuracy: 0.4470
Epoch 4/10
277/277 [==============================] - 20s 71ms/step - loss: 1.0791 - accuracy: 0.5428 - val_loss: 1.1535 - val_accuracy: 0.4884
Epoch 5/10
277/277 [==============================] - 20s 71ms/step - loss: 1.0317 - accuracy: 0.5668 - val_loss: 1.1305 - val_accuracy: 0.4955
Epoch 6/10
277/277 [==============================] - 20s 71ms/step - loss: 0.9874 - accuracy: 0.5791 - val_loss: 1.1266 - val_accuracy: 0.5166
Epoch 7/10
277/277 [==============================] - 20s 72ms/step - loss: 0.9342 - accuracy: 0.6154 - val_loss: 1.1693 - val_accuracy: 0.5227
Epoch 8/10
277/277 [==============================] - 20s 71ms/step - loss: 0.8914 - accuracy: 0.6345 - val_loss: 1.1575 - val_accuracy: 0.5177
Epoch 9/10
277/277 [==============================] - 20s 71ms/step - loss: 0.8300 - accuracy: 0.6643 - val_loss: 1.1687 - val_accuracy: 0.5580
Epoch 10/10
277/277 [==============================] - 20s 72ms/step - loss: 0.7733 - accuracy: 0.6876 - val_loss: 1.2046 - val_accuracy: 0.5399
<keras.callbacks.History at 0x1ee7c37f100>
model.evaluate(test_ds)
73/73 [==============================] - 5s 59ms/step - loss: 1.1350 - accuracy: 0.5420
[1.134972333908081, 0.5419847369194031]
import numpy

plt.figure(figsize=(10,10))
for images, labels in test_ds.take(1):
  classifications = model(images)
  # print(classifications)
  
  for i in range(9):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(images[i].numpy().astype("uint8"))
    index = numpy.argmax(classifications[i])
    plt.title("Pred: " + class_names[index] + " | Real: " + class_names[labels[i]])