748 lines
27 KiB
Python
748 lines
27 KiB
Python
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Keras image dataset loading utilities."""
|
|
|
|
import multiprocessing
|
|
import os
|
|
import random
|
|
import time
|
|
import warnings
|
|
|
|
import numpy as np
|
|
import tensorflow.compat.v2 as tf
|
|
|
|
from keras.utils import io_utils
|
|
|
|
# isort: off
|
|
from tensorflow.python.util.tf_export import keras_export
|
|
|
|
|
|
@keras_export("keras.utils.split_dataset", v1=[])
|
|
def split_dataset(
|
|
dataset, left_size=None, right_size=None, shuffle=False, seed=None
|
|
):
|
|
"""Split a dataset into a left half and a right half (e.g. train / test).
|
|
|
|
Args:
|
|
dataset: A `tf.data.Dataset` object, or a list/tuple of arrays with the
|
|
same length.
|
|
left_size: If float (in the range `[0, 1]`), it signifies
|
|
the fraction of the data to pack in the left dataset. If integer, it
|
|
signifies the number of samples to pack in the left dataset. If
|
|
`None`, it defaults to the complement to `right_size`.
|
|
right_size: If float (in the range `[0, 1]`), it signifies
|
|
the fraction of the data to pack in the right dataset. If integer, it
|
|
signifies the number of samples to pack in the right dataset. If
|
|
`None`, it defaults to the complement to `left_size`.
|
|
shuffle: Boolean, whether to shuffle the data before splitting it.
|
|
seed: A random seed for shuffling.
|
|
|
|
Returns:
|
|
A tuple of two `tf.data.Dataset` objects: the left and right splits.
|
|
|
|
Example:
|
|
|
|
>>> data = np.random.random(size=(1000, 4))
|
|
>>> left_ds, right_ds = tf.keras.utils.split_dataset(data, left_size=0.8)
|
|
>>> int(left_ds.cardinality())
|
|
800
|
|
>>> int(right_ds.cardinality())
|
|
200
|
|
|
|
"""
|
|
dataset_type_spec = _get_type_spec(dataset)
|
|
|
|
if dataset_type_spec not in [tf.data.Dataset, list, tuple, np.ndarray]:
|
|
raise TypeError(
|
|
"The `dataset` argument must be either a `tf.data.Dataset` "
|
|
"object or a list/tuple of arrays. "
|
|
f"Received: dataset={dataset} of type {type(dataset)}"
|
|
)
|
|
|
|
if right_size is None and left_size is None:
|
|
raise ValueError(
|
|
"At least one of the `left_size` or `right_size` "
|
|
"must be specified. Received: left_size=None and "
|
|
"right_size=None"
|
|
)
|
|
|
|
dataset_as_list = _convert_dataset_to_list(dataset, dataset_type_spec)
|
|
|
|
if shuffle:
|
|
if seed is None:
|
|
seed = random.randint(0, int(1e6))
|
|
random.seed(seed)
|
|
random.shuffle(dataset_as_list)
|
|
|
|
total_length = len(dataset_as_list)
|
|
|
|
left_size, right_size = _rescale_dataset_split_sizes(
|
|
left_size, right_size, total_length
|
|
)
|
|
left_split = list(dataset_as_list[:left_size])
|
|
right_split = list(dataset_as_list[-right_size:])
|
|
|
|
left_split = _restore_dataset_from_list(
|
|
left_split, dataset_type_spec, dataset
|
|
)
|
|
right_split = _restore_dataset_from_list(
|
|
right_split, dataset_type_spec, dataset
|
|
)
|
|
|
|
left_split = tf.data.Dataset.from_tensor_slices(left_split)
|
|
right_split = tf.data.Dataset.from_tensor_slices(right_split)
|
|
|
|
# apply batching to the splits if the dataset is batched
|
|
if dataset_type_spec is tf.data.Dataset and is_batched(dataset):
|
|
batch_size = get_batch_size(dataset)
|
|
if batch_size is not None:
|
|
left_split = left_split.batch(batch_size)
|
|
right_split = right_split.batch(batch_size)
|
|
|
|
left_split = left_split.prefetch(tf.data.AUTOTUNE)
|
|
right_split = right_split.prefetch(tf.data.AUTOTUNE)
|
|
|
|
return left_split, right_split
|
|
|
|
|
|
def _convert_dataset_to_list(
|
|
dataset,
|
|
dataset_type_spec,
|
|
data_size_warning_flag=True,
|
|
ensure_shape_similarity=True,
|
|
):
|
|
"""Convert `tf.data.Dataset` object or list/tuple of NumPy arrays to a list.
|
|
|
|
Args:
|
|
dataset : A `tf.data.Dataset` object or a list/tuple of arrays.
|
|
dataset_type_spec : the type of the dataset
|
|
data_size_warning_flag (bool, optional): If set to True, a warning will
|
|
be issued if the dataset takes longer than 10 seconds to iterate.
|
|
Defaults to True.
|
|
ensure_shape_similarity (bool, optional): If set to True, the shape of
|
|
the first sample will be used to validate the shape of rest of the
|
|
samples. Defaults to True.
|
|
|
|
Returns:
|
|
List: A list of tuples/NumPy arrays.
|
|
"""
|
|
dataset_iterator = _get_data_iterator_from_dataset(
|
|
dataset, dataset_type_spec
|
|
)
|
|
dataset_as_list = []
|
|
|
|
start_time = time.time()
|
|
for sample in _get_next_sample(
|
|
dataset_iterator,
|
|
ensure_shape_similarity,
|
|
data_size_warning_flag,
|
|
start_time,
|
|
):
|
|
if dataset_type_spec in [tuple, list]:
|
|
# The try-except here is for NumPy 1.24 compatibility, see:
|
|
# https://numpy.org/neps/nep-0034-infer-dtype-is-object.html
|
|
try:
|
|
arr = np.array(sample)
|
|
except ValueError:
|
|
arr = np.array(sample, dtype=object)
|
|
dataset_as_list.append(arr)
|
|
else:
|
|
dataset_as_list.append(sample)
|
|
|
|
return dataset_as_list
|
|
|
|
|
|
def _get_data_iterator_from_dataset(dataset, dataset_type_spec):
|
|
"""Get the iterator from a dataset.
|
|
|
|
Args:
|
|
dataset : A `tf.data.Dataset` object or a list/tuple of arrays.
|
|
dataset_type_spec : the type of the dataset
|
|
|
|
Raises:
|
|
ValueError:
|
|
- If the dataset is empty.
|
|
- If the dataset is not a `tf.data.Dataset` object
|
|
or a list/tuple of arrays.
|
|
- If the dataset is a list/tuple of arrays and the
|
|
length of the list/tuple is not equal to the number
|
|
|
|
Returns:
|
|
iterator: An `iterator` object.
|
|
"""
|
|
if dataset_type_spec == list:
|
|
if len(dataset) == 0:
|
|
raise ValueError(
|
|
"Received an empty list dataset. "
|
|
"Please provide a non-empty list of arrays."
|
|
)
|
|
|
|
if _get_type_spec(dataset[0]) is np.ndarray:
|
|
expected_shape = dataset[0].shape
|
|
for i, element in enumerate(dataset):
|
|
if np.array(element).shape[0] != expected_shape[0]:
|
|
raise ValueError(
|
|
"Received a list of NumPy arrays with different "
|
|
f"lengths. Mismatch found at index {i}, "
|
|
f"Expected shape={expected_shape} "
|
|
f"Received shape={np.array(element).shape}."
|
|
"Please provide a list of NumPy arrays with "
|
|
"the same length."
|
|
)
|
|
else:
|
|
raise ValueError(
|
|
"Expected a list of `numpy.ndarray` objects,"
|
|
f"Received: {type(dataset[0])}"
|
|
)
|
|
|
|
return iter(zip(*dataset))
|
|
elif dataset_type_spec == tuple:
|
|
if len(dataset) == 0:
|
|
raise ValueError(
|
|
"Received an empty list dataset."
|
|
"Please provide a non-empty tuple of arrays."
|
|
)
|
|
|
|
if _get_type_spec(dataset[0]) is np.ndarray:
|
|
expected_shape = dataset[0].shape
|
|
for i, element in enumerate(dataset):
|
|
if np.array(element).shape[0] != expected_shape[0]:
|
|
raise ValueError(
|
|
"Received a tuple of NumPy arrays with different "
|
|
f"lengths. Mismatch found at index {i}, "
|
|
f"Expected shape={expected_shape} "
|
|
f"Received shape={np.array(element).shape}."
|
|
"Please provide a tuple of NumPy arrays with "
|
|
"the same length."
|
|
)
|
|
else:
|
|
raise ValueError(
|
|
"Expected a tuple of `numpy.ndarray` objects, "
|
|
f"Received: {type(dataset[0])}"
|
|
)
|
|
|
|
return iter(zip(*dataset))
|
|
elif dataset_type_spec == tf.data.Dataset:
|
|
if is_batched(dataset):
|
|
dataset = dataset.unbatch()
|
|
return iter(dataset)
|
|
elif dataset_type_spec == np.ndarray:
|
|
return iter(dataset)
|
|
|
|
|
|
def _get_next_sample(
|
|
dataset_iterator,
|
|
ensure_shape_similarity,
|
|
data_size_warning_flag,
|
|
start_time,
|
|
):
|
|
""" "Yield data samples from the `dataset_iterator`.
|
|
|
|
Args:
|
|
dataset_iterator : An `iterator` object.
|
|
ensure_shape_similarity (bool, optional): If set to True, the shape of
|
|
the first sample will be used to validate the shape of rest of the
|
|
samples. Defaults to True.
|
|
data_size_warning_flag (bool, optional): If set to True, a warning will
|
|
be issued if the dataset takes longer than 10 seconds to iterate.
|
|
Defaults to True.
|
|
start_time (float): the start time of the dataset iteration. this is
|
|
used only if `data_size_warning_flag` is set to true.
|
|
|
|
Raises:
|
|
ValueError: - If the dataset is empty.
|
|
- If `ensure_shape_similarity` is set to True and the
|
|
shape of the first sample is not equal to the shape of
|
|
atleast one of the rest of the samples.
|
|
|
|
Yields:
|
|
data_sample: A tuple/list of numpy arrays.
|
|
"""
|
|
try:
|
|
dataset_iterator = iter(dataset_iterator)
|
|
first_sample = next(dataset_iterator)
|
|
if isinstance(first_sample, (tf.Tensor, np.ndarray)):
|
|
first_sample_shape = np.array(first_sample).shape
|
|
else:
|
|
first_sample_shape = None
|
|
ensure_shape_similarity = False
|
|
yield first_sample
|
|
except StopIteration:
|
|
raise ValueError(
|
|
"Received an empty Dataset. `dataset` must "
|
|
"be a non-empty list/tuple of `numpy.ndarray` objects "
|
|
"or `tf.data.Dataset` objects."
|
|
)
|
|
|
|
for i, sample in enumerate(dataset_iterator):
|
|
if ensure_shape_similarity:
|
|
if first_sample_shape != np.array(sample).shape:
|
|
raise ValueError(
|
|
"All `dataset` samples must have same shape, "
|
|
f"Expected shape: {np.array(first_sample).shape} "
|
|
f"Received shape: {np.array(sample).shape} at index "
|
|
f"{i}."
|
|
)
|
|
if data_size_warning_flag:
|
|
if i % 10 == 0:
|
|
cur_time = time.time()
|
|
# warns user if the dataset is too large to iterate within 10s
|
|
if int(cur_time - start_time) > 10 and data_size_warning_flag:
|
|
warnings.warn(
|
|
"The dataset is taking longer than 10 seconds to "
|
|
"iterate over. This may be due to the size of the "
|
|
"dataset. Keep in mind that the `split_dataset` "
|
|
"utility is only for small in-memory dataset "
|
|
"(e.g. < 10,000 samples).",
|
|
category=ResourceWarning,
|
|
source="split_dataset",
|
|
)
|
|
data_size_warning_flag = False
|
|
yield sample
|
|
|
|
|
|
def _restore_dataset_from_list(
|
|
dataset_as_list, dataset_type_spec, original_dataset
|
|
):
|
|
"""Restore the dataset from the list of arrays."""
|
|
if dataset_type_spec in [tuple, list]:
|
|
return tuple(np.array(sample) for sample in zip(*dataset_as_list))
|
|
elif dataset_type_spec == tf.data.Dataset:
|
|
if isinstance(original_dataset.element_spec, dict):
|
|
restored_dataset = {}
|
|
for d in dataset_as_list:
|
|
for k, v in d.items():
|
|
if k not in restored_dataset:
|
|
restored_dataset[k] = [v]
|
|
else:
|
|
restored_dataset[k].append(v)
|
|
return restored_dataset
|
|
else:
|
|
return tuple(np.array(sample) for sample in zip(*dataset_as_list))
|
|
return dataset_as_list
|
|
|
|
|
|
def _rescale_dataset_split_sizes(left_size, right_size, total_length):
|
|
"""Rescale the dataset split sizes.
|
|
|
|
We want to ensure that the sum of
|
|
the split sizes is equal to the total length of the dataset.
|
|
|
|
Args:
|
|
left_size : The size of the left dataset split.
|
|
right_size : The size of the right dataset split.
|
|
total_length : The total length of the dataset.
|
|
|
|
Raises:
|
|
TypeError: - If `left_size` or `right_size` is not an integer or float.
|
|
ValueError: - If `left_size` or `right_size` is negative or greater
|
|
than 1 or greater than `total_length`.
|
|
|
|
Returns:
|
|
tuple: A tuple of rescaled left_size and right_size
|
|
"""
|
|
left_size_type = type(left_size)
|
|
right_size_type = type(right_size)
|
|
|
|
# check both left_size and right_size are integers or floats
|
|
if (left_size is not None and left_size_type not in [int, float]) and (
|
|
right_size is not None and right_size_type not in [int, float]
|
|
):
|
|
raise TypeError(
|
|
"Invalid `left_size` and `right_size` Types. Expected: "
|
|
"integer or float or None, Received: type(left_size)="
|
|
f"{left_size_type} and type(right_size)={right_size_type}"
|
|
)
|
|
|
|
# check left_size is a integer or float
|
|
if left_size is not None and left_size_type not in [int, float]:
|
|
raise TypeError(
|
|
"Invalid `left_size` Type. Expected: int or float or None, "
|
|
f"Received: type(left_size)={left_size_type}. "
|
|
)
|
|
|
|
# check right_size is a integer or float
|
|
if right_size is not None and right_size_type not in [int, float]:
|
|
raise TypeError(
|
|
"Invalid `right_size` Type. "
|
|
"Expected: int or float or None,"
|
|
f"Received: type(right_size)={right_size_type}."
|
|
)
|
|
|
|
# check left_size and right_size are non-zero
|
|
if left_size == 0 and right_size == 0:
|
|
raise ValueError(
|
|
"Both `left_size` and `right_size` are zero. "
|
|
"At least one of the split sizes must be non-zero."
|
|
)
|
|
|
|
# check left_size is non-negative and less than 1 and less than total_length
|
|
if (
|
|
left_size_type == int
|
|
and (left_size <= 0 or left_size >= total_length)
|
|
or left_size_type == float
|
|
and (left_size <= 0 or left_size >= 1)
|
|
):
|
|
raise ValueError(
|
|
"`left_size` should be either a positive integer "
|
|
f"smaller than {total_length}, or a float "
|
|
"within the range `[0, 1]`. Received: left_size="
|
|
f"{left_size}"
|
|
)
|
|
|
|
# check right_size is non-negative and less than 1 and less than
|
|
# total_length
|
|
if (
|
|
right_size_type == int
|
|
and (right_size <= 0 or right_size >= total_length)
|
|
or right_size_type == float
|
|
and (right_size <= 0 or right_size >= 1)
|
|
):
|
|
raise ValueError(
|
|
"`right_size` should be either a positive integer "
|
|
f"and smaller than {total_length} or a float "
|
|
"within the range `[0, 1]`. Received: right_size="
|
|
f"{right_size}"
|
|
)
|
|
|
|
# check sum of left_size and right_size is less than or equal to
|
|
# total_length
|
|
if (
|
|
right_size_type == left_size_type == float
|
|
and right_size + left_size > 1
|
|
):
|
|
raise ValueError(
|
|
"The sum of `left_size` and `right_size` is greater "
|
|
"than 1. It must be less than or equal to 1."
|
|
)
|
|
|
|
if left_size_type == float:
|
|
left_size = round(left_size * total_length)
|
|
elif left_size_type == int:
|
|
left_size = float(left_size)
|
|
|
|
if right_size_type == float:
|
|
right_size = round(right_size * total_length)
|
|
elif right_size_type == int:
|
|
right_size = float(right_size)
|
|
|
|
if left_size is None:
|
|
left_size = total_length - right_size
|
|
elif right_size is None:
|
|
right_size = total_length - left_size
|
|
|
|
if left_size + right_size > total_length:
|
|
raise ValueError(
|
|
"The sum of `left_size` and `right_size` should "
|
|
"be smaller than the {total_length}. "
|
|
f"Received: left_size + right_size = {left_size+right_size}"
|
|
f"and total_length = {total_length}"
|
|
)
|
|
|
|
for split, side in [(left_size, "left"), (right_size, "right")]:
|
|
if split == 0:
|
|
raise ValueError(
|
|
f"With `dataset` of length={total_length}, `left_size`="
|
|
f"{left_size} and `right_size`={right_size}."
|
|
f"Resulting {side} side dataset split will be empty. "
|
|
"Adjust any of the aforementioned parameters"
|
|
)
|
|
|
|
left_size, right_size = int(left_size), int(right_size)
|
|
return left_size, right_size
|
|
|
|
|
|
def _get_type_spec(dataset):
|
|
"""Get the type spec of the dataset."""
|
|
if isinstance(dataset, tuple):
|
|
return tuple
|
|
elif isinstance(dataset, list):
|
|
return list
|
|
elif isinstance(dataset, np.ndarray):
|
|
return np.ndarray
|
|
elif isinstance(dataset, dict):
|
|
return dict
|
|
elif isinstance(dataset, tf.data.Dataset):
|
|
return tf.data.Dataset
|
|
else:
|
|
return None
|
|
|
|
|
|
def is_batched(tf_dataset):
|
|
""" "Check if the `tf.data.Dataset` is batched."""
|
|
return hasattr(tf_dataset, "_batch_size")
|
|
|
|
|
|
def get_batch_size(tf_dataset):
|
|
"""Get the batch size of the dataset."""
|
|
if is_batched(tf_dataset):
|
|
return tf_dataset._batch_size
|
|
else:
|
|
return None
|
|
|
|
|
|
def index_directory(
|
|
directory,
|
|
labels,
|
|
formats,
|
|
class_names=None,
|
|
shuffle=True,
|
|
seed=None,
|
|
follow_links=False,
|
|
):
|
|
"""Make list of all files in `directory`, with their labels.
|
|
|
|
Args:
|
|
directory: Directory where the data is located.
|
|
If `labels` is "inferred", it should contain
|
|
subdirectories, each containing files for a class.
|
|
Otherwise, the directory structure is ignored.
|
|
labels: Either "inferred"
|
|
(labels are generated from the directory structure),
|
|
None (no labels),
|
|
or a list/tuple of integer labels of the same size as the number of
|
|
valid files found in the directory. Labels should be sorted according
|
|
to the alphanumeric order of the image file paths
|
|
(obtained via `os.walk(directory)` in Python).
|
|
formats: Allowlist of file extensions to index (e.g. ".jpg", ".txt").
|
|
class_names: Only valid if "labels" is "inferred". This is the explicit
|
|
list of class names (must match names of subdirectories). Used
|
|
to control the order of the classes
|
|
(otherwise alphanumerical order is used).
|
|
shuffle: Whether to shuffle the data. Default: True.
|
|
If set to False, sorts the data in alphanumeric order.
|
|
seed: Optional random seed for shuffling.
|
|
follow_links: Whether to visits subdirectories pointed to by symlinks.
|
|
|
|
Returns:
|
|
tuple (file_paths, labels, class_names).
|
|
file_paths: list of file paths (strings).
|
|
labels: list of matching integer labels (same length as file_paths)
|
|
class_names: names of the classes corresponding to these labels, in
|
|
order.
|
|
"""
|
|
if labels != "inferred":
|
|
# in the explicit/no-label cases, index from the parent directory down.
|
|
subdirs = [""]
|
|
class_names = subdirs
|
|
else:
|
|
subdirs = []
|
|
for subdir in sorted(tf.io.gfile.listdir(directory)):
|
|
if tf.io.gfile.isdir(tf.io.gfile.join(directory, subdir)):
|
|
if subdir.endswith("/"):
|
|
subdir = subdir[:-1]
|
|
subdirs.append(subdir)
|
|
if not class_names:
|
|
class_names = subdirs
|
|
else:
|
|
if set(class_names) != set(subdirs):
|
|
raise ValueError(
|
|
"The `class_names` passed did not match the "
|
|
"names of the subdirectories of the target directory. "
|
|
f"Expected: {subdirs}, but received: {class_names}"
|
|
)
|
|
class_indices = dict(zip(class_names, range(len(class_names))))
|
|
|
|
# Build an index of the files
|
|
# in the different class subfolders.
|
|
pool = multiprocessing.pool.ThreadPool()
|
|
results = []
|
|
filenames = []
|
|
|
|
for dirpath in (tf.io.gfile.join(directory, subdir) for subdir in subdirs):
|
|
results.append(
|
|
pool.apply_async(
|
|
index_subdirectory,
|
|
(dirpath, class_indices, follow_links, formats),
|
|
)
|
|
)
|
|
labels_list = []
|
|
for res in results:
|
|
partial_filenames, partial_labels = res.get()
|
|
labels_list.append(partial_labels)
|
|
filenames += partial_filenames
|
|
if labels not in ("inferred", None):
|
|
if len(labels) != len(filenames):
|
|
raise ValueError(
|
|
"Expected the lengths of `labels` to match the number "
|
|
"of files in the target directory. len(labels) is "
|
|
f"{len(labels)} while we found {len(filenames)} files "
|
|
f"in directory {directory}."
|
|
)
|
|
class_names = sorted(set(labels))
|
|
else:
|
|
i = 0
|
|
labels = np.zeros((len(filenames),), dtype="int32")
|
|
for partial_labels in labels_list:
|
|
labels[i : i + len(partial_labels)] = partial_labels
|
|
i += len(partial_labels)
|
|
|
|
if labels is None:
|
|
io_utils.print_msg(f"Found {len(filenames)} files.")
|
|
else:
|
|
io_utils.print_msg(
|
|
f"Found {len(filenames)} files belonging "
|
|
f"to {len(class_names)} classes."
|
|
)
|
|
pool.close()
|
|
pool.join()
|
|
file_paths = [tf.io.gfile.join(directory, fname) for fname in filenames]
|
|
|
|
if shuffle:
|
|
# Shuffle globally to erase macro-structure
|
|
if seed is None:
|
|
seed = np.random.randint(1e6)
|
|
rng = np.random.RandomState(seed)
|
|
rng.shuffle(file_paths)
|
|
rng = np.random.RandomState(seed)
|
|
rng.shuffle(labels)
|
|
return file_paths, labels, class_names
|
|
|
|
|
|
def iter_valid_files(directory, follow_links, formats):
|
|
if not follow_links:
|
|
walk = tf.io.gfile.walk(directory)
|
|
else:
|
|
walk = os.walk(directory, followlinks=follow_links)
|
|
for root, _, files in sorted(walk, key=lambda x: x[0]):
|
|
for fname in sorted(files):
|
|
if fname.lower().endswith(formats):
|
|
yield root, fname
|
|
|
|
|
|
def index_subdirectory(directory, class_indices, follow_links, formats):
|
|
"""Recursively walks directory and list image paths and their class index.
|
|
|
|
Args:
|
|
directory: string, target directory.
|
|
class_indices: dict mapping class names to their index.
|
|
follow_links: boolean, whether to recursively follow subdirectories
|
|
(if False, we only list top-level images in `directory`).
|
|
formats: Allowlist of file extensions to index (e.g. ".jpg", ".txt").
|
|
|
|
Returns:
|
|
tuple `(filenames, labels)`. `filenames` is a list of relative file
|
|
paths, and `labels` is a list of integer labels corresponding to these
|
|
files.
|
|
"""
|
|
dirname = os.path.basename(directory)
|
|
valid_files = iter_valid_files(directory, follow_links, formats)
|
|
labels = []
|
|
filenames = []
|
|
for root, fname in valid_files:
|
|
labels.append(class_indices[dirname])
|
|
absolute_path = tf.io.gfile.join(root, fname)
|
|
relative_path = tf.io.gfile.join(
|
|
dirname, os.path.relpath(absolute_path, directory)
|
|
)
|
|
filenames.append(relative_path)
|
|
return filenames, labels
|
|
|
|
|
|
def get_training_or_validation_split(samples, labels, validation_split, subset):
|
|
"""Potentially restict samples & labels to a training or validation split.
|
|
|
|
Args:
|
|
samples: List of elements.
|
|
labels: List of corresponding labels.
|
|
validation_split: Float, fraction of data to reserve for validation.
|
|
subset: Subset of the data to return.
|
|
Either "training", "validation", or None. If None, we return all of the
|
|
data.
|
|
|
|
Returns:
|
|
tuple (samples, labels), potentially restricted to the specified subset.
|
|
"""
|
|
if not validation_split:
|
|
return samples, labels
|
|
|
|
num_val_samples = int(validation_split * len(samples))
|
|
if subset == "training":
|
|
print(f"Using {len(samples) - num_val_samples} files for training.")
|
|
samples = samples[:-num_val_samples]
|
|
labels = labels[:-num_val_samples]
|
|
elif subset == "validation":
|
|
print(f"Using {num_val_samples} files for validation.")
|
|
samples = samples[-num_val_samples:]
|
|
labels = labels[-num_val_samples:]
|
|
else:
|
|
raise ValueError(
|
|
'`subset` must be either "training" '
|
|
f'or "validation", received: {subset}'
|
|
)
|
|
return samples, labels
|
|
|
|
|
|
def labels_to_dataset(labels, label_mode, num_classes):
|
|
"""Create a tf.data.Dataset from the list/tuple of labels.
|
|
|
|
Args:
|
|
labels: list/tuple of labels to be converted into a tf.data.Dataset.
|
|
label_mode: String describing the encoding of `labels`. Options are:
|
|
- 'binary' indicates that the labels (there can be only 2) are encoded as
|
|
`float32` scalars with values 0 or 1 (e.g. for `binary_crossentropy`).
|
|
- 'categorical' means that the labels are mapped into a categorical
|
|
vector. (e.g. for `categorical_crossentropy` loss).
|
|
num_classes: number of classes of labels.
|
|
|
|
Returns:
|
|
A `Dataset` instance.
|
|
"""
|
|
label_ds = tf.data.Dataset.from_tensor_slices(labels)
|
|
if label_mode == "binary":
|
|
label_ds = label_ds.map(
|
|
lambda x: tf.expand_dims(tf.cast(x, "float32"), axis=-1),
|
|
num_parallel_calls=tf.data.AUTOTUNE,
|
|
)
|
|
elif label_mode == "categorical":
|
|
label_ds = label_ds.map(
|
|
lambda x: tf.one_hot(x, num_classes),
|
|
num_parallel_calls=tf.data.AUTOTUNE,
|
|
)
|
|
return label_ds
|
|
|
|
|
|
def check_validation_split_arg(validation_split, subset, shuffle, seed):
|
|
"""Raise errors in case of invalid argument values.
|
|
|
|
Args:
|
|
validation_split: float between 0 and 1, fraction of data to reserve for
|
|
validation.
|
|
subset: One of "training", "validation" or "both". Only used if
|
|
`validation_split` is set.
|
|
shuffle: Whether to shuffle the data. Either True or False.
|
|
seed: random seed for shuffling and transformations.
|
|
"""
|
|
if validation_split and not 0 < validation_split < 1:
|
|
raise ValueError(
|
|
"`validation_split` must be between 0 and 1, "
|
|
f"received: {validation_split}"
|
|
)
|
|
if (validation_split or subset) and not (validation_split and subset):
|
|
raise ValueError(
|
|
"If `subset` is set, `validation_split` must be set, and inversely."
|
|
)
|
|
if subset not in ("training", "validation", "both", None):
|
|
raise ValueError(
|
|
'`subset` must be either "training", '
|
|
f'"validation" or "both", received: {subset}'
|
|
)
|
|
if validation_split and shuffle and seed is None:
|
|
raise ValueError(
|
|
"If using `validation_split` and shuffling the data, you must "
|
|
"provide a `seed` argument, to make sure that there is no "
|
|
"overlap between the training and validation subset."
|
|
)
|