1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Keras image dataset loading utilities.""" 16# pylint: disable=g-classes-have-attributes 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import multiprocessing 22import os 23 24import numpy as np 25 26from tensorflow.python.data.ops import dataset_ops 27from tensorflow.python.ops import array_ops 28from tensorflow.python.ops import math_ops 29 30 31def index_directory(directory, 32 labels, 33 formats, 34 class_names=None, 35 shuffle=True, 36 seed=None, 37 follow_links=False): 38 """Make list of all files in the subdirs of `directory`, with their labels. 39 40 Args: 41 directory: The target directory (string). 42 labels: Either "inferred" 43 (labels are generated from the directory structure), 44 None (no labels), 45 or a list/tuple of integer labels of the same size as the number of 46 valid files found in the directory. Labels should be sorted according 47 to the alphanumeric order of the image file paths 48 (obtained via `os.walk(directory)` in Python). 49 formats: Allowlist of file extensions to index (e.g. ".jpg", ".txt"). 50 class_names: Only valid if "labels" is "inferred". This is the explict 51 list of class names (must match names of subdirectories). Used 52 to control the order of the classes 53 (otherwise alphanumerical order is used). 54 shuffle: Whether to shuffle the data. Default: True. 55 If set to False, sorts the data in alphanumeric order. 56 seed: Optional random seed for shuffling. 57 follow_links: Whether to visits subdirectories pointed to by symlinks. 58 59 Returns: 60 tuple (file_paths, labels, class_names). 61 file_paths: list of file paths (strings). 62 labels: list of matching integer labels (same length as file_paths) 63 class_names: names of the classes corresponding to these labels, in order. 64 """ 65 if labels is None: 66 # in the no-label case, index from the parent directory down. 67 subdirs = [''] 68 class_names = subdirs 69 else: 70 subdirs = [] 71 for subdir in sorted(os.listdir(directory)): 72 if os.path.isdir(os.path.join(directory, subdir)): 73 subdirs.append(subdir) 74 if not class_names: 75 class_names = subdirs 76 else: 77 if set(class_names) != set(subdirs): 78 raise ValueError( 79 'The `class_names` passed did not match the ' 80 'names of the subdirectories of the target directory. ' 81 'Expected: %s, but received: %s' % 82 (subdirs, class_names)) 83 class_indices = dict(zip(class_names, range(len(class_names)))) 84 85 # Build an index of the files 86 # in the different class subfolders. 87 pool = multiprocessing.pool.ThreadPool() 88 results = [] 89 filenames = [] 90 91 for dirpath in (os.path.join(directory, subdir) for subdir in subdirs): 92 results.append( 93 pool.apply_async(index_subdirectory, 94 (dirpath, class_indices, follow_links, formats))) 95 labels_list = [] 96 for res in results: 97 partial_filenames, partial_labels = res.get() 98 labels_list.append(partial_labels) 99 filenames += partial_filenames 100 if labels not in ('inferred', None): 101 if len(labels) != len(filenames): 102 raise ValueError('Expected the lengths of `labels` to match the number ' 103 'of files in the target directory. len(labels) is %s ' 104 'while we found %s files in %s.' % ( 105 len(labels), len(filenames), directory)) 106 else: 107 i = 0 108 labels = np.zeros((len(filenames),), dtype='int32') 109 for partial_labels in labels_list: 110 labels[i:i + len(partial_labels)] = partial_labels 111 i += len(partial_labels) 112 113 if labels is None: 114 print('Found %d files.' % (len(filenames),)) 115 else: 116 print('Found %d files belonging to %d classes.' % 117 (len(filenames), len(class_names))) 118 pool.close() 119 pool.join() 120 file_paths = [os.path.join(directory, fname) for fname in filenames] 121 122 if shuffle: 123 # Shuffle globally to erase macro-structure 124 if seed is None: 125 seed = np.random.randint(1e6) 126 rng = np.random.RandomState(seed) 127 rng.shuffle(file_paths) 128 rng = np.random.RandomState(seed) 129 rng.shuffle(labels) 130 return file_paths, labels, class_names 131 132 133def iter_valid_files(directory, follow_links, formats): 134 walk = os.walk(directory, followlinks=follow_links) 135 for root, _, files in sorted(walk, key=lambda x: x[0]): 136 for fname in sorted(files): 137 if fname.lower().endswith(formats): 138 yield root, fname 139 140 141def index_subdirectory(directory, class_indices, follow_links, formats): 142 """Recursively walks directory and list image paths and their class index. 143 144 Args: 145 directory: string, target directory. 146 class_indices: dict mapping class names to their index. 147 follow_links: boolean, whether to recursively follow subdirectories 148 (if False, we only list top-level images in `directory`). 149 formats: Allowlist of file extensions to index (e.g. ".jpg", ".txt"). 150 151 Returns: 152 tuple `(filenames, labels)`. `filenames` is a list of relative file 153 paths, and `labels` is a list of integer labels corresponding to these 154 files. 155 """ 156 dirname = os.path.basename(directory) 157 valid_files = iter_valid_files(directory, follow_links, formats) 158 labels = [] 159 filenames = [] 160 for root, fname in valid_files: 161 labels.append(class_indices[dirname]) 162 absolute_path = os.path.join(root, fname) 163 relative_path = os.path.join( 164 dirname, os.path.relpath(absolute_path, directory)) 165 filenames.append(relative_path) 166 return filenames, labels 167 168 169def get_training_or_validation_split(samples, labels, validation_split, subset): 170 """Potentially restict samples & labels to a training or validation split. 171 172 Args: 173 samples: List of elements. 174 labels: List of corresponding labels. 175 validation_split: Float, fraction of data to reserve for validation. 176 subset: Subset of the data to return. 177 Either "training", "validation", or None. If None, we return all of the 178 data. 179 180 Returns: 181 tuple (samples, labels), potentially restricted to the specified subset. 182 """ 183 if not validation_split: 184 return samples, labels 185 186 num_val_samples = int(validation_split * len(samples)) 187 if subset == 'training': 188 print('Using %d files for training.' % (len(samples) - num_val_samples,)) 189 samples = samples[:-num_val_samples] 190 labels = labels[:-num_val_samples] 191 elif subset == 'validation': 192 print('Using %d files for validation.' % (num_val_samples,)) 193 samples = samples[-num_val_samples:] 194 labels = labels[-num_val_samples:] 195 else: 196 raise ValueError('`subset` must be either "training" ' 197 'or "validation", received: %s' % (subset,)) 198 return samples, labels 199 200 201def labels_to_dataset(labels, label_mode, num_classes): 202 """Create a tf.data.Dataset from the list/tuple of labels. 203 204 Args: 205 labels: list/tuple of labels to be converted into a tf.data.Dataset. 206 label_mode: - 'binary' indicates that the labels (there can be only 2) are 207 encoded as `float32` scalars with values 0 or 1 (e.g. for 208 `binary_crossentropy`). - 'categorical' means that the labels are mapped 209 into a categorical vector. (e.g. for `categorical_crossentropy` loss). 210 num_classes: number of classes of labels. 211 """ 212 label_ds = dataset_ops.Dataset.from_tensor_slices(labels) 213 if label_mode == 'binary': 214 label_ds = label_ds.map( 215 lambda x: array_ops.expand_dims(math_ops.cast(x, 'float32'), axis=-1)) 216 elif label_mode == 'categorical': 217 label_ds = label_ds.map(lambda x: array_ops.one_hot(x, num_classes)) 218 return label_ds 219 220 221def check_validation_split_arg(validation_split, subset, shuffle, seed): 222 """Raise errors in case of invalid argument values. 223 224 Args: 225 shuffle: Whether to shuffle the data. Either True or False. 226 seed: random seed for shuffling and transformations. 227 validation_split: float between 0 and 1, fraction of data to reserve for 228 validation. 229 subset: One of "training" or "validation". Only used if `validation_split` 230 is set. 231 """ 232 if validation_split and not 0 < validation_split < 1: 233 raise ValueError( 234 '`validation_split` must be between 0 and 1, received: %s' % 235 (validation_split,)) 236 if (validation_split or subset) and not (validation_split and subset): 237 raise ValueError( 238 'If `subset` is set, `validation_split` must be set, and inversely.') 239 if subset not in ('training', 'validation', None): 240 raise ValueError('`subset` must be either "training" ' 241 'or "validation", received: %s' % (subset,)) 242 if validation_split and shuffle and seed is None: 243 raise ValueError( 244 'If using `validation_split` and shuffling the data, you must provide ' 245 'a `seed` argument, to make sure that there is no overlap between the ' 246 'training and validation subset.') 247