1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15r"""Replaces a subgraph of a TensorFlow GraphDef with a single node.
16
17In conjunction with TOCO's --allow_custom_op this script allows selected
18portions of a TensorFlow GraphDef to be executed by custom code.
19
20Example:
21
22bazel run tensorflow/lite/python:create_custom_op  -- \
23  --input_graph=/tmp/input.pb \
24  --output_graph=/tmp/output.pb \
25  --inputs=concat,concat_1 \
26  --outputs=detection_classes \
27  --op_definition='op:"PostProcessing" attr{key:"num" value:{i:10}}'
28
29The above will identify a subgraph starting at nodes 'concat' and 'concat_1',
30and ending at 'detection_classes'. All nodes in between will be removed and
31replaced by a new op called 'PostProcessing'.
32
33"""
34from __future__ import absolute_import
35from __future__ import division
36from __future__ import print_function
37import uuid as _uuid
38from absl import app
39from absl import flags
40from google.protobuf import text_format
41from tensorflow.contrib.framework.python.framework.graph_util import fuse_op
42from tensorflow.core.framework import graph_pb2
43from tensorflow.core.framework import node_def_pb2
44from tensorflow.core.framework import types_pb2
45from tensorflow.python.platform import gfile
46
47FLAGS = flags.FLAGS
48
49flags.DEFINE_string("input_graph", "", "Binary graphdef to load.")
50flags.DEFINE_string("output_graph", "", "Resulting binary graphdef.")
51
52flags.DEFINE_string("inputs", "",
53                    "Comma-separated list of inputs to the subgraph.")
54flags.DEFINE_string("outputs", "",
55                    "Comma-separated list of outputs of the subgraph.")
56flags.DEFINE_string("op_definition", "",
57                    "A text NodeDef defining the contents of the custom op.")
58
59
60def _read_graph_def(filename):
61  if not gfile.Exists(filename):
62    raise ValueError("Input graph file '" + filename + "' does not exist!")
63
64  graph_def = graph_pb2.GraphDef()
65  with gfile.GFile(filename, "rb") as f:
66    graph_def.ParseFromString(f.read())
67  return graph_def
68
69
70def _write_graph_def(graph_def, filename):
71  if not filename:
72    raise ValueError("Output graph file not specified")
73
74  with gfile.Open(filename, "wb") as f:
75    f.write(graph_def.SerializeToString())
76
77
78def _collapse_subgraph(graph_def, inputs, outputs, op_definition):
79  """Substitute a custom op for the subgraph delimited by inputs and outputs."""
80  name = _uuid.uuid1().hex
81  # We need a default type, but it can be changed using 'op_definition'.
82  default_type = types_pb2.DT_FLOAT
83  new_graph = fuse_op(
84      graph_def=graph_def,
85      input_nodes=inputs,
86      output_nodes=outputs,
87      output_dtypes=[default_type for _ in outputs],
88      output_quantized=False,
89      op_name=name,
90      op_type="CustomTfLiteOp")
91  node_def = node_def_pb2.NodeDef()
92  text_format.Parse(op_definition, node_def)
93  for node in new_graph.node:
94    if node.name == name:
95      node.MergeFrom(node_def)
96  return new_graph
97
98
99def main(argv):
100  del argv  # unused
101  graph = _read_graph_def(filename=flags.FLAGS.input_graph)
102  graph = _collapse_subgraph(
103      graph_def=graph,
104      inputs=flags.FLAGS.inputs.split(","),
105      outputs=flags.FLAGS.outputs.split(","),
106      op_definition=flags.FLAGS.op_definition)
107  _write_graph_def(graph_def=graph, filename=flags.FLAGS.output_graph)
108
109
110if __name__ == "__main__":
111  app.run(main)
112