1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Helpers to manipulate a tensor graph in python.
16"""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21import copy
22import re
23
24import six
25
26from tensorflow.core.framework import graph_pb2
27from tensorflow.core.framework import node_def_pb2
28from tensorflow.python import _proto_comparators
29from tensorflow.python.framework import dtypes
30from tensorflow.python.framework import ops
31from tensorflow.python.util import deprecation
32from tensorflow.python.util import lazy_loader
33from tensorflow.python.util.tf_export import tf_export
34
35# A normal import here would generate circular dependencies.
36convert_to_constants = lazy_loader.LazyLoader(
37    "convert_to_constants", globals(),
38    "tensorflow.python.framework.convert_to_constants")
39
40_VARIABLE_OPS = {
41    "Assign",
42    "AssignAdd",
43    "AssignSub",
44    "Queue",
45    "ScatterAdd",
46    "ScatterSub",
47    "ScatterUpdate",
48    "TruncatedNormal",
49    "Variable",
50    "VariableV2",
51}
52
53_CONTROL_FLOW_OP_NAMES_OR_IDENTITY = [
54    "Switch",
55    "Enter",
56    "Exit",
57    "Identity",
58    "Merge",
59    "NextIteration",
60]
61
62
63def _is_variable_op(op):
64  """Returns true if 'op' refers to a Variable node."""
65  return op in _VARIABLE_OPS
66
67
68@deprecation.deprecated(
69    date=None,
70    instructions="Use `tf.compat.v1.graph_util.must_run_on_cpu`")
71@tf_export(v1=["graph_util.must_run_on_cpu"])
72def must_run_on_cpu(node, pin_variables_on_cpu=False):
73  """Returns True if the given node_def must run on CPU, otherwise False.
74
75  Args:
76    node: The node to be assigned to a device. Could be either an ops.Operation
77      or NodeDef.
78    pin_variables_on_cpu: If True, this function will return False if node_def
79      represents a variable-related op.
80
81  Returns:
82    True if the given node must run on CPU, otherwise False.
83  """
84
85  if isinstance(node, ops.Operation):
86    node_def = node.node_def
87  else:
88    assert isinstance(node, node_def_pb2.NodeDef)
89    node_def = node
90
91  # If the op is a variable-related op, should we pin it on CPU?
92  if pin_variables_on_cpu and _is_variable_op(node_def.op):
93    return True
94
95  # Constant operations producing a string or int32 must run on CPU.
96  if node_def.op == "Const":
97    # Get the value of the 'dtype' attr
98    dtype = node_def.attr["dtype"].type
99    if dtype == dtypes.string or dtype == dtypes.int32:
100      return True
101
102  if node_def.op in ["DynamicStitch", "ParallelDynamicStitch"]:
103    dtype = node_def.attr["T"].type
104    if dtype == dtypes.int32:
105      # DynamicStitch on GPU only works for int32 values.
106      return True
107
108  if node_def.op in ["Cast"]:
109    dtype = node_def.attr["SrcT"].type
110    if dtype == dtypes.int32:
111      # Cast on GPU does not works for int32 values.
112      return True
113  return False
114
115
116################################################################################
117#
118# device functions for use in with g.device(...)
119#
120################################################################################
121
122
123def _node_name(n):
124  if n.startswith("^"):
125    return n[1:]
126  else:
127    return n.split(":")[0]
128
129
130def _get_colocated_node_name(colocated_node_name):
131  """Decodes colocated node name and returns it without loc:@ prepended."""
132  colocated_node_decoded = colocated_node_name.decode("utf-8")
133  if colocated_node_decoded.startswith("loc:@"):
134    return colocated_node_decoded[5:]
135  return colocated_node_decoded
136
137
138def _extract_graph_summary(graph_def):
139  """Extracts useful information from the graph and returns them."""
140  name_to_input_name = {}  # Keyed by the dest node name.
141  name_to_node = {}  # Keyed by node name.
142
143  # Keeps track of node sequences. It is important to still output the
144  # operations in the original order.
145  name_to_seq_num = {}  # Keyed by node name.
146  seq = 0
147  for node in graph_def.node:
148    n = _node_name(node.name)
149    name_to_node[n] = node
150    name_to_input_name[n] = [_node_name(x) for x in node.input]
151    # Prevent colocated nodes from being lost.
152    if "_class" in node.attr:
153      for colocated_node_name in node.attr["_class"].list.s:
154        name_to_input_name[n].append(
155            _get_colocated_node_name(colocated_node_name))
156    name_to_seq_num[n] = seq
157    seq += 1
158  return name_to_input_name, name_to_node, name_to_seq_num
159
160
161def _assert_nodes_are_present(name_to_node, nodes):
162  """Assert that nodes are present in the graph."""
163  for d in nodes:
164    assert d in name_to_node, "%s is not in graph" % d
165
166
167def _bfs_for_reachable_nodes(target_nodes, name_to_input_name):
168  """Breadth first search for reachable nodes from target nodes."""
169  nodes_to_keep = set()
170  # Breadth first search to find all the nodes that we should keep.
171  next_to_visit = target_nodes[:]
172  while next_to_visit:
173    node = next_to_visit[0]
174    del next_to_visit[0]
175    if node in nodes_to_keep:
176      # Already visited this node.
177      continue
178    nodes_to_keep.add(node)
179    if node in name_to_input_name:
180      next_to_visit += name_to_input_name[node]
181  return nodes_to_keep
182
183
184@deprecation.deprecated(
185    date=None,
186    instructions="Use `tf.compat.v1.graph_util.extract_sub_graph`")
187@tf_export(v1=["graph_util.extract_sub_graph"])
188def extract_sub_graph(graph_def, dest_nodes):
189  """Extract the subgraph that can reach any of the nodes in 'dest_nodes'.
190
191  Args:
192    graph_def: A graph_pb2.GraphDef proto.
193    dest_nodes: A list of strings specifying the destination node names.
194  Returns:
195    The GraphDef of the sub-graph.
196
197  Raises:
198    TypeError: If 'graph_def' is not a graph_pb2.GraphDef proto.
199  """
200
201  if not isinstance(graph_def, graph_pb2.GraphDef):
202    raise TypeError("graph_def must be a graph_pb2.GraphDef proto.")
203
204  if isinstance(dest_nodes, six.string_types):
205    raise TypeError("dest_nodes must be a list.")
206
207  name_to_input_name, name_to_node, name_to_seq_num = _extract_graph_summary(
208      graph_def)
209  _assert_nodes_are_present(name_to_node, dest_nodes)
210
211  nodes_to_keep = _bfs_for_reachable_nodes(dest_nodes, name_to_input_name)
212
213  nodes_to_keep_list = sorted(
214      list(nodes_to_keep), key=lambda n: name_to_seq_num[n])
215  # Now construct the output GraphDef
216  out = graph_pb2.GraphDef()
217  for n in nodes_to_keep_list:
218    out.node.extend([copy.deepcopy(name_to_node[n])])
219  out.library.CopyFrom(graph_def.library)
220  out.versions.CopyFrom(graph_def.versions)
221
222  return out
223
224
225@deprecation.deprecated(
226    date=None,
227    instructions="Use `tf.compat.v1.graph_util.tensor_shape_from_node_def_name`"
228)
229@tf_export(v1=["graph_util.tensor_shape_from_node_def_name"])
230def tensor_shape_from_node_def_name(graph, input_name):
231  """Convenience function to get a shape from a NodeDef's input string."""
232  # To get a tensor, the name must be in the form <input>:<port>, for example
233  # 'Mul:0'. The GraphDef input strings don't always have the port specified
234  # though, so if there isn't a colon we need to add a default ':0' to the end.
235  if ":" not in input_name:
236    canonical_name = input_name + ":0"
237  else:
238    canonical_name = input_name
239  tensor = graph.get_tensor_by_name(canonical_name)
240  shape = tensor.get_shape()
241  return shape
242
243
244@deprecation.deprecated(
245    date=None,
246    instructions="Use `tf.compat.v1.graph_util.convert_variables_to_constants`")
247@tf_export(v1=["graph_util.convert_variables_to_constants"])
248def convert_variables_to_constants(sess,
249                                   input_graph_def,
250                                   output_node_names,
251                                   variable_names_whitelist=None,
252                                   variable_names_blacklist=None):
253  """Replaces all the variables in a graph with constants of the same values.
254
255  If you have a trained graph containing Variable ops, it can be convenient to
256  convert them all to Const ops holding the same values. This makes it possible
257  to describe the network fully with a single GraphDef file, and allows the
258  removal of a lot of ops related to loading and saving the variables.
259
260  Args:
261    sess: Active TensorFlow session containing the variables.
262    input_graph_def: GraphDef object holding the network.
263    output_node_names: List of name strings for the result nodes of the graph.
264    variable_names_whitelist: The set of variable names to convert (by default,
265                              all variables are converted).
266    variable_names_blacklist: The set of variable names to omit converting
267                              to constants.
268
269  Returns:
270    GraphDef containing a simplified version of the original.
271
272  Raises:
273    RuntimeError: if a DT_RESOURCE op is found whose ancestor Variables are both
274      denylisted AND whitelisted for freezing.
275  """
276  ret = convert_to_constants.convert_variables_to_constants_from_session_graph(
277      session=sess,
278      graph_def=input_graph_def,
279      output_node_names=output_node_names,
280      variable_names_allowlist=variable_names_whitelist,
281      variable_names_denylist=variable_names_blacklist)
282  # The previous code logic generated an empty versions field, we clear it here
283  # to maintain backwards compatibility.
284  ret.versions.Clear()
285  return ret
286
287
288@deprecation.deprecated(
289    date=None,
290    instructions="Use `tf.compat.v1.graph_util.remove_training_nodes`")
291@tf_export(v1=["graph_util.remove_training_nodes"])
292def remove_training_nodes(input_graph, protected_nodes=None):
293  """Prunes out nodes that aren't needed for inference.
294
295  There are nodes like Identity and CheckNumerics that are only useful
296  during training, and can be removed in graphs that will be used for
297  nothing but inference. Here we identify and remove them, returning an
298  equivalent graph. To be specific, CheckNumerics nodes are always removed, and
299  Identity nodes that aren't involved in control edges are spliced out so that
300  their input and outputs are directly connected.
301
302  Args:
303    input_graph: Model to analyze and prune.
304    protected_nodes: An optional list of names of nodes to be kept
305      unconditionally. This is for example useful to preserve Identity output
306      nodes.
307
308  Returns:
309    A list of nodes with the unnecessary ones removed.
310  """
311  if not protected_nodes:
312    protected_nodes = []
313
314  types_to_remove = {"CheckNumerics": True}
315
316  input_nodes = input_graph.node
317  names_to_remove = {}
318  for node in input_nodes:
319    if node.op in types_to_remove and node.name not in protected_nodes:
320      names_to_remove[node.name] = True
321
322  nodes_after_removal = []
323  for node in input_nodes:
324    if node.name in names_to_remove:
325      continue
326    new_node = node_def_pb2.NodeDef()
327    new_node.CopyFrom(node)
328    input_before_removal = node.input
329    del new_node.input[:]
330    for full_input_name in input_before_removal:
331      input_name = re.sub(r"^\^", "", full_input_name)
332      if input_name in names_to_remove:
333        continue
334      new_node.input.append(full_input_name)
335    nodes_after_removal.append(new_node)
336
337  types_to_splice = {"Identity": True}
338  control_input_names = set()
339  node_names_with_control_input = set()
340  for node in nodes_after_removal:
341    for node_input in node.input:
342      if "^" in node_input:
343        control_input_names.add(node_input.replace("^", ""))
344        node_names_with_control_input.add(node.name)
345
346  names_to_splice = {}
347  for node in nodes_after_removal:
348    if node.op in types_to_splice and node.name not in protected_nodes:
349      # We don't want to remove nodes that have control edge inputs, because
350      # they might be involved in subtle dependency issues that removing them
351      # will jeopardize.
352      if node.name not in node_names_with_control_input:
353        names_to_splice[node.name] = node.input[0]
354
355  # We also don't want to remove nodes which are used as control edge inputs.
356  names_to_splice = {name: value for name, value in names_to_splice.items()
357                     if name not in control_input_names}
358
359  nodes_after_splicing = []
360  for node in nodes_after_removal:
361    if node.name in names_to_splice:
362      continue
363    new_node = node_def_pb2.NodeDef()
364    new_node.CopyFrom(node)
365    input_before_removal = node.input
366    del new_node.input[:]
367    for full_input_name in input_before_removal:
368      input_name = re.sub(r"^\^", "", full_input_name)
369      while input_name in names_to_splice:
370        full_input_name = names_to_splice[input_name]
371        input_name = re.sub(r"^\^", "", full_input_name)
372      new_node.input.append(full_input_name)
373    nodes_after_splicing.append(new_node)
374
375  output_graph = graph_pb2.GraphDef()
376  output_graph.node.extend(nodes_after_splicing)
377  return output_graph
378
379
380def graph_defs_equal(graph_def_1: graph_pb2.GraphDef,
381                     graph_def_2: graph_pb2.GraphDef,
382                     treat_nan_as_equal: bool = False) -> bool:
383  """Returns True iff the graph def arguments are structurally equivalent.
384
385  The notion of equivalence encoded here checks that the set of NodeDefs in
386  the GraphDef's function library and main graph body are identical.
387  Additionally, it checks that the functions in the function library are equal
388  as sets.
389
390  Args:
391    graph_def_1: Instance of `graph_pb2.GraphDef` to compare.
392    graph_def_2: Instance of `graph_pb2.GraphDef` to compare.
393    treat_nan_as_equal: Boolean indicating whether or not to treat nan
394      floating-point values as equal. This is crucial for any equivalence
395      relation defined over GraphDefs, to ensure symmetry.
396
397  Returns:
398    Boolean indicating structural equivalence as described above.
399
400  Raises:
401    TypeError: If either of the GraphDefs are not instances of
402      `graph_pb2.GraphDef`.
403  """
404  if not isinstance(graph_def_1, graph_pb2.GraphDef):
405    raise TypeError("graph_def_1 must be a graph_pb2.GraphDef proto.")
406  if not isinstance(graph_def_2, graph_pb2.GraphDef):
407    raise TypeError("graph_def_2 must be a graph_pb2.GraphDef proto.")
408  options = _proto_comparators.ProtoComparisonOptions(treat_nan_as_equal)
409  return _proto_comparators.EqualsGraphDef(graph_def_1.SerializeToString(),
410                                           graph_def_2.SerializeToString(),
411                                           options)
412