1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for image preprocessing utils."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import os
22import shutil
23import tempfile
24
25import numpy as np
26
27from tensorflow.python import keras
28from tensorflow.python.platform import test
29
30try:
31  import PIL  # pylint:disable=g-import-not-at-top
32except ImportError:
33  PIL = None
34
35
36def _generate_test_images():
37  img_w = img_h = 20
38  rgb_images = []
39  gray_images = []
40  for _ in range(8):
41    bias = np.random.rand(img_w, img_h, 1) * 64
42    variance = np.random.rand(img_w, img_h, 1) * (255 - 64)
43    imarray = np.random.rand(img_w, img_h, 3) * variance + bias
44    im = keras.preprocessing.image.array_to_img(imarray, scale=False)
45    rgb_images.append(im)
46
47    imarray = np.random.rand(img_w, img_h, 1) * variance + bias
48    im = keras.preprocessing.image.array_to_img(imarray, scale=False)
49    gray_images.append(im)
50
51  return [rgb_images, gray_images]
52
53
54class TestImage(test.TestCase):
55
56  def test_image_data_generator(self):
57    if PIL is None:
58      return  # Skip test if PIL is not available.
59
60    for test_images in _generate_test_images():
61      img_list = []
62      for im in test_images:
63        img_list.append(keras.preprocessing.image.img_to_array(im)[None, ...])
64
65      images = np.vstack(img_list)
66      generator = keras.preprocessing.image.ImageDataGenerator(
67          featurewise_center=True,
68          samplewise_center=True,
69          featurewise_std_normalization=True,
70          samplewise_std_normalization=True,
71          zca_whitening=True,
72          rotation_range=90.,
73          width_shift_range=0.1,
74          height_shift_range=0.1,
75          shear_range=0.5,
76          zoom_range=0.2,
77          channel_shift_range=0.,
78          brightness_range=(1, 5),
79          fill_mode='nearest',
80          cval=0.5,
81          horizontal_flip=True,
82          vertical_flip=True)
83      # Basic test before fit
84      x = np.random.random((32, 10, 10, 3))
85      generator.flow(x)
86
87      # Fit
88      generator.fit(images, augment=True)
89
90      for x, _ in generator.flow(
91          images,
92          np.arange(images.shape[0]),
93          shuffle=True):
94        self.assertEqual(x.shape[1:], images.shape[1:])
95        break
96
97  def test_image_data_generator_with_split_value_error(self):
98    with self.assertRaises(ValueError):
99      keras.preprocessing.image.ImageDataGenerator(validation_split=5)
100
101  def test_image_data_generator_invalid_data(self):
102    generator = keras.preprocessing.image.ImageDataGenerator(
103        featurewise_center=True,
104        samplewise_center=True,
105        featurewise_std_normalization=True,
106        samplewise_std_normalization=True,
107        zca_whitening=True,
108        data_format='channels_last')
109
110    # Test fit with invalid data
111    with self.assertRaises(ValueError):
112      x = np.random.random((3, 10, 10))
113      generator.fit(x)
114    # Test flow with invalid data
115    with self.assertRaises(ValueError):
116      generator.flow(np.arange(5))
117    # Invalid number of channels: will work but raise a warning
118    x = np.random.random((32, 10, 10, 5))
119    generator.flow(x)
120
121    with self.assertRaises(ValueError):
122      generator = keras.preprocessing.image.ImageDataGenerator(
123          data_format='unknown')
124
125    generator = keras.preprocessing.image.ImageDataGenerator(
126        zoom_range=(2, 2))
127
128  def test_image_data_generator_fit(self):
129    generator = keras.preprocessing.image.ImageDataGenerator(
130        featurewise_center=True,
131        samplewise_center=True,
132        featurewise_std_normalization=True,
133        samplewise_std_normalization=True,
134        zca_whitening=True,
135        data_format='channels_last')
136    # Test grayscale
137    x = np.random.random((32, 10, 10, 1))
138    generator.fit(x)
139    # Test RBG
140    x = np.random.random((32, 10, 10, 3))
141    generator.fit(x)
142    generator = keras.preprocessing.image.ImageDataGenerator(
143        featurewise_center=True,
144        samplewise_center=True,
145        featurewise_std_normalization=True,
146        samplewise_std_normalization=True,
147        zca_whitening=True,
148        data_format='channels_first')
149    # Test grayscale
150    x = np.random.random((32, 1, 10, 10))
151    generator.fit(x)
152    # Test RBG
153    x = np.random.random((32, 3, 10, 10))
154    generator.fit(x)
155
156  def test_directory_iterator(self):
157    if PIL is None:
158      return  # Skip test if PIL is not available.
159
160    num_classes = 2
161
162    temp_dir = self.get_temp_dir()
163    self.addCleanup(shutil.rmtree, temp_dir)
164
165    # create folders and subfolders
166    paths = []
167    for cl in range(num_classes):
168      class_directory = 'class-{}'.format(cl)
169      classpaths = [
170          class_directory, os.path.join(class_directory, 'subfolder-1'),
171          os.path.join(class_directory, 'subfolder-2'), os.path.join(
172              class_directory, 'subfolder-1', 'sub-subfolder')
173      ]
174      for path in classpaths:
175        os.mkdir(os.path.join(temp_dir, path))
176      paths.append(classpaths)
177
178    # save the images in the paths
179    count = 0
180    filenames = []
181    for test_images in _generate_test_images():
182      for im in test_images:
183        # rotate image class
184        im_class = count % num_classes
185        # rotate subfolders
186        classpaths = paths[im_class]
187        filename = os.path.join(classpaths[count % len(classpaths)],
188                                'image-{}.jpg'.format(count))
189        filenames.append(filename)
190        im.save(os.path.join(temp_dir, filename))
191        count += 1
192
193    # Test image loading util
194    fname = os.path.join(temp_dir, filenames[0])
195    _ = keras.preprocessing.image.load_img(fname)
196    _ = keras.preprocessing.image.load_img(fname, grayscale=True)
197    _ = keras.preprocessing.image.load_img(fname, target_size=(10, 10))
198    _ = keras.preprocessing.image.load_img(fname, target_size=(10, 10),
199                                           interpolation='bilinear')
200
201    # create iterator
202    generator = keras.preprocessing.image.ImageDataGenerator()
203    dir_iterator = generator.flow_from_directory(temp_dir)
204
205    # check number of classes and images
206    self.assertEqual(len(dir_iterator.class_indices), num_classes)
207    self.assertEqual(len(dir_iterator.classes), count)
208    self.assertEqual(set(dir_iterator.filenames), set(filenames))
209
210    def preprocessing_function(x):
211      """This will fail if not provided by a Numpy array.
212
213      Note: This is made to enforce backward compatibility.
214
215      Args:
216          x: A numpy array.
217
218      Returns:
219          An array of zeros with the same shape as the given array.
220      """
221      self.assertEqual(x.shape, (26, 26, 3))
222      self.assertIs(type(x), np.ndarray)
223      return np.zeros_like(x)
224
225    # Test usage as Sequence
226    generator = keras.preprocessing.image.ImageDataGenerator(
227        preprocessing_function=preprocessing_function)
228    dir_seq = generator.flow_from_directory(
229        str(temp_dir),
230        target_size=(26, 26),
231        color_mode='rgb',
232        batch_size=3,
233        class_mode='categorical')
234    self.assertEqual(len(dir_seq), count // 3 + 1)
235    x1, y1 = dir_seq[1]
236    self.assertEqual(x1.shape, (3, 26, 26, 3))
237    self.assertEqual(y1.shape, (3, num_classes))
238    x1, y1 = dir_seq[5]
239    self.assertTrue((x1 == 0).all())
240
241  def directory_iterator_with_validation_split_test_helper(
242      self, validation_split):
243    if PIL is None:
244      return  # Skip test if PIL is not available.
245
246    num_classes = 2
247    tmp_folder = tempfile.mkdtemp(prefix='test_images')
248
249    # create folders and subfolders
250    paths = []
251    for cl in range(num_classes):
252      class_directory = 'class-{}'.format(cl)
253      classpaths = [
254          class_directory,
255          os.path.join(class_directory, 'subfolder-1'),
256          os.path.join(class_directory, 'subfolder-2'),
257          os.path.join(class_directory, 'subfolder-1', 'sub-subfolder')
258      ]
259      for path in classpaths:
260        os.mkdir(os.path.join(tmp_folder, path))
261      paths.append(classpaths)
262
263    # save the images in the paths
264    count = 0
265    filenames = []
266    for test_images in _generate_test_images():
267      for im in test_images:
268        # rotate image class
269        im_class = count % num_classes
270        # rotate subfolders
271        classpaths = paths[im_class]
272        filename = os.path.join(classpaths[count % len(classpaths)],
273                                'image-{}.jpg'.format(count))
274        filenames.append(filename)
275        im.save(os.path.join(tmp_folder, filename))
276        count += 1
277
278    # create iterator
279    generator = keras.preprocessing.image.ImageDataGenerator(
280        validation_split=validation_split)
281
282    with self.assertRaises(ValueError):
283      generator.flow_from_directory(tmp_folder, subset='foo')
284
285    num_validation = int(count * validation_split)
286    num_training = count - num_validation
287    train_iterator = generator.flow_from_directory(
288        tmp_folder, subset='training')
289    self.assertEqual(train_iterator.samples, num_training)
290
291    valid_iterator = generator.flow_from_directory(
292        tmp_folder, subset='validation')
293    self.assertEqual(valid_iterator.samples, num_validation)
294
295    # check number of classes and images
296    self.assertEqual(len(train_iterator.class_indices), num_classes)
297    self.assertEqual(len(train_iterator.classes), num_training)
298    self.assertEqual(
299        len(set(train_iterator.filenames) & set(filenames)), num_training)
300
301    shutil.rmtree(tmp_folder)
302
303  def test_directory_iterator_with_validation_split_25_percent(self):
304    self.directory_iterator_with_validation_split_test_helper(0.25)
305
306  def test_directory_iterator_with_validation_split_40_percent(self):
307    self.directory_iterator_with_validation_split_test_helper(0.40)
308
309  def test_directory_iterator_with_validation_split_50_percent(self):
310    self.directory_iterator_with_validation_split_test_helper(0.50)
311
312  def test_img_utils(self):
313    if PIL is None:
314      return  # Skip test if PIL is not available.
315
316    height, width = 10, 8
317
318    # Test channels_first data format
319    x = np.random.random((3, height, width))
320    img = keras.preprocessing.image.array_to_img(
321        x, data_format='channels_first')
322    self.assertEqual(img.size, (width, height))
323    x = keras.preprocessing.image.img_to_array(
324        img, data_format='channels_first')
325    self.assertEqual(x.shape, (3, height, width))
326    # Test 2D
327    x = np.random.random((1, height, width))
328    img = keras.preprocessing.image.array_to_img(
329        x, data_format='channels_first')
330    self.assertEqual(img.size, (width, height))
331    x = keras.preprocessing.image.img_to_array(
332        img, data_format='channels_first')
333    self.assertEqual(x.shape, (1, height, width))
334
335    # Test channels_last data format
336    x = np.random.random((height, width, 3))
337    img = keras.preprocessing.image.array_to_img(x, data_format='channels_last')
338    self.assertEqual(img.size, (width, height))
339    x = keras.preprocessing.image.img_to_array(img, data_format='channels_last')
340    self.assertEqual(x.shape, (height, width, 3))
341    # Test 2D
342    x = np.random.random((height, width, 1))
343    img = keras.preprocessing.image.array_to_img(x, data_format='channels_last')
344    self.assertEqual(img.size, (width, height))
345    x = keras.preprocessing.image.img_to_array(img, data_format='channels_last')
346    self.assertEqual(x.shape, (height, width, 1))
347
348  def test_batch_standardize(self):
349    if PIL is None:
350      return  # Skip test if PIL is not available.
351
352    # ImageDataGenerator.standardize should work on batches
353    for test_images in _generate_test_images():
354      img_list = []
355      for im in test_images:
356        img_list.append(keras.preprocessing.image.img_to_array(im)[None, ...])
357
358      images = np.vstack(img_list)
359      generator = keras.preprocessing.image.ImageDataGenerator(
360          featurewise_center=True,
361          samplewise_center=True,
362          featurewise_std_normalization=True,
363          samplewise_std_normalization=True,
364          zca_whitening=True,
365          rotation_range=90.,
366          width_shift_range=0.1,
367          height_shift_range=0.1,
368          shear_range=0.5,
369          zoom_range=0.2,
370          channel_shift_range=0.,
371          brightness_range=(1, 5),
372          fill_mode='nearest',
373          cval=0.5,
374          horizontal_flip=True,
375          vertical_flip=True)
376      generator.fit(images, augment=True)
377
378      transformed = np.copy(images)
379      for i, im in enumerate(transformed):
380        transformed[i] = generator.random_transform(im)
381      transformed = generator.standardize(transformed)
382
383  def test_img_transforms(self):
384    x = np.random.random((3, 200, 200))
385    _ = keras.preprocessing.image.random_rotation(x, 20)
386    _ = keras.preprocessing.image.random_shift(x, 0.2, 0.2)
387    _ = keras.preprocessing.image.random_shear(x, 2.)
388    _ = keras.preprocessing.image.random_zoom(x, (0.5, 0.5))
389    _ = keras.preprocessing.image.apply_channel_shift(x, 2, 2)
390    _ = keras.preprocessing.image.apply_affine_transform(x, 2)
391    with self.assertRaises(ValueError):
392      keras.preprocessing.image.random_zoom(x, (0, 0, 0))
393    _ = keras.preprocessing.image.random_channel_shift(x, 2.)
394
395
396if __name__ == '__main__':
397  test.main()
398