1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Keras image dataset loading utilities.""" 16# pylint: disable=g-classes-have-attributes 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import numpy as np 22 23from tensorflow.python.data.ops import dataset_ops 24from tensorflow.python.keras.layers.preprocessing import image_preprocessing 25from tensorflow.python.keras.preprocessing import dataset_utils 26from tensorflow.python.keras.preprocessing import image as keras_image_ops 27from tensorflow.python.ops import image_ops 28from tensorflow.python.ops import io_ops 29from tensorflow.python.util.tf_export import keras_export 30 31 32ALLOWLIST_FORMATS = ('.bmp', '.gif', '.jpeg', '.jpg', '.png') 33 34 35@keras_export('keras.preprocessing.image_dataset_from_directory', v1=[]) 36def image_dataset_from_directory(directory, 37 labels='inferred', 38 label_mode='int', 39 class_names=None, 40 color_mode='rgb', 41 batch_size=32, 42 image_size=(256, 256), 43 shuffle=True, 44 seed=None, 45 validation_split=None, 46 subset=None, 47 interpolation='bilinear', 48 follow_links=False, 49 smart_resize=False): 50 """Generates a `tf.data.Dataset` from image files in a directory. 51 52 If your directory structure is: 53 54 ``` 55 main_directory/ 56 ...class_a/ 57 ......a_image_1.jpg 58 ......a_image_2.jpg 59 ...class_b/ 60 ......b_image_1.jpg 61 ......b_image_2.jpg 62 ``` 63 64 Then calling `image_dataset_from_directory(main_directory, labels='inferred')` 65 will return a `tf.data.Dataset` that yields batches of images from 66 the subdirectories `class_a` and `class_b`, together with labels 67 0 and 1 (0 corresponding to `class_a` and 1 corresponding to `class_b`). 68 69 Supported image formats: jpeg, png, bmp, gif. 70 Animated gifs are truncated to the first frame. 71 72 Args: 73 directory: Directory where the data is located. 74 If `labels` is "inferred", it should contain 75 subdirectories, each containing images for a class. 76 Otherwise, the directory structure is ignored. 77 labels: Either "inferred" 78 (labels are generated from the directory structure), 79 None (no labels), 80 or a list/tuple of integer labels of the same size as the number of 81 image files found in the directory. Labels should be sorted according 82 to the alphanumeric order of the image file paths 83 (obtained via `os.walk(directory)` in Python). 84 label_mode: 85 - 'int': means that the labels are encoded as integers 86 (e.g. for `sparse_categorical_crossentropy` loss). 87 - 'categorical' means that the labels are 88 encoded as a categorical vector 89 (e.g. for `categorical_crossentropy` loss). 90 - 'binary' means that the labels (there can be only 2) 91 are encoded as `float32` scalars with values 0 or 1 92 (e.g. for `binary_crossentropy`). 93 - None (no labels). 94 class_names: Only valid if "labels" is "inferred". This is the explict 95 list of class names (must match names of subdirectories). Used 96 to control the order of the classes 97 (otherwise alphanumerical order is used). 98 color_mode: One of "grayscale", "rgb", "rgba". Default: "rgb". 99 Whether the images will be converted to 100 have 1, 3, or 4 channels. 101 batch_size: Size of the batches of data. Default: 32. 102 image_size: Size to resize images to after they are read from disk. 103 Defaults to `(256, 256)`. 104 Since the pipeline processes batches of images that must all have 105 the same size, this must be provided. 106 shuffle: Whether to shuffle the data. Default: True. 107 If set to False, sorts the data in alphanumeric order. 108 seed: Optional random seed for shuffling and transformations. 109 validation_split: Optional float between 0 and 1, 110 fraction of data to reserve for validation. 111 subset: One of "training" or "validation". 112 Only used if `validation_split` is set. 113 interpolation: String, the interpolation method used when resizing images. 114 Defaults to `bilinear`. Supports `bilinear`, `nearest`, `bicubic`, 115 `area`, `lanczos3`, `lanczos5`, `gaussian`, `mitchellcubic`. 116 follow_links: Whether to visits subdirectories pointed to by symlinks. 117 Defaults to False. 118 smart_resize: If True, the resizing function used will be 119 `tf.keras.preprocessing.image.smart_resize`, which preserves the aspect 120 ratio of the original image by using a mixture of resizing and cropping. 121 If False (default), the resizing function is `tf.image.resize`, which 122 does not preserve aspect ratio. 123 124 Returns: 125 A `tf.data.Dataset` object. 126 - If `label_mode` is None, it yields `float32` tensors of shape 127 `(batch_size, image_size[0], image_size[1], num_channels)`, 128 encoding images (see below for rules regarding `num_channels`). 129 - Otherwise, it yields a tuple `(images, labels)`, where `images` 130 has shape `(batch_size, image_size[0], image_size[1], num_channels)`, 131 and `labels` follows the format described below. 132 133 Rules regarding labels format: 134 - if `label_mode` is `int`, the labels are an `int32` tensor of shape 135 `(batch_size,)`. 136 - if `label_mode` is `binary`, the labels are a `float32` tensor of 137 1s and 0s of shape `(batch_size, 1)`. 138 - if `label_mode` is `categorial`, the labels are a `float32` tensor 139 of shape `(batch_size, num_classes)`, representing a one-hot 140 encoding of the class index. 141 142 Rules regarding number of channels in the yielded images: 143 - if `color_mode` is `grayscale`, 144 there's 1 channel in the image tensors. 145 - if `color_mode` is `rgb`, 146 there are 3 channel in the image tensors. 147 - if `color_mode` is `rgba`, 148 there are 4 channel in the image tensors. 149 """ 150 if labels not in ('inferred', None): 151 if not isinstance(labels, (list, tuple)): 152 raise ValueError( 153 '`labels` argument should be a list/tuple of integer labels, of ' 154 'the same size as the number of image files in the target ' 155 'directory. If you wish to infer the labels from the subdirectory ' 156 'names in the target directory, pass `labels="inferred"`. ' 157 'If you wish to get a dataset that only contains images ' 158 '(no labels), pass `label_mode=None`.') 159 if class_names: 160 raise ValueError('You can only pass `class_names` if the labels are ' 161 'inferred from the subdirectory names in the target ' 162 'directory (`labels="inferred"`).') 163 if label_mode not in {'int', 'categorical', 'binary', None}: 164 raise ValueError( 165 '`label_mode` argument must be one of "int", "categorical", "binary", ' 166 'or None. Received: %s' % (label_mode,)) 167 if labels is None or label_mode is None: 168 labels = None 169 label_mode = None 170 if color_mode == 'rgb': 171 num_channels = 3 172 elif color_mode == 'rgba': 173 num_channels = 4 174 elif color_mode == 'grayscale': 175 num_channels = 1 176 else: 177 raise ValueError( 178 '`color_mode` must be one of {"rbg", "rgba", "grayscale"}. ' 179 'Received: %s' % (color_mode,)) 180 interpolation = image_preprocessing.get_interpolation(interpolation) 181 dataset_utils.check_validation_split_arg( 182 validation_split, subset, shuffle, seed) 183 184 if seed is None: 185 seed = np.random.randint(1e6) 186 image_paths, labels, class_names = dataset_utils.index_directory( 187 directory, 188 labels, 189 formats=ALLOWLIST_FORMATS, 190 class_names=class_names, 191 shuffle=shuffle, 192 seed=seed, 193 follow_links=follow_links) 194 195 if label_mode == 'binary' and len(class_names) != 2: 196 raise ValueError( 197 'When passing `label_mode="binary", there must exactly 2 classes. ' 198 'Found the following classes: %s' % (class_names,)) 199 200 image_paths, labels = dataset_utils.get_training_or_validation_split( 201 image_paths, labels, validation_split, subset) 202 if not image_paths: 203 raise ValueError('No images found.') 204 205 dataset = paths_and_labels_to_dataset( 206 image_paths=image_paths, 207 image_size=image_size, 208 num_channels=num_channels, 209 labels=labels, 210 label_mode=label_mode, 211 num_classes=len(class_names), 212 interpolation=interpolation, 213 smart_resize=smart_resize) 214 if shuffle: 215 # Shuffle locally at each iteration 216 dataset = dataset.shuffle(buffer_size=batch_size * 8, seed=seed) 217 dataset = dataset.batch(batch_size) 218 # Users may need to reference `class_names`. 219 dataset.class_names = class_names 220 # Include file paths for images as attribute. 221 dataset.file_paths = image_paths 222 return dataset 223 224 225def paths_and_labels_to_dataset(image_paths, 226 image_size, 227 num_channels, 228 labels, 229 label_mode, 230 num_classes, 231 interpolation, 232 smart_resize=False): 233 """Constructs a dataset of images and labels.""" 234 # TODO(fchollet): consider making num_parallel_calls settable 235 path_ds = dataset_ops.Dataset.from_tensor_slices(image_paths) 236 args = (image_size, num_channels, interpolation, smart_resize) 237 img_ds = path_ds.map( 238 lambda x: load_image(x, *args)) 239 if label_mode: 240 label_ds = dataset_utils.labels_to_dataset(labels, label_mode, num_classes) 241 img_ds = dataset_ops.Dataset.zip((img_ds, label_ds)) 242 return img_ds 243 244 245def load_image(path, image_size, num_channels, interpolation, 246 smart_resize=False): 247 """Load an image from a path and resize it.""" 248 img = io_ops.read_file(path) 249 img = image_ops.decode_image( 250 img, channels=num_channels, expand_animations=False) 251 if smart_resize: 252 img = keras_image_ops.smart_resize(img, image_size, 253 interpolation=interpolation) 254 else: 255 img = image_ops.resize_images_v2(img, image_size, method=interpolation) 256 img.set_shape((image_size[0], image_size[1], num_channels)) 257 return img 258