1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15r"""Replaces a subgraph of a TensorFlow GraphDef with a single node. 16 17In conjunction with TOCO's --allow_custom_op this script allows selected 18portions of a TensorFlow GraphDef to be executed by custom code. 19 20Example: 21 22bazel run tensorflow/lite/python:create_custom_op -- \ 23 --input_graph=/tmp/input.pb \ 24 --output_graph=/tmp/output.pb \ 25 --inputs=concat,concat_1 \ 26 --outputs=detection_classes \ 27 --op_definition='op:"PostProcessing" attr{key:"num" value:{i:10}}' 28 29The above will identify a subgraph starting at nodes 'concat' and 'concat_1', 30and ending at 'detection_classes'. All nodes in between will be removed and 31replaced by a new op called 'PostProcessing'. 32 33""" 34from __future__ import absolute_import 35from __future__ import division 36from __future__ import print_function 37import uuid as _uuid 38from absl import app 39from absl import flags 40from google.protobuf import text_format 41from tensorflow.contrib.framework.python.framework.graph_util import fuse_op 42from tensorflow.core.framework import graph_pb2 43from tensorflow.core.framework import node_def_pb2 44from tensorflow.core.framework import types_pb2 45from tensorflow.python.platform import gfile 46 47FLAGS = flags.FLAGS 48 49flags.DEFINE_string("input_graph", "", "Binary graphdef to load.") 50flags.DEFINE_string("output_graph", "", "Resulting binary graphdef.") 51 52flags.DEFINE_string("inputs", "", 53 "Comma-separated list of inputs to the subgraph.") 54flags.DEFINE_string("outputs", "", 55 "Comma-separated list of outputs of the subgraph.") 56flags.DEFINE_string("op_definition", "", 57 "A text NodeDef defining the contents of the custom op.") 58 59 60def _read_graph_def(filename): 61 if not gfile.Exists(filename): 62 raise ValueError("Input graph file '" + filename + "' does not exist!") 63 64 graph_def = graph_pb2.GraphDef() 65 with gfile.GFile(filename, "rb") as f: 66 graph_def.ParseFromString(f.read()) 67 return graph_def 68 69 70def _write_graph_def(graph_def, filename): 71 if not filename: 72 raise ValueError("Output graph file not specified") 73 74 with gfile.Open(filename, "wb") as f: 75 f.write(graph_def.SerializeToString()) 76 77 78def _collapse_subgraph(graph_def, inputs, outputs, op_definition): 79 """Substitute a custom op for the subgraph delimited by inputs and outputs.""" 80 name = _uuid.uuid1().hex 81 # We need a default type, but it can be changed using 'op_definition'. 82 default_type = types_pb2.DT_FLOAT 83 new_graph = fuse_op( 84 graph_def=graph_def, 85 input_nodes=inputs, 86 output_nodes=outputs, 87 output_dtypes=[default_type for _ in outputs], 88 output_quantized=False, 89 op_name=name, 90 op_type="CustomTfLiteOp") 91 node_def = node_def_pb2.NodeDef() 92 text_format.Parse(op_definition, node_def) 93 for node in new_graph.node: 94 if node.name == name: 95 node.MergeFrom(node_def) 96 return new_graph 97 98 99def main(argv): 100 del argv # unused 101 graph = _read_graph_def(filename=flags.FLAGS.input_graph) 102 graph = _collapse_subgraph( 103 graph_def=graph, 104 inputs=flags.FLAGS.inputs.split(","), 105 outputs=flags.FLAGS.outputs.split(","), 106 op_definition=flags.FLAGS.op_definition) 107 _write_graph_def(graph_def=graph, filename=flags.FLAGS.output_graph) 108 109 110if __name__ == "__main__": 111 app.run(main) 112