1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Contains private utilities used mainly by the base Layer class."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import functools
21import threading
22
23from tensorflow.python import tf2
24from tensorflow.python.distribute import distribution_strategy_context
25from tensorflow.python.eager import context
26from tensorflow.python.framework import dtypes
27from tensorflow.python.framework import ops
28from tensorflow.python.framework import sparse_tensor
29from tensorflow.python.framework import tensor_shape
30from tensorflow.python.framework import tensor_util
31from tensorflow.python.keras import backend
32from tensorflow.python.keras.utils import control_flow_util
33from tensorflow.python.keras.utils import tf_inspect
34from tensorflow.python.keras.utils import tf_utils
35from tensorflow.python.ops import array_ops
36from tensorflow.python.ops import variables as tf_variables
37from tensorflow.python.ops.ragged import ragged_tensor
38from tensorflow.python.training.tracking import base as tracking
39from tensorflow.python.util import keras_deps
40from tensorflow.python.util import nest
41from tensorflow.python.util.tf_export import keras_export
42
43_call_context = threading.local()
44
45
46def create_mean_metric(value, name=None):
47  # import keras will import base_layer and then this module, and metric relies
48  # on base_layer, which result into a cyclic dependency.
49  from tensorflow.python.keras import metrics as metrics_module  # pylint: disable=g-import-not-at-top
50  metric_obj = metrics_module.Mean(name=name, dtype=value.dtype)
51  return metric_obj, metric_obj(value)
52
53
54def make_variable(name,
55                  shape=None,
56                  dtype=dtypes.float32,
57                  initializer=None,
58                  trainable=None,
59                  caching_device=None,
60                  validate_shape=True,
61                  constraint=None,
62                  use_resource=None,
63                  collections=None,
64                  synchronization=tf_variables.VariableSynchronization.AUTO,
65                  aggregation=tf_variables.VariableAggregation.NONE,
66                  partitioner=None):  # pylint: disable=unused-argument
67  """Temporary util to create a variable (relies on `variable_scope.variable`).
68
69  Some reuse-related technicalities prevent us from using
70  `variable_scope.get_variable()` directly, so we use a subcomponent
71  that has fewer constraints (`variable_scope.variable()`).
72
73  In the longer term, it seems like a similar "default variable creator" method
74  should exist in `Trackable` instead. When this happens, we can get
75  rid of this temporary solution.
76
77  TODO(fchollet): remove this method when no longer needed.
78
79  Args:
80    name: Variable name.
81    shape: Variable shape.
82    dtype: The type of the variable. Defaults to `self.dtype` or `float32`.
83    initializer: Initializer instance (callable).
84    trainable: Whether the variable should be part of the layer's
85      "trainable_variables" (e.g. variables, biases)
86      or "non_trainable_variables" (e.g. BatchNorm mean, stddev).
87      Note, if the current variable scope is marked as non-trainable
88      then this parameter is ignored and any added variables are also
89      marked as non-trainable. `trainable` defaults to `True` unless
90      `synchronization` is set to `ON_READ`.
91    caching_device: Passed to `tf.Variable`.
92    validate_shape: Passed to `tf.Variable`.
93    constraint: Constraint instance (callable).
94    use_resource: Whether to use a `ResourceVariable`.
95    collections: List of graph collections keys. The new variable is added to
96      these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
97    synchronization: Indicates when a distributed a variable will be
98      aggregated. Accepted values are constants defined in the class
99      `tf.VariableSynchronization`. By default the synchronization is set to
100      `AUTO` and the current `DistributionStrategy` chooses
101      when to synchronize. If `synchronization` is set to `ON_READ`,
102      `trainable` must not be set to `True`.
103    aggregation: Indicates how a distributed variable will be aggregated.
104      Accepted values are constants defined in the class
105      `tf.VariableAggregation`.
106    partitioner: Not handled at this time.
107
108  Returns:
109    Variable instance.
110  """
111  initializing_from_value = False
112  if initializer is not None and not callable(initializer):
113    initializing_from_value = True
114
115  if initializing_from_value:
116    init_val = initializer
117    variable_dtype = None
118  else:
119    # Instantiate initializer if provided initializer is a type object.
120    if tf_inspect.isclass(initializer):
121      initializer = initializer()
122    init_val = functools.partial(initializer, shape, dtype=dtype)
123    variable_dtype = dtype.base_dtype
124  if use_resource is None:
125    use_resource = True
126
127  # TODO(apassos,rohanj) figure out how to remove collections from here so we
128  # can remove the V1.
129  variable_shape = tensor_shape.TensorShape(shape)
130  return tf_variables.VariableV1(
131      initial_value=init_val,
132      name=name,
133      trainable=trainable,
134      caching_device=caching_device,
135      dtype=variable_dtype,
136      validate_shape=validate_shape,
137      constraint=constraint,
138      use_resource=use_resource,
139      collections=collections,
140      synchronization=synchronization,
141      aggregation=aggregation,
142      shape=variable_shape if variable_shape else None)
143
144
145def collect_previous_mask(input_tensors):
146  """Retrieves the output mask(s) of the previous node.
147
148  Args:
149      input_tensors: An arbitrary structure of Tensors.
150
151  Returns:
152      A mask tensor or list of mask tensors.
153  """
154
155  def _collect_previous_mask(x):
156    return getattr(x, '_keras_mask', None)
157
158  return nest.map_structure(_collect_previous_mask, input_tensors)
159
160
161def have_all_keras_metadata(tensors):
162  return all(hasattr(x, '_keras_history') for x in nest.flatten(tensors))
163
164
165def generate_placeholders_from_shape(shape):
166  return array_ops.placeholder(shape=shape, dtype=backend.floatx())
167
168
169def create_keras_history(tensors):
170  """Wraps TensorFlow Operations for compatibility with the Functional API.
171
172  This method checks to see if a Tensor in `tensors` is missing Keras metadata
173  and has its origin in a Keras `Input` Layer. If so, this method will replace
174  the raw TensorFlow Operations that created this tensor with
175  `TensorFlowOpLayer` instances that create identical operations.
176
177  Any Tensors not originating from a Keras `Input` Layer will be treated as
178  constants when constructing `TensorFlowOpLayer` instances.
179
180  Args:
181    tensors: A structure of Tensors, some of which come from raw TensorFlow
182      operations and need to have Keras metadata assigned to them.
183
184  Returns:
185    created_layers: List. The `TensorFlowOpLayer` instances created to wrap
186      the raw Tensorflow operations.
187  """
188  _, created_layers = _create_keras_history_helper(tensors, set(), [])
189  return created_layers
190
191
192# Unsafe Internal attribute.
193# If True, Keras will not evaluate the constant-foldable inputs to tf op
194# layers in TF1 graphs. This *might* speed up model construction time in
195# certain settings, but it means
196# the models will not be serializable/deserializable via get_config
197# (Only via Savedmodels). It may also change the semantics of whether
198# generated random numbers are generated once and re-used, or recomputed
199# each time.
200# Note: This path triggers for TPUEstimators / xla compiled graphs regardless
201# of this setting.
202_UNSAFE_GRAPH_OP_LAYER_CREATION = False
203
204
205def _create_keras_history_helper(tensors, processed_ops, created_layers):
206  """Helper method for `create_keras_history`.
207
208  Args:
209    tensors: A structure of Tensors for which to create Keras metadata.
210    processed_ops: Set. TensorFlow operations that have already been wrapped in
211      `TensorFlowOpLayer` instances.
212    created_layers: List. The `TensorFlowOpLayer` instances created.
213
214  Returns:
215    Tuple. First element is the updated set of TensorFlow Operations that
216    have been wrapped in `TensorFlowOpLayer` instances. Second element is
217    a list of the `TensorFlowOpLayer` instances created.
218  """
219  # Import of `base_layer` needed in order to create `TensorFlowOpLayer`.
220  # Cannot be imported at top because of circular dependencies.
221  # TODO(omalleyt): Resolve circular dependency.
222  from tensorflow.python.keras.engine import base_layer  # pylint: disable=g-import-not-at-top
223  tensor_list = nest.flatten(tensors)
224  sparse_ops = []
225  ragged_tensors = []
226  for tensor in tensor_list:
227    if getattr(tensor, '_keras_history', None) is not None:
228      continue
229    if isinstance(
230        tensor, (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue)):
231      sparse_ops.append(tensor.op)
232      continue
233    if tf_utils.is_ragged(tensor):
234      # Ragged tensors don't have an op property
235      ragged_tensors.append(tensor)
236      continue
237    op = tensor.op  # The Op that created this Tensor.
238    if op not in processed_ops:
239      # Recursively set `_keras_history`.
240      op_inputs = list(op.inputs)
241      constants = {}
242      layer_inputs = []
243      for i, op_input in enumerate(op_inputs):
244        if uses_keras_history(op_input):
245          layer_inputs.append(op_input)
246        else:
247          # Treat any value not originating from a `keras.Input` as
248          # a constant. Variables cannot be supported.
249          ds_with_session = (
250              distribution_strategy_context.in_cross_replica_context() and
251              not ops.executing_eagerly_outside_functions())
252          using_xla = control_flow_util.GraphOrParentsInXlaContext(
253              ops.get_default_graph())
254          if ds_with_session or using_xla or _UNSAFE_GRAPH_OP_LAYER_CREATION:
255            # In Legacy Graph mode, evaluating here makes Session be
256            # configured improperly. The downside of this is that saving
257            # via `get_config` breaks, but SavedModel still works.
258            constants[i] = op_input
259          else:
260            with ops.init_scope():
261              if ops.executing_eagerly_outside_functions():
262                constants[i] = backend.eval_in_eager_or_function(op_input)
263              else:
264                constants[i] = backend.function([], op_input)([])
265      layer_inputs = unnest_if_single_tensor(layer_inputs)
266      processed_ops, created_layers = _create_keras_history_helper(
267          layer_inputs, processed_ops, created_layers)
268      name = op.name
269      node_def = op.node_def.SerializeToString()
270      op_layer = base_layer.TensorFlowOpLayer(
271          node_def, constants=constants, name=name)
272      created_layers.append(op_layer)
273      op_layer._set_connectivity_metadata(  # pylint: disable=protected-access
274          args=(layer_inputs,),
275          kwargs={},
276          outputs=op.outputs)
277      processed_ops.update([op])
278  if sparse_ops or ragged_tensors:
279    lambda_example = """
280    weights_mult = lambda x: tf.sparse.sparse_dense_matmul(x, weights)
281    output = tf.keras.layers.Lambda(weights_mult)(input)
282    """
283    raise ValueError(
284        'Tensorflow ops that generate ragged or sparse tensor '
285        'outputs are currently not supported by Keras automatic '
286        'op wrapping. Please wrap these ops in a Lambda layer: '
287        '\n\n```\n{example}\n```\n'
288        'Sparse ops encountered: {sparse_ops}\n'
289        'Ragged tensors encountered: {ragged_tensors}\n'.format(
290            example=lambda_example,
291            sparse_ops=str(sparse_ops),
292            ragged_tensors=str(ragged_tensors)))
293  return processed_ops, created_layers
294
295
296def unnest_if_single_tensor(input_tensors):
297  # Preserve compatibility with older configs
298  flat_input_tensors = nest.flatten(input_tensors)
299  # If this is a single element but not a dict, unwrap. If this is a dict,
300  # assume the first layer expects a dict (as is the case with a
301  # DenseFeatures layer); pass through.
302  if not isinstance(input_tensors, dict) and len(flat_input_tensors) == 1:
303    input_tensors = flat_input_tensors[0]
304  return input_tensors
305
306
307def needs_keras_history(tensors, ignore_call_context=False):
308  """Check if any Tensors need to be wrapped in TensorFlowOpLayers.
309
310  This will never return True inside a sublayer, because sublayers
311  do not need to create Keras History. Otherwise, this returns True
312  if one or more of `tensors` originates from a `keras.Input` and
313  does not have `_keras_history` set.
314
315  Args:
316    tensors: An arbitrary nested structure of Tensors.
317    ignore_call_context: Whether to ignore the check of if currently
318      outside of a `call` context. This is `True` when creating
319      KerasHistory inside `Node`, where we always know that Tensors
320      are being used with the Functional API.
321
322  Returns:
323    Bool, whether at least one Tensor needs to be wrapped.
324  """
325  input_tensors = nest.flatten(tensors)
326  if call_context().in_call and not ignore_call_context:
327    return False
328  if all(
329      getattr(tensor, '_keras_history', None) is not None
330      for tensor in input_tensors):
331    # KerasHistory already set.
332    return False
333  return uses_keras_history(tensors)
334
335
336def is_in_keras_graph():
337  """Returns if currently executing inside of a Keras graph."""
338  return call_context().in_keras_graph
339
340
341def is_in_eager_or_tf_function():
342  """Returns if in eager mode or inside of a tf.function."""
343  return context.executing_eagerly() or is_in_tf_function()
344
345
346def is_in_tf_function():
347  """Returns if inside of a tf.function."""
348  # Check if running in V1 graph mode.
349  if not ops.executing_eagerly_outside_functions():
350    return False
351  if not ops.inside_function():
352    return False
353  # Check if inside Keras FuncGraph.
354  if is_in_keras_graph():
355    return False
356  # Check for a v1 `wrap_function` FuncGraph.
357  graph = ops.get_default_graph()
358  if (getattr(graph, 'name', False) and
359      graph.name.startswith('wrapped_function')):
360    return False
361  return True
362
363
364def uses_keras_history(tensors):
365  """Check if at least one Tensor originates from a `keras.Input`.
366
367  This is `True` if at least one Tensor has its origin in a `keras.Input`.
368  Any Tensor that originates from a `keras.Input` will have a dependency
369  Tensor with a `_keras_history` attribute attached. Tensors that have
370  already been checked to not originate from a `keras.Input`
371  are marked as `_keras_history_checked`.
372
373  Args:
374    tensors: An arbitrary nested structure of Tensors.
375
376  Returns:
377    Bool, whether at least one Tensor originates from a `keras.Input`.
378  """
379  checked_tensors = set()
380  tensors_to_check = nest.flatten(tensors)
381
382  while tensors_to_check:
383    new_tensors_to_check = []
384    for tensor in tensors_to_check:
385      if id(tensor) in checked_tensors:
386        continue
387
388      checked_tensors.add(id(tensor))
389
390      if getattr(tensor, '_keras_history_checked', None) is not None:
391        continue
392      if getattr(tensor, '_keras_history', None) is not None:
393        return True
394
395      try:
396        new_tensors_to_check.extend(tensor.op.inputs)
397      except AttributeError:
398        # In case `tensor` is a Variable created in an Eager context.
399        pass
400
401    tensors_to_check = new_tensors_to_check
402
403  # Mark that these Tensors have been checked once for `_keras_history`,
404  # and should not be checked again for performance reasons.
405  mark_checked(tensors)
406  return False
407
408
409def mark_checked(tensors):
410  """Marks that these Tensors should not be tracked.
411
412  This prevents Layers from attempting to create TensorFlowOpLayers
413  for these Tensors.
414
415  Args:
416    tensors: An arbitrary structure of Tensors.
417  """
418
419  def _mark_checked(tensor):
420    tensor._keras_history_checked = True  # pylint: disable=protected-access
421
422  nest.map_structure(_mark_checked, tensors)
423
424
425def call_context():
426  """Returns currently active `CallContext`."""
427  call_ctx = getattr(_call_context, 'call_context', None)
428  if call_ctx is None:
429    call_ctx = CallContext()
430    _call_context.call_context = call_ctx
431  return call_ctx
432
433
434# Inject the call_context function to keras_deps to remove the dependency
435# from TFLite to Keras.
436keras_deps.register_call_context_function(call_context)
437
438
439class CallContext(object):
440  """Keeps track of properties currently inside a Layer/Model's `call`.
441
442  Attributes:
443    in_call: Whether currently inside the `call` of a Layer.
444    layer: The `Layer` whose `call` is currently active.
445    inputs: The inputs to the currently active `Layer`.
446    build_graph: Whether currently inside a Graph or FuncGraph.
447    training: Whether currently executing in training or inference mode.
448    saving: Whether currently saving to SavedModel.
449    frozen: Whether currently executing inside a `Layer` with `trainable` set to
450      `False`.
451    in_keras_graph: Whether executing inside the Keras Graph.
452  """
453
454  def __init__(self):
455    # Handle `in_call` separately as it is the most-read attr and reading it is
456    # on the hot path.
457    self.in_call = False
458    self._state = {
459        'layer': None,
460        'inputs': None,
461        'build_graph': False,
462        'training': None,
463        'saving': None
464    }
465    # TODO(b/150169018): This logic can be replaced after the Functional API
466    # refactor.
467    self._in_keras_graph = False
468
469  def enter(self, layer, inputs, build_graph, training, saving=None):
470    """Push a Layer and its inputs and state onto the current call context.
471
472    Args:
473      layer: The `Layer` whose `call` is currently active.
474      inputs: The inputs to the currently active `Layer`.
475      build_graph: Whether currently inside a Graph or FuncGraph.
476      training: Whether currently executing in training or inference mode.
477      saving: Whether currently saving to SavedModel.
478
479    Returns:
480      Context manager.
481    """
482    state = {
483        'layer': layer,
484        'inputs': inputs,
485        'build_graph': build_graph,
486        'training': training,
487        'saving': saving
488    }
489    return CallContextManager(self, state)
490
491  @property
492  def layer(self):
493    return self._state['layer']
494
495  @property
496  def inputs(self):
497    return self._state['inputs']
498
499  @property
500  def build_graph(self):
501    return self._state['build_graph']
502
503  @property
504  def training(self):
505    return self._state['training']
506
507  @property
508  def saving(self):
509    return self._state['saving']
510
511  @property
512  def frozen(self):
513    layer = self._state['layer']
514    if not layer:
515      return False
516    return not layer.trainable
517
518  @property
519  def in_keras_graph(self):
520    # Returns True even if in a subgraph of the Keras graph, such as those
521    # created by control flow ops.
522    if context.executing_eagerly():
523      return False
524    return (self._in_keras_graph or
525            getattr(backend.get_graph(), 'name', None) == 'keras_graph')
526
527
528class CallContextManager(object):
529  """Context manager for `CallContext`."""
530
531  def __init__(self, call_ctx, state):
532    self._call_ctx = call_ctx
533    self._state = state
534    self._build_graph = state['build_graph']
535
536  def __enter__(self):
537    call_ctx = self._call_ctx
538    self._prev_in_call = call_ctx.in_call
539    self._prev_state = call_ctx._state
540
541    call_ctx.in_call = True
542    call_ctx._state = self._state
543
544    # TODO(b/150169018): This logic can be removed after the Functional API
545    # refactor.
546    if self._build_graph:
547      self._prev_in_keras_graph = call_ctx._in_keras_graph
548      call_ctx._in_keras_graph = (
549          call_ctx._in_keras_graph or
550          getattr(backend.get_graph(), 'name', None) == 'keras_graph')
551
552  def __exit__(self, *exc_info):
553    call_ctx = self._call_ctx
554    call_ctx.in_call = self._prev_in_call
555    call_ctx._state = self._prev_state
556
557    if self._build_graph:
558      call_ctx._in_keras_graph = self._prev_in_keras_graph
559
560
561def training_arg_passed_to_call(argspec, args, kwargs):
562  """Returns whether a user passed the `training` argument in `__call__`."""
563  # `argspec.args` starts with ['self', 'inputs']
564  full_args = dict(zip(argspec.args[2:], args))
565  full_args.update(kwargs)
566  return 'training' in full_args and full_args['training'] is not None
567
568
569def is_subclassed(layer):
570  """Returns True if the object is a subclassed layer or subclassed model."""
571  return (layer.__module__.find('keras.engine') == -1 and
572          layer.__module__.find('keras.layers') == -1)
573
574
575def from_saved_model(layer):
576  """Returns whether the layer is loaded from a SavedModel."""
577  return layer.__module__.find('keras.saving.saved_model') != -1
578
579
580def check_graph_consistency(tensor=None, method='add_loss', force_raise=False):
581  """Checks that tensors passed to `add_*` method match the Keras graph.
582
583  When one of the `add_*` method is called inside a V2 conditional branch,
584  the underlying tensor gets created in a FuncGraph managed by control_flow_v2.
585  We need to raise clear error messages in such cases.
586
587  Args:
588    tensor: Tensor to check, or `False` if it is known that an error
589      should be raised.
590    method: Caller method, one of {'add_metric', 'add_loss', 'add_update'}.
591    force_raise: If an error should be raised regardless of `tensor`.
592
593  Raises:
594    RuntimeError: In case of an out-of-graph tensor.
595  """
596  if (force_raise or
597      (ops.executing_eagerly_outside_functions() and
598       hasattr(tensor, 'graph') and tensor.graph.is_control_flow_graph)):
599    if method == 'activity_regularizer':
600      bad_example = """
601      class TestModel(tf.keras.Model):
602
603        def __init__(self):
604          super(TestModel, self).__init__(name='test_model')
605          self.dense = tf.keras.layers.Dense(2, activity_regularizer='l2')
606
607        def call(self, x, training=None):
608          if training:
609            return self.dense(x)
610          else:
611            return self.dense(x)
612      """
613      correct_example = """
614      class TestModel(tf.keras.Model):
615
616        def __init__(self):
617          super(TestModel, self).__init__(name='test_model')
618          self.dense = tf.keras.layers.Dense(2, activity_regularizer='l2')
619
620        def call(self, x, training=None):
621          return self.dense(x)
622      """
623      raise RuntimeError(
624          'You are using a layer with `activity_regularizer` in a control flow '
625          'branch, e.g.:\n{bad_example}\nThis is currently not supported. '
626          'Please move your call to the layer with `activity_regularizer` out '
627          'of the control flow branch, e.g.:\n{correct_example}\n'
628          'You can also resolve this by marking your outer model/layer dynamic'
629          ' (eager-only) by passing `dynamic=True` to the layer constructor. '
630          'Any kind of control flow is supported with dynamic layers. '
631          'Note that using `dynamic=True` requires you to implement static '
632          'shape inference in the `compute_output_shape(input_shape)` '
633          'method.'.format(
634              bad_example=bad_example, correct_example=correct_example))
635
636    if method == 'add_metric':
637      bad_example = """
638      def call(self, inputs, training=None):
639        if training:
640          metric = compute_metric(inputs)
641          self.add_metric(metric, name='my_metric', aggregation='mean')
642        return inputs
643      """
644      correct_example = """
645      def call(self, inputs, training=None):
646        if training:
647          metric = compute_metric(inputs)
648        else:
649          metric = 0.
650        self.add_metric(metric, name='my_metric', aggregation='mean')
651        return inputs
652      """
653    elif method == 'add_loss':
654      bad_example = """
655      def call(self, inputs, training=None):
656        if training:
657          loss = compute_loss(inputs)
658          self.add_loss(loss)
659        return inputs
660      """
661      correct_example = """
662      def call(self, inputs, training=None):
663        if training:
664          loss = compute_loss(inputs)
665        else:
666          loss = 0.
667        self.add_loss(loss)
668        return inputs
669      """
670    else:
671      bad_example = """
672      def call(self, inputs, training=None):
673        if training:
674          self.add_update(self.w.assign_add(1))
675        return inputs
676      """
677      correct_example = """
678      def call(self, inputs, training=None):
679        if training:
680          increment = 1
681        else:
682          increment = 0
683        self.add_update(self.w.assign_add(increment))
684        return inputs
685      """
686    raise RuntimeError(
687        'You are using the method `{method}` in a control flow branch '
688        'in your layer, e.g.:\n{bad_example}\n'
689        'This is not currently supported. '
690        'Please move your call to {method} out of the control flow branch, '
691        'e.g.:\n{correct_example}\n'
692        'You can also resolve this by marking your layer '
693        'as dynamic (eager-only) by passing '
694        '`dynamic=True` to the layer constructor. '
695        'Any kind of control flow is supported with dynamic layers. '
696        'Note that using `dynamic=True` requires you '
697        'to implement static shape inference '
698        'in the `compute_output_shape(input_shape)` method.'.format(
699            method=method,
700            bad_example=bad_example,
701            correct_example=correct_example))
702
703
704def mark_as_return(outputs, acd):
705  """Marks `outputs` as the return values for automatic control deps."""
706
707  def _mark_as_return(tensor):
708    """Marks `tensor` as the return value for automatic control deps."""
709    if not tensor_util.is_tf_type(tensor):
710      return tensor
711
712    # pylint: disable=protected-access
713    return_tensor = acd.mark_as_return(tensor)
714    if getattr(tensor, '_keras_mask', None) is not None:
715      return_tensor._keras_mask = acd.mark_as_return(tensor._keras_mask)
716    else:
717      return_tensor._keras_mask = None
718
719    # Handle TensorFlow Probability attached metadata.
720    # TODO(b/132076537): Remove this once TFP uses `CompositeTensor`.
721    if getattr(tensor, '_tfp_distribution', None) is not None:
722      return_tensor._tfp_distribution = tensor._tfp_distribution
723
724    return return_tensor
725    # pylint: enable=protected-access
726
727  return nest.map_structure(_mark_as_return, outputs)
728
729
730V2_DTYPE_BEHAVIOR = None
731
732
733@keras_export(v1=['keras.layers.enable_v2_dtype_behavior'])
734def enable_v2_dtype_behavior():
735  """Enable the V2 dtype behavior for Keras layers.
736
737  By default, the V2 dtype behavior is enabled in TensorFlow 2, so this function
738  is only useful if `tf.compat.v1.disable_v2_behavior` has been called. Since
739  mixed precision requires V2 dtype behavior to be enabled, this function allows
740  you to use mixed precision in Keras layers if `disable_v2_behavior` has been
741  called.
742
743  When enabled, the dtype of Keras layers defaults to floatx (which is typically
744  float32) instead of None. In addition, layers will automatically cast
745  floating-point inputs to the layer's dtype.
746
747  >>> x = tf.ones((4, 4, 4, 4), dtype='float64')
748  >>> layer = tf.keras.layers.Conv2D(filters=4, kernel_size=2)
749  >>> print(layer.dtype)  # float32 since V2 dtype behavior is enabled
750  float32
751  >>> y = layer(x)  # Layer casts inputs since V2 dtype behavior is enabled
752  >>> print(y.dtype.name)
753  float32
754
755  A layer author can opt-out their layer from the automatic input casting by
756  passing `autocast=False` to the base Layer's constructor. This disables the
757  autocasting part of the V2 behavior for that layer, but not the defaulting to
758  floatx part of the V2 behavior.
759
760  When a global `tf.keras.mixed_precision.Policy` is set, a Keras layer's dtype
761  will default to the global policy instead of floatx. Layers will automatically
762  cast inputs to the policy's compute_dtype.
763  """
764  global V2_DTYPE_BEHAVIOR
765  V2_DTYPE_BEHAVIOR = True
766
767
768@keras_export(v1=['keras.layers.disable_v2_dtype_behavior'])
769def disable_v2_dtype_behavior():
770  """Disables the V2 dtype behavior for Keras layers.
771
772  See `tf.compat.v1.keras.layers.enable_v2_dtype_behavior`.
773  """
774  global V2_DTYPE_BEHAVIOR
775  V2_DTYPE_BEHAVIOR = False
776
777
778def v2_dtype_behavior_enabled():
779  """Returns True if the V2 dtype behavior is enabled."""
780  if V2_DTYPE_BEHAVIOR is None:
781    return tf2.enabled()
782  return V2_DTYPE_BEHAVIOR
783
784
785class TrackableWeightHandler(object):
786  """Keras wrapper for handling tracking.Trackable object saving and restoring.
787
788  This class handles Trackables in both V1 and V2 modes, ensuring that they can
789  be saved and restored with the correct data and without adding additional ops
790  on every save.
791
792  Attributes:
793    trackable: The trackable to wrap.
794    num_tensors: The number of tensors that this trackable requires for saving.
795  """
796
797  def __init__(self, trackable):
798    if not isinstance(trackable, tracking.Trackable):
799      raise ValueError('%s is not a Trackable object.' % (trackable,))
800    self._trackable = trackable
801    self._distribute_strategy = distribution_strategy_context.get_strategy()
802
803    # TODO(b/141682913): Figure out why this is private and fix it.
804    saveables = trackable._gather_saveables_for_checkpoint().values()  # pylint: disable=protected-access
805    # 'Saveables' won't exist when we're passed a legacy TF1 table like
806    # a StaticHashTable.
807    if not saveables:
808      self._num_tensors = 0
809      self._setter = lambda weights: None
810      self._getter = lambda: []
811
812    elif len(saveables) == 1:
813      saveable = list(saveables)[0]
814
815      if ops.executing_eagerly_outside_functions():
816        # If we're in eager mode, we need to defer calling the Trackable's
817        # saveable() callable until data export time.
818        # However, it is safe to call the saveable as many times as we want, so
819        # we will call it now to figure out how many tensors this Trackable will
820        # produce.
821        self._saveable = saveable
822        self._num_tensors = len(self._saveable().specs)
823        self._setter = lambda weights: self._saveable().restore(weights, None)
824        self._getter = lambda: [spec.tensor for spec in self._saveable().specs]
825      else:
826        # If we're in Graph mode, we need to evaluate the Saveable only once and
827        # cache the resulting restore graph. Failing to do this will result in
828        # new assignment ops being added to the graph each time set_weights() is
829        # called.
830        self._placeholder_tensors = []
831        self._saveable = saveable()
832        self._num_tensors = len(self._saveable.specs)
833        for spec in self._saveable.specs:
834          tensor = spec.tensor
835          self._placeholder_tensors.append(
836              array_ops.placeholder(tensor.dtype, tensor.shape))
837        self._assign_op = self._saveable.restore(self._placeholder_tensors,
838                                                 None)
839        self._setter = self._set_weights_v1
840        self._getter = lambda: [spec.tensor for spec in self._saveable.specs]
841    else:
842      raise ValueError('Only Trackables with one Saveable are supported. '
843                       'The Trackable %s has %d Saveables.' %
844                       (trackable, len(saveables)))
845
846  @property
847  def num_tensors(self):
848    return self._num_tensors
849
850  def set_weights(self, weights):
851    if len(weights) != self._num_tensors:
852      raise ValueError(
853          ('Weight handler for trackable %s received the wrong number of ' +
854           'weights: expected %s, got %s.') %
855          (self._trackable, self._num_tensors, len(weights)))
856    self._setter(weights)
857
858  def get_tensors(self):
859    return self._getter()
860
861  def _set_weights_v1(self, weights):
862    feed_dict = {}
863    for idx, tensor in enumerate(weights):
864      feed_dict[self._placeholder_tensors[idx]] = tensor
865    backend.get_session().run(self._assign_op, feed_dict)
866
867
868def no_ragged_support(inputs, layer_name):
869  input_list = nest.flatten(inputs)
870  if any(isinstance(x, ragged_tensor.RaggedTensor) for x in input_list):
871    raise ValueError('Layer %s does not support RaggedTensors as input. '
872                     'Inputs received: %s. You can try converting your '
873                     'input to an uniform tensor.' % (layer_name, inputs))
874
875
876def is_split_variable(v):
877  """Returns True if `v` is either a PartionedVariable or a ShardedVariable."""
878  return hasattr(v, '_variable_list') or hasattr(v, '_variables')
879
880
881def has_weights(obj):
882  obj_type = type(obj)
883  return (hasattr(obj_type, 'trainable_weights') and
884          hasattr(obj_type, 'non_trainable_weights') and
885          not isinstance(obj, type))
886
887
888# TODO(kathywu): This is a temporary hack. When a network of layers is revived
889# from SavedModel, only the top-level layer will have losses. This causes issues
890# in eager mode because the child layers may have graph losses
891# (thus model.losses returns a mix of Eager and graph tensors). To fix this,
892# whenever eager losses are added to one layer, add eager losses to all
893# child layers. This causes `.losses` to only return eager losses.
894REVIVED_LOSS_PLACEHOLDER = (
895    'This layer\'s losses have been added to the parent layer.')
896