1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Helpers to convert variables to constants in TensorFlow 2.0."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import numpy as np
23
24from tensorflow.core.framework import attr_value_pb2
25from tensorflow.core.framework import graph_pb2
26from tensorflow.core.framework import tensor_shape_pb2
27from tensorflow.core.framework import variable_pb2
28from tensorflow.core.protobuf import config_pb2
29from tensorflow.core.protobuf import meta_graph_pb2
30from tensorflow.core.protobuf import rewriter_config_pb2
31from tensorflow.python.eager import context
32from tensorflow.python.framework import dtypes
33from tensorflow.python.framework import graph_util
34from tensorflow.python.framework import ops
35from tensorflow.python.framework import tensor_util
36from tensorflow.python.grappler import tf_optimizer
37from tensorflow.python.ops import array_ops
38from tensorflow.python.training.saver import export_meta_graph
39from tensorflow.python.util import lazy_loader
40from tensorflow.python.util import object_identity
41
42# Lazy load the single eager module to avoid introducing new dependencies for
43# graph_util:convert_variables_to_constants (eg in
44# tensorflow/contrib/session_bundle:session_bundle_py_test).
45wrap_function = lazy_loader.LazyLoader(
46    "wrap_function", globals(),
47    "tensorflow.python.eager.wrap_function")
48
49_CONDITIONAL_OPS = set(["If", "StatelessIf"])
50_LOOP_OPS = set(["While", "StatelessWhile"])
51_CONTROL_FLOW_OPS = _CONDITIONAL_OPS.union(_LOOP_OPS)
52
53
54class _TensorData(
55    collections.namedtuple("_TensorData", ["numpy", "dtype", "index"])):
56  """Data about a tensor that was converted to a constant."""
57  __slots__ = ()
58
59  @property
60  def dtype_attr(self):
61    return attr_value_pb2.AttrValue(type=self.dtype)
62
63
64class _EndPoint(collections.namedtuple("_EndPoint", ["convertible", "index"])):
65  """An endpoint in a graph."""
66  __slots__ = ()
67
68  def __str__(self):
69    return "{}[{}]".format(self.convertible, self.index)
70
71
72class _Edge(collections.namedtuple("_Edge", ["source", "destination"])):
73  """A directed graph edge."""
74  __slots__ = ()
75
76  def __str__(self):
77    return "{} -> {}".format(self.source, self.destination)
78
79
80class _Convertible(object):
81  """An entity that can have variables converted to constants."""
82
83  def __init__(self, enclosing_graph):
84    self._enclosing_graph = enclosing_graph
85    self._outgoing_edges = []
86    self._converted_self = None
87
88  def converted_self(self):
89    """A copy of this Convertible to be modified during conversion.
90
91    Returns:
92      Implementations should return the copied instance, which in turn should
93      be contained in converted_enclosing_graph(). This instance is the one that
94      will be modified during conversion. Its main use will be in the
95      implementations of convert_variable_to_constant().
96    """
97    raise NotImplementedError()
98
99  def convert_variable_to_constant(self, incoming_edge, tensor_data):
100    """Converts a variable in this Convertible and its dependencies.
101
102    This method should make sure that a converted copy of itself is present in
103    the converted graph, and that all Convertibles depending on this one also go
104    through the same process.
105
106    Args:
107      incoming_edge: The graph edge into this Convertible that is being
108        converted to a constant.
109      tensor_data: The tensor representing the constant.
110    """
111    raise NotImplementedError()
112
113  def create_edges(self):
114    """Calls add_outgoing_edge for all edges known to this Convertible.
115
116    This is used to build the graph dependencies, so that conversion of
117    variables to constants can be properly propagated through the graph. Usually
118    this method will call add_outgoing_edge() to all the Convertible inputs.
119    """
120    raise NotImplementedError()
121
122  def add_outgoing_edge(self, edge):
123    """Adds an outgoing edge to the Convertible's list of edges.
124
125    Args:
126      edge: The outgoing edge (its source should be 'self').
127    """
128    self._outgoing_edges.append(edge)
129
130  @property
131  def converted_enclosing_graph(self):
132    """The graph being converted."""
133    return self._enclosing_graph.converted_self()
134
135  @property
136  def outgoing_edges(self):
137    """The list of edges starting at this Convertible."""
138    return self._outgoing_edges
139
140
141class _Function(_Convertible):
142  """A library function Convertible.
143
144  Edges into functions are edges from node _inputs_ into function _inputs_:
145  Functions get their input from their callers, not from node outputs, and the
146  callers in turn get those values as inputs.
147  """
148
149  def __init__(self, function, enclosing_graph):
150    super(_Function, self).__init__(enclosing_graph)
151    self._function = function
152    self._nodes = {
153        n.name:
154        _Node.new(node=n, function=self, enclosing_graph=enclosing_graph)
155        for n in function.node_def
156    }
157
158  def __str__(self):
159    return self.function.signature.name
160
161  @property
162  def function(self):
163    return self._function
164
165  @property
166  def nodes(self):
167    return self._nodes
168
169  def converted_self(self):
170    """The Function copy to be converted.
171
172    The copy will be renamed according to the graph's converted_function_name
173    map, to ensure the name does not match anything currently in TensorFlow's
174    function cache.
175
176    Returns:
177      The function instance to be converted.
178    """
179    if self._converted_self is None:
180      old_name = self.function.signature.name
181      new_name = self._enclosing_graph.converted_function_names[old_name]
182      self.converted_enclosing_graph.rename_function(old_name, new_name)
183      self._converted_self = self.converted_enclosing_graph.functions[new_name]
184    return self._converted_self
185
186  def convert_variable_to_constant(self, incoming_edge, tensor_data):
187    """Converts one function argument into a constant.
188
189    Args:
190      incoming_edge: The edge into the argument to be converted.
191      tensor_data: The constant value.
192    """
193    function = self.converted_self().function
194    index = incoming_edge.destination.index
195    function.signature.input_arg[index].type = tensor_data.dtype
196
197    for edge in self.outgoing_edges:
198      if edge.source.index == index:
199        edge.destination.convertible.convert_variable_to_constant(
200            edge, tensor_data)
201
202  def create_edges(self):
203    for n in self._nodes.values():
204      n.create_edges()
205
206
207class _Node(_Convertible):
208  """A Convertible NodeDef."""
209
210  def __init__(self, node, function, enclosing_graph):
211    super(_Node, self).__init__(enclosing_graph)
212    self._node = node
213    self._function = function
214
215  def __str__(self):
216    return self._node.name
217
218  @staticmethod
219  def new(node, function, enclosing_graph):
220    """Creates a new _Node base on its operation type."""
221    if node.op in ["VariableV2", "VarHandleOp", "Placeholder"]:
222      return _VarHandle(node, function, enclosing_graph)
223    elif node.op == "Case":
224      return _Case(node, function, enclosing_graph)
225    elif node.op == "Merge":
226      return _Merge(node, function, enclosing_graph)
227    elif node.op == "PartitionedCall":
228      return _PartitionedCall(node, function, enclosing_graph)
229    elif node.op == "ReadVariableOp":
230      return _ReadVariable(node, function, enclosing_graph)
231    elif node.op == "ResourceGather":
232      return _ResourceGather(node, function, enclosing_graph)
233    elif node.op == "ResourceGatherNd":
234      return _ResourceGatherNd(node, function, enclosing_graph)
235    elif node.op in ["If", "StatelessIf"]:
236      return _If(node, function, enclosing_graph)
237    elif node.op in ["While", "StatelessWhile"]:
238      return _While(node, function, enclosing_graph)
239    elif node.op in [
240        "Enter", "Exit", "Identity", "NextIteration", "Switch", "_SwitchN"]:
241      return _Intermediate(node, function, enclosing_graph)
242    else:
243      return _Node(node, function, enclosing_graph)
244
245  @property
246  def node(self):
247    return self._node
248
249  @property
250  def container(self):
251    """The node container (either a graph or a function)."""
252    if self._function is not None:
253      return self._function.function
254    return self._enclosing_graph.graph_def
255
256  def converted_self(self):
257    """The NodeDef to be converted.
258
259    Returns:
260      The NodeDef to be converted, which can come from either a graph for a
261      function. Derived classes should call this (via 'super') to make sure the
262      node is retrieved from the right place.
263    """
264    if self._converted_self is None:
265      source = self._function or self._enclosing_graph
266      self._converted_self = source.converted_self().nodes[self._node.name]
267    return self._converted_self
268
269  def convert_variable_to_constant(self, incoming_edge, tensor_data):
270    pass
271
272  def create_edges(self):
273    for index, name in enumerate(self._node.input):
274      # Discard edges from control inputs.
275      if name[0] == "^":
276        continue
277      source = self.resolve_input(name)
278      source.convertible.add_outgoing_edge(
279          _Edge(source, _EndPoint(self, index)))
280
281  def resolve_input(self, input_name):
282    """Resolves an input into its _EndPoint.
283
284    A NodeDef's input name can refer to either global NodeDefs (in the
285    GraphDef's node list), a NodeDef in a function's node list, or a Function
286    (in the GraphDef's function library). The name can also carry semantic
287    information, depending on whether it starts with "^". This method handles
288    all that logic in order to find the object to which the input name refers
289    to.
290
291    Args:
292      input_name: The input name to resolve.
293
294    Returns:
295      The object referred to by 'input_name'.
296    """
297
298    # The logic below oversimplifies the semantics, but is good enough for the
299    # purposes of converting to constants. The introduction of new types of
300    # operations may change this, forcing the code to be more generic.
301    #
302    # In particular, we are assuming that the lack of an index suffix means
303    # ":0", when it could mean "all the outputs of a node." This works now
304    # because converting to constants relies very little on output types, and
305    # when it does it specializes its treatment in dedicated classes.
306    name_elts = input_name.split(":")
307    source_name = name_elts[0]
308    if source_name[0] == "^":
309      source_name = source_name[1:]
310    source_index = 0
311    if len(name_elts) > 1 and name_elts[-1].isnumeric():
312      source_index = int(name_elts[-1])
313
314    if self._function is None:
315      return _EndPoint(self._enclosing_graph.nodes[source_name], source_index)
316
317    if source_index != 0 or source_name in self._function.nodes:
318      return _EndPoint(self._function.nodes[source_name], source_index)
319
320    inputs = [i.name for i in self._function.function.signature.input_arg]
321    return _EndPoint(self._function, inputs.index(source_name))
322
323  def update_dtype(self, attr_name, index, dtype):
324    """Changes the type of a given input.
325
326    Args:
327      attr_name: The NodeDef attribute containing the type to change.
328      index: The index of the input type to change.
329      dtype: The type to change to.
330    """
331    attr = self._node.attr[attr_name]
332    num_types = 0
333    # Check for various 'oneof' possibilities, and update the type if
334    # index in range.
335    if attr.HasField("list"):
336      types = attr.list.type
337      num_types = len(types)
338      if num_types > index:
339        types[index] = dtype
340        return
341    elif attr.HasField("type"):
342      num_types = 1
343      if index == 0:
344        attr.type = dtype
345        return
346    raise ValueError(
347        "Index %d out of range for node(%s).attr(%s), which has %d elements." %
348        (index, self._node.name, attr_name, num_types))
349
350
351class _Intermediate(_Node):
352  """Specialization of _Node to intermediate ops."""
353
354  def convert_variable_to_constant(self, incoming_edge, tensor_data):
355    node = self.converted_self()
356    node.update_dtype("T", incoming_edge.destination.index, tensor_data.dtype)
357    if "_output_shapes" in node.node.attr:
358      del node.node.attr["_output_shapes"]
359    for edge in self.outgoing_edges:
360      edge.destination.convertible.convert_variable_to_constant(
361          edge, tensor_data)
362
363
364class _Merge(_Node):
365  """Specialization of _Node to Merge ops."""
366
367  def convert_variable_to_constant(self, incoming_edge, tensor_data):
368    # The Merge operation has a single type for all its inputs, the number of
369    # which is reflected in the "N" attribute. For the time being, we assume
370    # that unilaterally changing all of them at once is ok.
371    super(_Merge, self).convert_variable_to_constant(
372        _Edge(incoming_edge.source,
373              _Edge(incoming_edge.destination.convertible, 0)), tensor_data)
374
375
376class _VarHandle(_Node):
377  """Specialization of _Node to VarHandleOp."""
378
379  def convert_variable_to_constant(self, incoming_edge, tensor_data):
380    tensor_proto = tensor_util.make_tensor_proto(tensor_data.numpy,
381                                                 tensor_data.dtype,
382                                                 tensor_data.numpy.shape)
383
384    node = self.converted_self().node
385    node.Clear()
386    node.name = self._node.name
387    node.op = "Const"
388    node.attr["dtype"].CopyFrom(tensor_data.dtype_attr)
389    node.attr["value"].tensor.CopyFrom(tensor_proto)
390
391    for edge in self.outgoing_edges:
392      edge.destination.convertible.convert_variable_to_constant(
393          edge, tensor_data)
394
395
396class _ResourceGather(_Node):
397  """Specialization of _Node to ResourceGather."""
398
399  def convert_variable_to_constant(self, incoming_edge, tensor_data):
400    # We currently skip the conversion if this is inside a function.
401    if self._function is not None:
402      return
403    if self._node.attr["batch_dims"].i != 0:
404      raise ValueError("batch_dims != 0 is not supported by freeze_graph.")
405    axis_node_name = self._node.name + "/axis"
406    axis_dtype = self._node.attr["Tindices"]
407    axis_data = np.array(self._node.attr["batch_dims"].i)
408    output_axis_node = self.converted_self().container.node.add()
409    output_axis_node.name = axis_node_name
410    output_axis_node.op = "Const"
411    output_axis_node.attr["dtype"].CopyFrom(axis_dtype)
412    tensor = tensor_util.make_tensor_proto(
413        axis_data, dtype=axis_dtype.type, shape=axis_data.shape)
414    output_axis_node.attr["value"].tensor.CopyFrom(tensor)
415
416    output_node = self.converted_self().node
417    output_node.Clear()
418    output_node.name = self._node.name
419    output_node.op = "GatherV2"
420    output_node.input.extend(
421        [self._node.input[0], self._node.input[1], axis_node_name])
422    output_node.attr["Tparams"].CopyFrom(self._node.attr["dtype"])
423    output_node.attr["Tindices"].CopyFrom(self._node.attr["Tindices"])
424    output_node.attr["Taxis"].CopyFrom(axis_dtype)
425    if "_class" in self._node.attr:
426      output_node.attr["_class"].CopyFrom(self._node.attr["_class"])
427
428
429class _ResourceGatherNd(_Node):
430  """Specialization of _Node to ResourceGatherNd."""
431
432  def convert_variable_to_constant(self, incoming_edge, tensor_data):
433    output_node = self.converted_self().node
434    output_node.Clear()
435    output_node.name = self._node.name
436    output_node.op = "GatherNd"
437    output_node.input.extend([self._node.input[0], self._node.input[1]])
438    output_node.attr["Tparams"].CopyFrom(self._node.attr["dtype"])
439    output_node.attr["Tindices"].CopyFrom(self._node.attr["Tindices"])
440    if "_class" in self._node.attr:
441      output_node.attr["_class"].CopyFrom(self._node.attr["_class"])
442
443
444class _ReadVariable(_Node):
445  """Specialization of _Node to ReadVariableOp."""
446
447  def convert_variable_to_constant(self, incoming_edge, tensor_data):
448    node = self.converted_self().node
449    node.Clear()
450    node.name = self._node.name
451    node.op = "Identity"
452
453    node.input.append(self._node.input[0])
454    node.attr["T"].CopyFrom(self._node.attr["dtype"])
455    if "_class" in self._node.attr:
456      node.attr["_class"].CopyFrom(self._node.attr["_class"])
457
458    # If the ReadVariableOp is part of a function, then every node having the
459    # ReadVariableOp one as its input will refer to it using a ":value"
460    # syntax. We need to change that to ":output".
461    if self._function is not None:
462      for edge in self.outgoing_edges:
463        index = edge.destination.index
464        dest = edge.destination.convertible.converted_self()
465        if isinstance(dest, _Node):
466          input_name_parts = dest.node.input[index].split(":")
467          if len(input_name_parts) > 1 and input_name_parts[1] == "value":
468            input_name_parts[1] = "output"
469            dest.node.input[index] = ":".join(input_name_parts)
470
471
472class _FunctionCaller(_Node):
473  """A base class for Convertibles that reference functions."""
474
475  def __init__(self, node, function, enclosing_graph, first_function_input,
476               type_attribute, function_attributes):
477    """Initializes a _FunctionCaller.
478
479    Args:
480      node: As in _Node.
481      function: As in _Node.
482      enclosing_graph: As in _Node.
483      first_function_input: The index of the first NodeDef input that is tied to
484        the function inputs. It is assumed that the rest of the NodeDef inputs
485        map one to one to function inputs.
486      type_attribute: The name of the NodeDef attribute that defines the input
487        types. It is assumed that the types listed here map one-to-one with the
488        function inputs (that is, they do _not_ specify types for inputs that
489        are not passed to functions).
490      function_attributes: The names of the NodeDef attributes containing
491        references to functions.
492    """
493    super(_FunctionCaller, self).__init__(node, function, enclosing_graph)
494    self._first_function_input = first_function_input
495    self._type_attribute = type_attribute
496    self._function_attributes = function_attributes
497
498  def converted_self(self):
499    if self._converted_self is None:
500      node = super(_FunctionCaller, self).converted_self().node
501      converted_names = self._enclosing_graph.converted_function_names
502      for attr_name in self._function_attributes:
503        attr = node.attr[attr_name]
504        if attr.HasField("func"):
505          attr.func.name = converted_names[attr.func.name]
506        elif attr.HasField("list"):
507          for func in attr.list.func:
508            func.name = converted_names[func.name]
509    return self._converted_self
510
511  def convert_variable_to_constant(self, incoming_edge, tensor_data):
512    node = self.converted_self()
513    index = incoming_edge.destination.index
514    if index >= self._first_function_input:
515      node.update_dtype(self._type_attribute,
516                        index - self._first_function_input, tensor_data.dtype)
517
518    # The loop below is reasonable but not correct in general:
519    # The outgoing edges going into the functions are correct, because the
520    # inputs map to the function inputs. But the edges going into other nodes do
521    # not take into account the logic of the body function, which may do
522    # arbitrary things to the node's output:
523    #
524    #   while x < 0:
525    #     return y
526    #
527    # In this case, the node's ":0" output may map to its ":1 input". For the
528    # time being, then, we only process edges into functions.
529    for edge in self.outgoing_edges:
530      dest = edge.destination.convertible
531      if edge.source.index == index and isinstance(dest, _Function):
532        dest.convert_variable_to_constant(edge, tensor_data)
533
534  def create_edges(self):
535    """Creates edges related to a function caller.
536
537    Edges from a function caller to its called functions are always edges from
538    _inputs_ to _inputs_: a FunctionDef input is given by the caller, based on
539    its own inputs.
540    """
541    super(_FunctionCaller, self).create_edges()
542    for attr_name in self._function_attributes:
543      attr = self._node.attr[attr_name]
544      if attr.HasField("func"):
545        function = self._enclosing_graph.functions[attr.func.name]
546        for index in range(len(self._node.input) - self._first_function_input):
547          self.add_outgoing_edge(
548              _Edge(
549                  _EndPoint(self, index + self._first_function_input),
550                  _EndPoint(function, index)))
551      elif attr.HasField("list"):
552        for func in attr.list.func:
553          function = self._enclosing_graph.functions[func.name]
554          for index in range(
555              len(self._node.input) - self._first_function_input):
556            self.add_outgoing_edge(
557                _Edge(
558                    _EndPoint(self, index + self._first_function_input),
559                    _EndPoint(function, index)))
560
561
562class _If(_FunctionCaller):
563  """Specialization of _Node to If-like operations."""
564
565  def __init__(self, node, function, enclosing_graph):
566    super(_If, self).__init__(
567        node,
568        function,
569        enclosing_graph,
570        first_function_input=1,
571        type_attribute="Tin",
572        function_attributes=["then_branch", "else_branch"])
573
574
575class _Case(_FunctionCaller):
576  """Specialization of _Node to Case-like operations."""
577
578  def __init__(self, node, function, enclosing_graph):
579    super(_Case, self).__init__(
580        node,
581        function,
582        enclosing_graph,
583        first_function_input=1,
584        type_attribute="Tin",
585        function_attributes=["branches"])
586
587
588class _PartitionedCall(_FunctionCaller):
589  """Specialization of _Node to PartitionedCall-like operations."""
590
591  def __init__(self, node, function, enclosing_graph):
592    super(_PartitionedCall, self).__init__(
593        node,
594        function,
595        enclosing_graph,
596        first_function_input=0,
597        type_attribute="Tin",
598        function_attributes=["f"])
599
600
601class _While(_FunctionCaller):
602  """Specialization of _Node to While-like operations."""
603
604  def __init__(self, node, function, enclosing_graph):
605    super(_While, self).__init__(
606        node,
607        function,
608        enclosing_graph,
609        first_function_input=0,
610        type_attribute="T",
611        function_attributes=["body", "cond"])
612
613  def convert_variable_to_constant(self, incoming_edge, tensor_data):
614    super(_While, self).convert_variable_to_constant(incoming_edge, tensor_data)
615    node = self.converted_self()
616    if node.node.attr["output_shapes"].list.shape:
617      node.node.attr["output_shapes"].list.shape[
618          incoming_edge.destination.index].CopyFrom(
619              tensor_shape_pb2.TensorShapeProto(dim=[
620                  tensor_shape_pb2.TensorShapeProto.Dim(size=dim)
621                  for dim in tensor_data.numpy.shape
622              ]))
623
624    # The while's body inputs and outputs have the same type, so here we can go
625    # ahead and change that function's output type.
626    body_name = self._node.attr["body"].func.name
627    body = self._enclosing_graph.functions[body_name].converted_self().function
628    body.signature.output_arg[
629        incoming_edge.destination.index].type = tensor_data.dtype
630
631
632class _GraphDef(_Convertible):
633  """A convertible GraphDef."""
634
635  def __init__(self, graph_def):
636    super(_GraphDef, self).__init__(enclosing_graph=None)
637    self._graph_def = graph_def
638    self._nodes = {
639        n.name: _Node.new(node=n, function=None, enclosing_graph=self)
640        for n in graph_def.node
641    }
642    self._functions = {
643        f.signature.name: _Function(f, enclosing_graph=self)
644        for f in graph_def.library.function
645    }
646    self.create_edges()
647    self._converted_function_names = None
648
649  @property
650  def graph_def(self):
651    return self._graph_def
652
653  @property
654  def nodes(self):
655    return self._nodes
656
657  @property
658  def functions(self):
659    return self._functions
660
661  @property
662  def converted_function_names(self):
663    """Map from original to new function names.
664
665    In order to avoid conflicts (two functions with the same name, one converted
666    and one not), we need to change the name of every converted function to
667    something that is hopefully unique.
668
669    Returns:
670      Map from original to new suggested function names.
671    """
672    if self._converted_function_names is None:
673      parsed_names = []  # List of (id, base_name, original_name)
674      for name in self.functions:
675        elements = name.rsplit("_", 1)
676        if len(elements) == 2 and elements[1].isnumeric():
677          parsed_names.append((int(elements[1]), elements[0], name))
678        else:
679          parsed_names.append((-1, name, name))
680      self._converted_function_names = {
681          name: "{}_frozen_{}".format(base_name, ops.uid())
682          for (_, base_name, name) in sorted(parsed_names)
683      }
684
685    return self._converted_function_names
686
687  def rename_function(self, old_name, new_name):
688    func = self.functions.pop(old_name)
689    func.function.signature.name = new_name
690    self.functions[new_name] = func
691
692  def converted_self(self):
693    if self._converted_self is None:
694      copied_graph = graph_pb2.GraphDef()
695      copied_graph.CopyFrom(self._graph_def)
696      self._converted_self = _GraphDef(copied_graph)
697    return self._converted_self
698
699  def create_edges(self):
700    for n in self._nodes.values():
701      n.create_edges()
702    for f in self._functions.values():
703      f.create_edges()
704
705
706class _ConverterData(object):
707  """Container for constant conversion supporting data.
708
709  The data includes the graph being converted, and the pre-converted
710  tensors. This class will be specialized for ConcreteFunction and Session-based
711  conversions, as the means to obtain that data is different for each case.
712  """
713
714  def __init__(self,
715               graph_def,
716               variable_names_allowlist=None,
717               variable_names_denylist=None):
718    self._graph_def = graph_def
719    self._tensor_data = {}
720    self._build_node_defs_list()
721    self._variable_names_allowlist = variable_names_allowlist
722    self._variable_names_denylist = variable_names_denylist
723
724  @property
725  def graph_def(self):
726    """The graph to be converted."""
727    return self._graph_def
728
729  @property
730  def node_defs(self):
731    """All the node defs in the graph to be converted.
732
733    Returns:
734      A map from node name to the NodeDef for all NodeDefs in the graph, as well
735      as all control flow NodeDefs in the functions.
736    """
737    return self._node_defs
738
739  @property
740  def tensor_data(self):
741    """A map from tensor name to its converted _TensorData."""
742    return self._tensor_data
743
744  def _should_convert(self, name):
745    """Checks whether to convert the given variable name to a constant."""
746    return (self._variable_names_allowlist is None or
747            name in self._variable_names_allowlist) and (
748                self._variable_names_denylist is None or
749                name not in self._variable_names_denylist)
750
751  def _build_node_defs_list(self):
752    """Builds the list of NodeDefs in the GraphDef.
753
754    This list consists of all NodeDefs in the main graph as well as all control
755    flow NodeDefs in the functions.
756
757    The remaining NodeDefs in the functions are not included because the op
758    names
759    are not unique and the variables are handled differently than the main
760    graph.
761    The control flow ops need to be extracted because they are need their
762    attributes to be updated similar to the control flow ops in the main graph.
763    """
764    self._node_defs = {node.name: node for node in self._graph_def.node}
765
766    if self._graph_def.library:
767      for func in self._graph_def.library.function:
768        self._node_defs.update({
769            node.name: node
770            for node in func.node_def
771            if node.op in _CONTROL_FLOW_OPS
772        })
773
774
775class _FunctionConverterData(_ConverterData):
776  """Container for ConcreteFunction-based conversion data."""
777
778  def __init__(self,
779               func,
780               lower_control_flow,
781               aggressive_inlining,
782               variable_names_allowlist=None,
783               variable_names_denylist=None):
784    """Creates the conversion data for the given function.
785
786    Args:
787      func: ConcreteFunction.
788      lower_control_flow: Boolean indicating whether or not to lower control
789        flow ops such as If and While.
790      aggressive_inlining: Boolean indicating whether or not to do aggressive
791        function inlining (might be unsafe if function has stateful ops, not
792        properly connected to control outputs).
793      variable_names_allowlist: The set of variable names to convert (by
794        default, all variables are converted).
795      variable_names_denylist: The set of variable names to omit converting to
796        constants.
797    """
798
799    self._func = func
800    # Inline the graph in order to remove functions when possible.
801    graph_def = _run_inline_graph_optimization(func, lower_control_flow,
802                                               aggressive_inlining)
803    super(_FunctionConverterData, self).__init__(
804        graph_def,
805        variable_names_allowlist=variable_names_allowlist,
806        variable_names_denylist=variable_names_denylist)
807    self._build_tensor_data()
808
809  def _build_tensor_data(self):
810    """Caches the tensor data for all Placeholders in the given function."""
811    map_index_to_variable = {}
812    for var in self._func.graph.variables:
813      for idx, captured_input in enumerate(self._func.captured_inputs):
814        if var.handle is captured_input:  # pylint: disable=protected-access
815          map_index_to_variable[idx] = var
816          break
817
818    # Iterates through all captures which are represented as Placeholders.
819    for idx, (val_tensor, name_tensor) in enumerate(self._func.graph.captures):
820      tensor_name = name_tensor.name.split(":")[0]
821      if not self._should_convert(tensor_name):
822        continue
823      if idx in map_index_to_variable:
824        data = map_index_to_variable[idx].numpy()
825      else:
826        data = np.array(val_tensor.numpy())
827      self._tensor_data[tensor_name] = _TensorData(
828          numpy=data,
829          dtype=dtypes.as_dtype(data.dtype).as_datatype_enum,
830          index=idx)
831
832    # Get data for VariableV2 ops (reference variables) that cannot be lifted.
833    for node in self.node_defs.values():
834      if node.op == "VariableV2":
835        if not self._should_convert(node.name):
836          continue
837        if node.name not in self.tensor_data:
838          with self._func.graph.as_default():
839            identity_node = array_ops.identity(
840                self._func.graph.as_graph_element(node.name + ":0"))
841          pruned_graph = self._func.prune([], [identity_node.name])()[0]
842          self._tensor_data[node.name] = _TensorData(
843              numpy=pruned_graph.numpy(),
844              dtype=node.attr["dtype"].type,
845              index=None)
846
847
848class _SessionConverterData(_ConverterData):
849  """Container for Session-based conversion data."""
850
851  def __init__(self,
852               session,
853               graph_def,
854               output_node_names,
855               variable_names_allowlist=None,
856               variable_names_denylist=None):
857    graph_def = graph_util.extract_sub_graph(graph_def, output_node_names)
858    super(_SessionConverterData, self).__init__(
859        graph_def,
860        variable_names_allowlist=variable_names_allowlist,
861        variable_names_denylist=variable_names_denylist)
862
863    nodes_to_convert = []
864    tensor_names_to_convert = []
865    for node in self.graph_def.node:
866      if node.op in ["Variable", "VariableV2", "VarHandleOp"]:
867        tensor_name = node.name
868        if not self._should_convert(tensor_name):
869          continue
870        if node.op == "VarHandleOp":
871          tensor_name = tensor_name + "/Read/ReadVariableOp"
872        nodes_to_convert.append(node)
873        tensor_names_to_convert.append(tensor_name + ":0")
874
875    if tensor_names_to_convert:
876      converted_tensors = session.run(tensor_names_to_convert)
877      for node, tensor_value in zip(nodes_to_convert, converted_tensors):
878        self._tensor_data[node.name] = _TensorData(
879            numpy=tensor_value, dtype=node.attr["dtype"].type, index=None)
880
881
882def disable_lower_using_switch_merge(graph_def):
883  """Set '_lower_using_switch_merge' attributes to False.
884
885  Sets the attribute to False in the NodeDefs in the main graph and the NodeDefs
886  in each function's graph.
887
888  Args:
889    graph_def: GraphDef proto.
890
891  Returns:
892    GraphDef
893  """
894  output_graph_def = graph_pb2.GraphDef()
895  output_graph_def.CopyFrom(graph_def)
896
897  def disable_control_flow_lowering(node):
898    if node.op in _CONTROL_FLOW_OPS:
899      node.attr["_lower_using_switch_merge"].b = False
900
901  for node in output_graph_def.node:
902    disable_control_flow_lowering(node)
903
904  if output_graph_def.library:
905    for func in output_graph_def.library.function:
906      for node in func.node_def:
907        disable_control_flow_lowering(node)
908  return output_graph_def
909
910
911def _run_inline_graph_optimization(func, lower_control_flow,
912                                   aggressive_inlining):
913  """Apply function inline optimization to the graph.
914
915  Returns the GraphDef after Grappler's function inlining optimization is
916  applied. This optimization does not work on models with control flow.
917
918  Args:
919    func: ConcreteFunction.
920    lower_control_flow: Boolean indicating whether or not to lower control flow
921      ops such as If and While. (default True)
922    aggressive_inlining: Boolean indicating whether or not to do aggressive
923      function inlining (might be unsafe if function has stateful ops not
924      properly connected to control outputs).
925
926  Returns:
927    GraphDef
928  """
929  graph_def = func.graph.as_graph_def()
930  if not lower_control_flow:
931    graph_def = disable_lower_using_switch_merge(graph_def)
932
933  # In some cases, a secondary implementation of the function (e.g. for GPU) is
934  # written to the "api_implements" attribute. (e.g. `tf.keras.layers.LSTM` in
935  # TF2 produces a CuDNN-based RNN for GPU).
936  # This function suppose to inline all functions calls, but "api_implements"
937  # prevents this from happening. Removing the attribute solves the problem.
938  # To learn more about "api_implements", see:
939  #   tensorflow/core/grappler/optimizers/implementation_selector.h
940  for function in graph_def.library.function:
941    if "api_implements" in function.attr:
942      del function.attr["api_implements"]
943
944  meta_graph = export_meta_graph(graph_def=graph_def, graph=func.graph)
945
946  # Clear the initializer_name for the variables collections, since they are not
947  # needed after saved to saved_model.
948  for name in [
949      "variables", "model_variables", "trainable_variables", "local_variables"
950  ]:
951    raw_list = []
952    for raw in meta_graph.collection_def["variables"].bytes_list.value:
953      variable = variable_pb2.VariableDef()
954      variable.ParseFromString(raw)
955      variable.ClearField("initializer_name")
956      raw_list.append(variable.SerializeToString())
957    meta_graph.collection_def[name].bytes_list.value[:] = raw_list
958
959  # Add a collection 'train_op' so that Grappler knows the outputs.
960  fetch_collection = meta_graph_pb2.CollectionDef()
961  for array in func.inputs + func.outputs:
962    fetch_collection.node_list.value.append(array.name)
963  meta_graph.collection_def["train_op"].CopyFrom(fetch_collection)
964
965  # Initialize RewriterConfig with everything disabled except function inlining.
966  config = config_pb2.ConfigProto()
967  rewrite_options = config.graph_options.rewrite_options
968  rewrite_options.min_graph_nodes = -1  # do not skip small graphs
969  rewrite_options.optimizers.append("function")
970  if aggressive_inlining:
971    rewrite_options.function_optimization =\
972      rewriter_config_pb2.RewriterConfig.AGGRESSIVE
973  return tf_optimizer.OptimizeGraph(config, meta_graph)
974
975
976def _construct_concrete_function(func, output_graph_def,
977                                 converted_input_indices):
978  """Constructs a concrete function from the `output_graph_def`.
979
980  Args:
981    func: ConcreteFunction
982    output_graph_def: GraphDef proto.
983    converted_input_indices: Set of integers of input indices that were
984      converted to constants.
985
986  Returns:
987    ConcreteFunction.
988  """
989  # Create a ConcreteFunction from the new GraphDef.
990  input_tensors = func.graph.internal_captures
991  converted_inputs = object_identity.ObjectIdentitySet(
992      [input_tensors[index] for index in converted_input_indices])
993  not_converted_inputs = [
994      tensor for tensor in func.inputs if tensor not in converted_inputs
995  ]
996  not_converted_inputs_map = {
997      tensor.name: tensor for tensor in not_converted_inputs
998  }
999
1000  new_input_names = [tensor.name for tensor in not_converted_inputs]
1001  new_output_names = [tensor.name for tensor in func.outputs]
1002
1003  # Remove old functions to use updated functions from graph def.
1004  for f in output_graph_def.library.function:
1005    if context.context().has_function(f.signature.name):
1006      context.context().remove_function(f.signature.name)
1007
1008  new_func = wrap_function.function_from_graph_def(output_graph_def,
1009                                                   new_input_names,
1010                                                   new_output_names)
1011
1012  # Manually propagate shape for input tensors where the shape is not correctly
1013  # propagated. Scalars shapes are lost when wrapping the function.
1014  for input_tensor in new_func.inputs:
1015    input_tensor.set_shape(not_converted_inputs_map[input_tensor.name].shape)
1016  return new_func
1017
1018
1019def _replace_variables_by_constants(converter_data):
1020  """Replaces variables by constants on a given graph.
1021
1022  Given a _ConverterData instance with converted variables in its tensor_data
1023  field, create a new graph where the respective variables are replaced with the
1024  converted constants.
1025
1026  Args:
1027    converter_data: A pre-populated _ConverterData instance.
1028
1029  Returns:
1030    The converted graph.
1031  """
1032  input_graph = _GraphDef(converter_data.graph_def)
1033
1034  for tensor_name, tensor_data in converter_data.tensor_data.items():
1035    input_graph.nodes[tensor_name].convert_variable_to_constant(
1036        None, tensor_data)
1037
1038  converted_graph = input_graph.converted_self().graph_def
1039
1040  converted_input_indices = {
1041      t.index
1042      for t in converter_data.tensor_data.values()
1043      if t.index is not None
1044  }
1045
1046  return converted_graph, converted_input_indices
1047
1048
1049def convert_variables_to_constants_v2(func,
1050                                      lower_control_flow=True,
1051                                      aggressive_inlining=False):
1052  """Replaces all the variables in a graph with constants of the same values.
1053
1054  TensorFlow 2.0 function for converting all Variable ops into Const ops holding
1055  the same values. This makes it possible to describe the network fully with a
1056  single GraphDef file, and allows the removal of a lot of ops related to
1057  loading and saving the variables. This function runs Grappler's function
1058  inlining optimization in order to return a single subgraph.
1059
1060  The current implementation only works for graphs that do not contain any
1061  control flow or embedding related ops.
1062
1063  Args:
1064    func: ConcreteFunction.
1065    lower_control_flow: Boolean indicating whether or not to lower control flow
1066      ops such as If and While. (default True)
1067    aggressive_inlining: Boolean indicating whether or not to do aggressive
1068      function inlining (might be unsafe if function has stateful ops, not
1069      properly connected to control outputs). (default False)
1070
1071  Returns:
1072    ConcreteFunction containing a simplified version of the original.
1073  """
1074
1075  converter_data = _FunctionConverterData(
1076      func=func,
1077      lower_control_flow=lower_control_flow,
1078      aggressive_inlining=aggressive_inlining)
1079
1080  output_graph_def, converted_input_indices = _replace_variables_by_constants(
1081      converter_data=converter_data)
1082
1083  return _construct_concrete_function(func, output_graph_def,
1084                                      converted_input_indices)
1085
1086
1087def convert_variables_to_constants_v2_as_graph(func,
1088                                               lower_control_flow=True,
1089                                               aggressive_inlining=False):
1090  """Replaces all the variables in a graph with constants of the same values.
1091
1092  This function works as same as convert_variables_to_constants_v2, but it
1093  returns the intermediate `GraphDef` as well. This `GraphDef` contains all the
1094  debug information after all the transformations in the frozen phase.
1095
1096  Args:
1097    func: ConcreteFunction.
1098    lower_control_flow: Boolean indicating whether or not to lower control flow
1099      ops such as If and While. (default True)
1100    aggressive_inlining: Boolean indicating whether or not to do aggressive
1101      function inlining (might be unsafe if function has stateful ops, not
1102      properly connected to control outputs).
1103
1104  Returns:
1105    ConcreteFunction containing a simplified version of the original, and also
1106    the intermediate GraphDef containing the node debug information for the
1107    transformations in the frozen phase.
1108  """
1109  converter_data = _FunctionConverterData(
1110      func=func,
1111      lower_control_flow=lower_control_flow,
1112      aggressive_inlining=aggressive_inlining)
1113
1114  output_graph_def, converted_input_indices = _replace_variables_by_constants(
1115      converter_data=converter_data)
1116
1117  frozen_func = _construct_concrete_function(func, output_graph_def,
1118                                             converted_input_indices)
1119  return frozen_func, output_graph_def
1120
1121
1122def convert_variables_to_constants_from_session_graph(
1123    session,
1124    graph_def,
1125    output_node_names,
1126    variable_names_allowlist=None,
1127    variable_names_denylist=None):
1128  """Replaces all the variables in a graph with constants of the same values.
1129
1130  This function works similarly to convert_variables_to_constants_v2, but it
1131  retrieves the constant values from a Session instead of from a
1132  ConcreteFunction. This is useful when converting graphs generated from
1133  TensorFlow V1, where ConcreteFunctions are not available. This also differs
1134  from graph_util.convert_variables_to_constants in that it supports resource
1135  variables when V2 control flow constructions are present.
1136
1137  Args:
1138    session: Active TensorFlow session containing the variables.
1139    graph_def: A GraphDef to convert.
1140    output_node_names: List of name strings for the result nodes of the graph.
1141    variable_names_allowlist: The set of variable names to convert (by default,
1142      all variables are converted).
1143    variable_names_denylist: The set of variable names to omit converting to
1144      constants.
1145
1146  Returns:
1147    An optimized GraphDef.
1148  """
1149  # TODO(b/176982859): Find a more satisfying way to update shape information
1150  # than clearing it, or migrate users to a workflow that does not require
1151  # freezing.
1152  for function in graph_def.library.function:
1153    if "_input_shapes" in function.attr:
1154      for input_arg, shape_attribute in zip(
1155          function.signature.input_arg,
1156          function.attr["_input_shapes"].list.shape):
1157        if dtypes.as_dtype(input_arg.type) == dtypes.resource:
1158          shape_attribute.unknown_rank = True
1159  graph_def, _ = _replace_variables_by_constants(
1160      converter_data=_SessionConverterData(
1161          session=session,
1162          graph_def=graph_def,
1163          output_node_names=output_node_names,
1164          variable_names_allowlist=variable_names_allowlist,
1165          variable_names_denylist=variable_names_denylist))
1166  return graph_def
1167