1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for image_dataset.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import os 22import shutil 23 24import numpy as np 25 26from tensorflow.python.compat import v2_compat 27from tensorflow.python.eager import def_function 28from tensorflow.python.keras import keras_parameterized 29from tensorflow.python.keras.preprocessing import image as image_preproc 30from tensorflow.python.keras.preprocessing import image_dataset 31from tensorflow.python.platform import test 32 33try: 34 import PIL # pylint:disable=g-import-not-at-top 35except ImportError: 36 PIL = None 37 38 39class ImageDatasetFromDirectoryTest(keras_parameterized.TestCase): 40 41 def _get_images(self, count=16, color_mode='rgb'): 42 width = height = 24 43 imgs = [] 44 for _ in range(count): 45 if color_mode == 'grayscale': 46 img = np.random.randint(0, 256, size=(height, width, 1)) 47 elif color_mode == 'rgba': 48 img = np.random.randint(0, 256, size=(height, width, 4)) 49 else: 50 img = np.random.randint(0, 256, size=(height, width, 3)) 51 img = image_preproc.array_to_img(img) 52 imgs.append(img) 53 return imgs 54 55 def _prepare_directory(self, 56 num_classes=2, 57 grayscale=False, 58 nested_dirs=False, 59 color_mode='rgb', 60 count=16): 61 # Get a unique temp directory 62 temp_dir = os.path.join(self.get_temp_dir(), str(np.random.randint(1e6))) 63 os.mkdir(temp_dir) 64 self.addCleanup(shutil.rmtree, temp_dir) 65 66 # Generate paths to class subdirectories 67 paths = [] 68 for class_index in range(num_classes): 69 class_directory = 'class_%s' % (class_index,) 70 if nested_dirs: 71 class_paths = [ 72 class_directory, os.path.join(class_directory, 'subfolder_1'), 73 os.path.join(class_directory, 'subfolder_2'), os.path.join( 74 class_directory, 'subfolder_1', 'sub-subfolder') 75 ] 76 else: 77 class_paths = [class_directory] 78 for path in class_paths: 79 os.mkdir(os.path.join(temp_dir, path)) 80 paths += class_paths 81 82 # Save images to the paths 83 i = 0 84 for img in self._get_images(color_mode=color_mode, count=count): 85 path = paths[i % len(paths)] 86 if color_mode == 'rgb': 87 ext = 'jpg' 88 else: 89 ext = 'png' 90 filename = os.path.join(path, 'image_%s.%s' % (i, ext)) 91 img.save(os.path.join(temp_dir, filename)) 92 i += 1 93 return temp_dir 94 95 def test_image_dataset_from_directory_standalone(self): 96 # Test retrieving images without labels from a directory and its subdirs. 97 if PIL is None: 98 return # Skip test if PIL is not available. 99 100 # Save a few extra images in the parent directory. 101 directory = self._prepare_directory(count=7, num_classes=2) 102 for i, img in enumerate(self._get_images(3)): 103 filename = 'image_%s.jpg' % (i,) 104 img.save(os.path.join(directory, filename)) 105 106 dataset = image_dataset.image_dataset_from_directory( 107 directory, batch_size=5, image_size=(18, 18), labels=None) 108 batch = next(iter(dataset)) 109 # We return plain images 110 self.assertEqual(batch.shape, (5, 18, 18, 3)) 111 self.assertEqual(batch.dtype.name, 'float32') 112 # Count samples 113 batch_count = 0 114 sample_count = 0 115 for batch in dataset: 116 batch_count += 1 117 sample_count += batch.shape[0] 118 self.assertEqual(batch_count, 2) 119 self.assertEqual(sample_count, 10) 120 121 def test_image_dataset_from_directory_binary(self): 122 if PIL is None: 123 return # Skip test if PIL is not available. 124 125 directory = self._prepare_directory(num_classes=2) 126 dataset = image_dataset.image_dataset_from_directory( 127 directory, batch_size=8, image_size=(18, 18), label_mode='int') 128 batch = next(iter(dataset)) 129 self.assertLen(batch, 2) 130 self.assertEqual(batch[0].shape, (8, 18, 18, 3)) 131 self.assertEqual(batch[0].dtype.name, 'float32') 132 self.assertEqual(batch[1].shape, (8,)) 133 self.assertEqual(batch[1].dtype.name, 'int32') 134 135 dataset = image_dataset.image_dataset_from_directory( 136 directory, batch_size=8, image_size=(18, 18), label_mode='binary') 137 batch = next(iter(dataset)) 138 self.assertLen(batch, 2) 139 self.assertEqual(batch[0].shape, (8, 18, 18, 3)) 140 self.assertEqual(batch[0].dtype.name, 'float32') 141 self.assertEqual(batch[1].shape, (8, 1)) 142 self.assertEqual(batch[1].dtype.name, 'float32') 143 144 dataset = image_dataset.image_dataset_from_directory( 145 directory, batch_size=8, image_size=(18, 18), label_mode='categorical') 146 batch = next(iter(dataset)) 147 self.assertLen(batch, 2) 148 self.assertEqual(batch[0].shape, (8, 18, 18, 3)) 149 self.assertEqual(batch[0].dtype.name, 'float32') 150 self.assertEqual(batch[1].shape, (8, 2)) 151 self.assertEqual(batch[1].dtype.name, 'float32') 152 153 def test_static_shape_in_graph(self): 154 if PIL is None: 155 return # Skip test if PIL is not available. 156 157 directory = self._prepare_directory(num_classes=2) 158 dataset = image_dataset.image_dataset_from_directory( 159 directory, batch_size=8, image_size=(18, 18), label_mode='int') 160 test_case = self 161 162 @def_function.function 163 def symbolic_fn(ds): 164 for x, _ in ds.take(1): 165 test_case.assertListEqual(x.shape.as_list(), [None, 18, 18, 3]) 166 167 symbolic_fn(dataset) 168 169 def test_sample_count(self): 170 if PIL is None: 171 return # Skip test if PIL is not available. 172 173 directory = self._prepare_directory(num_classes=4, count=15) 174 dataset = image_dataset.image_dataset_from_directory( 175 directory, batch_size=8, image_size=(18, 18), label_mode=None) 176 sample_count = 0 177 for batch in dataset: 178 sample_count += batch.shape[0] 179 self.assertEqual(sample_count, 15) 180 181 def test_image_dataset_from_directory_multiclass(self): 182 if PIL is None: 183 return # Skip test if PIL is not available. 184 185 directory = self._prepare_directory(num_classes=4, count=15) 186 187 dataset = image_dataset.image_dataset_from_directory( 188 directory, batch_size=8, image_size=(18, 18), label_mode=None) 189 batch = next(iter(dataset)) 190 self.assertEqual(batch.shape, (8, 18, 18, 3)) 191 192 dataset = image_dataset.image_dataset_from_directory( 193 directory, batch_size=8, image_size=(18, 18), label_mode=None) 194 sample_count = 0 195 iterator = iter(dataset) 196 for batch in dataset: 197 sample_count += next(iterator).shape[0] 198 self.assertEqual(sample_count, 15) 199 200 dataset = image_dataset.image_dataset_from_directory( 201 directory, batch_size=8, image_size=(18, 18), label_mode='int') 202 batch = next(iter(dataset)) 203 self.assertLen(batch, 2) 204 self.assertEqual(batch[0].shape, (8, 18, 18, 3)) 205 self.assertEqual(batch[0].dtype.name, 'float32') 206 self.assertEqual(batch[1].shape, (8,)) 207 self.assertEqual(batch[1].dtype.name, 'int32') 208 209 dataset = image_dataset.image_dataset_from_directory( 210 directory, batch_size=8, image_size=(18, 18), label_mode='categorical') 211 batch = next(iter(dataset)) 212 self.assertLen(batch, 2) 213 self.assertEqual(batch[0].shape, (8, 18, 18, 3)) 214 self.assertEqual(batch[0].dtype.name, 'float32') 215 self.assertEqual(batch[1].shape, (8, 4)) 216 self.assertEqual(batch[1].dtype.name, 'float32') 217 218 def test_image_dataset_from_directory_color_modes(self): 219 if PIL is None: 220 return # Skip test if PIL is not available. 221 222 directory = self._prepare_directory(num_classes=4, color_mode='rgba') 223 dataset = image_dataset.image_dataset_from_directory( 224 directory, batch_size=8, image_size=(18, 18), color_mode='rgba') 225 batch = next(iter(dataset)) 226 self.assertLen(batch, 2) 227 self.assertEqual(batch[0].shape, (8, 18, 18, 4)) 228 self.assertEqual(batch[0].dtype.name, 'float32') 229 230 directory = self._prepare_directory(num_classes=4, color_mode='grayscale') 231 dataset = image_dataset.image_dataset_from_directory( 232 directory, batch_size=8, image_size=(18, 18), color_mode='grayscale') 233 batch = next(iter(dataset)) 234 self.assertLen(batch, 2) 235 self.assertEqual(batch[0].shape, (8, 18, 18, 1)) 236 self.assertEqual(batch[0].dtype.name, 'float32') 237 238 def test_image_dataset_from_directory_validation_split(self): 239 if PIL is None: 240 return # Skip test if PIL is not available. 241 242 directory = self._prepare_directory(num_classes=2, count=10) 243 dataset = image_dataset.image_dataset_from_directory( 244 directory, batch_size=10, image_size=(18, 18), 245 validation_split=0.2, subset='training', seed=1337) 246 batch = next(iter(dataset)) 247 self.assertLen(batch, 2) 248 self.assertEqual(batch[0].shape, (8, 18, 18, 3)) 249 dataset = image_dataset.image_dataset_from_directory( 250 directory, batch_size=10, image_size=(18, 18), 251 validation_split=0.2, subset='validation', seed=1337) 252 batch = next(iter(dataset)) 253 self.assertLen(batch, 2) 254 self.assertEqual(batch[0].shape, (2, 18, 18, 3)) 255 256 def test_image_dataset_from_directory_manual_labels(self): 257 if PIL is None: 258 return # Skip test if PIL is not available. 259 260 directory = self._prepare_directory(num_classes=2, count=2) 261 dataset = image_dataset.image_dataset_from_directory( 262 directory, batch_size=8, image_size=(18, 18), 263 labels=[0, 1], shuffle=False) 264 batch = next(iter(dataset)) 265 self.assertLen(batch, 2) 266 self.assertAllClose(batch[1], [0, 1]) 267 268 def test_image_dataset_from_directory_follow_links(self): 269 if PIL is None: 270 return # Skip test if PIL is not available. 271 272 directory = self._prepare_directory(num_classes=2, count=25, 273 nested_dirs=True) 274 dataset = image_dataset.image_dataset_from_directory( 275 directory, batch_size=8, image_size=(18, 18), label_mode=None, 276 follow_links=True) 277 sample_count = 0 278 for batch in dataset: 279 sample_count += batch.shape[0] 280 self.assertEqual(sample_count, 25) 281 282 def test_image_dataset_from_directory_no_images(self): 283 directory = self._prepare_directory(num_classes=2, count=0) 284 with self.assertRaisesRegex(ValueError, 'No images found.'): 285 _ = image_dataset.image_dataset_from_directory(directory) 286 287 def test_image_dataset_from_directory_smart_resize(self): 288 if PIL is None: 289 return # Skip test if PIL is not available. 290 291 directory = self._prepare_directory(num_classes=2, count=5) 292 dataset = image_dataset.image_dataset_from_directory( 293 directory, batch_size=5, image_size=(18, 18), smart_resize=True) 294 batch = next(iter(dataset)) 295 self.assertLen(batch, 2) 296 self.assertEqual(batch[0].shape, (5, 18, 18, 3)) 297 298 def test_image_dataset_from_directory_errors(self): 299 if PIL is None: 300 return # Skip test if PIL is not available. 301 302 directory = self._prepare_directory(num_classes=3, count=5) 303 304 with self.assertRaisesRegex(ValueError, '`labels` argument should be'): 305 _ = image_dataset.image_dataset_from_directory( 306 directory, labels='other') 307 308 with self.assertRaisesRegex(ValueError, '`label_mode` argument must be'): 309 _ = image_dataset.image_dataset_from_directory( 310 directory, label_mode='other') 311 312 with self.assertRaisesRegex(ValueError, '`color_mode` must be one of'): 313 _ = image_dataset.image_dataset_from_directory( 314 directory, color_mode='other') 315 316 with self.assertRaisesRegex( 317 ValueError, 'only pass `class_names` if the labels are inferred'): 318 _ = image_dataset.image_dataset_from_directory( 319 directory, labels=[0, 0, 1, 1, 1], 320 class_names=['class_0', 'class_1', 'class_2']) 321 322 with self.assertRaisesRegex( 323 ValueError, 324 'Expected the lengths of `labels` to match the number of files'): 325 _ = image_dataset.image_dataset_from_directory( 326 directory, labels=[0, 0, 1, 1]) 327 328 with self.assertRaisesRegex( 329 ValueError, '`class_names` passed did not match'): 330 _ = image_dataset.image_dataset_from_directory( 331 directory, class_names=['class_0', 'class_2']) 332 333 with self.assertRaisesRegex(ValueError, 'there must exactly 2 classes'): 334 _ = image_dataset.image_dataset_from_directory( 335 directory, label_mode='binary') 336 337 with self.assertRaisesRegex(ValueError, 338 '`validation_split` must be between 0 and 1'): 339 _ = image_dataset.image_dataset_from_directory( 340 directory, validation_split=2) 341 342 with self.assertRaisesRegex(ValueError, 343 '`subset` must be either "training" or'): 344 _ = image_dataset.image_dataset_from_directory( 345 directory, validation_split=0.2, subset='other') 346 347 with self.assertRaisesRegex(ValueError, '`validation_split` must be set'): 348 _ = image_dataset.image_dataset_from_directory( 349 directory, validation_split=0, subset='training') 350 351 with self.assertRaisesRegex(ValueError, 'must provide a `seed`'): 352 _ = image_dataset.image_dataset_from_directory( 353 directory, validation_split=0.2, subset='training') 354 355 356if __name__ == '__main__': 357 v2_compat.enable_v2_behavior() 358 test.main() 359