1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utility functions shared between SavedModel saving/loading implementations."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import itertools
21import threading
22import types
23
24from tensorflow.python.eager import context
25from tensorflow.python.keras import backend as K
26from tensorflow.python.keras.engine import base_layer_utils
27from tensorflow.python.keras.utils import control_flow_util
28from tensorflow.python.keras.utils import tf_contextlib
29from tensorflow.python.keras.utils import tf_inspect
30from tensorflow.python.keras.utils.generic_utils import LazyLoader
31from tensorflow.python.util import tf_decorator
32
33
34# pylint:disable=g-inconsistent-quotes
35training_lib = LazyLoader(
36    "training_lib", globals(),
37    "tensorflow.python.keras.engine.training")
38# pylint:enable=g-inconsistent-quotes
39
40
41def use_wrapped_call(layer, call_fn, default_training_value=None,
42                     return_method=False):
43  """Creates fn that adds the losses returned by call_fn & returns the outputs.
44
45  Args:
46    layer: A Keras layer object
47    call_fn: tf.function that takes layer inputs (and possibly a training arg),
48      and returns a tuple of (outputs, list of losses).
49    default_training_value: Default value of the training kwarg. If `None`, the
50      default is `K.learning_phase()`.
51    return_method: Whether to return a method bound to the layer.
52
53  Returns:
54    function that calls call_fn and returns the outputs. Losses returned by
55    call_fn are added to the layer losses.
56  """
57  expects_training_arg = layer_uses_training_bool(layer)
58  if hasattr(call_fn, 'original_call'):  # call_fn is a LayerCall object
59    original_call = call_fn.original_call
60    # In Python 3, callable objects are not compatible with inspect.getargspec
61    call_fn = call_fn.__call__
62  else:
63    original_call = call_fn
64  fn, arg_spec = maybe_add_training_arg(
65      original_call, call_fn, expects_training_arg, default_training_value)
66
67  def return_outputs_and_add_losses(*args, **kwargs):
68    """Returns the outputs from the call_fn, and adds the losses."""
69    inputs_arg_index = 1 if return_method else 0
70    inputs = args[inputs_arg_index]
71    args = args[inputs_arg_index + 1:]
72    outputs, losses = fn(inputs, *args, **kwargs)
73    layer.add_loss(losses, inputs=inputs)
74
75    # TODO(kathywu): This is a temporary hack. When a network of layers is
76    # revived from SavedModel, only the top-level layer will have losses. This
77    # causes issues in eager mode because the child layers may have graph losses
78    # (thus model.losses returns a mix of Eager and graph tensors). To fix this,
79    # whenever eager losses are added to one layer, add eager losses to all
80    # child layers. This causes `.losses` to only return eager losses.
81    # pylint: disable=protected-access
82    if context.executing_eagerly():
83      for i in layer._flatten_layers():
84        if i is not layer:
85          i._eager_losses = [base_layer_utils.REVIVED_LOSS_PLACEHOLDER]
86    # pylint: enable=protected-access
87    return outputs
88
89  decorated = tf_decorator.make_decorator(
90      target=call_fn,
91      decorator_func=return_outputs_and_add_losses,
92      decorator_argspec=arg_spec)
93
94  if return_method:
95    return types.MethodType(decorated, layer)
96  else:
97    return decorated
98
99
100def layer_uses_training_bool(layer):
101  """Returns whether this layer or any of its children uses the training arg."""
102  if layer._expects_training_arg:  # pylint: disable=protected-access
103    return True
104  visited = {layer}
105  to_visit = list_all_layers(layer)
106  while to_visit:
107    layer = to_visit.pop()
108    if layer in visited:
109      continue
110    if getattr(layer, '_expects_training_arg', True):
111      return True
112    visited.add(layer)
113    to_visit.extend(list_all_layers(layer))
114  return False
115
116
117def list_all_layers(obj):
118  if isinstance(obj, training_lib.Model):
119    # Handle special case of Sequential, which doesn't return
120    # the `Input` layer.
121    return obj.layers
122  else:
123    return list(obj._flatten_layers(include_self=False, recursive=False))  # pylint: disable=protected-access
124
125
126def list_all_layers_and_sublayers(obj):
127  s = set([obj])
128  s.update(itertools.chain.from_iterable(
129      list_all_layers_and_sublayers(layer) for layer in list_all_layers(obj)))
130  return s
131
132
133def maybe_add_training_arg(
134    original_call, wrapped_call, expects_training_arg, default_training_value):
135  """Decorate call and optionally adds training argument.
136
137  If a layer expects a training argument, this function ensures that 'training'
138  is present in the layer args or kwonly args, with the default training value.
139
140  Args:
141    original_call: Original call function.
142    wrapped_call: Wrapped call function.
143    expects_training_arg: Whether to include 'training' argument.
144    default_training_value: Default value of the training kwarg to include in
145      the arg spec. If `None`, the default is `K.learning_phase()`.
146
147  Returns:
148    Tuple of (
149      function that calls `wrapped_call` and sets the training arg,
150      Argspec of returned function or `None` if the argspec is unchanged)
151  """
152  if not expects_training_arg:
153    return wrapped_call, None
154
155  def wrap_with_training_arg(*args, **kwargs):
156    """Wrap the `wrapped_call` function, and set training argument."""
157    training_arg_index = get_training_arg_index(original_call)
158    training = get_training_arg(training_arg_index, args, kwargs)
159    if training is None:
160      training = default_training_value or K.learning_phase()
161
162    args = list(args)
163    kwargs = kwargs.copy()
164
165    def replace_training_and_call(training):
166      set_training_arg(training, training_arg_index, args, kwargs)
167      return wrapped_call(*args, **kwargs)
168
169    return control_flow_util.smart_cond(
170        training, lambda: replace_training_and_call(True),
171        lambda: replace_training_and_call(False))
172
173  # Create arg spec for decorated function. If 'training' is not defined in the
174  # args of the original arg spec, then add it to kwonlyargs.
175  arg_spec = tf_inspect.getfullargspec(original_call)
176  defaults = list(arg_spec.defaults) if arg_spec.defaults is not None else []
177
178  kwonlyargs = arg_spec.kwonlyargs
179  kwonlydefaults = arg_spec.kwonlydefaults or {}
180  # Add training arg if it does not exist, or set the default training value.
181  if 'training' not in arg_spec.args:
182    kwonlyargs.append('training')
183    kwonlydefaults['training'] = default_training_value
184  else:
185    index = arg_spec.args.index('training')
186    training_default_index = len(arg_spec.args) - index
187    if (arg_spec.defaults and
188        len(arg_spec.defaults) >= training_default_index and
189        defaults[-training_default_index] is None):
190      defaults[-training_default_index] = default_training_value
191
192  decorator_argspec = tf_inspect.FullArgSpec(
193      args=arg_spec.args,
194      varargs=arg_spec.varargs,
195      varkw=arg_spec.varkw,
196      defaults=defaults,
197      kwonlyargs=kwonlyargs,
198      kwonlydefaults=kwonlydefaults,
199      annotations=arg_spec.annotations)
200  return wrap_with_training_arg, decorator_argspec
201
202
203def get_training_arg_index(call_fn):
204  """Returns the index of 'training' in the layer call function arguments.
205
206  Args:
207    call_fn: Call function.
208
209  Returns:
210    - n: index of 'training' in the call function arguments.
211    - -1: if 'training' is not found in the arguments, but layer.call accepts
212          variable keyword arguments
213    - None: if layer doesn't expect a training argument.
214  """
215  arg_list = tf_inspect.getfullargspec(call_fn).args
216  if tf_inspect.ismethod(call_fn):
217    arg_list = arg_list[1:]
218  if 'training' in arg_list:
219    return arg_list.index('training')
220  else:
221    return -1
222
223
224def set_training_arg(training, index, args, kwargs):
225  if index is None:
226    pass
227  elif index >= 0 and len(args) > index:
228    args[index] = training
229  else:
230    kwargs['training'] = training
231  return args, kwargs
232
233
234def get_training_arg(index, args, kwargs):
235  if index is None:
236    return None
237  elif index >= 0 and len(args) > index:
238    return args[index]
239  else:
240    return kwargs.get('training', None)
241
242
243def remove_training_arg(index, args, kwargs):
244  if index is None:
245    pass
246  elif index >= 0 and len(args) > index:
247    args.pop(index)
248  else:
249    kwargs.pop('training', None)
250
251
252class SaveOptionsContext(threading.local):
253
254  def __init__(self):
255    super(SaveOptionsContext, self).__init__()
256    self.save_traces = True
257
258
259_save_options_context = SaveOptionsContext()
260
261
262@tf_contextlib.contextmanager
263def keras_option_scope(save_traces):
264  previous_value = _save_options_context.save_traces
265  try:
266    _save_options_context.save_traces = save_traces
267    yield
268  finally:
269    _save_options_context.save_traces = previous_value
270
271
272def should_save_traces():
273  return _save_options_context.save_traces
274