1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=invalid-name
16# pylint: disable=g-import-not-at-top
17"""Set of tools for real-time data augmentation on image data.
18"""
19from __future__ import absolute_import
20from __future__ import division
21from __future__ import print_function
22
23from keras_preprocessing import image
24try:
25  from scipy import linalg  # pylint: disable=unused-import
26  from scipy import ndimage  # pylint: disable=unused-import
27except ImportError:
28  pass
29
30from tensorflow.python.keras import backend
31from tensorflow.python.keras import utils
32from tensorflow.python.util import tf_inspect
33from tensorflow.python.util.tf_export import keras_export
34
35random_rotation = image.random_rotation
36random_shift = image.random_shift
37random_shear = image.random_shear
38random_zoom = image.random_zoom
39apply_channel_shift = image.apply_channel_shift
40random_channel_shift = image.random_channel_shift
41apply_brightness_shift = image.apply_brightness_shift
42random_brightness = image.random_brightness
43apply_affine_transform = image.apply_affine_transform
44load_img = image.load_img
45
46
47@keras_export('keras.preprocessing.image.array_to_img')
48def array_to_img(x, data_format=None, scale=True, dtype=None):
49  """Converts a 3D Numpy array to a PIL Image instance.
50
51  Arguments:
52      x: Input Numpy array.
53      data_format: Image data format.
54          either "channels_first" or "channels_last".
55      scale: Whether to rescale image values
56          to be within `[0, 255]`.
57      dtype: Dtype to use.
58
59  Returns:
60      A PIL Image instance.
61
62  Raises:
63      ImportError: if PIL is not available.
64      ValueError: if invalid `x` or `data_format` is passed.
65  """
66
67  if data_format is None:
68    data_format = backend.image_data_format()
69  kwargs = {}
70  if 'dtype' in tf_inspect.getfullargspec(image.array_to_img)[0]:
71    if dtype is None:
72      dtype = backend.floatx()
73    kwargs['dtype'] = dtype
74  return image.array_to_img(x, data_format=data_format, scale=scale, **kwargs)
75
76
77@keras_export('keras.preprocessing.image.img_to_array')
78def img_to_array(img, data_format=None, dtype=None):
79  """Converts a PIL Image instance to a Numpy array.
80
81  Arguments:
82      img: PIL Image instance.
83      data_format: Image data format,
84          either "channels_first" or "channels_last".
85      dtype: Dtype to use for the returned array.
86
87  Returns:
88      A 3D Numpy array.
89
90  Raises:
91      ValueError: if invalid `img` or `data_format` is passed.
92  """
93
94  if data_format is None:
95    data_format = backend.image_data_format()
96  kwargs = {}
97  if 'dtype' in tf_inspect.getfullargspec(image.img_to_array)[0]:
98    if dtype is None:
99      dtype = backend.floatx()
100    kwargs['dtype'] = dtype
101  return image.img_to_array(img, data_format=data_format, **kwargs)
102
103
104@keras_export('keras.preprocessing.image.save_img')
105def save_img(path,
106             x,
107             data_format=None,
108             file_format=None,
109             scale=True,
110             **kwargs):
111  """Saves an image stored as a Numpy array to a path or file object.
112
113  Arguments:
114      path: Path or file object.
115      x: Numpy array.
116      data_format: Image data format,
117          either "channels_first" or "channels_last".
118      file_format: Optional file format override. If omitted, the
119          format to use is determined from the filename extension.
120          If a file object was used instead of a filename, this
121          parameter should always be used.
122      scale: Whether to rescale image values to be within `[0, 255]`.
123      **kwargs: Additional keyword arguments passed to `PIL.Image.save()`.
124  """
125  if data_format is None:
126    data_format = backend.image_data_format()
127  image.save_img(path,
128                 x,
129                 data_format=data_format,
130                 file_format=file_format,
131                 scale=scale, **kwargs)
132
133
134@keras_export('keras.preprocessing.image.Iterator')
135class Iterator(image.Iterator, utils.Sequence):
136  pass
137
138
139@keras_export('keras.preprocessing.image.DirectoryIterator')
140class DirectoryIterator(image.DirectoryIterator, Iterator):
141  """Iterator capable of reading images from a directory on disk.
142
143  Arguments:
144      directory: Path to the directory to read images from.
145          Each subdirectory in this directory will be
146          considered to contain images from one class,
147          or alternatively you could specify class subdirectories
148          via the `classes` argument.
149      image_data_generator: Instance of `ImageDataGenerator`
150          to use for random transformations and normalization.
151      target_size: tuple of integers, dimensions to resize input images to.
152      color_mode: One of `"rgb"`, `"rgba"`, `"grayscale"`.
153          Color mode to read images.
154      classes: Optional list of strings, names of subdirectories
155          containing images from each class (e.g. `["dogs", "cats"]`).
156          It will be computed automatically if not set.
157      class_mode: Mode for yielding the targets:
158          `"binary"`: binary targets (if there are only two classes),
159          `"categorical"`: categorical targets,
160          `"sparse"`: integer targets,
161          `"input"`: targets are images identical to input images (mainly
162              used to work with autoencoders),
163          `None`: no targets get yielded (only input images are yielded).
164      batch_size: Integer, size of a batch.
165      shuffle: Boolean, whether to shuffle the data between epochs.
166      seed: Random seed for data shuffling.
167      data_format: String, one of `channels_first`, `channels_last`.
168      save_to_dir: Optional directory where to save the pictures
169          being yielded, in a viewable format. This is useful
170          for visualizing the random transformations being
171          applied, for debugging purposes.
172      save_prefix: String prefix to use for saving sample
173          images (if `save_to_dir` is set).
174      save_format: Format to use for saving sample images
175          (if `save_to_dir` is set).
176      subset: Subset of data (`"training"` or `"validation"`) if
177          validation_split is set in ImageDataGenerator.
178      interpolation: Interpolation method used to resample the image if the
179          target size is different from that of the loaded image.
180          Supported methods are "nearest", "bilinear", and "bicubic".
181          If PIL version 1.1.3 or newer is installed, "lanczos" is also
182          supported. If PIL version 3.4.0 or newer is installed, "box" and
183          "hamming" are also supported. By default, "nearest" is used.
184      dtype: Dtype to use for generated arrays.
185  """
186
187  def __init__(self, directory, image_data_generator,
188               target_size=(256, 256),
189               color_mode='rgb',
190               classes=None,
191               class_mode='categorical',
192               batch_size=32,
193               shuffle=True,
194               seed=None,
195               data_format=None,
196               save_to_dir=None,
197               save_prefix='',
198               save_format='png',
199               follow_links=False,
200               subset=None,
201               interpolation='nearest',
202               dtype=None):
203    if data_format is None:
204      data_format = backend.image_data_format()
205    kwargs = {}
206    if 'dtype' in tf_inspect.getfullargspec(
207        image.ImageDataGenerator.__init__)[0]:
208      if dtype is None:
209        dtype = backend.floatx()
210      kwargs['dtype'] = dtype
211    super(DirectoryIterator, self).__init__(
212        directory, image_data_generator,
213        target_size=target_size,
214        color_mode=color_mode,
215        classes=classes,
216        class_mode=class_mode,
217        batch_size=batch_size,
218        shuffle=shuffle,
219        seed=seed,
220        data_format=data_format,
221        save_to_dir=save_to_dir,
222        save_prefix=save_prefix,
223        save_format=save_format,
224        follow_links=follow_links,
225        subset=subset,
226        interpolation=interpolation,
227        **kwargs)
228
229
230@keras_export('keras.preprocessing.image.NumpyArrayIterator')
231class NumpyArrayIterator(image.NumpyArrayIterator, Iterator):
232  """Iterator yielding data from a Numpy array.
233
234  Arguments:
235      x: Numpy array of input data or tuple.
236          If tuple, the second elements is either
237          another numpy array or a list of numpy arrays,
238          each of which gets passed
239          through as an output without any modifications.
240      y: Numpy array of targets data.
241      image_data_generator: Instance of `ImageDataGenerator`
242          to use for random transformations and normalization.
243      batch_size: Integer, size of a batch.
244      shuffle: Boolean, whether to shuffle the data between epochs.
245      sample_weight: Numpy array of sample weights.
246      seed: Random seed for data shuffling.
247      data_format: String, one of `channels_first`, `channels_last`.
248      save_to_dir: Optional directory where to save the pictures
249          being yielded, in a viewable format. This is useful
250          for visualizing the random transformations being
251          applied, for debugging purposes.
252      save_prefix: String prefix to use for saving sample
253          images (if `save_to_dir` is set).
254      save_format: Format to use for saving sample images
255          (if `save_to_dir` is set).
256      subset: Subset of data (`"training"` or `"validation"`) if
257          validation_split is set in ImageDataGenerator.
258      dtype: Dtype to use for the generated arrays.
259  """
260
261  def __init__(self, x, y, image_data_generator,
262               batch_size=32,
263               shuffle=False,
264               sample_weight=None,
265               seed=None,
266               data_format=None,
267               save_to_dir=None,
268               save_prefix='',
269               save_format='png',
270               subset=None,
271               dtype=None):
272    if data_format is None:
273      data_format = backend.image_data_format()
274    kwargs = {}
275    if 'dtype' in tf_inspect.getfullargspec(
276        image.NumpyArrayIterator.__init__)[0]:
277      if dtype is None:
278        dtype = backend.floatx()
279      kwargs['dtype'] = dtype
280    super(NumpyArrayIterator, self).__init__(
281        x, y, image_data_generator,
282        batch_size=batch_size,
283        shuffle=shuffle,
284        sample_weight=sample_weight,
285        seed=seed,
286        data_format=data_format,
287        save_to_dir=save_to_dir,
288        save_prefix=save_prefix,
289        save_format=save_format,
290        subset=subset,
291        **kwargs)
292
293
294@keras_export('keras.preprocessing.image.ImageDataGenerator')
295class ImageDataGenerator(image.ImageDataGenerator):
296  """Generate batches of tensor image data with real-time data augmentation.
297
298   The data will be looped over (in batches).
299
300  Arguments:
301      featurewise_center: Boolean.
302          Set input mean to 0 over the dataset, feature-wise.
303      samplewise_center: Boolean. Set each sample mean to 0.
304      featurewise_std_normalization: Boolean.
305          Divide inputs by std of the dataset, feature-wise.
306      samplewise_std_normalization: Boolean. Divide each input by its std.
307      zca_epsilon: epsilon for ZCA whitening. Default is 1e-6.
308      zca_whitening: Boolean. Apply ZCA whitening.
309      rotation_range: Int. Degree range for random rotations.
310      width_shift_range: Float, 1-D array-like or int
311          - float: fraction of total width, if < 1, or pixels if >= 1.
312          - 1-D array-like: random elements from the array.
313          - int: integer number of pixels from interval
314              `(-width_shift_range, +width_shift_range)`
315          - With `width_shift_range=2` possible values
316              are integers `[-1, 0, +1]`,
317              same as with `width_shift_range=[-1, 0, +1]`,
318              while with `width_shift_range=1.0` possible values are floats
319              in the interval [-1.0, +1.0).
320      height_shift_range: Float, 1-D array-like or int
321          - float: fraction of total height, if < 1, or pixels if >= 1.
322          - 1-D array-like: random elements from the array.
323          - int: integer number of pixels from interval
324              `(-height_shift_range, +height_shift_range)`
325          - With `height_shift_range=2` possible values
326              are integers `[-1, 0, +1]`,
327              same as with `height_shift_range=[-1, 0, +1]`,
328              while with `height_shift_range=1.0` possible values are floats
329              in the interval [-1.0, +1.0).
330      brightness_range: Tuple or list of two floats. Range for picking
331          a brightness shift value from.
332      shear_range: Float. Shear Intensity
333          (Shear angle in counter-clockwise direction in degrees)
334      zoom_range: Float or [lower, upper]. Range for random zoom.
335          If a float, `[lower, upper] = [1-zoom_range, 1+zoom_range]`.
336      channel_shift_range: Float. Range for random channel shifts.
337      fill_mode: One of {"constant", "nearest", "reflect" or "wrap"}.
338          Default is 'nearest'.
339          Points outside the boundaries of the input are filled
340          according to the given mode:
341          - 'constant': kkkkkkkk|abcd|kkkkkkkk (cval=k)
342          - 'nearest':  aaaaaaaa|abcd|dddddddd
343          - 'reflect':  abcddcba|abcd|dcbaabcd
344          - 'wrap':  abcdabcd|abcd|abcdabcd
345      cval: Float or Int.
346          Value used for points outside the boundaries
347          when `fill_mode = "constant"`.
348      horizontal_flip: Boolean. Randomly flip inputs horizontally.
349      vertical_flip: Boolean. Randomly flip inputs vertically.
350      rescale: rescaling factor. Defaults to None.
351          If None or 0, no rescaling is applied,
352          otherwise we multiply the data by the value provided
353          (after applying all other transformations).
354      preprocessing_function: function that will be implied on each input.
355          The function will run after the image is resized and augmented.
356          The function should take one argument:
357          one image (Numpy tensor with rank 3),
358          and should output a Numpy tensor with the same shape.
359      data_format: Image data format,
360          either "channels_first" or "channels_last".
361          "channels_last" mode means that the images should have shape
362          `(samples, height, width, channels)`,
363          "channels_first" mode means that the images should have shape
364          `(samples, channels, height, width)`.
365          It defaults to the `image_data_format` value found in your
366          Keras config file at `~/.keras/keras.json`.
367          If you never set it, then it will be "channels_last".
368      validation_split: Float. Fraction of images reserved for validation
369          (strictly between 0 and 1).
370      dtype: Dtype to use for the generated arrays.
371
372  Examples:
373
374  Example of using `.flow(x, y)`:
375
376  ```python
377  (x_train, y_train), (x_test, y_test) = cifar10.load_data()
378  y_train = np_utils.to_categorical(y_train, num_classes)
379  y_test = np_utils.to_categorical(y_test, num_classes)
380  datagen = ImageDataGenerator(
381      featurewise_center=True,
382      featurewise_std_normalization=True,
383      rotation_range=20,
384      width_shift_range=0.2,
385      height_shift_range=0.2,
386      horizontal_flip=True)
387  # compute quantities required for featurewise normalization
388  # (std, mean, and principal components if ZCA whitening is applied)
389  datagen.fit(x_train)
390  # fits the model on batches with real-time data augmentation:
391  model.fit_generator(datagen.flow(x_train, y_train, batch_size=32),
392                      steps_per_epoch=len(x_train) / 32, epochs=epochs)
393  # here's a more "manual" example
394  for e in range(epochs):
395      print('Epoch', e)
396      batches = 0
397      for x_batch, y_batch in datagen.flow(x_train, y_train, batch_size=32):
398          model.fit(x_batch, y_batch)
399          batches += 1
400          if batches >= len(x_train) / 32:
401              # we need to break the loop by hand because
402              # the generator loops indefinitely
403              break
404  ```
405
406  Example of using `.flow_from_directory(directory)`:
407
408  ```python
409  train_datagen = ImageDataGenerator(
410          rescale=1./255,
411          shear_range=0.2,
412          zoom_range=0.2,
413          horizontal_flip=True)
414  test_datagen = ImageDataGenerator(rescale=1./255)
415  train_generator = train_datagen.flow_from_directory(
416          'data/train',
417          target_size=(150, 150),
418          batch_size=32,
419          class_mode='binary')
420  validation_generator = test_datagen.flow_from_directory(
421          'data/validation',
422          target_size=(150, 150),
423          batch_size=32,
424          class_mode='binary')
425  model.fit_generator(
426          train_generator,
427          steps_per_epoch=2000,
428          epochs=50,
429          validation_data=validation_generator,
430          validation_steps=800)
431  ```
432
433  Example of transforming images and masks together.
434
435  ```python
436  # we create two instances with the same arguments
437  data_gen_args = dict(featurewise_center=True,
438                       featurewise_std_normalization=True,
439                       rotation_range=90,
440                       width_shift_range=0.1,
441                       height_shift_range=0.1,
442                       zoom_range=0.2)
443  image_datagen = ImageDataGenerator(**data_gen_args)
444  mask_datagen = ImageDataGenerator(**data_gen_args)
445  # Provide the same seed and keyword arguments to the fit and flow methods
446  seed = 1
447  image_datagen.fit(images, augment=True, seed=seed)
448  mask_datagen.fit(masks, augment=True, seed=seed)
449  image_generator = image_datagen.flow_from_directory(
450      'data/images',
451      class_mode=None,
452      seed=seed)
453  mask_generator = mask_datagen.flow_from_directory(
454      'data/masks',
455      class_mode=None,
456      seed=seed)
457  # combine generators into one which yields image and masks
458  train_generator = zip(image_generator, mask_generator)
459  model.fit_generator(
460      train_generator,
461      steps_per_epoch=2000,
462      epochs=50)
463  ```
464  """
465
466  def __init__(self,
467               featurewise_center=False,
468               samplewise_center=False,
469               featurewise_std_normalization=False,
470               samplewise_std_normalization=False,
471               zca_whitening=False,
472               zca_epsilon=1e-6,
473               rotation_range=0,
474               width_shift_range=0.,
475               height_shift_range=0.,
476               brightness_range=None,
477               shear_range=0.,
478               zoom_range=0.,
479               channel_shift_range=0.,
480               fill_mode='nearest',
481               cval=0.,
482               horizontal_flip=False,
483               vertical_flip=False,
484               rescale=None,
485               preprocessing_function=None,
486               data_format=None,
487               validation_split=0.0,
488               dtype=None):
489    if data_format is None:
490      data_format = backend.image_data_format()
491    kwargs = {}
492    if 'dtype' in tf_inspect.getfullargspec(
493        image.ImageDataGenerator.__init__)[0]:
494      if dtype is None:
495        dtype = backend.floatx()
496      kwargs['dtype'] = dtype
497    super(ImageDataGenerator, self).__init__(
498        featurewise_center=featurewise_center,
499        samplewise_center=samplewise_center,
500        featurewise_std_normalization=featurewise_std_normalization,
501        samplewise_std_normalization=samplewise_std_normalization,
502        zca_whitening=zca_whitening,
503        zca_epsilon=zca_epsilon,
504        rotation_range=rotation_range,
505        width_shift_range=width_shift_range,
506        height_shift_range=height_shift_range,
507        brightness_range=brightness_range,
508        shear_range=shear_range,
509        zoom_range=zoom_range,
510        channel_shift_range=channel_shift_range,
511        fill_mode=fill_mode,
512        cval=cval,
513        horizontal_flip=horizontal_flip,
514        vertical_flip=vertical_flip,
515        rescale=rescale,
516        preprocessing_function=preprocessing_function,
517        data_format=data_format,
518        validation_split=validation_split,
519        **kwargs)
520
521keras_export('keras.preprocessing.image.random_rotation')(random_rotation)
522keras_export('keras.preprocessing.image.random_shift')(random_shift)
523keras_export('keras.preprocessing.image.random_shear')(random_shear)
524keras_export('keras.preprocessing.image.random_zoom')(random_zoom)
525keras_export(
526    'keras.preprocessing.image.apply_channel_shift')(apply_channel_shift)
527keras_export(
528    'keras.preprocessing.image.random_channel_shift')(random_channel_shift)
529keras_export(
530    'keras.preprocessing.image.apply_brightness_shift')(apply_brightness_shift)
531keras_export('keras.preprocessing.image.random_brightness')(random_brightness)
532keras_export(
533    'keras.preprocessing.image.apply_affine_transform')(apply_affine_transform)
534keras_export('keras.preprocessing.image.load_img')(load_img)
535