1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for image_dataset."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import os
22import shutil
23
24import numpy as np
25
26from tensorflow.python.compat import v2_compat
27from tensorflow.python.eager import def_function
28from tensorflow.python.keras import keras_parameterized
29from tensorflow.python.keras.preprocessing import image as image_preproc
30from tensorflow.python.keras.preprocessing import image_dataset
31from tensorflow.python.platform import test
32
33try:
34  import PIL  # pylint:disable=g-import-not-at-top
35except ImportError:
36  PIL = None
37
38
39class ImageDatasetFromDirectoryTest(keras_parameterized.TestCase):
40
41  def _get_images(self, count=16, color_mode='rgb'):
42    width = height = 24
43    imgs = []
44    for _ in range(count):
45      if color_mode == 'grayscale':
46        img = np.random.randint(0, 256, size=(height, width, 1))
47      elif color_mode == 'rgba':
48        img = np.random.randint(0, 256, size=(height, width, 4))
49      else:
50        img = np.random.randint(0, 256, size=(height, width, 3))
51      img = image_preproc.array_to_img(img)
52      imgs.append(img)
53    return imgs
54
55  def _prepare_directory(self,
56                         num_classes=2,
57                         grayscale=False,
58                         nested_dirs=False,
59                         color_mode='rgb',
60                         count=16):
61    # Get a unique temp directory
62    temp_dir = os.path.join(self.get_temp_dir(), str(np.random.randint(1e6)))
63    os.mkdir(temp_dir)
64    self.addCleanup(shutil.rmtree, temp_dir)
65
66    # Generate paths to class subdirectories
67    paths = []
68    for class_index in range(num_classes):
69      class_directory = 'class_%s' % (class_index,)
70      if nested_dirs:
71        class_paths = [
72            class_directory, os.path.join(class_directory, 'subfolder_1'),
73            os.path.join(class_directory, 'subfolder_2'), os.path.join(
74                class_directory, 'subfolder_1', 'sub-subfolder')
75        ]
76      else:
77        class_paths = [class_directory]
78      for path in class_paths:
79        os.mkdir(os.path.join(temp_dir, path))
80      paths += class_paths
81
82    # Save images to the paths
83    i = 0
84    for img in self._get_images(color_mode=color_mode, count=count):
85      path = paths[i % len(paths)]
86      if color_mode == 'rgb':
87        ext = 'jpg'
88      else:
89        ext = 'png'
90      filename = os.path.join(path, 'image_%s.%s' % (i, ext))
91      img.save(os.path.join(temp_dir, filename))
92      i += 1
93    return temp_dir
94
95  def test_image_dataset_from_directory_standalone(self):
96    # Test retrieving images without labels from a directory and its subdirs.
97    if PIL is None:
98      return  # Skip test if PIL is not available.
99
100    # Save a few extra images in the parent directory.
101    directory = self._prepare_directory(count=7, num_classes=2)
102    for i, img in enumerate(self._get_images(3)):
103      filename = 'image_%s.jpg' % (i,)
104      img.save(os.path.join(directory, filename))
105
106    dataset = image_dataset.image_dataset_from_directory(
107        directory, batch_size=5, image_size=(18, 18), labels=None)
108    batch = next(iter(dataset))
109    # We return plain images
110    self.assertEqual(batch.shape, (5, 18, 18, 3))
111    self.assertEqual(batch.dtype.name, 'float32')
112    # Count samples
113    batch_count = 0
114    sample_count = 0
115    for batch in dataset:
116      batch_count += 1
117      sample_count += batch.shape[0]
118    self.assertEqual(batch_count, 2)
119    self.assertEqual(sample_count, 10)
120
121  def test_image_dataset_from_directory_binary(self):
122    if PIL is None:
123      return  # Skip test if PIL is not available.
124
125    directory = self._prepare_directory(num_classes=2)
126    dataset = image_dataset.image_dataset_from_directory(
127        directory, batch_size=8, image_size=(18, 18), label_mode='int')
128    batch = next(iter(dataset))
129    self.assertLen(batch, 2)
130    self.assertEqual(batch[0].shape, (8, 18, 18, 3))
131    self.assertEqual(batch[0].dtype.name, 'float32')
132    self.assertEqual(batch[1].shape, (8,))
133    self.assertEqual(batch[1].dtype.name, 'int32')
134
135    dataset = image_dataset.image_dataset_from_directory(
136        directory, batch_size=8, image_size=(18, 18), label_mode='binary')
137    batch = next(iter(dataset))
138    self.assertLen(batch, 2)
139    self.assertEqual(batch[0].shape, (8, 18, 18, 3))
140    self.assertEqual(batch[0].dtype.name, 'float32')
141    self.assertEqual(batch[1].shape, (8, 1))
142    self.assertEqual(batch[1].dtype.name, 'float32')
143
144    dataset = image_dataset.image_dataset_from_directory(
145        directory, batch_size=8, image_size=(18, 18), label_mode='categorical')
146    batch = next(iter(dataset))
147    self.assertLen(batch, 2)
148    self.assertEqual(batch[0].shape, (8, 18, 18, 3))
149    self.assertEqual(batch[0].dtype.name, 'float32')
150    self.assertEqual(batch[1].shape, (8, 2))
151    self.assertEqual(batch[1].dtype.name, 'float32')
152
153  def test_static_shape_in_graph(self):
154    if PIL is None:
155      return  # Skip test if PIL is not available.
156
157    directory = self._prepare_directory(num_classes=2)
158    dataset = image_dataset.image_dataset_from_directory(
159        directory, batch_size=8, image_size=(18, 18), label_mode='int')
160    test_case = self
161
162    @def_function.function
163    def symbolic_fn(ds):
164      for x, _ in ds.take(1):
165        test_case.assertListEqual(x.shape.as_list(), [None, 18, 18, 3])
166
167    symbolic_fn(dataset)
168
169  def test_sample_count(self):
170    if PIL is None:
171      return  # Skip test if PIL is not available.
172
173    directory = self._prepare_directory(num_classes=4, count=15)
174    dataset = image_dataset.image_dataset_from_directory(
175        directory, batch_size=8, image_size=(18, 18), label_mode=None)
176    sample_count = 0
177    for batch in dataset:
178      sample_count += batch.shape[0]
179    self.assertEqual(sample_count, 15)
180
181  def test_image_dataset_from_directory_multiclass(self):
182    if PIL is None:
183      return  # Skip test if PIL is not available.
184
185    directory = self._prepare_directory(num_classes=4, count=15)
186
187    dataset = image_dataset.image_dataset_from_directory(
188        directory, batch_size=8, image_size=(18, 18), label_mode=None)
189    batch = next(iter(dataset))
190    self.assertEqual(batch.shape, (8, 18, 18, 3))
191
192    dataset = image_dataset.image_dataset_from_directory(
193        directory, batch_size=8, image_size=(18, 18), label_mode=None)
194    sample_count = 0
195    iterator = iter(dataset)
196    for batch in dataset:
197      sample_count += next(iterator).shape[0]
198    self.assertEqual(sample_count, 15)
199
200    dataset = image_dataset.image_dataset_from_directory(
201        directory, batch_size=8, image_size=(18, 18), label_mode='int')
202    batch = next(iter(dataset))
203    self.assertLen(batch, 2)
204    self.assertEqual(batch[0].shape, (8, 18, 18, 3))
205    self.assertEqual(batch[0].dtype.name, 'float32')
206    self.assertEqual(batch[1].shape, (8,))
207    self.assertEqual(batch[1].dtype.name, 'int32')
208
209    dataset = image_dataset.image_dataset_from_directory(
210        directory, batch_size=8, image_size=(18, 18), label_mode='categorical')
211    batch = next(iter(dataset))
212    self.assertLen(batch, 2)
213    self.assertEqual(batch[0].shape, (8, 18, 18, 3))
214    self.assertEqual(batch[0].dtype.name, 'float32')
215    self.assertEqual(batch[1].shape, (8, 4))
216    self.assertEqual(batch[1].dtype.name, 'float32')
217
218  def test_image_dataset_from_directory_color_modes(self):
219    if PIL is None:
220      return  # Skip test if PIL is not available.
221
222    directory = self._prepare_directory(num_classes=4, color_mode='rgba')
223    dataset = image_dataset.image_dataset_from_directory(
224        directory, batch_size=8, image_size=(18, 18), color_mode='rgba')
225    batch = next(iter(dataset))
226    self.assertLen(batch, 2)
227    self.assertEqual(batch[0].shape, (8, 18, 18, 4))
228    self.assertEqual(batch[0].dtype.name, 'float32')
229
230    directory = self._prepare_directory(num_classes=4, color_mode='grayscale')
231    dataset = image_dataset.image_dataset_from_directory(
232        directory, batch_size=8, image_size=(18, 18), color_mode='grayscale')
233    batch = next(iter(dataset))
234    self.assertLen(batch, 2)
235    self.assertEqual(batch[0].shape, (8, 18, 18, 1))
236    self.assertEqual(batch[0].dtype.name, 'float32')
237
238  def test_image_dataset_from_directory_validation_split(self):
239    if PIL is None:
240      return  # Skip test if PIL is not available.
241
242    directory = self._prepare_directory(num_classes=2, count=10)
243    dataset = image_dataset.image_dataset_from_directory(
244        directory, batch_size=10, image_size=(18, 18),
245        validation_split=0.2, subset='training', seed=1337)
246    batch = next(iter(dataset))
247    self.assertLen(batch, 2)
248    self.assertEqual(batch[0].shape, (8, 18, 18, 3))
249    dataset = image_dataset.image_dataset_from_directory(
250        directory, batch_size=10, image_size=(18, 18),
251        validation_split=0.2, subset='validation', seed=1337)
252    batch = next(iter(dataset))
253    self.assertLen(batch, 2)
254    self.assertEqual(batch[0].shape, (2, 18, 18, 3))
255
256  def test_image_dataset_from_directory_manual_labels(self):
257    if PIL is None:
258      return  # Skip test if PIL is not available.
259
260    directory = self._prepare_directory(num_classes=2, count=2)
261    dataset = image_dataset.image_dataset_from_directory(
262        directory, batch_size=8, image_size=(18, 18),
263        labels=[0, 1], shuffle=False)
264    batch = next(iter(dataset))
265    self.assertLen(batch, 2)
266    self.assertAllClose(batch[1], [0, 1])
267
268  def test_image_dataset_from_directory_follow_links(self):
269    if PIL is None:
270      return  # Skip test if PIL is not available.
271
272    directory = self._prepare_directory(num_classes=2, count=25,
273                                        nested_dirs=True)
274    dataset = image_dataset.image_dataset_from_directory(
275        directory, batch_size=8, image_size=(18, 18), label_mode=None,
276        follow_links=True)
277    sample_count = 0
278    for batch in dataset:
279      sample_count += batch.shape[0]
280    self.assertEqual(sample_count, 25)
281
282  def test_image_dataset_from_directory_no_images(self):
283    directory = self._prepare_directory(num_classes=2, count=0)
284    with self.assertRaisesRegex(ValueError, 'No images found.'):
285      _ = image_dataset.image_dataset_from_directory(directory)
286
287  def test_image_dataset_from_directory_smart_resize(self):
288    if PIL is None:
289      return  # Skip test if PIL is not available.
290
291    directory = self._prepare_directory(num_classes=2, count=5)
292    dataset = image_dataset.image_dataset_from_directory(
293        directory, batch_size=5, image_size=(18, 18), smart_resize=True)
294    batch = next(iter(dataset))
295    self.assertLen(batch, 2)
296    self.assertEqual(batch[0].shape, (5, 18, 18, 3))
297
298  def test_image_dataset_from_directory_errors(self):
299    if PIL is None:
300      return  # Skip test if PIL is not available.
301
302    directory = self._prepare_directory(num_classes=3, count=5)
303
304    with self.assertRaisesRegex(ValueError, '`labels` argument should be'):
305      _ = image_dataset.image_dataset_from_directory(
306          directory, labels='other')
307
308    with self.assertRaisesRegex(ValueError, '`label_mode` argument must be'):
309      _ = image_dataset.image_dataset_from_directory(
310          directory, label_mode='other')
311
312    with self.assertRaisesRegex(ValueError, '`color_mode` must be one of'):
313      _ = image_dataset.image_dataset_from_directory(
314          directory, color_mode='other')
315
316    with self.assertRaisesRegex(
317        ValueError, 'only pass `class_names` if the labels are inferred'):
318      _ = image_dataset.image_dataset_from_directory(
319          directory, labels=[0, 0, 1, 1, 1],
320          class_names=['class_0', 'class_1', 'class_2'])
321
322    with self.assertRaisesRegex(
323        ValueError,
324        'Expected the lengths of `labels` to match the number of files'):
325      _ = image_dataset.image_dataset_from_directory(
326          directory, labels=[0, 0, 1, 1])
327
328    with self.assertRaisesRegex(
329        ValueError, '`class_names` passed did not match'):
330      _ = image_dataset.image_dataset_from_directory(
331          directory, class_names=['class_0', 'class_2'])
332
333    with self.assertRaisesRegex(ValueError, 'there must exactly 2 classes'):
334      _ = image_dataset.image_dataset_from_directory(
335          directory, label_mode='binary')
336
337    with self.assertRaisesRegex(ValueError,
338                                '`validation_split` must be between 0 and 1'):
339      _ = image_dataset.image_dataset_from_directory(
340          directory, validation_split=2)
341
342    with self.assertRaisesRegex(ValueError,
343                                '`subset` must be either "training" or'):
344      _ = image_dataset.image_dataset_from_directory(
345          directory, validation_split=0.2, subset='other')
346
347    with self.assertRaisesRegex(ValueError, '`validation_split` must be set'):
348      _ = image_dataset.image_dataset_from_directory(
349          directory, validation_split=0, subset='training')
350
351    with self.assertRaisesRegex(ValueError, 'must provide a `seed`'):
352      _ = image_dataset.image_dataset_from_directory(
353          directory, validation_split=0.2, subset='training')
354
355
356if __name__ == '__main__':
357  v2_compat.enable_v2_behavior()
358  test.main()
359