1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tools for deserializing `Function`s.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22import re 23 24from tensorflow.core.framework import function_pb2 25from tensorflow.python.eager import def_function 26from tensorflow.python.eager import function as function_lib 27from tensorflow.python.framework import func_graph as func_graph_lib 28from tensorflow.python.framework import function_def_to_graph as function_def_lib 29from tensorflow.python.framework import ops 30from tensorflow.python.framework import tensor_spec 31from tensorflow.python.ops import resource_variable_ops 32from tensorflow.python.platform import tf_logging as logging 33from tensorflow.python.saved_model import nested_structure_coder 34from tensorflow.python.util import compat 35from tensorflow.python.util import nest 36from tensorflow.python.util import tf_decorator 37from tensorflow.python.util import tf_inspect 38 39 40def _is_tensor(t): 41 return isinstance(t, (ops.Tensor, resource_variable_ops.ResourceVariable)) 42 43 44def _call_concrete_function(function, inputs): 45 """Calls a restored Function with structured inputs. 46 47 This differs from `function.__call__` in that inputs and outputs are 48 structured and that it casts inputs to tensors if needed. 49 50 Note: this does not checks that non-tensor inputs match. That should be 51 done before via `_concrete_function_callable_with`. 52 53 Args: 54 function: ConcreteFunction to call. 55 inputs: Structured inputs compatible with 56 `function.graph.structured_input_signature`. 57 58 Returns: 59 The structured function output. 60 """ 61 expected_structure = function.graph.structured_input_signature 62 flatten_inputs = nest.flatten_up_to(expected_structure, inputs) 63 tensor_inputs = [] 64 for arg, expected in zip(flatten_inputs, nest.flatten(expected_structure)): 65 if isinstance(expected, tensor_spec.TensorSpec): 66 tensor_inputs.append( 67 ops.convert_to_tensor(arg, dtype_hint=expected.dtype)) 68 result = function._call_flat(tensor_inputs) # pylint: disable=protected-access 69 if isinstance(result, ops.Operation): 70 return None 71 return result 72 73 74def _try_convert_to_tensor_spec(arg, dtype_hint): 75 """Returns None or TensorSpec obtained if `arg` is converted to tensor.""" 76 try: 77 # Note: try conversion in a FuncGraph to avoid poluting current context. 78 with func_graph_lib.FuncGraph(name="guess_conversion").as_default(): 79 result = ops.convert_to_tensor(arg, dtype_hint=dtype_hint) 80 return tensor_spec.TensorSpec(shape=result.shape, dtype=result.dtype) 81 except (TypeError, ValueError): 82 return None 83 84 85def _concrete_function_callable_with(function, inputs, allow_conversion): 86 """Returns whether concrete `function` can be called with `inputs`.""" 87 expected_structure = function.graph.structured_input_signature 88 try: 89 flatten_inputs = nest.flatten_up_to(expected_structure, inputs) 90 except (TypeError, ValueError): 91 return False 92 for arg, expected in zip(flatten_inputs, nest.flatten(expected_structure)): 93 if isinstance(expected, tensor_spec.TensorSpec): 94 if allow_conversion: 95 arg = _try_convert_to_tensor_spec(arg, dtype_hint=expected.dtype) 96 if not _is_tensor(arg) and not isinstance(arg, tensor_spec.TensorSpec): 97 return False 98 if arg.dtype != expected.dtype: 99 return False 100 if not expected.shape.is_compatible_with(arg.shape): 101 return False 102 else: 103 if arg != expected: 104 return False 105 return True 106 107 108def _deserialize_function_spec(function_spec_proto, coder): 109 """Deserialize a FunctionSpec object from its proto representation.""" 110 typeless_fullargspec = coder.decode_proto(function_spec_proto.fullargspec) 111 fullargspec = tf_inspect.FullArgSpec( 112 args=typeless_fullargspec.args, 113 varargs=typeless_fullargspec.varargs, 114 varkw=typeless_fullargspec.varkw, 115 defaults=typeless_fullargspec.defaults, 116 kwonlyargs=typeless_fullargspec.kwonlyargs, 117 kwonlydefaults=typeless_fullargspec.kwonlydefaults, 118 annotations=typeless_fullargspec.annotations) 119 is_method = function_spec_proto.is_method 120 args_to_prepend = coder.decode_proto(function_spec_proto.args_to_prepend) 121 kwargs_to_include = coder.decode_proto(function_spec_proto.kwargs_to_include) 122 input_signature = coder.decode_proto(function_spec_proto.input_signature) 123 return function_lib.FunctionSpec(fullargspec, is_method, args_to_prepend, 124 kwargs_to_include, input_signature) 125 126 127# TODO(allenl): The fact that we can't derive ConcreteFunction calling 128# conventions from the serialized input spec right now is unfortunate. Merging 129# these would be good, maybe by adding TensorSpec names to cache keys so renamed 130# keyword arguments would yield different ConcreteFunctions. 131def setup_bare_concrete_function(saved_bare_concrete_function, 132 concrete_functions): 133 """Makes a restored bare concrete function callable.""" 134 # Bare concrete functions accept only flat lists of Tensors with unique 135 # names. 136 concrete_function = concrete_functions[ 137 saved_bare_concrete_function.concrete_function_name] 138 # pylint: disable=protected-access 139 concrete_function._arg_keywords = ( 140 saved_bare_concrete_function.argument_keywords) 141 concrete_function._num_positional_args = ( 142 saved_bare_concrete_function.allowed_positional_arguments) 143 # pylint: enable=protected-access 144 concrete_function.add_to_graph() 145 return concrete_function 146 147 148class RestoredFunction(def_function.Function): 149 """Wrapper class for a function that has been restored from saved state. 150 151 See `def_function.Function`. 152 """ 153 154 def __init__(self, python_function, name, function_spec, concrete_functions): 155 # TODO(mdan): We may enable autograph once exceptions are supported. 156 super(RestoredFunction, self).__init__( 157 python_function, name, autograph=False) 158 self._concrete_functions = concrete_functions 159 self._function_spec = function_spec 160 161 def _list_all_concrete_functions_for_serialization(self): 162 return self._concrete_functions 163 164 165def recreate_function(saved_function, concrete_functions): 166 """Creates a `Function` from a `SavedFunction`. 167 168 Args: 169 saved_function: `SavedFunction` proto. 170 concrete_functions: map from function name to `ConcreteFunction`. 171 172 Returns: 173 A `Function`. 174 """ 175 # TODO(andresp): Construct a `Function` with the cache populated 176 # instead of creating a new `Function` backed by a Python layer to 177 # glue things together. Current approach is nesting functions deeper for each 178 # serialization cycle. 179 180 coder = nested_structure_coder.StructureCoder() 181 function_spec = _deserialize_function_spec(saved_function.function_spec, 182 coder) 183 184 def restored_function_body(*args, **kwargs): 185 """Calls a restored function.""" 186 # This is the format of function.graph.structured_input_signature. At this 187 # point, the args and kwargs have already been canonicalized. 188 inputs = (args, kwargs) 189 190 # First try to find a concrete function that can be called without input 191 # conversions. This allows one to pick a more specific trace in case there 192 # was also a more expensive one that supported tensors. 193 for allow_conversion in [False, True]: 194 for function_name in saved_function.concrete_functions: 195 function = concrete_functions[function_name] 196 if _concrete_function_callable_with(function, inputs, allow_conversion): 197 return _call_concrete_function(function, inputs) 198 199 available_signatures = [ 200 concrete_functions[function_name].graph.structured_input_signature 201 for function_name in saved_function.concrete_functions 202 ] 203 raise ValueError( 204 "Could not find matching function to call for inputs %r. " 205 "Only existing signatures are %r." 206 % (inputs, available_signatures)) 207 208 concrete_function_objects = [] 209 for concrete_function_name in saved_function.concrete_functions: 210 concrete_function_objects.append(concrete_functions[concrete_function_name]) 211 212 restored_function = RestoredFunction( 213 restored_function_body, 214 restored_function_body.__name__, 215 function_spec, 216 concrete_function_objects) 217 218 return tf_decorator.make_decorator( 219 restored_function_body, 220 restored_function, 221 decorator_argspec=function_spec.fullargspec) 222 223 224def load_function_def_library(library): 225 """Load a set of functions as concrete functions without captured inputs. 226 227 Functions names are manipulated during load such that they do not overlap 228 with previously created ones. 229 230 Args: 231 library: FunctionDefLibrary proto message. 232 233 Returns: 234 Map of original function names in the library to instances of 235 `ConcreteFunction` without captured inputs. 236 237 Raises: 238 ValueError: if functions dependencies have a cycle. 239 """ 240 functions = {} 241 242 load_shared_name_suffix = "_load_{}".format(ops.uid()) 243 for fdef in _sort_function_defs(library): 244 copy = _fix_fdef(fdef, functions, load_shared_name_suffix) 245 246 func_graph = function_def_lib.function_def_to_graph(copy) 247 for dep in _list_function_deps(fdef): 248 functions[dep].add_to_graph(func_graph) 249 func = function_lib.ConcreteFunction(func_graph) 250 func.add_to_graph() 251 252 functions[fdef.signature.name] = func 253 254 # Also register the gradients in the current root context. 255 with ops.init_scope(): 256 func._register_gradient() # pylint: disable=protected-access 257 258 return functions 259 260 261def _sort_function_defs(library): 262 """Return a topologic sort of FunctionDefs in a library.""" 263 edges = collections.defaultdict(list) 264 in_count = collections.defaultdict(lambda: 0) 265 266 for fdef in library.function: 267 for dep in _list_function_deps(fdef): 268 edges[dep].append(fdef.signature.name) 269 in_count[fdef.signature.name] += 1 270 271 ready = [ 272 fdef.signature.name 273 for fdef in library.function 274 if in_count[fdef.signature.name] == 0 275 ] 276 output = [] 277 while ready: 278 node = ready.pop() 279 output.append(node) 280 for dest in edges[node]: 281 in_count[dest] -= 1 282 if not in_count[dest]: 283 ready.append(dest) 284 285 if len(output) != len(library.function): 286 failed_to_resolve = sorted(set(in_count.keys()) - set(output)) 287 raise ValueError("There is a cyclic-dependency between functions. ", 288 "Could not resolve %r." % (failed_to_resolve,)) 289 290 reverse = {fdef.signature.name: fdef for fdef in library.function} 291 return [reverse[x] for x in output] 292 293 294def _fix_fdef(orig_fdef, functions, shared_name_suffix): 295 """Fixes a FunctionDef proto to be loaded in current context. 296 297 In particular, when loading a function library into an eager context, one 298 must rename the functions to avoid conflicts with existent functions. 299 300 Args: 301 orig_fdef: FunctionDef proto to fix. It is not modified. 302 functions: map from function name to a ConcreteFunction instance. 303 shared_name_suffix: A unique string for this load which helps to avoid 304 `shared_name` collisions across loads. Two functions from the same load 305 using the same `shared_name` still need to share, but functions from 306 different loads with the same `shared_name` should not. 307 308 Returns: 309 A fixed copy of the original FunctionDef. 310 """ 311 fdef = function_pb2.FunctionDef() 312 fdef.CopyFrom(orig_fdef) 313 for node_def in fdef.node_def: 314 if "_gradient_op_type" in node_def.attr: 315 if node_def.op in ["StatefulPartitionedCall", "PartitionedCall"]: 316 # TODO(andresp): This code assumes that the gradient registered for this 317 # function call is the default gradient for the function and not a 318 # custom one. 319 fname = node_def.attr["f"].func.name 320 node_def.attr["_gradient_op_type"].s = compat.as_bytes( 321 functions[fname]._gradient_name) # pylint: disable=protected-access 322 else: 323 logging.warning("Importing a function (%s) with ops with custom " 324 "gradients. Will likely fail if a gradient is " 325 "requested.", fdef.signature.name) 326 for _, attr_value in node_def.attr.items(): 327 if attr_value.func.name: 328 attr_value.func.name = functions[attr_value.func.name].name 329 330 # TODO(b/124205571): Avoid accidental sharing and destruction of restored 331 # resources. For now uniquify "shared_name" when loading functions to avoid 332 # sharing. 333 if "shared_name" in node_def.attr: 334 node_def.attr["shared_name"].s += compat.as_bytes(shared_name_suffix) 335 336 fdef.signature.name = _clean_function_name(fdef.signature.name) 337 return fdef 338 339 340def _list_function_deps(fdef): 341 # TODO(andresp): Recurse into list attributes and into NameAttrList attrs both 342 # when listing deps and when fixing them. `function_def_to_graph` also 343 # requires fixes. 344 deps = set() 345 for node_def in fdef.node_def: 346 for _, attr_value in node_def.attr.items(): 347 if attr_value.WhichOneof("value") == "func": 348 deps.add(attr_value.func.name) 349 return deps 350 351 352def _clean_function_name(name): 353 """Vanity function to keep the function names comprehensible.""" 354 # Note: each time a function is wrapped into `function_lib.ConcreteFunction` 355 # its name becomes "__inference_<orig>_xyz". 356 match = re.search(r"^__inference_(.*)_\d+$", name) 357 if match: 358 return match.group(1) 359 else: 360 return name 361