1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15r"""Computes a header file to be used with SELECTIVE_REGISTRATION. 16 17See the executable wrapper, print_selective_registration_header.py, for more 18information. 19""" 20 21from __future__ import absolute_import 22from __future__ import division 23from __future__ import print_function 24 25import os 26import sys 27 28from google.protobuf import text_format 29 30from tensorflow.core.framework import graph_pb2 31from tensorflow.python import pywrap_tensorflow 32from tensorflow.python.platform import gfile 33from tensorflow.python.platform import tf_logging 34 35# Usually, we use each graph node to induce registration of an op and 36# corresponding kernel; nodes without a corresponding kernel (perhaps due to 37# attr types) generate a warning but are otherwise ignored. Ops in this set are 38# registered even if there's no corresponding kernel. 39OPS_WITHOUT_KERNEL_WHITELIST = frozenset([ 40 # AccumulateNV2 is rewritten away by AccumulateNV2RemovePass; see 41 # core/common_runtime/accumulate_n_optimizer.cc. 42 'AccumulateNV2' 43]) 44 45 46def get_ops_and_kernels(proto_fileformat, proto_files, default_ops_str): 47 """Gets the ops and kernels needed from the model files.""" 48 ops = set() 49 50 for proto_file in proto_files: 51 tf_logging.info('Loading proto file %s', proto_file) 52 # Load GraphDef. 53 file_data = gfile.GFile(proto_file, 'rb').read() 54 if proto_fileformat == 'rawproto': 55 graph_def = graph_pb2.GraphDef.FromString(file_data) 56 else: 57 assert proto_fileformat == 'textproto' 58 graph_def = text_format.Parse(file_data, graph_pb2.GraphDef()) 59 60 # Find all ops and kernels used by the graph. 61 for node_def in graph_def.node: 62 if not node_def.device: 63 node_def.device = '/cpu:0' 64 kernel_class = pywrap_tensorflow.TryFindKernelClass( 65 node_def.SerializeToString()) 66 op = str(node_def.op) 67 if kernel_class or op in OPS_WITHOUT_KERNEL_WHITELIST: 68 op_and_kernel = (op, str(kernel_class.decode('utf-8')) 69 if kernel_class else None) 70 if op_and_kernel not in ops: 71 ops.add(op_and_kernel) 72 else: 73 print( 74 'Warning: no kernel found for op %s' % node_def.op, file=sys.stderr) 75 76 # Add default ops. 77 if default_ops_str and default_ops_str != 'all': 78 for s in default_ops_str.split(','): 79 op, kernel = s.split(':') 80 op_and_kernel = (op, kernel) 81 if op_and_kernel not in ops: 82 ops.add(op_and_kernel) 83 84 return list(sorted(ops)) 85 86 87def get_header_from_ops_and_kernels(ops_and_kernels, 88 include_all_ops_and_kernels): 89 """Returns a header for use with tensorflow SELECTIVE_REGISTRATION. 90 91 Args: 92 ops_and_kernels: a set of (op_name, kernel_class_name) pairs to include. 93 include_all_ops_and_kernels: if True, ops_and_kernels is ignored and all op 94 kernels are included. 95 96 Returns: 97 the string of the header that should be written as ops_to_register.h. 98 """ 99 ops = set([op for op, _ in ops_and_kernels]) 100 result_list = [] 101 102 def append(s): 103 result_list.append(s) 104 105 _, script_name = os.path.split(sys.argv[0]) 106 append('// This file was autogenerated by %s' % script_name) 107 append('#ifndef OPS_TO_REGISTER') 108 append('#define OPS_TO_REGISTER') 109 110 if include_all_ops_and_kernels: 111 append('#define SHOULD_REGISTER_OP(op) true') 112 append('#define SHOULD_REGISTER_OP_KERNEL(clz) true') 113 append('#define SHOULD_REGISTER_OP_GRADIENT true') 114 else: 115 line = ''' 116 namespace { 117 constexpr const char* skip(const char* x) { 118 return (*x) ? (*x == ' ' ? skip(x + 1) : x) : x; 119 } 120 121 constexpr bool isequal(const char* x, const char* y) { 122 return (*skip(x) && *skip(y)) 123 ? (*skip(x) == *skip(y) && isequal(skip(x) + 1, skip(y) + 1)) 124 : (!*skip(x) && !*skip(y)); 125 } 126 127 template<int N> 128 struct find_in { 129 static constexpr bool f(const char* x, const char* const y[N]) { 130 return isequal(x, y[0]) || find_in<N - 1>::f(x, y + 1); 131 } 132 }; 133 134 template<> 135 struct find_in<0> { 136 static constexpr bool f(const char* x, const char* const y[]) { 137 return false; 138 } 139 }; 140 } // end namespace 141 ''' 142 line += 'constexpr const char* kNecessaryOpKernelClasses[] = {\n' 143 for _, kernel_class in ops_and_kernels: 144 if kernel_class is None: continue 145 line += '"%s",\n' % kernel_class 146 line += '};' 147 append(line) 148 append('#define SHOULD_REGISTER_OP_KERNEL(clz) ' 149 '(find_in<sizeof(kNecessaryOpKernelClasses) ' 150 '/ sizeof(*kNecessaryOpKernelClasses)>::f(clz, ' 151 'kNecessaryOpKernelClasses))') 152 append('') 153 154 append('constexpr inline bool ShouldRegisterOp(const char op[]) {') 155 append(' return false') 156 for op in sorted(ops): 157 append(' || isequal(op, "%s")' % op) 158 append(' ;') 159 append('}') 160 append('#define SHOULD_REGISTER_OP(op) ShouldRegisterOp(op)') 161 append('') 162 163 append('#define SHOULD_REGISTER_OP_GRADIENT ' + ( 164 'true' if 'SymbolicGradient' in ops else 'false')) 165 166 append('#endif') 167 return '\n'.join(result_list) 168 169 170def get_header(graphs, 171 proto_fileformat='rawproto', 172 default_ops='NoOp:NoOp,_Recv:RecvOp,_Send:SendOp'): 173 """Computes a header for use with tensorflow SELECTIVE_REGISTRATION. 174 175 Args: 176 graphs: a list of paths to GraphDef files to include. 177 proto_fileformat: optional format of proto file, either 'textproto' or 178 'rawproto' (default). 179 default_ops: optional comma-separated string of operator:kernel pairs to 180 always include implementation for. Pass 'all' to have all operators and 181 kernels included. Default: 'NoOp:NoOp,_Recv:RecvOp,_Send:SendOp'. 182 Returns: 183 the string of the header that should be written as ops_to_register.h. 184 """ 185 ops_and_kernels = get_ops_and_kernels(proto_fileformat, graphs, default_ops) 186 if not ops_and_kernels: 187 print('Error reading graph!') 188 return 1 189 190 return get_header_from_ops_and_kernels(ops_and_kernels, default_ops == 'all') 191