1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Functions to convert SavedModel to frozen GraphDefs."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.lite.python.convert import tensor_name
22from tensorflow.core.framework import types_pb2
23from tensorflow.python.client import session
24from tensorflow.python.framework import graph_util as tf_graph_util
25from tensorflow.python.framework import ops
26from tensorflow.python.platform import tf_logging as logging
27from tensorflow.python.saved_model import constants
28from tensorflow.python.saved_model import loader
29
30
31def _log_tensor_details(tensor_info):
32  """Log tensor details: name, shape, and type."""
33  for key in tensor_info:
34    val = tensor_info[key]
35    dtype = types_pb2.DataType.Name(val.dtype)
36    if val.tensor_shape.unknown_rank:
37      shape = "unknown_rank"
38    else:
39      dims = [str(dim.size) for dim in val.tensor_shape.dim]
40      shape = "({})".format(", ".join(dims))
41
42    logging.info("Tensor's key in saved_model's tensor_map: %s", key)
43    logging.info(" tensor name: %s, shape: %s, type: %s", val.name, shape,
44                 dtype)
45
46
47def get_meta_graph_def(saved_model_dir, tag_set):
48  """Validate saved_model and extract MetaGraphDef.
49
50  Args:
51    saved_model_dir: saved_model path to convert.
52    tag_set: Set of tag(s) of the MetaGraphDef to load.
53
54  Returns:
55    The meta_graph_def used for tflite conversion.
56
57  Raises:
58    ValueError: No valid MetaGraphDef for given tag_set.
59  """
60  with session.Session(graph=ops.Graph()) as sess:
61    return loader.load(sess, tag_set, saved_model_dir)
62
63
64def get_signature_def(meta_graph, signature_key):
65  """Get the signature def from meta_graph with given signature_key.
66
67  Args:
68    meta_graph: meta_graph_def.
69    signature_key: signature_def in the meta_graph_def.
70
71  Returns:
72    The signature_def used for tflite conversion.
73
74  Raises:
75    ValueError: Given signature_key is not valid for this meta_graph.
76  """
77  signature_def_map = meta_graph.signature_def
78  signature_def_keys = set(signature_def_map.keys())
79  logging.info(
80      "The given SavedModel MetaGraphDef contains SignatureDefs with the "
81      "following keys: %s", signature_def_keys)
82  if signature_key not in signature_def_keys:
83    raise ValueError("No '{}' in the SavedModel\'s SignatureDefs. Possible "
84                     "values are '{}'.".format(signature_key,
85                                               ",".join(signature_def_keys)))
86  return signature_def_map[signature_key]
87
88
89def get_inputs_outputs(signature_def):
90  """Get inputs and outputs from SignatureDef.
91
92  Args:
93    signature_def: SignatureDef in the meta_graph_def for conversion.
94
95  Returns:
96    The inputs and outputs in the graph for conversion.
97  """
98  inputs_tensor_info = signature_def.inputs
99  outputs_tensor_info = signature_def.outputs
100  logging.info("input tensors info: ")
101  _log_tensor_details(inputs_tensor_info)
102  logging.info("output tensors info: ")
103  _log_tensor_details(outputs_tensor_info)
104
105  def gather_names(tensor_info):
106    return [tensor_info[key].name for key in tensor_info]
107
108  inputs = gather_names(inputs_tensor_info)
109  outputs = gather_names(outputs_tensor_info)
110  return inputs, outputs
111
112
113def _get_tensors(graph, signature_def_tensor_names=None,
114                 user_tensor_names=None):
115  """Gets the tensors associated with the tensor names.
116
117  Either signature_def_tensor_names or user_tensor_names should be provided. If
118  the user provides tensors, the tensors associated with the user provided
119  tensor names are provided. Otherwise, the tensors associated with the names in
120  the SignatureDef are provided.
121
122  Args:
123    graph: GraphDef representing graph.
124    signature_def_tensor_names: Tensor names stored in either the inputs or
125      outputs of a SignatureDef. (default None)
126    user_tensor_names: Tensor names provided by the user. (default None)
127
128  Returns:
129    List of tensors.
130
131  Raises:
132    ValueError:
133      signature_def_tensors and user_tensor_names are undefined or empty.
134      user_tensor_names are not valid.
135  """
136  tensors = []
137  if user_tensor_names:
138    # Sort the tensor names.
139    user_tensor_names = sorted(user_tensor_names)
140
141    tensors = get_tensors_from_tensor_names(graph, user_tensor_names)
142  elif signature_def_tensor_names:
143    tensors = [
144        graph.get_tensor_by_name(name)
145        for name in sorted(signature_def_tensor_names)
146    ]
147  else:
148    # Throw ValueError if signature_def_tensors and user_tensor_names are both
149    # either undefined or empty.
150    raise ValueError(
151        "Specify either signature_def_tensor_names or user_tensor_names")
152
153  return tensors
154
155
156def get_tensors_from_tensor_names(graph, tensor_names):
157  """Gets the Tensors associated with the `tensor_names` in the provided graph.
158
159  Args:
160    graph: TensorFlow Graph.
161    tensor_names: List of strings that represent names of tensors in the graph.
162
163  Returns:
164    A list of Tensor objects in the same order the names are provided.
165
166  Raises:
167    ValueError:
168      tensor_names contains an invalid tensor name.
169  """
170  # Get the list of all of the tensors.
171  tensor_name_to_tensor = {
172      tensor_name(tensor): tensor for op in graph.get_operations()
173      for tensor in op.values()
174  }
175
176  # Get the tensors associated with tensor_names.
177  tensors = []
178  invalid_tensors = []
179  for name in tensor_names:
180    tensor = tensor_name_to_tensor.get(name)
181    if tensor is None:
182      invalid_tensors.append(name)
183    else:
184      tensors.append(tensor)
185
186  # Throw ValueError if any user input names are not valid tensors.
187  if invalid_tensors:
188    raise ValueError("Invalid tensors '{}' were found.".format(
189        ",".join(invalid_tensors)))
190  return tensors
191
192
193def set_tensor_shapes(tensors, shapes):
194  """Sets Tensor shape for each tensor if the shape is defined.
195
196  Args:
197    tensors: TensorFlow ops.Tensor.
198    shapes: Dict of strings representing input tensor names to list of
199      integers representing input shapes (e.g., {"foo": : [1, 16, 16, 3]}).
200
201  Raises:
202    ValueError:
203      `shapes` contains an invalid tensor.
204      `shapes` contains an invalid shape for a valid tensor.
205  """
206  if shapes:
207    tensor_names_to_tensor = {tensor_name(tensor): tensor for tensor in tensors}
208    for name, shape in shapes.items():
209      if name not in tensor_names_to_tensor:
210        raise ValueError("Invalid tensor \'{}\' found in tensor shapes "
211                         "map.".format(name))
212      if shape is not None:
213        tensor = tensor_names_to_tensor[name]
214        try:
215          tensor.set_shape(shape)
216        except ValueError as error:
217          message = ("The shape of tensor '{0}' cannot be changed from {1} to "
218                     "{2}. {3}".format(name, tensor.shape, shape, str(error)))
219          raise ValueError(message)
220
221
222def freeze_saved_model(saved_model_dir, input_arrays, input_shapes,
223                       output_arrays, tag_set, signature_key):
224  """Converts a SavedModel to a frozen graph.
225
226  Args:
227    saved_model_dir: SavedModel directory to convert.
228    input_arrays: List of input tensors to freeze graph with. Uses input arrays
229      from SignatureDef when none are provided.
230    input_shapes: Dict of strings representing input tensor names to list of
231      integers representing input shapes (e.g., {"foo": : [1, 16, 16, 3]}).
232      Automatically determined when input shapes is None (e.g., {"foo" : None}).
233    output_arrays: List of output tensors to freeze graph with. Uses output
234      arrays from SignatureDef when none are provided.
235    tag_set: Set of tags identifying the MetaGraphDef within the SavedModel to
236      analyze. All tags in the tag set must be present.
237    signature_key: Key identifying SignatureDef containing inputs and outputs.
238
239  Returns:
240    frozen_graph_def: Frozen GraphDef.
241    in_tensors: List of input tensors for the graph.
242    out_tensors: List of output tensors for the graph.
243
244  Raises:
245    ValueError:
246      SavedModel doesn't contain a MetaGraphDef identified by tag_set.
247      signature_key is not in the MetaGraphDef.
248      assets/ directory is in the MetaGraphDef.
249      input_shapes does not match the length of input_arrays.
250      input_arrays or output_arrays are not valid.
251  """
252  # Read SignatureDef.
253  meta_graph = get_meta_graph_def(saved_model_dir, tag_set)
254  signature_def = get_signature_def(meta_graph, signature_key)
255  inputs, outputs = get_inputs_outputs(signature_def)
256
257  # Check SavedModel for assets directory.
258  collection_def = meta_graph.collection_def
259  if constants.ASSETS_KEY in collection_def:
260    raise ValueError("SavedModels with assets/ directory are not supported.")
261
262  graph = ops.Graph()
263  with session.Session(graph=graph) as sess:
264    loader.load(sess, meta_graph.meta_info_def.tags, saved_model_dir)
265
266    # Gets input and output tensors.
267    # TODO(zhixianyan): Use TFLite supported Op list to filter outputs.
268    in_tensors = _get_tensors(graph, inputs, input_arrays)
269    out_tensors = _get_tensors(graph, outputs, output_arrays)
270    set_tensor_shapes(in_tensors, input_shapes)
271
272    output_names = [node.split(":")[0] for node in outputs]
273    frozen_graph_def = tf_graph_util.convert_variables_to_constants(
274        sess, graph.as_graph_def(), output_names)
275
276    return frozen_graph_def, in_tensors, out_tensors
277