1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
7#     http://www.apache.org/licenses/LICENSE-2.0
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Python layer for image_ops."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
20from tensorflow.python.eager import context
21from tensorflow.contrib.image.ops import gen_image_ops
22from tensorflow.contrib.util import loader
23from tensorflow.python.framework import common_shapes
24from tensorflow.python.framework import constant_op
25from tensorflow.python.framework import dtypes
26from tensorflow.python.framework import ops
27from tensorflow.python.framework import tensor_util
28from tensorflow.python.ops import array_ops
29from tensorflow.python.ops import control_flow_ops
30from tensorflow.python.ops import linalg_ops
31from tensorflow.python.ops import math_ops
32from tensorflow.python.platform import resource_loader
34_image_ops_so = loader.load_op_library(
35    resource_loader.get_path_to_datafile("_image_ops.so"))
37_IMAGE_DTYPES = set(
38    [dtypes.uint8, dtypes.int32, dtypes.int64,
39     dtypes.float16, dtypes.float32, dtypes.float64])
46# TODO(ringwalt): Support a "reshape" (name used by SciPy) or "expand" (name
47# used by PIL, maybe more readable) mode, which determines the correct
48# output_shape and translation for the transform.
49def rotate(images, angles, interpolation="NEAREST", name=None):
50  """Rotate image(s) counterclockwise by the passed angle(s) in radians.
52  Args:
53    images: A tensor of shape (num_images, num_rows, num_columns, num_channels)
54       (NHWC), (num_rows, num_columns, num_channels) (HWC), or
55       (num_rows, num_columns) (HW). The rank must be statically known (the
56       shape is not `TensorShape(None)`.
57    angles: A scalar angle to rotate all images by, or (if images has rank 4)
58       a vector of length num_images, with an angle for each image in the batch.
59    interpolation: Interpolation mode. Supported values: "NEAREST", "BILINEAR".
60    name: The name of the op.
62  Returns:
63    Image(s) with the same type and shape as `images`, rotated by the given
64    angle(s). Empty space due to the rotation will be filled with zeros.
66  Raises:
67    TypeError: If `image` is an invalid type.
68  """
69  with ops.name_scope(name, "rotate"):
70    image_or_images = ops.convert_to_tensor(images)
71    if image_or_images.dtype.base_dtype not in _IMAGE_DTYPES:
72      raise TypeError("Invalid dtype %s." % image_or_images.dtype)
73    elif image_or_images.get_shape().ndims is None:
74      raise TypeError("image_or_images rank must be statically known")
75    elif len(image_or_images.get_shape()) == 2:
76      images = image_or_images[None, :, :, None]
77    elif len(image_or_images.get_shape()) == 3:
78      images = image_or_images[None, :, :, :]
79    elif len(image_or_images.get_shape()) == 4:
80      images = image_or_images
81    else:
82      raise TypeError("Images should have rank between 2 and 4.")
84    image_height = math_ops.cast(array_ops.shape(images)[1],
85                                 dtypes.float32)[None]
86    image_width = math_ops.cast(array_ops.shape(images)[2],
87                                dtypes.float32)[None]
88    output = transform(
89        images,
90        angles_to_projective_transforms(angles, image_height, image_width),
91        interpolation=interpolation)
92    if image_or_images.get_shape().ndims is None:
93      raise TypeError("image_or_images rank must be statically known")
94    elif len(image_or_images.get_shape()) == 2:
95      return output[0, :, :, 0]
96    elif len(image_or_images.get_shape()) == 3:
97      return output[0, :, :, :]
98    else:
99      return output
102def translate(images, translations, interpolation="NEAREST", name=None):
103  """Translate image(s) by the passed vectors(s).
105  Args:
106    images: A tensor of shape (num_images, num_rows, num_columns, num_channels)
107        (NHWC), (num_rows, num_columns, num_channels) (HWC), or
108        (num_rows, num_columns) (HW). The rank must be statically known (the
109        shape is not `TensorShape(None)`.
110    translations: A vector representing [dx, dy] or (if images has rank 4)
111        a matrix of length num_images, with a [dx, dy] vector for each image in
112        the batch.
113    interpolation: Interpolation mode. Supported values: "NEAREST", "BILINEAR".
114    name: The name of the op.
116  Returns:
117    Image(s) with the same type and shape as `images`, translated by the given
118        vector(s). Empty space due to the translation will be filled with zeros.
120  Raises:
121    TypeError: If `image` is an invalid type.
122  """
123  with ops.name_scope(name, "translate"):
124    return transform(
125        images,
126        translations_to_projective_transforms(translations),
127        interpolation=interpolation)
130def angles_to_projective_transforms(angles,
131                                    image_height,
132                                    image_width,
133                                    name=None):
134  """Returns projective transform(s) for the given angle(s).
136  Args:
137    angles: A scalar angle to rotate all images by, or (for batches of images)
138        a vector with an angle to rotate each image in the batch. The rank must
139        be statically known (the shape is not `TensorShape(None)`.
140    image_height: Height of the image(s) to be transformed.
141    image_width: Width of the image(s) to be transformed.
143  Returns:
144    A tensor of shape (num_images, 8). Projective transforms which can be given
145      to `tf.contrib.image.transform`.
146  """
147  with ops.name_scope(name, "angles_to_projective_transforms"):
148    angle_or_angles = ops.convert_to_tensor(
149        angles, name="angles", dtype=dtypes.float32)
150    if len(angle_or_angles.get_shape()) == 0:  # pylint: disable=g-explicit-length-test
151      angles = angle_or_angles[None]
152    elif len(angle_or_angles.get_shape()) == 1:
153      angles = angle_or_angles
154    else:
155      raise TypeError("Angles should have rank 0 or 1.")
156    x_offset = ((image_width - 1) - (math_ops.cos(angles) *
157                                     (image_width - 1) - math_ops.sin(angles) *
158                                     (image_height - 1))) / 2.0
159    y_offset = ((image_height - 1) - (math_ops.sin(angles) *
160                                      (image_width - 1) + math_ops.cos(angles) *
161                                      (image_height - 1))) / 2.0
162    num_angles = array_ops.shape(angles)[0]
163    return array_ops.concat(
164        values=[
165            math_ops.cos(angles)[:, None],
166            -math_ops.sin(angles)[:, None],
167            x_offset[:, None],
168            math_ops.sin(angles)[:, None],
169            math_ops.cos(angles)[:, None],
170            y_offset[:, None],
171            array_ops.zeros((num_angles, 2), dtypes.float32),
172        ],
173        axis=1)
176def translations_to_projective_transforms(translations, name=None):
177  """Returns projective transform(s) for the given translation(s).
179  Args:
180      translations: A 2-element list representing [dx, dy] or a matrix of
181          2-element lists representing [dx, dy] to translate for each image
182          (for a batch of images). The rank must be statically known (the shape
183          is not `TensorShape(None)`.
184      name: The name of the op.
186  Returns:
187      A tensor of shape (num_images, 8) projective transforms which can be given
188          to `tf.contrib.image.transform`.
189  """
190  with ops.name_scope(name, "translations_to_projective_transforms"):
191    translation_or_translations = ops.convert_to_tensor(
192        translations, name="translations", dtype=dtypes.float32)
193    if translation_or_translations.get_shape().ndims is None:
194      raise TypeError(
195          "translation_or_translations rank must be statically known")
196    elif len(translation_or_translations.get_shape()) == 1:
197      translations = translation_or_translations[None]
198    elif len(translation_or_translations.get_shape()) == 2:
199      translations = translation_or_translations
200    else:
201      raise TypeError("Translations should have rank 1 or 2.")
202    num_translations = array_ops.shape(translations)[0]
203    # The translation matrix looks like:
204    #     [[1 0 -dx]
205    #      [0 1 -dy]
206    #      [0 0 1]]
207    # where the last entry is implicit.
208    # Translation matrices are always float32.
209    return array_ops.concat(
210        values=[
211            array_ops.ones((num_translations, 1), dtypes.float32),
212            array_ops.zeros((num_translations, 1), dtypes.float32),
213            -translations[:, 0, None],
214            array_ops.zeros((num_translations, 1), dtypes.float32),
215            array_ops.ones((num_translations, 1), dtypes.float32),
216            -translations[:, 1, None],
217            array_ops.zeros((num_translations, 2), dtypes.float32),
218        ],
219        axis=1)
222def transform(images,
223              transforms,
224              interpolation="NEAREST",
225              output_shape=None,
226              name=None):
227  """Applies the given transform(s) to the image(s).
229  Args:
230    images: A tensor of shape (num_images, num_rows, num_columns, num_channels)
231       (NHWC), (num_rows, num_columns, num_channels) (HWC), or
232       (num_rows, num_columns) (HW). The rank must be statically known (the
233       shape is not `TensorShape(None)`.
234    transforms: Projective transform matrix/matrices. A vector of length 8 or
235       tensor of size N x 8. If one row of transforms is
236       [a0, a1, a2, b0, b1, b2, c0, c1], then it maps the *output* point
237       `(x, y)` to a transformed *input* point
238       `(x', y') = ((a0 x + a1 y + a2) / k, (b0 x + b1 y + b2) / k)`,
239       where `k = c0 x + c1 y + 1`. The transforms are *inverted* compared to
240       the transform mapping input points to output points. Note that gradients
241       are not backpropagated into transformation parameters.
242    interpolation: Interpolation mode. Supported values: "NEAREST", "BILINEAR".
243    output_shape: Output dimesion after the transform, [height, width].
244       If None, output is the same size as input image.
246    name: The name of the op.
248  Returns:
249    Image(s) with the same type and shape as `images`, with the given
250    transform(s) applied. Transformed coordinates outside of the input image
251    will be filled with zeros.
253  Raises:
254    TypeError: If `image` is an invalid type.
255    ValueError: If output shape is not 1-D int32 Tensor.
256  """
257  with ops.name_scope(name, "transform"):
258    image_or_images = ops.convert_to_tensor(images, name="images")
259    transform_or_transforms = ops.convert_to_tensor(
260        transforms, name="transforms", dtype=dtypes.float32)
261    if image_or_images.dtype.base_dtype not in _IMAGE_DTYPES:
262      raise TypeError("Invalid dtype %s." % image_or_images.dtype)
263    elif image_or_images.get_shape().ndims is None:
264      raise TypeError("image_or_images rank must be statically known")
265    elif len(image_or_images.get_shape()) == 2:
266      images = image_or_images[None, :, :, None]
267    elif len(image_or_images.get_shape()) == 3:
268      images = image_or_images[None, :, :, :]
269    elif len(image_or_images.get_shape()) == 4:
270      images = image_or_images
271    else:
272      raise TypeError("Images should have rank between 2 and 4.")
274    if output_shape is None:
275      output_shape = array_ops.shape(images)[1:3]
276      if not context.executing_eagerly():
277        output_shape_value = tensor_util.constant_value(output_shape)
278        if output_shape_value is not None:
279          output_shape = output_shape_value
281    output_shape = ops.convert_to_tensor(
282        output_shape, dtypes.int32, name="output_shape")
284    if not output_shape.get_shape().is_compatible_with([2]):
285      raise ValueError("output_shape must be a 1-D Tensor of 2 elements: "
286                       "new_height, new_width")
288    if len(transform_or_transforms.get_shape()) == 1:
289      transforms = transform_or_transforms[None]
290    elif transform_or_transforms.get_shape().ndims is None:
291      raise TypeError(
292          "transform_or_transforms rank must be statically known")
293    elif len(transform_or_transforms.get_shape()) == 2:
294      transforms = transform_or_transforms
295    else:
296      raise TypeError("Transforms should have rank 1 or 2.")
298    output = gen_image_ops.image_projective_transform_v2(
299        images,
300        output_shape=output_shape,
301        transforms=transforms,
302        interpolation=interpolation.upper())
303    if len(image_or_images.get_shape()) == 2:
304      return output[0, :, :, 0]
305    elif len(image_or_images.get_shape()) == 3:
306      return output[0, :, :, :]
307    else:
308      return output
311def compose_transforms(*transforms):
312  """Composes the transforms tensors.
314  Args:
315    *transforms: List of image projective transforms to be composed. Each
316        transform is length 8 (single transform) or shape (N, 8) (batched
317        transforms). The shapes of all inputs must be equal, and at least one
318        input must be given.
320  Returns:
321    A composed transform tensor. When passed to `tf.contrib.image.transform`,
322        equivalent to applying each of the given transforms to the image in
323        order.
324  """
325  assert transforms, "transforms cannot be empty"
326  with ops.name_scope("compose_transforms"):
327    composed = flat_transforms_to_matrices(transforms[0])
328    for tr in transforms[1:]:
329      # Multiply batches of matrices.
330      composed = math_ops.matmul(composed, flat_transforms_to_matrices(tr))
331    return matrices_to_flat_transforms(composed)
334def flat_transforms_to_matrices(transforms):
335  """Converts `tf.contrib.image` projective transforms to affine matrices.
337  Note that the output matrices map output coordinates to input coordinates. For
338  the forward transformation matrix, call `tf.linalg.inv` on the result.
340  Args:
341    transforms: Vector of length 8, or batches of transforms with shape
342      `(N, 8)`.
344  Returns:
345    3D tensor of matrices with shape `(N, 3, 3)`. The output matrices map the
346      *output coordinates* (in homogeneous coordinates) of each transform to the
347      corresponding *input coordinates*.
349  Raises:
350    ValueError: If `transforms` have an invalid shape.
351  """
352  with ops.name_scope("flat_transforms_to_matrices"):
353    transforms = ops.convert_to_tensor(transforms, name="transforms")
354    if transforms.shape.ndims not in (1, 2):
355      raise ValueError("Transforms should be 1D or 2D, got: %s" % transforms)
356    # Make the transform(s) 2D in case the input is a single transform.
357    transforms = array_ops.reshape(transforms, constant_op.constant([-1, 8]))
358    num_transforms = array_ops.shape(transforms)[0]
359    # Add a column of ones for the implicit last entry in the matrix.
360    return array_ops.reshape(
361        array_ops.concat(
362            [transforms, array_ops.ones([num_transforms, 1])], axis=1),
363        constant_op.constant([-1, 3, 3]))
366def matrices_to_flat_transforms(transform_matrices):
367  """Converts affine matrices to `tf.contrib.image` projective transforms.
369  Note that we expect matrices that map output coordinates to input coordinates.
370  To convert forward transformation matrices, call `tf.linalg.inv` on the
371  matrices and use the result here.
373  Args:
374    transform_matrices: One or more affine transformation matrices, for the
375      reverse transformation in homogeneous coordinates. Shape `(3, 3)` or
376      `(N, 3, 3)`.
378  Returns:
379    2D tensor of flat transforms with shape `(N, 8)`, which may be passed into
380      `tf.contrib.image.transform`.
382  Raises:
383    ValueError: If `transform_matrices` have an invalid shape.
384  """
385  with ops.name_scope("matrices_to_flat_transforms"):
386    transform_matrices = ops.convert_to_tensor(
387        transform_matrices, name="transform_matrices")
388    if transform_matrices.shape.ndims not in (2, 3):
389      raise ValueError(
390          "Matrices should be 2D or 3D, got: %s" % transform_matrices)
391    # Flatten each matrix.
392    transforms = array_ops.reshape(transform_matrices,
393                                   constant_op.constant([-1, 9]))
394    # Divide each matrix by the last entry (normally 1).
395    transforms /= transforms[:, 8:9]
396    return transforms[:, :8]
400def _image_projective_transform_grad(op, grad):
401  """Computes the gradient for ImageProjectiveTransform."""
402  images = op.inputs[0]
403  transforms = op.inputs[1]
404  interpolation = op.get_attr("interpolation")
406  image_or_images = ops.convert_to_tensor(images, name="images")
407  transform_or_transforms = ops.convert_to_tensor(
408      transforms, name="transforms", dtype=dtypes.float32)
410  if image_or_images.dtype.base_dtype not in _IMAGE_DTYPES:
411    raise TypeError("Invalid dtype %s." % image_or_images.dtype)
412  if len(transform_or_transforms.get_shape()) == 1:
413    transforms = transform_or_transforms[None]
414  elif len(transform_or_transforms.get_shape()) == 2:
415    transforms = transform_or_transforms
416  else:
417    raise TypeError("Transforms should have rank 1 or 2.")
419  # Invert transformations
420  transforms = flat_transforms_to_matrices(transforms=transforms)
421  inverse = linalg_ops.matrix_inverse(transforms)
422  transforms = matrices_to_flat_transforms(inverse)
423  output = gen_image_ops.image_projective_transform_v2(
424      images=grad,
425      transforms=transforms,
426      output_shape=array_ops.shape(image_or_images)[1:3],
427      interpolation=interpolation)
428  return [output, None, None]
431def bipartite_match(distance_mat,
432                    num_valid_rows,
433                    top_k=-1,
434                    name="bipartite_match"):
435  """Find bipartite matching based on a given distance matrix.
437  A greedy bi-partite matching algorithm is used to obtain the matching with
438  the (greedy) minimum distance.
440  Args:
441    distance_mat: A 2-D float tensor of shape `[num_rows, num_columns]`. It is a
442      pair-wise distance matrix between the entities represented by each row and
443      each column. It is an asymmetric matrix. The smaller the distance is, the
444      more similar the pairs are. The bipartite matching is to minimize the
445      distances.
446    num_valid_rows: A scalar or a 1-D tensor with one element describing the
447      number of valid rows of distance_mat to consider for the bipartite
448      matching. If set to be negative, then all rows from `distance_mat` are
449      used.
450    top_k: A scalar that specifies the number of top-k matches to retrieve.
451      If set to be negative, then is set according to the maximum number of
452      matches from `distance_mat`.
453    name: The name of the op.
455  Returns:
456    row_to_col_match_indices: A vector of length num_rows, which is the number
457      of rows of the input `distance_matrix`. If `row_to_col_match_indices[i]`
458      is not -1, row i is matched to column `row_to_col_match_indices[i]`.
459    col_to_row_match_indices: A vector of length num_columns, which is the
460      number of columns of the input distance matrix.
461      If `col_to_row_match_indices[j]` is not -1, column j is matched to row
462      `col_to_row_match_indices[j]`.
463  """
464  result = gen_image_ops.bipartite_match(
465      distance_mat, num_valid_rows, top_k, name=name)
466  return result
469def connected_components(images):
470  """Labels the connected components in a batch of images.
472  A component is a set of pixels in a single input image, which are all adjacent
473  and all have the same non-zero value. The components using a squared
474  connectivity of one (all True entries are joined with their neighbors above,
475  below, left, and right). Components across all images have consecutive ids 1
476  through n. Components are labeled according to the first pixel of the
477  component appearing in row-major order (lexicographic order by
478  image_index_in_batch, row, col). Zero entries all have an output id of 0.
480  This op is equivalent with `scipy.ndimage.measurements.label` on a 2D array
481  with the default structuring element (which is the connectivity used here).
483  Args:
484    images: A 2D (H, W) or 3D (N, H, W) Tensor of boolean image(s).
486  Returns:
487    Components with the same shape as `images`. False entries in `images` have
488    value 0, and all True entries map to a component id > 0.
490  Raises:
491    TypeError: if `images` is not 2D or 3D.
492  """
493  with ops.name_scope("connected_components"):
494    image_or_images = ops.convert_to_tensor(images, name="images")
495    if len(image_or_images.get_shape()) == 2:
496      images = image_or_images[None, :, :]
497    elif len(image_or_images.get_shape()) == 3:
498      images = image_or_images
499    else:
500      raise TypeError(
501          "images should have rank 2 (HW) or 3 (NHW). Static shape is %s" %
502          image_or_images.get_shape())
503    components = gen_image_ops.image_connected_components(images)
505    # TODO(ringwalt): Component id renaming should be done in the op, to avoid
506    # constructing multiple additional large tensors.
507    components_flat = array_ops.reshape(components, [-1])
508    unique_ids, id_index = array_ops.unique(components_flat)
509    id_is_zero = array_ops.where(math_ops.equal(unique_ids, 0))[:, 0]
510    # Map each nonzero id to consecutive values.
511    nonzero_consecutive_ids = math_ops.range(
512        array_ops.shape(unique_ids)[0] - array_ops.shape(id_is_zero)[0]) + 1
514    def no_zero():
515      # No need to insert a zero into the ids.
516      return nonzero_consecutive_ids
518    def has_zero():
519      # Insert a zero in the consecutive ids where zero appears in unique_ids.
520      # id_is_zero has length 1.
521      zero_id_ind = math_ops.cast(id_is_zero[0], dtypes.int32)
522      ids_before = nonzero_consecutive_ids[:zero_id_ind]
523      ids_after = nonzero_consecutive_ids[zero_id_ind:]
524      return array_ops.concat([ids_before, [0], ids_after], axis=0)
526    new_ids = control_flow_ops.cond(
527        math_ops.equal(array_ops.shape(id_is_zero)[0], 0), no_zero, has_zero)
528    components = array_ops.reshape(
529        array_ops.gather(new_ids, id_index), array_ops.shape(components))
530    if len(image_or_images.get_shape()) == 2:
531      return components[0, :, :]
532    else:
533      return components