1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""A utility function for importing TensorFlow graphs."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import contextlib
21
22from tensorflow.core.framework import graph_pb2
23from tensorflow.python import pywrap_tensorflow as c_api
24from tensorflow.python import tf2
25from tensorflow.python.framework import c_api_util
26from tensorflow.python.framework import device as pydev
27from tensorflow.python.framework import errors
28from tensorflow.python.framework import function
29from tensorflow.python.framework import op_def_registry
30from tensorflow.python.framework import ops
31from tensorflow.python.ops import control_flow_util
32from tensorflow.python.util import compat
33from tensorflow.python.util.deprecation import deprecated_args
34from tensorflow.python.util.tf_export import tf_export
35
36
37def _IsControlInput(input_name):
38  # Expected format: '^operation_name' (control input).
39  return input_name.startswith('^')
40
41
42def _ParseTensorName(tensor_name):
43  """Parses a tensor name into an operation name and output index.
44
45  This function will canonicalize tensor names as follows:
46
47  * "foo:0"       -> ("foo", 0)
48  * "foo:7"       -> ("foo", 7)
49  * "foo"         -> ("foo", 0)
50  * "foo:bar:baz" -> ValueError
51
52  Args:
53    tensor_name: The name of a tensor.
54
55  Returns:
56    A tuple containing the operation name, and the output index.
57
58  Raises:
59    ValueError: If `tensor_name' cannot be interpreted as the name of a tensor.
60  """
61  components = tensor_name.split(':')
62  if len(components) == 2:
63    # Expected format: 'operation_name:output_index'.
64    try:
65      output_index = int(components[1])
66    except ValueError:
67      raise ValueError('Cannot convert %r to a tensor name.' % (tensor_name,))
68    return components[0], output_index
69  elif len(components) == 1:
70    # Expected format: 'operation_name' (implicit 0th output).
71    return components[0], 0
72  else:
73    raise ValueError('Cannot convert %r to a tensor name.' % (tensor_name,))
74
75
76@contextlib.contextmanager
77def _MaybeDevice(device):
78  """Applies the given device only if device is not None or empty."""
79  if device:
80    with ops.device(device):
81      yield
82  else:
83    yield
84
85
86def _ProcessGraphDefParam(graph_def, op_dict):
87  """Type-checks and possibly canonicalizes `graph_def`."""
88  if not isinstance(graph_def, graph_pb2.GraphDef):
89    # `graph_def` could be a dynamically-created message, so try a duck-typed
90    # approach
91    try:
92      old_graph_def = graph_def
93      graph_def = graph_pb2.GraphDef()
94      graph_def.MergeFrom(old_graph_def)
95    except TypeError:
96      raise TypeError('graph_def must be a GraphDef proto.')
97  else:
98    # If we're using the graph_def provided by the caller, modify graph_def
99    # in-place to add attr defaults to the NodeDefs (this is visible to the
100    # caller).
101    # NOTE(skyewm): this is undocumented behavior that at least meta_graph.py
102    # depends on. It might make sense to move this to meta_graph.py and have
103    # import_graph_def not modify the graph_def argument (we'd have to make sure
104    # this doesn't break anything else.)
105    for node in graph_def.node:
106      if node.op not in op_dict:
107        # Assume unrecognized ops are functions for now. TF_ImportGraphDef will
108        # report an error if the op is actually missing.
109        continue
110      op_def = op_dict[node.op]
111      _SetDefaultAttrValues(node, op_def)
112
113  return graph_def
114
115
116def _ProcessInputMapParam(input_map):
117  """Type-checks and possibly canonicalizes `input_map`."""
118  if input_map is None:
119    input_map = {}
120  else:
121    if not (isinstance(input_map, dict) and all(
122        isinstance(k, compat.bytes_or_text_types) for k in input_map.keys())):
123      raise TypeError('input_map must be a dictionary mapping strings to '
124                      'Tensor objects.')
125  return input_map
126
127
128def _ProcessReturnElementsParam(return_elements):
129  """Type-checks and possibly canonicalizes `return_elements`."""
130  if return_elements is None:
131    return None
132  if not all(
133      isinstance(x, compat.bytes_or_text_types) for x in return_elements):
134    raise TypeError('return_elements must be a list of strings.')
135  return tuple(compat.as_str(x) for x in return_elements)
136
137
138def _FindAttrInOpDef(attr_name, op_def):
139  for attr_def in op_def.attr:
140    if attr_name == attr_def.name:
141      return attr_def
142  return None
143
144
145def _RemoveDefaultAttrs(op_dict, producer_op_list, graph_def):
146  """Removes unknown default attrs according to `producer_op_list`.
147
148  Removes any unknown attrs in `graph_def` (i.e. attrs that do not appear in
149  the OpDefs in `op_dict`) that have a default value in `producer_op_list`.
150
151  Args:
152    op_dict: dict mapping operation name to OpDef.
153    producer_op_list: OpList proto.
154    graph_def: GraphDef proto
155  """
156  producer_op_dict = {op.name: op for op in producer_op_list.op}
157  for node in graph_def.node:
158    # Remove any default attr values that aren't in op_def.
159    if node.op in producer_op_dict:
160      op_def = op_dict[node.op]
161      producer_op_def = producer_op_dict[node.op]
162      # We make a copy of node.attr to iterate through since we may modify
163      # node.attr inside the loop.
164      for key in list(node.attr):
165        if _FindAttrInOpDef(key, op_def) is None:
166          # No attr_def in consumer, look in producer.
167          attr_def = _FindAttrInOpDef(key, producer_op_def)
168          if (attr_def and attr_def.HasField('default_value') and
169              node.attr[key] == attr_def.default_value):
170            # Unknown attr had default value in producer, delete it so it can be
171            # understood by consumer.
172            del node.attr[key]
173
174
175def _ConvertInputMapValues(name, input_map):
176  """Ensures all input map values are tensors.
177
178  This should be called from inside the import name scope.
179
180  Args:
181    name: the `name` argument passed to import_graph_def
182    input_map: the `input_map` argument passed to import_graph_def.
183
184  Returns:
185    An possibly-updated version of `input_map`.
186
187  Raises:
188    ValueError: if input map values cannot be converted due to empty name scope.
189  """
190  if not all(isinstance(v, ops.Tensor) for v in input_map.values()):
191    if name == '':  # pylint: disable=g-explicit-bool-comparison
192      raise ValueError(
193          'tf.import_graph_def() requires a non-empty `name` if `input_map` '
194          'contains non-Tensor values. Try calling tf.convert_to_tensor() on '
195          '`input_map` values before calling tf.import_graph_def().')
196    with ops.name_scope('_inputs'):
197      input_map = {k: ops.convert_to_tensor(v) for k, v in input_map.items()}
198  return input_map
199
200
201def _PopulateTFImportGraphDefOptions(options, prefix, input_map,
202                                     return_elements):
203  """Populates the TF_ImportGraphDefOptions `options`."""
204  c_api.TF_ImportGraphDefOptionsSetPrefix(options, prefix)
205  c_api.TF_ImportGraphDefOptionsSetUniquifyNames(options, True)
206
207  for input_src, input_dst in input_map.items():
208    input_src = compat.as_str(input_src)
209    if input_src.startswith('^'):
210      src_name = compat.as_str(input_src[1:])
211      dst_op = input_dst._as_tf_output().oper  # pylint: disable=protected-access
212      c_api.TF_ImportGraphDefOptionsRemapControlDependency(
213          options, src_name, dst_op)
214    else:
215      src_name, src_idx = _ParseTensorName(input_src)
216      src_name = compat.as_str(src_name)
217      dst_output = input_dst._as_tf_output()  # pylint: disable=protected-access
218      c_api.TF_ImportGraphDefOptionsAddInputMapping(options, src_name, src_idx,
219                                                    dst_output)
220  for name in return_elements or []:
221    if ':' in name:
222      op_name, index = _ParseTensorName(name)
223      op_name = compat.as_str(op_name)
224      c_api.TF_ImportGraphDefOptionsAddReturnOutput(options, op_name, index)
225    else:
226      c_api.TF_ImportGraphDefOptionsAddReturnOperation(options,
227                                                       compat.as_str(name))
228
229
230def _ProcessNewOps(graph):
231  """Processes the newly-added TF_Operations in `graph`."""
232  # Maps from a node to the names of the ops it's colocated with, if colocation
233  # is specified in the attributes.
234  colocation_pairs = {}
235
236  for new_op in graph._add_new_tf_operations(compute_devices=False):  # pylint: disable=protected-access
237    original_device = new_op.device
238    new_op._set_device('')  # pylint: disable=protected-access
239    colocation_names = _GetColocationNames(new_op)
240    if colocation_names:
241      colocation_pairs[new_op] = colocation_names
242      # Don't set a device for this op, since colocation constraints override
243      # device functions and the original device. Note that this op's device may
244      # still be set by the loop below.
245      # TODO(skyewm): why does it override the original device?
246    else:
247      with _MaybeDevice(original_device):
248        graph._apply_device_functions(new_op)  # pylint: disable=protected-access
249
250  # The following loop populates the device field of ops that are colocated
251  # with another op.  This is implied by the colocation attribute, but we
252  # propagate the device field for completeness.
253  for op, coloc_op_list in colocation_pairs.items():
254    coloc_device = None
255    # Find any device in the list of colocated ops that have a device, if it
256    # exists.  We assume that if multiple ops have devices, they refer to the
257    # same device.  Otherwise, a runtime error will occur since the colocation
258    # property cannot be guaranteed.  Note in TF2 colocations have been removed
259    # from the public API and will be considered a hint, so there is no runtime
260    # error.
261    #
262    # One possible improvement is to try to check for compatibility of all
263    # devices in this list at import time here, which would require
264    # implementing a compatibility function for device specs in python.
265    for coloc_op_name in coloc_op_list:
266      try:
267        coloc_op = graph._get_operation_by_name_unsafe(coloc_op_name)  # pylint: disable=protected-access
268      except KeyError:
269        # Do not error in TF2 if the colocation cannot be guaranteed
270        if tf2.enabled() or control_flow_util.EnableControlFlowV2(graph):
271          continue
272
273        raise ValueError('Specified colocation to an op that '
274                         'does not exist during import: %s in %s' %
275                         (coloc_op_name, op.name))
276      if coloc_op.device:
277        coloc_device = pydev.DeviceSpec.from_string(coloc_op.device)
278        break
279    if coloc_device:
280      op._set_device(coloc_device)  # pylint: disable=protected-access
281
282
283def _GetColocationNames(op):
284  """Returns names of the ops that `op` should be colocated with."""
285  colocation_names = []
286  try:
287    class_values = op.get_attr('_class')
288  except ValueError:
289    # No _class attr
290    return
291  for val in class_values:
292    val = compat.as_str(val)
293    if val.startswith('loc:@'):
294      colocation_node_name = val[len('loc:@'):]
295      if colocation_node_name != op.name:
296        colocation_names.append(colocation_node_name)
297  return colocation_names
298
299
300def _GatherReturnElements(requested_return_elements, graph, results):
301  """Returns the requested return elements from results.
302
303  Args:
304    requested_return_elements: list of strings of operation and tensor names
305    graph: Graph
306    results: wrapped TF_ImportGraphDefResults
307
308  Returns:
309    list of `Operation` and/or `Tensor` objects
310  """
311  return_outputs = c_api.TF_ImportGraphDefResultsReturnOutputs(results)
312  return_opers = c_api.TF_ImportGraphDefResultsReturnOperations(results)
313
314  combined_return_elements = []
315  outputs_idx = 0
316  opers_idx = 0
317  for name in requested_return_elements:
318    if ':' in name:
319      combined_return_elements.append(
320          graph._get_tensor_by_tf_output(return_outputs[outputs_idx]))  # pylint: disable=protected-access
321      outputs_idx += 1
322    else:
323      combined_return_elements.append(
324          graph._get_operation_by_tf_operation(return_opers[opers_idx]))  # pylint: disable=protected-access
325      opers_idx += 1
326  return combined_return_elements
327
328
329def _SetDefaultAttrValues(node_def, op_def):
330  """Set any default attr values in `node_def` that aren't present."""
331  assert node_def.op == op_def.name
332  for attr_def in op_def.attr:
333    key = attr_def.name
334    if attr_def.HasField('default_value'):
335      value = node_def.attr[key]
336      if value is None or value.WhichOneof('value') is None:
337        node_def.attr[key].CopyFrom(attr_def.default_value)
338
339
340@tf_export('graph_util.import_graph_def', 'import_graph_def')
341@deprecated_args(None, 'Please file an issue at '
342                 'https://github.com/tensorflow/tensorflow/issues if you depend'
343                 ' on this feature.', 'op_dict')
344def import_graph_def(graph_def,
345                     input_map=None,
346                     return_elements=None,
347                     name=None,
348                     op_dict=None,
349                     producer_op_list=None):
350  """Imports the graph from `graph_def` into the current default `Graph`.
351
352  This function provides a way to import a serialized TensorFlow
353  [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto)
354  protocol buffer, and extract individual objects in the `GraphDef` as
355  `tf.Tensor` and `tf.Operation` objects. Once extracted,
356  these objects are placed into the current default `Graph`. See
357  `tf.Graph.as_graph_def` for a way to create a `GraphDef`
358  proto.
359
360  Args:
361    graph_def: A `GraphDef` proto containing operations to be imported into
362      the default graph.
363    input_map: A dictionary mapping input names (as strings) in `graph_def`
364      to `Tensor` objects. The values of the named input tensors in the
365      imported graph will be re-mapped to the respective `Tensor` values.
366    return_elements: A list of strings containing operation names in
367      `graph_def` that will be returned as `Operation` objects; and/or
368      tensor names in `graph_def` that will be returned as `Tensor` objects.
369    name: (Optional.) A prefix that will be prepended to the names in
370      `graph_def`. Note that this does not apply to imported function names.
371      Defaults to `"import"`.
372    op_dict: (Optional.) Deprecated, do not use.
373    producer_op_list: (Optional.) An `OpList` proto with the (possibly stripped)
374      list of `OpDef`s used by the producer of the graph. If provided,
375      unrecognized attrs for ops in `graph_def` that have their default value
376      according to `producer_op_list` will be removed. This will allow some more
377      `GraphDef`s produced by later binaries to be accepted by earlier binaries.
378
379  Returns:
380    A list of `Operation` and/or `Tensor` objects from the imported graph,
381    corresponding to the names in `return_elements`,
382    and None if `returns_elements` is None.
383
384  Raises:
385    TypeError: If `graph_def` is not a `GraphDef` proto,
386      `input_map` is not a dictionary mapping strings to `Tensor` objects,
387      or `return_elements` is not a list of strings.
388    ValueError: If `input_map`, or `return_elements` contains names that
389      do not appear in `graph_def`, or `graph_def` is not well-formed (e.g.
390      it refers to an unknown tensor).
391  """
392  op_dict = op_def_registry.get_registered_ops()
393
394  graph_def = _ProcessGraphDefParam(graph_def, op_dict)
395  input_map = _ProcessInputMapParam(input_map)
396  return_elements = _ProcessReturnElementsParam(return_elements)
397
398  if producer_op_list is not None:
399    # TODO(skyewm): make a copy of graph_def so we're not mutating the argument?
400    _RemoveDefaultAttrs(op_dict, producer_op_list, graph_def)
401
402  graph = ops.get_default_graph()
403  with ops.name_scope(name, 'import', input_map.values()) as scope:
404    # Save unique prefix generated by name_scope
405    if scope:
406      assert scope.endswith('/')
407      prefix = scope[:-1]
408    else:
409      prefix = ''
410
411    # Generate any input map tensors inside name scope
412    input_map = _ConvertInputMapValues(name, input_map)
413
414  scoped_options = c_api_util.ScopedTFImportGraphDefOptions()
415  options = scoped_options.options
416  _PopulateTFImportGraphDefOptions(options, prefix, input_map,
417                                   return_elements)
418
419  # _ProcessNewOps mutates the new operations. _mutation_lock ensures a
420  # Session.run call cannot occur between creating the TF_Operations in the
421  # TF_GraphImportGraphDefWithResults call and mutating the them in
422  # _ProcessNewOps.
423  with graph._mutation_lock():  # pylint: disable=protected-access
424    with c_api_util.tf_buffer(graph_def.SerializeToString()) as serialized:
425      try:
426        results = c_api.TF_GraphImportGraphDefWithResults(
427            graph._c_graph, serialized, options)  # pylint: disable=protected-access
428        results = c_api_util.ScopedTFImportGraphDefResults(results)
429      except errors.InvalidArgumentError as e:
430        # Convert to ValueError for backwards compatibility.
431        raise ValueError(str(e))
432
433    # Create _DefinedFunctions for any imported functions.
434    #
435    # We do this by creating _DefinedFunctions directly from `graph_def`, and
436    # adding them to `graph`. Adding an existing function to a TF_Graph is a
437    # no-op, so this only has the effect of updating the Python state (usually
438    # _DefinedFunction.add_to_graph also adds the function to the TF_Graph).
439    #
440    # TODO(skyewm): fetch the TF_Functions directly from the TF_Graph
441    # TODO(skyewm): avoid sending serialized FunctionDefs back to the TF_Graph
442
443    _ProcessNewOps(graph)
444
445  if graph_def.library and graph_def.library.function:
446    functions = function.from_library(graph_def.library)
447    for f in functions:
448      f.add_to_graph(graph)
449
450  # Treat input mappings that don't appear in the graph as an error, because
451  # they are likely to be due to a typo.
452  missing_unused_input_keys = (
453      c_api.TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper(
454          results.results))
455  if missing_unused_input_keys:
456    missing_unused_input_keys = [
457        compat.as_str(s) for s in missing_unused_input_keys
458    ]
459    raise ValueError(
460        'Attempted to map inputs that were not found in graph_def: [%s]' %
461        ', '.join(missing_unused_input_keys))
462
463  if return_elements is None:
464    return None
465  else:
466    return _GatherReturnElements(return_elements, graph, results.results)
467