1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Contains private utilities used mainly by the base Layer class."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import collections as collections_lib
21import threading
22import enum
23
24from tensorflow.python.eager import context
25from tensorflow.python.framework import auto_control_deps
26from tensorflow.python.framework import dtypes
27from tensorflow.python.framework import ops
28from tensorflow.python.keras import backend
29from tensorflow.python.keras.utils import tf_utils
30from tensorflow.python.ops import array_ops
31from tensorflow.python.ops import control_flow_util
32from tensorflow.python.ops import init_ops
33from tensorflow.python.ops import init_ops_v2
34from tensorflow.python.ops import variables as tf_variables
35from tensorflow.python.util import nest
36from tensorflow.python.util import tf_contextlib
37
38_call_context = threading.local()
39
40
41class CallConvention(enum.Enum):
42  """Calling conventions for passing `Layer` inputs to `Layer.call`."""
43  # The Layer takes inputs as its first argument, named "inputs" for
44  # compatibility with the signature of Layer.__call__. This is the mode assumed
45  # for Layers which are not subclassed Models.
46  EXPLICIT_INPUTS_ARGUMENT = 1
47  # The Layer takes a single positional argument, not named "inputs". It's
48  # treated like an "inputs" argument.
49  SINGLE_POSITIONAL_ARGUMENT = 2
50  # The Layer has multiple positional arguments to which its inputs should be
51  # bound.
52  POSITIONAL_ARGUMENTS_ARE_INPUTS = 3
53
54
55def create_mean_metric(value, name=None):
56  # TODO(psv): Remove this import when b/110718070 is fixed.
57  from tensorflow.python.keras import metrics as metrics_module  # pylint: disable=g-import-not-at-top
58  metric_obj = metrics_module.Mean(name=name)
59  result = metric_obj(value)
60  return metric_obj, result
61
62
63def make_variable(name,
64                  shape=None,
65                  dtype=dtypes.float32,
66                  initializer=None,
67                  trainable=None,
68                  caching_device=None,
69                  validate_shape=True,
70                  constraint=None,
71                  use_resource=None,
72                  collections=None,
73                  synchronization=tf_variables.VariableSynchronization.AUTO,
74                  aggregation=tf_variables.VariableAggregation.NONE,
75                  partitioner=None):  # pylint: disable=unused-argument
76  """Temporary util to create a variable (relies on `variable_scope.variable`).
77
78  Some reuse-related technicalities prevent us from using
79  `variable_scope.get_variable()` directly, so we use a subcomponent
80  that has fewer constraints (`variable_scope.variable()`).
81
82  In the longer term, it seems like a similar "default variable creator" method
83  should exist in `Trackable` instead. When this happens, we can get
84  rid of this temporary solution.
85
86  TODO(fchollet): remove this method when no longer needed.
87
88  Arguments:
89    name: Variable name.
90    shape: Variable shape.
91    dtype: The type of the variable. Defaults to `self.dtype` or `float32`.
92    initializer: Initializer instance (callable).
93    trainable: Whether the variable should be part of the layer's
94      "trainable_variables" (e.g. variables, biases)
95      or "non_trainable_variables" (e.g. BatchNorm mean, stddev).
96      Note, if the current variable scope is marked as non-trainable
97      then this parameter is ignored and any added variables are also
98      marked as non-trainable. `trainable` defaults to `True` unless
99      `synchronization` is set to `ON_READ`.
100    caching_device: Passed to `tf.Variable`.
101    validate_shape: Passed to `tf.Variable`.
102    constraint: Constraint instance (callable).
103    use_resource: Whether to use a `ResourceVariable`.
104    collections: List of graph collections keys. The new variable is added to
105      these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
106    synchronization: Indicates when a distributed a variable will be
107      aggregated. Accepted values are constants defined in the class
108      `tf.VariableSynchronization`. By default the synchronization is set to
109      `AUTO` and the current `DistributionStrategy` chooses
110      when to synchronize. If `synchronization` is set to `ON_READ`,
111      `trainable` must not be set to `True`.
112    aggregation: Indicates how a distributed variable will be aggregated.
113      Accepted values are constants defined in the class
114      `tf.VariableAggregation`.
115    partitioner: Not handled at this time.
116
117  Returns:
118    Variable instance.
119  """
120  initializing_from_value = False
121  if initializer is not None and not callable(initializer):
122    initializing_from_value = True
123
124  with ops.init_scope():
125    if initializing_from_value:
126      init_val = initializer
127      variable_dtype = None
128    else:
129      # Instantiate initializer if provided initializer is a type object.
130      if isinstance(
131          initializer,
132          (type(init_ops.Initializer), type(init_ops_v2.Initializer))):
133        initializer = initializer()
134      init_val = lambda: initializer(shape, dtype=dtype)
135      variable_dtype = dtype.base_dtype
136  if use_resource is None:
137    use_resource = True
138
139  # TODO(apassos,rohanj) figure out how to remove collections from here so we
140  # can remove the V1.
141  v = tf_variables.VariableV1(
142      initial_value=init_val,
143      name=name,
144      trainable=trainable,
145      caching_device=caching_device,
146      dtype=variable_dtype,
147      validate_shape=validate_shape,
148      constraint=constraint,
149      use_resource=use_resource,
150      collections=collections,
151      synchronization=synchronization,
152      aggregation=aggregation)
153  return v
154
155
156def get_default_graph_uid_map():
157  # TODO(fchollet): refactor this into backend.
158  graph = ops.get_default_graph()
159  name_uid_map = backend.PER_GRAPH_LAYER_NAME_UIDS.get(graph, None)
160  if name_uid_map is None:
161    name_uid_map = collections_lib.defaultdict(int)
162    backend.PER_GRAPH_LAYER_NAME_UIDS[graph] = name_uid_map
163  return name_uid_map
164
165
166def unique_layer_name(name, name_uid_map=None, avoid_names=None, namespace='',
167                      zero_based=False):
168  """Makes a layer name (or arbitrary string) unique within a TensorFlow graph.
169
170  Arguments:
171    name: String name to make unique.
172    name_uid_map: An optional defaultdict(int) to use when creating unique
173      names. If None (default), uses a per-Graph dictionary.
174    avoid_names: An optional set or dict with names which should not be used. If
175      None (default) does not avoid any names.
176    namespace: Gets a name which is unique within the (graph, namespace). Layers
177      which are not Networks use a blank namespace and so get graph-global
178      names.
179    zero_based: If True, name sequences start with no suffix (e.g. "dense",
180      "dense_1"). If False, naming is one-based ("dense_1", "dense_2").
181
182  Returns:
183    Unique string name.
184
185  Example:
186
187  ```python
188  _unique_layer_name('dense')  # dense_1
189  _unique_layer_name('dense')  # dense_2
190  ```
191  """
192  if name_uid_map is None:
193    name_uid_map = get_default_graph_uid_map()
194  if avoid_names is None:
195    avoid_names = set()
196  proposed_name = None
197  while proposed_name is None or proposed_name in avoid_names:
198    name_key = (namespace, name)
199    if zero_based:
200      number = name_uid_map[name_key]
201      if number:
202        proposed_name = name + '_' + str(number)
203      else:
204        proposed_name = name
205      name_uid_map[name_key] += 1
206    else:
207      name_uid_map[name_key] += 1
208      proposed_name = name + '_' + str(name_uid_map[name_key])
209  return proposed_name
210
211
212def collect_previous_mask(input_tensors):
213  """Retrieves the output mask(s) of the previous node.
214
215  Arguments:
216      input_tensors: An arbitrary structure of Tensors.
217
218  Returns:
219      A mask tensor or list of mask tensors.
220  """
221
222  def _collect_previous_mask(x):
223    return getattr(x, '_keras_mask', None)
224
225  return nest.map_structure(_collect_previous_mask, input_tensors)
226
227
228def have_all_keras_metadata(tensors):
229  return all(hasattr(x, '_keras_history') for x in nest.flatten(tensors))
230
231
232def generate_placeholders_from_shape(shape):
233  return array_ops.placeholder(shape=shape, dtype=backend.floatx())
234
235
236def create_keras_history(tensors):
237  """Wraps TensorFlow Operations for compatibility with the Functional API.
238
239  This method checks to see if a Tensor in `tensors` is missing Keras metadata
240  and has its origin in a Keras `Input` Layer. If so, this method will replace
241  the raw TensorFlow Operations that created this tensor with
242  `TensorFlowOpLayer` instances that create identical operations.
243
244  Any Tensors not originating from a Keras `Input` Layer will be treated as
245  constants when constructing `TensorFlowOpLayer` instances.
246
247  Arguments:
248    tensors: A structure of Tensors, some of which come from raw TensorFlow
249      operations and need to have Keras metadata assigned to them.
250  """
251  _create_keras_history_helper(tensors, set())
252
253
254def _create_keras_history_helper(tensors, processed_ops=None):
255  """Helper method for `create_keras_history`.
256
257  Arguments:
258    tensors: A structure of Tensors for which to create Keras metadata.
259    processed_ops: Set. TensorFlow operations that have already been wrapped
260      in `TensorFlowOpLayer` instances.
261
262  Returns:
263    The updated set of TensorFlow Operations that have been wrapped
264    in `TensorFlowOpLayer` instances.
265  """
266  # Import of `base_layer` needed in order to create `TensorFlowOpLayer`.
267  # Cannot be imported at top because of circular dependencies.
268  # TODO(omalleyt): Resolve circular dependency.
269  from tensorflow.python.keras.engine import base_layer  # pylint: disable=g-import-not-at-top
270  tensor_list = nest.flatten(tensors)
271  for tensor in tensor_list:
272    if getattr(tensor, '_keras_history', None) is not None:
273      continue
274    op = tensor.op  # The Op that created this Tensor.
275    if op not in processed_ops:
276      # Recursively set `_keras_history`.
277      op_inputs = list(op.inputs)
278      constants = {}
279      layer_inputs = []
280      for i, op_input in enumerate(op_inputs):
281        if uses_keras_history(op_input):
282          layer_inputs.append(op_input)
283        else:
284          # Treat any value not originating from a `keras.Input` as
285          # a constant (Variables currently have `Placeholder` op type
286          # when originating from an eager context
287          # so can't be supported.
288          constants[i] = backend.function([], op_input)([])
289      processed_ops = _create_keras_history_helper(layer_inputs, processed_ops)
290      name = op.name
291      node_def = op.node_def.SerializeToString()
292      op_layer = base_layer.TensorFlowOpLayer(
293          node_def, constants=constants, name=name)
294      op_layer._add_inbound_node(  # pylint: disable=protected-access
295          layer_inputs, op.outputs)
296      processed_ops.update([op])
297  return processed_ops
298
299
300def needs_keras_history(tensors):
301  """Check if any Tensors need to be wrapped in TensorFlowOpLayers.
302
303  This will never return True inside a sublayer, because sublayers
304  do not need to create Keras History. Otherwise, this returns True
305  if one or more of `tensors` originates from a `keras.Input` and
306  does not have `_keras_history` set.
307
308  Arguments:
309    tensors: An arbitrary nested structure of Tensors.
310
311  Returns:
312    Bool, whether at least one Tensor needs to be wrapped.
313  """
314  input_tensors = nest.flatten(tensors)
315  if is_in_call_context() or all(
316      getattr(tensor, '_keras_history', None) is not None
317      for tensor in input_tensors):
318    # KerasHistory already set.
319    return False
320  return uses_keras_history(tensors)
321
322
323def is_in_call_context():
324  """Returns true if inside of a model/layer '__call__'."""
325  return getattr(_call_context, 'in_call', False)
326
327
328def uses_keras_history(tensors):
329  """Check if at least one Tensor originates from a `keras.Input`.
330
331  This is `True` if at least one Tensor has its origin in a `keras.Input`.
332  Any Tensor that originates from a `keras.Input` will have a dependency
333  Tensor with a `_keras_history` attribute attached. Tensors that have
334  already been checked to not originate from a `keras.Input`
335  are marked as `_keras_history_checked`.
336
337  Arguments:
338    tensors: An arbitrary nested structure of Tensors.
339
340  Returns:
341    Bool, whether at least one Tensor originates from a `keras.Input`.
342  """
343  checked_tensors = set()
344  tensors_to_check = nest.flatten(tensors)
345
346  while tensors_to_check:
347    new_tensors_to_check = set()
348    for tensor in tensors_to_check:
349      if getattr(tensor, '_keras_history_checked', None) is not None:
350        continue
351      if getattr(tensor, '_keras_history', None) is not None:
352        return True
353
354      try:
355        new_tensors_to_check.update(tensor.op.inputs)
356      except AttributeError:
357        # In case `tensor` is a Variable created in an Eager context.
358        pass
359
360    checked_tensors.update(tensors_to_check)
361    tensors_to_check = list(new_tensors_to_check - checked_tensors)
362
363  # Mark that these Tensors have been checked once for `_keras_history`,
364  # and should not be checked again for performance reasons.
365  mark_checked(tensors)
366  return False
367
368
369def mark_checked(tensors):
370  """Marks that these Tensors should not be tracked.
371
372  This prevents Layers from attempting to create TensorFlowOpLayers
373  for these Tensors.
374
375  Arguments:
376    tensors: An arbitrary structure of Tensors.
377  """
378
379  def _mark_checked(tensor):
380    tensor._keras_history_checked = True  # pylint: disable=protected-access
381
382  nest.map_structure(_mark_checked, tensors)
383
384
385@tf_contextlib.contextmanager
386def call_context():
387  """Scope that marks when we are currently inside a Layer/Model's `call`."""
388  was_in_call = is_in_call_context()
389  _call_context.in_call = True
390  try:
391    yield
392  finally:
393    _call_context.in_call = was_in_call
394
395
396def training_arg_passed_to_call(argspec, args, kwargs):
397  """Returns whether a user passed the `training` argument in `__call__`."""
398  # `argspec.args` starts with ['self', 'inputs']
399  full_args = dict(zip(argspec.args[2:], args))
400  full_args.update(kwargs)
401  return 'training' in full_args
402
403
404class AutoAddUpdates(object):
405  """Automatically track stateful ops with `add_update`.
406
407  This context manager is used to automatically add stateful ops to a Layer
408  or Model's `.updates`. This ensures that stateful ops are run in the Keras
409  training loop. It also allows for these stateful ops to be disabled by
410  setting `trainable=False`.
411
412  Example:
413
414  ```
415  with AutoAddUpdates(layer, inputs) as auto_updates:
416    outputs = layer.call(inputs)
417    auto_updates.set_outputs(outputs)
418  ```
419
420  Attributes:
421    layer: Layer or Model instance to add the updates to.
422    inputs: The inputs to this Layer or Model, to be used for input-conditional
423      updates.
424    outputs: The outputs of this Layer or Model.
425  """
426
427  def __init__(self, layer, inputs):
428    self.layer = layer
429    self.inputs = inputs
430    self.outputs = []
431
432  def set_outputs(self, outputs):
433    if self.outputs:
434      raise RuntimeError('`set_outputs` should only be called once on an'
435                         '`AutoAddUpdates` instance.')
436    self.outputs = outputs
437
438  def __enter__(self):
439    # Only run in V2 Function mode.
440    if (context.executing_eagerly() or
441        not ops.executing_eagerly_outside_functions()):
442      return self
443
444    self._graph = ops.get_default_graph()
445    self._num_operations = len(self._graph.get_operations())
446    return self
447
448  def __exit__(self, error_type, unused_value, unused_traceback):
449    if error_type:
450      # Allow errors that occurred inside this context manager to pass through
451      # normally.
452      return
453
454    # Only run in V2 Function mode.
455    if (context.executing_eagerly() or
456        not ops.executing_eagerly_outside_functions()):
457      return
458
459    if (self._graph is not ops.get_default_graph() or
460        self._graph.name != 'keras_graph'):
461      # Only auto-track updates when the Keras Graph is the only one used.
462      return
463
464    new_operations = self._graph.get_operations()[self._num_operations:]
465    new_stateful_ops = set()
466
467    # pylint: disable=protected-access
468    for op in new_operations:
469      # While loop is not supported in general for automatic control
470      # dependencies.
471      if control_flow_util.IsInWhileLoop(op):
472        continue
473
474      # Track stateful ops via `add_update`.
475      is_stateful_op = (
476          op.type not in self._graph._registered_ops or
477          auto_control_deps.op_is_stateful(
478              self._graph._registered_ops[op.type]))
479
480      # Ignore ReadVariableOps as they are not needed to be run separately.
481      # This ensures existing Layers don't get extra updates.
482      if is_stateful_op and op.type != 'ReadVariableOp':
483        new_stateful_ops.add(op)
484
485    explicit_updates = set([
486        u for u in self.layer._get_unfiltered_updates(check_trainable=False)
487        if not isinstance(u, tuple)
488    ])
489    # pylint: enable=protected-access
490
491    # Don't add updates that will already be run by virtue of being consumed by
492    # other stateful ops or by the Layer's outputs. This ensures that existing
493    # Layers like `BatchNormalization` continue to return the same values for
494    # `.update` calls.
495    minimum_ops = set()
496    targets = new_stateful_ops.union(
497        set(nest.flatten(self.outputs)), explicit_updates)
498    for op in new_stateful_ops:
499      # Scrub any ops that are consumed by the outputs or other stateful ops.
500      reachable = tf_utils.get_reachable_from_inputs(op)
501      if not (targets - {op}).intersection(reachable):
502        minimum_ops.add(op)
503    new_stateful_ops = minimum_ops
504
505    # Don't double-track updates added via explicitly calling `add_update`.
506    # Also don't double-track updates already tracked in sublayers.
507    new_stateful_ops = new_stateful_ops - explicit_updates
508
509    # Decide whether to track as input-conditional or unconditional.
510    input_reachable_ops = tf_utils.get_reachable_from_inputs(
511        self.inputs, targets=new_stateful_ops)
512    unconditional_updates = new_stateful_ops - input_reachable_ops
513    conditional_updates = new_stateful_ops - unconditional_updates
514
515    if unconditional_updates:
516      self.layer.add_update(list(unconditional_updates))
517    if conditional_updates:
518      self.layer.add_update(list(conditional_updates), inputs=self.inputs)
519
520
521def _get_var_read_dtype(input_list, should_cast):
522  """Gets the dtype that AutoCastVariables should be read in."""
523  if should_cast and input_list and input_list[0].dtype.is_floating:
524    return input_list[0].dtype.base_dtype
525  else:
526    return None
527
528
529def autocast_context_manager(input_list, should_cast):
530  """Returns a context manager to autocast AutoCastVariables.
531
532  Under this context manager, if `should_cast` is True, AutoCastVariables will
533  be casted. If `should_cast` is False, AutoCastVariables will not be casted,
534  which can be used to disable autocasting if nested under another
535  call to `autocast_context_manager`.
536
537  Args:
538    input_list: The inputs to the layer with the AutoCastVariables.
539    should_cast: Whether AutoCastVariables should be casted.
540
541  Returns:
542    A context manager to automatically cast AutoCastVariables.
543  """
544  var_read_dtype = _get_var_read_dtype(input_list, should_cast)
545  return ops.get_default_graph()._enable_auto_casting_variables(  # pylint: disable=protected-access
546      var_read_dtype)
547