diff --git a/dataset/consts.py b/dataset/consts.py new file mode 100644 index 0000000..5f71df7 --- /dev/null +++ b/dataset/consts.py @@ -0,0 +1,40 @@ +PLANT_CLASSES = [ + "Tomato", + "Potato", + "Corn_(maize)", + "Apple", + "Blueberry", + "Soybean", + "Cherry_(including_sour)", + "Squash", + "Strawberry", + "Pepper,_bell", + "Peach", + "Grape", + "Orange", + "Raspberry", +] + +DISEASE_CLASSES = [ + "healthy", + "Northern_Leaf_Blight", + "Tomato_mosaic_virus", + "Early_blight", + "Leaf_scorch", + "Tomato_Yellow_Leaf_Curl_Virus", + "Cedar_apple_rust", + "Late_blight", + "Spider_mites Two-spotted_spider_mite", + "Black_rot", + "Bacterial_spot", + "Apple_scab", + "Powdery_mildew", + "Esca_(Black_Measles)", + "Haunglongbing_(Citrus_greening)", + "Leaf_Mold", + "Common_rust_", + "Target_Spot", + "Leaf_blight_(Isariopsis_Leaf_Spot)", + "Septoria_leaf_spot", + "Cercospora_leaf_spot Gray_leaf_spot", +] \ No newline at end of file diff --git a/dataset/dataset.py b/dataset/dataset.py index 6b4489b..b299e25 100644 --- a/dataset/dataset.py +++ b/dataset/dataset.py @@ -3,6 +3,8 @@ from pathlib import Path import tensorflow as tf +from .consts import DISEASE_CLASSES, PLANT_CLASSES + class Dataset: ''' Class to load and preprocess the dataset. @@ -50,16 +52,23 @@ class Dataset: path = tf.strings.split(image_path, os.path.sep)[-2] plant = tf.strings.split(path, '___')[0] disease = tf.strings.split(path, '___')[1] - return tf.cast(plant, dtype=tf.string, name=None), tf.cast(disease, dtype=tf.string, name=None) + + one_hot_plant = plant == PLANT_CLASSES + one_hot_disease = disease == DISEASE_CLASSES + + return tf.cast(one_hot_plant, dtype=tf.uint8, name=None), tf.cast(one_hot_disease, dtype=tf.uint8, name=None) def _get_image(self, image_path): img = tf.io.read_file(image_path) - img = tf.io.decode_jpeg(img, channels=3) / 255. - return tf.cast(img, dtype=tf.float32, name=None) + img = tf.io.decode_jpeg(img, channels=3) + return tf.cast(img, dtype=tf.float32, name=None) / 255. def _preprocess(self, image_path): labels = self._get_labels(image_path) image = self._get_image(image_path) # returns X, Y1, Y2 - return image, labels + return image, labels[0], labels[1] + + def __getattr__(self, attr): + return getattr(self.dataset, attr) diff --git a/test.py b/test.py index 41d5b6f..a75f18f 100644 --- a/test.py +++ b/test.py @@ -6,5 +6,5 @@ from dataset.dataset import Dataset train_dataset = Dataset(Path('data/resized_dataset/train')) valid_dataset = Dataset(Path('data/resized_dataset/valid')) -for image, labels in train_dataset.dataset.take(1): - print(image, labels) +for i in train_dataset.take(1): + print(i)