1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Contains the TFExampleDecoder its associated helper classes.
16
17The TFExampleDecode is a DataDecoder used to decode TensorFlow Example protos.
18In order to do so each requested item must be paired with one or more Example
19features that are parsed to produce the Tensor-based manifestation of the item.
20"""
21
22from __future__ import absolute_import
23from __future__ import division
24from __future__ import print_function
25
26import abc
27
28import six
29
30from tensorflow.contrib.slim.python.slim.data import data_decoder
31from tensorflow.python.framework import dtypes
32from tensorflow.python.framework import sparse_tensor
33from tensorflow.python.ops import array_ops
34from tensorflow.python.ops import control_flow_ops
35from tensorflow.python.ops import map_fn
36from tensorflow.python.ops import image_ops
37from tensorflow.python.ops import math_ops
38from tensorflow.python.ops import parsing_ops
39from tensorflow.python.ops import sparse_ops
40
41
42@six.add_metaclass(abc.ABCMeta)
43class ItemHandler(object):
44  """Specifies the item-to-Features mapping for tf.parse_example.
45
46  An ItemHandler both specifies a list of Features used for parsing an Example
47  proto as well as a function that post-processes the results of Example
48  parsing.
49  """
50
51  def __init__(self, keys):
52    """Constructs the handler with the name of the tf.Feature keys to use.
53
54    See third_party/tensorflow/core/example/feature.proto
55
56    Args:
57      keys: the name of the TensorFlow Example Feature.
58    """
59    if not isinstance(keys, (tuple, list)):
60      keys = [keys]
61    self._keys = keys
62
63  @property
64  def keys(self):
65    return self._keys
66
67  @abc.abstractmethod
68  def tensors_to_item(self, keys_to_tensors):
69    """Maps the given dictionary of tensors to the requested item.
70
71    Args:
72      keys_to_tensors: a mapping of TF-Example keys to parsed tensors.
73
74    Returns:
75      the final tensor representing the item being handled.
76    """
77    pass
78
79
80class ItemHandlerCallback(ItemHandler):
81  """An ItemHandler that converts the parsed tensors via a given function.
82
83  Unlike other ItemHandlers, the ItemHandlerCallback resolves its item via
84  a callback function rather than using prespecified behavior.
85  """
86
87  def __init__(self, keys, func):
88    """Initializes the ItemHandler.
89
90    Args:
91      keys: a list of TF-Example keys.
92      func: a function that takes as an argument a dictionary from `keys` to
93        parsed Tensors.
94    """
95    super(ItemHandlerCallback, self).__init__(keys)
96    self._func = func
97
98  def tensors_to_item(self, keys_to_tensors):
99    return self._func(keys_to_tensors)
100
101
102class BoundingBox(ItemHandler):
103  """An ItemHandler that concatenates a set of parsed Tensors to Bounding Boxes.
104  """
105
106  def __init__(self, keys=None, prefix=''):
107    """Initialize the bounding box handler.
108
109    Args:
110      keys: A list of four key names representing the ymin, xmin, ymax, mmax
111      prefix: An optional prefix for each of the bounding box keys.
112        If provided, `prefix` is appended to each key in `keys`.
113
114    Raises:
115      ValueError: if keys is not `None` and also not a list of exactly 4 keys
116    """
117    if keys is None:
118      keys = ['ymin', 'xmin', 'ymax', 'xmax']
119    elif len(keys) != 4:
120      raise ValueError('BoundingBox expects 4 keys but got {}'.format(
121          len(keys)))
122    self._prefix = prefix
123    self._keys = keys
124    self._full_keys = [prefix + k for k in keys]
125    super(BoundingBox, self).__init__(self._full_keys)
126
127  def tensors_to_item(self, keys_to_tensors):
128    """Maps the given dictionary of tensors to a concatenated list of bboxes.
129
130    Args:
131      keys_to_tensors: a mapping of TF-Example keys to parsed tensors.
132
133    Returns:
134      [num_boxes, 4] tensor of bounding box coordinates,
135        i.e. 1 bounding box per row, in order [y_min, x_min, y_max, x_max].
136    """
137    sides = []
138    for key in self._full_keys:
139      side = keys_to_tensors[key]
140      if isinstance(side, sparse_tensor.SparseTensor):
141        side = side.values
142      side = array_ops.expand_dims(side, 0)
143      sides.append(side)
144
145    bounding_box = array_ops.concat(sides, 0)
146    return array_ops.transpose(bounding_box)
147
148
149class Tensor(ItemHandler):
150  """An ItemHandler that returns a parsed Tensor."""
151
152  def __init__(self, tensor_key, shape_keys=None, shape=None, default_value=0):
153    """Initializes the Tensor handler.
154
155    Tensors are, by default, returned without any reshaping. However, there are
156    two mechanisms which allow reshaping to occur at load time. If `shape_keys`
157    is provided, both the `Tensor` corresponding to `tensor_key` and
158    `shape_keys` is loaded and the former `Tensor` is reshaped with the values
159    of the latter. Alternatively, if a fixed `shape` is provided, the `Tensor`
160    corresponding to `tensor_key` is loaded and reshape appropriately.
161    If neither `shape_keys` nor `shape` are provided, the `Tensor` will be
162    returned without any reshaping.
163
164    Args:
165      tensor_key: the name of the `TFExample` feature to read the tensor from.
166      shape_keys: Optional name or list of names of the TF-Example feature in
167        which the tensor shape is stored. If a list, then each corresponds to
168        one dimension of the shape.
169      shape: Optional output shape of the `Tensor`. If provided, the `Tensor` is
170        reshaped accordingly.
171      default_value: The value used when the `tensor_key` is not found in a
172        particular `TFExample`.
173
174    Raises:
175      ValueError: if both `shape_keys` and `shape` are specified.
176    """
177    if shape_keys and shape is not None:
178      raise ValueError('Cannot specify both shape_keys and shape parameters.')
179    if shape_keys and not isinstance(shape_keys, list):
180      shape_keys = [shape_keys]
181    self._tensor_key = tensor_key
182    self._shape_keys = shape_keys
183    self._shape = shape
184    self._default_value = default_value
185    keys = [tensor_key]
186    if shape_keys:
187      keys.extend(shape_keys)
188    super(Tensor, self).__init__(keys)
189
190  def tensors_to_item(self, keys_to_tensors):
191    tensor = keys_to_tensors[self._tensor_key]
192    shape = self._shape
193    if self._shape_keys:
194      shape_dims = []
195      for k in self._shape_keys:
196        shape_dim = keys_to_tensors[k]
197        if isinstance(shape_dim, sparse_tensor.SparseTensor):
198          shape_dim = sparse_ops.sparse_tensor_to_dense(shape_dim)
199        shape_dims.append(shape_dim)
200      shape = array_ops.reshape(array_ops.stack(shape_dims), [-1])
201    if isinstance(tensor, sparse_tensor.SparseTensor):
202      if shape is not None:
203        tensor = sparse_ops.sparse_reshape(tensor, shape)
204      tensor = sparse_ops.sparse_tensor_to_dense(tensor, self._default_value)
205    else:
206      if shape is not None:
207        tensor = array_ops.reshape(tensor, shape)
208    return tensor
209
210
211class LookupTensor(Tensor):
212  """An ItemHandler that returns a parsed Tensor, the result of a lookup."""
213
214  def __init__(self,
215               tensor_key,
216               table,
217               shape_keys=None,
218               shape=None,
219               default_value=''):
220    """Initializes the LookupTensor handler.
221
222    See Tensor.  Simply calls a vocabulary (most often, a label mapping) lookup.
223
224    Args:
225      tensor_key: the name of the `TFExample` feature to read the tensor from.
226      table: A tf.lookup table.
227      shape_keys: Optional name or list of names of the TF-Example feature in
228        which the tensor shape is stored. If a list, then each corresponds to
229        one dimension of the shape.
230      shape: Optional output shape of the `Tensor`. If provided, the `Tensor` is
231        reshaped accordingly.
232      default_value: The value used when the `tensor_key` is not found in a
233        particular `TFExample`.
234
235    Raises:
236      ValueError: if both `shape_keys` and `shape` are specified.
237    """
238    self._table = table
239    super(LookupTensor, self).__init__(tensor_key, shape_keys, shape,
240                                       default_value)
241
242  def tensors_to_item(self, keys_to_tensors):
243    unmapped_tensor = super(LookupTensor, self).tensors_to_item(keys_to_tensors)
244    return self._table.lookup(unmapped_tensor)
245
246
247class BackupHandler(ItemHandler):
248  """An ItemHandler that tries two ItemHandlers in order."""
249
250  def __init__(self, handler, backup):
251    """Initializes the BackupHandler handler.
252
253    If the first Handler's tensors_to_item returns a Tensor with no elements,
254    the second Handler is used.
255
256    Args:
257      handler: The primary ItemHandler.
258      backup: The backup ItemHandler.
259
260    Raises:
261      ValueError: if either is not an ItemHandler.
262    """
263    if not isinstance(handler, ItemHandler):
264      raise ValueError('Primary handler is of type %s instead of ItemHandler'
265                       % type(handler))
266    if not isinstance(backup, ItemHandler):
267      raise ValueError('Backup handler is of type %s instead of ItemHandler'
268                       % type(backup))
269    self._handler = handler
270    self._backup = backup
271    super(BackupHandler, self).__init__(handler.keys + backup.keys)
272
273  def tensors_to_item(self, keys_to_tensors):
274    item = self._handler.tensors_to_item(keys_to_tensors)
275    return control_flow_ops.cond(
276        pred=math_ops.equal(math_ops.reduce_prod(array_ops.shape(item)), 0),
277        true_fn=lambda: self._backup.tensors_to_item(keys_to_tensors),
278        false_fn=lambda: item)
279
280
281class SparseTensor(ItemHandler):
282  """An ItemHandler for SparseTensors."""
283
284  def __init__(self,
285               indices_key=None,
286               values_key=None,
287               shape_key=None,
288               shape=None,
289               densify=False,
290               default_value=0):
291    """Initializes the Tensor handler.
292
293    Args:
294      indices_key: the name of the TF-Example feature that contains the ids.
295        Defaults to 'indices'.
296      values_key: the name of the TF-Example feature that contains the values.
297        Defaults to 'values'.
298      shape_key: the name of the TF-Example feature that contains the shape.
299        If provided it would be used.
300      shape: the output shape of the SparseTensor. If `shape_key` is not
301        provided this `shape` would be used.
302      densify: whether to convert the SparseTensor into a dense Tensor.
303      default_value: Scalar value to set when making dense for indices not
304        specified in the `SparseTensor`.
305    """
306    indices_key = indices_key or 'indices'
307    values_key = values_key or 'values'
308    self._indices_key = indices_key
309    self._values_key = values_key
310    self._shape_key = shape_key
311    self._shape = shape
312    self._densify = densify
313    self._default_value = default_value
314    keys = [indices_key, values_key]
315    if shape_key:
316      keys.append(shape_key)
317    super(SparseTensor, self).__init__(keys)
318
319  def tensors_to_item(self, keys_to_tensors):
320    indices = keys_to_tensors[self._indices_key]
321    values = keys_to_tensors[self._values_key]
322    if self._shape_key:
323      shape = keys_to_tensors[self._shape_key]
324      if isinstance(shape, sparse_tensor.SparseTensor):
325        shape = sparse_ops.sparse_tensor_to_dense(shape)
326    elif self._shape:
327      shape = self._shape
328    else:
329      shape = indices.dense_shape
330    indices_shape = array_ops.shape(indices.indices)
331    rank = indices_shape[1]
332    ids = math_ops.cast(indices.values, dtypes.int64)
333    indices_columns_to_preserve = array_ops.slice(
334        indices.indices, [0, 0], array_ops.stack([-1, rank - 1]))
335    new_indices = array_ops.concat(
336        [indices_columns_to_preserve, array_ops.reshape(ids, [-1, 1])], 1)
337
338    tensor = sparse_tensor.SparseTensor(new_indices, values.values, shape)
339    if self._densify:
340      tensor = sparse_ops.sparse_tensor_to_dense(tensor, self._default_value)
341    return tensor
342
343
344class Image(ItemHandler):
345  """An ItemHandler that decodes a parsed Tensor as an image."""
346
347  def __init__(self,
348               image_key=None,
349               format_key=None,
350               shape=None,
351               channels=3,
352               dtype=dtypes.uint8,
353               repeated=False,
354               dct_method=''):
355    """Initializes the image.
356
357    Args:
358      image_key: the name of the TF-Example feature in which the encoded image
359        is stored.
360      format_key: the name of the TF-Example feature in which the image format
361        is stored.
362      shape: the output shape of the image as 1-D `Tensor`
363        [height, width, channels]. If provided, the image is reshaped
364        accordingly. If left as None, no reshaping is done. A shape should
365        be supplied only if all the stored images have the same shape.
366      channels: the number of channels in the image.
367      dtype: images will be decoded at this bit depth. Different formats
368        support different bit depths.
369          See tf.image.decode_image,
370              tf.decode_raw,
371      repeated: if False, decodes a single image. If True, decodes a
372        variable number of image strings from a 1D tensor of strings.
373      dct_method: An optional string. Defaults to empty string. It only takes
374        effect when image format is jpeg, used to specify a hint about the
375        algorithm used for jpeg decompression. Currently valid values
376        are ['INTEGER_FAST', 'INTEGER_ACCURATE']. The hint may be ignored, for
377        example, the jpeg library does not have that specific option.
378    """
379    if not image_key:
380      image_key = 'image/encoded'
381    if not format_key:
382      format_key = 'image/format'
383
384    super(Image, self).__init__([image_key, format_key])
385    self._image_key = image_key
386    self._format_key = format_key
387    self._shape = shape
388    self._channels = channels
389    self._dtype = dtype
390    self._repeated = repeated
391    self._dct_method = dct_method
392
393  def tensors_to_item(self, keys_to_tensors):
394    """See base class."""
395    image_buffer = keys_to_tensors[self._image_key]
396    image_format = keys_to_tensors[self._format_key]
397
398    if self._repeated:
399      return map_fn.map_fn(lambda x: self._decode(x, image_format),
400                           image_buffer, dtype=self._dtype)
401    else:
402      return self._decode(image_buffer, image_format)
403
404  def _decode(self, image_buffer, image_format):
405    """Decodes the image buffer.
406
407    Args:
408      image_buffer: The tensor representing the encoded image tensor.
409      image_format: The image format for the image in `image_buffer`. If image
410        format is `raw`, all images are expected to be in this format, otherwise
411        this op can decode a mix of `jpg` and `png` formats.
412
413    Returns:
414      A tensor that represents decoded image of self._shape, or
415      (?, ?, self._channels) if self._shape is not specified.
416    """
417
418    def decode_image():
419      """Decodes a image based on the headers."""
420      return math_ops.cast(
421          image_ops.decode_image(image_buffer, channels=self._channels),
422          self._dtype)
423
424    def decode_jpeg():
425      """Decodes a jpeg image with specified '_dct_method'."""
426      return math_ops.cast(
427          image_ops.decode_jpeg(
428              image_buffer,
429              channels=self._channels,
430              dct_method=self._dct_method), self._dtype)
431
432    def check_jpeg():
433      """Checks if an image is jpeg."""
434      # For jpeg, we directly use image_ops.decode_jpeg rather than decode_image
435      # in order to feed the jpeg specify parameter 'dct_method'.
436      return control_flow_ops.cond(
437          image_ops.is_jpeg(image_buffer),
438          decode_jpeg,
439          decode_image,
440          name='cond_jpeg')
441
442    def decode_raw():
443      """Decodes a raw image."""
444      return parsing_ops.decode_raw(image_buffer, out_type=self._dtype)
445
446    pred_fn_pairs = {
447        math_ops.logical_or(
448            math_ops.equal(image_format, 'raw'),
449            math_ops.equal(image_format, 'RAW')): decode_raw,
450    }
451    image = control_flow_ops.case(
452        pred_fn_pairs, default=check_jpeg, exclusive=True)
453
454    image.set_shape([None, None, self._channels])
455    if self._shape is not None:
456      image = array_ops.reshape(image, self._shape)
457
458    return image
459
460
461class TFExampleDecoder(data_decoder.DataDecoder):
462  """A decoder for TensorFlow Examples.
463
464  Decoding Example proto buffers is comprised of two stages: (1) Example parsing
465  and (2) tensor manipulation.
466
467  In the first stage, the tf.parse_example function is called with a list of
468  FixedLenFeatures and SparseLenFeatures. These instances tell TF how to parse
469  the example. The output of this stage is a set of tensors.
470
471  In the second stage, the resulting tensors are manipulated to provide the
472  requested 'item' tensors.
473
474  To perform this decoding operation, an ExampleDecoder is given a list of
475  ItemHandlers. Each ItemHandler indicates the set of features for stage 1 and
476  contains the instructions for post_processing its tensors for stage 2.
477  """
478
479  def __init__(self, keys_to_features, items_to_handlers):
480    """Constructs the decoder.
481
482    Args:
483      keys_to_features: a dictionary from TF-Example keys to either
484        tf.VarLenFeature or tf.FixedLenFeature instances. See tensorflow's
485        parsing_ops.py.
486      items_to_handlers: a dictionary from items (strings) to ItemHandler
487        instances. Note that the ItemHandler's are provided the keys that they
488        use to return the final item Tensors.
489    """
490    self._keys_to_features = keys_to_features
491    self._items_to_handlers = items_to_handlers
492
493  def list_items(self):
494    """See base class."""
495    return list(self._items_to_handlers.keys())
496
497  def decode(self, serialized_example, items=None):
498    """Decodes the given serialized TF-example.
499
500    Args:
501      serialized_example: a serialized TF-example tensor.
502      items: the list of items to decode. These must be a subset of the item
503        keys in self._items_to_handlers. If `items` is left as None, then all
504        of the items in self._items_to_handlers are decoded.
505
506    Returns:
507      the decoded items, a list of tensor.
508    """
509    example = parsing_ops.parse_single_example(serialized_example,
510                                               self._keys_to_features)
511
512    # Reshape non-sparse elements just once, adding the reshape ops in
513    # deterministic order.
514    for k in sorted(self._keys_to_features):
515      v = self._keys_to_features[k]
516      if isinstance(v, parsing_ops.FixedLenFeature):
517        example[k] = array_ops.reshape(example[k], v.shape)
518
519    if not items:
520      items = self._items_to_handlers.keys()
521
522    outputs = []
523    for item in items:
524      handler = self._items_to_handlers[item]
525      keys_to_tensors = {key: example[key] for key in handler.keys}
526      outputs.append(handler.tensors_to_item(keys_to_tensors))
527    return outputs
528