1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""TensorFlow-related utilities."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import six
21
22from tensorflow.python.eager import context
23from tensorflow.python.framework import composite_tensor
24from tensorflow.python.framework import ops
25from tensorflow.python.framework import smart_cond as smart_module
26from tensorflow.python.framework import tensor_shape
27from tensorflow.python.framework import tensor_util
28from tensorflow.python.ops import control_flow_ops
29from tensorflow.python.ops import variables
30from tensorflow.python.util import nest
31
32
33def smart_cond(pred, true_fn=None, false_fn=None, name=None):
34  """Return either `true_fn()` if predicate `pred` is true else `false_fn()`.
35
36  If `pred` is a bool or has a constant value, we return either `true_fn()`
37  or `false_fn()`, otherwise we use `tf.cond` to dynamically route to both.
38
39  Arguments:
40    pred: A scalar determining whether to return the result of `true_fn` or
41      `false_fn`.
42    true_fn: The callable to be performed if pred is true.
43    false_fn: The callable to be performed if pred is false.
44    name: Optional name prefix when using `tf.cond`.
45
46  Returns:
47    Tensors returned by the call to either `true_fn` or `false_fn`.
48
49  Raises:
50    TypeError: If `true_fn` or `false_fn` is not callable.
51  """
52  if isinstance(pred, variables.Variable):
53    return control_flow_ops.cond(
54        pred, true_fn=true_fn, false_fn=false_fn, name=name)
55  return smart_module.smart_cond(
56      pred, true_fn=true_fn, false_fn=false_fn, name=name)
57
58
59def constant_value(pred):
60  """Return the bool value for `pred`, or None if `pred` had a dynamic value.
61
62  Arguments:
63    pred: A scalar, either a Python bool or a TensorFlow boolean variable
64      or tensor, or the Python integer 1 or 0.
65
66  Returns:
67    True or False if `pred` has a constant boolean value, None otherwise.
68
69  Raises:
70    TypeError: If `pred` is not a Variable, Tensor or bool, or Python
71      integer 1 or 0.
72  """
73  # Allow integer booleans.
74  if isinstance(pred, int):
75    if pred == 1:
76      pred = True
77    elif pred == 0:
78      pred = False
79
80  if isinstance(pred, variables.Variable):
81    return None
82  return smart_module.smart_constant_value(pred)
83
84
85def is_tensor_or_tensor_list(v):
86  v = nest.flatten(v)
87  if v and isinstance(v[0], ops.Tensor):
88    return True
89  else:
90    return False
91
92
93def get_reachable_from_inputs(inputs, targets=None):
94  """Returns the set of tensors/ops reachable from `inputs`.
95
96  Stops if all targets have been found (target is optional).
97
98  Only valid in Symbolic mode, not Eager mode.
99
100  Args:
101    inputs: List of tensors.
102    targets: List of tensors.
103
104  Returns:
105    A set of tensors reachable from the inputs (includes the inputs themselves).
106  """
107  inputs = nest.flatten(inputs)
108  reachable = set(inputs)
109  if targets:
110    targets = set(targets)
111  queue = inputs[:]
112
113  while queue:
114    x = queue.pop()
115    if isinstance(x, tuple(_user_convertible_tensor_types)):
116      # Can't find consumers of user-specific types.
117      continue
118
119    if isinstance(x, ops.Operation):
120      outputs = x.outputs[:] or []
121      outputs += x._control_outputs  # pylint: disable=protected-access
122    elif isinstance(x, variables.Variable):
123      try:
124        outputs = [x.op]
125      except AttributeError:
126        # Variables can be created in an Eager context.
127        outputs = []
128    elif tensor_util.is_tensor(x):
129      outputs = x.consumers()
130    else:
131      raise TypeError('Expected Operation, Variable, or Tensor, got ' + str(x))
132
133    for y in outputs:
134      if y not in reachable:
135        reachable.add(y)
136        queue.insert(0, y)
137
138    if targets and targets.issubset(reachable):
139      return reachable
140  return reachable
141
142
143# This function needs access to private functions of `nest`.
144#  pylint: disable=protected-access
145def map_structure_with_atomic(is_atomic_fn, map_fn, nested):
146  """Maps the atomic elements of a nested structure.
147
148  Arguments:
149    is_atomic_fn: A function that determines if an element of `nested` is
150      atomic.
151    map_fn: The function to apply to atomic elements of `nested`.
152    nested: A nested structure.
153
154  Returns:
155    The nested structure, with atomic elements mapped according to `map_fn`.
156
157  Raises:
158    ValueError: If an element that is neither atomic nor a sequence is
159      encountered.
160  """
161  if is_atomic_fn(nested):
162    return map_fn(nested)
163
164  # Recursively convert.
165  if not nest.is_sequence(nested):
166    raise ValueError(
167        'Received non-atomic and non-sequence element: {}'.format(nested))
168  if nest._is_mapping(nested):
169    values = [nested[k] for k in nest._sorted(nested)]
170  else:
171    values = nested
172  mapped_values = [
173      map_structure_with_atomic(is_atomic_fn, map_fn, ele) for ele in values
174  ]
175  return nest._sequence_like(nested, mapped_values)
176
177
178#  pylint: enable=protected-access
179
180
181def convert_shapes(input_shape, to_tuples=True):
182  """Converts nested shape representations to desired format.
183
184  Performs:
185
186  TensorShapes -> tuples if `to_tuples=True`.
187  tuples of int or None -> TensorShapes if `to_tuples=False`.
188
189  Valid objects to be converted are:
190  - TensorShapes
191  - tuples with elements of type int or None.
192  - ints
193  - None
194
195  Arguments:
196    input_shape: A nested structure of objects to be converted to TensorShapes.
197    to_tuples: If `True`, converts all TensorShape to tuples. Otherwise converts
198      all tuples representing shapes to TensorShapes.
199
200  Returns:
201    Nested structure of shapes in desired format.
202  """
203
204  def _is_shape_component(value):
205    return value is None or isinstance(value, (int, tensor_shape.Dimension))
206
207  def _is_atomic_shape(input_shape):
208    # Ex: TensorShape or (None, 10, 32) or 5 or `None`
209    if _is_shape_component(input_shape):
210      return True
211    if isinstance(input_shape, tensor_shape.TensorShape):
212      return True
213    if (isinstance(input_shape, (tuple, list)) and
214        all(_is_shape_component(ele) for ele in input_shape)):
215      return True
216    return False
217
218  def _convert_shape(input_shape):
219    input_shape = tensor_shape.TensorShape(input_shape)
220    if to_tuples:
221      input_shape = tuple(input_shape.as_list())
222    return input_shape
223
224  return map_structure_with_atomic(_is_atomic_shape, _convert_shape,
225                                   input_shape)
226
227
228class ListWrapper(object):
229  """A wrapper for lists to be treated as elements for `nest`."""
230
231  def __init__(self, list_to_wrap):
232    self._list = list_to_wrap
233
234  def as_list(self):
235    return self._list
236
237
238def convert_inner_node_data(nested, wrap=False):
239  """Either wraps or unwraps innermost node data lists in `ListWrapper` objects.
240
241  Arguments:
242    nested: A nested data structure.
243    wrap: If `True`, wrap innermost lists in `ListWrapper` objects. If `False`,
244      unwraps `ListWrapper` objects into lists.
245
246  Returns:
247    Strucutre of same type as nested, with lists wrapped/unwrapped.
248  """
249
250  def _is_atomic_nested(nested):
251    """Returns `True` if `nested` is a list representing node data."""
252    if isinstance(nested, ListWrapper):
253      return True
254    # Node data can be of form `[layer_name, node_id, tensor_id]` or
255    # `[layer_name, node_id, tensor_id, kwargs]`.
256    if (isinstance(nested, list) and (len(nested) in [3, 4]) and
257        isinstance(nested[0], six.string_types)):
258      return True
259    return False
260
261  def _convert_object_or_list(nested):
262    """Convert b/t `ListWrapper` object and list representations."""
263    if wrap:
264      if isinstance(nested, ListWrapper):
265        return nested
266      return ListWrapper(nested)
267    else:
268      if isinstance(nested, ListWrapper):
269        return nested.as_list()
270      return nested
271
272  return map_structure_with_atomic(_is_atomic_nested, _convert_object_or_list,
273                                   nested)
274
275
276def shape_type_conversion(fn):
277  """Decorator that handles tuple/TensorShape conversion.
278
279  Used in `compute_output_shape` and `build`.
280
281  Arguments:
282    fn: function to wrap.
283
284  Returns:
285    Wrapped function.
286  """
287
288  def wrapper(instance, input_shape):
289    # Pass shapes as tuples to `fn`
290    # This preserves compatibility with external Keras.
291    if input_shape is not None:
292      input_shape = convert_shapes(input_shape, to_tuples=True)
293    output_shape = fn(instance, input_shape)
294    # Return shapes from `fn` as TensorShapes.
295    if output_shape is not None:
296      output_shape = convert_shapes(output_shape, to_tuples=False)
297    return output_shape
298
299  return wrapper
300
301
302def are_all_symbolic_tensors(tensors):
303  return all(is_symbolic_tensor(tensor) for tensor in tensors)
304
305
306_user_convertible_tensor_types = set()
307
308
309def is_symbolic_tensor(tensor):
310  """Returns whether a tensor is symbolic (from a TF graph) or an eager tensor.
311
312  A Variable can be seen as either: it is considered symbolic
313  when we are in a graph scope, and eager when we are in an eager scope.
314
315  Arguments:
316    tensor: A tensor instance to test.
317
318  Returns:
319    True for symbolic tensors, False for eager tensors.
320  """
321  if isinstance(tensor, variables.Variable):
322    # Variables that are output of a Keras Layer in Functional API mode
323    # should be considered symbolic.
324    # TODO(omalleyt): We need a better way to check this in order to
325    # enable `run_eagerly=True` for Models containing Layers that
326    # return Variables as outputs.
327    return (getattr(tensor, '_keras_history', False) or
328            not context.executing_eagerly())
329  if isinstance(tensor, composite_tensor.CompositeTensor):
330    return tensor._is_graph_tensor  # pylint: disable=protected-access
331  if isinstance(tensor, ops.Tensor):
332    return hasattr(tensor, 'graph')
333  if isinstance(tensor, tuple(_user_convertible_tensor_types)):
334    return hasattr(ops.convert_to_tensor(tensor), 'graph')
335  return False
336
337
338def register_symbolic_tensor_type(cls):
339  """Allows users to specify types regarded as symbolic `Tensor`s.
340
341  Used in conjunction with `tf.register_tensor_conversion_function`, calling
342  `tf.keras.utils.register_symbolic_tensor_type(cls)` allows non-`Tensor`
343  objects to be plumbed through Keras layers.
344
345  Example:
346
347  ```python
348  # One-time setup.
349  class Foo(object):
350    def __init__(self, input_):
351      self._input = input_
352    def value(self):
353      return tf.constant(42.)
354
355  tf.register_tensor_conversion_function(
356      Foo, lambda x, *args, **kwargs: x.value())
357
358  tf.keras.utils.register_symbolic_tensor_type(Foo)
359
360  # User-land.
361  layer = tf.keras.layers.Lambda(lambda input_: Foo(input_))
362  ```
363
364  Arguments:
365    cls: A `class` type which shall be regarded as a symbolic `Tensor`.
366  """
367  global _user_convertible_tensor_types
368  _user_convertible_tensor_types.add(cls)
369
370
371def is_tensor_or_variable(x):
372  return tensor_util.is_tensor(x) or isinstance(x, variables.Variable)
373