1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Importer for an exported TensorFlow model.
16
17This module provides a function to create a SessionBundle containing both the
18Session and MetaGraph.
19"""
20from __future__ import absolute_import
21from __future__ import division
22from __future__ import print_function
23
24import os
25
26from tensorflow.contrib.session_bundle import constants
27from tensorflow.contrib.session_bundle import manifest_pb2
28from tensorflow.core.framework import graph_pb2
29from tensorflow.core.protobuf import meta_graph_pb2
30from tensorflow.python.client import session
31from tensorflow.python.framework import ops
32from tensorflow.python.lib.io import file_io
33from tensorflow.python.training import saver as saver_lib
34from tensorflow.python.util.deprecation import deprecated
35
36
37@deprecated("2017-06-30",
38            "No longer supported. Switch to SavedModel immediately.")
39def maybe_session_bundle_dir(export_dir):
40  """Checks if the model path contains session bundle model.
41
42  Args:
43    export_dir: string path to model checkpoint, for example 'model/00000123'
44
45  Returns:
46    true if path contains session bundle model files, ie META_GRAPH_DEF_FILENAME
47  """
48
49  meta_graph_filename = os.path.join(export_dir,
50                                     constants.META_GRAPH_DEF_FILENAME)
51  return file_io.file_exists(meta_graph_filename)
52
53
54@deprecated("2017-06-30",
55            "No longer supported. Switch to SavedModel immediately.")
56def load_session_bundle_from_path(export_dir,
57                                  target="",
58                                  config=None,
59                                  meta_graph_def=None):
60  """Load session bundle from the given path.
61
62  The function reads input from the export_dir, constructs the graph data to the
63  default graph and restores the parameters for the session created.
64
65  Args:
66    export_dir: the directory that contains files exported by exporter.
67    target: The execution engine to connect to. See target in tf.Session()
68    config: A ConfigProto proto with configuration options. See config in
69    tf.Session()
70    meta_graph_def: optional object of type MetaGraphDef. If this object is
71    present, then it is used instead of parsing MetaGraphDef from export_dir.
72
73  Returns:
74    session: a tensorflow session created from the variable files.
75    meta_graph: a meta graph proto saved in the exporter directory.
76
77  Raises:
78    RuntimeError: if the required files are missing or contain unrecognizable
79    fields, i.e. the exported model is invalid.
80  """
81  if not meta_graph_def:
82    meta_graph_filename = os.path.join(export_dir,
83                                       constants.META_GRAPH_DEF_FILENAME)
84    if not file_io.file_exists(meta_graph_filename):
85      raise RuntimeError("Expected meta graph file missing %s" %
86                         meta_graph_filename)
87    # Reads meta graph file.
88    meta_graph_def = meta_graph_pb2.MetaGraphDef()
89    meta_graph_def.ParseFromString(
90        file_io.read_file_to_string(meta_graph_filename, binary_mode=True))
91
92  variables_filename = ""
93  variables_filename_list = []
94  checkpoint_sharded = False
95
96  variables_index_filename = os.path.join(export_dir,
97                                          constants.VARIABLES_INDEX_FILENAME_V2)
98  checkpoint_v2 = file_io.file_exists(variables_index_filename)
99
100  # Find matching checkpoint files.
101  if checkpoint_v2:
102    # The checkpoint is in v2 format.
103    variables_filename_pattern = os.path.join(
104        export_dir, constants.VARIABLES_FILENAME_PATTERN_V2)
105    variables_filename_list = file_io.get_matching_files(
106        variables_filename_pattern)
107    checkpoint_sharded = True
108  else:
109    variables_filename = os.path.join(export_dir, constants.VARIABLES_FILENAME)
110    if file_io.file_exists(variables_filename):
111      variables_filename_list = [variables_filename]
112    else:
113      variables_filename = os.path.join(export_dir,
114                                        constants.VARIABLES_FILENAME_PATTERN)
115      variables_filename_list = file_io.get_matching_files(variables_filename)
116      checkpoint_sharded = True
117
118  # Prepare the files to restore a session.
119  if not variables_filename_list:
120    restore_files = ""
121  elif checkpoint_v2 or not checkpoint_sharded:
122    # For checkpoint v2 or v1 with non-sharded files, use "export" to restore
123    # the session.
124    restore_files = constants.VARIABLES_FILENAME
125  else:
126    restore_files = constants.VARIABLES_FILENAME_PATTERN
127
128  assets_dir = os.path.join(export_dir, constants.ASSETS_DIRECTORY)
129
130  collection_def = meta_graph_def.collection_def
131  graph_def = graph_pb2.GraphDef()
132  if constants.GRAPH_KEY in collection_def:
133    # Use serving graph_def in MetaGraphDef collection_def if exists
134    graph_def_any = collection_def[constants.GRAPH_KEY].any_list.value
135    if len(graph_def_any) != 1:
136      raise RuntimeError("Expected exactly one serving GraphDef in : %s" %
137                         meta_graph_def)
138    else:
139      graph_def_any[0].Unpack(graph_def)
140      # Replace the graph def in meta graph proto.
141      meta_graph_def.graph_def.CopyFrom(graph_def)
142
143  ops.reset_default_graph()
144  sess = session.Session(target, graph=None, config=config)
145  # Import the graph.
146  saver = saver_lib.import_meta_graph(meta_graph_def)
147  # Restore the session.
148  if restore_files:
149    saver.restore(sess, os.path.join(export_dir, restore_files))
150
151  init_op_tensor = None
152  if constants.INIT_OP_KEY in collection_def:
153    init_ops = collection_def[constants.INIT_OP_KEY].node_list.value
154    if len(init_ops) != 1:
155      raise RuntimeError("Expected exactly one serving init op in : %s" %
156                         meta_graph_def)
157    init_op_tensor = ops.get_collection(constants.INIT_OP_KEY)[0]
158
159  # Create asset input tensor list.
160  asset_tensor_dict = {}
161  if constants.ASSETS_KEY in collection_def:
162    assets_any = collection_def[constants.ASSETS_KEY].any_list.value
163    for asset in assets_any:
164      asset_pb = manifest_pb2.AssetFile()
165      asset.Unpack(asset_pb)
166      asset_tensor_dict[asset_pb.tensor_binding.tensor_name] = os.path.join(
167          assets_dir, asset_pb.filename)
168
169  if init_op_tensor:
170    # Run the init op.
171    sess.run(fetches=[init_op_tensor], feed_dict=asset_tensor_dict)
172
173  return sess, meta_graph_def
174