1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""A utility function for importing TensorFlow graphs.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import contextlib 21 22from tensorflow.core.framework import graph_pb2 23from tensorflow.python import pywrap_tensorflow as c_api 24from tensorflow.python import tf2 25from tensorflow.python.framework import c_api_util 26from tensorflow.python.framework import device as pydev 27from tensorflow.python.framework import errors 28from tensorflow.python.framework import function 29from tensorflow.python.framework import op_def_registry 30from tensorflow.python.framework import ops 31from tensorflow.python.ops import control_flow_util 32from tensorflow.python.util import compat 33from tensorflow.python.util.deprecation import deprecated_args 34from tensorflow.python.util.tf_export import tf_export 35 36 37def _IsControlInput(input_name): 38 # Expected format: '^operation_name' (control input). 39 return input_name.startswith('^') 40 41 42def _ParseTensorName(tensor_name): 43 """Parses a tensor name into an operation name and output index. 44 45 This function will canonicalize tensor names as follows: 46 47 * "foo:0" -> ("foo", 0) 48 * "foo:7" -> ("foo", 7) 49 * "foo" -> ("foo", 0) 50 * "foo:bar:baz" -> ValueError 51 52 Args: 53 tensor_name: The name of a tensor. 54 55 Returns: 56 A tuple containing the operation name, and the output index. 57 58 Raises: 59 ValueError: If `tensor_name' cannot be interpreted as the name of a tensor. 60 """ 61 components = tensor_name.split(':') 62 if len(components) == 2: 63 # Expected format: 'operation_name:output_index'. 64 try: 65 output_index = int(components[1]) 66 except ValueError: 67 raise ValueError('Cannot convert %r to a tensor name.' % (tensor_name,)) 68 return components[0], output_index 69 elif len(components) == 1: 70 # Expected format: 'operation_name' (implicit 0th output). 71 return components[0], 0 72 else: 73 raise ValueError('Cannot convert %r to a tensor name.' % (tensor_name,)) 74 75 76@contextlib.contextmanager 77def _MaybeDevice(device): 78 """Applies the given device only if device is not None or empty.""" 79 if device: 80 with ops.device(device): 81 yield 82 else: 83 yield 84 85 86def _ProcessGraphDefParam(graph_def, op_dict): 87 """Type-checks and possibly canonicalizes `graph_def`.""" 88 if not isinstance(graph_def, graph_pb2.GraphDef): 89 # `graph_def` could be a dynamically-created message, so try a duck-typed 90 # approach 91 try: 92 old_graph_def = graph_def 93 graph_def = graph_pb2.GraphDef() 94 graph_def.MergeFrom(old_graph_def) 95 except TypeError: 96 raise TypeError('graph_def must be a GraphDef proto.') 97 else: 98 # If we're using the graph_def provided by the caller, modify graph_def 99 # in-place to add attr defaults to the NodeDefs (this is visible to the 100 # caller). 101 # NOTE(skyewm): this is undocumented behavior that at least meta_graph.py 102 # depends on. It might make sense to move this to meta_graph.py and have 103 # import_graph_def not modify the graph_def argument (we'd have to make sure 104 # this doesn't break anything else.) 105 for node in graph_def.node: 106 if node.op not in op_dict: 107 # Assume unrecognized ops are functions for now. TF_ImportGraphDef will 108 # report an error if the op is actually missing. 109 continue 110 op_def = op_dict[node.op] 111 _SetDefaultAttrValues(node, op_def) 112 113 return graph_def 114 115 116def _ProcessInputMapParam(input_map): 117 """Type-checks and possibly canonicalizes `input_map`.""" 118 if input_map is None: 119 input_map = {} 120 else: 121 if not (isinstance(input_map, dict) and all( 122 isinstance(k, compat.bytes_or_text_types) for k in input_map.keys())): 123 raise TypeError('input_map must be a dictionary mapping strings to ' 124 'Tensor objects.') 125 return input_map 126 127 128def _ProcessReturnElementsParam(return_elements): 129 """Type-checks and possibly canonicalizes `return_elements`.""" 130 if return_elements is None: 131 return None 132 if not all( 133 isinstance(x, compat.bytes_or_text_types) for x in return_elements): 134 raise TypeError('return_elements must be a list of strings.') 135 return tuple(compat.as_str(x) for x in return_elements) 136 137 138def _FindAttrInOpDef(attr_name, op_def): 139 for attr_def in op_def.attr: 140 if attr_name == attr_def.name: 141 return attr_def 142 return None 143 144 145def _RemoveDefaultAttrs(op_dict, producer_op_list, graph_def): 146 """Removes unknown default attrs according to `producer_op_list`. 147 148 Removes any unknown attrs in `graph_def` (i.e. attrs that do not appear in 149 the OpDefs in `op_dict`) that have a default value in `producer_op_list`. 150 151 Args: 152 op_dict: dict mapping operation name to OpDef. 153 producer_op_list: OpList proto. 154 graph_def: GraphDef proto 155 """ 156 producer_op_dict = {op.name: op for op in producer_op_list.op} 157 for node in graph_def.node: 158 # Remove any default attr values that aren't in op_def. 159 if node.op in producer_op_dict: 160 op_def = op_dict[node.op] 161 producer_op_def = producer_op_dict[node.op] 162 # We make a copy of node.attr to iterate through since we may modify 163 # node.attr inside the loop. 164 for key in list(node.attr): 165 if _FindAttrInOpDef(key, op_def) is None: 166 # No attr_def in consumer, look in producer. 167 attr_def = _FindAttrInOpDef(key, producer_op_def) 168 if (attr_def and attr_def.HasField('default_value') and 169 node.attr[key] == attr_def.default_value): 170 # Unknown attr had default value in producer, delete it so it can be 171 # understood by consumer. 172 del node.attr[key] 173 174 175def _ConvertInputMapValues(name, input_map): 176 """Ensures all input map values are tensors. 177 178 This should be called from inside the import name scope. 179 180 Args: 181 name: the `name` argument passed to import_graph_def 182 input_map: the `input_map` argument passed to import_graph_def. 183 184 Returns: 185 An possibly-updated version of `input_map`. 186 187 Raises: 188 ValueError: if input map values cannot be converted due to empty name scope. 189 """ 190 if not all(isinstance(v, ops.Tensor) for v in input_map.values()): 191 if name == '': # pylint: disable=g-explicit-bool-comparison 192 raise ValueError( 193 'tf.import_graph_def() requires a non-empty `name` if `input_map` ' 194 'contains non-Tensor values. Try calling tf.convert_to_tensor() on ' 195 '`input_map` values before calling tf.import_graph_def().') 196 with ops.name_scope('_inputs'): 197 input_map = {k: ops.convert_to_tensor(v) for k, v in input_map.items()} 198 return input_map 199 200 201def _PopulateTFImportGraphDefOptions(options, prefix, input_map, 202 return_elements): 203 """Populates the TF_ImportGraphDefOptions `options`.""" 204 c_api.TF_ImportGraphDefOptionsSetPrefix(options, prefix) 205 c_api.TF_ImportGraphDefOptionsSetUniquifyNames(options, True) 206 207 for input_src, input_dst in input_map.items(): 208 input_src = compat.as_str(input_src) 209 if input_src.startswith('^'): 210 src_name = compat.as_str(input_src[1:]) 211 dst_op = input_dst._as_tf_output().oper # pylint: disable=protected-access 212 c_api.TF_ImportGraphDefOptionsRemapControlDependency( 213 options, src_name, dst_op) 214 else: 215 src_name, src_idx = _ParseTensorName(input_src) 216 src_name = compat.as_str(src_name) 217 dst_output = input_dst._as_tf_output() # pylint: disable=protected-access 218 c_api.TF_ImportGraphDefOptionsAddInputMapping(options, src_name, src_idx, 219 dst_output) 220 for name in return_elements or []: 221 if ':' in name: 222 op_name, index = _ParseTensorName(name) 223 op_name = compat.as_str(op_name) 224 c_api.TF_ImportGraphDefOptionsAddReturnOutput(options, op_name, index) 225 else: 226 c_api.TF_ImportGraphDefOptionsAddReturnOperation(options, 227 compat.as_str(name)) 228 229 230def _ProcessNewOps(graph): 231 """Processes the newly-added TF_Operations in `graph`.""" 232 # Maps from a node to the names of the ops it's colocated with, if colocation 233 # is specified in the attributes. 234 colocation_pairs = {} 235 236 for new_op in graph._add_new_tf_operations(compute_devices=False): # pylint: disable=protected-access 237 original_device = new_op.device 238 new_op._set_device('') # pylint: disable=protected-access 239 colocation_names = _GetColocationNames(new_op) 240 if colocation_names: 241 colocation_pairs[new_op] = colocation_names 242 # Don't set a device for this op, since colocation constraints override 243 # device functions and the original device. Note that this op's device may 244 # still be set by the loop below. 245 # TODO(skyewm): why does it override the original device? 246 else: 247 with _MaybeDevice(original_device): 248 graph._apply_device_functions(new_op) # pylint: disable=protected-access 249 250 # The following loop populates the device field of ops that are colocated 251 # with another op. This is implied by the colocation attribute, but we 252 # propagate the device field for completeness. 253 for op, coloc_op_list in colocation_pairs.items(): 254 coloc_device = None 255 # Find any device in the list of colocated ops that have a device, if it 256 # exists. We assume that if multiple ops have devices, they refer to the 257 # same device. Otherwise, a runtime error will occur since the colocation 258 # property cannot be guaranteed. Note in TF2 colocations have been removed 259 # from the public API and will be considered a hint, so there is no runtime 260 # error. 261 # 262 # One possible improvement is to try to check for compatibility of all 263 # devices in this list at import time here, which would require 264 # implementing a compatibility function for device specs in python. 265 for coloc_op_name in coloc_op_list: 266 try: 267 coloc_op = graph._get_operation_by_name_unsafe(coloc_op_name) # pylint: disable=protected-access 268 except KeyError: 269 # Do not error in TF2 if the colocation cannot be guaranteed 270 if tf2.enabled() or control_flow_util.EnableControlFlowV2(graph): 271 continue 272 273 raise ValueError('Specified colocation to an op that ' 274 'does not exist during import: %s in %s' % 275 (coloc_op_name, op.name)) 276 if coloc_op.device: 277 coloc_device = pydev.DeviceSpec.from_string(coloc_op.device) 278 break 279 if coloc_device: 280 op._set_device(coloc_device) # pylint: disable=protected-access 281 282 283def _GetColocationNames(op): 284 """Returns names of the ops that `op` should be colocated with.""" 285 colocation_names = [] 286 try: 287 class_values = op.get_attr('_class') 288 except ValueError: 289 # No _class attr 290 return 291 for val in class_values: 292 val = compat.as_str(val) 293 if val.startswith('loc:@'): 294 colocation_node_name = val[len('loc:@'):] 295 if colocation_node_name != op.name: 296 colocation_names.append(colocation_node_name) 297 return colocation_names 298 299 300def _GatherReturnElements(requested_return_elements, graph, results): 301 """Returns the requested return elements from results. 302 303 Args: 304 requested_return_elements: list of strings of operation and tensor names 305 graph: Graph 306 results: wrapped TF_ImportGraphDefResults 307 308 Returns: 309 list of `Operation` and/or `Tensor` objects 310 """ 311 return_outputs = c_api.TF_ImportGraphDefResultsReturnOutputs(results) 312 return_opers = c_api.TF_ImportGraphDefResultsReturnOperations(results) 313 314 combined_return_elements = [] 315 outputs_idx = 0 316 opers_idx = 0 317 for name in requested_return_elements: 318 if ':' in name: 319 combined_return_elements.append( 320 graph._get_tensor_by_tf_output(return_outputs[outputs_idx])) # pylint: disable=protected-access 321 outputs_idx += 1 322 else: 323 combined_return_elements.append( 324 graph._get_operation_by_tf_operation(return_opers[opers_idx])) # pylint: disable=protected-access 325 opers_idx += 1 326 return combined_return_elements 327 328 329def _SetDefaultAttrValues(node_def, op_def): 330 """Set any default attr values in `node_def` that aren't present.""" 331 assert node_def.op == op_def.name 332 for attr_def in op_def.attr: 333 key = attr_def.name 334 if attr_def.HasField('default_value'): 335 value = node_def.attr[key] 336 if value is None or value.WhichOneof('value') is None: 337 node_def.attr[key].CopyFrom(attr_def.default_value) 338 339 340@tf_export('graph_util.import_graph_def', 'import_graph_def') 341@deprecated_args(None, 'Please file an issue at ' 342 'https://github.com/tensorflow/tensorflow/issues if you depend' 343 ' on this feature.', 'op_dict') 344def import_graph_def(graph_def, 345 input_map=None, 346 return_elements=None, 347 name=None, 348 op_dict=None, 349 producer_op_list=None): 350 """Imports the graph from `graph_def` into the current default `Graph`. 351 352 This function provides a way to import a serialized TensorFlow 353 [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto) 354 protocol buffer, and extract individual objects in the `GraphDef` as 355 `tf.Tensor` and `tf.Operation` objects. Once extracted, 356 these objects are placed into the current default `Graph`. See 357 `tf.Graph.as_graph_def` for a way to create a `GraphDef` 358 proto. 359 360 Args: 361 graph_def: A `GraphDef` proto containing operations to be imported into 362 the default graph. 363 input_map: A dictionary mapping input names (as strings) in `graph_def` 364 to `Tensor` objects. The values of the named input tensors in the 365 imported graph will be re-mapped to the respective `Tensor` values. 366 return_elements: A list of strings containing operation names in 367 `graph_def` that will be returned as `Operation` objects; and/or 368 tensor names in `graph_def` that will be returned as `Tensor` objects. 369 name: (Optional.) A prefix that will be prepended to the names in 370 `graph_def`. Note that this does not apply to imported function names. 371 Defaults to `"import"`. 372 op_dict: (Optional.) Deprecated, do not use. 373 producer_op_list: (Optional.) An `OpList` proto with the (possibly stripped) 374 list of `OpDef`s used by the producer of the graph. If provided, 375 unrecognized attrs for ops in `graph_def` that have their default value 376 according to `producer_op_list` will be removed. This will allow some more 377 `GraphDef`s produced by later binaries to be accepted by earlier binaries. 378 379 Returns: 380 A list of `Operation` and/or `Tensor` objects from the imported graph, 381 corresponding to the names in `return_elements`, 382 and None if `returns_elements` is None. 383 384 Raises: 385 TypeError: If `graph_def` is not a `GraphDef` proto, 386 `input_map` is not a dictionary mapping strings to `Tensor` objects, 387 or `return_elements` is not a list of strings. 388 ValueError: If `input_map`, or `return_elements` contains names that 389 do not appear in `graph_def`, or `graph_def` is not well-formed (e.g. 390 it refers to an unknown tensor). 391 """ 392 op_dict = op_def_registry.get_registered_ops() 393 394 graph_def = _ProcessGraphDefParam(graph_def, op_dict) 395 input_map = _ProcessInputMapParam(input_map) 396 return_elements = _ProcessReturnElementsParam(return_elements) 397 398 if producer_op_list is not None: 399 # TODO(skyewm): make a copy of graph_def so we're not mutating the argument? 400 _RemoveDefaultAttrs(op_dict, producer_op_list, graph_def) 401 402 graph = ops.get_default_graph() 403 with ops.name_scope(name, 'import', input_map.values()) as scope: 404 # Save unique prefix generated by name_scope 405 if scope: 406 assert scope.endswith('/') 407 prefix = scope[:-1] 408 else: 409 prefix = '' 410 411 # Generate any input map tensors inside name scope 412 input_map = _ConvertInputMapValues(name, input_map) 413 414 scoped_options = c_api_util.ScopedTFImportGraphDefOptions() 415 options = scoped_options.options 416 _PopulateTFImportGraphDefOptions(options, prefix, input_map, 417 return_elements) 418 419 # _ProcessNewOps mutates the new operations. _mutation_lock ensures a 420 # Session.run call cannot occur between creating the TF_Operations in the 421 # TF_GraphImportGraphDefWithResults call and mutating the them in 422 # _ProcessNewOps. 423 with graph._mutation_lock(): # pylint: disable=protected-access 424 with c_api_util.tf_buffer(graph_def.SerializeToString()) as serialized: 425 try: 426 results = c_api.TF_GraphImportGraphDefWithResults( 427 graph._c_graph, serialized, options) # pylint: disable=protected-access 428 results = c_api_util.ScopedTFImportGraphDefResults(results) 429 except errors.InvalidArgumentError as e: 430 # Convert to ValueError for backwards compatibility. 431 raise ValueError(str(e)) 432 433 # Create _DefinedFunctions for any imported functions. 434 # 435 # We do this by creating _DefinedFunctions directly from `graph_def`, and 436 # adding them to `graph`. Adding an existing function to a TF_Graph is a 437 # no-op, so this only has the effect of updating the Python state (usually 438 # _DefinedFunction.add_to_graph also adds the function to the TF_Graph). 439 # 440 # TODO(skyewm): fetch the TF_Functions directly from the TF_Graph 441 # TODO(skyewm): avoid sending serialized FunctionDefs back to the TF_Graph 442 443 _ProcessNewOps(graph) 444 445 if graph_def.library and graph_def.library.function: 446 functions = function.from_library(graph_def.library) 447 for f in functions: 448 f.add_to_graph(graph) 449 450 # Treat input mappings that don't appear in the graph as an error, because 451 # they are likely to be due to a typo. 452 missing_unused_input_keys = ( 453 c_api.TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper( 454 results.results)) 455 if missing_unused_input_keys: 456 missing_unused_input_keys = [ 457 compat.as_str(s) for s in missing_unused_input_keys 458 ] 459 raise ValueError( 460 'Attempted to map inputs that were not found in graph_def: [%s]' % 461 ', '.join(missing_unused_input_keys)) 462 463 if return_elements is None: 464 return None 465 else: 466 return _GatherReturnElements(return_elements, graph, results.results) 467