1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for image preprocessing utils."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import os
22import shutil
23import tempfile
24
25from absl.testing import parameterized
26import numpy as np
27
28from tensorflow.python.data import Dataset
29from tensorflow.python.keras import keras_parameterized
30from tensorflow.python.keras import layers
31from tensorflow.python.keras import testing_utils
32from tensorflow.python.keras.engine import sequential
33from tensorflow.python.keras.preprocessing import image as preprocessing_image
34from tensorflow.python.platform import test
35
36try:
37  import PIL  # pylint:disable=g-import-not-at-top
38except ImportError:
39  PIL = None
40
41
42def _generate_test_images():
43  img_w = img_h = 20
44  rgb_images = []
45  gray_images = []
46  for _ in range(8):
47    bias = np.random.rand(img_w, img_h, 1) * 64
48    variance = np.random.rand(img_w, img_h, 1) * (255 - 64)
49    imarray = np.random.rand(img_w, img_h, 3) * variance + bias
50    im = preprocessing_image.array_to_img(imarray, scale=False)
51    rgb_images.append(im)
52
53    imarray = np.random.rand(img_w, img_h, 1) * variance + bias
54    im = preprocessing_image.array_to_img(imarray, scale=False)
55    gray_images.append(im)
56
57  return [rgb_images, gray_images]
58
59
60class TestImage(keras_parameterized.TestCase):
61
62  @testing_utils.run_v2_only
63  def test_smart_resize(self):
64    test_input = np.random.random((20, 40, 3))
65    output = preprocessing_image.smart_resize(test_input, size=(50, 50))
66    self.assertIsInstance(output, np.ndarray)
67    self.assertListEqual(list(output.shape), [50, 50, 3])
68    output = preprocessing_image.smart_resize(test_input, size=(10, 10))
69    self.assertListEqual(list(output.shape), [10, 10, 3])
70    output = preprocessing_image.smart_resize(test_input, size=(100, 50))
71    self.assertListEqual(list(output.shape), [100, 50, 3])
72    output = preprocessing_image.smart_resize(test_input, size=(5, 15))
73    self.assertListEqual(list(output.shape), [5, 15, 3])
74
75  @parameterized.named_parameters(
76      ('size1', (50, 50)),
77      ('size2', (10, 10)),
78      ('size3', (100, 50)),
79      ('size4', (5, 15)))
80  @testing_utils.run_v2_only
81  def test_smart_resize_tf_dataset(self, size):
82    test_input_np = np.random.random((2, 20, 40, 3))
83    test_ds = Dataset.from_tensor_slices(test_input_np)
84
85    resize = lambda img: preprocessing_image.smart_resize(img, size=size)
86    test_ds = test_ds.map(resize)
87    for sample in test_ds.as_numpy_iterator():
88      self.assertIsInstance(sample, np.ndarray)
89      self.assertListEqual(list(sample.shape), [size[0], size[1], 3])
90
91  def test_smart_resize_errors(self):
92    with self.assertRaisesRegex(ValueError, 'a tuple of 2 integers'):
93      preprocessing_image.smart_resize(
94          np.random.random((20, 20, 2)), size=(10, 5, 3))
95    with self.assertRaisesRegex(ValueError, 'incorrect rank'):
96      preprocessing_image.smart_resize(np.random.random((20, 40)), size=(10, 5))
97
98  def test_image_data_generator(self):
99    if PIL is None:
100      return  # Skip test if PIL is not available.
101
102    for test_images in _generate_test_images():
103      img_list = []
104      for im in test_images:
105        img_list.append(preprocessing_image.img_to_array(im)[None, ...])
106
107      images = np.vstack(img_list)
108      generator = preprocessing_image.ImageDataGenerator(
109          featurewise_center=True,
110          samplewise_center=True,
111          featurewise_std_normalization=True,
112          samplewise_std_normalization=True,
113          zca_whitening=True,
114          rotation_range=90.,
115          width_shift_range=0.1,
116          height_shift_range=0.1,
117          shear_range=0.5,
118          zoom_range=0.2,
119          channel_shift_range=0.,
120          brightness_range=(1, 5),
121          fill_mode='nearest',
122          cval=0.5,
123          horizontal_flip=True,
124          vertical_flip=True)
125      # Basic test before fit
126      x = np.random.random((32, 10, 10, 3))
127      generator.flow(x)
128
129      # Fit
130      generator.fit(images, augment=True)
131
132      for x, _ in generator.flow(
133          images,
134          np.arange(images.shape[0]),
135          shuffle=True):
136        self.assertEqual(x.shape[1:], images.shape[1:])
137        break
138
139  def test_image_data_generator_with_split_value_error(self):
140    with self.assertRaises(ValueError):
141      preprocessing_image.ImageDataGenerator(validation_split=5)
142
143  def test_image_data_generator_invalid_data(self):
144    generator = preprocessing_image.ImageDataGenerator(
145        featurewise_center=True,
146        samplewise_center=True,
147        featurewise_std_normalization=True,
148        samplewise_std_normalization=True,
149        zca_whitening=True,
150        data_format='channels_last')
151
152    # Test fit with invalid data
153    with self.assertRaises(ValueError):
154      x = np.random.random((3, 10, 10))
155      generator.fit(x)
156    # Test flow with invalid data
157    with self.assertRaises(ValueError):
158      generator.flow(np.arange(5))
159    # Invalid number of channels: will work but raise a warning
160    x = np.random.random((32, 10, 10, 5))
161    generator.flow(x)
162
163    with self.assertRaises(ValueError):
164      generator = preprocessing_image.ImageDataGenerator(
165          data_format='unknown')
166
167    generator = preprocessing_image.ImageDataGenerator(zoom_range=(2., 2.))
168
169  def test_image_data_generator_fit(self):
170    generator = preprocessing_image.ImageDataGenerator(
171        featurewise_center=True,
172        samplewise_center=True,
173        featurewise_std_normalization=True,
174        samplewise_std_normalization=True,
175        zca_whitening=True,
176        data_format='channels_last')
177    # Test grayscale
178    x = np.random.random((32, 10, 10, 1))
179    generator.fit(x)
180    # Test RBG
181    x = np.random.random((32, 10, 10, 3))
182    generator.fit(x)
183    generator = preprocessing_image.ImageDataGenerator(
184        featurewise_center=True,
185        samplewise_center=True,
186        featurewise_std_normalization=True,
187        samplewise_std_normalization=True,
188        zca_whitening=True,
189        data_format='channels_first')
190    # Test grayscale
191    x = np.random.random((32, 1, 10, 10))
192    generator.fit(x)
193    # Test RBG
194    x = np.random.random((32, 3, 10, 10))
195    generator.fit(x)
196
197  def test_directory_iterator(self):
198    if PIL is None:
199      return  # Skip test if PIL is not available.
200
201    num_classes = 2
202
203    temp_dir = self.get_temp_dir()
204    self.addCleanup(shutil.rmtree, temp_dir)
205
206    # create folders and subfolders
207    paths = []
208    for cl in range(num_classes):
209      class_directory = 'class-{}'.format(cl)
210      classpaths = [
211          class_directory, os.path.join(class_directory, 'subfolder-1'),
212          os.path.join(class_directory, 'subfolder-2'), os.path.join(
213              class_directory, 'subfolder-1', 'sub-subfolder')
214      ]
215      for path in classpaths:
216        os.mkdir(os.path.join(temp_dir, path))
217      paths.append(classpaths)
218
219    # save the images in the paths
220    count = 0
221    filenames = []
222    for test_images in _generate_test_images():
223      for im in test_images:
224        # rotate image class
225        im_class = count % num_classes
226        # rotate subfolders
227        classpaths = paths[im_class]
228        filename = os.path.join(classpaths[count % len(classpaths)],
229                                'image-{}.jpg'.format(count))
230        filenames.append(filename)
231        im.save(os.path.join(temp_dir, filename))
232        count += 1
233
234    # Test image loading util
235    fname = os.path.join(temp_dir, filenames[0])
236    _ = preprocessing_image.load_img(fname)
237    _ = preprocessing_image.load_img(fname, grayscale=True)
238    _ = preprocessing_image.load_img(fname, target_size=(10, 10))
239    _ = preprocessing_image.load_img(fname, target_size=(10, 10),
240                                     interpolation='bilinear')
241
242    # create iterator
243    generator = preprocessing_image.ImageDataGenerator()
244    dir_iterator = generator.flow_from_directory(temp_dir)
245
246    # check number of classes and images
247    self.assertEqual(len(dir_iterator.class_indices), num_classes)
248    self.assertEqual(len(dir_iterator.classes), count)
249    self.assertEqual(set(dir_iterator.filenames), set(filenames))
250
251    def preprocessing_function(x):
252      """This will fail if not provided by a Numpy array.
253
254      Note: This is made to enforce backward compatibility.
255
256      Args:
257          x: A numpy array.
258
259      Returns:
260          An array of zeros with the same shape as the given array.
261      """
262      self.assertEqual(x.shape, (26, 26, 3))
263      self.assertIs(type(x), np.ndarray)
264      return np.zeros_like(x)
265
266    # Test usage as Sequence
267    generator = preprocessing_image.ImageDataGenerator(
268        preprocessing_function=preprocessing_function)
269    dir_seq = generator.flow_from_directory(
270        str(temp_dir),
271        target_size=(26, 26),
272        color_mode='rgb',
273        batch_size=3,
274        class_mode='categorical')
275    self.assertEqual(len(dir_seq), count // 3 + 1)
276    x1, y1 = dir_seq[1]
277    self.assertEqual(x1.shape, (3, 26, 26, 3))
278    self.assertEqual(y1.shape, (3, num_classes))
279    x1, y1 = dir_seq[5]
280    self.assertTrue((x1 == 0).all())
281
282  def directory_iterator_with_validation_split_test_helper(
283      self, validation_split):
284    if PIL is None:
285      return  # Skip test if PIL is not available.
286
287    num_classes = 2
288    tmp_folder = tempfile.mkdtemp(prefix='test_images')
289
290    # create folders and subfolders
291    paths = []
292    for cl in range(num_classes):
293      class_directory = 'class-{}'.format(cl)
294      classpaths = [
295          class_directory,
296          os.path.join(class_directory, 'subfolder-1'),
297          os.path.join(class_directory, 'subfolder-2'),
298          os.path.join(class_directory, 'subfolder-1', 'sub-subfolder')
299      ]
300      for path in classpaths:
301        os.mkdir(os.path.join(tmp_folder, path))
302      paths.append(classpaths)
303
304    # save the images in the paths
305    count = 0
306    filenames = []
307    for test_images in _generate_test_images():
308      for im in test_images:
309        # rotate image class
310        im_class = count % num_classes
311        # rotate subfolders
312        classpaths = paths[im_class]
313        filename = os.path.join(classpaths[count % len(classpaths)],
314                                'image-{}.jpg'.format(count))
315        filenames.append(filename)
316        im.save(os.path.join(tmp_folder, filename))
317        count += 1
318
319    # create iterator
320    generator = preprocessing_image.ImageDataGenerator(
321        validation_split=validation_split)
322
323    with self.assertRaises(ValueError):
324      generator.flow_from_directory(tmp_folder, subset='foo')
325
326    num_validation = int(count * validation_split)
327    num_training = count - num_validation
328    train_iterator = generator.flow_from_directory(
329        tmp_folder, subset='training')
330    self.assertEqual(train_iterator.samples, num_training)
331
332    valid_iterator = generator.flow_from_directory(
333        tmp_folder, subset='validation')
334    self.assertEqual(valid_iterator.samples, num_validation)
335
336    # check number of classes and images
337    self.assertEqual(len(train_iterator.class_indices), num_classes)
338    self.assertEqual(len(train_iterator.classes), num_training)
339    self.assertEqual(
340        len(set(train_iterator.filenames) & set(filenames)), num_training)
341
342    model = sequential.Sequential([layers.Flatten(), layers.Dense(2)])
343    model.compile(optimizer='sgd', loss='mse')
344    model.fit(train_iterator, epochs=1)
345
346    shutil.rmtree(tmp_folder)
347
348  @keras_parameterized.run_all_keras_modes
349  def test_directory_iterator_with_validation_split_25_percent(self):
350    self.directory_iterator_with_validation_split_test_helper(0.25)
351
352  @keras_parameterized.run_all_keras_modes
353  def test_directory_iterator_with_validation_split_40_percent(self):
354    self.directory_iterator_with_validation_split_test_helper(0.40)
355
356  @keras_parameterized.run_all_keras_modes
357  def test_directory_iterator_with_validation_split_50_percent(self):
358    self.directory_iterator_with_validation_split_test_helper(0.50)
359
360  def test_img_utils(self):
361    if PIL is None:
362      return  # Skip test if PIL is not available.
363
364    height, width = 10, 8
365
366    # Test channels_first data format
367    x = np.random.random((3, height, width))
368    img = preprocessing_image.array_to_img(
369        x, data_format='channels_first')
370    self.assertEqual(img.size, (width, height))
371    x = preprocessing_image.img_to_array(
372        img, data_format='channels_first')
373    self.assertEqual(x.shape, (3, height, width))
374    # Test 2D
375    x = np.random.random((1, height, width))
376    img = preprocessing_image.array_to_img(
377        x, data_format='channels_first')
378    self.assertEqual(img.size, (width, height))
379    x = preprocessing_image.img_to_array(
380        img, data_format='channels_first')
381    self.assertEqual(x.shape, (1, height, width))
382
383    # Test channels_last data format
384    x = np.random.random((height, width, 3))
385    img = preprocessing_image.array_to_img(x, data_format='channels_last')
386    self.assertEqual(img.size, (width, height))
387    x = preprocessing_image.img_to_array(img, data_format='channels_last')
388    self.assertEqual(x.shape, (height, width, 3))
389    # Test 2D
390    x = np.random.random((height, width, 1))
391    img = preprocessing_image.array_to_img(x, data_format='channels_last')
392    self.assertEqual(img.size, (width, height))
393    x = preprocessing_image.img_to_array(img, data_format='channels_last')
394    self.assertEqual(x.shape, (height, width, 1))
395
396  def test_batch_standardize(self):
397    if PIL is None:
398      return  # Skip test if PIL is not available.
399
400    # ImageDataGenerator.standardize should work on batches
401    for test_images in _generate_test_images():
402      img_list = []
403      for im in test_images:
404        img_list.append(preprocessing_image.img_to_array(im)[None, ...])
405
406      images = np.vstack(img_list)
407      generator = preprocessing_image.ImageDataGenerator(
408          featurewise_center=True,
409          samplewise_center=True,
410          featurewise_std_normalization=True,
411          samplewise_std_normalization=True,
412          zca_whitening=True,
413          rotation_range=90.,
414          width_shift_range=0.1,
415          height_shift_range=0.1,
416          shear_range=0.5,
417          zoom_range=0.2,
418          channel_shift_range=0.,
419          brightness_range=(1, 5),
420          fill_mode='nearest',
421          cval=0.5,
422          horizontal_flip=True,
423          vertical_flip=True)
424      generator.fit(images, augment=True)
425
426      transformed = np.copy(images)
427      for i, im in enumerate(transformed):
428        transformed[i] = generator.random_transform(im)
429      transformed = generator.standardize(transformed)
430
431  def test_img_transforms(self):
432    x = np.random.random((3, 200, 200))
433    _ = preprocessing_image.random_rotation(x, 20)
434    _ = preprocessing_image.random_shift(x, 0.2, 0.2)
435    _ = preprocessing_image.random_shear(x, 2.)
436    _ = preprocessing_image.random_zoom(x, (0.5, 0.5))
437    _ = preprocessing_image.apply_channel_shift(x, 2, 2)
438    _ = preprocessing_image.apply_affine_transform(x, 2)
439    with self.assertRaises(ValueError):
440      preprocessing_image.random_zoom(x, (0, 0, 0))
441    _ = preprocessing_image.random_channel_shift(x, 2.)
442
443
444if __name__ == '__main__':
445  test.main()
446