1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""TensorFlow-related utilities."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import collections
21import copy
22import numpy as np
23import six
24
25from tensorflow.python.data.experimental.ops import cardinality
26from tensorflow.python.eager import context
27from tensorflow.python.framework import composite_tensor
28from tensorflow.python.framework import ops
29from tensorflow.python.framework import tensor_shape
30from tensorflow.python.framework import tensor_spec
31from tensorflow.python.framework import tensor_util
32from tensorflow.python.framework import type_spec
33from tensorflow.python.keras import backend as K
34from tensorflow.python.keras.engine import keras_tensor
35from tensorflow.python.keras.utils import object_identity
36from tensorflow.python.keras.utils import tf_contextlib
37from tensorflow.python.ops import math_ops
38from tensorflow.python.ops import variables
39from tensorflow.python.ops.ragged import ragged_tensor
40from tensorflow.python.ops.ragged import ragged_tensor_value
41from tensorflow.python.util import nest
42
43
44def is_tensor_or_tensor_list(v):
45  v = nest.flatten(v)
46  if v and isinstance(v[0], ops.Tensor):
47    return True
48  else:
49    return False
50
51
52def get_reachable_from_inputs(inputs, targets=None):
53  """Returns the set of tensors/ops reachable from `inputs`.
54
55  Stops if all targets have been found (target is optional).
56
57  Only valid in Symbolic mode, not Eager mode.
58
59  Args:
60    inputs: List of tensors.
61    targets: List of tensors.
62
63  Returns:
64    A set of tensors reachable from the inputs (includes the inputs themselves).
65  """
66  inputs = nest.flatten(inputs, expand_composites=True)
67  reachable = object_identity.ObjectIdentitySet(inputs)
68  if targets:
69    remaining_targets = object_identity.ObjectIdentitySet(nest.flatten(targets))
70  queue = collections.deque(inputs)
71
72  while queue:
73    x = queue.pop()
74    if isinstance(x, tuple(_user_convertible_tensor_types)):
75      # Can't find consumers of user-specific types.
76      continue
77
78    if isinstance(x, ops.Operation):
79      outputs = x.outputs[:] or []
80      outputs += x._control_outputs  # pylint: disable=protected-access
81    elif isinstance(x, variables.Variable):
82      try:
83        outputs = [x.op]
84      except AttributeError:
85        # Variables can be created in an Eager context.
86        outputs = []
87    elif tensor_util.is_tf_type(x):
88      outputs = x.consumers()
89    else:
90      raise TypeError('Expected Operation, Variable, or Tensor, got ' + str(x))
91
92    for y in outputs:
93      if y not in reachable:
94        reachable.add(y)
95        if targets:
96          remaining_targets.discard(y)
97        queue.appendleft(y)
98
99    if targets and not remaining_targets:
100      return reachable
101
102  return reachable
103
104
105# This function needs access to private functions of `nest`.
106#  pylint: disable=protected-access
107def map_structure_with_atomic(is_atomic_fn, map_fn, nested):
108  """Maps the atomic elements of a nested structure.
109
110  Args:
111    is_atomic_fn: A function that determines if an element of `nested` is
112      atomic.
113    map_fn: The function to apply to atomic elements of `nested`.
114    nested: A nested structure.
115
116  Returns:
117    The nested structure, with atomic elements mapped according to `map_fn`.
118
119  Raises:
120    ValueError: If an element that is neither atomic nor a sequence is
121      encountered.
122  """
123  if is_atomic_fn(nested):
124    return map_fn(nested)
125
126  # Recursively convert.
127  if not nest.is_nested(nested):
128    raise ValueError(
129        'Received non-atomic and non-sequence element: {}'.format(nested))
130  if nest.is_mapping(nested):
131    values = [nested[k] for k in sorted(nested.keys())]
132  elif nest.is_attrs(nested):
133    values = _astuple(nested)
134  else:
135    values = nested
136  mapped_values = [
137      map_structure_with_atomic(is_atomic_fn, map_fn, ele) for ele in values
138  ]
139  return nest._sequence_like(nested, mapped_values)
140
141
142def get_shapes(tensors):
143  """Gets shapes from tensors."""
144  return nest.map_structure(lambda x: x.shape, tensors)
145
146
147#  pylint: enable=protected-access
148
149
150def convert_shapes(input_shape, to_tuples=True):
151  """Converts nested shape representations to desired format.
152
153  Performs:
154
155  TensorShapes -> tuples if `to_tuples=True`.
156  tuples of int or None -> TensorShapes if `to_tuples=False`.
157
158  Valid objects to be converted are:
159  - TensorShapes
160  - tuples with elements of type int or None.
161  - ints
162  - None
163
164  Args:
165    input_shape: A nested structure of objects to be converted to TensorShapes.
166    to_tuples: If `True`, converts all TensorShape to tuples. Otherwise converts
167      all tuples representing shapes to TensorShapes.
168
169  Returns:
170    Nested structure of shapes in desired format.
171
172  Raises:
173    ValueError: when the input tensor shape can't be converted to tuples, eg
174      unknown tensor shape.
175  """
176
177  def _is_shape_component(value):
178    return value is None or isinstance(value, (int, tensor_shape.Dimension))
179
180  def _is_atomic_shape(input_shape):
181    # Ex: TensorShape or (None, 10, 32) or 5 or `None`
182    if _is_shape_component(input_shape):
183      return True
184    if isinstance(input_shape, tensor_shape.TensorShape):
185      return True
186    if (isinstance(input_shape, (tuple, list)) and
187        all(_is_shape_component(ele) for ele in input_shape)):
188      return True
189    return False
190
191  def _convert_shape(input_shape):
192    input_shape = tensor_shape.TensorShape(input_shape)
193    if to_tuples:
194      input_shape = tuple(input_shape.as_list())
195    return input_shape
196
197  return map_structure_with_atomic(_is_atomic_shape, _convert_shape,
198                                   input_shape)
199
200
201class ListWrapper(object):
202  """A wrapper for lists to be treated as elements for `nest`."""
203
204  def __init__(self, list_to_wrap):
205    self._list = list_to_wrap
206
207  def as_list(self):
208    return self._list
209
210
211def convert_inner_node_data(nested, wrap=False):
212  """Either wraps or unwraps innermost node data lists in `ListWrapper` objects.
213
214  Args:
215    nested: A nested data structure.
216    wrap: If `True`, wrap innermost lists in `ListWrapper` objects. If `False`,
217      unwraps `ListWrapper` objects into lists.
218
219  Returns:
220    Structure of same type as nested, with lists wrapped/unwrapped.
221  """
222
223  def _is_serialized_node_data(nested):
224    # Node data can be of form `[layer_name, node_id, tensor_id]` or
225    # `[layer_name, node_id, tensor_id, kwargs]`.
226    if (isinstance(nested, list) and (len(nested) in [3, 4]) and
227        isinstance(nested[0], six.string_types)):
228      return True
229    return False
230
231  def _is_atomic_nested(nested):
232    """Returns `True` if `nested` is a list representing node data."""
233    if isinstance(nested, ListWrapper):
234      return True
235    if _is_serialized_node_data(nested):
236      return True
237    return not nest.is_nested(nested)
238
239  def _convert_object_or_list(nested):
240    """Convert b/t `ListWrapper` object and list representations."""
241    if wrap:
242      if isinstance(nested, ListWrapper):
243        return nested
244      if _is_serialized_node_data(nested):
245        return ListWrapper(nested)
246      return nested
247    else:
248      if isinstance(nested, ListWrapper):
249        return nested.as_list()
250      return nested
251
252  return map_structure_with_atomic(_is_atomic_nested, _convert_object_or_list,
253                                   nested)
254
255
256def shape_type_conversion(fn):
257  """Decorator that handles tuple/TensorShape conversion.
258
259  Used in `compute_output_shape` and `build`.
260
261  Args:
262    fn: function to wrap.
263
264  Returns:
265    Wrapped function.
266  """
267
268  def wrapper(instance, input_shape):
269    # Pass shapes as tuples to `fn`
270    # This preserves compatibility with external Keras.
271    if input_shape is not None:
272      input_shape = convert_shapes(input_shape, to_tuples=True)
273    output_shape = fn(instance, input_shape)
274    # Return shapes from `fn` as TensorShapes.
275    if output_shape is not None:
276      output_shape = convert_shapes(output_shape, to_tuples=False)
277    return output_shape
278
279  return wrapper
280
281
282def are_all_symbolic_tensors(tensors):
283  return all(map(is_symbolic_tensor, tensors))
284
285
286_user_convertible_tensor_types = set()
287
288
289def is_extension_type(tensor):
290  """Returns whether a tensor is of an ExtensionType.
291
292  github.com/tensorflow/community/pull/269
293  Currently it works by checking if `tensor` is a `CompositeTensor` instance,
294  but this will be changed to use an appropriate extensiontype protocol
295  check once ExtensionType is made public.
296
297  Args:
298    tensor: An object to test
299
300  Returns:
301    True if the tensor is an extension type object, false if not.
302  """
303  return isinstance(tensor, composite_tensor.CompositeTensor)
304
305
306def is_symbolic_tensor(tensor):
307  """Returns whether a tensor is symbolic (from a TF graph) or an eager tensor.
308
309  A Variable can be seen as either: it is considered symbolic
310  when we are in a graph scope, and eager when we are in an eager scope.
311
312  Args:
313    tensor: A tensor instance to test.
314
315  Returns:
316    True for symbolic tensors, False for eager tensors.
317  """
318  if isinstance(tensor, ops.Tensor):
319    return hasattr(tensor, 'graph')
320  elif is_extension_type(tensor):
321    component_tensors = nest.flatten(tensor, expand_composites=True)
322    return any(hasattr(t, 'graph') for t in component_tensors)
323  elif isinstance(tensor, variables.Variable):
324    # Variables that are output of a Keras Layer in Functional API mode
325    # should be considered symbolic.
326    # TODO(omalleyt): We need a better way to check this in order to
327    # enable `run_eagerly=True` for Models containing Layers that
328    # return Variables as outputs.
329    return (getattr(tensor, '_keras_history', False) or
330            not context.executing_eagerly())
331  elif isinstance(tensor, tuple(_user_convertible_tensor_types)):
332    tensor = ops.convert_to_tensor_or_composite(tensor)
333    return is_symbolic_tensor(tensor)
334  else:
335    return False
336
337
338def register_symbolic_tensor_type(cls):
339  """Allows users to specify types regarded as symbolic `Tensor`s.
340
341  Used in conjunction with `tf.register_tensor_conversion_function`, calling
342  `tf.keras.utils.register_symbolic_tensor_type(cls)` allows non-`Tensor`
343  objects to be plumbed through Keras layers.
344
345  Example:
346
347  ```python
348  # One-time setup.
349  class Foo(object):
350    def __init__(self, input_):
351      self._input = input_
352    def value(self):
353      return tf.constant(42.)
354
355  tf.register_tensor_conversion_function(
356      Foo, lambda x, *args, **kwargs: x.value())
357
358  tf.keras.utils.register_symbolic_tensor_type(Foo)
359
360  # User-land.
361  layer = tf.keras.layers.Lambda(lambda input_: Foo(input_))
362  ```
363
364  Args:
365    cls: A `class` type which shall be regarded as a symbolic `Tensor`.
366  """
367  global _user_convertible_tensor_types
368  if cls not in _user_convertible_tensor_types:
369    keras_tensor.register_keras_tensor_specialization(
370        cls, keras_tensor.UserRegisteredTypeKerasTensor)
371  _user_convertible_tensor_types.add(cls)
372
373
374def type_spec_from_value(value):
375  """Grab type_spec without converting array-likes to tensors."""
376  if is_extension_type(value):
377    return value._type_spec  # pylint: disable=protected-access
378  # Get a TensorSpec for array-like data without
379  # converting the data to a Tensor
380  if hasattr(value, 'shape') and hasattr(value, 'dtype'):
381    return tensor_spec.TensorSpec(value.shape, value.dtype)
382  else:
383    return type_spec.type_spec_from_value(value)
384
385
386def is_ragged(tensor):
387  """Returns true if `tensor` is a ragged tensor or ragged tensor value."""
388  return isinstance(
389      tensor,
390      (ragged_tensor.RaggedTensor, ragged_tensor_value.RaggedTensorValue))
391
392
393def is_tensor_or_variable(x):
394  return tensor_util.is_tf_type(x) or isinstance(x, variables.Variable)
395
396
397def assert_no_legacy_layers(layers):
398  """Prevent tf.layers.Layers from being used with Keras.
399
400  Certain legacy layers inherit from their keras analogs; however they are
401  not supported with keras and can lead to subtle and hard to diagnose bugs.
402
403  Args:
404    layers: A list of layers to check
405
406  Raises:
407    TypeError: If any elements of layers are tf.layers.Layers
408  """
409
410  # isinstance check for tf.layers.Layer introduces a circular dependency.
411  legacy_layers = [l for l in layers if getattr(l, '_is_legacy_layer', None)]
412  if legacy_layers:
413    layer_str = '\n'.join('  ' + str(l) for l in legacy_layers)
414    raise TypeError(
415        'The following are legacy tf.layers.Layers:\n{}\nTo use keras as a '
416        'framework (for instance using the Network, Model, or Sequential '
417        'classes), please use the tf.keras.layers implementation instead. '
418        '(Or, if writing custom layers, subclass from tf.keras.layers rather '
419        'than tf.layers)'.format(layer_str))
420
421
422@tf_contextlib.contextmanager
423def maybe_init_scope(layer):
424  """Open an `init_scope` if in V2 mode and using the keras graph.
425
426  Args:
427    layer: The Layer/Model that is currently active.
428
429  Yields:
430    None
431  """
432  # Don't open an init_scope in V1 mode or when using legacy tf.layers.
433  if (ops.executing_eagerly_outside_functions() and
434      getattr(layer, '_keras_style', True)):
435    with ops.init_scope():
436      yield
437  else:
438    yield
439
440
441@tf_contextlib.contextmanager
442def graph_context_for_symbolic_tensors(*args, **kwargs):
443  """Returns graph context manager if any of the inputs is a symbolic tensor."""
444  if any(is_symbolic_tensor(v) for v in list(args) + list(kwargs.values())):
445    with K.get_graph().as_default():
446      yield
447  else:
448    yield
449
450
451def dataset_is_infinite(dataset):
452  """True if the passed dataset is infinite."""
453  if ops.executing_eagerly_outside_functions():
454    return math_ops.equal(
455        cardinality.cardinality(dataset), cardinality.INFINITE)
456  else:
457    dataset_size = K.get_session().run(cardinality.cardinality(dataset))
458    return dataset_size == cardinality.INFINITE
459
460
461def get_tensor_spec(t, dynamic_batch=False, name=None):
462  """Returns a `TensorSpec` given a single `Tensor` or `TensorSpec`."""
463  # pylint: disable=protected-access
464  if isinstance(t, type_spec.TypeSpec):
465    spec = t
466  elif is_extension_type(t):
467    # TODO(b/148821952): Should these specs have a name attr?
468    spec = t._type_spec
469  elif (hasattr(t, '_keras_history') and
470        hasattr(t._keras_history[0], '_type_spec')):
471    return t._keras_history[0]._type_spec
472  elif hasattr(t, 'shape') and hasattr(t, 'dtype'):
473    spec = tensor_spec.TensorSpec(shape=t.shape, dtype=t.dtype, name=name)
474  else:
475    return None  # Allow non-Tensors to pass through.
476
477  if not dynamic_batch:
478    return spec
479
480  dynamic_batch_spec = copy.deepcopy(spec)
481  # RaggedTensorSpec only has a private _shape.
482  shape = dynamic_batch_spec._shape
483  if shape.rank is not None and shape.rank > 0:
484    shape_list = shape.as_list()
485    shape_list[0] = None
486    dynamic_batch_spec._shape = tensor_shape.TensorShape(shape_list)
487  return dynamic_batch_spec
488  # pylint: enable=protected-access
489
490
491def to_numpy_or_python_type(tensors):
492  """Converts a structure of `Tensor`s to `NumPy` arrays or Python scalar types.
493
494  For each tensor, it calls `tensor.numpy()`. If the result is a scalar value,
495  it converts it to a Python type, such as a float or int, by calling
496  `result.item()`.
497
498  Numpy scalars are converted, as Python types are often more convenient to deal
499  with. This is especially useful for bfloat16 Numpy scalars, which don't
500  support as many operations as other Numpy values.
501
502  Args:
503    tensors: A structure of tensors.
504
505  Returns:
506    `tensors`, but scalar tensors are converted to Python types and non-scalar
507    tensors are converted to Numpy arrays.
508  """
509  def _to_single_numpy_or_python_type(t):
510    if isinstance(t, ops.Tensor):
511      x = t.numpy()
512      return x.item() if np.ndim(x) == 0 else x
513    return t  # Don't turn ragged or sparse tensors to NumPy.
514
515  return nest.map_structure(_to_single_numpy_or_python_type, tensors)
516
517
518def _astuple(attrs):
519  """Converts the given attrs to tuple non-recursively."""
520  cls = type(attrs)
521  fields = getattr(cls, '__attrs_attrs__', None)
522  if fields is None:
523    raise ValueError('%r is not an attrs-decorated class.' % cls)
524  values = []
525  for field in fields:
526    values.append(getattr(attrs, field.name))
527  return tuple(values)
528