1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Keras image dataset loading utilities."""
16# pylint: disable=g-classes-have-attributes
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import multiprocessing
22import os
23
24import numpy as np
25
26from tensorflow.python.data.ops import dataset_ops
27from tensorflow.python.ops import array_ops
28from tensorflow.python.ops import math_ops
29
30
31def index_directory(directory,
32                    labels,
33                    formats,
34                    class_names=None,
35                    shuffle=True,
36                    seed=None,
37                    follow_links=False):
38  """Make list of all files in the subdirs of `directory`, with their labels.
39
40  Args:
41    directory: The target directory (string).
42    labels: Either "inferred"
43        (labels are generated from the directory structure),
44        None (no labels),
45        or a list/tuple of integer labels of the same size as the number of
46        valid files found in the directory. Labels should be sorted according
47        to the alphanumeric order of the image file paths
48        (obtained via `os.walk(directory)` in Python).
49    formats: Allowlist of file extensions to index (e.g. ".jpg", ".txt").
50    class_names: Only valid if "labels" is "inferred". This is the explict
51        list of class names (must match names of subdirectories). Used
52        to control the order of the classes
53        (otherwise alphanumerical order is used).
54    shuffle: Whether to shuffle the data. Default: True.
55        If set to False, sorts the data in alphanumeric order.
56    seed: Optional random seed for shuffling.
57    follow_links: Whether to visits subdirectories pointed to by symlinks.
58
59  Returns:
60    tuple (file_paths, labels, class_names).
61      file_paths: list of file paths (strings).
62      labels: list of matching integer labels (same length as file_paths)
63      class_names: names of the classes corresponding to these labels, in order.
64  """
65  if labels is None:
66    # in the no-label case, index from the parent directory down.
67    subdirs = ['']
68    class_names = subdirs
69  else:
70    subdirs = []
71    for subdir in sorted(os.listdir(directory)):
72      if os.path.isdir(os.path.join(directory, subdir)):
73        subdirs.append(subdir)
74    if not class_names:
75      class_names = subdirs
76    else:
77      if set(class_names) != set(subdirs):
78        raise ValueError(
79            'The `class_names` passed did not match the '
80            'names of the subdirectories of the target directory. '
81            'Expected: %s, but received: %s' %
82            (subdirs, class_names))
83  class_indices = dict(zip(class_names, range(len(class_names))))
84
85  # Build an index of the files
86  # in the different class subfolders.
87  pool = multiprocessing.pool.ThreadPool()
88  results = []
89  filenames = []
90
91  for dirpath in (os.path.join(directory, subdir) for subdir in subdirs):
92    results.append(
93        pool.apply_async(index_subdirectory,
94                         (dirpath, class_indices, follow_links, formats)))
95  labels_list = []
96  for res in results:
97    partial_filenames, partial_labels = res.get()
98    labels_list.append(partial_labels)
99    filenames += partial_filenames
100  if labels not in ('inferred', None):
101    if len(labels) != len(filenames):
102      raise ValueError('Expected the lengths of `labels` to match the number '
103                       'of files in the target directory. len(labels) is %s '
104                       'while we found %s files in %s.' % (
105                           len(labels), len(filenames), directory))
106  else:
107    i = 0
108    labels = np.zeros((len(filenames),), dtype='int32')
109    for partial_labels in labels_list:
110      labels[i:i + len(partial_labels)] = partial_labels
111      i += len(partial_labels)
112
113  if labels is None:
114    print('Found %d files.' % (len(filenames),))
115  else:
116    print('Found %d files belonging to %d classes.' %
117          (len(filenames), len(class_names)))
118  pool.close()
119  pool.join()
120  file_paths = [os.path.join(directory, fname) for fname in filenames]
121
122  if shuffle:
123    # Shuffle globally to erase macro-structure
124    if seed is None:
125      seed = np.random.randint(1e6)
126    rng = np.random.RandomState(seed)
127    rng.shuffle(file_paths)
128    rng = np.random.RandomState(seed)
129    rng.shuffle(labels)
130  return file_paths, labels, class_names
131
132
133def iter_valid_files(directory, follow_links, formats):
134  walk = os.walk(directory, followlinks=follow_links)
135  for root, _, files in sorted(walk, key=lambda x: x[0]):
136    for fname in sorted(files):
137      if fname.lower().endswith(formats):
138        yield root, fname
139
140
141def index_subdirectory(directory, class_indices, follow_links, formats):
142  """Recursively walks directory and list image paths and their class index.
143
144  Args:
145    directory: string, target directory.
146    class_indices: dict mapping class names to their index.
147    follow_links: boolean, whether to recursively follow subdirectories
148      (if False, we only list top-level images in `directory`).
149    formats: Allowlist of file extensions to index (e.g. ".jpg", ".txt").
150
151  Returns:
152    tuple `(filenames, labels)`. `filenames` is a list of relative file
153      paths, and `labels` is a list of integer labels corresponding to these
154      files.
155  """
156  dirname = os.path.basename(directory)
157  valid_files = iter_valid_files(directory, follow_links, formats)
158  labels = []
159  filenames = []
160  for root, fname in valid_files:
161    labels.append(class_indices[dirname])
162    absolute_path = os.path.join(root, fname)
163    relative_path = os.path.join(
164        dirname, os.path.relpath(absolute_path, directory))
165    filenames.append(relative_path)
166  return filenames, labels
167
168
169def get_training_or_validation_split(samples, labels, validation_split, subset):
170  """Potentially restict samples & labels to a training or validation split.
171
172  Args:
173    samples: List of elements.
174    labels: List of corresponding labels.
175    validation_split: Float, fraction of data to reserve for validation.
176    subset: Subset of the data to return.
177      Either "training", "validation", or None. If None, we return all of the
178      data.
179
180  Returns:
181    tuple (samples, labels), potentially restricted to the specified subset.
182  """
183  if not validation_split:
184    return samples, labels
185
186  num_val_samples = int(validation_split * len(samples))
187  if subset == 'training':
188    print('Using %d files for training.' % (len(samples) - num_val_samples,))
189    samples = samples[:-num_val_samples]
190    labels = labels[:-num_val_samples]
191  elif subset == 'validation':
192    print('Using %d files for validation.' % (num_val_samples,))
193    samples = samples[-num_val_samples:]
194    labels = labels[-num_val_samples:]
195  else:
196    raise ValueError('`subset` must be either "training" '
197                     'or "validation", received: %s' % (subset,))
198  return samples, labels
199
200
201def labels_to_dataset(labels, label_mode, num_classes):
202  """Create a tf.data.Dataset from the list/tuple of labels.
203
204  Args:
205    labels: list/tuple of labels to be converted into a tf.data.Dataset.
206    label_mode: - 'binary' indicates that the labels (there can be only 2) are
207      encoded as `float32` scalars with values 0 or 1 (e.g. for
208      `binary_crossentropy`). - 'categorical' means that the labels are mapped
209      into a categorical vector. (e.g. for `categorical_crossentropy` loss).
210    num_classes: number of classes of labels.
211  """
212  label_ds = dataset_ops.Dataset.from_tensor_slices(labels)
213  if label_mode == 'binary':
214    label_ds = label_ds.map(
215        lambda x: array_ops.expand_dims(math_ops.cast(x, 'float32'), axis=-1))
216  elif label_mode == 'categorical':
217    label_ds = label_ds.map(lambda x: array_ops.one_hot(x, num_classes))
218  return label_ds
219
220
221def check_validation_split_arg(validation_split, subset, shuffle, seed):
222  """Raise errors in case of invalid argument values.
223
224  Args:
225    shuffle: Whether to shuffle the data. Either True or False.
226    seed: random seed for shuffling and transformations.
227    validation_split: float between 0 and 1, fraction of data to reserve for
228      validation.
229    subset: One of "training" or "validation". Only used if `validation_split`
230      is set.
231  """
232  if validation_split and not 0 < validation_split < 1:
233    raise ValueError(
234        '`validation_split` must be between 0 and 1, received: %s' %
235        (validation_split,))
236  if (validation_split or subset) and not (validation_split and subset):
237    raise ValueError(
238        'If `subset` is set, `validation_split` must be set, and inversely.')
239  if subset not in ('training', 'validation', None):
240    raise ValueError('`subset` must be either "training" '
241                     'or "validation", received: %s' % (subset,))
242  if validation_split and shuffle and seed is None:
243    raise ValueError(
244        'If using `validation_split` and shuffling the data, you must provide '
245        'a `seed` argument, to make sure that there is no overlap between the '
246        'training and validation subset.')
247