1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
7#     http://www.apache.org/licenses/LICENSE-2.0
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tools for deserializing `Function`s."""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
21import collections
22import re
24from tensorflow.core.framework import function_pb2
25from tensorflow.python.eager import def_function
26from tensorflow.python.eager import function as function_lib
27from tensorflow.python.framework import func_graph as func_graph_lib
28from tensorflow.python.framework import function_def_to_graph as function_def_lib
29from tensorflow.python.framework import ops
30from tensorflow.python.framework import tensor_spec
31from tensorflow.python.ops import resource_variable_ops
32from tensorflow.python.platform import tf_logging as logging
33from tensorflow.python.saved_model import nested_structure_coder
34from tensorflow.python.util import compat
35from tensorflow.python.util import nest
36from tensorflow.python.util import tf_decorator
37from tensorflow.python.util import tf_inspect
40def _is_tensor(t):
41  return isinstance(t, (ops.Tensor, resource_variable_ops.ResourceVariable))
44def _call_concrete_function(function, inputs):
45  """Calls a restored Function with structured inputs.
47  This differs from `function.__call__` in that inputs and outputs are
48  structured and that it casts inputs to tensors if needed.
50  Note: this does not checks that non-tensor inputs match. That should be
51  done before via `_concrete_function_callable_with`.
53  Args:
54    function: ConcreteFunction to call.
55    inputs: Structured inputs compatible with
56        `function.graph.structured_input_signature`.
58  Returns:
59    The structured function output.
60  """
61  expected_structure = function.graph.structured_input_signature
62  flatten_inputs = nest.flatten_up_to(expected_structure, inputs)
63  tensor_inputs = []
64  for arg, expected in zip(flatten_inputs, nest.flatten(expected_structure)):
65    if isinstance(expected, tensor_spec.TensorSpec):
66      tensor_inputs.append(
67          ops.convert_to_tensor(arg, dtype_hint=expected.dtype))
68  result = function._call_flat(tensor_inputs)  # pylint: disable=protected-access
69  if isinstance(result, ops.Operation):
70    return None
71  return result
74def _try_convert_to_tensor_spec(arg, dtype_hint):
75  """Returns None or TensorSpec obtained if `arg` is converted to tensor."""
76  try:
77    # Note: try conversion in a FuncGraph to avoid poluting current context.
78    with func_graph_lib.FuncGraph(name="guess_conversion").as_default():
79      result = ops.convert_to_tensor(arg, dtype_hint=dtype_hint)
80      return tensor_spec.TensorSpec(shape=result.shape, dtype=result.dtype)
81  except (TypeError, ValueError):
82    return None
85def _concrete_function_callable_with(function, inputs, allow_conversion):
86  """Returns whether concrete `function` can be called with `inputs`."""
87  expected_structure = function.graph.structured_input_signature
88  try:
89    flatten_inputs = nest.flatten_up_to(expected_structure, inputs)
90  except (TypeError, ValueError):
91    return False
92  for arg, expected in zip(flatten_inputs, nest.flatten(expected_structure)):
93    if isinstance(expected, tensor_spec.TensorSpec):
94      if allow_conversion:
95        arg = _try_convert_to_tensor_spec(arg, dtype_hint=expected.dtype)
96      if not _is_tensor(arg) and not isinstance(arg, tensor_spec.TensorSpec):
97        return False
98      if arg.dtype != expected.dtype:
99        return False
100      if not expected.shape.is_compatible_with(arg.shape):
101        return False
102    else:
103      if arg != expected:
104        return False
105  return True
108def _deserialize_function_spec(function_spec_proto, coder):
109  """Deserialize a FunctionSpec object from its proto representation."""
110  typeless_fullargspec = coder.decode_proto(function_spec_proto.fullargspec)
111  fullargspec = tf_inspect.FullArgSpec(
112      args=typeless_fullargspec.args,
113      varargs=typeless_fullargspec.varargs,
114      varkw=typeless_fullargspec.varkw,
115      defaults=typeless_fullargspec.defaults,
116      kwonlyargs=typeless_fullargspec.kwonlyargs,
117      kwonlydefaults=typeless_fullargspec.kwonlydefaults,
118      annotations=typeless_fullargspec.annotations)
119  is_method = function_spec_proto.is_method
120  args_to_prepend = coder.decode_proto(function_spec_proto.args_to_prepend)
121  kwargs_to_include = coder.decode_proto(function_spec_proto.kwargs_to_include)
122  input_signature = coder.decode_proto(function_spec_proto.input_signature)
123  return function_lib.FunctionSpec(fullargspec, is_method, args_to_prepend,
124                                   kwargs_to_include, input_signature)
127# TODO(allenl): The fact that we can't derive ConcreteFunction calling
128# conventions from the serialized input spec right now is unfortunate. Merging
129# these would be good, maybe by adding TensorSpec names to cache keys so renamed
130# keyword arguments would yield different ConcreteFunctions.
131def setup_bare_concrete_function(saved_bare_concrete_function,
132                                 concrete_functions):
133  """Makes a restored bare concrete function callable."""
134  # Bare concrete functions accept only flat lists of Tensors with unique
135  # names.
136  concrete_function = concrete_functions[
137      saved_bare_concrete_function.concrete_function_name]
138  # pylint: disable=protected-access
139  concrete_function._arg_keywords = (
140      saved_bare_concrete_function.argument_keywords)
141  concrete_function._num_positional_args = (
142      saved_bare_concrete_function.allowed_positional_arguments)
143  # pylint: enable=protected-access
144  concrete_function.add_to_graph()
145  return concrete_function
148class RestoredFunction(def_function.Function):
149  """Wrapper class for a function that has been restored from saved state.
151  See `def_function.Function`.
152  """
154  def __init__(self, python_function, name, function_spec, concrete_functions):
155    # TODO(mdan): We may enable autograph once exceptions are supported.
156    super(RestoredFunction, self).__init__(
157        python_function, name, autograph=False)
158    self._concrete_functions = concrete_functions
159    self._function_spec = function_spec
161  def _list_all_concrete_functions_for_serialization(self):
162    return self._concrete_functions
165def recreate_function(saved_function, concrete_functions):
166  """Creates a `Function` from a `SavedFunction`.
168  Args:
169    saved_function: `SavedFunction` proto.
170    concrete_functions: map from function name to `ConcreteFunction`.
172  Returns:
173    A `Function`.
174  """
175  # TODO(andresp): Construct a `Function` with the cache populated
176  # instead of creating a new `Function` backed by a Python layer to
177  # glue things together. Current approach is nesting functions deeper for each
178  # serialization cycle.
180  coder = nested_structure_coder.StructureCoder()
181  function_spec = _deserialize_function_spec(saved_function.function_spec,
182                                             coder)
184  def restored_function_body(*args, **kwargs):
185    """Calls a restored function."""
186    # This is the format of function.graph.structured_input_signature. At this
187    # point, the args and kwargs have already been canonicalized.
188    inputs = (args, kwargs)
190    # First try to find a concrete function that can be called without input
191    # conversions. This allows one to pick a more specific trace in case there
192    # was also a more expensive one that supported tensors.
193    for allow_conversion in [False, True]:
194      for function_name in saved_function.concrete_functions:
195        function = concrete_functions[function_name]
196        if _concrete_function_callable_with(function, inputs, allow_conversion):
197          return _call_concrete_function(function, inputs)
199    available_signatures = [
200        concrete_functions[function_name].graph.structured_input_signature
201        for function_name in saved_function.concrete_functions
202    ]
203    raise ValueError(
204        "Could not find matching function to call for inputs %r. "
205        "Only existing signatures are %r."
206        % (inputs, available_signatures))
208  concrete_function_objects = []
209  for concrete_function_name in saved_function.concrete_functions:
210    concrete_function_objects.append(concrete_functions[concrete_function_name])
212  restored_function = RestoredFunction(
213      restored_function_body,
214      restored_function_body.__name__,
215      function_spec,
216      concrete_function_objects)
218  return tf_decorator.make_decorator(
219      restored_function_body,
220      restored_function,
221      decorator_argspec=function_spec.fullargspec)
224def load_function_def_library(library):
225  """Load a set of functions as concrete functions without captured inputs.
227  Functions names are manipulated during load such that they do not overlap
228  with previously created ones.
230  Args:
231    library: FunctionDefLibrary proto message.
233  Returns:
234    Map of original function names in the library to instances of
235    `ConcreteFunction` without captured inputs.
237  Raises:
238    ValueError: if functions dependencies have a cycle.
239  """
240  functions = {}
242  load_shared_name_suffix = "_load_{}".format(ops.uid())
243  for fdef in _sort_function_defs(library):
244    copy = _fix_fdef(fdef, functions, load_shared_name_suffix)
246    func_graph = function_def_lib.function_def_to_graph(copy)
247    for dep in _list_function_deps(fdef):
248      functions[dep].add_to_graph(func_graph)
249    func = function_lib.ConcreteFunction(func_graph)
250    func.add_to_graph()
252    functions[fdef.signature.name] = func
254    # Also register the gradients in the current root context.
255    with ops.init_scope():
256      func._register_gradient()  # pylint: disable=protected-access
258  return functions
261def _sort_function_defs(library):
262  """Return a topologic sort of FunctionDefs in a library."""
263  edges = collections.defaultdict(list)
264  in_count = collections.defaultdict(lambda: 0)
266  for fdef in library.function:
267    for dep in _list_function_deps(fdef):
268      edges[dep].append(fdef.signature.name)
269      in_count[fdef.signature.name] += 1
271  ready = [
272      fdef.signature.name
273      for fdef in library.function
274      if in_count[fdef.signature.name] == 0
275  ]
276  output = []
277  while ready:
278    node = ready.pop()
279    output.append(node)
280    for dest in edges[node]:
281      in_count[dest] -= 1
282      if not in_count[dest]:
283        ready.append(dest)
285  if len(output) != len(library.function):
286    failed_to_resolve = sorted(set(in_count.keys()) - set(output))
287    raise ValueError("There is a cyclic-dependency between functions. ",
288                     "Could not resolve %r." % (failed_to_resolve,))
290  reverse = {fdef.signature.name: fdef for fdef in library.function}
291  return [reverse[x] for x in output]
294def _fix_fdef(orig_fdef, functions, shared_name_suffix):
295  """Fixes a FunctionDef proto to be loaded in current context.
297  In particular, when loading a function library into an eager context, one
298  must rename the functions to avoid conflicts with existent functions.
300  Args:
301    orig_fdef: FunctionDef proto to fix. It is not modified.
302    functions: map from function name to a ConcreteFunction instance.
303    shared_name_suffix: A unique string for this load which helps to avoid
304      `shared_name` collisions across loads. Two functions from the same load
305      using the same `shared_name` still need to share, but functions from
306      different loads with the same `shared_name` should not.
308  Returns:
309    A fixed copy of the original FunctionDef.
310  """
311  fdef = function_pb2.FunctionDef()
312  fdef.CopyFrom(orig_fdef)
313  for node_def in fdef.node_def:
314    if "_gradient_op_type" in node_def.attr:
315      if node_def.op in ["StatefulPartitionedCall", "PartitionedCall"]:
316        # TODO(andresp): This code assumes that the gradient registered for this
317        # function call is the default gradient for the function and not a
318        # custom one.
319        fname = node_def.attr["f"].func.name
320        node_def.attr["_gradient_op_type"].s = compat.as_bytes(
321            functions[fname]._gradient_name)  # pylint: disable=protected-access
322      else:
323        logging.warning("Importing a function (%s) with ops with custom "
324                        "gradients. Will likely fail if a gradient is "
325                        "requested.", fdef.signature.name)
326    for _, attr_value in node_def.attr.items():
327      if attr_value.func.name:
328        attr_value.func.name = functions[attr_value.func.name].name
330    # TODO(b/124205571): Avoid accidental sharing and destruction of restored
331    # resources. For now uniquify "shared_name" when loading functions to avoid
332    # sharing.
333    if "shared_name" in node_def.attr:
334      node_def.attr["shared_name"].s += compat.as_bytes(shared_name_suffix)
336  fdef.signature.name = _clean_function_name(fdef.signature.name)
337  return fdef
340def _list_function_deps(fdef):
341  # TODO(andresp): Recurse into list attributes and into NameAttrList attrs both
342  # when listing deps and when fixing them. `function_def_to_graph` also
343  # requires fixes.
344  deps = set()
345  for node_def in fdef.node_def:
346    for _, attr_value in node_def.attr.items():
347      if attr_value.WhichOneof("value") == "func":
348        deps.add(attr_value.func.name)
349  return deps
352def _clean_function_name(name):
353  """Vanity function to keep the function names comprehensible."""
354  # Note: each time a function is wrapped into `function_lib.ConcreteFunction`
355  # its name becomes "__inference_<orig>_xyz".
356  match = re.search(r"^__inference_(.*)_\d+$", name)
357  if match:
358    return match.group(1)
359  else:
360    return name