1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Helpers to convert variables to constants in TensorFlow 2.0."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.core.framework import graph_pb2
22from tensorflow.core.protobuf import config_pb2
23from tensorflow.core.protobuf import meta_graph_pb2
24from tensorflow.python.eager import function
25from tensorflow.python.framework import func_graph
26from tensorflow.python.framework import importer
27from tensorflow.python.framework import ops
28from tensorflow.python.framework import tensor_util
29from tensorflow.python.grappler import tf_optimizer
30from tensorflow.python.platform import tf_logging as logging
31from tensorflow.python.training.saver import export_meta_graph
32
33
34def _run_inline_graph_optimization(func):
35  """Apply function inline optimization to the graph.
36
37  Returns the GraphDef after Grappler's function inlining optimization is
38  applied. This optimization does not work on models with control flow.
39
40  Args:
41    func: ConcreteFunction.
42
43  Returns:
44    GraphDef
45  """
46  meta_graph = export_meta_graph(
47      graph_def=func.graph.as_graph_def(), graph=func.graph)
48
49  # Add a collection 'train_op' so that Grappler knows the outputs.
50  fetch_collection = meta_graph_pb2.CollectionDef()
51  for array in func.inputs + func.outputs:
52    fetch_collection.node_list.value.append(array.name)
53  meta_graph.collection_def["train_op"].CopyFrom(fetch_collection)
54
55  # Initialize RewriterConfig with everything disabled except function inlining.
56  config = config_pb2.ConfigProto()
57  rewrite_options = config.graph_options.rewrite_options
58  rewrite_options.optimizers.append("function")
59  return tf_optimizer.OptimizeGraph(config, meta_graph)
60
61
62def _get_tensors_from_graph(graph, tensors):
63  """Gets the Tensors in `graph` with the name of the tensors in `tensors`.
64
65  Args:
66    graph: TensorFlow Graph.
67    tensors: List of Tensors.
68
69  Returns:
70    List of Tensors.
71  """
72  new_tensors = []
73  for orig_tensor in tensors:
74    new_tensor = graph.get_tensor_by_name(orig_tensor.name)
75    if new_tensor.shape.rank is None:
76      new_tensor.set_shape(orig_tensor.shape)
77    new_tensors.append(new_tensor)
78  return new_tensors
79
80
81def _construct_concrete_function(input_func, graph_def):
82  """Creates a ConcreteFunction from the input function and frozen graph.
83
84  Args:
85    input_func: ConcreteFunction.
86    graph_def: TensorFlow GraphDef.
87
88  Returns:
89    ConcreteFunction containing the graph_def.
90  """
91  output_graph = func_graph.FuncGraph(input_func.graph.name)
92  with output_graph.as_default():
93    importer.import_graph_def(graph_def, name="")
94    output_graph.inputs = _get_tensors_from_graph(output_graph,
95                                                  input_func.inputs)
96    output_graph.outputs = _get_tensors_from_graph(output_graph,
97                                                   input_func.outputs)
98
99  output_graph.structured_outputs = input_func.graph.structured_outputs
100  output_graph.structured_input_signature = (
101      input_func.graph.structured_input_signature)
102
103  # Create the ConcreteFunction and add it to the global context.
104  output_func = function.ConcreteFunction(output_graph)
105  output_func.add_to_graph()
106
107  # Inject the captured inputs into the ConcreteFunction.
108  output_func._captured_inputs = input_func.captured_inputs  # pylint: disable=protected-access
109  output_func.graph.variables = input_func.graph.variables
110
111  output_func._arg_keywords = input_func._arg_keywords  # pylint: disable=protected-access
112  output_func._num_position_args = input_func._num_positional_args  # pylint: disable=protected-access
113
114  # Register the gradients in the current root context.
115  with ops.init_scope():
116    output_func._register_gradient()  # pylint: disable=protected-access
117  return output_func
118
119
120def convert_variables_to_constants_v2(func):
121  """Replaces all the variables in a graph with constants of the same values.
122
123  TensorFlow 2.0 function for converting all Variable ops into Const ops holding
124  the same values. This makes it possible to describe the network fully with a
125  single GraphDef file, and allows the removal of a lot of ops related to
126  loading and saving the variables. This function runs Grappler's function
127  inlining optimization in order to return a single subgraph.
128
129  The current implementation only works for graphs that do not contain any
130  control flow or embedding related ops.
131
132  Args:
133    func: ConcreteFunction.
134
135  Returns:
136    ConcreteFunction containing a simplified version of the original.
137  """
138  # TODO(nupurgarg): Replace ResourceGather with Gather.
139  # TODO(nupurgarg): Change attr for Variables in control flow and functions.
140  graph_def = _run_inline_graph_optimization(func)
141
142  # Identify the ReadVariableOps.
143  get_name = lambda name: name.split(":")[0]
144  map_name_to_node = {get_name(node.name): node for node in graph_def.node}
145
146  # TODO(b/125838789): Use `func.graph.captures`.
147  # Get mapping from input name to variable value.
148  tensor_data = {}
149  input_tensors = func.inputs[-len(func.captured_inputs):]
150  for var in func.graph.variables:
151    index = func.captured_inputs.index(var.handle)
152    tensor = input_tensors[index]
153    tensor_data[get_name(tensor.name)] = var.numpy()
154
155  resource_identities = {}
156  resource_placeholders = {}
157  for node in graph_def.node:
158    if node.op == "ReadVariableOp":
159      # Get name of Placeholder op associated with ReadVariableOp. There can be
160      # an Identity in between the ReadVariableOp and Placeholder. Store the
161      # Identity ops with the associated dtypes.
162      input_name = get_name(node.input[0])
163      while map_name_to_node[input_name].op == "Identity":
164        resource_identities[input_name] = node.attr["dtype"]
165        input_name = get_name(map_name_to_node[input_name].input[0])
166      if map_name_to_node[input_name].op != "Placeholder":
167        raise ValueError("Cannot find the Placeholder op that is an input "
168                         "to the ReadVariableOp.")
169      # Build a map of Placeholder ops that are inputs to ReadVariableOps to the
170      # variable's dtype and data.
171      resource_placeholders[input_name] = {
172          "dtype": node.attr["dtype"],
173          "data": tensor_data[input_name],
174      }
175
176  # Reconstruct the graph with constants in place of variables.
177  output_graph_def = graph_pb2.GraphDef()
178  how_many_converted = 0
179
180  for input_node in graph_def.node:
181    output_node = output_graph_def.node.add()
182    # Convert Placeholder ops that are inputs to ReadVariableOps into Const ops.
183    if input_node.name in resource_placeholders:
184      dtype = resource_placeholders[input_node.name]["dtype"]
185      data = resource_placeholders[input_node.name]["data"]
186
187      output_node.op = "Const"
188      output_node.name = input_node.name
189      output_node.attr["dtype"].CopyFrom(dtype)
190      output_node.attr["value"].tensor.CopyFrom(
191          tensor_util.make_tensor_proto(
192              data, dtype=dtype.type, shape=data.shape))
193      how_many_converted += 1
194    # Change the dtype for Identity ops that are inputs to ReadVariableOps.
195    elif input_node.name in resource_identities:
196      output_node.CopyFrom(input_node)
197      output_node.attr["T"].CopyFrom(resource_identities[input_node.name])
198    # Convert ReadVariableOps into Identity ops.
199    elif input_node.op == "ReadVariableOp":
200      output_node.op = "Identity"
201      output_node.name = input_node.name
202      output_node.input.extend([input_node.input[0]])
203      output_node.attr["T"].CopyFrom(input_node.attr["dtype"])
204      if "_class" in input_node.attr:
205        output_node.attr["_class"].CopyFrom(input_node.attr["_class"])
206    else:
207      output_node.CopyFrom(input_node)
208
209  logging.info("Converted %d variables to const ops.", how_many_converted)
210  # TODO(b/126613403): Use wrap_function.function_from_graph_def.
211  return _construct_concrete_function(func, output_graph_def)
212