1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tools for deserializing `Function`s."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import re
23from absl import logging
24
25from tensorflow.core.framework import function_pb2
26from tensorflow.core.protobuf import saved_object_graph_pb2
27from tensorflow.python.eager import def_function
28from tensorflow.python.eager import function as function_lib
29from tensorflow.python.framework import func_graph as func_graph_lib
30from tensorflow.python.framework import function_def_to_graph as function_def_lib
31from tensorflow.python.framework import op_def_registry
32from tensorflow.python.framework import ops
33from tensorflow.python.framework import tensor_spec
34from tensorflow.python.framework import type_spec
35from tensorflow.python.ops import resource_variable_ops
36from tensorflow.python.saved_model import nested_structure_coder
37from tensorflow.python.util import compat
38from tensorflow.python.util import nest
39from tensorflow.python.util import tf_decorator
40from tensorflow.python.util import tf_inspect
41
42
43def _is_tensor(t):
44  return isinstance(t, (ops.Tensor, resource_variable_ops.BaseResourceVariable))
45
46
47# TODO(edloper): Update this to just use ConcreteFunction.__call__ with the
48# structured signature.
49def _call_concrete_function(function, inputs):
50  """Calls a restored Function with structured inputs.
51
52  This differs from `function.__call__` in that inputs and outputs are
53  structured and that it casts inputs to tensors if needed.
54
55  Note: this does not checks that non-tensor inputs match. That should be
56  done before via `_concrete_function_callable_with`.
57
58  Args:
59    function: ConcreteFunction to call.
60    inputs: Structured inputs compatible with
61        `function.graph.structured_input_signature`.
62
63  Returns:
64    The structured function output.
65  """
66  expected_structure = function.graph.structured_input_signature
67  flatten_inputs = nest.flatten_up_to(
68      expected_structure, inputs, expand_composites=True)
69  flatten_expected = nest.flatten(expected_structure, expand_composites=True)
70  tensor_inputs = []
71  for arg, expected in zip(flatten_inputs, flatten_expected):
72    if isinstance(expected, tensor_spec.TensorSpec):
73      tensor_inputs.append(
74          ops.convert_to_tensor(arg, dtype_hint=expected.dtype))
75  result = function._call_flat(tensor_inputs, function._captured_inputs)  # pylint: disable=protected-access
76  if isinstance(result, ops.Operation):
77    return None
78  return result
79
80
81def _try_convert_to_tensor_spec(arg, dtype_hint):
82  """Returns None or TensorSpec obtained if `arg` is converted to tensor."""
83  try:
84    # Note: try conversion in a FuncGraph to avoid polluting current context.
85    with func_graph_lib.FuncGraph(name="guess_conversion").as_default():
86      result = ops.convert_to_tensor(arg, dtype_hint=dtype_hint)
87      return tensor_spec.TensorSpec(shape=result.shape, dtype=result.dtype)
88  except (TypeError, ValueError):
89    return None
90
91
92def _concrete_function_callable_with(function, inputs, allow_conversion):
93  """Returns whether concrete `function` can be called with `inputs`."""
94  expected_structure = function.graph.structured_input_signature
95  try:
96    flatten_inputs = nest.flatten_up_to(expected_structure, inputs)
97  except (TypeError, ValueError):
98    return False
99
100  for arg, expected in zip(flatten_inputs, nest.flatten(expected_structure)):
101    if isinstance(expected, tensor_spec.TensorSpec):
102      if allow_conversion:
103        arg = _try_convert_to_tensor_spec(arg, dtype_hint=expected.dtype)
104      if not _is_tensor(arg) and not isinstance(arg, tensor_spec.TensorSpec):
105        return False
106      if arg.dtype != expected.dtype:
107        return False
108      if not expected.shape.is_compatible_with(arg.shape):
109        return False
110    elif isinstance(expected, type_spec.TypeSpec):
111      if not expected.is_compatible_with(arg):
112        return False
113    elif _is_tensor(arg):
114      if id(arg) != id(expected):
115        return False
116    else:
117      if arg != expected:
118        return False
119  return True
120
121
122def _deserialize_function_spec_as_nonmethod(function_spec_proto, coder):
123  """Deserialize a FunctionSpec object from its proto representation."""
124  typeless_fullargspec = coder.decode_proto(function_spec_proto.fullargspec)
125
126  # Convert a method function into a non method.
127  if function_spec_proto.is_method:
128    if not typeless_fullargspec.args:
129      raise NotImplementedError(
130          "Missing support to deserialize a method function without a named "
131          "'self' argument.")
132    args = typeless_fullargspec.args[1:]
133  else:
134    args = typeless_fullargspec.args
135
136  fullargspec = tf_inspect.FullArgSpec(
137      args=args,
138      varargs=typeless_fullargspec.varargs,
139      varkw=typeless_fullargspec.varkw,
140      defaults=typeless_fullargspec.defaults,
141      kwonlyargs=typeless_fullargspec.kwonlyargs,
142      kwonlydefaults=typeless_fullargspec.kwonlydefaults,
143      annotations=typeless_fullargspec.annotations)
144  input_signature = coder.decode_proto(function_spec_proto.input_signature)
145
146  # See `tf.function` and the JitCompile proto for details.
147  jit_compile = {
148      saved_object_graph_pb2.FunctionSpec.JitCompile.DEFAULT: None,
149      saved_object_graph_pb2.FunctionSpec.JitCompile.ON: True,
150      saved_object_graph_pb2.FunctionSpec.JitCompile.OFF: False,
151  }.get(function_spec_proto.jit_compile)
152
153  return function_lib.FunctionSpec(fullargspec=fullargspec,
154                                   is_method=False,
155                                   input_signature=input_signature,
156                                   jit_compile=jit_compile)
157
158
159# TODO(allenl): The fact that we can't derive ConcreteFunction calling
160# conventions from the serialized input spec right now is unfortunate. Merging
161# these would be good, maybe by adding TensorSpec names to cache keys so renamed
162# keyword arguments would yield different ConcreteFunctions.
163def setup_bare_concrete_function(saved_bare_concrete_function,
164                                 concrete_functions):
165  """Makes a restored bare concrete function callable."""
166  concrete_function = concrete_functions[
167      saved_bare_concrete_function.concrete_function_name]
168  # pylint: disable=protected-access
169  concrete_function._arg_keywords = (
170      saved_bare_concrete_function.argument_keywords)
171  concrete_function._num_positional_args = (
172      saved_bare_concrete_function.allowed_positional_arguments)
173  if saved_bare_concrete_function.HasField("function_spec"):
174    coder = nested_structure_coder.StructureCoder()
175    function_spec = _deserialize_function_spec_as_nonmethod(
176        saved_bare_concrete_function.function_spec,
177        coder)
178    concrete_function._set_function_spec(function_spec)
179  # pylint: enable=protected-access
180  concrete_function.add_to_graph()
181  return concrete_function
182
183
184class RestoredFunction(def_function.Function):
185  """Wrapper class for a function that has been restored from saved state.
186
187  See `def_function.Function`.
188  """
189
190  def __init__(self, python_function, name, function_spec, concrete_functions):
191    # TODO(mdan): We may enable autograph once exceptions are supported.
192    super(RestoredFunction, self).__init__(
193        python_function, name, autograph=False,
194        jit_compile=function_spec.jit_compile)
195    self.concrete_functions = concrete_functions
196    self._function_spec = function_spec
197
198    # Prevent RestoredFunction from spamming users with frequent tracing
199    # warnings.
200    self._omit_frequent_tracing_warning = True
201
202  @property
203  def _run_functions_eagerly(self):
204    # We do not have access to the original python function, and thus, we
205    # cannot meaningfully do anything but call our concrete function graphs
206    # under the hood.
207    #
208    # Attempting to call our bespoke python function (i.e.
209    # `restored_function_body`) will work so long as the user passes in all
210    # required and optional arguments. If an optional argument is missing,
211    # however, the call will break. For this reason, we instead skip the
212    # eager call path altogether if a user has enabled eager function execution
213    # via `tf.config.run_functions_eagerly`.
214    return False
215
216  def _list_all_concrete_functions_for_serialization(self):
217    return self.concrete_functions
218
219  def _defun_with_scope(self, scope):
220    func = super(RestoredFunction, self)._defun_with_scope(scope)
221    func._function_spec = self._function_spec  # pylint: disable=protected-access
222    return func
223
224
225def recreate_function(saved_function, concrete_functions):
226  """Creates a `Function` from a `SavedFunction`.
227
228  Args:
229    saved_function: `SavedFunction` proto.
230    concrete_functions: map from function name to `ConcreteFunction`.
231      As a side effect of this function, the `FunctionSpec` from
232      `saved_function` is added to each `ConcreteFunction` in this map.
233
234  Returns:
235    A `Function`.
236  """
237  # TODO(andresp): Construct a `Function` with the cache populated
238  # instead of creating a new `Function` backed by a Python layer to
239  # glue things together. Current approach is nesting functions deeper for each
240  # serialization cycle.
241  coder = nested_structure_coder.StructureCoder()
242
243  # Note: handling method functions is tricky since make_decorator does not
244  # allows control of "ismethod". Additionally since restored functions do
245  # not behave as methods i.e. they always use the same captured tensors
246  # independent of the object they are bound to, there is little value on
247  # propagating that correctly.
248  #
249  # Ideally this conversion should happen at serialization time. But since
250  # there are SavedModels which have "ismethod" populated and have an extra
251  # argument that they expect to be ignored, we do it at deserialization.
252  function_spec = _deserialize_function_spec_as_nonmethod(
253      saved_function.function_spec,
254      coder)
255
256  def restored_function_body(*args, **kwargs):
257    """Calls a restored function or raises an error if no matching function."""
258    if not saved_function.concrete_functions:
259      raise ValueError("Found zero restored functions for caller function.")
260    # This is the format of function.graph.structured_input_signature. At this
261    # point, the args and kwargs have already been canonicalized.
262    inputs = (args, kwargs)
263
264    # First try to find a concrete function that can be called without input
265    # conversions. This allows one to pick a more specific trace in case there
266    # was also a more expensive one that supported tensors.
267    for allow_conversion in [False, True]:
268      for function_name in saved_function.concrete_functions:
269        function = concrete_functions[function_name]
270        if _concrete_function_callable_with(function, inputs, allow_conversion):
271          return _call_concrete_function(function, inputs)
272
273    signature_descriptions = []
274
275    def _pretty_format_positional(positional):
276      return "Positional arguments ({} total):\n    * {}".format(
277          len(positional), "\n    * ".join(str(a) for a in positional))
278
279    for index, function_name in enumerate(saved_function.concrete_functions):
280      concrete_function = concrete_functions[function_name]
281      positional, keyword = concrete_function.structured_input_signature
282      signature_descriptions.append(
283          "Option {}:\n  {}\n  Keyword arguments: {}"
284          .format(index + 1, _pretty_format_positional(positional), keyword))
285    raise ValueError(
286        "Could not find matching function to call loaded from the SavedModel. "
287        "Got:\n  {}\n  Keyword arguments: {}\n\nExpected "
288        "these arguments to match one of the following {} option(s):\n\n{}"
289        .format(_pretty_format_positional(args), kwargs,
290                len(saved_function.concrete_functions),
291                "\n\n".join(signature_descriptions)))
292
293  concrete_function_objects = []
294  for concrete_function_name in saved_function.concrete_functions:
295    concrete_function_objects.append(concrete_functions[concrete_function_name])
296
297  for cf in concrete_function_objects:
298    cf._set_function_spec(function_spec)  # pylint: disable=protected-access
299
300  restored_function = RestoredFunction(
301      restored_function_body,
302      restored_function_body.__name__,
303      function_spec,
304      concrete_function_objects)
305
306  return tf_decorator.make_decorator(
307      restored_function_body,
308      restored_function,
309      decorator_argspec=function_spec.fullargspec)
310
311
312def load_function_def_library(library, load_shared_name_suffix=None):
313  """Load a set of functions as concrete functions without captured inputs.
314
315  Functions names are manipulated during load such that they do not overlap
316  with previously created ones.
317
318  Args:
319    library: FunctionDefLibrary proto message.
320    load_shared_name_suffix: If specified, used to uniquify shared
321      names. Otherwise, a unique name is generated.
322
323  Returns:
324    Map of original function names in the library to instances of
325    `ConcreteFunction` without captured inputs.
326
327  Raises:
328    ValueError: if functions dependencies have a cycle.
329  """
330  library_function_names = set(fdef.signature.name for fdef in library.function)
331  functions = {}
332  renamed_functions = {}
333
334  # Our graph building code currently requires functions to be registered with
335  # some tf.Graph in order to import functions using the
336  # op-name-is-function-name calling convention. To avoid leaking memory into
337  # the global default graph when executing eagerly, we create a temporary
338  # Graph.
339  #
340  # TODO(allenl): Make this Graph creation unnecessary when executing eagerly by
341  # fixing function_def_to_graph_def.
342  if ops.executing_eagerly_outside_functions():
343    graph = ops.Graph()
344  else:
345    graph = ops.get_default_graph()
346
347  if load_shared_name_suffix is None:
348    load_shared_name_suffix = "_load_{}".format(ops.uid())
349  for fdef in _sort_function_defs(library, library_function_names):
350    copy = _fix_fdef(fdef, functions, load_shared_name_suffix)
351
352    # There is no need to copy all functions into the function def graph. It
353    # leads to a O(n^2) increase of memory when importing functions and the
354    # extra function definitions are a no-op since they already imported as a
355    # function before and passed in explicitly (due to the topologic sort
356    # import).
357    with graph.as_default():
358      func_graph = function_def_lib.function_def_to_graph(copy)
359    _restore_gradient_functions(func_graph, renamed_functions)
360
361    for dep in _list_function_deps(fdef, library_function_names):
362      functions[dep].add_to_graph(func_graph)
363
364    # We do not initialize the new ConcreteFunction's function_spec and/or
365    # arg_keywords here (which are used to parse the structured and flat
366    # signatures, respectively). ConcreteFunction that are part of a saved
367    # function is set up later by recreate_function(); and bare ConcreteFunction
368    # is set up by by setup_bare_concrete_function().
369    func = function_lib.ConcreteFunction(func_graph)
370    func.add_to_graph(graph)
371
372    functions[fdef.signature.name] = func
373    renamed_functions[func.name] = func
374    if any(op.type == "TRTEngineOp" for op in func_graph.get_operations()):
375      # TODO(b/150708051): Remove this hack once TensorRT SavedModel integration
376      # is fixed. Currently it's leaking memory to maintain bug compatibility
377      # with previous behavior.
378      func.add_to_graph(ops.get_default_graph())
379
380  return functions
381
382
383def _restore_gradient_functions(func_graph, renamed_functions):
384  """Populate function op's _gradient_function with default gradient."""
385  for op in func_graph.get_operations():
386    # TODO(andresp): This code assumes that the gradient registered for this
387    # function call is the default gradient for the function and not a custom
388    # one.
389    if op.type in ["StatefulPartitionedCall", "PartitionedCall"]:
390      function = renamed_functions[compat.as_bytes(
391          op.node_def.attr["f"].func.name)]
392      op._gradient_function = function._get_gradient_function()  # pylint: disable=protected-access
393
394
395def _sort_function_defs(library, library_function_names):
396  """Return a topologic sort of FunctionDefs in a library."""
397  edges = collections.defaultdict(list)
398  in_count = collections.defaultdict(lambda: 0)
399
400  for fdef in library.function:
401    for dep in _list_function_deps(fdef, library_function_names):
402      edges[dep].append(fdef.signature.name)
403      in_count[fdef.signature.name] += 1
404
405  ready = [
406      fdef.signature.name
407      for fdef in library.function
408      if in_count[fdef.signature.name] == 0
409  ]
410  output = []
411  while ready:
412    node = ready.pop()
413    output.append(node)
414    for dest in edges[node]:
415      in_count[dest] -= 1
416      if not in_count[dest]:
417        ready.append(dest)
418
419  if len(output) != len(library.function):
420    failed_to_resolve = sorted(set(in_count.keys()) - set(output))
421    raise ValueError("There is a cyclic-dependency between functions. ",
422                     "Could not resolve %r." % (failed_to_resolve,))
423
424  reverse = {fdef.signature.name: fdef for fdef in library.function}
425  return [reverse[x] for x in output]
426
427
428def _check_op_has_custom_gradients(node_def):
429  """Returns True if op has custom gradients."""
430  return ("_gradient_op_type" in node_def.attr and
431          node_def.op not in ["StatefulPartitionedCall", "PartitionedCall"])
432
433
434def fix_node_def(node_def, functions, shared_name_suffix):
435  """Replace functions calls and shared names in `node_def`."""
436  if node_def.op in functions:
437    node_def.op = functions[node_def.op].name
438  for _, attr_value in node_def.attr.items():
439    if attr_value.WhichOneof("value") == "func":
440      attr_value.func.name = functions[attr_value.func.name].name
441    elif attr_value.WhichOneof("value") == "list":
442      for fn in attr_value.list.func:
443        fn.name = functions[fn.name].name
444
445  # Fix old table creation bug.
446  if node_def.op == "HashTableV2":
447    if ("use_node_name_sharing" not in node_def.attr or
448        not node_def.attr["use_node_name_sharing"].b):
449      node_def.attr["use_node_name_sharing"].b = True
450      # We are turning on node mame sharing, so have to make sure we don't
451      # accidentally share a table resource.
452      shared_name_suffix += "_{}".format(ops.uid())
453
454  # TODO(b/124205571): Avoid accidental sharing and destruction of restored
455  # resources. For now uniquify "shared_name" when loading functions to avoid
456  # sharing.
457  # TODO: Add regression test for b/150826922.
458  op_def = op_def_registry.get(node_def.op)
459  if op_def:
460    attr = next((a for a in op_def.attr if a.name == "shared_name"), None)
461    if attr:
462      shared_name = None
463      if "shared_name" in node_def.attr and node_def.attr["shared_name"].s:
464        shared_name = node_def.attr["shared_name"].s
465      elif attr.default_value.s:
466        shared_name = compat.as_bytes(attr.default_value.s)
467      if not shared_name:
468        shared_name = compat.as_bytes(node_def.name)
469
470      node_def.attr["shared_name"].s = (
471          shared_name + compat.as_bytes(shared_name_suffix))
472
473
474def _fix_fdef(orig_fdef, functions, shared_name_suffix):
475  """Fixes a FunctionDef proto to be loaded in current context.
476
477  In particular, when loading a function library into an eager context, one
478  must rename the functions to avoid conflicts with existent functions.
479
480  Args:
481    orig_fdef: FunctionDef proto to fix. It is not modified.
482    functions: map from function name to a ConcreteFunction instance.
483    shared_name_suffix: A unique string for this load which helps to avoid
484      `shared_name` collisions across loads. Two functions from the same load
485      using the same `shared_name` still need to share, but functions from
486      different loads with the same `shared_name` should not.
487
488  Returns:
489    A fixed copy of the original FunctionDef.
490  """
491  fdef = function_pb2.FunctionDef()
492  fdef.CopyFrom(orig_fdef)
493  contains_custom_gradients = False
494
495  for node_def in fdef.node_def:
496    fix_node_def(node_def, functions, shared_name_suffix)
497    if not contains_custom_gradients:
498      contains_custom_gradients = _check_op_has_custom_gradients(node_def)
499  if contains_custom_gradients:
500    logging.warning(
501        "Importing a function (%s) with ops with custom gradients. Will likely "
502        "fail if a gradient is requested.", fdef.signature.name)
503
504  fdef.signature.name = _clean_function_name(fdef.signature.name)
505  return fdef
506
507
508def _list_function_deps(fdef, library_function_names):
509  """Find functions referenced in `fdef`."""
510  # TODO(andresp): Recurse into list attributes and into NameAttrList attrs both
511  # when listing deps and when fixing them. `function_def_to_graph` also
512  # requires fixes.
513  deps = set()
514  for node_def in fdef.node_def:
515    if node_def.op in library_function_names:
516      deps.add(node_def.op)
517    else:
518      for _, attr_value in node_def.attr.items():
519        if attr_value.WhichOneof("value") == "func":
520          deps.add(attr_value.func.name)
521        elif attr_value.WhichOneof("value") == "list":
522          for fn in attr_value.list.func:
523            deps.add(fn.name)
524
525  return deps
526
527
528_FUNCTION_WRAPPER_NAME_REGEX = r"^%s(.*)_\d+$" % (function_lib._INFERENCE_PREFIX
529                                                 )  # pylint:disable=protected-access
530
531
532def _clean_function_name(name):
533  """Vanity function to keep the function names comprehensible."""
534  # Note: each time a function is wrapped into `function_lib.ConcreteFunction`
535  # its name becomes "__inference_<orig>_xyz".
536  match = re.search(_FUNCTION_WRAPPER_NAME_REGEX, name)
537  if match:
538    return match.group(1)
539  else:
540    return name
541