1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for image preprocessing utils.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import os 22import shutil 23import tempfile 24 25from absl.testing import parameterized 26import numpy as np 27 28from tensorflow.python.data import Dataset 29from tensorflow.python.keras import keras_parameterized 30from tensorflow.python.keras import layers 31from tensorflow.python.keras import testing_utils 32from tensorflow.python.keras.engine import sequential 33from tensorflow.python.keras.preprocessing import image as preprocessing_image 34from tensorflow.python.platform import test 35 36try: 37 import PIL # pylint:disable=g-import-not-at-top 38except ImportError: 39 PIL = None 40 41 42def _generate_test_images(): 43 img_w = img_h = 20 44 rgb_images = [] 45 gray_images = [] 46 for _ in range(8): 47 bias = np.random.rand(img_w, img_h, 1) * 64 48 variance = np.random.rand(img_w, img_h, 1) * (255 - 64) 49 imarray = np.random.rand(img_w, img_h, 3) * variance + bias 50 im = preprocessing_image.array_to_img(imarray, scale=False) 51 rgb_images.append(im) 52 53 imarray = np.random.rand(img_w, img_h, 1) * variance + bias 54 im = preprocessing_image.array_to_img(imarray, scale=False) 55 gray_images.append(im) 56 57 return [rgb_images, gray_images] 58 59 60class TestImage(keras_parameterized.TestCase): 61 62 @testing_utils.run_v2_only 63 def test_smart_resize(self): 64 test_input = np.random.random((20, 40, 3)) 65 output = preprocessing_image.smart_resize(test_input, size=(50, 50)) 66 self.assertIsInstance(output, np.ndarray) 67 self.assertListEqual(list(output.shape), [50, 50, 3]) 68 output = preprocessing_image.smart_resize(test_input, size=(10, 10)) 69 self.assertListEqual(list(output.shape), [10, 10, 3]) 70 output = preprocessing_image.smart_resize(test_input, size=(100, 50)) 71 self.assertListEqual(list(output.shape), [100, 50, 3]) 72 output = preprocessing_image.smart_resize(test_input, size=(5, 15)) 73 self.assertListEqual(list(output.shape), [5, 15, 3]) 74 75 @parameterized.named_parameters( 76 ('size1', (50, 50)), 77 ('size2', (10, 10)), 78 ('size3', (100, 50)), 79 ('size4', (5, 15))) 80 @testing_utils.run_v2_only 81 def test_smart_resize_tf_dataset(self, size): 82 test_input_np = np.random.random((2, 20, 40, 3)) 83 test_ds = Dataset.from_tensor_slices(test_input_np) 84 85 resize = lambda img: preprocessing_image.smart_resize(img, size=size) 86 test_ds = test_ds.map(resize) 87 for sample in test_ds.as_numpy_iterator(): 88 self.assertIsInstance(sample, np.ndarray) 89 self.assertListEqual(list(sample.shape), [size[0], size[1], 3]) 90 91 def test_smart_resize_errors(self): 92 with self.assertRaisesRegex(ValueError, 'a tuple of 2 integers'): 93 preprocessing_image.smart_resize( 94 np.random.random((20, 20, 2)), size=(10, 5, 3)) 95 with self.assertRaisesRegex(ValueError, 'incorrect rank'): 96 preprocessing_image.smart_resize(np.random.random((20, 40)), size=(10, 5)) 97 98 def test_image_data_generator(self): 99 if PIL is None: 100 return # Skip test if PIL is not available. 101 102 for test_images in _generate_test_images(): 103 img_list = [] 104 for im in test_images: 105 img_list.append(preprocessing_image.img_to_array(im)[None, ...]) 106 107 images = np.vstack(img_list) 108 generator = preprocessing_image.ImageDataGenerator( 109 featurewise_center=True, 110 samplewise_center=True, 111 featurewise_std_normalization=True, 112 samplewise_std_normalization=True, 113 zca_whitening=True, 114 rotation_range=90., 115 width_shift_range=0.1, 116 height_shift_range=0.1, 117 shear_range=0.5, 118 zoom_range=0.2, 119 channel_shift_range=0., 120 brightness_range=(1, 5), 121 fill_mode='nearest', 122 cval=0.5, 123 horizontal_flip=True, 124 vertical_flip=True) 125 # Basic test before fit 126 x = np.random.random((32, 10, 10, 3)) 127 generator.flow(x) 128 129 # Fit 130 generator.fit(images, augment=True) 131 132 for x, _ in generator.flow( 133 images, 134 np.arange(images.shape[0]), 135 shuffle=True): 136 self.assertEqual(x.shape[1:], images.shape[1:]) 137 break 138 139 def test_image_data_generator_with_split_value_error(self): 140 with self.assertRaises(ValueError): 141 preprocessing_image.ImageDataGenerator(validation_split=5) 142 143 def test_image_data_generator_invalid_data(self): 144 generator = preprocessing_image.ImageDataGenerator( 145 featurewise_center=True, 146 samplewise_center=True, 147 featurewise_std_normalization=True, 148 samplewise_std_normalization=True, 149 zca_whitening=True, 150 data_format='channels_last') 151 152 # Test fit with invalid data 153 with self.assertRaises(ValueError): 154 x = np.random.random((3, 10, 10)) 155 generator.fit(x) 156 # Test flow with invalid data 157 with self.assertRaises(ValueError): 158 generator.flow(np.arange(5)) 159 # Invalid number of channels: will work but raise a warning 160 x = np.random.random((32, 10, 10, 5)) 161 generator.flow(x) 162 163 with self.assertRaises(ValueError): 164 generator = preprocessing_image.ImageDataGenerator( 165 data_format='unknown') 166 167 generator = preprocessing_image.ImageDataGenerator(zoom_range=(2., 2.)) 168 169 def test_image_data_generator_fit(self): 170 generator = preprocessing_image.ImageDataGenerator( 171 featurewise_center=True, 172 samplewise_center=True, 173 featurewise_std_normalization=True, 174 samplewise_std_normalization=True, 175 zca_whitening=True, 176 data_format='channels_last') 177 # Test grayscale 178 x = np.random.random((32, 10, 10, 1)) 179 generator.fit(x) 180 # Test RBG 181 x = np.random.random((32, 10, 10, 3)) 182 generator.fit(x) 183 generator = preprocessing_image.ImageDataGenerator( 184 featurewise_center=True, 185 samplewise_center=True, 186 featurewise_std_normalization=True, 187 samplewise_std_normalization=True, 188 zca_whitening=True, 189 data_format='channels_first') 190 # Test grayscale 191 x = np.random.random((32, 1, 10, 10)) 192 generator.fit(x) 193 # Test RBG 194 x = np.random.random((32, 3, 10, 10)) 195 generator.fit(x) 196 197 def test_directory_iterator(self): 198 if PIL is None: 199 return # Skip test if PIL is not available. 200 201 num_classes = 2 202 203 temp_dir = self.get_temp_dir() 204 self.addCleanup(shutil.rmtree, temp_dir) 205 206 # create folders and subfolders 207 paths = [] 208 for cl in range(num_classes): 209 class_directory = 'class-{}'.format(cl) 210 classpaths = [ 211 class_directory, os.path.join(class_directory, 'subfolder-1'), 212 os.path.join(class_directory, 'subfolder-2'), os.path.join( 213 class_directory, 'subfolder-1', 'sub-subfolder') 214 ] 215 for path in classpaths: 216 os.mkdir(os.path.join(temp_dir, path)) 217 paths.append(classpaths) 218 219 # save the images in the paths 220 count = 0 221 filenames = [] 222 for test_images in _generate_test_images(): 223 for im in test_images: 224 # rotate image class 225 im_class = count % num_classes 226 # rotate subfolders 227 classpaths = paths[im_class] 228 filename = os.path.join(classpaths[count % len(classpaths)], 229 'image-{}.jpg'.format(count)) 230 filenames.append(filename) 231 im.save(os.path.join(temp_dir, filename)) 232 count += 1 233 234 # Test image loading util 235 fname = os.path.join(temp_dir, filenames[0]) 236 _ = preprocessing_image.load_img(fname) 237 _ = preprocessing_image.load_img(fname, grayscale=True) 238 _ = preprocessing_image.load_img(fname, target_size=(10, 10)) 239 _ = preprocessing_image.load_img(fname, target_size=(10, 10), 240 interpolation='bilinear') 241 242 # create iterator 243 generator = preprocessing_image.ImageDataGenerator() 244 dir_iterator = generator.flow_from_directory(temp_dir) 245 246 # check number of classes and images 247 self.assertEqual(len(dir_iterator.class_indices), num_classes) 248 self.assertEqual(len(dir_iterator.classes), count) 249 self.assertEqual(set(dir_iterator.filenames), set(filenames)) 250 251 def preprocessing_function(x): 252 """This will fail if not provided by a Numpy array. 253 254 Note: This is made to enforce backward compatibility. 255 256 Args: 257 x: A numpy array. 258 259 Returns: 260 An array of zeros with the same shape as the given array. 261 """ 262 self.assertEqual(x.shape, (26, 26, 3)) 263 self.assertIs(type(x), np.ndarray) 264 return np.zeros_like(x) 265 266 # Test usage as Sequence 267 generator = preprocessing_image.ImageDataGenerator( 268 preprocessing_function=preprocessing_function) 269 dir_seq = generator.flow_from_directory( 270 str(temp_dir), 271 target_size=(26, 26), 272 color_mode='rgb', 273 batch_size=3, 274 class_mode='categorical') 275 self.assertEqual(len(dir_seq), count // 3 + 1) 276 x1, y1 = dir_seq[1] 277 self.assertEqual(x1.shape, (3, 26, 26, 3)) 278 self.assertEqual(y1.shape, (3, num_classes)) 279 x1, y1 = dir_seq[5] 280 self.assertTrue((x1 == 0).all()) 281 282 def directory_iterator_with_validation_split_test_helper( 283 self, validation_split): 284 if PIL is None: 285 return # Skip test if PIL is not available. 286 287 num_classes = 2 288 tmp_folder = tempfile.mkdtemp(prefix='test_images') 289 290 # create folders and subfolders 291 paths = [] 292 for cl in range(num_classes): 293 class_directory = 'class-{}'.format(cl) 294 classpaths = [ 295 class_directory, 296 os.path.join(class_directory, 'subfolder-1'), 297 os.path.join(class_directory, 'subfolder-2'), 298 os.path.join(class_directory, 'subfolder-1', 'sub-subfolder') 299 ] 300 for path in classpaths: 301 os.mkdir(os.path.join(tmp_folder, path)) 302 paths.append(classpaths) 303 304 # save the images in the paths 305 count = 0 306 filenames = [] 307 for test_images in _generate_test_images(): 308 for im in test_images: 309 # rotate image class 310 im_class = count % num_classes 311 # rotate subfolders 312 classpaths = paths[im_class] 313 filename = os.path.join(classpaths[count % len(classpaths)], 314 'image-{}.jpg'.format(count)) 315 filenames.append(filename) 316 im.save(os.path.join(tmp_folder, filename)) 317 count += 1 318 319 # create iterator 320 generator = preprocessing_image.ImageDataGenerator( 321 validation_split=validation_split) 322 323 with self.assertRaises(ValueError): 324 generator.flow_from_directory(tmp_folder, subset='foo') 325 326 num_validation = int(count * validation_split) 327 num_training = count - num_validation 328 train_iterator = generator.flow_from_directory( 329 tmp_folder, subset='training') 330 self.assertEqual(train_iterator.samples, num_training) 331 332 valid_iterator = generator.flow_from_directory( 333 tmp_folder, subset='validation') 334 self.assertEqual(valid_iterator.samples, num_validation) 335 336 # check number of classes and images 337 self.assertEqual(len(train_iterator.class_indices), num_classes) 338 self.assertEqual(len(train_iterator.classes), num_training) 339 self.assertEqual( 340 len(set(train_iterator.filenames) & set(filenames)), num_training) 341 342 model = sequential.Sequential([layers.Flatten(), layers.Dense(2)]) 343 model.compile(optimizer='sgd', loss='mse') 344 model.fit(train_iterator, epochs=1) 345 346 shutil.rmtree(tmp_folder) 347 348 @keras_parameterized.run_all_keras_modes 349 def test_directory_iterator_with_validation_split_25_percent(self): 350 self.directory_iterator_with_validation_split_test_helper(0.25) 351 352 @keras_parameterized.run_all_keras_modes 353 def test_directory_iterator_with_validation_split_40_percent(self): 354 self.directory_iterator_with_validation_split_test_helper(0.40) 355 356 @keras_parameterized.run_all_keras_modes 357 def test_directory_iterator_with_validation_split_50_percent(self): 358 self.directory_iterator_with_validation_split_test_helper(0.50) 359 360 def test_img_utils(self): 361 if PIL is None: 362 return # Skip test if PIL is not available. 363 364 height, width = 10, 8 365 366 # Test channels_first data format 367 x = np.random.random((3, height, width)) 368 img = preprocessing_image.array_to_img( 369 x, data_format='channels_first') 370 self.assertEqual(img.size, (width, height)) 371 x = preprocessing_image.img_to_array( 372 img, data_format='channels_first') 373 self.assertEqual(x.shape, (3, height, width)) 374 # Test 2D 375 x = np.random.random((1, height, width)) 376 img = preprocessing_image.array_to_img( 377 x, data_format='channels_first') 378 self.assertEqual(img.size, (width, height)) 379 x = preprocessing_image.img_to_array( 380 img, data_format='channels_first') 381 self.assertEqual(x.shape, (1, height, width)) 382 383 # Test channels_last data format 384 x = np.random.random((height, width, 3)) 385 img = preprocessing_image.array_to_img(x, data_format='channels_last') 386 self.assertEqual(img.size, (width, height)) 387 x = preprocessing_image.img_to_array(img, data_format='channels_last') 388 self.assertEqual(x.shape, (height, width, 3)) 389 # Test 2D 390 x = np.random.random((height, width, 1)) 391 img = preprocessing_image.array_to_img(x, data_format='channels_last') 392 self.assertEqual(img.size, (width, height)) 393 x = preprocessing_image.img_to_array(img, data_format='channels_last') 394 self.assertEqual(x.shape, (height, width, 1)) 395 396 def test_batch_standardize(self): 397 if PIL is None: 398 return # Skip test if PIL is not available. 399 400 # ImageDataGenerator.standardize should work on batches 401 for test_images in _generate_test_images(): 402 img_list = [] 403 for im in test_images: 404 img_list.append(preprocessing_image.img_to_array(im)[None, ...]) 405 406 images = np.vstack(img_list) 407 generator = preprocessing_image.ImageDataGenerator( 408 featurewise_center=True, 409 samplewise_center=True, 410 featurewise_std_normalization=True, 411 samplewise_std_normalization=True, 412 zca_whitening=True, 413 rotation_range=90., 414 width_shift_range=0.1, 415 height_shift_range=0.1, 416 shear_range=0.5, 417 zoom_range=0.2, 418 channel_shift_range=0., 419 brightness_range=(1, 5), 420 fill_mode='nearest', 421 cval=0.5, 422 horizontal_flip=True, 423 vertical_flip=True) 424 generator.fit(images, augment=True) 425 426 transformed = np.copy(images) 427 for i, im in enumerate(transformed): 428 transformed[i] = generator.random_transform(im) 429 transformed = generator.standardize(transformed) 430 431 def test_img_transforms(self): 432 x = np.random.random((3, 200, 200)) 433 _ = preprocessing_image.random_rotation(x, 20) 434 _ = preprocessing_image.random_shift(x, 0.2, 0.2) 435 _ = preprocessing_image.random_shear(x, 2.) 436 _ = preprocessing_image.random_zoom(x, (0.5, 0.5)) 437 _ = preprocessing_image.apply_channel_shift(x, 2, 2) 438 _ = preprocessing_image.apply_affine_transform(x, 2) 439 with self.assertRaises(ValueError): 440 preprocessing_image.random_zoom(x, (0, 0, 0)) 441 _ = preprocessing_image.random_channel_shift(x, 2.) 442 443 444if __name__ == '__main__': 445 test.main() 446