diff --git a/dataset/dataset.py b/dataset/dataset.py index 6631a20..d098ead 100644 --- a/dataset/dataset.py +++ b/dataset/dataset.py @@ -7,8 +7,8 @@ import tensorflow as tf class Dataset: ''' Class to load and preprocess the dataset. Loads images and labels from the given directory to tf.data.Dataset. - - + + Args: `data_dir (Path)`: Path to the dataset directory. `seed (int)`: Seed for shuffling the dataset. @@ -16,6 +16,7 @@ class Dataset: `shuffle_buffer_size (int)`: Size of the buffer for shuffling the dataset. `batch_size (int)`: Batch size for the dataset. ''' + def __init__(self, data_dir: Path, seed: int = 42, @@ -25,10 +26,11 @@ class Dataset: self.data_dir = data_dir self.seed = seed self.repeat = repeat + self.shuffle_buffer_size = shuffle_buffer_size self.batch_size = batch_size self.dataset = self._load_dataset()\ - .shuffle(shuffle_buffer_size, seed=self.seed)\ + .shuffle(self.shuffle_buffer_size, seed=self.seed)\ .repeat(self.repeat)\ .prefetch(tf.data.experimental.AUTOTUNE) @@ -40,27 +42,24 @@ class Dataset: pass else: dataset = dataset.map( - _preprocess, num_parallel_calls=tf.data.experimental.AUTOTUNE) + self._preprocess, num_parallel_calls=tf.data.experimental.AUTOTUNE) return dataset + def _get_labels(self, image_path): + path = tf.strings.split(image_path, os.path.sep)[-2] + plant = tf.strings.split(path, '___')[0] + disease = tf.strings.split(path, '___')[1] + return tf.cast(plant, dtype=tf.string, name=None), tf.cast(disease, dtype=tf.string, name=None) -def _get_labels(image_path): - path = tf.strings.split(image_path, os.path.sep)[-2] - plant = tf.strings.split(path, '___')[0] - disease = tf.strings.split(path, '___')[1] - return tf.cast(plant, dtype=tf.string, name=None), tf.cast(disease, dtype=tf.string, name=None) + def _get_image(self, image_path): + img = tf.io.read_file(image_path) + img = tf.io.decode_jpeg(img, channels=3) / 255 + return tf.cast(img, dtype=tf.float32, name=None) + def _preprocess(self, image_path): + labels = self._get_labels(image_path) + image = self._get_image(image_path) -def _get_image(image_path): - img = tf.io.read_file(image_path) - img = tf.io.decode_jpeg(img, channels=3) / 255 - return tf.cast(img, dtype=tf.float32, name=None) - - -def _preprocess(image_path): - labels = _get_labels(image_path) - image = _get_image(image_path) - - # returns X, Y1, Y2 - return image, labels + # returns X, Y1, Y2 + return image, labels