move funcs to class

This commit is contained in:
mszmyd 2024-05-05 12:41:54 +02:00
parent e3002d5ef8
commit de27695d53

View File

@ -7,8 +7,8 @@ import tensorflow as tf
class Dataset: class Dataset:
''' Class to load and preprocess the dataset. ''' Class to load and preprocess the dataset.
Loads images and labels from the given directory to tf.data.Dataset. Loads images and labels from the given directory to tf.data.Dataset.
Args: Args:
`data_dir (Path)`: Path to the dataset directory. `data_dir (Path)`: Path to the dataset directory.
`seed (int)`: Seed for shuffling the dataset. `seed (int)`: Seed for shuffling the dataset.
@ -16,6 +16,7 @@ class Dataset:
`shuffle_buffer_size (int)`: Size of the buffer for shuffling the dataset. `shuffle_buffer_size (int)`: Size of the buffer for shuffling the dataset.
`batch_size (int)`: Batch size for the dataset. `batch_size (int)`: Batch size for the dataset.
''' '''
def __init__(self, def __init__(self,
data_dir: Path, data_dir: Path,
seed: int = 42, seed: int = 42,
@ -25,10 +26,11 @@ class Dataset:
self.data_dir = data_dir self.data_dir = data_dir
self.seed = seed self.seed = seed
self.repeat = repeat self.repeat = repeat
self.shuffle_buffer_size = shuffle_buffer_size
self.batch_size = batch_size self.batch_size = batch_size
self.dataset = self._load_dataset()\ self.dataset = self._load_dataset()\
.shuffle(shuffle_buffer_size, seed=self.seed)\ .shuffle(self.shuffle_buffer_size, seed=self.seed)\
.repeat(self.repeat)\ .repeat(self.repeat)\
.prefetch(tf.data.experimental.AUTOTUNE) .prefetch(tf.data.experimental.AUTOTUNE)
@ -40,27 +42,24 @@ class Dataset:
pass pass
else: else:
dataset = dataset.map( dataset = dataset.map(
_preprocess, num_parallel_calls=tf.data.experimental.AUTOTUNE) self._preprocess, num_parallel_calls=tf.data.experimental.AUTOTUNE)
return dataset return dataset
def _get_labels(self, image_path):
path = tf.strings.split(image_path, os.path.sep)[-2]
plant = tf.strings.split(path, '___')[0]
disease = tf.strings.split(path, '___')[1]
return tf.cast(plant, dtype=tf.string, name=None), tf.cast(disease, dtype=tf.string, name=None)
def _get_labels(image_path): def _get_image(self, image_path):
path = tf.strings.split(image_path, os.path.sep)[-2] img = tf.io.read_file(image_path)
plant = tf.strings.split(path, '___')[0] img = tf.io.decode_jpeg(img, channels=3) / 255
disease = tf.strings.split(path, '___')[1] return tf.cast(img, dtype=tf.float32, name=None)
return tf.cast(plant, dtype=tf.string, name=None), tf.cast(disease, dtype=tf.string, name=None)
def _preprocess(self, image_path):
labels = self._get_labels(image_path)
image = self._get_image(image_path)
def _get_image(image_path): # returns X, Y1, Y2
img = tf.io.read_file(image_path) return image, labels
img = tf.io.decode_jpeg(img, channels=3) / 255
return tf.cast(img, dtype=tf.float32, name=None)
def _preprocess(image_path):
labels = _get_labels(image_path)
image = _get_image(image_path)
# returns X, Y1, Y2
return image, labels