1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Common util functions used by layers.
16"""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from collections import namedtuple
22from collections import OrderedDict
23from tensorflow.python.framework import ops
24from tensorflow.python.framework import tensor_shape
25from tensorflow.python.framework import tensor_util
26from tensorflow.python.ops import control_flow_ops
27from tensorflow.python.ops import variables
28
29__all__ = ['collect_named_outputs',
30           'constant_value',
31           'static_cond',
32           'smart_cond',
33           'get_variable_collections',
34           'two_element_tuple',
35           'n_positive_integers',
36           'channel_dimension',
37           'last_dimension']
38
39NamedOutputs = namedtuple('NamedOutputs', ['name', 'outputs'])
40
41
42def collect_named_outputs(collections, alias, outputs):
43  """Add `Tensor` outputs tagged with alias to collections.
44
45  It is useful to collect end-points or tags for summaries. Example of usage:
46
47  logits = collect_named_outputs('end_points', 'inception_v3/logits', logits)
48  assert 'inception_v3/logits' in logits.aliases
49
50  Args:
51    collections: A collection or list of collections. If None skip collection.
52    alias: String to append to the list of aliases of outputs, for example,
53           'inception_v3/conv1'.
54    outputs: Tensor, an output tensor to collect
55
56  Returns:
57    The outputs Tensor to allow inline call.
58  """
59  if collections:
60    append_tensor_alias(outputs, alias)
61    ops.add_to_collections(collections, outputs)
62  return outputs
63
64
65def append_tensor_alias(tensor, alias):
66  """Append an alias to the list of aliases of the tensor.
67
68  Args:
69    tensor: A `Tensor`.
70    alias: String, to add to the list of aliases of the tensor.
71
72  Returns:
73    The tensor with a new alias appended to its list of aliases.
74  """
75  # Remove ending '/' if present.
76  if alias[-1] == '/':
77    alias = alias[:-1]
78  if hasattr(tensor, 'aliases'):
79    tensor.aliases.append(alias)
80  else:
81    tensor.aliases = [alias]
82  return tensor
83
84
85def gather_tensors_aliases(tensors):
86  """Given a list of tensors, gather their aliases.
87
88  Args:
89    tensors: A list of `Tensors`.
90
91  Returns:
92    A list of strings with the aliases of all tensors.
93  """
94  aliases = []
95  for tensor in tensors:
96    aliases += get_tensor_aliases(tensor)
97  return aliases
98
99
100def get_tensor_aliases(tensor):
101  """Get a list with the aliases of the input tensor.
102
103  If the tensor does not have any alias, it would default to its its op.name or
104  its name.
105
106  Args:
107    tensor: A `Tensor`.
108
109  Returns:
110    A list of strings with the aliases of the tensor.
111  """
112  if hasattr(tensor, 'aliases'):
113    aliases = tensor.aliases
114  else:
115    if tensor.name[-2:] == ':0':
116      # Use op.name for tensor ending in :0
117      aliases = [tensor.op.name]
118    else:
119      aliases = [tensor.name]
120  return aliases
121
122
123def convert_collection_to_dict(collection, clear_collection=False):
124  """Returns an OrderedDict of Tensors with their aliases as keys.
125
126  Args:
127    collection: A collection.
128    clear_collection: When True, it clears the collection after converting to
129      OrderedDict.
130
131  Returns:
132    An OrderedDict of {alias: tensor}
133  """
134  output = OrderedDict((alias, tensor)
135                       for tensor in ops.get_collection(collection)
136                       for alias in get_tensor_aliases(tensor))
137  if clear_collection:
138    ops.get_default_graph().clear_collection(collection)
139  return output
140
141
142def constant_value(value_or_tensor_or_var, dtype=None):
143  """Returns value if value_or_tensor_or_var has a constant value.
144
145  Args:
146    value_or_tensor_or_var: A value, a `Tensor` or a `Variable`.
147    dtype: Optional `tf.dtype`, if set it would check it has the right
148      dtype.
149
150  Returns:
151    The constant value or None if it not constant.
152
153  Raises:
154    ValueError: if value_or_tensor_or_var is None or the tensor_variable has the
155    wrong dtype.
156  """
157  if value_or_tensor_or_var is None:
158    raise ValueError('value_or_tensor_or_var cannot be None')
159  value = value_or_tensor_or_var
160  if isinstance(value_or_tensor_or_var, (ops.Tensor, variables.Variable)):
161    if dtype and value_or_tensor_or_var.dtype != dtype:
162      raise ValueError('It has the wrong type %s instead of %s' % (
163          value_or_tensor_or_var.dtype, dtype))
164    if isinstance(value_or_tensor_or_var, variables.Variable):
165      value = None
166    else:
167      value = tensor_util.constant_value(value_or_tensor_or_var)
168  return value
169
170
171def static_cond(pred, fn1, fn2):
172  """Return either fn1() or fn2() based on the boolean value of `pred`.
173
174  Same signature as `control_flow_ops.cond()` but requires pred to be a bool.
175
176  Args:
177    pred: A value determining whether to return the result of `fn1` or `fn2`.
178    fn1: The callable to be performed if pred is true.
179    fn2: The callable to be performed if pred is false.
180
181  Returns:
182    Tensors returned by the call to either `fn1` or `fn2`.
183
184  Raises:
185    TypeError: if `fn1` or `fn2` is not callable.
186  """
187  if not callable(fn1):
188    raise TypeError('fn1 must be callable.')
189  if not callable(fn2):
190    raise TypeError('fn2 must be callable.')
191  if pred:
192    return fn1()
193  else:
194    return fn2()
195
196
197def smart_cond(pred, fn1, fn2, name=None):
198  """Return either fn1() or fn2() based on the boolean predicate/value `pred`.
199
200  If `pred` is bool or has a constant value it would use `static_cond`,
201  otherwise it would use `tf.cond`.
202
203  Args:
204    pred: A scalar determining whether to return the result of `fn1` or `fn2`.
205    fn1: The callable to be performed if pred is true.
206    fn2: The callable to be performed if pred is false.
207    name: Optional name prefix when using tf.cond
208  Returns:
209    Tensors returned by the call to either `fn1` or `fn2`.
210  """
211  pred_value = constant_value(pred)
212  if pred_value is not None:
213    # Use static_cond if pred has a constant value.
214    return static_cond(pred_value, fn1, fn2)
215  else:
216    # Use dynamic cond otherwise.
217    return control_flow_ops.cond(pred, fn1, fn2, name)
218
219
220def get_variable_collections(variables_collections, name):
221  if isinstance(variables_collections, dict):
222    variable_collections = variables_collections.get(name, None)
223  else:
224    variable_collections = variables_collections
225  return variable_collections
226
227
228def _get_dimension(shape, dim, min_rank=1):
229  """Returns the `dim` dimension of `shape`, while checking it has `min_rank`.
230
231  Args:
232    shape: A `TensorShape`.
233    dim: Integer, which dimension to return.
234    min_rank: Integer, minimum rank of shape.
235
236  Returns:
237    The value of the `dim` dimension.
238
239  Raises:
240    ValueError: if inputs don't have at least min_rank dimensions, or if the
241      first dimension value is not defined.
242  """
243  dims = shape.dims
244  if dims is None:
245    raise ValueError('dims of shape must be known but is None')
246  if len(dims) < min_rank:
247    raise ValueError('rank of shape must be at least %d not: %d' % (min_rank,
248                                                                    len(dims)))
249  value = dims[dim].value
250  if value is None:
251    raise ValueError(
252        'dimension %d of shape must be known but is None: %s' % (dim, shape))
253  return value
254
255
256def channel_dimension(shape, data_format, min_rank=1):
257  """Returns the channel dimension of shape, while checking it has min_rank.
258
259  Args:
260    shape: A `TensorShape`.
261    data_format: `channels_first` or `channels_last`.
262    min_rank: Integer, minimum rank of shape.
263
264  Returns:
265    The value of the first dimension.
266
267  Raises:
268    ValueError: if inputs don't have at least min_rank dimensions, or if the
269      first dimension value is not defined.
270  """
271  return _get_dimension(shape, 1 if data_format == 'channels_first' else -1,
272                        min_rank=min_rank)
273
274
275def last_dimension(shape, min_rank=1):
276  """Returns the last dimension of shape while checking it has min_rank.
277
278  Args:
279    shape: A `TensorShape`.
280    min_rank: Integer, minimum rank of shape.
281
282  Returns:
283    The value of the last dimension.
284
285  Raises:
286    ValueError: if inputs don't have at least min_rank dimensions, or if the
287      last dimension value is not defined.
288  """
289  return _get_dimension(shape, -1, min_rank=min_rank)
290
291
292def two_element_tuple(int_or_tuple):
293  """Converts `int_or_tuple` to height, width.
294
295  Several of the functions that follow accept arguments as either
296  a tuple of 2 integers or a single integer.  A single integer
297  indicates that the 2 values of the tuple are the same.
298
299  This functions normalizes the input value by always returning a tuple.
300
301  Args:
302    int_or_tuple: A list of 2 ints, a single int or a `TensorShape`.
303
304  Returns:
305    A tuple with 2 values.
306
307  Raises:
308    ValueError: If `int_or_tuple` it not well formed.
309  """
310  if isinstance(int_or_tuple, (list, tuple)):
311    if len(int_or_tuple) != 2:
312      raise ValueError('Must be a list with 2 elements: %s' % int_or_tuple)
313    return int(int_or_tuple[0]), int(int_or_tuple[1])
314  if isinstance(int_or_tuple, int):
315    return int(int_or_tuple), int(int_or_tuple)
316  if isinstance(int_or_tuple, tensor_shape.TensorShape):
317    if len(int_or_tuple) == 2:
318      return int_or_tuple[0], int_or_tuple[1]
319  raise ValueError('Must be an int, a list with 2 elements or a TensorShape of '
320                   'length 2')
321
322
323def n_positive_integers(n, value):
324  """Converts `value` to a sequence of `n` positive integers.
325
326  `value` may be either be a sequence of values convertible to `int`, or a
327  single value convertible to `int`, in which case the resulting integer is
328  duplicated `n` times.  It may also be a TensorShape of rank `n`.
329
330  Args:
331    n: Length of sequence to return.
332    value: Either a single value convertible to a positive `int` or an
333      `n`-element sequence of values convertible to a positive `int`.
334
335  Returns:
336    A tuple of `n` positive integers.
337
338  Raises:
339    TypeError: If `n` is not convertible to an integer.
340    ValueError: If `n` or `value` are invalid.
341  """
342
343  n_orig = n
344  n = int(n)
345  if n < 1 or n != n_orig:
346    raise ValueError('n must be a positive integer')
347
348  try:
349    value = int(value)
350  except (TypeError, ValueError):
351    sequence_len = len(value)
352    if sequence_len != n:
353      raise ValueError(
354          'Expected sequence of %d positive integers, but received %r' %
355          (n, value))
356    try:
357      values = tuple(int(x) for x in value)
358    except:
359      raise ValueError(
360          'Expected sequence of %d positive integers, but received %r' %
361          (n, value))
362    for x in values:
363      if x < 1:
364        raise ValueError('expected positive integer, but received %d' % x)
365    return values
366
367  if value < 1:
368    raise ValueError('expected positive integer, but received %d' % value)
369  return (value,) * n
370