1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Importer for an exported TensorFlow model. 16 17This module provides a function to create a SessionBundle containing both the 18Session and MetaGraph. 19""" 20from __future__ import absolute_import 21from __future__ import division 22from __future__ import print_function 23 24import os 25 26from tensorflow.contrib.session_bundle import constants 27from tensorflow.contrib.session_bundle import manifest_pb2 28from tensorflow.core.framework import graph_pb2 29from tensorflow.core.protobuf import meta_graph_pb2 30from tensorflow.python.client import session 31from tensorflow.python.framework import ops 32from tensorflow.python.lib.io import file_io 33from tensorflow.python.training import saver as saver_lib 34from tensorflow.python.util.deprecation import deprecated 35 36 37@deprecated("2017-06-30", 38 "No longer supported. Switch to SavedModel immediately.") 39def maybe_session_bundle_dir(export_dir): 40 """Checks if the model path contains session bundle model. 41 42 Args: 43 export_dir: string path to model checkpoint, for example 'model/00000123' 44 45 Returns: 46 true if path contains session bundle model files, ie META_GRAPH_DEF_FILENAME 47 """ 48 49 meta_graph_filename = os.path.join(export_dir, 50 constants.META_GRAPH_DEF_FILENAME) 51 return file_io.file_exists(meta_graph_filename) 52 53 54@deprecated("2017-06-30", 55 "No longer supported. Switch to SavedModel immediately.") 56def load_session_bundle_from_path(export_dir, 57 target="", 58 config=None, 59 meta_graph_def=None): 60 """Load session bundle from the given path. 61 62 The function reads input from the export_dir, constructs the graph data to the 63 default graph and restores the parameters for the session created. 64 65 Args: 66 export_dir: the directory that contains files exported by exporter. 67 target: The execution engine to connect to. See target in tf.Session() 68 config: A ConfigProto proto with configuration options. See config in 69 tf.Session() 70 meta_graph_def: optional object of type MetaGraphDef. If this object is 71 present, then it is used instead of parsing MetaGraphDef from export_dir. 72 73 Returns: 74 session: a tensorflow session created from the variable files. 75 meta_graph: a meta graph proto saved in the exporter directory. 76 77 Raises: 78 RuntimeError: if the required files are missing or contain unrecognizable 79 fields, i.e. the exported model is invalid. 80 """ 81 if not meta_graph_def: 82 meta_graph_filename = os.path.join(export_dir, 83 constants.META_GRAPH_DEF_FILENAME) 84 if not file_io.file_exists(meta_graph_filename): 85 raise RuntimeError("Expected meta graph file missing %s" % 86 meta_graph_filename) 87 # Reads meta graph file. 88 meta_graph_def = meta_graph_pb2.MetaGraphDef() 89 meta_graph_def.ParseFromString( 90 file_io.read_file_to_string(meta_graph_filename, binary_mode=True)) 91 92 variables_filename = "" 93 variables_filename_list = [] 94 checkpoint_sharded = False 95 96 variables_index_filename = os.path.join(export_dir, 97 constants.VARIABLES_INDEX_FILENAME_V2) 98 checkpoint_v2 = file_io.file_exists(variables_index_filename) 99 100 # Find matching checkpoint files. 101 if checkpoint_v2: 102 # The checkpoint is in v2 format. 103 variables_filename_pattern = os.path.join( 104 export_dir, constants.VARIABLES_FILENAME_PATTERN_V2) 105 variables_filename_list = file_io.get_matching_files( 106 variables_filename_pattern) 107 checkpoint_sharded = True 108 else: 109 variables_filename = os.path.join(export_dir, constants.VARIABLES_FILENAME) 110 if file_io.file_exists(variables_filename): 111 variables_filename_list = [variables_filename] 112 else: 113 variables_filename = os.path.join(export_dir, 114 constants.VARIABLES_FILENAME_PATTERN) 115 variables_filename_list = file_io.get_matching_files(variables_filename) 116 checkpoint_sharded = True 117 118 # Prepare the files to restore a session. 119 if not variables_filename_list: 120 restore_files = "" 121 elif checkpoint_v2 or not checkpoint_sharded: 122 # For checkpoint v2 or v1 with non-sharded files, use "export" to restore 123 # the session. 124 restore_files = constants.VARIABLES_FILENAME 125 else: 126 restore_files = constants.VARIABLES_FILENAME_PATTERN 127 128 assets_dir = os.path.join(export_dir, constants.ASSETS_DIRECTORY) 129 130 collection_def = meta_graph_def.collection_def 131 graph_def = graph_pb2.GraphDef() 132 if constants.GRAPH_KEY in collection_def: 133 # Use serving graph_def in MetaGraphDef collection_def if exists 134 graph_def_any = collection_def[constants.GRAPH_KEY].any_list.value 135 if len(graph_def_any) != 1: 136 raise RuntimeError("Expected exactly one serving GraphDef in : %s" % 137 meta_graph_def) 138 else: 139 graph_def_any[0].Unpack(graph_def) 140 # Replace the graph def in meta graph proto. 141 meta_graph_def.graph_def.CopyFrom(graph_def) 142 143 ops.reset_default_graph() 144 sess = session.Session(target, graph=None, config=config) 145 # Import the graph. 146 saver = saver_lib.import_meta_graph(meta_graph_def) 147 # Restore the session. 148 if restore_files: 149 saver.restore(sess, os.path.join(export_dir, restore_files)) 150 151 init_op_tensor = None 152 if constants.INIT_OP_KEY in collection_def: 153 init_ops = collection_def[constants.INIT_OP_KEY].node_list.value 154 if len(init_ops) != 1: 155 raise RuntimeError("Expected exactly one serving init op in : %s" % 156 meta_graph_def) 157 init_op_tensor = ops.get_collection(constants.INIT_OP_KEY)[0] 158 159 # Create asset input tensor list. 160 asset_tensor_dict = {} 161 if constants.ASSETS_KEY in collection_def: 162 assets_any = collection_def[constants.ASSETS_KEY].any_list.value 163 for asset in assets_any: 164 asset_pb = manifest_pb2.AssetFile() 165 asset.Unpack(asset_pb) 166 asset_tensor_dict[asset_pb.tensor_binding.tensor_name] = os.path.join( 167 assets_dir, asset_pb.filename) 168 169 if init_op_tensor: 170 # Run the init op. 171 sess.run(fetches=[init_op_tensor], feed_dict=asset_tensor_dict) 172 173 return sess, meta_graph_def 174