1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tools for deserializing `Function`s."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import re
23
24from tensorflow.core.framework import function_pb2
25from tensorflow.python.eager import def_function
26from tensorflow.python.eager import function as function_lib
27from tensorflow.python.framework import func_graph as func_graph_lib
28from tensorflow.python.framework import function_def_to_graph as function_def_lib
29from tensorflow.python.framework import ops
30from tensorflow.python.framework import tensor_spec
31from tensorflow.python.ops import resource_variable_ops
32from tensorflow.python.platform import tf_logging as logging
33from tensorflow.python.saved_model import nested_structure_coder
34from tensorflow.python.util import compat
35from tensorflow.python.util import nest
36from tensorflow.python.util import tf_decorator
37from tensorflow.python.util import tf_inspect
38
39
40def _is_tensor(t):
41  return isinstance(t, (ops.Tensor, resource_variable_ops.ResourceVariable))
42
43
44def _call_concrete_function(function, inputs):
45  """Calls a restored Function with structured inputs.
46
47  This differs from `function.__call__` in that inputs and outputs are
48  structured and that it casts inputs to tensors if needed.
49
50  Note: this does not checks that non-tensor inputs match. That should be
51  done before via `_concrete_function_callable_with`.
52
53  Args:
54    function: ConcreteFunction to call.
55    inputs: Structured inputs compatible with
56        `function.graph.structured_input_signature`.
57
58  Returns:
59    The structured function output.
60  """
61  expected_structure = function.graph.structured_input_signature
62  flatten_inputs = nest.flatten_up_to(expected_structure, inputs)
63  tensor_inputs = []
64  for arg, expected in zip(flatten_inputs, nest.flatten(expected_structure)):
65    if isinstance(expected, tensor_spec.TensorSpec):
66      tensor_inputs.append(
67          ops.convert_to_tensor(arg, dtype_hint=expected.dtype))
68  result = function._call_flat(tensor_inputs)  # pylint: disable=protected-access
69  if isinstance(result, ops.Operation):
70    return None
71  return result
72
73
74def _try_convert_to_tensor_spec(arg, dtype_hint):
75  """Returns None or TensorSpec obtained if `arg` is converted to tensor."""
76  try:
77    # Note: try conversion in a FuncGraph to avoid poluting current context.
78    with func_graph_lib.FuncGraph(name="guess_conversion").as_default():
79      result = ops.convert_to_tensor(arg, dtype_hint=dtype_hint)
80      return tensor_spec.TensorSpec(shape=result.shape, dtype=result.dtype)
81  except (TypeError, ValueError):
82    return None
83
84
85def _concrete_function_callable_with(function, inputs, allow_conversion):
86  """Returns whether concrete `function` can be called with `inputs`."""
87  expected_structure = function.graph.structured_input_signature
88  try:
89    flatten_inputs = nest.flatten_up_to(expected_structure, inputs)
90  except (TypeError, ValueError):
91    return False
92  for arg, expected in zip(flatten_inputs, nest.flatten(expected_structure)):
93    if isinstance(expected, tensor_spec.TensorSpec):
94      if allow_conversion:
95        arg = _try_convert_to_tensor_spec(arg, dtype_hint=expected.dtype)
96      if not _is_tensor(arg) and not isinstance(arg, tensor_spec.TensorSpec):
97        return False
98      if arg.dtype != expected.dtype:
99        return False
100      if not expected.shape.is_compatible_with(arg.shape):
101        return False
102    else:
103      if arg != expected:
104        return False
105  return True
106
107
108def _deserialize_function_spec(function_spec_proto, coder):
109  """Deserialize a FunctionSpec object from its proto representation."""
110  typeless_fullargspec = coder.decode_proto(function_spec_proto.fullargspec)
111  fullargspec = tf_inspect.FullArgSpec(
112      args=typeless_fullargspec.args,
113      varargs=typeless_fullargspec.varargs,
114      varkw=typeless_fullargspec.varkw,
115      defaults=typeless_fullargspec.defaults,
116      kwonlyargs=typeless_fullargspec.kwonlyargs,
117      kwonlydefaults=typeless_fullargspec.kwonlydefaults,
118      annotations=typeless_fullargspec.annotations)
119  is_method = function_spec_proto.is_method
120  args_to_prepend = coder.decode_proto(function_spec_proto.args_to_prepend)
121  kwargs_to_include = coder.decode_proto(function_spec_proto.kwargs_to_include)
122  input_signature = coder.decode_proto(function_spec_proto.input_signature)
123  return function_lib.FunctionSpec(fullargspec, is_method, args_to_prepend,
124                                   kwargs_to_include, input_signature)
125
126
127# TODO(allenl): The fact that we can't derive ConcreteFunction calling
128# conventions from the serialized input spec right now is unfortunate. Merging
129# these would be good, maybe by adding TensorSpec names to cache keys so renamed
130# keyword arguments would yield different ConcreteFunctions.
131def setup_bare_concrete_function(saved_bare_concrete_function,
132                                 concrete_functions):
133  """Makes a restored bare concrete function callable."""
134  # Bare concrete functions accept only flat lists of Tensors with unique
135  # names.
136  concrete_function = concrete_functions[
137      saved_bare_concrete_function.concrete_function_name]
138  # pylint: disable=protected-access
139  concrete_function._arg_keywords = (
140      saved_bare_concrete_function.argument_keywords)
141  concrete_function._num_positional_args = (
142      saved_bare_concrete_function.allowed_positional_arguments)
143  # pylint: enable=protected-access
144  concrete_function.add_to_graph()
145  return concrete_function
146
147
148class RestoredFunction(def_function.Function):
149  """Wrapper class for a function that has been restored from saved state.
150
151  See `def_function.Function`.
152  """
153
154  def __init__(self, python_function, name, function_spec, concrete_functions):
155    # TODO(mdan): We may enable autograph once exceptions are supported.
156    super(RestoredFunction, self).__init__(
157        python_function, name, autograph=False)
158    self._concrete_functions = concrete_functions
159    self._function_spec = function_spec
160
161  def _list_all_concrete_functions_for_serialization(self):
162    return self._concrete_functions
163
164
165def recreate_function(saved_function, concrete_functions):
166  """Creates a `Function` from a `SavedFunction`.
167
168  Args:
169    saved_function: `SavedFunction` proto.
170    concrete_functions: map from function name to `ConcreteFunction`.
171
172  Returns:
173    A `Function`.
174  """
175  # TODO(andresp): Construct a `Function` with the cache populated
176  # instead of creating a new `Function` backed by a Python layer to
177  # glue things together. Current approach is nesting functions deeper for each
178  # serialization cycle.
179
180  coder = nested_structure_coder.StructureCoder()
181  function_spec = _deserialize_function_spec(saved_function.function_spec,
182                                             coder)
183
184  def restored_function_body(*args, **kwargs):
185    """Calls a restored function."""
186    # This is the format of function.graph.structured_input_signature. At this
187    # point, the args and kwargs have already been canonicalized.
188    inputs = (args, kwargs)
189
190    # First try to find a concrete function that can be called without input
191    # conversions. This allows one to pick a more specific trace in case there
192    # was also a more expensive one that supported tensors.
193    for allow_conversion in [False, True]:
194      for function_name in saved_function.concrete_functions:
195        function = concrete_functions[function_name]
196        if _concrete_function_callable_with(function, inputs, allow_conversion):
197          return _call_concrete_function(function, inputs)
198
199    available_signatures = [
200        concrete_functions[function_name].graph.structured_input_signature
201        for function_name in saved_function.concrete_functions
202    ]
203    raise ValueError(
204        "Could not find matching function to call for inputs %r. "
205        "Only existing signatures are %r."
206        % (inputs, available_signatures))
207
208  concrete_function_objects = []
209  for concrete_function_name in saved_function.concrete_functions:
210    concrete_function_objects.append(concrete_functions[concrete_function_name])
211
212  restored_function = RestoredFunction(
213      restored_function_body,
214      restored_function_body.__name__,
215      function_spec,
216      concrete_function_objects)
217
218  return tf_decorator.make_decorator(
219      restored_function_body,
220      restored_function,
221      decorator_argspec=function_spec.fullargspec)
222
223
224def load_function_def_library(library):
225  """Load a set of functions as concrete functions without captured inputs.
226
227  Functions names are manipulated during load such that they do not overlap
228  with previously created ones.
229
230  Args:
231    library: FunctionDefLibrary proto message.
232
233  Returns:
234    Map of original function names in the library to instances of
235    `ConcreteFunction` without captured inputs.
236
237  Raises:
238    ValueError: if functions dependencies have a cycle.
239  """
240  functions = {}
241
242  load_shared_name_suffix = "_load_{}".format(ops.uid())
243  for fdef in _sort_function_defs(library):
244    copy = _fix_fdef(fdef, functions, load_shared_name_suffix)
245
246    func_graph = function_def_lib.function_def_to_graph(copy)
247    for dep in _list_function_deps(fdef):
248      functions[dep].add_to_graph(func_graph)
249    func = function_lib.ConcreteFunction(func_graph)
250    func.add_to_graph()
251
252    functions[fdef.signature.name] = func
253
254    # Also register the gradients in the current root context.
255    with ops.init_scope():
256      func._register_gradient()  # pylint: disable=protected-access
257
258  return functions
259
260
261def _sort_function_defs(library):
262  """Return a topologic sort of FunctionDefs in a library."""
263  edges = collections.defaultdict(list)
264  in_count = collections.defaultdict(lambda: 0)
265
266  for fdef in library.function:
267    for dep in _list_function_deps(fdef):
268      edges[dep].append(fdef.signature.name)
269      in_count[fdef.signature.name] += 1
270
271  ready = [
272      fdef.signature.name
273      for fdef in library.function
274      if in_count[fdef.signature.name] == 0
275  ]
276  output = []
277  while ready:
278    node = ready.pop()
279    output.append(node)
280    for dest in edges[node]:
281      in_count[dest] -= 1
282      if not in_count[dest]:
283        ready.append(dest)
284
285  if len(output) != len(library.function):
286    failed_to_resolve = sorted(set(in_count.keys()) - set(output))
287    raise ValueError("There is a cyclic-dependency between functions. ",
288                     "Could not resolve %r." % (failed_to_resolve,))
289
290  reverse = {fdef.signature.name: fdef for fdef in library.function}
291  return [reverse[x] for x in output]
292
293
294def _fix_fdef(orig_fdef, functions, shared_name_suffix):
295  """Fixes a FunctionDef proto to be loaded in current context.
296
297  In particular, when loading a function library into an eager context, one
298  must rename the functions to avoid conflicts with existent functions.
299
300  Args:
301    orig_fdef: FunctionDef proto to fix. It is not modified.
302    functions: map from function name to a ConcreteFunction instance.
303    shared_name_suffix: A unique string for this load which helps to avoid
304      `shared_name` collisions across loads. Two functions from the same load
305      using the same `shared_name` still need to share, but functions from
306      different loads with the same `shared_name` should not.
307
308  Returns:
309    A fixed copy of the original FunctionDef.
310  """
311  fdef = function_pb2.FunctionDef()
312  fdef.CopyFrom(orig_fdef)
313  for node_def in fdef.node_def:
314    if "_gradient_op_type" in node_def.attr:
315      if node_def.op in ["StatefulPartitionedCall", "PartitionedCall"]:
316        # TODO(andresp): This code assumes that the gradient registered for this
317        # function call is the default gradient for the function and not a
318        # custom one.
319        fname = node_def.attr["f"].func.name
320        node_def.attr["_gradient_op_type"].s = compat.as_bytes(
321            functions[fname]._gradient_name)  # pylint: disable=protected-access
322      else:
323        logging.warning("Importing a function (%s) with ops with custom "
324                        "gradients. Will likely fail if a gradient is "
325                        "requested.", fdef.signature.name)
326    for _, attr_value in node_def.attr.items():
327      if attr_value.func.name:
328        attr_value.func.name = functions[attr_value.func.name].name
329
330    # TODO(b/124205571): Avoid accidental sharing and destruction of restored
331    # resources. For now uniquify "shared_name" when loading functions to avoid
332    # sharing.
333    if "shared_name" in node_def.attr:
334      node_def.attr["shared_name"].s += compat.as_bytes(shared_name_suffix)
335
336  fdef.signature.name = _clean_function_name(fdef.signature.name)
337  return fdef
338
339
340def _list_function_deps(fdef):
341  # TODO(andresp): Recurse into list attributes and into NameAttrList attrs both
342  # when listing deps and when fixing them. `function_def_to_graph` also
343  # requires fixes.
344  deps = set()
345  for node_def in fdef.node_def:
346    for _, attr_value in node_def.attr.items():
347      if attr_value.WhichOneof("value") == "func":
348        deps.add(attr_value.func.name)
349  return deps
350
351
352def _clean_function_name(name):
353  """Vanity function to keep the function names comprehensible."""
354  # Note: each time a function is wrapped into `function_lib.ConcreteFunction`
355  # its name becomes "__inference_<orig>_xyz".
356  match = re.search(r"^__inference_(.*)_\d+$", name)
357  if match:
358    return match.group(1)
359  else:
360    return name
361