1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Keras image dataset loading utilities."""
16# pylint: disable=g-classes-have-attributes
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import numpy as np
22
23from tensorflow.python.data.ops import dataset_ops
24from tensorflow.python.keras.layers.preprocessing import image_preprocessing
25from tensorflow.python.keras.preprocessing import dataset_utils
26from tensorflow.python.keras.preprocessing import image as keras_image_ops
27from tensorflow.python.ops import image_ops
28from tensorflow.python.ops import io_ops
29from tensorflow.python.util.tf_export import keras_export
30
31
32ALLOWLIST_FORMATS = ('.bmp', '.gif', '.jpeg', '.jpg', '.png')
33
34
35@keras_export('keras.preprocessing.image_dataset_from_directory', v1=[])
36def image_dataset_from_directory(directory,
37                                 labels='inferred',
38                                 label_mode='int',
39                                 class_names=None,
40                                 color_mode='rgb',
41                                 batch_size=32,
42                                 image_size=(256, 256),
43                                 shuffle=True,
44                                 seed=None,
45                                 validation_split=None,
46                                 subset=None,
47                                 interpolation='bilinear',
48                                 follow_links=False,
49                                 smart_resize=False):
50  """Generates a `tf.data.Dataset` from image files in a directory.
51
52  If your directory structure is:
53
54  ```
55  main_directory/
56  ...class_a/
57  ......a_image_1.jpg
58  ......a_image_2.jpg
59  ...class_b/
60  ......b_image_1.jpg
61  ......b_image_2.jpg
62  ```
63
64  Then calling `image_dataset_from_directory(main_directory, labels='inferred')`
65  will return a `tf.data.Dataset` that yields batches of images from
66  the subdirectories `class_a` and `class_b`, together with labels
67  0 and 1 (0 corresponding to `class_a` and 1 corresponding to `class_b`).
68
69  Supported image formats: jpeg, png, bmp, gif.
70  Animated gifs are truncated to the first frame.
71
72  Args:
73    directory: Directory where the data is located.
74        If `labels` is "inferred", it should contain
75        subdirectories, each containing images for a class.
76        Otherwise, the directory structure is ignored.
77    labels: Either "inferred"
78        (labels are generated from the directory structure),
79        None (no labels),
80        or a list/tuple of integer labels of the same size as the number of
81        image files found in the directory. Labels should be sorted according
82        to the alphanumeric order of the image file paths
83        (obtained via `os.walk(directory)` in Python).
84    label_mode:
85        - 'int': means that the labels are encoded as integers
86            (e.g. for `sparse_categorical_crossentropy` loss).
87        - 'categorical' means that the labels are
88            encoded as a categorical vector
89            (e.g. for `categorical_crossentropy` loss).
90        - 'binary' means that the labels (there can be only 2)
91            are encoded as `float32` scalars with values 0 or 1
92            (e.g. for `binary_crossentropy`).
93        - None (no labels).
94    class_names: Only valid if "labels" is "inferred". This is the explict
95        list of class names (must match names of subdirectories). Used
96        to control the order of the classes
97        (otherwise alphanumerical order is used).
98    color_mode: One of "grayscale", "rgb", "rgba". Default: "rgb".
99        Whether the images will be converted to
100        have 1, 3, or 4 channels.
101    batch_size: Size of the batches of data. Default: 32.
102    image_size: Size to resize images to after they are read from disk.
103        Defaults to `(256, 256)`.
104        Since the pipeline processes batches of images that must all have
105        the same size, this must be provided.
106    shuffle: Whether to shuffle the data. Default: True.
107        If set to False, sorts the data in alphanumeric order.
108    seed: Optional random seed for shuffling and transformations.
109    validation_split: Optional float between 0 and 1,
110        fraction of data to reserve for validation.
111    subset: One of "training" or "validation".
112        Only used if `validation_split` is set.
113    interpolation: String, the interpolation method used when resizing images.
114      Defaults to `bilinear`. Supports `bilinear`, `nearest`, `bicubic`,
115      `area`, `lanczos3`, `lanczos5`, `gaussian`, `mitchellcubic`.
116    follow_links: Whether to visits subdirectories pointed to by symlinks.
117        Defaults to False.
118    smart_resize: If True, the resizing function used will be
119      `tf.keras.preprocessing.image.smart_resize`, which preserves the aspect
120      ratio of the original image by using a mixture of resizing and cropping.
121      If False (default), the resizing function is `tf.image.resize`, which
122      does not preserve aspect ratio.
123
124  Returns:
125    A `tf.data.Dataset` object.
126      - If `label_mode` is None, it yields `float32` tensors of shape
127        `(batch_size, image_size[0], image_size[1], num_channels)`,
128        encoding images (see below for rules regarding `num_channels`).
129      - Otherwise, it yields a tuple `(images, labels)`, where `images`
130        has shape `(batch_size, image_size[0], image_size[1], num_channels)`,
131        and `labels` follows the format described below.
132
133  Rules regarding labels format:
134    - if `label_mode` is `int`, the labels are an `int32` tensor of shape
135      `(batch_size,)`.
136    - if `label_mode` is `binary`, the labels are a `float32` tensor of
137      1s and 0s of shape `(batch_size, 1)`.
138    - if `label_mode` is `categorial`, the labels are a `float32` tensor
139      of shape `(batch_size, num_classes)`, representing a one-hot
140      encoding of the class index.
141
142  Rules regarding number of channels in the yielded images:
143    - if `color_mode` is `grayscale`,
144      there's 1 channel in the image tensors.
145    - if `color_mode` is `rgb`,
146      there are 3 channel in the image tensors.
147    - if `color_mode` is `rgba`,
148      there are 4 channel in the image tensors.
149  """
150  if labels not in ('inferred', None):
151    if not isinstance(labels, (list, tuple)):
152      raise ValueError(
153          '`labels` argument should be a list/tuple of integer labels, of '
154          'the same size as the number of image files in the target '
155          'directory. If you wish to infer the labels from the subdirectory '
156          'names in the target directory, pass `labels="inferred"`. '
157          'If you wish to get a dataset that only contains images '
158          '(no labels), pass `label_mode=None`.')
159    if class_names:
160      raise ValueError('You can only pass `class_names` if the labels are '
161                       'inferred from the subdirectory names in the target '
162                       'directory (`labels="inferred"`).')
163  if label_mode not in {'int', 'categorical', 'binary', None}:
164    raise ValueError(
165        '`label_mode` argument must be one of "int", "categorical", "binary", '
166        'or None. Received: %s' % (label_mode,))
167  if labels is None or label_mode is None:
168    labels = None
169    label_mode = None
170  if color_mode == 'rgb':
171    num_channels = 3
172  elif color_mode == 'rgba':
173    num_channels = 4
174  elif color_mode == 'grayscale':
175    num_channels = 1
176  else:
177    raise ValueError(
178        '`color_mode` must be one of {"rbg", "rgba", "grayscale"}. '
179        'Received: %s' % (color_mode,))
180  interpolation = image_preprocessing.get_interpolation(interpolation)
181  dataset_utils.check_validation_split_arg(
182      validation_split, subset, shuffle, seed)
183
184  if seed is None:
185    seed = np.random.randint(1e6)
186  image_paths, labels, class_names = dataset_utils.index_directory(
187      directory,
188      labels,
189      formats=ALLOWLIST_FORMATS,
190      class_names=class_names,
191      shuffle=shuffle,
192      seed=seed,
193      follow_links=follow_links)
194
195  if label_mode == 'binary' and len(class_names) != 2:
196    raise ValueError(
197        'When passing `label_mode="binary", there must exactly 2 classes. '
198        'Found the following classes: %s' % (class_names,))
199
200  image_paths, labels = dataset_utils.get_training_or_validation_split(
201      image_paths, labels, validation_split, subset)
202  if not image_paths:
203    raise ValueError('No images found.')
204
205  dataset = paths_and_labels_to_dataset(
206      image_paths=image_paths,
207      image_size=image_size,
208      num_channels=num_channels,
209      labels=labels,
210      label_mode=label_mode,
211      num_classes=len(class_names),
212      interpolation=interpolation,
213      smart_resize=smart_resize)
214  if shuffle:
215    # Shuffle locally at each iteration
216    dataset = dataset.shuffle(buffer_size=batch_size * 8, seed=seed)
217  dataset = dataset.batch(batch_size)
218  # Users may need to reference `class_names`.
219  dataset.class_names = class_names
220  # Include file paths for images as attribute.
221  dataset.file_paths = image_paths
222  return dataset
223
224
225def paths_and_labels_to_dataset(image_paths,
226                                image_size,
227                                num_channels,
228                                labels,
229                                label_mode,
230                                num_classes,
231                                interpolation,
232                                smart_resize=False):
233  """Constructs a dataset of images and labels."""
234  # TODO(fchollet): consider making num_parallel_calls settable
235  path_ds = dataset_ops.Dataset.from_tensor_slices(image_paths)
236  args = (image_size, num_channels, interpolation, smart_resize)
237  img_ds = path_ds.map(
238      lambda x: load_image(x, *args))
239  if label_mode:
240    label_ds = dataset_utils.labels_to_dataset(labels, label_mode, num_classes)
241    img_ds = dataset_ops.Dataset.zip((img_ds, label_ds))
242  return img_ds
243
244
245def load_image(path, image_size, num_channels, interpolation,
246               smart_resize=False):
247  """Load an image from a path and resize it."""
248  img = io_ops.read_file(path)
249  img = image_ops.decode_image(
250      img, channels=num_channels, expand_animations=False)
251  if smart_resize:
252    img = keras_image_ops.smart_resize(img, image_size,
253                                       interpolation=interpolation)
254  else:
255    img = image_ops.resize_images_v2(img, image_size, method=interpolation)
256  img.set_shape((image_size[0], image_size[1], num_channels))
257  return img
258