diff --git a/dataset/dataset.py b/dataset/dataset.py index 0c9c035..e190384 100644 --- a/dataset/dataset.py +++ b/dataset/dataset.py @@ -34,6 +34,7 @@ class Dataset: self.dataset = self.__load_dataset()\ .shuffle(self.shuffle_buffer_size, seed=self.seed)\ .repeat(self.repeat)\ + .batch(self.batch_size, drop_remainder=True)\ .prefetch(tf.data.experimental.AUTOTUNE) def __load_dataset(self) -> tf.data.Dataset: